1 //===- SparseBufferRewriting.cpp - Sparse buffer rewriting rules ----------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements rewriting rules that are specific to sparse tensor 10 // primitives with memref operands. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "CodegenUtils.h" 15 16 #include "mlir/Dialect/Arith/IR/Arith.h" 17 #include "mlir/Dialect/Func/IR/FuncOps.h" 18 #include "mlir/Dialect/Linalg/IR/Linalg.h" 19 #include "mlir/Dialect/Math/IR/Math.h" 20 #include "mlir/Dialect/MemRef/IR/MemRef.h" 21 #include "mlir/Dialect/SCF/IR/SCF.h" 22 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" 23 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h" 24 #include "mlir/Support/LLVM.h" 25 26 using namespace mlir; 27 using namespace mlir::sparse_tensor; 28 29 //===---------------------------------------------------------------------===// 30 // Helper methods for the actual rewriting rules. 31 //===---------------------------------------------------------------------===// 32 33 static constexpr uint64_t loIdx = 0; 34 static constexpr uint64_t hiIdx = 1; 35 static constexpr uint64_t xStartIdx = 2; 36 37 static constexpr const char kPartitionFuncNamePrefix[] = "_sparse_partition_"; 38 static constexpr const char kBinarySearchFuncNamePrefix[] = 39 "_sparse_binary_search_"; 40 static constexpr const char kHybridQuickSortFuncNamePrefix[] = 41 "_sparse_hybrid_qsort_"; 42 static constexpr const char kSortStableFuncNamePrefix[] = 43 "_sparse_sort_stable_"; 44 static constexpr const char kShiftDownFuncNamePrefix[] = "_sparse_shift_down_"; 45 static constexpr const char kHeapSortFuncNamePrefix[] = "_sparse_heap_sort_"; 46 static constexpr const char kQuickSortFuncNamePrefix[] = "_sparse_qsort_"; 47 48 using FuncGeneratorType = function_ref<void( 49 OpBuilder &, ModuleOp, func::FuncOp, uint64_t, uint64_t, bool, uint32_t)>; 50 51 /// Constructs a function name with this format to facilitate quick sort: 52 /// <namePrefix><nx>_<x type>_<y0 type>..._<yn type> for sort 53 /// <namePrefix><nx>_<x type>_coo_<ny>_<y0 type>..._<yn type> for sort_coo 54 static void getMangledSortHelperFuncName(llvm::raw_svector_ostream &nameOstream, 55 StringRef namePrefix, uint64_t nx, 56 uint64_t ny, bool isCoo, 57 ValueRange operands) { 58 nameOstream << namePrefix << nx << "_" 59 << getMemRefType(operands[xStartIdx]).getElementType(); 60 61 if (isCoo) 62 nameOstream << "_coo_" << ny; 63 64 uint64_t yBufferOffset = isCoo ? 1 : nx; 65 for (Value v : operands.drop_front(xStartIdx + yBufferOffset)) 66 nameOstream << "_" << getMemRefType(v).getElementType(); 67 } 68 69 /// Looks up a function that is appropriate for the given operands being 70 /// sorted, and creates such a function if it doesn't exist yet. The 71 /// parameters `nx` and `ny` tell the number of x and y values provided 72 /// by the buffer in xStartIdx, and `isCoo` indicates whether the instruction 73 /// being processed is a sparse_tensor.sort or sparse_tensor.sort_coo. 74 // 75 // All sorting function generators take (lo, hi, xs, ys) in `operands` as 76 // parameters for the sorting functions. Other parameters, such as the recursive 77 // call depth, are appended to the end of the parameter list as 78 // "trailing parameters". 79 static FlatSymbolRefAttr 80 getMangledSortHelperFunc(OpBuilder &builder, func::FuncOp insertPoint, 81 TypeRange resultTypes, StringRef namePrefix, 82 uint64_t nx, uint64_t ny, bool isCoo, 83 ValueRange operands, FuncGeneratorType createFunc, 84 uint32_t nTrailingP = 0) { 85 SmallString<32> nameBuffer; 86 llvm::raw_svector_ostream nameOstream(nameBuffer); 87 getMangledSortHelperFuncName(nameOstream, namePrefix, nx, ny, isCoo, 88 operands.drop_back(nTrailingP)); 89 90 ModuleOp module = insertPoint->getParentOfType<ModuleOp>(); 91 MLIRContext *context = module.getContext(); 92 auto result = SymbolRefAttr::get(context, nameOstream.str()); 93 auto func = module.lookupSymbol<func::FuncOp>(result.getAttr()); 94 95 if (!func) { 96 // Create the function. 97 OpBuilder::InsertionGuard insertionGuard(builder); 98 builder.setInsertionPoint(insertPoint); 99 Location loc = insertPoint.getLoc(); 100 func = builder.create<func::FuncOp>( 101 loc, nameOstream.str(), 102 FunctionType::get(context, operands.getTypes(), resultTypes)); 103 func.setPrivate(); 104 createFunc(builder, module, func, nx, ny, isCoo, nTrailingP); 105 } 106 107 return result; 108 } 109 110 /// Creates a code block to process each pair of (xs[i], xs[j]) for sorting. 111 /// The code to process the value pairs is generated by `bodyBuilder`. 112 static void forEachIJPairInXs( 113 OpBuilder &builder, Location loc, ValueRange args, uint64_t nx, uint64_t ny, 114 bool isCoo, function_ref<void(uint64_t, Value, Value, Value)> bodyBuilder) { 115 Value iOffset, jOffset; 116 if (isCoo) { 117 Value cstep = constantIndex(builder, loc, nx + ny); 118 iOffset = builder.create<arith::MulIOp>(loc, args[0], cstep); 119 jOffset = builder.create<arith::MulIOp>(loc, args[1], cstep); 120 } 121 for (uint64_t k = 0; k < nx; k++) { 122 scf::IfOp ifOp; 123 Value i, j, buffer; 124 if (isCoo) { 125 Value ck = constantIndex(builder, loc, k); 126 i = builder.create<arith::AddIOp>(loc, ck, iOffset); 127 j = builder.create<arith::AddIOp>(loc, ck, jOffset); 128 buffer = args[xStartIdx]; 129 } else { 130 i = args[0]; 131 j = args[1]; 132 buffer = args[xStartIdx + k]; 133 } 134 bodyBuilder(k, i, j, buffer); 135 } 136 } 137 138 /// Creates a code block to process each pair of (xys[i], xys[j]) for sorting. 139 /// The code to process the value pairs is generated by `bodyBuilder`. 140 static void forEachIJPairInAllBuffers( 141 OpBuilder &builder, Location loc, ValueRange args, uint64_t nx, uint64_t ny, 142 bool isCoo, function_ref<void(uint64_t, Value, Value, Value)> bodyBuilder) { 143 144 // Create code for the first (nx + ny) buffers. When isCoo==true, these 145 // logical buffers are all from the xy buffer of the sort_coo operator. 146 forEachIJPairInXs(builder, loc, args, nx + ny, 0, isCoo, bodyBuilder); 147 148 uint64_t numHandledBuffers = isCoo ? 1 : nx + ny; 149 150 // Create code for the remaining buffers. 151 Value i = args[0]; 152 Value j = args[1]; 153 for (const auto &arg : 154 llvm::enumerate(args.drop_front(xStartIdx + numHandledBuffers))) { 155 bodyBuilder(arg.index() + nx + ny, i, j, arg.value()); 156 } 157 } 158 159 /// Creates a code block for swapping the values in index i and j for all the 160 /// buffers. 161 // 162 // The generated IR corresponds to this C like algorithm: 163 // swap(x0[i], x0[j]); 164 // swap(x1[i], x1[j]); 165 // ... 166 // swap(xn[i], xn[j]); 167 // swap(y0[i], y0[j]); 168 // ... 169 // swap(yn[i], yn[j]); 170 static void createSwap(OpBuilder &builder, Location loc, ValueRange args, 171 uint64_t nx, uint64_t ny, bool isCoo) { 172 auto swapOnePair = [&](uint64_t unused, Value i, Value j, Value buffer) { 173 Value vi = builder.create<memref::LoadOp>(loc, buffer, i); 174 Value vj = builder.create<memref::LoadOp>(loc, buffer, j); 175 builder.create<memref::StoreOp>(loc, vj, buffer, i); 176 builder.create<memref::StoreOp>(loc, vi, buffer, j); 177 }; 178 179 forEachIJPairInAllBuffers(builder, loc, args, nx, ny, isCoo, swapOnePair); 180 } 181 182 /// Creates code to compare all the (xs[i], xs[j]) pairs. The method to compare 183 /// each pair is create via `compareBuilder`. 184 static Value createInlinedCompareImplementation( 185 OpBuilder &builder, Location loc, ValueRange args, uint64_t nx, uint64_t ny, 186 bool isCoo, 187 function_ref<Value(OpBuilder &, Location, Value, Value, Value, bool, bool)> 188 compareBuilder) { 189 Value result; 190 auto bodyBuilder = [&](uint64_t k, Value i, Value j, Value buffer) { 191 bool isFirstDim = (k == 0); 192 bool isLastDim = (k == nx - 1); 193 Value val = 194 compareBuilder(builder, loc, i, j, buffer, isFirstDim, isLastDim); 195 if (isFirstDim) { 196 result = val; 197 } else if (!isLastDim) { 198 OpBuilder::InsertionGuard insertionGuard(builder); 199 auto ifOp = cast<scf::IfOp>(val.getDefiningOp()); 200 builder.setInsertionPointAfter(ifOp); 201 builder.create<scf::YieldOp>(loc, ifOp.getResult(0)); 202 } 203 }; 204 205 forEachIJPairInXs(builder, loc, args, nx, ny, isCoo, bodyBuilder); 206 207 builder.setInsertionPointAfterValue(result); 208 return result; 209 } 210 211 /// Generates code to compare whether x[i] is equal to x[j] and returns the 212 /// result of the comparison. 213 static Value createEqCompare(OpBuilder &builder, Location loc, Value i, Value j, 214 Value x, bool isFirstDim, bool isLastDim) { 215 Value vi = builder.create<memref::LoadOp>(loc, x, i); 216 Value vj = builder.create<memref::LoadOp>(loc, x, j); 217 218 Value res; 219 if (isLastDim) { 220 res = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, vi, vj); 221 // For 1D, we create a compare without any control flow. Otherwise, we 222 // create YieldOp to return the result in the nested if-stmt. 223 if (!isFirstDim) 224 builder.create<scf::YieldOp>(loc, res); 225 } else { 226 Value ne = 227 builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ne, vi, vj); 228 scf::IfOp ifOp = builder.create<scf::IfOp>(loc, builder.getIntegerType(1), 229 ne, /*else=*/true); 230 // If (x[i] != x[j]). 231 builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); 232 Value f = constantI1(builder, loc, false); 233 builder.create<scf::YieldOp>(loc, f); 234 235 // If (x[i] == x[j]). Set up the insertion point for the nested if-stmt that 236 // checks the remaining dimensions. 237 builder.setInsertionPointToStart(&ifOp.getElseRegion().front()); 238 res = ifOp.getResult(0); 239 } 240 241 return res; 242 } 243 244 /// Creates code to compare whether xs[i] is equal to xs[j]. 245 // 246 // The generate IR corresponds to this C like algorithm: 247 // if (x0[i] != x0[j]) 248 // return false; 249 // else 250 // if (x1[i] != x1[j]) 251 // return false; 252 // else if (x2[2] != x2[j])) 253 // and so on ... 254 static Value createInlinedEqCompare(OpBuilder &builder, Location loc, 255 ValueRange args, uint64_t nx, uint64_t ny, 256 bool isCoo, uint32_t nTrailingP = 0) { 257 // Compare functions don't use trailing parameters. 258 (void)nTrailingP; 259 assert(nTrailingP == 0); 260 return createInlinedCompareImplementation(builder, loc, args, nx, ny, isCoo, 261 createEqCompare); 262 } 263 264 /// Generates code to compare whether x[i] is less than x[j] and returns the 265 /// result of the comparison. 266 static Value createLessThanCompare(OpBuilder &builder, Location loc, Value i, 267 Value j, Value x, bool isFirstDim, 268 bool isLastDim) { 269 Value vi = builder.create<memref::LoadOp>(loc, x, i); 270 Value vj = builder.create<memref::LoadOp>(loc, x, j); 271 272 Value res; 273 if (isLastDim) { 274 res = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult, vi, vj); 275 // For 1D, we create a compare without any control flow. Otherwise, we 276 // create YieldOp to return the result in the nested if-stmt. 277 if (!isFirstDim) 278 builder.create<scf::YieldOp>(loc, res); 279 } else { 280 Value ne = 281 builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ne, vi, vj); 282 scf::IfOp ifOp = builder.create<scf::IfOp>(loc, builder.getIntegerType(1), 283 ne, /*else=*/true); 284 // If (x[i] != x[j]). 285 builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); 286 Value lt = 287 builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult, vi, vj); 288 builder.create<scf::YieldOp>(loc, lt); 289 290 // If (x[i] == x[j]). Set up the insertion point for the nested if-stmt that 291 // checks the remaining dimensions. 292 builder.setInsertionPointToStart(&ifOp.getElseRegion().front()); 293 res = ifOp.getResult(0); 294 } 295 296 return res; 297 } 298 299 /// Creates code to compare whether xs[i] is less than xs[j]. 300 // 301 // The generate IR corresponds to this C like algorithm: 302 // if (x0[i] != x0[j]) 303 // return x0[i] < x0[j]; 304 // else if (x1[j] != x1[i]) 305 // return x1[i] < x1[j]; 306 // else 307 // and so on ... 308 static Value createInlinedLessThan(OpBuilder &builder, Location loc, 309 ValueRange args, uint64_t nx, uint64_t ny, 310 bool isCoo, uint32_t nTrailingP = 0) { 311 // Compare functions don't use trailing parameters. 312 (void)nTrailingP; 313 assert(nTrailingP == 0); 314 return createInlinedCompareImplementation(builder, loc, args, nx, ny, isCoo, 315 createLessThanCompare); 316 } 317 318 /// Creates a function to use a binary search to find the insertion point for 319 /// inserting xs[hi] to the sorted values xs[lo..hi). 320 // 321 // The generate IR corresponds to this C like algorithm: 322 // p = hi 323 // while (lo < hi) 324 // mid = (lo + hi) >> 1 325 // if (xs[p] < xs[mid]) 326 // hi = mid 327 // else 328 // lo = mid - 1 329 // return lo; 330 // 331 static void createBinarySearchFunc(OpBuilder &builder, ModuleOp module, 332 func::FuncOp func, uint64_t nx, uint64_t ny, 333 bool isCoo, uint32_t nTrailingP = 0) { 334 // Binary search doesn't use trailing parameters. 335 (void)nTrailingP; 336 assert(nTrailingP == 0); 337 OpBuilder::InsertionGuard insertionGuard(builder); 338 Block *entryBlock = func.addEntryBlock(); 339 builder.setInsertionPointToStart(entryBlock); 340 341 Location loc = func.getLoc(); 342 ValueRange args = entryBlock->getArguments(); 343 Value p = args[hiIdx]; 344 SmallVector<Type, 2> types(2, p.getType()); // Only two types. 345 scf::WhileOp whileOp = builder.create<scf::WhileOp>( 346 loc, types, SmallVector<Value, 2>{args[loIdx], args[hiIdx]}); 347 348 // The before-region of the WhileOp. 349 Block *before = 350 builder.createBlock(&whileOp.getBefore(), {}, types, {loc, loc}); 351 builder.setInsertionPointToEnd(before); 352 Value cond1 = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult, 353 before->getArgument(0), 354 before->getArgument(1)); 355 builder.create<scf::ConditionOp>(loc, cond1, before->getArguments()); 356 357 // The after-region of the WhileOp. 358 Block *after = 359 builder.createBlock(&whileOp.getAfter(), {}, types, {loc, loc}); 360 builder.setInsertionPointToEnd(after); 361 Value lo = after->getArgument(0); 362 Value hi = after->getArgument(1); 363 // Compute mid = (lo + hi) >> 1. 364 Value c1 = constantIndex(builder, loc, 1); 365 Value mid = builder.create<arith::ShRUIOp>( 366 loc, builder.create<arith::AddIOp>(loc, lo, hi), c1); 367 Value midp1 = builder.create<arith::AddIOp>(loc, mid, c1); 368 369 // Compare xs[p] < xs[mid]. 370 SmallVector<Value> compareOperands{p, mid}; 371 uint64_t numXBuffers = isCoo ? 1 : nx; 372 compareOperands.append(args.begin() + xStartIdx, 373 args.begin() + xStartIdx + numXBuffers); 374 Value cond2 = 375 createInlinedLessThan(builder, loc, compareOperands, nx, ny, isCoo); 376 // Update lo and hi for the WhileOp as follows: 377 // if (xs[p] < xs[mid])) 378 // hi = mid; 379 // else 380 // lo = mid + 1; 381 Value newLo = builder.create<arith::SelectOp>(loc, cond2, lo, midp1); 382 Value newHi = builder.create<arith::SelectOp>(loc, cond2, mid, hi); 383 builder.create<scf::YieldOp>(loc, ValueRange{newLo, newHi}); 384 385 builder.setInsertionPointAfter(whileOp); 386 builder.create<func::ReturnOp>(loc, whileOp.getResult(0)); 387 } 388 389 /// Creates code to advance i in a loop based on xs[p] as follows: 390 /// while (xs[i] < xs[p]) i += step (step > 0) 391 /// or 392 /// while (xs[i] > xs[p]) i += step (step < 0) 393 /// The routine returns i as well as a boolean value to indicate whether 394 /// xs[i] == xs[p]. 395 static std::pair<Value, Value> 396 createScanLoop(OpBuilder &builder, ModuleOp module, func::FuncOp func, 397 ValueRange xs, Value i, Value p, uint64_t nx, uint64_t ny, 398 bool isCoo, int step) { 399 Location loc = func.getLoc(); 400 scf::WhileOp whileOp = 401 builder.create<scf::WhileOp>(loc, TypeRange{i.getType()}, ValueRange{i}); 402 403 Block *before = 404 builder.createBlock(&whileOp.getBefore(), {}, {i.getType()}, {loc}); 405 builder.setInsertionPointToEnd(before); 406 SmallVector<Value> compareOperands; 407 if (step > 0) { 408 compareOperands.push_back(before->getArgument(0)); 409 compareOperands.push_back(p); 410 } else { 411 assert(step < 0); 412 compareOperands.push_back(p); 413 compareOperands.push_back(before->getArgument(0)); 414 } 415 compareOperands.append(xs.begin(), xs.end()); 416 Value cond = 417 createInlinedLessThan(builder, loc, compareOperands, nx, ny, isCoo); 418 builder.create<scf::ConditionOp>(loc, cond, before->getArguments()); 419 420 Block *after = 421 builder.createBlock(&whileOp.getAfter(), {}, {i.getType()}, {loc}); 422 builder.setInsertionPointToEnd(after); 423 Value cs = constantIndex(builder, loc, step); 424 i = builder.create<arith::AddIOp>(loc, after->getArgument(0), cs); 425 builder.create<scf::YieldOp>(loc, ValueRange{i}); 426 i = whileOp.getResult(0); 427 428 builder.setInsertionPointAfter(whileOp); 429 compareOperands[0] = i; 430 compareOperands[1] = p; 431 Value compareEq = 432 createInlinedEqCompare(builder, loc, compareOperands, nx, ny, isCoo); 433 434 return std::make_pair(whileOp.getResult(0), compareEq); 435 } 436 437 /// Creates and returns an IfOp to compare two elements and swap the elements 438 /// if compareFunc(data[b], data[a]) returns true. The new insertion point is 439 /// right after the swap instructions. 440 static scf::IfOp createCompareThenSwap(OpBuilder &builder, Location loc, 441 uint64_t nx, uint64_t ny, bool isCoo, 442 SmallVectorImpl<Value> &swapOperands, 443 SmallVectorImpl<Value> &compareOperands, 444 Value a, Value b) { 445 // Compare(data[b], data[a]). 446 compareOperands[0] = b; 447 compareOperands[1] = a; 448 Value cond = 449 createInlinedLessThan(builder, loc, compareOperands, nx, ny, isCoo); 450 scf::IfOp ifOp = builder.create<scf::IfOp>(loc, cond, /*else=*/false); 451 builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); 452 swapOperands[0] = b; 453 swapOperands[1] = a; 454 createSwap(builder, loc, swapOperands, nx, ny, isCoo); 455 return ifOp; 456 } 457 458 /// Creates code to insert the 3rd element to a list of two sorted elements. 459 static void createInsert3rd(OpBuilder &builder, Location loc, uint64_t nx, 460 uint64_t ny, bool isCoo, 461 SmallVectorImpl<Value> &swapOperands, 462 SmallVectorImpl<Value> &compareOperands, Value v0, 463 Value v1, Value v2) { 464 scf::IfOp ifOp = createCompareThenSwap(builder, loc, nx, ny, isCoo, 465 swapOperands, compareOperands, v1, v2); 466 createCompareThenSwap(builder, loc, nx, ny, isCoo, swapOperands, 467 compareOperands, v0, v1); 468 builder.setInsertionPointAfter(ifOp); 469 } 470 471 /// Creates code to sort 3 elements. 472 static void createSort3(OpBuilder &builder, Location loc, uint64_t nx, 473 uint64_t ny, bool isCoo, 474 SmallVectorImpl<Value> &swapOperands, 475 SmallVectorImpl<Value> &compareOperands, Value v0, 476 Value v1, Value v2) { 477 // Sort the first 2 elements. 478 scf::IfOp ifOp1 = createCompareThenSwap( 479 builder, loc, nx, ny, isCoo, swapOperands, compareOperands, v0, v1); 480 builder.setInsertionPointAfter(ifOp1); 481 482 // Insert the 3th element. 483 createInsert3rd(builder, loc, nx, ny, isCoo, swapOperands, compareOperands, 484 v0, v1, v2); 485 } 486 487 /// Creates code to sort 5 elements. 488 static void createSort5(OpBuilder &builder, Location loc, uint64_t nx, 489 uint64_t ny, bool isCoo, 490 SmallVectorImpl<Value> &swapOperands, 491 SmallVectorImpl<Value> &compareOperands, Value v0, 492 Value v1, Value v2, Value v3, Value v4) { 493 // Sort the first 3 elements. 494 createSort3(builder, loc, nx, ny, isCoo, swapOperands, compareOperands, v0, 495 v1, v2); 496 497 auto insert4th = [&]() { 498 scf::IfOp ifOp = createCompareThenSwap( 499 builder, loc, nx, ny, isCoo, swapOperands, compareOperands, v2, v3); 500 createInsert3rd(builder, loc, nx, ny, isCoo, swapOperands, compareOperands, 501 v0, v1, v2); 502 builder.setInsertionPointAfter(ifOp); 503 }; 504 505 // Insert the 4th element. 506 insert4th(); 507 508 // Insert the 5th element. 509 scf::IfOp ifOp = createCompareThenSwap(builder, loc, nx, ny, isCoo, 510 swapOperands, compareOperands, v3, v4); 511 insert4th(); 512 builder.setInsertionPointAfter(ifOp); 513 } 514 515 /// Creates a code block to swap the values in indices lo, mi, and hi so that 516 /// data[lo], data[mi] and data[hi] are sorted in non-decreasing values. When 517 /// the number of values in range [lo, hi) is more than a threshold, we also 518 /// include the middle of [lo, mi) and [mi, hi) and sort a total of five values. 519 static void createChoosePivot(OpBuilder &builder, ModuleOp module, 520 func::FuncOp func, uint64_t nx, uint64_t ny, 521 bool isCoo, Value lo, Value hi, Value mi, 522 ValueRange args) { 523 SmallVector<Value> compareOperands{mi, lo}; 524 uint64_t numXBuffers = isCoo ? 1 : nx; 525 compareOperands.append(args.begin() + xStartIdx, 526 args.begin() + xStartIdx + numXBuffers); 527 SmallVector<Value> swapOperands{mi, lo}; 528 swapOperands.append(args.begin() + xStartIdx, args.end()); 529 Location loc = func.getLoc(); 530 Value c1 = constantIndex(builder, loc, 1); 531 Value hiP1 = builder.create<arith::AddIOp>(loc, hi, c1); 532 Value len = builder.create<arith::SubIOp>(loc, hiP1, lo); 533 Value lenThreshold = constantIndex(builder, loc, 1000); 534 Value lenCond = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult, 535 len, lenThreshold); 536 scf::IfOp lenIf = builder.create<scf::IfOp>(loc, lenCond, /*else=*/true); 537 538 // When len < 1000, choose pivot from median of 3 values. 539 builder.setInsertionPointToStart(&lenIf.getThenRegion().front()); 540 createSort3(builder, loc, nx, ny, isCoo, swapOperands, compareOperands, lo, 541 mi, hi); 542 543 // When len >= 1000, choose pivot from median of 5 values. 544 builder.setInsertionPointToStart(&lenIf.getElseRegion().front()); 545 Value miP1 = builder.create<arith::AddIOp>(loc, hi, c1); 546 Value a = builder.create<arith::AddIOp>(loc, lo, miP1); 547 // Value a is the middle between [loc, mi]. 548 a = builder.create<arith::ShRUIOp>(loc, a, c1); 549 Value b = builder.create<arith::AddIOp>(loc, mi, hiP1); 550 // Value b is the middle between [mi, hi]. 551 b = builder.create<arith::ShRUIOp>(loc, b, c1); 552 createSort5(builder, loc, nx, ny, isCoo, swapOperands, compareOperands, lo, a, 553 mi, b, hi); 554 555 builder.setInsertionPointAfter(lenIf); 556 } 557 558 /// Creates a function to perform quick sort partition on the values in the 559 /// range of index [lo, hi), assuming lo < hi. 560 // 561 // The generated IR corresponds to this C like algorithm: 562 // int partition(lo, hi, xs) { 563 // p = (lo+hi)/2 // pivot index 564 // i = lo 565 // j = hi-1 566 // while (true) do { 567 // while (xs[i] < xs[p]) i ++; 568 // i_eq = (xs[i] == xs[p]); 569 // while (xs[j] > xs[p]) j --; 570 // j_eq = (xs[j] == xs[p]); 571 // 572 // if (i >= j) return j + 1; 573 // 574 // if (i < j) { 575 // swap(xs[i], xs[j]) 576 // if (i == p) { 577 // p = j; 578 // } else if (j == p) { 579 // p = i; 580 // } 581 // if (i_eq && j_eq) { 582 // ++i; 583 // --j; 584 // } 585 // } 586 // } 587 // } 588 static void createPartitionFunc(OpBuilder &builder, ModuleOp module, 589 func::FuncOp func, uint64_t nx, uint64_t ny, 590 bool isCoo, uint32_t nTrailingP = 0) { 591 // Quick sort partition doesn't use trailing parameters. 592 (void)nTrailingP; 593 assert(nTrailingP == 0); 594 OpBuilder::InsertionGuard insertionGuard(builder); 595 596 Block *entryBlock = func.addEntryBlock(); 597 builder.setInsertionPointToStart(entryBlock); 598 599 Location loc = func.getLoc(); 600 ValueRange args = entryBlock->getArguments(); 601 Value lo = args[loIdx]; 602 Value hi = args[hiIdx]; 603 Value sum = builder.create<arith::AddIOp>(loc, lo, hi); 604 Value c1 = constantIndex(builder, loc, 1); 605 Value p = builder.create<arith::ShRUIOp>(loc, sum, c1); 606 607 Value i = lo; 608 Value j = builder.create<arith::SubIOp>(loc, hi, c1); 609 createChoosePivot(builder, module, func, nx, ny, isCoo, i, j, p, args); 610 Value trueVal = constantI1(builder, loc, true); // The value for while (true) 611 SmallVector<Value, 4> operands{i, j, p, trueVal}; // Exactly four values. 612 SmallVector<Type, 4> types{i.getType(), j.getType(), p.getType(), 613 trueVal.getType()}; 614 scf::WhileOp whileOp = builder.create<scf::WhileOp>(loc, types, operands); 615 616 // The before-region of the WhileOp. 617 Block *before = builder.createBlock(&whileOp.getBefore(), {}, types, 618 {loc, loc, loc, loc}); 619 builder.setInsertionPointToEnd(before); 620 builder.create<scf::ConditionOp>(loc, before->getArgument(3), 621 before->getArguments()); 622 623 // The after-region of the WhileOp. 624 Block *after = 625 builder.createBlock(&whileOp.getAfter(), {}, types, {loc, loc, loc, loc}); 626 builder.setInsertionPointToEnd(after); 627 i = after->getArgument(0); 628 j = after->getArgument(1); 629 p = after->getArgument(2); 630 631 uint64_t numXBuffers = isCoo ? 1 : nx; 632 auto [iresult, iCompareEq] = 633 createScanLoop(builder, module, func, args.slice(xStartIdx, numXBuffers), 634 i, p, nx, ny, isCoo, 1); 635 i = iresult; 636 auto [jresult, jCompareEq] = 637 createScanLoop(builder, module, func, args.slice(xStartIdx, numXBuffers), 638 j, p, nx, ny, isCoo, -1); 639 j = jresult; 640 641 // If i < j: 642 Value cond = 643 builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult, i, j); 644 scf::IfOp ifOp = builder.create<scf::IfOp>(loc, types, cond, /*else=*/true); 645 builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); 646 SmallVector<Value> swapOperands{i, j}; 647 swapOperands.append(args.begin() + xStartIdx, args.end()); 648 createSwap(builder, loc, swapOperands, nx, ny, isCoo); 649 // If the pivot is moved, update p with the new pivot. 650 Value icond = 651 builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, i, p); 652 scf::IfOp ifOpI = builder.create<scf::IfOp>(loc, TypeRange{p.getType()}, 653 icond, /*else=*/true); 654 builder.setInsertionPointToStart(&ifOpI.getThenRegion().front()); 655 builder.create<scf::YieldOp>(loc, ValueRange{j}); 656 builder.setInsertionPointToStart(&ifOpI.getElseRegion().front()); 657 Value jcond = 658 builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, j, p); 659 scf::IfOp ifOpJ = builder.create<scf::IfOp>(loc, TypeRange{p.getType()}, 660 jcond, /*else=*/true); 661 builder.setInsertionPointToStart(&ifOpJ.getThenRegion().front()); 662 builder.create<scf::YieldOp>(loc, ValueRange{i}); 663 builder.setInsertionPointToStart(&ifOpJ.getElseRegion().front()); 664 builder.create<scf::YieldOp>(loc, ValueRange{p}); 665 builder.setInsertionPointAfter(ifOpJ); 666 builder.create<scf::YieldOp>(loc, ifOpJ.getResults()); 667 builder.setInsertionPointAfter(ifOpI); 668 Value compareEqIJ = 669 builder.create<arith::AndIOp>(loc, iCompareEq, jCompareEq); 670 scf::IfOp ifOp2 = builder.create<scf::IfOp>( 671 loc, TypeRange{i.getType(), j.getType()}, compareEqIJ, /*else=*/true); 672 builder.setInsertionPointToStart(&ifOp2.getThenRegion().front()); 673 Value i2 = builder.create<arith::AddIOp>(loc, i, c1); 674 Value j2 = builder.create<arith::SubIOp>(loc, j, c1); 675 builder.create<scf::YieldOp>(loc, ValueRange{i2, j2}); 676 builder.setInsertionPointToStart(&ifOp2.getElseRegion().front()); 677 builder.create<scf::YieldOp>(loc, ValueRange{i, j}); 678 builder.setInsertionPointAfter(ifOp2); 679 builder.create<scf::YieldOp>( 680 loc, 681 ValueRange{ifOp2.getResult(0), ifOp2.getResult(1), ifOpI.getResult(0), 682 /*cont=*/constantI1(builder, loc, true)}); 683 684 // False branch for if i < j (i.e., i >= j): 685 builder.setInsertionPointToStart(&ifOp.getElseRegion().front()); 686 p = builder.create<arith::AddIOp>(loc, j, 687 constantOne(builder, loc, j.getType())); 688 builder.create<scf::YieldOp>( 689 loc, ValueRange{i, j, p, /*cont=*/constantI1(builder, loc, false)}); 690 691 // Return for the whileOp. 692 builder.setInsertionPointAfter(ifOp); 693 builder.create<scf::YieldOp>(loc, ifOp.getResults()); 694 695 // Return for the function. 696 builder.setInsertionPointAfter(whileOp); 697 builder.create<func::ReturnOp>(loc, whileOp.getResult(2)); 698 } 699 700 /// Computes (n-2)/n, assuming n has index type. 701 static Value createSubTwoDividedByTwo(OpBuilder &builder, Location loc, 702 Value n) { 703 Value i2 = constantIndex(builder, loc, 2); 704 Value res = builder.create<arith::SubIOp>(loc, n, i2); 705 Value i1 = constantIndex(builder, loc, 1); 706 return builder.create<arith::ShRUIOp>(loc, res, i1); 707 } 708 709 /// Creates a function to heapify the subtree with root `start` within the full 710 /// binary tree in the range of index [first, first + n). 711 // 712 // The generated IR corresponds to this C like algorithm: 713 // void shiftDown(first, start, n, data) { 714 // if (n >= 2) { 715 // child = start - first 716 // if ((n-2)/2 >= child) { 717 // // Left child exists. 718 // child = child * 2 + 1 // Initialize the bigger child to left child. 719 // childIndex = child + first 720 // if (child+1 < n && data[childIndex] < data[childIndex+1]) 721 // // Right child exits and is bigger. 722 // childIndex++; child++; 723 // // Shift data[start] down to where it belongs in the subtree. 724 // while (data[start] < data[childIndex) { 725 // swap(data[start], data[childIndex]) 726 // start = childIndex 727 // if ((n - 2)/2 >= child) { 728 // // Left child exists. 729 // child = 2*child + 1 730 // childIndex = child + 1 731 // if (child + 1) < n && data[childIndex] < data[childIndex+1] 732 // childIndex++; child++; 733 // } 734 // } 735 // } 736 // } 737 // } 738 // 739 static void createShiftDownFunc(OpBuilder &builder, ModuleOp module, 740 func::FuncOp func, uint64_t nx, uint64_t ny, 741 bool isCoo, uint32_t nTrailingP) { 742 // The value n is passed in as a trailing parameter. 743 assert(nTrailingP == 1); 744 OpBuilder::InsertionGuard insertionGuard(builder); 745 Block *entryBlock = func.addEntryBlock(); 746 builder.setInsertionPointToStart(entryBlock); 747 748 Location loc = func.getLoc(); 749 Value n = entryBlock->getArguments().back(); 750 ValueRange args = entryBlock->getArguments().drop_back(); 751 Value first = args[loIdx]; 752 Value start = args[hiIdx]; 753 754 // If (n >= 2). 755 Value c2 = constantIndex(builder, loc, 2); 756 Value condN = 757 builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::uge, n, c2); 758 scf::IfOp ifN = builder.create<scf::IfOp>(loc, condN, /*else=*/false); 759 builder.setInsertionPointToStart(&ifN.getThenRegion().front()); 760 Value child = builder.create<arith::SubIOp>(loc, start, first); 761 762 // If ((n-2)/2 >= child). 763 Value t = createSubTwoDividedByTwo(builder, loc, n); 764 Value condNc = 765 builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::uge, t, child); 766 scf::IfOp ifNc = builder.create<scf::IfOp>(loc, condNc, /*else=*/false); 767 768 builder.setInsertionPointToStart(&ifNc.getThenRegion().front()); 769 Value c1 = constantIndex(builder, loc, 1); 770 SmallVector<Value> compareOperands{start, start}; 771 uint64_t numXBuffers = isCoo ? 1 : nx; 772 compareOperands.append(args.begin() + xStartIdx, 773 args.begin() + xStartIdx + numXBuffers); 774 775 // Generate code to inspect the children of 'r' and return the larger child 776 // as follows: 777 // child = r * 2 + 1 // Left child. 778 // childIndex = child + first 779 // if (child+1 < n && data[childIndex] < data[childIndex+1]) 780 // childIndex ++; child ++ // Right child is bigger. 781 auto getLargerChild = [&](Value r) -> std::pair<Value, Value> { 782 Value lChild = builder.create<arith::ShLIOp>(loc, r, c1); 783 lChild = builder.create<arith::AddIOp>(loc, lChild, c1); 784 Value lChildIdx = builder.create<arith::AddIOp>(loc, lChild, first); 785 Value rChild = builder.create<arith::AddIOp>(loc, lChild, c1); 786 Value cond1 = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult, 787 rChild, n); 788 SmallVector<Type, 2> ifTypes(2, r.getType()); 789 scf::IfOp if1 = 790 builder.create<scf::IfOp>(loc, ifTypes, cond1, /*else=*/true); 791 builder.setInsertionPointToStart(&if1.getThenRegion().front()); 792 Value rChildIdx = builder.create<arith::AddIOp>(loc, rChild, first); 793 // Compare data[left] < data[right]. 794 compareOperands[0] = lChildIdx; 795 compareOperands[1] = rChildIdx; 796 Value cond2 = 797 createInlinedLessThan(builder, loc, compareOperands, nx, ny, isCoo); 798 scf::IfOp if2 = 799 builder.create<scf::IfOp>(loc, ifTypes, cond2, /*else=*/true); 800 builder.setInsertionPointToStart(&if2.getThenRegion().front()); 801 builder.create<scf::YieldOp>(loc, ValueRange{rChild, rChildIdx}); 802 builder.setInsertionPointToStart(&if2.getElseRegion().front()); 803 builder.create<scf::YieldOp>(loc, ValueRange{lChild, lChildIdx}); 804 builder.setInsertionPointAfter(if2); 805 builder.create<scf::YieldOp>(loc, if2.getResults()); 806 builder.setInsertionPointToStart(&if1.getElseRegion().front()); 807 builder.create<scf::YieldOp>(loc, ValueRange{lChild, lChildIdx}); 808 builder.setInsertionPointAfter(if1); 809 return std::make_pair(if1.getResult(0), if1.getResult(1)); 810 }; 811 812 Value childIdx; 813 std::tie(child, childIdx) = getLargerChild(child); 814 815 // While (data[start] < data[childIndex]). 816 SmallVector<Type, 3> types(3, child.getType()); 817 scf::WhileOp whileOp = builder.create<scf::WhileOp>( 818 loc, types, SmallVector<Value, 2>{start, child, childIdx}); 819 820 // The before-region of the WhileOp. 821 SmallVector<Location, 3> locs(3, loc); 822 Block *before = builder.createBlock(&whileOp.getBefore(), {}, types, locs); 823 builder.setInsertionPointToEnd(before); 824 start = before->getArgument(0); 825 childIdx = before->getArgument(2); 826 compareOperands[0] = start; 827 compareOperands[1] = childIdx; 828 Value cond = 829 createInlinedLessThan(builder, loc, compareOperands, nx, ny, isCoo); 830 builder.create<scf::ConditionOp>(loc, cond, before->getArguments()); 831 832 // The after-region of the WhileOp. 833 Block *after = builder.createBlock(&whileOp.getAfter(), {}, types, locs); 834 start = after->getArgument(0); 835 child = after->getArgument(1); 836 childIdx = after->getArgument(2); 837 SmallVector<Value> swapOperands{start, childIdx}; 838 swapOperands.append(args.begin() + xStartIdx, args.end()); 839 createSwap(builder, loc, swapOperands, nx, ny, isCoo); 840 start = childIdx; 841 Value cond2 = 842 builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::uge, t, child); 843 scf::IfOp if2 = builder.create<scf::IfOp>( 844 loc, TypeRange{child.getType(), child.getType()}, cond2, /*else=*/true); 845 builder.setInsertionPointToStart(&if2.getThenRegion().front()); 846 auto [newChild, newChildIdx] = getLargerChild(child); 847 builder.create<scf::YieldOp>(loc, ValueRange{newChild, newChildIdx}); 848 builder.setInsertionPointToStart(&if2.getElseRegion().front()); 849 builder.create<scf::YieldOp>(loc, ValueRange{child, childIdx}); 850 builder.setInsertionPointAfter(if2); 851 builder.create<scf::YieldOp>( 852 loc, ValueRange{start, if2.getResult(0), if2.getResult(1)}); 853 854 builder.setInsertionPointAfter(ifN); 855 builder.create<func::ReturnOp>(loc); 856 } 857 858 /// Creates a function to perform heap sort on the values in the range of index 859 /// [lo, hi) with the assumption hi - lo >= 2. 860 // 861 // The generate IR corresponds to this C like algorithm: 862 // void heapSort(lo, hi, data) { 863 // n = hi - lo 864 // for i = (n-2)/2 downto 0 865 // shiftDown(lo, lo+i, n) 866 // 867 // for l = n downto 2 868 // swap(lo, lo+l-1) 869 // shiftdown(lo, lo, l-1) 870 // } 871 static void createHeapSortFunc(OpBuilder &builder, ModuleOp module, 872 func::FuncOp func, uint64_t nx, uint64_t ny, 873 bool isCoo, uint32_t nTrailingP) { 874 // Heap sort function doesn't have trailing parameters. 875 (void)nTrailingP; 876 assert(nTrailingP == 0); 877 OpBuilder::InsertionGuard insertionGuard(builder); 878 Block *entryBlock = func.addEntryBlock(); 879 builder.setInsertionPointToStart(entryBlock); 880 881 Location loc = func.getLoc(); 882 ValueRange args = entryBlock->getArguments(); 883 Value lo = args[loIdx]; 884 Value hi = args[hiIdx]; 885 Value n = builder.create<arith::SubIOp>(loc, hi, lo); 886 887 // For i = (n-2)/2 downto 0. 888 Value c0 = constantIndex(builder, loc, 0); 889 Value c1 = constantIndex(builder, loc, 1); 890 Value s = createSubTwoDividedByTwo(builder, loc, n); 891 Value up = builder.create<arith::AddIOp>(loc, s, c1); 892 scf::ForOp forI = builder.create<scf::ForOp>(loc, c0, up, c1); 893 builder.setInsertionPointToStart(forI.getBody()); 894 Value i = builder.create<arith::SubIOp>(loc, s, forI.getInductionVar()); 895 Value lopi = builder.create<arith::AddIOp>(loc, lo, i); 896 SmallVector<Value> shiftDownOperands = {lo, lopi}; 897 shiftDownOperands.append(args.begin() + xStartIdx, args.end()); 898 shiftDownOperands.push_back(n); 899 FlatSymbolRefAttr shiftDownFunc = getMangledSortHelperFunc( 900 builder, func, TypeRange(), kShiftDownFuncNamePrefix, nx, ny, isCoo, 901 shiftDownOperands, createShiftDownFunc, /*nTrailingP=*/1); 902 builder.create<func::CallOp>(loc, shiftDownFunc, TypeRange(), 903 shiftDownOperands); 904 905 builder.setInsertionPointAfter(forI); 906 // For l = n downto 2. 907 up = builder.create<arith::SubIOp>(loc, n, c1); 908 scf::ForOp forL = builder.create<scf::ForOp>(loc, c0, up, c1); 909 builder.setInsertionPointToStart(forL.getBody()); 910 Value l = builder.create<arith::SubIOp>(loc, n, forL.getInductionVar()); 911 Value loplm1 = builder.create<arith::AddIOp>(loc, lo, l); 912 loplm1 = builder.create<arith::SubIOp>(loc, loplm1, c1); 913 SmallVector<Value> swapOperands{lo, loplm1}; 914 swapOperands.append(args.begin() + xStartIdx, args.end()); 915 createSwap(builder, loc, swapOperands, nx, ny, isCoo); 916 shiftDownOperands[1] = lo; 917 shiftDownOperands[shiftDownOperands.size() - 1] = 918 builder.create<arith::SubIOp>(loc, l, c1); 919 builder.create<func::CallOp>(loc, shiftDownFunc, TypeRange(), 920 shiftDownOperands); 921 922 builder.setInsertionPointAfter(forL); 923 builder.create<func::ReturnOp>(loc); 924 } 925 926 /// A helper for generating code to perform quick sort. It partitions [lo, hi), 927 /// recursively calls quick sort to process the smaller partition and returns 928 /// the bigger partition to be processed by the enclosed while-loop. 929 static std::pair<Value, Value> 930 createQuickSort(OpBuilder &builder, ModuleOp module, func::FuncOp func, 931 ValueRange args, uint64_t nx, uint64_t ny, bool isCoo, 932 uint32_t nTrailingP) { 933 MLIRContext *context = module.getContext(); 934 Location loc = func.getLoc(); 935 Value lo = args[loIdx]; 936 Value hi = args[hiIdx]; 937 SmallVector<Type, 2> types(2, lo.getType()); // Only two types. 938 939 FlatSymbolRefAttr partitionFunc = getMangledSortHelperFunc( 940 builder, func, {IndexType::get(context)}, kPartitionFuncNamePrefix, nx, 941 ny, isCoo, args.drop_back(nTrailingP), createPartitionFunc); 942 Value p = builder 943 .create<func::CallOp>(loc, partitionFunc, 944 TypeRange{IndexType::get(context)}, 945 args.drop_back(nTrailingP)) 946 .getResult(0); 947 948 Value lenLow = builder.create<arith::SubIOp>(loc, p, lo); 949 Value lenHigh = builder.create<arith::SubIOp>(loc, hi, p); 950 // Partition already sorts array with len <= 2 951 Value c2 = constantIndex(builder, loc, 2); 952 Value len = builder.create<arith::SubIOp>(loc, hi, lo); 953 Value lenGtTwo = 954 builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ugt, len, c2); 955 scf::IfOp ifLenGtTwo = 956 builder.create<scf::IfOp>(loc, types, lenGtTwo, /*else=*/true); 957 builder.setInsertionPointToStart(&ifLenGtTwo.getElseRegion().front()); 958 // Returns an empty range to mark the entire region is fully sorted. 959 builder.create<scf::YieldOp>(loc, ValueRange{lo, lo}); 960 961 // Else len > 2, need recursion. 962 builder.setInsertionPointToStart(&ifLenGtTwo.getThenRegion().front()); 963 Value cond = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ule, 964 lenLow, lenHigh); 965 966 scf::IfOp ifOp = builder.create<scf::IfOp>(loc, types, cond, /*else=*/true); 967 968 Value c0 = constantIndex(builder, loc, 0); 969 auto mayRecursion = [&](Value low, Value high, Value len) { 970 Value cond = 971 builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ne, len, c0); 972 scf::IfOp ifOp = builder.create<scf::IfOp>(loc, cond, /*else=*/false); 973 builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); 974 SmallVector<Value> operands{low, high}; 975 operands.append(args.begin() + xStartIdx, args.end()); 976 builder.create<func::CallOp>(loc, func, operands); 977 builder.setInsertionPointAfter(ifOp); 978 }; 979 980 // Recursively call quickSort to process the smaller partition and return 981 // the bigger partition to be processed by the enclosed while-loop. 982 builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); 983 mayRecursion(lo, p, lenLow); 984 builder.create<scf::YieldOp>(loc, ValueRange{p, hi}); 985 986 builder.setInsertionPointToStart(&ifOp.getElseRegion().front()); 987 mayRecursion(p, hi, lenHigh); 988 builder.create<scf::YieldOp>(loc, ValueRange{lo, p}); 989 990 builder.setInsertionPointAfter(ifOp); 991 builder.create<scf::YieldOp>(loc, ifOp.getResults()); 992 993 builder.setInsertionPointAfter(ifLenGtTwo); 994 return std::make_pair(ifLenGtTwo.getResult(0), ifLenGtTwo.getResult(1)); 995 } 996 997 /// Creates a function to perform insertion sort on the values in the range of 998 /// index [lo, hi). 999 // 1000 // The generate IR corresponds to this C like algorithm: 1001 // void insertionSort(lo, hi, data) { 1002 // for (i = lo+1; i < hi; i++) { 1003 // d = data[i]; 1004 // p = binarySearch(lo, i-1, data) 1005 // for (j = 0; j > i - p; j++) 1006 // data[i-j] = data[i-j-1] 1007 // data[p] = d 1008 // } 1009 // } 1010 static void createSortStableFunc(OpBuilder &builder, ModuleOp module, 1011 func::FuncOp func, uint64_t nx, uint64_t ny, 1012 bool isCoo, uint32_t nTrailingP) { 1013 // Stable sort function doesn't use trailing parameters. 1014 (void)nTrailingP; 1015 assert(nTrailingP == 0); 1016 OpBuilder::InsertionGuard insertionGuard(builder); 1017 Block *entryBlock = func.addEntryBlock(); 1018 builder.setInsertionPointToStart(entryBlock); 1019 1020 MLIRContext *context = module.getContext(); 1021 Location loc = func.getLoc(); 1022 ValueRange args = entryBlock->getArguments(); 1023 Value c1 = constantIndex(builder, loc, 1); 1024 Value lo = args[loIdx]; 1025 Value hi = args[hiIdx]; 1026 Value lop1 = builder.create<arith::AddIOp>(loc, lo, c1); 1027 1028 // Start the outer for-stmt with induction variable i. 1029 scf::ForOp forOpI = builder.create<scf::ForOp>(loc, lop1, hi, c1); 1030 builder.setInsertionPointToStart(forOpI.getBody()); 1031 Value i = forOpI.getInductionVar(); 1032 1033 // Binary search to find the insertion point p. 1034 SmallVector<Value> operands{lo, i}; 1035 operands.append(args.begin() + xStartIdx, args.end()); 1036 FlatSymbolRefAttr searchFunc = getMangledSortHelperFunc( 1037 builder, func, {IndexType::get(context)}, kBinarySearchFuncNamePrefix, nx, 1038 ny, isCoo, operands, createBinarySearchFunc); 1039 Value p = builder 1040 .create<func::CallOp>(loc, searchFunc, TypeRange{c1.getType()}, 1041 operands) 1042 .getResult(0); 1043 1044 // Move the value at data[i] to a temporary location. 1045 operands[0] = operands[1] = i; 1046 SmallVector<Value> d; 1047 forEachIJPairInAllBuffers( 1048 builder, loc, operands, nx, ny, isCoo, 1049 [&](uint64_t unused, Value i, Value unused2, Value buffer) { 1050 d.push_back(builder.create<memref::LoadOp>(loc, buffer, i)); 1051 }); 1052 1053 // Start the inner for-stmt with induction variable j, for moving data[p..i) 1054 // to data[p+1..i+1). 1055 Value imp = builder.create<arith::SubIOp>(loc, i, p); 1056 Value c0 = constantIndex(builder, loc, 0); 1057 scf::ForOp forOpJ = builder.create<scf::ForOp>(loc, c0, imp, c1); 1058 builder.setInsertionPointToStart(forOpJ.getBody()); 1059 Value j = forOpJ.getInductionVar(); 1060 Value imj = builder.create<arith::SubIOp>(loc, i, j); 1061 operands[1] = imj; 1062 operands[0] = builder.create<arith::SubIOp>(loc, imj, c1); 1063 forEachIJPairInAllBuffers( 1064 builder, loc, operands, nx, ny, isCoo, 1065 [&](uint64_t unused, Value imjm1, Value imj, Value buffer) { 1066 Value t = builder.create<memref::LoadOp>(loc, buffer, imjm1); 1067 builder.create<memref::StoreOp>(loc, t, buffer, imj); 1068 }); 1069 1070 // Store the value at data[i] to data[p]. 1071 builder.setInsertionPointAfter(forOpJ); 1072 operands[0] = operands[1] = p; 1073 forEachIJPairInAllBuffers( 1074 builder, loc, operands, nx, ny, isCoo, 1075 [&](uint64_t k, Value p, Value usused, Value buffer) { 1076 builder.create<memref::StoreOp>(loc, d[k], buffer, p); 1077 }); 1078 1079 builder.setInsertionPointAfter(forOpI); 1080 builder.create<func::ReturnOp>(loc); 1081 } 1082 1083 /// Creates a function to perform quick sort or a hybrid quick sort on the 1084 /// values in the range of index [lo, hi). 1085 // 1086 // 1087 // When nTrailingP == 0, the generated IR corresponds to this C like algorithm: 1088 // void quickSort(lo, hi, data) { 1089 // while (lo + 1 < hi) { 1090 // p = partition(low, high, data); 1091 // if (len(lo, p) < len(p+1, hi)) { 1092 // quickSort(lo, p, data); 1093 // lo = p+1; 1094 // } else { 1095 // quickSort(p + 1, hi, data); 1096 // hi = p; 1097 // } 1098 // } 1099 // } 1100 // 1101 // When nTrailingP == 1, the generated IR corresponds to this C like algorithm: 1102 // void hybridQuickSort(lo, hi, data, depthLimit) { 1103 // while (lo + 1 < hi) { 1104 // len = hi - lo; 1105 // if (len <= limit) { 1106 // insertionSort(lo, hi, data); 1107 // } else { 1108 // depthLimit --; 1109 // if (depthLimit <= 0) { 1110 // heapSort(lo, hi, data); 1111 // } else { 1112 // p = partition(low, high, data); 1113 // if (len(lo, p) < len(p+1, hi)) { 1114 // quickSort(lo, p, data, depthLimit); 1115 // lo = p+1; 1116 // } else { 1117 // quickSort(p + 1, hi, data, depthLimit); 1118 // hi = p; 1119 // } 1120 // } 1121 // } 1122 // } 1123 // } 1124 // 1125 static void createQuickSortFunc(OpBuilder &builder, ModuleOp module, 1126 func::FuncOp func, uint64_t nx, uint64_t ny, 1127 bool isCoo, uint32_t nTrailingP) { 1128 assert(nTrailingP == 1 || nTrailingP == 0); 1129 bool isHybrid = (nTrailingP == 1); 1130 OpBuilder::InsertionGuard insertionGuard(builder); 1131 Block *entryBlock = func.addEntryBlock(); 1132 builder.setInsertionPointToStart(entryBlock); 1133 1134 Location loc = func.getLoc(); 1135 SmallVector<Value> args; 1136 args.append(entryBlock->getArguments().begin(), 1137 entryBlock->getArguments().end()); 1138 Value lo = args[loIdx]; 1139 Value hi = args[hiIdx]; 1140 SmallVector<Type, 2> types(2, lo.getType()); // Only two types. 1141 scf::WhileOp whileOp = 1142 builder.create<scf::WhileOp>(loc, types, SmallVector<Value, 2>{lo, hi}); 1143 1144 // The before-region of the WhileOp. 1145 Block *before = 1146 builder.createBlock(&whileOp.getBefore(), {}, types, {loc, loc}); 1147 builder.setInsertionPointToEnd(before); 1148 lo = before->getArgument(0); 1149 hi = before->getArgument(1); 1150 Value loP1 = 1151 builder.create<arith::AddIOp>(loc, lo, constantIndex(builder, loc, 1)); 1152 Value needSort = 1153 builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult, loP1, hi); 1154 builder.create<scf::ConditionOp>(loc, needSort, before->getArguments()); 1155 1156 // The after-region of the WhileOp. 1157 Block *after = 1158 builder.createBlock(&whileOp.getAfter(), {}, types, {loc, loc}); 1159 builder.setInsertionPointToEnd(after); 1160 lo = after->getArgument(0); 1161 hi = after->getArgument(1); 1162 args[0] = lo; 1163 args[1] = hi; 1164 1165 if (isHybrid) { 1166 Value len = builder.create<arith::SubIOp>(loc, hi, lo); 1167 Value lenLimit = constantIndex(builder, loc, 30); 1168 Value lenCond = builder.create<arith::CmpIOp>( 1169 loc, arith::CmpIPredicate::ule, len, lenLimit); 1170 scf::IfOp lenIf = 1171 builder.create<scf::IfOp>(loc, types, lenCond, /*else=*/true); 1172 1173 // When len <= limit. 1174 builder.setInsertionPointToStart(&lenIf.getThenRegion().front()); 1175 FlatSymbolRefAttr insertionSortFunc = getMangledSortHelperFunc( 1176 builder, func, TypeRange(), kSortStableFuncNamePrefix, nx, ny, isCoo, 1177 ValueRange(args).drop_back(nTrailingP), createSortStableFunc); 1178 builder.create<func::CallOp>(loc, insertionSortFunc, TypeRange(), 1179 ValueRange(args).drop_back(nTrailingP)); 1180 builder.create<scf::YieldOp>(loc, ValueRange{lo, lo}); 1181 1182 // When len > limit. 1183 builder.setInsertionPointToStart(&lenIf.getElseRegion().front()); 1184 Value depthLimit = args.back(); 1185 depthLimit = builder.create<arith::SubIOp>(loc, depthLimit, 1186 constantI64(builder, loc, 1)); 1187 Value depthCond = 1188 builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ule, 1189 depthLimit, constantI64(builder, loc, 0)); 1190 scf::IfOp depthIf = 1191 builder.create<scf::IfOp>(loc, types, depthCond, /*else=*/true); 1192 1193 // When depth exceeds limit. 1194 builder.setInsertionPointToStart(&depthIf.getThenRegion().front()); 1195 FlatSymbolRefAttr heapSortFunc = getMangledSortHelperFunc( 1196 builder, func, TypeRange(), kHeapSortFuncNamePrefix, nx, ny, isCoo, 1197 ValueRange(args).drop_back(nTrailingP), createHeapSortFunc); 1198 builder.create<func::CallOp>(loc, heapSortFunc, TypeRange(), 1199 ValueRange(args).drop_back(nTrailingP)); 1200 builder.create<scf::YieldOp>(loc, ValueRange{lo, lo}); 1201 1202 // When depth doesn't exceed limit. 1203 builder.setInsertionPointToStart(&depthIf.getElseRegion().front()); 1204 args.back() = depthLimit; 1205 std::tie(lo, hi) = 1206 createQuickSort(builder, module, func, args, nx, ny, isCoo, nTrailingP); 1207 builder.create<scf::YieldOp>(loc, ValueRange{lo, hi}); 1208 1209 builder.setInsertionPointAfter(depthIf); 1210 lo = depthIf.getResult(0); 1211 hi = depthIf.getResult(1); 1212 builder.create<scf::YieldOp>(loc, ValueRange{lo, hi}); 1213 1214 builder.setInsertionPointAfter(lenIf); 1215 lo = lenIf.getResult(0); 1216 hi = lenIf.getResult(1); 1217 } else { 1218 std::tie(lo, hi) = 1219 createQuickSort(builder, module, func, args, nx, ny, isCoo, nTrailingP); 1220 } 1221 1222 // New [lo, hi) for the next while-loop iteration. 1223 builder.create<scf::YieldOp>(loc, ValueRange{lo, hi}); 1224 1225 // After the while-loop. 1226 builder.setInsertionPointAfter(whileOp); 1227 builder.create<func::ReturnOp>(loc); 1228 } 1229 1230 /// Implements the rewriting for operator sort and sort_coo. 1231 template <typename OpTy> 1232 LogicalResult matchAndRewriteSortOp(OpTy op, ValueRange xys, uint64_t nx, 1233 uint64_t ny, bool isCoo, 1234 PatternRewriter &rewriter) { 1235 Location loc = op.getLoc(); 1236 SmallVector<Value> operands{constantIndex(rewriter, loc, 0), op.getN()}; 1237 1238 // Convert `values` to have dynamic shape and append them to `operands`. 1239 for (Value v : xys) { 1240 auto mtp = getMemRefType(v); 1241 if (!mtp.isDynamicDim(0)) { 1242 auto newMtp = 1243 MemRefType::get({ShapedType::kDynamic}, mtp.getElementType()); 1244 v = rewriter.create<memref::CastOp>(loc, newMtp, v); 1245 } 1246 operands.push_back(v); 1247 } 1248 1249 auto insertPoint = op->template getParentOfType<func::FuncOp>(); 1250 if (!insertPoint) 1251 return failure(); 1252 1253 SmallString<32> funcName; 1254 FuncGeneratorType funcGenerator; 1255 uint32_t nTrailingP = 0; 1256 switch (op.getAlgorithm()) { 1257 case SparseTensorSortKind::HybridQuickSort: { 1258 funcName = kHybridQuickSortFuncNamePrefix; 1259 funcGenerator = createQuickSortFunc; 1260 nTrailingP = 1; 1261 // As a heuristics, set depthLimit = 2 * log2(n). 1262 Value lo = operands[loIdx]; 1263 Value hi = operands[hiIdx]; 1264 Value len = rewriter.create<arith::IndexCastOp>( 1265 loc, rewriter.getI64Type(), 1266 rewriter.create<arith::SubIOp>(loc, hi, lo)); 1267 Value depthLimit = rewriter.create<arith::SubIOp>( 1268 loc, constantI64(rewriter, loc, 64), 1269 rewriter.create<math::CountLeadingZerosOp>(loc, len)); 1270 operands.push_back(depthLimit); 1271 break; 1272 } 1273 case SparseTensorSortKind::QuickSort: 1274 funcName = kQuickSortFuncNamePrefix; 1275 funcGenerator = createQuickSortFunc; 1276 break; 1277 case SparseTensorSortKind::InsertionSortStable: 1278 funcName = kSortStableFuncNamePrefix; 1279 funcGenerator = createSortStableFunc; 1280 break; 1281 case SparseTensorSortKind::HeapSort: 1282 funcName = kHeapSortFuncNamePrefix; 1283 funcGenerator = createHeapSortFunc; 1284 break; 1285 } 1286 1287 FlatSymbolRefAttr func = 1288 getMangledSortHelperFunc(rewriter, insertPoint, TypeRange(), funcName, nx, 1289 ny, isCoo, operands, funcGenerator, nTrailingP); 1290 rewriter.replaceOpWithNewOp<func::CallOp>(op, func, TypeRange(), operands); 1291 return success(); 1292 } 1293 1294 //===---------------------------------------------------------------------===// 1295 // The actual sparse buffer rewriting rules. 1296 //===---------------------------------------------------------------------===// 1297 1298 namespace { 1299 1300 /// Sparse rewriting rule for the push_back operator. 1301 struct PushBackRewriter : OpRewritePattern<PushBackOp> { 1302 public: 1303 using OpRewritePattern<PushBackOp>::OpRewritePattern; 1304 PushBackRewriter(MLIRContext *context, bool enableInit) 1305 : OpRewritePattern(context), enableBufferInitialization(enableInit) {} 1306 LogicalResult matchAndRewrite(PushBackOp op, 1307 PatternRewriter &rewriter) const override { 1308 // Rewrite push_back(buffer, value, n) to: 1309 // new_size = size(buffer) + n 1310 // if (new_size > capacity(buffer)) 1311 // while new_size > new_capacity 1312 // new_capacity = new_capacity*2 1313 // new_buffer = realloc(buffer, new_capacity) 1314 // buffer = new_buffer 1315 // subBuffer = subviewof(buffer) 1316 // linalg.fill subBuffer value 1317 // 1318 // size(buffer) += n 1319 // 1320 // The capacity check is skipped when the attribute inbounds is presented. 1321 Location loc = op->getLoc(); 1322 Value c0 = constantIndex(rewriter, loc, 0); 1323 Value buffer = op.getInBuffer(); 1324 Value capacity = rewriter.create<memref::DimOp>(loc, buffer, c0); 1325 Value size = op.getCurSize(); 1326 Value value = op.getValue(); 1327 1328 Value n = op.getN() ? op.getN() : constantIndex(rewriter, loc, 1); 1329 Value newSize = rewriter.create<arith::AddIOp>(loc, size, n); 1330 auto nValue = dyn_cast_or_null<arith::ConstantIndexOp>(n.getDefiningOp()); 1331 bool nIsOne = (nValue && nValue.value() == 1); 1332 1333 if (!op.getInbounds()) { 1334 Value cond = rewriter.create<arith::CmpIOp>( 1335 loc, arith::CmpIPredicate::ugt, newSize, capacity); 1336 1337 Value c2 = constantIndex(rewriter, loc, 2); 1338 auto bufferType = 1339 MemRefType::get({ShapedType::kDynamic}, value.getType()); 1340 scf::IfOp ifOp = rewriter.create<scf::IfOp>(loc, bufferType, cond, 1341 /*else=*/true); 1342 // True branch. 1343 rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front()); 1344 if (nIsOne) { 1345 capacity = rewriter.create<arith::MulIOp>(loc, capacity, c2); 1346 } else { 1347 // Use a do-while loop to calculate the new capacity as follows: 1348 // do { new_capacity *= 2 } while (size > new_capacity) 1349 scf::WhileOp whileOp = 1350 rewriter.create<scf::WhileOp>(loc, capacity.getType(), capacity); 1351 1352 // The before-region of the WhileOp. 1353 Block *before = rewriter.createBlock(&whileOp.getBefore(), {}, 1354 {capacity.getType()}, {loc}); 1355 rewriter.setInsertionPointToEnd(before); 1356 1357 capacity = 1358 rewriter.create<arith::MulIOp>(loc, before->getArgument(0), c2); 1359 cond = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ugt, 1360 newSize, capacity); 1361 rewriter.create<scf::ConditionOp>(loc, cond, ValueRange{capacity}); 1362 // The after-region of the WhileOp. 1363 Block *after = rewriter.createBlock(&whileOp.getAfter(), {}, 1364 {capacity.getType()}, {loc}); 1365 rewriter.setInsertionPointToEnd(after); 1366 rewriter.create<scf::YieldOp>(loc, after->getArguments()); 1367 1368 rewriter.setInsertionPointAfter(whileOp); 1369 capacity = whileOp.getResult(0); 1370 } 1371 1372 Value newBuffer = 1373 rewriter.create<memref::ReallocOp>(loc, bufferType, buffer, capacity); 1374 if (enableBufferInitialization) { 1375 Value fillSize = rewriter.create<arith::SubIOp>(loc, capacity, newSize); 1376 Value fillValue = constantZero(rewriter, loc, value.getType()); 1377 Value subBuffer = rewriter.create<memref::SubViewOp>( 1378 loc, newBuffer, /*offset=*/ValueRange{newSize}, 1379 /*size=*/ValueRange{fillSize}, 1380 /*step=*/ValueRange{constantIndex(rewriter, loc, 1)}); 1381 rewriter.create<linalg::FillOp>(loc, fillValue, subBuffer); 1382 } 1383 rewriter.create<scf::YieldOp>(loc, newBuffer); 1384 1385 // False branch. 1386 rewriter.setInsertionPointToStart(&ifOp.getElseRegion().front()); 1387 rewriter.create<scf::YieldOp>(loc, buffer); 1388 1389 // Prepare for adding the value to the end of the buffer. 1390 rewriter.setInsertionPointAfter(ifOp); 1391 buffer = ifOp.getResult(0); 1392 } 1393 1394 // Add the value to the end of the buffer. 1395 if (nIsOne) { 1396 rewriter.create<memref::StoreOp>(loc, value, buffer, size); 1397 } else { 1398 Value subBuffer = rewriter.create<memref::SubViewOp>( 1399 loc, buffer, /*offset=*/ValueRange{size}, /*size=*/ValueRange{n}, 1400 /*step=*/ValueRange{constantIndex(rewriter, loc, 1)}); 1401 rewriter.create<linalg::FillOp>(loc, value, subBuffer); 1402 } 1403 1404 // Update the buffer size. 1405 rewriter.replaceOp(op, {buffer, newSize}); 1406 return success(); 1407 } 1408 1409 private: 1410 bool enableBufferInitialization; 1411 }; 1412 1413 /// Sparse rewriting rule for the sort operator. 1414 struct SortRewriter : public OpRewritePattern<SortOp> { 1415 public: 1416 using OpRewritePattern<SortOp>::OpRewritePattern; 1417 1418 LogicalResult matchAndRewrite(SortOp op, 1419 PatternRewriter &rewriter) const override { 1420 SmallVector<Value> xys(op.getXs()); 1421 xys.append(op.getYs().begin(), op.getYs().end()); 1422 return matchAndRewriteSortOp(op, xys, op.getXs().size(), /*ny=*/0, 1423 /*isCoo=*/false, rewriter); 1424 } 1425 }; 1426 1427 /// Sparse rewriting rule for the sort_coo operator. 1428 struct SortCooRewriter : public OpRewritePattern<SortCooOp> { 1429 public: 1430 using OpRewritePattern<SortCooOp>::OpRewritePattern; 1431 1432 LogicalResult matchAndRewrite(SortCooOp op, 1433 PatternRewriter &rewriter) const override { 1434 SmallVector<Value> xys; 1435 xys.push_back(op.getXy()); 1436 xys.append(op.getYs().begin(), op.getYs().end()); 1437 uint64_t nx = 1; 1438 if (auto nxAttr = op.getNxAttr()) 1439 nx = nxAttr.getInt(); 1440 1441 uint64_t ny = 0; 1442 if (auto nyAttr = op.getNyAttr()) 1443 ny = nyAttr.getInt(); 1444 1445 return matchAndRewriteSortOp(op, xys, nx, ny, 1446 /*isCoo=*/true, rewriter); 1447 } 1448 }; 1449 1450 } // namespace 1451 1452 //===---------------------------------------------------------------------===// 1453 // Methods that add patterns described in this file to a pattern list. 1454 //===---------------------------------------------------------------------===// 1455 1456 void mlir::populateSparseBufferRewriting(RewritePatternSet &patterns, 1457 bool enableBufferInitialization) { 1458 patterns.add<PushBackRewriter>(patterns.getContext(), 1459 enableBufferInitialization); 1460 patterns.add<SortRewriter, SortCooRewriter>(patterns.getContext()); 1461 } 1462