1 //===- SparseGPUCodegen.cpp - Generates GPU code --------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This is a prototype GPU codegenerator for the sparsifier. 10 // The objective is to eventually use the right combination of 11 // direct code generation and libary calls into vendor-specific 12 // highly optimized sparse libraries (e.g. cuSparse for CUDA). 13 // 14 //===----------------------------------------------------------------------===// 15 16 #include "CodegenUtils.h" 17 #include "LoopEmitter.h" 18 19 #include "mlir/Dialect/Bufferization/IR/Bufferization.h" 20 #include "mlir/Dialect/GPU/IR/GPUDialect.h" 21 #include "mlir/Dialect/Linalg/IR/Linalg.h" 22 #include "mlir/Dialect/Linalg/Utils/Utils.h" 23 #include "mlir/Dialect/MemRef/IR/MemRef.h" 24 #include "mlir/Dialect/SCF/IR/SCF.h" 25 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" 26 #include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h" 27 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h" 28 #include "mlir/IR/IRMapping.h" 29 #include "mlir/IR/Matchers.h" 30 31 using namespace mlir; 32 using namespace mlir::sparse_tensor; 33 34 namespace { 35 36 // Sparse formats supported by cuSparse. 37 enum class CuSparseFormat { 38 kNone, 39 kCOO, 40 kCSR, 41 kCSC, 42 kBSR, 43 }; 44 45 //===----------------------------------------------------------------------===// 46 // Helper methods. 47 //===----------------------------------------------------------------------===// 48 49 /// Marks the given top module as a GPU container module. 50 static void markAsGPUContainer(ModuleOp topModule) { 51 topModule->setAttr(gpu::GPUDialect::getContainerModuleAttrName(), 52 UnitAttr::get(topModule->getContext())); 53 } 54 55 /// Constructs a new GPU module (for GPU kernels) inside the given top module, 56 /// or returns an existing GPU module if one was built previously. 57 static gpu::GPUModuleOp genGPUModule(OpBuilder &builder, ModuleOp topModule) { 58 for (auto op : topModule.getBodyRegion().getOps<gpu::GPUModuleOp>()) 59 return op; // existing 60 markAsGPUContainer(topModule); 61 builder.setInsertionPointToStart(&topModule.getBodyRegion().front()); 62 return builder.create<gpu::GPUModuleOp>(topModule->getLoc(), 63 "sparse_kernels"); 64 } 65 66 /// Constructs a new GPU kernel in the given GPU module. 67 static gpu::GPUFuncOp genGPUFunc(OpBuilder &builder, gpu::GPUModuleOp gpuModule, 68 SmallVectorImpl<Value> &args) { 69 // Get a unique kernel name. Not very creative, 70 // but we simply try kernel0, kernel1, etc. 71 unsigned kernelNumber = 0; 72 SmallString<16> kernelName; 73 do { 74 kernelName.clear(); 75 ("kernel" + Twine(kernelNumber++)).toStringRef(kernelName); 76 } while (gpuModule.lookupSymbol(kernelName)); 77 // Then we insert a new kernel with given arguments into the module. 78 builder.setInsertionPointToStart(&gpuModule.getBodyRegion().front()); 79 SmallVector<Type> argsTp; 80 for (unsigned i = 0, e = args.size(); i < e; i++) 81 argsTp.push_back(args[i].getType()); 82 FunctionType type = FunctionType::get(gpuModule->getContext(), argsTp, {}); 83 auto gpuFunc = 84 builder.create<gpu::GPUFuncOp>(gpuModule->getLoc(), kernelName, type); 85 gpuFunc->setAttr(gpu::GPUDialect::getKernelFuncAttrName(), 86 builder.getUnitAttr()); 87 return gpuFunc; 88 } 89 90 /// Constructs code to launch GPU kernel. 91 static Value genLaunchGPUFunc(OpBuilder &builder, gpu::GPUFuncOp gpuFunc, 92 SmallVectorImpl<Value> &args, 93 SmallVectorImpl<Value> &tokens, 94 unsigned numThreads) { 95 Location loc = gpuFunc->getLoc(); 96 Value none = TypedValue<::mlir::IntegerType>{}; 97 Value one = constantIndex(builder, loc, 1); 98 Value numT = constantIndex(builder, loc, numThreads); 99 gpu::KernelDim3 gridSize = {one, one, one}; 100 gpu::KernelDim3 blckSize = {numT, one, one}; 101 return builder 102 .create<gpu::LaunchFuncOp>(loc, gpuFunc, gridSize, blckSize, 103 /*dynSharedMemSz*/ none, args, 104 builder.getType<gpu::AsyncTokenType>(), tokens) 105 .getAsyncToken(); 106 } 107 108 /// Maps the provided ranked host buffer into the device address space. 109 /// Writes from the host are guaranteed to be visible to device kernels 110 /// that are launched afterwards. Writes from the device are guaranteed 111 /// to be visible on the host after synchronizing with the device kernel 112 /// completion. Needs to cast the buffer to a unranked buffer. 113 static Value genHostRegisterMemref(OpBuilder &builder, Location loc, 114 Value mem) { 115 MemRefType memTp = cast<MemRefType>(mem.getType()); 116 UnrankedMemRefType resTp = 117 UnrankedMemRefType::get(memTp.getElementType(), /*memorySpace=*/0); 118 Value cast = builder.create<memref::CastOp>(loc, resTp, mem); 119 builder.create<gpu::HostRegisterOp>(loc, cast); 120 return cast; 121 } 122 123 /// Unmaps the provided buffer, expecting the casted buffer. 124 static void genHostUnregisterMemref(OpBuilder &builder, Location loc, 125 Value cast) { 126 builder.create<gpu::HostUnregisterOp>(loc, cast); 127 } 128 129 /// Generates first wait in an asynchronous chain. 130 static Value genFirstWait(OpBuilder &builder, Location loc) { 131 Type tokenType = builder.getType<gpu::AsyncTokenType>(); 132 return builder.create<gpu::WaitOp>(loc, tokenType, ValueRange()) 133 .getAsyncToken(); 134 } 135 136 /// Generates last, blocking wait in an asynchronous chain. 137 static void genBlockingWait(OpBuilder &builder, Location loc, 138 ValueRange operands) { 139 builder.create<gpu::WaitOp>(loc, Type(), operands); 140 } 141 142 /// Allocates memory on the device. 143 /// TODO: A `host_shared` attribute could be used to indicate that 144 /// the buffer is visible by both host and device, but lowering 145 /// that feature does not seem to be fully supported yet. 146 static gpu::AllocOp genAllocMemRef(OpBuilder &builder, Location loc, Value mem, 147 Value token) { 148 auto tp = cast<ShapedType>(mem.getType()); 149 auto elemTp = tp.getElementType(); 150 auto shape = tp.getShape(); 151 auto memTp = MemRefType::get(shape, elemTp); 152 SmallVector<Value> dynamicSizes; 153 for (unsigned r = 0, rank = tp.getRank(); r < rank; r++) { 154 if (shape[r] == ShapedType::kDynamic) { 155 Value dimOp = linalg::createOrFoldDimOp(builder, loc, mem, r); 156 dynamicSizes.push_back(dimOp); 157 } 158 } 159 return builder.create<gpu::AllocOp>(loc, TypeRange({memTp, token.getType()}), 160 token, dynamicSizes, ValueRange()); 161 } 162 163 // Allocates a typed buffer on the host with given size. 164 static Value genHostBuffer(OpBuilder &builder, Location loc, Type type, 165 Value size) { 166 const auto memTp = MemRefType::get({ShapedType::kDynamic}, type); 167 return builder.create<memref::AllocOp>(loc, memTp, size).getResult(); 168 } 169 170 // Allocates a typed buffer on the device with given size. 171 static gpu::AllocOp genAllocBuffer(OpBuilder &builder, Location loc, Type type, 172 Value size, Value token) { 173 const auto memTp = MemRefType::get({ShapedType::kDynamic}, type); 174 return builder.create<gpu::AllocOp>(loc, TypeRange({memTp, token.getType()}), 175 token, size, ValueRange()); 176 } 177 178 // Allocates a void buffer on the device with given size. 179 static gpu::AllocOp genAllocBuffer(OpBuilder &builder, Location loc, Value size, 180 Value token) { 181 return genAllocBuffer(builder, loc, builder.getI8Type(), size, token); 182 } 183 184 /// Deallocates memory from the device. 185 static Value genDeallocMemRef(OpBuilder &builder, Location loc, Value mem, 186 Value token) { 187 return builder.create<gpu::DeallocOp>(loc, token.getType(), token, mem) 188 .getAsyncToken(); 189 } 190 191 /// Copies memory between host and device (direction is implicit). 192 static Value genCopyMemRef(OpBuilder &builder, Location loc, Value dst, 193 Value src, Value token) { 194 return builder.create<gpu::MemcpyOp>(loc, token.getType(), token, dst, src) 195 .getAsyncToken(); 196 } 197 198 /// Generates an alloc/copy pair. 199 static Value genAllocCopy(OpBuilder &builder, Location loc, Value b, 200 SmallVectorImpl<Value> &tokens) { 201 Value firstToken = genFirstWait(builder, loc); 202 auto alloc = genAllocMemRef(builder, loc, b, firstToken); 203 Value devMem = alloc.getResult(0); 204 Value depToken = alloc.getAsyncToken(); // copy-after-alloc 205 tokens.push_back(genCopyMemRef(builder, loc, devMem, b, depToken)); 206 return devMem; 207 } 208 209 /// Generates a memref from tensor operation. 210 static Value genTensorToMemref(PatternRewriter &rewriter, Location loc, 211 Value tensor) { 212 auto tensorType = llvm::cast<ShapedType>(tensor.getType()); 213 auto memrefType = 214 MemRefType::get(tensorType.getShape(), tensorType.getElementType()); 215 return rewriter.create<bufferization::ToMemrefOp>(loc, memrefType, tensor); 216 } 217 218 /// Prepares the outlined arguments, passing scalars and buffers in. Here we 219 /// assume that the first buffer is the one allocated for output. We create 220 /// a set of properly chained asynchronous allocation/copy pairs to increase 221 /// overlap before launching the kernel. 222 static Value genParametersIn(OpBuilder &builder, Location loc, 223 SmallVectorImpl<Value> &scalars, 224 SmallVectorImpl<Value> &buffers, 225 SmallVectorImpl<Value> &args, 226 SmallVectorImpl<Value> &tokens, 227 bool useHostRegistrationForOut) { 228 Value out; 229 // Scalars are passed by value. 230 for (Value s : scalars) 231 args.push_back(s); 232 // Buffers are need to be made visible on device. 233 for (Value b : buffers) { 234 if (useHostRegistrationForOut) { 235 out = genHostRegisterMemref(builder, loc, b); 236 args.push_back(b); 237 useHostRegistrationForOut = false; 238 continue; 239 } 240 args.push_back(genAllocCopy(builder, loc, b, tokens)); 241 } 242 return out; 243 } 244 245 /// Finalizes the outlined arguments. The output buffer is copied depending 246 /// on the kernel token and then deallocated. All other buffers are simply 247 /// deallocated. Then we wait for all operations to complete. 248 static void genParametersOut(OpBuilder &builder, Location loc, Value out, 249 Value kernelToken, SmallVectorImpl<Value> &scalars, 250 SmallVectorImpl<Value> &buffers, 251 SmallVectorImpl<Value> &args, 252 SmallVectorImpl<Value> &tokens) { 253 unsigned base = scalars.size(); 254 for (unsigned i = base, e = args.size(); i < e; i++) { 255 Value firstToken; 256 if (i == base) { 257 // Assumed output parameter: unregister or copy-out. 258 if (out) { 259 genHostUnregisterMemref(builder, loc, out); 260 out = Value(); 261 continue; 262 } 263 firstToken = 264 genCopyMemRef(builder, loc, buffers[0], args[i], kernelToken); 265 } else { 266 firstToken = genFirstWait(builder, loc); 267 } 268 tokens.push_back(genDeallocMemRef(builder, loc, args[i], firstToken)); 269 } 270 } 271 272 /// Constructs code for new GPU kernel. 273 static void genGPUCode(PatternRewriter &rewriter, gpu::GPUFuncOp gpuFunc, 274 scf::ParallelOp forallOp, 275 SmallVectorImpl<Value> &constants, 276 SmallVectorImpl<Value> &scalars, 277 SmallVectorImpl<Value> &buffers) { 278 Location loc = gpuFunc->getLoc(); 279 Block &block = gpuFunc.getBody().front(); 280 rewriter.setInsertionPointToStart(&block); 281 282 // Re-generate the constants, recapture all arguments. 283 unsigned arg = 0; 284 IRMapping irMap; 285 for (Value c : constants) 286 irMap.map(c, rewriter.clone(*c.getDefiningOp())->getResult(0)); 287 for (Value s : scalars) 288 irMap.map(s, block.getArgument(arg++)); 289 for (Value b : buffers) 290 irMap.map(b, block.getArgument(arg++)); 291 292 // Assume 1-dimensional grid/block configuration (only x dimension), 293 // so that: 294 // row = blockIdx.x * blockDim.x + threadIdx.x 295 // inc = blockDim.x * gridDim.x 296 Value bid = rewriter.create<gpu::BlockIdOp>(loc, gpu::Dimension::x); 297 Value bsz = rewriter.create<gpu::BlockDimOp>(loc, gpu::Dimension::x); 298 Value tid = rewriter.create<gpu::ThreadIdOp>(loc, gpu::Dimension::x); 299 Value gsz = rewriter.create<gpu::GridDimOp>(loc, gpu::Dimension::x); 300 Value mul = rewriter.create<arith::MulIOp>(loc, bid, bsz); 301 Value row = rewriter.create<arith::AddIOp>(loc, mul, tid); 302 Value inc = rewriter.create<arith::MulIOp>(loc, bsz, gsz); 303 304 // Construct the iteration over the computational space that 305 // accounts for the fact that the total number of threads and 306 // the amount of work to be done usually do not match precisely. 307 // for (r = row; r < N; r += inc) { 308 // <loop-body> 309 // } 310 Value upper = irMap.lookup(forallOp.getUpperBound()[0]); 311 scf::ForOp forOp = rewriter.create<scf::ForOp>(loc, row, upper, inc); 312 rewriter.cloneRegionBefore(forallOp.getRegion(), forOp.getRegion(), 313 forOp.getRegion().begin(), irMap); 314 315 // Done. 316 rewriter.setInsertionPointAfter(forOp); 317 rewriter.create<gpu::ReturnOp>(gpuFunc->getLoc()); 318 } 319 320 //===----------------------------------------------------------------------===// 321 // Library helper methods. 322 //===----------------------------------------------------------------------===// 323 324 /// Helper to detect a + b with arguments taken from given block. 325 static bool matchAddOfArgs(Block *block, Value val) { 326 if (auto *def = val.getDefiningOp()) { 327 if (isa<arith::AddFOp, arith::AddIOp>(def)) { 328 Value a = block->getArguments()[0]; 329 Value b = block->getArguments()[1]; 330 return (def->getOperand(0) == a && def->getOperand(1) == b) || 331 (def->getOperand(0) == b && def->getOperand(1) == a); 332 } 333 } 334 return false; 335 } 336 337 /// Helper to detect a * b with arguments taken from given block. 338 static bool matchMulOfArgs(Block *block, Value val) { 339 if (auto *def = val.getDefiningOp()) { 340 if (isa<arith::MulFOp, arith::MulIOp>(def)) { 341 Value a = block->getArguments()[0]; 342 Value b = block->getArguments()[1]; 343 return (def->getOperand(0) == a && def->getOperand(1) == b) || 344 (def->getOperand(0) == b && def->getOperand(1) == a); 345 } 346 } 347 return false; 348 } 349 350 /// Helper to detect x = x + a * b 351 static bool matchSumOfMultOfArgs(linalg::GenericOp op) { 352 auto yieldOp = cast<linalg::YieldOp>(op.getRegion().front().getTerminator()); 353 if (auto *def = yieldOp.getOperand(0).getDefiningOp()) { 354 if (isa<arith::AddFOp, arith::AddIOp>(def)) { 355 Value x = op.getBlock()->getArguments()[2]; 356 return (def->getOperand(0) == x && 357 matchMulOfArgs(op.getBlock(), def->getOperand(1))) || 358 (def->getOperand(1) == x && 359 matchMulOfArgs(op.getBlock(), def->getOperand(0))); 360 } 361 } 362 return false; 363 } 364 365 // Helper to detect c += spy(s) x (a * b) 366 static bool matchSumReductionOfMulUnary(linalg::GenericOp op) { 367 auto yieldOp = cast<linalg::YieldOp>(op.getRegion().front().getTerminator()); 368 // The linalg yields a custom reduce result. 369 Value s_out = op.getBlock()->getArguments()[2]; 370 if (auto redOp = 371 yieldOp.getOperand(0).getDefiningOp<sparse_tensor::ReduceOp>()) { 372 // The reduce consumes the output. 373 Value other; 374 if (s_out == redOp->getOperand(0)) 375 other = redOp->getOperand(1); 376 else if (s_out == redOp->getOperand(1)) 377 other = redOp->getOperand(0); 378 else 379 return false; 380 // The reduce op also consumes an unary which also consumes the output 381 // and does not define an absent value. 382 if (auto unOp = other.getDefiningOp<sparse_tensor::UnaryOp>()) { 383 if (s_out != unOp->getOperand(0) || !unOp.getAbsentRegion().empty()) 384 return false; 385 // And the bodies are as expected. 386 auto yieldUn = cast<sparse_tensor::YieldOp>( 387 unOp.getRegion(0).front().getTerminator()); 388 auto yieldRed = cast<sparse_tensor::YieldOp>( 389 redOp.getRegion().front().getTerminator()); 390 return matchMulOfArgs(op.getBlock(), yieldUn.getOperand(0)) && 391 matchAddOfArgs(&redOp.getRegion().front(), yieldRed.getOperand(0)); 392 } 393 } 394 return false; 395 } 396 397 /// Test for dense tensor. 398 static bool isDenseTensor(Value v) { 399 auto sTp = getSparseTensorType(v); 400 return sTp.getDimRank() == sTp.getLvlRank() && sTp.isAllDense(); 401 } 402 403 /// Test for suitable positions/coordinates width. 404 static bool isAdmissibleMetaData(SparseTensorType &aTp) { 405 return (aTp.getPosWidth() == 0 || aTp.getPosWidth() >= 16) && 406 (aTp.getCrdWidth() == 0 || aTp.getCrdWidth() >= 16); 407 } 408 409 /// Test for sorted COO matrix with suitable metadata. 410 static bool isAdmissibleCOO(SparseTensorType &aTp) { 411 return aTp.getDimRank() == 2 && aTp.getLvlRank() == 2 && aTp.isIdentity() && 412 aTp.isCompressedLvl(0) && aTp.isOrderedLvl(0) && !aTp.isUniqueLvl(0) && 413 aTp.isSingletonLvl(1) && aTp.isOrderedLvl(1) && aTp.isUniqueLvl(1) && 414 isAdmissibleMetaData(aTp); 415 } 416 417 /// Test for CSR matrix with suitable metadata. 418 static bool isAdmissibleCSR(SparseTensorType &aTp) { 419 return aTp.getDimRank() == 2 && aTp.getLvlRank() == 2 && aTp.isIdentity() && 420 aTp.isDenseLvl(0) && aTp.isCompressedLvl(1) && aTp.isOrderedLvl(1) && 421 aTp.isUniqueLvl(1) && isAdmissibleMetaData(aTp); 422 } 423 424 /// Test for CSC matrix with suitable metadata. 425 static bool isAdmissibleCSC(SparseTensorType &aTp) { 426 return aTp.getDimRank() == 2 && aTp.getLvlRank() == 2 && !aTp.isIdentity() && 427 aTp.isPermutation() && aTp.isDenseLvl(0) && aTp.isCompressedLvl(1) && 428 aTp.isOrderedLvl(1) && aTp.isUniqueLvl(1) && isAdmissibleMetaData(aTp); 429 } 430 431 /// Test for BSR matrix with suitable metadata. 432 static bool isAdmissibleBSR(SparseTensorType &aTp) { 433 if (aTp.getDimRank() == 2 && aTp.getLvlRank() == 4 && aTp.isDenseLvl(0) && 434 aTp.isCompressedLvl(1) && aTp.isOrderedLvl(1) && aTp.isUniqueLvl(1) && 435 aTp.isDenseLvl(2) && aTp.isDenseLvl(3) && isAdmissibleMetaData(aTp)) { 436 // CuSparse only supports "square" blocks currently. 437 SmallVector<unsigned> dims = getBlockSize(aTp.getDimToLvl()); 438 assert(dims.size() == 2); 439 return dims[0] == dims[1] && dims[0] > 1; 440 } 441 return false; 442 } 443 444 /// Returns a suitable sparse format for the operation and given operand 445 /// types with cuSparse, or kNone if none is available. 446 static CuSparseFormat getCuSparseFormat(SparseTensorType aTp, 447 SparseTensorType bTp, 448 SparseTensorType cTp, bool enableRT, 449 bool isMatVec) { 450 // The other operands have a dense type. 451 if (bTp.hasEncoding() || cTp.hasEncoding()) 452 return CuSparseFormat::kNone; 453 // Now check for suitable operand type for the main operand. 454 if (isAdmissibleCOO(aTp)) 455 #ifdef CUSPARSE_COO_AOS 456 return isMatVec ? CuSparseFormat::kCOO : CuSparseFormat::kNone; 457 #else 458 return enableRT ? CuSparseFormat::kCOO : CuSparseFormat::kNone; 459 #endif 460 if (isAdmissibleCSR(aTp)) 461 return CuSparseFormat::kCSR; 462 if (isAdmissibleCSC(aTp)) 463 return CuSparseFormat::kCSC; 464 if (isAdmissibleBSR(aTp)) 465 return CuSparseFormat::kBSR; 466 return CuSparseFormat::kNone; 467 } 468 469 /// Generates the first positions/coordinates of a sparse matrix. 470 static Value genFirstPosOrCrds(OpBuilder &builder, Location loc, Value a, 471 CuSparseFormat format, bool enableRT) { 472 if (format == CuSparseFormat::kCOO) { 473 // Library uses SoA COO, direct IR uses AoS COO. 474 if (enableRT) 475 return genToCoordinates(builder, loc, a, 0, /*cooStart=*/0); 476 return genToCoordinatesBuffer(builder, loc, a); 477 } 478 // Formats CSR/CSC and BSR use positions at 1. 479 return genToPositions(builder, loc, a, 1); 480 } 481 482 /// Generates the second coordinates of a sparse matrix. 483 static Value genSecondCrds(OpBuilder &builder, Location loc, Value a, 484 CuSparseFormat format, bool enableRT) { 485 bool isCOO = format == CuSparseFormat::kCOO; 486 if (isCOO && !enableRT) 487 return Value(); // nothing needed 488 // Formats CSR/CSC and BSR use coordinates at 1. 489 return genToCoordinates(builder, loc, a, 1, /*cooStart=*/isCOO ? 0 : 2); 490 } 491 492 /// Generates the sparse matrix handle. 493 static Operation *genSpMat(OpBuilder &builder, Location loc, 494 SparseTensorType &aTp, Type handleTp, Type tokenTp, 495 Value token, Value sz1, Value sz2, Value nseA, 496 Value rowA, Value colA, Value valA, 497 CuSparseFormat format, bool enableRT) { 498 if (format == CuSparseFormat::kCOO) { 499 // Library uses SoA COO, direct IR uses AoS COO. 500 if (enableRT) { 501 assert(colA); 502 return builder.create<gpu::CreateCooOp>(loc, handleTp, tokenTp, token, 503 sz1, sz2, nseA, rowA, colA, valA); 504 } 505 #ifdef CUSPARSE_COO_AOS 506 assert(!colA); 507 return builder.create<gpu::CreateCooAoSOp>(loc, handleTp, tokenTp, token, 508 sz1, sz2, nseA, rowA, valA); 509 #else 510 llvm_unreachable("gpu::CreateCooAoSOp is deprecated"); 511 #endif 512 } 513 assert(colA); 514 if (format == CuSparseFormat::kCSR) 515 return builder.create<gpu::CreateCsrOp>(loc, handleTp, tokenTp, token, sz1, 516 sz2, nseA, rowA, colA, valA); 517 if (format == CuSparseFormat::kCSC) 518 return builder.create<gpu::CreateCscOp>(loc, handleTp, tokenTp, token, sz1, 519 sz2, nseA, rowA, colA, valA); 520 // BSR requires a bit more work since we need to pass in the block size 521 // and all others sizes in terms of blocks (#block-rows, #block-cols, 522 // #nonzero-blocks). 523 assert(format == CuSparseFormat::kBSR); 524 SmallVector<unsigned> dims = getBlockSize(aTp.getDimToLvl()); 525 assert(dims.size() == 2 && dims[0] == dims[1]); 526 uint64_t b = dims[0]; 527 Value bSz = constantIndex(builder, loc, b); 528 Value bRows = builder.create<arith::DivUIOp>(loc, sz1, bSz); 529 Value bCols = builder.create<arith::DivUIOp>(loc, sz2, bSz); 530 Value bNum = builder.create<arith::DivUIOp>( 531 loc, nseA, constantIndex(builder, loc, b * b)); 532 return builder.create<gpu::CreateBsrOp>(loc, handleTp, tokenTp, token, bRows, 533 bCols, bNum, bSz, bSz, rowA, colA, 534 valA); 535 } 536 537 /// Match and rewrite SpMV kernel. 538 static LogicalResult rewriteSpMV(PatternRewriter &rewriter, 539 linalg::GenericOp op, bool enableRT) { 540 Location loc = op.getLoc(); 541 Value a = op.getOperand(0); 542 Value x = op.getOperand(1); 543 Value y = op.getOperand(2); // we have y = Ax 544 SmallVector<Value> tokens; 545 546 // Only admissible sparse matrix format and dense vectors (no BSR). 547 SparseTensorType aTp = getSparseTensorType(a); 548 SparseTensorType xTp = getSparseTensorType(x); 549 SparseTensorType yTp = getSparseTensorType(y); 550 auto format = getCuSparseFormat(aTp, xTp, yTp, enableRT, /*isMatVec=*/true); 551 if (format == CuSparseFormat::kNone || format == CuSparseFormat::kBSR) 552 return failure(); 553 554 // Start sparse kernel and copy data from host to device. 555 // a : memR/memC/memV -> rowA,colA,valA 556 // x : memX -> vecX 557 // y : memY -> vecY 558 Value nseA = rewriter.create<NumberOfEntriesOp>(loc, a); 559 Value szY = linalg::createOrFoldDimOp(rewriter, loc, a, 0); 560 Value szX = linalg::createOrFoldDimOp(rewriter, loc, a, 1); 561 Value memR = genFirstPosOrCrds(rewriter, loc, a, format, enableRT); 562 Value memC = genSecondCrds(rewriter, loc, a, format, enableRT); // or empty 563 Value memV = genToValues(rewriter, loc, a); 564 Value rowA = genAllocCopy(rewriter, loc, memR, tokens); 565 Value colA = memC ? genAllocCopy(rewriter, loc, memC, tokens) : Value(); 566 Value valA = genAllocCopy(rewriter, loc, memV, tokens); 567 Value memX = genTensorToMemref(rewriter, loc, x); 568 Value vecX = genAllocCopy(rewriter, loc, memX, tokens); 569 Value memY = genTensorToMemref(rewriter, loc, y); 570 Value vecY = genAllocCopy(rewriter, loc, memY, tokens); 571 genBlockingWait(rewriter, loc, tokens); 572 tokens.clear(); 573 574 // Create sparse environment and sparse matrix/dense vector handles. 575 Type indexTp = rewriter.getIndexType(); 576 Type dnTensorHandleTp = rewriter.getType<gpu::SparseDnTensorHandleType>(); 577 Type spmatHandleTp = rewriter.getType<gpu::SparseSpMatHandleType>(); 578 Type tokenTp = rewriter.getType<gpu::AsyncTokenType>(); 579 Value token = genFirstWait(rewriter, loc); 580 Operation *spGenA = 581 genSpMat(rewriter, loc, aTp, spmatHandleTp, tokenTp, token, szY, szX, 582 nseA, rowA, colA, valA, format, enableRT); 583 Value spMatA = spGenA->getResult(0); 584 token = spGenA->getResult(1); 585 auto dvecX = rewriter.create<gpu::CreateDnTensorOp>( 586 loc, dnTensorHandleTp, tokenTp, token, vecX, szX); 587 Value dnX = dvecX.getResult(0); 588 token = dvecX.getAsyncToken(); 589 auto dvecY = rewriter.create<gpu::CreateDnTensorOp>( 590 loc, dnTensorHandleTp, tokenTp, token, vecY, szY); 591 Value dnY = dvecY.getResult(0); 592 token = dvecY.getAsyncToken(); 593 auto dnYType = llvm::cast<ShapedType>(y.getType()).getElementType(); 594 595 // Precompute buffersize for SpMV. 596 auto bufferComp = rewriter.create<gpu::SpMVBufferSizeOp>( 597 loc, indexTp, tokenTp, token, spMatA, dnX, dnY, 598 /*computeType=*/dnYType); 599 Value bufferSz = bufferComp.getResult(0); 600 token = bufferComp.getAsyncToken(); 601 auto buf = genAllocBuffer(rewriter, loc, bufferSz, token); 602 Value buffer = buf.getResult(0); 603 token = buf.getAsyncToken(); 604 605 // Perform the SpMV. 606 auto spmvComp = rewriter.create<gpu::SpMVOp>( 607 loc, tokenTp, token, spMatA, dnX, dnY, /*computeType=*/dnYType, buffer); 608 token = spmvComp.getAsyncToken(); 609 610 // Copy data back to host and free all the resoures. 611 token = rewriter.create<gpu::DestroySpMatOp>(loc, tokenTp, token, spMatA) 612 .getAsyncToken(); 613 token = rewriter.create<gpu::DestroyDnTensorOp>(loc, tokenTp, token, dnX) 614 .getAsyncToken(); 615 token = rewriter.create<gpu::DestroyDnTensorOp>(loc, tokenTp, token, dnY) 616 .getAsyncToken(); 617 token = genDeallocMemRef(rewriter, loc, rowA, token); 618 if (colA) 619 token = genDeallocMemRef(rewriter, loc, colA, token); 620 token = genDeallocMemRef(rewriter, loc, valA, token); 621 token = genDeallocMemRef(rewriter, loc, buffer, token); 622 token = genDeallocMemRef(rewriter, loc, vecX, token); 623 token = genCopyMemRef(rewriter, loc, memY, vecY, token); 624 token = genDeallocMemRef(rewriter, loc, vecY, token); 625 tokens.push_back(token); 626 genBlockingWait(rewriter, loc, tokens); 627 tokens.clear(); 628 629 // Done. 630 rewriter.replaceOpWithNewOp<bufferization::ToTensorOp>(op, memY); 631 return success(); 632 } 633 634 /// Match and rewrite SpMM kernel. 635 static LogicalResult rewriteSpMM(PatternRewriter &rewriter, 636 linalg::GenericOp op, bool enableRT) { 637 Location loc = op.getLoc(); 638 Value a = op.getOperand(0); 639 Value b = op.getOperand(1); 640 Value c = op.getOperand(2); // we have C = AB 641 SmallVector<Value> tokens; 642 643 // Only admissible sparse matrix format and dense matrices (no BSR). 644 SparseTensorType aTp = getSparseTensorType(a); 645 SparseTensorType bTp = getSparseTensorType(b); 646 SparseTensorType cTp = getSparseTensorType(c); 647 auto format = getCuSparseFormat(aTp, bTp, cTp, enableRT, /*isMatVec=*/false); 648 if (format == CuSparseFormat::kNone || format == CuSparseFormat::kBSR) 649 return failure(); 650 651 // Start sparse kernel and copy data from host to device. 652 // a : memR/memC/memV -> rowA,colA,valA 653 // b : bufB -> matB 654 // c : bufC -> matC 655 Value nseA = rewriter.create<NumberOfEntriesOp>(loc, a); 656 Value szm = linalg::createOrFoldDimOp(rewriter, loc, a, 0); 657 Value szk = linalg::createOrFoldDimOp(rewriter, loc, a, 1); 658 Value szn = linalg::createOrFoldDimOp(rewriter, loc, b, 1); 659 Value memR = genFirstPosOrCrds(rewriter, loc, a, format, enableRT); 660 Value memC = genSecondCrds(rewriter, loc, a, format, enableRT); // or empty 661 Value memV = genToValues(rewriter, loc, a); 662 Value rowA = genAllocCopy(rewriter, loc, memR, tokens); 663 Value colA = memC ? genAllocCopy(rewriter, loc, memC, tokens) : Value(); 664 Value valA = genAllocCopy(rewriter, loc, memV, tokens); 665 Value bufB = genTensorToMemref(rewriter, loc, b); 666 Value matB = genAllocCopy(rewriter, loc, bufB, tokens); 667 Value bufC = genTensorToMemref(rewriter, loc, c); 668 Value matC = genAllocCopy(rewriter, loc, bufC, tokens); 669 genBlockingWait(rewriter, loc, tokens); 670 tokens.clear(); 671 672 // Create sparse environment and sparse matrix/dense matrix handles. 673 Type indexTp = rewriter.getIndexType(); 674 Type dnTensorHandleTp = rewriter.getType<gpu::SparseDnTensorHandleType>(); 675 Type spMatHandleTp = rewriter.getType<gpu::SparseSpMatHandleType>(); 676 Type tokenTp = rewriter.getType<gpu::AsyncTokenType>(); 677 Value token = genFirstWait(rewriter, loc); 678 Operation *spGenA = 679 genSpMat(rewriter, loc, aTp, spMatHandleTp, tokenTp, token, szm, szk, 680 nseA, rowA, colA, valA, format, enableRT); 681 Value spMatA = spGenA->getResult(0); 682 token = spGenA->getResult(1); 683 auto dmatB = rewriter.create<gpu::CreateDnTensorOp>( 684 loc, dnTensorHandleTp, tokenTp, token, matB, 685 SmallVector<Value>{szk, szn}); 686 Value dnB = dmatB.getResult(0); 687 token = dmatB.getAsyncToken(); 688 auto dmatC = rewriter.create<gpu::CreateDnTensorOp>( 689 loc, dnTensorHandleTp, tokenTp, token, matC, 690 SmallVector<Value>{szm, szn}); 691 Value dnC = dmatC.getResult(0); 692 token = dmatC.getAsyncToken(); 693 auto dmatCType = llvm::cast<ShapedType>(c.getType()).getElementType(); 694 695 // Precompute buffersize for SpMM. 696 auto bufferComp = rewriter.create<gpu::SpMMBufferSizeOp>( 697 loc, indexTp, tokenTp, token, spMatA, dnB, dnC, 698 /*computeType=*/dmatCType); 699 Value bufferSz = bufferComp.getResult(0); 700 token = bufferComp.getAsyncToken(); 701 auto buf = genAllocBuffer(rewriter, loc, bufferSz, token); 702 Value buffer = buf.getResult(0); 703 token = buf.getAsyncToken(); 704 auto dnCType = llvm::cast<ShapedType>(c.getType()).getElementType(); 705 706 // Perform the SpMM. 707 auto spmmComp = rewriter.create<gpu::SpMMOp>( 708 loc, tokenTp, token, spMatA, dnB, dnC, /*computeType=*/dnCType, buffer); 709 token = spmmComp.getAsyncToken(); 710 711 // Copy data back to host and free all the resoures. 712 token = rewriter.create<gpu::DestroySpMatOp>(loc, tokenTp, token, spMatA) 713 .getAsyncToken(); 714 token = rewriter.create<gpu::DestroyDnTensorOp>(loc, tokenTp, token, dnB) 715 .getAsyncToken(); 716 token = rewriter.create<gpu::DestroyDnTensorOp>(loc, tokenTp, token, dnC) 717 .getAsyncToken(); 718 token = genDeallocMemRef(rewriter, loc, rowA, token); 719 if (colA) 720 token = genDeallocMemRef(rewriter, loc, colA, token); 721 token = genDeallocMemRef(rewriter, loc, valA, token); 722 token = genDeallocMemRef(rewriter, loc, buffer, token); 723 token = genDeallocMemRef(rewriter, loc, matB, token); 724 token = genCopyMemRef(rewriter, loc, bufC, matC, token); 725 token = genDeallocMemRef(rewriter, loc, matC, token); 726 tokens.push_back(token); 727 genBlockingWait(rewriter, loc, tokens); 728 tokens.clear(); 729 730 // Done. 731 rewriter.replaceOpWithNewOp<bufferization::ToTensorOp>(op, bufC); 732 return success(); 733 } 734 735 // Match and rewrite SpGEMM kernel. 736 static LogicalResult rewriteSpGEMM(PatternRewriter &rewriter, 737 linalg::GenericOp op, bool enableRT) { 738 Location loc = op.getLoc(); 739 Value a = op.getOperand(0); 740 Value b = op.getOperand(1); 741 Value c = op.getOperand(2); // we have C = AB 742 SmallVector<Value> tokens; 743 744 // Only CSR <- CSR x CSR supported. 745 auto format = CuSparseFormat::kCSR; 746 SparseTensorType aTp = getSparseTensorType(a); 747 SparseTensorType bTp = getSparseTensorType(b); 748 SparseTensorType cTp = getSparseTensorType(c); 749 if (!isAdmissibleCSR(aTp) || !isAdmissibleCSR(bTp) || !isAdmissibleCSR(cTp)) 750 return failure(); 751 752 // Start sparse kernel and copy data from host to device. 753 // a : amemR/amemC/amemV -> rowA,colA,valA 754 // b : bmemR/bmemC/bmemV -> rowB,colB,valB 755 // c : materializes 756 auto dnCType = cTp.getElementType(); 757 Value nseA = rewriter.create<NumberOfEntriesOp>(loc, a); 758 Value nseB = rewriter.create<NumberOfEntriesOp>(loc, b); 759 Value szm = linalg::createOrFoldDimOp(rewriter, loc, a, 0); 760 Value szk = linalg::createOrFoldDimOp(rewriter, loc, a, 1); 761 Value szn = linalg::createOrFoldDimOp(rewriter, loc, b, 1); 762 Value amemR = genFirstPosOrCrds(rewriter, loc, a, format, enableRT); 763 Value amemC = genSecondCrds(rewriter, loc, a, format, enableRT); // not empty 764 Value amemV = genToValues(rewriter, loc, a); 765 Value bmemR = genFirstPosOrCrds(rewriter, loc, b, format, enableRT); 766 Value bmemC = genSecondCrds(rewriter, loc, b, format, enableRT); // not empty 767 Value bmemV = genToValues(rewriter, loc, b); 768 Value rowA = genAllocCopy(rewriter, loc, amemR, tokens); 769 Value colA = genAllocCopy(rewriter, loc, amemC, tokens); 770 Value valA = genAllocCopy(rewriter, loc, amemV, tokens); 771 Value rowB = genAllocCopy(rewriter, loc, bmemR, tokens); 772 Value colB = genAllocCopy(rewriter, loc, bmemC, tokens); 773 Value valB = genAllocCopy(rewriter, loc, bmemV, tokens); 774 genBlockingWait(rewriter, loc, tokens); 775 tokens.clear(); 776 777 // Create sparse environment and sparse matrix/dense vector handles. 778 Type indexTp = rewriter.getIndexType(); 779 Type spmatHandleTp = rewriter.getType<gpu::SparseSpMatHandleType>(); 780 Type descTp = rewriter.getType<gpu::SparseSpGEMMOpHandleType>(); 781 Type tokenTp = rewriter.getType<gpu::AsyncTokenType>(); 782 Value token = genFirstWait(rewriter, loc); 783 Operation *spGenA = 784 genSpMat(rewriter, loc, aTp, spmatHandleTp, tokenTp, token, szm, szk, 785 nseA, rowA, colA, valA, format, enableRT); 786 Value spMatA = spGenA->getResult(0); 787 token = spGenA->getResult(1); 788 Operation *spGenB = 789 genSpMat(rewriter, loc, bTp, spmatHandleTp, tokenTp, token, szk, szn, 790 nseB, rowB, colB, valB, format, enableRT); 791 Value spMatB = spGenB->getResult(0); 792 token = spGenB->getResult(1); 793 794 // Sparse matrix C materializes (also assumes beta == 0). 795 Value zero = constantIndex(rewriter, loc, 0); 796 Value one = constantIndex(rewriter, loc, 1); 797 Value mplus1 = rewriter.create<arith::AddIOp>(loc, szm, one); 798 auto e1 = genAllocBuffer(rewriter, loc, cTp.getPosType(), mplus1, token); 799 Value rowC = e1.getResult(0); 800 token = e1.getAsyncToken(); 801 auto e2 = genAllocBuffer(rewriter, loc, cTp.getCrdType(), zero, token); 802 Value colC = e2.getResult(0); // no free needed 803 token = e2.getAsyncToken(); 804 auto e3 = genAllocBuffer(rewriter, loc, dnCType, zero, token); 805 Value valC = e3.getResult(0); // no free needed 806 token = e3.getAsyncToken(); 807 Operation *spGenC = 808 genSpMat(rewriter, loc, cTp, spmatHandleTp, tokenTp, token, szm, szn, 809 zero, rowC, colC, valC, format, enableRT); 810 Value spMatC = spGenC->getResult(0); 811 token = spGenC->getResult(1); 812 813 // Precompute buffersizes for SpGEMM. 814 Operation *descOp = 815 rewriter.create<gpu::SpGEMMCreateDescrOp>(loc, descTp, tokenTp, token); 816 Value desc = descOp->getResult(0); 817 token = descOp->getResult(1); 818 Operation *work1 = rewriter.create<gpu::SpGEMMWorkEstimationOrComputeOp>( 819 loc, indexTp, tokenTp, token, desc, gpu::TransposeMode::NON_TRANSPOSE, 820 gpu::TransposeMode::NON_TRANSPOSE, spMatA, spMatB, spMatC, dnCType, zero, 821 valC, gpu::SpGEMMWorkEstimationOrComputeKind::WORK_ESTIMATION); 822 Value bufferSz1 = work1->getResult(0); 823 token = work1->getResult(1); 824 auto buf1 = genAllocBuffer(rewriter, loc, bufferSz1, token); 825 Value buffer1 = buf1.getResult(0); 826 token = buf1.getAsyncToken(); 827 Operation *work2 = rewriter.create<gpu::SpGEMMWorkEstimationOrComputeOp>( 828 loc, indexTp, tokenTp, token, desc, gpu::TransposeMode::NON_TRANSPOSE, 829 gpu::TransposeMode::NON_TRANSPOSE, spMatA, spMatB, spMatC, dnCType, 830 bufferSz1, buffer1, 831 gpu::SpGEMMWorkEstimationOrComputeKind::WORK_ESTIMATION); 832 token = work2->getResult(1); 833 834 // Compute step. 835 Operation *compute1 = rewriter.create<gpu::SpGEMMWorkEstimationOrComputeOp>( 836 loc, indexTp, tokenTp, token, desc, gpu::TransposeMode::NON_TRANSPOSE, 837 gpu::TransposeMode::NON_TRANSPOSE, spMatA, spMatB, spMatC, dnCType, zero, 838 valC, gpu::SpGEMMWorkEstimationOrComputeKind::COMPUTE); 839 Value bufferSz2 = compute1->getResult(0); 840 token = compute1->getResult(1); 841 auto buf2 = genAllocBuffer(rewriter, loc, bufferSz2, token); 842 Value buffer2 = buf2.getResult(0); 843 token = buf2.getAsyncToken(); 844 Operation *compute2 = rewriter.create<gpu::SpGEMMWorkEstimationOrComputeOp>( 845 loc, indexTp, tokenTp, token, desc, gpu::TransposeMode::NON_TRANSPOSE, 846 gpu::TransposeMode::NON_TRANSPOSE, spMatA, spMatB, spMatC, dnCType, 847 bufferSz2, buffer2, gpu::SpGEMMWorkEstimationOrComputeKind::COMPUTE); 848 token = compute2->getResult(1); 849 850 // Get sizes. 851 Operation *sizes = rewriter.create<gpu::SpMatGetSizeOp>( 852 loc, indexTp, indexTp, indexTp, tokenTp, token, spMatC); 853 Value nnz = sizes->getResult(2); 854 token = sizes->getResult(3); 855 auto a2 = genAllocBuffer(rewriter, loc, cTp.getCrdType(), nnz, token); 856 colC = a2.getResult(0); 857 token = a2.getAsyncToken(); 858 auto a3 = genAllocBuffer(rewriter, loc, dnCType, nnz, token); 859 valC = a3.getResult(0); 860 token = a3.getAsyncToken(); 861 862 // Update C with new pointers and copy final product back into C. 863 Operation *update = rewriter.create<gpu::SetCsrPointersOp>( 864 loc, tokenTp, token, spMatC, rowC, colC, valC); 865 token = update->getResult(0); 866 Operation *copy = rewriter.create<gpu::SpGEMMCopyOp>( 867 loc, tokenTp, token, desc, gpu::TransposeMode::NON_TRANSPOSE, 868 gpu::TransposeMode::NON_TRANSPOSE, spMatA, spMatB, spMatC, dnCType); 869 token = copy->getResult(0); 870 871 // Allocate buffers on host. 872 Value rowH = genHostBuffer(rewriter, loc, cTp.getPosType(), mplus1); 873 Value colH = genHostBuffer(rewriter, loc, cTp.getCrdType(), nnz); 874 Value valH = genHostBuffer(rewriter, loc, dnCType, nnz); 875 876 // Copy data back to host and free all the resoures. 877 token = rewriter.create<gpu::SpGEMMDestroyDescrOp>(loc, tokenTp, token, desc) 878 .getAsyncToken(); 879 token = rewriter.create<gpu::DestroySpMatOp>(loc, tokenTp, token, spMatA) 880 .getAsyncToken(); 881 token = rewriter.create<gpu::DestroySpMatOp>(loc, tokenTp, token, spMatB) 882 .getAsyncToken(); 883 token = rewriter.create<gpu::DestroySpMatOp>(loc, tokenTp, token, spMatC) 884 .getAsyncToken(); 885 token = genCopyMemRef(rewriter, loc, rowH, rowC, token); 886 token = genCopyMemRef(rewriter, loc, colH, colC, token); 887 token = genCopyMemRef(rewriter, loc, valH, valC, token); 888 token = genDeallocMemRef(rewriter, loc, rowA, token); 889 token = genDeallocMemRef(rewriter, loc, colA, token); 890 token = genDeallocMemRef(rewriter, loc, valA, token); 891 token = genDeallocMemRef(rewriter, loc, rowB, token); 892 token = genDeallocMemRef(rewriter, loc, colB, token); 893 token = genDeallocMemRef(rewriter, loc, valB, token); 894 token = genDeallocMemRef(rewriter, loc, rowC, token); 895 token = genDeallocMemRef(rewriter, loc, colC, token); 896 token = genDeallocMemRef(rewriter, loc, valC, token); 897 token = genDeallocMemRef(rewriter, loc, buffer1, token); 898 token = genDeallocMemRef(rewriter, loc, buffer2, token); 899 tokens.push_back(token); 900 genBlockingWait(rewriter, loc, tokens); 901 tokens.clear(); 902 903 // Done. 904 Value vt = rewriter.create<bufferization::ToTensorOp>(loc, valH); 905 Value rt = rewriter.create<bufferization::ToTensorOp>(loc, rowH); 906 Value ct = rewriter.create<bufferization::ToTensorOp>(loc, colH); 907 rewriter.replaceOpWithNewOp<AssembleOp>(op, c.getType(), vt, 908 ValueRange{rt, ct}); 909 return success(); 910 } 911 912 // Match and rewrite 2:4 SpMM kernel. 913 static LogicalResult rewrite2To4SpMM(PatternRewriter &rewriter, 914 linalg::GenericOp op) { 915 Location loc = op.getLoc(); 916 Value A = op.getOperand(0); 917 Value B = op.getOperand(1); 918 Value C = op.getOperand(2); // we have C = AB 919 SmallVector<Value> tokens; 920 921 // All input should be dense tensors. 922 if (!isDenseTensor(A) || !isDenseTensor(B) || !isDenseTensor(C)) 923 return failure(); 924 925 // Start sparse kernel and copy data from host to device. 926 // a : bufA -> matA 927 // b : bufB -> matB 928 // c : bufC -> matC 929 Value bufA = genTensorToMemref(rewriter, loc, A); 930 Value matA = genAllocCopy(rewriter, loc, bufA, tokens); 931 Value bufB = genTensorToMemref(rewriter, loc, B); 932 Value matB = genAllocCopy(rewriter, loc, bufB, tokens); 933 Value bufC = genTensorToMemref(rewriter, loc, C); 934 Value matC = genAllocCopy(rewriter, loc, bufC, tokens); 935 genBlockingWait(rewriter, loc, tokens); 936 tokens.clear(); 937 938 // Create sparse environment and sparse matrix/dense vector handles. 939 Value szm = linalg::createOrFoldDimOp(rewriter, loc, matA, 0); 940 Value szk = linalg::createOrFoldDimOp(rewriter, loc, matB, 0); 941 Value szn = linalg::createOrFoldDimOp(rewriter, loc, matC, 1); 942 Type indexTp = rewriter.getIndexType(); 943 Type dnTensorHandleTp = rewriter.getType<gpu::SparseDnTensorHandleType>(); 944 Type spMatHandleTp = rewriter.getType<gpu::SparseSpMatHandleType>(); 945 Type tokenTp = rewriter.getType<gpu::AsyncTokenType>(); 946 Value token = genFirstWait(rewriter, loc); 947 Operation *spGenA = rewriter.create<gpu::Create2To4SpMatOp>( 948 loc, spMatHandleTp, tokenTp, token, szm, szk, 949 gpu::Prune2To4SpMatFlag::PRUNE_AND_CHECK, matA); 950 Value spMatA = spGenA->getResult(0); 951 token = spGenA->getResult(1); 952 auto dmatB = rewriter.create<gpu::CreateDnTensorOp>( 953 loc, dnTensorHandleTp, tokenTp, token, matB, 954 SmallVector<Value>{szk, szn}); 955 Value dnB = dmatB.getResult(0); 956 token = dmatB.getAsyncToken(); 957 auto dmatC = rewriter.create<gpu::CreateDnTensorOp>( 958 loc, dnTensorHandleTp, tokenTp, token, matC, 959 SmallVector<Value>{szm, szn}); 960 Value dnC = dmatC.getResult(0); 961 token = dmatC.getAsyncToken(); 962 auto dmatCType = llvm::cast<ShapedType>(matC.getType()).getElementType(); 963 964 // Precompute buffersize for SpMM. 965 SmallVector<Type> bufferTypes_{indexTp, indexTp, indexTp}; 966 TypeRange bufferTypes(bufferTypes_); 967 auto bufferComp = rewriter.create<gpu::SpMMBufferSizeOp>( 968 loc, bufferTypes, tokenTp, token, gpu::TransposeMode::NON_TRANSPOSE, 969 gpu::TransposeMode::NON_TRANSPOSE, spMatA, dnB, dnC, 970 /*computeType=*/dmatCType); 971 token = bufferComp.getAsyncToken(); 972 973 // Allocate buffers on host. 974 Value bufferSz1 = bufferComp.getResult(0); 975 auto buf1 = genAllocBuffer(rewriter, loc, bufferSz1, token); 976 Value buffer1 = buf1.getResult(0); 977 token = buf1.getAsyncToken(); 978 Value bufferSz2 = bufferComp.getResult(1); 979 auto buf2 = genAllocBuffer(rewriter, loc, bufferSz2, token); 980 Value buffer2 = buf2.getResult(0); 981 token = buf2.getAsyncToken(); 982 Value bufferSz3 = bufferComp.getResult(2); 983 auto buf3 = genAllocBuffer(rewriter, loc, bufferSz3, token); 984 Value buffer3 = buf3.getResult(0); 985 token = buf3.getAsyncToken(); 986 987 // Perform the SpMM. 988 auto dnCType = llvm::cast<ShapedType>(matC.getType()).getElementType(); 989 auto spmmComp = rewriter.create<gpu::SpMMOp>( 990 loc, tokenTp, token, spMatA, dnB, dnC, /*computeType=*/dnCType, 991 SmallVector<Value>{buffer1, buffer2, buffer3}); 992 token = spmmComp.getAsyncToken(); 993 994 // Copy data back to host and free all the resources. 995 token = rewriter.create<gpu::DestroySpMatOp>(loc, tokenTp, token, spMatA) 996 .getAsyncToken(); 997 token = rewriter.create<gpu::DestroyDnTensorOp>(loc, tokenTp, token, dnB) 998 .getAsyncToken(); 999 token = rewriter.create<gpu::DestroyDnTensorOp>(loc, tokenTp, token, dnC) 1000 .getAsyncToken(); 1001 SmallVector<Value> newDynamicSizes; 1002 token = genDeallocMemRef(rewriter, loc, buffer1, token); 1003 token = genDeallocMemRef(rewriter, loc, buffer2, token); 1004 token = genDeallocMemRef(rewriter, loc, buffer3, token); 1005 token = genDeallocMemRef(rewriter, loc, matA, token); 1006 token = genDeallocMemRef(rewriter, loc, matB, token); 1007 token = genCopyMemRef(rewriter, loc, bufC, matC, token); 1008 token = genDeallocMemRef(rewriter, loc, matC, token); 1009 tokens.push_back(token); 1010 genBlockingWait(rewriter, loc, tokens); 1011 tokens.clear(); 1012 1013 // Done. 1014 rewriter.replaceOpWithNewOp<bufferization::ToTensorOp>(op, bufC); 1015 return success(); 1016 } 1017 1018 /// Match and rewrite SDDMM kernel. 1019 static LogicalResult rewriteSDDMM(PatternRewriter &rewriter, 1020 linalg::GenericOp op, bool enableRT) { 1021 Location loc = op.getLoc(); 1022 Value a = op.getOperand(0); 1023 Value b = op.getOperand(1); 1024 Value c = op.getOperand(2); 1025 SmallVector<Value> tokens; 1026 1027 // Only admissible sparse matrix format (no COO/CSC) and dense matrices. 1028 SparseTensorType aTp = getSparseTensorType(a); 1029 SparseTensorType bTp = getSparseTensorType(b); 1030 SparseTensorType cTp = getSparseTensorType(c); 1031 auto format = getCuSparseFormat(cTp, bTp, aTp, enableRT, /*isMatVec=*/false); 1032 if (format == CuSparseFormat::kNone || format == CuSparseFormat::kCOO || 1033 format == CuSparseFormat::kCSC) 1034 return failure(); 1035 1036 // The SDDMM does the in-place operation. 1037 // Start sparse kernel and copy data from host to device. 1038 // a : bufA -> matA 1039 // b : bufB -> matB 1040 // c : memR/memC/memV -> rowC,colC,valC 1041 Value nseC = rewriter.create<NumberOfEntriesOp>(loc, c); 1042 Value szm = linalg::createOrFoldDimOp(rewriter, loc, a, 0); 1043 Value szk = linalg::createOrFoldDimOp(rewriter, loc, a, 1); 1044 Value szn = linalg::createOrFoldDimOp(rewriter, loc, b, 1); 1045 Value bufA = genTensorToMemref(rewriter, loc, a); 1046 Value matA = genAllocCopy(rewriter, loc, bufA, tokens); 1047 Value bufB = genTensorToMemref(rewriter, loc, b); 1048 Value matB = genAllocCopy(rewriter, loc, bufB, tokens); 1049 Value memR = genFirstPosOrCrds(rewriter, loc, c, format, enableRT); 1050 Value memC = genSecondCrds(rewriter, loc, c, format, enableRT); // or empty 1051 Value memV = genToValues(rewriter, loc, c); 1052 Value rowC = genAllocCopy(rewriter, loc, memR, tokens); 1053 Value colC = memC ? genAllocCopy(rewriter, loc, memC, tokens) : Value(); 1054 Value valC = genAllocCopy(rewriter, loc, memV, tokens); 1055 genBlockingWait(rewriter, loc, tokens); 1056 tokens.clear(); 1057 1058 // Create sparse environment and sparse matrix/dense matrix handles. 1059 Type indexTp = rewriter.getIndexType(); 1060 Type dnMatHandleTp = rewriter.getType<gpu::SparseDnTensorHandleType>(); 1061 Type spMatHandleTp = rewriter.getType<gpu::SparseSpMatHandleType>(); 1062 Type tokenTp = rewriter.getType<gpu::AsyncTokenType>(); 1063 Value token = genFirstWait(rewriter, loc); 1064 auto dmatA = rewriter.create<gpu::CreateDnTensorOp>( 1065 loc, dnMatHandleTp, tokenTp, token, matA, SmallVector<Value>{szm, szk}); 1066 Value dnA = dmatA.getResult(0); 1067 token = dmatA.getAsyncToken(); 1068 auto dmatB = rewriter.create<gpu::CreateDnTensorOp>( 1069 loc, dnMatHandleTp, tokenTp, token, matB, SmallVector<Value>{szk, szn}); 1070 Value dnB = dmatB.getResult(0); 1071 token = dmatB.getAsyncToken(); 1072 Operation *spGenC = 1073 genSpMat(rewriter, loc, cTp, spMatHandleTp, tokenTp, token, szm, szn, 1074 nseC, rowC, colC, valC, format, enableRT); 1075 Value spMatC = spGenC->getResult(0); 1076 token = spGenC->getResult(1); 1077 auto dnCType = llvm::cast<ShapedType>(c.getType()).getElementType(); 1078 1079 // Precompute buffersize for SDDMM. 1080 auto bufferComp = rewriter.create<gpu::SDDMMBufferSizeOp>( 1081 loc, indexTp, tokenTp, token, dnA, dnB, spMatC, dnCType); 1082 Value bufferSz = bufferComp.getResult(0); 1083 token = bufferComp.getAsyncToken(); 1084 auto buf = genAllocBuffer(rewriter, loc, bufferSz, token); 1085 Value buffer = buf.getResult(0); 1086 token = buf.getAsyncToken(); 1087 1088 // Perform the SDDMM. 1089 auto sddmmComp = rewriter.create<gpu::SDDMMOp>(loc, tokenTp, token, dnA, dnB, 1090 spMatC, dnCType, buffer); 1091 token = sddmmComp.getAsyncToken(); 1092 1093 // Copy data back to host and free all the resoures. 1094 token = rewriter.create<gpu::DestroyDnTensorOp>(loc, tokenTp, token, dnA) 1095 .getAsyncToken(); 1096 token = rewriter.create<gpu::DestroyDnTensorOp>(loc, tokenTp, token, dnB) 1097 .getAsyncToken(); 1098 token = rewriter.create<gpu::DestroySpMatOp>(loc, tokenTp, token, spMatC) 1099 .getAsyncToken(); 1100 token = genDeallocMemRef(rewriter, loc, buffer, token); 1101 token = genDeallocMemRef(rewriter, loc, matA, token); 1102 token = genDeallocMemRef(rewriter, loc, matB, token); 1103 token = genDeallocMemRef(rewriter, loc, rowC, token); 1104 if (colC) 1105 token = genDeallocMemRef(rewriter, loc, colC, token); 1106 token = genCopyMemRef(rewriter, loc, memV, valC, token); 1107 token = genDeallocMemRef(rewriter, loc, valC, token); 1108 tokens.push_back(token); 1109 genBlockingWait(rewriter, loc, tokens); 1110 tokens.clear(); 1111 1112 // Done. 1113 rewriter.replaceOpWithNewOp<sparse_tensor::LoadOp>(op, c); 1114 return success(); 1115 } 1116 1117 //===----------------------------------------------------------------------===// 1118 // Rewriting rules for direct code generation. 1119 //===----------------------------------------------------------------------===// 1120 1121 /// Proof-of-concept rewriter. This rule generates a GPU implementation 1122 /// for each outermost forall loop generated by the sparsifier. 1123 /// TODO: right now works with parallelization-strategy=dense-outer-loop 1124 /// but give this its own flags in the future 1125 struct ForallRewriter : public OpRewritePattern<scf::ParallelOp> { 1126 using OpRewritePattern<scf::ParallelOp>::OpRewritePattern; 1127 1128 ForallRewriter(MLIRContext *context, unsigned nT) 1129 : OpRewritePattern(context), numThreads(nT){}; 1130 1131 LogicalResult matchAndRewrite(scf::ParallelOp forallOp, 1132 PatternRewriter &rewriter) const override { 1133 // Reject inadmissible loop form. 1134 // Essentially only accept a loop, generated by the sparsifier, 1135 // of the form 1136 // forall (i = 0; i < N; i++) 1137 // so that cyclic scheduling over the threads is easy. 1138 if (!forallOp->hasAttr(LoopEmitter::getLoopEmitterLoopAttrName()) || 1139 forallOp.getNumReductions() != 0 || forallOp.getNumLoops() != 1 || 1140 !matchPattern(forallOp.getLowerBound()[0], m_Zero()) || 1141 !matchPattern(forallOp.getStep()[0], m_One())) 1142 return failure(); 1143 // Collect every value that is computed outside the parallel loop. 1144 SetVector<Value> invariants; // stable iteration! 1145 forallOp->walk([&](Operation *op) { 1146 // Collect all values of admissible ops. 1147 for (OpOperand &o : op->getOpOperands()) { 1148 Value val = o.get(); 1149 Block *block; 1150 if (auto arg = dyn_cast<BlockArgument>(val)) 1151 block = arg.getOwner(); 1152 else 1153 block = val.getDefiningOp()->getBlock(); 1154 if (!isNestedIn(block, forallOp)) 1155 invariants.insert(val); 1156 } 1157 }); 1158 // Outline the outside values as proper parameters. Fail when sharing 1159 // value between host and device is not straightforward. 1160 SmallVector<Value> constants; 1161 SmallVector<Value> scalars; 1162 SmallVector<Value> buffers; 1163 for (Value val : invariants) { 1164 Type tp = val.getType(); 1165 if (val.getDefiningOp<arith::ConstantOp>()) 1166 constants.push_back(val); 1167 else if (isa<FloatType>(tp) || tp.isIntOrIndex()) 1168 scalars.push_back(val); 1169 else if (isa<MemRefType>(tp)) 1170 buffers.push_back(val); 1171 else 1172 return failure(); // don't know how to share 1173 } 1174 // Pass outlined non-constant values. 1175 // TODO: Experiment with `useHostRegistrationForOut` to see if we want to 1176 // keep the feature at all (either through a heuristic or compiler 1177 // option for gpu codegen). 1178 Location loc = forallOp->getLoc(); 1179 SmallVector<Value> args; 1180 SmallVector<Value> tokens; 1181 Value out = genParametersIn(rewriter, loc, scalars, buffers, args, tokens, 1182 /*useHostRegistrationForOut=*/false); 1183 // Set up GPU module and construct GPU function. 1184 auto saveIp = rewriter.saveInsertionPoint(); 1185 ModuleOp topModule = forallOp->getParentOfType<ModuleOp>(); 1186 auto gpuModule = genGPUModule(rewriter, topModule); 1187 auto gpuFunc = genGPUFunc(rewriter, gpuModule, args); 1188 genGPUCode(rewriter, gpuFunc, forallOp, constants, scalars, buffers); 1189 // Generate code that launches the kernel asynchronously, blocking on all 1190 // opens tokens and yielding a new token for the output. 1191 // TODO: Passing in tokens to launch up does not seem to be properly lowered 1192 // by cubin yet, hence the current blocking wait. 1193 rewriter.restoreInsertionPoint(saveIp); 1194 genBlockingWait(rewriter, loc, tokens); 1195 tokens.clear(); 1196 Value kernelToken = 1197 genLaunchGPUFunc(rewriter, gpuFunc, args, tokens, numThreads); 1198 // Finalize the outlined arguments. 1199 genParametersOut(rewriter, loc, out, kernelToken, scalars, buffers, args, 1200 tokens); 1201 genBlockingWait(rewriter, loc, tokens); 1202 rewriter.eraseOp(forallOp); 1203 return success(); 1204 } 1205 1206 private: 1207 // Helper method to see if block appears in given loop. 1208 static bool isNestedIn(Block *block, scf::ParallelOp forallOp) { 1209 for (Operation *o = block->getParentOp(); o; o = o->getParentOp()) { 1210 if (o == forallOp) 1211 return true; 1212 } 1213 return false; 1214 } 1215 1216 unsigned numThreads; 1217 }; 1218 1219 //===----------------------------------------------------------------------===// 1220 // Rewriting rules for library recognition and code generation. 1221 //===----------------------------------------------------------------------===// 1222 1223 /// Proof-of-concept rewriter. This rule recognizes certain math kernels 1224 /// and replaces these with corresponding calls into a sparse library. 1225 struct LinalgOpRewriter : public OpRewritePattern<linalg::GenericOp> { 1226 using OpRewritePattern<linalg::GenericOp>::OpRewritePattern; 1227 1228 LinalgOpRewriter(MLIRContext *context, bool rt) 1229 : OpRewritePattern(context), enableRT(rt) {} 1230 1231 LogicalResult matchAndRewrite(linalg::GenericOp op, 1232 PatternRewriter &rewriter) const override { 1233 if (op.getNumDpsInits() != 1) 1234 return failure(); // reject multi-output 1235 1236 const unsigned numLoops = op.getNumLoops(); 1237 const unsigned numTensors = op->getNumOperands(); 1238 const auto iteratorTypes = op.getIteratorTypesArray(); 1239 SmallVector<AffineMap, 4> maps = op.getIndexingMapsArray(); 1240 1241 using MapList = ArrayRef<ArrayRef<AffineExpr>>; 1242 auto infer = [](MapList m) { return AffineMap::inferFromExprList(m); }; 1243 AffineExpr i, j, k; 1244 bindDims(getContext(), i, j, k); 1245 1246 // TODO: more robust patterns, tranposed versions, more kernels, 1247 // identify alpha and beta and pass them to the CUDA calls. 1248 1249 // Recognize a SpMV kernel. 1250 if (numLoops == 2 && numTensors == 3 && 1251 linalg::isParallelIterator(iteratorTypes[0]) && 1252 linalg::isReductionIterator(iteratorTypes[1]) && 1253 maps == infer({{i, j}, {j}, {i}}) && matchSumOfMultOfArgs(op)) { 1254 return rewriteSpMV(rewriter, op, enableRT); 1255 } 1256 1257 // Recognize a SpGEMM, 2:4-SpMM, or SpMM kernel. 1258 if (numLoops == 3 && numTensors == 3 && 1259 linalg::isParallelIterator(iteratorTypes[0]) && 1260 linalg::isParallelIterator(iteratorTypes[1]) && 1261 linalg::isReductionIterator(iteratorTypes[2]) && 1262 maps == infer({{i, k}, {k, j}, {i, j}}) && matchSumOfMultOfArgs(op)) { 1263 if (!isDenseTensor(op.getOperand(0)) && !isDenseTensor(op.getOperand(1))) 1264 return rewriteSpGEMM(rewriter, op, enableRT); 1265 if (op->getAttr("DENSE24")) 1266 return rewrite2To4SpMM(rewriter, op); 1267 return rewriteSpMM(rewriter, op, enableRT); 1268 } 1269 1270 // Recognize a SDDMM kernel. 1271 if (numLoops == 3 && numTensors == 3 && 1272 linalg::isParallelIterator(iteratorTypes[0]) && 1273 linalg::isParallelIterator(iteratorTypes[1]) && 1274 linalg::isReductionIterator(iteratorTypes[2]) && 1275 maps == infer({{i, k}, {k, j}, {i, j}}) && 1276 matchSumReductionOfMulUnary(op)) { 1277 return rewriteSDDMM(rewriter, op, enableRT); 1278 } 1279 1280 return failure(); 1281 } 1282 1283 private: 1284 bool enableRT; 1285 }; 1286 1287 } // namespace 1288 1289 //===----------------------------------------------------------------------===// 1290 // Public method for populating GPU rewriting rules. 1291 // 1292 // Currently two set of rewriting rules are made available. The first set 1293 // implements direct code generation, currently by means of convering the 1294 // outermost paralell loop into GPU threads. The second set implements 1295 // libary recognition of a set of sparse operations. Eventually, the right 1296 // combination of these two approaches has to be found. 1297 //===----------------------------------------------------------------------===// 1298 1299 void mlir::populateSparseGPUCodegenPatterns(RewritePatternSet &patterns, 1300 unsigned numThreads) { 1301 patterns.add<ForallRewriter>(patterns.getContext(), numThreads); 1302 } 1303 1304 void mlir::populateSparseGPULibgenPatterns(RewritePatternSet &patterns, 1305 bool enableRT) { 1306 patterns.add<LinalgOpRewriter>(patterns.getContext(), enableRT); 1307 } 1308