1 //===- ConvertLaunchFuncToGpuRuntimeCalls.cpp - MLIR GPU lowering passes --===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements a pass to convert gpu.launch_func op into a sequence of 10 // GPU runtime calls. As most of GPU runtimes does not have a stable published 11 // ABI, this pass uses a slim runtime layer that builds on top of the public 12 // API from GPU runtime headers. 13 // 14 //===----------------------------------------------------------------------===// 15 16 #include "mlir/Conversion/GPUCommon/GPUCommonPass.h" 17 18 #include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h" 19 #include "mlir/Conversion/AsyncToLLVM/AsyncToLLVM.h" 20 #include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h" 21 #include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h" 22 #include "mlir/Conversion/ConvertToLLVM/ToLLVMPass.h" 23 #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h" 24 #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVMPass.h" 25 #include "mlir/Conversion/GPUCommon/GPUToLLVM.h" 26 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" 27 #include "mlir/Conversion/LLVMCommon/Pattern.h" 28 #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h" 29 #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" 30 #include "mlir/Dialect/Async/IR/Async.h" 31 #include "mlir/Dialect/GPU/IR/GPUDialect.h" 32 #include "mlir/Dialect/GPU/Transforms/Passes.h" 33 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 34 #include "mlir/Dialect/MemRef/IR/MemRef.h" 35 #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" 36 #include "mlir/IR/Attributes.h" 37 #include "mlir/IR/Builders.h" 38 #include "mlir/IR/BuiltinOps.h" 39 #include "mlir/IR/BuiltinTypes.h" 40 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 41 42 #include "llvm/ADT/STLExtras.h" 43 #include "llvm/Support/Error.h" 44 #include "llvm/Support/FormatVariadic.h" 45 46 #define DEBUG_TYPE "gpu-to-llvm" 47 48 namespace mlir { 49 #define GEN_PASS_DEF_GPUTOLLVMCONVERSIONPASS 50 #include "mlir/Conversion/Passes.h.inc" 51 } // namespace mlir 52 53 using namespace mlir; 54 55 namespace { 56 class GpuToLLVMConversionPass 57 : public impl::GpuToLLVMConversionPassBase<GpuToLLVMConversionPass> { 58 public: 59 using Base::Base; 60 void getDependentDialects(DialectRegistry ®istry) const final { 61 Base::getDependentDialects(registry); 62 registerConvertToLLVMDependentDialectLoading(registry); 63 } 64 // Run the dialect converter on the module. 65 void runOnOperation() override; 66 }; 67 68 template <typename OpTy> 69 class ConvertOpToGpuRuntimeCallPattern : public ConvertOpToLLVMPattern<OpTy> { 70 public: 71 explicit ConvertOpToGpuRuntimeCallPattern( 72 const LLVMTypeConverter &typeConverter) 73 : ConvertOpToLLVMPattern<OpTy>(typeConverter) {} 74 75 protected: 76 Value getNumElements(ConversionPatternRewriter &rewriter, Location loc, 77 MemRefType type, MemRefDescriptor desc) const { 78 Type indexType = ConvertToLLVMPattern::getIndexType(); 79 return type.hasStaticShape() 80 ? ConvertToLLVMPattern::createIndexAttrConstant( 81 rewriter, loc, indexType, type.getNumElements()) 82 // For identity maps (verified by caller), the number of 83 // elements is stride[0] * size[0]. 84 : rewriter.create<LLVM::MulOp>(loc, 85 desc.stride(rewriter, loc, 0), 86 desc.size(rewriter, loc, 0)); 87 } 88 89 MLIRContext *context = &this->getTypeConverter()->getContext(); 90 91 Type llvmVoidType = LLVM::LLVMVoidType::get(context); 92 LLVM::LLVMPointerType llvmPointerType = LLVM::LLVMPointerType::get(context); 93 Type llvmInt8Type = IntegerType::get(context, 8); 94 Type llvmInt16Type = IntegerType::get(context, 16); 95 Type llvmInt32Type = IntegerType::get(context, 32); 96 Type llvmInt64Type = IntegerType::get(context, 64); 97 Type llvmFloat32Type = Float32Type::get(context); 98 Type llvmIntPtrType = IntegerType::get( 99 context, this->getTypeConverter()->getPointerBitwidth(0)); 100 101 FunctionCallBuilder streamCreateCallBuilder = { 102 "mgpuStreamCreate", llvmPointerType /* void *stream */, {}}; 103 FunctionCallBuilder streamDestroyCallBuilder = { 104 "mgpuStreamDestroy", llvmVoidType, {llvmPointerType /* void *stream */}}; 105 FunctionCallBuilder streamSynchronizeCallBuilder = { 106 "mgpuStreamSynchronize", 107 llvmVoidType, 108 {llvmPointerType /* void *stream */}}; 109 FunctionCallBuilder streamWaitEventCallBuilder = { 110 "mgpuStreamWaitEvent", 111 llvmVoidType, 112 {llvmPointerType /* void *stream */, llvmPointerType /* void *event */}}; 113 FunctionCallBuilder eventCreateCallBuilder = { 114 "mgpuEventCreate", llvmPointerType /* void *event */, {}}; 115 FunctionCallBuilder eventDestroyCallBuilder = { 116 "mgpuEventDestroy", llvmVoidType, {llvmPointerType /* void *event */}}; 117 FunctionCallBuilder eventSynchronizeCallBuilder = { 118 "mgpuEventSynchronize", 119 llvmVoidType, 120 {llvmPointerType /* void *event */}}; 121 FunctionCallBuilder eventRecordCallBuilder = { 122 "mgpuEventRecord", 123 llvmVoidType, 124 {llvmPointerType /* void *event */, llvmPointerType /* void *stream */}}; 125 FunctionCallBuilder hostRegisterCallBuilder = { 126 "mgpuMemHostRegisterMemRef", 127 llvmVoidType, 128 {llvmIntPtrType /* intptr_t rank */, 129 llvmPointerType /* void *memrefDesc */, 130 llvmIntPtrType /* intptr_t elementSizeBytes */}}; 131 FunctionCallBuilder hostUnregisterCallBuilder = { 132 "mgpuMemHostUnregisterMemRef", 133 llvmVoidType, 134 {llvmIntPtrType /* intptr_t rank */, 135 llvmPointerType /* void *memrefDesc */, 136 llvmIntPtrType /* intptr_t elementSizeBytes */}}; 137 FunctionCallBuilder allocCallBuilder = { 138 "mgpuMemAlloc", 139 llvmPointerType /* void * */, 140 {llvmIntPtrType /* intptr_t sizeBytes */, 141 llvmPointerType /* void *stream */, 142 llvmInt8Type /* bool isHostShared */}}; 143 FunctionCallBuilder deallocCallBuilder = { 144 "mgpuMemFree", 145 llvmVoidType, 146 {llvmPointerType /* void *ptr */, llvmPointerType /* void *stream */}}; 147 FunctionCallBuilder memcpyCallBuilder = { 148 "mgpuMemcpy", 149 llvmVoidType, 150 {llvmPointerType /* void *dst */, llvmPointerType /* void *src */, 151 llvmIntPtrType /* intptr_t sizeBytes */, 152 llvmPointerType /* void *stream */}}; 153 FunctionCallBuilder memset16CallBuilder = { 154 "mgpuMemset16", 155 llvmVoidType, 156 {llvmPointerType /* void *dst */, 157 llvmInt16Type /* unsigned short value */, 158 llvmIntPtrType /* intptr_t sizeBytes */, 159 llvmPointerType /* void *stream */}}; 160 FunctionCallBuilder memset32CallBuilder = { 161 "mgpuMemset32", 162 llvmVoidType, 163 {llvmPointerType /* void *dst */, llvmInt32Type /* unsigned int value */, 164 llvmIntPtrType /* intptr_t sizeBytes */, 165 llvmPointerType /* void *stream */}}; 166 FunctionCallBuilder setDefaultDeviceCallBuilder = { 167 "mgpuSetDefaultDevice", 168 llvmVoidType, 169 {llvmInt32Type /* uint32_t devIndex */}}; 170 FunctionCallBuilder createDnVecCallBuilder = { 171 "mgpuCreateDnVec", 172 llvmPointerType, 173 {llvmIntPtrType, llvmPointerType, llvmInt32Type, 174 llvmPointerType /* void *stream */}}; 175 FunctionCallBuilder destroyDnVecCallBuilder = { 176 "mgpuDestroyDnVec", 177 llvmVoidType, 178 {llvmPointerType, llvmPointerType /* void *stream */}}; 179 FunctionCallBuilder createDnMatCallBuilder = { 180 "mgpuCreateDnMat", 181 llvmPointerType, 182 {llvmIntPtrType, llvmIntPtrType, llvmPointerType, llvmInt32Type, 183 llvmPointerType /* void *stream */}}; 184 FunctionCallBuilder destroyDnMatCallBuilder = { 185 "mgpuDestroyDnMat", 186 llvmVoidType, 187 {llvmPointerType, llvmPointerType /* void *stream */}}; 188 FunctionCallBuilder createCooCallBuilder = { 189 "mgpuCreateCoo", 190 llvmPointerType, 191 {llvmIntPtrType, llvmIntPtrType, llvmIntPtrType, llvmPointerType, 192 llvmPointerType, llvmPointerType, llvmInt32Type, llvmInt32Type, 193 llvmPointerType /* void *stream */}}; 194 FunctionCallBuilder createCooAoSCallBuilder = { 195 "mgpuCreateCooAoS", // deprecated in cuSPARSE 11.2 196 llvmPointerType, 197 {llvmIntPtrType, llvmIntPtrType, llvmIntPtrType, llvmPointerType, 198 llvmPointerType, llvmInt32Type, llvmInt32Type, 199 llvmPointerType /* void *stream */}}; 200 FunctionCallBuilder createCsrCallBuilder = { 201 "mgpuCreateCsr", 202 llvmPointerType, 203 {llvmIntPtrType, llvmIntPtrType, llvmIntPtrType, llvmPointerType, 204 llvmPointerType, llvmPointerType, llvmInt32Type, llvmInt32Type, 205 llvmInt32Type, llvmPointerType /* void *stream */}}; 206 FunctionCallBuilder createCscCallBuilder = { 207 "mgpuCreateCsc", 208 llvmPointerType, 209 {llvmIntPtrType, llvmIntPtrType, llvmIntPtrType, llvmPointerType, 210 llvmPointerType, llvmPointerType, llvmInt32Type, llvmInt32Type, 211 llvmInt32Type, llvmPointerType /* void *stream */}}; 212 FunctionCallBuilder createBsrCallBuilder = { 213 "mgpuCreateBsr", 214 llvmPointerType, 215 {llvmIntPtrType, llvmIntPtrType, llvmIntPtrType, llvmIntPtrType, 216 llvmIntPtrType, llvmPointerType, llvmPointerType, llvmPointerType, 217 llvmInt32Type, llvmInt32Type, llvmInt32Type, 218 llvmPointerType /* void *stream */}}; 219 FunctionCallBuilder destroySpMatCallBuilder = { 220 "mgpuDestroySpMat", 221 llvmVoidType, 222 {llvmPointerType, llvmPointerType /* void *stream */}}; 223 FunctionCallBuilder spMVBufferSizeCallBuilder = { 224 "mgpuSpMVBufferSize", 225 llvmIntPtrType, 226 {llvmInt32Type, llvmPointerType, llvmPointerType, llvmPointerType, 227 llvmInt32Type, llvmPointerType /* void *stream */}}; 228 FunctionCallBuilder spMVCallBuilder = { 229 "mgpuSpMV", 230 llvmVoidType, 231 {llvmInt32Type, llvmPointerType, llvmPointerType, llvmPointerType, 232 llvmInt32Type, llvmPointerType, llvmPointerType /* void *stream */}}; 233 FunctionCallBuilder createSpMMBufferSizeCallBuilder = { 234 "mgpuSpMMBufferSize", 235 llvmIntPtrType, 236 {llvmInt32Type, llvmInt32Type, llvmPointerType, llvmPointerType, 237 llvmPointerType, llvmInt32Type, llvmPointerType /* void *stream */}}; 238 FunctionCallBuilder createSpMMCallBuilder = { 239 "mgpuSpMM", 240 llvmVoidType, 241 {llvmInt32Type, llvmInt32Type, llvmPointerType, llvmPointerType, 242 llvmPointerType, llvmInt32Type, llvmPointerType, 243 llvmPointerType /* void *stream */}}; 244 FunctionCallBuilder createSDDMMBufferSizeCallBuilder = { 245 "mgpuSDDMMBufferSize", 246 llvmIntPtrType, 247 {llvmInt32Type, llvmInt32Type, llvmPointerType, llvmPointerType, 248 llvmPointerType, llvmInt32Type, llvmPointerType /* void *stream */}}; 249 FunctionCallBuilder createSDDMMCallBuilder = { 250 "mgpuSDDMM", 251 llvmVoidType, 252 {llvmInt32Type, llvmInt32Type, llvmPointerType, llvmPointerType, 253 llvmPointerType, llvmInt32Type, llvmPointerType, 254 llvmPointerType /* void *stream */}}; 255 FunctionCallBuilder createLtDnMatCallBuilder = { 256 "mgpuCreateCuSparseLtDnMat", 257 llvmVoidType, 258 {llvmPointerType, llvmIntPtrType, llvmIntPtrType, llvmPointerType, 259 llvmInt32Type, llvmPointerType /* void *stream */}}; 260 FunctionCallBuilder destroyCuSparseLtSpMatBuilder = { 261 "mgpuDestroyCuSparseLtSpMat", 262 llvmVoidType, 263 {llvmPointerType, llvmPointerType /* void *stream */}}; 264 FunctionCallBuilder destroyCuSparseLtDnMatBuilder = { 265 "mgpuDestroyCuSparseLtDnMat", 266 llvmVoidType, 267 {llvmPointerType, llvmPointerType /* void *stream */}}; 268 FunctionCallBuilder create2To4SpMatCallBuilder = { 269 "mgpuCusparseLtCreate2To4SpMat", 270 llvmVoidType, 271 {llvmPointerType, llvmIntPtrType, llvmIntPtrType, llvmPointerType, 272 llvmInt32Type, llvmPointerType /* void *stream */}}; 273 FunctionCallBuilder createCuSparseLtSpMMBufferSizeBuilder = { 274 "mgpuCuSparseLtSpMMBufferSize", 275 llvmVoidType, 276 {llvmPointerType, llvmInt32Type, llvmInt32Type, llvmPointerType, 277 llvmPointerType, llvmPointerType, llvmInt32Type, llvmInt32Type, 278 llvmPointerType /*void *stream*/}}; 279 FunctionCallBuilder createCuSparseLtSpMMBuilder = { 280 "mgpuCuSparseLtSpMM", 281 llvmVoidType, 282 {llvmPointerType, llvmPointerType, llvmPointerType, llvmPointerType, 283 llvmPointerType, llvmPointerType, llvmPointerType /*void *stream*/}}; 284 FunctionCallBuilder createSpGEMMCreateDescrBuilder = { 285 "mgpuSpGEMMCreateDescr", 286 llvmPointerType, 287 {llvmPointerType /*void *stream*/}}; 288 FunctionCallBuilder createSpGEMMDestroyDescrBuilder = { 289 "mgpuSpGEMMDestroyDescr", 290 llvmVoidType, 291 {llvmPointerType /*s*/, llvmPointerType /*void *stream*/}}; 292 FunctionCallBuilder createSpGEMMWorkEstimationBuilder = { 293 "mgpuSpGEMMWorkEstimation", 294 llvmIntPtrType, 295 {llvmPointerType /*s*/, llvmInt32Type /*ma*/, llvmInt32Type /*mb*/, 296 llvmPointerType /*a*/, llvmPointerType /*b*/, llvmPointerType /*c*/, 297 llvmInt32Type /*ctp*/, llvmIntPtrType /*bs*/, llvmPointerType /*buf*/, 298 llvmPointerType /*void *stream*/}}; 299 FunctionCallBuilder createSpGEMMComputeBuilder = { 300 "mgpuSpGEMMCompute", 301 llvmIntPtrType, 302 {llvmPointerType /*s*/, llvmInt32Type /*ma*/, llvmInt32Type /*mb*/, 303 llvmPointerType /*a*/, llvmPointerType /*b*/, llvmPointerType /*c*/, 304 llvmInt32Type /*ctp*/, llvmIntPtrType /*bs*/, llvmPointerType /*buf*/, 305 llvmPointerType /*void *stream*/}}; 306 FunctionCallBuilder createSpGEMMCopyBuilder = { 307 "mgpuSpGEMMCopy", 308 llvmVoidType, 309 {llvmPointerType /*s*/, llvmInt32Type /*ma*/, llvmInt32Type /*mb*/, 310 llvmPointerType /*a*/, llvmPointerType /*b*/, llvmPointerType /*c*/, 311 llvmInt32Type /*ctp*/, llvmPointerType /*void *stream*/}}; 312 FunctionCallBuilder createSpMatGetSizeBuilder = { 313 "mgpuSpMatGetSize", 314 llvmVoidType, 315 {llvmPointerType /*mc*/, llvmPointerType /*rc*/, llvmPointerType /*cc*/, 316 llvmPointerType /*nc*/, llvmPointerType /*void *stream*/}}; 317 FunctionCallBuilder createSetCsrPointersBuilder = { 318 "mgpuSetCsrPointers", 319 llvmVoidType, 320 {llvmPointerType /*spmat*/, llvmPointerType /*pos*/, 321 llvmPointerType /*crd*/, llvmPointerType /*val*/, 322 llvmPointerType /*void *stream*/}}; 323 }; 324 325 /// A rewrite pattern to convert gpu.host_register operations into a GPU runtime 326 /// call. Currently it supports CUDA and ROCm (HIP). 327 class ConvertHostRegisterOpToGpuRuntimeCallPattern 328 : public ConvertOpToGpuRuntimeCallPattern<gpu::HostRegisterOp> { 329 public: 330 ConvertHostRegisterOpToGpuRuntimeCallPattern( 331 const LLVMTypeConverter &typeConverter) 332 : ConvertOpToGpuRuntimeCallPattern<gpu::HostRegisterOp>(typeConverter) {} 333 334 private: 335 LogicalResult 336 matchAndRewrite(gpu::HostRegisterOp hostRegisterOp, OpAdaptor adaptor, 337 ConversionPatternRewriter &rewriter) const override; 338 }; 339 340 class ConvertHostUnregisterOpToGpuRuntimeCallPattern 341 : public ConvertOpToGpuRuntimeCallPattern<gpu::HostUnregisterOp> { 342 public: 343 ConvertHostUnregisterOpToGpuRuntimeCallPattern( 344 const LLVMTypeConverter &typeConverter) 345 : ConvertOpToGpuRuntimeCallPattern<gpu::HostUnregisterOp>(typeConverter) { 346 } 347 348 private: 349 LogicalResult 350 matchAndRewrite(gpu::HostUnregisterOp hostUnregisterOp, OpAdaptor adaptor, 351 ConversionPatternRewriter &rewriter) const override; 352 }; 353 354 /// A rewrite pattern to convert gpu.alloc operations into a GPU runtime 355 /// call. Currently it supports CUDA and ROCm (HIP). 356 class ConvertAllocOpToGpuRuntimeCallPattern 357 : public ConvertOpToGpuRuntimeCallPattern<gpu::AllocOp> { 358 public: 359 ConvertAllocOpToGpuRuntimeCallPattern(const LLVMTypeConverter &typeConverter) 360 : ConvertOpToGpuRuntimeCallPattern<gpu::AllocOp>(typeConverter) {} 361 362 private: 363 LogicalResult 364 matchAndRewrite(gpu::AllocOp allocOp, OpAdaptor adaptor, 365 ConversionPatternRewriter &rewriter) const override; 366 }; 367 368 /// A rewrite pattern to convert gpu.dealloc operations into a GPU runtime 369 /// call. Currently it supports CUDA and ROCm (HIP). 370 class ConvertDeallocOpToGpuRuntimeCallPattern 371 : public ConvertOpToGpuRuntimeCallPattern<gpu::DeallocOp> { 372 public: 373 ConvertDeallocOpToGpuRuntimeCallPattern( 374 const LLVMTypeConverter &typeConverter) 375 : ConvertOpToGpuRuntimeCallPattern<gpu::DeallocOp>(typeConverter) {} 376 377 private: 378 LogicalResult 379 matchAndRewrite(gpu::DeallocOp deallocOp, OpAdaptor adaptor, 380 ConversionPatternRewriter &rewriter) const override; 381 }; 382 383 class ConvertAsyncYieldToGpuRuntimeCallPattern 384 : public ConvertOpToGpuRuntimeCallPattern<async::YieldOp> { 385 public: 386 ConvertAsyncYieldToGpuRuntimeCallPattern( 387 const LLVMTypeConverter &typeConverter) 388 : ConvertOpToGpuRuntimeCallPattern<async::YieldOp>(typeConverter) {} 389 390 private: 391 LogicalResult 392 matchAndRewrite(async::YieldOp yieldOp, OpAdaptor adaptor, 393 ConversionPatternRewriter &rewriter) const override; 394 }; 395 396 /// A rewrite pattern to convert gpu.wait operations into a GPU runtime 397 /// call. Currently it supports CUDA and ROCm (HIP). 398 class ConvertWaitOpToGpuRuntimeCallPattern 399 : public ConvertOpToGpuRuntimeCallPattern<gpu::WaitOp> { 400 public: 401 ConvertWaitOpToGpuRuntimeCallPattern(const LLVMTypeConverter &typeConverter) 402 : ConvertOpToGpuRuntimeCallPattern<gpu::WaitOp>(typeConverter) {} 403 404 private: 405 LogicalResult 406 matchAndRewrite(gpu::WaitOp waitOp, OpAdaptor adaptor, 407 ConversionPatternRewriter &rewriter) const override; 408 }; 409 410 /// A rewrite pattern to convert gpu.wait async operations into a GPU runtime 411 /// call. Currently it supports CUDA and ROCm (HIP). 412 class ConvertWaitAsyncOpToGpuRuntimeCallPattern 413 : public ConvertOpToGpuRuntimeCallPattern<gpu::WaitOp> { 414 public: 415 ConvertWaitAsyncOpToGpuRuntimeCallPattern( 416 const LLVMTypeConverter &typeConverter) 417 : ConvertOpToGpuRuntimeCallPattern<gpu::WaitOp>(typeConverter) {} 418 419 private: 420 LogicalResult 421 matchAndRewrite(gpu::WaitOp waitOp, OpAdaptor adaptor, 422 ConversionPatternRewriter &rewriter) const override; 423 }; 424 425 /// A rewrite patter to legalize gpu.launch_func with LLVM types. 426 class LegalizeLaunchFuncOpPattern 427 : public ConvertOpToGpuRuntimeCallPattern<gpu::LaunchFuncOp> { 428 public: 429 LegalizeLaunchFuncOpPattern(const LLVMTypeConverter &typeConverter, 430 bool kernelBarePtrCallConv, 431 bool kernelIntersperseSizeCallConv) 432 : ConvertOpToGpuRuntimeCallPattern<gpu::LaunchFuncOp>(typeConverter), 433 kernelBarePtrCallConv(kernelBarePtrCallConv), 434 kernelIntersperseSizeCallConv(kernelIntersperseSizeCallConv) {} 435 436 private: 437 LogicalResult 438 matchAndRewrite(gpu::LaunchFuncOp launchOp, OpAdaptor adaptor, 439 ConversionPatternRewriter &rewriter) const override; 440 441 bool kernelBarePtrCallConv; 442 bool kernelIntersperseSizeCallConv; 443 }; 444 445 /// A rewrite pattern to convert gpu.memcpy operations into a GPU runtime 446 /// call. Currently it supports CUDA and ROCm (HIP). 447 class ConvertMemcpyOpToGpuRuntimeCallPattern 448 : public ConvertOpToGpuRuntimeCallPattern<gpu::MemcpyOp> { 449 public: 450 ConvertMemcpyOpToGpuRuntimeCallPattern(const LLVMTypeConverter &typeConverter) 451 : ConvertOpToGpuRuntimeCallPattern<gpu::MemcpyOp>(typeConverter) {} 452 453 private: 454 LogicalResult 455 matchAndRewrite(gpu::MemcpyOp memcpyOp, OpAdaptor adaptor, 456 ConversionPatternRewriter &rewriter) const override; 457 }; 458 459 /// A rewrite pattern to convert gpu.memset operations into a GPU runtime 460 /// call. Currently it supports CUDA and ROCm (HIP). 461 class ConvertMemsetOpToGpuRuntimeCallPattern 462 : public ConvertOpToGpuRuntimeCallPattern<gpu::MemsetOp> { 463 public: 464 ConvertMemsetOpToGpuRuntimeCallPattern(const LLVMTypeConverter &typeConverter) 465 : ConvertOpToGpuRuntimeCallPattern<gpu::MemsetOp>(typeConverter) {} 466 467 private: 468 LogicalResult 469 matchAndRewrite(gpu::MemsetOp memsetOp, OpAdaptor adaptor, 470 ConversionPatternRewriter &rewriter) const override; 471 }; 472 473 /// A rewrite pattern to convert gpu.set_default_device to a GPU runtime call. 474 /// Currently supports CUDA and ROCm (HIP) 475 class ConvertSetDefaultDeviceOpToGpuRuntimeCallPattern 476 : public ConvertOpToGpuRuntimeCallPattern<gpu::SetDefaultDeviceOp> { 477 public: 478 ConvertSetDefaultDeviceOpToGpuRuntimeCallPattern( 479 const LLVMTypeConverter &typeConverter) 480 : ConvertOpToGpuRuntimeCallPattern<gpu::SetDefaultDeviceOp>( 481 typeConverter) {} 482 483 LogicalResult 484 matchAndRewrite(gpu::SetDefaultDeviceOp op, OpAdaptor adaptor, 485 ConversionPatternRewriter &rewriter) const override; 486 }; 487 488 /// Generic rewriting rule for operation on sparse matrices. 489 /// Currently supports CUDA (by means of cuSparse and cuSparseLt). 490 #define DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(op_name) \ 491 class Convert##op_name##ToGpuRuntimeCallPattern \ 492 : public ConvertOpToGpuRuntimeCallPattern<gpu::op_name> { \ 493 public: \ 494 Convert##op_name##ToGpuRuntimeCallPattern( \ 495 const LLVMTypeConverter &typeConverter) \ 496 : ConvertOpToGpuRuntimeCallPattern<gpu::op_name>(typeConverter) {} \ 497 \ 498 private: \ 499 LogicalResult \ 500 matchAndRewrite(gpu::op_name op, OpAdaptor adaptor, \ 501 ConversionPatternRewriter &rewriter) const override; \ 502 }; 503 504 DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(CreateDnTensorOp) 505 DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(DestroyDnTensorOp) 506 DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(CreateCooOp) 507 DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(CreateCooAoSOp) 508 DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(CreateCsrOp) 509 DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(CreateCscOp) 510 DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(CreateBsrOp) 511 DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(Create2To4SpMatOp) 512 DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(DestroySpMatOp) 513 DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(SpMVBufferSizeOp) 514 DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(SpMVOp) 515 DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(SpMMBufferSizeOp) 516 DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(SDDMMBufferSizeOp) 517 DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(SpMMOp) 518 DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(SDDMMOp) 519 DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(SpGEMMCreateDescrOp) 520 DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(SpGEMMDestroyDescrOp) 521 DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(SpGEMMWorkEstimationOrComputeOp) 522 DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(SpGEMMCopyOp) 523 DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(SpMatGetSizeOp) 524 DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(SetCsrPointersOp) 525 526 } // namespace 527 528 void GpuToLLVMConversionPass::runOnOperation() { 529 MLIRContext *context = &getContext(); 530 531 // Perform progressive lowering of vector transfer operations. 532 { 533 RewritePatternSet patterns(&getContext()); 534 // Vector transfer ops with rank > 1 should be lowered with VectorToSCF. 535 vector::populateVectorTransferLoweringPatterns(patterns, 536 /*maxTransferRank=*/1); 537 if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) 538 return signalPassFailure(); 539 } 540 541 LowerToLLVMOptions options(context); 542 options.useBarePtrCallConv = hostBarePtrCallConv; 543 RewritePatternSet patterns(context); 544 ConversionTarget target(*context); 545 target.addLegalDialect<LLVM::LLVMDialect>(); 546 LLVMTypeConverter converter(context, options); 547 548 // Populate all patterns from all dialects that implement the 549 // `ConvertToLLVMPatternInterface` interface. 550 for (Dialect *dialect : context->getLoadedDialects()) { 551 auto iface = dyn_cast<ConvertToLLVMPatternInterface>(dialect); 552 if (!iface) 553 continue; 554 iface->populateConvertToLLVMConversionPatterns(target, converter, patterns); 555 } 556 557 // Preserve GPU modules and binaries. Modules are preserved as they can be 558 // converted later by `gpu-module-to-binary`. 559 target.addLegalOp<gpu::GPUModuleOp, gpu::BinaryOp>(); 560 // Accept as legal LaunchFuncOps if the operands have been lowered. 561 target.addDynamicallyLegalOp<gpu::LaunchFuncOp>( 562 [&](gpu::LaunchFuncOp op) -> bool { return converter.isLegal(op); }); 563 564 // These aren't covered by the ConvertToLLVMPatternInterface right now. 565 populateVectorToLLVMConversionPatterns(converter, patterns); 566 populateFinalizeMemRefToLLVMConversionPatterns(converter, patterns); 567 populateAsyncStructuralTypeConversionsAndLegality(converter, patterns, 568 target); 569 populateGpuToLLVMConversionPatterns(converter, patterns, 570 kernelBarePtrCallConv, 571 kernelIntersperseSizeCallConv); 572 573 if (failed( 574 applyPartialConversion(getOperation(), target, std::move(patterns)))) 575 signalPassFailure(); 576 } 577 578 LLVM::CallOp FunctionCallBuilder::create(Location loc, OpBuilder &builder, 579 ArrayRef<Value> arguments) const { 580 auto module = builder.getBlock()->getParent()->getParentOfType<ModuleOp>(); 581 auto function = [&] { 582 if (auto function = module.lookupSymbol<LLVM::LLVMFuncOp>(functionName)) 583 return function; 584 return OpBuilder::atBlockEnd(module.getBody()) 585 .create<LLVM::LLVMFuncOp>(loc, functionName, functionType); 586 }(); 587 return builder.create<LLVM::CallOp>(loc, function, arguments); 588 } 589 590 // Corresponding to cusparseIndexType_t defined in cusparse.h. 591 static int32_t getCuSparseIndexTypeFrom(Type type) { 592 if (type.isInteger(16)) 593 return 1; // CUSPARSE_INDEX_16U 594 if (type.isInteger(32)) 595 return 2; // CUSPARSE_INDEX_32I 596 return 3; // CUSPARSE_INDEX_64I 597 } 598 599 static int32_t getCuSparseLtDataTypeFrom(Type type) { 600 if (type.isF16()) 601 return 0; // CUSPARSE_COMPUTE_16F, 602 if (type.isInteger(32)) 603 return 1; // CUSPARSE_COMPUTE_32I 604 llvm_unreachable("unsupported type"); 605 // TODO: add support to TF32 606 } 607 608 // Corresponding to cudaDataType_t defined in CUDA library_types.h. 609 static int32_t getCuSparseDataTypeFrom(Type type) { 610 if (llvm::isa<ComplexType>(type)) { 611 // get the element type 612 auto elementType = cast<ComplexType>(type).getElementType(); 613 if (elementType.isBF16()) 614 return 15; // CUDA_C_16BF 615 if (elementType.isF16()) 616 return 6; // CUDA_C_16F 617 if (elementType.isF32()) 618 return 4; // CUDA_C_32F 619 if (elementType.isF64()) 620 return 5; // CUDA_C_64F 621 if (elementType.isInteger(8)) 622 return 7; // CUDA_C_8I 623 if (elementType.isInteger(16)) 624 return 21; // CUDA_C_16I 625 if (elementType.isInteger(32)) 626 return 11; // CUDA_C_32I 627 } 628 if (type.isBF16()) 629 return 14; // CUDA_R_16BF 630 if (type.isF16()) 631 return 2; // CUDA_R_16F 632 if (type.isF32()) 633 return 0; // CUDA_R_32F 634 if (type.isF64()) 635 return 1; // CUDA_R_64F 636 if (type.isInteger(8)) 637 return 3; // CUDA_R_8I 638 if (type.isInteger(16)) 639 return 20; // CUDA_R_16I 640 if (type.isInteger(32)) 641 return 10; // CUDA_R_32I 642 643 llvm_unreachable("unsupported element type"); 644 } 645 646 static gpu::Prune2To4SpMatFlag get2To4PruneFlag(Value spMat) { 647 return spMat.getDefiningOp<gpu::Create2To4SpMatOp>().getPruneFlag(); 648 } 649 650 // TODO: We may want a run-time (of the mlir compiler) disablement/warning: 651 // cusparseLt currently won't work for cuda architecture <8.0 and will trigger a 652 // runtime (of the CUDA program) error , but it might be great if we could at 653 // least output a warning when we found the target architecture is <8.0 and the 654 // user still wants to use cusparseLt. to make sure when lowering gpu sparse 655 // dialect to llvm calls, the cusparselt calls are disabled for cuda 656 // architecture <8.0 657 static bool is2To4Sparsity(Value spMat) { 658 if (auto op = spMat.getDefiningOp<gpu::Create2To4SpMatOp>()) 659 return true; 660 if (auto op = spMat.getDefiningOp<gpu::CreateCooOp>()) 661 return false; 662 if (auto op = spMat.getDefiningOp<gpu::CreateCooAoSOp>()) 663 return false; 664 if (auto op = spMat.getDefiningOp<gpu::CreateCsrOp>()) 665 return false; 666 if (auto op = spMat.getDefiningOp<gpu::CreateCscOp>()) 667 return false; 668 if (auto op = spMat.getDefiningOp<gpu::CreateBsrOp>()) 669 return false; 670 // Print the spMat defining op 671 spMat.getDefiningOp()->print(llvm::errs()); 672 llvm_unreachable("cannot find spmat def"); 673 } 674 675 static bool isSpMMCusparseLtOp(Value op) { 676 for (Operation *user : op.getUsers()) { 677 auto spmmOp = dyn_cast<gpu::SpMMOp>(user); 678 // If the other operator is 50% sparsity then we should use cusparseLt 679 if (!spmmOp) 680 continue; 681 if (is2To4Sparsity(spmmOp.getSpmatA())) 682 return true; 683 } 684 return false; 685 } 686 687 // Returns whether all operands are of LLVM type. 688 static LogicalResult areAllLLVMTypes(Operation *op, ValueRange operands, 689 ConversionPatternRewriter &rewriter) { 690 if (!llvm::all_of(operands, [](Value value) { 691 return LLVM::isCompatibleType(value.getType()); 692 })) 693 return rewriter.notifyMatchFailure( 694 op, "Cannot convert if operands aren't of LLVM type."); 695 return success(); 696 } 697 698 static LogicalResult 699 isAsyncWithOneDependency(ConversionPatternRewriter &rewriter, 700 gpu::AsyncOpInterface op) { 701 if (op.getAsyncDependencies().size() != 1) 702 return rewriter.notifyMatchFailure( 703 op, "Can only convert with exactly one async dependency."); 704 705 if (!op.getAsyncToken()) 706 return rewriter.notifyMatchFailure(op, "Can convert only async version."); 707 708 return success(); 709 } 710 711 LogicalResult ConvertHostRegisterOpToGpuRuntimeCallPattern::matchAndRewrite( 712 gpu::HostRegisterOp hostRegisterOp, OpAdaptor adaptor, 713 ConversionPatternRewriter &rewriter) const { 714 auto *op = hostRegisterOp.getOperation(); 715 if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter))) 716 return failure(); 717 718 Location loc = op->getLoc(); 719 720 auto memRefType = hostRegisterOp.getValue().getType(); 721 auto elementType = cast<UnrankedMemRefType>(memRefType).getElementType(); 722 auto elementSize = getSizeInBytes(loc, elementType, rewriter); 723 724 auto arguments = getTypeConverter()->promoteOperands( 725 loc, op->getOperands(), adaptor.getOperands(), rewriter); 726 arguments.push_back(elementSize); 727 hostRegisterCallBuilder.create(loc, rewriter, arguments); 728 729 rewriter.eraseOp(op); 730 return success(); 731 } 732 733 LogicalResult ConvertHostUnregisterOpToGpuRuntimeCallPattern::matchAndRewrite( 734 gpu::HostUnregisterOp hostUnregisterOp, OpAdaptor adaptor, 735 ConversionPatternRewriter &rewriter) const { 736 Operation *op = hostUnregisterOp.getOperation(); 737 if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter))) 738 return failure(); 739 740 Location loc = op->getLoc(); 741 742 auto memRefType = hostUnregisterOp.getValue().getType(); 743 auto elementType = cast<UnrankedMemRefType>(memRefType).getElementType(); 744 auto elementSize = getSizeInBytes(loc, elementType, rewriter); 745 746 auto arguments = getTypeConverter()->promoteOperands( 747 loc, op->getOperands(), adaptor.getOperands(), rewriter); 748 arguments.push_back(elementSize); 749 hostUnregisterCallBuilder.create(loc, rewriter, arguments); 750 751 rewriter.eraseOp(op); 752 return success(); 753 } 754 755 LogicalResult ConvertAllocOpToGpuRuntimeCallPattern::matchAndRewrite( 756 gpu::AllocOp allocOp, OpAdaptor adaptor, 757 ConversionPatternRewriter &rewriter) const { 758 759 MemRefType memRefType = allocOp.getType(); 760 761 if (failed(areAllLLVMTypes(allocOp, adaptor.getOperands(), rewriter)) || 762 !isConvertibleAndHasIdentityMaps(memRefType)) 763 return failure(); 764 765 auto loc = allocOp.getLoc(); 766 767 bool isShared = allocOp.getHostShared(); 768 769 if (isShared && allocOp.getAsyncToken()) 770 return rewriter.notifyMatchFailure( 771 allocOp, "Host Shared allocation cannot be done async"); 772 if (!isShared && failed(isAsyncWithOneDependency(rewriter, allocOp))) 773 return failure(); 774 775 // Get shape of the memref as values: static sizes are constant 776 // values and dynamic sizes are passed to 'alloc' as operands. 777 SmallVector<Value, 4> shape; 778 SmallVector<Value, 4> strides; 779 Value sizeBytes; 780 getMemRefDescriptorSizes(loc, memRefType, adaptor.getDynamicSizes(), rewriter, 781 shape, strides, sizeBytes); 782 783 // Allocate the underlying buffer and store a pointer to it in the MemRef 784 // descriptor. 785 auto nullPtr = rewriter.create<mlir::LLVM::ZeroOp>(loc, llvmPointerType); 786 Value stream = adaptor.getAsyncDependencies().empty() 787 ? nullPtr 788 : adaptor.getAsyncDependencies().front(); 789 790 auto isHostShared = rewriter.create<mlir::LLVM::ConstantOp>( 791 loc, llvmInt8Type, rewriter.getI8IntegerAttr(isShared)); 792 793 Value allocatedPtr = 794 allocCallBuilder.create(loc, rewriter, {sizeBytes, stream, isHostShared}) 795 .getResult(); 796 797 // No alignment. 798 Value alignedPtr = allocatedPtr; 799 800 // Create the MemRef descriptor. 801 auto memRefDescriptor = this->createMemRefDescriptor( 802 loc, memRefType, allocatedPtr, alignedPtr, shape, strides, rewriter); 803 804 if (allocOp.getAsyncToken()) { 805 // Async alloc: make dependent ops use the same stream. 806 rewriter.replaceOp(allocOp, {memRefDescriptor, stream}); 807 } else { 808 rewriter.replaceOp(allocOp, {memRefDescriptor}); 809 } 810 811 return success(); 812 } 813 814 LogicalResult ConvertDeallocOpToGpuRuntimeCallPattern::matchAndRewrite( 815 gpu::DeallocOp deallocOp, OpAdaptor adaptor, 816 ConversionPatternRewriter &rewriter) const { 817 if (failed(areAllLLVMTypes(deallocOp, adaptor.getOperands(), rewriter)) || 818 failed(isAsyncWithOneDependency(rewriter, deallocOp))) 819 return failure(); 820 821 Location loc = deallocOp.getLoc(); 822 823 Value pointer = 824 MemRefDescriptor(adaptor.getMemref()).allocatedPtr(rewriter, loc); 825 Value stream = adaptor.getAsyncDependencies().front(); 826 deallocCallBuilder.create(loc, rewriter, {pointer, stream}); 827 828 rewriter.replaceOp(deallocOp, {stream}); 829 return success(); 830 } 831 832 static bool isGpuAsyncTokenType(Value value) { 833 return isa<gpu::AsyncTokenType>(value.getType()); 834 } 835 836 // Converts !gpu.async.token operands of `async.yield` to runtime calls. The 837 // !gpu.async.token are lowered to stream within the async.execute region, but 838 // are passed as events between them. For each !gpu.async.token operand, we 839 // create an event and record it on the stream. 840 LogicalResult ConvertAsyncYieldToGpuRuntimeCallPattern::matchAndRewrite( 841 async::YieldOp yieldOp, OpAdaptor adaptor, 842 ConversionPatternRewriter &rewriter) const { 843 if (llvm::none_of(yieldOp.getOperands(), isGpuAsyncTokenType)) 844 return rewriter.notifyMatchFailure(yieldOp, "no gpu async token operand"); 845 846 Location loc = yieldOp.getLoc(); 847 SmallVector<Value, 4> newOperands(adaptor.getOperands()); 848 llvm::SmallDenseSet<Value> streams; 849 for (auto &operand : yieldOp->getOpOperands()) { 850 if (!isGpuAsyncTokenType(operand.get())) 851 continue; 852 auto idx = operand.getOperandNumber(); 853 auto stream = adaptor.getOperands()[idx]; 854 auto event = eventCreateCallBuilder.create(loc, rewriter, {}).getResult(); 855 eventRecordCallBuilder.create(loc, rewriter, {event, stream}); 856 newOperands[idx] = event; 857 streams.insert(stream); 858 } 859 for (auto stream : streams) 860 streamDestroyCallBuilder.create(loc, rewriter, {stream}); 861 862 rewriter.modifyOpInPlace(yieldOp, [&] { yieldOp->setOperands(newOperands); }); 863 return success(); 864 } 865 866 // Returns whether `value` is the result of an LLVM::CallOp to `functionName`. 867 static bool isDefinedByCallTo(Value value, StringRef functionName) { 868 assert(isa<LLVM::LLVMPointerType>(value.getType())); 869 if (auto defOp = value.getDefiningOp<LLVM::CallOp>()) 870 return *defOp.getCallee() == functionName; 871 return false; 872 } 873 874 // Converts `gpu.wait` to runtime calls. The converted op synchronizes the host 875 // with the stream/event operands. The operands are destroyed. That is, it 876 // assumes that it is not used afterwards or elsewhere. Otherwise we will get a 877 // runtime error. Eventually, we should guarantee this property. 878 LogicalResult ConvertWaitOpToGpuRuntimeCallPattern::matchAndRewrite( 879 gpu::WaitOp waitOp, OpAdaptor adaptor, 880 ConversionPatternRewriter &rewriter) const { 881 if (waitOp.getAsyncToken()) 882 return rewriter.notifyMatchFailure(waitOp, "Cannot convert async op."); 883 884 Location loc = waitOp.getLoc(); 885 886 for (auto operand : adaptor.getOperands()) { 887 if (isDefinedByCallTo(operand, streamCreateCallBuilder.functionName)) { 888 // The converted operand's definition created a stream. 889 streamSynchronizeCallBuilder.create(loc, rewriter, {operand}); 890 streamDestroyCallBuilder.create(loc, rewriter, {operand}); 891 } else { 892 // Otherwise the converted operand is an event. This assumes that we use 893 // events in control flow code as well. 894 eventSynchronizeCallBuilder.create(loc, rewriter, {operand}); 895 eventDestroyCallBuilder.create(loc, rewriter, {operand}); 896 } 897 } 898 899 rewriter.eraseOp(waitOp); 900 return success(); 901 } 902 903 // Converts `gpu.wait async` to runtime calls. The converted op creates a new 904 // stream that is synchronized with stream/event operands. The operands are 905 // destroyed. That is, it assumes that it is not used afterwards or elsewhere. 906 // Otherwise we will get a runtime error. Eventually, we should guarantee this 907 // property. 908 LogicalResult ConvertWaitAsyncOpToGpuRuntimeCallPattern::matchAndRewrite( 909 gpu::WaitOp waitOp, OpAdaptor adaptor, 910 ConversionPatternRewriter &rewriter) const { 911 if (!waitOp.getAsyncToken()) 912 return rewriter.notifyMatchFailure(waitOp, "Can only convert async op."); 913 914 Location loc = waitOp.getLoc(); 915 916 auto insertionPoint = rewriter.saveInsertionPoint(); 917 SmallVector<Value, 1> events; 918 for (auto pair : 919 llvm::zip(waitOp.getAsyncDependencies(), adaptor.getOperands())) { 920 auto operand = std::get<1>(pair); 921 if (isDefinedByCallTo(operand, streamCreateCallBuilder.functionName)) { 922 // The converted operand's definition created a stream. Insert an event 923 // into the stream just after the last use of the original token operand. 924 auto *defOp = std::get<0>(pair).getDefiningOp(); 925 rewriter.setInsertionPointAfter(defOp); 926 auto event = eventCreateCallBuilder.create(loc, rewriter, {}).getResult(); 927 eventRecordCallBuilder.create(loc, rewriter, {event, operand}); 928 events.push_back(event); 929 } else { 930 // Otherwise the converted operand is an event. This assumes that we use 931 // events in control flow code as well. 932 events.push_back(operand); 933 } 934 } 935 rewriter.restoreInsertionPoint(insertionPoint); 936 auto stream = streamCreateCallBuilder.create(loc, rewriter, {}).getResult(); 937 for (auto event : events) 938 streamWaitEventCallBuilder.create(loc, rewriter, {stream, event}); 939 for (auto event : events) 940 eventDestroyCallBuilder.create(loc, rewriter, {event}); 941 rewriter.replaceOp(waitOp, {stream}); 942 943 return success(); 944 } 945 946 // Legalize the op's operands. 947 LogicalResult LegalizeLaunchFuncOpPattern::matchAndRewrite( 948 gpu::LaunchFuncOp launchOp, OpAdaptor adaptor, 949 ConversionPatternRewriter &rewriter) const { 950 if (failed(areAllLLVMTypes(launchOp, adaptor.getOperands(), rewriter))) 951 return failure(); 952 953 if (launchOp.getAsyncDependencies().size() > 1) 954 return rewriter.notifyMatchFailure( 955 launchOp, "Cannot convert with more than one async dependency."); 956 957 // Fail when the synchronous version of the op has async dependencies. The 958 // lowering destroys the stream, and we do not want to check that there is no 959 // use of the stream after this op. 960 if (!launchOp.getAsyncToken() && !launchOp.getAsyncDependencies().empty()) 961 return rewriter.notifyMatchFailure( 962 launchOp, "Cannot convert non-async op with async dependencies."); 963 964 Location loc = launchOp.getLoc(); 965 966 Value stream = Value(); 967 if (!adaptor.getAsyncDependencies().empty()) 968 stream = adaptor.getAsyncDependencies().front(); 969 // If the async keyword is present and there are no dependencies, then a 970 // stream must be created to pass to subsequent operations. 971 else if (launchOp.getAsyncToken()) 972 stream = streamCreateCallBuilder.create(loc, rewriter, {}).getResult(); 973 974 // Lower the kernel operands to match kernel parameters. 975 // Note: If `useBarePtrCallConv` is set in the type converter's options, 976 // the value of `kernelBarePtrCallConv` will be ignored. 977 OperandRange origArguments = launchOp.getKernelOperands(); 978 SmallVector<Value, 8> llvmArguments = getTypeConverter()->promoteOperands( 979 loc, origArguments, adaptor.getKernelOperands(), rewriter, 980 /*useBarePtrCallConv=*/kernelBarePtrCallConv); 981 SmallVector<Value, 8> llvmArgumentsWithSizes; 982 983 // Intersperse size information if requested. 984 if (kernelIntersperseSizeCallConv) { 985 if (origArguments.size() != llvmArguments.size()) { 986 // This shouldn't happen if the bare-pointer calling convention is used. 987 return rewriter.notifyMatchFailure( 988 launchOp, 989 "Cannot add sizes to arguments with one-to-many LLVM IR expansion."); 990 } 991 992 llvmArgumentsWithSizes.reserve(llvmArguments.size() * 2); 993 for (auto [llvmArg, origArg] : zip_equal(llvmArguments, origArguments)) { 994 auto memrefTy = dyn_cast<MemRefType>(origArg.getType()); 995 if (!memrefTy) { 996 return rewriter.notifyMatchFailure( 997 launchOp, "Operand to launch op is not a memref."); 998 } 999 1000 if (!memrefTy.hasStaticShape() || 1001 !memrefTy.getElementType().isIntOrFloat()) { 1002 return rewriter.notifyMatchFailure( 1003 launchOp, "Operand to launch op is not a memref with a static " 1004 "shape and an integer or float element type."); 1005 } 1006 1007 unsigned bitwidth = memrefTy.getElementTypeBitWidth(); 1008 if (bitwidth % 8 != 0) { 1009 return rewriter.notifyMatchFailure( 1010 launchOp, "Operand to launch op is not a memref with a " 1011 "byte-aligned element type."); 1012 } 1013 1014 uint64_t staticSize = static_cast<uint64_t>(bitwidth / 8) * 1015 static_cast<uint64_t>(memrefTy.getNumElements()); 1016 1017 Value sizeArg = rewriter.create<LLVM::ConstantOp>( 1018 loc, getIndexType(), rewriter.getIndexAttr(staticSize)); 1019 llvmArgumentsWithSizes.push_back(llvmArg); // Presumably a bare pointer. 1020 llvmArgumentsWithSizes.push_back(sizeArg); 1021 } 1022 } 1023 1024 std::optional<gpu::KernelDim3> clusterSize = std::nullopt; 1025 if (launchOp.hasClusterSize()) { 1026 clusterSize = 1027 gpu::KernelDim3{adaptor.getClusterSizeX(), adaptor.getClusterSizeY(), 1028 adaptor.getClusterSizeZ()}; 1029 } 1030 rewriter.create<gpu::LaunchFuncOp>( 1031 launchOp.getLoc(), launchOp.getKernelAttr(), 1032 gpu::KernelDim3{adaptor.getGridSizeX(), adaptor.getGridSizeY(), 1033 adaptor.getGridSizeZ()}, 1034 gpu::KernelDim3{adaptor.getBlockSizeX(), adaptor.getBlockSizeY(), 1035 adaptor.getBlockSizeZ()}, 1036 adaptor.getDynamicSharedMemorySize(), 1037 llvmArgumentsWithSizes.empty() ? llvmArguments : llvmArgumentsWithSizes, 1038 stream, clusterSize); 1039 if (launchOp.getAsyncToken()) 1040 rewriter.replaceOp(launchOp, {stream}); 1041 else 1042 rewriter.eraseOp(launchOp); 1043 return success(); 1044 } 1045 1046 static Value bitAndAddrspaceCast(Location loc, 1047 ConversionPatternRewriter &rewriter, 1048 LLVM::LLVMPointerType destinationType, 1049 Value sourcePtr, 1050 const LLVMTypeConverter &typeConverter) { 1051 auto sourceTy = cast<LLVM::LLVMPointerType>(sourcePtr.getType()); 1052 if (destinationType.getAddressSpace() != sourceTy.getAddressSpace()) 1053 sourcePtr = rewriter.create<LLVM::AddrSpaceCastOp>( 1054 loc, 1055 LLVM::LLVMPointerType::get(rewriter.getContext(), 1056 destinationType.getAddressSpace()), 1057 sourcePtr); 1058 return sourcePtr; 1059 } 1060 1061 LogicalResult ConvertMemcpyOpToGpuRuntimeCallPattern::matchAndRewrite( 1062 gpu::MemcpyOp memcpyOp, OpAdaptor adaptor, 1063 ConversionPatternRewriter &rewriter) const { 1064 auto memRefType = cast<MemRefType>(memcpyOp.getSrc().getType()); 1065 1066 if (failed(areAllLLVMTypes(memcpyOp, adaptor.getOperands(), rewriter)) || 1067 !isConvertibleAndHasIdentityMaps(memRefType) || 1068 failed(isAsyncWithOneDependency(rewriter, memcpyOp))) 1069 return failure(); 1070 1071 auto loc = memcpyOp.getLoc(); 1072 1073 MemRefDescriptor srcDesc(adaptor.getSrc()); 1074 Value numElements = getNumElements(rewriter, loc, memRefType, srcDesc); 1075 1076 Type elementPtrType = getElementPtrType(memRefType); 1077 Value nullPtr = rewriter.create<LLVM::ZeroOp>(loc, elementPtrType); 1078 Value gepPtr = rewriter.create<LLVM::GEPOp>( 1079 loc, elementPtrType, 1080 typeConverter->convertType(memRefType.getElementType()), nullPtr, 1081 numElements); 1082 auto sizeBytes = 1083 rewriter.create<LLVM::PtrToIntOp>(loc, getIndexType(), gepPtr); 1084 1085 auto src = bitAndAddrspaceCast(loc, rewriter, llvmPointerType, 1086 srcDesc.alignedPtr(rewriter, loc), 1087 *getTypeConverter()); 1088 auto dst = bitAndAddrspaceCast( 1089 loc, rewriter, llvmPointerType, 1090 MemRefDescriptor(adaptor.getDst()).alignedPtr(rewriter, loc), 1091 *getTypeConverter()); 1092 1093 auto stream = adaptor.getAsyncDependencies().front(); 1094 memcpyCallBuilder.create(loc, rewriter, {dst, src, sizeBytes, stream}); 1095 1096 rewriter.replaceOp(memcpyOp, {stream}); 1097 1098 return success(); 1099 } 1100 1101 LogicalResult ConvertMemsetOpToGpuRuntimeCallPattern::matchAndRewrite( 1102 gpu::MemsetOp memsetOp, OpAdaptor adaptor, 1103 ConversionPatternRewriter &rewriter) const { 1104 auto memRefType = cast<MemRefType>(memsetOp.getDst().getType()); 1105 1106 if (failed(areAllLLVMTypes(memsetOp, adaptor.getOperands(), rewriter)) || 1107 !isConvertibleAndHasIdentityMaps(memRefType) || 1108 failed(isAsyncWithOneDependency(rewriter, memsetOp))) 1109 return failure(); 1110 1111 auto loc = memsetOp.getLoc(); 1112 1113 Type valueType = adaptor.getValue().getType(); 1114 unsigned bitWidth = valueType.getIntOrFloatBitWidth(); 1115 // Ints and floats of 16 or 32 bit width are allowed. 1116 if (!valueType.isIntOrFloat() || (bitWidth != 16 && bitWidth != 32)) { 1117 return rewriter.notifyMatchFailure( 1118 memsetOp, "value must be a 16 or 32 bit int or float"); 1119 } 1120 1121 unsigned valueTypeWidth = valueType.getIntOrFloatBitWidth(); 1122 Type bitCastType = valueTypeWidth == 32 ? llvmInt32Type : llvmInt16Type; 1123 1124 MemRefDescriptor dstDesc(adaptor.getDst()); 1125 Value numElements = getNumElements(rewriter, loc, memRefType, dstDesc); 1126 1127 auto value = 1128 rewriter.create<LLVM::BitcastOp>(loc, bitCastType, adaptor.getValue()); 1129 auto dst = bitAndAddrspaceCast(loc, rewriter, llvmPointerType, 1130 dstDesc.alignedPtr(rewriter, loc), 1131 *getTypeConverter()); 1132 1133 auto stream = adaptor.getAsyncDependencies().front(); 1134 FunctionCallBuilder builder = 1135 valueTypeWidth == 32 ? memset32CallBuilder : memset16CallBuilder; 1136 builder.create(loc, rewriter, {dst, value, numElements, stream}); 1137 1138 rewriter.replaceOp(memsetOp, {stream}); 1139 return success(); 1140 } 1141 1142 LogicalResult ConvertSetDefaultDeviceOpToGpuRuntimeCallPattern::matchAndRewrite( 1143 gpu::SetDefaultDeviceOp op, OpAdaptor adaptor, 1144 ConversionPatternRewriter &rewriter) const { 1145 Location loc = op.getLoc(); 1146 auto call = setDefaultDeviceCallBuilder.create(loc, rewriter, 1147 {adaptor.getDevIndex()}); 1148 rewriter.replaceOp(op, call); 1149 return success(); 1150 } 1151 1152 template <typename T> 1153 static Value genConstInt32From(OpBuilder &builder, Location loc, T tValue) { 1154 Type llvmInt32Type = builder.getIntegerType(32); 1155 return builder.create<LLVM::ConstantOp>(loc, llvmInt32Type, 1156 static_cast<int32_t>(tValue)); 1157 } 1158 1159 template <typename T> 1160 static Value genConstFloat32From(OpBuilder &builder, Location loc, T tValue) { 1161 Type llvmFloat32Type = builder.getF32Type(); 1162 return builder.create<LLVM::ConstantOp>( 1163 loc, llvmFloat32Type, 1164 builder.getF32FloatAttr(static_cast<float>(tValue))); 1165 } 1166 1167 LogicalResult ConvertCreateDnTensorOpToGpuRuntimeCallPattern::matchAndRewrite( 1168 gpu::CreateDnTensorOp op, OpAdaptor adaptor, 1169 ConversionPatternRewriter &rewriter) const { 1170 if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) || 1171 failed(isAsyncWithOneDependency(rewriter, op))) 1172 return failure(); 1173 Location loc = op.getLoc(); 1174 auto stream = adaptor.getAsyncDependencies().front(); 1175 Value pTensor = 1176 MemRefDescriptor(adaptor.getMemref()).allocatedPtr(rewriter, loc); 1177 Type dType = op.getMemref().getType().getElementType(); 1178 auto dtp = genConstInt32From(rewriter, loc, getCuSparseDataTypeFrom(dType)); 1179 1180 SmallVector<Value, 4> dims; 1181 for (Value dim : adaptor.getDims()) { 1182 dims.push_back(dim); 1183 } 1184 1185 Value handle; 1186 // TODO: For now, we track the use of the handle and lower it to cusparse / 1187 // cusparseLt accordingly. If in a block, both cusparse and cusparseLt are 1188 // used, we require two separate Creation ops to be the correct logic. In 1189 // future, we may add support to using one handle in sparse tensor / GPU 1190 // dialect in both cusparse and cusparseLt. use the cusparseLt create call if 1191 // the dnmat is used with spmat with 2:4 sparsity 1192 if (dims.size() == 2) { 1193 if (isSpMMCusparseLtOp(op.getDnTensor())) { 1194 auto handleSz = rewriter.create<LLVM::ConstantOp>( 1195 loc, getIndexType(), rewriter.getIndexAttr(11032)); 1196 handle = rewriter.create<LLVM::AllocaOp>( 1197 loc, llvmPointerType, llvmInt8Type, handleSz, /*alignment=*/16); 1198 handle = rewriter.create<LLVM::BitcastOp>(loc, llvmPointerType, handle); 1199 1200 createLtDnMatCallBuilder 1201 .create(loc, rewriter, 1202 {handle, dims[0], dims[1], pTensor, dtp, stream}) 1203 .getResult(); 1204 } else { 1205 handle = 1206 createDnMatCallBuilder 1207 .create(loc, rewriter, {dims[0], dims[1], pTensor, dtp, stream}) 1208 .getResult(); 1209 } 1210 } else { 1211 assert(dims.size() == 1 && "Only 1D and 2D tensors are supported"); 1212 handle = createDnVecCallBuilder 1213 .create(loc, rewriter, {dims[0], pTensor, dtp, stream}) 1214 .getResult(); 1215 } 1216 rewriter.replaceOp(op, {handle, stream}); 1217 return success(); 1218 } 1219 1220 LogicalResult ConvertDestroyDnTensorOpToGpuRuntimeCallPattern::matchAndRewrite( 1221 gpu::DestroyDnTensorOp op, OpAdaptor adaptor, 1222 ConversionPatternRewriter &rewriter) const { 1223 if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) || 1224 failed(isAsyncWithOneDependency(rewriter, op))) 1225 return failure(); 1226 Location loc = op.getLoc(); 1227 auto stream = adaptor.getAsyncDependencies().front(); 1228 auto definingOp = op.getDnTensor().getDefiningOp<gpu::CreateDnTensorOp>(); 1229 SmallVector<Value, 4> dims; 1230 for (Value dim : definingOp.getDims()) { 1231 dims.push_back(dim); 1232 } 1233 if (dims.size() == 2) { 1234 // Use the cusparseLt destroy call if the dnmat is used with spmat with 1235 // 2:4 sparsity 1236 if (isSpMMCusparseLtOp(op.getDnTensor())) { 1237 destroyCuSparseLtDnMatBuilder.create(loc, rewriter, 1238 {adaptor.getDnTensor(), stream}); 1239 } else { 1240 destroyDnMatCallBuilder.create(loc, rewriter, 1241 {adaptor.getDnTensor(), stream}); 1242 } 1243 } else { 1244 assert(dims.size() == 1 && "Only 1D and 2D tensors are supported"); 1245 destroyDnVecCallBuilder.create(loc, rewriter, 1246 {adaptor.getDnTensor(), stream}); 1247 } 1248 rewriter.replaceOp(op, {stream}); 1249 return success(); 1250 } 1251 1252 LogicalResult ConvertCreateCooOpToGpuRuntimeCallPattern::matchAndRewrite( 1253 gpu::CreateCooOp op, OpAdaptor adaptor, 1254 ConversionPatternRewriter &rewriter) const { 1255 if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) || 1256 failed(isAsyncWithOneDependency(rewriter, op))) 1257 return failure(); 1258 Location loc = op.getLoc(); 1259 auto stream = adaptor.getAsyncDependencies().front(); 1260 Value pRowIdxs = 1261 MemRefDescriptor(adaptor.getRowIdxs()).allocatedPtr(rewriter, loc); 1262 Value pColIdxs = 1263 MemRefDescriptor(adaptor.getColIdxs()).allocatedPtr(rewriter, loc); 1264 Value pValues = 1265 MemRefDescriptor(adaptor.getValues()).allocatedPtr(rewriter, loc); 1266 Type iType = 1267 llvm::cast<MemRefType>(op.getColIdxs().getType()).getElementType(); 1268 Type dType = 1269 llvm::cast<MemRefType>(op.getValues().getType()).getElementType(); 1270 auto itp = genConstInt32From(rewriter, loc, getCuSparseIndexTypeFrom(iType)); 1271 auto dtp = genConstInt32From(rewriter, loc, getCuSparseDataTypeFrom(dType)); 1272 auto handle = 1273 createCooCallBuilder 1274 .create(loc, rewriter, 1275 {adaptor.getRows(), adaptor.getCols(), adaptor.getNnz(), 1276 pRowIdxs, pColIdxs, pValues, itp, dtp, stream}) 1277 .getResult(); 1278 rewriter.replaceOp(op, {handle, stream}); 1279 return success(); 1280 } 1281 1282 LogicalResult ConvertCreateCooAoSOpToGpuRuntimeCallPattern::matchAndRewrite( 1283 gpu::CreateCooAoSOp op, OpAdaptor adaptor, 1284 ConversionPatternRewriter &rewriter) const { 1285 if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) || 1286 failed(isAsyncWithOneDependency(rewriter, op))) 1287 return failure(); 1288 Location loc = op.getLoc(); 1289 auto stream = adaptor.getAsyncDependencies().front(); 1290 Value pIdxs = MemRefDescriptor(adaptor.getIdxs()).allocatedPtr(rewriter, loc); 1291 Value pValues = 1292 MemRefDescriptor(adaptor.getValues()).allocatedPtr(rewriter, loc); 1293 Type iType = llvm::cast<MemRefType>(op.getIdxs().getType()).getElementType(); 1294 Type dType = 1295 llvm::cast<MemRefType>(op.getValues().getType()).getElementType(); 1296 auto itp = genConstInt32From(rewriter, loc, getCuSparseIndexTypeFrom(iType)); 1297 auto dtp = genConstInt32From(rewriter, loc, getCuSparseDataTypeFrom(dType)); 1298 auto handle = 1299 createCooAoSCallBuilder 1300 .create(loc, rewriter, 1301 {adaptor.getRows(), adaptor.getCols(), adaptor.getNnz(), 1302 pIdxs, pValues, itp, dtp, stream}) 1303 .getResult(); 1304 rewriter.replaceOp(op, {handle, stream}); 1305 return success(); 1306 } 1307 1308 LogicalResult ConvertCreateCsrOpToGpuRuntimeCallPattern::matchAndRewrite( 1309 gpu::CreateCsrOp op, OpAdaptor adaptor, 1310 ConversionPatternRewriter &rewriter) const { 1311 if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) || 1312 failed(isAsyncWithOneDependency(rewriter, op))) 1313 return failure(); 1314 Location loc = op.getLoc(); 1315 auto stream = adaptor.getAsyncDependencies().front(); 1316 Value pRowPos = 1317 MemRefDescriptor(adaptor.getRowPos()).allocatedPtr(rewriter, loc); 1318 Value pColIdxs = 1319 MemRefDescriptor(adaptor.getColIdxs()).allocatedPtr(rewriter, loc); 1320 Value pValues = 1321 MemRefDescriptor(adaptor.getValues()).allocatedPtr(rewriter, loc); 1322 Type pType = 1323 llvm::cast<MemRefType>(op.getRowPos().getType()).getElementType(); 1324 Type iType = 1325 llvm::cast<MemRefType>(op.getColIdxs().getType()).getElementType(); 1326 Type dType = 1327 llvm::cast<MemRefType>(op.getValues().getType()).getElementType(); 1328 auto ptp = genConstInt32From(rewriter, loc, getCuSparseIndexTypeFrom(pType)); 1329 auto itp = genConstInt32From(rewriter, loc, getCuSparseIndexTypeFrom(iType)); 1330 auto dtp = genConstInt32From(rewriter, loc, getCuSparseDataTypeFrom(dType)); 1331 auto handle = 1332 createCsrCallBuilder 1333 .create(loc, rewriter, 1334 {adaptor.getRows(), adaptor.getCols(), adaptor.getNnz(), 1335 pRowPos, pColIdxs, pValues, ptp, itp, dtp, stream}) 1336 .getResult(); 1337 rewriter.replaceOp(op, {handle, stream}); 1338 return success(); 1339 } 1340 1341 LogicalResult ConvertCreate2To4SpMatOpToGpuRuntimeCallPattern::matchAndRewrite( 1342 gpu::Create2To4SpMatOp op, OpAdaptor adaptor, 1343 ConversionPatternRewriter &rewriter) const { 1344 if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) || 1345 failed(isAsyncWithOneDependency(rewriter, op))) 1346 return failure(); 1347 Location loc = op.getLoc(); 1348 auto stream = adaptor.getAsyncDependencies().front(); 1349 Value pMat = 1350 MemRefDescriptor(adaptor.getMemref()).allocatedPtr(rewriter, loc); 1351 Type dType = 1352 llvm::cast<MemRefType>(op.getMemref().getType()).getElementType(); 1353 auto dtp = genConstInt32From(rewriter, loc, getCuSparseDataTypeFrom(dType)); 1354 1355 // CUDA runner asserts the size is 44104 bytes. 1356 auto handleSz = rewriter.create<LLVM::ConstantOp>( 1357 loc, getIndexType(), rewriter.getIndexAttr(44104)); 1358 Value handle = rewriter.create<LLVM::AllocaOp>( 1359 loc, llvmPointerType, llvmInt8Type, handleSz, /*alignment=*/16); 1360 handle = rewriter.create<LLVM::BitcastOp>(loc, llvmPointerType, handle); 1361 1362 create2To4SpMatCallBuilder 1363 .create(loc, rewriter, 1364 {handle, adaptor.getRows(), adaptor.getCols(), pMat, dtp, stream}) 1365 .getResult(); 1366 rewriter.replaceOp(op, {handle, stream}); 1367 return success(); 1368 } 1369 1370 LogicalResult ConvertDestroySpMatOpToGpuRuntimeCallPattern::matchAndRewrite( 1371 gpu::DestroySpMatOp op, OpAdaptor adaptor, 1372 ConversionPatternRewriter &rewriter) const { 1373 if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) || 1374 failed(isAsyncWithOneDependency(rewriter, op))) 1375 return failure(); 1376 Location loc = op.getLoc(); 1377 auto stream = adaptor.getAsyncDependencies().front(); 1378 // Use the cusparseLt destroy call if the spmat is 2:4 sparsity 1379 if (is2To4Sparsity(op.getSpmat())) { 1380 destroyCuSparseLtSpMatBuilder.create(loc, rewriter, 1381 {adaptor.getSpmat(), stream}); 1382 1383 } else { 1384 destroySpMatCallBuilder.create(loc, rewriter, {adaptor.getSpmat(), stream}); 1385 } 1386 rewriter.replaceOp(op, {stream}); 1387 return success(); 1388 } 1389 1390 LogicalResult ConvertSpMVBufferSizeOpToGpuRuntimeCallPattern::matchAndRewrite( 1391 gpu::SpMVBufferSizeOp op, OpAdaptor adaptor, 1392 ConversionPatternRewriter &rewriter) const { 1393 if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) || 1394 failed(isAsyncWithOneDependency(rewriter, op))) 1395 return failure(); 1396 Location loc = op.getLoc(); 1397 auto modeA = genConstInt32From(rewriter, loc, op.getModeA()); 1398 auto computeType = genConstInt32From( 1399 rewriter, loc, getCuSparseDataTypeFrom(adaptor.getComputeType())); 1400 auto stream = adaptor.getAsyncDependencies().front(); 1401 auto bufferSize = spMVBufferSizeCallBuilder 1402 .create(loc, rewriter, 1403 {modeA, adaptor.getSpmatA(), adaptor.getDnX(), 1404 adaptor.getDnY(), computeType, stream}) 1405 .getResult(); 1406 rewriter.replaceOp(op, {bufferSize, stream}); 1407 return success(); 1408 } 1409 1410 LogicalResult ConvertSpMVOpToGpuRuntimeCallPattern::matchAndRewrite( 1411 gpu::SpMVOp op, OpAdaptor adaptor, 1412 ConversionPatternRewriter &rewriter) const { 1413 if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) || 1414 failed(isAsyncWithOneDependency(rewriter, op))) 1415 return failure(); 1416 Location loc = op.getLoc(); 1417 auto modeA = genConstInt32From(rewriter, loc, adaptor.getModeA()); 1418 auto computeType = genConstInt32From( 1419 rewriter, loc, getCuSparseDataTypeFrom(adaptor.getComputeType())); 1420 auto stream = adaptor.getAsyncDependencies().front(); 1421 Value pBuf = 1422 MemRefDescriptor(adaptor.getBuffer()).allocatedPtr(rewriter, loc); 1423 spMVCallBuilder.create(loc, rewriter, 1424 {modeA, adaptor.getSpmatA(), adaptor.getDnX(), 1425 adaptor.getDnY(), computeType, pBuf, stream}); 1426 rewriter.replaceOp(op, {stream}); 1427 return success(); 1428 } 1429 1430 LogicalResult ConvertSpMMBufferSizeOpToGpuRuntimeCallPattern::matchAndRewrite( 1431 gpu::SpMMBufferSizeOp op, OpAdaptor adaptor, 1432 ConversionPatternRewriter &rewriter) const { 1433 if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) || 1434 failed(isAsyncWithOneDependency(rewriter, op))) 1435 return failure(); 1436 Location loc = op.getLoc(); 1437 auto modeA = genConstInt32From(rewriter, loc, adaptor.getModeA()); 1438 auto modeB = genConstInt32From(rewriter, loc, adaptor.getModeB()); 1439 auto stream = adaptor.getAsyncDependencies().front(); 1440 Value bufferSize; 1441 if (is2To4Sparsity(op.getSpmatA())) { 1442 auto pruneFlag = 1443 genConstInt32From(rewriter, loc, get2To4PruneFlag(op.getSpmatA())); 1444 auto computeType = genConstInt32From( 1445 rewriter, loc, getCuSparseLtDataTypeFrom(adaptor.getComputeType())); 1446 auto three = rewriter.create<LLVM::ConstantOp>(loc, getIndexType(), 1447 rewriter.getIndexAttr(3)); 1448 auto bufferSize = rewriter.create<LLVM::AllocaOp>( 1449 loc, llvmPointerType, llvmPointerType, three, /*alignment=*/16); 1450 createCuSparseLtSpMMBufferSizeBuilder 1451 .create(loc, rewriter, 1452 {bufferSize, modeA, modeB, adaptor.getSpmatA(), 1453 adaptor.getDnmatB(), adaptor.getDnmatC(), computeType, 1454 pruneFlag, stream}) 1455 .getResult(); 1456 1457 auto bufferSizePtr1 = rewriter.create<LLVM::GEPOp>( 1458 loc, llvmPointerType, llvmPointerType, bufferSize, 1459 ValueRange{rewriter.create<LLVM::ConstantOp>( 1460 loc, getIndexType(), rewriter.getIndexAttr(1))}); 1461 auto bufferSizePtr2 = rewriter.create<LLVM::GEPOp>( 1462 loc, llvmPointerType, llvmPointerType, bufferSize, 1463 ValueRange{rewriter.create<LLVM::ConstantOp>( 1464 loc, getIndexType(), rewriter.getIndexAttr(2))}); 1465 auto bufferSize0 = 1466 rewriter.create<LLVM::LoadOp>(loc, llvmInt64Type, bufferSize); 1467 auto bufferSize1 = 1468 rewriter.create<LLVM::LoadOp>(loc, llvmInt64Type, bufferSizePtr1); 1469 auto bufferSize2 = 1470 rewriter.create<LLVM::LoadOp>(loc, llvmInt64Type, bufferSizePtr2); 1471 1472 rewriter.replaceOp(op, {bufferSize0, bufferSize1, bufferSize2, stream}); 1473 } else { 1474 auto computeType = genConstInt32From( 1475 rewriter, loc, getCuSparseDataTypeFrom(adaptor.getComputeType())); 1476 bufferSize = 1477 createSpMMBufferSizeCallBuilder 1478 .create(loc, rewriter, 1479 {modeA, modeB, adaptor.getSpmatA(), adaptor.getDnmatB(), 1480 adaptor.getDnmatC(), computeType, stream}) 1481 .getResult(); 1482 rewriter.replaceOp(op, {bufferSize, stream}); 1483 } 1484 return success(); 1485 } 1486 1487 LogicalResult ConvertSDDMMBufferSizeOpToGpuRuntimeCallPattern::matchAndRewrite( 1488 gpu::SDDMMBufferSizeOp op, OpAdaptor adaptor, 1489 ConversionPatternRewriter &rewriter) const { 1490 if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) || 1491 failed(isAsyncWithOneDependency(rewriter, op))) 1492 return failure(); 1493 Location loc = op.getLoc(); 1494 auto modeA = genConstInt32From(rewriter, loc, adaptor.getModeA()); 1495 auto modeB = genConstInt32From(rewriter, loc, adaptor.getModeB()); 1496 auto computeType = genConstInt32From( 1497 rewriter, loc, getCuSparseDataTypeFrom(adaptor.getComputeType())); 1498 auto stream = adaptor.getAsyncDependencies().front(); 1499 auto bufferSize = 1500 createSDDMMBufferSizeCallBuilder 1501 .create(loc, rewriter, 1502 {modeA, modeB, adaptor.getDnmatA(), adaptor.getDnmatB(), 1503 adaptor.getSpmatC(), computeType, stream}) 1504 .getResult(); 1505 rewriter.replaceOp(op, {bufferSize, stream}); 1506 return success(); 1507 } 1508 1509 LogicalResult ConvertSpMMOpToGpuRuntimeCallPattern::matchAndRewrite( 1510 gpu::SpMMOp op, OpAdaptor adaptor, 1511 ConversionPatternRewriter &rewriter) const { 1512 if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) || 1513 failed(isAsyncWithOneDependency(rewriter, op))) 1514 return failure(); 1515 Location loc = op.getLoc(); 1516 auto modeA = genConstInt32From(rewriter, loc, adaptor.getModeA()); 1517 auto modeB = genConstInt32From(rewriter, loc, adaptor.getModeB()); 1518 auto computeType = genConstInt32From( 1519 rewriter, loc, getCuSparseDataTypeFrom(adaptor.getComputeType())); 1520 1521 auto stream = adaptor.getAsyncDependencies().front(); 1522 1523 // Lower to cusparseLt if applicable 1524 if (is2To4Sparsity(op.getSpmatA())) { 1525 SmallVector<Value> pBufs; 1526 for (Value buffer : adaptor.getBuffers()) { 1527 Value pBuf = MemRefDescriptor(buffer).allocatedPtr(rewriter, loc); 1528 pBufs.push_back(pBuf); 1529 } 1530 createCuSparseLtSpMMBuilder.create( 1531 loc, rewriter, 1532 {adaptor.getSpmatA(), adaptor.getDnmatB(), adaptor.getDnmatC(), 1533 pBufs[0], pBufs[1], pBufs[2], stream}); 1534 } else { 1535 Value pBuf = MemRefDescriptor(adaptor.getBuffers().front()) 1536 .allocatedPtr(rewriter, loc); 1537 createSpMMCallBuilder.create(loc, rewriter, 1538 {modeA, modeB, adaptor.getSpmatA(), 1539 adaptor.getDnmatB(), adaptor.getDnmatC(), 1540 computeType, pBuf, stream}); 1541 } 1542 rewriter.replaceOp(op, {stream}); 1543 return success(); 1544 } 1545 1546 template <typename T> 1547 static void addOpaquePointerConversion(LLVMTypeConverter &converter) { 1548 converter.addConversion([&converter](T) -> Type { 1549 return LLVM::LLVMPointerType::get(&converter.getContext()); 1550 }); 1551 } 1552 1553 LogicalResult ConvertSDDMMOpToGpuRuntimeCallPattern::matchAndRewrite( 1554 gpu::SDDMMOp op, OpAdaptor adaptor, 1555 ConversionPatternRewriter &rewriter) const { 1556 if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) || 1557 failed(isAsyncWithOneDependency(rewriter, op))) 1558 return failure(); 1559 Location loc = op.getLoc(); 1560 auto computeType = genConstInt32From( 1561 rewriter, loc, getCuSparseDataTypeFrom(adaptor.getComputeType())); 1562 auto modeA = genConstInt32From(rewriter, loc, adaptor.getModeA()); 1563 auto modeB = genConstInt32From(rewriter, loc, adaptor.getModeB()); 1564 auto stream = adaptor.getAsyncDependencies().front(); 1565 Value pBuf = 1566 MemRefDescriptor(adaptor.getBuffer()).allocatedPtr(rewriter, loc); 1567 createSDDMMCallBuilder.create(loc, rewriter, 1568 {modeA, modeB, adaptor.getDnmatA(), 1569 adaptor.getDnmatB(), adaptor.getSpmatC(), 1570 computeType, pBuf, stream}); 1571 rewriter.replaceOp(op, {stream}); 1572 return success(); 1573 } 1574 1575 LogicalResult 1576 ConvertSpGEMMCreateDescrOpToGpuRuntimeCallPattern::matchAndRewrite( 1577 gpu::SpGEMMCreateDescrOp op, OpAdaptor adaptor, 1578 ConversionPatternRewriter &rewriter) const { 1579 if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) || 1580 failed(isAsyncWithOneDependency(rewriter, op))) 1581 return failure(); 1582 Location loc = op.getLoc(); 1583 auto stream = adaptor.getAsyncDependencies().front(); 1584 Value descr = createSpGEMMCreateDescrBuilder.create(loc, rewriter, {stream}) 1585 .getResult(); 1586 rewriter.replaceOp(op, {descr, stream}); 1587 return success(); 1588 } 1589 1590 LogicalResult 1591 ConvertSpGEMMDestroyDescrOpToGpuRuntimeCallPattern::matchAndRewrite( 1592 gpu::SpGEMMDestroyDescrOp op, OpAdaptor adaptor, 1593 ConversionPatternRewriter &rewriter) const { 1594 if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) || 1595 failed(isAsyncWithOneDependency(rewriter, op))) 1596 return failure(); 1597 Location loc = op.getLoc(); 1598 auto stream = adaptor.getAsyncDependencies().front(); 1599 createSpGEMMDestroyDescrBuilder.create(loc, rewriter, 1600 {adaptor.getDesc(), stream}); 1601 rewriter.replaceOp(op, {stream}); 1602 return success(); 1603 } 1604 1605 LogicalResult 1606 ConvertSpGEMMWorkEstimationOrComputeOpToGpuRuntimeCallPattern::matchAndRewrite( 1607 gpu::SpGEMMWorkEstimationOrComputeOp op, OpAdaptor adaptor, 1608 ConversionPatternRewriter &rewriter) const { 1609 if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) || 1610 failed(isAsyncWithOneDependency(rewriter, op))) 1611 return failure(); 1612 Location loc = op.getLoc(); 1613 auto computeType = genConstInt32From( 1614 rewriter, loc, getCuSparseDataTypeFrom(adaptor.getComputeType())); 1615 auto modeA = genConstInt32From(rewriter, loc, adaptor.getModeA()); 1616 auto modeB = genConstInt32From(rewriter, loc, adaptor.getModeB()); 1617 auto stream = adaptor.getAsyncDependencies().front(); 1618 1619 Value pBuf = 1620 MemRefDescriptor(adaptor.getBuffer()).allocatedPtr(rewriter, loc); 1621 Value bufferSizeNew; 1622 1623 if (adaptor.getKind() == 1624 gpu::SpGEMMWorkEstimationOrComputeKind::WORK_ESTIMATION) { 1625 bufferSizeNew = 1626 createSpGEMMWorkEstimationBuilder 1627 .create(loc, rewriter, 1628 {adaptor.getDesc(), modeA, modeB, adaptor.getSpmatA(), 1629 adaptor.getSpmatB(), adaptor.getSpmatC(), computeType, 1630 adaptor.getBufferSz(), pBuf, stream}) 1631 .getResult(); 1632 } else { 1633 bufferSizeNew = 1634 createSpGEMMComputeBuilder 1635 .create(loc, rewriter, 1636 {adaptor.getDesc(), modeA, modeB, adaptor.getSpmatA(), 1637 adaptor.getSpmatB(), adaptor.getSpmatC(), computeType, 1638 adaptor.getBufferSz(), pBuf, stream}) 1639 .getResult(); 1640 } 1641 rewriter.replaceOp(op, {bufferSizeNew, stream}); 1642 return success(); 1643 } 1644 1645 LogicalResult ConvertSpGEMMCopyOpToGpuRuntimeCallPattern::matchAndRewrite( 1646 gpu::SpGEMMCopyOp op, OpAdaptor adaptor, 1647 ConversionPatternRewriter &rewriter) const { 1648 if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) || 1649 failed(isAsyncWithOneDependency(rewriter, op))) 1650 return failure(); 1651 Location loc = op.getLoc(); 1652 auto computeType = genConstInt32From( 1653 rewriter, loc, getCuSparseDataTypeFrom(adaptor.getComputeType())); 1654 auto modeA = genConstInt32From(rewriter, loc, adaptor.getModeA()); 1655 auto modeB = genConstInt32From(rewriter, loc, adaptor.getModeB()); 1656 auto stream = adaptor.getAsyncDependencies().front(); 1657 createSpGEMMCopyBuilder.create(loc, rewriter, 1658 {adaptor.getDesc(), modeA, modeB, 1659 adaptor.getSpmatA(), adaptor.getSpmatB(), 1660 adaptor.getSpmatC(), computeType, stream}); 1661 rewriter.replaceOp(op, {stream}); 1662 return success(); 1663 } 1664 1665 LogicalResult ConvertSpMatGetSizeOpToGpuRuntimeCallPattern::matchAndRewrite( 1666 gpu::SpMatGetSizeOp op, OpAdaptor adaptor, 1667 ConversionPatternRewriter &rewriter) const { 1668 if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) || 1669 failed(isAsyncWithOneDependency(rewriter, op))) 1670 return failure(); 1671 Location loc = op.getLoc(); 1672 auto stream = adaptor.getAsyncDependencies().front(); 1673 1674 auto three = rewriter.create<LLVM::ConstantOp>(loc, getIndexType(), 1675 rewriter.getIndexAttr(3)); 1676 auto buffer = rewriter.create<LLVM::AllocaOp>( 1677 loc, llvmPointerType, llvmInt64Type, three, /*alignment=*/16); 1678 1679 auto rowsPtr = rewriter.create<LLVM::GEPOp>( 1680 loc, llvmPointerType, llvmPointerType, buffer, 1681 ValueRange{rewriter.create<LLVM::ConstantOp>(loc, getIndexType(), 1682 rewriter.getIndexAttr(0))}); 1683 auto colsPtr = rewriter.create<LLVM::GEPOp>( 1684 loc, llvmPointerType, llvmPointerType, buffer, 1685 ValueRange{rewriter.create<LLVM::ConstantOp>(loc, getIndexType(), 1686 rewriter.getIndexAttr(1))}); 1687 auto nnzsPtr = rewriter.create<LLVM::GEPOp>( 1688 loc, llvmPointerType, llvmPointerType, buffer, 1689 ValueRange{rewriter.create<LLVM::ConstantOp>(loc, getIndexType(), 1690 rewriter.getIndexAttr(2))}); 1691 createSpMatGetSizeBuilder.create( 1692 loc, rewriter, {adaptor.getSpmat(), rowsPtr, colsPtr, nnzsPtr, stream}); 1693 auto rows = rewriter.create<LLVM::LoadOp>(loc, llvmInt64Type, rowsPtr); 1694 auto cols = rewriter.create<LLVM::LoadOp>(loc, llvmInt64Type, colsPtr); 1695 auto nnzs = rewriter.create<LLVM::LoadOp>(loc, llvmInt64Type, nnzsPtr); 1696 1697 rewriter.replaceOp(op, {rows, cols, nnzs, stream}); 1698 return success(); 1699 } 1700 1701 LogicalResult ConvertSetCsrPointersOpToGpuRuntimeCallPattern::matchAndRewrite( 1702 gpu::SetCsrPointersOp op, OpAdaptor adaptor, 1703 ConversionPatternRewriter &rewriter) const { 1704 if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) || 1705 failed(isAsyncWithOneDependency(rewriter, op))) 1706 return failure(); 1707 Location loc = op.getLoc(); 1708 auto stream = adaptor.getAsyncDependencies().front(); 1709 Value pPos = 1710 MemRefDescriptor(adaptor.getPositions()).allocatedPtr(rewriter, loc); 1711 Value pCrd = 1712 MemRefDescriptor(adaptor.getCoordinates()).allocatedPtr(rewriter, loc); 1713 Value pVal = 1714 MemRefDescriptor(adaptor.getValues()).allocatedPtr(rewriter, loc); 1715 createSetCsrPointersBuilder.create( 1716 loc, rewriter, {adaptor.getSpmat(), pPos, pCrd, pVal, stream}); 1717 rewriter.replaceOp(op, {stream}); 1718 return success(); 1719 } 1720 1721 LogicalResult ConvertCreateCscOpToGpuRuntimeCallPattern::matchAndRewrite( 1722 gpu::CreateCscOp op, OpAdaptor adaptor, 1723 ConversionPatternRewriter &rewriter) const { 1724 if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) || 1725 failed(isAsyncWithOneDependency(rewriter, op))) 1726 return failure(); 1727 Location loc = op.getLoc(); 1728 auto stream = adaptor.getAsyncDependencies().front(); 1729 Value pColPos = 1730 MemRefDescriptor(adaptor.getColPos()).allocatedPtr(rewriter, loc); 1731 Value pRowIdxs = 1732 MemRefDescriptor(adaptor.getRowIdxs()).allocatedPtr(rewriter, loc); 1733 Value pValues = 1734 MemRefDescriptor(adaptor.getValues()).allocatedPtr(rewriter, loc); 1735 Type pType = 1736 llvm::cast<MemRefType>(op.getColPos().getType()).getElementType(); 1737 Type iType = 1738 llvm::cast<MemRefType>(op.getRowIdxs().getType()).getElementType(); 1739 Type dType = 1740 llvm::cast<MemRefType>(op.getValues().getType()).getElementType(); 1741 auto ptp = genConstInt32From(rewriter, loc, getCuSparseIndexTypeFrom(pType)); 1742 auto itp = genConstInt32From(rewriter, loc, getCuSparseIndexTypeFrom(iType)); 1743 auto dtp = genConstInt32From(rewriter, loc, getCuSparseDataTypeFrom(dType)); 1744 auto handle = 1745 createCscCallBuilder 1746 .create(loc, rewriter, 1747 {adaptor.getRows(), adaptor.getCols(), adaptor.getNnz(), 1748 pColPos, pRowIdxs, pValues, ptp, itp, dtp, stream}) 1749 .getResult(); 1750 rewriter.replaceOp(op, {handle, stream}); 1751 return success(); 1752 } 1753 1754 LogicalResult ConvertCreateBsrOpToGpuRuntimeCallPattern::matchAndRewrite( 1755 gpu::CreateBsrOp op, OpAdaptor adaptor, 1756 ConversionPatternRewriter &rewriter) const { 1757 if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) || 1758 failed(isAsyncWithOneDependency(rewriter, op))) 1759 return failure(); 1760 Location loc = op.getLoc(); 1761 auto stream = adaptor.getAsyncDependencies().front(); 1762 Value pRowPos = 1763 MemRefDescriptor(adaptor.getBRowPos()).allocatedPtr(rewriter, loc); 1764 Value pColIdxs = 1765 MemRefDescriptor(adaptor.getBColIdxs()).allocatedPtr(rewriter, loc); 1766 Value pValues = 1767 MemRefDescriptor(adaptor.getValues()).allocatedPtr(rewriter, loc); 1768 Type pType = 1769 llvm::cast<MemRefType>(op.getBRowPos().getType()).getElementType(); 1770 Type iType = 1771 llvm::cast<MemRefType>(op.getBColIdxs().getType()).getElementType(); 1772 Type dType = 1773 llvm::cast<MemRefType>(op.getValues().getType()).getElementType(); 1774 auto ptp = genConstInt32From(rewriter, loc, getCuSparseIndexTypeFrom(pType)); 1775 auto itp = genConstInt32From(rewriter, loc, getCuSparseIndexTypeFrom(iType)); 1776 auto dtp = genConstInt32From(rewriter, loc, getCuSparseDataTypeFrom(dType)); 1777 auto handle = 1778 createBsrCallBuilder 1779 .create(loc, rewriter, 1780 {adaptor.getBrows(), adaptor.getBcols(), adaptor.getBnnz(), 1781 adaptor.getRBlockSize(), adaptor.getCBlockSize(), pRowPos, 1782 pColIdxs, pValues, ptp, itp, dtp, stream}) 1783 .getResult(); 1784 rewriter.replaceOp(op, {handle, stream}); 1785 return success(); 1786 } 1787 1788 void mlir::populateGpuToLLVMConversionPatterns( 1789 LLVMTypeConverter &converter, RewritePatternSet &patterns, 1790 bool kernelBarePtrCallConv, bool kernelIntersperseSizeCallConv) { 1791 addOpaquePointerConversion<gpu::AsyncTokenType>(converter); 1792 addOpaquePointerConversion<gpu::SparseDnTensorHandleType>(converter); 1793 addOpaquePointerConversion<gpu::SparseSpMatHandleType>(converter); 1794 addOpaquePointerConversion<gpu::SparseSpGEMMOpHandleType>(converter); 1795 1796 patterns.add<ConvertAllocOpToGpuRuntimeCallPattern, 1797 ConvertDeallocOpToGpuRuntimeCallPattern, 1798 ConvertHostRegisterOpToGpuRuntimeCallPattern, 1799 ConvertHostUnregisterOpToGpuRuntimeCallPattern, 1800 ConvertMemcpyOpToGpuRuntimeCallPattern, 1801 ConvertMemsetOpToGpuRuntimeCallPattern, 1802 ConvertSetDefaultDeviceOpToGpuRuntimeCallPattern, 1803 ConvertWaitAsyncOpToGpuRuntimeCallPattern, 1804 ConvertWaitOpToGpuRuntimeCallPattern, 1805 ConvertAsyncYieldToGpuRuntimeCallPattern, 1806 ConvertCreateDnTensorOpToGpuRuntimeCallPattern, 1807 ConvertDestroyDnTensorOpToGpuRuntimeCallPattern, 1808 ConvertCreateCooOpToGpuRuntimeCallPattern, 1809 ConvertCreateCooAoSOpToGpuRuntimeCallPattern, 1810 ConvertCreateCsrOpToGpuRuntimeCallPattern, 1811 ConvertCreateCscOpToGpuRuntimeCallPattern, 1812 ConvertCreateBsrOpToGpuRuntimeCallPattern, 1813 ConvertCreate2To4SpMatOpToGpuRuntimeCallPattern, 1814 ConvertDestroySpMatOpToGpuRuntimeCallPattern, 1815 ConvertSpMVBufferSizeOpToGpuRuntimeCallPattern, 1816 ConvertSpMVOpToGpuRuntimeCallPattern, 1817 ConvertSpMMBufferSizeOpToGpuRuntimeCallPattern, 1818 ConvertSDDMMBufferSizeOpToGpuRuntimeCallPattern, 1819 ConvertSpMMOpToGpuRuntimeCallPattern, 1820 ConvertSDDMMOpToGpuRuntimeCallPattern, 1821 ConvertSpGEMMCreateDescrOpToGpuRuntimeCallPattern, 1822 ConvertSpGEMMDestroyDescrOpToGpuRuntimeCallPattern, 1823 ConvertSpGEMMWorkEstimationOrComputeOpToGpuRuntimeCallPattern, 1824 ConvertSpGEMMCopyOpToGpuRuntimeCallPattern, 1825 ConvertSpMatGetSizeOpToGpuRuntimeCallPattern, 1826 ConvertSetCsrPointersOpToGpuRuntimeCallPattern>(converter); 1827 patterns.add<LegalizeLaunchFuncOpPattern>(converter, kernelBarePtrCallConv, 1828 kernelIntersperseSizeCallConv); 1829 } 1830 1831 //===----------------------------------------------------------------------===// 1832 // GPUModuleOp convert to LLVM op interface 1833 //===----------------------------------------------------------------------===// 1834 1835 namespace { 1836 struct GPUModuleOpConvertToLLVMInterface 1837 : public ConvertToLLVMOpInterface::ExternalModel< 1838 GPUModuleOpConvertToLLVMInterface, gpu::GPUModuleOp> { 1839 /// Get the conversion patterns from the target attribute. 1840 void getConvertToLLVMConversionAttrs( 1841 Operation *op, SmallVectorImpl<ConvertToLLVMAttrInterface> &attrs) const; 1842 }; 1843 } // namespace 1844 1845 void GPUModuleOpConvertToLLVMInterface::getConvertToLLVMConversionAttrs( 1846 Operation *op, SmallVectorImpl<ConvertToLLVMAttrInterface> &attrs) const { 1847 auto module = cast<gpu::GPUModuleOp>(op); 1848 ArrayAttr targetsAttr = module.getTargetsAttr(); 1849 // Fail if there are no target attributes or there is more than one target. 1850 if (!targetsAttr || targetsAttr.size() != 1) 1851 return; 1852 if (auto patternAttr = dyn_cast<ConvertToLLVMAttrInterface>(targetsAttr[0])) 1853 attrs.push_back(patternAttr); 1854 } 1855 1856 void mlir::gpu::registerConvertGpuToLLVMInterface(DialectRegistry ®istry) { 1857 registry.addExtension(+[](MLIRContext *ctx, gpu::GPUDialect *dialect) { 1858 gpu::GPUModuleOp::attachInterface<GPUModuleOpConvertToLLVMInterface>(*ctx); 1859 }); 1860 } 1861