xref: /llvm-project/mlir/lib/Dialect/SPIRV/Transforms/UnifyAliasedResourcePass.cpp (revision fab2bb8bfda865bd438dee981d7be7df8017b76d)
1 //===- UnifyAliasedResourcePass.cpp - Pass to Unify Aliased Resources -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a pass that unifies access of multiple aliased resources
10 // into access of one single resource.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Dialect/SPIRV/Transforms/Passes.h"
15 
16 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
17 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
18 #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
19 #include "mlir/Dialect/SPIRV/IR/TargetAndABI.h"
20 #include "mlir/IR/Builders.h"
21 #include "mlir/IR/BuiltinAttributes.h"
22 #include "mlir/IR/BuiltinTypes.h"
23 #include "mlir/IR/SymbolTable.h"
24 #include "mlir/Transforms/DialectConversion.h"
25 #include "llvm/ADT/DenseMap.h"
26 #include "llvm/ADT/STLExtras.h"
27 #include "llvm/Support/Debug.h"
28 #include <algorithm>
29 #include <iterator>
30 
31 namespace mlir {
32 namespace spirv {
33 #define GEN_PASS_DEF_SPIRVUNIFYALIASEDRESOURCEPASS
34 #include "mlir/Dialect/SPIRV/Transforms/Passes.h.inc"
35 } // namespace spirv
36 } // namespace mlir
37 
38 #define DEBUG_TYPE "spirv-unify-aliased-resource"
39 
40 using namespace mlir;
41 
42 //===----------------------------------------------------------------------===//
43 // Utility functions
44 //===----------------------------------------------------------------------===//
45 
46 using Descriptor = std::pair<uint32_t, uint32_t>; // (set #, binding #)
47 using AliasedResourceMap =
48     DenseMap<Descriptor, SmallVector<spirv::GlobalVariableOp>>;
49 
50 /// Collects all aliased resources in the given SPIR-V `moduleOp`.
collectAliasedResources(spirv::ModuleOp moduleOp)51 static AliasedResourceMap collectAliasedResources(spirv::ModuleOp moduleOp) {
52   AliasedResourceMap aliasedResources;
53   moduleOp->walk([&aliasedResources](spirv::GlobalVariableOp varOp) {
54     if (varOp->getAttrOfType<UnitAttr>("aliased")) {
55       std::optional<uint32_t> set = varOp.getDescriptorSet();
56       std::optional<uint32_t> binding = varOp.getBinding();
57       if (set && binding)
58         aliasedResources[{*set, *binding}].push_back(varOp);
59     }
60   });
61   return aliasedResources;
62 }
63 
64 /// Returns the element type if the given `type` is a runtime array resource:
65 /// `!spirv.ptr<!spirv.struct<!spirv.rtarray<...>>>`. Returns null type
66 /// otherwise.
getRuntimeArrayElementType(Type type)67 static Type getRuntimeArrayElementType(Type type) {
68   auto ptrType = dyn_cast<spirv::PointerType>(type);
69   if (!ptrType)
70     return {};
71 
72   auto structType = dyn_cast<spirv::StructType>(ptrType.getPointeeType());
73   if (!structType || structType.getNumElements() != 1)
74     return {};
75 
76   auto rtArrayType =
77       dyn_cast<spirv::RuntimeArrayType>(structType.getElementType(0));
78   if (!rtArrayType)
79     return {};
80 
81   return rtArrayType.getElementType();
82 }
83 
84 /// Given a list of resource element `types`, returns the index of the canonical
85 /// resource that all resources should be unified into. Returns std::nullopt if
86 /// unable to unify.
87 static std::optional<int>
deduceCanonicalResource(ArrayRef<spirv::SPIRVType> types)88 deduceCanonicalResource(ArrayRef<spirv::SPIRVType> types) {
89   // scalarNumBits: contains all resources' scalar types' bit counts.
90   // vectorNumBits: only contains resources whose element types are vectors.
91   // vectorIndices: each vector's original index in `types`.
92   SmallVector<int> scalarNumBits, vectorNumBits, vectorIndices;
93   scalarNumBits.reserve(types.size());
94   vectorNumBits.reserve(types.size());
95   vectorIndices.reserve(types.size());
96 
97   for (const auto &indexedTypes : llvm::enumerate(types)) {
98     spirv::SPIRVType type = indexedTypes.value();
99     assert(type.isScalarOrVector());
100     if (auto vectorType = dyn_cast<VectorType>(type)) {
101       if (vectorType.getNumElements() % 2 != 0)
102         return std::nullopt; // Odd-sized vector has special layout
103                              // requirements.
104 
105       std::optional<int64_t> numBytes = type.getSizeInBytes();
106       if (!numBytes)
107         return std::nullopt;
108 
109       scalarNumBits.push_back(
110           vectorType.getElementType().getIntOrFloatBitWidth());
111       vectorNumBits.push_back(*numBytes * 8);
112       vectorIndices.push_back(indexedTypes.index());
113     } else {
114       scalarNumBits.push_back(type.getIntOrFloatBitWidth());
115     }
116   }
117 
118   if (!vectorNumBits.empty()) {
119     // Choose the *vector* with the smallest bitwidth as the canonical resource,
120     // so that we can still keep vectorized load/store and avoid partial updates
121     // to large vectors.
122     auto *minVal = llvm::min_element(vectorNumBits);
123     // Make sure that the canonical resource's bitwidth is divisible by others.
124     // With out this, we cannot properly adjust the index later.
125     if (llvm::any_of(vectorNumBits,
126                      [&](int bits) { return bits % *minVal != 0; }))
127       return std::nullopt;
128 
129     // Require all scalar type bit counts to be a multiple of the chosen
130     // vector's primitive type to avoid reading/writing subcomponents.
131     int index = vectorIndices[std::distance(vectorNumBits.begin(), minVal)];
132     int baseNumBits = scalarNumBits[index];
133     if (llvm::any_of(scalarNumBits,
134                      [&](int bits) { return bits % baseNumBits != 0; }))
135       return std::nullopt;
136 
137     return index;
138   }
139 
140   // All element types are scalars. Then choose the smallest bitwidth as the
141   // cannonical resource to avoid subcomponent load/store.
142   auto *minVal = llvm::min_element(scalarNumBits);
143   if (llvm::any_of(scalarNumBits,
144                    [minVal](int64_t bit) { return bit % *minVal != 0; }))
145     return std::nullopt;
146   return std::distance(scalarNumBits.begin(), minVal);
147 }
148 
areSameBitwidthScalarType(Type a,Type b)149 static bool areSameBitwidthScalarType(Type a, Type b) {
150   return a.isIntOrFloat() && b.isIntOrFloat() &&
151          a.getIntOrFloatBitWidth() == b.getIntOrFloatBitWidth();
152 }
153 
154 //===----------------------------------------------------------------------===//
155 // Analysis
156 //===----------------------------------------------------------------------===//
157 
158 namespace {
159 /// A class for analyzing aliased resources.
160 ///
161 /// Resources are expected to be spirv.GlobalVarible that has a descriptor set
162 /// and binding number. Such resources are of the type
163 /// `!spirv.ptr<!spirv.struct<...>>` per Vulkan requirements.
164 ///
165 /// Right now, we only support the case that there is a single runtime array
166 /// inside the struct.
167 class ResourceAliasAnalysis {
168 public:
169   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ResourceAliasAnalysis)
170 
171   explicit ResourceAliasAnalysis(Operation *);
172 
173   /// Returns true if the given `op` can be rewritten to use a canonical
174   /// resource.
175   bool shouldUnify(Operation *op) const;
176 
177   /// Returns all descriptors and their corresponding aliased resources.
getResourceMap() const178   const AliasedResourceMap &getResourceMap() const { return resourceMap; }
179 
180   /// Returns the canonical resource for the given descriptor/variable.
181   spirv::GlobalVariableOp
182   getCanonicalResource(const Descriptor &descriptor) const;
183   spirv::GlobalVariableOp
184   getCanonicalResource(spirv::GlobalVariableOp varOp) const;
185 
186   /// Returns the element type for the given variable.
187   spirv::SPIRVType getElementType(spirv::GlobalVariableOp varOp) const;
188 
189 private:
190   /// Given the descriptor and aliased resources bound to it, analyze whether we
191   /// can unify them and record if so.
192   void recordIfUnifiable(const Descriptor &descriptor,
193                          ArrayRef<spirv::GlobalVariableOp> resources);
194 
195   /// Mapping from a descriptor to all aliased resources bound to it.
196   AliasedResourceMap resourceMap;
197 
198   /// Mapping from a descriptor to the chosen canonical resource.
199   DenseMap<Descriptor, spirv::GlobalVariableOp> canonicalResourceMap;
200 
201   /// Mapping from an aliased resource to its descriptor.
202   DenseMap<spirv::GlobalVariableOp, Descriptor> descriptorMap;
203 
204   /// Mapping from an aliased resource to its element (scalar/vector) type.
205   DenseMap<spirv::GlobalVariableOp, spirv::SPIRVType> elementTypeMap;
206 };
207 } // namespace
208 
ResourceAliasAnalysis(Operation * root)209 ResourceAliasAnalysis::ResourceAliasAnalysis(Operation *root) {
210   // Collect all aliased resources first and put them into different sets
211   // according to the descriptor.
212   AliasedResourceMap aliasedResources =
213       collectAliasedResources(cast<spirv::ModuleOp>(root));
214 
215   // For each resource set, analyze whether we can unify; if so, try to identify
216   // a canonical resource, whose element type has the largest bitwidth.
217   for (const auto &descriptorResource : aliasedResources) {
218     recordIfUnifiable(descriptorResource.first, descriptorResource.second);
219   }
220 }
221 
shouldUnify(Operation * op) const222 bool ResourceAliasAnalysis::shouldUnify(Operation *op) const {
223   if (!op)
224     return false;
225 
226   if (auto varOp = dyn_cast<spirv::GlobalVariableOp>(op)) {
227     auto canonicalOp = getCanonicalResource(varOp);
228     return canonicalOp && varOp != canonicalOp;
229   }
230   if (auto addressOp = dyn_cast<spirv::AddressOfOp>(op)) {
231     auto moduleOp = addressOp->getParentOfType<spirv::ModuleOp>();
232     auto *varOp =
233         SymbolTable::lookupSymbolIn(moduleOp, addressOp.getVariable());
234     return shouldUnify(varOp);
235   }
236 
237   if (auto acOp = dyn_cast<spirv::AccessChainOp>(op))
238     return shouldUnify(acOp.getBasePtr().getDefiningOp());
239   if (auto loadOp = dyn_cast<spirv::LoadOp>(op))
240     return shouldUnify(loadOp.getPtr().getDefiningOp());
241   if (auto storeOp = dyn_cast<spirv::StoreOp>(op))
242     return shouldUnify(storeOp.getPtr().getDefiningOp());
243 
244   return false;
245 }
246 
getCanonicalResource(const Descriptor & descriptor) const247 spirv::GlobalVariableOp ResourceAliasAnalysis::getCanonicalResource(
248     const Descriptor &descriptor) const {
249   auto varIt = canonicalResourceMap.find(descriptor);
250   if (varIt == canonicalResourceMap.end())
251     return {};
252   return varIt->second;
253 }
254 
getCanonicalResource(spirv::GlobalVariableOp varOp) const255 spirv::GlobalVariableOp ResourceAliasAnalysis::getCanonicalResource(
256     spirv::GlobalVariableOp varOp) const {
257   auto descriptorIt = descriptorMap.find(varOp);
258   if (descriptorIt == descriptorMap.end())
259     return {};
260   return getCanonicalResource(descriptorIt->second);
261 }
262 
263 spirv::SPIRVType
getElementType(spirv::GlobalVariableOp varOp) const264 ResourceAliasAnalysis::getElementType(spirv::GlobalVariableOp varOp) const {
265   auto it = elementTypeMap.find(varOp);
266   if (it == elementTypeMap.end())
267     return {};
268   return it->second;
269 }
270 
recordIfUnifiable(const Descriptor & descriptor,ArrayRef<spirv::GlobalVariableOp> resources)271 void ResourceAliasAnalysis::recordIfUnifiable(
272     const Descriptor &descriptor, ArrayRef<spirv::GlobalVariableOp> resources) {
273   // Collect the element types for all resources in the current set.
274   SmallVector<spirv::SPIRVType> elementTypes;
275   for (spirv::GlobalVariableOp resource : resources) {
276     Type elementType = getRuntimeArrayElementType(resource.getType());
277     if (!elementType)
278       return; // Unexpected resource variable type.
279 
280     auto type = cast<spirv::SPIRVType>(elementType);
281     if (!type.isScalarOrVector())
282       return; // Unexpected resource element type.
283 
284     elementTypes.push_back(type);
285   }
286 
287   std::optional<int> index = deduceCanonicalResource(elementTypes);
288   if (!index)
289     return;
290 
291   // Update internal data structures for later use.
292   resourceMap[descriptor].assign(resources.begin(), resources.end());
293   canonicalResourceMap[descriptor] = resources[*index];
294   for (const auto &resource : llvm::enumerate(resources)) {
295     descriptorMap[resource.value()] = descriptor;
296     elementTypeMap[resource.value()] = elementTypes[resource.index()];
297   }
298 }
299 
300 //===----------------------------------------------------------------------===//
301 // Patterns
302 //===----------------------------------------------------------------------===//
303 
304 template <typename OpTy>
305 class ConvertAliasResource : public OpConversionPattern<OpTy> {
306 public:
ConvertAliasResource(const ResourceAliasAnalysis & analysis,MLIRContext * context,PatternBenefit benefit=1)307   ConvertAliasResource(const ResourceAliasAnalysis &analysis,
308                        MLIRContext *context, PatternBenefit benefit = 1)
309       : OpConversionPattern<OpTy>(context, benefit), analysis(analysis) {}
310 
311 protected:
312   const ResourceAliasAnalysis &analysis;
313 };
314 
315 struct ConvertVariable : public ConvertAliasResource<spirv::GlobalVariableOp> {
316   using ConvertAliasResource::ConvertAliasResource;
317 
318   LogicalResult
matchAndRewriteConvertVariable319   matchAndRewrite(spirv::GlobalVariableOp varOp, OpAdaptor adaptor,
320                   ConversionPatternRewriter &rewriter) const override {
321     // Just remove the aliased resource. Users will be rewritten to use the
322     // canonical one.
323     rewriter.eraseOp(varOp);
324     return success();
325   }
326 };
327 
328 struct ConvertAddressOf : public ConvertAliasResource<spirv::AddressOfOp> {
329   using ConvertAliasResource::ConvertAliasResource;
330 
331   LogicalResult
matchAndRewriteConvertAddressOf332   matchAndRewrite(spirv::AddressOfOp addressOp, OpAdaptor adaptor,
333                   ConversionPatternRewriter &rewriter) const override {
334     // Rewrite the AddressOf op to get the address of the canoncical resource.
335     auto moduleOp = addressOp->getParentOfType<spirv::ModuleOp>();
336     auto srcVarOp = cast<spirv::GlobalVariableOp>(
337         SymbolTable::lookupSymbolIn(moduleOp, addressOp.getVariable()));
338     auto dstVarOp = analysis.getCanonicalResource(srcVarOp);
339     rewriter.replaceOpWithNewOp<spirv::AddressOfOp>(addressOp, dstVarOp);
340     return success();
341   }
342 };
343 
344 struct ConvertAccessChain : public ConvertAliasResource<spirv::AccessChainOp> {
345   using ConvertAliasResource::ConvertAliasResource;
346 
347   LogicalResult
matchAndRewriteConvertAccessChain348   matchAndRewrite(spirv::AccessChainOp acOp, OpAdaptor adaptor,
349                   ConversionPatternRewriter &rewriter) const override {
350     auto addressOp = acOp.getBasePtr().getDefiningOp<spirv::AddressOfOp>();
351     if (!addressOp)
352       return rewriter.notifyMatchFailure(acOp, "base ptr not addressof op");
353 
354     auto moduleOp = acOp->getParentOfType<spirv::ModuleOp>();
355     auto srcVarOp = cast<spirv::GlobalVariableOp>(
356         SymbolTable::lookupSymbolIn(moduleOp, addressOp.getVariable()));
357     auto dstVarOp = analysis.getCanonicalResource(srcVarOp);
358 
359     spirv::SPIRVType srcElemType = analysis.getElementType(srcVarOp);
360     spirv::SPIRVType dstElemType = analysis.getElementType(dstVarOp);
361 
362     if (srcElemType == dstElemType ||
363         areSameBitwidthScalarType(srcElemType, dstElemType)) {
364       // We have the same bitwidth for source and destination element types.
365       // Thie indices keep the same.
366       rewriter.replaceOpWithNewOp<spirv::AccessChainOp>(
367           acOp, adaptor.getBasePtr(), adaptor.getIndices());
368       return success();
369     }
370 
371     Location loc = acOp.getLoc();
372 
373     if (srcElemType.isIntOrFloat() && isa<VectorType>(dstElemType)) {
374       // The source indices are for a buffer with scalar element types. Rewrite
375       // them into a buffer with vector element types. We need to scale the last
376       // index for the vector as a whole, then add one level of index for inside
377       // the vector.
378       int srcNumBytes = *srcElemType.getSizeInBytes();
379       int dstNumBytes = *dstElemType.getSizeInBytes();
380       assert(dstNumBytes >= srcNumBytes && dstNumBytes % srcNumBytes == 0);
381 
382       auto indices = llvm::to_vector<4>(acOp.getIndices());
383       Value oldIndex = indices.back();
384       Type indexType = oldIndex.getType();
385 
386       int ratio = dstNumBytes / srcNumBytes;
387       auto ratioValue = rewriter.create<spirv::ConstantOp>(
388           loc, indexType, rewriter.getIntegerAttr(indexType, ratio));
389 
390       indices.back() =
391           rewriter.create<spirv::SDivOp>(loc, indexType, oldIndex, ratioValue);
392       indices.push_back(
393           rewriter.create<spirv::SModOp>(loc, indexType, oldIndex, ratioValue));
394 
395       rewriter.replaceOpWithNewOp<spirv::AccessChainOp>(
396           acOp, adaptor.getBasePtr(), indices);
397       return success();
398     }
399 
400     if ((srcElemType.isIntOrFloat() && dstElemType.isIntOrFloat()) ||
401         (isa<VectorType>(srcElemType) && isa<VectorType>(dstElemType))) {
402       // The source indices are for a buffer with larger bitwidth scalar/vector
403       // element types. Rewrite them into a buffer with smaller bitwidth element
404       // types. We only need to scale the last index.
405       int srcNumBytes = *srcElemType.getSizeInBytes();
406       int dstNumBytes = *dstElemType.getSizeInBytes();
407       assert(srcNumBytes >= dstNumBytes && srcNumBytes % dstNumBytes == 0);
408 
409       auto indices = llvm::to_vector<4>(acOp.getIndices());
410       Value oldIndex = indices.back();
411       Type indexType = oldIndex.getType();
412 
413       int ratio = srcNumBytes / dstNumBytes;
414       auto ratioValue = rewriter.create<spirv::ConstantOp>(
415           loc, indexType, rewriter.getIntegerAttr(indexType, ratio));
416 
417       indices.back() =
418           rewriter.create<spirv::IMulOp>(loc, indexType, oldIndex, ratioValue);
419 
420       rewriter.replaceOpWithNewOp<spirv::AccessChainOp>(
421           acOp, adaptor.getBasePtr(), indices);
422       return success();
423     }
424 
425     return rewriter.notifyMatchFailure(
426         acOp, "unsupported src/dst types for spirv.AccessChain");
427   }
428 };
429 
430 struct ConvertLoad : public ConvertAliasResource<spirv::LoadOp> {
431   using ConvertAliasResource::ConvertAliasResource;
432 
433   LogicalResult
matchAndRewriteConvertLoad434   matchAndRewrite(spirv::LoadOp loadOp, OpAdaptor adaptor,
435                   ConversionPatternRewriter &rewriter) const override {
436     auto srcPtrType = cast<spirv::PointerType>(loadOp.getPtr().getType());
437     auto srcElemType = cast<spirv::SPIRVType>(srcPtrType.getPointeeType());
438     auto dstPtrType = cast<spirv::PointerType>(adaptor.getPtr().getType());
439     auto dstElemType = cast<spirv::SPIRVType>(dstPtrType.getPointeeType());
440 
441     Location loc = loadOp.getLoc();
442     auto newLoadOp = rewriter.create<spirv::LoadOp>(loc, adaptor.getPtr());
443     if (srcElemType == dstElemType) {
444       rewriter.replaceOp(loadOp, newLoadOp->getResults());
445       return success();
446     }
447 
448     if (areSameBitwidthScalarType(srcElemType, dstElemType)) {
449       auto castOp = rewriter.create<spirv::BitcastOp>(loc, srcElemType,
450                                                       newLoadOp.getValue());
451       rewriter.replaceOp(loadOp, castOp->getResults());
452 
453       return success();
454     }
455 
456     if ((srcElemType.isIntOrFloat() && dstElemType.isIntOrFloat()) ||
457         (isa<VectorType>(srcElemType) && isa<VectorType>(dstElemType))) {
458       // The source and destination have scalar types of different bitwidths, or
459       // vector types of different component counts. For such cases, we load
460       // multiple smaller bitwidth values and construct a larger bitwidth one.
461 
462       int srcNumBytes = *srcElemType.getSizeInBytes();
463       int dstNumBytes = *dstElemType.getSizeInBytes();
464       assert(srcNumBytes > dstNumBytes && srcNumBytes % dstNumBytes == 0);
465       int ratio = srcNumBytes / dstNumBytes;
466       if (ratio > 4)
467         return rewriter.notifyMatchFailure(loadOp, "more than 4 components");
468 
469       SmallVector<Value> components;
470       components.reserve(ratio);
471       components.push_back(newLoadOp);
472 
473       auto acOp = adaptor.getPtr().getDefiningOp<spirv::AccessChainOp>();
474       if (!acOp)
475         return rewriter.notifyMatchFailure(loadOp, "ptr not spirv.AccessChain");
476 
477       auto i32Type = rewriter.getI32Type();
478       Value oneValue = spirv::ConstantOp::getOne(i32Type, loc, rewriter);
479       auto indices = llvm::to_vector<4>(acOp.getIndices());
480       for (int i = 1; i < ratio; ++i) {
481         // Load all subsequent components belonging to this element.
482         indices.back() = rewriter.create<spirv::IAddOp>(
483             loc, i32Type, indices.back(), oneValue);
484         auto componentAcOp = rewriter.create<spirv::AccessChainOp>(
485             loc, acOp.getBasePtr(), indices);
486         // Assuming little endian, this reads lower-ordered bits of the number
487         // to lower-numbered components of the vector.
488         components.push_back(
489             rewriter.create<spirv::LoadOp>(loc, componentAcOp));
490       }
491 
492       // Create a vector of the components and then cast back to the larger
493       // bitwidth element type. For spirv.bitcast, the lower-numbered components
494       // of the vector map to lower-ordered bits of the larger bitwidth element
495       // type.
496 
497       Type vectorType = srcElemType;
498       if (!isa<VectorType>(srcElemType))
499         vectorType = VectorType::get({ratio}, dstElemType);
500 
501       // If both the source and destination are vector types, we need to make
502       // sure the scalar type is the same for composite construction later.
503       if (auto srcElemVecType = dyn_cast<VectorType>(srcElemType))
504         if (auto dstElemVecType = dyn_cast<VectorType>(dstElemType)) {
505           if (srcElemVecType.getElementType() !=
506               dstElemVecType.getElementType()) {
507             int64_t count =
508                 dstNumBytes / (srcElemVecType.getElementTypeBitWidth() / 8);
509 
510             // Make sure not to create 1-element vectors, which are illegal in
511             // SPIR-V.
512             Type castType = srcElemVecType.getElementType();
513             if (count > 1)
514               castType = VectorType::get({count}, castType);
515 
516             for (Value &c : components)
517               c = rewriter.create<spirv::BitcastOp>(loc, castType, c);
518           }
519         }
520       Value vectorValue = rewriter.create<spirv::CompositeConstructOp>(
521           loc, vectorType, components);
522 
523       if (!isa<VectorType>(srcElemType))
524         vectorValue =
525             rewriter.create<spirv::BitcastOp>(loc, srcElemType, vectorValue);
526       rewriter.replaceOp(loadOp, vectorValue);
527       return success();
528     }
529 
530     return rewriter.notifyMatchFailure(
531         loadOp, "unsupported src/dst types for spirv.Load");
532   }
533 };
534 
535 struct ConvertStore : public ConvertAliasResource<spirv::StoreOp> {
536   using ConvertAliasResource::ConvertAliasResource;
537 
538   LogicalResult
matchAndRewriteConvertStore539   matchAndRewrite(spirv::StoreOp storeOp, OpAdaptor adaptor,
540                   ConversionPatternRewriter &rewriter) const override {
541     auto srcElemType =
542         cast<spirv::PointerType>(storeOp.getPtr().getType()).getPointeeType();
543     auto dstElemType =
544         cast<spirv::PointerType>(adaptor.getPtr().getType()).getPointeeType();
545     if (!srcElemType.isIntOrFloat() || !dstElemType.isIntOrFloat())
546       return rewriter.notifyMatchFailure(storeOp, "not scalar type");
547     if (!areSameBitwidthScalarType(srcElemType, dstElemType))
548       return rewriter.notifyMatchFailure(storeOp, "different bitwidth");
549 
550     Location loc = storeOp.getLoc();
551     Value value = adaptor.getValue();
552     if (srcElemType != dstElemType)
553       value = rewriter.create<spirv::BitcastOp>(loc, dstElemType, value);
554     rewriter.replaceOpWithNewOp<spirv::StoreOp>(storeOp, adaptor.getPtr(),
555                                                 value, storeOp->getAttrs());
556     return success();
557   }
558 };
559 
560 //===----------------------------------------------------------------------===//
561 // Pass
562 //===----------------------------------------------------------------------===//
563 
564 namespace {
565 class UnifyAliasedResourcePass final
566     : public spirv::impl::SPIRVUnifyAliasedResourcePassBase<
567           UnifyAliasedResourcePass> {
568 public:
UnifyAliasedResourcePass(spirv::GetTargetEnvFn getTargetEnv)569   explicit UnifyAliasedResourcePass(spirv::GetTargetEnvFn getTargetEnv)
570       : getTargetEnvFn(std::move(getTargetEnv)) {}
571 
572   void runOnOperation() override;
573 
574 private:
575   spirv::GetTargetEnvFn getTargetEnvFn;
576 };
577 
runOnOperation()578 void UnifyAliasedResourcePass::runOnOperation() {
579   spirv::ModuleOp moduleOp = getOperation();
580   MLIRContext *context = &getContext();
581 
582   if (getTargetEnvFn) {
583     // This pass is only needed for targeting WebGPU, Metal, or layering
584     // Vulkan on Metal via MoltenVK, where we need to translate SPIR-V into
585     // WGSL or MSL. The translation has limitations.
586     spirv::TargetEnvAttr targetEnv = getTargetEnvFn(moduleOp);
587     spirv::ClientAPI clientAPI = targetEnv.getClientAPI();
588     bool isVulkanOnAppleDevices =
589         clientAPI == spirv::ClientAPI::Vulkan &&
590         targetEnv.getVendorID() == spirv::Vendor::Apple;
591     if (clientAPI != spirv::ClientAPI::WebGPU &&
592         clientAPI != spirv::ClientAPI::Metal && !isVulkanOnAppleDevices)
593       return;
594   }
595 
596   // Analyze aliased resources first.
597   ResourceAliasAnalysis &analysis = getAnalysis<ResourceAliasAnalysis>();
598 
599   ConversionTarget target(*context);
600   target.addDynamicallyLegalOp<spirv::GlobalVariableOp, spirv::AddressOfOp,
601                                spirv::AccessChainOp, spirv::LoadOp,
602                                spirv::StoreOp>(
603       [&analysis](Operation *op) { return !analysis.shouldUnify(op); });
604   target.addLegalDialect<spirv::SPIRVDialect>();
605 
606   // Run patterns to rewrite usages of non-canonical resources.
607   RewritePatternSet patterns(context);
608   patterns.add<ConvertVariable, ConvertAddressOf, ConvertAccessChain,
609                ConvertLoad, ConvertStore>(analysis, context);
610   if (failed(applyPartialConversion(moduleOp, target, std::move(patterns))))
611     return signalPassFailure();
612 
613   // Drop aliased attribute if we only have one single bound resource for a
614   // descriptor. We need to re-collect the map here given in the above the
615   // conversion is best effort; certain sets may not be converted.
616   AliasedResourceMap resourceMap =
617       collectAliasedResources(cast<spirv::ModuleOp>(moduleOp));
618   for (const auto &dr : resourceMap) {
619     const auto &resources = dr.second;
620     if (resources.size() == 1)
621       resources.front()->removeAttr("aliased");
622   }
623 }
624 } // namespace
625 
626 std::unique_ptr<mlir::OperationPass<spirv::ModuleOp>>
createUnifyAliasedResourcePass(spirv::GetTargetEnvFn getTargetEnv)627 spirv::createUnifyAliasedResourcePass(spirv::GetTargetEnvFn getTargetEnv) {
628   return std::make_unique<UnifyAliasedResourcePass>(std::move(getTargetEnv));
629 }
630