xref: /llvm-project/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp (revision 129ec845749fe117970f71c330945b5709e1d220)
126be7fe2SLei Zhang //===- MemRefToSPIRV.cpp - MemRef to SPIR-V Patterns ----------------------===//
226be7fe2SLei Zhang //
326be7fe2SLei Zhang // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
426be7fe2SLei Zhang // See https://llvm.org/LICENSE.txt for license information.
526be7fe2SLei Zhang // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
626be7fe2SLei Zhang //
726be7fe2SLei Zhang //===----------------------------------------------------------------------===//
826be7fe2SLei Zhang //
926be7fe2SLei Zhang // This file implements patterns to convert MemRef dialect to SPIR-V dialect.
1026be7fe2SLei Zhang //
1126be7fe2SLei Zhang //===----------------------------------------------------------------------===//
1226be7fe2SLei Zhang 
1378e172fcSLei Zhang #include "mlir/Dialect/Arith/IR/Arith.h"
1426be7fe2SLei Zhang #include "mlir/Dialect/MemRef/IR/MemRef.h"
158fd0bce4SJakub Kuderski #include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
168854b736SLei Zhang #include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
1726be7fe2SLei Zhang #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
180c7f3d6cSJakub Kuderski #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
1926be7fe2SLei Zhang #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
208fd0bce4SJakub Kuderski #include "mlir/IR/BuiltinAttributes.h"
2178e172fcSLei Zhang #include "mlir/IR/BuiltinTypes.h"
228fd0bce4SJakub Kuderski #include "mlir/IR/MLIRContext.h"
238fd0bce4SJakub Kuderski #include "mlir/IR/Visitors.h"
2426be7fe2SLei Zhang #include "llvm/Support/Debug.h"
258fd0bce4SJakub Kuderski #include <cassert>
26a1fe1f5fSKazu Hirata #include <optional>
2726be7fe2SLei Zhang 
2826be7fe2SLei Zhang #define DEBUG_TYPE "memref-to-spirv-pattern"
2926be7fe2SLei Zhang 
3026be7fe2SLei Zhang using namespace mlir;
3126be7fe2SLei Zhang 
3226be7fe2SLei Zhang //===----------------------------------------------------------------------===//
3326be7fe2SLei Zhang // Utility functions
3426be7fe2SLei Zhang //===----------------------------------------------------------------------===//
3526be7fe2SLei Zhang 
3626be7fe2SLei Zhang /// Returns the offset of the value in `targetBits` representation.
3726be7fe2SLei Zhang ///
3826be7fe2SLei Zhang /// `srcIdx` is an index into a 1-D array with each element having `sourceBits`.
3926be7fe2SLei Zhang /// It's assumed to be non-negative.
4026be7fe2SLei Zhang ///
4126be7fe2SLei Zhang /// When accessing an element in the array treating as having elements of
4226be7fe2SLei Zhang /// `targetBits`, multiple values are loaded in the same time. The method
4326be7fe2SLei Zhang /// returns the offset where the `srcIdx` locates in the value. For example, if
4426be7fe2SLei Zhang /// `sourceBits` equals to 8 and `targetBits` equals to 32, the x-th element is
4526be7fe2SLei Zhang /// located at (x % 4) * 8. Because there are four elements in one i32, and one
4626be7fe2SLei Zhang /// element has 8 bits.
4726be7fe2SLei Zhang static Value getOffsetForBitwidth(Location loc, Value srcIdx, int sourceBits,
4826be7fe2SLei Zhang                                   int targetBits, OpBuilder &builder) {
4926be7fe2SLei Zhang   assert(targetBits % sourceBits == 0);
504ffc63abSLei Zhang   Type type = srcIdx.getType();
514ffc63abSLei Zhang   IntegerAttr idxAttr = builder.getIntegerAttr(type, targetBits / sourceBits);
5238f8a3cfSFinn Plummer   auto idx = builder.createOrFold<spirv::ConstantOp>(loc, type, idxAttr);
534ffc63abSLei Zhang   IntegerAttr srcBitsAttr = builder.getIntegerAttr(type, sourceBits);
5438f8a3cfSFinn Plummer   auto srcBitsValue =
5538f8a3cfSFinn Plummer       builder.createOrFold<spirv::ConstantOp>(loc, type, srcBitsAttr);
5638f8a3cfSFinn Plummer   auto m = builder.createOrFold<spirv::UModOp>(loc, srcIdx, idx);
5738f8a3cfSFinn Plummer   return builder.createOrFold<spirv::IMulOp>(loc, type, m, srcBitsValue);
5826be7fe2SLei Zhang }
5926be7fe2SLei Zhang 
6026be7fe2SLei Zhang /// Returns an adjusted spirv::AccessChainOp. Based on the
6126be7fe2SLei Zhang /// extension/capabilities, certain integer bitwidths `sourceBits` might not be
6226be7fe2SLei Zhang /// supported. During conversion if a memref of an unsupported type is used,
6326be7fe2SLei Zhang /// load/stores to this memref need to be modified to use a supported higher
6426be7fe2SLei Zhang /// bitwidth `targetBits` and extracting the required bits. For an accessing a
654ffc63abSLei Zhang /// 1D array (spirv.array or spirv.rtarray), the last index is modified to load
665ab6ef75SJakub Kuderski /// the bits needed. The extraction of the actual bits needed are handled
6726be7fe2SLei Zhang /// separately. Note that this only works for a 1-D tensor.
68ce254598SMatthias Springer static Value
69ce254598SMatthias Springer adjustAccessChainForBitwidth(const SPIRVTypeConverter &typeConverter,
70ce254598SMatthias Springer                              spirv::AccessChainOp op, int sourceBits,
71ce254598SMatthias Springer                              int targetBits, OpBuilder &builder) {
7226be7fe2SLei Zhang   assert(targetBits % sourceBits == 0);
7326be7fe2SLei Zhang   const auto loc = op.getLoc();
744ffc63abSLei Zhang   Value lastDim = op->getOperand(op.getNumOperands() - 1);
754ffc63abSLei Zhang   Type type = lastDim.getType();
764ffc63abSLei Zhang   IntegerAttr attr = builder.getIntegerAttr(type, targetBits / sourceBits);
7738f8a3cfSFinn Plummer   auto idx = builder.createOrFold<spirv::ConstantOp>(loc, type, attr);
7890a1632dSJakub Kuderski   auto indices = llvm::to_vector<4>(op.getIndices());
7926be7fe2SLei Zhang   // There are two elements if this is a 1-D tensor.
8026be7fe2SLei Zhang   assert(indices.size() == 2);
8138f8a3cfSFinn Plummer   indices.back() = builder.createOrFold<spirv::SDivOp>(loc, lastDim, idx);
8290a1632dSJakub Kuderski   Type t = typeConverter.convertType(op.getComponentPtr().getType());
8390a1632dSJakub Kuderski   return builder.create<spirv::AccessChainOp>(loc, t, op.getBasePtr(), indices);
8426be7fe2SLei Zhang }
8526be7fe2SLei Zhang 
86887e1aa3SJakub Kuderski /// Casts the given `srcBool` into an integer of `dstType`.
87887e1aa3SJakub Kuderski static Value castBoolToIntN(Location loc, Value srcBool, Type dstType,
88887e1aa3SJakub Kuderski                             OpBuilder &builder) {
89887e1aa3SJakub Kuderski   assert(srcBool.getType().isInteger(1));
90887e1aa3SJakub Kuderski   if (dstType.isInteger(1))
91887e1aa3SJakub Kuderski     return srcBool;
92887e1aa3SJakub Kuderski   Value zero = spirv::ConstantOp::getZero(dstType, loc, builder);
93887e1aa3SJakub Kuderski   Value one = spirv::ConstantOp::getOne(dstType, loc, builder);
9438f8a3cfSFinn Plummer   return builder.createOrFold<spirv::SelectOp>(loc, dstType, srcBool, one,
9538f8a3cfSFinn Plummer                                                zero);
96887e1aa3SJakub Kuderski }
97887e1aa3SJakub Kuderski 
98887e1aa3SJakub Kuderski /// Returns the `targetBits`-bit value shifted by the given `offset`, and cast
99887e1aa3SJakub Kuderski /// to the type destination type, and masked.
10026be7fe2SLei Zhang static Value shiftValue(Location loc, Value value, Value offset, Value mask,
101887e1aa3SJakub Kuderski                         OpBuilder &builder) {
102887e1aa3SJakub Kuderski   IntegerType dstType = cast<IntegerType>(mask.getType());
103887e1aa3SJakub Kuderski   int targetBits = static_cast<int>(dstType.getWidth());
104887e1aa3SJakub Kuderski   int valueBits = value.getType().getIntOrFloatBitWidth();
105887e1aa3SJakub Kuderski   assert(valueBits <= targetBits);
106887e1aa3SJakub Kuderski 
107887e1aa3SJakub Kuderski   if (valueBits == 1) {
108887e1aa3SJakub Kuderski     value = castBoolToIntN(loc, value, dstType, builder);
109887e1aa3SJakub Kuderski   } else {
110887e1aa3SJakub Kuderski     if (valueBits < targetBits) {
111887e1aa3SJakub Kuderski       value = builder.create<spirv::UConvertOp>(
112887e1aa3SJakub Kuderski           loc, builder.getIntegerType(targetBits), value);
113887e1aa3SJakub Kuderski     }
114887e1aa3SJakub Kuderski 
11538f8a3cfSFinn Plummer     value = builder.createOrFold<spirv::BitwiseAndOp>(loc, value, mask);
116887e1aa3SJakub Kuderski   }
11738f8a3cfSFinn Plummer   return builder.createOrFold<spirv::ShiftLeftLogicalOp>(loc, value.getType(),
11838f8a3cfSFinn Plummer                                                          value, offset);
11926be7fe2SLei Zhang }
12026be7fe2SLei Zhang 
1218854b736SLei Zhang /// Returns true if the allocations of memref `type` generated from `allocOp`
1228854b736SLei Zhang /// can be lowered to SPIR-V.
1238854b736SLei Zhang static bool isAllocationSupported(Operation *allocOp, MemRefType type) {
1248854b736SLei Zhang   if (isa<memref::AllocOp, memref::DeallocOp>(allocOp)) {
1255550c821STres Popp     auto sc = dyn_cast_or_null<spirv::StorageClassAttr>(type.getMemorySpace());
12689b595e1SLei Zhang     if (!sc || sc.getValue() != spirv::StorageClass::Workgroup)
12726be7fe2SLei Zhang       return false;
1288854b736SLei Zhang   } else if (isa<memref::AllocaOp>(allocOp)) {
1295550c821STres Popp     auto sc = dyn_cast_or_null<spirv::StorageClassAttr>(type.getMemorySpace());
13089b595e1SLei Zhang     if (!sc || sc.getValue() != spirv::StorageClass::Function)
1318854b736SLei Zhang       return false;
1328854b736SLei Zhang   } else {
1338854b736SLei Zhang     return false;
1348854b736SLei Zhang   }
1358854b736SLei Zhang 
1368854b736SLei Zhang   // Currently only support static shape and int or float or vector of int or
1378854b736SLei Zhang   // float element type.
1388854b736SLei Zhang   if (!type.hasStaticShape())
1398854b736SLei Zhang     return false;
1408854b736SLei Zhang 
1418854b736SLei Zhang   Type elementType = type.getElementType();
1425550c821STres Popp   if (auto vecType = dyn_cast<VectorType>(elementType))
14326be7fe2SLei Zhang     elementType = vecType.getElementType();
14426be7fe2SLei Zhang   return elementType.isIntOrFloat();
14526be7fe2SLei Zhang }
14626be7fe2SLei Zhang 
14726be7fe2SLei Zhang /// Returns the scope to use for atomic operations use for emulating store
14826be7fe2SLei Zhang /// operations of unsupported integer bitwidths, based on the memref
14970c73d1bSKazu Hirata /// type. Returns std::nullopt on failure.
1500a81ace0SKazu Hirata static std::optional<spirv::Scope> getAtomicOpScope(MemRefType type) {
1515550c821STres Popp   auto sc = dyn_cast_or_null<spirv::StorageClassAttr>(type.getMemorySpace());
15289b595e1SLei Zhang   switch (sc.getValue()) {
15326be7fe2SLei Zhang   case spirv::StorageClass::StorageBuffer:
15426be7fe2SLei Zhang     return spirv::Scope::Device;
15526be7fe2SLei Zhang   case spirv::StorageClass::Workgroup:
15626be7fe2SLei Zhang     return spirv::Scope::Workgroup;
157713f85d5SLei Zhang   default:
158713f85d5SLei Zhang     break;
15926be7fe2SLei Zhang   }
16026be7fe2SLei Zhang   return {};
16126be7fe2SLei Zhang }
16226be7fe2SLei Zhang 
1630065bd2aSLei Zhang /// Casts the given `srcInt` into a boolean value.
1640065bd2aSLei Zhang static Value castIntNToBool(Location loc, Value srcInt, OpBuilder &builder) {
1650065bd2aSLei Zhang   if (srcInt.getType().isInteger(1))
1660065bd2aSLei Zhang     return srcInt;
1670065bd2aSLei Zhang 
16827158edaSDmitriy Smirnov   auto one = spirv::ConstantOp::getZero(srcInt.getType(), loc, builder);
16927158edaSDmitriy Smirnov   return builder.createOrFold<spirv::INotEqualOp>(loc, srcInt, one);
1700065bd2aSLei Zhang }
1710065bd2aSLei Zhang 
17226be7fe2SLei Zhang //===----------------------------------------------------------------------===//
17326be7fe2SLei Zhang // Operation conversion
17426be7fe2SLei Zhang //===----------------------------------------------------------------------===//
17526be7fe2SLei Zhang 
17626be7fe2SLei Zhang // Note that DRR cannot be used for the patterns in this file: we may need to
17726be7fe2SLei Zhang // convert type along the way, which requires ConversionPattern. DRR generates
17826be7fe2SLei Zhang // normal RewritePattern.
17926be7fe2SLei Zhang 
18026be7fe2SLei Zhang namespace {
18126be7fe2SLei Zhang 
1828854b736SLei Zhang /// Converts memref.alloca to SPIR-V Function variables.
1838854b736SLei Zhang class AllocaOpPattern final : public OpConversionPattern<memref::AllocaOp> {
1848854b736SLei Zhang public:
1858854b736SLei Zhang   using OpConversionPattern<memref::AllocaOp>::OpConversionPattern;
1868854b736SLei Zhang 
1878854b736SLei Zhang   LogicalResult
1888854b736SLei Zhang   matchAndRewrite(memref::AllocaOp allocaOp, OpAdaptor adaptor,
1898854b736SLei Zhang                   ConversionPatternRewriter &rewriter) const override;
1908854b736SLei Zhang };
1918854b736SLei Zhang 
19226be7fe2SLei Zhang /// Converts an allocation operation to SPIR-V. Currently only supports lowering
19326be7fe2SLei Zhang /// to Workgroup memory when the size is constant.  Note that this pattern needs
1945ab6ef75SJakub Kuderski /// to be applied in a pass that runs at least at spirv.module scope since it
1955ab6ef75SJakub Kuderski /// wil ladd global variables into the spirv.module.
19626be7fe2SLei Zhang class AllocOpPattern final : public OpConversionPattern<memref::AllocOp> {
19726be7fe2SLei Zhang public:
19826be7fe2SLei Zhang   using OpConversionPattern<memref::AllocOp>::OpConversionPattern;
19926be7fe2SLei Zhang 
20026be7fe2SLei Zhang   LogicalResult
201b54c724bSRiver Riddle   matchAndRewrite(memref::AllocOp operation, OpAdaptor adaptor,
20226be7fe2SLei Zhang                   ConversionPatternRewriter &rewriter) const override;
20326be7fe2SLei Zhang };
20426be7fe2SLei Zhang 
20578e172fcSLei Zhang /// Converts memref.automic_rmw operations to SPIR-V atomic operations.
20678e172fcSLei Zhang class AtomicRMWOpPattern final
20778e172fcSLei Zhang     : public OpConversionPattern<memref::AtomicRMWOp> {
20878e172fcSLei Zhang public:
20978e172fcSLei Zhang   using OpConversionPattern<memref::AtomicRMWOp>::OpConversionPattern;
21078e172fcSLei Zhang 
21178e172fcSLei Zhang   LogicalResult
21278e172fcSLei Zhang   matchAndRewrite(memref::AtomicRMWOp atomicOp, OpAdaptor adaptor,
21378e172fcSLei Zhang                   ConversionPatternRewriter &rewriter) const override;
21478e172fcSLei Zhang };
21578e172fcSLei Zhang 
21626be7fe2SLei Zhang /// Removed a deallocation if it is a supported allocation. Currently only
21726be7fe2SLei Zhang /// removes deallocation if the memory space is workgroup memory.
21826be7fe2SLei Zhang class DeallocOpPattern final : public OpConversionPattern<memref::DeallocOp> {
21926be7fe2SLei Zhang public:
22026be7fe2SLei Zhang   using OpConversionPattern<memref::DeallocOp>::OpConversionPattern;
22126be7fe2SLei Zhang 
22226be7fe2SLei Zhang   LogicalResult
223b54c724bSRiver Riddle   matchAndRewrite(memref::DeallocOp operation, OpAdaptor adaptor,
22426be7fe2SLei Zhang                   ConversionPatternRewriter &rewriter) const override;
22526be7fe2SLei Zhang };
22626be7fe2SLei Zhang 
2275ab6ef75SJakub Kuderski /// Converts memref.load to spirv.Load + spirv.AccessChain on integers.
22826be7fe2SLei Zhang class IntLoadOpPattern final : public OpConversionPattern<memref::LoadOp> {
22926be7fe2SLei Zhang public:
23026be7fe2SLei Zhang   using OpConversionPattern<memref::LoadOp>::OpConversionPattern;
23126be7fe2SLei Zhang 
23226be7fe2SLei Zhang   LogicalResult
233b54c724bSRiver Riddle   matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
23426be7fe2SLei Zhang                   ConversionPatternRewriter &rewriter) const override;
23526be7fe2SLei Zhang };
23626be7fe2SLei Zhang 
2375ab6ef75SJakub Kuderski /// Converts memref.load to spirv.Load + spirv.AccessChain.
23826be7fe2SLei Zhang class LoadOpPattern final : public OpConversionPattern<memref::LoadOp> {
23926be7fe2SLei Zhang public:
24026be7fe2SLei Zhang   using OpConversionPattern<memref::LoadOp>::OpConversionPattern;
24126be7fe2SLei Zhang 
24226be7fe2SLei Zhang   LogicalResult
243b54c724bSRiver Riddle   matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
24426be7fe2SLei Zhang                   ConversionPatternRewriter &rewriter) const override;
24526be7fe2SLei Zhang };
24626be7fe2SLei Zhang 
2475ab6ef75SJakub Kuderski /// Converts memref.store to spirv.Store on integers.
24826be7fe2SLei Zhang class IntStoreOpPattern final : public OpConversionPattern<memref::StoreOp> {
24926be7fe2SLei Zhang public:
25026be7fe2SLei Zhang   using OpConversionPattern<memref::StoreOp>::OpConversionPattern;
25126be7fe2SLei Zhang 
25226be7fe2SLei Zhang   LogicalResult
253b54c724bSRiver Riddle   matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
25426be7fe2SLei Zhang                   ConversionPatternRewriter &rewriter) const override;
25526be7fe2SLei Zhang };
25626be7fe2SLei Zhang 
2577fb9bbe5SKrzysztof Drewniak /// Converts memref.memory_space_cast to the appropriate spirv cast operations.
2587fb9bbe5SKrzysztof Drewniak class MemorySpaceCastOpPattern final
2597fb9bbe5SKrzysztof Drewniak     : public OpConversionPattern<memref::MemorySpaceCastOp> {
2607fb9bbe5SKrzysztof Drewniak public:
2617fb9bbe5SKrzysztof Drewniak   using OpConversionPattern<memref::MemorySpaceCastOp>::OpConversionPattern;
2627fb9bbe5SKrzysztof Drewniak 
2637fb9bbe5SKrzysztof Drewniak   LogicalResult
2647fb9bbe5SKrzysztof Drewniak   matchAndRewrite(memref::MemorySpaceCastOp addrCastOp, OpAdaptor adaptor,
2657fb9bbe5SKrzysztof Drewniak                   ConversionPatternRewriter &rewriter) const override;
2667fb9bbe5SKrzysztof Drewniak };
2677fb9bbe5SKrzysztof Drewniak 
2685ab6ef75SJakub Kuderski /// Converts memref.store to spirv.Store.
26926be7fe2SLei Zhang class StoreOpPattern final : public OpConversionPattern<memref::StoreOp> {
27026be7fe2SLei Zhang public:
27126be7fe2SLei Zhang   using OpConversionPattern<memref::StoreOp>::OpConversionPattern;
27226be7fe2SLei Zhang 
27326be7fe2SLei Zhang   LogicalResult
274b54c724bSRiver Riddle   matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
27526be7fe2SLei Zhang                   ConversionPatternRewriter &rewriter) const override;
27626be7fe2SLei Zhang };
27726be7fe2SLei Zhang 
2785fca4ce1SIvan Butygin class ReinterpretCastPattern final
2795fca4ce1SIvan Butygin     : public OpConversionPattern<memref::ReinterpretCastOp> {
2805fca4ce1SIvan Butygin public:
2815fca4ce1SIvan Butygin   using OpConversionPattern::OpConversionPattern;
2825fca4ce1SIvan Butygin 
2835fca4ce1SIvan Butygin   LogicalResult
2845fca4ce1SIvan Butygin   matchAndRewrite(memref::ReinterpretCastOp op, OpAdaptor adaptor,
2855fca4ce1SIvan Butygin                   ConversionPatternRewriter &rewriter) const override;
2865fca4ce1SIvan Butygin };
2875fca4ce1SIvan Butygin 
288c50f335bSIvan Butygin class CastPattern final : public OpConversionPattern<memref::CastOp> {
289c50f335bSIvan Butygin public:
290c50f335bSIvan Butygin   using OpConversionPattern::OpConversionPattern;
291c50f335bSIvan Butygin 
292c50f335bSIvan Butygin   LogicalResult
293c50f335bSIvan Butygin   matchAndRewrite(memref::CastOp op, OpAdaptor adaptor,
294c50f335bSIvan Butygin                   ConversionPatternRewriter &rewriter) const override {
295c50f335bSIvan Butygin     Value src = adaptor.getSource();
296c50f335bSIvan Butygin     Type srcType = src.getType();
297c50f335bSIvan Butygin 
298ce254598SMatthias Springer     const TypeConverter *converter = getTypeConverter();
299c50f335bSIvan Butygin     Type dstType = converter->convertType(op.getType());
300c50f335bSIvan Butygin     if (srcType != dstType)
301c50f335bSIvan Butygin       return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
302c50f335bSIvan Butygin         diag << "types doesn't match: " << srcType << " and " << dstType;
303c50f335bSIvan Butygin       });
304c50f335bSIvan Butygin 
305c50f335bSIvan Butygin     rewriter.replaceOp(op, src);
306c50f335bSIvan Butygin     return success();
307c50f335bSIvan Butygin   }
308c50f335bSIvan Butygin };
309c50f335bSIvan Butygin 
31026be7fe2SLei Zhang } // namespace
31126be7fe2SLei Zhang 
31226be7fe2SLei Zhang //===----------------------------------------------------------------------===//
3138854b736SLei Zhang // AllocaOp
3148854b736SLei Zhang //===----------------------------------------------------------------------===//
3158854b736SLei Zhang 
3168854b736SLei Zhang LogicalResult
3178854b736SLei Zhang AllocaOpPattern::matchAndRewrite(memref::AllocaOp allocaOp, OpAdaptor adaptor,
3188854b736SLei Zhang                                  ConversionPatternRewriter &rewriter) const {
3198854b736SLei Zhang   MemRefType allocType = allocaOp.getType();
3208854b736SLei Zhang   if (!isAllocationSupported(allocaOp, allocType))
3218854b736SLei Zhang     return rewriter.notifyMatchFailure(allocaOp, "unhandled allocation type");
3228854b736SLei Zhang 
3238854b736SLei Zhang   // Get the SPIR-V type for the allocation.
3248854b736SLei Zhang   Type spirvType = getTypeConverter()->convertType(allocType);
325370a7eaeSJakub Kuderski   if (!spirvType)
326370a7eaeSJakub Kuderski     return rewriter.notifyMatchFailure(allocaOp, "type conversion failed");
327370a7eaeSJakub Kuderski 
3288854b736SLei Zhang   rewriter.replaceOpWithNewOp<spirv::VariableOp>(allocaOp, spirvType,
3298854b736SLei Zhang                                                  spirv::StorageClass::Function,
3308854b736SLei Zhang                                                  /*initializer=*/nullptr);
3318854b736SLei Zhang   return success();
3328854b736SLei Zhang }
3338854b736SLei Zhang 
3348854b736SLei Zhang //===----------------------------------------------------------------------===//
33526be7fe2SLei Zhang // AllocOp
33626be7fe2SLei Zhang //===----------------------------------------------------------------------===//
33726be7fe2SLei Zhang 
33826be7fe2SLei Zhang LogicalResult
339b54c724bSRiver Riddle AllocOpPattern::matchAndRewrite(memref::AllocOp operation, OpAdaptor adaptor,
34026be7fe2SLei Zhang                                 ConversionPatternRewriter &rewriter) const {
34126be7fe2SLei Zhang   MemRefType allocType = operation.getType();
3428854b736SLei Zhang   if (!isAllocationSupported(operation, allocType))
3438854b736SLei Zhang     return rewriter.notifyMatchFailure(operation, "unhandled allocation type");
34426be7fe2SLei Zhang 
34526be7fe2SLei Zhang   // Get the SPIR-V type for the allocation.
34626be7fe2SLei Zhang   Type spirvType = getTypeConverter()->convertType(allocType);
347370a7eaeSJakub Kuderski   if (!spirvType)
348370a7eaeSJakub Kuderski     return rewriter.notifyMatchFailure(operation, "type conversion failed");
34926be7fe2SLei Zhang 
3505ab6ef75SJakub Kuderski   // Insert spirv.GlobalVariable for this allocation.
35126be7fe2SLei Zhang   Operation *parent =
35226be7fe2SLei Zhang       SymbolTable::getNearestSymbolTable(operation->getParentOp());
35326be7fe2SLei Zhang   if (!parent)
35426be7fe2SLei Zhang     return failure();
35526be7fe2SLei Zhang   Location loc = operation.getLoc();
35626be7fe2SLei Zhang   spirv::GlobalVariableOp varOp;
35726be7fe2SLei Zhang   {
35826be7fe2SLei Zhang     OpBuilder::InsertionGuard guard(rewriter);
35926be7fe2SLei Zhang     Block &entryBlock = *parent->getRegion(0).begin();
36026be7fe2SLei Zhang     rewriter.setInsertionPointToStart(&entryBlock);
36126be7fe2SLei Zhang     auto varOps = entryBlock.getOps<spirv::GlobalVariableOp>();
36226be7fe2SLei Zhang     std::string varName =
36326be7fe2SLei Zhang         std::string("__workgroup_mem__") +
36426be7fe2SLei Zhang         std::to_string(std::distance(varOps.begin(), varOps.end()));
36526be7fe2SLei Zhang     varOp = rewriter.create<spirv::GlobalVariableOp>(loc, spirvType, varName,
36626be7fe2SLei Zhang                                                      /*initializer=*/nullptr);
36726be7fe2SLei Zhang   }
36826be7fe2SLei Zhang 
36926be7fe2SLei Zhang   // Get pointer to global variable at the current scope.
37026be7fe2SLei Zhang   rewriter.replaceOpWithNewOp<spirv::AddressOfOp>(operation, varOp);
37126be7fe2SLei Zhang   return success();
37226be7fe2SLei Zhang }
37326be7fe2SLei Zhang 
37426be7fe2SLei Zhang //===----------------------------------------------------------------------===//
37578e172fcSLei Zhang // AllocOp
37678e172fcSLei Zhang //===----------------------------------------------------------------------===//
37778e172fcSLei Zhang 
37878e172fcSLei Zhang LogicalResult
37978e172fcSLei Zhang AtomicRMWOpPattern::matchAndRewrite(memref::AtomicRMWOp atomicOp,
38078e172fcSLei Zhang                                     OpAdaptor adaptor,
38178e172fcSLei Zhang                                     ConversionPatternRewriter &rewriter) const {
3825550c821STres Popp   if (isa<FloatType>(atomicOp.getType()))
38378e172fcSLei Zhang     return rewriter.notifyMatchFailure(atomicOp,
38478e172fcSLei Zhang                                        "unimplemented floating-point case");
38578e172fcSLei Zhang 
3865550c821STres Popp   auto memrefType = cast<MemRefType>(atomicOp.getMemref().getType());
38778e172fcSLei Zhang   std::optional<spirv::Scope> scope = getAtomicOpScope(memrefType);
38878e172fcSLei Zhang   if (!scope)
38978e172fcSLei Zhang     return rewriter.notifyMatchFailure(atomicOp,
39078e172fcSLei Zhang                                        "unsupported memref memory space");
39178e172fcSLei Zhang 
39278e172fcSLei Zhang   auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
39378e172fcSLei Zhang   Type resultType = typeConverter.convertType(atomicOp.getType());
39478e172fcSLei Zhang   if (!resultType)
39578e172fcSLei Zhang     return rewriter.notifyMatchFailure(atomicOp,
39678e172fcSLei Zhang                                        "failed to convert result type");
39778e172fcSLei Zhang 
39878e172fcSLei Zhang   auto loc = atomicOp.getLoc();
39978e172fcSLei Zhang   Value ptr =
40078e172fcSLei Zhang       spirv::getElementPtr(typeConverter, memrefType, adaptor.getMemref(),
40178e172fcSLei Zhang                            adaptor.getIndices(), loc, rewriter);
40278e172fcSLei Zhang 
40378e172fcSLei Zhang   if (!ptr)
40478e172fcSLei Zhang     return failure();
40578e172fcSLei Zhang 
40678e172fcSLei Zhang #define ATOMIC_CASE(kind, spirvOp)                                             \
40778e172fcSLei Zhang   case arith::AtomicRMWKind::kind:                                             \
40878e172fcSLei Zhang     rewriter.replaceOpWithNewOp<spirv::spirvOp>(                               \
40978e172fcSLei Zhang         atomicOp, resultType, ptr, *scope,                                     \
41078e172fcSLei Zhang         spirv::MemorySemantics::AcquireRelease, adaptor.getValue());           \
41178e172fcSLei Zhang     break
41278e172fcSLei Zhang 
41378e172fcSLei Zhang   switch (atomicOp.getKind()) {
41478e172fcSLei Zhang     ATOMIC_CASE(addi, AtomicIAddOp);
41578e172fcSLei Zhang     ATOMIC_CASE(maxs, AtomicSMaxOp);
41678e172fcSLei Zhang     ATOMIC_CASE(maxu, AtomicUMaxOp);
41778e172fcSLei Zhang     ATOMIC_CASE(mins, AtomicSMinOp);
41878e172fcSLei Zhang     ATOMIC_CASE(minu, AtomicUMinOp);
41978e172fcSLei Zhang     ATOMIC_CASE(ori, AtomicOrOp);
42078e172fcSLei Zhang     ATOMIC_CASE(andi, AtomicAndOp);
42178e172fcSLei Zhang   default:
42278e172fcSLei Zhang     return rewriter.notifyMatchFailure(atomicOp, "unimplemented atomic kind");
42378e172fcSLei Zhang   }
42478e172fcSLei Zhang 
42578e172fcSLei Zhang #undef ATOMIC_CASE
42678e172fcSLei Zhang 
42778e172fcSLei Zhang   return success();
42878e172fcSLei Zhang }
42978e172fcSLei Zhang 
43078e172fcSLei Zhang //===----------------------------------------------------------------------===//
43126be7fe2SLei Zhang // DeallocOp
43226be7fe2SLei Zhang //===----------------------------------------------------------------------===//
43326be7fe2SLei Zhang 
43426be7fe2SLei Zhang LogicalResult
43526be7fe2SLei Zhang DeallocOpPattern::matchAndRewrite(memref::DeallocOp operation,
436b54c724bSRiver Riddle                                   OpAdaptor adaptor,
43726be7fe2SLei Zhang                                   ConversionPatternRewriter &rewriter) const {
4385550c821STres Popp   MemRefType deallocType = cast<MemRefType>(operation.getMemref().getType());
4398854b736SLei Zhang   if (!isAllocationSupported(operation, deallocType))
4408854b736SLei Zhang     return rewriter.notifyMatchFailure(operation, "unhandled allocation type");
44126be7fe2SLei Zhang   rewriter.eraseOp(operation);
44226be7fe2SLei Zhang   return success();
44326be7fe2SLei Zhang }
44426be7fe2SLei Zhang 
44526be7fe2SLei Zhang //===----------------------------------------------------------------------===//
44626be7fe2SLei Zhang // LoadOp
44726be7fe2SLei Zhang //===----------------------------------------------------------------------===//
44826be7fe2SLei Zhang 
449def16bcaSArtem Tyurin struct MemoryRequirements {
450def16bcaSArtem Tyurin   spirv::MemoryAccessAttr memoryAccess;
451def16bcaSArtem Tyurin   IntegerAttr alignment;
452def16bcaSArtem Tyurin };
4538fd0bce4SJakub Kuderski 
4548fd0bce4SJakub Kuderski /// Given an accessed SPIR-V pointer, calculates its alignment requirements, if
4558fd0bce4SJakub Kuderski /// any.
456def16bcaSArtem Tyurin static FailureOr<MemoryRequirements>
457def16bcaSArtem Tyurin calculateMemoryRequirements(Value accessedPtr, bool isNontemporal) {
458def16bcaSArtem Tyurin   MLIRContext *ctx = accessedPtr.getContext();
459def16bcaSArtem Tyurin 
460def16bcaSArtem Tyurin   auto memoryAccess = spirv::MemoryAccess::None;
461def16bcaSArtem Tyurin   if (isNontemporal) {
462def16bcaSArtem Tyurin     memoryAccess = spirv::MemoryAccess::Nontemporal;
463def16bcaSArtem Tyurin   }
464def16bcaSArtem Tyurin 
4658fd0bce4SJakub Kuderski   auto ptrType = cast<spirv::PointerType>(accessedPtr.getType());
466def16bcaSArtem Tyurin   if (ptrType.getStorageClass() != spirv::StorageClass::PhysicalStorageBuffer) {
467def16bcaSArtem Tyurin     if (memoryAccess == spirv::MemoryAccess::None) {
468def16bcaSArtem Tyurin       return MemoryRequirements{spirv::MemoryAccessAttr{}, IntegerAttr{}};
469def16bcaSArtem Tyurin     }
470def16bcaSArtem Tyurin     return MemoryRequirements{spirv::MemoryAccessAttr::get(ctx, memoryAccess),
471def16bcaSArtem Tyurin                               IntegerAttr{}};
472def16bcaSArtem Tyurin   }
4738fd0bce4SJakub Kuderski 
4748fd0bce4SJakub Kuderski   // PhysicalStorageBuffers require the `Aligned` attribute.
4758fd0bce4SJakub Kuderski   auto pointeeType = dyn_cast<spirv::ScalarType>(ptrType.getPointeeType());
4768fd0bce4SJakub Kuderski   if (!pointeeType)
4778fd0bce4SJakub Kuderski     return failure();
4788fd0bce4SJakub Kuderski 
4798fd0bce4SJakub Kuderski   // For scalar types, the alignment is determined by their size.
4808fd0bce4SJakub Kuderski   std::optional<int64_t> sizeInBytes = pointeeType.getSizeInBytes();
4818fd0bce4SJakub Kuderski   if (!sizeInBytes.has_value())
4828fd0bce4SJakub Kuderski     return failure();
4838fd0bce4SJakub Kuderski 
484def16bcaSArtem Tyurin   memoryAccess = memoryAccess | spirv::MemoryAccess::Aligned;
485def16bcaSArtem Tyurin   auto memAccessAttr = spirv::MemoryAccessAttr::get(ctx, memoryAccess);
4868fd0bce4SJakub Kuderski   auto alignment = IntegerAttr::get(IntegerType::get(ctx, 32), *sizeInBytes);
487def16bcaSArtem Tyurin   return MemoryRequirements{memAccessAttr, alignment};
4888fd0bce4SJakub Kuderski }
4898fd0bce4SJakub Kuderski 
4908fd0bce4SJakub Kuderski /// Given an accessed SPIR-V pointer and the original memref load/store
4918fd0bce4SJakub Kuderski /// `memAccess` op, calculates the alignment requirements, if any. Takes into
4928fd0bce4SJakub Kuderski /// account the alignment attributes applied to the load/store op.
493def16bcaSArtem Tyurin template <class LoadOrStoreOp>
494def16bcaSArtem Tyurin static FailureOr<MemoryRequirements>
495def16bcaSArtem Tyurin calculateMemoryRequirements(Value accessedPtr, LoadOrStoreOp loadOrStoreOp) {
496def16bcaSArtem Tyurin   static_assert(
497def16bcaSArtem Tyurin       llvm::is_one_of<LoadOrStoreOp, memref::LoadOp, memref::StoreOp>::value,
498def16bcaSArtem Tyurin       "Must be called on either memref::LoadOp or memref::StoreOp");
4998fd0bce4SJakub Kuderski 
500def16bcaSArtem Tyurin   Operation *memrefAccessOp = loadOrStoreOp.getOperation();
5018fd0bce4SJakub Kuderski   auto memrefMemAccess = memrefAccessOp->getAttrOfType<spirv::MemoryAccessAttr>(
5028fd0bce4SJakub Kuderski       spirv::attributeName<spirv::MemoryAccess>());
5038fd0bce4SJakub Kuderski   auto memrefAlignment =
5048fd0bce4SJakub Kuderski       memrefAccessOp->getAttrOfType<IntegerAttr>("alignment");
5058fd0bce4SJakub Kuderski   if (memrefMemAccess && memrefAlignment)
506def16bcaSArtem Tyurin     return MemoryRequirements{memrefMemAccess, memrefAlignment};
5078fd0bce4SJakub Kuderski 
508def16bcaSArtem Tyurin   return calculateMemoryRequirements(accessedPtr,
509def16bcaSArtem Tyurin                                      loadOrStoreOp.getNontemporal());
5108fd0bce4SJakub Kuderski }
5118fd0bce4SJakub Kuderski 
51226be7fe2SLei Zhang LogicalResult
513b54c724bSRiver Riddle IntLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
51426be7fe2SLei Zhang                                   ConversionPatternRewriter &rewriter) const {
51526be7fe2SLei Zhang   auto loc = loadOp.getLoc();
5165550c821STres Popp   auto memrefType = cast<MemRefType>(loadOp.getMemref().getType());
51726be7fe2SLei Zhang   if (!memrefType.getElementType().isSignlessInteger())
51826be7fe2SLei Zhang     return failure();
51926be7fe2SLei Zhang 
520ce254598SMatthias Springer   const auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
5219c3a73a5SStanley Winata   Value accessChain =
522136d746eSJacques Pienaar       spirv::getElementPtr(typeConverter, memrefType, adaptor.getMemref(),
523136d746eSJacques Pienaar                            adaptor.getIndices(), loc, rewriter);
52426be7fe2SLei Zhang 
5259c3a73a5SStanley Winata   if (!accessChain)
5261e9799e2SButygin     return failure();
5271e9799e2SButygin 
52826be7fe2SLei Zhang   int srcBits = memrefType.getElementType().getIntOrFloatBitWidth();
52926be7fe2SLei Zhang   bool isBool = srcBits == 1;
53026be7fe2SLei Zhang   if (isBool)
53126be7fe2SLei Zhang     srcBits = typeConverter.getOptions().boolNumBits;
5320c7f3d6cSJakub Kuderski 
5330c7f3d6cSJakub Kuderski   auto pointerType = typeConverter.convertType<spirv::PointerType>(memrefType);
5340c7f3d6cSJakub Kuderski   if (!pointerType)
5350c7f3d6cSJakub Kuderski     return rewriter.notifyMatchFailure(loadOp, "failed to convert memref type");
5360c7f3d6cSJakub Kuderski 
5370c7f3d6cSJakub Kuderski   Type pointeeType = pointerType.getPointeeType();
53826be7fe2SLei Zhang   Type dstType;
5399c3a73a5SStanley Winata   if (typeConverter.allows(spirv::Capability::Kernel)) {
5405550c821STres Popp     if (auto arrayType = dyn_cast<spirv::ArrayType>(pointeeType))
541d6de6ddeSNirvedh Meshram       dstType = arrayType.getElementType();
542d6de6ddeSNirvedh Meshram     else
5439c3a73a5SStanley Winata       dstType = pointeeType;
5449c3a73a5SStanley Winata   } else {
5459c3a73a5SStanley Winata     // For Vulkan we need to extract element from wrapping struct and array.
5469c3a73a5SStanley Winata     Type structElemType =
5475550c821STres Popp         cast<spirv::StructType>(pointeeType).getElementType(0);
5485550c821STres Popp     if (auto arrayType = dyn_cast<spirv::ArrayType>(structElemType))
54926be7fe2SLei Zhang       dstType = arrayType.getElementType();
55026be7fe2SLei Zhang     else
5515550c821STres Popp       dstType = cast<spirv::RuntimeArrayType>(structElemType).getElementType();
5529c3a73a5SStanley Winata   }
55326be7fe2SLei Zhang   int dstBits = dstType.getIntOrFloatBitWidth();
55426be7fe2SLei Zhang   assert(dstBits % srcBits == 0);
55526be7fe2SLei Zhang 
556370a7eaeSJakub Kuderski   // If the rewritten load op has the same bit width, use the loading value
55726be7fe2SLei Zhang   // directly.
55826be7fe2SLei Zhang   if (srcBits == dstBits) {
559def16bcaSArtem Tyurin     auto memoryRequirements = calculateMemoryRequirements(accessChain, loadOp);
560def16bcaSArtem Tyurin     if (failed(memoryRequirements))
5618fd0bce4SJakub Kuderski       return rewriter.notifyMatchFailure(
562def16bcaSArtem Tyurin           loadOp, "failed to determine memory requirements");
5638fd0bce4SJakub Kuderski 
564def16bcaSArtem Tyurin     auto [memoryAccess, alignment] = *memoryRequirements;
5658fd0bce4SJakub Kuderski     Value loadVal = rewriter.create<spirv::LoadOp>(loc, accessChain,
5668fd0bce4SJakub Kuderski                                                    memoryAccess, alignment);
5670065bd2aSLei Zhang     if (isBool)
5680065bd2aSLei Zhang       loadVal = castIntNToBool(loc, loadVal, rewriter);
5690065bd2aSLei Zhang     rewriter.replaceOp(loadOp, loadVal);
57026be7fe2SLei Zhang     return success();
57126be7fe2SLei Zhang   }
57226be7fe2SLei Zhang 
5739c3a73a5SStanley Winata   // Bitcasting is currently unsupported for Kernel capability /
5745ab6ef75SJakub Kuderski   // spirv.PtrAccessChain.
5759c3a73a5SStanley Winata   if (typeConverter.allows(spirv::Capability::Kernel))
5769c3a73a5SStanley Winata     return failure();
5779c3a73a5SStanley Winata 
5789c3a73a5SStanley Winata   auto accessChainOp = accessChain.getDefiningOp<spirv::AccessChainOp>();
5799c3a73a5SStanley Winata   if (!accessChainOp)
5809c3a73a5SStanley Winata     return failure();
5819c3a73a5SStanley Winata 
58226be7fe2SLei Zhang   // Assume that getElementPtr() works linearizely. If it's a scalar, the method
58326be7fe2SLei Zhang   // still returns a linearized accessing. If the accessing is not linearized,
58426be7fe2SLei Zhang   // there will be offset issues.
58590a1632dSJakub Kuderski   assert(accessChainOp.getIndices().size() == 2);
58626be7fe2SLei Zhang   Value adjustedPtr = adjustAccessChainForBitwidth(typeConverter, accessChainOp,
58726be7fe2SLei Zhang                                                    srcBits, dstBits, rewriter);
588def16bcaSArtem Tyurin   auto memoryRequirements = calculateMemoryRequirements(adjustedPtr, loadOp);
589def16bcaSArtem Tyurin   if (failed(memoryRequirements))
5908fd0bce4SJakub Kuderski     return rewriter.notifyMatchFailure(
591def16bcaSArtem Tyurin         loadOp, "failed to determine memory requirements");
5928fd0bce4SJakub Kuderski 
593def16bcaSArtem Tyurin   auto [memoryAccess, alignment] = *memoryRequirements;
5948fd0bce4SJakub Kuderski   Value spvLoadOp = rewriter.create<spirv::LoadOp>(loc, dstType, adjustedPtr,
5958fd0bce4SJakub Kuderski                                                    memoryAccess, alignment);
59626be7fe2SLei Zhang 
59726be7fe2SLei Zhang   // Shift the bits to the rightmost.
59826be7fe2SLei Zhang   // ____XXXX________ -> ____________XXXX
59926be7fe2SLei Zhang   Value lastDim = accessChainOp->getOperand(accessChainOp.getNumOperands() - 1);
60026be7fe2SLei Zhang   Value offset = getOffsetForBitwidth(loc, lastDim, srcBits, dstBits, rewriter);
60138f8a3cfSFinn Plummer   Value result = rewriter.createOrFold<spirv::ShiftRightArithmeticOp>(
60226be7fe2SLei Zhang       loc, spvLoadOp.getType(), spvLoadOp, offset);
60326be7fe2SLei Zhang 
60426be7fe2SLei Zhang   // Apply the mask to extract corresponding bits.
60538f8a3cfSFinn Plummer   Value mask = rewriter.createOrFold<spirv::ConstantOp>(
60626be7fe2SLei Zhang       loc, dstType, rewriter.getIntegerAttr(dstType, (1 << srcBits) - 1));
60738f8a3cfSFinn Plummer   result =
60838f8a3cfSFinn Plummer       rewriter.createOrFold<spirv::BitwiseAndOp>(loc, dstType, result, mask);
60926be7fe2SLei Zhang 
61026be7fe2SLei Zhang   // Apply sign extension on the loading value unconditionally. The signedness
61126be7fe2SLei Zhang   // semantic is carried in the operator itself, we relies other pattern to
61226be7fe2SLei Zhang   // handle the casting.
61326be7fe2SLei Zhang   IntegerAttr shiftValueAttr =
61426be7fe2SLei Zhang       rewriter.getIntegerAttr(dstType, dstBits - srcBits);
61526be7fe2SLei Zhang   Value shiftValue =
61638f8a3cfSFinn Plummer       rewriter.createOrFold<spirv::ConstantOp>(loc, dstType, shiftValueAttr);
61738f8a3cfSFinn Plummer   result = rewriter.createOrFold<spirv::ShiftLeftLogicalOp>(loc, dstType,
61838f8a3cfSFinn Plummer                                                             result, shiftValue);
61938f8a3cfSFinn Plummer   result = rewriter.createOrFold<spirv::ShiftRightArithmeticOp>(
62038f8a3cfSFinn Plummer       loc, dstType, result, shiftValue);
62126be7fe2SLei Zhang 
62226be7fe2SLei Zhang   rewriter.replaceOp(loadOp, result);
62326be7fe2SLei Zhang 
62426be7fe2SLei Zhang   assert(accessChainOp.use_empty());
62526be7fe2SLei Zhang   rewriter.eraseOp(accessChainOp);
62626be7fe2SLei Zhang 
62726be7fe2SLei Zhang   return success();
62826be7fe2SLei Zhang }
62926be7fe2SLei Zhang 
63026be7fe2SLei Zhang LogicalResult
631b54c724bSRiver Riddle LoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
63226be7fe2SLei Zhang                                ConversionPatternRewriter &rewriter) const {
6335550c821STres Popp   auto memrefType = cast<MemRefType>(loadOp.getMemref().getType());
63426be7fe2SLei Zhang   if (memrefType.getElementType().isSignlessInteger())
63526be7fe2SLei Zhang     return failure();
6368fd0bce4SJakub Kuderski   Value loadPtr = spirv::getElementPtr(
637136d746eSJacques Pienaar       *getTypeConverter<SPIRVTypeConverter>(), memrefType, adaptor.getMemref(),
638136d746eSJacques Pienaar       adaptor.getIndices(), loadOp.getLoc(), rewriter);
6391e9799e2SButygin 
6401e9799e2SButygin   if (!loadPtr)
6411e9799e2SButygin     return failure();
6421e9799e2SButygin 
643def16bcaSArtem Tyurin   auto memoryRequirements = calculateMemoryRequirements(loadPtr, loadOp);
644def16bcaSArtem Tyurin   if (failed(memoryRequirements))
6458fd0bce4SJakub Kuderski     return rewriter.notifyMatchFailure(
646def16bcaSArtem Tyurin         loadOp, "failed to determine memory requirements");
6478fd0bce4SJakub Kuderski 
648def16bcaSArtem Tyurin   auto [memoryAccess, alignment] = *memoryRequirements;
649def16bcaSArtem Tyurin   rewriter.replaceOpWithNewOp<spirv::LoadOp>(loadOp, loadPtr, memoryAccess,
6508fd0bce4SJakub Kuderski                                              alignment);
65126be7fe2SLei Zhang   return success();
65226be7fe2SLei Zhang }
65326be7fe2SLei Zhang 
65426be7fe2SLei Zhang LogicalResult
655b54c724bSRiver Riddle IntStoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
65626be7fe2SLei Zhang                                    ConversionPatternRewriter &rewriter) const {
6575550c821STres Popp   auto memrefType = cast<MemRefType>(storeOp.getMemref().getType());
65826be7fe2SLei Zhang   if (!memrefType.getElementType().isSignlessInteger())
659887e1aa3SJakub Kuderski     return rewriter.notifyMatchFailure(storeOp,
660887e1aa3SJakub Kuderski                                        "element type is not a signless int");
66126be7fe2SLei Zhang 
66226be7fe2SLei Zhang   auto loc = storeOp.getLoc();
66326be7fe2SLei Zhang   auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
6649c3a73a5SStanley Winata   Value accessChain =
665136d746eSJacques Pienaar       spirv::getElementPtr(typeConverter, memrefType, adaptor.getMemref(),
666136d746eSJacques Pienaar                            adaptor.getIndices(), loc, rewriter);
6671e9799e2SButygin 
6689c3a73a5SStanley Winata   if (!accessChain)
669887e1aa3SJakub Kuderski     return rewriter.notifyMatchFailure(
670887e1aa3SJakub Kuderski         storeOp, "failed to convert element pointer type");
6711e9799e2SButygin 
67226be7fe2SLei Zhang   int srcBits = memrefType.getElementType().getIntOrFloatBitWidth();
67326be7fe2SLei Zhang 
67426be7fe2SLei Zhang   bool isBool = srcBits == 1;
67526be7fe2SLei Zhang   if (isBool)
67626be7fe2SLei Zhang     srcBits = typeConverter.getOptions().boolNumBits;
6779f5300c8SLei Zhang 
6780c7f3d6cSJakub Kuderski   auto pointerType = typeConverter.convertType<spirv::PointerType>(memrefType);
6790c7f3d6cSJakub Kuderski   if (!pointerType)
6800c7f3d6cSJakub Kuderski     return rewriter.notifyMatchFailure(storeOp,
6810c7f3d6cSJakub Kuderski                                        "failed to convert memref type");
6820c7f3d6cSJakub Kuderski 
6830c7f3d6cSJakub Kuderski   Type pointeeType = pointerType.getPointeeType();
684887e1aa3SJakub Kuderski   IntegerType dstType;
6859c3a73a5SStanley Winata   if (typeConverter.allows(spirv::Capability::Kernel)) {
6865550c821STres Popp     if (auto arrayType = dyn_cast<spirv::ArrayType>(pointeeType))
687887e1aa3SJakub Kuderski       dstType = dyn_cast<IntegerType>(arrayType.getElementType());
688d6de6ddeSNirvedh Meshram     else
689887e1aa3SJakub Kuderski       dstType = dyn_cast<IntegerType>(pointeeType);
6909c3a73a5SStanley Winata   } else {
6919c3a73a5SStanley Winata     // For Vulkan we need to extract element from wrapping struct and array.
6929c3a73a5SStanley Winata     Type structElemType =
6935550c821STres Popp         cast<spirv::StructType>(pointeeType).getElementType(0);
6945550c821STres Popp     if (auto arrayType = dyn_cast<spirv::ArrayType>(structElemType))
695887e1aa3SJakub Kuderski       dstType = dyn_cast<IntegerType>(arrayType.getElementType());
69626be7fe2SLei Zhang     else
697887e1aa3SJakub Kuderski       dstType = dyn_cast<IntegerType>(
698887e1aa3SJakub Kuderski           cast<spirv::RuntimeArrayType>(structElemType).getElementType());
6999c3a73a5SStanley Winata   }
70026be7fe2SLei Zhang 
701887e1aa3SJakub Kuderski   if (!dstType)
702887e1aa3SJakub Kuderski     return rewriter.notifyMatchFailure(
703887e1aa3SJakub Kuderski         storeOp, "failed to determine destination element type");
704887e1aa3SJakub Kuderski 
705887e1aa3SJakub Kuderski   int dstBits = static_cast<int>(dstType.getWidth());
70626be7fe2SLei Zhang   assert(dstBits % srcBits == 0);
70726be7fe2SLei Zhang 
70826be7fe2SLei Zhang   if (srcBits == dstBits) {
709def16bcaSArtem Tyurin     auto memoryRequirements = calculateMemoryRequirements(accessChain, storeOp);
710def16bcaSArtem Tyurin     if (failed(memoryRequirements))
7118fd0bce4SJakub Kuderski       return rewriter.notifyMatchFailure(
712def16bcaSArtem Tyurin           storeOp, "failed to determine memory requirements");
7138fd0bce4SJakub Kuderski 
714def16bcaSArtem Tyurin     auto [memoryAccess, alignment] = *memoryRequirements;
715136d746eSJacques Pienaar     Value storeVal = adaptor.getValue();
7169f5300c8SLei Zhang     if (isBool)
7179f5300c8SLei Zhang       storeVal = castBoolToIntN(loc, storeVal, dstType, rewriter);
7188fd0bce4SJakub Kuderski     rewriter.replaceOpWithNewOp<spirv::StoreOp>(storeOp, accessChain, storeVal,
719def16bcaSArtem Tyurin                                                 memoryAccess, alignment);
72026be7fe2SLei Zhang     return success();
72126be7fe2SLei Zhang   }
72226be7fe2SLei Zhang 
7239c3a73a5SStanley Winata   // Bitcasting is currently unsupported for Kernel capability /
7245ab6ef75SJakub Kuderski   // spirv.PtrAccessChain.
7259c3a73a5SStanley Winata   if (typeConverter.allows(spirv::Capability::Kernel))
7269c3a73a5SStanley Winata     return failure();
7279c3a73a5SStanley Winata 
7289c3a73a5SStanley Winata   auto accessChainOp = accessChain.getDefiningOp<spirv::AccessChainOp>();
7299c3a73a5SStanley Winata   if (!accessChainOp)
7309c3a73a5SStanley Winata     return failure();
7319c3a73a5SStanley Winata 
732887e1aa3SJakub Kuderski   // Since there are multiple threads in the processing, the emulation will be
733887e1aa3SJakub Kuderski   // done with atomic operations. E.g., if the stored value is i8, rewrite the
734887e1aa3SJakub Kuderski   // StoreOp to:
73526be7fe2SLei Zhang   // 1) load a 32-bit integer
736887e1aa3SJakub Kuderski   // 2) clear 8 bits in the loaded value
737887e1aa3SJakub Kuderski   // 3) set 8 bits in the loaded value
738887e1aa3SJakub Kuderski   // 4) store 32-bit value back
739887e1aa3SJakub Kuderski   //
740887e1aa3SJakub Kuderski   // Step 2 is done with AtomicAnd, and step 3 is done with AtomicOr (of the
741887e1aa3SJakub Kuderski   // loaded 32-bit value and the shifted 8-bit store value) as another atomic
742887e1aa3SJakub Kuderski   // step.
74390a1632dSJakub Kuderski   assert(accessChainOp.getIndices().size() == 2);
74426be7fe2SLei Zhang   Value lastDim = accessChainOp->getOperand(accessChainOp.getNumOperands() - 1);
74526be7fe2SLei Zhang   Value offset = getOffsetForBitwidth(loc, lastDim, srcBits, dstBits, rewriter);
74626be7fe2SLei Zhang 
74726be7fe2SLei Zhang   // Create a mask to clear the destination. E.g., if it is the second i8 in
74826be7fe2SLei Zhang   // i32, 0xFFFF00FF is created.
74938f8a3cfSFinn Plummer   Value mask = rewriter.createOrFold<spirv::ConstantOp>(
75026be7fe2SLei Zhang       loc, dstType, rewriter.getIntegerAttr(dstType, (1 << srcBits) - 1));
75138f8a3cfSFinn Plummer   Value clearBitsMask = rewriter.createOrFold<spirv::ShiftLeftLogicalOp>(
75238f8a3cfSFinn Plummer       loc, dstType, mask, offset);
75338f8a3cfSFinn Plummer   clearBitsMask =
75438f8a3cfSFinn Plummer       rewriter.createOrFold<spirv::NotOp>(loc, dstType, clearBitsMask);
75526be7fe2SLei Zhang 
756887e1aa3SJakub Kuderski   Value storeVal = shiftValue(loc, adaptor.getValue(), offset, mask, rewriter);
75726be7fe2SLei Zhang   Value adjustedPtr = adjustAccessChainForBitwidth(typeConverter, accessChainOp,
75826be7fe2SLei Zhang                                                    srcBits, dstBits, rewriter);
7590a81ace0SKazu Hirata   std::optional<spirv::Scope> scope = getAtomicOpScope(memrefType);
76026be7fe2SLei Zhang   if (!scope)
761887e1aa3SJakub Kuderski     return rewriter.notifyMatchFailure(storeOp, "atomic scope not available");
762887e1aa3SJakub Kuderski 
76326be7fe2SLei Zhang   Value result = rewriter.create<spirv::AtomicAndOp>(
76426be7fe2SLei Zhang       loc, dstType, adjustedPtr, *scope, spirv::MemorySemantics::AcquireRelease,
76526be7fe2SLei Zhang       clearBitsMask);
76626be7fe2SLei Zhang   result = rewriter.create<spirv::AtomicOrOp>(
76726be7fe2SLei Zhang       loc, dstType, adjustedPtr, *scope, spirv::MemorySemantics::AcquireRelease,
76826be7fe2SLei Zhang       storeVal);
76926be7fe2SLei Zhang 
77026be7fe2SLei Zhang   // The AtomicOrOp has no side effect. Since it is already inserted, we can
77126be7fe2SLei Zhang   // just remove the original StoreOp. Note that rewriter.replaceOp()
77226be7fe2SLei Zhang   // doesn't work because it only accepts that the numbers of result are the
77326be7fe2SLei Zhang   // same.
77426be7fe2SLei Zhang   rewriter.eraseOp(storeOp);
77526be7fe2SLei Zhang 
77626be7fe2SLei Zhang   assert(accessChainOp.use_empty());
77726be7fe2SLei Zhang   rewriter.eraseOp(accessChainOp);
77826be7fe2SLei Zhang 
77926be7fe2SLei Zhang   return success();
78026be7fe2SLei Zhang }
78126be7fe2SLei Zhang 
7827fb9bbe5SKrzysztof Drewniak //===----------------------------------------------------------------------===//
7837fb9bbe5SKrzysztof Drewniak // MemorySpaceCastOp
7847fb9bbe5SKrzysztof Drewniak //===----------------------------------------------------------------------===//
7857fb9bbe5SKrzysztof Drewniak 
7867fb9bbe5SKrzysztof Drewniak LogicalResult MemorySpaceCastOpPattern::matchAndRewrite(
7877fb9bbe5SKrzysztof Drewniak     memref::MemorySpaceCastOp addrCastOp, OpAdaptor adaptor,
7887fb9bbe5SKrzysztof Drewniak     ConversionPatternRewriter &rewriter) const {
7897fb9bbe5SKrzysztof Drewniak   Location loc = addrCastOp.getLoc();
7907fb9bbe5SKrzysztof Drewniak   auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
7917fb9bbe5SKrzysztof Drewniak   if (!typeConverter.allows(spirv::Capability::Kernel))
7927fb9bbe5SKrzysztof Drewniak     return rewriter.notifyMatchFailure(
7937fb9bbe5SKrzysztof Drewniak         loc, "address space casts require kernel capability");
7947fb9bbe5SKrzysztof Drewniak 
7955550c821STres Popp   auto sourceType = dyn_cast<MemRefType>(addrCastOp.getSource().getType());
7967fb9bbe5SKrzysztof Drewniak   if (!sourceType)
7977fb9bbe5SKrzysztof Drewniak     return rewriter.notifyMatchFailure(
7987fb9bbe5SKrzysztof Drewniak         loc, "SPIR-V lowering requires ranked memref types");
7995550c821STres Popp   auto resultType = cast<MemRefType>(addrCastOp.getResult().getType());
8007fb9bbe5SKrzysztof Drewniak 
8017fb9bbe5SKrzysztof Drewniak   auto sourceStorageClassAttr =
8025550c821STres Popp       dyn_cast_or_null<spirv::StorageClassAttr>(sourceType.getMemorySpace());
8037fb9bbe5SKrzysztof Drewniak   if (!sourceStorageClassAttr)
8047fb9bbe5SKrzysztof Drewniak     return rewriter.notifyMatchFailure(loc, [sourceType](Diagnostic &diag) {
8057fb9bbe5SKrzysztof Drewniak       diag << "source address space " << sourceType.getMemorySpace()
8067fb9bbe5SKrzysztof Drewniak            << " must be a SPIR-V storage class";
8077fb9bbe5SKrzysztof Drewniak     });
8087fb9bbe5SKrzysztof Drewniak   auto resultStorageClassAttr =
8095550c821STres Popp       dyn_cast_or_null<spirv::StorageClassAttr>(resultType.getMemorySpace());
8107fb9bbe5SKrzysztof Drewniak   if (!resultStorageClassAttr)
8117fb9bbe5SKrzysztof Drewniak     return rewriter.notifyMatchFailure(loc, [resultType](Diagnostic &diag) {
8127fb9bbe5SKrzysztof Drewniak       diag << "result address space " << resultType.getMemorySpace()
8137fb9bbe5SKrzysztof Drewniak            << " must be a SPIR-V storage class";
8147fb9bbe5SKrzysztof Drewniak     });
8157fb9bbe5SKrzysztof Drewniak 
8167fb9bbe5SKrzysztof Drewniak   spirv::StorageClass sourceSc = sourceStorageClassAttr.getValue();
8177fb9bbe5SKrzysztof Drewniak   spirv::StorageClass resultSc = resultStorageClassAttr.getValue();
8187fb9bbe5SKrzysztof Drewniak 
8197fb9bbe5SKrzysztof Drewniak   Value result = adaptor.getSource();
8207fb9bbe5SKrzysztof Drewniak   Type resultPtrType = typeConverter.convertType(resultType);
821370a7eaeSJakub Kuderski   if (!resultPtrType)
822370a7eaeSJakub Kuderski     return rewriter.notifyMatchFailure(addrCastOp,
823370a7eaeSJakub Kuderski                                        "failed to convert memref type");
824370a7eaeSJakub Kuderski 
8257fb9bbe5SKrzysztof Drewniak   Type genericPtrType = resultPtrType;
8267fb9bbe5SKrzysztof Drewniak   // SPIR-V doesn't have a general address space cast operation. Instead, it has
8277fb9bbe5SKrzysztof Drewniak   // conversions to and from generic pointers. To implement the general case,
8287fb9bbe5SKrzysztof Drewniak   // we use specific-to-generic conversions when the source class is not
8297fb9bbe5SKrzysztof Drewniak   // generic. Then when the result storage class is not generic, we convert the
830370a7eaeSJakub Kuderski   // generic pointer (either the input on ar intermediate result) to that
8317fb9bbe5SKrzysztof Drewniak   // class. This also means that we'll need the intermediate generic pointer
8327fb9bbe5SKrzysztof Drewniak   // type if neither the source or destination have it.
8337fb9bbe5SKrzysztof Drewniak   if (sourceSc != spirv::StorageClass::Generic &&
8347fb9bbe5SKrzysztof Drewniak       resultSc != spirv::StorageClass::Generic) {
8357fb9bbe5SKrzysztof Drewniak     Type intermediateType =
8367fb9bbe5SKrzysztof Drewniak         MemRefType::get(sourceType.getShape(), sourceType.getElementType(),
8377fb9bbe5SKrzysztof Drewniak                         sourceType.getLayout(),
8387fb9bbe5SKrzysztof Drewniak                         rewriter.getAttr<spirv::StorageClassAttr>(
8397fb9bbe5SKrzysztof Drewniak                             spirv::StorageClass::Generic));
8407fb9bbe5SKrzysztof Drewniak     genericPtrType = typeConverter.convertType(intermediateType);
8417fb9bbe5SKrzysztof Drewniak   }
8427fb9bbe5SKrzysztof Drewniak   if (sourceSc != spirv::StorageClass::Generic) {
8437fb9bbe5SKrzysztof Drewniak     result =
8447fb9bbe5SKrzysztof Drewniak         rewriter.create<spirv::PtrCastToGenericOp>(loc, genericPtrType, result);
8457fb9bbe5SKrzysztof Drewniak   }
8467fb9bbe5SKrzysztof Drewniak   if (resultSc != spirv::StorageClass::Generic) {
8477fb9bbe5SKrzysztof Drewniak     result =
8487fb9bbe5SKrzysztof Drewniak         rewriter.create<spirv::GenericCastToPtrOp>(loc, resultPtrType, result);
8497fb9bbe5SKrzysztof Drewniak   }
8507fb9bbe5SKrzysztof Drewniak   rewriter.replaceOp(addrCastOp, result);
8517fb9bbe5SKrzysztof Drewniak   return success();
8527fb9bbe5SKrzysztof Drewniak }
8537fb9bbe5SKrzysztof Drewniak 
85426be7fe2SLei Zhang LogicalResult
855b54c724bSRiver Riddle StoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
85626be7fe2SLei Zhang                                 ConversionPatternRewriter &rewriter) const {
8575550c821STres Popp   auto memrefType = cast<MemRefType>(storeOp.getMemref().getType());
85826be7fe2SLei Zhang   if (memrefType.getElementType().isSignlessInteger())
859887e1aa3SJakub Kuderski     return rewriter.notifyMatchFailure(storeOp, "signless int");
860b54c724bSRiver Riddle   auto storePtr = spirv::getElementPtr(
861136d746eSJacques Pienaar       *getTypeConverter<SPIRVTypeConverter>(), memrefType, adaptor.getMemref(),
862136d746eSJacques Pienaar       adaptor.getIndices(), storeOp.getLoc(), rewriter);
8631e9799e2SButygin 
8641e9799e2SButygin   if (!storePtr)
865887e1aa3SJakub Kuderski     return rewriter.notifyMatchFailure(storeOp, "type conversion failed");
8661e9799e2SButygin 
867def16bcaSArtem Tyurin   auto memoryRequirements = calculateMemoryRequirements(storePtr, storeOp);
868def16bcaSArtem Tyurin   if (failed(memoryRequirements))
8698fd0bce4SJakub Kuderski     return rewriter.notifyMatchFailure(
870def16bcaSArtem Tyurin         storeOp, "failed to determine memory requirements");
8718fd0bce4SJakub Kuderski 
872def16bcaSArtem Tyurin   auto [memoryAccess, alignment] = *memoryRequirements;
8738fd0bce4SJakub Kuderski   rewriter.replaceOpWithNewOp<spirv::StoreOp>(
874def16bcaSArtem Tyurin       storeOp, storePtr, adaptor.getValue(), memoryAccess, alignment);
87526be7fe2SLei Zhang   return success();
87626be7fe2SLei Zhang }
87726be7fe2SLei Zhang 
8785fca4ce1SIvan Butygin LogicalResult ReinterpretCastPattern::matchAndRewrite(
8795fca4ce1SIvan Butygin     memref::ReinterpretCastOp op, OpAdaptor adaptor,
8805fca4ce1SIvan Butygin     ConversionPatternRewriter &rewriter) const {
8815fca4ce1SIvan Butygin   Value src = adaptor.getSource();
8825fca4ce1SIvan Butygin   auto srcType = dyn_cast<spirv::PointerType>(src.getType());
8835fca4ce1SIvan Butygin 
8845fca4ce1SIvan Butygin   if (!srcType)
8855fca4ce1SIvan Butygin     return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
8865fca4ce1SIvan Butygin       diag << "invalid src type " << src.getType();
8875fca4ce1SIvan Butygin     });
8885fca4ce1SIvan Butygin 
889ce254598SMatthias Springer   const TypeConverter *converter = getTypeConverter();
8905fca4ce1SIvan Butygin 
8915fca4ce1SIvan Butygin   auto dstType = converter->convertType<spirv::PointerType>(op.getType());
8925fca4ce1SIvan Butygin   if (dstType != srcType)
8935fca4ce1SIvan Butygin     return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
8945fca4ce1SIvan Butygin       diag << "invalid dst type " << op.getType();
8955fca4ce1SIvan Butygin     });
8965fca4ce1SIvan Butygin 
8975fca4ce1SIvan Butygin   OpFoldResult offset =
8985fca4ce1SIvan Butygin       getMixedValues(adaptor.getStaticOffsets(), adaptor.getOffsets(), rewriter)
8995fca4ce1SIvan Butygin           .front();
9005fca4ce1SIvan Butygin   if (isConstantIntValue(offset, 0)) {
9015fca4ce1SIvan Butygin     rewriter.replaceOp(op, src);
9025fca4ce1SIvan Butygin     return success();
9035fca4ce1SIvan Butygin   }
9045fca4ce1SIvan Butygin 
9055fca4ce1SIvan Butygin   Type intType = converter->convertType(rewriter.getIndexType());
9065fca4ce1SIvan Butygin   if (!intType)
9075fca4ce1SIvan Butygin     return rewriter.notifyMatchFailure(op, "failed to convert index type");
9085fca4ce1SIvan Butygin 
9095fca4ce1SIvan Butygin   Location loc = op.getLoc();
9105fca4ce1SIvan Butygin   auto offsetValue = [&]() -> Value {
9115fca4ce1SIvan Butygin     if (auto val = dyn_cast<Value>(offset))
9125fca4ce1SIvan Butygin       return val;
9135fca4ce1SIvan Butygin 
914*129ec845SKazu Hirata     int64_t attrVal = cast<IntegerAttr>(cast<Attribute>(offset)).getInt();
9155fca4ce1SIvan Butygin     Attribute attr = rewriter.getIntegerAttr(intType, attrVal);
91638f8a3cfSFinn Plummer     return rewriter.createOrFold<spirv::ConstantOp>(loc, intType, attr);
9175fca4ce1SIvan Butygin   }();
9185fca4ce1SIvan Butygin 
9195fca4ce1SIvan Butygin   rewriter.replaceOpWithNewOp<spirv::InBoundsPtrAccessChainOp>(
9205fca4ce1SIvan Butygin       op, src, offsetValue, std::nullopt);
9215fca4ce1SIvan Butygin   return success();
9225fca4ce1SIvan Butygin }
9235fca4ce1SIvan Butygin 
92426be7fe2SLei Zhang //===----------------------------------------------------------------------===//
92526be7fe2SLei Zhang // Pattern population
92626be7fe2SLei Zhang //===----------------------------------------------------------------------===//
92726be7fe2SLei Zhang 
92826be7fe2SLei Zhang namespace mlir {
929206fad0eSMatthias Springer void populateMemRefToSPIRVPatterns(const SPIRVTypeConverter &typeConverter,
93026be7fe2SLei Zhang                                    RewritePatternSet &patterns) {
931c50f335bSIvan Butygin   patterns.add<AllocaOpPattern, AllocOpPattern, AtomicRMWOpPattern,
932c50f335bSIvan Butygin                DeallocOpPattern, IntLoadOpPattern, IntStoreOpPattern,
933c50f335bSIvan Butygin                LoadOpPattern, MemorySpaceCastOpPattern, StoreOpPattern,
934c50f335bSIvan Butygin                ReinterpretCastPattern, CastPattern>(typeConverter,
935c50f335bSIvan Butygin                                                     patterns.getContext());
93626be7fe2SLei Zhang }
93726be7fe2SLei Zhang } // namespace mlir
938