1930c74f1SLei Zhang //===- SPIRVToLLVM.cpp - SPIR-V to LLVM Patterns --------------------------===// 2930c74f1SLei Zhang // 3930c74f1SLei Zhang // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4930c74f1SLei Zhang // See https://llvm.org/LICENSE.txt for license information. 5930c74f1SLei Zhang // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6930c74f1SLei Zhang // 7930c74f1SLei Zhang //===----------------------------------------------------------------------===// 8930c74f1SLei Zhang // 9930c74f1SLei Zhang // This file implements patterns to convert SPIR-V dialect to LLVM dialect. 10930c74f1SLei Zhang // 11930c74f1SLei Zhang //===----------------------------------------------------------------------===// 12930c74f1SLei Zhang 13930c74f1SLei Zhang #include "mlir/Conversion/SPIRVToLLVM/SPIRVToLLVM.h" 1475e5f0aaSAlex Zinenko #include "mlir/Conversion/LLVMCommon/Pattern.h" 1575e5f0aaSAlex Zinenko #include "mlir/Conversion/LLVMCommon/TypeConverter.h" 16d45de800SVictor Perez #include "mlir/Conversion/SPIRVCommon/AttrToLLVMConverter.h" 17930c74f1SLei Zhang #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 18930c74f1SLei Zhang #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" 19a29fffc4SLei Zhang #include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h" 20930c74f1SLei Zhang #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" 21930c74f1SLei Zhang #include "mlir/Dialect/SPIRV/Utils/LayoutUtils.h" 22930c74f1SLei Zhang #include "mlir/IR/BuiltinOps.h" 23930c74f1SLei Zhang #include "mlir/IR/PatternMatch.h" 24930c74f1SLei Zhang #include "mlir/Transforms/DialectConversion.h" 256ade03d7SLukas Sommer #include "llvm/ADT/TypeSwitch.h" 26930c74f1SLei Zhang #include "llvm/Support/Debug.h" 27930c74f1SLei Zhang #include "llvm/Support/FormatVariadic.h" 28930c74f1SLei Zhang 29930c74f1SLei Zhang #define DEBUG_TYPE "spirv-to-llvm-pattern" 30930c74f1SLei Zhang 31930c74f1SLei Zhang using namespace mlir; 32930c74f1SLei Zhang 33930c74f1SLei Zhang //===----------------------------------------------------------------------===// 34930c74f1SLei Zhang // Utility functions 35930c74f1SLei Zhang //===----------------------------------------------------------------------===// 36930c74f1SLei Zhang 37930c74f1SLei Zhang /// Returns true if the given type is a signed integer or vector type. 38930c74f1SLei Zhang static bool isSignedIntegerOrVector(Type type) { 39930c74f1SLei Zhang if (type.isSignedInteger()) 40930c74f1SLei Zhang return true; 415550c821STres Popp if (auto vecType = dyn_cast<VectorType>(type)) 42930c74f1SLei Zhang return vecType.getElementType().isSignedInteger(); 43930c74f1SLei Zhang return false; 44930c74f1SLei Zhang } 45930c74f1SLei Zhang 46930c74f1SLei Zhang /// Returns true if the given type is an unsigned integer or vector type 47930c74f1SLei Zhang static bool isUnsignedIntegerOrVector(Type type) { 48930c74f1SLei Zhang if (type.isUnsignedInteger()) 49930c74f1SLei Zhang return true; 505550c821STres Popp if (auto vecType = dyn_cast<VectorType>(type)) 51930c74f1SLei Zhang return vecType.getElementType().isUnsignedInteger(); 52930c74f1SLei Zhang return false; 53930c74f1SLei Zhang } 54930c74f1SLei Zhang 5500b12a94SThéo Degioanni /// Returns the width of an integer or of the element type of an integer vector, 5600b12a94SThéo Degioanni /// if applicable. 5700b12a94SThéo Degioanni static std::optional<uint64_t> getIntegerOrVectorElementWidth(Type type) { 5800b12a94SThéo Degioanni if (auto intType = dyn_cast<IntegerType>(type)) 5900b12a94SThéo Degioanni return intType.getWidth(); 6000b12a94SThéo Degioanni if (auto vecType = dyn_cast<VectorType>(type)) 6100b12a94SThéo Degioanni if (auto intType = dyn_cast<IntegerType>(vecType.getElementType())) 6200b12a94SThéo Degioanni return intType.getWidth(); 6300b12a94SThéo Degioanni return std::nullopt; 6400b12a94SThéo Degioanni } 6500b12a94SThéo Degioanni 66930c74f1SLei Zhang /// Returns the bit width of integer, float or vector of float or integer values 67930c74f1SLei Zhang static unsigned getBitWidth(Type type) { 685550c821STres Popp assert((type.isIntOrFloat() || isa<VectorType>(type)) && 69930c74f1SLei Zhang "bitwidth is not supported for this type"); 70930c74f1SLei Zhang if (type.isIntOrFloat()) 71930c74f1SLei Zhang return type.getIntOrFloatBitWidth(); 725550c821STres Popp auto vecType = dyn_cast<VectorType>(type); 73930c74f1SLei Zhang auto elementType = vecType.getElementType(); 74930c74f1SLei Zhang assert(elementType.isIntOrFloat() && 75930c74f1SLei Zhang "only integers and floats have a bitwidth"); 76930c74f1SLei Zhang return elementType.getIntOrFloatBitWidth(); 77930c74f1SLei Zhang } 78930c74f1SLei Zhang 79930c74f1SLei Zhang /// Returns the bit width of LLVMType integer or vector. 80c69c9e0fSAlex Zinenko static unsigned getLLVMTypeBitWidth(Type type) { 815550c821STres Popp return cast<IntegerType>((LLVM::isCompatibleVectorType(type) 825550c821STres Popp ? LLVM::getVectorElementType(type) 835550c821STres Popp : type)) 842230bf99SAlex Zinenko .getWidth(); 85930c74f1SLei Zhang } 86930c74f1SLei Zhang 87930c74f1SLei Zhang /// Creates `IntegerAttribute` with all bits set for given type 88930c74f1SLei Zhang static IntegerAttr minusOneIntegerAttribute(Type type, Builder builder) { 895550c821STres Popp if (auto vecType = dyn_cast<VectorType>(type)) { 905550c821STres Popp auto integerType = cast<IntegerType>(vecType.getElementType()); 91930c74f1SLei Zhang return builder.getIntegerAttr(integerType, -1); 92930c74f1SLei Zhang } 935550c821STres Popp auto integerType = cast<IntegerType>(type); 94930c74f1SLei Zhang return builder.getIntegerAttr(integerType, -1); 95930c74f1SLei Zhang } 96930c74f1SLei Zhang 97930c74f1SLei Zhang /// Creates `llvm.mlir.constant` with all bits set for the given type. 98930c74f1SLei Zhang static Value createConstantAllBitsSet(Location loc, Type srcType, Type dstType, 99930c74f1SLei Zhang PatternRewriter &rewriter) { 1005550c821STres Popp if (isa<VectorType>(srcType)) { 101930c74f1SLei Zhang return rewriter.create<LLVM::ConstantOp>( 102930c74f1SLei Zhang loc, dstType, 1035550c821STres Popp SplatElementsAttr::get(cast<ShapedType>(srcType), 104930c74f1SLei Zhang minusOneIntegerAttribute(srcType, rewriter))); 105930c74f1SLei Zhang } 106930c74f1SLei Zhang return rewriter.create<LLVM::ConstantOp>( 107930c74f1SLei Zhang loc, dstType, minusOneIntegerAttribute(srcType, rewriter)); 108930c74f1SLei Zhang } 109930c74f1SLei Zhang 110930c74f1SLei Zhang /// Creates `llvm.mlir.constant` with a floating-point scalar or vector value. 111930c74f1SLei Zhang static Value createFPConstant(Location loc, Type srcType, Type dstType, 112930c74f1SLei Zhang PatternRewriter &rewriter, double value) { 1135550c821STres Popp if (auto vecType = dyn_cast<VectorType>(srcType)) { 1145550c821STres Popp auto floatType = cast<FloatType>(vecType.getElementType()); 115930c74f1SLei Zhang return rewriter.create<LLVM::ConstantOp>( 116930c74f1SLei Zhang loc, dstType, 117930c74f1SLei Zhang SplatElementsAttr::get(vecType, 118930c74f1SLei Zhang rewriter.getFloatAttr(floatType, value))); 119930c74f1SLei Zhang } 1205550c821STres Popp auto floatType = cast<FloatType>(srcType); 121930c74f1SLei Zhang return rewriter.create<LLVM::ConstantOp>( 122930c74f1SLei Zhang loc, dstType, rewriter.getFloatAttr(floatType, value)); 123930c74f1SLei Zhang } 124930c74f1SLei Zhang 125930c74f1SLei Zhang /// Utility function for bitfield ops: 126930c74f1SLei Zhang /// - `BitFieldInsert` 127930c74f1SLei Zhang /// - `BitFieldSExtract` 128930c74f1SLei Zhang /// - `BitFieldUExtract` 129930c74f1SLei Zhang /// Truncates or extends the value. If the bitwidth of the value is the same as 130c69c9e0fSAlex Zinenko /// `llvmType` bitwidth, the value remains unchanged. 131c69c9e0fSAlex Zinenko static Value optionallyTruncateOrExtend(Location loc, Value value, 132c69c9e0fSAlex Zinenko Type llvmType, 133930c74f1SLei Zhang PatternRewriter &rewriter) { 134930c74f1SLei Zhang auto srcType = value.getType(); 135930c74f1SLei Zhang unsigned targetBitWidth = getLLVMTypeBitWidth(llvmType); 136c69c9e0fSAlex Zinenko unsigned valueBitWidth = LLVM::isCompatibleType(srcType) 137c69c9e0fSAlex Zinenko ? getLLVMTypeBitWidth(srcType) 138930c74f1SLei Zhang : getBitWidth(srcType); 139930c74f1SLei Zhang 140930c74f1SLei Zhang if (valueBitWidth < targetBitWidth) 141930c74f1SLei Zhang return rewriter.create<LLVM::ZExtOp>(loc, llvmType, value); 142930c74f1SLei Zhang // If the bit widths of `Count` and `Offset` are greater than the bit width 143930c74f1SLei Zhang // of the target type, they are truncated. Truncation is safe since `Count` 144930c74f1SLei Zhang // and `Offset` must be no more than 64 for op behaviour to be defined. Hence, 145930c74f1SLei Zhang // both values can be expressed in 8 bits. 146930c74f1SLei Zhang if (valueBitWidth > targetBitWidth) 147930c74f1SLei Zhang return rewriter.create<LLVM::TruncOp>(loc, llvmType, value); 148930c74f1SLei Zhang return value; 149930c74f1SLei Zhang } 150930c74f1SLei Zhang 151930c74f1SLei Zhang /// Broadcasts the value to vector with `numElements` number of elements. 152930c74f1SLei Zhang static Value broadcast(Location loc, Value toBroadcast, unsigned numElements, 15398723e65SMatthias Springer const TypeConverter &typeConverter, 154930c74f1SLei Zhang ConversionPatternRewriter &rewriter) { 155930c74f1SLei Zhang auto vectorType = VectorType::get(numElements, toBroadcast.getType()); 156930c74f1SLei Zhang auto llvmVectorType = typeConverter.convertType(vectorType); 157930c74f1SLei Zhang auto llvmI32Type = typeConverter.convertType(rewriter.getIntegerType(32)); 158930c74f1SLei Zhang Value broadcasted = rewriter.create<LLVM::UndefOp>(loc, llvmVectorType); 159930c74f1SLei Zhang for (unsigned i = 0; i < numElements; ++i) { 160930c74f1SLei Zhang auto index = rewriter.create<LLVM::ConstantOp>( 161930c74f1SLei Zhang loc, llvmI32Type, rewriter.getI32IntegerAttr(i)); 162930c74f1SLei Zhang broadcasted = rewriter.create<LLVM::InsertElementOp>( 163930c74f1SLei Zhang loc, llvmVectorType, broadcasted, toBroadcast, index); 164930c74f1SLei Zhang } 165930c74f1SLei Zhang return broadcasted; 166930c74f1SLei Zhang } 167930c74f1SLei Zhang 168930c74f1SLei Zhang /// Broadcasts the value. If `srcType` is a scalar, the value remains unchanged. 169930c74f1SLei Zhang static Value optionallyBroadcast(Location loc, Value value, Type srcType, 17098723e65SMatthias Springer const TypeConverter &typeConverter, 171930c74f1SLei Zhang ConversionPatternRewriter &rewriter) { 1725550c821STres Popp if (auto vectorType = dyn_cast<VectorType>(srcType)) { 173930c74f1SLei Zhang unsigned numElements = vectorType.getNumElements(); 174930c74f1SLei Zhang return broadcast(loc, value, numElements, typeConverter, rewriter); 175930c74f1SLei Zhang } 176930c74f1SLei Zhang return value; 177930c74f1SLei Zhang } 178930c74f1SLei Zhang 179930c74f1SLei Zhang /// Utility function for bitfield ops: `BitFieldInsert`, `BitFieldSExtract` and 180930c74f1SLei Zhang /// `BitFieldUExtract`. 181930c74f1SLei Zhang /// Broadcast `Offset` and `Count` to match the type of `Base`. If `Base` is of 182930c74f1SLei Zhang /// a vector type, construct a vector that has: 183930c74f1SLei Zhang /// - same number of elements as `Base` 184930c74f1SLei Zhang /// - each element has the type that is the same as the type of `Offset` or 185930c74f1SLei Zhang /// `Count` 186930c74f1SLei Zhang /// - each element has the same value as `Offset` or `Count` 187930c74f1SLei Zhang /// Then cast `Offset` and `Count` if their bit width is different 188930c74f1SLei Zhang /// from `Base` bit width. 189930c74f1SLei Zhang static Value processCountOrOffset(Location loc, Value value, Type srcType, 19098723e65SMatthias Springer Type dstType, const TypeConverter &converter, 191930c74f1SLei Zhang ConversionPatternRewriter &rewriter) { 192930c74f1SLei Zhang Value broadcasted = 193930c74f1SLei Zhang optionallyBroadcast(loc, value, srcType, converter, rewriter); 194930c74f1SLei Zhang return optionallyTruncateOrExtend(loc, broadcasted, dstType, rewriter); 195930c74f1SLei Zhang } 196930c74f1SLei Zhang 197930c74f1SLei Zhang /// Converts SPIR-V struct with a regular (according to `VulkanLayoutUtils`) 198930c74f1SLei Zhang /// offset to LLVM struct. Otherwise, the conversion is not supported. 19930ca16ecSPierre van Houtryve static Type convertStructTypeWithOffset(spirv::StructType type, 20098723e65SMatthias Springer const TypeConverter &converter) { 201930c74f1SLei Zhang if (type != VulkanLayoutUtils::decorateType(type)) 20230ca16ecSPierre van Houtryve return nullptr; 203930c74f1SLei Zhang 20430ca16ecSPierre van Houtryve SmallVector<Type> elementsVector; 20530ca16ecSPierre van Houtryve if (failed(converter.convertTypes(type.getElementTypes(), elementsVector))) 20630ca16ecSPierre van Houtryve return nullptr; 207930c74f1SLei Zhang return LLVM::LLVMStructType::getLiteral(type.getContext(), elementsVector, 208930c74f1SLei Zhang /*isPacked=*/false); 209930c74f1SLei Zhang } 210930c74f1SLei Zhang 211930c74f1SLei Zhang /// Converts SPIR-V struct with no offset to packed LLVM struct. 212930c74f1SLei Zhang static Type convertStructTypePacked(spirv::StructType type, 21398723e65SMatthias Springer const TypeConverter &converter) { 21430ca16ecSPierre van Houtryve SmallVector<Type> elementsVector; 21530ca16ecSPierre van Houtryve if (failed(converter.convertTypes(type.getElementTypes(), elementsVector))) 21630ca16ecSPierre van Houtryve return nullptr; 217930c74f1SLei Zhang return LLVM::LLVMStructType::getLiteral(type.getContext(), elementsVector, 218930c74f1SLei Zhang /*isPacked=*/true); 219930c74f1SLei Zhang } 220930c74f1SLei Zhang 221930c74f1SLei Zhang /// Creates LLVM dialect constant with the given value. 222930c74f1SLei Zhang static Value createI32ConstantOf(Location loc, PatternRewriter &rewriter, 223930c74f1SLei Zhang unsigned value) { 224930c74f1SLei Zhang return rewriter.create<LLVM::ConstantOp>( 2252230bf99SAlex Zinenko loc, IntegerType::get(rewriter.getContext(), 32), 226930c74f1SLei Zhang rewriter.getIntegerAttr(rewriter.getI32Type(), value)); 227930c74f1SLei Zhang } 228930c74f1SLei Zhang 2295ab6ef75SJakub Kuderski /// Utility for `spirv.Load` and `spirv.Store` conversion. 23098723e65SMatthias Springer static LogicalResult replaceWithLoadOrStore(Operation *op, ValueRange operands, 23198723e65SMatthias Springer ConversionPatternRewriter &rewriter, 23298723e65SMatthias Springer const TypeConverter &typeConverter, 23398723e65SMatthias Springer unsigned alignment, bool isVolatile, 234930c74f1SLei Zhang bool isNonTemporal) { 235930c74f1SLei Zhang if (auto loadOp = dyn_cast<spirv::LoadOp>(op)) { 236930c74f1SLei Zhang auto dstType = typeConverter.convertType(loadOp.getType()); 237930c74f1SLei Zhang if (!dstType) 238c19436eeSKohei Yamaguchi return rewriter.notifyMatchFailure(op, "type conversion failed"); 239930c74f1SLei Zhang rewriter.replaceOpWithNewOp<LLVM::LoadOp>( 24090a1632dSJakub Kuderski loadOp, dstType, spirv::LoadOpAdaptor(operands).getPtr(), alignment, 241015192c6SRiver Riddle isVolatile, isNonTemporal); 242930c74f1SLei Zhang return success(); 243930c74f1SLei Zhang } 244930c74f1SLei Zhang auto storeOp = cast<spirv::StoreOp>(op); 245015192c6SRiver Riddle spirv::StoreOpAdaptor adaptor(operands); 24690a1632dSJakub Kuderski rewriter.replaceOpWithNewOp<LLVM::StoreOp>(storeOp, adaptor.getValue(), 24790a1632dSJakub Kuderski adaptor.getPtr(), alignment, 248930c74f1SLei Zhang isVolatile, isNonTemporal); 249930c74f1SLei Zhang return success(); 250930c74f1SLei Zhang } 251930c74f1SLei Zhang 252930c74f1SLei Zhang //===----------------------------------------------------------------------===// 253930c74f1SLei Zhang // Type conversion 254930c74f1SLei Zhang //===----------------------------------------------------------------------===// 255930c74f1SLei Zhang 256930c74f1SLei Zhang /// Converts SPIR-V array type to LLVM array. Natural stride (according to 257930c74f1SLei Zhang /// `VulkanLayoutUtils`) is also mapped to LLVM array. This has to be respected 258930c74f1SLei Zhang /// when converting ops that manipulate array types. 2590de16fafSRamkumar Ramachandra static std::optional<Type> convertArrayType(spirv::ArrayType type, 260930c74f1SLei Zhang TypeConverter &converter) { 261930c74f1SLei Zhang unsigned stride = type.getArrayStride(); 262930c74f1SLei Zhang Type elementType = type.getElementType(); 2635550c821STres Popp auto sizeInBytes = cast<spirv::SPIRVType>(elementType).getSizeInBytes(); 2646ed4310eSMehdi Amini if (stride != 0 && (!sizeInBytes || *sizeInBytes != stride)) 2651a36588eSKazu Hirata return std::nullopt; 266930c74f1SLei Zhang 267c69c9e0fSAlex Zinenko auto llvmElementType = converter.convertType(elementType); 268930c74f1SLei Zhang unsigned numElements = type.getNumElements(); 269930c74f1SLei Zhang return LLVM::LLVMArrayType::get(llvmElementType, numElements); 270930c74f1SLei Zhang } 271930c74f1SLei Zhang 272930c74f1SLei Zhang /// Converts SPIR-V pointer type to LLVM pointer. Pointer's storage class is not 273930c74f1SLei Zhang /// modelled at the moment. 274930c74f1SLei Zhang static Type convertPointerType(spirv::PointerType type, 27598723e65SMatthias Springer const TypeConverter &converter, 27612ce9fd1SVictor Perez spirv::ClientAPI clientAPI) { 277d45de800SVictor Perez unsigned addressSpace = 278d45de800SVictor Perez storageClassToAddressSpace(clientAPI, type.getStorageClass()); 27997a238e8SChristian Ulmann return LLVM::LLVMPointerType::get(type.getContext(), addressSpace); 280930c74f1SLei Zhang } 281930c74f1SLei Zhang 282930c74f1SLei Zhang /// Converts SPIR-V runtime array to LLVM array. Since LLVM allows indexing over 283930c74f1SLei Zhang /// the bounds, the runtime array is converted to a 0-sized LLVM array. There is 284930c74f1SLei Zhang /// no modelling of array stride at the moment. 2850de16fafSRamkumar Ramachandra static std::optional<Type> convertRuntimeArrayType(spirv::RuntimeArrayType type, 286930c74f1SLei Zhang TypeConverter &converter) { 287930c74f1SLei Zhang if (type.getArrayStride() != 0) 2881a36588eSKazu Hirata return std::nullopt; 289c69c9e0fSAlex Zinenko auto elementType = converter.convertType(type.getElementType()); 290930c74f1SLei Zhang return LLVM::LLVMArrayType::get(elementType, 0); 291930c74f1SLei Zhang } 292930c74f1SLei Zhang 293930c74f1SLei Zhang /// Converts SPIR-V struct to LLVM struct. There is no support of structs with 294930c74f1SLei Zhang /// member decorations. Also, only natural offset is supported. 29530ca16ecSPierre van Houtryve static Type convertStructType(spirv::StructType type, 29698723e65SMatthias Springer const TypeConverter &converter) { 297930c74f1SLei Zhang SmallVector<spirv::StructType::MemberDecorationInfo, 4> memberDecorations; 298930c74f1SLei Zhang type.getMemberDecorations(memberDecorations); 299930c74f1SLei Zhang if (!memberDecorations.empty()) 30030ca16ecSPierre van Houtryve return nullptr; 301930c74f1SLei Zhang if (type.hasOffset()) 302930c74f1SLei Zhang return convertStructTypeWithOffset(type, converter); 303930c74f1SLei Zhang return convertStructTypePacked(type, converter); 304930c74f1SLei Zhang } 305930c74f1SLei Zhang 306930c74f1SLei Zhang //===----------------------------------------------------------------------===// 307930c74f1SLei Zhang // Operation conversion 308930c74f1SLei Zhang //===----------------------------------------------------------------------===// 309930c74f1SLei Zhang 310930c74f1SLei Zhang namespace { 311930c74f1SLei Zhang 312930c74f1SLei Zhang class AccessChainPattern : public SPIRVToLLVMConversion<spirv::AccessChainOp> { 313930c74f1SLei Zhang public: 314930c74f1SLei Zhang using SPIRVToLLVMConversion<spirv::AccessChainOp>::SPIRVToLLVMConversion; 315930c74f1SLei Zhang 316930c74f1SLei Zhang LogicalResult 317b54c724bSRiver Riddle matchAndRewrite(spirv::AccessChainOp op, OpAdaptor adaptor, 318930c74f1SLei Zhang ConversionPatternRewriter &rewriter) const override { 319206fad0eSMatthias Springer auto dstType = 320206fad0eSMatthias Springer getTypeConverter()->convertType(op.getComponentPtr().getType()); 321930c74f1SLei Zhang if (!dstType) 322c19436eeSKohei Yamaguchi return rewriter.notifyMatchFailure(op, "type conversion failed"); 323930c74f1SLei Zhang // To use GEP we need to add a first 0 index to go through the pointer. 32490a1632dSJakub Kuderski auto indices = llvm::to_vector<4>(adaptor.getIndices()); 32590a1632dSJakub Kuderski Type indexType = op.getIndices().front().getType(); 326206fad0eSMatthias Springer auto llvmIndexType = getTypeConverter()->convertType(indexType); 327930c74f1SLei Zhang if (!llvmIndexType) 328c19436eeSKohei Yamaguchi return rewriter.notifyMatchFailure(op, "type conversion failed"); 329930c74f1SLei Zhang Value zero = rewriter.create<LLVM::ConstantOp>( 330930c74f1SLei Zhang op.getLoc(), llvmIndexType, rewriter.getIntegerAttr(indexType, 0)); 331930c74f1SLei Zhang indices.insert(indices.begin(), zero); 332c19436eeSKohei Yamaguchi 333206fad0eSMatthias Springer auto elementType = getTypeConverter()->convertType( 334c19436eeSKohei Yamaguchi cast<spirv::PointerType>(op.getBasePtr().getType()).getPointeeType()); 335c19436eeSKohei Yamaguchi if (!elementType) 336c19436eeSKohei Yamaguchi return rewriter.notifyMatchFailure(op, "type conversion failed"); 337c19436eeSKohei Yamaguchi rewriter.replaceOpWithNewOp<LLVM::GEPOp>(op, dstType, elementType, 33819b1e27fSMarkus Böck adaptor.getBasePtr(), indices); 339930c74f1SLei Zhang return success(); 340930c74f1SLei Zhang } 341930c74f1SLei Zhang }; 342930c74f1SLei Zhang 343930c74f1SLei Zhang class AddressOfPattern : public SPIRVToLLVMConversion<spirv::AddressOfOp> { 344930c74f1SLei Zhang public: 345930c74f1SLei Zhang using SPIRVToLLVMConversion<spirv::AddressOfOp>::SPIRVToLLVMConversion; 346930c74f1SLei Zhang 347930c74f1SLei Zhang LogicalResult 348b54c724bSRiver Riddle matchAndRewrite(spirv::AddressOfOp op, OpAdaptor adaptor, 349930c74f1SLei Zhang ConversionPatternRewriter &rewriter) const override { 350206fad0eSMatthias Springer auto dstType = getTypeConverter()->convertType(op.getPointer().getType()); 351930c74f1SLei Zhang if (!dstType) 352c19436eeSKohei Yamaguchi return rewriter.notifyMatchFailure(op, "type conversion failed"); 353162f7572SMahesh Ravishankar rewriter.replaceOpWithNewOp<LLVM::AddressOfOp>(op, dstType, 354162f7572SMahesh Ravishankar op.getVariable()); 355930c74f1SLei Zhang return success(); 356930c74f1SLei Zhang } 357930c74f1SLei Zhang }; 358930c74f1SLei Zhang 359930c74f1SLei Zhang class BitFieldInsertPattern 360930c74f1SLei Zhang : public SPIRVToLLVMConversion<spirv::BitFieldInsertOp> { 361930c74f1SLei Zhang public: 362930c74f1SLei Zhang using SPIRVToLLVMConversion<spirv::BitFieldInsertOp>::SPIRVToLLVMConversion; 363930c74f1SLei Zhang 364930c74f1SLei Zhang LogicalResult 365b54c724bSRiver Riddle matchAndRewrite(spirv::BitFieldInsertOp op, OpAdaptor adaptor, 366930c74f1SLei Zhang ConversionPatternRewriter &rewriter) const override { 367930c74f1SLei Zhang auto srcType = op.getType(); 368206fad0eSMatthias Springer auto dstType = getTypeConverter()->convertType(srcType); 369930c74f1SLei Zhang if (!dstType) 370c19436eeSKohei Yamaguchi return rewriter.notifyMatchFailure(op, "type conversion failed"); 371930c74f1SLei Zhang Location loc = op.getLoc(); 372930c74f1SLei Zhang 373930c74f1SLei Zhang // Process `Offset` and `Count`: broadcast and extend/truncate if needed. 37490a1632dSJakub Kuderski Value offset = processCountOrOffset(loc, op.getOffset(), srcType, dstType, 375206fad0eSMatthias Springer *getTypeConverter(), rewriter); 37690a1632dSJakub Kuderski Value count = processCountOrOffset(loc, op.getCount(), srcType, dstType, 377206fad0eSMatthias Springer *getTypeConverter(), rewriter); 378930c74f1SLei Zhang 379930c74f1SLei Zhang // Create a mask with bits set outside [Offset, Offset + Count - 1]. 380930c74f1SLei Zhang Value minusOne = createConstantAllBitsSet(loc, srcType, dstType, rewriter); 381930c74f1SLei Zhang Value maskShiftedByCount = 382930c74f1SLei Zhang rewriter.create<LLVM::ShlOp>(loc, dstType, minusOne, count); 383930c74f1SLei Zhang Value negated = rewriter.create<LLVM::XOrOp>(loc, dstType, 384930c74f1SLei Zhang maskShiftedByCount, minusOne); 385930c74f1SLei Zhang Value maskShiftedByCountAndOffset = 386930c74f1SLei Zhang rewriter.create<LLVM::ShlOp>(loc, dstType, negated, offset); 387930c74f1SLei Zhang Value mask = rewriter.create<LLVM::XOrOp>( 388930c74f1SLei Zhang loc, dstType, maskShiftedByCountAndOffset, minusOne); 389930c74f1SLei Zhang 390930c74f1SLei Zhang // Extract unchanged bits from the `Base` that are outside of 391930c74f1SLei Zhang // [Offset, Offset + Count - 1]. Then `or` with shifted `Insert`. 392930c74f1SLei Zhang Value baseAndMask = 39390a1632dSJakub Kuderski rewriter.create<LLVM::AndOp>(loc, dstType, op.getBase(), mask); 394930c74f1SLei Zhang Value insertShiftedByOffset = 39590a1632dSJakub Kuderski rewriter.create<LLVM::ShlOp>(loc, dstType, op.getInsert(), offset); 396930c74f1SLei Zhang rewriter.replaceOpWithNewOp<LLVM::OrOp>(op, dstType, baseAndMask, 397930c74f1SLei Zhang insertShiftedByOffset); 398930c74f1SLei Zhang return success(); 399930c74f1SLei Zhang } 400930c74f1SLei Zhang }; 401930c74f1SLei Zhang 402930c74f1SLei Zhang /// Converts SPIR-V ConstantOp with scalar or vector type. 403930c74f1SLei Zhang class ConstantScalarAndVectorPattern 404930c74f1SLei Zhang : public SPIRVToLLVMConversion<spirv::ConstantOp> { 405930c74f1SLei Zhang public: 406930c74f1SLei Zhang using SPIRVToLLVMConversion<spirv::ConstantOp>::SPIRVToLLVMConversion; 407930c74f1SLei Zhang 408930c74f1SLei Zhang LogicalResult 409b54c724bSRiver Riddle matchAndRewrite(spirv::ConstantOp constOp, OpAdaptor adaptor, 410930c74f1SLei Zhang ConversionPatternRewriter &rewriter) const override { 411930c74f1SLei Zhang auto srcType = constOp.getType(); 4125550c821STres Popp if (!isa<VectorType>(srcType) && !srcType.isIntOrFloat()) 413930c74f1SLei Zhang return failure(); 414930c74f1SLei Zhang 415206fad0eSMatthias Springer auto dstType = getTypeConverter()->convertType(srcType); 416930c74f1SLei Zhang if (!dstType) 417c19436eeSKohei Yamaguchi return rewriter.notifyMatchFailure(constOp, "type conversion failed"); 418930c74f1SLei Zhang 419930c74f1SLei Zhang // SPIR-V constant can be a signed/unsigned integer, which has to be 420930c74f1SLei Zhang // casted to signless integer when converting to LLVM dialect. Removing the 421930c74f1SLei Zhang // sign bit may have unexpected behaviour. However, it is better to handle 422930c74f1SLei Zhang // it case-by-case, given that the purpose of the conversion is not to 423930c74f1SLei Zhang // cover all possible corner cases. 424930c74f1SLei Zhang if (isSignedIntegerOrVector(srcType) || 425930c74f1SLei Zhang isUnsignedIntegerOrVector(srcType)) { 426930c74f1SLei Zhang auto signlessType = rewriter.getIntegerType(getBitWidth(srcType)); 427930c74f1SLei Zhang 4285550c821STres Popp if (isa<VectorType>(srcType)) { 4295550c821STres Popp auto dstElementsAttr = cast<DenseIntElementsAttr>(constOp.getValue()); 430930c74f1SLei Zhang rewriter.replaceOpWithNewOp<LLVM::ConstantOp>( 431930c74f1SLei Zhang constOp, dstType, 432930c74f1SLei Zhang dstElementsAttr.mapValues( 433930c74f1SLei Zhang signlessType, [&](const APInt &value) { return value; })); 434930c74f1SLei Zhang return success(); 435930c74f1SLei Zhang } 4365550c821STres Popp auto srcAttr = cast<IntegerAttr>(constOp.getValue()); 437930c74f1SLei Zhang auto dstAttr = rewriter.getIntegerAttr(signlessType, srcAttr.getValue()); 438930c74f1SLei Zhang rewriter.replaceOpWithNewOp<LLVM::ConstantOp>(constOp, dstType, dstAttr); 439930c74f1SLei Zhang return success(); 440930c74f1SLei Zhang } 441b54c724bSRiver Riddle rewriter.replaceOpWithNewOp<LLVM::ConstantOp>( 442b54c724bSRiver Riddle constOp, dstType, adaptor.getOperands(), constOp->getAttrs()); 443930c74f1SLei Zhang return success(); 444930c74f1SLei Zhang } 445930c74f1SLei Zhang }; 446930c74f1SLei Zhang 447930c74f1SLei Zhang class BitFieldSExtractPattern 448930c74f1SLei Zhang : public SPIRVToLLVMConversion<spirv::BitFieldSExtractOp> { 449930c74f1SLei Zhang public: 450930c74f1SLei Zhang using SPIRVToLLVMConversion<spirv::BitFieldSExtractOp>::SPIRVToLLVMConversion; 451930c74f1SLei Zhang 452930c74f1SLei Zhang LogicalResult 453b54c724bSRiver Riddle matchAndRewrite(spirv::BitFieldSExtractOp op, OpAdaptor adaptor, 454930c74f1SLei Zhang ConversionPatternRewriter &rewriter) const override { 455930c74f1SLei Zhang auto srcType = op.getType(); 456206fad0eSMatthias Springer auto dstType = getTypeConverter()->convertType(srcType); 457930c74f1SLei Zhang if (!dstType) 458c19436eeSKohei Yamaguchi return rewriter.notifyMatchFailure(op, "type conversion failed"); 459930c74f1SLei Zhang Location loc = op.getLoc(); 460930c74f1SLei Zhang 461930c74f1SLei Zhang // Process `Offset` and `Count`: broadcast and extend/truncate if needed. 46290a1632dSJakub Kuderski Value offset = processCountOrOffset(loc, op.getOffset(), srcType, dstType, 463206fad0eSMatthias Springer *getTypeConverter(), rewriter); 46490a1632dSJakub Kuderski Value count = processCountOrOffset(loc, op.getCount(), srcType, dstType, 465206fad0eSMatthias Springer *getTypeConverter(), rewriter); 466930c74f1SLei Zhang 467930c74f1SLei Zhang // Create a constant that holds the size of the `Base`. 468930c74f1SLei Zhang IntegerType integerType; 4695550c821STres Popp if (auto vecType = dyn_cast<VectorType>(srcType)) 4705550c821STres Popp integerType = cast<IntegerType>(vecType.getElementType()); 471930c74f1SLei Zhang else 4725550c821STres Popp integerType = cast<IntegerType>(srcType); 473930c74f1SLei Zhang 474930c74f1SLei Zhang auto baseSize = rewriter.getIntegerAttr(integerType, getBitWidth(srcType)); 475930c74f1SLei Zhang Value size = 4765550c821STres Popp isa<VectorType>(srcType) 477930c74f1SLei Zhang ? rewriter.create<LLVM::ConstantOp>( 478930c74f1SLei Zhang loc, dstType, 4795550c821STres Popp SplatElementsAttr::get(cast<ShapedType>(srcType), baseSize)) 480930c74f1SLei Zhang : rewriter.create<LLVM::ConstantOp>(loc, dstType, baseSize); 481930c74f1SLei Zhang 482930c74f1SLei Zhang // Shift `Base` left by [sizeof(Base) - (Count + Offset)], so that the bit 483930c74f1SLei Zhang // at Offset + Count - 1 is the most significant bit now. 484930c74f1SLei Zhang Value countPlusOffset = 485930c74f1SLei Zhang rewriter.create<LLVM::AddOp>(loc, dstType, count, offset); 486930c74f1SLei Zhang Value amountToShiftLeft = 487930c74f1SLei Zhang rewriter.create<LLVM::SubOp>(loc, dstType, size, countPlusOffset); 488930c74f1SLei Zhang Value baseShiftedLeft = rewriter.create<LLVM::ShlOp>( 48990a1632dSJakub Kuderski loc, dstType, op.getBase(), amountToShiftLeft); 490930c74f1SLei Zhang 491930c74f1SLei Zhang // Shift the result right, filling the bits with the sign bit. 492930c74f1SLei Zhang Value amountToShiftRight = 493930c74f1SLei Zhang rewriter.create<LLVM::AddOp>(loc, dstType, offset, amountToShiftLeft); 494930c74f1SLei Zhang rewriter.replaceOpWithNewOp<LLVM::AShrOp>(op, dstType, baseShiftedLeft, 495930c74f1SLei Zhang amountToShiftRight); 496930c74f1SLei Zhang return success(); 497930c74f1SLei Zhang } 498930c74f1SLei Zhang }; 499930c74f1SLei Zhang 500930c74f1SLei Zhang class BitFieldUExtractPattern 501930c74f1SLei Zhang : public SPIRVToLLVMConversion<spirv::BitFieldUExtractOp> { 502930c74f1SLei Zhang public: 503930c74f1SLei Zhang using SPIRVToLLVMConversion<spirv::BitFieldUExtractOp>::SPIRVToLLVMConversion; 504930c74f1SLei Zhang 505930c74f1SLei Zhang LogicalResult 506b54c724bSRiver Riddle matchAndRewrite(spirv::BitFieldUExtractOp op, OpAdaptor adaptor, 507930c74f1SLei Zhang ConversionPatternRewriter &rewriter) const override { 508930c74f1SLei Zhang auto srcType = op.getType(); 509206fad0eSMatthias Springer auto dstType = getTypeConverter()->convertType(srcType); 510930c74f1SLei Zhang if (!dstType) 511c19436eeSKohei Yamaguchi return rewriter.notifyMatchFailure(op, "type conversion failed"); 512930c74f1SLei Zhang Location loc = op.getLoc(); 513930c74f1SLei Zhang 514930c74f1SLei Zhang // Process `Offset` and `Count`: broadcast and extend/truncate if needed. 51590a1632dSJakub Kuderski Value offset = processCountOrOffset(loc, op.getOffset(), srcType, dstType, 516206fad0eSMatthias Springer *getTypeConverter(), rewriter); 51790a1632dSJakub Kuderski Value count = processCountOrOffset(loc, op.getCount(), srcType, dstType, 518206fad0eSMatthias Springer *getTypeConverter(), rewriter); 519930c74f1SLei Zhang 520930c74f1SLei Zhang // Create a mask with bits set at [0, Count - 1]. 521930c74f1SLei Zhang Value minusOne = createConstantAllBitsSet(loc, srcType, dstType, rewriter); 522930c74f1SLei Zhang Value maskShiftedByCount = 523930c74f1SLei Zhang rewriter.create<LLVM::ShlOp>(loc, dstType, minusOne, count); 524930c74f1SLei Zhang Value mask = rewriter.create<LLVM::XOrOp>(loc, dstType, maskShiftedByCount, 525930c74f1SLei Zhang minusOne); 526930c74f1SLei Zhang 527930c74f1SLei Zhang // Shift `Base` by `Offset` and apply the mask on it. 528930c74f1SLei Zhang Value shiftedBase = 52990a1632dSJakub Kuderski rewriter.create<LLVM::LShrOp>(loc, dstType, op.getBase(), offset); 530930c74f1SLei Zhang rewriter.replaceOpWithNewOp<LLVM::AndOp>(op, dstType, shiftedBase, mask); 531930c74f1SLei Zhang return success(); 532930c74f1SLei Zhang } 533930c74f1SLei Zhang }; 534930c74f1SLei Zhang 535930c74f1SLei Zhang class BranchConversionPattern : public SPIRVToLLVMConversion<spirv::BranchOp> { 536930c74f1SLei Zhang public: 537930c74f1SLei Zhang using SPIRVToLLVMConversion<spirv::BranchOp>::SPIRVToLLVMConversion; 538930c74f1SLei Zhang 539930c74f1SLei Zhang LogicalResult 540b54c724bSRiver Riddle matchAndRewrite(spirv::BranchOp branchOp, OpAdaptor adaptor, 541930c74f1SLei Zhang ConversionPatternRewriter &rewriter) const override { 542b54c724bSRiver Riddle rewriter.replaceOpWithNewOp<LLVM::BrOp>(branchOp, adaptor.getOperands(), 543930c74f1SLei Zhang branchOp.getTarget()); 544930c74f1SLei Zhang return success(); 545930c74f1SLei Zhang } 546930c74f1SLei Zhang }; 547930c74f1SLei Zhang 548930c74f1SLei Zhang class BranchConditionalConversionPattern 549930c74f1SLei Zhang : public SPIRVToLLVMConversion<spirv::BranchConditionalOp> { 550930c74f1SLei Zhang public: 551930c74f1SLei Zhang using SPIRVToLLVMConversion< 552930c74f1SLei Zhang spirv::BranchConditionalOp>::SPIRVToLLVMConversion; 553930c74f1SLei Zhang 554930c74f1SLei Zhang LogicalResult 555b54c724bSRiver Riddle matchAndRewrite(spirv::BranchConditionalOp op, OpAdaptor adaptor, 556930c74f1SLei Zhang ConversionPatternRewriter &rewriter) const override { 557930c74f1SLei Zhang // If branch weights exist, map them to 32-bit integer vector. 55810fa2770STobias Gysi DenseI32ArrayAttr branchWeights = nullptr; 55990a1632dSJakub Kuderski if (auto weights = op.getBranchWeights()) { 56010fa2770STobias Gysi SmallVector<int32_t> weightValues; 56110fa2770STobias Gysi for (auto weight : weights->getAsRange<IntegerAttr>()) 56210fa2770STobias Gysi weightValues.push_back(weight.getInt()); 56310fa2770STobias Gysi branchWeights = DenseI32ArrayAttr::get(getContext(), weightValues); 564930c74f1SLei Zhang } 565930c74f1SLei Zhang 566930c74f1SLei Zhang rewriter.replaceOpWithNewOp<LLVM::CondBrOp>( 56790a1632dSJakub Kuderski op, op.getCondition(), op.getTrueBlockArguments(), 568930c74f1SLei Zhang op.getFalseBlockArguments(), branchWeights, op.getTrueBlock(), 569930c74f1SLei Zhang op.getFalseBlock()); 570930c74f1SLei Zhang return success(); 571930c74f1SLei Zhang } 572930c74f1SLei Zhang }; 573930c74f1SLei Zhang 5745ab6ef75SJakub Kuderski /// Converts `spirv.getCompositeExtract` to `llvm.extractvalue` if the container 5755ab6ef75SJakub Kuderski /// type is an aggregate type (struct or array). Otherwise, converts to 576930c74f1SLei Zhang /// `llvm.extractelement` that operates on vectors. 577930c74f1SLei Zhang class CompositeExtractPattern 578930c74f1SLei Zhang : public SPIRVToLLVMConversion<spirv::CompositeExtractOp> { 579930c74f1SLei Zhang public: 580930c74f1SLei Zhang using SPIRVToLLVMConversion<spirv::CompositeExtractOp>::SPIRVToLLVMConversion; 581930c74f1SLei Zhang 582930c74f1SLei Zhang LogicalResult 583b54c724bSRiver Riddle matchAndRewrite(spirv::CompositeExtractOp op, OpAdaptor adaptor, 584930c74f1SLei Zhang ConversionPatternRewriter &rewriter) const override { 585206fad0eSMatthias Springer auto dstType = this->getTypeConverter()->convertType(op.getType()); 586930c74f1SLei Zhang if (!dstType) 587c19436eeSKohei Yamaguchi return rewriter.notifyMatchFailure(op, "type conversion failed"); 588930c74f1SLei Zhang 58990a1632dSJakub Kuderski Type containerType = op.getComposite().getType(); 5905550c821STres Popp if (isa<VectorType>(containerType)) { 591930c74f1SLei Zhang Location loc = op.getLoc(); 5925550c821STres Popp IntegerAttr value = cast<IntegerAttr>(op.getIndices()[0]); 593930c74f1SLei Zhang Value index = createI32ConstantOf(loc, rewriter, value.getInt()); 594930c74f1SLei Zhang rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>( 59590a1632dSJakub Kuderski op, dstType, adaptor.getComposite(), index); 596930c74f1SLei Zhang return success(); 597930c74f1SLei Zhang } 5985c5af910SJeff Niu 599930c74f1SLei Zhang rewriter.replaceOpWithNewOp<LLVM::ExtractValueOp>( 600162f7572SMahesh Ravishankar op, adaptor.getComposite(), 601162f7572SMahesh Ravishankar LLVM::convertArrayToIndices(op.getIndices())); 602930c74f1SLei Zhang return success(); 603930c74f1SLei Zhang } 604930c74f1SLei Zhang }; 605930c74f1SLei Zhang 6065ab6ef75SJakub Kuderski /// Converts `spirv.getCompositeInsert` to `llvm.insertvalue` if the container 6075ab6ef75SJakub Kuderski /// type is an aggregate type (struct or array). Otherwise, converts to 608930c74f1SLei Zhang /// `llvm.insertelement` that operates on vectors. 609930c74f1SLei Zhang class CompositeInsertPattern 610930c74f1SLei Zhang : public SPIRVToLLVMConversion<spirv::CompositeInsertOp> { 611930c74f1SLei Zhang public: 612930c74f1SLei Zhang using SPIRVToLLVMConversion<spirv::CompositeInsertOp>::SPIRVToLLVMConversion; 613930c74f1SLei Zhang 614930c74f1SLei Zhang LogicalResult 615b54c724bSRiver Riddle matchAndRewrite(spirv::CompositeInsertOp op, OpAdaptor adaptor, 616930c74f1SLei Zhang ConversionPatternRewriter &rewriter) const override { 617206fad0eSMatthias Springer auto dstType = this->getTypeConverter()->convertType(op.getType()); 618930c74f1SLei Zhang if (!dstType) 619c19436eeSKohei Yamaguchi return rewriter.notifyMatchFailure(op, "type conversion failed"); 620930c74f1SLei Zhang 62190a1632dSJakub Kuderski Type containerType = op.getComposite().getType(); 6225550c821STres Popp if (isa<VectorType>(containerType)) { 623930c74f1SLei Zhang Location loc = op.getLoc(); 6245550c821STres Popp IntegerAttr value = cast<IntegerAttr>(op.getIndices()[0]); 625930c74f1SLei Zhang Value index = createI32ConstantOf(loc, rewriter, value.getInt()); 626930c74f1SLei Zhang rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>( 62790a1632dSJakub Kuderski op, dstType, adaptor.getComposite(), adaptor.getObject(), index); 628930c74f1SLei Zhang return success(); 629930c74f1SLei Zhang } 6305c5af910SJeff Niu 631930c74f1SLei Zhang rewriter.replaceOpWithNewOp<LLVM::InsertValueOp>( 63290a1632dSJakub Kuderski op, adaptor.getComposite(), adaptor.getObject(), 63390a1632dSJakub Kuderski LLVM::convertArrayToIndices(op.getIndices())); 634930c74f1SLei Zhang return success(); 635930c74f1SLei Zhang } 636930c74f1SLei Zhang }; 637930c74f1SLei Zhang 638930c74f1SLei Zhang /// Converts SPIR-V operations that have straightforward LLVM equivalent 639930c74f1SLei Zhang /// into LLVM dialect operations. 640930c74f1SLei Zhang template <typename SPIRVOp, typename LLVMOp> 641930c74f1SLei Zhang class DirectConversionPattern : public SPIRVToLLVMConversion<SPIRVOp> { 642930c74f1SLei Zhang public: 643930c74f1SLei Zhang using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion; 644930c74f1SLei Zhang 645930c74f1SLei Zhang LogicalResult 646c19436eeSKohei Yamaguchi matchAndRewrite(SPIRVOp op, typename SPIRVOp::Adaptor adaptor, 647930c74f1SLei Zhang ConversionPatternRewriter &rewriter) const override { 648206fad0eSMatthias Springer auto dstType = this->getTypeConverter()->convertType(op.getType()); 649930c74f1SLei Zhang if (!dstType) 650c19436eeSKohei Yamaguchi return rewriter.notifyMatchFailure(op, "type conversion failed"); 651b54c724bSRiver Riddle rewriter.template replaceOpWithNewOp<LLVMOp>( 652c19436eeSKohei Yamaguchi op, dstType, adaptor.getOperands(), op->getAttrs()); 653930c74f1SLei Zhang return success(); 654930c74f1SLei Zhang } 655930c74f1SLei Zhang }; 656930c74f1SLei Zhang 6575ab6ef75SJakub Kuderski /// Converts `spirv.ExecutionMode` into a global struct constant that holds 658930c74f1SLei Zhang /// execution mode information. 659930c74f1SLei Zhang class ExecutionModePattern 660930c74f1SLei Zhang : public SPIRVToLLVMConversion<spirv::ExecutionModeOp> { 661930c74f1SLei Zhang public: 662930c74f1SLei Zhang using SPIRVToLLVMConversion<spirv::ExecutionModeOp>::SPIRVToLLVMConversion; 663930c74f1SLei Zhang 664930c74f1SLei Zhang LogicalResult 665b54c724bSRiver Riddle matchAndRewrite(spirv::ExecutionModeOp op, OpAdaptor adaptor, 666930c74f1SLei Zhang ConversionPatternRewriter &rewriter) const override { 667930c74f1SLei Zhang // First, create the global struct's name that would be associated with 668930c74f1SLei Zhang // this entry point's execution mode. We set it to be: 669c0a6381eSWeiwei Li // __spv__{SPIR-V module name}_{function name}_execution_mode_info_{mode} 670930c74f1SLei Zhang ModuleOp module = op->getParentOfType<ModuleOp>(); 67190a1632dSJakub Kuderski spirv::ExecutionModeAttr executionModeAttr = op.getExecutionModeAttr(); 672930c74f1SLei Zhang std::string moduleName; 673491d2701SKazu Hirata if (module.getName().has_value()) 67490a1632dSJakub Kuderski moduleName = "_" + module.getName()->str(); 675930c74f1SLei Zhang else 676930c74f1SLei Zhang moduleName = ""; 677a29fffc4SLei Zhang std::string executionModeInfoName = llvm::formatv( 67890a1632dSJakub Kuderski "__spv_{0}_{1}_execution_mode_info_{2}", moduleName, op.getFn().str(), 679a29fffc4SLei Zhang static_cast<uint32_t>(executionModeAttr.getValue())); 680930c74f1SLei Zhang 681930c74f1SLei Zhang MLIRContext *context = rewriter.getContext(); 682930c74f1SLei Zhang OpBuilder::InsertionGuard guard(rewriter); 683930c74f1SLei Zhang rewriter.setInsertionPointToStart(module.getBody()); 684930c74f1SLei Zhang 685930c74f1SLei Zhang // Create a struct type, corresponding to the C struct below. 686930c74f1SLei Zhang // struct { 687930c74f1SLei Zhang // int32_t executionMode; 688930c74f1SLei Zhang // int32_t values[]; // optional values 689930c74f1SLei Zhang // }; 6902230bf99SAlex Zinenko auto llvmI32Type = IntegerType::get(context, 32); 691c69c9e0fSAlex Zinenko SmallVector<Type, 2> fields; 692930c74f1SLei Zhang fields.push_back(llvmI32Type); 69390a1632dSJakub Kuderski ArrayAttr values = op.getValues(); 694930c74f1SLei Zhang if (!values.empty()) { 695930c74f1SLei Zhang auto arrayType = LLVM::LLVMArrayType::get(llvmI32Type, values.size()); 696930c74f1SLei Zhang fields.push_back(arrayType); 697930c74f1SLei Zhang } 698930c74f1SLei Zhang auto structType = LLVM::LLVMStructType::getLiteral(context, fields); 699930c74f1SLei Zhang 700930c74f1SLei Zhang // Create `llvm.mlir.global` with initializer region containing one block. 701930c74f1SLei Zhang auto global = rewriter.create<LLVM::GlobalOp>( 702930c74f1SLei Zhang UnknownLoc::get(context), structType, /*isConstant=*/true, 7039a0ea599SDumitru Potop LLVM::Linkage::External, executionModeInfoName, Attribute(), 7049a0ea599SDumitru Potop /*alignment=*/0); 705930c74f1SLei Zhang Location loc = global.getLoc(); 706930c74f1SLei Zhang Region ®ion = global.getInitializerRegion(); 707930c74f1SLei Zhang Block *block = rewriter.createBlock(®ion); 708930c74f1SLei Zhang 709930c74f1SLei Zhang // Initialize the struct and set the execution mode value. 710b613a540SMatthias Springer rewriter.setInsertionPointToStart(block); 711930c74f1SLei Zhang Value structValue = rewriter.create<LLVM::UndefOp>(loc, structType); 712a29fffc4SLei Zhang Value executionMode = rewriter.create<LLVM::ConstantOp>( 713a29fffc4SLei Zhang loc, llvmI32Type, 714a29fffc4SLei Zhang rewriter.getI32IntegerAttr( 715a29fffc4SLei Zhang static_cast<uint32_t>(executionModeAttr.getValue()))); 7165c5af910SJeff Niu structValue = rewriter.create<LLVM::InsertValueOp>(loc, structValue, 7175c5af910SJeff Niu executionMode, 0); 718930c74f1SLei Zhang 719930c74f1SLei Zhang // Insert extra operands if they exist into execution mode info struct. 720930c74f1SLei Zhang for (unsigned i = 0, e = values.size(); i < e; ++i) { 721930c74f1SLei Zhang auto attr = values.getValue()[i]; 722930c74f1SLei Zhang Value entry = rewriter.create<LLVM::ConstantOp>(loc, llvmI32Type, attr); 723930c74f1SLei Zhang structValue = rewriter.create<LLVM::InsertValueOp>( 72482973067SUday Bondhugula loc, structValue, entry, ArrayRef<int64_t>({1, i})); 725930c74f1SLei Zhang } 726930c74f1SLei Zhang rewriter.create<LLVM::ReturnOp>(loc, ArrayRef<Value>({structValue})); 727930c74f1SLei Zhang rewriter.eraseOp(op); 728930c74f1SLei Zhang return success(); 729930c74f1SLei Zhang } 730930c74f1SLei Zhang }; 731930c74f1SLei Zhang 7325ab6ef75SJakub Kuderski /// Converts `spirv.GlobalVariable` to `llvm.mlir.global`. Note that SPIR-V 7335ab6ef75SJakub Kuderski /// global returns a pointer, whereas in LLVM dialect the global holds an actual 7345ab6ef75SJakub Kuderski /// value. This difference is handled by `spirv.mlir.addressof` and 735930c74f1SLei Zhang /// `llvm.mlir.addressof`ops that both return a pointer. 736930c74f1SLei Zhang class GlobalVariablePattern 737930c74f1SLei Zhang : public SPIRVToLLVMConversion<spirv::GlobalVariableOp> { 738930c74f1SLei Zhang public: 73912ce9fd1SVictor Perez template <typename... Args> 74012ce9fd1SVictor Perez GlobalVariablePattern(spirv::ClientAPI clientAPI, Args &&...args) 74112ce9fd1SVictor Perez : SPIRVToLLVMConversion<spirv::GlobalVariableOp>( 74212ce9fd1SVictor Perez std::forward<Args>(args)...), 74312ce9fd1SVictor Perez clientAPI(clientAPI) {} 744930c74f1SLei Zhang 745930c74f1SLei Zhang LogicalResult 746b54c724bSRiver Riddle matchAndRewrite(spirv::GlobalVariableOp op, OpAdaptor adaptor, 747930c74f1SLei Zhang ConversionPatternRewriter &rewriter) const override { 748930c74f1SLei Zhang // Currently, there is no support of initialization with a constant value in 749930c74f1SLei Zhang // SPIR-V dialect. Specialization constants are not considered as well. 75090a1632dSJakub Kuderski if (op.getInitializer()) 751930c74f1SLei Zhang return failure(); 752930c74f1SLei Zhang 7535550c821STres Popp auto srcType = cast<spirv::PointerType>(op.getType()); 754206fad0eSMatthias Springer auto dstType = getTypeConverter()->convertType(srcType.getPointeeType()); 755930c74f1SLei Zhang if (!dstType) 756c19436eeSKohei Yamaguchi return rewriter.notifyMatchFailure(op, "type conversion failed"); 757930c74f1SLei Zhang 758930c74f1SLei Zhang // Limit conversion to the current invocation only or `StorageBuffer` 759930c74f1SLei Zhang // required by SPIR-V runner. 760930c74f1SLei Zhang // This is okay because multiple invocations are not supported yet. 761930c74f1SLei Zhang auto storageClass = srcType.getStorageClass(); 7621e4cfe5eSWeiwei Li switch (storageClass) { 7631e4cfe5eSWeiwei Li case spirv::StorageClass::Input: 7641e4cfe5eSWeiwei Li case spirv::StorageClass::Private: 7651e4cfe5eSWeiwei Li case spirv::StorageClass::Output: 7661e4cfe5eSWeiwei Li case spirv::StorageClass::StorageBuffer: 7671e4cfe5eSWeiwei Li case spirv::StorageClass::UniformConstant: 7681e4cfe5eSWeiwei Li break; 7691e4cfe5eSWeiwei Li default: 770930c74f1SLei Zhang return failure(); 771930c74f1SLei Zhang } 772930c74f1SLei Zhang 773930c74f1SLei Zhang // LLVM dialect spec: "If the global value is a constant, storing into it is 7741e4cfe5eSWeiwei Li // not allowed.". This corresponds to SPIR-V 'Input' and 'UniformConstant' 7751e4cfe5eSWeiwei Li // storage class that is read-only. 7761e4cfe5eSWeiwei Li bool isConstant = (storageClass == spirv::StorageClass::Input) || 7771e4cfe5eSWeiwei Li (storageClass == spirv::StorageClass::UniformConstant); 778930c74f1SLei Zhang // SPIR-V spec: "By default, functions and global variables are private to a 779930c74f1SLei Zhang // module and cannot be accessed by other modules. However, a module may be 780930c74f1SLei Zhang // written to export or import functions and global (module scope) 781930c74f1SLei Zhang // variables.". Therefore, map 'Private' storage class to private linkage, 782930c74f1SLei Zhang // 'Input' and 'Output' to external linkage. 783930c74f1SLei Zhang auto linkage = storageClass == spirv::StorageClass::Private 784930c74f1SLei Zhang ? LLVM::Linkage::Private 785930c74f1SLei Zhang : LLVM::Linkage::External; 7861e4cfe5eSWeiwei Li auto newGlobalOp = rewriter.replaceOpWithNewOp<LLVM::GlobalOp>( 78790a1632dSJakub Kuderski op, dstType, isConstant, linkage, op.getSymName(), Attribute(), 788d45de800SVictor Perez /*alignment=*/0, storageClassToAddressSpace(clientAPI, storageClass)); 7891e4cfe5eSWeiwei Li 7901e4cfe5eSWeiwei Li // Attach location attribute if applicable 79190a1632dSJakub Kuderski if (op.getLocationAttr()) 79290a1632dSJakub Kuderski newGlobalOp->setAttr(op.getLocationAttrName(), op.getLocationAttr()); 7931e4cfe5eSWeiwei Li 794930c74f1SLei Zhang return success(); 795930c74f1SLei Zhang } 79612ce9fd1SVictor Perez 79712ce9fd1SVictor Perez private: 79812ce9fd1SVictor Perez spirv::ClientAPI clientAPI; 799930c74f1SLei Zhang }; 800930c74f1SLei Zhang 801930c74f1SLei Zhang /// Converts SPIR-V cast ops that do not have straightforward LLVM 802930c74f1SLei Zhang /// equivalent in LLVM dialect. 803930c74f1SLei Zhang template <typename SPIRVOp, typename LLVMExtOp, typename LLVMTruncOp> 804930c74f1SLei Zhang class IndirectCastPattern : public SPIRVToLLVMConversion<SPIRVOp> { 805930c74f1SLei Zhang public: 806930c74f1SLei Zhang using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion; 807930c74f1SLei Zhang 808930c74f1SLei Zhang LogicalResult 809c19436eeSKohei Yamaguchi matchAndRewrite(SPIRVOp op, typename SPIRVOp::Adaptor adaptor, 810930c74f1SLei Zhang ConversionPatternRewriter &rewriter) const override { 811930c74f1SLei Zhang 812c19436eeSKohei Yamaguchi Type fromType = op.getOperand().getType(); 813c19436eeSKohei Yamaguchi Type toType = op.getType(); 814930c74f1SLei Zhang 815206fad0eSMatthias Springer auto dstType = this->getTypeConverter()->convertType(toType); 816930c74f1SLei Zhang if (!dstType) 817c19436eeSKohei Yamaguchi return rewriter.notifyMatchFailure(op, "type conversion failed"); 818930c74f1SLei Zhang 819930c74f1SLei Zhang if (getBitWidth(fromType) < getBitWidth(toType)) { 820c19436eeSKohei Yamaguchi rewriter.template replaceOpWithNewOp<LLVMExtOp>(op, dstType, 821b54c724bSRiver Riddle adaptor.getOperands()); 822930c74f1SLei Zhang return success(); 823930c74f1SLei Zhang } 824930c74f1SLei Zhang if (getBitWidth(fromType) > getBitWidth(toType)) { 825c19436eeSKohei Yamaguchi rewriter.template replaceOpWithNewOp<LLVMTruncOp>(op, dstType, 826b54c724bSRiver Riddle adaptor.getOperands()); 827930c74f1SLei Zhang return success(); 828930c74f1SLei Zhang } 829930c74f1SLei Zhang return failure(); 830930c74f1SLei Zhang } 831930c74f1SLei Zhang }; 832930c74f1SLei Zhang 833930c74f1SLei Zhang class FunctionCallPattern 834930c74f1SLei Zhang : public SPIRVToLLVMConversion<spirv::FunctionCallOp> { 835930c74f1SLei Zhang public: 836930c74f1SLei Zhang using SPIRVToLLVMConversion<spirv::FunctionCallOp>::SPIRVToLLVMConversion; 837930c74f1SLei Zhang 838930c74f1SLei Zhang LogicalResult 839b54c724bSRiver Riddle matchAndRewrite(spirv::FunctionCallOp callOp, OpAdaptor adaptor, 840930c74f1SLei Zhang ConversionPatternRewriter &rewriter) const override { 841930c74f1SLei Zhang if (callOp.getNumResults() == 0) { 842fde3c16aSSirui Mu auto newOp = rewriter.replaceOpWithNewOp<LLVM::CallOp>( 8431a36588eSKazu Hirata callOp, std::nullopt, adaptor.getOperands(), callOp->getAttrs()); 844fde3c16aSSirui Mu newOp.getProperties().operandSegmentSizes = { 845fde3c16aSSirui Mu static_cast<int32_t>(adaptor.getOperands().size()), 0}; 846fde3c16aSSirui Mu newOp.getProperties().op_bundle_sizes = rewriter.getDenseI32ArrayAttr({}); 847930c74f1SLei Zhang return success(); 848930c74f1SLei Zhang } 849930c74f1SLei Zhang 850930c74f1SLei Zhang // Function returns a single result. 851206fad0eSMatthias Springer auto dstType = getTypeConverter()->convertType(callOp.getType(0)); 852c19436eeSKohei Yamaguchi if (!dstType) 853c19436eeSKohei Yamaguchi return rewriter.notifyMatchFailure(callOp, "type conversion failed"); 854fde3c16aSSirui Mu auto newOp = rewriter.replaceOpWithNewOp<LLVM::CallOp>( 855b54c724bSRiver Riddle callOp, dstType, adaptor.getOperands(), callOp->getAttrs()); 856fde3c16aSSirui Mu newOp.getProperties().operandSegmentSizes = { 857fde3c16aSSirui Mu static_cast<int32_t>(adaptor.getOperands().size()), 0}; 858fde3c16aSSirui Mu newOp.getProperties().op_bundle_sizes = rewriter.getDenseI32ArrayAttr({}); 859930c74f1SLei Zhang return success(); 860930c74f1SLei Zhang } 861930c74f1SLei Zhang }; 862930c74f1SLei Zhang 863930c74f1SLei Zhang /// Converts SPIR-V floating-point comparisons to llvm.fcmp "predicate" 864930c74f1SLei Zhang template <typename SPIRVOp, LLVM::FCmpPredicate predicate> 865930c74f1SLei Zhang class FComparePattern : public SPIRVToLLVMConversion<SPIRVOp> { 866930c74f1SLei Zhang public: 867930c74f1SLei Zhang using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion; 868930c74f1SLei Zhang 869930c74f1SLei Zhang LogicalResult 870c19436eeSKohei Yamaguchi matchAndRewrite(SPIRVOp op, typename SPIRVOp::Adaptor adaptor, 871930c74f1SLei Zhang ConversionPatternRewriter &rewriter) const override { 872930c74f1SLei Zhang 873206fad0eSMatthias Springer auto dstType = this->getTypeConverter()->convertType(op.getType()); 874930c74f1SLei Zhang if (!dstType) 875c19436eeSKohei Yamaguchi return rewriter.notifyMatchFailure(op, "type conversion failed"); 876930c74f1SLei Zhang 877930c74f1SLei Zhang rewriter.template replaceOpWithNewOp<LLVM::FCmpOp>( 878c19436eeSKohei Yamaguchi op, dstType, predicate, op.getOperand1(), op.getOperand2()); 879930c74f1SLei Zhang return success(); 880930c74f1SLei Zhang } 881930c74f1SLei Zhang }; 882930c74f1SLei Zhang 883930c74f1SLei Zhang /// Converts SPIR-V integer comparisons to llvm.icmp "predicate" 884930c74f1SLei Zhang template <typename SPIRVOp, LLVM::ICmpPredicate predicate> 885930c74f1SLei Zhang class IComparePattern : public SPIRVToLLVMConversion<SPIRVOp> { 886930c74f1SLei Zhang public: 887930c74f1SLei Zhang using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion; 888930c74f1SLei Zhang 889930c74f1SLei Zhang LogicalResult 890c19436eeSKohei Yamaguchi matchAndRewrite(SPIRVOp op, typename SPIRVOp::Adaptor adaptor, 891930c74f1SLei Zhang ConversionPatternRewriter &rewriter) const override { 892930c74f1SLei Zhang 893206fad0eSMatthias Springer auto dstType = this->getTypeConverter()->convertType(op.getType()); 894930c74f1SLei Zhang if (!dstType) 895c19436eeSKohei Yamaguchi return rewriter.notifyMatchFailure(op, "type conversion failed"); 896930c74f1SLei Zhang 897930c74f1SLei Zhang rewriter.template replaceOpWithNewOp<LLVM::ICmpOp>( 898c19436eeSKohei Yamaguchi op, dstType, predicate, op.getOperand1(), op.getOperand2()); 899930c74f1SLei Zhang return success(); 900930c74f1SLei Zhang } 901930c74f1SLei Zhang }; 902930c74f1SLei Zhang 903930c74f1SLei Zhang class InverseSqrtPattern 90452b630daSJakub Kuderski : public SPIRVToLLVMConversion<spirv::GLInverseSqrtOp> { 905930c74f1SLei Zhang public: 90652b630daSJakub Kuderski using SPIRVToLLVMConversion<spirv::GLInverseSqrtOp>::SPIRVToLLVMConversion; 907930c74f1SLei Zhang 908930c74f1SLei Zhang LogicalResult 90952b630daSJakub Kuderski matchAndRewrite(spirv::GLInverseSqrtOp op, OpAdaptor adaptor, 910930c74f1SLei Zhang ConversionPatternRewriter &rewriter) const override { 911930c74f1SLei Zhang auto srcType = op.getType(); 912206fad0eSMatthias Springer auto dstType = getTypeConverter()->convertType(srcType); 913930c74f1SLei Zhang if (!dstType) 914c19436eeSKohei Yamaguchi return rewriter.notifyMatchFailure(op, "type conversion failed"); 915930c74f1SLei Zhang 916930c74f1SLei Zhang Location loc = op.getLoc(); 917930c74f1SLei Zhang Value one = createFPConstant(loc, srcType, dstType, rewriter, 1.0); 91890a1632dSJakub Kuderski Value sqrt = rewriter.create<LLVM::SqrtOp>(loc, dstType, op.getOperand()); 919930c74f1SLei Zhang rewriter.replaceOpWithNewOp<LLVM::FDivOp>(op, dstType, one, sqrt); 920930c74f1SLei Zhang return success(); 921930c74f1SLei Zhang } 922930c74f1SLei Zhang }; 923930c74f1SLei Zhang 9245ab6ef75SJakub Kuderski /// Converts `spirv.Load` and `spirv.Store` to LLVM dialect. 925b54c724bSRiver Riddle template <typename SPIRVOp> 926b54c724bSRiver Riddle class LoadStorePattern : public SPIRVToLLVMConversion<SPIRVOp> { 927930c74f1SLei Zhang public: 928b54c724bSRiver Riddle using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion; 929930c74f1SLei Zhang 930930c74f1SLei Zhang LogicalResult 931b54c724bSRiver Riddle matchAndRewrite(SPIRVOp op, typename SPIRVOp::Adaptor adaptor, 932930c74f1SLei Zhang ConversionPatternRewriter &rewriter) const override { 93390a1632dSJakub Kuderski if (!op.getMemoryAccess()) { 934015192c6SRiver Riddle return replaceWithLoadOrStore(op, adaptor.getOperands(), rewriter, 935206fad0eSMatthias Springer *this->getTypeConverter(), /*alignment=*/0, 936015192c6SRiver Riddle /*isVolatile=*/false, 937015192c6SRiver Riddle /*isNonTemporal=*/false); 938930c74f1SLei Zhang } 93990a1632dSJakub Kuderski auto memoryAccess = *op.getMemoryAccess(); 940930c74f1SLei Zhang switch (memoryAccess) { 941930c74f1SLei Zhang case spirv::MemoryAccess::Aligned: 942930c74f1SLei Zhang case spirv::MemoryAccess::None: 943930c74f1SLei Zhang case spirv::MemoryAccess::Nontemporal: 944930c74f1SLei Zhang case spirv::MemoryAccess::Volatile: { 945930c74f1SLei Zhang unsigned alignment = 94690a1632dSJakub Kuderski memoryAccess == spirv::MemoryAccess::Aligned ? *op.getAlignment() : 0; 947930c74f1SLei Zhang bool isNonTemporal = memoryAccess == spirv::MemoryAccess::Nontemporal; 948930c74f1SLei Zhang bool isVolatile = memoryAccess == spirv::MemoryAccess::Volatile; 949015192c6SRiver Riddle return replaceWithLoadOrStore(op, adaptor.getOperands(), rewriter, 950206fad0eSMatthias Springer *this->getTypeConverter(), alignment, 951206fad0eSMatthias Springer isVolatile, isNonTemporal); 952930c74f1SLei Zhang } 953930c74f1SLei Zhang default: 954930c74f1SLei Zhang // There is no support of other memory access attributes. 955930c74f1SLei Zhang return failure(); 956930c74f1SLei Zhang } 957930c74f1SLei Zhang } 958930c74f1SLei Zhang }; 959930c74f1SLei Zhang 9605ab6ef75SJakub Kuderski /// Converts `spirv.Not` and `spirv.LogicalNot` into LLVM dialect. 961930c74f1SLei Zhang template <typename SPIRVOp> 962930c74f1SLei Zhang class NotPattern : public SPIRVToLLVMConversion<SPIRVOp> { 963930c74f1SLei Zhang public: 964930c74f1SLei Zhang using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion; 965930c74f1SLei Zhang 966930c74f1SLei Zhang LogicalResult 967b54c724bSRiver Riddle matchAndRewrite(SPIRVOp notOp, typename SPIRVOp::Adaptor adaptor, 968930c74f1SLei Zhang ConversionPatternRewriter &rewriter) const override { 969930c74f1SLei Zhang auto srcType = notOp.getType(); 970206fad0eSMatthias Springer auto dstType = this->getTypeConverter()->convertType(srcType); 971930c74f1SLei Zhang if (!dstType) 972c19436eeSKohei Yamaguchi return rewriter.notifyMatchFailure(notOp, "type conversion failed"); 973930c74f1SLei Zhang 974930c74f1SLei Zhang Location loc = notOp.getLoc(); 975930c74f1SLei Zhang IntegerAttr minusOne = minusOneIntegerAttribute(srcType, rewriter); 9765550c821STres Popp auto mask = 9775550c821STres Popp isa<VectorType>(srcType) 978930c74f1SLei Zhang ? rewriter.create<LLVM::ConstantOp>( 979930c74f1SLei Zhang loc, dstType, 9805550c821STres Popp SplatElementsAttr::get(cast<VectorType>(srcType), minusOne)) 981930c74f1SLei Zhang : rewriter.create<LLVM::ConstantOp>(loc, dstType, minusOne); 982930c74f1SLei Zhang rewriter.template replaceOpWithNewOp<LLVM::XOrOp>(notOp, dstType, 98390a1632dSJakub Kuderski notOp.getOperand(), mask); 984930c74f1SLei Zhang return success(); 985930c74f1SLei Zhang } 986930c74f1SLei Zhang }; 987930c74f1SLei Zhang 988930c74f1SLei Zhang /// A template pattern that erases the given `SPIRVOp`. 989930c74f1SLei Zhang template <typename SPIRVOp> 990930c74f1SLei Zhang class ErasePattern : public SPIRVToLLVMConversion<SPIRVOp> { 991930c74f1SLei Zhang public: 992930c74f1SLei Zhang using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion; 993930c74f1SLei Zhang 994930c74f1SLei Zhang LogicalResult 995b54c724bSRiver Riddle matchAndRewrite(SPIRVOp op, typename SPIRVOp::Adaptor adaptor, 996930c74f1SLei Zhang ConversionPatternRewriter &rewriter) const override { 997930c74f1SLei Zhang rewriter.eraseOp(op); 998930c74f1SLei Zhang return success(); 999930c74f1SLei Zhang } 1000930c74f1SLei Zhang }; 1001930c74f1SLei Zhang 1002930c74f1SLei Zhang class ReturnPattern : public SPIRVToLLVMConversion<spirv::ReturnOp> { 1003930c74f1SLei Zhang public: 1004930c74f1SLei Zhang using SPIRVToLLVMConversion<spirv::ReturnOp>::SPIRVToLLVMConversion; 1005930c74f1SLei Zhang 1006930c74f1SLei Zhang LogicalResult 1007b54c724bSRiver Riddle matchAndRewrite(spirv::ReturnOp returnOp, OpAdaptor adaptor, 1008930c74f1SLei Zhang ConversionPatternRewriter &rewriter) const override { 1009930c74f1SLei Zhang rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(returnOp, ArrayRef<Type>(), 1010930c74f1SLei Zhang ArrayRef<Value>()); 1011930c74f1SLei Zhang return success(); 1012930c74f1SLei Zhang } 1013930c74f1SLei Zhang }; 1014930c74f1SLei Zhang 1015930c74f1SLei Zhang class ReturnValuePattern : public SPIRVToLLVMConversion<spirv::ReturnValueOp> { 1016930c74f1SLei Zhang public: 1017930c74f1SLei Zhang using SPIRVToLLVMConversion<spirv::ReturnValueOp>::SPIRVToLLVMConversion; 1018930c74f1SLei Zhang 1019930c74f1SLei Zhang LogicalResult 1020b54c724bSRiver Riddle matchAndRewrite(spirv::ReturnValueOp returnValueOp, OpAdaptor adaptor, 1021930c74f1SLei Zhang ConversionPatternRewriter &rewriter) const override { 1022930c74f1SLei Zhang rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(returnValueOp, ArrayRef<Type>(), 1023b54c724bSRiver Riddle adaptor.getOperands()); 1024930c74f1SLei Zhang return success(); 1025930c74f1SLei Zhang } 1026930c74f1SLei Zhang }; 1027930c74f1SLei Zhang 10281775b98dSFinlay static LLVM::LLVMFuncOp lookupOrCreateSPIRVFn(Operation *symbolTable, 10291775b98dSFinlay StringRef name, 10301775b98dSFinlay ArrayRef<Type> paramTypes, 10316ade03d7SLukas Sommer Type resultType, 10326ade03d7SLukas Sommer bool convergent = true) { 10331775b98dSFinlay auto func = dyn_cast_or_null<LLVM::LLVMFuncOp>( 10341775b98dSFinlay SymbolTable::lookupSymbolIn(symbolTable, name)); 10351775b98dSFinlay if (func) 10361775b98dSFinlay return func; 10371775b98dSFinlay 10381775b98dSFinlay OpBuilder b(symbolTable->getRegion(0)); 10391775b98dSFinlay func = b.create<LLVM::LLVMFuncOp>( 10401775b98dSFinlay symbolTable->getLoc(), name, 10411775b98dSFinlay LLVM::LLVMFunctionType::get(resultType, paramTypes)); 10421775b98dSFinlay func.setCConv(LLVM::cconv::CConv::SPIR_FUNC); 10436ade03d7SLukas Sommer func.setConvergent(convergent); 10441775b98dSFinlay func.setNoUnwind(true); 10451775b98dSFinlay func.setWillReturn(true); 10461775b98dSFinlay return func; 10471775b98dSFinlay } 10481775b98dSFinlay 10491775b98dSFinlay static LLVM::CallOp createSPIRVBuiltinCall(Location loc, OpBuilder &builder, 10501775b98dSFinlay LLVM::LLVMFuncOp func, 10511775b98dSFinlay ValueRange args) { 10521775b98dSFinlay auto call = builder.create<LLVM::CallOp>(loc, func, args); 10531775b98dSFinlay call.setCConv(func.getCConv()); 10541775b98dSFinlay call.setConvergentAttr(func.getConvergentAttr()); 10551775b98dSFinlay call.setNoUnwindAttr(func.getNoUnwindAttr()); 10561775b98dSFinlay call.setWillReturnAttr(func.getWillReturnAttr()); 10571775b98dSFinlay return call; 10581775b98dSFinlay } 10591775b98dSFinlay 106005fcdd55SVictor Perez template <typename BarrierOpTy> 106105fcdd55SVictor Perez class ControlBarrierPattern : public SPIRVToLLVMConversion<BarrierOpTy> { 10621775b98dSFinlay public: 106305fcdd55SVictor Perez using OpAdaptor = typename SPIRVToLLVMConversion<BarrierOpTy>::OpAdaptor; 106405fcdd55SVictor Perez 106505fcdd55SVictor Perez using SPIRVToLLVMConversion<BarrierOpTy>::SPIRVToLLVMConversion; 106605fcdd55SVictor Perez 106705fcdd55SVictor Perez static constexpr StringRef getFuncName(); 10681775b98dSFinlay 10691775b98dSFinlay LogicalResult 107005fcdd55SVictor Perez matchAndRewrite(BarrierOpTy controlBarrierOp, OpAdaptor adaptor, 10711775b98dSFinlay ConversionPatternRewriter &rewriter) const override { 107205fcdd55SVictor Perez constexpr StringRef funcName = getFuncName(); 10731775b98dSFinlay Operation *symbolTable = 107405fcdd55SVictor Perez controlBarrierOp->template getParentWithTrait<OpTrait::SymbolTable>(); 10751775b98dSFinlay 10761775b98dSFinlay Type i32 = rewriter.getI32Type(); 10771775b98dSFinlay 10781775b98dSFinlay Type voidTy = rewriter.getType<LLVM::LLVMVoidType>(); 10791775b98dSFinlay LLVM::LLVMFuncOp func = 10801775b98dSFinlay lookupOrCreateSPIRVFn(symbolTable, funcName, {i32, i32, i32}, voidTy); 10811775b98dSFinlay 10821775b98dSFinlay Location loc = controlBarrierOp->getLoc(); 10831775b98dSFinlay Value execution = rewriter.create<LLVM::ConstantOp>( 10841775b98dSFinlay loc, i32, static_cast<int32_t>(adaptor.getExecutionScope())); 10851775b98dSFinlay Value memory = rewriter.create<LLVM::ConstantOp>( 10861775b98dSFinlay loc, i32, static_cast<int32_t>(adaptor.getMemoryScope())); 10871775b98dSFinlay Value semantics = rewriter.create<LLVM::ConstantOp>( 10881775b98dSFinlay loc, i32, static_cast<int32_t>(adaptor.getMemorySemantics())); 10891775b98dSFinlay 10901775b98dSFinlay auto call = createSPIRVBuiltinCall(loc, rewriter, func, 10911775b98dSFinlay {execution, memory, semantics}); 10921775b98dSFinlay 10931775b98dSFinlay rewriter.replaceOp(controlBarrierOp, call); 10941775b98dSFinlay return success(); 10951775b98dSFinlay } 10961775b98dSFinlay }; 10971775b98dSFinlay 10986ade03d7SLukas Sommer namespace { 10996ade03d7SLukas Sommer 11006ade03d7SLukas Sommer StringRef getTypeMangling(Type type, bool isSigned) { 11016ade03d7SLukas Sommer return llvm::TypeSwitch<Type, StringRef>(type) 11026ade03d7SLukas Sommer .Case<Float16Type>([](auto) { return "Dh"; }) 11036ade03d7SLukas Sommer .Case<Float32Type>([](auto) { return "f"; }) 11046ade03d7SLukas Sommer .Case<Float64Type>([](auto) { return "d"; }) 11056ade03d7SLukas Sommer .Case<IntegerType>([isSigned](IntegerType intTy) { 11066ade03d7SLukas Sommer switch (intTy.getWidth()) { 11076ade03d7SLukas Sommer case 1: 11086ade03d7SLukas Sommer return "b"; 11096ade03d7SLukas Sommer case 8: 11106ade03d7SLukas Sommer return (isSigned) ? "a" : "c"; 11116ade03d7SLukas Sommer case 16: 11126ade03d7SLukas Sommer return (isSigned) ? "s" : "t"; 11136ade03d7SLukas Sommer case 32: 11146ade03d7SLukas Sommer return (isSigned) ? "i" : "j"; 11156ade03d7SLukas Sommer case 64: 11166ade03d7SLukas Sommer return (isSigned) ? "l" : "m"; 11176ade03d7SLukas Sommer default: 11186ade03d7SLukas Sommer llvm_unreachable("Unsupported integer width"); 11196ade03d7SLukas Sommer } 11206ade03d7SLukas Sommer }) 11216ade03d7SLukas Sommer .Default([](auto) { 11226ade03d7SLukas Sommer llvm_unreachable("No mangling defined"); 11236ade03d7SLukas Sommer return ""; 11246ade03d7SLukas Sommer }); 11256ade03d7SLukas Sommer } 11266ade03d7SLukas Sommer 11276ade03d7SLukas Sommer template <typename ReduceOp> 11286ade03d7SLukas Sommer constexpr StringLiteral getGroupFuncName(); 11296ade03d7SLukas Sommer 11306ade03d7SLukas Sommer template <> 11316ade03d7SLukas Sommer constexpr StringLiteral getGroupFuncName<spirv::GroupIAddOp>() { 11326ade03d7SLukas Sommer return "_Z17__spirv_GroupIAddii"; 11336ade03d7SLukas Sommer } 11346ade03d7SLukas Sommer template <> 11356ade03d7SLukas Sommer constexpr StringLiteral getGroupFuncName<spirv::GroupFAddOp>() { 11366ade03d7SLukas Sommer return "_Z17__spirv_GroupFAddii"; 11376ade03d7SLukas Sommer } 11386ade03d7SLukas Sommer template <> 11396ade03d7SLukas Sommer constexpr StringLiteral getGroupFuncName<spirv::GroupSMinOp>() { 11406ade03d7SLukas Sommer return "_Z17__spirv_GroupSMinii"; 11416ade03d7SLukas Sommer } 11426ade03d7SLukas Sommer template <> 11436ade03d7SLukas Sommer constexpr StringLiteral getGroupFuncName<spirv::GroupUMinOp>() { 11446ade03d7SLukas Sommer return "_Z17__spirv_GroupUMinii"; 11456ade03d7SLukas Sommer } 11466ade03d7SLukas Sommer template <> 11476ade03d7SLukas Sommer constexpr StringLiteral getGroupFuncName<spirv::GroupFMinOp>() { 11486ade03d7SLukas Sommer return "_Z17__spirv_GroupFMinii"; 11496ade03d7SLukas Sommer } 11506ade03d7SLukas Sommer template <> 11516ade03d7SLukas Sommer constexpr StringLiteral getGroupFuncName<spirv::GroupSMaxOp>() { 11526ade03d7SLukas Sommer return "_Z17__spirv_GroupSMaxii"; 11536ade03d7SLukas Sommer } 11546ade03d7SLukas Sommer template <> 11556ade03d7SLukas Sommer constexpr StringLiteral getGroupFuncName<spirv::GroupUMaxOp>() { 11566ade03d7SLukas Sommer return "_Z17__spirv_GroupUMaxii"; 11576ade03d7SLukas Sommer } 11586ade03d7SLukas Sommer template <> 11596ade03d7SLukas Sommer constexpr StringLiteral getGroupFuncName<spirv::GroupFMaxOp>() { 11606ade03d7SLukas Sommer return "_Z17__spirv_GroupFMaxii"; 11616ade03d7SLukas Sommer } 11626ade03d7SLukas Sommer template <> 11636ade03d7SLukas Sommer constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformIAddOp>() { 11646ade03d7SLukas Sommer return "_Z27__spirv_GroupNonUniformIAddii"; 11656ade03d7SLukas Sommer } 11666ade03d7SLukas Sommer template <> 11676ade03d7SLukas Sommer constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformFAddOp>() { 11686ade03d7SLukas Sommer return "_Z27__spirv_GroupNonUniformFAddii"; 11696ade03d7SLukas Sommer } 11706ade03d7SLukas Sommer template <> 11716ade03d7SLukas Sommer constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformIMulOp>() { 11726ade03d7SLukas Sommer return "_Z27__spirv_GroupNonUniformIMulii"; 11736ade03d7SLukas Sommer } 11746ade03d7SLukas Sommer template <> 11756ade03d7SLukas Sommer constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformFMulOp>() { 11766ade03d7SLukas Sommer return "_Z27__spirv_GroupNonUniformFMulii"; 11776ade03d7SLukas Sommer } 11786ade03d7SLukas Sommer template <> 11796ade03d7SLukas Sommer constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformSMinOp>() { 11806ade03d7SLukas Sommer return "_Z27__spirv_GroupNonUniformSMinii"; 11816ade03d7SLukas Sommer } 11826ade03d7SLukas Sommer template <> 11836ade03d7SLukas Sommer constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformUMinOp>() { 11846ade03d7SLukas Sommer return "_Z27__spirv_GroupNonUniformUMinii"; 11856ade03d7SLukas Sommer } 11866ade03d7SLukas Sommer template <> 11876ade03d7SLukas Sommer constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformFMinOp>() { 11886ade03d7SLukas Sommer return "_Z27__spirv_GroupNonUniformFMinii"; 11896ade03d7SLukas Sommer } 11906ade03d7SLukas Sommer template <> 11916ade03d7SLukas Sommer constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformSMaxOp>() { 11926ade03d7SLukas Sommer return "_Z27__spirv_GroupNonUniformSMaxii"; 11936ade03d7SLukas Sommer } 11946ade03d7SLukas Sommer template <> 11956ade03d7SLukas Sommer constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformUMaxOp>() { 11966ade03d7SLukas Sommer return "_Z27__spirv_GroupNonUniformUMaxii"; 11976ade03d7SLukas Sommer } 11986ade03d7SLukas Sommer template <> 11996ade03d7SLukas Sommer constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformFMaxOp>() { 12006ade03d7SLukas Sommer return "_Z27__spirv_GroupNonUniformFMaxii"; 12016ade03d7SLukas Sommer } 12026ade03d7SLukas Sommer template <> 12036ade03d7SLukas Sommer constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformBitwiseAndOp>() { 12046ade03d7SLukas Sommer return "_Z33__spirv_GroupNonUniformBitwiseAndii"; 12056ade03d7SLukas Sommer } 12066ade03d7SLukas Sommer template <> 12076ade03d7SLukas Sommer constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformBitwiseOrOp>() { 12086ade03d7SLukas Sommer return "_Z32__spirv_GroupNonUniformBitwiseOrii"; 12096ade03d7SLukas Sommer } 12106ade03d7SLukas Sommer template <> 12116ade03d7SLukas Sommer constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformBitwiseXorOp>() { 12126ade03d7SLukas Sommer return "_Z33__spirv_GroupNonUniformBitwiseXorii"; 12136ade03d7SLukas Sommer } 12146ade03d7SLukas Sommer template <> 12156ade03d7SLukas Sommer constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformLogicalAndOp>() { 12166ade03d7SLukas Sommer return "_Z33__spirv_GroupNonUniformLogicalAndii"; 12176ade03d7SLukas Sommer } 12186ade03d7SLukas Sommer template <> 12196ade03d7SLukas Sommer constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformLogicalOrOp>() { 12206ade03d7SLukas Sommer return "_Z32__spirv_GroupNonUniformLogicalOrii"; 12216ade03d7SLukas Sommer } 12226ade03d7SLukas Sommer template <> 12236ade03d7SLukas Sommer constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformLogicalXorOp>() { 12246ade03d7SLukas Sommer return "_Z33__spirv_GroupNonUniformLogicalXorii"; 12256ade03d7SLukas Sommer } 12266ade03d7SLukas Sommer } // namespace 12276ade03d7SLukas Sommer 12286ade03d7SLukas Sommer template <typename ReduceOp, bool Signed = false, bool NonUniform = false> 12296ade03d7SLukas Sommer class GroupReducePattern : public SPIRVToLLVMConversion<ReduceOp> { 12306ade03d7SLukas Sommer public: 12316ade03d7SLukas Sommer using SPIRVToLLVMConversion<ReduceOp>::SPIRVToLLVMConversion; 12326ade03d7SLukas Sommer 12336ade03d7SLukas Sommer LogicalResult 12346ade03d7SLukas Sommer matchAndRewrite(ReduceOp op, typename ReduceOp::Adaptor adaptor, 12356ade03d7SLukas Sommer ConversionPatternRewriter &rewriter) const override { 12366ade03d7SLukas Sommer 12376ade03d7SLukas Sommer Type retTy = op.getResult().getType(); 12386ade03d7SLukas Sommer if (!retTy.isIntOrFloat()) { 12396ade03d7SLukas Sommer return failure(); 12406ade03d7SLukas Sommer } 12416ade03d7SLukas Sommer SmallString<36> funcName = getGroupFuncName<ReduceOp>(); 12426ade03d7SLukas Sommer funcName += getTypeMangling(retTy, false); 12436ade03d7SLukas Sommer 12446ade03d7SLukas Sommer Type i32Ty = rewriter.getI32Type(); 12456ade03d7SLukas Sommer SmallVector<Type> paramTypes{i32Ty, i32Ty, retTy}; 12466ade03d7SLukas Sommer if constexpr (NonUniform) { 12476ade03d7SLukas Sommer if (adaptor.getClusterSize()) { 12486ade03d7SLukas Sommer funcName += "j"; 12496ade03d7SLukas Sommer paramTypes.push_back(i32Ty); 12506ade03d7SLukas Sommer } 12516ade03d7SLukas Sommer } 12526ade03d7SLukas Sommer 12536ade03d7SLukas Sommer Operation *symbolTable = 12546ade03d7SLukas Sommer op->template getParentWithTrait<OpTrait::SymbolTable>(); 12556ade03d7SLukas Sommer 1256*4adeb6cfSLukas Sommer LLVM::LLVMFuncOp func = 1257*4adeb6cfSLukas Sommer lookupOrCreateSPIRVFn(symbolTable, funcName, paramTypes, retTy); 12586ade03d7SLukas Sommer 12596ade03d7SLukas Sommer Location loc = op.getLoc(); 12606ade03d7SLukas Sommer Value scope = rewriter.create<LLVM::ConstantOp>( 12616ade03d7SLukas Sommer loc, i32Ty, static_cast<int32_t>(adaptor.getExecutionScope())); 12626ade03d7SLukas Sommer Value groupOp = rewriter.create<LLVM::ConstantOp>( 12636ade03d7SLukas Sommer loc, i32Ty, static_cast<int32_t>(adaptor.getGroupOperation())); 12646ade03d7SLukas Sommer SmallVector<Value> operands{scope, groupOp}; 12656ade03d7SLukas Sommer operands.append(adaptor.getOperands().begin(), adaptor.getOperands().end()); 12666ade03d7SLukas Sommer 12676ade03d7SLukas Sommer auto call = createSPIRVBuiltinCall(loc, rewriter, func, operands); 12686ade03d7SLukas Sommer rewriter.replaceOp(op, call); 12696ade03d7SLukas Sommer return success(); 12706ade03d7SLukas Sommer } 12716ade03d7SLukas Sommer }; 12726ade03d7SLukas Sommer 127305fcdd55SVictor Perez template <> 127405fcdd55SVictor Perez constexpr StringRef 127505fcdd55SVictor Perez ControlBarrierPattern<spirv::ControlBarrierOp>::getFuncName() { 127605fcdd55SVictor Perez return "_Z22__spirv_ControlBarrieriii"; 127705fcdd55SVictor Perez } 127805fcdd55SVictor Perez 127905fcdd55SVictor Perez template <> 128005fcdd55SVictor Perez constexpr StringRef 128105fcdd55SVictor Perez ControlBarrierPattern<spirv::INTELControlBarrierArriveOp>::getFuncName() { 128205fcdd55SVictor Perez return "_Z33__spirv_ControlBarrierArriveINTELiii"; 128305fcdd55SVictor Perez } 128405fcdd55SVictor Perez 128505fcdd55SVictor Perez template <> 128605fcdd55SVictor Perez constexpr StringRef 128705fcdd55SVictor Perez ControlBarrierPattern<spirv::INTELControlBarrierWaitOp>::getFuncName() { 128805fcdd55SVictor Perez return "_Z31__spirv_ControlBarrierWaitINTELiii"; 128905fcdd55SVictor Perez } 129005fcdd55SVictor Perez 12915ab6ef75SJakub Kuderski /// Converts `spirv.mlir.loop` to LLVM dialect. All blocks within selection 12925ab6ef75SJakub Kuderski /// should be reachable for conversion to succeed. The structure of the loop in 12935ab6ef75SJakub Kuderski /// LLVM dialect will be the following: 1294930c74f1SLei Zhang /// 1295930c74f1SLei Zhang /// +------------------------------------+ 12965ab6ef75SJakub Kuderski /// | <code before spirv.mlir.loop> | 1297930c74f1SLei Zhang /// | llvm.br ^header | 1298930c74f1SLei Zhang /// +------------------------------------+ 1299930c74f1SLei Zhang /// | 1300930c74f1SLei Zhang /// +----------------+ | 1301930c74f1SLei Zhang /// | | | 1302930c74f1SLei Zhang /// | V V 1303930c74f1SLei Zhang /// | +------------------------------------+ 1304930c74f1SLei Zhang /// | | ^header: | 1305930c74f1SLei Zhang /// | | <header code> | 1306930c74f1SLei Zhang /// | | llvm.cond_br %cond, ^body, ^exit | 1307930c74f1SLei Zhang /// | +------------------------------------+ 1308930c74f1SLei Zhang /// | | 1309930c74f1SLei Zhang /// | |----------------------+ 1310930c74f1SLei Zhang /// | | | 1311930c74f1SLei Zhang /// | V | 1312930c74f1SLei Zhang /// | +------------------------------------+ | 1313930c74f1SLei Zhang /// | | ^body: | | 1314930c74f1SLei Zhang /// | | <body code> | | 1315930c74f1SLei Zhang /// | | llvm.br ^continue | | 1316930c74f1SLei Zhang /// | +------------------------------------+ | 1317930c74f1SLei Zhang /// | | | 1318930c74f1SLei Zhang /// | V | 1319930c74f1SLei Zhang /// | +------------------------------------+ | 1320930c74f1SLei Zhang /// | | ^continue: | | 1321930c74f1SLei Zhang /// | | <continue code> | | 1322930c74f1SLei Zhang /// | | llvm.br ^header | | 1323930c74f1SLei Zhang /// | +------------------------------------+ | 1324930c74f1SLei Zhang /// | | | 1325930c74f1SLei Zhang /// +---------------+ +----------------------+ 1326930c74f1SLei Zhang /// | 1327930c74f1SLei Zhang /// V 1328930c74f1SLei Zhang /// +------------------------------------+ 1329930c74f1SLei Zhang /// | ^exit: | 1330930c74f1SLei Zhang /// | llvm.br ^remaining | 1331930c74f1SLei Zhang /// +------------------------------------+ 1332930c74f1SLei Zhang /// | 1333930c74f1SLei Zhang /// V 1334930c74f1SLei Zhang /// +------------------------------------+ 1335930c74f1SLei Zhang /// | ^remaining: | 13365ab6ef75SJakub Kuderski /// | <code after spirv.mlir.loop> | 1337930c74f1SLei Zhang /// +------------------------------------+ 1338930c74f1SLei Zhang /// 1339930c74f1SLei Zhang class LoopPattern : public SPIRVToLLVMConversion<spirv::LoopOp> { 1340930c74f1SLei Zhang public: 1341930c74f1SLei Zhang using SPIRVToLLVMConversion<spirv::LoopOp>::SPIRVToLLVMConversion; 1342930c74f1SLei Zhang 1343930c74f1SLei Zhang LogicalResult 1344b54c724bSRiver Riddle matchAndRewrite(spirv::LoopOp loopOp, OpAdaptor adaptor, 1345930c74f1SLei Zhang ConversionPatternRewriter &rewriter) const override { 1346930c74f1SLei Zhang // There is no support of loop control at the moment. 134790a1632dSJakub Kuderski if (loopOp.getLoopControl() != spirv::LoopControl::None) 1348930c74f1SLei Zhang return failure(); 1349930c74f1SLei Zhang 13505aa741d7SLongsheng Mou // `spirv.mlir.loop` with empty region is redundant and should be erased. 13515aa741d7SLongsheng Mou if (loopOp.getBody().empty()) { 13525aa741d7SLongsheng Mou rewriter.eraseOp(loopOp); 13535aa741d7SLongsheng Mou return success(); 13545aa741d7SLongsheng Mou } 13555aa741d7SLongsheng Mou 1356930c74f1SLei Zhang Location loc = loopOp.getLoc(); 1357930c74f1SLei Zhang 13585ab6ef75SJakub Kuderski // Split the current block after `spirv.mlir.loop`. The remaining ops will 13595ab6ef75SJakub Kuderski // be used in `endBlock`. 1360930c74f1SLei Zhang Block *currentBlock = rewriter.getBlock(); 1361930c74f1SLei Zhang auto position = Block::iterator(loopOp); 1362930c74f1SLei Zhang Block *endBlock = rewriter.splitBlock(currentBlock, position); 1363930c74f1SLei Zhang 1364930c74f1SLei Zhang // Remove entry block and create a branch in the current block going to the 1365930c74f1SLei Zhang // header block. 1366930c74f1SLei Zhang Block *entryBlock = loopOp.getEntryBlock(); 1367930c74f1SLei Zhang assert(entryBlock->getOperations().size() == 1); 1368930c74f1SLei Zhang auto brOp = dyn_cast<spirv::BranchOp>(entryBlock->getOperations().front()); 1369930c74f1SLei Zhang if (!brOp) 1370930c74f1SLei Zhang return failure(); 1371930c74f1SLei Zhang Block *headerBlock = loopOp.getHeaderBlock(); 1372930c74f1SLei Zhang rewriter.setInsertionPointToEnd(currentBlock); 1373930c74f1SLei Zhang rewriter.create<LLVM::BrOp>(loc, brOp.getBlockArguments(), headerBlock); 1374930c74f1SLei Zhang rewriter.eraseBlock(entryBlock); 1375930c74f1SLei Zhang 1376930c74f1SLei Zhang // Branch from merge block to end block. 1377930c74f1SLei Zhang Block *mergeBlock = loopOp.getMergeBlock(); 1378930c74f1SLei Zhang Operation *terminator = mergeBlock->getTerminator(); 1379930c74f1SLei Zhang ValueRange terminatorOperands = terminator->getOperands(); 1380930c74f1SLei Zhang rewriter.setInsertionPointToEnd(mergeBlock); 1381930c74f1SLei Zhang rewriter.create<LLVM::BrOp>(loc, terminatorOperands, endBlock); 1382930c74f1SLei Zhang 138390a1632dSJakub Kuderski rewriter.inlineRegionBefore(loopOp.getBody(), endBlock); 1384930c74f1SLei Zhang rewriter.replaceOp(loopOp, endBlock->getArguments()); 1385930c74f1SLei Zhang return success(); 1386930c74f1SLei Zhang } 1387930c74f1SLei Zhang }; 1388930c74f1SLei Zhang 13895ab6ef75SJakub Kuderski /// Converts `spirv.mlir.selection` with `spirv.BranchConditional` in its header 13903fb384d5SKareemErgawy-TomTom /// block. All blocks within selection should be reachable for conversion to 13913fb384d5SKareemErgawy-TomTom /// succeed. 1392930c74f1SLei Zhang class SelectionPattern : public SPIRVToLLVMConversion<spirv::SelectionOp> { 1393930c74f1SLei Zhang public: 1394930c74f1SLei Zhang using SPIRVToLLVMConversion<spirv::SelectionOp>::SPIRVToLLVMConversion; 1395930c74f1SLei Zhang 1396930c74f1SLei Zhang LogicalResult 1397b54c724bSRiver Riddle matchAndRewrite(spirv::SelectionOp op, OpAdaptor adaptor, 1398930c74f1SLei Zhang ConversionPatternRewriter &rewriter) const override { 1399930c74f1SLei Zhang // There is no support for `Flatten` or `DontFlatten` selection control at 1400930c74f1SLei Zhang // the moment. This are just compiler hints and can be performed during the 1401930c74f1SLei Zhang // optimization passes. 140290a1632dSJakub Kuderski if (op.getSelectionControl() != spirv::SelectionControl::None) 1403930c74f1SLei Zhang return failure(); 1404930c74f1SLei Zhang 14055ab6ef75SJakub Kuderski // `spirv.mlir.selection` should have at least two blocks: one selection 14063fb384d5SKareemErgawy-TomTom // header block and one merge block. If no blocks are present, or control 14073fb384d5SKareemErgawy-TomTom // flow branches straight to merge block (two blocks are present), the op is 1408930c74f1SLei Zhang // redundant and it is erased. 140990a1632dSJakub Kuderski if (op.getBody().getBlocks().size() <= 2) { 1410930c74f1SLei Zhang rewriter.eraseOp(op); 1411930c74f1SLei Zhang return success(); 1412930c74f1SLei Zhang } 1413930c74f1SLei Zhang 1414930c74f1SLei Zhang Location loc = op.getLoc(); 1415930c74f1SLei Zhang 14165ab6ef75SJakub Kuderski // Split the current block after `spirv.mlir.selection`. The remaining ops 14173fb384d5SKareemErgawy-TomTom // will be used in `continueBlock`. 1418930c74f1SLei Zhang auto *currentBlock = rewriter.getInsertionBlock(); 1419930c74f1SLei Zhang rewriter.setInsertionPointAfter(op); 1420930c74f1SLei Zhang auto position = rewriter.getInsertionPoint(); 1421930c74f1SLei Zhang auto *continueBlock = rewriter.splitBlock(currentBlock, position); 1422930c74f1SLei Zhang 1423930c74f1SLei Zhang // Extract conditional branch information from the header block. By SPIR-V 14245ab6ef75SJakub Kuderski // dialect spec, it should contain `spirv.BranchConditional` or 14255ab6ef75SJakub Kuderski // `spirv.Switch` op. Note that `spirv.Switch op` is not supported at the 14265ab6ef75SJakub Kuderski // moment in the SPIR-V dialect. Remove this block when finished. 1427930c74f1SLei Zhang auto *headerBlock = op.getHeaderBlock(); 1428930c74f1SLei Zhang assert(headerBlock->getOperations().size() == 1); 1429930c74f1SLei Zhang auto condBrOp = dyn_cast<spirv::BranchConditionalOp>( 1430930c74f1SLei Zhang headerBlock->getOperations().front()); 1431930c74f1SLei Zhang if (!condBrOp) 1432930c74f1SLei Zhang return failure(); 1433930c74f1SLei Zhang rewriter.eraseBlock(headerBlock); 1434930c74f1SLei Zhang 1435930c74f1SLei Zhang // Branch from merge block to continue block. 1436930c74f1SLei Zhang auto *mergeBlock = op.getMergeBlock(); 1437930c74f1SLei Zhang Operation *terminator = mergeBlock->getTerminator(); 1438930c74f1SLei Zhang ValueRange terminatorOperands = terminator->getOperands(); 1439930c74f1SLei Zhang rewriter.setInsertionPointToEnd(mergeBlock); 1440930c74f1SLei Zhang rewriter.create<LLVM::BrOp>(loc, terminatorOperands, continueBlock); 1441930c74f1SLei Zhang 1442930c74f1SLei Zhang // Link current block to `true` and `false` blocks within the selection. 1443930c74f1SLei Zhang Block *trueBlock = condBrOp.getTrueBlock(); 1444930c74f1SLei Zhang Block *falseBlock = condBrOp.getFalseBlock(); 1445930c74f1SLei Zhang rewriter.setInsertionPointToEnd(currentBlock); 144690a1632dSJakub Kuderski rewriter.create<LLVM::CondBrOp>(loc, condBrOp.getCondition(), trueBlock, 1447162f7572SMahesh Ravishankar condBrOp.getTrueTargetOperands(), 1448162f7572SMahesh Ravishankar falseBlock, 144990a1632dSJakub Kuderski condBrOp.getFalseTargetOperands()); 1450930c74f1SLei Zhang 145190a1632dSJakub Kuderski rewriter.inlineRegionBefore(op.getBody(), continueBlock); 1452930c74f1SLei Zhang rewriter.replaceOp(op, continueBlock->getArguments()); 1453930c74f1SLei Zhang return success(); 1454930c74f1SLei Zhang } 1455930c74f1SLei Zhang }; 1456930c74f1SLei Zhang 1457930c74f1SLei Zhang /// Converts SPIR-V shift ops to LLVM shift ops. Since LLVM dialect 1458930c74f1SLei Zhang /// puts a restriction on `Shift` and `Base` to have the same bit width, 1459930c74f1SLei Zhang /// `Shift` is zero or sign extended to match this specification. Cases when 1460930c74f1SLei Zhang /// `Shift` bit width > `Base` bit width are considered to be illegal. 1461930c74f1SLei Zhang template <typename SPIRVOp, typename LLVMOp> 1462930c74f1SLei Zhang class ShiftPattern : public SPIRVToLLVMConversion<SPIRVOp> { 1463930c74f1SLei Zhang public: 1464930c74f1SLei Zhang using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion; 1465930c74f1SLei Zhang 1466930c74f1SLei Zhang LogicalResult 1467c19436eeSKohei Yamaguchi matchAndRewrite(SPIRVOp op, typename SPIRVOp::Adaptor adaptor, 1468930c74f1SLei Zhang ConversionPatternRewriter &rewriter) const override { 1469930c74f1SLei Zhang 1470206fad0eSMatthias Springer auto dstType = this->getTypeConverter()->convertType(op.getType()); 1471930c74f1SLei Zhang if (!dstType) 1472c19436eeSKohei Yamaguchi return rewriter.notifyMatchFailure(op, "type conversion failed"); 1473930c74f1SLei Zhang 1474c19436eeSKohei Yamaguchi Type op1Type = op.getOperand1().getType(); 1475c19436eeSKohei Yamaguchi Type op2Type = op.getOperand2().getType(); 1476930c74f1SLei Zhang 1477930c74f1SLei Zhang if (op1Type == op2Type) { 1478c19436eeSKohei Yamaguchi rewriter.template replaceOpWithNewOp<LLVMOp>(op, dstType, 1479b54c724bSRiver Riddle adaptor.getOperands()); 1480930c74f1SLei Zhang return success(); 1481930c74f1SLei Zhang } 1482930c74f1SLei Zhang 148300b12a94SThéo Degioanni std::optional<uint64_t> dstTypeWidth = 148400b12a94SThéo Degioanni getIntegerOrVectorElementWidth(dstType); 148500b12a94SThéo Degioanni std::optional<uint64_t> op2TypeWidth = 148600b12a94SThéo Degioanni getIntegerOrVectorElementWidth(op2Type); 148700b12a94SThéo Degioanni 148800b12a94SThéo Degioanni if (!dstTypeWidth || !op2TypeWidth) 148900b12a94SThéo Degioanni return failure(); 149000b12a94SThéo Degioanni 1491c19436eeSKohei Yamaguchi Location loc = op.getLoc(); 1492930c74f1SLei Zhang Value extended; 149300b12a94SThéo Degioanni if (op2TypeWidth < dstTypeWidth) { 1494930c74f1SLei Zhang if (isUnsignedIntegerOrVector(op2Type)) { 149500b12a94SThéo Degioanni extended = rewriter.template create<LLVM::ZExtOp>( 149600b12a94SThéo Degioanni loc, dstType, adaptor.getOperand2()); 1497930c74f1SLei Zhang } else { 149800b12a94SThéo Degioanni extended = rewriter.template create<LLVM::SExtOp>( 149900b12a94SThéo Degioanni loc, dstType, adaptor.getOperand2()); 1500930c74f1SLei Zhang } 150100b12a94SThéo Degioanni } else if (op2TypeWidth == dstTypeWidth) { 150200b12a94SThéo Degioanni extended = adaptor.getOperand2(); 150300b12a94SThéo Degioanni } else { 150400b12a94SThéo Degioanni return failure(); 150500b12a94SThéo Degioanni } 150600b12a94SThéo Degioanni 1507930c74f1SLei Zhang Value result = rewriter.template create<LLVMOp>( 150890a1632dSJakub Kuderski loc, dstType, adaptor.getOperand1(), extended); 1509c19436eeSKohei Yamaguchi rewriter.replaceOp(op, result); 1510930c74f1SLei Zhang return success(); 1511930c74f1SLei Zhang } 1512930c74f1SLei Zhang }; 1513930c74f1SLei Zhang 151452b630daSJakub Kuderski class TanPattern : public SPIRVToLLVMConversion<spirv::GLTanOp> { 1515930c74f1SLei Zhang public: 151652b630daSJakub Kuderski using SPIRVToLLVMConversion<spirv::GLTanOp>::SPIRVToLLVMConversion; 1517930c74f1SLei Zhang 1518930c74f1SLei Zhang LogicalResult 151952b630daSJakub Kuderski matchAndRewrite(spirv::GLTanOp tanOp, OpAdaptor adaptor, 1520930c74f1SLei Zhang ConversionPatternRewriter &rewriter) const override { 1521206fad0eSMatthias Springer auto dstType = getTypeConverter()->convertType(tanOp.getType()); 1522930c74f1SLei Zhang if (!dstType) 1523c19436eeSKohei Yamaguchi return rewriter.notifyMatchFailure(tanOp, "type conversion failed"); 1524930c74f1SLei Zhang 1525930c74f1SLei Zhang Location loc = tanOp.getLoc(); 152690a1632dSJakub Kuderski Value sin = rewriter.create<LLVM::SinOp>(loc, dstType, tanOp.getOperand()); 152790a1632dSJakub Kuderski Value cos = rewriter.create<LLVM::CosOp>(loc, dstType, tanOp.getOperand()); 1528930c74f1SLei Zhang rewriter.replaceOpWithNewOp<LLVM::FDivOp>(tanOp, dstType, sin, cos); 1529930c74f1SLei Zhang return success(); 1530930c74f1SLei Zhang } 1531930c74f1SLei Zhang }; 1532930c74f1SLei Zhang 15335ab6ef75SJakub Kuderski /// Convert `spirv.Tanh` to 1534930c74f1SLei Zhang /// 1535930c74f1SLei Zhang /// exp(2x) - 1 1536930c74f1SLei Zhang /// ----------- 1537930c74f1SLei Zhang /// exp(2x) + 1 1538930c74f1SLei Zhang /// 153952b630daSJakub Kuderski class TanhPattern : public SPIRVToLLVMConversion<spirv::GLTanhOp> { 1540930c74f1SLei Zhang public: 154152b630daSJakub Kuderski using SPIRVToLLVMConversion<spirv::GLTanhOp>::SPIRVToLLVMConversion; 1542930c74f1SLei Zhang 1543930c74f1SLei Zhang LogicalResult 154452b630daSJakub Kuderski matchAndRewrite(spirv::GLTanhOp tanhOp, OpAdaptor adaptor, 1545930c74f1SLei Zhang ConversionPatternRewriter &rewriter) const override { 1546930c74f1SLei Zhang auto srcType = tanhOp.getType(); 1547206fad0eSMatthias Springer auto dstType = getTypeConverter()->convertType(srcType); 1548930c74f1SLei Zhang if (!dstType) 1549c19436eeSKohei Yamaguchi return rewriter.notifyMatchFailure(tanhOp, "type conversion failed"); 1550930c74f1SLei Zhang 1551930c74f1SLei Zhang Location loc = tanhOp.getLoc(); 1552930c74f1SLei Zhang Value two = createFPConstant(loc, srcType, dstType, rewriter, 2.0); 1553930c74f1SLei Zhang Value multiplied = 155490a1632dSJakub Kuderski rewriter.create<LLVM::FMulOp>(loc, dstType, two, tanhOp.getOperand()); 1555930c74f1SLei Zhang Value exponential = rewriter.create<LLVM::ExpOp>(loc, dstType, multiplied); 1556930c74f1SLei Zhang Value one = createFPConstant(loc, srcType, dstType, rewriter, 1.0); 1557930c74f1SLei Zhang Value numerator = 1558930c74f1SLei Zhang rewriter.create<LLVM::FSubOp>(loc, dstType, exponential, one); 1559930c74f1SLei Zhang Value denominator = 1560930c74f1SLei Zhang rewriter.create<LLVM::FAddOp>(loc, dstType, exponential, one); 1561930c74f1SLei Zhang rewriter.replaceOpWithNewOp<LLVM::FDivOp>(tanhOp, dstType, numerator, 1562930c74f1SLei Zhang denominator); 1563930c74f1SLei Zhang return success(); 1564930c74f1SLei Zhang } 1565930c74f1SLei Zhang }; 1566930c74f1SLei Zhang 1567930c74f1SLei Zhang class VariablePattern : public SPIRVToLLVMConversion<spirv::VariableOp> { 1568930c74f1SLei Zhang public: 1569930c74f1SLei Zhang using SPIRVToLLVMConversion<spirv::VariableOp>::SPIRVToLLVMConversion; 1570930c74f1SLei Zhang 1571930c74f1SLei Zhang LogicalResult 1572b54c724bSRiver Riddle matchAndRewrite(spirv::VariableOp varOp, OpAdaptor adaptor, 1573930c74f1SLei Zhang ConversionPatternRewriter &rewriter) const override { 1574930c74f1SLei Zhang auto srcType = varOp.getType(); 1575930c74f1SLei Zhang // Initialization is supported for scalars and vectors only. 15765550c821STres Popp auto pointerTo = cast<spirv::PointerType>(srcType).getPointeeType(); 157790a1632dSJakub Kuderski auto init = varOp.getInitializer(); 15785550c821STres Popp if (init && !pointerTo.isIntOrFloat() && !isa<VectorType>(pointerTo)) 1579930c74f1SLei Zhang return failure(); 1580930c74f1SLei Zhang 1581206fad0eSMatthias Springer auto dstType = getTypeConverter()->convertType(srcType); 1582930c74f1SLei Zhang if (!dstType) 1583c19436eeSKohei Yamaguchi return rewriter.notifyMatchFailure(varOp, "type conversion failed"); 1584930c74f1SLei Zhang 1585930c74f1SLei Zhang Location loc = varOp.getLoc(); 1586930c74f1SLei Zhang Value size = createI32ConstantOf(loc, rewriter, 1); 1587930c74f1SLei Zhang if (!init) { 1588206fad0eSMatthias Springer auto elementType = getTypeConverter()->convertType(pointerTo); 1589c19436eeSKohei Yamaguchi if (!elementType) 1590c19436eeSKohei Yamaguchi return rewriter.notifyMatchFailure(varOp, "type conversion failed"); 1591c19436eeSKohei Yamaguchi rewriter.replaceOpWithNewOp<LLVM::AllocaOp>(varOp, dstType, elementType, 1592c19436eeSKohei Yamaguchi size); 1593930c74f1SLei Zhang return success(); 1594930c74f1SLei Zhang } 1595206fad0eSMatthias Springer auto elementType = getTypeConverter()->convertType(pointerTo); 1596c19436eeSKohei Yamaguchi if (!elementType) 1597c19436eeSKohei Yamaguchi return rewriter.notifyMatchFailure(varOp, "type conversion failed"); 1598c19436eeSKohei Yamaguchi Value allocated = 1599c19436eeSKohei Yamaguchi rewriter.create<LLVM::AllocaOp>(loc, dstType, elementType, size); 160090a1632dSJakub Kuderski rewriter.create<LLVM::StoreOp>(loc, adaptor.getInitializer(), allocated); 1601930c74f1SLei Zhang rewriter.replaceOp(varOp, allocated); 1602930c74f1SLei Zhang return success(); 1603930c74f1SLei Zhang } 1604930c74f1SLei Zhang }; 1605930c74f1SLei Zhang 1606930c74f1SLei Zhang //===----------------------------------------------------------------------===// 160719b1e27fSMarkus Böck // BitcastOp conversion 160819b1e27fSMarkus Böck //===----------------------------------------------------------------------===// 160919b1e27fSMarkus Böck 161019b1e27fSMarkus Böck class BitcastConversionPattern 161119b1e27fSMarkus Böck : public SPIRVToLLVMConversion<spirv::BitcastOp> { 161219b1e27fSMarkus Böck public: 161319b1e27fSMarkus Böck using SPIRVToLLVMConversion<spirv::BitcastOp>::SPIRVToLLVMConversion; 161419b1e27fSMarkus Böck 161519b1e27fSMarkus Böck LogicalResult 161619b1e27fSMarkus Böck matchAndRewrite(spirv::BitcastOp bitcastOp, OpAdaptor adaptor, 161719b1e27fSMarkus Böck ConversionPatternRewriter &rewriter) const override { 1618206fad0eSMatthias Springer auto dstType = getTypeConverter()->convertType(bitcastOp.getType()); 161919b1e27fSMarkus Böck if (!dstType) 1620c19436eeSKohei Yamaguchi return rewriter.notifyMatchFailure(bitcastOp, "type conversion failed"); 162119b1e27fSMarkus Böck 162241f3b83fSChristian Ulmann // LLVM's opaque pointers do not require bitcasts. 162341f3b83fSChristian Ulmann if (isa<LLVM::LLVMPointerType>(dstType)) { 162419b1e27fSMarkus Böck rewriter.replaceOp(bitcastOp, adaptor.getOperand()); 162519b1e27fSMarkus Böck return success(); 162619b1e27fSMarkus Böck } 162719b1e27fSMarkus Böck 162819b1e27fSMarkus Böck rewriter.replaceOpWithNewOp<LLVM::BitcastOp>( 162919b1e27fSMarkus Böck bitcastOp, dstType, adaptor.getOperands(), bitcastOp->getAttrs()); 163019b1e27fSMarkus Böck return success(); 163119b1e27fSMarkus Böck } 163219b1e27fSMarkus Böck }; 163319b1e27fSMarkus Böck 163419b1e27fSMarkus Böck //===----------------------------------------------------------------------===// 1635930c74f1SLei Zhang // FuncOp conversion 1636930c74f1SLei Zhang //===----------------------------------------------------------------------===// 1637930c74f1SLei Zhang 1638930c74f1SLei Zhang class FuncConversionPattern : public SPIRVToLLVMConversion<spirv::FuncOp> { 1639930c74f1SLei Zhang public: 1640930c74f1SLei Zhang using SPIRVToLLVMConversion<spirv::FuncOp>::SPIRVToLLVMConversion; 1641930c74f1SLei Zhang 1642930c74f1SLei Zhang LogicalResult 1643b54c724bSRiver Riddle matchAndRewrite(spirv::FuncOp funcOp, OpAdaptor adaptor, 1644930c74f1SLei Zhang ConversionPatternRewriter &rewriter) const override { 1645930c74f1SLei Zhang 1646930c74f1SLei Zhang // Convert function signature. At the moment LLVMType converter is enough 1647930c74f1SLei Zhang // for currently supported types. 16484a3460a7SRiver Riddle auto funcType = funcOp.getFunctionType(); 1649930c74f1SLei Zhang TypeConverter::SignatureConversion signatureConverter( 1650930c74f1SLei Zhang funcType.getNumInputs()); 165198723e65SMatthias Springer auto llvmType = static_cast<const LLVMTypeConverter *>(getTypeConverter()) 165298723e65SMatthias Springer ->convertFunctionSignature( 165398723e65SMatthias Springer funcType, /*isVariadic=*/false, 165498723e65SMatthias Springer /*useBarePtrCallConv=*/false, signatureConverter); 1655930c74f1SLei Zhang if (!llvmType) 1656930c74f1SLei Zhang return failure(); 1657930c74f1SLei Zhang 1658930c74f1SLei Zhang // Create a new `LLVMFuncOp` 1659930c74f1SLei Zhang Location loc = funcOp.getLoc(); 1660930c74f1SLei Zhang StringRef name = funcOp.getName(); 1661930c74f1SLei Zhang auto newFuncOp = rewriter.create<LLVM::LLVMFuncOp>(loc, name, llvmType); 1662930c74f1SLei Zhang 1663930c74f1SLei Zhang // Convert SPIR-V Function Control to equivalent LLVM function attribute 1664930c74f1SLei Zhang MLIRContext *context = funcOp.getContext(); 166590a1632dSJakub Kuderski switch (funcOp.getFunctionControl()) { 1666c012e487SJohannes de Fine Licht case spirv::FunctionControl::Inline: 1667c012e487SJohannes de Fine Licht newFuncOp.setAlwaysInline(true); 1668c012e487SJohannes de Fine Licht break; 1669c012e487SJohannes de Fine Licht case spirv::FunctionControl::DontInline: 1670c012e487SJohannes de Fine Licht newFuncOp.setNoInline(true); 1671c012e487SJohannes de Fine Licht break; 1672c012e487SJohannes de Fine Licht 1673930c74f1SLei Zhang #define DISPATCH(functionControl, llvmAttr) \ 1674930c74f1SLei Zhang case functionControl: \ 1675c2c83e97STres Popp newFuncOp->setAttr("passthrough", ArrayAttr::get(context, {llvmAttr})); \ 1676930c74f1SLei Zhang break; 1677930c74f1SLei Zhang 1678930c74f1SLei Zhang DISPATCH(spirv::FunctionControl::Pure, 1679c2c83e97STres Popp StringAttr::get(context, "readonly")); 1680930c74f1SLei Zhang DISPATCH(spirv::FunctionControl::Const, 1681c2c83e97STres Popp StringAttr::get(context, "readnone")); 1682930c74f1SLei Zhang 1683930c74f1SLei Zhang #undef DISPATCH 1684930c74f1SLei Zhang 1685930c74f1SLei Zhang // Default: if `spirv::FunctionControl::None`, then no attributes are 1686930c74f1SLei Zhang // needed. 1687930c74f1SLei Zhang default: 1688930c74f1SLei Zhang break; 1689930c74f1SLei Zhang } 1690930c74f1SLei Zhang 1691930c74f1SLei Zhang rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(), 1692930c74f1SLei Zhang newFuncOp.end()); 1693206fad0eSMatthias Springer if (failed(rewriter.convertRegionTypes( 1694206fad0eSMatthias Springer &newFuncOp.getBody(), *getTypeConverter(), &signatureConverter))) { 1695930c74f1SLei Zhang return failure(); 1696930c74f1SLei Zhang } 1697930c74f1SLei Zhang rewriter.eraseOp(funcOp); 1698930c74f1SLei Zhang return success(); 1699930c74f1SLei Zhang } 1700930c74f1SLei Zhang }; 1701930c74f1SLei Zhang 1702930c74f1SLei Zhang //===----------------------------------------------------------------------===// 1703930c74f1SLei Zhang // ModuleOp conversion 1704930c74f1SLei Zhang //===----------------------------------------------------------------------===// 1705930c74f1SLei Zhang 1706930c74f1SLei Zhang class ModuleConversionPattern : public SPIRVToLLVMConversion<spirv::ModuleOp> { 1707930c74f1SLei Zhang public: 1708930c74f1SLei Zhang using SPIRVToLLVMConversion<spirv::ModuleOp>::SPIRVToLLVMConversion; 1709930c74f1SLei Zhang 1710930c74f1SLei Zhang LogicalResult 1711b54c724bSRiver Riddle matchAndRewrite(spirv::ModuleOp spvModuleOp, OpAdaptor adaptor, 1712930c74f1SLei Zhang ConversionPatternRewriter &rewriter) const override { 1713930c74f1SLei Zhang 1714930c74f1SLei Zhang auto newModuleOp = 1715930c74f1SLei Zhang rewriter.create<ModuleOp>(spvModuleOp.getLoc(), spvModuleOp.getName()); 171656f60a1cSLei Zhang rewriter.inlineRegionBefore(spvModuleOp.getRegion(), newModuleOp.getBody()); 1717930c74f1SLei Zhang 1718930c74f1SLei Zhang // Remove the terminator block that was automatically added by builder 1719930c74f1SLei Zhang rewriter.eraseBlock(&newModuleOp.getBodyRegion().back()); 1720930c74f1SLei Zhang rewriter.eraseOp(spvModuleOp); 1721930c74f1SLei Zhang return success(); 1722930c74f1SLei Zhang } 1723930c74f1SLei Zhang }; 1724930c74f1SLei Zhang 17253483fc5aSWeiwei Li //===----------------------------------------------------------------------===// 17263483fc5aSWeiwei Li // VectorShuffleOp conversion 17273483fc5aSWeiwei Li //===----------------------------------------------------------------------===// 17283483fc5aSWeiwei Li 17293483fc5aSWeiwei Li class VectorShufflePattern 17303483fc5aSWeiwei Li : public SPIRVToLLVMConversion<spirv::VectorShuffleOp> { 17313483fc5aSWeiwei Li public: 17323483fc5aSWeiwei Li using SPIRVToLLVMConversion<spirv::VectorShuffleOp>::SPIRVToLLVMConversion; 17333483fc5aSWeiwei Li LogicalResult 17343483fc5aSWeiwei Li matchAndRewrite(spirv::VectorShuffleOp op, OpAdaptor adaptor, 17353483fc5aSWeiwei Li ConversionPatternRewriter &rewriter) const override { 17363483fc5aSWeiwei Li Location loc = op.getLoc(); 173790a1632dSJakub Kuderski auto components = adaptor.getComponents(); 173890a1632dSJakub Kuderski auto vector1 = adaptor.getVector1(); 173990a1632dSJakub Kuderski auto vector2 = adaptor.getVector2(); 17405550c821STres Popp int vector1Size = cast<VectorType>(vector1.getType()).getNumElements(); 17415550c821STres Popp int vector2Size = cast<VectorType>(vector2.getType()).getNumElements(); 17423483fc5aSWeiwei Li if (vector1Size == vector2Size) { 1743b2ccfb4dSJeff Niu rewriter.replaceOpWithNewOp<LLVM::ShuffleVectorOp>( 1744b2ccfb4dSJeff Niu op, vector1, vector2, 1745b2ccfb4dSJeff Niu LLVM::convertArrayToIndices<int32_t>(components)); 17463483fc5aSWeiwei Li return success(); 17473483fc5aSWeiwei Li } 17483483fc5aSWeiwei Li 1749206fad0eSMatthias Springer auto dstType = getTypeConverter()->convertType(op.getType()); 1750c19436eeSKohei Yamaguchi if (!dstType) 1751c19436eeSKohei Yamaguchi return rewriter.notifyMatchFailure(op, "type conversion failed"); 17525550c821STres Popp auto scalarType = cast<VectorType>(dstType).getElementType(); 17533483fc5aSWeiwei Li auto componentsArray = components.getValue(); 175402b6fb21SMehdi Amini auto *context = rewriter.getContext(); 17553483fc5aSWeiwei Li auto llvmI32Type = IntegerType::get(context, 32); 17563483fc5aSWeiwei Li Value targetOp = rewriter.create<LLVM::UndefOp>(loc, dstType); 17573483fc5aSWeiwei Li for (unsigned i = 0; i < componentsArray.size(); i++) { 17585550c821STres Popp if (!isa<IntegerAttr>(componentsArray[i])) 1759a29fffc4SLei Zhang return op.emitError("unable to support non-constant component"); 17603483fc5aSWeiwei Li 17615550c821STres Popp int indexVal = cast<IntegerAttr>(componentsArray[i]).getInt(); 17623483fc5aSWeiwei Li if (indexVal == -1) 17633483fc5aSWeiwei Li continue; 17643483fc5aSWeiwei Li 17653483fc5aSWeiwei Li int offsetVal = 0; 17663483fc5aSWeiwei Li Value baseVector = vector1; 17673483fc5aSWeiwei Li if (indexVal >= vector1Size) { 17683483fc5aSWeiwei Li offsetVal = vector1Size; 17693483fc5aSWeiwei Li baseVector = vector2; 17703483fc5aSWeiwei Li } 17713483fc5aSWeiwei Li 17723483fc5aSWeiwei Li Value dstIndex = rewriter.create<LLVM::ConstantOp>( 17733483fc5aSWeiwei Li loc, llvmI32Type, rewriter.getIntegerAttr(rewriter.getI32Type(), i)); 17743483fc5aSWeiwei Li Value index = rewriter.create<LLVM::ConstantOp>( 17753483fc5aSWeiwei Li loc, llvmI32Type, 17763483fc5aSWeiwei Li rewriter.getIntegerAttr(rewriter.getI32Type(), indexVal - offsetVal)); 17773483fc5aSWeiwei Li 17783483fc5aSWeiwei Li auto extractOp = rewriter.create<LLVM::ExtractElementOp>( 17793483fc5aSWeiwei Li loc, scalarType, baseVector, index); 17803483fc5aSWeiwei Li targetOp = rewriter.create<LLVM::InsertElementOp>(loc, dstType, targetOp, 17813483fc5aSWeiwei Li extractOp, dstIndex); 17823483fc5aSWeiwei Li } 17833483fc5aSWeiwei Li rewriter.replaceOp(op, targetOp); 17843483fc5aSWeiwei Li return success(); 17853483fc5aSWeiwei Li } 17863483fc5aSWeiwei Li }; 1787930c74f1SLei Zhang } // namespace 1788930c74f1SLei Zhang 1789930c74f1SLei Zhang //===----------------------------------------------------------------------===// 1790930c74f1SLei Zhang // Pattern population 1791930c74f1SLei Zhang //===----------------------------------------------------------------------===// 1792930c74f1SLei Zhang 179312ce9fd1SVictor Perez void mlir::populateSPIRVToLLVMTypeConversion(LLVMTypeConverter &typeConverter, 179412ce9fd1SVictor Perez spirv::ClientAPI clientAPI) { 1795930c74f1SLei Zhang typeConverter.addConversion([&](spirv::ArrayType type) { 1796930c74f1SLei Zhang return convertArrayType(type, typeConverter); 1797930c74f1SLei Zhang }); 179812ce9fd1SVictor Perez typeConverter.addConversion([&, clientAPI](spirv::PointerType type) { 179912ce9fd1SVictor Perez return convertPointerType(type, typeConverter, clientAPI); 1800930c74f1SLei Zhang }); 1801930c74f1SLei Zhang typeConverter.addConversion([&](spirv::RuntimeArrayType type) { 1802930c74f1SLei Zhang return convertRuntimeArrayType(type, typeConverter); 1803930c74f1SLei Zhang }); 1804930c74f1SLei Zhang typeConverter.addConversion([&](spirv::StructType type) { 1805930c74f1SLei Zhang return convertStructType(type, typeConverter); 1806930c74f1SLei Zhang }); 1807930c74f1SLei Zhang } 1808930c74f1SLei Zhang 1809930c74f1SLei Zhang void mlir::populateSPIRVToLLVMConversionPatterns( 1810206fad0eSMatthias Springer const LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, 181112ce9fd1SVictor Perez spirv::ClientAPI clientAPI) { 1812dc4e913bSChris Lattner patterns.add< 1813930c74f1SLei Zhang // Arithmetic ops 1814930c74f1SLei Zhang DirectConversionPattern<spirv::IAddOp, LLVM::AddOp>, 1815930c74f1SLei Zhang DirectConversionPattern<spirv::IMulOp, LLVM::MulOp>, 1816930c74f1SLei Zhang DirectConversionPattern<spirv::ISubOp, LLVM::SubOp>, 1817930c74f1SLei Zhang DirectConversionPattern<spirv::FAddOp, LLVM::FAddOp>, 1818930c74f1SLei Zhang DirectConversionPattern<spirv::FDivOp, LLVM::FDivOp>, 1819930c74f1SLei Zhang DirectConversionPattern<spirv::FMulOp, LLVM::FMulOp>, 1820930c74f1SLei Zhang DirectConversionPattern<spirv::FNegateOp, LLVM::FNegOp>, 1821930c74f1SLei Zhang DirectConversionPattern<spirv::FRemOp, LLVM::FRemOp>, 1822930c74f1SLei Zhang DirectConversionPattern<spirv::FSubOp, LLVM::FSubOp>, 1823930c74f1SLei Zhang DirectConversionPattern<spirv::SDivOp, LLVM::SDivOp>, 1824930c74f1SLei Zhang DirectConversionPattern<spirv::SRemOp, LLVM::SRemOp>, 1825930c74f1SLei Zhang DirectConversionPattern<spirv::UDivOp, LLVM::UDivOp>, 1826930c74f1SLei Zhang DirectConversionPattern<spirv::UModOp, LLVM::URemOp>, 1827930c74f1SLei Zhang 1828930c74f1SLei Zhang // Bitwise ops 1829930c74f1SLei Zhang BitFieldInsertPattern, BitFieldUExtractPattern, BitFieldSExtractPattern, 1830930c74f1SLei Zhang DirectConversionPattern<spirv::BitCountOp, LLVM::CtPopOp>, 1831930c74f1SLei Zhang DirectConversionPattern<spirv::BitReverseOp, LLVM::BitReverseOp>, 1832930c74f1SLei Zhang DirectConversionPattern<spirv::BitwiseAndOp, LLVM::AndOp>, 1833930c74f1SLei Zhang DirectConversionPattern<spirv::BitwiseOrOp, LLVM::OrOp>, 1834930c74f1SLei Zhang DirectConversionPattern<spirv::BitwiseXorOp, LLVM::XOrOp>, 1835930c74f1SLei Zhang NotPattern<spirv::NotOp>, 1836930c74f1SLei Zhang 1837930c74f1SLei Zhang // Cast ops 183819b1e27fSMarkus Böck BitcastConversionPattern, 1839930c74f1SLei Zhang DirectConversionPattern<spirv::ConvertFToSOp, LLVM::FPToSIOp>, 1840930c74f1SLei Zhang DirectConversionPattern<spirv::ConvertFToUOp, LLVM::FPToUIOp>, 1841930c74f1SLei Zhang DirectConversionPattern<spirv::ConvertSToFOp, LLVM::SIToFPOp>, 1842930c74f1SLei Zhang DirectConversionPattern<spirv::ConvertUToFOp, LLVM::UIToFPOp>, 1843930c74f1SLei Zhang IndirectCastPattern<spirv::FConvertOp, LLVM::FPExtOp, LLVM::FPTruncOp>, 1844930c74f1SLei Zhang IndirectCastPattern<spirv::SConvertOp, LLVM::SExtOp, LLVM::TruncOp>, 1845930c74f1SLei Zhang IndirectCastPattern<spirv::UConvertOp, LLVM::ZExtOp, LLVM::TruncOp>, 1846930c74f1SLei Zhang 1847930c74f1SLei Zhang // Comparison ops 1848930c74f1SLei Zhang IComparePattern<spirv::IEqualOp, LLVM::ICmpPredicate::eq>, 1849930c74f1SLei Zhang IComparePattern<spirv::INotEqualOp, LLVM::ICmpPredicate::ne>, 1850930c74f1SLei Zhang FComparePattern<spirv::FOrdEqualOp, LLVM::FCmpPredicate::oeq>, 1851930c74f1SLei Zhang FComparePattern<spirv::FOrdGreaterThanOp, LLVM::FCmpPredicate::ogt>, 1852930c74f1SLei Zhang FComparePattern<spirv::FOrdGreaterThanEqualOp, LLVM::FCmpPredicate::oge>, 1853930c74f1SLei Zhang FComparePattern<spirv::FOrdLessThanEqualOp, LLVM::FCmpPredicate::ole>, 1854930c74f1SLei Zhang FComparePattern<spirv::FOrdLessThanOp, LLVM::FCmpPredicate::olt>, 1855930c74f1SLei Zhang FComparePattern<spirv::FOrdNotEqualOp, LLVM::FCmpPredicate::one>, 1856930c74f1SLei Zhang FComparePattern<spirv::FUnordEqualOp, LLVM::FCmpPredicate::ueq>, 1857930c74f1SLei Zhang FComparePattern<spirv::FUnordGreaterThanOp, LLVM::FCmpPredicate::ugt>, 1858930c74f1SLei Zhang FComparePattern<spirv::FUnordGreaterThanEqualOp, 1859930c74f1SLei Zhang LLVM::FCmpPredicate::uge>, 1860930c74f1SLei Zhang FComparePattern<spirv::FUnordLessThanEqualOp, LLVM::FCmpPredicate::ule>, 1861930c74f1SLei Zhang FComparePattern<spirv::FUnordLessThanOp, LLVM::FCmpPredicate::ult>, 1862930c74f1SLei Zhang FComparePattern<spirv::FUnordNotEqualOp, LLVM::FCmpPredicate::une>, 1863930c74f1SLei Zhang IComparePattern<spirv::SGreaterThanOp, LLVM::ICmpPredicate::sgt>, 1864930c74f1SLei Zhang IComparePattern<spirv::SGreaterThanEqualOp, LLVM::ICmpPredicate::sge>, 1865930c74f1SLei Zhang IComparePattern<spirv::SLessThanEqualOp, LLVM::ICmpPredicate::sle>, 1866930c74f1SLei Zhang IComparePattern<spirv::SLessThanOp, LLVM::ICmpPredicate::slt>, 1867930c74f1SLei Zhang IComparePattern<spirv::UGreaterThanOp, LLVM::ICmpPredicate::ugt>, 1868930c74f1SLei Zhang IComparePattern<spirv::UGreaterThanEqualOp, LLVM::ICmpPredicate::uge>, 1869930c74f1SLei Zhang IComparePattern<spirv::ULessThanEqualOp, LLVM::ICmpPredicate::ule>, 1870930c74f1SLei Zhang IComparePattern<spirv::ULessThanOp, LLVM::ICmpPredicate::ult>, 1871930c74f1SLei Zhang 1872930c74f1SLei Zhang // Constant op 1873930c74f1SLei Zhang ConstantScalarAndVectorPattern, 1874930c74f1SLei Zhang 1875930c74f1SLei Zhang // Control Flow ops 1876930c74f1SLei Zhang BranchConversionPattern, BranchConditionalConversionPattern, 1877930c74f1SLei Zhang FunctionCallPattern, LoopPattern, SelectionPattern, 1878930c74f1SLei Zhang ErasePattern<spirv::MergeOp>, 1879930c74f1SLei Zhang 1880930c74f1SLei Zhang // Entry points and execution mode are handled separately. 1881930c74f1SLei Zhang ErasePattern<spirv::EntryPointOp>, ExecutionModePattern, 1882930c74f1SLei Zhang 1883930c74f1SLei Zhang // GLSL extended instruction set ops 188452b630daSJakub Kuderski DirectConversionPattern<spirv::GLCeilOp, LLVM::FCeilOp>, 188552b630daSJakub Kuderski DirectConversionPattern<spirv::GLCosOp, LLVM::CosOp>, 188652b630daSJakub Kuderski DirectConversionPattern<spirv::GLExpOp, LLVM::ExpOp>, 188752b630daSJakub Kuderski DirectConversionPattern<spirv::GLFAbsOp, LLVM::FAbsOp>, 188852b630daSJakub Kuderski DirectConversionPattern<spirv::GLFloorOp, LLVM::FFloorOp>, 188952b630daSJakub Kuderski DirectConversionPattern<spirv::GLFMaxOp, LLVM::MaxNumOp>, 189052b630daSJakub Kuderski DirectConversionPattern<spirv::GLFMinOp, LLVM::MinNumOp>, 189152b630daSJakub Kuderski DirectConversionPattern<spirv::GLLogOp, LLVM::LogOp>, 189252b630daSJakub Kuderski DirectConversionPattern<spirv::GLSinOp, LLVM::SinOp>, 189352b630daSJakub Kuderski DirectConversionPattern<spirv::GLSMaxOp, LLVM::SMaxOp>, 189452b630daSJakub Kuderski DirectConversionPattern<spirv::GLSMinOp, LLVM::SMinOp>, 189552b630daSJakub Kuderski DirectConversionPattern<spirv::GLSqrtOp, LLVM::SqrtOp>, 1896930c74f1SLei Zhang InverseSqrtPattern, TanPattern, TanhPattern, 1897930c74f1SLei Zhang 1898930c74f1SLei Zhang // Logical ops 1899930c74f1SLei Zhang DirectConversionPattern<spirv::LogicalAndOp, LLVM::AndOp>, 1900930c74f1SLei Zhang DirectConversionPattern<spirv::LogicalOrOp, LLVM::OrOp>, 1901930c74f1SLei Zhang IComparePattern<spirv::LogicalEqualOp, LLVM::ICmpPredicate::eq>, 1902930c74f1SLei Zhang IComparePattern<spirv::LogicalNotEqualOp, LLVM::ICmpPredicate::ne>, 1903930c74f1SLei Zhang NotPattern<spirv::LogicalNotOp>, 1904930c74f1SLei Zhang 1905930c74f1SLei Zhang // Memory ops 190612ce9fd1SVictor Perez AccessChainPattern, AddressOfPattern, LoadStorePattern<spirv::LoadOp>, 190712ce9fd1SVictor Perez LoadStorePattern<spirv::StoreOp>, VariablePattern, 1908930c74f1SLei Zhang 1909930c74f1SLei Zhang // Miscellaneous ops 1910930c74f1SLei Zhang CompositeExtractPattern, CompositeInsertPattern, 1911930c74f1SLei Zhang DirectConversionPattern<spirv::SelectOp, LLVM::SelectOp>, 1912930c74f1SLei Zhang DirectConversionPattern<spirv::UndefOp, LLVM::UndefOp>, 19133483fc5aSWeiwei Li VectorShufflePattern, 1914930c74f1SLei Zhang 1915930c74f1SLei Zhang // Shift ops 1916930c74f1SLei Zhang ShiftPattern<spirv::ShiftRightArithmeticOp, LLVM::AShrOp>, 1917930c74f1SLei Zhang ShiftPattern<spirv::ShiftRightLogicalOp, LLVM::LShrOp>, 1918930c74f1SLei Zhang ShiftPattern<spirv::ShiftLeftLogicalOp, LLVM::ShlOp>, 1919930c74f1SLei Zhang 1920930c74f1SLei Zhang // Return ops 19211775b98dSFinlay ReturnPattern, ReturnValuePattern, 19221775b98dSFinlay 19231775b98dSFinlay // Barrier ops 192405fcdd55SVictor Perez ControlBarrierPattern<spirv::ControlBarrierOp>, 192505fcdd55SVictor Perez ControlBarrierPattern<spirv::INTELControlBarrierArriveOp>, 192605fcdd55SVictor Perez ControlBarrierPattern<spirv::INTELControlBarrierWaitOp>, 19276ade03d7SLukas Sommer 19286ade03d7SLukas Sommer // Group reduction operations 19296ade03d7SLukas Sommer GroupReducePattern<spirv::GroupIAddOp>, 19306ade03d7SLukas Sommer GroupReducePattern<spirv::GroupFAddOp>, 19316ade03d7SLukas Sommer GroupReducePattern<spirv::GroupFMinOp>, 19326ade03d7SLukas Sommer GroupReducePattern<spirv::GroupUMinOp>, 19336ade03d7SLukas Sommer GroupReducePattern<spirv::GroupSMinOp, /*Signed=*/true>, 19346ade03d7SLukas Sommer GroupReducePattern<spirv::GroupFMaxOp>, 19356ade03d7SLukas Sommer GroupReducePattern<spirv::GroupUMaxOp>, 19366ade03d7SLukas Sommer GroupReducePattern<spirv::GroupSMaxOp, /*Signed=*/true>, 19376ade03d7SLukas Sommer GroupReducePattern<spirv::GroupNonUniformIAddOp, /*Signed=*/false, 19386ade03d7SLukas Sommer /*NonUniform=*/true>, 19396ade03d7SLukas Sommer GroupReducePattern<spirv::GroupNonUniformFAddOp, /*Signed=*/false, 19406ade03d7SLukas Sommer /*NonUniform=*/true>, 19416ade03d7SLukas Sommer GroupReducePattern<spirv::GroupNonUniformIMulOp, /*Signed=*/false, 19426ade03d7SLukas Sommer /*NonUniform=*/true>, 19436ade03d7SLukas Sommer GroupReducePattern<spirv::GroupNonUniformFMulOp, /*Signed=*/false, 19446ade03d7SLukas Sommer /*NonUniform=*/true>, 19456ade03d7SLukas Sommer GroupReducePattern<spirv::GroupNonUniformSMinOp, /*Signed=*/true, 19466ade03d7SLukas Sommer /*NonUniform=*/true>, 19476ade03d7SLukas Sommer GroupReducePattern<spirv::GroupNonUniformUMinOp, /*Signed=*/false, 19486ade03d7SLukas Sommer /*NonUniform=*/true>, 19496ade03d7SLukas Sommer GroupReducePattern<spirv::GroupNonUniformFMinOp, /*Signed=*/false, 19506ade03d7SLukas Sommer /*NonUniform=*/true>, 19516ade03d7SLukas Sommer GroupReducePattern<spirv::GroupNonUniformSMaxOp, /*Signed=*/true, 19526ade03d7SLukas Sommer /*NonUniform=*/true>, 19536ade03d7SLukas Sommer GroupReducePattern<spirv::GroupNonUniformUMaxOp, /*Signed=*/false, 19546ade03d7SLukas Sommer /*NonUniform=*/true>, 19556ade03d7SLukas Sommer GroupReducePattern<spirv::GroupNonUniformFMaxOp, /*Signed=*/false, 19566ade03d7SLukas Sommer /*NonUniform=*/true>, 19576ade03d7SLukas Sommer GroupReducePattern<spirv::GroupNonUniformBitwiseAndOp, /*Signed=*/false, 19586ade03d7SLukas Sommer /*NonUniform=*/true>, 19596ade03d7SLukas Sommer GroupReducePattern<spirv::GroupNonUniformBitwiseOrOp, /*Signed=*/false, 19606ade03d7SLukas Sommer /*NonUniform=*/true>, 19616ade03d7SLukas Sommer GroupReducePattern<spirv::GroupNonUniformBitwiseXorOp, /*Signed=*/false, 19626ade03d7SLukas Sommer /*NonUniform=*/true>, 19636ade03d7SLukas Sommer GroupReducePattern<spirv::GroupNonUniformLogicalAndOp, /*Signed=*/false, 19646ade03d7SLukas Sommer /*NonUniform=*/true>, 19656ade03d7SLukas Sommer GroupReducePattern<spirv::GroupNonUniformLogicalOrOp, /*Signed=*/false, 19666ade03d7SLukas Sommer /*NonUniform=*/true>, 19676ade03d7SLukas Sommer GroupReducePattern<spirv::GroupNonUniformLogicalXorOp, /*Signed=*/false, 19686ade03d7SLukas Sommer /*NonUniform=*/true>>(patterns.getContext(), 19696ade03d7SLukas Sommer typeConverter); 197012ce9fd1SVictor Perez 197112ce9fd1SVictor Perez patterns.add<GlobalVariablePattern>(clientAPI, patterns.getContext(), 197212ce9fd1SVictor Perez typeConverter); 1973930c74f1SLei Zhang } 1974930c74f1SLei Zhang 1975930c74f1SLei Zhang void mlir::populateSPIRVToLLVMFunctionConversionPatterns( 1976206fad0eSMatthias Springer const LLVMTypeConverter &typeConverter, RewritePatternSet &patterns) { 1977dc4e913bSChris Lattner patterns.add<FuncConversionPattern>(patterns.getContext(), typeConverter); 1978930c74f1SLei Zhang } 1979930c74f1SLei Zhang 1980930c74f1SLei Zhang void mlir::populateSPIRVToLLVMModuleConversionPatterns( 1981206fad0eSMatthias Springer const LLVMTypeConverter &typeConverter, RewritePatternSet &patterns) { 198256f60a1cSLei Zhang patterns.add<ModuleConversionPattern>(patterns.getContext(), typeConverter); 1983930c74f1SLei Zhang } 1984930c74f1SLei Zhang 1985930c74f1SLei Zhang //===----------------------------------------------------------------------===// 1986930c74f1SLei Zhang // Pre-conversion hooks 1987930c74f1SLei Zhang //===----------------------------------------------------------------------===// 1988930c74f1SLei Zhang 1989930c74f1SLei Zhang /// Hook for descriptor set and binding number encoding. 1990930c74f1SLei Zhang static constexpr StringRef kBinding = "binding"; 1991930c74f1SLei Zhang static constexpr StringRef kDescriptorSet = "descriptor_set"; 1992930c74f1SLei Zhang void mlir::encodeBindAttribute(ModuleOp module) { 1993930c74f1SLei Zhang auto spvModules = module.getOps<spirv::ModuleOp>(); 1994930c74f1SLei Zhang for (auto spvModule : spvModules) { 1995930c74f1SLei Zhang spvModule.walk([&](spirv::GlobalVariableOp op) { 1996930c74f1SLei Zhang IntegerAttr descriptorSet = 1997930c74f1SLei Zhang op->getAttrOfType<IntegerAttr>(kDescriptorSet); 1998930c74f1SLei Zhang IntegerAttr binding = op->getAttrOfType<IntegerAttr>(kBinding); 1999930c74f1SLei Zhang // For every global variable in the module, get the ones with descriptor 2000930c74f1SLei Zhang // set and binding numbers. 2001930c74f1SLei Zhang if (descriptorSet && binding) { 2002930c74f1SLei Zhang // Encode these numbers into the variable's symbolic name. If the 2003930c74f1SLei Zhang // SPIR-V module has a name, add it at the beginning. 2004c27d8152SKazu Hirata auto moduleAndName = 2005c27d8152SKazu Hirata spvModule.getName().has_value() 200690a1632dSJakub Kuderski ? spvModule.getName()->str() + "_" + op.getSymName().str() 200790a1632dSJakub Kuderski : op.getSymName().str(); 2008930c74f1SLei Zhang std::string name = 2009930c74f1SLei Zhang llvm::formatv("{0}_descriptor_set{1}_binding{2}", moduleAndName, 2010930c74f1SLei Zhang std::to_string(descriptorSet.getInt()), 2011930c74f1SLei Zhang std::to_string(binding.getInt())); 201241d4aa7dSChris Lattner auto nameAttr = StringAttr::get(op->getContext(), name); 2013930c74f1SLei Zhang 2014930c74f1SLei Zhang // Replace all symbol uses and set the new symbol name. Finally, remove 2015930c74f1SLei Zhang // descriptor set and binding attributes. 201641d4aa7dSChris Lattner if (failed(SymbolTable::replaceAllSymbolUses(op, nameAttr, spvModule))) 2017930c74f1SLei Zhang op.emitError("unable to replace all symbol uses for ") << name; 201841d4aa7dSChris Lattner SymbolTable::setSymbolName(op, nameAttr); 2019dffc487bSChristian Sigg op->removeAttr(kDescriptorSet); 2020dffc487bSChristian Sigg op->removeAttr(kBinding); 2021930c74f1SLei Zhang } 2022930c74f1SLei Zhang }); 2023930c74f1SLei Zhang } 2024930c74f1SLei Zhang } 2025