xref: /llvm-project/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp (revision 733be4ed7dcf976719f424c0cb81b77a14f91f5a)
1a5f9cda1SChristian Sigg //===- ConvertLaunchFuncToGpuRuntimeCalls.cpp - MLIR GPU lowering passes --===//
2a5f9cda1SChristian Sigg //
3a5f9cda1SChristian Sigg // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4a5f9cda1SChristian Sigg // See https://llvm.org/LICENSE.txt for license information.
5a5f9cda1SChristian Sigg // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6a5f9cda1SChristian Sigg //
7a5f9cda1SChristian Sigg //===----------------------------------------------------------------------===//
8a5f9cda1SChristian Sigg //
9a5f9cda1SChristian Sigg // This file implements a pass to convert gpu.launch_func op into a sequence of
10a5f9cda1SChristian Sigg // GPU runtime calls. As most of GPU runtimes does not have a stable published
11a5f9cda1SChristian Sigg // ABI, this pass uses a slim runtime layer that builds on top of the public
12a5f9cda1SChristian Sigg // API from GPU runtime headers.
13a5f9cda1SChristian Sigg //
14a5f9cda1SChristian Sigg //===----------------------------------------------------------------------===//
15a5f9cda1SChristian Sigg 
16a5f9cda1SChristian Sigg #include "mlir/Conversion/GPUCommon/GPUCommonPass.h"
17a5f9cda1SChristian Sigg 
18abc362a1SJakub Kuderski #include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h"
19a5f9cda1SChristian Sigg #include "mlir/Conversion/AsyncToLLVM/AsyncToLLVM.h"
20330838ebSRiver Riddle #include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
219e7b6f46SMehdi Amini #include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
229e7b6f46SMehdi Amini #include "mlir/Conversion/ConvertToLLVM/ToLLVMPass.h"
235a7b9194SRiver Riddle #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h"
245a7b9194SRiver Riddle #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVMPass.h"
257498eaa9SFabian Mora #include "mlir/Conversion/GPUCommon/GPUToLLVM.h"
2675e5f0aaSAlex Zinenko #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
27684dfe8aSAlex Zinenko #include "mlir/Conversion/LLVMCommon/Pattern.h"
2875e5f0aaSAlex Zinenko #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
29a5f9cda1SChristian Sigg #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
30a5f9cda1SChristian Sigg #include "mlir/Dialect/Async/IR/Async.h"
31d7ef488bSMogball #include "mlir/Dialect/GPU/IR/GPUDialect.h"
32d7ef488bSMogball #include "mlir/Dialect/GPU/Transforms/Passes.h"
33a5f9cda1SChristian Sigg #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
3416f8d17fSXiang Li #include "mlir/Dialect/MemRef/IR/MemRef.h"
350693b9e9SMatthias Springer #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
36a5f9cda1SChristian Sigg #include "mlir/IR/Attributes.h"
37a5f9cda1SChristian Sigg #include "mlir/IR/Builders.h"
38a5f9cda1SChristian Sigg #include "mlir/IR/BuiltinOps.h"
39a5f9cda1SChristian Sigg #include "mlir/IR/BuiltinTypes.h"
400693b9e9SMatthias Springer #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
41a5f9cda1SChristian Sigg 
42a5f9cda1SChristian Sigg #include "llvm/ADT/STLExtras.h"
43a5f9cda1SChristian Sigg #include "llvm/Support/Error.h"
44a5f9cda1SChristian Sigg #include "llvm/Support/FormatVariadic.h"
45a5f9cda1SChristian Sigg 
469e7b6f46SMehdi Amini #define DEBUG_TYPE "gpu-to-llvm"
479e7b6f46SMehdi Amini 
4867d0d7acSMichele Scuttari namespace mlir {
4967d0d7acSMichele Scuttari #define GEN_PASS_DEF_GPUTOLLVMCONVERSIONPASS
5067d0d7acSMichele Scuttari #include "mlir/Conversion/Passes.h.inc"
5167d0d7acSMichele Scuttari } // namespace mlir
5267d0d7acSMichele Scuttari 
53a5f9cda1SChristian Sigg using namespace mlir;
54a5f9cda1SChristian Sigg 
55a5f9cda1SChristian Sigg namespace {
56039b969bSMichele Scuttari class GpuToLLVMConversionPass
5767d0d7acSMichele Scuttari     : public impl::GpuToLLVMConversionPassBase<GpuToLLVMConversionPass> {
58a5f9cda1SChristian Sigg public:
59cd4ca2d7SMarkus Böck   using Base::Base;
609e7b6f46SMehdi Amini   void getDependentDialects(DialectRegistry &registry) const final {
619e7b6f46SMehdi Amini     Base::getDependentDialects(registry);
629e7b6f46SMehdi Amini     registerConvertToLLVMDependentDialectLoading(registry);
639e7b6f46SMehdi Amini   }
64a5f9cda1SChristian Sigg   // Run the dialect converter on the module.
65a5f9cda1SChristian Sigg   void runOnOperation() override;
66a5f9cda1SChristian Sigg };
67a5f9cda1SChristian Sigg 
68a5f9cda1SChristian Sigg template <typename OpTy>
69a5f9cda1SChristian Sigg class ConvertOpToGpuRuntimeCallPattern : public ConvertOpToLLVMPattern<OpTy> {
70a5f9cda1SChristian Sigg public:
71ce254598SMatthias Springer   explicit ConvertOpToGpuRuntimeCallPattern(
72ce254598SMatthias Springer       const LLVMTypeConverter &typeConverter)
73a5f9cda1SChristian Sigg       : ConvertOpToLLVMPattern<OpTy>(typeConverter) {}
74a5f9cda1SChristian Sigg 
75a5f9cda1SChristian Sigg protected:
76361458b1SLoren Maggiore   Value getNumElements(ConversionPatternRewriter &rewriter, Location loc,
77361458b1SLoren Maggiore                        MemRefType type, MemRefDescriptor desc) const {
78e98e5995SAlex Zinenko     Type indexType = ConvertToLLVMPattern::getIndexType();
79361458b1SLoren Maggiore     return type.hasStaticShape()
80620e2bb2SNicolas Vasilache                ? ConvertToLLVMPattern::createIndexAttrConstant(
81620e2bb2SNicolas Vasilache                      rewriter, loc, indexType, type.getNumElements())
82361458b1SLoren Maggiore                // For identity maps (verified by caller), the number of
83361458b1SLoren Maggiore                // elements is stride[0] * size[0].
84361458b1SLoren Maggiore                : rewriter.create<LLVM::MulOp>(loc,
85361458b1SLoren Maggiore                                               desc.stride(rewriter, loc, 0),
86361458b1SLoren Maggiore                                               desc.size(rewriter, loc, 0));
87361458b1SLoren Maggiore   }
88361458b1SLoren Maggiore 
89a5f9cda1SChristian Sigg   MLIRContext *context = &this->getTypeConverter()->getContext();
90a5f9cda1SChristian Sigg 
91a5f9cda1SChristian Sigg   Type llvmVoidType = LLVM::LLVMVoidType::get(context);
92dbd4a0ddSChristian Ulmann   LLVM::LLVMPointerType llvmPointerType = LLVM::LLVMPointerType::get(context);
93a5f9cda1SChristian Sigg   Type llvmInt8Type = IntegerType::get(context, 8);
9418cc07aaSNavdeep Katel   Type llvmInt16Type = IntegerType::get(context, 16);
95a5f9cda1SChristian Sigg   Type llvmInt32Type = IntegerType::get(context, 32);
96a5f9cda1SChristian Sigg   Type llvmInt64Type = IntegerType::get(context, 64);
97dfe29429SKun Wu   Type llvmFloat32Type = Float32Type::get(context);
98a5f9cda1SChristian Sigg   Type llvmIntPtrType = IntegerType::get(
99a5f9cda1SChristian Sigg       context, this->getTypeConverter()->getPointerBitwidth(0));
100a5f9cda1SChristian Sigg 
101a5f9cda1SChristian Sigg   FunctionCallBuilder streamCreateCallBuilder = {
102a5f9cda1SChristian Sigg       "mgpuStreamCreate", llvmPointerType /* void *stream */, {}};
103a5f9cda1SChristian Sigg   FunctionCallBuilder streamDestroyCallBuilder = {
104a5f9cda1SChristian Sigg       "mgpuStreamDestroy", llvmVoidType, {llvmPointerType /* void *stream */}};
105a5f9cda1SChristian Sigg   FunctionCallBuilder streamSynchronizeCallBuilder = {
106a5f9cda1SChristian Sigg       "mgpuStreamSynchronize",
107a5f9cda1SChristian Sigg       llvmVoidType,
108a5f9cda1SChristian Sigg       {llvmPointerType /* void *stream */}};
109a5f9cda1SChristian Sigg   FunctionCallBuilder streamWaitEventCallBuilder = {
110a5f9cda1SChristian Sigg       "mgpuStreamWaitEvent",
111a5f9cda1SChristian Sigg       llvmVoidType,
112a5f9cda1SChristian Sigg       {llvmPointerType /* void *stream */, llvmPointerType /* void *event */}};
113a5f9cda1SChristian Sigg   FunctionCallBuilder eventCreateCallBuilder = {
114a5f9cda1SChristian Sigg       "mgpuEventCreate", llvmPointerType /* void *event */, {}};
115a5f9cda1SChristian Sigg   FunctionCallBuilder eventDestroyCallBuilder = {
116a5f9cda1SChristian Sigg       "mgpuEventDestroy", llvmVoidType, {llvmPointerType /* void *event */}};
117a5f9cda1SChristian Sigg   FunctionCallBuilder eventSynchronizeCallBuilder = {
118a5f9cda1SChristian Sigg       "mgpuEventSynchronize",
119a5f9cda1SChristian Sigg       llvmVoidType,
120a5f9cda1SChristian Sigg       {llvmPointerType /* void *event */}};
121a5f9cda1SChristian Sigg   FunctionCallBuilder eventRecordCallBuilder = {
122a5f9cda1SChristian Sigg       "mgpuEventRecord",
123a5f9cda1SChristian Sigg       llvmVoidType,
124a5f9cda1SChristian Sigg       {llvmPointerType /* void *event */, llvmPointerType /* void *stream */}};
125a5f9cda1SChristian Sigg   FunctionCallBuilder hostRegisterCallBuilder = {
126a5f9cda1SChristian Sigg       "mgpuMemHostRegisterMemRef",
127a5f9cda1SChristian Sigg       llvmVoidType,
128a5f9cda1SChristian Sigg       {llvmIntPtrType /* intptr_t rank */,
129a5f9cda1SChristian Sigg        llvmPointerType /* void *memrefDesc */,
130a5f9cda1SChristian Sigg        llvmIntPtrType /* intptr_t elementSizeBytes */}};
1318f7c8a6eSmax   FunctionCallBuilder hostUnregisterCallBuilder = {
1328f7c8a6eSmax       "mgpuMemHostUnregisterMemRef",
1338f7c8a6eSmax       llvmVoidType,
1348f7c8a6eSmax       {llvmIntPtrType /* intptr_t rank */,
1358f7c8a6eSmax        llvmPointerType /* void *memrefDesc */,
1368f7c8a6eSmax        llvmIntPtrType /* intptr_t elementSizeBytes */}};
137a5f9cda1SChristian Sigg   FunctionCallBuilder allocCallBuilder = {
138a5f9cda1SChristian Sigg       "mgpuMemAlloc",
139a5f9cda1SChristian Sigg       llvmPointerType /* void * */,
140a5f9cda1SChristian Sigg       {llvmIntPtrType /* intptr_t sizeBytes */,
1411002a1d0SNishant Patel        llvmPointerType /* void *stream */,
1421002a1d0SNishant Patel        llvmInt8Type /* bool isHostShared */}};
143a5f9cda1SChristian Sigg   FunctionCallBuilder deallocCallBuilder = {
144a5f9cda1SChristian Sigg       "mgpuMemFree",
145a5f9cda1SChristian Sigg       llvmVoidType,
146a5f9cda1SChristian Sigg       {llvmPointerType /* void *ptr */, llvmPointerType /* void *stream */}};
147a5f9cda1SChristian Sigg   FunctionCallBuilder memcpyCallBuilder = {
148a5f9cda1SChristian Sigg       "mgpuMemcpy",
149a5f9cda1SChristian Sigg       llvmVoidType,
150a5f9cda1SChristian Sigg       {llvmPointerType /* void *dst */, llvmPointerType /* void *src */,
151a5f9cda1SChristian Sigg        llvmIntPtrType /* intptr_t sizeBytes */,
152a5f9cda1SChristian Sigg        llvmPointerType /* void *stream */}};
15318cc07aaSNavdeep Katel   FunctionCallBuilder memset16CallBuilder = {
15418cc07aaSNavdeep Katel       "mgpuMemset16",
15518cc07aaSNavdeep Katel       llvmVoidType,
15618cc07aaSNavdeep Katel       {llvmPointerType /* void *dst */,
15718cc07aaSNavdeep Katel        llvmInt16Type /* unsigned short value */,
15818cc07aaSNavdeep Katel        llvmIntPtrType /* intptr_t sizeBytes */,
15918cc07aaSNavdeep Katel        llvmPointerType /* void *stream */}};
16018cc07aaSNavdeep Katel   FunctionCallBuilder memset32CallBuilder = {
161361458b1SLoren Maggiore       "mgpuMemset32",
162361458b1SLoren Maggiore       llvmVoidType,
163361458b1SLoren Maggiore       {llvmPointerType /* void *dst */, llvmInt32Type /* unsigned int value */,
164361458b1SLoren Maggiore        llvmIntPtrType /* intptr_t sizeBytes */,
165361458b1SLoren Maggiore        llvmPointerType /* void *stream */}};
16684718d37SKrzysztof Drewniak   FunctionCallBuilder setDefaultDeviceCallBuilder = {
16784718d37SKrzysztof Drewniak       "mgpuSetDefaultDevice",
16884718d37SKrzysztof Drewniak       llvmVoidType,
16984718d37SKrzysztof Drewniak       {llvmInt32Type /* uint32_t devIndex */}};
170b700a90cSAart Bik   FunctionCallBuilder createDnVecCallBuilder = {
171b700a90cSAart Bik       "mgpuCreateDnVec",
172b700a90cSAart Bik       llvmPointerType,
173b700a90cSAart Bik       {llvmIntPtrType, llvmPointerType, llvmInt32Type,
174b700a90cSAart Bik        llvmPointerType /* void *stream */}};
175b700a90cSAart Bik   FunctionCallBuilder destroyDnVecCallBuilder = {
176b700a90cSAart Bik       "mgpuDestroyDnVec",
177b700a90cSAart Bik       llvmVoidType,
178b700a90cSAart Bik       {llvmPointerType, llvmPointerType /* void *stream */}};
179981cf167SAart Bik   FunctionCallBuilder createDnMatCallBuilder = {
180981cf167SAart Bik       "mgpuCreateDnMat",
181981cf167SAart Bik       llvmPointerType,
182981cf167SAart Bik       {llvmIntPtrType, llvmIntPtrType, llvmPointerType, llvmInt32Type,
183981cf167SAart Bik        llvmPointerType /* void *stream */}};
184981cf167SAart Bik   FunctionCallBuilder destroyDnMatCallBuilder = {
185981cf167SAart Bik       "mgpuDestroyDnMat",
186981cf167SAart Bik       llvmVoidType,
187981cf167SAart Bik       {llvmPointerType, llvmPointerType /* void *stream */}};
188b700a90cSAart Bik   FunctionCallBuilder createCooCallBuilder = {
189b700a90cSAart Bik       "mgpuCreateCoo",
190b700a90cSAart Bik       llvmPointerType,
191b700a90cSAart Bik       {llvmIntPtrType, llvmIntPtrType, llvmIntPtrType, llvmPointerType,
192b700a90cSAart Bik        llvmPointerType, llvmPointerType, llvmInt32Type, llvmInt32Type,
193b700a90cSAart Bik        llvmPointerType /* void *stream */}};
1949fc02a7aSAart Bik   FunctionCallBuilder createCooAoSCallBuilder = {
1959fc02a7aSAart Bik       "mgpuCreateCooAoS", // deprecated in cuSPARSE 11.2
1969fc02a7aSAart Bik       llvmPointerType,
1979fc02a7aSAart Bik       {llvmIntPtrType, llvmIntPtrType, llvmIntPtrType, llvmPointerType,
1989fc02a7aSAart Bik        llvmPointerType, llvmInt32Type, llvmInt32Type,
1999fc02a7aSAart Bik        llvmPointerType /* void *stream */}};
200b700a90cSAart Bik   FunctionCallBuilder createCsrCallBuilder = {
201b700a90cSAart Bik       "mgpuCreateCsr",
202b700a90cSAart Bik       llvmPointerType,
203b700a90cSAart Bik       {llvmIntPtrType, llvmIntPtrType, llvmIntPtrType, llvmPointerType,
204b700a90cSAart Bik        llvmPointerType, llvmPointerType, llvmInt32Type, llvmInt32Type,
205b700a90cSAart Bik        llvmInt32Type, llvmPointerType /* void *stream */}};
20639038177SAart Bik   FunctionCallBuilder createCscCallBuilder = {
20739038177SAart Bik       "mgpuCreateCsc",
20839038177SAart Bik       llvmPointerType,
20939038177SAart Bik       {llvmIntPtrType, llvmIntPtrType, llvmIntPtrType, llvmPointerType,
21039038177SAart Bik        llvmPointerType, llvmPointerType, llvmInt32Type, llvmInt32Type,
21139038177SAart Bik        llvmInt32Type, llvmPointerType /* void *stream */}};
21239038177SAart Bik   FunctionCallBuilder createBsrCallBuilder = {
21339038177SAart Bik       "mgpuCreateBsr",
21439038177SAart Bik       llvmPointerType,
21539038177SAart Bik       {llvmIntPtrType, llvmIntPtrType, llvmIntPtrType, llvmIntPtrType,
21639038177SAart Bik        llvmIntPtrType, llvmPointerType, llvmPointerType, llvmPointerType,
21739038177SAart Bik        llvmInt32Type, llvmInt32Type, llvmInt32Type,
21839038177SAart Bik        llvmPointerType /* void *stream */}};
219b700a90cSAart Bik   FunctionCallBuilder destroySpMatCallBuilder = {
220b700a90cSAart Bik       "mgpuDestroySpMat",
221b700a90cSAart Bik       llvmVoidType,
222b700a90cSAart Bik       {llvmPointerType, llvmPointerType /* void *stream */}};
223b700a90cSAart Bik   FunctionCallBuilder spMVBufferSizeCallBuilder = {
224b700a90cSAart Bik       "mgpuSpMVBufferSize",
225b700a90cSAart Bik       llvmIntPtrType,
226be2dd22bSKun Wu       {llvmInt32Type, llvmPointerType, llvmPointerType, llvmPointerType,
227be2dd22bSKun Wu        llvmInt32Type, llvmPointerType /* void *stream */}};
228b700a90cSAart Bik   FunctionCallBuilder spMVCallBuilder = {
229b700a90cSAart Bik       "mgpuSpMV",
230b700a90cSAart Bik       llvmVoidType,
231be2dd22bSKun Wu       {llvmInt32Type, llvmPointerType, llvmPointerType, llvmPointerType,
232be2dd22bSKun Wu        llvmInt32Type, llvmPointerType, llvmPointerType /* void *stream */}};
233ac30f48eSKun Wu   FunctionCallBuilder createSpMMBufferSizeCallBuilder = {
234981cf167SAart Bik       "mgpuSpMMBufferSize",
235981cf167SAart Bik       llvmIntPtrType,
236be2dd22bSKun Wu       {llvmInt32Type, llvmInt32Type, llvmPointerType, llvmPointerType,
237be2dd22bSKun Wu        llvmPointerType, llvmInt32Type, llvmPointerType /* void *stream */}};
238ac30f48eSKun Wu   FunctionCallBuilder createSpMMCallBuilder = {
239981cf167SAart Bik       "mgpuSpMM",
240981cf167SAart Bik       llvmVoidType,
241be2dd22bSKun Wu       {llvmInt32Type, llvmInt32Type, llvmPointerType, llvmPointerType,
242be2dd22bSKun Wu        llvmPointerType, llvmInt32Type, llvmPointerType,
243235fbe79SKun Wu        llvmPointerType /* void *stream */}};
244ac30f48eSKun Wu   FunctionCallBuilder createSDDMMBufferSizeCallBuilder = {
245cf44847bSKun Wu       "mgpuSDDMMBufferSize",
246cf44847bSKun Wu       llvmIntPtrType,
247be2dd22bSKun Wu       {llvmInt32Type, llvmInt32Type, llvmPointerType, llvmPointerType,
248be2dd22bSKun Wu        llvmPointerType, llvmInt32Type, llvmPointerType /* void *stream */}};
249ac30f48eSKun Wu   FunctionCallBuilder createSDDMMCallBuilder = {
250cf44847bSKun Wu       "mgpuSDDMM",
251cf44847bSKun Wu       llvmVoidType,
252be2dd22bSKun Wu       {llvmInt32Type, llvmInt32Type, llvmPointerType, llvmPointerType,
253be2dd22bSKun Wu        llvmPointerType, llvmInt32Type, llvmPointerType,
254cf44847bSKun Wu        llvmPointerType /* void *stream */}};
2558ed59c53SKun Wu   FunctionCallBuilder createLtDnMatCallBuilder = {
2568ed59c53SKun Wu       "mgpuCreateCuSparseLtDnMat",
2578ed59c53SKun Wu       llvmVoidType,
258be2dd22bSKun Wu       {llvmPointerType, llvmIntPtrType, llvmIntPtrType, llvmPointerType,
259be2dd22bSKun Wu        llvmInt32Type, llvmPointerType /* void *stream */}};
2608ed59c53SKun Wu   FunctionCallBuilder destroyCuSparseLtSpMatBuilder = {
2618ed59c53SKun Wu       "mgpuDestroyCuSparseLtSpMat",
2628ed59c53SKun Wu       llvmVoidType,
2638ed59c53SKun Wu       {llvmPointerType, llvmPointerType /* void *stream */}};
2648ed59c53SKun Wu   FunctionCallBuilder destroyCuSparseLtDnMatBuilder = {
2658ed59c53SKun Wu       "mgpuDestroyCuSparseLtDnMat",
2668ed59c53SKun Wu       llvmVoidType,
2678ed59c53SKun Wu       {llvmPointerType, llvmPointerType /* void *stream */}};
2688ed59c53SKun Wu   FunctionCallBuilder create2To4SpMatCallBuilder = {
2698ed59c53SKun Wu       "mgpuCusparseLtCreate2To4SpMat",
2708ed59c53SKun Wu       llvmVoidType,
271be2dd22bSKun Wu       {llvmPointerType, llvmIntPtrType, llvmIntPtrType, llvmPointerType,
272be2dd22bSKun Wu        llvmInt32Type, llvmPointerType /* void *stream */}};
273ac30f48eSKun Wu   FunctionCallBuilder createCuSparseLtSpMMBufferSizeBuilder = {
2748ed59c53SKun Wu       "mgpuCuSparseLtSpMMBufferSize",
2758ed59c53SKun Wu       llvmVoidType,
276be2dd22bSKun Wu       {llvmPointerType, llvmInt32Type, llvmInt32Type, llvmPointerType,
2771e491c42SKun Wu        llvmPointerType, llvmPointerType, llvmInt32Type, llvmInt32Type,
2788ed59c53SKun Wu        llvmPointerType /*void *stream*/}};
279ac30f48eSKun Wu   FunctionCallBuilder createCuSparseLtSpMMBuilder = {
2808ed59c53SKun Wu       "mgpuCuSparseLtSpMM",
2818ed59c53SKun Wu       llvmVoidType,
2828ed59c53SKun Wu       {llvmPointerType, llvmPointerType, llvmPointerType, llvmPointerType,
283be2dd22bSKun Wu        llvmPointerType, llvmPointerType, llvmPointerType /*void *stream*/}};
284289f7231SAart Bik   FunctionCallBuilder createSpGEMMCreateDescrBuilder = {
285289f7231SAart Bik       "mgpuSpGEMMCreateDescr",
286289f7231SAart Bik       llvmPointerType,
287289f7231SAart Bik       {llvmPointerType /*void *stream*/}};
288289f7231SAart Bik   FunctionCallBuilder createSpGEMMDestroyDescrBuilder = {
289289f7231SAart Bik       "mgpuSpGEMMDestroyDescr",
290289f7231SAart Bik       llvmVoidType,
291289f7231SAart Bik       {llvmPointerType /*s*/, llvmPointerType /*void *stream*/}};
292dfe29429SKun Wu   FunctionCallBuilder createSpGEMMWorkEstimationBuilder = {
293dfe29429SKun Wu       "mgpuSpGEMMWorkEstimation",
294dfe29429SKun Wu       llvmIntPtrType,
295dfe29429SKun Wu       {llvmPointerType /*s*/, llvmInt32Type /*ma*/, llvmInt32Type /*mb*/,
296dfe29429SKun Wu        llvmPointerType /*a*/, llvmPointerType /*b*/, llvmPointerType /*c*/,
297e7e4ed0dSAart Bik        llvmInt32Type /*ctp*/, llvmIntPtrType /*bs*/, llvmPointerType /*buf*/,
298e7e4ed0dSAart Bik        llvmPointerType /*void *stream*/}};
299dfe29429SKun Wu   FunctionCallBuilder createSpGEMMComputeBuilder = {
300dfe29429SKun Wu       "mgpuSpGEMMCompute",
301dfe29429SKun Wu       llvmIntPtrType,
302dfe29429SKun Wu       {llvmPointerType /*s*/, llvmInt32Type /*ma*/, llvmInt32Type /*mb*/,
303dfe29429SKun Wu        llvmPointerType /*a*/, llvmPointerType /*b*/, llvmPointerType /*c*/,
304e7e4ed0dSAart Bik        llvmInt32Type /*ctp*/, llvmIntPtrType /*bs*/, llvmPointerType /*buf*/,
305e7e4ed0dSAart Bik        llvmPointerType /*void *stream*/}};
306dfe29429SKun Wu   FunctionCallBuilder createSpGEMMCopyBuilder = {
307dfe29429SKun Wu       "mgpuSpGEMMCopy",
308dfe29429SKun Wu       llvmVoidType,
309dfe29429SKun Wu       {llvmPointerType /*s*/, llvmInt32Type /*ma*/, llvmInt32Type /*mb*/,
310dfe29429SKun Wu        llvmPointerType /*a*/, llvmPointerType /*b*/, llvmPointerType /*c*/,
311e7e4ed0dSAart Bik        llvmInt32Type /*ctp*/, llvmPointerType /*void *stream*/}};
312289f7231SAart Bik   FunctionCallBuilder createSpMatGetSizeBuilder = {
313289f7231SAart Bik       "mgpuSpMatGetSize",
314dfe29429SKun Wu       llvmVoidType,
315dfe29429SKun Wu       {llvmPointerType /*mc*/, llvmPointerType /*rc*/, llvmPointerType /*cc*/,
316dfe29429SKun Wu        llvmPointerType /*nc*/, llvmPointerType /*void *stream*/}};
31795a6c509SAart Bik   FunctionCallBuilder createSetCsrPointersBuilder = {
31895a6c509SAart Bik       "mgpuSetCsrPointers",
31995a6c509SAart Bik       llvmVoidType,
32095a6c509SAart Bik       {llvmPointerType /*spmat*/, llvmPointerType /*pos*/,
32195a6c509SAart Bik        llvmPointerType /*crd*/, llvmPointerType /*val*/,
32295a6c509SAart Bik        llvmPointerType /*void *stream*/}};
323a5f9cda1SChristian Sigg };
324a5f9cda1SChristian Sigg 
325a5f9cda1SChristian Sigg /// A rewrite pattern to convert gpu.host_register operations into a GPU runtime
326a5f9cda1SChristian Sigg /// call. Currently it supports CUDA and ROCm (HIP).
327a5f9cda1SChristian Sigg class ConvertHostRegisterOpToGpuRuntimeCallPattern
328a5f9cda1SChristian Sigg     : public ConvertOpToGpuRuntimeCallPattern<gpu::HostRegisterOp> {
329a5f9cda1SChristian Sigg public:
330ce254598SMatthias Springer   ConvertHostRegisterOpToGpuRuntimeCallPattern(
331ce254598SMatthias Springer       const LLVMTypeConverter &typeConverter)
332a5f9cda1SChristian Sigg       : ConvertOpToGpuRuntimeCallPattern<gpu::HostRegisterOp>(typeConverter) {}
333a5f9cda1SChristian Sigg 
334a5f9cda1SChristian Sigg private:
335a5f9cda1SChristian Sigg   LogicalResult
336ef976337SRiver Riddle   matchAndRewrite(gpu::HostRegisterOp hostRegisterOp, OpAdaptor adaptor,
337a5f9cda1SChristian Sigg                   ConversionPatternRewriter &rewriter) const override;
338a5f9cda1SChristian Sigg };
339a5f9cda1SChristian Sigg 
3408f7c8a6eSmax class ConvertHostUnregisterOpToGpuRuntimeCallPattern
3418f7c8a6eSmax     : public ConvertOpToGpuRuntimeCallPattern<gpu::HostUnregisterOp> {
3428f7c8a6eSmax public:
3438f7c8a6eSmax   ConvertHostUnregisterOpToGpuRuntimeCallPattern(
344ce254598SMatthias Springer       const LLVMTypeConverter &typeConverter)
3458f7c8a6eSmax       : ConvertOpToGpuRuntimeCallPattern<gpu::HostUnregisterOp>(typeConverter) {
3468f7c8a6eSmax   }
3478f7c8a6eSmax 
3488f7c8a6eSmax private:
3498f7c8a6eSmax   LogicalResult
3508f7c8a6eSmax   matchAndRewrite(gpu::HostUnregisterOp hostUnregisterOp, OpAdaptor adaptor,
3518f7c8a6eSmax                   ConversionPatternRewriter &rewriter) const override;
3528f7c8a6eSmax };
3538f7c8a6eSmax 
354a5f9cda1SChristian Sigg /// A rewrite pattern to convert gpu.alloc operations into a GPU runtime
355a5f9cda1SChristian Sigg /// call. Currently it supports CUDA and ROCm (HIP).
356a5f9cda1SChristian Sigg class ConvertAllocOpToGpuRuntimeCallPattern
357a5f9cda1SChristian Sigg     : public ConvertOpToGpuRuntimeCallPattern<gpu::AllocOp> {
358a5f9cda1SChristian Sigg public:
359ce254598SMatthias Springer   ConvertAllocOpToGpuRuntimeCallPattern(const LLVMTypeConverter &typeConverter)
360a5f9cda1SChristian Sigg       : ConvertOpToGpuRuntimeCallPattern<gpu::AllocOp>(typeConverter) {}
361a5f9cda1SChristian Sigg 
362a5f9cda1SChristian Sigg private:
363a5f9cda1SChristian Sigg   LogicalResult
364ef976337SRiver Riddle   matchAndRewrite(gpu::AllocOp allocOp, OpAdaptor adaptor,
365a5f9cda1SChristian Sigg                   ConversionPatternRewriter &rewriter) const override;
366a5f9cda1SChristian Sigg };
367a5f9cda1SChristian Sigg 
368a5f9cda1SChristian Sigg /// A rewrite pattern to convert gpu.dealloc operations into a GPU runtime
369a5f9cda1SChristian Sigg /// call. Currently it supports CUDA and ROCm (HIP).
370a5f9cda1SChristian Sigg class ConvertDeallocOpToGpuRuntimeCallPattern
371a5f9cda1SChristian Sigg     : public ConvertOpToGpuRuntimeCallPattern<gpu::DeallocOp> {
372a5f9cda1SChristian Sigg public:
373ce254598SMatthias Springer   ConvertDeallocOpToGpuRuntimeCallPattern(
374ce254598SMatthias Springer       const LLVMTypeConverter &typeConverter)
375a5f9cda1SChristian Sigg       : ConvertOpToGpuRuntimeCallPattern<gpu::DeallocOp>(typeConverter) {}
376a5f9cda1SChristian Sigg 
377a5f9cda1SChristian Sigg private:
378a5f9cda1SChristian Sigg   LogicalResult
379ef976337SRiver Riddle   matchAndRewrite(gpu::DeallocOp deallocOp, OpAdaptor adaptor,
380a5f9cda1SChristian Sigg                   ConversionPatternRewriter &rewriter) const override;
381a5f9cda1SChristian Sigg };
382a5f9cda1SChristian Sigg 
383a5f9cda1SChristian Sigg class ConvertAsyncYieldToGpuRuntimeCallPattern
384a5f9cda1SChristian Sigg     : public ConvertOpToGpuRuntimeCallPattern<async::YieldOp> {
385a5f9cda1SChristian Sigg public:
386ce254598SMatthias Springer   ConvertAsyncYieldToGpuRuntimeCallPattern(
387ce254598SMatthias Springer       const LLVMTypeConverter &typeConverter)
388a5f9cda1SChristian Sigg       : ConvertOpToGpuRuntimeCallPattern<async::YieldOp>(typeConverter) {}
389a5f9cda1SChristian Sigg 
390a5f9cda1SChristian Sigg private:
391a5f9cda1SChristian Sigg   LogicalResult
392ef976337SRiver Riddle   matchAndRewrite(async::YieldOp yieldOp, OpAdaptor adaptor,
393a5f9cda1SChristian Sigg                   ConversionPatternRewriter &rewriter) const override;
394a5f9cda1SChristian Sigg };
395a5f9cda1SChristian Sigg 
396a5f9cda1SChristian Sigg /// A rewrite pattern to convert gpu.wait operations into a GPU runtime
397a5f9cda1SChristian Sigg /// call. Currently it supports CUDA and ROCm (HIP).
398a5f9cda1SChristian Sigg class ConvertWaitOpToGpuRuntimeCallPattern
399a5f9cda1SChristian Sigg     : public ConvertOpToGpuRuntimeCallPattern<gpu::WaitOp> {
400a5f9cda1SChristian Sigg public:
401ce254598SMatthias Springer   ConvertWaitOpToGpuRuntimeCallPattern(const LLVMTypeConverter &typeConverter)
402a5f9cda1SChristian Sigg       : ConvertOpToGpuRuntimeCallPattern<gpu::WaitOp>(typeConverter) {}
403a5f9cda1SChristian Sigg 
404a5f9cda1SChristian Sigg private:
405a5f9cda1SChristian Sigg   LogicalResult
406ef976337SRiver Riddle   matchAndRewrite(gpu::WaitOp waitOp, OpAdaptor adaptor,
407a5f9cda1SChristian Sigg                   ConversionPatternRewriter &rewriter) const override;
408a5f9cda1SChristian Sigg };
409a5f9cda1SChristian Sigg 
410a5f9cda1SChristian Sigg /// A rewrite pattern to convert gpu.wait async operations into a GPU runtime
411a5f9cda1SChristian Sigg /// call. Currently it supports CUDA and ROCm (HIP).
412a5f9cda1SChristian Sigg class ConvertWaitAsyncOpToGpuRuntimeCallPattern
413a5f9cda1SChristian Sigg     : public ConvertOpToGpuRuntimeCallPattern<gpu::WaitOp> {
414a5f9cda1SChristian Sigg public:
415ce254598SMatthias Springer   ConvertWaitAsyncOpToGpuRuntimeCallPattern(
416ce254598SMatthias Springer       const LLVMTypeConverter &typeConverter)
417a5f9cda1SChristian Sigg       : ConvertOpToGpuRuntimeCallPattern<gpu::WaitOp>(typeConverter) {}
418a5f9cda1SChristian Sigg 
419a5f9cda1SChristian Sigg private:
420a5f9cda1SChristian Sigg   LogicalResult
421ef976337SRiver Riddle   matchAndRewrite(gpu::WaitOp waitOp, OpAdaptor adaptor,
422a5f9cda1SChristian Sigg                   ConversionPatternRewriter &rewriter) const override;
423a5f9cda1SChristian Sigg };
424a5f9cda1SChristian Sigg 
4258e12f31bSFabian Mora /// A rewrite patter to legalize gpu.launch_func with LLVM types.
4268e12f31bSFabian Mora class LegalizeLaunchFuncOpPattern
427a5f9cda1SChristian Sigg     : public ConvertOpToGpuRuntimeCallPattern<gpu::LaunchFuncOp> {
428a5f9cda1SChristian Sigg public:
4298e12f31bSFabian Mora   LegalizeLaunchFuncOpPattern(const LLVMTypeConverter &typeConverter,
43099a562b3SAndrea Faulds                               bool kernelBarePtrCallConv,
431*733be4edSAndrea Faulds                               bool kernelIntersperseSizeCallConv)
432a5f9cda1SChristian Sigg       : ConvertOpToGpuRuntimeCallPattern<gpu::LaunchFuncOp>(typeConverter),
43399a562b3SAndrea Faulds         kernelBarePtrCallConv(kernelBarePtrCallConv),
434*733be4edSAndrea Faulds         kernelIntersperseSizeCallConv(kernelIntersperseSizeCallConv) {}
435a5f9cda1SChristian Sigg 
436a5f9cda1SChristian Sigg private:
437a5f9cda1SChristian Sigg   LogicalResult
438ef976337SRiver Riddle   matchAndRewrite(gpu::LaunchFuncOp launchOp, OpAdaptor adaptor,
439a5f9cda1SChristian Sigg                   ConversionPatternRewriter &rewriter) const override;
440a5f9cda1SChristian Sigg 
441c2fc8d9bSKrzysztof Drewniak   bool kernelBarePtrCallConv;
442*733be4edSAndrea Faulds   bool kernelIntersperseSizeCallConv;
443a5f9cda1SChristian Sigg };
444a5f9cda1SChristian Sigg 
445a5f9cda1SChristian Sigg /// A rewrite pattern to convert gpu.memcpy operations into a GPU runtime
446a5f9cda1SChristian Sigg /// call. Currently it supports CUDA and ROCm (HIP).
447a5f9cda1SChristian Sigg class ConvertMemcpyOpToGpuRuntimeCallPattern
448a5f9cda1SChristian Sigg     : public ConvertOpToGpuRuntimeCallPattern<gpu::MemcpyOp> {
449a5f9cda1SChristian Sigg public:
450ce254598SMatthias Springer   ConvertMemcpyOpToGpuRuntimeCallPattern(const LLVMTypeConverter &typeConverter)
451a5f9cda1SChristian Sigg       : ConvertOpToGpuRuntimeCallPattern<gpu::MemcpyOp>(typeConverter) {}
452a5f9cda1SChristian Sigg 
453a5f9cda1SChristian Sigg private:
454a5f9cda1SChristian Sigg   LogicalResult
455ef976337SRiver Riddle   matchAndRewrite(gpu::MemcpyOp memcpyOp, OpAdaptor adaptor,
456a5f9cda1SChristian Sigg                   ConversionPatternRewriter &rewriter) const override;
457a5f9cda1SChristian Sigg };
458361458b1SLoren Maggiore 
459361458b1SLoren Maggiore /// A rewrite pattern to convert gpu.memset operations into a GPU runtime
460361458b1SLoren Maggiore /// call. Currently it supports CUDA and ROCm (HIP).
461361458b1SLoren Maggiore class ConvertMemsetOpToGpuRuntimeCallPattern
462361458b1SLoren Maggiore     : public ConvertOpToGpuRuntimeCallPattern<gpu::MemsetOp> {
463361458b1SLoren Maggiore public:
464ce254598SMatthias Springer   ConvertMemsetOpToGpuRuntimeCallPattern(const LLVMTypeConverter &typeConverter)
465361458b1SLoren Maggiore       : ConvertOpToGpuRuntimeCallPattern<gpu::MemsetOp>(typeConverter) {}
466361458b1SLoren Maggiore 
467361458b1SLoren Maggiore private:
468361458b1SLoren Maggiore   LogicalResult
469ef976337SRiver Riddle   matchAndRewrite(gpu::MemsetOp memsetOp, OpAdaptor adaptor,
470361458b1SLoren Maggiore                   ConversionPatternRewriter &rewriter) const override;
471361458b1SLoren Maggiore };
47284718d37SKrzysztof Drewniak 
47384718d37SKrzysztof Drewniak /// A rewrite pattern to convert gpu.set_default_device to a GPU runtime call.
47484718d37SKrzysztof Drewniak /// Currently supports CUDA and ROCm (HIP)
47584718d37SKrzysztof Drewniak class ConvertSetDefaultDeviceOpToGpuRuntimeCallPattern
47684718d37SKrzysztof Drewniak     : public ConvertOpToGpuRuntimeCallPattern<gpu::SetDefaultDeviceOp> {
47784718d37SKrzysztof Drewniak public:
47884718d37SKrzysztof Drewniak   ConvertSetDefaultDeviceOpToGpuRuntimeCallPattern(
479ce254598SMatthias Springer       const LLVMTypeConverter &typeConverter)
48084718d37SKrzysztof Drewniak       : ConvertOpToGpuRuntimeCallPattern<gpu::SetDefaultDeviceOp>(
48184718d37SKrzysztof Drewniak             typeConverter) {}
48284718d37SKrzysztof Drewniak 
48384718d37SKrzysztof Drewniak   LogicalResult
48484718d37SKrzysztof Drewniak   matchAndRewrite(gpu::SetDefaultDeviceOp op, OpAdaptor adaptor,
48584718d37SKrzysztof Drewniak                   ConversionPatternRewriter &rewriter) const override;
48684718d37SKrzysztof Drewniak };
487b700a90cSAart Bik 
4889dfd3c32SAart Bik /// Generic rewriting rule for operation on sparse matrices.
4899dfd3c32SAart Bik /// Currently supports CUDA (by means of cuSparse and cuSparseLt).
490dfe29429SKun Wu #define DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(op_name)                \
491dfe29429SKun Wu   class Convert##op_name##ToGpuRuntimeCallPattern                              \
492dfe29429SKun Wu       : public ConvertOpToGpuRuntimeCallPattern<gpu::op_name> {                \
493dfe29429SKun Wu   public:                                                                      \
494dfe29429SKun Wu     Convert##op_name##ToGpuRuntimeCallPattern(                                 \
495ce254598SMatthias Springer         const LLVMTypeConverter &typeConverter)                                \
496dfe29429SKun Wu         : ConvertOpToGpuRuntimeCallPattern<gpu::op_name>(typeConverter) {}     \
497dfe29429SKun Wu                                                                                \
498dfe29429SKun Wu   private:                                                                     \
499dfe29429SKun Wu     LogicalResult                                                              \
500dfe29429SKun Wu     matchAndRewrite(gpu::op_name op, OpAdaptor adaptor,                        \
501dfe29429SKun Wu                     ConversionPatternRewriter &rewriter) const override;       \
502dfe29429SKun Wu   };
503dfe29429SKun Wu 
5049dfd3c32SAart Bik DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(CreateDnTensorOp)
5059dfd3c32SAart Bik DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(DestroyDnTensorOp)
5069dfd3c32SAart Bik DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(CreateCooOp)
5079dfd3c32SAart Bik DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(CreateCooAoSOp)
5089dfd3c32SAart Bik DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(CreateCsrOp)
50939038177SAart Bik DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(CreateCscOp)
51039038177SAart Bik DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(CreateBsrOp)
5119dfd3c32SAart Bik DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(Create2To4SpMatOp)
5129dfd3c32SAart Bik DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(DestroySpMatOp)
5139dfd3c32SAart Bik DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(SpMVBufferSizeOp)
5149dfd3c32SAart Bik DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(SpMVOp)
5159dfd3c32SAart Bik DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(SpMMBufferSizeOp)
5169dfd3c32SAart Bik DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(SDDMMBufferSizeOp)
5179dfd3c32SAart Bik DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(SpMMOp)
5189dfd3c32SAart Bik DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(SDDMMOp)
519dfe29429SKun Wu DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(SpGEMMCreateDescrOp)
520dfe29429SKun Wu DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(SpGEMMDestroyDescrOp)
521dfe29429SKun Wu DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(SpGEMMWorkEstimationOrComputeOp)
522dfe29429SKun Wu DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(SpGEMMCopyOp)
523289f7231SAart Bik DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(SpMatGetSizeOp)
52495a6c509SAart Bik DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(SetCsrPointersOp)
525dfe29429SKun Wu 
526a5f9cda1SChristian Sigg } // namespace
527a5f9cda1SChristian Sigg 
528039b969bSMichele Scuttari void GpuToLLVMConversionPass::runOnOperation() {
5299e7b6f46SMehdi Amini   MLIRContext *context = &getContext();
5300693b9e9SMatthias Springer 
5310693b9e9SMatthias Springer   // Perform progressive lowering of vector transfer operations.
5320693b9e9SMatthias Springer   {
5330693b9e9SMatthias Springer     RewritePatternSet patterns(&getContext());
5340693b9e9SMatthias Springer     // Vector transfer ops with rank > 1 should be lowered with VectorToSCF.
5350693b9e9SMatthias Springer     vector::populateVectorTransferLoweringPatterns(patterns,
5360693b9e9SMatthias Springer                                                    /*maxTransferRank=*/1);
53799019060SKazu Hirata     if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
5380693b9e9SMatthias Springer       return signalPassFailure();
5390693b9e9SMatthias Springer   }
5400693b9e9SMatthias Springer 
5419e7b6f46SMehdi Amini   LowerToLLVMOptions options(context);
5429e7b6f46SMehdi Amini   options.useBarePtrCallConv = hostBarePtrCallConv;
5439e7b6f46SMehdi Amini   RewritePatternSet patterns(context);
5449e7b6f46SMehdi Amini   ConversionTarget target(*context);
5459e7b6f46SMehdi Amini   target.addLegalDialect<LLVM::LLVMDialect>();
5469e7b6f46SMehdi Amini   LLVMTypeConverter converter(context, options);
5479e7b6f46SMehdi Amini 
5489e7b6f46SMehdi Amini   // Populate all patterns from all dialects that implement the
5499e7b6f46SMehdi Amini   // `ConvertToLLVMPatternInterface` interface.
5509e7b6f46SMehdi Amini   for (Dialect *dialect : context->getLoadedDialects()) {
5519e7b6f46SMehdi Amini     auto iface = dyn_cast<ConvertToLLVMPatternInterface>(dialect);
5529e7b6f46SMehdi Amini     if (!iface)
5539e7b6f46SMehdi Amini       continue;
5549e7b6f46SMehdi Amini     iface->populateConvertToLLVMConversionPatterns(target, converter, patterns);
5559e7b6f46SMehdi Amini   }
5569e7b6f46SMehdi Amini 
5578e12f31bSFabian Mora   // Preserve GPU modules and binaries. Modules are preserved as they can be
5588e12f31bSFabian Mora   // converted later by `gpu-module-to-binary`.
5598e12f31bSFabian Mora   target.addLegalOp<gpu::GPUModuleOp, gpu::BinaryOp>();
5608e12f31bSFabian Mora   // Accept as legal LaunchFuncOps if the operands have been lowered.
561fcfeb1e5SFabian Mora   target.addDynamicallyLegalOp<gpu::LaunchFuncOp>(
5628e12f31bSFabian Mora       [&](gpu::LaunchFuncOp op) -> bool { return converter.isLegal(op); });
563abe501f2SChristian Sigg 
5649e7b6f46SMehdi Amini   // These aren't covered by the ConvertToLLVMPatternInterface right now.
565a5f9cda1SChristian Sigg   populateVectorToLLVMConversionPatterns(converter, patterns);
566cb4ccd38SQuentin Colombet   populateFinalizeMemRefToLLVMConversionPatterns(converter, patterns);
5673a506b31SChris Lattner   populateAsyncStructuralTypeConversionsAndLegality(converter, patterns,
5683a506b31SChris Lattner                                                     target);
569*733be4edSAndrea Faulds   populateGpuToLLVMConversionPatterns(converter, patterns,
570*733be4edSAndrea Faulds                                       kernelBarePtrCallConv,
571*733be4edSAndrea Faulds                                       kernelIntersperseSizeCallConv);
572a5f9cda1SChristian Sigg 
573a5f9cda1SChristian Sigg   if (failed(
574a5f9cda1SChristian Sigg           applyPartialConversion(getOperation(), target, std::move(patterns))))
575a5f9cda1SChristian Sigg     signalPassFailure();
576a5f9cda1SChristian Sigg }
577a5f9cda1SChristian Sigg 
578a5f9cda1SChristian Sigg LLVM::CallOp FunctionCallBuilder::create(Location loc, OpBuilder &builder,
579a5f9cda1SChristian Sigg                                          ArrayRef<Value> arguments) const {
580a5f9cda1SChristian Sigg   auto module = builder.getBlock()->getParent()->getParentOfType<ModuleOp>();
581a5f9cda1SChristian Sigg   auto function = [&] {
582a5f9cda1SChristian Sigg     if (auto function = module.lookupSymbol<LLVM::LLVMFuncOp>(functionName))
583a5f9cda1SChristian Sigg       return function;
584973ddb7dSMehdi Amini     return OpBuilder::atBlockEnd(module.getBody())
585a5f9cda1SChristian Sigg         .create<LLVM::LLVMFuncOp>(loc, functionName, functionType);
586a5f9cda1SChristian Sigg   }();
587c1f719d1SAlex Zinenko   return builder.create<LLVM::CallOp>(loc, function, arguments);
588a5f9cda1SChristian Sigg }
589a5f9cda1SChristian Sigg 
590cc402de0SKun Wu // Corresponding to cusparseIndexType_t defined in cusparse.h.
591cc402de0SKun Wu static int32_t getCuSparseIndexTypeFrom(Type type) {
59239038177SAart Bik   if (type.isInteger(16))
59339038177SAart Bik     return 1; // CUSPARSE_INDEX_16U
59439038177SAart Bik   if (type.isInteger(32))
595cc402de0SKun Wu     return 2; // CUSPARSE_INDEX_32I
59639038177SAart Bik   return 3;   // CUSPARSE_INDEX_64I
597cc402de0SKun Wu }
598cc402de0SKun Wu 
5998ed59c53SKun Wu static int32_t getCuSparseLtDataTypeFrom(Type type) {
6008ed59c53SKun Wu   if (type.isF16())
6018ed59c53SKun Wu     return 0; // CUSPARSE_COMPUTE_16F,
6028ed59c53SKun Wu   if (type.isInteger(32))
6038ed59c53SKun Wu     return 1; // CUSPARSE_COMPUTE_32I
6048ed59c53SKun Wu   llvm_unreachable("unsupported type");
6058ed59c53SKun Wu   // TODO: add support to TF32
6068ed59c53SKun Wu }
6078ed59c53SKun Wu 
608cc402de0SKun Wu // Corresponding to cudaDataType_t defined in CUDA library_types.h.
609cc402de0SKun Wu static int32_t getCuSparseDataTypeFrom(Type type) {
610cc402de0SKun Wu   if (llvm::isa<ComplexType>(type)) {
611cc402de0SKun Wu     // get the element type
612a5757c5bSChristian Sigg     auto elementType = cast<ComplexType>(type).getElementType();
613cc402de0SKun Wu     if (elementType.isBF16())
614cc402de0SKun Wu       return 15; // CUDA_C_16BF
615cc402de0SKun Wu     if (elementType.isF16())
616cc402de0SKun Wu       return 6; // CUDA_C_16F
617cc402de0SKun Wu     if (elementType.isF32())
618cc402de0SKun Wu       return 4; // CUDA_C_32F
619cc402de0SKun Wu     if (elementType.isF64())
620cc402de0SKun Wu       return 5; // CUDA_C_64F
621cc402de0SKun Wu     if (elementType.isInteger(8))
622cc402de0SKun Wu       return 7; // CUDA_C_8I
623cc402de0SKun Wu     if (elementType.isInteger(16))
624cc402de0SKun Wu       return 21; // CUDA_C_16I
625cc402de0SKun Wu     if (elementType.isInteger(32))
626cc402de0SKun Wu       return 11; // CUDA_C_32I
627cc402de0SKun Wu   }
628cc402de0SKun Wu   if (type.isBF16())
629cc402de0SKun Wu     return 14; // CUDA_R_16BF
630cc402de0SKun Wu   if (type.isF16())
631cc402de0SKun Wu     return 2; // CUDA_R_16F
632cc402de0SKun Wu   if (type.isF32())
633cc402de0SKun Wu     return 0; // CUDA_R_32F
634cc402de0SKun Wu   if (type.isF64())
635cc402de0SKun Wu     return 1; // CUDA_R_64F
636cc402de0SKun Wu   if (type.isInteger(8))
637cc402de0SKun Wu     return 3; // CUDA_R_8I
638cc402de0SKun Wu   if (type.isInteger(16))
639cc402de0SKun Wu     return 20; // CUDA_R_16I
640cc402de0SKun Wu   if (type.isInteger(32))
641cc402de0SKun Wu     return 10; // CUDA_R_32I
642cc402de0SKun Wu 
643cc402de0SKun Wu   llvm_unreachable("unsupported element type");
644cc402de0SKun Wu }
645cc402de0SKun Wu 
6461e491c42SKun Wu static gpu::Prune2To4SpMatFlag get2To4PruneFlag(Value spMat) {
6471e491c42SKun Wu   return spMat.getDefiningOp<gpu::Create2To4SpMatOp>().getPruneFlag();
6481e491c42SKun Wu }
64939038177SAart Bik 
6508ed59c53SKun Wu // TODO:  We may want a run-time (of the mlir compiler) disablement/warning:
6518ed59c53SKun Wu // cusparseLt currently won't work for cuda architecture <8.0 and will trigger a
6528ed59c53SKun Wu // runtime (of the CUDA program) error , but it might be great if we could at
6538ed59c53SKun Wu // least output a warning when we found the target architecture is <8.0 and the
6548ed59c53SKun Wu // user still wants to use cusparseLt. to make sure when lowering gpu sparse
6558ed59c53SKun Wu // dialect to llvm calls, the cusparselt calls are disabled for cuda
6568ed59c53SKun Wu // architecture <8.0
6578ed59c53SKun Wu static bool is2To4Sparsity(Value spMat) {
6588ed59c53SKun Wu   if (auto op = spMat.getDefiningOp<gpu::Create2To4SpMatOp>())
6598ed59c53SKun Wu     return true;
6608ed59c53SKun Wu   if (auto op = spMat.getDefiningOp<gpu::CreateCooOp>())
6618ed59c53SKun Wu     return false;
66239038177SAart Bik   if (auto op = spMat.getDefiningOp<gpu::CreateCooAoSOp>())
66339038177SAart Bik     return false;
6648ed59c53SKun Wu   if (auto op = spMat.getDefiningOp<gpu::CreateCsrOp>())
6658ed59c53SKun Wu     return false;
66639038177SAart Bik   if (auto op = spMat.getDefiningOp<gpu::CreateCscOp>())
66739038177SAart Bik     return false;
66839038177SAart Bik   if (auto op = spMat.getDefiningOp<gpu::CreateBsrOp>())
6698ed59c53SKun Wu     return false;
6708ed59c53SKun Wu   // Print the spMat defining op
6718ed59c53SKun Wu   spMat.getDefiningOp()->print(llvm::errs());
6728ed59c53SKun Wu   llvm_unreachable("cannot find spmat def");
6738ed59c53SKun Wu }
6748ed59c53SKun Wu 
6758ed59c53SKun Wu static bool isSpMMCusparseLtOp(Value op) {
6768ed59c53SKun Wu   for (Operation *user : op.getUsers()) {
6778ed59c53SKun Wu     auto spmmOp = dyn_cast<gpu::SpMMOp>(user);
6788ed59c53SKun Wu     // If the other operator is 50% sparsity then we should use cusparseLt
6798ed59c53SKun Wu     if (!spmmOp)
6808ed59c53SKun Wu       continue;
6818ed59c53SKun Wu     if (is2To4Sparsity(spmmOp.getSpmatA()))
6828ed59c53SKun Wu       return true;
6838ed59c53SKun Wu   }
6848ed59c53SKun Wu   return false;
6858ed59c53SKun Wu }
6868ed59c53SKun Wu 
687a5f9cda1SChristian Sigg // Returns whether all operands are of LLVM type.
688a5f9cda1SChristian Sigg static LogicalResult areAllLLVMTypes(Operation *op, ValueRange operands,
689a5f9cda1SChristian Sigg                                      ConversionPatternRewriter &rewriter) {
690a5f9cda1SChristian Sigg   if (!llvm::all_of(operands, [](Value value) {
691a5f9cda1SChristian Sigg         return LLVM::isCompatibleType(value.getType());
692a5f9cda1SChristian Sigg       }))
693a5f9cda1SChristian Sigg     return rewriter.notifyMatchFailure(
694a5f9cda1SChristian Sigg         op, "Cannot convert if operands aren't of LLVM type.");
695a5f9cda1SChristian Sigg   return success();
696a5f9cda1SChristian Sigg }
697a5f9cda1SChristian Sigg 
698a5f9cda1SChristian Sigg static LogicalResult
699a5f9cda1SChristian Sigg isAsyncWithOneDependency(ConversionPatternRewriter &rewriter,
700a5f9cda1SChristian Sigg                          gpu::AsyncOpInterface op) {
701a5f9cda1SChristian Sigg   if (op.getAsyncDependencies().size() != 1)
702a5f9cda1SChristian Sigg     return rewriter.notifyMatchFailure(
703a5f9cda1SChristian Sigg         op, "Can only convert with exactly one async dependency.");
704a5f9cda1SChristian Sigg 
705a5f9cda1SChristian Sigg   if (!op.getAsyncToken())
706a5f9cda1SChristian Sigg     return rewriter.notifyMatchFailure(op, "Can convert only async version.");
707a5f9cda1SChristian Sigg 
708a5f9cda1SChristian Sigg   return success();
709a5f9cda1SChristian Sigg }
710a5f9cda1SChristian Sigg 
711a5f9cda1SChristian Sigg LogicalResult ConvertHostRegisterOpToGpuRuntimeCallPattern::matchAndRewrite(
712ef976337SRiver Riddle     gpu::HostRegisterOp hostRegisterOp, OpAdaptor adaptor,
713a5f9cda1SChristian Sigg     ConversionPatternRewriter &rewriter) const {
714a5f9cda1SChristian Sigg   auto *op = hostRegisterOp.getOperation();
715ef976337SRiver Riddle   if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)))
716a5f9cda1SChristian Sigg     return failure();
717a5f9cda1SChristian Sigg 
718a5f9cda1SChristian Sigg   Location loc = op->getLoc();
719a5f9cda1SChristian Sigg 
72010c04f46SRiver Riddle   auto memRefType = hostRegisterOp.getValue().getType();
7215550c821STres Popp   auto elementType = cast<UnrankedMemRefType>(memRefType).getElementType();
722a5f9cda1SChristian Sigg   auto elementSize = getSizeInBytes(loc, elementType, rewriter);
723a5f9cda1SChristian Sigg 
724ef976337SRiver Riddle   auto arguments = getTypeConverter()->promoteOperands(
725ef976337SRiver Riddle       loc, op->getOperands(), adaptor.getOperands(), rewriter);
726a5f9cda1SChristian Sigg   arguments.push_back(elementSize);
727a5f9cda1SChristian Sigg   hostRegisterCallBuilder.create(loc, rewriter, arguments);
728a5f9cda1SChristian Sigg 
729a5f9cda1SChristian Sigg   rewriter.eraseOp(op);
730a5f9cda1SChristian Sigg   return success();
731a5f9cda1SChristian Sigg }
732a5f9cda1SChristian Sigg 
7338f7c8a6eSmax LogicalResult ConvertHostUnregisterOpToGpuRuntimeCallPattern::matchAndRewrite(
7348f7c8a6eSmax     gpu::HostUnregisterOp hostUnregisterOp, OpAdaptor adaptor,
7358f7c8a6eSmax     ConversionPatternRewriter &rewriter) const {
7368f7c8a6eSmax   Operation *op = hostUnregisterOp.getOperation();
7378f7c8a6eSmax   if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)))
7388f7c8a6eSmax     return failure();
7398f7c8a6eSmax 
7408f7c8a6eSmax   Location loc = op->getLoc();
7418f7c8a6eSmax 
7428f7c8a6eSmax   auto memRefType = hostUnregisterOp.getValue().getType();
7435550c821STres Popp   auto elementType = cast<UnrankedMemRefType>(memRefType).getElementType();
7448f7c8a6eSmax   auto elementSize = getSizeInBytes(loc, elementType, rewriter);
7458f7c8a6eSmax 
7468f7c8a6eSmax   auto arguments = getTypeConverter()->promoteOperands(
7478f7c8a6eSmax       loc, op->getOperands(), adaptor.getOperands(), rewriter);
7488f7c8a6eSmax   arguments.push_back(elementSize);
7498f7c8a6eSmax   hostUnregisterCallBuilder.create(loc, rewriter, arguments);
7508f7c8a6eSmax 
7518f7c8a6eSmax   rewriter.eraseOp(op);
7528f7c8a6eSmax   return success();
7538f7c8a6eSmax }
7548f7c8a6eSmax 
755a5f9cda1SChristian Sigg LogicalResult ConvertAllocOpToGpuRuntimeCallPattern::matchAndRewrite(
756ef976337SRiver Riddle     gpu::AllocOp allocOp, OpAdaptor adaptor,
757a5f9cda1SChristian Sigg     ConversionPatternRewriter &rewriter) const {
758a93ec06aSIvan Butygin 
759a5f9cda1SChristian Sigg   MemRefType memRefType = allocOp.getType();
760a5f9cda1SChristian Sigg 
761ef976337SRiver Riddle   if (failed(areAllLLVMTypes(allocOp, adaptor.getOperands(), rewriter)) ||
7621002a1d0SNishant Patel       !isConvertibleAndHasIdentityMaps(memRefType))
763a5f9cda1SChristian Sigg     return failure();
764a5f9cda1SChristian Sigg 
765a5f9cda1SChristian Sigg   auto loc = allocOp.getLoc();
766a5f9cda1SChristian Sigg 
7671002a1d0SNishant Patel   bool isShared = allocOp.getHostShared();
7681002a1d0SNishant Patel 
7691002a1d0SNishant Patel   if (isShared && allocOp.getAsyncToken())
7701002a1d0SNishant Patel     return rewriter.notifyMatchFailure(
7711002a1d0SNishant Patel         allocOp, "Host Shared allocation cannot be done async");
772e204b919SMehdi Amini   if (!isShared && failed(isAsyncWithOneDependency(rewriter, allocOp)))
7731002a1d0SNishant Patel     return failure();
7741002a1d0SNishant Patel 
775a5f9cda1SChristian Sigg   // Get shape of the memref as values: static sizes are constant
776a5f9cda1SChristian Sigg   // values and dynamic sizes are passed to 'alloc' as operands.
777a5f9cda1SChristian Sigg   SmallVector<Value, 4> shape;
778a5f9cda1SChristian Sigg   SmallVector<Value, 4> strides;
779a5f9cda1SChristian Sigg   Value sizeBytes;
78010c04f46SRiver Riddle   getMemRefDescriptorSizes(loc, memRefType, adaptor.getDynamicSizes(), rewriter,
781a5f9cda1SChristian Sigg                            shape, strides, sizeBytes);
782a5f9cda1SChristian Sigg 
783a5f9cda1SChristian Sigg   // Allocate the underlying buffer and store a pointer to it in the MemRef
784a5f9cda1SChristian Sigg   // descriptor.
785ced9f4f0SNishant Patel   auto nullPtr = rewriter.create<mlir::LLVM::ZeroOp>(loc, llvmPointerType);
786ced9f4f0SNishant Patel   Value stream = adaptor.getAsyncDependencies().empty()
787ced9f4f0SNishant Patel                      ? nullPtr
788ced9f4f0SNishant Patel                      : adaptor.getAsyncDependencies().front();
7891002a1d0SNishant Patel 
7901002a1d0SNishant Patel   auto isHostShared = rewriter.create<mlir::LLVM::ConstantOp>(
7911002a1d0SNishant Patel       loc, llvmInt8Type, rewriter.getI8IntegerAttr(isShared));
7921002a1d0SNishant Patel 
793a5f9cda1SChristian Sigg   Value allocatedPtr =
7941002a1d0SNishant Patel       allocCallBuilder.create(loc, rewriter, {sizeBytes, stream, isHostShared})
7951002a1d0SNishant Patel           .getResult();
796a5f9cda1SChristian Sigg 
797a5f9cda1SChristian Sigg   // No alignment.
798a5f9cda1SChristian Sigg   Value alignedPtr = allocatedPtr;
799a5f9cda1SChristian Sigg 
800a5f9cda1SChristian Sigg   // Create the MemRef descriptor.
801a5f9cda1SChristian Sigg   auto memRefDescriptor = this->createMemRefDescriptor(
802a5f9cda1SChristian Sigg       loc, memRefType, allocatedPtr, alignedPtr, shape, strides, rewriter);
803a5f9cda1SChristian Sigg 
804ced9f4f0SNishant Patel   if (allocOp.getAsyncToken()) {
805ced9f4f0SNishant Patel     // Async alloc: make dependent ops use the same stream.
806a5f9cda1SChristian Sigg     rewriter.replaceOp(allocOp, {memRefDescriptor, stream});
807ced9f4f0SNishant Patel   } else {
808ced9f4f0SNishant Patel     rewriter.replaceOp(allocOp, {memRefDescriptor});
809ced9f4f0SNishant Patel   }
810a5f9cda1SChristian Sigg 
811a5f9cda1SChristian Sigg   return success();
812a5f9cda1SChristian Sigg }
813a5f9cda1SChristian Sigg 
814a5f9cda1SChristian Sigg LogicalResult ConvertDeallocOpToGpuRuntimeCallPattern::matchAndRewrite(
815ef976337SRiver Riddle     gpu::DeallocOp deallocOp, OpAdaptor adaptor,
816a5f9cda1SChristian Sigg     ConversionPatternRewriter &rewriter) const {
817ef976337SRiver Riddle   if (failed(areAllLLVMTypes(deallocOp, adaptor.getOperands(), rewriter)) ||
818a5f9cda1SChristian Sigg       failed(isAsyncWithOneDependency(rewriter, deallocOp)))
819a5f9cda1SChristian Sigg     return failure();
820a5f9cda1SChristian Sigg 
821a5f9cda1SChristian Sigg   Location loc = deallocOp.getLoc();
822a5f9cda1SChristian Sigg 
823a5f9cda1SChristian Sigg   Value pointer =
82410c04f46SRiver Riddle       MemRefDescriptor(adaptor.getMemref()).allocatedPtr(rewriter, loc);
82510c04f46SRiver Riddle   Value stream = adaptor.getAsyncDependencies().front();
8260e5aeae6SMarkus Böck   deallocCallBuilder.create(loc, rewriter, {pointer, stream});
827a5f9cda1SChristian Sigg 
828a5f9cda1SChristian Sigg   rewriter.replaceOp(deallocOp, {stream});
829a5f9cda1SChristian Sigg   return success();
830a5f9cda1SChristian Sigg }
831a5f9cda1SChristian Sigg 
832a5f9cda1SChristian Sigg static bool isGpuAsyncTokenType(Value value) {
8335550c821STres Popp   return isa<gpu::AsyncTokenType>(value.getType());
834a5f9cda1SChristian Sigg }
835a5f9cda1SChristian Sigg 
836a5f9cda1SChristian Sigg // Converts !gpu.async.token operands of `async.yield` to runtime calls. The
837a5f9cda1SChristian Sigg // !gpu.async.token are lowered to stream within the async.execute region, but
838a5f9cda1SChristian Sigg // are passed as events between them. For each !gpu.async.token operand, we
839a5f9cda1SChristian Sigg // create an event and record it on the stream.
840a5f9cda1SChristian Sigg LogicalResult ConvertAsyncYieldToGpuRuntimeCallPattern::matchAndRewrite(
841ef976337SRiver Riddle     async::YieldOp yieldOp, OpAdaptor adaptor,
842a5f9cda1SChristian Sigg     ConversionPatternRewriter &rewriter) const {
843b74192b7SRiver Riddle   if (llvm::none_of(yieldOp.getOperands(), isGpuAsyncTokenType))
844a5f9cda1SChristian Sigg     return rewriter.notifyMatchFailure(yieldOp, "no gpu async token operand");
845a5f9cda1SChristian Sigg 
846a5f9cda1SChristian Sigg   Location loc = yieldOp.getLoc();
847ef976337SRiver Riddle   SmallVector<Value, 4> newOperands(adaptor.getOperands());
848a5f9cda1SChristian Sigg   llvm::SmallDenseSet<Value> streams;
849a5f9cda1SChristian Sigg   for (auto &operand : yieldOp->getOpOperands()) {
850a5f9cda1SChristian Sigg     if (!isGpuAsyncTokenType(operand.get()))
851a5f9cda1SChristian Sigg       continue;
852a5f9cda1SChristian Sigg     auto idx = operand.getOperandNumber();
853ef976337SRiver Riddle     auto stream = adaptor.getOperands()[idx];
8545e0c3b43SJeff Niu     auto event = eventCreateCallBuilder.create(loc, rewriter, {}).getResult();
855a5f9cda1SChristian Sigg     eventRecordCallBuilder.create(loc, rewriter, {event, stream});
856a5f9cda1SChristian Sigg     newOperands[idx] = event;
857a5f9cda1SChristian Sigg     streams.insert(stream);
858a5f9cda1SChristian Sigg   }
859a5f9cda1SChristian Sigg   for (auto stream : streams)
860a5f9cda1SChristian Sigg     streamDestroyCallBuilder.create(loc, rewriter, {stream});
861a5f9cda1SChristian Sigg 
8625fcf907bSMatthias Springer   rewriter.modifyOpInPlace(yieldOp, [&] { yieldOp->setOperands(newOperands); });
863a5f9cda1SChristian Sigg   return success();
864a5f9cda1SChristian Sigg }
865a5f9cda1SChristian Sigg 
866a5f9cda1SChristian Sigg // Returns whether `value` is the result of an LLVM::CallOp to `functionName`.
867a5f9cda1SChristian Sigg static bool isDefinedByCallTo(Value value, StringRef functionName) {
8685550c821STres Popp   assert(isa<LLVM::LLVMPointerType>(value.getType()));
869a5f9cda1SChristian Sigg   if (auto defOp = value.getDefiningOp<LLVM::CallOp>())
870dec8055aSKazu Hirata     return *defOp.getCallee() == functionName;
871a5f9cda1SChristian Sigg   return false;
872a5f9cda1SChristian Sigg }
873a5f9cda1SChristian Sigg 
874a5f9cda1SChristian Sigg // Converts `gpu.wait` to runtime calls. The converted op synchronizes the host
875a5f9cda1SChristian Sigg // with the stream/event operands. The operands are destroyed. That is, it
876a5f9cda1SChristian Sigg // assumes that it is not used afterwards or elsewhere. Otherwise we will get a
877a5f9cda1SChristian Sigg // runtime error. Eventually, we should guarantee this property.
878a5f9cda1SChristian Sigg LogicalResult ConvertWaitOpToGpuRuntimeCallPattern::matchAndRewrite(
879ef976337SRiver Riddle     gpu::WaitOp waitOp, OpAdaptor adaptor,
880a5f9cda1SChristian Sigg     ConversionPatternRewriter &rewriter) const {
88110c04f46SRiver Riddle   if (waitOp.getAsyncToken())
882a5f9cda1SChristian Sigg     return rewriter.notifyMatchFailure(waitOp, "Cannot convert async op.");
883a5f9cda1SChristian Sigg 
884a5f9cda1SChristian Sigg   Location loc = waitOp.getLoc();
885a5f9cda1SChristian Sigg 
886ef976337SRiver Riddle   for (auto operand : adaptor.getOperands()) {
887a5f9cda1SChristian Sigg     if (isDefinedByCallTo(operand, streamCreateCallBuilder.functionName)) {
888a5f9cda1SChristian Sigg       // The converted operand's definition created a stream.
889a5f9cda1SChristian Sigg       streamSynchronizeCallBuilder.create(loc, rewriter, {operand});
890a5f9cda1SChristian Sigg       streamDestroyCallBuilder.create(loc, rewriter, {operand});
891a5f9cda1SChristian Sigg     } else {
892a5f9cda1SChristian Sigg       // Otherwise the converted operand is an event. This assumes that we use
893a5f9cda1SChristian Sigg       // events in control flow code as well.
894a5f9cda1SChristian Sigg       eventSynchronizeCallBuilder.create(loc, rewriter, {operand});
895a5f9cda1SChristian Sigg       eventDestroyCallBuilder.create(loc, rewriter, {operand});
896a5f9cda1SChristian Sigg     }
897a5f9cda1SChristian Sigg   }
898a5f9cda1SChristian Sigg 
899a5f9cda1SChristian Sigg   rewriter.eraseOp(waitOp);
900a5f9cda1SChristian Sigg   return success();
901a5f9cda1SChristian Sigg }
902a5f9cda1SChristian Sigg 
903a5f9cda1SChristian Sigg // Converts `gpu.wait async` to runtime calls. The converted op creates a new
904a5f9cda1SChristian Sigg // stream that is synchronized with stream/event operands. The operands are
905a5f9cda1SChristian Sigg // destroyed. That is, it assumes that it is not used afterwards or elsewhere.
906a5f9cda1SChristian Sigg // Otherwise we will get a runtime error. Eventually, we should guarantee this
907a5f9cda1SChristian Sigg // property.
908a5f9cda1SChristian Sigg LogicalResult ConvertWaitAsyncOpToGpuRuntimeCallPattern::matchAndRewrite(
909ef976337SRiver Riddle     gpu::WaitOp waitOp, OpAdaptor adaptor,
910a5f9cda1SChristian Sigg     ConversionPatternRewriter &rewriter) const {
91110c04f46SRiver Riddle   if (!waitOp.getAsyncToken())
912a5f9cda1SChristian Sigg     return rewriter.notifyMatchFailure(waitOp, "Can only convert async op.");
913a5f9cda1SChristian Sigg 
914a5f9cda1SChristian Sigg   Location loc = waitOp.getLoc();
915a5f9cda1SChristian Sigg 
916a5f9cda1SChristian Sigg   auto insertionPoint = rewriter.saveInsertionPoint();
917a5f9cda1SChristian Sigg   SmallVector<Value, 1> events;
918ef976337SRiver Riddle   for (auto pair :
91910c04f46SRiver Riddle        llvm::zip(waitOp.getAsyncDependencies(), adaptor.getOperands())) {
920a5f9cda1SChristian Sigg     auto operand = std::get<1>(pair);
921a5f9cda1SChristian Sigg     if (isDefinedByCallTo(operand, streamCreateCallBuilder.functionName)) {
922a5f9cda1SChristian Sigg       // The converted operand's definition created a stream. Insert an event
923a5f9cda1SChristian Sigg       // into the stream just after the last use of the original token operand.
924a5f9cda1SChristian Sigg       auto *defOp = std::get<0>(pair).getDefiningOp();
925a5f9cda1SChristian Sigg       rewriter.setInsertionPointAfter(defOp);
9265e0c3b43SJeff Niu       auto event = eventCreateCallBuilder.create(loc, rewriter, {}).getResult();
927a5f9cda1SChristian Sigg       eventRecordCallBuilder.create(loc, rewriter, {event, operand});
928a5f9cda1SChristian Sigg       events.push_back(event);
929a5f9cda1SChristian Sigg     } else {
930a5f9cda1SChristian Sigg       // Otherwise the converted operand is an event. This assumes that we use
931a5f9cda1SChristian Sigg       // events in control flow code as well.
932a5f9cda1SChristian Sigg       events.push_back(operand);
933a5f9cda1SChristian Sigg     }
934a5f9cda1SChristian Sigg   }
935a5f9cda1SChristian Sigg   rewriter.restoreInsertionPoint(insertionPoint);
9365e0c3b43SJeff Niu   auto stream = streamCreateCallBuilder.create(loc, rewriter, {}).getResult();
937a5f9cda1SChristian Sigg   for (auto event : events)
938a5f9cda1SChristian Sigg     streamWaitEventCallBuilder.create(loc, rewriter, {stream, event});
939a5f9cda1SChristian Sigg   for (auto event : events)
940a5f9cda1SChristian Sigg     eventDestroyCallBuilder.create(loc, rewriter, {event});
941a5f9cda1SChristian Sigg   rewriter.replaceOp(waitOp, {stream});
942a5f9cda1SChristian Sigg 
943a5f9cda1SChristian Sigg   return success();
944a5f9cda1SChristian Sigg }
945a5f9cda1SChristian Sigg 
9468e12f31bSFabian Mora // Legalize the op's operands.
9478e12f31bSFabian Mora LogicalResult LegalizeLaunchFuncOpPattern::matchAndRewrite(
948ef976337SRiver Riddle     gpu::LaunchFuncOp launchOp, OpAdaptor adaptor,
949a5f9cda1SChristian Sigg     ConversionPatternRewriter &rewriter) const {
950ef976337SRiver Riddle   if (failed(areAllLLVMTypes(launchOp, adaptor.getOperands(), rewriter)))
951a5f9cda1SChristian Sigg     return failure();
952a5f9cda1SChristian Sigg 
95310c04f46SRiver Riddle   if (launchOp.getAsyncDependencies().size() > 1)
954a5f9cda1SChristian Sigg     return rewriter.notifyMatchFailure(
955a5f9cda1SChristian Sigg         launchOp, "Cannot convert with more than one async dependency.");
956a5f9cda1SChristian Sigg 
957a5f9cda1SChristian Sigg   // Fail when the synchronous version of the op has async dependencies. The
958a5f9cda1SChristian Sigg   // lowering destroys the stream, and we do not want to check that there is no
959a5f9cda1SChristian Sigg   // use of the stream after this op.
96010c04f46SRiver Riddle   if (!launchOp.getAsyncToken() && !launchOp.getAsyncDependencies().empty())
961a5f9cda1SChristian Sigg     return rewriter.notifyMatchFailure(
962a5f9cda1SChristian Sigg         launchOp, "Cannot convert non-async op with async dependencies.");
963a5f9cda1SChristian Sigg 
964a5f9cda1SChristian Sigg   Location loc = launchOp.getLoc();
965a5f9cda1SChristian Sigg 
966fcfeb1e5SFabian Mora   Value stream = Value();
967583e78b3SAdrian Kuegel   if (!adaptor.getAsyncDependencies().empty())
968fcfeb1e5SFabian Mora     stream = adaptor.getAsyncDependencies().front();
969fcfeb1e5SFabian Mora   // If the async keyword is present and there are no dependencies, then a
970fcfeb1e5SFabian Mora   // stream must be created to pass to subsequent operations.
971fcfeb1e5SFabian Mora   else if (launchOp.getAsyncToken())
972fcfeb1e5SFabian Mora     stream = streamCreateCallBuilder.create(loc, rewriter, {}).getResult();
97399a562b3SAndrea Faulds 
974fcfeb1e5SFabian Mora   // Lower the kernel operands to match kernel parameters.
9757f4dbd83SMatthias Springer   // Note: If `useBarePtrCallConv` is set in the type converter's options,
9767f4dbd83SMatthias Springer   // the value of `kernelBarePtrCallConv` will be ignored.
977*733be4edSAndrea Faulds   OperandRange origArguments = launchOp.getKernelOperands();
978*733be4edSAndrea Faulds   SmallVector<Value, 8> llvmArguments = getTypeConverter()->promoteOperands(
979*733be4edSAndrea Faulds       loc, origArguments, adaptor.getKernelOperands(), rewriter,
9808e12f31bSFabian Mora       /*useBarePtrCallConv=*/kernelBarePtrCallConv);
981*733be4edSAndrea Faulds   SmallVector<Value, 8> llvmArgumentsWithSizes;
982*733be4edSAndrea Faulds 
983*733be4edSAndrea Faulds   // Intersperse size information if requested.
984*733be4edSAndrea Faulds   if (kernelIntersperseSizeCallConv) {
985*733be4edSAndrea Faulds     if (origArguments.size() != llvmArguments.size()) {
986*733be4edSAndrea Faulds       // This shouldn't happen if the bare-pointer calling convention is used.
987*733be4edSAndrea Faulds       return rewriter.notifyMatchFailure(
988*733be4edSAndrea Faulds           launchOp,
989*733be4edSAndrea Faulds           "Cannot add sizes to arguments with one-to-many LLVM IR expansion.");
990*733be4edSAndrea Faulds     }
991*733be4edSAndrea Faulds 
992*733be4edSAndrea Faulds     llvmArgumentsWithSizes.reserve(llvmArguments.size() * 2);
993*733be4edSAndrea Faulds     for (auto [llvmArg, origArg] : zip_equal(llvmArguments, origArguments)) {
994*733be4edSAndrea Faulds       auto memrefTy = dyn_cast<MemRefType>(origArg.getType());
995*733be4edSAndrea Faulds       if (!memrefTy) {
996*733be4edSAndrea Faulds         return rewriter.notifyMatchFailure(
997*733be4edSAndrea Faulds             launchOp, "Operand to launch op is not a memref.");
998*733be4edSAndrea Faulds       }
999*733be4edSAndrea Faulds 
1000*733be4edSAndrea Faulds       if (!memrefTy.hasStaticShape() ||
1001*733be4edSAndrea Faulds           !memrefTy.getElementType().isIntOrFloat()) {
1002*733be4edSAndrea Faulds         return rewriter.notifyMatchFailure(
1003*733be4edSAndrea Faulds             launchOp, "Operand to launch op is not a memref with a static "
1004*733be4edSAndrea Faulds                       "shape and an integer or float element type.");
1005*733be4edSAndrea Faulds       }
1006*733be4edSAndrea Faulds 
1007*733be4edSAndrea Faulds       unsigned bitwidth = memrefTy.getElementTypeBitWidth();
1008*733be4edSAndrea Faulds       if (bitwidth % 8 != 0) {
1009*733be4edSAndrea Faulds         return rewriter.notifyMatchFailure(
1010*733be4edSAndrea Faulds             launchOp, "Operand to launch op is not a memref with a "
1011*733be4edSAndrea Faulds                       "byte-aligned element type.");
1012*733be4edSAndrea Faulds       }
1013*733be4edSAndrea Faulds 
1014*733be4edSAndrea Faulds       uint64_t staticSize = static_cast<uint64_t>(bitwidth / 8) *
1015*733be4edSAndrea Faulds                             static_cast<uint64_t>(memrefTy.getNumElements());
1016*733be4edSAndrea Faulds 
1017*733be4edSAndrea Faulds       Value sizeArg = rewriter.create<LLVM::ConstantOp>(
1018*733be4edSAndrea Faulds           loc, getIndexType(), rewriter.getIndexAttr(staticSize));
1019*733be4edSAndrea Faulds       llvmArgumentsWithSizes.push_back(llvmArg); // Presumably a bare pointer.
1020*733be4edSAndrea Faulds       llvmArgumentsWithSizes.push_back(sizeArg);
1021*733be4edSAndrea Faulds     }
1022*733be4edSAndrea Faulds   }
1023fcfeb1e5SFabian Mora 
1024edf5cae7SGuray Ozen   std::optional<gpu::KernelDim3> clusterSize = std::nullopt;
1025edf5cae7SGuray Ozen   if (launchOp.hasClusterSize()) {
1026edf5cae7SGuray Ozen     clusterSize =
1027edf5cae7SGuray Ozen         gpu::KernelDim3{adaptor.getClusterSizeX(), adaptor.getClusterSizeY(),
1028edf5cae7SGuray Ozen                         adaptor.getClusterSizeZ()};
1029edf5cae7SGuray Ozen   }
1030fcfeb1e5SFabian Mora   rewriter.create<gpu::LaunchFuncOp>(
1031fcfeb1e5SFabian Mora       launchOp.getLoc(), launchOp.getKernelAttr(),
1032fcfeb1e5SFabian Mora       gpu::KernelDim3{adaptor.getGridSizeX(), adaptor.getGridSizeY(),
1033fcfeb1e5SFabian Mora                       adaptor.getGridSizeZ()},
1034fcfeb1e5SFabian Mora       gpu::KernelDim3{adaptor.getBlockSizeX(), adaptor.getBlockSizeY(),
1035fcfeb1e5SFabian Mora                       adaptor.getBlockSizeZ()},
1036*733be4edSAndrea Faulds       adaptor.getDynamicSharedMemorySize(),
1037*733be4edSAndrea Faulds       llvmArgumentsWithSizes.empty() ? llvmArguments : llvmArgumentsWithSizes,
1038*733be4edSAndrea Faulds       stream, clusterSize);
1039fcfeb1e5SFabian Mora   if (launchOp.getAsyncToken())
1040fcfeb1e5SFabian Mora     rewriter.replaceOp(launchOp, {stream});
1041fcfeb1e5SFabian Mora   else
1042fcfeb1e5SFabian Mora     rewriter.eraseOp(launchOp);
1043fcfeb1e5SFabian Mora   return success();
1044fcfeb1e5SFabian Mora }
1045fcfeb1e5SFabian Mora 
10460aaf2e3bSMarkus Böck static Value bitAndAddrspaceCast(Location loc,
10470aaf2e3bSMarkus Böck                                  ConversionPatternRewriter &rewriter,
10480aaf2e3bSMarkus Böck                                  LLVM::LLVMPointerType destinationType,
10490aaf2e3bSMarkus Böck                                  Value sourcePtr,
1050ce254598SMatthias Springer                                  const LLVMTypeConverter &typeConverter) {
10515550c821STres Popp   auto sourceTy = cast<LLVM::LLVMPointerType>(sourcePtr.getType());
10520aaf2e3bSMarkus Böck   if (destinationType.getAddressSpace() != sourceTy.getAddressSpace())
10530aaf2e3bSMarkus Böck     sourcePtr = rewriter.create<LLVM::AddrSpaceCastOp>(
10540aaf2e3bSMarkus Böck         loc,
1055dbd4a0ddSChristian Ulmann         LLVM::LLVMPointerType::get(rewriter.getContext(),
10560aaf2e3bSMarkus Böck                                    destinationType.getAddressSpace()),
10570aaf2e3bSMarkus Böck         sourcePtr);
10580e5aeae6SMarkus Böck   return sourcePtr;
10590aaf2e3bSMarkus Böck }
10600aaf2e3bSMarkus Böck 
1061a5f9cda1SChristian Sigg LogicalResult ConvertMemcpyOpToGpuRuntimeCallPattern::matchAndRewrite(
1062ef976337SRiver Riddle     gpu::MemcpyOp memcpyOp, OpAdaptor adaptor,
1063a5f9cda1SChristian Sigg     ConversionPatternRewriter &rewriter) const {
10645550c821STres Popp   auto memRefType = cast<MemRefType>(memcpyOp.getSrc().getType());
1065a5f9cda1SChristian Sigg 
1066ef976337SRiver Riddle   if (failed(areAllLLVMTypes(memcpyOp, adaptor.getOperands(), rewriter)) ||
1067a5f9cda1SChristian Sigg       !isConvertibleAndHasIdentityMaps(memRefType) ||
1068a5f9cda1SChristian Sigg       failed(isAsyncWithOneDependency(rewriter, memcpyOp)))
1069a5f9cda1SChristian Sigg     return failure();
1070a5f9cda1SChristian Sigg 
1071a5f9cda1SChristian Sigg   auto loc = memcpyOp.getLoc();
1072a5f9cda1SChristian Sigg 
107310c04f46SRiver Riddle   MemRefDescriptor srcDesc(adaptor.getSrc());
1074361458b1SLoren Maggiore   Value numElements = getNumElements(rewriter, loc, memRefType, srcDesc);
1075a5f9cda1SChristian Sigg 
1076a5f9cda1SChristian Sigg   Type elementPtrType = getElementPtrType(memRefType);
107785175eddSTobias Gysi   Value nullPtr = rewriter.create<LLVM::ZeroOp>(loc, elementPtrType);
10780e5aeae6SMarkus Böck   Value gepPtr = rewriter.create<LLVM::GEPOp>(
10790e5aeae6SMarkus Böck       loc, elementPtrType,
10800e5aeae6SMarkus Böck       typeConverter->convertType(memRefType.getElementType()), nullPtr,
10810e5aeae6SMarkus Böck       numElements);
1082a5f9cda1SChristian Sigg   auto sizeBytes =
1083a5f9cda1SChristian Sigg       rewriter.create<LLVM::PtrToIntOp>(loc, getIndexType(), gepPtr);
1084a5f9cda1SChristian Sigg 
10850aaf2e3bSMarkus Böck   auto src = bitAndAddrspaceCast(loc, rewriter, llvmPointerType,
10860aaf2e3bSMarkus Böck                                  srcDesc.alignedPtr(rewriter, loc),
10870aaf2e3bSMarkus Böck                                  *getTypeConverter());
10880aaf2e3bSMarkus Böck   auto dst = bitAndAddrspaceCast(
10890aaf2e3bSMarkus Böck       loc, rewriter, llvmPointerType,
10900aaf2e3bSMarkus Böck       MemRefDescriptor(adaptor.getDst()).alignedPtr(rewriter, loc),
10910aaf2e3bSMarkus Böck       *getTypeConverter());
1092a5f9cda1SChristian Sigg 
109310c04f46SRiver Riddle   auto stream = adaptor.getAsyncDependencies().front();
1094a5f9cda1SChristian Sigg   memcpyCallBuilder.create(loc, rewriter, {dst, src, sizeBytes, stream});
1095a5f9cda1SChristian Sigg 
1096a5f9cda1SChristian Sigg   rewriter.replaceOp(memcpyOp, {stream});
1097a5f9cda1SChristian Sigg 
1098a5f9cda1SChristian Sigg   return success();
1099a5f9cda1SChristian Sigg }
1100a5f9cda1SChristian Sigg 
1101361458b1SLoren Maggiore LogicalResult ConvertMemsetOpToGpuRuntimeCallPattern::matchAndRewrite(
1102ef976337SRiver Riddle     gpu::MemsetOp memsetOp, OpAdaptor adaptor,
1103361458b1SLoren Maggiore     ConversionPatternRewriter &rewriter) const {
11045550c821STres Popp   auto memRefType = cast<MemRefType>(memsetOp.getDst().getType());
1105361458b1SLoren Maggiore 
1106ef976337SRiver Riddle   if (failed(areAllLLVMTypes(memsetOp, adaptor.getOperands(), rewriter)) ||
1107361458b1SLoren Maggiore       !isConvertibleAndHasIdentityMaps(memRefType) ||
1108361458b1SLoren Maggiore       failed(isAsyncWithOneDependency(rewriter, memsetOp)))
1109361458b1SLoren Maggiore     return failure();
1110361458b1SLoren Maggiore 
1111361458b1SLoren Maggiore   auto loc = memsetOp.getLoc();
1112361458b1SLoren Maggiore 
111310c04f46SRiver Riddle   Type valueType = adaptor.getValue().getType();
111418cc07aaSNavdeep Katel   unsigned bitWidth = valueType.getIntOrFloatBitWidth();
111518cc07aaSNavdeep Katel   // Ints and floats of 16 or 32 bit width are allowed.
111618cc07aaSNavdeep Katel   if (!valueType.isIntOrFloat() || (bitWidth != 16 && bitWidth != 32)) {
111718cc07aaSNavdeep Katel     return rewriter.notifyMatchFailure(
111818cc07aaSNavdeep Katel         memsetOp, "value must be a 16 or 32 bit int or float");
1119361458b1SLoren Maggiore   }
1120361458b1SLoren Maggiore 
112118cc07aaSNavdeep Katel   unsigned valueTypeWidth = valueType.getIntOrFloatBitWidth();
112218cc07aaSNavdeep Katel   Type bitCastType = valueTypeWidth == 32 ? llvmInt32Type : llvmInt16Type;
112318cc07aaSNavdeep Katel 
112410c04f46SRiver Riddle   MemRefDescriptor dstDesc(adaptor.getDst());
1125361458b1SLoren Maggiore   Value numElements = getNumElements(rewriter, loc, memRefType, dstDesc);
1126361458b1SLoren Maggiore 
1127361458b1SLoren Maggiore   auto value =
112818cc07aaSNavdeep Katel       rewriter.create<LLVM::BitcastOp>(loc, bitCastType, adaptor.getValue());
11290aaf2e3bSMarkus Böck   auto dst = bitAndAddrspaceCast(loc, rewriter, llvmPointerType,
11300aaf2e3bSMarkus Böck                                  dstDesc.alignedPtr(rewriter, loc),
11310aaf2e3bSMarkus Böck                                  *getTypeConverter());
1132361458b1SLoren Maggiore 
113310c04f46SRiver Riddle   auto stream = adaptor.getAsyncDependencies().front();
113418cc07aaSNavdeep Katel   FunctionCallBuilder builder =
113518cc07aaSNavdeep Katel       valueTypeWidth == 32 ? memset32CallBuilder : memset16CallBuilder;
113618cc07aaSNavdeep Katel   builder.create(loc, rewriter, {dst, value, numElements, stream});
1137361458b1SLoren Maggiore 
1138361458b1SLoren Maggiore   rewriter.replaceOp(memsetOp, {stream});
1139361458b1SLoren Maggiore   return success();
1140361458b1SLoren Maggiore }
1141361458b1SLoren Maggiore 
114284718d37SKrzysztof Drewniak LogicalResult ConvertSetDefaultDeviceOpToGpuRuntimeCallPattern::matchAndRewrite(
114384718d37SKrzysztof Drewniak     gpu::SetDefaultDeviceOp op, OpAdaptor adaptor,
114484718d37SKrzysztof Drewniak     ConversionPatternRewriter &rewriter) const {
114584718d37SKrzysztof Drewniak   Location loc = op.getLoc();
114611141bc6SPaul C Fuqua   auto call = setDefaultDeviceCallBuilder.create(loc, rewriter,
114711141bc6SPaul C Fuqua                                                  {adaptor.getDevIndex()});
114811141bc6SPaul C Fuqua   rewriter.replaceOp(op, call);
114984718d37SKrzysztof Drewniak   return success();
115084718d37SKrzysztof Drewniak }
115184718d37SKrzysztof Drewniak 
1152cc402de0SKun Wu template <typename T>
11536ac80a76SMehdi Amini static Value genConstInt32From(OpBuilder &builder, Location loc, T tValue) {
1154235fbe79SKun Wu   Type llvmInt32Type = builder.getIntegerType(32);
1155235fbe79SKun Wu   return builder.create<LLVM::ConstantOp>(loc, llvmInt32Type,
11566ac80a76SMehdi Amini                                           static_cast<int32_t>(tValue));
1157cc402de0SKun Wu }
1158cc402de0SKun Wu 
1159dfe29429SKun Wu template <typename T>
11606ac80a76SMehdi Amini static Value genConstFloat32From(OpBuilder &builder, Location loc, T tValue) {
1161dfe29429SKun Wu   Type llvmFloat32Type = builder.getF32Type();
1162dfe29429SKun Wu   return builder.create<LLVM::ConstantOp>(
1163dfe29429SKun Wu       loc, llvmFloat32Type,
11646ac80a76SMehdi Amini       builder.getF32FloatAttr(static_cast<float>(tValue)));
1165dfe29429SKun Wu }
1166dfe29429SKun Wu 
116797f4c22bSKun Wu LogicalResult ConvertCreateDnTensorOpToGpuRuntimeCallPattern::matchAndRewrite(
116897f4c22bSKun Wu     gpu::CreateDnTensorOp op, OpAdaptor adaptor,
1169b700a90cSAart Bik     ConversionPatternRewriter &rewriter) const {
1170b700a90cSAart Bik   if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1171b700a90cSAart Bik       failed(isAsyncWithOneDependency(rewriter, op)))
1172b700a90cSAart Bik     return failure();
1173b700a90cSAart Bik   Location loc = op.getLoc();
1174b700a90cSAart Bik   auto stream = adaptor.getAsyncDependencies().front();
117597f4c22bSKun Wu   Value pTensor =
1176b700a90cSAart Bik       MemRefDescriptor(adaptor.getMemref()).allocatedPtr(rewriter, loc);
1177cc402de0SKun Wu   Type dType = op.getMemref().getType().getElementType();
1178cc402de0SKun Wu   auto dtp = genConstInt32From(rewriter, loc, getCuSparseDataTypeFrom(dType));
117997f4c22bSKun Wu 
118097f4c22bSKun Wu   SmallVector<Value, 4> dims;
118197f4c22bSKun Wu   for (Value dim : adaptor.getDims()) {
118297f4c22bSKun Wu     dims.push_back(dim);
1183b700a90cSAart Bik   }
1184b700a90cSAart Bik 
118597f4c22bSKun Wu   Value handle;
11868ed59c53SKun Wu   // TODO: For now, we track the use of the handle and lower it to cusparse /
11878ed59c53SKun Wu   // cusparseLt accordingly. If in a block, both cusparse and cusparseLt are
11888ed59c53SKun Wu   // used, we require two separate Creation ops to be the correct logic. In
11898ed59c53SKun Wu   // future, we may add support to using one handle in sparse tensor / GPU
11908ed59c53SKun Wu   // dialect in both cusparse and cusparseLt. use the cusparseLt create call if
11918ed59c53SKun Wu   // the dnmat is used with spmat with 2:4 sparsity
119297f4c22bSKun Wu   if (dims.size() == 2) {
119397f4c22bSKun Wu     if (isSpMMCusparseLtOp(op.getDnTensor())) {
11948ed59c53SKun Wu       auto handleSz = rewriter.create<LLVM::ConstantOp>(
11958ed59c53SKun Wu           loc, getIndexType(), rewriter.getIndexAttr(11032));
119686eff489SAart Bik       handle = rewriter.create<LLVM::AllocaOp>(
1197dbd4a0ddSChristian Ulmann           loc, llvmPointerType, llvmInt8Type, handleSz, /*alignment=*/16);
11988ed59c53SKun Wu       handle = rewriter.create<LLVM::BitcastOp>(loc, llvmPointerType, handle);
11998ed59c53SKun Wu 
12008ed59c53SKun Wu       createLtDnMatCallBuilder
12018ed59c53SKun Wu           .create(loc, rewriter,
1202be2dd22bSKun Wu                   {handle, dims[0], dims[1], pTensor, dtp, stream})
12038ed59c53SKun Wu           .getResult();
12048ed59c53SKun Wu     } else {
12058ed59c53SKun Wu       handle =
1206981cf167SAart Bik           createDnMatCallBuilder
120797f4c22bSKun Wu               .create(loc, rewriter, {dims[0], dims[1], pTensor, dtp, stream})
120897f4c22bSKun Wu               .getResult();
120997f4c22bSKun Wu     }
121097f4c22bSKun Wu   } else {
121197f4c22bSKun Wu     assert(dims.size() == 1 && "Only 1D and 2D tensors are supported");
121297f4c22bSKun Wu     handle = createDnVecCallBuilder
121397f4c22bSKun Wu                  .create(loc, rewriter, {dims[0], pTensor, dtp, stream})
1214981cf167SAart Bik                  .getResult();
12158ed59c53SKun Wu   }
1216981cf167SAart Bik   rewriter.replaceOp(op, {handle, stream});
1217981cf167SAart Bik   return success();
1218981cf167SAart Bik }
1219981cf167SAart Bik 
122097f4c22bSKun Wu LogicalResult ConvertDestroyDnTensorOpToGpuRuntimeCallPattern::matchAndRewrite(
122197f4c22bSKun Wu     gpu::DestroyDnTensorOp op, OpAdaptor adaptor,
1222981cf167SAart Bik     ConversionPatternRewriter &rewriter) const {
1223981cf167SAart Bik   if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1224981cf167SAart Bik       failed(isAsyncWithOneDependency(rewriter, op)))
1225981cf167SAart Bik     return failure();
1226981cf167SAart Bik   Location loc = op.getLoc();
1227981cf167SAart Bik   auto stream = adaptor.getAsyncDependencies().front();
122897f4c22bSKun Wu   auto definingOp = op.getDnTensor().getDefiningOp<gpu::CreateDnTensorOp>();
122997f4c22bSKun Wu   SmallVector<Value, 4> dims;
123097f4c22bSKun Wu   for (Value dim : definingOp.getDims()) {
123197f4c22bSKun Wu     dims.push_back(dim);
123297f4c22bSKun Wu   }
123397f4c22bSKun Wu   if (dims.size() == 2) {
12348ed59c53SKun Wu     // Use the cusparseLt destroy call if the dnmat is used with spmat with
12358ed59c53SKun Wu     // 2:4 sparsity
123697f4c22bSKun Wu     if (isSpMMCusparseLtOp(op.getDnTensor())) {
12378ed59c53SKun Wu       destroyCuSparseLtDnMatBuilder.create(loc, rewriter,
123897f4c22bSKun Wu                                            {adaptor.getDnTensor(), stream});
12398ed59c53SKun Wu     } else {
124097f4c22bSKun Wu       destroyDnMatCallBuilder.create(loc, rewriter,
124197f4c22bSKun Wu                                      {adaptor.getDnTensor(), stream});
124297f4c22bSKun Wu     }
124397f4c22bSKun Wu   } else {
124497f4c22bSKun Wu     assert(dims.size() == 1 && "Only 1D and 2D tensors are supported");
124597f4c22bSKun Wu     destroyDnVecCallBuilder.create(loc, rewriter,
124697f4c22bSKun Wu                                    {adaptor.getDnTensor(), stream});
12478ed59c53SKun Wu   }
1248981cf167SAart Bik   rewriter.replaceOp(op, {stream});
1249981cf167SAart Bik   return success();
1250981cf167SAart Bik }
1251981cf167SAart Bik 
1252b700a90cSAart Bik LogicalResult ConvertCreateCooOpToGpuRuntimeCallPattern::matchAndRewrite(
1253b700a90cSAart Bik     gpu::CreateCooOp op, OpAdaptor adaptor,
1254b700a90cSAart Bik     ConversionPatternRewriter &rewriter) const {
1255b700a90cSAart Bik   if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1256b700a90cSAart Bik       failed(isAsyncWithOneDependency(rewriter, op)))
1257b700a90cSAart Bik     return failure();
1258b700a90cSAart Bik   Location loc = op.getLoc();
1259b700a90cSAart Bik   auto stream = adaptor.getAsyncDependencies().front();
1260b700a90cSAart Bik   Value pRowIdxs =
1261b700a90cSAart Bik       MemRefDescriptor(adaptor.getRowIdxs()).allocatedPtr(rewriter, loc);
1262b700a90cSAart Bik   Value pColIdxs =
1263b700a90cSAart Bik       MemRefDescriptor(adaptor.getColIdxs()).allocatedPtr(rewriter, loc);
1264b700a90cSAart Bik   Value pValues =
1265b700a90cSAart Bik       MemRefDescriptor(adaptor.getValues()).allocatedPtr(rewriter, loc);
1266cf44847bSKun Wu   Type iType =
1267cf44847bSKun Wu       llvm::cast<MemRefType>(op.getColIdxs().getType()).getElementType();
1268cf44847bSKun Wu   Type dType =
1269cf44847bSKun Wu       llvm::cast<MemRefType>(op.getValues().getType()).getElementType();
1270cc402de0SKun Wu   auto itp = genConstInt32From(rewriter, loc, getCuSparseIndexTypeFrom(iType));
1271cc402de0SKun Wu   auto dtp = genConstInt32From(rewriter, loc, getCuSparseDataTypeFrom(dType));
1272b700a90cSAart Bik   auto handle =
1273b700a90cSAart Bik       createCooCallBuilder
1274b700a90cSAart Bik           .create(loc, rewriter,
1275b700a90cSAart Bik                   {adaptor.getRows(), adaptor.getCols(), adaptor.getNnz(),
1276cc402de0SKun Wu                    pRowIdxs, pColIdxs, pValues, itp, dtp, stream})
1277b700a90cSAart Bik           .getResult();
1278b700a90cSAart Bik   rewriter.replaceOp(op, {handle, stream});
1279b700a90cSAart Bik   return success();
1280b700a90cSAart Bik }
1281b700a90cSAart Bik 
12829fc02a7aSAart Bik LogicalResult ConvertCreateCooAoSOpToGpuRuntimeCallPattern::matchAndRewrite(
12839fc02a7aSAart Bik     gpu::CreateCooAoSOp op, OpAdaptor adaptor,
12849fc02a7aSAart Bik     ConversionPatternRewriter &rewriter) const {
12859fc02a7aSAart Bik   if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
12869fc02a7aSAart Bik       failed(isAsyncWithOneDependency(rewriter, op)))
12879fc02a7aSAart Bik     return failure();
12889fc02a7aSAart Bik   Location loc = op.getLoc();
12899fc02a7aSAart Bik   auto stream = adaptor.getAsyncDependencies().front();
12909fc02a7aSAart Bik   Value pIdxs = MemRefDescriptor(adaptor.getIdxs()).allocatedPtr(rewriter, loc);
12919fc02a7aSAart Bik   Value pValues =
12929fc02a7aSAart Bik       MemRefDescriptor(adaptor.getValues()).allocatedPtr(rewriter, loc);
12938ed59c53SKun Wu   Type iType = llvm::cast<MemRefType>(op.getIdxs().getType()).getElementType();
12949fc02a7aSAart Bik   Type dType =
12959fc02a7aSAart Bik       llvm::cast<MemRefType>(op.getValues().getType()).getElementType();
12969fc02a7aSAart Bik   auto itp = genConstInt32From(rewriter, loc, getCuSparseIndexTypeFrom(iType));
12979fc02a7aSAart Bik   auto dtp = genConstInt32From(rewriter, loc, getCuSparseDataTypeFrom(dType));
12989fc02a7aSAart Bik   auto handle =
12999fc02a7aSAart Bik       createCooAoSCallBuilder
13009fc02a7aSAart Bik           .create(loc, rewriter,
13019fc02a7aSAart Bik                   {adaptor.getRows(), adaptor.getCols(), adaptor.getNnz(),
13029fc02a7aSAart Bik                    pIdxs, pValues, itp, dtp, stream})
13039fc02a7aSAart Bik           .getResult();
13049fc02a7aSAart Bik   rewriter.replaceOp(op, {handle, stream});
13059fc02a7aSAart Bik   return success();
13069fc02a7aSAart Bik }
13079fc02a7aSAart Bik 
1308b700a90cSAart Bik LogicalResult ConvertCreateCsrOpToGpuRuntimeCallPattern::matchAndRewrite(
1309b700a90cSAart Bik     gpu::CreateCsrOp op, OpAdaptor adaptor,
1310b700a90cSAart Bik     ConversionPatternRewriter &rewriter) const {
1311b700a90cSAart Bik   if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1312b700a90cSAart Bik       failed(isAsyncWithOneDependency(rewriter, op)))
1313b700a90cSAart Bik     return failure();
1314b700a90cSAart Bik   Location loc = op.getLoc();
1315b700a90cSAart Bik   auto stream = adaptor.getAsyncDependencies().front();
1316b700a90cSAart Bik   Value pRowPos =
1317b700a90cSAart Bik       MemRefDescriptor(adaptor.getRowPos()).allocatedPtr(rewriter, loc);
1318b700a90cSAart Bik   Value pColIdxs =
1319b700a90cSAart Bik       MemRefDescriptor(adaptor.getColIdxs()).allocatedPtr(rewriter, loc);
1320b700a90cSAart Bik   Value pValues =
1321b700a90cSAart Bik       MemRefDescriptor(adaptor.getValues()).allocatedPtr(rewriter, loc);
1322cf44847bSKun Wu   Type pType =
1323cf44847bSKun Wu       llvm::cast<MemRefType>(op.getRowPos().getType()).getElementType();
1324cf44847bSKun Wu   Type iType =
1325cf44847bSKun Wu       llvm::cast<MemRefType>(op.getColIdxs().getType()).getElementType();
1326cf44847bSKun Wu   Type dType =
1327cf44847bSKun Wu       llvm::cast<MemRefType>(op.getValues().getType()).getElementType();
1328cc402de0SKun Wu   auto ptp = genConstInt32From(rewriter, loc, getCuSparseIndexTypeFrom(pType));
1329cc402de0SKun Wu   auto itp = genConstInt32From(rewriter, loc, getCuSparseIndexTypeFrom(iType));
1330cc402de0SKun Wu   auto dtp = genConstInt32From(rewriter, loc, getCuSparseDataTypeFrom(dType));
1331b700a90cSAart Bik   auto handle =
1332b700a90cSAart Bik       createCsrCallBuilder
1333b700a90cSAart Bik           .create(loc, rewriter,
1334b700a90cSAart Bik                   {adaptor.getRows(), adaptor.getCols(), adaptor.getNnz(),
1335cc402de0SKun Wu                    pRowPos, pColIdxs, pValues, ptp, itp, dtp, stream})
1336b700a90cSAart Bik           .getResult();
1337b700a90cSAart Bik   rewriter.replaceOp(op, {handle, stream});
1338b700a90cSAart Bik   return success();
1339b700a90cSAart Bik }
1340b700a90cSAart Bik 
13418ed59c53SKun Wu LogicalResult ConvertCreate2To4SpMatOpToGpuRuntimeCallPattern::matchAndRewrite(
13428ed59c53SKun Wu     gpu::Create2To4SpMatOp op, OpAdaptor adaptor,
13438ed59c53SKun Wu     ConversionPatternRewriter &rewriter) const {
13448ed59c53SKun Wu   if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
13458ed59c53SKun Wu       failed(isAsyncWithOneDependency(rewriter, op)))
13468ed59c53SKun Wu     return failure();
13478ed59c53SKun Wu   Location loc = op.getLoc();
13488ed59c53SKun Wu   auto stream = adaptor.getAsyncDependencies().front();
13498ed59c53SKun Wu   Value pMat =
13508ed59c53SKun Wu       MemRefDescriptor(adaptor.getMemref()).allocatedPtr(rewriter, loc);
13518ed59c53SKun Wu   Type dType =
13528ed59c53SKun Wu       llvm::cast<MemRefType>(op.getMemref().getType()).getElementType();
1353ac30f48eSKun Wu   auto dtp = genConstInt32From(rewriter, loc, getCuSparseDataTypeFrom(dType));
13548ed59c53SKun Wu 
1355ac30f48eSKun Wu   // CUDA runner asserts the size is 44104 bytes.
13568ed59c53SKun Wu   auto handleSz = rewriter.create<LLVM::ConstantOp>(
13578ed59c53SKun Wu       loc, getIndexType(), rewriter.getIndexAttr(44104));
135886eff489SAart Bik   Value handle = rewriter.create<LLVM::AllocaOp>(
1359dbd4a0ddSChristian Ulmann       loc, llvmPointerType, llvmInt8Type, handleSz, /*alignment=*/16);
13608ed59c53SKun Wu   handle = rewriter.create<LLVM::BitcastOp>(loc, llvmPointerType, handle);
13618ed59c53SKun Wu 
13628ed59c53SKun Wu   create2To4SpMatCallBuilder
13638ed59c53SKun Wu       .create(loc, rewriter,
1364be2dd22bSKun Wu               {handle, adaptor.getRows(), adaptor.getCols(), pMat, dtp, stream})
13658ed59c53SKun Wu       .getResult();
13668ed59c53SKun Wu   rewriter.replaceOp(op, {handle, stream});
13678ed59c53SKun Wu   return success();
13688ed59c53SKun Wu }
13698ed59c53SKun Wu 
1370b700a90cSAart Bik LogicalResult ConvertDestroySpMatOpToGpuRuntimeCallPattern::matchAndRewrite(
1371b700a90cSAart Bik     gpu::DestroySpMatOp op, OpAdaptor adaptor,
1372b700a90cSAart Bik     ConversionPatternRewriter &rewriter) const {
1373b700a90cSAart Bik   if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1374b700a90cSAart Bik       failed(isAsyncWithOneDependency(rewriter, op)))
1375b700a90cSAart Bik     return failure();
1376b700a90cSAart Bik   Location loc = op.getLoc();
1377b700a90cSAart Bik   auto stream = adaptor.getAsyncDependencies().front();
13788ed59c53SKun Wu   // Use the cusparseLt destroy call if the spmat is 2:4 sparsity
13798ed59c53SKun Wu   if (is2To4Sparsity(op.getSpmat())) {
13808ed59c53SKun Wu     destroyCuSparseLtSpMatBuilder.create(loc, rewriter,
13818ed59c53SKun Wu                                          {adaptor.getSpmat(), stream});
13828ed59c53SKun Wu 
13838ed59c53SKun Wu   } else {
1384b700a90cSAart Bik     destroySpMatCallBuilder.create(loc, rewriter, {adaptor.getSpmat(), stream});
13858ed59c53SKun Wu   }
1386b700a90cSAart Bik   rewriter.replaceOp(op, {stream});
1387b700a90cSAart Bik   return success();
1388b700a90cSAart Bik }
1389b700a90cSAart Bik 
1390b700a90cSAart Bik LogicalResult ConvertSpMVBufferSizeOpToGpuRuntimeCallPattern::matchAndRewrite(
1391b700a90cSAart Bik     gpu::SpMVBufferSizeOp op, OpAdaptor adaptor,
1392b700a90cSAart Bik     ConversionPatternRewriter &rewriter) const {
1393b700a90cSAart Bik   if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1394b700a90cSAart Bik       failed(isAsyncWithOneDependency(rewriter, op)))
1395b700a90cSAart Bik     return failure();
1396b700a90cSAart Bik   Location loc = op.getLoc();
1397cc402de0SKun Wu   auto modeA = genConstInt32From(rewriter, loc, op.getModeA());
1398ac30f48eSKun Wu   auto computeType = genConstInt32From(
1399ac30f48eSKun Wu       rewriter, loc, getCuSparseDataTypeFrom(adaptor.getComputeType()));
1400b700a90cSAart Bik   auto stream = adaptor.getAsyncDependencies().front();
1401be2dd22bSKun Wu   auto bufferSize = spMVBufferSizeCallBuilder
1402b700a90cSAart Bik                         .create(loc, rewriter,
1403be2dd22bSKun Wu                                 {modeA, adaptor.getSpmatA(), adaptor.getDnX(),
1404be2dd22bSKun Wu                                  adaptor.getDnY(), computeType, stream})
1405b700a90cSAart Bik                         .getResult();
1406b700a90cSAart Bik   rewriter.replaceOp(op, {bufferSize, stream});
1407b700a90cSAart Bik   return success();
1408b700a90cSAart Bik }
1409b700a90cSAart Bik 
1410b700a90cSAart Bik LogicalResult ConvertSpMVOpToGpuRuntimeCallPattern::matchAndRewrite(
1411b700a90cSAart Bik     gpu::SpMVOp op, OpAdaptor adaptor,
1412b700a90cSAart Bik     ConversionPatternRewriter &rewriter) const {
1413b700a90cSAart Bik   if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1414b700a90cSAart Bik       failed(isAsyncWithOneDependency(rewriter, op)))
1415b700a90cSAart Bik     return failure();
1416b700a90cSAart Bik   Location loc = op.getLoc();
1417cc402de0SKun Wu   auto modeA = genConstInt32From(rewriter, loc, adaptor.getModeA());
1418ac30f48eSKun Wu   auto computeType = genConstInt32From(
1419ac30f48eSKun Wu       rewriter, loc, getCuSparseDataTypeFrom(adaptor.getComputeType()));
1420b700a90cSAart Bik   auto stream = adaptor.getAsyncDependencies().front();
1421b700a90cSAart Bik   Value pBuf =
1422b700a90cSAart Bik       MemRefDescriptor(adaptor.getBuffer()).allocatedPtr(rewriter, loc);
1423b700a90cSAart Bik   spMVCallBuilder.create(loc, rewriter,
1424be2dd22bSKun Wu                          {modeA, adaptor.getSpmatA(), adaptor.getDnX(),
1425be2dd22bSKun Wu                           adaptor.getDnY(), computeType, pBuf, stream});
1426b700a90cSAart Bik   rewriter.replaceOp(op, {stream});
1427b700a90cSAart Bik   return success();
1428b700a90cSAart Bik }
1429b700a90cSAart Bik 
1430981cf167SAart Bik LogicalResult ConvertSpMMBufferSizeOpToGpuRuntimeCallPattern::matchAndRewrite(
1431981cf167SAart Bik     gpu::SpMMBufferSizeOp op, OpAdaptor adaptor,
1432981cf167SAart Bik     ConversionPatternRewriter &rewriter) const {
1433981cf167SAart Bik   if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1434981cf167SAart Bik       failed(isAsyncWithOneDependency(rewriter, op)))
1435981cf167SAart Bik     return failure();
1436981cf167SAart Bik   Location loc = op.getLoc();
1437cc402de0SKun Wu   auto modeA = genConstInt32From(rewriter, loc, adaptor.getModeA());
1438cc402de0SKun Wu   auto modeB = genConstInt32From(rewriter, loc, adaptor.getModeB());
1439981cf167SAart Bik   auto stream = adaptor.getAsyncDependencies().front();
14408ed59c53SKun Wu   Value bufferSize;
14418ed59c53SKun Wu   if (is2To4Sparsity(op.getSpmatA())) {
14426ac80a76SMehdi Amini     auto pruneFlag =
14431e491c42SKun Wu         genConstInt32From(rewriter, loc, get2To4PruneFlag(op.getSpmatA()));
1444ac30f48eSKun Wu     auto computeType = genConstInt32From(
1445ac30f48eSKun Wu         rewriter, loc, getCuSparseLtDataTypeFrom(adaptor.getComputeType()));
14468ed59c53SKun Wu     auto three = rewriter.create<LLVM::ConstantOp>(loc, getIndexType(),
14478ed59c53SKun Wu                                                    rewriter.getIndexAttr(3));
144886eff489SAart Bik     auto bufferSize = rewriter.create<LLVM::AllocaOp>(
1449dbd4a0ddSChristian Ulmann         loc, llvmPointerType, llvmPointerType, three, /*alignment=*/16);
1450ac30f48eSKun Wu     createCuSparseLtSpMMBufferSizeBuilder
14518ed59c53SKun Wu         .create(loc, rewriter,
1452be2dd22bSKun Wu                 {bufferSize, modeA, modeB, adaptor.getSpmatA(),
14531e491c42SKun Wu                  adaptor.getDnmatB(), adaptor.getDnmatC(), computeType,
14546ac80a76SMehdi Amini                  pruneFlag, stream})
14558ed59c53SKun Wu         .getResult();
1456632ccc53SKun Wu 
1457632ccc53SKun Wu     auto bufferSizePtr1 = rewriter.create<LLVM::GEPOp>(
1458dbd4a0ddSChristian Ulmann         loc, llvmPointerType, llvmPointerType, bufferSize,
1459632ccc53SKun Wu         ValueRange{rewriter.create<LLVM::ConstantOp>(
1460632ccc53SKun Wu             loc, getIndexType(), rewriter.getIndexAttr(1))});
1461632ccc53SKun Wu     auto bufferSizePtr2 = rewriter.create<LLVM::GEPOp>(
1462dbd4a0ddSChristian Ulmann         loc, llvmPointerType, llvmPointerType, bufferSize,
1463632ccc53SKun Wu         ValueRange{rewriter.create<LLVM::ConstantOp>(
1464632ccc53SKun Wu             loc, getIndexType(), rewriter.getIndexAttr(2))});
1465632ccc53SKun Wu     auto bufferSize0 =
1466632ccc53SKun Wu         rewriter.create<LLVM::LoadOp>(loc, llvmInt64Type, bufferSize);
1467632ccc53SKun Wu     auto bufferSize1 =
1468632ccc53SKun Wu         rewriter.create<LLVM::LoadOp>(loc, llvmInt64Type, bufferSizePtr1);
1469632ccc53SKun Wu     auto bufferSize2 =
1470632ccc53SKun Wu         rewriter.create<LLVM::LoadOp>(loc, llvmInt64Type, bufferSizePtr2);
1471632ccc53SKun Wu 
1472632ccc53SKun Wu     rewriter.replaceOp(op, {bufferSize0, bufferSize1, bufferSize2, stream});
14738ed59c53SKun Wu   } else {
1474ac30f48eSKun Wu     auto computeType = genConstInt32From(
1475ac30f48eSKun Wu         rewriter, loc, getCuSparseDataTypeFrom(adaptor.getComputeType()));
1476be2dd22bSKun Wu     bufferSize =
1477be2dd22bSKun Wu         createSpMMBufferSizeCallBuilder
1478981cf167SAart Bik             .create(loc, rewriter,
1479be2dd22bSKun Wu                     {modeA, modeB, adaptor.getSpmatA(), adaptor.getDnmatB(),
1480cc402de0SKun Wu                      adaptor.getDnmatC(), computeType, stream})
1481981cf167SAart Bik             .getResult();
1482981cf167SAart Bik     rewriter.replaceOp(op, {bufferSize, stream});
14838ed59c53SKun Wu   }
1484981cf167SAart Bik   return success();
1485981cf167SAart Bik }
1486981cf167SAart Bik 
1487cf44847bSKun Wu LogicalResult ConvertSDDMMBufferSizeOpToGpuRuntimeCallPattern::matchAndRewrite(
1488cf44847bSKun Wu     gpu::SDDMMBufferSizeOp op, OpAdaptor adaptor,
1489cf44847bSKun Wu     ConversionPatternRewriter &rewriter) const {
1490cf44847bSKun Wu   if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1491cf44847bSKun Wu       failed(isAsyncWithOneDependency(rewriter, op)))
1492cf44847bSKun Wu     return failure();
1493cf44847bSKun Wu   Location loc = op.getLoc();
1494cc402de0SKun Wu   auto modeA = genConstInt32From(rewriter, loc, adaptor.getModeA());
1495cc402de0SKun Wu   auto modeB = genConstInt32From(rewriter, loc, adaptor.getModeB());
1496ac30f48eSKun Wu   auto computeType = genConstInt32From(
1497ac30f48eSKun Wu       rewriter, loc, getCuSparseDataTypeFrom(adaptor.getComputeType()));
1498cf44847bSKun Wu   auto stream = adaptor.getAsyncDependencies().front();
1499be2dd22bSKun Wu   auto bufferSize =
1500be2dd22bSKun Wu       createSDDMMBufferSizeCallBuilder
1501cf44847bSKun Wu           .create(loc, rewriter,
1502be2dd22bSKun Wu                   {modeA, modeB, adaptor.getDnmatA(), adaptor.getDnmatB(),
1503cc402de0SKun Wu                    adaptor.getSpmatC(), computeType, stream})
1504cf44847bSKun Wu           .getResult();
1505cf44847bSKun Wu   rewriter.replaceOp(op, {bufferSize, stream});
1506cf44847bSKun Wu   return success();
1507cf44847bSKun Wu }
1508cf44847bSKun Wu 
1509981cf167SAart Bik LogicalResult ConvertSpMMOpToGpuRuntimeCallPattern::matchAndRewrite(
1510981cf167SAart Bik     gpu::SpMMOp op, OpAdaptor adaptor,
1511981cf167SAart Bik     ConversionPatternRewriter &rewriter) const {
1512981cf167SAart Bik   if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1513981cf167SAart Bik       failed(isAsyncWithOneDependency(rewriter, op)))
1514981cf167SAart Bik     return failure();
1515981cf167SAart Bik   Location loc = op.getLoc();
1516cc402de0SKun Wu   auto modeA = genConstInt32From(rewriter, loc, adaptor.getModeA());
1517cc402de0SKun Wu   auto modeB = genConstInt32From(rewriter, loc, adaptor.getModeB());
1518ac30f48eSKun Wu   auto computeType = genConstInt32From(
1519ac30f48eSKun Wu       rewriter, loc, getCuSparseDataTypeFrom(adaptor.getComputeType()));
1520cc402de0SKun Wu 
1521981cf167SAart Bik   auto stream = adaptor.getAsyncDependencies().front();
15228ed59c53SKun Wu 
15238ed59c53SKun Wu   // Lower to cusparseLt if applicable
15248ed59c53SKun Wu   if (is2To4Sparsity(op.getSpmatA())) {
15258ed59c53SKun Wu     SmallVector<Value> pBufs;
15268ed59c53SKun Wu     for (Value buffer : adaptor.getBuffers()) {
15278ed59c53SKun Wu       Value pBuf = MemRefDescriptor(buffer).allocatedPtr(rewriter, loc);
15288ed59c53SKun Wu       pBufs.push_back(pBuf);
15298ed59c53SKun Wu     }
1530ac30f48eSKun Wu     createCuSparseLtSpMMBuilder.create(
1531ac30f48eSKun Wu         loc, rewriter,
1532be2dd22bSKun Wu         {adaptor.getSpmatA(), adaptor.getDnmatB(), adaptor.getDnmatC(),
1533be2dd22bSKun Wu          pBufs[0], pBufs[1], pBufs[2], stream});
15348ed59c53SKun Wu   } else {
15358ed59c53SKun Wu     Value pBuf = MemRefDescriptor(adaptor.getBuffers().front())
15368ed59c53SKun Wu                      .allocatedPtr(rewriter, loc);
1537be2dd22bSKun Wu     createSpMMCallBuilder.create(loc, rewriter,
1538be2dd22bSKun Wu                                  {modeA, modeB, adaptor.getSpmatA(),
1539be2dd22bSKun Wu                                   adaptor.getDnmatB(), adaptor.getDnmatC(),
1540be2dd22bSKun Wu                                   computeType, pBuf, stream});
15418ed59c53SKun Wu   }
1542981cf167SAart Bik   rewriter.replaceOp(op, {stream});
1543981cf167SAart Bik   return success();
1544981cf167SAart Bik }
1545981cf167SAart Bik 
154686bf710cSKun Wu template <typename T>
154786bf710cSKun Wu static void addOpaquePointerConversion(LLVMTypeConverter &converter) {
154886bf710cSKun Wu   converter.addConversion([&converter](T) -> Type {
1549dbd4a0ddSChristian Ulmann     return LLVM::LLVMPointerType::get(&converter.getContext());
155086bf710cSKun Wu   });
155186bf710cSKun Wu }
155286bf710cSKun Wu 
1553cf44847bSKun Wu LogicalResult ConvertSDDMMOpToGpuRuntimeCallPattern::matchAndRewrite(
1554cf44847bSKun Wu     gpu::SDDMMOp op, OpAdaptor adaptor,
1555cf44847bSKun Wu     ConversionPatternRewriter &rewriter) const {
1556cf44847bSKun Wu   if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1557cf44847bSKun Wu       failed(isAsyncWithOneDependency(rewriter, op)))
1558cf44847bSKun Wu     return failure();
1559cf44847bSKun Wu   Location loc = op.getLoc();
1560ac30f48eSKun Wu   auto computeType = genConstInt32From(
1561ac30f48eSKun Wu       rewriter, loc, getCuSparseDataTypeFrom(adaptor.getComputeType()));
1562cc402de0SKun Wu   auto modeA = genConstInt32From(rewriter, loc, adaptor.getModeA());
1563cc402de0SKun Wu   auto modeB = genConstInt32From(rewriter, loc, adaptor.getModeB());
1564cf44847bSKun Wu   auto stream = adaptor.getAsyncDependencies().front();
1565cf44847bSKun Wu   Value pBuf =
1566cf44847bSKun Wu       MemRefDescriptor(adaptor.getBuffer()).allocatedPtr(rewriter, loc);
1567be2dd22bSKun Wu   createSDDMMCallBuilder.create(loc, rewriter,
1568be2dd22bSKun Wu                                 {modeA, modeB, adaptor.getDnmatA(),
1569be2dd22bSKun Wu                                  adaptor.getDnmatB(), adaptor.getSpmatC(),
1570be2dd22bSKun Wu                                  computeType, pBuf, stream});
1571cf44847bSKun Wu   rewriter.replaceOp(op, {stream});
1572cf44847bSKun Wu   return success();
1573cf44847bSKun Wu }
1574cf44847bSKun Wu 
1575dfe29429SKun Wu LogicalResult
1576dfe29429SKun Wu ConvertSpGEMMCreateDescrOpToGpuRuntimeCallPattern::matchAndRewrite(
1577dfe29429SKun Wu     gpu::SpGEMMCreateDescrOp op, OpAdaptor adaptor,
1578dfe29429SKun Wu     ConversionPatternRewriter &rewriter) const {
1579dfe29429SKun Wu   if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1580dfe29429SKun Wu       failed(isAsyncWithOneDependency(rewriter, op)))
1581dfe29429SKun Wu     return failure();
1582dfe29429SKun Wu   Location loc = op.getLoc();
1583dfe29429SKun Wu   auto stream = adaptor.getAsyncDependencies().front();
1584dfe29429SKun Wu   Value descr = createSpGEMMCreateDescrBuilder.create(loc, rewriter, {stream})
1585dfe29429SKun Wu                     .getResult();
1586dfe29429SKun Wu   rewriter.replaceOp(op, {descr, stream});
1587dfe29429SKun Wu   return success();
1588dfe29429SKun Wu }
1589dfe29429SKun Wu 
1590dfe29429SKun Wu LogicalResult
1591dfe29429SKun Wu ConvertSpGEMMDestroyDescrOpToGpuRuntimeCallPattern::matchAndRewrite(
1592dfe29429SKun Wu     gpu::SpGEMMDestroyDescrOp op, OpAdaptor adaptor,
1593dfe29429SKun Wu     ConversionPatternRewriter &rewriter) const {
1594dfe29429SKun Wu   if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1595dfe29429SKun Wu       failed(isAsyncWithOneDependency(rewriter, op)))
1596dfe29429SKun Wu     return failure();
1597dfe29429SKun Wu   Location loc = op.getLoc();
1598dfe29429SKun Wu   auto stream = adaptor.getAsyncDependencies().front();
159995a6c509SAart Bik   createSpGEMMDestroyDescrBuilder.create(loc, rewriter,
160095a6c509SAart Bik                                          {adaptor.getDesc(), stream});
1601dfe29429SKun Wu   rewriter.replaceOp(op, {stream});
1602dfe29429SKun Wu   return success();
1603dfe29429SKun Wu }
1604dfe29429SKun Wu 
1605dfe29429SKun Wu LogicalResult
1606dfe29429SKun Wu ConvertSpGEMMWorkEstimationOrComputeOpToGpuRuntimeCallPattern::matchAndRewrite(
1607dfe29429SKun Wu     gpu::SpGEMMWorkEstimationOrComputeOp op, OpAdaptor adaptor,
1608dfe29429SKun Wu     ConversionPatternRewriter &rewriter) const {
1609dfe29429SKun Wu   if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1610dfe29429SKun Wu       failed(isAsyncWithOneDependency(rewriter, op)))
1611dfe29429SKun Wu     return failure();
1612dfe29429SKun Wu   Location loc = op.getLoc();
1613dfe29429SKun Wu   auto computeType = genConstInt32From(
1614dfe29429SKun Wu       rewriter, loc, getCuSparseDataTypeFrom(adaptor.getComputeType()));
1615dfe29429SKun Wu   auto modeA = genConstInt32From(rewriter, loc, adaptor.getModeA());
1616dfe29429SKun Wu   auto modeB = genConstInt32From(rewriter, loc, adaptor.getModeB());
1617dfe29429SKun Wu   auto stream = adaptor.getAsyncDependencies().front();
1618dfe29429SKun Wu 
1619dfe29429SKun Wu   Value pBuf =
1620dfe29429SKun Wu       MemRefDescriptor(adaptor.getBuffer()).allocatedPtr(rewriter, loc);
1621dfe29429SKun Wu   Value bufferSizeNew;
1622dfe29429SKun Wu 
1623dfe29429SKun Wu   if (adaptor.getKind() ==
1624dfe29429SKun Wu       gpu::SpGEMMWorkEstimationOrComputeKind::WORK_ESTIMATION) {
1625dfe29429SKun Wu     bufferSizeNew =
1626dfe29429SKun Wu         createSpGEMMWorkEstimationBuilder
1627dfe29429SKun Wu             .create(loc, rewriter,
1628dfe29429SKun Wu                     {adaptor.getDesc(), modeA, modeB, adaptor.getSpmatA(),
1629e7e4ed0dSAart Bik                      adaptor.getSpmatB(), adaptor.getSpmatC(), computeType,
1630dfe29429SKun Wu                      adaptor.getBufferSz(), pBuf, stream})
1631dfe29429SKun Wu             .getResult();
1632dfe29429SKun Wu   } else {
1633dfe29429SKun Wu     bufferSizeNew =
1634dfe29429SKun Wu         createSpGEMMComputeBuilder
1635dfe29429SKun Wu             .create(loc, rewriter,
1636dfe29429SKun Wu                     {adaptor.getDesc(), modeA, modeB, adaptor.getSpmatA(),
1637e7e4ed0dSAart Bik                      adaptor.getSpmatB(), adaptor.getSpmatC(), computeType,
1638dfe29429SKun Wu                      adaptor.getBufferSz(), pBuf, stream})
1639dfe29429SKun Wu             .getResult();
1640dfe29429SKun Wu   }
1641dfe29429SKun Wu   rewriter.replaceOp(op, {bufferSizeNew, stream});
1642dfe29429SKun Wu   return success();
1643dfe29429SKun Wu }
1644dfe29429SKun Wu 
1645dfe29429SKun Wu LogicalResult ConvertSpGEMMCopyOpToGpuRuntimeCallPattern::matchAndRewrite(
1646dfe29429SKun Wu     gpu::SpGEMMCopyOp op, OpAdaptor adaptor,
1647dfe29429SKun Wu     ConversionPatternRewriter &rewriter) const {
1648dfe29429SKun Wu   if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1649dfe29429SKun Wu       failed(isAsyncWithOneDependency(rewriter, op)))
1650dfe29429SKun Wu     return failure();
1651dfe29429SKun Wu   Location loc = op.getLoc();
1652dfe29429SKun Wu   auto computeType = genConstInt32From(
1653dfe29429SKun Wu       rewriter, loc, getCuSparseDataTypeFrom(adaptor.getComputeType()));
1654dfe29429SKun Wu   auto modeA = genConstInt32From(rewriter, loc, adaptor.getModeA());
1655dfe29429SKun Wu   auto modeB = genConstInt32From(rewriter, loc, adaptor.getModeB());
1656dfe29429SKun Wu   auto stream = adaptor.getAsyncDependencies().front();
1657e7e4ed0dSAart Bik   createSpGEMMCopyBuilder.create(loc, rewriter,
1658e7e4ed0dSAart Bik                                  {adaptor.getDesc(), modeA, modeB,
1659e7e4ed0dSAart Bik                                   adaptor.getSpmatA(), adaptor.getSpmatB(),
1660e7e4ed0dSAart Bik                                   adaptor.getSpmatC(), computeType, stream});
1661dfe29429SKun Wu   rewriter.replaceOp(op, {stream});
1662dfe29429SKun Wu   return success();
1663dfe29429SKun Wu }
1664dfe29429SKun Wu 
1665289f7231SAart Bik LogicalResult ConvertSpMatGetSizeOpToGpuRuntimeCallPattern::matchAndRewrite(
1666289f7231SAart Bik     gpu::SpMatGetSizeOp op, OpAdaptor adaptor,
1667dfe29429SKun Wu     ConversionPatternRewriter &rewriter) const {
1668dfe29429SKun Wu   if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1669dfe29429SKun Wu       failed(isAsyncWithOneDependency(rewriter, op)))
1670dfe29429SKun Wu     return failure();
1671dfe29429SKun Wu   Location loc = op.getLoc();
1672dfe29429SKun Wu   auto stream = adaptor.getAsyncDependencies().front();
1673dfe29429SKun Wu 
1674dfe29429SKun Wu   auto three = rewriter.create<LLVM::ConstantOp>(loc, getIndexType(),
1675dfe29429SKun Wu                                                  rewriter.getIndexAttr(3));
1676dfe29429SKun Wu   auto buffer = rewriter.create<LLVM::AllocaOp>(
1677dbd4a0ddSChristian Ulmann       loc, llvmPointerType, llvmInt64Type, three, /*alignment=*/16);
1678dfe29429SKun Wu 
1679dfe29429SKun Wu   auto rowsPtr = rewriter.create<LLVM::GEPOp>(
1680dbd4a0ddSChristian Ulmann       loc, llvmPointerType, llvmPointerType, buffer,
1681dfe29429SKun Wu       ValueRange{rewriter.create<LLVM::ConstantOp>(loc, getIndexType(),
1682dfe29429SKun Wu                                                    rewriter.getIndexAttr(0))});
1683dfe29429SKun Wu   auto colsPtr = rewriter.create<LLVM::GEPOp>(
1684dbd4a0ddSChristian Ulmann       loc, llvmPointerType, llvmPointerType, buffer,
1685dfe29429SKun Wu       ValueRange{rewriter.create<LLVM::ConstantOp>(loc, getIndexType(),
1686dfe29429SKun Wu                                                    rewriter.getIndexAttr(1))});
1687dfe29429SKun Wu   auto nnzsPtr = rewriter.create<LLVM::GEPOp>(
1688dbd4a0ddSChristian Ulmann       loc, llvmPointerType, llvmPointerType, buffer,
1689dfe29429SKun Wu       ValueRange{rewriter.create<LLVM::ConstantOp>(loc, getIndexType(),
1690dfe29429SKun Wu                                                    rewriter.getIndexAttr(2))});
1691289f7231SAart Bik   createSpMatGetSizeBuilder.create(
1692dfe29429SKun Wu       loc, rewriter, {adaptor.getSpmat(), rowsPtr, colsPtr, nnzsPtr, stream});
1693dfe29429SKun Wu   auto rows = rewriter.create<LLVM::LoadOp>(loc, llvmInt64Type, rowsPtr);
1694dfe29429SKun Wu   auto cols = rewriter.create<LLVM::LoadOp>(loc, llvmInt64Type, colsPtr);
1695dfe29429SKun Wu   auto nnzs = rewriter.create<LLVM::LoadOp>(loc, llvmInt64Type, nnzsPtr);
1696dfe29429SKun Wu 
1697dfe29429SKun Wu   rewriter.replaceOp(op, {rows, cols, nnzs, stream});
1698dfe29429SKun Wu   return success();
1699dfe29429SKun Wu }
1700dfe29429SKun Wu 
170195a6c509SAart Bik LogicalResult ConvertSetCsrPointersOpToGpuRuntimeCallPattern::matchAndRewrite(
170295a6c509SAart Bik     gpu::SetCsrPointersOp op, OpAdaptor adaptor,
170395a6c509SAart Bik     ConversionPatternRewriter &rewriter) const {
170495a6c509SAart Bik   if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
170595a6c509SAart Bik       failed(isAsyncWithOneDependency(rewriter, op)))
170695a6c509SAart Bik     return failure();
170795a6c509SAart Bik   Location loc = op.getLoc();
170895a6c509SAart Bik   auto stream = adaptor.getAsyncDependencies().front();
170995a6c509SAart Bik   Value pPos =
171095a6c509SAart Bik       MemRefDescriptor(adaptor.getPositions()).allocatedPtr(rewriter, loc);
171195a6c509SAart Bik   Value pCrd =
171295a6c509SAart Bik       MemRefDescriptor(adaptor.getCoordinates()).allocatedPtr(rewriter, loc);
171395a6c509SAart Bik   Value pVal =
171495a6c509SAart Bik       MemRefDescriptor(adaptor.getValues()).allocatedPtr(rewriter, loc);
171595a6c509SAart Bik   createSetCsrPointersBuilder.create(
171695a6c509SAart Bik       loc, rewriter, {adaptor.getSpmat(), pPos, pCrd, pVal, stream});
171795a6c509SAart Bik   rewriter.replaceOp(op, {stream});
171895a6c509SAart Bik   return success();
171995a6c509SAart Bik }
172095a6c509SAart Bik 
172139038177SAart Bik LogicalResult ConvertCreateCscOpToGpuRuntimeCallPattern::matchAndRewrite(
172239038177SAart Bik     gpu::CreateCscOp op, OpAdaptor adaptor,
172339038177SAart Bik     ConversionPatternRewriter &rewriter) const {
172439038177SAart Bik   if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
172539038177SAart Bik       failed(isAsyncWithOneDependency(rewriter, op)))
172639038177SAart Bik     return failure();
172739038177SAart Bik   Location loc = op.getLoc();
172839038177SAart Bik   auto stream = adaptor.getAsyncDependencies().front();
172939038177SAart Bik   Value pColPos =
173039038177SAart Bik       MemRefDescriptor(adaptor.getColPos()).allocatedPtr(rewriter, loc);
173139038177SAart Bik   Value pRowIdxs =
173239038177SAart Bik       MemRefDescriptor(adaptor.getRowIdxs()).allocatedPtr(rewriter, loc);
173339038177SAart Bik   Value pValues =
173439038177SAart Bik       MemRefDescriptor(adaptor.getValues()).allocatedPtr(rewriter, loc);
173539038177SAart Bik   Type pType =
173639038177SAart Bik       llvm::cast<MemRefType>(op.getColPos().getType()).getElementType();
173739038177SAart Bik   Type iType =
173839038177SAart Bik       llvm::cast<MemRefType>(op.getRowIdxs().getType()).getElementType();
173939038177SAart Bik   Type dType =
174039038177SAart Bik       llvm::cast<MemRefType>(op.getValues().getType()).getElementType();
174139038177SAart Bik   auto ptp = genConstInt32From(rewriter, loc, getCuSparseIndexTypeFrom(pType));
174239038177SAart Bik   auto itp = genConstInt32From(rewriter, loc, getCuSparseIndexTypeFrom(iType));
174339038177SAart Bik   auto dtp = genConstInt32From(rewriter, loc, getCuSparseDataTypeFrom(dType));
174439038177SAart Bik   auto handle =
174539038177SAart Bik       createCscCallBuilder
174639038177SAart Bik           .create(loc, rewriter,
174739038177SAart Bik                   {adaptor.getRows(), adaptor.getCols(), adaptor.getNnz(),
174839038177SAart Bik                    pColPos, pRowIdxs, pValues, ptp, itp, dtp, stream})
174939038177SAart Bik           .getResult();
175039038177SAart Bik   rewriter.replaceOp(op, {handle, stream});
175139038177SAart Bik   return success();
175239038177SAart Bik }
175339038177SAart Bik 
175439038177SAart Bik LogicalResult ConvertCreateBsrOpToGpuRuntimeCallPattern::matchAndRewrite(
175539038177SAart Bik     gpu::CreateBsrOp op, OpAdaptor adaptor,
175639038177SAart Bik     ConversionPatternRewriter &rewriter) const {
175739038177SAart Bik   if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
175839038177SAart Bik       failed(isAsyncWithOneDependency(rewriter, op)))
175939038177SAart Bik     return failure();
176039038177SAart Bik   Location loc = op.getLoc();
176139038177SAart Bik   auto stream = adaptor.getAsyncDependencies().front();
176239038177SAart Bik   Value pRowPos =
176339038177SAart Bik       MemRefDescriptor(adaptor.getBRowPos()).allocatedPtr(rewriter, loc);
176439038177SAart Bik   Value pColIdxs =
176539038177SAart Bik       MemRefDescriptor(adaptor.getBColIdxs()).allocatedPtr(rewriter, loc);
176639038177SAart Bik   Value pValues =
176739038177SAart Bik       MemRefDescriptor(adaptor.getValues()).allocatedPtr(rewriter, loc);
176839038177SAart Bik   Type pType =
176939038177SAart Bik       llvm::cast<MemRefType>(op.getBRowPos().getType()).getElementType();
177039038177SAart Bik   Type iType =
177139038177SAart Bik       llvm::cast<MemRefType>(op.getBColIdxs().getType()).getElementType();
177239038177SAart Bik   Type dType =
177339038177SAart Bik       llvm::cast<MemRefType>(op.getValues().getType()).getElementType();
177439038177SAart Bik   auto ptp = genConstInt32From(rewriter, loc, getCuSparseIndexTypeFrom(pType));
177539038177SAart Bik   auto itp = genConstInt32From(rewriter, loc, getCuSparseIndexTypeFrom(iType));
177639038177SAart Bik   auto dtp = genConstInt32From(rewriter, loc, getCuSparseDataTypeFrom(dType));
177739038177SAart Bik   auto handle =
177839038177SAart Bik       createBsrCallBuilder
177939038177SAart Bik           .create(loc, rewriter,
178039038177SAart Bik                   {adaptor.getBrows(), adaptor.getBcols(), adaptor.getBnnz(),
178139038177SAart Bik                    adaptor.getRBlockSize(), adaptor.getCBlockSize(), pRowPos,
178239038177SAart Bik                    pColIdxs, pValues, ptp, itp, dtp, stream})
178339038177SAart Bik           .getResult();
178439038177SAart Bik   rewriter.replaceOp(op, {handle, stream});
178539038177SAart Bik   return success();
178639038177SAart Bik }
178739038177SAart Bik 
1788*733be4edSAndrea Faulds void mlir::populateGpuToLLVMConversionPatterns(
1789*733be4edSAndrea Faulds     LLVMTypeConverter &converter, RewritePatternSet &patterns,
1790*733be4edSAndrea Faulds     bool kernelBarePtrCallConv, bool kernelIntersperseSizeCallConv) {
179186bf710cSKun Wu   addOpaquePointerConversion<gpu::AsyncTokenType>(converter);
179297f4c22bSKun Wu   addOpaquePointerConversion<gpu::SparseDnTensorHandleType>(converter);
179386bf710cSKun Wu   addOpaquePointerConversion<gpu::SparseSpMatHandleType>(converter);
1794dfe29429SKun Wu   addOpaquePointerConversion<gpu::SparseSpGEMMOpHandleType>(converter);
1795b700a90cSAart Bik 
17967d855605SButygin   patterns.add<ConvertAllocOpToGpuRuntimeCallPattern,
17977d855605SButygin                ConvertDeallocOpToGpuRuntimeCallPattern,
17987d855605SButygin                ConvertHostRegisterOpToGpuRuntimeCallPattern,
17998f7c8a6eSmax                ConvertHostUnregisterOpToGpuRuntimeCallPattern,
18007d855605SButygin                ConvertMemcpyOpToGpuRuntimeCallPattern,
1801361458b1SLoren Maggiore                ConvertMemsetOpToGpuRuntimeCallPattern,
180284718d37SKrzysztof Drewniak                ConvertSetDefaultDeviceOpToGpuRuntimeCallPattern,
18037d855605SButygin                ConvertWaitAsyncOpToGpuRuntimeCallPattern,
18047d855605SButygin                ConvertWaitOpToGpuRuntimeCallPattern,
1805b700a90cSAart Bik                ConvertAsyncYieldToGpuRuntimeCallPattern,
180697f4c22bSKun Wu                ConvertCreateDnTensorOpToGpuRuntimeCallPattern,
180797f4c22bSKun Wu                ConvertDestroyDnTensorOpToGpuRuntimeCallPattern,
1808b700a90cSAart Bik                ConvertCreateCooOpToGpuRuntimeCallPattern,
18099fc02a7aSAart Bik                ConvertCreateCooAoSOpToGpuRuntimeCallPattern,
1810b700a90cSAart Bik                ConvertCreateCsrOpToGpuRuntimeCallPattern,
181139038177SAart Bik                ConvertCreateCscOpToGpuRuntimeCallPattern,
181239038177SAart Bik                ConvertCreateBsrOpToGpuRuntimeCallPattern,
18138ed59c53SKun Wu                ConvertCreate2To4SpMatOpToGpuRuntimeCallPattern,
1814b700a90cSAart Bik                ConvertDestroySpMatOpToGpuRuntimeCallPattern,
18159dfd3c32SAart Bik                ConvertSpMVBufferSizeOpToGpuRuntimeCallPattern,
18169dfd3c32SAart Bik                ConvertSpMVOpToGpuRuntimeCallPattern,
18179dfd3c32SAart Bik                ConvertSpMMBufferSizeOpToGpuRuntimeCallPattern,
18189dfd3c32SAart Bik                ConvertSDDMMBufferSizeOpToGpuRuntimeCallPattern,
18199dfd3c32SAart Bik                ConvertSpMMOpToGpuRuntimeCallPattern,
18209dfd3c32SAart Bik                ConvertSDDMMOpToGpuRuntimeCallPattern,
1821dfe29429SKun Wu                ConvertSpGEMMCreateDescrOpToGpuRuntimeCallPattern,
1822dfe29429SKun Wu                ConvertSpGEMMDestroyDescrOpToGpuRuntimeCallPattern,
1823dfe29429SKun Wu                ConvertSpGEMMWorkEstimationOrComputeOpToGpuRuntimeCallPattern,
1824dfe29429SKun Wu                ConvertSpGEMMCopyOpToGpuRuntimeCallPattern,
1825289f7231SAart Bik                ConvertSpMatGetSizeOpToGpuRuntimeCallPattern,
182695a6c509SAart Bik                ConvertSetCsrPointersOpToGpuRuntimeCallPattern>(converter);
182799a562b3SAndrea Faulds   patterns.add<LegalizeLaunchFuncOpPattern>(converter, kernelBarePtrCallConv,
1828*733be4edSAndrea Faulds                                             kernelIntersperseSizeCallConv);
18297d855605SButygin }
18307498eaa9SFabian Mora 
18317498eaa9SFabian Mora //===----------------------------------------------------------------------===//
18327498eaa9SFabian Mora // GPUModuleOp convert to LLVM op interface
18337498eaa9SFabian Mora //===----------------------------------------------------------------------===//
18347498eaa9SFabian Mora 
18357498eaa9SFabian Mora namespace {
18367498eaa9SFabian Mora struct GPUModuleOpConvertToLLVMInterface
18377498eaa9SFabian Mora     : public ConvertToLLVMOpInterface::ExternalModel<
18387498eaa9SFabian Mora           GPUModuleOpConvertToLLVMInterface, gpu::GPUModuleOp> {
18397498eaa9SFabian Mora   /// Get the conversion patterns from the target attribute.
18407498eaa9SFabian Mora   void getConvertToLLVMConversionAttrs(
18417498eaa9SFabian Mora       Operation *op, SmallVectorImpl<ConvertToLLVMAttrInterface> &attrs) const;
18427498eaa9SFabian Mora };
18437498eaa9SFabian Mora } // namespace
18447498eaa9SFabian Mora 
18457498eaa9SFabian Mora void GPUModuleOpConvertToLLVMInterface::getConvertToLLVMConversionAttrs(
18467498eaa9SFabian Mora     Operation *op, SmallVectorImpl<ConvertToLLVMAttrInterface> &attrs) const {
18477498eaa9SFabian Mora   auto module = cast<gpu::GPUModuleOp>(op);
18487498eaa9SFabian Mora   ArrayAttr targetsAttr = module.getTargetsAttr();
18497498eaa9SFabian Mora   // Fail if there are no target attributes or there is more than one target.
18507498eaa9SFabian Mora   if (!targetsAttr || targetsAttr.size() != 1)
18517498eaa9SFabian Mora     return;
18527498eaa9SFabian Mora   if (auto patternAttr = dyn_cast<ConvertToLLVMAttrInterface>(targetsAttr[0]))
18537498eaa9SFabian Mora     attrs.push_back(patternAttr);
18547498eaa9SFabian Mora }
18557498eaa9SFabian Mora 
18567498eaa9SFabian Mora void mlir::gpu::registerConvertGpuToLLVMInterface(DialectRegistry &registry) {
18577498eaa9SFabian Mora   registry.addExtension(+[](MLIRContext *ctx, gpu::GPUDialect *dialect) {
18587498eaa9SFabian Mora     gpu::GPUModuleOp::attachInterface<GPUModuleOpConvertToLLVMInterface>(*ctx);
18597498eaa9SFabian Mora   });
18607498eaa9SFabian Mora }
1861