xref: /llvm-project/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp (revision 733be4ed7dcf976719f424c0cb81b77a14f91f5a)
1 //===- ConvertLaunchFuncToGpuRuntimeCalls.cpp - MLIR GPU lowering passes --===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a pass to convert gpu.launch_func op into a sequence of
10 // GPU runtime calls. As most of GPU runtimes does not have a stable published
11 // ABI, this pass uses a slim runtime layer that builds on top of the public
12 // API from GPU runtime headers.
13 //
14 //===----------------------------------------------------------------------===//
15 
16 #include "mlir/Conversion/GPUCommon/GPUCommonPass.h"
17 
18 #include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h"
19 #include "mlir/Conversion/AsyncToLLVM/AsyncToLLVM.h"
20 #include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
21 #include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
22 #include "mlir/Conversion/ConvertToLLVM/ToLLVMPass.h"
23 #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h"
24 #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVMPass.h"
25 #include "mlir/Conversion/GPUCommon/GPUToLLVM.h"
26 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
27 #include "mlir/Conversion/LLVMCommon/Pattern.h"
28 #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
29 #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
30 #include "mlir/Dialect/Async/IR/Async.h"
31 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
32 #include "mlir/Dialect/GPU/Transforms/Passes.h"
33 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
34 #include "mlir/Dialect/MemRef/IR/MemRef.h"
35 #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
36 #include "mlir/IR/Attributes.h"
37 #include "mlir/IR/Builders.h"
38 #include "mlir/IR/BuiltinOps.h"
39 #include "mlir/IR/BuiltinTypes.h"
40 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
41 
42 #include "llvm/ADT/STLExtras.h"
43 #include "llvm/Support/Error.h"
44 #include "llvm/Support/FormatVariadic.h"
45 
46 #define DEBUG_TYPE "gpu-to-llvm"
47 
48 namespace mlir {
49 #define GEN_PASS_DEF_GPUTOLLVMCONVERSIONPASS
50 #include "mlir/Conversion/Passes.h.inc"
51 } // namespace mlir
52 
53 using namespace mlir;
54 
55 namespace {
56 class GpuToLLVMConversionPass
57     : public impl::GpuToLLVMConversionPassBase<GpuToLLVMConversionPass> {
58 public:
59   using Base::Base;
60   void getDependentDialects(DialectRegistry &registry) const final {
61     Base::getDependentDialects(registry);
62     registerConvertToLLVMDependentDialectLoading(registry);
63   }
64   // Run the dialect converter on the module.
65   void runOnOperation() override;
66 };
67 
68 template <typename OpTy>
69 class ConvertOpToGpuRuntimeCallPattern : public ConvertOpToLLVMPattern<OpTy> {
70 public:
71   explicit ConvertOpToGpuRuntimeCallPattern(
72       const LLVMTypeConverter &typeConverter)
73       : ConvertOpToLLVMPattern<OpTy>(typeConverter) {}
74 
75 protected:
76   Value getNumElements(ConversionPatternRewriter &rewriter, Location loc,
77                        MemRefType type, MemRefDescriptor desc) const {
78     Type indexType = ConvertToLLVMPattern::getIndexType();
79     return type.hasStaticShape()
80                ? ConvertToLLVMPattern::createIndexAttrConstant(
81                      rewriter, loc, indexType, type.getNumElements())
82                // For identity maps (verified by caller), the number of
83                // elements is stride[0] * size[0].
84                : rewriter.create<LLVM::MulOp>(loc,
85                                               desc.stride(rewriter, loc, 0),
86                                               desc.size(rewriter, loc, 0));
87   }
88 
89   MLIRContext *context = &this->getTypeConverter()->getContext();
90 
91   Type llvmVoidType = LLVM::LLVMVoidType::get(context);
92   LLVM::LLVMPointerType llvmPointerType = LLVM::LLVMPointerType::get(context);
93   Type llvmInt8Type = IntegerType::get(context, 8);
94   Type llvmInt16Type = IntegerType::get(context, 16);
95   Type llvmInt32Type = IntegerType::get(context, 32);
96   Type llvmInt64Type = IntegerType::get(context, 64);
97   Type llvmFloat32Type = Float32Type::get(context);
98   Type llvmIntPtrType = IntegerType::get(
99       context, this->getTypeConverter()->getPointerBitwidth(0));
100 
101   FunctionCallBuilder streamCreateCallBuilder = {
102       "mgpuStreamCreate", llvmPointerType /* void *stream */, {}};
103   FunctionCallBuilder streamDestroyCallBuilder = {
104       "mgpuStreamDestroy", llvmVoidType, {llvmPointerType /* void *stream */}};
105   FunctionCallBuilder streamSynchronizeCallBuilder = {
106       "mgpuStreamSynchronize",
107       llvmVoidType,
108       {llvmPointerType /* void *stream */}};
109   FunctionCallBuilder streamWaitEventCallBuilder = {
110       "mgpuStreamWaitEvent",
111       llvmVoidType,
112       {llvmPointerType /* void *stream */, llvmPointerType /* void *event */}};
113   FunctionCallBuilder eventCreateCallBuilder = {
114       "mgpuEventCreate", llvmPointerType /* void *event */, {}};
115   FunctionCallBuilder eventDestroyCallBuilder = {
116       "mgpuEventDestroy", llvmVoidType, {llvmPointerType /* void *event */}};
117   FunctionCallBuilder eventSynchronizeCallBuilder = {
118       "mgpuEventSynchronize",
119       llvmVoidType,
120       {llvmPointerType /* void *event */}};
121   FunctionCallBuilder eventRecordCallBuilder = {
122       "mgpuEventRecord",
123       llvmVoidType,
124       {llvmPointerType /* void *event */, llvmPointerType /* void *stream */}};
125   FunctionCallBuilder hostRegisterCallBuilder = {
126       "mgpuMemHostRegisterMemRef",
127       llvmVoidType,
128       {llvmIntPtrType /* intptr_t rank */,
129        llvmPointerType /* void *memrefDesc */,
130        llvmIntPtrType /* intptr_t elementSizeBytes */}};
131   FunctionCallBuilder hostUnregisterCallBuilder = {
132       "mgpuMemHostUnregisterMemRef",
133       llvmVoidType,
134       {llvmIntPtrType /* intptr_t rank */,
135        llvmPointerType /* void *memrefDesc */,
136        llvmIntPtrType /* intptr_t elementSizeBytes */}};
137   FunctionCallBuilder allocCallBuilder = {
138       "mgpuMemAlloc",
139       llvmPointerType /* void * */,
140       {llvmIntPtrType /* intptr_t sizeBytes */,
141        llvmPointerType /* void *stream */,
142        llvmInt8Type /* bool isHostShared */}};
143   FunctionCallBuilder deallocCallBuilder = {
144       "mgpuMemFree",
145       llvmVoidType,
146       {llvmPointerType /* void *ptr */, llvmPointerType /* void *stream */}};
147   FunctionCallBuilder memcpyCallBuilder = {
148       "mgpuMemcpy",
149       llvmVoidType,
150       {llvmPointerType /* void *dst */, llvmPointerType /* void *src */,
151        llvmIntPtrType /* intptr_t sizeBytes */,
152        llvmPointerType /* void *stream */}};
153   FunctionCallBuilder memset16CallBuilder = {
154       "mgpuMemset16",
155       llvmVoidType,
156       {llvmPointerType /* void *dst */,
157        llvmInt16Type /* unsigned short value */,
158        llvmIntPtrType /* intptr_t sizeBytes */,
159        llvmPointerType /* void *stream */}};
160   FunctionCallBuilder memset32CallBuilder = {
161       "mgpuMemset32",
162       llvmVoidType,
163       {llvmPointerType /* void *dst */, llvmInt32Type /* unsigned int value */,
164        llvmIntPtrType /* intptr_t sizeBytes */,
165        llvmPointerType /* void *stream */}};
166   FunctionCallBuilder setDefaultDeviceCallBuilder = {
167       "mgpuSetDefaultDevice",
168       llvmVoidType,
169       {llvmInt32Type /* uint32_t devIndex */}};
170   FunctionCallBuilder createDnVecCallBuilder = {
171       "mgpuCreateDnVec",
172       llvmPointerType,
173       {llvmIntPtrType, llvmPointerType, llvmInt32Type,
174        llvmPointerType /* void *stream */}};
175   FunctionCallBuilder destroyDnVecCallBuilder = {
176       "mgpuDestroyDnVec",
177       llvmVoidType,
178       {llvmPointerType, llvmPointerType /* void *stream */}};
179   FunctionCallBuilder createDnMatCallBuilder = {
180       "mgpuCreateDnMat",
181       llvmPointerType,
182       {llvmIntPtrType, llvmIntPtrType, llvmPointerType, llvmInt32Type,
183        llvmPointerType /* void *stream */}};
184   FunctionCallBuilder destroyDnMatCallBuilder = {
185       "mgpuDestroyDnMat",
186       llvmVoidType,
187       {llvmPointerType, llvmPointerType /* void *stream */}};
188   FunctionCallBuilder createCooCallBuilder = {
189       "mgpuCreateCoo",
190       llvmPointerType,
191       {llvmIntPtrType, llvmIntPtrType, llvmIntPtrType, llvmPointerType,
192        llvmPointerType, llvmPointerType, llvmInt32Type, llvmInt32Type,
193        llvmPointerType /* void *stream */}};
194   FunctionCallBuilder createCooAoSCallBuilder = {
195       "mgpuCreateCooAoS", // deprecated in cuSPARSE 11.2
196       llvmPointerType,
197       {llvmIntPtrType, llvmIntPtrType, llvmIntPtrType, llvmPointerType,
198        llvmPointerType, llvmInt32Type, llvmInt32Type,
199        llvmPointerType /* void *stream */}};
200   FunctionCallBuilder createCsrCallBuilder = {
201       "mgpuCreateCsr",
202       llvmPointerType,
203       {llvmIntPtrType, llvmIntPtrType, llvmIntPtrType, llvmPointerType,
204        llvmPointerType, llvmPointerType, llvmInt32Type, llvmInt32Type,
205        llvmInt32Type, llvmPointerType /* void *stream */}};
206   FunctionCallBuilder createCscCallBuilder = {
207       "mgpuCreateCsc",
208       llvmPointerType,
209       {llvmIntPtrType, llvmIntPtrType, llvmIntPtrType, llvmPointerType,
210        llvmPointerType, llvmPointerType, llvmInt32Type, llvmInt32Type,
211        llvmInt32Type, llvmPointerType /* void *stream */}};
212   FunctionCallBuilder createBsrCallBuilder = {
213       "mgpuCreateBsr",
214       llvmPointerType,
215       {llvmIntPtrType, llvmIntPtrType, llvmIntPtrType, llvmIntPtrType,
216        llvmIntPtrType, llvmPointerType, llvmPointerType, llvmPointerType,
217        llvmInt32Type, llvmInt32Type, llvmInt32Type,
218        llvmPointerType /* void *stream */}};
219   FunctionCallBuilder destroySpMatCallBuilder = {
220       "mgpuDestroySpMat",
221       llvmVoidType,
222       {llvmPointerType, llvmPointerType /* void *stream */}};
223   FunctionCallBuilder spMVBufferSizeCallBuilder = {
224       "mgpuSpMVBufferSize",
225       llvmIntPtrType,
226       {llvmInt32Type, llvmPointerType, llvmPointerType, llvmPointerType,
227        llvmInt32Type, llvmPointerType /* void *stream */}};
228   FunctionCallBuilder spMVCallBuilder = {
229       "mgpuSpMV",
230       llvmVoidType,
231       {llvmInt32Type, llvmPointerType, llvmPointerType, llvmPointerType,
232        llvmInt32Type, llvmPointerType, llvmPointerType /* void *stream */}};
233   FunctionCallBuilder createSpMMBufferSizeCallBuilder = {
234       "mgpuSpMMBufferSize",
235       llvmIntPtrType,
236       {llvmInt32Type, llvmInt32Type, llvmPointerType, llvmPointerType,
237        llvmPointerType, llvmInt32Type, llvmPointerType /* void *stream */}};
238   FunctionCallBuilder createSpMMCallBuilder = {
239       "mgpuSpMM",
240       llvmVoidType,
241       {llvmInt32Type, llvmInt32Type, llvmPointerType, llvmPointerType,
242        llvmPointerType, llvmInt32Type, llvmPointerType,
243        llvmPointerType /* void *stream */}};
244   FunctionCallBuilder createSDDMMBufferSizeCallBuilder = {
245       "mgpuSDDMMBufferSize",
246       llvmIntPtrType,
247       {llvmInt32Type, llvmInt32Type, llvmPointerType, llvmPointerType,
248        llvmPointerType, llvmInt32Type, llvmPointerType /* void *stream */}};
249   FunctionCallBuilder createSDDMMCallBuilder = {
250       "mgpuSDDMM",
251       llvmVoidType,
252       {llvmInt32Type, llvmInt32Type, llvmPointerType, llvmPointerType,
253        llvmPointerType, llvmInt32Type, llvmPointerType,
254        llvmPointerType /* void *stream */}};
255   FunctionCallBuilder createLtDnMatCallBuilder = {
256       "mgpuCreateCuSparseLtDnMat",
257       llvmVoidType,
258       {llvmPointerType, llvmIntPtrType, llvmIntPtrType, llvmPointerType,
259        llvmInt32Type, llvmPointerType /* void *stream */}};
260   FunctionCallBuilder destroyCuSparseLtSpMatBuilder = {
261       "mgpuDestroyCuSparseLtSpMat",
262       llvmVoidType,
263       {llvmPointerType, llvmPointerType /* void *stream */}};
264   FunctionCallBuilder destroyCuSparseLtDnMatBuilder = {
265       "mgpuDestroyCuSparseLtDnMat",
266       llvmVoidType,
267       {llvmPointerType, llvmPointerType /* void *stream */}};
268   FunctionCallBuilder create2To4SpMatCallBuilder = {
269       "mgpuCusparseLtCreate2To4SpMat",
270       llvmVoidType,
271       {llvmPointerType, llvmIntPtrType, llvmIntPtrType, llvmPointerType,
272        llvmInt32Type, llvmPointerType /* void *stream */}};
273   FunctionCallBuilder createCuSparseLtSpMMBufferSizeBuilder = {
274       "mgpuCuSparseLtSpMMBufferSize",
275       llvmVoidType,
276       {llvmPointerType, llvmInt32Type, llvmInt32Type, llvmPointerType,
277        llvmPointerType, llvmPointerType, llvmInt32Type, llvmInt32Type,
278        llvmPointerType /*void *stream*/}};
279   FunctionCallBuilder createCuSparseLtSpMMBuilder = {
280       "mgpuCuSparseLtSpMM",
281       llvmVoidType,
282       {llvmPointerType, llvmPointerType, llvmPointerType, llvmPointerType,
283        llvmPointerType, llvmPointerType, llvmPointerType /*void *stream*/}};
284   FunctionCallBuilder createSpGEMMCreateDescrBuilder = {
285       "mgpuSpGEMMCreateDescr",
286       llvmPointerType,
287       {llvmPointerType /*void *stream*/}};
288   FunctionCallBuilder createSpGEMMDestroyDescrBuilder = {
289       "mgpuSpGEMMDestroyDescr",
290       llvmVoidType,
291       {llvmPointerType /*s*/, llvmPointerType /*void *stream*/}};
292   FunctionCallBuilder createSpGEMMWorkEstimationBuilder = {
293       "mgpuSpGEMMWorkEstimation",
294       llvmIntPtrType,
295       {llvmPointerType /*s*/, llvmInt32Type /*ma*/, llvmInt32Type /*mb*/,
296        llvmPointerType /*a*/, llvmPointerType /*b*/, llvmPointerType /*c*/,
297        llvmInt32Type /*ctp*/, llvmIntPtrType /*bs*/, llvmPointerType /*buf*/,
298        llvmPointerType /*void *stream*/}};
299   FunctionCallBuilder createSpGEMMComputeBuilder = {
300       "mgpuSpGEMMCompute",
301       llvmIntPtrType,
302       {llvmPointerType /*s*/, llvmInt32Type /*ma*/, llvmInt32Type /*mb*/,
303        llvmPointerType /*a*/, llvmPointerType /*b*/, llvmPointerType /*c*/,
304        llvmInt32Type /*ctp*/, llvmIntPtrType /*bs*/, llvmPointerType /*buf*/,
305        llvmPointerType /*void *stream*/}};
306   FunctionCallBuilder createSpGEMMCopyBuilder = {
307       "mgpuSpGEMMCopy",
308       llvmVoidType,
309       {llvmPointerType /*s*/, llvmInt32Type /*ma*/, llvmInt32Type /*mb*/,
310        llvmPointerType /*a*/, llvmPointerType /*b*/, llvmPointerType /*c*/,
311        llvmInt32Type /*ctp*/, llvmPointerType /*void *stream*/}};
312   FunctionCallBuilder createSpMatGetSizeBuilder = {
313       "mgpuSpMatGetSize",
314       llvmVoidType,
315       {llvmPointerType /*mc*/, llvmPointerType /*rc*/, llvmPointerType /*cc*/,
316        llvmPointerType /*nc*/, llvmPointerType /*void *stream*/}};
317   FunctionCallBuilder createSetCsrPointersBuilder = {
318       "mgpuSetCsrPointers",
319       llvmVoidType,
320       {llvmPointerType /*spmat*/, llvmPointerType /*pos*/,
321        llvmPointerType /*crd*/, llvmPointerType /*val*/,
322        llvmPointerType /*void *stream*/}};
323 };
324 
325 /// A rewrite pattern to convert gpu.host_register operations into a GPU runtime
326 /// call. Currently it supports CUDA and ROCm (HIP).
327 class ConvertHostRegisterOpToGpuRuntimeCallPattern
328     : public ConvertOpToGpuRuntimeCallPattern<gpu::HostRegisterOp> {
329 public:
330   ConvertHostRegisterOpToGpuRuntimeCallPattern(
331       const LLVMTypeConverter &typeConverter)
332       : ConvertOpToGpuRuntimeCallPattern<gpu::HostRegisterOp>(typeConverter) {}
333 
334 private:
335   LogicalResult
336   matchAndRewrite(gpu::HostRegisterOp hostRegisterOp, OpAdaptor adaptor,
337                   ConversionPatternRewriter &rewriter) const override;
338 };
339 
340 class ConvertHostUnregisterOpToGpuRuntimeCallPattern
341     : public ConvertOpToGpuRuntimeCallPattern<gpu::HostUnregisterOp> {
342 public:
343   ConvertHostUnregisterOpToGpuRuntimeCallPattern(
344       const LLVMTypeConverter &typeConverter)
345       : ConvertOpToGpuRuntimeCallPattern<gpu::HostUnregisterOp>(typeConverter) {
346   }
347 
348 private:
349   LogicalResult
350   matchAndRewrite(gpu::HostUnregisterOp hostUnregisterOp, OpAdaptor adaptor,
351                   ConversionPatternRewriter &rewriter) const override;
352 };
353 
354 /// A rewrite pattern to convert gpu.alloc operations into a GPU runtime
355 /// call. Currently it supports CUDA and ROCm (HIP).
356 class ConvertAllocOpToGpuRuntimeCallPattern
357     : public ConvertOpToGpuRuntimeCallPattern<gpu::AllocOp> {
358 public:
359   ConvertAllocOpToGpuRuntimeCallPattern(const LLVMTypeConverter &typeConverter)
360       : ConvertOpToGpuRuntimeCallPattern<gpu::AllocOp>(typeConverter) {}
361 
362 private:
363   LogicalResult
364   matchAndRewrite(gpu::AllocOp allocOp, OpAdaptor adaptor,
365                   ConversionPatternRewriter &rewriter) const override;
366 };
367 
368 /// A rewrite pattern to convert gpu.dealloc operations into a GPU runtime
369 /// call. Currently it supports CUDA and ROCm (HIP).
370 class ConvertDeallocOpToGpuRuntimeCallPattern
371     : public ConvertOpToGpuRuntimeCallPattern<gpu::DeallocOp> {
372 public:
373   ConvertDeallocOpToGpuRuntimeCallPattern(
374       const LLVMTypeConverter &typeConverter)
375       : ConvertOpToGpuRuntimeCallPattern<gpu::DeallocOp>(typeConverter) {}
376 
377 private:
378   LogicalResult
379   matchAndRewrite(gpu::DeallocOp deallocOp, OpAdaptor adaptor,
380                   ConversionPatternRewriter &rewriter) const override;
381 };
382 
383 class ConvertAsyncYieldToGpuRuntimeCallPattern
384     : public ConvertOpToGpuRuntimeCallPattern<async::YieldOp> {
385 public:
386   ConvertAsyncYieldToGpuRuntimeCallPattern(
387       const LLVMTypeConverter &typeConverter)
388       : ConvertOpToGpuRuntimeCallPattern<async::YieldOp>(typeConverter) {}
389 
390 private:
391   LogicalResult
392   matchAndRewrite(async::YieldOp yieldOp, OpAdaptor adaptor,
393                   ConversionPatternRewriter &rewriter) const override;
394 };
395 
396 /// A rewrite pattern to convert gpu.wait operations into a GPU runtime
397 /// call. Currently it supports CUDA and ROCm (HIP).
398 class ConvertWaitOpToGpuRuntimeCallPattern
399     : public ConvertOpToGpuRuntimeCallPattern<gpu::WaitOp> {
400 public:
401   ConvertWaitOpToGpuRuntimeCallPattern(const LLVMTypeConverter &typeConverter)
402       : ConvertOpToGpuRuntimeCallPattern<gpu::WaitOp>(typeConverter) {}
403 
404 private:
405   LogicalResult
406   matchAndRewrite(gpu::WaitOp waitOp, OpAdaptor adaptor,
407                   ConversionPatternRewriter &rewriter) const override;
408 };
409 
410 /// A rewrite pattern to convert gpu.wait async operations into a GPU runtime
411 /// call. Currently it supports CUDA and ROCm (HIP).
412 class ConvertWaitAsyncOpToGpuRuntimeCallPattern
413     : public ConvertOpToGpuRuntimeCallPattern<gpu::WaitOp> {
414 public:
415   ConvertWaitAsyncOpToGpuRuntimeCallPattern(
416       const LLVMTypeConverter &typeConverter)
417       : ConvertOpToGpuRuntimeCallPattern<gpu::WaitOp>(typeConverter) {}
418 
419 private:
420   LogicalResult
421   matchAndRewrite(gpu::WaitOp waitOp, OpAdaptor adaptor,
422                   ConversionPatternRewriter &rewriter) const override;
423 };
424 
425 /// A rewrite patter to legalize gpu.launch_func with LLVM types.
426 class LegalizeLaunchFuncOpPattern
427     : public ConvertOpToGpuRuntimeCallPattern<gpu::LaunchFuncOp> {
428 public:
429   LegalizeLaunchFuncOpPattern(const LLVMTypeConverter &typeConverter,
430                               bool kernelBarePtrCallConv,
431                               bool kernelIntersperseSizeCallConv)
432       : ConvertOpToGpuRuntimeCallPattern<gpu::LaunchFuncOp>(typeConverter),
433         kernelBarePtrCallConv(kernelBarePtrCallConv),
434         kernelIntersperseSizeCallConv(kernelIntersperseSizeCallConv) {}
435 
436 private:
437   LogicalResult
438   matchAndRewrite(gpu::LaunchFuncOp launchOp, OpAdaptor adaptor,
439                   ConversionPatternRewriter &rewriter) const override;
440 
441   bool kernelBarePtrCallConv;
442   bool kernelIntersperseSizeCallConv;
443 };
444 
445 /// A rewrite pattern to convert gpu.memcpy operations into a GPU runtime
446 /// call. Currently it supports CUDA and ROCm (HIP).
447 class ConvertMemcpyOpToGpuRuntimeCallPattern
448     : public ConvertOpToGpuRuntimeCallPattern<gpu::MemcpyOp> {
449 public:
450   ConvertMemcpyOpToGpuRuntimeCallPattern(const LLVMTypeConverter &typeConverter)
451       : ConvertOpToGpuRuntimeCallPattern<gpu::MemcpyOp>(typeConverter) {}
452 
453 private:
454   LogicalResult
455   matchAndRewrite(gpu::MemcpyOp memcpyOp, OpAdaptor adaptor,
456                   ConversionPatternRewriter &rewriter) const override;
457 };
458 
459 /// A rewrite pattern to convert gpu.memset operations into a GPU runtime
460 /// call. Currently it supports CUDA and ROCm (HIP).
461 class ConvertMemsetOpToGpuRuntimeCallPattern
462     : public ConvertOpToGpuRuntimeCallPattern<gpu::MemsetOp> {
463 public:
464   ConvertMemsetOpToGpuRuntimeCallPattern(const LLVMTypeConverter &typeConverter)
465       : ConvertOpToGpuRuntimeCallPattern<gpu::MemsetOp>(typeConverter) {}
466 
467 private:
468   LogicalResult
469   matchAndRewrite(gpu::MemsetOp memsetOp, OpAdaptor adaptor,
470                   ConversionPatternRewriter &rewriter) const override;
471 };
472 
473 /// A rewrite pattern to convert gpu.set_default_device to a GPU runtime call.
474 /// Currently supports CUDA and ROCm (HIP)
475 class ConvertSetDefaultDeviceOpToGpuRuntimeCallPattern
476     : public ConvertOpToGpuRuntimeCallPattern<gpu::SetDefaultDeviceOp> {
477 public:
478   ConvertSetDefaultDeviceOpToGpuRuntimeCallPattern(
479       const LLVMTypeConverter &typeConverter)
480       : ConvertOpToGpuRuntimeCallPattern<gpu::SetDefaultDeviceOp>(
481             typeConverter) {}
482 
483   LogicalResult
484   matchAndRewrite(gpu::SetDefaultDeviceOp op, OpAdaptor adaptor,
485                   ConversionPatternRewriter &rewriter) const override;
486 };
487 
488 /// Generic rewriting rule for operation on sparse matrices.
489 /// Currently supports CUDA (by means of cuSparse and cuSparseLt).
490 #define DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(op_name)                \
491   class Convert##op_name##ToGpuRuntimeCallPattern                              \
492       : public ConvertOpToGpuRuntimeCallPattern<gpu::op_name> {                \
493   public:                                                                      \
494     Convert##op_name##ToGpuRuntimeCallPattern(                                 \
495         const LLVMTypeConverter &typeConverter)                                \
496         : ConvertOpToGpuRuntimeCallPattern<gpu::op_name>(typeConverter) {}     \
497                                                                                \
498   private:                                                                     \
499     LogicalResult                                                              \
500     matchAndRewrite(gpu::op_name op, OpAdaptor adaptor,                        \
501                     ConversionPatternRewriter &rewriter) const override;       \
502   };
503 
504 DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(CreateDnTensorOp)
505 DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(DestroyDnTensorOp)
506 DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(CreateCooOp)
507 DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(CreateCooAoSOp)
508 DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(CreateCsrOp)
509 DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(CreateCscOp)
510 DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(CreateBsrOp)
511 DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(Create2To4SpMatOp)
512 DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(DestroySpMatOp)
513 DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(SpMVBufferSizeOp)
514 DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(SpMVOp)
515 DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(SpMMBufferSizeOp)
516 DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(SDDMMBufferSizeOp)
517 DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(SpMMOp)
518 DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(SDDMMOp)
519 DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(SpGEMMCreateDescrOp)
520 DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(SpGEMMDestroyDescrOp)
521 DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(SpGEMMWorkEstimationOrComputeOp)
522 DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(SpGEMMCopyOp)
523 DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(SpMatGetSizeOp)
524 DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(SetCsrPointersOp)
525 
526 } // namespace
527 
528 void GpuToLLVMConversionPass::runOnOperation() {
529   MLIRContext *context = &getContext();
530 
531   // Perform progressive lowering of vector transfer operations.
532   {
533     RewritePatternSet patterns(&getContext());
534     // Vector transfer ops with rank > 1 should be lowered with VectorToSCF.
535     vector::populateVectorTransferLoweringPatterns(patterns,
536                                                    /*maxTransferRank=*/1);
537     if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
538       return signalPassFailure();
539   }
540 
541   LowerToLLVMOptions options(context);
542   options.useBarePtrCallConv = hostBarePtrCallConv;
543   RewritePatternSet patterns(context);
544   ConversionTarget target(*context);
545   target.addLegalDialect<LLVM::LLVMDialect>();
546   LLVMTypeConverter converter(context, options);
547 
548   // Populate all patterns from all dialects that implement the
549   // `ConvertToLLVMPatternInterface` interface.
550   for (Dialect *dialect : context->getLoadedDialects()) {
551     auto iface = dyn_cast<ConvertToLLVMPatternInterface>(dialect);
552     if (!iface)
553       continue;
554     iface->populateConvertToLLVMConversionPatterns(target, converter, patterns);
555   }
556 
557   // Preserve GPU modules and binaries. Modules are preserved as they can be
558   // converted later by `gpu-module-to-binary`.
559   target.addLegalOp<gpu::GPUModuleOp, gpu::BinaryOp>();
560   // Accept as legal LaunchFuncOps if the operands have been lowered.
561   target.addDynamicallyLegalOp<gpu::LaunchFuncOp>(
562       [&](gpu::LaunchFuncOp op) -> bool { return converter.isLegal(op); });
563 
564   // These aren't covered by the ConvertToLLVMPatternInterface right now.
565   populateVectorToLLVMConversionPatterns(converter, patterns);
566   populateFinalizeMemRefToLLVMConversionPatterns(converter, patterns);
567   populateAsyncStructuralTypeConversionsAndLegality(converter, patterns,
568                                                     target);
569   populateGpuToLLVMConversionPatterns(converter, patterns,
570                                       kernelBarePtrCallConv,
571                                       kernelIntersperseSizeCallConv);
572 
573   if (failed(
574           applyPartialConversion(getOperation(), target, std::move(patterns))))
575     signalPassFailure();
576 }
577 
578 LLVM::CallOp FunctionCallBuilder::create(Location loc, OpBuilder &builder,
579                                          ArrayRef<Value> arguments) const {
580   auto module = builder.getBlock()->getParent()->getParentOfType<ModuleOp>();
581   auto function = [&] {
582     if (auto function = module.lookupSymbol<LLVM::LLVMFuncOp>(functionName))
583       return function;
584     return OpBuilder::atBlockEnd(module.getBody())
585         .create<LLVM::LLVMFuncOp>(loc, functionName, functionType);
586   }();
587   return builder.create<LLVM::CallOp>(loc, function, arguments);
588 }
589 
590 // Corresponding to cusparseIndexType_t defined in cusparse.h.
591 static int32_t getCuSparseIndexTypeFrom(Type type) {
592   if (type.isInteger(16))
593     return 1; // CUSPARSE_INDEX_16U
594   if (type.isInteger(32))
595     return 2; // CUSPARSE_INDEX_32I
596   return 3;   // CUSPARSE_INDEX_64I
597 }
598 
599 static int32_t getCuSparseLtDataTypeFrom(Type type) {
600   if (type.isF16())
601     return 0; // CUSPARSE_COMPUTE_16F,
602   if (type.isInteger(32))
603     return 1; // CUSPARSE_COMPUTE_32I
604   llvm_unreachable("unsupported type");
605   // TODO: add support to TF32
606 }
607 
608 // Corresponding to cudaDataType_t defined in CUDA library_types.h.
609 static int32_t getCuSparseDataTypeFrom(Type type) {
610   if (llvm::isa<ComplexType>(type)) {
611     // get the element type
612     auto elementType = cast<ComplexType>(type).getElementType();
613     if (elementType.isBF16())
614       return 15; // CUDA_C_16BF
615     if (elementType.isF16())
616       return 6; // CUDA_C_16F
617     if (elementType.isF32())
618       return 4; // CUDA_C_32F
619     if (elementType.isF64())
620       return 5; // CUDA_C_64F
621     if (elementType.isInteger(8))
622       return 7; // CUDA_C_8I
623     if (elementType.isInteger(16))
624       return 21; // CUDA_C_16I
625     if (elementType.isInteger(32))
626       return 11; // CUDA_C_32I
627   }
628   if (type.isBF16())
629     return 14; // CUDA_R_16BF
630   if (type.isF16())
631     return 2; // CUDA_R_16F
632   if (type.isF32())
633     return 0; // CUDA_R_32F
634   if (type.isF64())
635     return 1; // CUDA_R_64F
636   if (type.isInteger(8))
637     return 3; // CUDA_R_8I
638   if (type.isInteger(16))
639     return 20; // CUDA_R_16I
640   if (type.isInteger(32))
641     return 10; // CUDA_R_32I
642 
643   llvm_unreachable("unsupported element type");
644 }
645 
646 static gpu::Prune2To4SpMatFlag get2To4PruneFlag(Value spMat) {
647   return spMat.getDefiningOp<gpu::Create2To4SpMatOp>().getPruneFlag();
648 }
649 
650 // TODO:  We may want a run-time (of the mlir compiler) disablement/warning:
651 // cusparseLt currently won't work for cuda architecture <8.0 and will trigger a
652 // runtime (of the CUDA program) error , but it might be great if we could at
653 // least output a warning when we found the target architecture is <8.0 and the
654 // user still wants to use cusparseLt. to make sure when lowering gpu sparse
655 // dialect to llvm calls, the cusparselt calls are disabled for cuda
656 // architecture <8.0
657 static bool is2To4Sparsity(Value spMat) {
658   if (auto op = spMat.getDefiningOp<gpu::Create2To4SpMatOp>())
659     return true;
660   if (auto op = spMat.getDefiningOp<gpu::CreateCooOp>())
661     return false;
662   if (auto op = spMat.getDefiningOp<gpu::CreateCooAoSOp>())
663     return false;
664   if (auto op = spMat.getDefiningOp<gpu::CreateCsrOp>())
665     return false;
666   if (auto op = spMat.getDefiningOp<gpu::CreateCscOp>())
667     return false;
668   if (auto op = spMat.getDefiningOp<gpu::CreateBsrOp>())
669     return false;
670   // Print the spMat defining op
671   spMat.getDefiningOp()->print(llvm::errs());
672   llvm_unreachable("cannot find spmat def");
673 }
674 
675 static bool isSpMMCusparseLtOp(Value op) {
676   for (Operation *user : op.getUsers()) {
677     auto spmmOp = dyn_cast<gpu::SpMMOp>(user);
678     // If the other operator is 50% sparsity then we should use cusparseLt
679     if (!spmmOp)
680       continue;
681     if (is2To4Sparsity(spmmOp.getSpmatA()))
682       return true;
683   }
684   return false;
685 }
686 
687 // Returns whether all operands are of LLVM type.
688 static LogicalResult areAllLLVMTypes(Operation *op, ValueRange operands,
689                                      ConversionPatternRewriter &rewriter) {
690   if (!llvm::all_of(operands, [](Value value) {
691         return LLVM::isCompatibleType(value.getType());
692       }))
693     return rewriter.notifyMatchFailure(
694         op, "Cannot convert if operands aren't of LLVM type.");
695   return success();
696 }
697 
698 static LogicalResult
699 isAsyncWithOneDependency(ConversionPatternRewriter &rewriter,
700                          gpu::AsyncOpInterface op) {
701   if (op.getAsyncDependencies().size() != 1)
702     return rewriter.notifyMatchFailure(
703         op, "Can only convert with exactly one async dependency.");
704 
705   if (!op.getAsyncToken())
706     return rewriter.notifyMatchFailure(op, "Can convert only async version.");
707 
708   return success();
709 }
710 
711 LogicalResult ConvertHostRegisterOpToGpuRuntimeCallPattern::matchAndRewrite(
712     gpu::HostRegisterOp hostRegisterOp, OpAdaptor adaptor,
713     ConversionPatternRewriter &rewriter) const {
714   auto *op = hostRegisterOp.getOperation();
715   if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)))
716     return failure();
717 
718   Location loc = op->getLoc();
719 
720   auto memRefType = hostRegisterOp.getValue().getType();
721   auto elementType = cast<UnrankedMemRefType>(memRefType).getElementType();
722   auto elementSize = getSizeInBytes(loc, elementType, rewriter);
723 
724   auto arguments = getTypeConverter()->promoteOperands(
725       loc, op->getOperands(), adaptor.getOperands(), rewriter);
726   arguments.push_back(elementSize);
727   hostRegisterCallBuilder.create(loc, rewriter, arguments);
728 
729   rewriter.eraseOp(op);
730   return success();
731 }
732 
733 LogicalResult ConvertHostUnregisterOpToGpuRuntimeCallPattern::matchAndRewrite(
734     gpu::HostUnregisterOp hostUnregisterOp, OpAdaptor adaptor,
735     ConversionPatternRewriter &rewriter) const {
736   Operation *op = hostUnregisterOp.getOperation();
737   if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)))
738     return failure();
739 
740   Location loc = op->getLoc();
741 
742   auto memRefType = hostUnregisterOp.getValue().getType();
743   auto elementType = cast<UnrankedMemRefType>(memRefType).getElementType();
744   auto elementSize = getSizeInBytes(loc, elementType, rewriter);
745 
746   auto arguments = getTypeConverter()->promoteOperands(
747       loc, op->getOperands(), adaptor.getOperands(), rewriter);
748   arguments.push_back(elementSize);
749   hostUnregisterCallBuilder.create(loc, rewriter, arguments);
750 
751   rewriter.eraseOp(op);
752   return success();
753 }
754 
755 LogicalResult ConvertAllocOpToGpuRuntimeCallPattern::matchAndRewrite(
756     gpu::AllocOp allocOp, OpAdaptor adaptor,
757     ConversionPatternRewriter &rewriter) const {
758 
759   MemRefType memRefType = allocOp.getType();
760 
761   if (failed(areAllLLVMTypes(allocOp, adaptor.getOperands(), rewriter)) ||
762       !isConvertibleAndHasIdentityMaps(memRefType))
763     return failure();
764 
765   auto loc = allocOp.getLoc();
766 
767   bool isShared = allocOp.getHostShared();
768 
769   if (isShared && allocOp.getAsyncToken())
770     return rewriter.notifyMatchFailure(
771         allocOp, "Host Shared allocation cannot be done async");
772   if (!isShared && failed(isAsyncWithOneDependency(rewriter, allocOp)))
773     return failure();
774 
775   // Get shape of the memref as values: static sizes are constant
776   // values and dynamic sizes are passed to 'alloc' as operands.
777   SmallVector<Value, 4> shape;
778   SmallVector<Value, 4> strides;
779   Value sizeBytes;
780   getMemRefDescriptorSizes(loc, memRefType, adaptor.getDynamicSizes(), rewriter,
781                            shape, strides, sizeBytes);
782 
783   // Allocate the underlying buffer and store a pointer to it in the MemRef
784   // descriptor.
785   auto nullPtr = rewriter.create<mlir::LLVM::ZeroOp>(loc, llvmPointerType);
786   Value stream = adaptor.getAsyncDependencies().empty()
787                      ? nullPtr
788                      : adaptor.getAsyncDependencies().front();
789 
790   auto isHostShared = rewriter.create<mlir::LLVM::ConstantOp>(
791       loc, llvmInt8Type, rewriter.getI8IntegerAttr(isShared));
792 
793   Value allocatedPtr =
794       allocCallBuilder.create(loc, rewriter, {sizeBytes, stream, isHostShared})
795           .getResult();
796 
797   // No alignment.
798   Value alignedPtr = allocatedPtr;
799 
800   // Create the MemRef descriptor.
801   auto memRefDescriptor = this->createMemRefDescriptor(
802       loc, memRefType, allocatedPtr, alignedPtr, shape, strides, rewriter);
803 
804   if (allocOp.getAsyncToken()) {
805     // Async alloc: make dependent ops use the same stream.
806     rewriter.replaceOp(allocOp, {memRefDescriptor, stream});
807   } else {
808     rewriter.replaceOp(allocOp, {memRefDescriptor});
809   }
810 
811   return success();
812 }
813 
814 LogicalResult ConvertDeallocOpToGpuRuntimeCallPattern::matchAndRewrite(
815     gpu::DeallocOp deallocOp, OpAdaptor adaptor,
816     ConversionPatternRewriter &rewriter) const {
817   if (failed(areAllLLVMTypes(deallocOp, adaptor.getOperands(), rewriter)) ||
818       failed(isAsyncWithOneDependency(rewriter, deallocOp)))
819     return failure();
820 
821   Location loc = deallocOp.getLoc();
822 
823   Value pointer =
824       MemRefDescriptor(adaptor.getMemref()).allocatedPtr(rewriter, loc);
825   Value stream = adaptor.getAsyncDependencies().front();
826   deallocCallBuilder.create(loc, rewriter, {pointer, stream});
827 
828   rewriter.replaceOp(deallocOp, {stream});
829   return success();
830 }
831 
832 static bool isGpuAsyncTokenType(Value value) {
833   return isa<gpu::AsyncTokenType>(value.getType());
834 }
835 
836 // Converts !gpu.async.token operands of `async.yield` to runtime calls. The
837 // !gpu.async.token are lowered to stream within the async.execute region, but
838 // are passed as events between them. For each !gpu.async.token operand, we
839 // create an event and record it on the stream.
840 LogicalResult ConvertAsyncYieldToGpuRuntimeCallPattern::matchAndRewrite(
841     async::YieldOp yieldOp, OpAdaptor adaptor,
842     ConversionPatternRewriter &rewriter) const {
843   if (llvm::none_of(yieldOp.getOperands(), isGpuAsyncTokenType))
844     return rewriter.notifyMatchFailure(yieldOp, "no gpu async token operand");
845 
846   Location loc = yieldOp.getLoc();
847   SmallVector<Value, 4> newOperands(adaptor.getOperands());
848   llvm::SmallDenseSet<Value> streams;
849   for (auto &operand : yieldOp->getOpOperands()) {
850     if (!isGpuAsyncTokenType(operand.get()))
851       continue;
852     auto idx = operand.getOperandNumber();
853     auto stream = adaptor.getOperands()[idx];
854     auto event = eventCreateCallBuilder.create(loc, rewriter, {}).getResult();
855     eventRecordCallBuilder.create(loc, rewriter, {event, stream});
856     newOperands[idx] = event;
857     streams.insert(stream);
858   }
859   for (auto stream : streams)
860     streamDestroyCallBuilder.create(loc, rewriter, {stream});
861 
862   rewriter.modifyOpInPlace(yieldOp, [&] { yieldOp->setOperands(newOperands); });
863   return success();
864 }
865 
866 // Returns whether `value` is the result of an LLVM::CallOp to `functionName`.
867 static bool isDefinedByCallTo(Value value, StringRef functionName) {
868   assert(isa<LLVM::LLVMPointerType>(value.getType()));
869   if (auto defOp = value.getDefiningOp<LLVM::CallOp>())
870     return *defOp.getCallee() == functionName;
871   return false;
872 }
873 
874 // Converts `gpu.wait` to runtime calls. The converted op synchronizes the host
875 // with the stream/event operands. The operands are destroyed. That is, it
876 // assumes that it is not used afterwards or elsewhere. Otherwise we will get a
877 // runtime error. Eventually, we should guarantee this property.
878 LogicalResult ConvertWaitOpToGpuRuntimeCallPattern::matchAndRewrite(
879     gpu::WaitOp waitOp, OpAdaptor adaptor,
880     ConversionPatternRewriter &rewriter) const {
881   if (waitOp.getAsyncToken())
882     return rewriter.notifyMatchFailure(waitOp, "Cannot convert async op.");
883 
884   Location loc = waitOp.getLoc();
885 
886   for (auto operand : adaptor.getOperands()) {
887     if (isDefinedByCallTo(operand, streamCreateCallBuilder.functionName)) {
888       // The converted operand's definition created a stream.
889       streamSynchronizeCallBuilder.create(loc, rewriter, {operand});
890       streamDestroyCallBuilder.create(loc, rewriter, {operand});
891     } else {
892       // Otherwise the converted operand is an event. This assumes that we use
893       // events in control flow code as well.
894       eventSynchronizeCallBuilder.create(loc, rewriter, {operand});
895       eventDestroyCallBuilder.create(loc, rewriter, {operand});
896     }
897   }
898 
899   rewriter.eraseOp(waitOp);
900   return success();
901 }
902 
903 // Converts `gpu.wait async` to runtime calls. The converted op creates a new
904 // stream that is synchronized with stream/event operands. The operands are
905 // destroyed. That is, it assumes that it is not used afterwards or elsewhere.
906 // Otherwise we will get a runtime error. Eventually, we should guarantee this
907 // property.
908 LogicalResult ConvertWaitAsyncOpToGpuRuntimeCallPattern::matchAndRewrite(
909     gpu::WaitOp waitOp, OpAdaptor adaptor,
910     ConversionPatternRewriter &rewriter) const {
911   if (!waitOp.getAsyncToken())
912     return rewriter.notifyMatchFailure(waitOp, "Can only convert async op.");
913 
914   Location loc = waitOp.getLoc();
915 
916   auto insertionPoint = rewriter.saveInsertionPoint();
917   SmallVector<Value, 1> events;
918   for (auto pair :
919        llvm::zip(waitOp.getAsyncDependencies(), adaptor.getOperands())) {
920     auto operand = std::get<1>(pair);
921     if (isDefinedByCallTo(operand, streamCreateCallBuilder.functionName)) {
922       // The converted operand's definition created a stream. Insert an event
923       // into the stream just after the last use of the original token operand.
924       auto *defOp = std::get<0>(pair).getDefiningOp();
925       rewriter.setInsertionPointAfter(defOp);
926       auto event = eventCreateCallBuilder.create(loc, rewriter, {}).getResult();
927       eventRecordCallBuilder.create(loc, rewriter, {event, operand});
928       events.push_back(event);
929     } else {
930       // Otherwise the converted operand is an event. This assumes that we use
931       // events in control flow code as well.
932       events.push_back(operand);
933     }
934   }
935   rewriter.restoreInsertionPoint(insertionPoint);
936   auto stream = streamCreateCallBuilder.create(loc, rewriter, {}).getResult();
937   for (auto event : events)
938     streamWaitEventCallBuilder.create(loc, rewriter, {stream, event});
939   for (auto event : events)
940     eventDestroyCallBuilder.create(loc, rewriter, {event});
941   rewriter.replaceOp(waitOp, {stream});
942 
943   return success();
944 }
945 
946 // Legalize the op's operands.
947 LogicalResult LegalizeLaunchFuncOpPattern::matchAndRewrite(
948     gpu::LaunchFuncOp launchOp, OpAdaptor adaptor,
949     ConversionPatternRewriter &rewriter) const {
950   if (failed(areAllLLVMTypes(launchOp, adaptor.getOperands(), rewriter)))
951     return failure();
952 
953   if (launchOp.getAsyncDependencies().size() > 1)
954     return rewriter.notifyMatchFailure(
955         launchOp, "Cannot convert with more than one async dependency.");
956 
957   // Fail when the synchronous version of the op has async dependencies. The
958   // lowering destroys the stream, and we do not want to check that there is no
959   // use of the stream after this op.
960   if (!launchOp.getAsyncToken() && !launchOp.getAsyncDependencies().empty())
961     return rewriter.notifyMatchFailure(
962         launchOp, "Cannot convert non-async op with async dependencies.");
963 
964   Location loc = launchOp.getLoc();
965 
966   Value stream = Value();
967   if (!adaptor.getAsyncDependencies().empty())
968     stream = adaptor.getAsyncDependencies().front();
969   // If the async keyword is present and there are no dependencies, then a
970   // stream must be created to pass to subsequent operations.
971   else if (launchOp.getAsyncToken())
972     stream = streamCreateCallBuilder.create(loc, rewriter, {}).getResult();
973 
974   // Lower the kernel operands to match kernel parameters.
975   // Note: If `useBarePtrCallConv` is set in the type converter's options,
976   // the value of `kernelBarePtrCallConv` will be ignored.
977   OperandRange origArguments = launchOp.getKernelOperands();
978   SmallVector<Value, 8> llvmArguments = getTypeConverter()->promoteOperands(
979       loc, origArguments, adaptor.getKernelOperands(), rewriter,
980       /*useBarePtrCallConv=*/kernelBarePtrCallConv);
981   SmallVector<Value, 8> llvmArgumentsWithSizes;
982 
983   // Intersperse size information if requested.
984   if (kernelIntersperseSizeCallConv) {
985     if (origArguments.size() != llvmArguments.size()) {
986       // This shouldn't happen if the bare-pointer calling convention is used.
987       return rewriter.notifyMatchFailure(
988           launchOp,
989           "Cannot add sizes to arguments with one-to-many LLVM IR expansion.");
990     }
991 
992     llvmArgumentsWithSizes.reserve(llvmArguments.size() * 2);
993     for (auto [llvmArg, origArg] : zip_equal(llvmArguments, origArguments)) {
994       auto memrefTy = dyn_cast<MemRefType>(origArg.getType());
995       if (!memrefTy) {
996         return rewriter.notifyMatchFailure(
997             launchOp, "Operand to launch op is not a memref.");
998       }
999 
1000       if (!memrefTy.hasStaticShape() ||
1001           !memrefTy.getElementType().isIntOrFloat()) {
1002         return rewriter.notifyMatchFailure(
1003             launchOp, "Operand to launch op is not a memref with a static "
1004                       "shape and an integer or float element type.");
1005       }
1006 
1007       unsigned bitwidth = memrefTy.getElementTypeBitWidth();
1008       if (bitwidth % 8 != 0) {
1009         return rewriter.notifyMatchFailure(
1010             launchOp, "Operand to launch op is not a memref with a "
1011                       "byte-aligned element type.");
1012       }
1013 
1014       uint64_t staticSize = static_cast<uint64_t>(bitwidth / 8) *
1015                             static_cast<uint64_t>(memrefTy.getNumElements());
1016 
1017       Value sizeArg = rewriter.create<LLVM::ConstantOp>(
1018           loc, getIndexType(), rewriter.getIndexAttr(staticSize));
1019       llvmArgumentsWithSizes.push_back(llvmArg); // Presumably a bare pointer.
1020       llvmArgumentsWithSizes.push_back(sizeArg);
1021     }
1022   }
1023 
1024   std::optional<gpu::KernelDim3> clusterSize = std::nullopt;
1025   if (launchOp.hasClusterSize()) {
1026     clusterSize =
1027         gpu::KernelDim3{adaptor.getClusterSizeX(), adaptor.getClusterSizeY(),
1028                         adaptor.getClusterSizeZ()};
1029   }
1030   rewriter.create<gpu::LaunchFuncOp>(
1031       launchOp.getLoc(), launchOp.getKernelAttr(),
1032       gpu::KernelDim3{adaptor.getGridSizeX(), adaptor.getGridSizeY(),
1033                       adaptor.getGridSizeZ()},
1034       gpu::KernelDim3{adaptor.getBlockSizeX(), adaptor.getBlockSizeY(),
1035                       adaptor.getBlockSizeZ()},
1036       adaptor.getDynamicSharedMemorySize(),
1037       llvmArgumentsWithSizes.empty() ? llvmArguments : llvmArgumentsWithSizes,
1038       stream, clusterSize);
1039   if (launchOp.getAsyncToken())
1040     rewriter.replaceOp(launchOp, {stream});
1041   else
1042     rewriter.eraseOp(launchOp);
1043   return success();
1044 }
1045 
1046 static Value bitAndAddrspaceCast(Location loc,
1047                                  ConversionPatternRewriter &rewriter,
1048                                  LLVM::LLVMPointerType destinationType,
1049                                  Value sourcePtr,
1050                                  const LLVMTypeConverter &typeConverter) {
1051   auto sourceTy = cast<LLVM::LLVMPointerType>(sourcePtr.getType());
1052   if (destinationType.getAddressSpace() != sourceTy.getAddressSpace())
1053     sourcePtr = rewriter.create<LLVM::AddrSpaceCastOp>(
1054         loc,
1055         LLVM::LLVMPointerType::get(rewriter.getContext(),
1056                                    destinationType.getAddressSpace()),
1057         sourcePtr);
1058   return sourcePtr;
1059 }
1060 
1061 LogicalResult ConvertMemcpyOpToGpuRuntimeCallPattern::matchAndRewrite(
1062     gpu::MemcpyOp memcpyOp, OpAdaptor adaptor,
1063     ConversionPatternRewriter &rewriter) const {
1064   auto memRefType = cast<MemRefType>(memcpyOp.getSrc().getType());
1065 
1066   if (failed(areAllLLVMTypes(memcpyOp, adaptor.getOperands(), rewriter)) ||
1067       !isConvertibleAndHasIdentityMaps(memRefType) ||
1068       failed(isAsyncWithOneDependency(rewriter, memcpyOp)))
1069     return failure();
1070 
1071   auto loc = memcpyOp.getLoc();
1072 
1073   MemRefDescriptor srcDesc(adaptor.getSrc());
1074   Value numElements = getNumElements(rewriter, loc, memRefType, srcDesc);
1075 
1076   Type elementPtrType = getElementPtrType(memRefType);
1077   Value nullPtr = rewriter.create<LLVM::ZeroOp>(loc, elementPtrType);
1078   Value gepPtr = rewriter.create<LLVM::GEPOp>(
1079       loc, elementPtrType,
1080       typeConverter->convertType(memRefType.getElementType()), nullPtr,
1081       numElements);
1082   auto sizeBytes =
1083       rewriter.create<LLVM::PtrToIntOp>(loc, getIndexType(), gepPtr);
1084 
1085   auto src = bitAndAddrspaceCast(loc, rewriter, llvmPointerType,
1086                                  srcDesc.alignedPtr(rewriter, loc),
1087                                  *getTypeConverter());
1088   auto dst = bitAndAddrspaceCast(
1089       loc, rewriter, llvmPointerType,
1090       MemRefDescriptor(adaptor.getDst()).alignedPtr(rewriter, loc),
1091       *getTypeConverter());
1092 
1093   auto stream = adaptor.getAsyncDependencies().front();
1094   memcpyCallBuilder.create(loc, rewriter, {dst, src, sizeBytes, stream});
1095 
1096   rewriter.replaceOp(memcpyOp, {stream});
1097 
1098   return success();
1099 }
1100 
1101 LogicalResult ConvertMemsetOpToGpuRuntimeCallPattern::matchAndRewrite(
1102     gpu::MemsetOp memsetOp, OpAdaptor adaptor,
1103     ConversionPatternRewriter &rewriter) const {
1104   auto memRefType = cast<MemRefType>(memsetOp.getDst().getType());
1105 
1106   if (failed(areAllLLVMTypes(memsetOp, adaptor.getOperands(), rewriter)) ||
1107       !isConvertibleAndHasIdentityMaps(memRefType) ||
1108       failed(isAsyncWithOneDependency(rewriter, memsetOp)))
1109     return failure();
1110 
1111   auto loc = memsetOp.getLoc();
1112 
1113   Type valueType = adaptor.getValue().getType();
1114   unsigned bitWidth = valueType.getIntOrFloatBitWidth();
1115   // Ints and floats of 16 or 32 bit width are allowed.
1116   if (!valueType.isIntOrFloat() || (bitWidth != 16 && bitWidth != 32)) {
1117     return rewriter.notifyMatchFailure(
1118         memsetOp, "value must be a 16 or 32 bit int or float");
1119   }
1120 
1121   unsigned valueTypeWidth = valueType.getIntOrFloatBitWidth();
1122   Type bitCastType = valueTypeWidth == 32 ? llvmInt32Type : llvmInt16Type;
1123 
1124   MemRefDescriptor dstDesc(adaptor.getDst());
1125   Value numElements = getNumElements(rewriter, loc, memRefType, dstDesc);
1126 
1127   auto value =
1128       rewriter.create<LLVM::BitcastOp>(loc, bitCastType, adaptor.getValue());
1129   auto dst = bitAndAddrspaceCast(loc, rewriter, llvmPointerType,
1130                                  dstDesc.alignedPtr(rewriter, loc),
1131                                  *getTypeConverter());
1132 
1133   auto stream = adaptor.getAsyncDependencies().front();
1134   FunctionCallBuilder builder =
1135       valueTypeWidth == 32 ? memset32CallBuilder : memset16CallBuilder;
1136   builder.create(loc, rewriter, {dst, value, numElements, stream});
1137 
1138   rewriter.replaceOp(memsetOp, {stream});
1139   return success();
1140 }
1141 
1142 LogicalResult ConvertSetDefaultDeviceOpToGpuRuntimeCallPattern::matchAndRewrite(
1143     gpu::SetDefaultDeviceOp op, OpAdaptor adaptor,
1144     ConversionPatternRewriter &rewriter) const {
1145   Location loc = op.getLoc();
1146   auto call = setDefaultDeviceCallBuilder.create(loc, rewriter,
1147                                                  {adaptor.getDevIndex()});
1148   rewriter.replaceOp(op, call);
1149   return success();
1150 }
1151 
1152 template <typename T>
1153 static Value genConstInt32From(OpBuilder &builder, Location loc, T tValue) {
1154   Type llvmInt32Type = builder.getIntegerType(32);
1155   return builder.create<LLVM::ConstantOp>(loc, llvmInt32Type,
1156                                           static_cast<int32_t>(tValue));
1157 }
1158 
1159 template <typename T>
1160 static Value genConstFloat32From(OpBuilder &builder, Location loc, T tValue) {
1161   Type llvmFloat32Type = builder.getF32Type();
1162   return builder.create<LLVM::ConstantOp>(
1163       loc, llvmFloat32Type,
1164       builder.getF32FloatAttr(static_cast<float>(tValue)));
1165 }
1166 
1167 LogicalResult ConvertCreateDnTensorOpToGpuRuntimeCallPattern::matchAndRewrite(
1168     gpu::CreateDnTensorOp op, OpAdaptor adaptor,
1169     ConversionPatternRewriter &rewriter) const {
1170   if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1171       failed(isAsyncWithOneDependency(rewriter, op)))
1172     return failure();
1173   Location loc = op.getLoc();
1174   auto stream = adaptor.getAsyncDependencies().front();
1175   Value pTensor =
1176       MemRefDescriptor(adaptor.getMemref()).allocatedPtr(rewriter, loc);
1177   Type dType = op.getMemref().getType().getElementType();
1178   auto dtp = genConstInt32From(rewriter, loc, getCuSparseDataTypeFrom(dType));
1179 
1180   SmallVector<Value, 4> dims;
1181   for (Value dim : adaptor.getDims()) {
1182     dims.push_back(dim);
1183   }
1184 
1185   Value handle;
1186   // TODO: For now, we track the use of the handle and lower it to cusparse /
1187   // cusparseLt accordingly. If in a block, both cusparse and cusparseLt are
1188   // used, we require two separate Creation ops to be the correct logic. In
1189   // future, we may add support to using one handle in sparse tensor / GPU
1190   // dialect in both cusparse and cusparseLt. use the cusparseLt create call if
1191   // the dnmat is used with spmat with 2:4 sparsity
1192   if (dims.size() == 2) {
1193     if (isSpMMCusparseLtOp(op.getDnTensor())) {
1194       auto handleSz = rewriter.create<LLVM::ConstantOp>(
1195           loc, getIndexType(), rewriter.getIndexAttr(11032));
1196       handle = rewriter.create<LLVM::AllocaOp>(
1197           loc, llvmPointerType, llvmInt8Type, handleSz, /*alignment=*/16);
1198       handle = rewriter.create<LLVM::BitcastOp>(loc, llvmPointerType, handle);
1199 
1200       createLtDnMatCallBuilder
1201           .create(loc, rewriter,
1202                   {handle, dims[0], dims[1], pTensor, dtp, stream})
1203           .getResult();
1204     } else {
1205       handle =
1206           createDnMatCallBuilder
1207               .create(loc, rewriter, {dims[0], dims[1], pTensor, dtp, stream})
1208               .getResult();
1209     }
1210   } else {
1211     assert(dims.size() == 1 && "Only 1D and 2D tensors are supported");
1212     handle = createDnVecCallBuilder
1213                  .create(loc, rewriter, {dims[0], pTensor, dtp, stream})
1214                  .getResult();
1215   }
1216   rewriter.replaceOp(op, {handle, stream});
1217   return success();
1218 }
1219 
1220 LogicalResult ConvertDestroyDnTensorOpToGpuRuntimeCallPattern::matchAndRewrite(
1221     gpu::DestroyDnTensorOp op, OpAdaptor adaptor,
1222     ConversionPatternRewriter &rewriter) const {
1223   if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1224       failed(isAsyncWithOneDependency(rewriter, op)))
1225     return failure();
1226   Location loc = op.getLoc();
1227   auto stream = adaptor.getAsyncDependencies().front();
1228   auto definingOp = op.getDnTensor().getDefiningOp<gpu::CreateDnTensorOp>();
1229   SmallVector<Value, 4> dims;
1230   for (Value dim : definingOp.getDims()) {
1231     dims.push_back(dim);
1232   }
1233   if (dims.size() == 2) {
1234     // Use the cusparseLt destroy call if the dnmat is used with spmat with
1235     // 2:4 sparsity
1236     if (isSpMMCusparseLtOp(op.getDnTensor())) {
1237       destroyCuSparseLtDnMatBuilder.create(loc, rewriter,
1238                                            {adaptor.getDnTensor(), stream});
1239     } else {
1240       destroyDnMatCallBuilder.create(loc, rewriter,
1241                                      {adaptor.getDnTensor(), stream});
1242     }
1243   } else {
1244     assert(dims.size() == 1 && "Only 1D and 2D tensors are supported");
1245     destroyDnVecCallBuilder.create(loc, rewriter,
1246                                    {adaptor.getDnTensor(), stream});
1247   }
1248   rewriter.replaceOp(op, {stream});
1249   return success();
1250 }
1251 
1252 LogicalResult ConvertCreateCooOpToGpuRuntimeCallPattern::matchAndRewrite(
1253     gpu::CreateCooOp op, OpAdaptor adaptor,
1254     ConversionPatternRewriter &rewriter) const {
1255   if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1256       failed(isAsyncWithOneDependency(rewriter, op)))
1257     return failure();
1258   Location loc = op.getLoc();
1259   auto stream = adaptor.getAsyncDependencies().front();
1260   Value pRowIdxs =
1261       MemRefDescriptor(adaptor.getRowIdxs()).allocatedPtr(rewriter, loc);
1262   Value pColIdxs =
1263       MemRefDescriptor(adaptor.getColIdxs()).allocatedPtr(rewriter, loc);
1264   Value pValues =
1265       MemRefDescriptor(adaptor.getValues()).allocatedPtr(rewriter, loc);
1266   Type iType =
1267       llvm::cast<MemRefType>(op.getColIdxs().getType()).getElementType();
1268   Type dType =
1269       llvm::cast<MemRefType>(op.getValues().getType()).getElementType();
1270   auto itp = genConstInt32From(rewriter, loc, getCuSparseIndexTypeFrom(iType));
1271   auto dtp = genConstInt32From(rewriter, loc, getCuSparseDataTypeFrom(dType));
1272   auto handle =
1273       createCooCallBuilder
1274           .create(loc, rewriter,
1275                   {adaptor.getRows(), adaptor.getCols(), adaptor.getNnz(),
1276                    pRowIdxs, pColIdxs, pValues, itp, dtp, stream})
1277           .getResult();
1278   rewriter.replaceOp(op, {handle, stream});
1279   return success();
1280 }
1281 
1282 LogicalResult ConvertCreateCooAoSOpToGpuRuntimeCallPattern::matchAndRewrite(
1283     gpu::CreateCooAoSOp op, OpAdaptor adaptor,
1284     ConversionPatternRewriter &rewriter) const {
1285   if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1286       failed(isAsyncWithOneDependency(rewriter, op)))
1287     return failure();
1288   Location loc = op.getLoc();
1289   auto stream = adaptor.getAsyncDependencies().front();
1290   Value pIdxs = MemRefDescriptor(adaptor.getIdxs()).allocatedPtr(rewriter, loc);
1291   Value pValues =
1292       MemRefDescriptor(adaptor.getValues()).allocatedPtr(rewriter, loc);
1293   Type iType = llvm::cast<MemRefType>(op.getIdxs().getType()).getElementType();
1294   Type dType =
1295       llvm::cast<MemRefType>(op.getValues().getType()).getElementType();
1296   auto itp = genConstInt32From(rewriter, loc, getCuSparseIndexTypeFrom(iType));
1297   auto dtp = genConstInt32From(rewriter, loc, getCuSparseDataTypeFrom(dType));
1298   auto handle =
1299       createCooAoSCallBuilder
1300           .create(loc, rewriter,
1301                   {adaptor.getRows(), adaptor.getCols(), adaptor.getNnz(),
1302                    pIdxs, pValues, itp, dtp, stream})
1303           .getResult();
1304   rewriter.replaceOp(op, {handle, stream});
1305   return success();
1306 }
1307 
1308 LogicalResult ConvertCreateCsrOpToGpuRuntimeCallPattern::matchAndRewrite(
1309     gpu::CreateCsrOp op, OpAdaptor adaptor,
1310     ConversionPatternRewriter &rewriter) const {
1311   if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1312       failed(isAsyncWithOneDependency(rewriter, op)))
1313     return failure();
1314   Location loc = op.getLoc();
1315   auto stream = adaptor.getAsyncDependencies().front();
1316   Value pRowPos =
1317       MemRefDescriptor(adaptor.getRowPos()).allocatedPtr(rewriter, loc);
1318   Value pColIdxs =
1319       MemRefDescriptor(adaptor.getColIdxs()).allocatedPtr(rewriter, loc);
1320   Value pValues =
1321       MemRefDescriptor(adaptor.getValues()).allocatedPtr(rewriter, loc);
1322   Type pType =
1323       llvm::cast<MemRefType>(op.getRowPos().getType()).getElementType();
1324   Type iType =
1325       llvm::cast<MemRefType>(op.getColIdxs().getType()).getElementType();
1326   Type dType =
1327       llvm::cast<MemRefType>(op.getValues().getType()).getElementType();
1328   auto ptp = genConstInt32From(rewriter, loc, getCuSparseIndexTypeFrom(pType));
1329   auto itp = genConstInt32From(rewriter, loc, getCuSparseIndexTypeFrom(iType));
1330   auto dtp = genConstInt32From(rewriter, loc, getCuSparseDataTypeFrom(dType));
1331   auto handle =
1332       createCsrCallBuilder
1333           .create(loc, rewriter,
1334                   {adaptor.getRows(), adaptor.getCols(), adaptor.getNnz(),
1335                    pRowPos, pColIdxs, pValues, ptp, itp, dtp, stream})
1336           .getResult();
1337   rewriter.replaceOp(op, {handle, stream});
1338   return success();
1339 }
1340 
1341 LogicalResult ConvertCreate2To4SpMatOpToGpuRuntimeCallPattern::matchAndRewrite(
1342     gpu::Create2To4SpMatOp op, OpAdaptor adaptor,
1343     ConversionPatternRewriter &rewriter) const {
1344   if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1345       failed(isAsyncWithOneDependency(rewriter, op)))
1346     return failure();
1347   Location loc = op.getLoc();
1348   auto stream = adaptor.getAsyncDependencies().front();
1349   Value pMat =
1350       MemRefDescriptor(adaptor.getMemref()).allocatedPtr(rewriter, loc);
1351   Type dType =
1352       llvm::cast<MemRefType>(op.getMemref().getType()).getElementType();
1353   auto dtp = genConstInt32From(rewriter, loc, getCuSparseDataTypeFrom(dType));
1354 
1355   // CUDA runner asserts the size is 44104 bytes.
1356   auto handleSz = rewriter.create<LLVM::ConstantOp>(
1357       loc, getIndexType(), rewriter.getIndexAttr(44104));
1358   Value handle = rewriter.create<LLVM::AllocaOp>(
1359       loc, llvmPointerType, llvmInt8Type, handleSz, /*alignment=*/16);
1360   handle = rewriter.create<LLVM::BitcastOp>(loc, llvmPointerType, handle);
1361 
1362   create2To4SpMatCallBuilder
1363       .create(loc, rewriter,
1364               {handle, adaptor.getRows(), adaptor.getCols(), pMat, dtp, stream})
1365       .getResult();
1366   rewriter.replaceOp(op, {handle, stream});
1367   return success();
1368 }
1369 
1370 LogicalResult ConvertDestroySpMatOpToGpuRuntimeCallPattern::matchAndRewrite(
1371     gpu::DestroySpMatOp op, OpAdaptor adaptor,
1372     ConversionPatternRewriter &rewriter) const {
1373   if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1374       failed(isAsyncWithOneDependency(rewriter, op)))
1375     return failure();
1376   Location loc = op.getLoc();
1377   auto stream = adaptor.getAsyncDependencies().front();
1378   // Use the cusparseLt destroy call if the spmat is 2:4 sparsity
1379   if (is2To4Sparsity(op.getSpmat())) {
1380     destroyCuSparseLtSpMatBuilder.create(loc, rewriter,
1381                                          {adaptor.getSpmat(), stream});
1382 
1383   } else {
1384     destroySpMatCallBuilder.create(loc, rewriter, {adaptor.getSpmat(), stream});
1385   }
1386   rewriter.replaceOp(op, {stream});
1387   return success();
1388 }
1389 
1390 LogicalResult ConvertSpMVBufferSizeOpToGpuRuntimeCallPattern::matchAndRewrite(
1391     gpu::SpMVBufferSizeOp op, OpAdaptor adaptor,
1392     ConversionPatternRewriter &rewriter) const {
1393   if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1394       failed(isAsyncWithOneDependency(rewriter, op)))
1395     return failure();
1396   Location loc = op.getLoc();
1397   auto modeA = genConstInt32From(rewriter, loc, op.getModeA());
1398   auto computeType = genConstInt32From(
1399       rewriter, loc, getCuSparseDataTypeFrom(adaptor.getComputeType()));
1400   auto stream = adaptor.getAsyncDependencies().front();
1401   auto bufferSize = spMVBufferSizeCallBuilder
1402                         .create(loc, rewriter,
1403                                 {modeA, adaptor.getSpmatA(), adaptor.getDnX(),
1404                                  adaptor.getDnY(), computeType, stream})
1405                         .getResult();
1406   rewriter.replaceOp(op, {bufferSize, stream});
1407   return success();
1408 }
1409 
1410 LogicalResult ConvertSpMVOpToGpuRuntimeCallPattern::matchAndRewrite(
1411     gpu::SpMVOp op, OpAdaptor adaptor,
1412     ConversionPatternRewriter &rewriter) const {
1413   if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1414       failed(isAsyncWithOneDependency(rewriter, op)))
1415     return failure();
1416   Location loc = op.getLoc();
1417   auto modeA = genConstInt32From(rewriter, loc, adaptor.getModeA());
1418   auto computeType = genConstInt32From(
1419       rewriter, loc, getCuSparseDataTypeFrom(adaptor.getComputeType()));
1420   auto stream = adaptor.getAsyncDependencies().front();
1421   Value pBuf =
1422       MemRefDescriptor(adaptor.getBuffer()).allocatedPtr(rewriter, loc);
1423   spMVCallBuilder.create(loc, rewriter,
1424                          {modeA, adaptor.getSpmatA(), adaptor.getDnX(),
1425                           adaptor.getDnY(), computeType, pBuf, stream});
1426   rewriter.replaceOp(op, {stream});
1427   return success();
1428 }
1429 
1430 LogicalResult ConvertSpMMBufferSizeOpToGpuRuntimeCallPattern::matchAndRewrite(
1431     gpu::SpMMBufferSizeOp op, OpAdaptor adaptor,
1432     ConversionPatternRewriter &rewriter) const {
1433   if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1434       failed(isAsyncWithOneDependency(rewriter, op)))
1435     return failure();
1436   Location loc = op.getLoc();
1437   auto modeA = genConstInt32From(rewriter, loc, adaptor.getModeA());
1438   auto modeB = genConstInt32From(rewriter, loc, adaptor.getModeB());
1439   auto stream = adaptor.getAsyncDependencies().front();
1440   Value bufferSize;
1441   if (is2To4Sparsity(op.getSpmatA())) {
1442     auto pruneFlag =
1443         genConstInt32From(rewriter, loc, get2To4PruneFlag(op.getSpmatA()));
1444     auto computeType = genConstInt32From(
1445         rewriter, loc, getCuSparseLtDataTypeFrom(adaptor.getComputeType()));
1446     auto three = rewriter.create<LLVM::ConstantOp>(loc, getIndexType(),
1447                                                    rewriter.getIndexAttr(3));
1448     auto bufferSize = rewriter.create<LLVM::AllocaOp>(
1449         loc, llvmPointerType, llvmPointerType, three, /*alignment=*/16);
1450     createCuSparseLtSpMMBufferSizeBuilder
1451         .create(loc, rewriter,
1452                 {bufferSize, modeA, modeB, adaptor.getSpmatA(),
1453                  adaptor.getDnmatB(), adaptor.getDnmatC(), computeType,
1454                  pruneFlag, stream})
1455         .getResult();
1456 
1457     auto bufferSizePtr1 = rewriter.create<LLVM::GEPOp>(
1458         loc, llvmPointerType, llvmPointerType, bufferSize,
1459         ValueRange{rewriter.create<LLVM::ConstantOp>(
1460             loc, getIndexType(), rewriter.getIndexAttr(1))});
1461     auto bufferSizePtr2 = rewriter.create<LLVM::GEPOp>(
1462         loc, llvmPointerType, llvmPointerType, bufferSize,
1463         ValueRange{rewriter.create<LLVM::ConstantOp>(
1464             loc, getIndexType(), rewriter.getIndexAttr(2))});
1465     auto bufferSize0 =
1466         rewriter.create<LLVM::LoadOp>(loc, llvmInt64Type, bufferSize);
1467     auto bufferSize1 =
1468         rewriter.create<LLVM::LoadOp>(loc, llvmInt64Type, bufferSizePtr1);
1469     auto bufferSize2 =
1470         rewriter.create<LLVM::LoadOp>(loc, llvmInt64Type, bufferSizePtr2);
1471 
1472     rewriter.replaceOp(op, {bufferSize0, bufferSize1, bufferSize2, stream});
1473   } else {
1474     auto computeType = genConstInt32From(
1475         rewriter, loc, getCuSparseDataTypeFrom(adaptor.getComputeType()));
1476     bufferSize =
1477         createSpMMBufferSizeCallBuilder
1478             .create(loc, rewriter,
1479                     {modeA, modeB, adaptor.getSpmatA(), adaptor.getDnmatB(),
1480                      adaptor.getDnmatC(), computeType, stream})
1481             .getResult();
1482     rewriter.replaceOp(op, {bufferSize, stream});
1483   }
1484   return success();
1485 }
1486 
1487 LogicalResult ConvertSDDMMBufferSizeOpToGpuRuntimeCallPattern::matchAndRewrite(
1488     gpu::SDDMMBufferSizeOp op, OpAdaptor adaptor,
1489     ConversionPatternRewriter &rewriter) const {
1490   if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1491       failed(isAsyncWithOneDependency(rewriter, op)))
1492     return failure();
1493   Location loc = op.getLoc();
1494   auto modeA = genConstInt32From(rewriter, loc, adaptor.getModeA());
1495   auto modeB = genConstInt32From(rewriter, loc, adaptor.getModeB());
1496   auto computeType = genConstInt32From(
1497       rewriter, loc, getCuSparseDataTypeFrom(adaptor.getComputeType()));
1498   auto stream = adaptor.getAsyncDependencies().front();
1499   auto bufferSize =
1500       createSDDMMBufferSizeCallBuilder
1501           .create(loc, rewriter,
1502                   {modeA, modeB, adaptor.getDnmatA(), adaptor.getDnmatB(),
1503                    adaptor.getSpmatC(), computeType, stream})
1504           .getResult();
1505   rewriter.replaceOp(op, {bufferSize, stream});
1506   return success();
1507 }
1508 
1509 LogicalResult ConvertSpMMOpToGpuRuntimeCallPattern::matchAndRewrite(
1510     gpu::SpMMOp op, OpAdaptor adaptor,
1511     ConversionPatternRewriter &rewriter) const {
1512   if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1513       failed(isAsyncWithOneDependency(rewriter, op)))
1514     return failure();
1515   Location loc = op.getLoc();
1516   auto modeA = genConstInt32From(rewriter, loc, adaptor.getModeA());
1517   auto modeB = genConstInt32From(rewriter, loc, adaptor.getModeB());
1518   auto computeType = genConstInt32From(
1519       rewriter, loc, getCuSparseDataTypeFrom(adaptor.getComputeType()));
1520 
1521   auto stream = adaptor.getAsyncDependencies().front();
1522 
1523   // Lower to cusparseLt if applicable
1524   if (is2To4Sparsity(op.getSpmatA())) {
1525     SmallVector<Value> pBufs;
1526     for (Value buffer : adaptor.getBuffers()) {
1527       Value pBuf = MemRefDescriptor(buffer).allocatedPtr(rewriter, loc);
1528       pBufs.push_back(pBuf);
1529     }
1530     createCuSparseLtSpMMBuilder.create(
1531         loc, rewriter,
1532         {adaptor.getSpmatA(), adaptor.getDnmatB(), adaptor.getDnmatC(),
1533          pBufs[0], pBufs[1], pBufs[2], stream});
1534   } else {
1535     Value pBuf = MemRefDescriptor(adaptor.getBuffers().front())
1536                      .allocatedPtr(rewriter, loc);
1537     createSpMMCallBuilder.create(loc, rewriter,
1538                                  {modeA, modeB, adaptor.getSpmatA(),
1539                                   adaptor.getDnmatB(), adaptor.getDnmatC(),
1540                                   computeType, pBuf, stream});
1541   }
1542   rewriter.replaceOp(op, {stream});
1543   return success();
1544 }
1545 
1546 template <typename T>
1547 static void addOpaquePointerConversion(LLVMTypeConverter &converter) {
1548   converter.addConversion([&converter](T) -> Type {
1549     return LLVM::LLVMPointerType::get(&converter.getContext());
1550   });
1551 }
1552 
1553 LogicalResult ConvertSDDMMOpToGpuRuntimeCallPattern::matchAndRewrite(
1554     gpu::SDDMMOp op, OpAdaptor adaptor,
1555     ConversionPatternRewriter &rewriter) const {
1556   if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1557       failed(isAsyncWithOneDependency(rewriter, op)))
1558     return failure();
1559   Location loc = op.getLoc();
1560   auto computeType = genConstInt32From(
1561       rewriter, loc, getCuSparseDataTypeFrom(adaptor.getComputeType()));
1562   auto modeA = genConstInt32From(rewriter, loc, adaptor.getModeA());
1563   auto modeB = genConstInt32From(rewriter, loc, adaptor.getModeB());
1564   auto stream = adaptor.getAsyncDependencies().front();
1565   Value pBuf =
1566       MemRefDescriptor(adaptor.getBuffer()).allocatedPtr(rewriter, loc);
1567   createSDDMMCallBuilder.create(loc, rewriter,
1568                                 {modeA, modeB, adaptor.getDnmatA(),
1569                                  adaptor.getDnmatB(), adaptor.getSpmatC(),
1570                                  computeType, pBuf, stream});
1571   rewriter.replaceOp(op, {stream});
1572   return success();
1573 }
1574 
1575 LogicalResult
1576 ConvertSpGEMMCreateDescrOpToGpuRuntimeCallPattern::matchAndRewrite(
1577     gpu::SpGEMMCreateDescrOp op, OpAdaptor adaptor,
1578     ConversionPatternRewriter &rewriter) const {
1579   if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1580       failed(isAsyncWithOneDependency(rewriter, op)))
1581     return failure();
1582   Location loc = op.getLoc();
1583   auto stream = adaptor.getAsyncDependencies().front();
1584   Value descr = createSpGEMMCreateDescrBuilder.create(loc, rewriter, {stream})
1585                     .getResult();
1586   rewriter.replaceOp(op, {descr, stream});
1587   return success();
1588 }
1589 
1590 LogicalResult
1591 ConvertSpGEMMDestroyDescrOpToGpuRuntimeCallPattern::matchAndRewrite(
1592     gpu::SpGEMMDestroyDescrOp op, OpAdaptor adaptor,
1593     ConversionPatternRewriter &rewriter) const {
1594   if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1595       failed(isAsyncWithOneDependency(rewriter, op)))
1596     return failure();
1597   Location loc = op.getLoc();
1598   auto stream = adaptor.getAsyncDependencies().front();
1599   createSpGEMMDestroyDescrBuilder.create(loc, rewriter,
1600                                          {adaptor.getDesc(), stream});
1601   rewriter.replaceOp(op, {stream});
1602   return success();
1603 }
1604 
1605 LogicalResult
1606 ConvertSpGEMMWorkEstimationOrComputeOpToGpuRuntimeCallPattern::matchAndRewrite(
1607     gpu::SpGEMMWorkEstimationOrComputeOp op, OpAdaptor adaptor,
1608     ConversionPatternRewriter &rewriter) const {
1609   if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1610       failed(isAsyncWithOneDependency(rewriter, op)))
1611     return failure();
1612   Location loc = op.getLoc();
1613   auto computeType = genConstInt32From(
1614       rewriter, loc, getCuSparseDataTypeFrom(adaptor.getComputeType()));
1615   auto modeA = genConstInt32From(rewriter, loc, adaptor.getModeA());
1616   auto modeB = genConstInt32From(rewriter, loc, adaptor.getModeB());
1617   auto stream = adaptor.getAsyncDependencies().front();
1618 
1619   Value pBuf =
1620       MemRefDescriptor(adaptor.getBuffer()).allocatedPtr(rewriter, loc);
1621   Value bufferSizeNew;
1622 
1623   if (adaptor.getKind() ==
1624       gpu::SpGEMMWorkEstimationOrComputeKind::WORK_ESTIMATION) {
1625     bufferSizeNew =
1626         createSpGEMMWorkEstimationBuilder
1627             .create(loc, rewriter,
1628                     {adaptor.getDesc(), modeA, modeB, adaptor.getSpmatA(),
1629                      adaptor.getSpmatB(), adaptor.getSpmatC(), computeType,
1630                      adaptor.getBufferSz(), pBuf, stream})
1631             .getResult();
1632   } else {
1633     bufferSizeNew =
1634         createSpGEMMComputeBuilder
1635             .create(loc, rewriter,
1636                     {adaptor.getDesc(), modeA, modeB, adaptor.getSpmatA(),
1637                      adaptor.getSpmatB(), adaptor.getSpmatC(), computeType,
1638                      adaptor.getBufferSz(), pBuf, stream})
1639             .getResult();
1640   }
1641   rewriter.replaceOp(op, {bufferSizeNew, stream});
1642   return success();
1643 }
1644 
1645 LogicalResult ConvertSpGEMMCopyOpToGpuRuntimeCallPattern::matchAndRewrite(
1646     gpu::SpGEMMCopyOp op, OpAdaptor adaptor,
1647     ConversionPatternRewriter &rewriter) const {
1648   if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1649       failed(isAsyncWithOneDependency(rewriter, op)))
1650     return failure();
1651   Location loc = op.getLoc();
1652   auto computeType = genConstInt32From(
1653       rewriter, loc, getCuSparseDataTypeFrom(adaptor.getComputeType()));
1654   auto modeA = genConstInt32From(rewriter, loc, adaptor.getModeA());
1655   auto modeB = genConstInt32From(rewriter, loc, adaptor.getModeB());
1656   auto stream = adaptor.getAsyncDependencies().front();
1657   createSpGEMMCopyBuilder.create(loc, rewriter,
1658                                  {adaptor.getDesc(), modeA, modeB,
1659                                   adaptor.getSpmatA(), adaptor.getSpmatB(),
1660                                   adaptor.getSpmatC(), computeType, stream});
1661   rewriter.replaceOp(op, {stream});
1662   return success();
1663 }
1664 
1665 LogicalResult ConvertSpMatGetSizeOpToGpuRuntimeCallPattern::matchAndRewrite(
1666     gpu::SpMatGetSizeOp op, OpAdaptor adaptor,
1667     ConversionPatternRewriter &rewriter) const {
1668   if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1669       failed(isAsyncWithOneDependency(rewriter, op)))
1670     return failure();
1671   Location loc = op.getLoc();
1672   auto stream = adaptor.getAsyncDependencies().front();
1673 
1674   auto three = rewriter.create<LLVM::ConstantOp>(loc, getIndexType(),
1675                                                  rewriter.getIndexAttr(3));
1676   auto buffer = rewriter.create<LLVM::AllocaOp>(
1677       loc, llvmPointerType, llvmInt64Type, three, /*alignment=*/16);
1678 
1679   auto rowsPtr = rewriter.create<LLVM::GEPOp>(
1680       loc, llvmPointerType, llvmPointerType, buffer,
1681       ValueRange{rewriter.create<LLVM::ConstantOp>(loc, getIndexType(),
1682                                                    rewriter.getIndexAttr(0))});
1683   auto colsPtr = rewriter.create<LLVM::GEPOp>(
1684       loc, llvmPointerType, llvmPointerType, buffer,
1685       ValueRange{rewriter.create<LLVM::ConstantOp>(loc, getIndexType(),
1686                                                    rewriter.getIndexAttr(1))});
1687   auto nnzsPtr = rewriter.create<LLVM::GEPOp>(
1688       loc, llvmPointerType, llvmPointerType, buffer,
1689       ValueRange{rewriter.create<LLVM::ConstantOp>(loc, getIndexType(),
1690                                                    rewriter.getIndexAttr(2))});
1691   createSpMatGetSizeBuilder.create(
1692       loc, rewriter, {adaptor.getSpmat(), rowsPtr, colsPtr, nnzsPtr, stream});
1693   auto rows = rewriter.create<LLVM::LoadOp>(loc, llvmInt64Type, rowsPtr);
1694   auto cols = rewriter.create<LLVM::LoadOp>(loc, llvmInt64Type, colsPtr);
1695   auto nnzs = rewriter.create<LLVM::LoadOp>(loc, llvmInt64Type, nnzsPtr);
1696 
1697   rewriter.replaceOp(op, {rows, cols, nnzs, stream});
1698   return success();
1699 }
1700 
1701 LogicalResult ConvertSetCsrPointersOpToGpuRuntimeCallPattern::matchAndRewrite(
1702     gpu::SetCsrPointersOp op, OpAdaptor adaptor,
1703     ConversionPatternRewriter &rewriter) const {
1704   if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1705       failed(isAsyncWithOneDependency(rewriter, op)))
1706     return failure();
1707   Location loc = op.getLoc();
1708   auto stream = adaptor.getAsyncDependencies().front();
1709   Value pPos =
1710       MemRefDescriptor(adaptor.getPositions()).allocatedPtr(rewriter, loc);
1711   Value pCrd =
1712       MemRefDescriptor(adaptor.getCoordinates()).allocatedPtr(rewriter, loc);
1713   Value pVal =
1714       MemRefDescriptor(adaptor.getValues()).allocatedPtr(rewriter, loc);
1715   createSetCsrPointersBuilder.create(
1716       loc, rewriter, {adaptor.getSpmat(), pPos, pCrd, pVal, stream});
1717   rewriter.replaceOp(op, {stream});
1718   return success();
1719 }
1720 
1721 LogicalResult ConvertCreateCscOpToGpuRuntimeCallPattern::matchAndRewrite(
1722     gpu::CreateCscOp op, OpAdaptor adaptor,
1723     ConversionPatternRewriter &rewriter) const {
1724   if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1725       failed(isAsyncWithOneDependency(rewriter, op)))
1726     return failure();
1727   Location loc = op.getLoc();
1728   auto stream = adaptor.getAsyncDependencies().front();
1729   Value pColPos =
1730       MemRefDescriptor(adaptor.getColPos()).allocatedPtr(rewriter, loc);
1731   Value pRowIdxs =
1732       MemRefDescriptor(adaptor.getRowIdxs()).allocatedPtr(rewriter, loc);
1733   Value pValues =
1734       MemRefDescriptor(adaptor.getValues()).allocatedPtr(rewriter, loc);
1735   Type pType =
1736       llvm::cast<MemRefType>(op.getColPos().getType()).getElementType();
1737   Type iType =
1738       llvm::cast<MemRefType>(op.getRowIdxs().getType()).getElementType();
1739   Type dType =
1740       llvm::cast<MemRefType>(op.getValues().getType()).getElementType();
1741   auto ptp = genConstInt32From(rewriter, loc, getCuSparseIndexTypeFrom(pType));
1742   auto itp = genConstInt32From(rewriter, loc, getCuSparseIndexTypeFrom(iType));
1743   auto dtp = genConstInt32From(rewriter, loc, getCuSparseDataTypeFrom(dType));
1744   auto handle =
1745       createCscCallBuilder
1746           .create(loc, rewriter,
1747                   {adaptor.getRows(), adaptor.getCols(), adaptor.getNnz(),
1748                    pColPos, pRowIdxs, pValues, ptp, itp, dtp, stream})
1749           .getResult();
1750   rewriter.replaceOp(op, {handle, stream});
1751   return success();
1752 }
1753 
1754 LogicalResult ConvertCreateBsrOpToGpuRuntimeCallPattern::matchAndRewrite(
1755     gpu::CreateBsrOp op, OpAdaptor adaptor,
1756     ConversionPatternRewriter &rewriter) const {
1757   if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1758       failed(isAsyncWithOneDependency(rewriter, op)))
1759     return failure();
1760   Location loc = op.getLoc();
1761   auto stream = adaptor.getAsyncDependencies().front();
1762   Value pRowPos =
1763       MemRefDescriptor(adaptor.getBRowPos()).allocatedPtr(rewriter, loc);
1764   Value pColIdxs =
1765       MemRefDescriptor(adaptor.getBColIdxs()).allocatedPtr(rewriter, loc);
1766   Value pValues =
1767       MemRefDescriptor(adaptor.getValues()).allocatedPtr(rewriter, loc);
1768   Type pType =
1769       llvm::cast<MemRefType>(op.getBRowPos().getType()).getElementType();
1770   Type iType =
1771       llvm::cast<MemRefType>(op.getBColIdxs().getType()).getElementType();
1772   Type dType =
1773       llvm::cast<MemRefType>(op.getValues().getType()).getElementType();
1774   auto ptp = genConstInt32From(rewriter, loc, getCuSparseIndexTypeFrom(pType));
1775   auto itp = genConstInt32From(rewriter, loc, getCuSparseIndexTypeFrom(iType));
1776   auto dtp = genConstInt32From(rewriter, loc, getCuSparseDataTypeFrom(dType));
1777   auto handle =
1778       createBsrCallBuilder
1779           .create(loc, rewriter,
1780                   {adaptor.getBrows(), adaptor.getBcols(), adaptor.getBnnz(),
1781                    adaptor.getRBlockSize(), adaptor.getCBlockSize(), pRowPos,
1782                    pColIdxs, pValues, ptp, itp, dtp, stream})
1783           .getResult();
1784   rewriter.replaceOp(op, {handle, stream});
1785   return success();
1786 }
1787 
1788 void mlir::populateGpuToLLVMConversionPatterns(
1789     LLVMTypeConverter &converter, RewritePatternSet &patterns,
1790     bool kernelBarePtrCallConv, bool kernelIntersperseSizeCallConv) {
1791   addOpaquePointerConversion<gpu::AsyncTokenType>(converter);
1792   addOpaquePointerConversion<gpu::SparseDnTensorHandleType>(converter);
1793   addOpaquePointerConversion<gpu::SparseSpMatHandleType>(converter);
1794   addOpaquePointerConversion<gpu::SparseSpGEMMOpHandleType>(converter);
1795 
1796   patterns.add<ConvertAllocOpToGpuRuntimeCallPattern,
1797                ConvertDeallocOpToGpuRuntimeCallPattern,
1798                ConvertHostRegisterOpToGpuRuntimeCallPattern,
1799                ConvertHostUnregisterOpToGpuRuntimeCallPattern,
1800                ConvertMemcpyOpToGpuRuntimeCallPattern,
1801                ConvertMemsetOpToGpuRuntimeCallPattern,
1802                ConvertSetDefaultDeviceOpToGpuRuntimeCallPattern,
1803                ConvertWaitAsyncOpToGpuRuntimeCallPattern,
1804                ConvertWaitOpToGpuRuntimeCallPattern,
1805                ConvertAsyncYieldToGpuRuntimeCallPattern,
1806                ConvertCreateDnTensorOpToGpuRuntimeCallPattern,
1807                ConvertDestroyDnTensorOpToGpuRuntimeCallPattern,
1808                ConvertCreateCooOpToGpuRuntimeCallPattern,
1809                ConvertCreateCooAoSOpToGpuRuntimeCallPattern,
1810                ConvertCreateCsrOpToGpuRuntimeCallPattern,
1811                ConvertCreateCscOpToGpuRuntimeCallPattern,
1812                ConvertCreateBsrOpToGpuRuntimeCallPattern,
1813                ConvertCreate2To4SpMatOpToGpuRuntimeCallPattern,
1814                ConvertDestroySpMatOpToGpuRuntimeCallPattern,
1815                ConvertSpMVBufferSizeOpToGpuRuntimeCallPattern,
1816                ConvertSpMVOpToGpuRuntimeCallPattern,
1817                ConvertSpMMBufferSizeOpToGpuRuntimeCallPattern,
1818                ConvertSDDMMBufferSizeOpToGpuRuntimeCallPattern,
1819                ConvertSpMMOpToGpuRuntimeCallPattern,
1820                ConvertSDDMMOpToGpuRuntimeCallPattern,
1821                ConvertSpGEMMCreateDescrOpToGpuRuntimeCallPattern,
1822                ConvertSpGEMMDestroyDescrOpToGpuRuntimeCallPattern,
1823                ConvertSpGEMMWorkEstimationOrComputeOpToGpuRuntimeCallPattern,
1824                ConvertSpGEMMCopyOpToGpuRuntimeCallPattern,
1825                ConvertSpMatGetSizeOpToGpuRuntimeCallPattern,
1826                ConvertSetCsrPointersOpToGpuRuntimeCallPattern>(converter);
1827   patterns.add<LegalizeLaunchFuncOpPattern>(converter, kernelBarePtrCallConv,
1828                                             kernelIntersperseSizeCallConv);
1829 }
1830 
1831 //===----------------------------------------------------------------------===//
1832 // GPUModuleOp convert to LLVM op interface
1833 //===----------------------------------------------------------------------===//
1834 
1835 namespace {
1836 struct GPUModuleOpConvertToLLVMInterface
1837     : public ConvertToLLVMOpInterface::ExternalModel<
1838           GPUModuleOpConvertToLLVMInterface, gpu::GPUModuleOp> {
1839   /// Get the conversion patterns from the target attribute.
1840   void getConvertToLLVMConversionAttrs(
1841       Operation *op, SmallVectorImpl<ConvertToLLVMAttrInterface> &attrs) const;
1842 };
1843 } // namespace
1844 
1845 void GPUModuleOpConvertToLLVMInterface::getConvertToLLVMConversionAttrs(
1846     Operation *op, SmallVectorImpl<ConvertToLLVMAttrInterface> &attrs) const {
1847   auto module = cast<gpu::GPUModuleOp>(op);
1848   ArrayAttr targetsAttr = module.getTargetsAttr();
1849   // Fail if there are no target attributes or there is more than one target.
1850   if (!targetsAttr || targetsAttr.size() != 1)
1851     return;
1852   if (auto patternAttr = dyn_cast<ConvertToLLVMAttrInterface>(targetsAttr[0]))
1853     attrs.push_back(patternAttr);
1854 }
1855 
1856 void mlir::gpu::registerConvertGpuToLLVMInterface(DialectRegistry &registry) {
1857   registry.addExtension(+[](MLIRContext *ctx, gpu::GPUDialect *dialect) {
1858     gpu::GPUModuleOp::attachInterface<GPUModuleOpConvertToLLVMInterface>(*ctx);
1859   });
1860 }
1861