1 //===- SparseBufferRewriting.cpp - Sparse buffer rewriting rules ----------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements rewriting rules that are specific to sparse tensor 10 // primitives with memref operands. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "CodegenUtils.h" 15 16 #include "mlir/Dialect/Arith/IR/Arith.h" 17 #include "mlir/Dialect/Func/IR/FuncOps.h" 18 #include "mlir/Dialect/Linalg/IR/Linalg.h" 19 #include "mlir/Dialect/MemRef/IR/MemRef.h" 20 #include "mlir/Dialect/SCF/IR/SCF.h" 21 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" 22 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h" 23 #include "mlir/Support/LLVM.h" 24 25 using namespace mlir; 26 using namespace mlir::sparse_tensor; 27 28 //===---------------------------------------------------------------------===// 29 // Helper methods for the actual rewriting rules. 30 //===---------------------------------------------------------------------===// 31 32 static constexpr uint64_t loIdx = 0; 33 static constexpr uint64_t hiIdx = 1; 34 static constexpr uint64_t xStartIdx = 2; 35 36 static constexpr const char kLessThanFuncNamePrefix[] = "_sparse_less_than_"; 37 static constexpr const char kCompareEqFuncNamePrefix[] = "_sparse_compare_eq_"; 38 static constexpr const char kPartitionFuncNamePrefix[] = "_sparse_partition_"; 39 static constexpr const char kBinarySearchFuncNamePrefix[] = 40 "_sparse_binary_search_"; 41 static constexpr const char kSortNonstableFuncNamePrefix[] = 42 "_sparse_sort_nonstable_"; 43 static constexpr const char kSortStableFuncNamePrefix[] = 44 "_sparse_sort_stable_"; 45 46 using FuncGeneratorType = function_ref<void( 47 OpBuilder &, ModuleOp, func::FuncOp, uint64_t, uint64_t, bool, uint32_t)>; 48 49 /// Constructs a function name with this format to facilitate quick sort: 50 /// <namePrefix><nx>_<x type>_<y0 type>..._<yn type> for sort 51 /// <namePrefix><nx>_<x type>_coo_<ny>_<y0 type>..._<yn type> for sort_coo 52 static void getMangledSortHelperFuncName(llvm::raw_svector_ostream &nameOstream, 53 StringRef namePrefix, uint64_t nx, 54 uint64_t ny, bool isCoo, 55 ValueRange operands) { 56 nameOstream << namePrefix << nx << "_" 57 << getMemRefType(operands[xStartIdx]).getElementType(); 58 59 if (isCoo) 60 nameOstream << "_coo_" << ny; 61 62 uint64_t yBufferOffset = isCoo ? 1 : nx; 63 for (Value v : operands.drop_front(xStartIdx + yBufferOffset)) 64 nameOstream << "_" << getMemRefType(v).getElementType(); 65 } 66 67 /// Looks up a function that is appropriate for the given operands being 68 /// sorted, and creates such a function if it doesn't exist yet. The 69 /// parameters `nx` and `ny` tell the number of x and y values provided 70 /// by the buffer in xStartIdx, and `isCoo` indicates whether the instruction 71 /// being processed is a sparse_tensor.sort or sparse_tensor.sort_coo. 72 // 73 // All sorting function generators take (lo, hi, xs, ys) in `operands` as 74 // parameters for the sorting functions. Other parameters, such as the recursive 75 // call depth, are appended to the end of the parameter list as 76 // "trailing parameters". 77 static FlatSymbolRefAttr 78 getMangledSortHelperFunc(OpBuilder &builder, func::FuncOp insertPoint, 79 TypeRange resultTypes, StringRef namePrefix, 80 uint64_t nx, uint64_t ny, bool isCoo, 81 ValueRange operands, FuncGeneratorType createFunc, 82 uint32_t nTrailingP = 0) { 83 SmallString<32> nameBuffer; 84 llvm::raw_svector_ostream nameOstream(nameBuffer); 85 getMangledSortHelperFuncName(nameOstream, namePrefix, nx, ny, isCoo, 86 operands.drop_back(nTrailingP)); 87 88 ModuleOp module = insertPoint->getParentOfType<ModuleOp>(); 89 MLIRContext *context = module.getContext(); 90 auto result = SymbolRefAttr::get(context, nameOstream.str()); 91 auto func = module.lookupSymbol<func::FuncOp>(result.getAttr()); 92 93 if (!func) { 94 // Create the function. 95 OpBuilder::InsertionGuard insertionGuard(builder); 96 builder.setInsertionPoint(insertPoint); 97 Location loc = insertPoint.getLoc(); 98 func = builder.create<func::FuncOp>( 99 loc, nameOstream.str(), 100 FunctionType::get(context, operands.getTypes(), resultTypes)); 101 func.setPrivate(); 102 createFunc(builder, module, func, nx, ny, isCoo, nTrailingP); 103 } 104 105 return result; 106 } 107 108 /// Creates a code block to process each pair of (xs[i], xs[j]) for sorting. 109 /// The code to process the value pairs is generated by `bodyBuilder`. 110 static void forEachIJPairInXs( 111 OpBuilder &builder, Location loc, ValueRange args, uint64_t nx, uint64_t ny, 112 bool isCoo, function_ref<void(uint64_t, Value, Value, Value)> bodyBuilder) { 113 Value iOffset, jOffset; 114 if (isCoo) { 115 Value cstep = constantIndex(builder, loc, nx + ny); 116 iOffset = builder.create<arith::MulIOp>(loc, args[0], cstep); 117 jOffset = builder.create<arith::MulIOp>(loc, args[1], cstep); 118 } 119 for (uint64_t k = 0; k < nx; k++) { 120 scf::IfOp ifOp; 121 Value i, j, buffer; 122 if (isCoo) { 123 Value ck = constantIndex(builder, loc, k); 124 i = builder.create<arith::AddIOp>(loc, ck, iOffset); 125 j = builder.create<arith::AddIOp>(loc, ck, jOffset); 126 buffer = args[xStartIdx]; 127 } else { 128 i = args[0]; 129 j = args[1]; 130 buffer = args[xStartIdx + k]; 131 } 132 bodyBuilder(k, i, j, buffer); 133 } 134 } 135 136 /// Creates a code block to process each pair of (xys[i], xys[j]) for sorting. 137 /// The code to process the value pairs is generated by `bodyBuilder`. 138 static void forEachIJPairInAllBuffers( 139 OpBuilder &builder, Location loc, ValueRange args, uint64_t nx, uint64_t ny, 140 bool isCoo, function_ref<void(uint64_t, Value, Value, Value)> bodyBuilder) { 141 142 // Create code for the first (nx + ny) buffers. When isCoo==true, these 143 // logical buffers are all from the xy buffer of the sort_coo operator. 144 forEachIJPairInXs(builder, loc, args, nx + ny, 0, isCoo, bodyBuilder); 145 146 uint64_t numHandledBuffers = isCoo ? 1 : nx + ny; 147 148 // Create code for the remaining buffers. 149 Value i = args[0]; 150 Value j = args[1]; 151 for (const auto &arg : 152 llvm::enumerate(args.drop_front(xStartIdx + numHandledBuffers))) { 153 bodyBuilder(arg.index() + nx + ny, i, j, arg.value()); 154 } 155 } 156 157 /// Creates a code block for swapping the values in index i and j for all the 158 /// buffers. 159 // 160 // The generated IR corresponds to this C like algorithm: 161 // swap(x0[i], x0[j]); 162 // swap(x1[i], x1[j]); 163 // ... 164 // swap(xn[i], xn[j]); 165 // swap(y0[i], y0[j]); 166 // ... 167 // swap(yn[i], yn[j]); 168 static void createSwap(OpBuilder &builder, Location loc, ValueRange args, 169 uint64_t nx, uint64_t ny, bool isCoo) { 170 auto swapOnePair = [&](uint64_t unused, Value i, Value j, Value buffer) { 171 Value vi = builder.create<memref::LoadOp>(loc, buffer, i); 172 Value vj = builder.create<memref::LoadOp>(loc, buffer, j); 173 builder.create<memref::StoreOp>(loc, vj, buffer, i); 174 builder.create<memref::StoreOp>(loc, vi, buffer, j); 175 }; 176 177 forEachIJPairInAllBuffers(builder, loc, args, nx, ny, isCoo, swapOnePair); 178 } 179 180 /// Creates a function to compare all the (xs[i], xs[j]) pairs. The method to 181 /// compare each pair is create via `compareBuilder`. 182 static void createCompareFuncImplementation( 183 OpBuilder &builder, ModuleOp unused, func::FuncOp func, uint64_t nx, 184 uint64_t ny, bool isCoo, 185 function_ref<scf::IfOp(OpBuilder &, Location, Value, Value, Value, bool)> 186 compareBuilder) { 187 OpBuilder::InsertionGuard insertionGuard(builder); 188 189 Block *entryBlock = func.addEntryBlock(); 190 builder.setInsertionPointToStart(entryBlock); 191 Location loc = func.getLoc(); 192 ValueRange args = entryBlock->getArguments(); 193 194 scf::IfOp topIfOp; 195 auto bodyBuilder = [&](uint64_t k, Value i, Value j, Value buffer) { 196 scf::IfOp ifOp = compareBuilder(builder, loc, i, j, buffer, (k == nx - 1)); 197 if (k == 0) { 198 topIfOp = ifOp; 199 } else { 200 OpBuilder::InsertionGuard insertionGuard(builder); 201 builder.setInsertionPointAfter(ifOp); 202 builder.create<scf::YieldOp>(loc, ifOp.getResult(0)); 203 } 204 }; 205 206 forEachIJPairInXs(builder, loc, args, nx, ny, isCoo, bodyBuilder); 207 208 builder.setInsertionPointAfter(topIfOp); 209 builder.create<func::ReturnOp>(loc, topIfOp.getResult(0)); 210 } 211 212 /// Generates an if-statement to compare whether x[i] is equal to x[j]. 213 static scf::IfOp createEqCompare(OpBuilder &builder, Location loc, Value i, 214 Value j, Value x, bool isLastDim) { 215 Value f = constantI1(builder, loc, false); 216 Value t = constantI1(builder, loc, true); 217 Value vi = builder.create<memref::LoadOp>(loc, x, i); 218 Value vj = builder.create<memref::LoadOp>(loc, x, j); 219 220 Value cond = 221 builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, vi, vj); 222 scf::IfOp ifOp = 223 builder.create<scf::IfOp>(loc, f.getType(), cond, /*else=*/true); 224 225 // x[1] != x[j]: 226 builder.setInsertionPointToStart(&ifOp.getElseRegion().front()); 227 builder.create<scf::YieldOp>(loc, f); 228 229 // x[i] == x[j]: 230 builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); 231 if (isLastDim == 1) { 232 // Finish checking all dimensions. 233 builder.create<scf::YieldOp>(loc, t); 234 } 235 236 return ifOp; 237 } 238 239 /// Creates a function to compare whether xs[i] is equal to xs[j]. 240 // 241 // The generate IR corresponds to this C like algorithm: 242 // if (x0[i] != x0[j]) 243 // return false; 244 // else 245 // if (x1[i] != x1[j]) 246 // return false; 247 // else if (x2[2] != x2[j])) 248 // and so on ... 249 static void createEqCompareFunc(OpBuilder &builder, ModuleOp unused, 250 func::FuncOp func, uint64_t nx, uint64_t ny, 251 bool isCoo, uint32_t nTrailingP = 0) { 252 // Compare functions don't use trailing parameters. 253 (void)nTrailingP; 254 assert(nTrailingP == 0); 255 createCompareFuncImplementation(builder, unused, func, nx, ny, isCoo, 256 createEqCompare); 257 } 258 259 /// Generates an if-statement to compare whether x[i] is less than x[j]. 260 static scf::IfOp createLessThanCompare(OpBuilder &builder, Location loc, 261 Value i, Value j, Value x, 262 bool isLastDim) { 263 Value f = constantI1(builder, loc, false); 264 Value t = constantI1(builder, loc, true); 265 Value vi = builder.create<memref::LoadOp>(loc, x, i); 266 Value vj = builder.create<memref::LoadOp>(loc, x, j); 267 268 Value cond = 269 builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult, vi, vj); 270 scf::IfOp ifOp = 271 builder.create<scf::IfOp>(loc, f.getType(), cond, /*else=*/true); 272 // If (x[i] < x[j]). 273 builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); 274 builder.create<scf::YieldOp>(loc, t); 275 276 builder.setInsertionPointToStart(&ifOp.getElseRegion().front()); 277 if (isLastDim == 1) { 278 // Finish checking all dimensions. 279 builder.create<scf::YieldOp>(loc, f); 280 } else { 281 cond = 282 builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult, vj, vi); 283 scf::IfOp ifOp2 = 284 builder.create<scf::IfOp>(loc, f.getType(), cond, /*else=*/true); 285 // Otherwise if (x[j] < x[i]). 286 builder.setInsertionPointToStart(&ifOp2.getThenRegion().front()); 287 builder.create<scf::YieldOp>(loc, f); 288 289 // Otherwise check the remaining dimensions. 290 builder.setInsertionPointAfter(ifOp2); 291 builder.create<scf::YieldOp>(loc, ifOp2.getResult(0)); 292 // Set up the insertion point for the nested if-stmt that checks the 293 // remaining dimensions. 294 builder.setInsertionPointToStart(&ifOp2.getElseRegion().front()); 295 } 296 297 return ifOp; 298 } 299 300 /// Creates a function to compare whether xs[i] is less than xs[j]. 301 // 302 // The generate IR corresponds to this C like algorithm: 303 // if (x0[i] < x0[j]) 304 // return true; 305 // else if (x0[j] < x0[i]) 306 // return false; 307 // else 308 // if (x1[i] < x1[j]) 309 // return true; 310 // else if (x1[j] < x1[i])) 311 // and so on ... 312 static void createLessThanFunc(OpBuilder &builder, ModuleOp unused, 313 func::FuncOp func, uint64_t nx, uint64_t ny, 314 bool isCoo, uint32_t nTrailingP = 0) { 315 // Compare functions don't use trailing parameters. 316 (void)nTrailingP; 317 assert(nTrailingP == 0); 318 createCompareFuncImplementation(builder, unused, func, nx, ny, isCoo, 319 createLessThanCompare); 320 } 321 322 /// Creates a function to use a binary search to find the insertion point for 323 /// inserting xs[hi] to the sorted values xs[lo..hi). 324 // 325 // The generate IR corresponds to this C like algorithm: 326 // p = hi 327 // while (lo < hi) 328 // mid = (lo + hi) >> 1 329 // if (xs[p] < xs[mid]) 330 // hi = mid 331 // else 332 // lo = mid - 1 333 // return lo; 334 // 335 static void createBinarySearchFunc(OpBuilder &builder, ModuleOp module, 336 func::FuncOp func, uint64_t nx, uint64_t ny, 337 bool isCoo, uint32_t nTrailingP = 0) { 338 // Binary search doesn't use trailing parameters. 339 (void)nTrailingP; 340 assert(nTrailingP == 0); 341 OpBuilder::InsertionGuard insertionGuard(builder); 342 Block *entryBlock = func.addEntryBlock(); 343 builder.setInsertionPointToStart(entryBlock); 344 345 Location loc = func.getLoc(); 346 ValueRange args = entryBlock->getArguments(); 347 Value p = args[hiIdx]; 348 SmallVector<Type, 2> types(2, p.getType()); // Only two types. 349 scf::WhileOp whileOp = builder.create<scf::WhileOp>( 350 loc, types, SmallVector<Value, 2>{args[loIdx], args[hiIdx]}); 351 352 // The before-region of the WhileOp. 353 Block *before = 354 builder.createBlock(&whileOp.getBefore(), {}, types, {loc, loc}); 355 builder.setInsertionPointToEnd(before); 356 Value cond1 = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult, 357 before->getArgument(0), 358 before->getArgument(1)); 359 builder.create<scf::ConditionOp>(loc, cond1, before->getArguments()); 360 361 // The after-region of the WhileOp. 362 Block *after = 363 builder.createBlock(&whileOp.getAfter(), {}, types, {loc, loc}); 364 builder.setInsertionPointToEnd(after); 365 Value lo = after->getArgument(0); 366 Value hi = after->getArgument(1); 367 // Compute mid = (lo + hi) >> 1. 368 Value c1 = constantIndex(builder, loc, 1); 369 Value mid = builder.create<arith::ShRUIOp>( 370 loc, builder.create<arith::AddIOp>(loc, lo, hi), c1); 371 Value midp1 = builder.create<arith::AddIOp>(loc, mid, c1); 372 373 // Compare xs[p] < xs[mid]. 374 SmallVector<Value> compareOperands{p, mid}; 375 uint64_t numXBuffers = isCoo ? 1 : nx; 376 compareOperands.append(args.begin() + xStartIdx, 377 args.begin() + xStartIdx + numXBuffers); 378 Type i1Type = IntegerType::get(module.getContext(), 1, IntegerType::Signless); 379 FlatSymbolRefAttr lessThanFunc = getMangledSortHelperFunc( 380 builder, func, {i1Type}, kLessThanFuncNamePrefix, nx, ny, isCoo, 381 compareOperands, createLessThanFunc, nTrailingP); 382 Value cond2 = builder 383 .create<func::CallOp>(loc, lessThanFunc, TypeRange{i1Type}, 384 compareOperands) 385 .getResult(0); 386 387 // Update lo and hi for the WhileOp as follows: 388 // if (xs[p] < xs[mid])) 389 // hi = mid; 390 // else 391 // lo = mid + 1; 392 Value newLo = builder.create<arith::SelectOp>(loc, cond2, lo, midp1); 393 Value newHi = builder.create<arith::SelectOp>(loc, cond2, mid, hi); 394 builder.create<scf::YieldOp>(loc, ValueRange{newLo, newHi}); 395 396 builder.setInsertionPointAfter(whileOp); 397 builder.create<func::ReturnOp>(loc, whileOp.getResult(0)); 398 } 399 400 /// Creates code to advance i in a loop based on xs[p] as follows: 401 /// while (xs[i] < xs[p]) i += step (step > 0) 402 /// or 403 /// while (xs[i] > xs[p]) i += step (step < 0) 404 /// The routine returns i as well as a boolean value to indicate whether 405 /// xs[i] == xs[p]. 406 static std::pair<Value, Value> 407 createScanLoop(OpBuilder &builder, ModuleOp module, func::FuncOp func, 408 ValueRange xs, Value i, Value p, uint64_t nx, uint64_t ny, 409 bool isCoo, int step) { 410 Location loc = func.getLoc(); 411 scf::WhileOp whileOp = 412 builder.create<scf::WhileOp>(loc, TypeRange{i.getType()}, ValueRange{i}); 413 414 Block *before = 415 builder.createBlock(&whileOp.getBefore(), {}, {i.getType()}, {loc}); 416 builder.setInsertionPointToEnd(before); 417 SmallVector<Value> compareOperands; 418 if (step > 0) { 419 compareOperands.push_back(before->getArgument(0)); 420 compareOperands.push_back(p); 421 } else { 422 assert(step < 0); 423 compareOperands.push_back(p); 424 compareOperands.push_back(before->getArgument(0)); 425 } 426 compareOperands.append(xs.begin(), xs.end()); 427 MLIRContext *context = module.getContext(); 428 Type i1Type = IntegerType::get(context, 1, IntegerType::Signless); 429 FlatSymbolRefAttr lessThanFunc = getMangledSortHelperFunc( 430 builder, func, {i1Type}, kLessThanFuncNamePrefix, nx, ny, isCoo, 431 compareOperands, createLessThanFunc); 432 Value cond = builder 433 .create<func::CallOp>(loc, lessThanFunc, TypeRange{i1Type}, 434 compareOperands) 435 .getResult(0); 436 builder.create<scf::ConditionOp>(loc, cond, before->getArguments()); 437 438 Block *after = 439 builder.createBlock(&whileOp.getAfter(), {}, {i.getType()}, {loc}); 440 builder.setInsertionPointToEnd(after); 441 Value cs = constantIndex(builder, loc, step); 442 i = builder.create<arith::AddIOp>(loc, after->getArgument(0), cs); 443 builder.create<scf::YieldOp>(loc, ValueRange{i}); 444 i = whileOp.getResult(0); 445 446 builder.setInsertionPointAfter(whileOp); 447 compareOperands[0] = i; 448 compareOperands[1] = p; 449 FlatSymbolRefAttr compareEqFunc = getMangledSortHelperFunc( 450 builder, func, {i1Type}, kCompareEqFuncNamePrefix, nx, ny, isCoo, 451 compareOperands, createEqCompareFunc); 452 Value compareEq = 453 builder 454 .create<func::CallOp>(loc, compareEqFunc, TypeRange{i1Type}, 455 compareOperands) 456 .getResult(0); 457 458 return std::make_pair(whileOp.getResult(0), compareEq); 459 } 460 461 /// Creates a code block to swap the values so that data[mi] is the median among 462 /// data[lo], data[hi], and data[mi]. 463 // The generated code corresponds to this C-like algorithm: 464 // median = mi 465 // if (data[mi] < data[lo]). (if1) 466 // if (data[hi] < data[lo]) (if2) 467 // median = data[hi] < data[mi] ? mi : hi 468 // else 469 // median = lo 470 // else 471 // if data[hi] < data[mi] (if3) 472 // median = data[hi] < data[lo] ? lo : hi 473 // if median != mi swap data[median] with data[mi] 474 static void createChoosePivot(OpBuilder &builder, ModuleOp module, 475 func::FuncOp func, uint64_t nx, uint64_t ny, 476 bool isCoo, Value lo, Value hi, Value mi, 477 ValueRange args) { 478 SmallVector<Value> compareOperands{mi, lo}; 479 uint64_t numXBuffers = isCoo ? 1 : nx; 480 compareOperands.append(args.begin() + xStartIdx, 481 args.begin() + xStartIdx + numXBuffers); 482 Type i1Type = IntegerType::get(module.getContext(), 1, IntegerType::Signless); 483 SmallVector<Type, 1> cmpTypes{i1Type}; 484 FlatSymbolRefAttr lessThanFunc = getMangledSortHelperFunc( 485 builder, func, cmpTypes, kLessThanFuncNamePrefix, nx, ny, isCoo, 486 compareOperands, createLessThanFunc); 487 Location loc = func.getLoc(); 488 // Compare data[mi] < data[lo]. 489 Value cond1 = 490 builder.create<func::CallOp>(loc, lessThanFunc, cmpTypes, compareOperands) 491 .getResult(0); 492 SmallVector<Type, 1> ifTypes{lo.getType()}; 493 scf::IfOp ifOp1 = 494 builder.create<scf::IfOp>(loc, ifTypes, cond1, /*else=*/true); 495 496 // Generate an if-stmt to find the median value, assuming we already know that 497 // data[b] < data[a] and we haven't compare data[c] yet. 498 auto createFindMedian = [&](Value a, Value b, Value c) -> scf::IfOp { 499 compareOperands[0] = c; 500 compareOperands[1] = a; 501 // Compare data[c]] < data[a]. 502 Value cond2 = 503 builder 504 .create<func::CallOp>(loc, lessThanFunc, cmpTypes, compareOperands) 505 .getResult(0); 506 scf::IfOp ifOp2 = 507 builder.create<scf::IfOp>(loc, ifTypes, cond2, /*else=*/true); 508 builder.setInsertionPointToStart(&ifOp2.getThenRegion().front()); 509 compareOperands[0] = c; 510 compareOperands[1] = b; 511 // Compare data[c] < data[b]. 512 Value cond3 = 513 builder 514 .create<func::CallOp>(loc, lessThanFunc, cmpTypes, compareOperands) 515 .getResult(0); 516 builder.create<scf::YieldOp>( 517 loc, ValueRange{builder.create<arith::SelectOp>(loc, cond3, b, c)}); 518 builder.setInsertionPointToStart(&ifOp2.getElseRegion().front()); 519 builder.create<scf::YieldOp>(loc, ValueRange{a}); 520 return ifOp2; 521 }; 522 523 builder.setInsertionPointToStart(&ifOp1.getThenRegion().front()); 524 scf::IfOp ifOp2 = createFindMedian(lo, mi, hi); 525 builder.setInsertionPointAfter(ifOp2); 526 builder.create<scf::YieldOp>(loc, ValueRange{ifOp2.getResult(0)}); 527 528 builder.setInsertionPointToStart(&ifOp1.getElseRegion().front()); 529 scf::IfOp ifOp3 = createFindMedian(mi, lo, hi); 530 531 builder.setInsertionPointAfter(ifOp3); 532 builder.create<scf::YieldOp>(loc, ValueRange{ifOp3.getResult(0)}); 533 534 builder.setInsertionPointAfter(ifOp1); 535 Value median = ifOp1.getResult(0); 536 Value cond = 537 builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ne, mi, median); 538 scf::IfOp ifOp = 539 builder.create<scf::IfOp>(loc, TypeRange(), cond, /*else=*/false); 540 541 SmallVector<Value> swapOperands{median, mi}; 542 swapOperands.append(args.begin() + xStartIdx, args.end()); 543 builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); 544 createSwap(builder, loc, swapOperands, nx, ny, isCoo); 545 builder.setInsertionPointAfter(ifOp); 546 } 547 548 /// Creates a function to perform quick sort partition on the values in the 549 /// range of index [lo, hi), assuming lo < hi. 550 // 551 // The generated IR corresponds to this C like algorithm: 552 // int partition(lo, hi, xs) { 553 // p = (lo+hi)/2 // pivot index 554 // i = lo 555 // j = hi-1 556 // while (i < j) do { 557 // while (xs[i] < xs[p]) i ++; 558 // i_eq = (xs[i] == xs[p]); 559 // while (xs[j] > xs[p]) j --; 560 // j_eq = (xs[j] == xs[p]); 561 // if (i < j) { 562 // swap(xs[i], xs[j]) 563 // if (i == p) { 564 // p = j; 565 // } else if (j == p) { 566 // p = i; 567 // } 568 // if (i_eq && j_eq) { 569 // ++i; 570 // --j; 571 // } 572 // } 573 // } 574 // return p 575 // } 576 static void createPartitionFunc(OpBuilder &builder, ModuleOp module, 577 func::FuncOp func, uint64_t nx, uint64_t ny, 578 bool isCoo, uint32_t nTrailingP = 0) { 579 // Quick sort partition doesn't use trailing parameters. 580 (void)nTrailingP; 581 assert(nTrailingP == 0); 582 OpBuilder::InsertionGuard insertionGuard(builder); 583 584 Block *entryBlock = func.addEntryBlock(); 585 builder.setInsertionPointToStart(entryBlock); 586 587 Location loc = func.getLoc(); 588 ValueRange args = entryBlock->getArguments(); 589 Value lo = args[loIdx]; 590 Value hi = args[hiIdx]; 591 Value sum = builder.create<arith::AddIOp>(loc, lo, hi); 592 Value c1 = constantIndex(builder, loc, 1); 593 Value p = builder.create<arith::ShRUIOp>(loc, sum, c1); 594 595 Value i = lo; 596 Value j = builder.create<arith::SubIOp>(loc, hi, c1); 597 createChoosePivot(builder, module, func, nx, ny, isCoo, i, j, p, args); 598 SmallVector<Value, 3> operands{i, j, p}; // Exactly three values. 599 SmallVector<Type, 3> types{i.getType(), j.getType(), p.getType()}; 600 scf::WhileOp whileOp = builder.create<scf::WhileOp>(loc, types, operands); 601 602 // The before-region of the WhileOp. 603 Block *before = 604 builder.createBlock(&whileOp.getBefore(), {}, types, {loc, loc, loc}); 605 builder.setInsertionPointToEnd(before); 606 Value cond = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult, 607 before->getArgument(0), 608 before->getArgument(1)); 609 builder.create<scf::ConditionOp>(loc, cond, before->getArguments()); 610 611 // The after-region of the WhileOp. 612 Block *after = 613 builder.createBlock(&whileOp.getAfter(), {}, types, {loc, loc, loc}); 614 builder.setInsertionPointToEnd(after); 615 i = after->getArgument(0); 616 j = after->getArgument(1); 617 p = after->getArgument(2); 618 619 uint64_t numXBuffers = isCoo ? 1 : nx; 620 auto [iresult, iCompareEq] = 621 createScanLoop(builder, module, func, args.slice(xStartIdx, numXBuffers), 622 i, p, nx, ny, isCoo, 1); 623 i = iresult; 624 auto [jresult, jCompareEq] = 625 createScanLoop(builder, module, func, args.slice(xStartIdx, numXBuffers), 626 j, p, nx, ny, isCoo, -1); 627 j = jresult; 628 629 // If i < j: 630 cond = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult, i, j); 631 scf::IfOp ifOp = builder.create<scf::IfOp>(loc, types, cond, /*else=*/true); 632 builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); 633 SmallVector<Value> swapOperands{i, j}; 634 swapOperands.append(args.begin() + xStartIdx, args.end()); 635 createSwap(builder, loc, swapOperands, nx, ny, isCoo); 636 // If the pivot is moved, update p with the new pivot. 637 Value icond = 638 builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, i, p); 639 scf::IfOp ifOpI = builder.create<scf::IfOp>(loc, TypeRange{p.getType()}, 640 icond, /*else=*/true); 641 builder.setInsertionPointToStart(&ifOpI.getThenRegion().front()); 642 builder.create<scf::YieldOp>(loc, ValueRange{j}); 643 builder.setInsertionPointToStart(&ifOpI.getElseRegion().front()); 644 Value jcond = 645 builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, j, p); 646 scf::IfOp ifOpJ = builder.create<scf::IfOp>(loc, TypeRange{p.getType()}, 647 jcond, /*else=*/true); 648 builder.setInsertionPointToStart(&ifOpJ.getThenRegion().front()); 649 builder.create<scf::YieldOp>(loc, ValueRange{i}); 650 builder.setInsertionPointToStart(&ifOpJ.getElseRegion().front()); 651 builder.create<scf::YieldOp>(loc, ValueRange{p}); 652 builder.setInsertionPointAfter(ifOpJ); 653 builder.create<scf::YieldOp>(loc, ifOpJ.getResults()); 654 builder.setInsertionPointAfter(ifOpI); 655 Value compareEqIJ = 656 builder.create<arith::AndIOp>(loc, iCompareEq, jCompareEq); 657 scf::IfOp ifOp2 = builder.create<scf::IfOp>( 658 loc, TypeRange{i.getType(), j.getType()}, compareEqIJ, /*else=*/true); 659 builder.setInsertionPointToStart(&ifOp2.getThenRegion().front()); 660 Value i2 = builder.create<arith::AddIOp>(loc, i, c1); 661 Value j2 = builder.create<arith::SubIOp>(loc, j, c1); 662 builder.create<scf::YieldOp>(loc, ValueRange{i2, j2}); 663 builder.setInsertionPointToStart(&ifOp2.getElseRegion().front()); 664 builder.create<scf::YieldOp>(loc, ValueRange{i, j}); 665 builder.setInsertionPointAfter(ifOp2); 666 builder.create<scf::YieldOp>( 667 loc, 668 ValueRange{ifOp2.getResult(0), ifOp2.getResult(1), ifOpI.getResult(0)}); 669 670 // False branch for if i < j: 671 builder.setInsertionPointToStart(&ifOp.getElseRegion().front()); 672 builder.create<scf::YieldOp>(loc, ValueRange{i, j, p}); 673 674 // Return for the whileOp. 675 builder.setInsertionPointAfter(ifOp); 676 builder.create<scf::YieldOp>(loc, ifOp.getResults()); 677 678 // Return for the function. 679 builder.setInsertionPointAfter(whileOp); 680 builder.create<func::ReturnOp>(loc, whileOp.getResult(2)); 681 } 682 683 /// Creates a function to perform quick sort on the value in the range of 684 /// index [lo, hi). 685 // 686 // The generate IR corresponds to this C like algorithm: 687 // void quickSort(lo, hi, data) { 688 // if (lo < hi) { 689 // p = partition(low, high, data); 690 // quickSort(lo, p, data); 691 // quickSort(p + 1, hi, data); 692 // } 693 // } 694 static void createSortNonstableFunc(OpBuilder &builder, ModuleOp module, 695 func::FuncOp func, uint64_t nx, uint64_t ny, 696 bool isCoo, uint32_t nTrailingP) { 697 (void)nTrailingP; 698 OpBuilder::InsertionGuard insertionGuard(builder); 699 Block *entryBlock = func.addEntryBlock(); 700 builder.setInsertionPointToStart(entryBlock); 701 702 MLIRContext *context = module.getContext(); 703 Location loc = func.getLoc(); 704 ValueRange args = entryBlock->getArguments(); 705 Value lo = args[loIdx]; 706 Value hi = args[hiIdx]; 707 Value cond = 708 builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult, lo, hi); 709 scf::IfOp ifOp = builder.create<scf::IfOp>(loc, cond, /*else=*/false); 710 711 // The if-stmt true branch. 712 builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); 713 FlatSymbolRefAttr partitionFunc = getMangledSortHelperFunc( 714 builder, func, {IndexType::get(context)}, kPartitionFuncNamePrefix, nx, 715 ny, isCoo, args, createPartitionFunc); 716 auto p = builder.create<func::CallOp>( 717 loc, partitionFunc, TypeRange{IndexType::get(context)}, ValueRange(args)); 718 719 SmallVector<Value> lowOperands{lo, p.getResult(0)}; 720 lowOperands.append(args.begin() + xStartIdx, args.end()); 721 builder.create<func::CallOp>(loc, func, lowOperands); 722 723 SmallVector<Value> highOperands{ 724 builder.create<arith::AddIOp>(loc, p.getResult(0), 725 constantIndex(builder, loc, 1)), 726 hi}; 727 highOperands.append(args.begin() + xStartIdx, args.end()); 728 builder.create<func::CallOp>(loc, func, highOperands); 729 730 // After the if-stmt. 731 builder.setInsertionPointAfter(ifOp); 732 builder.create<func::ReturnOp>(loc); 733 } 734 735 /// Creates a function to perform insertion sort on the values in the range of 736 /// index [lo, hi). 737 // 738 // The generate IR corresponds to this C like algorithm: 739 // void insertionSort(lo, hi, data) { 740 // for (i = lo+1; i < hi; i++) { 741 // d = data[i]; 742 // p = binarySearch(lo, i-1, data) 743 // for (j = 0; j > i - p; j++) 744 // data[i-j] = data[i-j-1] 745 // data[p] = d 746 // } 747 // } 748 static void createSortStableFunc(OpBuilder &builder, ModuleOp module, 749 func::FuncOp func, uint64_t nx, uint64_t ny, 750 bool isCoo, uint32_t nTrailingP) { 751 // Stable sort function doesn't use trailing parameters. 752 (void)nTrailingP; 753 assert(nTrailingP == 0); 754 OpBuilder::InsertionGuard insertionGuard(builder); 755 Block *entryBlock = func.addEntryBlock(); 756 builder.setInsertionPointToStart(entryBlock); 757 758 MLIRContext *context = module.getContext(); 759 Location loc = func.getLoc(); 760 ValueRange args = entryBlock->getArguments(); 761 Value c1 = constantIndex(builder, loc, 1); 762 Value lo = args[loIdx]; 763 Value hi = args[hiIdx]; 764 Value lop1 = builder.create<arith::AddIOp>(loc, lo, c1); 765 766 // Start the outer for-stmt with induction variable i. 767 scf::ForOp forOpI = builder.create<scf::ForOp>(loc, lop1, hi, c1); 768 builder.setInsertionPointToStart(forOpI.getBody()); 769 Value i = forOpI.getInductionVar(); 770 771 // Binary search to find the insertion point p. 772 SmallVector<Value> operands{lo, i}; 773 operands.append(args.begin() + xStartIdx, args.end()); 774 FlatSymbolRefAttr searchFunc = getMangledSortHelperFunc( 775 builder, func, {IndexType::get(context)}, kBinarySearchFuncNamePrefix, nx, 776 ny, isCoo, operands, createBinarySearchFunc); 777 Value p = builder 778 .create<func::CallOp>(loc, searchFunc, TypeRange{c1.getType()}, 779 operands) 780 .getResult(0); 781 782 // Move the value at data[i] to a temporary location. 783 operands[0] = operands[1] = i; 784 SmallVector<Value> d; 785 forEachIJPairInAllBuffers( 786 builder, loc, operands, nx, ny, isCoo, 787 [&](uint64_t unused, Value i, Value unused2, Value buffer) { 788 d.push_back(builder.create<memref::LoadOp>(loc, buffer, i)); 789 }); 790 791 // Start the inner for-stmt with induction variable j, for moving data[p..i) 792 // to data[p+1..i+1). 793 Value imp = builder.create<arith::SubIOp>(loc, i, p); 794 Value c0 = constantIndex(builder, loc, 0); 795 scf::ForOp forOpJ = builder.create<scf::ForOp>(loc, c0, imp, c1); 796 builder.setInsertionPointToStart(forOpJ.getBody()); 797 Value j = forOpJ.getInductionVar(); 798 Value imj = builder.create<arith::SubIOp>(loc, i, j); 799 operands[1] = imj; 800 operands[0] = builder.create<arith::SubIOp>(loc, imj, c1); 801 forEachIJPairInAllBuffers( 802 builder, loc, operands, nx, ny, isCoo, 803 [&](uint64_t unused, Value imjm1, Value imj, Value buffer) { 804 Value t = builder.create<memref::LoadOp>(loc, buffer, imjm1); 805 builder.create<memref::StoreOp>(loc, t, buffer, imj); 806 }); 807 808 // Store the value at data[i] to data[p]. 809 builder.setInsertionPointAfter(forOpJ); 810 operands[0] = operands[1] = p; 811 forEachIJPairInAllBuffers( 812 builder, loc, operands, nx, ny, isCoo, 813 [&](uint64_t k, Value p, Value usused, Value buffer) { 814 builder.create<memref::StoreOp>(loc, d[k], buffer, p); 815 }); 816 817 builder.setInsertionPointAfter(forOpI); 818 builder.create<func::ReturnOp>(loc); 819 } 820 821 /// Implements the rewriting for operator sort and sort_coo. 822 template <typename OpTy> 823 LogicalResult matchAndRewriteSortOp(OpTy op, ValueRange xys, uint64_t nx, 824 uint64_t ny, bool isCoo, 825 PatternRewriter &rewriter) { 826 Location loc = op.getLoc(); 827 SmallVector<Value> operands{constantIndex(rewriter, loc, 0), op.getN()}; 828 829 // Convert `values` to have dynamic shape and append them to `operands`. 830 for (Value v : xys) { 831 auto mtp = getMemRefType(v); 832 if (!mtp.isDynamicDim(0)) { 833 auto newMtp = 834 MemRefType::get({ShapedType::kDynamic}, mtp.getElementType()); 835 v = rewriter.create<memref::CastOp>(loc, newMtp, v); 836 } 837 operands.push_back(v); 838 } 839 bool isStable = 840 (op.getAlgorithm() == SparseTensorSortKind::InsertionSortStable); 841 auto insertPoint = op->template getParentOfType<func::FuncOp>(); 842 SmallString<32> funcName(isStable ? kSortStableFuncNamePrefix 843 : kSortNonstableFuncNamePrefix); 844 FuncGeneratorType funcGenerator = 845 isStable ? createSortStableFunc : createSortNonstableFunc; 846 uint32_t nTrailingP = 0; 847 FlatSymbolRefAttr func = 848 getMangledSortHelperFunc(rewriter, insertPoint, TypeRange(), funcName, nx, 849 ny, isCoo, operands, funcGenerator, nTrailingP); 850 rewriter.replaceOpWithNewOp<func::CallOp>(op, func, TypeRange(), operands); 851 return success(); 852 } 853 854 //===---------------------------------------------------------------------===// 855 // The actual sparse buffer rewriting rules. 856 //===---------------------------------------------------------------------===// 857 858 namespace { 859 860 /// Sparse rewriting rule for the push_back operator. 861 struct PushBackRewriter : OpRewritePattern<PushBackOp> { 862 public: 863 using OpRewritePattern<PushBackOp>::OpRewritePattern; 864 PushBackRewriter(MLIRContext *context, bool enableInit) 865 : OpRewritePattern(context), enableBufferInitialization(enableInit) {} 866 LogicalResult matchAndRewrite(PushBackOp op, 867 PatternRewriter &rewriter) const override { 868 // Rewrite push_back(buffer, value, n) to: 869 // new_size = size(buffer) + n 870 // if (new_size > capacity(buffer)) 871 // while new_size > new_capacity 872 // new_capacity = new_capacity*2 873 // new_buffer = realloc(buffer, new_capacity) 874 // buffer = new_buffer 875 // subBuffer = subviewof(buffer) 876 // linalg.fill subBuffer value 877 // 878 // size(buffer) += n 879 // 880 // The capacity check is skipped when the attribute inbounds is presented. 881 Location loc = op->getLoc(); 882 Value c0 = constantIndex(rewriter, loc, 0); 883 Value buffer = op.getInBuffer(); 884 Value capacity = rewriter.create<memref::DimOp>(loc, buffer, c0); 885 Value size = op.getCurSize(); 886 Value value = op.getValue(); 887 888 Value n = op.getN() ? op.getN() : constantIndex(rewriter, loc, 1); 889 Value newSize = rewriter.create<arith::AddIOp>(loc, size, n); 890 auto nValue = dyn_cast_or_null<arith::ConstantIndexOp>(n.getDefiningOp()); 891 bool nIsOne = (nValue && nValue.value() == 1); 892 893 if (!op.getInbounds()) { 894 Value cond = rewriter.create<arith::CmpIOp>( 895 loc, arith::CmpIPredicate::ugt, newSize, capacity); 896 897 Value c2 = constantIndex(rewriter, loc, 2); 898 auto bufferType = 899 MemRefType::get({ShapedType::kDynamic}, value.getType()); 900 scf::IfOp ifOp = rewriter.create<scf::IfOp>(loc, bufferType, cond, 901 /*else=*/true); 902 // True branch. 903 rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front()); 904 if (nIsOne) { 905 capacity = rewriter.create<arith::MulIOp>(loc, capacity, c2); 906 } else { 907 // Use a do-while loop to calculate the new capacity as follows: 908 // do { new_capacity *= 2 } while (size > new_capacity) 909 scf::WhileOp whileOp = 910 rewriter.create<scf::WhileOp>(loc, capacity.getType(), capacity); 911 912 // The before-region of the WhileOp. 913 Block *before = rewriter.createBlock(&whileOp.getBefore(), {}, 914 {capacity.getType()}, {loc}); 915 rewriter.setInsertionPointToEnd(before); 916 917 capacity = 918 rewriter.create<arith::MulIOp>(loc, before->getArgument(0), c2); 919 cond = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ugt, 920 newSize, capacity); 921 rewriter.create<scf::ConditionOp>(loc, cond, ValueRange{capacity}); 922 // The after-region of the WhileOp. 923 Block *after = rewriter.createBlock(&whileOp.getAfter(), {}, 924 {capacity.getType()}, {loc}); 925 rewriter.setInsertionPointToEnd(after); 926 rewriter.create<scf::YieldOp>(loc, after->getArguments()); 927 928 rewriter.setInsertionPointAfter(whileOp); 929 capacity = whileOp.getResult(0); 930 } 931 932 Value newBuffer = 933 rewriter.create<memref::ReallocOp>(loc, bufferType, buffer, capacity); 934 if (enableBufferInitialization) { 935 Value fillSize = rewriter.create<arith::SubIOp>(loc, capacity, newSize); 936 Value fillValue = constantZero(rewriter, loc, value.getType()); 937 Value subBuffer = rewriter.create<memref::SubViewOp>( 938 loc, newBuffer, /*offset=*/ValueRange{newSize}, 939 /*size=*/ValueRange{fillSize}, 940 /*step=*/ValueRange{constantIndex(rewriter, loc, 1)}); 941 rewriter.create<linalg::FillOp>(loc, fillValue, subBuffer); 942 } 943 rewriter.create<scf::YieldOp>(loc, newBuffer); 944 945 // False branch. 946 rewriter.setInsertionPointToStart(&ifOp.getElseRegion().front()); 947 rewriter.create<scf::YieldOp>(loc, buffer); 948 949 // Prepare for adding the value to the end of the buffer. 950 rewriter.setInsertionPointAfter(ifOp); 951 buffer = ifOp.getResult(0); 952 } 953 954 // Add the value to the end of the buffer. 955 if (nIsOne) { 956 rewriter.create<memref::StoreOp>(loc, value, buffer, size); 957 } else { 958 Value subBuffer = rewriter.create<memref::SubViewOp>( 959 loc, buffer, /*offset=*/ValueRange{size}, /*size=*/ValueRange{n}, 960 /*step=*/ValueRange{constantIndex(rewriter, loc, 1)}); 961 rewriter.create<linalg::FillOp>(loc, value, subBuffer); 962 } 963 964 // Update the buffer size. 965 rewriter.replaceOp(op, {buffer, newSize}); 966 return success(); 967 } 968 969 private: 970 bool enableBufferInitialization; 971 }; 972 973 /// Sparse rewriting rule for the sort operator. 974 struct SortRewriter : public OpRewritePattern<SortOp> { 975 public: 976 using OpRewritePattern<SortOp>::OpRewritePattern; 977 978 LogicalResult matchAndRewrite(SortOp op, 979 PatternRewriter &rewriter) const override { 980 SmallVector<Value> xys(op.getXs()); 981 xys.append(op.getYs().begin(), op.getYs().end()); 982 return matchAndRewriteSortOp(op, xys, op.getXs().size(), /*ny=*/0, 983 /*isCoo=*/false, rewriter); 984 } 985 }; 986 987 /// Sparse rewriting rule for the sort_coo operator. 988 struct SortCooRewriter : public OpRewritePattern<SortCooOp> { 989 public: 990 using OpRewritePattern<SortCooOp>::OpRewritePattern; 991 992 LogicalResult matchAndRewrite(SortCooOp op, 993 PatternRewriter &rewriter) const override { 994 SmallVector<Value> xys; 995 xys.push_back(op.getXy()); 996 xys.append(op.getYs().begin(), op.getYs().end()); 997 uint64_t nx = 1; 998 if (auto nxAttr = op.getNxAttr()) 999 nx = nxAttr.getInt(); 1000 1001 uint64_t ny = 0; 1002 if (auto nyAttr = op.getNyAttr()) 1003 ny = nyAttr.getInt(); 1004 1005 return matchAndRewriteSortOp(op, xys, nx, ny, 1006 /*isCoo=*/true, rewriter); 1007 } 1008 }; 1009 1010 } // namespace 1011 1012 //===---------------------------------------------------------------------===// 1013 // Methods that add patterns described in this file to a pattern list. 1014 //===---------------------------------------------------------------------===// 1015 1016 void mlir::populateSparseBufferRewriting(RewritePatternSet &patterns, 1017 bool enableBufferInitialization) { 1018 patterns.add<PushBackRewriter>(patterns.getContext(), 1019 enableBufferInitialization); 1020 patterns.add<SortRewriter, SortCooRewriter>(patterns.getContext()); 1021 } 1022