1 //===- SparseGPUCodegen.cpp - Generates GPU code --------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This is a prototype GPU codegenerator for the sparsifier. 10 // The objective is to eventually use the right combination of 11 // direct code generation and libary calls into vendor-specific 12 // highly optimized sparse libraries (e.g. cuSparse for CUDA). 13 // 14 //===----------------------------------------------------------------------===// 15 16 #include "Utils/CodegenUtils.h" 17 #include "Utils/LoopEmitter.h" 18 19 #include "mlir/Dialect/Bufferization/IR/Bufferization.h" 20 #include "mlir/Dialect/GPU/IR/GPUDialect.h" 21 #include "mlir/Dialect/Linalg/IR/Linalg.h" 22 #include "mlir/Dialect/Linalg/Utils/Utils.h" 23 #include "mlir/Dialect/MemRef/IR/MemRef.h" 24 #include "mlir/Dialect/SCF/IR/SCF.h" 25 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" 26 #include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h" 27 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h" 28 #include "mlir/IR/IRMapping.h" 29 #include "mlir/IR/Matchers.h" 30 31 using namespace mlir; 32 using namespace mlir::sparse_tensor; 33 34 namespace { 35 36 // Sparse formats supported by cuSparse. 37 enum class CuSparseFormat { 38 kNone, 39 kCOO, 40 kCSR, 41 kCSC, 42 kBSR, 43 }; 44 45 //===----------------------------------------------------------------------===// 46 // Helper methods. 47 //===----------------------------------------------------------------------===// 48 49 /// Marks the given top module as a GPU container module. 50 static void markAsGPUContainer(ModuleOp topModule) { 51 topModule->setAttr(gpu::GPUDialect::getContainerModuleAttrName(), 52 UnitAttr::get(topModule->getContext())); 53 } 54 55 /// Constructs a new GPU module (for GPU kernels) inside the given top module, 56 /// or returns an existing GPU module if one was built previously. 57 static gpu::GPUModuleOp genGPUModule(OpBuilder &builder, ModuleOp topModule) { 58 for (auto op : topModule.getBodyRegion().getOps<gpu::GPUModuleOp>()) 59 return op; // existing 60 markAsGPUContainer(topModule); 61 builder.setInsertionPointToStart(&topModule.getBodyRegion().front()); 62 return builder.create<gpu::GPUModuleOp>(topModule->getLoc(), 63 "sparse_kernels"); 64 } 65 66 /// Constructs a new GPU kernel in the given GPU module. 67 static gpu::GPUFuncOp genGPUFunc(OpBuilder &builder, gpu::GPUModuleOp gpuModule, 68 SmallVectorImpl<Value> &args) { 69 // Get a unique kernel name. Not very creative, 70 // but we simply try kernel0, kernel1, etc. 71 unsigned kernelNumber = 0; 72 SmallString<16> kernelName; 73 do { 74 kernelName.clear(); 75 ("kernel" + Twine(kernelNumber++)).toStringRef(kernelName); 76 } while (gpuModule.lookupSymbol(kernelName)); 77 // Then we insert a new kernel with given arguments into the module. 78 builder.setInsertionPointToStart(&gpuModule.getBodyRegion().front()); 79 SmallVector<Type> argsTp; 80 for (unsigned i = 0, e = args.size(); i < e; i++) 81 argsTp.push_back(args[i].getType()); 82 FunctionType type = FunctionType::get(gpuModule->getContext(), argsTp, {}); 83 auto gpuFunc = 84 builder.create<gpu::GPUFuncOp>(gpuModule->getLoc(), kernelName, type); 85 gpuFunc->setAttr(gpu::GPUDialect::getKernelFuncAttrName(), 86 builder.getUnitAttr()); 87 return gpuFunc; 88 } 89 90 /// Constructs code to launch GPU kernel. 91 static Value genLaunchGPUFunc(OpBuilder &builder, gpu::GPUFuncOp gpuFunc, 92 SmallVectorImpl<Value> &args, 93 SmallVectorImpl<Value> &tokens, 94 unsigned numThreads) { 95 Location loc = gpuFunc->getLoc(); 96 Value none = TypedValue<::mlir::IntegerType>{}; 97 Value one = constantIndex(builder, loc, 1); 98 Value numT = constantIndex(builder, loc, numThreads); 99 gpu::KernelDim3 gridSize = {one, one, one}; 100 gpu::KernelDim3 blckSize = {numT, one, one}; 101 return builder 102 .create<gpu::LaunchFuncOp>(loc, gpuFunc, gridSize, blckSize, 103 /*dynSharedMemSz*/ none, args, 104 builder.getType<gpu::AsyncTokenType>(), tokens) 105 .getAsyncToken(); 106 } 107 108 /// Maps the provided ranked host buffer into the device address space. 109 /// Writes from the host are guaranteed to be visible to device kernels 110 /// that are launched afterwards. Writes from the device are guaranteed 111 /// to be visible on the host after synchronizing with the device kernel 112 /// completion. Needs to cast the buffer to a unranked buffer. 113 static Value genHostRegisterMemref(OpBuilder &builder, Location loc, 114 Value mem) { 115 MemRefType memTp = cast<MemRefType>(mem.getType()); 116 UnrankedMemRefType resTp = 117 UnrankedMemRefType::get(memTp.getElementType(), /*memorySpace=*/0); 118 Value cast = builder.create<memref::CastOp>(loc, resTp, mem); 119 builder.create<gpu::HostRegisterOp>(loc, cast); 120 return cast; 121 } 122 123 /// Unmaps the provided buffer, expecting the casted buffer. 124 static void genHostUnregisterMemref(OpBuilder &builder, Location loc, 125 Value cast) { 126 builder.create<gpu::HostUnregisterOp>(loc, cast); 127 } 128 129 /// Generates first wait in an asynchronous chain. 130 static Value genFirstWait(OpBuilder &builder, Location loc) { 131 Type tokenType = builder.getType<gpu::AsyncTokenType>(); 132 return builder.create<gpu::WaitOp>(loc, tokenType, ValueRange()) 133 .getAsyncToken(); 134 } 135 136 /// Generates last, blocking wait in an asynchronous chain. 137 static void genBlockingWait(OpBuilder &builder, Location loc, 138 ValueRange operands) { 139 builder.create<gpu::WaitOp>(loc, Type(), operands); 140 } 141 142 /// Allocates memory on the device. 143 /// TODO: A `host_shared` attribute could be used to indicate that 144 /// the buffer is visible by both host and device, but lowering 145 /// that feature does not seem to be fully supported yet. 146 static gpu::AllocOp genAllocMemRef(OpBuilder &builder, Location loc, Value mem, 147 Value token) { 148 auto tp = cast<ShapedType>(mem.getType()); 149 auto elemTp = tp.getElementType(); 150 auto shape = tp.getShape(); 151 auto memTp = MemRefType::get(shape, elemTp); 152 SmallVector<Value> dynamicSizes; 153 for (unsigned r = 0, rank = tp.getRank(); r < rank; r++) { 154 if (shape[r] == ShapedType::kDynamic) { 155 Value dimOp = linalg::createOrFoldDimOp(builder, loc, mem, r); 156 dynamicSizes.push_back(dimOp); 157 } 158 } 159 return builder.create<gpu::AllocOp>(loc, TypeRange({memTp, token.getType()}), 160 token, dynamicSizes, ValueRange()); 161 } 162 163 // Allocates a typed buffer on the host with given size. 164 static Value genHostBuffer(OpBuilder &builder, Location loc, Type type, 165 Value size) { 166 const auto memTp = MemRefType::get({ShapedType::kDynamic}, type); 167 return builder.create<memref::AllocOp>(loc, memTp, size).getResult(); 168 } 169 170 // Allocates a typed buffer on the device with given size. 171 static gpu::AllocOp genAllocBuffer(OpBuilder &builder, Location loc, Type type, 172 Value size, Value token) { 173 const auto memTp = MemRefType::get({ShapedType::kDynamic}, type); 174 return builder.create<gpu::AllocOp>(loc, TypeRange({memTp, token.getType()}), 175 token, size, ValueRange()); 176 } 177 178 // Allocates a void buffer on the device with given size. 179 static gpu::AllocOp genAllocBuffer(OpBuilder &builder, Location loc, Value size, 180 Value token) { 181 return genAllocBuffer(builder, loc, builder.getI8Type(), size, token); 182 } 183 184 /// Deallocates memory from the device. 185 static Value genDeallocMemRef(OpBuilder &builder, Location loc, Value mem, 186 Value token) { 187 return builder.create<gpu::DeallocOp>(loc, token.getType(), token, mem) 188 .getAsyncToken(); 189 } 190 191 /// Copies memory between host and device (direction is implicit). 192 static Value genCopyMemRef(OpBuilder &builder, Location loc, Value dst, 193 Value src, Value token) { 194 return builder.create<gpu::MemcpyOp>(loc, token.getType(), token, dst, src) 195 .getAsyncToken(); 196 } 197 198 /// Generates an alloc/copy pair. 199 static Value genAllocCopy(OpBuilder &builder, Location loc, Value b, 200 SmallVectorImpl<Value> &tokens) { 201 Value firstToken = genFirstWait(builder, loc); 202 auto alloc = genAllocMemRef(builder, loc, b, firstToken); 203 Value devMem = alloc.getResult(0); 204 Value depToken = alloc.getAsyncToken(); // copy-after-alloc 205 tokens.push_back(genCopyMemRef(builder, loc, devMem, b, depToken)); 206 return devMem; 207 } 208 209 /// Generates a memref from tensor operation. 210 static Value genTensorToMemref(PatternRewriter &rewriter, Location loc, 211 Value tensor) { 212 auto tensorType = llvm::cast<ShapedType>(tensor.getType()); 213 auto memrefType = 214 MemRefType::get(tensorType.getShape(), tensorType.getElementType()); 215 return rewriter.create<bufferization::ToMemrefOp>(loc, memrefType, tensor); 216 } 217 218 /// Prepares the outlined arguments, passing scalars and buffers in. Here we 219 /// assume that the first buffer is the one allocated for output. We create 220 /// a set of properly chained asynchronous allocation/copy pairs to increase 221 /// overlap before launching the kernel. 222 static Value genParametersIn(OpBuilder &builder, Location loc, 223 SmallVectorImpl<Value> &scalars, 224 SmallVectorImpl<Value> &buffers, 225 SmallVectorImpl<Value> &args, 226 SmallVectorImpl<Value> &tokens, 227 bool useHostRegistrationForOut) { 228 Value out; 229 // Scalars are passed by value. 230 for (Value s : scalars) 231 args.push_back(s); 232 // Buffers are need to be made visible on device. 233 for (Value b : buffers) { 234 if (useHostRegistrationForOut) { 235 out = genHostRegisterMemref(builder, loc, b); 236 args.push_back(b); 237 useHostRegistrationForOut = false; 238 continue; 239 } 240 args.push_back(genAllocCopy(builder, loc, b, tokens)); 241 } 242 return out; 243 } 244 245 /// Finalizes the outlined arguments. The output buffer is copied depending 246 /// on the kernel token and then deallocated. All other buffers are simply 247 /// deallocated. Then we wait for all operations to complete. 248 static void genParametersOut(OpBuilder &builder, Location loc, Value out, 249 Value kernelToken, SmallVectorImpl<Value> &scalars, 250 SmallVectorImpl<Value> &buffers, 251 SmallVectorImpl<Value> &args, 252 SmallVectorImpl<Value> &tokens) { 253 unsigned base = scalars.size(); 254 for (unsigned i = base, e = args.size(); i < e; i++) { 255 Value firstToken; 256 if (i == base) { 257 // Assumed output parameter: unregister or copy-out. 258 if (out) { 259 genHostUnregisterMemref(builder, loc, out); 260 out = Value(); 261 continue; 262 } 263 firstToken = 264 genCopyMemRef(builder, loc, buffers[0], args[i], kernelToken); 265 } else { 266 firstToken = genFirstWait(builder, loc); 267 } 268 tokens.push_back(genDeallocMemRef(builder, loc, args[i], firstToken)); 269 } 270 } 271 272 /// Constructs code for new GPU kernel. 273 static void genGPUCode(PatternRewriter &rewriter, gpu::GPUFuncOp gpuFunc, 274 scf::ParallelOp forallOp, 275 SmallVectorImpl<Value> &constants, 276 SmallVectorImpl<Value> &scalars, 277 SmallVectorImpl<Value> &buffers) { 278 Location loc = gpuFunc->getLoc(); 279 Block &block = gpuFunc.getBody().front(); 280 rewriter.setInsertionPointToStart(&block); 281 282 // Re-generate the constants, recapture all arguments. 283 unsigned arg = 0; 284 IRMapping irMap; 285 for (Value c : constants) 286 irMap.map(c, rewriter.clone(*c.getDefiningOp())->getResult(0)); 287 for (Value s : scalars) 288 irMap.map(s, block.getArgument(arg++)); 289 for (Value b : buffers) 290 irMap.map(b, block.getArgument(arg++)); 291 292 // Assume 1-dimensional grid/block configuration (only x dimension), 293 // so that: 294 // row = blockIdx.x * blockDim.x + threadIdx.x 295 // inc = blockDim.x * gridDim.x 296 Value bid = rewriter.create<gpu::BlockIdOp>(loc, gpu::Dimension::x); 297 Value bsz = rewriter.create<gpu::BlockDimOp>(loc, gpu::Dimension::x); 298 Value tid = rewriter.create<gpu::ThreadIdOp>(loc, gpu::Dimension::x); 299 Value gsz = rewriter.create<gpu::GridDimOp>(loc, gpu::Dimension::x); 300 Value mul = rewriter.create<arith::MulIOp>(loc, bid, bsz); 301 Value row = rewriter.create<arith::AddIOp>(loc, mul, tid); 302 Value inc = rewriter.create<arith::MulIOp>(loc, bsz, gsz); 303 304 // Construct the iteration over the computational space that 305 // accounts for the fact that the total number of threads and 306 // the amount of work to be done usually do not match precisely. 307 // for (r = row; r < N; r += inc) { 308 // <loop-body> 309 // } 310 Value upper = irMap.lookup(forallOp.getUpperBound()[0]); 311 scf::ForOp forOp = rewriter.create<scf::ForOp>(loc, row, upper, inc); 312 // The scf.for builder creates an empty block. scf.for does not allow multiple 313 // blocks in its region, so delete the block before `cloneRegionBefore` adds 314 // an additional block. 315 rewriter.eraseBlock(forOp.getBody()); 316 rewriter.cloneRegionBefore(forallOp.getRegion(), forOp.getRegion(), 317 forOp.getRegion().begin(), irMap); 318 319 // Done. 320 rewriter.setInsertionPointAfter(forOp); 321 rewriter.create<gpu::ReturnOp>(gpuFunc->getLoc()); 322 } 323 324 //===----------------------------------------------------------------------===// 325 // Library helper methods. 326 //===----------------------------------------------------------------------===// 327 328 /// Helper to detect a + b with arguments taken from given block. 329 static bool matchAddOfArgs(Block *block, Value val) { 330 if (auto *def = val.getDefiningOp()) { 331 if (isa<arith::AddFOp, arith::AddIOp>(def)) { 332 Value a = block->getArguments()[0]; 333 Value b = block->getArguments()[1]; 334 return (def->getOperand(0) == a && def->getOperand(1) == b) || 335 (def->getOperand(0) == b && def->getOperand(1) == a); 336 } 337 } 338 return false; 339 } 340 341 /// Helper to detect a * b with arguments taken from given block. 342 static bool matchMulOfArgs(Block *block, Value val) { 343 if (auto *def = val.getDefiningOp()) { 344 if (isa<arith::MulFOp, arith::MulIOp>(def)) { 345 Value a = block->getArguments()[0]; 346 Value b = block->getArguments()[1]; 347 return (def->getOperand(0) == a && def->getOperand(1) == b) || 348 (def->getOperand(0) == b && def->getOperand(1) == a); 349 } 350 } 351 return false; 352 } 353 354 /// Helper to detect x = x + a * b 355 static bool matchSumOfMultOfArgs(linalg::GenericOp op) { 356 auto yieldOp = cast<linalg::YieldOp>(op.getRegion().front().getTerminator()); 357 if (auto *def = yieldOp.getOperand(0).getDefiningOp()) { 358 if (isa<arith::AddFOp, arith::AddIOp>(def)) { 359 Value x = op.getBlock()->getArguments()[2]; 360 return (def->getOperand(0) == x && 361 matchMulOfArgs(op.getBlock(), def->getOperand(1))) || 362 (def->getOperand(1) == x && 363 matchMulOfArgs(op.getBlock(), def->getOperand(0))); 364 } 365 } 366 return false; 367 } 368 369 // Helper to detect c += spy(s) x (a * b) 370 static bool matchSumReductionOfMulUnary(linalg::GenericOp op) { 371 auto yieldOp = cast<linalg::YieldOp>(op.getRegion().front().getTerminator()); 372 // The linalg yields a custom reduce result. 373 Value s_out = op.getBlock()->getArguments()[2]; 374 if (auto redOp = 375 yieldOp.getOperand(0).getDefiningOp<sparse_tensor::ReduceOp>()) { 376 // The reduce consumes the output. 377 Value other; 378 if (s_out == redOp->getOperand(0)) 379 other = redOp->getOperand(1); 380 else if (s_out == redOp->getOperand(1)) 381 other = redOp->getOperand(0); 382 else 383 return false; 384 // The reduce op also consumes an unary which also consumes the output 385 // and does not define an absent value. 386 if (auto unOp = other.getDefiningOp<sparse_tensor::UnaryOp>()) { 387 if (s_out != unOp->getOperand(0) || !unOp.getAbsentRegion().empty()) 388 return false; 389 // And the bodies are as expected. 390 auto yieldUn = cast<sparse_tensor::YieldOp>( 391 unOp.getRegion(0).front().getTerminator()); 392 auto yieldRed = cast<sparse_tensor::YieldOp>( 393 redOp.getRegion().front().getTerminator()); 394 return matchMulOfArgs(op.getBlock(), yieldUn.getOperand(0)) && 395 matchAddOfArgs(&redOp.getRegion().front(), yieldRed.getOperand(0)); 396 } 397 } 398 return false; 399 } 400 401 /// Test for dense tensor. 402 static bool isDenseTensor(Value v) { 403 auto sTp = getSparseTensorType(v); 404 return sTp.getDimRank() == sTp.getLvlRank() && sTp.isAllDense(); 405 } 406 407 /// Test for suitable positions/coordinates width. 408 static bool isAdmissibleMetaData(SparseTensorType &aTp) { 409 return (aTp.getPosWidth() == 0 || aTp.getPosWidth() >= 16) && 410 (aTp.getCrdWidth() == 0 || aTp.getCrdWidth() >= 16); 411 } 412 413 /// Test for sorted COO matrix with suitable metadata. 414 static bool isAdmissibleCOO(SparseTensorType &aTp) { 415 return aTp.getDimRank() == 2 && aTp.getLvlRank() == 2 && aTp.isIdentity() && 416 aTp.isCompressedLvl(0) && aTp.isOrderedLvl(0) && !aTp.isUniqueLvl(0) && 417 aTp.isSingletonLvl(1) && aTp.isOrderedLvl(1) && aTp.isUniqueLvl(1) && 418 isAdmissibleMetaData(aTp); 419 } 420 421 /// Test for CSR matrix with suitable metadata. 422 static bool isAdmissibleCSR(SparseTensorType &aTp) { 423 return aTp.getDimRank() == 2 && aTp.getLvlRank() == 2 && aTp.isIdentity() && 424 aTp.isDenseLvl(0) && aTp.isCompressedLvl(1) && aTp.isOrderedLvl(1) && 425 aTp.isUniqueLvl(1) && isAdmissibleMetaData(aTp); 426 } 427 428 /// Test for CSC matrix with suitable metadata. 429 static bool isAdmissibleCSC(SparseTensorType &aTp) { 430 return aTp.getDimRank() == 2 && aTp.getLvlRank() == 2 && !aTp.isIdentity() && 431 aTp.isPermutation() && aTp.isDenseLvl(0) && aTp.isCompressedLvl(1) && 432 aTp.isOrderedLvl(1) && aTp.isUniqueLvl(1) && isAdmissibleMetaData(aTp); 433 } 434 435 /// Test for BSR matrix with suitable metadata. 436 static bool isAdmissibleBSR(SparseTensorType &aTp) { 437 if (aTp.getDimRank() == 2 && aTp.getLvlRank() == 4 && aTp.isDenseLvl(0) && 438 aTp.isCompressedLvl(1) && aTp.isOrderedLvl(1) && aTp.isUniqueLvl(1) && 439 aTp.isDenseLvl(2) && aTp.isDenseLvl(3) && isAdmissibleMetaData(aTp)) { 440 // CuSparse only supports "square" blocks currently. 441 SmallVector<unsigned> dims = getBlockSize(aTp.getDimToLvl()); 442 assert(dims.size() == 2); 443 return dims[0] == dims[1] && dims[0] > 1; 444 } 445 return false; 446 } 447 448 /// Returns a suitable sparse format for the operation and given operand 449 /// types with cuSparse, or kNone if none is available. 450 static CuSparseFormat getCuSparseFormat(SparseTensorType aTp, 451 SparseTensorType bTp, 452 SparseTensorType cTp, bool enableRT, 453 bool isMatVec) { 454 // The other operands have a dense type. 455 if (bTp.hasEncoding() || cTp.hasEncoding()) 456 return CuSparseFormat::kNone; 457 // Now check for suitable operand type for the main operand. 458 if (isAdmissibleCOO(aTp)) 459 #ifdef CUSPARSE_COO_AOS 460 return isMatVec ? CuSparseFormat::kCOO : CuSparseFormat::kNone; 461 #else 462 return enableRT ? CuSparseFormat::kCOO : CuSparseFormat::kNone; 463 #endif 464 if (isAdmissibleCSR(aTp)) 465 return CuSparseFormat::kCSR; 466 if (isAdmissibleCSC(aTp)) 467 return CuSparseFormat::kCSC; 468 if (isAdmissibleBSR(aTp)) 469 return CuSparseFormat::kBSR; 470 return CuSparseFormat::kNone; 471 } 472 473 /// Generates the first positions/coordinates of a sparse matrix. 474 static Value genFirstPosOrCrds(OpBuilder &builder, Location loc, Value a, 475 CuSparseFormat format, bool enableRT) { 476 if (format == CuSparseFormat::kCOO) { 477 // Library uses SoA COO, direct IR uses AoS COO. 478 if (enableRT) 479 return genToCoordinates(builder, loc, a, 0); 480 return genToCoordinatesBuffer(builder, loc, a); 481 } 482 // Formats CSR/CSC and BSR use positions at 1. 483 return genToPositions(builder, loc, a, 1); 484 } 485 486 /// Generates the second coordinates of a sparse matrix. 487 static Value genSecondCrds(OpBuilder &builder, Location loc, Value a, 488 CuSparseFormat format, bool enableRT) { 489 bool isCOO = format == CuSparseFormat::kCOO; 490 if (isCOO && !enableRT) 491 return Value(); // nothing needed 492 // Formats CSR/CSC and BSR use coordinates at 1. 493 return genToCoordinates(builder, loc, a, 1); 494 } 495 496 /// Generates the sparse matrix handle. 497 static Operation *genSpMat(OpBuilder &builder, Location loc, 498 SparseTensorType &aTp, Type handleTp, Type tokenTp, 499 Value token, Value sz1, Value sz2, Value nseA, 500 Value rowA, Value colA, Value valA, 501 CuSparseFormat format, bool enableRT) { 502 if (format == CuSparseFormat::kCOO) { 503 // Library uses SoA COO, direct IR uses AoS COO. 504 if (enableRT) { 505 assert(colA); 506 return builder.create<gpu::CreateCooOp>(loc, handleTp, tokenTp, token, 507 sz1, sz2, nseA, rowA, colA, valA); 508 } 509 #ifdef CUSPARSE_COO_AOS 510 assert(!colA); 511 return builder.create<gpu::CreateCooAoSOp>(loc, handleTp, tokenTp, token, 512 sz1, sz2, nseA, rowA, valA); 513 #else 514 llvm_unreachable("gpu::CreateCooAoSOp is deprecated"); 515 #endif 516 } 517 assert(colA); 518 if (format == CuSparseFormat::kCSR) 519 return builder.create<gpu::CreateCsrOp>(loc, handleTp, tokenTp, token, sz1, 520 sz2, nseA, rowA, colA, valA); 521 if (format == CuSparseFormat::kCSC) 522 return builder.create<gpu::CreateCscOp>(loc, handleTp, tokenTp, token, sz1, 523 sz2, nseA, rowA, colA, valA); 524 // BSR requires a bit more work since we need to pass in the block size 525 // and all others sizes in terms of blocks (#block-rows, #block-cols, 526 // #nonzero-blocks). 527 assert(format == CuSparseFormat::kBSR); 528 SmallVector<unsigned> dims = getBlockSize(aTp.getDimToLvl()); 529 assert(dims.size() == 2 && dims[0] == dims[1]); 530 uint64_t b = dims[0]; 531 Value bSz = constantIndex(builder, loc, b); 532 Value bRows = builder.create<arith::DivUIOp>(loc, sz1, bSz); 533 Value bCols = builder.create<arith::DivUIOp>(loc, sz2, bSz); 534 Value bNum = builder.create<arith::DivUIOp>( 535 loc, nseA, constantIndex(builder, loc, b * b)); 536 return builder.create<gpu::CreateBsrOp>(loc, handleTp, tokenTp, token, bRows, 537 bCols, bNum, bSz, bSz, rowA, colA, 538 valA); 539 } 540 541 /// Match and rewrite SpMV kernel. 542 static LogicalResult rewriteSpMV(PatternRewriter &rewriter, 543 linalg::GenericOp op, bool enableRT) { 544 Location loc = op.getLoc(); 545 Value a = op.getOperand(0); 546 Value x = op.getOperand(1); 547 Value y = op.getOperand(2); // we have y = Ax 548 SmallVector<Value> tokens; 549 550 // Only admissible sparse matrix format and dense vectors (no BSR). 551 SparseTensorType aTp = getSparseTensorType(a); 552 SparseTensorType xTp = getSparseTensorType(x); 553 SparseTensorType yTp = getSparseTensorType(y); 554 auto format = getCuSparseFormat(aTp, xTp, yTp, enableRT, /*isMatVec=*/true); 555 if (format == CuSparseFormat::kNone || format == CuSparseFormat::kBSR) 556 return failure(); 557 558 // Start sparse kernel and copy data from host to device. 559 // a : memR/memC/memV -> rowA,colA,valA 560 // x : memX -> vecX 561 // y : memY -> vecY 562 Value nseA = rewriter.create<NumberOfEntriesOp>(loc, a); 563 Value szY = linalg::createOrFoldDimOp(rewriter, loc, a, 0); 564 Value szX = linalg::createOrFoldDimOp(rewriter, loc, a, 1); 565 Value memR = genFirstPosOrCrds(rewriter, loc, a, format, enableRT); 566 Value memC = genSecondCrds(rewriter, loc, a, format, enableRT); // or empty 567 Value memV = genToValues(rewriter, loc, a); 568 Value rowA = genAllocCopy(rewriter, loc, memR, tokens); 569 Value colA = memC ? genAllocCopy(rewriter, loc, memC, tokens) : Value(); 570 Value valA = genAllocCopy(rewriter, loc, memV, tokens); 571 Value memX = genTensorToMemref(rewriter, loc, x); 572 Value vecX = genAllocCopy(rewriter, loc, memX, tokens); 573 Value memY = genTensorToMemref(rewriter, loc, y); 574 Value vecY = genAllocCopy(rewriter, loc, memY, tokens); 575 genBlockingWait(rewriter, loc, tokens); 576 tokens.clear(); 577 578 // Create sparse environment and sparse matrix/dense vector handles. 579 Type indexTp = rewriter.getIndexType(); 580 Type dnTensorHandleTp = rewriter.getType<gpu::SparseDnTensorHandleType>(); 581 Type spmatHandleTp = rewriter.getType<gpu::SparseSpMatHandleType>(); 582 Type tokenTp = rewriter.getType<gpu::AsyncTokenType>(); 583 Value token = genFirstWait(rewriter, loc); 584 Operation *spGenA = 585 genSpMat(rewriter, loc, aTp, spmatHandleTp, tokenTp, token, szY, szX, 586 nseA, rowA, colA, valA, format, enableRT); 587 Value spMatA = spGenA->getResult(0); 588 token = spGenA->getResult(1); 589 auto dvecX = rewriter.create<gpu::CreateDnTensorOp>( 590 loc, dnTensorHandleTp, tokenTp, token, vecX, szX); 591 Value dnX = dvecX.getResult(0); 592 token = dvecX.getAsyncToken(); 593 auto dvecY = rewriter.create<gpu::CreateDnTensorOp>( 594 loc, dnTensorHandleTp, tokenTp, token, vecY, szY); 595 Value dnY = dvecY.getResult(0); 596 token = dvecY.getAsyncToken(); 597 auto dnYType = llvm::cast<ShapedType>(y.getType()).getElementType(); 598 599 // Precompute buffersize for SpMV. 600 auto bufferComp = rewriter.create<gpu::SpMVBufferSizeOp>( 601 loc, indexTp, tokenTp, token, spMatA, dnX, dnY, 602 /*computeType=*/dnYType); 603 Value bufferSz = bufferComp.getResult(0); 604 token = bufferComp.getAsyncToken(); 605 auto buf = genAllocBuffer(rewriter, loc, bufferSz, token); 606 Value buffer = buf.getResult(0); 607 token = buf.getAsyncToken(); 608 609 // Perform the SpMV. 610 auto spmvComp = rewriter.create<gpu::SpMVOp>( 611 loc, tokenTp, token, spMatA, dnX, dnY, /*computeType=*/dnYType, buffer); 612 token = spmvComp.getAsyncToken(); 613 614 // Copy data back to host and free all the resoures. 615 token = rewriter.create<gpu::DestroySpMatOp>(loc, tokenTp, token, spMatA) 616 .getAsyncToken(); 617 token = rewriter.create<gpu::DestroyDnTensorOp>(loc, tokenTp, token, dnX) 618 .getAsyncToken(); 619 token = rewriter.create<gpu::DestroyDnTensorOp>(loc, tokenTp, token, dnY) 620 .getAsyncToken(); 621 token = genDeallocMemRef(rewriter, loc, rowA, token); 622 if (colA) 623 token = genDeallocMemRef(rewriter, loc, colA, token); 624 token = genDeallocMemRef(rewriter, loc, valA, token); 625 token = genDeallocMemRef(rewriter, loc, buffer, token); 626 token = genDeallocMemRef(rewriter, loc, vecX, token); 627 token = genCopyMemRef(rewriter, loc, memY, vecY, token); 628 token = genDeallocMemRef(rewriter, loc, vecY, token); 629 tokens.push_back(token); 630 genBlockingWait(rewriter, loc, tokens); 631 tokens.clear(); 632 633 // Done. 634 rewriter.replaceOpWithNewOp<bufferization::ToTensorOp>(op, memY); 635 return success(); 636 } 637 638 /// Match and rewrite SpMM kernel. 639 static LogicalResult rewriteSpMM(PatternRewriter &rewriter, 640 linalg::GenericOp op, bool enableRT) { 641 Location loc = op.getLoc(); 642 Value a = op.getOperand(0); 643 Value b = op.getOperand(1); 644 Value c = op.getOperand(2); // we have C = AB 645 SmallVector<Value> tokens; 646 647 // Only admissible sparse matrix format and dense matrices (no BSR). 648 SparseTensorType aTp = getSparseTensorType(a); 649 SparseTensorType bTp = getSparseTensorType(b); 650 SparseTensorType cTp = getSparseTensorType(c); 651 auto format = getCuSparseFormat(aTp, bTp, cTp, enableRT, /*isMatVec=*/false); 652 if (format == CuSparseFormat::kNone || format == CuSparseFormat::kBSR) 653 return failure(); 654 655 // Start sparse kernel and copy data from host to device. 656 // a : memR/memC/memV -> rowA,colA,valA 657 // b : bufB -> matB 658 // c : bufC -> matC 659 Value nseA = rewriter.create<NumberOfEntriesOp>(loc, a); 660 Value szm = linalg::createOrFoldDimOp(rewriter, loc, a, 0); 661 Value szk = linalg::createOrFoldDimOp(rewriter, loc, a, 1); 662 Value szn = linalg::createOrFoldDimOp(rewriter, loc, b, 1); 663 Value memR = genFirstPosOrCrds(rewriter, loc, a, format, enableRT); 664 Value memC = genSecondCrds(rewriter, loc, a, format, enableRT); // or empty 665 Value memV = genToValues(rewriter, loc, a); 666 Value rowA = genAllocCopy(rewriter, loc, memR, tokens); 667 Value colA = memC ? genAllocCopy(rewriter, loc, memC, tokens) : Value(); 668 Value valA = genAllocCopy(rewriter, loc, memV, tokens); 669 Value bufB = genTensorToMemref(rewriter, loc, b); 670 Value matB = genAllocCopy(rewriter, loc, bufB, tokens); 671 Value bufC = genTensorToMemref(rewriter, loc, c); 672 Value matC = genAllocCopy(rewriter, loc, bufC, tokens); 673 genBlockingWait(rewriter, loc, tokens); 674 tokens.clear(); 675 676 // Create sparse environment and sparse matrix/dense matrix handles. 677 Type indexTp = rewriter.getIndexType(); 678 Type dnTensorHandleTp = rewriter.getType<gpu::SparseDnTensorHandleType>(); 679 Type spMatHandleTp = rewriter.getType<gpu::SparseSpMatHandleType>(); 680 Type tokenTp = rewriter.getType<gpu::AsyncTokenType>(); 681 Value token = genFirstWait(rewriter, loc); 682 Operation *spGenA = 683 genSpMat(rewriter, loc, aTp, spMatHandleTp, tokenTp, token, szm, szk, 684 nseA, rowA, colA, valA, format, enableRT); 685 Value spMatA = spGenA->getResult(0); 686 token = spGenA->getResult(1); 687 auto dmatB = rewriter.create<gpu::CreateDnTensorOp>( 688 loc, dnTensorHandleTp, tokenTp, token, matB, 689 SmallVector<Value>{szk, szn}); 690 Value dnB = dmatB.getResult(0); 691 token = dmatB.getAsyncToken(); 692 auto dmatC = rewriter.create<gpu::CreateDnTensorOp>( 693 loc, dnTensorHandleTp, tokenTp, token, matC, 694 SmallVector<Value>{szm, szn}); 695 Value dnC = dmatC.getResult(0); 696 token = dmatC.getAsyncToken(); 697 auto dmatCType = llvm::cast<ShapedType>(c.getType()).getElementType(); 698 699 // Precompute buffersize for SpMM. 700 auto bufferComp = rewriter.create<gpu::SpMMBufferSizeOp>( 701 loc, indexTp, tokenTp, token, spMatA, dnB, dnC, 702 /*computeType=*/dmatCType); 703 Value bufferSz = bufferComp.getResult(0); 704 token = bufferComp.getAsyncToken(); 705 auto buf = genAllocBuffer(rewriter, loc, bufferSz, token); 706 Value buffer = buf.getResult(0); 707 token = buf.getAsyncToken(); 708 auto dnCType = llvm::cast<ShapedType>(c.getType()).getElementType(); 709 710 // Perform the SpMM. 711 auto spmmComp = rewriter.create<gpu::SpMMOp>( 712 loc, tokenTp, token, spMatA, dnB, dnC, /*computeType=*/dnCType, buffer); 713 token = spmmComp.getAsyncToken(); 714 715 // Copy data back to host and free all the resoures. 716 token = rewriter.create<gpu::DestroySpMatOp>(loc, tokenTp, token, spMatA) 717 .getAsyncToken(); 718 token = rewriter.create<gpu::DestroyDnTensorOp>(loc, tokenTp, token, dnB) 719 .getAsyncToken(); 720 token = rewriter.create<gpu::DestroyDnTensorOp>(loc, tokenTp, token, dnC) 721 .getAsyncToken(); 722 token = genDeallocMemRef(rewriter, loc, rowA, token); 723 if (colA) 724 token = genDeallocMemRef(rewriter, loc, colA, token); 725 token = genDeallocMemRef(rewriter, loc, valA, token); 726 token = genDeallocMemRef(rewriter, loc, buffer, token); 727 token = genDeallocMemRef(rewriter, loc, matB, token); 728 token = genCopyMemRef(rewriter, loc, bufC, matC, token); 729 token = genDeallocMemRef(rewriter, loc, matC, token); 730 tokens.push_back(token); 731 genBlockingWait(rewriter, loc, tokens); 732 tokens.clear(); 733 734 // Done. 735 rewriter.replaceOpWithNewOp<bufferization::ToTensorOp>(op, bufC); 736 return success(); 737 } 738 739 // Match and rewrite SpGEMM kernel. 740 static LogicalResult rewriteSpGEMM(PatternRewriter &rewriter, 741 linalg::GenericOp op, bool enableRT) { 742 Location loc = op.getLoc(); 743 Value a = op.getOperand(0); 744 Value b = op.getOperand(1); 745 Value c = op.getOperand(2); // we have C = AB 746 SmallVector<Value> tokens; 747 748 // Only CSR <- CSR x CSR supported. 749 auto format = CuSparseFormat::kCSR; 750 SparseTensorType aTp = getSparseTensorType(a); 751 SparseTensorType bTp = getSparseTensorType(b); 752 SparseTensorType cTp = getSparseTensorType(c); 753 if (!isAdmissibleCSR(aTp) || !isAdmissibleCSR(bTp) || !isAdmissibleCSR(cTp)) 754 return failure(); 755 756 // Start sparse kernel and copy data from host to device. 757 // a : amemR/amemC/amemV -> rowA,colA,valA 758 // b : bmemR/bmemC/bmemV -> rowB,colB,valB 759 // c : materializes 760 auto dnCType = cTp.getElementType(); 761 Value nseA = rewriter.create<NumberOfEntriesOp>(loc, a); 762 Value nseB = rewriter.create<NumberOfEntriesOp>(loc, b); 763 Value szm = linalg::createOrFoldDimOp(rewriter, loc, a, 0); 764 Value szk = linalg::createOrFoldDimOp(rewriter, loc, a, 1); 765 Value szn = linalg::createOrFoldDimOp(rewriter, loc, b, 1); 766 Value amemR = genFirstPosOrCrds(rewriter, loc, a, format, enableRT); 767 Value amemC = genSecondCrds(rewriter, loc, a, format, enableRT); // not empty 768 Value amemV = genToValues(rewriter, loc, a); 769 Value bmemR = genFirstPosOrCrds(rewriter, loc, b, format, enableRT); 770 Value bmemC = genSecondCrds(rewriter, loc, b, format, enableRT); // not empty 771 Value bmemV = genToValues(rewriter, loc, b); 772 Value rowA = genAllocCopy(rewriter, loc, amemR, tokens); 773 Value colA = genAllocCopy(rewriter, loc, amemC, tokens); 774 Value valA = genAllocCopy(rewriter, loc, amemV, tokens); 775 Value rowB = genAllocCopy(rewriter, loc, bmemR, tokens); 776 Value colB = genAllocCopy(rewriter, loc, bmemC, tokens); 777 Value valB = genAllocCopy(rewriter, loc, bmemV, tokens); 778 genBlockingWait(rewriter, loc, tokens); 779 tokens.clear(); 780 781 // Create sparse environment and sparse matrix/dense vector handles. 782 Type indexTp = rewriter.getIndexType(); 783 Type spmatHandleTp = rewriter.getType<gpu::SparseSpMatHandleType>(); 784 Type descTp = rewriter.getType<gpu::SparseSpGEMMOpHandleType>(); 785 Type tokenTp = rewriter.getType<gpu::AsyncTokenType>(); 786 Value token = genFirstWait(rewriter, loc); 787 Operation *spGenA = 788 genSpMat(rewriter, loc, aTp, spmatHandleTp, tokenTp, token, szm, szk, 789 nseA, rowA, colA, valA, format, enableRT); 790 Value spMatA = spGenA->getResult(0); 791 token = spGenA->getResult(1); 792 Operation *spGenB = 793 genSpMat(rewriter, loc, bTp, spmatHandleTp, tokenTp, token, szk, szn, 794 nseB, rowB, colB, valB, format, enableRT); 795 Value spMatB = spGenB->getResult(0); 796 token = spGenB->getResult(1); 797 798 // Sparse matrix C materializes (also assumes beta == 0). 799 Value zero = constantIndex(rewriter, loc, 0); 800 Value one = constantIndex(rewriter, loc, 1); 801 Value mplus1 = rewriter.create<arith::AddIOp>(loc, szm, one); 802 auto e1 = genAllocBuffer(rewriter, loc, cTp.getPosType(), mplus1, token); 803 Value rowC = e1.getResult(0); 804 token = e1.getAsyncToken(); 805 auto e2 = genAllocBuffer(rewriter, loc, cTp.getCrdType(), zero, token); 806 Value colC = e2.getResult(0); // no free needed 807 token = e2.getAsyncToken(); 808 auto e3 = genAllocBuffer(rewriter, loc, dnCType, zero, token); 809 Value valC = e3.getResult(0); // no free needed 810 token = e3.getAsyncToken(); 811 Operation *spGenC = 812 genSpMat(rewriter, loc, cTp, spmatHandleTp, tokenTp, token, szm, szn, 813 zero, rowC, colC, valC, format, enableRT); 814 Value spMatC = spGenC->getResult(0); 815 token = spGenC->getResult(1); 816 817 // Precompute buffersizes for SpGEMM. 818 Operation *descOp = 819 rewriter.create<gpu::SpGEMMCreateDescrOp>(loc, descTp, tokenTp, token); 820 Value desc = descOp->getResult(0); 821 token = descOp->getResult(1); 822 Operation *work1 = rewriter.create<gpu::SpGEMMWorkEstimationOrComputeOp>( 823 loc, indexTp, tokenTp, token, desc, gpu::TransposeMode::NON_TRANSPOSE, 824 gpu::TransposeMode::NON_TRANSPOSE, spMatA, spMatB, spMatC, dnCType, zero, 825 valC, gpu::SpGEMMWorkEstimationOrComputeKind::WORK_ESTIMATION); 826 Value bufferSz1 = work1->getResult(0); 827 token = work1->getResult(1); 828 auto buf1 = genAllocBuffer(rewriter, loc, bufferSz1, token); 829 Value buffer1 = buf1.getResult(0); 830 token = buf1.getAsyncToken(); 831 Operation *work2 = rewriter.create<gpu::SpGEMMWorkEstimationOrComputeOp>( 832 loc, indexTp, tokenTp, token, desc, gpu::TransposeMode::NON_TRANSPOSE, 833 gpu::TransposeMode::NON_TRANSPOSE, spMatA, spMatB, spMatC, dnCType, 834 bufferSz1, buffer1, 835 gpu::SpGEMMWorkEstimationOrComputeKind::WORK_ESTIMATION); 836 token = work2->getResult(1); 837 838 // Compute step. 839 Operation *compute1 = rewriter.create<gpu::SpGEMMWorkEstimationOrComputeOp>( 840 loc, indexTp, tokenTp, token, desc, gpu::TransposeMode::NON_TRANSPOSE, 841 gpu::TransposeMode::NON_TRANSPOSE, spMatA, spMatB, spMatC, dnCType, zero, 842 valC, gpu::SpGEMMWorkEstimationOrComputeKind::COMPUTE); 843 Value bufferSz2 = compute1->getResult(0); 844 token = compute1->getResult(1); 845 auto buf2 = genAllocBuffer(rewriter, loc, bufferSz2, token); 846 Value buffer2 = buf2.getResult(0); 847 token = buf2.getAsyncToken(); 848 Operation *compute2 = rewriter.create<gpu::SpGEMMWorkEstimationOrComputeOp>( 849 loc, indexTp, tokenTp, token, desc, gpu::TransposeMode::NON_TRANSPOSE, 850 gpu::TransposeMode::NON_TRANSPOSE, spMatA, spMatB, spMatC, dnCType, 851 bufferSz2, buffer2, gpu::SpGEMMWorkEstimationOrComputeKind::COMPUTE); 852 token = compute2->getResult(1); 853 854 // Get sizes. 855 Operation *sizes = rewriter.create<gpu::SpMatGetSizeOp>( 856 loc, indexTp, indexTp, indexTp, tokenTp, token, spMatC); 857 Value nnz = sizes->getResult(2); 858 token = sizes->getResult(3); 859 auto a2 = genAllocBuffer(rewriter, loc, cTp.getCrdType(), nnz, token); 860 colC = a2.getResult(0); 861 token = a2.getAsyncToken(); 862 auto a3 = genAllocBuffer(rewriter, loc, dnCType, nnz, token); 863 valC = a3.getResult(0); 864 token = a3.getAsyncToken(); 865 866 // Update C with new pointers and copy final product back into C. 867 Operation *update = rewriter.create<gpu::SetCsrPointersOp>( 868 loc, tokenTp, token, spMatC, rowC, colC, valC); 869 token = update->getResult(0); 870 Operation *copy = rewriter.create<gpu::SpGEMMCopyOp>( 871 loc, tokenTp, token, desc, gpu::TransposeMode::NON_TRANSPOSE, 872 gpu::TransposeMode::NON_TRANSPOSE, spMatA, spMatB, spMatC, dnCType); 873 token = copy->getResult(0); 874 875 // Allocate buffers on host. 876 Value rowH = genHostBuffer(rewriter, loc, cTp.getPosType(), mplus1); 877 Value colH = genHostBuffer(rewriter, loc, cTp.getCrdType(), nnz); 878 Value valH = genHostBuffer(rewriter, loc, dnCType, nnz); 879 880 // Copy data back to host and free all the resoures. 881 token = rewriter.create<gpu::SpGEMMDestroyDescrOp>(loc, tokenTp, token, desc) 882 .getAsyncToken(); 883 token = rewriter.create<gpu::DestroySpMatOp>(loc, tokenTp, token, spMatA) 884 .getAsyncToken(); 885 token = rewriter.create<gpu::DestroySpMatOp>(loc, tokenTp, token, spMatB) 886 .getAsyncToken(); 887 token = rewriter.create<gpu::DestroySpMatOp>(loc, tokenTp, token, spMatC) 888 .getAsyncToken(); 889 token = genCopyMemRef(rewriter, loc, rowH, rowC, token); 890 token = genCopyMemRef(rewriter, loc, colH, colC, token); 891 token = genCopyMemRef(rewriter, loc, valH, valC, token); 892 token = genDeallocMemRef(rewriter, loc, rowA, token); 893 token = genDeallocMemRef(rewriter, loc, colA, token); 894 token = genDeallocMemRef(rewriter, loc, valA, token); 895 token = genDeallocMemRef(rewriter, loc, rowB, token); 896 token = genDeallocMemRef(rewriter, loc, colB, token); 897 token = genDeallocMemRef(rewriter, loc, valB, token); 898 token = genDeallocMemRef(rewriter, loc, rowC, token); 899 token = genDeallocMemRef(rewriter, loc, colC, token); 900 token = genDeallocMemRef(rewriter, loc, valC, token); 901 token = genDeallocMemRef(rewriter, loc, buffer1, token); 902 token = genDeallocMemRef(rewriter, loc, buffer2, token); 903 tokens.push_back(token); 904 genBlockingWait(rewriter, loc, tokens); 905 tokens.clear(); 906 907 // Done. 908 Value vt = rewriter.create<bufferization::ToTensorOp>(loc, valH); 909 Value rt = rewriter.create<bufferization::ToTensorOp>(loc, rowH); 910 Value ct = rewriter.create<bufferization::ToTensorOp>(loc, colH); 911 rewriter.replaceOpWithNewOp<AssembleOp>(op, c.getType(), vt, 912 ValueRange{rt, ct}); 913 return success(); 914 } 915 916 // Match and rewrite 2:4 SpMM kernel. 917 static LogicalResult rewrite2To4SpMM(PatternRewriter &rewriter, 918 linalg::GenericOp op) { 919 Location loc = op.getLoc(); 920 Value A = op.getOperand(0); 921 Value B = op.getOperand(1); 922 Value C = op.getOperand(2); // we have C = AB 923 SmallVector<Value> tokens; 924 925 // All input should be dense tensors. 926 if (!isDenseTensor(A) || !isDenseTensor(B) || !isDenseTensor(C)) 927 return failure(); 928 929 // Start sparse kernel and copy data from host to device. 930 // a : bufA -> matA 931 // b : bufB -> matB 932 // c : bufC -> matC 933 Value bufA = genTensorToMemref(rewriter, loc, A); 934 Value matA = genAllocCopy(rewriter, loc, bufA, tokens); 935 Value bufB = genTensorToMemref(rewriter, loc, B); 936 Value matB = genAllocCopy(rewriter, loc, bufB, tokens); 937 Value bufC = genTensorToMemref(rewriter, loc, C); 938 Value matC = genAllocCopy(rewriter, loc, bufC, tokens); 939 genBlockingWait(rewriter, loc, tokens); 940 tokens.clear(); 941 942 // Create sparse environment and sparse matrix/dense vector handles. 943 Value szm = linalg::createOrFoldDimOp(rewriter, loc, matA, 0); 944 Value szk = linalg::createOrFoldDimOp(rewriter, loc, matB, 0); 945 Value szn = linalg::createOrFoldDimOp(rewriter, loc, matC, 1); 946 Type indexTp = rewriter.getIndexType(); 947 Type dnTensorHandleTp = rewriter.getType<gpu::SparseDnTensorHandleType>(); 948 Type spMatHandleTp = rewriter.getType<gpu::SparseSpMatHandleType>(); 949 Type tokenTp = rewriter.getType<gpu::AsyncTokenType>(); 950 Value token = genFirstWait(rewriter, loc); 951 Operation *spGenA = rewriter.create<gpu::Create2To4SpMatOp>( 952 loc, spMatHandleTp, tokenTp, token, szm, szk, 953 gpu::Prune2To4SpMatFlag::PRUNE_AND_CHECK, matA); 954 Value spMatA = spGenA->getResult(0); 955 token = spGenA->getResult(1); 956 auto dmatB = rewriter.create<gpu::CreateDnTensorOp>( 957 loc, dnTensorHandleTp, tokenTp, token, matB, 958 SmallVector<Value>{szk, szn}); 959 Value dnB = dmatB.getResult(0); 960 token = dmatB.getAsyncToken(); 961 auto dmatC = rewriter.create<gpu::CreateDnTensorOp>( 962 loc, dnTensorHandleTp, tokenTp, token, matC, 963 SmallVector<Value>{szm, szn}); 964 Value dnC = dmatC.getResult(0); 965 token = dmatC.getAsyncToken(); 966 auto dmatCType = llvm::cast<ShapedType>(matC.getType()).getElementType(); 967 968 // Precompute buffersize for SpMM. 969 SmallVector<Type> bufferTypes_{indexTp, indexTp, indexTp}; 970 TypeRange bufferTypes(bufferTypes_); 971 auto bufferComp = rewriter.create<gpu::SpMMBufferSizeOp>( 972 loc, bufferTypes, tokenTp, token, gpu::TransposeMode::NON_TRANSPOSE, 973 gpu::TransposeMode::NON_TRANSPOSE, spMatA, dnB, dnC, 974 /*computeType=*/dmatCType); 975 token = bufferComp.getAsyncToken(); 976 977 // Allocate buffers on host. 978 Value bufferSz1 = bufferComp.getResult(0); 979 auto buf1 = genAllocBuffer(rewriter, loc, bufferSz1, token); 980 Value buffer1 = buf1.getResult(0); 981 token = buf1.getAsyncToken(); 982 Value bufferSz2 = bufferComp.getResult(1); 983 auto buf2 = genAllocBuffer(rewriter, loc, bufferSz2, token); 984 Value buffer2 = buf2.getResult(0); 985 token = buf2.getAsyncToken(); 986 Value bufferSz3 = bufferComp.getResult(2); 987 auto buf3 = genAllocBuffer(rewriter, loc, bufferSz3, token); 988 Value buffer3 = buf3.getResult(0); 989 token = buf3.getAsyncToken(); 990 991 // Perform the SpMM. 992 auto dnCType = llvm::cast<ShapedType>(matC.getType()).getElementType(); 993 auto spmmComp = rewriter.create<gpu::SpMMOp>( 994 loc, tokenTp, token, spMatA, dnB, dnC, /*computeType=*/dnCType, 995 SmallVector<Value>{buffer1, buffer2, buffer3}); 996 token = spmmComp.getAsyncToken(); 997 998 // Copy data back to host and free all the resources. 999 token = rewriter.create<gpu::DestroySpMatOp>(loc, tokenTp, token, spMatA) 1000 .getAsyncToken(); 1001 token = rewriter.create<gpu::DestroyDnTensorOp>(loc, tokenTp, token, dnB) 1002 .getAsyncToken(); 1003 token = rewriter.create<gpu::DestroyDnTensorOp>(loc, tokenTp, token, dnC) 1004 .getAsyncToken(); 1005 SmallVector<Value> newDynamicSizes; 1006 token = genDeallocMemRef(rewriter, loc, buffer1, token); 1007 token = genDeallocMemRef(rewriter, loc, buffer2, token); 1008 token = genDeallocMemRef(rewriter, loc, buffer3, token); 1009 token = genDeallocMemRef(rewriter, loc, matA, token); 1010 token = genDeallocMemRef(rewriter, loc, matB, token); 1011 token = genCopyMemRef(rewriter, loc, bufC, matC, token); 1012 token = genDeallocMemRef(rewriter, loc, matC, token); 1013 tokens.push_back(token); 1014 genBlockingWait(rewriter, loc, tokens); 1015 tokens.clear(); 1016 1017 // Done. 1018 rewriter.replaceOpWithNewOp<bufferization::ToTensorOp>(op, bufC); 1019 return success(); 1020 } 1021 1022 /// Match and rewrite SDDMM kernel. 1023 static LogicalResult rewriteSDDMM(PatternRewriter &rewriter, 1024 linalg::GenericOp op, bool enableRT) { 1025 Location loc = op.getLoc(); 1026 Value a = op.getOperand(0); 1027 Value b = op.getOperand(1); 1028 Value c = op.getOperand(2); 1029 SmallVector<Value> tokens; 1030 1031 // Only admissible sparse matrix format (no COO/CSC) and dense matrices. 1032 SparseTensorType aTp = getSparseTensorType(a); 1033 SparseTensorType bTp = getSparseTensorType(b); 1034 SparseTensorType cTp = getSparseTensorType(c); 1035 auto format = getCuSparseFormat(cTp, bTp, aTp, enableRT, /*isMatVec=*/false); 1036 if (format == CuSparseFormat::kNone || format == CuSparseFormat::kCOO || 1037 format == CuSparseFormat::kCSC) 1038 return failure(); 1039 1040 // The SDDMM does the in-place operation. 1041 // Start sparse kernel and copy data from host to device. 1042 // a : bufA -> matA 1043 // b : bufB -> matB 1044 // c : memR/memC/memV -> rowC,colC,valC 1045 Value nseC = rewriter.create<NumberOfEntriesOp>(loc, c); 1046 Value szm = linalg::createOrFoldDimOp(rewriter, loc, a, 0); 1047 Value szk = linalg::createOrFoldDimOp(rewriter, loc, a, 1); 1048 Value szn = linalg::createOrFoldDimOp(rewriter, loc, b, 1); 1049 Value bufA = genTensorToMemref(rewriter, loc, a); 1050 Value matA = genAllocCopy(rewriter, loc, bufA, tokens); 1051 Value bufB = genTensorToMemref(rewriter, loc, b); 1052 Value matB = genAllocCopy(rewriter, loc, bufB, tokens); 1053 Value memR = genFirstPosOrCrds(rewriter, loc, c, format, enableRT); 1054 Value memC = genSecondCrds(rewriter, loc, c, format, enableRT); // or empty 1055 Value memV = genToValues(rewriter, loc, c); 1056 Value rowC = genAllocCopy(rewriter, loc, memR, tokens); 1057 Value colC = memC ? genAllocCopy(rewriter, loc, memC, tokens) : Value(); 1058 Value valC = genAllocCopy(rewriter, loc, memV, tokens); 1059 genBlockingWait(rewriter, loc, tokens); 1060 tokens.clear(); 1061 1062 // Create sparse environment and sparse matrix/dense matrix handles. 1063 Type indexTp = rewriter.getIndexType(); 1064 Type dnMatHandleTp = rewriter.getType<gpu::SparseDnTensorHandleType>(); 1065 Type spMatHandleTp = rewriter.getType<gpu::SparseSpMatHandleType>(); 1066 Type tokenTp = rewriter.getType<gpu::AsyncTokenType>(); 1067 Value token = genFirstWait(rewriter, loc); 1068 auto dmatA = rewriter.create<gpu::CreateDnTensorOp>( 1069 loc, dnMatHandleTp, tokenTp, token, matA, SmallVector<Value>{szm, szk}); 1070 Value dnA = dmatA.getResult(0); 1071 token = dmatA.getAsyncToken(); 1072 auto dmatB = rewriter.create<gpu::CreateDnTensorOp>( 1073 loc, dnMatHandleTp, tokenTp, token, matB, SmallVector<Value>{szk, szn}); 1074 Value dnB = dmatB.getResult(0); 1075 token = dmatB.getAsyncToken(); 1076 Operation *spGenC = 1077 genSpMat(rewriter, loc, cTp, spMatHandleTp, tokenTp, token, szm, szn, 1078 nseC, rowC, colC, valC, format, enableRT); 1079 Value spMatC = spGenC->getResult(0); 1080 token = spGenC->getResult(1); 1081 auto dnCType = llvm::cast<ShapedType>(c.getType()).getElementType(); 1082 1083 // Precompute buffersize for SDDMM. 1084 auto bufferComp = rewriter.create<gpu::SDDMMBufferSizeOp>( 1085 loc, indexTp, tokenTp, token, dnA, dnB, spMatC, dnCType); 1086 Value bufferSz = bufferComp.getResult(0); 1087 token = bufferComp.getAsyncToken(); 1088 auto buf = genAllocBuffer(rewriter, loc, bufferSz, token); 1089 Value buffer = buf.getResult(0); 1090 token = buf.getAsyncToken(); 1091 1092 // Perform the SDDMM. 1093 auto sddmmComp = rewriter.create<gpu::SDDMMOp>(loc, tokenTp, token, dnA, dnB, 1094 spMatC, dnCType, buffer); 1095 token = sddmmComp.getAsyncToken(); 1096 1097 // Copy data back to host and free all the resoures. 1098 token = rewriter.create<gpu::DestroyDnTensorOp>(loc, tokenTp, token, dnA) 1099 .getAsyncToken(); 1100 token = rewriter.create<gpu::DestroyDnTensorOp>(loc, tokenTp, token, dnB) 1101 .getAsyncToken(); 1102 token = rewriter.create<gpu::DestroySpMatOp>(loc, tokenTp, token, spMatC) 1103 .getAsyncToken(); 1104 token = genDeallocMemRef(rewriter, loc, buffer, token); 1105 token = genDeallocMemRef(rewriter, loc, matA, token); 1106 token = genDeallocMemRef(rewriter, loc, matB, token); 1107 token = genDeallocMemRef(rewriter, loc, rowC, token); 1108 if (colC) 1109 token = genDeallocMemRef(rewriter, loc, colC, token); 1110 token = genCopyMemRef(rewriter, loc, memV, valC, token); 1111 token = genDeallocMemRef(rewriter, loc, valC, token); 1112 tokens.push_back(token); 1113 genBlockingWait(rewriter, loc, tokens); 1114 tokens.clear(); 1115 1116 // Done. 1117 rewriter.replaceOpWithNewOp<sparse_tensor::LoadOp>(op, c); 1118 return success(); 1119 } 1120 1121 //===----------------------------------------------------------------------===// 1122 // Rewriting rules for direct code generation. 1123 //===----------------------------------------------------------------------===// 1124 1125 /// Proof-of-concept rewriter. This rule generates a GPU implementation 1126 /// for each outermost forall loop generated by the sparsifier. 1127 /// TODO: right now works with parallelization-strategy=dense-outer-loop 1128 /// but give this its own flags in the future 1129 struct ForallRewriter : public OpRewritePattern<scf::ParallelOp> { 1130 using OpRewritePattern<scf::ParallelOp>::OpRewritePattern; 1131 1132 ForallRewriter(MLIRContext *context, unsigned nT) 1133 : OpRewritePattern(context), numThreads(nT){}; 1134 1135 LogicalResult matchAndRewrite(scf::ParallelOp forallOp, 1136 PatternRewriter &rewriter) const override { 1137 // Reject inadmissible loop form. 1138 // Essentially only accept a loop, generated by the sparsifier, 1139 // of the form 1140 // forall (i = 0; i < N; i++) 1141 // so that cyclic scheduling over the threads is easy. 1142 if (!forallOp->hasAttr(LoopEmitter::getLoopEmitterLoopAttrName()) || 1143 forallOp.getNumReductions() != 0 || forallOp.getNumLoops() != 1 || 1144 !matchPattern(forallOp.getLowerBound()[0], m_Zero()) || 1145 !matchPattern(forallOp.getStep()[0], m_One())) 1146 return failure(); 1147 // Collect every value that is computed outside the parallel loop. 1148 SetVector<Value> invariants; // stable iteration! 1149 forallOp->walk([&](Operation *op) { 1150 // Collect all values of admissible ops. 1151 for (OpOperand &o : op->getOpOperands()) { 1152 Value val = o.get(); 1153 Block *block; 1154 if (auto arg = dyn_cast<BlockArgument>(val)) 1155 block = arg.getOwner(); 1156 else 1157 block = val.getDefiningOp()->getBlock(); 1158 if (!forallOp.getRegion().findAncestorBlockInRegion(*block)) 1159 invariants.insert(val); 1160 } 1161 }); 1162 // Outline the outside values as proper parameters. Fail when sharing 1163 // value between host and device is not straightforward. 1164 SmallVector<Value> constants; 1165 SmallVector<Value> scalars; 1166 SmallVector<Value> buffers; 1167 for (Value val : invariants) { 1168 Type tp = val.getType(); 1169 if (val.getDefiningOp<arith::ConstantOp>()) 1170 constants.push_back(val); 1171 else if (isa<FloatType>(tp) || tp.isIntOrIndex()) 1172 scalars.push_back(val); 1173 else if (isa<MemRefType>(tp)) 1174 buffers.push_back(val); 1175 else 1176 return failure(); // don't know how to share 1177 } 1178 // Pass outlined non-constant values. 1179 // TODO: Experiment with `useHostRegistrationForOut` to see if we want to 1180 // keep the feature at all (either through a heuristic or compiler 1181 // option for gpu codegen). 1182 Location loc = forallOp->getLoc(); 1183 SmallVector<Value> args; 1184 SmallVector<Value> tokens; 1185 Value out = genParametersIn(rewriter, loc, scalars, buffers, args, tokens, 1186 /*useHostRegistrationForOut=*/false); 1187 // Set up GPU module and construct GPU function. 1188 auto saveIp = rewriter.saveInsertionPoint(); 1189 ModuleOp topModule = forallOp->getParentOfType<ModuleOp>(); 1190 auto gpuModule = genGPUModule(rewriter, topModule); 1191 auto gpuFunc = genGPUFunc(rewriter, gpuModule, args); 1192 genGPUCode(rewriter, gpuFunc, forallOp, constants, scalars, buffers); 1193 // Generate code that launches the kernel asynchronously, blocking on all 1194 // opens tokens and yielding a new token for the output. 1195 // TODO: Passing in tokens to launch up does not seem to be properly lowered 1196 // by cubin yet, hence the current blocking wait. 1197 rewriter.restoreInsertionPoint(saveIp); 1198 genBlockingWait(rewriter, loc, tokens); 1199 tokens.clear(); 1200 Value kernelToken = 1201 genLaunchGPUFunc(rewriter, gpuFunc, args, tokens, numThreads); 1202 // Finalize the outlined arguments. 1203 genParametersOut(rewriter, loc, out, kernelToken, scalars, buffers, args, 1204 tokens); 1205 genBlockingWait(rewriter, loc, tokens); 1206 rewriter.eraseOp(forallOp); 1207 return success(); 1208 } 1209 1210 private: 1211 unsigned numThreads; 1212 }; 1213 1214 //===----------------------------------------------------------------------===// 1215 // Rewriting rules for library recognition and code generation. 1216 //===----------------------------------------------------------------------===// 1217 1218 /// Proof-of-concept rewriter. This rule recognizes certain math kernels 1219 /// and replaces these with corresponding calls into a sparse library. 1220 struct LinalgOpRewriter : public OpRewritePattern<linalg::GenericOp> { 1221 using OpRewritePattern<linalg::GenericOp>::OpRewritePattern; 1222 1223 LinalgOpRewriter(MLIRContext *context, bool rt) 1224 : OpRewritePattern(context), enableRT(rt) {} 1225 1226 LogicalResult matchAndRewrite(linalg::GenericOp op, 1227 PatternRewriter &rewriter) const override { 1228 if (op.getNumDpsInits() != 1) 1229 return failure(); // reject multi-output 1230 1231 const unsigned numLoops = op.getNumLoops(); 1232 const unsigned numTensors = op->getNumOperands(); 1233 const auto iteratorTypes = op.getIteratorTypesArray(); 1234 SmallVector<AffineMap, 4> maps = op.getIndexingMapsArray(); 1235 1236 using MapList = ArrayRef<ArrayRef<AffineExpr>>; 1237 auto infer = [](MapList m) { return AffineMap::inferFromExprList(m); }; 1238 AffineExpr i, j, k; 1239 bindDims(getContext(), i, j, k); 1240 1241 // TODO: more robust patterns, tranposed versions, more kernels, 1242 // identify alpha and beta and pass them to the CUDA calls. 1243 1244 // Recognize a SpMV kernel. 1245 if (numLoops == 2 && numTensors == 3 && 1246 linalg::isParallelIterator(iteratorTypes[0]) && 1247 linalg::isReductionIterator(iteratorTypes[1]) && 1248 maps == infer({{i, j}, {j}, {i}}) && matchSumOfMultOfArgs(op)) { 1249 return rewriteSpMV(rewriter, op, enableRT); 1250 } 1251 1252 // Recognize a SpGEMM, 2:4-SpMM, or SpMM kernel. 1253 if (numLoops == 3 && numTensors == 3 && 1254 linalg::isParallelIterator(iteratorTypes[0]) && 1255 linalg::isParallelIterator(iteratorTypes[1]) && 1256 linalg::isReductionIterator(iteratorTypes[2]) && 1257 maps == infer({{i, k}, {k, j}, {i, j}}) && matchSumOfMultOfArgs(op)) { 1258 if (!isDenseTensor(op.getOperand(0)) && !isDenseTensor(op.getOperand(1))) 1259 return rewriteSpGEMM(rewriter, op, enableRT); 1260 if (op->getAttr("DENSE24")) 1261 return rewrite2To4SpMM(rewriter, op); 1262 return rewriteSpMM(rewriter, op, enableRT); 1263 } 1264 1265 // Recognize a SDDMM kernel. 1266 if (numLoops == 3 && numTensors == 3 && 1267 linalg::isParallelIterator(iteratorTypes[0]) && 1268 linalg::isParallelIterator(iteratorTypes[1]) && 1269 linalg::isReductionIterator(iteratorTypes[2]) && 1270 maps == infer({{i, k}, {k, j}, {i, j}}) && 1271 matchSumReductionOfMulUnary(op)) { 1272 return rewriteSDDMM(rewriter, op, enableRT); 1273 } 1274 1275 return failure(); 1276 } 1277 1278 private: 1279 bool enableRT; 1280 }; 1281 1282 } // namespace 1283 1284 //===----------------------------------------------------------------------===// 1285 // Public method for populating GPU rewriting rules. 1286 // 1287 // Currently two set of rewriting rules are made available. The first set 1288 // implements direct code generation, currently by means of convering the 1289 // outermost paralell loop into GPU threads. The second set implements 1290 // libary recognition of a set of sparse operations. Eventually, the right 1291 // combination of these two approaches has to be found. 1292 //===----------------------------------------------------------------------===// 1293 1294 void mlir::populateSparseGPUCodegenPatterns(RewritePatternSet &patterns, 1295 unsigned numThreads) { 1296 patterns.add<ForallRewriter>(patterns.getContext(), numThreads); 1297 } 1298 1299 void mlir::populateSparseGPULibgenPatterns(RewritePatternSet &patterns, 1300 bool enableRT) { 1301 patterns.add<LinalgOpRewriter>(patterns.getContext(), enableRT); 1302 } 1303