xref: /llvm-project/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp (revision 6aaa8f25b66dc1fef4e465f274ee40b82d632988)
101178654SLei Zhang //===- SPIRVConversion.cpp - SPIR-V Conversion Utilities ------------------===//
201178654SLei Zhang //
301178654SLei Zhang // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
401178654SLei Zhang // See https://llvm.org/LICENSE.txt for license information.
501178654SLei Zhang // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
601178654SLei Zhang //
701178654SLei Zhang //===----------------------------------------------------------------------===//
801178654SLei Zhang //
901178654SLei Zhang // This file implements utilities used to lower to SPIR-V dialect.
1001178654SLei Zhang //
1101178654SLei Zhang //===----------------------------------------------------------------------===//
1201178654SLei Zhang 
1301178654SLei Zhang #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
146867e49fSAngel Zhang #include "mlir/Dialect/Arith/IR/Arith.h"
1536550692SRiver Riddle #include "mlir/Dialect/Func/IR/FuncOps.h"
1601178654SLei Zhang #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
1789b595e1SLei Zhang #include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
1801178654SLei Zhang #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
19fce33e11SLei Zhang #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
2097f3bb73SLei Zhang #include "mlir/Dialect/SPIRV/IR/TargetAndABI.h"
216867e49fSAngel Zhang #include "mlir/Dialect/Utils/IndexingUtils.h"
226867e49fSAngel Zhang #include "mlir/Dialect/Vector/IR/VectorOps.h"
23f83950abSAngel Zhang #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
246867e49fSAngel Zhang #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
2569a3c9cdSLei Zhang #include "mlir/IR/BuiltinTypes.h"
266867e49fSAngel Zhang #include "mlir/IR/Operation.h"
276867e49fSAngel Zhang #include "mlir/IR/PatternMatch.h"
28f83950abSAngel Zhang #include "mlir/Pass/Pass.h"
296867e49fSAngel Zhang #include "mlir/Support/LLVM.h"
307c3ae48fSLei Zhang #include "mlir/Transforms/DialectConversion.h"
31f83950abSAngel Zhang #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
326867e49fSAngel Zhang #include "mlir/Transforms/OneToNTypeConversion.h"
336867e49fSAngel Zhang #include "llvm/ADT/STLExtras.h"
346867e49fSAngel Zhang #include "llvm/ADT/SmallVector.h"
3501178654SLei Zhang #include "llvm/ADT/StringExtras.h"
3601178654SLei Zhang #include "llvm/Support/Debug.h"
37f83950abSAngel Zhang #include "llvm/Support/LogicalResult.h"
38f772dcbbSLei Zhang #include "llvm/Support/MathExtras.h"
3901178654SLei Zhang 
4001178654SLei Zhang #include <functional>
4197f3bb73SLei Zhang #include <optional>
4201178654SLei Zhang 
4301178654SLei Zhang #define DEBUG_TYPE "mlir-spirv-conversion"
4401178654SLei Zhang 
4501178654SLei Zhang using namespace mlir;
4601178654SLei Zhang 
479527d77aSAngel Zhang namespace {
489527d77aSAngel Zhang 
4901178654SLei Zhang //===----------------------------------------------------------------------===//
5001178654SLei Zhang // Utility functions
5101178654SLei Zhang //===----------------------------------------------------------------------===//
5201178654SLei Zhang 
536867e49fSAngel Zhang static std::optional<SmallVector<int64_t>> getTargetShape(VectorType vecType) {
546867e49fSAngel Zhang   LLVM_DEBUG(llvm::dbgs() << "Get target shape\n");
556867e49fSAngel Zhang   if (vecType.isScalable()) {
566867e49fSAngel Zhang     LLVM_DEBUG(llvm::dbgs()
576867e49fSAngel Zhang                << "--scalable vectors are not supported -> BAIL\n");
586867e49fSAngel Zhang     return std::nullopt;
596867e49fSAngel Zhang   }
606867e49fSAngel Zhang   SmallVector<int64_t> unrollShape = llvm::to_vector<4>(vecType.getShape());
61f83950abSAngel Zhang   std::optional<SmallVector<int64_t>> targetShape = SmallVector<int64_t>(
62f83950abSAngel Zhang       1, mlir::spirv::getComputeVectorSize(vecType.getShape().back()));
636867e49fSAngel Zhang   if (!targetShape) {
646867e49fSAngel Zhang     LLVM_DEBUG(llvm::dbgs() << "--no unrolling target shape defined\n");
656867e49fSAngel Zhang     return std::nullopt;
666867e49fSAngel Zhang   }
676867e49fSAngel Zhang   auto maybeShapeRatio = computeShapeRatio(unrollShape, *targetShape);
686867e49fSAngel Zhang   if (!maybeShapeRatio) {
696867e49fSAngel Zhang     LLVM_DEBUG(llvm::dbgs()
706867e49fSAngel Zhang                << "--could not compute integral shape ratio -> BAIL\n");
716867e49fSAngel Zhang     return std::nullopt;
726867e49fSAngel Zhang   }
736867e49fSAngel Zhang   if (llvm::all_of(*maybeShapeRatio, [](int64_t v) { return v == 1; })) {
746867e49fSAngel Zhang     LLVM_DEBUG(llvm::dbgs() << "--no unrolling needed -> SKIP\n");
756867e49fSAngel Zhang     return std::nullopt;
766867e49fSAngel Zhang   }
776867e49fSAngel Zhang   LLVM_DEBUG(llvm::dbgs()
786867e49fSAngel Zhang              << "--found an integral shape ratio to unroll to -> SUCCESS\n");
796867e49fSAngel Zhang   return targetShape;
806867e49fSAngel Zhang }
816867e49fSAngel Zhang 
8201178654SLei Zhang /// Checks that `candidates` extension requirements are possible to be satisfied
8301178654SLei Zhang /// with the given `targetEnv`.
8401178654SLei Zhang ///
8501178654SLei Zhang ///  `candidates` is a vector of vector for extension requirements following
8601178654SLei Zhang /// ((Extension::A OR Extension::B) AND (Extension::C OR Extension::D))
8701178654SLei Zhang /// convention.
8801178654SLei Zhang template <typename LabelT>
8901178654SLei Zhang static LogicalResult checkExtensionRequirements(
9001178654SLei Zhang     LabelT label, const spirv::TargetEnv &targetEnv,
9101178654SLei Zhang     const spirv::SPIRVType::ExtensionArrayRefVector &candidates) {
9201178654SLei Zhang   for (const auto &ors : candidates) {
9301178654SLei Zhang     if (targetEnv.allows(ors))
9401178654SLei Zhang       continue;
9501178654SLei Zhang 
96fd91f81cSLei Zhang     LLVM_DEBUG({
97fd91f81cSLei Zhang       SmallVector<StringRef> extStrings;
9801178654SLei Zhang       for (spirv::Extension ext : ors)
9901178654SLei Zhang         extStrings.push_back(spirv::stringifyExtension(ext));
10001178654SLei Zhang 
101fd91f81cSLei Zhang       llvm::dbgs() << label << " illegal: requires at least one extension in ["
10201178654SLei Zhang                    << llvm::join(extStrings, ", ")
103fd91f81cSLei Zhang                    << "] but none allowed in target environment\n";
104fd91f81cSLei Zhang     });
10501178654SLei Zhang     return failure();
10601178654SLei Zhang   }
10701178654SLei Zhang   return success();
10801178654SLei Zhang }
10901178654SLei Zhang 
11001178654SLei Zhang /// Checks that `candidates`capability requirements are possible to be satisfied
11101178654SLei Zhang /// with the given `isAllowedFn`.
11201178654SLei Zhang ///
11301178654SLei Zhang ///  `candidates` is a vector of vector for capability requirements following
11401178654SLei Zhang /// ((Capability::A OR Capability::B) AND (Capability::C OR Capability::D))
11501178654SLei Zhang /// convention.
11601178654SLei Zhang template <typename LabelT>
11701178654SLei Zhang static LogicalResult checkCapabilityRequirements(
11801178654SLei Zhang     LabelT label, const spirv::TargetEnv &targetEnv,
11901178654SLei Zhang     const spirv::SPIRVType::CapabilityArrayRefVector &candidates) {
12001178654SLei Zhang   for (const auto &ors : candidates) {
12101178654SLei Zhang     if (targetEnv.allows(ors))
12201178654SLei Zhang       continue;
12301178654SLei Zhang 
124fd91f81cSLei Zhang     LLVM_DEBUG({
125fd91f81cSLei Zhang       SmallVector<StringRef> capStrings;
12601178654SLei Zhang       for (spirv::Capability cap : ors)
12701178654SLei Zhang         capStrings.push_back(spirv::stringifyCapability(cap));
12801178654SLei Zhang 
129fd91f81cSLei Zhang       llvm::dbgs() << label << " illegal: requires at least one capability in ["
13001178654SLei Zhang                    << llvm::join(capStrings, ", ")
131fd91f81cSLei Zhang                    << "] but none allowed in target environment\n";
132fd91f81cSLei Zhang     });
13301178654SLei Zhang     return failure();
13401178654SLei Zhang   }
13501178654SLei Zhang   return success();
13601178654SLei Zhang }
13701178654SLei Zhang 
1385b15fe93SLei Zhang /// Returns true if the given `storageClass` needs explicit layout when used in
1395b15fe93SLei Zhang /// Shader environments.
1405b15fe93SLei Zhang static bool needsExplicitLayout(spirv::StorageClass storageClass) {
1415b15fe93SLei Zhang   switch (storageClass) {
1425b15fe93SLei Zhang   case spirv::StorageClass::PhysicalStorageBuffer:
1435b15fe93SLei Zhang   case spirv::StorageClass::PushConstant:
1445b15fe93SLei Zhang   case spirv::StorageClass::StorageBuffer:
1455b15fe93SLei Zhang   case spirv::StorageClass::Uniform:
1465b15fe93SLei Zhang     return true;
1475b15fe93SLei Zhang   default:
1485b15fe93SLei Zhang     return false;
1495b15fe93SLei Zhang   }
1505b15fe93SLei Zhang }
1515b15fe93SLei Zhang 
1525b15fe93SLei Zhang /// Wraps the given `elementType` in a struct and gets the pointer to the
1535b15fe93SLei Zhang /// struct. This is used to satisfy Vulkan interface requirements.
1545b15fe93SLei Zhang static spirv::PointerType
1555b15fe93SLei Zhang wrapInStructAndGetPointer(Type elementType, spirv::StorageClass storageClass) {
1565b15fe93SLei Zhang   auto structType = needsExplicitLayout(storageClass)
1575b15fe93SLei Zhang                         ? spirv::StructType::get(elementType, /*offsetInfo=*/0)
1585b15fe93SLei Zhang                         : spirv::StructType::get(elementType);
1595b15fe93SLei Zhang   return spirv::PointerType::get(structType, storageClass);
1605b15fe93SLei Zhang }
1615b15fe93SLei Zhang 
16201178654SLei Zhang //===----------------------------------------------------------------------===//
16301178654SLei Zhang // Type Conversion
16401178654SLei Zhang //===----------------------------------------------------------------------===//
16501178654SLei Zhang 
166c29fc69eSJakub Kuderski static spirv::ScalarType getIndexType(MLIRContext *ctx,
167c29fc69eSJakub Kuderski                                       const SPIRVConversionOptions &options) {
168c29fc69eSJakub Kuderski   return cast<spirv::ScalarType>(
169c29fc69eSJakub Kuderski       IntegerType::get(ctx, options.use64bitIndex ? 64 : 32));
170c29fc69eSJakub Kuderski }
171c29fc69eSJakub Kuderski 
1725299843cSLei Zhang // TODO: This is a utility function that should probably be exposed by the
1735299843cSLei Zhang // SPIR-V dialect. Keeping it local till the use case arises.
1740de16fafSRamkumar Ramachandra static std::optional<int64_t>
1750de16fafSRamkumar Ramachandra getTypeNumBytes(const SPIRVConversionOptions &options, Type type) {
1765550c821STres Popp   if (isa<spirv::ScalarType>(type)) {
1775299843cSLei Zhang     auto bitWidth = type.getIntOrFloatBitWidth();
17801178654SLei Zhang     // According to the SPIR-V spec:
17901178654SLei Zhang     // "There is no physical size or bit pattern defined for values with boolean
18001178654SLei Zhang     // type. If they are stored (in conjunction with OpVariable), they can only
18101178654SLei Zhang     // be used with logical addressing operations, not physical, and only with
18201178654SLei Zhang     // non-externally visible shader Storage Classes: Workgroup, CrossWorkgroup,
18301178654SLei Zhang     // Private, Function, Input, and Output."
1845299843cSLei Zhang     if (bitWidth == 1)
1851a36588eSKazu Hirata       return std::nullopt;
18601178654SLei Zhang     return bitWidth / 8;
18701178654SLei Zhang   }
1886dd07fa5SLei Zhang 
1895550c821STres Popp   if (auto complexType = dyn_cast<ComplexType>(type)) {
19097f3bb73SLei Zhang     auto elementSize = getTypeNumBytes(options, complexType.getElementType());
19197f3bb73SLei Zhang     if (!elementSize)
19297f3bb73SLei Zhang       return std::nullopt;
19397f3bb73SLei Zhang     return 2 * *elementSize;
19497f3bb73SLei Zhang   }
19597f3bb73SLei Zhang 
1965550c821STres Popp   if (auto vecType = dyn_cast<VectorType>(type)) {
1975299843cSLei Zhang     auto elementSize = getTypeNumBytes(options, vecType.getElementType());
19801178654SLei Zhang     if (!elementSize)
1991a36588eSKazu Hirata       return std::nullopt;
2006d5fc1e3SKazu Hirata     return vecType.getNumElements() * *elementSize;
20101178654SLei Zhang   }
2026dd07fa5SLei Zhang 
2035550c821STres Popp   if (auto memRefType = dyn_cast<MemRefType>(type)) {
20401178654SLei Zhang     // TODO: Layout should also be controlled by the ABI attributes. For now
20501178654SLei Zhang     // using the layout from MemRef.
20601178654SLei Zhang     int64_t offset;
20701178654SLei Zhang     SmallVector<int64_t, 4> strides;
20801178654SLei Zhang     if (!memRefType.hasStaticShape() ||
209*6aaa8f25SMatthias Springer         failed(memRefType.getStridesAndOffset(strides, offset)))
2101a36588eSKazu Hirata       return std::nullopt;
2115299843cSLei Zhang 
21201178654SLei Zhang     // To get the size of the memref object in memory, the total size is the
21301178654SLei Zhang     // max(stride * dimension-size) computed for all dimensions times the size
21401178654SLei Zhang     // of the element.
2155299843cSLei Zhang     auto elementSize = getTypeNumBytes(options, memRefType.getElementType());
2165299843cSLei Zhang     if (!elementSize)
2171a36588eSKazu Hirata       return std::nullopt;
2185299843cSLei Zhang 
2195299843cSLei Zhang     if (memRefType.getRank() == 0)
22001178654SLei Zhang       return elementSize;
2215299843cSLei Zhang 
22201178654SLei Zhang     auto dims = memRefType.getShape();
223399638f9SAliia Khasanova     if (llvm::is_contained(dims, ShapedType::kDynamic) ||
224399638f9SAliia Khasanova         ShapedType::isDynamic(offset) ||
225399638f9SAliia Khasanova         llvm::is_contained(strides, ShapedType::kDynamic))
2261a36588eSKazu Hirata       return std::nullopt;
2275299843cSLei Zhang 
22801178654SLei Zhang     int64_t memrefSize = -1;
229e4853be2SMehdi Amini     for (const auto &shape : enumerate(dims))
23001178654SLei Zhang       memrefSize = std::max(memrefSize, shape.value() * strides[shape.index()]);
2315299843cSLei Zhang 
2326d5fc1e3SKazu Hirata     return (offset + memrefSize) * *elementSize;
2336dd07fa5SLei Zhang   }
2346dd07fa5SLei Zhang 
2355550c821STres Popp   if (auto tensorType = dyn_cast<TensorType>(type)) {
2365299843cSLei Zhang     if (!tensorType.hasStaticShape())
2371a36588eSKazu Hirata       return std::nullopt;
2385299843cSLei Zhang 
2395299843cSLei Zhang     auto elementSize = getTypeNumBytes(options, tensorType.getElementType());
2405299843cSLei Zhang     if (!elementSize)
2411a36588eSKazu Hirata       return std::nullopt;
2425299843cSLei Zhang 
2436d5fc1e3SKazu Hirata     int64_t size = *elementSize;
2445299843cSLei Zhang     for (auto shape : tensorType.getShape())
24501178654SLei Zhang       size *= shape;
2465299843cSLei Zhang 
24701178654SLei Zhang     return size;
24801178654SLei Zhang   }
2496dd07fa5SLei Zhang 
25001178654SLei Zhang   // TODO: Add size computation for other types.
2511a36588eSKazu Hirata   return std::nullopt;
25201178654SLei Zhang }
25301178654SLei Zhang 
25401178654SLei Zhang /// Converts a scalar `type` to a suitable type under the given `targetEnv`.
2550de16fafSRamkumar Ramachandra static Type
2560de16fafSRamkumar Ramachandra convertScalarType(const spirv::TargetEnv &targetEnv,
2570de16fafSRamkumar Ramachandra                   const SPIRVConversionOptions &options, spirv::ScalarType type,
2580de16fafSRamkumar Ramachandra                   std::optional<spirv::StorageClass> storageClass = {}) {
25901178654SLei Zhang   // Get extension and capability requirements for the given type.
26001178654SLei Zhang   SmallVector<ArrayRef<spirv::Extension>, 1> extensions;
26101178654SLei Zhang   SmallVector<ArrayRef<spirv::Capability>, 2> capabilities;
26201178654SLei Zhang   type.getExtensions(extensions, storageClass);
26301178654SLei Zhang   type.getCapabilities(capabilities, storageClass);
26401178654SLei Zhang 
26501178654SLei Zhang   // If all requirements are met, then we can accept this type as-is.
26601178654SLei Zhang   if (succeeded(checkCapabilityRequirements(type, targetEnv, capabilities)) &&
26701178654SLei Zhang       succeeded(checkExtensionRequirements(type, targetEnv, extensions)))
26801178654SLei Zhang     return type;
26901178654SLei Zhang 
27001178654SLei Zhang   // Otherwise we need to adjust the type, which really means adjusting the
27101178654SLei Zhang   // bitwidth given this is a scalar type.
272c0645454SJakub Kuderski   if (!options.emulateLT32BitScalarTypes)
2735299843cSLei Zhang     return nullptr;
27401178654SLei Zhang 
275c0645454SJakub Kuderski   // We only emulate narrower scalar types here and do not truncate results.
276c0645454SJakub Kuderski   if (type.getIntOrFloatBitWidth() > 32) {
277c0645454SJakub Kuderski     LLVM_DEBUG(llvm::dbgs()
278c0645454SJakub Kuderski                << type
279c0645454SJakub Kuderski                << " not converted to 32-bit for SPIR-V to avoid truncation\n");
280c0645454SJakub Kuderski     return nullptr;
281c0645454SJakub Kuderski   }
282c0645454SJakub Kuderski 
2835550c821STres Popp   if (auto floatType = dyn_cast<FloatType>(type)) {
28401178654SLei Zhang     LLVM_DEBUG(llvm::dbgs() << type << " converted to 32-bit for SPIR-V\n");
28501178654SLei Zhang     return Builder(targetEnv.getContext()).getF32Type();
28601178654SLei Zhang   }
28701178654SLei Zhang 
2885550c821STres Popp   auto intType = cast<IntegerType>(type);
28901178654SLei Zhang   LLVM_DEBUG(llvm::dbgs() << type << " converted to 32-bit for SPIR-V\n");
2901b97cdf8SRiver Riddle   return IntegerType::get(targetEnv.getContext(), /*width=*/32,
2911b97cdf8SRiver Riddle                           intType.getSignedness());
29201178654SLei Zhang }
29301178654SLei Zhang 
294f772dcbbSLei Zhang /// Converts a sub-byte integer `type` to i32 regardless of target environment.
295bb9bb686SJakub Kuderski /// Returns a nullptr for unsupported integer types, including non sub-byte
296bb9bb686SJakub Kuderski /// types.
297f772dcbbSLei Zhang ///
298f772dcbbSLei Zhang /// Note that we don't recognize sub-byte types in `spirv::ScalarType` and use
299f772dcbbSLei Zhang /// the above given that these sub-byte types are not supported at all in
300f772dcbbSLei Zhang /// SPIR-V; there are no compute/storage capability for them like other
301f772dcbbSLei Zhang /// supported integer types.
302f772dcbbSLei Zhang static Type convertSubByteIntegerType(const SPIRVConversionOptions &options,
303f772dcbbSLei Zhang                                       IntegerType type) {
304bb9bb686SJakub Kuderski   if (type.getWidth() > 8) {
305bb9bb686SJakub Kuderski     LLVM_DEBUG(llvm::dbgs() << "not a subbyte type\n");
306bb9bb686SJakub Kuderski     return nullptr;
307bb9bb686SJakub Kuderski   }
308f772dcbbSLei Zhang   if (options.subByteTypeStorage != SPIRVSubByteTypeStorage::Packed) {
309f772dcbbSLei Zhang     LLVM_DEBUG(llvm::dbgs() << "unsupported sub-byte storage kind\n");
310f772dcbbSLei Zhang     return nullptr;
311f772dcbbSLei Zhang   }
312f772dcbbSLei Zhang 
313f772dcbbSLei Zhang   if (!llvm::isPowerOf2_32(type.getWidth())) {
314f772dcbbSLei Zhang     LLVM_DEBUG(llvm::dbgs()
315f772dcbbSLei Zhang                << "unsupported non-power-of-two bitwidth in sub-byte" << type
316f772dcbbSLei Zhang                << "\n");
317f772dcbbSLei Zhang     return nullptr;
318f772dcbbSLei Zhang   }
319f772dcbbSLei Zhang 
320f772dcbbSLei Zhang   LLVM_DEBUG(llvm::dbgs() << type << " converted to 32-bit for SPIR-V\n");
321f772dcbbSLei Zhang   return IntegerType::get(type.getContext(), /*width=*/32,
322f772dcbbSLei Zhang                           type.getSignedness());
323f772dcbbSLei Zhang }
324f772dcbbSLei Zhang 
325c29fc69eSJakub Kuderski /// Returns a type with the same shape but with any index element type converted
326c29fc69eSJakub Kuderski /// to the matching integer type. This is a noop when the element type is not
327c29fc69eSJakub Kuderski /// the index type.
328c29fc69eSJakub Kuderski static ShapedType
329c29fc69eSJakub Kuderski convertIndexElementType(ShapedType type,
330c29fc69eSJakub Kuderski                         const SPIRVConversionOptions &options) {
331c29fc69eSJakub Kuderski   Type indexType = dyn_cast<IndexType>(type.getElementType());
332c29fc69eSJakub Kuderski   if (!indexType)
333c29fc69eSJakub Kuderski     return type;
334c29fc69eSJakub Kuderski 
335c29fc69eSJakub Kuderski   return type.clone(getIndexType(type.getContext(), options));
336c29fc69eSJakub Kuderski }
337c29fc69eSJakub Kuderski 
33801178654SLei Zhang /// Converts a vector `type` to a suitable type under the given `targetEnv`.
3390de16fafSRamkumar Ramachandra static Type
3400de16fafSRamkumar Ramachandra convertVectorType(const spirv::TargetEnv &targetEnv,
3410de16fafSRamkumar Ramachandra                   const SPIRVConversionOptions &options, VectorType type,
3420de16fafSRamkumar Ramachandra                   std::optional<spirv::StorageClass> storageClass = {}) {
343c29fc69eSJakub Kuderski   type = cast<VectorType>(convertIndexElementType(type, options));
344c29fc69eSJakub Kuderski   auto scalarType = dyn_cast_or_null<spirv::ScalarType>(type.getElementType());
345c29fc69eSJakub Kuderski   if (!scalarType) {
34669a3c9cdSLei Zhang     // If this is not a spec allowed scalar type, try to handle sub-byte integer
34769a3c9cdSLei Zhang     // types.
34869a3c9cdSLei Zhang     auto intType = dyn_cast<IntegerType>(type.getElementType());
34969a3c9cdSLei Zhang     if (!intType) {
350c29fc69eSJakub Kuderski       LLVM_DEBUG(llvm::dbgs()
35169a3c9cdSLei Zhang                  << type
35269a3c9cdSLei Zhang                  << " illegal: cannot convert non-scalar element type\n");
353c29fc69eSJakub Kuderski       return nullptr;
354c29fc69eSJakub Kuderski     }
355c29fc69eSJakub Kuderski 
35669a3c9cdSLei Zhang     Type elementType = convertSubByteIntegerType(options, intType);
357bb9bb686SJakub Kuderski     if (!elementType)
358bb9bb686SJakub Kuderski       return nullptr;
359bb9bb686SJakub Kuderski 
36069a3c9cdSLei Zhang     if (type.getRank() <= 1 && type.getNumElements() == 1)
36169a3c9cdSLei Zhang       return elementType;
36269a3c9cdSLei Zhang 
36369a3c9cdSLei Zhang     if (type.getNumElements() > 4) {
36469a3c9cdSLei Zhang       LLVM_DEBUG(llvm::dbgs()
36569a3c9cdSLei Zhang                  << type << " illegal: > 4-element unimplemented\n");
36669a3c9cdSLei Zhang       return nullptr;
36769a3c9cdSLei Zhang     }
36869a3c9cdSLei Zhang 
36969a3c9cdSLei Zhang     return VectorType::get(type.getShape(), elementType);
37069a3c9cdSLei Zhang   }
37169a3c9cdSLei Zhang 
37253dac098SLei Zhang   if (type.getRank() <= 1 && type.getNumElements() == 1)
373fce33e11SLei Zhang     return convertScalarType(targetEnv, options, scalarType, storageClass);
3749f622b3dSLei Zhang 
37501178654SLei Zhang   if (!spirv::CompositeType::isValid(type)) {
37669a3c9cdSLei Zhang     LLVM_DEBUG(llvm::dbgs()
37769a3c9cdSLei Zhang                << type << " illegal: not a valid composite type\n");
378004f29c0SLei Zhang     return nullptr;
37901178654SLei Zhang   }
38001178654SLei Zhang 
38101178654SLei Zhang   // Get extension and capability requirements for the given type.
38201178654SLei Zhang   SmallVector<ArrayRef<spirv::Extension>, 1> extensions;
38301178654SLei Zhang   SmallVector<ArrayRef<spirv::Capability>, 2> capabilities;
3845550c821STres Popp   cast<spirv::CompositeType>(type).getExtensions(extensions, storageClass);
3855550c821STres Popp   cast<spirv::CompositeType>(type).getCapabilities(capabilities, storageClass);
38601178654SLei Zhang 
38701178654SLei Zhang   // If all requirements are met, then we can accept this type as-is.
38801178654SLei Zhang   if (succeeded(checkCapabilityRequirements(type, targetEnv, capabilities)) &&
38901178654SLei Zhang       succeeded(checkExtensionRequirements(type, targetEnv, extensions)))
39001178654SLei Zhang     return type;
39101178654SLei Zhang 
392fce33e11SLei Zhang   auto elementType =
393fce33e11SLei Zhang       convertScalarType(targetEnv, options, scalarType, storageClass);
39401178654SLei Zhang   if (elementType)
395004f29c0SLei Zhang     return VectorType::get(type.getShape(), elementType);
396004f29c0SLei Zhang   return nullptr;
39701178654SLei Zhang }
39801178654SLei Zhang 
39997f3bb73SLei Zhang static Type
40097f3bb73SLei Zhang convertComplexType(const spirv::TargetEnv &targetEnv,
40197f3bb73SLei Zhang                    const SPIRVConversionOptions &options, ComplexType type,
40297f3bb73SLei Zhang                    std::optional<spirv::StorageClass> storageClass = {}) {
40397f3bb73SLei Zhang   auto scalarType = dyn_cast_or_null<spirv::ScalarType>(type.getElementType());
40497f3bb73SLei Zhang   if (!scalarType) {
40597f3bb73SLei Zhang     LLVM_DEBUG(llvm::dbgs()
40697f3bb73SLei Zhang                << type << " illegal: cannot convert non-scalar element type\n");
40797f3bb73SLei Zhang     return nullptr;
40897f3bb73SLei Zhang   }
40997f3bb73SLei Zhang 
41097f3bb73SLei Zhang   auto elementType =
41197f3bb73SLei Zhang       convertScalarType(targetEnv, options, scalarType, storageClass);
41297f3bb73SLei Zhang   if (!elementType)
41397f3bb73SLei Zhang     return nullptr;
41497f3bb73SLei Zhang   if (elementType != type.getElementType()) {
41597f3bb73SLei Zhang     LLVM_DEBUG(llvm::dbgs()
41697f3bb73SLei Zhang                << type << " illegal: complex type emulation unsupported\n");
41797f3bb73SLei Zhang     return nullptr;
41897f3bb73SLei Zhang   }
41997f3bb73SLei Zhang 
42097f3bb73SLei Zhang   return VectorType::get(2, elementType);
42197f3bb73SLei Zhang }
42297f3bb73SLei Zhang 
42301178654SLei Zhang /// Converts a tensor `type` to a suitable type under the given `targetEnv`.
42401178654SLei Zhang ///
42501178654SLei Zhang /// Note that this is mainly for lowering constant tensors. In SPIR-V one can
42601178654SLei Zhang /// create composite constants with OpConstantComposite to embed relative large
42701178654SLei Zhang /// constant values and use OpCompositeExtract and OpCompositeInsert to
42801178654SLei Zhang /// manipulate, like what we do for vectors.
429004f29c0SLei Zhang static Type convertTensorType(const spirv::TargetEnv &targetEnv,
430dfaebd3dSLei Zhang                               const SPIRVConversionOptions &options,
43101178654SLei Zhang                               TensorType type) {
43201178654SLei Zhang   // TODO: Handle dynamic shapes.
43301178654SLei Zhang   if (!type.hasStaticShape()) {
43401178654SLei Zhang     LLVM_DEBUG(llvm::dbgs()
43501178654SLei Zhang                << type << " illegal: dynamic shape unimplemented\n");
436004f29c0SLei Zhang     return nullptr;
43701178654SLei Zhang   }
43801178654SLei Zhang 
439c29fc69eSJakub Kuderski   type = cast<TensorType>(convertIndexElementType(type, options));
440c29fc69eSJakub Kuderski   auto scalarType = dyn_cast_or_null<spirv::ScalarType>(type.getElementType());
44101178654SLei Zhang   if (!scalarType) {
44201178654SLei Zhang     LLVM_DEBUG(llvm::dbgs()
44301178654SLei Zhang                << type << " illegal: cannot convert non-scalar element type\n");
444004f29c0SLei Zhang     return nullptr;
44501178654SLei Zhang   }
44601178654SLei Zhang 
4470de16fafSRamkumar Ramachandra   std::optional<int64_t> scalarSize = getTypeNumBytes(options, scalarType);
4480de16fafSRamkumar Ramachandra   std::optional<int64_t> tensorSize = getTypeNumBytes(options, type);
44901178654SLei Zhang   if (!scalarSize || !tensorSize) {
45001178654SLei Zhang     LLVM_DEBUG(llvm::dbgs()
45101178654SLei Zhang                << type << " illegal: cannot deduce element count\n");
452004f29c0SLei Zhang     return nullptr;
45301178654SLei Zhang   }
45401178654SLei Zhang 
4552ad297dbSJakub Kuderski   int64_t arrayElemCount = *tensorSize / *scalarSize;
4562ad297dbSJakub Kuderski   if (arrayElemCount == 0) {
4572ad297dbSJakub Kuderski     LLVM_DEBUG(llvm::dbgs()
4582ad297dbSJakub Kuderski                << type << " illegal: cannot handle zero-element tensors\n");
4592ad297dbSJakub Kuderski     return nullptr;
4602ad297dbSJakub Kuderski   }
4612ad297dbSJakub Kuderski 
4622ad297dbSJakub Kuderski   Type arrayElemType = convertScalarType(targetEnv, options, scalarType);
46301178654SLei Zhang   if (!arrayElemType)
464004f29c0SLei Zhang     return nullptr;
4650de16fafSRamkumar Ramachandra   std::optional<int64_t> arrayElemSize =
4660de16fafSRamkumar Ramachandra       getTypeNumBytes(options, arrayElemType);
46701178654SLei Zhang   if (!arrayElemSize) {
46801178654SLei Zhang     LLVM_DEBUG(llvm::dbgs()
46901178654SLei Zhang                << type << " illegal: cannot deduce converted element size\n");
470004f29c0SLei Zhang     return nullptr;
47101178654SLei Zhang   }
47201178654SLei Zhang 
473bbffece3SLei Zhang   return spirv::ArrayType::get(arrayElemType, arrayElemCount);
47401178654SLei Zhang }
47501178654SLei Zhang 
476c3614358SHanhan Wang static Type convertBoolMemrefType(const spirv::TargetEnv &targetEnv,
477dfaebd3dSLei Zhang                                   const SPIRVConversionOptions &options,
47889b595e1SLei Zhang                                   MemRefType type,
47989b595e1SLei Zhang                                   spirv::StorageClass storageClass) {
480c3614358SHanhan Wang   unsigned numBoolBits = options.boolNumBits;
481c3614358SHanhan Wang   if (numBoolBits != 8) {
482c3614358SHanhan Wang     LLVM_DEBUG(llvm::dbgs()
483c3614358SHanhan Wang                << "using non-8-bit storage for bool types unimplemented");
484c3614358SHanhan Wang     return nullptr;
485c3614358SHanhan Wang   }
4865550c821STres Popp   auto elementType = dyn_cast<spirv::ScalarType>(
4875550c821STres Popp       IntegerType::get(type.getContext(), numBoolBits));
488c3614358SHanhan Wang   if (!elementType)
489c3614358SHanhan Wang     return nullptr;
490c3614358SHanhan Wang   Type arrayElemType =
491c3614358SHanhan Wang       convertScalarType(targetEnv, options, elementType, storageClass);
492c3614358SHanhan Wang   if (!arrayElemType)
493c3614358SHanhan Wang     return nullptr;
4940de16fafSRamkumar Ramachandra   std::optional<int64_t> arrayElemSize =
4950de16fafSRamkumar Ramachandra       getTypeNumBytes(options, arrayElemType);
496c3614358SHanhan Wang   if (!arrayElemSize) {
497c3614358SHanhan Wang     LLVM_DEBUG(llvm::dbgs()
498c3614358SHanhan Wang                << type << " illegal: cannot deduce converted element size\n");
499c3614358SHanhan Wang     return nullptr;
500c3614358SHanhan Wang   }
501c3614358SHanhan Wang 
502d6de6ddeSNirvedh Meshram   if (!type.hasStaticShape()) {
503d6de6ddeSNirvedh Meshram     // For OpenCL Kernel, dynamic shaped memrefs convert into a pointer pointing
504d6de6ddeSNirvedh Meshram     // to the element.
5059c3a73a5SStanley Winata     if (targetEnv.allows(spirv::Capability::Kernel))
5069c3a73a5SStanley Winata       return spirv::PointerType::get(arrayElemType, storageClass);
50789b595e1SLei Zhang     int64_t stride = needsExplicitLayout(storageClass) ? *arrayElemSize : 0;
508bbffece3SLei Zhang     auto arrayType = spirv::RuntimeArrayType::get(arrayElemType, stride);
509d6de6ddeSNirvedh Meshram     // For Vulkan we need extra wrapping struct and array to satisfy interface
510d6de6ddeSNirvedh Meshram     // needs.
51189b595e1SLei Zhang     return wrapInStructAndGetPointer(arrayType, storageClass);
5127c4de2e9SHanhan Wang   }
5137c4de2e9SHanhan Wang 
5146ba60390SJakub Kuderski   if (type.getNumElements() == 0) {
5156ba60390SJakub Kuderski     LLVM_DEBUG(llvm::dbgs()
5166ba60390SJakub Kuderski                << type << " illegal: zero-element memrefs are not supported\n");
5176ba60390SJakub Kuderski     return nullptr;
5186ba60390SJakub Kuderski   }
5196ba60390SJakub Kuderski 
520f772dcbbSLei Zhang   int64_t memrefSize = llvm::divideCeil(type.getNumElements() * numBoolBits, 8);
521f772dcbbSLei Zhang   int64_t arrayElemCount = llvm::divideCeil(memrefSize, *arrayElemSize);
52289b595e1SLei Zhang   int64_t stride = needsExplicitLayout(storageClass) ? *arrayElemSize : 0;
523bbffece3SLei Zhang   auto arrayType = spirv::ArrayType::get(arrayElemType, arrayElemCount, stride);
524d6de6ddeSNirvedh Meshram   if (targetEnv.allows(spirv::Capability::Kernel))
525d6de6ddeSNirvedh Meshram     return spirv::PointerType::get(arrayType, storageClass);
52689b595e1SLei Zhang   return wrapInStructAndGetPointer(arrayType, storageClass);
527c3614358SHanhan Wang }
528c3614358SHanhan Wang 
529f772dcbbSLei Zhang static Type convertSubByteMemrefType(const spirv::TargetEnv &targetEnv,
530f772dcbbSLei Zhang                                      const SPIRVConversionOptions &options,
531f772dcbbSLei Zhang                                      MemRefType type,
532f772dcbbSLei Zhang                                      spirv::StorageClass storageClass) {
533f772dcbbSLei Zhang   IntegerType elementType = cast<IntegerType>(type.getElementType());
534f772dcbbSLei Zhang   Type arrayElemType = convertSubByteIntegerType(options, elementType);
535f772dcbbSLei Zhang   if (!arrayElemType)
536f772dcbbSLei Zhang     return nullptr;
537f772dcbbSLei Zhang   int64_t arrayElemSize = *getTypeNumBytes(options, arrayElemType);
538f772dcbbSLei Zhang 
539f772dcbbSLei Zhang   if (!type.hasStaticShape()) {
540f772dcbbSLei Zhang     // For OpenCL Kernel, dynamic shaped memrefs convert into a pointer pointing
541f772dcbbSLei Zhang     // to the element.
542f772dcbbSLei Zhang     if (targetEnv.allows(spirv::Capability::Kernel))
543f772dcbbSLei Zhang       return spirv::PointerType::get(arrayElemType, storageClass);
544f772dcbbSLei Zhang     int64_t stride = needsExplicitLayout(storageClass) ? arrayElemSize : 0;
545f772dcbbSLei Zhang     auto arrayType = spirv::RuntimeArrayType::get(arrayElemType, stride);
546f772dcbbSLei Zhang     // For Vulkan we need extra wrapping struct and array to satisfy interface
547f772dcbbSLei Zhang     // needs.
548f772dcbbSLei Zhang     return wrapInStructAndGetPointer(arrayType, storageClass);
549f772dcbbSLei Zhang   }
550f772dcbbSLei Zhang 
5516ba60390SJakub Kuderski   if (type.getNumElements() == 0) {
5526ba60390SJakub Kuderski     LLVM_DEBUG(llvm::dbgs()
5536ba60390SJakub Kuderski                << type << " illegal: zero-element memrefs are not supported\n");
5546ba60390SJakub Kuderski     return nullptr;
5556ba60390SJakub Kuderski   }
5566ba60390SJakub Kuderski 
557f772dcbbSLei Zhang   int64_t memrefSize =
558f772dcbbSLei Zhang       llvm::divideCeil(type.getNumElements() * elementType.getWidth(), 8);
559f772dcbbSLei Zhang   int64_t arrayElemCount = llvm::divideCeil(memrefSize, arrayElemSize);
560f772dcbbSLei Zhang   int64_t stride = needsExplicitLayout(storageClass) ? arrayElemSize : 0;
561f772dcbbSLei Zhang   auto arrayType = spirv::ArrayType::get(arrayElemType, arrayElemCount, stride);
562f772dcbbSLei Zhang   if (targetEnv.allows(spirv::Capability::Kernel))
563f772dcbbSLei Zhang     return spirv::PointerType::get(arrayType, storageClass);
564f772dcbbSLei Zhang   return wrapInStructAndGetPointer(arrayType, storageClass);
565f772dcbbSLei Zhang }
566f772dcbbSLei Zhang 
567004f29c0SLei Zhang static Type convertMemrefType(const spirv::TargetEnv &targetEnv,
568dfaebd3dSLei Zhang                               const SPIRVConversionOptions &options,
56901178654SLei Zhang                               MemRefType type) {
5705550c821STres Popp   auto attr = dyn_cast_or_null<spirv::StorageClassAttr>(type.getMemorySpace());
57189b595e1SLei Zhang   if (!attr) {
57289b595e1SLei Zhang     LLVM_DEBUG(
57389b595e1SLei Zhang         llvm::dbgs()
57489b595e1SLei Zhang         << type
57589b595e1SLei Zhang         << " illegal: expected memory space to be a SPIR-V storage class "
57689b595e1SLei Zhang            "attribute; please use MemorySpaceToStorageClassConverter to map "
57789b595e1SLei Zhang            "numeric memory spaces beforehand\n");
57889b595e1SLei Zhang     return nullptr;
57989b595e1SLei Zhang   }
58089b595e1SLei Zhang   spirv::StorageClass storageClass = attr.getValue();
58189b595e1SLei Zhang 
5825550c821STres Popp   if (isa<IntegerType>(type.getElementType())) {
583f772dcbbSLei Zhang     if (type.getElementTypeBitWidth() == 1)
58489b595e1SLei Zhang       return convertBoolMemrefType(targetEnv, options, type, storageClass);
585f772dcbbSLei Zhang     if (type.getElementTypeBitWidth() < 8)
586f772dcbbSLei Zhang       return convertSubByteMemrefType(targetEnv, options, type, storageClass);
58701178654SLei Zhang   }
58801178654SLei Zhang 
589004f29c0SLei Zhang   Type arrayElemType;
59001178654SLei Zhang   Type elementType = type.getElementType();
5915550c821STres Popp   if (auto vecType = dyn_cast<VectorType>(elementType)) {
5925299843cSLei Zhang     arrayElemType =
5935299843cSLei Zhang         convertVectorType(targetEnv, options, vecType, storageClass);
5945550c821STres Popp   } else if (auto complexType = dyn_cast<ComplexType>(elementType)) {
59597f3bb73SLei Zhang     arrayElemType =
59697f3bb73SLei Zhang         convertComplexType(targetEnv, options, complexType, storageClass);
5975550c821STres Popp   } else if (auto scalarType = dyn_cast<spirv::ScalarType>(elementType)) {
5985299843cSLei Zhang     arrayElemType =
5995299843cSLei Zhang         convertScalarType(targetEnv, options, scalarType, storageClass);
6005550c821STres Popp   } else if (auto indexType = dyn_cast<IndexType>(elementType)) {
6015550c821STres Popp     type = cast<MemRefType>(convertIndexElementType(type, options));
602c29fc69eSJakub Kuderski     arrayElemType = type.getElementType();
60301178654SLei Zhang   } else {
60401178654SLei Zhang     LLVM_DEBUG(
60501178654SLei Zhang         llvm::dbgs()
60601178654SLei Zhang         << type
60701178654SLei Zhang         << " unhandled: can only convert scalar or vector element type\n");
608004f29c0SLei Zhang     return nullptr;
60901178654SLei Zhang   }
61001178654SLei Zhang   if (!arrayElemType)
611004f29c0SLei Zhang     return nullptr;
61201178654SLei Zhang 
6130de16fafSRamkumar Ramachandra   std::optional<int64_t> arrayElemSize =
6140de16fafSRamkumar Ramachandra       getTypeNumBytes(options, arrayElemType);
61523b8264bSLei Zhang   if (!arrayElemSize) {
61623b8264bSLei Zhang     LLVM_DEBUG(llvm::dbgs()
61723b8264bSLei Zhang                << type << " illegal: cannot deduce converted element size\n");
61823b8264bSLei Zhang     return nullptr;
61923b8264bSLei Zhang   }
62023b8264bSLei Zhang 
621d6de6ddeSNirvedh Meshram   if (!type.hasStaticShape()) {
622d6de6ddeSNirvedh Meshram     // For OpenCL Kernel, dynamic shaped memrefs convert into a pointer pointing
623d6de6ddeSNirvedh Meshram     // to the element.
6249c3a73a5SStanley Winata     if (targetEnv.allows(spirv::Capability::Kernel))
6259c3a73a5SStanley Winata       return spirv::PointerType::get(arrayElemType, storageClass);
62689b595e1SLei Zhang     int64_t stride = needsExplicitLayout(storageClass) ? *arrayElemSize : 0;
627bbffece3SLei Zhang     auto arrayType = spirv::RuntimeArrayType::get(arrayElemType, stride);
628d6de6ddeSNirvedh Meshram     // For Vulkan we need extra wrapping struct and array to satisfy interface
629d6de6ddeSNirvedh Meshram     // needs.
63089b595e1SLei Zhang     return wrapInStructAndGetPointer(arrayType, storageClass);
63101178654SLei Zhang   }
63201178654SLei Zhang 
6330de16fafSRamkumar Ramachandra   std::optional<int64_t> memrefSize = getTypeNumBytes(options, type);
63401178654SLei Zhang   if (!memrefSize) {
63501178654SLei Zhang     LLVM_DEBUG(llvm::dbgs()
63601178654SLei Zhang                << type << " illegal: cannot deduce element count\n");
637004f29c0SLei Zhang     return nullptr;
63801178654SLei Zhang   }
63901178654SLei Zhang 
640370a7eaeSJakub Kuderski   if (*memrefSize == 0) {
641370a7eaeSJakub Kuderski     LLVM_DEBUG(llvm::dbgs()
642370a7eaeSJakub Kuderski                << type << " illegal: zero-element memrefs are not supported\n");
643370a7eaeSJakub Kuderski     return nullptr;
644370a7eaeSJakub Kuderski   }
645370a7eaeSJakub Kuderski 
646f772dcbbSLei Zhang   int64_t arrayElemCount = llvm::divideCeil(*memrefSize, *arrayElemSize);
64789b595e1SLei Zhang   int64_t stride = needsExplicitLayout(storageClass) ? *arrayElemSize : 0;
648bbffece3SLei Zhang   auto arrayType = spirv::ArrayType::get(arrayElemType, arrayElemCount, stride);
649d6de6ddeSNirvedh Meshram   if (targetEnv.allows(spirv::Capability::Kernel))
650d6de6ddeSNirvedh Meshram     return spirv::PointerType::get(arrayType, storageClass);
65189b595e1SLei Zhang   return wrapInStructAndGetPointer(arrayType, storageClass);
65201178654SLei Zhang }
65301178654SLei Zhang 
65418ad80eaSLei Zhang //===----------------------------------------------------------------------===//
65518ad80eaSLei Zhang // Type casting materialization
65618ad80eaSLei Zhang //===----------------------------------------------------------------------===//
65718ad80eaSLei Zhang 
65818ad80eaSLei Zhang /// Converts the given `inputs` to the original source `type` considering the
65918ad80eaSLei Zhang /// `targetEnv`'s capabilities.
66018ad80eaSLei Zhang ///
66118ad80eaSLei Zhang /// This function is meant to be used for source materialization in type
66218ad80eaSLei Zhang /// converters. When the type converter needs to materialize a cast op back
66318ad80eaSLei Zhang /// to some original source type, we need to check whether the original source
66418ad80eaSLei Zhang /// type is supported in the target environment. If so, we can insert legal
66518ad80eaSLei Zhang /// SPIR-V cast ops accordingly.
66618ad80eaSLei Zhang ///
66718ad80eaSLei Zhang /// Note that in SPIR-V the capabilities for storage and compute are separate.
66818ad80eaSLei Zhang /// This function is meant to handle the **compute** side; so it does not
66918ad80eaSLei Zhang /// involve storage classes in its logic. The storage side is expected to be
67018ad80eaSLei Zhang /// handled by MemRef conversion logic.
671f18c3e4eSMatthias Springer static Value castToSourceType(const spirv::TargetEnv &targetEnv,
672f18c3e4eSMatthias Springer                               OpBuilder &builder, Type type, ValueRange inputs,
673f18c3e4eSMatthias Springer                               Location loc) {
67418ad80eaSLei Zhang   // We can only cast one value in SPIR-V.
67518ad80eaSLei Zhang   if (inputs.size() != 1) {
67618ad80eaSLei Zhang     auto castOp = builder.create<UnrealizedConversionCastOp>(loc, type, inputs);
67718ad80eaSLei Zhang     return castOp.getResult(0);
67818ad80eaSLei Zhang   }
67918ad80eaSLei Zhang   Value input = inputs.front();
68018ad80eaSLei Zhang 
68118ad80eaSLei Zhang   // Only support integer types for now. Floating point types to be implemented.
68218ad80eaSLei Zhang   if (!isa<IntegerType>(type)) {
68318ad80eaSLei Zhang     auto castOp = builder.create<UnrealizedConversionCastOp>(loc, type, inputs);
68418ad80eaSLei Zhang     return castOp.getResult(0);
68518ad80eaSLei Zhang   }
68618ad80eaSLei Zhang   auto inputType = cast<IntegerType>(input.getType());
68718ad80eaSLei Zhang 
68818ad80eaSLei Zhang   auto scalarType = dyn_cast<spirv::ScalarType>(type);
68918ad80eaSLei Zhang   if (!scalarType) {
69018ad80eaSLei Zhang     auto castOp = builder.create<UnrealizedConversionCastOp>(loc, type, inputs);
69118ad80eaSLei Zhang     return castOp.getResult(0);
69218ad80eaSLei Zhang   }
69318ad80eaSLei Zhang 
69418ad80eaSLei Zhang   // Only support source type with a smaller bitwidth. This would mean we are
69518ad80eaSLei Zhang   // truncating to go back so we don't need to worry about the signedness.
69618ad80eaSLei Zhang   // For extension, we cannot have enough signal here to decide which op to use.
69718ad80eaSLei Zhang   if (inputType.getIntOrFloatBitWidth() < scalarType.getIntOrFloatBitWidth()) {
69818ad80eaSLei Zhang     auto castOp = builder.create<UnrealizedConversionCastOp>(loc, type, inputs);
69918ad80eaSLei Zhang     return castOp.getResult(0);
70018ad80eaSLei Zhang   }
70118ad80eaSLei Zhang 
70218ad80eaSLei Zhang   // Boolean values would need to use different ops than normal integer values.
70318ad80eaSLei Zhang   if (type.isInteger(1)) {
70418ad80eaSLei Zhang     Value one = spirv::ConstantOp::getOne(inputType, loc, builder);
70518ad80eaSLei Zhang     return builder.create<spirv::IEqualOp>(loc, input, one);
70618ad80eaSLei Zhang   }
70718ad80eaSLei Zhang 
70818ad80eaSLei Zhang   // Check that the source integer type is supported by the environment.
70918ad80eaSLei Zhang   SmallVector<ArrayRef<spirv::Extension>, 1> exts;
71018ad80eaSLei Zhang   SmallVector<ArrayRef<spirv::Capability>, 2> caps;
71118ad80eaSLei Zhang   scalarType.getExtensions(exts);
71218ad80eaSLei Zhang   scalarType.getCapabilities(caps);
71318ad80eaSLei Zhang   if (failed(checkCapabilityRequirements(type, targetEnv, caps)) ||
71418ad80eaSLei Zhang       failed(checkExtensionRequirements(type, targetEnv, exts))) {
71518ad80eaSLei Zhang     auto castOp = builder.create<UnrealizedConversionCastOp>(loc, type, inputs);
71618ad80eaSLei Zhang     return castOp.getResult(0);
71718ad80eaSLei Zhang   }
71818ad80eaSLei Zhang 
71918ad80eaSLei Zhang   // We've already made sure this is truncating previously, so we don't need to
72018ad80eaSLei Zhang   // care about signedness here. Still try to use a corresponding op for better
72118ad80eaSLei Zhang   // consistency though.
72218ad80eaSLei Zhang   if (type.isSignedInteger()) {
72318ad80eaSLei Zhang     return builder.create<spirv::SConvertOp>(loc, type, input);
72418ad80eaSLei Zhang   }
72518ad80eaSLei Zhang   return builder.create<spirv::UConvertOp>(loc, type, input);
72618ad80eaSLei Zhang }
72718ad80eaSLei Zhang 
72818ad80eaSLei Zhang //===----------------------------------------------------------------------===//
7299527d77aSAngel Zhang // Builtin Variables
73018ad80eaSLei Zhang //===----------------------------------------------------------------------===//
73118ad80eaSLei Zhang 
7329527d77aSAngel Zhang static spirv::GlobalVariableOp getBuiltinVariable(Block &body,
7339527d77aSAngel Zhang                                                   spirv::BuiltIn builtin) {
7349527d77aSAngel Zhang   // Look through all global variables in the given `body` block and check if
7359527d77aSAngel Zhang   // there is a spirv.GlobalVariable that has the same `builtin` attribute.
7369527d77aSAngel Zhang   for (auto varOp : body.getOps<spirv::GlobalVariableOp>()) {
7379527d77aSAngel Zhang     if (auto builtinAttr = varOp->getAttrOfType<StringAttr>(
7389527d77aSAngel Zhang             spirv::SPIRVDialect::getAttributeName(
7399527d77aSAngel Zhang                 spirv::Decoration::BuiltIn))) {
7409527d77aSAngel Zhang       auto varBuiltIn = spirv::symbolizeBuiltIn(builtinAttr.getValue());
7419527d77aSAngel Zhang       if (varBuiltIn && *varBuiltIn == builtin) {
7429527d77aSAngel Zhang         return varOp;
7439527d77aSAngel Zhang       }
7449527d77aSAngel Zhang     }
7459527d77aSAngel Zhang   }
7469527d77aSAngel Zhang   return nullptr;
7479527d77aSAngel Zhang }
74801178654SLei Zhang 
7499527d77aSAngel Zhang /// Gets name of global variable for a builtin.
7509527d77aSAngel Zhang std::string getBuiltinVarName(spirv::BuiltIn builtin, StringRef prefix,
7519527d77aSAngel Zhang                               StringRef suffix) {
7529527d77aSAngel Zhang   return Twine(prefix).concat(stringifyBuiltIn(builtin)).concat(suffix).str();
7539527d77aSAngel Zhang }
75401178654SLei Zhang 
7559527d77aSAngel Zhang /// Gets or inserts a global variable for a builtin within `body` block.
7569527d77aSAngel Zhang static spirv::GlobalVariableOp
7579527d77aSAngel Zhang getOrInsertBuiltinVariable(Block &body, Location loc, spirv::BuiltIn builtin,
7589527d77aSAngel Zhang                            Type integerType, OpBuilder &builder,
7599527d77aSAngel Zhang                            StringRef prefix, StringRef suffix) {
7609527d77aSAngel Zhang   if (auto varOp = getBuiltinVariable(body, builtin))
7619527d77aSAngel Zhang     return varOp;
76201178654SLei Zhang 
7639527d77aSAngel Zhang   OpBuilder::InsertionGuard guard(builder);
7649527d77aSAngel Zhang   builder.setInsertionPointToStart(&body);
76501178654SLei Zhang 
7669527d77aSAngel Zhang   spirv::GlobalVariableOp newVarOp;
7679527d77aSAngel Zhang   switch (builtin) {
7689527d77aSAngel Zhang   case spirv::BuiltIn::NumWorkgroups:
7699527d77aSAngel Zhang   case spirv::BuiltIn::WorkgroupSize:
7709527d77aSAngel Zhang   case spirv::BuiltIn::WorkgroupId:
7719527d77aSAngel Zhang   case spirv::BuiltIn::LocalInvocationId:
7729527d77aSAngel Zhang   case spirv::BuiltIn::GlobalInvocationId: {
7739527d77aSAngel Zhang     auto ptrType = spirv::PointerType::get(VectorType::get({3}, integerType),
7749527d77aSAngel Zhang                                            spirv::StorageClass::Input);
7759527d77aSAngel Zhang     std::string name = getBuiltinVarName(builtin, prefix, suffix);
7769527d77aSAngel Zhang     newVarOp =
7779527d77aSAngel Zhang         builder.create<spirv::GlobalVariableOp>(loc, ptrType, name, builtin);
7789527d77aSAngel Zhang     break;
7799527d77aSAngel Zhang   }
7809527d77aSAngel Zhang   case spirv::BuiltIn::SubgroupId:
7819527d77aSAngel Zhang   case spirv::BuiltIn::NumSubgroups:
7829527d77aSAngel Zhang   case spirv::BuiltIn::SubgroupSize: {
7839527d77aSAngel Zhang     auto ptrType =
7849527d77aSAngel Zhang         spirv::PointerType::get(integerType, spirv::StorageClass::Input);
7859527d77aSAngel Zhang     std::string name = getBuiltinVarName(builtin, prefix, suffix);
7869527d77aSAngel Zhang     newVarOp =
7879527d77aSAngel Zhang         builder.create<spirv::GlobalVariableOp>(loc, ptrType, name, builtin);
7889527d77aSAngel Zhang     break;
7899527d77aSAngel Zhang   }
7909527d77aSAngel Zhang   default:
7919527d77aSAngel Zhang     emitError(loc, "unimplemented builtin variable generation for ")
7929527d77aSAngel Zhang         << stringifyBuiltIn(builtin);
7939527d77aSAngel Zhang   }
7949527d77aSAngel Zhang   return newVarOp;
7959527d77aSAngel Zhang }
79601178654SLei Zhang 
7979527d77aSAngel Zhang //===----------------------------------------------------------------------===//
7989527d77aSAngel Zhang // Push constant storage
7999527d77aSAngel Zhang //===----------------------------------------------------------------------===//
80097f3bb73SLei Zhang 
8019527d77aSAngel Zhang /// Returns the pointer type for the push constant storage containing
8029527d77aSAngel Zhang /// `elementCount` 32-bit integer values.
8039527d77aSAngel Zhang static spirv::PointerType getPushConstantStorageType(unsigned elementCount,
8049527d77aSAngel Zhang                                                      Builder &builder,
8059527d77aSAngel Zhang                                                      Type indexType) {
8069527d77aSAngel Zhang   auto arrayType = spirv::ArrayType::get(indexType, elementCount,
8079527d77aSAngel Zhang                                          /*stride=*/4);
8089527d77aSAngel Zhang   auto structType = spirv::StructType::get({arrayType}, /*offsetInfo=*/0);
8099527d77aSAngel Zhang   return spirv::PointerType::get(structType, spirv::StorageClass::PushConstant);
8109527d77aSAngel Zhang }
81101178654SLei Zhang 
8129527d77aSAngel Zhang /// Returns the push constant varible containing `elementCount` 32-bit integer
8139527d77aSAngel Zhang /// values in `body`. Returns null op if such an op does not exit.
8149527d77aSAngel Zhang static spirv::GlobalVariableOp getPushConstantVariable(Block &body,
8159527d77aSAngel Zhang                                                        unsigned elementCount) {
8169527d77aSAngel Zhang   for (auto varOp : body.getOps<spirv::GlobalVariableOp>()) {
8179527d77aSAngel Zhang     auto ptrType = dyn_cast<spirv::PointerType>(varOp.getType());
8189527d77aSAngel Zhang     if (!ptrType)
8199527d77aSAngel Zhang       continue;
82001178654SLei Zhang 
8219527d77aSAngel Zhang     // Note that Vulkan requires "There must be no more than one push constant
8229527d77aSAngel Zhang     // block statically used per shader entry point." So we should always reuse
8239527d77aSAngel Zhang     // the existing one.
8249527d77aSAngel Zhang     if (ptrType.getStorageClass() == spirv::StorageClass::PushConstant) {
8259527d77aSAngel Zhang       auto numElements = cast<spirv::ArrayType>(
8269527d77aSAngel Zhang                              cast<spirv::StructType>(ptrType.getPointeeType())
8279527d77aSAngel Zhang                                  .getElementType(0))
8289527d77aSAngel Zhang                              .getNumElements();
8299527d77aSAngel Zhang       if (numElements == elementCount)
8309527d77aSAngel Zhang         return varOp;
8319527d77aSAngel Zhang     }
8329527d77aSAngel Zhang   }
8339527d77aSAngel Zhang   return nullptr;
8349527d77aSAngel Zhang }
83518ad80eaSLei Zhang 
8369527d77aSAngel Zhang /// Gets or inserts a global variable for push constant storage containing
8379527d77aSAngel Zhang /// `elementCount` 32-bit integer values in `block`.
8389527d77aSAngel Zhang static spirv::GlobalVariableOp
8399527d77aSAngel Zhang getOrInsertPushConstantVariable(Location loc, Block &block,
8409527d77aSAngel Zhang                                 unsigned elementCount, OpBuilder &b,
8419527d77aSAngel Zhang                                 Type indexType) {
8429527d77aSAngel Zhang   if (auto varOp = getPushConstantVariable(block, elementCount))
8439527d77aSAngel Zhang     return varOp;
8449527d77aSAngel Zhang 
8459527d77aSAngel Zhang   auto builder = OpBuilder::atBlockBegin(&block, b.getListener());
8469527d77aSAngel Zhang   auto type = getPushConstantStorageType(elementCount, builder, indexType);
8479527d77aSAngel Zhang   const char *name = "__push_constant_var__";
8489527d77aSAngel Zhang   return builder.create<spirv::GlobalVariableOp>(loc, type, name,
8499527d77aSAngel Zhang                                                  /*initializer=*/nullptr);
85001178654SLei Zhang }
85101178654SLei Zhang 
85201178654SLei Zhang //===----------------------------------------------------------------------===//
85358ceae95SRiver Riddle // func::FuncOp Conversion Patterns
85401178654SLei Zhang //===----------------------------------------------------------------------===//
85501178654SLei Zhang 
85601178654SLei Zhang /// A pattern for rewriting function signature to convert arguments of functions
85701178654SLei Zhang /// to be of valid SPIR-V types.
8589527d77aSAngel Zhang struct FuncOpConversion final : OpConversionPattern<func::FuncOp> {
85958ceae95SRiver Riddle   using OpConversionPattern<func::FuncOp>::OpConversionPattern;
86001178654SLei Zhang 
86101178654SLei Zhang   LogicalResult
86258ceae95SRiver Riddle   matchAndRewrite(func::FuncOp funcOp, OpAdaptor adaptor,
8639527d77aSAngel Zhang                   ConversionPatternRewriter &rewriter) const override {
8649527d77aSAngel Zhang     FunctionType fnType = funcOp.getFunctionType();
86542980a78SLei Zhang     if (fnType.getNumResults() > 1)
86601178654SLei Zhang       return failure();
86701178654SLei Zhang 
8689527d77aSAngel Zhang     TypeConverter::SignatureConversion signatureConverter(
8699527d77aSAngel Zhang         fnType.getNumInputs());
870e4853be2SMehdi Amini     for (const auto &argType : enumerate(fnType.getInputs())) {
8717c3ae48fSLei Zhang       auto convertedType = getTypeConverter()->convertType(argType.value());
87201178654SLei Zhang       if (!convertedType)
87301178654SLei Zhang         return failure();
87401178654SLei Zhang       signatureConverter.addInputs(argType.index(), convertedType);
87501178654SLei Zhang     }
87601178654SLei Zhang 
87742980a78SLei Zhang     Type resultType;
8785299843cSLei Zhang     if (fnType.getNumResults() == 1) {
8797c3ae48fSLei Zhang       resultType = getTypeConverter()->convertType(fnType.getResult(0));
8805299843cSLei Zhang       if (!resultType)
8815299843cSLei Zhang         return failure();
8825299843cSLei Zhang     }
88342980a78SLei Zhang 
8845ab6ef75SJakub Kuderski     // Create the converted spirv.func op.
88501178654SLei Zhang     auto newFuncOp = rewriter.create<spirv::FuncOp>(
88601178654SLei Zhang         funcOp.getLoc(), funcOp.getName(),
88701178654SLei Zhang         rewriter.getFunctionType(signatureConverter.getConvertedTypes(),
88842980a78SLei Zhang                                  resultType ? TypeRange(resultType)
88942980a78SLei Zhang                                             : TypeRange()));
89001178654SLei Zhang 
89101178654SLei Zhang     // Copy over all attributes other than the function name and type.
89256774bddSMarius Brehler     for (const auto &namedAttr : funcOp->getAttrs()) {
89353406427SJeff Niu       if (namedAttr.getName() != funcOp.getFunctionTypeAttrName() &&
8940c7890c8SRiver Riddle           namedAttr.getName() != SymbolTable::getSymbolAttrName())
8950c7890c8SRiver Riddle         newFuncOp->setAttr(namedAttr.getName(), namedAttr.getValue());
89601178654SLei Zhang     }
89701178654SLei Zhang 
89801178654SLei Zhang     rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
89901178654SLei Zhang                                 newFuncOp.end());
9007c3ae48fSLei Zhang     if (failed(rewriter.convertRegionTypes(
9017c3ae48fSLei Zhang             &newFuncOp.getBody(), *getTypeConverter(), &signatureConverter)))
90201178654SLei Zhang       return failure();
90301178654SLei Zhang     rewriter.eraseOp(funcOp);
90401178654SLei Zhang     return success();
90501178654SLei Zhang   }
9069527d77aSAngel Zhang };
90701178654SLei Zhang 
9086867e49fSAngel Zhang /// A pattern for rewriting function signature to convert vector arguments of
9096867e49fSAngel Zhang /// functions to be of valid types
9106867e49fSAngel Zhang struct FuncOpVectorUnroll final : OpRewritePattern<func::FuncOp> {
9116867e49fSAngel Zhang   using OpRewritePattern::OpRewritePattern;
9126867e49fSAngel Zhang 
9136867e49fSAngel Zhang   LogicalResult matchAndRewrite(func::FuncOp funcOp,
9146867e49fSAngel Zhang                                 PatternRewriter &rewriter) const override {
9156867e49fSAngel Zhang     FunctionType fnType = funcOp.getFunctionType();
9166867e49fSAngel Zhang 
9176867e49fSAngel Zhang     // TODO: Handle declarations.
9186867e49fSAngel Zhang     if (funcOp.isDeclaration()) {
9196867e49fSAngel Zhang       LLVM_DEBUG(llvm::dbgs()
9206867e49fSAngel Zhang                  << fnType << " illegal: declarations are unsupported\n");
9216867e49fSAngel Zhang       return failure();
9226867e49fSAngel Zhang     }
9236867e49fSAngel Zhang 
9246867e49fSAngel Zhang     // Create a new func op with the original type and copy the function body.
9256867e49fSAngel Zhang     auto newFuncOp = rewriter.create<func::FuncOp>(funcOp.getLoc(),
9266867e49fSAngel Zhang                                                    funcOp.getName(), fnType);
9276867e49fSAngel Zhang     rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
9286867e49fSAngel Zhang                                 newFuncOp.end());
9296867e49fSAngel Zhang 
9306867e49fSAngel Zhang     Location loc = newFuncOp.getBody().getLoc();
9316867e49fSAngel Zhang 
9326867e49fSAngel Zhang     Block &entryBlock = newFuncOp.getBlocks().front();
9336867e49fSAngel Zhang     OpBuilder::InsertionGuard guard(rewriter);
9346867e49fSAngel Zhang     rewriter.setInsertionPointToStart(&entryBlock);
9356867e49fSAngel Zhang 
9366867e49fSAngel Zhang     OneToNTypeMapping oneToNTypeMapping(fnType.getInputs());
9376867e49fSAngel Zhang 
9386867e49fSAngel Zhang     // For arguments that are of illegal types and require unrolling.
9396867e49fSAngel Zhang     // `unrolledInputNums` stores the indices of arguments that result from
9406867e49fSAngel Zhang     // unrolling in the new function signature. `newInputNo` is a counter.
9416867e49fSAngel Zhang     SmallVector<size_t> unrolledInputNums;
9426867e49fSAngel Zhang     size_t newInputNo = 0;
9436867e49fSAngel Zhang 
9446867e49fSAngel Zhang     // For arguments that are of legal types and do not require unrolling.
9456867e49fSAngel Zhang     // `tmpOps` stores a mapping from temporary operations that serve as
9466867e49fSAngel Zhang     // placeholders for new arguments that will be added later. These operations
9476867e49fSAngel Zhang     // will be erased once the entry block's argument list is updated.
9486867e49fSAngel Zhang     llvm::SmallDenseMap<Operation *, size_t> tmpOps;
9496867e49fSAngel Zhang 
9506867e49fSAngel Zhang     // This counts the number of new operations created.
9516867e49fSAngel Zhang     size_t newOpCount = 0;
9526867e49fSAngel Zhang 
9536867e49fSAngel Zhang     // Enumerate through the arguments.
9546867e49fSAngel Zhang     for (auto [origInputNo, origType] : enumerate(fnType.getInputs())) {
9556867e49fSAngel Zhang       // Check whether the argument is of vector type.
9566867e49fSAngel Zhang       auto origVecType = dyn_cast<VectorType>(origType);
9576867e49fSAngel Zhang       if (!origVecType) {
9586867e49fSAngel Zhang         // We need a placeholder for the old argument that will be erased later.
9596867e49fSAngel Zhang         Value result = rewriter.create<arith::ConstantOp>(
9606867e49fSAngel Zhang             loc, origType, rewriter.getZeroAttr(origType));
9616867e49fSAngel Zhang         rewriter.replaceAllUsesWith(newFuncOp.getArgument(origInputNo), result);
9626867e49fSAngel Zhang         tmpOps.insert({result.getDefiningOp(), newInputNo});
9636867e49fSAngel Zhang         oneToNTypeMapping.addInputs(origInputNo, origType);
9646867e49fSAngel Zhang         ++newInputNo;
9656867e49fSAngel Zhang         ++newOpCount;
9666867e49fSAngel Zhang         continue;
9676867e49fSAngel Zhang       }
9686867e49fSAngel Zhang       // Check whether the vector needs unrolling.
9696867e49fSAngel Zhang       auto targetShape = getTargetShape(origVecType);
9706867e49fSAngel Zhang       if (!targetShape) {
9716867e49fSAngel Zhang         // We need a placeholder for the old argument that will be erased later.
9726867e49fSAngel Zhang         Value result = rewriter.create<arith::ConstantOp>(
9736867e49fSAngel Zhang             loc, origType, rewriter.getZeroAttr(origType));
9746867e49fSAngel Zhang         rewriter.replaceAllUsesWith(newFuncOp.getArgument(origInputNo), result);
9756867e49fSAngel Zhang         tmpOps.insert({result.getDefiningOp(), newInputNo});
9766867e49fSAngel Zhang         oneToNTypeMapping.addInputs(origInputNo, origType);
9776867e49fSAngel Zhang         ++newInputNo;
9786867e49fSAngel Zhang         ++newOpCount;
9796867e49fSAngel Zhang         continue;
9806867e49fSAngel Zhang       }
9816867e49fSAngel Zhang       VectorType unrolledType =
9826867e49fSAngel Zhang           VectorType::get(*targetShape, origVecType.getElementType());
9836867e49fSAngel Zhang       auto originalShape =
9846867e49fSAngel Zhang           llvm::to_vector_of<int64_t, 4>(origVecType.getShape());
9856867e49fSAngel Zhang 
9866867e49fSAngel Zhang       // Prepare the result vector.
9876867e49fSAngel Zhang       Value result = rewriter.create<arith::ConstantOp>(
9886867e49fSAngel Zhang           loc, origVecType, rewriter.getZeroAttr(origVecType));
9896867e49fSAngel Zhang       ++newOpCount;
9906867e49fSAngel Zhang       // Prepare the placeholder for the new arguments that will be added later.
9916867e49fSAngel Zhang       Value dummy = rewriter.create<arith::ConstantOp>(
9926867e49fSAngel Zhang           loc, unrolledType, rewriter.getZeroAttr(unrolledType));
9936867e49fSAngel Zhang       ++newOpCount;
9946867e49fSAngel Zhang 
9956867e49fSAngel Zhang       // Create the `vector.insert_strided_slice` ops.
9966867e49fSAngel Zhang       SmallVector<int64_t> strides(targetShape->size(), 1);
9976867e49fSAngel Zhang       SmallVector<Type> newTypes;
9986867e49fSAngel Zhang       for (SmallVector<int64_t> offsets :
9996867e49fSAngel Zhang            StaticTileOffsetRange(originalShape, *targetShape)) {
10006867e49fSAngel Zhang         result = rewriter.create<vector::InsertStridedSliceOp>(
10016867e49fSAngel Zhang             loc, dummy, result, offsets, strides);
10026867e49fSAngel Zhang         newTypes.push_back(unrolledType);
10036867e49fSAngel Zhang         unrolledInputNums.push_back(newInputNo);
10046867e49fSAngel Zhang         ++newInputNo;
10056867e49fSAngel Zhang         ++newOpCount;
10066867e49fSAngel Zhang       }
10076867e49fSAngel Zhang       rewriter.replaceAllUsesWith(newFuncOp.getArgument(origInputNo), result);
10086867e49fSAngel Zhang       oneToNTypeMapping.addInputs(origInputNo, newTypes);
10096867e49fSAngel Zhang     }
10106867e49fSAngel Zhang 
10116867e49fSAngel Zhang     // Change the function signature.
10126867e49fSAngel Zhang     auto convertedTypes = oneToNTypeMapping.getConvertedTypes();
10136867e49fSAngel Zhang     auto newFnType = fnType.clone(convertedTypes, fnType.getResults());
10146867e49fSAngel Zhang     rewriter.modifyOpInPlace(newFuncOp,
10156867e49fSAngel Zhang                              [&] { newFuncOp.setFunctionType(newFnType); });
10166867e49fSAngel Zhang 
10176867e49fSAngel Zhang     // Update the arguments in the entry block.
10186867e49fSAngel Zhang     entryBlock.eraseArguments(0, fnType.getNumInputs());
10196867e49fSAngel Zhang     SmallVector<Location> locs(convertedTypes.size(), newFuncOp.getLoc());
10206867e49fSAngel Zhang     entryBlock.addArguments(convertedTypes, locs);
10216867e49fSAngel Zhang 
10226867e49fSAngel Zhang     // Replace the placeholder values with the new arguments. We assume there is
10236867e49fSAngel Zhang     // only one block for now.
10246867e49fSAngel Zhang     size_t unrolledInputIdx = 0;
10256867e49fSAngel Zhang     for (auto [count, op] : enumerate(entryBlock.getOperations())) {
10266867e49fSAngel Zhang       // We first look for operands that are placeholders for initially legal
10276867e49fSAngel Zhang       // arguments.
10286867e49fSAngel Zhang       Operation &curOp = op;
10296867e49fSAngel Zhang       for (auto [operandIdx, operandVal] : llvm::enumerate(op.getOperands())) {
10306867e49fSAngel Zhang         Operation *operandOp = operandVal.getDefiningOp();
10316867e49fSAngel Zhang         if (auto it = tmpOps.find(operandOp); it != tmpOps.end()) {
10326867e49fSAngel Zhang           size_t idx = operandIdx;
10336867e49fSAngel Zhang           rewriter.modifyOpInPlace(&curOp, [&curOp, &newFuncOp, it, idx] {
10346867e49fSAngel Zhang             curOp.setOperand(idx, newFuncOp.getArgument(it->second));
10356867e49fSAngel Zhang           });
10366867e49fSAngel Zhang         }
10376867e49fSAngel Zhang       }
10386867e49fSAngel Zhang       // Since all newly created operations are in the beginning, reaching the
10396867e49fSAngel Zhang       // end of them means that any later `vector.insert_strided_slice` should
10406867e49fSAngel Zhang       // not be touched.
10416867e49fSAngel Zhang       if (count >= newOpCount)
10426867e49fSAngel Zhang         continue;
10436867e49fSAngel Zhang       if (auto vecOp = dyn_cast<vector::InsertStridedSliceOp>(op)) {
10446867e49fSAngel Zhang         size_t unrolledInputNo = unrolledInputNums[unrolledInputIdx];
10456867e49fSAngel Zhang         rewriter.modifyOpInPlace(&curOp, [&] {
10466867e49fSAngel Zhang           curOp.setOperand(0, newFuncOp.getArgument(unrolledInputNo));
10476867e49fSAngel Zhang         });
10486867e49fSAngel Zhang         ++unrolledInputIdx;
10496867e49fSAngel Zhang       }
10506867e49fSAngel Zhang     }
10516867e49fSAngel Zhang 
10526867e49fSAngel Zhang     // Erase the original funcOp. The `tmpOps` do not need to be erased since
10536867e49fSAngel Zhang     // they have no uses and will be handled by dead-code elimination.
10546867e49fSAngel Zhang     rewriter.eraseOp(funcOp);
10556867e49fSAngel Zhang     return success();
10566867e49fSAngel Zhang   }
10576867e49fSAngel Zhang };
10586867e49fSAngel Zhang 
10596867e49fSAngel Zhang //===----------------------------------------------------------------------===//
10606867e49fSAngel Zhang // func::ReturnOp Conversion Patterns
10616867e49fSAngel Zhang //===----------------------------------------------------------------------===//
10626867e49fSAngel Zhang 
10636867e49fSAngel Zhang /// A pattern for rewriting function signature and the return op to convert
10646867e49fSAngel Zhang /// vectors to be of valid types.
10656867e49fSAngel Zhang struct ReturnOpVectorUnroll final : OpRewritePattern<func::ReturnOp> {
10666867e49fSAngel Zhang   using OpRewritePattern::OpRewritePattern;
10676867e49fSAngel Zhang 
10686867e49fSAngel Zhang   LogicalResult matchAndRewrite(func::ReturnOp returnOp,
10696867e49fSAngel Zhang                                 PatternRewriter &rewriter) const override {
10706867e49fSAngel Zhang     // Check whether the parent funcOp is valid.
10716867e49fSAngel Zhang     auto funcOp = dyn_cast<func::FuncOp>(returnOp->getParentOp());
10726867e49fSAngel Zhang     if (!funcOp)
10736867e49fSAngel Zhang       return failure();
10746867e49fSAngel Zhang 
10756867e49fSAngel Zhang     FunctionType fnType = funcOp.getFunctionType();
10766867e49fSAngel Zhang     OneToNTypeMapping oneToNTypeMapping(fnType.getResults());
10776867e49fSAngel Zhang     Location loc = returnOp.getLoc();
10786867e49fSAngel Zhang 
10796867e49fSAngel Zhang     // For the new return op.
10806867e49fSAngel Zhang     SmallVector<Value> newOperands;
10816867e49fSAngel Zhang 
10826867e49fSAngel Zhang     // Enumerate through the results.
10836867e49fSAngel Zhang     for (auto [origResultNo, origType] : enumerate(fnType.getResults())) {
10846867e49fSAngel Zhang       // Check whether the argument is of vector type.
10856867e49fSAngel Zhang       auto origVecType = dyn_cast<VectorType>(origType);
10866867e49fSAngel Zhang       if (!origVecType) {
10876867e49fSAngel Zhang         oneToNTypeMapping.addInputs(origResultNo, origType);
10886867e49fSAngel Zhang         newOperands.push_back(returnOp.getOperand(origResultNo));
10896867e49fSAngel Zhang         continue;
10906867e49fSAngel Zhang       }
10916867e49fSAngel Zhang       // Check whether the vector needs unrolling.
10926867e49fSAngel Zhang       auto targetShape = getTargetShape(origVecType);
10936867e49fSAngel Zhang       if (!targetShape) {
10946867e49fSAngel Zhang         // The original argument can be used.
10956867e49fSAngel Zhang         oneToNTypeMapping.addInputs(origResultNo, origType);
10966867e49fSAngel Zhang         newOperands.push_back(returnOp.getOperand(origResultNo));
10976867e49fSAngel Zhang         continue;
10986867e49fSAngel Zhang       }
10996867e49fSAngel Zhang       VectorType unrolledType =
11006867e49fSAngel Zhang           VectorType::get(*targetShape, origVecType.getElementType());
11016867e49fSAngel Zhang 
11026867e49fSAngel Zhang       // Create `vector.extract_strided_slice` ops to form legal vectors from
11036867e49fSAngel Zhang       // the original operand of illegal type.
11046867e49fSAngel Zhang       auto originalShape =
11056867e49fSAngel Zhang           llvm::to_vector_of<int64_t, 4>(origVecType.getShape());
1106f83950abSAngel Zhang       SmallVector<int64_t> strides(originalShape.size(), 1);
1107f83950abSAngel Zhang       SmallVector<int64_t> extractShape(originalShape.size(), 1);
1108f83950abSAngel Zhang       extractShape.back() = targetShape->back();
11096867e49fSAngel Zhang       SmallVector<Type> newTypes;
11106867e49fSAngel Zhang       Value returnValue = returnOp.getOperand(origResultNo);
11116867e49fSAngel Zhang       for (SmallVector<int64_t> offsets :
11126867e49fSAngel Zhang            StaticTileOffsetRange(originalShape, *targetShape)) {
11136867e49fSAngel Zhang         Value result = rewriter.create<vector::ExtractStridedSliceOp>(
1114f83950abSAngel Zhang             loc, returnValue, offsets, extractShape, strides);
1115f83950abSAngel Zhang         if (originalShape.size() > 1) {
1116f83950abSAngel Zhang           SmallVector<int64_t> extractIndices(originalShape.size() - 1, 0);
1117f83950abSAngel Zhang           result =
1118f83950abSAngel Zhang               rewriter.create<vector::ExtractOp>(loc, result, extractIndices);
1119f83950abSAngel Zhang         }
11206867e49fSAngel Zhang         newOperands.push_back(result);
11216867e49fSAngel Zhang         newTypes.push_back(unrolledType);
11226867e49fSAngel Zhang       }
11236867e49fSAngel Zhang       oneToNTypeMapping.addInputs(origResultNo, newTypes);
11246867e49fSAngel Zhang     }
11256867e49fSAngel Zhang 
11266867e49fSAngel Zhang     // Change the function signature.
11276867e49fSAngel Zhang     auto newFnType =
11286867e49fSAngel Zhang         FunctionType::get(rewriter.getContext(), TypeRange(fnType.getInputs()),
11296867e49fSAngel Zhang                           TypeRange(oneToNTypeMapping.getConvertedTypes()));
11306867e49fSAngel Zhang     rewriter.modifyOpInPlace(funcOp,
11316867e49fSAngel Zhang                              [&] { funcOp.setFunctionType(newFnType); });
11326867e49fSAngel Zhang 
11336867e49fSAngel Zhang     // Replace the return op using the new operands. This will automatically
11346867e49fSAngel Zhang     // update the entry block as well.
11356867e49fSAngel Zhang     rewriter.replaceOp(returnOp,
11366867e49fSAngel Zhang                        rewriter.create<func::ReturnOp>(loc, newOperands));
11376867e49fSAngel Zhang 
11386867e49fSAngel Zhang     return success();
11396867e49fSAngel Zhang   }
11406867e49fSAngel Zhang };
11419527d77aSAngel Zhang 
11426867e49fSAngel Zhang } // namespace
11436867e49fSAngel Zhang 
11446867e49fSAngel Zhang //===----------------------------------------------------------------------===//
11459527d77aSAngel Zhang // Public function for builtin variables
114601178654SLei Zhang //===----------------------------------------------------------------------===//
114701178654SLei Zhang 
114801178654SLei Zhang Value mlir::spirv::getBuiltinVariableValue(Operation *op,
114901178654SLei Zhang                                            spirv::BuiltIn builtin,
11509da4b6dbSVictor Perez                                            Type integerType, OpBuilder &builder,
11519da4b6dbSVictor Perez                                            StringRef prefix, StringRef suffix) {
115201178654SLei Zhang   Operation *parent = SymbolTable::getNearestSymbolTable(op->getParentOp());
115301178654SLei Zhang   if (!parent) {
115401178654SLei Zhang     op->emitError("expected operation to be within a module-like op");
115501178654SLei Zhang     return nullptr;
115601178654SLei Zhang   }
115701178654SLei Zhang 
11581e35a769SButygin   spirv::GlobalVariableOp varOp =
11591e35a769SButygin       getOrInsertBuiltinVariable(*parent->getRegion(0).begin(), op->getLoc(),
11609da4b6dbSVictor Perez                                  builtin, integerType, builder, prefix, suffix);
116101178654SLei Zhang   Value ptr = builder.create<spirv::AddressOfOp>(op->getLoc(), varOp);
116201178654SLei Zhang   return builder.create<spirv::LoadOp>(op->getLoc(), ptr);
116301178654SLei Zhang }
116401178654SLei Zhang 
116501178654SLei Zhang //===----------------------------------------------------------------------===//
11669527d77aSAngel Zhang // Public function for pushing constant storage
11676dd07fa5SLei Zhang //===----------------------------------------------------------------------===//
11686dd07fa5SLei Zhang 
11696dd07fa5SLei Zhang Value spirv::getPushConstantValue(Operation *op, unsigned elementCount,
11701e35a769SButygin                                   unsigned offset, Type integerType,
11711e35a769SButygin                                   OpBuilder &builder) {
11726dd07fa5SLei Zhang   Location loc = op->getLoc();
11736dd07fa5SLei Zhang   Operation *parent = SymbolTable::getNearestSymbolTable(op->getParentOp());
11746dd07fa5SLei Zhang   if (!parent) {
11756dd07fa5SLei Zhang     op->emitError("expected operation to be within a module-like op");
11766dd07fa5SLei Zhang     return nullptr;
11776dd07fa5SLei Zhang   }
11786dd07fa5SLei Zhang 
11796dd07fa5SLei Zhang   spirv::GlobalVariableOp varOp = getOrInsertPushConstantVariable(
11801e35a769SButygin       loc, parent->getRegion(0).front(), elementCount, builder, integerType);
11816dd07fa5SLei Zhang 
11821e35a769SButygin   Value zeroOp = spirv::ConstantOp::getZero(integerType, loc, builder);
11836dd07fa5SLei Zhang   Value offsetOp = builder.create<spirv::ConstantOp>(
11841e35a769SButygin       loc, integerType, builder.getI32IntegerAttr(offset));
11856dd07fa5SLei Zhang   auto addrOp = builder.create<spirv::AddressOfOp>(loc, varOp);
11866dd07fa5SLei Zhang   auto acOp = builder.create<spirv::AccessChainOp>(
1187984b800aSserge-sans-paille       loc, addrOp, llvm::ArrayRef({zeroOp, offsetOp}));
11886dd07fa5SLei Zhang   return builder.create<spirv::LoadOp>(loc, acOp);
11896dd07fa5SLei Zhang }
11906dd07fa5SLei Zhang 
11916dd07fa5SLei Zhang //===----------------------------------------------------------------------===//
11929527d77aSAngel Zhang // Public functions for index calculation
119301178654SLei Zhang //===----------------------------------------------------------------------===//
119401178654SLei Zhang 
1195bb6f5c83SLei Zhang Value mlir::spirv::linearizeIndex(ValueRange indices, ArrayRef<int64_t> strides,
11961e35a769SButygin                                   int64_t offset, Type integerType,
11971e35a769SButygin                                   Location loc, OpBuilder &builder) {
1198bb6f5c83SLei Zhang   assert(indices.size() == strides.size() &&
1199bb6f5c83SLei Zhang          "must provide indices for all dimensions");
1200bb6f5c83SLei Zhang 
1201bb6f5c83SLei Zhang   // TODO: Consider moving to use affine.apply and patterns converting
1202bb6f5c83SLei Zhang   // affine.apply to standard ops. This needs converting to SPIR-V passes to be
1203bb6f5c83SLei Zhang   // broken down into progressive small steps so we can have intermediate steps
1204bb6f5c83SLei Zhang   // using other dialects. At the moment SPIR-V is the final sink.
1205bb6f5c83SLei Zhang 
120638f8a3cfSFinn Plummer   Value linearizedIndex = builder.createOrFold<spirv::ConstantOp>(
12071e35a769SButygin       loc, integerType, IntegerAttr::get(integerType, offset));
1208e4853be2SMehdi Amini   for (const auto &index : llvm::enumerate(indices)) {
120938f8a3cfSFinn Plummer     Value strideVal = builder.createOrFold<spirv::ConstantOp>(
12101e35a769SButygin         loc, integerType,
12111e35a769SButygin         IntegerAttr::get(integerType, strides[index.index()]));
121238f8a3cfSFinn Plummer     Value update =
121338f8a3cfSFinn Plummer         builder.createOrFold<spirv::IMulOp>(loc, index.value(), strideVal);
1214bb6f5c83SLei Zhang     linearizedIndex =
121538f8a3cfSFinn Plummer         builder.createOrFold<spirv::IAddOp>(loc, update, linearizedIndex);
1216bb6f5c83SLei Zhang   }
1217bb6f5c83SLei Zhang   return linearizedIndex;
1218bb6f5c83SLei Zhang }
1219bb6f5c83SLei Zhang 
1220ce254598SMatthias Springer Value mlir::spirv::getVulkanElementPtr(const SPIRVTypeConverter &typeConverter,
12219c3a73a5SStanley Winata                                        MemRefType baseType, Value basePtr,
12229c3a73a5SStanley Winata                                        ValueRange indices, Location loc,
12239c3a73a5SStanley Winata                                        OpBuilder &builder) {
122401178654SLei Zhang   // Get base and offset of the MemRefType and verify they are static.
122501178654SLei Zhang 
122601178654SLei Zhang   int64_t offset;
122701178654SLei Zhang   SmallVector<int64_t, 4> strides;
1228*6aaa8f25SMatthias Springer   if (failed(baseType.getStridesAndOffset(strides, offset)) ||
1229399638f9SAliia Khasanova       llvm::is_contained(strides, ShapedType::kDynamic) ||
1230399638f9SAliia Khasanova       ShapedType::isDynamic(offset)) {
123101178654SLei Zhang     return nullptr;
123201178654SLei Zhang   }
123301178654SLei Zhang 
12341e35a769SButygin   auto indexType = typeConverter.getIndexType();
123501178654SLei Zhang 
123601178654SLei Zhang   SmallVector<Value, 2> linearizedIndices;
123701178654SLei Zhang   auto zero = spirv::ConstantOp::getZero(indexType, loc, builder);
1238bb6f5c83SLei Zhang 
1239bb6f5c83SLei Zhang   // Add a '0' at the start to index into the struct.
124001178654SLei Zhang   linearizedIndices.push_back(zero);
124101178654SLei Zhang 
124201178654SLei Zhang   if (baseType.getRank() == 0) {
124301178654SLei Zhang     linearizedIndices.push_back(zero);
124401178654SLei Zhang   } else {
1245bb6f5c83SLei Zhang     linearizedIndices.push_back(
12461e35a769SButygin         linearizeIndex(indices, strides, offset, indexType, loc, builder));
124701178654SLei Zhang   }
124801178654SLei Zhang   return builder.create<spirv::AccessChainOp>(loc, basePtr, linearizedIndices);
124901178654SLei Zhang }
125001178654SLei Zhang 
1251ce254598SMatthias Springer Value mlir::spirv::getOpenCLElementPtr(const SPIRVTypeConverter &typeConverter,
12529c3a73a5SStanley Winata                                        MemRefType baseType, Value basePtr,
12539c3a73a5SStanley Winata                                        ValueRange indices, Location loc,
12549c3a73a5SStanley Winata                                        OpBuilder &builder) {
12559c3a73a5SStanley Winata   // Get base and offset of the MemRefType and verify they are static.
12569c3a73a5SStanley Winata 
12579c3a73a5SStanley Winata   int64_t offset;
12589c3a73a5SStanley Winata   SmallVector<int64_t, 4> strides;
1259*6aaa8f25SMatthias Springer   if (failed(baseType.getStridesAndOffset(strides, offset)) ||
1260399638f9SAliia Khasanova       llvm::is_contained(strides, ShapedType::kDynamic) ||
1261399638f9SAliia Khasanova       ShapedType::isDynamic(offset)) {
12629c3a73a5SStanley Winata     return nullptr;
12639c3a73a5SStanley Winata   }
12649c3a73a5SStanley Winata 
12659c3a73a5SStanley Winata   auto indexType = typeConverter.getIndexType();
12669c3a73a5SStanley Winata 
12679c3a73a5SStanley Winata   SmallVector<Value, 2> linearizedIndices;
12689c3a73a5SStanley Winata   Value linearIndex;
12699c3a73a5SStanley Winata   if (baseType.getRank() == 0) {
1270d6de6ddeSNirvedh Meshram     linearIndex = spirv::ConstantOp::getZero(indexType, loc, builder);
12719c3a73a5SStanley Winata   } else {
12729c3a73a5SStanley Winata     linearIndex =
12739c3a73a5SStanley Winata         linearizeIndex(indices, strides, offset, indexType, loc, builder);
12749c3a73a5SStanley Winata   }
1275d6de6ddeSNirvedh Meshram   Type pointeeType =
12765550c821STres Popp       cast<spirv::PointerType>(basePtr.getType()).getPointeeType();
12775550c821STres Popp   if (isa<spirv::ArrayType>(pointeeType)) {
1278d6de6ddeSNirvedh Meshram     linearizedIndices.push_back(linearIndex);
1279d6de6ddeSNirvedh Meshram     return builder.create<spirv::AccessChainOp>(loc, basePtr,
1280d6de6ddeSNirvedh Meshram                                                 linearizedIndices);
1281d6de6ddeSNirvedh Meshram   }
12829c3a73a5SStanley Winata   return builder.create<spirv::PtrAccessChainOp>(loc, basePtr, linearIndex,
12839c3a73a5SStanley Winata                                                  linearizedIndices);
12849c3a73a5SStanley Winata }
12859c3a73a5SStanley Winata 
1286ce254598SMatthias Springer Value mlir::spirv::getElementPtr(const SPIRVTypeConverter &typeConverter,
12879c3a73a5SStanley Winata                                  MemRefType baseType, Value basePtr,
12889c3a73a5SStanley Winata                                  ValueRange indices, Location loc,
12899c3a73a5SStanley Winata                                  OpBuilder &builder) {
12909c3a73a5SStanley Winata 
12919c3a73a5SStanley Winata   if (typeConverter.allows(spirv::Capability::Kernel)) {
12929c3a73a5SStanley Winata     return getOpenCLElementPtr(typeConverter, baseType, basePtr, indices, loc,
12939c3a73a5SStanley Winata                                builder);
12949c3a73a5SStanley Winata   }
12959c3a73a5SStanley Winata 
12969c3a73a5SStanley Winata   return getVulkanElementPtr(typeConverter, baseType, basePtr, indices, loc,
12979c3a73a5SStanley Winata                              builder);
12989c3a73a5SStanley Winata }
12999c3a73a5SStanley Winata 
130001178654SLei Zhang //===----------------------------------------------------------------------===//
1301f83950abSAngel Zhang // Public functions for vector unrolling
1302f83950abSAngel Zhang //===----------------------------------------------------------------------===//
1303f83950abSAngel Zhang 
1304f83950abSAngel Zhang int mlir::spirv::getComputeVectorSize(int64_t size) {
1305f83950abSAngel Zhang   for (int i : {4, 3, 2}) {
1306f83950abSAngel Zhang     if (size % i == 0)
1307f83950abSAngel Zhang       return i;
1308f83950abSAngel Zhang   }
1309f83950abSAngel Zhang   return 1;
1310f83950abSAngel Zhang }
1311f83950abSAngel Zhang 
1312f83950abSAngel Zhang SmallVector<int64_t>
1313f83950abSAngel Zhang mlir::spirv::getNativeVectorShapeImpl(vector::ReductionOp op) {
1314f83950abSAngel Zhang   VectorType srcVectorType = op.getSourceVectorType();
1315f83950abSAngel Zhang   assert(srcVectorType.getRank() == 1); // Guaranteed by semantics
1316f83950abSAngel Zhang   int64_t vectorSize =
1317f83950abSAngel Zhang       mlir::spirv::getComputeVectorSize(srcVectorType.getDimSize(0));
1318f83950abSAngel Zhang   return {vectorSize};
1319f83950abSAngel Zhang }
1320f83950abSAngel Zhang 
1321f83950abSAngel Zhang SmallVector<int64_t>
1322f83950abSAngel Zhang mlir::spirv::getNativeVectorShapeImpl(vector::TransposeOp op) {
1323f83950abSAngel Zhang   VectorType vectorType = op.getResultVectorType();
1324f83950abSAngel Zhang   SmallVector<int64_t> nativeSize(vectorType.getRank(), 1);
1325f83950abSAngel Zhang   nativeSize.back() =
1326f83950abSAngel Zhang       mlir::spirv::getComputeVectorSize(vectorType.getShape().back());
1327f83950abSAngel Zhang   return nativeSize;
1328f83950abSAngel Zhang }
1329f83950abSAngel Zhang 
1330f83950abSAngel Zhang std::optional<SmallVector<int64_t>>
1331f83950abSAngel Zhang mlir::spirv::getNativeVectorShape(Operation *op) {
1332f83950abSAngel Zhang   if (OpTrait::hasElementwiseMappableTraits(op) && op->getNumResults() == 1) {
1333f83950abSAngel Zhang     if (auto vecType = dyn_cast<VectorType>(op->getResultTypes()[0])) {
1334f83950abSAngel Zhang       SmallVector<int64_t> nativeSize(vecType.getRank(), 1);
1335f83950abSAngel Zhang       nativeSize.back() =
1336f83950abSAngel Zhang           mlir::spirv::getComputeVectorSize(vecType.getShape().back());
1337f83950abSAngel Zhang       return nativeSize;
1338f83950abSAngel Zhang     }
1339f83950abSAngel Zhang   }
1340f83950abSAngel Zhang 
1341f83950abSAngel Zhang   return TypeSwitch<Operation *, std::optional<SmallVector<int64_t>>>(op)
1342f83950abSAngel Zhang       .Case<vector::ReductionOp, vector::TransposeOp>(
1343f83950abSAngel Zhang           [](auto typedOp) { return getNativeVectorShapeImpl(typedOp); })
1344f83950abSAngel Zhang       .Default([](Operation *) { return std::nullopt; });
1345f83950abSAngel Zhang }
1346f83950abSAngel Zhang 
1347f83950abSAngel Zhang LogicalResult mlir::spirv::unrollVectorsInSignatures(Operation *op) {
1348f83950abSAngel Zhang   MLIRContext *context = op->getContext();
1349f83950abSAngel Zhang   RewritePatternSet patterns(context);
1350f83950abSAngel Zhang   populateFuncOpVectorRewritePatterns(patterns);
1351f83950abSAngel Zhang   populateReturnOpVectorRewritePatterns(patterns);
1352f83950abSAngel Zhang   // We only want to apply signature conversion once to the existing func ops.
1353f83950abSAngel Zhang   // Without specifying strictMode, the greedy pattern rewriter will keep
1354f83950abSAngel Zhang   // looking for newly created func ops.
1355f83950abSAngel Zhang   GreedyRewriteConfig config;
1356f83950abSAngel Zhang   config.strictMode = GreedyRewriteStrictness::ExistingOps;
135709dfc571SJacques Pienaar   return applyPatternsGreedily(op, std::move(patterns), config);
1358f83950abSAngel Zhang }
1359f83950abSAngel Zhang 
1360f83950abSAngel Zhang LogicalResult mlir::spirv::unrollVectorsInFuncBodies(Operation *op) {
1361f83950abSAngel Zhang   MLIRContext *context = op->getContext();
1362f83950abSAngel Zhang 
1363f83950abSAngel Zhang   // Unroll vectors in function bodies to native vector size.
1364f83950abSAngel Zhang   {
1365f83950abSAngel Zhang     RewritePatternSet patterns(context);
1366f83950abSAngel Zhang     auto options = vector::UnrollVectorOptions().setNativeShapeFn(
1367f83950abSAngel Zhang         [](auto op) { return mlir::spirv::getNativeVectorShape(op); });
1368f83950abSAngel Zhang     populateVectorUnrollPatterns(patterns, options);
136909dfc571SJacques Pienaar     if (failed(applyPatternsGreedily(op, std::move(patterns))))
1370f83950abSAngel Zhang       return failure();
1371f83950abSAngel Zhang   }
1372f83950abSAngel Zhang 
1373f83950abSAngel Zhang   // Convert transpose ops into extract and insert pairs, in preparation of
1374f83950abSAngel Zhang   // further transformations to canonicalize/cancel.
1375f83950abSAngel Zhang   {
1376f83950abSAngel Zhang     RewritePatternSet patterns(context);
1377f83950abSAngel Zhang     auto options = vector::VectorTransformsOptions().setVectorTransposeLowering(
1378f83950abSAngel Zhang         vector::VectorTransposeLowering::EltWise);
1379f83950abSAngel Zhang     vector::populateVectorTransposeLoweringPatterns(patterns, options);
1380f83950abSAngel Zhang     vector::populateVectorShapeCastLoweringPatterns(patterns);
138109dfc571SJacques Pienaar     if (failed(applyPatternsGreedily(op, std::move(patterns))))
1382f83950abSAngel Zhang       return failure();
1383f83950abSAngel Zhang   }
1384f83950abSAngel Zhang 
1385f83950abSAngel Zhang   // Run canonicalization to cast away leading size-1 dimensions.
1386f83950abSAngel Zhang   {
1387f83950abSAngel Zhang     RewritePatternSet patterns(context);
1388f83950abSAngel Zhang 
1389f83950abSAngel Zhang     // We need to pull in casting way leading one dims.
1390f83950abSAngel Zhang     vector::populateCastAwayVectorLeadingOneDimPatterns(patterns);
1391f83950abSAngel Zhang     vector::ReductionOp::getCanonicalizationPatterns(patterns, context);
1392f83950abSAngel Zhang     vector::TransposeOp::getCanonicalizationPatterns(patterns, context);
1393f83950abSAngel Zhang 
1394f83950abSAngel Zhang     // Decompose different rank insert_strided_slice and n-D
1395f83950abSAngel Zhang     // extract_slided_slice.
1396f83950abSAngel Zhang     vector::populateVectorInsertExtractStridedSliceDecompositionPatterns(
1397f83950abSAngel Zhang         patterns);
1398f83950abSAngel Zhang     vector::InsertOp::getCanonicalizationPatterns(patterns, context);
1399f83950abSAngel Zhang     vector::ExtractOp::getCanonicalizationPatterns(patterns, context);
1400f83950abSAngel Zhang 
1401f83950abSAngel Zhang     // Trimming leading unit dims may generate broadcast/shape_cast ops. Clean
1402f83950abSAngel Zhang     // them up.
1403f83950abSAngel Zhang     vector::BroadcastOp::getCanonicalizationPatterns(patterns, context);
1404f83950abSAngel Zhang     vector::ShapeCastOp::getCanonicalizationPatterns(patterns, context);
1405f83950abSAngel Zhang 
140609dfc571SJacques Pienaar     if (failed(applyPatternsGreedily(op, std::move(patterns))))
1407f83950abSAngel Zhang       return failure();
1408f83950abSAngel Zhang   }
1409f83950abSAngel Zhang   return success();
1410f83950abSAngel Zhang }
1411f83950abSAngel Zhang 
1412f83950abSAngel Zhang //===----------------------------------------------------------------------===//
14139527d77aSAngel Zhang // SPIR-V TypeConverter
14149527d77aSAngel Zhang //===----------------------------------------------------------------------===//
14159527d77aSAngel Zhang 
14169527d77aSAngel Zhang SPIRVTypeConverter::SPIRVTypeConverter(spirv::TargetEnvAttr targetAttr,
14179527d77aSAngel Zhang                                        const SPIRVConversionOptions &options)
14189527d77aSAngel Zhang     : targetEnv(targetAttr), options(options) {
14199527d77aSAngel Zhang   // Add conversions. The order matters here: later ones will be tried earlier.
14209527d77aSAngel Zhang 
14219527d77aSAngel Zhang   // Allow all SPIR-V dialect specific types. This assumes all builtin types
14229527d77aSAngel Zhang   // adopted in the SPIR-V dialect (i.e., IntegerType, FloatType, VectorType)
14239527d77aSAngel Zhang   // were tried before.
14249527d77aSAngel Zhang   //
14259527d77aSAngel Zhang   // TODO: This assumes that the SPIR-V types are valid to use in the given
14269527d77aSAngel Zhang   // target environment, which should be the case if the whole pipeline is
14279527d77aSAngel Zhang   // driven by the same target environment. Still, we probably still want to
14289527d77aSAngel Zhang   // validate and convert to be safe.
14299527d77aSAngel Zhang   addConversion([](spirv::SPIRVType type) { return type; });
14309527d77aSAngel Zhang 
14319527d77aSAngel Zhang   addConversion([this](IndexType /*indexType*/) { return getIndexType(); });
14329527d77aSAngel Zhang 
14339527d77aSAngel Zhang   addConversion([this](IntegerType intType) -> std::optional<Type> {
14349527d77aSAngel Zhang     if (auto scalarType = dyn_cast<spirv::ScalarType>(intType))
14359527d77aSAngel Zhang       return convertScalarType(this->targetEnv, this->options, scalarType);
14369527d77aSAngel Zhang     if (intType.getWidth() < 8)
14379527d77aSAngel Zhang       return convertSubByteIntegerType(this->options, intType);
14389527d77aSAngel Zhang     return Type();
14399527d77aSAngel Zhang   });
14409527d77aSAngel Zhang 
14419527d77aSAngel Zhang   addConversion([this](FloatType floatType) -> std::optional<Type> {
14429527d77aSAngel Zhang     if (auto scalarType = dyn_cast<spirv::ScalarType>(floatType))
14439527d77aSAngel Zhang       return convertScalarType(this->targetEnv, this->options, scalarType);
14449527d77aSAngel Zhang     return Type();
14459527d77aSAngel Zhang   });
14469527d77aSAngel Zhang 
14479527d77aSAngel Zhang   addConversion([this](ComplexType complexType) {
14489527d77aSAngel Zhang     return convertComplexType(this->targetEnv, this->options, complexType);
14499527d77aSAngel Zhang   });
14509527d77aSAngel Zhang 
14519527d77aSAngel Zhang   addConversion([this](VectorType vectorType) {
14529527d77aSAngel Zhang     return convertVectorType(this->targetEnv, this->options, vectorType);
14539527d77aSAngel Zhang   });
14549527d77aSAngel Zhang 
14559527d77aSAngel Zhang   addConversion([this](TensorType tensorType) {
14569527d77aSAngel Zhang     return convertTensorType(this->targetEnv, this->options, tensorType);
14579527d77aSAngel Zhang   });
14589527d77aSAngel Zhang 
14599527d77aSAngel Zhang   addConversion([this](MemRefType memRefType) {
14609527d77aSAngel Zhang     return convertMemrefType(this->targetEnv, this->options, memRefType);
14619527d77aSAngel Zhang   });
14629527d77aSAngel Zhang 
14639527d77aSAngel Zhang   // Register some last line of defense casting logic.
14649527d77aSAngel Zhang   addSourceMaterialization(
14659527d77aSAngel Zhang       [this](OpBuilder &builder, Type type, ValueRange inputs, Location loc) {
14669527d77aSAngel Zhang         return castToSourceType(this->targetEnv, builder, type, inputs, loc);
14679527d77aSAngel Zhang       });
14689527d77aSAngel Zhang   addTargetMaterialization([](OpBuilder &builder, Type type, ValueRange inputs,
14699527d77aSAngel Zhang                               Location loc) {
14709527d77aSAngel Zhang     auto cast = builder.create<UnrealizedConversionCastOp>(loc, type, inputs);
1471f18c3e4eSMatthias Springer     return cast.getResult(0);
14729527d77aSAngel Zhang   });
14739527d77aSAngel Zhang }
14749527d77aSAngel Zhang 
14759527d77aSAngel Zhang Type SPIRVTypeConverter::getIndexType() const {
14769527d77aSAngel Zhang   return ::getIndexType(getContext(), options);
14779527d77aSAngel Zhang }
14789527d77aSAngel Zhang 
14799527d77aSAngel Zhang MLIRContext *SPIRVTypeConverter::getContext() const {
14809527d77aSAngel Zhang   return targetEnv.getAttr().getContext();
14819527d77aSAngel Zhang }
14829527d77aSAngel Zhang 
14839527d77aSAngel Zhang bool SPIRVTypeConverter::allows(spirv::Capability capability) const {
14849527d77aSAngel Zhang   return targetEnv.allows(capability);
14859527d77aSAngel Zhang }
14869527d77aSAngel Zhang 
14879527d77aSAngel Zhang //===----------------------------------------------------------------------===//
148801178654SLei Zhang // SPIR-V ConversionTarget
148901178654SLei Zhang //===----------------------------------------------------------------------===//
149001178654SLei Zhang 
14916dd07fa5SLei Zhang std::unique_ptr<SPIRVConversionTarget>
14926dd07fa5SLei Zhang SPIRVConversionTarget::get(spirv::TargetEnvAttr targetAttr) {
149301178654SLei Zhang   std::unique_ptr<SPIRVConversionTarget> target(
149401178654SLei Zhang       // std::make_unique does not work here because the constructor is private.
149501178654SLei Zhang       new SPIRVConversionTarget(targetAttr));
149601178654SLei Zhang   SPIRVConversionTarget *targetPtr = target.get();
14976dd07fa5SLei Zhang   target->addDynamicallyLegalDialect<spirv::SPIRVDialect>(
149801178654SLei Zhang       // We need to capture the raw pointer here because it is stable:
149901178654SLei Zhang       // target will be destroyed once this function is returned.
150001178654SLei Zhang       [targetPtr](Operation *op) { return targetPtr->isLegalOp(op); });
150101178654SLei Zhang   return target;
150201178654SLei Zhang }
150301178654SLei Zhang 
15046dd07fa5SLei Zhang SPIRVConversionTarget::SPIRVConversionTarget(spirv::TargetEnvAttr targetAttr)
150501178654SLei Zhang     : ConversionTarget(*targetAttr.getContext()), targetEnv(targetAttr) {}
150601178654SLei Zhang 
15076dd07fa5SLei Zhang bool SPIRVConversionTarget::isLegalOp(Operation *op) {
150801178654SLei Zhang   // Make sure this op is available at the given version. Ops not implementing
150901178654SLei Zhang   // QueryMinVersionInterface/QueryMaxVersionInterface are available to all
151001178654SLei Zhang   // SPIR-V versions.
1511cb395f66SLei Zhang   if (auto minVersionIfx = dyn_cast<spirv::QueryMinVersionInterface>(op)) {
1512e8bcc37fSRamkumar Ramachandra     std::optional<spirv::Version> minVersion = minVersionIfx.getMinVersion();
1513cb395f66SLei Zhang     if (minVersion && *minVersion > this->targetEnv.getVersion()) {
151401178654SLei Zhang       LLVM_DEBUG(llvm::dbgs()
151501178654SLei Zhang                  << op->getName() << " illegal: requiring min version "
1516cb395f66SLei Zhang                  << spirv::stringifyVersion(*minVersion) << "\n");
151701178654SLei Zhang       return false;
151801178654SLei Zhang     }
1519cb395f66SLei Zhang   }
1520cb395f66SLei Zhang   if (auto maxVersionIfx = dyn_cast<spirv::QueryMaxVersionInterface>(op)) {
1521e8bcc37fSRamkumar Ramachandra     std::optional<spirv::Version> maxVersion = maxVersionIfx.getMaxVersion();
1522cb395f66SLei Zhang     if (maxVersion && *maxVersion < this->targetEnv.getVersion()) {
152301178654SLei Zhang       LLVM_DEBUG(llvm::dbgs()
152401178654SLei Zhang                  << op->getName() << " illegal: requiring max version "
1525cb395f66SLei Zhang                  << spirv::stringifyVersion(*maxVersion) << "\n");
152601178654SLei Zhang       return false;
152701178654SLei Zhang     }
1528cb395f66SLei Zhang   }
152901178654SLei Zhang 
153001178654SLei Zhang   // Make sure this op's required extensions are allowed to use. Ops not
153101178654SLei Zhang   // implementing QueryExtensionInterface do not require extensions to be
153201178654SLei Zhang   // available.
153301178654SLei Zhang   if (auto extensions = dyn_cast<spirv::QueryExtensionInterface>(op))
153401178654SLei Zhang     if (failed(checkExtensionRequirements(op->getName(), this->targetEnv,
153501178654SLei Zhang                                           extensions.getExtensions())))
153601178654SLei Zhang       return false;
153701178654SLei Zhang 
153801178654SLei Zhang   // Make sure this op's required extensions are allowed to use. Ops not
153901178654SLei Zhang   // implementing QueryCapabilityInterface do not require capabilities to be
154001178654SLei Zhang   // available.
154101178654SLei Zhang   if (auto capabilities = dyn_cast<spirv::QueryCapabilityInterface>(op))
154201178654SLei Zhang     if (failed(checkCapabilityRequirements(op->getName(), this->targetEnv,
154301178654SLei Zhang                                            capabilities.getCapabilities())))
154401178654SLei Zhang       return false;
154501178654SLei Zhang 
154601178654SLei Zhang   SmallVector<Type, 4> valueTypes;
154701178654SLei Zhang   valueTypes.append(op->operand_type_begin(), op->operand_type_end());
154801178654SLei Zhang   valueTypes.append(op->result_type_begin(), op->result_type_end());
154901178654SLei Zhang 
15507c5ecc8bSMogball   // Ensure that all types have been converted to SPIRV types.
15517c5ecc8bSMogball   if (llvm::any_of(valueTypes,
15525550c821STres Popp                    [](Type t) { return !isa<spirv::SPIRVType>(t); }))
15537c5ecc8bSMogball     return false;
15547c5ecc8bSMogball 
155501178654SLei Zhang   // Special treatment for global variables, whose type requirements are
155601178654SLei Zhang   // conveyed by type attributes.
155701178654SLei Zhang   if (auto globalVar = dyn_cast<spirv::GlobalVariableOp>(op))
155890a1632dSJakub Kuderski     valueTypes.push_back(globalVar.getType());
155901178654SLei Zhang 
156001178654SLei Zhang   // Make sure the op's operands/results use types that are allowed by the
156101178654SLei Zhang   // target environment.
156201178654SLei Zhang   SmallVector<ArrayRef<spirv::Extension>, 4> typeExtensions;
156301178654SLei Zhang   SmallVector<ArrayRef<spirv::Capability>, 8> typeCapabilities;
156401178654SLei Zhang   for (Type valueType : valueTypes) {
156501178654SLei Zhang     typeExtensions.clear();
15665550c821STres Popp     cast<spirv::SPIRVType>(valueType).getExtensions(typeExtensions);
156701178654SLei Zhang     if (failed(checkExtensionRequirements(op->getName(), this->targetEnv,
156801178654SLei Zhang                                           typeExtensions)))
156901178654SLei Zhang       return false;
157001178654SLei Zhang 
157101178654SLei Zhang     typeCapabilities.clear();
15725550c821STres Popp     cast<spirv::SPIRVType>(valueType).getCapabilities(typeCapabilities);
157301178654SLei Zhang     if (failed(checkCapabilityRequirements(op->getName(), this->targetEnv,
157401178654SLei Zhang                                            typeCapabilities)))
157501178654SLei Zhang       return false;
157601178654SLei Zhang   }
157701178654SLei Zhang 
157801178654SLei Zhang   return true;
157901178654SLei Zhang }
15809527d77aSAngel Zhang 
15819527d77aSAngel Zhang //===----------------------------------------------------------------------===//
15829527d77aSAngel Zhang // Public functions for populating patterns
15839527d77aSAngel Zhang //===----------------------------------------------------------------------===//
15849527d77aSAngel Zhang 
1585206fad0eSMatthias Springer void mlir::populateBuiltinFuncToSPIRVPatterns(
1586206fad0eSMatthias Springer     const SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns) {
15879527d77aSAngel Zhang   patterns.add<FuncOpConversion>(typeConverter, patterns.getContext());
15889527d77aSAngel Zhang }
15899527d77aSAngel Zhang 
15909527d77aSAngel Zhang void mlir::populateFuncOpVectorRewritePatterns(RewritePatternSet &patterns) {
15919527d77aSAngel Zhang   patterns.add<FuncOpVectorUnroll>(patterns.getContext());
15929527d77aSAngel Zhang }
15939527d77aSAngel Zhang 
15949527d77aSAngel Zhang void mlir::populateReturnOpVectorRewritePatterns(RewritePatternSet &patterns) {
15959527d77aSAngel Zhang   patterns.add<ReturnOpVectorUnroll>(patterns.getContext());
15969527d77aSAngel Zhang }
1597