xref: /llvm-project/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp (revision 6aaa8f25b66dc1fef4e465f274ee40b82d632988)
1 //===- SPIRVConversion.cpp - SPIR-V Conversion Utilities ------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements utilities used to lower to SPIR-V dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
14 #include "mlir/Dialect/Arith/IR/Arith.h"
15 #include "mlir/Dialect/Func/IR/FuncOps.h"
16 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
17 #include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
18 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
19 #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
20 #include "mlir/Dialect/SPIRV/IR/TargetAndABI.h"
21 #include "mlir/Dialect/Utils/IndexingUtils.h"
22 #include "mlir/Dialect/Vector/IR/VectorOps.h"
23 #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
24 #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
25 #include "mlir/IR/BuiltinTypes.h"
26 #include "mlir/IR/Operation.h"
27 #include "mlir/IR/PatternMatch.h"
28 #include "mlir/Pass/Pass.h"
29 #include "mlir/Support/LLVM.h"
30 #include "mlir/Transforms/DialectConversion.h"
31 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
32 #include "mlir/Transforms/OneToNTypeConversion.h"
33 #include "llvm/ADT/STLExtras.h"
34 #include "llvm/ADT/SmallVector.h"
35 #include "llvm/ADT/StringExtras.h"
36 #include "llvm/Support/Debug.h"
37 #include "llvm/Support/LogicalResult.h"
38 #include "llvm/Support/MathExtras.h"
39 
40 #include <functional>
41 #include <optional>
42 
43 #define DEBUG_TYPE "mlir-spirv-conversion"
44 
45 using namespace mlir;
46 
47 namespace {
48 
49 //===----------------------------------------------------------------------===//
50 // Utility functions
51 //===----------------------------------------------------------------------===//
52 
53 static std::optional<SmallVector<int64_t>> getTargetShape(VectorType vecType) {
54   LLVM_DEBUG(llvm::dbgs() << "Get target shape\n");
55   if (vecType.isScalable()) {
56     LLVM_DEBUG(llvm::dbgs()
57                << "--scalable vectors are not supported -> BAIL\n");
58     return std::nullopt;
59   }
60   SmallVector<int64_t> unrollShape = llvm::to_vector<4>(vecType.getShape());
61   std::optional<SmallVector<int64_t>> targetShape = SmallVector<int64_t>(
62       1, mlir::spirv::getComputeVectorSize(vecType.getShape().back()));
63   if (!targetShape) {
64     LLVM_DEBUG(llvm::dbgs() << "--no unrolling target shape defined\n");
65     return std::nullopt;
66   }
67   auto maybeShapeRatio = computeShapeRatio(unrollShape, *targetShape);
68   if (!maybeShapeRatio) {
69     LLVM_DEBUG(llvm::dbgs()
70                << "--could not compute integral shape ratio -> BAIL\n");
71     return std::nullopt;
72   }
73   if (llvm::all_of(*maybeShapeRatio, [](int64_t v) { return v == 1; })) {
74     LLVM_DEBUG(llvm::dbgs() << "--no unrolling needed -> SKIP\n");
75     return std::nullopt;
76   }
77   LLVM_DEBUG(llvm::dbgs()
78              << "--found an integral shape ratio to unroll to -> SUCCESS\n");
79   return targetShape;
80 }
81 
82 /// Checks that `candidates` extension requirements are possible to be satisfied
83 /// with the given `targetEnv`.
84 ///
85 ///  `candidates` is a vector of vector for extension requirements following
86 /// ((Extension::A OR Extension::B) AND (Extension::C OR Extension::D))
87 /// convention.
88 template <typename LabelT>
89 static LogicalResult checkExtensionRequirements(
90     LabelT label, const spirv::TargetEnv &targetEnv,
91     const spirv::SPIRVType::ExtensionArrayRefVector &candidates) {
92   for (const auto &ors : candidates) {
93     if (targetEnv.allows(ors))
94       continue;
95 
96     LLVM_DEBUG({
97       SmallVector<StringRef> extStrings;
98       for (spirv::Extension ext : ors)
99         extStrings.push_back(spirv::stringifyExtension(ext));
100 
101       llvm::dbgs() << label << " illegal: requires at least one extension in ["
102                    << llvm::join(extStrings, ", ")
103                    << "] but none allowed in target environment\n";
104     });
105     return failure();
106   }
107   return success();
108 }
109 
110 /// Checks that `candidates`capability requirements are possible to be satisfied
111 /// with the given `isAllowedFn`.
112 ///
113 ///  `candidates` is a vector of vector for capability requirements following
114 /// ((Capability::A OR Capability::B) AND (Capability::C OR Capability::D))
115 /// convention.
116 template <typename LabelT>
117 static LogicalResult checkCapabilityRequirements(
118     LabelT label, const spirv::TargetEnv &targetEnv,
119     const spirv::SPIRVType::CapabilityArrayRefVector &candidates) {
120   for (const auto &ors : candidates) {
121     if (targetEnv.allows(ors))
122       continue;
123 
124     LLVM_DEBUG({
125       SmallVector<StringRef> capStrings;
126       for (spirv::Capability cap : ors)
127         capStrings.push_back(spirv::stringifyCapability(cap));
128 
129       llvm::dbgs() << label << " illegal: requires at least one capability in ["
130                    << llvm::join(capStrings, ", ")
131                    << "] but none allowed in target environment\n";
132     });
133     return failure();
134   }
135   return success();
136 }
137 
138 /// Returns true if the given `storageClass` needs explicit layout when used in
139 /// Shader environments.
140 static bool needsExplicitLayout(spirv::StorageClass storageClass) {
141   switch (storageClass) {
142   case spirv::StorageClass::PhysicalStorageBuffer:
143   case spirv::StorageClass::PushConstant:
144   case spirv::StorageClass::StorageBuffer:
145   case spirv::StorageClass::Uniform:
146     return true;
147   default:
148     return false;
149   }
150 }
151 
152 /// Wraps the given `elementType` in a struct and gets the pointer to the
153 /// struct. This is used to satisfy Vulkan interface requirements.
154 static spirv::PointerType
155 wrapInStructAndGetPointer(Type elementType, spirv::StorageClass storageClass) {
156   auto structType = needsExplicitLayout(storageClass)
157                         ? spirv::StructType::get(elementType, /*offsetInfo=*/0)
158                         : spirv::StructType::get(elementType);
159   return spirv::PointerType::get(structType, storageClass);
160 }
161 
162 //===----------------------------------------------------------------------===//
163 // Type Conversion
164 //===----------------------------------------------------------------------===//
165 
166 static spirv::ScalarType getIndexType(MLIRContext *ctx,
167                                       const SPIRVConversionOptions &options) {
168   return cast<spirv::ScalarType>(
169       IntegerType::get(ctx, options.use64bitIndex ? 64 : 32));
170 }
171 
172 // TODO: This is a utility function that should probably be exposed by the
173 // SPIR-V dialect. Keeping it local till the use case arises.
174 static std::optional<int64_t>
175 getTypeNumBytes(const SPIRVConversionOptions &options, Type type) {
176   if (isa<spirv::ScalarType>(type)) {
177     auto bitWidth = type.getIntOrFloatBitWidth();
178     // According to the SPIR-V spec:
179     // "There is no physical size or bit pattern defined for values with boolean
180     // type. If they are stored (in conjunction with OpVariable), they can only
181     // be used with logical addressing operations, not physical, and only with
182     // non-externally visible shader Storage Classes: Workgroup, CrossWorkgroup,
183     // Private, Function, Input, and Output."
184     if (bitWidth == 1)
185       return std::nullopt;
186     return bitWidth / 8;
187   }
188 
189   if (auto complexType = dyn_cast<ComplexType>(type)) {
190     auto elementSize = getTypeNumBytes(options, complexType.getElementType());
191     if (!elementSize)
192       return std::nullopt;
193     return 2 * *elementSize;
194   }
195 
196   if (auto vecType = dyn_cast<VectorType>(type)) {
197     auto elementSize = getTypeNumBytes(options, vecType.getElementType());
198     if (!elementSize)
199       return std::nullopt;
200     return vecType.getNumElements() * *elementSize;
201   }
202 
203   if (auto memRefType = dyn_cast<MemRefType>(type)) {
204     // TODO: Layout should also be controlled by the ABI attributes. For now
205     // using the layout from MemRef.
206     int64_t offset;
207     SmallVector<int64_t, 4> strides;
208     if (!memRefType.hasStaticShape() ||
209         failed(memRefType.getStridesAndOffset(strides, offset)))
210       return std::nullopt;
211 
212     // To get the size of the memref object in memory, the total size is the
213     // max(stride * dimension-size) computed for all dimensions times the size
214     // of the element.
215     auto elementSize = getTypeNumBytes(options, memRefType.getElementType());
216     if (!elementSize)
217       return std::nullopt;
218 
219     if (memRefType.getRank() == 0)
220       return elementSize;
221 
222     auto dims = memRefType.getShape();
223     if (llvm::is_contained(dims, ShapedType::kDynamic) ||
224         ShapedType::isDynamic(offset) ||
225         llvm::is_contained(strides, ShapedType::kDynamic))
226       return std::nullopt;
227 
228     int64_t memrefSize = -1;
229     for (const auto &shape : enumerate(dims))
230       memrefSize = std::max(memrefSize, shape.value() * strides[shape.index()]);
231 
232     return (offset + memrefSize) * *elementSize;
233   }
234 
235   if (auto tensorType = dyn_cast<TensorType>(type)) {
236     if (!tensorType.hasStaticShape())
237       return std::nullopt;
238 
239     auto elementSize = getTypeNumBytes(options, tensorType.getElementType());
240     if (!elementSize)
241       return std::nullopt;
242 
243     int64_t size = *elementSize;
244     for (auto shape : tensorType.getShape())
245       size *= shape;
246 
247     return size;
248   }
249 
250   // TODO: Add size computation for other types.
251   return std::nullopt;
252 }
253 
254 /// Converts a scalar `type` to a suitable type under the given `targetEnv`.
255 static Type
256 convertScalarType(const spirv::TargetEnv &targetEnv,
257                   const SPIRVConversionOptions &options, spirv::ScalarType type,
258                   std::optional<spirv::StorageClass> storageClass = {}) {
259   // Get extension and capability requirements for the given type.
260   SmallVector<ArrayRef<spirv::Extension>, 1> extensions;
261   SmallVector<ArrayRef<spirv::Capability>, 2> capabilities;
262   type.getExtensions(extensions, storageClass);
263   type.getCapabilities(capabilities, storageClass);
264 
265   // If all requirements are met, then we can accept this type as-is.
266   if (succeeded(checkCapabilityRequirements(type, targetEnv, capabilities)) &&
267       succeeded(checkExtensionRequirements(type, targetEnv, extensions)))
268     return type;
269 
270   // Otherwise we need to adjust the type, which really means adjusting the
271   // bitwidth given this is a scalar type.
272   if (!options.emulateLT32BitScalarTypes)
273     return nullptr;
274 
275   // We only emulate narrower scalar types here and do not truncate results.
276   if (type.getIntOrFloatBitWidth() > 32) {
277     LLVM_DEBUG(llvm::dbgs()
278                << type
279                << " not converted to 32-bit for SPIR-V to avoid truncation\n");
280     return nullptr;
281   }
282 
283   if (auto floatType = dyn_cast<FloatType>(type)) {
284     LLVM_DEBUG(llvm::dbgs() << type << " converted to 32-bit for SPIR-V\n");
285     return Builder(targetEnv.getContext()).getF32Type();
286   }
287 
288   auto intType = cast<IntegerType>(type);
289   LLVM_DEBUG(llvm::dbgs() << type << " converted to 32-bit for SPIR-V\n");
290   return IntegerType::get(targetEnv.getContext(), /*width=*/32,
291                           intType.getSignedness());
292 }
293 
294 /// Converts a sub-byte integer `type` to i32 regardless of target environment.
295 /// Returns a nullptr for unsupported integer types, including non sub-byte
296 /// types.
297 ///
298 /// Note that we don't recognize sub-byte types in `spirv::ScalarType` and use
299 /// the above given that these sub-byte types are not supported at all in
300 /// SPIR-V; there are no compute/storage capability for them like other
301 /// supported integer types.
302 static Type convertSubByteIntegerType(const SPIRVConversionOptions &options,
303                                       IntegerType type) {
304   if (type.getWidth() > 8) {
305     LLVM_DEBUG(llvm::dbgs() << "not a subbyte type\n");
306     return nullptr;
307   }
308   if (options.subByteTypeStorage != SPIRVSubByteTypeStorage::Packed) {
309     LLVM_DEBUG(llvm::dbgs() << "unsupported sub-byte storage kind\n");
310     return nullptr;
311   }
312 
313   if (!llvm::isPowerOf2_32(type.getWidth())) {
314     LLVM_DEBUG(llvm::dbgs()
315                << "unsupported non-power-of-two bitwidth in sub-byte" << type
316                << "\n");
317     return nullptr;
318   }
319 
320   LLVM_DEBUG(llvm::dbgs() << type << " converted to 32-bit for SPIR-V\n");
321   return IntegerType::get(type.getContext(), /*width=*/32,
322                           type.getSignedness());
323 }
324 
325 /// Returns a type with the same shape but with any index element type converted
326 /// to the matching integer type. This is a noop when the element type is not
327 /// the index type.
328 static ShapedType
329 convertIndexElementType(ShapedType type,
330                         const SPIRVConversionOptions &options) {
331   Type indexType = dyn_cast<IndexType>(type.getElementType());
332   if (!indexType)
333     return type;
334 
335   return type.clone(getIndexType(type.getContext(), options));
336 }
337 
338 /// Converts a vector `type` to a suitable type under the given `targetEnv`.
339 static Type
340 convertVectorType(const spirv::TargetEnv &targetEnv,
341                   const SPIRVConversionOptions &options, VectorType type,
342                   std::optional<spirv::StorageClass> storageClass = {}) {
343   type = cast<VectorType>(convertIndexElementType(type, options));
344   auto scalarType = dyn_cast_or_null<spirv::ScalarType>(type.getElementType());
345   if (!scalarType) {
346     // If this is not a spec allowed scalar type, try to handle sub-byte integer
347     // types.
348     auto intType = dyn_cast<IntegerType>(type.getElementType());
349     if (!intType) {
350       LLVM_DEBUG(llvm::dbgs()
351                  << type
352                  << " illegal: cannot convert non-scalar element type\n");
353       return nullptr;
354     }
355 
356     Type elementType = convertSubByteIntegerType(options, intType);
357     if (!elementType)
358       return nullptr;
359 
360     if (type.getRank() <= 1 && type.getNumElements() == 1)
361       return elementType;
362 
363     if (type.getNumElements() > 4) {
364       LLVM_DEBUG(llvm::dbgs()
365                  << type << " illegal: > 4-element unimplemented\n");
366       return nullptr;
367     }
368 
369     return VectorType::get(type.getShape(), elementType);
370   }
371 
372   if (type.getRank() <= 1 && type.getNumElements() == 1)
373     return convertScalarType(targetEnv, options, scalarType, storageClass);
374 
375   if (!spirv::CompositeType::isValid(type)) {
376     LLVM_DEBUG(llvm::dbgs()
377                << type << " illegal: not a valid composite type\n");
378     return nullptr;
379   }
380 
381   // Get extension and capability requirements for the given type.
382   SmallVector<ArrayRef<spirv::Extension>, 1> extensions;
383   SmallVector<ArrayRef<spirv::Capability>, 2> capabilities;
384   cast<spirv::CompositeType>(type).getExtensions(extensions, storageClass);
385   cast<spirv::CompositeType>(type).getCapabilities(capabilities, storageClass);
386 
387   // If all requirements are met, then we can accept this type as-is.
388   if (succeeded(checkCapabilityRequirements(type, targetEnv, capabilities)) &&
389       succeeded(checkExtensionRequirements(type, targetEnv, extensions)))
390     return type;
391 
392   auto elementType =
393       convertScalarType(targetEnv, options, scalarType, storageClass);
394   if (elementType)
395     return VectorType::get(type.getShape(), elementType);
396   return nullptr;
397 }
398 
399 static Type
400 convertComplexType(const spirv::TargetEnv &targetEnv,
401                    const SPIRVConversionOptions &options, ComplexType type,
402                    std::optional<spirv::StorageClass> storageClass = {}) {
403   auto scalarType = dyn_cast_or_null<spirv::ScalarType>(type.getElementType());
404   if (!scalarType) {
405     LLVM_DEBUG(llvm::dbgs()
406                << type << " illegal: cannot convert non-scalar element type\n");
407     return nullptr;
408   }
409 
410   auto elementType =
411       convertScalarType(targetEnv, options, scalarType, storageClass);
412   if (!elementType)
413     return nullptr;
414   if (elementType != type.getElementType()) {
415     LLVM_DEBUG(llvm::dbgs()
416                << type << " illegal: complex type emulation unsupported\n");
417     return nullptr;
418   }
419 
420   return VectorType::get(2, elementType);
421 }
422 
423 /// Converts a tensor `type` to a suitable type under the given `targetEnv`.
424 ///
425 /// Note that this is mainly for lowering constant tensors. In SPIR-V one can
426 /// create composite constants with OpConstantComposite to embed relative large
427 /// constant values and use OpCompositeExtract and OpCompositeInsert to
428 /// manipulate, like what we do for vectors.
429 static Type convertTensorType(const spirv::TargetEnv &targetEnv,
430                               const SPIRVConversionOptions &options,
431                               TensorType type) {
432   // TODO: Handle dynamic shapes.
433   if (!type.hasStaticShape()) {
434     LLVM_DEBUG(llvm::dbgs()
435                << type << " illegal: dynamic shape unimplemented\n");
436     return nullptr;
437   }
438 
439   type = cast<TensorType>(convertIndexElementType(type, options));
440   auto scalarType = dyn_cast_or_null<spirv::ScalarType>(type.getElementType());
441   if (!scalarType) {
442     LLVM_DEBUG(llvm::dbgs()
443                << type << " illegal: cannot convert non-scalar element type\n");
444     return nullptr;
445   }
446 
447   std::optional<int64_t> scalarSize = getTypeNumBytes(options, scalarType);
448   std::optional<int64_t> tensorSize = getTypeNumBytes(options, type);
449   if (!scalarSize || !tensorSize) {
450     LLVM_DEBUG(llvm::dbgs()
451                << type << " illegal: cannot deduce element count\n");
452     return nullptr;
453   }
454 
455   int64_t arrayElemCount = *tensorSize / *scalarSize;
456   if (arrayElemCount == 0) {
457     LLVM_DEBUG(llvm::dbgs()
458                << type << " illegal: cannot handle zero-element tensors\n");
459     return nullptr;
460   }
461 
462   Type arrayElemType = convertScalarType(targetEnv, options, scalarType);
463   if (!arrayElemType)
464     return nullptr;
465   std::optional<int64_t> arrayElemSize =
466       getTypeNumBytes(options, arrayElemType);
467   if (!arrayElemSize) {
468     LLVM_DEBUG(llvm::dbgs()
469                << type << " illegal: cannot deduce converted element size\n");
470     return nullptr;
471   }
472 
473   return spirv::ArrayType::get(arrayElemType, arrayElemCount);
474 }
475 
476 static Type convertBoolMemrefType(const spirv::TargetEnv &targetEnv,
477                                   const SPIRVConversionOptions &options,
478                                   MemRefType type,
479                                   spirv::StorageClass storageClass) {
480   unsigned numBoolBits = options.boolNumBits;
481   if (numBoolBits != 8) {
482     LLVM_DEBUG(llvm::dbgs()
483                << "using non-8-bit storage for bool types unimplemented");
484     return nullptr;
485   }
486   auto elementType = dyn_cast<spirv::ScalarType>(
487       IntegerType::get(type.getContext(), numBoolBits));
488   if (!elementType)
489     return nullptr;
490   Type arrayElemType =
491       convertScalarType(targetEnv, options, elementType, storageClass);
492   if (!arrayElemType)
493     return nullptr;
494   std::optional<int64_t> arrayElemSize =
495       getTypeNumBytes(options, arrayElemType);
496   if (!arrayElemSize) {
497     LLVM_DEBUG(llvm::dbgs()
498                << type << " illegal: cannot deduce converted element size\n");
499     return nullptr;
500   }
501 
502   if (!type.hasStaticShape()) {
503     // For OpenCL Kernel, dynamic shaped memrefs convert into a pointer pointing
504     // to the element.
505     if (targetEnv.allows(spirv::Capability::Kernel))
506       return spirv::PointerType::get(arrayElemType, storageClass);
507     int64_t stride = needsExplicitLayout(storageClass) ? *arrayElemSize : 0;
508     auto arrayType = spirv::RuntimeArrayType::get(arrayElemType, stride);
509     // For Vulkan we need extra wrapping struct and array to satisfy interface
510     // needs.
511     return wrapInStructAndGetPointer(arrayType, storageClass);
512   }
513 
514   if (type.getNumElements() == 0) {
515     LLVM_DEBUG(llvm::dbgs()
516                << type << " illegal: zero-element memrefs are not supported\n");
517     return nullptr;
518   }
519 
520   int64_t memrefSize = llvm::divideCeil(type.getNumElements() * numBoolBits, 8);
521   int64_t arrayElemCount = llvm::divideCeil(memrefSize, *arrayElemSize);
522   int64_t stride = needsExplicitLayout(storageClass) ? *arrayElemSize : 0;
523   auto arrayType = spirv::ArrayType::get(arrayElemType, arrayElemCount, stride);
524   if (targetEnv.allows(spirv::Capability::Kernel))
525     return spirv::PointerType::get(arrayType, storageClass);
526   return wrapInStructAndGetPointer(arrayType, storageClass);
527 }
528 
529 static Type convertSubByteMemrefType(const spirv::TargetEnv &targetEnv,
530                                      const SPIRVConversionOptions &options,
531                                      MemRefType type,
532                                      spirv::StorageClass storageClass) {
533   IntegerType elementType = cast<IntegerType>(type.getElementType());
534   Type arrayElemType = convertSubByteIntegerType(options, elementType);
535   if (!arrayElemType)
536     return nullptr;
537   int64_t arrayElemSize = *getTypeNumBytes(options, arrayElemType);
538 
539   if (!type.hasStaticShape()) {
540     // For OpenCL Kernel, dynamic shaped memrefs convert into a pointer pointing
541     // to the element.
542     if (targetEnv.allows(spirv::Capability::Kernel))
543       return spirv::PointerType::get(arrayElemType, storageClass);
544     int64_t stride = needsExplicitLayout(storageClass) ? arrayElemSize : 0;
545     auto arrayType = spirv::RuntimeArrayType::get(arrayElemType, stride);
546     // For Vulkan we need extra wrapping struct and array to satisfy interface
547     // needs.
548     return wrapInStructAndGetPointer(arrayType, storageClass);
549   }
550 
551   if (type.getNumElements() == 0) {
552     LLVM_DEBUG(llvm::dbgs()
553                << type << " illegal: zero-element memrefs are not supported\n");
554     return nullptr;
555   }
556 
557   int64_t memrefSize =
558       llvm::divideCeil(type.getNumElements() * elementType.getWidth(), 8);
559   int64_t arrayElemCount = llvm::divideCeil(memrefSize, arrayElemSize);
560   int64_t stride = needsExplicitLayout(storageClass) ? arrayElemSize : 0;
561   auto arrayType = spirv::ArrayType::get(arrayElemType, arrayElemCount, stride);
562   if (targetEnv.allows(spirv::Capability::Kernel))
563     return spirv::PointerType::get(arrayType, storageClass);
564   return wrapInStructAndGetPointer(arrayType, storageClass);
565 }
566 
567 static Type convertMemrefType(const spirv::TargetEnv &targetEnv,
568                               const SPIRVConversionOptions &options,
569                               MemRefType type) {
570   auto attr = dyn_cast_or_null<spirv::StorageClassAttr>(type.getMemorySpace());
571   if (!attr) {
572     LLVM_DEBUG(
573         llvm::dbgs()
574         << type
575         << " illegal: expected memory space to be a SPIR-V storage class "
576            "attribute; please use MemorySpaceToStorageClassConverter to map "
577            "numeric memory spaces beforehand\n");
578     return nullptr;
579   }
580   spirv::StorageClass storageClass = attr.getValue();
581 
582   if (isa<IntegerType>(type.getElementType())) {
583     if (type.getElementTypeBitWidth() == 1)
584       return convertBoolMemrefType(targetEnv, options, type, storageClass);
585     if (type.getElementTypeBitWidth() < 8)
586       return convertSubByteMemrefType(targetEnv, options, type, storageClass);
587   }
588 
589   Type arrayElemType;
590   Type elementType = type.getElementType();
591   if (auto vecType = dyn_cast<VectorType>(elementType)) {
592     arrayElemType =
593         convertVectorType(targetEnv, options, vecType, storageClass);
594   } else if (auto complexType = dyn_cast<ComplexType>(elementType)) {
595     arrayElemType =
596         convertComplexType(targetEnv, options, complexType, storageClass);
597   } else if (auto scalarType = dyn_cast<spirv::ScalarType>(elementType)) {
598     arrayElemType =
599         convertScalarType(targetEnv, options, scalarType, storageClass);
600   } else if (auto indexType = dyn_cast<IndexType>(elementType)) {
601     type = cast<MemRefType>(convertIndexElementType(type, options));
602     arrayElemType = type.getElementType();
603   } else {
604     LLVM_DEBUG(
605         llvm::dbgs()
606         << type
607         << " unhandled: can only convert scalar or vector element type\n");
608     return nullptr;
609   }
610   if (!arrayElemType)
611     return nullptr;
612 
613   std::optional<int64_t> arrayElemSize =
614       getTypeNumBytes(options, arrayElemType);
615   if (!arrayElemSize) {
616     LLVM_DEBUG(llvm::dbgs()
617                << type << " illegal: cannot deduce converted element size\n");
618     return nullptr;
619   }
620 
621   if (!type.hasStaticShape()) {
622     // For OpenCL Kernel, dynamic shaped memrefs convert into a pointer pointing
623     // to the element.
624     if (targetEnv.allows(spirv::Capability::Kernel))
625       return spirv::PointerType::get(arrayElemType, storageClass);
626     int64_t stride = needsExplicitLayout(storageClass) ? *arrayElemSize : 0;
627     auto arrayType = spirv::RuntimeArrayType::get(arrayElemType, stride);
628     // For Vulkan we need extra wrapping struct and array to satisfy interface
629     // needs.
630     return wrapInStructAndGetPointer(arrayType, storageClass);
631   }
632 
633   std::optional<int64_t> memrefSize = getTypeNumBytes(options, type);
634   if (!memrefSize) {
635     LLVM_DEBUG(llvm::dbgs()
636                << type << " illegal: cannot deduce element count\n");
637     return nullptr;
638   }
639 
640   if (*memrefSize == 0) {
641     LLVM_DEBUG(llvm::dbgs()
642                << type << " illegal: zero-element memrefs are not supported\n");
643     return nullptr;
644   }
645 
646   int64_t arrayElemCount = llvm::divideCeil(*memrefSize, *arrayElemSize);
647   int64_t stride = needsExplicitLayout(storageClass) ? *arrayElemSize : 0;
648   auto arrayType = spirv::ArrayType::get(arrayElemType, arrayElemCount, stride);
649   if (targetEnv.allows(spirv::Capability::Kernel))
650     return spirv::PointerType::get(arrayType, storageClass);
651   return wrapInStructAndGetPointer(arrayType, storageClass);
652 }
653 
654 //===----------------------------------------------------------------------===//
655 // Type casting materialization
656 //===----------------------------------------------------------------------===//
657 
658 /// Converts the given `inputs` to the original source `type` considering the
659 /// `targetEnv`'s capabilities.
660 ///
661 /// This function is meant to be used for source materialization in type
662 /// converters. When the type converter needs to materialize a cast op back
663 /// to some original source type, we need to check whether the original source
664 /// type is supported in the target environment. If so, we can insert legal
665 /// SPIR-V cast ops accordingly.
666 ///
667 /// Note that in SPIR-V the capabilities for storage and compute are separate.
668 /// This function is meant to handle the **compute** side; so it does not
669 /// involve storage classes in its logic. The storage side is expected to be
670 /// handled by MemRef conversion logic.
671 static Value castToSourceType(const spirv::TargetEnv &targetEnv,
672                               OpBuilder &builder, Type type, ValueRange inputs,
673                               Location loc) {
674   // We can only cast one value in SPIR-V.
675   if (inputs.size() != 1) {
676     auto castOp = builder.create<UnrealizedConversionCastOp>(loc, type, inputs);
677     return castOp.getResult(0);
678   }
679   Value input = inputs.front();
680 
681   // Only support integer types for now. Floating point types to be implemented.
682   if (!isa<IntegerType>(type)) {
683     auto castOp = builder.create<UnrealizedConversionCastOp>(loc, type, inputs);
684     return castOp.getResult(0);
685   }
686   auto inputType = cast<IntegerType>(input.getType());
687 
688   auto scalarType = dyn_cast<spirv::ScalarType>(type);
689   if (!scalarType) {
690     auto castOp = builder.create<UnrealizedConversionCastOp>(loc, type, inputs);
691     return castOp.getResult(0);
692   }
693 
694   // Only support source type with a smaller bitwidth. This would mean we are
695   // truncating to go back so we don't need to worry about the signedness.
696   // For extension, we cannot have enough signal here to decide which op to use.
697   if (inputType.getIntOrFloatBitWidth() < scalarType.getIntOrFloatBitWidth()) {
698     auto castOp = builder.create<UnrealizedConversionCastOp>(loc, type, inputs);
699     return castOp.getResult(0);
700   }
701 
702   // Boolean values would need to use different ops than normal integer values.
703   if (type.isInteger(1)) {
704     Value one = spirv::ConstantOp::getOne(inputType, loc, builder);
705     return builder.create<spirv::IEqualOp>(loc, input, one);
706   }
707 
708   // Check that the source integer type is supported by the environment.
709   SmallVector<ArrayRef<spirv::Extension>, 1> exts;
710   SmallVector<ArrayRef<spirv::Capability>, 2> caps;
711   scalarType.getExtensions(exts);
712   scalarType.getCapabilities(caps);
713   if (failed(checkCapabilityRequirements(type, targetEnv, caps)) ||
714       failed(checkExtensionRequirements(type, targetEnv, exts))) {
715     auto castOp = builder.create<UnrealizedConversionCastOp>(loc, type, inputs);
716     return castOp.getResult(0);
717   }
718 
719   // We've already made sure this is truncating previously, so we don't need to
720   // care about signedness here. Still try to use a corresponding op for better
721   // consistency though.
722   if (type.isSignedInteger()) {
723     return builder.create<spirv::SConvertOp>(loc, type, input);
724   }
725   return builder.create<spirv::UConvertOp>(loc, type, input);
726 }
727 
728 //===----------------------------------------------------------------------===//
729 // Builtin Variables
730 //===----------------------------------------------------------------------===//
731 
732 static spirv::GlobalVariableOp getBuiltinVariable(Block &body,
733                                                   spirv::BuiltIn builtin) {
734   // Look through all global variables in the given `body` block and check if
735   // there is a spirv.GlobalVariable that has the same `builtin` attribute.
736   for (auto varOp : body.getOps<spirv::GlobalVariableOp>()) {
737     if (auto builtinAttr = varOp->getAttrOfType<StringAttr>(
738             spirv::SPIRVDialect::getAttributeName(
739                 spirv::Decoration::BuiltIn))) {
740       auto varBuiltIn = spirv::symbolizeBuiltIn(builtinAttr.getValue());
741       if (varBuiltIn && *varBuiltIn == builtin) {
742         return varOp;
743       }
744     }
745   }
746   return nullptr;
747 }
748 
749 /// Gets name of global variable for a builtin.
750 std::string getBuiltinVarName(spirv::BuiltIn builtin, StringRef prefix,
751                               StringRef suffix) {
752   return Twine(prefix).concat(stringifyBuiltIn(builtin)).concat(suffix).str();
753 }
754 
755 /// Gets or inserts a global variable for a builtin within `body` block.
756 static spirv::GlobalVariableOp
757 getOrInsertBuiltinVariable(Block &body, Location loc, spirv::BuiltIn builtin,
758                            Type integerType, OpBuilder &builder,
759                            StringRef prefix, StringRef suffix) {
760   if (auto varOp = getBuiltinVariable(body, builtin))
761     return varOp;
762 
763   OpBuilder::InsertionGuard guard(builder);
764   builder.setInsertionPointToStart(&body);
765 
766   spirv::GlobalVariableOp newVarOp;
767   switch (builtin) {
768   case spirv::BuiltIn::NumWorkgroups:
769   case spirv::BuiltIn::WorkgroupSize:
770   case spirv::BuiltIn::WorkgroupId:
771   case spirv::BuiltIn::LocalInvocationId:
772   case spirv::BuiltIn::GlobalInvocationId: {
773     auto ptrType = spirv::PointerType::get(VectorType::get({3}, integerType),
774                                            spirv::StorageClass::Input);
775     std::string name = getBuiltinVarName(builtin, prefix, suffix);
776     newVarOp =
777         builder.create<spirv::GlobalVariableOp>(loc, ptrType, name, builtin);
778     break;
779   }
780   case spirv::BuiltIn::SubgroupId:
781   case spirv::BuiltIn::NumSubgroups:
782   case spirv::BuiltIn::SubgroupSize: {
783     auto ptrType =
784         spirv::PointerType::get(integerType, spirv::StorageClass::Input);
785     std::string name = getBuiltinVarName(builtin, prefix, suffix);
786     newVarOp =
787         builder.create<spirv::GlobalVariableOp>(loc, ptrType, name, builtin);
788     break;
789   }
790   default:
791     emitError(loc, "unimplemented builtin variable generation for ")
792         << stringifyBuiltIn(builtin);
793   }
794   return newVarOp;
795 }
796 
797 //===----------------------------------------------------------------------===//
798 // Push constant storage
799 //===----------------------------------------------------------------------===//
800 
801 /// Returns the pointer type for the push constant storage containing
802 /// `elementCount` 32-bit integer values.
803 static spirv::PointerType getPushConstantStorageType(unsigned elementCount,
804                                                      Builder &builder,
805                                                      Type indexType) {
806   auto arrayType = spirv::ArrayType::get(indexType, elementCount,
807                                          /*stride=*/4);
808   auto structType = spirv::StructType::get({arrayType}, /*offsetInfo=*/0);
809   return spirv::PointerType::get(structType, spirv::StorageClass::PushConstant);
810 }
811 
812 /// Returns the push constant varible containing `elementCount` 32-bit integer
813 /// values in `body`. Returns null op if such an op does not exit.
814 static spirv::GlobalVariableOp getPushConstantVariable(Block &body,
815                                                        unsigned elementCount) {
816   for (auto varOp : body.getOps<spirv::GlobalVariableOp>()) {
817     auto ptrType = dyn_cast<spirv::PointerType>(varOp.getType());
818     if (!ptrType)
819       continue;
820 
821     // Note that Vulkan requires "There must be no more than one push constant
822     // block statically used per shader entry point." So we should always reuse
823     // the existing one.
824     if (ptrType.getStorageClass() == spirv::StorageClass::PushConstant) {
825       auto numElements = cast<spirv::ArrayType>(
826                              cast<spirv::StructType>(ptrType.getPointeeType())
827                                  .getElementType(0))
828                              .getNumElements();
829       if (numElements == elementCount)
830         return varOp;
831     }
832   }
833   return nullptr;
834 }
835 
836 /// Gets or inserts a global variable for push constant storage containing
837 /// `elementCount` 32-bit integer values in `block`.
838 static spirv::GlobalVariableOp
839 getOrInsertPushConstantVariable(Location loc, Block &block,
840                                 unsigned elementCount, OpBuilder &b,
841                                 Type indexType) {
842   if (auto varOp = getPushConstantVariable(block, elementCount))
843     return varOp;
844 
845   auto builder = OpBuilder::atBlockBegin(&block, b.getListener());
846   auto type = getPushConstantStorageType(elementCount, builder, indexType);
847   const char *name = "__push_constant_var__";
848   return builder.create<spirv::GlobalVariableOp>(loc, type, name,
849                                                  /*initializer=*/nullptr);
850 }
851 
852 //===----------------------------------------------------------------------===//
853 // func::FuncOp Conversion Patterns
854 //===----------------------------------------------------------------------===//
855 
856 /// A pattern for rewriting function signature to convert arguments of functions
857 /// to be of valid SPIR-V types.
858 struct FuncOpConversion final : OpConversionPattern<func::FuncOp> {
859   using OpConversionPattern<func::FuncOp>::OpConversionPattern;
860 
861   LogicalResult
862   matchAndRewrite(func::FuncOp funcOp, OpAdaptor adaptor,
863                   ConversionPatternRewriter &rewriter) const override {
864     FunctionType fnType = funcOp.getFunctionType();
865     if (fnType.getNumResults() > 1)
866       return failure();
867 
868     TypeConverter::SignatureConversion signatureConverter(
869         fnType.getNumInputs());
870     for (const auto &argType : enumerate(fnType.getInputs())) {
871       auto convertedType = getTypeConverter()->convertType(argType.value());
872       if (!convertedType)
873         return failure();
874       signatureConverter.addInputs(argType.index(), convertedType);
875     }
876 
877     Type resultType;
878     if (fnType.getNumResults() == 1) {
879       resultType = getTypeConverter()->convertType(fnType.getResult(0));
880       if (!resultType)
881         return failure();
882     }
883 
884     // Create the converted spirv.func op.
885     auto newFuncOp = rewriter.create<spirv::FuncOp>(
886         funcOp.getLoc(), funcOp.getName(),
887         rewriter.getFunctionType(signatureConverter.getConvertedTypes(),
888                                  resultType ? TypeRange(resultType)
889                                             : TypeRange()));
890 
891     // Copy over all attributes other than the function name and type.
892     for (const auto &namedAttr : funcOp->getAttrs()) {
893       if (namedAttr.getName() != funcOp.getFunctionTypeAttrName() &&
894           namedAttr.getName() != SymbolTable::getSymbolAttrName())
895         newFuncOp->setAttr(namedAttr.getName(), namedAttr.getValue());
896     }
897 
898     rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
899                                 newFuncOp.end());
900     if (failed(rewriter.convertRegionTypes(
901             &newFuncOp.getBody(), *getTypeConverter(), &signatureConverter)))
902       return failure();
903     rewriter.eraseOp(funcOp);
904     return success();
905   }
906 };
907 
908 /// A pattern for rewriting function signature to convert vector arguments of
909 /// functions to be of valid types
910 struct FuncOpVectorUnroll final : OpRewritePattern<func::FuncOp> {
911   using OpRewritePattern::OpRewritePattern;
912 
913   LogicalResult matchAndRewrite(func::FuncOp funcOp,
914                                 PatternRewriter &rewriter) const override {
915     FunctionType fnType = funcOp.getFunctionType();
916 
917     // TODO: Handle declarations.
918     if (funcOp.isDeclaration()) {
919       LLVM_DEBUG(llvm::dbgs()
920                  << fnType << " illegal: declarations are unsupported\n");
921       return failure();
922     }
923 
924     // Create a new func op with the original type and copy the function body.
925     auto newFuncOp = rewriter.create<func::FuncOp>(funcOp.getLoc(),
926                                                    funcOp.getName(), fnType);
927     rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
928                                 newFuncOp.end());
929 
930     Location loc = newFuncOp.getBody().getLoc();
931 
932     Block &entryBlock = newFuncOp.getBlocks().front();
933     OpBuilder::InsertionGuard guard(rewriter);
934     rewriter.setInsertionPointToStart(&entryBlock);
935 
936     OneToNTypeMapping oneToNTypeMapping(fnType.getInputs());
937 
938     // For arguments that are of illegal types and require unrolling.
939     // `unrolledInputNums` stores the indices of arguments that result from
940     // unrolling in the new function signature. `newInputNo` is a counter.
941     SmallVector<size_t> unrolledInputNums;
942     size_t newInputNo = 0;
943 
944     // For arguments that are of legal types and do not require unrolling.
945     // `tmpOps` stores a mapping from temporary operations that serve as
946     // placeholders for new arguments that will be added later. These operations
947     // will be erased once the entry block's argument list is updated.
948     llvm::SmallDenseMap<Operation *, size_t> tmpOps;
949 
950     // This counts the number of new operations created.
951     size_t newOpCount = 0;
952 
953     // Enumerate through the arguments.
954     for (auto [origInputNo, origType] : enumerate(fnType.getInputs())) {
955       // Check whether the argument is of vector type.
956       auto origVecType = dyn_cast<VectorType>(origType);
957       if (!origVecType) {
958         // We need a placeholder for the old argument that will be erased later.
959         Value result = rewriter.create<arith::ConstantOp>(
960             loc, origType, rewriter.getZeroAttr(origType));
961         rewriter.replaceAllUsesWith(newFuncOp.getArgument(origInputNo), result);
962         tmpOps.insert({result.getDefiningOp(), newInputNo});
963         oneToNTypeMapping.addInputs(origInputNo, origType);
964         ++newInputNo;
965         ++newOpCount;
966         continue;
967       }
968       // Check whether the vector needs unrolling.
969       auto targetShape = getTargetShape(origVecType);
970       if (!targetShape) {
971         // We need a placeholder for the old argument that will be erased later.
972         Value result = rewriter.create<arith::ConstantOp>(
973             loc, origType, rewriter.getZeroAttr(origType));
974         rewriter.replaceAllUsesWith(newFuncOp.getArgument(origInputNo), result);
975         tmpOps.insert({result.getDefiningOp(), newInputNo});
976         oneToNTypeMapping.addInputs(origInputNo, origType);
977         ++newInputNo;
978         ++newOpCount;
979         continue;
980       }
981       VectorType unrolledType =
982           VectorType::get(*targetShape, origVecType.getElementType());
983       auto originalShape =
984           llvm::to_vector_of<int64_t, 4>(origVecType.getShape());
985 
986       // Prepare the result vector.
987       Value result = rewriter.create<arith::ConstantOp>(
988           loc, origVecType, rewriter.getZeroAttr(origVecType));
989       ++newOpCount;
990       // Prepare the placeholder for the new arguments that will be added later.
991       Value dummy = rewriter.create<arith::ConstantOp>(
992           loc, unrolledType, rewriter.getZeroAttr(unrolledType));
993       ++newOpCount;
994 
995       // Create the `vector.insert_strided_slice` ops.
996       SmallVector<int64_t> strides(targetShape->size(), 1);
997       SmallVector<Type> newTypes;
998       for (SmallVector<int64_t> offsets :
999            StaticTileOffsetRange(originalShape, *targetShape)) {
1000         result = rewriter.create<vector::InsertStridedSliceOp>(
1001             loc, dummy, result, offsets, strides);
1002         newTypes.push_back(unrolledType);
1003         unrolledInputNums.push_back(newInputNo);
1004         ++newInputNo;
1005         ++newOpCount;
1006       }
1007       rewriter.replaceAllUsesWith(newFuncOp.getArgument(origInputNo), result);
1008       oneToNTypeMapping.addInputs(origInputNo, newTypes);
1009     }
1010 
1011     // Change the function signature.
1012     auto convertedTypes = oneToNTypeMapping.getConvertedTypes();
1013     auto newFnType = fnType.clone(convertedTypes, fnType.getResults());
1014     rewriter.modifyOpInPlace(newFuncOp,
1015                              [&] { newFuncOp.setFunctionType(newFnType); });
1016 
1017     // Update the arguments in the entry block.
1018     entryBlock.eraseArguments(0, fnType.getNumInputs());
1019     SmallVector<Location> locs(convertedTypes.size(), newFuncOp.getLoc());
1020     entryBlock.addArguments(convertedTypes, locs);
1021 
1022     // Replace the placeholder values with the new arguments. We assume there is
1023     // only one block for now.
1024     size_t unrolledInputIdx = 0;
1025     for (auto [count, op] : enumerate(entryBlock.getOperations())) {
1026       // We first look for operands that are placeholders for initially legal
1027       // arguments.
1028       Operation &curOp = op;
1029       for (auto [operandIdx, operandVal] : llvm::enumerate(op.getOperands())) {
1030         Operation *operandOp = operandVal.getDefiningOp();
1031         if (auto it = tmpOps.find(operandOp); it != tmpOps.end()) {
1032           size_t idx = operandIdx;
1033           rewriter.modifyOpInPlace(&curOp, [&curOp, &newFuncOp, it, idx] {
1034             curOp.setOperand(idx, newFuncOp.getArgument(it->second));
1035           });
1036         }
1037       }
1038       // Since all newly created operations are in the beginning, reaching the
1039       // end of them means that any later `vector.insert_strided_slice` should
1040       // not be touched.
1041       if (count >= newOpCount)
1042         continue;
1043       if (auto vecOp = dyn_cast<vector::InsertStridedSliceOp>(op)) {
1044         size_t unrolledInputNo = unrolledInputNums[unrolledInputIdx];
1045         rewriter.modifyOpInPlace(&curOp, [&] {
1046           curOp.setOperand(0, newFuncOp.getArgument(unrolledInputNo));
1047         });
1048         ++unrolledInputIdx;
1049       }
1050     }
1051 
1052     // Erase the original funcOp. The `tmpOps` do not need to be erased since
1053     // they have no uses and will be handled by dead-code elimination.
1054     rewriter.eraseOp(funcOp);
1055     return success();
1056   }
1057 };
1058 
1059 //===----------------------------------------------------------------------===//
1060 // func::ReturnOp Conversion Patterns
1061 //===----------------------------------------------------------------------===//
1062 
1063 /// A pattern for rewriting function signature and the return op to convert
1064 /// vectors to be of valid types.
1065 struct ReturnOpVectorUnroll final : OpRewritePattern<func::ReturnOp> {
1066   using OpRewritePattern::OpRewritePattern;
1067 
1068   LogicalResult matchAndRewrite(func::ReturnOp returnOp,
1069                                 PatternRewriter &rewriter) const override {
1070     // Check whether the parent funcOp is valid.
1071     auto funcOp = dyn_cast<func::FuncOp>(returnOp->getParentOp());
1072     if (!funcOp)
1073       return failure();
1074 
1075     FunctionType fnType = funcOp.getFunctionType();
1076     OneToNTypeMapping oneToNTypeMapping(fnType.getResults());
1077     Location loc = returnOp.getLoc();
1078 
1079     // For the new return op.
1080     SmallVector<Value> newOperands;
1081 
1082     // Enumerate through the results.
1083     for (auto [origResultNo, origType] : enumerate(fnType.getResults())) {
1084       // Check whether the argument is of vector type.
1085       auto origVecType = dyn_cast<VectorType>(origType);
1086       if (!origVecType) {
1087         oneToNTypeMapping.addInputs(origResultNo, origType);
1088         newOperands.push_back(returnOp.getOperand(origResultNo));
1089         continue;
1090       }
1091       // Check whether the vector needs unrolling.
1092       auto targetShape = getTargetShape(origVecType);
1093       if (!targetShape) {
1094         // The original argument can be used.
1095         oneToNTypeMapping.addInputs(origResultNo, origType);
1096         newOperands.push_back(returnOp.getOperand(origResultNo));
1097         continue;
1098       }
1099       VectorType unrolledType =
1100           VectorType::get(*targetShape, origVecType.getElementType());
1101 
1102       // Create `vector.extract_strided_slice` ops to form legal vectors from
1103       // the original operand of illegal type.
1104       auto originalShape =
1105           llvm::to_vector_of<int64_t, 4>(origVecType.getShape());
1106       SmallVector<int64_t> strides(originalShape.size(), 1);
1107       SmallVector<int64_t> extractShape(originalShape.size(), 1);
1108       extractShape.back() = targetShape->back();
1109       SmallVector<Type> newTypes;
1110       Value returnValue = returnOp.getOperand(origResultNo);
1111       for (SmallVector<int64_t> offsets :
1112            StaticTileOffsetRange(originalShape, *targetShape)) {
1113         Value result = rewriter.create<vector::ExtractStridedSliceOp>(
1114             loc, returnValue, offsets, extractShape, strides);
1115         if (originalShape.size() > 1) {
1116           SmallVector<int64_t> extractIndices(originalShape.size() - 1, 0);
1117           result =
1118               rewriter.create<vector::ExtractOp>(loc, result, extractIndices);
1119         }
1120         newOperands.push_back(result);
1121         newTypes.push_back(unrolledType);
1122       }
1123       oneToNTypeMapping.addInputs(origResultNo, newTypes);
1124     }
1125 
1126     // Change the function signature.
1127     auto newFnType =
1128         FunctionType::get(rewriter.getContext(), TypeRange(fnType.getInputs()),
1129                           TypeRange(oneToNTypeMapping.getConvertedTypes()));
1130     rewriter.modifyOpInPlace(funcOp,
1131                              [&] { funcOp.setFunctionType(newFnType); });
1132 
1133     // Replace the return op using the new operands. This will automatically
1134     // update the entry block as well.
1135     rewriter.replaceOp(returnOp,
1136                        rewriter.create<func::ReturnOp>(loc, newOperands));
1137 
1138     return success();
1139   }
1140 };
1141 
1142 } // namespace
1143 
1144 //===----------------------------------------------------------------------===//
1145 // Public function for builtin variables
1146 //===----------------------------------------------------------------------===//
1147 
1148 Value mlir::spirv::getBuiltinVariableValue(Operation *op,
1149                                            spirv::BuiltIn builtin,
1150                                            Type integerType, OpBuilder &builder,
1151                                            StringRef prefix, StringRef suffix) {
1152   Operation *parent = SymbolTable::getNearestSymbolTable(op->getParentOp());
1153   if (!parent) {
1154     op->emitError("expected operation to be within a module-like op");
1155     return nullptr;
1156   }
1157 
1158   spirv::GlobalVariableOp varOp =
1159       getOrInsertBuiltinVariable(*parent->getRegion(0).begin(), op->getLoc(),
1160                                  builtin, integerType, builder, prefix, suffix);
1161   Value ptr = builder.create<spirv::AddressOfOp>(op->getLoc(), varOp);
1162   return builder.create<spirv::LoadOp>(op->getLoc(), ptr);
1163 }
1164 
1165 //===----------------------------------------------------------------------===//
1166 // Public function for pushing constant storage
1167 //===----------------------------------------------------------------------===//
1168 
1169 Value spirv::getPushConstantValue(Operation *op, unsigned elementCount,
1170                                   unsigned offset, Type integerType,
1171                                   OpBuilder &builder) {
1172   Location loc = op->getLoc();
1173   Operation *parent = SymbolTable::getNearestSymbolTable(op->getParentOp());
1174   if (!parent) {
1175     op->emitError("expected operation to be within a module-like op");
1176     return nullptr;
1177   }
1178 
1179   spirv::GlobalVariableOp varOp = getOrInsertPushConstantVariable(
1180       loc, parent->getRegion(0).front(), elementCount, builder, integerType);
1181 
1182   Value zeroOp = spirv::ConstantOp::getZero(integerType, loc, builder);
1183   Value offsetOp = builder.create<spirv::ConstantOp>(
1184       loc, integerType, builder.getI32IntegerAttr(offset));
1185   auto addrOp = builder.create<spirv::AddressOfOp>(loc, varOp);
1186   auto acOp = builder.create<spirv::AccessChainOp>(
1187       loc, addrOp, llvm::ArrayRef({zeroOp, offsetOp}));
1188   return builder.create<spirv::LoadOp>(loc, acOp);
1189 }
1190 
1191 //===----------------------------------------------------------------------===//
1192 // Public functions for index calculation
1193 //===----------------------------------------------------------------------===//
1194 
1195 Value mlir::spirv::linearizeIndex(ValueRange indices, ArrayRef<int64_t> strides,
1196                                   int64_t offset, Type integerType,
1197                                   Location loc, OpBuilder &builder) {
1198   assert(indices.size() == strides.size() &&
1199          "must provide indices for all dimensions");
1200 
1201   // TODO: Consider moving to use affine.apply and patterns converting
1202   // affine.apply to standard ops. This needs converting to SPIR-V passes to be
1203   // broken down into progressive small steps so we can have intermediate steps
1204   // using other dialects. At the moment SPIR-V is the final sink.
1205 
1206   Value linearizedIndex = builder.createOrFold<spirv::ConstantOp>(
1207       loc, integerType, IntegerAttr::get(integerType, offset));
1208   for (const auto &index : llvm::enumerate(indices)) {
1209     Value strideVal = builder.createOrFold<spirv::ConstantOp>(
1210         loc, integerType,
1211         IntegerAttr::get(integerType, strides[index.index()]));
1212     Value update =
1213         builder.createOrFold<spirv::IMulOp>(loc, index.value(), strideVal);
1214     linearizedIndex =
1215         builder.createOrFold<spirv::IAddOp>(loc, update, linearizedIndex);
1216   }
1217   return linearizedIndex;
1218 }
1219 
1220 Value mlir::spirv::getVulkanElementPtr(const SPIRVTypeConverter &typeConverter,
1221                                        MemRefType baseType, Value basePtr,
1222                                        ValueRange indices, Location loc,
1223                                        OpBuilder &builder) {
1224   // Get base and offset of the MemRefType and verify they are static.
1225 
1226   int64_t offset;
1227   SmallVector<int64_t, 4> strides;
1228   if (failed(baseType.getStridesAndOffset(strides, offset)) ||
1229       llvm::is_contained(strides, ShapedType::kDynamic) ||
1230       ShapedType::isDynamic(offset)) {
1231     return nullptr;
1232   }
1233 
1234   auto indexType = typeConverter.getIndexType();
1235 
1236   SmallVector<Value, 2> linearizedIndices;
1237   auto zero = spirv::ConstantOp::getZero(indexType, loc, builder);
1238 
1239   // Add a '0' at the start to index into the struct.
1240   linearizedIndices.push_back(zero);
1241 
1242   if (baseType.getRank() == 0) {
1243     linearizedIndices.push_back(zero);
1244   } else {
1245     linearizedIndices.push_back(
1246         linearizeIndex(indices, strides, offset, indexType, loc, builder));
1247   }
1248   return builder.create<spirv::AccessChainOp>(loc, basePtr, linearizedIndices);
1249 }
1250 
1251 Value mlir::spirv::getOpenCLElementPtr(const SPIRVTypeConverter &typeConverter,
1252                                        MemRefType baseType, Value basePtr,
1253                                        ValueRange indices, Location loc,
1254                                        OpBuilder &builder) {
1255   // Get base and offset of the MemRefType and verify they are static.
1256 
1257   int64_t offset;
1258   SmallVector<int64_t, 4> strides;
1259   if (failed(baseType.getStridesAndOffset(strides, offset)) ||
1260       llvm::is_contained(strides, ShapedType::kDynamic) ||
1261       ShapedType::isDynamic(offset)) {
1262     return nullptr;
1263   }
1264 
1265   auto indexType = typeConverter.getIndexType();
1266 
1267   SmallVector<Value, 2> linearizedIndices;
1268   Value linearIndex;
1269   if (baseType.getRank() == 0) {
1270     linearIndex = spirv::ConstantOp::getZero(indexType, loc, builder);
1271   } else {
1272     linearIndex =
1273         linearizeIndex(indices, strides, offset, indexType, loc, builder);
1274   }
1275   Type pointeeType =
1276       cast<spirv::PointerType>(basePtr.getType()).getPointeeType();
1277   if (isa<spirv::ArrayType>(pointeeType)) {
1278     linearizedIndices.push_back(linearIndex);
1279     return builder.create<spirv::AccessChainOp>(loc, basePtr,
1280                                                 linearizedIndices);
1281   }
1282   return builder.create<spirv::PtrAccessChainOp>(loc, basePtr, linearIndex,
1283                                                  linearizedIndices);
1284 }
1285 
1286 Value mlir::spirv::getElementPtr(const SPIRVTypeConverter &typeConverter,
1287                                  MemRefType baseType, Value basePtr,
1288                                  ValueRange indices, Location loc,
1289                                  OpBuilder &builder) {
1290 
1291   if (typeConverter.allows(spirv::Capability::Kernel)) {
1292     return getOpenCLElementPtr(typeConverter, baseType, basePtr, indices, loc,
1293                                builder);
1294   }
1295 
1296   return getVulkanElementPtr(typeConverter, baseType, basePtr, indices, loc,
1297                              builder);
1298 }
1299 
1300 //===----------------------------------------------------------------------===//
1301 // Public functions for vector unrolling
1302 //===----------------------------------------------------------------------===//
1303 
1304 int mlir::spirv::getComputeVectorSize(int64_t size) {
1305   for (int i : {4, 3, 2}) {
1306     if (size % i == 0)
1307       return i;
1308   }
1309   return 1;
1310 }
1311 
1312 SmallVector<int64_t>
1313 mlir::spirv::getNativeVectorShapeImpl(vector::ReductionOp op) {
1314   VectorType srcVectorType = op.getSourceVectorType();
1315   assert(srcVectorType.getRank() == 1); // Guaranteed by semantics
1316   int64_t vectorSize =
1317       mlir::spirv::getComputeVectorSize(srcVectorType.getDimSize(0));
1318   return {vectorSize};
1319 }
1320 
1321 SmallVector<int64_t>
1322 mlir::spirv::getNativeVectorShapeImpl(vector::TransposeOp op) {
1323   VectorType vectorType = op.getResultVectorType();
1324   SmallVector<int64_t> nativeSize(vectorType.getRank(), 1);
1325   nativeSize.back() =
1326       mlir::spirv::getComputeVectorSize(vectorType.getShape().back());
1327   return nativeSize;
1328 }
1329 
1330 std::optional<SmallVector<int64_t>>
1331 mlir::spirv::getNativeVectorShape(Operation *op) {
1332   if (OpTrait::hasElementwiseMappableTraits(op) && op->getNumResults() == 1) {
1333     if (auto vecType = dyn_cast<VectorType>(op->getResultTypes()[0])) {
1334       SmallVector<int64_t> nativeSize(vecType.getRank(), 1);
1335       nativeSize.back() =
1336           mlir::spirv::getComputeVectorSize(vecType.getShape().back());
1337       return nativeSize;
1338     }
1339   }
1340 
1341   return TypeSwitch<Operation *, std::optional<SmallVector<int64_t>>>(op)
1342       .Case<vector::ReductionOp, vector::TransposeOp>(
1343           [](auto typedOp) { return getNativeVectorShapeImpl(typedOp); })
1344       .Default([](Operation *) { return std::nullopt; });
1345 }
1346 
1347 LogicalResult mlir::spirv::unrollVectorsInSignatures(Operation *op) {
1348   MLIRContext *context = op->getContext();
1349   RewritePatternSet patterns(context);
1350   populateFuncOpVectorRewritePatterns(patterns);
1351   populateReturnOpVectorRewritePatterns(patterns);
1352   // We only want to apply signature conversion once to the existing func ops.
1353   // Without specifying strictMode, the greedy pattern rewriter will keep
1354   // looking for newly created func ops.
1355   GreedyRewriteConfig config;
1356   config.strictMode = GreedyRewriteStrictness::ExistingOps;
1357   return applyPatternsGreedily(op, std::move(patterns), config);
1358 }
1359 
1360 LogicalResult mlir::spirv::unrollVectorsInFuncBodies(Operation *op) {
1361   MLIRContext *context = op->getContext();
1362 
1363   // Unroll vectors in function bodies to native vector size.
1364   {
1365     RewritePatternSet patterns(context);
1366     auto options = vector::UnrollVectorOptions().setNativeShapeFn(
1367         [](auto op) { return mlir::spirv::getNativeVectorShape(op); });
1368     populateVectorUnrollPatterns(patterns, options);
1369     if (failed(applyPatternsGreedily(op, std::move(patterns))))
1370       return failure();
1371   }
1372 
1373   // Convert transpose ops into extract and insert pairs, in preparation of
1374   // further transformations to canonicalize/cancel.
1375   {
1376     RewritePatternSet patterns(context);
1377     auto options = vector::VectorTransformsOptions().setVectorTransposeLowering(
1378         vector::VectorTransposeLowering::EltWise);
1379     vector::populateVectorTransposeLoweringPatterns(patterns, options);
1380     vector::populateVectorShapeCastLoweringPatterns(patterns);
1381     if (failed(applyPatternsGreedily(op, std::move(patterns))))
1382       return failure();
1383   }
1384 
1385   // Run canonicalization to cast away leading size-1 dimensions.
1386   {
1387     RewritePatternSet patterns(context);
1388 
1389     // We need to pull in casting way leading one dims.
1390     vector::populateCastAwayVectorLeadingOneDimPatterns(patterns);
1391     vector::ReductionOp::getCanonicalizationPatterns(patterns, context);
1392     vector::TransposeOp::getCanonicalizationPatterns(patterns, context);
1393 
1394     // Decompose different rank insert_strided_slice and n-D
1395     // extract_slided_slice.
1396     vector::populateVectorInsertExtractStridedSliceDecompositionPatterns(
1397         patterns);
1398     vector::InsertOp::getCanonicalizationPatterns(patterns, context);
1399     vector::ExtractOp::getCanonicalizationPatterns(patterns, context);
1400 
1401     // Trimming leading unit dims may generate broadcast/shape_cast ops. Clean
1402     // them up.
1403     vector::BroadcastOp::getCanonicalizationPatterns(patterns, context);
1404     vector::ShapeCastOp::getCanonicalizationPatterns(patterns, context);
1405 
1406     if (failed(applyPatternsGreedily(op, std::move(patterns))))
1407       return failure();
1408   }
1409   return success();
1410 }
1411 
1412 //===----------------------------------------------------------------------===//
1413 // SPIR-V TypeConverter
1414 //===----------------------------------------------------------------------===//
1415 
1416 SPIRVTypeConverter::SPIRVTypeConverter(spirv::TargetEnvAttr targetAttr,
1417                                        const SPIRVConversionOptions &options)
1418     : targetEnv(targetAttr), options(options) {
1419   // Add conversions. The order matters here: later ones will be tried earlier.
1420 
1421   // Allow all SPIR-V dialect specific types. This assumes all builtin types
1422   // adopted in the SPIR-V dialect (i.e., IntegerType, FloatType, VectorType)
1423   // were tried before.
1424   //
1425   // TODO: This assumes that the SPIR-V types are valid to use in the given
1426   // target environment, which should be the case if the whole pipeline is
1427   // driven by the same target environment. Still, we probably still want to
1428   // validate and convert to be safe.
1429   addConversion([](spirv::SPIRVType type) { return type; });
1430 
1431   addConversion([this](IndexType /*indexType*/) { return getIndexType(); });
1432 
1433   addConversion([this](IntegerType intType) -> std::optional<Type> {
1434     if (auto scalarType = dyn_cast<spirv::ScalarType>(intType))
1435       return convertScalarType(this->targetEnv, this->options, scalarType);
1436     if (intType.getWidth() < 8)
1437       return convertSubByteIntegerType(this->options, intType);
1438     return Type();
1439   });
1440 
1441   addConversion([this](FloatType floatType) -> std::optional<Type> {
1442     if (auto scalarType = dyn_cast<spirv::ScalarType>(floatType))
1443       return convertScalarType(this->targetEnv, this->options, scalarType);
1444     return Type();
1445   });
1446 
1447   addConversion([this](ComplexType complexType) {
1448     return convertComplexType(this->targetEnv, this->options, complexType);
1449   });
1450 
1451   addConversion([this](VectorType vectorType) {
1452     return convertVectorType(this->targetEnv, this->options, vectorType);
1453   });
1454 
1455   addConversion([this](TensorType tensorType) {
1456     return convertTensorType(this->targetEnv, this->options, tensorType);
1457   });
1458 
1459   addConversion([this](MemRefType memRefType) {
1460     return convertMemrefType(this->targetEnv, this->options, memRefType);
1461   });
1462 
1463   // Register some last line of defense casting logic.
1464   addSourceMaterialization(
1465       [this](OpBuilder &builder, Type type, ValueRange inputs, Location loc) {
1466         return castToSourceType(this->targetEnv, builder, type, inputs, loc);
1467       });
1468   addTargetMaterialization([](OpBuilder &builder, Type type, ValueRange inputs,
1469                               Location loc) {
1470     auto cast = builder.create<UnrealizedConversionCastOp>(loc, type, inputs);
1471     return cast.getResult(0);
1472   });
1473 }
1474 
1475 Type SPIRVTypeConverter::getIndexType() const {
1476   return ::getIndexType(getContext(), options);
1477 }
1478 
1479 MLIRContext *SPIRVTypeConverter::getContext() const {
1480   return targetEnv.getAttr().getContext();
1481 }
1482 
1483 bool SPIRVTypeConverter::allows(spirv::Capability capability) const {
1484   return targetEnv.allows(capability);
1485 }
1486 
1487 //===----------------------------------------------------------------------===//
1488 // SPIR-V ConversionTarget
1489 //===----------------------------------------------------------------------===//
1490 
1491 std::unique_ptr<SPIRVConversionTarget>
1492 SPIRVConversionTarget::get(spirv::TargetEnvAttr targetAttr) {
1493   std::unique_ptr<SPIRVConversionTarget> target(
1494       // std::make_unique does not work here because the constructor is private.
1495       new SPIRVConversionTarget(targetAttr));
1496   SPIRVConversionTarget *targetPtr = target.get();
1497   target->addDynamicallyLegalDialect<spirv::SPIRVDialect>(
1498       // We need to capture the raw pointer here because it is stable:
1499       // target will be destroyed once this function is returned.
1500       [targetPtr](Operation *op) { return targetPtr->isLegalOp(op); });
1501   return target;
1502 }
1503 
1504 SPIRVConversionTarget::SPIRVConversionTarget(spirv::TargetEnvAttr targetAttr)
1505     : ConversionTarget(*targetAttr.getContext()), targetEnv(targetAttr) {}
1506 
1507 bool SPIRVConversionTarget::isLegalOp(Operation *op) {
1508   // Make sure this op is available at the given version. Ops not implementing
1509   // QueryMinVersionInterface/QueryMaxVersionInterface are available to all
1510   // SPIR-V versions.
1511   if (auto minVersionIfx = dyn_cast<spirv::QueryMinVersionInterface>(op)) {
1512     std::optional<spirv::Version> minVersion = minVersionIfx.getMinVersion();
1513     if (minVersion && *minVersion > this->targetEnv.getVersion()) {
1514       LLVM_DEBUG(llvm::dbgs()
1515                  << op->getName() << " illegal: requiring min version "
1516                  << spirv::stringifyVersion(*minVersion) << "\n");
1517       return false;
1518     }
1519   }
1520   if (auto maxVersionIfx = dyn_cast<spirv::QueryMaxVersionInterface>(op)) {
1521     std::optional<spirv::Version> maxVersion = maxVersionIfx.getMaxVersion();
1522     if (maxVersion && *maxVersion < this->targetEnv.getVersion()) {
1523       LLVM_DEBUG(llvm::dbgs()
1524                  << op->getName() << " illegal: requiring max version "
1525                  << spirv::stringifyVersion(*maxVersion) << "\n");
1526       return false;
1527     }
1528   }
1529 
1530   // Make sure this op's required extensions are allowed to use. Ops not
1531   // implementing QueryExtensionInterface do not require extensions to be
1532   // available.
1533   if (auto extensions = dyn_cast<spirv::QueryExtensionInterface>(op))
1534     if (failed(checkExtensionRequirements(op->getName(), this->targetEnv,
1535                                           extensions.getExtensions())))
1536       return false;
1537 
1538   // Make sure this op's required extensions are allowed to use. Ops not
1539   // implementing QueryCapabilityInterface do not require capabilities to be
1540   // available.
1541   if (auto capabilities = dyn_cast<spirv::QueryCapabilityInterface>(op))
1542     if (failed(checkCapabilityRequirements(op->getName(), this->targetEnv,
1543                                            capabilities.getCapabilities())))
1544       return false;
1545 
1546   SmallVector<Type, 4> valueTypes;
1547   valueTypes.append(op->operand_type_begin(), op->operand_type_end());
1548   valueTypes.append(op->result_type_begin(), op->result_type_end());
1549 
1550   // Ensure that all types have been converted to SPIRV types.
1551   if (llvm::any_of(valueTypes,
1552                    [](Type t) { return !isa<spirv::SPIRVType>(t); }))
1553     return false;
1554 
1555   // Special treatment for global variables, whose type requirements are
1556   // conveyed by type attributes.
1557   if (auto globalVar = dyn_cast<spirv::GlobalVariableOp>(op))
1558     valueTypes.push_back(globalVar.getType());
1559 
1560   // Make sure the op's operands/results use types that are allowed by the
1561   // target environment.
1562   SmallVector<ArrayRef<spirv::Extension>, 4> typeExtensions;
1563   SmallVector<ArrayRef<spirv::Capability>, 8> typeCapabilities;
1564   for (Type valueType : valueTypes) {
1565     typeExtensions.clear();
1566     cast<spirv::SPIRVType>(valueType).getExtensions(typeExtensions);
1567     if (failed(checkExtensionRequirements(op->getName(), this->targetEnv,
1568                                           typeExtensions)))
1569       return false;
1570 
1571     typeCapabilities.clear();
1572     cast<spirv::SPIRVType>(valueType).getCapabilities(typeCapabilities);
1573     if (failed(checkCapabilityRequirements(op->getName(), this->targetEnv,
1574                                            typeCapabilities)))
1575       return false;
1576   }
1577 
1578   return true;
1579 }
1580 
1581 //===----------------------------------------------------------------------===//
1582 // Public functions for populating patterns
1583 //===----------------------------------------------------------------------===//
1584 
1585 void mlir::populateBuiltinFuncToSPIRVPatterns(
1586     const SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns) {
1587   patterns.add<FuncOpConversion>(typeConverter, patterns.getContext());
1588 }
1589 
1590 void mlir::populateFuncOpVectorRewritePatterns(RewritePatternSet &patterns) {
1591   patterns.add<FuncOpVectorUnroll>(patterns.getContext());
1592 }
1593 
1594 void mlir::populateReturnOpVectorRewritePatterns(RewritePatternSet &patterns) {
1595   patterns.add<ReturnOpVectorUnroll>(patterns.getContext());
1596 }
1597