198d5d344SVictor Perez //===- GPUToLLVMSPV.cpp - Convert GPU operations to LLVM dialect ----------===// 298d5d344SVictor Perez // 398d5d344SVictor Perez // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 498d5d344SVictor Perez // See https://llvm.org/LICENSE.txt for license information. 598d5d344SVictor Perez // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 698d5d344SVictor Perez // 798d5d344SVictor Perez //===----------------------------------------------------------------------===// 898d5d344SVictor Perez 998d5d344SVictor Perez #include "mlir/Conversion/GPUToLLVMSPV/GPUToLLVMSPVPass.h" 1098d5d344SVictor Perez 11d45de800SVictor Perez #include "../GPUCommon/GPUOpsLowering.h" 1275cb9edfSVictor Perez #include "mlir/Conversion/GPUCommon/AttrToSPIRVConverter.h" 1375cb9edfSVictor Perez #include "mlir/Conversion/GPUCommon/GPUCommonPass.h" 1498d5d344SVictor Perez #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" 1598d5d344SVictor Perez #include "mlir/Conversion/LLVMCommon/LoweringOptions.h" 1698d5d344SVictor Perez #include "mlir/Conversion/LLVMCommon/Pattern.h" 1798d5d344SVictor Perez #include "mlir/Conversion/LLVMCommon/TypeConverter.h" 18d45de800SVictor Perez #include "mlir/Conversion/SPIRVCommon/AttrToLLVMConverter.h" 1998d5d344SVictor Perez #include "mlir/Dialect/GPU/IR/GPUDialect.h" 2098d5d344SVictor Perez #include "mlir/Dialect/LLVMIR/LLVMAttrs.h" 2198d5d344SVictor Perez #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 2298d5d344SVictor Perez #include "mlir/Dialect/LLVMIR/LLVMTypes.h" 23d45de800SVictor Perez #include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h" 2498d5d344SVictor Perez #include "mlir/IR/BuiltinTypes.h" 2598d5d344SVictor Perez #include "mlir/IR/Matchers.h" 2698d5d344SVictor Perez #include "mlir/IR/PatternMatch.h" 2798d5d344SVictor Perez #include "mlir/IR/SymbolTable.h" 2898d5d344SVictor Perez #include "mlir/Pass/Pass.h" 2998d5d344SVictor Perez #include "mlir/Support/LLVM.h" 3098d5d344SVictor Perez #include "mlir/Transforms/DialectConversion.h" 3198d5d344SVictor Perez 3298d5d344SVictor Perez #include "llvm/ADT/TypeSwitch.h" 3398d5d344SVictor Perez #include "llvm/Support/FormatVariadic.h" 3498d5d344SVictor Perez 35f8b7a653SPetr Kurapov #define DEBUG_TYPE "gpu-to-llvm-spv" 36f8b7a653SPetr Kurapov 3798d5d344SVictor Perez using namespace mlir; 3898d5d344SVictor Perez 3998d5d344SVictor Perez namespace mlir { 4098d5d344SVictor Perez #define GEN_PASS_DEF_CONVERTGPUOPSTOLLVMSPVOPS 4198d5d344SVictor Perez #include "mlir/Conversion/Passes.h.inc" 4298d5d344SVictor Perez } // namespace mlir 4398d5d344SVictor Perez 4498d5d344SVictor Perez //===----------------------------------------------------------------------===// 4598d5d344SVictor Perez // Helper Functions 4698d5d344SVictor Perez //===----------------------------------------------------------------------===// 4798d5d344SVictor Perez 4898d5d344SVictor Perez static LLVM::LLVMFuncOp lookupOrCreateSPIRVFn(Operation *symbolTable, 4998d5d344SVictor Perez StringRef name, 5098d5d344SVictor Perez ArrayRef<Type> paramTypes, 515a53add8SFinlay Type resultType, bool isMemNone, 525a53add8SFinlay bool isConvergent) { 5398d5d344SVictor Perez auto func = dyn_cast_or_null<LLVM::LLVMFuncOp>( 5498d5d344SVictor Perez SymbolTable::lookupSymbolIn(symbolTable, name)); 5598d5d344SVictor Perez if (!func) { 5698d5d344SVictor Perez OpBuilder b(symbolTable->getRegion(0)); 5798d5d344SVictor Perez func = b.create<LLVM::LLVMFuncOp>( 5898d5d344SVictor Perez symbolTable->getLoc(), name, 5998d5d344SVictor Perez LLVM::LLVMFunctionType::get(resultType, paramTypes)); 6098d5d344SVictor Perez func.setCConv(LLVM::cconv::CConv::SPIR_FUNC); 615a53add8SFinlay func.setNoUnwind(true); 625a53add8SFinlay func.setWillReturn(true); 635a53add8SFinlay 645a53add8SFinlay if (isMemNone) { 655a53add8SFinlay // no externally observable effects 665a53add8SFinlay constexpr auto noModRef = mlir::LLVM::ModRefInfo::NoModRef; 675a53add8SFinlay auto memAttr = b.getAttr<LLVM::MemoryEffectsAttr>( 685a53add8SFinlay /*other=*/noModRef, 695a53add8SFinlay /*argMem=*/noModRef, /*inaccessibleMem=*/noModRef); 705a53add8SFinlay func.setMemoryEffectsAttr(memAttr); 715a53add8SFinlay } 725a53add8SFinlay 733670e7f8SFinlay func.setConvergent(isConvergent); 7498d5d344SVictor Perez } 7598d5d344SVictor Perez return func; 7698d5d344SVictor Perez } 7798d5d344SVictor Perez 7898d5d344SVictor Perez static LLVM::CallOp createSPIRVBuiltinCall(Location loc, 7998d5d344SVictor Perez ConversionPatternRewriter &rewriter, 8098d5d344SVictor Perez LLVM::LLVMFuncOp func, 8198d5d344SVictor Perez ValueRange args) { 8298d5d344SVictor Perez auto call = rewriter.create<LLVM::CallOp>(loc, func, args); 8398d5d344SVictor Perez call.setCConv(func.getCConv()); 845a53add8SFinlay call.setConvergentAttr(func.getConvergentAttr()); 855a53add8SFinlay call.setNoUnwindAttr(func.getNoUnwindAttr()); 865a53add8SFinlay call.setWillReturnAttr(func.getWillReturnAttr()); 875a53add8SFinlay call.setMemoryEffectsAttr(func.getMemoryEffectsAttr()); 8898d5d344SVictor Perez return call; 8998d5d344SVictor Perez } 9098d5d344SVictor Perez 9198d5d344SVictor Perez namespace { 9298d5d344SVictor Perez //===----------------------------------------------------------------------===// 9398d5d344SVictor Perez // Barriers 9498d5d344SVictor Perez //===----------------------------------------------------------------------===// 9598d5d344SVictor Perez 9698d5d344SVictor Perez /// Replace `gpu.barrier` with an `llvm.call` to `barrier` with 9798d5d344SVictor Perez /// `CLK_LOCAL_MEM_FENCE` argument, indicating work-group memory scope: 9898d5d344SVictor Perez /// ``` 9998d5d344SVictor Perez /// // gpu.barrier 10098d5d344SVictor Perez /// %c1 = llvm.mlir.constant(1: i32) : i32 10198d5d344SVictor Perez /// llvm.call spir_funccc @_Z7barrierj(%c1) : (i32) -> () 10298d5d344SVictor Perez /// ``` 10398d5d344SVictor Perez struct GPUBarrierConversion final : ConvertOpToLLVMPattern<gpu::BarrierOp> { 10498d5d344SVictor Perez using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; 10598d5d344SVictor Perez 10698d5d344SVictor Perez LogicalResult 10798d5d344SVictor Perez matchAndRewrite(gpu::BarrierOp op, OpAdaptor adaptor, 10898d5d344SVictor Perez ConversionPatternRewriter &rewriter) const final { 10998d5d344SVictor Perez constexpr StringLiteral funcName = "_Z7barrierj"; 11098d5d344SVictor Perez 11198d5d344SVictor Perez Operation *moduleOp = op->getParentWithTrait<OpTrait::SymbolTable>(); 11298d5d344SVictor Perez assert(moduleOp && "Expecting module"); 11398d5d344SVictor Perez Type flagTy = rewriter.getI32Type(); 11498d5d344SVictor Perez Type voidTy = rewriter.getType<LLVM::LLVMVoidType>(); 1155a53add8SFinlay LLVM::LLVMFuncOp func = 1165a53add8SFinlay lookupOrCreateSPIRVFn(moduleOp, funcName, flagTy, voidTy, 1175a53add8SFinlay /*isMemNone=*/false, /*isConvergent=*/true); 11898d5d344SVictor Perez 11998d5d344SVictor Perez // Value used by SPIR-V backend to represent `CLK_LOCAL_MEM_FENCE`. 12098d5d344SVictor Perez // See `llvm/lib/Target/SPIRV/SPIRVBuiltins.td`. 12198d5d344SVictor Perez constexpr int64_t localMemFenceFlag = 1; 12298d5d344SVictor Perez Location loc = op->getLoc(); 12398d5d344SVictor Perez Value flag = 12498d5d344SVictor Perez rewriter.create<LLVM::ConstantOp>(loc, flagTy, localMemFenceFlag); 12598d5d344SVictor Perez rewriter.replaceOp(op, createSPIRVBuiltinCall(loc, rewriter, func, flag)); 12698d5d344SVictor Perez return success(); 12798d5d344SVictor Perez } 12898d5d344SVictor Perez }; 12998d5d344SVictor Perez 13098d5d344SVictor Perez //===----------------------------------------------------------------------===// 13198d5d344SVictor Perez // SPIR-V Builtins 13298d5d344SVictor Perez //===----------------------------------------------------------------------===// 13398d5d344SVictor Perez 13498d5d344SVictor Perez /// Replace `gpu.*` with an `llvm.call` to the corresponding SPIR-V builtin with 13598d5d344SVictor Perez /// a constant argument for the `dimension` attribute. Return type will depend 13698d5d344SVictor Perez /// on index width option: 13798d5d344SVictor Perez /// ``` 13898d5d344SVictor Perez /// // %thread_id_y = gpu.thread_id y 13998d5d344SVictor Perez /// %c1 = llvm.mlir.constant(1: i32) : i32 14098d5d344SVictor Perez /// %0 = llvm.call spir_funccc @_Z12get_local_idj(%c1) : (i32) -> i64 14198d5d344SVictor Perez /// ``` 14298d5d344SVictor Perez struct LaunchConfigConversion : ConvertToLLVMPattern { 14398d5d344SVictor Perez LaunchConfigConversion(StringRef funcName, StringRef rootOpName, 14498d5d344SVictor Perez MLIRContext *context, 14598d5d344SVictor Perez const LLVMTypeConverter &typeConverter, 14698d5d344SVictor Perez PatternBenefit benefit) 14798d5d344SVictor Perez : ConvertToLLVMPattern(rootOpName, context, typeConverter, benefit), 14898d5d344SVictor Perez funcName(funcName) {} 14998d5d344SVictor Perez 15098d5d344SVictor Perez virtual gpu::Dimension getDimension(Operation *op) const = 0; 15198d5d344SVictor Perez 15298d5d344SVictor Perez LogicalResult 15398d5d344SVictor Perez matchAndRewrite(Operation *op, ArrayRef<Value> operands, 15498d5d344SVictor Perez ConversionPatternRewriter &rewriter) const final { 15598d5d344SVictor Perez Operation *moduleOp = op->getParentWithTrait<OpTrait::SymbolTable>(); 15698d5d344SVictor Perez assert(moduleOp && "Expecting module"); 15798d5d344SVictor Perez Type dimTy = rewriter.getI32Type(); 15898d5d344SVictor Perez Type indexTy = getTypeConverter()->getIndexType(); 1595a53add8SFinlay LLVM::LLVMFuncOp func = lookupOrCreateSPIRVFn(moduleOp, funcName, dimTy, 1605a53add8SFinlay indexTy, /*isMemNone=*/true, 1615a53add8SFinlay /*isConvergent=*/false); 16298d5d344SVictor Perez 16398d5d344SVictor Perez Location loc = op->getLoc(); 16498d5d344SVictor Perez gpu::Dimension dim = getDimension(op); 16598d5d344SVictor Perez Value dimVal = rewriter.create<LLVM::ConstantOp>(loc, dimTy, 16698d5d344SVictor Perez static_cast<int64_t>(dim)); 16798d5d344SVictor Perez rewriter.replaceOp(op, createSPIRVBuiltinCall(loc, rewriter, func, dimVal)); 16898d5d344SVictor Perez return success(); 16998d5d344SVictor Perez } 17098d5d344SVictor Perez 17198d5d344SVictor Perez StringRef funcName; 17298d5d344SVictor Perez }; 17398d5d344SVictor Perez 17498d5d344SVictor Perez template <typename SourceOp> 17598d5d344SVictor Perez struct LaunchConfigOpConversion final : LaunchConfigConversion { 17698d5d344SVictor Perez static StringRef getFuncName(); 17798d5d344SVictor Perez 17898d5d344SVictor Perez explicit LaunchConfigOpConversion(const LLVMTypeConverter &typeConverter, 17998d5d344SVictor Perez PatternBenefit benefit = 1) 18098d5d344SVictor Perez : LaunchConfigConversion(getFuncName(), SourceOp::getOperationName(), 18198d5d344SVictor Perez &typeConverter.getContext(), typeConverter, 18298d5d344SVictor Perez benefit) {} 18398d5d344SVictor Perez 18498d5d344SVictor Perez gpu::Dimension getDimension(Operation *op) const final { 18598d5d344SVictor Perez return cast<SourceOp>(op).getDimension(); 18698d5d344SVictor Perez } 18798d5d344SVictor Perez }; 18898d5d344SVictor Perez 18998d5d344SVictor Perez template <> 19098d5d344SVictor Perez StringRef LaunchConfigOpConversion<gpu::BlockIdOp>::getFuncName() { 19198d5d344SVictor Perez return "_Z12get_group_idj"; 19298d5d344SVictor Perez } 19398d5d344SVictor Perez 19498d5d344SVictor Perez template <> 19598d5d344SVictor Perez StringRef LaunchConfigOpConversion<gpu::GridDimOp>::getFuncName() { 19698d5d344SVictor Perez return "_Z14get_num_groupsj"; 19798d5d344SVictor Perez } 19898d5d344SVictor Perez 19998d5d344SVictor Perez template <> 20098d5d344SVictor Perez StringRef LaunchConfigOpConversion<gpu::BlockDimOp>::getFuncName() { 20198d5d344SVictor Perez return "_Z14get_local_sizej"; 20298d5d344SVictor Perez } 20398d5d344SVictor Perez 20498d5d344SVictor Perez template <> 20598d5d344SVictor Perez StringRef LaunchConfigOpConversion<gpu::ThreadIdOp>::getFuncName() { 20698d5d344SVictor Perez return "_Z12get_local_idj"; 20798d5d344SVictor Perez } 20898d5d344SVictor Perez 20998d5d344SVictor Perez template <> 21098d5d344SVictor Perez StringRef LaunchConfigOpConversion<gpu::GlobalIdOp>::getFuncName() { 21198d5d344SVictor Perez return "_Z13get_global_idj"; 21298d5d344SVictor Perez } 21398d5d344SVictor Perez 21498d5d344SVictor Perez //===----------------------------------------------------------------------===// 21598d5d344SVictor Perez // Shuffles 21698d5d344SVictor Perez //===----------------------------------------------------------------------===// 21798d5d344SVictor Perez 21898d5d344SVictor Perez /// Replace `gpu.shuffle` with an `llvm.call` to the corresponding SPIR-V 21998d5d344SVictor Perez /// builtin for `shuffleResult`, keeping `value` and `offset` arguments, and a 22098d5d344SVictor Perez /// `true` constant for the `valid` result type. Conversion will only take place 22198d5d344SVictor Perez /// if `width` is constant and equal to the `subgroup` pass option: 22298d5d344SVictor Perez /// ``` 22398d5d344SVictor Perez /// // %0 = gpu.shuffle idx %value, %offset, %width : f64 22498d5d344SVictor Perez /// %0 = llvm.call spir_funccc @_Z17sub_group_shuffledj(%value, %offset) 22598d5d344SVictor Perez /// : (f64, i32) -> f64 22698d5d344SVictor Perez /// ``` 22798d5d344SVictor Perez struct GPUShuffleConversion final : ConvertOpToLLVMPattern<gpu::ShuffleOp> { 22898d5d344SVictor Perez using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; 22998d5d344SVictor Perez 23098d5d344SVictor Perez static StringRef getBaseName(gpu::ShuffleMode mode) { 23198d5d344SVictor Perez switch (mode) { 23298d5d344SVictor Perez case gpu::ShuffleMode::IDX: 23398d5d344SVictor Perez return "sub_group_shuffle"; 23498d5d344SVictor Perez case gpu::ShuffleMode::XOR: 23598d5d344SVictor Perez return "sub_group_shuffle_xor"; 23698d5d344SVictor Perez case gpu::ShuffleMode::UP: 23798d5d344SVictor Perez return "sub_group_shuffle_up"; 23898d5d344SVictor Perez case gpu::ShuffleMode::DOWN: 23998d5d344SVictor Perez return "sub_group_shuffle_down"; 24098d5d344SVictor Perez } 24198d5d344SVictor Perez llvm_unreachable("Unhandled shuffle mode"); 24298d5d344SVictor Perez } 24398d5d344SVictor Perez 244552d26e2SFinlay static std::optional<StringRef> getTypeMangling(Type type) { 245552d26e2SFinlay return TypeSwitch<Type, std::optional<StringRef>>(type) 246552d26e2SFinlay .Case<Float16Type>([](auto) { return "Dhj"; }) 24798d5d344SVictor Perez .Case<Float32Type>([](auto) { return "fj"; }) 24898d5d344SVictor Perez .Case<Float64Type>([](auto) { return "dj"; }) 249552d26e2SFinlay .Case<IntegerType>([](auto intTy) -> std::optional<StringRef> { 25098d5d344SVictor Perez switch (intTy.getWidth()) { 251552d26e2SFinlay case 8: 252552d26e2SFinlay return "cj"; 253552d26e2SFinlay case 16: 254552d26e2SFinlay return "sj"; 25598d5d344SVictor Perez case 32: 25698d5d344SVictor Perez return "ij"; 25798d5d344SVictor Perez case 64: 25898d5d344SVictor Perez return "lj"; 25998d5d344SVictor Perez } 260552d26e2SFinlay return std::nullopt; 261552d26e2SFinlay }) 262552d26e2SFinlay .Default([](auto) { return std::nullopt; }); 26398d5d344SVictor Perez } 26498d5d344SVictor Perez 265*cdd652ebSPietro Ghiglio static std::optional<std::string> getFuncName(gpu::ShuffleMode mode, 266*cdd652ebSPietro Ghiglio Type type) { 267*cdd652ebSPietro Ghiglio StringRef baseName = getBaseName(mode); 268*cdd652ebSPietro Ghiglio std::optional<StringRef> typeMangling = getTypeMangling(type); 269552d26e2SFinlay if (!typeMangling) 270552d26e2SFinlay return std::nullopt; 271*cdd652ebSPietro Ghiglio return llvm::formatv("_Z{}{}{}", baseName.size(), baseName, 272552d26e2SFinlay typeMangling.value()); 27398d5d344SVictor Perez } 27498d5d344SVictor Perez 27598d5d344SVictor Perez /// Get the subgroup size from the target or return a default. 276a807bbeaSVictor Perez static std::optional<int> getSubgroupSize(Operation *op) { 277a807bbeaSVictor Perez auto parentFunc = op->getParentOfType<LLVM::LLVMFuncOp>(); 278a807bbeaSVictor Perez if (!parentFunc) 279a807bbeaSVictor Perez return std::nullopt; 280a807bbeaSVictor Perez return parentFunc.getIntelReqdSubGroupSize(); 28198d5d344SVictor Perez } 28298d5d344SVictor Perez 28398d5d344SVictor Perez static bool hasValidWidth(gpu::ShuffleOp op) { 28498d5d344SVictor Perez llvm::APInt val; 28598d5d344SVictor Perez Value width = op.getWidth(); 28698d5d344SVictor Perez return matchPattern(width, m_ConstantInt(&val)) && 28798d5d344SVictor Perez val == getSubgroupSize(op); 28898d5d344SVictor Perez } 28998d5d344SVictor Perez 290*cdd652ebSPietro Ghiglio static Value bitcastOrExtBeforeShuffle(Value oldVal, Location loc, 291*cdd652ebSPietro Ghiglio ConversionPatternRewriter &rewriter) { 292*cdd652ebSPietro Ghiglio return TypeSwitch<Type, Value>(oldVal.getType()) 293*cdd652ebSPietro Ghiglio .Case([&](BFloat16Type) { 294*cdd652ebSPietro Ghiglio return rewriter.create<LLVM::BitcastOp>(loc, rewriter.getI16Type(), 295*cdd652ebSPietro Ghiglio oldVal); 296*cdd652ebSPietro Ghiglio }) 297*cdd652ebSPietro Ghiglio .Case([&](IntegerType intTy) -> Value { 298*cdd652ebSPietro Ghiglio if (intTy.getWidth() == 1) 299*cdd652ebSPietro Ghiglio return rewriter.create<LLVM::ZExtOp>(loc, rewriter.getI8Type(), 300*cdd652ebSPietro Ghiglio oldVal); 301*cdd652ebSPietro Ghiglio return oldVal; 302*cdd652ebSPietro Ghiglio }) 303*cdd652ebSPietro Ghiglio .Default(oldVal); 304*cdd652ebSPietro Ghiglio } 305*cdd652ebSPietro Ghiglio 306*cdd652ebSPietro Ghiglio static Value bitcastOrTruncAfterShuffle(Value oldVal, Type newTy, 307*cdd652ebSPietro Ghiglio Location loc, 308*cdd652ebSPietro Ghiglio ConversionPatternRewriter &rewriter) { 309*cdd652ebSPietro Ghiglio return TypeSwitch<Type, Value>(newTy) 310*cdd652ebSPietro Ghiglio .Case([&](BFloat16Type) { 311*cdd652ebSPietro Ghiglio return rewriter.create<LLVM::BitcastOp>(loc, newTy, oldVal); 312*cdd652ebSPietro Ghiglio }) 313*cdd652ebSPietro Ghiglio .Case([&](IntegerType intTy) -> Value { 314*cdd652ebSPietro Ghiglio if (intTy.getWidth() == 1) 315*cdd652ebSPietro Ghiglio return rewriter.create<LLVM::TruncOp>(loc, newTy, oldVal); 316*cdd652ebSPietro Ghiglio return oldVal; 317*cdd652ebSPietro Ghiglio }) 318*cdd652ebSPietro Ghiglio .Default(oldVal); 319*cdd652ebSPietro Ghiglio } 320*cdd652ebSPietro Ghiglio 32198d5d344SVictor Perez LogicalResult 32298d5d344SVictor Perez matchAndRewrite(gpu::ShuffleOp op, OpAdaptor adaptor, 32398d5d344SVictor Perez ConversionPatternRewriter &rewriter) const final { 32498d5d344SVictor Perez if (!hasValidWidth(op)) 32598d5d344SVictor Perez return rewriter.notifyMatchFailure( 32698d5d344SVictor Perez op, "shuffle width and subgroup size mismatch"); 32798d5d344SVictor Perez 328*cdd652ebSPietro Ghiglio Location loc = op->getLoc(); 329*cdd652ebSPietro Ghiglio Value inValue = 330*cdd652ebSPietro Ghiglio bitcastOrExtBeforeShuffle(adaptor.getValue(), loc, rewriter); 331*cdd652ebSPietro Ghiglio std::optional<std::string> funcName = 332*cdd652ebSPietro Ghiglio getFuncName(op.getMode(), inValue.getType()); 333552d26e2SFinlay if (!funcName) 334552d26e2SFinlay return rewriter.notifyMatchFailure(op, "unsupported value type"); 33598d5d344SVictor Perez 33698d5d344SVictor Perez Operation *moduleOp = op->getParentWithTrait<OpTrait::SymbolTable>(); 33798d5d344SVictor Perez assert(moduleOp && "Expecting module"); 338*cdd652ebSPietro Ghiglio Type valueType = inValue.getType(); 33998d5d344SVictor Perez Type offsetType = adaptor.getOffset().getType(); 34098d5d344SVictor Perez Type resultType = valueType; 3415a53add8SFinlay LLVM::LLVMFuncOp func = lookupOrCreateSPIRVFn( 342552d26e2SFinlay moduleOp, funcName.value(), {valueType, offsetType}, resultType, 3435a53add8SFinlay /*isMemNone=*/false, /*isConvergent=*/true); 34498d5d344SVictor Perez 345*cdd652ebSPietro Ghiglio std::array<Value, 2> args{inValue, adaptor.getOffset()}; 34698d5d344SVictor Perez Value result = 34798d5d344SVictor Perez createSPIRVBuiltinCall(loc, rewriter, func, args).getResult(); 348*cdd652ebSPietro Ghiglio Value resultOrConversion = 349*cdd652ebSPietro Ghiglio bitcastOrTruncAfterShuffle(result, op.getType(0), loc, rewriter); 350*cdd652ebSPietro Ghiglio 35198d5d344SVictor Perez Value trueVal = 35298d5d344SVictor Perez rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI1Type(), true); 353*cdd652ebSPietro Ghiglio rewriter.replaceOp(op, {resultOrConversion, trueVal}); 35498d5d344SVictor Perez return success(); 35598d5d344SVictor Perez } 35698d5d344SVictor Perez }; 35798d5d344SVictor Perez 358f8b7a653SPetr Kurapov class MemorySpaceToOpenCLMemorySpaceConverter final : public TypeConverter { 359f8b7a653SPetr Kurapov public: 360f8b7a653SPetr Kurapov MemorySpaceToOpenCLMemorySpaceConverter(MLIRContext *ctx) { 361f8b7a653SPetr Kurapov addConversion([](Type t) { return t; }); 362f8b7a653SPetr Kurapov addConversion([ctx](BaseMemRefType memRefType) -> std::optional<Type> { 363f8b7a653SPetr Kurapov // Attach global addr space attribute to memrefs with no addr space attr 364f8b7a653SPetr Kurapov Attribute memSpaceAttr = memRefType.getMemorySpace(); 365f8b7a653SPetr Kurapov if (memSpaceAttr) 366f8b7a653SPetr Kurapov return std::nullopt; 367f8b7a653SPetr Kurapov 368f8b7a653SPetr Kurapov unsigned globalAddrspace = storageClassToAddressSpace( 369f8b7a653SPetr Kurapov spirv::ClientAPI::OpenCL, spirv::StorageClass::CrossWorkgroup); 370f8b7a653SPetr Kurapov Attribute addrSpaceAttr = 371f8b7a653SPetr Kurapov IntegerAttr::get(IntegerType::get(ctx, 64), globalAddrspace); 372f8b7a653SPetr Kurapov if (auto rankedType = dyn_cast<MemRefType>(memRefType)) { 373f8b7a653SPetr Kurapov return MemRefType::get(memRefType.getShape(), 374f8b7a653SPetr Kurapov memRefType.getElementType(), 375f8b7a653SPetr Kurapov rankedType.getLayout(), addrSpaceAttr); 376f8b7a653SPetr Kurapov } 377f8b7a653SPetr Kurapov return UnrankedMemRefType::get(memRefType.getElementType(), 378f8b7a653SPetr Kurapov addrSpaceAttr); 379f8b7a653SPetr Kurapov }); 380f8b7a653SPetr Kurapov addConversion([this](FunctionType type) { 381f8b7a653SPetr Kurapov auto inputs = llvm::map_to_vector( 382f8b7a653SPetr Kurapov type.getInputs(), [this](Type ty) { return convertType(ty); }); 383f8b7a653SPetr Kurapov auto results = llvm::map_to_vector( 384f8b7a653SPetr Kurapov type.getResults(), [this](Type ty) { return convertType(ty); }); 385f8b7a653SPetr Kurapov return FunctionType::get(type.getContext(), inputs, results); 386f8b7a653SPetr Kurapov }); 387f8b7a653SPetr Kurapov } 388f8b7a653SPetr Kurapov }; 389f8b7a653SPetr Kurapov 39098d5d344SVictor Perez //===----------------------------------------------------------------------===// 391af7aa223SFinlay // Subgroup query ops. 392af7aa223SFinlay //===----------------------------------------------------------------------===// 393af7aa223SFinlay 394af7aa223SFinlay template <typename SubgroupOp> 395af7aa223SFinlay struct GPUSubgroupOpConversion final : ConvertOpToLLVMPattern<SubgroupOp> { 396af7aa223SFinlay using ConvertOpToLLVMPattern<SubgroupOp>::ConvertOpToLLVMPattern; 397af7aa223SFinlay using ConvertToLLVMPattern::getTypeConverter; 398af7aa223SFinlay 399af7aa223SFinlay LogicalResult 400af7aa223SFinlay matchAndRewrite(SubgroupOp op, typename SubgroupOp::Adaptor adaptor, 401af7aa223SFinlay ConversionPatternRewriter &rewriter) const final { 402af7aa223SFinlay constexpr StringRef funcName = [] { 403af7aa223SFinlay if constexpr (std::is_same_v<SubgroupOp, gpu::SubgroupIdOp>) { 404af7aa223SFinlay return "_Z16get_sub_group_id"; 405af7aa223SFinlay } else if constexpr (std::is_same_v<SubgroupOp, gpu::LaneIdOp>) { 406af7aa223SFinlay return "_Z22get_sub_group_local_id"; 407af7aa223SFinlay } else if constexpr (std::is_same_v<SubgroupOp, gpu::NumSubgroupsOp>) { 408af7aa223SFinlay return "_Z18get_num_sub_groups"; 409af7aa223SFinlay } else if constexpr (std::is_same_v<SubgroupOp, gpu::SubgroupSizeOp>) { 410af7aa223SFinlay return "_Z18get_sub_group_size"; 411af7aa223SFinlay } 412af7aa223SFinlay }(); 413af7aa223SFinlay 414af7aa223SFinlay Operation *moduleOp = 415af7aa223SFinlay op->template getParentWithTrait<OpTrait::SymbolTable>(); 416af7aa223SFinlay Type resultTy = rewriter.getI32Type(); 417af7aa223SFinlay LLVM::LLVMFuncOp func = 418af7aa223SFinlay lookupOrCreateSPIRVFn(moduleOp, funcName, {}, resultTy, 419af7aa223SFinlay /*isMemNone=*/false, /*isConvergent=*/false); 420af7aa223SFinlay 421af7aa223SFinlay Location loc = op->getLoc(); 422af7aa223SFinlay Value result = createSPIRVBuiltinCall(loc, rewriter, func, {}).getResult(); 423af7aa223SFinlay 424af7aa223SFinlay Type indexTy = getTypeConverter()->getIndexType(); 425af7aa223SFinlay if (resultTy != indexTy) { 426af7aa223SFinlay if (indexTy.getIntOrFloatBitWidth() < resultTy.getIntOrFloatBitWidth()) { 427af7aa223SFinlay return failure(); 428af7aa223SFinlay } 429af7aa223SFinlay result = rewriter.create<LLVM::ZExtOp>(loc, indexTy, result); 430af7aa223SFinlay } 431af7aa223SFinlay 432af7aa223SFinlay rewriter.replaceOp(op, result); 433af7aa223SFinlay return success(); 434af7aa223SFinlay } 435af7aa223SFinlay }; 436af7aa223SFinlay 437af7aa223SFinlay //===----------------------------------------------------------------------===// 43898d5d344SVictor Perez // GPU To LLVM-SPV Pass. 43998d5d344SVictor Perez //===----------------------------------------------------------------------===// 44098d5d344SVictor Perez 44198d5d344SVictor Perez struct GPUToLLVMSPVConversionPass final 44298d5d344SVictor Perez : impl::ConvertGpuOpsToLLVMSPVOpsBase<GPUToLLVMSPVConversionPass> { 44398d5d344SVictor Perez using Base::Base; 44498d5d344SVictor Perez 44598d5d344SVictor Perez void runOnOperation() final { 44698d5d344SVictor Perez MLIRContext *context = &getContext(); 44798d5d344SVictor Perez RewritePatternSet patterns(context); 44898d5d344SVictor Perez 44998d5d344SVictor Perez LowerToLLVMOptions options(context); 45081825687SJefferson Le Quellec options.overrideIndexBitwidth(this->use64bitIndex ? 64 : 32); 45198d5d344SVictor Perez LLVMTypeConverter converter(context, options); 45298d5d344SVictor Perez LLVMConversionTarget target(*context); 45398d5d344SVictor Perez 454f8b7a653SPetr Kurapov // Force OpenCL address spaces when they are not present 455f8b7a653SPetr Kurapov { 456f8b7a653SPetr Kurapov MemorySpaceToOpenCLMemorySpaceConverter converter(context); 457f8b7a653SPetr Kurapov AttrTypeReplacer replacer; 458f8b7a653SPetr Kurapov replacer.addReplacement([&converter](BaseMemRefType origType) 459f8b7a653SPetr Kurapov -> std::optional<BaseMemRefType> { 460f8b7a653SPetr Kurapov return converter.convertType<BaseMemRefType>(origType); 461f8b7a653SPetr Kurapov }); 462f8b7a653SPetr Kurapov 463f8b7a653SPetr Kurapov replacer.recursivelyReplaceElementsIn(getOperation(), 464f8b7a653SPetr Kurapov /*replaceAttrs=*/true, 465f8b7a653SPetr Kurapov /*replaceLocs=*/false, 466f8b7a653SPetr Kurapov /*replaceTypes=*/true); 467f8b7a653SPetr Kurapov } 468f8b7a653SPetr Kurapov 46998d5d344SVictor Perez target.addIllegalOp<gpu::BarrierOp, gpu::BlockDimOp, gpu::BlockIdOp, 470d45de800SVictor Perez gpu::GPUFuncOp, gpu::GlobalIdOp, gpu::GridDimOp, 471af7aa223SFinlay gpu::LaneIdOp, gpu::NumSubgroupsOp, gpu::ReturnOp, 472af7aa223SFinlay gpu::ShuffleOp, gpu::SubgroupIdOp, gpu::SubgroupSizeOp, 473af7aa223SFinlay gpu::ThreadIdOp>(); 47498d5d344SVictor Perez 47598d5d344SVictor Perez populateGpuToLLVMSPVConversionPatterns(converter, patterns); 47675cb9edfSVictor Perez populateGpuMemorySpaceAttributeConversions(converter); 47798d5d344SVictor Perez 47898d5d344SVictor Perez if (failed(applyPartialConversion(getOperation(), target, 47998d5d344SVictor Perez std::move(patterns)))) 48098d5d344SVictor Perez signalPassFailure(); 48198d5d344SVictor Perez } 48298d5d344SVictor Perez }; 48398d5d344SVictor Perez } // namespace 48498d5d344SVictor Perez 48598d5d344SVictor Perez //===----------------------------------------------------------------------===// 48698d5d344SVictor Perez // GPU To LLVM-SPV Patterns. 48798d5d344SVictor Perez //===----------------------------------------------------------------------===// 48898d5d344SVictor Perez 48998d5d344SVictor Perez namespace mlir { 49075cb9edfSVictor Perez namespace { 49175cb9edfSVictor Perez static unsigned 49275cb9edfSVictor Perez gpuAddressSpaceToOCLAddressSpace(gpu::AddressSpace addressSpace) { 49375cb9edfSVictor Perez constexpr spirv::ClientAPI clientAPI = spirv::ClientAPI::OpenCL; 49475cb9edfSVictor Perez return storageClassToAddressSpace(clientAPI, 49575cb9edfSVictor Perez addressSpaceToStorageClass(addressSpace)); 49675cb9edfSVictor Perez } 49775cb9edfSVictor Perez } // namespace 49875cb9edfSVictor Perez 499206fad0eSMatthias Springer void populateGpuToLLVMSPVConversionPatterns( 500206fad0eSMatthias Springer const LLVMTypeConverter &typeConverter, RewritePatternSet &patterns) { 501d45de800SVictor Perez patterns.add<GPUBarrierConversion, GPUReturnOpLowering, GPUShuffleConversion, 502af7aa223SFinlay GPUSubgroupOpConversion<gpu::LaneIdOp>, 503af7aa223SFinlay GPUSubgroupOpConversion<gpu::NumSubgroupsOp>, 504af7aa223SFinlay GPUSubgroupOpConversion<gpu::SubgroupIdOp>, 505af7aa223SFinlay GPUSubgroupOpConversion<gpu::SubgroupSizeOp>, 50698d5d344SVictor Perez LaunchConfigOpConversion<gpu::BlockDimOp>, 507af7aa223SFinlay LaunchConfigOpConversion<gpu::BlockIdOp>, 508af7aa223SFinlay LaunchConfigOpConversion<gpu::GlobalIdOp>, 509af7aa223SFinlay LaunchConfigOpConversion<gpu::GridDimOp>, 510af7aa223SFinlay LaunchConfigOpConversion<gpu::ThreadIdOp>>(typeConverter); 511d45de800SVictor Perez MLIRContext *context = &typeConverter.getContext(); 512d45de800SVictor Perez unsigned privateAddressSpace = 51375cb9edfSVictor Perez gpuAddressSpaceToOCLAddressSpace(gpu::AddressSpace::Private); 514d45de800SVictor Perez unsigned localAddressSpace = 51575cb9edfSVictor Perez gpuAddressSpaceToOCLAddressSpace(gpu::AddressSpace::Workgroup); 516d45de800SVictor Perez OperationName llvmFuncOpName(LLVM::LLVMFuncOp::getOperationName(), context); 517d45de800SVictor Perez StringAttr kernelBlockSizeAttributeName = 518d45de800SVictor Perez LLVM::LLVMFuncOp::getReqdWorkGroupSizeAttrName(llvmFuncOpName); 519d45de800SVictor Perez patterns.add<GPUFuncOpLowering>( 520d45de800SVictor Perez typeConverter, 521d45de800SVictor Perez GPUFuncOpLoweringOptions{ 522d45de800SVictor Perez privateAddressSpace, localAddressSpace, 523d45de800SVictor Perez /*kernelAttributeName=*/{}, kernelBlockSizeAttributeName, 524d45de800SVictor Perez LLVM::CConv::SPIR_KERNEL, LLVM::CConv::SPIR_FUNC, 525d45de800SVictor Perez /*encodeWorkgroupAttributionsAsArguments=*/true}); 52698d5d344SVictor Perez } 52775cb9edfSVictor Perez 52875cb9edfSVictor Perez void populateGpuMemorySpaceAttributeConversions(TypeConverter &typeConverter) { 52975cb9edfSVictor Perez populateGpuMemorySpaceAttributeConversions(typeConverter, 53075cb9edfSVictor Perez gpuAddressSpaceToOCLAddressSpace); 53175cb9edfSVictor Perez } 53298d5d344SVictor Perez } // namespace mlir 533