1 //===- SparseBufferRewriting.cpp - Sparse buffer rewriting rules ----------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements rewriting rules that are specific to sparse tensor 10 // primitives with memref operands. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "CodegenUtils.h" 15 16 #include "mlir/Dialect/Arith/IR/Arith.h" 17 #include "mlir/Dialect/Func/IR/FuncOps.h" 18 #include "mlir/Dialect/Linalg/IR/Linalg.h" 19 #include "mlir/Dialect/Math/IR/Math.h" 20 #include "mlir/Dialect/MemRef/IR/MemRef.h" 21 #include "mlir/Dialect/SCF/IR/SCF.h" 22 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" 23 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h" 24 #include "mlir/Support/LLVM.h" 25 26 using namespace mlir; 27 using namespace mlir::sparse_tensor; 28 29 //===---------------------------------------------------------------------===// 30 // Helper methods for the actual rewriting rules. 31 //===---------------------------------------------------------------------===// 32 33 static constexpr uint64_t loIdx = 0; 34 static constexpr uint64_t hiIdx = 1; 35 static constexpr uint64_t xStartIdx = 2; 36 37 static constexpr const char kPartitionFuncNamePrefix[] = "_sparse_partition_"; 38 static constexpr const char kBinarySearchFuncNamePrefix[] = 39 "_sparse_binary_search_"; 40 static constexpr const char kHybridQuickSortFuncNamePrefix[] = 41 "_sparse_hybrid_qsort_"; 42 static constexpr const char kSortStableFuncNamePrefix[] = 43 "_sparse_sort_stable_"; 44 static constexpr const char kShiftDownFuncNamePrefix[] = "_sparse_shift_down_"; 45 static constexpr const char kHeapSortFuncNamePrefix[] = "_sparse_heap_sort_"; 46 static constexpr const char kQuickSortFuncNamePrefix[] = "_sparse_qsort_"; 47 48 using FuncGeneratorType = function_ref<void(OpBuilder &, ModuleOp, func::FuncOp, 49 AffineMap, uint64_t, uint32_t)>; 50 51 /// Constructs a function name with this format to facilitate quick sort: 52 /// <namePrefix><xPerm>_<x type>_<y0 type>..._<yn type> for sort 53 /// <namePrefix><xPerm>_<x type>_coo_<ny>_<y0 type>..._<yn type> for sort_coo 54 static void getMangledSortHelperFuncName(llvm::raw_svector_ostream &nameOstream, 55 StringRef namePrefix, AffineMap xPerm, 56 uint64_t ny, ValueRange operands) { 57 nameOstream << namePrefix; 58 for (auto res : xPerm.getResults()) 59 nameOstream << cast<AffineDimExpr>(res).getPosition() << "_"; 60 61 nameOstream << getMemRefType(operands[xStartIdx]).getElementType(); 62 nameOstream << "_coo_" << ny; 63 64 constexpr uint64_t yBufferOffset = 1; 65 for (Value v : operands.drop_front(xStartIdx + yBufferOffset)) 66 nameOstream << "_" << getMemRefType(v).getElementType(); 67 } 68 69 /// Looks up a function that is appropriate for the given operands being 70 /// sorted, and creates such a function if it doesn't exist yet. The 71 /// parameters `xPerm` and `ny` tell the number of x and y values provided 72 /// by the buffer in xStartIdx. 73 // 74 // All sorting function generators take (lo, hi, xs, ys) in `operands` as 75 // parameters for the sorting functions. Other parameters, such as the recursive 76 // call depth, are appended to the end of the parameter list as 77 // "trailing parameters". 78 static FlatSymbolRefAttr getMangledSortHelperFunc( 79 OpBuilder &builder, func::FuncOp insertPoint, TypeRange resultTypes, 80 StringRef namePrefix, AffineMap xPerm, uint64_t ny, ValueRange operands, 81 FuncGeneratorType createFunc, uint32_t nTrailingP = 0) { 82 SmallString<32> nameBuffer; 83 llvm::raw_svector_ostream nameOstream(nameBuffer); 84 getMangledSortHelperFuncName(nameOstream, namePrefix, xPerm, ny, 85 operands.drop_back(nTrailingP)); 86 87 ModuleOp module = insertPoint->getParentOfType<ModuleOp>(); 88 MLIRContext *context = module.getContext(); 89 auto result = SymbolRefAttr::get(context, nameOstream.str()); 90 auto func = module.lookupSymbol<func::FuncOp>(result.getAttr()); 91 92 if (!func) { 93 // Create the function. 94 OpBuilder::InsertionGuard insertionGuard(builder); 95 builder.setInsertionPoint(insertPoint); 96 Location loc = insertPoint.getLoc(); 97 func = builder.create<func::FuncOp>( 98 loc, nameOstream.str(), 99 FunctionType::get(context, operands.getTypes(), resultTypes)); 100 func.setPrivate(); 101 createFunc(builder, module, func, xPerm, ny, nTrailingP); 102 } 103 104 return result; 105 } 106 107 /// Creates a code block to process each pair of (xs[i], xs[j]) for sorting. 108 /// The code to process the value pairs is generated by `bodyBuilder`. 109 static void forEachIJPairInXs( 110 OpBuilder &builder, Location loc, ValueRange args, AffineMap xPerm, 111 uint64_t ny, 112 function_ref<void(uint64_t, Value, Value, Value)> bodyBuilder) { 113 Value cstep = constantIndex(builder, loc, xPerm.getNumResults() + ny); 114 Value iOffset = builder.create<arith::MulIOp>(loc, args[0], cstep); 115 Value jOffset = builder.create<arith::MulIOp>(loc, args[1], cstep); 116 for (unsigned k = 0, e = xPerm.getNumResults(); k < e; k++) { 117 unsigned actualK = cast<AffineDimExpr>(xPerm.getResult(k)).getPosition(); 118 Value ak = constantIndex(builder, loc, actualK); 119 Value i = builder.create<arith::AddIOp>(loc, ak, iOffset); 120 Value j = builder.create<arith::AddIOp>(loc, ak, jOffset); 121 Value buffer = args[xStartIdx]; 122 123 bodyBuilder(k, i, j, buffer); 124 } 125 } 126 127 /// Creates a code block to process each pair of (xys[i], xys[j]) for sorting. 128 /// The code to process the value pairs is generated by `bodyBuilder`. 129 static void forEachIJPairInAllBuffers( 130 OpBuilder &builder, Location loc, ValueRange args, AffineMap xPerm, 131 uint64_t ny, 132 function_ref<void(uint64_t, Value, Value, Value)> bodyBuilder) { 133 134 // Create code for the first (xPerm + ny) buffers. 135 SmallVector<AffineExpr> exps(xPerm.getResults().begin(), 136 xPerm.getResults().end()); 137 for (unsigned y = 0; y < ny; y++) { 138 exps.push_back(builder.getAffineDimExpr(y + xPerm.getNumResults())); 139 } 140 AffineMap xyPerm = AffineMap::get(exps.size(), 0, exps, builder.getContext()); 141 assert(xyPerm.isPermutation()); 142 143 forEachIJPairInXs(builder, loc, args, xyPerm, 0, bodyBuilder); 144 145 constexpr uint64_t numHandledBuffers = 1; 146 // Create code for the remaining buffers. 147 Value i = args[0]; 148 Value j = args[1]; 149 for (const auto &arg : 150 llvm::enumerate(args.drop_front(xStartIdx + numHandledBuffers))) { 151 bodyBuilder(arg.index() + xPerm.getNumResults() + ny, i, j, arg.value()); 152 } 153 } 154 155 /// Creates a code block for swapping the values in index i and j for all the 156 /// buffers. 157 // 158 // The generated IR corresponds to this C like algorithm: 159 // swap(x0[i], x0[j]); 160 // swap(x1[i], x1[j]); 161 // ... 162 // swap(xn[i], xn[j]); 163 // swap(y0[i], y0[j]); 164 // ... 165 // swap(yn[i], yn[j]); 166 static void createSwap(OpBuilder &builder, Location loc, ValueRange args, 167 AffineMap xPerm, uint64_t ny) { 168 auto swapOnePair = [&](uint64_t unused, Value i, Value j, Value buffer) { 169 Value vi = builder.create<memref::LoadOp>(loc, buffer, i); 170 Value vj = builder.create<memref::LoadOp>(loc, buffer, j); 171 builder.create<memref::StoreOp>(loc, vj, buffer, i); 172 builder.create<memref::StoreOp>(loc, vi, buffer, j); 173 }; 174 175 forEachIJPairInAllBuffers(builder, loc, args, xPerm, ny, swapOnePair); 176 } 177 178 /// Creates code to compare all the (xs[i], xs[j]) pairs. The method to compare 179 /// each pair is create via `compareBuilder`. 180 static Value createInlinedCompareImplementation( 181 OpBuilder &builder, Location loc, ValueRange args, AffineMap xPerm, 182 uint64_t ny, 183 function_ref<Value(OpBuilder &, Location, Value, Value, Value, bool, bool)> 184 compareBuilder) { 185 Value result; 186 auto bodyBuilder = [&](uint64_t k, Value i, Value j, Value buffer) { 187 bool isFirstDim = (k == 0); 188 bool isLastDim = (k == xPerm.getNumResults() - 1); 189 Value val = 190 compareBuilder(builder, loc, i, j, buffer, isFirstDim, isLastDim); 191 if (isFirstDim) { 192 result = val; 193 } else if (!isLastDim) { 194 OpBuilder::InsertionGuard insertionGuard(builder); 195 auto ifOp = cast<scf::IfOp>(val.getDefiningOp()); 196 builder.setInsertionPointAfter(ifOp); 197 builder.create<scf::YieldOp>(loc, ifOp.getResult(0)); 198 } 199 }; 200 201 forEachIJPairInXs(builder, loc, args, xPerm, ny, bodyBuilder); 202 203 builder.setInsertionPointAfterValue(result); 204 return result; 205 } 206 207 /// Generates code to compare whether x[i] is equal to x[j] and returns the 208 /// result of the comparison. 209 static Value createEqCompare(OpBuilder &builder, Location loc, Value i, Value j, 210 Value x, bool isFirstDim, bool isLastDim) { 211 Value vi = builder.create<memref::LoadOp>(loc, x, i); 212 Value vj = builder.create<memref::LoadOp>(loc, x, j); 213 214 Value res; 215 if (isLastDim) { 216 res = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, vi, vj); 217 // For 1D, we create a compare without any control flow. Otherwise, we 218 // create YieldOp to return the result in the nested if-stmt. 219 if (!isFirstDim) 220 builder.create<scf::YieldOp>(loc, res); 221 } else { 222 Value ne = 223 builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ne, vi, vj); 224 scf::IfOp ifOp = builder.create<scf::IfOp>(loc, builder.getIntegerType(1), 225 ne, /*else=*/true); 226 // If (x[i] != x[j]). 227 builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); 228 Value f = constantI1(builder, loc, false); 229 builder.create<scf::YieldOp>(loc, f); 230 231 // If (x[i] == x[j]). Set up the insertion point for the nested if-stmt that 232 // checks the remaining dimensions. 233 builder.setInsertionPointToStart(&ifOp.getElseRegion().front()); 234 res = ifOp.getResult(0); 235 } 236 237 return res; 238 } 239 240 /// Creates code to compare whether xs[i] is equal to xs[j]. 241 // 242 // The generate IR corresponds to this C like algorithm: 243 // if (x0[i] != x0[j]) 244 // return false; 245 // else 246 // if (x1[i] != x1[j]) 247 // return false; 248 // else if (x2[2] != x2[j])) 249 // and so on ... 250 static Value createInlinedEqCompare(OpBuilder &builder, Location loc, 251 ValueRange args, AffineMap xPerm, 252 uint64_t ny, uint32_t nTrailingP = 0) { 253 // Compare functions don't use trailing parameters. 254 (void)nTrailingP; 255 assert(nTrailingP == 0); 256 return createInlinedCompareImplementation(builder, loc, args, xPerm, ny, 257 createEqCompare); 258 } 259 260 /// Generates code to compare whether x[i] is less than x[j] and returns the 261 /// result of the comparison. 262 static Value createLessThanCompare(OpBuilder &builder, Location loc, Value i, 263 Value j, Value x, bool isFirstDim, 264 bool isLastDim) { 265 Value vi = builder.create<memref::LoadOp>(loc, x, i); 266 Value vj = builder.create<memref::LoadOp>(loc, x, j); 267 268 Value res; 269 if (isLastDim) { 270 res = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult, vi, vj); 271 // For 1D, we create a compare without any control flow. Otherwise, we 272 // create YieldOp to return the result in the nested if-stmt. 273 if (!isFirstDim) 274 builder.create<scf::YieldOp>(loc, res); 275 } else { 276 Value ne = 277 builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ne, vi, vj); 278 scf::IfOp ifOp = builder.create<scf::IfOp>(loc, builder.getIntegerType(1), 279 ne, /*else=*/true); 280 // If (x[i] != x[j]). 281 builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); 282 Value lt = 283 builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult, vi, vj); 284 builder.create<scf::YieldOp>(loc, lt); 285 286 // If (x[i] == x[j]). Set up the insertion point for the nested if-stmt that 287 // checks the remaining dimensions. 288 builder.setInsertionPointToStart(&ifOp.getElseRegion().front()); 289 res = ifOp.getResult(0); 290 } 291 292 return res; 293 } 294 295 /// Creates code to compare whether xs[i] is less than xs[j]. 296 // 297 // The generate IR corresponds to this C like algorithm: 298 // if (x0[i] != x0[j]) 299 // return x0[i] < x0[j]; 300 // else if (x1[j] != x1[i]) 301 // return x1[i] < x1[j]; 302 // else 303 // and so on ... 304 static Value createInlinedLessThan(OpBuilder &builder, Location loc, 305 ValueRange args, AffineMap xPerm, 306 uint64_t ny, uint32_t nTrailingP = 0) { 307 // Compare functions don't use trailing parameters. 308 (void)nTrailingP; 309 assert(nTrailingP == 0); 310 return createInlinedCompareImplementation(builder, loc, args, xPerm, ny, 311 createLessThanCompare); 312 } 313 314 /// Creates a function to use a binary search to find the insertion point for 315 /// inserting xs[hi] to the sorted values xs[lo..hi). 316 // 317 // The generate IR corresponds to this C like algorithm: 318 // p = hi 319 // while (lo < hi) 320 // mid = (lo + hi) >> 1 321 // if (xs[p] < xs[mid]) 322 // hi = mid 323 // else 324 // lo = mid - 1 325 // return lo; 326 // 327 static void createBinarySearchFunc(OpBuilder &builder, ModuleOp module, 328 func::FuncOp func, AffineMap xPerm, 329 uint64_t ny, uint32_t nTrailingP = 0) { 330 // Binary search doesn't use trailing parameters. 331 (void)nTrailingP; 332 assert(nTrailingP == 0); 333 OpBuilder::InsertionGuard insertionGuard(builder); 334 Block *entryBlock = func.addEntryBlock(); 335 builder.setInsertionPointToStart(entryBlock); 336 337 Location loc = func.getLoc(); 338 ValueRange args = entryBlock->getArguments(); 339 Value p = args[hiIdx]; 340 SmallVector<Type, 2> types(2, p.getType()); // Only two types. 341 scf::WhileOp whileOp = builder.create<scf::WhileOp>( 342 loc, types, SmallVector<Value, 2>{args[loIdx], args[hiIdx]}); 343 344 // The before-region of the WhileOp. 345 Block *before = 346 builder.createBlock(&whileOp.getBefore(), {}, types, {loc, loc}); 347 builder.setInsertionPointToEnd(before); 348 Value cond1 = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult, 349 before->getArgument(0), 350 before->getArgument(1)); 351 builder.create<scf::ConditionOp>(loc, cond1, before->getArguments()); 352 353 // The after-region of the WhileOp. 354 Block *after = 355 builder.createBlock(&whileOp.getAfter(), {}, types, {loc, loc}); 356 builder.setInsertionPointToEnd(after); 357 Value lo = after->getArgument(0); 358 Value hi = after->getArgument(1); 359 // Compute mid = (lo + hi) >> 1. 360 Value c1 = constantIndex(builder, loc, 1); 361 Value mid = builder.create<arith::ShRUIOp>( 362 loc, builder.create<arith::AddIOp>(loc, lo, hi), c1); 363 Value midp1 = builder.create<arith::AddIOp>(loc, mid, c1); 364 365 // Compare xs[p] < xs[mid]. 366 SmallVector<Value> compareOperands{p, mid}; 367 constexpr uint64_t numXBuffers = 1; 368 compareOperands.append(args.begin() + xStartIdx, 369 args.begin() + xStartIdx + numXBuffers); 370 Value cond2 = createInlinedLessThan(builder, loc, compareOperands, xPerm, ny); 371 // Update lo and hi for the WhileOp as follows: 372 // if (xs[p] < xs[mid])) 373 // hi = mid; 374 // else 375 // lo = mid + 1; 376 Value newLo = builder.create<arith::SelectOp>(loc, cond2, lo, midp1); 377 Value newHi = builder.create<arith::SelectOp>(loc, cond2, mid, hi); 378 builder.create<scf::YieldOp>(loc, ValueRange{newLo, newHi}); 379 380 builder.setInsertionPointAfter(whileOp); 381 builder.create<func::ReturnOp>(loc, whileOp.getResult(0)); 382 } 383 384 /// Creates code to advance i in a loop based on xs[p] as follows: 385 /// while (xs[i] < xs[p]) i += step (step > 0) 386 /// or 387 /// while (xs[i] > xs[p]) i += step (step < 0) 388 /// The routine returns i as well as a boolean value to indicate whether 389 /// xs[i] == xs[p]. 390 static std::pair<Value, Value> createScanLoop(OpBuilder &builder, 391 ModuleOp module, 392 func::FuncOp func, ValueRange xs, 393 Value i, Value p, AffineMap xPerm, 394 uint64_t ny, int step) { 395 Location loc = func.getLoc(); 396 scf::WhileOp whileOp = 397 builder.create<scf::WhileOp>(loc, TypeRange{i.getType()}, ValueRange{i}); 398 399 Block *before = 400 builder.createBlock(&whileOp.getBefore(), {}, {i.getType()}, {loc}); 401 builder.setInsertionPointToEnd(before); 402 SmallVector<Value> compareOperands; 403 if (step > 0) { 404 compareOperands.push_back(before->getArgument(0)); 405 compareOperands.push_back(p); 406 } else { 407 assert(step < 0); 408 compareOperands.push_back(p); 409 compareOperands.push_back(before->getArgument(0)); 410 } 411 compareOperands.append(xs.begin(), xs.end()); 412 Value cond = createInlinedLessThan(builder, loc, compareOperands, xPerm, ny); 413 builder.create<scf::ConditionOp>(loc, cond, before->getArguments()); 414 415 Block *after = 416 builder.createBlock(&whileOp.getAfter(), {}, {i.getType()}, {loc}); 417 builder.setInsertionPointToEnd(after); 418 Value cs = constantIndex(builder, loc, step); 419 i = builder.create<arith::AddIOp>(loc, after->getArgument(0), cs); 420 builder.create<scf::YieldOp>(loc, ValueRange{i}); 421 i = whileOp.getResult(0); 422 423 builder.setInsertionPointAfter(whileOp); 424 compareOperands[0] = i; 425 compareOperands[1] = p; 426 Value compareEq = 427 createInlinedEqCompare(builder, loc, compareOperands, xPerm, ny); 428 429 return std::make_pair(whileOp.getResult(0), compareEq); 430 } 431 432 /// Creates and returns an IfOp to compare two elements and swap the elements 433 /// if compareFunc(data[b], data[a]) returns true. The new insertion point is 434 /// right after the swap instructions. 435 static scf::IfOp createCompareThenSwap(OpBuilder &builder, Location loc, 436 AffineMap xPerm, uint64_t ny, 437 SmallVectorImpl<Value> &swapOperands, 438 SmallVectorImpl<Value> &compareOperands, 439 Value a, Value b) { 440 // Compare(data[b], data[a]). 441 compareOperands[0] = b; 442 compareOperands[1] = a; 443 Value cond = createInlinedLessThan(builder, loc, compareOperands, xPerm, ny); 444 scf::IfOp ifOp = builder.create<scf::IfOp>(loc, cond, /*else=*/false); 445 builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); 446 swapOperands[0] = b; 447 swapOperands[1] = a; 448 createSwap(builder, loc, swapOperands, xPerm, ny); 449 return ifOp; 450 } 451 452 /// Creates code to insert the 3rd element to a list of two sorted elements. 453 static void createInsert3rd(OpBuilder &builder, Location loc, AffineMap xPerm, 454 uint64_t ny, SmallVectorImpl<Value> &swapOperands, 455 SmallVectorImpl<Value> &compareOperands, Value v0, 456 Value v1, Value v2) { 457 scf::IfOp ifOp = createCompareThenSwap(builder, loc, xPerm, ny, swapOperands, 458 compareOperands, v1, v2); 459 createCompareThenSwap(builder, loc, xPerm, ny, swapOperands, compareOperands, 460 v0, v1); 461 builder.setInsertionPointAfter(ifOp); 462 } 463 464 /// Creates code to sort 3 elements. 465 static void createSort3(OpBuilder &builder, Location loc, AffineMap xPerm, 466 uint64_t ny, SmallVectorImpl<Value> &swapOperands, 467 SmallVectorImpl<Value> &compareOperands, Value v0, 468 Value v1, Value v2) { 469 // Sort the first 2 elements. 470 scf::IfOp ifOp1 = createCompareThenSwap(builder, loc, xPerm, ny, swapOperands, 471 compareOperands, v0, v1); 472 builder.setInsertionPointAfter(ifOp1); 473 474 // Insert the 3th element. 475 createInsert3rd(builder, loc, xPerm, ny, swapOperands, compareOperands, v0, 476 v1, v2); 477 } 478 479 /// Creates code to sort 5 elements. 480 static void createSort5(OpBuilder &builder, Location loc, AffineMap xPerm, 481 uint64_t ny, SmallVectorImpl<Value> &swapOperands, 482 SmallVectorImpl<Value> &compareOperands, Value v0, 483 Value v1, Value v2, Value v3, Value v4) { 484 // Sort the first 3 elements. 485 createSort3(builder, loc, xPerm, ny, swapOperands, compareOperands, v0, v1, 486 v2); 487 488 auto insert4th = [&]() { 489 scf::IfOp ifOp = createCompareThenSwap( 490 builder, loc, xPerm, ny, swapOperands, compareOperands, v2, v3); 491 createInsert3rd(builder, loc, xPerm, ny, swapOperands, compareOperands, v0, 492 v1, v2); 493 builder.setInsertionPointAfter(ifOp); 494 }; 495 496 // Insert the 4th element. 497 insert4th(); 498 499 // Insert the 5th element. 500 scf::IfOp ifOp = createCompareThenSwap(builder, loc, xPerm, ny, swapOperands, 501 compareOperands, v3, v4); 502 insert4th(); 503 builder.setInsertionPointAfter(ifOp); 504 } 505 506 /// Creates a code block to swap the values in indices lo, mi, and hi so that 507 /// data[lo], data[mi] and data[hi] are sorted in non-decreasing values. When 508 /// the number of values in range [lo, hi) is more than a threshold, we also 509 /// include the middle of [lo, mi) and [mi, hi) and sort a total of five values. 510 static void createChoosePivot(OpBuilder &builder, ModuleOp module, 511 func::FuncOp func, AffineMap xPerm, uint64_t ny, 512 Value lo, Value hi, Value mi, ValueRange args) { 513 SmallVector<Value> compareOperands{mi, lo}; 514 constexpr uint64_t numXBuffers = 1; 515 compareOperands.append(args.begin() + xStartIdx, 516 args.begin() + xStartIdx + numXBuffers); 517 SmallVector<Value> swapOperands{mi, lo}; 518 swapOperands.append(args.begin() + xStartIdx, args.end()); 519 Location loc = func.getLoc(); 520 Value c1 = constantIndex(builder, loc, 1); 521 Value hiP1 = builder.create<arith::AddIOp>(loc, hi, c1); 522 Value len = builder.create<arith::SubIOp>(loc, hiP1, lo); 523 Value lenThreshold = constantIndex(builder, loc, 1000); 524 Value lenCond = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult, 525 len, lenThreshold); 526 scf::IfOp lenIf = builder.create<scf::IfOp>(loc, lenCond, /*else=*/true); 527 528 // When len < 1000, choose pivot from median of 3 values. 529 builder.setInsertionPointToStart(&lenIf.getThenRegion().front()); 530 createSort3(builder, loc, xPerm, ny, swapOperands, compareOperands, lo, mi, 531 hi); 532 533 // When len >= 1000, choose pivot from median of 5 values. 534 builder.setInsertionPointToStart(&lenIf.getElseRegion().front()); 535 Value miP1 = builder.create<arith::AddIOp>(loc, hi, c1); 536 Value a = builder.create<arith::AddIOp>(loc, lo, miP1); 537 // Value a is the middle between [loc, mi]. 538 a = builder.create<arith::ShRUIOp>(loc, a, c1); 539 Value b = builder.create<arith::AddIOp>(loc, mi, hiP1); 540 // Value b is the middle between [mi, hi]. 541 b = builder.create<arith::ShRUIOp>(loc, b, c1); 542 createSort5(builder, loc, xPerm, ny, swapOperands, compareOperands, lo, a, mi, 543 b, hi); 544 545 builder.setInsertionPointAfter(lenIf); 546 } 547 548 /// Creates a function to perform quick sort partition on the values in the 549 /// range of index [lo, hi), assuming lo < hi. 550 // 551 // The generated IR corresponds to this C like algorithm: 552 // int partition(lo, hi, xs) { 553 // p = (lo+hi)/2 // pivot index 554 // i = lo 555 // j = hi-1 556 // while (true) do { 557 // while (xs[i] < xs[p]) i ++; 558 // i_eq = (xs[i] == xs[p]); 559 // while (xs[j] > xs[p]) j --; 560 // j_eq = (xs[j] == xs[p]); 561 // 562 // if (i >= j) return j + 1; 563 // 564 // if (i < j) { 565 // swap(xs[i], xs[j]) 566 // if (i == p) { 567 // p = j; 568 // } else if (j == p) { 569 // p = i; 570 // } 571 // if (i_eq && j_eq) { 572 // ++i; 573 // --j; 574 // } 575 // } 576 // } 577 // } 578 static void createPartitionFunc(OpBuilder &builder, ModuleOp module, 579 func::FuncOp func, AffineMap xPerm, uint64_t ny, 580 uint32_t nTrailingP = 0) { 581 // Quick sort partition doesn't use trailing parameters. 582 (void)nTrailingP; 583 assert(nTrailingP == 0); 584 OpBuilder::InsertionGuard insertionGuard(builder); 585 586 Block *entryBlock = func.addEntryBlock(); 587 builder.setInsertionPointToStart(entryBlock); 588 589 Location loc = func.getLoc(); 590 ValueRange args = entryBlock->getArguments(); 591 Value lo = args[loIdx]; 592 Value hi = args[hiIdx]; 593 Value sum = builder.create<arith::AddIOp>(loc, lo, hi); 594 Value c1 = constantIndex(builder, loc, 1); 595 Value p = builder.create<arith::ShRUIOp>(loc, sum, c1); 596 597 Value i = lo; 598 Value j = builder.create<arith::SubIOp>(loc, hi, c1); 599 createChoosePivot(builder, module, func, xPerm, ny, i, j, p, args); 600 Value trueVal = constantI1(builder, loc, true); // The value for while (true) 601 SmallVector<Value, 4> operands{i, j, p, trueVal}; // Exactly four values. 602 SmallVector<Type, 4> types{i.getType(), j.getType(), p.getType(), 603 trueVal.getType()}; 604 scf::WhileOp whileOp = builder.create<scf::WhileOp>(loc, types, operands); 605 606 // The before-region of the WhileOp. 607 Block *before = builder.createBlock(&whileOp.getBefore(), {}, types, 608 {loc, loc, loc, loc}); 609 builder.setInsertionPointToEnd(before); 610 builder.create<scf::ConditionOp>(loc, before->getArgument(3), 611 before->getArguments()); 612 613 // The after-region of the WhileOp. 614 Block *after = 615 builder.createBlock(&whileOp.getAfter(), {}, types, {loc, loc, loc, loc}); 616 builder.setInsertionPointToEnd(after); 617 i = after->getArgument(0); 618 j = after->getArgument(1); 619 p = after->getArgument(2); 620 621 constexpr uint64_t numXBuffers = 1; 622 auto [iresult, iCompareEq] = 623 createScanLoop(builder, module, func, args.slice(xStartIdx, numXBuffers), 624 i, p, xPerm, ny, 1); 625 i = iresult; 626 auto [jresult, jCompareEq] = 627 createScanLoop(builder, module, func, args.slice(xStartIdx, numXBuffers), 628 j, p, xPerm, ny, -1); 629 j = jresult; 630 631 // If i < j: 632 Value cond = 633 builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult, i, j); 634 scf::IfOp ifOp = builder.create<scf::IfOp>(loc, types, cond, /*else=*/true); 635 builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); 636 SmallVector<Value> swapOperands{i, j}; 637 swapOperands.append(args.begin() + xStartIdx, args.end()); 638 createSwap(builder, loc, swapOperands, xPerm, ny); 639 // If the pivot is moved, update p with the new pivot. 640 Value icond = 641 builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, i, p); 642 scf::IfOp ifOpI = builder.create<scf::IfOp>(loc, TypeRange{p.getType()}, 643 icond, /*else=*/true); 644 builder.setInsertionPointToStart(&ifOpI.getThenRegion().front()); 645 builder.create<scf::YieldOp>(loc, ValueRange{j}); 646 builder.setInsertionPointToStart(&ifOpI.getElseRegion().front()); 647 Value jcond = 648 builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, j, p); 649 scf::IfOp ifOpJ = builder.create<scf::IfOp>(loc, TypeRange{p.getType()}, 650 jcond, /*else=*/true); 651 builder.setInsertionPointToStart(&ifOpJ.getThenRegion().front()); 652 builder.create<scf::YieldOp>(loc, ValueRange{i}); 653 builder.setInsertionPointToStart(&ifOpJ.getElseRegion().front()); 654 builder.create<scf::YieldOp>(loc, ValueRange{p}); 655 builder.setInsertionPointAfter(ifOpJ); 656 builder.create<scf::YieldOp>(loc, ifOpJ.getResults()); 657 builder.setInsertionPointAfter(ifOpI); 658 Value compareEqIJ = 659 builder.create<arith::AndIOp>(loc, iCompareEq, jCompareEq); 660 scf::IfOp ifOp2 = builder.create<scf::IfOp>( 661 loc, TypeRange{i.getType(), j.getType()}, compareEqIJ, /*else=*/true); 662 builder.setInsertionPointToStart(&ifOp2.getThenRegion().front()); 663 Value i2 = builder.create<arith::AddIOp>(loc, i, c1); 664 Value j2 = builder.create<arith::SubIOp>(loc, j, c1); 665 builder.create<scf::YieldOp>(loc, ValueRange{i2, j2}); 666 builder.setInsertionPointToStart(&ifOp2.getElseRegion().front()); 667 builder.create<scf::YieldOp>(loc, ValueRange{i, j}); 668 builder.setInsertionPointAfter(ifOp2); 669 builder.create<scf::YieldOp>( 670 loc, 671 ValueRange{ifOp2.getResult(0), ifOp2.getResult(1), ifOpI.getResult(0), 672 /*cont=*/constantI1(builder, loc, true)}); 673 674 // False branch for if i < j (i.e., i >= j): 675 builder.setInsertionPointToStart(&ifOp.getElseRegion().front()); 676 p = builder.create<arith::AddIOp>(loc, j, 677 constantOne(builder, loc, j.getType())); 678 builder.create<scf::YieldOp>( 679 loc, ValueRange{i, j, p, /*cont=*/constantI1(builder, loc, false)}); 680 681 // Return for the whileOp. 682 builder.setInsertionPointAfter(ifOp); 683 builder.create<scf::YieldOp>(loc, ifOp.getResults()); 684 685 // Return for the function. 686 builder.setInsertionPointAfter(whileOp); 687 builder.create<func::ReturnOp>(loc, whileOp.getResult(2)); 688 } 689 690 /// Computes (n-2)/n, assuming n has index type. 691 static Value createSubTwoDividedByTwo(OpBuilder &builder, Location loc, 692 Value n) { 693 Value i2 = constantIndex(builder, loc, 2); 694 Value res = builder.create<arith::SubIOp>(loc, n, i2); 695 Value i1 = constantIndex(builder, loc, 1); 696 return builder.create<arith::ShRUIOp>(loc, res, i1); 697 } 698 699 /// Creates a function to heapify the subtree with root `start` within the full 700 /// binary tree in the range of index [first, first + n). 701 // 702 // The generated IR corresponds to this C like algorithm: 703 // void shiftDown(first, start, n, data) { 704 // if (n >= 2) { 705 // child = start - first 706 // if ((n-2)/2 >= child) { 707 // // Left child exists. 708 // child = child * 2 + 1 // Initialize the bigger child to left child. 709 // childIndex = child + first 710 // if (child+1 < n && data[childIndex] < data[childIndex+1]) 711 // // Right child exits and is bigger. 712 // childIndex++; child++; 713 // // Shift data[start] down to where it belongs in the subtree. 714 // while (data[start] < data[childIndex) { 715 // swap(data[start], data[childIndex]) 716 // start = childIndex 717 // if ((n - 2)/2 >= child) { 718 // // Left child exists. 719 // child = 2*child + 1 720 // childIndex = child + 1 721 // if (child + 1) < n && data[childIndex] < data[childIndex+1] 722 // childIndex++; child++; 723 // } 724 // } 725 // } 726 // } 727 // } 728 // 729 static void createShiftDownFunc(OpBuilder &builder, ModuleOp module, 730 func::FuncOp func, AffineMap xPerm, uint64_t ny, 731 uint32_t nTrailingP) { 732 // The value n is passed in as a trailing parameter. 733 assert(nTrailingP == 1); 734 OpBuilder::InsertionGuard insertionGuard(builder); 735 Block *entryBlock = func.addEntryBlock(); 736 builder.setInsertionPointToStart(entryBlock); 737 738 Location loc = func.getLoc(); 739 Value n = entryBlock->getArguments().back(); 740 ValueRange args = entryBlock->getArguments().drop_back(); 741 Value first = args[loIdx]; 742 Value start = args[hiIdx]; 743 744 // If (n >= 2). 745 Value c2 = constantIndex(builder, loc, 2); 746 Value condN = 747 builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::uge, n, c2); 748 scf::IfOp ifN = builder.create<scf::IfOp>(loc, condN, /*else=*/false); 749 builder.setInsertionPointToStart(&ifN.getThenRegion().front()); 750 Value child = builder.create<arith::SubIOp>(loc, start, first); 751 752 // If ((n-2)/2 >= child). 753 Value t = createSubTwoDividedByTwo(builder, loc, n); 754 Value condNc = 755 builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::uge, t, child); 756 scf::IfOp ifNc = builder.create<scf::IfOp>(loc, condNc, /*else=*/false); 757 758 builder.setInsertionPointToStart(&ifNc.getThenRegion().front()); 759 Value c1 = constantIndex(builder, loc, 1); 760 SmallVector<Value> compareOperands{start, start}; 761 constexpr uint64_t numXBuffers = 1; 762 compareOperands.append(args.begin() + xStartIdx, 763 args.begin() + xStartIdx + numXBuffers); 764 765 // Generate code to inspect the children of 'r' and return the larger child 766 // as follows: 767 // child = r * 2 + 1 // Left child. 768 // childIndex = child + first 769 // if (child+1 < n && data[childIndex] < data[childIndex+1]) 770 // childIndex ++; child ++ // Right child is bigger. 771 auto getLargerChild = [&](Value r) -> std::pair<Value, Value> { 772 Value lChild = builder.create<arith::ShLIOp>(loc, r, c1); 773 lChild = builder.create<arith::AddIOp>(loc, lChild, c1); 774 Value lChildIdx = builder.create<arith::AddIOp>(loc, lChild, first); 775 Value rChild = builder.create<arith::AddIOp>(loc, lChild, c1); 776 Value cond1 = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult, 777 rChild, n); 778 SmallVector<Type, 2> ifTypes(2, r.getType()); 779 scf::IfOp if1 = 780 builder.create<scf::IfOp>(loc, ifTypes, cond1, /*else=*/true); 781 builder.setInsertionPointToStart(&if1.getThenRegion().front()); 782 Value rChildIdx = builder.create<arith::AddIOp>(loc, rChild, first); 783 // Compare data[left] < data[right]. 784 compareOperands[0] = lChildIdx; 785 compareOperands[1] = rChildIdx; 786 Value cond2 = 787 createInlinedLessThan(builder, loc, compareOperands, xPerm, ny); 788 scf::IfOp if2 = 789 builder.create<scf::IfOp>(loc, ifTypes, cond2, /*else=*/true); 790 builder.setInsertionPointToStart(&if2.getThenRegion().front()); 791 builder.create<scf::YieldOp>(loc, ValueRange{rChild, rChildIdx}); 792 builder.setInsertionPointToStart(&if2.getElseRegion().front()); 793 builder.create<scf::YieldOp>(loc, ValueRange{lChild, lChildIdx}); 794 builder.setInsertionPointAfter(if2); 795 builder.create<scf::YieldOp>(loc, if2.getResults()); 796 builder.setInsertionPointToStart(&if1.getElseRegion().front()); 797 builder.create<scf::YieldOp>(loc, ValueRange{lChild, lChildIdx}); 798 builder.setInsertionPointAfter(if1); 799 return std::make_pair(if1.getResult(0), if1.getResult(1)); 800 }; 801 802 Value childIdx; 803 std::tie(child, childIdx) = getLargerChild(child); 804 805 // While (data[start] < data[childIndex]). 806 SmallVector<Type, 3> types(3, child.getType()); 807 scf::WhileOp whileOp = builder.create<scf::WhileOp>( 808 loc, types, SmallVector<Value, 2>{start, child, childIdx}); 809 810 // The before-region of the WhileOp. 811 SmallVector<Location, 3> locs(3, loc); 812 Block *before = builder.createBlock(&whileOp.getBefore(), {}, types, locs); 813 builder.setInsertionPointToEnd(before); 814 start = before->getArgument(0); 815 childIdx = before->getArgument(2); 816 compareOperands[0] = start; 817 compareOperands[1] = childIdx; 818 Value cond = createInlinedLessThan(builder, loc, compareOperands, xPerm, ny); 819 builder.create<scf::ConditionOp>(loc, cond, before->getArguments()); 820 821 // The after-region of the WhileOp. 822 Block *after = builder.createBlock(&whileOp.getAfter(), {}, types, locs); 823 start = after->getArgument(0); 824 child = after->getArgument(1); 825 childIdx = after->getArgument(2); 826 SmallVector<Value> swapOperands{start, childIdx}; 827 swapOperands.append(args.begin() + xStartIdx, args.end()); 828 createSwap(builder, loc, swapOperands, xPerm, ny); 829 start = childIdx; 830 Value cond2 = 831 builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::uge, t, child); 832 scf::IfOp if2 = builder.create<scf::IfOp>( 833 loc, TypeRange{child.getType(), child.getType()}, cond2, /*else=*/true); 834 builder.setInsertionPointToStart(&if2.getThenRegion().front()); 835 auto [newChild, newChildIdx] = getLargerChild(child); 836 builder.create<scf::YieldOp>(loc, ValueRange{newChild, newChildIdx}); 837 builder.setInsertionPointToStart(&if2.getElseRegion().front()); 838 builder.create<scf::YieldOp>(loc, ValueRange{child, childIdx}); 839 builder.setInsertionPointAfter(if2); 840 builder.create<scf::YieldOp>( 841 loc, ValueRange{start, if2.getResult(0), if2.getResult(1)}); 842 843 builder.setInsertionPointAfter(ifN); 844 builder.create<func::ReturnOp>(loc); 845 } 846 847 /// Creates a function to perform heap sort on the values in the range of index 848 /// [lo, hi) with the assumption hi - lo >= 2. 849 // 850 // The generate IR corresponds to this C like algorithm: 851 // void heapSort(lo, hi, data) { 852 // n = hi - lo 853 // for i = (n-2)/2 downto 0 854 // shiftDown(lo, lo+i, n) 855 // 856 // for l = n downto 2 857 // swap(lo, lo+l-1) 858 // shiftdown(lo, lo, l-1) 859 // } 860 static void createHeapSortFunc(OpBuilder &builder, ModuleOp module, 861 func::FuncOp func, AffineMap xPerm, uint64_t ny, 862 uint32_t nTrailingP) { 863 // Heap sort function doesn't have trailing parameters. 864 (void)nTrailingP; 865 assert(nTrailingP == 0); 866 OpBuilder::InsertionGuard insertionGuard(builder); 867 Block *entryBlock = func.addEntryBlock(); 868 builder.setInsertionPointToStart(entryBlock); 869 870 Location loc = func.getLoc(); 871 ValueRange args = entryBlock->getArguments(); 872 Value lo = args[loIdx]; 873 Value hi = args[hiIdx]; 874 Value n = builder.create<arith::SubIOp>(loc, hi, lo); 875 876 // For i = (n-2)/2 downto 0. 877 Value c0 = constantIndex(builder, loc, 0); 878 Value c1 = constantIndex(builder, loc, 1); 879 Value s = createSubTwoDividedByTwo(builder, loc, n); 880 Value up = builder.create<arith::AddIOp>(loc, s, c1); 881 scf::ForOp forI = builder.create<scf::ForOp>(loc, c0, up, c1); 882 builder.setInsertionPointToStart(forI.getBody()); 883 Value i = builder.create<arith::SubIOp>(loc, s, forI.getInductionVar()); 884 Value lopi = builder.create<arith::AddIOp>(loc, lo, i); 885 SmallVector<Value> shiftDownOperands = {lo, lopi}; 886 shiftDownOperands.append(args.begin() + xStartIdx, args.end()); 887 shiftDownOperands.push_back(n); 888 FlatSymbolRefAttr shiftDownFunc = getMangledSortHelperFunc( 889 builder, func, TypeRange(), kShiftDownFuncNamePrefix, xPerm, ny, 890 shiftDownOperands, createShiftDownFunc, /*nTrailingP=*/1); 891 builder.create<func::CallOp>(loc, shiftDownFunc, TypeRange(), 892 shiftDownOperands); 893 894 builder.setInsertionPointAfter(forI); 895 // For l = n downto 2. 896 up = builder.create<arith::SubIOp>(loc, n, c1); 897 scf::ForOp forL = builder.create<scf::ForOp>(loc, c0, up, c1); 898 builder.setInsertionPointToStart(forL.getBody()); 899 Value l = builder.create<arith::SubIOp>(loc, n, forL.getInductionVar()); 900 Value loplm1 = builder.create<arith::AddIOp>(loc, lo, l); 901 loplm1 = builder.create<arith::SubIOp>(loc, loplm1, c1); 902 SmallVector<Value> swapOperands{lo, loplm1}; 903 swapOperands.append(args.begin() + xStartIdx, args.end()); 904 createSwap(builder, loc, swapOperands, xPerm, ny); 905 shiftDownOperands[1] = lo; 906 shiftDownOperands[shiftDownOperands.size() - 1] = 907 builder.create<arith::SubIOp>(loc, l, c1); 908 builder.create<func::CallOp>(loc, shiftDownFunc, TypeRange(), 909 shiftDownOperands); 910 911 builder.setInsertionPointAfter(forL); 912 builder.create<func::ReturnOp>(loc); 913 } 914 915 /// A helper for generating code to perform quick sort. It partitions [lo, hi), 916 /// recursively calls quick sort to process the smaller partition and returns 917 /// the bigger partition to be processed by the enclosed while-loop. 918 static std::pair<Value, Value> 919 createQuickSort(OpBuilder &builder, ModuleOp module, func::FuncOp func, 920 ValueRange args, AffineMap xPerm, uint64_t ny, 921 uint32_t nTrailingP) { 922 MLIRContext *context = module.getContext(); 923 Location loc = func.getLoc(); 924 Value lo = args[loIdx]; 925 Value hi = args[hiIdx]; 926 SmallVector<Type, 2> types(2, lo.getType()); // Only two types. 927 928 FlatSymbolRefAttr partitionFunc = getMangledSortHelperFunc( 929 builder, func, {IndexType::get(context)}, kPartitionFuncNamePrefix, xPerm, 930 ny, args.drop_back(nTrailingP), createPartitionFunc); 931 Value p = builder 932 .create<func::CallOp>(loc, partitionFunc, 933 TypeRange{IndexType::get(context)}, 934 args.drop_back(nTrailingP)) 935 .getResult(0); 936 937 Value lenLow = builder.create<arith::SubIOp>(loc, p, lo); 938 Value lenHigh = builder.create<arith::SubIOp>(loc, hi, p); 939 // Partition already sorts array with len <= 2 940 Value c2 = constantIndex(builder, loc, 2); 941 Value len = builder.create<arith::SubIOp>(loc, hi, lo); 942 Value lenGtTwo = 943 builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ugt, len, c2); 944 scf::IfOp ifLenGtTwo = 945 builder.create<scf::IfOp>(loc, types, lenGtTwo, /*else=*/true); 946 builder.setInsertionPointToStart(&ifLenGtTwo.getElseRegion().front()); 947 // Returns an empty range to mark the entire region is fully sorted. 948 builder.create<scf::YieldOp>(loc, ValueRange{lo, lo}); 949 950 // Else len > 2, need recursion. 951 builder.setInsertionPointToStart(&ifLenGtTwo.getThenRegion().front()); 952 Value cond = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ule, 953 lenLow, lenHigh); 954 955 Value c0 = constantIndex(builder, loc, 0); 956 scf::IfOp ifOp = builder.create<scf::IfOp>(loc, types, cond, /*else=*/true); 957 958 auto mayRecursion = [&](Value low, Value high, Value len) { 959 Value cond = 960 builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ne, len, c0); 961 scf::IfOp ifOp = builder.create<scf::IfOp>(loc, cond, /*else=*/false); 962 builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); 963 SmallVector<Value> operands{low, high}; 964 operands.append(args.begin() + xStartIdx, args.end()); 965 builder.create<func::CallOp>(loc, func, operands); 966 builder.setInsertionPointAfter(ifOp); 967 }; 968 969 // Recursively call quickSort to process the smaller partition and return 970 // the bigger partition to be processed by the enclosed while-loop. 971 builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); 972 mayRecursion(lo, p, lenLow); 973 builder.create<scf::YieldOp>(loc, ValueRange{p, hi}); 974 975 builder.setInsertionPointToStart(&ifOp.getElseRegion().front()); 976 mayRecursion(p, hi, lenHigh); 977 builder.create<scf::YieldOp>(loc, ValueRange{lo, p}); 978 979 builder.setInsertionPointAfter(ifOp); 980 builder.create<scf::YieldOp>(loc, ifOp.getResults()); 981 982 builder.setInsertionPointAfter(ifLenGtTwo); 983 return std::make_pair(ifLenGtTwo.getResult(0), ifLenGtTwo.getResult(1)); 984 } 985 986 /// Creates a function to perform insertion sort on the values in the range of 987 /// index [lo, hi). 988 // 989 // The generate IR corresponds to this C like algorithm: 990 // void insertionSort(lo, hi, data) { 991 // for (i = lo+1; i < hi; i++) { 992 // d = data[i]; 993 // p = binarySearch(lo, i-1, data) 994 // for (j = 0; j > i - p; j++) 995 // data[i-j] = data[i-j-1] 996 // data[p] = d 997 // } 998 // } 999 static void createSortStableFunc(OpBuilder &builder, ModuleOp module, 1000 func::FuncOp func, AffineMap xPerm, 1001 uint64_t ny, uint32_t nTrailingP) { 1002 // Stable sort function doesn't use trailing parameters. 1003 (void)nTrailingP; 1004 assert(nTrailingP == 0); 1005 OpBuilder::InsertionGuard insertionGuard(builder); 1006 Block *entryBlock = func.addEntryBlock(); 1007 builder.setInsertionPointToStart(entryBlock); 1008 1009 MLIRContext *context = module.getContext(); 1010 Location loc = func.getLoc(); 1011 ValueRange args = entryBlock->getArguments(); 1012 Value c1 = constantIndex(builder, loc, 1); 1013 Value lo = args[loIdx]; 1014 Value hi = args[hiIdx]; 1015 Value lop1 = builder.create<arith::AddIOp>(loc, lo, c1); 1016 1017 // Start the outer for-stmt with induction variable i. 1018 scf::ForOp forOpI = builder.create<scf::ForOp>(loc, lop1, hi, c1); 1019 builder.setInsertionPointToStart(forOpI.getBody()); 1020 Value i = forOpI.getInductionVar(); 1021 1022 // Binary search to find the insertion point p. 1023 SmallVector<Value> operands{lo, i}; 1024 operands.append(args.begin() + xStartIdx, args.end()); 1025 FlatSymbolRefAttr searchFunc = getMangledSortHelperFunc( 1026 builder, func, {IndexType::get(context)}, kBinarySearchFuncNamePrefix, 1027 xPerm, ny, operands, createBinarySearchFunc); 1028 Value p = builder 1029 .create<func::CallOp>(loc, searchFunc, TypeRange{c1.getType()}, 1030 operands) 1031 .getResult(0); 1032 1033 // Move the value at data[i] to a temporary location. 1034 operands[0] = operands[1] = i; 1035 SmallVector<Value> d; 1036 forEachIJPairInAllBuffers( 1037 builder, loc, operands, xPerm, ny, 1038 [&](uint64_t unused, Value i, Value unused2, Value buffer) { 1039 d.push_back(builder.create<memref::LoadOp>(loc, buffer, i)); 1040 }); 1041 1042 // Start the inner for-stmt with induction variable j, for moving data[p..i) 1043 // to data[p+1..i+1). 1044 Value imp = builder.create<arith::SubIOp>(loc, i, p); 1045 Value c0 = constantIndex(builder, loc, 0); 1046 scf::ForOp forOpJ = builder.create<scf::ForOp>(loc, c0, imp, c1); 1047 builder.setInsertionPointToStart(forOpJ.getBody()); 1048 Value j = forOpJ.getInductionVar(); 1049 Value imj = builder.create<arith::SubIOp>(loc, i, j); 1050 operands[1] = imj; 1051 operands[0] = builder.create<arith::SubIOp>(loc, imj, c1); 1052 forEachIJPairInAllBuffers( 1053 builder, loc, operands, xPerm, ny, 1054 [&](uint64_t unused, Value imjm1, Value imj, Value buffer) { 1055 Value t = builder.create<memref::LoadOp>(loc, buffer, imjm1); 1056 builder.create<memref::StoreOp>(loc, t, buffer, imj); 1057 }); 1058 1059 // Store the value at data[i] to data[p]. 1060 builder.setInsertionPointAfter(forOpJ); 1061 operands[0] = operands[1] = p; 1062 forEachIJPairInAllBuffers( 1063 builder, loc, operands, xPerm, ny, 1064 [&](uint64_t k, Value p, Value usused, Value buffer) { 1065 builder.create<memref::StoreOp>(loc, d[k], buffer, p); 1066 }); 1067 1068 builder.setInsertionPointAfter(forOpI); 1069 builder.create<func::ReturnOp>(loc); 1070 } 1071 1072 /// Creates a function to perform quick sort or a hybrid quick sort on the 1073 /// values in the range of index [lo, hi). 1074 // 1075 // 1076 // When nTrailingP == 0, the generated IR corresponds to this C like algorithm: 1077 // void quickSort(lo, hi, data) { 1078 // while (lo + 1 < hi) { 1079 // p = partition(low, high, data); 1080 // if (len(lo, p) < len(p+1, hi)) { 1081 // quickSort(lo, p, data); 1082 // lo = p+1; 1083 // } else { 1084 // quickSort(p + 1, hi, data); 1085 // hi = p; 1086 // } 1087 // } 1088 // } 1089 // 1090 // When nTrailingP == 1, the generated IR corresponds to this C like algorithm: 1091 // void hybridQuickSort(lo, hi, data, depthLimit) { 1092 // while (lo + 1 < hi) { 1093 // len = hi - lo; 1094 // if (len <= limit) { 1095 // insertionSort(lo, hi, data); 1096 // } else { 1097 // depthLimit --; 1098 // if (depthLimit <= 0) { 1099 // heapSort(lo, hi, data); 1100 // } else { 1101 // p = partition(low, high, data); 1102 // if (len(lo, p) < len(p+1, hi)) { 1103 // quickSort(lo, p, data, depthLimit); 1104 // lo = p+1; 1105 // } else { 1106 // quickSort(p + 1, hi, data, depthLimit); 1107 // hi = p; 1108 // } 1109 // } 1110 // } 1111 // } 1112 // } 1113 // 1114 static void createQuickSortFunc(OpBuilder &builder, ModuleOp module, 1115 func::FuncOp func, AffineMap xPerm, uint64_t ny, 1116 uint32_t nTrailingP) { 1117 assert(nTrailingP == 1 || nTrailingP == 0); 1118 bool isHybrid = (nTrailingP == 1); 1119 OpBuilder::InsertionGuard insertionGuard(builder); 1120 Block *entryBlock = func.addEntryBlock(); 1121 builder.setInsertionPointToStart(entryBlock); 1122 1123 Location loc = func.getLoc(); 1124 SmallVector<Value> args; 1125 args.append(entryBlock->getArguments().begin(), 1126 entryBlock->getArguments().end()); 1127 Value lo = args[loIdx]; 1128 Value hi = args[hiIdx]; 1129 SmallVector<Type, 2> types(2, lo.getType()); // Only two types. 1130 scf::WhileOp whileOp = 1131 builder.create<scf::WhileOp>(loc, types, SmallVector<Value, 2>{lo, hi}); 1132 1133 // The before-region of the WhileOp. 1134 Block *before = 1135 builder.createBlock(&whileOp.getBefore(), {}, types, {loc, loc}); 1136 builder.setInsertionPointToEnd(before); 1137 lo = before->getArgument(0); 1138 hi = before->getArgument(1); 1139 Value loP1 = 1140 builder.create<arith::AddIOp>(loc, lo, constantIndex(builder, loc, 1)); 1141 Value needSort = 1142 builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult, loP1, hi); 1143 builder.create<scf::ConditionOp>(loc, needSort, before->getArguments()); 1144 1145 // The after-region of the WhileOp. 1146 Block *after = 1147 builder.createBlock(&whileOp.getAfter(), {}, types, {loc, loc}); 1148 builder.setInsertionPointToEnd(after); 1149 lo = after->getArgument(0); 1150 hi = after->getArgument(1); 1151 args[0] = lo; 1152 args[1] = hi; 1153 1154 if (isHybrid) { 1155 Value len = builder.create<arith::SubIOp>(loc, hi, lo); 1156 Value lenLimit = constantIndex(builder, loc, 30); 1157 Value lenCond = builder.create<arith::CmpIOp>( 1158 loc, arith::CmpIPredicate::ule, len, lenLimit); 1159 scf::IfOp lenIf = 1160 builder.create<scf::IfOp>(loc, types, lenCond, /*else=*/true); 1161 1162 // When len <= limit. 1163 builder.setInsertionPointToStart(&lenIf.getThenRegion().front()); 1164 FlatSymbolRefAttr insertionSortFunc = getMangledSortHelperFunc( 1165 builder, func, TypeRange(), kSortStableFuncNamePrefix, xPerm, ny, 1166 ValueRange(args).drop_back(nTrailingP), createSortStableFunc); 1167 builder.create<func::CallOp>(loc, insertionSortFunc, TypeRange(), 1168 ValueRange(args).drop_back(nTrailingP)); 1169 builder.create<scf::YieldOp>(loc, ValueRange{lo, lo}); 1170 1171 // When len > limit. 1172 builder.setInsertionPointToStart(&lenIf.getElseRegion().front()); 1173 Value depthLimit = args.back(); 1174 depthLimit = builder.create<arith::SubIOp>(loc, depthLimit, 1175 constantI64(builder, loc, 1)); 1176 Value depthCond = 1177 builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ule, 1178 depthLimit, constantI64(builder, loc, 0)); 1179 scf::IfOp depthIf = 1180 builder.create<scf::IfOp>(loc, types, depthCond, /*else=*/true); 1181 1182 // When depth exceeds limit. 1183 builder.setInsertionPointToStart(&depthIf.getThenRegion().front()); 1184 FlatSymbolRefAttr heapSortFunc = getMangledSortHelperFunc( 1185 builder, func, TypeRange(), kHeapSortFuncNamePrefix, xPerm, ny, 1186 ValueRange(args).drop_back(nTrailingP), createHeapSortFunc); 1187 builder.create<func::CallOp>(loc, heapSortFunc, TypeRange(), 1188 ValueRange(args).drop_back(nTrailingP)); 1189 builder.create<scf::YieldOp>(loc, ValueRange{lo, lo}); 1190 1191 // When depth doesn't exceed limit. 1192 builder.setInsertionPointToStart(&depthIf.getElseRegion().front()); 1193 args.back() = depthLimit; 1194 std::tie(lo, hi) = 1195 createQuickSort(builder, module, func, args, xPerm, ny, nTrailingP); 1196 builder.create<scf::YieldOp>(loc, ValueRange{lo, hi}); 1197 1198 builder.setInsertionPointAfter(depthIf); 1199 lo = depthIf.getResult(0); 1200 hi = depthIf.getResult(1); 1201 builder.create<scf::YieldOp>(loc, ValueRange{lo, hi}); 1202 1203 builder.setInsertionPointAfter(lenIf); 1204 lo = lenIf.getResult(0); 1205 hi = lenIf.getResult(1); 1206 } else { 1207 std::tie(lo, hi) = 1208 createQuickSort(builder, module, func, args, xPerm, ny, nTrailingP); 1209 } 1210 1211 // New [lo, hi) for the next while-loop iteration. 1212 builder.create<scf::YieldOp>(loc, ValueRange{lo, hi}); 1213 1214 // After the while-loop. 1215 builder.setInsertionPointAfter(whileOp); 1216 builder.create<func::ReturnOp>(loc); 1217 } 1218 1219 /// Implements the rewriting for operator sort and sort_coo. 1220 template <typename OpTy> 1221 LogicalResult matchAndRewriteSortOp(OpTy op, ValueRange xys, AffineMap xPerm, 1222 uint64_t ny, PatternRewriter &rewriter) { 1223 Location loc = op.getLoc(); 1224 SmallVector<Value> operands{constantIndex(rewriter, loc, 0), op.getN()}; 1225 1226 // Convert `values` to have dynamic shape and append them to `operands`. 1227 for (Value v : xys) { 1228 auto mtp = getMemRefType(v); 1229 if (!mtp.isDynamicDim(0)) { 1230 auto newMtp = 1231 MemRefType::get({ShapedType::kDynamic}, mtp.getElementType()); 1232 v = rewriter.create<memref::CastOp>(loc, newMtp, v); 1233 } 1234 operands.push_back(v); 1235 } 1236 1237 auto insertPoint = op->template getParentOfType<func::FuncOp>(); 1238 if (!insertPoint) 1239 return failure(); 1240 1241 SmallString<32> funcName; 1242 FuncGeneratorType funcGenerator; 1243 uint32_t nTrailingP = 0; 1244 switch (op.getAlgorithm()) { 1245 case SparseTensorSortKind::HybridQuickSort: { 1246 funcName = kHybridQuickSortFuncNamePrefix; 1247 funcGenerator = createQuickSortFunc; 1248 nTrailingP = 1; 1249 // As a heuristics, set depthLimit = 2 * log2(n). 1250 Value lo = operands[loIdx]; 1251 Value hi = operands[hiIdx]; 1252 Value len = rewriter.create<arith::IndexCastOp>( 1253 loc, rewriter.getI64Type(), 1254 rewriter.create<arith::SubIOp>(loc, hi, lo)); 1255 Value depthLimit = rewriter.create<arith::SubIOp>( 1256 loc, constantI64(rewriter, loc, 64), 1257 rewriter.create<math::CountLeadingZerosOp>(loc, len)); 1258 operands.push_back(depthLimit); 1259 break; 1260 } 1261 case SparseTensorSortKind::QuickSort: 1262 funcName = kQuickSortFuncNamePrefix; 1263 funcGenerator = createQuickSortFunc; 1264 break; 1265 case SparseTensorSortKind::InsertionSortStable: 1266 funcName = kSortStableFuncNamePrefix; 1267 funcGenerator = createSortStableFunc; 1268 break; 1269 case SparseTensorSortKind::HeapSort: 1270 funcName = kHeapSortFuncNamePrefix; 1271 funcGenerator = createHeapSortFunc; 1272 break; 1273 } 1274 1275 FlatSymbolRefAttr func = 1276 getMangledSortHelperFunc(rewriter, insertPoint, TypeRange(), funcName, 1277 xPerm, ny, operands, funcGenerator, nTrailingP); 1278 rewriter.replaceOpWithNewOp<func::CallOp>(op, func, TypeRange(), operands); 1279 return success(); 1280 } 1281 1282 //===---------------------------------------------------------------------===// 1283 // The actual sparse buffer rewriting rules. 1284 //===---------------------------------------------------------------------===// 1285 1286 namespace { 1287 /// Sparse rewriting rule for the push_back operator. 1288 struct PushBackRewriter : OpRewritePattern<PushBackOp> { 1289 public: 1290 using OpRewritePattern<PushBackOp>::OpRewritePattern; 1291 PushBackRewriter(MLIRContext *context, bool enableInit) 1292 : OpRewritePattern(context), enableBufferInitialization(enableInit) {} 1293 LogicalResult matchAndRewrite(PushBackOp op, 1294 PatternRewriter &rewriter) const override { 1295 // Rewrite push_back(buffer, value, n) to: 1296 // new_size = size(buffer) + n 1297 // if (new_size > capacity(buffer)) 1298 // while new_size > new_capacity 1299 // new_capacity = new_capacity*2 1300 // new_buffer = realloc(buffer, new_capacity) 1301 // buffer = new_buffer 1302 // subBuffer = subviewof(buffer) 1303 // linalg.fill subBuffer value 1304 // 1305 // size(buffer) += n 1306 // 1307 // The capacity check is skipped when the attribute inbounds is presented. 1308 Location loc = op->getLoc(); 1309 Value c0 = constantIndex(rewriter, loc, 0); 1310 Value buffer = op.getInBuffer(); 1311 Value capacity = rewriter.create<memref::DimOp>(loc, buffer, c0); 1312 Value size = op.getCurSize(); 1313 Value value = op.getValue(); 1314 1315 Value n = op.getN() ? op.getN() : constantIndex(rewriter, loc, 1); 1316 Value newSize = rewriter.create<arith::AddIOp>(loc, size, n); 1317 auto nValue = dyn_cast_or_null<arith::ConstantIndexOp>(n.getDefiningOp()); 1318 bool nIsOne = (nValue && nValue.value() == 1); 1319 1320 if (!op.getInbounds()) { 1321 Value cond = rewriter.create<arith::CmpIOp>( 1322 loc, arith::CmpIPredicate::ugt, newSize, capacity); 1323 1324 Value c2 = constantIndex(rewriter, loc, 2); 1325 auto bufferType = 1326 MemRefType::get({ShapedType::kDynamic}, value.getType()); 1327 scf::IfOp ifOp = rewriter.create<scf::IfOp>(loc, bufferType, cond, 1328 /*else=*/true); 1329 // True branch. 1330 rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front()); 1331 if (nIsOne) { 1332 capacity = rewriter.create<arith::MulIOp>(loc, capacity, c2); 1333 } else { 1334 // Use a do-while loop to calculate the new capacity as follows: 1335 // do { new_capacity *= 2 } while (size > new_capacity) 1336 scf::WhileOp whileOp = 1337 rewriter.create<scf::WhileOp>(loc, capacity.getType(), capacity); 1338 1339 // The before-region of the WhileOp. 1340 Block *before = rewriter.createBlock(&whileOp.getBefore(), {}, 1341 {capacity.getType()}, {loc}); 1342 rewriter.setInsertionPointToEnd(before); 1343 1344 capacity = 1345 rewriter.create<arith::MulIOp>(loc, before->getArgument(0), c2); 1346 cond = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ugt, 1347 newSize, capacity); 1348 rewriter.create<scf::ConditionOp>(loc, cond, ValueRange{capacity}); 1349 // The after-region of the WhileOp. 1350 Block *after = rewriter.createBlock(&whileOp.getAfter(), {}, 1351 {capacity.getType()}, {loc}); 1352 rewriter.setInsertionPointToEnd(after); 1353 rewriter.create<scf::YieldOp>(loc, after->getArguments()); 1354 1355 rewriter.setInsertionPointAfter(whileOp); 1356 capacity = whileOp.getResult(0); 1357 } 1358 1359 Value newBuffer = 1360 rewriter.create<memref::ReallocOp>(loc, bufferType, buffer, capacity); 1361 if (enableBufferInitialization) { 1362 Value fillSize = rewriter.create<arith::SubIOp>(loc, capacity, newSize); 1363 Value fillValue = constantZero(rewriter, loc, value.getType()); 1364 Value subBuffer = rewriter.create<memref::SubViewOp>( 1365 loc, newBuffer, /*offset=*/ValueRange{newSize}, 1366 /*size=*/ValueRange{fillSize}, 1367 /*step=*/ValueRange{constantIndex(rewriter, loc, 1)}); 1368 rewriter.create<linalg::FillOp>(loc, fillValue, subBuffer); 1369 } 1370 rewriter.create<scf::YieldOp>(loc, newBuffer); 1371 1372 // False branch. 1373 rewriter.setInsertionPointToStart(&ifOp.getElseRegion().front()); 1374 rewriter.create<scf::YieldOp>(loc, buffer); 1375 1376 // Prepare for adding the value to the end of the buffer. 1377 rewriter.setInsertionPointAfter(ifOp); 1378 buffer = ifOp.getResult(0); 1379 } 1380 1381 // Add the value to the end of the buffer. 1382 if (nIsOne) { 1383 rewriter.create<memref::StoreOp>(loc, value, buffer, size); 1384 } else { 1385 Value subBuffer = rewriter.create<memref::SubViewOp>( 1386 loc, buffer, /*offset=*/ValueRange{size}, /*size=*/ValueRange{n}, 1387 /*step=*/ValueRange{constantIndex(rewriter, loc, 1)}); 1388 rewriter.create<linalg::FillOp>(loc, value, subBuffer); 1389 } 1390 1391 // Update the buffer size. 1392 rewriter.replaceOp(op, {buffer, newSize}); 1393 return success(); 1394 } 1395 1396 private: 1397 bool enableBufferInitialization; 1398 }; 1399 1400 /// Sparse rewriting rule for the sort_coo operator. 1401 struct SortRewriter : public OpRewritePattern<SortOp> { 1402 public: 1403 using OpRewritePattern<SortOp>::OpRewritePattern; 1404 1405 LogicalResult matchAndRewrite(SortOp op, 1406 PatternRewriter &rewriter) const override { 1407 SmallVector<Value> xys; 1408 xys.push_back(op.getXy()); 1409 xys.append(op.getYs().begin(), op.getYs().end()); 1410 1411 auto xPerm = op.getPermMap(); 1412 uint64_t ny = 0; 1413 if (auto nyAttr = op.getNyAttr()) 1414 ny = nyAttr.getInt(); 1415 1416 return matchAndRewriteSortOp(op, xys, xPerm, ny, rewriter); 1417 } 1418 }; 1419 1420 } // namespace 1421 1422 //===---------------------------------------------------------------------===// 1423 // Methods that add patterns described in this file to a pattern list. 1424 //===---------------------------------------------------------------------===// 1425 1426 void mlir::populateSparseBufferRewriting(RewritePatternSet &patterns, 1427 bool enableBufferInitialization) { 1428 patterns.add<PushBackRewriter>(patterns.getContext(), 1429 enableBufferInitialization); 1430 patterns.add<SortRewriter>(patterns.getContext()); 1431 } 1432