101178654SLei Zhang //===- SPIRVConversion.cpp - SPIR-V Conversion Utilities ------------------===// 201178654SLei Zhang // 301178654SLei Zhang // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 401178654SLei Zhang // See https://llvm.org/LICENSE.txt for license information. 501178654SLei Zhang // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 601178654SLei Zhang // 701178654SLei Zhang //===----------------------------------------------------------------------===// 801178654SLei Zhang // 901178654SLei Zhang // This file implements utilities used to lower to SPIR-V dialect. 1001178654SLei Zhang // 1101178654SLei Zhang //===----------------------------------------------------------------------===// 1201178654SLei Zhang 1301178654SLei Zhang #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h" 146867e49fSAngel Zhang #include "mlir/Dialect/Arith/IR/Arith.h" 1536550692SRiver Riddle #include "mlir/Dialect/Func/IR/FuncOps.h" 1601178654SLei Zhang #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" 1789b595e1SLei Zhang #include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h" 1801178654SLei Zhang #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" 19fce33e11SLei Zhang #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h" 2097f3bb73SLei Zhang #include "mlir/Dialect/SPIRV/IR/TargetAndABI.h" 216867e49fSAngel Zhang #include "mlir/Dialect/Utils/IndexingUtils.h" 226867e49fSAngel Zhang #include "mlir/Dialect/Vector/IR/VectorOps.h" 23f83950abSAngel Zhang #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" 246867e49fSAngel Zhang #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h" 2569a3c9cdSLei Zhang #include "mlir/IR/BuiltinTypes.h" 266867e49fSAngel Zhang #include "mlir/IR/Operation.h" 276867e49fSAngel Zhang #include "mlir/IR/PatternMatch.h" 28f83950abSAngel Zhang #include "mlir/Pass/Pass.h" 296867e49fSAngel Zhang #include "mlir/Support/LLVM.h" 307c3ae48fSLei Zhang #include "mlir/Transforms/DialectConversion.h" 31f83950abSAngel Zhang #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 326867e49fSAngel Zhang #include "mlir/Transforms/OneToNTypeConversion.h" 336867e49fSAngel Zhang #include "llvm/ADT/STLExtras.h" 346867e49fSAngel Zhang #include "llvm/ADT/SmallVector.h" 3501178654SLei Zhang #include "llvm/ADT/StringExtras.h" 3601178654SLei Zhang #include "llvm/Support/Debug.h" 37f83950abSAngel Zhang #include "llvm/Support/LogicalResult.h" 38f772dcbbSLei Zhang #include "llvm/Support/MathExtras.h" 3901178654SLei Zhang 4001178654SLei Zhang #include <functional> 4197f3bb73SLei Zhang #include <optional> 4201178654SLei Zhang 4301178654SLei Zhang #define DEBUG_TYPE "mlir-spirv-conversion" 4401178654SLei Zhang 4501178654SLei Zhang using namespace mlir; 4601178654SLei Zhang 479527d77aSAngel Zhang namespace { 489527d77aSAngel Zhang 4901178654SLei Zhang //===----------------------------------------------------------------------===// 5001178654SLei Zhang // Utility functions 5101178654SLei Zhang //===----------------------------------------------------------------------===// 5201178654SLei Zhang 536867e49fSAngel Zhang static std::optional<SmallVector<int64_t>> getTargetShape(VectorType vecType) { 546867e49fSAngel Zhang LLVM_DEBUG(llvm::dbgs() << "Get target shape\n"); 556867e49fSAngel Zhang if (vecType.isScalable()) { 566867e49fSAngel Zhang LLVM_DEBUG(llvm::dbgs() 576867e49fSAngel Zhang << "--scalable vectors are not supported -> BAIL\n"); 586867e49fSAngel Zhang return std::nullopt; 596867e49fSAngel Zhang } 606867e49fSAngel Zhang SmallVector<int64_t> unrollShape = llvm::to_vector<4>(vecType.getShape()); 61f83950abSAngel Zhang std::optional<SmallVector<int64_t>> targetShape = SmallVector<int64_t>( 62f83950abSAngel Zhang 1, mlir::spirv::getComputeVectorSize(vecType.getShape().back())); 636867e49fSAngel Zhang if (!targetShape) { 646867e49fSAngel Zhang LLVM_DEBUG(llvm::dbgs() << "--no unrolling target shape defined\n"); 656867e49fSAngel Zhang return std::nullopt; 666867e49fSAngel Zhang } 676867e49fSAngel Zhang auto maybeShapeRatio = computeShapeRatio(unrollShape, *targetShape); 686867e49fSAngel Zhang if (!maybeShapeRatio) { 696867e49fSAngel Zhang LLVM_DEBUG(llvm::dbgs() 706867e49fSAngel Zhang << "--could not compute integral shape ratio -> BAIL\n"); 716867e49fSAngel Zhang return std::nullopt; 726867e49fSAngel Zhang } 736867e49fSAngel Zhang if (llvm::all_of(*maybeShapeRatio, [](int64_t v) { return v == 1; })) { 746867e49fSAngel Zhang LLVM_DEBUG(llvm::dbgs() << "--no unrolling needed -> SKIP\n"); 756867e49fSAngel Zhang return std::nullopt; 766867e49fSAngel Zhang } 776867e49fSAngel Zhang LLVM_DEBUG(llvm::dbgs() 786867e49fSAngel Zhang << "--found an integral shape ratio to unroll to -> SUCCESS\n"); 796867e49fSAngel Zhang return targetShape; 806867e49fSAngel Zhang } 816867e49fSAngel Zhang 8201178654SLei Zhang /// Checks that `candidates` extension requirements are possible to be satisfied 8301178654SLei Zhang /// with the given `targetEnv`. 8401178654SLei Zhang /// 8501178654SLei Zhang /// `candidates` is a vector of vector for extension requirements following 8601178654SLei Zhang /// ((Extension::A OR Extension::B) AND (Extension::C OR Extension::D)) 8701178654SLei Zhang /// convention. 8801178654SLei Zhang template <typename LabelT> 8901178654SLei Zhang static LogicalResult checkExtensionRequirements( 9001178654SLei Zhang LabelT label, const spirv::TargetEnv &targetEnv, 9101178654SLei Zhang const spirv::SPIRVType::ExtensionArrayRefVector &candidates) { 9201178654SLei Zhang for (const auto &ors : candidates) { 9301178654SLei Zhang if (targetEnv.allows(ors)) 9401178654SLei Zhang continue; 9501178654SLei Zhang 96fd91f81cSLei Zhang LLVM_DEBUG({ 97fd91f81cSLei Zhang SmallVector<StringRef> extStrings; 9801178654SLei Zhang for (spirv::Extension ext : ors) 9901178654SLei Zhang extStrings.push_back(spirv::stringifyExtension(ext)); 10001178654SLei Zhang 101fd91f81cSLei Zhang llvm::dbgs() << label << " illegal: requires at least one extension in [" 10201178654SLei Zhang << llvm::join(extStrings, ", ") 103fd91f81cSLei Zhang << "] but none allowed in target environment\n"; 104fd91f81cSLei Zhang }); 10501178654SLei Zhang return failure(); 10601178654SLei Zhang } 10701178654SLei Zhang return success(); 10801178654SLei Zhang } 10901178654SLei Zhang 11001178654SLei Zhang /// Checks that `candidates`capability requirements are possible to be satisfied 11101178654SLei Zhang /// with the given `isAllowedFn`. 11201178654SLei Zhang /// 11301178654SLei Zhang /// `candidates` is a vector of vector for capability requirements following 11401178654SLei Zhang /// ((Capability::A OR Capability::B) AND (Capability::C OR Capability::D)) 11501178654SLei Zhang /// convention. 11601178654SLei Zhang template <typename LabelT> 11701178654SLei Zhang static LogicalResult checkCapabilityRequirements( 11801178654SLei Zhang LabelT label, const spirv::TargetEnv &targetEnv, 11901178654SLei Zhang const spirv::SPIRVType::CapabilityArrayRefVector &candidates) { 12001178654SLei Zhang for (const auto &ors : candidates) { 12101178654SLei Zhang if (targetEnv.allows(ors)) 12201178654SLei Zhang continue; 12301178654SLei Zhang 124fd91f81cSLei Zhang LLVM_DEBUG({ 125fd91f81cSLei Zhang SmallVector<StringRef> capStrings; 12601178654SLei Zhang for (spirv::Capability cap : ors) 12701178654SLei Zhang capStrings.push_back(spirv::stringifyCapability(cap)); 12801178654SLei Zhang 129fd91f81cSLei Zhang llvm::dbgs() << label << " illegal: requires at least one capability in [" 13001178654SLei Zhang << llvm::join(capStrings, ", ") 131fd91f81cSLei Zhang << "] but none allowed in target environment\n"; 132fd91f81cSLei Zhang }); 13301178654SLei Zhang return failure(); 13401178654SLei Zhang } 13501178654SLei Zhang return success(); 13601178654SLei Zhang } 13701178654SLei Zhang 1385b15fe93SLei Zhang /// Returns true if the given `storageClass` needs explicit layout when used in 1395b15fe93SLei Zhang /// Shader environments. 1405b15fe93SLei Zhang static bool needsExplicitLayout(spirv::StorageClass storageClass) { 1415b15fe93SLei Zhang switch (storageClass) { 1425b15fe93SLei Zhang case spirv::StorageClass::PhysicalStorageBuffer: 1435b15fe93SLei Zhang case spirv::StorageClass::PushConstant: 1445b15fe93SLei Zhang case spirv::StorageClass::StorageBuffer: 1455b15fe93SLei Zhang case spirv::StorageClass::Uniform: 1465b15fe93SLei Zhang return true; 1475b15fe93SLei Zhang default: 1485b15fe93SLei Zhang return false; 1495b15fe93SLei Zhang } 1505b15fe93SLei Zhang } 1515b15fe93SLei Zhang 1525b15fe93SLei Zhang /// Wraps the given `elementType` in a struct and gets the pointer to the 1535b15fe93SLei Zhang /// struct. This is used to satisfy Vulkan interface requirements. 1545b15fe93SLei Zhang static spirv::PointerType 1555b15fe93SLei Zhang wrapInStructAndGetPointer(Type elementType, spirv::StorageClass storageClass) { 1565b15fe93SLei Zhang auto structType = needsExplicitLayout(storageClass) 1575b15fe93SLei Zhang ? spirv::StructType::get(elementType, /*offsetInfo=*/0) 1585b15fe93SLei Zhang : spirv::StructType::get(elementType); 1595b15fe93SLei Zhang return spirv::PointerType::get(structType, storageClass); 1605b15fe93SLei Zhang } 1615b15fe93SLei Zhang 16201178654SLei Zhang //===----------------------------------------------------------------------===// 16301178654SLei Zhang // Type Conversion 16401178654SLei Zhang //===----------------------------------------------------------------------===// 16501178654SLei Zhang 166c29fc69eSJakub Kuderski static spirv::ScalarType getIndexType(MLIRContext *ctx, 167c29fc69eSJakub Kuderski const SPIRVConversionOptions &options) { 168c29fc69eSJakub Kuderski return cast<spirv::ScalarType>( 169c29fc69eSJakub Kuderski IntegerType::get(ctx, options.use64bitIndex ? 64 : 32)); 170c29fc69eSJakub Kuderski } 171c29fc69eSJakub Kuderski 1725299843cSLei Zhang // TODO: This is a utility function that should probably be exposed by the 1735299843cSLei Zhang // SPIR-V dialect. Keeping it local till the use case arises. 1740de16fafSRamkumar Ramachandra static std::optional<int64_t> 1750de16fafSRamkumar Ramachandra getTypeNumBytes(const SPIRVConversionOptions &options, Type type) { 1765550c821STres Popp if (isa<spirv::ScalarType>(type)) { 1775299843cSLei Zhang auto bitWidth = type.getIntOrFloatBitWidth(); 17801178654SLei Zhang // According to the SPIR-V spec: 17901178654SLei Zhang // "There is no physical size or bit pattern defined for values with boolean 18001178654SLei Zhang // type. If they are stored (in conjunction with OpVariable), they can only 18101178654SLei Zhang // be used with logical addressing operations, not physical, and only with 18201178654SLei Zhang // non-externally visible shader Storage Classes: Workgroup, CrossWorkgroup, 18301178654SLei Zhang // Private, Function, Input, and Output." 1845299843cSLei Zhang if (bitWidth == 1) 1851a36588eSKazu Hirata return std::nullopt; 18601178654SLei Zhang return bitWidth / 8; 18701178654SLei Zhang } 1886dd07fa5SLei Zhang 1895550c821STres Popp if (auto complexType = dyn_cast<ComplexType>(type)) { 19097f3bb73SLei Zhang auto elementSize = getTypeNumBytes(options, complexType.getElementType()); 19197f3bb73SLei Zhang if (!elementSize) 19297f3bb73SLei Zhang return std::nullopt; 19397f3bb73SLei Zhang return 2 * *elementSize; 19497f3bb73SLei Zhang } 19597f3bb73SLei Zhang 1965550c821STres Popp if (auto vecType = dyn_cast<VectorType>(type)) { 1975299843cSLei Zhang auto elementSize = getTypeNumBytes(options, vecType.getElementType()); 19801178654SLei Zhang if (!elementSize) 1991a36588eSKazu Hirata return std::nullopt; 2006d5fc1e3SKazu Hirata return vecType.getNumElements() * *elementSize; 20101178654SLei Zhang } 2026dd07fa5SLei Zhang 2035550c821STres Popp if (auto memRefType = dyn_cast<MemRefType>(type)) { 20401178654SLei Zhang // TODO: Layout should also be controlled by the ABI attributes. For now 20501178654SLei Zhang // using the layout from MemRef. 20601178654SLei Zhang int64_t offset; 20701178654SLei Zhang SmallVector<int64_t, 4> strides; 20801178654SLei Zhang if (!memRefType.hasStaticShape() || 209*6aaa8f25SMatthias Springer failed(memRefType.getStridesAndOffset(strides, offset))) 2101a36588eSKazu Hirata return std::nullopt; 2115299843cSLei Zhang 21201178654SLei Zhang // To get the size of the memref object in memory, the total size is the 21301178654SLei Zhang // max(stride * dimension-size) computed for all dimensions times the size 21401178654SLei Zhang // of the element. 2155299843cSLei Zhang auto elementSize = getTypeNumBytes(options, memRefType.getElementType()); 2165299843cSLei Zhang if (!elementSize) 2171a36588eSKazu Hirata return std::nullopt; 2185299843cSLei Zhang 2195299843cSLei Zhang if (memRefType.getRank() == 0) 22001178654SLei Zhang return elementSize; 2215299843cSLei Zhang 22201178654SLei Zhang auto dims = memRefType.getShape(); 223399638f9SAliia Khasanova if (llvm::is_contained(dims, ShapedType::kDynamic) || 224399638f9SAliia Khasanova ShapedType::isDynamic(offset) || 225399638f9SAliia Khasanova llvm::is_contained(strides, ShapedType::kDynamic)) 2261a36588eSKazu Hirata return std::nullopt; 2275299843cSLei Zhang 22801178654SLei Zhang int64_t memrefSize = -1; 229e4853be2SMehdi Amini for (const auto &shape : enumerate(dims)) 23001178654SLei Zhang memrefSize = std::max(memrefSize, shape.value() * strides[shape.index()]); 2315299843cSLei Zhang 2326d5fc1e3SKazu Hirata return (offset + memrefSize) * *elementSize; 2336dd07fa5SLei Zhang } 2346dd07fa5SLei Zhang 2355550c821STres Popp if (auto tensorType = dyn_cast<TensorType>(type)) { 2365299843cSLei Zhang if (!tensorType.hasStaticShape()) 2371a36588eSKazu Hirata return std::nullopt; 2385299843cSLei Zhang 2395299843cSLei Zhang auto elementSize = getTypeNumBytes(options, tensorType.getElementType()); 2405299843cSLei Zhang if (!elementSize) 2411a36588eSKazu Hirata return std::nullopt; 2425299843cSLei Zhang 2436d5fc1e3SKazu Hirata int64_t size = *elementSize; 2445299843cSLei Zhang for (auto shape : tensorType.getShape()) 24501178654SLei Zhang size *= shape; 2465299843cSLei Zhang 24701178654SLei Zhang return size; 24801178654SLei Zhang } 2496dd07fa5SLei Zhang 25001178654SLei Zhang // TODO: Add size computation for other types. 2511a36588eSKazu Hirata return std::nullopt; 25201178654SLei Zhang } 25301178654SLei Zhang 25401178654SLei Zhang /// Converts a scalar `type` to a suitable type under the given `targetEnv`. 2550de16fafSRamkumar Ramachandra static Type 2560de16fafSRamkumar Ramachandra convertScalarType(const spirv::TargetEnv &targetEnv, 2570de16fafSRamkumar Ramachandra const SPIRVConversionOptions &options, spirv::ScalarType type, 2580de16fafSRamkumar Ramachandra std::optional<spirv::StorageClass> storageClass = {}) { 25901178654SLei Zhang // Get extension and capability requirements for the given type. 26001178654SLei Zhang SmallVector<ArrayRef<spirv::Extension>, 1> extensions; 26101178654SLei Zhang SmallVector<ArrayRef<spirv::Capability>, 2> capabilities; 26201178654SLei Zhang type.getExtensions(extensions, storageClass); 26301178654SLei Zhang type.getCapabilities(capabilities, storageClass); 26401178654SLei Zhang 26501178654SLei Zhang // If all requirements are met, then we can accept this type as-is. 26601178654SLei Zhang if (succeeded(checkCapabilityRequirements(type, targetEnv, capabilities)) && 26701178654SLei Zhang succeeded(checkExtensionRequirements(type, targetEnv, extensions))) 26801178654SLei Zhang return type; 26901178654SLei Zhang 27001178654SLei Zhang // Otherwise we need to adjust the type, which really means adjusting the 27101178654SLei Zhang // bitwidth given this is a scalar type. 272c0645454SJakub Kuderski if (!options.emulateLT32BitScalarTypes) 2735299843cSLei Zhang return nullptr; 27401178654SLei Zhang 275c0645454SJakub Kuderski // We only emulate narrower scalar types here and do not truncate results. 276c0645454SJakub Kuderski if (type.getIntOrFloatBitWidth() > 32) { 277c0645454SJakub Kuderski LLVM_DEBUG(llvm::dbgs() 278c0645454SJakub Kuderski << type 279c0645454SJakub Kuderski << " not converted to 32-bit for SPIR-V to avoid truncation\n"); 280c0645454SJakub Kuderski return nullptr; 281c0645454SJakub Kuderski } 282c0645454SJakub Kuderski 2835550c821STres Popp if (auto floatType = dyn_cast<FloatType>(type)) { 28401178654SLei Zhang LLVM_DEBUG(llvm::dbgs() << type << " converted to 32-bit for SPIR-V\n"); 28501178654SLei Zhang return Builder(targetEnv.getContext()).getF32Type(); 28601178654SLei Zhang } 28701178654SLei Zhang 2885550c821STres Popp auto intType = cast<IntegerType>(type); 28901178654SLei Zhang LLVM_DEBUG(llvm::dbgs() << type << " converted to 32-bit for SPIR-V\n"); 2901b97cdf8SRiver Riddle return IntegerType::get(targetEnv.getContext(), /*width=*/32, 2911b97cdf8SRiver Riddle intType.getSignedness()); 29201178654SLei Zhang } 29301178654SLei Zhang 294f772dcbbSLei Zhang /// Converts a sub-byte integer `type` to i32 regardless of target environment. 295bb9bb686SJakub Kuderski /// Returns a nullptr for unsupported integer types, including non sub-byte 296bb9bb686SJakub Kuderski /// types. 297f772dcbbSLei Zhang /// 298f772dcbbSLei Zhang /// Note that we don't recognize sub-byte types in `spirv::ScalarType` and use 299f772dcbbSLei Zhang /// the above given that these sub-byte types are not supported at all in 300f772dcbbSLei Zhang /// SPIR-V; there are no compute/storage capability for them like other 301f772dcbbSLei Zhang /// supported integer types. 302f772dcbbSLei Zhang static Type convertSubByteIntegerType(const SPIRVConversionOptions &options, 303f772dcbbSLei Zhang IntegerType type) { 304bb9bb686SJakub Kuderski if (type.getWidth() > 8) { 305bb9bb686SJakub Kuderski LLVM_DEBUG(llvm::dbgs() << "not a subbyte type\n"); 306bb9bb686SJakub Kuderski return nullptr; 307bb9bb686SJakub Kuderski } 308f772dcbbSLei Zhang if (options.subByteTypeStorage != SPIRVSubByteTypeStorage::Packed) { 309f772dcbbSLei Zhang LLVM_DEBUG(llvm::dbgs() << "unsupported sub-byte storage kind\n"); 310f772dcbbSLei Zhang return nullptr; 311f772dcbbSLei Zhang } 312f772dcbbSLei Zhang 313f772dcbbSLei Zhang if (!llvm::isPowerOf2_32(type.getWidth())) { 314f772dcbbSLei Zhang LLVM_DEBUG(llvm::dbgs() 315f772dcbbSLei Zhang << "unsupported non-power-of-two bitwidth in sub-byte" << type 316f772dcbbSLei Zhang << "\n"); 317f772dcbbSLei Zhang return nullptr; 318f772dcbbSLei Zhang } 319f772dcbbSLei Zhang 320f772dcbbSLei Zhang LLVM_DEBUG(llvm::dbgs() << type << " converted to 32-bit for SPIR-V\n"); 321f772dcbbSLei Zhang return IntegerType::get(type.getContext(), /*width=*/32, 322f772dcbbSLei Zhang type.getSignedness()); 323f772dcbbSLei Zhang } 324f772dcbbSLei Zhang 325c29fc69eSJakub Kuderski /// Returns a type with the same shape but with any index element type converted 326c29fc69eSJakub Kuderski /// to the matching integer type. This is a noop when the element type is not 327c29fc69eSJakub Kuderski /// the index type. 328c29fc69eSJakub Kuderski static ShapedType 329c29fc69eSJakub Kuderski convertIndexElementType(ShapedType type, 330c29fc69eSJakub Kuderski const SPIRVConversionOptions &options) { 331c29fc69eSJakub Kuderski Type indexType = dyn_cast<IndexType>(type.getElementType()); 332c29fc69eSJakub Kuderski if (!indexType) 333c29fc69eSJakub Kuderski return type; 334c29fc69eSJakub Kuderski 335c29fc69eSJakub Kuderski return type.clone(getIndexType(type.getContext(), options)); 336c29fc69eSJakub Kuderski } 337c29fc69eSJakub Kuderski 33801178654SLei Zhang /// Converts a vector `type` to a suitable type under the given `targetEnv`. 3390de16fafSRamkumar Ramachandra static Type 3400de16fafSRamkumar Ramachandra convertVectorType(const spirv::TargetEnv &targetEnv, 3410de16fafSRamkumar Ramachandra const SPIRVConversionOptions &options, VectorType type, 3420de16fafSRamkumar Ramachandra std::optional<spirv::StorageClass> storageClass = {}) { 343c29fc69eSJakub Kuderski type = cast<VectorType>(convertIndexElementType(type, options)); 344c29fc69eSJakub Kuderski auto scalarType = dyn_cast_or_null<spirv::ScalarType>(type.getElementType()); 345c29fc69eSJakub Kuderski if (!scalarType) { 34669a3c9cdSLei Zhang // If this is not a spec allowed scalar type, try to handle sub-byte integer 34769a3c9cdSLei Zhang // types. 34869a3c9cdSLei Zhang auto intType = dyn_cast<IntegerType>(type.getElementType()); 34969a3c9cdSLei Zhang if (!intType) { 350c29fc69eSJakub Kuderski LLVM_DEBUG(llvm::dbgs() 35169a3c9cdSLei Zhang << type 35269a3c9cdSLei Zhang << " illegal: cannot convert non-scalar element type\n"); 353c29fc69eSJakub Kuderski return nullptr; 354c29fc69eSJakub Kuderski } 355c29fc69eSJakub Kuderski 35669a3c9cdSLei Zhang Type elementType = convertSubByteIntegerType(options, intType); 357bb9bb686SJakub Kuderski if (!elementType) 358bb9bb686SJakub Kuderski return nullptr; 359bb9bb686SJakub Kuderski 36069a3c9cdSLei Zhang if (type.getRank() <= 1 && type.getNumElements() == 1) 36169a3c9cdSLei Zhang return elementType; 36269a3c9cdSLei Zhang 36369a3c9cdSLei Zhang if (type.getNumElements() > 4) { 36469a3c9cdSLei Zhang LLVM_DEBUG(llvm::dbgs() 36569a3c9cdSLei Zhang << type << " illegal: > 4-element unimplemented\n"); 36669a3c9cdSLei Zhang return nullptr; 36769a3c9cdSLei Zhang } 36869a3c9cdSLei Zhang 36969a3c9cdSLei Zhang return VectorType::get(type.getShape(), elementType); 37069a3c9cdSLei Zhang } 37169a3c9cdSLei Zhang 37253dac098SLei Zhang if (type.getRank() <= 1 && type.getNumElements() == 1) 373fce33e11SLei Zhang return convertScalarType(targetEnv, options, scalarType, storageClass); 3749f622b3dSLei Zhang 37501178654SLei Zhang if (!spirv::CompositeType::isValid(type)) { 37669a3c9cdSLei Zhang LLVM_DEBUG(llvm::dbgs() 37769a3c9cdSLei Zhang << type << " illegal: not a valid composite type\n"); 378004f29c0SLei Zhang return nullptr; 37901178654SLei Zhang } 38001178654SLei Zhang 38101178654SLei Zhang // Get extension and capability requirements for the given type. 38201178654SLei Zhang SmallVector<ArrayRef<spirv::Extension>, 1> extensions; 38301178654SLei Zhang SmallVector<ArrayRef<spirv::Capability>, 2> capabilities; 3845550c821STres Popp cast<spirv::CompositeType>(type).getExtensions(extensions, storageClass); 3855550c821STres Popp cast<spirv::CompositeType>(type).getCapabilities(capabilities, storageClass); 38601178654SLei Zhang 38701178654SLei Zhang // If all requirements are met, then we can accept this type as-is. 38801178654SLei Zhang if (succeeded(checkCapabilityRequirements(type, targetEnv, capabilities)) && 38901178654SLei Zhang succeeded(checkExtensionRequirements(type, targetEnv, extensions))) 39001178654SLei Zhang return type; 39101178654SLei Zhang 392fce33e11SLei Zhang auto elementType = 393fce33e11SLei Zhang convertScalarType(targetEnv, options, scalarType, storageClass); 39401178654SLei Zhang if (elementType) 395004f29c0SLei Zhang return VectorType::get(type.getShape(), elementType); 396004f29c0SLei Zhang return nullptr; 39701178654SLei Zhang } 39801178654SLei Zhang 39997f3bb73SLei Zhang static Type 40097f3bb73SLei Zhang convertComplexType(const spirv::TargetEnv &targetEnv, 40197f3bb73SLei Zhang const SPIRVConversionOptions &options, ComplexType type, 40297f3bb73SLei Zhang std::optional<spirv::StorageClass> storageClass = {}) { 40397f3bb73SLei Zhang auto scalarType = dyn_cast_or_null<spirv::ScalarType>(type.getElementType()); 40497f3bb73SLei Zhang if (!scalarType) { 40597f3bb73SLei Zhang LLVM_DEBUG(llvm::dbgs() 40697f3bb73SLei Zhang << type << " illegal: cannot convert non-scalar element type\n"); 40797f3bb73SLei Zhang return nullptr; 40897f3bb73SLei Zhang } 40997f3bb73SLei Zhang 41097f3bb73SLei Zhang auto elementType = 41197f3bb73SLei Zhang convertScalarType(targetEnv, options, scalarType, storageClass); 41297f3bb73SLei Zhang if (!elementType) 41397f3bb73SLei Zhang return nullptr; 41497f3bb73SLei Zhang if (elementType != type.getElementType()) { 41597f3bb73SLei Zhang LLVM_DEBUG(llvm::dbgs() 41697f3bb73SLei Zhang << type << " illegal: complex type emulation unsupported\n"); 41797f3bb73SLei Zhang return nullptr; 41897f3bb73SLei Zhang } 41997f3bb73SLei Zhang 42097f3bb73SLei Zhang return VectorType::get(2, elementType); 42197f3bb73SLei Zhang } 42297f3bb73SLei Zhang 42301178654SLei Zhang /// Converts a tensor `type` to a suitable type under the given `targetEnv`. 42401178654SLei Zhang /// 42501178654SLei Zhang /// Note that this is mainly for lowering constant tensors. In SPIR-V one can 42601178654SLei Zhang /// create composite constants with OpConstantComposite to embed relative large 42701178654SLei Zhang /// constant values and use OpCompositeExtract and OpCompositeInsert to 42801178654SLei Zhang /// manipulate, like what we do for vectors. 429004f29c0SLei Zhang static Type convertTensorType(const spirv::TargetEnv &targetEnv, 430dfaebd3dSLei Zhang const SPIRVConversionOptions &options, 43101178654SLei Zhang TensorType type) { 43201178654SLei Zhang // TODO: Handle dynamic shapes. 43301178654SLei Zhang if (!type.hasStaticShape()) { 43401178654SLei Zhang LLVM_DEBUG(llvm::dbgs() 43501178654SLei Zhang << type << " illegal: dynamic shape unimplemented\n"); 436004f29c0SLei Zhang return nullptr; 43701178654SLei Zhang } 43801178654SLei Zhang 439c29fc69eSJakub Kuderski type = cast<TensorType>(convertIndexElementType(type, options)); 440c29fc69eSJakub Kuderski auto scalarType = dyn_cast_or_null<spirv::ScalarType>(type.getElementType()); 44101178654SLei Zhang if (!scalarType) { 44201178654SLei Zhang LLVM_DEBUG(llvm::dbgs() 44301178654SLei Zhang << type << " illegal: cannot convert non-scalar element type\n"); 444004f29c0SLei Zhang return nullptr; 44501178654SLei Zhang } 44601178654SLei Zhang 4470de16fafSRamkumar Ramachandra std::optional<int64_t> scalarSize = getTypeNumBytes(options, scalarType); 4480de16fafSRamkumar Ramachandra std::optional<int64_t> tensorSize = getTypeNumBytes(options, type); 44901178654SLei Zhang if (!scalarSize || !tensorSize) { 45001178654SLei Zhang LLVM_DEBUG(llvm::dbgs() 45101178654SLei Zhang << type << " illegal: cannot deduce element count\n"); 452004f29c0SLei Zhang return nullptr; 45301178654SLei Zhang } 45401178654SLei Zhang 4552ad297dbSJakub Kuderski int64_t arrayElemCount = *tensorSize / *scalarSize; 4562ad297dbSJakub Kuderski if (arrayElemCount == 0) { 4572ad297dbSJakub Kuderski LLVM_DEBUG(llvm::dbgs() 4582ad297dbSJakub Kuderski << type << " illegal: cannot handle zero-element tensors\n"); 4592ad297dbSJakub Kuderski return nullptr; 4602ad297dbSJakub Kuderski } 4612ad297dbSJakub Kuderski 4622ad297dbSJakub Kuderski Type arrayElemType = convertScalarType(targetEnv, options, scalarType); 46301178654SLei Zhang if (!arrayElemType) 464004f29c0SLei Zhang return nullptr; 4650de16fafSRamkumar Ramachandra std::optional<int64_t> arrayElemSize = 4660de16fafSRamkumar Ramachandra getTypeNumBytes(options, arrayElemType); 46701178654SLei Zhang if (!arrayElemSize) { 46801178654SLei Zhang LLVM_DEBUG(llvm::dbgs() 46901178654SLei Zhang << type << " illegal: cannot deduce converted element size\n"); 470004f29c0SLei Zhang return nullptr; 47101178654SLei Zhang } 47201178654SLei Zhang 473bbffece3SLei Zhang return spirv::ArrayType::get(arrayElemType, arrayElemCount); 47401178654SLei Zhang } 47501178654SLei Zhang 476c3614358SHanhan Wang static Type convertBoolMemrefType(const spirv::TargetEnv &targetEnv, 477dfaebd3dSLei Zhang const SPIRVConversionOptions &options, 47889b595e1SLei Zhang MemRefType type, 47989b595e1SLei Zhang spirv::StorageClass storageClass) { 480c3614358SHanhan Wang unsigned numBoolBits = options.boolNumBits; 481c3614358SHanhan Wang if (numBoolBits != 8) { 482c3614358SHanhan Wang LLVM_DEBUG(llvm::dbgs() 483c3614358SHanhan Wang << "using non-8-bit storage for bool types unimplemented"); 484c3614358SHanhan Wang return nullptr; 485c3614358SHanhan Wang } 4865550c821STres Popp auto elementType = dyn_cast<spirv::ScalarType>( 4875550c821STres Popp IntegerType::get(type.getContext(), numBoolBits)); 488c3614358SHanhan Wang if (!elementType) 489c3614358SHanhan Wang return nullptr; 490c3614358SHanhan Wang Type arrayElemType = 491c3614358SHanhan Wang convertScalarType(targetEnv, options, elementType, storageClass); 492c3614358SHanhan Wang if (!arrayElemType) 493c3614358SHanhan Wang return nullptr; 4940de16fafSRamkumar Ramachandra std::optional<int64_t> arrayElemSize = 4950de16fafSRamkumar Ramachandra getTypeNumBytes(options, arrayElemType); 496c3614358SHanhan Wang if (!arrayElemSize) { 497c3614358SHanhan Wang LLVM_DEBUG(llvm::dbgs() 498c3614358SHanhan Wang << type << " illegal: cannot deduce converted element size\n"); 499c3614358SHanhan Wang return nullptr; 500c3614358SHanhan Wang } 501c3614358SHanhan Wang 502d6de6ddeSNirvedh Meshram if (!type.hasStaticShape()) { 503d6de6ddeSNirvedh Meshram // For OpenCL Kernel, dynamic shaped memrefs convert into a pointer pointing 504d6de6ddeSNirvedh Meshram // to the element. 5059c3a73a5SStanley Winata if (targetEnv.allows(spirv::Capability::Kernel)) 5069c3a73a5SStanley Winata return spirv::PointerType::get(arrayElemType, storageClass); 50789b595e1SLei Zhang int64_t stride = needsExplicitLayout(storageClass) ? *arrayElemSize : 0; 508bbffece3SLei Zhang auto arrayType = spirv::RuntimeArrayType::get(arrayElemType, stride); 509d6de6ddeSNirvedh Meshram // For Vulkan we need extra wrapping struct and array to satisfy interface 510d6de6ddeSNirvedh Meshram // needs. 51189b595e1SLei Zhang return wrapInStructAndGetPointer(arrayType, storageClass); 5127c4de2e9SHanhan Wang } 5137c4de2e9SHanhan Wang 5146ba60390SJakub Kuderski if (type.getNumElements() == 0) { 5156ba60390SJakub Kuderski LLVM_DEBUG(llvm::dbgs() 5166ba60390SJakub Kuderski << type << " illegal: zero-element memrefs are not supported\n"); 5176ba60390SJakub Kuderski return nullptr; 5186ba60390SJakub Kuderski } 5196ba60390SJakub Kuderski 520f772dcbbSLei Zhang int64_t memrefSize = llvm::divideCeil(type.getNumElements() * numBoolBits, 8); 521f772dcbbSLei Zhang int64_t arrayElemCount = llvm::divideCeil(memrefSize, *arrayElemSize); 52289b595e1SLei Zhang int64_t stride = needsExplicitLayout(storageClass) ? *arrayElemSize : 0; 523bbffece3SLei Zhang auto arrayType = spirv::ArrayType::get(arrayElemType, arrayElemCount, stride); 524d6de6ddeSNirvedh Meshram if (targetEnv.allows(spirv::Capability::Kernel)) 525d6de6ddeSNirvedh Meshram return spirv::PointerType::get(arrayType, storageClass); 52689b595e1SLei Zhang return wrapInStructAndGetPointer(arrayType, storageClass); 527c3614358SHanhan Wang } 528c3614358SHanhan Wang 529f772dcbbSLei Zhang static Type convertSubByteMemrefType(const spirv::TargetEnv &targetEnv, 530f772dcbbSLei Zhang const SPIRVConversionOptions &options, 531f772dcbbSLei Zhang MemRefType type, 532f772dcbbSLei Zhang spirv::StorageClass storageClass) { 533f772dcbbSLei Zhang IntegerType elementType = cast<IntegerType>(type.getElementType()); 534f772dcbbSLei Zhang Type arrayElemType = convertSubByteIntegerType(options, elementType); 535f772dcbbSLei Zhang if (!arrayElemType) 536f772dcbbSLei Zhang return nullptr; 537f772dcbbSLei Zhang int64_t arrayElemSize = *getTypeNumBytes(options, arrayElemType); 538f772dcbbSLei Zhang 539f772dcbbSLei Zhang if (!type.hasStaticShape()) { 540f772dcbbSLei Zhang // For OpenCL Kernel, dynamic shaped memrefs convert into a pointer pointing 541f772dcbbSLei Zhang // to the element. 542f772dcbbSLei Zhang if (targetEnv.allows(spirv::Capability::Kernel)) 543f772dcbbSLei Zhang return spirv::PointerType::get(arrayElemType, storageClass); 544f772dcbbSLei Zhang int64_t stride = needsExplicitLayout(storageClass) ? arrayElemSize : 0; 545f772dcbbSLei Zhang auto arrayType = spirv::RuntimeArrayType::get(arrayElemType, stride); 546f772dcbbSLei Zhang // For Vulkan we need extra wrapping struct and array to satisfy interface 547f772dcbbSLei Zhang // needs. 548f772dcbbSLei Zhang return wrapInStructAndGetPointer(arrayType, storageClass); 549f772dcbbSLei Zhang } 550f772dcbbSLei Zhang 5516ba60390SJakub Kuderski if (type.getNumElements() == 0) { 5526ba60390SJakub Kuderski LLVM_DEBUG(llvm::dbgs() 5536ba60390SJakub Kuderski << type << " illegal: zero-element memrefs are not supported\n"); 5546ba60390SJakub Kuderski return nullptr; 5556ba60390SJakub Kuderski } 5566ba60390SJakub Kuderski 557f772dcbbSLei Zhang int64_t memrefSize = 558f772dcbbSLei Zhang llvm::divideCeil(type.getNumElements() * elementType.getWidth(), 8); 559f772dcbbSLei Zhang int64_t arrayElemCount = llvm::divideCeil(memrefSize, arrayElemSize); 560f772dcbbSLei Zhang int64_t stride = needsExplicitLayout(storageClass) ? arrayElemSize : 0; 561f772dcbbSLei Zhang auto arrayType = spirv::ArrayType::get(arrayElemType, arrayElemCount, stride); 562f772dcbbSLei Zhang if (targetEnv.allows(spirv::Capability::Kernel)) 563f772dcbbSLei Zhang return spirv::PointerType::get(arrayType, storageClass); 564f772dcbbSLei Zhang return wrapInStructAndGetPointer(arrayType, storageClass); 565f772dcbbSLei Zhang } 566f772dcbbSLei Zhang 567004f29c0SLei Zhang static Type convertMemrefType(const spirv::TargetEnv &targetEnv, 568dfaebd3dSLei Zhang const SPIRVConversionOptions &options, 56901178654SLei Zhang MemRefType type) { 5705550c821STres Popp auto attr = dyn_cast_or_null<spirv::StorageClassAttr>(type.getMemorySpace()); 57189b595e1SLei Zhang if (!attr) { 57289b595e1SLei Zhang LLVM_DEBUG( 57389b595e1SLei Zhang llvm::dbgs() 57489b595e1SLei Zhang << type 57589b595e1SLei Zhang << " illegal: expected memory space to be a SPIR-V storage class " 57689b595e1SLei Zhang "attribute; please use MemorySpaceToStorageClassConverter to map " 57789b595e1SLei Zhang "numeric memory spaces beforehand\n"); 57889b595e1SLei Zhang return nullptr; 57989b595e1SLei Zhang } 58089b595e1SLei Zhang spirv::StorageClass storageClass = attr.getValue(); 58189b595e1SLei Zhang 5825550c821STres Popp if (isa<IntegerType>(type.getElementType())) { 583f772dcbbSLei Zhang if (type.getElementTypeBitWidth() == 1) 58489b595e1SLei Zhang return convertBoolMemrefType(targetEnv, options, type, storageClass); 585f772dcbbSLei Zhang if (type.getElementTypeBitWidth() < 8) 586f772dcbbSLei Zhang return convertSubByteMemrefType(targetEnv, options, type, storageClass); 58701178654SLei Zhang } 58801178654SLei Zhang 589004f29c0SLei Zhang Type arrayElemType; 59001178654SLei Zhang Type elementType = type.getElementType(); 5915550c821STres Popp if (auto vecType = dyn_cast<VectorType>(elementType)) { 5925299843cSLei Zhang arrayElemType = 5935299843cSLei Zhang convertVectorType(targetEnv, options, vecType, storageClass); 5945550c821STres Popp } else if (auto complexType = dyn_cast<ComplexType>(elementType)) { 59597f3bb73SLei Zhang arrayElemType = 59697f3bb73SLei Zhang convertComplexType(targetEnv, options, complexType, storageClass); 5975550c821STres Popp } else if (auto scalarType = dyn_cast<spirv::ScalarType>(elementType)) { 5985299843cSLei Zhang arrayElemType = 5995299843cSLei Zhang convertScalarType(targetEnv, options, scalarType, storageClass); 6005550c821STres Popp } else if (auto indexType = dyn_cast<IndexType>(elementType)) { 6015550c821STres Popp type = cast<MemRefType>(convertIndexElementType(type, options)); 602c29fc69eSJakub Kuderski arrayElemType = type.getElementType(); 60301178654SLei Zhang } else { 60401178654SLei Zhang LLVM_DEBUG( 60501178654SLei Zhang llvm::dbgs() 60601178654SLei Zhang << type 60701178654SLei Zhang << " unhandled: can only convert scalar or vector element type\n"); 608004f29c0SLei Zhang return nullptr; 60901178654SLei Zhang } 61001178654SLei Zhang if (!arrayElemType) 611004f29c0SLei Zhang return nullptr; 61201178654SLei Zhang 6130de16fafSRamkumar Ramachandra std::optional<int64_t> arrayElemSize = 6140de16fafSRamkumar Ramachandra getTypeNumBytes(options, arrayElemType); 61523b8264bSLei Zhang if (!arrayElemSize) { 61623b8264bSLei Zhang LLVM_DEBUG(llvm::dbgs() 61723b8264bSLei Zhang << type << " illegal: cannot deduce converted element size\n"); 61823b8264bSLei Zhang return nullptr; 61923b8264bSLei Zhang } 62023b8264bSLei Zhang 621d6de6ddeSNirvedh Meshram if (!type.hasStaticShape()) { 622d6de6ddeSNirvedh Meshram // For OpenCL Kernel, dynamic shaped memrefs convert into a pointer pointing 623d6de6ddeSNirvedh Meshram // to the element. 6249c3a73a5SStanley Winata if (targetEnv.allows(spirv::Capability::Kernel)) 6259c3a73a5SStanley Winata return spirv::PointerType::get(arrayElemType, storageClass); 62689b595e1SLei Zhang int64_t stride = needsExplicitLayout(storageClass) ? *arrayElemSize : 0; 627bbffece3SLei Zhang auto arrayType = spirv::RuntimeArrayType::get(arrayElemType, stride); 628d6de6ddeSNirvedh Meshram // For Vulkan we need extra wrapping struct and array to satisfy interface 629d6de6ddeSNirvedh Meshram // needs. 63089b595e1SLei Zhang return wrapInStructAndGetPointer(arrayType, storageClass); 63101178654SLei Zhang } 63201178654SLei Zhang 6330de16fafSRamkumar Ramachandra std::optional<int64_t> memrefSize = getTypeNumBytes(options, type); 63401178654SLei Zhang if (!memrefSize) { 63501178654SLei Zhang LLVM_DEBUG(llvm::dbgs() 63601178654SLei Zhang << type << " illegal: cannot deduce element count\n"); 637004f29c0SLei Zhang return nullptr; 63801178654SLei Zhang } 63901178654SLei Zhang 640370a7eaeSJakub Kuderski if (*memrefSize == 0) { 641370a7eaeSJakub Kuderski LLVM_DEBUG(llvm::dbgs() 642370a7eaeSJakub Kuderski << type << " illegal: zero-element memrefs are not supported\n"); 643370a7eaeSJakub Kuderski return nullptr; 644370a7eaeSJakub Kuderski } 645370a7eaeSJakub Kuderski 646f772dcbbSLei Zhang int64_t arrayElemCount = llvm::divideCeil(*memrefSize, *arrayElemSize); 64789b595e1SLei Zhang int64_t stride = needsExplicitLayout(storageClass) ? *arrayElemSize : 0; 648bbffece3SLei Zhang auto arrayType = spirv::ArrayType::get(arrayElemType, arrayElemCount, stride); 649d6de6ddeSNirvedh Meshram if (targetEnv.allows(spirv::Capability::Kernel)) 650d6de6ddeSNirvedh Meshram return spirv::PointerType::get(arrayType, storageClass); 65189b595e1SLei Zhang return wrapInStructAndGetPointer(arrayType, storageClass); 65201178654SLei Zhang } 65301178654SLei Zhang 65418ad80eaSLei Zhang //===----------------------------------------------------------------------===// 65518ad80eaSLei Zhang // Type casting materialization 65618ad80eaSLei Zhang //===----------------------------------------------------------------------===// 65718ad80eaSLei Zhang 65818ad80eaSLei Zhang /// Converts the given `inputs` to the original source `type` considering the 65918ad80eaSLei Zhang /// `targetEnv`'s capabilities. 66018ad80eaSLei Zhang /// 66118ad80eaSLei Zhang /// This function is meant to be used for source materialization in type 66218ad80eaSLei Zhang /// converters. When the type converter needs to materialize a cast op back 66318ad80eaSLei Zhang /// to some original source type, we need to check whether the original source 66418ad80eaSLei Zhang /// type is supported in the target environment. If so, we can insert legal 66518ad80eaSLei Zhang /// SPIR-V cast ops accordingly. 66618ad80eaSLei Zhang /// 66718ad80eaSLei Zhang /// Note that in SPIR-V the capabilities for storage and compute are separate. 66818ad80eaSLei Zhang /// This function is meant to handle the **compute** side; so it does not 66918ad80eaSLei Zhang /// involve storage classes in its logic. The storage side is expected to be 67018ad80eaSLei Zhang /// handled by MemRef conversion logic. 671f18c3e4eSMatthias Springer static Value castToSourceType(const spirv::TargetEnv &targetEnv, 672f18c3e4eSMatthias Springer OpBuilder &builder, Type type, ValueRange inputs, 673f18c3e4eSMatthias Springer Location loc) { 67418ad80eaSLei Zhang // We can only cast one value in SPIR-V. 67518ad80eaSLei Zhang if (inputs.size() != 1) { 67618ad80eaSLei Zhang auto castOp = builder.create<UnrealizedConversionCastOp>(loc, type, inputs); 67718ad80eaSLei Zhang return castOp.getResult(0); 67818ad80eaSLei Zhang } 67918ad80eaSLei Zhang Value input = inputs.front(); 68018ad80eaSLei Zhang 68118ad80eaSLei Zhang // Only support integer types for now. Floating point types to be implemented. 68218ad80eaSLei Zhang if (!isa<IntegerType>(type)) { 68318ad80eaSLei Zhang auto castOp = builder.create<UnrealizedConversionCastOp>(loc, type, inputs); 68418ad80eaSLei Zhang return castOp.getResult(0); 68518ad80eaSLei Zhang } 68618ad80eaSLei Zhang auto inputType = cast<IntegerType>(input.getType()); 68718ad80eaSLei Zhang 68818ad80eaSLei Zhang auto scalarType = dyn_cast<spirv::ScalarType>(type); 68918ad80eaSLei Zhang if (!scalarType) { 69018ad80eaSLei Zhang auto castOp = builder.create<UnrealizedConversionCastOp>(loc, type, inputs); 69118ad80eaSLei Zhang return castOp.getResult(0); 69218ad80eaSLei Zhang } 69318ad80eaSLei Zhang 69418ad80eaSLei Zhang // Only support source type with a smaller bitwidth. This would mean we are 69518ad80eaSLei Zhang // truncating to go back so we don't need to worry about the signedness. 69618ad80eaSLei Zhang // For extension, we cannot have enough signal here to decide which op to use. 69718ad80eaSLei Zhang if (inputType.getIntOrFloatBitWidth() < scalarType.getIntOrFloatBitWidth()) { 69818ad80eaSLei Zhang auto castOp = builder.create<UnrealizedConversionCastOp>(loc, type, inputs); 69918ad80eaSLei Zhang return castOp.getResult(0); 70018ad80eaSLei Zhang } 70118ad80eaSLei Zhang 70218ad80eaSLei Zhang // Boolean values would need to use different ops than normal integer values. 70318ad80eaSLei Zhang if (type.isInteger(1)) { 70418ad80eaSLei Zhang Value one = spirv::ConstantOp::getOne(inputType, loc, builder); 70518ad80eaSLei Zhang return builder.create<spirv::IEqualOp>(loc, input, one); 70618ad80eaSLei Zhang } 70718ad80eaSLei Zhang 70818ad80eaSLei Zhang // Check that the source integer type is supported by the environment. 70918ad80eaSLei Zhang SmallVector<ArrayRef<spirv::Extension>, 1> exts; 71018ad80eaSLei Zhang SmallVector<ArrayRef<spirv::Capability>, 2> caps; 71118ad80eaSLei Zhang scalarType.getExtensions(exts); 71218ad80eaSLei Zhang scalarType.getCapabilities(caps); 71318ad80eaSLei Zhang if (failed(checkCapabilityRequirements(type, targetEnv, caps)) || 71418ad80eaSLei Zhang failed(checkExtensionRequirements(type, targetEnv, exts))) { 71518ad80eaSLei Zhang auto castOp = builder.create<UnrealizedConversionCastOp>(loc, type, inputs); 71618ad80eaSLei Zhang return castOp.getResult(0); 71718ad80eaSLei Zhang } 71818ad80eaSLei Zhang 71918ad80eaSLei Zhang // We've already made sure this is truncating previously, so we don't need to 72018ad80eaSLei Zhang // care about signedness here. Still try to use a corresponding op for better 72118ad80eaSLei Zhang // consistency though. 72218ad80eaSLei Zhang if (type.isSignedInteger()) { 72318ad80eaSLei Zhang return builder.create<spirv::SConvertOp>(loc, type, input); 72418ad80eaSLei Zhang } 72518ad80eaSLei Zhang return builder.create<spirv::UConvertOp>(loc, type, input); 72618ad80eaSLei Zhang } 72718ad80eaSLei Zhang 72818ad80eaSLei Zhang //===----------------------------------------------------------------------===// 7299527d77aSAngel Zhang // Builtin Variables 73018ad80eaSLei Zhang //===----------------------------------------------------------------------===// 73118ad80eaSLei Zhang 7329527d77aSAngel Zhang static spirv::GlobalVariableOp getBuiltinVariable(Block &body, 7339527d77aSAngel Zhang spirv::BuiltIn builtin) { 7349527d77aSAngel Zhang // Look through all global variables in the given `body` block and check if 7359527d77aSAngel Zhang // there is a spirv.GlobalVariable that has the same `builtin` attribute. 7369527d77aSAngel Zhang for (auto varOp : body.getOps<spirv::GlobalVariableOp>()) { 7379527d77aSAngel Zhang if (auto builtinAttr = varOp->getAttrOfType<StringAttr>( 7389527d77aSAngel Zhang spirv::SPIRVDialect::getAttributeName( 7399527d77aSAngel Zhang spirv::Decoration::BuiltIn))) { 7409527d77aSAngel Zhang auto varBuiltIn = spirv::symbolizeBuiltIn(builtinAttr.getValue()); 7419527d77aSAngel Zhang if (varBuiltIn && *varBuiltIn == builtin) { 7429527d77aSAngel Zhang return varOp; 7439527d77aSAngel Zhang } 7449527d77aSAngel Zhang } 7459527d77aSAngel Zhang } 7469527d77aSAngel Zhang return nullptr; 7479527d77aSAngel Zhang } 74801178654SLei Zhang 7499527d77aSAngel Zhang /// Gets name of global variable for a builtin. 7509527d77aSAngel Zhang std::string getBuiltinVarName(spirv::BuiltIn builtin, StringRef prefix, 7519527d77aSAngel Zhang StringRef suffix) { 7529527d77aSAngel Zhang return Twine(prefix).concat(stringifyBuiltIn(builtin)).concat(suffix).str(); 7539527d77aSAngel Zhang } 75401178654SLei Zhang 7559527d77aSAngel Zhang /// Gets or inserts a global variable for a builtin within `body` block. 7569527d77aSAngel Zhang static spirv::GlobalVariableOp 7579527d77aSAngel Zhang getOrInsertBuiltinVariable(Block &body, Location loc, spirv::BuiltIn builtin, 7589527d77aSAngel Zhang Type integerType, OpBuilder &builder, 7599527d77aSAngel Zhang StringRef prefix, StringRef suffix) { 7609527d77aSAngel Zhang if (auto varOp = getBuiltinVariable(body, builtin)) 7619527d77aSAngel Zhang return varOp; 76201178654SLei Zhang 7639527d77aSAngel Zhang OpBuilder::InsertionGuard guard(builder); 7649527d77aSAngel Zhang builder.setInsertionPointToStart(&body); 76501178654SLei Zhang 7669527d77aSAngel Zhang spirv::GlobalVariableOp newVarOp; 7679527d77aSAngel Zhang switch (builtin) { 7689527d77aSAngel Zhang case spirv::BuiltIn::NumWorkgroups: 7699527d77aSAngel Zhang case spirv::BuiltIn::WorkgroupSize: 7709527d77aSAngel Zhang case spirv::BuiltIn::WorkgroupId: 7719527d77aSAngel Zhang case spirv::BuiltIn::LocalInvocationId: 7729527d77aSAngel Zhang case spirv::BuiltIn::GlobalInvocationId: { 7739527d77aSAngel Zhang auto ptrType = spirv::PointerType::get(VectorType::get({3}, integerType), 7749527d77aSAngel Zhang spirv::StorageClass::Input); 7759527d77aSAngel Zhang std::string name = getBuiltinVarName(builtin, prefix, suffix); 7769527d77aSAngel Zhang newVarOp = 7779527d77aSAngel Zhang builder.create<spirv::GlobalVariableOp>(loc, ptrType, name, builtin); 7789527d77aSAngel Zhang break; 7799527d77aSAngel Zhang } 7809527d77aSAngel Zhang case spirv::BuiltIn::SubgroupId: 7819527d77aSAngel Zhang case spirv::BuiltIn::NumSubgroups: 7829527d77aSAngel Zhang case spirv::BuiltIn::SubgroupSize: { 7839527d77aSAngel Zhang auto ptrType = 7849527d77aSAngel Zhang spirv::PointerType::get(integerType, spirv::StorageClass::Input); 7859527d77aSAngel Zhang std::string name = getBuiltinVarName(builtin, prefix, suffix); 7869527d77aSAngel Zhang newVarOp = 7879527d77aSAngel Zhang builder.create<spirv::GlobalVariableOp>(loc, ptrType, name, builtin); 7889527d77aSAngel Zhang break; 7899527d77aSAngel Zhang } 7909527d77aSAngel Zhang default: 7919527d77aSAngel Zhang emitError(loc, "unimplemented builtin variable generation for ") 7929527d77aSAngel Zhang << stringifyBuiltIn(builtin); 7939527d77aSAngel Zhang } 7949527d77aSAngel Zhang return newVarOp; 7959527d77aSAngel Zhang } 79601178654SLei Zhang 7979527d77aSAngel Zhang //===----------------------------------------------------------------------===// 7989527d77aSAngel Zhang // Push constant storage 7999527d77aSAngel Zhang //===----------------------------------------------------------------------===// 80097f3bb73SLei Zhang 8019527d77aSAngel Zhang /// Returns the pointer type for the push constant storage containing 8029527d77aSAngel Zhang /// `elementCount` 32-bit integer values. 8039527d77aSAngel Zhang static spirv::PointerType getPushConstantStorageType(unsigned elementCount, 8049527d77aSAngel Zhang Builder &builder, 8059527d77aSAngel Zhang Type indexType) { 8069527d77aSAngel Zhang auto arrayType = spirv::ArrayType::get(indexType, elementCount, 8079527d77aSAngel Zhang /*stride=*/4); 8089527d77aSAngel Zhang auto structType = spirv::StructType::get({arrayType}, /*offsetInfo=*/0); 8099527d77aSAngel Zhang return spirv::PointerType::get(structType, spirv::StorageClass::PushConstant); 8109527d77aSAngel Zhang } 81101178654SLei Zhang 8129527d77aSAngel Zhang /// Returns the push constant varible containing `elementCount` 32-bit integer 8139527d77aSAngel Zhang /// values in `body`. Returns null op if such an op does not exit. 8149527d77aSAngel Zhang static spirv::GlobalVariableOp getPushConstantVariable(Block &body, 8159527d77aSAngel Zhang unsigned elementCount) { 8169527d77aSAngel Zhang for (auto varOp : body.getOps<spirv::GlobalVariableOp>()) { 8179527d77aSAngel Zhang auto ptrType = dyn_cast<spirv::PointerType>(varOp.getType()); 8189527d77aSAngel Zhang if (!ptrType) 8199527d77aSAngel Zhang continue; 82001178654SLei Zhang 8219527d77aSAngel Zhang // Note that Vulkan requires "There must be no more than one push constant 8229527d77aSAngel Zhang // block statically used per shader entry point." So we should always reuse 8239527d77aSAngel Zhang // the existing one. 8249527d77aSAngel Zhang if (ptrType.getStorageClass() == spirv::StorageClass::PushConstant) { 8259527d77aSAngel Zhang auto numElements = cast<spirv::ArrayType>( 8269527d77aSAngel Zhang cast<spirv::StructType>(ptrType.getPointeeType()) 8279527d77aSAngel Zhang .getElementType(0)) 8289527d77aSAngel Zhang .getNumElements(); 8299527d77aSAngel Zhang if (numElements == elementCount) 8309527d77aSAngel Zhang return varOp; 8319527d77aSAngel Zhang } 8329527d77aSAngel Zhang } 8339527d77aSAngel Zhang return nullptr; 8349527d77aSAngel Zhang } 83518ad80eaSLei Zhang 8369527d77aSAngel Zhang /// Gets or inserts a global variable for push constant storage containing 8379527d77aSAngel Zhang /// `elementCount` 32-bit integer values in `block`. 8389527d77aSAngel Zhang static spirv::GlobalVariableOp 8399527d77aSAngel Zhang getOrInsertPushConstantVariable(Location loc, Block &block, 8409527d77aSAngel Zhang unsigned elementCount, OpBuilder &b, 8419527d77aSAngel Zhang Type indexType) { 8429527d77aSAngel Zhang if (auto varOp = getPushConstantVariable(block, elementCount)) 8439527d77aSAngel Zhang return varOp; 8449527d77aSAngel Zhang 8459527d77aSAngel Zhang auto builder = OpBuilder::atBlockBegin(&block, b.getListener()); 8469527d77aSAngel Zhang auto type = getPushConstantStorageType(elementCount, builder, indexType); 8479527d77aSAngel Zhang const char *name = "__push_constant_var__"; 8489527d77aSAngel Zhang return builder.create<spirv::GlobalVariableOp>(loc, type, name, 8499527d77aSAngel Zhang /*initializer=*/nullptr); 85001178654SLei Zhang } 85101178654SLei Zhang 85201178654SLei Zhang //===----------------------------------------------------------------------===// 85358ceae95SRiver Riddle // func::FuncOp Conversion Patterns 85401178654SLei Zhang //===----------------------------------------------------------------------===// 85501178654SLei Zhang 85601178654SLei Zhang /// A pattern for rewriting function signature to convert arguments of functions 85701178654SLei Zhang /// to be of valid SPIR-V types. 8589527d77aSAngel Zhang struct FuncOpConversion final : OpConversionPattern<func::FuncOp> { 85958ceae95SRiver Riddle using OpConversionPattern<func::FuncOp>::OpConversionPattern; 86001178654SLei Zhang 86101178654SLei Zhang LogicalResult 86258ceae95SRiver Riddle matchAndRewrite(func::FuncOp funcOp, OpAdaptor adaptor, 8639527d77aSAngel Zhang ConversionPatternRewriter &rewriter) const override { 8649527d77aSAngel Zhang FunctionType fnType = funcOp.getFunctionType(); 86542980a78SLei Zhang if (fnType.getNumResults() > 1) 86601178654SLei Zhang return failure(); 86701178654SLei Zhang 8689527d77aSAngel Zhang TypeConverter::SignatureConversion signatureConverter( 8699527d77aSAngel Zhang fnType.getNumInputs()); 870e4853be2SMehdi Amini for (const auto &argType : enumerate(fnType.getInputs())) { 8717c3ae48fSLei Zhang auto convertedType = getTypeConverter()->convertType(argType.value()); 87201178654SLei Zhang if (!convertedType) 87301178654SLei Zhang return failure(); 87401178654SLei Zhang signatureConverter.addInputs(argType.index(), convertedType); 87501178654SLei Zhang } 87601178654SLei Zhang 87742980a78SLei Zhang Type resultType; 8785299843cSLei Zhang if (fnType.getNumResults() == 1) { 8797c3ae48fSLei Zhang resultType = getTypeConverter()->convertType(fnType.getResult(0)); 8805299843cSLei Zhang if (!resultType) 8815299843cSLei Zhang return failure(); 8825299843cSLei Zhang } 88342980a78SLei Zhang 8845ab6ef75SJakub Kuderski // Create the converted spirv.func op. 88501178654SLei Zhang auto newFuncOp = rewriter.create<spirv::FuncOp>( 88601178654SLei Zhang funcOp.getLoc(), funcOp.getName(), 88701178654SLei Zhang rewriter.getFunctionType(signatureConverter.getConvertedTypes(), 88842980a78SLei Zhang resultType ? TypeRange(resultType) 88942980a78SLei Zhang : TypeRange())); 89001178654SLei Zhang 89101178654SLei Zhang // Copy over all attributes other than the function name and type. 89256774bddSMarius Brehler for (const auto &namedAttr : funcOp->getAttrs()) { 89353406427SJeff Niu if (namedAttr.getName() != funcOp.getFunctionTypeAttrName() && 8940c7890c8SRiver Riddle namedAttr.getName() != SymbolTable::getSymbolAttrName()) 8950c7890c8SRiver Riddle newFuncOp->setAttr(namedAttr.getName(), namedAttr.getValue()); 89601178654SLei Zhang } 89701178654SLei Zhang 89801178654SLei Zhang rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(), 89901178654SLei Zhang newFuncOp.end()); 9007c3ae48fSLei Zhang if (failed(rewriter.convertRegionTypes( 9017c3ae48fSLei Zhang &newFuncOp.getBody(), *getTypeConverter(), &signatureConverter))) 90201178654SLei Zhang return failure(); 90301178654SLei Zhang rewriter.eraseOp(funcOp); 90401178654SLei Zhang return success(); 90501178654SLei Zhang } 9069527d77aSAngel Zhang }; 90701178654SLei Zhang 9086867e49fSAngel Zhang /// A pattern for rewriting function signature to convert vector arguments of 9096867e49fSAngel Zhang /// functions to be of valid types 9106867e49fSAngel Zhang struct FuncOpVectorUnroll final : OpRewritePattern<func::FuncOp> { 9116867e49fSAngel Zhang using OpRewritePattern::OpRewritePattern; 9126867e49fSAngel Zhang 9136867e49fSAngel Zhang LogicalResult matchAndRewrite(func::FuncOp funcOp, 9146867e49fSAngel Zhang PatternRewriter &rewriter) const override { 9156867e49fSAngel Zhang FunctionType fnType = funcOp.getFunctionType(); 9166867e49fSAngel Zhang 9176867e49fSAngel Zhang // TODO: Handle declarations. 9186867e49fSAngel Zhang if (funcOp.isDeclaration()) { 9196867e49fSAngel Zhang LLVM_DEBUG(llvm::dbgs() 9206867e49fSAngel Zhang << fnType << " illegal: declarations are unsupported\n"); 9216867e49fSAngel Zhang return failure(); 9226867e49fSAngel Zhang } 9236867e49fSAngel Zhang 9246867e49fSAngel Zhang // Create a new func op with the original type and copy the function body. 9256867e49fSAngel Zhang auto newFuncOp = rewriter.create<func::FuncOp>(funcOp.getLoc(), 9266867e49fSAngel Zhang funcOp.getName(), fnType); 9276867e49fSAngel Zhang rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(), 9286867e49fSAngel Zhang newFuncOp.end()); 9296867e49fSAngel Zhang 9306867e49fSAngel Zhang Location loc = newFuncOp.getBody().getLoc(); 9316867e49fSAngel Zhang 9326867e49fSAngel Zhang Block &entryBlock = newFuncOp.getBlocks().front(); 9336867e49fSAngel Zhang OpBuilder::InsertionGuard guard(rewriter); 9346867e49fSAngel Zhang rewriter.setInsertionPointToStart(&entryBlock); 9356867e49fSAngel Zhang 9366867e49fSAngel Zhang OneToNTypeMapping oneToNTypeMapping(fnType.getInputs()); 9376867e49fSAngel Zhang 9386867e49fSAngel Zhang // For arguments that are of illegal types and require unrolling. 9396867e49fSAngel Zhang // `unrolledInputNums` stores the indices of arguments that result from 9406867e49fSAngel Zhang // unrolling in the new function signature. `newInputNo` is a counter. 9416867e49fSAngel Zhang SmallVector<size_t> unrolledInputNums; 9426867e49fSAngel Zhang size_t newInputNo = 0; 9436867e49fSAngel Zhang 9446867e49fSAngel Zhang // For arguments that are of legal types and do not require unrolling. 9456867e49fSAngel Zhang // `tmpOps` stores a mapping from temporary operations that serve as 9466867e49fSAngel Zhang // placeholders for new arguments that will be added later. These operations 9476867e49fSAngel Zhang // will be erased once the entry block's argument list is updated. 9486867e49fSAngel Zhang llvm::SmallDenseMap<Operation *, size_t> tmpOps; 9496867e49fSAngel Zhang 9506867e49fSAngel Zhang // This counts the number of new operations created. 9516867e49fSAngel Zhang size_t newOpCount = 0; 9526867e49fSAngel Zhang 9536867e49fSAngel Zhang // Enumerate through the arguments. 9546867e49fSAngel Zhang for (auto [origInputNo, origType] : enumerate(fnType.getInputs())) { 9556867e49fSAngel Zhang // Check whether the argument is of vector type. 9566867e49fSAngel Zhang auto origVecType = dyn_cast<VectorType>(origType); 9576867e49fSAngel Zhang if (!origVecType) { 9586867e49fSAngel Zhang // We need a placeholder for the old argument that will be erased later. 9596867e49fSAngel Zhang Value result = rewriter.create<arith::ConstantOp>( 9606867e49fSAngel Zhang loc, origType, rewriter.getZeroAttr(origType)); 9616867e49fSAngel Zhang rewriter.replaceAllUsesWith(newFuncOp.getArgument(origInputNo), result); 9626867e49fSAngel Zhang tmpOps.insert({result.getDefiningOp(), newInputNo}); 9636867e49fSAngel Zhang oneToNTypeMapping.addInputs(origInputNo, origType); 9646867e49fSAngel Zhang ++newInputNo; 9656867e49fSAngel Zhang ++newOpCount; 9666867e49fSAngel Zhang continue; 9676867e49fSAngel Zhang } 9686867e49fSAngel Zhang // Check whether the vector needs unrolling. 9696867e49fSAngel Zhang auto targetShape = getTargetShape(origVecType); 9706867e49fSAngel Zhang if (!targetShape) { 9716867e49fSAngel Zhang // We need a placeholder for the old argument that will be erased later. 9726867e49fSAngel Zhang Value result = rewriter.create<arith::ConstantOp>( 9736867e49fSAngel Zhang loc, origType, rewriter.getZeroAttr(origType)); 9746867e49fSAngel Zhang rewriter.replaceAllUsesWith(newFuncOp.getArgument(origInputNo), result); 9756867e49fSAngel Zhang tmpOps.insert({result.getDefiningOp(), newInputNo}); 9766867e49fSAngel Zhang oneToNTypeMapping.addInputs(origInputNo, origType); 9776867e49fSAngel Zhang ++newInputNo; 9786867e49fSAngel Zhang ++newOpCount; 9796867e49fSAngel Zhang continue; 9806867e49fSAngel Zhang } 9816867e49fSAngel Zhang VectorType unrolledType = 9826867e49fSAngel Zhang VectorType::get(*targetShape, origVecType.getElementType()); 9836867e49fSAngel Zhang auto originalShape = 9846867e49fSAngel Zhang llvm::to_vector_of<int64_t, 4>(origVecType.getShape()); 9856867e49fSAngel Zhang 9866867e49fSAngel Zhang // Prepare the result vector. 9876867e49fSAngel Zhang Value result = rewriter.create<arith::ConstantOp>( 9886867e49fSAngel Zhang loc, origVecType, rewriter.getZeroAttr(origVecType)); 9896867e49fSAngel Zhang ++newOpCount; 9906867e49fSAngel Zhang // Prepare the placeholder for the new arguments that will be added later. 9916867e49fSAngel Zhang Value dummy = rewriter.create<arith::ConstantOp>( 9926867e49fSAngel Zhang loc, unrolledType, rewriter.getZeroAttr(unrolledType)); 9936867e49fSAngel Zhang ++newOpCount; 9946867e49fSAngel Zhang 9956867e49fSAngel Zhang // Create the `vector.insert_strided_slice` ops. 9966867e49fSAngel Zhang SmallVector<int64_t> strides(targetShape->size(), 1); 9976867e49fSAngel Zhang SmallVector<Type> newTypes; 9986867e49fSAngel Zhang for (SmallVector<int64_t> offsets : 9996867e49fSAngel Zhang StaticTileOffsetRange(originalShape, *targetShape)) { 10006867e49fSAngel Zhang result = rewriter.create<vector::InsertStridedSliceOp>( 10016867e49fSAngel Zhang loc, dummy, result, offsets, strides); 10026867e49fSAngel Zhang newTypes.push_back(unrolledType); 10036867e49fSAngel Zhang unrolledInputNums.push_back(newInputNo); 10046867e49fSAngel Zhang ++newInputNo; 10056867e49fSAngel Zhang ++newOpCount; 10066867e49fSAngel Zhang } 10076867e49fSAngel Zhang rewriter.replaceAllUsesWith(newFuncOp.getArgument(origInputNo), result); 10086867e49fSAngel Zhang oneToNTypeMapping.addInputs(origInputNo, newTypes); 10096867e49fSAngel Zhang } 10106867e49fSAngel Zhang 10116867e49fSAngel Zhang // Change the function signature. 10126867e49fSAngel Zhang auto convertedTypes = oneToNTypeMapping.getConvertedTypes(); 10136867e49fSAngel Zhang auto newFnType = fnType.clone(convertedTypes, fnType.getResults()); 10146867e49fSAngel Zhang rewriter.modifyOpInPlace(newFuncOp, 10156867e49fSAngel Zhang [&] { newFuncOp.setFunctionType(newFnType); }); 10166867e49fSAngel Zhang 10176867e49fSAngel Zhang // Update the arguments in the entry block. 10186867e49fSAngel Zhang entryBlock.eraseArguments(0, fnType.getNumInputs()); 10196867e49fSAngel Zhang SmallVector<Location> locs(convertedTypes.size(), newFuncOp.getLoc()); 10206867e49fSAngel Zhang entryBlock.addArguments(convertedTypes, locs); 10216867e49fSAngel Zhang 10226867e49fSAngel Zhang // Replace the placeholder values with the new arguments. We assume there is 10236867e49fSAngel Zhang // only one block for now. 10246867e49fSAngel Zhang size_t unrolledInputIdx = 0; 10256867e49fSAngel Zhang for (auto [count, op] : enumerate(entryBlock.getOperations())) { 10266867e49fSAngel Zhang // We first look for operands that are placeholders for initially legal 10276867e49fSAngel Zhang // arguments. 10286867e49fSAngel Zhang Operation &curOp = op; 10296867e49fSAngel Zhang for (auto [operandIdx, operandVal] : llvm::enumerate(op.getOperands())) { 10306867e49fSAngel Zhang Operation *operandOp = operandVal.getDefiningOp(); 10316867e49fSAngel Zhang if (auto it = tmpOps.find(operandOp); it != tmpOps.end()) { 10326867e49fSAngel Zhang size_t idx = operandIdx; 10336867e49fSAngel Zhang rewriter.modifyOpInPlace(&curOp, [&curOp, &newFuncOp, it, idx] { 10346867e49fSAngel Zhang curOp.setOperand(idx, newFuncOp.getArgument(it->second)); 10356867e49fSAngel Zhang }); 10366867e49fSAngel Zhang } 10376867e49fSAngel Zhang } 10386867e49fSAngel Zhang // Since all newly created operations are in the beginning, reaching the 10396867e49fSAngel Zhang // end of them means that any later `vector.insert_strided_slice` should 10406867e49fSAngel Zhang // not be touched. 10416867e49fSAngel Zhang if (count >= newOpCount) 10426867e49fSAngel Zhang continue; 10436867e49fSAngel Zhang if (auto vecOp = dyn_cast<vector::InsertStridedSliceOp>(op)) { 10446867e49fSAngel Zhang size_t unrolledInputNo = unrolledInputNums[unrolledInputIdx]; 10456867e49fSAngel Zhang rewriter.modifyOpInPlace(&curOp, [&] { 10466867e49fSAngel Zhang curOp.setOperand(0, newFuncOp.getArgument(unrolledInputNo)); 10476867e49fSAngel Zhang }); 10486867e49fSAngel Zhang ++unrolledInputIdx; 10496867e49fSAngel Zhang } 10506867e49fSAngel Zhang } 10516867e49fSAngel Zhang 10526867e49fSAngel Zhang // Erase the original funcOp. The `tmpOps` do not need to be erased since 10536867e49fSAngel Zhang // they have no uses and will be handled by dead-code elimination. 10546867e49fSAngel Zhang rewriter.eraseOp(funcOp); 10556867e49fSAngel Zhang return success(); 10566867e49fSAngel Zhang } 10576867e49fSAngel Zhang }; 10586867e49fSAngel Zhang 10596867e49fSAngel Zhang //===----------------------------------------------------------------------===// 10606867e49fSAngel Zhang // func::ReturnOp Conversion Patterns 10616867e49fSAngel Zhang //===----------------------------------------------------------------------===// 10626867e49fSAngel Zhang 10636867e49fSAngel Zhang /// A pattern for rewriting function signature and the return op to convert 10646867e49fSAngel Zhang /// vectors to be of valid types. 10656867e49fSAngel Zhang struct ReturnOpVectorUnroll final : OpRewritePattern<func::ReturnOp> { 10666867e49fSAngel Zhang using OpRewritePattern::OpRewritePattern; 10676867e49fSAngel Zhang 10686867e49fSAngel Zhang LogicalResult matchAndRewrite(func::ReturnOp returnOp, 10696867e49fSAngel Zhang PatternRewriter &rewriter) const override { 10706867e49fSAngel Zhang // Check whether the parent funcOp is valid. 10716867e49fSAngel Zhang auto funcOp = dyn_cast<func::FuncOp>(returnOp->getParentOp()); 10726867e49fSAngel Zhang if (!funcOp) 10736867e49fSAngel Zhang return failure(); 10746867e49fSAngel Zhang 10756867e49fSAngel Zhang FunctionType fnType = funcOp.getFunctionType(); 10766867e49fSAngel Zhang OneToNTypeMapping oneToNTypeMapping(fnType.getResults()); 10776867e49fSAngel Zhang Location loc = returnOp.getLoc(); 10786867e49fSAngel Zhang 10796867e49fSAngel Zhang // For the new return op. 10806867e49fSAngel Zhang SmallVector<Value> newOperands; 10816867e49fSAngel Zhang 10826867e49fSAngel Zhang // Enumerate through the results. 10836867e49fSAngel Zhang for (auto [origResultNo, origType] : enumerate(fnType.getResults())) { 10846867e49fSAngel Zhang // Check whether the argument is of vector type. 10856867e49fSAngel Zhang auto origVecType = dyn_cast<VectorType>(origType); 10866867e49fSAngel Zhang if (!origVecType) { 10876867e49fSAngel Zhang oneToNTypeMapping.addInputs(origResultNo, origType); 10886867e49fSAngel Zhang newOperands.push_back(returnOp.getOperand(origResultNo)); 10896867e49fSAngel Zhang continue; 10906867e49fSAngel Zhang } 10916867e49fSAngel Zhang // Check whether the vector needs unrolling. 10926867e49fSAngel Zhang auto targetShape = getTargetShape(origVecType); 10936867e49fSAngel Zhang if (!targetShape) { 10946867e49fSAngel Zhang // The original argument can be used. 10956867e49fSAngel Zhang oneToNTypeMapping.addInputs(origResultNo, origType); 10966867e49fSAngel Zhang newOperands.push_back(returnOp.getOperand(origResultNo)); 10976867e49fSAngel Zhang continue; 10986867e49fSAngel Zhang } 10996867e49fSAngel Zhang VectorType unrolledType = 11006867e49fSAngel Zhang VectorType::get(*targetShape, origVecType.getElementType()); 11016867e49fSAngel Zhang 11026867e49fSAngel Zhang // Create `vector.extract_strided_slice` ops to form legal vectors from 11036867e49fSAngel Zhang // the original operand of illegal type. 11046867e49fSAngel Zhang auto originalShape = 11056867e49fSAngel Zhang llvm::to_vector_of<int64_t, 4>(origVecType.getShape()); 1106f83950abSAngel Zhang SmallVector<int64_t> strides(originalShape.size(), 1); 1107f83950abSAngel Zhang SmallVector<int64_t> extractShape(originalShape.size(), 1); 1108f83950abSAngel Zhang extractShape.back() = targetShape->back(); 11096867e49fSAngel Zhang SmallVector<Type> newTypes; 11106867e49fSAngel Zhang Value returnValue = returnOp.getOperand(origResultNo); 11116867e49fSAngel Zhang for (SmallVector<int64_t> offsets : 11126867e49fSAngel Zhang StaticTileOffsetRange(originalShape, *targetShape)) { 11136867e49fSAngel Zhang Value result = rewriter.create<vector::ExtractStridedSliceOp>( 1114f83950abSAngel Zhang loc, returnValue, offsets, extractShape, strides); 1115f83950abSAngel Zhang if (originalShape.size() > 1) { 1116f83950abSAngel Zhang SmallVector<int64_t> extractIndices(originalShape.size() - 1, 0); 1117f83950abSAngel Zhang result = 1118f83950abSAngel Zhang rewriter.create<vector::ExtractOp>(loc, result, extractIndices); 1119f83950abSAngel Zhang } 11206867e49fSAngel Zhang newOperands.push_back(result); 11216867e49fSAngel Zhang newTypes.push_back(unrolledType); 11226867e49fSAngel Zhang } 11236867e49fSAngel Zhang oneToNTypeMapping.addInputs(origResultNo, newTypes); 11246867e49fSAngel Zhang } 11256867e49fSAngel Zhang 11266867e49fSAngel Zhang // Change the function signature. 11276867e49fSAngel Zhang auto newFnType = 11286867e49fSAngel Zhang FunctionType::get(rewriter.getContext(), TypeRange(fnType.getInputs()), 11296867e49fSAngel Zhang TypeRange(oneToNTypeMapping.getConvertedTypes())); 11306867e49fSAngel Zhang rewriter.modifyOpInPlace(funcOp, 11316867e49fSAngel Zhang [&] { funcOp.setFunctionType(newFnType); }); 11326867e49fSAngel Zhang 11336867e49fSAngel Zhang // Replace the return op using the new operands. This will automatically 11346867e49fSAngel Zhang // update the entry block as well. 11356867e49fSAngel Zhang rewriter.replaceOp(returnOp, 11366867e49fSAngel Zhang rewriter.create<func::ReturnOp>(loc, newOperands)); 11376867e49fSAngel Zhang 11386867e49fSAngel Zhang return success(); 11396867e49fSAngel Zhang } 11406867e49fSAngel Zhang }; 11419527d77aSAngel Zhang 11426867e49fSAngel Zhang } // namespace 11436867e49fSAngel Zhang 11446867e49fSAngel Zhang //===----------------------------------------------------------------------===// 11459527d77aSAngel Zhang // Public function for builtin variables 114601178654SLei Zhang //===----------------------------------------------------------------------===// 114701178654SLei Zhang 114801178654SLei Zhang Value mlir::spirv::getBuiltinVariableValue(Operation *op, 114901178654SLei Zhang spirv::BuiltIn builtin, 11509da4b6dbSVictor Perez Type integerType, OpBuilder &builder, 11519da4b6dbSVictor Perez StringRef prefix, StringRef suffix) { 115201178654SLei Zhang Operation *parent = SymbolTable::getNearestSymbolTable(op->getParentOp()); 115301178654SLei Zhang if (!parent) { 115401178654SLei Zhang op->emitError("expected operation to be within a module-like op"); 115501178654SLei Zhang return nullptr; 115601178654SLei Zhang } 115701178654SLei Zhang 11581e35a769SButygin spirv::GlobalVariableOp varOp = 11591e35a769SButygin getOrInsertBuiltinVariable(*parent->getRegion(0).begin(), op->getLoc(), 11609da4b6dbSVictor Perez builtin, integerType, builder, prefix, suffix); 116101178654SLei Zhang Value ptr = builder.create<spirv::AddressOfOp>(op->getLoc(), varOp); 116201178654SLei Zhang return builder.create<spirv::LoadOp>(op->getLoc(), ptr); 116301178654SLei Zhang } 116401178654SLei Zhang 116501178654SLei Zhang //===----------------------------------------------------------------------===// 11669527d77aSAngel Zhang // Public function for pushing constant storage 11676dd07fa5SLei Zhang //===----------------------------------------------------------------------===// 11686dd07fa5SLei Zhang 11696dd07fa5SLei Zhang Value spirv::getPushConstantValue(Operation *op, unsigned elementCount, 11701e35a769SButygin unsigned offset, Type integerType, 11711e35a769SButygin OpBuilder &builder) { 11726dd07fa5SLei Zhang Location loc = op->getLoc(); 11736dd07fa5SLei Zhang Operation *parent = SymbolTable::getNearestSymbolTable(op->getParentOp()); 11746dd07fa5SLei Zhang if (!parent) { 11756dd07fa5SLei Zhang op->emitError("expected operation to be within a module-like op"); 11766dd07fa5SLei Zhang return nullptr; 11776dd07fa5SLei Zhang } 11786dd07fa5SLei Zhang 11796dd07fa5SLei Zhang spirv::GlobalVariableOp varOp = getOrInsertPushConstantVariable( 11801e35a769SButygin loc, parent->getRegion(0).front(), elementCount, builder, integerType); 11816dd07fa5SLei Zhang 11821e35a769SButygin Value zeroOp = spirv::ConstantOp::getZero(integerType, loc, builder); 11836dd07fa5SLei Zhang Value offsetOp = builder.create<spirv::ConstantOp>( 11841e35a769SButygin loc, integerType, builder.getI32IntegerAttr(offset)); 11856dd07fa5SLei Zhang auto addrOp = builder.create<spirv::AddressOfOp>(loc, varOp); 11866dd07fa5SLei Zhang auto acOp = builder.create<spirv::AccessChainOp>( 1187984b800aSserge-sans-paille loc, addrOp, llvm::ArrayRef({zeroOp, offsetOp})); 11886dd07fa5SLei Zhang return builder.create<spirv::LoadOp>(loc, acOp); 11896dd07fa5SLei Zhang } 11906dd07fa5SLei Zhang 11916dd07fa5SLei Zhang //===----------------------------------------------------------------------===// 11929527d77aSAngel Zhang // Public functions for index calculation 119301178654SLei Zhang //===----------------------------------------------------------------------===// 119401178654SLei Zhang 1195bb6f5c83SLei Zhang Value mlir::spirv::linearizeIndex(ValueRange indices, ArrayRef<int64_t> strides, 11961e35a769SButygin int64_t offset, Type integerType, 11971e35a769SButygin Location loc, OpBuilder &builder) { 1198bb6f5c83SLei Zhang assert(indices.size() == strides.size() && 1199bb6f5c83SLei Zhang "must provide indices for all dimensions"); 1200bb6f5c83SLei Zhang 1201bb6f5c83SLei Zhang // TODO: Consider moving to use affine.apply and patterns converting 1202bb6f5c83SLei Zhang // affine.apply to standard ops. This needs converting to SPIR-V passes to be 1203bb6f5c83SLei Zhang // broken down into progressive small steps so we can have intermediate steps 1204bb6f5c83SLei Zhang // using other dialects. At the moment SPIR-V is the final sink. 1205bb6f5c83SLei Zhang 120638f8a3cfSFinn Plummer Value linearizedIndex = builder.createOrFold<spirv::ConstantOp>( 12071e35a769SButygin loc, integerType, IntegerAttr::get(integerType, offset)); 1208e4853be2SMehdi Amini for (const auto &index : llvm::enumerate(indices)) { 120938f8a3cfSFinn Plummer Value strideVal = builder.createOrFold<spirv::ConstantOp>( 12101e35a769SButygin loc, integerType, 12111e35a769SButygin IntegerAttr::get(integerType, strides[index.index()])); 121238f8a3cfSFinn Plummer Value update = 121338f8a3cfSFinn Plummer builder.createOrFold<spirv::IMulOp>(loc, index.value(), strideVal); 1214bb6f5c83SLei Zhang linearizedIndex = 121538f8a3cfSFinn Plummer builder.createOrFold<spirv::IAddOp>(loc, update, linearizedIndex); 1216bb6f5c83SLei Zhang } 1217bb6f5c83SLei Zhang return linearizedIndex; 1218bb6f5c83SLei Zhang } 1219bb6f5c83SLei Zhang 1220ce254598SMatthias Springer Value mlir::spirv::getVulkanElementPtr(const SPIRVTypeConverter &typeConverter, 12219c3a73a5SStanley Winata MemRefType baseType, Value basePtr, 12229c3a73a5SStanley Winata ValueRange indices, Location loc, 12239c3a73a5SStanley Winata OpBuilder &builder) { 122401178654SLei Zhang // Get base and offset of the MemRefType and verify they are static. 122501178654SLei Zhang 122601178654SLei Zhang int64_t offset; 122701178654SLei Zhang SmallVector<int64_t, 4> strides; 1228*6aaa8f25SMatthias Springer if (failed(baseType.getStridesAndOffset(strides, offset)) || 1229399638f9SAliia Khasanova llvm::is_contained(strides, ShapedType::kDynamic) || 1230399638f9SAliia Khasanova ShapedType::isDynamic(offset)) { 123101178654SLei Zhang return nullptr; 123201178654SLei Zhang } 123301178654SLei Zhang 12341e35a769SButygin auto indexType = typeConverter.getIndexType(); 123501178654SLei Zhang 123601178654SLei Zhang SmallVector<Value, 2> linearizedIndices; 123701178654SLei Zhang auto zero = spirv::ConstantOp::getZero(indexType, loc, builder); 1238bb6f5c83SLei Zhang 1239bb6f5c83SLei Zhang // Add a '0' at the start to index into the struct. 124001178654SLei Zhang linearizedIndices.push_back(zero); 124101178654SLei Zhang 124201178654SLei Zhang if (baseType.getRank() == 0) { 124301178654SLei Zhang linearizedIndices.push_back(zero); 124401178654SLei Zhang } else { 1245bb6f5c83SLei Zhang linearizedIndices.push_back( 12461e35a769SButygin linearizeIndex(indices, strides, offset, indexType, loc, builder)); 124701178654SLei Zhang } 124801178654SLei Zhang return builder.create<spirv::AccessChainOp>(loc, basePtr, linearizedIndices); 124901178654SLei Zhang } 125001178654SLei Zhang 1251ce254598SMatthias Springer Value mlir::spirv::getOpenCLElementPtr(const SPIRVTypeConverter &typeConverter, 12529c3a73a5SStanley Winata MemRefType baseType, Value basePtr, 12539c3a73a5SStanley Winata ValueRange indices, Location loc, 12549c3a73a5SStanley Winata OpBuilder &builder) { 12559c3a73a5SStanley Winata // Get base and offset of the MemRefType and verify they are static. 12569c3a73a5SStanley Winata 12579c3a73a5SStanley Winata int64_t offset; 12589c3a73a5SStanley Winata SmallVector<int64_t, 4> strides; 1259*6aaa8f25SMatthias Springer if (failed(baseType.getStridesAndOffset(strides, offset)) || 1260399638f9SAliia Khasanova llvm::is_contained(strides, ShapedType::kDynamic) || 1261399638f9SAliia Khasanova ShapedType::isDynamic(offset)) { 12629c3a73a5SStanley Winata return nullptr; 12639c3a73a5SStanley Winata } 12649c3a73a5SStanley Winata 12659c3a73a5SStanley Winata auto indexType = typeConverter.getIndexType(); 12669c3a73a5SStanley Winata 12679c3a73a5SStanley Winata SmallVector<Value, 2> linearizedIndices; 12689c3a73a5SStanley Winata Value linearIndex; 12699c3a73a5SStanley Winata if (baseType.getRank() == 0) { 1270d6de6ddeSNirvedh Meshram linearIndex = spirv::ConstantOp::getZero(indexType, loc, builder); 12719c3a73a5SStanley Winata } else { 12729c3a73a5SStanley Winata linearIndex = 12739c3a73a5SStanley Winata linearizeIndex(indices, strides, offset, indexType, loc, builder); 12749c3a73a5SStanley Winata } 1275d6de6ddeSNirvedh Meshram Type pointeeType = 12765550c821STres Popp cast<spirv::PointerType>(basePtr.getType()).getPointeeType(); 12775550c821STres Popp if (isa<spirv::ArrayType>(pointeeType)) { 1278d6de6ddeSNirvedh Meshram linearizedIndices.push_back(linearIndex); 1279d6de6ddeSNirvedh Meshram return builder.create<spirv::AccessChainOp>(loc, basePtr, 1280d6de6ddeSNirvedh Meshram linearizedIndices); 1281d6de6ddeSNirvedh Meshram } 12829c3a73a5SStanley Winata return builder.create<spirv::PtrAccessChainOp>(loc, basePtr, linearIndex, 12839c3a73a5SStanley Winata linearizedIndices); 12849c3a73a5SStanley Winata } 12859c3a73a5SStanley Winata 1286ce254598SMatthias Springer Value mlir::spirv::getElementPtr(const SPIRVTypeConverter &typeConverter, 12879c3a73a5SStanley Winata MemRefType baseType, Value basePtr, 12889c3a73a5SStanley Winata ValueRange indices, Location loc, 12899c3a73a5SStanley Winata OpBuilder &builder) { 12909c3a73a5SStanley Winata 12919c3a73a5SStanley Winata if (typeConverter.allows(spirv::Capability::Kernel)) { 12929c3a73a5SStanley Winata return getOpenCLElementPtr(typeConverter, baseType, basePtr, indices, loc, 12939c3a73a5SStanley Winata builder); 12949c3a73a5SStanley Winata } 12959c3a73a5SStanley Winata 12969c3a73a5SStanley Winata return getVulkanElementPtr(typeConverter, baseType, basePtr, indices, loc, 12979c3a73a5SStanley Winata builder); 12989c3a73a5SStanley Winata } 12999c3a73a5SStanley Winata 130001178654SLei Zhang //===----------------------------------------------------------------------===// 1301f83950abSAngel Zhang // Public functions for vector unrolling 1302f83950abSAngel Zhang //===----------------------------------------------------------------------===// 1303f83950abSAngel Zhang 1304f83950abSAngel Zhang int mlir::spirv::getComputeVectorSize(int64_t size) { 1305f83950abSAngel Zhang for (int i : {4, 3, 2}) { 1306f83950abSAngel Zhang if (size % i == 0) 1307f83950abSAngel Zhang return i; 1308f83950abSAngel Zhang } 1309f83950abSAngel Zhang return 1; 1310f83950abSAngel Zhang } 1311f83950abSAngel Zhang 1312f83950abSAngel Zhang SmallVector<int64_t> 1313f83950abSAngel Zhang mlir::spirv::getNativeVectorShapeImpl(vector::ReductionOp op) { 1314f83950abSAngel Zhang VectorType srcVectorType = op.getSourceVectorType(); 1315f83950abSAngel Zhang assert(srcVectorType.getRank() == 1); // Guaranteed by semantics 1316f83950abSAngel Zhang int64_t vectorSize = 1317f83950abSAngel Zhang mlir::spirv::getComputeVectorSize(srcVectorType.getDimSize(0)); 1318f83950abSAngel Zhang return {vectorSize}; 1319f83950abSAngel Zhang } 1320f83950abSAngel Zhang 1321f83950abSAngel Zhang SmallVector<int64_t> 1322f83950abSAngel Zhang mlir::spirv::getNativeVectorShapeImpl(vector::TransposeOp op) { 1323f83950abSAngel Zhang VectorType vectorType = op.getResultVectorType(); 1324f83950abSAngel Zhang SmallVector<int64_t> nativeSize(vectorType.getRank(), 1); 1325f83950abSAngel Zhang nativeSize.back() = 1326f83950abSAngel Zhang mlir::spirv::getComputeVectorSize(vectorType.getShape().back()); 1327f83950abSAngel Zhang return nativeSize; 1328f83950abSAngel Zhang } 1329f83950abSAngel Zhang 1330f83950abSAngel Zhang std::optional<SmallVector<int64_t>> 1331f83950abSAngel Zhang mlir::spirv::getNativeVectorShape(Operation *op) { 1332f83950abSAngel Zhang if (OpTrait::hasElementwiseMappableTraits(op) && op->getNumResults() == 1) { 1333f83950abSAngel Zhang if (auto vecType = dyn_cast<VectorType>(op->getResultTypes()[0])) { 1334f83950abSAngel Zhang SmallVector<int64_t> nativeSize(vecType.getRank(), 1); 1335f83950abSAngel Zhang nativeSize.back() = 1336f83950abSAngel Zhang mlir::spirv::getComputeVectorSize(vecType.getShape().back()); 1337f83950abSAngel Zhang return nativeSize; 1338f83950abSAngel Zhang } 1339f83950abSAngel Zhang } 1340f83950abSAngel Zhang 1341f83950abSAngel Zhang return TypeSwitch<Operation *, std::optional<SmallVector<int64_t>>>(op) 1342f83950abSAngel Zhang .Case<vector::ReductionOp, vector::TransposeOp>( 1343f83950abSAngel Zhang [](auto typedOp) { return getNativeVectorShapeImpl(typedOp); }) 1344f83950abSAngel Zhang .Default([](Operation *) { return std::nullopt; }); 1345f83950abSAngel Zhang } 1346f83950abSAngel Zhang 1347f83950abSAngel Zhang LogicalResult mlir::spirv::unrollVectorsInSignatures(Operation *op) { 1348f83950abSAngel Zhang MLIRContext *context = op->getContext(); 1349f83950abSAngel Zhang RewritePatternSet patterns(context); 1350f83950abSAngel Zhang populateFuncOpVectorRewritePatterns(patterns); 1351f83950abSAngel Zhang populateReturnOpVectorRewritePatterns(patterns); 1352f83950abSAngel Zhang // We only want to apply signature conversion once to the existing func ops. 1353f83950abSAngel Zhang // Without specifying strictMode, the greedy pattern rewriter will keep 1354f83950abSAngel Zhang // looking for newly created func ops. 1355f83950abSAngel Zhang GreedyRewriteConfig config; 1356f83950abSAngel Zhang config.strictMode = GreedyRewriteStrictness::ExistingOps; 135709dfc571SJacques Pienaar return applyPatternsGreedily(op, std::move(patterns), config); 1358f83950abSAngel Zhang } 1359f83950abSAngel Zhang 1360f83950abSAngel Zhang LogicalResult mlir::spirv::unrollVectorsInFuncBodies(Operation *op) { 1361f83950abSAngel Zhang MLIRContext *context = op->getContext(); 1362f83950abSAngel Zhang 1363f83950abSAngel Zhang // Unroll vectors in function bodies to native vector size. 1364f83950abSAngel Zhang { 1365f83950abSAngel Zhang RewritePatternSet patterns(context); 1366f83950abSAngel Zhang auto options = vector::UnrollVectorOptions().setNativeShapeFn( 1367f83950abSAngel Zhang [](auto op) { return mlir::spirv::getNativeVectorShape(op); }); 1368f83950abSAngel Zhang populateVectorUnrollPatterns(patterns, options); 136909dfc571SJacques Pienaar if (failed(applyPatternsGreedily(op, std::move(patterns)))) 1370f83950abSAngel Zhang return failure(); 1371f83950abSAngel Zhang } 1372f83950abSAngel Zhang 1373f83950abSAngel Zhang // Convert transpose ops into extract and insert pairs, in preparation of 1374f83950abSAngel Zhang // further transformations to canonicalize/cancel. 1375f83950abSAngel Zhang { 1376f83950abSAngel Zhang RewritePatternSet patterns(context); 1377f83950abSAngel Zhang auto options = vector::VectorTransformsOptions().setVectorTransposeLowering( 1378f83950abSAngel Zhang vector::VectorTransposeLowering::EltWise); 1379f83950abSAngel Zhang vector::populateVectorTransposeLoweringPatterns(patterns, options); 1380f83950abSAngel Zhang vector::populateVectorShapeCastLoweringPatterns(patterns); 138109dfc571SJacques Pienaar if (failed(applyPatternsGreedily(op, std::move(patterns)))) 1382f83950abSAngel Zhang return failure(); 1383f83950abSAngel Zhang } 1384f83950abSAngel Zhang 1385f83950abSAngel Zhang // Run canonicalization to cast away leading size-1 dimensions. 1386f83950abSAngel Zhang { 1387f83950abSAngel Zhang RewritePatternSet patterns(context); 1388f83950abSAngel Zhang 1389f83950abSAngel Zhang // We need to pull in casting way leading one dims. 1390f83950abSAngel Zhang vector::populateCastAwayVectorLeadingOneDimPatterns(patterns); 1391f83950abSAngel Zhang vector::ReductionOp::getCanonicalizationPatterns(patterns, context); 1392f83950abSAngel Zhang vector::TransposeOp::getCanonicalizationPatterns(patterns, context); 1393f83950abSAngel Zhang 1394f83950abSAngel Zhang // Decompose different rank insert_strided_slice and n-D 1395f83950abSAngel Zhang // extract_slided_slice. 1396f83950abSAngel Zhang vector::populateVectorInsertExtractStridedSliceDecompositionPatterns( 1397f83950abSAngel Zhang patterns); 1398f83950abSAngel Zhang vector::InsertOp::getCanonicalizationPatterns(patterns, context); 1399f83950abSAngel Zhang vector::ExtractOp::getCanonicalizationPatterns(patterns, context); 1400f83950abSAngel Zhang 1401f83950abSAngel Zhang // Trimming leading unit dims may generate broadcast/shape_cast ops. Clean 1402f83950abSAngel Zhang // them up. 1403f83950abSAngel Zhang vector::BroadcastOp::getCanonicalizationPatterns(patterns, context); 1404f83950abSAngel Zhang vector::ShapeCastOp::getCanonicalizationPatterns(patterns, context); 1405f83950abSAngel Zhang 140609dfc571SJacques Pienaar if (failed(applyPatternsGreedily(op, std::move(patterns)))) 1407f83950abSAngel Zhang return failure(); 1408f83950abSAngel Zhang } 1409f83950abSAngel Zhang return success(); 1410f83950abSAngel Zhang } 1411f83950abSAngel Zhang 1412f83950abSAngel Zhang //===----------------------------------------------------------------------===// 14139527d77aSAngel Zhang // SPIR-V TypeConverter 14149527d77aSAngel Zhang //===----------------------------------------------------------------------===// 14159527d77aSAngel Zhang 14169527d77aSAngel Zhang SPIRVTypeConverter::SPIRVTypeConverter(spirv::TargetEnvAttr targetAttr, 14179527d77aSAngel Zhang const SPIRVConversionOptions &options) 14189527d77aSAngel Zhang : targetEnv(targetAttr), options(options) { 14199527d77aSAngel Zhang // Add conversions. The order matters here: later ones will be tried earlier. 14209527d77aSAngel Zhang 14219527d77aSAngel Zhang // Allow all SPIR-V dialect specific types. This assumes all builtin types 14229527d77aSAngel Zhang // adopted in the SPIR-V dialect (i.e., IntegerType, FloatType, VectorType) 14239527d77aSAngel Zhang // were tried before. 14249527d77aSAngel Zhang // 14259527d77aSAngel Zhang // TODO: This assumes that the SPIR-V types are valid to use in the given 14269527d77aSAngel Zhang // target environment, which should be the case if the whole pipeline is 14279527d77aSAngel Zhang // driven by the same target environment. Still, we probably still want to 14289527d77aSAngel Zhang // validate and convert to be safe. 14299527d77aSAngel Zhang addConversion([](spirv::SPIRVType type) { return type; }); 14309527d77aSAngel Zhang 14319527d77aSAngel Zhang addConversion([this](IndexType /*indexType*/) { return getIndexType(); }); 14329527d77aSAngel Zhang 14339527d77aSAngel Zhang addConversion([this](IntegerType intType) -> std::optional<Type> { 14349527d77aSAngel Zhang if (auto scalarType = dyn_cast<spirv::ScalarType>(intType)) 14359527d77aSAngel Zhang return convertScalarType(this->targetEnv, this->options, scalarType); 14369527d77aSAngel Zhang if (intType.getWidth() < 8) 14379527d77aSAngel Zhang return convertSubByteIntegerType(this->options, intType); 14389527d77aSAngel Zhang return Type(); 14399527d77aSAngel Zhang }); 14409527d77aSAngel Zhang 14419527d77aSAngel Zhang addConversion([this](FloatType floatType) -> std::optional<Type> { 14429527d77aSAngel Zhang if (auto scalarType = dyn_cast<spirv::ScalarType>(floatType)) 14439527d77aSAngel Zhang return convertScalarType(this->targetEnv, this->options, scalarType); 14449527d77aSAngel Zhang return Type(); 14459527d77aSAngel Zhang }); 14469527d77aSAngel Zhang 14479527d77aSAngel Zhang addConversion([this](ComplexType complexType) { 14489527d77aSAngel Zhang return convertComplexType(this->targetEnv, this->options, complexType); 14499527d77aSAngel Zhang }); 14509527d77aSAngel Zhang 14519527d77aSAngel Zhang addConversion([this](VectorType vectorType) { 14529527d77aSAngel Zhang return convertVectorType(this->targetEnv, this->options, vectorType); 14539527d77aSAngel Zhang }); 14549527d77aSAngel Zhang 14559527d77aSAngel Zhang addConversion([this](TensorType tensorType) { 14569527d77aSAngel Zhang return convertTensorType(this->targetEnv, this->options, tensorType); 14579527d77aSAngel Zhang }); 14589527d77aSAngel Zhang 14599527d77aSAngel Zhang addConversion([this](MemRefType memRefType) { 14609527d77aSAngel Zhang return convertMemrefType(this->targetEnv, this->options, memRefType); 14619527d77aSAngel Zhang }); 14629527d77aSAngel Zhang 14639527d77aSAngel Zhang // Register some last line of defense casting logic. 14649527d77aSAngel Zhang addSourceMaterialization( 14659527d77aSAngel Zhang [this](OpBuilder &builder, Type type, ValueRange inputs, Location loc) { 14669527d77aSAngel Zhang return castToSourceType(this->targetEnv, builder, type, inputs, loc); 14679527d77aSAngel Zhang }); 14689527d77aSAngel Zhang addTargetMaterialization([](OpBuilder &builder, Type type, ValueRange inputs, 14699527d77aSAngel Zhang Location loc) { 14709527d77aSAngel Zhang auto cast = builder.create<UnrealizedConversionCastOp>(loc, type, inputs); 1471f18c3e4eSMatthias Springer return cast.getResult(0); 14729527d77aSAngel Zhang }); 14739527d77aSAngel Zhang } 14749527d77aSAngel Zhang 14759527d77aSAngel Zhang Type SPIRVTypeConverter::getIndexType() const { 14769527d77aSAngel Zhang return ::getIndexType(getContext(), options); 14779527d77aSAngel Zhang } 14789527d77aSAngel Zhang 14799527d77aSAngel Zhang MLIRContext *SPIRVTypeConverter::getContext() const { 14809527d77aSAngel Zhang return targetEnv.getAttr().getContext(); 14819527d77aSAngel Zhang } 14829527d77aSAngel Zhang 14839527d77aSAngel Zhang bool SPIRVTypeConverter::allows(spirv::Capability capability) const { 14849527d77aSAngel Zhang return targetEnv.allows(capability); 14859527d77aSAngel Zhang } 14869527d77aSAngel Zhang 14879527d77aSAngel Zhang //===----------------------------------------------------------------------===// 148801178654SLei Zhang // SPIR-V ConversionTarget 148901178654SLei Zhang //===----------------------------------------------------------------------===// 149001178654SLei Zhang 14916dd07fa5SLei Zhang std::unique_ptr<SPIRVConversionTarget> 14926dd07fa5SLei Zhang SPIRVConversionTarget::get(spirv::TargetEnvAttr targetAttr) { 149301178654SLei Zhang std::unique_ptr<SPIRVConversionTarget> target( 149401178654SLei Zhang // std::make_unique does not work here because the constructor is private. 149501178654SLei Zhang new SPIRVConversionTarget(targetAttr)); 149601178654SLei Zhang SPIRVConversionTarget *targetPtr = target.get(); 14976dd07fa5SLei Zhang target->addDynamicallyLegalDialect<spirv::SPIRVDialect>( 149801178654SLei Zhang // We need to capture the raw pointer here because it is stable: 149901178654SLei Zhang // target will be destroyed once this function is returned. 150001178654SLei Zhang [targetPtr](Operation *op) { return targetPtr->isLegalOp(op); }); 150101178654SLei Zhang return target; 150201178654SLei Zhang } 150301178654SLei Zhang 15046dd07fa5SLei Zhang SPIRVConversionTarget::SPIRVConversionTarget(spirv::TargetEnvAttr targetAttr) 150501178654SLei Zhang : ConversionTarget(*targetAttr.getContext()), targetEnv(targetAttr) {} 150601178654SLei Zhang 15076dd07fa5SLei Zhang bool SPIRVConversionTarget::isLegalOp(Operation *op) { 150801178654SLei Zhang // Make sure this op is available at the given version. Ops not implementing 150901178654SLei Zhang // QueryMinVersionInterface/QueryMaxVersionInterface are available to all 151001178654SLei Zhang // SPIR-V versions. 1511cb395f66SLei Zhang if (auto minVersionIfx = dyn_cast<spirv::QueryMinVersionInterface>(op)) { 1512e8bcc37fSRamkumar Ramachandra std::optional<spirv::Version> minVersion = minVersionIfx.getMinVersion(); 1513cb395f66SLei Zhang if (minVersion && *minVersion > this->targetEnv.getVersion()) { 151401178654SLei Zhang LLVM_DEBUG(llvm::dbgs() 151501178654SLei Zhang << op->getName() << " illegal: requiring min version " 1516cb395f66SLei Zhang << spirv::stringifyVersion(*minVersion) << "\n"); 151701178654SLei Zhang return false; 151801178654SLei Zhang } 1519cb395f66SLei Zhang } 1520cb395f66SLei Zhang if (auto maxVersionIfx = dyn_cast<spirv::QueryMaxVersionInterface>(op)) { 1521e8bcc37fSRamkumar Ramachandra std::optional<spirv::Version> maxVersion = maxVersionIfx.getMaxVersion(); 1522cb395f66SLei Zhang if (maxVersion && *maxVersion < this->targetEnv.getVersion()) { 152301178654SLei Zhang LLVM_DEBUG(llvm::dbgs() 152401178654SLei Zhang << op->getName() << " illegal: requiring max version " 1525cb395f66SLei Zhang << spirv::stringifyVersion(*maxVersion) << "\n"); 152601178654SLei Zhang return false; 152701178654SLei Zhang } 1528cb395f66SLei Zhang } 152901178654SLei Zhang 153001178654SLei Zhang // Make sure this op's required extensions are allowed to use. Ops not 153101178654SLei Zhang // implementing QueryExtensionInterface do not require extensions to be 153201178654SLei Zhang // available. 153301178654SLei Zhang if (auto extensions = dyn_cast<spirv::QueryExtensionInterface>(op)) 153401178654SLei Zhang if (failed(checkExtensionRequirements(op->getName(), this->targetEnv, 153501178654SLei Zhang extensions.getExtensions()))) 153601178654SLei Zhang return false; 153701178654SLei Zhang 153801178654SLei Zhang // Make sure this op's required extensions are allowed to use. Ops not 153901178654SLei Zhang // implementing QueryCapabilityInterface do not require capabilities to be 154001178654SLei Zhang // available. 154101178654SLei Zhang if (auto capabilities = dyn_cast<spirv::QueryCapabilityInterface>(op)) 154201178654SLei Zhang if (failed(checkCapabilityRequirements(op->getName(), this->targetEnv, 154301178654SLei Zhang capabilities.getCapabilities()))) 154401178654SLei Zhang return false; 154501178654SLei Zhang 154601178654SLei Zhang SmallVector<Type, 4> valueTypes; 154701178654SLei Zhang valueTypes.append(op->operand_type_begin(), op->operand_type_end()); 154801178654SLei Zhang valueTypes.append(op->result_type_begin(), op->result_type_end()); 154901178654SLei Zhang 15507c5ecc8bSMogball // Ensure that all types have been converted to SPIRV types. 15517c5ecc8bSMogball if (llvm::any_of(valueTypes, 15525550c821STres Popp [](Type t) { return !isa<spirv::SPIRVType>(t); })) 15537c5ecc8bSMogball return false; 15547c5ecc8bSMogball 155501178654SLei Zhang // Special treatment for global variables, whose type requirements are 155601178654SLei Zhang // conveyed by type attributes. 155701178654SLei Zhang if (auto globalVar = dyn_cast<spirv::GlobalVariableOp>(op)) 155890a1632dSJakub Kuderski valueTypes.push_back(globalVar.getType()); 155901178654SLei Zhang 156001178654SLei Zhang // Make sure the op's operands/results use types that are allowed by the 156101178654SLei Zhang // target environment. 156201178654SLei Zhang SmallVector<ArrayRef<spirv::Extension>, 4> typeExtensions; 156301178654SLei Zhang SmallVector<ArrayRef<spirv::Capability>, 8> typeCapabilities; 156401178654SLei Zhang for (Type valueType : valueTypes) { 156501178654SLei Zhang typeExtensions.clear(); 15665550c821STres Popp cast<spirv::SPIRVType>(valueType).getExtensions(typeExtensions); 156701178654SLei Zhang if (failed(checkExtensionRequirements(op->getName(), this->targetEnv, 156801178654SLei Zhang typeExtensions))) 156901178654SLei Zhang return false; 157001178654SLei Zhang 157101178654SLei Zhang typeCapabilities.clear(); 15725550c821STres Popp cast<spirv::SPIRVType>(valueType).getCapabilities(typeCapabilities); 157301178654SLei Zhang if (failed(checkCapabilityRequirements(op->getName(), this->targetEnv, 157401178654SLei Zhang typeCapabilities))) 157501178654SLei Zhang return false; 157601178654SLei Zhang } 157701178654SLei Zhang 157801178654SLei Zhang return true; 157901178654SLei Zhang } 15809527d77aSAngel Zhang 15819527d77aSAngel Zhang //===----------------------------------------------------------------------===// 15829527d77aSAngel Zhang // Public functions for populating patterns 15839527d77aSAngel Zhang //===----------------------------------------------------------------------===// 15849527d77aSAngel Zhang 1585206fad0eSMatthias Springer void mlir::populateBuiltinFuncToSPIRVPatterns( 1586206fad0eSMatthias Springer const SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns) { 15879527d77aSAngel Zhang patterns.add<FuncOpConversion>(typeConverter, patterns.getContext()); 15889527d77aSAngel Zhang } 15899527d77aSAngel Zhang 15909527d77aSAngel Zhang void mlir::populateFuncOpVectorRewritePatterns(RewritePatternSet &patterns) { 15919527d77aSAngel Zhang patterns.add<FuncOpVectorUnroll>(patterns.getContext()); 15929527d77aSAngel Zhang } 15939527d77aSAngel Zhang 15949527d77aSAngel Zhang void mlir::populateReturnOpVectorRewritePatterns(RewritePatternSet &patterns) { 15959527d77aSAngel Zhang patterns.add<ReturnOpVectorUnroll>(patterns.getContext()); 15969527d77aSAngel Zhang } 1597