xref: /llvm-project/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp (revision aa2952165cd1808dab2bb49b97becc097f4c9cac)
1 //===- SparseGPUCodegen.cpp - Generates GPU code --------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This is a prototype GPU codegenerator for the sparsifier.
10 // The objective is to eventually use the right combination of
11 // direct code generation and libary calls into vendor-specific
12 // highly optimized sparse libraries (e.g. cuSparse for CUDA).
13 //
14 //===----------------------------------------------------------------------===//
15 
16 #include "Utils/CodegenUtils.h"
17 #include "Utils/LoopEmitter.h"
18 
19 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
20 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
21 #include "mlir/Dialect/Linalg/IR/Linalg.h"
22 #include "mlir/Dialect/Linalg/Utils/Utils.h"
23 #include "mlir/Dialect/MemRef/IR/MemRef.h"
24 #include "mlir/Dialect/SCF/IR/SCF.h"
25 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
26 #include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h"
27 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
28 #include "mlir/IR/IRMapping.h"
29 #include "mlir/IR/Matchers.h"
30 
31 using namespace mlir;
32 using namespace mlir::sparse_tensor;
33 
34 namespace {
35 
36 // Sparse formats supported by cuSparse.
37 enum class CuSparseFormat {
38   kNone,
39   kCOO,
40   kCSR,
41   kCSC,
42   kBSR,
43 };
44 
45 //===----------------------------------------------------------------------===//
46 // Helper methods.
47 //===----------------------------------------------------------------------===//
48 
49 /// Marks the given top module as a GPU container module.
50 static void markAsGPUContainer(ModuleOp topModule) {
51   topModule->setAttr(gpu::GPUDialect::getContainerModuleAttrName(),
52                      UnitAttr::get(topModule->getContext()));
53 }
54 
55 /// Constructs a new GPU module (for GPU kernels) inside the given top module,
56 /// or returns an existing GPU module if one was built previously.
57 static gpu::GPUModuleOp genGPUModule(OpBuilder &builder, ModuleOp topModule) {
58   for (auto op : topModule.getBodyRegion().getOps<gpu::GPUModuleOp>())
59     return op; // existing
60   markAsGPUContainer(topModule);
61   builder.setInsertionPointToStart(topModule.getBody());
62   return builder.create<gpu::GPUModuleOp>(topModule->getLoc(),
63                                           "sparse_kernels");
64 }
65 
66 /// Constructs a new GPU kernel in the given GPU module.
67 static gpu::GPUFuncOp genGPUFunc(OpBuilder &builder, gpu::GPUModuleOp gpuModule,
68                                  SmallVectorImpl<Value> &args) {
69   // Get a unique kernel name. Not very creative,
70   // but we simply try kernel0, kernel1, etc.
71   unsigned kernelNumber = 0;
72   SmallString<16> kernelName;
73   do {
74     kernelName.clear();
75     ("kernel" + Twine(kernelNumber++)).toStringRef(kernelName);
76   } while (gpuModule.lookupSymbol(kernelName));
77   // Then we insert a new kernel with given arguments into the module.
78   builder.setInsertionPointToStart(gpuModule.getBody());
79   SmallVector<Type> argsTp;
80   for (auto arg : args)
81     argsTp.push_back(arg.getType());
82   FunctionType type = FunctionType::get(gpuModule->getContext(), argsTp, {});
83   auto gpuFunc =
84       builder.create<gpu::GPUFuncOp>(gpuModule->getLoc(), kernelName, type);
85   gpuFunc->setAttr(gpu::GPUDialect::getKernelFuncAttrName(),
86                    builder.getUnitAttr());
87   return gpuFunc;
88 }
89 
90 /// Constructs code to launch GPU kernel.
91 static Value genLaunchGPUFunc(OpBuilder &builder, gpu::GPUFuncOp gpuFunc,
92                               SmallVectorImpl<Value> &args,
93                               SmallVectorImpl<Value> &tokens,
94                               unsigned numThreads) {
95   Location loc = gpuFunc->getLoc();
96   Value none = TypedValue<::mlir::IntegerType>{};
97   Value one = constantIndex(builder, loc, 1);
98   Value numT = constantIndex(builder, loc, numThreads);
99   gpu::KernelDim3 gridSize = {one, one, one};
100   gpu::KernelDim3 blckSize = {numT, one, one};
101   return builder
102       .create<gpu::LaunchFuncOp>(loc, gpuFunc, gridSize, blckSize,
103                                  /*dynSharedMemSz*/ none, args,
104                                  builder.getType<gpu::AsyncTokenType>(), tokens)
105       .getAsyncToken();
106 }
107 
108 /// Maps the provided ranked host buffer into the device address space.
109 /// Writes from the host are guaranteed to be visible to device kernels
110 /// that are launched afterwards. Writes from the device are guaranteed
111 /// to be visible on the host after synchronizing with the device kernel
112 /// completion. Needs to cast the buffer to a unranked buffer.
113 static Value genHostRegisterMemref(OpBuilder &builder, Location loc,
114                                    Value mem) {
115   MemRefType memTp = cast<MemRefType>(mem.getType());
116   UnrankedMemRefType resTp =
117       UnrankedMemRefType::get(memTp.getElementType(), /*memorySpace=*/0);
118   Value cast = builder.create<memref::CastOp>(loc, resTp, mem);
119   builder.create<gpu::HostRegisterOp>(loc, cast);
120   return cast;
121 }
122 
123 /// Unmaps the provided buffer, expecting the casted buffer.
124 static void genHostUnregisterMemref(OpBuilder &builder, Location loc,
125                                     Value cast) {
126   builder.create<gpu::HostUnregisterOp>(loc, cast);
127 }
128 
129 /// Generates first wait in an asynchronous chain.
130 static Value genFirstWait(OpBuilder &builder, Location loc) {
131   Type tokenType = builder.getType<gpu::AsyncTokenType>();
132   return builder.create<gpu::WaitOp>(loc, tokenType, ValueRange())
133       .getAsyncToken();
134 }
135 
136 /// Generates last, blocking wait in an asynchronous chain.
137 static void genBlockingWait(OpBuilder &builder, Location loc,
138                             ValueRange operands) {
139   builder.create<gpu::WaitOp>(loc, Type(), operands);
140 }
141 
142 /// Allocates memory on the device.
143 /// TODO: A `host_shared` attribute could be used to indicate that
144 ///       the buffer is visible by both host and device, but lowering
145 ///       that feature does not seem to be fully supported yet.
146 static gpu::AllocOp genAllocMemRef(OpBuilder &builder, Location loc, Value mem,
147                                    Value token) {
148   auto tp = cast<ShapedType>(mem.getType());
149   auto elemTp = tp.getElementType();
150   auto shape = tp.getShape();
151   auto memTp = MemRefType::get(shape, elemTp);
152   SmallVector<Value> dynamicSizes;
153   for (unsigned r = 0, rank = tp.getRank(); r < rank; r++) {
154     if (shape[r] == ShapedType::kDynamic) {
155       Value dimOp = linalg::createOrFoldDimOp(builder, loc, mem, r);
156       dynamicSizes.push_back(dimOp);
157     }
158   }
159   return builder.create<gpu::AllocOp>(loc, TypeRange({memTp, token.getType()}),
160                                       token, dynamicSizes, ValueRange());
161 }
162 
163 // Allocates a typed buffer on the host with given size.
164 static Value genHostBuffer(OpBuilder &builder, Location loc, Type type,
165                            Value size) {
166   const auto memTp = MemRefType::get({ShapedType::kDynamic}, type);
167   return builder.create<memref::AllocOp>(loc, memTp, size).getResult();
168 }
169 
170 // Allocates a typed buffer on the device with given size.
171 static gpu::AllocOp genAllocBuffer(OpBuilder &builder, Location loc, Type type,
172                                    Value size, Value token) {
173   const auto memTp = MemRefType::get({ShapedType::kDynamic}, type);
174   return builder.create<gpu::AllocOp>(loc, TypeRange({memTp, token.getType()}),
175                                       token, size, ValueRange());
176 }
177 
178 // Allocates a void buffer on the device with given size.
179 static gpu::AllocOp genAllocBuffer(OpBuilder &builder, Location loc, Value size,
180                                    Value token) {
181   return genAllocBuffer(builder, loc, builder.getI8Type(), size, token);
182 }
183 
184 /// Deallocates memory from the device.
185 static Value genDeallocMemRef(OpBuilder &builder, Location loc, Value mem,
186                               Value token) {
187   return builder.create<gpu::DeallocOp>(loc, token.getType(), token, mem)
188       .getAsyncToken();
189 }
190 
191 /// Copies memory between host and device (direction is implicit).
192 static Value genCopyMemRef(OpBuilder &builder, Location loc, Value dst,
193                            Value src, Value token) {
194   return builder.create<gpu::MemcpyOp>(loc, token.getType(), token, dst, src)
195       .getAsyncToken();
196 }
197 
198 /// Generates an alloc/copy pair.
199 static Value genAllocCopy(OpBuilder &builder, Location loc, Value b,
200                           SmallVectorImpl<Value> &tokens) {
201   Value firstToken = genFirstWait(builder, loc);
202   auto alloc = genAllocMemRef(builder, loc, b, firstToken);
203   Value devMem = alloc.getResult(0);
204   Value depToken = alloc.getAsyncToken(); // copy-after-alloc
205   tokens.push_back(genCopyMemRef(builder, loc, devMem, b, depToken));
206   return devMem;
207 }
208 
209 /// Generates a memref from tensor operation.
210 static Value genTensorToMemref(PatternRewriter &rewriter, Location loc,
211                                Value tensor) {
212   auto tensorType = llvm::cast<ShapedType>(tensor.getType());
213   auto memrefType =
214       MemRefType::get(tensorType.getShape(), tensorType.getElementType());
215   return rewriter.create<bufferization::ToMemrefOp>(loc, memrefType, tensor);
216 }
217 
218 /// Prepares the outlined arguments, passing scalars and buffers in. Here we
219 /// assume that the first buffer is the one allocated for output. We create
220 /// a set of properly chained asynchronous allocation/copy pairs to increase
221 /// overlap before launching the kernel.
222 static Value genParametersIn(OpBuilder &builder, Location loc,
223                              SmallVectorImpl<Value> &scalars,
224                              SmallVectorImpl<Value> &buffers,
225                              SmallVectorImpl<Value> &args,
226                              SmallVectorImpl<Value> &tokens,
227                              bool useHostRegistrationForOut) {
228   Value out;
229   // Scalars are passed by value.
230   for (Value s : scalars)
231     args.push_back(s);
232   // Buffers are need to be made visible on device.
233   for (Value b : buffers) {
234     if (useHostRegistrationForOut) {
235       out = genHostRegisterMemref(builder, loc, b);
236       args.push_back(b);
237       useHostRegistrationForOut = false;
238       continue;
239     }
240     args.push_back(genAllocCopy(builder, loc, b, tokens));
241   }
242   return out;
243 }
244 
245 /// Finalizes the outlined arguments. The output buffer is copied depending
246 /// on the kernel token and then deallocated. All other buffers are simply
247 /// deallocated. Then we wait for all operations to complete.
248 static void genParametersOut(OpBuilder &builder, Location loc, Value out,
249                              Value kernelToken, SmallVectorImpl<Value> &scalars,
250                              SmallVectorImpl<Value> &buffers,
251                              SmallVectorImpl<Value> &args,
252                              SmallVectorImpl<Value> &tokens) {
253   unsigned base = scalars.size();
254   for (unsigned i = base, e = args.size(); i < e; i++) {
255     Value firstToken;
256     if (i == base) {
257       // Assumed output parameter: unregister or copy-out.
258       if (out) {
259         genHostUnregisterMemref(builder, loc, out);
260         out = Value();
261         continue;
262       }
263       firstToken =
264           genCopyMemRef(builder, loc, buffers[0], args[i], kernelToken);
265     } else {
266       firstToken = genFirstWait(builder, loc);
267     }
268     tokens.push_back(genDeallocMemRef(builder, loc, args[i], firstToken));
269   }
270 }
271 
272 /// Constructs code for new GPU kernel.
273 static void genGPUCode(PatternRewriter &rewriter, gpu::GPUFuncOp gpuFunc,
274                        scf::ParallelOp forallOp,
275                        SmallVectorImpl<Value> &constants,
276                        SmallVectorImpl<Value> &scalars,
277                        SmallVectorImpl<Value> &buffers) {
278   Location loc = gpuFunc->getLoc();
279   Block &block = gpuFunc.getBody().front();
280   rewriter.setInsertionPointToStart(&block);
281 
282   // Re-generate the constants, recapture all arguments.
283   unsigned arg = 0;
284   IRMapping irMap;
285   for (Value c : constants)
286     irMap.map(c, rewriter.clone(*c.getDefiningOp())->getResult(0));
287   for (Value s : scalars)
288     irMap.map(s, block.getArgument(arg++));
289   for (Value b : buffers)
290     irMap.map(b, block.getArgument(arg++));
291 
292   // Assume 1-dimensional grid/block configuration (only x dimension),
293   // so that:
294   //   row = blockIdx.x * blockDim.x + threadIdx.x
295   //   inc = blockDim.x * gridDim.x
296   Value bid = rewriter.create<gpu::BlockIdOp>(loc, gpu::Dimension::x);
297   Value bsz = rewriter.create<gpu::BlockDimOp>(loc, gpu::Dimension::x);
298   Value tid = rewriter.create<gpu::ThreadIdOp>(loc, gpu::Dimension::x);
299   Value gsz = rewriter.create<gpu::GridDimOp>(loc, gpu::Dimension::x);
300   Value mul = rewriter.create<arith::MulIOp>(loc, bid, bsz);
301   Value row = rewriter.create<arith::AddIOp>(loc, mul, tid);
302   Value inc = rewriter.create<arith::MulIOp>(loc, bsz, gsz);
303 
304   // Construct the iteration over the computational space that
305   // accounts for the fact that the total number of threads and
306   // the amount of work to be done usually do not match precisely.
307   //   for (r = row; r < N; r += inc) {
308   //     <loop-body>
309   //   }
310   Value upper = irMap.lookup(forallOp.getUpperBound()[0]);
311   scf::ForOp forOp = rewriter.create<scf::ForOp>(loc, row, upper, inc);
312   // The scf.for builder creates an empty block. scf.for does not allow multiple
313   // blocks in its region, so delete the block before `cloneRegionBefore` adds
314   // an additional block.
315   rewriter.eraseBlock(forOp.getBody());
316   rewriter.cloneRegionBefore(forallOp.getRegion(), forOp.getRegion(),
317                              forOp.getRegion().begin(), irMap);
318   // Replace the scf.reduce terminator.
319   rewriter.setInsertionPoint(forOp.getBody()->getTerminator());
320   rewriter.replaceOpWithNewOp<scf::YieldOp>(forOp.getBody()->getTerminator());
321 
322   // Done.
323   rewriter.setInsertionPointAfter(forOp);
324   rewriter.create<gpu::ReturnOp>(gpuFunc->getLoc());
325 }
326 
327 //===----------------------------------------------------------------------===//
328 // Library helper methods.
329 //===----------------------------------------------------------------------===//
330 
331 /// Helper to detect a + b with arguments taken from given block.
332 static bool matchAddOfArgs(Block *block, Value val) {
333   if (auto *def = val.getDefiningOp()) {
334     if (isa<arith::AddFOp, arith::AddIOp>(def)) {
335       Value a = block->getArguments()[0];
336       Value b = block->getArguments()[1];
337       return (def->getOperand(0) == a && def->getOperand(1) == b) ||
338              (def->getOperand(0) == b && def->getOperand(1) == a);
339     }
340   }
341   return false;
342 }
343 
344 /// Helper to detect a * b with arguments taken from given block.
345 static bool matchMulOfArgs(Block *block, Value val) {
346   if (auto *def = val.getDefiningOp()) {
347     if (isa<arith::MulFOp, arith::MulIOp>(def)) {
348       Value a = block->getArguments()[0];
349       Value b = block->getArguments()[1];
350       return (def->getOperand(0) == a && def->getOperand(1) == b) ||
351              (def->getOperand(0) == b && def->getOperand(1) == a);
352     }
353   }
354   return false;
355 }
356 
357 /// Helper to detect x = x + a * b
358 static bool matchSumOfMultOfArgs(linalg::GenericOp op) {
359   auto yieldOp = cast<linalg::YieldOp>(op.getRegion().front().getTerminator());
360   if (auto *def = yieldOp.getOperand(0).getDefiningOp()) {
361     if (isa<arith::AddFOp, arith::AddIOp>(def)) {
362       Value x = op.getBlock()->getArguments()[2];
363       return (def->getOperand(0) == x &&
364               matchMulOfArgs(op.getBlock(), def->getOperand(1))) ||
365              (def->getOperand(1) == x &&
366               matchMulOfArgs(op.getBlock(), def->getOperand(0)));
367     }
368   }
369   return false;
370 }
371 
372 // Helper to detect c += spy(s) x (a * b)
373 static bool matchSumReductionOfMulUnary(linalg::GenericOp op) {
374   auto yieldOp = cast<linalg::YieldOp>(op.getRegion().front().getTerminator());
375   // The linalg yields a custom reduce result.
376   Value s_out = op.getBlock()->getArguments()[2];
377   if (auto redOp =
378           yieldOp.getOperand(0).getDefiningOp<sparse_tensor::ReduceOp>()) {
379     // The reduce consumes the output.
380     Value other;
381     if (s_out == redOp->getOperand(0))
382       other = redOp->getOperand(1);
383     else if (s_out == redOp->getOperand(1))
384       other = redOp->getOperand(0);
385     else
386       return false;
387     // The reduce op also consumes an unary which also consumes the output
388     // and does not define an absent value.
389     if (auto unOp = other.getDefiningOp<sparse_tensor::UnaryOp>()) {
390       if (s_out != unOp->getOperand(0) || !unOp.getAbsentRegion().empty())
391         return false;
392       // And the bodies are as expected.
393       auto yieldUn = cast<sparse_tensor::YieldOp>(
394           unOp.getRegion(0).front().getTerminator());
395       auto yieldRed = cast<sparse_tensor::YieldOp>(
396           redOp.getRegion().front().getTerminator());
397       return matchMulOfArgs(op.getBlock(), yieldUn.getOperand(0)) &&
398              matchAddOfArgs(&redOp.getRegion().front(), yieldRed.getOperand(0));
399     }
400   }
401   return false;
402 }
403 
404 /// Test for dense tensor.
405 static bool isDenseTensor(Value v) {
406   auto sTp = getSparseTensorType(v);
407   return sTp.getDimRank() == sTp.getLvlRank() && sTp.isAllDense();
408 }
409 
410 /// Test for suitable positions/coordinates width.
411 static bool isAdmissibleMetaData(SparseTensorType &aTp) {
412   return (aTp.getPosWidth() == 0 || aTp.getPosWidth() >= 16) &&
413          (aTp.getCrdWidth() == 0 || aTp.getCrdWidth() >= 16);
414 }
415 
416 /// Test for sorted COO matrix with suitable metadata.
417 static bool isAdmissibleCOO(SparseTensorType &aTp) {
418   return aTp.getDimRank() == 2 && aTp.getLvlRank() == 2 && aTp.isIdentity() &&
419          aTp.isCompressedLvl(0) && aTp.isOrderedLvl(0) && !aTp.isUniqueLvl(0) &&
420          aTp.isSingletonLvl(1) && aTp.isOrderedLvl(1) && aTp.isUniqueLvl(1) &&
421          isAdmissibleMetaData(aTp);
422 }
423 
424 /// Test for CSR matrix with suitable metadata.
425 static bool isAdmissibleCSR(SparseTensorType &aTp) {
426   return aTp.getDimRank() == 2 && aTp.getLvlRank() == 2 && aTp.isIdentity() &&
427          aTp.isDenseLvl(0) && aTp.isCompressedLvl(1) && aTp.isOrderedLvl(1) &&
428          aTp.isUniqueLvl(1) && isAdmissibleMetaData(aTp);
429 }
430 
431 /// Test for CSC matrix with suitable metadata.
432 static bool isAdmissibleCSC(SparseTensorType &aTp) {
433   return aTp.getDimRank() == 2 && aTp.getLvlRank() == 2 && !aTp.isIdentity() &&
434          aTp.isPermutation() && aTp.isDenseLvl(0) && aTp.isCompressedLvl(1) &&
435          aTp.isOrderedLvl(1) && aTp.isUniqueLvl(1) && isAdmissibleMetaData(aTp);
436 }
437 
438 /// Test for BSR matrix with suitable metadata.
439 static bool isAdmissibleBSR(SparseTensorType &aTp) {
440   if (aTp.getDimRank() == 2 && aTp.getLvlRank() == 4 && aTp.isDenseLvl(0) &&
441       aTp.isCompressedLvl(1) && aTp.isOrderedLvl(1) && aTp.isUniqueLvl(1) &&
442       aTp.isDenseLvl(2) && aTp.isDenseLvl(3) && isAdmissibleMetaData(aTp)) {
443     // CuSparse only supports "square" blocks currently.
444     SmallVector<unsigned> dims = getBlockSize(aTp.getDimToLvl());
445     assert(dims.size() == 2);
446     return dims[0] == dims[1] && dims[0] > 1;
447   }
448   return false;
449 }
450 
451 /// Test for 2:4 matrix with suitable metadata.
452 static bool isAdmissible24(SparseTensorType &aTp) {
453   return aTp.getDimRank() == 2 && aTp.getLvlRank() == 3 && aTp.isDenseLvl(0) &&
454          aTp.isDenseLvl(1) && aTp.isNOutOfMLvl(2) && isAdmissibleMetaData(aTp);
455 }
456 
457 /// Test for conversion into 2:4 matrix.
458 static bool isConversionInto24(Value v) {
459   if (auto cnv = v.getDefiningOp<ConvertOp>()) {
460     Value a = cnv.getResult();
461     Value d = cnv.getSource();
462     SparseTensorType aTp = getSparseTensorType(a);
463     return isDenseTensor(d) && isAdmissible24(aTp);
464   }
465   return false;
466 }
467 
468 /// Returns a suitable sparse format for the operation and given operand
469 /// types with cuSparse, or kNone if none is available.
470 static CuSparseFormat getCuSparseFormat(SparseTensorType aTp,
471                                         SparseTensorType bTp,
472                                         SparseTensorType cTp, bool enableRT,
473                                         bool isMatVec) {
474   // The other operands have a dense type.
475   if (bTp.hasEncoding() || cTp.hasEncoding())
476     return CuSparseFormat::kNone;
477   // Now check for suitable operand type for the main operand.
478   if (isAdmissibleCOO(aTp))
479 #ifdef CUSPARSE_COO_AOS
480     return isMatVec ? CuSparseFormat::kCOO : CuSparseFormat::kNone;
481 #else
482     return enableRT ? CuSparseFormat::kCOO : CuSparseFormat::kNone;
483 #endif
484   if (isAdmissibleCSR(aTp))
485     return CuSparseFormat::kCSR;
486   if (isAdmissibleCSC(aTp))
487     return CuSparseFormat::kCSC;
488   if (isAdmissibleBSR(aTp))
489     return CuSparseFormat::kBSR;
490   return CuSparseFormat::kNone;
491 }
492 
493 /// Generates the first positions/coordinates of a sparse matrix.
494 static Value genFirstPosOrCrds(OpBuilder &builder, Location loc, Value a,
495                                CuSparseFormat format, bool enableRT) {
496   if (format == CuSparseFormat::kCOO) {
497     // Library uses SoA COO, direct IR uses AoS COO.
498     if (enableRT)
499       return builder.create<ToCoordinatesOp>(loc, a, 0);
500     return builder.create<ToCoordinatesBufferOp>(loc, a);
501   }
502   // Formats CSR/CSC and BSR use positions at 1.
503   return builder.create<ToPositionsOp>(loc, a, 1);
504 }
505 
506 /// Generates the second coordinates of a sparse matrix.
507 static Value genSecondCrds(OpBuilder &builder, Location loc, Value a,
508                            CuSparseFormat format, bool enableRT) {
509   bool isCOO = format == CuSparseFormat::kCOO;
510   if (isCOO && !enableRT)
511     return Value(); // nothing needed
512   // Formats CSR/CSC and BSR use coordinates at 1.
513   return builder.create<ToCoordinatesOp>(loc, a, 1);
514 }
515 
516 /// Generates the sparse matrix handle.
517 static Operation *genSpMat(OpBuilder &builder, Location loc,
518                            SparseTensorType &aTp, Type handleTp, Type tokenTp,
519                            Value token, Value sz1, Value sz2, Value nseA,
520                            Value rowA, Value colA, Value valA,
521                            CuSparseFormat format, bool enableRT) {
522   if (format == CuSparseFormat::kCOO) {
523     // Library uses SoA COO, direct IR uses AoS COO.
524     if (enableRT) {
525       assert(colA);
526       return builder.create<gpu::CreateCooOp>(loc, handleTp, tokenTp, token,
527                                               sz1, sz2, nseA, rowA, colA, valA);
528     }
529 #ifdef CUSPARSE_COO_AOS
530     assert(!colA);
531     return builder.create<gpu::CreateCooAoSOp>(loc, handleTp, tokenTp, token,
532                                                sz1, sz2, nseA, rowA, valA);
533 #else
534     llvm_unreachable("gpu::CreateCooAoSOp is deprecated");
535 #endif
536   }
537   assert(colA);
538   if (format == CuSparseFormat::kCSR)
539     return builder.create<gpu::CreateCsrOp>(loc, handleTp, tokenTp, token, sz1,
540                                             sz2, nseA, rowA, colA, valA);
541   if (format == CuSparseFormat::kCSC)
542     return builder.create<gpu::CreateCscOp>(loc, handleTp, tokenTp, token, sz1,
543                                             sz2, nseA, rowA, colA, valA);
544   // BSR requires a bit more work since we need to pass in the block size
545   // and all others sizes in terms of blocks (#block-rows, #block-cols,
546   // #nonzero-blocks).
547   assert(format == CuSparseFormat::kBSR);
548   SmallVector<unsigned> dims = getBlockSize(aTp.getDimToLvl());
549   assert(dims.size() == 2 && dims[0] == dims[1]);
550   uint64_t b = dims[0];
551   Value bSz = constantIndex(builder, loc, b);
552   Value bRows = builder.create<arith::DivUIOp>(loc, sz1, bSz);
553   Value bCols = builder.create<arith::DivUIOp>(loc, sz2, bSz);
554   Value bNum = builder.create<arith::DivUIOp>(
555       loc, nseA, constantIndex(builder, loc, b * b));
556   return builder.create<gpu::CreateBsrOp>(loc, handleTp, tokenTp, token, bRows,
557                                           bCols, bNum, bSz, bSz, rowA, colA,
558                                           valA);
559 }
560 
561 /// Match and rewrite SpMV kernel.
562 static LogicalResult rewriteSpMV(PatternRewriter &rewriter,
563                                  linalg::GenericOp op, bool enableRT) {
564   Location loc = op.getLoc();
565   Value a = op.getOperand(0);
566   Value x = op.getOperand(1);
567   Value y = op.getOperand(2); // we have y = Ax
568   SmallVector<Value> tokens;
569 
570   // Only admissible sparse matrix format and dense vectors (no BSR).
571   SparseTensorType aTp = getSparseTensorType(a);
572   SparseTensorType xTp = getSparseTensorType(x);
573   SparseTensorType yTp = getSparseTensorType(y);
574   auto format = getCuSparseFormat(aTp, xTp, yTp, enableRT, /*isMatVec=*/true);
575   if (format == CuSparseFormat::kNone || format == CuSparseFormat::kBSR)
576     return failure();
577 
578   // Start sparse kernel and copy data from host to device.
579   //   a : memR/memC/memV -> rowA,colA,valA
580   //   x : memX           -> vecX
581   //   y : memY           -> vecY
582   Value nseA = rewriter.create<NumberOfEntriesOp>(loc, a);
583   Value szY = linalg::createOrFoldDimOp(rewriter, loc, a, 0);
584   Value szX = linalg::createOrFoldDimOp(rewriter, loc, a, 1);
585   Value memR = genFirstPosOrCrds(rewriter, loc, a, format, enableRT);
586   Value memC = genSecondCrds(rewriter, loc, a, format, enableRT); // or empty
587   Value memV = rewriter.create<ToValuesOp>(loc, a);
588   Value rowA = genAllocCopy(rewriter, loc, memR, tokens);
589   Value colA = memC ? genAllocCopy(rewriter, loc, memC, tokens) : Value();
590   Value valA = genAllocCopy(rewriter, loc, memV, tokens);
591   Value memX = genTensorToMemref(rewriter, loc, x);
592   Value vecX = genAllocCopy(rewriter, loc, memX, tokens);
593   Value memY = genTensorToMemref(rewriter, loc, y);
594   Value vecY = genAllocCopy(rewriter, loc, memY, tokens);
595   genBlockingWait(rewriter, loc, tokens);
596   tokens.clear();
597 
598   // Create sparse environment and sparse matrix/dense vector handles.
599   Type indexTp = rewriter.getIndexType();
600   Type dnTensorHandleTp = rewriter.getType<gpu::SparseDnTensorHandleType>();
601   Type spmatHandleTp = rewriter.getType<gpu::SparseSpMatHandleType>();
602   Type tokenTp = rewriter.getType<gpu::AsyncTokenType>();
603   Value token = genFirstWait(rewriter, loc);
604   Operation *spGenA =
605       genSpMat(rewriter, loc, aTp, spmatHandleTp, tokenTp, token, szY, szX,
606                nseA, rowA, colA, valA, format, enableRT);
607   Value spMatA = spGenA->getResult(0);
608   token = spGenA->getResult(1);
609   auto dvecX = rewriter.create<gpu::CreateDnTensorOp>(
610       loc, dnTensorHandleTp, tokenTp, token, vecX, szX);
611   Value dnX = dvecX.getResult(0);
612   token = dvecX.getAsyncToken();
613   auto dvecY = rewriter.create<gpu::CreateDnTensorOp>(
614       loc, dnTensorHandleTp, tokenTp, token, vecY, szY);
615   Value dnY = dvecY.getResult(0);
616   token = dvecY.getAsyncToken();
617   auto dnYType = llvm::cast<ShapedType>(y.getType()).getElementType();
618 
619   // Precompute buffersize for SpMV.
620   auto bufferComp = rewriter.create<gpu::SpMVBufferSizeOp>(
621       loc, indexTp, tokenTp, token, spMatA, dnX, dnY,
622       /*computeType=*/dnYType);
623   Value bufferSz = bufferComp.getResult(0);
624   token = bufferComp.getAsyncToken();
625   auto buf = genAllocBuffer(rewriter, loc, bufferSz, token);
626   Value buffer = buf.getResult(0);
627   token = buf.getAsyncToken();
628 
629   // Perform the SpMV.
630   auto spmvComp = rewriter.create<gpu::SpMVOp>(
631       loc, tokenTp, token, spMatA, dnX, dnY, /*computeType=*/dnYType, buffer);
632   token = spmvComp.getAsyncToken();
633 
634   // Copy data back to host and free all the resoures.
635   token = rewriter.create<gpu::DestroySpMatOp>(loc, tokenTp, token, spMatA)
636               .getAsyncToken();
637   token = rewriter.create<gpu::DestroyDnTensorOp>(loc, tokenTp, token, dnX)
638               .getAsyncToken();
639   token = rewriter.create<gpu::DestroyDnTensorOp>(loc, tokenTp, token, dnY)
640               .getAsyncToken();
641   token = genDeallocMemRef(rewriter, loc, rowA, token);
642   if (colA)
643     token = genDeallocMemRef(rewriter, loc, colA, token);
644   token = genDeallocMemRef(rewriter, loc, valA, token);
645   token = genDeallocMemRef(rewriter, loc, buffer, token);
646   token = genDeallocMemRef(rewriter, loc, vecX, token);
647   token = genCopyMemRef(rewriter, loc, memY, vecY, token);
648   token = genDeallocMemRef(rewriter, loc, vecY, token);
649   tokens.push_back(token);
650   genBlockingWait(rewriter, loc, tokens);
651   tokens.clear();
652 
653   // Done.
654   rewriter.replaceOpWithNewOp<bufferization::ToTensorOp>(op, memY);
655   return success();
656 }
657 
658 /// Match and rewrite SpMM kernel.
659 static LogicalResult rewriteSpMM(PatternRewriter &rewriter,
660                                  linalg::GenericOp op, bool enableRT) {
661   Location loc = op.getLoc();
662   Value a = op.getOperand(0);
663   Value b = op.getOperand(1);
664   Value c = op.getOperand(2); // we have C = AB
665   SmallVector<Value> tokens;
666 
667   // Only admissible sparse matrix format and dense matrices (no BSR).
668   SparseTensorType aTp = getSparseTensorType(a);
669   SparseTensorType bTp = getSparseTensorType(b);
670   SparseTensorType cTp = getSparseTensorType(c);
671   auto format = getCuSparseFormat(aTp, bTp, cTp, enableRT, /*isMatVec=*/false);
672   if (format == CuSparseFormat::kNone || format == CuSparseFormat::kBSR)
673     return failure();
674 
675   // Start sparse kernel and copy data from host to device.
676   //   a : memR/memC/memV -> rowA,colA,valA
677   //   b : bufB           -> matB
678   //   c : bufC           -> matC
679   Value nseA = rewriter.create<NumberOfEntriesOp>(loc, a);
680   Value szm = linalg::createOrFoldDimOp(rewriter, loc, a, 0);
681   Value szk = linalg::createOrFoldDimOp(rewriter, loc, a, 1);
682   Value szn = linalg::createOrFoldDimOp(rewriter, loc, b, 1);
683   Value memR = genFirstPosOrCrds(rewriter, loc, a, format, enableRT);
684   Value memC = genSecondCrds(rewriter, loc, a, format, enableRT); // or empty
685   Value memV = rewriter.create<ToValuesOp>(loc, a);
686   Value rowA = genAllocCopy(rewriter, loc, memR, tokens);
687   Value colA = memC ? genAllocCopy(rewriter, loc, memC, tokens) : Value();
688   Value valA = genAllocCopy(rewriter, loc, memV, tokens);
689   Value bufB = genTensorToMemref(rewriter, loc, b);
690   Value matB = genAllocCopy(rewriter, loc, bufB, tokens);
691   Value bufC = genTensorToMemref(rewriter, loc, c);
692   Value matC = genAllocCopy(rewriter, loc, bufC, tokens);
693   genBlockingWait(rewriter, loc, tokens);
694   tokens.clear();
695 
696   // Create sparse environment and sparse matrix/dense matrix handles.
697   Type indexTp = rewriter.getIndexType();
698   Type dnTensorHandleTp = rewriter.getType<gpu::SparseDnTensorHandleType>();
699   Type spMatHandleTp = rewriter.getType<gpu::SparseSpMatHandleType>();
700   Type tokenTp = rewriter.getType<gpu::AsyncTokenType>();
701   Value token = genFirstWait(rewriter, loc);
702   Operation *spGenA =
703       genSpMat(rewriter, loc, aTp, spMatHandleTp, tokenTp, token, szm, szk,
704                nseA, rowA, colA, valA, format, enableRT);
705   Value spMatA = spGenA->getResult(0);
706   token = spGenA->getResult(1);
707   auto dmatB = rewriter.create<gpu::CreateDnTensorOp>(
708       loc, dnTensorHandleTp, tokenTp, token, matB,
709       SmallVector<Value>{szk, szn});
710   Value dnB = dmatB.getResult(0);
711   token = dmatB.getAsyncToken();
712   auto dmatC = rewriter.create<gpu::CreateDnTensorOp>(
713       loc, dnTensorHandleTp, tokenTp, token, matC,
714       SmallVector<Value>{szm, szn});
715   Value dnC = dmatC.getResult(0);
716   token = dmatC.getAsyncToken();
717   auto dmatCType = llvm::cast<ShapedType>(c.getType()).getElementType();
718 
719   // Precompute buffersize for SpMM.
720   auto bufferComp = rewriter.create<gpu::SpMMBufferSizeOp>(
721       loc, indexTp, tokenTp, token, spMatA, dnB, dnC,
722       /*computeType=*/dmatCType);
723   Value bufferSz = bufferComp.getResult(0);
724   token = bufferComp.getAsyncToken();
725   auto buf = genAllocBuffer(rewriter, loc, bufferSz, token);
726   Value buffer = buf.getResult(0);
727   token = buf.getAsyncToken();
728   auto dnCType = llvm::cast<ShapedType>(c.getType()).getElementType();
729 
730   // Perform the SpMM.
731   auto spmmComp = rewriter.create<gpu::SpMMOp>(
732       loc, tokenTp, token, spMatA, dnB, dnC, /*computeType=*/dnCType, buffer);
733   token = spmmComp.getAsyncToken();
734 
735   // Copy data back to host and free all the resoures.
736   token = rewriter.create<gpu::DestroySpMatOp>(loc, tokenTp, token, spMatA)
737               .getAsyncToken();
738   token = rewriter.create<gpu::DestroyDnTensorOp>(loc, tokenTp, token, dnB)
739               .getAsyncToken();
740   token = rewriter.create<gpu::DestroyDnTensorOp>(loc, tokenTp, token, dnC)
741               .getAsyncToken();
742   token = genDeallocMemRef(rewriter, loc, rowA, token);
743   if (colA)
744     token = genDeallocMemRef(rewriter, loc, colA, token);
745   token = genDeallocMemRef(rewriter, loc, valA, token);
746   token = genDeallocMemRef(rewriter, loc, buffer, token);
747   token = genDeallocMemRef(rewriter, loc, matB, token);
748   token = genCopyMemRef(rewriter, loc, bufC, matC, token);
749   token = genDeallocMemRef(rewriter, loc, matC, token);
750   tokens.push_back(token);
751   genBlockingWait(rewriter, loc, tokens);
752   tokens.clear();
753 
754   // Done.
755   rewriter.replaceOpWithNewOp<bufferization::ToTensorOp>(op, bufC);
756   return success();
757 }
758 
759 // Match and rewrite SpGEMM kernel.
760 static LogicalResult rewriteSpGEMM(PatternRewriter &rewriter,
761                                    linalg::GenericOp op, bool enableRT) {
762   Location loc = op.getLoc();
763   Value a = op.getOperand(0);
764   Value b = op.getOperand(1);
765   Value c = op.getOperand(2); // we have C = AB
766   SmallVector<Value> tokens;
767 
768   // Only CSR <- CSR x CSR supported.
769   auto format = CuSparseFormat::kCSR;
770   SparseTensorType aTp = getSparseTensorType(a);
771   SparseTensorType bTp = getSparseTensorType(b);
772   SparseTensorType cTp = getSparseTensorType(c);
773   if (!isAdmissibleCSR(aTp) || !isAdmissibleCSR(bTp) || !isAdmissibleCSR(cTp))
774     return failure();
775 
776   // Start sparse kernel and copy data from host to device.
777   //   a : amemR/amemC/amemV -> rowA,colA,valA
778   //   b : bmemR/bmemC/bmemV -> rowB,colB,valB
779   //   c : materializes
780   auto dnCType = cTp.getElementType();
781   Value nseA = rewriter.create<NumberOfEntriesOp>(loc, a);
782   Value nseB = rewriter.create<NumberOfEntriesOp>(loc, b);
783   Value szm = linalg::createOrFoldDimOp(rewriter, loc, a, 0);
784   Value szk = linalg::createOrFoldDimOp(rewriter, loc, a, 1);
785   Value szn = linalg::createOrFoldDimOp(rewriter, loc, b, 1);
786   Value amemR = genFirstPosOrCrds(rewriter, loc, a, format, enableRT);
787   Value amemC = genSecondCrds(rewriter, loc, a, format, enableRT); // not empty
788   Value amemV = rewriter.create<ToValuesOp>(loc, a);
789   Value bmemR = genFirstPosOrCrds(rewriter, loc, b, format, enableRT);
790   Value bmemC = genSecondCrds(rewriter, loc, b, format, enableRT); // not empty
791   Value bmemV = rewriter.create<ToValuesOp>(loc, b);
792   Value rowA = genAllocCopy(rewriter, loc, amemR, tokens);
793   Value colA = genAllocCopy(rewriter, loc, amemC, tokens);
794   Value valA = genAllocCopy(rewriter, loc, amemV, tokens);
795   Value rowB = genAllocCopy(rewriter, loc, bmemR, tokens);
796   Value colB = genAllocCopy(rewriter, loc, bmemC, tokens);
797   Value valB = genAllocCopy(rewriter, loc, bmemV, tokens);
798   genBlockingWait(rewriter, loc, tokens);
799   tokens.clear();
800 
801   // Create sparse environment and sparse matrix/dense vector handles.
802   Type indexTp = rewriter.getIndexType();
803   Type spmatHandleTp = rewriter.getType<gpu::SparseSpMatHandleType>();
804   Type descTp = rewriter.getType<gpu::SparseSpGEMMOpHandleType>();
805   Type tokenTp = rewriter.getType<gpu::AsyncTokenType>();
806   Value token = genFirstWait(rewriter, loc);
807   Operation *spGenA =
808       genSpMat(rewriter, loc, aTp, spmatHandleTp, tokenTp, token, szm, szk,
809                nseA, rowA, colA, valA, format, enableRT);
810   Value spMatA = spGenA->getResult(0);
811   token = spGenA->getResult(1);
812   Operation *spGenB =
813       genSpMat(rewriter, loc, bTp, spmatHandleTp, tokenTp, token, szk, szn,
814                nseB, rowB, colB, valB, format, enableRT);
815   Value spMatB = spGenB->getResult(0);
816   token = spGenB->getResult(1);
817 
818   // Sparse matrix C materializes (also assumes beta == 0).
819   Value zero = constantIndex(rewriter, loc, 0);
820   Value one = constantIndex(rewriter, loc, 1);
821   Value mplus1 = rewriter.create<arith::AddIOp>(loc, szm, one);
822   auto e1 = genAllocBuffer(rewriter, loc, cTp.getPosType(), mplus1, token);
823   Value rowC = e1.getResult(0);
824   token = e1.getAsyncToken();
825   auto e2 = genAllocBuffer(rewriter, loc, cTp.getCrdType(), zero, token);
826   Value colC = e2.getResult(0); // no free needed
827   token = e2.getAsyncToken();
828   auto e3 = genAllocBuffer(rewriter, loc, dnCType, zero, token);
829   Value valC = e3.getResult(0); // no free needed
830   token = e3.getAsyncToken();
831   Operation *spGenC =
832       genSpMat(rewriter, loc, cTp, spmatHandleTp, tokenTp, token, szm, szn,
833                zero, rowC, colC, valC, format, enableRT);
834   Value spMatC = spGenC->getResult(0);
835   token = spGenC->getResult(1);
836 
837   // Precompute buffersizes for SpGEMM.
838   Operation *descOp =
839       rewriter.create<gpu::SpGEMMCreateDescrOp>(loc, descTp, tokenTp, token);
840   Value desc = descOp->getResult(0);
841   token = descOp->getResult(1);
842   Operation *work1 = rewriter.create<gpu::SpGEMMWorkEstimationOrComputeOp>(
843       loc, indexTp, tokenTp, token, desc, gpu::TransposeMode::NON_TRANSPOSE,
844       gpu::TransposeMode::NON_TRANSPOSE, spMatA, spMatB, spMatC, dnCType, zero,
845       valC, gpu::SpGEMMWorkEstimationOrComputeKind::WORK_ESTIMATION);
846   Value bufferSz1 = work1->getResult(0);
847   token = work1->getResult(1);
848   auto buf1 = genAllocBuffer(rewriter, loc, bufferSz1, token);
849   Value buffer1 = buf1.getResult(0);
850   token = buf1.getAsyncToken();
851   Operation *work2 = rewriter.create<gpu::SpGEMMWorkEstimationOrComputeOp>(
852       loc, indexTp, tokenTp, token, desc, gpu::TransposeMode::NON_TRANSPOSE,
853       gpu::TransposeMode::NON_TRANSPOSE, spMatA, spMatB, spMatC, dnCType,
854       bufferSz1, buffer1,
855       gpu::SpGEMMWorkEstimationOrComputeKind::WORK_ESTIMATION);
856   token = work2->getResult(1);
857 
858   // Compute step.
859   Operation *compute1 = rewriter.create<gpu::SpGEMMWorkEstimationOrComputeOp>(
860       loc, indexTp, tokenTp, token, desc, gpu::TransposeMode::NON_TRANSPOSE,
861       gpu::TransposeMode::NON_TRANSPOSE, spMatA, spMatB, spMatC, dnCType, zero,
862       valC, gpu::SpGEMMWorkEstimationOrComputeKind::COMPUTE);
863   Value bufferSz2 = compute1->getResult(0);
864   token = compute1->getResult(1);
865   auto buf2 = genAllocBuffer(rewriter, loc, bufferSz2, token);
866   Value buffer2 = buf2.getResult(0);
867   token = buf2.getAsyncToken();
868   Operation *compute2 = rewriter.create<gpu::SpGEMMWorkEstimationOrComputeOp>(
869       loc, indexTp, tokenTp, token, desc, gpu::TransposeMode::NON_TRANSPOSE,
870       gpu::TransposeMode::NON_TRANSPOSE, spMatA, spMatB, spMatC, dnCType,
871       bufferSz2, buffer2, gpu::SpGEMMWorkEstimationOrComputeKind::COMPUTE);
872   token = compute2->getResult(1);
873 
874   // Get sizes.
875   Operation *sizes = rewriter.create<gpu::SpMatGetSizeOp>(
876       loc, indexTp, indexTp, indexTp, tokenTp, token, spMatC);
877   Value nnz = sizes->getResult(2);
878   token = sizes->getResult(3);
879   auto a2 = genAllocBuffer(rewriter, loc, cTp.getCrdType(), nnz, token);
880   colC = a2.getResult(0);
881   token = a2.getAsyncToken();
882   auto a3 = genAllocBuffer(rewriter, loc, dnCType, nnz, token);
883   valC = a3.getResult(0);
884   token = a3.getAsyncToken();
885 
886   // Update C with new pointers and copy final product back into C.
887   Operation *update = rewriter.create<gpu::SetCsrPointersOp>(
888       loc, tokenTp, token, spMatC, rowC, colC, valC);
889   token = update->getResult(0);
890   Operation *copy = rewriter.create<gpu::SpGEMMCopyOp>(
891       loc, tokenTp, token, desc, gpu::TransposeMode::NON_TRANSPOSE,
892       gpu::TransposeMode::NON_TRANSPOSE, spMatA, spMatB, spMatC, dnCType);
893   token = copy->getResult(0);
894 
895   // Allocate buffers on host.
896   Value rowH = genHostBuffer(rewriter, loc, cTp.getPosType(), mplus1);
897   Value colH = genHostBuffer(rewriter, loc, cTp.getCrdType(), nnz);
898   Value valH = genHostBuffer(rewriter, loc, dnCType, nnz);
899 
900   // Copy data back to host and free all the resoures.
901   token = rewriter.create<gpu::SpGEMMDestroyDescrOp>(loc, tokenTp, token, desc)
902               .getAsyncToken();
903   token = rewriter.create<gpu::DestroySpMatOp>(loc, tokenTp, token, spMatA)
904               .getAsyncToken();
905   token = rewriter.create<gpu::DestroySpMatOp>(loc, tokenTp, token, spMatB)
906               .getAsyncToken();
907   token = rewriter.create<gpu::DestroySpMatOp>(loc, tokenTp, token, spMatC)
908               .getAsyncToken();
909   token = genCopyMemRef(rewriter, loc, rowH, rowC, token);
910   token = genCopyMemRef(rewriter, loc, colH, colC, token);
911   token = genCopyMemRef(rewriter, loc, valH, valC, token);
912   token = genDeallocMemRef(rewriter, loc, rowA, token);
913   token = genDeallocMemRef(rewriter, loc, colA, token);
914   token = genDeallocMemRef(rewriter, loc, valA, token);
915   token = genDeallocMemRef(rewriter, loc, rowB, token);
916   token = genDeallocMemRef(rewriter, loc, colB, token);
917   token = genDeallocMemRef(rewriter, loc, valB, token);
918   token = genDeallocMemRef(rewriter, loc, rowC, token);
919   token = genDeallocMemRef(rewriter, loc, colC, token);
920   token = genDeallocMemRef(rewriter, loc, valC, token);
921   token = genDeallocMemRef(rewriter, loc, buffer1, token);
922   token = genDeallocMemRef(rewriter, loc, buffer2, token);
923   tokens.push_back(token);
924   genBlockingWait(rewriter, loc, tokens);
925   tokens.clear();
926 
927   // Done.
928   Value vt = rewriter.create<bufferization::ToTensorOp>(loc, valH);
929   Value rt = rewriter.create<bufferization::ToTensorOp>(loc, rowH);
930   Value ct = rewriter.create<bufferization::ToTensorOp>(loc, colH);
931   rewriter.replaceOpWithNewOp<AssembleOp>(op, c.getType(), ValueRange{rt, ct},
932                                           vt);
933   return success();
934 }
935 
936 // Match and rewrite 2:4 SpMM kernel.
937 static LogicalResult rewrite2To4SpMM(PatternRewriter &rewriter,
938                                      linalg::GenericOp op) {
939   Location loc = op.getLoc();
940   Value A = op.getOperand(0);
941   Value B = op.getOperand(1);
942   Value C = op.getOperand(2); // we have C = AB
943   SmallVector<Value> tokens;
944 
945   // The cuSparselt API currently only allows pruning and compression
946   // to occur on the device. So we recognize the pattern
947   //    A' = convert A  ; dense to 2:4
948   //    C  = A'B        ; 2:4 matrix mult
949   // and then perform compression and matrix multiplication on device.
950   auto cnv = A.getDefiningOp<ConvertOp>();
951   assert(cnv);
952   A = cnv.getSource();
953 
954   // All input should be dense tensors.
955   if (!isDenseTensor(A) || !isDenseTensor(B) || !isDenseTensor(C))
956     return failure();
957 
958   // Start sparse kernel and copy data from host to device.
959   //   a : bufA -> matA
960   //   b : bufB -> matB
961   //   c : bufC -> matC
962   Value bufA = genTensorToMemref(rewriter, loc, A);
963   Value matA = genAllocCopy(rewriter, loc, bufA, tokens);
964   Value bufB = genTensorToMemref(rewriter, loc, B);
965   Value matB = genAllocCopy(rewriter, loc, bufB, tokens);
966   Value bufC = genTensorToMemref(rewriter, loc, C);
967   Value matC = genAllocCopy(rewriter, loc, bufC, tokens);
968   genBlockingWait(rewriter, loc, tokens);
969   tokens.clear();
970 
971   // Create sparse environment and sparse matrix/dense vector handles.
972   Value szm = linalg::createOrFoldDimOp(rewriter, loc, matA, 0);
973   Value szk = linalg::createOrFoldDimOp(rewriter, loc, matB, 0);
974   Value szn = linalg::createOrFoldDimOp(rewriter, loc, matC, 1);
975   Type indexTp = rewriter.getIndexType();
976   Type dnTensorHandleTp = rewriter.getType<gpu::SparseDnTensorHandleType>();
977   Type spMatHandleTp = rewriter.getType<gpu::SparseSpMatHandleType>();
978   Type tokenTp = rewriter.getType<gpu::AsyncTokenType>();
979   Value token = genFirstWait(rewriter, loc);
980   Operation *spGenA = rewriter.create<gpu::Create2To4SpMatOp>(
981       loc, spMatHandleTp, tokenTp, token, szm, szk,
982       gpu::Prune2To4SpMatFlag::PRUNE_AND_CHECK, matA);
983   Value spMatA = spGenA->getResult(0);
984   token = spGenA->getResult(1);
985   auto dmatB = rewriter.create<gpu::CreateDnTensorOp>(
986       loc, dnTensorHandleTp, tokenTp, token, matB,
987       SmallVector<Value>{szk, szn});
988   Value dnB = dmatB.getResult(0);
989   token = dmatB.getAsyncToken();
990   auto dmatC = rewriter.create<gpu::CreateDnTensorOp>(
991       loc, dnTensorHandleTp, tokenTp, token, matC,
992       SmallVector<Value>{szm, szn});
993   Value dnC = dmatC.getResult(0);
994   token = dmatC.getAsyncToken();
995   auto dmatCType = llvm::cast<ShapedType>(matC.getType()).getElementType();
996 
997   // Precompute buffersize for SpMM.
998   SmallVector<Type> bufferTypes_{indexTp, indexTp, indexTp};
999   TypeRange bufferTypes(bufferTypes_);
1000   auto bufferComp = rewriter.create<gpu::SpMMBufferSizeOp>(
1001       loc, bufferTypes, tokenTp, token, gpu::TransposeMode::NON_TRANSPOSE,
1002       gpu::TransposeMode::NON_TRANSPOSE, spMatA, dnB, dnC,
1003       /*computeType=*/dmatCType);
1004   token = bufferComp.getAsyncToken();
1005 
1006   // Allocate buffers on host.
1007   Value bufferSz1 = bufferComp.getResult(0);
1008   auto buf1 = genAllocBuffer(rewriter, loc, bufferSz1, token);
1009   Value buffer1 = buf1.getResult(0);
1010   token = buf1.getAsyncToken();
1011   Value bufferSz2 = bufferComp.getResult(1);
1012   auto buf2 = genAllocBuffer(rewriter, loc, bufferSz2, token);
1013   Value buffer2 = buf2.getResult(0);
1014   token = buf2.getAsyncToken();
1015   Value bufferSz3 = bufferComp.getResult(2);
1016   auto buf3 = genAllocBuffer(rewriter, loc, bufferSz3, token);
1017   Value buffer3 = buf3.getResult(0);
1018   token = buf3.getAsyncToken();
1019 
1020   // Perform the SpMM.
1021   auto dnCType = llvm::cast<ShapedType>(matC.getType()).getElementType();
1022   auto spmmComp = rewriter.create<gpu::SpMMOp>(
1023       loc, tokenTp, token, spMatA, dnB, dnC, /*computeType=*/dnCType,
1024       SmallVector<Value>{buffer1, buffer2, buffer3});
1025   token = spmmComp.getAsyncToken();
1026 
1027   // Copy data back to host and free all the resources.
1028   token = rewriter.create<gpu::DestroySpMatOp>(loc, tokenTp, token, spMatA)
1029               .getAsyncToken();
1030   token = rewriter.create<gpu::DestroyDnTensorOp>(loc, tokenTp, token, dnB)
1031               .getAsyncToken();
1032   token = rewriter.create<gpu::DestroyDnTensorOp>(loc, tokenTp, token, dnC)
1033               .getAsyncToken();
1034   SmallVector<Value> newDynamicSizes;
1035   token = genDeallocMemRef(rewriter, loc, buffer1, token);
1036   token = genDeallocMemRef(rewriter, loc, buffer2, token);
1037   token = genDeallocMemRef(rewriter, loc, buffer3, token);
1038   token = genDeallocMemRef(rewriter, loc, matA, token);
1039   token = genDeallocMemRef(rewriter, loc, matB, token);
1040   token = genCopyMemRef(rewriter, loc, bufC, matC, token);
1041   token = genDeallocMemRef(rewriter, loc, matC, token);
1042   tokens.push_back(token);
1043   genBlockingWait(rewriter, loc, tokens);
1044   tokens.clear();
1045 
1046   // Done.
1047   rewriter.replaceOpWithNewOp<bufferization::ToTensorOp>(op, bufC);
1048   return success();
1049 }
1050 
1051 /// Match and rewrite SDDMM kernel.
1052 static LogicalResult rewriteSDDMM(PatternRewriter &rewriter,
1053                                   linalg::GenericOp op, bool enableRT) {
1054   Location loc = op.getLoc();
1055   Value a = op.getOperand(0);
1056   Value b = op.getOperand(1);
1057   Value c = op.getOperand(2);
1058   SmallVector<Value> tokens;
1059 
1060   // Only admissible sparse matrix format (no COO/CSC) and dense matrices.
1061   SparseTensorType aTp = getSparseTensorType(a);
1062   SparseTensorType bTp = getSparseTensorType(b);
1063   SparseTensorType cTp = getSparseTensorType(c);
1064   auto format = getCuSparseFormat(cTp, bTp, aTp, enableRT, /*isMatVec=*/false);
1065   if (format == CuSparseFormat::kNone || format == CuSparseFormat::kCOO ||
1066       format == CuSparseFormat::kCSC)
1067     return failure();
1068 
1069   // The SDDMM does the in-place operation.
1070   // Start sparse kernel and copy data from host to device.
1071   //   a : bufA           -> matA
1072   //   b : bufB           -> matB
1073   //   c : memR/memC/memV -> rowC,colC,valC
1074   Value nseC = rewriter.create<NumberOfEntriesOp>(loc, c);
1075   Value szm = linalg::createOrFoldDimOp(rewriter, loc, a, 0);
1076   Value szk = linalg::createOrFoldDimOp(rewriter, loc, a, 1);
1077   Value szn = linalg::createOrFoldDimOp(rewriter, loc, b, 1);
1078   Value bufA = genTensorToMemref(rewriter, loc, a);
1079   Value matA = genAllocCopy(rewriter, loc, bufA, tokens);
1080   Value bufB = genTensorToMemref(rewriter, loc, b);
1081   Value matB = genAllocCopy(rewriter, loc, bufB, tokens);
1082   Value memR = genFirstPosOrCrds(rewriter, loc, c, format, enableRT);
1083   Value memC = genSecondCrds(rewriter, loc, c, format, enableRT); // or empty
1084   Value memV = rewriter.create<ToValuesOp>(loc, c);
1085   Value rowC = genAllocCopy(rewriter, loc, memR, tokens);
1086   Value colC = memC ? genAllocCopy(rewriter, loc, memC, tokens) : Value();
1087   Value valC = genAllocCopy(rewriter, loc, memV, tokens);
1088   genBlockingWait(rewriter, loc, tokens);
1089   tokens.clear();
1090 
1091   // Create sparse environment and sparse matrix/dense matrix handles.
1092   Type indexTp = rewriter.getIndexType();
1093   Type dnMatHandleTp = rewriter.getType<gpu::SparseDnTensorHandleType>();
1094   Type spMatHandleTp = rewriter.getType<gpu::SparseSpMatHandleType>();
1095   Type tokenTp = rewriter.getType<gpu::AsyncTokenType>();
1096   Value token = genFirstWait(rewriter, loc);
1097   auto dmatA = rewriter.create<gpu::CreateDnTensorOp>(
1098       loc, dnMatHandleTp, tokenTp, token, matA, SmallVector<Value>{szm, szk});
1099   Value dnA = dmatA.getResult(0);
1100   token = dmatA.getAsyncToken();
1101   auto dmatB = rewriter.create<gpu::CreateDnTensorOp>(
1102       loc, dnMatHandleTp, tokenTp, token, matB, SmallVector<Value>{szk, szn});
1103   Value dnB = dmatB.getResult(0);
1104   token = dmatB.getAsyncToken();
1105   Operation *spGenC =
1106       genSpMat(rewriter, loc, cTp, spMatHandleTp, tokenTp, token, szm, szn,
1107                nseC, rowC, colC, valC, format, enableRT);
1108   Value spMatC = spGenC->getResult(0);
1109   token = spGenC->getResult(1);
1110   auto dnCType = llvm::cast<ShapedType>(c.getType()).getElementType();
1111 
1112   // Precompute buffersize for SDDMM.
1113   auto bufferComp = rewriter.create<gpu::SDDMMBufferSizeOp>(
1114       loc, indexTp, tokenTp, token, dnA, dnB, spMatC, dnCType);
1115   Value bufferSz = bufferComp.getResult(0);
1116   token = bufferComp.getAsyncToken();
1117   auto buf = genAllocBuffer(rewriter, loc, bufferSz, token);
1118   Value buffer = buf.getResult(0);
1119   token = buf.getAsyncToken();
1120 
1121   // Perform the SDDMM.
1122   auto sddmmComp = rewriter.create<gpu::SDDMMOp>(loc, tokenTp, token, dnA, dnB,
1123                                                  spMatC, dnCType, buffer);
1124   token = sddmmComp.getAsyncToken();
1125 
1126   // Copy data back to host and free all the resoures.
1127   token = rewriter.create<gpu::DestroyDnTensorOp>(loc, tokenTp, token, dnA)
1128               .getAsyncToken();
1129   token = rewriter.create<gpu::DestroyDnTensorOp>(loc, tokenTp, token, dnB)
1130               .getAsyncToken();
1131   token = rewriter.create<gpu::DestroySpMatOp>(loc, tokenTp, token, spMatC)
1132               .getAsyncToken();
1133   token = genDeallocMemRef(rewriter, loc, buffer, token);
1134   token = genDeallocMemRef(rewriter, loc, matA, token);
1135   token = genDeallocMemRef(rewriter, loc, matB, token);
1136   token = genDeallocMemRef(rewriter, loc, rowC, token);
1137   if (colC)
1138     token = genDeallocMemRef(rewriter, loc, colC, token);
1139   token = genCopyMemRef(rewriter, loc, memV, valC, token);
1140   token = genDeallocMemRef(rewriter, loc, valC, token);
1141   tokens.push_back(token);
1142   genBlockingWait(rewriter, loc, tokens);
1143   tokens.clear();
1144 
1145   // Done.
1146   rewriter.replaceOpWithNewOp<sparse_tensor::LoadOp>(op, c);
1147   return success();
1148 }
1149 
1150 //===----------------------------------------------------------------------===//
1151 // Rewriting rules for direct code generation.
1152 //===----------------------------------------------------------------------===//
1153 
1154 /// Proof-of-concept rewriter. This rule generates a GPU implementation
1155 /// for each outermost forall loop generated by the sparsifier.
1156 /// TODO: right now works with parallelization-strategy=dense-outer-loop
1157 ///       but give this its own flags in the future
1158 struct ForallRewriter : public OpRewritePattern<scf::ParallelOp> {
1159   using OpRewritePattern<scf::ParallelOp>::OpRewritePattern;
1160 
1161   ForallRewriter(MLIRContext *context, unsigned nT)
1162       : OpRewritePattern(context), numThreads(nT){};
1163 
1164   LogicalResult matchAndRewrite(scf::ParallelOp forallOp,
1165                                 PatternRewriter &rewriter) const override {
1166     // Reject inadmissible loop form.
1167     // Essentially only accept a loop, generated by the sparsifier,
1168     // of the form
1169     //   forall (i = 0; i < N; i++)
1170     // so that cyclic scheduling over the threads is easy.
1171     if (!forallOp->hasAttr(LoopEmitter::getLoopEmitterLoopAttrName()) ||
1172         forallOp.getNumReductions() != 0 || forallOp.getNumLoops() != 1 ||
1173         !matchPattern(forallOp.getLowerBound()[0], m_Zero()) ||
1174         !matchPattern(forallOp.getStep()[0], m_One()))
1175       return failure();
1176     // Collect every value that is computed outside the parallel loop.
1177     SetVector<Value> invariants; // stable iteration!
1178     forallOp->walk([&](Operation *op) {
1179       // Collect all values of admissible ops.
1180       for (OpOperand &o : op->getOpOperands()) {
1181         Value val = o.get();
1182         Block *block;
1183         if (auto arg = dyn_cast<BlockArgument>(val))
1184           block = arg.getOwner();
1185         else
1186           block = val.getDefiningOp()->getBlock();
1187         if (!forallOp.getRegion().findAncestorBlockInRegion(*block))
1188           invariants.insert(val);
1189       }
1190     });
1191     // Outline the outside values as proper parameters. Fail when sharing
1192     // value between host and device is not straightforward.
1193     SmallVector<Value> constants;
1194     SmallVector<Value> scalars;
1195     SmallVector<Value> buffers;
1196     for (Value val : invariants) {
1197       Type tp = val.getType();
1198       if (val.getDefiningOp<arith::ConstantOp>())
1199         constants.push_back(val);
1200       else if (isa<FloatType>(tp) || tp.isIntOrIndex())
1201         scalars.push_back(val);
1202       else if (isa<MemRefType>(tp))
1203         buffers.push_back(val);
1204       else
1205         return failure(); // don't know how to share
1206     }
1207     // Pass outlined non-constant values.
1208     // TODO: Experiment with `useHostRegistrationForOut` to see if we want to
1209     //       keep the feature at all (either through a heuristic or compiler
1210     //       option for gpu codegen).
1211     Location loc = forallOp->getLoc();
1212     SmallVector<Value> args;
1213     SmallVector<Value> tokens;
1214     Value out = genParametersIn(rewriter, loc, scalars, buffers, args, tokens,
1215                                 /*useHostRegistrationForOut=*/false);
1216     // Set up GPU module and construct GPU function.
1217     auto saveIp = rewriter.saveInsertionPoint();
1218     ModuleOp topModule = forallOp->getParentOfType<ModuleOp>();
1219     auto gpuModule = genGPUModule(rewriter, topModule);
1220     auto gpuFunc = genGPUFunc(rewriter, gpuModule, args);
1221     genGPUCode(rewriter, gpuFunc, forallOp, constants, scalars, buffers);
1222     // Generate code that launches the kernel asynchronously, blocking on all
1223     // opens tokens and yielding a new token for the output.
1224     // TODO: Passing in tokens to launch up does not seem to be properly lowered
1225     //       by cubin yet, hence the current blocking wait.
1226     rewriter.restoreInsertionPoint(saveIp);
1227     genBlockingWait(rewriter, loc, tokens);
1228     tokens.clear();
1229     Value kernelToken =
1230         genLaunchGPUFunc(rewriter, gpuFunc, args, tokens, numThreads);
1231     // Finalize the outlined arguments.
1232     genParametersOut(rewriter, loc, out, kernelToken, scalars, buffers, args,
1233                      tokens);
1234     genBlockingWait(rewriter, loc, tokens);
1235     rewriter.eraseOp(forallOp);
1236     return success();
1237   }
1238 
1239 private:
1240   unsigned numThreads;
1241 };
1242 
1243 //===----------------------------------------------------------------------===//
1244 // Rewriting rules for library recognition and code generation.
1245 //===----------------------------------------------------------------------===//
1246 
1247 /// Proof-of-concept rewriter. This rule recognizes certain math kernels
1248 /// and replaces these with corresponding calls into a sparse library.
1249 struct LinalgOpRewriter : public OpRewritePattern<linalg::GenericOp> {
1250   using OpRewritePattern<linalg::GenericOp>::OpRewritePattern;
1251 
1252   LinalgOpRewriter(MLIRContext *context, bool rt)
1253       : OpRewritePattern(context), enableRT(rt) {}
1254 
1255   LogicalResult matchAndRewrite(linalg::GenericOp op,
1256                                 PatternRewriter &rewriter) const override {
1257     if (op.getNumDpsInits() != 1)
1258       return failure(); // reject multi-output
1259 
1260     const unsigned numLoops = op.getNumLoops();
1261     const unsigned numTensors = op->getNumOperands();
1262     const auto iteratorTypes = op.getIteratorTypesArray();
1263     SmallVector<AffineMap, 4> maps = op.getIndexingMapsArray();
1264 
1265     using MapList = ArrayRef<ArrayRef<AffineExpr>>;
1266     auto infer = [&](MapList m) {
1267       return AffineMap::inferFromExprList(m, op.getContext());
1268     };
1269     AffineExpr i, j, k;
1270     bindDims(getContext(), i, j, k);
1271 
1272     // TODO: more robust patterns, transposed versions, more kernels,
1273     //       identify alpha and beta and pass them to the CUDA calls.
1274 
1275     // Recognize a SpMV kernel.
1276     if (numLoops == 2 && numTensors == 3 &&
1277         linalg::isParallelIterator(iteratorTypes[0]) &&
1278         linalg::isReductionIterator(iteratorTypes[1]) &&
1279         maps == infer({{i, j}, {j}, {i}}) && matchSumOfMultOfArgs(op)) {
1280       return rewriteSpMV(rewriter, op, enableRT);
1281     }
1282 
1283     // Recognize a SpGEMM, 2:4-SpMM, or SpMM kernel.
1284     if (numLoops == 3 && numTensors == 3 &&
1285         linalg::isParallelIterator(iteratorTypes[0]) &&
1286         linalg::isParallelIterator(iteratorTypes[1]) &&
1287         linalg::isReductionIterator(iteratorTypes[2]) &&
1288         maps == infer({{i, k}, {k, j}, {i, j}}) && matchSumOfMultOfArgs(op)) {
1289       if (!isDenseTensor(op.getOperand(0)) && !isDenseTensor(op.getOperand(1)))
1290         return rewriteSpGEMM(rewriter, op, enableRT);
1291       if (isConversionInto24(op.getOperand(0)))
1292         return rewrite2To4SpMM(rewriter, op);
1293       return rewriteSpMM(rewriter, op, enableRT);
1294     }
1295 
1296     // Recognize a SDDMM kernel.
1297     if (numLoops == 3 && numTensors == 3 &&
1298         linalg::isParallelIterator(iteratorTypes[0]) &&
1299         linalg::isParallelIterator(iteratorTypes[1]) &&
1300         linalg::isReductionIterator(iteratorTypes[2]) &&
1301         maps == infer({{i, k}, {k, j}, {i, j}}) &&
1302         matchSumReductionOfMulUnary(op)) {
1303       return rewriteSDDMM(rewriter, op, enableRT);
1304     }
1305 
1306     return failure();
1307   }
1308 
1309 private:
1310   bool enableRT;
1311 };
1312 
1313 } // namespace
1314 
1315 //===----------------------------------------------------------------------===//
1316 // Public method for populating GPU rewriting rules.
1317 //
1318 // Currently two set of rewriting rules are made available. The first set
1319 // implements direct code generation, currently by means of convering the
1320 // outermost paralell loop into GPU threads. The second set implements
1321 // libary recognition of a set of sparse operations. Eventually, the right
1322 // combination of these two approaches has to be found.
1323 //===----------------------------------------------------------------------===//
1324 
1325 void mlir::populateSparseGPUCodegenPatterns(RewritePatternSet &patterns,
1326                                             unsigned numThreads) {
1327   patterns.add<ForallRewriter>(patterns.getContext(), numThreads);
1328 }
1329 
1330 void mlir::populateSparseGPULibgenPatterns(RewritePatternSet &patterns,
1331                                            bool enableRT) {
1332   patterns.add<LinalgOpRewriter>(patterns.getContext(), enableRT);
1333 }
1334