xref: /llvm-project/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp (revision 10056c821a56a19cef732129e4e0c5883ae1ee49)
1 //===- SparseGPUCodegen.cpp - Generates GPU code --------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This is a prototype GPU codegenerator for the sparsifier.
10 // The objective is to eventually use the right combination of
11 // direct code generation and libary calls into vendor-specific
12 // highly optimized sparse libraries (e.g. cuSparse for CUDA).
13 //
14 //===----------------------------------------------------------------------===//
15 
16 #include "Utils/CodegenUtils.h"
17 #include "Utils/LoopEmitter.h"
18 
19 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
20 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
21 #include "mlir/Dialect/Linalg/IR/Linalg.h"
22 #include "mlir/Dialect/Linalg/Utils/Utils.h"
23 #include "mlir/Dialect/MemRef/IR/MemRef.h"
24 #include "mlir/Dialect/SCF/IR/SCF.h"
25 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
26 #include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h"
27 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
28 #include "mlir/IR/IRMapping.h"
29 #include "mlir/IR/Matchers.h"
30 
31 using namespace mlir;
32 using namespace mlir::sparse_tensor;
33 
34 namespace {
35 
36 // Sparse formats supported by cuSparse.
37 enum class CuSparseFormat {
38   kNone,
39   kCOO,
40   kCSR,
41   kCSC,
42   kBSR,
43 };
44 
45 //===----------------------------------------------------------------------===//
46 // Helper methods.
47 //===----------------------------------------------------------------------===//
48 
49 /// Marks the given top module as a GPU container module.
50 static void markAsGPUContainer(ModuleOp topModule) {
51   topModule->setAttr(gpu::GPUDialect::getContainerModuleAttrName(),
52                      UnitAttr::get(topModule->getContext()));
53 }
54 
55 /// Constructs a new GPU module (for GPU kernels) inside the given top module,
56 /// or returns an existing GPU module if one was built previously.
57 static gpu::GPUModuleOp genGPUModule(OpBuilder &builder, ModuleOp topModule) {
58   for (auto op : topModule.getBodyRegion().getOps<gpu::GPUModuleOp>())
59     return op; // existing
60   markAsGPUContainer(topModule);
61   builder.setInsertionPointToStart(&topModule.getBodyRegion().front());
62   return builder.create<gpu::GPUModuleOp>(topModule->getLoc(),
63                                           "sparse_kernels");
64 }
65 
66 /// Constructs a new GPU kernel in the given GPU module.
67 static gpu::GPUFuncOp genGPUFunc(OpBuilder &builder, gpu::GPUModuleOp gpuModule,
68                                  SmallVectorImpl<Value> &args) {
69   // Get a unique kernel name. Not very creative,
70   // but we simply try kernel0, kernel1, etc.
71   unsigned kernelNumber = 0;
72   SmallString<16> kernelName;
73   do {
74     kernelName.clear();
75     ("kernel" + Twine(kernelNumber++)).toStringRef(kernelName);
76   } while (gpuModule.lookupSymbol(kernelName));
77   // Then we insert a new kernel with given arguments into the module.
78   builder.setInsertionPointToStart(&gpuModule.getBodyRegion().front());
79   SmallVector<Type> argsTp;
80   for (unsigned i = 0, e = args.size(); i < e; i++)
81     argsTp.push_back(args[i].getType());
82   FunctionType type = FunctionType::get(gpuModule->getContext(), argsTp, {});
83   auto gpuFunc =
84       builder.create<gpu::GPUFuncOp>(gpuModule->getLoc(), kernelName, type);
85   gpuFunc->setAttr(gpu::GPUDialect::getKernelFuncAttrName(),
86                    builder.getUnitAttr());
87   return gpuFunc;
88 }
89 
90 /// Constructs code to launch GPU kernel.
91 static Value genLaunchGPUFunc(OpBuilder &builder, gpu::GPUFuncOp gpuFunc,
92                               SmallVectorImpl<Value> &args,
93                               SmallVectorImpl<Value> &tokens,
94                               unsigned numThreads) {
95   Location loc = gpuFunc->getLoc();
96   Value none = TypedValue<::mlir::IntegerType>{};
97   Value one = constantIndex(builder, loc, 1);
98   Value numT = constantIndex(builder, loc, numThreads);
99   gpu::KernelDim3 gridSize = {one, one, one};
100   gpu::KernelDim3 blckSize = {numT, one, one};
101   return builder
102       .create<gpu::LaunchFuncOp>(loc, gpuFunc, gridSize, blckSize,
103                                  /*dynSharedMemSz*/ none, args,
104                                  builder.getType<gpu::AsyncTokenType>(), tokens)
105       .getAsyncToken();
106 }
107 
108 /// Maps the provided ranked host buffer into the device address space.
109 /// Writes from the host are guaranteed to be visible to device kernels
110 /// that are launched afterwards. Writes from the device are guaranteed
111 /// to be visible on the host after synchronizing with the device kernel
112 /// completion. Needs to cast the buffer to a unranked buffer.
113 static Value genHostRegisterMemref(OpBuilder &builder, Location loc,
114                                    Value mem) {
115   MemRefType memTp = cast<MemRefType>(mem.getType());
116   UnrankedMemRefType resTp =
117       UnrankedMemRefType::get(memTp.getElementType(), /*memorySpace=*/0);
118   Value cast = builder.create<memref::CastOp>(loc, resTp, mem);
119   builder.create<gpu::HostRegisterOp>(loc, cast);
120   return cast;
121 }
122 
123 /// Unmaps the provided buffer, expecting the casted buffer.
124 static void genHostUnregisterMemref(OpBuilder &builder, Location loc,
125                                     Value cast) {
126   builder.create<gpu::HostUnregisterOp>(loc, cast);
127 }
128 
129 /// Generates first wait in an asynchronous chain.
130 static Value genFirstWait(OpBuilder &builder, Location loc) {
131   Type tokenType = builder.getType<gpu::AsyncTokenType>();
132   return builder.create<gpu::WaitOp>(loc, tokenType, ValueRange())
133       .getAsyncToken();
134 }
135 
136 /// Generates last, blocking wait in an asynchronous chain.
137 static void genBlockingWait(OpBuilder &builder, Location loc,
138                             ValueRange operands) {
139   builder.create<gpu::WaitOp>(loc, Type(), operands);
140 }
141 
142 /// Allocates memory on the device.
143 /// TODO: A `host_shared` attribute could be used to indicate that
144 ///       the buffer is visible by both host and device, but lowering
145 ///       that feature does not seem to be fully supported yet.
146 static gpu::AllocOp genAllocMemRef(OpBuilder &builder, Location loc, Value mem,
147                                    Value token) {
148   auto tp = cast<ShapedType>(mem.getType());
149   auto elemTp = tp.getElementType();
150   auto shape = tp.getShape();
151   auto memTp = MemRefType::get(shape, elemTp);
152   SmallVector<Value> dynamicSizes;
153   for (unsigned r = 0, rank = tp.getRank(); r < rank; r++) {
154     if (shape[r] == ShapedType::kDynamic) {
155       Value dimOp = linalg::createOrFoldDimOp(builder, loc, mem, r);
156       dynamicSizes.push_back(dimOp);
157     }
158   }
159   return builder.create<gpu::AllocOp>(loc, TypeRange({memTp, token.getType()}),
160                                       token, dynamicSizes, ValueRange());
161 }
162 
163 // Allocates a typed buffer on the host with given size.
164 static Value genHostBuffer(OpBuilder &builder, Location loc, Type type,
165                            Value size) {
166   const auto memTp = MemRefType::get({ShapedType::kDynamic}, type);
167   return builder.create<memref::AllocOp>(loc, memTp, size).getResult();
168 }
169 
170 // Allocates a typed buffer on the device with given size.
171 static gpu::AllocOp genAllocBuffer(OpBuilder &builder, Location loc, Type type,
172                                    Value size, Value token) {
173   const auto memTp = MemRefType::get({ShapedType::kDynamic}, type);
174   return builder.create<gpu::AllocOp>(loc, TypeRange({memTp, token.getType()}),
175                                       token, size, ValueRange());
176 }
177 
178 // Allocates a void buffer on the device with given size.
179 static gpu::AllocOp genAllocBuffer(OpBuilder &builder, Location loc, Value size,
180                                    Value token) {
181   return genAllocBuffer(builder, loc, builder.getI8Type(), size, token);
182 }
183 
184 /// Deallocates memory from the device.
185 static Value genDeallocMemRef(OpBuilder &builder, Location loc, Value mem,
186                               Value token) {
187   return builder.create<gpu::DeallocOp>(loc, token.getType(), token, mem)
188       .getAsyncToken();
189 }
190 
191 /// Copies memory between host and device (direction is implicit).
192 static Value genCopyMemRef(OpBuilder &builder, Location loc, Value dst,
193                            Value src, Value token) {
194   return builder.create<gpu::MemcpyOp>(loc, token.getType(), token, dst, src)
195       .getAsyncToken();
196 }
197 
198 /// Generates an alloc/copy pair.
199 static Value genAllocCopy(OpBuilder &builder, Location loc, Value b,
200                           SmallVectorImpl<Value> &tokens) {
201   Value firstToken = genFirstWait(builder, loc);
202   auto alloc = genAllocMemRef(builder, loc, b, firstToken);
203   Value devMem = alloc.getResult(0);
204   Value depToken = alloc.getAsyncToken(); // copy-after-alloc
205   tokens.push_back(genCopyMemRef(builder, loc, devMem, b, depToken));
206   return devMem;
207 }
208 
209 /// Generates a memref from tensor operation.
210 static Value genTensorToMemref(PatternRewriter &rewriter, Location loc,
211                                Value tensor) {
212   auto tensorType = llvm::cast<ShapedType>(tensor.getType());
213   auto memrefType =
214       MemRefType::get(tensorType.getShape(), tensorType.getElementType());
215   return rewriter.create<bufferization::ToMemrefOp>(loc, memrefType, tensor);
216 }
217 
218 /// Prepares the outlined arguments, passing scalars and buffers in. Here we
219 /// assume that the first buffer is the one allocated for output. We create
220 /// a set of properly chained asynchronous allocation/copy pairs to increase
221 /// overlap before launching the kernel.
222 static Value genParametersIn(OpBuilder &builder, Location loc,
223                              SmallVectorImpl<Value> &scalars,
224                              SmallVectorImpl<Value> &buffers,
225                              SmallVectorImpl<Value> &args,
226                              SmallVectorImpl<Value> &tokens,
227                              bool useHostRegistrationForOut) {
228   Value out;
229   // Scalars are passed by value.
230   for (Value s : scalars)
231     args.push_back(s);
232   // Buffers are need to be made visible on device.
233   for (Value b : buffers) {
234     if (useHostRegistrationForOut) {
235       out = genHostRegisterMemref(builder, loc, b);
236       args.push_back(b);
237       useHostRegistrationForOut = false;
238       continue;
239     }
240     args.push_back(genAllocCopy(builder, loc, b, tokens));
241   }
242   return out;
243 }
244 
245 /// Finalizes the outlined arguments. The output buffer is copied depending
246 /// on the kernel token and then deallocated. All other buffers are simply
247 /// deallocated. Then we wait for all operations to complete.
248 static void genParametersOut(OpBuilder &builder, Location loc, Value out,
249                              Value kernelToken, SmallVectorImpl<Value> &scalars,
250                              SmallVectorImpl<Value> &buffers,
251                              SmallVectorImpl<Value> &args,
252                              SmallVectorImpl<Value> &tokens) {
253   unsigned base = scalars.size();
254   for (unsigned i = base, e = args.size(); i < e; i++) {
255     Value firstToken;
256     if (i == base) {
257       // Assumed output parameter: unregister or copy-out.
258       if (out) {
259         genHostUnregisterMemref(builder, loc, out);
260         out = Value();
261         continue;
262       }
263       firstToken =
264           genCopyMemRef(builder, loc, buffers[0], args[i], kernelToken);
265     } else {
266       firstToken = genFirstWait(builder, loc);
267     }
268     tokens.push_back(genDeallocMemRef(builder, loc, args[i], firstToken));
269   }
270 }
271 
272 /// Constructs code for new GPU kernel.
273 static void genGPUCode(PatternRewriter &rewriter, gpu::GPUFuncOp gpuFunc,
274                        scf::ParallelOp forallOp,
275                        SmallVectorImpl<Value> &constants,
276                        SmallVectorImpl<Value> &scalars,
277                        SmallVectorImpl<Value> &buffers) {
278   Location loc = gpuFunc->getLoc();
279   Block &block = gpuFunc.getBody().front();
280   rewriter.setInsertionPointToStart(&block);
281 
282   // Re-generate the constants, recapture all arguments.
283   unsigned arg = 0;
284   IRMapping irMap;
285   for (Value c : constants)
286     irMap.map(c, rewriter.clone(*c.getDefiningOp())->getResult(0));
287   for (Value s : scalars)
288     irMap.map(s, block.getArgument(arg++));
289   for (Value b : buffers)
290     irMap.map(b, block.getArgument(arg++));
291 
292   // Assume 1-dimensional grid/block configuration (only x dimension),
293   // so that:
294   //   row = blockIdx.x * blockDim.x + threadIdx.x
295   //   inc = blockDim.x * gridDim.x
296   Value bid = rewriter.create<gpu::BlockIdOp>(loc, gpu::Dimension::x);
297   Value bsz = rewriter.create<gpu::BlockDimOp>(loc, gpu::Dimension::x);
298   Value tid = rewriter.create<gpu::ThreadIdOp>(loc, gpu::Dimension::x);
299   Value gsz = rewriter.create<gpu::GridDimOp>(loc, gpu::Dimension::x);
300   Value mul = rewriter.create<arith::MulIOp>(loc, bid, bsz);
301   Value row = rewriter.create<arith::AddIOp>(loc, mul, tid);
302   Value inc = rewriter.create<arith::MulIOp>(loc, bsz, gsz);
303 
304   // Construct the iteration over the computational space that
305   // accounts for the fact that the total number of threads and
306   // the amount of work to be done usually do not match precisely.
307   //   for (r = row; r < N; r += inc) {
308   //     <loop-body>
309   //   }
310   Value upper = irMap.lookup(forallOp.getUpperBound()[0]);
311   scf::ForOp forOp = rewriter.create<scf::ForOp>(loc, row, upper, inc);
312   // The scf.for builder creates an empty block. scf.for does not allow multiple
313   // blocks in its region, so delete the block before `cloneRegionBefore` adds
314   // an additional block.
315   rewriter.eraseBlock(forOp.getBody());
316   rewriter.cloneRegionBefore(forallOp.getRegion(), forOp.getRegion(),
317                              forOp.getRegion().begin(), irMap);
318   // Replace the scf.reduce terminator.
319   rewriter.setInsertionPoint(forOp.getBody()->getTerminator());
320   rewriter.replaceOpWithNewOp<scf::YieldOp>(forOp.getBody()->getTerminator());
321 
322   // Done.
323   rewriter.setInsertionPointAfter(forOp);
324   rewriter.create<gpu::ReturnOp>(gpuFunc->getLoc());
325 }
326 
327 //===----------------------------------------------------------------------===//
328 // Library helper methods.
329 //===----------------------------------------------------------------------===//
330 
331 /// Helper to detect a + b with arguments taken from given block.
332 static bool matchAddOfArgs(Block *block, Value val) {
333   if (auto *def = val.getDefiningOp()) {
334     if (isa<arith::AddFOp, arith::AddIOp>(def)) {
335       Value a = block->getArguments()[0];
336       Value b = block->getArguments()[1];
337       return (def->getOperand(0) == a && def->getOperand(1) == b) ||
338              (def->getOperand(0) == b && def->getOperand(1) == a);
339     }
340   }
341   return false;
342 }
343 
344 /// Helper to detect a * b with arguments taken from given block.
345 static bool matchMulOfArgs(Block *block, Value val) {
346   if (auto *def = val.getDefiningOp()) {
347     if (isa<arith::MulFOp, arith::MulIOp>(def)) {
348       Value a = block->getArguments()[0];
349       Value b = block->getArguments()[1];
350       return (def->getOperand(0) == a && def->getOperand(1) == b) ||
351              (def->getOperand(0) == b && def->getOperand(1) == a);
352     }
353   }
354   return false;
355 }
356 
357 /// Helper to detect x = x + a * b
358 static bool matchSumOfMultOfArgs(linalg::GenericOp op) {
359   auto yieldOp = cast<linalg::YieldOp>(op.getRegion().front().getTerminator());
360   if (auto *def = yieldOp.getOperand(0).getDefiningOp()) {
361     if (isa<arith::AddFOp, arith::AddIOp>(def)) {
362       Value x = op.getBlock()->getArguments()[2];
363       return (def->getOperand(0) == x &&
364               matchMulOfArgs(op.getBlock(), def->getOperand(1))) ||
365              (def->getOperand(1) == x &&
366               matchMulOfArgs(op.getBlock(), def->getOperand(0)));
367     }
368   }
369   return false;
370 }
371 
372 // Helper to detect c += spy(s) x (a * b)
373 static bool matchSumReductionOfMulUnary(linalg::GenericOp op) {
374   auto yieldOp = cast<linalg::YieldOp>(op.getRegion().front().getTerminator());
375   // The linalg yields a custom reduce result.
376   Value s_out = op.getBlock()->getArguments()[2];
377   if (auto redOp =
378           yieldOp.getOperand(0).getDefiningOp<sparse_tensor::ReduceOp>()) {
379     // The reduce consumes the output.
380     Value other;
381     if (s_out == redOp->getOperand(0))
382       other = redOp->getOperand(1);
383     else if (s_out == redOp->getOperand(1))
384       other = redOp->getOperand(0);
385     else
386       return false;
387     // The reduce op also consumes an unary which also consumes the output
388     // and does not define an absent value.
389     if (auto unOp = other.getDefiningOp<sparse_tensor::UnaryOp>()) {
390       if (s_out != unOp->getOperand(0) || !unOp.getAbsentRegion().empty())
391         return false;
392       // And the bodies are as expected.
393       auto yieldUn = cast<sparse_tensor::YieldOp>(
394           unOp.getRegion(0).front().getTerminator());
395       auto yieldRed = cast<sparse_tensor::YieldOp>(
396           redOp.getRegion().front().getTerminator());
397       return matchMulOfArgs(op.getBlock(), yieldUn.getOperand(0)) &&
398              matchAddOfArgs(&redOp.getRegion().front(), yieldRed.getOperand(0));
399     }
400   }
401   return false;
402 }
403 
404 /// Test for dense tensor.
405 static bool isDenseTensor(Value v) {
406   auto sTp = getSparseTensorType(v);
407   return sTp.getDimRank() == sTp.getLvlRank() && sTp.isAllDense();
408 }
409 
410 /// Test for suitable positions/coordinates width.
411 static bool isAdmissibleMetaData(SparseTensorType &aTp) {
412   return (aTp.getPosWidth() == 0 || aTp.getPosWidth() >= 16) &&
413          (aTp.getCrdWidth() == 0 || aTp.getCrdWidth() >= 16);
414 }
415 
416 /// Test for sorted COO matrix with suitable metadata.
417 static bool isAdmissibleCOO(SparseTensorType &aTp) {
418   return aTp.getDimRank() == 2 && aTp.getLvlRank() == 2 && aTp.isIdentity() &&
419          aTp.isCompressedLvl(0) && aTp.isOrderedLvl(0) && !aTp.isUniqueLvl(0) &&
420          aTp.isSingletonLvl(1) && aTp.isOrderedLvl(1) && aTp.isUniqueLvl(1) &&
421          isAdmissibleMetaData(aTp);
422 }
423 
424 /// Test for CSR matrix with suitable metadata.
425 static bool isAdmissibleCSR(SparseTensorType &aTp) {
426   return aTp.getDimRank() == 2 && aTp.getLvlRank() == 2 && aTp.isIdentity() &&
427          aTp.isDenseLvl(0) && aTp.isCompressedLvl(1) && aTp.isOrderedLvl(1) &&
428          aTp.isUniqueLvl(1) && isAdmissibleMetaData(aTp);
429 }
430 
431 /// Test for CSC matrix with suitable metadata.
432 static bool isAdmissibleCSC(SparseTensorType &aTp) {
433   return aTp.getDimRank() == 2 && aTp.getLvlRank() == 2 && !aTp.isIdentity() &&
434          aTp.isPermutation() && aTp.isDenseLvl(0) && aTp.isCompressedLvl(1) &&
435          aTp.isOrderedLvl(1) && aTp.isUniqueLvl(1) && isAdmissibleMetaData(aTp);
436 }
437 
438 /// Test for BSR matrix with suitable metadata.
439 static bool isAdmissibleBSR(SparseTensorType &aTp) {
440   if (aTp.getDimRank() == 2 && aTp.getLvlRank() == 4 && aTp.isDenseLvl(0) &&
441       aTp.isCompressedLvl(1) && aTp.isOrderedLvl(1) && aTp.isUniqueLvl(1) &&
442       aTp.isDenseLvl(2) && aTp.isDenseLvl(3) && isAdmissibleMetaData(aTp)) {
443     // CuSparse only supports "square" blocks currently.
444     SmallVector<unsigned> dims = getBlockSize(aTp.getDimToLvl());
445     assert(dims.size() == 2);
446     return dims[0] == dims[1] && dims[0] > 1;
447   }
448   return false;
449 }
450 
451 /// Returns a suitable sparse format for the operation and given operand
452 /// types with cuSparse, or kNone if none is available.
453 static CuSparseFormat getCuSparseFormat(SparseTensorType aTp,
454                                         SparseTensorType bTp,
455                                         SparseTensorType cTp, bool enableRT,
456                                         bool isMatVec) {
457   // The other operands have a dense type.
458   if (bTp.hasEncoding() || cTp.hasEncoding())
459     return CuSparseFormat::kNone;
460   // Now check for suitable operand type for the main operand.
461   if (isAdmissibleCOO(aTp))
462 #ifdef CUSPARSE_COO_AOS
463     return isMatVec ? CuSparseFormat::kCOO : CuSparseFormat::kNone;
464 #else
465     return enableRT ? CuSparseFormat::kCOO : CuSparseFormat::kNone;
466 #endif
467   if (isAdmissibleCSR(aTp))
468     return CuSparseFormat::kCSR;
469   if (isAdmissibleCSC(aTp))
470     return CuSparseFormat::kCSC;
471   if (isAdmissibleBSR(aTp))
472     return CuSparseFormat::kBSR;
473   return CuSparseFormat::kNone;
474 }
475 
476 /// Generates the first positions/coordinates of a sparse matrix.
477 static Value genFirstPosOrCrds(OpBuilder &builder, Location loc, Value a,
478                                CuSparseFormat format, bool enableRT) {
479   if (format == CuSparseFormat::kCOO) {
480     // Library uses SoA COO, direct IR uses AoS COO.
481     if (enableRT)
482       return genToCoordinates(builder, loc, a, 0);
483     return genToCoordinatesBuffer(builder, loc, a);
484   }
485   // Formats CSR/CSC and BSR use positions at 1.
486   return genToPositions(builder, loc, a, 1);
487 }
488 
489 /// Generates the second coordinates of a sparse matrix.
490 static Value genSecondCrds(OpBuilder &builder, Location loc, Value a,
491                            CuSparseFormat format, bool enableRT) {
492   bool isCOO = format == CuSparseFormat::kCOO;
493   if (isCOO && !enableRT)
494     return Value(); // nothing needed
495   // Formats CSR/CSC and BSR use coordinates at 1.
496   return genToCoordinates(builder, loc, a, 1);
497 }
498 
499 /// Generates the sparse matrix handle.
500 static Operation *genSpMat(OpBuilder &builder, Location loc,
501                            SparseTensorType &aTp, Type handleTp, Type tokenTp,
502                            Value token, Value sz1, Value sz2, Value nseA,
503                            Value rowA, Value colA, Value valA,
504                            CuSparseFormat format, bool enableRT) {
505   if (format == CuSparseFormat::kCOO) {
506     // Library uses SoA COO, direct IR uses AoS COO.
507     if (enableRT) {
508       assert(colA);
509       return builder.create<gpu::CreateCooOp>(loc, handleTp, tokenTp, token,
510                                               sz1, sz2, nseA, rowA, colA, valA);
511     }
512 #ifdef CUSPARSE_COO_AOS
513     assert(!colA);
514     return builder.create<gpu::CreateCooAoSOp>(loc, handleTp, tokenTp, token,
515                                                sz1, sz2, nseA, rowA, valA);
516 #else
517     llvm_unreachable("gpu::CreateCooAoSOp is deprecated");
518 #endif
519   }
520   assert(colA);
521   if (format == CuSparseFormat::kCSR)
522     return builder.create<gpu::CreateCsrOp>(loc, handleTp, tokenTp, token, sz1,
523                                             sz2, nseA, rowA, colA, valA);
524   if (format == CuSparseFormat::kCSC)
525     return builder.create<gpu::CreateCscOp>(loc, handleTp, tokenTp, token, sz1,
526                                             sz2, nseA, rowA, colA, valA);
527   // BSR requires a bit more work since we need to pass in the block size
528   // and all others sizes in terms of blocks (#block-rows, #block-cols,
529   // #nonzero-blocks).
530   assert(format == CuSparseFormat::kBSR);
531   SmallVector<unsigned> dims = getBlockSize(aTp.getDimToLvl());
532   assert(dims.size() == 2 && dims[0] == dims[1]);
533   uint64_t b = dims[0];
534   Value bSz = constantIndex(builder, loc, b);
535   Value bRows = builder.create<arith::DivUIOp>(loc, sz1, bSz);
536   Value bCols = builder.create<arith::DivUIOp>(loc, sz2, bSz);
537   Value bNum = builder.create<arith::DivUIOp>(
538       loc, nseA, constantIndex(builder, loc, b * b));
539   return builder.create<gpu::CreateBsrOp>(loc, handleTp, tokenTp, token, bRows,
540                                           bCols, bNum, bSz, bSz, rowA, colA,
541                                           valA);
542 }
543 
544 /// Match and rewrite SpMV kernel.
545 static LogicalResult rewriteSpMV(PatternRewriter &rewriter,
546                                  linalg::GenericOp op, bool enableRT) {
547   Location loc = op.getLoc();
548   Value a = op.getOperand(0);
549   Value x = op.getOperand(1);
550   Value y = op.getOperand(2); // we have y = Ax
551   SmallVector<Value> tokens;
552 
553   // Only admissible sparse matrix format and dense vectors (no BSR).
554   SparseTensorType aTp = getSparseTensorType(a);
555   SparseTensorType xTp = getSparseTensorType(x);
556   SparseTensorType yTp = getSparseTensorType(y);
557   auto format = getCuSparseFormat(aTp, xTp, yTp, enableRT, /*isMatVec=*/true);
558   if (format == CuSparseFormat::kNone || format == CuSparseFormat::kBSR)
559     return failure();
560 
561   // Start sparse kernel and copy data from host to device.
562   //   a : memR/memC/memV -> rowA,colA,valA
563   //   x : memX           -> vecX
564   //   y : memY           -> vecY
565   Value nseA = rewriter.create<NumberOfEntriesOp>(loc, a);
566   Value szY = linalg::createOrFoldDimOp(rewriter, loc, a, 0);
567   Value szX = linalg::createOrFoldDimOp(rewriter, loc, a, 1);
568   Value memR = genFirstPosOrCrds(rewriter, loc, a, format, enableRT);
569   Value memC = genSecondCrds(rewriter, loc, a, format, enableRT); // or empty
570   Value memV = genToValues(rewriter, loc, a);
571   Value rowA = genAllocCopy(rewriter, loc, memR, tokens);
572   Value colA = memC ? genAllocCopy(rewriter, loc, memC, tokens) : Value();
573   Value valA = genAllocCopy(rewriter, loc, memV, tokens);
574   Value memX = genTensorToMemref(rewriter, loc, x);
575   Value vecX = genAllocCopy(rewriter, loc, memX, tokens);
576   Value memY = genTensorToMemref(rewriter, loc, y);
577   Value vecY = genAllocCopy(rewriter, loc, memY, tokens);
578   genBlockingWait(rewriter, loc, tokens);
579   tokens.clear();
580 
581   // Create sparse environment and sparse matrix/dense vector handles.
582   Type indexTp = rewriter.getIndexType();
583   Type dnTensorHandleTp = rewriter.getType<gpu::SparseDnTensorHandleType>();
584   Type spmatHandleTp = rewriter.getType<gpu::SparseSpMatHandleType>();
585   Type tokenTp = rewriter.getType<gpu::AsyncTokenType>();
586   Value token = genFirstWait(rewriter, loc);
587   Operation *spGenA =
588       genSpMat(rewriter, loc, aTp, spmatHandleTp, tokenTp, token, szY, szX,
589                nseA, rowA, colA, valA, format, enableRT);
590   Value spMatA = spGenA->getResult(0);
591   token = spGenA->getResult(1);
592   auto dvecX = rewriter.create<gpu::CreateDnTensorOp>(
593       loc, dnTensorHandleTp, tokenTp, token, vecX, szX);
594   Value dnX = dvecX.getResult(0);
595   token = dvecX.getAsyncToken();
596   auto dvecY = rewriter.create<gpu::CreateDnTensorOp>(
597       loc, dnTensorHandleTp, tokenTp, token, vecY, szY);
598   Value dnY = dvecY.getResult(0);
599   token = dvecY.getAsyncToken();
600   auto dnYType = llvm::cast<ShapedType>(y.getType()).getElementType();
601 
602   // Precompute buffersize for SpMV.
603   auto bufferComp = rewriter.create<gpu::SpMVBufferSizeOp>(
604       loc, indexTp, tokenTp, token, spMatA, dnX, dnY,
605       /*computeType=*/dnYType);
606   Value bufferSz = bufferComp.getResult(0);
607   token = bufferComp.getAsyncToken();
608   auto buf = genAllocBuffer(rewriter, loc, bufferSz, token);
609   Value buffer = buf.getResult(0);
610   token = buf.getAsyncToken();
611 
612   // Perform the SpMV.
613   auto spmvComp = rewriter.create<gpu::SpMVOp>(
614       loc, tokenTp, token, spMatA, dnX, dnY, /*computeType=*/dnYType, buffer);
615   token = spmvComp.getAsyncToken();
616 
617   // Copy data back to host and free all the resoures.
618   token = rewriter.create<gpu::DestroySpMatOp>(loc, tokenTp, token, spMatA)
619               .getAsyncToken();
620   token = rewriter.create<gpu::DestroyDnTensorOp>(loc, tokenTp, token, dnX)
621               .getAsyncToken();
622   token = rewriter.create<gpu::DestroyDnTensorOp>(loc, tokenTp, token, dnY)
623               .getAsyncToken();
624   token = genDeallocMemRef(rewriter, loc, rowA, token);
625   if (colA)
626     token = genDeallocMemRef(rewriter, loc, colA, token);
627   token = genDeallocMemRef(rewriter, loc, valA, token);
628   token = genDeallocMemRef(rewriter, loc, buffer, token);
629   token = genDeallocMemRef(rewriter, loc, vecX, token);
630   token = genCopyMemRef(rewriter, loc, memY, vecY, token);
631   token = genDeallocMemRef(rewriter, loc, vecY, token);
632   tokens.push_back(token);
633   genBlockingWait(rewriter, loc, tokens);
634   tokens.clear();
635 
636   // Done.
637   rewriter.replaceOpWithNewOp<bufferization::ToTensorOp>(op, memY);
638   return success();
639 }
640 
641 /// Match and rewrite SpMM kernel.
642 static LogicalResult rewriteSpMM(PatternRewriter &rewriter,
643                                  linalg::GenericOp op, bool enableRT) {
644   Location loc = op.getLoc();
645   Value a = op.getOperand(0);
646   Value b = op.getOperand(1);
647   Value c = op.getOperand(2); // we have C = AB
648   SmallVector<Value> tokens;
649 
650   // Only admissible sparse matrix format and dense matrices (no BSR).
651   SparseTensorType aTp = getSparseTensorType(a);
652   SparseTensorType bTp = getSparseTensorType(b);
653   SparseTensorType cTp = getSparseTensorType(c);
654   auto format = getCuSparseFormat(aTp, bTp, cTp, enableRT, /*isMatVec=*/false);
655   if (format == CuSparseFormat::kNone || format == CuSparseFormat::kBSR)
656     return failure();
657 
658   // Start sparse kernel and copy data from host to device.
659   //   a : memR/memC/memV -> rowA,colA,valA
660   //   b : bufB           -> matB
661   //   c : bufC           -> matC
662   Value nseA = rewriter.create<NumberOfEntriesOp>(loc, a);
663   Value szm = linalg::createOrFoldDimOp(rewriter, loc, a, 0);
664   Value szk = linalg::createOrFoldDimOp(rewriter, loc, a, 1);
665   Value szn = linalg::createOrFoldDimOp(rewriter, loc, b, 1);
666   Value memR = genFirstPosOrCrds(rewriter, loc, a, format, enableRT);
667   Value memC = genSecondCrds(rewriter, loc, a, format, enableRT); // or empty
668   Value memV = genToValues(rewriter, loc, a);
669   Value rowA = genAllocCopy(rewriter, loc, memR, tokens);
670   Value colA = memC ? genAllocCopy(rewriter, loc, memC, tokens) : Value();
671   Value valA = genAllocCopy(rewriter, loc, memV, tokens);
672   Value bufB = genTensorToMemref(rewriter, loc, b);
673   Value matB = genAllocCopy(rewriter, loc, bufB, tokens);
674   Value bufC = genTensorToMemref(rewriter, loc, c);
675   Value matC = genAllocCopy(rewriter, loc, bufC, tokens);
676   genBlockingWait(rewriter, loc, tokens);
677   tokens.clear();
678 
679   // Create sparse environment and sparse matrix/dense matrix handles.
680   Type indexTp = rewriter.getIndexType();
681   Type dnTensorHandleTp = rewriter.getType<gpu::SparseDnTensorHandleType>();
682   Type spMatHandleTp = rewriter.getType<gpu::SparseSpMatHandleType>();
683   Type tokenTp = rewriter.getType<gpu::AsyncTokenType>();
684   Value token = genFirstWait(rewriter, loc);
685   Operation *spGenA =
686       genSpMat(rewriter, loc, aTp, spMatHandleTp, tokenTp, token, szm, szk,
687                nseA, rowA, colA, valA, format, enableRT);
688   Value spMatA = spGenA->getResult(0);
689   token = spGenA->getResult(1);
690   auto dmatB = rewriter.create<gpu::CreateDnTensorOp>(
691       loc, dnTensorHandleTp, tokenTp, token, matB,
692       SmallVector<Value>{szk, szn});
693   Value dnB = dmatB.getResult(0);
694   token = dmatB.getAsyncToken();
695   auto dmatC = rewriter.create<gpu::CreateDnTensorOp>(
696       loc, dnTensorHandleTp, tokenTp, token, matC,
697       SmallVector<Value>{szm, szn});
698   Value dnC = dmatC.getResult(0);
699   token = dmatC.getAsyncToken();
700   auto dmatCType = llvm::cast<ShapedType>(c.getType()).getElementType();
701 
702   // Precompute buffersize for SpMM.
703   auto bufferComp = rewriter.create<gpu::SpMMBufferSizeOp>(
704       loc, indexTp, tokenTp, token, spMatA, dnB, dnC,
705       /*computeType=*/dmatCType);
706   Value bufferSz = bufferComp.getResult(0);
707   token = bufferComp.getAsyncToken();
708   auto buf = genAllocBuffer(rewriter, loc, bufferSz, token);
709   Value buffer = buf.getResult(0);
710   token = buf.getAsyncToken();
711   auto dnCType = llvm::cast<ShapedType>(c.getType()).getElementType();
712 
713   // Perform the SpMM.
714   auto spmmComp = rewriter.create<gpu::SpMMOp>(
715       loc, tokenTp, token, spMatA, dnB, dnC, /*computeType=*/dnCType, buffer);
716   token = spmmComp.getAsyncToken();
717 
718   // Copy data back to host and free all the resoures.
719   token = rewriter.create<gpu::DestroySpMatOp>(loc, tokenTp, token, spMatA)
720               .getAsyncToken();
721   token = rewriter.create<gpu::DestroyDnTensorOp>(loc, tokenTp, token, dnB)
722               .getAsyncToken();
723   token = rewriter.create<gpu::DestroyDnTensorOp>(loc, tokenTp, token, dnC)
724               .getAsyncToken();
725   token = genDeallocMemRef(rewriter, loc, rowA, token);
726   if (colA)
727     token = genDeallocMemRef(rewriter, loc, colA, token);
728   token = genDeallocMemRef(rewriter, loc, valA, token);
729   token = genDeallocMemRef(rewriter, loc, buffer, token);
730   token = genDeallocMemRef(rewriter, loc, matB, token);
731   token = genCopyMemRef(rewriter, loc, bufC, matC, token);
732   token = genDeallocMemRef(rewriter, loc, matC, token);
733   tokens.push_back(token);
734   genBlockingWait(rewriter, loc, tokens);
735   tokens.clear();
736 
737   // Done.
738   rewriter.replaceOpWithNewOp<bufferization::ToTensorOp>(op, bufC);
739   return success();
740 }
741 
742 // Match and rewrite SpGEMM kernel.
743 static LogicalResult rewriteSpGEMM(PatternRewriter &rewriter,
744                                    linalg::GenericOp op, bool enableRT) {
745   Location loc = op.getLoc();
746   Value a = op.getOperand(0);
747   Value b = op.getOperand(1);
748   Value c = op.getOperand(2); // we have C = AB
749   SmallVector<Value> tokens;
750 
751   // Only CSR <- CSR x CSR supported.
752   auto format = CuSparseFormat::kCSR;
753   SparseTensorType aTp = getSparseTensorType(a);
754   SparseTensorType bTp = getSparseTensorType(b);
755   SparseTensorType cTp = getSparseTensorType(c);
756   if (!isAdmissibleCSR(aTp) || !isAdmissibleCSR(bTp) || !isAdmissibleCSR(cTp))
757     return failure();
758 
759   // Start sparse kernel and copy data from host to device.
760   //   a : amemR/amemC/amemV -> rowA,colA,valA
761   //   b : bmemR/bmemC/bmemV -> rowB,colB,valB
762   //   c : materializes
763   auto dnCType = cTp.getElementType();
764   Value nseA = rewriter.create<NumberOfEntriesOp>(loc, a);
765   Value nseB = rewriter.create<NumberOfEntriesOp>(loc, b);
766   Value szm = linalg::createOrFoldDimOp(rewriter, loc, a, 0);
767   Value szk = linalg::createOrFoldDimOp(rewriter, loc, a, 1);
768   Value szn = linalg::createOrFoldDimOp(rewriter, loc, b, 1);
769   Value amemR = genFirstPosOrCrds(rewriter, loc, a, format, enableRT);
770   Value amemC = genSecondCrds(rewriter, loc, a, format, enableRT); // not empty
771   Value amemV = genToValues(rewriter, loc, a);
772   Value bmemR = genFirstPosOrCrds(rewriter, loc, b, format, enableRT);
773   Value bmemC = genSecondCrds(rewriter, loc, b, format, enableRT); // not empty
774   Value bmemV = genToValues(rewriter, loc, b);
775   Value rowA = genAllocCopy(rewriter, loc, amemR, tokens);
776   Value colA = genAllocCopy(rewriter, loc, amemC, tokens);
777   Value valA = genAllocCopy(rewriter, loc, amemV, tokens);
778   Value rowB = genAllocCopy(rewriter, loc, bmemR, tokens);
779   Value colB = genAllocCopy(rewriter, loc, bmemC, tokens);
780   Value valB = genAllocCopy(rewriter, loc, bmemV, tokens);
781   genBlockingWait(rewriter, loc, tokens);
782   tokens.clear();
783 
784   // Create sparse environment and sparse matrix/dense vector handles.
785   Type indexTp = rewriter.getIndexType();
786   Type spmatHandleTp = rewriter.getType<gpu::SparseSpMatHandleType>();
787   Type descTp = rewriter.getType<gpu::SparseSpGEMMOpHandleType>();
788   Type tokenTp = rewriter.getType<gpu::AsyncTokenType>();
789   Value token = genFirstWait(rewriter, loc);
790   Operation *spGenA =
791       genSpMat(rewriter, loc, aTp, spmatHandleTp, tokenTp, token, szm, szk,
792                nseA, rowA, colA, valA, format, enableRT);
793   Value spMatA = spGenA->getResult(0);
794   token = spGenA->getResult(1);
795   Operation *spGenB =
796       genSpMat(rewriter, loc, bTp, spmatHandleTp, tokenTp, token, szk, szn,
797                nseB, rowB, colB, valB, format, enableRT);
798   Value spMatB = spGenB->getResult(0);
799   token = spGenB->getResult(1);
800 
801   // Sparse matrix C materializes (also assumes beta == 0).
802   Value zero = constantIndex(rewriter, loc, 0);
803   Value one = constantIndex(rewriter, loc, 1);
804   Value mplus1 = rewriter.create<arith::AddIOp>(loc, szm, one);
805   auto e1 = genAllocBuffer(rewriter, loc, cTp.getPosType(), mplus1, token);
806   Value rowC = e1.getResult(0);
807   token = e1.getAsyncToken();
808   auto e2 = genAllocBuffer(rewriter, loc, cTp.getCrdType(), zero, token);
809   Value colC = e2.getResult(0); // no free needed
810   token = e2.getAsyncToken();
811   auto e3 = genAllocBuffer(rewriter, loc, dnCType, zero, token);
812   Value valC = e3.getResult(0); // no free needed
813   token = e3.getAsyncToken();
814   Operation *spGenC =
815       genSpMat(rewriter, loc, cTp, spmatHandleTp, tokenTp, token, szm, szn,
816                zero, rowC, colC, valC, format, enableRT);
817   Value spMatC = spGenC->getResult(0);
818   token = spGenC->getResult(1);
819 
820   // Precompute buffersizes for SpGEMM.
821   Operation *descOp =
822       rewriter.create<gpu::SpGEMMCreateDescrOp>(loc, descTp, tokenTp, token);
823   Value desc = descOp->getResult(0);
824   token = descOp->getResult(1);
825   Operation *work1 = rewriter.create<gpu::SpGEMMWorkEstimationOrComputeOp>(
826       loc, indexTp, tokenTp, token, desc, gpu::TransposeMode::NON_TRANSPOSE,
827       gpu::TransposeMode::NON_TRANSPOSE, spMatA, spMatB, spMatC, dnCType, zero,
828       valC, gpu::SpGEMMWorkEstimationOrComputeKind::WORK_ESTIMATION);
829   Value bufferSz1 = work1->getResult(0);
830   token = work1->getResult(1);
831   auto buf1 = genAllocBuffer(rewriter, loc, bufferSz1, token);
832   Value buffer1 = buf1.getResult(0);
833   token = buf1.getAsyncToken();
834   Operation *work2 = rewriter.create<gpu::SpGEMMWorkEstimationOrComputeOp>(
835       loc, indexTp, tokenTp, token, desc, gpu::TransposeMode::NON_TRANSPOSE,
836       gpu::TransposeMode::NON_TRANSPOSE, spMatA, spMatB, spMatC, dnCType,
837       bufferSz1, buffer1,
838       gpu::SpGEMMWorkEstimationOrComputeKind::WORK_ESTIMATION);
839   token = work2->getResult(1);
840 
841   // Compute step.
842   Operation *compute1 = rewriter.create<gpu::SpGEMMWorkEstimationOrComputeOp>(
843       loc, indexTp, tokenTp, token, desc, gpu::TransposeMode::NON_TRANSPOSE,
844       gpu::TransposeMode::NON_TRANSPOSE, spMatA, spMatB, spMatC, dnCType, zero,
845       valC, gpu::SpGEMMWorkEstimationOrComputeKind::COMPUTE);
846   Value bufferSz2 = compute1->getResult(0);
847   token = compute1->getResult(1);
848   auto buf2 = genAllocBuffer(rewriter, loc, bufferSz2, token);
849   Value buffer2 = buf2.getResult(0);
850   token = buf2.getAsyncToken();
851   Operation *compute2 = rewriter.create<gpu::SpGEMMWorkEstimationOrComputeOp>(
852       loc, indexTp, tokenTp, token, desc, gpu::TransposeMode::NON_TRANSPOSE,
853       gpu::TransposeMode::NON_TRANSPOSE, spMatA, spMatB, spMatC, dnCType,
854       bufferSz2, buffer2, gpu::SpGEMMWorkEstimationOrComputeKind::COMPUTE);
855   token = compute2->getResult(1);
856 
857   // Get sizes.
858   Operation *sizes = rewriter.create<gpu::SpMatGetSizeOp>(
859       loc, indexTp, indexTp, indexTp, tokenTp, token, spMatC);
860   Value nnz = sizes->getResult(2);
861   token = sizes->getResult(3);
862   auto a2 = genAllocBuffer(rewriter, loc, cTp.getCrdType(), nnz, token);
863   colC = a2.getResult(0);
864   token = a2.getAsyncToken();
865   auto a3 = genAllocBuffer(rewriter, loc, dnCType, nnz, token);
866   valC = a3.getResult(0);
867   token = a3.getAsyncToken();
868 
869   // Update C with new pointers and copy final product back into C.
870   Operation *update = rewriter.create<gpu::SetCsrPointersOp>(
871       loc, tokenTp, token, spMatC, rowC, colC, valC);
872   token = update->getResult(0);
873   Operation *copy = rewriter.create<gpu::SpGEMMCopyOp>(
874       loc, tokenTp, token, desc, gpu::TransposeMode::NON_TRANSPOSE,
875       gpu::TransposeMode::NON_TRANSPOSE, spMatA, spMatB, spMatC, dnCType);
876   token = copy->getResult(0);
877 
878   // Allocate buffers on host.
879   Value rowH = genHostBuffer(rewriter, loc, cTp.getPosType(), mplus1);
880   Value colH = genHostBuffer(rewriter, loc, cTp.getCrdType(), nnz);
881   Value valH = genHostBuffer(rewriter, loc, dnCType, nnz);
882 
883   // Copy data back to host and free all the resoures.
884   token = rewriter.create<gpu::SpGEMMDestroyDescrOp>(loc, tokenTp, token, desc)
885               .getAsyncToken();
886   token = rewriter.create<gpu::DestroySpMatOp>(loc, tokenTp, token, spMatA)
887               .getAsyncToken();
888   token = rewriter.create<gpu::DestroySpMatOp>(loc, tokenTp, token, spMatB)
889               .getAsyncToken();
890   token = rewriter.create<gpu::DestroySpMatOp>(loc, tokenTp, token, spMatC)
891               .getAsyncToken();
892   token = genCopyMemRef(rewriter, loc, rowH, rowC, token);
893   token = genCopyMemRef(rewriter, loc, colH, colC, token);
894   token = genCopyMemRef(rewriter, loc, valH, valC, token);
895   token = genDeallocMemRef(rewriter, loc, rowA, token);
896   token = genDeallocMemRef(rewriter, loc, colA, token);
897   token = genDeallocMemRef(rewriter, loc, valA, token);
898   token = genDeallocMemRef(rewriter, loc, rowB, token);
899   token = genDeallocMemRef(rewriter, loc, colB, token);
900   token = genDeallocMemRef(rewriter, loc, valB, token);
901   token = genDeallocMemRef(rewriter, loc, rowC, token);
902   token = genDeallocMemRef(rewriter, loc, colC, token);
903   token = genDeallocMemRef(rewriter, loc, valC, token);
904   token = genDeallocMemRef(rewriter, loc, buffer1, token);
905   token = genDeallocMemRef(rewriter, loc, buffer2, token);
906   tokens.push_back(token);
907   genBlockingWait(rewriter, loc, tokens);
908   tokens.clear();
909 
910   // Done.
911   Value vt = rewriter.create<bufferization::ToTensorOp>(loc, valH);
912   Value rt = rewriter.create<bufferization::ToTensorOp>(loc, rowH);
913   Value ct = rewriter.create<bufferization::ToTensorOp>(loc, colH);
914   rewriter.replaceOpWithNewOp<AssembleOp>(op, c.getType(), vt,
915                                           ValueRange{rt, ct});
916   return success();
917 }
918 
919 // Match and rewrite 2:4 SpMM kernel.
920 static LogicalResult rewrite2To4SpMM(PatternRewriter &rewriter,
921                                      linalg::GenericOp op) {
922   Location loc = op.getLoc();
923   Value A = op.getOperand(0);
924   Value B = op.getOperand(1);
925   Value C = op.getOperand(2); // we have C = AB
926   SmallVector<Value> tokens;
927 
928   // All input should be dense tensors.
929   if (!isDenseTensor(A) || !isDenseTensor(B) || !isDenseTensor(C))
930     return failure();
931 
932   // Start sparse kernel and copy data from host to device.
933   //   a : bufA -> matA
934   //   b : bufB -> matB
935   //   c : bufC -> matC
936   Value bufA = genTensorToMemref(rewriter, loc, A);
937   Value matA = genAllocCopy(rewriter, loc, bufA, tokens);
938   Value bufB = genTensorToMemref(rewriter, loc, B);
939   Value matB = genAllocCopy(rewriter, loc, bufB, tokens);
940   Value bufC = genTensorToMemref(rewriter, loc, C);
941   Value matC = genAllocCopy(rewriter, loc, bufC, tokens);
942   genBlockingWait(rewriter, loc, tokens);
943   tokens.clear();
944 
945   // Create sparse environment and sparse matrix/dense vector handles.
946   Value szm = linalg::createOrFoldDimOp(rewriter, loc, matA, 0);
947   Value szk = linalg::createOrFoldDimOp(rewriter, loc, matB, 0);
948   Value szn = linalg::createOrFoldDimOp(rewriter, loc, matC, 1);
949   Type indexTp = rewriter.getIndexType();
950   Type dnTensorHandleTp = rewriter.getType<gpu::SparseDnTensorHandleType>();
951   Type spMatHandleTp = rewriter.getType<gpu::SparseSpMatHandleType>();
952   Type tokenTp = rewriter.getType<gpu::AsyncTokenType>();
953   Value token = genFirstWait(rewriter, loc);
954   Operation *spGenA = rewriter.create<gpu::Create2To4SpMatOp>(
955       loc, spMatHandleTp, tokenTp, token, szm, szk,
956       gpu::Prune2To4SpMatFlag::PRUNE_AND_CHECK, matA);
957   Value spMatA = spGenA->getResult(0);
958   token = spGenA->getResult(1);
959   auto dmatB = rewriter.create<gpu::CreateDnTensorOp>(
960       loc, dnTensorHandleTp, tokenTp, token, matB,
961       SmallVector<Value>{szk, szn});
962   Value dnB = dmatB.getResult(0);
963   token = dmatB.getAsyncToken();
964   auto dmatC = rewriter.create<gpu::CreateDnTensorOp>(
965       loc, dnTensorHandleTp, tokenTp, token, matC,
966       SmallVector<Value>{szm, szn});
967   Value dnC = dmatC.getResult(0);
968   token = dmatC.getAsyncToken();
969   auto dmatCType = llvm::cast<ShapedType>(matC.getType()).getElementType();
970 
971   // Precompute buffersize for SpMM.
972   SmallVector<Type> bufferTypes_{indexTp, indexTp, indexTp};
973   TypeRange bufferTypes(bufferTypes_);
974   auto bufferComp = rewriter.create<gpu::SpMMBufferSizeOp>(
975       loc, bufferTypes, tokenTp, token, gpu::TransposeMode::NON_TRANSPOSE,
976       gpu::TransposeMode::NON_TRANSPOSE, spMatA, dnB, dnC,
977       /*computeType=*/dmatCType);
978   token = bufferComp.getAsyncToken();
979 
980   // Allocate buffers on host.
981   Value bufferSz1 = bufferComp.getResult(0);
982   auto buf1 = genAllocBuffer(rewriter, loc, bufferSz1, token);
983   Value buffer1 = buf1.getResult(0);
984   token = buf1.getAsyncToken();
985   Value bufferSz2 = bufferComp.getResult(1);
986   auto buf2 = genAllocBuffer(rewriter, loc, bufferSz2, token);
987   Value buffer2 = buf2.getResult(0);
988   token = buf2.getAsyncToken();
989   Value bufferSz3 = bufferComp.getResult(2);
990   auto buf3 = genAllocBuffer(rewriter, loc, bufferSz3, token);
991   Value buffer3 = buf3.getResult(0);
992   token = buf3.getAsyncToken();
993 
994   // Perform the SpMM.
995   auto dnCType = llvm::cast<ShapedType>(matC.getType()).getElementType();
996   auto spmmComp = rewriter.create<gpu::SpMMOp>(
997       loc, tokenTp, token, spMatA, dnB, dnC, /*computeType=*/dnCType,
998       SmallVector<Value>{buffer1, buffer2, buffer3});
999   token = spmmComp.getAsyncToken();
1000 
1001   // Copy data back to host and free all the resources.
1002   token = rewriter.create<gpu::DestroySpMatOp>(loc, tokenTp, token, spMatA)
1003               .getAsyncToken();
1004   token = rewriter.create<gpu::DestroyDnTensorOp>(loc, tokenTp, token, dnB)
1005               .getAsyncToken();
1006   token = rewriter.create<gpu::DestroyDnTensorOp>(loc, tokenTp, token, dnC)
1007               .getAsyncToken();
1008   SmallVector<Value> newDynamicSizes;
1009   token = genDeallocMemRef(rewriter, loc, buffer1, token);
1010   token = genDeallocMemRef(rewriter, loc, buffer2, token);
1011   token = genDeallocMemRef(rewriter, loc, buffer3, token);
1012   token = genDeallocMemRef(rewriter, loc, matA, token);
1013   token = genDeallocMemRef(rewriter, loc, matB, token);
1014   token = genCopyMemRef(rewriter, loc, bufC, matC, token);
1015   token = genDeallocMemRef(rewriter, loc, matC, token);
1016   tokens.push_back(token);
1017   genBlockingWait(rewriter, loc, tokens);
1018   tokens.clear();
1019 
1020   // Done.
1021   rewriter.replaceOpWithNewOp<bufferization::ToTensorOp>(op, bufC);
1022   return success();
1023 }
1024 
1025 /// Match and rewrite SDDMM kernel.
1026 static LogicalResult rewriteSDDMM(PatternRewriter &rewriter,
1027                                   linalg::GenericOp op, bool enableRT) {
1028   Location loc = op.getLoc();
1029   Value a = op.getOperand(0);
1030   Value b = op.getOperand(1);
1031   Value c = op.getOperand(2);
1032   SmallVector<Value> tokens;
1033 
1034   // Only admissible sparse matrix format (no COO/CSC) and dense matrices.
1035   SparseTensorType aTp = getSparseTensorType(a);
1036   SparseTensorType bTp = getSparseTensorType(b);
1037   SparseTensorType cTp = getSparseTensorType(c);
1038   auto format = getCuSparseFormat(cTp, bTp, aTp, enableRT, /*isMatVec=*/false);
1039   if (format == CuSparseFormat::kNone || format == CuSparseFormat::kCOO ||
1040       format == CuSparseFormat::kCSC)
1041     return failure();
1042 
1043   // The SDDMM does the in-place operation.
1044   // Start sparse kernel and copy data from host to device.
1045   //   a : bufA           -> matA
1046   //   b : bufB           -> matB
1047   //   c : memR/memC/memV -> rowC,colC,valC
1048   Value nseC = rewriter.create<NumberOfEntriesOp>(loc, c);
1049   Value szm = linalg::createOrFoldDimOp(rewriter, loc, a, 0);
1050   Value szk = linalg::createOrFoldDimOp(rewriter, loc, a, 1);
1051   Value szn = linalg::createOrFoldDimOp(rewriter, loc, b, 1);
1052   Value bufA = genTensorToMemref(rewriter, loc, a);
1053   Value matA = genAllocCopy(rewriter, loc, bufA, tokens);
1054   Value bufB = genTensorToMemref(rewriter, loc, b);
1055   Value matB = genAllocCopy(rewriter, loc, bufB, tokens);
1056   Value memR = genFirstPosOrCrds(rewriter, loc, c, format, enableRT);
1057   Value memC = genSecondCrds(rewriter, loc, c, format, enableRT); // or empty
1058   Value memV = genToValues(rewriter, loc, c);
1059   Value rowC = genAllocCopy(rewriter, loc, memR, tokens);
1060   Value colC = memC ? genAllocCopy(rewriter, loc, memC, tokens) : Value();
1061   Value valC = genAllocCopy(rewriter, loc, memV, tokens);
1062   genBlockingWait(rewriter, loc, tokens);
1063   tokens.clear();
1064 
1065   // Create sparse environment and sparse matrix/dense matrix handles.
1066   Type indexTp = rewriter.getIndexType();
1067   Type dnMatHandleTp = rewriter.getType<gpu::SparseDnTensorHandleType>();
1068   Type spMatHandleTp = rewriter.getType<gpu::SparseSpMatHandleType>();
1069   Type tokenTp = rewriter.getType<gpu::AsyncTokenType>();
1070   Value token = genFirstWait(rewriter, loc);
1071   auto dmatA = rewriter.create<gpu::CreateDnTensorOp>(
1072       loc, dnMatHandleTp, tokenTp, token, matA, SmallVector<Value>{szm, szk});
1073   Value dnA = dmatA.getResult(0);
1074   token = dmatA.getAsyncToken();
1075   auto dmatB = rewriter.create<gpu::CreateDnTensorOp>(
1076       loc, dnMatHandleTp, tokenTp, token, matB, SmallVector<Value>{szk, szn});
1077   Value dnB = dmatB.getResult(0);
1078   token = dmatB.getAsyncToken();
1079   Operation *spGenC =
1080       genSpMat(rewriter, loc, cTp, spMatHandleTp, tokenTp, token, szm, szn,
1081                nseC, rowC, colC, valC, format, enableRT);
1082   Value spMatC = spGenC->getResult(0);
1083   token = spGenC->getResult(1);
1084   auto dnCType = llvm::cast<ShapedType>(c.getType()).getElementType();
1085 
1086   // Precompute buffersize for SDDMM.
1087   auto bufferComp = rewriter.create<gpu::SDDMMBufferSizeOp>(
1088       loc, indexTp, tokenTp, token, dnA, dnB, spMatC, dnCType);
1089   Value bufferSz = bufferComp.getResult(0);
1090   token = bufferComp.getAsyncToken();
1091   auto buf = genAllocBuffer(rewriter, loc, bufferSz, token);
1092   Value buffer = buf.getResult(0);
1093   token = buf.getAsyncToken();
1094 
1095   // Perform the SDDMM.
1096   auto sddmmComp = rewriter.create<gpu::SDDMMOp>(loc, tokenTp, token, dnA, dnB,
1097                                                  spMatC, dnCType, buffer);
1098   token = sddmmComp.getAsyncToken();
1099 
1100   // Copy data back to host and free all the resoures.
1101   token = rewriter.create<gpu::DestroyDnTensorOp>(loc, tokenTp, token, dnA)
1102               .getAsyncToken();
1103   token = rewriter.create<gpu::DestroyDnTensorOp>(loc, tokenTp, token, dnB)
1104               .getAsyncToken();
1105   token = rewriter.create<gpu::DestroySpMatOp>(loc, tokenTp, token, spMatC)
1106               .getAsyncToken();
1107   token = genDeallocMemRef(rewriter, loc, buffer, token);
1108   token = genDeallocMemRef(rewriter, loc, matA, token);
1109   token = genDeallocMemRef(rewriter, loc, matB, token);
1110   token = genDeallocMemRef(rewriter, loc, rowC, token);
1111   if (colC)
1112     token = genDeallocMemRef(rewriter, loc, colC, token);
1113   token = genCopyMemRef(rewriter, loc, memV, valC, token);
1114   token = genDeallocMemRef(rewriter, loc, valC, token);
1115   tokens.push_back(token);
1116   genBlockingWait(rewriter, loc, tokens);
1117   tokens.clear();
1118 
1119   // Done.
1120   rewriter.replaceOpWithNewOp<sparse_tensor::LoadOp>(op, c);
1121   return success();
1122 }
1123 
1124 //===----------------------------------------------------------------------===//
1125 // Rewriting rules for direct code generation.
1126 //===----------------------------------------------------------------------===//
1127 
1128 /// Proof-of-concept rewriter. This rule generates a GPU implementation
1129 /// for each outermost forall loop generated by the sparsifier.
1130 /// TODO: right now works with parallelization-strategy=dense-outer-loop
1131 ///       but give this its own flags in the future
1132 struct ForallRewriter : public OpRewritePattern<scf::ParallelOp> {
1133   using OpRewritePattern<scf::ParallelOp>::OpRewritePattern;
1134 
1135   ForallRewriter(MLIRContext *context, unsigned nT)
1136       : OpRewritePattern(context), numThreads(nT){};
1137 
1138   LogicalResult matchAndRewrite(scf::ParallelOp forallOp,
1139                                 PatternRewriter &rewriter) const override {
1140     // Reject inadmissible loop form.
1141     // Essentially only accept a loop, generated by the sparsifier,
1142     // of the form
1143     //   forall (i = 0; i < N; i++)
1144     // so that cyclic scheduling over the threads is easy.
1145     if (!forallOp->hasAttr(LoopEmitter::getLoopEmitterLoopAttrName()) ||
1146         forallOp.getNumReductions() != 0 || forallOp.getNumLoops() != 1 ||
1147         !matchPattern(forallOp.getLowerBound()[0], m_Zero()) ||
1148         !matchPattern(forallOp.getStep()[0], m_One()))
1149       return failure();
1150     // Collect every value that is computed outside the parallel loop.
1151     SetVector<Value> invariants; // stable iteration!
1152     forallOp->walk([&](Operation *op) {
1153       // Collect all values of admissible ops.
1154       for (OpOperand &o : op->getOpOperands()) {
1155         Value val = o.get();
1156         Block *block;
1157         if (auto arg = dyn_cast<BlockArgument>(val))
1158           block = arg.getOwner();
1159         else
1160           block = val.getDefiningOp()->getBlock();
1161         if (!forallOp.getRegion().findAncestorBlockInRegion(*block))
1162           invariants.insert(val);
1163       }
1164     });
1165     // Outline the outside values as proper parameters. Fail when sharing
1166     // value between host and device is not straightforward.
1167     SmallVector<Value> constants;
1168     SmallVector<Value> scalars;
1169     SmallVector<Value> buffers;
1170     for (Value val : invariants) {
1171       Type tp = val.getType();
1172       if (val.getDefiningOp<arith::ConstantOp>())
1173         constants.push_back(val);
1174       else if (isa<FloatType>(tp) || tp.isIntOrIndex())
1175         scalars.push_back(val);
1176       else if (isa<MemRefType>(tp))
1177         buffers.push_back(val);
1178       else
1179         return failure(); // don't know how to share
1180     }
1181     // Pass outlined non-constant values.
1182     // TODO: Experiment with `useHostRegistrationForOut` to see if we want to
1183     //       keep the feature at all (either through a heuristic or compiler
1184     //       option for gpu codegen).
1185     Location loc = forallOp->getLoc();
1186     SmallVector<Value> args;
1187     SmallVector<Value> tokens;
1188     Value out = genParametersIn(rewriter, loc, scalars, buffers, args, tokens,
1189                                 /*useHostRegistrationForOut=*/false);
1190     // Set up GPU module and construct GPU function.
1191     auto saveIp = rewriter.saveInsertionPoint();
1192     ModuleOp topModule = forallOp->getParentOfType<ModuleOp>();
1193     auto gpuModule = genGPUModule(rewriter, topModule);
1194     auto gpuFunc = genGPUFunc(rewriter, gpuModule, args);
1195     genGPUCode(rewriter, gpuFunc, forallOp, constants, scalars, buffers);
1196     // Generate code that launches the kernel asynchronously, blocking on all
1197     // opens tokens and yielding a new token for the output.
1198     // TODO: Passing in tokens to launch up does not seem to be properly lowered
1199     //       by cubin yet, hence the current blocking wait.
1200     rewriter.restoreInsertionPoint(saveIp);
1201     genBlockingWait(rewriter, loc, tokens);
1202     tokens.clear();
1203     Value kernelToken =
1204         genLaunchGPUFunc(rewriter, gpuFunc, args, tokens, numThreads);
1205     // Finalize the outlined arguments.
1206     genParametersOut(rewriter, loc, out, kernelToken, scalars, buffers, args,
1207                      tokens);
1208     genBlockingWait(rewriter, loc, tokens);
1209     rewriter.eraseOp(forallOp);
1210     return success();
1211   }
1212 
1213 private:
1214   unsigned numThreads;
1215 };
1216 
1217 //===----------------------------------------------------------------------===//
1218 // Rewriting rules for library recognition and code generation.
1219 //===----------------------------------------------------------------------===//
1220 
1221 /// Proof-of-concept rewriter. This rule recognizes certain math kernels
1222 /// and replaces these with corresponding calls into a sparse library.
1223 struct LinalgOpRewriter : public OpRewritePattern<linalg::GenericOp> {
1224   using OpRewritePattern<linalg::GenericOp>::OpRewritePattern;
1225 
1226   LinalgOpRewriter(MLIRContext *context, bool rt)
1227       : OpRewritePattern(context), enableRT(rt) {}
1228 
1229   LogicalResult matchAndRewrite(linalg::GenericOp op,
1230                                 PatternRewriter &rewriter) const override {
1231     if (op.getNumDpsInits() != 1)
1232       return failure(); // reject multi-output
1233 
1234     const unsigned numLoops = op.getNumLoops();
1235     const unsigned numTensors = op->getNumOperands();
1236     const auto iteratorTypes = op.getIteratorTypesArray();
1237     SmallVector<AffineMap, 4> maps = op.getIndexingMapsArray();
1238 
1239     using MapList = ArrayRef<ArrayRef<AffineExpr>>;
1240     auto infer = [](MapList m) { return AffineMap::inferFromExprList(m); };
1241     AffineExpr i, j, k;
1242     bindDims(getContext(), i, j, k);
1243 
1244     // TODO: more robust patterns, tranposed versions, more kernels,
1245     //       identify alpha and beta and pass them to the CUDA calls.
1246 
1247     // Recognize a SpMV kernel.
1248     if (numLoops == 2 && numTensors == 3 &&
1249         linalg::isParallelIterator(iteratorTypes[0]) &&
1250         linalg::isReductionIterator(iteratorTypes[1]) &&
1251         maps == infer({{i, j}, {j}, {i}}) && matchSumOfMultOfArgs(op)) {
1252       return rewriteSpMV(rewriter, op, enableRT);
1253     }
1254 
1255     // Recognize a SpGEMM, 2:4-SpMM, or SpMM kernel.
1256     if (numLoops == 3 && numTensors == 3 &&
1257         linalg::isParallelIterator(iteratorTypes[0]) &&
1258         linalg::isParallelIterator(iteratorTypes[1]) &&
1259         linalg::isReductionIterator(iteratorTypes[2]) &&
1260         maps == infer({{i, k}, {k, j}, {i, j}}) && matchSumOfMultOfArgs(op)) {
1261       if (!isDenseTensor(op.getOperand(0)) && !isDenseTensor(op.getOperand(1)))
1262         return rewriteSpGEMM(rewriter, op, enableRT);
1263       if (op->getAttr("DENSE24"))
1264         return rewrite2To4SpMM(rewriter, op);
1265       return rewriteSpMM(rewriter, op, enableRT);
1266     }
1267 
1268     // Recognize a SDDMM kernel.
1269     if (numLoops == 3 && numTensors == 3 &&
1270         linalg::isParallelIterator(iteratorTypes[0]) &&
1271         linalg::isParallelIterator(iteratorTypes[1]) &&
1272         linalg::isReductionIterator(iteratorTypes[2]) &&
1273         maps == infer({{i, k}, {k, j}, {i, j}}) &&
1274         matchSumReductionOfMulUnary(op)) {
1275       return rewriteSDDMM(rewriter, op, enableRT);
1276     }
1277 
1278     return failure();
1279   }
1280 
1281 private:
1282   bool enableRT;
1283 };
1284 
1285 } // namespace
1286 
1287 //===----------------------------------------------------------------------===//
1288 // Public method for populating GPU rewriting rules.
1289 //
1290 // Currently two set of rewriting rules are made available. The first set
1291 // implements direct code generation, currently by means of convering the
1292 // outermost paralell loop into GPU threads. The second set implements
1293 // libary recognition of a set of sparse operations. Eventually, the right
1294 // combination of these two approaches has to be found.
1295 //===----------------------------------------------------------------------===//
1296 
1297 void mlir::populateSparseGPUCodegenPatterns(RewritePatternSet &patterns,
1298                                             unsigned numThreads) {
1299   patterns.add<ForallRewriter>(patterns.getContext(), numThreads);
1300 }
1301 
1302 void mlir::populateSparseGPULibgenPatterns(RewritePatternSet &patterns,
1303                                            bool enableRT) {
1304   patterns.add<LinalgOpRewriter>(patterns.getContext(), enableRT);
1305 }
1306