1 //===- CudaRuntimeWrappers.cpp - MLIR CUDA API wrapper library ------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Implements C wrappers around the CUDA library for easy linking in ORC jit. 10 // Also adds some debugging helpers that are helpful when writing MLIR code to 11 // run on GPUs. 12 // 13 //===----------------------------------------------------------------------===// 14 15 #include "mlir/ExecutionEngine/CRunnerUtils.h" 16 17 #include <stdio.h> 18 19 #include "cuda.h" 20 #include "cuda_bf16.h" 21 #include "cuda_fp16.h" 22 23 #ifdef MLIR_ENABLE_CUDA_CUSPARSE 24 #include "cusparse.h" 25 #ifdef MLIR_ENABLE_CUDA_CUSPARSELT 26 #include "cusparseLt.h" 27 #endif // MLIR_ENABLE_CUDA_CUSPARSELT 28 #endif // MLIR_ENABLE_CUDA_CUSPARSE 29 30 #ifdef _WIN32 31 #include <malloc.h> 32 #define MLIR_CUDA_WRAPPERS_EXPORT __declspec(dllexport) 33 #else 34 #define MLIR_CUDA_WRAPPERS_EXPORT __attribute__((visibility("default"))) 35 #endif // _WIN32 36 37 #define CUDA_REPORT_IF_ERROR(expr) \ 38 [](CUresult result) { \ 39 if (!result) \ 40 return; \ 41 const char *name = nullptr; \ 42 cuGetErrorName(result, &name); \ 43 if (!name) \ 44 name = "<unknown>"; \ 45 fprintf(stderr, "'%s' failed with '%s'\n", #expr, name); \ 46 }(expr) 47 48 #define CUSPARSE_REPORT_IF_ERROR(expr) \ 49 { \ 50 cusparseStatus_t status = (expr); \ 51 if (status != CUSPARSE_STATUS_SUCCESS) { \ 52 fprintf(stderr, "cuSPARSE '%s' failed with '%s'\n", #expr, \ 53 cusparseGetErrorString(status)); \ 54 } \ 55 } 56 57 thread_local static int32_t defaultDevice = 0; 58 59 const char *kDebugEnvironmentVariable = "MLIR_CUDA_DEBUG"; 60 61 /// Helper method that checks environment value for debugging. 62 bool isDebugEnabled() { 63 static bool isInitialized = false; 64 static bool isEnabled = false; 65 if (!isInitialized) 66 isEnabled = getenv(kDebugEnvironmentVariable) != nullptr; 67 return isEnabled; 68 } 69 70 #define debug_print(fmt, ...) \ 71 do { \ 72 if (isDebugEnabled()) \ 73 fprintf(stderr, "%s:%d:%s(): " fmt, "CudaRuntimeWrappers.cpp", __LINE__, \ 74 __func__, __VA_ARGS__); \ 75 } while (0) 76 77 // Returns default CUdevice 78 CUdevice getDefaultCuDevice() { 79 CUdevice device; 80 CUDA_REPORT_IF_ERROR(cuDeviceGet(&device, /*ordinal=*/defaultDevice)); 81 return device; 82 } 83 84 // Make the primary context of the current default device current for the 85 // duration 86 // of the instance and restore the previous context on destruction. 87 class ScopedContext { 88 public: 89 ScopedContext() { 90 // Static reference to CUDA primary context for device ordinal 91 // defaultDevice. 92 static CUcontext context = [] { 93 CUDA_REPORT_IF_ERROR(cuInit(/*flags=*/0)); 94 CUcontext ctx; 95 // Note: this does not affect the current context. 96 CUDA_REPORT_IF_ERROR( 97 cuDevicePrimaryCtxRetain(&ctx, getDefaultCuDevice())); 98 return ctx; 99 }(); 100 101 CUDA_REPORT_IF_ERROR(cuCtxPushCurrent(context)); 102 } 103 104 ~ScopedContext() { CUDA_REPORT_IF_ERROR(cuCtxPopCurrent(nullptr)); } 105 }; 106 107 #ifdef MLIR_ENABLE_CUDA_CUSPARSE 108 // Note that (1) Nvidia confirms the safety to share handle across multiple 109 // instances, and streams. (2) Clients are responsible to call the @mgpu 110 // environment initialization/destruction in a thread-safe manner, e.g., 111 // at the beginning of the program before multi-threads are created. 112 static cusparseHandle_t cusparse_env = nullptr; 113 114 #ifdef MLIR_ENABLE_CUDA_CUSPARSELT 115 // cusparseLtHandle_t is not a pointer type, so we need an additional flag to 116 // indicate whether it is initialized. 117 static cusparseLtHandle_t cusparseLt_env; 118 static bool cusparseLt_initiated = false; 119 120 #endif // MLIR_ENABLE_CUDA_CUSPARSELT 121 #endif // MLIR_ENABLE_CUDA_CUSPARSE 122 123 extern "C" MLIR_CUDA_WRAPPERS_EXPORT CUmodule 124 mgpuModuleLoad(void *data, size_t /*gpuBlobSize*/) { 125 ScopedContext scopedContext; 126 CUmodule module = nullptr; 127 CUDA_REPORT_IF_ERROR(cuModuleLoadData(&module, data)); 128 return module; 129 } 130 131 extern "C" MLIR_CUDA_WRAPPERS_EXPORT CUmodule mgpuModuleLoadJIT(void *data, 132 int optLevel) { 133 ScopedContext scopedContext; 134 CUmodule module = nullptr; 135 char jitErrorBuffer[4096] = {0}; 136 CUjit_option jitOptions[] = {CU_JIT_ERROR_LOG_BUFFER, 137 CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES, 138 CU_JIT_OPTIMIZATION_LEVEL}; 139 void *jitOptionsVals[] = {jitErrorBuffer, 140 reinterpret_cast<void *>(sizeof(jitErrorBuffer)), 141 reinterpret_cast<void *>(optLevel)}; 142 143 CUresult result = 144 cuModuleLoadDataEx(&module, data, 3, jitOptions, jitOptionsVals); 145 if (result) { 146 fprintf(stderr, "JIT compilation failed with: '%s'\n", jitErrorBuffer); 147 CUDA_REPORT_IF_ERROR(result); 148 } 149 return module; 150 } 151 152 extern "C" MLIR_CUDA_WRAPPERS_EXPORT void mgpuModuleUnload(CUmodule module) { 153 CUDA_REPORT_IF_ERROR(cuModuleUnload(module)); 154 } 155 156 extern "C" MLIR_CUDA_WRAPPERS_EXPORT CUfunction 157 mgpuModuleGetFunction(CUmodule module, const char *name) { 158 CUfunction function = nullptr; 159 CUDA_REPORT_IF_ERROR(cuModuleGetFunction(&function, module, name)); 160 return function; 161 } 162 163 // The wrapper uses intptr_t instead of CUDA's unsigned int to match 164 // the type of MLIR's index type. This avoids the need for casts in the 165 // generated MLIR code. 166 extern "C" MLIR_CUDA_WRAPPERS_EXPORT void 167 mgpuLaunchKernel(CUfunction function, intptr_t gridX, intptr_t gridY, 168 intptr_t gridZ, intptr_t blockX, intptr_t blockY, 169 intptr_t blockZ, int32_t smem, CUstream stream, void **params, 170 void **extra, size_t /*paramsCount*/) { 171 ScopedContext scopedContext; 172 if (smem > 0) { 173 // Avoid checking driver as it's more expensive than if statement 174 int32_t maxShmem = 0; 175 CUdevice device = getDefaultCuDevice(); 176 CUDA_REPORT_IF_ERROR(cuDeviceGet(&device, /*ordinal=*/defaultDevice)); 177 CUDA_REPORT_IF_ERROR(cuDeviceGetAttribute( 178 &maxShmem, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN, 179 device)); 180 if (maxShmem < smem) { 181 fprintf(stderr, 182 "Requested shared memory (%dkb) is larger than maximum allowed " 183 "shared memory (%dkb) for this device\n", 184 smem, maxShmem); 185 } 186 CUDA_REPORT_IF_ERROR(cuFuncSetAttribute( 187 function, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, smem)); 188 } 189 debug_print("Launching kernel, grid=%ld,%ld,%ld, " 190 "threads: %ld, %ld, %ld, " 191 "smem: %dkb\n", 192 gridX, gridY, gridZ, blockX, blockY, blockZ, smem); 193 CUDA_REPORT_IF_ERROR(cuLaunchKernel(function, gridX, gridY, gridZ, blockX, 194 blockY, blockZ, smem, stream, params, 195 extra)); 196 } 197 198 extern "C" MLIR_CUDA_WRAPPERS_EXPORT CUstream mgpuStreamCreate() { 199 ScopedContext scopedContext; 200 CUstream stream = nullptr; 201 CUDA_REPORT_IF_ERROR(cuStreamCreate(&stream, CU_STREAM_NON_BLOCKING)); 202 return stream; 203 } 204 205 extern "C" MLIR_CUDA_WRAPPERS_EXPORT void mgpuStreamDestroy(CUstream stream) { 206 CUDA_REPORT_IF_ERROR(cuStreamDestroy(stream)); 207 } 208 209 extern "C" MLIR_CUDA_WRAPPERS_EXPORT void 210 mgpuStreamSynchronize(CUstream stream) { 211 CUDA_REPORT_IF_ERROR(cuStreamSynchronize(stream)); 212 } 213 214 extern "C" MLIR_CUDA_WRAPPERS_EXPORT void mgpuStreamWaitEvent(CUstream stream, 215 CUevent event) { 216 CUDA_REPORT_IF_ERROR(cuStreamWaitEvent(stream, event, /*flags=*/0)); 217 } 218 219 extern "C" MLIR_CUDA_WRAPPERS_EXPORT CUevent mgpuEventCreate() { 220 ScopedContext scopedContext; 221 CUevent event = nullptr; 222 CUDA_REPORT_IF_ERROR(cuEventCreate(&event, CU_EVENT_DISABLE_TIMING)); 223 return event; 224 } 225 226 extern "C" MLIR_CUDA_WRAPPERS_EXPORT void mgpuEventDestroy(CUevent event) { 227 CUDA_REPORT_IF_ERROR(cuEventDestroy(event)); 228 } 229 230 extern "C" MLIR_CUDA_WRAPPERS_EXPORT void mgpuEventSynchronize(CUevent event) { 231 CUDA_REPORT_IF_ERROR(cuEventSynchronize(event)); 232 } 233 234 extern "C" MLIR_CUDA_WRAPPERS_EXPORT void mgpuEventRecord(CUevent event, 235 CUstream stream) { 236 CUDA_REPORT_IF_ERROR(cuEventRecord(event, stream)); 237 } 238 239 extern "C" MLIR_CUDA_WRAPPERS_EXPORT void * 240 mgpuMemAlloc(uint64_t sizeBytes, CUstream stream, bool isHostShared) { 241 ScopedContext scopedContext; 242 CUdeviceptr ptr = 0; 243 if (sizeBytes == 0) 244 return reinterpret_cast<void *>(ptr); 245 246 if (isHostShared) { 247 CUDA_REPORT_IF_ERROR( 248 cuMemAllocManaged(&ptr, sizeBytes, CU_MEM_ATTACH_GLOBAL)); 249 return reinterpret_cast<void *>(ptr); 250 } 251 CUDA_REPORT_IF_ERROR(cuMemAlloc(&ptr, sizeBytes)); 252 return reinterpret_cast<void *>(ptr); 253 } 254 255 extern "C" MLIR_CUDA_WRAPPERS_EXPORT void mgpuMemFree(void *ptr, 256 CUstream /*stream*/) { 257 CUDA_REPORT_IF_ERROR(cuMemFree(reinterpret_cast<CUdeviceptr>(ptr))); 258 } 259 260 extern "C" MLIR_CUDA_WRAPPERS_EXPORT void 261 mgpuMemcpy(void *dst, void *src, size_t sizeBytes, CUstream stream) { 262 CUDA_REPORT_IF_ERROR(cuMemcpyAsync(reinterpret_cast<CUdeviceptr>(dst), 263 reinterpret_cast<CUdeviceptr>(src), 264 sizeBytes, stream)); 265 } 266 267 extern "C" MLIR_CUDA_WRAPPERS_EXPORT void 268 mgpuMemset32(void *dst, unsigned int value, size_t count, CUstream stream) { 269 CUDA_REPORT_IF_ERROR(cuMemsetD32Async(reinterpret_cast<CUdeviceptr>(dst), 270 value, count, stream)); 271 } 272 273 extern "C" MLIR_CUDA_WRAPPERS_EXPORT void 274 mgpuMemset16(void *dst, unsigned short value, size_t count, CUstream stream) { 275 CUDA_REPORT_IF_ERROR(cuMemsetD16Async(reinterpret_cast<CUdeviceptr>(dst), 276 value, count, stream)); 277 } 278 279 /// 280 /// Helper functions for writing mlir example code 281 /// 282 283 // Allows to register byte array with the CUDA runtime. Helpful until we have 284 // transfer functions implemented. 285 extern "C" MLIR_CUDA_WRAPPERS_EXPORT void 286 mgpuMemHostRegister(void *ptr, uint64_t sizeBytes) { 287 ScopedContext scopedContext; 288 CUDA_REPORT_IF_ERROR(cuMemHostRegister(ptr, sizeBytes, /*flags=*/0)); 289 } 290 291 /// Registers a memref with the CUDA runtime. `descriptor` is a pointer to a 292 /// ranked memref descriptor struct of rank `rank`. Helpful until we have 293 /// transfer functions implemented. 294 extern "C" MLIR_CUDA_WRAPPERS_EXPORT void 295 mgpuMemHostRegisterMemRef(int64_t rank, StridedMemRefType<char, 1> *descriptor, 296 int64_t elementSizeBytes) { 297 // Only densely packed tensors are currently supported. 298 #ifdef _WIN32 299 int64_t *denseStrides = (int64_t *)_alloca(rank * sizeof(int64_t)); 300 #else 301 int64_t *denseStrides = (int64_t *)alloca(rank * sizeof(int64_t)); 302 #endif // _WIN32 303 int64_t *sizes = descriptor->sizes; 304 for (int64_t i = rank - 1, runningStride = 1; i >= 0; i--) { 305 denseStrides[i] = runningStride; 306 runningStride *= sizes[i]; 307 } 308 uint64_t sizeBytes = sizes[0] * denseStrides[0] * elementSizeBytes; 309 int64_t *strides = &sizes[rank]; 310 (void)strides; 311 for (unsigned i = 0; i < rank; ++i) 312 assert(strides[i] == denseStrides[i] && 313 "Mismatch in computed dense strides"); 314 315 auto *ptr = descriptor->data + descriptor->offset * elementSizeBytes; 316 mgpuMemHostRegister(ptr, sizeBytes); 317 } 318 319 // Allows to unregister byte array with the CUDA runtime. 320 extern "C" MLIR_CUDA_WRAPPERS_EXPORT void mgpuMemHostUnregister(void *ptr) { 321 ScopedContext scopedContext; 322 CUDA_REPORT_IF_ERROR(cuMemHostUnregister(ptr)); 323 } 324 325 /// Unregisters a memref with the CUDA runtime. `descriptor` is a pointer to a 326 /// ranked memref descriptor struct of rank `rank` 327 extern "C" MLIR_CUDA_WRAPPERS_EXPORT void 328 mgpuMemHostUnregisterMemRef(int64_t rank, 329 StridedMemRefType<char, 1> *descriptor, 330 int64_t elementSizeBytes) { 331 auto *ptr = descriptor->data + descriptor->offset * elementSizeBytes; 332 mgpuMemHostUnregister(ptr); 333 } 334 335 extern "C" MLIR_CUDA_WRAPPERS_EXPORT void mgpuSetDefaultDevice(int32_t device) { 336 defaultDevice = device; 337 } 338 339 /// 340 /// Runtime methods using CUDA 12.0+ driver 341 /// 342 343 #if (CUDA_VERSION >= 12000) 344 345 extern "C" MLIR_CUDA_WRAPPERS_EXPORT void mgpuLaunchClusterKernel( 346 CUfunction function, intptr_t clusterX, intptr_t clusterY, 347 intptr_t clusterZ, intptr_t gridX, intptr_t gridY, intptr_t gridZ, 348 intptr_t blockX, intptr_t blockY, intptr_t blockZ, int32_t smem, 349 CUstream stream, void **params, void **extra, size_t /*paramsCount*/) { 350 ScopedContext scopedContext; 351 if (smem > 0) { 352 // Avoid checking driver as it's more expensive than if statement 353 int32_t maxShmem = 0; 354 CUdevice device = getDefaultCuDevice(); 355 CUDA_REPORT_IF_ERROR(cuDeviceGet(&device, /*ordinal=*/defaultDevice)); 356 CUDA_REPORT_IF_ERROR(cuDeviceGetAttribute( 357 &maxShmem, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN, 358 device)); 359 if (maxShmem < smem) { 360 fprintf(stderr, 361 "Requested shared memory (%dkb) is larger than maximum allowed " 362 "shared memory (%dkb) for this device\n", 363 smem, maxShmem); 364 } 365 CUDA_REPORT_IF_ERROR(cuFuncSetAttribute( 366 function, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, smem)); 367 } 368 CUlaunchConfig config; 369 config.gridDimX = gridX; 370 config.gridDimY = gridY; 371 config.gridDimZ = gridZ; 372 config.blockDimX = blockX; 373 config.blockDimY = blockY; 374 config.blockDimZ = blockZ; 375 config.sharedMemBytes = smem; 376 config.hStream = stream; 377 CUlaunchAttribute launchAttr[2]; 378 launchAttr[0].id = CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION; 379 launchAttr[0].value.clusterDim.x = clusterX; 380 launchAttr[0].value.clusterDim.y = clusterY; 381 launchAttr[0].value.clusterDim.z = clusterZ; 382 launchAttr[1].id = CU_LAUNCH_ATTRIBUTE_CLUSTER_SCHEDULING_POLICY_PREFERENCE; 383 launchAttr[1].value.clusterSchedulingPolicyPreference = 384 CU_CLUSTER_SCHEDULING_POLICY_SPREAD; 385 config.numAttrs = 2; 386 config.attrs = launchAttr; 387 388 debug_print("Launching kernel," 389 "cluster: %ld, %ld, %ld, " 390 "grid=%ld,%ld,%ld, " 391 "threads: %ld, %ld, %ld, " 392 "smem: %dkb\n", 393 clusterX, clusterY, clusterZ, gridX, gridY, gridZ, blockX, blockY, 394 blockZ, smem); 395 396 CUDA_REPORT_IF_ERROR(cuLaunchKernelEx(&config, function, params, extra)); 397 } 398 399 extern "C" MLIR_CUDA_WRAPPERS_EXPORT void mgpuTensorMapEncodeTiled( 400 CUtensorMap *tensorMap, // Tensor map object 401 CUtensorMapDataType tensorDataType, // Tensor data type 402 cuuint32_t tensorRank, // Dimensionality of tensor 403 void *globalAddress, // Starting address 404 const cuuint64_t *globalDim, // Tensor size (number of elements) 405 const cuuint64_t *globalStrides, // Stride size (in bytes) 406 const cuuint32_t *boxDim, // Traversal box (number of elments) 407 const cuuint32_t *elementStrides, // Traversal stride 408 CUtensorMapInterleave interleave, // Type of interleaved layout 409 CUtensorMapSwizzle swizzle, // Bank swizzling pattern 410 CUtensorMapL2promotion l2Promotion, // L2 promotion size 411 CUtensorMapFloatOOBfill oobFill // Padding zfill or NaN fill 412 ) { 413 ScopedContext scopedContext; 414 CUDA_REPORT_IF_ERROR(cuTensorMapEncodeTiled( 415 tensorMap, tensorDataType, tensorRank, globalAddress, globalDim, 416 globalStrides, boxDim, elementStrides, interleave, swizzle, l2Promotion, 417 oobFill)); 418 debug_print("Created TMA descriptor\n Addr: %p\n" 419 "data type : %d\n" 420 "rank : %d\n" 421 "globalDim[5]: %zu, %zu, %zu, %zu, %zu\n" 422 "globalStrides[5]: %zu, %zu, %zu, %zu, %zu\n" 423 "boxDim[5]: %u, %u, %u, %u, %u\n" 424 "elementStrides[5]: %u, %u, %u, %u, %u\n" 425 "interleave: %u \n" 426 "swizzle: %u \n" 427 "l2Promotion: %u \n" 428 "oobFill: %u \n", 429 (void *)&tensorMap, tensorDataType, tensorRank, globalDim[0], 430 globalDim[1], globalDim[2], globalDim[3], globalDim[4], 431 globalStrides[0], globalStrides[1], globalStrides[2], 432 globalStrides[3], globalStrides[4], boxDim[0], boxDim[1], 433 boxDim[2], boxDim[3], boxDim[4], elementStrides[0], 434 elementStrides[1], elementStrides[2], elementStrides[3], 435 elementStrides[4], interleave, swizzle, l2Promotion, oobFill); 436 } 437 438 template <int Rank> 439 void mgpuGetMemRefDataAndShape(void *rawDescriptor, char **addr, 440 uint64_t *globalDim, uint64_t *globalStrides, 441 const CUtensorMapDataType tensorDataType) { 442 auto descriptor = 443 reinterpret_cast<StridedMemRefType<char, Rank> *>(rawDescriptor); 444 *addr = descriptor->data; 445 for (int i = 0; i < Rank; ++i) { 446 globalDim[i] = static_cast<uint64_t>(descriptor->sizes[Rank - i - 1]); 447 } 448 static constexpr int elementSizeInBytes[] = {1, 2, 4, 4, 8, 8, 2, 449 4, 8, 2, 4, 4, 4}; 450 for (int i = 0; i < Rank - 1; ++i) { 451 globalStrides[i] = static_cast<uint64_t>( 452 descriptor->strides[Rank - i - 2] * elementSizeInBytes[tensorDataType]); 453 } 454 } 455 456 extern "C" MLIR_CUDA_WRAPPERS_EXPORT void *mgpuTensorMapEncodeTiledMemref( 457 int64_t tensorRank, // Dimensionality of tensor 458 void *rankedDescriptor, // Ranked MemRef descriptor 459 const CUtensorMapDataType tensorDataType, // Stride size (in bytes) 460 CUtensorMapInterleave interleave, // Type of interleaved layout 461 CUtensorMapSwizzle swizzle, // Bank swizzling pattern 462 CUtensorMapL2promotion l2Promotion, // L2 promotion size 463 CUtensorMapFloatOOBfill oobFill, // Padding zfill or NaN fill 464 int64_t *inputBoxDims // Tensor size (number of elements) 465 ) { 466 CUtensorMap tensorMap; 467 468 uint32_t boxDim[5] = {1, 1, 1, 1, 1}, elementStrides[5] = {1, 1, 1, 1, 1}; 469 uint64_t globalDim[5] = {1, 1, 1, 1, 1}, globalStrides[5] = {0}; 470 uint32_t tensorRank32 = uint32_t(tensorRank); 471 472 char *globalAddress = nullptr; 473 switch (tensorRank) { 474 case 1: 475 mgpuGetMemRefDataAndShape<1>(rankedDescriptor, &globalAddress, globalDim, 476 globalStrides, tensorDataType); 477 break; 478 case 2: 479 mgpuGetMemRefDataAndShape<2>(rankedDescriptor, &globalAddress, globalDim, 480 globalStrides, tensorDataType); 481 break; 482 case 3: 483 mgpuGetMemRefDataAndShape<3>(rankedDescriptor, &globalAddress, globalDim, 484 globalStrides, tensorDataType); 485 break; 486 case 4: 487 mgpuGetMemRefDataAndShape<4>(rankedDescriptor, &globalAddress, globalDim, 488 globalStrides, tensorDataType); 489 break; 490 case 5: 491 mgpuGetMemRefDataAndShape<5>(rankedDescriptor, &globalAddress, globalDim, 492 globalStrides, tensorDataType); 493 break; 494 default: 495 fprintf( 496 stderr, 497 "'mgpuTensorMapEncodeTiledMemref' failed with 'rank is too high'\n"); 498 return nullptr; 499 } 500 501 for (int64_t r = 0; r < tensorRank; ++r) { 502 boxDim[r] = static_cast<uint32_t>(inputBoxDims[tensorRank - r - 1]); 503 } 504 505 ScopedContext scopedContext; 506 mgpuTensorMapEncodeTiled(&tensorMap, tensorDataType, tensorRank32, 507 globalAddress, globalDim, globalStrides, boxDim, 508 elementStrides, interleave, swizzle, l2Promotion, 509 oobFill); 510 // Copy created tensor map to device 511 CUdeviceptr dTensorMap; 512 CUDA_REPORT_IF_ERROR(cuMemAlloc(&dTensorMap, sizeof(CUtensorMap))); 513 CUDA_REPORT_IF_ERROR(cuMemcpy(dTensorMap, 514 reinterpret_cast<CUdeviceptr>(&tensorMap), 515 sizeof(CUtensorMap))); 516 return reinterpret_cast<void *>(dTensorMap); 517 } 518 #endif 519 520 #ifdef MLIR_ENABLE_CUDA_CUSPARSE 521 522 /// 523 /// Wrapper methods for the cuSparse library. 524 /// 525 526 // Some macro magic to get float/double alpha and beta on host. 527 // TODO: add support to passing alpha and beta as arguments 528 #define ALPHABETA(dtp, alpha, beta) \ 529 __nv_bfloat16(alpha##16bf) = 1.0f; \ 530 __nv_bfloat16(beta##16bf) = 1.0f; \ 531 __half(alpha##16f) = 1.0f; \ 532 __half(beta##16f) = 1.0f; \ 533 float(alpha##f) = 1.0f; \ 534 float(beta##f) = 1.0f; \ 535 double(alpha##d) = 1.0; \ 536 double(beta##d) = 1.0; \ 537 const void *(alpha##p) = nullptr; \ 538 const void *(beta##p) = nullptr; \ 539 if (dtp == CUDA_R_16BF || dtp == CUDA_C_16BF) { \ 540 (alpha##p) = reinterpret_cast<void *>(&(alpha##16bf)); \ 541 (beta##p) = reinterpret_cast<void *>(&(beta##16bf)); \ 542 } else if (dtp == CUDA_R_16F || dtp == CUDA_C_16F) { \ 543 (alpha##p) = reinterpret_cast<void *>(&(alpha##16f)); \ 544 (beta##p) = reinterpret_cast<void *>(&(beta##16f)); \ 545 } else if (dtp == CUDA_R_32F || dtp == CUDA_C_32F) { \ 546 (alpha##p) = reinterpret_cast<void *>(&(alpha##f)); \ 547 (beta##p) = reinterpret_cast<void *>(&(beta##f)); \ 548 } else { \ 549 (alpha##p) = reinterpret_cast<void *>(&(alpha##d)); \ 550 (beta##p) = reinterpret_cast<void *>(&(beta##d)); \ 551 } 552 553 extern "C" MLIR_CUDA_WRAPPERS_EXPORT void mgpuCreateSparseEnv() { 554 // ScopedContext is for cuda initialization. 555 ScopedContext scopedContext; 556 assert(!cusparse_env && "client called mgpuCreateSparseEnv() twice"); 557 CUSPARSE_REPORT_IF_ERROR(cusparseCreate(&cusparse_env)); 558 } 559 560 extern "C" MLIR_CUDA_WRAPPERS_EXPORT void mgpuDestroySparseEnv() { 561 assert(cusparse_env && "client did not call mgpuCreateSparseEnv()"); 562 CUSPARSE_REPORT_IF_ERROR(cusparseDestroy(cusparse_env)); 563 cusparse_env = nullptr; 564 } 565 566 extern "C" MLIR_CUDA_WRAPPERS_EXPORT void * 567 mgpuCreateDnVec(intptr_t size, void *values, int32_t dtp, CUstream /*stream*/) { 568 cusparseDnVecDescr_t vec = nullptr; 569 auto dTp = static_cast<cudaDataType_t>(dtp); 570 CUSPARSE_REPORT_IF_ERROR(cusparseCreateDnVec(&vec, size, values, dTp)) 571 return reinterpret_cast<void *>(vec); 572 } 573 574 extern "C" MLIR_CUDA_WRAPPERS_EXPORT void 575 mgpuDestroyDnVec(void *v, CUstream /*stream*/) { 576 cusparseDnVecDescr_t vec = reinterpret_cast<cusparseDnVecDescr_t>(v); 577 CUSPARSE_REPORT_IF_ERROR(cusparseDestroyDnVec(vec)) 578 } 579 580 extern "C" MLIR_CUDA_WRAPPERS_EXPORT void * 581 mgpuCreateDnMat(intptr_t rows, intptr_t cols, void *values, int32_t dtp, 582 CUstream /*stream*/) { 583 cusparseDnMatDescr_t mat = nullptr; 584 auto dTp = static_cast<cudaDataType_t>(dtp); 585 CUSPARSE_REPORT_IF_ERROR(cusparseCreateDnMat(&mat, rows, cols, /*ld=*/cols, 586 values, dTp, CUSPARSE_ORDER_ROW)) 587 return reinterpret_cast<void *>(mat); 588 } 589 590 extern "C" MLIR_CUDA_WRAPPERS_EXPORT void 591 mgpuDestroyDnMat(void *m, CUstream /*stream*/) { 592 cusparseDnMatDescr_t mat = reinterpret_cast<cusparseDnMatDescr_t>(m); 593 CUSPARSE_REPORT_IF_ERROR(cusparseDestroyDnMat(mat)) 594 } 595 596 extern "C" MLIR_CUDA_WRAPPERS_EXPORT void * 597 mgpuCreateCoo(intptr_t rows, intptr_t cols, intptr_t nnz, void *rowIdxs, 598 void *colIdxs, void *values, int32_t itp, int32_t dtp, 599 CUstream /*stream*/) { 600 cusparseSpMatDescr_t mat = nullptr; 601 auto iTp = static_cast<cusparseIndexType_t>(itp); 602 auto dTp = static_cast<cudaDataType_t>(dtp); 603 CUSPARSE_REPORT_IF_ERROR(cusparseCreateCoo(&mat, rows, cols, nnz, rowIdxs, 604 colIdxs, values, iTp, 605 CUSPARSE_INDEX_BASE_ZERO, dTp)) 606 return reinterpret_cast<void *>(mat); 607 } 608 609 #ifdef CUSPARSE_COO_AOS // deprecated in cuSPARSE 11.2 610 extern "C" MLIR_CUDA_WRAPPERS_EXPORT void * 611 mgpuCreateCooAoS(intptr_t rows, intptr_t cols, intptr_t nnz, void *idxs, 612 void *values, int32_t itp, int32_t dtp, CUstream /*stream*/) { 613 cusparseSpMatDescr_t mat = nullptr; 614 auto iTp = static_cast<cusparseIndexType_t>(itp); 615 auto dTp = static_cast<cudaDataType_t>(dtp); 616 CUSPARSE_REPORT_IF_ERROR(cusparseCreateCooAoS( 617 &mat, rows, cols, nnz, idxs, values, iTp, CUSPARSE_INDEX_BASE_ZERO, dTp)) 618 return reinterpret_cast<void *>(mat); 619 } 620 #endif // CUSPARSE_COO_AOS 621 622 extern "C" MLIR_CUDA_WRAPPERS_EXPORT void * 623 mgpuCreateCsr(intptr_t rows, intptr_t cols, intptr_t nnz, void *rowPos, 624 void *colIdxs, void *values, int32_t ptp, int32_t itp, 625 int32_t dtp, CUstream /*stream*/) { 626 cusparseSpMatDescr_t mat = nullptr; 627 auto pTp = static_cast<cusparseIndexType_t>(ptp); 628 auto iTp = static_cast<cusparseIndexType_t>(itp); 629 auto dTp = static_cast<cudaDataType_t>(dtp); 630 CUSPARSE_REPORT_IF_ERROR(cusparseCreateCsr(&mat, rows, cols, nnz, rowPos, 631 colIdxs, values, pTp, iTp, 632 CUSPARSE_INDEX_BASE_ZERO, dTp)) 633 return reinterpret_cast<void *>(mat); 634 } 635 636 extern "C" MLIR_CUDA_WRAPPERS_EXPORT void * 637 mgpuCreateCsc(intptr_t rows, intptr_t cols, intptr_t nnz, void *colPos, 638 void *rowIdxs, void *values, int32_t ptp, int32_t itp, 639 int32_t dtp, CUstream /*stream*/) { 640 cusparseSpMatDescr_t mat = nullptr; 641 auto pTp = static_cast<cusparseIndexType_t>(ptp); 642 auto iTp = static_cast<cusparseIndexType_t>(itp); 643 auto dTp = static_cast<cudaDataType_t>(dtp); 644 CUSPARSE_REPORT_IF_ERROR(cusparseCreateCsc(&mat, rows, cols, nnz, colPos, 645 rowIdxs, values, pTp, iTp, 646 CUSPARSE_INDEX_BASE_ZERO, dTp)) 647 return reinterpret_cast<void *>(mat); 648 } 649 650 extern "C" MLIR_CUDA_WRAPPERS_EXPORT void * 651 mgpuCreateBsr(intptr_t brows, intptr_t bcols, intptr_t bnnz, intptr_t rBsz, 652 intptr_t cBsz, void *rowPos, void *colIdxs, void *values, 653 int32_t ptp, int32_t itp, int32_t dtp, CUstream /*stream*/) { 654 cusparseSpMatDescr_t mat = nullptr; 655 #if CUSPARSE_VERSION >= 12100 656 auto pTp = static_cast<cusparseIndexType_t>(ptp); 657 auto iTp = static_cast<cusparseIndexType_t>(itp); 658 auto dTp = static_cast<cudaDataType_t>(dtp); 659 CUSPARSE_REPORT_IF_ERROR(cusparseCreateBsr( 660 &mat, brows, bcols, bnnz, rBsz, cBsz, rowPos, colIdxs, values, pTp, iTp, 661 CUSPARSE_INDEX_BASE_ZERO, dTp, CUSPARSE_ORDER_ROW)) 662 #endif 663 return reinterpret_cast<void *>(mat); 664 } 665 666 extern "C" MLIR_CUDA_WRAPPERS_EXPORT void 667 mgpuDestroySpMat(void *m, CUstream /*stream*/) { 668 cusparseSpMatDescr_t mat = reinterpret_cast<cusparseSpMatDescr_t>(m); 669 CUSPARSE_REPORT_IF_ERROR(cusparseDestroySpMat(mat)) 670 } 671 672 extern "C" MLIR_CUDA_WRAPPERS_EXPORT intptr_t mgpuSpMVBufferSize( 673 int32_t ma, void *a, void *x, void *y, int32_t ctp, CUstream /*stream*/) { 674 assert(cusparse_env && "client did not call mgpuCreateSparseEnv()"); 675 cusparseOperation_t modeA = static_cast<cusparseOperation_t>(ma); 676 cusparseSpMatDescr_t matA = reinterpret_cast<cusparseSpMatDescr_t>(a); 677 cusparseDnVecDescr_t vecX = reinterpret_cast<cusparseDnVecDescr_t>(x); 678 cusparseDnVecDescr_t vecY = reinterpret_cast<cusparseDnVecDescr_t>(y); 679 cudaDataType_t cTp = static_cast<cudaDataType_t>(ctp); 680 ALPHABETA(cTp, alpha, beta) 681 size_t bufferSize = 0; 682 CUSPARSE_REPORT_IF_ERROR(cusparseSpMV_bufferSize( 683 cusparse_env, modeA, alphap, matA, vecX, betap, vecY, cTp, 684 CUSPARSE_SPMV_ALG_DEFAULT, &bufferSize)) 685 return bufferSize; 686 } 687 688 extern "C" MLIR_CUDA_WRAPPERS_EXPORT void mgpuSpMV(int32_t ma, void *a, void *x, 689 void *y, int32_t ctp, 690 void *buf, 691 CUstream /*stream*/) { 692 assert(cusparse_env && "client did not call mgpuCreateSparseEnv()"); 693 cusparseOperation_t modeA = static_cast<cusparseOperation_t>(ma); 694 cusparseSpMatDescr_t matA = reinterpret_cast<cusparseSpMatDescr_t>(a); 695 cusparseDnVecDescr_t vecX = reinterpret_cast<cusparseDnVecDescr_t>(x); 696 cusparseDnVecDescr_t vecY = reinterpret_cast<cusparseDnVecDescr_t>(y); 697 cudaDataType_t cTp = static_cast<cudaDataType_t>(ctp); 698 ALPHABETA(cTp, alpha, beta) 699 CUSPARSE_REPORT_IF_ERROR(cusparseSpMV(cusparse_env, modeA, alphap, matA, vecX, 700 betap, vecY, cTp, 701 CUSPARSE_SPMV_ALG_DEFAULT, buf)) 702 } 703 704 extern "C" MLIR_CUDA_WRAPPERS_EXPORT intptr_t 705 mgpuSpMMBufferSize(int32_t ma, int32_t mb, void *a, void *b, void *c, 706 int32_t ctp, CUstream /*stream*/) { 707 assert(cusparse_env && "client did not call mgpuCreateSparseEnv()"); 708 cusparseOperation_t modeA = static_cast<cusparseOperation_t>(ma); 709 cusparseOperation_t modeB = static_cast<cusparseOperation_t>(mb); 710 cusparseSpMatDescr_t matA = reinterpret_cast<cusparseSpMatDescr_t>(a); 711 cusparseDnMatDescr_t matB = reinterpret_cast<cusparseDnMatDescr_t>(b); 712 cusparseDnMatDescr_t matC = reinterpret_cast<cusparseDnMatDescr_t>(c); 713 cudaDataType_t cTp = static_cast<cudaDataType_t>(ctp); 714 ALPHABETA(cTp, alpha, beta) 715 size_t bufferSize = 0; 716 CUSPARSE_REPORT_IF_ERROR(cusparseSpMM_bufferSize( 717 cusparse_env, modeA, modeB, alphap, matA, matB, betap, matC, cTp, 718 CUSPARSE_SPMM_ALG_DEFAULT, &bufferSize)) 719 return bufferSize; 720 } 721 722 extern "C" MLIR_CUDA_WRAPPERS_EXPORT void mgpuSpMM(int32_t ma, int32_t mb, 723 void *a, void *b, void *c, 724 int32_t ctp, void *buf, 725 CUstream /*stream*/) { 726 assert(cusparse_env && "client did not call mgpuCreateSparseEnv()"); 727 cusparseOperation_t modeA = static_cast<cusparseOperation_t>(ma); 728 cusparseOperation_t modeB = static_cast<cusparseOperation_t>(mb); 729 cusparseSpMatDescr_t matA = reinterpret_cast<cusparseSpMatDescr_t>(a); 730 cusparseDnMatDescr_t matB = reinterpret_cast<cusparseDnMatDescr_t>(b); 731 cusparseDnMatDescr_t matC = reinterpret_cast<cusparseDnMatDescr_t>(c); 732 cudaDataType_t cTp = static_cast<cudaDataType_t>(ctp); 733 ALPHABETA(cTp, alpha, beta) 734 CUSPARSE_REPORT_IF_ERROR(cusparseSpMM(cusparse_env, modeA, modeB, alphap, 735 matA, matB, betap, matC, cTp, 736 CUSPARSE_SPMM_ALG_DEFAULT, buf)) 737 } 738 739 extern "C" MLIR_CUDA_WRAPPERS_EXPORT intptr_t 740 mgpuSDDMMBufferSize(int32_t ma, int32_t mb, void *a, void *b, void *c, 741 int32_t ctp, CUstream /*stream*/) { 742 assert(cusparse_env && "client did not call mgpuCreateSparseEnv()"); 743 cusparseOperation_t modeA = static_cast<cusparseOperation_t>(ma); 744 cusparseOperation_t modeB = static_cast<cusparseOperation_t>(mb); 745 cusparseDnMatDescr_t matA = reinterpret_cast<cusparseDnMatDescr_t>(a); 746 cusparseDnMatDescr_t matB = reinterpret_cast<cusparseDnMatDescr_t>(b); 747 cusparseSpMatDescr_t matC = reinterpret_cast<cusparseSpMatDescr_t>(c); 748 auto cTp = static_cast<cudaDataType_t>(ctp); 749 ALPHABETA(cTp, alpha, beta) 750 size_t bufferSize = 0; 751 CUSPARSE_REPORT_IF_ERROR(cusparseSDDMM_bufferSize( 752 cusparse_env, modeA, modeB, alphap, matA, matB, betap, matC, cTp, 753 CUSPARSE_SDDMM_ALG_DEFAULT, &bufferSize)) 754 return bufferSize; 755 } 756 757 extern "C" MLIR_CUDA_WRAPPERS_EXPORT void mgpuSDDMM(int32_t ma, int32_t mb, 758 void *a, void *b, void *c, 759 int32_t ctp, void *buf, 760 CUstream /*stream*/) { 761 assert(cusparse_env && "client did not call mgpuCreateSparseEnv()"); 762 cusparseOperation_t modeA = static_cast<cusparseOperation_t>(ma); 763 cusparseOperation_t modeB = static_cast<cusparseOperation_t>(mb); 764 cusparseDnMatDescr_t matA = reinterpret_cast<cusparseDnMatDescr_t>(a); 765 cusparseDnMatDescr_t matB = reinterpret_cast<cusparseDnMatDescr_t>(b); 766 cusparseSpMatDescr_t matC = reinterpret_cast<cusparseSpMatDescr_t>(c); 767 auto cTp = static_cast<cudaDataType_t>(ctp); 768 ALPHABETA(cTp, alpha, beta) 769 CUSPARSE_REPORT_IF_ERROR(cusparseSDDMM(cusparse_env, modeA, modeB, alphap, 770 matA, matB, betap, matC, cTp, 771 CUSPARSE_SDDMM_ALG_DEFAULT, buf)) 772 } 773 774 extern "C" MLIR_CUDA_WRAPPERS_EXPORT void * 775 mgpuSpGEMMCreateDescr(CUstream /*stream*/) { 776 cusparseSpGEMMDescr_t spgemmDesc = nullptr; 777 CUSPARSE_REPORT_IF_ERROR(cusparseSpGEMM_createDescr(&spgemmDesc)) 778 return reinterpret_cast<void *>(spgemmDesc); 779 } 780 781 extern "C" MLIR_CUDA_WRAPPERS_EXPORT void 782 mgpuSpGEMMDestroyDescr(void *s, CUstream /*stream*/) { 783 cusparseSpGEMMDescr_t spgemmDesc = reinterpret_cast<cusparseSpGEMMDescr_t>(s); 784 CUSPARSE_REPORT_IF_ERROR(cusparseSpGEMM_destroyDescr(spgemmDesc)) 785 } 786 787 extern "C" MLIR_CUDA_WRAPPERS_EXPORT intptr_t mgpuSpGEMMWorkEstimation( 788 void *s, int32_t ma, int32_t mb, void *a, void *b, void *c, int32_t ctp, 789 intptr_t bs, void *buf, CUstream /*stream*/) { 790 cusparseSpGEMMDescr_t spgemmDesc = reinterpret_cast<cusparseSpGEMMDescr_t>(s); 791 cusparseOperation_t modeA = static_cast<cusparseOperation_t>(ma); 792 cusparseOperation_t modeB = static_cast<cusparseOperation_t>(mb); 793 cusparseSpMatDescr_t matA = reinterpret_cast<cusparseSpMatDescr_t>(a); 794 cusparseSpMatDescr_t matB = reinterpret_cast<cusparseSpMatDescr_t>(b); 795 cusparseSpMatDescr_t matC = reinterpret_cast<cusparseSpMatDescr_t>(c); 796 auto cTp = static_cast<cudaDataType_t>(ctp); 797 ALPHABETA(cTp, alpha, beta) 798 size_t newBufferSize = bs; 799 CUSPARSE_REPORT_IF_ERROR(cusparseSpGEMM_workEstimation( 800 cusparse_env, modeA, modeB, alphap, matA, matB, betap, matC, cTp, 801 CUSPARSE_SPGEMM_DEFAULT, spgemmDesc, &newBufferSize, buf)) 802 return newBufferSize; 803 } 804 805 extern "C" MLIR_CUDA_WRAPPERS_EXPORT intptr_t 806 mgpuSpGEMMCompute(void *s, int32_t ma, int32_t mb, void *a, void *b, void *c, 807 int32_t ctp, intptr_t bsz2, void *buf2, CUstream /*stream*/) { 808 cusparseSpGEMMDescr_t spgemmDesc = reinterpret_cast<cusparseSpGEMMDescr_t>(s); 809 cusparseOperation_t modeA = static_cast<cusparseOperation_t>(ma); 810 cusparseOperation_t modeB = static_cast<cusparseOperation_t>(mb); 811 cusparseSpMatDescr_t matA = reinterpret_cast<cusparseSpMatDescr_t>(a); 812 cusparseSpMatDescr_t matB = reinterpret_cast<cusparseSpMatDescr_t>(b); 813 cusparseSpMatDescr_t matC = reinterpret_cast<cusparseSpMatDescr_t>(c); 814 auto cTp = static_cast<cudaDataType_t>(ctp); 815 ALPHABETA(cTp, alpha, beta) 816 size_t newBufferSize2 = bsz2; 817 CUSPARSE_REPORT_IF_ERROR(cusparseSpGEMM_compute( 818 cusparse_env, modeA, modeB, alphap, matA, matB, betap, matC, cTp, 819 CUSPARSE_SPGEMM_DEFAULT, spgemmDesc, &newBufferSize2, buf2)) 820 return newBufferSize2; 821 } 822 823 extern "C" MLIR_CUDA_WRAPPERS_EXPORT void 824 mgpuSpGEMMCopy(void *s, int32_t ma, int32_t mb, void *a, void *b, void *c, 825 int32_t ctp, CUstream /*stream*/) { 826 cusparseSpGEMMDescr_t spgemmDesc = reinterpret_cast<cusparseSpGEMMDescr_t>(s); 827 cusparseOperation_t modeA = static_cast<cusparseOperation_t>(ma); 828 cusparseOperation_t modeB = static_cast<cusparseOperation_t>(mb); 829 cusparseSpMatDescr_t matA = reinterpret_cast<cusparseSpMatDescr_t>(a); 830 cusparseSpMatDescr_t matB = reinterpret_cast<cusparseSpMatDescr_t>(b); 831 cusparseSpMatDescr_t matC = reinterpret_cast<cusparseSpMatDescr_t>(c); 832 auto cTp = static_cast<cudaDataType_t>(ctp); 833 ALPHABETA(cTp, alpha, beta) 834 CUSPARSE_REPORT_IF_ERROR( 835 cusparseSpGEMM_copy(cusparse_env, modeA, modeB, alphap, matA, matB, betap, 836 matC, cTp, CUSPARSE_SPGEMM_DEFAULT, spgemmDesc)) 837 } 838 839 extern "C" MLIR_CUDA_WRAPPERS_EXPORT void 840 mgpuSpMatGetSize(void *m, void *r, void *c, void *n, CUstream /*stream*/) { 841 cusparseConstSpMatDescr_t matDescr = 842 reinterpret_cast<cusparseConstSpMatDescr_t>(m); 843 int64_t *rows = reinterpret_cast<int64_t *>(r); 844 int64_t *cols = reinterpret_cast<int64_t *>(c); 845 int64_t *nnz = reinterpret_cast<int64_t *>(n); 846 CUSPARSE_REPORT_IF_ERROR(cusparseSpMatGetSize(matDescr, rows, cols, nnz)); 847 } 848 849 extern "C" MLIR_CUDA_WRAPPERS_EXPORT void 850 mgpuSetCsrPointers(void *m, void *p, void *c, void *v, CUstream /*stream*/) { 851 cusparseSpMatDescr_t matDescr = reinterpret_cast<cusparseSpMatDescr_t>(m); 852 CUSPARSE_REPORT_IF_ERROR(cusparseCsrSetPointers(matDescr, p, c, v)); 853 } 854 855 #ifdef MLIR_ENABLE_CUDA_CUSPARSELT 856 857 /// 858 /// Wrapper methods for the cuSparseLt library. 859 /// 860 861 struct cusparseLtSpMatHandleAndData { 862 cusparseLtMatDescriptor_t mat; 863 // TODO: the following three are associated with the SpMM operator rather than 864 // the sparse matrix. Create workspace buffers and pass them to the SpMM 865 // execution. 866 cusparseLtMatmulAlgSelection_t alg_sel; 867 cusparseLtMatmulPlan_t plan; 868 cusparseLtMatmulDescriptor_t matmul; 869 void *values{nullptr}; 870 }; 871 872 struct cusparseLtDnMatHandleAndData { 873 cusparseLtMatDescriptor_t mat; 874 void *values{nullptr}; 875 }; 876 877 static_assert(sizeof(cusparseLtHandle_t) == 11024, 878 "Unexpected cusparseLt handle size"); 879 static_assert(sizeof(cusparseLtSpMatHandleAndData) == 44104, 880 "Unexpected cusparseLt sparse matrix handle size"); 881 static_assert(sizeof(cusparseLtDnMatHandleAndData) == 11032, 882 "Unexpected cusparseLt dense matrix handle size"); 883 884 extern "C" MLIR_CUDA_WRAPPERS_EXPORT void mgpuCreateSparseLtEnv() { 885 // ScopedContext is for cuda initialization. 886 ScopedContext scopedContext; 887 assert(!cusparseLt_initiated && 888 "client called mgpuCreateSparseLtEnv() twice"); 889 // Note that cuSparseLt still uses cusparseStatus_t. 890 CUSPARSE_REPORT_IF_ERROR(cusparseLtInit(&cusparseLt_env)); 891 cusparseLt_initiated = true; 892 } 893 894 extern "C" MLIR_CUDA_WRAPPERS_EXPORT void mgpuDestroySparseLtEnv() { 895 assert(cusparseLt_initiated && "client did not call mgpuCreateSparseLtEnv()"); 896 CUSPARSE_REPORT_IF_ERROR(cusparseLtDestroy(&cusparseLt_env)); 897 cusparseLt_initiated = false; 898 } 899 900 extern "C" MLIR_CUDA_WRAPPERS_EXPORT void 901 mgpuCreateCuSparseLtDnMat(void *dh, intptr_t rows, intptr_t cols, void *values, 902 int32_t dtp, CUstream /*stream*/) { 903 assert(cusparseLt_initiated && "client did not call mgpuCreateSparseLtEnv()"); 904 auto dnmat_handle = reinterpret_cast<cusparseLtDnMatHandleAndData *>(dh); 905 dnmat_handle->values = values; 906 auto dTp = static_cast<cudaDataType_t>(dtp); 907 // Assume row-major when deciding lda. 908 const uint32_t alignment = 16; 909 CUSPARSE_REPORT_IF_ERROR(cusparseLtDenseDescriptorInit( 910 &cusparseLt_env, &(dnmat_handle->mat), rows, cols, /*lda=*/cols, 911 alignment, dTp, CUSPARSE_ORDER_ROW)) 912 } 913 914 extern "C" MLIR_CUDA_WRAPPERS_EXPORT void 915 mgpuDestroyCuSparseLtDnMat(void *dh, CUstream /*stream*/) { 916 auto dnmat_handle = reinterpret_cast<cusparseLtDnMatHandleAndData *>(dh); 917 CUSPARSE_REPORT_IF_ERROR(cusparseLtMatDescriptorDestroy(&(dnmat_handle->mat))) 918 } 919 920 extern "C" MLIR_CUDA_WRAPPERS_EXPORT void 921 mgpuCusparseLtCreate2To4SpMat(void *sh, intptr_t rows, intptr_t cols, 922 void *values, int32_t dtp, CUstream /*stream*/) { 923 assert(cusparseLt_initiated && "client did not call mgpuCreateSparseLtEnv()"); 924 auto spmat_handle = reinterpret_cast<cusparseLtSpMatHandleAndData *>(sh); 925 spmat_handle->values = values; 926 auto dTp = static_cast<cudaDataType_t>(dtp); 927 // Assume row-major when deciding lda. 928 const uint32_t alignment = 16; 929 CUSPARSE_REPORT_IF_ERROR(cusparseLtStructuredDescriptorInit( 930 &cusparseLt_env, &(spmat_handle->mat), rows, cols, /*ld=*/cols, alignment, 931 dTp, CUSPARSE_ORDER_ROW, CUSPARSELT_SPARSITY_50_PERCENT)) 932 } 933 934 extern "C" MLIR_CUDA_WRAPPERS_EXPORT void 935 mgpuDestroyCuSparseLtSpMat(void *sh, CUstream /*stream*/) { 936 auto spmat_handle = reinterpret_cast<cusparseLtSpMatHandleAndData *>(sh); 937 CUSPARSE_REPORT_IF_ERROR(cusparseLtMatDescriptorDestroy(&(spmat_handle->mat))) 938 } 939 940 // Several things are being done in this stage, algorithm selection, planning, 941 // and returning workspace and compressed matrices data buffer sizes. 942 // The parameter prune_flag is used to indicate whether pruning and pruning 943 // check will happen 0 means not prune or prune check, 1 means prune, 2 means 944 // prune & prune check 945 extern "C" MLIR_CUDA_WRAPPERS_EXPORT void 946 mgpuCuSparseLtSpMMBufferSize(void *bs, int32_t ma, int32_t mb, void *a, void *b, 947 void *c, int32_t ctp, int32_t prune_flag, 948 CUstream stream) { 949 assert(cusparseLt_initiated && "client did not call mgpuCreateSparseLtEnv()"); 950 // TODO: support more advanced settings, e.g., the input right operand is a 951 // sparse matrix assuming matA is the sparse matrix 952 auto matA = reinterpret_cast<cusparseLtSpMatHandleAndData *>(a); 953 auto matB = reinterpret_cast<cusparseLtDnMatHandleAndData *>(b); 954 auto matC = reinterpret_cast<cusparseLtDnMatHandleAndData *>(c); 955 auto workspace_size = reinterpret_cast<size_t *>(bs); 956 auto compressed_size = &(reinterpret_cast<size_t *>(bs)[1]); 957 auto compressed_buffer_size = &(reinterpret_cast<size_t *>(bs)[2]); 958 auto cTp = static_cast<cusparseComputeType>(ctp); 959 960 cusparseOperation_t modeA = static_cast<cusparseOperation_t>(ma); 961 cusparseOperation_t modeB = static_cast<cusparseOperation_t>(mb); 962 CUSPARSE_REPORT_IF_ERROR(cusparseLtMatmulDescriptorInit( 963 &cusparseLt_env, &(matA->matmul), modeA, modeB, &(matA->mat), 964 &(matB->mat), &(matC->mat), &(matC->mat), cTp)) 965 CUSPARSE_REPORT_IF_ERROR(cusparseLtMatmulAlgSelectionInit( 966 &cusparseLt_env, &(matA->alg_sel), &(matA->matmul), 967 CUSPARSELT_MATMUL_ALG_DEFAULT)) 968 int alg = 0; 969 CUSPARSE_REPORT_IF_ERROR(cusparseLtMatmulAlgSetAttribute( 970 &cusparseLt_env, &(matA->alg_sel), CUSPARSELT_MATMUL_ALG_CONFIG_ID, &alg, 971 sizeof(alg))) 972 973 CUSPARSE_REPORT_IF_ERROR(cusparseLtMatmulPlanInit( 974 &cusparseLt_env, &(matA->plan), &(matA->matmul), &(matA->alg_sel))) 975 976 // Pruning step (in-place). 977 if (prune_flag > 0) 978 CUSPARSE_REPORT_IF_ERROR(cusparseLtSpMMAPrune( 979 &cusparseLt_env, &(matA->matmul), matA->values, matA->values, 980 CUSPARSELT_PRUNE_SPMMA_STRIP, stream)) 981 982 // Check structure of A. 983 // Note that this adds a synchronization on the stream. 984 // TODO: Do we want that? 985 if (prune_flag == 2) { 986 int *dvalid = (int *)mgpuMemAlloc(sizeof(int), stream, false); 987 CUSPARSE_REPORT_IF_ERROR(cusparseLtSpMMAPruneCheck( 988 &cusparseLt_env, &(matA->matmul), matA->values, dvalid, stream)) 989 int valid = 0; 990 mgpuMemcpy(&valid, dvalid, sizeof(int), stream); 991 mgpuStreamSynchronize(stream); 992 mgpuMemFree(dvalid, stream); 993 if (valid != 0) 994 fprintf(stderr, "CUPARSE-LT: sparse matrix is not 2:4; computed results " 995 "will be invalid\n"); 996 } 997 998 CUSPARSE_REPORT_IF_ERROR(cusparseLtMatmulGetWorkspace( 999 &cusparseLt_env, &(matA->plan), workspace_size)) 1000 CUSPARSE_REPORT_IF_ERROR(cusparseLtSpMMACompressedSize( 1001 &cusparseLt_env, &(matA->plan), compressed_size, compressed_buffer_size)) 1002 } 1003 1004 extern "C" MLIR_CUDA_WRAPPERS_EXPORT void 1005 mgpuCuSparseLtSpMM(void *a, void *b, void *c, void *d_workspace, 1006 void *dA_compressed, void *dA_compressedBuffer, 1007 CUstream stream) { 1008 assert(cusparseLt_initiated && "client did not call mgpuCreateSparseLtEnv()"); 1009 auto matA = reinterpret_cast<cusparseLtSpMatHandleAndData *>(a); 1010 auto matB = reinterpret_cast<cusparseLtDnMatHandleAndData *>(b); 1011 auto matC = reinterpret_cast<cusparseLtDnMatHandleAndData *>(c); 1012 1013 ALPHABETA(CUDA_R_32F, alpha, beta) 1014 CUSPARSE_REPORT_IF_ERROR( 1015 cusparseLtSpMMACompress(&cusparseLt_env, &(matA->plan), (matA->values), 1016 dA_compressed, dA_compressedBuffer, stream)) 1017 1018 // TODO: add support to multi-stream execution 1019 // Perform the matrix multiplication. D = A*B+C using C==D for now 1020 CUSPARSE_REPORT_IF_ERROR( 1021 cusparseLtMatmul(&cusparseLt_env, &(matA->plan), alphap, dA_compressed, 1022 matB->values, betap, matC->values, 1023 /*dD*/ matC->values, d_workspace, nullptr, 0)) 1024 1025 CUSPARSE_REPORT_IF_ERROR(cusparseLtMatDescriptorDestroy(&(matA->mat))) 1026 // destroy the plan associated with the sparse matrix 1027 CUSPARSE_REPORT_IF_ERROR(cusparseLtMatmulPlanDestroy(&(matA->plan))) 1028 } 1029 1030 #endif // MLIR_ENABLE_CUDA_CUSPARSELT 1031 #endif // MLIR_ENABLE_CUDA_CUSPARSE 1032