1 //===- SparseBufferRewriting.cpp - Sparse buffer rewriting rules ----------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements rewriting rules that are specific to sparse tensor 10 // primitives with memref operands. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "CodegenUtils.h" 15 16 #include "mlir/Dialect/Arith/IR/Arith.h" 17 #include "mlir/Dialect/Func/IR/FuncOps.h" 18 #include "mlir/Dialect/Linalg/IR/Linalg.h" 19 #include "mlir/Dialect/MemRef/IR/MemRef.h" 20 #include "mlir/Dialect/SCF/IR/SCF.h" 21 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" 22 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h" 23 #include "mlir/Support/LLVM.h" 24 25 using namespace mlir; 26 using namespace mlir::sparse_tensor; 27 28 //===---------------------------------------------------------------------===// 29 // Helper methods for the actual rewriting rules. 30 //===---------------------------------------------------------------------===// 31 32 static constexpr uint64_t loIdx = 0; 33 static constexpr uint64_t hiIdx = 1; 34 static constexpr uint64_t xStartIdx = 2; 35 36 static constexpr const char kLessThanFuncNamePrefix[] = "_sparse_less_than_"; 37 static constexpr const char kCompareEqFuncNamePrefix[] = "_sparse_compare_eq_"; 38 static constexpr const char kPartitionFuncNamePrefix[] = "_sparse_partition_"; 39 static constexpr const char kBinarySearchFuncNamePrefix[] = 40 "_sparse_binary_search_"; 41 static constexpr const char kSortNonstableFuncNamePrefix[] = 42 "_sparse_sort_nonstable_"; 43 static constexpr const char kSortStableFuncNamePrefix[] = 44 "_sparse_sort_stable_"; 45 46 using FuncGeneratorType = function_ref<void(OpBuilder &, ModuleOp, func::FuncOp, 47 uint64_t, uint64_t, bool)>; 48 49 /// Constructs a function name with this format to facilitate quick sort: 50 /// <namePrefix><nx>_<x type>_<y0 type>..._<yn type> for sort 51 /// <namePrefix><nx>_<x type>_coo_<ny>_<y0 type>..._<yn type> for sort_coo 52 static void getMangledSortHelperFuncName(llvm::raw_svector_ostream &nameOstream, 53 StringRef namePrefix, uint64_t nx, 54 uint64_t ny, bool isCoo, 55 ValueRange operands) { 56 nameOstream 57 << namePrefix << nx << "_" 58 << operands[xStartIdx].getType().cast<MemRefType>().getElementType(); 59 60 if (isCoo) 61 nameOstream << "_coo_" << ny; 62 63 uint64_t yBufferOffset = isCoo ? 1 : nx; 64 for (Value v : operands.drop_front(xStartIdx + yBufferOffset)) 65 nameOstream << "_" << v.getType().cast<MemRefType>().getElementType(); 66 } 67 68 /// Looks up a function that is appropriate for the given operands being 69 /// sorted, and creates such a function if it doesn't exist yet. The 70 /// parameters `nx` and `ny` tell the number of x and y values provided 71 /// by the buffer in xStartIdx, and `isCoo` indicates whether the instruction 72 /// being processed is a sparse_tensor.sort or sparse_tensor.sort_coo. 73 static FlatSymbolRefAttr 74 getMangledSortHelperFunc(OpBuilder &builder, func::FuncOp insertPoint, 75 TypeRange resultTypes, StringRef namePrefix, 76 uint64_t nx, uint64_t ny, bool isCoo, 77 ValueRange operands, FuncGeneratorType createFunc) { 78 SmallString<32> nameBuffer; 79 llvm::raw_svector_ostream nameOstream(nameBuffer); 80 getMangledSortHelperFuncName(nameOstream, namePrefix, nx, ny, isCoo, 81 operands); 82 83 ModuleOp module = insertPoint->getParentOfType<ModuleOp>(); 84 MLIRContext *context = module.getContext(); 85 auto result = SymbolRefAttr::get(context, nameOstream.str()); 86 auto func = module.lookupSymbol<func::FuncOp>(result.getAttr()); 87 88 if (!func) { 89 // Create the function. 90 OpBuilder::InsertionGuard insertionGuard(builder); 91 builder.setInsertionPoint(insertPoint); 92 Location loc = insertPoint.getLoc(); 93 func = builder.create<func::FuncOp>( 94 loc, nameOstream.str(), 95 FunctionType::get(context, operands.getTypes(), resultTypes)); 96 func.setPrivate(); 97 createFunc(builder, module, func, nx, ny, isCoo); 98 } 99 100 return result; 101 } 102 103 /// Creates a code block to process each pair of (xs[i], xs[j]) for sorting. 104 /// The code to process the value pairs is generated by `bodyBuilder`. 105 static void forEachIJPairInXs( 106 OpBuilder &builder, Location loc, ValueRange args, uint64_t nx, uint64_t ny, 107 bool isCoo, function_ref<void(uint64_t, Value, Value, Value)> bodyBuilder) { 108 Value iOffset, jOffset; 109 if (isCoo) { 110 Value cstep = constantIndex(builder, loc, nx + ny); 111 iOffset = builder.create<arith::MulIOp>(loc, args[0], cstep); 112 jOffset = builder.create<arith::MulIOp>(loc, args[1], cstep); 113 } 114 for (uint64_t k = 0; k < nx; k++) { 115 scf::IfOp ifOp; 116 Value i, j, buffer; 117 if (isCoo) { 118 Value ck = constantIndex(builder, loc, k); 119 i = builder.create<arith::AddIOp>(loc, ck, iOffset); 120 j = builder.create<arith::AddIOp>(loc, ck, jOffset); 121 buffer = args[xStartIdx]; 122 } else { 123 i = args[0]; 124 j = args[1]; 125 buffer = args[xStartIdx + k]; 126 } 127 bodyBuilder(k, i, j, buffer); 128 } 129 } 130 131 /// Creates a code block to process each pair of (xys[i], xys[j]) for sorting. 132 /// The code to process the value pairs is generated by `bodyBuilder`. 133 static void forEachIJPairInAllBuffers( 134 OpBuilder &builder, Location loc, ValueRange args, uint64_t nx, uint64_t ny, 135 bool isCoo, function_ref<void(uint64_t, Value, Value, Value)> bodyBuilder) { 136 137 // Create code for the first (nx + ny) buffers. When isCoo==true, these 138 // logical buffers are all from the xy buffer of the sort_coo operator. 139 forEachIJPairInXs(builder, loc, args, nx + ny, 0, isCoo, bodyBuilder); 140 141 uint64_t numHandledBuffers = isCoo ? 1 : nx + ny; 142 143 // Create code for the remaining buffers. 144 Value i = args[0]; 145 Value j = args[1]; 146 for (const auto &arg : 147 llvm::enumerate(args.drop_front(xStartIdx + numHandledBuffers))) { 148 bodyBuilder(arg.index() + nx + ny, i, j, arg.value()); 149 } 150 } 151 152 /// Creates a code block for swapping the values in index i and j for all the 153 /// buffers. 154 // 155 // The generated IR corresponds to this C like algorithm: 156 // swap(x0[i], x0[j]); 157 // swap(x1[i], x1[j]); 158 // ... 159 // swap(xn[i], xn[j]); 160 // swap(y0[i], y0[j]); 161 // ... 162 // swap(yn[i], yn[j]); 163 static void createSwap(OpBuilder &builder, Location loc, ValueRange args, 164 uint64_t nx, uint64_t ny, bool isCoo) { 165 auto swapOnePair = [&](uint64_t unused, Value i, Value j, Value buffer) { 166 Value vi = builder.create<memref::LoadOp>(loc, buffer, i); 167 Value vj = builder.create<memref::LoadOp>(loc, buffer, j); 168 builder.create<memref::StoreOp>(loc, vj, buffer, i); 169 builder.create<memref::StoreOp>(loc, vi, buffer, j); 170 }; 171 172 forEachIJPairInAllBuffers(builder, loc, args, nx, ny, isCoo, swapOnePair); 173 } 174 175 /// Creates a function to compare all the (xs[i], xs[j]) pairs. The method to 176 /// compare each pair is create via `compareBuilder`. 177 static void createCompareFuncImplementation( 178 OpBuilder &builder, ModuleOp unused, func::FuncOp func, uint64_t nx, 179 uint64_t ny, bool isCoo, 180 function_ref<scf::IfOp(OpBuilder &, Location, Value, Value, Value, bool)> 181 compareBuilder) { 182 OpBuilder::InsertionGuard insertionGuard(builder); 183 184 Block *entryBlock = func.addEntryBlock(); 185 builder.setInsertionPointToStart(entryBlock); 186 Location loc = func.getLoc(); 187 ValueRange args = entryBlock->getArguments(); 188 189 scf::IfOp topIfOp; 190 auto bodyBuilder = [&](uint64_t k, Value i, Value j, Value buffer) { 191 scf::IfOp ifOp = compareBuilder(builder, loc, i, j, buffer, (k == nx - 1)); 192 if (k == 0) { 193 topIfOp = ifOp; 194 } else { 195 OpBuilder::InsertionGuard insertionGuard(builder); 196 builder.setInsertionPointAfter(ifOp); 197 builder.create<scf::YieldOp>(loc, ifOp.getResult(0)); 198 } 199 }; 200 201 forEachIJPairInXs(builder, loc, args, nx, ny, isCoo, bodyBuilder); 202 203 builder.setInsertionPointAfter(topIfOp); 204 builder.create<func::ReturnOp>(loc, topIfOp.getResult(0)); 205 } 206 207 /// Generates an if-statement to compare whether x[i] is equal to x[j]. 208 static scf::IfOp createEqCompare(OpBuilder &builder, Location loc, Value i, 209 Value j, Value x, bool isLastDim) { 210 Value f = constantI1(builder, loc, false); 211 Value t = constantI1(builder, loc, true); 212 Value vi = builder.create<memref::LoadOp>(loc, x, i); 213 Value vj = builder.create<memref::LoadOp>(loc, x, j); 214 215 Value cond = 216 builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, vi, vj); 217 scf::IfOp ifOp = 218 builder.create<scf::IfOp>(loc, f.getType(), cond, /*else=*/true); 219 220 // x[1] != x[j]: 221 builder.setInsertionPointToStart(&ifOp.getElseRegion().front()); 222 builder.create<scf::YieldOp>(loc, f); 223 224 // x[i] == x[j]: 225 builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); 226 if (isLastDim == 1) { 227 // Finish checking all dimensions. 228 builder.create<scf::YieldOp>(loc, t); 229 } 230 231 return ifOp; 232 } 233 234 /// Creates a function to compare whether xs[i] is equal to xs[j]. 235 // 236 // The generate IR corresponds to this C like algorithm: 237 // if (x0[i] != x0[j]) 238 // return false; 239 // else 240 // if (x1[i] != x1[j]) 241 // return false; 242 // else if (x2[2] != x2[j])) 243 // and so on ... 244 static void createEqCompareFunc(OpBuilder &builder, ModuleOp unused, 245 func::FuncOp func, uint64_t nx, uint64_t ny, 246 bool isCoo) { 247 createCompareFuncImplementation(builder, unused, func, nx, ny, isCoo, 248 createEqCompare); 249 } 250 251 /// Generates an if-statement to compare whether x[i] is less than x[j]. 252 static scf::IfOp createLessThanCompare(OpBuilder &builder, Location loc, 253 Value i, Value j, Value x, 254 bool isLastDim) { 255 Value f = constantI1(builder, loc, false); 256 Value t = constantI1(builder, loc, true); 257 Value vi = builder.create<memref::LoadOp>(loc, x, i); 258 Value vj = builder.create<memref::LoadOp>(loc, x, j); 259 260 Value cond = 261 builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult, vi, vj); 262 scf::IfOp ifOp = 263 builder.create<scf::IfOp>(loc, f.getType(), cond, /*else=*/true); 264 // If (x[i] < x[j]). 265 builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); 266 builder.create<scf::YieldOp>(loc, t); 267 268 builder.setInsertionPointToStart(&ifOp.getElseRegion().front()); 269 if (isLastDim == 1) { 270 // Finish checking all dimensions. 271 builder.create<scf::YieldOp>(loc, f); 272 } else { 273 cond = 274 builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult, vj, vi); 275 scf::IfOp ifOp2 = 276 builder.create<scf::IfOp>(loc, f.getType(), cond, /*else=*/true); 277 // Otherwise if (x[j] < x[i]). 278 builder.setInsertionPointToStart(&ifOp2.getThenRegion().front()); 279 builder.create<scf::YieldOp>(loc, f); 280 281 // Otherwise check the remaining dimensions. 282 builder.setInsertionPointAfter(ifOp2); 283 builder.create<scf::YieldOp>(loc, ifOp2.getResult(0)); 284 // Set up the insertion point for the nested if-stmt that checks the 285 // remaining dimensions. 286 builder.setInsertionPointToStart(&ifOp2.getElseRegion().front()); 287 } 288 289 return ifOp; 290 } 291 292 /// Creates a function to compare whether xs[i] is less than xs[j]. 293 // 294 // The generate IR corresponds to this C like algorithm: 295 // if (x0[i] < x0[j]) 296 // return true; 297 // else if (x0[j] < x0[i]) 298 // return false; 299 // else 300 // if (x1[i] < x1[j]) 301 // return true; 302 // else if (x1[j] < x1[i])) 303 // and so on ... 304 static void createLessThanFunc(OpBuilder &builder, ModuleOp unused, 305 func::FuncOp func, uint64_t nx, uint64_t ny, 306 bool isCoo) { 307 createCompareFuncImplementation(builder, unused, func, nx, ny, isCoo, 308 createLessThanCompare); 309 } 310 311 /// Creates a function to use a binary search to find the insertion point for 312 /// inserting xs[hi] to the sorted values xs[lo..hi). 313 // 314 // The generate IR corresponds to this C like algorithm: 315 // p = hi 316 // while (lo < hi) 317 // mid = (lo + hi) >> 1 318 // if (xs[p] < xs[mid]) 319 // hi = mid 320 // else 321 // lo = mid - 1 322 // return lo; 323 // 324 static void createBinarySearchFunc(OpBuilder &builder, ModuleOp module, 325 func::FuncOp func, uint64_t nx, uint64_t ny, 326 bool isCoo) { 327 OpBuilder::InsertionGuard insertionGuard(builder); 328 Block *entryBlock = func.addEntryBlock(); 329 builder.setInsertionPointToStart(entryBlock); 330 331 Location loc = func.getLoc(); 332 ValueRange args = entryBlock->getArguments(); 333 Value p = args[hiIdx]; 334 SmallVector<Type, 2> types(2, p.getType()); // only two 335 scf::WhileOp whileOp = builder.create<scf::WhileOp>( 336 loc, types, SmallVector<Value, 2>{args[loIdx], args[hiIdx]}); 337 338 // The before-region of the WhileOp. 339 Block *before = 340 builder.createBlock(&whileOp.getBefore(), {}, types, {loc, loc}); 341 builder.setInsertionPointToEnd(before); 342 Value cond1 = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult, 343 before->getArgument(0), 344 before->getArgument(1)); 345 builder.create<scf::ConditionOp>(loc, cond1, before->getArguments()); 346 347 // The after-region of the WhileOp. 348 Block *after = 349 builder.createBlock(&whileOp.getAfter(), {}, types, {loc, loc}); 350 builder.setInsertionPointToEnd(after); 351 Value lo = after->getArgument(0); 352 Value hi = after->getArgument(1); 353 // Compute mid = (lo + hi) >> 1. 354 Value c1 = constantIndex(builder, loc, 1); 355 Value mid = builder.create<arith::ShRUIOp>( 356 loc, builder.create<arith::AddIOp>(loc, lo, hi), c1); 357 Value midp1 = builder.create<arith::AddIOp>(loc, mid, c1); 358 359 // Compare xs[p] < xs[mid]. 360 SmallVector<Value> compareOperands{p, mid}; 361 uint64_t numXBuffers = isCoo ? 1 : nx; 362 compareOperands.append(args.begin() + xStartIdx, 363 args.begin() + xStartIdx + numXBuffers); 364 Type i1Type = IntegerType::get(module.getContext(), 1, IntegerType::Signless); 365 FlatSymbolRefAttr lessThanFunc = getMangledSortHelperFunc( 366 builder, func, {i1Type}, kLessThanFuncNamePrefix, nx, ny, isCoo, 367 compareOperands, createLessThanFunc); 368 Value cond2 = builder 369 .create<func::CallOp>(loc, lessThanFunc, TypeRange{i1Type}, 370 compareOperands) 371 .getResult(0); 372 373 // Update lo and hi for the WhileOp as follows: 374 // if (xs[p] < xs[mid])) 375 // hi = mid; 376 // else 377 // lo = mid + 1; 378 Value newLo = builder.create<arith::SelectOp>(loc, cond2, lo, midp1); 379 Value newHi = builder.create<arith::SelectOp>(loc, cond2, mid, hi); 380 builder.create<scf::YieldOp>(loc, ValueRange{newLo, newHi}); 381 382 builder.setInsertionPointAfter(whileOp); 383 builder.create<func::ReturnOp>(loc, whileOp.getResult(0)); 384 } 385 386 /// Creates code to advance i in a loop based on xs[p] as follows: 387 /// while (xs[i] < xs[p]) i += step (step > 0) 388 /// or 389 /// while (xs[i] > xs[p]) i += step (step < 0) 390 /// The routine returns i as well as a boolean value to indicate whether 391 /// xs[i] == xs[p]. 392 static std::pair<Value, Value> 393 createScanLoop(OpBuilder &builder, ModuleOp module, func::FuncOp func, 394 ValueRange xs, Value i, Value p, uint64_t nx, uint64_t ny, 395 bool isCoo, int step) { 396 Location loc = func.getLoc(); 397 scf::WhileOp whileOp = 398 builder.create<scf::WhileOp>(loc, TypeRange{i.getType()}, ValueRange{i}); 399 400 Block *before = 401 builder.createBlock(&whileOp.getBefore(), {}, {i.getType()}, {loc}); 402 builder.setInsertionPointToEnd(before); 403 SmallVector<Value> compareOperands; 404 if (step > 0) { 405 compareOperands.push_back(before->getArgument(0)); 406 compareOperands.push_back(p); 407 } else { 408 assert(step < 0); 409 compareOperands.push_back(p); 410 compareOperands.push_back(before->getArgument(0)); 411 } 412 compareOperands.append(xs.begin(), xs.end()); 413 MLIRContext *context = module.getContext(); 414 Type i1Type = IntegerType::get(context, 1, IntegerType::Signless); 415 FlatSymbolRefAttr lessThanFunc = getMangledSortHelperFunc( 416 builder, func, {i1Type}, kLessThanFuncNamePrefix, nx, ny, isCoo, 417 compareOperands, createLessThanFunc); 418 Value cond = builder 419 .create<func::CallOp>(loc, lessThanFunc, TypeRange{i1Type}, 420 compareOperands) 421 .getResult(0); 422 builder.create<scf::ConditionOp>(loc, cond, before->getArguments()); 423 424 Block *after = 425 builder.createBlock(&whileOp.getAfter(), {}, {i.getType()}, {loc}); 426 builder.setInsertionPointToEnd(after); 427 Value cs = constantIndex(builder, loc, step); 428 i = builder.create<arith::AddIOp>(loc, after->getArgument(0), cs); 429 builder.create<scf::YieldOp>(loc, ValueRange{i}); 430 i = whileOp.getResult(0); 431 432 builder.setInsertionPointAfter(whileOp); 433 compareOperands[0] = i; 434 compareOperands[1] = p; 435 FlatSymbolRefAttr compareEqFunc = getMangledSortHelperFunc( 436 builder, func, {i1Type}, kCompareEqFuncNamePrefix, nx, ny, isCoo, 437 compareOperands, createEqCompareFunc); 438 Value compareEq = 439 builder 440 .create<func::CallOp>(loc, compareEqFunc, TypeRange{i1Type}, 441 compareOperands) 442 .getResult(0); 443 444 return std::make_pair(whileOp.getResult(0), compareEq); 445 } 446 447 /// Creates a function to perform quick sort partition on the values in the 448 /// range of index [lo, hi), assuming lo < hi. 449 // 450 // The generated IR corresponds to this C like algorithm: 451 // int partition(lo, hi, xs) { 452 // p = (lo+hi)/2 // pivot index 453 // i = lo 454 // j = hi-1 455 // while (i < j) do { 456 // while (xs[i] < xs[p]) i ++; 457 // i_eq = (xs[i] == xs[p]); 458 // while (xs[j] > xs[p]) j --; 459 // j_eq = (xs[j] == xs[p]); 460 // if (i < j) { 461 // swap(xs[i], xs[j]) 462 // if (i == p) { 463 // p = j; 464 // } else if (j == p) { 465 // p = i; 466 // } 467 // if (i_eq && j_eq) { 468 // ++i; 469 // --j; 470 // } 471 // } 472 // } 473 // return p 474 // } 475 static void createPartitionFunc(OpBuilder &builder, ModuleOp module, 476 func::FuncOp func, uint64_t nx, uint64_t ny, 477 bool isCoo) { 478 OpBuilder::InsertionGuard insertionGuard(builder); 479 480 Block *entryBlock = func.addEntryBlock(); 481 builder.setInsertionPointToStart(entryBlock); 482 483 Location loc = func.getLoc(); 484 ValueRange args = entryBlock->getArguments(); 485 Value lo = args[loIdx]; 486 Value hi = args[hiIdx]; 487 Value sum = builder.create<arith::AddIOp>(loc, lo, hi); 488 Value c1 = constantIndex(builder, loc, 1); 489 Value p = builder.create<arith::ShRUIOp>(loc, sum, c1); 490 491 Value i = lo; 492 Value j = builder.create<arith::SubIOp>(loc, hi, c1); 493 SmallVector<Value, 3> operands{i, j, p}; // exactly three 494 SmallVector<Type, 3> types{i.getType(), j.getType(), p.getType()}; 495 scf::WhileOp whileOp = builder.create<scf::WhileOp>(loc, types, operands); 496 497 // The before-region of the WhileOp. 498 Block *before = 499 builder.createBlock(&whileOp.getBefore(), {}, types, {loc, loc, loc}); 500 builder.setInsertionPointToEnd(before); 501 Value cond = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult, 502 before->getArgument(0), 503 before->getArgument(1)); 504 builder.create<scf::ConditionOp>(loc, cond, before->getArguments()); 505 506 // The after-region of the WhileOp. 507 Block *after = 508 builder.createBlock(&whileOp.getAfter(), {}, types, {loc, loc, loc}); 509 builder.setInsertionPointToEnd(after); 510 i = after->getArgument(0); 511 j = after->getArgument(1); 512 p = after->getArgument(2); 513 514 uint64_t numXBuffers = isCoo ? 1 : nx; 515 auto [iresult, iCompareEq] = 516 createScanLoop(builder, module, func, args.slice(xStartIdx, numXBuffers), 517 i, p, nx, ny, isCoo, 1); 518 i = iresult; 519 auto [jresult, jCompareEq] = 520 createScanLoop(builder, module, func, args.slice(xStartIdx, numXBuffers), 521 j, p, nx, ny, isCoo, -1); 522 j = jresult; 523 524 // If i < j: 525 cond = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult, i, j); 526 scf::IfOp ifOp = builder.create<scf::IfOp>(loc, types, cond, /*else=*/true); 527 builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); 528 SmallVector<Value> swapOperands{i, j}; 529 swapOperands.append(args.begin() + xStartIdx, args.end()); 530 createSwap(builder, loc, swapOperands, nx, ny, isCoo); 531 // If the pivot is moved, update p with the new pivot. 532 Value icond = 533 builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, i, p); 534 scf::IfOp ifOpI = builder.create<scf::IfOp>(loc, TypeRange{p.getType()}, 535 icond, /*else=*/true); 536 builder.setInsertionPointToStart(&ifOpI.getThenRegion().front()); 537 builder.create<scf::YieldOp>(loc, ValueRange{j}); 538 builder.setInsertionPointToStart(&ifOpI.getElseRegion().front()); 539 Value jcond = 540 builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, j, p); 541 scf::IfOp ifOpJ = builder.create<scf::IfOp>(loc, TypeRange{p.getType()}, 542 jcond, /*else=*/true); 543 builder.setInsertionPointToStart(&ifOpJ.getThenRegion().front()); 544 builder.create<scf::YieldOp>(loc, ValueRange{i}); 545 builder.setInsertionPointToStart(&ifOpJ.getElseRegion().front()); 546 builder.create<scf::YieldOp>(loc, ValueRange{p}); 547 builder.setInsertionPointAfter(ifOpJ); 548 builder.create<scf::YieldOp>(loc, ifOpJ.getResults()); 549 builder.setInsertionPointAfter(ifOpI); 550 Value compareEqIJ = 551 builder.create<arith::AndIOp>(loc, iCompareEq, jCompareEq); 552 scf::IfOp ifOp2 = builder.create<scf::IfOp>( 553 loc, TypeRange{i.getType(), j.getType()}, compareEqIJ, /*else=*/true); 554 builder.setInsertionPointToStart(&ifOp2.getThenRegion().front()); 555 Value i2 = builder.create<arith::AddIOp>(loc, i, c1); 556 Value j2 = builder.create<arith::SubIOp>(loc, j, c1); 557 builder.create<scf::YieldOp>(loc, ValueRange{i2, j2}); 558 builder.setInsertionPointToStart(&ifOp2.getElseRegion().front()); 559 builder.create<scf::YieldOp>(loc, ValueRange{i, j}); 560 builder.setInsertionPointAfter(ifOp2); 561 builder.create<scf::YieldOp>( 562 loc, 563 ValueRange{ifOp2.getResult(0), ifOp2.getResult(1), ifOpI.getResult(0)}); 564 565 // False branch for if i < j: 566 builder.setInsertionPointToStart(&ifOp.getElseRegion().front()); 567 builder.create<scf::YieldOp>(loc, ValueRange{i, j, p}); 568 569 // Return for the whileOp. 570 builder.setInsertionPointAfter(ifOp); 571 builder.create<scf::YieldOp>(loc, ifOp.getResults()); 572 573 // Return for the function. 574 builder.setInsertionPointAfter(whileOp); 575 builder.create<func::ReturnOp>(loc, whileOp.getResult(2)); 576 } 577 578 /// Creates a function to perform quick sort on the value in the range of 579 /// index [lo, hi). 580 // 581 // The generate IR corresponds to this C like algorithm: 582 // void quickSort(lo, hi, data) { 583 // if (lo < hi) { 584 // p = partition(low, high, data); 585 // quickSort(lo, p, data); 586 // quickSort(p + 1, hi, data); 587 // } 588 // } 589 static void createSortNonstableFunc(OpBuilder &builder, ModuleOp module, 590 func::FuncOp func, uint64_t nx, uint64_t ny, 591 bool isCoo) { 592 OpBuilder::InsertionGuard insertionGuard(builder); 593 Block *entryBlock = func.addEntryBlock(); 594 builder.setInsertionPointToStart(entryBlock); 595 596 MLIRContext *context = module.getContext(); 597 Location loc = func.getLoc(); 598 ValueRange args = entryBlock->getArguments(); 599 Value lo = args[loIdx]; 600 Value hi = args[hiIdx]; 601 Value cond = 602 builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult, lo, hi); 603 scf::IfOp ifOp = builder.create<scf::IfOp>(loc, cond, /*else=*/false); 604 605 // The if-stmt true branch. 606 builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); 607 FlatSymbolRefAttr partitionFunc = getMangledSortHelperFunc( 608 builder, func, {IndexType::get(context)}, kPartitionFuncNamePrefix, nx, 609 ny, isCoo, args, createPartitionFunc); 610 auto p = builder.create<func::CallOp>( 611 loc, partitionFunc, TypeRange{IndexType::get(context)}, ValueRange(args)); 612 613 SmallVector<Value> lowOperands{lo, p.getResult(0)}; 614 lowOperands.append(args.begin() + xStartIdx, args.end()); 615 builder.create<func::CallOp>(loc, func, lowOperands); 616 617 SmallVector<Value> highOperands{ 618 builder.create<arith::AddIOp>(loc, p.getResult(0), 619 constantIndex(builder, loc, 1)), 620 hi}; 621 highOperands.append(args.begin() + xStartIdx, args.end()); 622 builder.create<func::CallOp>(loc, func, highOperands); 623 624 // After the if-stmt. 625 builder.setInsertionPointAfter(ifOp); 626 builder.create<func::ReturnOp>(loc); 627 } 628 629 /// Creates a function to perform insertion sort on the values in the range of 630 /// index [lo, hi). 631 // 632 // The generate IR corresponds to this C like algorithm: 633 // void insertionSort(lo, hi, data) { 634 // for (i = lo+1; i < hi; i++) { 635 // d = data[i]; 636 // p = binarySearch(lo, i-1, data) 637 // for (j = 0; j > i - p; j++) 638 // data[i-j] = data[i-j-1] 639 // data[p] = d 640 // } 641 // } 642 static void createSortStableFunc(OpBuilder &builder, ModuleOp module, 643 func::FuncOp func, uint64_t nx, uint64_t ny, 644 bool isCoo) { 645 OpBuilder::InsertionGuard insertionGuard(builder); 646 Block *entryBlock = func.addEntryBlock(); 647 builder.setInsertionPointToStart(entryBlock); 648 649 MLIRContext *context = module.getContext(); 650 Location loc = func.getLoc(); 651 ValueRange args = entryBlock->getArguments(); 652 Value c1 = constantIndex(builder, loc, 1); 653 Value lo = args[loIdx]; 654 Value hi = args[hiIdx]; 655 Value lop1 = builder.create<arith::AddIOp>(loc, lo, c1); 656 657 // Start the outer for-stmt with induction variable i. 658 scf::ForOp forOpI = builder.create<scf::ForOp>(loc, lop1, hi, c1); 659 builder.setInsertionPointToStart(forOpI.getBody()); 660 Value i = forOpI.getInductionVar(); 661 662 // Binary search to find the insertion point p. 663 SmallVector<Value> operands{lo, i}; 664 operands.append(args.begin() + xStartIdx, args.end()); 665 FlatSymbolRefAttr searchFunc = getMangledSortHelperFunc( 666 builder, func, {IndexType::get(context)}, kBinarySearchFuncNamePrefix, nx, 667 ny, isCoo, operands, createBinarySearchFunc); 668 Value p = builder 669 .create<func::CallOp>(loc, searchFunc, TypeRange{c1.getType()}, 670 operands) 671 .getResult(0); 672 673 // Move the value at data[i] to a temporary location. 674 operands[0] = operands[1] = i; 675 SmallVector<Value> d; 676 forEachIJPairInAllBuffers( 677 builder, loc, operands, nx, ny, isCoo, 678 [&](uint64_t unused, Value i, Value unused2, Value buffer) { 679 d.push_back(builder.create<memref::LoadOp>(loc, buffer, i)); 680 }); 681 682 // Start the inner for-stmt with induction variable j, for moving data[p..i) 683 // to data[p+1..i+1). 684 Value imp = builder.create<arith::SubIOp>(loc, i, p); 685 Value c0 = constantIndex(builder, loc, 0); 686 scf::ForOp forOpJ = builder.create<scf::ForOp>(loc, c0, imp, c1); 687 builder.setInsertionPointToStart(forOpJ.getBody()); 688 Value j = forOpJ.getInductionVar(); 689 Value imj = builder.create<arith::SubIOp>(loc, i, j); 690 operands[1] = imj; 691 operands[0] = builder.create<arith::SubIOp>(loc, imj, c1); 692 forEachIJPairInAllBuffers( 693 builder, loc, operands, nx, ny, isCoo, 694 [&](uint64_t unused, Value imjm1, Value imj, Value buffer) { 695 Value t = builder.create<memref::LoadOp>(loc, buffer, imjm1); 696 builder.create<memref::StoreOp>(loc, t, buffer, imj); 697 }); 698 699 // Store the value at data[i] to data[p]. 700 builder.setInsertionPointAfter(forOpJ); 701 operands[0] = operands[1] = p; 702 forEachIJPairInAllBuffers( 703 builder, loc, operands, nx, ny, isCoo, 704 [&](uint64_t k, Value p, Value usused, Value buffer) { 705 builder.create<memref::StoreOp>(loc, d[k], buffer, p); 706 }); 707 708 builder.setInsertionPointAfter(forOpI); 709 builder.create<func::ReturnOp>(loc); 710 } 711 712 /// Implements the rewriting for operator sort and sort_coo. 713 template <typename OpTy> 714 LogicalResult matchAndRewriteSortOp(OpTy op, ValueRange xys, uint64_t nx, 715 uint64_t ny, bool isCoo, 716 PatternRewriter &rewriter) { 717 Location loc = op.getLoc(); 718 SmallVector<Value> operands{constantIndex(rewriter, loc, 0), op.getN()}; 719 720 // Convert `values` to have dynamic shape and append them to `operands`. 721 for (Value v : xys) { 722 auto mtp = v.getType().cast<MemRefType>(); 723 if (!mtp.isDynamicDim(0)) { 724 auto newMtp = 725 MemRefType::get({ShapedType::kDynamic}, mtp.getElementType()); 726 v = rewriter.create<memref::CastOp>(loc, newMtp, v); 727 } 728 operands.push_back(v); 729 } 730 auto insertPoint = op->template getParentOfType<func::FuncOp>(); 731 SmallString<32> funcName(op.getStable() ? kSortStableFuncNamePrefix 732 : kSortNonstableFuncNamePrefix); 733 FuncGeneratorType funcGenerator = 734 op.getStable() ? createSortStableFunc : createSortNonstableFunc; 735 FlatSymbolRefAttr func = 736 getMangledSortHelperFunc(rewriter, insertPoint, TypeRange(), funcName, nx, 737 ny, isCoo, operands, funcGenerator); 738 rewriter.replaceOpWithNewOp<func::CallOp>(op, func, TypeRange(), operands); 739 return success(); 740 } 741 742 //===---------------------------------------------------------------------===// 743 // The actual sparse buffer rewriting rules. 744 //===---------------------------------------------------------------------===// 745 746 namespace { 747 748 /// Sparse rewriting rule for the push_back operator. 749 struct PushBackRewriter : OpRewritePattern<PushBackOp> { 750 public: 751 using OpRewritePattern<PushBackOp>::OpRewritePattern; 752 PushBackRewriter(MLIRContext *context, bool enableInit) 753 : OpRewritePattern(context), enableBufferInitialization(enableInit) {} 754 LogicalResult matchAndRewrite(PushBackOp op, 755 PatternRewriter &rewriter) const override { 756 // Rewrite push_back(buffer, value, n) to: 757 // new_size = size(buffer) + n 758 // if (new_size > capacity(buffer)) 759 // while new_size > new_capacity 760 // new_capacity = new_capacity*2 761 // new_buffer = realloc(buffer, new_capacity) 762 // buffer = new_buffer 763 // subBuffer = subviewof(buffer) 764 // linalg.fill subBuffer value 765 // 766 // size(buffer) += n 767 // 768 // The capacity check is skipped when the attribute inbounds is presented. 769 Location loc = op->getLoc(); 770 Value c0 = constantIndex(rewriter, loc, 0); 771 Value buffer = op.getInBuffer(); 772 Value capacity = rewriter.create<memref::DimOp>(loc, buffer, c0); 773 Value idx = constantIndex(rewriter, loc, op.getIdx().getZExtValue()); 774 Value bufferSizes = op.getBufferSizes(); 775 Value size = rewriter.create<memref::LoadOp>(loc, bufferSizes, idx); 776 Value value = op.getValue(); 777 778 Value n = op.getN() ? op.getN() : constantIndex(rewriter, loc, 1); 779 Value newSize = rewriter.create<arith::AddIOp>(loc, size, n); 780 auto nValue = dyn_cast_or_null<arith::ConstantIndexOp>(n.getDefiningOp()); 781 bool nIsOne = (nValue && nValue.value() == 1); 782 783 if (!op.getInbounds()) { 784 Value cond = rewriter.create<arith::CmpIOp>( 785 loc, arith::CmpIPredicate::ugt, newSize, capacity); 786 787 Value c2 = constantIndex(rewriter, loc, 2); 788 auto bufferType = 789 MemRefType::get({ShapedType::kDynamic}, value.getType()); 790 scf::IfOp ifOp = rewriter.create<scf::IfOp>(loc, bufferType, cond, 791 /*else=*/true); 792 // True branch. 793 rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front()); 794 if (nIsOne) { 795 capacity = rewriter.create<arith::MulIOp>(loc, capacity, c2); 796 } else { 797 // Use a do-while loop to calculate the new capacity as follows: 798 // do { new_capacity *= 2 } while (size > new_capacity) 799 scf::WhileOp whileOp = 800 rewriter.create<scf::WhileOp>(loc, capacity.getType(), capacity); 801 802 // The before-region of the WhileOp. 803 Block *before = rewriter.createBlock(&whileOp.getBefore(), {}, 804 {capacity.getType()}, {loc}); 805 rewriter.setInsertionPointToEnd(before); 806 807 capacity = 808 rewriter.create<arith::MulIOp>(loc, before->getArgument(0), c2); 809 cond = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ugt, 810 newSize, capacity); 811 rewriter.create<scf::ConditionOp>(loc, cond, ValueRange{capacity}); 812 // The after-region of the WhileOp. 813 Block *after = rewriter.createBlock(&whileOp.getAfter(), {}, 814 {capacity.getType()}, {loc}); 815 rewriter.setInsertionPointToEnd(after); 816 rewriter.create<scf::YieldOp>(loc, after->getArguments()); 817 818 rewriter.setInsertionPointAfter(whileOp); 819 capacity = whileOp.getResult(0); 820 } 821 822 Value newBuffer = 823 rewriter.create<memref::ReallocOp>(loc, bufferType, buffer, capacity); 824 if (enableBufferInitialization) { 825 Value fillSize = rewriter.create<arith::SubIOp>(loc, capacity, newSize); 826 Value fillValue = constantZero(rewriter, loc, value.getType()); 827 Value subBuffer = rewriter.create<memref::SubViewOp>( 828 loc, newBuffer, /*offset=*/ValueRange{newSize}, 829 /*size=*/ValueRange{fillSize}, 830 /*step=*/ValueRange{constantIndex(rewriter, loc, 1)}); 831 rewriter.create<linalg::FillOp>(loc, fillValue, subBuffer); 832 } 833 rewriter.create<scf::YieldOp>(loc, newBuffer); 834 835 // False branch. 836 rewriter.setInsertionPointToStart(&ifOp.getElseRegion().front()); 837 rewriter.create<scf::YieldOp>(loc, buffer); 838 839 // Prepare for adding the value to the end of the buffer. 840 rewriter.setInsertionPointAfter(ifOp); 841 buffer = ifOp.getResult(0); 842 } 843 844 // Add the value to the end of the buffer. 845 if (nIsOne) { 846 rewriter.create<memref::StoreOp>(loc, value, buffer, size); 847 } else { 848 Value subBuffer = rewriter.create<memref::SubViewOp>( 849 loc, buffer, /*offset=*/ValueRange{size}, /*size=*/ValueRange{n}, 850 /*step=*/ValueRange{constantIndex(rewriter, loc, 1)}); 851 rewriter.create<linalg::FillOp>(loc, value, subBuffer); 852 } 853 854 // Update the buffer size. 855 rewriter.create<memref::StoreOp>(loc, newSize, bufferSizes, idx); 856 rewriter.replaceOp(op, buffer); 857 return success(); 858 } 859 860 private: 861 bool enableBufferInitialization; 862 }; 863 864 /// Sparse rewriting rule for the sort operator. 865 struct SortRewriter : public OpRewritePattern<SortOp> { 866 public: 867 using OpRewritePattern<SortOp>::OpRewritePattern; 868 869 LogicalResult matchAndRewrite(SortOp op, 870 PatternRewriter &rewriter) const override { 871 SmallVector<Value> xys(op.getXs()); 872 xys.append(op.getYs().begin(), op.getYs().end()); 873 return matchAndRewriteSortOp(op, xys, op.getXs().size(), /*ny=*/0, 874 /*isCoo=*/false, rewriter); 875 } 876 }; 877 878 /// Sparse rewriting rule for the sort_coo operator. 879 struct SortCooRewriter : public OpRewritePattern<SortCooOp> { 880 public: 881 using OpRewritePattern<SortCooOp>::OpRewritePattern; 882 883 LogicalResult matchAndRewrite(SortCooOp op, 884 PatternRewriter &rewriter) const override { 885 SmallVector<Value> xys; 886 xys.push_back(op.getXy()); 887 xys.append(op.getYs().begin(), op.getYs().end()); 888 uint64_t nx = 1; 889 if (auto nxAttr = op.getNxAttr()) 890 nx = nxAttr.getInt(); 891 892 uint64_t ny = 0; 893 if (auto nyAttr = op.getNyAttr()) 894 ny = nyAttr.getInt(); 895 896 return matchAndRewriteSortOp(op, xys, nx, ny, 897 /*isCoo=*/true, rewriter); 898 } 899 }; 900 901 } // namespace 902 903 //===---------------------------------------------------------------------===// 904 // Methods that add patterns described in this file to a pattern list. 905 //===---------------------------------------------------------------------===// 906 907 void mlir::populateSparseBufferRewriting(RewritePatternSet &patterns, 908 bool enableBufferInitialization) { 909 patterns.add<PushBackRewriter>(patterns.getContext(), 910 enableBufferInitialization); 911 patterns.add<SortRewriter, SortCooRewriter>(patterns.getContext()); 912 } 913