xref: /llvm-project/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp (revision aa2952165cd1808dab2bb49b97becc097f4c9cac)
1ee42e236SAart Bik //===- SparseGPUCodegen.cpp - Generates GPU code --------------------------===//
219466ebcSAart Bik //
319466ebcSAart Bik // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
419466ebcSAart Bik // See https://llvm.org/LICENSE.txt for license information.
519466ebcSAart Bik // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
619466ebcSAart Bik //
719466ebcSAart Bik //===----------------------------------------------------------------------===//
819466ebcSAart Bik //
9c43e6274STim Harvey // This is a prototype GPU codegenerator for the sparsifier.
1019466ebcSAart Bik // The objective is to eventually use the right combination of
1119466ebcSAart Bik // direct code generation and libary calls into vendor-specific
1219466ebcSAart Bik // highly optimized sparse libraries (e.g. cuSparse for CUDA).
1319466ebcSAart Bik //
1419466ebcSAart Bik //===----------------------------------------------------------------------===//
1519466ebcSAart Bik 
16365777ecSAart Bik #include "Utils/CodegenUtils.h"
17365777ecSAart Bik #include "Utils/LoopEmitter.h"
1819466ebcSAart Bik 
1919466ebcSAart Bik #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
2019466ebcSAart Bik #include "mlir/Dialect/GPU/IR/GPUDialect.h"
21ee42e236SAart Bik #include "mlir/Dialect/Linalg/IR/Linalg.h"
22ee42e236SAart Bik #include "mlir/Dialect/Linalg/Utils/Utils.h"
2319466ebcSAart Bik #include "mlir/Dialect/MemRef/IR/MemRef.h"
2419466ebcSAart Bik #include "mlir/Dialect/SCF/IR/SCF.h"
2519466ebcSAart Bik #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
26ee42e236SAart Bik #include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h"
2719466ebcSAart Bik #include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
2819466ebcSAart Bik #include "mlir/IR/IRMapping.h"
2919466ebcSAart Bik #include "mlir/IR/Matchers.h"
3019466ebcSAart Bik 
3119466ebcSAart Bik using namespace mlir;
3219466ebcSAart Bik using namespace mlir::sparse_tensor;
3319466ebcSAart Bik 
3419466ebcSAart Bik namespace {
3519466ebcSAart Bik 
363231a365SAart Bik // Sparse formats supported by cuSparse.
373231a365SAart Bik enum class CuSparseFormat {
383231a365SAart Bik   kNone,
393231a365SAart Bik   kCOO,
403231a365SAart Bik   kCSR,
413231a365SAart Bik   kCSC,
423d89c088SAart Bik   kBSR,
433231a365SAart Bik };
443231a365SAart Bik 
4519466ebcSAart Bik //===----------------------------------------------------------------------===//
4619466ebcSAart Bik // Helper methods.
4719466ebcSAart Bik //===----------------------------------------------------------------------===//
4819466ebcSAart Bik 
4919466ebcSAart Bik /// Marks the given top module as a GPU container module.
5019466ebcSAart Bik static void markAsGPUContainer(ModuleOp topModule) {
5119466ebcSAart Bik   topModule->setAttr(gpu::GPUDialect::getContainerModuleAttrName(),
5219466ebcSAart Bik                      UnitAttr::get(topModule->getContext()));
5319466ebcSAart Bik }
5419466ebcSAart Bik 
554889214aSAart Bik /// Constructs a new GPU module (for GPU kernels) inside the given top module,
564889214aSAart Bik /// or returns an existing GPU module if one was built previously.
574889214aSAart Bik static gpu::GPUModuleOp genGPUModule(OpBuilder &builder, ModuleOp topModule) {
584889214aSAart Bik   for (auto op : topModule.getBodyRegion().getOps<gpu::GPUModuleOp>())
594889214aSAart Bik     return op; // existing
6019466ebcSAart Bik   markAsGPUContainer(topModule);
6149df12c0SMatthias Springer   builder.setInsertionPointToStart(topModule.getBody());
624889214aSAart Bik   return builder.create<gpu::GPUModuleOp>(topModule->getLoc(),
634889214aSAart Bik                                           "sparse_kernels");
6419466ebcSAart Bik }
6519466ebcSAart Bik 
6619466ebcSAart Bik /// Constructs a new GPU kernel in the given GPU module.
6719466ebcSAart Bik static gpu::GPUFuncOp genGPUFunc(OpBuilder &builder, gpu::GPUModuleOp gpuModule,
684889214aSAart Bik                                  SmallVectorImpl<Value> &args) {
694889214aSAart Bik   // Get a unique kernel name. Not very creative,
704889214aSAart Bik   // but we simply try kernel0, kernel1, etc.
714889214aSAart Bik   unsigned kernelNumber = 0;
724889214aSAart Bik   SmallString<16> kernelName;
734889214aSAart Bik   do {
744889214aSAart Bik     kernelName.clear();
754889214aSAart Bik     ("kernel" + Twine(kernelNumber++)).toStringRef(kernelName);
764889214aSAart Bik   } while (gpuModule.lookupSymbol(kernelName));
774889214aSAart Bik   // Then we insert a new kernel with given arguments into the module.
7849df12c0SMatthias Springer   builder.setInsertionPointToStart(gpuModule.getBody());
7919466ebcSAart Bik   SmallVector<Type> argsTp;
8056c385cdSMehdi Amini   for (auto arg : args)
8156c385cdSMehdi Amini     argsTp.push_back(arg.getType());
8219466ebcSAart Bik   FunctionType type = FunctionType::get(gpuModule->getContext(), argsTp, {});
8319466ebcSAart Bik   auto gpuFunc =
844889214aSAart Bik       builder.create<gpu::GPUFuncOp>(gpuModule->getLoc(), kernelName, type);
8519466ebcSAart Bik   gpuFunc->setAttr(gpu::GPUDialect::getKernelFuncAttrName(),
8619466ebcSAart Bik                    builder.getUnitAttr());
8719466ebcSAart Bik   return gpuFunc;
8819466ebcSAart Bik }
8919466ebcSAart Bik 
9019466ebcSAart Bik /// Constructs code to launch GPU kernel.
9186888e42SAart Bik static Value genLaunchGPUFunc(OpBuilder &builder, gpu::GPUFuncOp gpuFunc,
9219466ebcSAart Bik                               SmallVectorImpl<Value> &args,
9386888e42SAart Bik                               SmallVectorImpl<Value> &tokens,
9419466ebcSAart Bik                               unsigned numThreads) {
9519466ebcSAart Bik   Location loc = gpuFunc->getLoc();
9619466ebcSAart Bik   Value none = TypedValue<::mlir::IntegerType>{};
9719466ebcSAart Bik   Value one = constantIndex(builder, loc, 1);
9819466ebcSAart Bik   Value numT = constantIndex(builder, loc, numThreads);
9919466ebcSAart Bik   gpu::KernelDim3 gridSize = {one, one, one};
10019466ebcSAart Bik   gpu::KernelDim3 blckSize = {numT, one, one};
10186888e42SAart Bik   return builder
10286888e42SAart Bik       .create<gpu::LaunchFuncOp>(loc, gpuFunc, gridSize, blckSize,
10386888e42SAart Bik                                  /*dynSharedMemSz*/ none, args,
10486888e42SAart Bik                                  builder.getType<gpu::AsyncTokenType>(), tokens)
10586888e42SAart Bik       .getAsyncToken();
10619466ebcSAart Bik }
10719466ebcSAart Bik 
10819466ebcSAart Bik /// Maps the provided ranked host buffer into the device address space.
10919466ebcSAart Bik /// Writes from the host are guaranteed to be visible to device kernels
11019466ebcSAart Bik /// that are launched afterwards. Writes from the device are guaranteed
11119466ebcSAart Bik /// to be visible on the host after synchronizing with the device kernel
11286888e42SAart Bik /// completion. Needs to cast the buffer to a unranked buffer.
11319466ebcSAart Bik static Value genHostRegisterMemref(OpBuilder &builder, Location loc,
11419466ebcSAart Bik                                    Value mem) {
1155550c821STres Popp   MemRefType memTp = cast<MemRefType>(mem.getType());
11619466ebcSAart Bik   UnrankedMemRefType resTp =
11719466ebcSAart Bik       UnrankedMemRefType::get(memTp.getElementType(), /*memorySpace=*/0);
11819466ebcSAart Bik   Value cast = builder.create<memref::CastOp>(loc, resTp, mem);
11919466ebcSAart Bik   builder.create<gpu::HostRegisterOp>(loc, cast);
12086888e42SAart Bik   return cast;
12186888e42SAart Bik }
12286888e42SAart Bik 
12386888e42SAart Bik /// Unmaps the provided buffer, expecting the casted buffer.
12486888e42SAart Bik static void genHostUnregisterMemref(OpBuilder &builder, Location loc,
12586888e42SAart Bik                                     Value cast) {
12686888e42SAart Bik   builder.create<gpu::HostUnregisterOp>(loc, cast);
12786888e42SAart Bik }
12886888e42SAart Bik 
12986888e42SAart Bik /// Generates first wait in an asynchronous chain.
13086888e42SAart Bik static Value genFirstWait(OpBuilder &builder, Location loc) {
13186888e42SAart Bik   Type tokenType = builder.getType<gpu::AsyncTokenType>();
13286888e42SAart Bik   return builder.create<gpu::WaitOp>(loc, tokenType, ValueRange())
13386888e42SAart Bik       .getAsyncToken();
13486888e42SAart Bik }
13586888e42SAart Bik 
13686888e42SAart Bik /// Generates last, blocking wait in an asynchronous chain.
13786888e42SAart Bik static void genBlockingWait(OpBuilder &builder, Location loc,
13886888e42SAart Bik                             ValueRange operands) {
13986888e42SAart Bik   builder.create<gpu::WaitOp>(loc, Type(), operands);
14086888e42SAart Bik }
14186888e42SAart Bik 
14286888e42SAart Bik /// Allocates memory on the device.
14386888e42SAart Bik /// TODO: A `host_shared` attribute could be used to indicate that
14486888e42SAart Bik ///       the buffer is visible by both host and device, but lowering
14586888e42SAart Bik ///       that feature does not seem to be fully supported yet.
14686888e42SAart Bik static gpu::AllocOp genAllocMemRef(OpBuilder &builder, Location loc, Value mem,
14786888e42SAart Bik                                    Value token) {
1485550c821STres Popp   auto tp = cast<ShapedType>(mem.getType());
14986888e42SAart Bik   auto elemTp = tp.getElementType();
15086888e42SAart Bik   auto shape = tp.getShape();
15186888e42SAart Bik   auto memTp = MemRefType::get(shape, elemTp);
15286888e42SAart Bik   SmallVector<Value> dynamicSizes;
15386888e42SAart Bik   for (unsigned r = 0, rank = tp.getRank(); r < rank; r++) {
15486888e42SAart Bik     if (shape[r] == ShapedType::kDynamic) {
155ee42e236SAart Bik       Value dimOp = linalg::createOrFoldDimOp(builder, loc, mem, r);
15686888e42SAart Bik       dynamicSizes.push_back(dimOp);
15786888e42SAart Bik     }
15886888e42SAart Bik   }
15986888e42SAart Bik   return builder.create<gpu::AllocOp>(loc, TypeRange({memTp, token.getType()}),
16086888e42SAart Bik                                       token, dynamicSizes, ValueRange());
16186888e42SAart Bik }
16286888e42SAart Bik 
16376a80a08SAart Bik // Allocates a typed buffer on the host with given size.
16476a80a08SAart Bik static Value genHostBuffer(OpBuilder &builder, Location loc, Type type,
16576a80a08SAart Bik                            Value size) {
16676a80a08SAart Bik   const auto memTp = MemRefType::get({ShapedType::kDynamic}, type);
16776a80a08SAart Bik   return builder.create<memref::AllocOp>(loc, memTp, size).getResult();
16876a80a08SAart Bik }
16976a80a08SAart Bik 
17076a80a08SAart Bik // Allocates a typed buffer on the device with given size.
17176a80a08SAart Bik static gpu::AllocOp genAllocBuffer(OpBuilder &builder, Location loc, Type type,
17276a80a08SAart Bik                                    Value size, Value token) {
17376a80a08SAart Bik   const auto memTp = MemRefType::get({ShapedType::kDynamic}, type);
17476a80a08SAart Bik   return builder.create<gpu::AllocOp>(loc, TypeRange({memTp, token.getType()}),
17576a80a08SAart Bik                                       token, size, ValueRange());
17676a80a08SAart Bik }
17776a80a08SAart Bik 
178ee42e236SAart Bik // Allocates a void buffer on the device with given size.
179ee42e236SAart Bik static gpu::AllocOp genAllocBuffer(OpBuilder &builder, Location loc, Value size,
180ee42e236SAart Bik                                    Value token) {
18176a80a08SAart Bik   return genAllocBuffer(builder, loc, builder.getI8Type(), size, token);
182ee42e236SAart Bik }
183ee42e236SAart Bik 
18486888e42SAart Bik /// Deallocates memory from the device.
18586888e42SAart Bik static Value genDeallocMemRef(OpBuilder &builder, Location loc, Value mem,
18686888e42SAart Bik                               Value token) {
18786888e42SAart Bik   return builder.create<gpu::DeallocOp>(loc, token.getType(), token, mem)
18886888e42SAart Bik       .getAsyncToken();
18986888e42SAart Bik }
19086888e42SAart Bik 
19186888e42SAart Bik /// Copies memory between host and device (direction is implicit).
19286888e42SAart Bik static Value genCopyMemRef(OpBuilder &builder, Location loc, Value dst,
19386888e42SAart Bik                            Value src, Value token) {
19486888e42SAart Bik   return builder.create<gpu::MemcpyOp>(loc, token.getType(), token, dst, src)
19586888e42SAart Bik       .getAsyncToken();
19686888e42SAart Bik }
19786888e42SAart Bik 
198ee42e236SAart Bik /// Generates an alloc/copy pair.
199ee42e236SAart Bik static Value genAllocCopy(OpBuilder &builder, Location loc, Value b,
200ee42e236SAart Bik                           SmallVectorImpl<Value> &tokens) {
201ee42e236SAart Bik   Value firstToken = genFirstWait(builder, loc);
202ee42e236SAart Bik   auto alloc = genAllocMemRef(builder, loc, b, firstToken);
203ee42e236SAart Bik   Value devMem = alloc.getResult(0);
204ee42e236SAart Bik   Value depToken = alloc.getAsyncToken(); // copy-after-alloc
205ee42e236SAart Bik   tokens.push_back(genCopyMemRef(builder, loc, devMem, b, depToken));
206ee42e236SAart Bik   return devMem;
207ee42e236SAart Bik }
208ee42e236SAart Bik 
209ee42e236SAart Bik /// Generates a memref from tensor operation.
210ee42e236SAart Bik static Value genTensorToMemref(PatternRewriter &rewriter, Location loc,
211ee42e236SAart Bik                                Value tensor) {
21268f58812STres Popp   auto tensorType = llvm::cast<ShapedType>(tensor.getType());
213ee42e236SAart Bik   auto memrefType =
214ee42e236SAart Bik       MemRefType::get(tensorType.getShape(), tensorType.getElementType());
215ee42e236SAart Bik   return rewriter.create<bufferization::ToMemrefOp>(loc, memrefType, tensor);
216ee42e236SAart Bik }
217ee42e236SAart Bik 
21886888e42SAart Bik /// Prepares the outlined arguments, passing scalars and buffers in. Here we
21986888e42SAart Bik /// assume that the first buffer is the one allocated for output. We create
22086888e42SAart Bik /// a set of properly chained asynchronous allocation/copy pairs to increase
22186888e42SAart Bik /// overlap before launching the kernel.
22286888e42SAart Bik static Value genParametersIn(OpBuilder &builder, Location loc,
22386888e42SAart Bik                              SmallVectorImpl<Value> &scalars,
22486888e42SAart Bik                              SmallVectorImpl<Value> &buffers,
22586888e42SAart Bik                              SmallVectorImpl<Value> &args,
22686888e42SAart Bik                              SmallVectorImpl<Value> &tokens,
22786888e42SAart Bik                              bool useHostRegistrationForOut) {
22886888e42SAart Bik   Value out;
22986888e42SAart Bik   // Scalars are passed by value.
23086888e42SAart Bik   for (Value s : scalars)
23186888e42SAart Bik     args.push_back(s);
23286888e42SAart Bik   // Buffers are need to be made visible on device.
23386888e42SAart Bik   for (Value b : buffers) {
23486888e42SAart Bik     if (useHostRegistrationForOut) {
23586888e42SAart Bik       out = genHostRegisterMemref(builder, loc, b);
23686888e42SAart Bik       args.push_back(b);
23786888e42SAart Bik       useHostRegistrationForOut = false;
23886888e42SAart Bik       continue;
23986888e42SAart Bik     }
240ee42e236SAart Bik     args.push_back(genAllocCopy(builder, loc, b, tokens));
24186888e42SAart Bik   }
24286888e42SAart Bik   return out;
24386888e42SAart Bik }
24486888e42SAart Bik 
24586888e42SAart Bik /// Finalizes the outlined arguments. The output buffer is copied depending
24686888e42SAart Bik /// on the kernel token and then deallocated. All other buffers are simply
24786888e42SAart Bik /// deallocated. Then we wait for all operations to complete.
24886888e42SAart Bik static void genParametersOut(OpBuilder &builder, Location loc, Value out,
24986888e42SAart Bik                              Value kernelToken, SmallVectorImpl<Value> &scalars,
25086888e42SAart Bik                              SmallVectorImpl<Value> &buffers,
25186888e42SAart Bik                              SmallVectorImpl<Value> &args,
25286888e42SAart Bik                              SmallVectorImpl<Value> &tokens) {
25386888e42SAart Bik   unsigned base = scalars.size();
25486888e42SAart Bik   for (unsigned i = base, e = args.size(); i < e; i++) {
25586888e42SAart Bik     Value firstToken;
25686888e42SAart Bik     if (i == base) {
25786888e42SAart Bik       // Assumed output parameter: unregister or copy-out.
25886888e42SAart Bik       if (out) {
25986888e42SAart Bik         genHostUnregisterMemref(builder, loc, out);
26086888e42SAart Bik         out = Value();
26186888e42SAart Bik         continue;
26286888e42SAart Bik       }
26386888e42SAart Bik       firstToken =
26486888e42SAart Bik           genCopyMemRef(builder, loc, buffers[0], args[i], kernelToken);
26586888e42SAart Bik     } else {
26686888e42SAart Bik       firstToken = genFirstWait(builder, loc);
26786888e42SAart Bik     }
26886888e42SAart Bik     tokens.push_back(genDeallocMemRef(builder, loc, args[i], firstToken));
26986888e42SAart Bik   }
27019466ebcSAart Bik }
27119466ebcSAart Bik 
27219466ebcSAart Bik /// Constructs code for new GPU kernel.
27319466ebcSAart Bik static void genGPUCode(PatternRewriter &rewriter, gpu::GPUFuncOp gpuFunc,
27419466ebcSAart Bik                        scf::ParallelOp forallOp,
27519466ebcSAart Bik                        SmallVectorImpl<Value> &constants,
27619466ebcSAart Bik                        SmallVectorImpl<Value> &scalars,
27719466ebcSAart Bik                        SmallVectorImpl<Value> &buffers) {
27819466ebcSAart Bik   Location loc = gpuFunc->getLoc();
27919466ebcSAart Bik   Block &block = gpuFunc.getBody().front();
28019466ebcSAart Bik   rewriter.setInsertionPointToStart(&block);
28119466ebcSAart Bik 
28219466ebcSAart Bik   // Re-generate the constants, recapture all arguments.
28319466ebcSAart Bik   unsigned arg = 0;
28419466ebcSAart Bik   IRMapping irMap;
28519466ebcSAart Bik   for (Value c : constants)
28619466ebcSAart Bik     irMap.map(c, rewriter.clone(*c.getDefiningOp())->getResult(0));
28719466ebcSAart Bik   for (Value s : scalars)
28819466ebcSAart Bik     irMap.map(s, block.getArgument(arg++));
28919466ebcSAart Bik   for (Value b : buffers)
29019466ebcSAart Bik     irMap.map(b, block.getArgument(arg++));
29119466ebcSAart Bik 
29219466ebcSAart Bik   // Assume 1-dimensional grid/block configuration (only x dimension),
29319466ebcSAart Bik   // so that:
29419466ebcSAart Bik   //   row = blockIdx.x * blockDim.x + threadIdx.x
29519466ebcSAart Bik   //   inc = blockDim.x * gridDim.x
29619466ebcSAart Bik   Value bid = rewriter.create<gpu::BlockIdOp>(loc, gpu::Dimension::x);
29719466ebcSAart Bik   Value bsz = rewriter.create<gpu::BlockDimOp>(loc, gpu::Dimension::x);
29819466ebcSAart Bik   Value tid = rewriter.create<gpu::ThreadIdOp>(loc, gpu::Dimension::x);
29919466ebcSAart Bik   Value gsz = rewriter.create<gpu::GridDimOp>(loc, gpu::Dimension::x);
30019466ebcSAart Bik   Value mul = rewriter.create<arith::MulIOp>(loc, bid, bsz);
30119466ebcSAart Bik   Value row = rewriter.create<arith::AddIOp>(loc, mul, tid);
30219466ebcSAart Bik   Value inc = rewriter.create<arith::MulIOp>(loc, bsz, gsz);
30319466ebcSAart Bik 
30419466ebcSAart Bik   // Construct the iteration over the computational space that
30519466ebcSAart Bik   // accounts for the fact that the total number of threads and
30619466ebcSAart Bik   // the amount of work to be done usually do not match precisely.
30719466ebcSAart Bik   //   for (r = row; r < N; r += inc) {
30819466ebcSAart Bik   //     <loop-body>
30919466ebcSAart Bik   //   }
31019466ebcSAart Bik   Value upper = irMap.lookup(forallOp.getUpperBound()[0]);
31119466ebcSAart Bik   scf::ForOp forOp = rewriter.create<scf::ForOp>(loc, row, upper, inc);
312861600f1SMatthias Springer   // The scf.for builder creates an empty block. scf.for does not allow multiple
313861600f1SMatthias Springer   // blocks in its region, so delete the block before `cloneRegionBefore` adds
314861600f1SMatthias Springer   // an additional block.
315861600f1SMatthias Springer   rewriter.eraseBlock(forOp.getBody());
3169b5ef2beSMatthias Springer   rewriter.cloneRegionBefore(forallOp.getRegion(), forOp.getRegion(),
3179b5ef2beSMatthias Springer                              forOp.getRegion().begin(), irMap);
31810056c82SMatthias Springer   // Replace the scf.reduce terminator.
31910056c82SMatthias Springer   rewriter.setInsertionPoint(forOp.getBody()->getTerminator());
32010056c82SMatthias Springer   rewriter.replaceOpWithNewOp<scf::YieldOp>(forOp.getBody()->getTerminator());
32119466ebcSAart Bik 
32219466ebcSAart Bik   // Done.
32319466ebcSAart Bik   rewriter.setInsertionPointAfter(forOp);
32419466ebcSAart Bik   rewriter.create<gpu::ReturnOp>(gpuFunc->getLoc());
32519466ebcSAart Bik }
32619466ebcSAart Bik 
32719466ebcSAart Bik //===----------------------------------------------------------------------===//
328ee42e236SAart Bik // Library helper methods.
32919466ebcSAart Bik //===----------------------------------------------------------------------===//
33019466ebcSAart Bik 
331f14c8eb5SAart Bik /// Helper to detect a + b with arguments taken from given block.
332f14c8eb5SAart Bik static bool matchAddOfArgs(Block *block, Value val) {
333ee42e236SAart Bik   if (auto *def = val.getDefiningOp()) {
334f14c8eb5SAart Bik     if (isa<arith::AddFOp, arith::AddIOp>(def)) {
335f14c8eb5SAart Bik       Value a = block->getArguments()[0];
336f14c8eb5SAart Bik       Value b = block->getArguments()[1];
337f14c8eb5SAart Bik       return (def->getOperand(0) == a && def->getOperand(1) == b) ||
338f14c8eb5SAart Bik              (def->getOperand(0) == b && def->getOperand(1) == a);
339f14c8eb5SAart Bik     }
340f14c8eb5SAart Bik   }
341f14c8eb5SAart Bik   return false;
342f14c8eb5SAart Bik }
343f14c8eb5SAart Bik 
344f14c8eb5SAart Bik /// Helper to detect a * b with arguments taken from given block.
345f14c8eb5SAart Bik static bool matchMulOfArgs(Block *block, Value val) {
346f14c8eb5SAart Bik   if (auto *def = val.getDefiningOp()) {
347f14c8eb5SAart Bik     if (isa<arith::MulFOp, arith::MulIOp>(def)) {
348f14c8eb5SAart Bik       Value a = block->getArguments()[0];
349f14c8eb5SAart Bik       Value b = block->getArguments()[1];
350ee42e236SAart Bik       return (def->getOperand(0) == a && def->getOperand(1) == b) ||
351ee42e236SAart Bik              (def->getOperand(0) == b && def->getOperand(1) == a);
352ee42e236SAart Bik     }
353ee42e236SAart Bik   }
354ee42e236SAart Bik   return false;
355ee42e236SAart Bik }
356ee42e236SAart Bik 
357ee42e236SAart Bik /// Helper to detect x = x + a * b
358ee42e236SAart Bik static bool matchSumOfMultOfArgs(linalg::GenericOp op) {
359ee42e236SAart Bik   auto yieldOp = cast<linalg::YieldOp>(op.getRegion().front().getTerminator());
360ee42e236SAart Bik   if (auto *def = yieldOp.getOperand(0).getDefiningOp()) {
361f14c8eb5SAart Bik     if (isa<arith::AddFOp, arith::AddIOp>(def)) {
362ee42e236SAart Bik       Value x = op.getBlock()->getArguments()[2];
363ee42e236SAart Bik       return (def->getOperand(0) == x &&
364f14c8eb5SAart Bik               matchMulOfArgs(op.getBlock(), def->getOperand(1))) ||
365ee42e236SAart Bik              (def->getOperand(1) == x &&
366f14c8eb5SAart Bik               matchMulOfArgs(op.getBlock(), def->getOperand(0)));
367ee42e236SAart Bik     }
368ee42e236SAart Bik   }
369ee42e236SAart Bik   return false;
370ee42e236SAart Bik }
371ee42e236SAart Bik 
372f14c8eb5SAart Bik // Helper to detect c += spy(s) x (a * b)
3739167dd46SKun Wu static bool matchSumReductionOfMulUnary(linalg::GenericOp op) {
3749167dd46SKun Wu   auto yieldOp = cast<linalg::YieldOp>(op.getRegion().front().getTerminator());
375f14c8eb5SAart Bik   // The linalg yields a custom reduce result.
3769167dd46SKun Wu   Value s_out = op.getBlock()->getArguments()[2];
377f14c8eb5SAart Bik   if (auto redOp =
378f14c8eb5SAart Bik           yieldOp.getOperand(0).getDefiningOp<sparse_tensor::ReduceOp>()) {
379f14c8eb5SAart Bik     // The reduce consumes the output.
380f14c8eb5SAart Bik     Value other;
381f14c8eb5SAart Bik     if (s_out == redOp->getOperand(0))
382f14c8eb5SAart Bik       other = redOp->getOperand(1);
383f14c8eb5SAart Bik     else if (s_out == redOp->getOperand(1))
384f14c8eb5SAart Bik       other = redOp->getOperand(0);
385f14c8eb5SAart Bik     else
3869167dd46SKun Wu       return false;
387f14c8eb5SAart Bik     // The reduce op also consumes an unary which also consumes the output
388f14c8eb5SAart Bik     // and does not define an absent value.
389f14c8eb5SAart Bik     if (auto unOp = other.getDefiningOp<sparse_tensor::UnaryOp>()) {
390f14c8eb5SAart Bik       if (s_out != unOp->getOperand(0) || !unOp.getAbsentRegion().empty())
3919167dd46SKun Wu         return false;
392f14c8eb5SAart Bik       // And the bodies are as expected.
393f14c8eb5SAart Bik       auto yieldUn = cast<sparse_tensor::YieldOp>(
394f14c8eb5SAart Bik           unOp.getRegion(0).front().getTerminator());
395f14c8eb5SAart Bik       auto yieldRed = cast<sparse_tensor::YieldOp>(
396f14c8eb5SAart Bik           redOp.getRegion().front().getTerminator());
397f14c8eb5SAart Bik       return matchMulOfArgs(op.getBlock(), yieldUn.getOperand(0)) &&
398f14c8eb5SAart Bik              matchAddOfArgs(&redOp.getRegion().front(), yieldRed.getOperand(0));
3999167dd46SKun Wu     }
4009167dd46SKun Wu   }
401f14c8eb5SAart Bik   return false;
4029167dd46SKun Wu }
4039167dd46SKun Wu 
4043231a365SAart Bik /// Test for dense tensor.
405e37fc3ccSK-Wu static bool isDenseTensor(Value v) {
4063231a365SAart Bik   auto sTp = getSparseTensorType(v);
4073231a365SAart Bik   return sTp.getDimRank() == sTp.getLvlRank() && sTp.isAllDense();
408e37fc3ccSK-Wu }
409e37fc3ccSK-Wu 
4103231a365SAart Bik /// Test for suitable positions/coordinates width.
4113231a365SAart Bik static bool isAdmissibleMetaData(SparseTensorType &aTp) {
4123231a365SAart Bik   return (aTp.getPosWidth() == 0 || aTp.getPosWidth() >= 16) &&
4133231a365SAart Bik          (aTp.getCrdWidth() == 0 || aTp.getCrdWidth() >= 16);
4143231a365SAart Bik }
4153231a365SAart Bik 
4163231a365SAart Bik /// Test for sorted COO matrix with suitable metadata.
417ee42e236SAart Bik static bool isAdmissibleCOO(SparseTensorType &aTp) {
4183231a365SAart Bik   return aTp.getDimRank() == 2 && aTp.getLvlRank() == 2 && aTp.isIdentity() &&
4193231a365SAart Bik          aTp.isCompressedLvl(0) && aTp.isOrderedLvl(0) && !aTp.isUniqueLvl(0) &&
420ee42e236SAart Bik          aTp.isSingletonLvl(1) && aTp.isOrderedLvl(1) && aTp.isUniqueLvl(1) &&
4213231a365SAart Bik          isAdmissibleMetaData(aTp);
422ee42e236SAart Bik }
423ee42e236SAart Bik 
4243231a365SAart Bik /// Test for CSR matrix with suitable metadata.
425ee42e236SAart Bik static bool isAdmissibleCSR(SparseTensorType &aTp) {
4263231a365SAart Bik   return aTp.getDimRank() == 2 && aTp.getLvlRank() == 2 && aTp.isIdentity() &&
4273231a365SAart Bik          aTp.isDenseLvl(0) && aTp.isCompressedLvl(1) && aTp.isOrderedLvl(1) &&
4283231a365SAart Bik          aTp.isUniqueLvl(1) && isAdmissibleMetaData(aTp);
429ee42e236SAart Bik }
430ee42e236SAart Bik 
4313231a365SAart Bik /// Test for CSC matrix with suitable metadata.
4323231a365SAart Bik static bool isAdmissibleCSC(SparseTensorType &aTp) {
4333231a365SAart Bik   return aTp.getDimRank() == 2 && aTp.getLvlRank() == 2 && !aTp.isIdentity() &&
4343231a365SAart Bik          aTp.isPermutation() && aTp.isDenseLvl(0) && aTp.isCompressedLvl(1) &&
4353231a365SAart Bik          aTp.isOrderedLvl(1) && aTp.isUniqueLvl(1) && isAdmissibleMetaData(aTp);
436b75d6a40SAart Bik }
4373231a365SAart Bik 
4383d89c088SAart Bik /// Test for BSR matrix with suitable metadata.
4393d89c088SAart Bik static bool isAdmissibleBSR(SparseTensorType &aTp) {
4403d89c088SAart Bik   if (aTp.getDimRank() == 2 && aTp.getLvlRank() == 4 && aTp.isDenseLvl(0) &&
4413d89c088SAart Bik       aTp.isCompressedLvl(1) && aTp.isOrderedLvl(1) && aTp.isUniqueLvl(1) &&
4423d89c088SAart Bik       aTp.isDenseLvl(2) && aTp.isDenseLvl(3) && isAdmissibleMetaData(aTp)) {
4433d89c088SAart Bik     // CuSparse only supports "square" blocks currently.
4443d89c088SAart Bik     SmallVector<unsigned> dims = getBlockSize(aTp.getDimToLvl());
4453d89c088SAart Bik     assert(dims.size() == 2);
446e35b6062SMaksim Levental     return dims[0] == dims[1] && dims[0] > 1;
4473d89c088SAart Bik   }
4483d89c088SAart Bik   return false;
4493d89c088SAart Bik }
4503d89c088SAart Bik 
45141a07e66SAart Bik /// Test for 2:4 matrix with suitable metadata.
45241a07e66SAart Bik static bool isAdmissible24(SparseTensorType &aTp) {
45341a07e66SAart Bik   return aTp.getDimRank() == 2 && aTp.getLvlRank() == 3 && aTp.isDenseLvl(0) &&
454e5924d64SYinying Li          aTp.isDenseLvl(1) && aTp.isNOutOfMLvl(2) && isAdmissibleMetaData(aTp);
45541a07e66SAart Bik }
45641a07e66SAart Bik 
45741a07e66SAart Bik /// Test for conversion into 2:4 matrix.
45841a07e66SAart Bik static bool isConversionInto24(Value v) {
45941a07e66SAart Bik   if (auto cnv = v.getDefiningOp<ConvertOp>()) {
46041a07e66SAart Bik     Value a = cnv.getResult();
46141a07e66SAart Bik     Value d = cnv.getSource();
46241a07e66SAart Bik     SparseTensorType aTp = getSparseTensorType(a);
46341a07e66SAart Bik     return isDenseTensor(d) && isAdmissible24(aTp);
46441a07e66SAart Bik   }
46541a07e66SAart Bik   return false;
46641a07e66SAart Bik }
46741a07e66SAart Bik 
4683231a365SAart Bik /// Returns a suitable sparse format for the operation and given operand
4693231a365SAart Bik /// types with cuSparse, or kNone if none is available.
4703231a365SAart Bik static CuSparseFormat getCuSparseFormat(SparseTensorType aTp,
4713231a365SAart Bik                                         SparseTensorType bTp,
4723231a365SAart Bik                                         SparseTensorType cTp, bool enableRT,
4733231a365SAart Bik                                         bool isMatVec) {
4743231a365SAart Bik   // The other operands have a dense type.
4753231a365SAart Bik   if (bTp.hasEncoding() || cTp.hasEncoding())
4763231a365SAart Bik     return CuSparseFormat::kNone;
4773231a365SAart Bik   // Now check for suitable operand type for the main operand.
4783231a365SAart Bik   if (isAdmissibleCOO(aTp))
4793231a365SAart Bik #ifdef CUSPARSE_COO_AOS
4803231a365SAart Bik     return isMatVec ? CuSparseFormat::kCOO : CuSparseFormat::kNone;
4813231a365SAart Bik #else
4823231a365SAart Bik     return enableRT ? CuSparseFormat::kCOO : CuSparseFormat::kNone;
4833231a365SAart Bik #endif
4843231a365SAart Bik   if (isAdmissibleCSR(aTp))
4853231a365SAart Bik     return CuSparseFormat::kCSR;
4863231a365SAart Bik   if (isAdmissibleCSC(aTp))
4873231a365SAart Bik     return CuSparseFormat::kCSC;
4883d89c088SAart Bik   if (isAdmissibleBSR(aTp))
4893d89c088SAart Bik     return CuSparseFormat::kBSR;
4903231a365SAart Bik   return CuSparseFormat::kNone;
491b75d6a40SAart Bik }
492b75d6a40SAart Bik 
493ee42e236SAart Bik /// Generates the first positions/coordinates of a sparse matrix.
494ee42e236SAart Bik static Value genFirstPosOrCrds(OpBuilder &builder, Location loc, Value a,
4953231a365SAart Bik                                CuSparseFormat format, bool enableRT) {
4963231a365SAart Bik   if (format == CuSparseFormat::kCOO) {
497ee42e236SAart Bik     // Library uses SoA COO, direct IR uses AoS COO.
498ee42e236SAart Bik     if (enableRT)
4991a0986f0SPeiming Liu       return builder.create<ToCoordinatesOp>(loc, a, 0);
5001a0986f0SPeiming Liu     return builder.create<ToCoordinatesBufferOp>(loc, a);
501ee42e236SAart Bik   }
5023231a365SAart Bik   // Formats CSR/CSC and BSR use positions at 1.
5031a0986f0SPeiming Liu   return builder.create<ToPositionsOp>(loc, a, 1);
504ee42e236SAart Bik }
505ee42e236SAart Bik 
506ee42e236SAart Bik /// Generates the second coordinates of a sparse matrix.
507ee42e236SAart Bik static Value genSecondCrds(OpBuilder &builder, Location loc, Value a,
5083231a365SAart Bik                            CuSparseFormat format, bool enableRT) {
5093231a365SAart Bik   bool isCOO = format == CuSparseFormat::kCOO;
510ee42e236SAart Bik   if (isCOO && !enableRT)
511ee42e236SAart Bik     return Value(); // nothing needed
5123231a365SAart Bik   // Formats CSR/CSC and BSR use coordinates at 1.
5131a0986f0SPeiming Liu   return builder.create<ToCoordinatesOp>(loc, a, 1);
514ee42e236SAart Bik }
515ee42e236SAart Bik 
5163231a365SAart Bik /// Generates the sparse matrix handle.
5173d89c088SAart Bik static Operation *genSpMat(OpBuilder &builder, Location loc,
5183d89c088SAart Bik                            SparseTensorType &aTp, Type handleTp, Type tokenTp,
5193d89c088SAart Bik                            Value token, Value sz1, Value sz2, Value nseA,
5203d89c088SAart Bik                            Value rowA, Value colA, Value valA,
5213231a365SAart Bik                            CuSparseFormat format, bool enableRT) {
5223231a365SAart Bik   if (format == CuSparseFormat::kCOO) {
523ee42e236SAart Bik     // Library uses SoA COO, direct IR uses AoS COO.
524bcb698bfSAart Bik     if (enableRT) {
525bcb698bfSAart Bik       assert(colA);
526ee42e236SAart Bik       return builder.create<gpu::CreateCooOp>(loc, handleTp, tokenTp, token,
527bcb698bfSAart Bik                                               sz1, sz2, nseA, rowA, colA, valA);
528bcb698bfSAart Bik     }
5299fc02a7aSAart Bik #ifdef CUSPARSE_COO_AOS
5309fc02a7aSAart Bik     assert(!colA);
5319fc02a7aSAart Bik     return builder.create<gpu::CreateCooAoSOp>(loc, handleTp, tokenTp, token,
5329fc02a7aSAart Bik                                                sz1, sz2, nseA, rowA, valA);
5339fc02a7aSAart Bik #else
534ee42e236SAart Bik     llvm_unreachable("gpu::CreateCooAoSOp is deprecated");
5359fc02a7aSAart Bik #endif
536ee42e236SAart Bik   }
537bcb698bfSAart Bik   assert(colA);
5383231a365SAart Bik   if (format == CuSparseFormat::kCSR)
539bcb698bfSAart Bik     return builder.create<gpu::CreateCsrOp>(loc, handleTp, tokenTp, token, sz1,
540bcb698bfSAart Bik                                             sz2, nseA, rowA, colA, valA);
5413d89c088SAart Bik   if (format == CuSparseFormat::kCSC)
5423231a365SAart Bik     return builder.create<gpu::CreateCscOp>(loc, handleTp, tokenTp, token, sz1,
5433231a365SAart Bik                                             sz2, nseA, rowA, colA, valA);
5443d89c088SAart Bik   // BSR requires a bit more work since we need to pass in the block size
5453d89c088SAart Bik   // and all others sizes in terms of blocks (#block-rows, #block-cols,
5463d89c088SAart Bik   // #nonzero-blocks).
5473d89c088SAart Bik   assert(format == CuSparseFormat::kBSR);
5483d89c088SAart Bik   SmallVector<unsigned> dims = getBlockSize(aTp.getDimToLvl());
5493d89c088SAart Bik   assert(dims.size() == 2 && dims[0] == dims[1]);
5503d89c088SAart Bik   uint64_t b = dims[0];
5513d89c088SAart Bik   Value bSz = constantIndex(builder, loc, b);
5523d89c088SAart Bik   Value bRows = builder.create<arith::DivUIOp>(loc, sz1, bSz);
5533d89c088SAart Bik   Value bCols = builder.create<arith::DivUIOp>(loc, sz2, bSz);
5543d89c088SAart Bik   Value bNum = builder.create<arith::DivUIOp>(
5553d89c088SAart Bik       loc, nseA, constantIndex(builder, loc, b * b));
5563d89c088SAart Bik   return builder.create<gpu::CreateBsrOp>(loc, handleTp, tokenTp, token, bRows,
5573d89c088SAart Bik                                           bCols, bNum, bSz, bSz, rowA, colA,
5583d89c088SAart Bik                                           valA);
559ee42e236SAart Bik }
560ee42e236SAart Bik 
561ee42e236SAart Bik /// Match and rewrite SpMV kernel.
5625ef44679SAart Bik static LogicalResult rewriteSpMV(PatternRewriter &rewriter,
5635ef44679SAart Bik                                  linalg::GenericOp op, bool enableRT) {
564ee42e236SAart Bik   Location loc = op.getLoc();
565ee42e236SAart Bik   Value a = op.getOperand(0);
566ee42e236SAart Bik   Value x = op.getOperand(1);
567ee42e236SAart Bik   Value y = op.getOperand(2); // we have y = Ax
568ee42e236SAart Bik   SmallVector<Value> tokens;
569ee42e236SAart Bik 
5703231a365SAart Bik   // Only admissible sparse matrix format and dense vectors (no BSR).
571ee42e236SAart Bik   SparseTensorType aTp = getSparseTensorType(a);
572ee42e236SAart Bik   SparseTensorType xTp = getSparseTensorType(x);
573ee42e236SAart Bik   SparseTensorType yTp = getSparseTensorType(y);
5743231a365SAart Bik   auto format = getCuSparseFormat(aTp, xTp, yTp, enableRT, /*isMatVec=*/true);
5753231a365SAart Bik   if (format == CuSparseFormat::kNone || format == CuSparseFormat::kBSR)
576ee42e236SAart Bik     return failure();
577ee42e236SAart Bik 
578ee42e236SAart Bik   // Start sparse kernel and copy data from host to device.
579ee42e236SAart Bik   //   a : memR/memC/memV -> rowA,colA,valA
580ee42e236SAart Bik   //   x : memX           -> vecX
581ee42e236SAart Bik   //   y : memY           -> vecY
582b75d6a40SAart Bik   Value nseA = rewriter.create<NumberOfEntriesOp>(loc, a);
583ee42e236SAart Bik   Value szY = linalg::createOrFoldDimOp(rewriter, loc, a, 0);
584ee42e236SAart Bik   Value szX = linalg::createOrFoldDimOp(rewriter, loc, a, 1);
5853231a365SAart Bik   Value memR = genFirstPosOrCrds(rewriter, loc, a, format, enableRT);
5865ef44679SAart Bik   Value memC = genSecondCrds(rewriter, loc, a, format, enableRT); // or empty
5871a0986f0SPeiming Liu   Value memV = rewriter.create<ToValuesOp>(loc, a);
588ee42e236SAart Bik   Value rowA = genAllocCopy(rewriter, loc, memR, tokens);
589ee42e236SAart Bik   Value colA = memC ? genAllocCopy(rewriter, loc, memC, tokens) : Value();
590ee42e236SAart Bik   Value valA = genAllocCopy(rewriter, loc, memV, tokens);
5915ef44679SAart Bik   Value memX = genTensorToMemref(rewriter, loc, x);
5925ef44679SAart Bik   Value vecX = genAllocCopy(rewriter, loc, memX, tokens);
5935ef44679SAart Bik   Value memY = genTensorToMemref(rewriter, loc, y);
594ee42e236SAart Bik   Value vecY = genAllocCopy(rewriter, loc, memY, tokens);
595ee42e236SAart Bik   genBlockingWait(rewriter, loc, tokens);
596ee42e236SAart Bik   tokens.clear();
597ee42e236SAart Bik 
598ee42e236SAart Bik   // Create sparse environment and sparse matrix/dense vector handles.
599ee42e236SAart Bik   Type indexTp = rewriter.getIndexType();
60097f4c22bSKun Wu   Type dnTensorHandleTp = rewriter.getType<gpu::SparseDnTensorHandleType>();
60186bf710cSKun Wu   Type spmatHandleTp = rewriter.getType<gpu::SparseSpMatHandleType>();
602ee42e236SAart Bik   Type tokenTp = rewriter.getType<gpu::AsyncTokenType>();
603ee42e236SAart Bik   Value token = genFirstWait(rewriter, loc);
60486bf710cSKun Wu   Operation *spGenA =
6053d89c088SAart Bik       genSpMat(rewriter, loc, aTp, spmatHandleTp, tokenTp, token, szY, szX,
6063d89c088SAart Bik                nseA, rowA, colA, valA, format, enableRT);
607ee42e236SAart Bik   Value spMatA = spGenA->getResult(0);
608ee42e236SAart Bik   token = spGenA->getResult(1);
60997f4c22bSKun Wu   auto dvecX = rewriter.create<gpu::CreateDnTensorOp>(
610be2dd22bSKun Wu       loc, dnTensorHandleTp, tokenTp, token, vecX, szX);
611ee42e236SAart Bik   Value dnX = dvecX.getResult(0);
612ee42e236SAart Bik   token = dvecX.getAsyncToken();
61397f4c22bSKun Wu   auto dvecY = rewriter.create<gpu::CreateDnTensorOp>(
614be2dd22bSKun Wu       loc, dnTensorHandleTp, tokenTp, token, vecY, szY);
615ee42e236SAart Bik   Value dnY = dvecY.getResult(0);
616ee42e236SAart Bik   token = dvecY.getAsyncToken();
617fa98bdbdSKun Wu   auto dnYType = llvm::cast<ShapedType>(y.getType()).getElementType();
618fa98bdbdSKun Wu 
619ee42e236SAart Bik   // Precompute buffersize for SpMV.
620ee42e236SAart Bik   auto bufferComp = rewriter.create<gpu::SpMVBufferSizeOp>(
621be2dd22bSKun Wu       loc, indexTp, tokenTp, token, spMatA, dnX, dnY,
622fa98bdbdSKun Wu       /*computeType=*/dnYType);
623ee42e236SAart Bik   Value bufferSz = bufferComp.getResult(0);
624ee42e236SAart Bik   token = bufferComp.getAsyncToken();
625ee42e236SAart Bik   auto buf = genAllocBuffer(rewriter, loc, bufferSz, token);
626ee42e236SAart Bik   Value buffer = buf.getResult(0);
627ee42e236SAart Bik   token = buf.getAsyncToken();
628ee42e236SAart Bik 
629ee42e236SAart Bik   // Perform the SpMV.
630be2dd22bSKun Wu   auto spmvComp = rewriter.create<gpu::SpMVOp>(
631be2dd22bSKun Wu       loc, tokenTp, token, spMatA, dnX, dnY, /*computeType=*/dnYType, buffer);
632ee42e236SAart Bik   token = spmvComp.getAsyncToken();
633ee42e236SAart Bik 
634ee42e236SAart Bik   // Copy data back to host and free all the resoures.
635ee42e236SAart Bik   token = rewriter.create<gpu::DestroySpMatOp>(loc, tokenTp, token, spMatA)
636ee42e236SAart Bik               .getAsyncToken();
63797f4c22bSKun Wu   token = rewriter.create<gpu::DestroyDnTensorOp>(loc, tokenTp, token, dnX)
638ee42e236SAart Bik               .getAsyncToken();
63997f4c22bSKun Wu   token = rewriter.create<gpu::DestroyDnTensorOp>(loc, tokenTp, token, dnY)
640ee42e236SAart Bik               .getAsyncToken();
641ee42e236SAart Bik   token = genDeallocMemRef(rewriter, loc, rowA, token);
642ee42e236SAart Bik   if (colA)
643ee42e236SAart Bik     token = genDeallocMemRef(rewriter, loc, colA, token);
644ee42e236SAart Bik   token = genDeallocMemRef(rewriter, loc, valA, token);
645ee42e236SAart Bik   token = genDeallocMemRef(rewriter, loc, buffer, token);
646ee42e236SAart Bik   token = genDeallocMemRef(rewriter, loc, vecX, token);
647bcb698bfSAart Bik   token = genCopyMemRef(rewriter, loc, memY, vecY, token);
648ee42e236SAart Bik   token = genDeallocMemRef(rewriter, loc, vecY, token);
649ee42e236SAart Bik   tokens.push_back(token);
650ee42e236SAart Bik   genBlockingWait(rewriter, loc, tokens);
65176a80a08SAart Bik   tokens.clear();
652ee42e236SAart Bik 
653ee42e236SAart Bik   // Done.
654bcb698bfSAart Bik   rewriter.replaceOpWithNewOp<bufferization::ToTensorOp>(op, memY);
655ee42e236SAart Bik   return success();
656ee42e236SAart Bik }
657ee42e236SAart Bik 
658ee42e236SAart Bik /// Match and rewrite SpMM kernel.
6595ef44679SAart Bik static LogicalResult rewriteSpMM(PatternRewriter &rewriter,
6605ef44679SAart Bik                                  linalg::GenericOp op, bool enableRT) {
661b75d6a40SAart Bik   Location loc = op.getLoc();
662b75d6a40SAart Bik   Value a = op.getOperand(0);
663b75d6a40SAart Bik   Value b = op.getOperand(1);
664b75d6a40SAart Bik   Value c = op.getOperand(2); // we have C = AB
665b75d6a40SAart Bik   SmallVector<Value> tokens;
666b75d6a40SAart Bik 
6673231a365SAart Bik   // Only admissible sparse matrix format and dense matrices (no BSR).
668b75d6a40SAart Bik   SparseTensorType aTp = getSparseTensorType(a);
669b75d6a40SAart Bik   SparseTensorType bTp = getSparseTensorType(b);
670b75d6a40SAart Bik   SparseTensorType cTp = getSparseTensorType(c);
6713231a365SAart Bik   auto format = getCuSparseFormat(aTp, bTp, cTp, enableRT, /*isMatVec=*/false);
6723231a365SAart Bik   if (format == CuSparseFormat::kNone || format == CuSparseFormat::kBSR)
673b75d6a40SAart Bik     return failure();
674b75d6a40SAart Bik 
675b75d6a40SAart Bik   // Start sparse kernel and copy data from host to device.
676b75d6a40SAart Bik   //   a : memR/memC/memV -> rowA,colA,valA
6775ef44679SAart Bik   //   b : bufB           -> matB
678b75d6a40SAart Bik   //   c : bufC           -> matC
679b75d6a40SAart Bik   Value nseA = rewriter.create<NumberOfEntriesOp>(loc, a);
680b75d6a40SAart Bik   Value szm = linalg::createOrFoldDimOp(rewriter, loc, a, 0);
681b75d6a40SAart Bik   Value szk = linalg::createOrFoldDimOp(rewriter, loc, a, 1);
682b75d6a40SAart Bik   Value szn = linalg::createOrFoldDimOp(rewriter, loc, b, 1);
6833231a365SAart Bik   Value memR = genFirstPosOrCrds(rewriter, loc, a, format, enableRT);
6845ef44679SAart Bik   Value memC = genSecondCrds(rewriter, loc, a, format, enableRT); // or empty
6851a0986f0SPeiming Liu   Value memV = rewriter.create<ToValuesOp>(loc, a);
686b75d6a40SAart Bik   Value rowA = genAllocCopy(rewriter, loc, memR, tokens);
687b75d6a40SAart Bik   Value colA = memC ? genAllocCopy(rewriter, loc, memC, tokens) : Value();
688b75d6a40SAart Bik   Value valA = genAllocCopy(rewriter, loc, memV, tokens);
6895ef44679SAart Bik   Value bufB = genTensorToMemref(rewriter, loc, b);
6905ef44679SAart Bik   Value matB = genAllocCopy(rewriter, loc, bufB, tokens);
6915ef44679SAart Bik   Value bufC = genTensorToMemref(rewriter, loc, c);
692b75d6a40SAart Bik   Value matC = genAllocCopy(rewriter, loc, bufC, tokens);
693b75d6a40SAart Bik   genBlockingWait(rewriter, loc, tokens);
694b75d6a40SAart Bik   tokens.clear();
695b75d6a40SAart Bik 
696b75d6a40SAart Bik   // Create sparse environment and sparse matrix/dense matrix handles.
697b75d6a40SAart Bik   Type indexTp = rewriter.getIndexType();
69897f4c22bSKun Wu   Type dnTensorHandleTp = rewriter.getType<gpu::SparseDnTensorHandleType>();
69986bf710cSKun Wu   Type spMatHandleTp = rewriter.getType<gpu::SparseSpMatHandleType>();
700b75d6a40SAart Bik   Type tokenTp = rewriter.getType<gpu::AsyncTokenType>();
701b75d6a40SAart Bik   Value token = genFirstWait(rewriter, loc);
70286bf710cSKun Wu   Operation *spGenA =
7033d89c088SAart Bik       genSpMat(rewriter, loc, aTp, spMatHandleTp, tokenTp, token, szm, szk,
7043d89c088SAart Bik                nseA, rowA, colA, valA, format, enableRT);
705b75d6a40SAart Bik   Value spMatA = spGenA->getResult(0);
706b75d6a40SAart Bik   token = spGenA->getResult(1);
70797f4c22bSKun Wu   auto dmatB = rewriter.create<gpu::CreateDnTensorOp>(
708be2dd22bSKun Wu       loc, dnTensorHandleTp, tokenTp, token, matB,
70997f4c22bSKun Wu       SmallVector<Value>{szk, szn});
710b75d6a40SAart Bik   Value dnB = dmatB.getResult(0);
711b75d6a40SAart Bik   token = dmatB.getAsyncToken();
71297f4c22bSKun Wu   auto dmatC = rewriter.create<gpu::CreateDnTensorOp>(
713be2dd22bSKun Wu       loc, dnTensorHandleTp, tokenTp, token, matC,
71497f4c22bSKun Wu       SmallVector<Value>{szm, szn});
715b75d6a40SAart Bik   Value dnC = dmatC.getResult(0);
716b75d6a40SAart Bik   token = dmatC.getAsyncToken();
717fa98bdbdSKun Wu   auto dmatCType = llvm::cast<ShapedType>(c.getType()).getElementType();
718fa98bdbdSKun Wu 
719b75d6a40SAart Bik   // Precompute buffersize for SpMM.
720b75d6a40SAart Bik   auto bufferComp = rewriter.create<gpu::SpMMBufferSizeOp>(
721be2dd22bSKun Wu       loc, indexTp, tokenTp, token, spMatA, dnB, dnC,
722fa98bdbdSKun Wu       /*computeType=*/dmatCType);
723b75d6a40SAart Bik   Value bufferSz = bufferComp.getResult(0);
724b75d6a40SAart Bik   token = bufferComp.getAsyncToken();
725b75d6a40SAart Bik   auto buf = genAllocBuffer(rewriter, loc, bufferSz, token);
726b75d6a40SAart Bik   Value buffer = buf.getResult(0);
727b75d6a40SAart Bik   token = buf.getAsyncToken();
728fa98bdbdSKun Wu   auto dnCType = llvm::cast<ShapedType>(c.getType()).getElementType();
729fa98bdbdSKun Wu 
730b75d6a40SAart Bik   // Perform the SpMM.
731be2dd22bSKun Wu   auto spmmComp = rewriter.create<gpu::SpMMOp>(
732be2dd22bSKun Wu       loc, tokenTp, token, spMatA, dnB, dnC, /*computeType=*/dnCType, buffer);
733b75d6a40SAart Bik   token = spmmComp.getAsyncToken();
734b75d6a40SAart Bik 
735b75d6a40SAart Bik   // Copy data back to host and free all the resoures.
736b75d6a40SAart Bik   token = rewriter.create<gpu::DestroySpMatOp>(loc, tokenTp, token, spMatA)
737b75d6a40SAart Bik               .getAsyncToken();
73897f4c22bSKun Wu   token = rewriter.create<gpu::DestroyDnTensorOp>(loc, tokenTp, token, dnB)
739b75d6a40SAart Bik               .getAsyncToken();
74097f4c22bSKun Wu   token = rewriter.create<gpu::DestroyDnTensorOp>(loc, tokenTp, token, dnC)
741b75d6a40SAart Bik               .getAsyncToken();
74203125e68SAart Bik   token = genDeallocMemRef(rewriter, loc, rowA, token);
743b75d6a40SAart Bik   if (colA)
744b75d6a40SAart Bik     token = genDeallocMemRef(rewriter, loc, colA, token);
745b75d6a40SAart Bik   token = genDeallocMemRef(rewriter, loc, valA, token);
746b75d6a40SAart Bik   token = genDeallocMemRef(rewriter, loc, buffer, token);
747b75d6a40SAart Bik   token = genDeallocMemRef(rewriter, loc, matB, token);
748bcb698bfSAart Bik   token = genCopyMemRef(rewriter, loc, bufC, matC, token);
749b75d6a40SAart Bik   token = genDeallocMemRef(rewriter, loc, matC, token);
750b75d6a40SAart Bik   tokens.push_back(token);
751b75d6a40SAart Bik   genBlockingWait(rewriter, loc, tokens);
75276a80a08SAart Bik   tokens.clear();
753b75d6a40SAart Bik 
754b75d6a40SAart Bik   // Done.
75522caafc9SAart Bik   rewriter.replaceOpWithNewOp<bufferization::ToTensorOp>(op, bufC);
756b75d6a40SAart Bik   return success();
757ee42e236SAart Bik }
758ee42e236SAart Bik 
75976a80a08SAart Bik // Match and rewrite SpGEMM kernel.
7605ef44679SAart Bik static LogicalResult rewriteSpGEMM(PatternRewriter &rewriter,
7615ef44679SAart Bik                                    linalg::GenericOp op, bool enableRT) {
76276a80a08SAart Bik   Location loc = op.getLoc();
76376a80a08SAart Bik   Value a = op.getOperand(0);
76476a80a08SAart Bik   Value b = op.getOperand(1);
76576a80a08SAart Bik   Value c = op.getOperand(2); // we have C = AB
76676a80a08SAart Bik   SmallVector<Value> tokens;
76776a80a08SAart Bik 
76876a80a08SAart Bik   // Only CSR <- CSR x CSR supported.
7693231a365SAart Bik   auto format = CuSparseFormat::kCSR;
77076a80a08SAart Bik   SparseTensorType aTp = getSparseTensorType(a);
77176a80a08SAart Bik   SparseTensorType bTp = getSparseTensorType(b);
77276a80a08SAart Bik   SparseTensorType cTp = getSparseTensorType(c);
77376a80a08SAart Bik   if (!isAdmissibleCSR(aTp) || !isAdmissibleCSR(bTp) || !isAdmissibleCSR(cTp))
77476a80a08SAart Bik     return failure();
77576a80a08SAart Bik 
77676a80a08SAart Bik   // Start sparse kernel and copy data from host to device.
77776a80a08SAart Bik   //   a : amemR/amemC/amemV -> rowA,colA,valA
77876a80a08SAart Bik   //   b : bmemR/bmemC/bmemV -> rowB,colB,valB
77976a80a08SAart Bik   //   c : materializes
78076a80a08SAart Bik   auto dnCType = cTp.getElementType();
78176a80a08SAart Bik   Value nseA = rewriter.create<NumberOfEntriesOp>(loc, a);
78276a80a08SAart Bik   Value nseB = rewriter.create<NumberOfEntriesOp>(loc, b);
78376a80a08SAart Bik   Value szm = linalg::createOrFoldDimOp(rewriter, loc, a, 0);
78476a80a08SAart Bik   Value szk = linalg::createOrFoldDimOp(rewriter, loc, a, 1);
78576a80a08SAart Bik   Value szn = linalg::createOrFoldDimOp(rewriter, loc, b, 1);
7863231a365SAart Bik   Value amemR = genFirstPosOrCrds(rewriter, loc, a, format, enableRT);
7875ef44679SAart Bik   Value amemC = genSecondCrds(rewriter, loc, a, format, enableRT); // not empty
7881a0986f0SPeiming Liu   Value amemV = rewriter.create<ToValuesOp>(loc, a);
7893231a365SAart Bik   Value bmemR = genFirstPosOrCrds(rewriter, loc, b, format, enableRT);
7905ef44679SAart Bik   Value bmemC = genSecondCrds(rewriter, loc, b, format, enableRT); // not empty
7911a0986f0SPeiming Liu   Value bmemV = rewriter.create<ToValuesOp>(loc, b);
79276a80a08SAart Bik   Value rowA = genAllocCopy(rewriter, loc, amemR, tokens);
79376a80a08SAart Bik   Value colA = genAllocCopy(rewriter, loc, amemC, tokens);
79476a80a08SAart Bik   Value valA = genAllocCopy(rewriter, loc, amemV, tokens);
79576a80a08SAart Bik   Value rowB = genAllocCopy(rewriter, loc, bmemR, tokens);
79676a80a08SAart Bik   Value colB = genAllocCopy(rewriter, loc, bmemC, tokens);
79776a80a08SAart Bik   Value valB = genAllocCopy(rewriter, loc, bmemV, tokens);
79876a80a08SAart Bik   genBlockingWait(rewriter, loc, tokens);
79976a80a08SAart Bik   tokens.clear();
80076a80a08SAart Bik 
80176a80a08SAart Bik   // Create sparse environment and sparse matrix/dense vector handles.
80276a80a08SAart Bik   Type indexTp = rewriter.getIndexType();
80376a80a08SAart Bik   Type spmatHandleTp = rewriter.getType<gpu::SparseSpMatHandleType>();
80476a80a08SAart Bik   Type descTp = rewriter.getType<gpu::SparseSpGEMMOpHandleType>();
80576a80a08SAart Bik   Type tokenTp = rewriter.getType<gpu::AsyncTokenType>();
80676a80a08SAart Bik   Value token = genFirstWait(rewriter, loc);
80776a80a08SAart Bik   Operation *spGenA =
8083d89c088SAart Bik       genSpMat(rewriter, loc, aTp, spmatHandleTp, tokenTp, token, szm, szk,
8093d89c088SAart Bik                nseA, rowA, colA, valA, format, enableRT);
81076a80a08SAart Bik   Value spMatA = spGenA->getResult(0);
81176a80a08SAart Bik   token = spGenA->getResult(1);
81276a80a08SAart Bik   Operation *spGenB =
8133d89c088SAart Bik       genSpMat(rewriter, loc, bTp, spmatHandleTp, tokenTp, token, szk, szn,
8143d89c088SAart Bik                nseB, rowB, colB, valB, format, enableRT);
81576a80a08SAart Bik   Value spMatB = spGenB->getResult(0);
81676a80a08SAart Bik   token = spGenB->getResult(1);
81776a80a08SAart Bik 
81876a80a08SAart Bik   // Sparse matrix C materializes (also assumes beta == 0).
81976a80a08SAart Bik   Value zero = constantIndex(rewriter, loc, 0);
82076a80a08SAart Bik   Value one = constantIndex(rewriter, loc, 1);
82176a80a08SAart Bik   Value mplus1 = rewriter.create<arith::AddIOp>(loc, szm, one);
82276a80a08SAart Bik   auto e1 = genAllocBuffer(rewriter, loc, cTp.getPosType(), mplus1, token);
82376a80a08SAart Bik   Value rowC = e1.getResult(0);
82476a80a08SAart Bik   token = e1.getAsyncToken();
82576a80a08SAart Bik   auto e2 = genAllocBuffer(rewriter, loc, cTp.getCrdType(), zero, token);
826619a888dSAart Bik   Value colC = e2.getResult(0); // no free needed
82776a80a08SAart Bik   token = e2.getAsyncToken();
82876a80a08SAart Bik   auto e3 = genAllocBuffer(rewriter, loc, dnCType, zero, token);
829619a888dSAart Bik   Value valC = e3.getResult(0); // no free needed
83076a80a08SAart Bik   token = e3.getAsyncToken();
83176a80a08SAart Bik   Operation *spGenC =
8323d89c088SAart Bik       genSpMat(rewriter, loc, cTp, spmatHandleTp, tokenTp, token, szm, szn,
8333d89c088SAart Bik                zero, rowC, colC, valC, format, enableRT);
83476a80a08SAart Bik   Value spMatC = spGenC->getResult(0);
83576a80a08SAart Bik   token = spGenC->getResult(1);
83676a80a08SAart Bik 
83776a80a08SAart Bik   // Precompute buffersizes for SpGEMM.
83876a80a08SAart Bik   Operation *descOp =
83976a80a08SAart Bik       rewriter.create<gpu::SpGEMMCreateDescrOp>(loc, descTp, tokenTp, token);
84076a80a08SAart Bik   Value desc = descOp->getResult(0);
84176a80a08SAart Bik   token = descOp->getResult(1);
84276a80a08SAart Bik   Operation *work1 = rewriter.create<gpu::SpGEMMWorkEstimationOrComputeOp>(
84376a80a08SAart Bik       loc, indexTp, tokenTp, token, desc, gpu::TransposeMode::NON_TRANSPOSE,
84476a80a08SAart Bik       gpu::TransposeMode::NON_TRANSPOSE, spMatA, spMatB, spMatC, dnCType, zero,
84576a80a08SAart Bik       valC, gpu::SpGEMMWorkEstimationOrComputeKind::WORK_ESTIMATION);
84676a80a08SAart Bik   Value bufferSz1 = work1->getResult(0);
84776a80a08SAart Bik   token = work1->getResult(1);
84876a80a08SAart Bik   auto buf1 = genAllocBuffer(rewriter, loc, bufferSz1, token);
84976a80a08SAart Bik   Value buffer1 = buf1.getResult(0);
85076a80a08SAart Bik   token = buf1.getAsyncToken();
85176a80a08SAart Bik   Operation *work2 = rewriter.create<gpu::SpGEMMWorkEstimationOrComputeOp>(
85276a80a08SAart Bik       loc, indexTp, tokenTp, token, desc, gpu::TransposeMode::NON_TRANSPOSE,
85376a80a08SAart Bik       gpu::TransposeMode::NON_TRANSPOSE, spMatA, spMatB, spMatC, dnCType,
85476a80a08SAart Bik       bufferSz1, buffer1,
85576a80a08SAart Bik       gpu::SpGEMMWorkEstimationOrComputeKind::WORK_ESTIMATION);
85676a80a08SAart Bik   token = work2->getResult(1);
85776a80a08SAart Bik 
85876a80a08SAart Bik   // Compute step.
85976a80a08SAart Bik   Operation *compute1 = rewriter.create<gpu::SpGEMMWorkEstimationOrComputeOp>(
86076a80a08SAart Bik       loc, indexTp, tokenTp, token, desc, gpu::TransposeMode::NON_TRANSPOSE,
86176a80a08SAart Bik       gpu::TransposeMode::NON_TRANSPOSE, spMatA, spMatB, spMatC, dnCType, zero,
86276a80a08SAart Bik       valC, gpu::SpGEMMWorkEstimationOrComputeKind::COMPUTE);
86376a80a08SAart Bik   Value bufferSz2 = compute1->getResult(0);
86476a80a08SAart Bik   token = compute1->getResult(1);
86576a80a08SAart Bik   auto buf2 = genAllocBuffer(rewriter, loc, bufferSz2, token);
86676a80a08SAart Bik   Value buffer2 = buf2.getResult(0);
86776a80a08SAart Bik   token = buf2.getAsyncToken();
86876a80a08SAart Bik   Operation *compute2 = rewriter.create<gpu::SpGEMMWorkEstimationOrComputeOp>(
86976a80a08SAart Bik       loc, indexTp, tokenTp, token, desc, gpu::TransposeMode::NON_TRANSPOSE,
87076a80a08SAart Bik       gpu::TransposeMode::NON_TRANSPOSE, spMatA, spMatB, spMatC, dnCType,
87176a80a08SAart Bik       bufferSz2, buffer2, gpu::SpGEMMWorkEstimationOrComputeKind::COMPUTE);
87276a80a08SAart Bik   token = compute2->getResult(1);
87376a80a08SAart Bik 
87476a80a08SAart Bik   // Get sizes.
875289f7231SAart Bik   Operation *sizes = rewriter.create<gpu::SpMatGetSizeOp>(
87676a80a08SAart Bik       loc, indexTp, indexTp, indexTp, tokenTp, token, spMatC);
87776a80a08SAart Bik   Value nnz = sizes->getResult(2);
87876a80a08SAart Bik   token = sizes->getResult(3);
87976a80a08SAart Bik   auto a2 = genAllocBuffer(rewriter, loc, cTp.getCrdType(), nnz, token);
88076a80a08SAart Bik   colC = a2.getResult(0);
88176a80a08SAart Bik   token = a2.getAsyncToken();
88276a80a08SAart Bik   auto a3 = genAllocBuffer(rewriter, loc, dnCType, nnz, token);
88376a80a08SAart Bik   valC = a3.getResult(0);
88476a80a08SAart Bik   token = a3.getAsyncToken();
88576a80a08SAart Bik 
88676a80a08SAart Bik   // Update C with new pointers and copy final product back into C.
88776a80a08SAart Bik   Operation *update = rewriter.create<gpu::SetCsrPointersOp>(
88876a80a08SAart Bik       loc, tokenTp, token, spMatC, rowC, colC, valC);
88976a80a08SAart Bik   token = update->getResult(0);
89076a80a08SAart Bik   Operation *copy = rewriter.create<gpu::SpGEMMCopyOp>(
89176a80a08SAart Bik       loc, tokenTp, token, desc, gpu::TransposeMode::NON_TRANSPOSE,
89276a80a08SAart Bik       gpu::TransposeMode::NON_TRANSPOSE, spMatA, spMatB, spMatC, dnCType);
89376a80a08SAart Bik   token = copy->getResult(0);
89476a80a08SAart Bik 
89576a80a08SAart Bik   // Allocate buffers on host.
89676a80a08SAart Bik   Value rowH = genHostBuffer(rewriter, loc, cTp.getPosType(), mplus1);
89776a80a08SAart Bik   Value colH = genHostBuffer(rewriter, loc, cTp.getCrdType(), nnz);
89876a80a08SAart Bik   Value valH = genHostBuffer(rewriter, loc, dnCType, nnz);
89976a80a08SAart Bik 
90076a80a08SAart Bik   // Copy data back to host and free all the resoures.
90176a80a08SAart Bik   token = rewriter.create<gpu::SpGEMMDestroyDescrOp>(loc, tokenTp, token, desc)
90276a80a08SAart Bik               .getAsyncToken();
90376a80a08SAart Bik   token = rewriter.create<gpu::DestroySpMatOp>(loc, tokenTp, token, spMatA)
90476a80a08SAart Bik               .getAsyncToken();
90576a80a08SAart Bik   token = rewriter.create<gpu::DestroySpMatOp>(loc, tokenTp, token, spMatB)
90676a80a08SAart Bik               .getAsyncToken();
90776a80a08SAart Bik   token = rewriter.create<gpu::DestroySpMatOp>(loc, tokenTp, token, spMatC)
90876a80a08SAart Bik               .getAsyncToken();
90976a80a08SAart Bik   token = genCopyMemRef(rewriter, loc, rowH, rowC, token);
91076a80a08SAart Bik   token = genCopyMemRef(rewriter, loc, colH, colC, token);
91176a80a08SAart Bik   token = genCopyMemRef(rewriter, loc, valH, valC, token);
912619a888dSAart Bik   token = genDeallocMemRef(rewriter, loc, rowA, token);
913619a888dSAart Bik   token = genDeallocMemRef(rewriter, loc, colA, token);
914619a888dSAart Bik   token = genDeallocMemRef(rewriter, loc, valA, token);
915619a888dSAart Bik   token = genDeallocMemRef(rewriter, loc, rowB, token);
916619a888dSAart Bik   token = genDeallocMemRef(rewriter, loc, colB, token);
917619a888dSAart Bik   token = genDeallocMemRef(rewriter, loc, valB, token);
918619a888dSAart Bik   token = genDeallocMemRef(rewriter, loc, rowC, token);
919619a888dSAart Bik   token = genDeallocMemRef(rewriter, loc, colC, token);
920619a888dSAart Bik   token = genDeallocMemRef(rewriter, loc, valC, token);
921619a888dSAart Bik   token = genDeallocMemRef(rewriter, loc, buffer1, token);
922619a888dSAart Bik   token = genDeallocMemRef(rewriter, loc, buffer2, token);
92376a80a08SAart Bik   tokens.push_back(token);
92476a80a08SAart Bik   genBlockingWait(rewriter, loc, tokens);
92576a80a08SAart Bik   tokens.clear();
92676a80a08SAart Bik 
92776a80a08SAart Bik   // Done.
92876a80a08SAart Bik   Value vt = rewriter.create<bufferization::ToTensorOp>(loc, valH);
92976a80a08SAart Bik   Value rt = rewriter.create<bufferization::ToTensorOp>(loc, rowH);
93076a80a08SAart Bik   Value ct = rewriter.create<bufferization::ToTensorOp>(loc, colH);
931fc9f1d49SPeiming Liu   rewriter.replaceOpWithNewOp<AssembleOp>(op, c.getType(), ValueRange{rt, ct},
932fc9f1d49SPeiming Liu                                           vt);
93376a80a08SAart Bik   return success();
93476a80a08SAart Bik }
93576a80a08SAart Bik 
93676a80a08SAart Bik // Match and rewrite 2:4 SpMM kernel.
9375ef44679SAart Bik static LogicalResult rewrite2To4SpMM(PatternRewriter &rewriter,
9385ef44679SAart Bik                                      linalg::GenericOp op) {
939e37fc3ccSK-Wu   Location loc = op.getLoc();
940e37fc3ccSK-Wu   Value A = op.getOperand(0);
941e37fc3ccSK-Wu   Value B = op.getOperand(1);
942e37fc3ccSK-Wu   Value C = op.getOperand(2); // we have C = AB
943e37fc3ccSK-Wu   SmallVector<Value> tokens;
944e37fc3ccSK-Wu 
94541a07e66SAart Bik   // The cuSparselt API currently only allows pruning and compression
94641a07e66SAart Bik   // to occur on the device. So we recognize the pattern
94741a07e66SAart Bik   //    A' = convert A  ; dense to 2:4
94841a07e66SAart Bik   //    C  = A'B        ; 2:4 matrix mult
94941a07e66SAart Bik   // and then perform compression and matrix multiplication on device.
95041a07e66SAart Bik   auto cnv = A.getDefiningOp<ConvertOp>();
95141a07e66SAart Bik   assert(cnv);
95241a07e66SAart Bik   A = cnv.getSource();
95341a07e66SAart Bik 
954e37fc3ccSK-Wu   // All input should be dense tensors.
955e37fc3ccSK-Wu   if (!isDenseTensor(A) || !isDenseTensor(B) || !isDenseTensor(C))
956e37fc3ccSK-Wu     return failure();
957e37fc3ccSK-Wu 
9585ef44679SAart Bik   // Start sparse kernel and copy data from host to device.
9595ef44679SAart Bik   //   a : bufA -> matA
9605ef44679SAart Bik   //   b : bufB -> matB
9615ef44679SAart Bik   //   c : bufC -> matC
962e37fc3ccSK-Wu   Value bufA = genTensorToMemref(rewriter, loc, A);
9635ef44679SAart Bik   Value matA = genAllocCopy(rewriter, loc, bufA, tokens);
964e37fc3ccSK-Wu   Value bufB = genTensorToMemref(rewriter, loc, B);
9655ef44679SAart Bik   Value matB = genAllocCopy(rewriter, loc, bufB, tokens);
966e37fc3ccSK-Wu   Value bufC = genTensorToMemref(rewriter, loc, C);
967e37fc3ccSK-Wu   Value matC = genAllocCopy(rewriter, loc, bufC, tokens);
968e37fc3ccSK-Wu   genBlockingWait(rewriter, loc, tokens);
969e37fc3ccSK-Wu   tokens.clear();
97076a80a08SAart Bik 
97176a80a08SAart Bik   // Create sparse environment and sparse matrix/dense vector handles.
972e37fc3ccSK-Wu   Value szm = linalg::createOrFoldDimOp(rewriter, loc, matA, 0);
973e37fc3ccSK-Wu   Value szk = linalg::createOrFoldDimOp(rewriter, loc, matB, 0);
974e37fc3ccSK-Wu   Value szn = linalg::createOrFoldDimOp(rewriter, loc, matC, 1);
975e37fc3ccSK-Wu   Type indexTp = rewriter.getIndexType();
976e37fc3ccSK-Wu   Type dnTensorHandleTp = rewriter.getType<gpu::SparseDnTensorHandleType>();
977e37fc3ccSK-Wu   Type spMatHandleTp = rewriter.getType<gpu::SparseSpMatHandleType>();
978e37fc3ccSK-Wu   Type tokenTp = rewriter.getType<gpu::AsyncTokenType>();
979e37fc3ccSK-Wu   Value token = genFirstWait(rewriter, loc);
980e37fc3ccSK-Wu   Operation *spGenA = rewriter.create<gpu::Create2To4SpMatOp>(
9811e491c42SKun Wu       loc, spMatHandleTp, tokenTp, token, szm, szk,
9821e491c42SKun Wu       gpu::Prune2To4SpMatFlag::PRUNE_AND_CHECK, matA);
983e37fc3ccSK-Wu   Value spMatA = spGenA->getResult(0);
984e37fc3ccSK-Wu   token = spGenA->getResult(1);
985e37fc3ccSK-Wu   auto dmatB = rewriter.create<gpu::CreateDnTensorOp>(
986e37fc3ccSK-Wu       loc, dnTensorHandleTp, tokenTp, token, matB,
987e37fc3ccSK-Wu       SmallVector<Value>{szk, szn});
988e37fc3ccSK-Wu   Value dnB = dmatB.getResult(0);
989e37fc3ccSK-Wu   token = dmatB.getAsyncToken();
990e37fc3ccSK-Wu   auto dmatC = rewriter.create<gpu::CreateDnTensorOp>(
991e37fc3ccSK-Wu       loc, dnTensorHandleTp, tokenTp, token, matC,
992e37fc3ccSK-Wu       SmallVector<Value>{szm, szn});
993e37fc3ccSK-Wu   Value dnC = dmatC.getResult(0);
994e37fc3ccSK-Wu   token = dmatC.getAsyncToken();
995e37fc3ccSK-Wu   auto dmatCType = llvm::cast<ShapedType>(matC.getType()).getElementType();
996e37fc3ccSK-Wu 
997e37fc3ccSK-Wu   // Precompute buffersize for SpMM.
998e37fc3ccSK-Wu   SmallVector<Type> bufferTypes_{indexTp, indexTp, indexTp};
999e37fc3ccSK-Wu   TypeRange bufferTypes(bufferTypes_);
1000e37fc3ccSK-Wu   auto bufferComp = rewriter.create<gpu::SpMMBufferSizeOp>(
1001e37fc3ccSK-Wu       loc, bufferTypes, tokenTp, token, gpu::TransposeMode::NON_TRANSPOSE,
1002e37fc3ccSK-Wu       gpu::TransposeMode::NON_TRANSPOSE, spMatA, dnB, dnC,
1003e37fc3ccSK-Wu       /*computeType=*/dmatCType);
1004e37fc3ccSK-Wu   token = bufferComp.getAsyncToken();
100576a80a08SAart Bik 
10065ef44679SAart Bik   // Allocate buffers on host.
10075ef44679SAart Bik   Value bufferSz1 = bufferComp.getResult(0);
10085ef44679SAart Bik   auto buf1 = genAllocBuffer(rewriter, loc, bufferSz1, token);
10095ef44679SAart Bik   Value buffer1 = buf1.getResult(0);
10105ef44679SAart Bik   token = buf1.getAsyncToken();
1011e37fc3ccSK-Wu   Value bufferSz2 = bufferComp.getResult(1);
1012e37fc3ccSK-Wu   auto buf2 = genAllocBuffer(rewriter, loc, bufferSz2, token);
1013e37fc3ccSK-Wu   Value buffer2 = buf2.getResult(0);
1014e37fc3ccSK-Wu   token = buf2.getAsyncToken();
1015e37fc3ccSK-Wu   Value bufferSz3 = bufferComp.getResult(2);
1016e37fc3ccSK-Wu   auto buf3 = genAllocBuffer(rewriter, loc, bufferSz3, token);
1017e37fc3ccSK-Wu   Value buffer3 = buf3.getResult(0);
1018e37fc3ccSK-Wu   token = buf3.getAsyncToken();
1019e37fc3ccSK-Wu 
1020e37fc3ccSK-Wu   // Perform the SpMM.
10215ef44679SAart Bik   auto dnCType = llvm::cast<ShapedType>(matC.getType()).getElementType();
1022e37fc3ccSK-Wu   auto spmmComp = rewriter.create<gpu::SpMMOp>(
1023e37fc3ccSK-Wu       loc, tokenTp, token, spMatA, dnB, dnC, /*computeType=*/dnCType,
10245ef44679SAart Bik       SmallVector<Value>{buffer1, buffer2, buffer3});
1025e37fc3ccSK-Wu   token = spmmComp.getAsyncToken();
1026e37fc3ccSK-Wu 
1027e37fc3ccSK-Wu   // Copy data back to host and free all the resources.
1028e37fc3ccSK-Wu   token = rewriter.create<gpu::DestroySpMatOp>(loc, tokenTp, token, spMatA)
1029e37fc3ccSK-Wu               .getAsyncToken();
1030e37fc3ccSK-Wu   token = rewriter.create<gpu::DestroyDnTensorOp>(loc, tokenTp, token, dnB)
1031e37fc3ccSK-Wu               .getAsyncToken();
1032e37fc3ccSK-Wu   token = rewriter.create<gpu::DestroyDnTensorOp>(loc, tokenTp, token, dnC)
1033e37fc3ccSK-Wu               .getAsyncToken();
1034e37fc3ccSK-Wu   SmallVector<Value> newDynamicSizes;
10355ef44679SAart Bik   token = genDeallocMemRef(rewriter, loc, buffer1, token);
1036e37fc3ccSK-Wu   token = genDeallocMemRef(rewriter, loc, buffer2, token);
1037e37fc3ccSK-Wu   token = genDeallocMemRef(rewriter, loc, buffer3, token);
1038e37fc3ccSK-Wu   token = genDeallocMemRef(rewriter, loc, matA, token);
1039e37fc3ccSK-Wu   token = genDeallocMemRef(rewriter, loc, matB, token);
1040e37fc3ccSK-Wu   token = genCopyMemRef(rewriter, loc, bufC, matC, token);
1041e37fc3ccSK-Wu   token = genDeallocMemRef(rewriter, loc, matC, token);
1042e37fc3ccSK-Wu   tokens.push_back(token);
1043e37fc3ccSK-Wu   genBlockingWait(rewriter, loc, tokens);
104476a80a08SAart Bik   tokens.clear();
104576a80a08SAart Bik 
104676a80a08SAart Bik   // Done.
1047e37fc3ccSK-Wu   rewriter.replaceOpWithNewOp<bufferization::ToTensorOp>(op, bufC);
1048e37fc3ccSK-Wu   return success();
1049e37fc3ccSK-Wu }
1050e37fc3ccSK-Wu 
10519167dd46SKun Wu /// Match and rewrite SDDMM kernel.
10525ef44679SAart Bik static LogicalResult rewriteSDDMM(PatternRewriter &rewriter,
10535ef44679SAart Bik                                   linalg::GenericOp op, bool enableRT) {
10549167dd46SKun Wu   Location loc = op.getLoc();
10559167dd46SKun Wu   Value a = op.getOperand(0);
10569167dd46SKun Wu   Value b = op.getOperand(1);
10579167dd46SKun Wu   Value c = op.getOperand(2);
10589167dd46SKun Wu   SmallVector<Value> tokens;
10599167dd46SKun Wu 
10603231a365SAart Bik   // Only admissible sparse matrix format (no COO/CSC) and dense matrices.
10619167dd46SKun Wu   SparseTensorType aTp = getSparseTensorType(a);
10629167dd46SKun Wu   SparseTensorType bTp = getSparseTensorType(b);
10639167dd46SKun Wu   SparseTensorType cTp = getSparseTensorType(c);
10643231a365SAart Bik   auto format = getCuSparseFormat(cTp, bTp, aTp, enableRT, /*isMatVec=*/false);
10653231a365SAart Bik   if (format == CuSparseFormat::kNone || format == CuSparseFormat::kCOO ||
10663231a365SAart Bik       format == CuSparseFormat::kCSC)
10679167dd46SKun Wu     return failure();
10689167dd46SKun Wu 
10699167dd46SKun Wu   // The SDDMM does the in-place operation.
10709167dd46SKun Wu   // Start sparse kernel and copy data from host to device.
10719167dd46SKun Wu   //   a : bufA           -> matA
10725ef44679SAart Bik   //   b : bufB           -> matB
10739167dd46SKun Wu   //   c : memR/memC/memV -> rowC,colC,valC
10749167dd46SKun Wu   Value nseC = rewriter.create<NumberOfEntriesOp>(loc, c);
10759167dd46SKun Wu   Value szm = linalg::createOrFoldDimOp(rewriter, loc, a, 0);
10769167dd46SKun Wu   Value szk = linalg::createOrFoldDimOp(rewriter, loc, a, 1);
10779167dd46SKun Wu   Value szn = linalg::createOrFoldDimOp(rewriter, loc, b, 1);
10789167dd46SKun Wu   Value bufA = genTensorToMemref(rewriter, loc, a);
10795ef44679SAart Bik   Value matA = genAllocCopy(rewriter, loc, bufA, tokens);
10809167dd46SKun Wu   Value bufB = genTensorToMemref(rewriter, loc, b);
10815ef44679SAart Bik   Value matB = genAllocCopy(rewriter, loc, bufB, tokens);
10823231a365SAart Bik   Value memR = genFirstPosOrCrds(rewriter, loc, c, format, enableRT);
10835ef44679SAart Bik   Value memC = genSecondCrds(rewriter, loc, c, format, enableRT); // or empty
10841a0986f0SPeiming Liu   Value memV = rewriter.create<ToValuesOp>(loc, c);
10859167dd46SKun Wu   Value rowC = genAllocCopy(rewriter, loc, memR, tokens);
10869167dd46SKun Wu   Value colC = memC ? genAllocCopy(rewriter, loc, memC, tokens) : Value();
10879167dd46SKun Wu   Value valC = genAllocCopy(rewriter, loc, memV, tokens);
10889167dd46SKun Wu   genBlockingWait(rewriter, loc, tokens);
10899167dd46SKun Wu   tokens.clear();
10909167dd46SKun Wu 
10919167dd46SKun Wu   // Create sparse environment and sparse matrix/dense matrix handles.
10929167dd46SKun Wu   Type indexTp = rewriter.getIndexType();
10939167dd46SKun Wu   Type dnMatHandleTp = rewriter.getType<gpu::SparseDnTensorHandleType>();
10949167dd46SKun Wu   Type spMatHandleTp = rewriter.getType<gpu::SparseSpMatHandleType>();
10959167dd46SKun Wu   Type tokenTp = rewriter.getType<gpu::AsyncTokenType>();
10969167dd46SKun Wu   Value token = genFirstWait(rewriter, loc);
10979167dd46SKun Wu   auto dmatA = rewriter.create<gpu::CreateDnTensorOp>(
1098be2dd22bSKun Wu       loc, dnMatHandleTp, tokenTp, token, matA, SmallVector<Value>{szm, szk});
10999167dd46SKun Wu   Value dnA = dmatA.getResult(0);
11009167dd46SKun Wu   token = dmatA.getAsyncToken();
11019167dd46SKun Wu   auto dmatB = rewriter.create<gpu::CreateDnTensorOp>(
1102be2dd22bSKun Wu       loc, dnMatHandleTp, tokenTp, token, matB, SmallVector<Value>{szk, szn});
11039167dd46SKun Wu   Value dnB = dmatB.getResult(0);
11049167dd46SKun Wu   token = dmatB.getAsyncToken();
11059167dd46SKun Wu   Operation *spGenC =
11063d89c088SAart Bik       genSpMat(rewriter, loc, cTp, spMatHandleTp, tokenTp, token, szm, szn,
11073d89c088SAart Bik                nseC, rowC, colC, valC, format, enableRT);
11089167dd46SKun Wu   Value spMatC = spGenC->getResult(0);
11099167dd46SKun Wu   token = spGenC->getResult(1);
11109167dd46SKun Wu   auto dnCType = llvm::cast<ShapedType>(c.getType()).getElementType();
111176a80a08SAart Bik 
11129167dd46SKun Wu   // Precompute buffersize for SDDMM.
11139167dd46SKun Wu   auto bufferComp = rewriter.create<gpu::SDDMMBufferSizeOp>(
1114be2dd22bSKun Wu       loc, indexTp, tokenTp, token, dnA, dnB, spMatC, dnCType);
11159167dd46SKun Wu   Value bufferSz = bufferComp.getResult(0);
11169167dd46SKun Wu   token = bufferComp.getAsyncToken();
11179167dd46SKun Wu   auto buf = genAllocBuffer(rewriter, loc, bufferSz, token);
11189167dd46SKun Wu   Value buffer = buf.getResult(0);
11199167dd46SKun Wu   token = buf.getAsyncToken();
11209167dd46SKun Wu 
11219167dd46SKun Wu   // Perform the SDDMM.
1122be2dd22bSKun Wu   auto sddmmComp = rewriter.create<gpu::SDDMMOp>(loc, tokenTp, token, dnA, dnB,
1123be2dd22bSKun Wu                                                  spMatC, dnCType, buffer);
11249167dd46SKun Wu   token = sddmmComp.getAsyncToken();
11259167dd46SKun Wu 
11269167dd46SKun Wu   // Copy data back to host and free all the resoures.
11279167dd46SKun Wu   token = rewriter.create<gpu::DestroyDnTensorOp>(loc, tokenTp, token, dnA)
11289167dd46SKun Wu               .getAsyncToken();
11299167dd46SKun Wu   token = rewriter.create<gpu::DestroyDnTensorOp>(loc, tokenTp, token, dnB)
11309167dd46SKun Wu               .getAsyncToken();
11319167dd46SKun Wu   token = rewriter.create<gpu::DestroySpMatOp>(loc, tokenTp, token, spMatC)
11329167dd46SKun Wu               .getAsyncToken();
11339167dd46SKun Wu   token = genDeallocMemRef(rewriter, loc, buffer, token);
11349167dd46SKun Wu   token = genDeallocMemRef(rewriter, loc, matA, token);
11359167dd46SKun Wu   token = genDeallocMemRef(rewriter, loc, matB, token);
11369167dd46SKun Wu   token = genDeallocMemRef(rewriter, loc, rowC, token);
11379167dd46SKun Wu   if (colC)
11389167dd46SKun Wu     token = genDeallocMemRef(rewriter, loc, colC, token);
11399167dd46SKun Wu   token = genCopyMemRef(rewriter, loc, memV, valC, token);
11409167dd46SKun Wu   token = genDeallocMemRef(rewriter, loc, valC, token);
11419167dd46SKun Wu   tokens.push_back(token);
11429167dd46SKun Wu   genBlockingWait(rewriter, loc, tokens);
114376a80a08SAart Bik   tokens.clear();
11449167dd46SKun Wu 
1145f14c8eb5SAart Bik   // Done.
11469167dd46SKun Wu   rewriter.replaceOpWithNewOp<sparse_tensor::LoadOp>(op, c);
11479167dd46SKun Wu   return success();
11489167dd46SKun Wu }
11499167dd46SKun Wu 
1150ee42e236SAart Bik //===----------------------------------------------------------------------===//
1151ee42e236SAart Bik // Rewriting rules for direct code generation.
1152ee42e236SAart Bik //===----------------------------------------------------------------------===//
1153ee42e236SAart Bik 
1154ee42e236SAart Bik /// Proof-of-concept rewriter. This rule generates a GPU implementation
1155c43e6274STim Harvey /// for each outermost forall loop generated by the sparsifier.
115676a80a08SAart Bik /// TODO: right now works with parallelization-strategy=dense-outer-loop
115786888e42SAart Bik ///       but give this its own flags in the future
115819466ebcSAart Bik struct ForallRewriter : public OpRewritePattern<scf::ParallelOp> {
115919466ebcSAart Bik   using OpRewritePattern<scf::ParallelOp>::OpRewritePattern;
116019466ebcSAart Bik 
116119466ebcSAart Bik   ForallRewriter(MLIRContext *context, unsigned nT)
116219466ebcSAart Bik       : OpRewritePattern(context), numThreads(nT){};
116319466ebcSAart Bik 
116419466ebcSAart Bik   LogicalResult matchAndRewrite(scf::ParallelOp forallOp,
116519466ebcSAart Bik                                 PatternRewriter &rewriter) const override {
116619466ebcSAart Bik     // Reject inadmissible loop form.
1167c43e6274STim Harvey     // Essentially only accept a loop, generated by the sparsifier,
116819466ebcSAart Bik     // of the form
116919466ebcSAart Bik     //   forall (i = 0; i < N; i++)
117019466ebcSAart Bik     // so that cyclic scheduling over the threads is easy.
117119466ebcSAart Bik     if (!forallOp->hasAttr(LoopEmitter::getLoopEmitterLoopAttrName()) ||
117219466ebcSAart Bik         forallOp.getNumReductions() != 0 || forallOp.getNumLoops() != 1 ||
117319466ebcSAart Bik         !matchPattern(forallOp.getLowerBound()[0], m_Zero()) ||
117419466ebcSAart Bik         !matchPattern(forallOp.getStep()[0], m_One()))
117519466ebcSAart Bik       return failure();
117619466ebcSAart Bik     // Collect every value that is computed outside the parallel loop.
117719466ebcSAart Bik     SetVector<Value> invariants; // stable iteration!
117819466ebcSAart Bik     forallOp->walk([&](Operation *op) {
117919466ebcSAart Bik       // Collect all values of admissible ops.
118019466ebcSAart Bik       for (OpOperand &o : op->getOpOperands()) {
118119466ebcSAart Bik         Value val = o.get();
118219466ebcSAart Bik         Block *block;
11835550c821STres Popp         if (auto arg = dyn_cast<BlockArgument>(val))
118419466ebcSAart Bik           block = arg.getOwner();
118519466ebcSAart Bik         else
118619466ebcSAart Bik           block = val.getDefiningOp()->getBlock();
1187ea979b24SMatthias Springer         if (!forallOp.getRegion().findAncestorBlockInRegion(*block))
118819466ebcSAart Bik           invariants.insert(val);
118919466ebcSAart Bik       }
119019466ebcSAart Bik     });
119119466ebcSAart Bik     // Outline the outside values as proper parameters. Fail when sharing
119219466ebcSAart Bik     // value between host and device is not straightforward.
119319466ebcSAart Bik     SmallVector<Value> constants;
119419466ebcSAart Bik     SmallVector<Value> scalars;
119519466ebcSAart Bik     SmallVector<Value> buffers;
119619466ebcSAart Bik     for (Value val : invariants) {
119719466ebcSAart Bik       Type tp = val.getType();
119819466ebcSAart Bik       if (val.getDefiningOp<arith::ConstantOp>())
119919466ebcSAart Bik         constants.push_back(val);
12005550c821STres Popp       else if (isa<FloatType>(tp) || tp.isIntOrIndex())
120119466ebcSAart Bik         scalars.push_back(val);
120219466ebcSAart Bik       else if (isa<MemRefType>(tp))
120319466ebcSAart Bik         buffers.push_back(val);
120419466ebcSAart Bik       else
120519466ebcSAart Bik         return failure(); // don't know how to share
120619466ebcSAart Bik     }
120786888e42SAart Bik     // Pass outlined non-constant values.
120886888e42SAart Bik     // TODO: Experiment with `useHostRegistrationForOut` to see if we want to
120986888e42SAart Bik     //       keep the feature at all (either through a heuristic or compiler
121086888e42SAart Bik     //       option for gpu codegen).
121119466ebcSAart Bik     Location loc = forallOp->getLoc();
121219466ebcSAart Bik     SmallVector<Value> args;
121386888e42SAart Bik     SmallVector<Value> tokens;
121486888e42SAart Bik     Value out = genParametersIn(rewriter, loc, scalars, buffers, args, tokens,
121586888e42SAart Bik                                 /*useHostRegistrationForOut=*/false);
121619466ebcSAart Bik     // Set up GPU module and construct GPU function.
121786888e42SAart Bik     auto saveIp = rewriter.saveInsertionPoint();
121819466ebcSAart Bik     ModuleOp topModule = forallOp->getParentOfType<ModuleOp>();
12194889214aSAart Bik     auto gpuModule = genGPUModule(rewriter, topModule);
12204889214aSAart Bik     auto gpuFunc = genGPUFunc(rewriter, gpuModule, args);
122119466ebcSAart Bik     genGPUCode(rewriter, gpuFunc, forallOp, constants, scalars, buffers);
122286888e42SAart Bik     // Generate code that launches the kernel asynchronously, blocking on all
122386888e42SAart Bik     // opens tokens and yielding a new token for the output.
122486888e42SAart Bik     // TODO: Passing in tokens to launch up does not seem to be properly lowered
122586888e42SAart Bik     //       by cubin yet, hence the current blocking wait.
122619466ebcSAart Bik     rewriter.restoreInsertionPoint(saveIp);
122786888e42SAart Bik     genBlockingWait(rewriter, loc, tokens);
122886888e42SAart Bik     tokens.clear();
122986888e42SAart Bik     Value kernelToken =
123086888e42SAart Bik         genLaunchGPUFunc(rewriter, gpuFunc, args, tokens, numThreads);
123186888e42SAart Bik     // Finalize the outlined arguments.
123286888e42SAart Bik     genParametersOut(rewriter, loc, out, kernelToken, scalars, buffers, args,
123386888e42SAart Bik                      tokens);
123486888e42SAart Bik     genBlockingWait(rewriter, loc, tokens);
123519466ebcSAart Bik     rewriter.eraseOp(forallOp);
123619466ebcSAart Bik     return success();
123719466ebcSAart Bik   }
123819466ebcSAart Bik 
123919466ebcSAart Bik private:
124019466ebcSAart Bik   unsigned numThreads;
124119466ebcSAart Bik };
124219466ebcSAart Bik 
1243ee42e236SAart Bik //===----------------------------------------------------------------------===//
1244ee42e236SAart Bik // Rewriting rules for library recognition and code generation.
1245ee42e236SAart Bik //===----------------------------------------------------------------------===//
1246ee42e236SAart Bik 
1247ee42e236SAart Bik /// Proof-of-concept rewriter. This rule recognizes certain math kernels
1248b75d6a40SAart Bik /// and replaces these with corresponding calls into a sparse library.
1249ee42e236SAart Bik struct LinalgOpRewriter : public OpRewritePattern<linalg::GenericOp> {
1250ee42e236SAart Bik   using OpRewritePattern<linalg::GenericOp>::OpRewritePattern;
1251ee42e236SAart Bik 
12525ef44679SAart Bik   LinalgOpRewriter(MLIRContext *context, bool rt)
12535ef44679SAart Bik       : OpRewritePattern(context), enableRT(rt) {}
1254ee42e236SAart Bik 
1255ee42e236SAart Bik   LogicalResult matchAndRewrite(linalg::GenericOp op,
1256ee42e236SAart Bik                                 PatternRewriter &rewriter) const override {
1257ee42e236SAart Bik     if (op.getNumDpsInits() != 1)
1258ee42e236SAart Bik       return failure(); // reject multi-output
1259ee42e236SAart Bik 
1260ee42e236SAart Bik     const unsigned numLoops = op.getNumLoops();
1261ee42e236SAart Bik     const unsigned numTensors = op->getNumOperands();
1262ee42e236SAart Bik     const auto iteratorTypes = op.getIteratorTypesArray();
1263ee42e236SAart Bik     SmallVector<AffineMap, 4> maps = op.getIndexingMapsArray();
1264ee42e236SAart Bik 
1265ee42e236SAart Bik     using MapList = ArrayRef<ArrayRef<AffineExpr>>;
1266fe8a62c4SUday Bondhugula     auto infer = [&](MapList m) {
1267fe8a62c4SUday Bondhugula       return AffineMap::inferFromExprList(m, op.getContext());
1268fe8a62c4SUday Bondhugula     };
1269ee42e236SAart Bik     AffineExpr i, j, k;
1270ee42e236SAart Bik     bindDims(getContext(), i, j, k);
1271ee42e236SAart Bik 
1272*aa295216SJay Foad     // TODO: more robust patterns, transposed versions, more kernels,
127376a80a08SAart Bik     //       identify alpha and beta and pass them to the CUDA calls.
1274ee42e236SAart Bik 
1275ee42e236SAart Bik     // Recognize a SpMV kernel.
1276ee42e236SAart Bik     if (numLoops == 2 && numTensors == 3 &&
1277ee42e236SAart Bik         linalg::isParallelIterator(iteratorTypes[0]) &&
1278ee42e236SAart Bik         linalg::isReductionIterator(iteratorTypes[1]) &&
1279ee42e236SAart Bik         maps == infer({{i, j}, {j}, {i}}) && matchSumOfMultOfArgs(op)) {
12805ef44679SAart Bik       return rewriteSpMV(rewriter, op, enableRT);
1281ee42e236SAart Bik     }
1282ee42e236SAart Bik 
128376a80a08SAart Bik     // Recognize a SpGEMM, 2:4-SpMM, or SpMM kernel.
1284ee42e236SAart Bik     if (numLoops == 3 && numTensors == 3 &&
1285ee42e236SAart Bik         linalg::isParallelIterator(iteratorTypes[0]) &&
1286ee42e236SAart Bik         linalg::isParallelIterator(iteratorTypes[1]) &&
1287ee42e236SAart Bik         linalg::isReductionIterator(iteratorTypes[2]) &&
1288ee42e236SAart Bik         maps == infer({{i, k}, {k, j}, {i, j}}) && matchSumOfMultOfArgs(op)) {
128976a80a08SAart Bik       if (!isDenseTensor(op.getOperand(0)) && !isDenseTensor(op.getOperand(1)))
12905ef44679SAart Bik         return rewriteSpGEMM(rewriter, op, enableRT);
129141a07e66SAart Bik       if (isConversionInto24(op.getOperand(0)))
12925ef44679SAart Bik         return rewrite2To4SpMM(rewriter, op);
12935ef44679SAart Bik       return rewriteSpMM(rewriter, op, enableRT);
1294ee42e236SAart Bik     }
1295ee42e236SAart Bik 
12969167dd46SKun Wu     // Recognize a SDDMM kernel.
12979167dd46SKun Wu     if (numLoops == 3 && numTensors == 3 &&
12989167dd46SKun Wu         linalg::isParallelIterator(iteratorTypes[0]) &&
12999167dd46SKun Wu         linalg::isParallelIterator(iteratorTypes[1]) &&
13009167dd46SKun Wu         linalg::isReductionIterator(iteratorTypes[2]) &&
13019167dd46SKun Wu         maps == infer({{i, k}, {k, j}, {i, j}}) &&
13029167dd46SKun Wu         matchSumReductionOfMulUnary(op)) {
13035ef44679SAart Bik       return rewriteSDDMM(rewriter, op, enableRT);
13049167dd46SKun Wu     }
13059167dd46SKun Wu 
1306ee42e236SAart Bik     return failure();
1307ee42e236SAart Bik   }
1308ee42e236SAart Bik 
1309ee42e236SAart Bik private:
1310ee42e236SAart Bik   bool enableRT;
1311ee42e236SAart Bik };
1312ee42e236SAart Bik 
131319466ebcSAart Bik } // namespace
131419466ebcSAart Bik 
131519466ebcSAart Bik //===----------------------------------------------------------------------===//
131619466ebcSAart Bik // Public method for populating GPU rewriting rules.
1317ee42e236SAart Bik //
1318ee42e236SAart Bik // Currently two set of rewriting rules are made available. The first set
1319ee42e236SAart Bik // implements direct code generation, currently by means of convering the
1320ee42e236SAart Bik // outermost paralell loop into GPU threads. The second set implements
1321ee42e236SAart Bik // libary recognition of a set of sparse operations. Eventually, the right
1322ee42e236SAart Bik // combination of these two approaches has to be found.
132319466ebcSAart Bik //===----------------------------------------------------------------------===//
132419466ebcSAart Bik 
132519466ebcSAart Bik void mlir::populateSparseGPUCodegenPatterns(RewritePatternSet &patterns,
132619466ebcSAart Bik                                             unsigned numThreads) {
132719466ebcSAart Bik   patterns.add<ForallRewriter>(patterns.getContext(), numThreads);
132819466ebcSAart Bik }
1329ee42e236SAart Bik 
13305ef44679SAart Bik void mlir::populateSparseGPULibgenPatterns(RewritePatternSet &patterns,
13315ef44679SAart Bik                                            bool enableRT) {
13325ef44679SAart Bik   patterns.add<LinalgOpRewriter>(patterns.getContext(), enableRT);
1333ee42e236SAart Bik }
1334