1 //===- SparseBufferRewriting.cpp - Sparse buffer rewriting rules ----------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements rewriting rules that are specific to sparse tensor 10 // primitives with memref operands. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "Utils/CodegenUtils.h" 15 16 #include "mlir/Dialect/Arith/IR/Arith.h" 17 #include "mlir/Dialect/Func/IR/FuncOps.h" 18 #include "mlir/Dialect/Linalg/IR/Linalg.h" 19 #include "mlir/Dialect/Math/IR/Math.h" 20 #include "mlir/Dialect/MemRef/IR/MemRef.h" 21 #include "mlir/Dialect/SCF/IR/SCF.h" 22 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" 23 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h" 24 #include "mlir/Support/LLVM.h" 25 26 using namespace mlir; 27 using namespace mlir::sparse_tensor; 28 29 //===---------------------------------------------------------------------===// 30 // Helper methods for the actual rewriting rules. 31 //===---------------------------------------------------------------------===// 32 33 static constexpr uint64_t loIdx = 0; 34 static constexpr uint64_t hiIdx = 1; 35 static constexpr uint64_t xStartIdx = 2; 36 37 static constexpr const char kPartitionFuncNamePrefix[] = "_sparse_partition_"; 38 static constexpr const char kBinarySearchFuncNamePrefix[] = 39 "_sparse_binary_search_"; 40 static constexpr const char kHybridQuickSortFuncNamePrefix[] = 41 "_sparse_hybrid_qsort_"; 42 static constexpr const char kSortStableFuncNamePrefix[] = 43 "_sparse_sort_stable_"; 44 static constexpr const char kShiftDownFuncNamePrefix[] = "_sparse_shift_down_"; 45 static constexpr const char kHeapSortFuncNamePrefix[] = "_sparse_heap_sort_"; 46 static constexpr const char kQuickSortFuncNamePrefix[] = "_sparse_qsort_"; 47 48 using FuncGeneratorType = function_ref<void(OpBuilder &, ModuleOp, func::FuncOp, 49 AffineMap, uint64_t, uint32_t)>; 50 51 /// Constructs a function name with this format to facilitate quick sort: 52 /// <namePrefix><xPerm>_<x type>_<y0 type>..._<yn type> for sort 53 /// <namePrefix><xPerm>_<x type>_coo_<ny>_<y0 type>..._<yn type> for sort_coo 54 static void getMangledSortHelperFuncName(llvm::raw_svector_ostream &nameOstream, 55 StringRef namePrefix, AffineMap xPerm, 56 uint64_t ny, ValueRange operands) { 57 nameOstream << namePrefix; 58 for (auto res : xPerm.getResults()) 59 nameOstream << cast<AffineDimExpr>(res).getPosition() << "_"; 60 61 nameOstream << getMemRefType(operands[xStartIdx]).getElementType(); 62 nameOstream << "_coo_" << ny; 63 64 constexpr uint64_t yBufferOffset = 1; 65 for (Value v : operands.drop_front(xStartIdx + yBufferOffset)) 66 nameOstream << "_" << getMemRefType(v).getElementType(); 67 } 68 69 /// Looks up a function that is appropriate for the given operands being 70 /// sorted, and creates such a function if it doesn't exist yet. The 71 /// parameters `xPerm` and `ny` tell the number of x and y values provided 72 /// by the buffer in xStartIdx. 73 // 74 // All sorting function generators take (lo, hi, xs, ys) in `operands` as 75 // parameters for the sorting functions. Other parameters, such as the recursive 76 // call depth, are appended to the end of the parameter list as 77 // "trailing parameters". 78 static FlatSymbolRefAttr getMangledSortHelperFunc( 79 OpBuilder &builder, func::FuncOp insertPoint, TypeRange resultTypes, 80 StringRef namePrefix, AffineMap xPerm, uint64_t ny, ValueRange operands, 81 FuncGeneratorType createFunc, uint32_t nTrailingP = 0) { 82 SmallString<32> nameBuffer; 83 llvm::raw_svector_ostream nameOstream(nameBuffer); 84 getMangledSortHelperFuncName(nameOstream, namePrefix, xPerm, ny, 85 operands.drop_back(nTrailingP)); 86 87 ModuleOp module = insertPoint->getParentOfType<ModuleOp>(); 88 MLIRContext *context = module.getContext(); 89 auto result = SymbolRefAttr::get(context, nameOstream.str()); 90 auto func = module.lookupSymbol<func::FuncOp>(result.getAttr()); 91 92 if (!func) { 93 // Create the function. 94 OpBuilder::InsertionGuard insertionGuard(builder); 95 builder.setInsertionPoint(insertPoint); 96 Location loc = insertPoint.getLoc(); 97 func = builder.create<func::FuncOp>( 98 loc, nameOstream.str(), 99 FunctionType::get(context, operands.getTypes(), resultTypes)); 100 func.setPrivate(); 101 createFunc(builder, module, func, xPerm, ny, nTrailingP); 102 } 103 104 return result; 105 } 106 107 /// Creates a code block to process each pair of (xs[i], xs[j]) for sorting. 108 /// The code to process the value pairs is generated by `bodyBuilder`. 109 static void forEachIJPairInXs( 110 OpBuilder &builder, Location loc, ValueRange args, AffineMap xPerm, 111 uint64_t ny, 112 function_ref<void(uint64_t, Value, Value, Value)> bodyBuilder) { 113 Value cstep = constantIndex(builder, loc, xPerm.getNumResults() + ny); 114 Value iOffset = builder.create<arith::MulIOp>(loc, args[0], cstep); 115 Value jOffset = builder.create<arith::MulIOp>(loc, args[1], cstep); 116 for (unsigned k = 0, e = xPerm.getNumResults(); k < e; k++) { 117 unsigned actualK = cast<AffineDimExpr>(xPerm.getResult(k)).getPosition(); 118 Value ak = constantIndex(builder, loc, actualK); 119 Value i = builder.create<arith::AddIOp>(loc, ak, iOffset); 120 Value j = builder.create<arith::AddIOp>(loc, ak, jOffset); 121 Value buffer = args[xStartIdx]; 122 123 bodyBuilder(k, i, j, buffer); 124 } 125 } 126 127 /// Creates a code block to process each pair of (xys[i], xys[j]) for sorting. 128 /// The code to process the value pairs is generated by `bodyBuilder`. 129 static void forEachIJPairInAllBuffers( 130 OpBuilder &builder, Location loc, ValueRange args, AffineMap xPerm, 131 uint64_t ny, 132 function_ref<void(uint64_t, Value, Value, Value)> bodyBuilder) { 133 134 // Create code for the first (xPerm + ny) buffers. 135 SmallVector<AffineExpr> exps(xPerm.getResults()); 136 for (unsigned y = 0; y < ny; y++) { 137 exps.push_back(builder.getAffineDimExpr(y + xPerm.getNumResults())); 138 } 139 AffineMap xyPerm = AffineMap::get(exps.size(), 0, exps, builder.getContext()); 140 assert(xyPerm.isPermutation()); 141 142 forEachIJPairInXs(builder, loc, args, xyPerm, 0, bodyBuilder); 143 144 constexpr uint64_t numHandledBuffers = 1; 145 // Create code for the remaining buffers. 146 Value i = args[0]; 147 Value j = args[1]; 148 for (const auto &arg : 149 llvm::enumerate(args.drop_front(xStartIdx + numHandledBuffers))) { 150 bodyBuilder(arg.index() + xPerm.getNumResults() + ny, i, j, arg.value()); 151 } 152 } 153 154 /// Creates a code block for swapping the values in index i and j for all the 155 /// buffers. 156 // 157 // The generated IR corresponds to this C like algorithm: 158 // swap(x0[i], x0[j]); 159 // swap(x1[i], x1[j]); 160 // ... 161 // swap(xn[i], xn[j]); 162 // swap(y0[i], y0[j]); 163 // ... 164 // swap(yn[i], yn[j]); 165 static void createSwap(OpBuilder &builder, Location loc, ValueRange args, 166 AffineMap xPerm, uint64_t ny) { 167 auto swapOnePair = [&](uint64_t unused, Value i, Value j, Value buffer) { 168 Value vi = builder.create<memref::LoadOp>(loc, buffer, i); 169 Value vj = builder.create<memref::LoadOp>(loc, buffer, j); 170 builder.create<memref::StoreOp>(loc, vj, buffer, i); 171 builder.create<memref::StoreOp>(loc, vi, buffer, j); 172 }; 173 174 forEachIJPairInAllBuffers(builder, loc, args, xPerm, ny, swapOnePair); 175 } 176 177 /// Creates code to compare all the (xs[i], xs[j]) pairs. The method to compare 178 /// each pair is create via `compareBuilder`. 179 static Value createInlinedCompareImplementation( 180 OpBuilder &builder, Location loc, ValueRange args, AffineMap xPerm, 181 uint64_t ny, 182 function_ref<Value(OpBuilder &, Location, Value, Value, Value, bool, bool)> 183 compareBuilder) { 184 Value result; 185 auto bodyBuilder = [&](uint64_t k, Value i, Value j, Value buffer) { 186 bool isFirstDim = (k == 0); 187 bool isLastDim = (k == xPerm.getNumResults() - 1); 188 Value val = 189 compareBuilder(builder, loc, i, j, buffer, isFirstDim, isLastDim); 190 if (isFirstDim) { 191 result = val; 192 } else if (!isLastDim) { 193 OpBuilder::InsertionGuard insertionGuard(builder); 194 auto ifOp = cast<scf::IfOp>(val.getDefiningOp()); 195 builder.setInsertionPointAfter(ifOp); 196 builder.create<scf::YieldOp>(loc, ifOp.getResult(0)); 197 } 198 }; 199 200 forEachIJPairInXs(builder, loc, args, xPerm, ny, bodyBuilder); 201 202 builder.setInsertionPointAfterValue(result); 203 return result; 204 } 205 206 /// Generates code to compare whether x[i] is equal to x[j] and returns the 207 /// result of the comparison. 208 static Value createEqCompare(OpBuilder &builder, Location loc, Value i, Value j, 209 Value x, bool isFirstDim, bool isLastDim) { 210 Value vi = builder.create<memref::LoadOp>(loc, x, i); 211 Value vj = builder.create<memref::LoadOp>(loc, x, j); 212 213 Value res; 214 if (isLastDim) { 215 res = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, vi, vj); 216 // For 1D, we create a compare without any control flow. Otherwise, we 217 // create YieldOp to return the result in the nested if-stmt. 218 if (!isFirstDim) 219 builder.create<scf::YieldOp>(loc, res); 220 } else { 221 Value ne = 222 builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ne, vi, vj); 223 scf::IfOp ifOp = builder.create<scf::IfOp>(loc, builder.getIntegerType(1), 224 ne, /*else=*/true); 225 // If (x[i] != x[j]). 226 builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); 227 Value f = constantI1(builder, loc, false); 228 builder.create<scf::YieldOp>(loc, f); 229 230 // If (x[i] == x[j]). Set up the insertion point for the nested if-stmt that 231 // checks the remaining dimensions. 232 builder.setInsertionPointToStart(&ifOp.getElseRegion().front()); 233 res = ifOp.getResult(0); 234 } 235 236 return res; 237 } 238 239 /// Creates code to compare whether xs[i] is equal to xs[j]. 240 // 241 // The generate IR corresponds to this C like algorithm: 242 // if (x0[i] != x0[j]) 243 // return false; 244 // else 245 // if (x1[i] != x1[j]) 246 // return false; 247 // else if (x2[2] != x2[j])) 248 // and so on ... 249 static Value createInlinedEqCompare(OpBuilder &builder, Location loc, 250 ValueRange args, AffineMap xPerm, 251 uint64_t ny, uint32_t nTrailingP = 0) { 252 // Compare functions don't use trailing parameters. 253 (void)nTrailingP; 254 assert(nTrailingP == 0); 255 return createInlinedCompareImplementation(builder, loc, args, xPerm, ny, 256 createEqCompare); 257 } 258 259 /// Generates code to compare whether x[i] is less than x[j] and returns the 260 /// result of the comparison. 261 static Value createLessThanCompare(OpBuilder &builder, Location loc, Value i, 262 Value j, Value x, bool isFirstDim, 263 bool isLastDim) { 264 Value vi = builder.create<memref::LoadOp>(loc, x, i); 265 Value vj = builder.create<memref::LoadOp>(loc, x, j); 266 267 Value res; 268 if (isLastDim) { 269 res = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult, vi, vj); 270 // For 1D, we create a compare without any control flow. Otherwise, we 271 // create YieldOp to return the result in the nested if-stmt. 272 if (!isFirstDim) 273 builder.create<scf::YieldOp>(loc, res); 274 } else { 275 Value ne = 276 builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ne, vi, vj); 277 scf::IfOp ifOp = builder.create<scf::IfOp>(loc, builder.getIntegerType(1), 278 ne, /*else=*/true); 279 // If (x[i] != x[j]). 280 builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); 281 Value lt = 282 builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult, vi, vj); 283 builder.create<scf::YieldOp>(loc, lt); 284 285 // If (x[i] == x[j]). Set up the insertion point for the nested if-stmt that 286 // checks the remaining dimensions. 287 builder.setInsertionPointToStart(&ifOp.getElseRegion().front()); 288 res = ifOp.getResult(0); 289 } 290 291 return res; 292 } 293 294 /// Creates code to compare whether xs[i] is less than xs[j]. 295 // 296 // The generate IR corresponds to this C like algorithm: 297 // if (x0[i] != x0[j]) 298 // return x0[i] < x0[j]; 299 // else if (x1[j] != x1[i]) 300 // return x1[i] < x1[j]; 301 // else 302 // and so on ... 303 static Value createInlinedLessThan(OpBuilder &builder, Location loc, 304 ValueRange args, AffineMap xPerm, 305 uint64_t ny, uint32_t nTrailingP = 0) { 306 // Compare functions don't use trailing parameters. 307 (void)nTrailingP; 308 assert(nTrailingP == 0); 309 return createInlinedCompareImplementation(builder, loc, args, xPerm, ny, 310 createLessThanCompare); 311 } 312 313 /// Creates a function to use a binary search to find the insertion point for 314 /// inserting xs[hi] to the sorted values xs[lo..hi). 315 // 316 // The generate IR corresponds to this C like algorithm: 317 // p = hi 318 // while (lo < hi) 319 // mid = (lo + hi) >> 1 320 // if (xs[p] < xs[mid]) 321 // hi = mid 322 // else 323 // lo = mid - 1 324 // return lo; 325 // 326 static void createBinarySearchFunc(OpBuilder &builder, ModuleOp module, 327 func::FuncOp func, AffineMap xPerm, 328 uint64_t ny, uint32_t nTrailingP = 0) { 329 // Binary search doesn't use trailing parameters. 330 (void)nTrailingP; 331 assert(nTrailingP == 0); 332 OpBuilder::InsertionGuard insertionGuard(builder); 333 Block *entryBlock = func.addEntryBlock(); 334 builder.setInsertionPointToStart(entryBlock); 335 336 Location loc = func.getLoc(); 337 ValueRange args = entryBlock->getArguments(); 338 Value p = args[hiIdx]; 339 SmallVector<Type, 2> types(2, p.getType()); // Only two types. 340 scf::WhileOp whileOp = builder.create<scf::WhileOp>( 341 loc, types, SmallVector<Value, 2>{args[loIdx], args[hiIdx]}); 342 343 // The before-region of the WhileOp. 344 Block *before = 345 builder.createBlock(&whileOp.getBefore(), {}, types, {loc, loc}); 346 builder.setInsertionPointToEnd(before); 347 Value cond1 = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult, 348 before->getArgument(0), 349 before->getArgument(1)); 350 builder.create<scf::ConditionOp>(loc, cond1, before->getArguments()); 351 352 // The after-region of the WhileOp. 353 Block *after = 354 builder.createBlock(&whileOp.getAfter(), {}, types, {loc, loc}); 355 builder.setInsertionPointToEnd(after); 356 Value lo = after->getArgument(0); 357 Value hi = after->getArgument(1); 358 // Compute mid = (lo + hi) >> 1. 359 Value c1 = constantIndex(builder, loc, 1); 360 Value mid = builder.create<arith::ShRUIOp>( 361 loc, builder.create<arith::AddIOp>(loc, lo, hi), c1); 362 Value midp1 = builder.create<arith::AddIOp>(loc, mid, c1); 363 364 // Compare xs[p] < xs[mid]. 365 SmallVector<Value> compareOperands{p, mid}; 366 constexpr uint64_t numXBuffers = 1; 367 compareOperands.append(args.begin() + xStartIdx, 368 args.begin() + xStartIdx + numXBuffers); 369 Value cond2 = createInlinedLessThan(builder, loc, compareOperands, xPerm, ny); 370 // Update lo and hi for the WhileOp as follows: 371 // if (xs[p] < xs[mid])) 372 // hi = mid; 373 // else 374 // lo = mid + 1; 375 Value newLo = builder.create<arith::SelectOp>(loc, cond2, lo, midp1); 376 Value newHi = builder.create<arith::SelectOp>(loc, cond2, mid, hi); 377 builder.create<scf::YieldOp>(loc, ValueRange{newLo, newHi}); 378 379 builder.setInsertionPointAfter(whileOp); 380 builder.create<func::ReturnOp>(loc, whileOp.getResult(0)); 381 } 382 383 /// Creates code to advance i in a loop based on xs[p] as follows: 384 /// while (xs[i] < xs[p]) i += step (step > 0) 385 /// or 386 /// while (xs[i] > xs[p]) i += step (step < 0) 387 /// The routine returns i as well as a boolean value to indicate whether 388 /// xs[i] == xs[p]. 389 static std::pair<Value, Value> createScanLoop(OpBuilder &builder, 390 ModuleOp module, 391 func::FuncOp func, ValueRange xs, 392 Value i, Value p, AffineMap xPerm, 393 uint64_t ny, int step) { 394 Location loc = func.getLoc(); 395 scf::WhileOp whileOp = 396 builder.create<scf::WhileOp>(loc, TypeRange{i.getType()}, ValueRange{i}); 397 398 Block *before = 399 builder.createBlock(&whileOp.getBefore(), {}, {i.getType()}, {loc}); 400 builder.setInsertionPointToEnd(before); 401 SmallVector<Value> compareOperands; 402 if (step > 0) { 403 compareOperands.push_back(before->getArgument(0)); 404 compareOperands.push_back(p); 405 } else { 406 assert(step < 0); 407 compareOperands.push_back(p); 408 compareOperands.push_back(before->getArgument(0)); 409 } 410 compareOperands.append(xs.begin(), xs.end()); 411 Value cond = createInlinedLessThan(builder, loc, compareOperands, xPerm, ny); 412 builder.create<scf::ConditionOp>(loc, cond, before->getArguments()); 413 414 Block *after = 415 builder.createBlock(&whileOp.getAfter(), {}, {i.getType()}, {loc}); 416 builder.setInsertionPointToEnd(after); 417 Value cs = constantIndex(builder, loc, step); 418 i = builder.create<arith::AddIOp>(loc, after->getArgument(0), cs); 419 builder.create<scf::YieldOp>(loc, ValueRange{i}); 420 i = whileOp.getResult(0); 421 422 builder.setInsertionPointAfter(whileOp); 423 compareOperands[0] = i; 424 compareOperands[1] = p; 425 Value compareEq = 426 createInlinedEqCompare(builder, loc, compareOperands, xPerm, ny); 427 428 return std::make_pair(whileOp.getResult(0), compareEq); 429 } 430 431 /// Creates and returns an IfOp to compare two elements and swap the elements 432 /// if compareFunc(data[b], data[a]) returns true. The new insertion point is 433 /// right after the swap instructions. 434 static scf::IfOp createCompareThenSwap(OpBuilder &builder, Location loc, 435 AffineMap xPerm, uint64_t ny, 436 SmallVectorImpl<Value> &swapOperands, 437 SmallVectorImpl<Value> &compareOperands, 438 Value a, Value b) { 439 // Compare(data[b], data[a]). 440 compareOperands[0] = b; 441 compareOperands[1] = a; 442 Value cond = createInlinedLessThan(builder, loc, compareOperands, xPerm, ny); 443 scf::IfOp ifOp = builder.create<scf::IfOp>(loc, cond, /*else=*/false); 444 builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); 445 swapOperands[0] = b; 446 swapOperands[1] = a; 447 createSwap(builder, loc, swapOperands, xPerm, ny); 448 return ifOp; 449 } 450 451 /// Creates code to insert the 3rd element to a list of two sorted elements. 452 static void createInsert3rd(OpBuilder &builder, Location loc, AffineMap xPerm, 453 uint64_t ny, SmallVectorImpl<Value> &swapOperands, 454 SmallVectorImpl<Value> &compareOperands, Value v0, 455 Value v1, Value v2) { 456 scf::IfOp ifOp = createCompareThenSwap(builder, loc, xPerm, ny, swapOperands, 457 compareOperands, v1, v2); 458 createCompareThenSwap(builder, loc, xPerm, ny, swapOperands, compareOperands, 459 v0, v1); 460 builder.setInsertionPointAfter(ifOp); 461 } 462 463 /// Creates code to sort 3 elements. 464 static void createSort3(OpBuilder &builder, Location loc, AffineMap xPerm, 465 uint64_t ny, SmallVectorImpl<Value> &swapOperands, 466 SmallVectorImpl<Value> &compareOperands, Value v0, 467 Value v1, Value v2) { 468 // Sort the first 2 elements. 469 scf::IfOp ifOp1 = createCompareThenSwap(builder, loc, xPerm, ny, swapOperands, 470 compareOperands, v0, v1); 471 builder.setInsertionPointAfter(ifOp1); 472 473 // Insert the 3th element. 474 createInsert3rd(builder, loc, xPerm, ny, swapOperands, compareOperands, v0, 475 v1, v2); 476 } 477 478 /// Creates code to sort 5 elements. 479 static void createSort5(OpBuilder &builder, Location loc, AffineMap xPerm, 480 uint64_t ny, SmallVectorImpl<Value> &swapOperands, 481 SmallVectorImpl<Value> &compareOperands, Value v0, 482 Value v1, Value v2, Value v3, Value v4) { 483 // Sort the first 3 elements. 484 createSort3(builder, loc, xPerm, ny, swapOperands, compareOperands, v0, v1, 485 v2); 486 487 auto insert4th = [&]() { 488 scf::IfOp ifOp = createCompareThenSwap( 489 builder, loc, xPerm, ny, swapOperands, compareOperands, v2, v3); 490 createInsert3rd(builder, loc, xPerm, ny, swapOperands, compareOperands, v0, 491 v1, v2); 492 builder.setInsertionPointAfter(ifOp); 493 }; 494 495 // Insert the 4th element. 496 insert4th(); 497 498 // Insert the 5th element. 499 scf::IfOp ifOp = createCompareThenSwap(builder, loc, xPerm, ny, swapOperands, 500 compareOperands, v3, v4); 501 insert4th(); 502 builder.setInsertionPointAfter(ifOp); 503 } 504 505 /// Creates a code block to swap the values in indices lo, mi, and hi so that 506 /// data[lo], data[mi] and data[hi] are sorted in non-decreasing values. When 507 /// the number of values in range [lo, hi) is more than a threshold, we also 508 /// include the middle of [lo, mi) and [mi, hi) and sort a total of five values. 509 static void createChoosePivot(OpBuilder &builder, ModuleOp module, 510 func::FuncOp func, AffineMap xPerm, uint64_t ny, 511 Value lo, Value hi, Value mi, ValueRange args) { 512 SmallVector<Value> compareOperands{mi, lo}; 513 constexpr uint64_t numXBuffers = 1; 514 compareOperands.append(args.begin() + xStartIdx, 515 args.begin() + xStartIdx + numXBuffers); 516 SmallVector<Value> swapOperands{mi, lo}; 517 swapOperands.append(args.begin() + xStartIdx, args.end()); 518 Location loc = func.getLoc(); 519 Value c1 = constantIndex(builder, loc, 1); 520 Value hiP1 = builder.create<arith::AddIOp>(loc, hi, c1); 521 Value len = builder.create<arith::SubIOp>(loc, hiP1, lo); 522 Value lenThreshold = constantIndex(builder, loc, 1000); 523 Value lenCond = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult, 524 len, lenThreshold); 525 scf::IfOp lenIf = builder.create<scf::IfOp>(loc, lenCond, /*else=*/true); 526 527 // When len < 1000, choose pivot from median of 3 values. 528 builder.setInsertionPointToStart(&lenIf.getThenRegion().front()); 529 createSort3(builder, loc, xPerm, ny, swapOperands, compareOperands, lo, mi, 530 hi); 531 532 // When len >= 1000, choose pivot from median of 5 values. 533 builder.setInsertionPointToStart(&lenIf.getElseRegion().front()); 534 Value miP1 = builder.create<arith::AddIOp>(loc, hi, c1); 535 Value a = builder.create<arith::AddIOp>(loc, lo, miP1); 536 // Value a is the middle between [loc, mi]. 537 a = builder.create<arith::ShRUIOp>(loc, a, c1); 538 Value b = builder.create<arith::AddIOp>(loc, mi, hiP1); 539 // Value b is the middle between [mi, hi]. 540 b = builder.create<arith::ShRUIOp>(loc, b, c1); 541 createSort5(builder, loc, xPerm, ny, swapOperands, compareOperands, lo, a, mi, 542 b, hi); 543 544 builder.setInsertionPointAfter(lenIf); 545 } 546 547 /// Creates a function to perform quick sort partition on the values in the 548 /// range of index [lo, hi), assuming lo < hi. 549 // 550 // The generated IR corresponds to this C like algorithm: 551 // int partition(lo, hi, xs) { 552 // p = (lo+hi)/2 // pivot index 553 // i = lo 554 // j = hi-1 555 // while (true) do { 556 // while (xs[i] < xs[p]) i ++; 557 // i_eq = (xs[i] == xs[p]); 558 // while (xs[j] > xs[p]) j --; 559 // j_eq = (xs[j] == xs[p]); 560 // 561 // if (i >= j) return j + 1; 562 // 563 // if (i < j) { 564 // swap(xs[i], xs[j]) 565 // if (i == p) { 566 // p = j; 567 // } else if (j == p) { 568 // p = i; 569 // } 570 // if (i_eq && j_eq) { 571 // ++i; 572 // --j; 573 // } 574 // } 575 // } 576 // } 577 static void createPartitionFunc(OpBuilder &builder, ModuleOp module, 578 func::FuncOp func, AffineMap xPerm, uint64_t ny, 579 uint32_t nTrailingP = 0) { 580 // Quick sort partition doesn't use trailing parameters. 581 (void)nTrailingP; 582 assert(nTrailingP == 0); 583 OpBuilder::InsertionGuard insertionGuard(builder); 584 585 Block *entryBlock = func.addEntryBlock(); 586 builder.setInsertionPointToStart(entryBlock); 587 588 Location loc = func.getLoc(); 589 ValueRange args = entryBlock->getArguments(); 590 Value lo = args[loIdx]; 591 Value hi = args[hiIdx]; 592 Value sum = builder.create<arith::AddIOp>(loc, lo, hi); 593 Value c1 = constantIndex(builder, loc, 1); 594 Value p = builder.create<arith::ShRUIOp>(loc, sum, c1); 595 596 Value i = lo; 597 Value j = builder.create<arith::SubIOp>(loc, hi, c1); 598 createChoosePivot(builder, module, func, xPerm, ny, i, j, p, args); 599 Value trueVal = constantI1(builder, loc, true); // The value for while (true) 600 SmallVector<Value, 4> operands{i, j, p, trueVal}; // Exactly four values. 601 SmallVector<Type, 4> types{i.getType(), j.getType(), p.getType(), 602 trueVal.getType()}; 603 scf::WhileOp whileOp = builder.create<scf::WhileOp>(loc, types, operands); 604 605 // The before-region of the WhileOp. 606 Block *before = builder.createBlock(&whileOp.getBefore(), {}, types, 607 {loc, loc, loc, loc}); 608 builder.setInsertionPointToEnd(before); 609 builder.create<scf::ConditionOp>(loc, before->getArgument(3), 610 before->getArguments()); 611 612 // The after-region of the WhileOp. 613 Block *after = 614 builder.createBlock(&whileOp.getAfter(), {}, types, {loc, loc, loc, loc}); 615 builder.setInsertionPointToEnd(after); 616 i = after->getArgument(0); 617 j = after->getArgument(1); 618 p = after->getArgument(2); 619 620 constexpr uint64_t numXBuffers = 1; 621 auto [iresult, iCompareEq] = 622 createScanLoop(builder, module, func, args.slice(xStartIdx, numXBuffers), 623 i, p, xPerm, ny, 1); 624 i = iresult; 625 auto [jresult, jCompareEq] = 626 createScanLoop(builder, module, func, args.slice(xStartIdx, numXBuffers), 627 j, p, xPerm, ny, -1); 628 j = jresult; 629 630 // If i < j: 631 Value cond = 632 builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult, i, j); 633 scf::IfOp ifOp = builder.create<scf::IfOp>(loc, types, cond, /*else=*/true); 634 builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); 635 SmallVector<Value> swapOperands{i, j}; 636 swapOperands.append(args.begin() + xStartIdx, args.end()); 637 createSwap(builder, loc, swapOperands, xPerm, ny); 638 // If the pivot is moved, update p with the new pivot. 639 Value icond = 640 builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, i, p); 641 scf::IfOp ifOpI = builder.create<scf::IfOp>(loc, TypeRange{p.getType()}, 642 icond, /*else=*/true); 643 builder.setInsertionPointToStart(&ifOpI.getThenRegion().front()); 644 builder.create<scf::YieldOp>(loc, ValueRange{j}); 645 builder.setInsertionPointToStart(&ifOpI.getElseRegion().front()); 646 Value jcond = 647 builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, j, p); 648 scf::IfOp ifOpJ = builder.create<scf::IfOp>(loc, TypeRange{p.getType()}, 649 jcond, /*else=*/true); 650 builder.setInsertionPointToStart(&ifOpJ.getThenRegion().front()); 651 builder.create<scf::YieldOp>(loc, ValueRange{i}); 652 builder.setInsertionPointToStart(&ifOpJ.getElseRegion().front()); 653 builder.create<scf::YieldOp>(loc, ValueRange{p}); 654 builder.setInsertionPointAfter(ifOpJ); 655 builder.create<scf::YieldOp>(loc, ifOpJ.getResults()); 656 builder.setInsertionPointAfter(ifOpI); 657 Value compareEqIJ = 658 builder.create<arith::AndIOp>(loc, iCompareEq, jCompareEq); 659 scf::IfOp ifOp2 = builder.create<scf::IfOp>( 660 loc, TypeRange{i.getType(), j.getType()}, compareEqIJ, /*else=*/true); 661 builder.setInsertionPointToStart(&ifOp2.getThenRegion().front()); 662 Value i2 = builder.create<arith::AddIOp>(loc, i, c1); 663 Value j2 = builder.create<arith::SubIOp>(loc, j, c1); 664 builder.create<scf::YieldOp>(loc, ValueRange{i2, j2}); 665 builder.setInsertionPointToStart(&ifOp2.getElseRegion().front()); 666 builder.create<scf::YieldOp>(loc, ValueRange{i, j}); 667 builder.setInsertionPointAfter(ifOp2); 668 builder.create<scf::YieldOp>( 669 loc, 670 ValueRange{ifOp2.getResult(0), ifOp2.getResult(1), ifOpI.getResult(0), 671 /*cont=*/constantI1(builder, loc, true)}); 672 673 // False branch for if i < j (i.e., i >= j): 674 builder.setInsertionPointToStart(&ifOp.getElseRegion().front()); 675 p = builder.create<arith::AddIOp>(loc, j, 676 constantOne(builder, loc, j.getType())); 677 builder.create<scf::YieldOp>( 678 loc, ValueRange{i, j, p, /*cont=*/constantI1(builder, loc, false)}); 679 680 // Return for the whileOp. 681 builder.setInsertionPointAfter(ifOp); 682 builder.create<scf::YieldOp>(loc, ifOp.getResults()); 683 684 // Return for the function. 685 builder.setInsertionPointAfter(whileOp); 686 builder.create<func::ReturnOp>(loc, whileOp.getResult(2)); 687 } 688 689 /// Computes (n-2)/n, assuming n has index type. 690 static Value createSubTwoDividedByTwo(OpBuilder &builder, Location loc, 691 Value n) { 692 Value i2 = constantIndex(builder, loc, 2); 693 Value res = builder.create<arith::SubIOp>(loc, n, i2); 694 Value i1 = constantIndex(builder, loc, 1); 695 return builder.create<arith::ShRUIOp>(loc, res, i1); 696 } 697 698 /// Creates a function to heapify the subtree with root `start` within the full 699 /// binary tree in the range of index [first, first + n). 700 // 701 // The generated IR corresponds to this C like algorithm: 702 // void shiftDown(first, start, n, data) { 703 // if (n >= 2) { 704 // child = start - first 705 // if ((n-2)/2 >= child) { 706 // // Left child exists. 707 // child = child * 2 + 1 // Initialize the bigger child to left child. 708 // childIndex = child + first 709 // if (child+1 < n && data[childIndex] < data[childIndex+1]) 710 // // Right child exits and is bigger. 711 // childIndex++; child++; 712 // // Shift data[start] down to where it belongs in the subtree. 713 // while (data[start] < data[childIndex) { 714 // swap(data[start], data[childIndex]) 715 // start = childIndex 716 // if ((n - 2)/2 >= child) { 717 // // Left child exists. 718 // child = 2*child + 1 719 // childIndex = child + 1 720 // if (child + 1) < n && data[childIndex] < data[childIndex+1] 721 // childIndex++; child++; 722 // } 723 // } 724 // } 725 // } 726 // } 727 // 728 static void createShiftDownFunc(OpBuilder &builder, ModuleOp module, 729 func::FuncOp func, AffineMap xPerm, uint64_t ny, 730 uint32_t nTrailingP) { 731 // The value n is passed in as a trailing parameter. 732 assert(nTrailingP == 1); 733 OpBuilder::InsertionGuard insertionGuard(builder); 734 Block *entryBlock = func.addEntryBlock(); 735 builder.setInsertionPointToStart(entryBlock); 736 737 Location loc = func.getLoc(); 738 Value n = entryBlock->getArguments().back(); 739 ValueRange args = entryBlock->getArguments().drop_back(); 740 Value first = args[loIdx]; 741 Value start = args[hiIdx]; 742 743 // If (n >= 2). 744 Value c2 = constantIndex(builder, loc, 2); 745 Value condN = 746 builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::uge, n, c2); 747 scf::IfOp ifN = builder.create<scf::IfOp>(loc, condN, /*else=*/false); 748 builder.setInsertionPointToStart(&ifN.getThenRegion().front()); 749 Value child = builder.create<arith::SubIOp>(loc, start, first); 750 751 // If ((n-2)/2 >= child). 752 Value t = createSubTwoDividedByTwo(builder, loc, n); 753 Value condNc = 754 builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::uge, t, child); 755 scf::IfOp ifNc = builder.create<scf::IfOp>(loc, condNc, /*else=*/false); 756 757 builder.setInsertionPointToStart(&ifNc.getThenRegion().front()); 758 Value c1 = constantIndex(builder, loc, 1); 759 SmallVector<Value> compareOperands{start, start}; 760 constexpr uint64_t numXBuffers = 1; 761 compareOperands.append(args.begin() + xStartIdx, 762 args.begin() + xStartIdx + numXBuffers); 763 764 // Generate code to inspect the children of 'r' and return the larger child 765 // as follows: 766 // child = r * 2 + 1 // Left child. 767 // childIndex = child + first 768 // if (child+1 < n && data[childIndex] < data[childIndex+1]) 769 // childIndex ++; child ++ // Right child is bigger. 770 auto getLargerChild = [&](Value r) -> std::pair<Value, Value> { 771 Value lChild = builder.create<arith::ShLIOp>(loc, r, c1); 772 lChild = builder.create<arith::AddIOp>(loc, lChild, c1); 773 Value lChildIdx = builder.create<arith::AddIOp>(loc, lChild, first); 774 Value rChild = builder.create<arith::AddIOp>(loc, lChild, c1); 775 Value cond1 = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult, 776 rChild, n); 777 SmallVector<Type, 2> ifTypes(2, r.getType()); 778 scf::IfOp if1 = 779 builder.create<scf::IfOp>(loc, ifTypes, cond1, /*else=*/true); 780 builder.setInsertionPointToStart(&if1.getThenRegion().front()); 781 Value rChildIdx = builder.create<arith::AddIOp>(loc, rChild, first); 782 // Compare data[left] < data[right]. 783 compareOperands[0] = lChildIdx; 784 compareOperands[1] = rChildIdx; 785 Value cond2 = 786 createInlinedLessThan(builder, loc, compareOperands, xPerm, ny); 787 scf::IfOp if2 = 788 builder.create<scf::IfOp>(loc, ifTypes, cond2, /*else=*/true); 789 builder.setInsertionPointToStart(&if2.getThenRegion().front()); 790 builder.create<scf::YieldOp>(loc, ValueRange{rChild, rChildIdx}); 791 builder.setInsertionPointToStart(&if2.getElseRegion().front()); 792 builder.create<scf::YieldOp>(loc, ValueRange{lChild, lChildIdx}); 793 builder.setInsertionPointAfter(if2); 794 builder.create<scf::YieldOp>(loc, if2.getResults()); 795 builder.setInsertionPointToStart(&if1.getElseRegion().front()); 796 builder.create<scf::YieldOp>(loc, ValueRange{lChild, lChildIdx}); 797 builder.setInsertionPointAfter(if1); 798 return std::make_pair(if1.getResult(0), if1.getResult(1)); 799 }; 800 801 Value childIdx; 802 std::tie(child, childIdx) = getLargerChild(child); 803 804 // While (data[start] < data[childIndex]). 805 SmallVector<Type, 3> types(3, child.getType()); 806 scf::WhileOp whileOp = builder.create<scf::WhileOp>( 807 loc, types, SmallVector<Value, 2>{start, child, childIdx}); 808 809 // The before-region of the WhileOp. 810 SmallVector<Location, 3> locs(3, loc); 811 Block *before = builder.createBlock(&whileOp.getBefore(), {}, types, locs); 812 builder.setInsertionPointToEnd(before); 813 start = before->getArgument(0); 814 childIdx = before->getArgument(2); 815 compareOperands[0] = start; 816 compareOperands[1] = childIdx; 817 Value cond = createInlinedLessThan(builder, loc, compareOperands, xPerm, ny); 818 builder.create<scf::ConditionOp>(loc, cond, before->getArguments()); 819 820 // The after-region of the WhileOp. 821 Block *after = builder.createBlock(&whileOp.getAfter(), {}, types, locs); 822 start = after->getArgument(0); 823 child = after->getArgument(1); 824 childIdx = after->getArgument(2); 825 SmallVector<Value> swapOperands{start, childIdx}; 826 swapOperands.append(args.begin() + xStartIdx, args.end()); 827 createSwap(builder, loc, swapOperands, xPerm, ny); 828 start = childIdx; 829 Value cond2 = 830 builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::uge, t, child); 831 scf::IfOp if2 = builder.create<scf::IfOp>( 832 loc, TypeRange{child.getType(), child.getType()}, cond2, /*else=*/true); 833 builder.setInsertionPointToStart(&if2.getThenRegion().front()); 834 auto [newChild, newChildIdx] = getLargerChild(child); 835 builder.create<scf::YieldOp>(loc, ValueRange{newChild, newChildIdx}); 836 builder.setInsertionPointToStart(&if2.getElseRegion().front()); 837 builder.create<scf::YieldOp>(loc, ValueRange{child, childIdx}); 838 builder.setInsertionPointAfter(if2); 839 builder.create<scf::YieldOp>( 840 loc, ValueRange{start, if2.getResult(0), if2.getResult(1)}); 841 842 builder.setInsertionPointAfter(ifN); 843 builder.create<func::ReturnOp>(loc); 844 } 845 846 /// Creates a function to perform heap sort on the values in the range of index 847 /// [lo, hi) with the assumption hi - lo >= 2. 848 // 849 // The generate IR corresponds to this C like algorithm: 850 // void heapSort(lo, hi, data) { 851 // n = hi - lo 852 // for i = (n-2)/2 downto 0 853 // shiftDown(lo, lo+i, n) 854 // 855 // for l = n downto 2 856 // swap(lo, lo+l-1) 857 // shiftdown(lo, lo, l-1) 858 // } 859 static void createHeapSortFunc(OpBuilder &builder, ModuleOp module, 860 func::FuncOp func, AffineMap xPerm, uint64_t ny, 861 uint32_t nTrailingP) { 862 // Heap sort function doesn't have trailing parameters. 863 (void)nTrailingP; 864 assert(nTrailingP == 0); 865 OpBuilder::InsertionGuard insertionGuard(builder); 866 Block *entryBlock = func.addEntryBlock(); 867 builder.setInsertionPointToStart(entryBlock); 868 869 Location loc = func.getLoc(); 870 ValueRange args = entryBlock->getArguments(); 871 Value lo = args[loIdx]; 872 Value hi = args[hiIdx]; 873 Value n = builder.create<arith::SubIOp>(loc, hi, lo); 874 875 // For i = (n-2)/2 downto 0. 876 Value c0 = constantIndex(builder, loc, 0); 877 Value c1 = constantIndex(builder, loc, 1); 878 Value s = createSubTwoDividedByTwo(builder, loc, n); 879 Value up = builder.create<arith::AddIOp>(loc, s, c1); 880 scf::ForOp forI = builder.create<scf::ForOp>(loc, c0, up, c1); 881 builder.setInsertionPointToStart(forI.getBody()); 882 Value i = builder.create<arith::SubIOp>(loc, s, forI.getInductionVar()); 883 Value lopi = builder.create<arith::AddIOp>(loc, lo, i); 884 SmallVector<Value> shiftDownOperands = {lo, lopi}; 885 shiftDownOperands.append(args.begin() + xStartIdx, args.end()); 886 shiftDownOperands.push_back(n); 887 FlatSymbolRefAttr shiftDownFunc = getMangledSortHelperFunc( 888 builder, func, TypeRange(), kShiftDownFuncNamePrefix, xPerm, ny, 889 shiftDownOperands, createShiftDownFunc, /*nTrailingP=*/1); 890 builder.create<func::CallOp>(loc, shiftDownFunc, TypeRange(), 891 shiftDownOperands); 892 893 builder.setInsertionPointAfter(forI); 894 // For l = n downto 2. 895 up = builder.create<arith::SubIOp>(loc, n, c1); 896 scf::ForOp forL = builder.create<scf::ForOp>(loc, c0, up, c1); 897 builder.setInsertionPointToStart(forL.getBody()); 898 Value l = builder.create<arith::SubIOp>(loc, n, forL.getInductionVar()); 899 Value loplm1 = builder.create<arith::AddIOp>(loc, lo, l); 900 loplm1 = builder.create<arith::SubIOp>(loc, loplm1, c1); 901 SmallVector<Value> swapOperands{lo, loplm1}; 902 swapOperands.append(args.begin() + xStartIdx, args.end()); 903 createSwap(builder, loc, swapOperands, xPerm, ny); 904 shiftDownOperands[1] = lo; 905 shiftDownOperands[shiftDownOperands.size() - 1] = 906 builder.create<arith::SubIOp>(loc, l, c1); 907 builder.create<func::CallOp>(loc, shiftDownFunc, TypeRange(), 908 shiftDownOperands); 909 910 builder.setInsertionPointAfter(forL); 911 builder.create<func::ReturnOp>(loc); 912 } 913 914 /// A helper for generating code to perform quick sort. It partitions [lo, hi), 915 /// recursively calls quick sort to process the smaller partition and returns 916 /// the bigger partition to be processed by the enclosed while-loop. 917 static std::pair<Value, Value> 918 createQuickSort(OpBuilder &builder, ModuleOp module, func::FuncOp func, 919 ValueRange args, AffineMap xPerm, uint64_t ny, 920 uint32_t nTrailingP) { 921 MLIRContext *context = module.getContext(); 922 Location loc = func.getLoc(); 923 Value lo = args[loIdx]; 924 Value hi = args[hiIdx]; 925 SmallVector<Type, 2> types(2, lo.getType()); // Only two types. 926 927 FlatSymbolRefAttr partitionFunc = getMangledSortHelperFunc( 928 builder, func, {IndexType::get(context)}, kPartitionFuncNamePrefix, xPerm, 929 ny, args.drop_back(nTrailingP), createPartitionFunc); 930 Value p = builder 931 .create<func::CallOp>(loc, partitionFunc, 932 TypeRange{IndexType::get(context)}, 933 args.drop_back(nTrailingP)) 934 .getResult(0); 935 936 Value lenLow = builder.create<arith::SubIOp>(loc, p, lo); 937 Value lenHigh = builder.create<arith::SubIOp>(loc, hi, p); 938 // Partition already sorts array with len <= 2 939 Value c2 = constantIndex(builder, loc, 2); 940 Value len = builder.create<arith::SubIOp>(loc, hi, lo); 941 Value lenGtTwo = 942 builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ugt, len, c2); 943 scf::IfOp ifLenGtTwo = 944 builder.create<scf::IfOp>(loc, types, lenGtTwo, /*else=*/true); 945 builder.setInsertionPointToStart(&ifLenGtTwo.getElseRegion().front()); 946 // Returns an empty range to mark the entire region is fully sorted. 947 builder.create<scf::YieldOp>(loc, ValueRange{lo, lo}); 948 949 // Else len > 2, need recursion. 950 builder.setInsertionPointToStart(&ifLenGtTwo.getThenRegion().front()); 951 Value cond = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ule, 952 lenLow, lenHigh); 953 954 Value c0 = constantIndex(builder, loc, 0); 955 scf::IfOp ifOp = builder.create<scf::IfOp>(loc, types, cond, /*else=*/true); 956 957 auto mayRecursion = [&](Value low, Value high, Value len) { 958 Value cond = 959 builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ne, len, c0); 960 scf::IfOp ifOp = builder.create<scf::IfOp>(loc, cond, /*else=*/false); 961 builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); 962 SmallVector<Value> operands{low, high}; 963 operands.append(args.begin() + xStartIdx, args.end()); 964 builder.create<func::CallOp>(loc, func, operands); 965 builder.setInsertionPointAfter(ifOp); 966 }; 967 968 // Recursively call quickSort to process the smaller partition and return 969 // the bigger partition to be processed by the enclosed while-loop. 970 builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); 971 mayRecursion(lo, p, lenLow); 972 builder.create<scf::YieldOp>(loc, ValueRange{p, hi}); 973 974 builder.setInsertionPointToStart(&ifOp.getElseRegion().front()); 975 mayRecursion(p, hi, lenHigh); 976 builder.create<scf::YieldOp>(loc, ValueRange{lo, p}); 977 978 builder.setInsertionPointAfter(ifOp); 979 builder.create<scf::YieldOp>(loc, ifOp.getResults()); 980 981 builder.setInsertionPointAfter(ifLenGtTwo); 982 return std::make_pair(ifLenGtTwo.getResult(0), ifLenGtTwo.getResult(1)); 983 } 984 985 /// Creates a function to perform insertion sort on the values in the range of 986 /// index [lo, hi). 987 // 988 // The generate IR corresponds to this C like algorithm: 989 // void insertionSort(lo, hi, data) { 990 // for (i = lo+1; i < hi; i++) { 991 // d = data[i]; 992 // p = binarySearch(lo, i-1, data) 993 // for (j = 0; j > i - p; j++) 994 // data[i-j] = data[i-j-1] 995 // data[p] = d 996 // } 997 // } 998 static void createSortStableFunc(OpBuilder &builder, ModuleOp module, 999 func::FuncOp func, AffineMap xPerm, 1000 uint64_t ny, uint32_t nTrailingP) { 1001 // Stable sort function doesn't use trailing parameters. 1002 (void)nTrailingP; 1003 assert(nTrailingP == 0); 1004 OpBuilder::InsertionGuard insertionGuard(builder); 1005 Block *entryBlock = func.addEntryBlock(); 1006 builder.setInsertionPointToStart(entryBlock); 1007 1008 MLIRContext *context = module.getContext(); 1009 Location loc = func.getLoc(); 1010 ValueRange args = entryBlock->getArguments(); 1011 Value c1 = constantIndex(builder, loc, 1); 1012 Value lo = args[loIdx]; 1013 Value hi = args[hiIdx]; 1014 Value lop1 = builder.create<arith::AddIOp>(loc, lo, c1); 1015 1016 // Start the outer for-stmt with induction variable i. 1017 scf::ForOp forOpI = builder.create<scf::ForOp>(loc, lop1, hi, c1); 1018 builder.setInsertionPointToStart(forOpI.getBody()); 1019 Value i = forOpI.getInductionVar(); 1020 1021 // Binary search to find the insertion point p. 1022 SmallVector<Value> operands{lo, i}; 1023 operands.append(args.begin() + xStartIdx, args.end()); 1024 FlatSymbolRefAttr searchFunc = getMangledSortHelperFunc( 1025 builder, func, {IndexType::get(context)}, kBinarySearchFuncNamePrefix, 1026 xPerm, ny, operands, createBinarySearchFunc); 1027 Value p = builder 1028 .create<func::CallOp>(loc, searchFunc, TypeRange{c1.getType()}, 1029 operands) 1030 .getResult(0); 1031 1032 // Move the value at data[i] to a temporary location. 1033 operands[0] = operands[1] = i; 1034 SmallVector<Value> d; 1035 forEachIJPairInAllBuffers( 1036 builder, loc, operands, xPerm, ny, 1037 [&](uint64_t unused, Value i, Value unused2, Value buffer) { 1038 d.push_back(builder.create<memref::LoadOp>(loc, buffer, i)); 1039 }); 1040 1041 // Start the inner for-stmt with induction variable j, for moving data[p..i) 1042 // to data[p+1..i+1). 1043 Value imp = builder.create<arith::SubIOp>(loc, i, p); 1044 Value c0 = constantIndex(builder, loc, 0); 1045 scf::ForOp forOpJ = builder.create<scf::ForOp>(loc, c0, imp, c1); 1046 builder.setInsertionPointToStart(forOpJ.getBody()); 1047 Value j = forOpJ.getInductionVar(); 1048 Value imj = builder.create<arith::SubIOp>(loc, i, j); 1049 operands[1] = imj; 1050 operands[0] = builder.create<arith::SubIOp>(loc, imj, c1); 1051 forEachIJPairInAllBuffers( 1052 builder, loc, operands, xPerm, ny, 1053 [&](uint64_t unused, Value imjm1, Value imj, Value buffer) { 1054 Value t = builder.create<memref::LoadOp>(loc, buffer, imjm1); 1055 builder.create<memref::StoreOp>(loc, t, buffer, imj); 1056 }); 1057 1058 // Store the value at data[i] to data[p]. 1059 builder.setInsertionPointAfter(forOpJ); 1060 operands[0] = operands[1] = p; 1061 forEachIJPairInAllBuffers( 1062 builder, loc, operands, xPerm, ny, 1063 [&](uint64_t k, Value p, Value usused, Value buffer) { 1064 builder.create<memref::StoreOp>(loc, d[k], buffer, p); 1065 }); 1066 1067 builder.setInsertionPointAfter(forOpI); 1068 builder.create<func::ReturnOp>(loc); 1069 } 1070 1071 /// Creates a function to perform quick sort or a hybrid quick sort on the 1072 /// values in the range of index [lo, hi). 1073 // 1074 // 1075 // When nTrailingP == 0, the generated IR corresponds to this C like algorithm: 1076 // void quickSort(lo, hi, data) { 1077 // while (lo + 1 < hi) { 1078 // p = partition(low, high, data); 1079 // if (len(lo, p) < len(p+1, hi)) { 1080 // quickSort(lo, p, data); 1081 // lo = p+1; 1082 // } else { 1083 // quickSort(p + 1, hi, data); 1084 // hi = p; 1085 // } 1086 // } 1087 // } 1088 // 1089 // When nTrailingP == 1, the generated IR corresponds to this C like algorithm: 1090 // void hybridQuickSort(lo, hi, data, depthLimit) { 1091 // while (lo + 1 < hi) { 1092 // len = hi - lo; 1093 // if (len <= limit) { 1094 // insertionSort(lo, hi, data); 1095 // } else { 1096 // depthLimit --; 1097 // if (depthLimit <= 0) { 1098 // heapSort(lo, hi, data); 1099 // } else { 1100 // p = partition(low, high, data); 1101 // if (len(lo, p) < len(p+1, hi)) { 1102 // quickSort(lo, p, data, depthLimit); 1103 // lo = p+1; 1104 // } else { 1105 // quickSort(p + 1, hi, data, depthLimit); 1106 // hi = p; 1107 // } 1108 // } 1109 // } 1110 // } 1111 // } 1112 // 1113 static void createQuickSortFunc(OpBuilder &builder, ModuleOp module, 1114 func::FuncOp func, AffineMap xPerm, uint64_t ny, 1115 uint32_t nTrailingP) { 1116 assert(nTrailingP == 1 || nTrailingP == 0); 1117 bool isHybrid = (nTrailingP == 1); 1118 OpBuilder::InsertionGuard insertionGuard(builder); 1119 Block *entryBlock = func.addEntryBlock(); 1120 builder.setInsertionPointToStart(entryBlock); 1121 1122 Location loc = func.getLoc(); 1123 SmallVector<Value> args; 1124 args.append(entryBlock->getArguments().begin(), 1125 entryBlock->getArguments().end()); 1126 Value lo = args[loIdx]; 1127 Value hi = args[hiIdx]; 1128 SmallVector<Type, 2> types(2, lo.getType()); // Only two types. 1129 scf::WhileOp whileOp = 1130 builder.create<scf::WhileOp>(loc, types, SmallVector<Value, 2>{lo, hi}); 1131 1132 // The before-region of the WhileOp. 1133 Block *before = 1134 builder.createBlock(&whileOp.getBefore(), {}, types, {loc, loc}); 1135 builder.setInsertionPointToEnd(before); 1136 lo = before->getArgument(0); 1137 hi = before->getArgument(1); 1138 Value loP1 = 1139 builder.create<arith::AddIOp>(loc, lo, constantIndex(builder, loc, 1)); 1140 Value needSort = 1141 builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult, loP1, hi); 1142 builder.create<scf::ConditionOp>(loc, needSort, before->getArguments()); 1143 1144 // The after-region of the WhileOp. 1145 Block *after = 1146 builder.createBlock(&whileOp.getAfter(), {}, types, {loc, loc}); 1147 builder.setInsertionPointToEnd(after); 1148 lo = after->getArgument(0); 1149 hi = after->getArgument(1); 1150 args[0] = lo; 1151 args[1] = hi; 1152 1153 if (isHybrid) { 1154 Value len = builder.create<arith::SubIOp>(loc, hi, lo); 1155 Value lenLimit = constantIndex(builder, loc, 30); 1156 Value lenCond = builder.create<arith::CmpIOp>( 1157 loc, arith::CmpIPredicate::ule, len, lenLimit); 1158 scf::IfOp lenIf = 1159 builder.create<scf::IfOp>(loc, types, lenCond, /*else=*/true); 1160 1161 // When len <= limit. 1162 builder.setInsertionPointToStart(&lenIf.getThenRegion().front()); 1163 FlatSymbolRefAttr insertionSortFunc = getMangledSortHelperFunc( 1164 builder, func, TypeRange(), kSortStableFuncNamePrefix, xPerm, ny, 1165 ValueRange(args).drop_back(nTrailingP), createSortStableFunc); 1166 builder.create<func::CallOp>(loc, insertionSortFunc, TypeRange(), 1167 ValueRange(args).drop_back(nTrailingP)); 1168 builder.create<scf::YieldOp>(loc, ValueRange{lo, lo}); 1169 1170 // When len > limit. 1171 builder.setInsertionPointToStart(&lenIf.getElseRegion().front()); 1172 Value depthLimit = args.back(); 1173 depthLimit = builder.create<arith::SubIOp>(loc, depthLimit, 1174 constantI64(builder, loc, 1)); 1175 Value depthCond = 1176 builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ule, 1177 depthLimit, constantI64(builder, loc, 0)); 1178 scf::IfOp depthIf = 1179 builder.create<scf::IfOp>(loc, types, depthCond, /*else=*/true); 1180 1181 // When depth exceeds limit. 1182 builder.setInsertionPointToStart(&depthIf.getThenRegion().front()); 1183 FlatSymbolRefAttr heapSortFunc = getMangledSortHelperFunc( 1184 builder, func, TypeRange(), kHeapSortFuncNamePrefix, xPerm, ny, 1185 ValueRange(args).drop_back(nTrailingP), createHeapSortFunc); 1186 builder.create<func::CallOp>(loc, heapSortFunc, TypeRange(), 1187 ValueRange(args).drop_back(nTrailingP)); 1188 builder.create<scf::YieldOp>(loc, ValueRange{lo, lo}); 1189 1190 // When depth doesn't exceed limit. 1191 builder.setInsertionPointToStart(&depthIf.getElseRegion().front()); 1192 args.back() = depthLimit; 1193 std::tie(lo, hi) = 1194 createQuickSort(builder, module, func, args, xPerm, ny, nTrailingP); 1195 builder.create<scf::YieldOp>(loc, ValueRange{lo, hi}); 1196 1197 builder.setInsertionPointAfter(depthIf); 1198 lo = depthIf.getResult(0); 1199 hi = depthIf.getResult(1); 1200 builder.create<scf::YieldOp>(loc, ValueRange{lo, hi}); 1201 1202 builder.setInsertionPointAfter(lenIf); 1203 lo = lenIf.getResult(0); 1204 hi = lenIf.getResult(1); 1205 } else { 1206 std::tie(lo, hi) = 1207 createQuickSort(builder, module, func, args, xPerm, ny, nTrailingP); 1208 } 1209 1210 // New [lo, hi) for the next while-loop iteration. 1211 builder.create<scf::YieldOp>(loc, ValueRange{lo, hi}); 1212 1213 // After the while-loop. 1214 builder.setInsertionPointAfter(whileOp); 1215 builder.create<func::ReturnOp>(loc); 1216 } 1217 1218 /// Implements the rewriting for operator sort and sort_coo. 1219 template <typename OpTy> 1220 LogicalResult matchAndRewriteSortOp(OpTy op, ValueRange xys, AffineMap xPerm, 1221 uint64_t ny, PatternRewriter &rewriter) { 1222 Location loc = op.getLoc(); 1223 SmallVector<Value> operands{constantIndex(rewriter, loc, 0), op.getN()}; 1224 1225 // Convert `values` to have dynamic shape and append them to `operands`. 1226 for (Value v : xys) { 1227 auto mtp = getMemRefType(v); 1228 if (!mtp.isDynamicDim(0)) { 1229 auto newMtp = 1230 MemRefType::get({ShapedType::kDynamic}, mtp.getElementType()); 1231 v = rewriter.create<memref::CastOp>(loc, newMtp, v); 1232 } 1233 operands.push_back(v); 1234 } 1235 1236 auto insertPoint = op->template getParentOfType<func::FuncOp>(); 1237 if (!insertPoint) 1238 return failure(); 1239 1240 SmallString<32> funcName; 1241 FuncGeneratorType funcGenerator; 1242 uint32_t nTrailingP = 0; 1243 switch (op.getAlgorithm()) { 1244 case SparseTensorSortKind::HybridQuickSort: { 1245 funcName = kHybridQuickSortFuncNamePrefix; 1246 funcGenerator = createQuickSortFunc; 1247 nTrailingP = 1; 1248 // As a heuristics, set depthLimit = 2 * log2(n). 1249 Value lo = operands[loIdx]; 1250 Value hi = operands[hiIdx]; 1251 Value len = rewriter.create<arith::IndexCastOp>( 1252 loc, rewriter.getI64Type(), 1253 rewriter.create<arith::SubIOp>(loc, hi, lo)); 1254 Value depthLimit = rewriter.create<arith::SubIOp>( 1255 loc, constantI64(rewriter, loc, 64), 1256 rewriter.create<math::CountLeadingZerosOp>(loc, len)); 1257 operands.push_back(depthLimit); 1258 break; 1259 } 1260 case SparseTensorSortKind::QuickSort: 1261 funcName = kQuickSortFuncNamePrefix; 1262 funcGenerator = createQuickSortFunc; 1263 break; 1264 case SparseTensorSortKind::InsertionSortStable: 1265 funcName = kSortStableFuncNamePrefix; 1266 funcGenerator = createSortStableFunc; 1267 break; 1268 case SparseTensorSortKind::HeapSort: 1269 funcName = kHeapSortFuncNamePrefix; 1270 funcGenerator = createHeapSortFunc; 1271 break; 1272 } 1273 1274 FlatSymbolRefAttr func = 1275 getMangledSortHelperFunc(rewriter, insertPoint, TypeRange(), funcName, 1276 xPerm, ny, operands, funcGenerator, nTrailingP); 1277 rewriter.replaceOpWithNewOp<func::CallOp>(op, func, TypeRange(), operands); 1278 return success(); 1279 } 1280 1281 //===---------------------------------------------------------------------===// 1282 // The actual sparse buffer rewriting rules. 1283 //===---------------------------------------------------------------------===// 1284 1285 namespace { 1286 /// Sparse rewriting rule for the push_back operator. 1287 struct PushBackRewriter : OpRewritePattern<PushBackOp> { 1288 public: 1289 using OpRewritePattern<PushBackOp>::OpRewritePattern; 1290 PushBackRewriter(MLIRContext *context, bool enableInit) 1291 : OpRewritePattern(context), enableBufferInitialization(enableInit) {} 1292 LogicalResult matchAndRewrite(PushBackOp op, 1293 PatternRewriter &rewriter) const override { 1294 // Rewrite push_back(buffer, value, n) to: 1295 // new_size = size(buffer) + n 1296 // if (new_size > capacity(buffer)) 1297 // while new_size > new_capacity 1298 // new_capacity = new_capacity*2 1299 // new_buffer = realloc(buffer, new_capacity) 1300 // buffer = new_buffer 1301 // subBuffer = subviewof(buffer) 1302 // linalg.fill subBuffer value 1303 // 1304 // size(buffer) += n 1305 // 1306 // The capacity check is skipped when the attribute inbounds is presented. 1307 Location loc = op->getLoc(); 1308 Value c0 = constantIndex(rewriter, loc, 0); 1309 Value buffer = op.getInBuffer(); 1310 Value capacity = rewriter.create<memref::DimOp>(loc, buffer, c0); 1311 Value size = op.getCurSize(); 1312 Value value = op.getValue(); 1313 1314 Value n = op.getN() ? op.getN() : constantIndex(rewriter, loc, 1); 1315 Value newSize = rewriter.create<arith::AddIOp>(loc, size, n); 1316 auto nValue = dyn_cast_or_null<arith::ConstantIndexOp>(n.getDefiningOp()); 1317 bool nIsOne = (nValue && nValue.value() == 1); 1318 1319 if (!op.getInbounds()) { 1320 Value cond = rewriter.create<arith::CmpIOp>( 1321 loc, arith::CmpIPredicate::ugt, newSize, capacity); 1322 1323 Value c2 = constantIndex(rewriter, loc, 2); 1324 auto bufferType = 1325 MemRefType::get({ShapedType::kDynamic}, value.getType()); 1326 scf::IfOp ifOp = rewriter.create<scf::IfOp>(loc, bufferType, cond, 1327 /*else=*/true); 1328 // True branch. 1329 rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front()); 1330 if (nIsOne) { 1331 capacity = rewriter.create<arith::MulIOp>(loc, capacity, c2); 1332 } else { 1333 // Use a do-while loop to calculate the new capacity as follows: 1334 // do { new_capacity *= 2 } while (size > new_capacity) 1335 scf::WhileOp whileOp = 1336 rewriter.create<scf::WhileOp>(loc, capacity.getType(), capacity); 1337 1338 // The before-region of the WhileOp. 1339 Block *before = rewriter.createBlock(&whileOp.getBefore(), {}, 1340 {capacity.getType()}, {loc}); 1341 rewriter.setInsertionPointToEnd(before); 1342 1343 capacity = 1344 rewriter.create<arith::MulIOp>(loc, before->getArgument(0), c2); 1345 cond = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ugt, 1346 newSize, capacity); 1347 rewriter.create<scf::ConditionOp>(loc, cond, ValueRange{capacity}); 1348 // The after-region of the WhileOp. 1349 Block *after = rewriter.createBlock(&whileOp.getAfter(), {}, 1350 {capacity.getType()}, {loc}); 1351 rewriter.setInsertionPointToEnd(after); 1352 rewriter.create<scf::YieldOp>(loc, after->getArguments()); 1353 1354 rewriter.setInsertionPointAfter(whileOp); 1355 capacity = whileOp.getResult(0); 1356 } 1357 1358 Value newBuffer = 1359 rewriter.create<memref::ReallocOp>(loc, bufferType, buffer, capacity); 1360 if (enableBufferInitialization) { 1361 Value fillSize = rewriter.create<arith::SubIOp>(loc, capacity, newSize); 1362 Value fillValue = constantZero(rewriter, loc, value.getType()); 1363 Value subBuffer = rewriter.create<memref::SubViewOp>( 1364 loc, newBuffer, /*offset=*/ValueRange{newSize}, 1365 /*size=*/ValueRange{fillSize}, 1366 /*step=*/ValueRange{constantIndex(rewriter, loc, 1)}); 1367 rewriter.create<linalg::FillOp>(loc, fillValue, subBuffer); 1368 } 1369 rewriter.create<scf::YieldOp>(loc, newBuffer); 1370 1371 // False branch. 1372 rewriter.setInsertionPointToStart(&ifOp.getElseRegion().front()); 1373 rewriter.create<scf::YieldOp>(loc, buffer); 1374 1375 // Prepare for adding the value to the end of the buffer. 1376 rewriter.setInsertionPointAfter(ifOp); 1377 buffer = ifOp.getResult(0); 1378 } 1379 1380 // Add the value to the end of the buffer. 1381 if (nIsOne) { 1382 rewriter.create<memref::StoreOp>(loc, value, buffer, size); 1383 } else { 1384 Value subBuffer = rewriter.create<memref::SubViewOp>( 1385 loc, buffer, /*offset=*/ValueRange{size}, /*size=*/ValueRange{n}, 1386 /*step=*/ValueRange{constantIndex(rewriter, loc, 1)}); 1387 rewriter.create<linalg::FillOp>(loc, value, subBuffer); 1388 } 1389 1390 // Update the buffer size. 1391 rewriter.replaceOp(op, {buffer, newSize}); 1392 return success(); 1393 } 1394 1395 private: 1396 bool enableBufferInitialization; 1397 }; 1398 1399 /// Sparse rewriting rule for the sort_coo operator. 1400 struct SortRewriter : public OpRewritePattern<SortOp> { 1401 public: 1402 using OpRewritePattern<SortOp>::OpRewritePattern; 1403 1404 LogicalResult matchAndRewrite(SortOp op, 1405 PatternRewriter &rewriter) const override { 1406 SmallVector<Value> xys; 1407 xys.push_back(op.getXy()); 1408 xys.append(op.getYs().begin(), op.getYs().end()); 1409 1410 auto xPerm = op.getPermMap(); 1411 uint64_t ny = 0; 1412 if (auto nyAttr = op.getNyAttr()) 1413 ny = nyAttr.getInt(); 1414 1415 return matchAndRewriteSortOp(op, xys, xPerm, ny, rewriter); 1416 } 1417 }; 1418 1419 } // namespace 1420 1421 //===---------------------------------------------------------------------===// 1422 // Methods that add patterns described in this file to a pattern list. 1423 //===---------------------------------------------------------------------===// 1424 1425 void mlir::populateSparseBufferRewriting(RewritePatternSet &patterns, 1426 bool enableBufferInitialization) { 1427 patterns.add<PushBackRewriter>(patterns.getContext(), 1428 enableBufferInitialization); 1429 patterns.add<SortRewriter>(patterns.getContext()); 1430 } 1431