1 //===- SparseGPUCodegen.cpp - Generates GPU code --------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This is a prototype GPU codegenerator for the sparsifier. 10 // The objective is to eventually use the right combination of 11 // direct code generation and libary calls into vendor-specific 12 // highly optimized sparse libraries (e.g. cuSparse for CUDA). 13 // 14 //===----------------------------------------------------------------------===// 15 16 #include "Utils/CodegenUtils.h" 17 #include "Utils/LoopEmitter.h" 18 19 #include "mlir/Dialect/Bufferization/IR/Bufferization.h" 20 #include "mlir/Dialect/GPU/IR/GPUDialect.h" 21 #include "mlir/Dialect/Linalg/IR/Linalg.h" 22 #include "mlir/Dialect/Linalg/Utils/Utils.h" 23 #include "mlir/Dialect/MemRef/IR/MemRef.h" 24 #include "mlir/Dialect/SCF/IR/SCF.h" 25 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" 26 #include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h" 27 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h" 28 #include "mlir/IR/IRMapping.h" 29 #include "mlir/IR/Matchers.h" 30 31 using namespace mlir; 32 using namespace mlir::sparse_tensor; 33 34 namespace { 35 36 // Sparse formats supported by cuSparse. 37 enum class CuSparseFormat { 38 kNone, 39 kCOO, 40 kCSR, 41 kCSC, 42 kBSR, 43 }; 44 45 //===----------------------------------------------------------------------===// 46 // Helper methods. 47 //===----------------------------------------------------------------------===// 48 49 /// Marks the given top module as a GPU container module. 50 static void markAsGPUContainer(ModuleOp topModule) { 51 topModule->setAttr(gpu::GPUDialect::getContainerModuleAttrName(), 52 UnitAttr::get(topModule->getContext())); 53 } 54 55 /// Constructs a new GPU module (for GPU kernels) inside the given top module, 56 /// or returns an existing GPU module if one was built previously. 57 static gpu::GPUModuleOp genGPUModule(OpBuilder &builder, ModuleOp topModule) { 58 for (auto op : topModule.getBodyRegion().getOps<gpu::GPUModuleOp>()) 59 return op; // existing 60 markAsGPUContainer(topModule); 61 builder.setInsertionPointToStart(&topModule.getBodyRegion().front()); 62 return builder.create<gpu::GPUModuleOp>(topModule->getLoc(), 63 "sparse_kernels"); 64 } 65 66 /// Constructs a new GPU kernel in the given GPU module. 67 static gpu::GPUFuncOp genGPUFunc(OpBuilder &builder, gpu::GPUModuleOp gpuModule, 68 SmallVectorImpl<Value> &args) { 69 // Get a unique kernel name. Not very creative, 70 // but we simply try kernel0, kernel1, etc. 71 unsigned kernelNumber = 0; 72 SmallString<16> kernelName; 73 do { 74 kernelName.clear(); 75 ("kernel" + Twine(kernelNumber++)).toStringRef(kernelName); 76 } while (gpuModule.lookupSymbol(kernelName)); 77 // Then we insert a new kernel with given arguments into the module. 78 builder.setInsertionPointToStart(&gpuModule.getBodyRegion().front()); 79 SmallVector<Type> argsTp; 80 for (unsigned i = 0, e = args.size(); i < e; i++) 81 argsTp.push_back(args[i].getType()); 82 FunctionType type = FunctionType::get(gpuModule->getContext(), argsTp, {}); 83 auto gpuFunc = 84 builder.create<gpu::GPUFuncOp>(gpuModule->getLoc(), kernelName, type); 85 gpuFunc->setAttr(gpu::GPUDialect::getKernelFuncAttrName(), 86 builder.getUnitAttr()); 87 return gpuFunc; 88 } 89 90 /// Constructs code to launch GPU kernel. 91 static Value genLaunchGPUFunc(OpBuilder &builder, gpu::GPUFuncOp gpuFunc, 92 SmallVectorImpl<Value> &args, 93 SmallVectorImpl<Value> &tokens, 94 unsigned numThreads) { 95 Location loc = gpuFunc->getLoc(); 96 Value none = TypedValue<::mlir::IntegerType>{}; 97 Value one = constantIndex(builder, loc, 1); 98 Value numT = constantIndex(builder, loc, numThreads); 99 gpu::KernelDim3 gridSize = {one, one, one}; 100 gpu::KernelDim3 blckSize = {numT, one, one}; 101 return builder 102 .create<gpu::LaunchFuncOp>(loc, gpuFunc, gridSize, blckSize, 103 /*dynSharedMemSz*/ none, args, 104 builder.getType<gpu::AsyncTokenType>(), tokens) 105 .getAsyncToken(); 106 } 107 108 /// Maps the provided ranked host buffer into the device address space. 109 /// Writes from the host are guaranteed to be visible to device kernels 110 /// that are launched afterwards. Writes from the device are guaranteed 111 /// to be visible on the host after synchronizing with the device kernel 112 /// completion. Needs to cast the buffer to a unranked buffer. 113 static Value genHostRegisterMemref(OpBuilder &builder, Location loc, 114 Value mem) { 115 MemRefType memTp = cast<MemRefType>(mem.getType()); 116 UnrankedMemRefType resTp = 117 UnrankedMemRefType::get(memTp.getElementType(), /*memorySpace=*/0); 118 Value cast = builder.create<memref::CastOp>(loc, resTp, mem); 119 builder.create<gpu::HostRegisterOp>(loc, cast); 120 return cast; 121 } 122 123 /// Unmaps the provided buffer, expecting the casted buffer. 124 static void genHostUnregisterMemref(OpBuilder &builder, Location loc, 125 Value cast) { 126 builder.create<gpu::HostUnregisterOp>(loc, cast); 127 } 128 129 /// Generates first wait in an asynchronous chain. 130 static Value genFirstWait(OpBuilder &builder, Location loc) { 131 Type tokenType = builder.getType<gpu::AsyncTokenType>(); 132 return builder.create<gpu::WaitOp>(loc, tokenType, ValueRange()) 133 .getAsyncToken(); 134 } 135 136 /// Generates last, blocking wait in an asynchronous chain. 137 static void genBlockingWait(OpBuilder &builder, Location loc, 138 ValueRange operands) { 139 builder.create<gpu::WaitOp>(loc, Type(), operands); 140 } 141 142 /// Allocates memory on the device. 143 /// TODO: A `host_shared` attribute could be used to indicate that 144 /// the buffer is visible by both host and device, but lowering 145 /// that feature does not seem to be fully supported yet. 146 static gpu::AllocOp genAllocMemRef(OpBuilder &builder, Location loc, Value mem, 147 Value token) { 148 auto tp = cast<ShapedType>(mem.getType()); 149 auto elemTp = tp.getElementType(); 150 auto shape = tp.getShape(); 151 auto memTp = MemRefType::get(shape, elemTp); 152 SmallVector<Value> dynamicSizes; 153 for (unsigned r = 0, rank = tp.getRank(); r < rank; r++) { 154 if (shape[r] == ShapedType::kDynamic) { 155 Value dimOp = linalg::createOrFoldDimOp(builder, loc, mem, r); 156 dynamicSizes.push_back(dimOp); 157 } 158 } 159 return builder.create<gpu::AllocOp>(loc, TypeRange({memTp, token.getType()}), 160 token, dynamicSizes, ValueRange()); 161 } 162 163 // Allocates a typed buffer on the host with given size. 164 static Value genHostBuffer(OpBuilder &builder, Location loc, Type type, 165 Value size) { 166 const auto memTp = MemRefType::get({ShapedType::kDynamic}, type); 167 return builder.create<memref::AllocOp>(loc, memTp, size).getResult(); 168 } 169 170 // Allocates a typed buffer on the device with given size. 171 static gpu::AllocOp genAllocBuffer(OpBuilder &builder, Location loc, Type type, 172 Value size, Value token) { 173 const auto memTp = MemRefType::get({ShapedType::kDynamic}, type); 174 return builder.create<gpu::AllocOp>(loc, TypeRange({memTp, token.getType()}), 175 token, size, ValueRange()); 176 } 177 178 // Allocates a void buffer on the device with given size. 179 static gpu::AllocOp genAllocBuffer(OpBuilder &builder, Location loc, Value size, 180 Value token) { 181 return genAllocBuffer(builder, loc, builder.getI8Type(), size, token); 182 } 183 184 /// Deallocates memory from the device. 185 static Value genDeallocMemRef(OpBuilder &builder, Location loc, Value mem, 186 Value token) { 187 return builder.create<gpu::DeallocOp>(loc, token.getType(), token, mem) 188 .getAsyncToken(); 189 } 190 191 /// Copies memory between host and device (direction is implicit). 192 static Value genCopyMemRef(OpBuilder &builder, Location loc, Value dst, 193 Value src, Value token) { 194 return builder.create<gpu::MemcpyOp>(loc, token.getType(), token, dst, src) 195 .getAsyncToken(); 196 } 197 198 /// Generates an alloc/copy pair. 199 static Value genAllocCopy(OpBuilder &builder, Location loc, Value b, 200 SmallVectorImpl<Value> &tokens) { 201 Value firstToken = genFirstWait(builder, loc); 202 auto alloc = genAllocMemRef(builder, loc, b, firstToken); 203 Value devMem = alloc.getResult(0); 204 Value depToken = alloc.getAsyncToken(); // copy-after-alloc 205 tokens.push_back(genCopyMemRef(builder, loc, devMem, b, depToken)); 206 return devMem; 207 } 208 209 /// Generates a memref from tensor operation. 210 static Value genTensorToMemref(PatternRewriter &rewriter, Location loc, 211 Value tensor) { 212 auto tensorType = llvm::cast<ShapedType>(tensor.getType()); 213 auto memrefType = 214 MemRefType::get(tensorType.getShape(), tensorType.getElementType()); 215 return rewriter.create<bufferization::ToMemrefOp>(loc, memrefType, tensor); 216 } 217 218 /// Prepares the outlined arguments, passing scalars and buffers in. Here we 219 /// assume that the first buffer is the one allocated for output. We create 220 /// a set of properly chained asynchronous allocation/copy pairs to increase 221 /// overlap before launching the kernel. 222 static Value genParametersIn(OpBuilder &builder, Location loc, 223 SmallVectorImpl<Value> &scalars, 224 SmallVectorImpl<Value> &buffers, 225 SmallVectorImpl<Value> &args, 226 SmallVectorImpl<Value> &tokens, 227 bool useHostRegistrationForOut) { 228 Value out; 229 // Scalars are passed by value. 230 for (Value s : scalars) 231 args.push_back(s); 232 // Buffers are need to be made visible on device. 233 for (Value b : buffers) { 234 if (useHostRegistrationForOut) { 235 out = genHostRegisterMemref(builder, loc, b); 236 args.push_back(b); 237 useHostRegistrationForOut = false; 238 continue; 239 } 240 args.push_back(genAllocCopy(builder, loc, b, tokens)); 241 } 242 return out; 243 } 244 245 /// Finalizes the outlined arguments. The output buffer is copied depending 246 /// on the kernel token and then deallocated. All other buffers are simply 247 /// deallocated. Then we wait for all operations to complete. 248 static void genParametersOut(OpBuilder &builder, Location loc, Value out, 249 Value kernelToken, SmallVectorImpl<Value> &scalars, 250 SmallVectorImpl<Value> &buffers, 251 SmallVectorImpl<Value> &args, 252 SmallVectorImpl<Value> &tokens) { 253 unsigned base = scalars.size(); 254 for (unsigned i = base, e = args.size(); i < e; i++) { 255 Value firstToken; 256 if (i == base) { 257 // Assumed output parameter: unregister or copy-out. 258 if (out) { 259 genHostUnregisterMemref(builder, loc, out); 260 out = Value(); 261 continue; 262 } 263 firstToken = 264 genCopyMemRef(builder, loc, buffers[0], args[i], kernelToken); 265 } else { 266 firstToken = genFirstWait(builder, loc); 267 } 268 tokens.push_back(genDeallocMemRef(builder, loc, args[i], firstToken)); 269 } 270 } 271 272 /// Constructs code for new GPU kernel. 273 static void genGPUCode(PatternRewriter &rewriter, gpu::GPUFuncOp gpuFunc, 274 scf::ParallelOp forallOp, 275 SmallVectorImpl<Value> &constants, 276 SmallVectorImpl<Value> &scalars, 277 SmallVectorImpl<Value> &buffers) { 278 Location loc = gpuFunc->getLoc(); 279 Block &block = gpuFunc.getBody().front(); 280 rewriter.setInsertionPointToStart(&block); 281 282 // Re-generate the constants, recapture all arguments. 283 unsigned arg = 0; 284 IRMapping irMap; 285 for (Value c : constants) 286 irMap.map(c, rewriter.clone(*c.getDefiningOp())->getResult(0)); 287 for (Value s : scalars) 288 irMap.map(s, block.getArgument(arg++)); 289 for (Value b : buffers) 290 irMap.map(b, block.getArgument(arg++)); 291 292 // Assume 1-dimensional grid/block configuration (only x dimension), 293 // so that: 294 // row = blockIdx.x * blockDim.x + threadIdx.x 295 // inc = blockDim.x * gridDim.x 296 Value bid = rewriter.create<gpu::BlockIdOp>(loc, gpu::Dimension::x); 297 Value bsz = rewriter.create<gpu::BlockDimOp>(loc, gpu::Dimension::x); 298 Value tid = rewriter.create<gpu::ThreadIdOp>(loc, gpu::Dimension::x); 299 Value gsz = rewriter.create<gpu::GridDimOp>(loc, gpu::Dimension::x); 300 Value mul = rewriter.create<arith::MulIOp>(loc, bid, bsz); 301 Value row = rewriter.create<arith::AddIOp>(loc, mul, tid); 302 Value inc = rewriter.create<arith::MulIOp>(loc, bsz, gsz); 303 304 // Construct the iteration over the computational space that 305 // accounts for the fact that the total number of threads and 306 // the amount of work to be done usually do not match precisely. 307 // for (r = row; r < N; r += inc) { 308 // <loop-body> 309 // } 310 Value upper = irMap.lookup(forallOp.getUpperBound()[0]); 311 scf::ForOp forOp = rewriter.create<scf::ForOp>(loc, row, upper, inc); 312 // The scf.for builder creates an empty block. scf.for does not allow multiple 313 // blocks in its region, so delete the block before `cloneRegionBefore` adds 314 // an additional block. 315 rewriter.eraseBlock(forOp.getBody()); 316 rewriter.cloneRegionBefore(forallOp.getRegion(), forOp.getRegion(), 317 forOp.getRegion().begin(), irMap); 318 // Replace the scf.reduce terminator. 319 rewriter.setInsertionPoint(forOp.getBody()->getTerminator()); 320 rewriter.replaceOpWithNewOp<scf::YieldOp>(forOp.getBody()->getTerminator()); 321 322 // Done. 323 rewriter.setInsertionPointAfter(forOp); 324 rewriter.create<gpu::ReturnOp>(gpuFunc->getLoc()); 325 } 326 327 //===----------------------------------------------------------------------===// 328 // Library helper methods. 329 //===----------------------------------------------------------------------===// 330 331 /// Helper to detect a + b with arguments taken from given block. 332 static bool matchAddOfArgs(Block *block, Value val) { 333 if (auto *def = val.getDefiningOp()) { 334 if (isa<arith::AddFOp, arith::AddIOp>(def)) { 335 Value a = block->getArguments()[0]; 336 Value b = block->getArguments()[1]; 337 return (def->getOperand(0) == a && def->getOperand(1) == b) || 338 (def->getOperand(0) == b && def->getOperand(1) == a); 339 } 340 } 341 return false; 342 } 343 344 /// Helper to detect a * b with arguments taken from given block. 345 static bool matchMulOfArgs(Block *block, Value val) { 346 if (auto *def = val.getDefiningOp()) { 347 if (isa<arith::MulFOp, arith::MulIOp>(def)) { 348 Value a = block->getArguments()[0]; 349 Value b = block->getArguments()[1]; 350 return (def->getOperand(0) == a && def->getOperand(1) == b) || 351 (def->getOperand(0) == b && def->getOperand(1) == a); 352 } 353 } 354 return false; 355 } 356 357 /// Helper to detect x = x + a * b 358 static bool matchSumOfMultOfArgs(linalg::GenericOp op) { 359 auto yieldOp = cast<linalg::YieldOp>(op.getRegion().front().getTerminator()); 360 if (auto *def = yieldOp.getOperand(0).getDefiningOp()) { 361 if (isa<arith::AddFOp, arith::AddIOp>(def)) { 362 Value x = op.getBlock()->getArguments()[2]; 363 return (def->getOperand(0) == x && 364 matchMulOfArgs(op.getBlock(), def->getOperand(1))) || 365 (def->getOperand(1) == x && 366 matchMulOfArgs(op.getBlock(), def->getOperand(0))); 367 } 368 } 369 return false; 370 } 371 372 // Helper to detect c += spy(s) x (a * b) 373 static bool matchSumReductionOfMulUnary(linalg::GenericOp op) { 374 auto yieldOp = cast<linalg::YieldOp>(op.getRegion().front().getTerminator()); 375 // The linalg yields a custom reduce result. 376 Value s_out = op.getBlock()->getArguments()[2]; 377 if (auto redOp = 378 yieldOp.getOperand(0).getDefiningOp<sparse_tensor::ReduceOp>()) { 379 // The reduce consumes the output. 380 Value other; 381 if (s_out == redOp->getOperand(0)) 382 other = redOp->getOperand(1); 383 else if (s_out == redOp->getOperand(1)) 384 other = redOp->getOperand(0); 385 else 386 return false; 387 // The reduce op also consumes an unary which also consumes the output 388 // and does not define an absent value. 389 if (auto unOp = other.getDefiningOp<sparse_tensor::UnaryOp>()) { 390 if (s_out != unOp->getOperand(0) || !unOp.getAbsentRegion().empty()) 391 return false; 392 // And the bodies are as expected. 393 auto yieldUn = cast<sparse_tensor::YieldOp>( 394 unOp.getRegion(0).front().getTerminator()); 395 auto yieldRed = cast<sparse_tensor::YieldOp>( 396 redOp.getRegion().front().getTerminator()); 397 return matchMulOfArgs(op.getBlock(), yieldUn.getOperand(0)) && 398 matchAddOfArgs(&redOp.getRegion().front(), yieldRed.getOperand(0)); 399 } 400 } 401 return false; 402 } 403 404 /// Test for dense tensor. 405 static bool isDenseTensor(Value v) { 406 auto sTp = getSparseTensorType(v); 407 return sTp.getDimRank() == sTp.getLvlRank() && sTp.isAllDense(); 408 } 409 410 /// Test for suitable positions/coordinates width. 411 static bool isAdmissibleMetaData(SparseTensorType &aTp) { 412 return (aTp.getPosWidth() == 0 || aTp.getPosWidth() >= 16) && 413 (aTp.getCrdWidth() == 0 || aTp.getCrdWidth() >= 16); 414 } 415 416 /// Test for sorted COO matrix with suitable metadata. 417 static bool isAdmissibleCOO(SparseTensorType &aTp) { 418 return aTp.getDimRank() == 2 && aTp.getLvlRank() == 2 && aTp.isIdentity() && 419 aTp.isCompressedLvl(0) && aTp.isOrderedLvl(0) && !aTp.isUniqueLvl(0) && 420 aTp.isSingletonLvl(1) && aTp.isOrderedLvl(1) && aTp.isUniqueLvl(1) && 421 isAdmissibleMetaData(aTp); 422 } 423 424 /// Test for CSR matrix with suitable metadata. 425 static bool isAdmissibleCSR(SparseTensorType &aTp) { 426 return aTp.getDimRank() == 2 && aTp.getLvlRank() == 2 && aTp.isIdentity() && 427 aTp.isDenseLvl(0) && aTp.isCompressedLvl(1) && aTp.isOrderedLvl(1) && 428 aTp.isUniqueLvl(1) && isAdmissibleMetaData(aTp); 429 } 430 431 /// Test for CSC matrix with suitable metadata. 432 static bool isAdmissibleCSC(SparseTensorType &aTp) { 433 return aTp.getDimRank() == 2 && aTp.getLvlRank() == 2 && !aTp.isIdentity() && 434 aTp.isPermutation() && aTp.isDenseLvl(0) && aTp.isCompressedLvl(1) && 435 aTp.isOrderedLvl(1) && aTp.isUniqueLvl(1) && isAdmissibleMetaData(aTp); 436 } 437 438 /// Test for BSR matrix with suitable metadata. 439 static bool isAdmissibleBSR(SparseTensorType &aTp) { 440 if (aTp.getDimRank() == 2 && aTp.getLvlRank() == 4 && aTp.isDenseLvl(0) && 441 aTp.isCompressedLvl(1) && aTp.isOrderedLvl(1) && aTp.isUniqueLvl(1) && 442 aTp.isDenseLvl(2) && aTp.isDenseLvl(3) && isAdmissibleMetaData(aTp)) { 443 // CuSparse only supports "square" blocks currently. 444 SmallVector<unsigned> dims = getBlockSize(aTp.getDimToLvl()); 445 assert(dims.size() == 2); 446 return dims[0] == dims[1] && dims[0] > 1; 447 } 448 return false; 449 } 450 451 /// Returns a suitable sparse format for the operation and given operand 452 /// types with cuSparse, or kNone if none is available. 453 static CuSparseFormat getCuSparseFormat(SparseTensorType aTp, 454 SparseTensorType bTp, 455 SparseTensorType cTp, bool enableRT, 456 bool isMatVec) { 457 // The other operands have a dense type. 458 if (bTp.hasEncoding() || cTp.hasEncoding()) 459 return CuSparseFormat::kNone; 460 // Now check for suitable operand type for the main operand. 461 if (isAdmissibleCOO(aTp)) 462 #ifdef CUSPARSE_COO_AOS 463 return isMatVec ? CuSparseFormat::kCOO : CuSparseFormat::kNone; 464 #else 465 return enableRT ? CuSparseFormat::kCOO : CuSparseFormat::kNone; 466 #endif 467 if (isAdmissibleCSR(aTp)) 468 return CuSparseFormat::kCSR; 469 if (isAdmissibleCSC(aTp)) 470 return CuSparseFormat::kCSC; 471 if (isAdmissibleBSR(aTp)) 472 return CuSparseFormat::kBSR; 473 return CuSparseFormat::kNone; 474 } 475 476 /// Generates the first positions/coordinates of a sparse matrix. 477 static Value genFirstPosOrCrds(OpBuilder &builder, Location loc, Value a, 478 CuSparseFormat format, bool enableRT) { 479 if (format == CuSparseFormat::kCOO) { 480 // Library uses SoA COO, direct IR uses AoS COO. 481 if (enableRT) 482 return genToCoordinates(builder, loc, a, 0); 483 return genToCoordinatesBuffer(builder, loc, a); 484 } 485 // Formats CSR/CSC and BSR use positions at 1. 486 return genToPositions(builder, loc, a, 1); 487 } 488 489 /// Generates the second coordinates of a sparse matrix. 490 static Value genSecondCrds(OpBuilder &builder, Location loc, Value a, 491 CuSparseFormat format, bool enableRT) { 492 bool isCOO = format == CuSparseFormat::kCOO; 493 if (isCOO && !enableRT) 494 return Value(); // nothing needed 495 // Formats CSR/CSC and BSR use coordinates at 1. 496 return genToCoordinates(builder, loc, a, 1); 497 } 498 499 /// Generates the sparse matrix handle. 500 static Operation *genSpMat(OpBuilder &builder, Location loc, 501 SparseTensorType &aTp, Type handleTp, Type tokenTp, 502 Value token, Value sz1, Value sz2, Value nseA, 503 Value rowA, Value colA, Value valA, 504 CuSparseFormat format, bool enableRT) { 505 if (format == CuSparseFormat::kCOO) { 506 // Library uses SoA COO, direct IR uses AoS COO. 507 if (enableRT) { 508 assert(colA); 509 return builder.create<gpu::CreateCooOp>(loc, handleTp, tokenTp, token, 510 sz1, sz2, nseA, rowA, colA, valA); 511 } 512 #ifdef CUSPARSE_COO_AOS 513 assert(!colA); 514 return builder.create<gpu::CreateCooAoSOp>(loc, handleTp, tokenTp, token, 515 sz1, sz2, nseA, rowA, valA); 516 #else 517 llvm_unreachable("gpu::CreateCooAoSOp is deprecated"); 518 #endif 519 } 520 assert(colA); 521 if (format == CuSparseFormat::kCSR) 522 return builder.create<gpu::CreateCsrOp>(loc, handleTp, tokenTp, token, sz1, 523 sz2, nseA, rowA, colA, valA); 524 if (format == CuSparseFormat::kCSC) 525 return builder.create<gpu::CreateCscOp>(loc, handleTp, tokenTp, token, sz1, 526 sz2, nseA, rowA, colA, valA); 527 // BSR requires a bit more work since we need to pass in the block size 528 // and all others sizes in terms of blocks (#block-rows, #block-cols, 529 // #nonzero-blocks). 530 assert(format == CuSparseFormat::kBSR); 531 SmallVector<unsigned> dims = getBlockSize(aTp.getDimToLvl()); 532 assert(dims.size() == 2 && dims[0] == dims[1]); 533 uint64_t b = dims[0]; 534 Value bSz = constantIndex(builder, loc, b); 535 Value bRows = builder.create<arith::DivUIOp>(loc, sz1, bSz); 536 Value bCols = builder.create<arith::DivUIOp>(loc, sz2, bSz); 537 Value bNum = builder.create<arith::DivUIOp>( 538 loc, nseA, constantIndex(builder, loc, b * b)); 539 return builder.create<gpu::CreateBsrOp>(loc, handleTp, tokenTp, token, bRows, 540 bCols, bNum, bSz, bSz, rowA, colA, 541 valA); 542 } 543 544 /// Match and rewrite SpMV kernel. 545 static LogicalResult rewriteSpMV(PatternRewriter &rewriter, 546 linalg::GenericOp op, bool enableRT) { 547 Location loc = op.getLoc(); 548 Value a = op.getOperand(0); 549 Value x = op.getOperand(1); 550 Value y = op.getOperand(2); // we have y = Ax 551 SmallVector<Value> tokens; 552 553 // Only admissible sparse matrix format and dense vectors (no BSR). 554 SparseTensorType aTp = getSparseTensorType(a); 555 SparseTensorType xTp = getSparseTensorType(x); 556 SparseTensorType yTp = getSparseTensorType(y); 557 auto format = getCuSparseFormat(aTp, xTp, yTp, enableRT, /*isMatVec=*/true); 558 if (format == CuSparseFormat::kNone || format == CuSparseFormat::kBSR) 559 return failure(); 560 561 // Start sparse kernel and copy data from host to device. 562 // a : memR/memC/memV -> rowA,colA,valA 563 // x : memX -> vecX 564 // y : memY -> vecY 565 Value nseA = rewriter.create<NumberOfEntriesOp>(loc, a); 566 Value szY = linalg::createOrFoldDimOp(rewriter, loc, a, 0); 567 Value szX = linalg::createOrFoldDimOp(rewriter, loc, a, 1); 568 Value memR = genFirstPosOrCrds(rewriter, loc, a, format, enableRT); 569 Value memC = genSecondCrds(rewriter, loc, a, format, enableRT); // or empty 570 Value memV = genToValues(rewriter, loc, a); 571 Value rowA = genAllocCopy(rewriter, loc, memR, tokens); 572 Value colA = memC ? genAllocCopy(rewriter, loc, memC, tokens) : Value(); 573 Value valA = genAllocCopy(rewriter, loc, memV, tokens); 574 Value memX = genTensorToMemref(rewriter, loc, x); 575 Value vecX = genAllocCopy(rewriter, loc, memX, tokens); 576 Value memY = genTensorToMemref(rewriter, loc, y); 577 Value vecY = genAllocCopy(rewriter, loc, memY, tokens); 578 genBlockingWait(rewriter, loc, tokens); 579 tokens.clear(); 580 581 // Create sparse environment and sparse matrix/dense vector handles. 582 Type indexTp = rewriter.getIndexType(); 583 Type dnTensorHandleTp = rewriter.getType<gpu::SparseDnTensorHandleType>(); 584 Type spmatHandleTp = rewriter.getType<gpu::SparseSpMatHandleType>(); 585 Type tokenTp = rewriter.getType<gpu::AsyncTokenType>(); 586 Value token = genFirstWait(rewriter, loc); 587 Operation *spGenA = 588 genSpMat(rewriter, loc, aTp, spmatHandleTp, tokenTp, token, szY, szX, 589 nseA, rowA, colA, valA, format, enableRT); 590 Value spMatA = spGenA->getResult(0); 591 token = spGenA->getResult(1); 592 auto dvecX = rewriter.create<gpu::CreateDnTensorOp>( 593 loc, dnTensorHandleTp, tokenTp, token, vecX, szX); 594 Value dnX = dvecX.getResult(0); 595 token = dvecX.getAsyncToken(); 596 auto dvecY = rewriter.create<gpu::CreateDnTensorOp>( 597 loc, dnTensorHandleTp, tokenTp, token, vecY, szY); 598 Value dnY = dvecY.getResult(0); 599 token = dvecY.getAsyncToken(); 600 auto dnYType = llvm::cast<ShapedType>(y.getType()).getElementType(); 601 602 // Precompute buffersize for SpMV. 603 auto bufferComp = rewriter.create<gpu::SpMVBufferSizeOp>( 604 loc, indexTp, tokenTp, token, spMatA, dnX, dnY, 605 /*computeType=*/dnYType); 606 Value bufferSz = bufferComp.getResult(0); 607 token = bufferComp.getAsyncToken(); 608 auto buf = genAllocBuffer(rewriter, loc, bufferSz, token); 609 Value buffer = buf.getResult(0); 610 token = buf.getAsyncToken(); 611 612 // Perform the SpMV. 613 auto spmvComp = rewriter.create<gpu::SpMVOp>( 614 loc, tokenTp, token, spMatA, dnX, dnY, /*computeType=*/dnYType, buffer); 615 token = spmvComp.getAsyncToken(); 616 617 // Copy data back to host and free all the resoures. 618 token = rewriter.create<gpu::DestroySpMatOp>(loc, tokenTp, token, spMatA) 619 .getAsyncToken(); 620 token = rewriter.create<gpu::DestroyDnTensorOp>(loc, tokenTp, token, dnX) 621 .getAsyncToken(); 622 token = rewriter.create<gpu::DestroyDnTensorOp>(loc, tokenTp, token, dnY) 623 .getAsyncToken(); 624 token = genDeallocMemRef(rewriter, loc, rowA, token); 625 if (colA) 626 token = genDeallocMemRef(rewriter, loc, colA, token); 627 token = genDeallocMemRef(rewriter, loc, valA, token); 628 token = genDeallocMemRef(rewriter, loc, buffer, token); 629 token = genDeallocMemRef(rewriter, loc, vecX, token); 630 token = genCopyMemRef(rewriter, loc, memY, vecY, token); 631 token = genDeallocMemRef(rewriter, loc, vecY, token); 632 tokens.push_back(token); 633 genBlockingWait(rewriter, loc, tokens); 634 tokens.clear(); 635 636 // Done. 637 rewriter.replaceOpWithNewOp<bufferization::ToTensorOp>(op, memY); 638 return success(); 639 } 640 641 /// Match and rewrite SpMM kernel. 642 static LogicalResult rewriteSpMM(PatternRewriter &rewriter, 643 linalg::GenericOp op, bool enableRT) { 644 Location loc = op.getLoc(); 645 Value a = op.getOperand(0); 646 Value b = op.getOperand(1); 647 Value c = op.getOperand(2); // we have C = AB 648 SmallVector<Value> tokens; 649 650 // Only admissible sparse matrix format and dense matrices (no BSR). 651 SparseTensorType aTp = getSparseTensorType(a); 652 SparseTensorType bTp = getSparseTensorType(b); 653 SparseTensorType cTp = getSparseTensorType(c); 654 auto format = getCuSparseFormat(aTp, bTp, cTp, enableRT, /*isMatVec=*/false); 655 if (format == CuSparseFormat::kNone || format == CuSparseFormat::kBSR) 656 return failure(); 657 658 // Start sparse kernel and copy data from host to device. 659 // a : memR/memC/memV -> rowA,colA,valA 660 // b : bufB -> matB 661 // c : bufC -> matC 662 Value nseA = rewriter.create<NumberOfEntriesOp>(loc, a); 663 Value szm = linalg::createOrFoldDimOp(rewriter, loc, a, 0); 664 Value szk = linalg::createOrFoldDimOp(rewriter, loc, a, 1); 665 Value szn = linalg::createOrFoldDimOp(rewriter, loc, b, 1); 666 Value memR = genFirstPosOrCrds(rewriter, loc, a, format, enableRT); 667 Value memC = genSecondCrds(rewriter, loc, a, format, enableRT); // or empty 668 Value memV = genToValues(rewriter, loc, a); 669 Value rowA = genAllocCopy(rewriter, loc, memR, tokens); 670 Value colA = memC ? genAllocCopy(rewriter, loc, memC, tokens) : Value(); 671 Value valA = genAllocCopy(rewriter, loc, memV, tokens); 672 Value bufB = genTensorToMemref(rewriter, loc, b); 673 Value matB = genAllocCopy(rewriter, loc, bufB, tokens); 674 Value bufC = genTensorToMemref(rewriter, loc, c); 675 Value matC = genAllocCopy(rewriter, loc, bufC, tokens); 676 genBlockingWait(rewriter, loc, tokens); 677 tokens.clear(); 678 679 // Create sparse environment and sparse matrix/dense matrix handles. 680 Type indexTp = rewriter.getIndexType(); 681 Type dnTensorHandleTp = rewriter.getType<gpu::SparseDnTensorHandleType>(); 682 Type spMatHandleTp = rewriter.getType<gpu::SparseSpMatHandleType>(); 683 Type tokenTp = rewriter.getType<gpu::AsyncTokenType>(); 684 Value token = genFirstWait(rewriter, loc); 685 Operation *spGenA = 686 genSpMat(rewriter, loc, aTp, spMatHandleTp, tokenTp, token, szm, szk, 687 nseA, rowA, colA, valA, format, enableRT); 688 Value spMatA = spGenA->getResult(0); 689 token = spGenA->getResult(1); 690 auto dmatB = rewriter.create<gpu::CreateDnTensorOp>( 691 loc, dnTensorHandleTp, tokenTp, token, matB, 692 SmallVector<Value>{szk, szn}); 693 Value dnB = dmatB.getResult(0); 694 token = dmatB.getAsyncToken(); 695 auto dmatC = rewriter.create<gpu::CreateDnTensorOp>( 696 loc, dnTensorHandleTp, tokenTp, token, matC, 697 SmallVector<Value>{szm, szn}); 698 Value dnC = dmatC.getResult(0); 699 token = dmatC.getAsyncToken(); 700 auto dmatCType = llvm::cast<ShapedType>(c.getType()).getElementType(); 701 702 // Precompute buffersize for SpMM. 703 auto bufferComp = rewriter.create<gpu::SpMMBufferSizeOp>( 704 loc, indexTp, tokenTp, token, spMatA, dnB, dnC, 705 /*computeType=*/dmatCType); 706 Value bufferSz = bufferComp.getResult(0); 707 token = bufferComp.getAsyncToken(); 708 auto buf = genAllocBuffer(rewriter, loc, bufferSz, token); 709 Value buffer = buf.getResult(0); 710 token = buf.getAsyncToken(); 711 auto dnCType = llvm::cast<ShapedType>(c.getType()).getElementType(); 712 713 // Perform the SpMM. 714 auto spmmComp = rewriter.create<gpu::SpMMOp>( 715 loc, tokenTp, token, spMatA, dnB, dnC, /*computeType=*/dnCType, buffer); 716 token = spmmComp.getAsyncToken(); 717 718 // Copy data back to host and free all the resoures. 719 token = rewriter.create<gpu::DestroySpMatOp>(loc, tokenTp, token, spMatA) 720 .getAsyncToken(); 721 token = rewriter.create<gpu::DestroyDnTensorOp>(loc, tokenTp, token, dnB) 722 .getAsyncToken(); 723 token = rewriter.create<gpu::DestroyDnTensorOp>(loc, tokenTp, token, dnC) 724 .getAsyncToken(); 725 token = genDeallocMemRef(rewriter, loc, rowA, token); 726 if (colA) 727 token = genDeallocMemRef(rewriter, loc, colA, token); 728 token = genDeallocMemRef(rewriter, loc, valA, token); 729 token = genDeallocMemRef(rewriter, loc, buffer, token); 730 token = genDeallocMemRef(rewriter, loc, matB, token); 731 token = genCopyMemRef(rewriter, loc, bufC, matC, token); 732 token = genDeallocMemRef(rewriter, loc, matC, token); 733 tokens.push_back(token); 734 genBlockingWait(rewriter, loc, tokens); 735 tokens.clear(); 736 737 // Done. 738 rewriter.replaceOpWithNewOp<bufferization::ToTensorOp>(op, bufC); 739 return success(); 740 } 741 742 // Match and rewrite SpGEMM kernel. 743 static LogicalResult rewriteSpGEMM(PatternRewriter &rewriter, 744 linalg::GenericOp op, bool enableRT) { 745 Location loc = op.getLoc(); 746 Value a = op.getOperand(0); 747 Value b = op.getOperand(1); 748 Value c = op.getOperand(2); // we have C = AB 749 SmallVector<Value> tokens; 750 751 // Only CSR <- CSR x CSR supported. 752 auto format = CuSparseFormat::kCSR; 753 SparseTensorType aTp = getSparseTensorType(a); 754 SparseTensorType bTp = getSparseTensorType(b); 755 SparseTensorType cTp = getSparseTensorType(c); 756 if (!isAdmissibleCSR(aTp) || !isAdmissibleCSR(bTp) || !isAdmissibleCSR(cTp)) 757 return failure(); 758 759 // Start sparse kernel and copy data from host to device. 760 // a : amemR/amemC/amemV -> rowA,colA,valA 761 // b : bmemR/bmemC/bmemV -> rowB,colB,valB 762 // c : materializes 763 auto dnCType = cTp.getElementType(); 764 Value nseA = rewriter.create<NumberOfEntriesOp>(loc, a); 765 Value nseB = rewriter.create<NumberOfEntriesOp>(loc, b); 766 Value szm = linalg::createOrFoldDimOp(rewriter, loc, a, 0); 767 Value szk = linalg::createOrFoldDimOp(rewriter, loc, a, 1); 768 Value szn = linalg::createOrFoldDimOp(rewriter, loc, b, 1); 769 Value amemR = genFirstPosOrCrds(rewriter, loc, a, format, enableRT); 770 Value amemC = genSecondCrds(rewriter, loc, a, format, enableRT); // not empty 771 Value amemV = genToValues(rewriter, loc, a); 772 Value bmemR = genFirstPosOrCrds(rewriter, loc, b, format, enableRT); 773 Value bmemC = genSecondCrds(rewriter, loc, b, format, enableRT); // not empty 774 Value bmemV = genToValues(rewriter, loc, b); 775 Value rowA = genAllocCopy(rewriter, loc, amemR, tokens); 776 Value colA = genAllocCopy(rewriter, loc, amemC, tokens); 777 Value valA = genAllocCopy(rewriter, loc, amemV, tokens); 778 Value rowB = genAllocCopy(rewriter, loc, bmemR, tokens); 779 Value colB = genAllocCopy(rewriter, loc, bmemC, tokens); 780 Value valB = genAllocCopy(rewriter, loc, bmemV, tokens); 781 genBlockingWait(rewriter, loc, tokens); 782 tokens.clear(); 783 784 // Create sparse environment and sparse matrix/dense vector handles. 785 Type indexTp = rewriter.getIndexType(); 786 Type spmatHandleTp = rewriter.getType<gpu::SparseSpMatHandleType>(); 787 Type descTp = rewriter.getType<gpu::SparseSpGEMMOpHandleType>(); 788 Type tokenTp = rewriter.getType<gpu::AsyncTokenType>(); 789 Value token = genFirstWait(rewriter, loc); 790 Operation *spGenA = 791 genSpMat(rewriter, loc, aTp, spmatHandleTp, tokenTp, token, szm, szk, 792 nseA, rowA, colA, valA, format, enableRT); 793 Value spMatA = spGenA->getResult(0); 794 token = spGenA->getResult(1); 795 Operation *spGenB = 796 genSpMat(rewriter, loc, bTp, spmatHandleTp, tokenTp, token, szk, szn, 797 nseB, rowB, colB, valB, format, enableRT); 798 Value spMatB = spGenB->getResult(0); 799 token = spGenB->getResult(1); 800 801 // Sparse matrix C materializes (also assumes beta == 0). 802 Value zero = constantIndex(rewriter, loc, 0); 803 Value one = constantIndex(rewriter, loc, 1); 804 Value mplus1 = rewriter.create<arith::AddIOp>(loc, szm, one); 805 auto e1 = genAllocBuffer(rewriter, loc, cTp.getPosType(), mplus1, token); 806 Value rowC = e1.getResult(0); 807 token = e1.getAsyncToken(); 808 auto e2 = genAllocBuffer(rewriter, loc, cTp.getCrdType(), zero, token); 809 Value colC = e2.getResult(0); // no free needed 810 token = e2.getAsyncToken(); 811 auto e3 = genAllocBuffer(rewriter, loc, dnCType, zero, token); 812 Value valC = e3.getResult(0); // no free needed 813 token = e3.getAsyncToken(); 814 Operation *spGenC = 815 genSpMat(rewriter, loc, cTp, spmatHandleTp, tokenTp, token, szm, szn, 816 zero, rowC, colC, valC, format, enableRT); 817 Value spMatC = spGenC->getResult(0); 818 token = spGenC->getResult(1); 819 820 // Precompute buffersizes for SpGEMM. 821 Operation *descOp = 822 rewriter.create<gpu::SpGEMMCreateDescrOp>(loc, descTp, tokenTp, token); 823 Value desc = descOp->getResult(0); 824 token = descOp->getResult(1); 825 Operation *work1 = rewriter.create<gpu::SpGEMMWorkEstimationOrComputeOp>( 826 loc, indexTp, tokenTp, token, desc, gpu::TransposeMode::NON_TRANSPOSE, 827 gpu::TransposeMode::NON_TRANSPOSE, spMatA, spMatB, spMatC, dnCType, zero, 828 valC, gpu::SpGEMMWorkEstimationOrComputeKind::WORK_ESTIMATION); 829 Value bufferSz1 = work1->getResult(0); 830 token = work1->getResult(1); 831 auto buf1 = genAllocBuffer(rewriter, loc, bufferSz1, token); 832 Value buffer1 = buf1.getResult(0); 833 token = buf1.getAsyncToken(); 834 Operation *work2 = rewriter.create<gpu::SpGEMMWorkEstimationOrComputeOp>( 835 loc, indexTp, tokenTp, token, desc, gpu::TransposeMode::NON_TRANSPOSE, 836 gpu::TransposeMode::NON_TRANSPOSE, spMatA, spMatB, spMatC, dnCType, 837 bufferSz1, buffer1, 838 gpu::SpGEMMWorkEstimationOrComputeKind::WORK_ESTIMATION); 839 token = work2->getResult(1); 840 841 // Compute step. 842 Operation *compute1 = rewriter.create<gpu::SpGEMMWorkEstimationOrComputeOp>( 843 loc, indexTp, tokenTp, token, desc, gpu::TransposeMode::NON_TRANSPOSE, 844 gpu::TransposeMode::NON_TRANSPOSE, spMatA, spMatB, spMatC, dnCType, zero, 845 valC, gpu::SpGEMMWorkEstimationOrComputeKind::COMPUTE); 846 Value bufferSz2 = compute1->getResult(0); 847 token = compute1->getResult(1); 848 auto buf2 = genAllocBuffer(rewriter, loc, bufferSz2, token); 849 Value buffer2 = buf2.getResult(0); 850 token = buf2.getAsyncToken(); 851 Operation *compute2 = rewriter.create<gpu::SpGEMMWorkEstimationOrComputeOp>( 852 loc, indexTp, tokenTp, token, desc, gpu::TransposeMode::NON_TRANSPOSE, 853 gpu::TransposeMode::NON_TRANSPOSE, spMatA, spMatB, spMatC, dnCType, 854 bufferSz2, buffer2, gpu::SpGEMMWorkEstimationOrComputeKind::COMPUTE); 855 token = compute2->getResult(1); 856 857 // Get sizes. 858 Operation *sizes = rewriter.create<gpu::SpMatGetSizeOp>( 859 loc, indexTp, indexTp, indexTp, tokenTp, token, spMatC); 860 Value nnz = sizes->getResult(2); 861 token = sizes->getResult(3); 862 auto a2 = genAllocBuffer(rewriter, loc, cTp.getCrdType(), nnz, token); 863 colC = a2.getResult(0); 864 token = a2.getAsyncToken(); 865 auto a3 = genAllocBuffer(rewriter, loc, dnCType, nnz, token); 866 valC = a3.getResult(0); 867 token = a3.getAsyncToken(); 868 869 // Update C with new pointers and copy final product back into C. 870 Operation *update = rewriter.create<gpu::SetCsrPointersOp>( 871 loc, tokenTp, token, spMatC, rowC, colC, valC); 872 token = update->getResult(0); 873 Operation *copy = rewriter.create<gpu::SpGEMMCopyOp>( 874 loc, tokenTp, token, desc, gpu::TransposeMode::NON_TRANSPOSE, 875 gpu::TransposeMode::NON_TRANSPOSE, spMatA, spMatB, spMatC, dnCType); 876 token = copy->getResult(0); 877 878 // Allocate buffers on host. 879 Value rowH = genHostBuffer(rewriter, loc, cTp.getPosType(), mplus1); 880 Value colH = genHostBuffer(rewriter, loc, cTp.getCrdType(), nnz); 881 Value valH = genHostBuffer(rewriter, loc, dnCType, nnz); 882 883 // Copy data back to host and free all the resoures. 884 token = rewriter.create<gpu::SpGEMMDestroyDescrOp>(loc, tokenTp, token, desc) 885 .getAsyncToken(); 886 token = rewriter.create<gpu::DestroySpMatOp>(loc, tokenTp, token, spMatA) 887 .getAsyncToken(); 888 token = rewriter.create<gpu::DestroySpMatOp>(loc, tokenTp, token, spMatB) 889 .getAsyncToken(); 890 token = rewriter.create<gpu::DestroySpMatOp>(loc, tokenTp, token, spMatC) 891 .getAsyncToken(); 892 token = genCopyMemRef(rewriter, loc, rowH, rowC, token); 893 token = genCopyMemRef(rewriter, loc, colH, colC, token); 894 token = genCopyMemRef(rewriter, loc, valH, valC, token); 895 token = genDeallocMemRef(rewriter, loc, rowA, token); 896 token = genDeallocMemRef(rewriter, loc, colA, token); 897 token = genDeallocMemRef(rewriter, loc, valA, token); 898 token = genDeallocMemRef(rewriter, loc, rowB, token); 899 token = genDeallocMemRef(rewriter, loc, colB, token); 900 token = genDeallocMemRef(rewriter, loc, valB, token); 901 token = genDeallocMemRef(rewriter, loc, rowC, token); 902 token = genDeallocMemRef(rewriter, loc, colC, token); 903 token = genDeallocMemRef(rewriter, loc, valC, token); 904 token = genDeallocMemRef(rewriter, loc, buffer1, token); 905 token = genDeallocMemRef(rewriter, loc, buffer2, token); 906 tokens.push_back(token); 907 genBlockingWait(rewriter, loc, tokens); 908 tokens.clear(); 909 910 // Done. 911 Value vt = rewriter.create<bufferization::ToTensorOp>(loc, valH); 912 Value rt = rewriter.create<bufferization::ToTensorOp>(loc, rowH); 913 Value ct = rewriter.create<bufferization::ToTensorOp>(loc, colH); 914 rewriter.replaceOpWithNewOp<AssembleOp>(op, c.getType(), vt, 915 ValueRange{rt, ct}); 916 return success(); 917 } 918 919 // Match and rewrite 2:4 SpMM kernel. 920 static LogicalResult rewrite2To4SpMM(PatternRewriter &rewriter, 921 linalg::GenericOp op) { 922 Location loc = op.getLoc(); 923 Value A = op.getOperand(0); 924 Value B = op.getOperand(1); 925 Value C = op.getOperand(2); // we have C = AB 926 SmallVector<Value> tokens; 927 928 // All input should be dense tensors. 929 if (!isDenseTensor(A) || !isDenseTensor(B) || !isDenseTensor(C)) 930 return failure(); 931 932 // Start sparse kernel and copy data from host to device. 933 // a : bufA -> matA 934 // b : bufB -> matB 935 // c : bufC -> matC 936 Value bufA = genTensorToMemref(rewriter, loc, A); 937 Value matA = genAllocCopy(rewriter, loc, bufA, tokens); 938 Value bufB = genTensorToMemref(rewriter, loc, B); 939 Value matB = genAllocCopy(rewriter, loc, bufB, tokens); 940 Value bufC = genTensorToMemref(rewriter, loc, C); 941 Value matC = genAllocCopy(rewriter, loc, bufC, tokens); 942 genBlockingWait(rewriter, loc, tokens); 943 tokens.clear(); 944 945 // Create sparse environment and sparse matrix/dense vector handles. 946 Value szm = linalg::createOrFoldDimOp(rewriter, loc, matA, 0); 947 Value szk = linalg::createOrFoldDimOp(rewriter, loc, matB, 0); 948 Value szn = linalg::createOrFoldDimOp(rewriter, loc, matC, 1); 949 Type indexTp = rewriter.getIndexType(); 950 Type dnTensorHandleTp = rewriter.getType<gpu::SparseDnTensorHandleType>(); 951 Type spMatHandleTp = rewriter.getType<gpu::SparseSpMatHandleType>(); 952 Type tokenTp = rewriter.getType<gpu::AsyncTokenType>(); 953 Value token = genFirstWait(rewriter, loc); 954 Operation *spGenA = rewriter.create<gpu::Create2To4SpMatOp>( 955 loc, spMatHandleTp, tokenTp, token, szm, szk, 956 gpu::Prune2To4SpMatFlag::PRUNE_AND_CHECK, matA); 957 Value spMatA = spGenA->getResult(0); 958 token = spGenA->getResult(1); 959 auto dmatB = rewriter.create<gpu::CreateDnTensorOp>( 960 loc, dnTensorHandleTp, tokenTp, token, matB, 961 SmallVector<Value>{szk, szn}); 962 Value dnB = dmatB.getResult(0); 963 token = dmatB.getAsyncToken(); 964 auto dmatC = rewriter.create<gpu::CreateDnTensorOp>( 965 loc, dnTensorHandleTp, tokenTp, token, matC, 966 SmallVector<Value>{szm, szn}); 967 Value dnC = dmatC.getResult(0); 968 token = dmatC.getAsyncToken(); 969 auto dmatCType = llvm::cast<ShapedType>(matC.getType()).getElementType(); 970 971 // Precompute buffersize for SpMM. 972 SmallVector<Type> bufferTypes_{indexTp, indexTp, indexTp}; 973 TypeRange bufferTypes(bufferTypes_); 974 auto bufferComp = rewriter.create<gpu::SpMMBufferSizeOp>( 975 loc, bufferTypes, tokenTp, token, gpu::TransposeMode::NON_TRANSPOSE, 976 gpu::TransposeMode::NON_TRANSPOSE, spMatA, dnB, dnC, 977 /*computeType=*/dmatCType); 978 token = bufferComp.getAsyncToken(); 979 980 // Allocate buffers on host. 981 Value bufferSz1 = bufferComp.getResult(0); 982 auto buf1 = genAllocBuffer(rewriter, loc, bufferSz1, token); 983 Value buffer1 = buf1.getResult(0); 984 token = buf1.getAsyncToken(); 985 Value bufferSz2 = bufferComp.getResult(1); 986 auto buf2 = genAllocBuffer(rewriter, loc, bufferSz2, token); 987 Value buffer2 = buf2.getResult(0); 988 token = buf2.getAsyncToken(); 989 Value bufferSz3 = bufferComp.getResult(2); 990 auto buf3 = genAllocBuffer(rewriter, loc, bufferSz3, token); 991 Value buffer3 = buf3.getResult(0); 992 token = buf3.getAsyncToken(); 993 994 // Perform the SpMM. 995 auto dnCType = llvm::cast<ShapedType>(matC.getType()).getElementType(); 996 auto spmmComp = rewriter.create<gpu::SpMMOp>( 997 loc, tokenTp, token, spMatA, dnB, dnC, /*computeType=*/dnCType, 998 SmallVector<Value>{buffer1, buffer2, buffer3}); 999 token = spmmComp.getAsyncToken(); 1000 1001 // Copy data back to host and free all the resources. 1002 token = rewriter.create<gpu::DestroySpMatOp>(loc, tokenTp, token, spMatA) 1003 .getAsyncToken(); 1004 token = rewriter.create<gpu::DestroyDnTensorOp>(loc, tokenTp, token, dnB) 1005 .getAsyncToken(); 1006 token = rewriter.create<gpu::DestroyDnTensorOp>(loc, tokenTp, token, dnC) 1007 .getAsyncToken(); 1008 SmallVector<Value> newDynamicSizes; 1009 token = genDeallocMemRef(rewriter, loc, buffer1, token); 1010 token = genDeallocMemRef(rewriter, loc, buffer2, token); 1011 token = genDeallocMemRef(rewriter, loc, buffer3, token); 1012 token = genDeallocMemRef(rewriter, loc, matA, token); 1013 token = genDeallocMemRef(rewriter, loc, matB, token); 1014 token = genCopyMemRef(rewriter, loc, bufC, matC, token); 1015 token = genDeallocMemRef(rewriter, loc, matC, token); 1016 tokens.push_back(token); 1017 genBlockingWait(rewriter, loc, tokens); 1018 tokens.clear(); 1019 1020 // Done. 1021 rewriter.replaceOpWithNewOp<bufferization::ToTensorOp>(op, bufC); 1022 return success(); 1023 } 1024 1025 /// Match and rewrite SDDMM kernel. 1026 static LogicalResult rewriteSDDMM(PatternRewriter &rewriter, 1027 linalg::GenericOp op, bool enableRT) { 1028 Location loc = op.getLoc(); 1029 Value a = op.getOperand(0); 1030 Value b = op.getOperand(1); 1031 Value c = op.getOperand(2); 1032 SmallVector<Value> tokens; 1033 1034 // Only admissible sparse matrix format (no COO/CSC) and dense matrices. 1035 SparseTensorType aTp = getSparseTensorType(a); 1036 SparseTensorType bTp = getSparseTensorType(b); 1037 SparseTensorType cTp = getSparseTensorType(c); 1038 auto format = getCuSparseFormat(cTp, bTp, aTp, enableRT, /*isMatVec=*/false); 1039 if (format == CuSparseFormat::kNone || format == CuSparseFormat::kCOO || 1040 format == CuSparseFormat::kCSC) 1041 return failure(); 1042 1043 // The SDDMM does the in-place operation. 1044 // Start sparse kernel and copy data from host to device. 1045 // a : bufA -> matA 1046 // b : bufB -> matB 1047 // c : memR/memC/memV -> rowC,colC,valC 1048 Value nseC = rewriter.create<NumberOfEntriesOp>(loc, c); 1049 Value szm = linalg::createOrFoldDimOp(rewriter, loc, a, 0); 1050 Value szk = linalg::createOrFoldDimOp(rewriter, loc, a, 1); 1051 Value szn = linalg::createOrFoldDimOp(rewriter, loc, b, 1); 1052 Value bufA = genTensorToMemref(rewriter, loc, a); 1053 Value matA = genAllocCopy(rewriter, loc, bufA, tokens); 1054 Value bufB = genTensorToMemref(rewriter, loc, b); 1055 Value matB = genAllocCopy(rewriter, loc, bufB, tokens); 1056 Value memR = genFirstPosOrCrds(rewriter, loc, c, format, enableRT); 1057 Value memC = genSecondCrds(rewriter, loc, c, format, enableRT); // or empty 1058 Value memV = genToValues(rewriter, loc, c); 1059 Value rowC = genAllocCopy(rewriter, loc, memR, tokens); 1060 Value colC = memC ? genAllocCopy(rewriter, loc, memC, tokens) : Value(); 1061 Value valC = genAllocCopy(rewriter, loc, memV, tokens); 1062 genBlockingWait(rewriter, loc, tokens); 1063 tokens.clear(); 1064 1065 // Create sparse environment and sparse matrix/dense matrix handles. 1066 Type indexTp = rewriter.getIndexType(); 1067 Type dnMatHandleTp = rewriter.getType<gpu::SparseDnTensorHandleType>(); 1068 Type spMatHandleTp = rewriter.getType<gpu::SparseSpMatHandleType>(); 1069 Type tokenTp = rewriter.getType<gpu::AsyncTokenType>(); 1070 Value token = genFirstWait(rewriter, loc); 1071 auto dmatA = rewriter.create<gpu::CreateDnTensorOp>( 1072 loc, dnMatHandleTp, tokenTp, token, matA, SmallVector<Value>{szm, szk}); 1073 Value dnA = dmatA.getResult(0); 1074 token = dmatA.getAsyncToken(); 1075 auto dmatB = rewriter.create<gpu::CreateDnTensorOp>( 1076 loc, dnMatHandleTp, tokenTp, token, matB, SmallVector<Value>{szk, szn}); 1077 Value dnB = dmatB.getResult(0); 1078 token = dmatB.getAsyncToken(); 1079 Operation *spGenC = 1080 genSpMat(rewriter, loc, cTp, spMatHandleTp, tokenTp, token, szm, szn, 1081 nseC, rowC, colC, valC, format, enableRT); 1082 Value spMatC = spGenC->getResult(0); 1083 token = spGenC->getResult(1); 1084 auto dnCType = llvm::cast<ShapedType>(c.getType()).getElementType(); 1085 1086 // Precompute buffersize for SDDMM. 1087 auto bufferComp = rewriter.create<gpu::SDDMMBufferSizeOp>( 1088 loc, indexTp, tokenTp, token, dnA, dnB, spMatC, dnCType); 1089 Value bufferSz = bufferComp.getResult(0); 1090 token = bufferComp.getAsyncToken(); 1091 auto buf = genAllocBuffer(rewriter, loc, bufferSz, token); 1092 Value buffer = buf.getResult(0); 1093 token = buf.getAsyncToken(); 1094 1095 // Perform the SDDMM. 1096 auto sddmmComp = rewriter.create<gpu::SDDMMOp>(loc, tokenTp, token, dnA, dnB, 1097 spMatC, dnCType, buffer); 1098 token = sddmmComp.getAsyncToken(); 1099 1100 // Copy data back to host and free all the resoures. 1101 token = rewriter.create<gpu::DestroyDnTensorOp>(loc, tokenTp, token, dnA) 1102 .getAsyncToken(); 1103 token = rewriter.create<gpu::DestroyDnTensorOp>(loc, tokenTp, token, dnB) 1104 .getAsyncToken(); 1105 token = rewriter.create<gpu::DestroySpMatOp>(loc, tokenTp, token, spMatC) 1106 .getAsyncToken(); 1107 token = genDeallocMemRef(rewriter, loc, buffer, token); 1108 token = genDeallocMemRef(rewriter, loc, matA, token); 1109 token = genDeallocMemRef(rewriter, loc, matB, token); 1110 token = genDeallocMemRef(rewriter, loc, rowC, token); 1111 if (colC) 1112 token = genDeallocMemRef(rewriter, loc, colC, token); 1113 token = genCopyMemRef(rewriter, loc, memV, valC, token); 1114 token = genDeallocMemRef(rewriter, loc, valC, token); 1115 tokens.push_back(token); 1116 genBlockingWait(rewriter, loc, tokens); 1117 tokens.clear(); 1118 1119 // Done. 1120 rewriter.replaceOpWithNewOp<sparse_tensor::LoadOp>(op, c); 1121 return success(); 1122 } 1123 1124 //===----------------------------------------------------------------------===// 1125 // Rewriting rules for direct code generation. 1126 //===----------------------------------------------------------------------===// 1127 1128 /// Proof-of-concept rewriter. This rule generates a GPU implementation 1129 /// for each outermost forall loop generated by the sparsifier. 1130 /// TODO: right now works with parallelization-strategy=dense-outer-loop 1131 /// but give this its own flags in the future 1132 struct ForallRewriter : public OpRewritePattern<scf::ParallelOp> { 1133 using OpRewritePattern<scf::ParallelOp>::OpRewritePattern; 1134 1135 ForallRewriter(MLIRContext *context, unsigned nT) 1136 : OpRewritePattern(context), numThreads(nT){}; 1137 1138 LogicalResult matchAndRewrite(scf::ParallelOp forallOp, 1139 PatternRewriter &rewriter) const override { 1140 // Reject inadmissible loop form. 1141 // Essentially only accept a loop, generated by the sparsifier, 1142 // of the form 1143 // forall (i = 0; i < N; i++) 1144 // so that cyclic scheduling over the threads is easy. 1145 if (!forallOp->hasAttr(LoopEmitter::getLoopEmitterLoopAttrName()) || 1146 forallOp.getNumReductions() != 0 || forallOp.getNumLoops() != 1 || 1147 !matchPattern(forallOp.getLowerBound()[0], m_Zero()) || 1148 !matchPattern(forallOp.getStep()[0], m_One())) 1149 return failure(); 1150 // Collect every value that is computed outside the parallel loop. 1151 SetVector<Value> invariants; // stable iteration! 1152 forallOp->walk([&](Operation *op) { 1153 // Collect all values of admissible ops. 1154 for (OpOperand &o : op->getOpOperands()) { 1155 Value val = o.get(); 1156 Block *block; 1157 if (auto arg = dyn_cast<BlockArgument>(val)) 1158 block = arg.getOwner(); 1159 else 1160 block = val.getDefiningOp()->getBlock(); 1161 if (!forallOp.getRegion().findAncestorBlockInRegion(*block)) 1162 invariants.insert(val); 1163 } 1164 }); 1165 // Outline the outside values as proper parameters. Fail when sharing 1166 // value between host and device is not straightforward. 1167 SmallVector<Value> constants; 1168 SmallVector<Value> scalars; 1169 SmallVector<Value> buffers; 1170 for (Value val : invariants) { 1171 Type tp = val.getType(); 1172 if (val.getDefiningOp<arith::ConstantOp>()) 1173 constants.push_back(val); 1174 else if (isa<FloatType>(tp) || tp.isIntOrIndex()) 1175 scalars.push_back(val); 1176 else if (isa<MemRefType>(tp)) 1177 buffers.push_back(val); 1178 else 1179 return failure(); // don't know how to share 1180 } 1181 // Pass outlined non-constant values. 1182 // TODO: Experiment with `useHostRegistrationForOut` to see if we want to 1183 // keep the feature at all (either through a heuristic or compiler 1184 // option for gpu codegen). 1185 Location loc = forallOp->getLoc(); 1186 SmallVector<Value> args; 1187 SmallVector<Value> tokens; 1188 Value out = genParametersIn(rewriter, loc, scalars, buffers, args, tokens, 1189 /*useHostRegistrationForOut=*/false); 1190 // Set up GPU module and construct GPU function. 1191 auto saveIp = rewriter.saveInsertionPoint(); 1192 ModuleOp topModule = forallOp->getParentOfType<ModuleOp>(); 1193 auto gpuModule = genGPUModule(rewriter, topModule); 1194 auto gpuFunc = genGPUFunc(rewriter, gpuModule, args); 1195 genGPUCode(rewriter, gpuFunc, forallOp, constants, scalars, buffers); 1196 // Generate code that launches the kernel asynchronously, blocking on all 1197 // opens tokens and yielding a new token for the output. 1198 // TODO: Passing in tokens to launch up does not seem to be properly lowered 1199 // by cubin yet, hence the current blocking wait. 1200 rewriter.restoreInsertionPoint(saveIp); 1201 genBlockingWait(rewriter, loc, tokens); 1202 tokens.clear(); 1203 Value kernelToken = 1204 genLaunchGPUFunc(rewriter, gpuFunc, args, tokens, numThreads); 1205 // Finalize the outlined arguments. 1206 genParametersOut(rewriter, loc, out, kernelToken, scalars, buffers, args, 1207 tokens); 1208 genBlockingWait(rewriter, loc, tokens); 1209 rewriter.eraseOp(forallOp); 1210 return success(); 1211 } 1212 1213 private: 1214 unsigned numThreads; 1215 }; 1216 1217 //===----------------------------------------------------------------------===// 1218 // Rewriting rules for library recognition and code generation. 1219 //===----------------------------------------------------------------------===// 1220 1221 /// Proof-of-concept rewriter. This rule recognizes certain math kernels 1222 /// and replaces these with corresponding calls into a sparse library. 1223 struct LinalgOpRewriter : public OpRewritePattern<linalg::GenericOp> { 1224 using OpRewritePattern<linalg::GenericOp>::OpRewritePattern; 1225 1226 LinalgOpRewriter(MLIRContext *context, bool rt) 1227 : OpRewritePattern(context), enableRT(rt) {} 1228 1229 LogicalResult matchAndRewrite(linalg::GenericOp op, 1230 PatternRewriter &rewriter) const override { 1231 if (op.getNumDpsInits() != 1) 1232 return failure(); // reject multi-output 1233 1234 const unsigned numLoops = op.getNumLoops(); 1235 const unsigned numTensors = op->getNumOperands(); 1236 const auto iteratorTypes = op.getIteratorTypesArray(); 1237 SmallVector<AffineMap, 4> maps = op.getIndexingMapsArray(); 1238 1239 using MapList = ArrayRef<ArrayRef<AffineExpr>>; 1240 auto infer = [](MapList m) { return AffineMap::inferFromExprList(m); }; 1241 AffineExpr i, j, k; 1242 bindDims(getContext(), i, j, k); 1243 1244 // TODO: more robust patterns, tranposed versions, more kernels, 1245 // identify alpha and beta and pass them to the CUDA calls. 1246 1247 // Recognize a SpMV kernel. 1248 if (numLoops == 2 && numTensors == 3 && 1249 linalg::isParallelIterator(iteratorTypes[0]) && 1250 linalg::isReductionIterator(iteratorTypes[1]) && 1251 maps == infer({{i, j}, {j}, {i}}) && matchSumOfMultOfArgs(op)) { 1252 return rewriteSpMV(rewriter, op, enableRT); 1253 } 1254 1255 // Recognize a SpGEMM, 2:4-SpMM, or SpMM kernel. 1256 if (numLoops == 3 && numTensors == 3 && 1257 linalg::isParallelIterator(iteratorTypes[0]) && 1258 linalg::isParallelIterator(iteratorTypes[1]) && 1259 linalg::isReductionIterator(iteratorTypes[2]) && 1260 maps == infer({{i, k}, {k, j}, {i, j}}) && matchSumOfMultOfArgs(op)) { 1261 if (!isDenseTensor(op.getOperand(0)) && !isDenseTensor(op.getOperand(1))) 1262 return rewriteSpGEMM(rewriter, op, enableRT); 1263 if (op->getAttr("DENSE24")) 1264 return rewrite2To4SpMM(rewriter, op); 1265 return rewriteSpMM(rewriter, op, enableRT); 1266 } 1267 1268 // Recognize a SDDMM kernel. 1269 if (numLoops == 3 && numTensors == 3 && 1270 linalg::isParallelIterator(iteratorTypes[0]) && 1271 linalg::isParallelIterator(iteratorTypes[1]) && 1272 linalg::isReductionIterator(iteratorTypes[2]) && 1273 maps == infer({{i, k}, {k, j}, {i, j}}) && 1274 matchSumReductionOfMulUnary(op)) { 1275 return rewriteSDDMM(rewriter, op, enableRT); 1276 } 1277 1278 return failure(); 1279 } 1280 1281 private: 1282 bool enableRT; 1283 }; 1284 1285 } // namespace 1286 1287 //===----------------------------------------------------------------------===// 1288 // Public method for populating GPU rewriting rules. 1289 // 1290 // Currently two set of rewriting rules are made available. The first set 1291 // implements direct code generation, currently by means of convering the 1292 // outermost paralell loop into GPU threads. The second set implements 1293 // libary recognition of a set of sparse operations. Eventually, the right 1294 // combination of these two approaches has to be found. 1295 //===----------------------------------------------------------------------===// 1296 1297 void mlir::populateSparseGPUCodegenPatterns(RewritePatternSet &patterns, 1298 unsigned numThreads) { 1299 patterns.add<ForallRewriter>(patterns.getContext(), numThreads); 1300 } 1301 1302 void mlir::populateSparseGPULibgenPatterns(RewritePatternSet &patterns, 1303 bool enableRT) { 1304 patterns.add<LinalgOpRewriter>(patterns.getContext(), enableRT); 1305 } 1306