1 //===- SparseBufferRewriting.cpp - Sparse buffer rewriting rules ----------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements rewriting rules that are specific to sparse tensor 10 // primitives with memref operands. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "CodegenUtils.h" 15 16 #include "mlir/Dialect/Arith/IR/Arith.h" 17 #include "mlir/Dialect/Func/IR/FuncOps.h" 18 #include "mlir/Dialect/Linalg/IR/Linalg.h" 19 #include "mlir/Dialect/MemRef/IR/MemRef.h" 20 #include "mlir/Dialect/SCF/IR/SCF.h" 21 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" 22 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h" 23 #include "mlir/Support/LLVM.h" 24 25 using namespace mlir; 26 using namespace mlir::sparse_tensor; 27 28 //===---------------------------------------------------------------------===// 29 // Helper methods for the actual rewriting rules. 30 //===---------------------------------------------------------------------===// 31 32 static constexpr uint64_t loIdx = 0; 33 static constexpr uint64_t hiIdx = 1; 34 static constexpr uint64_t xStartIdx = 2; 35 36 static constexpr const char kLessThanFuncNamePrefix[] = "_sparse_less_than_"; 37 static constexpr const char kCompareEqFuncNamePrefix[] = "_sparse_compare_eq_"; 38 static constexpr const char kPartitionFuncNamePrefix[] = "_sparse_partition_"; 39 static constexpr const char kBinarySearchFuncNamePrefix[] = 40 "_sparse_binary_search_"; 41 static constexpr const char kSortNonstableFuncNamePrefix[] = 42 "_sparse_sort_nonstable_"; 43 static constexpr const char kSortStableFuncNamePrefix[] = 44 "_sparse_sort_stable_"; 45 46 using FuncGeneratorType = function_ref<void(OpBuilder &, ModuleOp, func::FuncOp, 47 uint64_t, uint64_t, bool)>; 48 49 /// Constructs a function name with this format to facilitate quick sort: 50 /// <namePrefix><nx>_<x type>_<y0 type>..._<yn type> for sort 51 /// <namePrefix><nx>_<x type>_coo_<ny>_<y0 type>..._<yn type> for sort_coo 52 static void getMangledSortHelperFuncName(llvm::raw_svector_ostream &nameOstream, 53 StringRef namePrefix, uint64_t nx, 54 uint64_t ny, bool isCoo, 55 ValueRange operands) { 56 nameOstream << namePrefix << nx << "_" 57 << getMemRefType(operands[xStartIdx]).getElementType(); 58 59 if (isCoo) 60 nameOstream << "_coo_" << ny; 61 62 uint64_t yBufferOffset = isCoo ? 1 : nx; 63 for (Value v : operands.drop_front(xStartIdx + yBufferOffset)) 64 nameOstream << "_" << getMemRefType(v).getElementType(); 65 } 66 67 /// Looks up a function that is appropriate for the given operands being 68 /// sorted, and creates such a function if it doesn't exist yet. The 69 /// parameters `nx` and `ny` tell the number of x and y values provided 70 /// by the buffer in xStartIdx, and `isCoo` indicates whether the instruction 71 /// being processed is a sparse_tensor.sort or sparse_tensor.sort_coo. 72 static FlatSymbolRefAttr 73 getMangledSortHelperFunc(OpBuilder &builder, func::FuncOp insertPoint, 74 TypeRange resultTypes, StringRef namePrefix, 75 uint64_t nx, uint64_t ny, bool isCoo, 76 ValueRange operands, FuncGeneratorType createFunc) { 77 SmallString<32> nameBuffer; 78 llvm::raw_svector_ostream nameOstream(nameBuffer); 79 getMangledSortHelperFuncName(nameOstream, namePrefix, nx, ny, isCoo, 80 operands); 81 82 ModuleOp module = insertPoint->getParentOfType<ModuleOp>(); 83 MLIRContext *context = module.getContext(); 84 auto result = SymbolRefAttr::get(context, nameOstream.str()); 85 auto func = module.lookupSymbol<func::FuncOp>(result.getAttr()); 86 87 if (!func) { 88 // Create the function. 89 OpBuilder::InsertionGuard insertionGuard(builder); 90 builder.setInsertionPoint(insertPoint); 91 Location loc = insertPoint.getLoc(); 92 func = builder.create<func::FuncOp>( 93 loc, nameOstream.str(), 94 FunctionType::get(context, operands.getTypes(), resultTypes)); 95 func.setPrivate(); 96 createFunc(builder, module, func, nx, ny, isCoo); 97 } 98 99 return result; 100 } 101 102 /// Creates a code block to process each pair of (xs[i], xs[j]) for sorting. 103 /// The code to process the value pairs is generated by `bodyBuilder`. 104 static void forEachIJPairInXs( 105 OpBuilder &builder, Location loc, ValueRange args, uint64_t nx, uint64_t ny, 106 bool isCoo, function_ref<void(uint64_t, Value, Value, Value)> bodyBuilder) { 107 Value iOffset, jOffset; 108 if (isCoo) { 109 Value cstep = constantIndex(builder, loc, nx + ny); 110 iOffset = builder.create<arith::MulIOp>(loc, args[0], cstep); 111 jOffset = builder.create<arith::MulIOp>(loc, args[1], cstep); 112 } 113 for (uint64_t k = 0; k < nx; k++) { 114 scf::IfOp ifOp; 115 Value i, j, buffer; 116 if (isCoo) { 117 Value ck = constantIndex(builder, loc, k); 118 i = builder.create<arith::AddIOp>(loc, ck, iOffset); 119 j = builder.create<arith::AddIOp>(loc, ck, jOffset); 120 buffer = args[xStartIdx]; 121 } else { 122 i = args[0]; 123 j = args[1]; 124 buffer = args[xStartIdx + k]; 125 } 126 bodyBuilder(k, i, j, buffer); 127 } 128 } 129 130 /// Creates a code block to process each pair of (xys[i], xys[j]) for sorting. 131 /// The code to process the value pairs is generated by `bodyBuilder`. 132 static void forEachIJPairInAllBuffers( 133 OpBuilder &builder, Location loc, ValueRange args, uint64_t nx, uint64_t ny, 134 bool isCoo, function_ref<void(uint64_t, Value, Value, Value)> bodyBuilder) { 135 136 // Create code for the first (nx + ny) buffers. When isCoo==true, these 137 // logical buffers are all from the xy buffer of the sort_coo operator. 138 forEachIJPairInXs(builder, loc, args, nx + ny, 0, isCoo, bodyBuilder); 139 140 uint64_t numHandledBuffers = isCoo ? 1 : nx + ny; 141 142 // Create code for the remaining buffers. 143 Value i = args[0]; 144 Value j = args[1]; 145 for (const auto &arg : 146 llvm::enumerate(args.drop_front(xStartIdx + numHandledBuffers))) { 147 bodyBuilder(arg.index() + nx + ny, i, j, arg.value()); 148 } 149 } 150 151 /// Creates a code block for swapping the values in index i and j for all the 152 /// buffers. 153 // 154 // The generated IR corresponds to this C like algorithm: 155 // swap(x0[i], x0[j]); 156 // swap(x1[i], x1[j]); 157 // ... 158 // swap(xn[i], xn[j]); 159 // swap(y0[i], y0[j]); 160 // ... 161 // swap(yn[i], yn[j]); 162 static void createSwap(OpBuilder &builder, Location loc, ValueRange args, 163 uint64_t nx, uint64_t ny, bool isCoo) { 164 auto swapOnePair = [&](uint64_t unused, Value i, Value j, Value buffer) { 165 Value vi = builder.create<memref::LoadOp>(loc, buffer, i); 166 Value vj = builder.create<memref::LoadOp>(loc, buffer, j); 167 builder.create<memref::StoreOp>(loc, vj, buffer, i); 168 builder.create<memref::StoreOp>(loc, vi, buffer, j); 169 }; 170 171 forEachIJPairInAllBuffers(builder, loc, args, nx, ny, isCoo, swapOnePair); 172 } 173 174 /// Creates a function to compare all the (xs[i], xs[j]) pairs. The method to 175 /// compare each pair is create via `compareBuilder`. 176 static void createCompareFuncImplementation( 177 OpBuilder &builder, ModuleOp unused, func::FuncOp func, uint64_t nx, 178 uint64_t ny, bool isCoo, 179 function_ref<scf::IfOp(OpBuilder &, Location, Value, Value, Value, bool)> 180 compareBuilder) { 181 OpBuilder::InsertionGuard insertionGuard(builder); 182 183 Block *entryBlock = func.addEntryBlock(); 184 builder.setInsertionPointToStart(entryBlock); 185 Location loc = func.getLoc(); 186 ValueRange args = entryBlock->getArguments(); 187 188 scf::IfOp topIfOp; 189 auto bodyBuilder = [&](uint64_t k, Value i, Value j, Value buffer) { 190 scf::IfOp ifOp = compareBuilder(builder, loc, i, j, buffer, (k == nx - 1)); 191 if (k == 0) { 192 topIfOp = ifOp; 193 } else { 194 OpBuilder::InsertionGuard insertionGuard(builder); 195 builder.setInsertionPointAfter(ifOp); 196 builder.create<scf::YieldOp>(loc, ifOp.getResult(0)); 197 } 198 }; 199 200 forEachIJPairInXs(builder, loc, args, nx, ny, isCoo, bodyBuilder); 201 202 builder.setInsertionPointAfter(topIfOp); 203 builder.create<func::ReturnOp>(loc, topIfOp.getResult(0)); 204 } 205 206 /// Generates an if-statement to compare whether x[i] is equal to x[j]. 207 static scf::IfOp createEqCompare(OpBuilder &builder, Location loc, Value i, 208 Value j, Value x, bool isLastDim) { 209 Value f = constantI1(builder, loc, false); 210 Value t = constantI1(builder, loc, true); 211 Value vi = builder.create<memref::LoadOp>(loc, x, i); 212 Value vj = builder.create<memref::LoadOp>(loc, x, j); 213 214 Value cond = 215 builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, vi, vj); 216 scf::IfOp ifOp = 217 builder.create<scf::IfOp>(loc, f.getType(), cond, /*else=*/true); 218 219 // x[1] != x[j]: 220 builder.setInsertionPointToStart(&ifOp.getElseRegion().front()); 221 builder.create<scf::YieldOp>(loc, f); 222 223 // x[i] == x[j]: 224 builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); 225 if (isLastDim == 1) { 226 // Finish checking all dimensions. 227 builder.create<scf::YieldOp>(loc, t); 228 } 229 230 return ifOp; 231 } 232 233 /// Creates a function to compare whether xs[i] is equal to xs[j]. 234 // 235 // The generate IR corresponds to this C like algorithm: 236 // if (x0[i] != x0[j]) 237 // return false; 238 // else 239 // if (x1[i] != x1[j]) 240 // return false; 241 // else if (x2[2] != x2[j])) 242 // and so on ... 243 static void createEqCompareFunc(OpBuilder &builder, ModuleOp unused, 244 func::FuncOp func, uint64_t nx, uint64_t ny, 245 bool isCoo) { 246 createCompareFuncImplementation(builder, unused, func, nx, ny, isCoo, 247 createEqCompare); 248 } 249 250 /// Generates an if-statement to compare whether x[i] is less than x[j]. 251 static scf::IfOp createLessThanCompare(OpBuilder &builder, Location loc, 252 Value i, Value j, Value x, 253 bool isLastDim) { 254 Value f = constantI1(builder, loc, false); 255 Value t = constantI1(builder, loc, true); 256 Value vi = builder.create<memref::LoadOp>(loc, x, i); 257 Value vj = builder.create<memref::LoadOp>(loc, x, j); 258 259 Value cond = 260 builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult, vi, vj); 261 scf::IfOp ifOp = 262 builder.create<scf::IfOp>(loc, f.getType(), cond, /*else=*/true); 263 // If (x[i] < x[j]). 264 builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); 265 builder.create<scf::YieldOp>(loc, t); 266 267 builder.setInsertionPointToStart(&ifOp.getElseRegion().front()); 268 if (isLastDim == 1) { 269 // Finish checking all dimensions. 270 builder.create<scf::YieldOp>(loc, f); 271 } else { 272 cond = 273 builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult, vj, vi); 274 scf::IfOp ifOp2 = 275 builder.create<scf::IfOp>(loc, f.getType(), cond, /*else=*/true); 276 // Otherwise if (x[j] < x[i]). 277 builder.setInsertionPointToStart(&ifOp2.getThenRegion().front()); 278 builder.create<scf::YieldOp>(loc, f); 279 280 // Otherwise check the remaining dimensions. 281 builder.setInsertionPointAfter(ifOp2); 282 builder.create<scf::YieldOp>(loc, ifOp2.getResult(0)); 283 // Set up the insertion point for the nested if-stmt that checks the 284 // remaining dimensions. 285 builder.setInsertionPointToStart(&ifOp2.getElseRegion().front()); 286 } 287 288 return ifOp; 289 } 290 291 /// Creates a function to compare whether xs[i] is less than xs[j]. 292 // 293 // The generate IR corresponds to this C like algorithm: 294 // if (x0[i] < x0[j]) 295 // return true; 296 // else if (x0[j] < x0[i]) 297 // return false; 298 // else 299 // if (x1[i] < x1[j]) 300 // return true; 301 // else if (x1[j] < x1[i])) 302 // and so on ... 303 static void createLessThanFunc(OpBuilder &builder, ModuleOp unused, 304 func::FuncOp func, uint64_t nx, uint64_t ny, 305 bool isCoo) { 306 createCompareFuncImplementation(builder, unused, func, nx, ny, isCoo, 307 createLessThanCompare); 308 } 309 310 /// Creates a function to use a binary search to find the insertion point for 311 /// inserting xs[hi] to the sorted values xs[lo..hi). 312 // 313 // The generate IR corresponds to this C like algorithm: 314 // p = hi 315 // while (lo < hi) 316 // mid = (lo + hi) >> 1 317 // if (xs[p] < xs[mid]) 318 // hi = mid 319 // else 320 // lo = mid - 1 321 // return lo; 322 // 323 static void createBinarySearchFunc(OpBuilder &builder, ModuleOp module, 324 func::FuncOp func, uint64_t nx, uint64_t ny, 325 bool isCoo) { 326 OpBuilder::InsertionGuard insertionGuard(builder); 327 Block *entryBlock = func.addEntryBlock(); 328 builder.setInsertionPointToStart(entryBlock); 329 330 Location loc = func.getLoc(); 331 ValueRange args = entryBlock->getArguments(); 332 Value p = args[hiIdx]; 333 SmallVector<Type, 2> types(2, p.getType()); // only two 334 scf::WhileOp whileOp = builder.create<scf::WhileOp>( 335 loc, types, SmallVector<Value, 2>{args[loIdx], args[hiIdx]}); 336 337 // The before-region of the WhileOp. 338 Block *before = 339 builder.createBlock(&whileOp.getBefore(), {}, types, {loc, loc}); 340 builder.setInsertionPointToEnd(before); 341 Value cond1 = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult, 342 before->getArgument(0), 343 before->getArgument(1)); 344 builder.create<scf::ConditionOp>(loc, cond1, before->getArguments()); 345 346 // The after-region of the WhileOp. 347 Block *after = 348 builder.createBlock(&whileOp.getAfter(), {}, types, {loc, loc}); 349 builder.setInsertionPointToEnd(after); 350 Value lo = after->getArgument(0); 351 Value hi = after->getArgument(1); 352 // Compute mid = (lo + hi) >> 1. 353 Value c1 = constantIndex(builder, loc, 1); 354 Value mid = builder.create<arith::ShRUIOp>( 355 loc, builder.create<arith::AddIOp>(loc, lo, hi), c1); 356 Value midp1 = builder.create<arith::AddIOp>(loc, mid, c1); 357 358 // Compare xs[p] < xs[mid]. 359 SmallVector<Value> compareOperands{p, mid}; 360 uint64_t numXBuffers = isCoo ? 1 : nx; 361 compareOperands.append(args.begin() + xStartIdx, 362 args.begin() + xStartIdx + numXBuffers); 363 Type i1Type = IntegerType::get(module.getContext(), 1, IntegerType::Signless); 364 FlatSymbolRefAttr lessThanFunc = getMangledSortHelperFunc( 365 builder, func, {i1Type}, kLessThanFuncNamePrefix, nx, ny, isCoo, 366 compareOperands, createLessThanFunc); 367 Value cond2 = builder 368 .create<func::CallOp>(loc, lessThanFunc, TypeRange{i1Type}, 369 compareOperands) 370 .getResult(0); 371 372 // Update lo and hi for the WhileOp as follows: 373 // if (xs[p] < xs[mid])) 374 // hi = mid; 375 // else 376 // lo = mid + 1; 377 Value newLo = builder.create<arith::SelectOp>(loc, cond2, lo, midp1); 378 Value newHi = builder.create<arith::SelectOp>(loc, cond2, mid, hi); 379 builder.create<scf::YieldOp>(loc, ValueRange{newLo, newHi}); 380 381 builder.setInsertionPointAfter(whileOp); 382 builder.create<func::ReturnOp>(loc, whileOp.getResult(0)); 383 } 384 385 /// Creates code to advance i in a loop based on xs[p] as follows: 386 /// while (xs[i] < xs[p]) i += step (step > 0) 387 /// or 388 /// while (xs[i] > xs[p]) i += step (step < 0) 389 /// The routine returns i as well as a boolean value to indicate whether 390 /// xs[i] == xs[p]. 391 static std::pair<Value, Value> 392 createScanLoop(OpBuilder &builder, ModuleOp module, func::FuncOp func, 393 ValueRange xs, Value i, Value p, uint64_t nx, uint64_t ny, 394 bool isCoo, int step) { 395 Location loc = func.getLoc(); 396 scf::WhileOp whileOp = 397 builder.create<scf::WhileOp>(loc, TypeRange{i.getType()}, ValueRange{i}); 398 399 Block *before = 400 builder.createBlock(&whileOp.getBefore(), {}, {i.getType()}, {loc}); 401 builder.setInsertionPointToEnd(before); 402 SmallVector<Value> compareOperands; 403 if (step > 0) { 404 compareOperands.push_back(before->getArgument(0)); 405 compareOperands.push_back(p); 406 } else { 407 assert(step < 0); 408 compareOperands.push_back(p); 409 compareOperands.push_back(before->getArgument(0)); 410 } 411 compareOperands.append(xs.begin(), xs.end()); 412 MLIRContext *context = module.getContext(); 413 Type i1Type = IntegerType::get(context, 1, IntegerType::Signless); 414 FlatSymbolRefAttr lessThanFunc = getMangledSortHelperFunc( 415 builder, func, {i1Type}, kLessThanFuncNamePrefix, nx, ny, isCoo, 416 compareOperands, createLessThanFunc); 417 Value cond = builder 418 .create<func::CallOp>(loc, lessThanFunc, TypeRange{i1Type}, 419 compareOperands) 420 .getResult(0); 421 builder.create<scf::ConditionOp>(loc, cond, before->getArguments()); 422 423 Block *after = 424 builder.createBlock(&whileOp.getAfter(), {}, {i.getType()}, {loc}); 425 builder.setInsertionPointToEnd(after); 426 Value cs = constantIndex(builder, loc, step); 427 i = builder.create<arith::AddIOp>(loc, after->getArgument(0), cs); 428 builder.create<scf::YieldOp>(loc, ValueRange{i}); 429 i = whileOp.getResult(0); 430 431 builder.setInsertionPointAfter(whileOp); 432 compareOperands[0] = i; 433 compareOperands[1] = p; 434 FlatSymbolRefAttr compareEqFunc = getMangledSortHelperFunc( 435 builder, func, {i1Type}, kCompareEqFuncNamePrefix, nx, ny, isCoo, 436 compareOperands, createEqCompareFunc); 437 Value compareEq = 438 builder 439 .create<func::CallOp>(loc, compareEqFunc, TypeRange{i1Type}, 440 compareOperands) 441 .getResult(0); 442 443 return std::make_pair(whileOp.getResult(0), compareEq); 444 } 445 446 /// Creates a code block to swap the values so that data[mi] is the median among 447 /// data[lo], data[hi], and data[mi]. 448 // The generated code corresponds to this C-like algorithm: 449 // median = mi 450 // if (data[mi] < data[lo]). (if1) 451 // if (data[hi] < data[lo]) (if2) 452 // median = data[hi] < data[mi] ? mi : hi 453 // else 454 // median = lo 455 // else 456 // if data[hi] < data[mi] (if3) 457 // median = data[hi] < data[lo] ? lo : hi 458 // if median != mi swap data[median] with data[mi] 459 static void createChoosePivot(OpBuilder &builder, ModuleOp module, 460 func::FuncOp func, uint64_t nx, uint64_t ny, 461 bool isCoo, Value lo, Value hi, Value mi, 462 ValueRange args) { 463 SmallVector<Value> compareOperands{mi, lo}; 464 uint64_t numXBuffers = isCoo ? 1 : nx; 465 compareOperands.append(args.begin() + xStartIdx, 466 args.begin() + xStartIdx + numXBuffers); 467 Type i1Type = IntegerType::get(module.getContext(), 1, IntegerType::Signless); 468 SmallVector<Type, 1> cmpTypes{i1Type}; 469 FlatSymbolRefAttr lessThanFunc = getMangledSortHelperFunc( 470 builder, func, cmpTypes, kLessThanFuncNamePrefix, nx, ny, isCoo, 471 compareOperands, createLessThanFunc); 472 Location loc = func.getLoc(); 473 // Compare data[mi] < data[lo]. 474 Value cond1 = 475 builder.create<func::CallOp>(loc, lessThanFunc, cmpTypes, compareOperands) 476 .getResult(0); 477 SmallVector<Type, 1> ifTypes{lo.getType()}; 478 scf::IfOp ifOp1 = 479 builder.create<scf::IfOp>(loc, ifTypes, cond1, /*else=*/true); 480 481 // Generate an if-stmt to find the median value, assuming we already know that 482 // data[b] < data[a] and we haven't compare data[c] yet. 483 auto createFindMedian = [&](Value a, Value b, Value c) -> scf::IfOp { 484 compareOperands[0] = c; 485 compareOperands[1] = a; 486 // Compare data[c]] < data[a]. 487 Value cond2 = 488 builder 489 .create<func::CallOp>(loc, lessThanFunc, cmpTypes, compareOperands) 490 .getResult(0); 491 scf::IfOp ifOp2 = 492 builder.create<scf::IfOp>(loc, ifTypes, cond2, /*else=*/true); 493 builder.setInsertionPointToStart(&ifOp2.getThenRegion().front()); 494 compareOperands[0] = c; 495 compareOperands[1] = b; 496 // Compare data[c] < data[b]. 497 Value cond3 = 498 builder 499 .create<func::CallOp>(loc, lessThanFunc, cmpTypes, compareOperands) 500 .getResult(0); 501 builder.create<scf::YieldOp>( 502 loc, ValueRange{builder.create<arith::SelectOp>(loc, cond3, b, c)}); 503 builder.setInsertionPointToStart(&ifOp2.getElseRegion().front()); 504 builder.create<scf::YieldOp>(loc, ValueRange{a}); 505 return ifOp2; 506 }; 507 508 builder.setInsertionPointToStart(&ifOp1.getThenRegion().front()); 509 scf::IfOp ifOp2 = createFindMedian(lo, mi, hi); 510 builder.setInsertionPointAfter(ifOp2); 511 builder.create<scf::YieldOp>(loc, ValueRange{ifOp2.getResult(0)}); 512 513 builder.setInsertionPointToStart(&ifOp1.getElseRegion().front()); 514 scf::IfOp ifOp3 = createFindMedian(mi, lo, hi); 515 516 builder.setInsertionPointAfter(ifOp3); 517 builder.create<scf::YieldOp>(loc, ValueRange{ifOp3.getResult(0)}); 518 519 builder.setInsertionPointAfter(ifOp1); 520 Value median = ifOp1.getResult(0); 521 Value cond = 522 builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ne, mi, median); 523 scf::IfOp ifOp = 524 builder.create<scf::IfOp>(loc, TypeRange(), cond, /*else=*/false); 525 526 SmallVector<Value> swapOperands{median, mi}; 527 swapOperands.append(args.begin() + xStartIdx, args.end()); 528 builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); 529 createSwap(builder, loc, swapOperands, nx, ny, isCoo); 530 builder.setInsertionPointAfter(ifOp); 531 } 532 533 /// Creates a function to perform quick sort partition on the values in the 534 /// range of index [lo, hi), assuming lo < hi. 535 // 536 // The generated IR corresponds to this C like algorithm: 537 // int partition(lo, hi, xs) { 538 // p = (lo+hi)/2 // pivot index 539 // i = lo 540 // j = hi-1 541 // while (i < j) do { 542 // while (xs[i] < xs[p]) i ++; 543 // i_eq = (xs[i] == xs[p]); 544 // while (xs[j] > xs[p]) j --; 545 // j_eq = (xs[j] == xs[p]); 546 // if (i < j) { 547 // swap(xs[i], xs[j]) 548 // if (i == p) { 549 // p = j; 550 // } else if (j == p) { 551 // p = i; 552 // } 553 // if (i_eq && j_eq) { 554 // ++i; 555 // --j; 556 // } 557 // } 558 // } 559 // return p 560 // } 561 static void createPartitionFunc(OpBuilder &builder, ModuleOp module, 562 func::FuncOp func, uint64_t nx, uint64_t ny, 563 bool isCoo) { 564 OpBuilder::InsertionGuard insertionGuard(builder); 565 566 Block *entryBlock = func.addEntryBlock(); 567 builder.setInsertionPointToStart(entryBlock); 568 569 Location loc = func.getLoc(); 570 ValueRange args = entryBlock->getArguments(); 571 Value lo = args[loIdx]; 572 Value hi = args[hiIdx]; 573 Value sum = builder.create<arith::AddIOp>(loc, lo, hi); 574 Value c1 = constantIndex(builder, loc, 1); 575 Value p = builder.create<arith::ShRUIOp>(loc, sum, c1); 576 577 Value i = lo; 578 Value j = builder.create<arith::SubIOp>(loc, hi, c1); 579 createChoosePivot(builder, module, func, nx, ny, isCoo, i, j, p, args); 580 SmallVector<Value, 3> operands{i, j, p}; // Exactly three values. 581 SmallVector<Type, 3> types{i.getType(), j.getType(), p.getType()}; 582 scf::WhileOp whileOp = builder.create<scf::WhileOp>(loc, types, operands); 583 584 // The before-region of the WhileOp. 585 Block *before = 586 builder.createBlock(&whileOp.getBefore(), {}, types, {loc, loc, loc}); 587 builder.setInsertionPointToEnd(before); 588 Value cond = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult, 589 before->getArgument(0), 590 before->getArgument(1)); 591 builder.create<scf::ConditionOp>(loc, cond, before->getArguments()); 592 593 // The after-region of the WhileOp. 594 Block *after = 595 builder.createBlock(&whileOp.getAfter(), {}, types, {loc, loc, loc}); 596 builder.setInsertionPointToEnd(after); 597 i = after->getArgument(0); 598 j = after->getArgument(1); 599 p = after->getArgument(2); 600 601 uint64_t numXBuffers = isCoo ? 1 : nx; 602 auto [iresult, iCompareEq] = 603 createScanLoop(builder, module, func, args.slice(xStartIdx, numXBuffers), 604 i, p, nx, ny, isCoo, 1); 605 i = iresult; 606 auto [jresult, jCompareEq] = 607 createScanLoop(builder, module, func, args.slice(xStartIdx, numXBuffers), 608 j, p, nx, ny, isCoo, -1); 609 j = jresult; 610 611 // If i < j: 612 cond = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult, i, j); 613 scf::IfOp ifOp = builder.create<scf::IfOp>(loc, types, cond, /*else=*/true); 614 builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); 615 SmallVector<Value> swapOperands{i, j}; 616 swapOperands.append(args.begin() + xStartIdx, args.end()); 617 createSwap(builder, loc, swapOperands, nx, ny, isCoo); 618 // If the pivot is moved, update p with the new pivot. 619 Value icond = 620 builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, i, p); 621 scf::IfOp ifOpI = builder.create<scf::IfOp>(loc, TypeRange{p.getType()}, 622 icond, /*else=*/true); 623 builder.setInsertionPointToStart(&ifOpI.getThenRegion().front()); 624 builder.create<scf::YieldOp>(loc, ValueRange{j}); 625 builder.setInsertionPointToStart(&ifOpI.getElseRegion().front()); 626 Value jcond = 627 builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, j, p); 628 scf::IfOp ifOpJ = builder.create<scf::IfOp>(loc, TypeRange{p.getType()}, 629 jcond, /*else=*/true); 630 builder.setInsertionPointToStart(&ifOpJ.getThenRegion().front()); 631 builder.create<scf::YieldOp>(loc, ValueRange{i}); 632 builder.setInsertionPointToStart(&ifOpJ.getElseRegion().front()); 633 builder.create<scf::YieldOp>(loc, ValueRange{p}); 634 builder.setInsertionPointAfter(ifOpJ); 635 builder.create<scf::YieldOp>(loc, ifOpJ.getResults()); 636 builder.setInsertionPointAfter(ifOpI); 637 Value compareEqIJ = 638 builder.create<arith::AndIOp>(loc, iCompareEq, jCompareEq); 639 scf::IfOp ifOp2 = builder.create<scf::IfOp>( 640 loc, TypeRange{i.getType(), j.getType()}, compareEqIJ, /*else=*/true); 641 builder.setInsertionPointToStart(&ifOp2.getThenRegion().front()); 642 Value i2 = builder.create<arith::AddIOp>(loc, i, c1); 643 Value j2 = builder.create<arith::SubIOp>(loc, j, c1); 644 builder.create<scf::YieldOp>(loc, ValueRange{i2, j2}); 645 builder.setInsertionPointToStart(&ifOp2.getElseRegion().front()); 646 builder.create<scf::YieldOp>(loc, ValueRange{i, j}); 647 builder.setInsertionPointAfter(ifOp2); 648 builder.create<scf::YieldOp>( 649 loc, 650 ValueRange{ifOp2.getResult(0), ifOp2.getResult(1), ifOpI.getResult(0)}); 651 652 // False branch for if i < j: 653 builder.setInsertionPointToStart(&ifOp.getElseRegion().front()); 654 builder.create<scf::YieldOp>(loc, ValueRange{i, j, p}); 655 656 // Return for the whileOp. 657 builder.setInsertionPointAfter(ifOp); 658 builder.create<scf::YieldOp>(loc, ifOp.getResults()); 659 660 // Return for the function. 661 builder.setInsertionPointAfter(whileOp); 662 builder.create<func::ReturnOp>(loc, whileOp.getResult(2)); 663 } 664 665 /// Creates a function to perform quick sort on the value in the range of 666 /// index [lo, hi). 667 // 668 // The generate IR corresponds to this C like algorithm: 669 // void quickSort(lo, hi, data) { 670 // if (lo < hi) { 671 // p = partition(low, high, data); 672 // quickSort(lo, p, data); 673 // quickSort(p + 1, hi, data); 674 // } 675 // } 676 static void createSortNonstableFunc(OpBuilder &builder, ModuleOp module, 677 func::FuncOp func, uint64_t nx, uint64_t ny, 678 bool isCoo) { 679 OpBuilder::InsertionGuard insertionGuard(builder); 680 Block *entryBlock = func.addEntryBlock(); 681 builder.setInsertionPointToStart(entryBlock); 682 683 MLIRContext *context = module.getContext(); 684 Location loc = func.getLoc(); 685 ValueRange args = entryBlock->getArguments(); 686 Value lo = args[loIdx]; 687 Value hi = args[hiIdx]; 688 Value cond = 689 builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult, lo, hi); 690 scf::IfOp ifOp = builder.create<scf::IfOp>(loc, cond, /*else=*/false); 691 692 // The if-stmt true branch. 693 builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); 694 FlatSymbolRefAttr partitionFunc = getMangledSortHelperFunc( 695 builder, func, {IndexType::get(context)}, kPartitionFuncNamePrefix, nx, 696 ny, isCoo, args, createPartitionFunc); 697 auto p = builder.create<func::CallOp>( 698 loc, partitionFunc, TypeRange{IndexType::get(context)}, ValueRange(args)); 699 700 SmallVector<Value> lowOperands{lo, p.getResult(0)}; 701 lowOperands.append(args.begin() + xStartIdx, args.end()); 702 builder.create<func::CallOp>(loc, func, lowOperands); 703 704 SmallVector<Value> highOperands{ 705 builder.create<arith::AddIOp>(loc, p.getResult(0), 706 constantIndex(builder, loc, 1)), 707 hi}; 708 highOperands.append(args.begin() + xStartIdx, args.end()); 709 builder.create<func::CallOp>(loc, func, highOperands); 710 711 // After the if-stmt. 712 builder.setInsertionPointAfter(ifOp); 713 builder.create<func::ReturnOp>(loc); 714 } 715 716 /// Creates a function to perform insertion sort on the values in the range of 717 /// index [lo, hi). 718 // 719 // The generate IR corresponds to this C like algorithm: 720 // void insertionSort(lo, hi, data) { 721 // for (i = lo+1; i < hi; i++) { 722 // d = data[i]; 723 // p = binarySearch(lo, i-1, data) 724 // for (j = 0; j > i - p; j++) 725 // data[i-j] = data[i-j-1] 726 // data[p] = d 727 // } 728 // } 729 static void createSortStableFunc(OpBuilder &builder, ModuleOp module, 730 func::FuncOp func, uint64_t nx, uint64_t ny, 731 bool isCoo) { 732 OpBuilder::InsertionGuard insertionGuard(builder); 733 Block *entryBlock = func.addEntryBlock(); 734 builder.setInsertionPointToStart(entryBlock); 735 736 MLIRContext *context = module.getContext(); 737 Location loc = func.getLoc(); 738 ValueRange args = entryBlock->getArguments(); 739 Value c1 = constantIndex(builder, loc, 1); 740 Value lo = args[loIdx]; 741 Value hi = args[hiIdx]; 742 Value lop1 = builder.create<arith::AddIOp>(loc, lo, c1); 743 744 // Start the outer for-stmt with induction variable i. 745 scf::ForOp forOpI = builder.create<scf::ForOp>(loc, lop1, hi, c1); 746 builder.setInsertionPointToStart(forOpI.getBody()); 747 Value i = forOpI.getInductionVar(); 748 749 // Binary search to find the insertion point p. 750 SmallVector<Value> operands{lo, i}; 751 operands.append(args.begin() + xStartIdx, args.end()); 752 FlatSymbolRefAttr searchFunc = getMangledSortHelperFunc( 753 builder, func, {IndexType::get(context)}, kBinarySearchFuncNamePrefix, nx, 754 ny, isCoo, operands, createBinarySearchFunc); 755 Value p = builder 756 .create<func::CallOp>(loc, searchFunc, TypeRange{c1.getType()}, 757 operands) 758 .getResult(0); 759 760 // Move the value at data[i] to a temporary location. 761 operands[0] = operands[1] = i; 762 SmallVector<Value> d; 763 forEachIJPairInAllBuffers( 764 builder, loc, operands, nx, ny, isCoo, 765 [&](uint64_t unused, Value i, Value unused2, Value buffer) { 766 d.push_back(builder.create<memref::LoadOp>(loc, buffer, i)); 767 }); 768 769 // Start the inner for-stmt with induction variable j, for moving data[p..i) 770 // to data[p+1..i+1). 771 Value imp = builder.create<arith::SubIOp>(loc, i, p); 772 Value c0 = constantIndex(builder, loc, 0); 773 scf::ForOp forOpJ = builder.create<scf::ForOp>(loc, c0, imp, c1); 774 builder.setInsertionPointToStart(forOpJ.getBody()); 775 Value j = forOpJ.getInductionVar(); 776 Value imj = builder.create<arith::SubIOp>(loc, i, j); 777 operands[1] = imj; 778 operands[0] = builder.create<arith::SubIOp>(loc, imj, c1); 779 forEachIJPairInAllBuffers( 780 builder, loc, operands, nx, ny, isCoo, 781 [&](uint64_t unused, Value imjm1, Value imj, Value buffer) { 782 Value t = builder.create<memref::LoadOp>(loc, buffer, imjm1); 783 builder.create<memref::StoreOp>(loc, t, buffer, imj); 784 }); 785 786 // Store the value at data[i] to data[p]. 787 builder.setInsertionPointAfter(forOpJ); 788 operands[0] = operands[1] = p; 789 forEachIJPairInAllBuffers( 790 builder, loc, operands, nx, ny, isCoo, 791 [&](uint64_t k, Value p, Value usused, Value buffer) { 792 builder.create<memref::StoreOp>(loc, d[k], buffer, p); 793 }); 794 795 builder.setInsertionPointAfter(forOpI); 796 builder.create<func::ReturnOp>(loc); 797 } 798 799 /// Implements the rewriting for operator sort and sort_coo. 800 template <typename OpTy> 801 LogicalResult matchAndRewriteSortOp(OpTy op, ValueRange xys, uint64_t nx, 802 uint64_t ny, bool isCoo, 803 PatternRewriter &rewriter) { 804 Location loc = op.getLoc(); 805 SmallVector<Value> operands{constantIndex(rewriter, loc, 0), op.getN()}; 806 807 // Convert `values` to have dynamic shape and append them to `operands`. 808 for (Value v : xys) { 809 auto mtp = getMemRefType(v); 810 if (!mtp.isDynamicDim(0)) { 811 auto newMtp = 812 MemRefType::get({ShapedType::kDynamic}, mtp.getElementType()); 813 v = rewriter.create<memref::CastOp>(loc, newMtp, v); 814 } 815 operands.push_back(v); 816 } 817 bool isStable = 818 (op.getAlgorithm() == SparseTensorSortKind::InsertionSortStable); 819 auto insertPoint = op->template getParentOfType<func::FuncOp>(); 820 SmallString<32> funcName(isStable ? kSortStableFuncNamePrefix 821 : kSortNonstableFuncNamePrefix); 822 FuncGeneratorType funcGenerator = 823 isStable ? createSortStableFunc : createSortNonstableFunc; 824 FlatSymbolRefAttr func = 825 getMangledSortHelperFunc(rewriter, insertPoint, TypeRange(), funcName, nx, 826 ny, isCoo, operands, funcGenerator); 827 rewriter.replaceOpWithNewOp<func::CallOp>(op, func, TypeRange(), operands); 828 return success(); 829 } 830 831 //===---------------------------------------------------------------------===// 832 // The actual sparse buffer rewriting rules. 833 //===---------------------------------------------------------------------===// 834 835 namespace { 836 837 /// Sparse rewriting rule for the push_back operator. 838 struct PushBackRewriter : OpRewritePattern<PushBackOp> { 839 public: 840 using OpRewritePattern<PushBackOp>::OpRewritePattern; 841 PushBackRewriter(MLIRContext *context, bool enableInit) 842 : OpRewritePattern(context), enableBufferInitialization(enableInit) {} 843 LogicalResult matchAndRewrite(PushBackOp op, 844 PatternRewriter &rewriter) const override { 845 // Rewrite push_back(buffer, value, n) to: 846 // new_size = size(buffer) + n 847 // if (new_size > capacity(buffer)) 848 // while new_size > new_capacity 849 // new_capacity = new_capacity*2 850 // new_buffer = realloc(buffer, new_capacity) 851 // buffer = new_buffer 852 // subBuffer = subviewof(buffer) 853 // linalg.fill subBuffer value 854 // 855 // size(buffer) += n 856 // 857 // The capacity check is skipped when the attribute inbounds is presented. 858 Location loc = op->getLoc(); 859 Value c0 = constantIndex(rewriter, loc, 0); 860 Value buffer = op.getInBuffer(); 861 Value capacity = rewriter.create<memref::DimOp>(loc, buffer, c0); 862 Value size = op.getCurSize(); 863 Value value = op.getValue(); 864 865 Value n = op.getN() ? op.getN() : constantIndex(rewriter, loc, 1); 866 Value newSize = rewriter.create<arith::AddIOp>(loc, size, n); 867 auto nValue = dyn_cast_or_null<arith::ConstantIndexOp>(n.getDefiningOp()); 868 bool nIsOne = (nValue && nValue.value() == 1); 869 870 if (!op.getInbounds()) { 871 Value cond = rewriter.create<arith::CmpIOp>( 872 loc, arith::CmpIPredicate::ugt, newSize, capacity); 873 874 Value c2 = constantIndex(rewriter, loc, 2); 875 auto bufferType = 876 MemRefType::get({ShapedType::kDynamic}, value.getType()); 877 scf::IfOp ifOp = rewriter.create<scf::IfOp>(loc, bufferType, cond, 878 /*else=*/true); 879 // True branch. 880 rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front()); 881 if (nIsOne) { 882 capacity = rewriter.create<arith::MulIOp>(loc, capacity, c2); 883 } else { 884 // Use a do-while loop to calculate the new capacity as follows: 885 // do { new_capacity *= 2 } while (size > new_capacity) 886 scf::WhileOp whileOp = 887 rewriter.create<scf::WhileOp>(loc, capacity.getType(), capacity); 888 889 // The before-region of the WhileOp. 890 Block *before = rewriter.createBlock(&whileOp.getBefore(), {}, 891 {capacity.getType()}, {loc}); 892 rewriter.setInsertionPointToEnd(before); 893 894 capacity = 895 rewriter.create<arith::MulIOp>(loc, before->getArgument(0), c2); 896 cond = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ugt, 897 newSize, capacity); 898 rewriter.create<scf::ConditionOp>(loc, cond, ValueRange{capacity}); 899 // The after-region of the WhileOp. 900 Block *after = rewriter.createBlock(&whileOp.getAfter(), {}, 901 {capacity.getType()}, {loc}); 902 rewriter.setInsertionPointToEnd(after); 903 rewriter.create<scf::YieldOp>(loc, after->getArguments()); 904 905 rewriter.setInsertionPointAfter(whileOp); 906 capacity = whileOp.getResult(0); 907 } 908 909 Value newBuffer = 910 rewriter.create<memref::ReallocOp>(loc, bufferType, buffer, capacity); 911 if (enableBufferInitialization) { 912 Value fillSize = rewriter.create<arith::SubIOp>(loc, capacity, newSize); 913 Value fillValue = constantZero(rewriter, loc, value.getType()); 914 Value subBuffer = rewriter.create<memref::SubViewOp>( 915 loc, newBuffer, /*offset=*/ValueRange{newSize}, 916 /*size=*/ValueRange{fillSize}, 917 /*step=*/ValueRange{constantIndex(rewriter, loc, 1)}); 918 rewriter.create<linalg::FillOp>(loc, fillValue, subBuffer); 919 } 920 rewriter.create<scf::YieldOp>(loc, newBuffer); 921 922 // False branch. 923 rewriter.setInsertionPointToStart(&ifOp.getElseRegion().front()); 924 rewriter.create<scf::YieldOp>(loc, buffer); 925 926 // Prepare for adding the value to the end of the buffer. 927 rewriter.setInsertionPointAfter(ifOp); 928 buffer = ifOp.getResult(0); 929 } 930 931 // Add the value to the end of the buffer. 932 if (nIsOne) { 933 rewriter.create<memref::StoreOp>(loc, value, buffer, size); 934 } else { 935 Value subBuffer = rewriter.create<memref::SubViewOp>( 936 loc, buffer, /*offset=*/ValueRange{size}, /*size=*/ValueRange{n}, 937 /*step=*/ValueRange{constantIndex(rewriter, loc, 1)}); 938 rewriter.create<linalg::FillOp>(loc, value, subBuffer); 939 } 940 941 // Update the buffer size. 942 rewriter.replaceOp(op, {buffer, newSize}); 943 return success(); 944 } 945 946 private: 947 bool enableBufferInitialization; 948 }; 949 950 /// Sparse rewriting rule for the sort operator. 951 struct SortRewriter : public OpRewritePattern<SortOp> { 952 public: 953 using OpRewritePattern<SortOp>::OpRewritePattern; 954 955 LogicalResult matchAndRewrite(SortOp op, 956 PatternRewriter &rewriter) const override { 957 SmallVector<Value> xys(op.getXs()); 958 xys.append(op.getYs().begin(), op.getYs().end()); 959 return matchAndRewriteSortOp(op, xys, op.getXs().size(), /*ny=*/0, 960 /*isCoo=*/false, rewriter); 961 } 962 }; 963 964 /// Sparse rewriting rule for the sort_coo operator. 965 struct SortCooRewriter : public OpRewritePattern<SortCooOp> { 966 public: 967 using OpRewritePattern<SortCooOp>::OpRewritePattern; 968 969 LogicalResult matchAndRewrite(SortCooOp op, 970 PatternRewriter &rewriter) const override { 971 SmallVector<Value> xys; 972 xys.push_back(op.getXy()); 973 xys.append(op.getYs().begin(), op.getYs().end()); 974 uint64_t nx = 1; 975 if (auto nxAttr = op.getNxAttr()) 976 nx = nxAttr.getInt(); 977 978 uint64_t ny = 0; 979 if (auto nyAttr = op.getNyAttr()) 980 ny = nyAttr.getInt(); 981 982 return matchAndRewriteSortOp(op, xys, nx, ny, 983 /*isCoo=*/true, rewriter); 984 } 985 }; 986 987 } // namespace 988 989 //===---------------------------------------------------------------------===// 990 // Methods that add patterns described in this file to a pattern list. 991 //===---------------------------------------------------------------------===// 992 993 void mlir::populateSparseBufferRewriting(RewritePatternSet &patterns, 994 bool enableBufferInitialization) { 995 patterns.add<PushBackRewriter>(patterns.getContext(), 996 enableBufferInitialization); 997 patterns.add<SortRewriter, SortCooRewriter>(patterns.getContext()); 998 } 999