1ee42e236SAart Bik //===- SparseGPUCodegen.cpp - Generates GPU code --------------------------===// 219466ebcSAart Bik // 319466ebcSAart Bik // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 419466ebcSAart Bik // See https://llvm.org/LICENSE.txt for license information. 519466ebcSAart Bik // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 619466ebcSAart Bik // 719466ebcSAart Bik //===----------------------------------------------------------------------===// 819466ebcSAart Bik // 9c43e6274STim Harvey // This is a prototype GPU codegenerator for the sparsifier. 1019466ebcSAart Bik // The objective is to eventually use the right combination of 1119466ebcSAart Bik // direct code generation and libary calls into vendor-specific 1219466ebcSAart Bik // highly optimized sparse libraries (e.g. cuSparse for CUDA). 1319466ebcSAart Bik // 1419466ebcSAart Bik //===----------------------------------------------------------------------===// 1519466ebcSAart Bik 16365777ecSAart Bik #include "Utils/CodegenUtils.h" 17365777ecSAart Bik #include "Utils/LoopEmitter.h" 1819466ebcSAart Bik 1919466ebcSAart Bik #include "mlir/Dialect/Bufferization/IR/Bufferization.h" 2019466ebcSAart Bik #include "mlir/Dialect/GPU/IR/GPUDialect.h" 21ee42e236SAart Bik #include "mlir/Dialect/Linalg/IR/Linalg.h" 22ee42e236SAart Bik #include "mlir/Dialect/Linalg/Utils/Utils.h" 2319466ebcSAart Bik #include "mlir/Dialect/MemRef/IR/MemRef.h" 2419466ebcSAart Bik #include "mlir/Dialect/SCF/IR/SCF.h" 2519466ebcSAart Bik #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" 26ee42e236SAart Bik #include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h" 2719466ebcSAart Bik #include "mlir/Dialect/SparseTensor/Transforms/Passes.h" 2819466ebcSAart Bik #include "mlir/IR/IRMapping.h" 2919466ebcSAart Bik #include "mlir/IR/Matchers.h" 3019466ebcSAart Bik 3119466ebcSAart Bik using namespace mlir; 3219466ebcSAart Bik using namespace mlir::sparse_tensor; 3319466ebcSAart Bik 3419466ebcSAart Bik namespace { 3519466ebcSAart Bik 363231a365SAart Bik // Sparse formats supported by cuSparse. 373231a365SAart Bik enum class CuSparseFormat { 383231a365SAart Bik kNone, 393231a365SAart Bik kCOO, 403231a365SAart Bik kCSR, 413231a365SAart Bik kCSC, 423d89c088SAart Bik kBSR, 433231a365SAart Bik }; 443231a365SAart Bik 4519466ebcSAart Bik //===----------------------------------------------------------------------===// 4619466ebcSAart Bik // Helper methods. 4719466ebcSAart Bik //===----------------------------------------------------------------------===// 4819466ebcSAart Bik 4919466ebcSAart Bik /// Marks the given top module as a GPU container module. 5019466ebcSAart Bik static void markAsGPUContainer(ModuleOp topModule) { 5119466ebcSAart Bik topModule->setAttr(gpu::GPUDialect::getContainerModuleAttrName(), 5219466ebcSAart Bik UnitAttr::get(topModule->getContext())); 5319466ebcSAart Bik } 5419466ebcSAart Bik 554889214aSAart Bik /// Constructs a new GPU module (for GPU kernels) inside the given top module, 564889214aSAart Bik /// or returns an existing GPU module if one was built previously. 574889214aSAart Bik static gpu::GPUModuleOp genGPUModule(OpBuilder &builder, ModuleOp topModule) { 584889214aSAart Bik for (auto op : topModule.getBodyRegion().getOps<gpu::GPUModuleOp>()) 594889214aSAart Bik return op; // existing 6019466ebcSAart Bik markAsGPUContainer(topModule); 6149df12c0SMatthias Springer builder.setInsertionPointToStart(topModule.getBody()); 624889214aSAart Bik return builder.create<gpu::GPUModuleOp>(topModule->getLoc(), 634889214aSAart Bik "sparse_kernels"); 6419466ebcSAart Bik } 6519466ebcSAart Bik 6619466ebcSAart Bik /// Constructs a new GPU kernel in the given GPU module. 6719466ebcSAart Bik static gpu::GPUFuncOp genGPUFunc(OpBuilder &builder, gpu::GPUModuleOp gpuModule, 684889214aSAart Bik SmallVectorImpl<Value> &args) { 694889214aSAart Bik // Get a unique kernel name. Not very creative, 704889214aSAart Bik // but we simply try kernel0, kernel1, etc. 714889214aSAart Bik unsigned kernelNumber = 0; 724889214aSAart Bik SmallString<16> kernelName; 734889214aSAart Bik do { 744889214aSAart Bik kernelName.clear(); 754889214aSAart Bik ("kernel" + Twine(kernelNumber++)).toStringRef(kernelName); 764889214aSAart Bik } while (gpuModule.lookupSymbol(kernelName)); 774889214aSAart Bik // Then we insert a new kernel with given arguments into the module. 7849df12c0SMatthias Springer builder.setInsertionPointToStart(gpuModule.getBody()); 7919466ebcSAart Bik SmallVector<Type> argsTp; 8056c385cdSMehdi Amini for (auto arg : args) 8156c385cdSMehdi Amini argsTp.push_back(arg.getType()); 8219466ebcSAart Bik FunctionType type = FunctionType::get(gpuModule->getContext(), argsTp, {}); 8319466ebcSAart Bik auto gpuFunc = 844889214aSAart Bik builder.create<gpu::GPUFuncOp>(gpuModule->getLoc(), kernelName, type); 8519466ebcSAart Bik gpuFunc->setAttr(gpu::GPUDialect::getKernelFuncAttrName(), 8619466ebcSAart Bik builder.getUnitAttr()); 8719466ebcSAart Bik return gpuFunc; 8819466ebcSAart Bik } 8919466ebcSAart Bik 9019466ebcSAart Bik /// Constructs code to launch GPU kernel. 9186888e42SAart Bik static Value genLaunchGPUFunc(OpBuilder &builder, gpu::GPUFuncOp gpuFunc, 9219466ebcSAart Bik SmallVectorImpl<Value> &args, 9386888e42SAart Bik SmallVectorImpl<Value> &tokens, 9419466ebcSAart Bik unsigned numThreads) { 9519466ebcSAart Bik Location loc = gpuFunc->getLoc(); 9619466ebcSAart Bik Value none = TypedValue<::mlir::IntegerType>{}; 9719466ebcSAart Bik Value one = constantIndex(builder, loc, 1); 9819466ebcSAart Bik Value numT = constantIndex(builder, loc, numThreads); 9919466ebcSAart Bik gpu::KernelDim3 gridSize = {one, one, one}; 10019466ebcSAart Bik gpu::KernelDim3 blckSize = {numT, one, one}; 10186888e42SAart Bik return builder 10286888e42SAart Bik .create<gpu::LaunchFuncOp>(loc, gpuFunc, gridSize, blckSize, 10386888e42SAart Bik /*dynSharedMemSz*/ none, args, 10486888e42SAart Bik builder.getType<gpu::AsyncTokenType>(), tokens) 10586888e42SAart Bik .getAsyncToken(); 10619466ebcSAart Bik } 10719466ebcSAart Bik 10819466ebcSAart Bik /// Maps the provided ranked host buffer into the device address space. 10919466ebcSAart Bik /// Writes from the host are guaranteed to be visible to device kernels 11019466ebcSAart Bik /// that are launched afterwards. Writes from the device are guaranteed 11119466ebcSAart Bik /// to be visible on the host after synchronizing with the device kernel 11286888e42SAart Bik /// completion. Needs to cast the buffer to a unranked buffer. 11319466ebcSAart Bik static Value genHostRegisterMemref(OpBuilder &builder, Location loc, 11419466ebcSAart Bik Value mem) { 1155550c821STres Popp MemRefType memTp = cast<MemRefType>(mem.getType()); 11619466ebcSAart Bik UnrankedMemRefType resTp = 11719466ebcSAart Bik UnrankedMemRefType::get(memTp.getElementType(), /*memorySpace=*/0); 11819466ebcSAart Bik Value cast = builder.create<memref::CastOp>(loc, resTp, mem); 11919466ebcSAart Bik builder.create<gpu::HostRegisterOp>(loc, cast); 12086888e42SAart Bik return cast; 12186888e42SAart Bik } 12286888e42SAart Bik 12386888e42SAart Bik /// Unmaps the provided buffer, expecting the casted buffer. 12486888e42SAart Bik static void genHostUnregisterMemref(OpBuilder &builder, Location loc, 12586888e42SAart Bik Value cast) { 12686888e42SAart Bik builder.create<gpu::HostUnregisterOp>(loc, cast); 12786888e42SAart Bik } 12886888e42SAart Bik 12986888e42SAart Bik /// Generates first wait in an asynchronous chain. 13086888e42SAart Bik static Value genFirstWait(OpBuilder &builder, Location loc) { 13186888e42SAart Bik Type tokenType = builder.getType<gpu::AsyncTokenType>(); 13286888e42SAart Bik return builder.create<gpu::WaitOp>(loc, tokenType, ValueRange()) 13386888e42SAart Bik .getAsyncToken(); 13486888e42SAart Bik } 13586888e42SAart Bik 13686888e42SAart Bik /// Generates last, blocking wait in an asynchronous chain. 13786888e42SAart Bik static void genBlockingWait(OpBuilder &builder, Location loc, 13886888e42SAart Bik ValueRange operands) { 13986888e42SAart Bik builder.create<gpu::WaitOp>(loc, Type(), operands); 14086888e42SAart Bik } 14186888e42SAart Bik 14286888e42SAart Bik /// Allocates memory on the device. 14386888e42SAart Bik /// TODO: A `host_shared` attribute could be used to indicate that 14486888e42SAart Bik /// the buffer is visible by both host and device, but lowering 14586888e42SAart Bik /// that feature does not seem to be fully supported yet. 14686888e42SAart Bik static gpu::AllocOp genAllocMemRef(OpBuilder &builder, Location loc, Value mem, 14786888e42SAart Bik Value token) { 1485550c821STres Popp auto tp = cast<ShapedType>(mem.getType()); 14986888e42SAart Bik auto elemTp = tp.getElementType(); 15086888e42SAart Bik auto shape = tp.getShape(); 15186888e42SAart Bik auto memTp = MemRefType::get(shape, elemTp); 15286888e42SAart Bik SmallVector<Value> dynamicSizes; 15386888e42SAart Bik for (unsigned r = 0, rank = tp.getRank(); r < rank; r++) { 15486888e42SAart Bik if (shape[r] == ShapedType::kDynamic) { 155ee42e236SAart Bik Value dimOp = linalg::createOrFoldDimOp(builder, loc, mem, r); 15686888e42SAart Bik dynamicSizes.push_back(dimOp); 15786888e42SAart Bik } 15886888e42SAart Bik } 15986888e42SAart Bik return builder.create<gpu::AllocOp>(loc, TypeRange({memTp, token.getType()}), 16086888e42SAart Bik token, dynamicSizes, ValueRange()); 16186888e42SAart Bik } 16286888e42SAart Bik 16376a80a08SAart Bik // Allocates a typed buffer on the host with given size. 16476a80a08SAart Bik static Value genHostBuffer(OpBuilder &builder, Location loc, Type type, 16576a80a08SAart Bik Value size) { 16676a80a08SAart Bik const auto memTp = MemRefType::get({ShapedType::kDynamic}, type); 16776a80a08SAart Bik return builder.create<memref::AllocOp>(loc, memTp, size).getResult(); 16876a80a08SAart Bik } 16976a80a08SAart Bik 17076a80a08SAart Bik // Allocates a typed buffer on the device with given size. 17176a80a08SAart Bik static gpu::AllocOp genAllocBuffer(OpBuilder &builder, Location loc, Type type, 17276a80a08SAart Bik Value size, Value token) { 17376a80a08SAart Bik const auto memTp = MemRefType::get({ShapedType::kDynamic}, type); 17476a80a08SAart Bik return builder.create<gpu::AllocOp>(loc, TypeRange({memTp, token.getType()}), 17576a80a08SAart Bik token, size, ValueRange()); 17676a80a08SAart Bik } 17776a80a08SAart Bik 178ee42e236SAart Bik // Allocates a void buffer on the device with given size. 179ee42e236SAart Bik static gpu::AllocOp genAllocBuffer(OpBuilder &builder, Location loc, Value size, 180ee42e236SAart Bik Value token) { 18176a80a08SAart Bik return genAllocBuffer(builder, loc, builder.getI8Type(), size, token); 182ee42e236SAart Bik } 183ee42e236SAart Bik 18486888e42SAart Bik /// Deallocates memory from the device. 18586888e42SAart Bik static Value genDeallocMemRef(OpBuilder &builder, Location loc, Value mem, 18686888e42SAart Bik Value token) { 18786888e42SAart Bik return builder.create<gpu::DeallocOp>(loc, token.getType(), token, mem) 18886888e42SAart Bik .getAsyncToken(); 18986888e42SAart Bik } 19086888e42SAart Bik 19186888e42SAart Bik /// Copies memory between host and device (direction is implicit). 19286888e42SAart Bik static Value genCopyMemRef(OpBuilder &builder, Location loc, Value dst, 19386888e42SAart Bik Value src, Value token) { 19486888e42SAart Bik return builder.create<gpu::MemcpyOp>(loc, token.getType(), token, dst, src) 19586888e42SAart Bik .getAsyncToken(); 19686888e42SAart Bik } 19786888e42SAart Bik 198ee42e236SAart Bik /// Generates an alloc/copy pair. 199ee42e236SAart Bik static Value genAllocCopy(OpBuilder &builder, Location loc, Value b, 200ee42e236SAart Bik SmallVectorImpl<Value> &tokens) { 201ee42e236SAart Bik Value firstToken = genFirstWait(builder, loc); 202ee42e236SAart Bik auto alloc = genAllocMemRef(builder, loc, b, firstToken); 203ee42e236SAart Bik Value devMem = alloc.getResult(0); 204ee42e236SAart Bik Value depToken = alloc.getAsyncToken(); // copy-after-alloc 205ee42e236SAart Bik tokens.push_back(genCopyMemRef(builder, loc, devMem, b, depToken)); 206ee42e236SAart Bik return devMem; 207ee42e236SAart Bik } 208ee42e236SAart Bik 209ee42e236SAart Bik /// Generates a memref from tensor operation. 210ee42e236SAart Bik static Value genTensorToMemref(PatternRewriter &rewriter, Location loc, 211ee42e236SAart Bik Value tensor) { 21268f58812STres Popp auto tensorType = llvm::cast<ShapedType>(tensor.getType()); 213ee42e236SAart Bik auto memrefType = 214ee42e236SAart Bik MemRefType::get(tensorType.getShape(), tensorType.getElementType()); 215ee42e236SAart Bik return rewriter.create<bufferization::ToMemrefOp>(loc, memrefType, tensor); 216ee42e236SAart Bik } 217ee42e236SAart Bik 21886888e42SAart Bik /// Prepares the outlined arguments, passing scalars and buffers in. Here we 21986888e42SAart Bik /// assume that the first buffer is the one allocated for output. We create 22086888e42SAart Bik /// a set of properly chained asynchronous allocation/copy pairs to increase 22186888e42SAart Bik /// overlap before launching the kernel. 22286888e42SAart Bik static Value genParametersIn(OpBuilder &builder, Location loc, 22386888e42SAart Bik SmallVectorImpl<Value> &scalars, 22486888e42SAart Bik SmallVectorImpl<Value> &buffers, 22586888e42SAart Bik SmallVectorImpl<Value> &args, 22686888e42SAart Bik SmallVectorImpl<Value> &tokens, 22786888e42SAart Bik bool useHostRegistrationForOut) { 22886888e42SAart Bik Value out; 22986888e42SAart Bik // Scalars are passed by value. 23086888e42SAart Bik for (Value s : scalars) 23186888e42SAart Bik args.push_back(s); 23286888e42SAart Bik // Buffers are need to be made visible on device. 23386888e42SAart Bik for (Value b : buffers) { 23486888e42SAart Bik if (useHostRegistrationForOut) { 23586888e42SAart Bik out = genHostRegisterMemref(builder, loc, b); 23686888e42SAart Bik args.push_back(b); 23786888e42SAart Bik useHostRegistrationForOut = false; 23886888e42SAart Bik continue; 23986888e42SAart Bik } 240ee42e236SAart Bik args.push_back(genAllocCopy(builder, loc, b, tokens)); 24186888e42SAart Bik } 24286888e42SAart Bik return out; 24386888e42SAart Bik } 24486888e42SAart Bik 24586888e42SAart Bik /// Finalizes the outlined arguments. The output buffer is copied depending 24686888e42SAart Bik /// on the kernel token and then deallocated. All other buffers are simply 24786888e42SAart Bik /// deallocated. Then we wait for all operations to complete. 24886888e42SAart Bik static void genParametersOut(OpBuilder &builder, Location loc, Value out, 24986888e42SAart Bik Value kernelToken, SmallVectorImpl<Value> &scalars, 25086888e42SAart Bik SmallVectorImpl<Value> &buffers, 25186888e42SAart Bik SmallVectorImpl<Value> &args, 25286888e42SAart Bik SmallVectorImpl<Value> &tokens) { 25386888e42SAart Bik unsigned base = scalars.size(); 25486888e42SAart Bik for (unsigned i = base, e = args.size(); i < e; i++) { 25586888e42SAart Bik Value firstToken; 25686888e42SAart Bik if (i == base) { 25786888e42SAart Bik // Assumed output parameter: unregister or copy-out. 25886888e42SAart Bik if (out) { 25986888e42SAart Bik genHostUnregisterMemref(builder, loc, out); 26086888e42SAart Bik out = Value(); 26186888e42SAart Bik continue; 26286888e42SAart Bik } 26386888e42SAart Bik firstToken = 26486888e42SAart Bik genCopyMemRef(builder, loc, buffers[0], args[i], kernelToken); 26586888e42SAart Bik } else { 26686888e42SAart Bik firstToken = genFirstWait(builder, loc); 26786888e42SAart Bik } 26886888e42SAart Bik tokens.push_back(genDeallocMemRef(builder, loc, args[i], firstToken)); 26986888e42SAart Bik } 27019466ebcSAart Bik } 27119466ebcSAart Bik 27219466ebcSAart Bik /// Constructs code for new GPU kernel. 27319466ebcSAart Bik static void genGPUCode(PatternRewriter &rewriter, gpu::GPUFuncOp gpuFunc, 27419466ebcSAart Bik scf::ParallelOp forallOp, 27519466ebcSAart Bik SmallVectorImpl<Value> &constants, 27619466ebcSAart Bik SmallVectorImpl<Value> &scalars, 27719466ebcSAart Bik SmallVectorImpl<Value> &buffers) { 27819466ebcSAart Bik Location loc = gpuFunc->getLoc(); 27919466ebcSAart Bik Block &block = gpuFunc.getBody().front(); 28019466ebcSAart Bik rewriter.setInsertionPointToStart(&block); 28119466ebcSAart Bik 28219466ebcSAart Bik // Re-generate the constants, recapture all arguments. 28319466ebcSAart Bik unsigned arg = 0; 28419466ebcSAart Bik IRMapping irMap; 28519466ebcSAart Bik for (Value c : constants) 28619466ebcSAart Bik irMap.map(c, rewriter.clone(*c.getDefiningOp())->getResult(0)); 28719466ebcSAart Bik for (Value s : scalars) 28819466ebcSAart Bik irMap.map(s, block.getArgument(arg++)); 28919466ebcSAart Bik for (Value b : buffers) 29019466ebcSAart Bik irMap.map(b, block.getArgument(arg++)); 29119466ebcSAart Bik 29219466ebcSAart Bik // Assume 1-dimensional grid/block configuration (only x dimension), 29319466ebcSAart Bik // so that: 29419466ebcSAart Bik // row = blockIdx.x * blockDim.x + threadIdx.x 29519466ebcSAart Bik // inc = blockDim.x * gridDim.x 29619466ebcSAart Bik Value bid = rewriter.create<gpu::BlockIdOp>(loc, gpu::Dimension::x); 29719466ebcSAart Bik Value bsz = rewriter.create<gpu::BlockDimOp>(loc, gpu::Dimension::x); 29819466ebcSAart Bik Value tid = rewriter.create<gpu::ThreadIdOp>(loc, gpu::Dimension::x); 29919466ebcSAart Bik Value gsz = rewriter.create<gpu::GridDimOp>(loc, gpu::Dimension::x); 30019466ebcSAart Bik Value mul = rewriter.create<arith::MulIOp>(loc, bid, bsz); 30119466ebcSAart Bik Value row = rewriter.create<arith::AddIOp>(loc, mul, tid); 30219466ebcSAart Bik Value inc = rewriter.create<arith::MulIOp>(loc, bsz, gsz); 30319466ebcSAart Bik 30419466ebcSAart Bik // Construct the iteration over the computational space that 30519466ebcSAart Bik // accounts for the fact that the total number of threads and 30619466ebcSAart Bik // the amount of work to be done usually do not match precisely. 30719466ebcSAart Bik // for (r = row; r < N; r += inc) { 30819466ebcSAart Bik // <loop-body> 30919466ebcSAart Bik // } 31019466ebcSAart Bik Value upper = irMap.lookup(forallOp.getUpperBound()[0]); 31119466ebcSAart Bik scf::ForOp forOp = rewriter.create<scf::ForOp>(loc, row, upper, inc); 312861600f1SMatthias Springer // The scf.for builder creates an empty block. scf.for does not allow multiple 313861600f1SMatthias Springer // blocks in its region, so delete the block before `cloneRegionBefore` adds 314861600f1SMatthias Springer // an additional block. 315861600f1SMatthias Springer rewriter.eraseBlock(forOp.getBody()); 3169b5ef2beSMatthias Springer rewriter.cloneRegionBefore(forallOp.getRegion(), forOp.getRegion(), 3179b5ef2beSMatthias Springer forOp.getRegion().begin(), irMap); 31810056c82SMatthias Springer // Replace the scf.reduce terminator. 31910056c82SMatthias Springer rewriter.setInsertionPoint(forOp.getBody()->getTerminator()); 32010056c82SMatthias Springer rewriter.replaceOpWithNewOp<scf::YieldOp>(forOp.getBody()->getTerminator()); 32119466ebcSAart Bik 32219466ebcSAart Bik // Done. 32319466ebcSAart Bik rewriter.setInsertionPointAfter(forOp); 32419466ebcSAart Bik rewriter.create<gpu::ReturnOp>(gpuFunc->getLoc()); 32519466ebcSAart Bik } 32619466ebcSAart Bik 32719466ebcSAart Bik //===----------------------------------------------------------------------===// 328ee42e236SAart Bik // Library helper methods. 32919466ebcSAart Bik //===----------------------------------------------------------------------===// 33019466ebcSAart Bik 331f14c8eb5SAart Bik /// Helper to detect a + b with arguments taken from given block. 332f14c8eb5SAart Bik static bool matchAddOfArgs(Block *block, Value val) { 333ee42e236SAart Bik if (auto *def = val.getDefiningOp()) { 334f14c8eb5SAart Bik if (isa<arith::AddFOp, arith::AddIOp>(def)) { 335f14c8eb5SAart Bik Value a = block->getArguments()[0]; 336f14c8eb5SAart Bik Value b = block->getArguments()[1]; 337f14c8eb5SAart Bik return (def->getOperand(0) == a && def->getOperand(1) == b) || 338f14c8eb5SAart Bik (def->getOperand(0) == b && def->getOperand(1) == a); 339f14c8eb5SAart Bik } 340f14c8eb5SAart Bik } 341f14c8eb5SAart Bik return false; 342f14c8eb5SAart Bik } 343f14c8eb5SAart Bik 344f14c8eb5SAart Bik /// Helper to detect a * b with arguments taken from given block. 345f14c8eb5SAart Bik static bool matchMulOfArgs(Block *block, Value val) { 346f14c8eb5SAart Bik if (auto *def = val.getDefiningOp()) { 347f14c8eb5SAart Bik if (isa<arith::MulFOp, arith::MulIOp>(def)) { 348f14c8eb5SAart Bik Value a = block->getArguments()[0]; 349f14c8eb5SAart Bik Value b = block->getArguments()[1]; 350ee42e236SAart Bik return (def->getOperand(0) == a && def->getOperand(1) == b) || 351ee42e236SAart Bik (def->getOperand(0) == b && def->getOperand(1) == a); 352ee42e236SAart Bik } 353ee42e236SAart Bik } 354ee42e236SAart Bik return false; 355ee42e236SAart Bik } 356ee42e236SAart Bik 357ee42e236SAart Bik /// Helper to detect x = x + a * b 358ee42e236SAart Bik static bool matchSumOfMultOfArgs(linalg::GenericOp op) { 359ee42e236SAart Bik auto yieldOp = cast<linalg::YieldOp>(op.getRegion().front().getTerminator()); 360ee42e236SAart Bik if (auto *def = yieldOp.getOperand(0).getDefiningOp()) { 361f14c8eb5SAart Bik if (isa<arith::AddFOp, arith::AddIOp>(def)) { 362ee42e236SAart Bik Value x = op.getBlock()->getArguments()[2]; 363ee42e236SAart Bik return (def->getOperand(0) == x && 364f14c8eb5SAart Bik matchMulOfArgs(op.getBlock(), def->getOperand(1))) || 365ee42e236SAart Bik (def->getOperand(1) == x && 366f14c8eb5SAart Bik matchMulOfArgs(op.getBlock(), def->getOperand(0))); 367ee42e236SAart Bik } 368ee42e236SAart Bik } 369ee42e236SAart Bik return false; 370ee42e236SAart Bik } 371ee42e236SAart Bik 372f14c8eb5SAart Bik // Helper to detect c += spy(s) x (a * b) 3739167dd46SKun Wu static bool matchSumReductionOfMulUnary(linalg::GenericOp op) { 3749167dd46SKun Wu auto yieldOp = cast<linalg::YieldOp>(op.getRegion().front().getTerminator()); 375f14c8eb5SAart Bik // The linalg yields a custom reduce result. 3769167dd46SKun Wu Value s_out = op.getBlock()->getArguments()[2]; 377f14c8eb5SAart Bik if (auto redOp = 378f14c8eb5SAart Bik yieldOp.getOperand(0).getDefiningOp<sparse_tensor::ReduceOp>()) { 379f14c8eb5SAart Bik // The reduce consumes the output. 380f14c8eb5SAart Bik Value other; 381f14c8eb5SAart Bik if (s_out == redOp->getOperand(0)) 382f14c8eb5SAart Bik other = redOp->getOperand(1); 383f14c8eb5SAart Bik else if (s_out == redOp->getOperand(1)) 384f14c8eb5SAart Bik other = redOp->getOperand(0); 385f14c8eb5SAart Bik else 3869167dd46SKun Wu return false; 387f14c8eb5SAart Bik // The reduce op also consumes an unary which also consumes the output 388f14c8eb5SAart Bik // and does not define an absent value. 389f14c8eb5SAart Bik if (auto unOp = other.getDefiningOp<sparse_tensor::UnaryOp>()) { 390f14c8eb5SAart Bik if (s_out != unOp->getOperand(0) || !unOp.getAbsentRegion().empty()) 3919167dd46SKun Wu return false; 392f14c8eb5SAart Bik // And the bodies are as expected. 393f14c8eb5SAart Bik auto yieldUn = cast<sparse_tensor::YieldOp>( 394f14c8eb5SAart Bik unOp.getRegion(0).front().getTerminator()); 395f14c8eb5SAart Bik auto yieldRed = cast<sparse_tensor::YieldOp>( 396f14c8eb5SAart Bik redOp.getRegion().front().getTerminator()); 397f14c8eb5SAart Bik return matchMulOfArgs(op.getBlock(), yieldUn.getOperand(0)) && 398f14c8eb5SAart Bik matchAddOfArgs(&redOp.getRegion().front(), yieldRed.getOperand(0)); 3999167dd46SKun Wu } 4009167dd46SKun Wu } 401f14c8eb5SAart Bik return false; 4029167dd46SKun Wu } 4039167dd46SKun Wu 4043231a365SAart Bik /// Test for dense tensor. 405e37fc3ccSK-Wu static bool isDenseTensor(Value v) { 4063231a365SAart Bik auto sTp = getSparseTensorType(v); 4073231a365SAart Bik return sTp.getDimRank() == sTp.getLvlRank() && sTp.isAllDense(); 408e37fc3ccSK-Wu } 409e37fc3ccSK-Wu 4103231a365SAart Bik /// Test for suitable positions/coordinates width. 4113231a365SAart Bik static bool isAdmissibleMetaData(SparseTensorType &aTp) { 4123231a365SAart Bik return (aTp.getPosWidth() == 0 || aTp.getPosWidth() >= 16) && 4133231a365SAart Bik (aTp.getCrdWidth() == 0 || aTp.getCrdWidth() >= 16); 4143231a365SAart Bik } 4153231a365SAart Bik 4163231a365SAart Bik /// Test for sorted COO matrix with suitable metadata. 417ee42e236SAart Bik static bool isAdmissibleCOO(SparseTensorType &aTp) { 4183231a365SAart Bik return aTp.getDimRank() == 2 && aTp.getLvlRank() == 2 && aTp.isIdentity() && 4193231a365SAart Bik aTp.isCompressedLvl(0) && aTp.isOrderedLvl(0) && !aTp.isUniqueLvl(0) && 420ee42e236SAart Bik aTp.isSingletonLvl(1) && aTp.isOrderedLvl(1) && aTp.isUniqueLvl(1) && 4213231a365SAart Bik isAdmissibleMetaData(aTp); 422ee42e236SAart Bik } 423ee42e236SAart Bik 4243231a365SAart Bik /// Test for CSR matrix with suitable metadata. 425ee42e236SAart Bik static bool isAdmissibleCSR(SparseTensorType &aTp) { 4263231a365SAart Bik return aTp.getDimRank() == 2 && aTp.getLvlRank() == 2 && aTp.isIdentity() && 4273231a365SAart Bik aTp.isDenseLvl(0) && aTp.isCompressedLvl(1) && aTp.isOrderedLvl(1) && 4283231a365SAart Bik aTp.isUniqueLvl(1) && isAdmissibleMetaData(aTp); 429ee42e236SAart Bik } 430ee42e236SAart Bik 4313231a365SAart Bik /// Test for CSC matrix with suitable metadata. 4323231a365SAart Bik static bool isAdmissibleCSC(SparseTensorType &aTp) { 4333231a365SAart Bik return aTp.getDimRank() == 2 && aTp.getLvlRank() == 2 && !aTp.isIdentity() && 4343231a365SAart Bik aTp.isPermutation() && aTp.isDenseLvl(0) && aTp.isCompressedLvl(1) && 4353231a365SAart Bik aTp.isOrderedLvl(1) && aTp.isUniqueLvl(1) && isAdmissibleMetaData(aTp); 436b75d6a40SAart Bik } 4373231a365SAart Bik 4383d89c088SAart Bik /// Test for BSR matrix with suitable metadata. 4393d89c088SAart Bik static bool isAdmissibleBSR(SparseTensorType &aTp) { 4403d89c088SAart Bik if (aTp.getDimRank() == 2 && aTp.getLvlRank() == 4 && aTp.isDenseLvl(0) && 4413d89c088SAart Bik aTp.isCompressedLvl(1) && aTp.isOrderedLvl(1) && aTp.isUniqueLvl(1) && 4423d89c088SAart Bik aTp.isDenseLvl(2) && aTp.isDenseLvl(3) && isAdmissibleMetaData(aTp)) { 4433d89c088SAart Bik // CuSparse only supports "square" blocks currently. 4443d89c088SAart Bik SmallVector<unsigned> dims = getBlockSize(aTp.getDimToLvl()); 4453d89c088SAart Bik assert(dims.size() == 2); 446e35b6062SMaksim Levental return dims[0] == dims[1] && dims[0] > 1; 4473d89c088SAart Bik } 4483d89c088SAart Bik return false; 4493d89c088SAart Bik } 4503d89c088SAart Bik 45141a07e66SAart Bik /// Test for 2:4 matrix with suitable metadata. 45241a07e66SAart Bik static bool isAdmissible24(SparseTensorType &aTp) { 45341a07e66SAart Bik return aTp.getDimRank() == 2 && aTp.getLvlRank() == 3 && aTp.isDenseLvl(0) && 454e5924d64SYinying Li aTp.isDenseLvl(1) && aTp.isNOutOfMLvl(2) && isAdmissibleMetaData(aTp); 45541a07e66SAart Bik } 45641a07e66SAart Bik 45741a07e66SAart Bik /// Test for conversion into 2:4 matrix. 45841a07e66SAart Bik static bool isConversionInto24(Value v) { 45941a07e66SAart Bik if (auto cnv = v.getDefiningOp<ConvertOp>()) { 46041a07e66SAart Bik Value a = cnv.getResult(); 46141a07e66SAart Bik Value d = cnv.getSource(); 46241a07e66SAart Bik SparseTensorType aTp = getSparseTensorType(a); 46341a07e66SAart Bik return isDenseTensor(d) && isAdmissible24(aTp); 46441a07e66SAart Bik } 46541a07e66SAart Bik return false; 46641a07e66SAart Bik } 46741a07e66SAart Bik 4683231a365SAart Bik /// Returns a suitable sparse format for the operation and given operand 4693231a365SAart Bik /// types with cuSparse, or kNone if none is available. 4703231a365SAart Bik static CuSparseFormat getCuSparseFormat(SparseTensorType aTp, 4713231a365SAart Bik SparseTensorType bTp, 4723231a365SAart Bik SparseTensorType cTp, bool enableRT, 4733231a365SAart Bik bool isMatVec) { 4743231a365SAart Bik // The other operands have a dense type. 4753231a365SAart Bik if (bTp.hasEncoding() || cTp.hasEncoding()) 4763231a365SAart Bik return CuSparseFormat::kNone; 4773231a365SAart Bik // Now check for suitable operand type for the main operand. 4783231a365SAart Bik if (isAdmissibleCOO(aTp)) 4793231a365SAart Bik #ifdef CUSPARSE_COO_AOS 4803231a365SAart Bik return isMatVec ? CuSparseFormat::kCOO : CuSparseFormat::kNone; 4813231a365SAart Bik #else 4823231a365SAart Bik return enableRT ? CuSparseFormat::kCOO : CuSparseFormat::kNone; 4833231a365SAart Bik #endif 4843231a365SAart Bik if (isAdmissibleCSR(aTp)) 4853231a365SAart Bik return CuSparseFormat::kCSR; 4863231a365SAart Bik if (isAdmissibleCSC(aTp)) 4873231a365SAart Bik return CuSparseFormat::kCSC; 4883d89c088SAart Bik if (isAdmissibleBSR(aTp)) 4893d89c088SAart Bik return CuSparseFormat::kBSR; 4903231a365SAart Bik return CuSparseFormat::kNone; 491b75d6a40SAart Bik } 492b75d6a40SAart Bik 493ee42e236SAart Bik /// Generates the first positions/coordinates of a sparse matrix. 494ee42e236SAart Bik static Value genFirstPosOrCrds(OpBuilder &builder, Location loc, Value a, 4953231a365SAart Bik CuSparseFormat format, bool enableRT) { 4963231a365SAart Bik if (format == CuSparseFormat::kCOO) { 497ee42e236SAart Bik // Library uses SoA COO, direct IR uses AoS COO. 498ee42e236SAart Bik if (enableRT) 4991a0986f0SPeiming Liu return builder.create<ToCoordinatesOp>(loc, a, 0); 5001a0986f0SPeiming Liu return builder.create<ToCoordinatesBufferOp>(loc, a); 501ee42e236SAart Bik } 5023231a365SAart Bik // Formats CSR/CSC and BSR use positions at 1. 5031a0986f0SPeiming Liu return builder.create<ToPositionsOp>(loc, a, 1); 504ee42e236SAart Bik } 505ee42e236SAart Bik 506ee42e236SAart Bik /// Generates the second coordinates of a sparse matrix. 507ee42e236SAart Bik static Value genSecondCrds(OpBuilder &builder, Location loc, Value a, 5083231a365SAart Bik CuSparseFormat format, bool enableRT) { 5093231a365SAart Bik bool isCOO = format == CuSparseFormat::kCOO; 510ee42e236SAart Bik if (isCOO && !enableRT) 511ee42e236SAart Bik return Value(); // nothing needed 5123231a365SAart Bik // Formats CSR/CSC and BSR use coordinates at 1. 5131a0986f0SPeiming Liu return builder.create<ToCoordinatesOp>(loc, a, 1); 514ee42e236SAart Bik } 515ee42e236SAart Bik 5163231a365SAart Bik /// Generates the sparse matrix handle. 5173d89c088SAart Bik static Operation *genSpMat(OpBuilder &builder, Location loc, 5183d89c088SAart Bik SparseTensorType &aTp, Type handleTp, Type tokenTp, 5193d89c088SAart Bik Value token, Value sz1, Value sz2, Value nseA, 5203d89c088SAart Bik Value rowA, Value colA, Value valA, 5213231a365SAart Bik CuSparseFormat format, bool enableRT) { 5223231a365SAart Bik if (format == CuSparseFormat::kCOO) { 523ee42e236SAart Bik // Library uses SoA COO, direct IR uses AoS COO. 524bcb698bfSAart Bik if (enableRT) { 525bcb698bfSAart Bik assert(colA); 526ee42e236SAart Bik return builder.create<gpu::CreateCooOp>(loc, handleTp, tokenTp, token, 527bcb698bfSAart Bik sz1, sz2, nseA, rowA, colA, valA); 528bcb698bfSAart Bik } 5299fc02a7aSAart Bik #ifdef CUSPARSE_COO_AOS 5309fc02a7aSAart Bik assert(!colA); 5319fc02a7aSAart Bik return builder.create<gpu::CreateCooAoSOp>(loc, handleTp, tokenTp, token, 5329fc02a7aSAart Bik sz1, sz2, nseA, rowA, valA); 5339fc02a7aSAart Bik #else 534ee42e236SAart Bik llvm_unreachable("gpu::CreateCooAoSOp is deprecated"); 5359fc02a7aSAart Bik #endif 536ee42e236SAart Bik } 537bcb698bfSAart Bik assert(colA); 5383231a365SAart Bik if (format == CuSparseFormat::kCSR) 539bcb698bfSAart Bik return builder.create<gpu::CreateCsrOp>(loc, handleTp, tokenTp, token, sz1, 540bcb698bfSAart Bik sz2, nseA, rowA, colA, valA); 5413d89c088SAart Bik if (format == CuSparseFormat::kCSC) 5423231a365SAart Bik return builder.create<gpu::CreateCscOp>(loc, handleTp, tokenTp, token, sz1, 5433231a365SAart Bik sz2, nseA, rowA, colA, valA); 5443d89c088SAart Bik // BSR requires a bit more work since we need to pass in the block size 5453d89c088SAart Bik // and all others sizes in terms of blocks (#block-rows, #block-cols, 5463d89c088SAart Bik // #nonzero-blocks). 5473d89c088SAart Bik assert(format == CuSparseFormat::kBSR); 5483d89c088SAart Bik SmallVector<unsigned> dims = getBlockSize(aTp.getDimToLvl()); 5493d89c088SAart Bik assert(dims.size() == 2 && dims[0] == dims[1]); 5503d89c088SAart Bik uint64_t b = dims[0]; 5513d89c088SAart Bik Value bSz = constantIndex(builder, loc, b); 5523d89c088SAart Bik Value bRows = builder.create<arith::DivUIOp>(loc, sz1, bSz); 5533d89c088SAart Bik Value bCols = builder.create<arith::DivUIOp>(loc, sz2, bSz); 5543d89c088SAart Bik Value bNum = builder.create<arith::DivUIOp>( 5553d89c088SAart Bik loc, nseA, constantIndex(builder, loc, b * b)); 5563d89c088SAart Bik return builder.create<gpu::CreateBsrOp>(loc, handleTp, tokenTp, token, bRows, 5573d89c088SAart Bik bCols, bNum, bSz, bSz, rowA, colA, 5583d89c088SAart Bik valA); 559ee42e236SAart Bik } 560ee42e236SAart Bik 561ee42e236SAart Bik /// Match and rewrite SpMV kernel. 5625ef44679SAart Bik static LogicalResult rewriteSpMV(PatternRewriter &rewriter, 5635ef44679SAart Bik linalg::GenericOp op, bool enableRT) { 564ee42e236SAart Bik Location loc = op.getLoc(); 565ee42e236SAart Bik Value a = op.getOperand(0); 566ee42e236SAart Bik Value x = op.getOperand(1); 567ee42e236SAart Bik Value y = op.getOperand(2); // we have y = Ax 568ee42e236SAart Bik SmallVector<Value> tokens; 569ee42e236SAart Bik 5703231a365SAart Bik // Only admissible sparse matrix format and dense vectors (no BSR). 571ee42e236SAart Bik SparseTensorType aTp = getSparseTensorType(a); 572ee42e236SAart Bik SparseTensorType xTp = getSparseTensorType(x); 573ee42e236SAart Bik SparseTensorType yTp = getSparseTensorType(y); 5743231a365SAart Bik auto format = getCuSparseFormat(aTp, xTp, yTp, enableRT, /*isMatVec=*/true); 5753231a365SAart Bik if (format == CuSparseFormat::kNone || format == CuSparseFormat::kBSR) 576ee42e236SAart Bik return failure(); 577ee42e236SAart Bik 578ee42e236SAart Bik // Start sparse kernel and copy data from host to device. 579ee42e236SAart Bik // a : memR/memC/memV -> rowA,colA,valA 580ee42e236SAart Bik // x : memX -> vecX 581ee42e236SAart Bik // y : memY -> vecY 582b75d6a40SAart Bik Value nseA = rewriter.create<NumberOfEntriesOp>(loc, a); 583ee42e236SAart Bik Value szY = linalg::createOrFoldDimOp(rewriter, loc, a, 0); 584ee42e236SAart Bik Value szX = linalg::createOrFoldDimOp(rewriter, loc, a, 1); 5853231a365SAart Bik Value memR = genFirstPosOrCrds(rewriter, loc, a, format, enableRT); 5865ef44679SAart Bik Value memC = genSecondCrds(rewriter, loc, a, format, enableRT); // or empty 5871a0986f0SPeiming Liu Value memV = rewriter.create<ToValuesOp>(loc, a); 588ee42e236SAart Bik Value rowA = genAllocCopy(rewriter, loc, memR, tokens); 589ee42e236SAart Bik Value colA = memC ? genAllocCopy(rewriter, loc, memC, tokens) : Value(); 590ee42e236SAart Bik Value valA = genAllocCopy(rewriter, loc, memV, tokens); 5915ef44679SAart Bik Value memX = genTensorToMemref(rewriter, loc, x); 5925ef44679SAart Bik Value vecX = genAllocCopy(rewriter, loc, memX, tokens); 5935ef44679SAart Bik Value memY = genTensorToMemref(rewriter, loc, y); 594ee42e236SAart Bik Value vecY = genAllocCopy(rewriter, loc, memY, tokens); 595ee42e236SAart Bik genBlockingWait(rewriter, loc, tokens); 596ee42e236SAart Bik tokens.clear(); 597ee42e236SAart Bik 598ee42e236SAart Bik // Create sparse environment and sparse matrix/dense vector handles. 599ee42e236SAart Bik Type indexTp = rewriter.getIndexType(); 60097f4c22bSKun Wu Type dnTensorHandleTp = rewriter.getType<gpu::SparseDnTensorHandleType>(); 60186bf710cSKun Wu Type spmatHandleTp = rewriter.getType<gpu::SparseSpMatHandleType>(); 602ee42e236SAart Bik Type tokenTp = rewriter.getType<gpu::AsyncTokenType>(); 603ee42e236SAart Bik Value token = genFirstWait(rewriter, loc); 60486bf710cSKun Wu Operation *spGenA = 6053d89c088SAart Bik genSpMat(rewriter, loc, aTp, spmatHandleTp, tokenTp, token, szY, szX, 6063d89c088SAart Bik nseA, rowA, colA, valA, format, enableRT); 607ee42e236SAart Bik Value spMatA = spGenA->getResult(0); 608ee42e236SAart Bik token = spGenA->getResult(1); 60997f4c22bSKun Wu auto dvecX = rewriter.create<gpu::CreateDnTensorOp>( 610be2dd22bSKun Wu loc, dnTensorHandleTp, tokenTp, token, vecX, szX); 611ee42e236SAart Bik Value dnX = dvecX.getResult(0); 612ee42e236SAart Bik token = dvecX.getAsyncToken(); 61397f4c22bSKun Wu auto dvecY = rewriter.create<gpu::CreateDnTensorOp>( 614be2dd22bSKun Wu loc, dnTensorHandleTp, tokenTp, token, vecY, szY); 615ee42e236SAart Bik Value dnY = dvecY.getResult(0); 616ee42e236SAart Bik token = dvecY.getAsyncToken(); 617fa98bdbdSKun Wu auto dnYType = llvm::cast<ShapedType>(y.getType()).getElementType(); 618fa98bdbdSKun Wu 619ee42e236SAart Bik // Precompute buffersize for SpMV. 620ee42e236SAart Bik auto bufferComp = rewriter.create<gpu::SpMVBufferSizeOp>( 621be2dd22bSKun Wu loc, indexTp, tokenTp, token, spMatA, dnX, dnY, 622fa98bdbdSKun Wu /*computeType=*/dnYType); 623ee42e236SAart Bik Value bufferSz = bufferComp.getResult(0); 624ee42e236SAart Bik token = bufferComp.getAsyncToken(); 625ee42e236SAart Bik auto buf = genAllocBuffer(rewriter, loc, bufferSz, token); 626ee42e236SAart Bik Value buffer = buf.getResult(0); 627ee42e236SAart Bik token = buf.getAsyncToken(); 628ee42e236SAart Bik 629ee42e236SAart Bik // Perform the SpMV. 630be2dd22bSKun Wu auto spmvComp = rewriter.create<gpu::SpMVOp>( 631be2dd22bSKun Wu loc, tokenTp, token, spMatA, dnX, dnY, /*computeType=*/dnYType, buffer); 632ee42e236SAart Bik token = spmvComp.getAsyncToken(); 633ee42e236SAart Bik 634ee42e236SAart Bik // Copy data back to host and free all the resoures. 635ee42e236SAart Bik token = rewriter.create<gpu::DestroySpMatOp>(loc, tokenTp, token, spMatA) 636ee42e236SAart Bik .getAsyncToken(); 63797f4c22bSKun Wu token = rewriter.create<gpu::DestroyDnTensorOp>(loc, tokenTp, token, dnX) 638ee42e236SAart Bik .getAsyncToken(); 63997f4c22bSKun Wu token = rewriter.create<gpu::DestroyDnTensorOp>(loc, tokenTp, token, dnY) 640ee42e236SAart Bik .getAsyncToken(); 641ee42e236SAart Bik token = genDeallocMemRef(rewriter, loc, rowA, token); 642ee42e236SAart Bik if (colA) 643ee42e236SAart Bik token = genDeallocMemRef(rewriter, loc, colA, token); 644ee42e236SAart Bik token = genDeallocMemRef(rewriter, loc, valA, token); 645ee42e236SAart Bik token = genDeallocMemRef(rewriter, loc, buffer, token); 646ee42e236SAart Bik token = genDeallocMemRef(rewriter, loc, vecX, token); 647bcb698bfSAart Bik token = genCopyMemRef(rewriter, loc, memY, vecY, token); 648ee42e236SAart Bik token = genDeallocMemRef(rewriter, loc, vecY, token); 649ee42e236SAart Bik tokens.push_back(token); 650ee42e236SAart Bik genBlockingWait(rewriter, loc, tokens); 65176a80a08SAart Bik tokens.clear(); 652ee42e236SAart Bik 653ee42e236SAart Bik // Done. 654bcb698bfSAart Bik rewriter.replaceOpWithNewOp<bufferization::ToTensorOp>(op, memY); 655ee42e236SAart Bik return success(); 656ee42e236SAart Bik } 657ee42e236SAart Bik 658ee42e236SAart Bik /// Match and rewrite SpMM kernel. 6595ef44679SAart Bik static LogicalResult rewriteSpMM(PatternRewriter &rewriter, 6605ef44679SAart Bik linalg::GenericOp op, bool enableRT) { 661b75d6a40SAart Bik Location loc = op.getLoc(); 662b75d6a40SAart Bik Value a = op.getOperand(0); 663b75d6a40SAart Bik Value b = op.getOperand(1); 664b75d6a40SAart Bik Value c = op.getOperand(2); // we have C = AB 665b75d6a40SAart Bik SmallVector<Value> tokens; 666b75d6a40SAart Bik 6673231a365SAart Bik // Only admissible sparse matrix format and dense matrices (no BSR). 668b75d6a40SAart Bik SparseTensorType aTp = getSparseTensorType(a); 669b75d6a40SAart Bik SparseTensorType bTp = getSparseTensorType(b); 670b75d6a40SAart Bik SparseTensorType cTp = getSparseTensorType(c); 6713231a365SAart Bik auto format = getCuSparseFormat(aTp, bTp, cTp, enableRT, /*isMatVec=*/false); 6723231a365SAart Bik if (format == CuSparseFormat::kNone || format == CuSparseFormat::kBSR) 673b75d6a40SAart Bik return failure(); 674b75d6a40SAart Bik 675b75d6a40SAart Bik // Start sparse kernel and copy data from host to device. 676b75d6a40SAart Bik // a : memR/memC/memV -> rowA,colA,valA 6775ef44679SAart Bik // b : bufB -> matB 678b75d6a40SAart Bik // c : bufC -> matC 679b75d6a40SAart Bik Value nseA = rewriter.create<NumberOfEntriesOp>(loc, a); 680b75d6a40SAart Bik Value szm = linalg::createOrFoldDimOp(rewriter, loc, a, 0); 681b75d6a40SAart Bik Value szk = linalg::createOrFoldDimOp(rewriter, loc, a, 1); 682b75d6a40SAart Bik Value szn = linalg::createOrFoldDimOp(rewriter, loc, b, 1); 6833231a365SAart Bik Value memR = genFirstPosOrCrds(rewriter, loc, a, format, enableRT); 6845ef44679SAart Bik Value memC = genSecondCrds(rewriter, loc, a, format, enableRT); // or empty 6851a0986f0SPeiming Liu Value memV = rewriter.create<ToValuesOp>(loc, a); 686b75d6a40SAart Bik Value rowA = genAllocCopy(rewriter, loc, memR, tokens); 687b75d6a40SAart Bik Value colA = memC ? genAllocCopy(rewriter, loc, memC, tokens) : Value(); 688b75d6a40SAart Bik Value valA = genAllocCopy(rewriter, loc, memV, tokens); 6895ef44679SAart Bik Value bufB = genTensorToMemref(rewriter, loc, b); 6905ef44679SAart Bik Value matB = genAllocCopy(rewriter, loc, bufB, tokens); 6915ef44679SAart Bik Value bufC = genTensorToMemref(rewriter, loc, c); 692b75d6a40SAart Bik Value matC = genAllocCopy(rewriter, loc, bufC, tokens); 693b75d6a40SAart Bik genBlockingWait(rewriter, loc, tokens); 694b75d6a40SAart Bik tokens.clear(); 695b75d6a40SAart Bik 696b75d6a40SAart Bik // Create sparse environment and sparse matrix/dense matrix handles. 697b75d6a40SAart Bik Type indexTp = rewriter.getIndexType(); 69897f4c22bSKun Wu Type dnTensorHandleTp = rewriter.getType<gpu::SparseDnTensorHandleType>(); 69986bf710cSKun Wu Type spMatHandleTp = rewriter.getType<gpu::SparseSpMatHandleType>(); 700b75d6a40SAart Bik Type tokenTp = rewriter.getType<gpu::AsyncTokenType>(); 701b75d6a40SAart Bik Value token = genFirstWait(rewriter, loc); 70286bf710cSKun Wu Operation *spGenA = 7033d89c088SAart Bik genSpMat(rewriter, loc, aTp, spMatHandleTp, tokenTp, token, szm, szk, 7043d89c088SAart Bik nseA, rowA, colA, valA, format, enableRT); 705b75d6a40SAart Bik Value spMatA = spGenA->getResult(0); 706b75d6a40SAart Bik token = spGenA->getResult(1); 70797f4c22bSKun Wu auto dmatB = rewriter.create<gpu::CreateDnTensorOp>( 708be2dd22bSKun Wu loc, dnTensorHandleTp, tokenTp, token, matB, 70997f4c22bSKun Wu SmallVector<Value>{szk, szn}); 710b75d6a40SAart Bik Value dnB = dmatB.getResult(0); 711b75d6a40SAart Bik token = dmatB.getAsyncToken(); 71297f4c22bSKun Wu auto dmatC = rewriter.create<gpu::CreateDnTensorOp>( 713be2dd22bSKun Wu loc, dnTensorHandleTp, tokenTp, token, matC, 71497f4c22bSKun Wu SmallVector<Value>{szm, szn}); 715b75d6a40SAart Bik Value dnC = dmatC.getResult(0); 716b75d6a40SAart Bik token = dmatC.getAsyncToken(); 717fa98bdbdSKun Wu auto dmatCType = llvm::cast<ShapedType>(c.getType()).getElementType(); 718fa98bdbdSKun Wu 719b75d6a40SAart Bik // Precompute buffersize for SpMM. 720b75d6a40SAart Bik auto bufferComp = rewriter.create<gpu::SpMMBufferSizeOp>( 721be2dd22bSKun Wu loc, indexTp, tokenTp, token, spMatA, dnB, dnC, 722fa98bdbdSKun Wu /*computeType=*/dmatCType); 723b75d6a40SAart Bik Value bufferSz = bufferComp.getResult(0); 724b75d6a40SAart Bik token = bufferComp.getAsyncToken(); 725b75d6a40SAart Bik auto buf = genAllocBuffer(rewriter, loc, bufferSz, token); 726b75d6a40SAart Bik Value buffer = buf.getResult(0); 727b75d6a40SAart Bik token = buf.getAsyncToken(); 728fa98bdbdSKun Wu auto dnCType = llvm::cast<ShapedType>(c.getType()).getElementType(); 729fa98bdbdSKun Wu 730b75d6a40SAart Bik // Perform the SpMM. 731be2dd22bSKun Wu auto spmmComp = rewriter.create<gpu::SpMMOp>( 732be2dd22bSKun Wu loc, tokenTp, token, spMatA, dnB, dnC, /*computeType=*/dnCType, buffer); 733b75d6a40SAart Bik token = spmmComp.getAsyncToken(); 734b75d6a40SAart Bik 735b75d6a40SAart Bik // Copy data back to host and free all the resoures. 736b75d6a40SAart Bik token = rewriter.create<gpu::DestroySpMatOp>(loc, tokenTp, token, spMatA) 737b75d6a40SAart Bik .getAsyncToken(); 73897f4c22bSKun Wu token = rewriter.create<gpu::DestroyDnTensorOp>(loc, tokenTp, token, dnB) 739b75d6a40SAart Bik .getAsyncToken(); 74097f4c22bSKun Wu token = rewriter.create<gpu::DestroyDnTensorOp>(loc, tokenTp, token, dnC) 741b75d6a40SAart Bik .getAsyncToken(); 74203125e68SAart Bik token = genDeallocMemRef(rewriter, loc, rowA, token); 743b75d6a40SAart Bik if (colA) 744b75d6a40SAart Bik token = genDeallocMemRef(rewriter, loc, colA, token); 745b75d6a40SAart Bik token = genDeallocMemRef(rewriter, loc, valA, token); 746b75d6a40SAart Bik token = genDeallocMemRef(rewriter, loc, buffer, token); 747b75d6a40SAart Bik token = genDeallocMemRef(rewriter, loc, matB, token); 748bcb698bfSAart Bik token = genCopyMemRef(rewriter, loc, bufC, matC, token); 749b75d6a40SAart Bik token = genDeallocMemRef(rewriter, loc, matC, token); 750b75d6a40SAart Bik tokens.push_back(token); 751b75d6a40SAart Bik genBlockingWait(rewriter, loc, tokens); 75276a80a08SAart Bik tokens.clear(); 753b75d6a40SAart Bik 754b75d6a40SAart Bik // Done. 75522caafc9SAart Bik rewriter.replaceOpWithNewOp<bufferization::ToTensorOp>(op, bufC); 756b75d6a40SAart Bik return success(); 757ee42e236SAart Bik } 758ee42e236SAart Bik 75976a80a08SAart Bik // Match and rewrite SpGEMM kernel. 7605ef44679SAart Bik static LogicalResult rewriteSpGEMM(PatternRewriter &rewriter, 7615ef44679SAart Bik linalg::GenericOp op, bool enableRT) { 76276a80a08SAart Bik Location loc = op.getLoc(); 76376a80a08SAart Bik Value a = op.getOperand(0); 76476a80a08SAart Bik Value b = op.getOperand(1); 76576a80a08SAart Bik Value c = op.getOperand(2); // we have C = AB 76676a80a08SAart Bik SmallVector<Value> tokens; 76776a80a08SAart Bik 76876a80a08SAart Bik // Only CSR <- CSR x CSR supported. 7693231a365SAart Bik auto format = CuSparseFormat::kCSR; 77076a80a08SAart Bik SparseTensorType aTp = getSparseTensorType(a); 77176a80a08SAart Bik SparseTensorType bTp = getSparseTensorType(b); 77276a80a08SAart Bik SparseTensorType cTp = getSparseTensorType(c); 77376a80a08SAart Bik if (!isAdmissibleCSR(aTp) || !isAdmissibleCSR(bTp) || !isAdmissibleCSR(cTp)) 77476a80a08SAart Bik return failure(); 77576a80a08SAart Bik 77676a80a08SAart Bik // Start sparse kernel and copy data from host to device. 77776a80a08SAart Bik // a : amemR/amemC/amemV -> rowA,colA,valA 77876a80a08SAart Bik // b : bmemR/bmemC/bmemV -> rowB,colB,valB 77976a80a08SAart Bik // c : materializes 78076a80a08SAart Bik auto dnCType = cTp.getElementType(); 78176a80a08SAart Bik Value nseA = rewriter.create<NumberOfEntriesOp>(loc, a); 78276a80a08SAart Bik Value nseB = rewriter.create<NumberOfEntriesOp>(loc, b); 78376a80a08SAart Bik Value szm = linalg::createOrFoldDimOp(rewriter, loc, a, 0); 78476a80a08SAart Bik Value szk = linalg::createOrFoldDimOp(rewriter, loc, a, 1); 78576a80a08SAart Bik Value szn = linalg::createOrFoldDimOp(rewriter, loc, b, 1); 7863231a365SAart Bik Value amemR = genFirstPosOrCrds(rewriter, loc, a, format, enableRT); 7875ef44679SAart Bik Value amemC = genSecondCrds(rewriter, loc, a, format, enableRT); // not empty 7881a0986f0SPeiming Liu Value amemV = rewriter.create<ToValuesOp>(loc, a); 7893231a365SAart Bik Value bmemR = genFirstPosOrCrds(rewriter, loc, b, format, enableRT); 7905ef44679SAart Bik Value bmemC = genSecondCrds(rewriter, loc, b, format, enableRT); // not empty 7911a0986f0SPeiming Liu Value bmemV = rewriter.create<ToValuesOp>(loc, b); 79276a80a08SAart Bik Value rowA = genAllocCopy(rewriter, loc, amemR, tokens); 79376a80a08SAart Bik Value colA = genAllocCopy(rewriter, loc, amemC, tokens); 79476a80a08SAart Bik Value valA = genAllocCopy(rewriter, loc, amemV, tokens); 79576a80a08SAart Bik Value rowB = genAllocCopy(rewriter, loc, bmemR, tokens); 79676a80a08SAart Bik Value colB = genAllocCopy(rewriter, loc, bmemC, tokens); 79776a80a08SAart Bik Value valB = genAllocCopy(rewriter, loc, bmemV, tokens); 79876a80a08SAart Bik genBlockingWait(rewriter, loc, tokens); 79976a80a08SAart Bik tokens.clear(); 80076a80a08SAart Bik 80176a80a08SAart Bik // Create sparse environment and sparse matrix/dense vector handles. 80276a80a08SAart Bik Type indexTp = rewriter.getIndexType(); 80376a80a08SAart Bik Type spmatHandleTp = rewriter.getType<gpu::SparseSpMatHandleType>(); 80476a80a08SAart Bik Type descTp = rewriter.getType<gpu::SparseSpGEMMOpHandleType>(); 80576a80a08SAart Bik Type tokenTp = rewriter.getType<gpu::AsyncTokenType>(); 80676a80a08SAart Bik Value token = genFirstWait(rewriter, loc); 80776a80a08SAart Bik Operation *spGenA = 8083d89c088SAart Bik genSpMat(rewriter, loc, aTp, spmatHandleTp, tokenTp, token, szm, szk, 8093d89c088SAart Bik nseA, rowA, colA, valA, format, enableRT); 81076a80a08SAart Bik Value spMatA = spGenA->getResult(0); 81176a80a08SAart Bik token = spGenA->getResult(1); 81276a80a08SAart Bik Operation *spGenB = 8133d89c088SAart Bik genSpMat(rewriter, loc, bTp, spmatHandleTp, tokenTp, token, szk, szn, 8143d89c088SAart Bik nseB, rowB, colB, valB, format, enableRT); 81576a80a08SAart Bik Value spMatB = spGenB->getResult(0); 81676a80a08SAart Bik token = spGenB->getResult(1); 81776a80a08SAart Bik 81876a80a08SAart Bik // Sparse matrix C materializes (also assumes beta == 0). 81976a80a08SAart Bik Value zero = constantIndex(rewriter, loc, 0); 82076a80a08SAart Bik Value one = constantIndex(rewriter, loc, 1); 82176a80a08SAart Bik Value mplus1 = rewriter.create<arith::AddIOp>(loc, szm, one); 82276a80a08SAart Bik auto e1 = genAllocBuffer(rewriter, loc, cTp.getPosType(), mplus1, token); 82376a80a08SAart Bik Value rowC = e1.getResult(0); 82476a80a08SAart Bik token = e1.getAsyncToken(); 82576a80a08SAart Bik auto e2 = genAllocBuffer(rewriter, loc, cTp.getCrdType(), zero, token); 826619a888dSAart Bik Value colC = e2.getResult(0); // no free needed 82776a80a08SAart Bik token = e2.getAsyncToken(); 82876a80a08SAart Bik auto e3 = genAllocBuffer(rewriter, loc, dnCType, zero, token); 829619a888dSAart Bik Value valC = e3.getResult(0); // no free needed 83076a80a08SAart Bik token = e3.getAsyncToken(); 83176a80a08SAart Bik Operation *spGenC = 8323d89c088SAart Bik genSpMat(rewriter, loc, cTp, spmatHandleTp, tokenTp, token, szm, szn, 8333d89c088SAart Bik zero, rowC, colC, valC, format, enableRT); 83476a80a08SAart Bik Value spMatC = spGenC->getResult(0); 83576a80a08SAart Bik token = spGenC->getResult(1); 83676a80a08SAart Bik 83776a80a08SAart Bik // Precompute buffersizes for SpGEMM. 83876a80a08SAart Bik Operation *descOp = 83976a80a08SAart Bik rewriter.create<gpu::SpGEMMCreateDescrOp>(loc, descTp, tokenTp, token); 84076a80a08SAart Bik Value desc = descOp->getResult(0); 84176a80a08SAart Bik token = descOp->getResult(1); 84276a80a08SAart Bik Operation *work1 = rewriter.create<gpu::SpGEMMWorkEstimationOrComputeOp>( 84376a80a08SAart Bik loc, indexTp, tokenTp, token, desc, gpu::TransposeMode::NON_TRANSPOSE, 84476a80a08SAart Bik gpu::TransposeMode::NON_TRANSPOSE, spMatA, spMatB, spMatC, dnCType, zero, 84576a80a08SAart Bik valC, gpu::SpGEMMWorkEstimationOrComputeKind::WORK_ESTIMATION); 84676a80a08SAart Bik Value bufferSz1 = work1->getResult(0); 84776a80a08SAart Bik token = work1->getResult(1); 84876a80a08SAart Bik auto buf1 = genAllocBuffer(rewriter, loc, bufferSz1, token); 84976a80a08SAart Bik Value buffer1 = buf1.getResult(0); 85076a80a08SAart Bik token = buf1.getAsyncToken(); 85176a80a08SAart Bik Operation *work2 = rewriter.create<gpu::SpGEMMWorkEstimationOrComputeOp>( 85276a80a08SAart Bik loc, indexTp, tokenTp, token, desc, gpu::TransposeMode::NON_TRANSPOSE, 85376a80a08SAart Bik gpu::TransposeMode::NON_TRANSPOSE, spMatA, spMatB, spMatC, dnCType, 85476a80a08SAart Bik bufferSz1, buffer1, 85576a80a08SAart Bik gpu::SpGEMMWorkEstimationOrComputeKind::WORK_ESTIMATION); 85676a80a08SAart Bik token = work2->getResult(1); 85776a80a08SAart Bik 85876a80a08SAart Bik // Compute step. 85976a80a08SAart Bik Operation *compute1 = rewriter.create<gpu::SpGEMMWorkEstimationOrComputeOp>( 86076a80a08SAart Bik loc, indexTp, tokenTp, token, desc, gpu::TransposeMode::NON_TRANSPOSE, 86176a80a08SAart Bik gpu::TransposeMode::NON_TRANSPOSE, spMatA, spMatB, spMatC, dnCType, zero, 86276a80a08SAart Bik valC, gpu::SpGEMMWorkEstimationOrComputeKind::COMPUTE); 86376a80a08SAart Bik Value bufferSz2 = compute1->getResult(0); 86476a80a08SAart Bik token = compute1->getResult(1); 86576a80a08SAart Bik auto buf2 = genAllocBuffer(rewriter, loc, bufferSz2, token); 86676a80a08SAart Bik Value buffer2 = buf2.getResult(0); 86776a80a08SAart Bik token = buf2.getAsyncToken(); 86876a80a08SAart Bik Operation *compute2 = rewriter.create<gpu::SpGEMMWorkEstimationOrComputeOp>( 86976a80a08SAart Bik loc, indexTp, tokenTp, token, desc, gpu::TransposeMode::NON_TRANSPOSE, 87076a80a08SAart Bik gpu::TransposeMode::NON_TRANSPOSE, spMatA, spMatB, spMatC, dnCType, 87176a80a08SAart Bik bufferSz2, buffer2, gpu::SpGEMMWorkEstimationOrComputeKind::COMPUTE); 87276a80a08SAart Bik token = compute2->getResult(1); 87376a80a08SAart Bik 87476a80a08SAart Bik // Get sizes. 875289f7231SAart Bik Operation *sizes = rewriter.create<gpu::SpMatGetSizeOp>( 87676a80a08SAart Bik loc, indexTp, indexTp, indexTp, tokenTp, token, spMatC); 87776a80a08SAart Bik Value nnz = sizes->getResult(2); 87876a80a08SAart Bik token = sizes->getResult(3); 87976a80a08SAart Bik auto a2 = genAllocBuffer(rewriter, loc, cTp.getCrdType(), nnz, token); 88076a80a08SAart Bik colC = a2.getResult(0); 88176a80a08SAart Bik token = a2.getAsyncToken(); 88276a80a08SAart Bik auto a3 = genAllocBuffer(rewriter, loc, dnCType, nnz, token); 88376a80a08SAart Bik valC = a3.getResult(0); 88476a80a08SAart Bik token = a3.getAsyncToken(); 88576a80a08SAart Bik 88676a80a08SAart Bik // Update C with new pointers and copy final product back into C. 88776a80a08SAart Bik Operation *update = rewriter.create<gpu::SetCsrPointersOp>( 88876a80a08SAart Bik loc, tokenTp, token, spMatC, rowC, colC, valC); 88976a80a08SAart Bik token = update->getResult(0); 89076a80a08SAart Bik Operation *copy = rewriter.create<gpu::SpGEMMCopyOp>( 89176a80a08SAart Bik loc, tokenTp, token, desc, gpu::TransposeMode::NON_TRANSPOSE, 89276a80a08SAart Bik gpu::TransposeMode::NON_TRANSPOSE, spMatA, spMatB, spMatC, dnCType); 89376a80a08SAart Bik token = copy->getResult(0); 89476a80a08SAart Bik 89576a80a08SAart Bik // Allocate buffers on host. 89676a80a08SAart Bik Value rowH = genHostBuffer(rewriter, loc, cTp.getPosType(), mplus1); 89776a80a08SAart Bik Value colH = genHostBuffer(rewriter, loc, cTp.getCrdType(), nnz); 89876a80a08SAart Bik Value valH = genHostBuffer(rewriter, loc, dnCType, nnz); 89976a80a08SAart Bik 90076a80a08SAart Bik // Copy data back to host and free all the resoures. 90176a80a08SAart Bik token = rewriter.create<gpu::SpGEMMDestroyDescrOp>(loc, tokenTp, token, desc) 90276a80a08SAart Bik .getAsyncToken(); 90376a80a08SAart Bik token = rewriter.create<gpu::DestroySpMatOp>(loc, tokenTp, token, spMatA) 90476a80a08SAart Bik .getAsyncToken(); 90576a80a08SAart Bik token = rewriter.create<gpu::DestroySpMatOp>(loc, tokenTp, token, spMatB) 90676a80a08SAart Bik .getAsyncToken(); 90776a80a08SAart Bik token = rewriter.create<gpu::DestroySpMatOp>(loc, tokenTp, token, spMatC) 90876a80a08SAart Bik .getAsyncToken(); 90976a80a08SAart Bik token = genCopyMemRef(rewriter, loc, rowH, rowC, token); 91076a80a08SAart Bik token = genCopyMemRef(rewriter, loc, colH, colC, token); 91176a80a08SAart Bik token = genCopyMemRef(rewriter, loc, valH, valC, token); 912619a888dSAart Bik token = genDeallocMemRef(rewriter, loc, rowA, token); 913619a888dSAart Bik token = genDeallocMemRef(rewriter, loc, colA, token); 914619a888dSAart Bik token = genDeallocMemRef(rewriter, loc, valA, token); 915619a888dSAart Bik token = genDeallocMemRef(rewriter, loc, rowB, token); 916619a888dSAart Bik token = genDeallocMemRef(rewriter, loc, colB, token); 917619a888dSAart Bik token = genDeallocMemRef(rewriter, loc, valB, token); 918619a888dSAart Bik token = genDeallocMemRef(rewriter, loc, rowC, token); 919619a888dSAart Bik token = genDeallocMemRef(rewriter, loc, colC, token); 920619a888dSAart Bik token = genDeallocMemRef(rewriter, loc, valC, token); 921619a888dSAart Bik token = genDeallocMemRef(rewriter, loc, buffer1, token); 922619a888dSAart Bik token = genDeallocMemRef(rewriter, loc, buffer2, token); 92376a80a08SAart Bik tokens.push_back(token); 92476a80a08SAart Bik genBlockingWait(rewriter, loc, tokens); 92576a80a08SAart Bik tokens.clear(); 92676a80a08SAart Bik 92776a80a08SAart Bik // Done. 92876a80a08SAart Bik Value vt = rewriter.create<bufferization::ToTensorOp>(loc, valH); 92976a80a08SAart Bik Value rt = rewriter.create<bufferization::ToTensorOp>(loc, rowH); 93076a80a08SAart Bik Value ct = rewriter.create<bufferization::ToTensorOp>(loc, colH); 931fc9f1d49SPeiming Liu rewriter.replaceOpWithNewOp<AssembleOp>(op, c.getType(), ValueRange{rt, ct}, 932fc9f1d49SPeiming Liu vt); 93376a80a08SAart Bik return success(); 93476a80a08SAart Bik } 93576a80a08SAart Bik 93676a80a08SAart Bik // Match and rewrite 2:4 SpMM kernel. 9375ef44679SAart Bik static LogicalResult rewrite2To4SpMM(PatternRewriter &rewriter, 9385ef44679SAart Bik linalg::GenericOp op) { 939e37fc3ccSK-Wu Location loc = op.getLoc(); 940e37fc3ccSK-Wu Value A = op.getOperand(0); 941e37fc3ccSK-Wu Value B = op.getOperand(1); 942e37fc3ccSK-Wu Value C = op.getOperand(2); // we have C = AB 943e37fc3ccSK-Wu SmallVector<Value> tokens; 944e37fc3ccSK-Wu 94541a07e66SAart Bik // The cuSparselt API currently only allows pruning and compression 94641a07e66SAart Bik // to occur on the device. So we recognize the pattern 94741a07e66SAart Bik // A' = convert A ; dense to 2:4 94841a07e66SAart Bik // C = A'B ; 2:4 matrix mult 94941a07e66SAart Bik // and then perform compression and matrix multiplication on device. 95041a07e66SAart Bik auto cnv = A.getDefiningOp<ConvertOp>(); 95141a07e66SAart Bik assert(cnv); 95241a07e66SAart Bik A = cnv.getSource(); 95341a07e66SAart Bik 954e37fc3ccSK-Wu // All input should be dense tensors. 955e37fc3ccSK-Wu if (!isDenseTensor(A) || !isDenseTensor(B) || !isDenseTensor(C)) 956e37fc3ccSK-Wu return failure(); 957e37fc3ccSK-Wu 9585ef44679SAart Bik // Start sparse kernel and copy data from host to device. 9595ef44679SAart Bik // a : bufA -> matA 9605ef44679SAart Bik // b : bufB -> matB 9615ef44679SAart Bik // c : bufC -> matC 962e37fc3ccSK-Wu Value bufA = genTensorToMemref(rewriter, loc, A); 9635ef44679SAart Bik Value matA = genAllocCopy(rewriter, loc, bufA, tokens); 964e37fc3ccSK-Wu Value bufB = genTensorToMemref(rewriter, loc, B); 9655ef44679SAart Bik Value matB = genAllocCopy(rewriter, loc, bufB, tokens); 966e37fc3ccSK-Wu Value bufC = genTensorToMemref(rewriter, loc, C); 967e37fc3ccSK-Wu Value matC = genAllocCopy(rewriter, loc, bufC, tokens); 968e37fc3ccSK-Wu genBlockingWait(rewriter, loc, tokens); 969e37fc3ccSK-Wu tokens.clear(); 97076a80a08SAart Bik 97176a80a08SAart Bik // Create sparse environment and sparse matrix/dense vector handles. 972e37fc3ccSK-Wu Value szm = linalg::createOrFoldDimOp(rewriter, loc, matA, 0); 973e37fc3ccSK-Wu Value szk = linalg::createOrFoldDimOp(rewriter, loc, matB, 0); 974e37fc3ccSK-Wu Value szn = linalg::createOrFoldDimOp(rewriter, loc, matC, 1); 975e37fc3ccSK-Wu Type indexTp = rewriter.getIndexType(); 976e37fc3ccSK-Wu Type dnTensorHandleTp = rewriter.getType<gpu::SparseDnTensorHandleType>(); 977e37fc3ccSK-Wu Type spMatHandleTp = rewriter.getType<gpu::SparseSpMatHandleType>(); 978e37fc3ccSK-Wu Type tokenTp = rewriter.getType<gpu::AsyncTokenType>(); 979e37fc3ccSK-Wu Value token = genFirstWait(rewriter, loc); 980e37fc3ccSK-Wu Operation *spGenA = rewriter.create<gpu::Create2To4SpMatOp>( 9811e491c42SKun Wu loc, spMatHandleTp, tokenTp, token, szm, szk, 9821e491c42SKun Wu gpu::Prune2To4SpMatFlag::PRUNE_AND_CHECK, matA); 983e37fc3ccSK-Wu Value spMatA = spGenA->getResult(0); 984e37fc3ccSK-Wu token = spGenA->getResult(1); 985e37fc3ccSK-Wu auto dmatB = rewriter.create<gpu::CreateDnTensorOp>( 986e37fc3ccSK-Wu loc, dnTensorHandleTp, tokenTp, token, matB, 987e37fc3ccSK-Wu SmallVector<Value>{szk, szn}); 988e37fc3ccSK-Wu Value dnB = dmatB.getResult(0); 989e37fc3ccSK-Wu token = dmatB.getAsyncToken(); 990e37fc3ccSK-Wu auto dmatC = rewriter.create<gpu::CreateDnTensorOp>( 991e37fc3ccSK-Wu loc, dnTensorHandleTp, tokenTp, token, matC, 992e37fc3ccSK-Wu SmallVector<Value>{szm, szn}); 993e37fc3ccSK-Wu Value dnC = dmatC.getResult(0); 994e37fc3ccSK-Wu token = dmatC.getAsyncToken(); 995e37fc3ccSK-Wu auto dmatCType = llvm::cast<ShapedType>(matC.getType()).getElementType(); 996e37fc3ccSK-Wu 997e37fc3ccSK-Wu // Precompute buffersize for SpMM. 998e37fc3ccSK-Wu SmallVector<Type> bufferTypes_{indexTp, indexTp, indexTp}; 999e37fc3ccSK-Wu TypeRange bufferTypes(bufferTypes_); 1000e37fc3ccSK-Wu auto bufferComp = rewriter.create<gpu::SpMMBufferSizeOp>( 1001e37fc3ccSK-Wu loc, bufferTypes, tokenTp, token, gpu::TransposeMode::NON_TRANSPOSE, 1002e37fc3ccSK-Wu gpu::TransposeMode::NON_TRANSPOSE, spMatA, dnB, dnC, 1003e37fc3ccSK-Wu /*computeType=*/dmatCType); 1004e37fc3ccSK-Wu token = bufferComp.getAsyncToken(); 100576a80a08SAart Bik 10065ef44679SAart Bik // Allocate buffers on host. 10075ef44679SAart Bik Value bufferSz1 = bufferComp.getResult(0); 10085ef44679SAart Bik auto buf1 = genAllocBuffer(rewriter, loc, bufferSz1, token); 10095ef44679SAart Bik Value buffer1 = buf1.getResult(0); 10105ef44679SAart Bik token = buf1.getAsyncToken(); 1011e37fc3ccSK-Wu Value bufferSz2 = bufferComp.getResult(1); 1012e37fc3ccSK-Wu auto buf2 = genAllocBuffer(rewriter, loc, bufferSz2, token); 1013e37fc3ccSK-Wu Value buffer2 = buf2.getResult(0); 1014e37fc3ccSK-Wu token = buf2.getAsyncToken(); 1015e37fc3ccSK-Wu Value bufferSz3 = bufferComp.getResult(2); 1016e37fc3ccSK-Wu auto buf3 = genAllocBuffer(rewriter, loc, bufferSz3, token); 1017e37fc3ccSK-Wu Value buffer3 = buf3.getResult(0); 1018e37fc3ccSK-Wu token = buf3.getAsyncToken(); 1019e37fc3ccSK-Wu 1020e37fc3ccSK-Wu // Perform the SpMM. 10215ef44679SAart Bik auto dnCType = llvm::cast<ShapedType>(matC.getType()).getElementType(); 1022e37fc3ccSK-Wu auto spmmComp = rewriter.create<gpu::SpMMOp>( 1023e37fc3ccSK-Wu loc, tokenTp, token, spMatA, dnB, dnC, /*computeType=*/dnCType, 10245ef44679SAart Bik SmallVector<Value>{buffer1, buffer2, buffer3}); 1025e37fc3ccSK-Wu token = spmmComp.getAsyncToken(); 1026e37fc3ccSK-Wu 1027e37fc3ccSK-Wu // Copy data back to host and free all the resources. 1028e37fc3ccSK-Wu token = rewriter.create<gpu::DestroySpMatOp>(loc, tokenTp, token, spMatA) 1029e37fc3ccSK-Wu .getAsyncToken(); 1030e37fc3ccSK-Wu token = rewriter.create<gpu::DestroyDnTensorOp>(loc, tokenTp, token, dnB) 1031e37fc3ccSK-Wu .getAsyncToken(); 1032e37fc3ccSK-Wu token = rewriter.create<gpu::DestroyDnTensorOp>(loc, tokenTp, token, dnC) 1033e37fc3ccSK-Wu .getAsyncToken(); 1034e37fc3ccSK-Wu SmallVector<Value> newDynamicSizes; 10355ef44679SAart Bik token = genDeallocMemRef(rewriter, loc, buffer1, token); 1036e37fc3ccSK-Wu token = genDeallocMemRef(rewriter, loc, buffer2, token); 1037e37fc3ccSK-Wu token = genDeallocMemRef(rewriter, loc, buffer3, token); 1038e37fc3ccSK-Wu token = genDeallocMemRef(rewriter, loc, matA, token); 1039e37fc3ccSK-Wu token = genDeallocMemRef(rewriter, loc, matB, token); 1040e37fc3ccSK-Wu token = genCopyMemRef(rewriter, loc, bufC, matC, token); 1041e37fc3ccSK-Wu token = genDeallocMemRef(rewriter, loc, matC, token); 1042e37fc3ccSK-Wu tokens.push_back(token); 1043e37fc3ccSK-Wu genBlockingWait(rewriter, loc, tokens); 104476a80a08SAart Bik tokens.clear(); 104576a80a08SAart Bik 104676a80a08SAart Bik // Done. 1047e37fc3ccSK-Wu rewriter.replaceOpWithNewOp<bufferization::ToTensorOp>(op, bufC); 1048e37fc3ccSK-Wu return success(); 1049e37fc3ccSK-Wu } 1050e37fc3ccSK-Wu 10519167dd46SKun Wu /// Match and rewrite SDDMM kernel. 10525ef44679SAart Bik static LogicalResult rewriteSDDMM(PatternRewriter &rewriter, 10535ef44679SAart Bik linalg::GenericOp op, bool enableRT) { 10549167dd46SKun Wu Location loc = op.getLoc(); 10559167dd46SKun Wu Value a = op.getOperand(0); 10569167dd46SKun Wu Value b = op.getOperand(1); 10579167dd46SKun Wu Value c = op.getOperand(2); 10589167dd46SKun Wu SmallVector<Value> tokens; 10599167dd46SKun Wu 10603231a365SAart Bik // Only admissible sparse matrix format (no COO/CSC) and dense matrices. 10619167dd46SKun Wu SparseTensorType aTp = getSparseTensorType(a); 10629167dd46SKun Wu SparseTensorType bTp = getSparseTensorType(b); 10639167dd46SKun Wu SparseTensorType cTp = getSparseTensorType(c); 10643231a365SAart Bik auto format = getCuSparseFormat(cTp, bTp, aTp, enableRT, /*isMatVec=*/false); 10653231a365SAart Bik if (format == CuSparseFormat::kNone || format == CuSparseFormat::kCOO || 10663231a365SAart Bik format == CuSparseFormat::kCSC) 10679167dd46SKun Wu return failure(); 10689167dd46SKun Wu 10699167dd46SKun Wu // The SDDMM does the in-place operation. 10709167dd46SKun Wu // Start sparse kernel and copy data from host to device. 10719167dd46SKun Wu // a : bufA -> matA 10725ef44679SAart Bik // b : bufB -> matB 10739167dd46SKun Wu // c : memR/memC/memV -> rowC,colC,valC 10749167dd46SKun Wu Value nseC = rewriter.create<NumberOfEntriesOp>(loc, c); 10759167dd46SKun Wu Value szm = linalg::createOrFoldDimOp(rewriter, loc, a, 0); 10769167dd46SKun Wu Value szk = linalg::createOrFoldDimOp(rewriter, loc, a, 1); 10779167dd46SKun Wu Value szn = linalg::createOrFoldDimOp(rewriter, loc, b, 1); 10789167dd46SKun Wu Value bufA = genTensorToMemref(rewriter, loc, a); 10795ef44679SAart Bik Value matA = genAllocCopy(rewriter, loc, bufA, tokens); 10809167dd46SKun Wu Value bufB = genTensorToMemref(rewriter, loc, b); 10815ef44679SAart Bik Value matB = genAllocCopy(rewriter, loc, bufB, tokens); 10823231a365SAart Bik Value memR = genFirstPosOrCrds(rewriter, loc, c, format, enableRT); 10835ef44679SAart Bik Value memC = genSecondCrds(rewriter, loc, c, format, enableRT); // or empty 10841a0986f0SPeiming Liu Value memV = rewriter.create<ToValuesOp>(loc, c); 10859167dd46SKun Wu Value rowC = genAllocCopy(rewriter, loc, memR, tokens); 10869167dd46SKun Wu Value colC = memC ? genAllocCopy(rewriter, loc, memC, tokens) : Value(); 10879167dd46SKun Wu Value valC = genAllocCopy(rewriter, loc, memV, tokens); 10889167dd46SKun Wu genBlockingWait(rewriter, loc, tokens); 10899167dd46SKun Wu tokens.clear(); 10909167dd46SKun Wu 10919167dd46SKun Wu // Create sparse environment and sparse matrix/dense matrix handles. 10929167dd46SKun Wu Type indexTp = rewriter.getIndexType(); 10939167dd46SKun Wu Type dnMatHandleTp = rewriter.getType<gpu::SparseDnTensorHandleType>(); 10949167dd46SKun Wu Type spMatHandleTp = rewriter.getType<gpu::SparseSpMatHandleType>(); 10959167dd46SKun Wu Type tokenTp = rewriter.getType<gpu::AsyncTokenType>(); 10969167dd46SKun Wu Value token = genFirstWait(rewriter, loc); 10979167dd46SKun Wu auto dmatA = rewriter.create<gpu::CreateDnTensorOp>( 1098be2dd22bSKun Wu loc, dnMatHandleTp, tokenTp, token, matA, SmallVector<Value>{szm, szk}); 10999167dd46SKun Wu Value dnA = dmatA.getResult(0); 11009167dd46SKun Wu token = dmatA.getAsyncToken(); 11019167dd46SKun Wu auto dmatB = rewriter.create<gpu::CreateDnTensorOp>( 1102be2dd22bSKun Wu loc, dnMatHandleTp, tokenTp, token, matB, SmallVector<Value>{szk, szn}); 11039167dd46SKun Wu Value dnB = dmatB.getResult(0); 11049167dd46SKun Wu token = dmatB.getAsyncToken(); 11059167dd46SKun Wu Operation *spGenC = 11063d89c088SAart Bik genSpMat(rewriter, loc, cTp, spMatHandleTp, tokenTp, token, szm, szn, 11073d89c088SAart Bik nseC, rowC, colC, valC, format, enableRT); 11089167dd46SKun Wu Value spMatC = spGenC->getResult(0); 11099167dd46SKun Wu token = spGenC->getResult(1); 11109167dd46SKun Wu auto dnCType = llvm::cast<ShapedType>(c.getType()).getElementType(); 111176a80a08SAart Bik 11129167dd46SKun Wu // Precompute buffersize for SDDMM. 11139167dd46SKun Wu auto bufferComp = rewriter.create<gpu::SDDMMBufferSizeOp>( 1114be2dd22bSKun Wu loc, indexTp, tokenTp, token, dnA, dnB, spMatC, dnCType); 11159167dd46SKun Wu Value bufferSz = bufferComp.getResult(0); 11169167dd46SKun Wu token = bufferComp.getAsyncToken(); 11179167dd46SKun Wu auto buf = genAllocBuffer(rewriter, loc, bufferSz, token); 11189167dd46SKun Wu Value buffer = buf.getResult(0); 11199167dd46SKun Wu token = buf.getAsyncToken(); 11209167dd46SKun Wu 11219167dd46SKun Wu // Perform the SDDMM. 1122be2dd22bSKun Wu auto sddmmComp = rewriter.create<gpu::SDDMMOp>(loc, tokenTp, token, dnA, dnB, 1123be2dd22bSKun Wu spMatC, dnCType, buffer); 11249167dd46SKun Wu token = sddmmComp.getAsyncToken(); 11259167dd46SKun Wu 11269167dd46SKun Wu // Copy data back to host and free all the resoures. 11279167dd46SKun Wu token = rewriter.create<gpu::DestroyDnTensorOp>(loc, tokenTp, token, dnA) 11289167dd46SKun Wu .getAsyncToken(); 11299167dd46SKun Wu token = rewriter.create<gpu::DestroyDnTensorOp>(loc, tokenTp, token, dnB) 11309167dd46SKun Wu .getAsyncToken(); 11319167dd46SKun Wu token = rewriter.create<gpu::DestroySpMatOp>(loc, tokenTp, token, spMatC) 11329167dd46SKun Wu .getAsyncToken(); 11339167dd46SKun Wu token = genDeallocMemRef(rewriter, loc, buffer, token); 11349167dd46SKun Wu token = genDeallocMemRef(rewriter, loc, matA, token); 11359167dd46SKun Wu token = genDeallocMemRef(rewriter, loc, matB, token); 11369167dd46SKun Wu token = genDeallocMemRef(rewriter, loc, rowC, token); 11379167dd46SKun Wu if (colC) 11389167dd46SKun Wu token = genDeallocMemRef(rewriter, loc, colC, token); 11399167dd46SKun Wu token = genCopyMemRef(rewriter, loc, memV, valC, token); 11409167dd46SKun Wu token = genDeallocMemRef(rewriter, loc, valC, token); 11419167dd46SKun Wu tokens.push_back(token); 11429167dd46SKun Wu genBlockingWait(rewriter, loc, tokens); 114376a80a08SAart Bik tokens.clear(); 11449167dd46SKun Wu 1145f14c8eb5SAart Bik // Done. 11469167dd46SKun Wu rewriter.replaceOpWithNewOp<sparse_tensor::LoadOp>(op, c); 11479167dd46SKun Wu return success(); 11489167dd46SKun Wu } 11499167dd46SKun Wu 1150ee42e236SAart Bik //===----------------------------------------------------------------------===// 1151ee42e236SAart Bik // Rewriting rules for direct code generation. 1152ee42e236SAart Bik //===----------------------------------------------------------------------===// 1153ee42e236SAart Bik 1154ee42e236SAart Bik /// Proof-of-concept rewriter. This rule generates a GPU implementation 1155c43e6274STim Harvey /// for each outermost forall loop generated by the sparsifier. 115676a80a08SAart Bik /// TODO: right now works with parallelization-strategy=dense-outer-loop 115786888e42SAart Bik /// but give this its own flags in the future 115819466ebcSAart Bik struct ForallRewriter : public OpRewritePattern<scf::ParallelOp> { 115919466ebcSAart Bik using OpRewritePattern<scf::ParallelOp>::OpRewritePattern; 116019466ebcSAart Bik 116119466ebcSAart Bik ForallRewriter(MLIRContext *context, unsigned nT) 116219466ebcSAart Bik : OpRewritePattern(context), numThreads(nT){}; 116319466ebcSAart Bik 116419466ebcSAart Bik LogicalResult matchAndRewrite(scf::ParallelOp forallOp, 116519466ebcSAart Bik PatternRewriter &rewriter) const override { 116619466ebcSAart Bik // Reject inadmissible loop form. 1167c43e6274STim Harvey // Essentially only accept a loop, generated by the sparsifier, 116819466ebcSAart Bik // of the form 116919466ebcSAart Bik // forall (i = 0; i < N; i++) 117019466ebcSAart Bik // so that cyclic scheduling over the threads is easy. 117119466ebcSAart Bik if (!forallOp->hasAttr(LoopEmitter::getLoopEmitterLoopAttrName()) || 117219466ebcSAart Bik forallOp.getNumReductions() != 0 || forallOp.getNumLoops() != 1 || 117319466ebcSAart Bik !matchPattern(forallOp.getLowerBound()[0], m_Zero()) || 117419466ebcSAart Bik !matchPattern(forallOp.getStep()[0], m_One())) 117519466ebcSAart Bik return failure(); 117619466ebcSAart Bik // Collect every value that is computed outside the parallel loop. 117719466ebcSAart Bik SetVector<Value> invariants; // stable iteration! 117819466ebcSAart Bik forallOp->walk([&](Operation *op) { 117919466ebcSAart Bik // Collect all values of admissible ops. 118019466ebcSAart Bik for (OpOperand &o : op->getOpOperands()) { 118119466ebcSAart Bik Value val = o.get(); 118219466ebcSAart Bik Block *block; 11835550c821STres Popp if (auto arg = dyn_cast<BlockArgument>(val)) 118419466ebcSAart Bik block = arg.getOwner(); 118519466ebcSAart Bik else 118619466ebcSAart Bik block = val.getDefiningOp()->getBlock(); 1187ea979b24SMatthias Springer if (!forallOp.getRegion().findAncestorBlockInRegion(*block)) 118819466ebcSAart Bik invariants.insert(val); 118919466ebcSAart Bik } 119019466ebcSAart Bik }); 119119466ebcSAart Bik // Outline the outside values as proper parameters. Fail when sharing 119219466ebcSAart Bik // value between host and device is not straightforward. 119319466ebcSAart Bik SmallVector<Value> constants; 119419466ebcSAart Bik SmallVector<Value> scalars; 119519466ebcSAart Bik SmallVector<Value> buffers; 119619466ebcSAart Bik for (Value val : invariants) { 119719466ebcSAart Bik Type tp = val.getType(); 119819466ebcSAart Bik if (val.getDefiningOp<arith::ConstantOp>()) 119919466ebcSAart Bik constants.push_back(val); 12005550c821STres Popp else if (isa<FloatType>(tp) || tp.isIntOrIndex()) 120119466ebcSAart Bik scalars.push_back(val); 120219466ebcSAart Bik else if (isa<MemRefType>(tp)) 120319466ebcSAart Bik buffers.push_back(val); 120419466ebcSAart Bik else 120519466ebcSAart Bik return failure(); // don't know how to share 120619466ebcSAart Bik } 120786888e42SAart Bik // Pass outlined non-constant values. 120886888e42SAart Bik // TODO: Experiment with `useHostRegistrationForOut` to see if we want to 120986888e42SAart Bik // keep the feature at all (either through a heuristic or compiler 121086888e42SAart Bik // option for gpu codegen). 121119466ebcSAart Bik Location loc = forallOp->getLoc(); 121219466ebcSAart Bik SmallVector<Value> args; 121386888e42SAart Bik SmallVector<Value> tokens; 121486888e42SAart Bik Value out = genParametersIn(rewriter, loc, scalars, buffers, args, tokens, 121586888e42SAart Bik /*useHostRegistrationForOut=*/false); 121619466ebcSAart Bik // Set up GPU module and construct GPU function. 121786888e42SAart Bik auto saveIp = rewriter.saveInsertionPoint(); 121819466ebcSAart Bik ModuleOp topModule = forallOp->getParentOfType<ModuleOp>(); 12194889214aSAart Bik auto gpuModule = genGPUModule(rewriter, topModule); 12204889214aSAart Bik auto gpuFunc = genGPUFunc(rewriter, gpuModule, args); 122119466ebcSAart Bik genGPUCode(rewriter, gpuFunc, forallOp, constants, scalars, buffers); 122286888e42SAart Bik // Generate code that launches the kernel asynchronously, blocking on all 122386888e42SAart Bik // opens tokens and yielding a new token for the output. 122486888e42SAart Bik // TODO: Passing in tokens to launch up does not seem to be properly lowered 122586888e42SAart Bik // by cubin yet, hence the current blocking wait. 122619466ebcSAart Bik rewriter.restoreInsertionPoint(saveIp); 122786888e42SAart Bik genBlockingWait(rewriter, loc, tokens); 122886888e42SAart Bik tokens.clear(); 122986888e42SAart Bik Value kernelToken = 123086888e42SAart Bik genLaunchGPUFunc(rewriter, gpuFunc, args, tokens, numThreads); 123186888e42SAart Bik // Finalize the outlined arguments. 123286888e42SAart Bik genParametersOut(rewriter, loc, out, kernelToken, scalars, buffers, args, 123386888e42SAart Bik tokens); 123486888e42SAart Bik genBlockingWait(rewriter, loc, tokens); 123519466ebcSAart Bik rewriter.eraseOp(forallOp); 123619466ebcSAart Bik return success(); 123719466ebcSAart Bik } 123819466ebcSAart Bik 123919466ebcSAart Bik private: 124019466ebcSAart Bik unsigned numThreads; 124119466ebcSAart Bik }; 124219466ebcSAart Bik 1243ee42e236SAart Bik //===----------------------------------------------------------------------===// 1244ee42e236SAart Bik // Rewriting rules for library recognition and code generation. 1245ee42e236SAart Bik //===----------------------------------------------------------------------===// 1246ee42e236SAart Bik 1247ee42e236SAart Bik /// Proof-of-concept rewriter. This rule recognizes certain math kernels 1248b75d6a40SAart Bik /// and replaces these with corresponding calls into a sparse library. 1249ee42e236SAart Bik struct LinalgOpRewriter : public OpRewritePattern<linalg::GenericOp> { 1250ee42e236SAart Bik using OpRewritePattern<linalg::GenericOp>::OpRewritePattern; 1251ee42e236SAart Bik 12525ef44679SAart Bik LinalgOpRewriter(MLIRContext *context, bool rt) 12535ef44679SAart Bik : OpRewritePattern(context), enableRT(rt) {} 1254ee42e236SAart Bik 1255ee42e236SAart Bik LogicalResult matchAndRewrite(linalg::GenericOp op, 1256ee42e236SAart Bik PatternRewriter &rewriter) const override { 1257ee42e236SAart Bik if (op.getNumDpsInits() != 1) 1258ee42e236SAart Bik return failure(); // reject multi-output 1259ee42e236SAart Bik 1260ee42e236SAart Bik const unsigned numLoops = op.getNumLoops(); 1261ee42e236SAart Bik const unsigned numTensors = op->getNumOperands(); 1262ee42e236SAart Bik const auto iteratorTypes = op.getIteratorTypesArray(); 1263ee42e236SAart Bik SmallVector<AffineMap, 4> maps = op.getIndexingMapsArray(); 1264ee42e236SAart Bik 1265ee42e236SAart Bik using MapList = ArrayRef<ArrayRef<AffineExpr>>; 1266fe8a62c4SUday Bondhugula auto infer = [&](MapList m) { 1267fe8a62c4SUday Bondhugula return AffineMap::inferFromExprList(m, op.getContext()); 1268fe8a62c4SUday Bondhugula }; 1269ee42e236SAart Bik AffineExpr i, j, k; 1270ee42e236SAart Bik bindDims(getContext(), i, j, k); 1271ee42e236SAart Bik 1272*aa295216SJay Foad // TODO: more robust patterns, transposed versions, more kernels, 127376a80a08SAart Bik // identify alpha and beta and pass them to the CUDA calls. 1274ee42e236SAart Bik 1275ee42e236SAart Bik // Recognize a SpMV kernel. 1276ee42e236SAart Bik if (numLoops == 2 && numTensors == 3 && 1277ee42e236SAart Bik linalg::isParallelIterator(iteratorTypes[0]) && 1278ee42e236SAart Bik linalg::isReductionIterator(iteratorTypes[1]) && 1279ee42e236SAart Bik maps == infer({{i, j}, {j}, {i}}) && matchSumOfMultOfArgs(op)) { 12805ef44679SAart Bik return rewriteSpMV(rewriter, op, enableRT); 1281ee42e236SAart Bik } 1282ee42e236SAart Bik 128376a80a08SAart Bik // Recognize a SpGEMM, 2:4-SpMM, or SpMM kernel. 1284ee42e236SAart Bik if (numLoops == 3 && numTensors == 3 && 1285ee42e236SAart Bik linalg::isParallelIterator(iteratorTypes[0]) && 1286ee42e236SAart Bik linalg::isParallelIterator(iteratorTypes[1]) && 1287ee42e236SAart Bik linalg::isReductionIterator(iteratorTypes[2]) && 1288ee42e236SAart Bik maps == infer({{i, k}, {k, j}, {i, j}}) && matchSumOfMultOfArgs(op)) { 128976a80a08SAart Bik if (!isDenseTensor(op.getOperand(0)) && !isDenseTensor(op.getOperand(1))) 12905ef44679SAart Bik return rewriteSpGEMM(rewriter, op, enableRT); 129141a07e66SAart Bik if (isConversionInto24(op.getOperand(0))) 12925ef44679SAart Bik return rewrite2To4SpMM(rewriter, op); 12935ef44679SAart Bik return rewriteSpMM(rewriter, op, enableRT); 1294ee42e236SAart Bik } 1295ee42e236SAart Bik 12969167dd46SKun Wu // Recognize a SDDMM kernel. 12979167dd46SKun Wu if (numLoops == 3 && numTensors == 3 && 12989167dd46SKun Wu linalg::isParallelIterator(iteratorTypes[0]) && 12999167dd46SKun Wu linalg::isParallelIterator(iteratorTypes[1]) && 13009167dd46SKun Wu linalg::isReductionIterator(iteratorTypes[2]) && 13019167dd46SKun Wu maps == infer({{i, k}, {k, j}, {i, j}}) && 13029167dd46SKun Wu matchSumReductionOfMulUnary(op)) { 13035ef44679SAart Bik return rewriteSDDMM(rewriter, op, enableRT); 13049167dd46SKun Wu } 13059167dd46SKun Wu 1306ee42e236SAart Bik return failure(); 1307ee42e236SAart Bik } 1308ee42e236SAart Bik 1309ee42e236SAart Bik private: 1310ee42e236SAart Bik bool enableRT; 1311ee42e236SAart Bik }; 1312ee42e236SAart Bik 131319466ebcSAart Bik } // namespace 131419466ebcSAart Bik 131519466ebcSAart Bik //===----------------------------------------------------------------------===// 131619466ebcSAart Bik // Public method for populating GPU rewriting rules. 1317ee42e236SAart Bik // 1318ee42e236SAart Bik // Currently two set of rewriting rules are made available. The first set 1319ee42e236SAart Bik // implements direct code generation, currently by means of convering the 1320ee42e236SAart Bik // outermost paralell loop into GPU threads. The second set implements 1321ee42e236SAart Bik // libary recognition of a set of sparse operations. Eventually, the right 1322ee42e236SAart Bik // combination of these two approaches has to be found. 132319466ebcSAart Bik //===----------------------------------------------------------------------===// 132419466ebcSAart Bik 132519466ebcSAart Bik void mlir::populateSparseGPUCodegenPatterns(RewritePatternSet &patterns, 132619466ebcSAart Bik unsigned numThreads) { 132719466ebcSAart Bik patterns.add<ForallRewriter>(patterns.getContext(), numThreads); 132819466ebcSAart Bik } 1329ee42e236SAart Bik 13305ef44679SAart Bik void mlir::populateSparseGPULibgenPatterns(RewritePatternSet &patterns, 13315ef44679SAart Bik bool enableRT) { 13325ef44679SAart Bik patterns.add<LinalgOpRewriter>(patterns.getContext(), enableRT); 1333ee42e236SAart Bik } 1334