126be7fe2SLei Zhang //===- MemRefToSPIRV.cpp - MemRef to SPIR-V Patterns ----------------------===// 226be7fe2SLei Zhang // 326be7fe2SLei Zhang // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 426be7fe2SLei Zhang // See https://llvm.org/LICENSE.txt for license information. 526be7fe2SLei Zhang // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 626be7fe2SLei Zhang // 726be7fe2SLei Zhang //===----------------------------------------------------------------------===// 826be7fe2SLei Zhang // 926be7fe2SLei Zhang // This file implements patterns to convert MemRef dialect to SPIR-V dialect. 1026be7fe2SLei Zhang // 1126be7fe2SLei Zhang //===----------------------------------------------------------------------===// 1226be7fe2SLei Zhang 1378e172fcSLei Zhang #include "mlir/Dialect/Arith/IR/Arith.h" 1426be7fe2SLei Zhang #include "mlir/Dialect/MemRef/IR/MemRef.h" 158fd0bce4SJakub Kuderski #include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h" 168854b736SLei Zhang #include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h" 1726be7fe2SLei Zhang #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" 180c7f3d6cSJakub Kuderski #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h" 1926be7fe2SLei Zhang #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h" 208fd0bce4SJakub Kuderski #include "mlir/IR/BuiltinAttributes.h" 2178e172fcSLei Zhang #include "mlir/IR/BuiltinTypes.h" 228fd0bce4SJakub Kuderski #include "mlir/IR/MLIRContext.h" 238fd0bce4SJakub Kuderski #include "mlir/IR/Visitors.h" 2426be7fe2SLei Zhang #include "llvm/Support/Debug.h" 258fd0bce4SJakub Kuderski #include <cassert> 26a1fe1f5fSKazu Hirata #include <optional> 2726be7fe2SLei Zhang 2826be7fe2SLei Zhang #define DEBUG_TYPE "memref-to-spirv-pattern" 2926be7fe2SLei Zhang 3026be7fe2SLei Zhang using namespace mlir; 3126be7fe2SLei Zhang 3226be7fe2SLei Zhang //===----------------------------------------------------------------------===// 3326be7fe2SLei Zhang // Utility functions 3426be7fe2SLei Zhang //===----------------------------------------------------------------------===// 3526be7fe2SLei Zhang 3626be7fe2SLei Zhang /// Returns the offset of the value in `targetBits` representation. 3726be7fe2SLei Zhang /// 3826be7fe2SLei Zhang /// `srcIdx` is an index into a 1-D array with each element having `sourceBits`. 3926be7fe2SLei Zhang /// It's assumed to be non-negative. 4026be7fe2SLei Zhang /// 4126be7fe2SLei Zhang /// When accessing an element in the array treating as having elements of 4226be7fe2SLei Zhang /// `targetBits`, multiple values are loaded in the same time. The method 4326be7fe2SLei Zhang /// returns the offset where the `srcIdx` locates in the value. For example, if 4426be7fe2SLei Zhang /// `sourceBits` equals to 8 and `targetBits` equals to 32, the x-th element is 4526be7fe2SLei Zhang /// located at (x % 4) * 8. Because there are four elements in one i32, and one 4626be7fe2SLei Zhang /// element has 8 bits. 4726be7fe2SLei Zhang static Value getOffsetForBitwidth(Location loc, Value srcIdx, int sourceBits, 4826be7fe2SLei Zhang int targetBits, OpBuilder &builder) { 4926be7fe2SLei Zhang assert(targetBits % sourceBits == 0); 504ffc63abSLei Zhang Type type = srcIdx.getType(); 514ffc63abSLei Zhang IntegerAttr idxAttr = builder.getIntegerAttr(type, targetBits / sourceBits); 5238f8a3cfSFinn Plummer auto idx = builder.createOrFold<spirv::ConstantOp>(loc, type, idxAttr); 534ffc63abSLei Zhang IntegerAttr srcBitsAttr = builder.getIntegerAttr(type, sourceBits); 5438f8a3cfSFinn Plummer auto srcBitsValue = 5538f8a3cfSFinn Plummer builder.createOrFold<spirv::ConstantOp>(loc, type, srcBitsAttr); 5638f8a3cfSFinn Plummer auto m = builder.createOrFold<spirv::UModOp>(loc, srcIdx, idx); 5738f8a3cfSFinn Plummer return builder.createOrFold<spirv::IMulOp>(loc, type, m, srcBitsValue); 5826be7fe2SLei Zhang } 5926be7fe2SLei Zhang 6026be7fe2SLei Zhang /// Returns an adjusted spirv::AccessChainOp. Based on the 6126be7fe2SLei Zhang /// extension/capabilities, certain integer bitwidths `sourceBits` might not be 6226be7fe2SLei Zhang /// supported. During conversion if a memref of an unsupported type is used, 6326be7fe2SLei Zhang /// load/stores to this memref need to be modified to use a supported higher 6426be7fe2SLei Zhang /// bitwidth `targetBits` and extracting the required bits. For an accessing a 654ffc63abSLei Zhang /// 1D array (spirv.array or spirv.rtarray), the last index is modified to load 665ab6ef75SJakub Kuderski /// the bits needed. The extraction of the actual bits needed are handled 6726be7fe2SLei Zhang /// separately. Note that this only works for a 1-D tensor. 68ce254598SMatthias Springer static Value 69ce254598SMatthias Springer adjustAccessChainForBitwidth(const SPIRVTypeConverter &typeConverter, 70ce254598SMatthias Springer spirv::AccessChainOp op, int sourceBits, 71ce254598SMatthias Springer int targetBits, OpBuilder &builder) { 7226be7fe2SLei Zhang assert(targetBits % sourceBits == 0); 7326be7fe2SLei Zhang const auto loc = op.getLoc(); 744ffc63abSLei Zhang Value lastDim = op->getOperand(op.getNumOperands() - 1); 754ffc63abSLei Zhang Type type = lastDim.getType(); 764ffc63abSLei Zhang IntegerAttr attr = builder.getIntegerAttr(type, targetBits / sourceBits); 7738f8a3cfSFinn Plummer auto idx = builder.createOrFold<spirv::ConstantOp>(loc, type, attr); 7890a1632dSJakub Kuderski auto indices = llvm::to_vector<4>(op.getIndices()); 7926be7fe2SLei Zhang // There are two elements if this is a 1-D tensor. 8026be7fe2SLei Zhang assert(indices.size() == 2); 8138f8a3cfSFinn Plummer indices.back() = builder.createOrFold<spirv::SDivOp>(loc, lastDim, idx); 8290a1632dSJakub Kuderski Type t = typeConverter.convertType(op.getComponentPtr().getType()); 8390a1632dSJakub Kuderski return builder.create<spirv::AccessChainOp>(loc, t, op.getBasePtr(), indices); 8426be7fe2SLei Zhang } 8526be7fe2SLei Zhang 86887e1aa3SJakub Kuderski /// Casts the given `srcBool` into an integer of `dstType`. 87887e1aa3SJakub Kuderski static Value castBoolToIntN(Location loc, Value srcBool, Type dstType, 88887e1aa3SJakub Kuderski OpBuilder &builder) { 89887e1aa3SJakub Kuderski assert(srcBool.getType().isInteger(1)); 90887e1aa3SJakub Kuderski if (dstType.isInteger(1)) 91887e1aa3SJakub Kuderski return srcBool; 92887e1aa3SJakub Kuderski Value zero = spirv::ConstantOp::getZero(dstType, loc, builder); 93887e1aa3SJakub Kuderski Value one = spirv::ConstantOp::getOne(dstType, loc, builder); 9438f8a3cfSFinn Plummer return builder.createOrFold<spirv::SelectOp>(loc, dstType, srcBool, one, 9538f8a3cfSFinn Plummer zero); 96887e1aa3SJakub Kuderski } 97887e1aa3SJakub Kuderski 98887e1aa3SJakub Kuderski /// Returns the `targetBits`-bit value shifted by the given `offset`, and cast 99887e1aa3SJakub Kuderski /// to the type destination type, and masked. 10026be7fe2SLei Zhang static Value shiftValue(Location loc, Value value, Value offset, Value mask, 101887e1aa3SJakub Kuderski OpBuilder &builder) { 102887e1aa3SJakub Kuderski IntegerType dstType = cast<IntegerType>(mask.getType()); 103887e1aa3SJakub Kuderski int targetBits = static_cast<int>(dstType.getWidth()); 104887e1aa3SJakub Kuderski int valueBits = value.getType().getIntOrFloatBitWidth(); 105887e1aa3SJakub Kuderski assert(valueBits <= targetBits); 106887e1aa3SJakub Kuderski 107887e1aa3SJakub Kuderski if (valueBits == 1) { 108887e1aa3SJakub Kuderski value = castBoolToIntN(loc, value, dstType, builder); 109887e1aa3SJakub Kuderski } else { 110887e1aa3SJakub Kuderski if (valueBits < targetBits) { 111887e1aa3SJakub Kuderski value = builder.create<spirv::UConvertOp>( 112887e1aa3SJakub Kuderski loc, builder.getIntegerType(targetBits), value); 113887e1aa3SJakub Kuderski } 114887e1aa3SJakub Kuderski 11538f8a3cfSFinn Plummer value = builder.createOrFold<spirv::BitwiseAndOp>(loc, value, mask); 116887e1aa3SJakub Kuderski } 11738f8a3cfSFinn Plummer return builder.createOrFold<spirv::ShiftLeftLogicalOp>(loc, value.getType(), 11838f8a3cfSFinn Plummer value, offset); 11926be7fe2SLei Zhang } 12026be7fe2SLei Zhang 1218854b736SLei Zhang /// Returns true if the allocations of memref `type` generated from `allocOp` 1228854b736SLei Zhang /// can be lowered to SPIR-V. 1238854b736SLei Zhang static bool isAllocationSupported(Operation *allocOp, MemRefType type) { 1248854b736SLei Zhang if (isa<memref::AllocOp, memref::DeallocOp>(allocOp)) { 1255550c821STres Popp auto sc = dyn_cast_or_null<spirv::StorageClassAttr>(type.getMemorySpace()); 12689b595e1SLei Zhang if (!sc || sc.getValue() != spirv::StorageClass::Workgroup) 12726be7fe2SLei Zhang return false; 1288854b736SLei Zhang } else if (isa<memref::AllocaOp>(allocOp)) { 1295550c821STres Popp auto sc = dyn_cast_or_null<spirv::StorageClassAttr>(type.getMemorySpace()); 13089b595e1SLei Zhang if (!sc || sc.getValue() != spirv::StorageClass::Function) 1318854b736SLei Zhang return false; 1328854b736SLei Zhang } else { 1338854b736SLei Zhang return false; 1348854b736SLei Zhang } 1358854b736SLei Zhang 1368854b736SLei Zhang // Currently only support static shape and int or float or vector of int or 1378854b736SLei Zhang // float element type. 1388854b736SLei Zhang if (!type.hasStaticShape()) 1398854b736SLei Zhang return false; 1408854b736SLei Zhang 1418854b736SLei Zhang Type elementType = type.getElementType(); 1425550c821STres Popp if (auto vecType = dyn_cast<VectorType>(elementType)) 14326be7fe2SLei Zhang elementType = vecType.getElementType(); 14426be7fe2SLei Zhang return elementType.isIntOrFloat(); 14526be7fe2SLei Zhang } 14626be7fe2SLei Zhang 14726be7fe2SLei Zhang /// Returns the scope to use for atomic operations use for emulating store 14826be7fe2SLei Zhang /// operations of unsupported integer bitwidths, based on the memref 14970c73d1bSKazu Hirata /// type. Returns std::nullopt on failure. 1500a81ace0SKazu Hirata static std::optional<spirv::Scope> getAtomicOpScope(MemRefType type) { 1515550c821STres Popp auto sc = dyn_cast_or_null<spirv::StorageClassAttr>(type.getMemorySpace()); 15289b595e1SLei Zhang switch (sc.getValue()) { 15326be7fe2SLei Zhang case spirv::StorageClass::StorageBuffer: 15426be7fe2SLei Zhang return spirv::Scope::Device; 15526be7fe2SLei Zhang case spirv::StorageClass::Workgroup: 15626be7fe2SLei Zhang return spirv::Scope::Workgroup; 157713f85d5SLei Zhang default: 158713f85d5SLei Zhang break; 15926be7fe2SLei Zhang } 16026be7fe2SLei Zhang return {}; 16126be7fe2SLei Zhang } 16226be7fe2SLei Zhang 1630065bd2aSLei Zhang /// Casts the given `srcInt` into a boolean value. 1640065bd2aSLei Zhang static Value castIntNToBool(Location loc, Value srcInt, OpBuilder &builder) { 1650065bd2aSLei Zhang if (srcInt.getType().isInteger(1)) 1660065bd2aSLei Zhang return srcInt; 1670065bd2aSLei Zhang 16827158edaSDmitriy Smirnov auto one = spirv::ConstantOp::getZero(srcInt.getType(), loc, builder); 16927158edaSDmitriy Smirnov return builder.createOrFold<spirv::INotEqualOp>(loc, srcInt, one); 1700065bd2aSLei Zhang } 1710065bd2aSLei Zhang 17226be7fe2SLei Zhang //===----------------------------------------------------------------------===// 17326be7fe2SLei Zhang // Operation conversion 17426be7fe2SLei Zhang //===----------------------------------------------------------------------===// 17526be7fe2SLei Zhang 17626be7fe2SLei Zhang // Note that DRR cannot be used for the patterns in this file: we may need to 17726be7fe2SLei Zhang // convert type along the way, which requires ConversionPattern. DRR generates 17826be7fe2SLei Zhang // normal RewritePattern. 17926be7fe2SLei Zhang 18026be7fe2SLei Zhang namespace { 18126be7fe2SLei Zhang 1828854b736SLei Zhang /// Converts memref.alloca to SPIR-V Function variables. 1838854b736SLei Zhang class AllocaOpPattern final : public OpConversionPattern<memref::AllocaOp> { 1848854b736SLei Zhang public: 1858854b736SLei Zhang using OpConversionPattern<memref::AllocaOp>::OpConversionPattern; 1868854b736SLei Zhang 1878854b736SLei Zhang LogicalResult 1888854b736SLei Zhang matchAndRewrite(memref::AllocaOp allocaOp, OpAdaptor adaptor, 1898854b736SLei Zhang ConversionPatternRewriter &rewriter) const override; 1908854b736SLei Zhang }; 1918854b736SLei Zhang 19226be7fe2SLei Zhang /// Converts an allocation operation to SPIR-V. Currently only supports lowering 19326be7fe2SLei Zhang /// to Workgroup memory when the size is constant. Note that this pattern needs 1945ab6ef75SJakub Kuderski /// to be applied in a pass that runs at least at spirv.module scope since it 1955ab6ef75SJakub Kuderski /// wil ladd global variables into the spirv.module. 19626be7fe2SLei Zhang class AllocOpPattern final : public OpConversionPattern<memref::AllocOp> { 19726be7fe2SLei Zhang public: 19826be7fe2SLei Zhang using OpConversionPattern<memref::AllocOp>::OpConversionPattern; 19926be7fe2SLei Zhang 20026be7fe2SLei Zhang LogicalResult 201b54c724bSRiver Riddle matchAndRewrite(memref::AllocOp operation, OpAdaptor adaptor, 20226be7fe2SLei Zhang ConversionPatternRewriter &rewriter) const override; 20326be7fe2SLei Zhang }; 20426be7fe2SLei Zhang 20578e172fcSLei Zhang /// Converts memref.automic_rmw operations to SPIR-V atomic operations. 20678e172fcSLei Zhang class AtomicRMWOpPattern final 20778e172fcSLei Zhang : public OpConversionPattern<memref::AtomicRMWOp> { 20878e172fcSLei Zhang public: 20978e172fcSLei Zhang using OpConversionPattern<memref::AtomicRMWOp>::OpConversionPattern; 21078e172fcSLei Zhang 21178e172fcSLei Zhang LogicalResult 21278e172fcSLei Zhang matchAndRewrite(memref::AtomicRMWOp atomicOp, OpAdaptor adaptor, 21378e172fcSLei Zhang ConversionPatternRewriter &rewriter) const override; 21478e172fcSLei Zhang }; 21578e172fcSLei Zhang 21626be7fe2SLei Zhang /// Removed a deallocation if it is a supported allocation. Currently only 21726be7fe2SLei Zhang /// removes deallocation if the memory space is workgroup memory. 21826be7fe2SLei Zhang class DeallocOpPattern final : public OpConversionPattern<memref::DeallocOp> { 21926be7fe2SLei Zhang public: 22026be7fe2SLei Zhang using OpConversionPattern<memref::DeallocOp>::OpConversionPattern; 22126be7fe2SLei Zhang 22226be7fe2SLei Zhang LogicalResult 223b54c724bSRiver Riddle matchAndRewrite(memref::DeallocOp operation, OpAdaptor adaptor, 22426be7fe2SLei Zhang ConversionPatternRewriter &rewriter) const override; 22526be7fe2SLei Zhang }; 22626be7fe2SLei Zhang 2275ab6ef75SJakub Kuderski /// Converts memref.load to spirv.Load + spirv.AccessChain on integers. 22826be7fe2SLei Zhang class IntLoadOpPattern final : public OpConversionPattern<memref::LoadOp> { 22926be7fe2SLei Zhang public: 23026be7fe2SLei Zhang using OpConversionPattern<memref::LoadOp>::OpConversionPattern; 23126be7fe2SLei Zhang 23226be7fe2SLei Zhang LogicalResult 233b54c724bSRiver Riddle matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor, 23426be7fe2SLei Zhang ConversionPatternRewriter &rewriter) const override; 23526be7fe2SLei Zhang }; 23626be7fe2SLei Zhang 2375ab6ef75SJakub Kuderski /// Converts memref.load to spirv.Load + spirv.AccessChain. 23826be7fe2SLei Zhang class LoadOpPattern final : public OpConversionPattern<memref::LoadOp> { 23926be7fe2SLei Zhang public: 24026be7fe2SLei Zhang using OpConversionPattern<memref::LoadOp>::OpConversionPattern; 24126be7fe2SLei Zhang 24226be7fe2SLei Zhang LogicalResult 243b54c724bSRiver Riddle matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor, 24426be7fe2SLei Zhang ConversionPatternRewriter &rewriter) const override; 24526be7fe2SLei Zhang }; 24626be7fe2SLei Zhang 2475ab6ef75SJakub Kuderski /// Converts memref.store to spirv.Store on integers. 24826be7fe2SLei Zhang class IntStoreOpPattern final : public OpConversionPattern<memref::StoreOp> { 24926be7fe2SLei Zhang public: 25026be7fe2SLei Zhang using OpConversionPattern<memref::StoreOp>::OpConversionPattern; 25126be7fe2SLei Zhang 25226be7fe2SLei Zhang LogicalResult 253b54c724bSRiver Riddle matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor, 25426be7fe2SLei Zhang ConversionPatternRewriter &rewriter) const override; 25526be7fe2SLei Zhang }; 25626be7fe2SLei Zhang 2577fb9bbe5SKrzysztof Drewniak /// Converts memref.memory_space_cast to the appropriate spirv cast operations. 2587fb9bbe5SKrzysztof Drewniak class MemorySpaceCastOpPattern final 2597fb9bbe5SKrzysztof Drewniak : public OpConversionPattern<memref::MemorySpaceCastOp> { 2607fb9bbe5SKrzysztof Drewniak public: 2617fb9bbe5SKrzysztof Drewniak using OpConversionPattern<memref::MemorySpaceCastOp>::OpConversionPattern; 2627fb9bbe5SKrzysztof Drewniak 2637fb9bbe5SKrzysztof Drewniak LogicalResult 2647fb9bbe5SKrzysztof Drewniak matchAndRewrite(memref::MemorySpaceCastOp addrCastOp, OpAdaptor adaptor, 2657fb9bbe5SKrzysztof Drewniak ConversionPatternRewriter &rewriter) const override; 2667fb9bbe5SKrzysztof Drewniak }; 2677fb9bbe5SKrzysztof Drewniak 2685ab6ef75SJakub Kuderski /// Converts memref.store to spirv.Store. 26926be7fe2SLei Zhang class StoreOpPattern final : public OpConversionPattern<memref::StoreOp> { 27026be7fe2SLei Zhang public: 27126be7fe2SLei Zhang using OpConversionPattern<memref::StoreOp>::OpConversionPattern; 27226be7fe2SLei Zhang 27326be7fe2SLei Zhang LogicalResult 274b54c724bSRiver Riddle matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor, 27526be7fe2SLei Zhang ConversionPatternRewriter &rewriter) const override; 27626be7fe2SLei Zhang }; 27726be7fe2SLei Zhang 2785fca4ce1SIvan Butygin class ReinterpretCastPattern final 2795fca4ce1SIvan Butygin : public OpConversionPattern<memref::ReinterpretCastOp> { 2805fca4ce1SIvan Butygin public: 2815fca4ce1SIvan Butygin using OpConversionPattern::OpConversionPattern; 2825fca4ce1SIvan Butygin 2835fca4ce1SIvan Butygin LogicalResult 2845fca4ce1SIvan Butygin matchAndRewrite(memref::ReinterpretCastOp op, OpAdaptor adaptor, 2855fca4ce1SIvan Butygin ConversionPatternRewriter &rewriter) const override; 2865fca4ce1SIvan Butygin }; 2875fca4ce1SIvan Butygin 288c50f335bSIvan Butygin class CastPattern final : public OpConversionPattern<memref::CastOp> { 289c50f335bSIvan Butygin public: 290c50f335bSIvan Butygin using OpConversionPattern::OpConversionPattern; 291c50f335bSIvan Butygin 292c50f335bSIvan Butygin LogicalResult 293c50f335bSIvan Butygin matchAndRewrite(memref::CastOp op, OpAdaptor adaptor, 294c50f335bSIvan Butygin ConversionPatternRewriter &rewriter) const override { 295c50f335bSIvan Butygin Value src = adaptor.getSource(); 296c50f335bSIvan Butygin Type srcType = src.getType(); 297c50f335bSIvan Butygin 298ce254598SMatthias Springer const TypeConverter *converter = getTypeConverter(); 299c50f335bSIvan Butygin Type dstType = converter->convertType(op.getType()); 300c50f335bSIvan Butygin if (srcType != dstType) 301c50f335bSIvan Butygin return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) { 302c50f335bSIvan Butygin diag << "types doesn't match: " << srcType << " and " << dstType; 303c50f335bSIvan Butygin }); 304c50f335bSIvan Butygin 305c50f335bSIvan Butygin rewriter.replaceOp(op, src); 306c50f335bSIvan Butygin return success(); 307c50f335bSIvan Butygin } 308c50f335bSIvan Butygin }; 309c50f335bSIvan Butygin 31026be7fe2SLei Zhang } // namespace 31126be7fe2SLei Zhang 31226be7fe2SLei Zhang //===----------------------------------------------------------------------===// 3138854b736SLei Zhang // AllocaOp 3148854b736SLei Zhang //===----------------------------------------------------------------------===// 3158854b736SLei Zhang 3168854b736SLei Zhang LogicalResult 3178854b736SLei Zhang AllocaOpPattern::matchAndRewrite(memref::AllocaOp allocaOp, OpAdaptor adaptor, 3188854b736SLei Zhang ConversionPatternRewriter &rewriter) const { 3198854b736SLei Zhang MemRefType allocType = allocaOp.getType(); 3208854b736SLei Zhang if (!isAllocationSupported(allocaOp, allocType)) 3218854b736SLei Zhang return rewriter.notifyMatchFailure(allocaOp, "unhandled allocation type"); 3228854b736SLei Zhang 3238854b736SLei Zhang // Get the SPIR-V type for the allocation. 3248854b736SLei Zhang Type spirvType = getTypeConverter()->convertType(allocType); 325370a7eaeSJakub Kuderski if (!spirvType) 326370a7eaeSJakub Kuderski return rewriter.notifyMatchFailure(allocaOp, "type conversion failed"); 327370a7eaeSJakub Kuderski 3288854b736SLei Zhang rewriter.replaceOpWithNewOp<spirv::VariableOp>(allocaOp, spirvType, 3298854b736SLei Zhang spirv::StorageClass::Function, 3308854b736SLei Zhang /*initializer=*/nullptr); 3318854b736SLei Zhang return success(); 3328854b736SLei Zhang } 3338854b736SLei Zhang 3348854b736SLei Zhang //===----------------------------------------------------------------------===// 33526be7fe2SLei Zhang // AllocOp 33626be7fe2SLei Zhang //===----------------------------------------------------------------------===// 33726be7fe2SLei Zhang 33826be7fe2SLei Zhang LogicalResult 339b54c724bSRiver Riddle AllocOpPattern::matchAndRewrite(memref::AllocOp operation, OpAdaptor adaptor, 34026be7fe2SLei Zhang ConversionPatternRewriter &rewriter) const { 34126be7fe2SLei Zhang MemRefType allocType = operation.getType(); 3428854b736SLei Zhang if (!isAllocationSupported(operation, allocType)) 3438854b736SLei Zhang return rewriter.notifyMatchFailure(operation, "unhandled allocation type"); 34426be7fe2SLei Zhang 34526be7fe2SLei Zhang // Get the SPIR-V type for the allocation. 34626be7fe2SLei Zhang Type spirvType = getTypeConverter()->convertType(allocType); 347370a7eaeSJakub Kuderski if (!spirvType) 348370a7eaeSJakub Kuderski return rewriter.notifyMatchFailure(operation, "type conversion failed"); 34926be7fe2SLei Zhang 3505ab6ef75SJakub Kuderski // Insert spirv.GlobalVariable for this allocation. 35126be7fe2SLei Zhang Operation *parent = 35226be7fe2SLei Zhang SymbolTable::getNearestSymbolTable(operation->getParentOp()); 35326be7fe2SLei Zhang if (!parent) 35426be7fe2SLei Zhang return failure(); 35526be7fe2SLei Zhang Location loc = operation.getLoc(); 35626be7fe2SLei Zhang spirv::GlobalVariableOp varOp; 35726be7fe2SLei Zhang { 35826be7fe2SLei Zhang OpBuilder::InsertionGuard guard(rewriter); 35926be7fe2SLei Zhang Block &entryBlock = *parent->getRegion(0).begin(); 36026be7fe2SLei Zhang rewriter.setInsertionPointToStart(&entryBlock); 36126be7fe2SLei Zhang auto varOps = entryBlock.getOps<spirv::GlobalVariableOp>(); 36226be7fe2SLei Zhang std::string varName = 36326be7fe2SLei Zhang std::string("__workgroup_mem__") + 36426be7fe2SLei Zhang std::to_string(std::distance(varOps.begin(), varOps.end())); 36526be7fe2SLei Zhang varOp = rewriter.create<spirv::GlobalVariableOp>(loc, spirvType, varName, 36626be7fe2SLei Zhang /*initializer=*/nullptr); 36726be7fe2SLei Zhang } 36826be7fe2SLei Zhang 36926be7fe2SLei Zhang // Get pointer to global variable at the current scope. 37026be7fe2SLei Zhang rewriter.replaceOpWithNewOp<spirv::AddressOfOp>(operation, varOp); 37126be7fe2SLei Zhang return success(); 37226be7fe2SLei Zhang } 37326be7fe2SLei Zhang 37426be7fe2SLei Zhang //===----------------------------------------------------------------------===// 37578e172fcSLei Zhang // AllocOp 37678e172fcSLei Zhang //===----------------------------------------------------------------------===// 37778e172fcSLei Zhang 37878e172fcSLei Zhang LogicalResult 37978e172fcSLei Zhang AtomicRMWOpPattern::matchAndRewrite(memref::AtomicRMWOp atomicOp, 38078e172fcSLei Zhang OpAdaptor adaptor, 38178e172fcSLei Zhang ConversionPatternRewriter &rewriter) const { 3825550c821STres Popp if (isa<FloatType>(atomicOp.getType())) 38378e172fcSLei Zhang return rewriter.notifyMatchFailure(atomicOp, 38478e172fcSLei Zhang "unimplemented floating-point case"); 38578e172fcSLei Zhang 3865550c821STres Popp auto memrefType = cast<MemRefType>(atomicOp.getMemref().getType()); 38778e172fcSLei Zhang std::optional<spirv::Scope> scope = getAtomicOpScope(memrefType); 38878e172fcSLei Zhang if (!scope) 38978e172fcSLei Zhang return rewriter.notifyMatchFailure(atomicOp, 39078e172fcSLei Zhang "unsupported memref memory space"); 39178e172fcSLei Zhang 39278e172fcSLei Zhang auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>(); 39378e172fcSLei Zhang Type resultType = typeConverter.convertType(atomicOp.getType()); 39478e172fcSLei Zhang if (!resultType) 39578e172fcSLei Zhang return rewriter.notifyMatchFailure(atomicOp, 39678e172fcSLei Zhang "failed to convert result type"); 39778e172fcSLei Zhang 39878e172fcSLei Zhang auto loc = atomicOp.getLoc(); 39978e172fcSLei Zhang Value ptr = 40078e172fcSLei Zhang spirv::getElementPtr(typeConverter, memrefType, adaptor.getMemref(), 40178e172fcSLei Zhang adaptor.getIndices(), loc, rewriter); 40278e172fcSLei Zhang 40378e172fcSLei Zhang if (!ptr) 40478e172fcSLei Zhang return failure(); 40578e172fcSLei Zhang 40678e172fcSLei Zhang #define ATOMIC_CASE(kind, spirvOp) \ 40778e172fcSLei Zhang case arith::AtomicRMWKind::kind: \ 40878e172fcSLei Zhang rewriter.replaceOpWithNewOp<spirv::spirvOp>( \ 40978e172fcSLei Zhang atomicOp, resultType, ptr, *scope, \ 41078e172fcSLei Zhang spirv::MemorySemantics::AcquireRelease, adaptor.getValue()); \ 41178e172fcSLei Zhang break 41278e172fcSLei Zhang 41378e172fcSLei Zhang switch (atomicOp.getKind()) { 41478e172fcSLei Zhang ATOMIC_CASE(addi, AtomicIAddOp); 41578e172fcSLei Zhang ATOMIC_CASE(maxs, AtomicSMaxOp); 41678e172fcSLei Zhang ATOMIC_CASE(maxu, AtomicUMaxOp); 41778e172fcSLei Zhang ATOMIC_CASE(mins, AtomicSMinOp); 41878e172fcSLei Zhang ATOMIC_CASE(minu, AtomicUMinOp); 41978e172fcSLei Zhang ATOMIC_CASE(ori, AtomicOrOp); 42078e172fcSLei Zhang ATOMIC_CASE(andi, AtomicAndOp); 42178e172fcSLei Zhang default: 42278e172fcSLei Zhang return rewriter.notifyMatchFailure(atomicOp, "unimplemented atomic kind"); 42378e172fcSLei Zhang } 42478e172fcSLei Zhang 42578e172fcSLei Zhang #undef ATOMIC_CASE 42678e172fcSLei Zhang 42778e172fcSLei Zhang return success(); 42878e172fcSLei Zhang } 42978e172fcSLei Zhang 43078e172fcSLei Zhang //===----------------------------------------------------------------------===// 43126be7fe2SLei Zhang // DeallocOp 43226be7fe2SLei Zhang //===----------------------------------------------------------------------===// 43326be7fe2SLei Zhang 43426be7fe2SLei Zhang LogicalResult 43526be7fe2SLei Zhang DeallocOpPattern::matchAndRewrite(memref::DeallocOp operation, 436b54c724bSRiver Riddle OpAdaptor adaptor, 43726be7fe2SLei Zhang ConversionPatternRewriter &rewriter) const { 4385550c821STres Popp MemRefType deallocType = cast<MemRefType>(operation.getMemref().getType()); 4398854b736SLei Zhang if (!isAllocationSupported(operation, deallocType)) 4408854b736SLei Zhang return rewriter.notifyMatchFailure(operation, "unhandled allocation type"); 44126be7fe2SLei Zhang rewriter.eraseOp(operation); 44226be7fe2SLei Zhang return success(); 44326be7fe2SLei Zhang } 44426be7fe2SLei Zhang 44526be7fe2SLei Zhang //===----------------------------------------------------------------------===// 44626be7fe2SLei Zhang // LoadOp 44726be7fe2SLei Zhang //===----------------------------------------------------------------------===// 44826be7fe2SLei Zhang 449def16bcaSArtem Tyurin struct MemoryRequirements { 450def16bcaSArtem Tyurin spirv::MemoryAccessAttr memoryAccess; 451def16bcaSArtem Tyurin IntegerAttr alignment; 452def16bcaSArtem Tyurin }; 4538fd0bce4SJakub Kuderski 4548fd0bce4SJakub Kuderski /// Given an accessed SPIR-V pointer, calculates its alignment requirements, if 4558fd0bce4SJakub Kuderski /// any. 456def16bcaSArtem Tyurin static FailureOr<MemoryRequirements> 457def16bcaSArtem Tyurin calculateMemoryRequirements(Value accessedPtr, bool isNontemporal) { 458def16bcaSArtem Tyurin MLIRContext *ctx = accessedPtr.getContext(); 459def16bcaSArtem Tyurin 460def16bcaSArtem Tyurin auto memoryAccess = spirv::MemoryAccess::None; 461def16bcaSArtem Tyurin if (isNontemporal) { 462def16bcaSArtem Tyurin memoryAccess = spirv::MemoryAccess::Nontemporal; 463def16bcaSArtem Tyurin } 464def16bcaSArtem Tyurin 4658fd0bce4SJakub Kuderski auto ptrType = cast<spirv::PointerType>(accessedPtr.getType()); 466def16bcaSArtem Tyurin if (ptrType.getStorageClass() != spirv::StorageClass::PhysicalStorageBuffer) { 467def16bcaSArtem Tyurin if (memoryAccess == spirv::MemoryAccess::None) { 468def16bcaSArtem Tyurin return MemoryRequirements{spirv::MemoryAccessAttr{}, IntegerAttr{}}; 469def16bcaSArtem Tyurin } 470def16bcaSArtem Tyurin return MemoryRequirements{spirv::MemoryAccessAttr::get(ctx, memoryAccess), 471def16bcaSArtem Tyurin IntegerAttr{}}; 472def16bcaSArtem Tyurin } 4738fd0bce4SJakub Kuderski 4748fd0bce4SJakub Kuderski // PhysicalStorageBuffers require the `Aligned` attribute. 4758fd0bce4SJakub Kuderski auto pointeeType = dyn_cast<spirv::ScalarType>(ptrType.getPointeeType()); 4768fd0bce4SJakub Kuderski if (!pointeeType) 4778fd0bce4SJakub Kuderski return failure(); 4788fd0bce4SJakub Kuderski 4798fd0bce4SJakub Kuderski // For scalar types, the alignment is determined by their size. 4808fd0bce4SJakub Kuderski std::optional<int64_t> sizeInBytes = pointeeType.getSizeInBytes(); 4818fd0bce4SJakub Kuderski if (!sizeInBytes.has_value()) 4828fd0bce4SJakub Kuderski return failure(); 4838fd0bce4SJakub Kuderski 484def16bcaSArtem Tyurin memoryAccess = memoryAccess | spirv::MemoryAccess::Aligned; 485def16bcaSArtem Tyurin auto memAccessAttr = spirv::MemoryAccessAttr::get(ctx, memoryAccess); 4868fd0bce4SJakub Kuderski auto alignment = IntegerAttr::get(IntegerType::get(ctx, 32), *sizeInBytes); 487def16bcaSArtem Tyurin return MemoryRequirements{memAccessAttr, alignment}; 4888fd0bce4SJakub Kuderski } 4898fd0bce4SJakub Kuderski 4908fd0bce4SJakub Kuderski /// Given an accessed SPIR-V pointer and the original memref load/store 4918fd0bce4SJakub Kuderski /// `memAccess` op, calculates the alignment requirements, if any. Takes into 4928fd0bce4SJakub Kuderski /// account the alignment attributes applied to the load/store op. 493def16bcaSArtem Tyurin template <class LoadOrStoreOp> 494def16bcaSArtem Tyurin static FailureOr<MemoryRequirements> 495def16bcaSArtem Tyurin calculateMemoryRequirements(Value accessedPtr, LoadOrStoreOp loadOrStoreOp) { 496def16bcaSArtem Tyurin static_assert( 497def16bcaSArtem Tyurin llvm::is_one_of<LoadOrStoreOp, memref::LoadOp, memref::StoreOp>::value, 498def16bcaSArtem Tyurin "Must be called on either memref::LoadOp or memref::StoreOp"); 4998fd0bce4SJakub Kuderski 500def16bcaSArtem Tyurin Operation *memrefAccessOp = loadOrStoreOp.getOperation(); 5018fd0bce4SJakub Kuderski auto memrefMemAccess = memrefAccessOp->getAttrOfType<spirv::MemoryAccessAttr>( 5028fd0bce4SJakub Kuderski spirv::attributeName<spirv::MemoryAccess>()); 5038fd0bce4SJakub Kuderski auto memrefAlignment = 5048fd0bce4SJakub Kuderski memrefAccessOp->getAttrOfType<IntegerAttr>("alignment"); 5058fd0bce4SJakub Kuderski if (memrefMemAccess && memrefAlignment) 506def16bcaSArtem Tyurin return MemoryRequirements{memrefMemAccess, memrefAlignment}; 5078fd0bce4SJakub Kuderski 508def16bcaSArtem Tyurin return calculateMemoryRequirements(accessedPtr, 509def16bcaSArtem Tyurin loadOrStoreOp.getNontemporal()); 5108fd0bce4SJakub Kuderski } 5118fd0bce4SJakub Kuderski 51226be7fe2SLei Zhang LogicalResult 513b54c724bSRiver Riddle IntLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor, 51426be7fe2SLei Zhang ConversionPatternRewriter &rewriter) const { 51526be7fe2SLei Zhang auto loc = loadOp.getLoc(); 5165550c821STres Popp auto memrefType = cast<MemRefType>(loadOp.getMemref().getType()); 51726be7fe2SLei Zhang if (!memrefType.getElementType().isSignlessInteger()) 51826be7fe2SLei Zhang return failure(); 51926be7fe2SLei Zhang 520ce254598SMatthias Springer const auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>(); 5219c3a73a5SStanley Winata Value accessChain = 522136d746eSJacques Pienaar spirv::getElementPtr(typeConverter, memrefType, adaptor.getMemref(), 523136d746eSJacques Pienaar adaptor.getIndices(), loc, rewriter); 52426be7fe2SLei Zhang 5259c3a73a5SStanley Winata if (!accessChain) 5261e9799e2SButygin return failure(); 5271e9799e2SButygin 52826be7fe2SLei Zhang int srcBits = memrefType.getElementType().getIntOrFloatBitWidth(); 52926be7fe2SLei Zhang bool isBool = srcBits == 1; 53026be7fe2SLei Zhang if (isBool) 53126be7fe2SLei Zhang srcBits = typeConverter.getOptions().boolNumBits; 5320c7f3d6cSJakub Kuderski 5330c7f3d6cSJakub Kuderski auto pointerType = typeConverter.convertType<spirv::PointerType>(memrefType); 5340c7f3d6cSJakub Kuderski if (!pointerType) 5350c7f3d6cSJakub Kuderski return rewriter.notifyMatchFailure(loadOp, "failed to convert memref type"); 5360c7f3d6cSJakub Kuderski 5370c7f3d6cSJakub Kuderski Type pointeeType = pointerType.getPointeeType(); 53826be7fe2SLei Zhang Type dstType; 5399c3a73a5SStanley Winata if (typeConverter.allows(spirv::Capability::Kernel)) { 5405550c821STres Popp if (auto arrayType = dyn_cast<spirv::ArrayType>(pointeeType)) 541d6de6ddeSNirvedh Meshram dstType = arrayType.getElementType(); 542d6de6ddeSNirvedh Meshram else 5439c3a73a5SStanley Winata dstType = pointeeType; 5449c3a73a5SStanley Winata } else { 5459c3a73a5SStanley Winata // For Vulkan we need to extract element from wrapping struct and array. 5469c3a73a5SStanley Winata Type structElemType = 5475550c821STres Popp cast<spirv::StructType>(pointeeType).getElementType(0); 5485550c821STres Popp if (auto arrayType = dyn_cast<spirv::ArrayType>(structElemType)) 54926be7fe2SLei Zhang dstType = arrayType.getElementType(); 55026be7fe2SLei Zhang else 5515550c821STres Popp dstType = cast<spirv::RuntimeArrayType>(structElemType).getElementType(); 5529c3a73a5SStanley Winata } 55326be7fe2SLei Zhang int dstBits = dstType.getIntOrFloatBitWidth(); 55426be7fe2SLei Zhang assert(dstBits % srcBits == 0); 55526be7fe2SLei Zhang 556370a7eaeSJakub Kuderski // If the rewritten load op has the same bit width, use the loading value 55726be7fe2SLei Zhang // directly. 55826be7fe2SLei Zhang if (srcBits == dstBits) { 559def16bcaSArtem Tyurin auto memoryRequirements = calculateMemoryRequirements(accessChain, loadOp); 560def16bcaSArtem Tyurin if (failed(memoryRequirements)) 5618fd0bce4SJakub Kuderski return rewriter.notifyMatchFailure( 562def16bcaSArtem Tyurin loadOp, "failed to determine memory requirements"); 5638fd0bce4SJakub Kuderski 564def16bcaSArtem Tyurin auto [memoryAccess, alignment] = *memoryRequirements; 5658fd0bce4SJakub Kuderski Value loadVal = rewriter.create<spirv::LoadOp>(loc, accessChain, 5668fd0bce4SJakub Kuderski memoryAccess, alignment); 5670065bd2aSLei Zhang if (isBool) 5680065bd2aSLei Zhang loadVal = castIntNToBool(loc, loadVal, rewriter); 5690065bd2aSLei Zhang rewriter.replaceOp(loadOp, loadVal); 57026be7fe2SLei Zhang return success(); 57126be7fe2SLei Zhang } 57226be7fe2SLei Zhang 5739c3a73a5SStanley Winata // Bitcasting is currently unsupported for Kernel capability / 5745ab6ef75SJakub Kuderski // spirv.PtrAccessChain. 5759c3a73a5SStanley Winata if (typeConverter.allows(spirv::Capability::Kernel)) 5769c3a73a5SStanley Winata return failure(); 5779c3a73a5SStanley Winata 5789c3a73a5SStanley Winata auto accessChainOp = accessChain.getDefiningOp<spirv::AccessChainOp>(); 5799c3a73a5SStanley Winata if (!accessChainOp) 5809c3a73a5SStanley Winata return failure(); 5819c3a73a5SStanley Winata 58226be7fe2SLei Zhang // Assume that getElementPtr() works linearizely. If it's a scalar, the method 58326be7fe2SLei Zhang // still returns a linearized accessing. If the accessing is not linearized, 58426be7fe2SLei Zhang // there will be offset issues. 58590a1632dSJakub Kuderski assert(accessChainOp.getIndices().size() == 2); 58626be7fe2SLei Zhang Value adjustedPtr = adjustAccessChainForBitwidth(typeConverter, accessChainOp, 58726be7fe2SLei Zhang srcBits, dstBits, rewriter); 588def16bcaSArtem Tyurin auto memoryRequirements = calculateMemoryRequirements(adjustedPtr, loadOp); 589def16bcaSArtem Tyurin if (failed(memoryRequirements)) 5908fd0bce4SJakub Kuderski return rewriter.notifyMatchFailure( 591def16bcaSArtem Tyurin loadOp, "failed to determine memory requirements"); 5928fd0bce4SJakub Kuderski 593def16bcaSArtem Tyurin auto [memoryAccess, alignment] = *memoryRequirements; 5948fd0bce4SJakub Kuderski Value spvLoadOp = rewriter.create<spirv::LoadOp>(loc, dstType, adjustedPtr, 5958fd0bce4SJakub Kuderski memoryAccess, alignment); 59626be7fe2SLei Zhang 59726be7fe2SLei Zhang // Shift the bits to the rightmost. 59826be7fe2SLei Zhang // ____XXXX________ -> ____________XXXX 59926be7fe2SLei Zhang Value lastDim = accessChainOp->getOperand(accessChainOp.getNumOperands() - 1); 60026be7fe2SLei Zhang Value offset = getOffsetForBitwidth(loc, lastDim, srcBits, dstBits, rewriter); 60138f8a3cfSFinn Plummer Value result = rewriter.createOrFold<spirv::ShiftRightArithmeticOp>( 60226be7fe2SLei Zhang loc, spvLoadOp.getType(), spvLoadOp, offset); 60326be7fe2SLei Zhang 60426be7fe2SLei Zhang // Apply the mask to extract corresponding bits. 60538f8a3cfSFinn Plummer Value mask = rewriter.createOrFold<spirv::ConstantOp>( 60626be7fe2SLei Zhang loc, dstType, rewriter.getIntegerAttr(dstType, (1 << srcBits) - 1)); 60738f8a3cfSFinn Plummer result = 60838f8a3cfSFinn Plummer rewriter.createOrFold<spirv::BitwiseAndOp>(loc, dstType, result, mask); 60926be7fe2SLei Zhang 61026be7fe2SLei Zhang // Apply sign extension on the loading value unconditionally. The signedness 61126be7fe2SLei Zhang // semantic is carried in the operator itself, we relies other pattern to 61226be7fe2SLei Zhang // handle the casting. 61326be7fe2SLei Zhang IntegerAttr shiftValueAttr = 61426be7fe2SLei Zhang rewriter.getIntegerAttr(dstType, dstBits - srcBits); 61526be7fe2SLei Zhang Value shiftValue = 61638f8a3cfSFinn Plummer rewriter.createOrFold<spirv::ConstantOp>(loc, dstType, shiftValueAttr); 61738f8a3cfSFinn Plummer result = rewriter.createOrFold<spirv::ShiftLeftLogicalOp>(loc, dstType, 61838f8a3cfSFinn Plummer result, shiftValue); 61938f8a3cfSFinn Plummer result = rewriter.createOrFold<spirv::ShiftRightArithmeticOp>( 62038f8a3cfSFinn Plummer loc, dstType, result, shiftValue); 62126be7fe2SLei Zhang 62226be7fe2SLei Zhang rewriter.replaceOp(loadOp, result); 62326be7fe2SLei Zhang 62426be7fe2SLei Zhang assert(accessChainOp.use_empty()); 62526be7fe2SLei Zhang rewriter.eraseOp(accessChainOp); 62626be7fe2SLei Zhang 62726be7fe2SLei Zhang return success(); 62826be7fe2SLei Zhang } 62926be7fe2SLei Zhang 63026be7fe2SLei Zhang LogicalResult 631b54c724bSRiver Riddle LoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor, 63226be7fe2SLei Zhang ConversionPatternRewriter &rewriter) const { 6335550c821STres Popp auto memrefType = cast<MemRefType>(loadOp.getMemref().getType()); 63426be7fe2SLei Zhang if (memrefType.getElementType().isSignlessInteger()) 63526be7fe2SLei Zhang return failure(); 6368fd0bce4SJakub Kuderski Value loadPtr = spirv::getElementPtr( 637136d746eSJacques Pienaar *getTypeConverter<SPIRVTypeConverter>(), memrefType, adaptor.getMemref(), 638136d746eSJacques Pienaar adaptor.getIndices(), loadOp.getLoc(), rewriter); 6391e9799e2SButygin 6401e9799e2SButygin if (!loadPtr) 6411e9799e2SButygin return failure(); 6421e9799e2SButygin 643def16bcaSArtem Tyurin auto memoryRequirements = calculateMemoryRequirements(loadPtr, loadOp); 644def16bcaSArtem Tyurin if (failed(memoryRequirements)) 6458fd0bce4SJakub Kuderski return rewriter.notifyMatchFailure( 646def16bcaSArtem Tyurin loadOp, "failed to determine memory requirements"); 6478fd0bce4SJakub Kuderski 648def16bcaSArtem Tyurin auto [memoryAccess, alignment] = *memoryRequirements; 649def16bcaSArtem Tyurin rewriter.replaceOpWithNewOp<spirv::LoadOp>(loadOp, loadPtr, memoryAccess, 6508fd0bce4SJakub Kuderski alignment); 65126be7fe2SLei Zhang return success(); 65226be7fe2SLei Zhang } 65326be7fe2SLei Zhang 65426be7fe2SLei Zhang LogicalResult 655b54c724bSRiver Riddle IntStoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor, 65626be7fe2SLei Zhang ConversionPatternRewriter &rewriter) const { 6575550c821STres Popp auto memrefType = cast<MemRefType>(storeOp.getMemref().getType()); 65826be7fe2SLei Zhang if (!memrefType.getElementType().isSignlessInteger()) 659887e1aa3SJakub Kuderski return rewriter.notifyMatchFailure(storeOp, 660887e1aa3SJakub Kuderski "element type is not a signless int"); 66126be7fe2SLei Zhang 66226be7fe2SLei Zhang auto loc = storeOp.getLoc(); 66326be7fe2SLei Zhang auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>(); 6649c3a73a5SStanley Winata Value accessChain = 665136d746eSJacques Pienaar spirv::getElementPtr(typeConverter, memrefType, adaptor.getMemref(), 666136d746eSJacques Pienaar adaptor.getIndices(), loc, rewriter); 6671e9799e2SButygin 6689c3a73a5SStanley Winata if (!accessChain) 669887e1aa3SJakub Kuderski return rewriter.notifyMatchFailure( 670887e1aa3SJakub Kuderski storeOp, "failed to convert element pointer type"); 6711e9799e2SButygin 67226be7fe2SLei Zhang int srcBits = memrefType.getElementType().getIntOrFloatBitWidth(); 67326be7fe2SLei Zhang 67426be7fe2SLei Zhang bool isBool = srcBits == 1; 67526be7fe2SLei Zhang if (isBool) 67626be7fe2SLei Zhang srcBits = typeConverter.getOptions().boolNumBits; 6779f5300c8SLei Zhang 6780c7f3d6cSJakub Kuderski auto pointerType = typeConverter.convertType<spirv::PointerType>(memrefType); 6790c7f3d6cSJakub Kuderski if (!pointerType) 6800c7f3d6cSJakub Kuderski return rewriter.notifyMatchFailure(storeOp, 6810c7f3d6cSJakub Kuderski "failed to convert memref type"); 6820c7f3d6cSJakub Kuderski 6830c7f3d6cSJakub Kuderski Type pointeeType = pointerType.getPointeeType(); 684887e1aa3SJakub Kuderski IntegerType dstType; 6859c3a73a5SStanley Winata if (typeConverter.allows(spirv::Capability::Kernel)) { 6865550c821STres Popp if (auto arrayType = dyn_cast<spirv::ArrayType>(pointeeType)) 687887e1aa3SJakub Kuderski dstType = dyn_cast<IntegerType>(arrayType.getElementType()); 688d6de6ddeSNirvedh Meshram else 689887e1aa3SJakub Kuderski dstType = dyn_cast<IntegerType>(pointeeType); 6909c3a73a5SStanley Winata } else { 6919c3a73a5SStanley Winata // For Vulkan we need to extract element from wrapping struct and array. 6929c3a73a5SStanley Winata Type structElemType = 6935550c821STres Popp cast<spirv::StructType>(pointeeType).getElementType(0); 6945550c821STres Popp if (auto arrayType = dyn_cast<spirv::ArrayType>(structElemType)) 695887e1aa3SJakub Kuderski dstType = dyn_cast<IntegerType>(arrayType.getElementType()); 69626be7fe2SLei Zhang else 697887e1aa3SJakub Kuderski dstType = dyn_cast<IntegerType>( 698887e1aa3SJakub Kuderski cast<spirv::RuntimeArrayType>(structElemType).getElementType()); 6999c3a73a5SStanley Winata } 70026be7fe2SLei Zhang 701887e1aa3SJakub Kuderski if (!dstType) 702887e1aa3SJakub Kuderski return rewriter.notifyMatchFailure( 703887e1aa3SJakub Kuderski storeOp, "failed to determine destination element type"); 704887e1aa3SJakub Kuderski 705887e1aa3SJakub Kuderski int dstBits = static_cast<int>(dstType.getWidth()); 70626be7fe2SLei Zhang assert(dstBits % srcBits == 0); 70726be7fe2SLei Zhang 70826be7fe2SLei Zhang if (srcBits == dstBits) { 709def16bcaSArtem Tyurin auto memoryRequirements = calculateMemoryRequirements(accessChain, storeOp); 710def16bcaSArtem Tyurin if (failed(memoryRequirements)) 7118fd0bce4SJakub Kuderski return rewriter.notifyMatchFailure( 712def16bcaSArtem Tyurin storeOp, "failed to determine memory requirements"); 7138fd0bce4SJakub Kuderski 714def16bcaSArtem Tyurin auto [memoryAccess, alignment] = *memoryRequirements; 715136d746eSJacques Pienaar Value storeVal = adaptor.getValue(); 7169f5300c8SLei Zhang if (isBool) 7179f5300c8SLei Zhang storeVal = castBoolToIntN(loc, storeVal, dstType, rewriter); 7188fd0bce4SJakub Kuderski rewriter.replaceOpWithNewOp<spirv::StoreOp>(storeOp, accessChain, storeVal, 719def16bcaSArtem Tyurin memoryAccess, alignment); 72026be7fe2SLei Zhang return success(); 72126be7fe2SLei Zhang } 72226be7fe2SLei Zhang 7239c3a73a5SStanley Winata // Bitcasting is currently unsupported for Kernel capability / 7245ab6ef75SJakub Kuderski // spirv.PtrAccessChain. 7259c3a73a5SStanley Winata if (typeConverter.allows(spirv::Capability::Kernel)) 7269c3a73a5SStanley Winata return failure(); 7279c3a73a5SStanley Winata 7289c3a73a5SStanley Winata auto accessChainOp = accessChain.getDefiningOp<spirv::AccessChainOp>(); 7299c3a73a5SStanley Winata if (!accessChainOp) 7309c3a73a5SStanley Winata return failure(); 7319c3a73a5SStanley Winata 732887e1aa3SJakub Kuderski // Since there are multiple threads in the processing, the emulation will be 733887e1aa3SJakub Kuderski // done with atomic operations. E.g., if the stored value is i8, rewrite the 734887e1aa3SJakub Kuderski // StoreOp to: 73526be7fe2SLei Zhang // 1) load a 32-bit integer 736887e1aa3SJakub Kuderski // 2) clear 8 bits in the loaded value 737887e1aa3SJakub Kuderski // 3) set 8 bits in the loaded value 738887e1aa3SJakub Kuderski // 4) store 32-bit value back 739887e1aa3SJakub Kuderski // 740887e1aa3SJakub Kuderski // Step 2 is done with AtomicAnd, and step 3 is done with AtomicOr (of the 741887e1aa3SJakub Kuderski // loaded 32-bit value and the shifted 8-bit store value) as another atomic 742887e1aa3SJakub Kuderski // step. 74390a1632dSJakub Kuderski assert(accessChainOp.getIndices().size() == 2); 74426be7fe2SLei Zhang Value lastDim = accessChainOp->getOperand(accessChainOp.getNumOperands() - 1); 74526be7fe2SLei Zhang Value offset = getOffsetForBitwidth(loc, lastDim, srcBits, dstBits, rewriter); 74626be7fe2SLei Zhang 74726be7fe2SLei Zhang // Create a mask to clear the destination. E.g., if it is the second i8 in 74826be7fe2SLei Zhang // i32, 0xFFFF00FF is created. 74938f8a3cfSFinn Plummer Value mask = rewriter.createOrFold<spirv::ConstantOp>( 75026be7fe2SLei Zhang loc, dstType, rewriter.getIntegerAttr(dstType, (1 << srcBits) - 1)); 75138f8a3cfSFinn Plummer Value clearBitsMask = rewriter.createOrFold<spirv::ShiftLeftLogicalOp>( 75238f8a3cfSFinn Plummer loc, dstType, mask, offset); 75338f8a3cfSFinn Plummer clearBitsMask = 75438f8a3cfSFinn Plummer rewriter.createOrFold<spirv::NotOp>(loc, dstType, clearBitsMask); 75526be7fe2SLei Zhang 756887e1aa3SJakub Kuderski Value storeVal = shiftValue(loc, adaptor.getValue(), offset, mask, rewriter); 75726be7fe2SLei Zhang Value adjustedPtr = adjustAccessChainForBitwidth(typeConverter, accessChainOp, 75826be7fe2SLei Zhang srcBits, dstBits, rewriter); 7590a81ace0SKazu Hirata std::optional<spirv::Scope> scope = getAtomicOpScope(memrefType); 76026be7fe2SLei Zhang if (!scope) 761887e1aa3SJakub Kuderski return rewriter.notifyMatchFailure(storeOp, "atomic scope not available"); 762887e1aa3SJakub Kuderski 76326be7fe2SLei Zhang Value result = rewriter.create<spirv::AtomicAndOp>( 76426be7fe2SLei Zhang loc, dstType, adjustedPtr, *scope, spirv::MemorySemantics::AcquireRelease, 76526be7fe2SLei Zhang clearBitsMask); 76626be7fe2SLei Zhang result = rewriter.create<spirv::AtomicOrOp>( 76726be7fe2SLei Zhang loc, dstType, adjustedPtr, *scope, spirv::MemorySemantics::AcquireRelease, 76826be7fe2SLei Zhang storeVal); 76926be7fe2SLei Zhang 77026be7fe2SLei Zhang // The AtomicOrOp has no side effect. Since it is already inserted, we can 77126be7fe2SLei Zhang // just remove the original StoreOp. Note that rewriter.replaceOp() 77226be7fe2SLei Zhang // doesn't work because it only accepts that the numbers of result are the 77326be7fe2SLei Zhang // same. 77426be7fe2SLei Zhang rewriter.eraseOp(storeOp); 77526be7fe2SLei Zhang 77626be7fe2SLei Zhang assert(accessChainOp.use_empty()); 77726be7fe2SLei Zhang rewriter.eraseOp(accessChainOp); 77826be7fe2SLei Zhang 77926be7fe2SLei Zhang return success(); 78026be7fe2SLei Zhang } 78126be7fe2SLei Zhang 7827fb9bbe5SKrzysztof Drewniak //===----------------------------------------------------------------------===// 7837fb9bbe5SKrzysztof Drewniak // MemorySpaceCastOp 7847fb9bbe5SKrzysztof Drewniak //===----------------------------------------------------------------------===// 7857fb9bbe5SKrzysztof Drewniak 7867fb9bbe5SKrzysztof Drewniak LogicalResult MemorySpaceCastOpPattern::matchAndRewrite( 7877fb9bbe5SKrzysztof Drewniak memref::MemorySpaceCastOp addrCastOp, OpAdaptor adaptor, 7887fb9bbe5SKrzysztof Drewniak ConversionPatternRewriter &rewriter) const { 7897fb9bbe5SKrzysztof Drewniak Location loc = addrCastOp.getLoc(); 7907fb9bbe5SKrzysztof Drewniak auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>(); 7917fb9bbe5SKrzysztof Drewniak if (!typeConverter.allows(spirv::Capability::Kernel)) 7927fb9bbe5SKrzysztof Drewniak return rewriter.notifyMatchFailure( 7937fb9bbe5SKrzysztof Drewniak loc, "address space casts require kernel capability"); 7947fb9bbe5SKrzysztof Drewniak 7955550c821STres Popp auto sourceType = dyn_cast<MemRefType>(addrCastOp.getSource().getType()); 7967fb9bbe5SKrzysztof Drewniak if (!sourceType) 7977fb9bbe5SKrzysztof Drewniak return rewriter.notifyMatchFailure( 7987fb9bbe5SKrzysztof Drewniak loc, "SPIR-V lowering requires ranked memref types"); 7995550c821STres Popp auto resultType = cast<MemRefType>(addrCastOp.getResult().getType()); 8007fb9bbe5SKrzysztof Drewniak 8017fb9bbe5SKrzysztof Drewniak auto sourceStorageClassAttr = 8025550c821STres Popp dyn_cast_or_null<spirv::StorageClassAttr>(sourceType.getMemorySpace()); 8037fb9bbe5SKrzysztof Drewniak if (!sourceStorageClassAttr) 8047fb9bbe5SKrzysztof Drewniak return rewriter.notifyMatchFailure(loc, [sourceType](Diagnostic &diag) { 8057fb9bbe5SKrzysztof Drewniak diag << "source address space " << sourceType.getMemorySpace() 8067fb9bbe5SKrzysztof Drewniak << " must be a SPIR-V storage class"; 8077fb9bbe5SKrzysztof Drewniak }); 8087fb9bbe5SKrzysztof Drewniak auto resultStorageClassAttr = 8095550c821STres Popp dyn_cast_or_null<spirv::StorageClassAttr>(resultType.getMemorySpace()); 8107fb9bbe5SKrzysztof Drewniak if (!resultStorageClassAttr) 8117fb9bbe5SKrzysztof Drewniak return rewriter.notifyMatchFailure(loc, [resultType](Diagnostic &diag) { 8127fb9bbe5SKrzysztof Drewniak diag << "result address space " << resultType.getMemorySpace() 8137fb9bbe5SKrzysztof Drewniak << " must be a SPIR-V storage class"; 8147fb9bbe5SKrzysztof Drewniak }); 8157fb9bbe5SKrzysztof Drewniak 8167fb9bbe5SKrzysztof Drewniak spirv::StorageClass sourceSc = sourceStorageClassAttr.getValue(); 8177fb9bbe5SKrzysztof Drewniak spirv::StorageClass resultSc = resultStorageClassAttr.getValue(); 8187fb9bbe5SKrzysztof Drewniak 8197fb9bbe5SKrzysztof Drewniak Value result = adaptor.getSource(); 8207fb9bbe5SKrzysztof Drewniak Type resultPtrType = typeConverter.convertType(resultType); 821370a7eaeSJakub Kuderski if (!resultPtrType) 822370a7eaeSJakub Kuderski return rewriter.notifyMatchFailure(addrCastOp, 823370a7eaeSJakub Kuderski "failed to convert memref type"); 824370a7eaeSJakub Kuderski 8257fb9bbe5SKrzysztof Drewniak Type genericPtrType = resultPtrType; 8267fb9bbe5SKrzysztof Drewniak // SPIR-V doesn't have a general address space cast operation. Instead, it has 8277fb9bbe5SKrzysztof Drewniak // conversions to and from generic pointers. To implement the general case, 8287fb9bbe5SKrzysztof Drewniak // we use specific-to-generic conversions when the source class is not 8297fb9bbe5SKrzysztof Drewniak // generic. Then when the result storage class is not generic, we convert the 830370a7eaeSJakub Kuderski // generic pointer (either the input on ar intermediate result) to that 8317fb9bbe5SKrzysztof Drewniak // class. This also means that we'll need the intermediate generic pointer 8327fb9bbe5SKrzysztof Drewniak // type if neither the source or destination have it. 8337fb9bbe5SKrzysztof Drewniak if (sourceSc != spirv::StorageClass::Generic && 8347fb9bbe5SKrzysztof Drewniak resultSc != spirv::StorageClass::Generic) { 8357fb9bbe5SKrzysztof Drewniak Type intermediateType = 8367fb9bbe5SKrzysztof Drewniak MemRefType::get(sourceType.getShape(), sourceType.getElementType(), 8377fb9bbe5SKrzysztof Drewniak sourceType.getLayout(), 8387fb9bbe5SKrzysztof Drewniak rewriter.getAttr<spirv::StorageClassAttr>( 8397fb9bbe5SKrzysztof Drewniak spirv::StorageClass::Generic)); 8407fb9bbe5SKrzysztof Drewniak genericPtrType = typeConverter.convertType(intermediateType); 8417fb9bbe5SKrzysztof Drewniak } 8427fb9bbe5SKrzysztof Drewniak if (sourceSc != spirv::StorageClass::Generic) { 8437fb9bbe5SKrzysztof Drewniak result = 8447fb9bbe5SKrzysztof Drewniak rewriter.create<spirv::PtrCastToGenericOp>(loc, genericPtrType, result); 8457fb9bbe5SKrzysztof Drewniak } 8467fb9bbe5SKrzysztof Drewniak if (resultSc != spirv::StorageClass::Generic) { 8477fb9bbe5SKrzysztof Drewniak result = 8487fb9bbe5SKrzysztof Drewniak rewriter.create<spirv::GenericCastToPtrOp>(loc, resultPtrType, result); 8497fb9bbe5SKrzysztof Drewniak } 8507fb9bbe5SKrzysztof Drewniak rewriter.replaceOp(addrCastOp, result); 8517fb9bbe5SKrzysztof Drewniak return success(); 8527fb9bbe5SKrzysztof Drewniak } 8537fb9bbe5SKrzysztof Drewniak 85426be7fe2SLei Zhang LogicalResult 855b54c724bSRiver Riddle StoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor, 85626be7fe2SLei Zhang ConversionPatternRewriter &rewriter) const { 8575550c821STres Popp auto memrefType = cast<MemRefType>(storeOp.getMemref().getType()); 85826be7fe2SLei Zhang if (memrefType.getElementType().isSignlessInteger()) 859887e1aa3SJakub Kuderski return rewriter.notifyMatchFailure(storeOp, "signless int"); 860b54c724bSRiver Riddle auto storePtr = spirv::getElementPtr( 861136d746eSJacques Pienaar *getTypeConverter<SPIRVTypeConverter>(), memrefType, adaptor.getMemref(), 862136d746eSJacques Pienaar adaptor.getIndices(), storeOp.getLoc(), rewriter); 8631e9799e2SButygin 8641e9799e2SButygin if (!storePtr) 865887e1aa3SJakub Kuderski return rewriter.notifyMatchFailure(storeOp, "type conversion failed"); 8661e9799e2SButygin 867def16bcaSArtem Tyurin auto memoryRequirements = calculateMemoryRequirements(storePtr, storeOp); 868def16bcaSArtem Tyurin if (failed(memoryRequirements)) 8698fd0bce4SJakub Kuderski return rewriter.notifyMatchFailure( 870def16bcaSArtem Tyurin storeOp, "failed to determine memory requirements"); 8718fd0bce4SJakub Kuderski 872def16bcaSArtem Tyurin auto [memoryAccess, alignment] = *memoryRequirements; 8738fd0bce4SJakub Kuderski rewriter.replaceOpWithNewOp<spirv::StoreOp>( 874def16bcaSArtem Tyurin storeOp, storePtr, adaptor.getValue(), memoryAccess, alignment); 87526be7fe2SLei Zhang return success(); 87626be7fe2SLei Zhang } 87726be7fe2SLei Zhang 8785fca4ce1SIvan Butygin LogicalResult ReinterpretCastPattern::matchAndRewrite( 8795fca4ce1SIvan Butygin memref::ReinterpretCastOp op, OpAdaptor adaptor, 8805fca4ce1SIvan Butygin ConversionPatternRewriter &rewriter) const { 8815fca4ce1SIvan Butygin Value src = adaptor.getSource(); 8825fca4ce1SIvan Butygin auto srcType = dyn_cast<spirv::PointerType>(src.getType()); 8835fca4ce1SIvan Butygin 8845fca4ce1SIvan Butygin if (!srcType) 8855fca4ce1SIvan Butygin return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) { 8865fca4ce1SIvan Butygin diag << "invalid src type " << src.getType(); 8875fca4ce1SIvan Butygin }); 8885fca4ce1SIvan Butygin 889ce254598SMatthias Springer const TypeConverter *converter = getTypeConverter(); 8905fca4ce1SIvan Butygin 8915fca4ce1SIvan Butygin auto dstType = converter->convertType<spirv::PointerType>(op.getType()); 8925fca4ce1SIvan Butygin if (dstType != srcType) 8935fca4ce1SIvan Butygin return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) { 8945fca4ce1SIvan Butygin diag << "invalid dst type " << op.getType(); 8955fca4ce1SIvan Butygin }); 8965fca4ce1SIvan Butygin 8975fca4ce1SIvan Butygin OpFoldResult offset = 8985fca4ce1SIvan Butygin getMixedValues(adaptor.getStaticOffsets(), adaptor.getOffsets(), rewriter) 8995fca4ce1SIvan Butygin .front(); 9005fca4ce1SIvan Butygin if (isConstantIntValue(offset, 0)) { 9015fca4ce1SIvan Butygin rewriter.replaceOp(op, src); 9025fca4ce1SIvan Butygin return success(); 9035fca4ce1SIvan Butygin } 9045fca4ce1SIvan Butygin 9055fca4ce1SIvan Butygin Type intType = converter->convertType(rewriter.getIndexType()); 9065fca4ce1SIvan Butygin if (!intType) 9075fca4ce1SIvan Butygin return rewriter.notifyMatchFailure(op, "failed to convert index type"); 9085fca4ce1SIvan Butygin 9095fca4ce1SIvan Butygin Location loc = op.getLoc(); 9105fca4ce1SIvan Butygin auto offsetValue = [&]() -> Value { 9115fca4ce1SIvan Butygin if (auto val = dyn_cast<Value>(offset)) 9125fca4ce1SIvan Butygin return val; 9135fca4ce1SIvan Butygin 914*129ec845SKazu Hirata int64_t attrVal = cast<IntegerAttr>(cast<Attribute>(offset)).getInt(); 9155fca4ce1SIvan Butygin Attribute attr = rewriter.getIntegerAttr(intType, attrVal); 91638f8a3cfSFinn Plummer return rewriter.createOrFold<spirv::ConstantOp>(loc, intType, attr); 9175fca4ce1SIvan Butygin }(); 9185fca4ce1SIvan Butygin 9195fca4ce1SIvan Butygin rewriter.replaceOpWithNewOp<spirv::InBoundsPtrAccessChainOp>( 9205fca4ce1SIvan Butygin op, src, offsetValue, std::nullopt); 9215fca4ce1SIvan Butygin return success(); 9225fca4ce1SIvan Butygin } 9235fca4ce1SIvan Butygin 92426be7fe2SLei Zhang //===----------------------------------------------------------------------===// 92526be7fe2SLei Zhang // Pattern population 92626be7fe2SLei Zhang //===----------------------------------------------------------------------===// 92726be7fe2SLei Zhang 92826be7fe2SLei Zhang namespace mlir { 929206fad0eSMatthias Springer void populateMemRefToSPIRVPatterns(const SPIRVTypeConverter &typeConverter, 93026be7fe2SLei Zhang RewritePatternSet &patterns) { 931c50f335bSIvan Butygin patterns.add<AllocaOpPattern, AllocOpPattern, AtomicRMWOpPattern, 932c50f335bSIvan Butygin DeallocOpPattern, IntLoadOpPattern, IntStoreOpPattern, 933c50f335bSIvan Butygin LoadOpPattern, MemorySpaceCastOpPattern, StoreOpPattern, 934c50f335bSIvan Butygin ReinterpretCastPattern, CastPattern>(typeConverter, 935c50f335bSIvan Butygin patterns.getContext()); 93626be7fe2SLei Zhang } 93726be7fe2SLei Zhang } // namespace mlir 938