xref: /llvm-project/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp (revision 20861f1f2fdc5d9aaa76140ade96f7edfdefb0e1)
19d7be77bSChristian Sigg //===- CudaRuntimeWrappers.cpp - MLIR CUDA API wrapper library ------------===//
29d7be77bSChristian Sigg //
39d7be77bSChristian Sigg // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
49d7be77bSChristian Sigg // See https://llvm.org/LICENSE.txt for license information.
59d7be77bSChristian Sigg // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
69d7be77bSChristian Sigg //
79d7be77bSChristian Sigg //===----------------------------------------------------------------------===//
89d7be77bSChristian Sigg //
99d7be77bSChristian Sigg // Implements C wrappers around the CUDA library for easy linking in ORC jit.
109d7be77bSChristian Sigg // Also adds some debugging helpers that are helpful when writing MLIR code to
119d7be77bSChristian Sigg // run on GPUs.
129d7be77bSChristian Sigg //
139d7be77bSChristian Sigg //===----------------------------------------------------------------------===//
149d7be77bSChristian Sigg 
159d7be77bSChristian Sigg #include "mlir/ExecutionEngine/CRunnerUtils.h"
169d7be77bSChristian Sigg 
17b9f87e24SAart Bik #include <stdio.h>
18b9f87e24SAart Bik 
199d7be77bSChristian Sigg #include "cuda.h"
20cc402de0SKun Wu #include "cuda_bf16.h"
21cc402de0SKun Wu #include "cuda_fp16.h"
228ed59c53SKun Wu 
2350db4789SAart Bik #ifdef MLIR_ENABLE_CUDA_CUSPARSE
24b700a90cSAart Bik #include "cusparse.h"
2550db4789SAart Bik #ifdef MLIR_ENABLE_CUDA_CUSPARSELT
268ed59c53SKun Wu #include "cusparseLt.h"
278ed59c53SKun Wu #endif // MLIR_ENABLE_CUDA_CUSPARSELT
288ed59c53SKun Wu #endif // MLIR_ENABLE_CUDA_CUSPARSE
299d7be77bSChristian Sigg 
309d7be77bSChristian Sigg #ifdef _WIN32
315e78417dSJustin Holewinski #include <malloc.h>
329d7be77bSChristian Sigg #define MLIR_CUDA_WRAPPERS_EXPORT __declspec(dllexport)
339d7be77bSChristian Sigg #else
341c2a0768SAdam Paszke #define MLIR_CUDA_WRAPPERS_EXPORT __attribute__((visibility("default")))
359d7be77bSChristian Sigg #endif // _WIN32
369d7be77bSChristian Sigg 
379d7be77bSChristian Sigg #define CUDA_REPORT_IF_ERROR(expr)                                             \
389d7be77bSChristian Sigg   [](CUresult result) {                                                        \
399d7be77bSChristian Sigg     if (!result)                                                               \
409d7be77bSChristian Sigg       return;                                                                  \
419d7be77bSChristian Sigg     const char *name = nullptr;                                                \
429d7be77bSChristian Sigg     cuGetErrorName(result, &name);                                             \
439d7be77bSChristian Sigg     if (!name)                                                                 \
449d7be77bSChristian Sigg       name = "<unknown>";                                                      \
459d7be77bSChristian Sigg     fprintf(stderr, "'%s' failed with '%s'\n", #expr, name);                   \
469d7be77bSChristian Sigg   }(expr)
479d7be77bSChristian Sigg 
48b700a90cSAart Bik #define CUSPARSE_REPORT_IF_ERROR(expr)                                         \
49b700a90cSAart Bik   {                                                                            \
50b700a90cSAart Bik     cusparseStatus_t status = (expr);                                          \
51b700a90cSAart Bik     if (status != CUSPARSE_STATUS_SUCCESS) {                                   \
52b700a90cSAart Bik       fprintf(stderr, "cuSPARSE '%s' failed with '%s'\n", #expr,               \
53b700a90cSAart Bik               cusparseGetErrorString(status));                                 \
54b700a90cSAart Bik     }                                                                          \
55b700a90cSAart Bik   }
56b700a90cSAart Bik 
5784718d37SKrzysztof Drewniak thread_local static int32_t defaultDevice = 0;
5884718d37SKrzysztof Drewniak 
5919b11079SGuray Ozen const char *kDebugEnvironmentVariable = "MLIR_CUDA_DEBUG";
6019b11079SGuray Ozen 
6119b11079SGuray Ozen /// Helper method that checks environment value for debugging.
6219b11079SGuray Ozen bool isDebugEnabled() {
6319b11079SGuray Ozen   static bool isInitialized = false;
6419b11079SGuray Ozen   static bool isEnabled = false;
6519b11079SGuray Ozen   if (!isInitialized)
6619b11079SGuray Ozen     isEnabled = getenv(kDebugEnvironmentVariable) != nullptr;
6719b11079SGuray Ozen   return isEnabled;
6819b11079SGuray Ozen }
6919b11079SGuray Ozen 
7019b11079SGuray Ozen #define debug_print(fmt, ...)                                                  \
7119b11079SGuray Ozen   do {                                                                         \
7219b11079SGuray Ozen     if (isDebugEnabled())                                                      \
7319b11079SGuray Ozen       fprintf(stderr, "%s:%d:%s(): " fmt, "CudaRuntimeWrappers.cpp", __LINE__, \
7419b11079SGuray Ozen               __func__, __VA_ARGS__);                                          \
7519b11079SGuray Ozen   } while (0)
7619b11079SGuray Ozen 
7753881490SGuray Ozen // Returns default CUdevice
7853881490SGuray Ozen CUdevice getDefaultCuDevice() {
7953881490SGuray Ozen   CUdevice device;
8053881490SGuray Ozen   CUDA_REPORT_IF_ERROR(cuDeviceGet(&device, /*ordinal=*/defaultDevice));
8153881490SGuray Ozen   return device;
8253881490SGuray Ozen }
8353881490SGuray Ozen 
8484718d37SKrzysztof Drewniak // Make the primary context of the current default device current for the
8584718d37SKrzysztof Drewniak // duration
8684718d37SKrzysztof Drewniak //  of the instance and restore the previous context on destruction.
879d7be77bSChristian Sigg class ScopedContext {
889d7be77bSChristian Sigg public:
899d7be77bSChristian Sigg   ScopedContext() {
9084718d37SKrzysztof Drewniak     // Static reference to CUDA primary context for device ordinal
9184718d37SKrzysztof Drewniak     // defaultDevice.
92f69d5a7fSChristian Sigg     static CUcontext context = [] {
93f69d5a7fSChristian Sigg       CUDA_REPORT_IF_ERROR(cuInit(/*flags=*/0));
94f69d5a7fSChristian Sigg       CUcontext ctx;
95f69d5a7fSChristian Sigg       // Note: this does not affect the current context.
9653881490SGuray Ozen       CUDA_REPORT_IF_ERROR(
9753881490SGuray Ozen           cuDevicePrimaryCtxRetain(&ctx, getDefaultCuDevice()));
98f69d5a7fSChristian Sigg       return ctx;
99f69d5a7fSChristian Sigg     }();
100f69d5a7fSChristian Sigg 
101f69d5a7fSChristian Sigg     CUDA_REPORT_IF_ERROR(cuCtxPushCurrent(context));
1029d7be77bSChristian Sigg   }
1039d7be77bSChristian Sigg 
104f69d5a7fSChristian Sigg   ~ScopedContext() { CUDA_REPORT_IF_ERROR(cuCtxPopCurrent(nullptr)); }
1059d7be77bSChristian Sigg };
1069d7be77bSChristian Sigg 
10703125e68SAart Bik #ifdef MLIR_ENABLE_CUDA_CUSPARSE
108be2dd22bSKun Wu // Note that (1) Nvidia confirms the safety to share handle across multiple
109be2dd22bSKun Wu // instances, and streams. (2) Clients are responsible to call the @mgpu
110be2dd22bSKun Wu // environment initialization/destruction in a thread-safe manner, e.g.,
111be2dd22bSKun Wu // at the beginning of the program before multi-threads are created.
112be2dd22bSKun Wu static cusparseHandle_t cusparse_env = nullptr;
113be2dd22bSKun Wu 
114be2dd22bSKun Wu #ifdef MLIR_ENABLE_CUDA_CUSPARSELT
115be2dd22bSKun Wu // cusparseLtHandle_t is not a pointer type, so we need an additional flag to
116be2dd22bSKun Wu // indicate whether it is initialized.
117be2dd22bSKun Wu static cusparseLtHandle_t cusparseLt_env;
118be2dd22bSKun Wu static bool cusparseLt_initiated = false;
119be2dd22bSKun Wu 
120be2dd22bSKun Wu #endif // MLIR_ENABLE_CUDA_CUSPARSELT
121be2dd22bSKun Wu #endif // MLIR_ENABLE_CUDA_CUSPARSE
122be2dd22bSKun Wu 
123ebfea261SNishant Patel extern "C" MLIR_CUDA_WRAPPERS_EXPORT CUmodule
124ebfea261SNishant Patel mgpuModuleLoad(void *data, size_t /*gpuBlobSize*/) {
1259d7be77bSChristian Sigg   ScopedContext scopedContext;
1269d7be77bSChristian Sigg   CUmodule module = nullptr;
1279d7be77bSChristian Sigg   CUDA_REPORT_IF_ERROR(cuModuleLoadData(&module, data));
1289d7be77bSChristian Sigg   return module;
1299d7be77bSChristian Sigg }
1309d7be77bSChristian Sigg 
1315093413aSFabian Mora extern "C" MLIR_CUDA_WRAPPERS_EXPORT CUmodule mgpuModuleLoadJIT(void *data,
1325093413aSFabian Mora                                                                 int optLevel) {
1335093413aSFabian Mora   ScopedContext scopedContext;
1345093413aSFabian Mora   CUmodule module = nullptr;
1355093413aSFabian Mora   char jitErrorBuffer[4096] = {0};
1365093413aSFabian Mora   CUjit_option jitOptions[] = {CU_JIT_ERROR_LOG_BUFFER,
1375093413aSFabian Mora                                CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES,
1385093413aSFabian Mora                                CU_JIT_OPTIMIZATION_LEVEL};
1395093413aSFabian Mora   void *jitOptionsVals[] = {jitErrorBuffer,
1405093413aSFabian Mora                             reinterpret_cast<void *>(sizeof(jitErrorBuffer)),
1415093413aSFabian Mora                             reinterpret_cast<void *>(optLevel)};
1425093413aSFabian Mora 
1435093413aSFabian Mora   CUresult result =
1445093413aSFabian Mora       cuModuleLoadDataEx(&module, data, 3, jitOptions, jitOptionsVals);
1455093413aSFabian Mora   if (result) {
1465093413aSFabian Mora     fprintf(stderr, "JIT compilation failed with: '%s'\n", jitErrorBuffer);
1475093413aSFabian Mora     CUDA_REPORT_IF_ERROR(result);
1485093413aSFabian Mora   }
1495093413aSFabian Mora   return module;
1505093413aSFabian Mora }
1515093413aSFabian Mora 
1529d7be77bSChristian Sigg extern "C" MLIR_CUDA_WRAPPERS_EXPORT void mgpuModuleUnload(CUmodule module) {
1539d7be77bSChristian Sigg   CUDA_REPORT_IF_ERROR(cuModuleUnload(module));
1549d7be77bSChristian Sigg }
1559d7be77bSChristian Sigg 
1569d7be77bSChristian Sigg extern "C" MLIR_CUDA_WRAPPERS_EXPORT CUfunction
1579d7be77bSChristian Sigg mgpuModuleGetFunction(CUmodule module, const char *name) {
1589d7be77bSChristian Sigg   CUfunction function = nullptr;
1599d7be77bSChristian Sigg   CUDA_REPORT_IF_ERROR(cuModuleGetFunction(&function, module, name));
1609d7be77bSChristian Sigg   return function;
1619d7be77bSChristian Sigg }
1629d7be77bSChristian Sigg 
1639d7be77bSChristian Sigg // The wrapper uses intptr_t instead of CUDA's unsigned int to match
1649d7be77bSChristian Sigg // the type of MLIR's index type. This avoids the need for casts in the
1659d7be77bSChristian Sigg // generated MLIR code.
1669d7be77bSChristian Sigg extern "C" MLIR_CUDA_WRAPPERS_EXPORT void
1679d7be77bSChristian Sigg mgpuLaunchKernel(CUfunction function, intptr_t gridX, intptr_t gridY,
1689d7be77bSChristian Sigg                  intptr_t gridZ, intptr_t blockX, intptr_t blockY,
1699d7be77bSChristian Sigg                  intptr_t blockZ, int32_t smem, CUstream stream, void **params,
170ebfea261SNishant Patel                  void **extra, size_t /*paramsCount*/) {
1719d7be77bSChristian Sigg   ScopedContext scopedContext;
1725ef45c02SGuray Ozen   if (smem > 0) {
1735ef45c02SGuray Ozen     // Avoid checking driver as it's more expensive than if statement
17453881490SGuray Ozen     int32_t maxShmem = 0;
17553881490SGuray Ozen     CUdevice device = getDefaultCuDevice();
17653881490SGuray Ozen     CUDA_REPORT_IF_ERROR(cuDeviceGet(&device, /*ordinal=*/defaultDevice));
17753881490SGuray Ozen     CUDA_REPORT_IF_ERROR(cuDeviceGetAttribute(
17853881490SGuray Ozen         &maxShmem, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN,
17953881490SGuray Ozen         device));
18053881490SGuray Ozen     if (maxShmem < smem) {
18153881490SGuray Ozen       fprintf(stderr,
18253881490SGuray Ozen               "Requested shared memory (%dkb) is larger than maximum allowed "
18353881490SGuray Ozen               "shared memory (%dkb) for this device\n",
18453881490SGuray Ozen               smem, maxShmem);
18553881490SGuray Ozen     }
18653881490SGuray Ozen     CUDA_REPORT_IF_ERROR(cuFuncSetAttribute(
18753881490SGuray Ozen         function, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, smem));
1885ef45c02SGuray Ozen   }
18953881490SGuray Ozen   debug_print("Launching kernel, grid=%ld,%ld,%ld, "
19053881490SGuray Ozen               "threads: %ld, %ld, %ld, "
19153881490SGuray Ozen               "smem: %dkb\n",
19253881490SGuray Ozen               gridX, gridY, gridZ, blockX, blockY, blockZ, smem);
1939d7be77bSChristian Sigg   CUDA_REPORT_IF_ERROR(cuLaunchKernel(function, gridX, gridY, gridZ, blockX,
1949d7be77bSChristian Sigg                                       blockY, blockZ, smem, stream, params,
1959d7be77bSChristian Sigg                                       extra));
1969d7be77bSChristian Sigg }
1979d7be77bSChristian Sigg 
1989d7be77bSChristian Sigg extern "C" MLIR_CUDA_WRAPPERS_EXPORT CUstream mgpuStreamCreate() {
1999d7be77bSChristian Sigg   ScopedContext scopedContext;
2009d7be77bSChristian Sigg   CUstream stream = nullptr;
2019d7be77bSChristian Sigg   CUDA_REPORT_IF_ERROR(cuStreamCreate(&stream, CU_STREAM_NON_BLOCKING));
2029d7be77bSChristian Sigg   return stream;
2039d7be77bSChristian Sigg }
2049d7be77bSChristian Sigg 
2059d7be77bSChristian Sigg extern "C" MLIR_CUDA_WRAPPERS_EXPORT void mgpuStreamDestroy(CUstream stream) {
2069d7be77bSChristian Sigg   CUDA_REPORT_IF_ERROR(cuStreamDestroy(stream));
2079d7be77bSChristian Sigg }
2089d7be77bSChristian Sigg 
2099d7be77bSChristian Sigg extern "C" MLIR_CUDA_WRAPPERS_EXPORT void
2109d7be77bSChristian Sigg mgpuStreamSynchronize(CUstream stream) {
2119d7be77bSChristian Sigg   CUDA_REPORT_IF_ERROR(cuStreamSynchronize(stream));
2129d7be77bSChristian Sigg }
2139d7be77bSChristian Sigg 
2149d7be77bSChristian Sigg extern "C" MLIR_CUDA_WRAPPERS_EXPORT void mgpuStreamWaitEvent(CUstream stream,
2159d7be77bSChristian Sigg                                                               CUevent event) {
2169d7be77bSChristian Sigg   CUDA_REPORT_IF_ERROR(cuStreamWaitEvent(stream, event, /*flags=*/0));
2179d7be77bSChristian Sigg }
2189d7be77bSChristian Sigg 
2199d7be77bSChristian Sigg extern "C" MLIR_CUDA_WRAPPERS_EXPORT CUevent mgpuEventCreate() {
2209d7be77bSChristian Sigg   ScopedContext scopedContext;
2219d7be77bSChristian Sigg   CUevent event = nullptr;
2229d7be77bSChristian Sigg   CUDA_REPORT_IF_ERROR(cuEventCreate(&event, CU_EVENT_DISABLE_TIMING));
2239d7be77bSChristian Sigg   return event;
2249d7be77bSChristian Sigg }
2259d7be77bSChristian Sigg 
2269d7be77bSChristian Sigg extern "C" MLIR_CUDA_WRAPPERS_EXPORT void mgpuEventDestroy(CUevent event) {
2279d7be77bSChristian Sigg   CUDA_REPORT_IF_ERROR(cuEventDestroy(event));
2289d7be77bSChristian Sigg }
2299d7be77bSChristian Sigg 
2301c2a0768SAdam Paszke extern "C" MLIR_CUDA_WRAPPERS_EXPORT void mgpuEventSynchronize(CUevent event) {
2319d7be77bSChristian Sigg   CUDA_REPORT_IF_ERROR(cuEventSynchronize(event));
2329d7be77bSChristian Sigg }
2339d7be77bSChristian Sigg 
2341c2a0768SAdam Paszke extern "C" MLIR_CUDA_WRAPPERS_EXPORT void mgpuEventRecord(CUevent event,
2359d7be77bSChristian Sigg                                                           CUstream stream) {
2369d7be77bSChristian Sigg   CUDA_REPORT_IF_ERROR(cuEventRecord(event, stream));
2379d7be77bSChristian Sigg }
2389d7be77bSChristian Sigg 
2391c2a0768SAdam Paszke extern "C" MLIR_CUDA_WRAPPERS_EXPORT void *
240*20861f1fSGuray Ozen mgpuMemAlloc(uint64_t sizeBytes, CUstream stream, bool isHostShared) {
2419d7be77bSChristian Sigg   ScopedContext scopedContext;
2423635c743SAart Bik   CUdeviceptr ptr = 0;
243*20861f1fSGuray Ozen   if (sizeBytes == 0)
244*20861f1fSGuray Ozen     return reinterpret_cast<void *>(ptr);
245*20861f1fSGuray Ozen 
246*20861f1fSGuray Ozen   if (isHostShared) {
247*20861f1fSGuray Ozen     CUDA_REPORT_IF_ERROR(
248*20861f1fSGuray Ozen         cuMemAllocManaged(&ptr, sizeBytes, CU_MEM_ATTACH_GLOBAL));
249*20861f1fSGuray Ozen     return reinterpret_cast<void *>(ptr);
250*20861f1fSGuray Ozen   }
2519d7be77bSChristian Sigg   CUDA_REPORT_IF_ERROR(cuMemAlloc(&ptr, sizeBytes));
2529d7be77bSChristian Sigg   return reinterpret_cast<void *>(ptr);
2539d7be77bSChristian Sigg }
2549d7be77bSChristian Sigg 
2551c2a0768SAdam Paszke extern "C" MLIR_CUDA_WRAPPERS_EXPORT void mgpuMemFree(void *ptr,
2561c2a0768SAdam Paszke                                                       CUstream /*stream*/) {
2579d7be77bSChristian Sigg   CUDA_REPORT_IF_ERROR(cuMemFree(reinterpret_cast<CUdeviceptr>(ptr)));
2589d7be77bSChristian Sigg }
2599d7be77bSChristian Sigg 
2601c2a0768SAdam Paszke extern "C" MLIR_CUDA_WRAPPERS_EXPORT void
2611c2a0768SAdam Paszke mgpuMemcpy(void *dst, void *src, size_t sizeBytes, CUstream stream) {
2629d7be77bSChristian Sigg   CUDA_REPORT_IF_ERROR(cuMemcpyAsync(reinterpret_cast<CUdeviceptr>(dst),
2639d7be77bSChristian Sigg                                      reinterpret_cast<CUdeviceptr>(src),
2649d7be77bSChristian Sigg                                      sizeBytes, stream));
2659d7be77bSChristian Sigg }
2669d7be77bSChristian Sigg 
2671c2a0768SAdam Paszke extern "C" MLIR_CUDA_WRAPPERS_EXPORT void
2681c2a0768SAdam Paszke mgpuMemset32(void *dst, unsigned int value, size_t count, CUstream stream) {
269361458b1SLoren Maggiore   CUDA_REPORT_IF_ERROR(cuMemsetD32Async(reinterpret_cast<CUdeviceptr>(dst),
270361458b1SLoren Maggiore                                         value, count, stream));
271361458b1SLoren Maggiore }
272361458b1SLoren Maggiore 
2731c2a0768SAdam Paszke extern "C" MLIR_CUDA_WRAPPERS_EXPORT void
2741c2a0768SAdam Paszke mgpuMemset16(void *dst, unsigned short value, size_t count, CUstream stream) {
27518cc07aaSNavdeep Katel   CUDA_REPORT_IF_ERROR(cuMemsetD16Async(reinterpret_cast<CUdeviceptr>(dst),
27618cc07aaSNavdeep Katel                                         value, count, stream));
27718cc07aaSNavdeep Katel }
27818cc07aaSNavdeep Katel 
279b700a90cSAart Bik ///
2809d7be77bSChristian Sigg /// Helper functions for writing mlir example code
281b700a90cSAart Bik ///
2829d7be77bSChristian Sigg 
2839d7be77bSChristian Sigg // Allows to register byte array with the CUDA runtime. Helpful until we have
2849d7be77bSChristian Sigg // transfer functions implemented.
2859d7be77bSChristian Sigg extern "C" MLIR_CUDA_WRAPPERS_EXPORT void
2869d7be77bSChristian Sigg mgpuMemHostRegister(void *ptr, uint64_t sizeBytes) {
2879d7be77bSChristian Sigg   ScopedContext scopedContext;
2889d7be77bSChristian Sigg   CUDA_REPORT_IF_ERROR(cuMemHostRegister(ptr, sizeBytes, /*flags=*/0));
2899d7be77bSChristian Sigg }
2909d7be77bSChristian Sigg 
2914edc9e2aSUday Bondhugula /// Registers a memref with the CUDA runtime. `descriptor` is a pointer to a
2924edc9e2aSUday Bondhugula /// ranked memref descriptor struct of rank `rank`. Helpful until we have
2934edc9e2aSUday Bondhugula /// transfer functions implemented.
2949d7be77bSChristian Sigg extern "C" MLIR_CUDA_WRAPPERS_EXPORT void
2959d7be77bSChristian Sigg mgpuMemHostRegisterMemRef(int64_t rank, StridedMemRefType<char, 1> *descriptor,
2969d7be77bSChristian Sigg                           int64_t elementSizeBytes) {
2979d7be77bSChristian Sigg   // Only densely packed tensors are currently supported.
2985e78417dSJustin Holewinski #ifdef _WIN32
2995e78417dSJustin Holewinski   int64_t *denseStrides = (int64_t *)_alloca(rank * sizeof(int64_t));
3005e78417dSJustin Holewinski #else
3014edc9e2aSUday Bondhugula   int64_t *denseStrides = (int64_t *)alloca(rank * sizeof(int64_t));
3025e78417dSJustin Holewinski #endif // _WIN32
3034edc9e2aSUday Bondhugula   int64_t *sizes = descriptor->sizes;
3044edc9e2aSUday Bondhugula   for (int64_t i = rank - 1, runningStride = 1; i >= 0; i--) {
3054edc9e2aSUday Bondhugula     denseStrides[i] = runningStride;
3064edc9e2aSUday Bondhugula     runningStride *= sizes[i];
3074edc9e2aSUday Bondhugula   }
3084edc9e2aSUday Bondhugula   uint64_t sizeBytes = sizes[0] * denseStrides[0] * elementSizeBytes;
3094edc9e2aSUday Bondhugula   int64_t *strides = &sizes[rank];
310012c0cc7SNicolas Vasilache   (void)strides;
3114edc9e2aSUday Bondhugula   for (unsigned i = 0; i < rank; ++i)
3124edc9e2aSUday Bondhugula     assert(strides[i] == denseStrides[i] &&
3134edc9e2aSUday Bondhugula            "Mismatch in computed dense strides");
3149d7be77bSChristian Sigg 
3154edc9e2aSUday Bondhugula   auto *ptr = descriptor->data + descriptor->offset * elementSizeBytes;
3169d7be77bSChristian Sigg   mgpuMemHostRegister(ptr, sizeBytes);
3179d7be77bSChristian Sigg }
31884718d37SKrzysztof Drewniak 
3198f7c8a6eSmax // Allows to unregister byte array with the CUDA runtime.
3208f7c8a6eSmax extern "C" MLIR_CUDA_WRAPPERS_EXPORT void mgpuMemHostUnregister(void *ptr) {
3218f7c8a6eSmax   ScopedContext scopedContext;
3228f7c8a6eSmax   CUDA_REPORT_IF_ERROR(cuMemHostUnregister(ptr));
3238f7c8a6eSmax }
3248f7c8a6eSmax 
3258f7c8a6eSmax /// Unregisters a memref with the CUDA runtime. `descriptor` is a pointer to a
3268f7c8a6eSmax /// ranked memref descriptor struct of rank `rank`
3278f7c8a6eSmax extern "C" MLIR_CUDA_WRAPPERS_EXPORT void
3288f7c8a6eSmax mgpuMemHostUnregisterMemRef(int64_t rank,
3298f7c8a6eSmax                             StridedMemRefType<char, 1> *descriptor,
3308f7c8a6eSmax                             int64_t elementSizeBytes) {
3318f7c8a6eSmax   auto *ptr = descriptor->data + descriptor->offset * elementSizeBytes;
3328f7c8a6eSmax   mgpuMemHostUnregister(ptr);
3338f7c8a6eSmax }
3348f7c8a6eSmax 
33584718d37SKrzysztof Drewniak extern "C" MLIR_CUDA_WRAPPERS_EXPORT void mgpuSetDefaultDevice(int32_t device) {
33684718d37SKrzysztof Drewniak   defaultDevice = device;
33784718d37SKrzysztof Drewniak }
338b700a90cSAart Bik 
3391dc00712SGuray Ozen ///
3401dc00712SGuray Ozen /// Runtime methods using CUDA 12.0+ driver
3411dc00712SGuray Ozen ///
3421dc00712SGuray Ozen 
3431dc00712SGuray Ozen #if (CUDA_VERSION >= 12000)
3441dc00712SGuray Ozen 
345f21a70f9SGuray Ozen extern "C" MLIR_CUDA_WRAPPERS_EXPORT void mgpuLaunchClusterKernel(
346f21a70f9SGuray Ozen     CUfunction function, intptr_t clusterX, intptr_t clusterY,
347f21a70f9SGuray Ozen     intptr_t clusterZ, intptr_t gridX, intptr_t gridY, intptr_t gridZ,
348f21a70f9SGuray Ozen     intptr_t blockX, intptr_t blockY, intptr_t blockZ, int32_t smem,
349f21a70f9SGuray Ozen     CUstream stream, void **params, void **extra, size_t /*paramsCount*/) {
350f21a70f9SGuray Ozen   ScopedContext scopedContext;
351f21a70f9SGuray Ozen   if (smem > 0) {
352f21a70f9SGuray Ozen     // Avoid checking driver as it's more expensive than if statement
353f21a70f9SGuray Ozen     int32_t maxShmem = 0;
354f21a70f9SGuray Ozen     CUdevice device = getDefaultCuDevice();
355f21a70f9SGuray Ozen     CUDA_REPORT_IF_ERROR(cuDeviceGet(&device, /*ordinal=*/defaultDevice));
356f21a70f9SGuray Ozen     CUDA_REPORT_IF_ERROR(cuDeviceGetAttribute(
357f21a70f9SGuray Ozen         &maxShmem, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN,
358f21a70f9SGuray Ozen         device));
359f21a70f9SGuray Ozen     if (maxShmem < smem) {
360f21a70f9SGuray Ozen       fprintf(stderr,
361f21a70f9SGuray Ozen               "Requested shared memory (%dkb) is larger than maximum allowed "
362f21a70f9SGuray Ozen               "shared memory (%dkb) for this device\n",
363f21a70f9SGuray Ozen               smem, maxShmem);
364f21a70f9SGuray Ozen     }
365f21a70f9SGuray Ozen     CUDA_REPORT_IF_ERROR(cuFuncSetAttribute(
366f21a70f9SGuray Ozen         function, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, smem));
367f21a70f9SGuray Ozen   }
368f21a70f9SGuray Ozen   CUlaunchConfig config;
369f21a70f9SGuray Ozen   config.gridDimX = gridX;
370f21a70f9SGuray Ozen   config.gridDimY = gridY;
371f21a70f9SGuray Ozen   config.gridDimZ = gridZ;
372f21a70f9SGuray Ozen   config.blockDimX = blockX;
373f21a70f9SGuray Ozen   config.blockDimY = blockY;
374f21a70f9SGuray Ozen   config.blockDimZ = blockZ;
375f21a70f9SGuray Ozen   config.sharedMemBytes = smem;
376f21a70f9SGuray Ozen   config.hStream = stream;
377f21a70f9SGuray Ozen   CUlaunchAttribute launchAttr[2];
378f21a70f9SGuray Ozen   launchAttr[0].id = CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION;
379f21a70f9SGuray Ozen   launchAttr[0].value.clusterDim.x = clusterX;
380f21a70f9SGuray Ozen   launchAttr[0].value.clusterDim.y = clusterY;
381f21a70f9SGuray Ozen   launchAttr[0].value.clusterDim.z = clusterZ;
382f21a70f9SGuray Ozen   launchAttr[1].id = CU_LAUNCH_ATTRIBUTE_CLUSTER_SCHEDULING_POLICY_PREFERENCE;
383f21a70f9SGuray Ozen   launchAttr[1].value.clusterSchedulingPolicyPreference =
384f21a70f9SGuray Ozen       CU_CLUSTER_SCHEDULING_POLICY_SPREAD;
385f21a70f9SGuray Ozen   config.numAttrs = 2;
386f21a70f9SGuray Ozen   config.attrs = launchAttr;
387f21a70f9SGuray Ozen 
388f21a70f9SGuray Ozen   debug_print("Launching kernel,"
389f21a70f9SGuray Ozen               "cluster: %ld, %ld, %ld, "
390f21a70f9SGuray Ozen               "grid=%ld,%ld,%ld, "
391f21a70f9SGuray Ozen               "threads: %ld, %ld, %ld, "
392f21a70f9SGuray Ozen               "smem: %dkb\n",
393f21a70f9SGuray Ozen               clusterX, clusterY, clusterZ, gridX, gridY, gridZ, blockX, blockY,
394f21a70f9SGuray Ozen               blockZ, smem);
395f21a70f9SGuray Ozen 
396f21a70f9SGuray Ozen   CUDA_REPORT_IF_ERROR(cuLaunchKernelEx(&config, function, params, extra));
397f21a70f9SGuray Ozen }
398f21a70f9SGuray Ozen 
399e56d6745SGuray Ozen extern "C" MLIR_CUDA_WRAPPERS_EXPORT void mgpuTensorMapEncodeTiled(
400e56d6745SGuray Ozen     CUtensorMap *tensorMap,             // Tensor map object
401e56d6745SGuray Ozen     CUtensorMapDataType tensorDataType, // Tensor data type
402e56d6745SGuray Ozen     cuuint32_t tensorRank,              // Dimensionality of tensor
403e56d6745SGuray Ozen     void *globalAddress,                // Starting address
404e56d6745SGuray Ozen     const cuuint64_t *globalDim,        // Tensor size (number of elements)
405e56d6745SGuray Ozen     const cuuint64_t *globalStrides,    // Stride size (in bytes)
406e56d6745SGuray Ozen     const cuuint32_t *boxDim,           // Traversal box (number of elments)
407e56d6745SGuray Ozen     const cuuint32_t *elementStrides,   // Traversal stride
408e56d6745SGuray Ozen     CUtensorMapInterleave interleave,   // Type of interleaved layout
409e56d6745SGuray Ozen     CUtensorMapSwizzle swizzle,         // Bank swizzling pattern
410e56d6745SGuray Ozen     CUtensorMapL2promotion l2Promotion, // L2 promotion size
411e56d6745SGuray Ozen     CUtensorMapFloatOOBfill oobFill     // Padding zfill or NaN fill
412e56d6745SGuray Ozen ) {
413e56d6745SGuray Ozen   ScopedContext scopedContext;
414e56d6745SGuray Ozen   CUDA_REPORT_IF_ERROR(cuTensorMapEncodeTiled(
415e56d6745SGuray Ozen       tensorMap, tensorDataType, tensorRank, globalAddress, globalDim,
416e56d6745SGuray Ozen       globalStrides, boxDim, elementStrides, interleave, swizzle, l2Promotion,
417e56d6745SGuray Ozen       oobFill));
41819b11079SGuray Ozen   debug_print("Created TMA descriptor\n Addr: %p\n"
41919b11079SGuray Ozen               "data type : %d\n"
42019b11079SGuray Ozen               "rank : %d\n"
42119b11079SGuray Ozen               "globalDim[5]: %zu, %zu, %zu, %zu, %zu\n"
42219b11079SGuray Ozen               "globalStrides[5]: %zu, %zu, %zu, %zu, %zu\n"
42319b11079SGuray Ozen               "boxDim[5]: %u, %u, %u, %u, %u\n"
42419b11079SGuray Ozen               "elementStrides[5]: %u, %u, %u, %u, %u\n"
42519b11079SGuray Ozen               "interleave: %u \n"
42619b11079SGuray Ozen               "swizzle: %u \n"
42719b11079SGuray Ozen               "l2Promotion: %u \n"
42819b11079SGuray Ozen               "oobFill: %u \n",
42919b11079SGuray Ozen               (void *)&tensorMap, tensorDataType, tensorRank, globalDim[0],
43019b11079SGuray Ozen               globalDim[1], globalDim[2], globalDim[3], globalDim[4],
43119b11079SGuray Ozen               globalStrides[0], globalStrides[1], globalStrides[2],
43219b11079SGuray Ozen               globalStrides[3], globalStrides[4], boxDim[0], boxDim[1],
43319b11079SGuray Ozen               boxDim[2], boxDim[3], boxDim[4], elementStrides[0],
43419b11079SGuray Ozen               elementStrides[1], elementStrides[2], elementStrides[3],
43519b11079SGuray Ozen               elementStrides[4], interleave, swizzle, l2Promotion, oobFill);
436e56d6745SGuray Ozen }
437e56d6745SGuray Ozen 
4387d55b916SGuray Ozen template <int Rank>
4397d55b916SGuray Ozen void mgpuGetMemRefDataAndShape(void *rawDescriptor, char **addr,
4407d55b916SGuray Ozen                                uint64_t *globalDim, uint64_t *globalStrides,
4417d55b916SGuray Ozen                                const CUtensorMapDataType tensorDataType) {
44265aab9e7SAdam Paszke   auto descriptor =
4437d55b916SGuray Ozen       reinterpret_cast<StridedMemRefType<char, Rank> *>(rawDescriptor);
44465aab9e7SAdam Paszke   *addr = descriptor->data;
4457d55b916SGuray Ozen   for (int i = 0; i < Rank; ++i) {
4467d55b916SGuray Ozen     globalDim[i] = static_cast<uint64_t>(descriptor->sizes[Rank - i - 1]);
4477d55b916SGuray Ozen   }
4487d55b916SGuray Ozen   static constexpr int elementSizeInBytes[] = {1, 2, 4, 4, 8, 8, 2,
4497d55b916SGuray Ozen                                                4, 8, 2, 4, 4, 4};
4507d55b916SGuray Ozen   for (int i = 0; i < Rank - 1; ++i) {
4517d55b916SGuray Ozen     globalStrides[i] = static_cast<uint64_t>(
4527d55b916SGuray Ozen         descriptor->strides[Rank - i - 2] * elementSizeInBytes[tensorDataType]);
45365aab9e7SAdam Paszke   }
45465aab9e7SAdam Paszke }
45565aab9e7SAdam Paszke 
456e56d6745SGuray Ozen extern "C" MLIR_CUDA_WRAPPERS_EXPORT void *mgpuTensorMapEncodeTiledMemref(
457e56d6745SGuray Ozen     int64_t tensorRank,                       // Dimensionality of tensor
4587d55b916SGuray Ozen     void *rankedDescriptor,                   // Ranked MemRef descriptor
459e56d6745SGuray Ozen     const CUtensorMapDataType tensorDataType, // Stride size (in bytes)
460e56d6745SGuray Ozen     CUtensorMapInterleave interleave,         // Type of interleaved layout
461e56d6745SGuray Ozen     CUtensorMapSwizzle swizzle,               // Bank swizzling pattern
462e56d6745SGuray Ozen     CUtensorMapL2promotion l2Promotion,       // L2 promotion size
463e56d6745SGuray Ozen     CUtensorMapFloatOOBfill oobFill,          // Padding zfill or NaN fill
464e56d6745SGuray Ozen     int64_t *inputBoxDims // Tensor size (number of elements)
465e56d6745SGuray Ozen ) {
466e56d6745SGuray Ozen   CUtensorMap tensorMap;
467e56d6745SGuray Ozen 
4682379432cSGuray Ozen   uint32_t boxDim[5] = {1, 1, 1, 1, 1}, elementStrides[5] = {1, 1, 1, 1, 1};
4692379432cSGuray Ozen   uint64_t globalDim[5] = {1, 1, 1, 1, 1}, globalStrides[5] = {0};
470e56d6745SGuray Ozen   uint32_t tensorRank32 = uint32_t(tensorRank);
471e56d6745SGuray Ozen 
47265aab9e7SAdam Paszke   char *globalAddress = nullptr;
47365aab9e7SAdam Paszke   switch (tensorRank) {
47465aab9e7SAdam Paszke   case 1:
4757d55b916SGuray Ozen     mgpuGetMemRefDataAndShape<1>(rankedDescriptor, &globalAddress, globalDim,
4767d55b916SGuray Ozen                                  globalStrides, tensorDataType);
47765aab9e7SAdam Paszke     break;
47865aab9e7SAdam Paszke   case 2:
4797d55b916SGuray Ozen     mgpuGetMemRefDataAndShape<2>(rankedDescriptor, &globalAddress, globalDim,
4807d55b916SGuray Ozen                                  globalStrides, tensorDataType);
48165aab9e7SAdam Paszke     break;
48265aab9e7SAdam Paszke   case 3:
4837d55b916SGuray Ozen     mgpuGetMemRefDataAndShape<3>(rankedDescriptor, &globalAddress, globalDim,
4847d55b916SGuray Ozen                                  globalStrides, tensorDataType);
48565aab9e7SAdam Paszke     break;
48665aab9e7SAdam Paszke   case 4:
4877d55b916SGuray Ozen     mgpuGetMemRefDataAndShape<4>(rankedDescriptor, &globalAddress, globalDim,
4887d55b916SGuray Ozen                                  globalStrides, tensorDataType);
48965aab9e7SAdam Paszke     break;
49065aab9e7SAdam Paszke   case 5:
4917d55b916SGuray Ozen     mgpuGetMemRefDataAndShape<5>(rankedDescriptor, &globalAddress, globalDim,
4927d55b916SGuray Ozen                                  globalStrides, tensorDataType);
49365aab9e7SAdam Paszke     break;
49465aab9e7SAdam Paszke   default:
49565aab9e7SAdam Paszke     fprintf(
49665aab9e7SAdam Paszke         stderr,
49765aab9e7SAdam Paszke         "'mgpuTensorMapEncodeTiledMemref' failed with 'rank is too high'\n");
4987d55b916SGuray Ozen     return nullptr;
49965aab9e7SAdam Paszke   }
50065aab9e7SAdam Paszke 
501e56d6745SGuray Ozen   for (int64_t r = 0; r < tensorRank; ++r) {
502e56d6745SGuray Ozen     boxDim[r] = static_cast<uint32_t>(inputBoxDims[tensorRank - r - 1]);
503e56d6745SGuray Ozen   }
504e56d6745SGuray Ozen 
505e56d6745SGuray Ozen   ScopedContext scopedContext;
506e56d6745SGuray Ozen   mgpuTensorMapEncodeTiled(&tensorMap, tensorDataType, tensorRank32,
507e56d6745SGuray Ozen                            globalAddress, globalDim, globalStrides, boxDim,
508e56d6745SGuray Ozen                            elementStrides, interleave, swizzle, l2Promotion,
509e56d6745SGuray Ozen                            oobFill);
510e56d6745SGuray Ozen   // Copy created tensor map to device
511e56d6745SGuray Ozen   CUdeviceptr dTensorMap;
512e56d6745SGuray Ozen   CUDA_REPORT_IF_ERROR(cuMemAlloc(&dTensorMap, sizeof(CUtensorMap)));
513e56d6745SGuray Ozen   CUDA_REPORT_IF_ERROR(cuMemcpy(dTensorMap,
514e56d6745SGuray Ozen                                 reinterpret_cast<CUdeviceptr>(&tensorMap),
515e56d6745SGuray Ozen                                 sizeof(CUtensorMap)));
516e56d6745SGuray Ozen   return reinterpret_cast<void *>(dTensorMap);
517e56d6745SGuray Ozen }
5181dc00712SGuray Ozen #endif
519e56d6745SGuray Ozen 
52050db4789SAart Bik #ifdef MLIR_ENABLE_CUDA_CUSPARSE
5218ed59c53SKun Wu 
522b700a90cSAart Bik ///
523b700a90cSAart Bik /// Wrapper methods for the cuSparse library.
524b700a90cSAart Bik ///
525b700a90cSAart Bik 
5264ebd836dSAart Bik // Some macro magic to get float/double alpha and beta on host.
527dfe29429SKun Wu // TODO: add support to passing alpha and beta as arguments
528cc402de0SKun Wu #define ALPHABETA(dtp, alpha, beta)                                            \
529be6c5320SKun Wu   __nv_bfloat16(alpha##16bf) = 1.0f;                                           \
530be6c5320SKun Wu   __nv_bfloat16(beta##16bf) = 1.0f;                                            \
531be6c5320SKun Wu   __half(alpha##16f) = 1.0f;                                                   \
532be6c5320SKun Wu   __half(beta##16f) = 1.0f;                                                    \
533bcb698bfSAart Bik   float(alpha##f) = 1.0f;                                                      \
534bcb698bfSAart Bik   float(beta##f) = 1.0f;                                                       \
535bcb698bfSAart Bik   double(alpha##d) = 1.0;                                                      \
536bcb698bfSAart Bik   double(beta##d) = 1.0;                                                       \
537bcb698bfSAart Bik   const void *(alpha##p) = nullptr;                                            \
538bcb698bfSAart Bik   const void *(beta##p) = nullptr;                                             \
539cc402de0SKun Wu   if (dtp == CUDA_R_16BF || dtp == CUDA_C_16BF) {                              \
540cc402de0SKun Wu     (alpha##p) = reinterpret_cast<void *>(&(alpha##16bf));                     \
541cc402de0SKun Wu     (beta##p) = reinterpret_cast<void *>(&(beta##16bf));                       \
542cc402de0SKun Wu   } else if (dtp == CUDA_R_16F || dtp == CUDA_C_16F) {                         \
543cc402de0SKun Wu     (alpha##p) = reinterpret_cast<void *>(&(alpha##16f));                      \
544cc402de0SKun Wu     (beta##p) = reinterpret_cast<void *>(&(beta##16f));                        \
545cc402de0SKun Wu   } else if (dtp == CUDA_R_32F || dtp == CUDA_C_32F) {                         \
5464ebd836dSAart Bik     (alpha##p) = reinterpret_cast<void *>(&(alpha##f));                        \
5474ebd836dSAart Bik     (beta##p) = reinterpret_cast<void *>(&(beta##f));                          \
548be6c5320SKun Wu   } else {                                                                     \
5494ebd836dSAart Bik     (alpha##p) = reinterpret_cast<void *>(&(alpha##d));                        \
5504ebd836dSAart Bik     (beta##p) = reinterpret_cast<void *>(&(beta##d));                          \
5514ebd836dSAart Bik   }
5524ebd836dSAart Bik 
553be2dd22bSKun Wu extern "C" MLIR_CUDA_WRAPPERS_EXPORT void mgpuCreateSparseEnv() {
554be2dd22bSKun Wu   // ScopedContext is for cuda initialization.
555be2dd22bSKun Wu   ScopedContext scopedContext;
556be2dd22bSKun Wu   assert(!cusparse_env && "client called mgpuCreateSparseEnv() twice");
557be2dd22bSKun Wu   CUSPARSE_REPORT_IF_ERROR(cusparseCreate(&cusparse_env));
558b700a90cSAart Bik }
559b700a90cSAart Bik 
560be2dd22bSKun Wu extern "C" MLIR_CUDA_WRAPPERS_EXPORT void mgpuDestroySparseEnv() {
561be2dd22bSKun Wu   assert(cusparse_env && "client did not call mgpuCreateSparseEnv()");
562be2dd22bSKun Wu   CUSPARSE_REPORT_IF_ERROR(cusparseDestroy(cusparse_env));
563be2dd22bSKun Wu   cusparse_env = nullptr;
564b700a90cSAart Bik }
565b700a90cSAart Bik 
566b700a90cSAart Bik extern "C" MLIR_CUDA_WRAPPERS_EXPORT void *
567cc402de0SKun Wu mgpuCreateDnVec(intptr_t size, void *values, int32_t dtp, CUstream /*stream*/) {
568b700a90cSAart Bik   cusparseDnVecDescr_t vec = nullptr;
569cc402de0SKun Wu   auto dTp = static_cast<cudaDataType_t>(dtp);
570cc402de0SKun Wu   CUSPARSE_REPORT_IF_ERROR(cusparseCreateDnVec(&vec, size, values, dTp))
571b700a90cSAart Bik   return reinterpret_cast<void *>(vec);
572b700a90cSAart Bik }
573b700a90cSAart Bik 
574b700a90cSAart Bik extern "C" MLIR_CUDA_WRAPPERS_EXPORT void
575b700a90cSAart Bik mgpuDestroyDnVec(void *v, CUstream /*stream*/) {
576b700a90cSAart Bik   cusparseDnVecDescr_t vec = reinterpret_cast<cusparseDnVecDescr_t>(v);
577b700a90cSAart Bik   CUSPARSE_REPORT_IF_ERROR(cusparseDestroyDnVec(vec))
578b700a90cSAart Bik }
579b700a90cSAart Bik 
580b700a90cSAart Bik extern "C" MLIR_CUDA_WRAPPERS_EXPORT void *
581cc402de0SKun Wu mgpuCreateDnMat(intptr_t rows, intptr_t cols, void *values, int32_t dtp,
582b700a90cSAart Bik                 CUstream /*stream*/) {
583b700a90cSAart Bik   cusparseDnMatDescr_t mat = nullptr;
584cc402de0SKun Wu   auto dTp = static_cast<cudaDataType_t>(dtp);
585b700a90cSAart Bik   CUSPARSE_REPORT_IF_ERROR(cusparseCreateDnMat(&mat, rows, cols, /*ld=*/cols,
586cc402de0SKun Wu                                                values, dTp, CUSPARSE_ORDER_ROW))
587b700a90cSAart Bik   return reinterpret_cast<void *>(mat);
588b700a90cSAart Bik }
589b700a90cSAart Bik 
590b700a90cSAart Bik extern "C" MLIR_CUDA_WRAPPERS_EXPORT void
591b700a90cSAart Bik mgpuDestroyDnMat(void *m, CUstream /*stream*/) {
592b700a90cSAart Bik   cusparseDnMatDescr_t mat = reinterpret_cast<cusparseDnMatDescr_t>(m);
593b700a90cSAart Bik   CUSPARSE_REPORT_IF_ERROR(cusparseDestroyDnMat(mat))
594b700a90cSAart Bik }
595b700a90cSAart Bik 
596b700a90cSAart Bik extern "C" MLIR_CUDA_WRAPPERS_EXPORT void *
597b700a90cSAart Bik mgpuCreateCoo(intptr_t rows, intptr_t cols, intptr_t nnz, void *rowIdxs,
598cc402de0SKun Wu               void *colIdxs, void *values, int32_t itp, int32_t dtp,
599b700a90cSAart Bik               CUstream /*stream*/) {
600b700a90cSAart Bik   cusparseSpMatDescr_t mat = nullptr;
601cc402de0SKun Wu   auto iTp = static_cast<cusparseIndexType_t>(itp);
602cc402de0SKun Wu   auto dTp = static_cast<cudaDataType_t>(dtp);
603b700a90cSAart Bik   CUSPARSE_REPORT_IF_ERROR(cusparseCreateCoo(&mat, rows, cols, nnz, rowIdxs,
604cc402de0SKun Wu                                              colIdxs, values, iTp,
605cc402de0SKun Wu                                              CUSPARSE_INDEX_BASE_ZERO, dTp))
606b700a90cSAart Bik   return reinterpret_cast<void *>(mat);
607b700a90cSAart Bik }
608b700a90cSAart Bik 
6099fc02a7aSAart Bik #ifdef CUSPARSE_COO_AOS // deprecated in cuSPARSE 11.2
6109fc02a7aSAart Bik extern "C" MLIR_CUDA_WRAPPERS_EXPORT void *
6119fc02a7aSAart Bik mgpuCreateCooAoS(intptr_t rows, intptr_t cols, intptr_t nnz, void *idxs,
6129fc02a7aSAart Bik                  void *values, int32_t itp, int32_t dtp, CUstream /*stream*/) {
6139fc02a7aSAart Bik   cusparseSpMatDescr_t mat = nullptr;
6149fc02a7aSAart Bik   auto iTp = static_cast<cusparseIndexType_t>(itp);
6159fc02a7aSAart Bik   auto dTp = static_cast<cudaDataType_t>(dtp);
6169fc02a7aSAart Bik   CUSPARSE_REPORT_IF_ERROR(cusparseCreateCooAoS(
6179fc02a7aSAart Bik       &mat, rows, cols, nnz, idxs, values, iTp, CUSPARSE_INDEX_BASE_ZERO, dTp))
6189fc02a7aSAart Bik   return reinterpret_cast<void *>(mat);
6199fc02a7aSAart Bik }
6209fc02a7aSAart Bik #endif // CUSPARSE_COO_AOS
6219fc02a7aSAart Bik 
622b700a90cSAart Bik extern "C" MLIR_CUDA_WRAPPERS_EXPORT void *
623b700a90cSAart Bik mgpuCreateCsr(intptr_t rows, intptr_t cols, intptr_t nnz, void *rowPos,
624cc402de0SKun Wu               void *colIdxs, void *values, int32_t ptp, int32_t itp,
625cc402de0SKun Wu               int32_t dtp, CUstream /*stream*/) {
626b700a90cSAart Bik   cusparseSpMatDescr_t mat = nullptr;
627cc402de0SKun Wu   auto pTp = static_cast<cusparseIndexType_t>(ptp);
628cc402de0SKun Wu   auto iTp = static_cast<cusparseIndexType_t>(itp);
629be6c5320SKun Wu   auto dTp = static_cast<cudaDataType_t>(dtp);
630b700a90cSAart Bik   CUSPARSE_REPORT_IF_ERROR(cusparseCreateCsr(&mat, rows, cols, nnz, rowPos,
631cc402de0SKun Wu                                              colIdxs, values, pTp, iTp,
6327e44f073SKun Wu                                              CUSPARSE_INDEX_BASE_ZERO, dTp))
633b700a90cSAart Bik   return reinterpret_cast<void *>(mat);
634b700a90cSAart Bik }
635b700a90cSAart Bik 
63639038177SAart Bik extern "C" MLIR_CUDA_WRAPPERS_EXPORT void *
63739038177SAart Bik mgpuCreateCsc(intptr_t rows, intptr_t cols, intptr_t nnz, void *colPos,
63839038177SAart Bik               void *rowIdxs, void *values, int32_t ptp, int32_t itp,
63939038177SAart Bik               int32_t dtp, CUstream /*stream*/) {
64039038177SAart Bik   cusparseSpMatDescr_t mat = nullptr;
64139038177SAart Bik   auto pTp = static_cast<cusparseIndexType_t>(ptp);
64239038177SAart Bik   auto iTp = static_cast<cusparseIndexType_t>(itp);
64339038177SAart Bik   auto dTp = static_cast<cudaDataType_t>(dtp);
64439038177SAart Bik   CUSPARSE_REPORT_IF_ERROR(cusparseCreateCsc(&mat, rows, cols, nnz, colPos,
64539038177SAart Bik                                              rowIdxs, values, pTp, iTp,
64639038177SAart Bik                                              CUSPARSE_INDEX_BASE_ZERO, dTp))
64739038177SAart Bik   return reinterpret_cast<void *>(mat);
64839038177SAart Bik }
64939038177SAart Bik 
65039038177SAart Bik extern "C" MLIR_CUDA_WRAPPERS_EXPORT void *
65139038177SAart Bik mgpuCreateBsr(intptr_t brows, intptr_t bcols, intptr_t bnnz, intptr_t rBsz,
65239038177SAart Bik               intptr_t cBsz, void *rowPos, void *colIdxs, void *values,
65339038177SAart Bik               int32_t ptp, int32_t itp, int32_t dtp, CUstream /*stream*/) {
65439038177SAart Bik   cusparseSpMatDescr_t mat = nullptr;
6557ac330a4SAart Bik #if CUSPARSE_VERSION >= 12100
65639038177SAart Bik   auto pTp = static_cast<cusparseIndexType_t>(ptp);
65739038177SAart Bik   auto iTp = static_cast<cusparseIndexType_t>(itp);
65839038177SAart Bik   auto dTp = static_cast<cudaDataType_t>(dtp);
65939038177SAart Bik   CUSPARSE_REPORT_IF_ERROR(cusparseCreateBsr(
66039038177SAart Bik       &mat, brows, bcols, bnnz, rBsz, cBsz, rowPos, colIdxs, values, pTp, iTp,
66139038177SAart Bik       CUSPARSE_INDEX_BASE_ZERO, dTp, CUSPARSE_ORDER_ROW))
6627ac330a4SAart Bik #endif
66339038177SAart Bik   return reinterpret_cast<void *>(mat);
66439038177SAart Bik }
66539038177SAart Bik 
666b700a90cSAart Bik extern "C" MLIR_CUDA_WRAPPERS_EXPORT void
667b700a90cSAart Bik mgpuDestroySpMat(void *m, CUstream /*stream*/) {
668b700a90cSAart Bik   cusparseSpMatDescr_t mat = reinterpret_cast<cusparseSpMatDescr_t>(m);
669b700a90cSAart Bik   CUSPARSE_REPORT_IF_ERROR(cusparseDestroySpMat(mat))
670b700a90cSAart Bik }
671b700a90cSAart Bik 
672be2dd22bSKun Wu extern "C" MLIR_CUDA_WRAPPERS_EXPORT intptr_t mgpuSpMVBufferSize(
673be2dd22bSKun Wu     int32_t ma, void *a, void *x, void *y, int32_t ctp, CUstream /*stream*/) {
674be2dd22bSKun Wu   assert(cusparse_env && "client did not call mgpuCreateSparseEnv()");
675235fbe79SKun Wu   cusparseOperation_t modeA = static_cast<cusparseOperation_t>(ma);
676b700a90cSAart Bik   cusparseSpMatDescr_t matA = reinterpret_cast<cusparseSpMatDescr_t>(a);
677b700a90cSAart Bik   cusparseDnVecDescr_t vecX = reinterpret_cast<cusparseDnVecDescr_t>(x);
678b700a90cSAart Bik   cusparseDnVecDescr_t vecY = reinterpret_cast<cusparseDnVecDescr_t>(y);
679cc402de0SKun Wu   cudaDataType_t cTp = static_cast<cudaDataType_t>(ctp);
680cc402de0SKun Wu   ALPHABETA(cTp, alpha, beta)
681b700a90cSAart Bik   size_t bufferSize = 0;
682be2dd22bSKun Wu   CUSPARSE_REPORT_IF_ERROR(cusparseSpMV_bufferSize(
683be2dd22bSKun Wu       cusparse_env, modeA, alphap, matA, vecX, betap, vecY, cTp,
684be2dd22bSKun Wu       CUSPARSE_SPMV_ALG_DEFAULT, &bufferSize))
6853635c743SAart Bik   return bufferSize;
686b700a90cSAart Bik }
687b700a90cSAart Bik 
688be2dd22bSKun Wu extern "C" MLIR_CUDA_WRAPPERS_EXPORT void mgpuSpMV(int32_t ma, void *a, void *x,
689be2dd22bSKun Wu                                                    void *y, int32_t ctp,
690be2dd22bSKun Wu                                                    void *buf,
691a8e1f80fSAart Bik                                                    CUstream /*stream*/) {
692be2dd22bSKun Wu   assert(cusparse_env && "client did not call mgpuCreateSparseEnv()");
693235fbe79SKun Wu   cusparseOperation_t modeA = static_cast<cusparseOperation_t>(ma);
694b700a90cSAart Bik   cusparseSpMatDescr_t matA = reinterpret_cast<cusparseSpMatDescr_t>(a);
695b700a90cSAart Bik   cusparseDnVecDescr_t vecX = reinterpret_cast<cusparseDnVecDescr_t>(x);
696b700a90cSAart Bik   cusparseDnVecDescr_t vecY = reinterpret_cast<cusparseDnVecDescr_t>(y);
697cc402de0SKun Wu   cudaDataType_t cTp = static_cast<cudaDataType_t>(ctp);
698cc402de0SKun Wu   ALPHABETA(cTp, alpha, beta)
699be2dd22bSKun Wu   CUSPARSE_REPORT_IF_ERROR(cusparseSpMV(cusparse_env, modeA, alphap, matA, vecX,
700cc402de0SKun Wu                                         betap, vecY, cTp,
701235fbe79SKun Wu                                         CUSPARSE_SPMV_ALG_DEFAULT, buf))
702981cf167SAart Bik }
703981cf167SAart Bik 
704235fbe79SKun Wu extern "C" MLIR_CUDA_WRAPPERS_EXPORT intptr_t
705be2dd22bSKun Wu mgpuSpMMBufferSize(int32_t ma, int32_t mb, void *a, void *b, void *c,
706cc402de0SKun Wu                    int32_t ctp, CUstream /*stream*/) {
707be2dd22bSKun Wu   assert(cusparse_env && "client did not call mgpuCreateSparseEnv()");
708235fbe79SKun Wu   cusparseOperation_t modeA = static_cast<cusparseOperation_t>(ma);
709235fbe79SKun Wu   cusparseOperation_t modeB = static_cast<cusparseOperation_t>(mb);
710981cf167SAart Bik   cusparseSpMatDescr_t matA = reinterpret_cast<cusparseSpMatDescr_t>(a);
711981cf167SAart Bik   cusparseDnMatDescr_t matB = reinterpret_cast<cusparseDnMatDescr_t>(b);
712981cf167SAart Bik   cusparseDnMatDescr_t matC = reinterpret_cast<cusparseDnMatDescr_t>(c);
713cc402de0SKun Wu   cudaDataType_t cTp = static_cast<cudaDataType_t>(ctp);
714cc402de0SKun Wu   ALPHABETA(cTp, alpha, beta)
715981cf167SAart Bik   size_t bufferSize = 0;
716981cf167SAart Bik   CUSPARSE_REPORT_IF_ERROR(cusparseSpMM_bufferSize(
717be2dd22bSKun Wu       cusparse_env, modeA, modeB, alphap, matA, matB, betap, matC, cTp,
718a8e1f80fSAart Bik       CUSPARSE_SPMM_ALG_DEFAULT, &bufferSize))
7193635c743SAart Bik   return bufferSize;
720981cf167SAart Bik }
721981cf167SAart Bik 
722be2dd22bSKun Wu extern "C" MLIR_CUDA_WRAPPERS_EXPORT void mgpuSpMM(int32_t ma, int32_t mb,
723be2dd22bSKun Wu                                                    void *a, void *b, void *c,
724be2dd22bSKun Wu                                                    int32_t ctp, void *buf,
725be2dd22bSKun Wu                                                    CUstream /*stream*/) {
726be2dd22bSKun Wu   assert(cusparse_env && "client did not call mgpuCreateSparseEnv()");
727235fbe79SKun Wu   cusparseOperation_t modeA = static_cast<cusparseOperation_t>(ma);
728235fbe79SKun Wu   cusparseOperation_t modeB = static_cast<cusparseOperation_t>(mb);
729981cf167SAart Bik   cusparseSpMatDescr_t matA = reinterpret_cast<cusparseSpMatDescr_t>(a);
730981cf167SAart Bik   cusparseDnMatDescr_t matB = reinterpret_cast<cusparseDnMatDescr_t>(b);
731981cf167SAart Bik   cusparseDnMatDescr_t matC = reinterpret_cast<cusparseDnMatDescr_t>(c);
732cc402de0SKun Wu   cudaDataType_t cTp = static_cast<cudaDataType_t>(ctp);
733cc402de0SKun Wu   ALPHABETA(cTp, alpha, beta)
734be2dd22bSKun Wu   CUSPARSE_REPORT_IF_ERROR(cusparseSpMM(cusparse_env, modeA, modeB, alphap,
735be2dd22bSKun Wu                                         matA, matB, betap, matC, cTp,
736235fbe79SKun Wu                                         CUSPARSE_SPMM_ALG_DEFAULT, buf))
737b700a90cSAart Bik }
738cf44847bSKun Wu 
739cf44847bSKun Wu extern "C" MLIR_CUDA_WRAPPERS_EXPORT intptr_t
740be2dd22bSKun Wu mgpuSDDMMBufferSize(int32_t ma, int32_t mb, void *a, void *b, void *c,
741cc402de0SKun Wu                     int32_t ctp, CUstream /*stream*/) {
742be2dd22bSKun Wu   assert(cusparse_env && "client did not call mgpuCreateSparseEnv()");
743cf44847bSKun Wu   cusparseOperation_t modeA = static_cast<cusparseOperation_t>(ma);
744cf44847bSKun Wu   cusparseOperation_t modeB = static_cast<cusparseOperation_t>(mb);
745cf44847bSKun Wu   cusparseDnMatDescr_t matA = reinterpret_cast<cusparseDnMatDescr_t>(a);
746cf44847bSKun Wu   cusparseDnMatDescr_t matB = reinterpret_cast<cusparseDnMatDescr_t>(b);
747cf44847bSKun Wu   cusparseSpMatDescr_t matC = reinterpret_cast<cusparseSpMatDescr_t>(c);
748cc402de0SKun Wu   auto cTp = static_cast<cudaDataType_t>(ctp);
749cc402de0SKun Wu   ALPHABETA(cTp, alpha, beta)
750cf44847bSKun Wu   size_t bufferSize = 0;
751cf44847bSKun Wu   CUSPARSE_REPORT_IF_ERROR(cusparseSDDMM_bufferSize(
752be2dd22bSKun Wu       cusparse_env, modeA, modeB, alphap, matA, matB, betap, matC, cTp,
753cf44847bSKun Wu       CUSPARSE_SDDMM_ALG_DEFAULT, &bufferSize))
7543635c743SAart Bik   return bufferSize;
755cf44847bSKun Wu }
756cf44847bSKun Wu 
757be2dd22bSKun Wu extern "C" MLIR_CUDA_WRAPPERS_EXPORT void mgpuSDDMM(int32_t ma, int32_t mb,
758be2dd22bSKun Wu                                                     void *a, void *b, void *c,
759be2dd22bSKun Wu                                                     int32_t ctp, void *buf,
760be2dd22bSKun Wu                                                     CUstream /*stream*/) {
761be2dd22bSKun Wu   assert(cusparse_env && "client did not call mgpuCreateSparseEnv()");
762cf44847bSKun Wu   cusparseOperation_t modeA = static_cast<cusparseOperation_t>(ma);
763cf44847bSKun Wu   cusparseOperation_t modeB = static_cast<cusparseOperation_t>(mb);
764cf44847bSKun Wu   cusparseDnMatDescr_t matA = reinterpret_cast<cusparseDnMatDescr_t>(a);
765cf44847bSKun Wu   cusparseDnMatDescr_t matB = reinterpret_cast<cusparseDnMatDescr_t>(b);
766cf44847bSKun Wu   cusparseSpMatDescr_t matC = reinterpret_cast<cusparseSpMatDescr_t>(c);
767cc402de0SKun Wu   auto cTp = static_cast<cudaDataType_t>(ctp);
768cc402de0SKun Wu   ALPHABETA(cTp, alpha, beta)
769be2dd22bSKun Wu   CUSPARSE_REPORT_IF_ERROR(cusparseSDDMM(cusparse_env, modeA, modeB, alphap,
770be2dd22bSKun Wu                                          matA, matB, betap, matC, cTp,
771cf44847bSKun Wu                                          CUSPARSE_SDDMM_ALG_DEFAULT, buf))
772cf44847bSKun Wu }
7738ed59c53SKun Wu 
774289f7231SAart Bik extern "C" MLIR_CUDA_WRAPPERS_EXPORT void *
775289f7231SAart Bik mgpuSpGEMMCreateDescr(CUstream /*stream*/) {
776289f7231SAart Bik   cusparseSpGEMMDescr_t spgemmDesc = nullptr;
777289f7231SAart Bik   CUSPARSE_REPORT_IF_ERROR(cusparseSpGEMM_createDescr(&spgemmDesc))
778289f7231SAart Bik   return reinterpret_cast<void *>(spgemmDesc);
779289f7231SAart Bik }
780289f7231SAart Bik 
781289f7231SAart Bik extern "C" MLIR_CUDA_WRAPPERS_EXPORT void
782289f7231SAart Bik mgpuSpGEMMDestroyDescr(void *s, CUstream /*stream*/) {
783289f7231SAart Bik   cusparseSpGEMMDescr_t spgemmDesc = reinterpret_cast<cusparseSpGEMMDescr_t>(s);
784289f7231SAart Bik   CUSPARSE_REPORT_IF_ERROR(cusparseSpGEMM_destroyDescr(spgemmDesc))
785289f7231SAart Bik }
786289f7231SAart Bik 
787dfe29429SKun Wu extern "C" MLIR_CUDA_WRAPPERS_EXPORT intptr_t mgpuSpGEMMWorkEstimation(
788dfe29429SKun Wu     void *s, int32_t ma, int32_t mb, void *a, void *b, void *c, int32_t ctp,
789e7e4ed0dSAart Bik     intptr_t bs, void *buf, CUstream /*stream*/) {
790dfe29429SKun Wu   cusparseSpGEMMDescr_t spgemmDesc = reinterpret_cast<cusparseSpGEMMDescr_t>(s);
791dfe29429SKun Wu   cusparseOperation_t modeA = static_cast<cusparseOperation_t>(ma);
792dfe29429SKun Wu   cusparseOperation_t modeB = static_cast<cusparseOperation_t>(mb);
793dfe29429SKun Wu   cusparseSpMatDescr_t matA = reinterpret_cast<cusparseSpMatDescr_t>(a);
794dfe29429SKun Wu   cusparseSpMatDescr_t matB = reinterpret_cast<cusparseSpMatDescr_t>(b);
795dfe29429SKun Wu   cusparseSpMatDescr_t matC = reinterpret_cast<cusparseSpMatDescr_t>(c);
796dfe29429SKun Wu   auto cTp = static_cast<cudaDataType_t>(ctp);
797dfe29429SKun Wu   ALPHABETA(cTp, alpha, beta)
798dfe29429SKun Wu   size_t newBufferSize = bs;
799dfe29429SKun Wu   CUSPARSE_REPORT_IF_ERROR(cusparseSpGEMM_workEstimation(
800dfe29429SKun Wu       cusparse_env, modeA, modeB, alphap, matA, matB, betap, matC, cTp,
801e7e4ed0dSAart Bik       CUSPARSE_SPGEMM_DEFAULT, spgemmDesc, &newBufferSize, buf))
8023635c743SAart Bik   return newBufferSize;
803dfe29429SKun Wu }
804dfe29429SKun Wu 
805e7e4ed0dSAart Bik extern "C" MLIR_CUDA_WRAPPERS_EXPORT intptr_t
806e7e4ed0dSAart Bik mgpuSpGEMMCompute(void *s, int32_t ma, int32_t mb, void *a, void *b, void *c,
807e7e4ed0dSAart Bik                   int32_t ctp, intptr_t bsz2, void *buf2, CUstream /*stream*/) {
808dfe29429SKun Wu   cusparseSpGEMMDescr_t spgemmDesc = reinterpret_cast<cusparseSpGEMMDescr_t>(s);
809dfe29429SKun Wu   cusparseOperation_t modeA = static_cast<cusparseOperation_t>(ma);
810dfe29429SKun Wu   cusparseOperation_t modeB = static_cast<cusparseOperation_t>(mb);
811dfe29429SKun Wu   cusparseSpMatDescr_t matA = reinterpret_cast<cusparseSpMatDescr_t>(a);
812dfe29429SKun Wu   cusparseSpMatDescr_t matB = reinterpret_cast<cusparseSpMatDescr_t>(b);
813dfe29429SKun Wu   cusparseSpMatDescr_t matC = reinterpret_cast<cusparseSpMatDescr_t>(c);
814dfe29429SKun Wu   auto cTp = static_cast<cudaDataType_t>(ctp);
815dfe29429SKun Wu   ALPHABETA(cTp, alpha, beta)
816dfe29429SKun Wu   size_t newBufferSize2 = bsz2;
817dfe29429SKun Wu   CUSPARSE_REPORT_IF_ERROR(cusparseSpGEMM_compute(
818dfe29429SKun Wu       cusparse_env, modeA, modeB, alphap, matA, matB, betap, matC, cTp,
819e7e4ed0dSAart Bik       CUSPARSE_SPGEMM_DEFAULT, spgemmDesc, &newBufferSize2, buf2))
8203635c743SAart Bik   return newBufferSize2;
821dfe29429SKun Wu }
822dfe29429SKun Wu 
823dfe29429SKun Wu extern "C" MLIR_CUDA_WRAPPERS_EXPORT void
824dfe29429SKun Wu mgpuSpGEMMCopy(void *s, int32_t ma, int32_t mb, void *a, void *b, void *c,
825e7e4ed0dSAart Bik                int32_t ctp, CUstream /*stream*/) {
826dfe29429SKun Wu   cusparseSpGEMMDescr_t spgemmDesc = reinterpret_cast<cusparseSpGEMMDescr_t>(s);
827dfe29429SKun Wu   cusparseOperation_t modeA = static_cast<cusparseOperation_t>(ma);
828dfe29429SKun Wu   cusparseOperation_t modeB = static_cast<cusparseOperation_t>(mb);
829dfe29429SKun Wu   cusparseSpMatDescr_t matA = reinterpret_cast<cusparseSpMatDescr_t>(a);
830dfe29429SKun Wu   cusparseSpMatDescr_t matB = reinterpret_cast<cusparseSpMatDescr_t>(b);
831dfe29429SKun Wu   cusparseSpMatDescr_t matC = reinterpret_cast<cusparseSpMatDescr_t>(c);
832dfe29429SKun Wu   auto cTp = static_cast<cudaDataType_t>(ctp);
833dfe29429SKun Wu   ALPHABETA(cTp, alpha, beta)
834e7e4ed0dSAart Bik   CUSPARSE_REPORT_IF_ERROR(
835e7e4ed0dSAart Bik       cusparseSpGEMM_copy(cusparse_env, modeA, modeB, alphap, matA, matB, betap,
836e7e4ed0dSAart Bik                           matC, cTp, CUSPARSE_SPGEMM_DEFAULT, spgemmDesc))
837dfe29429SKun Wu }
838dfe29429SKun Wu 
839dfe29429SKun Wu extern "C" MLIR_CUDA_WRAPPERS_EXPORT void
840289f7231SAart Bik mgpuSpMatGetSize(void *m, void *r, void *c, void *n, CUstream /*stream*/) {
841dfe29429SKun Wu   cusparseConstSpMatDescr_t matDescr =
842dfe29429SKun Wu       reinterpret_cast<cusparseConstSpMatDescr_t>(m);
843dfe29429SKun Wu   int64_t *rows = reinterpret_cast<int64_t *>(r);
844dfe29429SKun Wu   int64_t *cols = reinterpret_cast<int64_t *>(c);
845dfe29429SKun Wu   int64_t *nnz = reinterpret_cast<int64_t *>(n);
846dfe29429SKun Wu   CUSPARSE_REPORT_IF_ERROR(cusparseSpMatGetSize(matDescr, rows, cols, nnz));
847dfe29429SKun Wu }
848dfe29429SKun Wu 
84995a6c509SAart Bik extern "C" MLIR_CUDA_WRAPPERS_EXPORT void
85095a6c509SAart Bik mgpuSetCsrPointers(void *m, void *p, void *c, void *v, CUstream /*stream*/) {
85195a6c509SAart Bik   cusparseSpMatDescr_t matDescr = reinterpret_cast<cusparseSpMatDescr_t>(m);
85295a6c509SAart Bik   CUSPARSE_REPORT_IF_ERROR(cusparseCsrSetPointers(matDescr, p, c, v));
85395a6c509SAart Bik }
85495a6c509SAart Bik 
85550db4789SAart Bik #ifdef MLIR_ENABLE_CUDA_CUSPARSELT
8568ed59c53SKun Wu 
8578ed59c53SKun Wu ///
8588ed59c53SKun Wu /// Wrapper methods for the cuSparseLt library.
8598ed59c53SKun Wu ///
8608ed59c53SKun Wu 
8618ed59c53SKun Wu struct cusparseLtSpMatHandleAndData {
8628ed59c53SKun Wu   cusparseLtMatDescriptor_t mat;
863ac30f48eSKun Wu   // TODO: the following three are associated with the SpMM operator rather than
864ac30f48eSKun Wu   // the sparse matrix. Create workspace buffers and pass them to the SpMM
8658ed59c53SKun Wu   // execution.
8668ed59c53SKun Wu   cusparseLtMatmulAlgSelection_t alg_sel;
8678ed59c53SKun Wu   cusparseLtMatmulPlan_t plan;
8688ed59c53SKun Wu   cusparseLtMatmulDescriptor_t matmul;
869ac30f48eSKun Wu   void *values{nullptr};
8708ed59c53SKun Wu };
8718ed59c53SKun Wu 
8728ed59c53SKun Wu struct cusparseLtDnMatHandleAndData {
8738ed59c53SKun Wu   cusparseLtMatDescriptor_t mat;
8748ed59c53SKun Wu   void *values{nullptr};
8758ed59c53SKun Wu };
8768ed59c53SKun Wu 
8777a3ebba9SKun Wu static_assert(sizeof(cusparseLtHandle_t) == 11024,
8787a3ebba9SKun Wu               "Unexpected cusparseLt handle size");
8797a3ebba9SKun Wu static_assert(sizeof(cusparseLtSpMatHandleAndData) == 44104,
8807a3ebba9SKun Wu               "Unexpected cusparseLt sparse matrix handle size");
8817a3ebba9SKun Wu static_assert(sizeof(cusparseLtDnMatHandleAndData) == 11032,
8827a3ebba9SKun Wu               "Unexpected cusparseLt dense matrix handle size");
8838ed59c53SKun Wu 
884be2dd22bSKun Wu extern "C" MLIR_CUDA_WRAPPERS_EXPORT void mgpuCreateSparseLtEnv() {
885be2dd22bSKun Wu   // ScopedContext is for cuda initialization.
886be2dd22bSKun Wu   ScopedContext scopedContext;
887be2dd22bSKun Wu   assert(!cusparseLt_initiated &&
888be2dd22bSKun Wu          "client called mgpuCreateSparseLtEnv() twice");
88986eff489SAart Bik   // Note that cuSparseLt still uses cusparseStatus_t.
890be2dd22bSKun Wu   CUSPARSE_REPORT_IF_ERROR(cusparseLtInit(&cusparseLt_env));
891be2dd22bSKun Wu   cusparseLt_initiated = true;
892be2dd22bSKun Wu }
893be2dd22bSKun Wu 
894be2dd22bSKun Wu extern "C" MLIR_CUDA_WRAPPERS_EXPORT void mgpuDestroySparseLtEnv() {
895be2dd22bSKun Wu   assert(cusparseLt_initiated && "client did not call mgpuCreateSparseLtEnv()");
896be2dd22bSKun Wu   CUSPARSE_REPORT_IF_ERROR(cusparseLtDestroy(&cusparseLt_env));
897be2dd22bSKun Wu   cusparseLt_initiated = false;
8988ed59c53SKun Wu }
8998ed59c53SKun Wu 
9008ed59c53SKun Wu extern "C" MLIR_CUDA_WRAPPERS_EXPORT void
901be2dd22bSKun Wu mgpuCreateCuSparseLtDnMat(void *dh, intptr_t rows, intptr_t cols, void *values,
902be2dd22bSKun Wu                           int32_t dtp, CUstream /*stream*/) {
903be2dd22bSKun Wu   assert(cusparseLt_initiated && "client did not call mgpuCreateSparseLtEnv()");
9048ed59c53SKun Wu   auto dnmat_handle = reinterpret_cast<cusparseLtDnMatHandleAndData *>(dh);
90586eff489SAart Bik   dnmat_handle->values = values;
906ac30f48eSKun Wu   auto dTp = static_cast<cudaDataType_t>(dtp);
90786eff489SAart Bik   // Assume row-major when deciding lda.
90886eff489SAart Bik   const uint32_t alignment = 16;
9098ed59c53SKun Wu   CUSPARSE_REPORT_IF_ERROR(cusparseLtDenseDescriptorInit(
910be2dd22bSKun Wu       &cusparseLt_env, &(dnmat_handle->mat), rows, cols, /*lda=*/cols,
91186eff489SAart Bik       alignment, dTp, CUSPARSE_ORDER_ROW))
9128ed59c53SKun Wu }
9138ed59c53SKun Wu 
9148ed59c53SKun Wu extern "C" MLIR_CUDA_WRAPPERS_EXPORT void
91586eff489SAart Bik mgpuDestroyCuSparseLtDnMat(void *dh, CUstream /*stream*/) {
91686eff489SAart Bik   auto dnmat_handle = reinterpret_cast<cusparseLtDnMatHandleAndData *>(dh);
91786eff489SAart Bik   CUSPARSE_REPORT_IF_ERROR(cusparseLtMatDescriptorDestroy(&(dnmat_handle->mat)))
9188ed59c53SKun Wu }
9198ed59c53SKun Wu 
9208ed59c53SKun Wu extern "C" MLIR_CUDA_WRAPPERS_EXPORT void
921be2dd22bSKun Wu mgpuCusparseLtCreate2To4SpMat(void *sh, intptr_t rows, intptr_t cols,
922ac30f48eSKun Wu                               void *values, int32_t dtp, CUstream /*stream*/) {
923be2dd22bSKun Wu   assert(cusparseLt_initiated && "client did not call mgpuCreateSparseLtEnv()");
9248ed59c53SKun Wu   auto spmat_handle = reinterpret_cast<cusparseLtSpMatHandleAndData *>(sh);
9258ed59c53SKun Wu   spmat_handle->values = values;
926ac30f48eSKun Wu   auto dTp = static_cast<cudaDataType_t>(dtp);
92786eff489SAart Bik   // Assume row-major when deciding lda.
92886eff489SAart Bik   const uint32_t alignment = 16;
9298ed59c53SKun Wu   CUSPARSE_REPORT_IF_ERROR(cusparseLtStructuredDescriptorInit(
93086eff489SAart Bik       &cusparseLt_env, &(spmat_handle->mat), rows, cols, /*ld=*/cols, alignment,
93186eff489SAart Bik       dTp, CUSPARSE_ORDER_ROW, CUSPARSELT_SPARSITY_50_PERCENT))
93286eff489SAart Bik }
93386eff489SAart Bik 
93486eff489SAart Bik extern "C" MLIR_CUDA_WRAPPERS_EXPORT void
93586eff489SAart Bik mgpuDestroyCuSparseLtSpMat(void *sh, CUstream /*stream*/) {
93686eff489SAart Bik   auto spmat_handle = reinterpret_cast<cusparseLtSpMatHandleAndData *>(sh);
93786eff489SAart Bik   CUSPARSE_REPORT_IF_ERROR(cusparseLtMatDescriptorDestroy(&(spmat_handle->mat)))
9388ed59c53SKun Wu }
9398ed59c53SKun Wu 
9408ed59c53SKun Wu // Several things are being done in this stage, algorithm selection, planning,
9418ed59c53SKun Wu // and returning workspace and compressed matrices data buffer sizes.
9421e491c42SKun Wu // The parameter prune_flag is used to indicate whether pruning and pruning
9431e491c42SKun Wu // check will happen 0 means not prune or prune check, 1 means prune, 2 means
9441e491c42SKun Wu // prune & prune check
9458ed59c53SKun Wu extern "C" MLIR_CUDA_WRAPPERS_EXPORT void
946be2dd22bSKun Wu mgpuCuSparseLtSpMMBufferSize(void *bs, int32_t ma, int32_t mb, void *a, void *b,
9471e491c42SKun Wu                              void *c, int32_t ctp, int32_t prune_flag,
9481e491c42SKun Wu                              CUstream stream) {
949be2dd22bSKun Wu   assert(cusparseLt_initiated && "client did not call mgpuCreateSparseLtEnv()");
9508ed59c53SKun Wu   // TODO: support more advanced settings, e.g., the input right operand is a
9518ed59c53SKun Wu   // sparse matrix assuming matA is the sparse matrix
9528ed59c53SKun Wu   auto matA = reinterpret_cast<cusparseLtSpMatHandleAndData *>(a);
953ac30f48eSKun Wu   auto matB = reinterpret_cast<cusparseLtDnMatHandleAndData *>(b);
954ac30f48eSKun Wu   auto matC = reinterpret_cast<cusparseLtDnMatHandleAndData *>(c);
9558998bcfbSAart Bik   auto workspace_size = reinterpret_cast<size_t *>(bs);
9568998bcfbSAart Bik   auto compressed_size = &(reinterpret_cast<size_t *>(bs)[1]);
9578998bcfbSAart Bik   auto compressed_buffer_size = &(reinterpret_cast<size_t *>(bs)[2]);
958ac30f48eSKun Wu   auto cTp = static_cast<cusparseComputeType>(ctp);
9598ed59c53SKun Wu 
960ac30f48eSKun Wu   cusparseOperation_t modeA = static_cast<cusparseOperation_t>(ma);
961ac30f48eSKun Wu   cusparseOperation_t modeB = static_cast<cusparseOperation_t>(mb);
962ac30f48eSKun Wu   CUSPARSE_REPORT_IF_ERROR(cusparseLtMatmulDescriptorInit(
963be2dd22bSKun Wu       &cusparseLt_env, &(matA->matmul), modeA, modeB, &(matA->mat),
964be2dd22bSKun Wu       &(matB->mat), &(matC->mat), &(matC->mat), cTp))
965ac30f48eSKun Wu   CUSPARSE_REPORT_IF_ERROR(cusparseLtMatmulAlgSelectionInit(
966be2dd22bSKun Wu       &cusparseLt_env, &(matA->alg_sel), &(matA->matmul),
967be2dd22bSKun Wu       CUSPARSELT_MATMUL_ALG_DEFAULT))
9688ed59c53SKun Wu   int alg = 0;
969ac30f48eSKun Wu   CUSPARSE_REPORT_IF_ERROR(cusparseLtMatmulAlgSetAttribute(
970be2dd22bSKun Wu       &cusparseLt_env, &(matA->alg_sel), CUSPARSELT_MATMUL_ALG_CONFIG_ID, &alg,
9718ed59c53SKun Wu       sizeof(alg)))
9728ed59c53SKun Wu 
973ac30f48eSKun Wu   CUSPARSE_REPORT_IF_ERROR(cusparseLtMatmulPlanInit(
974be2dd22bSKun Wu       &cusparseLt_env, &(matA->plan), &(matA->matmul), &(matA->alg_sel)))
975ac30f48eSKun Wu 
9764df01dc2SAart Bik   // Pruning step (in-place).
9771e491c42SKun Wu   if (prune_flag > 0)
9781e491c42SKun Wu     CUSPARSE_REPORT_IF_ERROR(cusparseLtSpMMAPrune(
9791e491c42SKun Wu         &cusparseLt_env, &(matA->matmul), matA->values, matA->values,
9801e491c42SKun Wu         CUSPARSELT_PRUNE_SPMMA_STRIP, stream))
9814df01dc2SAart Bik 
9824df01dc2SAart Bik   // Check structure of A.
9834df01dc2SAart Bik   // Note that this adds a synchronization on the stream.
9844df01dc2SAart Bik   // TODO: Do we want that?
9851e491c42SKun Wu   if (prune_flag == 2) {
98641a07e66SAart Bik     int *dvalid = (int *)mgpuMemAlloc(sizeof(int), stream, false);
9874df01dc2SAart Bik     CUSPARSE_REPORT_IF_ERROR(cusparseLtSpMMAPruneCheck(
9884df01dc2SAart Bik         &cusparseLt_env, &(matA->matmul), matA->values, dvalid, stream))
9894df01dc2SAart Bik     int valid = 0;
9904df01dc2SAart Bik     mgpuMemcpy(&valid, dvalid, sizeof(int), stream);
9914df01dc2SAart Bik     mgpuStreamSynchronize(stream);
9924df01dc2SAart Bik     mgpuMemFree(dvalid, stream);
9934df01dc2SAart Bik     if (valid != 0)
9944df01dc2SAart Bik       fprintf(stderr, "CUPARSE-LT: sparse matrix is not 2:4; computed results "
9954df01dc2SAart Bik                       "will be invalid\n");
9961e491c42SKun Wu   }
9974df01dc2SAart Bik 
998be2dd22bSKun Wu   CUSPARSE_REPORT_IF_ERROR(cusparseLtMatmulGetWorkspace(
9993635c743SAart Bik       &cusparseLt_env, &(matA->plan), workspace_size))
1000ac30f48eSKun Wu   CUSPARSE_REPORT_IF_ERROR(cusparseLtSpMMACompressedSize(
10013635c743SAart Bik       &cusparseLt_env, &(matA->plan), compressed_size, compressed_buffer_size))
10028ed59c53SKun Wu }
10038ed59c53SKun Wu 
10048ed59c53SKun Wu extern "C" MLIR_CUDA_WRAPPERS_EXPORT void
1005be2dd22bSKun Wu mgpuCuSparseLtSpMM(void *a, void *b, void *c, void *d_workspace,
1006ac30f48eSKun Wu                    void *dA_compressed, void *dA_compressedBuffer,
1007ac30f48eSKun Wu                    CUstream stream) {
1008be2dd22bSKun Wu   assert(cusparseLt_initiated && "client did not call mgpuCreateSparseLtEnv()");
10098ed59c53SKun Wu   auto matA = reinterpret_cast<cusparseLtSpMatHandleAndData *>(a);
10108ed59c53SKun Wu   auto matB = reinterpret_cast<cusparseLtDnMatHandleAndData *>(b);
10118ed59c53SKun Wu   auto matC = reinterpret_cast<cusparseLtDnMatHandleAndData *>(c);
10128ed59c53SKun Wu 
1013ac30f48eSKun Wu   ALPHABETA(CUDA_R_32F, alpha, beta)
1014ac30f48eSKun Wu   CUSPARSE_REPORT_IF_ERROR(
1015be2dd22bSKun Wu       cusparseLtSpMMACompress(&cusparseLt_env, &(matA->plan), (matA->values),
1016ac30f48eSKun Wu                               dA_compressed, dA_compressedBuffer, stream))
10178ed59c53SKun Wu 
10188ed59c53SKun Wu   // TODO: add support to multi-stream execution
10198ed59c53SKun Wu   // Perform the matrix multiplication. D = A*B+C using C==D for now
1020ac30f48eSKun Wu   CUSPARSE_REPORT_IF_ERROR(
1021be2dd22bSKun Wu       cusparseLtMatmul(&cusparseLt_env, &(matA->plan), alphap, dA_compressed,
1022ac30f48eSKun Wu                        matB->values, betap, matC->values,
1023ac30f48eSKun Wu                        /*dD*/ matC->values, d_workspace, nullptr, 0))
10248ed59c53SKun Wu 
1025ac30f48eSKun Wu   CUSPARSE_REPORT_IF_ERROR(cusparseLtMatDescriptorDestroy(&(matA->mat)))
10268ed59c53SKun Wu   // destroy the plan associated with the sparse matrix
1027ac30f48eSKun Wu   CUSPARSE_REPORT_IF_ERROR(cusparseLtMatmulPlanDestroy(&(matA->plan)))
10288ed59c53SKun Wu }
10298ed59c53SKun Wu 
10308ed59c53SKun Wu #endif // MLIR_ENABLE_CUDA_CUSPARSELT
10318ed59c53SKun Wu #endif // MLIR_ENABLE_CUDA_CUSPARSE
1032