1a5f9cda1SChristian Sigg //===- ConvertLaunchFuncToGpuRuntimeCalls.cpp - MLIR GPU lowering passes --===// 2a5f9cda1SChristian Sigg // 3a5f9cda1SChristian Sigg // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4a5f9cda1SChristian Sigg // See https://llvm.org/LICENSE.txt for license information. 5a5f9cda1SChristian Sigg // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6a5f9cda1SChristian Sigg // 7a5f9cda1SChristian Sigg //===----------------------------------------------------------------------===// 8a5f9cda1SChristian Sigg // 9a5f9cda1SChristian Sigg // This file implements a pass to convert gpu.launch_func op into a sequence of 10a5f9cda1SChristian Sigg // GPU runtime calls. As most of GPU runtimes does not have a stable published 11a5f9cda1SChristian Sigg // ABI, this pass uses a slim runtime layer that builds on top of the public 12a5f9cda1SChristian Sigg // API from GPU runtime headers. 13a5f9cda1SChristian Sigg // 14a5f9cda1SChristian Sigg //===----------------------------------------------------------------------===// 15a5f9cda1SChristian Sigg 16a5f9cda1SChristian Sigg #include "mlir/Conversion/GPUCommon/GPUCommonPass.h" 17a5f9cda1SChristian Sigg 18abc362a1SJakub Kuderski #include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h" 19a5f9cda1SChristian Sigg #include "mlir/Conversion/AsyncToLLVM/AsyncToLLVM.h" 20330838ebSRiver Riddle #include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h" 219e7b6f46SMehdi Amini #include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h" 229e7b6f46SMehdi Amini #include "mlir/Conversion/ConvertToLLVM/ToLLVMPass.h" 235a7b9194SRiver Riddle #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h" 245a7b9194SRiver Riddle #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVMPass.h" 257498eaa9SFabian Mora #include "mlir/Conversion/GPUCommon/GPUToLLVM.h" 2675e5f0aaSAlex Zinenko #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" 27684dfe8aSAlex Zinenko #include "mlir/Conversion/LLVMCommon/Pattern.h" 2875e5f0aaSAlex Zinenko #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h" 29a5f9cda1SChristian Sigg #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" 30a5f9cda1SChristian Sigg #include "mlir/Dialect/Async/IR/Async.h" 31d7ef488bSMogball #include "mlir/Dialect/GPU/IR/GPUDialect.h" 32d7ef488bSMogball #include "mlir/Dialect/GPU/Transforms/Passes.h" 33a5f9cda1SChristian Sigg #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 3416f8d17fSXiang Li #include "mlir/Dialect/MemRef/IR/MemRef.h" 350693b9e9SMatthias Springer #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" 36a5f9cda1SChristian Sigg #include "mlir/IR/Attributes.h" 37a5f9cda1SChristian Sigg #include "mlir/IR/Builders.h" 38a5f9cda1SChristian Sigg #include "mlir/IR/BuiltinOps.h" 39a5f9cda1SChristian Sigg #include "mlir/IR/BuiltinTypes.h" 400693b9e9SMatthias Springer #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 41a5f9cda1SChristian Sigg 42a5f9cda1SChristian Sigg #include "llvm/ADT/STLExtras.h" 43a5f9cda1SChristian Sigg #include "llvm/Support/Error.h" 44a5f9cda1SChristian Sigg #include "llvm/Support/FormatVariadic.h" 45a5f9cda1SChristian Sigg 469e7b6f46SMehdi Amini #define DEBUG_TYPE "gpu-to-llvm" 479e7b6f46SMehdi Amini 4867d0d7acSMichele Scuttari namespace mlir { 4967d0d7acSMichele Scuttari #define GEN_PASS_DEF_GPUTOLLVMCONVERSIONPASS 5067d0d7acSMichele Scuttari #include "mlir/Conversion/Passes.h.inc" 5167d0d7acSMichele Scuttari } // namespace mlir 5267d0d7acSMichele Scuttari 53a5f9cda1SChristian Sigg using namespace mlir; 54a5f9cda1SChristian Sigg 55a5f9cda1SChristian Sigg namespace { 56039b969bSMichele Scuttari class GpuToLLVMConversionPass 5767d0d7acSMichele Scuttari : public impl::GpuToLLVMConversionPassBase<GpuToLLVMConversionPass> { 58a5f9cda1SChristian Sigg public: 59cd4ca2d7SMarkus Böck using Base::Base; 609e7b6f46SMehdi Amini void getDependentDialects(DialectRegistry ®istry) const final { 619e7b6f46SMehdi Amini Base::getDependentDialects(registry); 629e7b6f46SMehdi Amini registerConvertToLLVMDependentDialectLoading(registry); 639e7b6f46SMehdi Amini } 64a5f9cda1SChristian Sigg // Run the dialect converter on the module. 65a5f9cda1SChristian Sigg void runOnOperation() override; 66a5f9cda1SChristian Sigg }; 67a5f9cda1SChristian Sigg 68a5f9cda1SChristian Sigg template <typename OpTy> 69a5f9cda1SChristian Sigg class ConvertOpToGpuRuntimeCallPattern : public ConvertOpToLLVMPattern<OpTy> { 70a5f9cda1SChristian Sigg public: 71ce254598SMatthias Springer explicit ConvertOpToGpuRuntimeCallPattern( 72ce254598SMatthias Springer const LLVMTypeConverter &typeConverter) 73a5f9cda1SChristian Sigg : ConvertOpToLLVMPattern<OpTy>(typeConverter) {} 74a5f9cda1SChristian Sigg 75a5f9cda1SChristian Sigg protected: 76361458b1SLoren Maggiore Value getNumElements(ConversionPatternRewriter &rewriter, Location loc, 77361458b1SLoren Maggiore MemRefType type, MemRefDescriptor desc) const { 78e98e5995SAlex Zinenko Type indexType = ConvertToLLVMPattern::getIndexType(); 79361458b1SLoren Maggiore return type.hasStaticShape() 80620e2bb2SNicolas Vasilache ? ConvertToLLVMPattern::createIndexAttrConstant( 81620e2bb2SNicolas Vasilache rewriter, loc, indexType, type.getNumElements()) 82361458b1SLoren Maggiore // For identity maps (verified by caller), the number of 83361458b1SLoren Maggiore // elements is stride[0] * size[0]. 84361458b1SLoren Maggiore : rewriter.create<LLVM::MulOp>(loc, 85361458b1SLoren Maggiore desc.stride(rewriter, loc, 0), 86361458b1SLoren Maggiore desc.size(rewriter, loc, 0)); 87361458b1SLoren Maggiore } 88361458b1SLoren Maggiore 89a5f9cda1SChristian Sigg MLIRContext *context = &this->getTypeConverter()->getContext(); 90a5f9cda1SChristian Sigg 91a5f9cda1SChristian Sigg Type llvmVoidType = LLVM::LLVMVoidType::get(context); 92dbd4a0ddSChristian Ulmann LLVM::LLVMPointerType llvmPointerType = LLVM::LLVMPointerType::get(context); 93a5f9cda1SChristian Sigg Type llvmInt8Type = IntegerType::get(context, 8); 9418cc07aaSNavdeep Katel Type llvmInt16Type = IntegerType::get(context, 16); 95a5f9cda1SChristian Sigg Type llvmInt32Type = IntegerType::get(context, 32); 96a5f9cda1SChristian Sigg Type llvmInt64Type = IntegerType::get(context, 64); 97dfe29429SKun Wu Type llvmFloat32Type = Float32Type::get(context); 98a5f9cda1SChristian Sigg Type llvmIntPtrType = IntegerType::get( 99a5f9cda1SChristian Sigg context, this->getTypeConverter()->getPointerBitwidth(0)); 100a5f9cda1SChristian Sigg 101a5f9cda1SChristian Sigg FunctionCallBuilder streamCreateCallBuilder = { 102a5f9cda1SChristian Sigg "mgpuStreamCreate", llvmPointerType /* void *stream */, {}}; 103a5f9cda1SChristian Sigg FunctionCallBuilder streamDestroyCallBuilder = { 104a5f9cda1SChristian Sigg "mgpuStreamDestroy", llvmVoidType, {llvmPointerType /* void *stream */}}; 105a5f9cda1SChristian Sigg FunctionCallBuilder streamSynchronizeCallBuilder = { 106a5f9cda1SChristian Sigg "mgpuStreamSynchronize", 107a5f9cda1SChristian Sigg llvmVoidType, 108a5f9cda1SChristian Sigg {llvmPointerType /* void *stream */}}; 109a5f9cda1SChristian Sigg FunctionCallBuilder streamWaitEventCallBuilder = { 110a5f9cda1SChristian Sigg "mgpuStreamWaitEvent", 111a5f9cda1SChristian Sigg llvmVoidType, 112a5f9cda1SChristian Sigg {llvmPointerType /* void *stream */, llvmPointerType /* void *event */}}; 113a5f9cda1SChristian Sigg FunctionCallBuilder eventCreateCallBuilder = { 114a5f9cda1SChristian Sigg "mgpuEventCreate", llvmPointerType /* void *event */, {}}; 115a5f9cda1SChristian Sigg FunctionCallBuilder eventDestroyCallBuilder = { 116a5f9cda1SChristian Sigg "mgpuEventDestroy", llvmVoidType, {llvmPointerType /* void *event */}}; 117a5f9cda1SChristian Sigg FunctionCallBuilder eventSynchronizeCallBuilder = { 118a5f9cda1SChristian Sigg "mgpuEventSynchronize", 119a5f9cda1SChristian Sigg llvmVoidType, 120a5f9cda1SChristian Sigg {llvmPointerType /* void *event */}}; 121a5f9cda1SChristian Sigg FunctionCallBuilder eventRecordCallBuilder = { 122a5f9cda1SChristian Sigg "mgpuEventRecord", 123a5f9cda1SChristian Sigg llvmVoidType, 124a5f9cda1SChristian Sigg {llvmPointerType /* void *event */, llvmPointerType /* void *stream */}}; 125a5f9cda1SChristian Sigg FunctionCallBuilder hostRegisterCallBuilder = { 126a5f9cda1SChristian Sigg "mgpuMemHostRegisterMemRef", 127a5f9cda1SChristian Sigg llvmVoidType, 128a5f9cda1SChristian Sigg {llvmIntPtrType /* intptr_t rank */, 129a5f9cda1SChristian Sigg llvmPointerType /* void *memrefDesc */, 130a5f9cda1SChristian Sigg llvmIntPtrType /* intptr_t elementSizeBytes */}}; 1318f7c8a6eSmax FunctionCallBuilder hostUnregisterCallBuilder = { 1328f7c8a6eSmax "mgpuMemHostUnregisterMemRef", 1338f7c8a6eSmax llvmVoidType, 1348f7c8a6eSmax {llvmIntPtrType /* intptr_t rank */, 1358f7c8a6eSmax llvmPointerType /* void *memrefDesc */, 1368f7c8a6eSmax llvmIntPtrType /* intptr_t elementSizeBytes */}}; 137a5f9cda1SChristian Sigg FunctionCallBuilder allocCallBuilder = { 138a5f9cda1SChristian Sigg "mgpuMemAlloc", 139a5f9cda1SChristian Sigg llvmPointerType /* void * */, 140a5f9cda1SChristian Sigg {llvmIntPtrType /* intptr_t sizeBytes */, 1411002a1d0SNishant Patel llvmPointerType /* void *stream */, 1421002a1d0SNishant Patel llvmInt8Type /* bool isHostShared */}}; 143a5f9cda1SChristian Sigg FunctionCallBuilder deallocCallBuilder = { 144a5f9cda1SChristian Sigg "mgpuMemFree", 145a5f9cda1SChristian Sigg llvmVoidType, 146a5f9cda1SChristian Sigg {llvmPointerType /* void *ptr */, llvmPointerType /* void *stream */}}; 147a5f9cda1SChristian Sigg FunctionCallBuilder memcpyCallBuilder = { 148a5f9cda1SChristian Sigg "mgpuMemcpy", 149a5f9cda1SChristian Sigg llvmVoidType, 150a5f9cda1SChristian Sigg {llvmPointerType /* void *dst */, llvmPointerType /* void *src */, 151a5f9cda1SChristian Sigg llvmIntPtrType /* intptr_t sizeBytes */, 152a5f9cda1SChristian Sigg llvmPointerType /* void *stream */}}; 15318cc07aaSNavdeep Katel FunctionCallBuilder memset16CallBuilder = { 15418cc07aaSNavdeep Katel "mgpuMemset16", 15518cc07aaSNavdeep Katel llvmVoidType, 15618cc07aaSNavdeep Katel {llvmPointerType /* void *dst */, 15718cc07aaSNavdeep Katel llvmInt16Type /* unsigned short value */, 15818cc07aaSNavdeep Katel llvmIntPtrType /* intptr_t sizeBytes */, 15918cc07aaSNavdeep Katel llvmPointerType /* void *stream */}}; 16018cc07aaSNavdeep Katel FunctionCallBuilder memset32CallBuilder = { 161361458b1SLoren Maggiore "mgpuMemset32", 162361458b1SLoren Maggiore llvmVoidType, 163361458b1SLoren Maggiore {llvmPointerType /* void *dst */, llvmInt32Type /* unsigned int value */, 164361458b1SLoren Maggiore llvmIntPtrType /* intptr_t sizeBytes */, 165361458b1SLoren Maggiore llvmPointerType /* void *stream */}}; 16684718d37SKrzysztof Drewniak FunctionCallBuilder setDefaultDeviceCallBuilder = { 16784718d37SKrzysztof Drewniak "mgpuSetDefaultDevice", 16884718d37SKrzysztof Drewniak llvmVoidType, 16984718d37SKrzysztof Drewniak {llvmInt32Type /* uint32_t devIndex */}}; 170b700a90cSAart Bik FunctionCallBuilder createDnVecCallBuilder = { 171b700a90cSAart Bik "mgpuCreateDnVec", 172b700a90cSAart Bik llvmPointerType, 173b700a90cSAart Bik {llvmIntPtrType, llvmPointerType, llvmInt32Type, 174b700a90cSAart Bik llvmPointerType /* void *stream */}}; 175b700a90cSAart Bik FunctionCallBuilder destroyDnVecCallBuilder = { 176b700a90cSAart Bik "mgpuDestroyDnVec", 177b700a90cSAart Bik llvmVoidType, 178b700a90cSAart Bik {llvmPointerType, llvmPointerType /* void *stream */}}; 179981cf167SAart Bik FunctionCallBuilder createDnMatCallBuilder = { 180981cf167SAart Bik "mgpuCreateDnMat", 181981cf167SAart Bik llvmPointerType, 182981cf167SAart Bik {llvmIntPtrType, llvmIntPtrType, llvmPointerType, llvmInt32Type, 183981cf167SAart Bik llvmPointerType /* void *stream */}}; 184981cf167SAart Bik FunctionCallBuilder destroyDnMatCallBuilder = { 185981cf167SAart Bik "mgpuDestroyDnMat", 186981cf167SAart Bik llvmVoidType, 187981cf167SAart Bik {llvmPointerType, llvmPointerType /* void *stream */}}; 188b700a90cSAart Bik FunctionCallBuilder createCooCallBuilder = { 189b700a90cSAart Bik "mgpuCreateCoo", 190b700a90cSAart Bik llvmPointerType, 191b700a90cSAart Bik {llvmIntPtrType, llvmIntPtrType, llvmIntPtrType, llvmPointerType, 192b700a90cSAart Bik llvmPointerType, llvmPointerType, llvmInt32Type, llvmInt32Type, 193b700a90cSAart Bik llvmPointerType /* void *stream */}}; 1949fc02a7aSAart Bik FunctionCallBuilder createCooAoSCallBuilder = { 1959fc02a7aSAart Bik "mgpuCreateCooAoS", // deprecated in cuSPARSE 11.2 1969fc02a7aSAart Bik llvmPointerType, 1979fc02a7aSAart Bik {llvmIntPtrType, llvmIntPtrType, llvmIntPtrType, llvmPointerType, 1989fc02a7aSAart Bik llvmPointerType, llvmInt32Type, llvmInt32Type, 1999fc02a7aSAart Bik llvmPointerType /* void *stream */}}; 200b700a90cSAart Bik FunctionCallBuilder createCsrCallBuilder = { 201b700a90cSAart Bik "mgpuCreateCsr", 202b700a90cSAart Bik llvmPointerType, 203b700a90cSAart Bik {llvmIntPtrType, llvmIntPtrType, llvmIntPtrType, llvmPointerType, 204b700a90cSAart Bik llvmPointerType, llvmPointerType, llvmInt32Type, llvmInt32Type, 205b700a90cSAart Bik llvmInt32Type, llvmPointerType /* void *stream */}}; 20639038177SAart Bik FunctionCallBuilder createCscCallBuilder = { 20739038177SAart Bik "mgpuCreateCsc", 20839038177SAart Bik llvmPointerType, 20939038177SAart Bik {llvmIntPtrType, llvmIntPtrType, llvmIntPtrType, llvmPointerType, 21039038177SAart Bik llvmPointerType, llvmPointerType, llvmInt32Type, llvmInt32Type, 21139038177SAart Bik llvmInt32Type, llvmPointerType /* void *stream */}}; 21239038177SAart Bik FunctionCallBuilder createBsrCallBuilder = { 21339038177SAart Bik "mgpuCreateBsr", 21439038177SAart Bik llvmPointerType, 21539038177SAart Bik {llvmIntPtrType, llvmIntPtrType, llvmIntPtrType, llvmIntPtrType, 21639038177SAart Bik llvmIntPtrType, llvmPointerType, llvmPointerType, llvmPointerType, 21739038177SAart Bik llvmInt32Type, llvmInt32Type, llvmInt32Type, 21839038177SAart Bik llvmPointerType /* void *stream */}}; 219b700a90cSAart Bik FunctionCallBuilder destroySpMatCallBuilder = { 220b700a90cSAart Bik "mgpuDestroySpMat", 221b700a90cSAart Bik llvmVoidType, 222b700a90cSAart Bik {llvmPointerType, llvmPointerType /* void *stream */}}; 223b700a90cSAart Bik FunctionCallBuilder spMVBufferSizeCallBuilder = { 224b700a90cSAart Bik "mgpuSpMVBufferSize", 225b700a90cSAart Bik llvmIntPtrType, 226be2dd22bSKun Wu {llvmInt32Type, llvmPointerType, llvmPointerType, llvmPointerType, 227be2dd22bSKun Wu llvmInt32Type, llvmPointerType /* void *stream */}}; 228b700a90cSAart Bik FunctionCallBuilder spMVCallBuilder = { 229b700a90cSAart Bik "mgpuSpMV", 230b700a90cSAart Bik llvmVoidType, 231be2dd22bSKun Wu {llvmInt32Type, llvmPointerType, llvmPointerType, llvmPointerType, 232be2dd22bSKun Wu llvmInt32Type, llvmPointerType, llvmPointerType /* void *stream */}}; 233ac30f48eSKun Wu FunctionCallBuilder createSpMMBufferSizeCallBuilder = { 234981cf167SAart Bik "mgpuSpMMBufferSize", 235981cf167SAart Bik llvmIntPtrType, 236be2dd22bSKun Wu {llvmInt32Type, llvmInt32Type, llvmPointerType, llvmPointerType, 237be2dd22bSKun Wu llvmPointerType, llvmInt32Type, llvmPointerType /* void *stream */}}; 238ac30f48eSKun Wu FunctionCallBuilder createSpMMCallBuilder = { 239981cf167SAart Bik "mgpuSpMM", 240981cf167SAart Bik llvmVoidType, 241be2dd22bSKun Wu {llvmInt32Type, llvmInt32Type, llvmPointerType, llvmPointerType, 242be2dd22bSKun Wu llvmPointerType, llvmInt32Type, llvmPointerType, 243235fbe79SKun Wu llvmPointerType /* void *stream */}}; 244ac30f48eSKun Wu FunctionCallBuilder createSDDMMBufferSizeCallBuilder = { 245cf44847bSKun Wu "mgpuSDDMMBufferSize", 246cf44847bSKun Wu llvmIntPtrType, 247be2dd22bSKun Wu {llvmInt32Type, llvmInt32Type, llvmPointerType, llvmPointerType, 248be2dd22bSKun Wu llvmPointerType, llvmInt32Type, llvmPointerType /* void *stream */}}; 249ac30f48eSKun Wu FunctionCallBuilder createSDDMMCallBuilder = { 250cf44847bSKun Wu "mgpuSDDMM", 251cf44847bSKun Wu llvmVoidType, 252be2dd22bSKun Wu {llvmInt32Type, llvmInt32Type, llvmPointerType, llvmPointerType, 253be2dd22bSKun Wu llvmPointerType, llvmInt32Type, llvmPointerType, 254cf44847bSKun Wu llvmPointerType /* void *stream */}}; 2558ed59c53SKun Wu FunctionCallBuilder createLtDnMatCallBuilder = { 2568ed59c53SKun Wu "mgpuCreateCuSparseLtDnMat", 2578ed59c53SKun Wu llvmVoidType, 258be2dd22bSKun Wu {llvmPointerType, llvmIntPtrType, llvmIntPtrType, llvmPointerType, 259be2dd22bSKun Wu llvmInt32Type, llvmPointerType /* void *stream */}}; 2608ed59c53SKun Wu FunctionCallBuilder destroyCuSparseLtSpMatBuilder = { 2618ed59c53SKun Wu "mgpuDestroyCuSparseLtSpMat", 2628ed59c53SKun Wu llvmVoidType, 2638ed59c53SKun Wu {llvmPointerType, llvmPointerType /* void *stream */}}; 2648ed59c53SKun Wu FunctionCallBuilder destroyCuSparseLtDnMatBuilder = { 2658ed59c53SKun Wu "mgpuDestroyCuSparseLtDnMat", 2668ed59c53SKun Wu llvmVoidType, 2678ed59c53SKun Wu {llvmPointerType, llvmPointerType /* void *stream */}}; 2688ed59c53SKun Wu FunctionCallBuilder create2To4SpMatCallBuilder = { 2698ed59c53SKun Wu "mgpuCusparseLtCreate2To4SpMat", 2708ed59c53SKun Wu llvmVoidType, 271be2dd22bSKun Wu {llvmPointerType, llvmIntPtrType, llvmIntPtrType, llvmPointerType, 272be2dd22bSKun Wu llvmInt32Type, llvmPointerType /* void *stream */}}; 273ac30f48eSKun Wu FunctionCallBuilder createCuSparseLtSpMMBufferSizeBuilder = { 2748ed59c53SKun Wu "mgpuCuSparseLtSpMMBufferSize", 2758ed59c53SKun Wu llvmVoidType, 276be2dd22bSKun Wu {llvmPointerType, llvmInt32Type, llvmInt32Type, llvmPointerType, 2771e491c42SKun Wu llvmPointerType, llvmPointerType, llvmInt32Type, llvmInt32Type, 2788ed59c53SKun Wu llvmPointerType /*void *stream*/}}; 279ac30f48eSKun Wu FunctionCallBuilder createCuSparseLtSpMMBuilder = { 2808ed59c53SKun Wu "mgpuCuSparseLtSpMM", 2818ed59c53SKun Wu llvmVoidType, 2828ed59c53SKun Wu {llvmPointerType, llvmPointerType, llvmPointerType, llvmPointerType, 283be2dd22bSKun Wu llvmPointerType, llvmPointerType, llvmPointerType /*void *stream*/}}; 284289f7231SAart Bik FunctionCallBuilder createSpGEMMCreateDescrBuilder = { 285289f7231SAart Bik "mgpuSpGEMMCreateDescr", 286289f7231SAart Bik llvmPointerType, 287289f7231SAart Bik {llvmPointerType /*void *stream*/}}; 288289f7231SAart Bik FunctionCallBuilder createSpGEMMDestroyDescrBuilder = { 289289f7231SAart Bik "mgpuSpGEMMDestroyDescr", 290289f7231SAart Bik llvmVoidType, 291289f7231SAart Bik {llvmPointerType /*s*/, llvmPointerType /*void *stream*/}}; 292dfe29429SKun Wu FunctionCallBuilder createSpGEMMWorkEstimationBuilder = { 293dfe29429SKun Wu "mgpuSpGEMMWorkEstimation", 294dfe29429SKun Wu llvmIntPtrType, 295dfe29429SKun Wu {llvmPointerType /*s*/, llvmInt32Type /*ma*/, llvmInt32Type /*mb*/, 296dfe29429SKun Wu llvmPointerType /*a*/, llvmPointerType /*b*/, llvmPointerType /*c*/, 297e7e4ed0dSAart Bik llvmInt32Type /*ctp*/, llvmIntPtrType /*bs*/, llvmPointerType /*buf*/, 298e7e4ed0dSAart Bik llvmPointerType /*void *stream*/}}; 299dfe29429SKun Wu FunctionCallBuilder createSpGEMMComputeBuilder = { 300dfe29429SKun Wu "mgpuSpGEMMCompute", 301dfe29429SKun Wu llvmIntPtrType, 302dfe29429SKun Wu {llvmPointerType /*s*/, llvmInt32Type /*ma*/, llvmInt32Type /*mb*/, 303dfe29429SKun Wu llvmPointerType /*a*/, llvmPointerType /*b*/, llvmPointerType /*c*/, 304e7e4ed0dSAart Bik llvmInt32Type /*ctp*/, llvmIntPtrType /*bs*/, llvmPointerType /*buf*/, 305e7e4ed0dSAart Bik llvmPointerType /*void *stream*/}}; 306dfe29429SKun Wu FunctionCallBuilder createSpGEMMCopyBuilder = { 307dfe29429SKun Wu "mgpuSpGEMMCopy", 308dfe29429SKun Wu llvmVoidType, 309dfe29429SKun Wu {llvmPointerType /*s*/, llvmInt32Type /*ma*/, llvmInt32Type /*mb*/, 310dfe29429SKun Wu llvmPointerType /*a*/, llvmPointerType /*b*/, llvmPointerType /*c*/, 311e7e4ed0dSAart Bik llvmInt32Type /*ctp*/, llvmPointerType /*void *stream*/}}; 312289f7231SAart Bik FunctionCallBuilder createSpMatGetSizeBuilder = { 313289f7231SAart Bik "mgpuSpMatGetSize", 314dfe29429SKun Wu llvmVoidType, 315dfe29429SKun Wu {llvmPointerType /*mc*/, llvmPointerType /*rc*/, llvmPointerType /*cc*/, 316dfe29429SKun Wu llvmPointerType /*nc*/, llvmPointerType /*void *stream*/}}; 31795a6c509SAart Bik FunctionCallBuilder createSetCsrPointersBuilder = { 31895a6c509SAart Bik "mgpuSetCsrPointers", 31995a6c509SAart Bik llvmVoidType, 32095a6c509SAart Bik {llvmPointerType /*spmat*/, llvmPointerType /*pos*/, 32195a6c509SAart Bik llvmPointerType /*crd*/, llvmPointerType /*val*/, 32295a6c509SAart Bik llvmPointerType /*void *stream*/}}; 323a5f9cda1SChristian Sigg }; 324a5f9cda1SChristian Sigg 325a5f9cda1SChristian Sigg /// A rewrite pattern to convert gpu.host_register operations into a GPU runtime 326a5f9cda1SChristian Sigg /// call. Currently it supports CUDA and ROCm (HIP). 327a5f9cda1SChristian Sigg class ConvertHostRegisterOpToGpuRuntimeCallPattern 328a5f9cda1SChristian Sigg : public ConvertOpToGpuRuntimeCallPattern<gpu::HostRegisterOp> { 329a5f9cda1SChristian Sigg public: 330ce254598SMatthias Springer ConvertHostRegisterOpToGpuRuntimeCallPattern( 331ce254598SMatthias Springer const LLVMTypeConverter &typeConverter) 332a5f9cda1SChristian Sigg : ConvertOpToGpuRuntimeCallPattern<gpu::HostRegisterOp>(typeConverter) {} 333a5f9cda1SChristian Sigg 334a5f9cda1SChristian Sigg private: 335a5f9cda1SChristian Sigg LogicalResult 336ef976337SRiver Riddle matchAndRewrite(gpu::HostRegisterOp hostRegisterOp, OpAdaptor adaptor, 337a5f9cda1SChristian Sigg ConversionPatternRewriter &rewriter) const override; 338a5f9cda1SChristian Sigg }; 339a5f9cda1SChristian Sigg 3408f7c8a6eSmax class ConvertHostUnregisterOpToGpuRuntimeCallPattern 3418f7c8a6eSmax : public ConvertOpToGpuRuntimeCallPattern<gpu::HostUnregisterOp> { 3428f7c8a6eSmax public: 3438f7c8a6eSmax ConvertHostUnregisterOpToGpuRuntimeCallPattern( 344ce254598SMatthias Springer const LLVMTypeConverter &typeConverter) 3458f7c8a6eSmax : ConvertOpToGpuRuntimeCallPattern<gpu::HostUnregisterOp>(typeConverter) { 3468f7c8a6eSmax } 3478f7c8a6eSmax 3488f7c8a6eSmax private: 3498f7c8a6eSmax LogicalResult 3508f7c8a6eSmax matchAndRewrite(gpu::HostUnregisterOp hostUnregisterOp, OpAdaptor adaptor, 3518f7c8a6eSmax ConversionPatternRewriter &rewriter) const override; 3528f7c8a6eSmax }; 3538f7c8a6eSmax 354a5f9cda1SChristian Sigg /// A rewrite pattern to convert gpu.alloc operations into a GPU runtime 355a5f9cda1SChristian Sigg /// call. Currently it supports CUDA and ROCm (HIP). 356a5f9cda1SChristian Sigg class ConvertAllocOpToGpuRuntimeCallPattern 357a5f9cda1SChristian Sigg : public ConvertOpToGpuRuntimeCallPattern<gpu::AllocOp> { 358a5f9cda1SChristian Sigg public: 359ce254598SMatthias Springer ConvertAllocOpToGpuRuntimeCallPattern(const LLVMTypeConverter &typeConverter) 360a5f9cda1SChristian Sigg : ConvertOpToGpuRuntimeCallPattern<gpu::AllocOp>(typeConverter) {} 361a5f9cda1SChristian Sigg 362a5f9cda1SChristian Sigg private: 363a5f9cda1SChristian Sigg LogicalResult 364ef976337SRiver Riddle matchAndRewrite(gpu::AllocOp allocOp, OpAdaptor adaptor, 365a5f9cda1SChristian Sigg ConversionPatternRewriter &rewriter) const override; 366a5f9cda1SChristian Sigg }; 367a5f9cda1SChristian Sigg 368a5f9cda1SChristian Sigg /// A rewrite pattern to convert gpu.dealloc operations into a GPU runtime 369a5f9cda1SChristian Sigg /// call. Currently it supports CUDA and ROCm (HIP). 370a5f9cda1SChristian Sigg class ConvertDeallocOpToGpuRuntimeCallPattern 371a5f9cda1SChristian Sigg : public ConvertOpToGpuRuntimeCallPattern<gpu::DeallocOp> { 372a5f9cda1SChristian Sigg public: 373ce254598SMatthias Springer ConvertDeallocOpToGpuRuntimeCallPattern( 374ce254598SMatthias Springer const LLVMTypeConverter &typeConverter) 375a5f9cda1SChristian Sigg : ConvertOpToGpuRuntimeCallPattern<gpu::DeallocOp>(typeConverter) {} 376a5f9cda1SChristian Sigg 377a5f9cda1SChristian Sigg private: 378a5f9cda1SChristian Sigg LogicalResult 379ef976337SRiver Riddle matchAndRewrite(gpu::DeallocOp deallocOp, OpAdaptor adaptor, 380a5f9cda1SChristian Sigg ConversionPatternRewriter &rewriter) const override; 381a5f9cda1SChristian Sigg }; 382a5f9cda1SChristian Sigg 383a5f9cda1SChristian Sigg class ConvertAsyncYieldToGpuRuntimeCallPattern 384a5f9cda1SChristian Sigg : public ConvertOpToGpuRuntimeCallPattern<async::YieldOp> { 385a5f9cda1SChristian Sigg public: 386ce254598SMatthias Springer ConvertAsyncYieldToGpuRuntimeCallPattern( 387ce254598SMatthias Springer const LLVMTypeConverter &typeConverter) 388a5f9cda1SChristian Sigg : ConvertOpToGpuRuntimeCallPattern<async::YieldOp>(typeConverter) {} 389a5f9cda1SChristian Sigg 390a5f9cda1SChristian Sigg private: 391a5f9cda1SChristian Sigg LogicalResult 392ef976337SRiver Riddle matchAndRewrite(async::YieldOp yieldOp, OpAdaptor adaptor, 393a5f9cda1SChristian Sigg ConversionPatternRewriter &rewriter) const override; 394a5f9cda1SChristian Sigg }; 395a5f9cda1SChristian Sigg 396a5f9cda1SChristian Sigg /// A rewrite pattern to convert gpu.wait operations into a GPU runtime 397a5f9cda1SChristian Sigg /// call. Currently it supports CUDA and ROCm (HIP). 398a5f9cda1SChristian Sigg class ConvertWaitOpToGpuRuntimeCallPattern 399a5f9cda1SChristian Sigg : public ConvertOpToGpuRuntimeCallPattern<gpu::WaitOp> { 400a5f9cda1SChristian Sigg public: 401ce254598SMatthias Springer ConvertWaitOpToGpuRuntimeCallPattern(const LLVMTypeConverter &typeConverter) 402a5f9cda1SChristian Sigg : ConvertOpToGpuRuntimeCallPattern<gpu::WaitOp>(typeConverter) {} 403a5f9cda1SChristian Sigg 404a5f9cda1SChristian Sigg private: 405a5f9cda1SChristian Sigg LogicalResult 406ef976337SRiver Riddle matchAndRewrite(gpu::WaitOp waitOp, OpAdaptor adaptor, 407a5f9cda1SChristian Sigg ConversionPatternRewriter &rewriter) const override; 408a5f9cda1SChristian Sigg }; 409a5f9cda1SChristian Sigg 410a5f9cda1SChristian Sigg /// A rewrite pattern to convert gpu.wait async operations into a GPU runtime 411a5f9cda1SChristian Sigg /// call. Currently it supports CUDA and ROCm (HIP). 412a5f9cda1SChristian Sigg class ConvertWaitAsyncOpToGpuRuntimeCallPattern 413a5f9cda1SChristian Sigg : public ConvertOpToGpuRuntimeCallPattern<gpu::WaitOp> { 414a5f9cda1SChristian Sigg public: 415ce254598SMatthias Springer ConvertWaitAsyncOpToGpuRuntimeCallPattern( 416ce254598SMatthias Springer const LLVMTypeConverter &typeConverter) 417a5f9cda1SChristian Sigg : ConvertOpToGpuRuntimeCallPattern<gpu::WaitOp>(typeConverter) {} 418a5f9cda1SChristian Sigg 419a5f9cda1SChristian Sigg private: 420a5f9cda1SChristian Sigg LogicalResult 421ef976337SRiver Riddle matchAndRewrite(gpu::WaitOp waitOp, OpAdaptor adaptor, 422a5f9cda1SChristian Sigg ConversionPatternRewriter &rewriter) const override; 423a5f9cda1SChristian Sigg }; 424a5f9cda1SChristian Sigg 4258e12f31bSFabian Mora /// A rewrite patter to legalize gpu.launch_func with LLVM types. 4268e12f31bSFabian Mora class LegalizeLaunchFuncOpPattern 427a5f9cda1SChristian Sigg : public ConvertOpToGpuRuntimeCallPattern<gpu::LaunchFuncOp> { 428a5f9cda1SChristian Sigg public: 4298e12f31bSFabian Mora LegalizeLaunchFuncOpPattern(const LLVMTypeConverter &typeConverter, 43099a562b3SAndrea Faulds bool kernelBarePtrCallConv, 431*733be4edSAndrea Faulds bool kernelIntersperseSizeCallConv) 432a5f9cda1SChristian Sigg : ConvertOpToGpuRuntimeCallPattern<gpu::LaunchFuncOp>(typeConverter), 43399a562b3SAndrea Faulds kernelBarePtrCallConv(kernelBarePtrCallConv), 434*733be4edSAndrea Faulds kernelIntersperseSizeCallConv(kernelIntersperseSizeCallConv) {} 435a5f9cda1SChristian Sigg 436a5f9cda1SChristian Sigg private: 437a5f9cda1SChristian Sigg LogicalResult 438ef976337SRiver Riddle matchAndRewrite(gpu::LaunchFuncOp launchOp, OpAdaptor adaptor, 439a5f9cda1SChristian Sigg ConversionPatternRewriter &rewriter) const override; 440a5f9cda1SChristian Sigg 441c2fc8d9bSKrzysztof Drewniak bool kernelBarePtrCallConv; 442*733be4edSAndrea Faulds bool kernelIntersperseSizeCallConv; 443a5f9cda1SChristian Sigg }; 444a5f9cda1SChristian Sigg 445a5f9cda1SChristian Sigg /// A rewrite pattern to convert gpu.memcpy operations into a GPU runtime 446a5f9cda1SChristian Sigg /// call. Currently it supports CUDA and ROCm (HIP). 447a5f9cda1SChristian Sigg class ConvertMemcpyOpToGpuRuntimeCallPattern 448a5f9cda1SChristian Sigg : public ConvertOpToGpuRuntimeCallPattern<gpu::MemcpyOp> { 449a5f9cda1SChristian Sigg public: 450ce254598SMatthias Springer ConvertMemcpyOpToGpuRuntimeCallPattern(const LLVMTypeConverter &typeConverter) 451a5f9cda1SChristian Sigg : ConvertOpToGpuRuntimeCallPattern<gpu::MemcpyOp>(typeConverter) {} 452a5f9cda1SChristian Sigg 453a5f9cda1SChristian Sigg private: 454a5f9cda1SChristian Sigg LogicalResult 455ef976337SRiver Riddle matchAndRewrite(gpu::MemcpyOp memcpyOp, OpAdaptor adaptor, 456a5f9cda1SChristian Sigg ConversionPatternRewriter &rewriter) const override; 457a5f9cda1SChristian Sigg }; 458361458b1SLoren Maggiore 459361458b1SLoren Maggiore /// A rewrite pattern to convert gpu.memset operations into a GPU runtime 460361458b1SLoren Maggiore /// call. Currently it supports CUDA and ROCm (HIP). 461361458b1SLoren Maggiore class ConvertMemsetOpToGpuRuntimeCallPattern 462361458b1SLoren Maggiore : public ConvertOpToGpuRuntimeCallPattern<gpu::MemsetOp> { 463361458b1SLoren Maggiore public: 464ce254598SMatthias Springer ConvertMemsetOpToGpuRuntimeCallPattern(const LLVMTypeConverter &typeConverter) 465361458b1SLoren Maggiore : ConvertOpToGpuRuntimeCallPattern<gpu::MemsetOp>(typeConverter) {} 466361458b1SLoren Maggiore 467361458b1SLoren Maggiore private: 468361458b1SLoren Maggiore LogicalResult 469ef976337SRiver Riddle matchAndRewrite(gpu::MemsetOp memsetOp, OpAdaptor adaptor, 470361458b1SLoren Maggiore ConversionPatternRewriter &rewriter) const override; 471361458b1SLoren Maggiore }; 47284718d37SKrzysztof Drewniak 47384718d37SKrzysztof Drewniak /// A rewrite pattern to convert gpu.set_default_device to a GPU runtime call. 47484718d37SKrzysztof Drewniak /// Currently supports CUDA and ROCm (HIP) 47584718d37SKrzysztof Drewniak class ConvertSetDefaultDeviceOpToGpuRuntimeCallPattern 47684718d37SKrzysztof Drewniak : public ConvertOpToGpuRuntimeCallPattern<gpu::SetDefaultDeviceOp> { 47784718d37SKrzysztof Drewniak public: 47884718d37SKrzysztof Drewniak ConvertSetDefaultDeviceOpToGpuRuntimeCallPattern( 479ce254598SMatthias Springer const LLVMTypeConverter &typeConverter) 48084718d37SKrzysztof Drewniak : ConvertOpToGpuRuntimeCallPattern<gpu::SetDefaultDeviceOp>( 48184718d37SKrzysztof Drewniak typeConverter) {} 48284718d37SKrzysztof Drewniak 48384718d37SKrzysztof Drewniak LogicalResult 48484718d37SKrzysztof Drewniak matchAndRewrite(gpu::SetDefaultDeviceOp op, OpAdaptor adaptor, 48584718d37SKrzysztof Drewniak ConversionPatternRewriter &rewriter) const override; 48684718d37SKrzysztof Drewniak }; 487b700a90cSAart Bik 4889dfd3c32SAart Bik /// Generic rewriting rule for operation on sparse matrices. 4899dfd3c32SAart Bik /// Currently supports CUDA (by means of cuSparse and cuSparseLt). 490dfe29429SKun Wu #define DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(op_name) \ 491dfe29429SKun Wu class Convert##op_name##ToGpuRuntimeCallPattern \ 492dfe29429SKun Wu : public ConvertOpToGpuRuntimeCallPattern<gpu::op_name> { \ 493dfe29429SKun Wu public: \ 494dfe29429SKun Wu Convert##op_name##ToGpuRuntimeCallPattern( \ 495ce254598SMatthias Springer const LLVMTypeConverter &typeConverter) \ 496dfe29429SKun Wu : ConvertOpToGpuRuntimeCallPattern<gpu::op_name>(typeConverter) {} \ 497dfe29429SKun Wu \ 498dfe29429SKun Wu private: \ 499dfe29429SKun Wu LogicalResult \ 500dfe29429SKun Wu matchAndRewrite(gpu::op_name op, OpAdaptor adaptor, \ 501dfe29429SKun Wu ConversionPatternRewriter &rewriter) const override; \ 502dfe29429SKun Wu }; 503dfe29429SKun Wu 5049dfd3c32SAart Bik DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(CreateDnTensorOp) 5059dfd3c32SAart Bik DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(DestroyDnTensorOp) 5069dfd3c32SAart Bik DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(CreateCooOp) 5079dfd3c32SAart Bik DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(CreateCooAoSOp) 5089dfd3c32SAart Bik DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(CreateCsrOp) 50939038177SAart Bik DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(CreateCscOp) 51039038177SAart Bik DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(CreateBsrOp) 5119dfd3c32SAart Bik DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(Create2To4SpMatOp) 5129dfd3c32SAart Bik DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(DestroySpMatOp) 5139dfd3c32SAart Bik DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(SpMVBufferSizeOp) 5149dfd3c32SAart Bik DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(SpMVOp) 5159dfd3c32SAart Bik DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(SpMMBufferSizeOp) 5169dfd3c32SAart Bik DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(SDDMMBufferSizeOp) 5179dfd3c32SAart Bik DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(SpMMOp) 5189dfd3c32SAart Bik DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(SDDMMOp) 519dfe29429SKun Wu DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(SpGEMMCreateDescrOp) 520dfe29429SKun Wu DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(SpGEMMDestroyDescrOp) 521dfe29429SKun Wu DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(SpGEMMWorkEstimationOrComputeOp) 522dfe29429SKun Wu DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(SpGEMMCopyOp) 523289f7231SAart Bik DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(SpMatGetSizeOp) 52495a6c509SAart Bik DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(SetCsrPointersOp) 525dfe29429SKun Wu 526a5f9cda1SChristian Sigg } // namespace 527a5f9cda1SChristian Sigg 528039b969bSMichele Scuttari void GpuToLLVMConversionPass::runOnOperation() { 5299e7b6f46SMehdi Amini MLIRContext *context = &getContext(); 5300693b9e9SMatthias Springer 5310693b9e9SMatthias Springer // Perform progressive lowering of vector transfer operations. 5320693b9e9SMatthias Springer { 5330693b9e9SMatthias Springer RewritePatternSet patterns(&getContext()); 5340693b9e9SMatthias Springer // Vector transfer ops with rank > 1 should be lowered with VectorToSCF. 5350693b9e9SMatthias Springer vector::populateVectorTransferLoweringPatterns(patterns, 5360693b9e9SMatthias Springer /*maxTransferRank=*/1); 53799019060SKazu Hirata if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) 5380693b9e9SMatthias Springer return signalPassFailure(); 5390693b9e9SMatthias Springer } 5400693b9e9SMatthias Springer 5419e7b6f46SMehdi Amini LowerToLLVMOptions options(context); 5429e7b6f46SMehdi Amini options.useBarePtrCallConv = hostBarePtrCallConv; 5439e7b6f46SMehdi Amini RewritePatternSet patterns(context); 5449e7b6f46SMehdi Amini ConversionTarget target(*context); 5459e7b6f46SMehdi Amini target.addLegalDialect<LLVM::LLVMDialect>(); 5469e7b6f46SMehdi Amini LLVMTypeConverter converter(context, options); 5479e7b6f46SMehdi Amini 5489e7b6f46SMehdi Amini // Populate all patterns from all dialects that implement the 5499e7b6f46SMehdi Amini // `ConvertToLLVMPatternInterface` interface. 5509e7b6f46SMehdi Amini for (Dialect *dialect : context->getLoadedDialects()) { 5519e7b6f46SMehdi Amini auto iface = dyn_cast<ConvertToLLVMPatternInterface>(dialect); 5529e7b6f46SMehdi Amini if (!iface) 5539e7b6f46SMehdi Amini continue; 5549e7b6f46SMehdi Amini iface->populateConvertToLLVMConversionPatterns(target, converter, patterns); 5559e7b6f46SMehdi Amini } 5569e7b6f46SMehdi Amini 5578e12f31bSFabian Mora // Preserve GPU modules and binaries. Modules are preserved as they can be 5588e12f31bSFabian Mora // converted later by `gpu-module-to-binary`. 5598e12f31bSFabian Mora target.addLegalOp<gpu::GPUModuleOp, gpu::BinaryOp>(); 5608e12f31bSFabian Mora // Accept as legal LaunchFuncOps if the operands have been lowered. 561fcfeb1e5SFabian Mora target.addDynamicallyLegalOp<gpu::LaunchFuncOp>( 5628e12f31bSFabian Mora [&](gpu::LaunchFuncOp op) -> bool { return converter.isLegal(op); }); 563abe501f2SChristian Sigg 5649e7b6f46SMehdi Amini // These aren't covered by the ConvertToLLVMPatternInterface right now. 565a5f9cda1SChristian Sigg populateVectorToLLVMConversionPatterns(converter, patterns); 566cb4ccd38SQuentin Colombet populateFinalizeMemRefToLLVMConversionPatterns(converter, patterns); 5673a506b31SChris Lattner populateAsyncStructuralTypeConversionsAndLegality(converter, patterns, 5683a506b31SChris Lattner target); 569*733be4edSAndrea Faulds populateGpuToLLVMConversionPatterns(converter, patterns, 570*733be4edSAndrea Faulds kernelBarePtrCallConv, 571*733be4edSAndrea Faulds kernelIntersperseSizeCallConv); 572a5f9cda1SChristian Sigg 573a5f9cda1SChristian Sigg if (failed( 574a5f9cda1SChristian Sigg applyPartialConversion(getOperation(), target, std::move(patterns)))) 575a5f9cda1SChristian Sigg signalPassFailure(); 576a5f9cda1SChristian Sigg } 577a5f9cda1SChristian Sigg 578a5f9cda1SChristian Sigg LLVM::CallOp FunctionCallBuilder::create(Location loc, OpBuilder &builder, 579a5f9cda1SChristian Sigg ArrayRef<Value> arguments) const { 580a5f9cda1SChristian Sigg auto module = builder.getBlock()->getParent()->getParentOfType<ModuleOp>(); 581a5f9cda1SChristian Sigg auto function = [&] { 582a5f9cda1SChristian Sigg if (auto function = module.lookupSymbol<LLVM::LLVMFuncOp>(functionName)) 583a5f9cda1SChristian Sigg return function; 584973ddb7dSMehdi Amini return OpBuilder::atBlockEnd(module.getBody()) 585a5f9cda1SChristian Sigg .create<LLVM::LLVMFuncOp>(loc, functionName, functionType); 586a5f9cda1SChristian Sigg }(); 587c1f719d1SAlex Zinenko return builder.create<LLVM::CallOp>(loc, function, arguments); 588a5f9cda1SChristian Sigg } 589a5f9cda1SChristian Sigg 590cc402de0SKun Wu // Corresponding to cusparseIndexType_t defined in cusparse.h. 591cc402de0SKun Wu static int32_t getCuSparseIndexTypeFrom(Type type) { 59239038177SAart Bik if (type.isInteger(16)) 59339038177SAart Bik return 1; // CUSPARSE_INDEX_16U 59439038177SAart Bik if (type.isInteger(32)) 595cc402de0SKun Wu return 2; // CUSPARSE_INDEX_32I 59639038177SAart Bik return 3; // CUSPARSE_INDEX_64I 597cc402de0SKun Wu } 598cc402de0SKun Wu 5998ed59c53SKun Wu static int32_t getCuSparseLtDataTypeFrom(Type type) { 6008ed59c53SKun Wu if (type.isF16()) 6018ed59c53SKun Wu return 0; // CUSPARSE_COMPUTE_16F, 6028ed59c53SKun Wu if (type.isInteger(32)) 6038ed59c53SKun Wu return 1; // CUSPARSE_COMPUTE_32I 6048ed59c53SKun Wu llvm_unreachable("unsupported type"); 6058ed59c53SKun Wu // TODO: add support to TF32 6068ed59c53SKun Wu } 6078ed59c53SKun Wu 608cc402de0SKun Wu // Corresponding to cudaDataType_t defined in CUDA library_types.h. 609cc402de0SKun Wu static int32_t getCuSparseDataTypeFrom(Type type) { 610cc402de0SKun Wu if (llvm::isa<ComplexType>(type)) { 611cc402de0SKun Wu // get the element type 612a5757c5bSChristian Sigg auto elementType = cast<ComplexType>(type).getElementType(); 613cc402de0SKun Wu if (elementType.isBF16()) 614cc402de0SKun Wu return 15; // CUDA_C_16BF 615cc402de0SKun Wu if (elementType.isF16()) 616cc402de0SKun Wu return 6; // CUDA_C_16F 617cc402de0SKun Wu if (elementType.isF32()) 618cc402de0SKun Wu return 4; // CUDA_C_32F 619cc402de0SKun Wu if (elementType.isF64()) 620cc402de0SKun Wu return 5; // CUDA_C_64F 621cc402de0SKun Wu if (elementType.isInteger(8)) 622cc402de0SKun Wu return 7; // CUDA_C_8I 623cc402de0SKun Wu if (elementType.isInteger(16)) 624cc402de0SKun Wu return 21; // CUDA_C_16I 625cc402de0SKun Wu if (elementType.isInteger(32)) 626cc402de0SKun Wu return 11; // CUDA_C_32I 627cc402de0SKun Wu } 628cc402de0SKun Wu if (type.isBF16()) 629cc402de0SKun Wu return 14; // CUDA_R_16BF 630cc402de0SKun Wu if (type.isF16()) 631cc402de0SKun Wu return 2; // CUDA_R_16F 632cc402de0SKun Wu if (type.isF32()) 633cc402de0SKun Wu return 0; // CUDA_R_32F 634cc402de0SKun Wu if (type.isF64()) 635cc402de0SKun Wu return 1; // CUDA_R_64F 636cc402de0SKun Wu if (type.isInteger(8)) 637cc402de0SKun Wu return 3; // CUDA_R_8I 638cc402de0SKun Wu if (type.isInteger(16)) 639cc402de0SKun Wu return 20; // CUDA_R_16I 640cc402de0SKun Wu if (type.isInteger(32)) 641cc402de0SKun Wu return 10; // CUDA_R_32I 642cc402de0SKun Wu 643cc402de0SKun Wu llvm_unreachable("unsupported element type"); 644cc402de0SKun Wu } 645cc402de0SKun Wu 6461e491c42SKun Wu static gpu::Prune2To4SpMatFlag get2To4PruneFlag(Value spMat) { 6471e491c42SKun Wu return spMat.getDefiningOp<gpu::Create2To4SpMatOp>().getPruneFlag(); 6481e491c42SKun Wu } 64939038177SAart Bik 6508ed59c53SKun Wu // TODO: We may want a run-time (of the mlir compiler) disablement/warning: 6518ed59c53SKun Wu // cusparseLt currently won't work for cuda architecture <8.0 and will trigger a 6528ed59c53SKun Wu // runtime (of the CUDA program) error , but it might be great if we could at 6538ed59c53SKun Wu // least output a warning when we found the target architecture is <8.0 and the 6548ed59c53SKun Wu // user still wants to use cusparseLt. to make sure when lowering gpu sparse 6558ed59c53SKun Wu // dialect to llvm calls, the cusparselt calls are disabled for cuda 6568ed59c53SKun Wu // architecture <8.0 6578ed59c53SKun Wu static bool is2To4Sparsity(Value spMat) { 6588ed59c53SKun Wu if (auto op = spMat.getDefiningOp<gpu::Create2To4SpMatOp>()) 6598ed59c53SKun Wu return true; 6608ed59c53SKun Wu if (auto op = spMat.getDefiningOp<gpu::CreateCooOp>()) 6618ed59c53SKun Wu return false; 66239038177SAart Bik if (auto op = spMat.getDefiningOp<gpu::CreateCooAoSOp>()) 66339038177SAart Bik return false; 6648ed59c53SKun Wu if (auto op = spMat.getDefiningOp<gpu::CreateCsrOp>()) 6658ed59c53SKun Wu return false; 66639038177SAart Bik if (auto op = spMat.getDefiningOp<gpu::CreateCscOp>()) 66739038177SAart Bik return false; 66839038177SAart Bik if (auto op = spMat.getDefiningOp<gpu::CreateBsrOp>()) 6698ed59c53SKun Wu return false; 6708ed59c53SKun Wu // Print the spMat defining op 6718ed59c53SKun Wu spMat.getDefiningOp()->print(llvm::errs()); 6728ed59c53SKun Wu llvm_unreachable("cannot find spmat def"); 6738ed59c53SKun Wu } 6748ed59c53SKun Wu 6758ed59c53SKun Wu static bool isSpMMCusparseLtOp(Value op) { 6768ed59c53SKun Wu for (Operation *user : op.getUsers()) { 6778ed59c53SKun Wu auto spmmOp = dyn_cast<gpu::SpMMOp>(user); 6788ed59c53SKun Wu // If the other operator is 50% sparsity then we should use cusparseLt 6798ed59c53SKun Wu if (!spmmOp) 6808ed59c53SKun Wu continue; 6818ed59c53SKun Wu if (is2To4Sparsity(spmmOp.getSpmatA())) 6828ed59c53SKun Wu return true; 6838ed59c53SKun Wu } 6848ed59c53SKun Wu return false; 6858ed59c53SKun Wu } 6868ed59c53SKun Wu 687a5f9cda1SChristian Sigg // Returns whether all operands are of LLVM type. 688a5f9cda1SChristian Sigg static LogicalResult areAllLLVMTypes(Operation *op, ValueRange operands, 689a5f9cda1SChristian Sigg ConversionPatternRewriter &rewriter) { 690a5f9cda1SChristian Sigg if (!llvm::all_of(operands, [](Value value) { 691a5f9cda1SChristian Sigg return LLVM::isCompatibleType(value.getType()); 692a5f9cda1SChristian Sigg })) 693a5f9cda1SChristian Sigg return rewriter.notifyMatchFailure( 694a5f9cda1SChristian Sigg op, "Cannot convert if operands aren't of LLVM type."); 695a5f9cda1SChristian Sigg return success(); 696a5f9cda1SChristian Sigg } 697a5f9cda1SChristian Sigg 698a5f9cda1SChristian Sigg static LogicalResult 699a5f9cda1SChristian Sigg isAsyncWithOneDependency(ConversionPatternRewriter &rewriter, 700a5f9cda1SChristian Sigg gpu::AsyncOpInterface op) { 701a5f9cda1SChristian Sigg if (op.getAsyncDependencies().size() != 1) 702a5f9cda1SChristian Sigg return rewriter.notifyMatchFailure( 703a5f9cda1SChristian Sigg op, "Can only convert with exactly one async dependency."); 704a5f9cda1SChristian Sigg 705a5f9cda1SChristian Sigg if (!op.getAsyncToken()) 706a5f9cda1SChristian Sigg return rewriter.notifyMatchFailure(op, "Can convert only async version."); 707a5f9cda1SChristian Sigg 708a5f9cda1SChristian Sigg return success(); 709a5f9cda1SChristian Sigg } 710a5f9cda1SChristian Sigg 711a5f9cda1SChristian Sigg LogicalResult ConvertHostRegisterOpToGpuRuntimeCallPattern::matchAndRewrite( 712ef976337SRiver Riddle gpu::HostRegisterOp hostRegisterOp, OpAdaptor adaptor, 713a5f9cda1SChristian Sigg ConversionPatternRewriter &rewriter) const { 714a5f9cda1SChristian Sigg auto *op = hostRegisterOp.getOperation(); 715ef976337SRiver Riddle if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter))) 716a5f9cda1SChristian Sigg return failure(); 717a5f9cda1SChristian Sigg 718a5f9cda1SChristian Sigg Location loc = op->getLoc(); 719a5f9cda1SChristian Sigg 72010c04f46SRiver Riddle auto memRefType = hostRegisterOp.getValue().getType(); 7215550c821STres Popp auto elementType = cast<UnrankedMemRefType>(memRefType).getElementType(); 722a5f9cda1SChristian Sigg auto elementSize = getSizeInBytes(loc, elementType, rewriter); 723a5f9cda1SChristian Sigg 724ef976337SRiver Riddle auto arguments = getTypeConverter()->promoteOperands( 725ef976337SRiver Riddle loc, op->getOperands(), adaptor.getOperands(), rewriter); 726a5f9cda1SChristian Sigg arguments.push_back(elementSize); 727a5f9cda1SChristian Sigg hostRegisterCallBuilder.create(loc, rewriter, arguments); 728a5f9cda1SChristian Sigg 729a5f9cda1SChristian Sigg rewriter.eraseOp(op); 730a5f9cda1SChristian Sigg return success(); 731a5f9cda1SChristian Sigg } 732a5f9cda1SChristian Sigg 7338f7c8a6eSmax LogicalResult ConvertHostUnregisterOpToGpuRuntimeCallPattern::matchAndRewrite( 7348f7c8a6eSmax gpu::HostUnregisterOp hostUnregisterOp, OpAdaptor adaptor, 7358f7c8a6eSmax ConversionPatternRewriter &rewriter) const { 7368f7c8a6eSmax Operation *op = hostUnregisterOp.getOperation(); 7378f7c8a6eSmax if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter))) 7388f7c8a6eSmax return failure(); 7398f7c8a6eSmax 7408f7c8a6eSmax Location loc = op->getLoc(); 7418f7c8a6eSmax 7428f7c8a6eSmax auto memRefType = hostUnregisterOp.getValue().getType(); 7435550c821STres Popp auto elementType = cast<UnrankedMemRefType>(memRefType).getElementType(); 7448f7c8a6eSmax auto elementSize = getSizeInBytes(loc, elementType, rewriter); 7458f7c8a6eSmax 7468f7c8a6eSmax auto arguments = getTypeConverter()->promoteOperands( 7478f7c8a6eSmax loc, op->getOperands(), adaptor.getOperands(), rewriter); 7488f7c8a6eSmax arguments.push_back(elementSize); 7498f7c8a6eSmax hostUnregisterCallBuilder.create(loc, rewriter, arguments); 7508f7c8a6eSmax 7518f7c8a6eSmax rewriter.eraseOp(op); 7528f7c8a6eSmax return success(); 7538f7c8a6eSmax } 7548f7c8a6eSmax 755a5f9cda1SChristian Sigg LogicalResult ConvertAllocOpToGpuRuntimeCallPattern::matchAndRewrite( 756ef976337SRiver Riddle gpu::AllocOp allocOp, OpAdaptor adaptor, 757a5f9cda1SChristian Sigg ConversionPatternRewriter &rewriter) const { 758a93ec06aSIvan Butygin 759a5f9cda1SChristian Sigg MemRefType memRefType = allocOp.getType(); 760a5f9cda1SChristian Sigg 761ef976337SRiver Riddle if (failed(areAllLLVMTypes(allocOp, adaptor.getOperands(), rewriter)) || 7621002a1d0SNishant Patel !isConvertibleAndHasIdentityMaps(memRefType)) 763a5f9cda1SChristian Sigg return failure(); 764a5f9cda1SChristian Sigg 765a5f9cda1SChristian Sigg auto loc = allocOp.getLoc(); 766a5f9cda1SChristian Sigg 7671002a1d0SNishant Patel bool isShared = allocOp.getHostShared(); 7681002a1d0SNishant Patel 7691002a1d0SNishant Patel if (isShared && allocOp.getAsyncToken()) 7701002a1d0SNishant Patel return rewriter.notifyMatchFailure( 7711002a1d0SNishant Patel allocOp, "Host Shared allocation cannot be done async"); 772e204b919SMehdi Amini if (!isShared && failed(isAsyncWithOneDependency(rewriter, allocOp))) 7731002a1d0SNishant Patel return failure(); 7741002a1d0SNishant Patel 775a5f9cda1SChristian Sigg // Get shape of the memref as values: static sizes are constant 776a5f9cda1SChristian Sigg // values and dynamic sizes are passed to 'alloc' as operands. 777a5f9cda1SChristian Sigg SmallVector<Value, 4> shape; 778a5f9cda1SChristian Sigg SmallVector<Value, 4> strides; 779a5f9cda1SChristian Sigg Value sizeBytes; 78010c04f46SRiver Riddle getMemRefDescriptorSizes(loc, memRefType, adaptor.getDynamicSizes(), rewriter, 781a5f9cda1SChristian Sigg shape, strides, sizeBytes); 782a5f9cda1SChristian Sigg 783a5f9cda1SChristian Sigg // Allocate the underlying buffer and store a pointer to it in the MemRef 784a5f9cda1SChristian Sigg // descriptor. 785ced9f4f0SNishant Patel auto nullPtr = rewriter.create<mlir::LLVM::ZeroOp>(loc, llvmPointerType); 786ced9f4f0SNishant Patel Value stream = adaptor.getAsyncDependencies().empty() 787ced9f4f0SNishant Patel ? nullPtr 788ced9f4f0SNishant Patel : adaptor.getAsyncDependencies().front(); 7891002a1d0SNishant Patel 7901002a1d0SNishant Patel auto isHostShared = rewriter.create<mlir::LLVM::ConstantOp>( 7911002a1d0SNishant Patel loc, llvmInt8Type, rewriter.getI8IntegerAttr(isShared)); 7921002a1d0SNishant Patel 793a5f9cda1SChristian Sigg Value allocatedPtr = 7941002a1d0SNishant Patel allocCallBuilder.create(loc, rewriter, {sizeBytes, stream, isHostShared}) 7951002a1d0SNishant Patel .getResult(); 796a5f9cda1SChristian Sigg 797a5f9cda1SChristian Sigg // No alignment. 798a5f9cda1SChristian Sigg Value alignedPtr = allocatedPtr; 799a5f9cda1SChristian Sigg 800a5f9cda1SChristian Sigg // Create the MemRef descriptor. 801a5f9cda1SChristian Sigg auto memRefDescriptor = this->createMemRefDescriptor( 802a5f9cda1SChristian Sigg loc, memRefType, allocatedPtr, alignedPtr, shape, strides, rewriter); 803a5f9cda1SChristian Sigg 804ced9f4f0SNishant Patel if (allocOp.getAsyncToken()) { 805ced9f4f0SNishant Patel // Async alloc: make dependent ops use the same stream. 806a5f9cda1SChristian Sigg rewriter.replaceOp(allocOp, {memRefDescriptor, stream}); 807ced9f4f0SNishant Patel } else { 808ced9f4f0SNishant Patel rewriter.replaceOp(allocOp, {memRefDescriptor}); 809ced9f4f0SNishant Patel } 810a5f9cda1SChristian Sigg 811a5f9cda1SChristian Sigg return success(); 812a5f9cda1SChristian Sigg } 813a5f9cda1SChristian Sigg 814a5f9cda1SChristian Sigg LogicalResult ConvertDeallocOpToGpuRuntimeCallPattern::matchAndRewrite( 815ef976337SRiver Riddle gpu::DeallocOp deallocOp, OpAdaptor adaptor, 816a5f9cda1SChristian Sigg ConversionPatternRewriter &rewriter) const { 817ef976337SRiver Riddle if (failed(areAllLLVMTypes(deallocOp, adaptor.getOperands(), rewriter)) || 818a5f9cda1SChristian Sigg failed(isAsyncWithOneDependency(rewriter, deallocOp))) 819a5f9cda1SChristian Sigg return failure(); 820a5f9cda1SChristian Sigg 821a5f9cda1SChristian Sigg Location loc = deallocOp.getLoc(); 822a5f9cda1SChristian Sigg 823a5f9cda1SChristian Sigg Value pointer = 82410c04f46SRiver Riddle MemRefDescriptor(adaptor.getMemref()).allocatedPtr(rewriter, loc); 82510c04f46SRiver Riddle Value stream = adaptor.getAsyncDependencies().front(); 8260e5aeae6SMarkus Böck deallocCallBuilder.create(loc, rewriter, {pointer, stream}); 827a5f9cda1SChristian Sigg 828a5f9cda1SChristian Sigg rewriter.replaceOp(deallocOp, {stream}); 829a5f9cda1SChristian Sigg return success(); 830a5f9cda1SChristian Sigg } 831a5f9cda1SChristian Sigg 832a5f9cda1SChristian Sigg static bool isGpuAsyncTokenType(Value value) { 8335550c821STres Popp return isa<gpu::AsyncTokenType>(value.getType()); 834a5f9cda1SChristian Sigg } 835a5f9cda1SChristian Sigg 836a5f9cda1SChristian Sigg // Converts !gpu.async.token operands of `async.yield` to runtime calls. The 837a5f9cda1SChristian Sigg // !gpu.async.token are lowered to stream within the async.execute region, but 838a5f9cda1SChristian Sigg // are passed as events between them. For each !gpu.async.token operand, we 839a5f9cda1SChristian Sigg // create an event and record it on the stream. 840a5f9cda1SChristian Sigg LogicalResult ConvertAsyncYieldToGpuRuntimeCallPattern::matchAndRewrite( 841ef976337SRiver Riddle async::YieldOp yieldOp, OpAdaptor adaptor, 842a5f9cda1SChristian Sigg ConversionPatternRewriter &rewriter) const { 843b74192b7SRiver Riddle if (llvm::none_of(yieldOp.getOperands(), isGpuAsyncTokenType)) 844a5f9cda1SChristian Sigg return rewriter.notifyMatchFailure(yieldOp, "no gpu async token operand"); 845a5f9cda1SChristian Sigg 846a5f9cda1SChristian Sigg Location loc = yieldOp.getLoc(); 847ef976337SRiver Riddle SmallVector<Value, 4> newOperands(adaptor.getOperands()); 848a5f9cda1SChristian Sigg llvm::SmallDenseSet<Value> streams; 849a5f9cda1SChristian Sigg for (auto &operand : yieldOp->getOpOperands()) { 850a5f9cda1SChristian Sigg if (!isGpuAsyncTokenType(operand.get())) 851a5f9cda1SChristian Sigg continue; 852a5f9cda1SChristian Sigg auto idx = operand.getOperandNumber(); 853ef976337SRiver Riddle auto stream = adaptor.getOperands()[idx]; 8545e0c3b43SJeff Niu auto event = eventCreateCallBuilder.create(loc, rewriter, {}).getResult(); 855a5f9cda1SChristian Sigg eventRecordCallBuilder.create(loc, rewriter, {event, stream}); 856a5f9cda1SChristian Sigg newOperands[idx] = event; 857a5f9cda1SChristian Sigg streams.insert(stream); 858a5f9cda1SChristian Sigg } 859a5f9cda1SChristian Sigg for (auto stream : streams) 860a5f9cda1SChristian Sigg streamDestroyCallBuilder.create(loc, rewriter, {stream}); 861a5f9cda1SChristian Sigg 8625fcf907bSMatthias Springer rewriter.modifyOpInPlace(yieldOp, [&] { yieldOp->setOperands(newOperands); }); 863a5f9cda1SChristian Sigg return success(); 864a5f9cda1SChristian Sigg } 865a5f9cda1SChristian Sigg 866a5f9cda1SChristian Sigg // Returns whether `value` is the result of an LLVM::CallOp to `functionName`. 867a5f9cda1SChristian Sigg static bool isDefinedByCallTo(Value value, StringRef functionName) { 8685550c821STres Popp assert(isa<LLVM::LLVMPointerType>(value.getType())); 869a5f9cda1SChristian Sigg if (auto defOp = value.getDefiningOp<LLVM::CallOp>()) 870dec8055aSKazu Hirata return *defOp.getCallee() == functionName; 871a5f9cda1SChristian Sigg return false; 872a5f9cda1SChristian Sigg } 873a5f9cda1SChristian Sigg 874a5f9cda1SChristian Sigg // Converts `gpu.wait` to runtime calls. The converted op synchronizes the host 875a5f9cda1SChristian Sigg // with the stream/event operands. The operands are destroyed. That is, it 876a5f9cda1SChristian Sigg // assumes that it is not used afterwards or elsewhere. Otherwise we will get a 877a5f9cda1SChristian Sigg // runtime error. Eventually, we should guarantee this property. 878a5f9cda1SChristian Sigg LogicalResult ConvertWaitOpToGpuRuntimeCallPattern::matchAndRewrite( 879ef976337SRiver Riddle gpu::WaitOp waitOp, OpAdaptor adaptor, 880a5f9cda1SChristian Sigg ConversionPatternRewriter &rewriter) const { 88110c04f46SRiver Riddle if (waitOp.getAsyncToken()) 882a5f9cda1SChristian Sigg return rewriter.notifyMatchFailure(waitOp, "Cannot convert async op."); 883a5f9cda1SChristian Sigg 884a5f9cda1SChristian Sigg Location loc = waitOp.getLoc(); 885a5f9cda1SChristian Sigg 886ef976337SRiver Riddle for (auto operand : adaptor.getOperands()) { 887a5f9cda1SChristian Sigg if (isDefinedByCallTo(operand, streamCreateCallBuilder.functionName)) { 888a5f9cda1SChristian Sigg // The converted operand's definition created a stream. 889a5f9cda1SChristian Sigg streamSynchronizeCallBuilder.create(loc, rewriter, {operand}); 890a5f9cda1SChristian Sigg streamDestroyCallBuilder.create(loc, rewriter, {operand}); 891a5f9cda1SChristian Sigg } else { 892a5f9cda1SChristian Sigg // Otherwise the converted operand is an event. This assumes that we use 893a5f9cda1SChristian Sigg // events in control flow code as well. 894a5f9cda1SChristian Sigg eventSynchronizeCallBuilder.create(loc, rewriter, {operand}); 895a5f9cda1SChristian Sigg eventDestroyCallBuilder.create(loc, rewriter, {operand}); 896a5f9cda1SChristian Sigg } 897a5f9cda1SChristian Sigg } 898a5f9cda1SChristian Sigg 899a5f9cda1SChristian Sigg rewriter.eraseOp(waitOp); 900a5f9cda1SChristian Sigg return success(); 901a5f9cda1SChristian Sigg } 902a5f9cda1SChristian Sigg 903a5f9cda1SChristian Sigg // Converts `gpu.wait async` to runtime calls. The converted op creates a new 904a5f9cda1SChristian Sigg // stream that is synchronized with stream/event operands. The operands are 905a5f9cda1SChristian Sigg // destroyed. That is, it assumes that it is not used afterwards or elsewhere. 906a5f9cda1SChristian Sigg // Otherwise we will get a runtime error. Eventually, we should guarantee this 907a5f9cda1SChristian Sigg // property. 908a5f9cda1SChristian Sigg LogicalResult ConvertWaitAsyncOpToGpuRuntimeCallPattern::matchAndRewrite( 909ef976337SRiver Riddle gpu::WaitOp waitOp, OpAdaptor adaptor, 910a5f9cda1SChristian Sigg ConversionPatternRewriter &rewriter) const { 91110c04f46SRiver Riddle if (!waitOp.getAsyncToken()) 912a5f9cda1SChristian Sigg return rewriter.notifyMatchFailure(waitOp, "Can only convert async op."); 913a5f9cda1SChristian Sigg 914a5f9cda1SChristian Sigg Location loc = waitOp.getLoc(); 915a5f9cda1SChristian Sigg 916a5f9cda1SChristian Sigg auto insertionPoint = rewriter.saveInsertionPoint(); 917a5f9cda1SChristian Sigg SmallVector<Value, 1> events; 918ef976337SRiver Riddle for (auto pair : 91910c04f46SRiver Riddle llvm::zip(waitOp.getAsyncDependencies(), adaptor.getOperands())) { 920a5f9cda1SChristian Sigg auto operand = std::get<1>(pair); 921a5f9cda1SChristian Sigg if (isDefinedByCallTo(operand, streamCreateCallBuilder.functionName)) { 922a5f9cda1SChristian Sigg // The converted operand's definition created a stream. Insert an event 923a5f9cda1SChristian Sigg // into the stream just after the last use of the original token operand. 924a5f9cda1SChristian Sigg auto *defOp = std::get<0>(pair).getDefiningOp(); 925a5f9cda1SChristian Sigg rewriter.setInsertionPointAfter(defOp); 9265e0c3b43SJeff Niu auto event = eventCreateCallBuilder.create(loc, rewriter, {}).getResult(); 927a5f9cda1SChristian Sigg eventRecordCallBuilder.create(loc, rewriter, {event, operand}); 928a5f9cda1SChristian Sigg events.push_back(event); 929a5f9cda1SChristian Sigg } else { 930a5f9cda1SChristian Sigg // Otherwise the converted operand is an event. This assumes that we use 931a5f9cda1SChristian Sigg // events in control flow code as well. 932a5f9cda1SChristian Sigg events.push_back(operand); 933a5f9cda1SChristian Sigg } 934a5f9cda1SChristian Sigg } 935a5f9cda1SChristian Sigg rewriter.restoreInsertionPoint(insertionPoint); 9365e0c3b43SJeff Niu auto stream = streamCreateCallBuilder.create(loc, rewriter, {}).getResult(); 937a5f9cda1SChristian Sigg for (auto event : events) 938a5f9cda1SChristian Sigg streamWaitEventCallBuilder.create(loc, rewriter, {stream, event}); 939a5f9cda1SChristian Sigg for (auto event : events) 940a5f9cda1SChristian Sigg eventDestroyCallBuilder.create(loc, rewriter, {event}); 941a5f9cda1SChristian Sigg rewriter.replaceOp(waitOp, {stream}); 942a5f9cda1SChristian Sigg 943a5f9cda1SChristian Sigg return success(); 944a5f9cda1SChristian Sigg } 945a5f9cda1SChristian Sigg 9468e12f31bSFabian Mora // Legalize the op's operands. 9478e12f31bSFabian Mora LogicalResult LegalizeLaunchFuncOpPattern::matchAndRewrite( 948ef976337SRiver Riddle gpu::LaunchFuncOp launchOp, OpAdaptor adaptor, 949a5f9cda1SChristian Sigg ConversionPatternRewriter &rewriter) const { 950ef976337SRiver Riddle if (failed(areAllLLVMTypes(launchOp, adaptor.getOperands(), rewriter))) 951a5f9cda1SChristian Sigg return failure(); 952a5f9cda1SChristian Sigg 95310c04f46SRiver Riddle if (launchOp.getAsyncDependencies().size() > 1) 954a5f9cda1SChristian Sigg return rewriter.notifyMatchFailure( 955a5f9cda1SChristian Sigg launchOp, "Cannot convert with more than one async dependency."); 956a5f9cda1SChristian Sigg 957a5f9cda1SChristian Sigg // Fail when the synchronous version of the op has async dependencies. The 958a5f9cda1SChristian Sigg // lowering destroys the stream, and we do not want to check that there is no 959a5f9cda1SChristian Sigg // use of the stream after this op. 96010c04f46SRiver Riddle if (!launchOp.getAsyncToken() && !launchOp.getAsyncDependencies().empty()) 961a5f9cda1SChristian Sigg return rewriter.notifyMatchFailure( 962a5f9cda1SChristian Sigg launchOp, "Cannot convert non-async op with async dependencies."); 963a5f9cda1SChristian Sigg 964a5f9cda1SChristian Sigg Location loc = launchOp.getLoc(); 965a5f9cda1SChristian Sigg 966fcfeb1e5SFabian Mora Value stream = Value(); 967583e78b3SAdrian Kuegel if (!adaptor.getAsyncDependencies().empty()) 968fcfeb1e5SFabian Mora stream = adaptor.getAsyncDependencies().front(); 969fcfeb1e5SFabian Mora // If the async keyword is present and there are no dependencies, then a 970fcfeb1e5SFabian Mora // stream must be created to pass to subsequent operations. 971fcfeb1e5SFabian Mora else if (launchOp.getAsyncToken()) 972fcfeb1e5SFabian Mora stream = streamCreateCallBuilder.create(loc, rewriter, {}).getResult(); 97399a562b3SAndrea Faulds 974fcfeb1e5SFabian Mora // Lower the kernel operands to match kernel parameters. 9757f4dbd83SMatthias Springer // Note: If `useBarePtrCallConv` is set in the type converter's options, 9767f4dbd83SMatthias Springer // the value of `kernelBarePtrCallConv` will be ignored. 977*733be4edSAndrea Faulds OperandRange origArguments = launchOp.getKernelOperands(); 978*733be4edSAndrea Faulds SmallVector<Value, 8> llvmArguments = getTypeConverter()->promoteOperands( 979*733be4edSAndrea Faulds loc, origArguments, adaptor.getKernelOperands(), rewriter, 9808e12f31bSFabian Mora /*useBarePtrCallConv=*/kernelBarePtrCallConv); 981*733be4edSAndrea Faulds SmallVector<Value, 8> llvmArgumentsWithSizes; 982*733be4edSAndrea Faulds 983*733be4edSAndrea Faulds // Intersperse size information if requested. 984*733be4edSAndrea Faulds if (kernelIntersperseSizeCallConv) { 985*733be4edSAndrea Faulds if (origArguments.size() != llvmArguments.size()) { 986*733be4edSAndrea Faulds // This shouldn't happen if the bare-pointer calling convention is used. 987*733be4edSAndrea Faulds return rewriter.notifyMatchFailure( 988*733be4edSAndrea Faulds launchOp, 989*733be4edSAndrea Faulds "Cannot add sizes to arguments with one-to-many LLVM IR expansion."); 990*733be4edSAndrea Faulds } 991*733be4edSAndrea Faulds 992*733be4edSAndrea Faulds llvmArgumentsWithSizes.reserve(llvmArguments.size() * 2); 993*733be4edSAndrea Faulds for (auto [llvmArg, origArg] : zip_equal(llvmArguments, origArguments)) { 994*733be4edSAndrea Faulds auto memrefTy = dyn_cast<MemRefType>(origArg.getType()); 995*733be4edSAndrea Faulds if (!memrefTy) { 996*733be4edSAndrea Faulds return rewriter.notifyMatchFailure( 997*733be4edSAndrea Faulds launchOp, "Operand to launch op is not a memref."); 998*733be4edSAndrea Faulds } 999*733be4edSAndrea Faulds 1000*733be4edSAndrea Faulds if (!memrefTy.hasStaticShape() || 1001*733be4edSAndrea Faulds !memrefTy.getElementType().isIntOrFloat()) { 1002*733be4edSAndrea Faulds return rewriter.notifyMatchFailure( 1003*733be4edSAndrea Faulds launchOp, "Operand to launch op is not a memref with a static " 1004*733be4edSAndrea Faulds "shape and an integer or float element type."); 1005*733be4edSAndrea Faulds } 1006*733be4edSAndrea Faulds 1007*733be4edSAndrea Faulds unsigned bitwidth = memrefTy.getElementTypeBitWidth(); 1008*733be4edSAndrea Faulds if (bitwidth % 8 != 0) { 1009*733be4edSAndrea Faulds return rewriter.notifyMatchFailure( 1010*733be4edSAndrea Faulds launchOp, "Operand to launch op is not a memref with a " 1011*733be4edSAndrea Faulds "byte-aligned element type."); 1012*733be4edSAndrea Faulds } 1013*733be4edSAndrea Faulds 1014*733be4edSAndrea Faulds uint64_t staticSize = static_cast<uint64_t>(bitwidth / 8) * 1015*733be4edSAndrea Faulds static_cast<uint64_t>(memrefTy.getNumElements()); 1016*733be4edSAndrea Faulds 1017*733be4edSAndrea Faulds Value sizeArg = rewriter.create<LLVM::ConstantOp>( 1018*733be4edSAndrea Faulds loc, getIndexType(), rewriter.getIndexAttr(staticSize)); 1019*733be4edSAndrea Faulds llvmArgumentsWithSizes.push_back(llvmArg); // Presumably a bare pointer. 1020*733be4edSAndrea Faulds llvmArgumentsWithSizes.push_back(sizeArg); 1021*733be4edSAndrea Faulds } 1022*733be4edSAndrea Faulds } 1023fcfeb1e5SFabian Mora 1024edf5cae7SGuray Ozen std::optional<gpu::KernelDim3> clusterSize = std::nullopt; 1025edf5cae7SGuray Ozen if (launchOp.hasClusterSize()) { 1026edf5cae7SGuray Ozen clusterSize = 1027edf5cae7SGuray Ozen gpu::KernelDim3{adaptor.getClusterSizeX(), adaptor.getClusterSizeY(), 1028edf5cae7SGuray Ozen adaptor.getClusterSizeZ()}; 1029edf5cae7SGuray Ozen } 1030fcfeb1e5SFabian Mora rewriter.create<gpu::LaunchFuncOp>( 1031fcfeb1e5SFabian Mora launchOp.getLoc(), launchOp.getKernelAttr(), 1032fcfeb1e5SFabian Mora gpu::KernelDim3{adaptor.getGridSizeX(), adaptor.getGridSizeY(), 1033fcfeb1e5SFabian Mora adaptor.getGridSizeZ()}, 1034fcfeb1e5SFabian Mora gpu::KernelDim3{adaptor.getBlockSizeX(), adaptor.getBlockSizeY(), 1035fcfeb1e5SFabian Mora adaptor.getBlockSizeZ()}, 1036*733be4edSAndrea Faulds adaptor.getDynamicSharedMemorySize(), 1037*733be4edSAndrea Faulds llvmArgumentsWithSizes.empty() ? llvmArguments : llvmArgumentsWithSizes, 1038*733be4edSAndrea Faulds stream, clusterSize); 1039fcfeb1e5SFabian Mora if (launchOp.getAsyncToken()) 1040fcfeb1e5SFabian Mora rewriter.replaceOp(launchOp, {stream}); 1041fcfeb1e5SFabian Mora else 1042fcfeb1e5SFabian Mora rewriter.eraseOp(launchOp); 1043fcfeb1e5SFabian Mora return success(); 1044fcfeb1e5SFabian Mora } 1045fcfeb1e5SFabian Mora 10460aaf2e3bSMarkus Böck static Value bitAndAddrspaceCast(Location loc, 10470aaf2e3bSMarkus Böck ConversionPatternRewriter &rewriter, 10480aaf2e3bSMarkus Böck LLVM::LLVMPointerType destinationType, 10490aaf2e3bSMarkus Böck Value sourcePtr, 1050ce254598SMatthias Springer const LLVMTypeConverter &typeConverter) { 10515550c821STres Popp auto sourceTy = cast<LLVM::LLVMPointerType>(sourcePtr.getType()); 10520aaf2e3bSMarkus Böck if (destinationType.getAddressSpace() != sourceTy.getAddressSpace()) 10530aaf2e3bSMarkus Böck sourcePtr = rewriter.create<LLVM::AddrSpaceCastOp>( 10540aaf2e3bSMarkus Böck loc, 1055dbd4a0ddSChristian Ulmann LLVM::LLVMPointerType::get(rewriter.getContext(), 10560aaf2e3bSMarkus Böck destinationType.getAddressSpace()), 10570aaf2e3bSMarkus Böck sourcePtr); 10580e5aeae6SMarkus Böck return sourcePtr; 10590aaf2e3bSMarkus Böck } 10600aaf2e3bSMarkus Böck 1061a5f9cda1SChristian Sigg LogicalResult ConvertMemcpyOpToGpuRuntimeCallPattern::matchAndRewrite( 1062ef976337SRiver Riddle gpu::MemcpyOp memcpyOp, OpAdaptor adaptor, 1063a5f9cda1SChristian Sigg ConversionPatternRewriter &rewriter) const { 10645550c821STres Popp auto memRefType = cast<MemRefType>(memcpyOp.getSrc().getType()); 1065a5f9cda1SChristian Sigg 1066ef976337SRiver Riddle if (failed(areAllLLVMTypes(memcpyOp, adaptor.getOperands(), rewriter)) || 1067a5f9cda1SChristian Sigg !isConvertibleAndHasIdentityMaps(memRefType) || 1068a5f9cda1SChristian Sigg failed(isAsyncWithOneDependency(rewriter, memcpyOp))) 1069a5f9cda1SChristian Sigg return failure(); 1070a5f9cda1SChristian Sigg 1071a5f9cda1SChristian Sigg auto loc = memcpyOp.getLoc(); 1072a5f9cda1SChristian Sigg 107310c04f46SRiver Riddle MemRefDescriptor srcDesc(adaptor.getSrc()); 1074361458b1SLoren Maggiore Value numElements = getNumElements(rewriter, loc, memRefType, srcDesc); 1075a5f9cda1SChristian Sigg 1076a5f9cda1SChristian Sigg Type elementPtrType = getElementPtrType(memRefType); 107785175eddSTobias Gysi Value nullPtr = rewriter.create<LLVM::ZeroOp>(loc, elementPtrType); 10780e5aeae6SMarkus Böck Value gepPtr = rewriter.create<LLVM::GEPOp>( 10790e5aeae6SMarkus Böck loc, elementPtrType, 10800e5aeae6SMarkus Böck typeConverter->convertType(memRefType.getElementType()), nullPtr, 10810e5aeae6SMarkus Böck numElements); 1082a5f9cda1SChristian Sigg auto sizeBytes = 1083a5f9cda1SChristian Sigg rewriter.create<LLVM::PtrToIntOp>(loc, getIndexType(), gepPtr); 1084a5f9cda1SChristian Sigg 10850aaf2e3bSMarkus Böck auto src = bitAndAddrspaceCast(loc, rewriter, llvmPointerType, 10860aaf2e3bSMarkus Böck srcDesc.alignedPtr(rewriter, loc), 10870aaf2e3bSMarkus Böck *getTypeConverter()); 10880aaf2e3bSMarkus Böck auto dst = bitAndAddrspaceCast( 10890aaf2e3bSMarkus Böck loc, rewriter, llvmPointerType, 10900aaf2e3bSMarkus Böck MemRefDescriptor(adaptor.getDst()).alignedPtr(rewriter, loc), 10910aaf2e3bSMarkus Böck *getTypeConverter()); 1092a5f9cda1SChristian Sigg 109310c04f46SRiver Riddle auto stream = adaptor.getAsyncDependencies().front(); 1094a5f9cda1SChristian Sigg memcpyCallBuilder.create(loc, rewriter, {dst, src, sizeBytes, stream}); 1095a5f9cda1SChristian Sigg 1096a5f9cda1SChristian Sigg rewriter.replaceOp(memcpyOp, {stream}); 1097a5f9cda1SChristian Sigg 1098a5f9cda1SChristian Sigg return success(); 1099a5f9cda1SChristian Sigg } 1100a5f9cda1SChristian Sigg 1101361458b1SLoren Maggiore LogicalResult ConvertMemsetOpToGpuRuntimeCallPattern::matchAndRewrite( 1102ef976337SRiver Riddle gpu::MemsetOp memsetOp, OpAdaptor adaptor, 1103361458b1SLoren Maggiore ConversionPatternRewriter &rewriter) const { 11045550c821STres Popp auto memRefType = cast<MemRefType>(memsetOp.getDst().getType()); 1105361458b1SLoren Maggiore 1106ef976337SRiver Riddle if (failed(areAllLLVMTypes(memsetOp, adaptor.getOperands(), rewriter)) || 1107361458b1SLoren Maggiore !isConvertibleAndHasIdentityMaps(memRefType) || 1108361458b1SLoren Maggiore failed(isAsyncWithOneDependency(rewriter, memsetOp))) 1109361458b1SLoren Maggiore return failure(); 1110361458b1SLoren Maggiore 1111361458b1SLoren Maggiore auto loc = memsetOp.getLoc(); 1112361458b1SLoren Maggiore 111310c04f46SRiver Riddle Type valueType = adaptor.getValue().getType(); 111418cc07aaSNavdeep Katel unsigned bitWidth = valueType.getIntOrFloatBitWidth(); 111518cc07aaSNavdeep Katel // Ints and floats of 16 or 32 bit width are allowed. 111618cc07aaSNavdeep Katel if (!valueType.isIntOrFloat() || (bitWidth != 16 && bitWidth != 32)) { 111718cc07aaSNavdeep Katel return rewriter.notifyMatchFailure( 111818cc07aaSNavdeep Katel memsetOp, "value must be a 16 or 32 bit int or float"); 1119361458b1SLoren Maggiore } 1120361458b1SLoren Maggiore 112118cc07aaSNavdeep Katel unsigned valueTypeWidth = valueType.getIntOrFloatBitWidth(); 112218cc07aaSNavdeep Katel Type bitCastType = valueTypeWidth == 32 ? llvmInt32Type : llvmInt16Type; 112318cc07aaSNavdeep Katel 112410c04f46SRiver Riddle MemRefDescriptor dstDesc(adaptor.getDst()); 1125361458b1SLoren Maggiore Value numElements = getNumElements(rewriter, loc, memRefType, dstDesc); 1126361458b1SLoren Maggiore 1127361458b1SLoren Maggiore auto value = 112818cc07aaSNavdeep Katel rewriter.create<LLVM::BitcastOp>(loc, bitCastType, adaptor.getValue()); 11290aaf2e3bSMarkus Böck auto dst = bitAndAddrspaceCast(loc, rewriter, llvmPointerType, 11300aaf2e3bSMarkus Böck dstDesc.alignedPtr(rewriter, loc), 11310aaf2e3bSMarkus Böck *getTypeConverter()); 1132361458b1SLoren Maggiore 113310c04f46SRiver Riddle auto stream = adaptor.getAsyncDependencies().front(); 113418cc07aaSNavdeep Katel FunctionCallBuilder builder = 113518cc07aaSNavdeep Katel valueTypeWidth == 32 ? memset32CallBuilder : memset16CallBuilder; 113618cc07aaSNavdeep Katel builder.create(loc, rewriter, {dst, value, numElements, stream}); 1137361458b1SLoren Maggiore 1138361458b1SLoren Maggiore rewriter.replaceOp(memsetOp, {stream}); 1139361458b1SLoren Maggiore return success(); 1140361458b1SLoren Maggiore } 1141361458b1SLoren Maggiore 114284718d37SKrzysztof Drewniak LogicalResult ConvertSetDefaultDeviceOpToGpuRuntimeCallPattern::matchAndRewrite( 114384718d37SKrzysztof Drewniak gpu::SetDefaultDeviceOp op, OpAdaptor adaptor, 114484718d37SKrzysztof Drewniak ConversionPatternRewriter &rewriter) const { 114584718d37SKrzysztof Drewniak Location loc = op.getLoc(); 114611141bc6SPaul C Fuqua auto call = setDefaultDeviceCallBuilder.create(loc, rewriter, 114711141bc6SPaul C Fuqua {adaptor.getDevIndex()}); 114811141bc6SPaul C Fuqua rewriter.replaceOp(op, call); 114984718d37SKrzysztof Drewniak return success(); 115084718d37SKrzysztof Drewniak } 115184718d37SKrzysztof Drewniak 1152cc402de0SKun Wu template <typename T> 11536ac80a76SMehdi Amini static Value genConstInt32From(OpBuilder &builder, Location loc, T tValue) { 1154235fbe79SKun Wu Type llvmInt32Type = builder.getIntegerType(32); 1155235fbe79SKun Wu return builder.create<LLVM::ConstantOp>(loc, llvmInt32Type, 11566ac80a76SMehdi Amini static_cast<int32_t>(tValue)); 1157cc402de0SKun Wu } 1158cc402de0SKun Wu 1159dfe29429SKun Wu template <typename T> 11606ac80a76SMehdi Amini static Value genConstFloat32From(OpBuilder &builder, Location loc, T tValue) { 1161dfe29429SKun Wu Type llvmFloat32Type = builder.getF32Type(); 1162dfe29429SKun Wu return builder.create<LLVM::ConstantOp>( 1163dfe29429SKun Wu loc, llvmFloat32Type, 11646ac80a76SMehdi Amini builder.getF32FloatAttr(static_cast<float>(tValue))); 1165dfe29429SKun Wu } 1166dfe29429SKun Wu 116797f4c22bSKun Wu LogicalResult ConvertCreateDnTensorOpToGpuRuntimeCallPattern::matchAndRewrite( 116897f4c22bSKun Wu gpu::CreateDnTensorOp op, OpAdaptor adaptor, 1169b700a90cSAart Bik ConversionPatternRewriter &rewriter) const { 1170b700a90cSAart Bik if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) || 1171b700a90cSAart Bik failed(isAsyncWithOneDependency(rewriter, op))) 1172b700a90cSAart Bik return failure(); 1173b700a90cSAart Bik Location loc = op.getLoc(); 1174b700a90cSAart Bik auto stream = adaptor.getAsyncDependencies().front(); 117597f4c22bSKun Wu Value pTensor = 1176b700a90cSAart Bik MemRefDescriptor(adaptor.getMemref()).allocatedPtr(rewriter, loc); 1177cc402de0SKun Wu Type dType = op.getMemref().getType().getElementType(); 1178cc402de0SKun Wu auto dtp = genConstInt32From(rewriter, loc, getCuSparseDataTypeFrom(dType)); 117997f4c22bSKun Wu 118097f4c22bSKun Wu SmallVector<Value, 4> dims; 118197f4c22bSKun Wu for (Value dim : adaptor.getDims()) { 118297f4c22bSKun Wu dims.push_back(dim); 1183b700a90cSAart Bik } 1184b700a90cSAart Bik 118597f4c22bSKun Wu Value handle; 11868ed59c53SKun Wu // TODO: For now, we track the use of the handle and lower it to cusparse / 11878ed59c53SKun Wu // cusparseLt accordingly. If in a block, both cusparse and cusparseLt are 11888ed59c53SKun Wu // used, we require two separate Creation ops to be the correct logic. In 11898ed59c53SKun Wu // future, we may add support to using one handle in sparse tensor / GPU 11908ed59c53SKun Wu // dialect in both cusparse and cusparseLt. use the cusparseLt create call if 11918ed59c53SKun Wu // the dnmat is used with spmat with 2:4 sparsity 119297f4c22bSKun Wu if (dims.size() == 2) { 119397f4c22bSKun Wu if (isSpMMCusparseLtOp(op.getDnTensor())) { 11948ed59c53SKun Wu auto handleSz = rewriter.create<LLVM::ConstantOp>( 11958ed59c53SKun Wu loc, getIndexType(), rewriter.getIndexAttr(11032)); 119686eff489SAart Bik handle = rewriter.create<LLVM::AllocaOp>( 1197dbd4a0ddSChristian Ulmann loc, llvmPointerType, llvmInt8Type, handleSz, /*alignment=*/16); 11988ed59c53SKun Wu handle = rewriter.create<LLVM::BitcastOp>(loc, llvmPointerType, handle); 11998ed59c53SKun Wu 12008ed59c53SKun Wu createLtDnMatCallBuilder 12018ed59c53SKun Wu .create(loc, rewriter, 1202be2dd22bSKun Wu {handle, dims[0], dims[1], pTensor, dtp, stream}) 12038ed59c53SKun Wu .getResult(); 12048ed59c53SKun Wu } else { 12058ed59c53SKun Wu handle = 1206981cf167SAart Bik createDnMatCallBuilder 120797f4c22bSKun Wu .create(loc, rewriter, {dims[0], dims[1], pTensor, dtp, stream}) 120897f4c22bSKun Wu .getResult(); 120997f4c22bSKun Wu } 121097f4c22bSKun Wu } else { 121197f4c22bSKun Wu assert(dims.size() == 1 && "Only 1D and 2D tensors are supported"); 121297f4c22bSKun Wu handle = createDnVecCallBuilder 121397f4c22bSKun Wu .create(loc, rewriter, {dims[0], pTensor, dtp, stream}) 1214981cf167SAart Bik .getResult(); 12158ed59c53SKun Wu } 1216981cf167SAart Bik rewriter.replaceOp(op, {handle, stream}); 1217981cf167SAart Bik return success(); 1218981cf167SAart Bik } 1219981cf167SAart Bik 122097f4c22bSKun Wu LogicalResult ConvertDestroyDnTensorOpToGpuRuntimeCallPattern::matchAndRewrite( 122197f4c22bSKun Wu gpu::DestroyDnTensorOp op, OpAdaptor adaptor, 1222981cf167SAart Bik ConversionPatternRewriter &rewriter) const { 1223981cf167SAart Bik if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) || 1224981cf167SAart Bik failed(isAsyncWithOneDependency(rewriter, op))) 1225981cf167SAart Bik return failure(); 1226981cf167SAart Bik Location loc = op.getLoc(); 1227981cf167SAart Bik auto stream = adaptor.getAsyncDependencies().front(); 122897f4c22bSKun Wu auto definingOp = op.getDnTensor().getDefiningOp<gpu::CreateDnTensorOp>(); 122997f4c22bSKun Wu SmallVector<Value, 4> dims; 123097f4c22bSKun Wu for (Value dim : definingOp.getDims()) { 123197f4c22bSKun Wu dims.push_back(dim); 123297f4c22bSKun Wu } 123397f4c22bSKun Wu if (dims.size() == 2) { 12348ed59c53SKun Wu // Use the cusparseLt destroy call if the dnmat is used with spmat with 12358ed59c53SKun Wu // 2:4 sparsity 123697f4c22bSKun Wu if (isSpMMCusparseLtOp(op.getDnTensor())) { 12378ed59c53SKun Wu destroyCuSparseLtDnMatBuilder.create(loc, rewriter, 123897f4c22bSKun Wu {adaptor.getDnTensor(), stream}); 12398ed59c53SKun Wu } else { 124097f4c22bSKun Wu destroyDnMatCallBuilder.create(loc, rewriter, 124197f4c22bSKun Wu {adaptor.getDnTensor(), stream}); 124297f4c22bSKun Wu } 124397f4c22bSKun Wu } else { 124497f4c22bSKun Wu assert(dims.size() == 1 && "Only 1D and 2D tensors are supported"); 124597f4c22bSKun Wu destroyDnVecCallBuilder.create(loc, rewriter, 124697f4c22bSKun Wu {adaptor.getDnTensor(), stream}); 12478ed59c53SKun Wu } 1248981cf167SAart Bik rewriter.replaceOp(op, {stream}); 1249981cf167SAart Bik return success(); 1250981cf167SAart Bik } 1251981cf167SAart Bik 1252b700a90cSAart Bik LogicalResult ConvertCreateCooOpToGpuRuntimeCallPattern::matchAndRewrite( 1253b700a90cSAart Bik gpu::CreateCooOp op, OpAdaptor adaptor, 1254b700a90cSAart Bik ConversionPatternRewriter &rewriter) const { 1255b700a90cSAart Bik if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) || 1256b700a90cSAart Bik failed(isAsyncWithOneDependency(rewriter, op))) 1257b700a90cSAart Bik return failure(); 1258b700a90cSAart Bik Location loc = op.getLoc(); 1259b700a90cSAart Bik auto stream = adaptor.getAsyncDependencies().front(); 1260b700a90cSAart Bik Value pRowIdxs = 1261b700a90cSAart Bik MemRefDescriptor(adaptor.getRowIdxs()).allocatedPtr(rewriter, loc); 1262b700a90cSAart Bik Value pColIdxs = 1263b700a90cSAart Bik MemRefDescriptor(adaptor.getColIdxs()).allocatedPtr(rewriter, loc); 1264b700a90cSAart Bik Value pValues = 1265b700a90cSAart Bik MemRefDescriptor(adaptor.getValues()).allocatedPtr(rewriter, loc); 1266cf44847bSKun Wu Type iType = 1267cf44847bSKun Wu llvm::cast<MemRefType>(op.getColIdxs().getType()).getElementType(); 1268cf44847bSKun Wu Type dType = 1269cf44847bSKun Wu llvm::cast<MemRefType>(op.getValues().getType()).getElementType(); 1270cc402de0SKun Wu auto itp = genConstInt32From(rewriter, loc, getCuSparseIndexTypeFrom(iType)); 1271cc402de0SKun Wu auto dtp = genConstInt32From(rewriter, loc, getCuSparseDataTypeFrom(dType)); 1272b700a90cSAart Bik auto handle = 1273b700a90cSAart Bik createCooCallBuilder 1274b700a90cSAart Bik .create(loc, rewriter, 1275b700a90cSAart Bik {adaptor.getRows(), adaptor.getCols(), adaptor.getNnz(), 1276cc402de0SKun Wu pRowIdxs, pColIdxs, pValues, itp, dtp, stream}) 1277b700a90cSAart Bik .getResult(); 1278b700a90cSAart Bik rewriter.replaceOp(op, {handle, stream}); 1279b700a90cSAart Bik return success(); 1280b700a90cSAart Bik } 1281b700a90cSAart Bik 12829fc02a7aSAart Bik LogicalResult ConvertCreateCooAoSOpToGpuRuntimeCallPattern::matchAndRewrite( 12839fc02a7aSAart Bik gpu::CreateCooAoSOp op, OpAdaptor adaptor, 12849fc02a7aSAart Bik ConversionPatternRewriter &rewriter) const { 12859fc02a7aSAart Bik if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) || 12869fc02a7aSAart Bik failed(isAsyncWithOneDependency(rewriter, op))) 12879fc02a7aSAart Bik return failure(); 12889fc02a7aSAart Bik Location loc = op.getLoc(); 12899fc02a7aSAart Bik auto stream = adaptor.getAsyncDependencies().front(); 12909fc02a7aSAart Bik Value pIdxs = MemRefDescriptor(adaptor.getIdxs()).allocatedPtr(rewriter, loc); 12919fc02a7aSAart Bik Value pValues = 12929fc02a7aSAart Bik MemRefDescriptor(adaptor.getValues()).allocatedPtr(rewriter, loc); 12938ed59c53SKun Wu Type iType = llvm::cast<MemRefType>(op.getIdxs().getType()).getElementType(); 12949fc02a7aSAart Bik Type dType = 12959fc02a7aSAart Bik llvm::cast<MemRefType>(op.getValues().getType()).getElementType(); 12969fc02a7aSAart Bik auto itp = genConstInt32From(rewriter, loc, getCuSparseIndexTypeFrom(iType)); 12979fc02a7aSAart Bik auto dtp = genConstInt32From(rewriter, loc, getCuSparseDataTypeFrom(dType)); 12989fc02a7aSAart Bik auto handle = 12999fc02a7aSAart Bik createCooAoSCallBuilder 13009fc02a7aSAart Bik .create(loc, rewriter, 13019fc02a7aSAart Bik {adaptor.getRows(), adaptor.getCols(), adaptor.getNnz(), 13029fc02a7aSAart Bik pIdxs, pValues, itp, dtp, stream}) 13039fc02a7aSAart Bik .getResult(); 13049fc02a7aSAart Bik rewriter.replaceOp(op, {handle, stream}); 13059fc02a7aSAart Bik return success(); 13069fc02a7aSAart Bik } 13079fc02a7aSAart Bik 1308b700a90cSAart Bik LogicalResult ConvertCreateCsrOpToGpuRuntimeCallPattern::matchAndRewrite( 1309b700a90cSAart Bik gpu::CreateCsrOp op, OpAdaptor adaptor, 1310b700a90cSAart Bik ConversionPatternRewriter &rewriter) const { 1311b700a90cSAart Bik if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) || 1312b700a90cSAart Bik failed(isAsyncWithOneDependency(rewriter, op))) 1313b700a90cSAart Bik return failure(); 1314b700a90cSAart Bik Location loc = op.getLoc(); 1315b700a90cSAart Bik auto stream = adaptor.getAsyncDependencies().front(); 1316b700a90cSAart Bik Value pRowPos = 1317b700a90cSAart Bik MemRefDescriptor(adaptor.getRowPos()).allocatedPtr(rewriter, loc); 1318b700a90cSAart Bik Value pColIdxs = 1319b700a90cSAart Bik MemRefDescriptor(adaptor.getColIdxs()).allocatedPtr(rewriter, loc); 1320b700a90cSAart Bik Value pValues = 1321b700a90cSAart Bik MemRefDescriptor(adaptor.getValues()).allocatedPtr(rewriter, loc); 1322cf44847bSKun Wu Type pType = 1323cf44847bSKun Wu llvm::cast<MemRefType>(op.getRowPos().getType()).getElementType(); 1324cf44847bSKun Wu Type iType = 1325cf44847bSKun Wu llvm::cast<MemRefType>(op.getColIdxs().getType()).getElementType(); 1326cf44847bSKun Wu Type dType = 1327cf44847bSKun Wu llvm::cast<MemRefType>(op.getValues().getType()).getElementType(); 1328cc402de0SKun Wu auto ptp = genConstInt32From(rewriter, loc, getCuSparseIndexTypeFrom(pType)); 1329cc402de0SKun Wu auto itp = genConstInt32From(rewriter, loc, getCuSparseIndexTypeFrom(iType)); 1330cc402de0SKun Wu auto dtp = genConstInt32From(rewriter, loc, getCuSparseDataTypeFrom(dType)); 1331b700a90cSAart Bik auto handle = 1332b700a90cSAart Bik createCsrCallBuilder 1333b700a90cSAart Bik .create(loc, rewriter, 1334b700a90cSAart Bik {adaptor.getRows(), adaptor.getCols(), adaptor.getNnz(), 1335cc402de0SKun Wu pRowPos, pColIdxs, pValues, ptp, itp, dtp, stream}) 1336b700a90cSAart Bik .getResult(); 1337b700a90cSAart Bik rewriter.replaceOp(op, {handle, stream}); 1338b700a90cSAart Bik return success(); 1339b700a90cSAart Bik } 1340b700a90cSAart Bik 13418ed59c53SKun Wu LogicalResult ConvertCreate2To4SpMatOpToGpuRuntimeCallPattern::matchAndRewrite( 13428ed59c53SKun Wu gpu::Create2To4SpMatOp op, OpAdaptor adaptor, 13438ed59c53SKun Wu ConversionPatternRewriter &rewriter) const { 13448ed59c53SKun Wu if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) || 13458ed59c53SKun Wu failed(isAsyncWithOneDependency(rewriter, op))) 13468ed59c53SKun Wu return failure(); 13478ed59c53SKun Wu Location loc = op.getLoc(); 13488ed59c53SKun Wu auto stream = adaptor.getAsyncDependencies().front(); 13498ed59c53SKun Wu Value pMat = 13508ed59c53SKun Wu MemRefDescriptor(adaptor.getMemref()).allocatedPtr(rewriter, loc); 13518ed59c53SKun Wu Type dType = 13528ed59c53SKun Wu llvm::cast<MemRefType>(op.getMemref().getType()).getElementType(); 1353ac30f48eSKun Wu auto dtp = genConstInt32From(rewriter, loc, getCuSparseDataTypeFrom(dType)); 13548ed59c53SKun Wu 1355ac30f48eSKun Wu // CUDA runner asserts the size is 44104 bytes. 13568ed59c53SKun Wu auto handleSz = rewriter.create<LLVM::ConstantOp>( 13578ed59c53SKun Wu loc, getIndexType(), rewriter.getIndexAttr(44104)); 135886eff489SAart Bik Value handle = rewriter.create<LLVM::AllocaOp>( 1359dbd4a0ddSChristian Ulmann loc, llvmPointerType, llvmInt8Type, handleSz, /*alignment=*/16); 13608ed59c53SKun Wu handle = rewriter.create<LLVM::BitcastOp>(loc, llvmPointerType, handle); 13618ed59c53SKun Wu 13628ed59c53SKun Wu create2To4SpMatCallBuilder 13638ed59c53SKun Wu .create(loc, rewriter, 1364be2dd22bSKun Wu {handle, adaptor.getRows(), adaptor.getCols(), pMat, dtp, stream}) 13658ed59c53SKun Wu .getResult(); 13668ed59c53SKun Wu rewriter.replaceOp(op, {handle, stream}); 13678ed59c53SKun Wu return success(); 13688ed59c53SKun Wu } 13698ed59c53SKun Wu 1370b700a90cSAart Bik LogicalResult ConvertDestroySpMatOpToGpuRuntimeCallPattern::matchAndRewrite( 1371b700a90cSAart Bik gpu::DestroySpMatOp op, OpAdaptor adaptor, 1372b700a90cSAart Bik ConversionPatternRewriter &rewriter) const { 1373b700a90cSAart Bik if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) || 1374b700a90cSAart Bik failed(isAsyncWithOneDependency(rewriter, op))) 1375b700a90cSAart Bik return failure(); 1376b700a90cSAart Bik Location loc = op.getLoc(); 1377b700a90cSAart Bik auto stream = adaptor.getAsyncDependencies().front(); 13788ed59c53SKun Wu // Use the cusparseLt destroy call if the spmat is 2:4 sparsity 13798ed59c53SKun Wu if (is2To4Sparsity(op.getSpmat())) { 13808ed59c53SKun Wu destroyCuSparseLtSpMatBuilder.create(loc, rewriter, 13818ed59c53SKun Wu {adaptor.getSpmat(), stream}); 13828ed59c53SKun Wu 13838ed59c53SKun Wu } else { 1384b700a90cSAart Bik destroySpMatCallBuilder.create(loc, rewriter, {adaptor.getSpmat(), stream}); 13858ed59c53SKun Wu } 1386b700a90cSAart Bik rewriter.replaceOp(op, {stream}); 1387b700a90cSAart Bik return success(); 1388b700a90cSAart Bik } 1389b700a90cSAart Bik 1390b700a90cSAart Bik LogicalResult ConvertSpMVBufferSizeOpToGpuRuntimeCallPattern::matchAndRewrite( 1391b700a90cSAart Bik gpu::SpMVBufferSizeOp op, OpAdaptor adaptor, 1392b700a90cSAart Bik ConversionPatternRewriter &rewriter) const { 1393b700a90cSAart Bik if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) || 1394b700a90cSAart Bik failed(isAsyncWithOneDependency(rewriter, op))) 1395b700a90cSAart Bik return failure(); 1396b700a90cSAart Bik Location loc = op.getLoc(); 1397cc402de0SKun Wu auto modeA = genConstInt32From(rewriter, loc, op.getModeA()); 1398ac30f48eSKun Wu auto computeType = genConstInt32From( 1399ac30f48eSKun Wu rewriter, loc, getCuSparseDataTypeFrom(adaptor.getComputeType())); 1400b700a90cSAart Bik auto stream = adaptor.getAsyncDependencies().front(); 1401be2dd22bSKun Wu auto bufferSize = spMVBufferSizeCallBuilder 1402b700a90cSAart Bik .create(loc, rewriter, 1403be2dd22bSKun Wu {modeA, adaptor.getSpmatA(), adaptor.getDnX(), 1404be2dd22bSKun Wu adaptor.getDnY(), computeType, stream}) 1405b700a90cSAart Bik .getResult(); 1406b700a90cSAart Bik rewriter.replaceOp(op, {bufferSize, stream}); 1407b700a90cSAart Bik return success(); 1408b700a90cSAart Bik } 1409b700a90cSAart Bik 1410b700a90cSAart Bik LogicalResult ConvertSpMVOpToGpuRuntimeCallPattern::matchAndRewrite( 1411b700a90cSAart Bik gpu::SpMVOp op, OpAdaptor adaptor, 1412b700a90cSAart Bik ConversionPatternRewriter &rewriter) const { 1413b700a90cSAart Bik if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) || 1414b700a90cSAart Bik failed(isAsyncWithOneDependency(rewriter, op))) 1415b700a90cSAart Bik return failure(); 1416b700a90cSAart Bik Location loc = op.getLoc(); 1417cc402de0SKun Wu auto modeA = genConstInt32From(rewriter, loc, adaptor.getModeA()); 1418ac30f48eSKun Wu auto computeType = genConstInt32From( 1419ac30f48eSKun Wu rewriter, loc, getCuSparseDataTypeFrom(adaptor.getComputeType())); 1420b700a90cSAart Bik auto stream = adaptor.getAsyncDependencies().front(); 1421b700a90cSAart Bik Value pBuf = 1422b700a90cSAart Bik MemRefDescriptor(adaptor.getBuffer()).allocatedPtr(rewriter, loc); 1423b700a90cSAart Bik spMVCallBuilder.create(loc, rewriter, 1424be2dd22bSKun Wu {modeA, adaptor.getSpmatA(), adaptor.getDnX(), 1425be2dd22bSKun Wu adaptor.getDnY(), computeType, pBuf, stream}); 1426b700a90cSAart Bik rewriter.replaceOp(op, {stream}); 1427b700a90cSAart Bik return success(); 1428b700a90cSAart Bik } 1429b700a90cSAart Bik 1430981cf167SAart Bik LogicalResult ConvertSpMMBufferSizeOpToGpuRuntimeCallPattern::matchAndRewrite( 1431981cf167SAart Bik gpu::SpMMBufferSizeOp op, OpAdaptor adaptor, 1432981cf167SAart Bik ConversionPatternRewriter &rewriter) const { 1433981cf167SAart Bik if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) || 1434981cf167SAart Bik failed(isAsyncWithOneDependency(rewriter, op))) 1435981cf167SAart Bik return failure(); 1436981cf167SAart Bik Location loc = op.getLoc(); 1437cc402de0SKun Wu auto modeA = genConstInt32From(rewriter, loc, adaptor.getModeA()); 1438cc402de0SKun Wu auto modeB = genConstInt32From(rewriter, loc, adaptor.getModeB()); 1439981cf167SAart Bik auto stream = adaptor.getAsyncDependencies().front(); 14408ed59c53SKun Wu Value bufferSize; 14418ed59c53SKun Wu if (is2To4Sparsity(op.getSpmatA())) { 14426ac80a76SMehdi Amini auto pruneFlag = 14431e491c42SKun Wu genConstInt32From(rewriter, loc, get2To4PruneFlag(op.getSpmatA())); 1444ac30f48eSKun Wu auto computeType = genConstInt32From( 1445ac30f48eSKun Wu rewriter, loc, getCuSparseLtDataTypeFrom(adaptor.getComputeType())); 14468ed59c53SKun Wu auto three = rewriter.create<LLVM::ConstantOp>(loc, getIndexType(), 14478ed59c53SKun Wu rewriter.getIndexAttr(3)); 144886eff489SAart Bik auto bufferSize = rewriter.create<LLVM::AllocaOp>( 1449dbd4a0ddSChristian Ulmann loc, llvmPointerType, llvmPointerType, three, /*alignment=*/16); 1450ac30f48eSKun Wu createCuSparseLtSpMMBufferSizeBuilder 14518ed59c53SKun Wu .create(loc, rewriter, 1452be2dd22bSKun Wu {bufferSize, modeA, modeB, adaptor.getSpmatA(), 14531e491c42SKun Wu adaptor.getDnmatB(), adaptor.getDnmatC(), computeType, 14546ac80a76SMehdi Amini pruneFlag, stream}) 14558ed59c53SKun Wu .getResult(); 1456632ccc53SKun Wu 1457632ccc53SKun Wu auto bufferSizePtr1 = rewriter.create<LLVM::GEPOp>( 1458dbd4a0ddSChristian Ulmann loc, llvmPointerType, llvmPointerType, bufferSize, 1459632ccc53SKun Wu ValueRange{rewriter.create<LLVM::ConstantOp>( 1460632ccc53SKun Wu loc, getIndexType(), rewriter.getIndexAttr(1))}); 1461632ccc53SKun Wu auto bufferSizePtr2 = rewriter.create<LLVM::GEPOp>( 1462dbd4a0ddSChristian Ulmann loc, llvmPointerType, llvmPointerType, bufferSize, 1463632ccc53SKun Wu ValueRange{rewriter.create<LLVM::ConstantOp>( 1464632ccc53SKun Wu loc, getIndexType(), rewriter.getIndexAttr(2))}); 1465632ccc53SKun Wu auto bufferSize0 = 1466632ccc53SKun Wu rewriter.create<LLVM::LoadOp>(loc, llvmInt64Type, bufferSize); 1467632ccc53SKun Wu auto bufferSize1 = 1468632ccc53SKun Wu rewriter.create<LLVM::LoadOp>(loc, llvmInt64Type, bufferSizePtr1); 1469632ccc53SKun Wu auto bufferSize2 = 1470632ccc53SKun Wu rewriter.create<LLVM::LoadOp>(loc, llvmInt64Type, bufferSizePtr2); 1471632ccc53SKun Wu 1472632ccc53SKun Wu rewriter.replaceOp(op, {bufferSize0, bufferSize1, bufferSize2, stream}); 14738ed59c53SKun Wu } else { 1474ac30f48eSKun Wu auto computeType = genConstInt32From( 1475ac30f48eSKun Wu rewriter, loc, getCuSparseDataTypeFrom(adaptor.getComputeType())); 1476be2dd22bSKun Wu bufferSize = 1477be2dd22bSKun Wu createSpMMBufferSizeCallBuilder 1478981cf167SAart Bik .create(loc, rewriter, 1479be2dd22bSKun Wu {modeA, modeB, adaptor.getSpmatA(), adaptor.getDnmatB(), 1480cc402de0SKun Wu adaptor.getDnmatC(), computeType, stream}) 1481981cf167SAart Bik .getResult(); 1482981cf167SAart Bik rewriter.replaceOp(op, {bufferSize, stream}); 14838ed59c53SKun Wu } 1484981cf167SAart Bik return success(); 1485981cf167SAart Bik } 1486981cf167SAart Bik 1487cf44847bSKun Wu LogicalResult ConvertSDDMMBufferSizeOpToGpuRuntimeCallPattern::matchAndRewrite( 1488cf44847bSKun Wu gpu::SDDMMBufferSizeOp op, OpAdaptor adaptor, 1489cf44847bSKun Wu ConversionPatternRewriter &rewriter) const { 1490cf44847bSKun Wu if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) || 1491cf44847bSKun Wu failed(isAsyncWithOneDependency(rewriter, op))) 1492cf44847bSKun Wu return failure(); 1493cf44847bSKun Wu Location loc = op.getLoc(); 1494cc402de0SKun Wu auto modeA = genConstInt32From(rewriter, loc, adaptor.getModeA()); 1495cc402de0SKun Wu auto modeB = genConstInt32From(rewriter, loc, adaptor.getModeB()); 1496ac30f48eSKun Wu auto computeType = genConstInt32From( 1497ac30f48eSKun Wu rewriter, loc, getCuSparseDataTypeFrom(adaptor.getComputeType())); 1498cf44847bSKun Wu auto stream = adaptor.getAsyncDependencies().front(); 1499be2dd22bSKun Wu auto bufferSize = 1500be2dd22bSKun Wu createSDDMMBufferSizeCallBuilder 1501cf44847bSKun Wu .create(loc, rewriter, 1502be2dd22bSKun Wu {modeA, modeB, adaptor.getDnmatA(), adaptor.getDnmatB(), 1503cc402de0SKun Wu adaptor.getSpmatC(), computeType, stream}) 1504cf44847bSKun Wu .getResult(); 1505cf44847bSKun Wu rewriter.replaceOp(op, {bufferSize, stream}); 1506cf44847bSKun Wu return success(); 1507cf44847bSKun Wu } 1508cf44847bSKun Wu 1509981cf167SAart Bik LogicalResult ConvertSpMMOpToGpuRuntimeCallPattern::matchAndRewrite( 1510981cf167SAart Bik gpu::SpMMOp op, OpAdaptor adaptor, 1511981cf167SAart Bik ConversionPatternRewriter &rewriter) const { 1512981cf167SAart Bik if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) || 1513981cf167SAart Bik failed(isAsyncWithOneDependency(rewriter, op))) 1514981cf167SAart Bik return failure(); 1515981cf167SAart Bik Location loc = op.getLoc(); 1516cc402de0SKun Wu auto modeA = genConstInt32From(rewriter, loc, adaptor.getModeA()); 1517cc402de0SKun Wu auto modeB = genConstInt32From(rewriter, loc, adaptor.getModeB()); 1518ac30f48eSKun Wu auto computeType = genConstInt32From( 1519ac30f48eSKun Wu rewriter, loc, getCuSparseDataTypeFrom(adaptor.getComputeType())); 1520cc402de0SKun Wu 1521981cf167SAart Bik auto stream = adaptor.getAsyncDependencies().front(); 15228ed59c53SKun Wu 15238ed59c53SKun Wu // Lower to cusparseLt if applicable 15248ed59c53SKun Wu if (is2To4Sparsity(op.getSpmatA())) { 15258ed59c53SKun Wu SmallVector<Value> pBufs; 15268ed59c53SKun Wu for (Value buffer : adaptor.getBuffers()) { 15278ed59c53SKun Wu Value pBuf = MemRefDescriptor(buffer).allocatedPtr(rewriter, loc); 15288ed59c53SKun Wu pBufs.push_back(pBuf); 15298ed59c53SKun Wu } 1530ac30f48eSKun Wu createCuSparseLtSpMMBuilder.create( 1531ac30f48eSKun Wu loc, rewriter, 1532be2dd22bSKun Wu {adaptor.getSpmatA(), adaptor.getDnmatB(), adaptor.getDnmatC(), 1533be2dd22bSKun Wu pBufs[0], pBufs[1], pBufs[2], stream}); 15348ed59c53SKun Wu } else { 15358ed59c53SKun Wu Value pBuf = MemRefDescriptor(adaptor.getBuffers().front()) 15368ed59c53SKun Wu .allocatedPtr(rewriter, loc); 1537be2dd22bSKun Wu createSpMMCallBuilder.create(loc, rewriter, 1538be2dd22bSKun Wu {modeA, modeB, adaptor.getSpmatA(), 1539be2dd22bSKun Wu adaptor.getDnmatB(), adaptor.getDnmatC(), 1540be2dd22bSKun Wu computeType, pBuf, stream}); 15418ed59c53SKun Wu } 1542981cf167SAart Bik rewriter.replaceOp(op, {stream}); 1543981cf167SAart Bik return success(); 1544981cf167SAart Bik } 1545981cf167SAart Bik 154686bf710cSKun Wu template <typename T> 154786bf710cSKun Wu static void addOpaquePointerConversion(LLVMTypeConverter &converter) { 154886bf710cSKun Wu converter.addConversion([&converter](T) -> Type { 1549dbd4a0ddSChristian Ulmann return LLVM::LLVMPointerType::get(&converter.getContext()); 155086bf710cSKun Wu }); 155186bf710cSKun Wu } 155286bf710cSKun Wu 1553cf44847bSKun Wu LogicalResult ConvertSDDMMOpToGpuRuntimeCallPattern::matchAndRewrite( 1554cf44847bSKun Wu gpu::SDDMMOp op, OpAdaptor adaptor, 1555cf44847bSKun Wu ConversionPatternRewriter &rewriter) const { 1556cf44847bSKun Wu if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) || 1557cf44847bSKun Wu failed(isAsyncWithOneDependency(rewriter, op))) 1558cf44847bSKun Wu return failure(); 1559cf44847bSKun Wu Location loc = op.getLoc(); 1560ac30f48eSKun Wu auto computeType = genConstInt32From( 1561ac30f48eSKun Wu rewriter, loc, getCuSparseDataTypeFrom(adaptor.getComputeType())); 1562cc402de0SKun Wu auto modeA = genConstInt32From(rewriter, loc, adaptor.getModeA()); 1563cc402de0SKun Wu auto modeB = genConstInt32From(rewriter, loc, adaptor.getModeB()); 1564cf44847bSKun Wu auto stream = adaptor.getAsyncDependencies().front(); 1565cf44847bSKun Wu Value pBuf = 1566cf44847bSKun Wu MemRefDescriptor(adaptor.getBuffer()).allocatedPtr(rewriter, loc); 1567be2dd22bSKun Wu createSDDMMCallBuilder.create(loc, rewriter, 1568be2dd22bSKun Wu {modeA, modeB, adaptor.getDnmatA(), 1569be2dd22bSKun Wu adaptor.getDnmatB(), adaptor.getSpmatC(), 1570be2dd22bSKun Wu computeType, pBuf, stream}); 1571cf44847bSKun Wu rewriter.replaceOp(op, {stream}); 1572cf44847bSKun Wu return success(); 1573cf44847bSKun Wu } 1574cf44847bSKun Wu 1575dfe29429SKun Wu LogicalResult 1576dfe29429SKun Wu ConvertSpGEMMCreateDescrOpToGpuRuntimeCallPattern::matchAndRewrite( 1577dfe29429SKun Wu gpu::SpGEMMCreateDescrOp op, OpAdaptor adaptor, 1578dfe29429SKun Wu ConversionPatternRewriter &rewriter) const { 1579dfe29429SKun Wu if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) || 1580dfe29429SKun Wu failed(isAsyncWithOneDependency(rewriter, op))) 1581dfe29429SKun Wu return failure(); 1582dfe29429SKun Wu Location loc = op.getLoc(); 1583dfe29429SKun Wu auto stream = adaptor.getAsyncDependencies().front(); 1584dfe29429SKun Wu Value descr = createSpGEMMCreateDescrBuilder.create(loc, rewriter, {stream}) 1585dfe29429SKun Wu .getResult(); 1586dfe29429SKun Wu rewriter.replaceOp(op, {descr, stream}); 1587dfe29429SKun Wu return success(); 1588dfe29429SKun Wu } 1589dfe29429SKun Wu 1590dfe29429SKun Wu LogicalResult 1591dfe29429SKun Wu ConvertSpGEMMDestroyDescrOpToGpuRuntimeCallPattern::matchAndRewrite( 1592dfe29429SKun Wu gpu::SpGEMMDestroyDescrOp op, OpAdaptor adaptor, 1593dfe29429SKun Wu ConversionPatternRewriter &rewriter) const { 1594dfe29429SKun Wu if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) || 1595dfe29429SKun Wu failed(isAsyncWithOneDependency(rewriter, op))) 1596dfe29429SKun Wu return failure(); 1597dfe29429SKun Wu Location loc = op.getLoc(); 1598dfe29429SKun Wu auto stream = adaptor.getAsyncDependencies().front(); 159995a6c509SAart Bik createSpGEMMDestroyDescrBuilder.create(loc, rewriter, 160095a6c509SAart Bik {adaptor.getDesc(), stream}); 1601dfe29429SKun Wu rewriter.replaceOp(op, {stream}); 1602dfe29429SKun Wu return success(); 1603dfe29429SKun Wu } 1604dfe29429SKun Wu 1605dfe29429SKun Wu LogicalResult 1606dfe29429SKun Wu ConvertSpGEMMWorkEstimationOrComputeOpToGpuRuntimeCallPattern::matchAndRewrite( 1607dfe29429SKun Wu gpu::SpGEMMWorkEstimationOrComputeOp op, OpAdaptor adaptor, 1608dfe29429SKun Wu ConversionPatternRewriter &rewriter) const { 1609dfe29429SKun Wu if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) || 1610dfe29429SKun Wu failed(isAsyncWithOneDependency(rewriter, op))) 1611dfe29429SKun Wu return failure(); 1612dfe29429SKun Wu Location loc = op.getLoc(); 1613dfe29429SKun Wu auto computeType = genConstInt32From( 1614dfe29429SKun Wu rewriter, loc, getCuSparseDataTypeFrom(adaptor.getComputeType())); 1615dfe29429SKun Wu auto modeA = genConstInt32From(rewriter, loc, adaptor.getModeA()); 1616dfe29429SKun Wu auto modeB = genConstInt32From(rewriter, loc, adaptor.getModeB()); 1617dfe29429SKun Wu auto stream = adaptor.getAsyncDependencies().front(); 1618dfe29429SKun Wu 1619dfe29429SKun Wu Value pBuf = 1620dfe29429SKun Wu MemRefDescriptor(adaptor.getBuffer()).allocatedPtr(rewriter, loc); 1621dfe29429SKun Wu Value bufferSizeNew; 1622dfe29429SKun Wu 1623dfe29429SKun Wu if (adaptor.getKind() == 1624dfe29429SKun Wu gpu::SpGEMMWorkEstimationOrComputeKind::WORK_ESTIMATION) { 1625dfe29429SKun Wu bufferSizeNew = 1626dfe29429SKun Wu createSpGEMMWorkEstimationBuilder 1627dfe29429SKun Wu .create(loc, rewriter, 1628dfe29429SKun Wu {adaptor.getDesc(), modeA, modeB, adaptor.getSpmatA(), 1629e7e4ed0dSAart Bik adaptor.getSpmatB(), adaptor.getSpmatC(), computeType, 1630dfe29429SKun Wu adaptor.getBufferSz(), pBuf, stream}) 1631dfe29429SKun Wu .getResult(); 1632dfe29429SKun Wu } else { 1633dfe29429SKun Wu bufferSizeNew = 1634dfe29429SKun Wu createSpGEMMComputeBuilder 1635dfe29429SKun Wu .create(loc, rewriter, 1636dfe29429SKun Wu {adaptor.getDesc(), modeA, modeB, adaptor.getSpmatA(), 1637e7e4ed0dSAart Bik adaptor.getSpmatB(), adaptor.getSpmatC(), computeType, 1638dfe29429SKun Wu adaptor.getBufferSz(), pBuf, stream}) 1639dfe29429SKun Wu .getResult(); 1640dfe29429SKun Wu } 1641dfe29429SKun Wu rewriter.replaceOp(op, {bufferSizeNew, stream}); 1642dfe29429SKun Wu return success(); 1643dfe29429SKun Wu } 1644dfe29429SKun Wu 1645dfe29429SKun Wu LogicalResult ConvertSpGEMMCopyOpToGpuRuntimeCallPattern::matchAndRewrite( 1646dfe29429SKun Wu gpu::SpGEMMCopyOp op, OpAdaptor adaptor, 1647dfe29429SKun Wu ConversionPatternRewriter &rewriter) const { 1648dfe29429SKun Wu if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) || 1649dfe29429SKun Wu failed(isAsyncWithOneDependency(rewriter, op))) 1650dfe29429SKun Wu return failure(); 1651dfe29429SKun Wu Location loc = op.getLoc(); 1652dfe29429SKun Wu auto computeType = genConstInt32From( 1653dfe29429SKun Wu rewriter, loc, getCuSparseDataTypeFrom(adaptor.getComputeType())); 1654dfe29429SKun Wu auto modeA = genConstInt32From(rewriter, loc, adaptor.getModeA()); 1655dfe29429SKun Wu auto modeB = genConstInt32From(rewriter, loc, adaptor.getModeB()); 1656dfe29429SKun Wu auto stream = adaptor.getAsyncDependencies().front(); 1657e7e4ed0dSAart Bik createSpGEMMCopyBuilder.create(loc, rewriter, 1658e7e4ed0dSAart Bik {adaptor.getDesc(), modeA, modeB, 1659e7e4ed0dSAart Bik adaptor.getSpmatA(), adaptor.getSpmatB(), 1660e7e4ed0dSAart Bik adaptor.getSpmatC(), computeType, stream}); 1661dfe29429SKun Wu rewriter.replaceOp(op, {stream}); 1662dfe29429SKun Wu return success(); 1663dfe29429SKun Wu } 1664dfe29429SKun Wu 1665289f7231SAart Bik LogicalResult ConvertSpMatGetSizeOpToGpuRuntimeCallPattern::matchAndRewrite( 1666289f7231SAart Bik gpu::SpMatGetSizeOp op, OpAdaptor adaptor, 1667dfe29429SKun Wu ConversionPatternRewriter &rewriter) const { 1668dfe29429SKun Wu if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) || 1669dfe29429SKun Wu failed(isAsyncWithOneDependency(rewriter, op))) 1670dfe29429SKun Wu return failure(); 1671dfe29429SKun Wu Location loc = op.getLoc(); 1672dfe29429SKun Wu auto stream = adaptor.getAsyncDependencies().front(); 1673dfe29429SKun Wu 1674dfe29429SKun Wu auto three = rewriter.create<LLVM::ConstantOp>(loc, getIndexType(), 1675dfe29429SKun Wu rewriter.getIndexAttr(3)); 1676dfe29429SKun Wu auto buffer = rewriter.create<LLVM::AllocaOp>( 1677dbd4a0ddSChristian Ulmann loc, llvmPointerType, llvmInt64Type, three, /*alignment=*/16); 1678dfe29429SKun Wu 1679dfe29429SKun Wu auto rowsPtr = rewriter.create<LLVM::GEPOp>( 1680dbd4a0ddSChristian Ulmann loc, llvmPointerType, llvmPointerType, buffer, 1681dfe29429SKun Wu ValueRange{rewriter.create<LLVM::ConstantOp>(loc, getIndexType(), 1682dfe29429SKun Wu rewriter.getIndexAttr(0))}); 1683dfe29429SKun Wu auto colsPtr = rewriter.create<LLVM::GEPOp>( 1684dbd4a0ddSChristian Ulmann loc, llvmPointerType, llvmPointerType, buffer, 1685dfe29429SKun Wu ValueRange{rewriter.create<LLVM::ConstantOp>(loc, getIndexType(), 1686dfe29429SKun Wu rewriter.getIndexAttr(1))}); 1687dfe29429SKun Wu auto nnzsPtr = rewriter.create<LLVM::GEPOp>( 1688dbd4a0ddSChristian Ulmann loc, llvmPointerType, llvmPointerType, buffer, 1689dfe29429SKun Wu ValueRange{rewriter.create<LLVM::ConstantOp>(loc, getIndexType(), 1690dfe29429SKun Wu rewriter.getIndexAttr(2))}); 1691289f7231SAart Bik createSpMatGetSizeBuilder.create( 1692dfe29429SKun Wu loc, rewriter, {adaptor.getSpmat(), rowsPtr, colsPtr, nnzsPtr, stream}); 1693dfe29429SKun Wu auto rows = rewriter.create<LLVM::LoadOp>(loc, llvmInt64Type, rowsPtr); 1694dfe29429SKun Wu auto cols = rewriter.create<LLVM::LoadOp>(loc, llvmInt64Type, colsPtr); 1695dfe29429SKun Wu auto nnzs = rewriter.create<LLVM::LoadOp>(loc, llvmInt64Type, nnzsPtr); 1696dfe29429SKun Wu 1697dfe29429SKun Wu rewriter.replaceOp(op, {rows, cols, nnzs, stream}); 1698dfe29429SKun Wu return success(); 1699dfe29429SKun Wu } 1700dfe29429SKun Wu 170195a6c509SAart Bik LogicalResult ConvertSetCsrPointersOpToGpuRuntimeCallPattern::matchAndRewrite( 170295a6c509SAart Bik gpu::SetCsrPointersOp op, OpAdaptor adaptor, 170395a6c509SAart Bik ConversionPatternRewriter &rewriter) const { 170495a6c509SAart Bik if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) || 170595a6c509SAart Bik failed(isAsyncWithOneDependency(rewriter, op))) 170695a6c509SAart Bik return failure(); 170795a6c509SAart Bik Location loc = op.getLoc(); 170895a6c509SAart Bik auto stream = adaptor.getAsyncDependencies().front(); 170995a6c509SAart Bik Value pPos = 171095a6c509SAart Bik MemRefDescriptor(adaptor.getPositions()).allocatedPtr(rewriter, loc); 171195a6c509SAart Bik Value pCrd = 171295a6c509SAart Bik MemRefDescriptor(adaptor.getCoordinates()).allocatedPtr(rewriter, loc); 171395a6c509SAart Bik Value pVal = 171495a6c509SAart Bik MemRefDescriptor(adaptor.getValues()).allocatedPtr(rewriter, loc); 171595a6c509SAart Bik createSetCsrPointersBuilder.create( 171695a6c509SAart Bik loc, rewriter, {adaptor.getSpmat(), pPos, pCrd, pVal, stream}); 171795a6c509SAart Bik rewriter.replaceOp(op, {stream}); 171895a6c509SAart Bik return success(); 171995a6c509SAart Bik } 172095a6c509SAart Bik 172139038177SAart Bik LogicalResult ConvertCreateCscOpToGpuRuntimeCallPattern::matchAndRewrite( 172239038177SAart Bik gpu::CreateCscOp op, OpAdaptor adaptor, 172339038177SAart Bik ConversionPatternRewriter &rewriter) const { 172439038177SAart Bik if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) || 172539038177SAart Bik failed(isAsyncWithOneDependency(rewriter, op))) 172639038177SAart Bik return failure(); 172739038177SAart Bik Location loc = op.getLoc(); 172839038177SAart Bik auto stream = adaptor.getAsyncDependencies().front(); 172939038177SAart Bik Value pColPos = 173039038177SAart Bik MemRefDescriptor(adaptor.getColPos()).allocatedPtr(rewriter, loc); 173139038177SAart Bik Value pRowIdxs = 173239038177SAart Bik MemRefDescriptor(adaptor.getRowIdxs()).allocatedPtr(rewriter, loc); 173339038177SAart Bik Value pValues = 173439038177SAart Bik MemRefDescriptor(adaptor.getValues()).allocatedPtr(rewriter, loc); 173539038177SAart Bik Type pType = 173639038177SAart Bik llvm::cast<MemRefType>(op.getColPos().getType()).getElementType(); 173739038177SAart Bik Type iType = 173839038177SAart Bik llvm::cast<MemRefType>(op.getRowIdxs().getType()).getElementType(); 173939038177SAart Bik Type dType = 174039038177SAart Bik llvm::cast<MemRefType>(op.getValues().getType()).getElementType(); 174139038177SAart Bik auto ptp = genConstInt32From(rewriter, loc, getCuSparseIndexTypeFrom(pType)); 174239038177SAart Bik auto itp = genConstInt32From(rewriter, loc, getCuSparseIndexTypeFrom(iType)); 174339038177SAart Bik auto dtp = genConstInt32From(rewriter, loc, getCuSparseDataTypeFrom(dType)); 174439038177SAart Bik auto handle = 174539038177SAart Bik createCscCallBuilder 174639038177SAart Bik .create(loc, rewriter, 174739038177SAart Bik {adaptor.getRows(), adaptor.getCols(), adaptor.getNnz(), 174839038177SAart Bik pColPos, pRowIdxs, pValues, ptp, itp, dtp, stream}) 174939038177SAart Bik .getResult(); 175039038177SAart Bik rewriter.replaceOp(op, {handle, stream}); 175139038177SAart Bik return success(); 175239038177SAart Bik } 175339038177SAart Bik 175439038177SAart Bik LogicalResult ConvertCreateBsrOpToGpuRuntimeCallPattern::matchAndRewrite( 175539038177SAart Bik gpu::CreateBsrOp op, OpAdaptor adaptor, 175639038177SAart Bik ConversionPatternRewriter &rewriter) const { 175739038177SAart Bik if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) || 175839038177SAart Bik failed(isAsyncWithOneDependency(rewriter, op))) 175939038177SAart Bik return failure(); 176039038177SAart Bik Location loc = op.getLoc(); 176139038177SAart Bik auto stream = adaptor.getAsyncDependencies().front(); 176239038177SAart Bik Value pRowPos = 176339038177SAart Bik MemRefDescriptor(adaptor.getBRowPos()).allocatedPtr(rewriter, loc); 176439038177SAart Bik Value pColIdxs = 176539038177SAart Bik MemRefDescriptor(adaptor.getBColIdxs()).allocatedPtr(rewriter, loc); 176639038177SAart Bik Value pValues = 176739038177SAart Bik MemRefDescriptor(adaptor.getValues()).allocatedPtr(rewriter, loc); 176839038177SAart Bik Type pType = 176939038177SAart Bik llvm::cast<MemRefType>(op.getBRowPos().getType()).getElementType(); 177039038177SAart Bik Type iType = 177139038177SAart Bik llvm::cast<MemRefType>(op.getBColIdxs().getType()).getElementType(); 177239038177SAart Bik Type dType = 177339038177SAart Bik llvm::cast<MemRefType>(op.getValues().getType()).getElementType(); 177439038177SAart Bik auto ptp = genConstInt32From(rewriter, loc, getCuSparseIndexTypeFrom(pType)); 177539038177SAart Bik auto itp = genConstInt32From(rewriter, loc, getCuSparseIndexTypeFrom(iType)); 177639038177SAart Bik auto dtp = genConstInt32From(rewriter, loc, getCuSparseDataTypeFrom(dType)); 177739038177SAart Bik auto handle = 177839038177SAart Bik createBsrCallBuilder 177939038177SAart Bik .create(loc, rewriter, 178039038177SAart Bik {adaptor.getBrows(), adaptor.getBcols(), adaptor.getBnnz(), 178139038177SAart Bik adaptor.getRBlockSize(), adaptor.getCBlockSize(), pRowPos, 178239038177SAart Bik pColIdxs, pValues, ptp, itp, dtp, stream}) 178339038177SAart Bik .getResult(); 178439038177SAart Bik rewriter.replaceOp(op, {handle, stream}); 178539038177SAart Bik return success(); 178639038177SAart Bik } 178739038177SAart Bik 1788*733be4edSAndrea Faulds void mlir::populateGpuToLLVMConversionPatterns( 1789*733be4edSAndrea Faulds LLVMTypeConverter &converter, RewritePatternSet &patterns, 1790*733be4edSAndrea Faulds bool kernelBarePtrCallConv, bool kernelIntersperseSizeCallConv) { 179186bf710cSKun Wu addOpaquePointerConversion<gpu::AsyncTokenType>(converter); 179297f4c22bSKun Wu addOpaquePointerConversion<gpu::SparseDnTensorHandleType>(converter); 179386bf710cSKun Wu addOpaquePointerConversion<gpu::SparseSpMatHandleType>(converter); 1794dfe29429SKun Wu addOpaquePointerConversion<gpu::SparseSpGEMMOpHandleType>(converter); 1795b700a90cSAart Bik 17967d855605SButygin patterns.add<ConvertAllocOpToGpuRuntimeCallPattern, 17977d855605SButygin ConvertDeallocOpToGpuRuntimeCallPattern, 17987d855605SButygin ConvertHostRegisterOpToGpuRuntimeCallPattern, 17998f7c8a6eSmax ConvertHostUnregisterOpToGpuRuntimeCallPattern, 18007d855605SButygin ConvertMemcpyOpToGpuRuntimeCallPattern, 1801361458b1SLoren Maggiore ConvertMemsetOpToGpuRuntimeCallPattern, 180284718d37SKrzysztof Drewniak ConvertSetDefaultDeviceOpToGpuRuntimeCallPattern, 18037d855605SButygin ConvertWaitAsyncOpToGpuRuntimeCallPattern, 18047d855605SButygin ConvertWaitOpToGpuRuntimeCallPattern, 1805b700a90cSAart Bik ConvertAsyncYieldToGpuRuntimeCallPattern, 180697f4c22bSKun Wu ConvertCreateDnTensorOpToGpuRuntimeCallPattern, 180797f4c22bSKun Wu ConvertDestroyDnTensorOpToGpuRuntimeCallPattern, 1808b700a90cSAart Bik ConvertCreateCooOpToGpuRuntimeCallPattern, 18099fc02a7aSAart Bik ConvertCreateCooAoSOpToGpuRuntimeCallPattern, 1810b700a90cSAart Bik ConvertCreateCsrOpToGpuRuntimeCallPattern, 181139038177SAart Bik ConvertCreateCscOpToGpuRuntimeCallPattern, 181239038177SAart Bik ConvertCreateBsrOpToGpuRuntimeCallPattern, 18138ed59c53SKun Wu ConvertCreate2To4SpMatOpToGpuRuntimeCallPattern, 1814b700a90cSAart Bik ConvertDestroySpMatOpToGpuRuntimeCallPattern, 18159dfd3c32SAart Bik ConvertSpMVBufferSizeOpToGpuRuntimeCallPattern, 18169dfd3c32SAart Bik ConvertSpMVOpToGpuRuntimeCallPattern, 18179dfd3c32SAart Bik ConvertSpMMBufferSizeOpToGpuRuntimeCallPattern, 18189dfd3c32SAart Bik ConvertSDDMMBufferSizeOpToGpuRuntimeCallPattern, 18199dfd3c32SAart Bik ConvertSpMMOpToGpuRuntimeCallPattern, 18209dfd3c32SAart Bik ConvertSDDMMOpToGpuRuntimeCallPattern, 1821dfe29429SKun Wu ConvertSpGEMMCreateDescrOpToGpuRuntimeCallPattern, 1822dfe29429SKun Wu ConvertSpGEMMDestroyDescrOpToGpuRuntimeCallPattern, 1823dfe29429SKun Wu ConvertSpGEMMWorkEstimationOrComputeOpToGpuRuntimeCallPattern, 1824dfe29429SKun Wu ConvertSpGEMMCopyOpToGpuRuntimeCallPattern, 1825289f7231SAart Bik ConvertSpMatGetSizeOpToGpuRuntimeCallPattern, 182695a6c509SAart Bik ConvertSetCsrPointersOpToGpuRuntimeCallPattern>(converter); 182799a562b3SAndrea Faulds patterns.add<LegalizeLaunchFuncOpPattern>(converter, kernelBarePtrCallConv, 1828*733be4edSAndrea Faulds kernelIntersperseSizeCallConv); 18297d855605SButygin } 18307498eaa9SFabian Mora 18317498eaa9SFabian Mora //===----------------------------------------------------------------------===// 18327498eaa9SFabian Mora // GPUModuleOp convert to LLVM op interface 18337498eaa9SFabian Mora //===----------------------------------------------------------------------===// 18347498eaa9SFabian Mora 18357498eaa9SFabian Mora namespace { 18367498eaa9SFabian Mora struct GPUModuleOpConvertToLLVMInterface 18377498eaa9SFabian Mora : public ConvertToLLVMOpInterface::ExternalModel< 18387498eaa9SFabian Mora GPUModuleOpConvertToLLVMInterface, gpu::GPUModuleOp> { 18397498eaa9SFabian Mora /// Get the conversion patterns from the target attribute. 18407498eaa9SFabian Mora void getConvertToLLVMConversionAttrs( 18417498eaa9SFabian Mora Operation *op, SmallVectorImpl<ConvertToLLVMAttrInterface> &attrs) const; 18427498eaa9SFabian Mora }; 18437498eaa9SFabian Mora } // namespace 18447498eaa9SFabian Mora 18457498eaa9SFabian Mora void GPUModuleOpConvertToLLVMInterface::getConvertToLLVMConversionAttrs( 18467498eaa9SFabian Mora Operation *op, SmallVectorImpl<ConvertToLLVMAttrInterface> &attrs) const { 18477498eaa9SFabian Mora auto module = cast<gpu::GPUModuleOp>(op); 18487498eaa9SFabian Mora ArrayAttr targetsAttr = module.getTargetsAttr(); 18497498eaa9SFabian Mora // Fail if there are no target attributes or there is more than one target. 18507498eaa9SFabian Mora if (!targetsAttr || targetsAttr.size() != 1) 18517498eaa9SFabian Mora return; 18527498eaa9SFabian Mora if (auto patternAttr = dyn_cast<ConvertToLLVMAttrInterface>(targetsAttr[0])) 18537498eaa9SFabian Mora attrs.push_back(patternAttr); 18547498eaa9SFabian Mora } 18557498eaa9SFabian Mora 18567498eaa9SFabian Mora void mlir::gpu::registerConvertGpuToLLVMInterface(DialectRegistry ®istry) { 18577498eaa9SFabian Mora registry.addExtension(+[](MLIRContext *ctx, gpu::GPUDialect *dialect) { 18587498eaa9SFabian Mora gpu::GPUModuleOp::attachInterface<GPUModuleOpConvertToLLVMInterface>(*ctx); 18597498eaa9SFabian Mora }); 18607498eaa9SFabian Mora } 1861