1 //===- SparseGPUCodegen.cpp - Generates GPU code --------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This is a prototype GPU codegenerator for the sparsifier. 10 // The objective is to eventually use the right combination of 11 // direct code generation and libary calls into vendor-specific 12 // highly optimized sparse libraries (e.g. cuSparse for CUDA). 13 // 14 //===----------------------------------------------------------------------===// 15 16 #include "Utils/CodegenUtils.h" 17 #include "Utils/LoopEmitter.h" 18 19 #include "mlir/Dialect/Bufferization/IR/Bufferization.h" 20 #include "mlir/Dialect/GPU/IR/GPUDialect.h" 21 #include "mlir/Dialect/Linalg/IR/Linalg.h" 22 #include "mlir/Dialect/Linalg/Utils/Utils.h" 23 #include "mlir/Dialect/MemRef/IR/MemRef.h" 24 #include "mlir/Dialect/SCF/IR/SCF.h" 25 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" 26 #include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h" 27 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h" 28 #include "mlir/IR/IRMapping.h" 29 #include "mlir/IR/Matchers.h" 30 31 using namespace mlir; 32 using namespace mlir::sparse_tensor; 33 34 namespace { 35 36 // Sparse formats supported by cuSparse. 37 enum class CuSparseFormat { 38 kNone, 39 kCOO, 40 kCSR, 41 kCSC, 42 kBSR, 43 }; 44 45 //===----------------------------------------------------------------------===// 46 // Helper methods. 47 //===----------------------------------------------------------------------===// 48 49 /// Marks the given top module as a GPU container module. 50 static void markAsGPUContainer(ModuleOp topModule) { 51 topModule->setAttr(gpu::GPUDialect::getContainerModuleAttrName(), 52 UnitAttr::get(topModule->getContext())); 53 } 54 55 /// Constructs a new GPU module (for GPU kernels) inside the given top module, 56 /// or returns an existing GPU module if one was built previously. 57 static gpu::GPUModuleOp genGPUModule(OpBuilder &builder, ModuleOp topModule) { 58 for (auto op : topModule.getBodyRegion().getOps<gpu::GPUModuleOp>()) 59 return op; // existing 60 markAsGPUContainer(topModule); 61 builder.setInsertionPointToStart(topModule.getBody()); 62 return builder.create<gpu::GPUModuleOp>(topModule->getLoc(), 63 "sparse_kernels"); 64 } 65 66 /// Constructs a new GPU kernel in the given GPU module. 67 static gpu::GPUFuncOp genGPUFunc(OpBuilder &builder, gpu::GPUModuleOp gpuModule, 68 SmallVectorImpl<Value> &args) { 69 // Get a unique kernel name. Not very creative, 70 // but we simply try kernel0, kernel1, etc. 71 unsigned kernelNumber = 0; 72 SmallString<16> kernelName; 73 do { 74 kernelName.clear(); 75 ("kernel" + Twine(kernelNumber++)).toStringRef(kernelName); 76 } while (gpuModule.lookupSymbol(kernelName)); 77 // Then we insert a new kernel with given arguments into the module. 78 builder.setInsertionPointToStart(gpuModule.getBody()); 79 SmallVector<Type> argsTp; 80 for (auto arg : args) 81 argsTp.push_back(arg.getType()); 82 FunctionType type = FunctionType::get(gpuModule->getContext(), argsTp, {}); 83 auto gpuFunc = 84 builder.create<gpu::GPUFuncOp>(gpuModule->getLoc(), kernelName, type); 85 gpuFunc->setAttr(gpu::GPUDialect::getKernelFuncAttrName(), 86 builder.getUnitAttr()); 87 return gpuFunc; 88 } 89 90 /// Constructs code to launch GPU kernel. 91 static Value genLaunchGPUFunc(OpBuilder &builder, gpu::GPUFuncOp gpuFunc, 92 SmallVectorImpl<Value> &args, 93 SmallVectorImpl<Value> &tokens, 94 unsigned numThreads) { 95 Location loc = gpuFunc->getLoc(); 96 Value none = TypedValue<::mlir::IntegerType>{}; 97 Value one = constantIndex(builder, loc, 1); 98 Value numT = constantIndex(builder, loc, numThreads); 99 gpu::KernelDim3 gridSize = {one, one, one}; 100 gpu::KernelDim3 blckSize = {numT, one, one}; 101 return builder 102 .create<gpu::LaunchFuncOp>(loc, gpuFunc, gridSize, blckSize, 103 /*dynSharedMemSz*/ none, args, 104 builder.getType<gpu::AsyncTokenType>(), tokens) 105 .getAsyncToken(); 106 } 107 108 /// Maps the provided ranked host buffer into the device address space. 109 /// Writes from the host are guaranteed to be visible to device kernels 110 /// that are launched afterwards. Writes from the device are guaranteed 111 /// to be visible on the host after synchronizing with the device kernel 112 /// completion. Needs to cast the buffer to a unranked buffer. 113 static Value genHostRegisterMemref(OpBuilder &builder, Location loc, 114 Value mem) { 115 MemRefType memTp = cast<MemRefType>(mem.getType()); 116 UnrankedMemRefType resTp = 117 UnrankedMemRefType::get(memTp.getElementType(), /*memorySpace=*/0); 118 Value cast = builder.create<memref::CastOp>(loc, resTp, mem); 119 builder.create<gpu::HostRegisterOp>(loc, cast); 120 return cast; 121 } 122 123 /// Unmaps the provided buffer, expecting the casted buffer. 124 static void genHostUnregisterMemref(OpBuilder &builder, Location loc, 125 Value cast) { 126 builder.create<gpu::HostUnregisterOp>(loc, cast); 127 } 128 129 /// Generates first wait in an asynchronous chain. 130 static Value genFirstWait(OpBuilder &builder, Location loc) { 131 Type tokenType = builder.getType<gpu::AsyncTokenType>(); 132 return builder.create<gpu::WaitOp>(loc, tokenType, ValueRange()) 133 .getAsyncToken(); 134 } 135 136 /// Generates last, blocking wait in an asynchronous chain. 137 static void genBlockingWait(OpBuilder &builder, Location loc, 138 ValueRange operands) { 139 builder.create<gpu::WaitOp>(loc, Type(), operands); 140 } 141 142 /// Allocates memory on the device. 143 /// TODO: A `host_shared` attribute could be used to indicate that 144 /// the buffer is visible by both host and device, but lowering 145 /// that feature does not seem to be fully supported yet. 146 static gpu::AllocOp genAllocMemRef(OpBuilder &builder, Location loc, Value mem, 147 Value token) { 148 auto tp = cast<ShapedType>(mem.getType()); 149 auto elemTp = tp.getElementType(); 150 auto shape = tp.getShape(); 151 auto memTp = MemRefType::get(shape, elemTp); 152 SmallVector<Value> dynamicSizes; 153 for (unsigned r = 0, rank = tp.getRank(); r < rank; r++) { 154 if (shape[r] == ShapedType::kDynamic) { 155 Value dimOp = linalg::createOrFoldDimOp(builder, loc, mem, r); 156 dynamicSizes.push_back(dimOp); 157 } 158 } 159 return builder.create<gpu::AllocOp>(loc, TypeRange({memTp, token.getType()}), 160 token, dynamicSizes, ValueRange()); 161 } 162 163 // Allocates a typed buffer on the host with given size. 164 static Value genHostBuffer(OpBuilder &builder, Location loc, Type type, 165 Value size) { 166 const auto memTp = MemRefType::get({ShapedType::kDynamic}, type); 167 return builder.create<memref::AllocOp>(loc, memTp, size).getResult(); 168 } 169 170 // Allocates a typed buffer on the device with given size. 171 static gpu::AllocOp genAllocBuffer(OpBuilder &builder, Location loc, Type type, 172 Value size, Value token) { 173 const auto memTp = MemRefType::get({ShapedType::kDynamic}, type); 174 return builder.create<gpu::AllocOp>(loc, TypeRange({memTp, token.getType()}), 175 token, size, ValueRange()); 176 } 177 178 // Allocates a void buffer on the device with given size. 179 static gpu::AllocOp genAllocBuffer(OpBuilder &builder, Location loc, Value size, 180 Value token) { 181 return genAllocBuffer(builder, loc, builder.getI8Type(), size, token); 182 } 183 184 /// Deallocates memory from the device. 185 static Value genDeallocMemRef(OpBuilder &builder, Location loc, Value mem, 186 Value token) { 187 return builder.create<gpu::DeallocOp>(loc, token.getType(), token, mem) 188 .getAsyncToken(); 189 } 190 191 /// Copies memory between host and device (direction is implicit). 192 static Value genCopyMemRef(OpBuilder &builder, Location loc, Value dst, 193 Value src, Value token) { 194 return builder.create<gpu::MemcpyOp>(loc, token.getType(), token, dst, src) 195 .getAsyncToken(); 196 } 197 198 /// Generates an alloc/copy pair. 199 static Value genAllocCopy(OpBuilder &builder, Location loc, Value b, 200 SmallVectorImpl<Value> &tokens) { 201 Value firstToken = genFirstWait(builder, loc); 202 auto alloc = genAllocMemRef(builder, loc, b, firstToken); 203 Value devMem = alloc.getResult(0); 204 Value depToken = alloc.getAsyncToken(); // copy-after-alloc 205 tokens.push_back(genCopyMemRef(builder, loc, devMem, b, depToken)); 206 return devMem; 207 } 208 209 /// Generates a memref from tensor operation. 210 static Value genTensorToMemref(PatternRewriter &rewriter, Location loc, 211 Value tensor) { 212 auto tensorType = llvm::cast<ShapedType>(tensor.getType()); 213 auto memrefType = 214 MemRefType::get(tensorType.getShape(), tensorType.getElementType()); 215 return rewriter.create<bufferization::ToMemrefOp>(loc, memrefType, tensor); 216 } 217 218 /// Prepares the outlined arguments, passing scalars and buffers in. Here we 219 /// assume that the first buffer is the one allocated for output. We create 220 /// a set of properly chained asynchronous allocation/copy pairs to increase 221 /// overlap before launching the kernel. 222 static Value genParametersIn(OpBuilder &builder, Location loc, 223 SmallVectorImpl<Value> &scalars, 224 SmallVectorImpl<Value> &buffers, 225 SmallVectorImpl<Value> &args, 226 SmallVectorImpl<Value> &tokens, 227 bool useHostRegistrationForOut) { 228 Value out; 229 // Scalars are passed by value. 230 for (Value s : scalars) 231 args.push_back(s); 232 // Buffers are need to be made visible on device. 233 for (Value b : buffers) { 234 if (useHostRegistrationForOut) { 235 out = genHostRegisterMemref(builder, loc, b); 236 args.push_back(b); 237 useHostRegistrationForOut = false; 238 continue; 239 } 240 args.push_back(genAllocCopy(builder, loc, b, tokens)); 241 } 242 return out; 243 } 244 245 /// Finalizes the outlined arguments. The output buffer is copied depending 246 /// on the kernel token and then deallocated. All other buffers are simply 247 /// deallocated. Then we wait for all operations to complete. 248 static void genParametersOut(OpBuilder &builder, Location loc, Value out, 249 Value kernelToken, SmallVectorImpl<Value> &scalars, 250 SmallVectorImpl<Value> &buffers, 251 SmallVectorImpl<Value> &args, 252 SmallVectorImpl<Value> &tokens) { 253 unsigned base = scalars.size(); 254 for (unsigned i = base, e = args.size(); i < e; i++) { 255 Value firstToken; 256 if (i == base) { 257 // Assumed output parameter: unregister or copy-out. 258 if (out) { 259 genHostUnregisterMemref(builder, loc, out); 260 out = Value(); 261 continue; 262 } 263 firstToken = 264 genCopyMemRef(builder, loc, buffers[0], args[i], kernelToken); 265 } else { 266 firstToken = genFirstWait(builder, loc); 267 } 268 tokens.push_back(genDeallocMemRef(builder, loc, args[i], firstToken)); 269 } 270 } 271 272 /// Constructs code for new GPU kernel. 273 static void genGPUCode(PatternRewriter &rewriter, gpu::GPUFuncOp gpuFunc, 274 scf::ParallelOp forallOp, 275 SmallVectorImpl<Value> &constants, 276 SmallVectorImpl<Value> &scalars, 277 SmallVectorImpl<Value> &buffers) { 278 Location loc = gpuFunc->getLoc(); 279 Block &block = gpuFunc.getBody().front(); 280 rewriter.setInsertionPointToStart(&block); 281 282 // Re-generate the constants, recapture all arguments. 283 unsigned arg = 0; 284 IRMapping irMap; 285 for (Value c : constants) 286 irMap.map(c, rewriter.clone(*c.getDefiningOp())->getResult(0)); 287 for (Value s : scalars) 288 irMap.map(s, block.getArgument(arg++)); 289 for (Value b : buffers) 290 irMap.map(b, block.getArgument(arg++)); 291 292 // Assume 1-dimensional grid/block configuration (only x dimension), 293 // so that: 294 // row = blockIdx.x * blockDim.x + threadIdx.x 295 // inc = blockDim.x * gridDim.x 296 Value bid = rewriter.create<gpu::BlockIdOp>(loc, gpu::Dimension::x); 297 Value bsz = rewriter.create<gpu::BlockDimOp>(loc, gpu::Dimension::x); 298 Value tid = rewriter.create<gpu::ThreadIdOp>(loc, gpu::Dimension::x); 299 Value gsz = rewriter.create<gpu::GridDimOp>(loc, gpu::Dimension::x); 300 Value mul = rewriter.create<arith::MulIOp>(loc, bid, bsz); 301 Value row = rewriter.create<arith::AddIOp>(loc, mul, tid); 302 Value inc = rewriter.create<arith::MulIOp>(loc, bsz, gsz); 303 304 // Construct the iteration over the computational space that 305 // accounts for the fact that the total number of threads and 306 // the amount of work to be done usually do not match precisely. 307 // for (r = row; r < N; r += inc) { 308 // <loop-body> 309 // } 310 Value upper = irMap.lookup(forallOp.getUpperBound()[0]); 311 scf::ForOp forOp = rewriter.create<scf::ForOp>(loc, row, upper, inc); 312 // The scf.for builder creates an empty block. scf.for does not allow multiple 313 // blocks in its region, so delete the block before `cloneRegionBefore` adds 314 // an additional block. 315 rewriter.eraseBlock(forOp.getBody()); 316 rewriter.cloneRegionBefore(forallOp.getRegion(), forOp.getRegion(), 317 forOp.getRegion().begin(), irMap); 318 // Replace the scf.reduce terminator. 319 rewriter.setInsertionPoint(forOp.getBody()->getTerminator()); 320 rewriter.replaceOpWithNewOp<scf::YieldOp>(forOp.getBody()->getTerminator()); 321 322 // Done. 323 rewriter.setInsertionPointAfter(forOp); 324 rewriter.create<gpu::ReturnOp>(gpuFunc->getLoc()); 325 } 326 327 //===----------------------------------------------------------------------===// 328 // Library helper methods. 329 //===----------------------------------------------------------------------===// 330 331 /// Helper to detect a + b with arguments taken from given block. 332 static bool matchAddOfArgs(Block *block, Value val) { 333 if (auto *def = val.getDefiningOp()) { 334 if (isa<arith::AddFOp, arith::AddIOp>(def)) { 335 Value a = block->getArguments()[0]; 336 Value b = block->getArguments()[1]; 337 return (def->getOperand(0) == a && def->getOperand(1) == b) || 338 (def->getOperand(0) == b && def->getOperand(1) == a); 339 } 340 } 341 return false; 342 } 343 344 /// Helper to detect a * b with arguments taken from given block. 345 static bool matchMulOfArgs(Block *block, Value val) { 346 if (auto *def = val.getDefiningOp()) { 347 if (isa<arith::MulFOp, arith::MulIOp>(def)) { 348 Value a = block->getArguments()[0]; 349 Value b = block->getArguments()[1]; 350 return (def->getOperand(0) == a && def->getOperand(1) == b) || 351 (def->getOperand(0) == b && def->getOperand(1) == a); 352 } 353 } 354 return false; 355 } 356 357 /// Helper to detect x = x + a * b 358 static bool matchSumOfMultOfArgs(linalg::GenericOp op) { 359 auto yieldOp = cast<linalg::YieldOp>(op.getRegion().front().getTerminator()); 360 if (auto *def = yieldOp.getOperand(0).getDefiningOp()) { 361 if (isa<arith::AddFOp, arith::AddIOp>(def)) { 362 Value x = op.getBlock()->getArguments()[2]; 363 return (def->getOperand(0) == x && 364 matchMulOfArgs(op.getBlock(), def->getOperand(1))) || 365 (def->getOperand(1) == x && 366 matchMulOfArgs(op.getBlock(), def->getOperand(0))); 367 } 368 } 369 return false; 370 } 371 372 // Helper to detect c += spy(s) x (a * b) 373 static bool matchSumReductionOfMulUnary(linalg::GenericOp op) { 374 auto yieldOp = cast<linalg::YieldOp>(op.getRegion().front().getTerminator()); 375 // The linalg yields a custom reduce result. 376 Value s_out = op.getBlock()->getArguments()[2]; 377 if (auto redOp = 378 yieldOp.getOperand(0).getDefiningOp<sparse_tensor::ReduceOp>()) { 379 // The reduce consumes the output. 380 Value other; 381 if (s_out == redOp->getOperand(0)) 382 other = redOp->getOperand(1); 383 else if (s_out == redOp->getOperand(1)) 384 other = redOp->getOperand(0); 385 else 386 return false; 387 // The reduce op also consumes an unary which also consumes the output 388 // and does not define an absent value. 389 if (auto unOp = other.getDefiningOp<sparse_tensor::UnaryOp>()) { 390 if (s_out != unOp->getOperand(0) || !unOp.getAbsentRegion().empty()) 391 return false; 392 // And the bodies are as expected. 393 auto yieldUn = cast<sparse_tensor::YieldOp>( 394 unOp.getRegion(0).front().getTerminator()); 395 auto yieldRed = cast<sparse_tensor::YieldOp>( 396 redOp.getRegion().front().getTerminator()); 397 return matchMulOfArgs(op.getBlock(), yieldUn.getOperand(0)) && 398 matchAddOfArgs(&redOp.getRegion().front(), yieldRed.getOperand(0)); 399 } 400 } 401 return false; 402 } 403 404 /// Test for dense tensor. 405 static bool isDenseTensor(Value v) { 406 auto sTp = getSparseTensorType(v); 407 return sTp.getDimRank() == sTp.getLvlRank() && sTp.isAllDense(); 408 } 409 410 /// Test for suitable positions/coordinates width. 411 static bool isAdmissibleMetaData(SparseTensorType &aTp) { 412 return (aTp.getPosWidth() == 0 || aTp.getPosWidth() >= 16) && 413 (aTp.getCrdWidth() == 0 || aTp.getCrdWidth() >= 16); 414 } 415 416 /// Test for sorted COO matrix with suitable metadata. 417 static bool isAdmissibleCOO(SparseTensorType &aTp) { 418 return aTp.getDimRank() == 2 && aTp.getLvlRank() == 2 && aTp.isIdentity() && 419 aTp.isCompressedLvl(0) && aTp.isOrderedLvl(0) && !aTp.isUniqueLvl(0) && 420 aTp.isSingletonLvl(1) && aTp.isOrderedLvl(1) && aTp.isUniqueLvl(1) && 421 isAdmissibleMetaData(aTp); 422 } 423 424 /// Test for CSR matrix with suitable metadata. 425 static bool isAdmissibleCSR(SparseTensorType &aTp) { 426 return aTp.getDimRank() == 2 && aTp.getLvlRank() == 2 && aTp.isIdentity() && 427 aTp.isDenseLvl(0) && aTp.isCompressedLvl(1) && aTp.isOrderedLvl(1) && 428 aTp.isUniqueLvl(1) && isAdmissibleMetaData(aTp); 429 } 430 431 /// Test for CSC matrix with suitable metadata. 432 static bool isAdmissibleCSC(SparseTensorType &aTp) { 433 return aTp.getDimRank() == 2 && aTp.getLvlRank() == 2 && !aTp.isIdentity() && 434 aTp.isPermutation() && aTp.isDenseLvl(0) && aTp.isCompressedLvl(1) && 435 aTp.isOrderedLvl(1) && aTp.isUniqueLvl(1) && isAdmissibleMetaData(aTp); 436 } 437 438 /// Test for BSR matrix with suitable metadata. 439 static bool isAdmissibleBSR(SparseTensorType &aTp) { 440 if (aTp.getDimRank() == 2 && aTp.getLvlRank() == 4 && aTp.isDenseLvl(0) && 441 aTp.isCompressedLvl(1) && aTp.isOrderedLvl(1) && aTp.isUniqueLvl(1) && 442 aTp.isDenseLvl(2) && aTp.isDenseLvl(3) && isAdmissibleMetaData(aTp)) { 443 // CuSparse only supports "square" blocks currently. 444 SmallVector<unsigned> dims = getBlockSize(aTp.getDimToLvl()); 445 assert(dims.size() == 2); 446 return dims[0] == dims[1] && dims[0] > 1; 447 } 448 return false; 449 } 450 451 /// Test for 2:4 matrix with suitable metadata. 452 static bool isAdmissible24(SparseTensorType &aTp) { 453 return aTp.getDimRank() == 2 && aTp.getLvlRank() == 3 && aTp.isDenseLvl(0) && 454 aTp.isDenseLvl(1) && aTp.isNOutOfMLvl(2) && isAdmissibleMetaData(aTp); 455 } 456 457 /// Test for conversion into 2:4 matrix. 458 static bool isConversionInto24(Value v) { 459 if (auto cnv = v.getDefiningOp<ConvertOp>()) { 460 Value a = cnv.getResult(); 461 Value d = cnv.getSource(); 462 SparseTensorType aTp = getSparseTensorType(a); 463 return isDenseTensor(d) && isAdmissible24(aTp); 464 } 465 return false; 466 } 467 468 /// Returns a suitable sparse format for the operation and given operand 469 /// types with cuSparse, or kNone if none is available. 470 static CuSparseFormat getCuSparseFormat(SparseTensorType aTp, 471 SparseTensorType bTp, 472 SparseTensorType cTp, bool enableRT, 473 bool isMatVec) { 474 // The other operands have a dense type. 475 if (bTp.hasEncoding() || cTp.hasEncoding()) 476 return CuSparseFormat::kNone; 477 // Now check for suitable operand type for the main operand. 478 if (isAdmissibleCOO(aTp)) 479 #ifdef CUSPARSE_COO_AOS 480 return isMatVec ? CuSparseFormat::kCOO : CuSparseFormat::kNone; 481 #else 482 return enableRT ? CuSparseFormat::kCOO : CuSparseFormat::kNone; 483 #endif 484 if (isAdmissibleCSR(aTp)) 485 return CuSparseFormat::kCSR; 486 if (isAdmissibleCSC(aTp)) 487 return CuSparseFormat::kCSC; 488 if (isAdmissibleBSR(aTp)) 489 return CuSparseFormat::kBSR; 490 return CuSparseFormat::kNone; 491 } 492 493 /// Generates the first positions/coordinates of a sparse matrix. 494 static Value genFirstPosOrCrds(OpBuilder &builder, Location loc, Value a, 495 CuSparseFormat format, bool enableRT) { 496 if (format == CuSparseFormat::kCOO) { 497 // Library uses SoA COO, direct IR uses AoS COO. 498 if (enableRT) 499 return builder.create<ToCoordinatesOp>(loc, a, 0); 500 return builder.create<ToCoordinatesBufferOp>(loc, a); 501 } 502 // Formats CSR/CSC and BSR use positions at 1. 503 return builder.create<ToPositionsOp>(loc, a, 1); 504 } 505 506 /// Generates the second coordinates of a sparse matrix. 507 static Value genSecondCrds(OpBuilder &builder, Location loc, Value a, 508 CuSparseFormat format, bool enableRT) { 509 bool isCOO = format == CuSparseFormat::kCOO; 510 if (isCOO && !enableRT) 511 return Value(); // nothing needed 512 // Formats CSR/CSC and BSR use coordinates at 1. 513 return builder.create<ToCoordinatesOp>(loc, a, 1); 514 } 515 516 /// Generates the sparse matrix handle. 517 static Operation *genSpMat(OpBuilder &builder, Location loc, 518 SparseTensorType &aTp, Type handleTp, Type tokenTp, 519 Value token, Value sz1, Value sz2, Value nseA, 520 Value rowA, Value colA, Value valA, 521 CuSparseFormat format, bool enableRT) { 522 if (format == CuSparseFormat::kCOO) { 523 // Library uses SoA COO, direct IR uses AoS COO. 524 if (enableRT) { 525 assert(colA); 526 return builder.create<gpu::CreateCooOp>(loc, handleTp, tokenTp, token, 527 sz1, sz2, nseA, rowA, colA, valA); 528 } 529 #ifdef CUSPARSE_COO_AOS 530 assert(!colA); 531 return builder.create<gpu::CreateCooAoSOp>(loc, handleTp, tokenTp, token, 532 sz1, sz2, nseA, rowA, valA); 533 #else 534 llvm_unreachable("gpu::CreateCooAoSOp is deprecated"); 535 #endif 536 } 537 assert(colA); 538 if (format == CuSparseFormat::kCSR) 539 return builder.create<gpu::CreateCsrOp>(loc, handleTp, tokenTp, token, sz1, 540 sz2, nseA, rowA, colA, valA); 541 if (format == CuSparseFormat::kCSC) 542 return builder.create<gpu::CreateCscOp>(loc, handleTp, tokenTp, token, sz1, 543 sz2, nseA, rowA, colA, valA); 544 // BSR requires a bit more work since we need to pass in the block size 545 // and all others sizes in terms of blocks (#block-rows, #block-cols, 546 // #nonzero-blocks). 547 assert(format == CuSparseFormat::kBSR); 548 SmallVector<unsigned> dims = getBlockSize(aTp.getDimToLvl()); 549 assert(dims.size() == 2 && dims[0] == dims[1]); 550 uint64_t b = dims[0]; 551 Value bSz = constantIndex(builder, loc, b); 552 Value bRows = builder.create<arith::DivUIOp>(loc, sz1, bSz); 553 Value bCols = builder.create<arith::DivUIOp>(loc, sz2, bSz); 554 Value bNum = builder.create<arith::DivUIOp>( 555 loc, nseA, constantIndex(builder, loc, b * b)); 556 return builder.create<gpu::CreateBsrOp>(loc, handleTp, tokenTp, token, bRows, 557 bCols, bNum, bSz, bSz, rowA, colA, 558 valA); 559 } 560 561 /// Match and rewrite SpMV kernel. 562 static LogicalResult rewriteSpMV(PatternRewriter &rewriter, 563 linalg::GenericOp op, bool enableRT) { 564 Location loc = op.getLoc(); 565 Value a = op.getOperand(0); 566 Value x = op.getOperand(1); 567 Value y = op.getOperand(2); // we have y = Ax 568 SmallVector<Value> tokens; 569 570 // Only admissible sparse matrix format and dense vectors (no BSR). 571 SparseTensorType aTp = getSparseTensorType(a); 572 SparseTensorType xTp = getSparseTensorType(x); 573 SparseTensorType yTp = getSparseTensorType(y); 574 auto format = getCuSparseFormat(aTp, xTp, yTp, enableRT, /*isMatVec=*/true); 575 if (format == CuSparseFormat::kNone || format == CuSparseFormat::kBSR) 576 return failure(); 577 578 // Start sparse kernel and copy data from host to device. 579 // a : memR/memC/memV -> rowA,colA,valA 580 // x : memX -> vecX 581 // y : memY -> vecY 582 Value nseA = rewriter.create<NumberOfEntriesOp>(loc, a); 583 Value szY = linalg::createOrFoldDimOp(rewriter, loc, a, 0); 584 Value szX = linalg::createOrFoldDimOp(rewriter, loc, a, 1); 585 Value memR = genFirstPosOrCrds(rewriter, loc, a, format, enableRT); 586 Value memC = genSecondCrds(rewriter, loc, a, format, enableRT); // or empty 587 Value memV = rewriter.create<ToValuesOp>(loc, a); 588 Value rowA = genAllocCopy(rewriter, loc, memR, tokens); 589 Value colA = memC ? genAllocCopy(rewriter, loc, memC, tokens) : Value(); 590 Value valA = genAllocCopy(rewriter, loc, memV, tokens); 591 Value memX = genTensorToMemref(rewriter, loc, x); 592 Value vecX = genAllocCopy(rewriter, loc, memX, tokens); 593 Value memY = genTensorToMemref(rewriter, loc, y); 594 Value vecY = genAllocCopy(rewriter, loc, memY, tokens); 595 genBlockingWait(rewriter, loc, tokens); 596 tokens.clear(); 597 598 // Create sparse environment and sparse matrix/dense vector handles. 599 Type indexTp = rewriter.getIndexType(); 600 Type dnTensorHandleTp = rewriter.getType<gpu::SparseDnTensorHandleType>(); 601 Type spmatHandleTp = rewriter.getType<gpu::SparseSpMatHandleType>(); 602 Type tokenTp = rewriter.getType<gpu::AsyncTokenType>(); 603 Value token = genFirstWait(rewriter, loc); 604 Operation *spGenA = 605 genSpMat(rewriter, loc, aTp, spmatHandleTp, tokenTp, token, szY, szX, 606 nseA, rowA, colA, valA, format, enableRT); 607 Value spMatA = spGenA->getResult(0); 608 token = spGenA->getResult(1); 609 auto dvecX = rewriter.create<gpu::CreateDnTensorOp>( 610 loc, dnTensorHandleTp, tokenTp, token, vecX, szX); 611 Value dnX = dvecX.getResult(0); 612 token = dvecX.getAsyncToken(); 613 auto dvecY = rewriter.create<gpu::CreateDnTensorOp>( 614 loc, dnTensorHandleTp, tokenTp, token, vecY, szY); 615 Value dnY = dvecY.getResult(0); 616 token = dvecY.getAsyncToken(); 617 auto dnYType = llvm::cast<ShapedType>(y.getType()).getElementType(); 618 619 // Precompute buffersize for SpMV. 620 auto bufferComp = rewriter.create<gpu::SpMVBufferSizeOp>( 621 loc, indexTp, tokenTp, token, spMatA, dnX, dnY, 622 /*computeType=*/dnYType); 623 Value bufferSz = bufferComp.getResult(0); 624 token = bufferComp.getAsyncToken(); 625 auto buf = genAllocBuffer(rewriter, loc, bufferSz, token); 626 Value buffer = buf.getResult(0); 627 token = buf.getAsyncToken(); 628 629 // Perform the SpMV. 630 auto spmvComp = rewriter.create<gpu::SpMVOp>( 631 loc, tokenTp, token, spMatA, dnX, dnY, /*computeType=*/dnYType, buffer); 632 token = spmvComp.getAsyncToken(); 633 634 // Copy data back to host and free all the resoures. 635 token = rewriter.create<gpu::DestroySpMatOp>(loc, tokenTp, token, spMatA) 636 .getAsyncToken(); 637 token = rewriter.create<gpu::DestroyDnTensorOp>(loc, tokenTp, token, dnX) 638 .getAsyncToken(); 639 token = rewriter.create<gpu::DestroyDnTensorOp>(loc, tokenTp, token, dnY) 640 .getAsyncToken(); 641 token = genDeallocMemRef(rewriter, loc, rowA, token); 642 if (colA) 643 token = genDeallocMemRef(rewriter, loc, colA, token); 644 token = genDeallocMemRef(rewriter, loc, valA, token); 645 token = genDeallocMemRef(rewriter, loc, buffer, token); 646 token = genDeallocMemRef(rewriter, loc, vecX, token); 647 token = genCopyMemRef(rewriter, loc, memY, vecY, token); 648 token = genDeallocMemRef(rewriter, loc, vecY, token); 649 tokens.push_back(token); 650 genBlockingWait(rewriter, loc, tokens); 651 tokens.clear(); 652 653 // Done. 654 rewriter.replaceOpWithNewOp<bufferization::ToTensorOp>(op, memY); 655 return success(); 656 } 657 658 /// Match and rewrite SpMM kernel. 659 static LogicalResult rewriteSpMM(PatternRewriter &rewriter, 660 linalg::GenericOp op, bool enableRT) { 661 Location loc = op.getLoc(); 662 Value a = op.getOperand(0); 663 Value b = op.getOperand(1); 664 Value c = op.getOperand(2); // we have C = AB 665 SmallVector<Value> tokens; 666 667 // Only admissible sparse matrix format and dense matrices (no BSR). 668 SparseTensorType aTp = getSparseTensorType(a); 669 SparseTensorType bTp = getSparseTensorType(b); 670 SparseTensorType cTp = getSparseTensorType(c); 671 auto format = getCuSparseFormat(aTp, bTp, cTp, enableRT, /*isMatVec=*/false); 672 if (format == CuSparseFormat::kNone || format == CuSparseFormat::kBSR) 673 return failure(); 674 675 // Start sparse kernel and copy data from host to device. 676 // a : memR/memC/memV -> rowA,colA,valA 677 // b : bufB -> matB 678 // c : bufC -> matC 679 Value nseA = rewriter.create<NumberOfEntriesOp>(loc, a); 680 Value szm = linalg::createOrFoldDimOp(rewriter, loc, a, 0); 681 Value szk = linalg::createOrFoldDimOp(rewriter, loc, a, 1); 682 Value szn = linalg::createOrFoldDimOp(rewriter, loc, b, 1); 683 Value memR = genFirstPosOrCrds(rewriter, loc, a, format, enableRT); 684 Value memC = genSecondCrds(rewriter, loc, a, format, enableRT); // or empty 685 Value memV = rewriter.create<ToValuesOp>(loc, a); 686 Value rowA = genAllocCopy(rewriter, loc, memR, tokens); 687 Value colA = memC ? genAllocCopy(rewriter, loc, memC, tokens) : Value(); 688 Value valA = genAllocCopy(rewriter, loc, memV, tokens); 689 Value bufB = genTensorToMemref(rewriter, loc, b); 690 Value matB = genAllocCopy(rewriter, loc, bufB, tokens); 691 Value bufC = genTensorToMemref(rewriter, loc, c); 692 Value matC = genAllocCopy(rewriter, loc, bufC, tokens); 693 genBlockingWait(rewriter, loc, tokens); 694 tokens.clear(); 695 696 // Create sparse environment and sparse matrix/dense matrix handles. 697 Type indexTp = rewriter.getIndexType(); 698 Type dnTensorHandleTp = rewriter.getType<gpu::SparseDnTensorHandleType>(); 699 Type spMatHandleTp = rewriter.getType<gpu::SparseSpMatHandleType>(); 700 Type tokenTp = rewriter.getType<gpu::AsyncTokenType>(); 701 Value token = genFirstWait(rewriter, loc); 702 Operation *spGenA = 703 genSpMat(rewriter, loc, aTp, spMatHandleTp, tokenTp, token, szm, szk, 704 nseA, rowA, colA, valA, format, enableRT); 705 Value spMatA = spGenA->getResult(0); 706 token = spGenA->getResult(1); 707 auto dmatB = rewriter.create<gpu::CreateDnTensorOp>( 708 loc, dnTensorHandleTp, tokenTp, token, matB, 709 SmallVector<Value>{szk, szn}); 710 Value dnB = dmatB.getResult(0); 711 token = dmatB.getAsyncToken(); 712 auto dmatC = rewriter.create<gpu::CreateDnTensorOp>( 713 loc, dnTensorHandleTp, tokenTp, token, matC, 714 SmallVector<Value>{szm, szn}); 715 Value dnC = dmatC.getResult(0); 716 token = dmatC.getAsyncToken(); 717 auto dmatCType = llvm::cast<ShapedType>(c.getType()).getElementType(); 718 719 // Precompute buffersize for SpMM. 720 auto bufferComp = rewriter.create<gpu::SpMMBufferSizeOp>( 721 loc, indexTp, tokenTp, token, spMatA, dnB, dnC, 722 /*computeType=*/dmatCType); 723 Value bufferSz = bufferComp.getResult(0); 724 token = bufferComp.getAsyncToken(); 725 auto buf = genAllocBuffer(rewriter, loc, bufferSz, token); 726 Value buffer = buf.getResult(0); 727 token = buf.getAsyncToken(); 728 auto dnCType = llvm::cast<ShapedType>(c.getType()).getElementType(); 729 730 // Perform the SpMM. 731 auto spmmComp = rewriter.create<gpu::SpMMOp>( 732 loc, tokenTp, token, spMatA, dnB, dnC, /*computeType=*/dnCType, buffer); 733 token = spmmComp.getAsyncToken(); 734 735 // Copy data back to host and free all the resoures. 736 token = rewriter.create<gpu::DestroySpMatOp>(loc, tokenTp, token, spMatA) 737 .getAsyncToken(); 738 token = rewriter.create<gpu::DestroyDnTensorOp>(loc, tokenTp, token, dnB) 739 .getAsyncToken(); 740 token = rewriter.create<gpu::DestroyDnTensorOp>(loc, tokenTp, token, dnC) 741 .getAsyncToken(); 742 token = genDeallocMemRef(rewriter, loc, rowA, token); 743 if (colA) 744 token = genDeallocMemRef(rewriter, loc, colA, token); 745 token = genDeallocMemRef(rewriter, loc, valA, token); 746 token = genDeallocMemRef(rewriter, loc, buffer, token); 747 token = genDeallocMemRef(rewriter, loc, matB, token); 748 token = genCopyMemRef(rewriter, loc, bufC, matC, token); 749 token = genDeallocMemRef(rewriter, loc, matC, token); 750 tokens.push_back(token); 751 genBlockingWait(rewriter, loc, tokens); 752 tokens.clear(); 753 754 // Done. 755 rewriter.replaceOpWithNewOp<bufferization::ToTensorOp>(op, bufC); 756 return success(); 757 } 758 759 // Match and rewrite SpGEMM kernel. 760 static LogicalResult rewriteSpGEMM(PatternRewriter &rewriter, 761 linalg::GenericOp op, bool enableRT) { 762 Location loc = op.getLoc(); 763 Value a = op.getOperand(0); 764 Value b = op.getOperand(1); 765 Value c = op.getOperand(2); // we have C = AB 766 SmallVector<Value> tokens; 767 768 // Only CSR <- CSR x CSR supported. 769 auto format = CuSparseFormat::kCSR; 770 SparseTensorType aTp = getSparseTensorType(a); 771 SparseTensorType bTp = getSparseTensorType(b); 772 SparseTensorType cTp = getSparseTensorType(c); 773 if (!isAdmissibleCSR(aTp) || !isAdmissibleCSR(bTp) || !isAdmissibleCSR(cTp)) 774 return failure(); 775 776 // Start sparse kernel and copy data from host to device. 777 // a : amemR/amemC/amemV -> rowA,colA,valA 778 // b : bmemR/bmemC/bmemV -> rowB,colB,valB 779 // c : materializes 780 auto dnCType = cTp.getElementType(); 781 Value nseA = rewriter.create<NumberOfEntriesOp>(loc, a); 782 Value nseB = rewriter.create<NumberOfEntriesOp>(loc, b); 783 Value szm = linalg::createOrFoldDimOp(rewriter, loc, a, 0); 784 Value szk = linalg::createOrFoldDimOp(rewriter, loc, a, 1); 785 Value szn = linalg::createOrFoldDimOp(rewriter, loc, b, 1); 786 Value amemR = genFirstPosOrCrds(rewriter, loc, a, format, enableRT); 787 Value amemC = genSecondCrds(rewriter, loc, a, format, enableRT); // not empty 788 Value amemV = rewriter.create<ToValuesOp>(loc, a); 789 Value bmemR = genFirstPosOrCrds(rewriter, loc, b, format, enableRT); 790 Value bmemC = genSecondCrds(rewriter, loc, b, format, enableRT); // not empty 791 Value bmemV = rewriter.create<ToValuesOp>(loc, b); 792 Value rowA = genAllocCopy(rewriter, loc, amemR, tokens); 793 Value colA = genAllocCopy(rewriter, loc, amemC, tokens); 794 Value valA = genAllocCopy(rewriter, loc, amemV, tokens); 795 Value rowB = genAllocCopy(rewriter, loc, bmemR, tokens); 796 Value colB = genAllocCopy(rewriter, loc, bmemC, tokens); 797 Value valB = genAllocCopy(rewriter, loc, bmemV, tokens); 798 genBlockingWait(rewriter, loc, tokens); 799 tokens.clear(); 800 801 // Create sparse environment and sparse matrix/dense vector handles. 802 Type indexTp = rewriter.getIndexType(); 803 Type spmatHandleTp = rewriter.getType<gpu::SparseSpMatHandleType>(); 804 Type descTp = rewriter.getType<gpu::SparseSpGEMMOpHandleType>(); 805 Type tokenTp = rewriter.getType<gpu::AsyncTokenType>(); 806 Value token = genFirstWait(rewriter, loc); 807 Operation *spGenA = 808 genSpMat(rewriter, loc, aTp, spmatHandleTp, tokenTp, token, szm, szk, 809 nseA, rowA, colA, valA, format, enableRT); 810 Value spMatA = spGenA->getResult(0); 811 token = spGenA->getResult(1); 812 Operation *spGenB = 813 genSpMat(rewriter, loc, bTp, spmatHandleTp, tokenTp, token, szk, szn, 814 nseB, rowB, colB, valB, format, enableRT); 815 Value spMatB = spGenB->getResult(0); 816 token = spGenB->getResult(1); 817 818 // Sparse matrix C materializes (also assumes beta == 0). 819 Value zero = constantIndex(rewriter, loc, 0); 820 Value one = constantIndex(rewriter, loc, 1); 821 Value mplus1 = rewriter.create<arith::AddIOp>(loc, szm, one); 822 auto e1 = genAllocBuffer(rewriter, loc, cTp.getPosType(), mplus1, token); 823 Value rowC = e1.getResult(0); 824 token = e1.getAsyncToken(); 825 auto e2 = genAllocBuffer(rewriter, loc, cTp.getCrdType(), zero, token); 826 Value colC = e2.getResult(0); // no free needed 827 token = e2.getAsyncToken(); 828 auto e3 = genAllocBuffer(rewriter, loc, dnCType, zero, token); 829 Value valC = e3.getResult(0); // no free needed 830 token = e3.getAsyncToken(); 831 Operation *spGenC = 832 genSpMat(rewriter, loc, cTp, spmatHandleTp, tokenTp, token, szm, szn, 833 zero, rowC, colC, valC, format, enableRT); 834 Value spMatC = spGenC->getResult(0); 835 token = spGenC->getResult(1); 836 837 // Precompute buffersizes for SpGEMM. 838 Operation *descOp = 839 rewriter.create<gpu::SpGEMMCreateDescrOp>(loc, descTp, tokenTp, token); 840 Value desc = descOp->getResult(0); 841 token = descOp->getResult(1); 842 Operation *work1 = rewriter.create<gpu::SpGEMMWorkEstimationOrComputeOp>( 843 loc, indexTp, tokenTp, token, desc, gpu::TransposeMode::NON_TRANSPOSE, 844 gpu::TransposeMode::NON_TRANSPOSE, spMatA, spMatB, spMatC, dnCType, zero, 845 valC, gpu::SpGEMMWorkEstimationOrComputeKind::WORK_ESTIMATION); 846 Value bufferSz1 = work1->getResult(0); 847 token = work1->getResult(1); 848 auto buf1 = genAllocBuffer(rewriter, loc, bufferSz1, token); 849 Value buffer1 = buf1.getResult(0); 850 token = buf1.getAsyncToken(); 851 Operation *work2 = rewriter.create<gpu::SpGEMMWorkEstimationOrComputeOp>( 852 loc, indexTp, tokenTp, token, desc, gpu::TransposeMode::NON_TRANSPOSE, 853 gpu::TransposeMode::NON_TRANSPOSE, spMatA, spMatB, spMatC, dnCType, 854 bufferSz1, buffer1, 855 gpu::SpGEMMWorkEstimationOrComputeKind::WORK_ESTIMATION); 856 token = work2->getResult(1); 857 858 // Compute step. 859 Operation *compute1 = rewriter.create<gpu::SpGEMMWorkEstimationOrComputeOp>( 860 loc, indexTp, tokenTp, token, desc, gpu::TransposeMode::NON_TRANSPOSE, 861 gpu::TransposeMode::NON_TRANSPOSE, spMatA, spMatB, spMatC, dnCType, zero, 862 valC, gpu::SpGEMMWorkEstimationOrComputeKind::COMPUTE); 863 Value bufferSz2 = compute1->getResult(0); 864 token = compute1->getResult(1); 865 auto buf2 = genAllocBuffer(rewriter, loc, bufferSz2, token); 866 Value buffer2 = buf2.getResult(0); 867 token = buf2.getAsyncToken(); 868 Operation *compute2 = rewriter.create<gpu::SpGEMMWorkEstimationOrComputeOp>( 869 loc, indexTp, tokenTp, token, desc, gpu::TransposeMode::NON_TRANSPOSE, 870 gpu::TransposeMode::NON_TRANSPOSE, spMatA, spMatB, spMatC, dnCType, 871 bufferSz2, buffer2, gpu::SpGEMMWorkEstimationOrComputeKind::COMPUTE); 872 token = compute2->getResult(1); 873 874 // Get sizes. 875 Operation *sizes = rewriter.create<gpu::SpMatGetSizeOp>( 876 loc, indexTp, indexTp, indexTp, tokenTp, token, spMatC); 877 Value nnz = sizes->getResult(2); 878 token = sizes->getResult(3); 879 auto a2 = genAllocBuffer(rewriter, loc, cTp.getCrdType(), nnz, token); 880 colC = a2.getResult(0); 881 token = a2.getAsyncToken(); 882 auto a3 = genAllocBuffer(rewriter, loc, dnCType, nnz, token); 883 valC = a3.getResult(0); 884 token = a3.getAsyncToken(); 885 886 // Update C with new pointers and copy final product back into C. 887 Operation *update = rewriter.create<gpu::SetCsrPointersOp>( 888 loc, tokenTp, token, spMatC, rowC, colC, valC); 889 token = update->getResult(0); 890 Operation *copy = rewriter.create<gpu::SpGEMMCopyOp>( 891 loc, tokenTp, token, desc, gpu::TransposeMode::NON_TRANSPOSE, 892 gpu::TransposeMode::NON_TRANSPOSE, spMatA, spMatB, spMatC, dnCType); 893 token = copy->getResult(0); 894 895 // Allocate buffers on host. 896 Value rowH = genHostBuffer(rewriter, loc, cTp.getPosType(), mplus1); 897 Value colH = genHostBuffer(rewriter, loc, cTp.getCrdType(), nnz); 898 Value valH = genHostBuffer(rewriter, loc, dnCType, nnz); 899 900 // Copy data back to host and free all the resoures. 901 token = rewriter.create<gpu::SpGEMMDestroyDescrOp>(loc, tokenTp, token, desc) 902 .getAsyncToken(); 903 token = rewriter.create<gpu::DestroySpMatOp>(loc, tokenTp, token, spMatA) 904 .getAsyncToken(); 905 token = rewriter.create<gpu::DestroySpMatOp>(loc, tokenTp, token, spMatB) 906 .getAsyncToken(); 907 token = rewriter.create<gpu::DestroySpMatOp>(loc, tokenTp, token, spMatC) 908 .getAsyncToken(); 909 token = genCopyMemRef(rewriter, loc, rowH, rowC, token); 910 token = genCopyMemRef(rewriter, loc, colH, colC, token); 911 token = genCopyMemRef(rewriter, loc, valH, valC, token); 912 token = genDeallocMemRef(rewriter, loc, rowA, token); 913 token = genDeallocMemRef(rewriter, loc, colA, token); 914 token = genDeallocMemRef(rewriter, loc, valA, token); 915 token = genDeallocMemRef(rewriter, loc, rowB, token); 916 token = genDeallocMemRef(rewriter, loc, colB, token); 917 token = genDeallocMemRef(rewriter, loc, valB, token); 918 token = genDeallocMemRef(rewriter, loc, rowC, token); 919 token = genDeallocMemRef(rewriter, loc, colC, token); 920 token = genDeallocMemRef(rewriter, loc, valC, token); 921 token = genDeallocMemRef(rewriter, loc, buffer1, token); 922 token = genDeallocMemRef(rewriter, loc, buffer2, token); 923 tokens.push_back(token); 924 genBlockingWait(rewriter, loc, tokens); 925 tokens.clear(); 926 927 // Done. 928 Value vt = rewriter.create<bufferization::ToTensorOp>(loc, valH); 929 Value rt = rewriter.create<bufferization::ToTensorOp>(loc, rowH); 930 Value ct = rewriter.create<bufferization::ToTensorOp>(loc, colH); 931 rewriter.replaceOpWithNewOp<AssembleOp>(op, c.getType(), ValueRange{rt, ct}, 932 vt); 933 return success(); 934 } 935 936 // Match and rewrite 2:4 SpMM kernel. 937 static LogicalResult rewrite2To4SpMM(PatternRewriter &rewriter, 938 linalg::GenericOp op) { 939 Location loc = op.getLoc(); 940 Value A = op.getOperand(0); 941 Value B = op.getOperand(1); 942 Value C = op.getOperand(2); // we have C = AB 943 SmallVector<Value> tokens; 944 945 // The cuSparselt API currently only allows pruning and compression 946 // to occur on the device. So we recognize the pattern 947 // A' = convert A ; dense to 2:4 948 // C = A'B ; 2:4 matrix mult 949 // and then perform compression and matrix multiplication on device. 950 auto cnv = A.getDefiningOp<ConvertOp>(); 951 assert(cnv); 952 A = cnv.getSource(); 953 954 // All input should be dense tensors. 955 if (!isDenseTensor(A) || !isDenseTensor(B) || !isDenseTensor(C)) 956 return failure(); 957 958 // Start sparse kernel and copy data from host to device. 959 // a : bufA -> matA 960 // b : bufB -> matB 961 // c : bufC -> matC 962 Value bufA = genTensorToMemref(rewriter, loc, A); 963 Value matA = genAllocCopy(rewriter, loc, bufA, tokens); 964 Value bufB = genTensorToMemref(rewriter, loc, B); 965 Value matB = genAllocCopy(rewriter, loc, bufB, tokens); 966 Value bufC = genTensorToMemref(rewriter, loc, C); 967 Value matC = genAllocCopy(rewriter, loc, bufC, tokens); 968 genBlockingWait(rewriter, loc, tokens); 969 tokens.clear(); 970 971 // Create sparse environment and sparse matrix/dense vector handles. 972 Value szm = linalg::createOrFoldDimOp(rewriter, loc, matA, 0); 973 Value szk = linalg::createOrFoldDimOp(rewriter, loc, matB, 0); 974 Value szn = linalg::createOrFoldDimOp(rewriter, loc, matC, 1); 975 Type indexTp = rewriter.getIndexType(); 976 Type dnTensorHandleTp = rewriter.getType<gpu::SparseDnTensorHandleType>(); 977 Type spMatHandleTp = rewriter.getType<gpu::SparseSpMatHandleType>(); 978 Type tokenTp = rewriter.getType<gpu::AsyncTokenType>(); 979 Value token = genFirstWait(rewriter, loc); 980 Operation *spGenA = rewriter.create<gpu::Create2To4SpMatOp>( 981 loc, spMatHandleTp, tokenTp, token, szm, szk, 982 gpu::Prune2To4SpMatFlag::PRUNE_AND_CHECK, matA); 983 Value spMatA = spGenA->getResult(0); 984 token = spGenA->getResult(1); 985 auto dmatB = rewriter.create<gpu::CreateDnTensorOp>( 986 loc, dnTensorHandleTp, tokenTp, token, matB, 987 SmallVector<Value>{szk, szn}); 988 Value dnB = dmatB.getResult(0); 989 token = dmatB.getAsyncToken(); 990 auto dmatC = rewriter.create<gpu::CreateDnTensorOp>( 991 loc, dnTensorHandleTp, tokenTp, token, matC, 992 SmallVector<Value>{szm, szn}); 993 Value dnC = dmatC.getResult(0); 994 token = dmatC.getAsyncToken(); 995 auto dmatCType = llvm::cast<ShapedType>(matC.getType()).getElementType(); 996 997 // Precompute buffersize for SpMM. 998 SmallVector<Type> bufferTypes_{indexTp, indexTp, indexTp}; 999 TypeRange bufferTypes(bufferTypes_); 1000 auto bufferComp = rewriter.create<gpu::SpMMBufferSizeOp>( 1001 loc, bufferTypes, tokenTp, token, gpu::TransposeMode::NON_TRANSPOSE, 1002 gpu::TransposeMode::NON_TRANSPOSE, spMatA, dnB, dnC, 1003 /*computeType=*/dmatCType); 1004 token = bufferComp.getAsyncToken(); 1005 1006 // Allocate buffers on host. 1007 Value bufferSz1 = bufferComp.getResult(0); 1008 auto buf1 = genAllocBuffer(rewriter, loc, bufferSz1, token); 1009 Value buffer1 = buf1.getResult(0); 1010 token = buf1.getAsyncToken(); 1011 Value bufferSz2 = bufferComp.getResult(1); 1012 auto buf2 = genAllocBuffer(rewriter, loc, bufferSz2, token); 1013 Value buffer2 = buf2.getResult(0); 1014 token = buf2.getAsyncToken(); 1015 Value bufferSz3 = bufferComp.getResult(2); 1016 auto buf3 = genAllocBuffer(rewriter, loc, bufferSz3, token); 1017 Value buffer3 = buf3.getResult(0); 1018 token = buf3.getAsyncToken(); 1019 1020 // Perform the SpMM. 1021 auto dnCType = llvm::cast<ShapedType>(matC.getType()).getElementType(); 1022 auto spmmComp = rewriter.create<gpu::SpMMOp>( 1023 loc, tokenTp, token, spMatA, dnB, dnC, /*computeType=*/dnCType, 1024 SmallVector<Value>{buffer1, buffer2, buffer3}); 1025 token = spmmComp.getAsyncToken(); 1026 1027 // Copy data back to host and free all the resources. 1028 token = rewriter.create<gpu::DestroySpMatOp>(loc, tokenTp, token, spMatA) 1029 .getAsyncToken(); 1030 token = rewriter.create<gpu::DestroyDnTensorOp>(loc, tokenTp, token, dnB) 1031 .getAsyncToken(); 1032 token = rewriter.create<gpu::DestroyDnTensorOp>(loc, tokenTp, token, dnC) 1033 .getAsyncToken(); 1034 SmallVector<Value> newDynamicSizes; 1035 token = genDeallocMemRef(rewriter, loc, buffer1, token); 1036 token = genDeallocMemRef(rewriter, loc, buffer2, token); 1037 token = genDeallocMemRef(rewriter, loc, buffer3, token); 1038 token = genDeallocMemRef(rewriter, loc, matA, token); 1039 token = genDeallocMemRef(rewriter, loc, matB, token); 1040 token = genCopyMemRef(rewriter, loc, bufC, matC, token); 1041 token = genDeallocMemRef(rewriter, loc, matC, token); 1042 tokens.push_back(token); 1043 genBlockingWait(rewriter, loc, tokens); 1044 tokens.clear(); 1045 1046 // Done. 1047 rewriter.replaceOpWithNewOp<bufferization::ToTensorOp>(op, bufC); 1048 return success(); 1049 } 1050 1051 /// Match and rewrite SDDMM kernel. 1052 static LogicalResult rewriteSDDMM(PatternRewriter &rewriter, 1053 linalg::GenericOp op, bool enableRT) { 1054 Location loc = op.getLoc(); 1055 Value a = op.getOperand(0); 1056 Value b = op.getOperand(1); 1057 Value c = op.getOperand(2); 1058 SmallVector<Value> tokens; 1059 1060 // Only admissible sparse matrix format (no COO/CSC) and dense matrices. 1061 SparseTensorType aTp = getSparseTensorType(a); 1062 SparseTensorType bTp = getSparseTensorType(b); 1063 SparseTensorType cTp = getSparseTensorType(c); 1064 auto format = getCuSparseFormat(cTp, bTp, aTp, enableRT, /*isMatVec=*/false); 1065 if (format == CuSparseFormat::kNone || format == CuSparseFormat::kCOO || 1066 format == CuSparseFormat::kCSC) 1067 return failure(); 1068 1069 // The SDDMM does the in-place operation. 1070 // Start sparse kernel and copy data from host to device. 1071 // a : bufA -> matA 1072 // b : bufB -> matB 1073 // c : memR/memC/memV -> rowC,colC,valC 1074 Value nseC = rewriter.create<NumberOfEntriesOp>(loc, c); 1075 Value szm = linalg::createOrFoldDimOp(rewriter, loc, a, 0); 1076 Value szk = linalg::createOrFoldDimOp(rewriter, loc, a, 1); 1077 Value szn = linalg::createOrFoldDimOp(rewriter, loc, b, 1); 1078 Value bufA = genTensorToMemref(rewriter, loc, a); 1079 Value matA = genAllocCopy(rewriter, loc, bufA, tokens); 1080 Value bufB = genTensorToMemref(rewriter, loc, b); 1081 Value matB = genAllocCopy(rewriter, loc, bufB, tokens); 1082 Value memR = genFirstPosOrCrds(rewriter, loc, c, format, enableRT); 1083 Value memC = genSecondCrds(rewriter, loc, c, format, enableRT); // or empty 1084 Value memV = rewriter.create<ToValuesOp>(loc, c); 1085 Value rowC = genAllocCopy(rewriter, loc, memR, tokens); 1086 Value colC = memC ? genAllocCopy(rewriter, loc, memC, tokens) : Value(); 1087 Value valC = genAllocCopy(rewriter, loc, memV, tokens); 1088 genBlockingWait(rewriter, loc, tokens); 1089 tokens.clear(); 1090 1091 // Create sparse environment and sparse matrix/dense matrix handles. 1092 Type indexTp = rewriter.getIndexType(); 1093 Type dnMatHandleTp = rewriter.getType<gpu::SparseDnTensorHandleType>(); 1094 Type spMatHandleTp = rewriter.getType<gpu::SparseSpMatHandleType>(); 1095 Type tokenTp = rewriter.getType<gpu::AsyncTokenType>(); 1096 Value token = genFirstWait(rewriter, loc); 1097 auto dmatA = rewriter.create<gpu::CreateDnTensorOp>( 1098 loc, dnMatHandleTp, tokenTp, token, matA, SmallVector<Value>{szm, szk}); 1099 Value dnA = dmatA.getResult(0); 1100 token = dmatA.getAsyncToken(); 1101 auto dmatB = rewriter.create<gpu::CreateDnTensorOp>( 1102 loc, dnMatHandleTp, tokenTp, token, matB, SmallVector<Value>{szk, szn}); 1103 Value dnB = dmatB.getResult(0); 1104 token = dmatB.getAsyncToken(); 1105 Operation *spGenC = 1106 genSpMat(rewriter, loc, cTp, spMatHandleTp, tokenTp, token, szm, szn, 1107 nseC, rowC, colC, valC, format, enableRT); 1108 Value spMatC = spGenC->getResult(0); 1109 token = spGenC->getResult(1); 1110 auto dnCType = llvm::cast<ShapedType>(c.getType()).getElementType(); 1111 1112 // Precompute buffersize for SDDMM. 1113 auto bufferComp = rewriter.create<gpu::SDDMMBufferSizeOp>( 1114 loc, indexTp, tokenTp, token, dnA, dnB, spMatC, dnCType); 1115 Value bufferSz = bufferComp.getResult(0); 1116 token = bufferComp.getAsyncToken(); 1117 auto buf = genAllocBuffer(rewriter, loc, bufferSz, token); 1118 Value buffer = buf.getResult(0); 1119 token = buf.getAsyncToken(); 1120 1121 // Perform the SDDMM. 1122 auto sddmmComp = rewriter.create<gpu::SDDMMOp>(loc, tokenTp, token, dnA, dnB, 1123 spMatC, dnCType, buffer); 1124 token = sddmmComp.getAsyncToken(); 1125 1126 // Copy data back to host and free all the resoures. 1127 token = rewriter.create<gpu::DestroyDnTensorOp>(loc, tokenTp, token, dnA) 1128 .getAsyncToken(); 1129 token = rewriter.create<gpu::DestroyDnTensorOp>(loc, tokenTp, token, dnB) 1130 .getAsyncToken(); 1131 token = rewriter.create<gpu::DestroySpMatOp>(loc, tokenTp, token, spMatC) 1132 .getAsyncToken(); 1133 token = genDeallocMemRef(rewriter, loc, buffer, token); 1134 token = genDeallocMemRef(rewriter, loc, matA, token); 1135 token = genDeallocMemRef(rewriter, loc, matB, token); 1136 token = genDeallocMemRef(rewriter, loc, rowC, token); 1137 if (colC) 1138 token = genDeallocMemRef(rewriter, loc, colC, token); 1139 token = genCopyMemRef(rewriter, loc, memV, valC, token); 1140 token = genDeallocMemRef(rewriter, loc, valC, token); 1141 tokens.push_back(token); 1142 genBlockingWait(rewriter, loc, tokens); 1143 tokens.clear(); 1144 1145 // Done. 1146 rewriter.replaceOpWithNewOp<sparse_tensor::LoadOp>(op, c); 1147 return success(); 1148 } 1149 1150 //===----------------------------------------------------------------------===// 1151 // Rewriting rules for direct code generation. 1152 //===----------------------------------------------------------------------===// 1153 1154 /// Proof-of-concept rewriter. This rule generates a GPU implementation 1155 /// for each outermost forall loop generated by the sparsifier. 1156 /// TODO: right now works with parallelization-strategy=dense-outer-loop 1157 /// but give this its own flags in the future 1158 struct ForallRewriter : public OpRewritePattern<scf::ParallelOp> { 1159 using OpRewritePattern<scf::ParallelOp>::OpRewritePattern; 1160 1161 ForallRewriter(MLIRContext *context, unsigned nT) 1162 : OpRewritePattern(context), numThreads(nT){}; 1163 1164 LogicalResult matchAndRewrite(scf::ParallelOp forallOp, 1165 PatternRewriter &rewriter) const override { 1166 // Reject inadmissible loop form. 1167 // Essentially only accept a loop, generated by the sparsifier, 1168 // of the form 1169 // forall (i = 0; i < N; i++) 1170 // so that cyclic scheduling over the threads is easy. 1171 if (!forallOp->hasAttr(LoopEmitter::getLoopEmitterLoopAttrName()) || 1172 forallOp.getNumReductions() != 0 || forallOp.getNumLoops() != 1 || 1173 !matchPattern(forallOp.getLowerBound()[0], m_Zero()) || 1174 !matchPattern(forallOp.getStep()[0], m_One())) 1175 return failure(); 1176 // Collect every value that is computed outside the parallel loop. 1177 SetVector<Value> invariants; // stable iteration! 1178 forallOp->walk([&](Operation *op) { 1179 // Collect all values of admissible ops. 1180 for (OpOperand &o : op->getOpOperands()) { 1181 Value val = o.get(); 1182 Block *block; 1183 if (auto arg = dyn_cast<BlockArgument>(val)) 1184 block = arg.getOwner(); 1185 else 1186 block = val.getDefiningOp()->getBlock(); 1187 if (!forallOp.getRegion().findAncestorBlockInRegion(*block)) 1188 invariants.insert(val); 1189 } 1190 }); 1191 // Outline the outside values as proper parameters. Fail when sharing 1192 // value between host and device is not straightforward. 1193 SmallVector<Value> constants; 1194 SmallVector<Value> scalars; 1195 SmallVector<Value> buffers; 1196 for (Value val : invariants) { 1197 Type tp = val.getType(); 1198 if (val.getDefiningOp<arith::ConstantOp>()) 1199 constants.push_back(val); 1200 else if (isa<FloatType>(tp) || tp.isIntOrIndex()) 1201 scalars.push_back(val); 1202 else if (isa<MemRefType>(tp)) 1203 buffers.push_back(val); 1204 else 1205 return failure(); // don't know how to share 1206 } 1207 // Pass outlined non-constant values. 1208 // TODO: Experiment with `useHostRegistrationForOut` to see if we want to 1209 // keep the feature at all (either through a heuristic or compiler 1210 // option for gpu codegen). 1211 Location loc = forallOp->getLoc(); 1212 SmallVector<Value> args; 1213 SmallVector<Value> tokens; 1214 Value out = genParametersIn(rewriter, loc, scalars, buffers, args, tokens, 1215 /*useHostRegistrationForOut=*/false); 1216 // Set up GPU module and construct GPU function. 1217 auto saveIp = rewriter.saveInsertionPoint(); 1218 ModuleOp topModule = forallOp->getParentOfType<ModuleOp>(); 1219 auto gpuModule = genGPUModule(rewriter, topModule); 1220 auto gpuFunc = genGPUFunc(rewriter, gpuModule, args); 1221 genGPUCode(rewriter, gpuFunc, forallOp, constants, scalars, buffers); 1222 // Generate code that launches the kernel asynchronously, blocking on all 1223 // opens tokens and yielding a new token for the output. 1224 // TODO: Passing in tokens to launch up does not seem to be properly lowered 1225 // by cubin yet, hence the current blocking wait. 1226 rewriter.restoreInsertionPoint(saveIp); 1227 genBlockingWait(rewriter, loc, tokens); 1228 tokens.clear(); 1229 Value kernelToken = 1230 genLaunchGPUFunc(rewriter, gpuFunc, args, tokens, numThreads); 1231 // Finalize the outlined arguments. 1232 genParametersOut(rewriter, loc, out, kernelToken, scalars, buffers, args, 1233 tokens); 1234 genBlockingWait(rewriter, loc, tokens); 1235 rewriter.eraseOp(forallOp); 1236 return success(); 1237 } 1238 1239 private: 1240 unsigned numThreads; 1241 }; 1242 1243 //===----------------------------------------------------------------------===// 1244 // Rewriting rules for library recognition and code generation. 1245 //===----------------------------------------------------------------------===// 1246 1247 /// Proof-of-concept rewriter. This rule recognizes certain math kernels 1248 /// and replaces these with corresponding calls into a sparse library. 1249 struct LinalgOpRewriter : public OpRewritePattern<linalg::GenericOp> { 1250 using OpRewritePattern<linalg::GenericOp>::OpRewritePattern; 1251 1252 LinalgOpRewriter(MLIRContext *context, bool rt) 1253 : OpRewritePattern(context), enableRT(rt) {} 1254 1255 LogicalResult matchAndRewrite(linalg::GenericOp op, 1256 PatternRewriter &rewriter) const override { 1257 if (op.getNumDpsInits() != 1) 1258 return failure(); // reject multi-output 1259 1260 const unsigned numLoops = op.getNumLoops(); 1261 const unsigned numTensors = op->getNumOperands(); 1262 const auto iteratorTypes = op.getIteratorTypesArray(); 1263 SmallVector<AffineMap, 4> maps = op.getIndexingMapsArray(); 1264 1265 using MapList = ArrayRef<ArrayRef<AffineExpr>>; 1266 auto infer = [&](MapList m) { 1267 return AffineMap::inferFromExprList(m, op.getContext()); 1268 }; 1269 AffineExpr i, j, k; 1270 bindDims(getContext(), i, j, k); 1271 1272 // TODO: more robust patterns, transposed versions, more kernels, 1273 // identify alpha and beta and pass them to the CUDA calls. 1274 1275 // Recognize a SpMV kernel. 1276 if (numLoops == 2 && numTensors == 3 && 1277 linalg::isParallelIterator(iteratorTypes[0]) && 1278 linalg::isReductionIterator(iteratorTypes[1]) && 1279 maps == infer({{i, j}, {j}, {i}}) && matchSumOfMultOfArgs(op)) { 1280 return rewriteSpMV(rewriter, op, enableRT); 1281 } 1282 1283 // Recognize a SpGEMM, 2:4-SpMM, or SpMM kernel. 1284 if (numLoops == 3 && numTensors == 3 && 1285 linalg::isParallelIterator(iteratorTypes[0]) && 1286 linalg::isParallelIterator(iteratorTypes[1]) && 1287 linalg::isReductionIterator(iteratorTypes[2]) && 1288 maps == infer({{i, k}, {k, j}, {i, j}}) && matchSumOfMultOfArgs(op)) { 1289 if (!isDenseTensor(op.getOperand(0)) && !isDenseTensor(op.getOperand(1))) 1290 return rewriteSpGEMM(rewriter, op, enableRT); 1291 if (isConversionInto24(op.getOperand(0))) 1292 return rewrite2To4SpMM(rewriter, op); 1293 return rewriteSpMM(rewriter, op, enableRT); 1294 } 1295 1296 // Recognize a SDDMM kernel. 1297 if (numLoops == 3 && numTensors == 3 && 1298 linalg::isParallelIterator(iteratorTypes[0]) && 1299 linalg::isParallelIterator(iteratorTypes[1]) && 1300 linalg::isReductionIterator(iteratorTypes[2]) && 1301 maps == infer({{i, k}, {k, j}, {i, j}}) && 1302 matchSumReductionOfMulUnary(op)) { 1303 return rewriteSDDMM(rewriter, op, enableRT); 1304 } 1305 1306 return failure(); 1307 } 1308 1309 private: 1310 bool enableRT; 1311 }; 1312 1313 } // namespace 1314 1315 //===----------------------------------------------------------------------===// 1316 // Public method for populating GPU rewriting rules. 1317 // 1318 // Currently two set of rewriting rules are made available. The first set 1319 // implements direct code generation, currently by means of convering the 1320 // outermost paralell loop into GPU threads. The second set implements 1321 // libary recognition of a set of sparse operations. Eventually, the right 1322 // combination of these two approaches has to be found. 1323 //===----------------------------------------------------------------------===// 1324 1325 void mlir::populateSparseGPUCodegenPatterns(RewritePatternSet &patterns, 1326 unsigned numThreads) { 1327 patterns.add<ForallRewriter>(patterns.getContext(), numThreads); 1328 } 1329 1330 void mlir::populateSparseGPULibgenPatterns(RewritePatternSet &patterns, 1331 bool enableRT) { 1332 patterns.add<LinalgOpRewriter>(patterns.getContext(), enableRT); 1333 } 1334