xref: /llvm-project/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp (revision 3231a365c185094abe8af5970ccdbcbcd2bc001a)
1 //===- SparseGPUCodegen.cpp - Generates GPU code --------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This is a prototype GPU codegenerator for the sparse compiler.
10 // The objective is to eventually use the right combination of
11 // direct code generation and libary calls into vendor-specific
12 // highly optimized sparse libraries (e.g. cuSparse for CUDA).
13 //
14 //===----------------------------------------------------------------------===//
15 
16 #include "CodegenUtils.h"
17 #include "LoopEmitter.h"
18 
19 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
20 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
21 #include "mlir/Dialect/Linalg/IR/Linalg.h"
22 #include "mlir/Dialect/Linalg/Utils/Utils.h"
23 #include "mlir/Dialect/MemRef/IR/MemRef.h"
24 #include "mlir/Dialect/SCF/IR/SCF.h"
25 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
26 #include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h"
27 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
28 #include "mlir/IR/IRMapping.h"
29 #include "mlir/IR/Matchers.h"
30 
31 using namespace mlir;
32 using namespace mlir::sparse_tensor;
33 
34 namespace {
35 
36 // Sparse formats supported by cuSparse.
37 enum class CuSparseFormat {
38   kNone,
39   kCOO,
40   kCSR,
41   kCSC,
42   kBSR, // TODO: coming soon!
43 };
44 
45 //===----------------------------------------------------------------------===//
46 // Helper methods.
47 //===----------------------------------------------------------------------===//
48 
49 /// Marks the given top module as a GPU container module.
50 static void markAsGPUContainer(ModuleOp topModule) {
51   topModule->setAttr(gpu::GPUDialect::getContainerModuleAttrName(),
52                      UnitAttr::get(topModule->getContext()));
53 }
54 
55 /// Constructs a new GPU module (for GPU kernels) inside the given top module,
56 /// or returns an existing GPU module if one was built previously.
57 static gpu::GPUModuleOp genGPUModule(OpBuilder &builder, ModuleOp topModule) {
58   for (auto op : topModule.getBodyRegion().getOps<gpu::GPUModuleOp>())
59     return op; // existing
60   markAsGPUContainer(topModule);
61   builder.setInsertionPointToStart(&topModule.getBodyRegion().front());
62   return builder.create<gpu::GPUModuleOp>(topModule->getLoc(),
63                                           "sparse_kernels");
64 }
65 
66 /// Constructs a new GPU kernel in the given GPU module.
67 static gpu::GPUFuncOp genGPUFunc(OpBuilder &builder, gpu::GPUModuleOp gpuModule,
68                                  SmallVectorImpl<Value> &args) {
69   // Get a unique kernel name. Not very creative,
70   // but we simply try kernel0, kernel1, etc.
71   unsigned kernelNumber = 0;
72   SmallString<16> kernelName;
73   do {
74     kernelName.clear();
75     ("kernel" + Twine(kernelNumber++)).toStringRef(kernelName);
76   } while (gpuModule.lookupSymbol(kernelName));
77   // Then we insert a new kernel with given arguments into the module.
78   builder.setInsertionPointToStart(&gpuModule.getBodyRegion().front());
79   SmallVector<Type> argsTp;
80   for (unsigned i = 0, e = args.size(); i < e; i++)
81     argsTp.push_back(args[i].getType());
82   FunctionType type = FunctionType::get(gpuModule->getContext(), argsTp, {});
83   auto gpuFunc =
84       builder.create<gpu::GPUFuncOp>(gpuModule->getLoc(), kernelName, type);
85   gpuFunc->setAttr(gpu::GPUDialect::getKernelFuncAttrName(),
86                    builder.getUnitAttr());
87   return gpuFunc;
88 }
89 
90 /// Constructs code to launch GPU kernel.
91 static Value genLaunchGPUFunc(OpBuilder &builder, gpu::GPUFuncOp gpuFunc,
92                               SmallVectorImpl<Value> &args,
93                               SmallVectorImpl<Value> &tokens,
94                               unsigned numThreads) {
95   Location loc = gpuFunc->getLoc();
96   Value none = TypedValue<::mlir::IntegerType>{};
97   Value one = constantIndex(builder, loc, 1);
98   Value numT = constantIndex(builder, loc, numThreads);
99   gpu::KernelDim3 gridSize = {one, one, one};
100   gpu::KernelDim3 blckSize = {numT, one, one};
101   return builder
102       .create<gpu::LaunchFuncOp>(loc, gpuFunc, gridSize, blckSize,
103                                  /*dynSharedMemSz*/ none, args,
104                                  builder.getType<gpu::AsyncTokenType>(), tokens)
105       .getAsyncToken();
106 }
107 
108 /// Maps the provided ranked host buffer into the device address space.
109 /// Writes from the host are guaranteed to be visible to device kernels
110 /// that are launched afterwards. Writes from the device are guaranteed
111 /// to be visible on the host after synchronizing with the device kernel
112 /// completion. Needs to cast the buffer to a unranked buffer.
113 static Value genHostRegisterMemref(OpBuilder &builder, Location loc,
114                                    Value mem) {
115   MemRefType memTp = cast<MemRefType>(mem.getType());
116   UnrankedMemRefType resTp =
117       UnrankedMemRefType::get(memTp.getElementType(), /*memorySpace=*/0);
118   Value cast = builder.create<memref::CastOp>(loc, resTp, mem);
119   builder.create<gpu::HostRegisterOp>(loc, cast);
120   return cast;
121 }
122 
123 /// Unmaps the provided buffer, expecting the casted buffer.
124 static void genHostUnregisterMemref(OpBuilder &builder, Location loc,
125                                     Value cast) {
126   builder.create<gpu::HostUnregisterOp>(loc, cast);
127 }
128 
129 /// Generates first wait in an asynchronous chain.
130 static Value genFirstWait(OpBuilder &builder, Location loc) {
131   Type tokenType = builder.getType<gpu::AsyncTokenType>();
132   return builder.create<gpu::WaitOp>(loc, tokenType, ValueRange())
133       .getAsyncToken();
134 }
135 
136 /// Generates last, blocking wait in an asynchronous chain.
137 static void genBlockingWait(OpBuilder &builder, Location loc,
138                             ValueRange operands) {
139   builder.create<gpu::WaitOp>(loc, Type(), operands);
140 }
141 
142 /// Allocates memory on the device.
143 /// TODO: A `host_shared` attribute could be used to indicate that
144 ///       the buffer is visible by both host and device, but lowering
145 ///       that feature does not seem to be fully supported yet.
146 static gpu::AllocOp genAllocMemRef(OpBuilder &builder, Location loc, Value mem,
147                                    Value token) {
148   auto tp = cast<ShapedType>(mem.getType());
149   auto elemTp = tp.getElementType();
150   auto shape = tp.getShape();
151   auto memTp = MemRefType::get(shape, elemTp);
152   SmallVector<Value> dynamicSizes;
153   for (unsigned r = 0, rank = tp.getRank(); r < rank; r++) {
154     if (shape[r] == ShapedType::kDynamic) {
155       Value dimOp = linalg::createOrFoldDimOp(builder, loc, mem, r);
156       dynamicSizes.push_back(dimOp);
157     }
158   }
159   return builder.create<gpu::AllocOp>(loc, TypeRange({memTp, token.getType()}),
160                                       token, dynamicSizes, ValueRange());
161 }
162 
163 // Allocates a typed buffer on the host with given size.
164 static Value genHostBuffer(OpBuilder &builder, Location loc, Type type,
165                            Value size) {
166   const auto memTp = MemRefType::get({ShapedType::kDynamic}, type);
167   return builder.create<memref::AllocOp>(loc, memTp, size).getResult();
168 }
169 
170 // Allocates a typed buffer on the device with given size.
171 static gpu::AllocOp genAllocBuffer(OpBuilder &builder, Location loc, Type type,
172                                    Value size, Value token) {
173   const auto memTp = MemRefType::get({ShapedType::kDynamic}, type);
174   return builder.create<gpu::AllocOp>(loc, TypeRange({memTp, token.getType()}),
175                                       token, size, ValueRange());
176 }
177 
178 // Allocates a void buffer on the device with given size.
179 static gpu::AllocOp genAllocBuffer(OpBuilder &builder, Location loc, Value size,
180                                    Value token) {
181   return genAllocBuffer(builder, loc, builder.getI8Type(), size, token);
182 }
183 
184 /// Deallocates memory from the device.
185 static Value genDeallocMemRef(OpBuilder &builder, Location loc, Value mem,
186                               Value token) {
187   return builder.create<gpu::DeallocOp>(loc, token.getType(), token, mem)
188       .getAsyncToken();
189 }
190 
191 /// Copies memory between host and device (direction is implicit).
192 static Value genCopyMemRef(OpBuilder &builder, Location loc, Value dst,
193                            Value src, Value token) {
194   return builder.create<gpu::MemcpyOp>(loc, token.getType(), token, dst, src)
195       .getAsyncToken();
196 }
197 
198 /// Generates an alloc/copy pair.
199 static Value genAllocCopy(OpBuilder &builder, Location loc, Value b,
200                           SmallVectorImpl<Value> &tokens) {
201   Value firstToken = genFirstWait(builder, loc);
202   auto alloc = genAllocMemRef(builder, loc, b, firstToken);
203   Value devMem = alloc.getResult(0);
204   Value depToken = alloc.getAsyncToken(); // copy-after-alloc
205   tokens.push_back(genCopyMemRef(builder, loc, devMem, b, depToken));
206   return devMem;
207 }
208 
209 /// Generates a memref from tensor operation.
210 static Value genTensorToMemref(PatternRewriter &rewriter, Location loc,
211                                Value tensor) {
212   auto tensorType = llvm::cast<ShapedType>(tensor.getType());
213   auto memrefType =
214       MemRefType::get(tensorType.getShape(), tensorType.getElementType());
215   return rewriter.create<bufferization::ToMemrefOp>(loc, memrefType, tensor);
216 }
217 
218 /// Prepares the outlined arguments, passing scalars and buffers in. Here we
219 /// assume that the first buffer is the one allocated for output. We create
220 /// a set of properly chained asynchronous allocation/copy pairs to increase
221 /// overlap before launching the kernel.
222 static Value genParametersIn(OpBuilder &builder, Location loc,
223                              SmallVectorImpl<Value> &scalars,
224                              SmallVectorImpl<Value> &buffers,
225                              SmallVectorImpl<Value> &args,
226                              SmallVectorImpl<Value> &tokens,
227                              bool useHostRegistrationForOut) {
228   Value out;
229   // Scalars are passed by value.
230   for (Value s : scalars)
231     args.push_back(s);
232   // Buffers are need to be made visible on device.
233   for (Value b : buffers) {
234     if (useHostRegistrationForOut) {
235       out = genHostRegisterMemref(builder, loc, b);
236       args.push_back(b);
237       useHostRegistrationForOut = false;
238       continue;
239     }
240     args.push_back(genAllocCopy(builder, loc, b, tokens));
241   }
242   return out;
243 }
244 
245 /// Finalizes the outlined arguments. The output buffer is copied depending
246 /// on the kernel token and then deallocated. All other buffers are simply
247 /// deallocated. Then we wait for all operations to complete.
248 static void genParametersOut(OpBuilder &builder, Location loc, Value out,
249                              Value kernelToken, SmallVectorImpl<Value> &scalars,
250                              SmallVectorImpl<Value> &buffers,
251                              SmallVectorImpl<Value> &args,
252                              SmallVectorImpl<Value> &tokens) {
253   unsigned base = scalars.size();
254   for (unsigned i = base, e = args.size(); i < e; i++) {
255     Value firstToken;
256     if (i == base) {
257       // Assumed output parameter: unregister or copy-out.
258       if (out) {
259         genHostUnregisterMemref(builder, loc, out);
260         out = Value();
261         continue;
262       }
263       firstToken =
264           genCopyMemRef(builder, loc, buffers[0], args[i], kernelToken);
265     } else {
266       firstToken = genFirstWait(builder, loc);
267     }
268     tokens.push_back(genDeallocMemRef(builder, loc, args[i], firstToken));
269   }
270 }
271 
272 /// Constructs code for new GPU kernel.
273 static void genGPUCode(PatternRewriter &rewriter, gpu::GPUFuncOp gpuFunc,
274                        scf::ParallelOp forallOp,
275                        SmallVectorImpl<Value> &constants,
276                        SmallVectorImpl<Value> &scalars,
277                        SmallVectorImpl<Value> &buffers) {
278   Location loc = gpuFunc->getLoc();
279   Block &block = gpuFunc.getBody().front();
280   rewriter.setInsertionPointToStart(&block);
281 
282   // Re-generate the constants, recapture all arguments.
283   unsigned arg = 0;
284   IRMapping irMap;
285   for (Value c : constants)
286     irMap.map(c, rewriter.clone(*c.getDefiningOp())->getResult(0));
287   for (Value s : scalars)
288     irMap.map(s, block.getArgument(arg++));
289   for (Value b : buffers)
290     irMap.map(b, block.getArgument(arg++));
291 
292   // Assume 1-dimensional grid/block configuration (only x dimension),
293   // so that:
294   //   row = blockIdx.x * blockDim.x + threadIdx.x
295   //   inc = blockDim.x * gridDim.x
296   Value bid = rewriter.create<gpu::BlockIdOp>(loc, gpu::Dimension::x);
297   Value bsz = rewriter.create<gpu::BlockDimOp>(loc, gpu::Dimension::x);
298   Value tid = rewriter.create<gpu::ThreadIdOp>(loc, gpu::Dimension::x);
299   Value gsz = rewriter.create<gpu::GridDimOp>(loc, gpu::Dimension::x);
300   Value mul = rewriter.create<arith::MulIOp>(loc, bid, bsz);
301   Value row = rewriter.create<arith::AddIOp>(loc, mul, tid);
302   Value inc = rewriter.create<arith::MulIOp>(loc, bsz, gsz);
303 
304   // Construct the iteration over the computational space that
305   // accounts for the fact that the total number of threads and
306   // the amount of work to be done usually do not match precisely.
307   //   for (r = row; r < N; r += inc) {
308   //     <loop-body>
309   //   }
310   Value upper = irMap.lookup(forallOp.getUpperBound()[0]);
311   scf::ForOp forOp = rewriter.create<scf::ForOp>(loc, row, upper, inc);
312   rewriter.cloneRegionBefore(forallOp.getRegion(), forOp.getRegion(),
313                              forOp.getRegion().begin(), irMap);
314 
315   // Done.
316   rewriter.setInsertionPointAfter(forOp);
317   rewriter.create<gpu::ReturnOp>(gpuFunc->getLoc());
318 }
319 
320 //===----------------------------------------------------------------------===//
321 // Library helper methods.
322 //===----------------------------------------------------------------------===//
323 
324 /// Helper to detect a + b with arguments taken from given block.
325 static bool matchAddOfArgs(Block *block, Value val) {
326   if (auto *def = val.getDefiningOp()) {
327     if (isa<arith::AddFOp, arith::AddIOp>(def)) {
328       Value a = block->getArguments()[0];
329       Value b = block->getArguments()[1];
330       return (def->getOperand(0) == a && def->getOperand(1) == b) ||
331              (def->getOperand(0) == b && def->getOperand(1) == a);
332     }
333   }
334   return false;
335 }
336 
337 /// Helper to detect a * b with arguments taken from given block.
338 static bool matchMulOfArgs(Block *block, Value val) {
339   if (auto *def = val.getDefiningOp()) {
340     if (isa<arith::MulFOp, arith::MulIOp>(def)) {
341       Value a = block->getArguments()[0];
342       Value b = block->getArguments()[1];
343       return (def->getOperand(0) == a && def->getOperand(1) == b) ||
344              (def->getOperand(0) == b && def->getOperand(1) == a);
345     }
346   }
347   return false;
348 }
349 
350 /// Helper to detect x = x + a * b
351 static bool matchSumOfMultOfArgs(linalg::GenericOp op) {
352   auto yieldOp = cast<linalg::YieldOp>(op.getRegion().front().getTerminator());
353   if (auto *def = yieldOp.getOperand(0).getDefiningOp()) {
354     if (isa<arith::AddFOp, arith::AddIOp>(def)) {
355       Value x = op.getBlock()->getArguments()[2];
356       return (def->getOperand(0) == x &&
357               matchMulOfArgs(op.getBlock(), def->getOperand(1))) ||
358              (def->getOperand(1) == x &&
359               matchMulOfArgs(op.getBlock(), def->getOperand(0)));
360     }
361   }
362   return false;
363 }
364 
365 // Helper to detect c += spy(s) x (a * b)
366 static bool matchSumReductionOfMulUnary(linalg::GenericOp op) {
367   auto yieldOp = cast<linalg::YieldOp>(op.getRegion().front().getTerminator());
368   // The linalg yields a custom reduce result.
369   Value s_out = op.getBlock()->getArguments()[2];
370   if (auto redOp =
371           yieldOp.getOperand(0).getDefiningOp<sparse_tensor::ReduceOp>()) {
372     // The reduce consumes the output.
373     Value other;
374     if (s_out == redOp->getOperand(0))
375       other = redOp->getOperand(1);
376     else if (s_out == redOp->getOperand(1))
377       other = redOp->getOperand(0);
378     else
379       return false;
380     // The reduce op also consumes an unary which also consumes the output
381     // and does not define an absent value.
382     if (auto unOp = other.getDefiningOp<sparse_tensor::UnaryOp>()) {
383       if (s_out != unOp->getOperand(0) || !unOp.getAbsentRegion().empty())
384         return false;
385       // And the bodies are as expected.
386       auto yieldUn = cast<sparse_tensor::YieldOp>(
387           unOp.getRegion(0).front().getTerminator());
388       auto yieldRed = cast<sparse_tensor::YieldOp>(
389           redOp.getRegion().front().getTerminator());
390       return matchMulOfArgs(op.getBlock(), yieldUn.getOperand(0)) &&
391              matchAddOfArgs(&redOp.getRegion().front(), yieldRed.getOperand(0));
392     }
393   }
394   return false;
395 }
396 
397 /// Test for dense tensor.
398 static bool isDenseTensor(Value v) {
399   auto sTp = getSparseTensorType(v);
400   return sTp.getDimRank() == sTp.getLvlRank() && sTp.isAllDense();
401 }
402 
403 /// Test for suitable positions/coordinates width.
404 static bool isAdmissibleMetaData(SparseTensorType &aTp) {
405   return (aTp.getPosWidth() == 0 || aTp.getPosWidth() >= 16) &&
406          (aTp.getCrdWidth() == 0 || aTp.getCrdWidth() >= 16);
407 }
408 
409 /// Test for sorted COO matrix with suitable metadata.
410 static bool isAdmissibleCOO(SparseTensorType &aTp) {
411   return aTp.getDimRank() == 2 && aTp.getLvlRank() == 2 && aTp.isIdentity() &&
412          aTp.isCompressedLvl(0) && aTp.isOrderedLvl(0) && !aTp.isUniqueLvl(0) &&
413          aTp.isSingletonLvl(1) && aTp.isOrderedLvl(1) && aTp.isUniqueLvl(1) &&
414          isAdmissibleMetaData(aTp);
415 }
416 
417 /// Test for CSR matrix with suitable metadata.
418 static bool isAdmissibleCSR(SparseTensorType &aTp) {
419   return aTp.getDimRank() == 2 && aTp.getLvlRank() == 2 && aTp.isIdentity() &&
420          aTp.isDenseLvl(0) && aTp.isCompressedLvl(1) && aTp.isOrderedLvl(1) &&
421          aTp.isUniqueLvl(1) && isAdmissibleMetaData(aTp);
422 }
423 
424 /// Test for CSC matrix with suitable metadata.
425 static bool isAdmissibleCSC(SparseTensorType &aTp) {
426   return aTp.getDimRank() == 2 && aTp.getLvlRank() == 2 && !aTp.isIdentity() &&
427          aTp.isPermutation() && aTp.isDenseLvl(0) && aTp.isCompressedLvl(1) &&
428          aTp.isOrderedLvl(1) && aTp.isUniqueLvl(1) && isAdmissibleMetaData(aTp);
429 }
430 
431 /// Returns a suitable sparse format for the operation and given operand
432 /// types with cuSparse, or kNone if none is available.
433 static CuSparseFormat getCuSparseFormat(SparseTensorType aTp,
434                                         SparseTensorType bTp,
435                                         SparseTensorType cTp, bool enableRT,
436                                         bool isMatVec) {
437   // The other operands have a dense type.
438   if (bTp.hasEncoding() || cTp.hasEncoding())
439     return CuSparseFormat::kNone;
440   // Now check for suitable operand type for the main operand.
441   if (isAdmissibleCOO(aTp))
442 #ifdef CUSPARSE_COO_AOS
443     return isMatVec ? CuSparseFormat::kCOO : CuSparseFormat::kNone;
444 #else
445     return enableRT ? CuSparseFormat::kCOO : CuSparseFormat::kNone;
446 #endif
447   if (isAdmissibleCSR(aTp))
448     return CuSparseFormat::kCSR;
449   if (isAdmissibleCSC(aTp))
450     return CuSparseFormat::kCSC;
451   return CuSparseFormat::kNone;
452 }
453 
454 /// Generates the first positions/coordinates of a sparse matrix.
455 static Value genFirstPosOrCrds(OpBuilder &builder, Location loc, Value a,
456                                CuSparseFormat format, bool enableRT) {
457   if (format == CuSparseFormat::kCOO) {
458     // Library uses SoA COO, direct IR uses AoS COO.
459     if (enableRT)
460       return genToCoordinates(builder, loc, a, 0, /*cooStart=*/0);
461     return genToCoordinatesBuffer(builder, loc, a);
462   }
463   // Formats CSR/CSC and BSR use positions at 1.
464   return genToPositions(builder, loc, a, 1);
465 }
466 
467 /// Generates the second coordinates of a sparse matrix.
468 static Value genSecondCrds(OpBuilder &builder, Location loc, Value a,
469                            CuSparseFormat format, bool enableRT) {
470   bool isCOO = format == CuSparseFormat::kCOO;
471   if (isCOO && !enableRT)
472     return Value(); // nothing needed
473   // Formats CSR/CSC and BSR use coordinates at 1.
474   return genToCoordinates(builder, loc, a, 1, /*cooStart=*/isCOO ? 0 : 2);
475 }
476 
477 /// Generates the sparse matrix handle.
478 static Operation *genSpMat(OpBuilder &builder, Location loc, Type handleTp,
479                            Type tokenTp, Value token, Value sz1, Value sz2,
480                            Value nseA, Value rowA, Value colA, Value valA,
481                            CuSparseFormat format, bool enableRT) {
482   if (format == CuSparseFormat::kCOO) {
483     // Library uses SoA COO, direct IR uses AoS COO.
484     if (enableRT) {
485       assert(colA);
486       return builder.create<gpu::CreateCooOp>(loc, handleTp, tokenTp, token,
487                                               sz1, sz2, nseA, rowA, colA, valA);
488     }
489 #ifdef CUSPARSE_COO_AOS
490     assert(!colA);
491     return builder.create<gpu::CreateCooAoSOp>(loc, handleTp, tokenTp, token,
492                                                sz1, sz2, nseA, rowA, valA);
493 #else
494     llvm_unreachable("gpu::CreateCooAoSOp is deprecated");
495 #endif
496   }
497   assert(colA);
498   if (format == CuSparseFormat::kCSR)
499     return builder.create<gpu::CreateCsrOp>(loc, handleTp, tokenTp, token, sz1,
500                                             sz2, nseA, rowA, colA, valA);
501   assert(format == CuSparseFormat::kCSC);
502   return builder.create<gpu::CreateCscOp>(loc, handleTp, tokenTp, token, sz1,
503                                           sz2, nseA, rowA, colA, valA);
504 }
505 
506 /// Match and rewrite SpMV kernel.
507 static LogicalResult
508 rewriteSpMV(PatternRewriter &rewriter, linalg::GenericOp op, bool enableRT,
509             GPUDataTransferStrategy gpuDataTransferStrategy) {
510   Location loc = op.getLoc();
511   Value a = op.getOperand(0);
512   Value x = op.getOperand(1);
513   Value y = op.getOperand(2); // we have y = Ax
514   SmallVector<Value> tokens;
515 
516   bool isZeroCopy =
517       gpuDataTransferStrategy == GPUDataTransferStrategy::kZeroCopy;
518 
519   // Only admissible sparse matrix format and dense vectors (no BSR).
520   SparseTensorType aTp = getSparseTensorType(a);
521   SparseTensorType xTp = getSparseTensorType(x);
522   SparseTensorType yTp = getSparseTensorType(y);
523   auto format = getCuSparseFormat(aTp, xTp, yTp, enableRT, /*isMatVec=*/true);
524   if (format == CuSparseFormat::kNone || format == CuSparseFormat::kBSR)
525     return failure();
526 
527   // Start sparse kernel and copy data from host to device.
528   //   a : memR/memC/memV -> rowA,colA,valA
529   //   x : memX           -> vecX
530   //   y : memY           -> vecY
531   Value nseA = rewriter.create<NumberOfEntriesOp>(loc, a);
532   Value szY = linalg::createOrFoldDimOp(rewriter, loc, a, 0);
533   Value szX = linalg::createOrFoldDimOp(rewriter, loc, a, 1);
534   Value memR = genFirstPosOrCrds(rewriter, loc, a, format, enableRT);
535   Value memC = genSecondCrds(rewriter, loc, a, format, enableRT);
536   Value memV = genToValues(rewriter, loc, a);
537   Value memX, memY;
538   Value castR, castC, castV, castX, castY;
539   if (gpuDataTransferStrategy != GPUDataTransferStrategy::kRegularDMA) {
540     memX = genTensorToMemref(rewriter, loc, x);
541     memY = genTensorToMemref(rewriter, loc, y);
542     castR = genHostRegisterMemref(rewriter, loc, memR);
543     if (memC)
544       castC = genHostRegisterMemref(rewriter, loc, memC);
545     castV = genHostRegisterMemref(rewriter, loc, memV);
546     castX = genHostRegisterMemref(rewriter, loc, memX);
547     castY = genHostRegisterMemref(rewriter, loc, memY);
548   }
549 
550   Value rowA = genAllocCopy(rewriter, loc, memR, tokens);
551   Value colA = memC ? genAllocCopy(rewriter, loc, memC, tokens) : Value();
552   Value valA = genAllocCopy(rewriter, loc, memV, tokens);
553   if (gpuDataTransferStrategy == GPUDataTransferStrategy::kRegularDMA)
554     memX = genTensorToMemref(rewriter, loc, x);
555   Value vecX = isZeroCopy ? memX : genAllocCopy(rewriter, loc, memX, tokens);
556   if (gpuDataTransferStrategy == GPUDataTransferStrategy::kRegularDMA)
557     memY = genTensorToMemref(rewriter, loc, y);
558   Value vecY = genAllocCopy(rewriter, loc, memY, tokens);
559   genBlockingWait(rewriter, loc, tokens);
560   tokens.clear();
561 
562   // Create sparse environment and sparse matrix/dense vector handles.
563   Type indexTp = rewriter.getIndexType();
564   Type dnTensorHandleTp = rewriter.getType<gpu::SparseDnTensorHandleType>();
565   Type spmatHandleTp = rewriter.getType<gpu::SparseSpMatHandleType>();
566   Type tokenTp = rewriter.getType<gpu::AsyncTokenType>();
567   Value token = genFirstWait(rewriter, loc);
568   Operation *spGenA =
569       genSpMat(rewriter, loc, spmatHandleTp, tokenTp, token, szY, szX, nseA,
570                rowA, colA, valA, format, enableRT);
571   Value spMatA = spGenA->getResult(0);
572   token = spGenA->getResult(1);
573   auto dvecX = rewriter.create<gpu::CreateDnTensorOp>(
574       loc, dnTensorHandleTp, tokenTp, token, vecX, szX);
575   Value dnX = dvecX.getResult(0);
576   token = dvecX.getAsyncToken();
577   auto dvecY = rewriter.create<gpu::CreateDnTensorOp>(
578       loc, dnTensorHandleTp, tokenTp, token, vecY, szY);
579   Value dnY = dvecY.getResult(0);
580   token = dvecY.getAsyncToken();
581   auto dnYType = llvm::cast<ShapedType>(y.getType()).getElementType();
582 
583   // Precompute buffersize for SpMV.
584   auto bufferComp = rewriter.create<gpu::SpMVBufferSizeOp>(
585       loc, indexTp, tokenTp, token, spMatA, dnX, dnY,
586       /*computeType=*/dnYType);
587   Value bufferSz = bufferComp.getResult(0);
588   token = bufferComp.getAsyncToken();
589   auto buf = genAllocBuffer(rewriter, loc, bufferSz, token);
590   Value buffer = buf.getResult(0);
591   token = buf.getAsyncToken();
592 
593   // Perform the SpMV.
594   auto spmvComp = rewriter.create<gpu::SpMVOp>(
595       loc, tokenTp, token, spMatA, dnX, dnY, /*computeType=*/dnYType, buffer);
596   token = spmvComp.getAsyncToken();
597 
598   // Copy data back to host and free all the resoures.
599   token = rewriter.create<gpu::DestroySpMatOp>(loc, tokenTp, token, spMatA)
600               .getAsyncToken();
601   token = rewriter.create<gpu::DestroyDnTensorOp>(loc, tokenTp, token, dnX)
602               .getAsyncToken();
603   token = rewriter.create<gpu::DestroyDnTensorOp>(loc, tokenTp, token, dnY)
604               .getAsyncToken();
605   token = genDeallocMemRef(rewriter, loc, rowA, token);
606   if (colA)
607     token = genDeallocMemRef(rewriter, loc, colA, token);
608   token = genDeallocMemRef(rewriter, loc, valA, token);
609   token = genDeallocMemRef(rewriter, loc, buffer, token);
610   if (!isZeroCopy)
611     token = genDeallocMemRef(rewriter, loc, vecX, token);
612   token = genCopyMemRef(rewriter, loc, memY, vecY, token);
613   token = genDeallocMemRef(rewriter, loc, vecY, token);
614   tokens.push_back(token);
615   genBlockingWait(rewriter, loc, tokens);
616   tokens.clear();
617   if (gpuDataTransferStrategy != GPUDataTransferStrategy::kRegularDMA) {
618     genHostUnregisterMemref(rewriter, loc, castR);
619     if (memC)
620       genHostUnregisterMemref(rewriter, loc, castC);
621     genHostUnregisterMemref(rewriter, loc, castV);
622     genHostUnregisterMemref(rewriter, loc, castX);
623     genHostUnregisterMemref(rewriter, loc, castY);
624   }
625 
626   // Done.
627   rewriter.replaceOpWithNewOp<bufferization::ToTensorOp>(op, memY);
628   return success();
629 }
630 
631 /// Match and rewrite SpMM kernel.
632 static LogicalResult
633 rewriteSpMM(PatternRewriter &rewriter, linalg::GenericOp op, bool enableRT,
634             GPUDataTransferStrategy gpuDataTransferStrategy) {
635   Location loc = op.getLoc();
636   Value a = op.getOperand(0);
637   Value b = op.getOperand(1);
638   Value c = op.getOperand(2); // we have C = AB
639   SmallVector<Value> tokens;
640 
641   bool isZeroCopy =
642       gpuDataTransferStrategy == GPUDataTransferStrategy::kZeroCopy;
643 
644   // Only admissible sparse matrix format and dense matrices (no BSR).
645   SparseTensorType aTp = getSparseTensorType(a);
646   SparseTensorType bTp = getSparseTensorType(b);
647   SparseTensorType cTp = getSparseTensorType(c);
648   auto format = getCuSparseFormat(aTp, bTp, cTp, enableRT, /*isMatVec=*/false);
649   if (format == CuSparseFormat::kNone || format == CuSparseFormat::kBSR)
650     return failure();
651 
652   // Start sparse kernel and copy data from host to device.
653   //   a : memR/memC/memV -> rowA,colA,valA
654   //   b : bufB           -> matA
655   //   c : bufC           -> matC
656   Value nseA = rewriter.create<NumberOfEntriesOp>(loc, a);
657   Value szm = linalg::createOrFoldDimOp(rewriter, loc, a, 0);
658   Value szk = linalg::createOrFoldDimOp(rewriter, loc, a, 1);
659   Value szn = linalg::createOrFoldDimOp(rewriter, loc, b, 1);
660   Value memR = genFirstPosOrCrds(rewriter, loc, a, format, enableRT);
661   Value memC = genSecondCrds(rewriter, loc, a, format, enableRT);
662   Value memV = genToValues(rewriter, loc, a);
663   Value bufB, bufC;
664   Value castR, castC, castV, castB, castBufC;
665   if (gpuDataTransferStrategy != GPUDataTransferStrategy::kRegularDMA) {
666     bufB = genTensorToMemref(rewriter, loc, b);
667     bufC = genTensorToMemref(rewriter, loc, c);
668     castR = genHostRegisterMemref(rewriter, loc, memR);
669     if (memC)
670       castC = genHostRegisterMemref(rewriter, loc, memC);
671     castV = genHostRegisterMemref(rewriter, loc, memV);
672     castB = genHostRegisterMemref(rewriter, loc, bufB);
673     castBufC = genHostRegisterMemref(rewriter, loc, bufC);
674   }
675   Value rowA = genAllocCopy(rewriter, loc, memR, tokens);
676   Value colA = memC ? genAllocCopy(rewriter, loc, memC, tokens) : Value();
677   Value valA = genAllocCopy(rewriter, loc, memV, tokens);
678   if (gpuDataTransferStrategy == GPUDataTransferStrategy::kRegularDMA)
679     bufB = genTensorToMemref(rewriter, loc, b);
680   Value matB = isZeroCopy ? bufB : genAllocCopy(rewriter, loc, bufB, tokens);
681   if (gpuDataTransferStrategy == GPUDataTransferStrategy::kRegularDMA)
682     bufC = genTensorToMemref(rewriter, loc, c);
683   Value matC = genAllocCopy(rewriter, loc, bufC, tokens);
684   genBlockingWait(rewriter, loc, tokens);
685   tokens.clear();
686 
687   // Create sparse environment and sparse matrix/dense matrix handles.
688   Type indexTp = rewriter.getIndexType();
689   Type dnTensorHandleTp = rewriter.getType<gpu::SparseDnTensorHandleType>();
690   Type spMatHandleTp = rewriter.getType<gpu::SparseSpMatHandleType>();
691   Type tokenTp = rewriter.getType<gpu::AsyncTokenType>();
692   Value token = genFirstWait(rewriter, loc);
693   Operation *spGenA =
694       genSpMat(rewriter, loc, spMatHandleTp, tokenTp, token, szm, szk, nseA,
695                rowA, colA, valA, format, enableRT);
696   Value spMatA = spGenA->getResult(0);
697   token = spGenA->getResult(1);
698   auto dmatB = rewriter.create<gpu::CreateDnTensorOp>(
699       loc, dnTensorHandleTp, tokenTp, token, matB,
700       SmallVector<Value>{szk, szn});
701   Value dnB = dmatB.getResult(0);
702   token = dmatB.getAsyncToken();
703   auto dmatC = rewriter.create<gpu::CreateDnTensorOp>(
704       loc, dnTensorHandleTp, tokenTp, token, matC,
705       SmallVector<Value>{szm, szn});
706   Value dnC = dmatC.getResult(0);
707   token = dmatC.getAsyncToken();
708   auto dmatCType = llvm::cast<ShapedType>(c.getType()).getElementType();
709 
710   // Precompute buffersize for SpMM.
711   auto bufferComp = rewriter.create<gpu::SpMMBufferSizeOp>(
712       loc, indexTp, tokenTp, token, spMatA, dnB, dnC,
713       /*computeType=*/dmatCType);
714   Value bufferSz = bufferComp.getResult(0);
715   token = bufferComp.getAsyncToken();
716   auto buf = genAllocBuffer(rewriter, loc, bufferSz, token);
717   Value buffer = buf.getResult(0);
718   token = buf.getAsyncToken();
719   auto dnCType = llvm::cast<ShapedType>(c.getType()).getElementType();
720 
721   // Perform the SpMM.
722   auto spmmComp = rewriter.create<gpu::SpMMOp>(
723       loc, tokenTp, token, spMatA, dnB, dnC, /*computeType=*/dnCType, buffer);
724   token = spmmComp.getAsyncToken();
725 
726   // Copy data back to host and free all the resoures.
727   token = rewriter.create<gpu::DestroySpMatOp>(loc, tokenTp, token, spMatA)
728               .getAsyncToken();
729   token = rewriter.create<gpu::DestroyDnTensorOp>(loc, tokenTp, token, dnB)
730               .getAsyncToken();
731   token = rewriter.create<gpu::DestroyDnTensorOp>(loc, tokenTp, token, dnC)
732               .getAsyncToken();
733   token = genDeallocMemRef(rewriter, loc, rowA, token);
734   if (colA)
735     token = genDeallocMemRef(rewriter, loc, colA, token);
736   token = genDeallocMemRef(rewriter, loc, valA, token);
737   token = genDeallocMemRef(rewriter, loc, buffer, token);
738   if (!isZeroCopy)
739     token = genDeallocMemRef(rewriter, loc, matB, token);
740   token = genCopyMemRef(rewriter, loc, bufC, matC, token);
741   token = genDeallocMemRef(rewriter, loc, matC, token);
742   tokens.push_back(token);
743   genBlockingWait(rewriter, loc, tokens);
744   tokens.clear();
745   if (gpuDataTransferStrategy != GPUDataTransferStrategy::kRegularDMA) {
746     genHostUnregisterMemref(rewriter, loc, castR);
747     if (memC)
748       genHostUnregisterMemref(rewriter, loc, castC);
749     genHostUnregisterMemref(rewriter, loc, castV);
750     genHostUnregisterMemref(rewriter, loc, castB);
751     genHostUnregisterMemref(rewriter, loc, castC);
752   }
753 
754   // Done.
755   rewriter.replaceOpWithNewOp<bufferization::ToTensorOp>(op, bufC);
756   return success();
757 }
758 
759 // Match and rewrite SpGEMM kernel.
760 static LogicalResult
761 rewriteSpGEMM(PatternRewriter &rewriter, linalg::GenericOp op, bool enableRT,
762               GPUDataTransferStrategy gpuDataTransferStrategy) {
763   Location loc = op.getLoc();
764   Value a = op.getOperand(0);
765   Value b = op.getOperand(1);
766   Value c = op.getOperand(2); // we have C = AB
767   SmallVector<Value> tokens;
768 
769   // Only CSR <- CSR x CSR supported.
770   auto format = CuSparseFormat::kCSR;
771   SparseTensorType aTp = getSparseTensorType(a);
772   SparseTensorType bTp = getSparseTensorType(b);
773   SparseTensorType cTp = getSparseTensorType(c);
774   if (!isAdmissibleCSR(aTp) || !isAdmissibleCSR(bTp) || !isAdmissibleCSR(cTp))
775     return failure();
776 
777   // Start sparse kernel and copy data from host to device.
778   //   a : amemR/amemC/amemV -> rowA,colA,valA
779   //   b : bmemR/bmemC/bmemV -> rowB,colB,valB
780   //   c : materializes
781   auto dnCType = cTp.getElementType();
782   Value nseA = rewriter.create<NumberOfEntriesOp>(loc, a);
783   Value nseB = rewriter.create<NumberOfEntriesOp>(loc, b);
784   Value szm = linalg::createOrFoldDimOp(rewriter, loc, a, 0);
785   Value szk = linalg::createOrFoldDimOp(rewriter, loc, a, 1);
786   Value szn = linalg::createOrFoldDimOp(rewriter, loc, b, 1);
787   Value amemR = genFirstPosOrCrds(rewriter, loc, a, format, enableRT);
788   Value amemC = genSecondCrds(rewriter, loc, a, format, enableRT);
789   Value amemV = genToValues(rewriter, loc, a);
790   Value bmemR = genFirstPosOrCrds(rewriter, loc, b, format, enableRT);
791   Value bmemC = genSecondCrds(rewriter, loc, b, format, enableRT);
792   Value bmemV = genToValues(rewriter, loc, b);
793   Value rowA = genAllocCopy(rewriter, loc, amemR, tokens);
794   Value colA = genAllocCopy(rewriter, loc, amemC, tokens);
795   Value valA = genAllocCopy(rewriter, loc, amemV, tokens);
796   Value rowB = genAllocCopy(rewriter, loc, bmemR, tokens);
797   Value colB = genAllocCopy(rewriter, loc, bmemC, tokens);
798   Value valB = genAllocCopy(rewriter, loc, bmemV, tokens);
799   genBlockingWait(rewriter, loc, tokens);
800   tokens.clear();
801 
802   // Create sparse environment and sparse matrix/dense vector handles.
803   Type indexTp = rewriter.getIndexType();
804   Type spmatHandleTp = rewriter.getType<gpu::SparseSpMatHandleType>();
805   Type descTp = rewriter.getType<gpu::SparseSpGEMMOpHandleType>();
806   Type tokenTp = rewriter.getType<gpu::AsyncTokenType>();
807   Value token = genFirstWait(rewriter, loc);
808   Operation *spGenA =
809       genSpMat(rewriter, loc, spmatHandleTp, tokenTp, token, szm, szk, nseA,
810                rowA, colA, valA, format, enableRT);
811   Value spMatA = spGenA->getResult(0);
812   token = spGenA->getResult(1);
813   Operation *spGenB =
814       genSpMat(rewriter, loc, spmatHandleTp, tokenTp, token, szk, szn, nseB,
815                rowB, colB, valB, format, enableRT);
816   Value spMatB = spGenB->getResult(0);
817   token = spGenB->getResult(1);
818 
819   // Sparse matrix C materializes (also assumes beta == 0).
820   Value zero = constantIndex(rewriter, loc, 0);
821   Value one = constantIndex(rewriter, loc, 1);
822   Value mplus1 = rewriter.create<arith::AddIOp>(loc, szm, one);
823   auto e1 = genAllocBuffer(rewriter, loc, cTp.getPosType(), mplus1, token);
824   Value rowC = e1.getResult(0);
825   token = e1.getAsyncToken();
826   auto e2 = genAllocBuffer(rewriter, loc, cTp.getCrdType(), zero, token);
827   Value colC = e2.getResult(0); // no free needed
828   token = e2.getAsyncToken();
829   auto e3 = genAllocBuffer(rewriter, loc, dnCType, zero, token);
830   Value valC = e3.getResult(0); // no free needed
831   token = e3.getAsyncToken();
832   Operation *spGenC =
833       genSpMat(rewriter, loc, spmatHandleTp, tokenTp, token, szm, szn, zero,
834                rowC, colC, valC, format, enableRT);
835   Value spMatC = spGenC->getResult(0);
836   token = spGenC->getResult(1);
837 
838   // Precompute buffersizes for SpGEMM.
839   Operation *descOp =
840       rewriter.create<gpu::SpGEMMCreateDescrOp>(loc, descTp, tokenTp, token);
841   Value desc = descOp->getResult(0);
842   token = descOp->getResult(1);
843   Operation *work1 = rewriter.create<gpu::SpGEMMWorkEstimationOrComputeOp>(
844       loc, indexTp, tokenTp, token, desc, gpu::TransposeMode::NON_TRANSPOSE,
845       gpu::TransposeMode::NON_TRANSPOSE, spMatA, spMatB, spMatC, dnCType, zero,
846       valC, gpu::SpGEMMWorkEstimationOrComputeKind::WORK_ESTIMATION);
847   Value bufferSz1 = work1->getResult(0);
848   token = work1->getResult(1);
849   auto buf1 = genAllocBuffer(rewriter, loc, bufferSz1, token);
850   Value buffer1 = buf1.getResult(0);
851   token = buf1.getAsyncToken();
852   Operation *work2 = rewriter.create<gpu::SpGEMMWorkEstimationOrComputeOp>(
853       loc, indexTp, tokenTp, token, desc, gpu::TransposeMode::NON_TRANSPOSE,
854       gpu::TransposeMode::NON_TRANSPOSE, spMatA, spMatB, spMatC, dnCType,
855       bufferSz1, buffer1,
856       gpu::SpGEMMWorkEstimationOrComputeKind::WORK_ESTIMATION);
857   token = work2->getResult(1);
858 
859   // Compute step.
860   Operation *compute1 = rewriter.create<gpu::SpGEMMWorkEstimationOrComputeOp>(
861       loc, indexTp, tokenTp, token, desc, gpu::TransposeMode::NON_TRANSPOSE,
862       gpu::TransposeMode::NON_TRANSPOSE, spMatA, spMatB, spMatC, dnCType, zero,
863       valC, gpu::SpGEMMWorkEstimationOrComputeKind::COMPUTE);
864   Value bufferSz2 = compute1->getResult(0);
865   token = compute1->getResult(1);
866   auto buf2 = genAllocBuffer(rewriter, loc, bufferSz2, token);
867   Value buffer2 = buf2.getResult(0);
868   token = buf2.getAsyncToken();
869   Operation *compute2 = rewriter.create<gpu::SpGEMMWorkEstimationOrComputeOp>(
870       loc, indexTp, tokenTp, token, desc, gpu::TransposeMode::NON_TRANSPOSE,
871       gpu::TransposeMode::NON_TRANSPOSE, spMatA, spMatB, spMatC, dnCType,
872       bufferSz2, buffer2, gpu::SpGEMMWorkEstimationOrComputeKind::COMPUTE);
873   token = compute2->getResult(1);
874 
875   // Get sizes.
876   Operation *sizes = rewriter.create<gpu::SpMatGetSizeOp>(
877       loc, indexTp, indexTp, indexTp, tokenTp, token, spMatC);
878   Value nnz = sizes->getResult(2);
879   token = sizes->getResult(3);
880   auto a2 = genAllocBuffer(rewriter, loc, cTp.getCrdType(), nnz, token);
881   colC = a2.getResult(0);
882   token = a2.getAsyncToken();
883   auto a3 = genAllocBuffer(rewriter, loc, dnCType, nnz, token);
884   valC = a3.getResult(0);
885   token = a3.getAsyncToken();
886 
887   // Update C with new pointers and copy final product back into C.
888   Operation *update = rewriter.create<gpu::SetCsrPointersOp>(
889       loc, tokenTp, token, spMatC, rowC, colC, valC);
890   token = update->getResult(0);
891   Operation *copy = rewriter.create<gpu::SpGEMMCopyOp>(
892       loc, tokenTp, token, desc, gpu::TransposeMode::NON_TRANSPOSE,
893       gpu::TransposeMode::NON_TRANSPOSE, spMatA, spMatB, spMatC, dnCType);
894   token = copy->getResult(0);
895 
896   // Allocate buffers on host.
897   Value rowH = genHostBuffer(rewriter, loc, cTp.getPosType(), mplus1);
898   Value colH = genHostBuffer(rewriter, loc, cTp.getCrdType(), nnz);
899   Value valH = genHostBuffer(rewriter, loc, dnCType, nnz);
900 
901   // Copy data back to host and free all the resoures.
902   token = rewriter.create<gpu::SpGEMMDestroyDescrOp>(loc, tokenTp, token, desc)
903               .getAsyncToken();
904   token = rewriter.create<gpu::DestroySpMatOp>(loc, tokenTp, token, spMatA)
905               .getAsyncToken();
906   token = rewriter.create<gpu::DestroySpMatOp>(loc, tokenTp, token, spMatB)
907               .getAsyncToken();
908   token = rewriter.create<gpu::DestroySpMatOp>(loc, tokenTp, token, spMatC)
909               .getAsyncToken();
910   token = genCopyMemRef(rewriter, loc, rowH, rowC, token);
911   token = genCopyMemRef(rewriter, loc, colH, colC, token);
912   token = genCopyMemRef(rewriter, loc, valH, valC, token);
913   token = genDeallocMemRef(rewriter, loc, rowA, token);
914   token = genDeallocMemRef(rewriter, loc, colA, token);
915   token = genDeallocMemRef(rewriter, loc, valA, token);
916   token = genDeallocMemRef(rewriter, loc, rowB, token);
917   token = genDeallocMemRef(rewriter, loc, colB, token);
918   token = genDeallocMemRef(rewriter, loc, valB, token);
919   token = genDeallocMemRef(rewriter, loc, rowC, token);
920   token = genDeallocMemRef(rewriter, loc, colC, token);
921   token = genDeallocMemRef(rewriter, loc, valC, token);
922   token = genDeallocMemRef(rewriter, loc, buffer1, token);
923   token = genDeallocMemRef(rewriter, loc, buffer2, token);
924   tokens.push_back(token);
925   genBlockingWait(rewriter, loc, tokens);
926   tokens.clear();
927 
928   // Done.
929   Value vt = rewriter.create<bufferization::ToTensorOp>(loc, valH);
930   Value rt = rewriter.create<bufferization::ToTensorOp>(loc, rowH);
931   Value ct = rewriter.create<bufferization::ToTensorOp>(loc, colH);
932   rewriter.replaceOpWithNewOp<AssembleOp>(op, c.getType(), vt,
933                                           ValueRange{rt, ct});
934   return success();
935 }
936 
937 // Match and rewrite 2:4 SpMM kernel.
938 static LogicalResult
939 rewrite2To4SpMM(PatternRewriter &rewriter, linalg::GenericOp op,
940                 GPUDataTransferStrategy gpuDataTransferStrategy) {
941   Location loc = op.getLoc();
942   Value A = op.getOperand(0);
943   Value B = op.getOperand(1);
944   Value C = op.getOperand(2); // we have C = AB
945   SmallVector<Value> tokens;
946 
947   bool isZeroCopy =
948       gpuDataTransferStrategy == GPUDataTransferStrategy::kZeroCopy;
949 
950   // All input should be dense tensors.
951   if (!isDenseTensor(A) || !isDenseTensor(B) || !isDenseTensor(C))
952     return failure();
953 
954   Value matA, matB;
955   Value bufA = genTensorToMemref(rewriter, loc, A);
956   if (!isZeroCopy)
957     matA = genAllocCopy(rewriter, loc, bufA, tokens);
958   Value bufB = genTensorToMemref(rewriter, loc, B);
959   if (!isZeroCopy)
960     matB = genAllocCopy(rewriter, loc, bufB, tokens);
961   Value bufC = genTensorToMemref(rewriter, loc, C);
962   Value castA, castB, castC;
963   if (gpuDataTransferStrategy != GPUDataTransferStrategy::kRegularDMA) {
964     castA = genHostRegisterMemref(rewriter, loc, bufA);
965     castB = genHostRegisterMemref(rewriter, loc, bufB);
966     castC = genHostRegisterMemref(rewriter, loc, bufC);
967   }
968   if (isZeroCopy) {
969     matA = bufA;
970     matB = bufB;
971   }
972   Value matC = genAllocCopy(rewriter, loc, bufC, tokens);
973   genBlockingWait(rewriter, loc, tokens);
974   tokens.clear();
975 
976   // Create sparse environment and sparse matrix/dense vector handles.
977   Value szm = linalg::createOrFoldDimOp(rewriter, loc, matA, 0);
978   Value szk = linalg::createOrFoldDimOp(rewriter, loc, matB, 0);
979   Value szn = linalg::createOrFoldDimOp(rewriter, loc, matC, 1);
980   Type indexTp = rewriter.getIndexType();
981   Type dnTensorHandleTp = rewriter.getType<gpu::SparseDnTensorHandleType>();
982   Type spMatHandleTp = rewriter.getType<gpu::SparseSpMatHandleType>();
983   Type tokenTp = rewriter.getType<gpu::AsyncTokenType>();
984   Value token = genFirstWait(rewriter, loc);
985   Operation *spGenA = rewriter.create<gpu::Create2To4SpMatOp>(
986       loc, spMatHandleTp, tokenTp, token, szm, szk,
987       gpu::Prune2To4SpMatFlag::PRUNE_AND_CHECK, matA);
988   Value spMatA = spGenA->getResult(0);
989   token = spGenA->getResult(1);
990   auto dmatB = rewriter.create<gpu::CreateDnTensorOp>(
991       loc, dnTensorHandleTp, tokenTp, token, matB,
992       SmallVector<Value>{szk, szn});
993   Value dnB = dmatB.getResult(0);
994   token = dmatB.getAsyncToken();
995   auto dmatC = rewriter.create<gpu::CreateDnTensorOp>(
996       loc, dnTensorHandleTp, tokenTp, token, matC,
997       SmallVector<Value>{szm, szn});
998   Value dnC = dmatC.getResult(0);
999   token = dmatC.getAsyncToken();
1000   auto dmatCType = llvm::cast<ShapedType>(matC.getType()).getElementType();
1001 
1002   // Precompute buffersize for SpMM.
1003   SmallVector<Type> bufferTypes_{indexTp, indexTp, indexTp};
1004   TypeRange bufferTypes(bufferTypes_);
1005   auto bufferComp = rewriter.create<gpu::SpMMBufferSizeOp>(
1006       loc, bufferTypes, tokenTp, token, gpu::TransposeMode::NON_TRANSPOSE,
1007       gpu::TransposeMode::NON_TRANSPOSE, spMatA, dnB, dnC,
1008       /*computeType=*/dmatCType);
1009   token = bufferComp.getAsyncToken();
1010 
1011   Value bufferSz = bufferComp.getResult(0);
1012   auto buf = genAllocBuffer(rewriter, loc, bufferSz, token);
1013   Value buffer = buf.getResult(0);
1014   token = buf.getAsyncToken();
1015 
1016   Value bufferSz2 = bufferComp.getResult(1);
1017   auto buf2 = genAllocBuffer(rewriter, loc, bufferSz2, token);
1018   Value buffer2 = buf2.getResult(0);
1019   token = buf2.getAsyncToken();
1020 
1021   Value bufferSz3 = bufferComp.getResult(2);
1022   auto buf3 = genAllocBuffer(rewriter, loc, bufferSz3, token);
1023   Value buffer3 = buf3.getResult(0);
1024   token = buf3.getAsyncToken();
1025 
1026   auto dnCType = llvm::cast<ShapedType>(matC.getType()).getElementType();
1027 
1028   // Perform the SpMM.
1029   auto spmmComp = rewriter.create<gpu::SpMMOp>(
1030       loc, tokenTp, token, spMatA, dnB, dnC, /*computeType=*/dnCType,
1031       SmallVector<Value>{buffer, buffer2, buffer3});
1032   token = spmmComp.getAsyncToken();
1033 
1034   // Copy data back to host and free all the resources.
1035   token = rewriter.create<gpu::DestroySpMatOp>(loc, tokenTp, token, spMatA)
1036               .getAsyncToken();
1037   token = rewriter.create<gpu::DestroyDnTensorOp>(loc, tokenTp, token, dnB)
1038               .getAsyncToken();
1039   token = rewriter.create<gpu::DestroyDnTensorOp>(loc, tokenTp, token, dnC)
1040               .getAsyncToken();
1041   SmallVector<Value> newDynamicSizes;
1042   token = genDeallocMemRef(rewriter, loc, buffer, token);
1043   token = genDeallocMemRef(rewriter, loc, buffer2, token);
1044   token = genDeallocMemRef(rewriter, loc, buffer3, token);
1045   if (!isZeroCopy)
1046     token = genDeallocMemRef(rewriter, loc, matA, token);
1047   if (!isZeroCopy)
1048     token = genDeallocMemRef(rewriter, loc, matB, token);
1049   token = genCopyMemRef(rewriter, loc, bufC, matC, token);
1050   token = genDeallocMemRef(rewriter, loc, matC, token);
1051   tokens.push_back(token);
1052   genBlockingWait(rewriter, loc, tokens);
1053   tokens.clear();
1054   if (gpuDataTransferStrategy != GPUDataTransferStrategy::kRegularDMA) {
1055     genHostUnregisterMemref(rewriter, loc, castA);
1056     genHostUnregisterMemref(rewriter, loc, castB);
1057     genHostUnregisterMemref(rewriter, loc, castC);
1058   }
1059 
1060   // Done.
1061   rewriter.replaceOpWithNewOp<bufferization::ToTensorOp>(op, bufC);
1062   return success();
1063 }
1064 
1065 /// Match and rewrite SDDMM kernel.
1066 static LogicalResult
1067 rewriteSDDMM(PatternRewriter &rewriter, linalg::GenericOp op, bool enableRT,
1068              GPUDataTransferStrategy gpuDataTransferStrategy) {
1069   Location loc = op.getLoc();
1070   Value a = op.getOperand(0);
1071   Value b = op.getOperand(1);
1072   Value c = op.getOperand(2);
1073   SmallVector<Value> tokens;
1074 
1075   bool isZeroCopy =
1076       gpuDataTransferStrategy == GPUDataTransferStrategy::kZeroCopy;
1077 
1078   // Only admissible sparse matrix format (no COO/CSC) and dense matrices.
1079   SparseTensorType aTp = getSparseTensorType(a);
1080   SparseTensorType bTp = getSparseTensorType(b);
1081   SparseTensorType cTp = getSparseTensorType(c);
1082   auto format = getCuSparseFormat(cTp, bTp, aTp, enableRT, /*isMatVec=*/false);
1083   if (format == CuSparseFormat::kNone || format == CuSparseFormat::kCOO ||
1084       format == CuSparseFormat::kCSC)
1085     return failure();
1086 
1087   // The SDDMM does the in-place operation.
1088   // Start sparse kernel and copy data from host to device.
1089   //   a : bufA           -> matA
1090   //   b : bufB           -> matA
1091   //   c : memR/memC/memV -> rowC,colC,valC
1092   Value nseC = rewriter.create<NumberOfEntriesOp>(loc, c);
1093   Value szm = linalg::createOrFoldDimOp(rewriter, loc, a, 0);
1094   Value szk = linalg::createOrFoldDimOp(rewriter, loc, a, 1);
1095   Value szn = linalg::createOrFoldDimOp(rewriter, loc, b, 1);
1096   Value matA, matB;
1097   Value bufA = genTensorToMemref(rewriter, loc, a);
1098   if (!isZeroCopy)
1099     matA = genAllocCopy(rewriter, loc, bufA, tokens);
1100   Value bufB = genTensorToMemref(rewriter, loc, b);
1101   if (!isZeroCopy)
1102     matB = isZeroCopy ? bufB : genAllocCopy(rewriter, loc, bufB, tokens);
1103   Value memR = genFirstPosOrCrds(rewriter, loc, c, format, enableRT);
1104   Value memC = genSecondCrds(rewriter, loc, c, format, enableRT);
1105   Value memV = genToValues(rewriter, loc, c);
1106   Value castB, castA, castR, castC, castV;
1107   if (gpuDataTransferStrategy != GPUDataTransferStrategy::kRegularDMA) {
1108     castB = genHostRegisterMemref(rewriter, loc, bufB);
1109     castA = genHostRegisterMemref(rewriter, loc, bufA);
1110     castR = genHostRegisterMemref(rewriter, loc, memR);
1111     if (memC)
1112       castC = genHostRegisterMemref(rewriter, loc, memC);
1113     castV = genHostRegisterMemref(rewriter, loc, memV);
1114   }
1115   if (isZeroCopy) {
1116     matA = bufA;
1117     matB = bufB;
1118   }
1119   Value rowC = genAllocCopy(rewriter, loc, memR, tokens);
1120   Value colC = memC ? genAllocCopy(rewriter, loc, memC, tokens) : Value();
1121   Value valC = genAllocCopy(rewriter, loc, memV, tokens);
1122   genBlockingWait(rewriter, loc, tokens);
1123   tokens.clear();
1124 
1125   // Create sparse environment and sparse matrix/dense matrix handles.
1126   Type indexTp = rewriter.getIndexType();
1127   Type dnMatHandleTp = rewriter.getType<gpu::SparseDnTensorHandleType>();
1128   Type spMatHandleTp = rewriter.getType<gpu::SparseSpMatHandleType>();
1129   Type tokenTp = rewriter.getType<gpu::AsyncTokenType>();
1130   Value token = genFirstWait(rewriter, loc);
1131   auto dmatA = rewriter.create<gpu::CreateDnTensorOp>(
1132       loc, dnMatHandleTp, tokenTp, token, matA, SmallVector<Value>{szm, szk});
1133   Value dnA = dmatA.getResult(0);
1134   token = dmatA.getAsyncToken();
1135   auto dmatB = rewriter.create<gpu::CreateDnTensorOp>(
1136       loc, dnMatHandleTp, tokenTp, token, matB, SmallVector<Value>{szk, szn});
1137   Value dnB = dmatB.getResult(0);
1138   token = dmatB.getAsyncToken();
1139   Operation *spGenC =
1140       genSpMat(rewriter, loc, spMatHandleTp, tokenTp, token, szm, szn, nseC,
1141                rowC, colC, valC, format, enableRT);
1142   Value spMatC = spGenC->getResult(0);
1143   token = spGenC->getResult(1);
1144   auto dnCType = llvm::cast<ShapedType>(c.getType()).getElementType();
1145 
1146   // Precompute buffersize for SDDMM.
1147   auto bufferComp = rewriter.create<gpu::SDDMMBufferSizeOp>(
1148       loc, indexTp, tokenTp, token, dnA, dnB, spMatC, dnCType);
1149   Value bufferSz = bufferComp.getResult(0);
1150   token = bufferComp.getAsyncToken();
1151   auto buf = genAllocBuffer(rewriter, loc, bufferSz, token);
1152   Value buffer = buf.getResult(0);
1153   token = buf.getAsyncToken();
1154 
1155   // Perform the SDDMM.
1156   auto sddmmComp = rewriter.create<gpu::SDDMMOp>(loc, tokenTp, token, dnA, dnB,
1157                                                  spMatC, dnCType, buffer);
1158   token = sddmmComp.getAsyncToken();
1159 
1160   // Copy data back to host and free all the resoures.
1161   token = rewriter.create<gpu::DestroyDnTensorOp>(loc, tokenTp, token, dnA)
1162               .getAsyncToken();
1163   token = rewriter.create<gpu::DestroyDnTensorOp>(loc, tokenTp, token, dnB)
1164               .getAsyncToken();
1165   token = rewriter.create<gpu::DestroySpMatOp>(loc, tokenTp, token, spMatC)
1166               .getAsyncToken();
1167   token = genDeallocMemRef(rewriter, loc, buffer, token);
1168   if (!isZeroCopy) {
1169     token = genDeallocMemRef(rewriter, loc, matA, token);
1170     token = genDeallocMemRef(rewriter, loc, matB, token);
1171   }
1172   token = genDeallocMemRef(rewriter, loc, rowC, token);
1173   if (colC)
1174     token = genDeallocMemRef(rewriter, loc, colC, token);
1175   token = genCopyMemRef(rewriter, loc, memV, valC, token);
1176   token = genDeallocMemRef(rewriter, loc, valC, token);
1177   tokens.push_back(token);
1178   genBlockingWait(rewriter, loc, tokens);
1179   tokens.clear();
1180   if (gpuDataTransferStrategy != GPUDataTransferStrategy::kRegularDMA) {
1181     genHostUnregisterMemref(rewriter, loc, castB);
1182     genHostUnregisterMemref(rewriter, loc, castA);
1183     genHostUnregisterMemref(rewriter, loc, castR);
1184     if (memC)
1185       genHostUnregisterMemref(rewriter, loc, castC);
1186     genHostUnregisterMemref(rewriter, loc, castV);
1187   }
1188 
1189   // Done.
1190   rewriter.replaceOpWithNewOp<sparse_tensor::LoadOp>(op, c);
1191   return success();
1192 }
1193 
1194 //===----------------------------------------------------------------------===//
1195 // Rewriting rules for direct code generation.
1196 //===----------------------------------------------------------------------===//
1197 
1198 /// Proof-of-concept rewriter. This rule generates a GPU implementation
1199 /// for each outermost forall loop generated by the sparse compiler.
1200 /// TODO: right now works with parallelization-strategy=dense-outer-loop
1201 ///       but give this its own flags in the future
1202 struct ForallRewriter : public OpRewritePattern<scf::ParallelOp> {
1203   using OpRewritePattern<scf::ParallelOp>::OpRewritePattern;
1204 
1205   ForallRewriter(MLIRContext *context, unsigned nT)
1206       : OpRewritePattern(context), numThreads(nT){};
1207 
1208   LogicalResult matchAndRewrite(scf::ParallelOp forallOp,
1209                                 PatternRewriter &rewriter) const override {
1210     // Reject inadmissible loop form.
1211     // Essentially only accept a loop, generated by the sparse compiler,
1212     // of the form
1213     //   forall (i = 0; i < N; i++)
1214     // so that cyclic scheduling over the threads is easy.
1215     if (!forallOp->hasAttr(LoopEmitter::getLoopEmitterLoopAttrName()) ||
1216         forallOp.getNumReductions() != 0 || forallOp.getNumLoops() != 1 ||
1217         !matchPattern(forallOp.getLowerBound()[0], m_Zero()) ||
1218         !matchPattern(forallOp.getStep()[0], m_One()))
1219       return failure();
1220     // Collect every value that is computed outside the parallel loop.
1221     SetVector<Value> invariants; // stable iteration!
1222     forallOp->walk([&](Operation *op) {
1223       // Collect all values of admissible ops.
1224       for (OpOperand &o : op->getOpOperands()) {
1225         Value val = o.get();
1226         Block *block;
1227         if (auto arg = dyn_cast<BlockArgument>(val))
1228           block = arg.getOwner();
1229         else
1230           block = val.getDefiningOp()->getBlock();
1231         if (!isNestedIn(block, forallOp))
1232           invariants.insert(val);
1233       }
1234     });
1235     // Outline the outside values as proper parameters. Fail when sharing
1236     // value between host and device is not straightforward.
1237     SmallVector<Value> constants;
1238     SmallVector<Value> scalars;
1239     SmallVector<Value> buffers;
1240     for (Value val : invariants) {
1241       Type tp = val.getType();
1242       if (val.getDefiningOp<arith::ConstantOp>())
1243         constants.push_back(val);
1244       else if (isa<FloatType>(tp) || tp.isIntOrIndex())
1245         scalars.push_back(val);
1246       else if (isa<MemRefType>(tp))
1247         buffers.push_back(val);
1248       else
1249         return failure(); // don't know how to share
1250     }
1251     // Pass outlined non-constant values.
1252     // TODO: Experiment with `useHostRegistrationForOut` to see if we want to
1253     //       keep the feature at all (either through a heuristic or compiler
1254     //       option for gpu codegen).
1255     Location loc = forallOp->getLoc();
1256     SmallVector<Value> args;
1257     SmallVector<Value> tokens;
1258     Value out = genParametersIn(rewriter, loc, scalars, buffers, args, tokens,
1259                                 /*useHostRegistrationForOut=*/false);
1260     // Set up GPU module and construct GPU function.
1261     auto saveIp = rewriter.saveInsertionPoint();
1262     ModuleOp topModule = forallOp->getParentOfType<ModuleOp>();
1263     auto gpuModule = genGPUModule(rewriter, topModule);
1264     auto gpuFunc = genGPUFunc(rewriter, gpuModule, args);
1265     genGPUCode(rewriter, gpuFunc, forallOp, constants, scalars, buffers);
1266     // Generate code that launches the kernel asynchronously, blocking on all
1267     // opens tokens and yielding a new token for the output.
1268     // TODO: Passing in tokens to launch up does not seem to be properly lowered
1269     //       by cubin yet, hence the current blocking wait.
1270     rewriter.restoreInsertionPoint(saveIp);
1271     genBlockingWait(rewriter, loc, tokens);
1272     tokens.clear();
1273     Value kernelToken =
1274         genLaunchGPUFunc(rewriter, gpuFunc, args, tokens, numThreads);
1275     // Finalize the outlined arguments.
1276     genParametersOut(rewriter, loc, out, kernelToken, scalars, buffers, args,
1277                      tokens);
1278     genBlockingWait(rewriter, loc, tokens);
1279     rewriter.eraseOp(forallOp);
1280     return success();
1281   }
1282 
1283 private:
1284   // Helper method to see if block appears in given loop.
1285   static bool isNestedIn(Block *block, scf::ParallelOp forallOp) {
1286     for (Operation *o = block->getParentOp(); o; o = o->getParentOp()) {
1287       if (o == forallOp)
1288         return true;
1289     }
1290     return false;
1291   }
1292 
1293   unsigned numThreads;
1294 };
1295 
1296 //===----------------------------------------------------------------------===//
1297 // Rewriting rules for library recognition and code generation.
1298 //===----------------------------------------------------------------------===//
1299 
1300 /// Proof-of-concept rewriter. This rule recognizes certain math kernels
1301 /// and replaces these with corresponding calls into a sparse library.
1302 struct LinalgOpRewriter : public OpRewritePattern<linalg::GenericOp> {
1303   using OpRewritePattern<linalg::GenericOp>::OpRewritePattern;
1304 
1305   LinalgOpRewriter(MLIRContext *context, bool rt, GPUDataTransferStrategy t)
1306       : OpRewritePattern(context), enableRT(rt), gpuDataTransferStrategy(t) {}
1307 
1308   LogicalResult matchAndRewrite(linalg::GenericOp op,
1309                                 PatternRewriter &rewriter) const override {
1310     if (op.getNumDpsInits() != 1)
1311       return failure(); // reject multi-output
1312 
1313     const unsigned numLoops = op.getNumLoops();
1314     const unsigned numTensors = op->getNumOperands();
1315     const auto iteratorTypes = op.getIteratorTypesArray();
1316     SmallVector<AffineMap, 4> maps = op.getIndexingMapsArray();
1317 
1318     using MapList = ArrayRef<ArrayRef<AffineExpr>>;
1319     auto infer = [](MapList m) { return AffineMap::inferFromExprList(m); };
1320     AffineExpr i, j, k;
1321     bindDims(getContext(), i, j, k);
1322 
1323     // TODO: more robust patterns, tranposed versions, more kernels,
1324     //       identify alpha and beta and pass them to the CUDA calls.
1325 
1326     // Recognize a SpMV kernel.
1327     if (numLoops == 2 && numTensors == 3 &&
1328         linalg::isParallelIterator(iteratorTypes[0]) &&
1329         linalg::isReductionIterator(iteratorTypes[1]) &&
1330         maps == infer({{i, j}, {j}, {i}}) && matchSumOfMultOfArgs(op)) {
1331       return rewriteSpMV(rewriter, op, enableRT, gpuDataTransferStrategy);
1332     }
1333 
1334     // Recognize a SpGEMM, 2:4-SpMM, or SpMM kernel.
1335     if (numLoops == 3 && numTensors == 3 &&
1336         linalg::isParallelIterator(iteratorTypes[0]) &&
1337         linalg::isParallelIterator(iteratorTypes[1]) &&
1338         linalg::isReductionIterator(iteratorTypes[2]) &&
1339         maps == infer({{i, k}, {k, j}, {i, j}}) && matchSumOfMultOfArgs(op)) {
1340       if (!isDenseTensor(op.getOperand(0)) && !isDenseTensor(op.getOperand(1)))
1341         return rewriteSpGEMM(rewriter, op, enableRT, gpuDataTransferStrategy);
1342       if (op->getAttr("DENSE24"))
1343         return rewrite2To4SpMM(rewriter, op, gpuDataTransferStrategy);
1344       return rewriteSpMM(rewriter, op, enableRT, gpuDataTransferStrategy);
1345     }
1346 
1347     // Recognize a SDDMM kernel.
1348     if (numLoops == 3 && numTensors == 3 &&
1349         linalg::isParallelIterator(iteratorTypes[0]) &&
1350         linalg::isParallelIterator(iteratorTypes[1]) &&
1351         linalg::isReductionIterator(iteratorTypes[2]) &&
1352         maps == infer({{i, k}, {k, j}, {i, j}}) &&
1353         matchSumReductionOfMulUnary(op)) {
1354       return rewriteSDDMM(rewriter, op, enableRT, gpuDataTransferStrategy);
1355     }
1356 
1357     return failure();
1358   }
1359 
1360 private:
1361   bool enableRT;
1362   GPUDataTransferStrategy gpuDataTransferStrategy;
1363 };
1364 
1365 } // namespace
1366 
1367 //===----------------------------------------------------------------------===//
1368 // Public method for populating GPU rewriting rules.
1369 //
1370 // Currently two set of rewriting rules are made available. The first set
1371 // implements direct code generation, currently by means of convering the
1372 // outermost paralell loop into GPU threads. The second set implements
1373 // libary recognition of a set of sparse operations. Eventually, the right
1374 // combination of these two approaches has to be found.
1375 //===----------------------------------------------------------------------===//
1376 
1377 void mlir::populateSparseGPUCodegenPatterns(RewritePatternSet &patterns,
1378                                             unsigned numThreads) {
1379   patterns.add<ForallRewriter>(patterns.getContext(), numThreads);
1380 }
1381 
1382 void mlir::populateSparseGPULibgenPatterns(
1383     RewritePatternSet &patterns, bool enableRT,
1384     GPUDataTransferStrategy gpuDataTransfer) {
1385   patterns.add<LinalgOpRewriter>(patterns.getContext(), enableRT,
1386                                  gpuDataTransfer);
1387 }
1388