xref: /llvm-project/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp (revision 6ca47eb49ded6281e887fbdb26323deea45df44e)
1 //===- SparseGPUCodegen.cpp - Generates GPU code --------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This is a prototype GPU codegenerator for the sparse compiler.
10 // The objective is to eventually use the right combination of
11 // direct code generation and libary calls into vendor-specific
12 // highly optimized sparse libraries (e.g. cuSparse for CUDA).
13 //
14 //===----------------------------------------------------------------------===//
15 
16 #include "CodegenUtils.h"
17 #include "LoopEmitter.h"
18 
19 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
20 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
21 #include "mlir/Dialect/Linalg/IR/Linalg.h"
22 #include "mlir/Dialect/Linalg/Utils/Utils.h"
23 #include "mlir/Dialect/MemRef/IR/MemRef.h"
24 #include "mlir/Dialect/SCF/IR/SCF.h"
25 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
26 #include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h"
27 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
28 #include "mlir/IR/IRMapping.h"
29 #include "mlir/IR/Matchers.h"
30 
31 using namespace mlir;
32 using namespace mlir::sparse_tensor;
33 
34 namespace {
35 
36 //===----------------------------------------------------------------------===//
37 // Helper methods.
38 //===----------------------------------------------------------------------===//
39 
40 /// Marks the given top module as a GPU container module.
41 static void markAsGPUContainer(ModuleOp topModule) {
42   topModule->setAttr(gpu::GPUDialect::getContainerModuleAttrName(),
43                      UnitAttr::get(topModule->getContext()));
44 }
45 
46 /// Constructs a new GPU module (for GPU kernels) inside the given top module,
47 /// or returns an existing GPU module if one was built previously.
48 static gpu::GPUModuleOp genGPUModule(OpBuilder &builder, ModuleOp topModule) {
49   for (auto op : topModule.getBodyRegion().getOps<gpu::GPUModuleOp>())
50     return op; // existing
51   markAsGPUContainer(topModule);
52   builder.setInsertionPointToStart(&topModule.getBodyRegion().front());
53   return builder.create<gpu::GPUModuleOp>(topModule->getLoc(),
54                                           "sparse_kernels");
55 }
56 
57 /// Constructs a new GPU kernel in the given GPU module.
58 static gpu::GPUFuncOp genGPUFunc(OpBuilder &builder, gpu::GPUModuleOp gpuModule,
59                                  SmallVectorImpl<Value> &args) {
60   // Get a unique kernel name. Not very creative,
61   // but we simply try kernel0, kernel1, etc.
62   unsigned kernelNumber = 0;
63   SmallString<16> kernelName;
64   do {
65     kernelName.clear();
66     ("kernel" + Twine(kernelNumber++)).toStringRef(kernelName);
67   } while (gpuModule.lookupSymbol(kernelName));
68   // Then we insert a new kernel with given arguments into the module.
69   builder.setInsertionPointToStart(&gpuModule.getBodyRegion().front());
70   SmallVector<Type> argsTp;
71   for (unsigned i = 0, e = args.size(); i < e; i++)
72     argsTp.push_back(args[i].getType());
73   FunctionType type = FunctionType::get(gpuModule->getContext(), argsTp, {});
74   auto gpuFunc =
75       builder.create<gpu::GPUFuncOp>(gpuModule->getLoc(), kernelName, type);
76   gpuFunc->setAttr(gpu::GPUDialect::getKernelFuncAttrName(),
77                    builder.getUnitAttr());
78   return gpuFunc;
79 }
80 
81 /// Constructs code to launch GPU kernel.
82 static Value genLaunchGPUFunc(OpBuilder &builder, gpu::GPUFuncOp gpuFunc,
83                               SmallVectorImpl<Value> &args,
84                               SmallVectorImpl<Value> &tokens,
85                               unsigned numThreads) {
86   Location loc = gpuFunc->getLoc();
87   Value none = TypedValue<::mlir::IntegerType>{};
88   Value one = constantIndex(builder, loc, 1);
89   Value numT = constantIndex(builder, loc, numThreads);
90   gpu::KernelDim3 gridSize = {one, one, one};
91   gpu::KernelDim3 blckSize = {numT, one, one};
92   return builder
93       .create<gpu::LaunchFuncOp>(loc, gpuFunc, gridSize, blckSize,
94                                  /*dynSharedMemSz*/ none, args,
95                                  builder.getType<gpu::AsyncTokenType>(), tokens)
96       .getAsyncToken();
97 }
98 
99 /// Maps the provided ranked host buffer into the device address space.
100 /// Writes from the host are guaranteed to be visible to device kernels
101 /// that are launched afterwards. Writes from the device are guaranteed
102 /// to be visible on the host after synchronizing with the device kernel
103 /// completion. Needs to cast the buffer to a unranked buffer.
104 static Value genHostRegisterMemref(OpBuilder &builder, Location loc,
105                                    Value mem) {
106   MemRefType memTp = cast<MemRefType>(mem.getType());
107   UnrankedMemRefType resTp =
108       UnrankedMemRefType::get(memTp.getElementType(), /*memorySpace=*/0);
109   Value cast = builder.create<memref::CastOp>(loc, resTp, mem);
110   builder.create<gpu::HostRegisterOp>(loc, cast);
111   return cast;
112 }
113 
114 /// Unmaps the provided buffer, expecting the casted buffer.
115 static void genHostUnregisterMemref(OpBuilder &builder, Location loc,
116                                     Value cast) {
117   builder.create<gpu::HostUnregisterOp>(loc, cast);
118 }
119 
120 /// Generates first wait in an asynchronous chain.
121 static Value genFirstWait(OpBuilder &builder, Location loc) {
122   Type tokenType = builder.getType<gpu::AsyncTokenType>();
123   return builder.create<gpu::WaitOp>(loc, tokenType, ValueRange())
124       .getAsyncToken();
125 }
126 
127 /// Generates last, blocking wait in an asynchronous chain.
128 static void genBlockingWait(OpBuilder &builder, Location loc,
129                             ValueRange operands) {
130   builder.create<gpu::WaitOp>(loc, Type(), operands);
131 }
132 
133 /// Allocates memory on the device.
134 /// TODO: A `host_shared` attribute could be used to indicate that
135 ///       the buffer is visible by both host and device, but lowering
136 ///       that feature does not seem to be fully supported yet.
137 static gpu::AllocOp genAllocMemRef(OpBuilder &builder, Location loc, Value mem,
138                                    Value token) {
139   auto tp = cast<ShapedType>(mem.getType());
140   auto elemTp = tp.getElementType();
141   auto shape = tp.getShape();
142   auto memTp = MemRefType::get(shape, elemTp);
143   SmallVector<Value> dynamicSizes;
144   for (unsigned r = 0, rank = tp.getRank(); r < rank; r++) {
145     if (shape[r] == ShapedType::kDynamic) {
146       Value dimOp = linalg::createOrFoldDimOp(builder, loc, mem, r);
147       dynamicSizes.push_back(dimOp);
148     }
149   }
150   return builder.create<gpu::AllocOp>(loc, TypeRange({memTp, token.getType()}),
151                                       token, dynamicSizes, ValueRange());
152 }
153 
154 // Allocates a typed buffer on the host with given size.
155 static Value genHostBuffer(OpBuilder &builder, Location loc, Type type,
156                            Value size) {
157   const auto memTp = MemRefType::get({ShapedType::kDynamic}, type);
158   return builder.create<memref::AllocOp>(loc, memTp, size).getResult();
159 }
160 
161 // Allocates a typed buffer on the device with given size.
162 static gpu::AllocOp genAllocBuffer(OpBuilder &builder, Location loc, Type type,
163                                    Value size, Value token) {
164   const auto memTp = MemRefType::get({ShapedType::kDynamic}, type);
165   return builder.create<gpu::AllocOp>(loc, TypeRange({memTp, token.getType()}),
166                                       token, size, ValueRange());
167 }
168 
169 // Allocates a void buffer on the device with given size.
170 static gpu::AllocOp genAllocBuffer(OpBuilder &builder, Location loc, Value size,
171                                    Value token) {
172   return genAllocBuffer(builder, loc, builder.getI8Type(), size, token);
173 }
174 
175 /// Deallocates memory from the device.
176 static Value genDeallocMemRef(OpBuilder &builder, Location loc, Value mem,
177                               Value token) {
178   return builder.create<gpu::DeallocOp>(loc, token.getType(), token, mem)
179       .getAsyncToken();
180 }
181 
182 /// Copies memory between host and device (direction is implicit).
183 static Value genCopyMemRef(OpBuilder &builder, Location loc, Value dst,
184                            Value src, Value token) {
185   return builder.create<gpu::MemcpyOp>(loc, token.getType(), token, dst, src)
186       .getAsyncToken();
187 }
188 
189 /// Generates an alloc/copy pair.
190 static Value genAllocCopy(OpBuilder &builder, Location loc, Value b,
191                           SmallVectorImpl<Value> &tokens) {
192   Value firstToken = genFirstWait(builder, loc);
193   auto alloc = genAllocMemRef(builder, loc, b, firstToken);
194   Value devMem = alloc.getResult(0);
195   Value depToken = alloc.getAsyncToken(); // copy-after-alloc
196   tokens.push_back(genCopyMemRef(builder, loc, devMem, b, depToken));
197   return devMem;
198 }
199 
200 /// Generates a memref from tensor operation.
201 static Value genTensorToMemref(PatternRewriter &rewriter, Location loc,
202                                Value tensor) {
203   auto tensorType = llvm::cast<ShapedType>(tensor.getType());
204   auto memrefType =
205       MemRefType::get(tensorType.getShape(), tensorType.getElementType());
206   return rewriter.create<bufferization::ToMemrefOp>(loc, memrefType, tensor);
207 }
208 
209 /// Prepares the outlined arguments, passing scalars and buffers in. Here we
210 /// assume that the first buffer is the one allocated for output. We create
211 /// a set of properly chained asynchronous allocation/copy pairs to increase
212 /// overlap before launching the kernel.
213 static Value genParametersIn(OpBuilder &builder, Location loc,
214                              SmallVectorImpl<Value> &scalars,
215                              SmallVectorImpl<Value> &buffers,
216                              SmallVectorImpl<Value> &args,
217                              SmallVectorImpl<Value> &tokens,
218                              bool useHostRegistrationForOut) {
219   Value out;
220   // Scalars are passed by value.
221   for (Value s : scalars)
222     args.push_back(s);
223   // Buffers are need to be made visible on device.
224   for (Value b : buffers) {
225     if (useHostRegistrationForOut) {
226       out = genHostRegisterMemref(builder, loc, b);
227       args.push_back(b);
228       useHostRegistrationForOut = false;
229       continue;
230     }
231     args.push_back(genAllocCopy(builder, loc, b, tokens));
232   }
233   return out;
234 }
235 
236 /// Finalizes the outlined arguments. The output buffer is copied depending
237 /// on the kernel token and then deallocated. All other buffers are simply
238 /// deallocated. Then we wait for all operations to complete.
239 static void genParametersOut(OpBuilder &builder, Location loc, Value out,
240                              Value kernelToken, SmallVectorImpl<Value> &scalars,
241                              SmallVectorImpl<Value> &buffers,
242                              SmallVectorImpl<Value> &args,
243                              SmallVectorImpl<Value> &tokens) {
244   unsigned base = scalars.size();
245   for (unsigned i = base, e = args.size(); i < e; i++) {
246     Value firstToken;
247     if (i == base) {
248       // Assumed output parameter: unregister or copy-out.
249       if (out) {
250         genHostUnregisterMemref(builder, loc, out);
251         out = Value();
252         continue;
253       }
254       firstToken =
255           genCopyMemRef(builder, loc, buffers[0], args[i], kernelToken);
256     } else {
257       firstToken = genFirstWait(builder, loc);
258     }
259     tokens.push_back(genDeallocMemRef(builder, loc, args[i], firstToken));
260   }
261 }
262 
263 /// Constructs code for new GPU kernel.
264 static void genGPUCode(PatternRewriter &rewriter, gpu::GPUFuncOp gpuFunc,
265                        scf::ParallelOp forallOp,
266                        SmallVectorImpl<Value> &constants,
267                        SmallVectorImpl<Value> &scalars,
268                        SmallVectorImpl<Value> &buffers) {
269   Location loc = gpuFunc->getLoc();
270   Block &block = gpuFunc.getBody().front();
271   rewriter.setInsertionPointToStart(&block);
272 
273   // Re-generate the constants, recapture all arguments.
274   unsigned arg = 0;
275   IRMapping irMap;
276   for (Value c : constants)
277     irMap.map(c, rewriter.clone(*c.getDefiningOp())->getResult(0));
278   for (Value s : scalars)
279     irMap.map(s, block.getArgument(arg++));
280   for (Value b : buffers)
281     irMap.map(b, block.getArgument(arg++));
282 
283   // Assume 1-dimensional grid/block configuration (only x dimension),
284   // so that:
285   //   row = blockIdx.x * blockDim.x + threadIdx.x
286   //   inc = blockDim.x * gridDim.x
287   Value bid = rewriter.create<gpu::BlockIdOp>(loc, gpu::Dimension::x);
288   Value bsz = rewriter.create<gpu::BlockDimOp>(loc, gpu::Dimension::x);
289   Value tid = rewriter.create<gpu::ThreadIdOp>(loc, gpu::Dimension::x);
290   Value gsz = rewriter.create<gpu::GridDimOp>(loc, gpu::Dimension::x);
291   Value mul = rewriter.create<arith::MulIOp>(loc, bid, bsz);
292   Value row = rewriter.create<arith::AddIOp>(loc, mul, tid);
293   Value inc = rewriter.create<arith::MulIOp>(loc, bsz, gsz);
294 
295   // Construct the iteration over the computational space that
296   // accounts for the fact that the total number of threads and
297   // the amount of work to be done usually do not match precisely.
298   //   for (r = row; r < N; r += inc) {
299   //     <loop-body>
300   //   }
301   Value upper = irMap.lookup(forallOp.getUpperBound()[0]);
302   scf::ForOp forOp = rewriter.create<scf::ForOp>(loc, row, upper, inc);
303   rewriter.cloneRegionBefore(forallOp.getRegion(), forOp.getRegion(),
304                              forOp.getRegion().begin(), irMap);
305 
306   // Done.
307   rewriter.setInsertionPointAfter(forOp);
308   rewriter.create<gpu::ReturnOp>(gpuFunc->getLoc());
309 }
310 
311 //===----------------------------------------------------------------------===//
312 // Library helper methods.
313 //===----------------------------------------------------------------------===//
314 
315 /// Helper to detect a + b with arguments taken from given block.
316 static bool matchAddOfArgs(Block *block, Value val) {
317   if (auto *def = val.getDefiningOp()) {
318     if (isa<arith::AddFOp, arith::AddIOp>(def)) {
319       Value a = block->getArguments()[0];
320       Value b = block->getArguments()[1];
321       return (def->getOperand(0) == a && def->getOperand(1) == b) ||
322              (def->getOperand(0) == b && def->getOperand(1) == a);
323     }
324   }
325   return false;
326 }
327 
328 /// Helper to detect a * b with arguments taken from given block.
329 static bool matchMulOfArgs(Block *block, Value val) {
330   if (auto *def = val.getDefiningOp()) {
331     if (isa<arith::MulFOp, arith::MulIOp>(def)) {
332       Value a = block->getArguments()[0];
333       Value b = block->getArguments()[1];
334       return (def->getOperand(0) == a && def->getOperand(1) == b) ||
335              (def->getOperand(0) == b && def->getOperand(1) == a);
336     }
337   }
338   return false;
339 }
340 
341 /// Helper to detect x = x + a * b
342 static bool matchSumOfMultOfArgs(linalg::GenericOp op) {
343   auto yieldOp = cast<linalg::YieldOp>(op.getRegion().front().getTerminator());
344   if (auto *def = yieldOp.getOperand(0).getDefiningOp()) {
345     if (isa<arith::AddFOp, arith::AddIOp>(def)) {
346       Value x = op.getBlock()->getArguments()[2];
347       return (def->getOperand(0) == x &&
348               matchMulOfArgs(op.getBlock(), def->getOperand(1))) ||
349              (def->getOperand(1) == x &&
350               matchMulOfArgs(op.getBlock(), def->getOperand(0)));
351     }
352   }
353   return false;
354 }
355 
356 // Helper to detect c += spy(s) x (a * b)
357 static bool matchSumReductionOfMulUnary(linalg::GenericOp op) {
358   auto yieldOp = cast<linalg::YieldOp>(op.getRegion().front().getTerminator());
359   // The linalg yields a custom reduce result.
360   Value s_out = op.getBlock()->getArguments()[2];
361   if (auto redOp =
362           yieldOp.getOperand(0).getDefiningOp<sparse_tensor::ReduceOp>()) {
363     // The reduce consumes the output.
364     Value other;
365     if (s_out == redOp->getOperand(0))
366       other = redOp->getOperand(1);
367     else if (s_out == redOp->getOperand(1))
368       other = redOp->getOperand(0);
369     else
370       return false;
371     // The reduce op also consumes an unary which also consumes the output
372     // and does not define an absent value.
373     if (auto unOp = other.getDefiningOp<sparse_tensor::UnaryOp>()) {
374       if (s_out != unOp->getOperand(0) || !unOp.getAbsentRegion().empty())
375         return false;
376       // And the bodies are as expected.
377       auto yieldUn = cast<sparse_tensor::YieldOp>(
378           unOp.getRegion(0).front().getTerminator());
379       auto yieldRed = cast<sparse_tensor::YieldOp>(
380           redOp.getRegion().front().getTerminator());
381       return matchMulOfArgs(op.getBlock(), yieldUn.getOperand(0)) &&
382              matchAddOfArgs(&redOp.getRegion().front(), yieldRed.getOperand(0));
383     }
384   }
385   return false;
386 }
387 
388 /// Determines if the given value is a dense tensor instead of a sparse one.
389 static bool isDenseTensor(Value v) {
390   return (sparse_tensor::getSparseTensorType(v).isAllDense());
391 }
392 
393 /// Test for sorted COO with suitable data and coordinates types.
394 static bool isAdmissibleCOO(SparseTensorType &aTp) {
395   return aTp.isCompressedLvl(0) && aTp.isOrderedLvl(0) && !aTp.isUniqueLvl(0) &&
396          aTp.isSingletonLvl(1) && aTp.isOrderedLvl(1) && aTp.isUniqueLvl(1) &&
397          (aTp.getElementType().isF64() || aTp.getElementType().isF32()) &&
398          (aTp.getCrdWidth() == 0 || aTp.getCrdWidth() == 32 ||
399           aTp.getCrdWidth() == 64);
400 }
401 
402 /// Test for CSR with suitable data and coordinates types.
403 static bool isAdmissibleCSR(SparseTensorType &aTp) {
404   return aTp.isDenseLvl(0) && aTp.isCompressedLvl(1) && aTp.isOrderedLvl(1) &&
405          aTp.isUniqueLvl(1) &&
406          (aTp.getElementType().isF64() || aTp.getElementType().isF32()) &&
407          (aTp.getCrdWidth() == 0 || aTp.getCrdWidth() == 32 ||
408           aTp.getCrdWidth() == 64);
409 }
410 
411 /// Test for admissible types on operands (with output parameter `isCOO`).
412 static bool areAdmissibleTypes(SparseTensorType aTp, SparseTensorType bTp,
413                                SparseTensorType cTp, bool enableRT,
414                                bool isMatVec, bool &isCOO) {
415   if (bTp.hasEncoding() || cTp.hasEncoding())
416     return false;
417   if (isAdmissibleCOO(aTp)) {
418     isCOO = true;
419 #ifdef CUSPARSE_COO_AOS
420     return isMatVec;
421 #else
422     return enableRT;
423 #endif
424   }
425   return isAdmissibleCSR(aTp);
426 }
427 
428 /// Generates the first positions/coordinates of a sparse matrix.
429 static Value genFirstPosOrCrds(OpBuilder &builder, Location loc, Value a,
430                                bool isCOO, bool enableRT) {
431   if (isCOO) {
432     // Library uses SoA COO, direct IR uses AoS COO.
433     if (enableRT)
434       return genToCoordinates(builder, loc, a, 0, /*cooStart=*/0);
435     return genToCoordinatesBuffer(builder, loc, a);
436   }
437   // CSR uses positions.
438   return genToPositions(builder, loc, a, 1);
439 }
440 
441 /// Generates the second coordinates of a sparse matrix.
442 static Value genSecondCrds(OpBuilder &builder, Location loc, Value a,
443                            bool isCOO, bool enableRT) {
444   if (isCOO && !enableRT)
445     return Value(); // nothing needed
446   return genToCoordinates(builder, loc, a, 1, /*cooStart=*/isCOO ? 0 : 2);
447 }
448 
449 /// Generates the sparse matrix multiplication.
450 static Operation *genSpMat(OpBuilder &builder, Location loc, Type handleTp,
451                            Type tokenTp, Value token, Value sz1, Value sz2,
452                            Value nseA, Value rowA, Value colA, Value valA,
453                            bool isCOO, bool enableRT) {
454   if (isCOO) {
455     // Library uses SoA COO, direct IR uses AoS COO.
456     if (enableRT) {
457       assert(colA);
458       return builder.create<gpu::CreateCooOp>(loc, handleTp, tokenTp, token,
459                                               sz1, sz2, nseA, rowA, colA, valA);
460     }
461 #ifdef CUSPARSE_COO_AOS
462     assert(!colA);
463     return builder.create<gpu::CreateCooAoSOp>(loc, handleTp, tokenTp, token,
464                                                sz1, sz2, nseA, rowA, valA);
465 #else
466     llvm_unreachable("gpu::CreateCooAoSOp is deprecated");
467 #endif
468   }
469   assert(colA);
470   return builder.create<gpu::CreateCsrOp>(loc, handleTp, tokenTp, token, sz1,
471                                           sz2, nseA, rowA, colA, valA);
472 }
473 
474 /// Match and rewrite SpMV kernel.
475 static LogicalResult
476 rewriteSpMV(PatternRewriter &rewriter, linalg::GenericOp op, bool enableRT,
477             GPUDataTransferStrategy gpuDataTransferStrategy) {
478   Location loc = op.getLoc();
479   Value a = op.getOperand(0);
480   Value x = op.getOperand(1);
481   Value y = op.getOperand(2); // we have y = Ax
482   SmallVector<Value> tokens;
483 
484   bool isZeroCopy =
485       gpuDataTransferStrategy == GPUDataTransferStrategy::kZeroCopy;
486 
487   // Only admissible sparse matrix format and dense vectors.
488   bool isCOO = false;
489   SparseTensorType aTp = getSparseTensorType(a);
490   SparseTensorType xTp = getSparseTensorType(x);
491   SparseTensorType yTp = getSparseTensorType(y);
492   if (!areAdmissibleTypes(aTp, xTp, yTp, enableRT, /*isMatVec=*/true, isCOO))
493     return failure();
494 
495   // Start sparse kernel and copy data from host to device.
496   //   a : memR/memC/memV -> rowA,colA,valA
497   //   x : memX           -> vecX
498   //   y : memY           -> vecY
499   Value nseA = rewriter.create<NumberOfEntriesOp>(loc, a);
500   Value szY = linalg::createOrFoldDimOp(rewriter, loc, a, 0);
501   Value szX = linalg::createOrFoldDimOp(rewriter, loc, a, 1);
502   Value memR = genFirstPosOrCrds(rewriter, loc, a, isCOO, enableRT);
503   Value memC = genSecondCrds(rewriter, loc, a, isCOO, enableRT);
504   Value memV = genToValues(rewriter, loc, a);
505   Value memX, memY;
506   Value castR, castC, castV, castX, castY;
507   if (gpuDataTransferStrategy != GPUDataTransferStrategy::kRegularDMA) {
508     memX = genTensorToMemref(rewriter, loc, x);
509     memY = genTensorToMemref(rewriter, loc, y);
510     castR = genHostRegisterMemref(rewriter, loc, memR);
511     if (memC)
512       castC = genHostRegisterMemref(rewriter, loc, memC);
513     castV = genHostRegisterMemref(rewriter, loc, memV);
514     castX = genHostRegisterMemref(rewriter, loc, memX);
515     castY = genHostRegisterMemref(rewriter, loc, memY);
516   }
517 
518   Value rowA = genAllocCopy(rewriter, loc, memR, tokens);
519   Value colA = memC ? genAllocCopy(rewriter, loc, memC, tokens) : Value();
520   Value valA = genAllocCopy(rewriter, loc, memV, tokens);
521   if (gpuDataTransferStrategy == GPUDataTransferStrategy::kRegularDMA)
522     memX = genTensorToMemref(rewriter, loc, x);
523   Value vecX = isZeroCopy ? memX : genAllocCopy(rewriter, loc, memX, tokens);
524   if (gpuDataTransferStrategy == GPUDataTransferStrategy::kRegularDMA)
525     memY = genTensorToMemref(rewriter, loc, y);
526   Value vecY = genAllocCopy(rewriter, loc, memY, tokens);
527   genBlockingWait(rewriter, loc, tokens);
528   tokens.clear();
529 
530   // Create sparse environment and sparse matrix/dense vector handles.
531   Type indexTp = rewriter.getIndexType();
532   Type dnTensorHandleTp = rewriter.getType<gpu::SparseDnTensorHandleType>();
533   Type spmatHandleTp = rewriter.getType<gpu::SparseSpMatHandleType>();
534   Type tokenTp = rewriter.getType<gpu::AsyncTokenType>();
535   Value token = genFirstWait(rewriter, loc);
536   Operation *spGenA =
537       genSpMat(rewriter, loc, spmatHandleTp, tokenTp, token, szY, szX, nseA,
538                rowA, colA, valA, isCOO, enableRT);
539   Value spMatA = spGenA->getResult(0);
540   token = spGenA->getResult(1);
541   auto dvecX = rewriter.create<gpu::CreateDnTensorOp>(
542       loc, dnTensorHandleTp, tokenTp, token, vecX, szX);
543   Value dnX = dvecX.getResult(0);
544   token = dvecX.getAsyncToken();
545   auto dvecY = rewriter.create<gpu::CreateDnTensorOp>(
546       loc, dnTensorHandleTp, tokenTp, token, vecY, szY);
547   Value dnY = dvecY.getResult(0);
548   token = dvecY.getAsyncToken();
549 
550   auto dnYType = llvm::cast<ShapedType>(y.getType()).getElementType();
551 
552   // Precompute buffersize for SpMV.
553   auto bufferComp = rewriter.create<gpu::SpMVBufferSizeOp>(
554       loc, indexTp, tokenTp, token, spMatA, dnX, dnY,
555       /*computeType=*/dnYType);
556   Value bufferSz = bufferComp.getResult(0);
557   token = bufferComp.getAsyncToken();
558   auto buf = genAllocBuffer(rewriter, loc, bufferSz, token);
559   Value buffer = buf.getResult(0);
560   token = buf.getAsyncToken();
561 
562   // Perform the SpMV.
563   auto spmvComp = rewriter.create<gpu::SpMVOp>(
564       loc, tokenTp, token, spMatA, dnX, dnY, /*computeType=*/dnYType, buffer);
565   token = spmvComp.getAsyncToken();
566 
567   // Copy data back to host and free all the resoures.
568   token = rewriter.create<gpu::DestroySpMatOp>(loc, tokenTp, token, spMatA)
569               .getAsyncToken();
570   token = rewriter.create<gpu::DestroyDnTensorOp>(loc, tokenTp, token, dnX)
571               .getAsyncToken();
572   token = rewriter.create<gpu::DestroyDnTensorOp>(loc, tokenTp, token, dnY)
573               .getAsyncToken();
574   token = genDeallocMemRef(rewriter, loc, rowA, token);
575   if (colA)
576     token = genDeallocMemRef(rewriter, loc, colA, token);
577   token = genDeallocMemRef(rewriter, loc, valA, token);
578   token = genDeallocMemRef(rewriter, loc, buffer, token);
579   if (!isZeroCopy)
580     token = genDeallocMemRef(rewriter, loc, vecX, token);
581   token = genCopyMemRef(rewriter, loc, memY, vecY, token);
582   token = genDeallocMemRef(rewriter, loc, vecY, token);
583   tokens.push_back(token);
584   genBlockingWait(rewriter, loc, tokens);
585   tokens.clear();
586   if (gpuDataTransferStrategy != GPUDataTransferStrategy::kRegularDMA) {
587     genHostUnregisterMemref(rewriter, loc, castR);
588     if (memC)
589       genHostUnregisterMemref(rewriter, loc, castC);
590     genHostUnregisterMemref(rewriter, loc, castV);
591     genHostUnregisterMemref(rewriter, loc, castX);
592     genHostUnregisterMemref(rewriter, loc, castY);
593   }
594 
595   // Done.
596   rewriter.replaceOpWithNewOp<bufferization::ToTensorOp>(op, memY);
597   return success();
598 }
599 
600 /// Match and rewrite SpMM kernel.
601 static LogicalResult
602 rewriteSpMM(PatternRewriter &rewriter, linalg::GenericOp op, bool enableRT,
603             GPUDataTransferStrategy gpuDataTransferStrategy) {
604   Location loc = op.getLoc();
605   Value a = op.getOperand(0);
606   Value b = op.getOperand(1);
607   Value c = op.getOperand(2); // we have C = AB
608   SmallVector<Value> tokens;
609 
610   bool isZeroCopy =
611       gpuDataTransferStrategy == GPUDataTransferStrategy::kZeroCopy;
612 
613   // Only admissible sparse matrix format and dense matrices.
614   bool isCOO = false;
615   SparseTensorType aTp = getSparseTensorType(a);
616   SparseTensorType bTp = getSparseTensorType(b);
617   SparseTensorType cTp = getSparseTensorType(c);
618   if (!areAdmissibleTypes(aTp, bTp, cTp, enableRT, /*isMatVec=*/false, isCOO))
619     return failure();
620 
621   // Start sparse kernel and copy data from host to device.
622   //   a : memR/memC/memV -> rowA,colA,valA
623   //   b : bufB           -> matA
624   //   c : bufC           -> matC
625   Value nseA = rewriter.create<NumberOfEntriesOp>(loc, a);
626   Value szm = linalg::createOrFoldDimOp(rewriter, loc, a, 0);
627   Value szk = linalg::createOrFoldDimOp(rewriter, loc, a, 1);
628   Value szn = linalg::createOrFoldDimOp(rewriter, loc, b, 1);
629   Value memR = genFirstPosOrCrds(rewriter, loc, a, isCOO, enableRT);
630   Value memC = genSecondCrds(rewriter, loc, a, isCOO, enableRT);
631   Value memV = genToValues(rewriter, loc, a);
632   Value bufB, bufC;
633   Value castR, castC, castV, castB, castBufC;
634   if (gpuDataTransferStrategy != GPUDataTransferStrategy::kRegularDMA) {
635     bufB = genTensorToMemref(rewriter, loc, b);
636     bufC = genTensorToMemref(rewriter, loc, c);
637     castR = genHostRegisterMemref(rewriter, loc, memR);
638     if (memC)
639       castC = genHostRegisterMemref(rewriter, loc, memC);
640     castV = genHostRegisterMemref(rewriter, loc, memV);
641     castB = genHostRegisterMemref(rewriter, loc, bufB);
642     castBufC = genHostRegisterMemref(rewriter, loc, bufC);
643   }
644   Value rowA = genAllocCopy(rewriter, loc, memR, tokens);
645   Value colA = memC ? genAllocCopy(rewriter, loc, memC, tokens) : Value();
646   Value valA = genAllocCopy(rewriter, loc, memV, tokens);
647   if (gpuDataTransferStrategy == GPUDataTransferStrategy::kRegularDMA)
648     bufB = genTensorToMemref(rewriter, loc, b);
649   Value matB = isZeroCopy ? bufB : genAllocCopy(rewriter, loc, bufB, tokens);
650   if (gpuDataTransferStrategy == GPUDataTransferStrategy::kRegularDMA)
651     bufC = genTensorToMemref(rewriter, loc, c);
652   Value matC = genAllocCopy(rewriter, loc, bufC, tokens);
653   genBlockingWait(rewriter, loc, tokens);
654   tokens.clear();
655 
656   // Create sparse environment and sparse matrix/dense matrix handles.
657   Type indexTp = rewriter.getIndexType();
658   Type dnTensorHandleTp = rewriter.getType<gpu::SparseDnTensorHandleType>();
659   Type spMatHandleTp = rewriter.getType<gpu::SparseSpMatHandleType>();
660   Type tokenTp = rewriter.getType<gpu::AsyncTokenType>();
661   Value token = genFirstWait(rewriter, loc);
662   Operation *spGenA =
663       genSpMat(rewriter, loc, spMatHandleTp, tokenTp, token, szm, szk, nseA,
664                rowA, colA, valA, isCOO, enableRT);
665   Value spMatA = spGenA->getResult(0);
666   token = spGenA->getResult(1);
667   auto dmatB = rewriter.create<gpu::CreateDnTensorOp>(
668       loc, dnTensorHandleTp, tokenTp, token, matB,
669       SmallVector<Value>{szk, szn});
670   Value dnB = dmatB.getResult(0);
671   token = dmatB.getAsyncToken();
672   auto dmatC = rewriter.create<gpu::CreateDnTensorOp>(
673       loc, dnTensorHandleTp, tokenTp, token, matC,
674       SmallVector<Value>{szm, szn});
675   Value dnC = dmatC.getResult(0);
676   token = dmatC.getAsyncToken();
677 
678   auto dmatCType = llvm::cast<ShapedType>(c.getType()).getElementType();
679 
680   // Precompute buffersize for SpMM.
681   auto bufferComp = rewriter.create<gpu::SpMMBufferSizeOp>(
682       loc, indexTp, tokenTp, token, spMatA, dnB, dnC,
683       /*computeType=*/dmatCType);
684   Value bufferSz = bufferComp.getResult(0);
685   token = bufferComp.getAsyncToken();
686   auto buf = genAllocBuffer(rewriter, loc, bufferSz, token);
687   Value buffer = buf.getResult(0);
688   token = buf.getAsyncToken();
689 
690   auto dnCType = llvm::cast<ShapedType>(c.getType()).getElementType();
691 
692   // Perform the SpMM.
693   auto spmmComp = rewriter.create<gpu::SpMMOp>(
694       loc, tokenTp, token, spMatA, dnB, dnC, /*computeType=*/dnCType, buffer);
695   token = spmmComp.getAsyncToken();
696 
697   // Copy data back to host and free all the resoures.
698   token = rewriter.create<gpu::DestroySpMatOp>(loc, tokenTp, token, spMatA)
699               .getAsyncToken();
700   token = rewriter.create<gpu::DestroyDnTensorOp>(loc, tokenTp, token, dnB)
701               .getAsyncToken();
702   token = rewriter.create<gpu::DestroyDnTensorOp>(loc, tokenTp, token, dnC)
703               .getAsyncToken();
704   token = genDeallocMemRef(rewriter, loc, rowA, token);
705   if (colA)
706     token = genDeallocMemRef(rewriter, loc, colA, token);
707   token = genDeallocMemRef(rewriter, loc, valA, token);
708   token = genDeallocMemRef(rewriter, loc, buffer, token);
709   if (!isZeroCopy)
710     token = genDeallocMemRef(rewriter, loc, matB, token);
711   token = genCopyMemRef(rewriter, loc, bufC, matC, token);
712   token = genDeallocMemRef(rewriter, loc, matC, token);
713   tokens.push_back(token);
714   genBlockingWait(rewriter, loc, tokens);
715   tokens.clear();
716   if (gpuDataTransferStrategy != GPUDataTransferStrategy::kRegularDMA) {
717     genHostUnregisterMemref(rewriter, loc, castR);
718     if (memC)
719       genHostUnregisterMemref(rewriter, loc, castC);
720     genHostUnregisterMemref(rewriter, loc, castV);
721     genHostUnregisterMemref(rewriter, loc, castB);
722     genHostUnregisterMemref(rewriter, loc, castC);
723   }
724 
725   // Done.
726   rewriter.replaceOpWithNewOp<bufferization::ToTensorOp>(op, bufC);
727   return success();
728 }
729 
730 // Match and rewrite SpGEMM kernel.
731 static LogicalResult
732 rewriteSpGEMM(PatternRewriter &rewriter, linalg::GenericOp op, bool enableRT,
733               GPUDataTransferStrategy gpuDataTransferStrategy) {
734   Location loc = op.getLoc();
735   Value a = op.getOperand(0);
736   Value b = op.getOperand(1);
737   Value c = op.getOperand(2); // we have C = AB
738   SmallVector<Value> tokens;
739 
740   // Only CSR <- CSR x CSR supported.
741   bool isCOO = false;
742   SparseTensorType aTp = getSparseTensorType(a);
743   SparseTensorType bTp = getSparseTensorType(b);
744   SparseTensorType cTp = getSparseTensorType(c);
745   if (!isAdmissibleCSR(aTp) || !isAdmissibleCSR(bTp) || !isAdmissibleCSR(cTp))
746     return failure();
747 
748   // Start sparse kernel and copy data from host to device.
749   //   a : amemR/amemC/amemV -> rowA,colA,valA
750   //   b : bmemR/bmemC/bmemV -> rowB,colB,valB
751   //   c : materializes
752   auto dnCType = cTp.getElementType();
753   Value nseA = rewriter.create<NumberOfEntriesOp>(loc, a);
754   Value nseB = rewriter.create<NumberOfEntriesOp>(loc, b);
755   Value szm = linalg::createOrFoldDimOp(rewriter, loc, a, 0);
756   Value szk = linalg::createOrFoldDimOp(rewriter, loc, a, 1);
757   Value szn = linalg::createOrFoldDimOp(rewriter, loc, b, 1);
758   Value amemR = genFirstPosOrCrds(rewriter, loc, a, isCOO, enableRT);
759   Value amemC = genSecondCrds(rewriter, loc, a, isCOO, enableRT);
760   Value amemV = genToValues(rewriter, loc, a);
761   Value bmemR = genFirstPosOrCrds(rewriter, loc, b, isCOO, enableRT);
762   Value bmemC = genSecondCrds(rewriter, loc, b, isCOO, enableRT);
763   Value bmemV = genToValues(rewriter, loc, b);
764   Value rowA = genAllocCopy(rewriter, loc, amemR, tokens);
765   Value colA = genAllocCopy(rewriter, loc, amemC, tokens);
766   Value valA = genAllocCopy(rewriter, loc, amemV, tokens);
767   Value rowB = genAllocCopy(rewriter, loc, bmemR, tokens);
768   Value colB = genAllocCopy(rewriter, loc, bmemC, tokens);
769   Value valB = genAllocCopy(rewriter, loc, bmemV, tokens);
770   genBlockingWait(rewriter, loc, tokens);
771   tokens.clear();
772 
773   // Create sparse environment and sparse matrix/dense vector handles.
774   Type indexTp = rewriter.getIndexType();
775   Type spmatHandleTp = rewriter.getType<gpu::SparseSpMatHandleType>();
776   Type descTp = rewriter.getType<gpu::SparseSpGEMMOpHandleType>();
777   Type tokenTp = rewriter.getType<gpu::AsyncTokenType>();
778   Value token = genFirstWait(rewriter, loc);
779   Operation *spGenA =
780       genSpMat(rewriter, loc, spmatHandleTp, tokenTp, token, szm, szk, nseA,
781                rowA, colA, valA, isCOO, enableRT);
782   Value spMatA = spGenA->getResult(0);
783   token = spGenA->getResult(1);
784   Operation *spGenB =
785       genSpMat(rewriter, loc, spmatHandleTp, tokenTp, token, szk, szn, nseB,
786                rowB, colB, valB, isCOO, enableRT);
787   Value spMatB = spGenB->getResult(0);
788   token = spGenB->getResult(1);
789 
790   // Sparse matrix C materializes (also assumes beta == 0).
791   Value zero = constantIndex(rewriter, loc, 0);
792   Value one = constantIndex(rewriter, loc, 1);
793   Value mplus1 = rewriter.create<arith::AddIOp>(loc, szm, one);
794   auto e1 = genAllocBuffer(rewriter, loc, cTp.getPosType(), mplus1, token);
795   Value rowC = e1.getResult(0);
796   token = e1.getAsyncToken();
797   auto e2 = genAllocBuffer(rewriter, loc, cTp.getCrdType(), zero, token);
798   Value colC = e2.getResult(0); // no free needed
799   token = e2.getAsyncToken();
800   auto e3 = genAllocBuffer(rewriter, loc, dnCType, zero, token);
801   Value valC = e3.getResult(0); // no free needed
802   token = e3.getAsyncToken();
803   Operation *spGenC =
804       genSpMat(rewriter, loc, spmatHandleTp, tokenTp, token, szm, szn, zero,
805                rowC, colC, valC, isCOO, enableRT);
806   Value spMatC = spGenC->getResult(0);
807   token = spGenC->getResult(1);
808 
809   // Precompute buffersizes for SpGEMM.
810   Operation *descOp =
811       rewriter.create<gpu::SpGEMMCreateDescrOp>(loc, descTp, tokenTp, token);
812   Value desc = descOp->getResult(0);
813   token = descOp->getResult(1);
814   Operation *work1 = rewriter.create<gpu::SpGEMMWorkEstimationOrComputeOp>(
815       loc, indexTp, tokenTp, token, desc, gpu::TransposeMode::NON_TRANSPOSE,
816       gpu::TransposeMode::NON_TRANSPOSE, spMatA, spMatB, spMatC, dnCType, zero,
817       valC, gpu::SpGEMMWorkEstimationOrComputeKind::WORK_ESTIMATION);
818   Value bufferSz1 = work1->getResult(0);
819   token = work1->getResult(1);
820   auto buf1 = genAllocBuffer(rewriter, loc, bufferSz1, token);
821   Value buffer1 = buf1.getResult(0);
822   token = buf1.getAsyncToken();
823   Operation *work2 = rewriter.create<gpu::SpGEMMWorkEstimationOrComputeOp>(
824       loc, indexTp, tokenTp, token, desc, gpu::TransposeMode::NON_TRANSPOSE,
825       gpu::TransposeMode::NON_TRANSPOSE, spMatA, spMatB, spMatC, dnCType,
826       bufferSz1, buffer1,
827       gpu::SpGEMMWorkEstimationOrComputeKind::WORK_ESTIMATION);
828   token = work2->getResult(1);
829 
830   // Compute step.
831   Operation *compute1 = rewriter.create<gpu::SpGEMMWorkEstimationOrComputeOp>(
832       loc, indexTp, tokenTp, token, desc, gpu::TransposeMode::NON_TRANSPOSE,
833       gpu::TransposeMode::NON_TRANSPOSE, spMatA, spMatB, spMatC, dnCType, zero,
834       valC, gpu::SpGEMMWorkEstimationOrComputeKind::COMPUTE);
835   Value bufferSz2 = compute1->getResult(0);
836   token = compute1->getResult(1);
837   auto buf2 = genAllocBuffer(rewriter, loc, bufferSz2, token);
838   Value buffer2 = buf2.getResult(0);
839   token = buf2.getAsyncToken();
840   Operation *compute2 = rewriter.create<gpu::SpGEMMWorkEstimationOrComputeOp>(
841       loc, indexTp, tokenTp, token, desc, gpu::TransposeMode::NON_TRANSPOSE,
842       gpu::TransposeMode::NON_TRANSPOSE, spMatA, spMatB, spMatC, dnCType,
843       bufferSz2, buffer2, gpu::SpGEMMWorkEstimationOrComputeKind::COMPUTE);
844   token = compute2->getResult(1);
845 
846   // Get sizes.
847   Operation *sizes = rewriter.create<gpu::SpMatGetSizeOp>(
848       loc, indexTp, indexTp, indexTp, tokenTp, token, spMatC);
849   Value nnz = sizes->getResult(2);
850   token = sizes->getResult(3);
851   auto a2 = genAllocBuffer(rewriter, loc, cTp.getCrdType(), nnz, token);
852   colC = a2.getResult(0);
853   token = a2.getAsyncToken();
854   auto a3 = genAllocBuffer(rewriter, loc, dnCType, nnz, token);
855   valC = a3.getResult(0);
856   token = a3.getAsyncToken();
857 
858   // Update C with new pointers and copy final product back into C.
859   Operation *update = rewriter.create<gpu::SetCsrPointersOp>(
860       loc, tokenTp, token, spMatC, rowC, colC, valC);
861   token = update->getResult(0);
862   Operation *copy = rewriter.create<gpu::SpGEMMCopyOp>(
863       loc, tokenTp, token, desc, gpu::TransposeMode::NON_TRANSPOSE,
864       gpu::TransposeMode::NON_TRANSPOSE, spMatA, spMatB, spMatC, dnCType);
865   token = copy->getResult(0);
866 
867   // Allocate buffers on host.
868   Value rowH = genHostBuffer(rewriter, loc, cTp.getPosType(), mplus1);
869   Value colH = genHostBuffer(rewriter, loc, cTp.getCrdType(), nnz);
870   Value valH = genHostBuffer(rewriter, loc, dnCType, nnz);
871 
872   // Copy data back to host and free all the resoures.
873   token = rewriter.create<gpu::SpGEMMDestroyDescrOp>(loc, tokenTp, token, desc)
874               .getAsyncToken();
875   token = rewriter.create<gpu::DestroySpMatOp>(loc, tokenTp, token, spMatA)
876               .getAsyncToken();
877   token = rewriter.create<gpu::DestroySpMatOp>(loc, tokenTp, token, spMatB)
878               .getAsyncToken();
879   token = rewriter.create<gpu::DestroySpMatOp>(loc, tokenTp, token, spMatC)
880               .getAsyncToken();
881   token = genCopyMemRef(rewriter, loc, rowH, rowC, token);
882   token = genCopyMemRef(rewriter, loc, colH, colC, token);
883   token = genCopyMemRef(rewriter, loc, valH, valC, token);
884   token = genDeallocMemRef(rewriter, loc, rowA, token);
885   token = genDeallocMemRef(rewriter, loc, colA, token);
886   token = genDeallocMemRef(rewriter, loc, valA, token);
887   token = genDeallocMemRef(rewriter, loc, rowB, token);
888   token = genDeallocMemRef(rewriter, loc, colB, token);
889   token = genDeallocMemRef(rewriter, loc, valB, token);
890   token = genDeallocMemRef(rewriter, loc, rowC, token);
891   token = genDeallocMemRef(rewriter, loc, colC, token);
892   token = genDeallocMemRef(rewriter, loc, valC, token);
893   token = genDeallocMemRef(rewriter, loc, buffer1, token);
894   token = genDeallocMemRef(rewriter, loc, buffer2, token);
895   tokens.push_back(token);
896   genBlockingWait(rewriter, loc, tokens);
897   tokens.clear();
898 
899   // Done.
900   Value vt = rewriter.create<bufferization::ToTensorOp>(loc, valH);
901   Value rt = rewriter.create<bufferization::ToTensorOp>(loc, rowH);
902   Value ct = rewriter.create<bufferization::ToTensorOp>(loc, colH);
903   rewriter.replaceOpWithNewOp<AssembleOp>(op, c.getType(), vt,
904                                           ValueRange{rt, ct});
905   return success();
906 }
907 
908 // Match and rewrite 2:4 SpMM kernel.
909 static LogicalResult
910 rewrite2To4SpMM(PatternRewriter &rewriter, linalg::GenericOp op,
911                 GPUDataTransferStrategy gpuDataTransferStrategy) {
912   Location loc = op.getLoc();
913   Value A = op.getOperand(0);
914   Value B = op.getOperand(1);
915   Value C = op.getOperand(2); // we have C = AB
916   SmallVector<Value> tokens;
917 
918   bool isZeroCopy =
919       gpuDataTransferStrategy == GPUDataTransferStrategy::kZeroCopy;
920 
921   // All input should be dense tensors.
922   if (!isDenseTensor(A) || !isDenseTensor(B) || !isDenseTensor(C))
923     return failure();
924 
925   Value matA, matB;
926   Value bufA = genTensorToMemref(rewriter, loc, A);
927   if (!isZeroCopy)
928     matA = genAllocCopy(rewriter, loc, bufA, tokens);
929   Value bufB = genTensorToMemref(rewriter, loc, B);
930   if (!isZeroCopy)
931     matB = genAllocCopy(rewriter, loc, bufB, tokens);
932   Value bufC = genTensorToMemref(rewriter, loc, C);
933   Value castA, castB, castC;
934   if (gpuDataTransferStrategy != GPUDataTransferStrategy::kRegularDMA) {
935     castA = genHostRegisterMemref(rewriter, loc, bufA);
936     castB = genHostRegisterMemref(rewriter, loc, bufB);
937     castC = genHostRegisterMemref(rewriter, loc, bufC);
938   }
939   if (isZeroCopy) {
940     matA = bufA;
941     matB = bufB;
942   }
943   Value matC = genAllocCopy(rewriter, loc, bufC, tokens);
944   genBlockingWait(rewriter, loc, tokens);
945   tokens.clear();
946 
947   // Create sparse environment and sparse matrix/dense vector handles.
948   Value szm = linalg::createOrFoldDimOp(rewriter, loc, matA, 0);
949   Value szk = linalg::createOrFoldDimOp(rewriter, loc, matB, 0);
950   Value szn = linalg::createOrFoldDimOp(rewriter, loc, matC, 1);
951   Type indexTp = rewriter.getIndexType();
952   Type dnTensorHandleTp = rewriter.getType<gpu::SparseDnTensorHandleType>();
953   Type spMatHandleTp = rewriter.getType<gpu::SparseSpMatHandleType>();
954   Type tokenTp = rewriter.getType<gpu::AsyncTokenType>();
955   Value token = genFirstWait(rewriter, loc);
956   Operation *spGenA = rewriter.create<gpu::Create2To4SpMatOp>(
957       loc, spMatHandleTp, tokenTp, token, szm, szk,
958       gpu::Prune2To4SpMatFlag::PRUNE_AND_CHECK, matA);
959   Value spMatA = spGenA->getResult(0);
960   token = spGenA->getResult(1);
961   auto dmatB = rewriter.create<gpu::CreateDnTensorOp>(
962       loc, dnTensorHandleTp, tokenTp, token, matB,
963       SmallVector<Value>{szk, szn});
964   Value dnB = dmatB.getResult(0);
965   token = dmatB.getAsyncToken();
966   auto dmatC = rewriter.create<gpu::CreateDnTensorOp>(
967       loc, dnTensorHandleTp, tokenTp, token, matC,
968       SmallVector<Value>{szm, szn});
969   Value dnC = dmatC.getResult(0);
970   token = dmatC.getAsyncToken();
971   auto dmatCType = llvm::cast<ShapedType>(matC.getType()).getElementType();
972 
973   // Precompute buffersize for SpMM.
974   SmallVector<Type> bufferTypes_{indexTp, indexTp, indexTp};
975   TypeRange bufferTypes(bufferTypes_);
976   auto bufferComp = rewriter.create<gpu::SpMMBufferSizeOp>(
977       loc, bufferTypes, tokenTp, token, gpu::TransposeMode::NON_TRANSPOSE,
978       gpu::TransposeMode::NON_TRANSPOSE, spMatA, dnB, dnC,
979       /*computeType=*/dmatCType);
980   token = bufferComp.getAsyncToken();
981 
982   Value bufferSz = bufferComp.getResult(0);
983   auto buf = genAllocBuffer(rewriter, loc, bufferSz, token);
984   Value buffer = buf.getResult(0);
985   token = buf.getAsyncToken();
986 
987   Value bufferSz2 = bufferComp.getResult(1);
988   auto buf2 = genAllocBuffer(rewriter, loc, bufferSz2, token);
989   Value buffer2 = buf2.getResult(0);
990   token = buf2.getAsyncToken();
991 
992   Value bufferSz3 = bufferComp.getResult(2);
993   auto buf3 = genAllocBuffer(rewriter, loc, bufferSz3, token);
994   Value buffer3 = buf3.getResult(0);
995   token = buf3.getAsyncToken();
996 
997   auto dnCType = llvm::cast<ShapedType>(matC.getType()).getElementType();
998 
999   // Perform the SpMM.
1000   auto spmmComp = rewriter.create<gpu::SpMMOp>(
1001       loc, tokenTp, token, spMatA, dnB, dnC, /*computeType=*/dnCType,
1002       SmallVector<Value>{buffer, buffer2, buffer3});
1003   token = spmmComp.getAsyncToken();
1004 
1005   // Copy data back to host and free all the resources.
1006   token = rewriter.create<gpu::DestroySpMatOp>(loc, tokenTp, token, spMatA)
1007               .getAsyncToken();
1008   token = rewriter.create<gpu::DestroyDnTensorOp>(loc, tokenTp, token, dnB)
1009               .getAsyncToken();
1010   token = rewriter.create<gpu::DestroyDnTensorOp>(loc, tokenTp, token, dnC)
1011               .getAsyncToken();
1012   SmallVector<Value> newDynamicSizes;
1013   token = genDeallocMemRef(rewriter, loc, buffer, token);
1014   token = genDeallocMemRef(rewriter, loc, buffer2, token);
1015   token = genDeallocMemRef(rewriter, loc, buffer3, token);
1016   if (!isZeroCopy)
1017     token = genDeallocMemRef(rewriter, loc, matA, token);
1018   if (!isZeroCopy)
1019     token = genDeallocMemRef(rewriter, loc, matB, token);
1020   token = genCopyMemRef(rewriter, loc, bufC, matC, token);
1021   token = genDeallocMemRef(rewriter, loc, matC, token);
1022   tokens.push_back(token);
1023   genBlockingWait(rewriter, loc, tokens);
1024   tokens.clear();
1025   if (gpuDataTransferStrategy != GPUDataTransferStrategy::kRegularDMA) {
1026     genHostUnregisterMemref(rewriter, loc, castA);
1027     genHostUnregisterMemref(rewriter, loc, castB);
1028     genHostUnregisterMemref(rewriter, loc, castC);
1029   }
1030 
1031   // Done.
1032   rewriter.replaceOpWithNewOp<bufferization::ToTensorOp>(op, bufC);
1033   return success();
1034 }
1035 
1036 /// Match and rewrite SDDMM kernel.
1037 static LogicalResult
1038 rewriteSDDMM(PatternRewriter &rewriter, linalg::GenericOp op, bool enableRT,
1039              GPUDataTransferStrategy gpuDataTransferStrategy) {
1040   Location loc = op.getLoc();
1041   Value a = op.getOperand(0);
1042   Value b = op.getOperand(1);
1043   Value c = op.getOperand(2);
1044   SmallVector<Value> tokens;
1045 
1046   bool isZeroCopy =
1047       gpuDataTransferStrategy == GPUDataTransferStrategy::kZeroCopy;
1048 
1049   // Only admissible sparse matrix format and dense matrices, no COO.
1050   bool isCOO = false;
1051   SparseTensorType aTp = getSparseTensorType(a);
1052   SparseTensorType bTp = getSparseTensorType(b);
1053   SparseTensorType cTp = getSparseTensorType(c);
1054   if (!areAdmissibleTypes(cTp, bTp, aTp, enableRT, false, isCOO))
1055     return failure();
1056   if (isCOO)
1057     return failure();
1058 
1059   // The SDDMM does the in-place operation.
1060   // Start sparse kernel and copy data from host to device.
1061   //   a : bufA           -> matA
1062   //   b : bufB           -> matA
1063   //   c : memR/memC/memV -> rowC,colC,valC
1064   Value nseC = rewriter.create<NumberOfEntriesOp>(loc, c);
1065   Value szm = linalg::createOrFoldDimOp(rewriter, loc, a, 0);
1066   Value szk = linalg::createOrFoldDimOp(rewriter, loc, a, 1);
1067   Value szn = linalg::createOrFoldDimOp(rewriter, loc, b, 1);
1068   Value matA, matB;
1069   Value bufA = genTensorToMemref(rewriter, loc, a);
1070   if (!isZeroCopy)
1071     matA = genAllocCopy(rewriter, loc, bufA, tokens);
1072   Value bufB = genTensorToMemref(rewriter, loc, b);
1073   if (!isZeroCopy)
1074     matB = isZeroCopy ? bufB : genAllocCopy(rewriter, loc, bufB, tokens);
1075   Value memR = genFirstPosOrCrds(rewriter, loc, c, isCOO, enableRT);
1076   Value memC = genSecondCrds(rewriter, loc, c, isCOO, enableRT);
1077   Value memV = genToValues(rewriter, loc, c);
1078   Value castB, castA, castR, castC, castV;
1079   if (gpuDataTransferStrategy != GPUDataTransferStrategy::kRegularDMA) {
1080     castB = genHostRegisterMemref(rewriter, loc, bufB);
1081     castA = genHostRegisterMemref(rewriter, loc, bufA);
1082     castR = genHostRegisterMemref(rewriter, loc, memR);
1083     if (memC)
1084       castC = genHostRegisterMemref(rewriter, loc, memC);
1085     castV = genHostRegisterMemref(rewriter, loc, memV);
1086   }
1087   if (isZeroCopy) {
1088     matA = bufA;
1089     matB = bufB;
1090   }
1091   Value rowC = genAllocCopy(rewriter, loc, memR, tokens);
1092   Value colC = memC ? genAllocCopy(rewriter, loc, memC, tokens) : Value();
1093   Value valC = genAllocCopy(rewriter, loc, memV, tokens);
1094   genBlockingWait(rewriter, loc, tokens);
1095   tokens.clear();
1096 
1097   // Create sparse environment and sparse matrix/dense matrix handles.
1098   Type indexTp = rewriter.getIndexType();
1099   Type dnMatHandleTp = rewriter.getType<gpu::SparseDnTensorHandleType>();
1100   Type spMatHandleTp = rewriter.getType<gpu::SparseSpMatHandleType>();
1101   Type tokenTp = rewriter.getType<gpu::AsyncTokenType>();
1102   Value token = genFirstWait(rewriter, loc);
1103   auto dmatA = rewriter.create<gpu::CreateDnTensorOp>(
1104       loc, dnMatHandleTp, tokenTp, token, matA, SmallVector<Value>{szm, szk});
1105   Value dnA = dmatA.getResult(0);
1106   token = dmatA.getAsyncToken();
1107   auto dmatB = rewriter.create<gpu::CreateDnTensorOp>(
1108       loc, dnMatHandleTp, tokenTp, token, matB, SmallVector<Value>{szk, szn});
1109   Value dnB = dmatB.getResult(0);
1110   token = dmatB.getAsyncToken();
1111 
1112   Operation *spGenC =
1113       genSpMat(rewriter, loc, spMatHandleTp, tokenTp, token, szm, szn, nseC,
1114                rowC, colC, valC, isCOO, enableRT);
1115   Value spMatC = spGenC->getResult(0);
1116   token = spGenC->getResult(1);
1117   auto dnCType = llvm::cast<ShapedType>(c.getType()).getElementType();
1118 
1119   // Precompute buffersize for SDDMM.
1120   auto bufferComp = rewriter.create<gpu::SDDMMBufferSizeOp>(
1121       loc, indexTp, tokenTp, token, dnA, dnB, spMatC, dnCType);
1122   Value bufferSz = bufferComp.getResult(0);
1123   token = bufferComp.getAsyncToken();
1124   auto buf = genAllocBuffer(rewriter, loc, bufferSz, token);
1125   Value buffer = buf.getResult(0);
1126   token = buf.getAsyncToken();
1127 
1128   // Perform the SDDMM.
1129   auto sddmmComp = rewriter.create<gpu::SDDMMOp>(loc, tokenTp, token, dnA, dnB,
1130                                                  spMatC, dnCType, buffer);
1131   token = sddmmComp.getAsyncToken();
1132 
1133   // Copy data back to host and free all the resoures.
1134   token = rewriter.create<gpu::DestroyDnTensorOp>(loc, tokenTp, token, dnA)
1135               .getAsyncToken();
1136   token = rewriter.create<gpu::DestroyDnTensorOp>(loc, tokenTp, token, dnB)
1137               .getAsyncToken();
1138   token = rewriter.create<gpu::DestroySpMatOp>(loc, tokenTp, token, spMatC)
1139               .getAsyncToken();
1140   token = genDeallocMemRef(rewriter, loc, buffer, token);
1141   if (!isZeroCopy) {
1142     token = genDeallocMemRef(rewriter, loc, matA, token);
1143     token = genDeallocMemRef(rewriter, loc, matB, token);
1144   }
1145   token = genDeallocMemRef(rewriter, loc, rowC, token);
1146   if (colC)
1147     token = genDeallocMemRef(rewriter, loc, colC, token);
1148   token = genCopyMemRef(rewriter, loc, memV, valC, token);
1149   token = genDeallocMemRef(rewriter, loc, valC, token);
1150   tokens.push_back(token);
1151   genBlockingWait(rewriter, loc, tokens);
1152   tokens.clear();
1153   if (gpuDataTransferStrategy != GPUDataTransferStrategy::kRegularDMA) {
1154     genHostUnregisterMemref(rewriter, loc, castB);
1155     genHostUnregisterMemref(rewriter, loc, castA);
1156     genHostUnregisterMemref(rewriter, loc, castR);
1157     if (memC)
1158       genHostUnregisterMemref(rewriter, loc, castC);
1159     genHostUnregisterMemref(rewriter, loc, castV);
1160   }
1161 
1162   // Done.
1163   rewriter.replaceOpWithNewOp<sparse_tensor::LoadOp>(op, c);
1164   return success();
1165 }
1166 
1167 //===----------------------------------------------------------------------===//
1168 // Rewriting rules for direct code generation.
1169 //===----------------------------------------------------------------------===//
1170 
1171 /// Proof-of-concept rewriter. This rule generates a GPU implementation
1172 /// for each outermost forall loop generated by the sparse compiler.
1173 /// TODO: right now works with parallelization-strategy=dense-outer-loop
1174 ///       but give this its own flags in the future
1175 struct ForallRewriter : public OpRewritePattern<scf::ParallelOp> {
1176   using OpRewritePattern<scf::ParallelOp>::OpRewritePattern;
1177 
1178   ForallRewriter(MLIRContext *context, unsigned nT)
1179       : OpRewritePattern(context), numThreads(nT){};
1180 
1181   LogicalResult matchAndRewrite(scf::ParallelOp forallOp,
1182                                 PatternRewriter &rewriter) const override {
1183     // Reject inadmissible loop form.
1184     // Essentially only accept a loop, generated by the sparse compiler,
1185     // of the form
1186     //   forall (i = 0; i < N; i++)
1187     // so that cyclic scheduling over the threads is easy.
1188     if (!forallOp->hasAttr(LoopEmitter::getLoopEmitterLoopAttrName()) ||
1189         forallOp.getNumReductions() != 0 || forallOp.getNumLoops() != 1 ||
1190         !matchPattern(forallOp.getLowerBound()[0], m_Zero()) ||
1191         !matchPattern(forallOp.getStep()[0], m_One()))
1192       return failure();
1193     // Collect every value that is computed outside the parallel loop.
1194     SetVector<Value> invariants; // stable iteration!
1195     forallOp->walk([&](Operation *op) {
1196       // Collect all values of admissible ops.
1197       for (OpOperand &o : op->getOpOperands()) {
1198         Value val = o.get();
1199         Block *block;
1200         if (auto arg = dyn_cast<BlockArgument>(val))
1201           block = arg.getOwner();
1202         else
1203           block = val.getDefiningOp()->getBlock();
1204         if (!isNestedIn(block, forallOp))
1205           invariants.insert(val);
1206       }
1207     });
1208     // Outline the outside values as proper parameters. Fail when sharing
1209     // value between host and device is not straightforward.
1210     SmallVector<Value> constants;
1211     SmallVector<Value> scalars;
1212     SmallVector<Value> buffers;
1213     for (Value val : invariants) {
1214       Type tp = val.getType();
1215       if (val.getDefiningOp<arith::ConstantOp>())
1216         constants.push_back(val);
1217       else if (isa<FloatType>(tp) || tp.isIntOrIndex())
1218         scalars.push_back(val);
1219       else if (isa<MemRefType>(tp))
1220         buffers.push_back(val);
1221       else
1222         return failure(); // don't know how to share
1223     }
1224     // Pass outlined non-constant values.
1225     // TODO: Experiment with `useHostRegistrationForOut` to see if we want to
1226     //       keep the feature at all (either through a heuristic or compiler
1227     //       option for gpu codegen).
1228     Location loc = forallOp->getLoc();
1229     SmallVector<Value> args;
1230     SmallVector<Value> tokens;
1231     Value out = genParametersIn(rewriter, loc, scalars, buffers, args, tokens,
1232                                 /*useHostRegistrationForOut=*/false);
1233     // Set up GPU module and construct GPU function.
1234     auto saveIp = rewriter.saveInsertionPoint();
1235     ModuleOp topModule = forallOp->getParentOfType<ModuleOp>();
1236     auto gpuModule = genGPUModule(rewriter, topModule);
1237     auto gpuFunc = genGPUFunc(rewriter, gpuModule, args);
1238     genGPUCode(rewriter, gpuFunc, forallOp, constants, scalars, buffers);
1239     // Generate code that launches the kernel asynchronously, blocking on all
1240     // opens tokens and yielding a new token for the output.
1241     // TODO: Passing in tokens to launch up does not seem to be properly lowered
1242     //       by cubin yet, hence the current blocking wait.
1243     rewriter.restoreInsertionPoint(saveIp);
1244     genBlockingWait(rewriter, loc, tokens);
1245     tokens.clear();
1246     Value kernelToken =
1247         genLaunchGPUFunc(rewriter, gpuFunc, args, tokens, numThreads);
1248     // Finalize the outlined arguments.
1249     genParametersOut(rewriter, loc, out, kernelToken, scalars, buffers, args,
1250                      tokens);
1251     genBlockingWait(rewriter, loc, tokens);
1252     rewriter.eraseOp(forallOp);
1253     return success();
1254   }
1255 
1256 private:
1257   // Helper method to see if block appears in given loop.
1258   static bool isNestedIn(Block *block, scf::ParallelOp forallOp) {
1259     for (Operation *o = block->getParentOp(); o; o = o->getParentOp()) {
1260       if (o == forallOp)
1261         return true;
1262     }
1263     return false;
1264   }
1265 
1266   unsigned numThreads;
1267 };
1268 
1269 //===----------------------------------------------------------------------===//
1270 // Rewriting rules for library recognition and code generation.
1271 //===----------------------------------------------------------------------===//
1272 
1273 /// Proof-of-concept rewriter. This rule recognizes certain math kernels
1274 /// and replaces these with corresponding calls into a sparse library.
1275 struct LinalgOpRewriter : public OpRewritePattern<linalg::GenericOp> {
1276   using OpRewritePattern<linalg::GenericOp>::OpRewritePattern;
1277 
1278   LinalgOpRewriter(MLIRContext *context, bool rt, GPUDataTransferStrategy t)
1279       : OpRewritePattern(context), enableRT(rt), gpuDataTransferStrategy(t) {}
1280 
1281   LogicalResult matchAndRewrite(linalg::GenericOp op,
1282                                 PatternRewriter &rewriter) const override {
1283     if (op.getNumDpsInits() != 1)
1284       return failure(); // reject multi-output
1285 
1286     const unsigned numLoops = op.getNumLoops();
1287     const unsigned numTensors = op->getNumOperands();
1288     const auto iteratorTypes = op.getIteratorTypesArray();
1289     SmallVector<AffineMap, 4> maps = op.getIndexingMapsArray();
1290 
1291     using MapList = ArrayRef<ArrayRef<AffineExpr>>;
1292     auto infer = [](MapList m) { return AffineMap::inferFromExprList(m); };
1293     AffineExpr i, j, k;
1294     bindDims(getContext(), i, j, k);
1295 
1296     // TODO: more robust patterns, tranposed versions, more kernels,
1297     //       identify alpha and beta and pass them to the CUDA calls.
1298 
1299     // Recognize a SpMV kernel.
1300     if (numLoops == 2 && numTensors == 3 &&
1301         linalg::isParallelIterator(iteratorTypes[0]) &&
1302         linalg::isReductionIterator(iteratorTypes[1]) &&
1303         maps == infer({{i, j}, {j}, {i}}) && matchSumOfMultOfArgs(op)) {
1304       return rewriteSpMV(rewriter, op, enableRT, gpuDataTransferStrategy);
1305     }
1306 
1307     // Recognize a SpGEMM, 2:4-SpMM, or SpMM kernel.
1308     if (numLoops == 3 && numTensors == 3 &&
1309         linalg::isParallelIterator(iteratorTypes[0]) &&
1310         linalg::isParallelIterator(iteratorTypes[1]) &&
1311         linalg::isReductionIterator(iteratorTypes[2]) &&
1312         maps == infer({{i, k}, {k, j}, {i, j}}) && matchSumOfMultOfArgs(op)) {
1313       if (!isDenseTensor(op.getOperand(0)) && !isDenseTensor(op.getOperand(1)))
1314         return rewriteSpGEMM(rewriter, op, enableRT, gpuDataTransferStrategy);
1315       if (op->getAttr("DENSE24"))
1316         return rewrite2To4SpMM(rewriter, op, gpuDataTransferStrategy);
1317       return rewriteSpMM(rewriter, op, enableRT, gpuDataTransferStrategy);
1318     }
1319 
1320     // Recognize a SDDMM kernel.
1321     if (numLoops == 3 && numTensors == 3 &&
1322         linalg::isParallelIterator(iteratorTypes[0]) &&
1323         linalg::isParallelIterator(iteratorTypes[1]) &&
1324         linalg::isReductionIterator(iteratorTypes[2]) &&
1325         maps == infer({{i, k}, {k, j}, {i, j}}) &&
1326         matchSumReductionOfMulUnary(op)) {
1327       return rewriteSDDMM(rewriter, op, enableRT, gpuDataTransferStrategy);
1328     }
1329 
1330     return failure();
1331   }
1332 
1333 private:
1334   bool enableRT;
1335   GPUDataTransferStrategy gpuDataTransferStrategy;
1336 };
1337 
1338 } // namespace
1339 
1340 //===----------------------------------------------------------------------===//
1341 // Public method for populating GPU rewriting rules.
1342 //
1343 // Currently two set of rewriting rules are made available. The first set
1344 // implements direct code generation, currently by means of convering the
1345 // outermost paralell loop into GPU threads. The second set implements
1346 // libary recognition of a set of sparse operations. Eventually, the right
1347 // combination of these two approaches has to be found.
1348 //===----------------------------------------------------------------------===//
1349 
1350 void mlir::populateSparseGPUCodegenPatterns(RewritePatternSet &patterns,
1351                                             unsigned numThreads) {
1352   patterns.add<ForallRewriter>(patterns.getContext(), numThreads);
1353 }
1354 
1355 void mlir::populateSparseGPULibgenPatterns(
1356     RewritePatternSet &patterns, bool enableRT,
1357     GPUDataTransferStrategy gpuDataTransfer) {
1358   patterns.add<LinalgOpRewriter>(patterns.getContext(), enableRT,
1359                                  gpuDataTransfer);
1360 }
1361