19d7be77bSChristian Sigg //===- CudaRuntimeWrappers.cpp - MLIR CUDA API wrapper library ------------===// 29d7be77bSChristian Sigg // 39d7be77bSChristian Sigg // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 49d7be77bSChristian Sigg // See https://llvm.org/LICENSE.txt for license information. 59d7be77bSChristian Sigg // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 69d7be77bSChristian Sigg // 79d7be77bSChristian Sigg //===----------------------------------------------------------------------===// 89d7be77bSChristian Sigg // 99d7be77bSChristian Sigg // Implements C wrappers around the CUDA library for easy linking in ORC jit. 109d7be77bSChristian Sigg // Also adds some debugging helpers that are helpful when writing MLIR code to 119d7be77bSChristian Sigg // run on GPUs. 129d7be77bSChristian Sigg // 139d7be77bSChristian Sigg //===----------------------------------------------------------------------===// 149d7be77bSChristian Sigg 159d7be77bSChristian Sigg #include "mlir/ExecutionEngine/CRunnerUtils.h" 169d7be77bSChristian Sigg 17b9f87e24SAart Bik #include <stdio.h> 18b9f87e24SAart Bik 199d7be77bSChristian Sigg #include "cuda.h" 20cc402de0SKun Wu #include "cuda_bf16.h" 21cc402de0SKun Wu #include "cuda_fp16.h" 228ed59c53SKun Wu 2350db4789SAart Bik #ifdef MLIR_ENABLE_CUDA_CUSPARSE 24b700a90cSAart Bik #include "cusparse.h" 2550db4789SAart Bik #ifdef MLIR_ENABLE_CUDA_CUSPARSELT 268ed59c53SKun Wu #include "cusparseLt.h" 278ed59c53SKun Wu #endif // MLIR_ENABLE_CUDA_CUSPARSELT 288ed59c53SKun Wu #endif // MLIR_ENABLE_CUDA_CUSPARSE 299d7be77bSChristian Sigg 309d7be77bSChristian Sigg #ifdef _WIN32 315e78417dSJustin Holewinski #include <malloc.h> 329d7be77bSChristian Sigg #define MLIR_CUDA_WRAPPERS_EXPORT __declspec(dllexport) 339d7be77bSChristian Sigg #else 341c2a0768SAdam Paszke #define MLIR_CUDA_WRAPPERS_EXPORT __attribute__((visibility("default"))) 359d7be77bSChristian Sigg #endif // _WIN32 369d7be77bSChristian Sigg 379d7be77bSChristian Sigg #define CUDA_REPORT_IF_ERROR(expr) \ 389d7be77bSChristian Sigg [](CUresult result) { \ 399d7be77bSChristian Sigg if (!result) \ 409d7be77bSChristian Sigg return; \ 419d7be77bSChristian Sigg const char *name = nullptr; \ 429d7be77bSChristian Sigg cuGetErrorName(result, &name); \ 439d7be77bSChristian Sigg if (!name) \ 449d7be77bSChristian Sigg name = "<unknown>"; \ 459d7be77bSChristian Sigg fprintf(stderr, "'%s' failed with '%s'\n", #expr, name); \ 469d7be77bSChristian Sigg }(expr) 479d7be77bSChristian Sigg 48b700a90cSAart Bik #define CUSPARSE_REPORT_IF_ERROR(expr) \ 49b700a90cSAart Bik { \ 50b700a90cSAart Bik cusparseStatus_t status = (expr); \ 51b700a90cSAart Bik if (status != CUSPARSE_STATUS_SUCCESS) { \ 52b700a90cSAart Bik fprintf(stderr, "cuSPARSE '%s' failed with '%s'\n", #expr, \ 53b700a90cSAart Bik cusparseGetErrorString(status)); \ 54b700a90cSAart Bik } \ 55b700a90cSAart Bik } 56b700a90cSAart Bik 5784718d37SKrzysztof Drewniak thread_local static int32_t defaultDevice = 0; 5884718d37SKrzysztof Drewniak 5919b11079SGuray Ozen const char *kDebugEnvironmentVariable = "MLIR_CUDA_DEBUG"; 6019b11079SGuray Ozen 6119b11079SGuray Ozen /// Helper method that checks environment value for debugging. 6219b11079SGuray Ozen bool isDebugEnabled() { 6319b11079SGuray Ozen static bool isInitialized = false; 6419b11079SGuray Ozen static bool isEnabled = false; 6519b11079SGuray Ozen if (!isInitialized) 6619b11079SGuray Ozen isEnabled = getenv(kDebugEnvironmentVariable) != nullptr; 6719b11079SGuray Ozen return isEnabled; 6819b11079SGuray Ozen } 6919b11079SGuray Ozen 7019b11079SGuray Ozen #define debug_print(fmt, ...) \ 7119b11079SGuray Ozen do { \ 7219b11079SGuray Ozen if (isDebugEnabled()) \ 7319b11079SGuray Ozen fprintf(stderr, "%s:%d:%s(): " fmt, "CudaRuntimeWrappers.cpp", __LINE__, \ 7419b11079SGuray Ozen __func__, __VA_ARGS__); \ 7519b11079SGuray Ozen } while (0) 7619b11079SGuray Ozen 7753881490SGuray Ozen // Returns default CUdevice 7853881490SGuray Ozen CUdevice getDefaultCuDevice() { 7953881490SGuray Ozen CUdevice device; 8053881490SGuray Ozen CUDA_REPORT_IF_ERROR(cuDeviceGet(&device, /*ordinal=*/defaultDevice)); 8153881490SGuray Ozen return device; 8253881490SGuray Ozen } 8353881490SGuray Ozen 8484718d37SKrzysztof Drewniak // Make the primary context of the current default device current for the 8584718d37SKrzysztof Drewniak // duration 8684718d37SKrzysztof Drewniak // of the instance and restore the previous context on destruction. 879d7be77bSChristian Sigg class ScopedContext { 889d7be77bSChristian Sigg public: 899d7be77bSChristian Sigg ScopedContext() { 9084718d37SKrzysztof Drewniak // Static reference to CUDA primary context for device ordinal 9184718d37SKrzysztof Drewniak // defaultDevice. 92f69d5a7fSChristian Sigg static CUcontext context = [] { 93f69d5a7fSChristian Sigg CUDA_REPORT_IF_ERROR(cuInit(/*flags=*/0)); 94f69d5a7fSChristian Sigg CUcontext ctx; 95f69d5a7fSChristian Sigg // Note: this does not affect the current context. 9653881490SGuray Ozen CUDA_REPORT_IF_ERROR( 9753881490SGuray Ozen cuDevicePrimaryCtxRetain(&ctx, getDefaultCuDevice())); 98f69d5a7fSChristian Sigg return ctx; 99f69d5a7fSChristian Sigg }(); 100f69d5a7fSChristian Sigg 101f69d5a7fSChristian Sigg CUDA_REPORT_IF_ERROR(cuCtxPushCurrent(context)); 1029d7be77bSChristian Sigg } 1039d7be77bSChristian Sigg 104f69d5a7fSChristian Sigg ~ScopedContext() { CUDA_REPORT_IF_ERROR(cuCtxPopCurrent(nullptr)); } 1059d7be77bSChristian Sigg }; 1069d7be77bSChristian Sigg 10703125e68SAart Bik #ifdef MLIR_ENABLE_CUDA_CUSPARSE 108be2dd22bSKun Wu // Note that (1) Nvidia confirms the safety to share handle across multiple 109be2dd22bSKun Wu // instances, and streams. (2) Clients are responsible to call the @mgpu 110be2dd22bSKun Wu // environment initialization/destruction in a thread-safe manner, e.g., 111be2dd22bSKun Wu // at the beginning of the program before multi-threads are created. 112be2dd22bSKun Wu static cusparseHandle_t cusparse_env = nullptr; 113be2dd22bSKun Wu 114be2dd22bSKun Wu #ifdef MLIR_ENABLE_CUDA_CUSPARSELT 115be2dd22bSKun Wu // cusparseLtHandle_t is not a pointer type, so we need an additional flag to 116be2dd22bSKun Wu // indicate whether it is initialized. 117be2dd22bSKun Wu static cusparseLtHandle_t cusparseLt_env; 118be2dd22bSKun Wu static bool cusparseLt_initiated = false; 119be2dd22bSKun Wu 120be2dd22bSKun Wu #endif // MLIR_ENABLE_CUDA_CUSPARSELT 121be2dd22bSKun Wu #endif // MLIR_ENABLE_CUDA_CUSPARSE 122be2dd22bSKun Wu 123ebfea261SNishant Patel extern "C" MLIR_CUDA_WRAPPERS_EXPORT CUmodule 124ebfea261SNishant Patel mgpuModuleLoad(void *data, size_t /*gpuBlobSize*/) { 1259d7be77bSChristian Sigg ScopedContext scopedContext; 1269d7be77bSChristian Sigg CUmodule module = nullptr; 1279d7be77bSChristian Sigg CUDA_REPORT_IF_ERROR(cuModuleLoadData(&module, data)); 1289d7be77bSChristian Sigg return module; 1299d7be77bSChristian Sigg } 1309d7be77bSChristian Sigg 1315093413aSFabian Mora extern "C" MLIR_CUDA_WRAPPERS_EXPORT CUmodule mgpuModuleLoadJIT(void *data, 1325093413aSFabian Mora int optLevel) { 1335093413aSFabian Mora ScopedContext scopedContext; 1345093413aSFabian Mora CUmodule module = nullptr; 1355093413aSFabian Mora char jitErrorBuffer[4096] = {0}; 1365093413aSFabian Mora CUjit_option jitOptions[] = {CU_JIT_ERROR_LOG_BUFFER, 1375093413aSFabian Mora CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES, 1385093413aSFabian Mora CU_JIT_OPTIMIZATION_LEVEL}; 1395093413aSFabian Mora void *jitOptionsVals[] = {jitErrorBuffer, 1405093413aSFabian Mora reinterpret_cast<void *>(sizeof(jitErrorBuffer)), 1415093413aSFabian Mora reinterpret_cast<void *>(optLevel)}; 1425093413aSFabian Mora 1435093413aSFabian Mora CUresult result = 1445093413aSFabian Mora cuModuleLoadDataEx(&module, data, 3, jitOptions, jitOptionsVals); 1455093413aSFabian Mora if (result) { 1465093413aSFabian Mora fprintf(stderr, "JIT compilation failed with: '%s'\n", jitErrorBuffer); 1475093413aSFabian Mora CUDA_REPORT_IF_ERROR(result); 1485093413aSFabian Mora } 1495093413aSFabian Mora return module; 1505093413aSFabian Mora } 1515093413aSFabian Mora 1529d7be77bSChristian Sigg extern "C" MLIR_CUDA_WRAPPERS_EXPORT void mgpuModuleUnload(CUmodule module) { 1539d7be77bSChristian Sigg CUDA_REPORT_IF_ERROR(cuModuleUnload(module)); 1549d7be77bSChristian Sigg } 1559d7be77bSChristian Sigg 1569d7be77bSChristian Sigg extern "C" MLIR_CUDA_WRAPPERS_EXPORT CUfunction 1579d7be77bSChristian Sigg mgpuModuleGetFunction(CUmodule module, const char *name) { 1589d7be77bSChristian Sigg CUfunction function = nullptr; 1599d7be77bSChristian Sigg CUDA_REPORT_IF_ERROR(cuModuleGetFunction(&function, module, name)); 1609d7be77bSChristian Sigg return function; 1619d7be77bSChristian Sigg } 1629d7be77bSChristian Sigg 1639d7be77bSChristian Sigg // The wrapper uses intptr_t instead of CUDA's unsigned int to match 1649d7be77bSChristian Sigg // the type of MLIR's index type. This avoids the need for casts in the 1659d7be77bSChristian Sigg // generated MLIR code. 1669d7be77bSChristian Sigg extern "C" MLIR_CUDA_WRAPPERS_EXPORT void 1679d7be77bSChristian Sigg mgpuLaunchKernel(CUfunction function, intptr_t gridX, intptr_t gridY, 1689d7be77bSChristian Sigg intptr_t gridZ, intptr_t blockX, intptr_t blockY, 1699d7be77bSChristian Sigg intptr_t blockZ, int32_t smem, CUstream stream, void **params, 170ebfea261SNishant Patel void **extra, size_t /*paramsCount*/) { 1719d7be77bSChristian Sigg ScopedContext scopedContext; 1725ef45c02SGuray Ozen if (smem > 0) { 1735ef45c02SGuray Ozen // Avoid checking driver as it's more expensive than if statement 17453881490SGuray Ozen int32_t maxShmem = 0; 17553881490SGuray Ozen CUdevice device = getDefaultCuDevice(); 17653881490SGuray Ozen CUDA_REPORT_IF_ERROR(cuDeviceGet(&device, /*ordinal=*/defaultDevice)); 17753881490SGuray Ozen CUDA_REPORT_IF_ERROR(cuDeviceGetAttribute( 17853881490SGuray Ozen &maxShmem, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN, 17953881490SGuray Ozen device)); 18053881490SGuray Ozen if (maxShmem < smem) { 18153881490SGuray Ozen fprintf(stderr, 18253881490SGuray Ozen "Requested shared memory (%dkb) is larger than maximum allowed " 18353881490SGuray Ozen "shared memory (%dkb) for this device\n", 18453881490SGuray Ozen smem, maxShmem); 18553881490SGuray Ozen } 18653881490SGuray Ozen CUDA_REPORT_IF_ERROR(cuFuncSetAttribute( 18753881490SGuray Ozen function, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, smem)); 1885ef45c02SGuray Ozen } 18953881490SGuray Ozen debug_print("Launching kernel, grid=%ld,%ld,%ld, " 19053881490SGuray Ozen "threads: %ld, %ld, %ld, " 19153881490SGuray Ozen "smem: %dkb\n", 19253881490SGuray Ozen gridX, gridY, gridZ, blockX, blockY, blockZ, smem); 1939d7be77bSChristian Sigg CUDA_REPORT_IF_ERROR(cuLaunchKernel(function, gridX, gridY, gridZ, blockX, 1949d7be77bSChristian Sigg blockY, blockZ, smem, stream, params, 1959d7be77bSChristian Sigg extra)); 1969d7be77bSChristian Sigg } 1979d7be77bSChristian Sigg 1989d7be77bSChristian Sigg extern "C" MLIR_CUDA_WRAPPERS_EXPORT CUstream mgpuStreamCreate() { 1999d7be77bSChristian Sigg ScopedContext scopedContext; 2009d7be77bSChristian Sigg CUstream stream = nullptr; 2019d7be77bSChristian Sigg CUDA_REPORT_IF_ERROR(cuStreamCreate(&stream, CU_STREAM_NON_BLOCKING)); 2029d7be77bSChristian Sigg return stream; 2039d7be77bSChristian Sigg } 2049d7be77bSChristian Sigg 2059d7be77bSChristian Sigg extern "C" MLIR_CUDA_WRAPPERS_EXPORT void mgpuStreamDestroy(CUstream stream) { 2069d7be77bSChristian Sigg CUDA_REPORT_IF_ERROR(cuStreamDestroy(stream)); 2079d7be77bSChristian Sigg } 2089d7be77bSChristian Sigg 2099d7be77bSChristian Sigg extern "C" MLIR_CUDA_WRAPPERS_EXPORT void 2109d7be77bSChristian Sigg mgpuStreamSynchronize(CUstream stream) { 2119d7be77bSChristian Sigg CUDA_REPORT_IF_ERROR(cuStreamSynchronize(stream)); 2129d7be77bSChristian Sigg } 2139d7be77bSChristian Sigg 2149d7be77bSChristian Sigg extern "C" MLIR_CUDA_WRAPPERS_EXPORT void mgpuStreamWaitEvent(CUstream stream, 2159d7be77bSChristian Sigg CUevent event) { 2169d7be77bSChristian Sigg CUDA_REPORT_IF_ERROR(cuStreamWaitEvent(stream, event, /*flags=*/0)); 2179d7be77bSChristian Sigg } 2189d7be77bSChristian Sigg 2199d7be77bSChristian Sigg extern "C" MLIR_CUDA_WRAPPERS_EXPORT CUevent mgpuEventCreate() { 2209d7be77bSChristian Sigg ScopedContext scopedContext; 2219d7be77bSChristian Sigg CUevent event = nullptr; 2229d7be77bSChristian Sigg CUDA_REPORT_IF_ERROR(cuEventCreate(&event, CU_EVENT_DISABLE_TIMING)); 2239d7be77bSChristian Sigg return event; 2249d7be77bSChristian Sigg } 2259d7be77bSChristian Sigg 2269d7be77bSChristian Sigg extern "C" MLIR_CUDA_WRAPPERS_EXPORT void mgpuEventDestroy(CUevent event) { 2279d7be77bSChristian Sigg CUDA_REPORT_IF_ERROR(cuEventDestroy(event)); 2289d7be77bSChristian Sigg } 2299d7be77bSChristian Sigg 2301c2a0768SAdam Paszke extern "C" MLIR_CUDA_WRAPPERS_EXPORT void mgpuEventSynchronize(CUevent event) { 2319d7be77bSChristian Sigg CUDA_REPORT_IF_ERROR(cuEventSynchronize(event)); 2329d7be77bSChristian Sigg } 2339d7be77bSChristian Sigg 2341c2a0768SAdam Paszke extern "C" MLIR_CUDA_WRAPPERS_EXPORT void mgpuEventRecord(CUevent event, 2359d7be77bSChristian Sigg CUstream stream) { 2369d7be77bSChristian Sigg CUDA_REPORT_IF_ERROR(cuEventRecord(event, stream)); 2379d7be77bSChristian Sigg } 2389d7be77bSChristian Sigg 2391c2a0768SAdam Paszke extern "C" MLIR_CUDA_WRAPPERS_EXPORT void * 240*20861f1fSGuray Ozen mgpuMemAlloc(uint64_t sizeBytes, CUstream stream, bool isHostShared) { 2419d7be77bSChristian Sigg ScopedContext scopedContext; 2423635c743SAart Bik CUdeviceptr ptr = 0; 243*20861f1fSGuray Ozen if (sizeBytes == 0) 244*20861f1fSGuray Ozen return reinterpret_cast<void *>(ptr); 245*20861f1fSGuray Ozen 246*20861f1fSGuray Ozen if (isHostShared) { 247*20861f1fSGuray Ozen CUDA_REPORT_IF_ERROR( 248*20861f1fSGuray Ozen cuMemAllocManaged(&ptr, sizeBytes, CU_MEM_ATTACH_GLOBAL)); 249*20861f1fSGuray Ozen return reinterpret_cast<void *>(ptr); 250*20861f1fSGuray Ozen } 2519d7be77bSChristian Sigg CUDA_REPORT_IF_ERROR(cuMemAlloc(&ptr, sizeBytes)); 2529d7be77bSChristian Sigg return reinterpret_cast<void *>(ptr); 2539d7be77bSChristian Sigg } 2549d7be77bSChristian Sigg 2551c2a0768SAdam Paszke extern "C" MLIR_CUDA_WRAPPERS_EXPORT void mgpuMemFree(void *ptr, 2561c2a0768SAdam Paszke CUstream /*stream*/) { 2579d7be77bSChristian Sigg CUDA_REPORT_IF_ERROR(cuMemFree(reinterpret_cast<CUdeviceptr>(ptr))); 2589d7be77bSChristian Sigg } 2599d7be77bSChristian Sigg 2601c2a0768SAdam Paszke extern "C" MLIR_CUDA_WRAPPERS_EXPORT void 2611c2a0768SAdam Paszke mgpuMemcpy(void *dst, void *src, size_t sizeBytes, CUstream stream) { 2629d7be77bSChristian Sigg CUDA_REPORT_IF_ERROR(cuMemcpyAsync(reinterpret_cast<CUdeviceptr>(dst), 2639d7be77bSChristian Sigg reinterpret_cast<CUdeviceptr>(src), 2649d7be77bSChristian Sigg sizeBytes, stream)); 2659d7be77bSChristian Sigg } 2669d7be77bSChristian Sigg 2671c2a0768SAdam Paszke extern "C" MLIR_CUDA_WRAPPERS_EXPORT void 2681c2a0768SAdam Paszke mgpuMemset32(void *dst, unsigned int value, size_t count, CUstream stream) { 269361458b1SLoren Maggiore CUDA_REPORT_IF_ERROR(cuMemsetD32Async(reinterpret_cast<CUdeviceptr>(dst), 270361458b1SLoren Maggiore value, count, stream)); 271361458b1SLoren Maggiore } 272361458b1SLoren Maggiore 2731c2a0768SAdam Paszke extern "C" MLIR_CUDA_WRAPPERS_EXPORT void 2741c2a0768SAdam Paszke mgpuMemset16(void *dst, unsigned short value, size_t count, CUstream stream) { 27518cc07aaSNavdeep Katel CUDA_REPORT_IF_ERROR(cuMemsetD16Async(reinterpret_cast<CUdeviceptr>(dst), 27618cc07aaSNavdeep Katel value, count, stream)); 27718cc07aaSNavdeep Katel } 27818cc07aaSNavdeep Katel 279b700a90cSAart Bik /// 2809d7be77bSChristian Sigg /// Helper functions for writing mlir example code 281b700a90cSAart Bik /// 2829d7be77bSChristian Sigg 2839d7be77bSChristian Sigg // Allows to register byte array with the CUDA runtime. Helpful until we have 2849d7be77bSChristian Sigg // transfer functions implemented. 2859d7be77bSChristian Sigg extern "C" MLIR_CUDA_WRAPPERS_EXPORT void 2869d7be77bSChristian Sigg mgpuMemHostRegister(void *ptr, uint64_t sizeBytes) { 2879d7be77bSChristian Sigg ScopedContext scopedContext; 2889d7be77bSChristian Sigg CUDA_REPORT_IF_ERROR(cuMemHostRegister(ptr, sizeBytes, /*flags=*/0)); 2899d7be77bSChristian Sigg } 2909d7be77bSChristian Sigg 2914edc9e2aSUday Bondhugula /// Registers a memref with the CUDA runtime. `descriptor` is a pointer to a 2924edc9e2aSUday Bondhugula /// ranked memref descriptor struct of rank `rank`. Helpful until we have 2934edc9e2aSUday Bondhugula /// transfer functions implemented. 2949d7be77bSChristian Sigg extern "C" MLIR_CUDA_WRAPPERS_EXPORT void 2959d7be77bSChristian Sigg mgpuMemHostRegisterMemRef(int64_t rank, StridedMemRefType<char, 1> *descriptor, 2969d7be77bSChristian Sigg int64_t elementSizeBytes) { 2979d7be77bSChristian Sigg // Only densely packed tensors are currently supported. 2985e78417dSJustin Holewinski #ifdef _WIN32 2995e78417dSJustin Holewinski int64_t *denseStrides = (int64_t *)_alloca(rank * sizeof(int64_t)); 3005e78417dSJustin Holewinski #else 3014edc9e2aSUday Bondhugula int64_t *denseStrides = (int64_t *)alloca(rank * sizeof(int64_t)); 3025e78417dSJustin Holewinski #endif // _WIN32 3034edc9e2aSUday Bondhugula int64_t *sizes = descriptor->sizes; 3044edc9e2aSUday Bondhugula for (int64_t i = rank - 1, runningStride = 1; i >= 0; i--) { 3054edc9e2aSUday Bondhugula denseStrides[i] = runningStride; 3064edc9e2aSUday Bondhugula runningStride *= sizes[i]; 3074edc9e2aSUday Bondhugula } 3084edc9e2aSUday Bondhugula uint64_t sizeBytes = sizes[0] * denseStrides[0] * elementSizeBytes; 3094edc9e2aSUday Bondhugula int64_t *strides = &sizes[rank]; 310012c0cc7SNicolas Vasilache (void)strides; 3114edc9e2aSUday Bondhugula for (unsigned i = 0; i < rank; ++i) 3124edc9e2aSUday Bondhugula assert(strides[i] == denseStrides[i] && 3134edc9e2aSUday Bondhugula "Mismatch in computed dense strides"); 3149d7be77bSChristian Sigg 3154edc9e2aSUday Bondhugula auto *ptr = descriptor->data + descriptor->offset * elementSizeBytes; 3169d7be77bSChristian Sigg mgpuMemHostRegister(ptr, sizeBytes); 3179d7be77bSChristian Sigg } 31884718d37SKrzysztof Drewniak 3198f7c8a6eSmax // Allows to unregister byte array with the CUDA runtime. 3208f7c8a6eSmax extern "C" MLIR_CUDA_WRAPPERS_EXPORT void mgpuMemHostUnregister(void *ptr) { 3218f7c8a6eSmax ScopedContext scopedContext; 3228f7c8a6eSmax CUDA_REPORT_IF_ERROR(cuMemHostUnregister(ptr)); 3238f7c8a6eSmax } 3248f7c8a6eSmax 3258f7c8a6eSmax /// Unregisters a memref with the CUDA runtime. `descriptor` is a pointer to a 3268f7c8a6eSmax /// ranked memref descriptor struct of rank `rank` 3278f7c8a6eSmax extern "C" MLIR_CUDA_WRAPPERS_EXPORT void 3288f7c8a6eSmax mgpuMemHostUnregisterMemRef(int64_t rank, 3298f7c8a6eSmax StridedMemRefType<char, 1> *descriptor, 3308f7c8a6eSmax int64_t elementSizeBytes) { 3318f7c8a6eSmax auto *ptr = descriptor->data + descriptor->offset * elementSizeBytes; 3328f7c8a6eSmax mgpuMemHostUnregister(ptr); 3338f7c8a6eSmax } 3348f7c8a6eSmax 33584718d37SKrzysztof Drewniak extern "C" MLIR_CUDA_WRAPPERS_EXPORT void mgpuSetDefaultDevice(int32_t device) { 33684718d37SKrzysztof Drewniak defaultDevice = device; 33784718d37SKrzysztof Drewniak } 338b700a90cSAart Bik 3391dc00712SGuray Ozen /// 3401dc00712SGuray Ozen /// Runtime methods using CUDA 12.0+ driver 3411dc00712SGuray Ozen /// 3421dc00712SGuray Ozen 3431dc00712SGuray Ozen #if (CUDA_VERSION >= 12000) 3441dc00712SGuray Ozen 345f21a70f9SGuray Ozen extern "C" MLIR_CUDA_WRAPPERS_EXPORT void mgpuLaunchClusterKernel( 346f21a70f9SGuray Ozen CUfunction function, intptr_t clusterX, intptr_t clusterY, 347f21a70f9SGuray Ozen intptr_t clusterZ, intptr_t gridX, intptr_t gridY, intptr_t gridZ, 348f21a70f9SGuray Ozen intptr_t blockX, intptr_t blockY, intptr_t blockZ, int32_t smem, 349f21a70f9SGuray Ozen CUstream stream, void **params, void **extra, size_t /*paramsCount*/) { 350f21a70f9SGuray Ozen ScopedContext scopedContext; 351f21a70f9SGuray Ozen if (smem > 0) { 352f21a70f9SGuray Ozen // Avoid checking driver as it's more expensive than if statement 353f21a70f9SGuray Ozen int32_t maxShmem = 0; 354f21a70f9SGuray Ozen CUdevice device = getDefaultCuDevice(); 355f21a70f9SGuray Ozen CUDA_REPORT_IF_ERROR(cuDeviceGet(&device, /*ordinal=*/defaultDevice)); 356f21a70f9SGuray Ozen CUDA_REPORT_IF_ERROR(cuDeviceGetAttribute( 357f21a70f9SGuray Ozen &maxShmem, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN, 358f21a70f9SGuray Ozen device)); 359f21a70f9SGuray Ozen if (maxShmem < smem) { 360f21a70f9SGuray Ozen fprintf(stderr, 361f21a70f9SGuray Ozen "Requested shared memory (%dkb) is larger than maximum allowed " 362f21a70f9SGuray Ozen "shared memory (%dkb) for this device\n", 363f21a70f9SGuray Ozen smem, maxShmem); 364f21a70f9SGuray Ozen } 365f21a70f9SGuray Ozen CUDA_REPORT_IF_ERROR(cuFuncSetAttribute( 366f21a70f9SGuray Ozen function, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, smem)); 367f21a70f9SGuray Ozen } 368f21a70f9SGuray Ozen CUlaunchConfig config; 369f21a70f9SGuray Ozen config.gridDimX = gridX; 370f21a70f9SGuray Ozen config.gridDimY = gridY; 371f21a70f9SGuray Ozen config.gridDimZ = gridZ; 372f21a70f9SGuray Ozen config.blockDimX = blockX; 373f21a70f9SGuray Ozen config.blockDimY = blockY; 374f21a70f9SGuray Ozen config.blockDimZ = blockZ; 375f21a70f9SGuray Ozen config.sharedMemBytes = smem; 376f21a70f9SGuray Ozen config.hStream = stream; 377f21a70f9SGuray Ozen CUlaunchAttribute launchAttr[2]; 378f21a70f9SGuray Ozen launchAttr[0].id = CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION; 379f21a70f9SGuray Ozen launchAttr[0].value.clusterDim.x = clusterX; 380f21a70f9SGuray Ozen launchAttr[0].value.clusterDim.y = clusterY; 381f21a70f9SGuray Ozen launchAttr[0].value.clusterDim.z = clusterZ; 382f21a70f9SGuray Ozen launchAttr[1].id = CU_LAUNCH_ATTRIBUTE_CLUSTER_SCHEDULING_POLICY_PREFERENCE; 383f21a70f9SGuray Ozen launchAttr[1].value.clusterSchedulingPolicyPreference = 384f21a70f9SGuray Ozen CU_CLUSTER_SCHEDULING_POLICY_SPREAD; 385f21a70f9SGuray Ozen config.numAttrs = 2; 386f21a70f9SGuray Ozen config.attrs = launchAttr; 387f21a70f9SGuray Ozen 388f21a70f9SGuray Ozen debug_print("Launching kernel," 389f21a70f9SGuray Ozen "cluster: %ld, %ld, %ld, " 390f21a70f9SGuray Ozen "grid=%ld,%ld,%ld, " 391f21a70f9SGuray Ozen "threads: %ld, %ld, %ld, " 392f21a70f9SGuray Ozen "smem: %dkb\n", 393f21a70f9SGuray Ozen clusterX, clusterY, clusterZ, gridX, gridY, gridZ, blockX, blockY, 394f21a70f9SGuray Ozen blockZ, smem); 395f21a70f9SGuray Ozen 396f21a70f9SGuray Ozen CUDA_REPORT_IF_ERROR(cuLaunchKernelEx(&config, function, params, extra)); 397f21a70f9SGuray Ozen } 398f21a70f9SGuray Ozen 399e56d6745SGuray Ozen extern "C" MLIR_CUDA_WRAPPERS_EXPORT void mgpuTensorMapEncodeTiled( 400e56d6745SGuray Ozen CUtensorMap *tensorMap, // Tensor map object 401e56d6745SGuray Ozen CUtensorMapDataType tensorDataType, // Tensor data type 402e56d6745SGuray Ozen cuuint32_t tensorRank, // Dimensionality of tensor 403e56d6745SGuray Ozen void *globalAddress, // Starting address 404e56d6745SGuray Ozen const cuuint64_t *globalDim, // Tensor size (number of elements) 405e56d6745SGuray Ozen const cuuint64_t *globalStrides, // Stride size (in bytes) 406e56d6745SGuray Ozen const cuuint32_t *boxDim, // Traversal box (number of elments) 407e56d6745SGuray Ozen const cuuint32_t *elementStrides, // Traversal stride 408e56d6745SGuray Ozen CUtensorMapInterleave interleave, // Type of interleaved layout 409e56d6745SGuray Ozen CUtensorMapSwizzle swizzle, // Bank swizzling pattern 410e56d6745SGuray Ozen CUtensorMapL2promotion l2Promotion, // L2 promotion size 411e56d6745SGuray Ozen CUtensorMapFloatOOBfill oobFill // Padding zfill or NaN fill 412e56d6745SGuray Ozen ) { 413e56d6745SGuray Ozen ScopedContext scopedContext; 414e56d6745SGuray Ozen CUDA_REPORT_IF_ERROR(cuTensorMapEncodeTiled( 415e56d6745SGuray Ozen tensorMap, tensorDataType, tensorRank, globalAddress, globalDim, 416e56d6745SGuray Ozen globalStrides, boxDim, elementStrides, interleave, swizzle, l2Promotion, 417e56d6745SGuray Ozen oobFill)); 41819b11079SGuray Ozen debug_print("Created TMA descriptor\n Addr: %p\n" 41919b11079SGuray Ozen "data type : %d\n" 42019b11079SGuray Ozen "rank : %d\n" 42119b11079SGuray Ozen "globalDim[5]: %zu, %zu, %zu, %zu, %zu\n" 42219b11079SGuray Ozen "globalStrides[5]: %zu, %zu, %zu, %zu, %zu\n" 42319b11079SGuray Ozen "boxDim[5]: %u, %u, %u, %u, %u\n" 42419b11079SGuray Ozen "elementStrides[5]: %u, %u, %u, %u, %u\n" 42519b11079SGuray Ozen "interleave: %u \n" 42619b11079SGuray Ozen "swizzle: %u \n" 42719b11079SGuray Ozen "l2Promotion: %u \n" 42819b11079SGuray Ozen "oobFill: %u \n", 42919b11079SGuray Ozen (void *)&tensorMap, tensorDataType, tensorRank, globalDim[0], 43019b11079SGuray Ozen globalDim[1], globalDim[2], globalDim[3], globalDim[4], 43119b11079SGuray Ozen globalStrides[0], globalStrides[1], globalStrides[2], 43219b11079SGuray Ozen globalStrides[3], globalStrides[4], boxDim[0], boxDim[1], 43319b11079SGuray Ozen boxDim[2], boxDim[3], boxDim[4], elementStrides[0], 43419b11079SGuray Ozen elementStrides[1], elementStrides[2], elementStrides[3], 43519b11079SGuray Ozen elementStrides[4], interleave, swizzle, l2Promotion, oobFill); 436e56d6745SGuray Ozen } 437e56d6745SGuray Ozen 4387d55b916SGuray Ozen template <int Rank> 4397d55b916SGuray Ozen void mgpuGetMemRefDataAndShape(void *rawDescriptor, char **addr, 4407d55b916SGuray Ozen uint64_t *globalDim, uint64_t *globalStrides, 4417d55b916SGuray Ozen const CUtensorMapDataType tensorDataType) { 44265aab9e7SAdam Paszke auto descriptor = 4437d55b916SGuray Ozen reinterpret_cast<StridedMemRefType<char, Rank> *>(rawDescriptor); 44465aab9e7SAdam Paszke *addr = descriptor->data; 4457d55b916SGuray Ozen for (int i = 0; i < Rank; ++i) { 4467d55b916SGuray Ozen globalDim[i] = static_cast<uint64_t>(descriptor->sizes[Rank - i - 1]); 4477d55b916SGuray Ozen } 4487d55b916SGuray Ozen static constexpr int elementSizeInBytes[] = {1, 2, 4, 4, 8, 8, 2, 4497d55b916SGuray Ozen 4, 8, 2, 4, 4, 4}; 4507d55b916SGuray Ozen for (int i = 0; i < Rank - 1; ++i) { 4517d55b916SGuray Ozen globalStrides[i] = static_cast<uint64_t>( 4527d55b916SGuray Ozen descriptor->strides[Rank - i - 2] * elementSizeInBytes[tensorDataType]); 45365aab9e7SAdam Paszke } 45465aab9e7SAdam Paszke } 45565aab9e7SAdam Paszke 456e56d6745SGuray Ozen extern "C" MLIR_CUDA_WRAPPERS_EXPORT void *mgpuTensorMapEncodeTiledMemref( 457e56d6745SGuray Ozen int64_t tensorRank, // Dimensionality of tensor 4587d55b916SGuray Ozen void *rankedDescriptor, // Ranked MemRef descriptor 459e56d6745SGuray Ozen const CUtensorMapDataType tensorDataType, // Stride size (in bytes) 460e56d6745SGuray Ozen CUtensorMapInterleave interleave, // Type of interleaved layout 461e56d6745SGuray Ozen CUtensorMapSwizzle swizzle, // Bank swizzling pattern 462e56d6745SGuray Ozen CUtensorMapL2promotion l2Promotion, // L2 promotion size 463e56d6745SGuray Ozen CUtensorMapFloatOOBfill oobFill, // Padding zfill or NaN fill 464e56d6745SGuray Ozen int64_t *inputBoxDims // Tensor size (number of elements) 465e56d6745SGuray Ozen ) { 466e56d6745SGuray Ozen CUtensorMap tensorMap; 467e56d6745SGuray Ozen 4682379432cSGuray Ozen uint32_t boxDim[5] = {1, 1, 1, 1, 1}, elementStrides[5] = {1, 1, 1, 1, 1}; 4692379432cSGuray Ozen uint64_t globalDim[5] = {1, 1, 1, 1, 1}, globalStrides[5] = {0}; 470e56d6745SGuray Ozen uint32_t tensorRank32 = uint32_t(tensorRank); 471e56d6745SGuray Ozen 47265aab9e7SAdam Paszke char *globalAddress = nullptr; 47365aab9e7SAdam Paszke switch (tensorRank) { 47465aab9e7SAdam Paszke case 1: 4757d55b916SGuray Ozen mgpuGetMemRefDataAndShape<1>(rankedDescriptor, &globalAddress, globalDim, 4767d55b916SGuray Ozen globalStrides, tensorDataType); 47765aab9e7SAdam Paszke break; 47865aab9e7SAdam Paszke case 2: 4797d55b916SGuray Ozen mgpuGetMemRefDataAndShape<2>(rankedDescriptor, &globalAddress, globalDim, 4807d55b916SGuray Ozen globalStrides, tensorDataType); 48165aab9e7SAdam Paszke break; 48265aab9e7SAdam Paszke case 3: 4837d55b916SGuray Ozen mgpuGetMemRefDataAndShape<3>(rankedDescriptor, &globalAddress, globalDim, 4847d55b916SGuray Ozen globalStrides, tensorDataType); 48565aab9e7SAdam Paszke break; 48665aab9e7SAdam Paszke case 4: 4877d55b916SGuray Ozen mgpuGetMemRefDataAndShape<4>(rankedDescriptor, &globalAddress, globalDim, 4887d55b916SGuray Ozen globalStrides, tensorDataType); 48965aab9e7SAdam Paszke break; 49065aab9e7SAdam Paszke case 5: 4917d55b916SGuray Ozen mgpuGetMemRefDataAndShape<5>(rankedDescriptor, &globalAddress, globalDim, 4927d55b916SGuray Ozen globalStrides, tensorDataType); 49365aab9e7SAdam Paszke break; 49465aab9e7SAdam Paszke default: 49565aab9e7SAdam Paszke fprintf( 49665aab9e7SAdam Paszke stderr, 49765aab9e7SAdam Paszke "'mgpuTensorMapEncodeTiledMemref' failed with 'rank is too high'\n"); 4987d55b916SGuray Ozen return nullptr; 49965aab9e7SAdam Paszke } 50065aab9e7SAdam Paszke 501e56d6745SGuray Ozen for (int64_t r = 0; r < tensorRank; ++r) { 502e56d6745SGuray Ozen boxDim[r] = static_cast<uint32_t>(inputBoxDims[tensorRank - r - 1]); 503e56d6745SGuray Ozen } 504e56d6745SGuray Ozen 505e56d6745SGuray Ozen ScopedContext scopedContext; 506e56d6745SGuray Ozen mgpuTensorMapEncodeTiled(&tensorMap, tensorDataType, tensorRank32, 507e56d6745SGuray Ozen globalAddress, globalDim, globalStrides, boxDim, 508e56d6745SGuray Ozen elementStrides, interleave, swizzle, l2Promotion, 509e56d6745SGuray Ozen oobFill); 510e56d6745SGuray Ozen // Copy created tensor map to device 511e56d6745SGuray Ozen CUdeviceptr dTensorMap; 512e56d6745SGuray Ozen CUDA_REPORT_IF_ERROR(cuMemAlloc(&dTensorMap, sizeof(CUtensorMap))); 513e56d6745SGuray Ozen CUDA_REPORT_IF_ERROR(cuMemcpy(dTensorMap, 514e56d6745SGuray Ozen reinterpret_cast<CUdeviceptr>(&tensorMap), 515e56d6745SGuray Ozen sizeof(CUtensorMap))); 516e56d6745SGuray Ozen return reinterpret_cast<void *>(dTensorMap); 517e56d6745SGuray Ozen } 5181dc00712SGuray Ozen #endif 519e56d6745SGuray Ozen 52050db4789SAart Bik #ifdef MLIR_ENABLE_CUDA_CUSPARSE 5218ed59c53SKun Wu 522b700a90cSAart Bik /// 523b700a90cSAart Bik /// Wrapper methods for the cuSparse library. 524b700a90cSAart Bik /// 525b700a90cSAart Bik 5264ebd836dSAart Bik // Some macro magic to get float/double alpha and beta on host. 527dfe29429SKun Wu // TODO: add support to passing alpha and beta as arguments 528cc402de0SKun Wu #define ALPHABETA(dtp, alpha, beta) \ 529be6c5320SKun Wu __nv_bfloat16(alpha##16bf) = 1.0f; \ 530be6c5320SKun Wu __nv_bfloat16(beta##16bf) = 1.0f; \ 531be6c5320SKun Wu __half(alpha##16f) = 1.0f; \ 532be6c5320SKun Wu __half(beta##16f) = 1.0f; \ 533bcb698bfSAart Bik float(alpha##f) = 1.0f; \ 534bcb698bfSAart Bik float(beta##f) = 1.0f; \ 535bcb698bfSAart Bik double(alpha##d) = 1.0; \ 536bcb698bfSAart Bik double(beta##d) = 1.0; \ 537bcb698bfSAart Bik const void *(alpha##p) = nullptr; \ 538bcb698bfSAart Bik const void *(beta##p) = nullptr; \ 539cc402de0SKun Wu if (dtp == CUDA_R_16BF || dtp == CUDA_C_16BF) { \ 540cc402de0SKun Wu (alpha##p) = reinterpret_cast<void *>(&(alpha##16bf)); \ 541cc402de0SKun Wu (beta##p) = reinterpret_cast<void *>(&(beta##16bf)); \ 542cc402de0SKun Wu } else if (dtp == CUDA_R_16F || dtp == CUDA_C_16F) { \ 543cc402de0SKun Wu (alpha##p) = reinterpret_cast<void *>(&(alpha##16f)); \ 544cc402de0SKun Wu (beta##p) = reinterpret_cast<void *>(&(beta##16f)); \ 545cc402de0SKun Wu } else if (dtp == CUDA_R_32F || dtp == CUDA_C_32F) { \ 5464ebd836dSAart Bik (alpha##p) = reinterpret_cast<void *>(&(alpha##f)); \ 5474ebd836dSAart Bik (beta##p) = reinterpret_cast<void *>(&(beta##f)); \ 548be6c5320SKun Wu } else { \ 5494ebd836dSAart Bik (alpha##p) = reinterpret_cast<void *>(&(alpha##d)); \ 5504ebd836dSAart Bik (beta##p) = reinterpret_cast<void *>(&(beta##d)); \ 5514ebd836dSAart Bik } 5524ebd836dSAart Bik 553be2dd22bSKun Wu extern "C" MLIR_CUDA_WRAPPERS_EXPORT void mgpuCreateSparseEnv() { 554be2dd22bSKun Wu // ScopedContext is for cuda initialization. 555be2dd22bSKun Wu ScopedContext scopedContext; 556be2dd22bSKun Wu assert(!cusparse_env && "client called mgpuCreateSparseEnv() twice"); 557be2dd22bSKun Wu CUSPARSE_REPORT_IF_ERROR(cusparseCreate(&cusparse_env)); 558b700a90cSAart Bik } 559b700a90cSAart Bik 560be2dd22bSKun Wu extern "C" MLIR_CUDA_WRAPPERS_EXPORT void mgpuDestroySparseEnv() { 561be2dd22bSKun Wu assert(cusparse_env && "client did not call mgpuCreateSparseEnv()"); 562be2dd22bSKun Wu CUSPARSE_REPORT_IF_ERROR(cusparseDestroy(cusparse_env)); 563be2dd22bSKun Wu cusparse_env = nullptr; 564b700a90cSAart Bik } 565b700a90cSAart Bik 566b700a90cSAart Bik extern "C" MLIR_CUDA_WRAPPERS_EXPORT void * 567cc402de0SKun Wu mgpuCreateDnVec(intptr_t size, void *values, int32_t dtp, CUstream /*stream*/) { 568b700a90cSAart Bik cusparseDnVecDescr_t vec = nullptr; 569cc402de0SKun Wu auto dTp = static_cast<cudaDataType_t>(dtp); 570cc402de0SKun Wu CUSPARSE_REPORT_IF_ERROR(cusparseCreateDnVec(&vec, size, values, dTp)) 571b700a90cSAart Bik return reinterpret_cast<void *>(vec); 572b700a90cSAart Bik } 573b700a90cSAart Bik 574b700a90cSAart Bik extern "C" MLIR_CUDA_WRAPPERS_EXPORT void 575b700a90cSAart Bik mgpuDestroyDnVec(void *v, CUstream /*stream*/) { 576b700a90cSAart Bik cusparseDnVecDescr_t vec = reinterpret_cast<cusparseDnVecDescr_t>(v); 577b700a90cSAart Bik CUSPARSE_REPORT_IF_ERROR(cusparseDestroyDnVec(vec)) 578b700a90cSAart Bik } 579b700a90cSAart Bik 580b700a90cSAart Bik extern "C" MLIR_CUDA_WRAPPERS_EXPORT void * 581cc402de0SKun Wu mgpuCreateDnMat(intptr_t rows, intptr_t cols, void *values, int32_t dtp, 582b700a90cSAart Bik CUstream /*stream*/) { 583b700a90cSAart Bik cusparseDnMatDescr_t mat = nullptr; 584cc402de0SKun Wu auto dTp = static_cast<cudaDataType_t>(dtp); 585b700a90cSAart Bik CUSPARSE_REPORT_IF_ERROR(cusparseCreateDnMat(&mat, rows, cols, /*ld=*/cols, 586cc402de0SKun Wu values, dTp, CUSPARSE_ORDER_ROW)) 587b700a90cSAart Bik return reinterpret_cast<void *>(mat); 588b700a90cSAart Bik } 589b700a90cSAart Bik 590b700a90cSAart Bik extern "C" MLIR_CUDA_WRAPPERS_EXPORT void 591b700a90cSAart Bik mgpuDestroyDnMat(void *m, CUstream /*stream*/) { 592b700a90cSAart Bik cusparseDnMatDescr_t mat = reinterpret_cast<cusparseDnMatDescr_t>(m); 593b700a90cSAart Bik CUSPARSE_REPORT_IF_ERROR(cusparseDestroyDnMat(mat)) 594b700a90cSAart Bik } 595b700a90cSAart Bik 596b700a90cSAart Bik extern "C" MLIR_CUDA_WRAPPERS_EXPORT void * 597b700a90cSAart Bik mgpuCreateCoo(intptr_t rows, intptr_t cols, intptr_t nnz, void *rowIdxs, 598cc402de0SKun Wu void *colIdxs, void *values, int32_t itp, int32_t dtp, 599b700a90cSAart Bik CUstream /*stream*/) { 600b700a90cSAart Bik cusparseSpMatDescr_t mat = nullptr; 601cc402de0SKun Wu auto iTp = static_cast<cusparseIndexType_t>(itp); 602cc402de0SKun Wu auto dTp = static_cast<cudaDataType_t>(dtp); 603b700a90cSAart Bik CUSPARSE_REPORT_IF_ERROR(cusparseCreateCoo(&mat, rows, cols, nnz, rowIdxs, 604cc402de0SKun Wu colIdxs, values, iTp, 605cc402de0SKun Wu CUSPARSE_INDEX_BASE_ZERO, dTp)) 606b700a90cSAart Bik return reinterpret_cast<void *>(mat); 607b700a90cSAart Bik } 608b700a90cSAart Bik 6099fc02a7aSAart Bik #ifdef CUSPARSE_COO_AOS // deprecated in cuSPARSE 11.2 6109fc02a7aSAart Bik extern "C" MLIR_CUDA_WRAPPERS_EXPORT void * 6119fc02a7aSAart Bik mgpuCreateCooAoS(intptr_t rows, intptr_t cols, intptr_t nnz, void *idxs, 6129fc02a7aSAart Bik void *values, int32_t itp, int32_t dtp, CUstream /*stream*/) { 6139fc02a7aSAart Bik cusparseSpMatDescr_t mat = nullptr; 6149fc02a7aSAart Bik auto iTp = static_cast<cusparseIndexType_t>(itp); 6159fc02a7aSAart Bik auto dTp = static_cast<cudaDataType_t>(dtp); 6169fc02a7aSAart Bik CUSPARSE_REPORT_IF_ERROR(cusparseCreateCooAoS( 6179fc02a7aSAart Bik &mat, rows, cols, nnz, idxs, values, iTp, CUSPARSE_INDEX_BASE_ZERO, dTp)) 6189fc02a7aSAart Bik return reinterpret_cast<void *>(mat); 6199fc02a7aSAart Bik } 6209fc02a7aSAart Bik #endif // CUSPARSE_COO_AOS 6219fc02a7aSAart Bik 622b700a90cSAart Bik extern "C" MLIR_CUDA_WRAPPERS_EXPORT void * 623b700a90cSAart Bik mgpuCreateCsr(intptr_t rows, intptr_t cols, intptr_t nnz, void *rowPos, 624cc402de0SKun Wu void *colIdxs, void *values, int32_t ptp, int32_t itp, 625cc402de0SKun Wu int32_t dtp, CUstream /*stream*/) { 626b700a90cSAart Bik cusparseSpMatDescr_t mat = nullptr; 627cc402de0SKun Wu auto pTp = static_cast<cusparseIndexType_t>(ptp); 628cc402de0SKun Wu auto iTp = static_cast<cusparseIndexType_t>(itp); 629be6c5320SKun Wu auto dTp = static_cast<cudaDataType_t>(dtp); 630b700a90cSAart Bik CUSPARSE_REPORT_IF_ERROR(cusparseCreateCsr(&mat, rows, cols, nnz, rowPos, 631cc402de0SKun Wu colIdxs, values, pTp, iTp, 6327e44f073SKun Wu CUSPARSE_INDEX_BASE_ZERO, dTp)) 633b700a90cSAart Bik return reinterpret_cast<void *>(mat); 634b700a90cSAart Bik } 635b700a90cSAart Bik 63639038177SAart Bik extern "C" MLIR_CUDA_WRAPPERS_EXPORT void * 63739038177SAart Bik mgpuCreateCsc(intptr_t rows, intptr_t cols, intptr_t nnz, void *colPos, 63839038177SAart Bik void *rowIdxs, void *values, int32_t ptp, int32_t itp, 63939038177SAart Bik int32_t dtp, CUstream /*stream*/) { 64039038177SAart Bik cusparseSpMatDescr_t mat = nullptr; 64139038177SAart Bik auto pTp = static_cast<cusparseIndexType_t>(ptp); 64239038177SAart Bik auto iTp = static_cast<cusparseIndexType_t>(itp); 64339038177SAart Bik auto dTp = static_cast<cudaDataType_t>(dtp); 64439038177SAart Bik CUSPARSE_REPORT_IF_ERROR(cusparseCreateCsc(&mat, rows, cols, nnz, colPos, 64539038177SAart Bik rowIdxs, values, pTp, iTp, 64639038177SAart Bik CUSPARSE_INDEX_BASE_ZERO, dTp)) 64739038177SAart Bik return reinterpret_cast<void *>(mat); 64839038177SAart Bik } 64939038177SAart Bik 65039038177SAart Bik extern "C" MLIR_CUDA_WRAPPERS_EXPORT void * 65139038177SAart Bik mgpuCreateBsr(intptr_t brows, intptr_t bcols, intptr_t bnnz, intptr_t rBsz, 65239038177SAart Bik intptr_t cBsz, void *rowPos, void *colIdxs, void *values, 65339038177SAart Bik int32_t ptp, int32_t itp, int32_t dtp, CUstream /*stream*/) { 65439038177SAart Bik cusparseSpMatDescr_t mat = nullptr; 6557ac330a4SAart Bik #if CUSPARSE_VERSION >= 12100 65639038177SAart Bik auto pTp = static_cast<cusparseIndexType_t>(ptp); 65739038177SAart Bik auto iTp = static_cast<cusparseIndexType_t>(itp); 65839038177SAart Bik auto dTp = static_cast<cudaDataType_t>(dtp); 65939038177SAart Bik CUSPARSE_REPORT_IF_ERROR(cusparseCreateBsr( 66039038177SAart Bik &mat, brows, bcols, bnnz, rBsz, cBsz, rowPos, colIdxs, values, pTp, iTp, 66139038177SAart Bik CUSPARSE_INDEX_BASE_ZERO, dTp, CUSPARSE_ORDER_ROW)) 6627ac330a4SAart Bik #endif 66339038177SAart Bik return reinterpret_cast<void *>(mat); 66439038177SAart Bik } 66539038177SAart Bik 666b700a90cSAart Bik extern "C" MLIR_CUDA_WRAPPERS_EXPORT void 667b700a90cSAart Bik mgpuDestroySpMat(void *m, CUstream /*stream*/) { 668b700a90cSAart Bik cusparseSpMatDescr_t mat = reinterpret_cast<cusparseSpMatDescr_t>(m); 669b700a90cSAart Bik CUSPARSE_REPORT_IF_ERROR(cusparseDestroySpMat(mat)) 670b700a90cSAart Bik } 671b700a90cSAart Bik 672be2dd22bSKun Wu extern "C" MLIR_CUDA_WRAPPERS_EXPORT intptr_t mgpuSpMVBufferSize( 673be2dd22bSKun Wu int32_t ma, void *a, void *x, void *y, int32_t ctp, CUstream /*stream*/) { 674be2dd22bSKun Wu assert(cusparse_env && "client did not call mgpuCreateSparseEnv()"); 675235fbe79SKun Wu cusparseOperation_t modeA = static_cast<cusparseOperation_t>(ma); 676b700a90cSAart Bik cusparseSpMatDescr_t matA = reinterpret_cast<cusparseSpMatDescr_t>(a); 677b700a90cSAart Bik cusparseDnVecDescr_t vecX = reinterpret_cast<cusparseDnVecDescr_t>(x); 678b700a90cSAart Bik cusparseDnVecDescr_t vecY = reinterpret_cast<cusparseDnVecDescr_t>(y); 679cc402de0SKun Wu cudaDataType_t cTp = static_cast<cudaDataType_t>(ctp); 680cc402de0SKun Wu ALPHABETA(cTp, alpha, beta) 681b700a90cSAart Bik size_t bufferSize = 0; 682be2dd22bSKun Wu CUSPARSE_REPORT_IF_ERROR(cusparseSpMV_bufferSize( 683be2dd22bSKun Wu cusparse_env, modeA, alphap, matA, vecX, betap, vecY, cTp, 684be2dd22bSKun Wu CUSPARSE_SPMV_ALG_DEFAULT, &bufferSize)) 6853635c743SAart Bik return bufferSize; 686b700a90cSAart Bik } 687b700a90cSAart Bik 688be2dd22bSKun Wu extern "C" MLIR_CUDA_WRAPPERS_EXPORT void mgpuSpMV(int32_t ma, void *a, void *x, 689be2dd22bSKun Wu void *y, int32_t ctp, 690be2dd22bSKun Wu void *buf, 691a8e1f80fSAart Bik CUstream /*stream*/) { 692be2dd22bSKun Wu assert(cusparse_env && "client did not call mgpuCreateSparseEnv()"); 693235fbe79SKun Wu cusparseOperation_t modeA = static_cast<cusparseOperation_t>(ma); 694b700a90cSAart Bik cusparseSpMatDescr_t matA = reinterpret_cast<cusparseSpMatDescr_t>(a); 695b700a90cSAart Bik cusparseDnVecDescr_t vecX = reinterpret_cast<cusparseDnVecDescr_t>(x); 696b700a90cSAart Bik cusparseDnVecDescr_t vecY = reinterpret_cast<cusparseDnVecDescr_t>(y); 697cc402de0SKun Wu cudaDataType_t cTp = static_cast<cudaDataType_t>(ctp); 698cc402de0SKun Wu ALPHABETA(cTp, alpha, beta) 699be2dd22bSKun Wu CUSPARSE_REPORT_IF_ERROR(cusparseSpMV(cusparse_env, modeA, alphap, matA, vecX, 700cc402de0SKun Wu betap, vecY, cTp, 701235fbe79SKun Wu CUSPARSE_SPMV_ALG_DEFAULT, buf)) 702981cf167SAart Bik } 703981cf167SAart Bik 704235fbe79SKun Wu extern "C" MLIR_CUDA_WRAPPERS_EXPORT intptr_t 705be2dd22bSKun Wu mgpuSpMMBufferSize(int32_t ma, int32_t mb, void *a, void *b, void *c, 706cc402de0SKun Wu int32_t ctp, CUstream /*stream*/) { 707be2dd22bSKun Wu assert(cusparse_env && "client did not call mgpuCreateSparseEnv()"); 708235fbe79SKun Wu cusparseOperation_t modeA = static_cast<cusparseOperation_t>(ma); 709235fbe79SKun Wu cusparseOperation_t modeB = static_cast<cusparseOperation_t>(mb); 710981cf167SAart Bik cusparseSpMatDescr_t matA = reinterpret_cast<cusparseSpMatDescr_t>(a); 711981cf167SAart Bik cusparseDnMatDescr_t matB = reinterpret_cast<cusparseDnMatDescr_t>(b); 712981cf167SAart Bik cusparseDnMatDescr_t matC = reinterpret_cast<cusparseDnMatDescr_t>(c); 713cc402de0SKun Wu cudaDataType_t cTp = static_cast<cudaDataType_t>(ctp); 714cc402de0SKun Wu ALPHABETA(cTp, alpha, beta) 715981cf167SAart Bik size_t bufferSize = 0; 716981cf167SAart Bik CUSPARSE_REPORT_IF_ERROR(cusparseSpMM_bufferSize( 717be2dd22bSKun Wu cusparse_env, modeA, modeB, alphap, matA, matB, betap, matC, cTp, 718a8e1f80fSAart Bik CUSPARSE_SPMM_ALG_DEFAULT, &bufferSize)) 7193635c743SAart Bik return bufferSize; 720981cf167SAart Bik } 721981cf167SAart Bik 722be2dd22bSKun Wu extern "C" MLIR_CUDA_WRAPPERS_EXPORT void mgpuSpMM(int32_t ma, int32_t mb, 723be2dd22bSKun Wu void *a, void *b, void *c, 724be2dd22bSKun Wu int32_t ctp, void *buf, 725be2dd22bSKun Wu CUstream /*stream*/) { 726be2dd22bSKun Wu assert(cusparse_env && "client did not call mgpuCreateSparseEnv()"); 727235fbe79SKun Wu cusparseOperation_t modeA = static_cast<cusparseOperation_t>(ma); 728235fbe79SKun Wu cusparseOperation_t modeB = static_cast<cusparseOperation_t>(mb); 729981cf167SAart Bik cusparseSpMatDescr_t matA = reinterpret_cast<cusparseSpMatDescr_t>(a); 730981cf167SAart Bik cusparseDnMatDescr_t matB = reinterpret_cast<cusparseDnMatDescr_t>(b); 731981cf167SAart Bik cusparseDnMatDescr_t matC = reinterpret_cast<cusparseDnMatDescr_t>(c); 732cc402de0SKun Wu cudaDataType_t cTp = static_cast<cudaDataType_t>(ctp); 733cc402de0SKun Wu ALPHABETA(cTp, alpha, beta) 734be2dd22bSKun Wu CUSPARSE_REPORT_IF_ERROR(cusparseSpMM(cusparse_env, modeA, modeB, alphap, 735be2dd22bSKun Wu matA, matB, betap, matC, cTp, 736235fbe79SKun Wu CUSPARSE_SPMM_ALG_DEFAULT, buf)) 737b700a90cSAart Bik } 738cf44847bSKun Wu 739cf44847bSKun Wu extern "C" MLIR_CUDA_WRAPPERS_EXPORT intptr_t 740be2dd22bSKun Wu mgpuSDDMMBufferSize(int32_t ma, int32_t mb, void *a, void *b, void *c, 741cc402de0SKun Wu int32_t ctp, CUstream /*stream*/) { 742be2dd22bSKun Wu assert(cusparse_env && "client did not call mgpuCreateSparseEnv()"); 743cf44847bSKun Wu cusparseOperation_t modeA = static_cast<cusparseOperation_t>(ma); 744cf44847bSKun Wu cusparseOperation_t modeB = static_cast<cusparseOperation_t>(mb); 745cf44847bSKun Wu cusparseDnMatDescr_t matA = reinterpret_cast<cusparseDnMatDescr_t>(a); 746cf44847bSKun Wu cusparseDnMatDescr_t matB = reinterpret_cast<cusparseDnMatDescr_t>(b); 747cf44847bSKun Wu cusparseSpMatDescr_t matC = reinterpret_cast<cusparseSpMatDescr_t>(c); 748cc402de0SKun Wu auto cTp = static_cast<cudaDataType_t>(ctp); 749cc402de0SKun Wu ALPHABETA(cTp, alpha, beta) 750cf44847bSKun Wu size_t bufferSize = 0; 751cf44847bSKun Wu CUSPARSE_REPORT_IF_ERROR(cusparseSDDMM_bufferSize( 752be2dd22bSKun Wu cusparse_env, modeA, modeB, alphap, matA, matB, betap, matC, cTp, 753cf44847bSKun Wu CUSPARSE_SDDMM_ALG_DEFAULT, &bufferSize)) 7543635c743SAart Bik return bufferSize; 755cf44847bSKun Wu } 756cf44847bSKun Wu 757be2dd22bSKun Wu extern "C" MLIR_CUDA_WRAPPERS_EXPORT void mgpuSDDMM(int32_t ma, int32_t mb, 758be2dd22bSKun Wu void *a, void *b, void *c, 759be2dd22bSKun Wu int32_t ctp, void *buf, 760be2dd22bSKun Wu CUstream /*stream*/) { 761be2dd22bSKun Wu assert(cusparse_env && "client did not call mgpuCreateSparseEnv()"); 762cf44847bSKun Wu cusparseOperation_t modeA = static_cast<cusparseOperation_t>(ma); 763cf44847bSKun Wu cusparseOperation_t modeB = static_cast<cusparseOperation_t>(mb); 764cf44847bSKun Wu cusparseDnMatDescr_t matA = reinterpret_cast<cusparseDnMatDescr_t>(a); 765cf44847bSKun Wu cusparseDnMatDescr_t matB = reinterpret_cast<cusparseDnMatDescr_t>(b); 766cf44847bSKun Wu cusparseSpMatDescr_t matC = reinterpret_cast<cusparseSpMatDescr_t>(c); 767cc402de0SKun Wu auto cTp = static_cast<cudaDataType_t>(ctp); 768cc402de0SKun Wu ALPHABETA(cTp, alpha, beta) 769be2dd22bSKun Wu CUSPARSE_REPORT_IF_ERROR(cusparseSDDMM(cusparse_env, modeA, modeB, alphap, 770be2dd22bSKun Wu matA, matB, betap, matC, cTp, 771cf44847bSKun Wu CUSPARSE_SDDMM_ALG_DEFAULT, buf)) 772cf44847bSKun Wu } 7738ed59c53SKun Wu 774289f7231SAart Bik extern "C" MLIR_CUDA_WRAPPERS_EXPORT void * 775289f7231SAart Bik mgpuSpGEMMCreateDescr(CUstream /*stream*/) { 776289f7231SAart Bik cusparseSpGEMMDescr_t spgemmDesc = nullptr; 777289f7231SAart Bik CUSPARSE_REPORT_IF_ERROR(cusparseSpGEMM_createDescr(&spgemmDesc)) 778289f7231SAart Bik return reinterpret_cast<void *>(spgemmDesc); 779289f7231SAart Bik } 780289f7231SAart Bik 781289f7231SAart Bik extern "C" MLIR_CUDA_WRAPPERS_EXPORT void 782289f7231SAart Bik mgpuSpGEMMDestroyDescr(void *s, CUstream /*stream*/) { 783289f7231SAart Bik cusparseSpGEMMDescr_t spgemmDesc = reinterpret_cast<cusparseSpGEMMDescr_t>(s); 784289f7231SAart Bik CUSPARSE_REPORT_IF_ERROR(cusparseSpGEMM_destroyDescr(spgemmDesc)) 785289f7231SAart Bik } 786289f7231SAart Bik 787dfe29429SKun Wu extern "C" MLIR_CUDA_WRAPPERS_EXPORT intptr_t mgpuSpGEMMWorkEstimation( 788dfe29429SKun Wu void *s, int32_t ma, int32_t mb, void *a, void *b, void *c, int32_t ctp, 789e7e4ed0dSAart Bik intptr_t bs, void *buf, CUstream /*stream*/) { 790dfe29429SKun Wu cusparseSpGEMMDescr_t spgemmDesc = reinterpret_cast<cusparseSpGEMMDescr_t>(s); 791dfe29429SKun Wu cusparseOperation_t modeA = static_cast<cusparseOperation_t>(ma); 792dfe29429SKun Wu cusparseOperation_t modeB = static_cast<cusparseOperation_t>(mb); 793dfe29429SKun Wu cusparseSpMatDescr_t matA = reinterpret_cast<cusparseSpMatDescr_t>(a); 794dfe29429SKun Wu cusparseSpMatDescr_t matB = reinterpret_cast<cusparseSpMatDescr_t>(b); 795dfe29429SKun Wu cusparseSpMatDescr_t matC = reinterpret_cast<cusparseSpMatDescr_t>(c); 796dfe29429SKun Wu auto cTp = static_cast<cudaDataType_t>(ctp); 797dfe29429SKun Wu ALPHABETA(cTp, alpha, beta) 798dfe29429SKun Wu size_t newBufferSize = bs; 799dfe29429SKun Wu CUSPARSE_REPORT_IF_ERROR(cusparseSpGEMM_workEstimation( 800dfe29429SKun Wu cusparse_env, modeA, modeB, alphap, matA, matB, betap, matC, cTp, 801e7e4ed0dSAart Bik CUSPARSE_SPGEMM_DEFAULT, spgemmDesc, &newBufferSize, buf)) 8023635c743SAart Bik return newBufferSize; 803dfe29429SKun Wu } 804dfe29429SKun Wu 805e7e4ed0dSAart Bik extern "C" MLIR_CUDA_WRAPPERS_EXPORT intptr_t 806e7e4ed0dSAart Bik mgpuSpGEMMCompute(void *s, int32_t ma, int32_t mb, void *a, void *b, void *c, 807e7e4ed0dSAart Bik int32_t ctp, intptr_t bsz2, void *buf2, CUstream /*stream*/) { 808dfe29429SKun Wu cusparseSpGEMMDescr_t spgemmDesc = reinterpret_cast<cusparseSpGEMMDescr_t>(s); 809dfe29429SKun Wu cusparseOperation_t modeA = static_cast<cusparseOperation_t>(ma); 810dfe29429SKun Wu cusparseOperation_t modeB = static_cast<cusparseOperation_t>(mb); 811dfe29429SKun Wu cusparseSpMatDescr_t matA = reinterpret_cast<cusparseSpMatDescr_t>(a); 812dfe29429SKun Wu cusparseSpMatDescr_t matB = reinterpret_cast<cusparseSpMatDescr_t>(b); 813dfe29429SKun Wu cusparseSpMatDescr_t matC = reinterpret_cast<cusparseSpMatDescr_t>(c); 814dfe29429SKun Wu auto cTp = static_cast<cudaDataType_t>(ctp); 815dfe29429SKun Wu ALPHABETA(cTp, alpha, beta) 816dfe29429SKun Wu size_t newBufferSize2 = bsz2; 817dfe29429SKun Wu CUSPARSE_REPORT_IF_ERROR(cusparseSpGEMM_compute( 818dfe29429SKun Wu cusparse_env, modeA, modeB, alphap, matA, matB, betap, matC, cTp, 819e7e4ed0dSAart Bik CUSPARSE_SPGEMM_DEFAULT, spgemmDesc, &newBufferSize2, buf2)) 8203635c743SAart Bik return newBufferSize2; 821dfe29429SKun Wu } 822dfe29429SKun Wu 823dfe29429SKun Wu extern "C" MLIR_CUDA_WRAPPERS_EXPORT void 824dfe29429SKun Wu mgpuSpGEMMCopy(void *s, int32_t ma, int32_t mb, void *a, void *b, void *c, 825e7e4ed0dSAart Bik int32_t ctp, CUstream /*stream*/) { 826dfe29429SKun Wu cusparseSpGEMMDescr_t spgemmDesc = reinterpret_cast<cusparseSpGEMMDescr_t>(s); 827dfe29429SKun Wu cusparseOperation_t modeA = static_cast<cusparseOperation_t>(ma); 828dfe29429SKun Wu cusparseOperation_t modeB = static_cast<cusparseOperation_t>(mb); 829dfe29429SKun Wu cusparseSpMatDescr_t matA = reinterpret_cast<cusparseSpMatDescr_t>(a); 830dfe29429SKun Wu cusparseSpMatDescr_t matB = reinterpret_cast<cusparseSpMatDescr_t>(b); 831dfe29429SKun Wu cusparseSpMatDescr_t matC = reinterpret_cast<cusparseSpMatDescr_t>(c); 832dfe29429SKun Wu auto cTp = static_cast<cudaDataType_t>(ctp); 833dfe29429SKun Wu ALPHABETA(cTp, alpha, beta) 834e7e4ed0dSAart Bik CUSPARSE_REPORT_IF_ERROR( 835e7e4ed0dSAart Bik cusparseSpGEMM_copy(cusparse_env, modeA, modeB, alphap, matA, matB, betap, 836e7e4ed0dSAart Bik matC, cTp, CUSPARSE_SPGEMM_DEFAULT, spgemmDesc)) 837dfe29429SKun Wu } 838dfe29429SKun Wu 839dfe29429SKun Wu extern "C" MLIR_CUDA_WRAPPERS_EXPORT void 840289f7231SAart Bik mgpuSpMatGetSize(void *m, void *r, void *c, void *n, CUstream /*stream*/) { 841dfe29429SKun Wu cusparseConstSpMatDescr_t matDescr = 842dfe29429SKun Wu reinterpret_cast<cusparseConstSpMatDescr_t>(m); 843dfe29429SKun Wu int64_t *rows = reinterpret_cast<int64_t *>(r); 844dfe29429SKun Wu int64_t *cols = reinterpret_cast<int64_t *>(c); 845dfe29429SKun Wu int64_t *nnz = reinterpret_cast<int64_t *>(n); 846dfe29429SKun Wu CUSPARSE_REPORT_IF_ERROR(cusparseSpMatGetSize(matDescr, rows, cols, nnz)); 847dfe29429SKun Wu } 848dfe29429SKun Wu 84995a6c509SAart Bik extern "C" MLIR_CUDA_WRAPPERS_EXPORT void 85095a6c509SAart Bik mgpuSetCsrPointers(void *m, void *p, void *c, void *v, CUstream /*stream*/) { 85195a6c509SAart Bik cusparseSpMatDescr_t matDescr = reinterpret_cast<cusparseSpMatDescr_t>(m); 85295a6c509SAart Bik CUSPARSE_REPORT_IF_ERROR(cusparseCsrSetPointers(matDescr, p, c, v)); 85395a6c509SAart Bik } 85495a6c509SAart Bik 85550db4789SAart Bik #ifdef MLIR_ENABLE_CUDA_CUSPARSELT 8568ed59c53SKun Wu 8578ed59c53SKun Wu /// 8588ed59c53SKun Wu /// Wrapper methods for the cuSparseLt library. 8598ed59c53SKun Wu /// 8608ed59c53SKun Wu 8618ed59c53SKun Wu struct cusparseLtSpMatHandleAndData { 8628ed59c53SKun Wu cusparseLtMatDescriptor_t mat; 863ac30f48eSKun Wu // TODO: the following three are associated with the SpMM operator rather than 864ac30f48eSKun Wu // the sparse matrix. Create workspace buffers and pass them to the SpMM 8658ed59c53SKun Wu // execution. 8668ed59c53SKun Wu cusparseLtMatmulAlgSelection_t alg_sel; 8678ed59c53SKun Wu cusparseLtMatmulPlan_t plan; 8688ed59c53SKun Wu cusparseLtMatmulDescriptor_t matmul; 869ac30f48eSKun Wu void *values{nullptr}; 8708ed59c53SKun Wu }; 8718ed59c53SKun Wu 8728ed59c53SKun Wu struct cusparseLtDnMatHandleAndData { 8738ed59c53SKun Wu cusparseLtMatDescriptor_t mat; 8748ed59c53SKun Wu void *values{nullptr}; 8758ed59c53SKun Wu }; 8768ed59c53SKun Wu 8777a3ebba9SKun Wu static_assert(sizeof(cusparseLtHandle_t) == 11024, 8787a3ebba9SKun Wu "Unexpected cusparseLt handle size"); 8797a3ebba9SKun Wu static_assert(sizeof(cusparseLtSpMatHandleAndData) == 44104, 8807a3ebba9SKun Wu "Unexpected cusparseLt sparse matrix handle size"); 8817a3ebba9SKun Wu static_assert(sizeof(cusparseLtDnMatHandleAndData) == 11032, 8827a3ebba9SKun Wu "Unexpected cusparseLt dense matrix handle size"); 8838ed59c53SKun Wu 884be2dd22bSKun Wu extern "C" MLIR_CUDA_WRAPPERS_EXPORT void mgpuCreateSparseLtEnv() { 885be2dd22bSKun Wu // ScopedContext is for cuda initialization. 886be2dd22bSKun Wu ScopedContext scopedContext; 887be2dd22bSKun Wu assert(!cusparseLt_initiated && 888be2dd22bSKun Wu "client called mgpuCreateSparseLtEnv() twice"); 88986eff489SAart Bik // Note that cuSparseLt still uses cusparseStatus_t. 890be2dd22bSKun Wu CUSPARSE_REPORT_IF_ERROR(cusparseLtInit(&cusparseLt_env)); 891be2dd22bSKun Wu cusparseLt_initiated = true; 892be2dd22bSKun Wu } 893be2dd22bSKun Wu 894be2dd22bSKun Wu extern "C" MLIR_CUDA_WRAPPERS_EXPORT void mgpuDestroySparseLtEnv() { 895be2dd22bSKun Wu assert(cusparseLt_initiated && "client did not call mgpuCreateSparseLtEnv()"); 896be2dd22bSKun Wu CUSPARSE_REPORT_IF_ERROR(cusparseLtDestroy(&cusparseLt_env)); 897be2dd22bSKun Wu cusparseLt_initiated = false; 8988ed59c53SKun Wu } 8998ed59c53SKun Wu 9008ed59c53SKun Wu extern "C" MLIR_CUDA_WRAPPERS_EXPORT void 901be2dd22bSKun Wu mgpuCreateCuSparseLtDnMat(void *dh, intptr_t rows, intptr_t cols, void *values, 902be2dd22bSKun Wu int32_t dtp, CUstream /*stream*/) { 903be2dd22bSKun Wu assert(cusparseLt_initiated && "client did not call mgpuCreateSparseLtEnv()"); 9048ed59c53SKun Wu auto dnmat_handle = reinterpret_cast<cusparseLtDnMatHandleAndData *>(dh); 90586eff489SAart Bik dnmat_handle->values = values; 906ac30f48eSKun Wu auto dTp = static_cast<cudaDataType_t>(dtp); 90786eff489SAart Bik // Assume row-major when deciding lda. 90886eff489SAart Bik const uint32_t alignment = 16; 9098ed59c53SKun Wu CUSPARSE_REPORT_IF_ERROR(cusparseLtDenseDescriptorInit( 910be2dd22bSKun Wu &cusparseLt_env, &(dnmat_handle->mat), rows, cols, /*lda=*/cols, 91186eff489SAart Bik alignment, dTp, CUSPARSE_ORDER_ROW)) 9128ed59c53SKun Wu } 9138ed59c53SKun Wu 9148ed59c53SKun Wu extern "C" MLIR_CUDA_WRAPPERS_EXPORT void 91586eff489SAart Bik mgpuDestroyCuSparseLtDnMat(void *dh, CUstream /*stream*/) { 91686eff489SAart Bik auto dnmat_handle = reinterpret_cast<cusparseLtDnMatHandleAndData *>(dh); 91786eff489SAart Bik CUSPARSE_REPORT_IF_ERROR(cusparseLtMatDescriptorDestroy(&(dnmat_handle->mat))) 9188ed59c53SKun Wu } 9198ed59c53SKun Wu 9208ed59c53SKun Wu extern "C" MLIR_CUDA_WRAPPERS_EXPORT void 921be2dd22bSKun Wu mgpuCusparseLtCreate2To4SpMat(void *sh, intptr_t rows, intptr_t cols, 922ac30f48eSKun Wu void *values, int32_t dtp, CUstream /*stream*/) { 923be2dd22bSKun Wu assert(cusparseLt_initiated && "client did not call mgpuCreateSparseLtEnv()"); 9248ed59c53SKun Wu auto spmat_handle = reinterpret_cast<cusparseLtSpMatHandleAndData *>(sh); 9258ed59c53SKun Wu spmat_handle->values = values; 926ac30f48eSKun Wu auto dTp = static_cast<cudaDataType_t>(dtp); 92786eff489SAart Bik // Assume row-major when deciding lda. 92886eff489SAart Bik const uint32_t alignment = 16; 9298ed59c53SKun Wu CUSPARSE_REPORT_IF_ERROR(cusparseLtStructuredDescriptorInit( 93086eff489SAart Bik &cusparseLt_env, &(spmat_handle->mat), rows, cols, /*ld=*/cols, alignment, 93186eff489SAart Bik dTp, CUSPARSE_ORDER_ROW, CUSPARSELT_SPARSITY_50_PERCENT)) 93286eff489SAart Bik } 93386eff489SAart Bik 93486eff489SAart Bik extern "C" MLIR_CUDA_WRAPPERS_EXPORT void 93586eff489SAart Bik mgpuDestroyCuSparseLtSpMat(void *sh, CUstream /*stream*/) { 93686eff489SAart Bik auto spmat_handle = reinterpret_cast<cusparseLtSpMatHandleAndData *>(sh); 93786eff489SAart Bik CUSPARSE_REPORT_IF_ERROR(cusparseLtMatDescriptorDestroy(&(spmat_handle->mat))) 9388ed59c53SKun Wu } 9398ed59c53SKun Wu 9408ed59c53SKun Wu // Several things are being done in this stage, algorithm selection, planning, 9418ed59c53SKun Wu // and returning workspace and compressed matrices data buffer sizes. 9421e491c42SKun Wu // The parameter prune_flag is used to indicate whether pruning and pruning 9431e491c42SKun Wu // check will happen 0 means not prune or prune check, 1 means prune, 2 means 9441e491c42SKun Wu // prune & prune check 9458ed59c53SKun Wu extern "C" MLIR_CUDA_WRAPPERS_EXPORT void 946be2dd22bSKun Wu mgpuCuSparseLtSpMMBufferSize(void *bs, int32_t ma, int32_t mb, void *a, void *b, 9471e491c42SKun Wu void *c, int32_t ctp, int32_t prune_flag, 9481e491c42SKun Wu CUstream stream) { 949be2dd22bSKun Wu assert(cusparseLt_initiated && "client did not call mgpuCreateSparseLtEnv()"); 9508ed59c53SKun Wu // TODO: support more advanced settings, e.g., the input right operand is a 9518ed59c53SKun Wu // sparse matrix assuming matA is the sparse matrix 9528ed59c53SKun Wu auto matA = reinterpret_cast<cusparseLtSpMatHandleAndData *>(a); 953ac30f48eSKun Wu auto matB = reinterpret_cast<cusparseLtDnMatHandleAndData *>(b); 954ac30f48eSKun Wu auto matC = reinterpret_cast<cusparseLtDnMatHandleAndData *>(c); 9558998bcfbSAart Bik auto workspace_size = reinterpret_cast<size_t *>(bs); 9568998bcfbSAart Bik auto compressed_size = &(reinterpret_cast<size_t *>(bs)[1]); 9578998bcfbSAart Bik auto compressed_buffer_size = &(reinterpret_cast<size_t *>(bs)[2]); 958ac30f48eSKun Wu auto cTp = static_cast<cusparseComputeType>(ctp); 9598ed59c53SKun Wu 960ac30f48eSKun Wu cusparseOperation_t modeA = static_cast<cusparseOperation_t>(ma); 961ac30f48eSKun Wu cusparseOperation_t modeB = static_cast<cusparseOperation_t>(mb); 962ac30f48eSKun Wu CUSPARSE_REPORT_IF_ERROR(cusparseLtMatmulDescriptorInit( 963be2dd22bSKun Wu &cusparseLt_env, &(matA->matmul), modeA, modeB, &(matA->mat), 964be2dd22bSKun Wu &(matB->mat), &(matC->mat), &(matC->mat), cTp)) 965ac30f48eSKun Wu CUSPARSE_REPORT_IF_ERROR(cusparseLtMatmulAlgSelectionInit( 966be2dd22bSKun Wu &cusparseLt_env, &(matA->alg_sel), &(matA->matmul), 967be2dd22bSKun Wu CUSPARSELT_MATMUL_ALG_DEFAULT)) 9688ed59c53SKun Wu int alg = 0; 969ac30f48eSKun Wu CUSPARSE_REPORT_IF_ERROR(cusparseLtMatmulAlgSetAttribute( 970be2dd22bSKun Wu &cusparseLt_env, &(matA->alg_sel), CUSPARSELT_MATMUL_ALG_CONFIG_ID, &alg, 9718ed59c53SKun Wu sizeof(alg))) 9728ed59c53SKun Wu 973ac30f48eSKun Wu CUSPARSE_REPORT_IF_ERROR(cusparseLtMatmulPlanInit( 974be2dd22bSKun Wu &cusparseLt_env, &(matA->plan), &(matA->matmul), &(matA->alg_sel))) 975ac30f48eSKun Wu 9764df01dc2SAart Bik // Pruning step (in-place). 9771e491c42SKun Wu if (prune_flag > 0) 9781e491c42SKun Wu CUSPARSE_REPORT_IF_ERROR(cusparseLtSpMMAPrune( 9791e491c42SKun Wu &cusparseLt_env, &(matA->matmul), matA->values, matA->values, 9801e491c42SKun Wu CUSPARSELT_PRUNE_SPMMA_STRIP, stream)) 9814df01dc2SAart Bik 9824df01dc2SAart Bik // Check structure of A. 9834df01dc2SAart Bik // Note that this adds a synchronization on the stream. 9844df01dc2SAart Bik // TODO: Do we want that? 9851e491c42SKun Wu if (prune_flag == 2) { 98641a07e66SAart Bik int *dvalid = (int *)mgpuMemAlloc(sizeof(int), stream, false); 9874df01dc2SAart Bik CUSPARSE_REPORT_IF_ERROR(cusparseLtSpMMAPruneCheck( 9884df01dc2SAart Bik &cusparseLt_env, &(matA->matmul), matA->values, dvalid, stream)) 9894df01dc2SAart Bik int valid = 0; 9904df01dc2SAart Bik mgpuMemcpy(&valid, dvalid, sizeof(int), stream); 9914df01dc2SAart Bik mgpuStreamSynchronize(stream); 9924df01dc2SAart Bik mgpuMemFree(dvalid, stream); 9934df01dc2SAart Bik if (valid != 0) 9944df01dc2SAart Bik fprintf(stderr, "CUPARSE-LT: sparse matrix is not 2:4; computed results " 9954df01dc2SAart Bik "will be invalid\n"); 9961e491c42SKun Wu } 9974df01dc2SAart Bik 998be2dd22bSKun Wu CUSPARSE_REPORT_IF_ERROR(cusparseLtMatmulGetWorkspace( 9993635c743SAart Bik &cusparseLt_env, &(matA->plan), workspace_size)) 1000ac30f48eSKun Wu CUSPARSE_REPORT_IF_ERROR(cusparseLtSpMMACompressedSize( 10013635c743SAart Bik &cusparseLt_env, &(matA->plan), compressed_size, compressed_buffer_size)) 10028ed59c53SKun Wu } 10038ed59c53SKun Wu 10048ed59c53SKun Wu extern "C" MLIR_CUDA_WRAPPERS_EXPORT void 1005be2dd22bSKun Wu mgpuCuSparseLtSpMM(void *a, void *b, void *c, void *d_workspace, 1006ac30f48eSKun Wu void *dA_compressed, void *dA_compressedBuffer, 1007ac30f48eSKun Wu CUstream stream) { 1008be2dd22bSKun Wu assert(cusparseLt_initiated && "client did not call mgpuCreateSparseLtEnv()"); 10098ed59c53SKun Wu auto matA = reinterpret_cast<cusparseLtSpMatHandleAndData *>(a); 10108ed59c53SKun Wu auto matB = reinterpret_cast<cusparseLtDnMatHandleAndData *>(b); 10118ed59c53SKun Wu auto matC = reinterpret_cast<cusparseLtDnMatHandleAndData *>(c); 10128ed59c53SKun Wu 1013ac30f48eSKun Wu ALPHABETA(CUDA_R_32F, alpha, beta) 1014ac30f48eSKun Wu CUSPARSE_REPORT_IF_ERROR( 1015be2dd22bSKun Wu cusparseLtSpMMACompress(&cusparseLt_env, &(matA->plan), (matA->values), 1016ac30f48eSKun Wu dA_compressed, dA_compressedBuffer, stream)) 10178ed59c53SKun Wu 10188ed59c53SKun Wu // TODO: add support to multi-stream execution 10198ed59c53SKun Wu // Perform the matrix multiplication. D = A*B+C using C==D for now 1020ac30f48eSKun Wu CUSPARSE_REPORT_IF_ERROR( 1021be2dd22bSKun Wu cusparseLtMatmul(&cusparseLt_env, &(matA->plan), alphap, dA_compressed, 1022ac30f48eSKun Wu matB->values, betap, matC->values, 1023ac30f48eSKun Wu /*dD*/ matC->values, d_workspace, nullptr, 0)) 10248ed59c53SKun Wu 1025ac30f48eSKun Wu CUSPARSE_REPORT_IF_ERROR(cusparseLtMatDescriptorDestroy(&(matA->mat))) 10268ed59c53SKun Wu // destroy the plan associated with the sparse matrix 1027ac30f48eSKun Wu CUSPARSE_REPORT_IF_ERROR(cusparseLtMatmulPlanDestroy(&(matA->plan))) 10288ed59c53SKun Wu } 10298ed59c53SKun Wu 10308ed59c53SKun Wu #endif // MLIR_ENABLE_CUDA_CUSPARSELT 10318ed59c53SKun Wu #endif // MLIR_ENABLE_CUDA_CUSPARSE 1032