xref: /llvm-project/mlir/lib/Conversion/GPUToLLVMSPV/GPUToLLVMSPV.cpp (revision cdd652eb28d1dcec28fec289674940d11a92c4fa)
198d5d344SVictor Perez //===- GPUToLLVMSPV.cpp - Convert GPU operations to LLVM dialect ----------===//
298d5d344SVictor Perez //
398d5d344SVictor Perez // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
498d5d344SVictor Perez // See https://llvm.org/LICENSE.txt for license information.
598d5d344SVictor Perez // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
698d5d344SVictor Perez //
798d5d344SVictor Perez //===----------------------------------------------------------------------===//
898d5d344SVictor Perez 
998d5d344SVictor Perez #include "mlir/Conversion/GPUToLLVMSPV/GPUToLLVMSPVPass.h"
1098d5d344SVictor Perez 
11d45de800SVictor Perez #include "../GPUCommon/GPUOpsLowering.h"
1275cb9edfSVictor Perez #include "mlir/Conversion/GPUCommon/AttrToSPIRVConverter.h"
1375cb9edfSVictor Perez #include "mlir/Conversion/GPUCommon/GPUCommonPass.h"
1498d5d344SVictor Perez #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
1598d5d344SVictor Perez #include "mlir/Conversion/LLVMCommon/LoweringOptions.h"
1698d5d344SVictor Perez #include "mlir/Conversion/LLVMCommon/Pattern.h"
1798d5d344SVictor Perez #include "mlir/Conversion/LLVMCommon/TypeConverter.h"
18d45de800SVictor Perez #include "mlir/Conversion/SPIRVCommon/AttrToLLVMConverter.h"
1998d5d344SVictor Perez #include "mlir/Dialect/GPU/IR/GPUDialect.h"
2098d5d344SVictor Perez #include "mlir/Dialect/LLVMIR/LLVMAttrs.h"
2198d5d344SVictor Perez #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
2298d5d344SVictor Perez #include "mlir/Dialect/LLVMIR/LLVMTypes.h"
23d45de800SVictor Perez #include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
2498d5d344SVictor Perez #include "mlir/IR/BuiltinTypes.h"
2598d5d344SVictor Perez #include "mlir/IR/Matchers.h"
2698d5d344SVictor Perez #include "mlir/IR/PatternMatch.h"
2798d5d344SVictor Perez #include "mlir/IR/SymbolTable.h"
2898d5d344SVictor Perez #include "mlir/Pass/Pass.h"
2998d5d344SVictor Perez #include "mlir/Support/LLVM.h"
3098d5d344SVictor Perez #include "mlir/Transforms/DialectConversion.h"
3198d5d344SVictor Perez 
3298d5d344SVictor Perez #include "llvm/ADT/TypeSwitch.h"
3398d5d344SVictor Perez #include "llvm/Support/FormatVariadic.h"
3498d5d344SVictor Perez 
35f8b7a653SPetr Kurapov #define DEBUG_TYPE "gpu-to-llvm-spv"
36f8b7a653SPetr Kurapov 
3798d5d344SVictor Perez using namespace mlir;
3898d5d344SVictor Perez 
3998d5d344SVictor Perez namespace mlir {
4098d5d344SVictor Perez #define GEN_PASS_DEF_CONVERTGPUOPSTOLLVMSPVOPS
4198d5d344SVictor Perez #include "mlir/Conversion/Passes.h.inc"
4298d5d344SVictor Perez } // namespace mlir
4398d5d344SVictor Perez 
4498d5d344SVictor Perez //===----------------------------------------------------------------------===//
4598d5d344SVictor Perez // Helper Functions
4698d5d344SVictor Perez //===----------------------------------------------------------------------===//
4798d5d344SVictor Perez 
4898d5d344SVictor Perez static LLVM::LLVMFuncOp lookupOrCreateSPIRVFn(Operation *symbolTable,
4998d5d344SVictor Perez                                               StringRef name,
5098d5d344SVictor Perez                                               ArrayRef<Type> paramTypes,
515a53add8SFinlay                                               Type resultType, bool isMemNone,
525a53add8SFinlay                                               bool isConvergent) {
5398d5d344SVictor Perez   auto func = dyn_cast_or_null<LLVM::LLVMFuncOp>(
5498d5d344SVictor Perez       SymbolTable::lookupSymbolIn(symbolTable, name));
5598d5d344SVictor Perez   if (!func) {
5698d5d344SVictor Perez     OpBuilder b(symbolTable->getRegion(0));
5798d5d344SVictor Perez     func = b.create<LLVM::LLVMFuncOp>(
5898d5d344SVictor Perez         symbolTable->getLoc(), name,
5998d5d344SVictor Perez         LLVM::LLVMFunctionType::get(resultType, paramTypes));
6098d5d344SVictor Perez     func.setCConv(LLVM::cconv::CConv::SPIR_FUNC);
615a53add8SFinlay     func.setNoUnwind(true);
625a53add8SFinlay     func.setWillReturn(true);
635a53add8SFinlay 
645a53add8SFinlay     if (isMemNone) {
655a53add8SFinlay       // no externally observable effects
665a53add8SFinlay       constexpr auto noModRef = mlir::LLVM::ModRefInfo::NoModRef;
675a53add8SFinlay       auto memAttr = b.getAttr<LLVM::MemoryEffectsAttr>(
685a53add8SFinlay           /*other=*/noModRef,
695a53add8SFinlay           /*argMem=*/noModRef, /*inaccessibleMem=*/noModRef);
705a53add8SFinlay       func.setMemoryEffectsAttr(memAttr);
715a53add8SFinlay     }
725a53add8SFinlay 
733670e7f8SFinlay     func.setConvergent(isConvergent);
7498d5d344SVictor Perez   }
7598d5d344SVictor Perez   return func;
7698d5d344SVictor Perez }
7798d5d344SVictor Perez 
7898d5d344SVictor Perez static LLVM::CallOp createSPIRVBuiltinCall(Location loc,
7998d5d344SVictor Perez                                            ConversionPatternRewriter &rewriter,
8098d5d344SVictor Perez                                            LLVM::LLVMFuncOp func,
8198d5d344SVictor Perez                                            ValueRange args) {
8298d5d344SVictor Perez   auto call = rewriter.create<LLVM::CallOp>(loc, func, args);
8398d5d344SVictor Perez   call.setCConv(func.getCConv());
845a53add8SFinlay   call.setConvergentAttr(func.getConvergentAttr());
855a53add8SFinlay   call.setNoUnwindAttr(func.getNoUnwindAttr());
865a53add8SFinlay   call.setWillReturnAttr(func.getWillReturnAttr());
875a53add8SFinlay   call.setMemoryEffectsAttr(func.getMemoryEffectsAttr());
8898d5d344SVictor Perez   return call;
8998d5d344SVictor Perez }
9098d5d344SVictor Perez 
9198d5d344SVictor Perez namespace {
9298d5d344SVictor Perez //===----------------------------------------------------------------------===//
9398d5d344SVictor Perez // Barriers
9498d5d344SVictor Perez //===----------------------------------------------------------------------===//
9598d5d344SVictor Perez 
9698d5d344SVictor Perez /// Replace `gpu.barrier` with an `llvm.call` to `barrier` with
9798d5d344SVictor Perez /// `CLK_LOCAL_MEM_FENCE` argument, indicating work-group memory scope:
9898d5d344SVictor Perez /// ```
9998d5d344SVictor Perez /// // gpu.barrier
10098d5d344SVictor Perez /// %c1 = llvm.mlir.constant(1: i32) : i32
10198d5d344SVictor Perez /// llvm.call spir_funccc @_Z7barrierj(%c1) : (i32) -> ()
10298d5d344SVictor Perez /// ```
10398d5d344SVictor Perez struct GPUBarrierConversion final : ConvertOpToLLVMPattern<gpu::BarrierOp> {
10498d5d344SVictor Perez   using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
10598d5d344SVictor Perez 
10698d5d344SVictor Perez   LogicalResult
10798d5d344SVictor Perez   matchAndRewrite(gpu::BarrierOp op, OpAdaptor adaptor,
10898d5d344SVictor Perez                   ConversionPatternRewriter &rewriter) const final {
10998d5d344SVictor Perez     constexpr StringLiteral funcName = "_Z7barrierj";
11098d5d344SVictor Perez 
11198d5d344SVictor Perez     Operation *moduleOp = op->getParentWithTrait<OpTrait::SymbolTable>();
11298d5d344SVictor Perez     assert(moduleOp && "Expecting module");
11398d5d344SVictor Perez     Type flagTy = rewriter.getI32Type();
11498d5d344SVictor Perez     Type voidTy = rewriter.getType<LLVM::LLVMVoidType>();
1155a53add8SFinlay     LLVM::LLVMFuncOp func =
1165a53add8SFinlay         lookupOrCreateSPIRVFn(moduleOp, funcName, flagTy, voidTy,
1175a53add8SFinlay                               /*isMemNone=*/false, /*isConvergent=*/true);
11898d5d344SVictor Perez 
11998d5d344SVictor Perez     // Value used by SPIR-V backend to represent `CLK_LOCAL_MEM_FENCE`.
12098d5d344SVictor Perez     // See `llvm/lib/Target/SPIRV/SPIRVBuiltins.td`.
12198d5d344SVictor Perez     constexpr int64_t localMemFenceFlag = 1;
12298d5d344SVictor Perez     Location loc = op->getLoc();
12398d5d344SVictor Perez     Value flag =
12498d5d344SVictor Perez         rewriter.create<LLVM::ConstantOp>(loc, flagTy, localMemFenceFlag);
12598d5d344SVictor Perez     rewriter.replaceOp(op, createSPIRVBuiltinCall(loc, rewriter, func, flag));
12698d5d344SVictor Perez     return success();
12798d5d344SVictor Perez   }
12898d5d344SVictor Perez };
12998d5d344SVictor Perez 
13098d5d344SVictor Perez //===----------------------------------------------------------------------===//
13198d5d344SVictor Perez // SPIR-V Builtins
13298d5d344SVictor Perez //===----------------------------------------------------------------------===//
13398d5d344SVictor Perez 
13498d5d344SVictor Perez /// Replace `gpu.*` with an `llvm.call` to the corresponding SPIR-V builtin with
13598d5d344SVictor Perez /// a constant argument for the `dimension` attribute. Return type will depend
13698d5d344SVictor Perez /// on index width option:
13798d5d344SVictor Perez /// ```
13898d5d344SVictor Perez /// // %thread_id_y = gpu.thread_id y
13998d5d344SVictor Perez /// %c1 = llvm.mlir.constant(1: i32) : i32
14098d5d344SVictor Perez /// %0 = llvm.call spir_funccc @_Z12get_local_idj(%c1) : (i32) -> i64
14198d5d344SVictor Perez /// ```
14298d5d344SVictor Perez struct LaunchConfigConversion : ConvertToLLVMPattern {
14398d5d344SVictor Perez   LaunchConfigConversion(StringRef funcName, StringRef rootOpName,
14498d5d344SVictor Perez                          MLIRContext *context,
14598d5d344SVictor Perez                          const LLVMTypeConverter &typeConverter,
14698d5d344SVictor Perez                          PatternBenefit benefit)
14798d5d344SVictor Perez       : ConvertToLLVMPattern(rootOpName, context, typeConverter, benefit),
14898d5d344SVictor Perez         funcName(funcName) {}
14998d5d344SVictor Perez 
15098d5d344SVictor Perez   virtual gpu::Dimension getDimension(Operation *op) const = 0;
15198d5d344SVictor Perez 
15298d5d344SVictor Perez   LogicalResult
15398d5d344SVictor Perez   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
15498d5d344SVictor Perez                   ConversionPatternRewriter &rewriter) const final {
15598d5d344SVictor Perez     Operation *moduleOp = op->getParentWithTrait<OpTrait::SymbolTable>();
15698d5d344SVictor Perez     assert(moduleOp && "Expecting module");
15798d5d344SVictor Perez     Type dimTy = rewriter.getI32Type();
15898d5d344SVictor Perez     Type indexTy = getTypeConverter()->getIndexType();
1595a53add8SFinlay     LLVM::LLVMFuncOp func = lookupOrCreateSPIRVFn(moduleOp, funcName, dimTy,
1605a53add8SFinlay                                                   indexTy, /*isMemNone=*/true,
1615a53add8SFinlay                                                   /*isConvergent=*/false);
16298d5d344SVictor Perez 
16398d5d344SVictor Perez     Location loc = op->getLoc();
16498d5d344SVictor Perez     gpu::Dimension dim = getDimension(op);
16598d5d344SVictor Perez     Value dimVal = rewriter.create<LLVM::ConstantOp>(loc, dimTy,
16698d5d344SVictor Perez                                                      static_cast<int64_t>(dim));
16798d5d344SVictor Perez     rewriter.replaceOp(op, createSPIRVBuiltinCall(loc, rewriter, func, dimVal));
16898d5d344SVictor Perez     return success();
16998d5d344SVictor Perez   }
17098d5d344SVictor Perez 
17198d5d344SVictor Perez   StringRef funcName;
17298d5d344SVictor Perez };
17398d5d344SVictor Perez 
17498d5d344SVictor Perez template <typename SourceOp>
17598d5d344SVictor Perez struct LaunchConfigOpConversion final : LaunchConfigConversion {
17698d5d344SVictor Perez   static StringRef getFuncName();
17798d5d344SVictor Perez 
17898d5d344SVictor Perez   explicit LaunchConfigOpConversion(const LLVMTypeConverter &typeConverter,
17998d5d344SVictor Perez                                     PatternBenefit benefit = 1)
18098d5d344SVictor Perez       : LaunchConfigConversion(getFuncName(), SourceOp::getOperationName(),
18198d5d344SVictor Perez                                &typeConverter.getContext(), typeConverter,
18298d5d344SVictor Perez                                benefit) {}
18398d5d344SVictor Perez 
18498d5d344SVictor Perez   gpu::Dimension getDimension(Operation *op) const final {
18598d5d344SVictor Perez     return cast<SourceOp>(op).getDimension();
18698d5d344SVictor Perez   }
18798d5d344SVictor Perez };
18898d5d344SVictor Perez 
18998d5d344SVictor Perez template <>
19098d5d344SVictor Perez StringRef LaunchConfigOpConversion<gpu::BlockIdOp>::getFuncName() {
19198d5d344SVictor Perez   return "_Z12get_group_idj";
19298d5d344SVictor Perez }
19398d5d344SVictor Perez 
19498d5d344SVictor Perez template <>
19598d5d344SVictor Perez StringRef LaunchConfigOpConversion<gpu::GridDimOp>::getFuncName() {
19698d5d344SVictor Perez   return "_Z14get_num_groupsj";
19798d5d344SVictor Perez }
19898d5d344SVictor Perez 
19998d5d344SVictor Perez template <>
20098d5d344SVictor Perez StringRef LaunchConfigOpConversion<gpu::BlockDimOp>::getFuncName() {
20198d5d344SVictor Perez   return "_Z14get_local_sizej";
20298d5d344SVictor Perez }
20398d5d344SVictor Perez 
20498d5d344SVictor Perez template <>
20598d5d344SVictor Perez StringRef LaunchConfigOpConversion<gpu::ThreadIdOp>::getFuncName() {
20698d5d344SVictor Perez   return "_Z12get_local_idj";
20798d5d344SVictor Perez }
20898d5d344SVictor Perez 
20998d5d344SVictor Perez template <>
21098d5d344SVictor Perez StringRef LaunchConfigOpConversion<gpu::GlobalIdOp>::getFuncName() {
21198d5d344SVictor Perez   return "_Z13get_global_idj";
21298d5d344SVictor Perez }
21398d5d344SVictor Perez 
21498d5d344SVictor Perez //===----------------------------------------------------------------------===//
21598d5d344SVictor Perez // Shuffles
21698d5d344SVictor Perez //===----------------------------------------------------------------------===//
21798d5d344SVictor Perez 
21898d5d344SVictor Perez /// Replace `gpu.shuffle` with an `llvm.call` to the corresponding SPIR-V
21998d5d344SVictor Perez /// builtin for `shuffleResult`, keeping `value` and `offset` arguments, and a
22098d5d344SVictor Perez /// `true` constant for the `valid` result type. Conversion will only take place
22198d5d344SVictor Perez /// if `width` is constant and equal to the `subgroup` pass option:
22298d5d344SVictor Perez /// ```
22398d5d344SVictor Perez /// // %0 = gpu.shuffle idx %value, %offset, %width : f64
22498d5d344SVictor Perez /// %0 = llvm.call spir_funccc @_Z17sub_group_shuffledj(%value, %offset)
22598d5d344SVictor Perez ///     : (f64, i32) -> f64
22698d5d344SVictor Perez /// ```
22798d5d344SVictor Perez struct GPUShuffleConversion final : ConvertOpToLLVMPattern<gpu::ShuffleOp> {
22898d5d344SVictor Perez   using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
22998d5d344SVictor Perez 
23098d5d344SVictor Perez   static StringRef getBaseName(gpu::ShuffleMode mode) {
23198d5d344SVictor Perez     switch (mode) {
23298d5d344SVictor Perez     case gpu::ShuffleMode::IDX:
23398d5d344SVictor Perez       return "sub_group_shuffle";
23498d5d344SVictor Perez     case gpu::ShuffleMode::XOR:
23598d5d344SVictor Perez       return "sub_group_shuffle_xor";
23698d5d344SVictor Perez     case gpu::ShuffleMode::UP:
23798d5d344SVictor Perez       return "sub_group_shuffle_up";
23898d5d344SVictor Perez     case gpu::ShuffleMode::DOWN:
23998d5d344SVictor Perez       return "sub_group_shuffle_down";
24098d5d344SVictor Perez     }
24198d5d344SVictor Perez     llvm_unreachable("Unhandled shuffle mode");
24298d5d344SVictor Perez   }
24398d5d344SVictor Perez 
244552d26e2SFinlay   static std::optional<StringRef> getTypeMangling(Type type) {
245552d26e2SFinlay     return TypeSwitch<Type, std::optional<StringRef>>(type)
246552d26e2SFinlay         .Case<Float16Type>([](auto) { return "Dhj"; })
24798d5d344SVictor Perez         .Case<Float32Type>([](auto) { return "fj"; })
24898d5d344SVictor Perez         .Case<Float64Type>([](auto) { return "dj"; })
249552d26e2SFinlay         .Case<IntegerType>([](auto intTy) -> std::optional<StringRef> {
25098d5d344SVictor Perez           switch (intTy.getWidth()) {
251552d26e2SFinlay           case 8:
252552d26e2SFinlay             return "cj";
253552d26e2SFinlay           case 16:
254552d26e2SFinlay             return "sj";
25598d5d344SVictor Perez           case 32:
25698d5d344SVictor Perez             return "ij";
25798d5d344SVictor Perez           case 64:
25898d5d344SVictor Perez             return "lj";
25998d5d344SVictor Perez           }
260552d26e2SFinlay           return std::nullopt;
261552d26e2SFinlay         })
262552d26e2SFinlay         .Default([](auto) { return std::nullopt; });
26398d5d344SVictor Perez   }
26498d5d344SVictor Perez 
265*cdd652ebSPietro Ghiglio   static std::optional<std::string> getFuncName(gpu::ShuffleMode mode,
266*cdd652ebSPietro Ghiglio                                                 Type type) {
267*cdd652ebSPietro Ghiglio     StringRef baseName = getBaseName(mode);
268*cdd652ebSPietro Ghiglio     std::optional<StringRef> typeMangling = getTypeMangling(type);
269552d26e2SFinlay     if (!typeMangling)
270552d26e2SFinlay       return std::nullopt;
271*cdd652ebSPietro Ghiglio     return llvm::formatv("_Z{}{}{}", baseName.size(), baseName,
272552d26e2SFinlay                          typeMangling.value());
27398d5d344SVictor Perez   }
27498d5d344SVictor Perez 
27598d5d344SVictor Perez   /// Get the subgroup size from the target or return a default.
276a807bbeaSVictor Perez   static std::optional<int> getSubgroupSize(Operation *op) {
277a807bbeaSVictor Perez     auto parentFunc = op->getParentOfType<LLVM::LLVMFuncOp>();
278a807bbeaSVictor Perez     if (!parentFunc)
279a807bbeaSVictor Perez       return std::nullopt;
280a807bbeaSVictor Perez     return parentFunc.getIntelReqdSubGroupSize();
28198d5d344SVictor Perez   }
28298d5d344SVictor Perez 
28398d5d344SVictor Perez   static bool hasValidWidth(gpu::ShuffleOp op) {
28498d5d344SVictor Perez     llvm::APInt val;
28598d5d344SVictor Perez     Value width = op.getWidth();
28698d5d344SVictor Perez     return matchPattern(width, m_ConstantInt(&val)) &&
28798d5d344SVictor Perez            val == getSubgroupSize(op);
28898d5d344SVictor Perez   }
28998d5d344SVictor Perez 
290*cdd652ebSPietro Ghiglio   static Value bitcastOrExtBeforeShuffle(Value oldVal, Location loc,
291*cdd652ebSPietro Ghiglio                                          ConversionPatternRewriter &rewriter) {
292*cdd652ebSPietro Ghiglio     return TypeSwitch<Type, Value>(oldVal.getType())
293*cdd652ebSPietro Ghiglio         .Case([&](BFloat16Type) {
294*cdd652ebSPietro Ghiglio           return rewriter.create<LLVM::BitcastOp>(loc, rewriter.getI16Type(),
295*cdd652ebSPietro Ghiglio                                                   oldVal);
296*cdd652ebSPietro Ghiglio         })
297*cdd652ebSPietro Ghiglio         .Case([&](IntegerType intTy) -> Value {
298*cdd652ebSPietro Ghiglio           if (intTy.getWidth() == 1)
299*cdd652ebSPietro Ghiglio             return rewriter.create<LLVM::ZExtOp>(loc, rewriter.getI8Type(),
300*cdd652ebSPietro Ghiglio                                                  oldVal);
301*cdd652ebSPietro Ghiglio           return oldVal;
302*cdd652ebSPietro Ghiglio         })
303*cdd652ebSPietro Ghiglio         .Default(oldVal);
304*cdd652ebSPietro Ghiglio   }
305*cdd652ebSPietro Ghiglio 
306*cdd652ebSPietro Ghiglio   static Value bitcastOrTruncAfterShuffle(Value oldVal, Type newTy,
307*cdd652ebSPietro Ghiglio                                           Location loc,
308*cdd652ebSPietro Ghiglio                                           ConversionPatternRewriter &rewriter) {
309*cdd652ebSPietro Ghiglio     return TypeSwitch<Type, Value>(newTy)
310*cdd652ebSPietro Ghiglio         .Case([&](BFloat16Type) {
311*cdd652ebSPietro Ghiglio           return rewriter.create<LLVM::BitcastOp>(loc, newTy, oldVal);
312*cdd652ebSPietro Ghiglio         })
313*cdd652ebSPietro Ghiglio         .Case([&](IntegerType intTy) -> Value {
314*cdd652ebSPietro Ghiglio           if (intTy.getWidth() == 1)
315*cdd652ebSPietro Ghiglio             return rewriter.create<LLVM::TruncOp>(loc, newTy, oldVal);
316*cdd652ebSPietro Ghiglio           return oldVal;
317*cdd652ebSPietro Ghiglio         })
318*cdd652ebSPietro Ghiglio         .Default(oldVal);
319*cdd652ebSPietro Ghiglio   }
320*cdd652ebSPietro Ghiglio 
32198d5d344SVictor Perez   LogicalResult
32298d5d344SVictor Perez   matchAndRewrite(gpu::ShuffleOp op, OpAdaptor adaptor,
32398d5d344SVictor Perez                   ConversionPatternRewriter &rewriter) const final {
32498d5d344SVictor Perez     if (!hasValidWidth(op))
32598d5d344SVictor Perez       return rewriter.notifyMatchFailure(
32698d5d344SVictor Perez           op, "shuffle width and subgroup size mismatch");
32798d5d344SVictor Perez 
328*cdd652ebSPietro Ghiglio     Location loc = op->getLoc();
329*cdd652ebSPietro Ghiglio     Value inValue =
330*cdd652ebSPietro Ghiglio         bitcastOrExtBeforeShuffle(adaptor.getValue(), loc, rewriter);
331*cdd652ebSPietro Ghiglio     std::optional<std::string> funcName =
332*cdd652ebSPietro Ghiglio         getFuncName(op.getMode(), inValue.getType());
333552d26e2SFinlay     if (!funcName)
334552d26e2SFinlay       return rewriter.notifyMatchFailure(op, "unsupported value type");
33598d5d344SVictor Perez 
33698d5d344SVictor Perez     Operation *moduleOp = op->getParentWithTrait<OpTrait::SymbolTable>();
33798d5d344SVictor Perez     assert(moduleOp && "Expecting module");
338*cdd652ebSPietro Ghiglio     Type valueType = inValue.getType();
33998d5d344SVictor Perez     Type offsetType = adaptor.getOffset().getType();
34098d5d344SVictor Perez     Type resultType = valueType;
3415a53add8SFinlay     LLVM::LLVMFuncOp func = lookupOrCreateSPIRVFn(
342552d26e2SFinlay         moduleOp, funcName.value(), {valueType, offsetType}, resultType,
3435a53add8SFinlay         /*isMemNone=*/false, /*isConvergent=*/true);
34498d5d344SVictor Perez 
345*cdd652ebSPietro Ghiglio     std::array<Value, 2> args{inValue, adaptor.getOffset()};
34698d5d344SVictor Perez     Value result =
34798d5d344SVictor Perez         createSPIRVBuiltinCall(loc, rewriter, func, args).getResult();
348*cdd652ebSPietro Ghiglio     Value resultOrConversion =
349*cdd652ebSPietro Ghiglio         bitcastOrTruncAfterShuffle(result, op.getType(0), loc, rewriter);
350*cdd652ebSPietro Ghiglio 
35198d5d344SVictor Perez     Value trueVal =
35298d5d344SVictor Perez         rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI1Type(), true);
353*cdd652ebSPietro Ghiglio     rewriter.replaceOp(op, {resultOrConversion, trueVal});
35498d5d344SVictor Perez     return success();
35598d5d344SVictor Perez   }
35698d5d344SVictor Perez };
35798d5d344SVictor Perez 
358f8b7a653SPetr Kurapov class MemorySpaceToOpenCLMemorySpaceConverter final : public TypeConverter {
359f8b7a653SPetr Kurapov public:
360f8b7a653SPetr Kurapov   MemorySpaceToOpenCLMemorySpaceConverter(MLIRContext *ctx) {
361f8b7a653SPetr Kurapov     addConversion([](Type t) { return t; });
362f8b7a653SPetr Kurapov     addConversion([ctx](BaseMemRefType memRefType) -> std::optional<Type> {
363f8b7a653SPetr Kurapov       // Attach global addr space attribute to memrefs with no addr space attr
364f8b7a653SPetr Kurapov       Attribute memSpaceAttr = memRefType.getMemorySpace();
365f8b7a653SPetr Kurapov       if (memSpaceAttr)
366f8b7a653SPetr Kurapov         return std::nullopt;
367f8b7a653SPetr Kurapov 
368f8b7a653SPetr Kurapov       unsigned globalAddrspace = storageClassToAddressSpace(
369f8b7a653SPetr Kurapov           spirv::ClientAPI::OpenCL, spirv::StorageClass::CrossWorkgroup);
370f8b7a653SPetr Kurapov       Attribute addrSpaceAttr =
371f8b7a653SPetr Kurapov           IntegerAttr::get(IntegerType::get(ctx, 64), globalAddrspace);
372f8b7a653SPetr Kurapov       if (auto rankedType = dyn_cast<MemRefType>(memRefType)) {
373f8b7a653SPetr Kurapov         return MemRefType::get(memRefType.getShape(),
374f8b7a653SPetr Kurapov                                memRefType.getElementType(),
375f8b7a653SPetr Kurapov                                rankedType.getLayout(), addrSpaceAttr);
376f8b7a653SPetr Kurapov       }
377f8b7a653SPetr Kurapov       return UnrankedMemRefType::get(memRefType.getElementType(),
378f8b7a653SPetr Kurapov                                      addrSpaceAttr);
379f8b7a653SPetr Kurapov     });
380f8b7a653SPetr Kurapov     addConversion([this](FunctionType type) {
381f8b7a653SPetr Kurapov       auto inputs = llvm::map_to_vector(
382f8b7a653SPetr Kurapov           type.getInputs(), [this](Type ty) { return convertType(ty); });
383f8b7a653SPetr Kurapov       auto results = llvm::map_to_vector(
384f8b7a653SPetr Kurapov           type.getResults(), [this](Type ty) { return convertType(ty); });
385f8b7a653SPetr Kurapov       return FunctionType::get(type.getContext(), inputs, results);
386f8b7a653SPetr Kurapov     });
387f8b7a653SPetr Kurapov   }
388f8b7a653SPetr Kurapov };
389f8b7a653SPetr Kurapov 
39098d5d344SVictor Perez //===----------------------------------------------------------------------===//
391af7aa223SFinlay // Subgroup query ops.
392af7aa223SFinlay //===----------------------------------------------------------------------===//
393af7aa223SFinlay 
394af7aa223SFinlay template <typename SubgroupOp>
395af7aa223SFinlay struct GPUSubgroupOpConversion final : ConvertOpToLLVMPattern<SubgroupOp> {
396af7aa223SFinlay   using ConvertOpToLLVMPattern<SubgroupOp>::ConvertOpToLLVMPattern;
397af7aa223SFinlay   using ConvertToLLVMPattern::getTypeConverter;
398af7aa223SFinlay 
399af7aa223SFinlay   LogicalResult
400af7aa223SFinlay   matchAndRewrite(SubgroupOp op, typename SubgroupOp::Adaptor adaptor,
401af7aa223SFinlay                   ConversionPatternRewriter &rewriter) const final {
402af7aa223SFinlay     constexpr StringRef funcName = [] {
403af7aa223SFinlay       if constexpr (std::is_same_v<SubgroupOp, gpu::SubgroupIdOp>) {
404af7aa223SFinlay         return "_Z16get_sub_group_id";
405af7aa223SFinlay       } else if constexpr (std::is_same_v<SubgroupOp, gpu::LaneIdOp>) {
406af7aa223SFinlay         return "_Z22get_sub_group_local_id";
407af7aa223SFinlay       } else if constexpr (std::is_same_v<SubgroupOp, gpu::NumSubgroupsOp>) {
408af7aa223SFinlay         return "_Z18get_num_sub_groups";
409af7aa223SFinlay       } else if constexpr (std::is_same_v<SubgroupOp, gpu::SubgroupSizeOp>) {
410af7aa223SFinlay         return "_Z18get_sub_group_size";
411af7aa223SFinlay       }
412af7aa223SFinlay     }();
413af7aa223SFinlay 
414af7aa223SFinlay     Operation *moduleOp =
415af7aa223SFinlay         op->template getParentWithTrait<OpTrait::SymbolTable>();
416af7aa223SFinlay     Type resultTy = rewriter.getI32Type();
417af7aa223SFinlay     LLVM::LLVMFuncOp func =
418af7aa223SFinlay         lookupOrCreateSPIRVFn(moduleOp, funcName, {}, resultTy,
419af7aa223SFinlay                               /*isMemNone=*/false, /*isConvergent=*/false);
420af7aa223SFinlay 
421af7aa223SFinlay     Location loc = op->getLoc();
422af7aa223SFinlay     Value result = createSPIRVBuiltinCall(loc, rewriter, func, {}).getResult();
423af7aa223SFinlay 
424af7aa223SFinlay     Type indexTy = getTypeConverter()->getIndexType();
425af7aa223SFinlay     if (resultTy != indexTy) {
426af7aa223SFinlay       if (indexTy.getIntOrFloatBitWidth() < resultTy.getIntOrFloatBitWidth()) {
427af7aa223SFinlay         return failure();
428af7aa223SFinlay       }
429af7aa223SFinlay       result = rewriter.create<LLVM::ZExtOp>(loc, indexTy, result);
430af7aa223SFinlay     }
431af7aa223SFinlay 
432af7aa223SFinlay     rewriter.replaceOp(op, result);
433af7aa223SFinlay     return success();
434af7aa223SFinlay   }
435af7aa223SFinlay };
436af7aa223SFinlay 
437af7aa223SFinlay //===----------------------------------------------------------------------===//
43898d5d344SVictor Perez // GPU To LLVM-SPV Pass.
43998d5d344SVictor Perez //===----------------------------------------------------------------------===//
44098d5d344SVictor Perez 
44198d5d344SVictor Perez struct GPUToLLVMSPVConversionPass final
44298d5d344SVictor Perez     : impl::ConvertGpuOpsToLLVMSPVOpsBase<GPUToLLVMSPVConversionPass> {
44398d5d344SVictor Perez   using Base::Base;
44498d5d344SVictor Perez 
44598d5d344SVictor Perez   void runOnOperation() final {
44698d5d344SVictor Perez     MLIRContext *context = &getContext();
44798d5d344SVictor Perez     RewritePatternSet patterns(context);
44898d5d344SVictor Perez 
44998d5d344SVictor Perez     LowerToLLVMOptions options(context);
45081825687SJefferson Le Quellec     options.overrideIndexBitwidth(this->use64bitIndex ? 64 : 32);
45198d5d344SVictor Perez     LLVMTypeConverter converter(context, options);
45298d5d344SVictor Perez     LLVMConversionTarget target(*context);
45398d5d344SVictor Perez 
454f8b7a653SPetr Kurapov     // Force OpenCL address spaces when they are not present
455f8b7a653SPetr Kurapov     {
456f8b7a653SPetr Kurapov       MemorySpaceToOpenCLMemorySpaceConverter converter(context);
457f8b7a653SPetr Kurapov       AttrTypeReplacer replacer;
458f8b7a653SPetr Kurapov       replacer.addReplacement([&converter](BaseMemRefType origType)
459f8b7a653SPetr Kurapov                                   -> std::optional<BaseMemRefType> {
460f8b7a653SPetr Kurapov         return converter.convertType<BaseMemRefType>(origType);
461f8b7a653SPetr Kurapov       });
462f8b7a653SPetr Kurapov 
463f8b7a653SPetr Kurapov       replacer.recursivelyReplaceElementsIn(getOperation(),
464f8b7a653SPetr Kurapov                                             /*replaceAttrs=*/true,
465f8b7a653SPetr Kurapov                                             /*replaceLocs=*/false,
466f8b7a653SPetr Kurapov                                             /*replaceTypes=*/true);
467f8b7a653SPetr Kurapov     }
468f8b7a653SPetr Kurapov 
46998d5d344SVictor Perez     target.addIllegalOp<gpu::BarrierOp, gpu::BlockDimOp, gpu::BlockIdOp,
470d45de800SVictor Perez                         gpu::GPUFuncOp, gpu::GlobalIdOp, gpu::GridDimOp,
471af7aa223SFinlay                         gpu::LaneIdOp, gpu::NumSubgroupsOp, gpu::ReturnOp,
472af7aa223SFinlay                         gpu::ShuffleOp, gpu::SubgroupIdOp, gpu::SubgroupSizeOp,
473af7aa223SFinlay                         gpu::ThreadIdOp>();
47498d5d344SVictor Perez 
47598d5d344SVictor Perez     populateGpuToLLVMSPVConversionPatterns(converter, patterns);
47675cb9edfSVictor Perez     populateGpuMemorySpaceAttributeConversions(converter);
47798d5d344SVictor Perez 
47898d5d344SVictor Perez     if (failed(applyPartialConversion(getOperation(), target,
47998d5d344SVictor Perez                                       std::move(patterns))))
48098d5d344SVictor Perez       signalPassFailure();
48198d5d344SVictor Perez   }
48298d5d344SVictor Perez };
48398d5d344SVictor Perez } // namespace
48498d5d344SVictor Perez 
48598d5d344SVictor Perez //===----------------------------------------------------------------------===//
48698d5d344SVictor Perez // GPU To LLVM-SPV Patterns.
48798d5d344SVictor Perez //===----------------------------------------------------------------------===//
48898d5d344SVictor Perez 
48998d5d344SVictor Perez namespace mlir {
49075cb9edfSVictor Perez namespace {
49175cb9edfSVictor Perez static unsigned
49275cb9edfSVictor Perez gpuAddressSpaceToOCLAddressSpace(gpu::AddressSpace addressSpace) {
49375cb9edfSVictor Perez   constexpr spirv::ClientAPI clientAPI = spirv::ClientAPI::OpenCL;
49475cb9edfSVictor Perez   return storageClassToAddressSpace(clientAPI,
49575cb9edfSVictor Perez                                     addressSpaceToStorageClass(addressSpace));
49675cb9edfSVictor Perez }
49775cb9edfSVictor Perez } // namespace
49875cb9edfSVictor Perez 
499206fad0eSMatthias Springer void populateGpuToLLVMSPVConversionPatterns(
500206fad0eSMatthias Springer     const LLVMTypeConverter &typeConverter, RewritePatternSet &patterns) {
501d45de800SVictor Perez   patterns.add<GPUBarrierConversion, GPUReturnOpLowering, GPUShuffleConversion,
502af7aa223SFinlay                GPUSubgroupOpConversion<gpu::LaneIdOp>,
503af7aa223SFinlay                GPUSubgroupOpConversion<gpu::NumSubgroupsOp>,
504af7aa223SFinlay                GPUSubgroupOpConversion<gpu::SubgroupIdOp>,
505af7aa223SFinlay                GPUSubgroupOpConversion<gpu::SubgroupSizeOp>,
50698d5d344SVictor Perez                LaunchConfigOpConversion<gpu::BlockDimOp>,
507af7aa223SFinlay                LaunchConfigOpConversion<gpu::BlockIdOp>,
508af7aa223SFinlay                LaunchConfigOpConversion<gpu::GlobalIdOp>,
509af7aa223SFinlay                LaunchConfigOpConversion<gpu::GridDimOp>,
510af7aa223SFinlay                LaunchConfigOpConversion<gpu::ThreadIdOp>>(typeConverter);
511d45de800SVictor Perez   MLIRContext *context = &typeConverter.getContext();
512d45de800SVictor Perez   unsigned privateAddressSpace =
51375cb9edfSVictor Perez       gpuAddressSpaceToOCLAddressSpace(gpu::AddressSpace::Private);
514d45de800SVictor Perez   unsigned localAddressSpace =
51575cb9edfSVictor Perez       gpuAddressSpaceToOCLAddressSpace(gpu::AddressSpace::Workgroup);
516d45de800SVictor Perez   OperationName llvmFuncOpName(LLVM::LLVMFuncOp::getOperationName(), context);
517d45de800SVictor Perez   StringAttr kernelBlockSizeAttributeName =
518d45de800SVictor Perez       LLVM::LLVMFuncOp::getReqdWorkGroupSizeAttrName(llvmFuncOpName);
519d45de800SVictor Perez   patterns.add<GPUFuncOpLowering>(
520d45de800SVictor Perez       typeConverter,
521d45de800SVictor Perez       GPUFuncOpLoweringOptions{
522d45de800SVictor Perez           privateAddressSpace, localAddressSpace,
523d45de800SVictor Perez           /*kernelAttributeName=*/{}, kernelBlockSizeAttributeName,
524d45de800SVictor Perez           LLVM::CConv::SPIR_KERNEL, LLVM::CConv::SPIR_FUNC,
525d45de800SVictor Perez           /*encodeWorkgroupAttributionsAsArguments=*/true});
52698d5d344SVictor Perez }
52775cb9edfSVictor Perez 
52875cb9edfSVictor Perez void populateGpuMemorySpaceAttributeConversions(TypeConverter &typeConverter) {
52975cb9edfSVictor Perez   populateGpuMemorySpaceAttributeConversions(typeConverter,
53075cb9edfSVictor Perez                                              gpuAddressSpaceToOCLAddressSpace);
53175cb9edfSVictor Perez }
53298d5d344SVictor Perez } // namespace mlir
533