1 //===- SparseGPUCodegen.cpp - Generates GPU code --------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This is a prototype GPU codegenerator for the sparse compiler. 10 // The objective is to eventually use the right combination of 11 // direct code generation and libary calls into vendor-specific 12 // highly optimized sparse libraries (e.g. cuSparse for CUDA). 13 // 14 //===----------------------------------------------------------------------===// 15 16 #include "CodegenUtils.h" 17 #include "LoopEmitter.h" 18 19 #include "mlir/Dialect/Bufferization/IR/Bufferization.h" 20 #include "mlir/Dialect/GPU/IR/GPUDialect.h" 21 #include "mlir/Dialect/Linalg/IR/Linalg.h" 22 #include "mlir/Dialect/Linalg/Utils/Utils.h" 23 #include "mlir/Dialect/MemRef/IR/MemRef.h" 24 #include "mlir/Dialect/SCF/IR/SCF.h" 25 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" 26 #include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h" 27 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h" 28 #include "mlir/IR/IRMapping.h" 29 #include "mlir/IR/Matchers.h" 30 31 using namespace mlir; 32 using namespace mlir::sparse_tensor; 33 34 namespace { 35 36 // Sparse formats supported by cuSparse. 37 enum class CuSparseFormat { 38 kNone, 39 kCOO, 40 kCSR, 41 kCSC, 42 kBSR, 43 }; 44 45 //===----------------------------------------------------------------------===// 46 // Helper methods. 47 //===----------------------------------------------------------------------===// 48 49 /// Marks the given top module as a GPU container module. 50 static void markAsGPUContainer(ModuleOp topModule) { 51 topModule->setAttr(gpu::GPUDialect::getContainerModuleAttrName(), 52 UnitAttr::get(topModule->getContext())); 53 } 54 55 /// Constructs a new GPU module (for GPU kernels) inside the given top module, 56 /// or returns an existing GPU module if one was built previously. 57 static gpu::GPUModuleOp genGPUModule(OpBuilder &builder, ModuleOp topModule) { 58 for (auto op : topModule.getBodyRegion().getOps<gpu::GPUModuleOp>()) 59 return op; // existing 60 markAsGPUContainer(topModule); 61 builder.setInsertionPointToStart(&topModule.getBodyRegion().front()); 62 return builder.create<gpu::GPUModuleOp>(topModule->getLoc(), 63 "sparse_kernels"); 64 } 65 66 /// Constructs a new GPU kernel in the given GPU module. 67 static gpu::GPUFuncOp genGPUFunc(OpBuilder &builder, gpu::GPUModuleOp gpuModule, 68 SmallVectorImpl<Value> &args) { 69 // Get a unique kernel name. Not very creative, 70 // but we simply try kernel0, kernel1, etc. 71 unsigned kernelNumber = 0; 72 SmallString<16> kernelName; 73 do { 74 kernelName.clear(); 75 ("kernel" + Twine(kernelNumber++)).toStringRef(kernelName); 76 } while (gpuModule.lookupSymbol(kernelName)); 77 // Then we insert a new kernel with given arguments into the module. 78 builder.setInsertionPointToStart(&gpuModule.getBodyRegion().front()); 79 SmallVector<Type> argsTp; 80 for (unsigned i = 0, e = args.size(); i < e; i++) 81 argsTp.push_back(args[i].getType()); 82 FunctionType type = FunctionType::get(gpuModule->getContext(), argsTp, {}); 83 auto gpuFunc = 84 builder.create<gpu::GPUFuncOp>(gpuModule->getLoc(), kernelName, type); 85 gpuFunc->setAttr(gpu::GPUDialect::getKernelFuncAttrName(), 86 builder.getUnitAttr()); 87 return gpuFunc; 88 } 89 90 /// Constructs code to launch GPU kernel. 91 static Value genLaunchGPUFunc(OpBuilder &builder, gpu::GPUFuncOp gpuFunc, 92 SmallVectorImpl<Value> &args, 93 SmallVectorImpl<Value> &tokens, 94 unsigned numThreads) { 95 Location loc = gpuFunc->getLoc(); 96 Value none = TypedValue<::mlir::IntegerType>{}; 97 Value one = constantIndex(builder, loc, 1); 98 Value numT = constantIndex(builder, loc, numThreads); 99 gpu::KernelDim3 gridSize = {one, one, one}; 100 gpu::KernelDim3 blckSize = {numT, one, one}; 101 return builder 102 .create<gpu::LaunchFuncOp>(loc, gpuFunc, gridSize, blckSize, 103 /*dynSharedMemSz*/ none, args, 104 builder.getType<gpu::AsyncTokenType>(), tokens) 105 .getAsyncToken(); 106 } 107 108 /// Maps the provided ranked host buffer into the device address space. 109 /// Writes from the host are guaranteed to be visible to device kernels 110 /// that are launched afterwards. Writes from the device are guaranteed 111 /// to be visible on the host after synchronizing with the device kernel 112 /// completion. Needs to cast the buffer to a unranked buffer. 113 static Value genHostRegisterMemref(OpBuilder &builder, Location loc, 114 Value mem) { 115 MemRefType memTp = cast<MemRefType>(mem.getType()); 116 UnrankedMemRefType resTp = 117 UnrankedMemRefType::get(memTp.getElementType(), /*memorySpace=*/0); 118 Value cast = builder.create<memref::CastOp>(loc, resTp, mem); 119 builder.create<gpu::HostRegisterOp>(loc, cast); 120 return cast; 121 } 122 123 /// Unmaps the provided buffer, expecting the casted buffer. 124 static void genHostUnregisterMemref(OpBuilder &builder, Location loc, 125 Value cast) { 126 builder.create<gpu::HostUnregisterOp>(loc, cast); 127 } 128 129 /// Generates first wait in an asynchronous chain. 130 static Value genFirstWait(OpBuilder &builder, Location loc) { 131 Type tokenType = builder.getType<gpu::AsyncTokenType>(); 132 return builder.create<gpu::WaitOp>(loc, tokenType, ValueRange()) 133 .getAsyncToken(); 134 } 135 136 /// Generates last, blocking wait in an asynchronous chain. 137 static void genBlockingWait(OpBuilder &builder, Location loc, 138 ValueRange operands) { 139 builder.create<gpu::WaitOp>(loc, Type(), operands); 140 } 141 142 /// Allocates memory on the device. 143 /// TODO: A `host_shared` attribute could be used to indicate that 144 /// the buffer is visible by both host and device, but lowering 145 /// that feature does not seem to be fully supported yet. 146 static gpu::AllocOp genAllocMemRef(OpBuilder &builder, Location loc, Value mem, 147 Value token) { 148 auto tp = cast<ShapedType>(mem.getType()); 149 auto elemTp = tp.getElementType(); 150 auto shape = tp.getShape(); 151 auto memTp = MemRefType::get(shape, elemTp); 152 SmallVector<Value> dynamicSizes; 153 for (unsigned r = 0, rank = tp.getRank(); r < rank; r++) { 154 if (shape[r] == ShapedType::kDynamic) { 155 Value dimOp = linalg::createOrFoldDimOp(builder, loc, mem, r); 156 dynamicSizes.push_back(dimOp); 157 } 158 } 159 return builder.create<gpu::AllocOp>(loc, TypeRange({memTp, token.getType()}), 160 token, dynamicSizes, ValueRange()); 161 } 162 163 // Allocates a typed buffer on the host with given size. 164 static Value genHostBuffer(OpBuilder &builder, Location loc, Type type, 165 Value size) { 166 const auto memTp = MemRefType::get({ShapedType::kDynamic}, type); 167 return builder.create<memref::AllocOp>(loc, memTp, size).getResult(); 168 } 169 170 // Allocates a typed buffer on the device with given size. 171 static gpu::AllocOp genAllocBuffer(OpBuilder &builder, Location loc, Type type, 172 Value size, Value token) { 173 const auto memTp = MemRefType::get({ShapedType::kDynamic}, type); 174 return builder.create<gpu::AllocOp>(loc, TypeRange({memTp, token.getType()}), 175 token, size, ValueRange()); 176 } 177 178 // Allocates a void buffer on the device with given size. 179 static gpu::AllocOp genAllocBuffer(OpBuilder &builder, Location loc, Value size, 180 Value token) { 181 return genAllocBuffer(builder, loc, builder.getI8Type(), size, token); 182 } 183 184 /// Deallocates memory from the device. 185 static Value genDeallocMemRef(OpBuilder &builder, Location loc, Value mem, 186 Value token) { 187 return builder.create<gpu::DeallocOp>(loc, token.getType(), token, mem) 188 .getAsyncToken(); 189 } 190 191 /// Copies memory between host and device (direction is implicit). 192 static Value genCopyMemRef(OpBuilder &builder, Location loc, Value dst, 193 Value src, Value token) { 194 return builder.create<gpu::MemcpyOp>(loc, token.getType(), token, dst, src) 195 .getAsyncToken(); 196 } 197 198 /// Generates an alloc/copy pair. 199 static Value genAllocCopy(OpBuilder &builder, Location loc, Value b, 200 SmallVectorImpl<Value> &tokens) { 201 Value firstToken = genFirstWait(builder, loc); 202 auto alloc = genAllocMemRef(builder, loc, b, firstToken); 203 Value devMem = alloc.getResult(0); 204 Value depToken = alloc.getAsyncToken(); // copy-after-alloc 205 tokens.push_back(genCopyMemRef(builder, loc, devMem, b, depToken)); 206 return devMem; 207 } 208 209 /// Generates a memref from tensor operation. 210 static Value genTensorToMemref(PatternRewriter &rewriter, Location loc, 211 Value tensor) { 212 auto tensorType = llvm::cast<ShapedType>(tensor.getType()); 213 auto memrefType = 214 MemRefType::get(tensorType.getShape(), tensorType.getElementType()); 215 return rewriter.create<bufferization::ToMemrefOp>(loc, memrefType, tensor); 216 } 217 218 /// Prepares the outlined arguments, passing scalars and buffers in. Here we 219 /// assume that the first buffer is the one allocated for output. We create 220 /// a set of properly chained asynchronous allocation/copy pairs to increase 221 /// overlap before launching the kernel. 222 static Value genParametersIn(OpBuilder &builder, Location loc, 223 SmallVectorImpl<Value> &scalars, 224 SmallVectorImpl<Value> &buffers, 225 SmallVectorImpl<Value> &args, 226 SmallVectorImpl<Value> &tokens, 227 bool useHostRegistrationForOut) { 228 Value out; 229 // Scalars are passed by value. 230 for (Value s : scalars) 231 args.push_back(s); 232 // Buffers are need to be made visible on device. 233 for (Value b : buffers) { 234 if (useHostRegistrationForOut) { 235 out = genHostRegisterMemref(builder, loc, b); 236 args.push_back(b); 237 useHostRegistrationForOut = false; 238 continue; 239 } 240 args.push_back(genAllocCopy(builder, loc, b, tokens)); 241 } 242 return out; 243 } 244 245 /// Finalizes the outlined arguments. The output buffer is copied depending 246 /// on the kernel token and then deallocated. All other buffers are simply 247 /// deallocated. Then we wait for all operations to complete. 248 static void genParametersOut(OpBuilder &builder, Location loc, Value out, 249 Value kernelToken, SmallVectorImpl<Value> &scalars, 250 SmallVectorImpl<Value> &buffers, 251 SmallVectorImpl<Value> &args, 252 SmallVectorImpl<Value> &tokens) { 253 unsigned base = scalars.size(); 254 for (unsigned i = base, e = args.size(); i < e; i++) { 255 Value firstToken; 256 if (i == base) { 257 // Assumed output parameter: unregister or copy-out. 258 if (out) { 259 genHostUnregisterMemref(builder, loc, out); 260 out = Value(); 261 continue; 262 } 263 firstToken = 264 genCopyMemRef(builder, loc, buffers[0], args[i], kernelToken); 265 } else { 266 firstToken = genFirstWait(builder, loc); 267 } 268 tokens.push_back(genDeallocMemRef(builder, loc, args[i], firstToken)); 269 } 270 } 271 272 /// Constructs code for new GPU kernel. 273 static void genGPUCode(PatternRewriter &rewriter, gpu::GPUFuncOp gpuFunc, 274 scf::ParallelOp forallOp, 275 SmallVectorImpl<Value> &constants, 276 SmallVectorImpl<Value> &scalars, 277 SmallVectorImpl<Value> &buffers) { 278 Location loc = gpuFunc->getLoc(); 279 Block &block = gpuFunc.getBody().front(); 280 rewriter.setInsertionPointToStart(&block); 281 282 // Re-generate the constants, recapture all arguments. 283 unsigned arg = 0; 284 IRMapping irMap; 285 for (Value c : constants) 286 irMap.map(c, rewriter.clone(*c.getDefiningOp())->getResult(0)); 287 for (Value s : scalars) 288 irMap.map(s, block.getArgument(arg++)); 289 for (Value b : buffers) 290 irMap.map(b, block.getArgument(arg++)); 291 292 // Assume 1-dimensional grid/block configuration (only x dimension), 293 // so that: 294 // row = blockIdx.x * blockDim.x + threadIdx.x 295 // inc = blockDim.x * gridDim.x 296 Value bid = rewriter.create<gpu::BlockIdOp>(loc, gpu::Dimension::x); 297 Value bsz = rewriter.create<gpu::BlockDimOp>(loc, gpu::Dimension::x); 298 Value tid = rewriter.create<gpu::ThreadIdOp>(loc, gpu::Dimension::x); 299 Value gsz = rewriter.create<gpu::GridDimOp>(loc, gpu::Dimension::x); 300 Value mul = rewriter.create<arith::MulIOp>(loc, bid, bsz); 301 Value row = rewriter.create<arith::AddIOp>(loc, mul, tid); 302 Value inc = rewriter.create<arith::MulIOp>(loc, bsz, gsz); 303 304 // Construct the iteration over the computational space that 305 // accounts for the fact that the total number of threads and 306 // the amount of work to be done usually do not match precisely. 307 // for (r = row; r < N; r += inc) { 308 // <loop-body> 309 // } 310 Value upper = irMap.lookup(forallOp.getUpperBound()[0]); 311 scf::ForOp forOp = rewriter.create<scf::ForOp>(loc, row, upper, inc); 312 rewriter.cloneRegionBefore(forallOp.getRegion(), forOp.getRegion(), 313 forOp.getRegion().begin(), irMap); 314 315 // Done. 316 rewriter.setInsertionPointAfter(forOp); 317 rewriter.create<gpu::ReturnOp>(gpuFunc->getLoc()); 318 } 319 320 //===----------------------------------------------------------------------===// 321 // Library helper methods. 322 //===----------------------------------------------------------------------===// 323 324 /// Helper to detect a + b with arguments taken from given block. 325 static bool matchAddOfArgs(Block *block, Value val) { 326 if (auto *def = val.getDefiningOp()) { 327 if (isa<arith::AddFOp, arith::AddIOp>(def)) { 328 Value a = block->getArguments()[0]; 329 Value b = block->getArguments()[1]; 330 return (def->getOperand(0) == a && def->getOperand(1) == b) || 331 (def->getOperand(0) == b && def->getOperand(1) == a); 332 } 333 } 334 return false; 335 } 336 337 /// Helper to detect a * b with arguments taken from given block. 338 static bool matchMulOfArgs(Block *block, Value val) { 339 if (auto *def = val.getDefiningOp()) { 340 if (isa<arith::MulFOp, arith::MulIOp>(def)) { 341 Value a = block->getArguments()[0]; 342 Value b = block->getArguments()[1]; 343 return (def->getOperand(0) == a && def->getOperand(1) == b) || 344 (def->getOperand(0) == b && def->getOperand(1) == a); 345 } 346 } 347 return false; 348 } 349 350 /// Helper to detect x = x + a * b 351 static bool matchSumOfMultOfArgs(linalg::GenericOp op) { 352 auto yieldOp = cast<linalg::YieldOp>(op.getRegion().front().getTerminator()); 353 if (auto *def = yieldOp.getOperand(0).getDefiningOp()) { 354 if (isa<arith::AddFOp, arith::AddIOp>(def)) { 355 Value x = op.getBlock()->getArguments()[2]; 356 return (def->getOperand(0) == x && 357 matchMulOfArgs(op.getBlock(), def->getOperand(1))) || 358 (def->getOperand(1) == x && 359 matchMulOfArgs(op.getBlock(), def->getOperand(0))); 360 } 361 } 362 return false; 363 } 364 365 // Helper to detect c += spy(s) x (a * b) 366 static bool matchSumReductionOfMulUnary(linalg::GenericOp op) { 367 auto yieldOp = cast<linalg::YieldOp>(op.getRegion().front().getTerminator()); 368 // The linalg yields a custom reduce result. 369 Value s_out = op.getBlock()->getArguments()[2]; 370 if (auto redOp = 371 yieldOp.getOperand(0).getDefiningOp<sparse_tensor::ReduceOp>()) { 372 // The reduce consumes the output. 373 Value other; 374 if (s_out == redOp->getOperand(0)) 375 other = redOp->getOperand(1); 376 else if (s_out == redOp->getOperand(1)) 377 other = redOp->getOperand(0); 378 else 379 return false; 380 // The reduce op also consumes an unary which also consumes the output 381 // and does not define an absent value. 382 if (auto unOp = other.getDefiningOp<sparse_tensor::UnaryOp>()) { 383 if (s_out != unOp->getOperand(0) || !unOp.getAbsentRegion().empty()) 384 return false; 385 // And the bodies are as expected. 386 auto yieldUn = cast<sparse_tensor::YieldOp>( 387 unOp.getRegion(0).front().getTerminator()); 388 auto yieldRed = cast<sparse_tensor::YieldOp>( 389 redOp.getRegion().front().getTerminator()); 390 return matchMulOfArgs(op.getBlock(), yieldUn.getOperand(0)) && 391 matchAddOfArgs(&redOp.getRegion().front(), yieldRed.getOperand(0)); 392 } 393 } 394 return false; 395 } 396 397 /// Test for dense tensor. 398 static bool isDenseTensor(Value v) { 399 auto sTp = getSparseTensorType(v); 400 return sTp.getDimRank() == sTp.getLvlRank() && sTp.isAllDense(); 401 } 402 403 /// Test for suitable positions/coordinates width. 404 static bool isAdmissibleMetaData(SparseTensorType &aTp) { 405 return (aTp.getPosWidth() == 0 || aTp.getPosWidth() >= 16) && 406 (aTp.getCrdWidth() == 0 || aTp.getCrdWidth() >= 16); 407 } 408 409 /// Test for sorted COO matrix with suitable metadata. 410 static bool isAdmissibleCOO(SparseTensorType &aTp) { 411 return aTp.getDimRank() == 2 && aTp.getLvlRank() == 2 && aTp.isIdentity() && 412 aTp.isCompressedLvl(0) && aTp.isOrderedLvl(0) && !aTp.isUniqueLvl(0) && 413 aTp.isSingletonLvl(1) && aTp.isOrderedLvl(1) && aTp.isUniqueLvl(1) && 414 isAdmissibleMetaData(aTp); 415 } 416 417 /// Test for CSR matrix with suitable metadata. 418 static bool isAdmissibleCSR(SparseTensorType &aTp) { 419 return aTp.getDimRank() == 2 && aTp.getLvlRank() == 2 && aTp.isIdentity() && 420 aTp.isDenseLvl(0) && aTp.isCompressedLvl(1) && aTp.isOrderedLvl(1) && 421 aTp.isUniqueLvl(1) && isAdmissibleMetaData(aTp); 422 } 423 424 /// Test for CSC matrix with suitable metadata. 425 static bool isAdmissibleCSC(SparseTensorType &aTp) { 426 return aTp.getDimRank() == 2 && aTp.getLvlRank() == 2 && !aTp.isIdentity() && 427 aTp.isPermutation() && aTp.isDenseLvl(0) && aTp.isCompressedLvl(1) && 428 aTp.isOrderedLvl(1) && aTp.isUniqueLvl(1) && isAdmissibleMetaData(aTp); 429 } 430 431 /// Test for BSR matrix with suitable metadata. 432 static bool isAdmissibleBSR(SparseTensorType &aTp) { 433 if (aTp.getDimRank() == 2 && aTp.getLvlRank() == 4 && aTp.isDenseLvl(0) && 434 aTp.isCompressedLvl(1) && aTp.isOrderedLvl(1) && aTp.isUniqueLvl(1) && 435 aTp.isDenseLvl(2) && aTp.isDenseLvl(3) && isAdmissibleMetaData(aTp)) { 436 // CuSparse only supports "square" blocks currently. 437 SmallVector<unsigned> dims = getBlockSize(aTp.getDimToLvl()); 438 assert(dims.size() == 2); 439 return dims[0] = dims[1] && dims[0] > 1; 440 } 441 return false; 442 } 443 444 /// Returns a suitable sparse format for the operation and given operand 445 /// types with cuSparse, or kNone if none is available. 446 static CuSparseFormat getCuSparseFormat(SparseTensorType aTp, 447 SparseTensorType bTp, 448 SparseTensorType cTp, bool enableRT, 449 bool isMatVec) { 450 // The other operands have a dense type. 451 if (bTp.hasEncoding() || cTp.hasEncoding()) 452 return CuSparseFormat::kNone; 453 // Now check for suitable operand type for the main operand. 454 if (isAdmissibleCOO(aTp)) 455 #ifdef CUSPARSE_COO_AOS 456 return isMatVec ? CuSparseFormat::kCOO : CuSparseFormat::kNone; 457 #else 458 return enableRT ? CuSparseFormat::kCOO : CuSparseFormat::kNone; 459 #endif 460 if (isAdmissibleCSR(aTp)) 461 return CuSparseFormat::kCSR; 462 if (isAdmissibleCSC(aTp)) 463 return CuSparseFormat::kCSC; 464 if (isAdmissibleBSR(aTp)) 465 return CuSparseFormat::kBSR; 466 return CuSparseFormat::kNone; 467 } 468 469 /// Generates the first positions/coordinates of a sparse matrix. 470 static Value genFirstPosOrCrds(OpBuilder &builder, Location loc, Value a, 471 CuSparseFormat format, bool enableRT) { 472 if (format == CuSparseFormat::kCOO) { 473 // Library uses SoA COO, direct IR uses AoS COO. 474 if (enableRT) 475 return genToCoordinates(builder, loc, a, 0, /*cooStart=*/0); 476 return genToCoordinatesBuffer(builder, loc, a); 477 } 478 // Formats CSR/CSC and BSR use positions at 1. 479 return genToPositions(builder, loc, a, 1); 480 } 481 482 /// Generates the second coordinates of a sparse matrix. 483 static Value genSecondCrds(OpBuilder &builder, Location loc, Value a, 484 CuSparseFormat format, bool enableRT) { 485 bool isCOO = format == CuSparseFormat::kCOO; 486 if (isCOO && !enableRT) 487 return Value(); // nothing needed 488 // Formats CSR/CSC and BSR use coordinates at 1. 489 return genToCoordinates(builder, loc, a, 1, /*cooStart=*/isCOO ? 0 : 2); 490 } 491 492 /// Generates the sparse matrix handle. 493 static Operation *genSpMat(OpBuilder &builder, Location loc, 494 SparseTensorType &aTp, Type handleTp, Type tokenTp, 495 Value token, Value sz1, Value sz2, Value nseA, 496 Value rowA, Value colA, Value valA, 497 CuSparseFormat format, bool enableRT) { 498 if (format == CuSparseFormat::kCOO) { 499 // Library uses SoA COO, direct IR uses AoS COO. 500 if (enableRT) { 501 assert(colA); 502 return builder.create<gpu::CreateCooOp>(loc, handleTp, tokenTp, token, 503 sz1, sz2, nseA, rowA, colA, valA); 504 } 505 #ifdef CUSPARSE_COO_AOS 506 assert(!colA); 507 return builder.create<gpu::CreateCooAoSOp>(loc, handleTp, tokenTp, token, 508 sz1, sz2, nseA, rowA, valA); 509 #else 510 llvm_unreachable("gpu::CreateCooAoSOp is deprecated"); 511 #endif 512 } 513 assert(colA); 514 if (format == CuSparseFormat::kCSR) 515 return builder.create<gpu::CreateCsrOp>(loc, handleTp, tokenTp, token, sz1, 516 sz2, nseA, rowA, colA, valA); 517 if (format == CuSparseFormat::kCSC) 518 return builder.create<gpu::CreateCscOp>(loc, handleTp, tokenTp, token, sz1, 519 sz2, nseA, rowA, colA, valA); 520 // BSR requires a bit more work since we need to pass in the block size 521 // and all others sizes in terms of blocks (#block-rows, #block-cols, 522 // #nonzero-blocks). 523 assert(format == CuSparseFormat::kBSR); 524 SmallVector<unsigned> dims = getBlockSize(aTp.getDimToLvl()); 525 assert(dims.size() == 2 && dims[0] == dims[1]); 526 uint64_t b = dims[0]; 527 Value bSz = constantIndex(builder, loc, b); 528 Value bRows = builder.create<arith::DivUIOp>(loc, sz1, bSz); 529 Value bCols = builder.create<arith::DivUIOp>(loc, sz2, bSz); 530 Value bNum = builder.create<arith::DivUIOp>( 531 loc, nseA, constantIndex(builder, loc, b * b)); 532 return builder.create<gpu::CreateBsrOp>(loc, handleTp, tokenTp, token, bRows, 533 bCols, bNum, bSz, bSz, rowA, colA, 534 valA); 535 } 536 537 /// Match and rewrite SpMV kernel. 538 static LogicalResult 539 rewriteSpMV(PatternRewriter &rewriter, linalg::GenericOp op, bool enableRT, 540 GPUDataTransferStrategy gpuDataTransferStrategy) { 541 Location loc = op.getLoc(); 542 Value a = op.getOperand(0); 543 Value x = op.getOperand(1); 544 Value y = op.getOperand(2); // we have y = Ax 545 SmallVector<Value> tokens; 546 547 bool isZeroCopy = 548 gpuDataTransferStrategy == GPUDataTransferStrategy::kZeroCopy; 549 550 // Only admissible sparse matrix format and dense vectors (no BSR). 551 SparseTensorType aTp = getSparseTensorType(a); 552 SparseTensorType xTp = getSparseTensorType(x); 553 SparseTensorType yTp = getSparseTensorType(y); 554 auto format = getCuSparseFormat(aTp, xTp, yTp, enableRT, /*isMatVec=*/true); 555 if (format == CuSparseFormat::kNone || format == CuSparseFormat::kBSR) 556 return failure(); 557 558 // Start sparse kernel and copy data from host to device. 559 // a : memR/memC/memV -> rowA,colA,valA 560 // x : memX -> vecX 561 // y : memY -> vecY 562 Value nseA = rewriter.create<NumberOfEntriesOp>(loc, a); 563 Value szY = linalg::createOrFoldDimOp(rewriter, loc, a, 0); 564 Value szX = linalg::createOrFoldDimOp(rewriter, loc, a, 1); 565 Value memR = genFirstPosOrCrds(rewriter, loc, a, format, enableRT); 566 Value memC = genSecondCrds(rewriter, loc, a, format, enableRT); 567 Value memV = genToValues(rewriter, loc, a); 568 Value memX, memY; 569 Value castR, castC, castV, castX, castY; 570 if (gpuDataTransferStrategy != GPUDataTransferStrategy::kRegularDMA) { 571 memX = genTensorToMemref(rewriter, loc, x); 572 memY = genTensorToMemref(rewriter, loc, y); 573 castR = genHostRegisterMemref(rewriter, loc, memR); 574 if (memC) 575 castC = genHostRegisterMemref(rewriter, loc, memC); 576 castV = genHostRegisterMemref(rewriter, loc, memV); 577 castX = genHostRegisterMemref(rewriter, loc, memX); 578 castY = genHostRegisterMemref(rewriter, loc, memY); 579 } 580 581 Value rowA = genAllocCopy(rewriter, loc, memR, tokens); 582 Value colA = memC ? genAllocCopy(rewriter, loc, memC, tokens) : Value(); 583 Value valA = genAllocCopy(rewriter, loc, memV, tokens); 584 if (gpuDataTransferStrategy == GPUDataTransferStrategy::kRegularDMA) 585 memX = genTensorToMemref(rewriter, loc, x); 586 Value vecX = isZeroCopy ? memX : genAllocCopy(rewriter, loc, memX, tokens); 587 if (gpuDataTransferStrategy == GPUDataTransferStrategy::kRegularDMA) 588 memY = genTensorToMemref(rewriter, loc, y); 589 Value vecY = genAllocCopy(rewriter, loc, memY, tokens); 590 genBlockingWait(rewriter, loc, tokens); 591 tokens.clear(); 592 593 // Create sparse environment and sparse matrix/dense vector handles. 594 Type indexTp = rewriter.getIndexType(); 595 Type dnTensorHandleTp = rewriter.getType<gpu::SparseDnTensorHandleType>(); 596 Type spmatHandleTp = rewriter.getType<gpu::SparseSpMatHandleType>(); 597 Type tokenTp = rewriter.getType<gpu::AsyncTokenType>(); 598 Value token = genFirstWait(rewriter, loc); 599 Operation *spGenA = 600 genSpMat(rewriter, loc, aTp, spmatHandleTp, tokenTp, token, szY, szX, 601 nseA, rowA, colA, valA, format, enableRT); 602 Value spMatA = spGenA->getResult(0); 603 token = spGenA->getResult(1); 604 auto dvecX = rewriter.create<gpu::CreateDnTensorOp>( 605 loc, dnTensorHandleTp, tokenTp, token, vecX, szX); 606 Value dnX = dvecX.getResult(0); 607 token = dvecX.getAsyncToken(); 608 auto dvecY = rewriter.create<gpu::CreateDnTensorOp>( 609 loc, dnTensorHandleTp, tokenTp, token, vecY, szY); 610 Value dnY = dvecY.getResult(0); 611 token = dvecY.getAsyncToken(); 612 auto dnYType = llvm::cast<ShapedType>(y.getType()).getElementType(); 613 614 // Precompute buffersize for SpMV. 615 auto bufferComp = rewriter.create<gpu::SpMVBufferSizeOp>( 616 loc, indexTp, tokenTp, token, spMatA, dnX, dnY, 617 /*computeType=*/dnYType); 618 Value bufferSz = bufferComp.getResult(0); 619 token = bufferComp.getAsyncToken(); 620 auto buf = genAllocBuffer(rewriter, loc, bufferSz, token); 621 Value buffer = buf.getResult(0); 622 token = buf.getAsyncToken(); 623 624 // Perform the SpMV. 625 auto spmvComp = rewriter.create<gpu::SpMVOp>( 626 loc, tokenTp, token, spMatA, dnX, dnY, /*computeType=*/dnYType, buffer); 627 token = spmvComp.getAsyncToken(); 628 629 // Copy data back to host and free all the resoures. 630 token = rewriter.create<gpu::DestroySpMatOp>(loc, tokenTp, token, spMatA) 631 .getAsyncToken(); 632 token = rewriter.create<gpu::DestroyDnTensorOp>(loc, tokenTp, token, dnX) 633 .getAsyncToken(); 634 token = rewriter.create<gpu::DestroyDnTensorOp>(loc, tokenTp, token, dnY) 635 .getAsyncToken(); 636 token = genDeallocMemRef(rewriter, loc, rowA, token); 637 if (colA) 638 token = genDeallocMemRef(rewriter, loc, colA, token); 639 token = genDeallocMemRef(rewriter, loc, valA, token); 640 token = genDeallocMemRef(rewriter, loc, buffer, token); 641 if (!isZeroCopy) 642 token = genDeallocMemRef(rewriter, loc, vecX, token); 643 token = genCopyMemRef(rewriter, loc, memY, vecY, token); 644 token = genDeallocMemRef(rewriter, loc, vecY, token); 645 tokens.push_back(token); 646 genBlockingWait(rewriter, loc, tokens); 647 tokens.clear(); 648 if (gpuDataTransferStrategy != GPUDataTransferStrategy::kRegularDMA) { 649 genHostUnregisterMemref(rewriter, loc, castR); 650 if (memC) 651 genHostUnregisterMemref(rewriter, loc, castC); 652 genHostUnregisterMemref(rewriter, loc, castV); 653 genHostUnregisterMemref(rewriter, loc, castX); 654 genHostUnregisterMemref(rewriter, loc, castY); 655 } 656 657 // Done. 658 rewriter.replaceOpWithNewOp<bufferization::ToTensorOp>(op, memY); 659 return success(); 660 } 661 662 /// Match and rewrite SpMM kernel. 663 static LogicalResult 664 rewriteSpMM(PatternRewriter &rewriter, linalg::GenericOp op, bool enableRT, 665 GPUDataTransferStrategy gpuDataTransferStrategy) { 666 Location loc = op.getLoc(); 667 Value a = op.getOperand(0); 668 Value b = op.getOperand(1); 669 Value c = op.getOperand(2); // we have C = AB 670 SmallVector<Value> tokens; 671 672 bool isZeroCopy = 673 gpuDataTransferStrategy == GPUDataTransferStrategy::kZeroCopy; 674 675 // Only admissible sparse matrix format and dense matrices (no BSR). 676 SparseTensorType aTp = getSparseTensorType(a); 677 SparseTensorType bTp = getSparseTensorType(b); 678 SparseTensorType cTp = getSparseTensorType(c); 679 auto format = getCuSparseFormat(aTp, bTp, cTp, enableRT, /*isMatVec=*/false); 680 if (format == CuSparseFormat::kNone || format == CuSparseFormat::kBSR) 681 return failure(); 682 683 // Start sparse kernel and copy data from host to device. 684 // a : memR/memC/memV -> rowA,colA,valA 685 // b : bufB -> matA 686 // c : bufC -> matC 687 Value nseA = rewriter.create<NumberOfEntriesOp>(loc, a); 688 Value szm = linalg::createOrFoldDimOp(rewriter, loc, a, 0); 689 Value szk = linalg::createOrFoldDimOp(rewriter, loc, a, 1); 690 Value szn = linalg::createOrFoldDimOp(rewriter, loc, b, 1); 691 Value memR = genFirstPosOrCrds(rewriter, loc, a, format, enableRT); 692 Value memC = genSecondCrds(rewriter, loc, a, format, enableRT); 693 Value memV = genToValues(rewriter, loc, a); 694 Value bufB, bufC; 695 Value castR, castC, castV, castB, castBufC; 696 if (gpuDataTransferStrategy != GPUDataTransferStrategy::kRegularDMA) { 697 bufB = genTensorToMemref(rewriter, loc, b); 698 bufC = genTensorToMemref(rewriter, loc, c); 699 castR = genHostRegisterMemref(rewriter, loc, memR); 700 if (memC) 701 castC = genHostRegisterMemref(rewriter, loc, memC); 702 castV = genHostRegisterMemref(rewriter, loc, memV); 703 castB = genHostRegisterMemref(rewriter, loc, bufB); 704 castBufC = genHostRegisterMemref(rewriter, loc, bufC); 705 } 706 Value rowA = genAllocCopy(rewriter, loc, memR, tokens); 707 Value colA = memC ? genAllocCopy(rewriter, loc, memC, tokens) : Value(); 708 Value valA = genAllocCopy(rewriter, loc, memV, tokens); 709 if (gpuDataTransferStrategy == GPUDataTransferStrategy::kRegularDMA) 710 bufB = genTensorToMemref(rewriter, loc, b); 711 Value matB = isZeroCopy ? bufB : genAllocCopy(rewriter, loc, bufB, tokens); 712 if (gpuDataTransferStrategy == GPUDataTransferStrategy::kRegularDMA) 713 bufC = genTensorToMemref(rewriter, loc, c); 714 Value matC = genAllocCopy(rewriter, loc, bufC, tokens); 715 genBlockingWait(rewriter, loc, tokens); 716 tokens.clear(); 717 718 // Create sparse environment and sparse matrix/dense matrix handles. 719 Type indexTp = rewriter.getIndexType(); 720 Type dnTensorHandleTp = rewriter.getType<gpu::SparseDnTensorHandleType>(); 721 Type spMatHandleTp = rewriter.getType<gpu::SparseSpMatHandleType>(); 722 Type tokenTp = rewriter.getType<gpu::AsyncTokenType>(); 723 Value token = genFirstWait(rewriter, loc); 724 Operation *spGenA = 725 genSpMat(rewriter, loc, aTp, spMatHandleTp, tokenTp, token, szm, szk, 726 nseA, rowA, colA, valA, format, enableRT); 727 Value spMatA = spGenA->getResult(0); 728 token = spGenA->getResult(1); 729 auto dmatB = rewriter.create<gpu::CreateDnTensorOp>( 730 loc, dnTensorHandleTp, tokenTp, token, matB, 731 SmallVector<Value>{szk, szn}); 732 Value dnB = dmatB.getResult(0); 733 token = dmatB.getAsyncToken(); 734 auto dmatC = rewriter.create<gpu::CreateDnTensorOp>( 735 loc, dnTensorHandleTp, tokenTp, token, matC, 736 SmallVector<Value>{szm, szn}); 737 Value dnC = dmatC.getResult(0); 738 token = dmatC.getAsyncToken(); 739 auto dmatCType = llvm::cast<ShapedType>(c.getType()).getElementType(); 740 741 // Precompute buffersize for SpMM. 742 auto bufferComp = rewriter.create<gpu::SpMMBufferSizeOp>( 743 loc, indexTp, tokenTp, token, spMatA, dnB, dnC, 744 /*computeType=*/dmatCType); 745 Value bufferSz = bufferComp.getResult(0); 746 token = bufferComp.getAsyncToken(); 747 auto buf = genAllocBuffer(rewriter, loc, bufferSz, token); 748 Value buffer = buf.getResult(0); 749 token = buf.getAsyncToken(); 750 auto dnCType = llvm::cast<ShapedType>(c.getType()).getElementType(); 751 752 // Perform the SpMM. 753 auto spmmComp = rewriter.create<gpu::SpMMOp>( 754 loc, tokenTp, token, spMatA, dnB, dnC, /*computeType=*/dnCType, buffer); 755 token = spmmComp.getAsyncToken(); 756 757 // Copy data back to host and free all the resoures. 758 token = rewriter.create<gpu::DestroySpMatOp>(loc, tokenTp, token, spMatA) 759 .getAsyncToken(); 760 token = rewriter.create<gpu::DestroyDnTensorOp>(loc, tokenTp, token, dnB) 761 .getAsyncToken(); 762 token = rewriter.create<gpu::DestroyDnTensorOp>(loc, tokenTp, token, dnC) 763 .getAsyncToken(); 764 token = genDeallocMemRef(rewriter, loc, rowA, token); 765 if (colA) 766 token = genDeallocMemRef(rewriter, loc, colA, token); 767 token = genDeallocMemRef(rewriter, loc, valA, token); 768 token = genDeallocMemRef(rewriter, loc, buffer, token); 769 if (!isZeroCopy) 770 token = genDeallocMemRef(rewriter, loc, matB, token); 771 token = genCopyMemRef(rewriter, loc, bufC, matC, token); 772 token = genDeallocMemRef(rewriter, loc, matC, token); 773 tokens.push_back(token); 774 genBlockingWait(rewriter, loc, tokens); 775 tokens.clear(); 776 if (gpuDataTransferStrategy != GPUDataTransferStrategy::kRegularDMA) { 777 genHostUnregisterMemref(rewriter, loc, castR); 778 if (memC) 779 genHostUnregisterMemref(rewriter, loc, castC); 780 genHostUnregisterMemref(rewriter, loc, castV); 781 genHostUnregisterMemref(rewriter, loc, castB); 782 genHostUnregisterMemref(rewriter, loc, castC); 783 } 784 785 // Done. 786 rewriter.replaceOpWithNewOp<bufferization::ToTensorOp>(op, bufC); 787 return success(); 788 } 789 790 // Match and rewrite SpGEMM kernel. 791 static LogicalResult 792 rewriteSpGEMM(PatternRewriter &rewriter, linalg::GenericOp op, bool enableRT, 793 GPUDataTransferStrategy gpuDataTransferStrategy) { 794 Location loc = op.getLoc(); 795 Value a = op.getOperand(0); 796 Value b = op.getOperand(1); 797 Value c = op.getOperand(2); // we have C = AB 798 SmallVector<Value> tokens; 799 800 // Only CSR <- CSR x CSR supported. 801 auto format = CuSparseFormat::kCSR; 802 SparseTensorType aTp = getSparseTensorType(a); 803 SparseTensorType bTp = getSparseTensorType(b); 804 SparseTensorType cTp = getSparseTensorType(c); 805 if (!isAdmissibleCSR(aTp) || !isAdmissibleCSR(bTp) || !isAdmissibleCSR(cTp)) 806 return failure(); 807 808 // Start sparse kernel and copy data from host to device. 809 // a : amemR/amemC/amemV -> rowA,colA,valA 810 // b : bmemR/bmemC/bmemV -> rowB,colB,valB 811 // c : materializes 812 auto dnCType = cTp.getElementType(); 813 Value nseA = rewriter.create<NumberOfEntriesOp>(loc, a); 814 Value nseB = rewriter.create<NumberOfEntriesOp>(loc, b); 815 Value szm = linalg::createOrFoldDimOp(rewriter, loc, a, 0); 816 Value szk = linalg::createOrFoldDimOp(rewriter, loc, a, 1); 817 Value szn = linalg::createOrFoldDimOp(rewriter, loc, b, 1); 818 Value amemR = genFirstPosOrCrds(rewriter, loc, a, format, enableRT); 819 Value amemC = genSecondCrds(rewriter, loc, a, format, enableRT); 820 Value amemV = genToValues(rewriter, loc, a); 821 Value bmemR = genFirstPosOrCrds(rewriter, loc, b, format, enableRT); 822 Value bmemC = genSecondCrds(rewriter, loc, b, format, enableRT); 823 Value bmemV = genToValues(rewriter, loc, b); 824 Value rowA = genAllocCopy(rewriter, loc, amemR, tokens); 825 Value colA = genAllocCopy(rewriter, loc, amemC, tokens); 826 Value valA = genAllocCopy(rewriter, loc, amemV, tokens); 827 Value rowB = genAllocCopy(rewriter, loc, bmemR, tokens); 828 Value colB = genAllocCopy(rewriter, loc, bmemC, tokens); 829 Value valB = genAllocCopy(rewriter, loc, bmemV, tokens); 830 genBlockingWait(rewriter, loc, tokens); 831 tokens.clear(); 832 833 // Create sparse environment and sparse matrix/dense vector handles. 834 Type indexTp = rewriter.getIndexType(); 835 Type spmatHandleTp = rewriter.getType<gpu::SparseSpMatHandleType>(); 836 Type descTp = rewriter.getType<gpu::SparseSpGEMMOpHandleType>(); 837 Type tokenTp = rewriter.getType<gpu::AsyncTokenType>(); 838 Value token = genFirstWait(rewriter, loc); 839 Operation *spGenA = 840 genSpMat(rewriter, loc, aTp, spmatHandleTp, tokenTp, token, szm, szk, 841 nseA, rowA, colA, valA, format, enableRT); 842 Value spMatA = spGenA->getResult(0); 843 token = spGenA->getResult(1); 844 Operation *spGenB = 845 genSpMat(rewriter, loc, bTp, spmatHandleTp, tokenTp, token, szk, szn, 846 nseB, rowB, colB, valB, format, enableRT); 847 Value spMatB = spGenB->getResult(0); 848 token = spGenB->getResult(1); 849 850 // Sparse matrix C materializes (also assumes beta == 0). 851 Value zero = constantIndex(rewriter, loc, 0); 852 Value one = constantIndex(rewriter, loc, 1); 853 Value mplus1 = rewriter.create<arith::AddIOp>(loc, szm, one); 854 auto e1 = genAllocBuffer(rewriter, loc, cTp.getPosType(), mplus1, token); 855 Value rowC = e1.getResult(0); 856 token = e1.getAsyncToken(); 857 auto e2 = genAllocBuffer(rewriter, loc, cTp.getCrdType(), zero, token); 858 Value colC = e2.getResult(0); // no free needed 859 token = e2.getAsyncToken(); 860 auto e3 = genAllocBuffer(rewriter, loc, dnCType, zero, token); 861 Value valC = e3.getResult(0); // no free needed 862 token = e3.getAsyncToken(); 863 Operation *spGenC = 864 genSpMat(rewriter, loc, cTp, spmatHandleTp, tokenTp, token, szm, szn, 865 zero, rowC, colC, valC, format, enableRT); 866 Value spMatC = spGenC->getResult(0); 867 token = spGenC->getResult(1); 868 869 // Precompute buffersizes for SpGEMM. 870 Operation *descOp = 871 rewriter.create<gpu::SpGEMMCreateDescrOp>(loc, descTp, tokenTp, token); 872 Value desc = descOp->getResult(0); 873 token = descOp->getResult(1); 874 Operation *work1 = rewriter.create<gpu::SpGEMMWorkEstimationOrComputeOp>( 875 loc, indexTp, tokenTp, token, desc, gpu::TransposeMode::NON_TRANSPOSE, 876 gpu::TransposeMode::NON_TRANSPOSE, spMatA, spMatB, spMatC, dnCType, zero, 877 valC, gpu::SpGEMMWorkEstimationOrComputeKind::WORK_ESTIMATION); 878 Value bufferSz1 = work1->getResult(0); 879 token = work1->getResult(1); 880 auto buf1 = genAllocBuffer(rewriter, loc, bufferSz1, token); 881 Value buffer1 = buf1.getResult(0); 882 token = buf1.getAsyncToken(); 883 Operation *work2 = rewriter.create<gpu::SpGEMMWorkEstimationOrComputeOp>( 884 loc, indexTp, tokenTp, token, desc, gpu::TransposeMode::NON_TRANSPOSE, 885 gpu::TransposeMode::NON_TRANSPOSE, spMatA, spMatB, spMatC, dnCType, 886 bufferSz1, buffer1, 887 gpu::SpGEMMWorkEstimationOrComputeKind::WORK_ESTIMATION); 888 token = work2->getResult(1); 889 890 // Compute step. 891 Operation *compute1 = rewriter.create<gpu::SpGEMMWorkEstimationOrComputeOp>( 892 loc, indexTp, tokenTp, token, desc, gpu::TransposeMode::NON_TRANSPOSE, 893 gpu::TransposeMode::NON_TRANSPOSE, spMatA, spMatB, spMatC, dnCType, zero, 894 valC, gpu::SpGEMMWorkEstimationOrComputeKind::COMPUTE); 895 Value bufferSz2 = compute1->getResult(0); 896 token = compute1->getResult(1); 897 auto buf2 = genAllocBuffer(rewriter, loc, bufferSz2, token); 898 Value buffer2 = buf2.getResult(0); 899 token = buf2.getAsyncToken(); 900 Operation *compute2 = rewriter.create<gpu::SpGEMMWorkEstimationOrComputeOp>( 901 loc, indexTp, tokenTp, token, desc, gpu::TransposeMode::NON_TRANSPOSE, 902 gpu::TransposeMode::NON_TRANSPOSE, spMatA, spMatB, spMatC, dnCType, 903 bufferSz2, buffer2, gpu::SpGEMMWorkEstimationOrComputeKind::COMPUTE); 904 token = compute2->getResult(1); 905 906 // Get sizes. 907 Operation *sizes = rewriter.create<gpu::SpMatGetSizeOp>( 908 loc, indexTp, indexTp, indexTp, tokenTp, token, spMatC); 909 Value nnz = sizes->getResult(2); 910 token = sizes->getResult(3); 911 auto a2 = genAllocBuffer(rewriter, loc, cTp.getCrdType(), nnz, token); 912 colC = a2.getResult(0); 913 token = a2.getAsyncToken(); 914 auto a3 = genAllocBuffer(rewriter, loc, dnCType, nnz, token); 915 valC = a3.getResult(0); 916 token = a3.getAsyncToken(); 917 918 // Update C with new pointers and copy final product back into C. 919 Operation *update = rewriter.create<gpu::SetCsrPointersOp>( 920 loc, tokenTp, token, spMatC, rowC, colC, valC); 921 token = update->getResult(0); 922 Operation *copy = rewriter.create<gpu::SpGEMMCopyOp>( 923 loc, tokenTp, token, desc, gpu::TransposeMode::NON_TRANSPOSE, 924 gpu::TransposeMode::NON_TRANSPOSE, spMatA, spMatB, spMatC, dnCType); 925 token = copy->getResult(0); 926 927 // Allocate buffers on host. 928 Value rowH = genHostBuffer(rewriter, loc, cTp.getPosType(), mplus1); 929 Value colH = genHostBuffer(rewriter, loc, cTp.getCrdType(), nnz); 930 Value valH = genHostBuffer(rewriter, loc, dnCType, nnz); 931 932 // Copy data back to host and free all the resoures. 933 token = rewriter.create<gpu::SpGEMMDestroyDescrOp>(loc, tokenTp, token, desc) 934 .getAsyncToken(); 935 token = rewriter.create<gpu::DestroySpMatOp>(loc, tokenTp, token, spMatA) 936 .getAsyncToken(); 937 token = rewriter.create<gpu::DestroySpMatOp>(loc, tokenTp, token, spMatB) 938 .getAsyncToken(); 939 token = rewriter.create<gpu::DestroySpMatOp>(loc, tokenTp, token, spMatC) 940 .getAsyncToken(); 941 token = genCopyMemRef(rewriter, loc, rowH, rowC, token); 942 token = genCopyMemRef(rewriter, loc, colH, colC, token); 943 token = genCopyMemRef(rewriter, loc, valH, valC, token); 944 token = genDeallocMemRef(rewriter, loc, rowA, token); 945 token = genDeallocMemRef(rewriter, loc, colA, token); 946 token = genDeallocMemRef(rewriter, loc, valA, token); 947 token = genDeallocMemRef(rewriter, loc, rowB, token); 948 token = genDeallocMemRef(rewriter, loc, colB, token); 949 token = genDeallocMemRef(rewriter, loc, valB, token); 950 token = genDeallocMemRef(rewriter, loc, rowC, token); 951 token = genDeallocMemRef(rewriter, loc, colC, token); 952 token = genDeallocMemRef(rewriter, loc, valC, token); 953 token = genDeallocMemRef(rewriter, loc, buffer1, token); 954 token = genDeallocMemRef(rewriter, loc, buffer2, token); 955 tokens.push_back(token); 956 genBlockingWait(rewriter, loc, tokens); 957 tokens.clear(); 958 959 // Done. 960 Value vt = rewriter.create<bufferization::ToTensorOp>(loc, valH); 961 Value rt = rewriter.create<bufferization::ToTensorOp>(loc, rowH); 962 Value ct = rewriter.create<bufferization::ToTensorOp>(loc, colH); 963 rewriter.replaceOpWithNewOp<AssembleOp>(op, c.getType(), vt, 964 ValueRange{rt, ct}); 965 return success(); 966 } 967 968 // Match and rewrite 2:4 SpMM kernel. 969 static LogicalResult 970 rewrite2To4SpMM(PatternRewriter &rewriter, linalg::GenericOp op, 971 GPUDataTransferStrategy gpuDataTransferStrategy) { 972 Location loc = op.getLoc(); 973 Value A = op.getOperand(0); 974 Value B = op.getOperand(1); 975 Value C = op.getOperand(2); // we have C = AB 976 SmallVector<Value> tokens; 977 978 bool isZeroCopy = 979 gpuDataTransferStrategy == GPUDataTransferStrategy::kZeroCopy; 980 981 // All input should be dense tensors. 982 if (!isDenseTensor(A) || !isDenseTensor(B) || !isDenseTensor(C)) 983 return failure(); 984 985 Value matA, matB; 986 Value bufA = genTensorToMemref(rewriter, loc, A); 987 if (!isZeroCopy) 988 matA = genAllocCopy(rewriter, loc, bufA, tokens); 989 Value bufB = genTensorToMemref(rewriter, loc, B); 990 if (!isZeroCopy) 991 matB = genAllocCopy(rewriter, loc, bufB, tokens); 992 Value bufC = genTensorToMemref(rewriter, loc, C); 993 Value castA, castB, castC; 994 if (gpuDataTransferStrategy != GPUDataTransferStrategy::kRegularDMA) { 995 castA = genHostRegisterMemref(rewriter, loc, bufA); 996 castB = genHostRegisterMemref(rewriter, loc, bufB); 997 castC = genHostRegisterMemref(rewriter, loc, bufC); 998 } 999 if (isZeroCopy) { 1000 matA = bufA; 1001 matB = bufB; 1002 } 1003 Value matC = genAllocCopy(rewriter, loc, bufC, tokens); 1004 genBlockingWait(rewriter, loc, tokens); 1005 tokens.clear(); 1006 1007 // Create sparse environment and sparse matrix/dense vector handles. 1008 Value szm = linalg::createOrFoldDimOp(rewriter, loc, matA, 0); 1009 Value szk = linalg::createOrFoldDimOp(rewriter, loc, matB, 0); 1010 Value szn = linalg::createOrFoldDimOp(rewriter, loc, matC, 1); 1011 Type indexTp = rewriter.getIndexType(); 1012 Type dnTensorHandleTp = rewriter.getType<gpu::SparseDnTensorHandleType>(); 1013 Type spMatHandleTp = rewriter.getType<gpu::SparseSpMatHandleType>(); 1014 Type tokenTp = rewriter.getType<gpu::AsyncTokenType>(); 1015 Value token = genFirstWait(rewriter, loc); 1016 Operation *spGenA = rewriter.create<gpu::Create2To4SpMatOp>( 1017 loc, spMatHandleTp, tokenTp, token, szm, szk, 1018 gpu::Prune2To4SpMatFlag::PRUNE_AND_CHECK, matA); 1019 Value spMatA = spGenA->getResult(0); 1020 token = spGenA->getResult(1); 1021 auto dmatB = rewriter.create<gpu::CreateDnTensorOp>( 1022 loc, dnTensorHandleTp, tokenTp, token, matB, 1023 SmallVector<Value>{szk, szn}); 1024 Value dnB = dmatB.getResult(0); 1025 token = dmatB.getAsyncToken(); 1026 auto dmatC = rewriter.create<gpu::CreateDnTensorOp>( 1027 loc, dnTensorHandleTp, tokenTp, token, matC, 1028 SmallVector<Value>{szm, szn}); 1029 Value dnC = dmatC.getResult(0); 1030 token = dmatC.getAsyncToken(); 1031 auto dmatCType = llvm::cast<ShapedType>(matC.getType()).getElementType(); 1032 1033 // Precompute buffersize for SpMM. 1034 SmallVector<Type> bufferTypes_{indexTp, indexTp, indexTp}; 1035 TypeRange bufferTypes(bufferTypes_); 1036 auto bufferComp = rewriter.create<gpu::SpMMBufferSizeOp>( 1037 loc, bufferTypes, tokenTp, token, gpu::TransposeMode::NON_TRANSPOSE, 1038 gpu::TransposeMode::NON_TRANSPOSE, spMatA, dnB, dnC, 1039 /*computeType=*/dmatCType); 1040 token = bufferComp.getAsyncToken(); 1041 1042 Value bufferSz = bufferComp.getResult(0); 1043 auto buf = genAllocBuffer(rewriter, loc, bufferSz, token); 1044 Value buffer = buf.getResult(0); 1045 token = buf.getAsyncToken(); 1046 1047 Value bufferSz2 = bufferComp.getResult(1); 1048 auto buf2 = genAllocBuffer(rewriter, loc, bufferSz2, token); 1049 Value buffer2 = buf2.getResult(0); 1050 token = buf2.getAsyncToken(); 1051 1052 Value bufferSz3 = bufferComp.getResult(2); 1053 auto buf3 = genAllocBuffer(rewriter, loc, bufferSz3, token); 1054 Value buffer3 = buf3.getResult(0); 1055 token = buf3.getAsyncToken(); 1056 1057 auto dnCType = llvm::cast<ShapedType>(matC.getType()).getElementType(); 1058 1059 // Perform the SpMM. 1060 auto spmmComp = rewriter.create<gpu::SpMMOp>( 1061 loc, tokenTp, token, spMatA, dnB, dnC, /*computeType=*/dnCType, 1062 SmallVector<Value>{buffer, buffer2, buffer3}); 1063 token = spmmComp.getAsyncToken(); 1064 1065 // Copy data back to host and free all the resources. 1066 token = rewriter.create<gpu::DestroySpMatOp>(loc, tokenTp, token, spMatA) 1067 .getAsyncToken(); 1068 token = rewriter.create<gpu::DestroyDnTensorOp>(loc, tokenTp, token, dnB) 1069 .getAsyncToken(); 1070 token = rewriter.create<gpu::DestroyDnTensorOp>(loc, tokenTp, token, dnC) 1071 .getAsyncToken(); 1072 SmallVector<Value> newDynamicSizes; 1073 token = genDeallocMemRef(rewriter, loc, buffer, token); 1074 token = genDeallocMemRef(rewriter, loc, buffer2, token); 1075 token = genDeallocMemRef(rewriter, loc, buffer3, token); 1076 if (!isZeroCopy) 1077 token = genDeallocMemRef(rewriter, loc, matA, token); 1078 if (!isZeroCopy) 1079 token = genDeallocMemRef(rewriter, loc, matB, token); 1080 token = genCopyMemRef(rewriter, loc, bufC, matC, token); 1081 token = genDeallocMemRef(rewriter, loc, matC, token); 1082 tokens.push_back(token); 1083 genBlockingWait(rewriter, loc, tokens); 1084 tokens.clear(); 1085 if (gpuDataTransferStrategy != GPUDataTransferStrategy::kRegularDMA) { 1086 genHostUnregisterMemref(rewriter, loc, castA); 1087 genHostUnregisterMemref(rewriter, loc, castB); 1088 genHostUnregisterMemref(rewriter, loc, castC); 1089 } 1090 1091 // Done. 1092 rewriter.replaceOpWithNewOp<bufferization::ToTensorOp>(op, bufC); 1093 return success(); 1094 } 1095 1096 /// Match and rewrite SDDMM kernel. 1097 static LogicalResult 1098 rewriteSDDMM(PatternRewriter &rewriter, linalg::GenericOp op, bool enableRT, 1099 GPUDataTransferStrategy gpuDataTransferStrategy) { 1100 Location loc = op.getLoc(); 1101 Value a = op.getOperand(0); 1102 Value b = op.getOperand(1); 1103 Value c = op.getOperand(2); 1104 SmallVector<Value> tokens; 1105 1106 bool isZeroCopy = 1107 gpuDataTransferStrategy == GPUDataTransferStrategy::kZeroCopy; 1108 1109 // Only admissible sparse matrix format (no COO/CSC) and dense matrices. 1110 SparseTensorType aTp = getSparseTensorType(a); 1111 SparseTensorType bTp = getSparseTensorType(b); 1112 SparseTensorType cTp = getSparseTensorType(c); 1113 auto format = getCuSparseFormat(cTp, bTp, aTp, enableRT, /*isMatVec=*/false); 1114 if (format == CuSparseFormat::kNone || format == CuSparseFormat::kCOO || 1115 format == CuSparseFormat::kCSC) 1116 return failure(); 1117 1118 // The SDDMM does the in-place operation. 1119 // Start sparse kernel and copy data from host to device. 1120 // a : bufA -> matA 1121 // b : bufB -> matA 1122 // c : memR/memC/memV -> rowC,colC,valC 1123 Value nseC = rewriter.create<NumberOfEntriesOp>(loc, c); 1124 Value szm = linalg::createOrFoldDimOp(rewriter, loc, a, 0); 1125 Value szk = linalg::createOrFoldDimOp(rewriter, loc, a, 1); 1126 Value szn = linalg::createOrFoldDimOp(rewriter, loc, b, 1); 1127 Value matA, matB; 1128 Value bufA = genTensorToMemref(rewriter, loc, a); 1129 if (!isZeroCopy) 1130 matA = genAllocCopy(rewriter, loc, bufA, tokens); 1131 Value bufB = genTensorToMemref(rewriter, loc, b); 1132 if (!isZeroCopy) 1133 matB = isZeroCopy ? bufB : genAllocCopy(rewriter, loc, bufB, tokens); 1134 Value memR = genFirstPosOrCrds(rewriter, loc, c, format, enableRT); 1135 Value memC = genSecondCrds(rewriter, loc, c, format, enableRT); 1136 Value memV = genToValues(rewriter, loc, c); 1137 Value castB, castA, castR, castC, castV; 1138 if (gpuDataTransferStrategy != GPUDataTransferStrategy::kRegularDMA) { 1139 castB = genHostRegisterMemref(rewriter, loc, bufB); 1140 castA = genHostRegisterMemref(rewriter, loc, bufA); 1141 castR = genHostRegisterMemref(rewriter, loc, memR); 1142 if (memC) 1143 castC = genHostRegisterMemref(rewriter, loc, memC); 1144 castV = genHostRegisterMemref(rewriter, loc, memV); 1145 } 1146 if (isZeroCopy) { 1147 matA = bufA; 1148 matB = bufB; 1149 } 1150 Value rowC = genAllocCopy(rewriter, loc, memR, tokens); 1151 Value colC = memC ? genAllocCopy(rewriter, loc, memC, tokens) : Value(); 1152 Value valC = genAllocCopy(rewriter, loc, memV, tokens); 1153 genBlockingWait(rewriter, loc, tokens); 1154 tokens.clear(); 1155 1156 // Create sparse environment and sparse matrix/dense matrix handles. 1157 Type indexTp = rewriter.getIndexType(); 1158 Type dnMatHandleTp = rewriter.getType<gpu::SparseDnTensorHandleType>(); 1159 Type spMatHandleTp = rewriter.getType<gpu::SparseSpMatHandleType>(); 1160 Type tokenTp = rewriter.getType<gpu::AsyncTokenType>(); 1161 Value token = genFirstWait(rewriter, loc); 1162 auto dmatA = rewriter.create<gpu::CreateDnTensorOp>( 1163 loc, dnMatHandleTp, tokenTp, token, matA, SmallVector<Value>{szm, szk}); 1164 Value dnA = dmatA.getResult(0); 1165 token = dmatA.getAsyncToken(); 1166 auto dmatB = rewriter.create<gpu::CreateDnTensorOp>( 1167 loc, dnMatHandleTp, tokenTp, token, matB, SmallVector<Value>{szk, szn}); 1168 Value dnB = dmatB.getResult(0); 1169 token = dmatB.getAsyncToken(); 1170 Operation *spGenC = 1171 genSpMat(rewriter, loc, cTp, spMatHandleTp, tokenTp, token, szm, szn, 1172 nseC, rowC, colC, valC, format, enableRT); 1173 Value spMatC = spGenC->getResult(0); 1174 token = spGenC->getResult(1); 1175 auto dnCType = llvm::cast<ShapedType>(c.getType()).getElementType(); 1176 1177 // Precompute buffersize for SDDMM. 1178 auto bufferComp = rewriter.create<gpu::SDDMMBufferSizeOp>( 1179 loc, indexTp, tokenTp, token, dnA, dnB, spMatC, dnCType); 1180 Value bufferSz = bufferComp.getResult(0); 1181 token = bufferComp.getAsyncToken(); 1182 auto buf = genAllocBuffer(rewriter, loc, bufferSz, token); 1183 Value buffer = buf.getResult(0); 1184 token = buf.getAsyncToken(); 1185 1186 // Perform the SDDMM. 1187 auto sddmmComp = rewriter.create<gpu::SDDMMOp>(loc, tokenTp, token, dnA, dnB, 1188 spMatC, dnCType, buffer); 1189 token = sddmmComp.getAsyncToken(); 1190 1191 // Copy data back to host and free all the resoures. 1192 token = rewriter.create<gpu::DestroyDnTensorOp>(loc, tokenTp, token, dnA) 1193 .getAsyncToken(); 1194 token = rewriter.create<gpu::DestroyDnTensorOp>(loc, tokenTp, token, dnB) 1195 .getAsyncToken(); 1196 token = rewriter.create<gpu::DestroySpMatOp>(loc, tokenTp, token, spMatC) 1197 .getAsyncToken(); 1198 token = genDeallocMemRef(rewriter, loc, buffer, token); 1199 if (!isZeroCopy) { 1200 token = genDeallocMemRef(rewriter, loc, matA, token); 1201 token = genDeallocMemRef(rewriter, loc, matB, token); 1202 } 1203 token = genDeallocMemRef(rewriter, loc, rowC, token); 1204 if (colC) 1205 token = genDeallocMemRef(rewriter, loc, colC, token); 1206 token = genCopyMemRef(rewriter, loc, memV, valC, token); 1207 token = genDeallocMemRef(rewriter, loc, valC, token); 1208 tokens.push_back(token); 1209 genBlockingWait(rewriter, loc, tokens); 1210 tokens.clear(); 1211 if (gpuDataTransferStrategy != GPUDataTransferStrategy::kRegularDMA) { 1212 genHostUnregisterMemref(rewriter, loc, castB); 1213 genHostUnregisterMemref(rewriter, loc, castA); 1214 genHostUnregisterMemref(rewriter, loc, castR); 1215 if (memC) 1216 genHostUnregisterMemref(rewriter, loc, castC); 1217 genHostUnregisterMemref(rewriter, loc, castV); 1218 } 1219 1220 // Done. 1221 rewriter.replaceOpWithNewOp<sparse_tensor::LoadOp>(op, c); 1222 return success(); 1223 } 1224 1225 //===----------------------------------------------------------------------===// 1226 // Rewriting rules for direct code generation. 1227 //===----------------------------------------------------------------------===// 1228 1229 /// Proof-of-concept rewriter. This rule generates a GPU implementation 1230 /// for each outermost forall loop generated by the sparse compiler. 1231 /// TODO: right now works with parallelization-strategy=dense-outer-loop 1232 /// but give this its own flags in the future 1233 struct ForallRewriter : public OpRewritePattern<scf::ParallelOp> { 1234 using OpRewritePattern<scf::ParallelOp>::OpRewritePattern; 1235 1236 ForallRewriter(MLIRContext *context, unsigned nT) 1237 : OpRewritePattern(context), numThreads(nT){}; 1238 1239 LogicalResult matchAndRewrite(scf::ParallelOp forallOp, 1240 PatternRewriter &rewriter) const override { 1241 // Reject inadmissible loop form. 1242 // Essentially only accept a loop, generated by the sparse compiler, 1243 // of the form 1244 // forall (i = 0; i < N; i++) 1245 // so that cyclic scheduling over the threads is easy. 1246 if (!forallOp->hasAttr(LoopEmitter::getLoopEmitterLoopAttrName()) || 1247 forallOp.getNumReductions() != 0 || forallOp.getNumLoops() != 1 || 1248 !matchPattern(forallOp.getLowerBound()[0], m_Zero()) || 1249 !matchPattern(forallOp.getStep()[0], m_One())) 1250 return failure(); 1251 // Collect every value that is computed outside the parallel loop. 1252 SetVector<Value> invariants; // stable iteration! 1253 forallOp->walk([&](Operation *op) { 1254 // Collect all values of admissible ops. 1255 for (OpOperand &o : op->getOpOperands()) { 1256 Value val = o.get(); 1257 Block *block; 1258 if (auto arg = dyn_cast<BlockArgument>(val)) 1259 block = arg.getOwner(); 1260 else 1261 block = val.getDefiningOp()->getBlock(); 1262 if (!isNestedIn(block, forallOp)) 1263 invariants.insert(val); 1264 } 1265 }); 1266 // Outline the outside values as proper parameters. Fail when sharing 1267 // value between host and device is not straightforward. 1268 SmallVector<Value> constants; 1269 SmallVector<Value> scalars; 1270 SmallVector<Value> buffers; 1271 for (Value val : invariants) { 1272 Type tp = val.getType(); 1273 if (val.getDefiningOp<arith::ConstantOp>()) 1274 constants.push_back(val); 1275 else if (isa<FloatType>(tp) || tp.isIntOrIndex()) 1276 scalars.push_back(val); 1277 else if (isa<MemRefType>(tp)) 1278 buffers.push_back(val); 1279 else 1280 return failure(); // don't know how to share 1281 } 1282 // Pass outlined non-constant values. 1283 // TODO: Experiment with `useHostRegistrationForOut` to see if we want to 1284 // keep the feature at all (either through a heuristic or compiler 1285 // option for gpu codegen). 1286 Location loc = forallOp->getLoc(); 1287 SmallVector<Value> args; 1288 SmallVector<Value> tokens; 1289 Value out = genParametersIn(rewriter, loc, scalars, buffers, args, tokens, 1290 /*useHostRegistrationForOut=*/false); 1291 // Set up GPU module and construct GPU function. 1292 auto saveIp = rewriter.saveInsertionPoint(); 1293 ModuleOp topModule = forallOp->getParentOfType<ModuleOp>(); 1294 auto gpuModule = genGPUModule(rewriter, topModule); 1295 auto gpuFunc = genGPUFunc(rewriter, gpuModule, args); 1296 genGPUCode(rewriter, gpuFunc, forallOp, constants, scalars, buffers); 1297 // Generate code that launches the kernel asynchronously, blocking on all 1298 // opens tokens and yielding a new token for the output. 1299 // TODO: Passing in tokens to launch up does not seem to be properly lowered 1300 // by cubin yet, hence the current blocking wait. 1301 rewriter.restoreInsertionPoint(saveIp); 1302 genBlockingWait(rewriter, loc, tokens); 1303 tokens.clear(); 1304 Value kernelToken = 1305 genLaunchGPUFunc(rewriter, gpuFunc, args, tokens, numThreads); 1306 // Finalize the outlined arguments. 1307 genParametersOut(rewriter, loc, out, kernelToken, scalars, buffers, args, 1308 tokens); 1309 genBlockingWait(rewriter, loc, tokens); 1310 rewriter.eraseOp(forallOp); 1311 return success(); 1312 } 1313 1314 private: 1315 // Helper method to see if block appears in given loop. 1316 static bool isNestedIn(Block *block, scf::ParallelOp forallOp) { 1317 for (Operation *o = block->getParentOp(); o; o = o->getParentOp()) { 1318 if (o == forallOp) 1319 return true; 1320 } 1321 return false; 1322 } 1323 1324 unsigned numThreads; 1325 }; 1326 1327 //===----------------------------------------------------------------------===// 1328 // Rewriting rules for library recognition and code generation. 1329 //===----------------------------------------------------------------------===// 1330 1331 /// Proof-of-concept rewriter. This rule recognizes certain math kernels 1332 /// and replaces these with corresponding calls into a sparse library. 1333 struct LinalgOpRewriter : public OpRewritePattern<linalg::GenericOp> { 1334 using OpRewritePattern<linalg::GenericOp>::OpRewritePattern; 1335 1336 LinalgOpRewriter(MLIRContext *context, bool rt, GPUDataTransferStrategy t) 1337 : OpRewritePattern(context), enableRT(rt), gpuDataTransferStrategy(t) {} 1338 1339 LogicalResult matchAndRewrite(linalg::GenericOp op, 1340 PatternRewriter &rewriter) const override { 1341 if (op.getNumDpsInits() != 1) 1342 return failure(); // reject multi-output 1343 1344 const unsigned numLoops = op.getNumLoops(); 1345 const unsigned numTensors = op->getNumOperands(); 1346 const auto iteratorTypes = op.getIteratorTypesArray(); 1347 SmallVector<AffineMap, 4> maps = op.getIndexingMapsArray(); 1348 1349 using MapList = ArrayRef<ArrayRef<AffineExpr>>; 1350 auto infer = [](MapList m) { return AffineMap::inferFromExprList(m); }; 1351 AffineExpr i, j, k; 1352 bindDims(getContext(), i, j, k); 1353 1354 // TODO: more robust patterns, tranposed versions, more kernels, 1355 // identify alpha and beta and pass them to the CUDA calls. 1356 1357 // Recognize a SpMV kernel. 1358 if (numLoops == 2 && numTensors == 3 && 1359 linalg::isParallelIterator(iteratorTypes[0]) && 1360 linalg::isReductionIterator(iteratorTypes[1]) && 1361 maps == infer({{i, j}, {j}, {i}}) && matchSumOfMultOfArgs(op)) { 1362 return rewriteSpMV(rewriter, op, enableRT, gpuDataTransferStrategy); 1363 } 1364 1365 // Recognize a SpGEMM, 2:4-SpMM, or SpMM kernel. 1366 if (numLoops == 3 && numTensors == 3 && 1367 linalg::isParallelIterator(iteratorTypes[0]) && 1368 linalg::isParallelIterator(iteratorTypes[1]) && 1369 linalg::isReductionIterator(iteratorTypes[2]) && 1370 maps == infer({{i, k}, {k, j}, {i, j}}) && matchSumOfMultOfArgs(op)) { 1371 if (!isDenseTensor(op.getOperand(0)) && !isDenseTensor(op.getOperand(1))) 1372 return rewriteSpGEMM(rewriter, op, enableRT, gpuDataTransferStrategy); 1373 if (op->getAttr("DENSE24")) 1374 return rewrite2To4SpMM(rewriter, op, gpuDataTransferStrategy); 1375 return rewriteSpMM(rewriter, op, enableRT, gpuDataTransferStrategy); 1376 } 1377 1378 // Recognize a SDDMM kernel. 1379 if (numLoops == 3 && numTensors == 3 && 1380 linalg::isParallelIterator(iteratorTypes[0]) && 1381 linalg::isParallelIterator(iteratorTypes[1]) && 1382 linalg::isReductionIterator(iteratorTypes[2]) && 1383 maps == infer({{i, k}, {k, j}, {i, j}}) && 1384 matchSumReductionOfMulUnary(op)) { 1385 return rewriteSDDMM(rewriter, op, enableRT, gpuDataTransferStrategy); 1386 } 1387 1388 return failure(); 1389 } 1390 1391 private: 1392 bool enableRT; 1393 GPUDataTransferStrategy gpuDataTransferStrategy; 1394 }; 1395 1396 } // namespace 1397 1398 //===----------------------------------------------------------------------===// 1399 // Public method for populating GPU rewriting rules. 1400 // 1401 // Currently two set of rewriting rules are made available. The first set 1402 // implements direct code generation, currently by means of convering the 1403 // outermost paralell loop into GPU threads. The second set implements 1404 // libary recognition of a set of sparse operations. Eventually, the right 1405 // combination of these two approaches has to be found. 1406 //===----------------------------------------------------------------------===// 1407 1408 void mlir::populateSparseGPUCodegenPatterns(RewritePatternSet &patterns, 1409 unsigned numThreads) { 1410 patterns.add<ForallRewriter>(patterns.getContext(), numThreads); 1411 } 1412 1413 void mlir::populateSparseGPULibgenPatterns( 1414 RewritePatternSet &patterns, bool enableRT, 1415 GPUDataTransferStrategy gpuDataTransfer) { 1416 patterns.add<LinalgOpRewriter>(patterns.getContext(), enableRT, 1417 gpuDataTransfer); 1418 } 1419