1 //===- SparseBufferRewriting.cpp - Sparse buffer rewriting rules ----------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements rewriting rules that are specific to sparse tensor 10 // primitives with memref operands. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "CodegenUtils.h" 15 16 #include "mlir/Dialect/Arith/IR/Arith.h" 17 #include "mlir/Dialect/Func/IR/FuncOps.h" 18 #include "mlir/Dialect/Linalg/IR/Linalg.h" 19 #include "mlir/Dialect/Math/IR/Math.h" 20 #include "mlir/Dialect/MemRef/IR/MemRef.h" 21 #include "mlir/Dialect/SCF/IR/SCF.h" 22 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" 23 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h" 24 #include "mlir/Support/LLVM.h" 25 26 using namespace mlir; 27 using namespace mlir::sparse_tensor; 28 29 //===---------------------------------------------------------------------===// 30 // Helper methods for the actual rewriting rules. 31 //===---------------------------------------------------------------------===// 32 33 static constexpr uint64_t loIdx = 0; 34 static constexpr uint64_t hiIdx = 1; 35 static constexpr uint64_t xStartIdx = 2; 36 37 static constexpr const char kLessThanFuncNamePrefix[] = "_sparse_less_than_"; 38 static constexpr const char kCompareEqFuncNamePrefix[] = "_sparse_compare_eq_"; 39 static constexpr const char kPartitionFuncNamePrefix[] = "_sparse_partition_"; 40 static constexpr const char kBinarySearchFuncNamePrefix[] = 41 "_sparse_binary_search_"; 42 static constexpr const char kHybridQuickSortFuncNamePrefix[] = 43 "_sparse_hybrid_qsort_"; 44 static constexpr const char kSortStableFuncNamePrefix[] = 45 "_sparse_sort_stable_"; 46 static constexpr const char kShiftDownFuncNamePrefix[] = "_sparse_shift_down_"; 47 static constexpr const char kHeapSortFuncNamePrefix[] = "_sparse_heap_sort_"; 48 static constexpr const char kQuickSortFuncNamePrefix[] = "_sparse_qsort_"; 49 50 using FuncGeneratorType = function_ref<void( 51 OpBuilder &, ModuleOp, func::FuncOp, uint64_t, uint64_t, bool, uint32_t)>; 52 53 /// Constructs a function name with this format to facilitate quick sort: 54 /// <namePrefix><nx>_<x type>_<y0 type>..._<yn type> for sort 55 /// <namePrefix><nx>_<x type>_coo_<ny>_<y0 type>..._<yn type> for sort_coo 56 static void getMangledSortHelperFuncName(llvm::raw_svector_ostream &nameOstream, 57 StringRef namePrefix, uint64_t nx, 58 uint64_t ny, bool isCoo, 59 ValueRange operands) { 60 nameOstream << namePrefix << nx << "_" 61 << getMemRefType(operands[xStartIdx]).getElementType(); 62 63 if (isCoo) 64 nameOstream << "_coo_" << ny; 65 66 uint64_t yBufferOffset = isCoo ? 1 : nx; 67 for (Value v : operands.drop_front(xStartIdx + yBufferOffset)) 68 nameOstream << "_" << getMemRefType(v).getElementType(); 69 } 70 71 /// Looks up a function that is appropriate for the given operands being 72 /// sorted, and creates such a function if it doesn't exist yet. The 73 /// parameters `nx` and `ny` tell the number of x and y values provided 74 /// by the buffer in xStartIdx, and `isCoo` indicates whether the instruction 75 /// being processed is a sparse_tensor.sort or sparse_tensor.sort_coo. 76 // 77 // All sorting function generators take (lo, hi, xs, ys) in `operands` as 78 // parameters for the sorting functions. Other parameters, such as the recursive 79 // call depth, are appended to the end of the parameter list as 80 // "trailing parameters". 81 static FlatSymbolRefAttr 82 getMangledSortHelperFunc(OpBuilder &builder, func::FuncOp insertPoint, 83 TypeRange resultTypes, StringRef namePrefix, 84 uint64_t nx, uint64_t ny, bool isCoo, 85 ValueRange operands, FuncGeneratorType createFunc, 86 uint32_t nTrailingP = 0) { 87 SmallString<32> nameBuffer; 88 llvm::raw_svector_ostream nameOstream(nameBuffer); 89 getMangledSortHelperFuncName(nameOstream, namePrefix, nx, ny, isCoo, 90 operands.drop_back(nTrailingP)); 91 92 ModuleOp module = insertPoint->getParentOfType<ModuleOp>(); 93 MLIRContext *context = module.getContext(); 94 auto result = SymbolRefAttr::get(context, nameOstream.str()); 95 auto func = module.lookupSymbol<func::FuncOp>(result.getAttr()); 96 97 if (!func) { 98 // Create the function. 99 OpBuilder::InsertionGuard insertionGuard(builder); 100 builder.setInsertionPoint(insertPoint); 101 Location loc = insertPoint.getLoc(); 102 func = builder.create<func::FuncOp>( 103 loc, nameOstream.str(), 104 FunctionType::get(context, operands.getTypes(), resultTypes)); 105 func.setPrivate(); 106 createFunc(builder, module, func, nx, ny, isCoo, nTrailingP); 107 } 108 109 return result; 110 } 111 112 /// Creates a code block to process each pair of (xs[i], xs[j]) for sorting. 113 /// The code to process the value pairs is generated by `bodyBuilder`. 114 static void forEachIJPairInXs( 115 OpBuilder &builder, Location loc, ValueRange args, uint64_t nx, uint64_t ny, 116 bool isCoo, function_ref<void(uint64_t, Value, Value, Value)> bodyBuilder) { 117 Value iOffset, jOffset; 118 if (isCoo) { 119 Value cstep = constantIndex(builder, loc, nx + ny); 120 iOffset = builder.create<arith::MulIOp>(loc, args[0], cstep); 121 jOffset = builder.create<arith::MulIOp>(loc, args[1], cstep); 122 } 123 for (uint64_t k = 0; k < nx; k++) { 124 scf::IfOp ifOp; 125 Value i, j, buffer; 126 if (isCoo) { 127 Value ck = constantIndex(builder, loc, k); 128 i = builder.create<arith::AddIOp>(loc, ck, iOffset); 129 j = builder.create<arith::AddIOp>(loc, ck, jOffset); 130 buffer = args[xStartIdx]; 131 } else { 132 i = args[0]; 133 j = args[1]; 134 buffer = args[xStartIdx + k]; 135 } 136 bodyBuilder(k, i, j, buffer); 137 } 138 } 139 140 /// Creates a code block to process each pair of (xys[i], xys[j]) for sorting. 141 /// The code to process the value pairs is generated by `bodyBuilder`. 142 static void forEachIJPairInAllBuffers( 143 OpBuilder &builder, Location loc, ValueRange args, uint64_t nx, uint64_t ny, 144 bool isCoo, function_ref<void(uint64_t, Value, Value, Value)> bodyBuilder) { 145 146 // Create code for the first (nx + ny) buffers. When isCoo==true, these 147 // logical buffers are all from the xy buffer of the sort_coo operator. 148 forEachIJPairInXs(builder, loc, args, nx + ny, 0, isCoo, bodyBuilder); 149 150 uint64_t numHandledBuffers = isCoo ? 1 : nx + ny; 151 152 // Create code for the remaining buffers. 153 Value i = args[0]; 154 Value j = args[1]; 155 for (const auto &arg : 156 llvm::enumerate(args.drop_front(xStartIdx + numHandledBuffers))) { 157 bodyBuilder(arg.index() + nx + ny, i, j, arg.value()); 158 } 159 } 160 161 /// Creates a code block for swapping the values in index i and j for all the 162 /// buffers. 163 // 164 // The generated IR corresponds to this C like algorithm: 165 // swap(x0[i], x0[j]); 166 // swap(x1[i], x1[j]); 167 // ... 168 // swap(xn[i], xn[j]); 169 // swap(y0[i], y0[j]); 170 // ... 171 // swap(yn[i], yn[j]); 172 static void createSwap(OpBuilder &builder, Location loc, ValueRange args, 173 uint64_t nx, uint64_t ny, bool isCoo) { 174 auto swapOnePair = [&](uint64_t unused, Value i, Value j, Value buffer) { 175 Value vi = builder.create<memref::LoadOp>(loc, buffer, i); 176 Value vj = builder.create<memref::LoadOp>(loc, buffer, j); 177 builder.create<memref::StoreOp>(loc, vj, buffer, i); 178 builder.create<memref::StoreOp>(loc, vi, buffer, j); 179 }; 180 181 forEachIJPairInAllBuffers(builder, loc, args, nx, ny, isCoo, swapOnePair); 182 } 183 184 /// Creates a function to compare all the (xs[i], xs[j]) pairs. The method to 185 /// compare each pair is create via `compareBuilder`. 186 static void createCompareFuncImplementation( 187 OpBuilder &builder, ModuleOp unused, func::FuncOp func, uint64_t nx, 188 uint64_t ny, bool isCoo, 189 function_ref<scf::IfOp(OpBuilder &, Location, Value, Value, Value, bool)> 190 compareBuilder) { 191 OpBuilder::InsertionGuard insertionGuard(builder); 192 193 Block *entryBlock = func.addEntryBlock(); 194 builder.setInsertionPointToStart(entryBlock); 195 Location loc = func.getLoc(); 196 ValueRange args = entryBlock->getArguments(); 197 198 scf::IfOp topIfOp; 199 auto bodyBuilder = [&](uint64_t k, Value i, Value j, Value buffer) { 200 scf::IfOp ifOp = compareBuilder(builder, loc, i, j, buffer, (k == nx - 1)); 201 if (k == 0) { 202 topIfOp = ifOp; 203 } else { 204 OpBuilder::InsertionGuard insertionGuard(builder); 205 builder.setInsertionPointAfter(ifOp); 206 builder.create<scf::YieldOp>(loc, ifOp.getResult(0)); 207 } 208 }; 209 210 forEachIJPairInXs(builder, loc, args, nx, ny, isCoo, bodyBuilder); 211 212 builder.setInsertionPointAfter(topIfOp); 213 builder.create<func::ReturnOp>(loc, topIfOp.getResult(0)); 214 } 215 216 /// Generates an if-statement to compare whether x[i] is equal to x[j]. 217 static scf::IfOp createEqCompare(OpBuilder &builder, Location loc, Value i, 218 Value j, Value x, bool isLastDim) { 219 Value f = constantI1(builder, loc, false); 220 Value t = constantI1(builder, loc, true); 221 Value vi = builder.create<memref::LoadOp>(loc, x, i); 222 Value vj = builder.create<memref::LoadOp>(loc, x, j); 223 224 Value cond = 225 builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, vi, vj); 226 scf::IfOp ifOp = 227 builder.create<scf::IfOp>(loc, f.getType(), cond, /*else=*/true); 228 229 // x[1] != x[j]: 230 builder.setInsertionPointToStart(&ifOp.getElseRegion().front()); 231 builder.create<scf::YieldOp>(loc, f); 232 233 // x[i] == x[j]: 234 builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); 235 if (isLastDim == 1) { 236 // Finish checking all dimensions. 237 builder.create<scf::YieldOp>(loc, t); 238 } 239 240 return ifOp; 241 } 242 243 /// Creates a function to compare whether xs[i] is equal to xs[j]. 244 // 245 // The generate IR corresponds to this C like algorithm: 246 // if (x0[i] != x0[j]) 247 // return false; 248 // else 249 // if (x1[i] != x1[j]) 250 // return false; 251 // else if (x2[2] != x2[j])) 252 // and so on ... 253 static void createEqCompareFunc(OpBuilder &builder, ModuleOp unused, 254 func::FuncOp func, uint64_t nx, uint64_t ny, 255 bool isCoo, uint32_t nTrailingP = 0) { 256 // Compare functions don't use trailing parameters. 257 (void)nTrailingP; 258 assert(nTrailingP == 0); 259 createCompareFuncImplementation(builder, unused, func, nx, ny, isCoo, 260 createEqCompare); 261 } 262 263 /// Generates an if-statement to compare whether x[i] is less than x[j]. 264 static scf::IfOp createLessThanCompare(OpBuilder &builder, Location loc, 265 Value i, Value j, Value x, 266 bool isLastDim) { 267 Value f = constantI1(builder, loc, false); 268 Value t = constantI1(builder, loc, true); 269 Value vi = builder.create<memref::LoadOp>(loc, x, i); 270 Value vj = builder.create<memref::LoadOp>(loc, x, j); 271 272 Value cond = 273 builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult, vi, vj); 274 scf::IfOp ifOp = 275 builder.create<scf::IfOp>(loc, f.getType(), cond, /*else=*/true); 276 // If (x[i] < x[j]). 277 builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); 278 builder.create<scf::YieldOp>(loc, t); 279 280 builder.setInsertionPointToStart(&ifOp.getElseRegion().front()); 281 if (isLastDim == 1) { 282 // Finish checking all dimensions. 283 builder.create<scf::YieldOp>(loc, f); 284 } else { 285 cond = 286 builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult, vj, vi); 287 scf::IfOp ifOp2 = 288 builder.create<scf::IfOp>(loc, f.getType(), cond, /*else=*/true); 289 // Otherwise if (x[j] < x[i]). 290 builder.setInsertionPointToStart(&ifOp2.getThenRegion().front()); 291 builder.create<scf::YieldOp>(loc, f); 292 293 // Otherwise check the remaining dimensions. 294 builder.setInsertionPointAfter(ifOp2); 295 builder.create<scf::YieldOp>(loc, ifOp2.getResult(0)); 296 // Set up the insertion point for the nested if-stmt that checks the 297 // remaining dimensions. 298 builder.setInsertionPointToStart(&ifOp2.getElseRegion().front()); 299 } 300 301 return ifOp; 302 } 303 304 /// Creates a function to compare whether xs[i] is less than xs[j]. 305 // 306 // The generate IR corresponds to this C like algorithm: 307 // if (x0[i] < x0[j]) 308 // return true; 309 // else if (x0[j] < x0[i]) 310 // return false; 311 // else 312 // if (x1[i] < x1[j]) 313 // return true; 314 // else if (x1[j] < x1[i])) 315 // and so on ... 316 static void createLessThanFunc(OpBuilder &builder, ModuleOp unused, 317 func::FuncOp func, uint64_t nx, uint64_t ny, 318 bool isCoo, uint32_t nTrailingP = 0) { 319 // Compare functions don't use trailing parameters. 320 (void)nTrailingP; 321 assert(nTrailingP == 0); 322 createCompareFuncImplementation(builder, unused, func, nx, ny, isCoo, 323 createLessThanCompare); 324 } 325 326 /// Creates a function to use a binary search to find the insertion point for 327 /// inserting xs[hi] to the sorted values xs[lo..hi). 328 // 329 // The generate IR corresponds to this C like algorithm: 330 // p = hi 331 // while (lo < hi) 332 // mid = (lo + hi) >> 1 333 // if (xs[p] < xs[mid]) 334 // hi = mid 335 // else 336 // lo = mid - 1 337 // return lo; 338 // 339 static void createBinarySearchFunc(OpBuilder &builder, ModuleOp module, 340 func::FuncOp func, uint64_t nx, uint64_t ny, 341 bool isCoo, uint32_t nTrailingP = 0) { 342 // Binary search doesn't use trailing parameters. 343 (void)nTrailingP; 344 assert(nTrailingP == 0); 345 OpBuilder::InsertionGuard insertionGuard(builder); 346 Block *entryBlock = func.addEntryBlock(); 347 builder.setInsertionPointToStart(entryBlock); 348 349 Location loc = func.getLoc(); 350 ValueRange args = entryBlock->getArguments(); 351 Value p = args[hiIdx]; 352 SmallVector<Type, 2> types(2, p.getType()); // Only two types. 353 scf::WhileOp whileOp = builder.create<scf::WhileOp>( 354 loc, types, SmallVector<Value, 2>{args[loIdx], args[hiIdx]}); 355 356 // The before-region of the WhileOp. 357 Block *before = 358 builder.createBlock(&whileOp.getBefore(), {}, types, {loc, loc}); 359 builder.setInsertionPointToEnd(before); 360 Value cond1 = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult, 361 before->getArgument(0), 362 before->getArgument(1)); 363 builder.create<scf::ConditionOp>(loc, cond1, before->getArguments()); 364 365 // The after-region of the WhileOp. 366 Block *after = 367 builder.createBlock(&whileOp.getAfter(), {}, types, {loc, loc}); 368 builder.setInsertionPointToEnd(after); 369 Value lo = after->getArgument(0); 370 Value hi = after->getArgument(1); 371 // Compute mid = (lo + hi) >> 1. 372 Value c1 = constantIndex(builder, loc, 1); 373 Value mid = builder.create<arith::ShRUIOp>( 374 loc, builder.create<arith::AddIOp>(loc, lo, hi), c1); 375 Value midp1 = builder.create<arith::AddIOp>(loc, mid, c1); 376 377 // Compare xs[p] < xs[mid]. 378 SmallVector<Value> compareOperands{p, mid}; 379 uint64_t numXBuffers = isCoo ? 1 : nx; 380 compareOperands.append(args.begin() + xStartIdx, 381 args.begin() + xStartIdx + numXBuffers); 382 Type i1Type = IntegerType::get(module.getContext(), 1, IntegerType::Signless); 383 FlatSymbolRefAttr lessThanFunc = getMangledSortHelperFunc( 384 builder, func, {i1Type}, kLessThanFuncNamePrefix, nx, ny, isCoo, 385 compareOperands, createLessThanFunc, nTrailingP); 386 Value cond2 = builder 387 .create<func::CallOp>(loc, lessThanFunc, TypeRange{i1Type}, 388 compareOperands) 389 .getResult(0); 390 391 // Update lo and hi for the WhileOp as follows: 392 // if (xs[p] < xs[mid])) 393 // hi = mid; 394 // else 395 // lo = mid + 1; 396 Value newLo = builder.create<arith::SelectOp>(loc, cond2, lo, midp1); 397 Value newHi = builder.create<arith::SelectOp>(loc, cond2, mid, hi); 398 builder.create<scf::YieldOp>(loc, ValueRange{newLo, newHi}); 399 400 builder.setInsertionPointAfter(whileOp); 401 builder.create<func::ReturnOp>(loc, whileOp.getResult(0)); 402 } 403 404 /// Creates code to advance i in a loop based on xs[p] as follows: 405 /// while (xs[i] < xs[p]) i += step (step > 0) 406 /// or 407 /// while (xs[i] > xs[p]) i += step (step < 0) 408 /// The routine returns i as well as a boolean value to indicate whether 409 /// xs[i] == xs[p]. 410 static std::pair<Value, Value> 411 createScanLoop(OpBuilder &builder, ModuleOp module, func::FuncOp func, 412 ValueRange xs, Value i, Value p, uint64_t nx, uint64_t ny, 413 bool isCoo, int step) { 414 Location loc = func.getLoc(); 415 scf::WhileOp whileOp = 416 builder.create<scf::WhileOp>(loc, TypeRange{i.getType()}, ValueRange{i}); 417 418 Block *before = 419 builder.createBlock(&whileOp.getBefore(), {}, {i.getType()}, {loc}); 420 builder.setInsertionPointToEnd(before); 421 SmallVector<Value> compareOperands; 422 if (step > 0) { 423 compareOperands.push_back(before->getArgument(0)); 424 compareOperands.push_back(p); 425 } else { 426 assert(step < 0); 427 compareOperands.push_back(p); 428 compareOperands.push_back(before->getArgument(0)); 429 } 430 compareOperands.append(xs.begin(), xs.end()); 431 MLIRContext *context = module.getContext(); 432 Type i1Type = IntegerType::get(context, 1, IntegerType::Signless); 433 FlatSymbolRefAttr lessThanFunc = getMangledSortHelperFunc( 434 builder, func, {i1Type}, kLessThanFuncNamePrefix, nx, ny, isCoo, 435 compareOperands, createLessThanFunc); 436 Value cond = builder 437 .create<func::CallOp>(loc, lessThanFunc, TypeRange{i1Type}, 438 compareOperands) 439 .getResult(0); 440 builder.create<scf::ConditionOp>(loc, cond, before->getArguments()); 441 442 Block *after = 443 builder.createBlock(&whileOp.getAfter(), {}, {i.getType()}, {loc}); 444 builder.setInsertionPointToEnd(after); 445 Value cs = constantIndex(builder, loc, step); 446 i = builder.create<arith::AddIOp>(loc, after->getArgument(0), cs); 447 builder.create<scf::YieldOp>(loc, ValueRange{i}); 448 i = whileOp.getResult(0); 449 450 builder.setInsertionPointAfter(whileOp); 451 compareOperands[0] = i; 452 compareOperands[1] = p; 453 FlatSymbolRefAttr compareEqFunc = getMangledSortHelperFunc( 454 builder, func, {i1Type}, kCompareEqFuncNamePrefix, nx, ny, isCoo, 455 compareOperands, createEqCompareFunc); 456 Value compareEq = 457 builder 458 .create<func::CallOp>(loc, compareEqFunc, TypeRange{i1Type}, 459 compareOperands) 460 .getResult(0); 461 462 return std::make_pair(whileOp.getResult(0), compareEq); 463 } 464 465 /// Creates a code block to swap the values so that data[mi] is the median among 466 /// data[lo], data[hi], and data[mi]. 467 // The generated code corresponds to this C-like algorithm: 468 // median = mi 469 // if (data[mi] < data[lo]). (if1) 470 // if (data[hi] < data[lo]) (if2) 471 // median = data[hi] < data[mi] ? mi : hi 472 // else 473 // median = lo 474 // else 475 // if data[hi] < data[mi] (if3) 476 // median = data[hi] < data[lo] ? lo : hi 477 // if median != mi swap data[median] with data[mi] 478 static void createChoosePivot(OpBuilder &builder, ModuleOp module, 479 func::FuncOp func, uint64_t nx, uint64_t ny, 480 bool isCoo, Value lo, Value hi, Value mi, 481 ValueRange args) { 482 SmallVector<Value> compareOperands{mi, lo}; 483 uint64_t numXBuffers = isCoo ? 1 : nx; 484 compareOperands.append(args.begin() + xStartIdx, 485 args.begin() + xStartIdx + numXBuffers); 486 Type i1Type = IntegerType::get(module.getContext(), 1, IntegerType::Signless); 487 SmallVector<Type, 1> cmpTypes{i1Type}; 488 FlatSymbolRefAttr lessThanFunc = getMangledSortHelperFunc( 489 builder, func, cmpTypes, kLessThanFuncNamePrefix, nx, ny, isCoo, 490 compareOperands, createLessThanFunc); 491 Location loc = func.getLoc(); 492 // Compare data[mi] < data[lo]. 493 Value cond1 = 494 builder.create<func::CallOp>(loc, lessThanFunc, cmpTypes, compareOperands) 495 .getResult(0); 496 SmallVector<Type, 1> ifTypes{lo.getType()}; 497 scf::IfOp ifOp1 = 498 builder.create<scf::IfOp>(loc, ifTypes, cond1, /*else=*/true); 499 500 // Generate an if-stmt to find the median value, assuming we already know that 501 // data[b] < data[a] and we haven't compare data[c] yet. 502 auto createFindMedian = [&](Value a, Value b, Value c) -> scf::IfOp { 503 compareOperands[0] = c; 504 compareOperands[1] = a; 505 // Compare data[c]] < data[a]. 506 Value cond2 = 507 builder 508 .create<func::CallOp>(loc, lessThanFunc, cmpTypes, compareOperands) 509 .getResult(0); 510 scf::IfOp ifOp2 = 511 builder.create<scf::IfOp>(loc, ifTypes, cond2, /*else=*/true); 512 builder.setInsertionPointToStart(&ifOp2.getThenRegion().front()); 513 compareOperands[0] = c; 514 compareOperands[1] = b; 515 // Compare data[c] < data[b]. 516 Value cond3 = 517 builder 518 .create<func::CallOp>(loc, lessThanFunc, cmpTypes, compareOperands) 519 .getResult(0); 520 builder.create<scf::YieldOp>( 521 loc, ValueRange{builder.create<arith::SelectOp>(loc, cond3, b, c)}); 522 builder.setInsertionPointToStart(&ifOp2.getElseRegion().front()); 523 builder.create<scf::YieldOp>(loc, ValueRange{a}); 524 return ifOp2; 525 }; 526 527 builder.setInsertionPointToStart(&ifOp1.getThenRegion().front()); 528 scf::IfOp ifOp2 = createFindMedian(lo, mi, hi); 529 builder.setInsertionPointAfter(ifOp2); 530 builder.create<scf::YieldOp>(loc, ValueRange{ifOp2.getResult(0)}); 531 532 builder.setInsertionPointToStart(&ifOp1.getElseRegion().front()); 533 scf::IfOp ifOp3 = createFindMedian(mi, lo, hi); 534 535 builder.setInsertionPointAfter(ifOp3); 536 builder.create<scf::YieldOp>(loc, ValueRange{ifOp3.getResult(0)}); 537 538 builder.setInsertionPointAfter(ifOp1); 539 Value median = ifOp1.getResult(0); 540 Value cond = 541 builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ne, mi, median); 542 scf::IfOp ifOp = 543 builder.create<scf::IfOp>(loc, TypeRange(), cond, /*else=*/false); 544 545 SmallVector<Value> swapOperands{median, mi}; 546 swapOperands.append(args.begin() + xStartIdx, args.end()); 547 builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); 548 createSwap(builder, loc, swapOperands, nx, ny, isCoo); 549 builder.setInsertionPointAfter(ifOp); 550 } 551 552 /// Creates a function to perform quick sort partition on the values in the 553 /// range of index [lo, hi), assuming lo < hi. 554 // 555 // The generated IR corresponds to this C like algorithm: 556 // int partition(lo, hi, xs) { 557 // p = (lo+hi)/2 // pivot index 558 // i = lo 559 // j = hi-1 560 // while (i < j) do { 561 // while (xs[i] < xs[p]) i ++; 562 // i_eq = (xs[i] == xs[p]); 563 // while (xs[j] > xs[p]) j --; 564 // j_eq = (xs[j] == xs[p]); 565 // if (i < j) { 566 // swap(xs[i], xs[j]) 567 // if (i == p) { 568 // p = j; 569 // } else if (j == p) { 570 // p = i; 571 // } 572 // if (i_eq && j_eq) { 573 // ++i; 574 // --j; 575 // } 576 // } 577 // } 578 // return p 579 // } 580 static void createPartitionFunc(OpBuilder &builder, ModuleOp module, 581 func::FuncOp func, uint64_t nx, uint64_t ny, 582 bool isCoo, uint32_t nTrailingP = 0) { 583 // Quick sort partition doesn't use trailing parameters. 584 (void)nTrailingP; 585 assert(nTrailingP == 0); 586 OpBuilder::InsertionGuard insertionGuard(builder); 587 588 Block *entryBlock = func.addEntryBlock(); 589 builder.setInsertionPointToStart(entryBlock); 590 591 Location loc = func.getLoc(); 592 ValueRange args = entryBlock->getArguments(); 593 Value lo = args[loIdx]; 594 Value hi = args[hiIdx]; 595 Value sum = builder.create<arith::AddIOp>(loc, lo, hi); 596 Value c1 = constantIndex(builder, loc, 1); 597 Value p = builder.create<arith::ShRUIOp>(loc, sum, c1); 598 599 Value i = lo; 600 Value j = builder.create<arith::SubIOp>(loc, hi, c1); 601 createChoosePivot(builder, module, func, nx, ny, isCoo, i, j, p, args); 602 SmallVector<Value, 3> operands{i, j, p}; // Exactly three values. 603 SmallVector<Type, 3> types{i.getType(), j.getType(), p.getType()}; 604 scf::WhileOp whileOp = builder.create<scf::WhileOp>(loc, types, operands); 605 606 // The before-region of the WhileOp. 607 Block *before = 608 builder.createBlock(&whileOp.getBefore(), {}, types, {loc, loc, loc}); 609 builder.setInsertionPointToEnd(before); 610 Value cond = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult, 611 before->getArgument(0), 612 before->getArgument(1)); 613 builder.create<scf::ConditionOp>(loc, cond, before->getArguments()); 614 615 // The after-region of the WhileOp. 616 Block *after = 617 builder.createBlock(&whileOp.getAfter(), {}, types, {loc, loc, loc}); 618 builder.setInsertionPointToEnd(after); 619 i = after->getArgument(0); 620 j = after->getArgument(1); 621 p = after->getArgument(2); 622 623 uint64_t numXBuffers = isCoo ? 1 : nx; 624 auto [iresult, iCompareEq] = 625 createScanLoop(builder, module, func, args.slice(xStartIdx, numXBuffers), 626 i, p, nx, ny, isCoo, 1); 627 i = iresult; 628 auto [jresult, jCompareEq] = 629 createScanLoop(builder, module, func, args.slice(xStartIdx, numXBuffers), 630 j, p, nx, ny, isCoo, -1); 631 j = jresult; 632 633 // If i < j: 634 cond = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult, i, j); 635 scf::IfOp ifOp = builder.create<scf::IfOp>(loc, types, cond, /*else=*/true); 636 builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); 637 SmallVector<Value> swapOperands{i, j}; 638 swapOperands.append(args.begin() + xStartIdx, args.end()); 639 createSwap(builder, loc, swapOperands, nx, ny, isCoo); 640 // If the pivot is moved, update p with the new pivot. 641 Value icond = 642 builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, i, p); 643 scf::IfOp ifOpI = builder.create<scf::IfOp>(loc, TypeRange{p.getType()}, 644 icond, /*else=*/true); 645 builder.setInsertionPointToStart(&ifOpI.getThenRegion().front()); 646 builder.create<scf::YieldOp>(loc, ValueRange{j}); 647 builder.setInsertionPointToStart(&ifOpI.getElseRegion().front()); 648 Value jcond = 649 builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, j, p); 650 scf::IfOp ifOpJ = builder.create<scf::IfOp>(loc, TypeRange{p.getType()}, 651 jcond, /*else=*/true); 652 builder.setInsertionPointToStart(&ifOpJ.getThenRegion().front()); 653 builder.create<scf::YieldOp>(loc, ValueRange{i}); 654 builder.setInsertionPointToStart(&ifOpJ.getElseRegion().front()); 655 builder.create<scf::YieldOp>(loc, ValueRange{p}); 656 builder.setInsertionPointAfter(ifOpJ); 657 builder.create<scf::YieldOp>(loc, ifOpJ.getResults()); 658 builder.setInsertionPointAfter(ifOpI); 659 Value compareEqIJ = 660 builder.create<arith::AndIOp>(loc, iCompareEq, jCompareEq); 661 scf::IfOp ifOp2 = builder.create<scf::IfOp>( 662 loc, TypeRange{i.getType(), j.getType()}, compareEqIJ, /*else=*/true); 663 builder.setInsertionPointToStart(&ifOp2.getThenRegion().front()); 664 Value i2 = builder.create<arith::AddIOp>(loc, i, c1); 665 Value j2 = builder.create<arith::SubIOp>(loc, j, c1); 666 builder.create<scf::YieldOp>(loc, ValueRange{i2, j2}); 667 builder.setInsertionPointToStart(&ifOp2.getElseRegion().front()); 668 builder.create<scf::YieldOp>(loc, ValueRange{i, j}); 669 builder.setInsertionPointAfter(ifOp2); 670 builder.create<scf::YieldOp>( 671 loc, 672 ValueRange{ifOp2.getResult(0), ifOp2.getResult(1), ifOpI.getResult(0)}); 673 674 // False branch for if i < j: 675 builder.setInsertionPointToStart(&ifOp.getElseRegion().front()); 676 builder.create<scf::YieldOp>(loc, ValueRange{i, j, p}); 677 678 // Return for the whileOp. 679 builder.setInsertionPointAfter(ifOp); 680 builder.create<scf::YieldOp>(loc, ifOp.getResults()); 681 682 // Return for the function. 683 builder.setInsertionPointAfter(whileOp); 684 builder.create<func::ReturnOp>(loc, whileOp.getResult(2)); 685 } 686 687 /// Computes (n-2)/n, assuming n has index type. 688 static Value createSubTwoDividedByTwo(OpBuilder &builder, Location loc, 689 Value n) { 690 Value i2 = constantIndex(builder, loc, 2); 691 Value res = builder.create<arith::SubIOp>(loc, n, i2); 692 Value i1 = constantIndex(builder, loc, 1); 693 return builder.create<arith::ShRUIOp>(loc, res, i1); 694 } 695 696 /// Creates a function to heapify the subtree with root `start` within the full 697 /// binary tree in the range of index [first, first + n). 698 // 699 // The generated IR corresponds to this C like algorithm: 700 // void shiftDown(first, start, n, data) { 701 // if (n >= 2) { 702 // child = start - first 703 // if ((n-2)/2 >= child) { 704 // // Left child exists. 705 // child = child * 2 + 1 // Initialize the bigger child to left child. 706 // childIndex = child + first 707 // if (child+1 < n && data[childIndex] < data[childIndex+1]) 708 // // Right child exits and is bigger. 709 // childIndex++; child++; 710 // // Shift data[start] down to where it belongs in the subtree. 711 // while (data[start] < data[childIndex) { 712 // swap(data[start], data[childIndex]) 713 // start = childIndex 714 // if ((n - 2)/2 >= child) { 715 // // Left child exists. 716 // child = 2*child + 1 717 // childIndex = child + 1 718 // if (child + 1) < n && data[childIndex] < data[childIndex+1] 719 // childIndex++; child++; 720 // } 721 // } 722 // } 723 // } 724 // } 725 // 726 static void createShiftDownFunc(OpBuilder &builder, ModuleOp module, 727 func::FuncOp func, uint64_t nx, uint64_t ny, 728 bool isCoo, uint32_t nTrailingP) { 729 // The value n is passed in as a trailing parameter. 730 assert(nTrailingP == 1); 731 OpBuilder::InsertionGuard insertionGuard(builder); 732 Block *entryBlock = func.addEntryBlock(); 733 builder.setInsertionPointToStart(entryBlock); 734 735 Location loc = func.getLoc(); 736 Value n = entryBlock->getArguments().back(); 737 ValueRange args = entryBlock->getArguments().drop_back(); 738 Value first = args[loIdx]; 739 Value start = args[hiIdx]; 740 741 // If (n >= 2). 742 Value c2 = constantIndex(builder, loc, 2); 743 Value condN = 744 builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::uge, n, c2); 745 scf::IfOp ifN = builder.create<scf::IfOp>(loc, condN, /*else=*/false); 746 builder.setInsertionPointToStart(&ifN.getThenRegion().front()); 747 Value child = builder.create<arith::SubIOp>(loc, start, first); 748 749 // If ((n-2)/2 >= child). 750 Value t = createSubTwoDividedByTwo(builder, loc, n); 751 Value condNc = 752 builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::uge, t, child); 753 scf::IfOp ifNc = builder.create<scf::IfOp>(loc, condNc, /*else=*/false); 754 755 builder.setInsertionPointToStart(&ifNc.getThenRegion().front()); 756 Value c1 = constantIndex(builder, loc, 1); 757 SmallVector<Value> compareOperands{start, start}; 758 uint64_t numXBuffers = isCoo ? 1 : nx; 759 compareOperands.append(args.begin() + xStartIdx, 760 args.begin() + xStartIdx + numXBuffers); 761 Type i1Type = IntegerType::get(module.getContext(), 1, IntegerType::Signless); 762 FlatSymbolRefAttr lessThanFunc = getMangledSortHelperFunc( 763 builder, func, {i1Type}, kLessThanFuncNamePrefix, nx, ny, isCoo, 764 compareOperands, createLessThanFunc); 765 766 // Generate code to inspect the children of 'r' and return the larger child 767 // as follows: 768 // child = r * 2 + 1 // Left child. 769 // childIndex = child + first 770 // if (child+1 < n && data[childIndex] < data[childIndex+1]) 771 // childIndex ++; child ++ // Right child is bigger. 772 auto getLargerChild = [&](Value r) -> std::pair<Value, Value> { 773 Value lChild = builder.create<arith::ShLIOp>(loc, r, c1); 774 lChild = builder.create<arith::AddIOp>(loc, lChild, c1); 775 Value lChildIdx = builder.create<arith::AddIOp>(loc, lChild, first); 776 Value rChild = builder.create<arith::AddIOp>(loc, lChild, c1); 777 Value cond1 = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult, 778 rChild, n); 779 SmallVector<Type, 2> ifTypes(2, r.getType()); 780 scf::IfOp if1 = 781 builder.create<scf::IfOp>(loc, ifTypes, cond1, /*else=*/true); 782 builder.setInsertionPointToStart(&if1.getThenRegion().front()); 783 Value rChildIdx = builder.create<arith::AddIOp>(loc, rChild, first); 784 // Compare data[left] < data[right]. 785 compareOperands[0] = lChildIdx; 786 compareOperands[1] = rChildIdx; 787 Value cond2 = builder 788 .create<func::CallOp>(loc, lessThanFunc, 789 TypeRange{i1Type}, compareOperands) 790 .getResult(0); 791 scf::IfOp if2 = 792 builder.create<scf::IfOp>(loc, ifTypes, cond2, /*else=*/true); 793 builder.setInsertionPointToStart(&if2.getThenRegion().front()); 794 builder.create<scf::YieldOp>(loc, ValueRange{rChild, rChildIdx}); 795 builder.setInsertionPointToStart(&if2.getElseRegion().front()); 796 builder.create<scf::YieldOp>(loc, ValueRange{lChild, lChildIdx}); 797 builder.setInsertionPointAfter(if2); 798 builder.create<scf::YieldOp>(loc, if2.getResults()); 799 builder.setInsertionPointToStart(&if1.getElseRegion().front()); 800 builder.create<scf::YieldOp>(loc, ValueRange{lChild, lChildIdx}); 801 builder.setInsertionPointAfter(if1); 802 return std::make_pair(if1.getResult(0), if1.getResult(1)); 803 }; 804 805 Value childIdx; 806 std::tie(child, childIdx) = getLargerChild(child); 807 808 // While (data[start] < data[childIndex]). 809 SmallVector<Type, 3> types(3, child.getType()); 810 scf::WhileOp whileOp = builder.create<scf::WhileOp>( 811 loc, types, SmallVector<Value, 2>{start, child, childIdx}); 812 813 // The before-region of the WhileOp. 814 SmallVector<Location, 3> locs(3, loc); 815 Block *before = builder.createBlock(&whileOp.getBefore(), {}, types, locs); 816 builder.setInsertionPointToEnd(before); 817 start = before->getArgument(0); 818 childIdx = before->getArgument(2); 819 compareOperands[0] = start; 820 compareOperands[1] = childIdx; 821 Value cond = builder 822 .create<func::CallOp>(loc, lessThanFunc, TypeRange{i1Type}, 823 compareOperands) 824 .getResult(0); 825 builder.create<scf::ConditionOp>(loc, cond, before->getArguments()); 826 827 // The after-region of the WhileOp. 828 Block *after = builder.createBlock(&whileOp.getAfter(), {}, types, locs); 829 start = after->getArgument(0); 830 child = after->getArgument(1); 831 childIdx = after->getArgument(2); 832 SmallVector<Value> swapOperands{start, childIdx}; 833 swapOperands.append(args.begin() + xStartIdx, args.end()); 834 createSwap(builder, loc, swapOperands, nx, ny, isCoo); 835 start = childIdx; 836 Value cond2 = 837 builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::uge, t, child); 838 scf::IfOp if2 = builder.create<scf::IfOp>( 839 loc, TypeRange{child.getType(), child.getType()}, cond2, /*else=*/true); 840 builder.setInsertionPointToStart(&if2.getThenRegion().front()); 841 auto [newChild, newChildIdx] = getLargerChild(child); 842 builder.create<scf::YieldOp>(loc, ValueRange{newChild, newChildIdx}); 843 builder.setInsertionPointToStart(&if2.getElseRegion().front()); 844 builder.create<scf::YieldOp>(loc, ValueRange{child, childIdx}); 845 builder.setInsertionPointAfter(if2); 846 builder.create<scf::YieldOp>( 847 loc, ValueRange{start, if2.getResult(0), if2.getResult(1)}); 848 849 builder.setInsertionPointAfter(ifN); 850 builder.create<func::ReturnOp>(loc); 851 } 852 853 /// Creates a function to perform heap sort on the values in the range of index 854 /// [lo, hi) with the assumption hi - lo >= 2. 855 // 856 // The generate IR corresponds to this C like algorithm: 857 // void heapSort(lo, hi, data) { 858 // n = hi - lo 859 // for i = (n-2)/2 downto 0 860 // shiftDown(lo, lo+i, n) 861 // 862 // for l = n downto 2 863 // swap(lo, lo+l-1) 864 // shiftdown(lo, lo, l-1) 865 // } 866 static void createHeapSortFunc(OpBuilder &builder, ModuleOp module, 867 func::FuncOp func, uint64_t nx, uint64_t ny, 868 bool isCoo, uint32_t nTrailingP) { 869 // Heap sort function doesn't have trailing parameters. 870 (void)nTrailingP; 871 assert(nTrailingP == 0); 872 OpBuilder::InsertionGuard insertionGuard(builder); 873 Block *entryBlock = func.addEntryBlock(); 874 builder.setInsertionPointToStart(entryBlock); 875 876 Location loc = func.getLoc(); 877 ValueRange args = entryBlock->getArguments(); 878 Value lo = args[loIdx]; 879 Value hi = args[hiIdx]; 880 Value n = builder.create<arith::SubIOp>(loc, hi, lo); 881 882 // For i = (n-2)/2 downto 0. 883 Value c0 = constantIndex(builder, loc, 0); 884 Value c1 = constantIndex(builder, loc, 1); 885 Value s = createSubTwoDividedByTwo(builder, loc, n); 886 Value up = builder.create<arith::AddIOp>(loc, s, c1); 887 scf::ForOp forI = builder.create<scf::ForOp>(loc, c0, up, c1); 888 builder.setInsertionPointToStart(forI.getBody()); 889 Value i = builder.create<arith::SubIOp>(loc, s, forI.getInductionVar()); 890 Value lopi = builder.create<arith::AddIOp>(loc, lo, i); 891 SmallVector<Value> shiftDownOperands = {lo, lopi}; 892 shiftDownOperands.append(args.begin() + xStartIdx, args.end()); 893 shiftDownOperands.push_back(n); 894 FlatSymbolRefAttr shiftDownFunc = getMangledSortHelperFunc( 895 builder, func, TypeRange(), kShiftDownFuncNamePrefix, nx, ny, isCoo, 896 shiftDownOperands, createShiftDownFunc, /*nTrailingP=*/1); 897 builder.create<func::CallOp>(loc, shiftDownFunc, TypeRange(), 898 shiftDownOperands); 899 900 builder.setInsertionPointAfter(forI); 901 // For l = n downto 2. 902 up = builder.create<arith::SubIOp>(loc, n, c1); 903 scf::ForOp forL = builder.create<scf::ForOp>(loc, c0, up, c1); 904 builder.setInsertionPointToStart(forL.getBody()); 905 Value l = builder.create<arith::SubIOp>(loc, n, forL.getInductionVar()); 906 Value loplm1 = builder.create<arith::AddIOp>(loc, lo, l); 907 loplm1 = builder.create<arith::SubIOp>(loc, loplm1, c1); 908 SmallVector<Value> swapOperands{lo, loplm1}; 909 swapOperands.append(args.begin() + xStartIdx, args.end()); 910 createSwap(builder, loc, swapOperands, nx, ny, isCoo); 911 shiftDownOperands[1] = lo; 912 shiftDownOperands[shiftDownOperands.size() - 1] = 913 builder.create<arith::SubIOp>(loc, l, c1); 914 builder.create<func::CallOp>(loc, shiftDownFunc, TypeRange(), 915 shiftDownOperands); 916 917 builder.setInsertionPointAfter(forL); 918 builder.create<func::ReturnOp>(loc); 919 } 920 921 static void createQuickSort(OpBuilder &builder, ModuleOp module, 922 func::FuncOp func, ValueRange args, uint64_t nx, 923 uint64_t ny, bool isCoo, uint32_t nTrailingP) { 924 MLIRContext *context = module.getContext(); 925 Location loc = func.getLoc(); 926 Value lo = args[loIdx]; 927 Value hi = args[hiIdx]; 928 FlatSymbolRefAttr partitionFunc = getMangledSortHelperFunc( 929 builder, func, {IndexType::get(context)}, kPartitionFuncNamePrefix, nx, 930 ny, isCoo, args.drop_back(nTrailingP), createPartitionFunc); 931 auto p = builder.create<func::CallOp>(loc, partitionFunc, 932 TypeRange{IndexType::get(context)}, 933 args.drop_back(nTrailingP)); 934 935 SmallVector<Value> lowOperands{lo, p.getResult(0)}; 936 lowOperands.append(args.begin() + xStartIdx, args.end()); 937 builder.create<func::CallOp>(loc, func, lowOperands); 938 939 SmallVector<Value> highOperands{ 940 builder.create<arith::AddIOp>(loc, p.getResult(0), 941 constantIndex(builder, loc, 1)), 942 hi}; 943 highOperands.append(args.begin() + xStartIdx, args.end()); 944 builder.create<func::CallOp>(loc, func, highOperands); 945 } 946 947 /// Creates a function to perform insertion sort on the values in the range of 948 /// index [lo, hi). 949 // 950 // The generate IR corresponds to this C like algorithm: 951 // void insertionSort(lo, hi, data) { 952 // for (i = lo+1; i < hi; i++) { 953 // d = data[i]; 954 // p = binarySearch(lo, i-1, data) 955 // for (j = 0; j > i - p; j++) 956 // data[i-j] = data[i-j-1] 957 // data[p] = d 958 // } 959 // } 960 static void createSortStableFunc(OpBuilder &builder, ModuleOp module, 961 func::FuncOp func, uint64_t nx, uint64_t ny, 962 bool isCoo, uint32_t nTrailingP) { 963 // Stable sort function doesn't use trailing parameters. 964 (void)nTrailingP; 965 assert(nTrailingP == 0); 966 OpBuilder::InsertionGuard insertionGuard(builder); 967 Block *entryBlock = func.addEntryBlock(); 968 builder.setInsertionPointToStart(entryBlock); 969 970 MLIRContext *context = module.getContext(); 971 Location loc = func.getLoc(); 972 ValueRange args = entryBlock->getArguments(); 973 Value c1 = constantIndex(builder, loc, 1); 974 Value lo = args[loIdx]; 975 Value hi = args[hiIdx]; 976 Value lop1 = builder.create<arith::AddIOp>(loc, lo, c1); 977 978 // Start the outer for-stmt with induction variable i. 979 scf::ForOp forOpI = builder.create<scf::ForOp>(loc, lop1, hi, c1); 980 builder.setInsertionPointToStart(forOpI.getBody()); 981 Value i = forOpI.getInductionVar(); 982 983 // Binary search to find the insertion point p. 984 SmallVector<Value> operands{lo, i}; 985 operands.append(args.begin() + xStartIdx, args.end()); 986 FlatSymbolRefAttr searchFunc = getMangledSortHelperFunc( 987 builder, func, {IndexType::get(context)}, kBinarySearchFuncNamePrefix, nx, 988 ny, isCoo, operands, createBinarySearchFunc); 989 Value p = builder 990 .create<func::CallOp>(loc, searchFunc, TypeRange{c1.getType()}, 991 operands) 992 .getResult(0); 993 994 // Move the value at data[i] to a temporary location. 995 operands[0] = operands[1] = i; 996 SmallVector<Value> d; 997 forEachIJPairInAllBuffers( 998 builder, loc, operands, nx, ny, isCoo, 999 [&](uint64_t unused, Value i, Value unused2, Value buffer) { 1000 d.push_back(builder.create<memref::LoadOp>(loc, buffer, i)); 1001 }); 1002 1003 // Start the inner for-stmt with induction variable j, for moving data[p..i) 1004 // to data[p+1..i+1). 1005 Value imp = builder.create<arith::SubIOp>(loc, i, p); 1006 Value c0 = constantIndex(builder, loc, 0); 1007 scf::ForOp forOpJ = builder.create<scf::ForOp>(loc, c0, imp, c1); 1008 builder.setInsertionPointToStart(forOpJ.getBody()); 1009 Value j = forOpJ.getInductionVar(); 1010 Value imj = builder.create<arith::SubIOp>(loc, i, j); 1011 operands[1] = imj; 1012 operands[0] = builder.create<arith::SubIOp>(loc, imj, c1); 1013 forEachIJPairInAllBuffers( 1014 builder, loc, operands, nx, ny, isCoo, 1015 [&](uint64_t unused, Value imjm1, Value imj, Value buffer) { 1016 Value t = builder.create<memref::LoadOp>(loc, buffer, imjm1); 1017 builder.create<memref::StoreOp>(loc, t, buffer, imj); 1018 }); 1019 1020 // Store the value at data[i] to data[p]. 1021 builder.setInsertionPointAfter(forOpJ); 1022 operands[0] = operands[1] = p; 1023 forEachIJPairInAllBuffers( 1024 builder, loc, operands, nx, ny, isCoo, 1025 [&](uint64_t k, Value p, Value usused, Value buffer) { 1026 builder.create<memref::StoreOp>(loc, d[k], buffer, p); 1027 }); 1028 1029 builder.setInsertionPointAfter(forOpI); 1030 builder.create<func::ReturnOp>(loc); 1031 } 1032 1033 /// Creates a function to perform quick sort or a hybrid quick sort on the 1034 /// values in the range of index [lo, hi). 1035 // 1036 // 1037 // When nTrailingP == 0, the generated IR corresponds to this C like algorithm: 1038 // void quickSort(lo, hi, data) { 1039 // if (lo + 1 < hi) { 1040 // p = partition(low, high, data); 1041 // quickSort(lo, p, data); 1042 // quickSort(p + 1, hi, data); 1043 // } 1044 // } 1045 // 1046 // When nTrailingP == 1, the generated IR corresponds to this C like algorithm: 1047 // void hybridQuickSort(lo, hi, data, depthLimit) { 1048 // if (lo + 1 < hi) { 1049 // len = hi - lo; 1050 // if (len <= limit) { 1051 // insertionSort(lo, hi, data); 1052 // } else { 1053 // depthLimit --; 1054 // if (depthLimit <= 0) { 1055 // heapSort(lo, hi, data); 1056 // } else { 1057 // p = partition(low, high, data); 1058 // quickSort(lo, p, data); 1059 // quickSort(p + 1, hi, data); 1060 // } 1061 // depthLimit ++; 1062 // } 1063 // } 1064 // } 1065 // 1066 static void createQuickSortFunc(OpBuilder &builder, ModuleOp module, 1067 func::FuncOp func, uint64_t nx, uint64_t ny, 1068 bool isCoo, uint32_t nTrailingP) { 1069 assert(nTrailingP == 1 || nTrailingP == 0); 1070 bool isHybrid = (nTrailingP == 1); 1071 OpBuilder::InsertionGuard insertionGuard(builder); 1072 Block *entryBlock = func.addEntryBlock(); 1073 builder.setInsertionPointToStart(entryBlock); 1074 1075 Location loc = func.getLoc(); 1076 ValueRange args = entryBlock->getArguments(); 1077 Value lo = args[loIdx]; 1078 Value hi = args[hiIdx]; 1079 Value loCmp = 1080 builder.create<arith::AddIOp>(loc, lo, constantIndex(builder, loc, 1)); 1081 Value cond = 1082 builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult, loCmp, hi); 1083 scf::IfOp ifOp = builder.create<scf::IfOp>(loc, cond, /*else=*/false); 1084 1085 // The if-stmt true branch. 1086 builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); 1087 Value pDepthLimit; 1088 Value savedDepthLimit; 1089 scf::IfOp depthIf; 1090 1091 if (isHybrid) { 1092 Value len = builder.create<arith::SubIOp>(loc, hi, lo); 1093 Value lenLimit = constantIndex(builder, loc, 30); 1094 Value lenCond = builder.create<arith::CmpIOp>( 1095 loc, arith::CmpIPredicate::ule, len, lenLimit); 1096 scf::IfOp lenIf = builder.create<scf::IfOp>(loc, lenCond, /*else=*/true); 1097 1098 // When len <= limit. 1099 builder.setInsertionPointToStart(&lenIf.getThenRegion().front()); 1100 FlatSymbolRefAttr insertionSortFunc = getMangledSortHelperFunc( 1101 builder, func, TypeRange(), kSortStableFuncNamePrefix, nx, ny, isCoo, 1102 args.drop_back(nTrailingP), createSortStableFunc); 1103 builder.create<func::CallOp>(loc, insertionSortFunc, TypeRange(), 1104 ValueRange(args.drop_back(nTrailingP))); 1105 1106 // When len > limit. 1107 builder.setInsertionPointToStart(&lenIf.getElseRegion().front()); 1108 pDepthLimit = args.back(); 1109 savedDepthLimit = builder.create<memref::LoadOp>(loc, pDepthLimit); 1110 Value depthLimit = builder.create<arith::SubIOp>( 1111 loc, savedDepthLimit, constantI64(builder, loc, 1)); 1112 builder.create<memref::StoreOp>(loc, depthLimit, pDepthLimit); 1113 Value depthCond = 1114 builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ule, 1115 depthLimit, constantI64(builder, loc, 0)); 1116 depthIf = builder.create<scf::IfOp>(loc, depthCond, /*else=*/true); 1117 1118 // When depth exceeds limit. 1119 builder.setInsertionPointToStart(&depthIf.getThenRegion().front()); 1120 FlatSymbolRefAttr heapSortFunc = getMangledSortHelperFunc( 1121 builder, func, TypeRange(), kHeapSortFuncNamePrefix, nx, ny, isCoo, 1122 args.drop_back(nTrailingP), createHeapSortFunc); 1123 builder.create<func::CallOp>(loc, heapSortFunc, TypeRange(), 1124 ValueRange(args.drop_back(nTrailingP))); 1125 1126 // When depth doesn't exceed limit. 1127 builder.setInsertionPointToStart(&depthIf.getElseRegion().front()); 1128 } 1129 1130 createQuickSort(builder, module, func, args, nx, ny, isCoo, nTrailingP); 1131 1132 if (isHybrid) { 1133 // Restore depthLimit. 1134 builder.setInsertionPointAfter(depthIf); 1135 builder.create<memref::StoreOp>(loc, savedDepthLimit, pDepthLimit); 1136 } 1137 1138 // After the if-stmt. 1139 builder.setInsertionPointAfter(ifOp); 1140 builder.create<func::ReturnOp>(loc); 1141 } 1142 1143 /// Implements the rewriting for operator sort and sort_coo. 1144 template <typename OpTy> 1145 LogicalResult matchAndRewriteSortOp(OpTy op, ValueRange xys, uint64_t nx, 1146 uint64_t ny, bool isCoo, 1147 PatternRewriter &rewriter) { 1148 Location loc = op.getLoc(); 1149 SmallVector<Value> operands{constantIndex(rewriter, loc, 0), op.getN()}; 1150 1151 // Convert `values` to have dynamic shape and append them to `operands`. 1152 for (Value v : xys) { 1153 auto mtp = getMemRefType(v); 1154 if (!mtp.isDynamicDim(0)) { 1155 auto newMtp = 1156 MemRefType::get({ShapedType::kDynamic}, mtp.getElementType()); 1157 v = rewriter.create<memref::CastOp>(loc, newMtp, v); 1158 } 1159 operands.push_back(v); 1160 } 1161 1162 auto insertPoint = op->template getParentOfType<func::FuncOp>(); 1163 SmallString<32> funcName; 1164 FuncGeneratorType funcGenerator; 1165 uint32_t nTrailingP = 0; 1166 switch (op.getAlgorithm()) { 1167 case SparseTensorSortKind::HybridQuickSort: { 1168 funcName = kHybridQuickSortFuncNamePrefix; 1169 funcGenerator = createQuickSortFunc; 1170 nTrailingP = 1; 1171 Value pDepthLimit = rewriter.create<memref::AllocaOp>( 1172 loc, MemRefType::get({}, rewriter.getI64Type())); 1173 operands.push_back(pDepthLimit); 1174 // As a heuristics, set depthLimit = 2 * log2(n). 1175 Value lo = operands[loIdx]; 1176 Value hi = operands[hiIdx]; 1177 Value len = rewriter.create<arith::IndexCastOp>( 1178 loc, rewriter.getI64Type(), 1179 rewriter.create<arith::SubIOp>(loc, hi, lo)); 1180 Value depthLimit = rewriter.create<arith::SubIOp>( 1181 loc, constantI64(rewriter, loc, 64), 1182 rewriter.create<math::CountLeadingZerosOp>(loc, len)); 1183 depthLimit = rewriter.create<arith::ShLIOp>(loc, depthLimit, 1184 constantI64(rewriter, loc, 1)); 1185 rewriter.create<memref::StoreOp>(loc, depthLimit, pDepthLimit); 1186 break; 1187 } 1188 case SparseTensorSortKind::QuickSort: 1189 funcName = kQuickSortFuncNamePrefix; 1190 funcGenerator = createQuickSortFunc; 1191 break; 1192 case SparseTensorSortKind::InsertionSortStable: 1193 funcName = kSortStableFuncNamePrefix; 1194 funcGenerator = createSortStableFunc; 1195 break; 1196 case SparseTensorSortKind::HeapSort: 1197 funcName = kHeapSortFuncNamePrefix; 1198 funcGenerator = createHeapSortFunc; 1199 break; 1200 } 1201 1202 FlatSymbolRefAttr func = 1203 getMangledSortHelperFunc(rewriter, insertPoint, TypeRange(), funcName, nx, 1204 ny, isCoo, operands, funcGenerator, nTrailingP); 1205 rewriter.replaceOpWithNewOp<func::CallOp>(op, func, TypeRange(), operands); 1206 return success(); 1207 } 1208 1209 //===---------------------------------------------------------------------===// 1210 // The actual sparse buffer rewriting rules. 1211 //===---------------------------------------------------------------------===// 1212 1213 namespace { 1214 1215 /// Sparse rewriting rule for the push_back operator. 1216 struct PushBackRewriter : OpRewritePattern<PushBackOp> { 1217 public: 1218 using OpRewritePattern<PushBackOp>::OpRewritePattern; 1219 PushBackRewriter(MLIRContext *context, bool enableInit) 1220 : OpRewritePattern(context), enableBufferInitialization(enableInit) {} 1221 LogicalResult matchAndRewrite(PushBackOp op, 1222 PatternRewriter &rewriter) const override { 1223 // Rewrite push_back(buffer, value, n) to: 1224 // new_size = size(buffer) + n 1225 // if (new_size > capacity(buffer)) 1226 // while new_size > new_capacity 1227 // new_capacity = new_capacity*2 1228 // new_buffer = realloc(buffer, new_capacity) 1229 // buffer = new_buffer 1230 // subBuffer = subviewof(buffer) 1231 // linalg.fill subBuffer value 1232 // 1233 // size(buffer) += n 1234 // 1235 // The capacity check is skipped when the attribute inbounds is presented. 1236 Location loc = op->getLoc(); 1237 Value c0 = constantIndex(rewriter, loc, 0); 1238 Value buffer = op.getInBuffer(); 1239 Value capacity = rewriter.create<memref::DimOp>(loc, buffer, c0); 1240 Value size = op.getCurSize(); 1241 Value value = op.getValue(); 1242 1243 Value n = op.getN() ? op.getN() : constantIndex(rewriter, loc, 1); 1244 Value newSize = rewriter.create<arith::AddIOp>(loc, size, n); 1245 auto nValue = dyn_cast_or_null<arith::ConstantIndexOp>(n.getDefiningOp()); 1246 bool nIsOne = (nValue && nValue.value() == 1); 1247 1248 if (!op.getInbounds()) { 1249 Value cond = rewriter.create<arith::CmpIOp>( 1250 loc, arith::CmpIPredicate::ugt, newSize, capacity); 1251 1252 Value c2 = constantIndex(rewriter, loc, 2); 1253 auto bufferType = 1254 MemRefType::get({ShapedType::kDynamic}, value.getType()); 1255 scf::IfOp ifOp = rewriter.create<scf::IfOp>(loc, bufferType, cond, 1256 /*else=*/true); 1257 // True branch. 1258 rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front()); 1259 if (nIsOne) { 1260 capacity = rewriter.create<arith::MulIOp>(loc, capacity, c2); 1261 } else { 1262 // Use a do-while loop to calculate the new capacity as follows: 1263 // do { new_capacity *= 2 } while (size > new_capacity) 1264 scf::WhileOp whileOp = 1265 rewriter.create<scf::WhileOp>(loc, capacity.getType(), capacity); 1266 1267 // The before-region of the WhileOp. 1268 Block *before = rewriter.createBlock(&whileOp.getBefore(), {}, 1269 {capacity.getType()}, {loc}); 1270 rewriter.setInsertionPointToEnd(before); 1271 1272 capacity = 1273 rewriter.create<arith::MulIOp>(loc, before->getArgument(0), c2); 1274 cond = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ugt, 1275 newSize, capacity); 1276 rewriter.create<scf::ConditionOp>(loc, cond, ValueRange{capacity}); 1277 // The after-region of the WhileOp. 1278 Block *after = rewriter.createBlock(&whileOp.getAfter(), {}, 1279 {capacity.getType()}, {loc}); 1280 rewriter.setInsertionPointToEnd(after); 1281 rewriter.create<scf::YieldOp>(loc, after->getArguments()); 1282 1283 rewriter.setInsertionPointAfter(whileOp); 1284 capacity = whileOp.getResult(0); 1285 } 1286 1287 Value newBuffer = 1288 rewriter.create<memref::ReallocOp>(loc, bufferType, buffer, capacity); 1289 if (enableBufferInitialization) { 1290 Value fillSize = rewriter.create<arith::SubIOp>(loc, capacity, newSize); 1291 Value fillValue = constantZero(rewriter, loc, value.getType()); 1292 Value subBuffer = rewriter.create<memref::SubViewOp>( 1293 loc, newBuffer, /*offset=*/ValueRange{newSize}, 1294 /*size=*/ValueRange{fillSize}, 1295 /*step=*/ValueRange{constantIndex(rewriter, loc, 1)}); 1296 rewriter.create<linalg::FillOp>(loc, fillValue, subBuffer); 1297 } 1298 rewriter.create<scf::YieldOp>(loc, newBuffer); 1299 1300 // False branch. 1301 rewriter.setInsertionPointToStart(&ifOp.getElseRegion().front()); 1302 rewriter.create<scf::YieldOp>(loc, buffer); 1303 1304 // Prepare for adding the value to the end of the buffer. 1305 rewriter.setInsertionPointAfter(ifOp); 1306 buffer = ifOp.getResult(0); 1307 } 1308 1309 // Add the value to the end of the buffer. 1310 if (nIsOne) { 1311 rewriter.create<memref::StoreOp>(loc, value, buffer, size); 1312 } else { 1313 Value subBuffer = rewriter.create<memref::SubViewOp>( 1314 loc, buffer, /*offset=*/ValueRange{size}, /*size=*/ValueRange{n}, 1315 /*step=*/ValueRange{constantIndex(rewriter, loc, 1)}); 1316 rewriter.create<linalg::FillOp>(loc, value, subBuffer); 1317 } 1318 1319 // Update the buffer size. 1320 rewriter.replaceOp(op, {buffer, newSize}); 1321 return success(); 1322 } 1323 1324 private: 1325 bool enableBufferInitialization; 1326 }; 1327 1328 /// Sparse rewriting rule for the sort operator. 1329 struct SortRewriter : public OpRewritePattern<SortOp> { 1330 public: 1331 using OpRewritePattern<SortOp>::OpRewritePattern; 1332 1333 LogicalResult matchAndRewrite(SortOp op, 1334 PatternRewriter &rewriter) const override { 1335 SmallVector<Value> xys(op.getXs()); 1336 xys.append(op.getYs().begin(), op.getYs().end()); 1337 return matchAndRewriteSortOp(op, xys, op.getXs().size(), /*ny=*/0, 1338 /*isCoo=*/false, rewriter); 1339 } 1340 }; 1341 1342 /// Sparse rewriting rule for the sort_coo operator. 1343 struct SortCooRewriter : public OpRewritePattern<SortCooOp> { 1344 public: 1345 using OpRewritePattern<SortCooOp>::OpRewritePattern; 1346 1347 LogicalResult matchAndRewrite(SortCooOp op, 1348 PatternRewriter &rewriter) const override { 1349 SmallVector<Value> xys; 1350 xys.push_back(op.getXy()); 1351 xys.append(op.getYs().begin(), op.getYs().end()); 1352 uint64_t nx = 1; 1353 if (auto nxAttr = op.getNxAttr()) 1354 nx = nxAttr.getInt(); 1355 1356 uint64_t ny = 0; 1357 if (auto nyAttr = op.getNyAttr()) 1358 ny = nyAttr.getInt(); 1359 1360 return matchAndRewriteSortOp(op, xys, nx, ny, 1361 /*isCoo=*/true, rewriter); 1362 } 1363 }; 1364 1365 } // namespace 1366 1367 //===---------------------------------------------------------------------===// 1368 // Methods that add patterns described in this file to a pattern list. 1369 //===---------------------------------------------------------------------===// 1370 1371 void mlir::populateSparseBufferRewriting(RewritePatternSet &patterns, 1372 bool enableBufferInitialization) { 1373 patterns.add<PushBackRewriter>(patterns.getContext(), 1374 enableBufferInitialization); 1375 patterns.add<SortRewriter, SortCooRewriter>(patterns.getContext()); 1376 } 1377