1 //===- SparseBufferRewriting.cpp - Sparse buffer rewriting rules ----------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements rewriting rules that are specific to sparse tensor 10 // primitives with memref operands. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "CodegenUtils.h" 15 16 #include "mlir/Dialect/Arith/IR/Arith.h" 17 #include "mlir/Dialect/Func/IR/FuncOps.h" 18 #include "mlir/Dialect/Linalg/IR/Linalg.h" 19 #include "mlir/Dialect/MemRef/IR/MemRef.h" 20 #include "mlir/Dialect/SCF/IR/SCF.h" 21 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" 22 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h" 23 #include "mlir/Support/LLVM.h" 24 25 using namespace mlir; 26 using namespace mlir::sparse_tensor; 27 28 //===---------------------------------------------------------------------===// 29 // Helper methods for the actual rewriting rules. 30 //===---------------------------------------------------------------------===// 31 32 static constexpr uint64_t loIdx = 0; 33 static constexpr uint64_t hiIdx = 1; 34 static constexpr uint64_t xStartIdx = 2; 35 36 static constexpr const char kLessThanFuncNamePrefix[] = "_sparse_less_than_"; 37 static constexpr const char kCompareEqFuncNamePrefix[] = "_sparse_compare_eq_"; 38 static constexpr const char kPartitionFuncNamePrefix[] = "_sparse_partition_"; 39 static constexpr const char kBinarySearchFuncNamePrefix[] = 40 "_sparse_binary_search_"; 41 static constexpr const char kSortNonstableFuncNamePrefix[] = 42 "_sparse_sort_nonstable_"; 43 static constexpr const char kSortStableFuncNamePrefix[] = 44 "_sparse_sort_stable_"; 45 46 using FuncGeneratorType = 47 function_ref<void(OpBuilder &, ModuleOp, func::FuncOp, size_t)>; 48 49 /// Constructs a function name with this format to facilitate quick sort: 50 /// <namePrefix><dim>_<x type>_<y0 type>..._<yn type> 51 static void getMangledSortHelperFuncName(llvm::raw_svector_ostream &nameOstream, 52 StringRef namePrefix, size_t dim, 53 ValueRange operands) { 54 nameOstream 55 << namePrefix << dim << "_" 56 << operands[xStartIdx].getType().cast<MemRefType>().getElementType(); 57 58 for (Value v : operands.drop_front(xStartIdx + dim)) 59 nameOstream << "_" << v.getType().cast<MemRefType>().getElementType(); 60 } 61 62 /// Looks up a function that is appropriate for the given operands being 63 /// sorted, and creates such a function if it doesn't exist yet. 64 static FlatSymbolRefAttr 65 getMangledSortHelperFunc(OpBuilder &builder, func::FuncOp insertPoint, 66 TypeRange resultTypes, StringRef namePrefix, 67 size_t dim, ValueRange operands, 68 FuncGeneratorType createFunc) { 69 SmallString<32> nameBuffer; 70 llvm::raw_svector_ostream nameOstream(nameBuffer); 71 getMangledSortHelperFuncName(nameOstream, namePrefix, dim, operands); 72 73 ModuleOp module = insertPoint->getParentOfType<ModuleOp>(); 74 MLIRContext *context = module.getContext(); 75 auto result = SymbolRefAttr::get(context, nameOstream.str()); 76 auto func = module.lookupSymbol<func::FuncOp>(result.getAttr()); 77 78 if (!func) { 79 // Create the function. 80 OpBuilder::InsertionGuard insertionGuard(builder); 81 builder.setInsertionPoint(insertPoint); 82 Location loc = insertPoint.getLoc(); 83 func = builder.create<func::FuncOp>( 84 loc, nameOstream.str(), 85 FunctionType::get(context, operands.getTypes(), resultTypes)); 86 func.setPrivate(); 87 createFunc(builder, module, func, dim); 88 } 89 90 return result; 91 } 92 93 /// Creates a code block for swapping the values in index i and j for all the 94 /// buffers. 95 // 96 // The generated IR corresponds to this C like algorithm: 97 // swap(x0[i], x0[j]); 98 // swap(x1[i], x1[j]); 99 // ... 100 // swap(xn[i], xn[j]); 101 // swap(y0[i], y0[j]); 102 // ... 103 // swap(yn[i], yn[j]); 104 static void createSwap(OpBuilder &builder, Location loc, ValueRange args) { 105 Value i = args[0]; 106 Value j = args[1]; 107 for (auto arg : args.drop_front(xStartIdx)) { 108 Value vi = builder.create<memref::LoadOp>(loc, arg, i); 109 Value vj = builder.create<memref::LoadOp>(loc, arg, j); 110 builder.create<memref::StoreOp>(loc, vj, arg, i); 111 builder.create<memref::StoreOp>(loc, vi, arg, j); 112 } 113 } 114 115 /// Creates a function to compare all the (xs[i], xs[j]) pairs. The method to 116 /// compare each pair is create via `compareBuilder`. 117 static void createCompareFuncImplementation( 118 OpBuilder &builder, ModuleOp unused, func::FuncOp func, size_t dim, 119 function_ref<scf::IfOp(OpBuilder &, Location, Value, Value, Value, bool)> 120 compareBuilder) { 121 OpBuilder::InsertionGuard insertionGuard(builder); 122 123 Block *entryBlock = func.addEntryBlock(); 124 builder.setInsertionPointToStart(entryBlock); 125 Location loc = func.getLoc(); 126 ValueRange args = entryBlock->getArguments(); 127 128 scf::IfOp topIfOp; 129 for (const auto &item : llvm::enumerate(args.slice(xStartIdx, dim))) { 130 scf::IfOp ifOp = compareBuilder(builder, loc, args[0], args[1], 131 item.value(), (item.index() == dim - 1)); 132 if (item.index() == 0) { 133 topIfOp = ifOp; 134 } else { 135 OpBuilder::InsertionGuard insertionGuard(builder); 136 builder.setInsertionPointAfter(ifOp); 137 builder.create<scf::YieldOp>(loc, ifOp.getResult(0)); 138 } 139 } 140 141 builder.setInsertionPointAfter(topIfOp); 142 builder.create<func::ReturnOp>(loc, topIfOp.getResult(0)); 143 } 144 145 /// Generates an if-statement to compare whether x[i] is equal to x[j]. 146 static scf::IfOp createEqCompare(OpBuilder &builder, Location loc, Value i, 147 Value j, Value x, bool isLastDim) { 148 Value f = constantI1(builder, loc, false); 149 Value t = constantI1(builder, loc, true); 150 Value vi = builder.create<memref::LoadOp>(loc, x, i); 151 Value vj = builder.create<memref::LoadOp>(loc, x, j); 152 153 Value cond = 154 builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, vi, vj); 155 scf::IfOp ifOp = 156 builder.create<scf::IfOp>(loc, f.getType(), cond, /*else=*/true); 157 158 // x[1] != x[j]: 159 builder.setInsertionPointToStart(&ifOp.getElseRegion().front()); 160 builder.create<scf::YieldOp>(loc, f); 161 162 // x[i] == x[j]: 163 builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); 164 if (isLastDim == 1) { 165 // Finish checking all dimensions. 166 builder.create<scf::YieldOp>(loc, t); 167 } 168 169 return ifOp; 170 } 171 172 /// Creates a function to compare whether xs[i] is equal to xs[j]. 173 // 174 // The generate IR corresponds to this C like algorithm: 175 // if (x0[i] != x0[j]) 176 // return false; 177 // else 178 // if (x1[i] != x1[j]) 179 // return false; 180 // else if (x2[2] != x2[j])) 181 // and so on ... 182 static void createEqCompareFunc(OpBuilder &builder, ModuleOp unused, 183 func::FuncOp func, size_t dim) { 184 createCompareFuncImplementation(builder, unused, func, dim, createEqCompare); 185 } 186 187 /// Generates an if-statement to compare whether x[i] is less than x[j]. 188 static scf::IfOp createLessThanCompare(OpBuilder &builder, Location loc, 189 Value i, Value j, Value x, 190 bool isLastDim) { 191 Value f = constantI1(builder, loc, false); 192 Value t = constantI1(builder, loc, true); 193 Value vi = builder.create<memref::LoadOp>(loc, x, i); 194 Value vj = builder.create<memref::LoadOp>(loc, x, j); 195 196 Value cond = 197 builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult, vi, vj); 198 scf::IfOp ifOp = 199 builder.create<scf::IfOp>(loc, f.getType(), cond, /*else=*/true); 200 // If (x[i] < x[j]). 201 builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); 202 builder.create<scf::YieldOp>(loc, t); 203 204 builder.setInsertionPointToStart(&ifOp.getElseRegion().front()); 205 if (isLastDim == 1) { 206 // Finish checking all dimensions. 207 builder.create<scf::YieldOp>(loc, f); 208 } else { 209 cond = 210 builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult, vj, vi); 211 scf::IfOp ifOp2 = 212 builder.create<scf::IfOp>(loc, f.getType(), cond, /*else=*/true); 213 // Otherwise if (x[j] < x[i]). 214 builder.setInsertionPointToStart(&ifOp2.getThenRegion().front()); 215 builder.create<scf::YieldOp>(loc, f); 216 217 // Otherwise check the remaining dimensions. 218 builder.setInsertionPointAfter(ifOp2); 219 builder.create<scf::YieldOp>(loc, ifOp2.getResult(0)); 220 // Set up the insertion point for the nested if-stmt that checks the 221 // remaining dimensions. 222 builder.setInsertionPointToStart(&ifOp2.getElseRegion().front()); 223 } 224 225 return ifOp; 226 } 227 228 /// Creates a function to compare whether xs[i] is less than xs[j]. 229 // 230 // The generate IR corresponds to this C like algorithm: 231 // if (x0[i] < x0[j]) 232 // return true; 233 // else if (x0[j] < x0[i]) 234 // return false; 235 // else 236 // if (x1[i] < x1[j]) 237 // return true; 238 // else if (x1[j] < x1[i])) 239 // and so on ... 240 static void createLessThanFunc(OpBuilder &builder, ModuleOp unused, 241 func::FuncOp func, size_t dim) { 242 createCompareFuncImplementation(builder, unused, func, dim, 243 createLessThanCompare); 244 } 245 246 /// Creates a function to use a binary search to find the insertion point for 247 /// inserting xs[hi] to the sorted values xs[lo..hi). 248 // 249 // The generate IR corresponds to this C like algorithm: 250 // p = hi 251 // while (lo < hi) 252 // mid = (lo + hi) >> 1 253 // if (xs[p] < xs[mid]) 254 // hi = mid 255 // else 256 // lo = mid - 1 257 // return lo; 258 // 259 static void createBinarySearchFunc(OpBuilder &builder, ModuleOp module, 260 func::FuncOp func, size_t dim) { 261 OpBuilder::InsertionGuard insertionGuard(builder); 262 Block *entryBlock = func.addEntryBlock(); 263 builder.setInsertionPointToStart(entryBlock); 264 265 Location loc = func.getLoc(); 266 ValueRange args = entryBlock->getArguments(); 267 Value p = args[hiIdx]; 268 SmallVector<Type, 2> types(2, p.getType()); 269 scf::WhileOp whileOp = builder.create<scf::WhileOp>( 270 loc, types, SmallVector<Value, 2>{args[loIdx], args[hiIdx]}); 271 272 // The before-region of the WhileOp. 273 Block *before = 274 builder.createBlock(&whileOp.getBefore(), {}, types, {loc, loc}); 275 builder.setInsertionPointToEnd(before); 276 Value cond1 = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult, 277 before->getArgument(0), 278 before->getArgument(1)); 279 builder.create<scf::ConditionOp>(loc, cond1, before->getArguments()); 280 281 // The after-region of the WhileOp. 282 Block *after = 283 builder.createBlock(&whileOp.getAfter(), {}, types, {loc, loc}); 284 builder.setInsertionPointToEnd(after); 285 Value lo = after->getArgument(0); 286 Value hi = after->getArgument(1); 287 // Compute mid = (lo + hi) >> 1. 288 Value c1 = constantIndex(builder, loc, 1); 289 Value mid = builder.create<arith::ShRUIOp>( 290 loc, builder.create<arith::AddIOp>(loc, lo, hi), c1); 291 Value midp1 = builder.create<arith::AddIOp>(loc, mid, c1); 292 293 // Compare xs[p] < xs[mid]. 294 SmallVector<Value, 6> compareOperands{p, mid}; 295 compareOperands.append(args.begin() + xStartIdx, 296 args.begin() + xStartIdx + dim); 297 Type i1Type = IntegerType::get(module.getContext(), 1, IntegerType::Signless); 298 FlatSymbolRefAttr lessThanFunc = 299 getMangledSortHelperFunc(builder, func, {i1Type}, kLessThanFuncNamePrefix, 300 dim, compareOperands, createLessThanFunc); 301 Value cond2 = builder 302 .create<func::CallOp>(loc, lessThanFunc, TypeRange{i1Type}, 303 compareOperands) 304 .getResult(0); 305 306 // Update lo and hi for the WhileOp as follows: 307 // if (xs[p] < xs[mid])) 308 // hi = mid; 309 // else 310 // lo = mid + 1; 311 Value newLo = builder.create<arith::SelectOp>(loc, cond2, lo, midp1); 312 Value newHi = builder.create<arith::SelectOp>(loc, cond2, mid, hi); 313 builder.create<scf::YieldOp>(loc, ValueRange{newLo, newHi}); 314 315 builder.setInsertionPointAfter(whileOp); 316 builder.create<func::ReturnOp>(loc, whileOp.getResult(0)); 317 } 318 319 /// Creates code to advance i in a loop based on xs[p] as follows: 320 /// while (xs[i] < xs[p]) i += step (step > 0) 321 /// or 322 /// while (xs[i] > xs[p]) i += step (step < 0) 323 /// The routine returns i as well as a boolean value to indicate whether 324 /// xs[i] == xs[p]. 325 static std::pair<Value, Value> 326 createScanLoop(OpBuilder &builder, ModuleOp module, func::FuncOp func, 327 ValueRange xs, Value i, Value p, size_t dim, int step) { 328 Location loc = func.getLoc(); 329 scf::WhileOp whileOp = 330 builder.create<scf::WhileOp>(loc, TypeRange{i.getType()}, ValueRange{i}); 331 332 Block *before = 333 builder.createBlock(&whileOp.getBefore(), {}, {i.getType()}, {loc}); 334 builder.setInsertionPointToEnd(before); 335 SmallVector<Value, 6> compareOperands; 336 if (step > 0) { 337 compareOperands.push_back(before->getArgument(0)); 338 compareOperands.push_back(p); 339 } else { 340 assert(step < 0); 341 compareOperands.push_back(p); 342 compareOperands.push_back(before->getArgument(0)); 343 } 344 compareOperands.append(xs.begin(), xs.end()); 345 MLIRContext *context = module.getContext(); 346 Type i1Type = IntegerType::get(context, 1, IntegerType::Signless); 347 FlatSymbolRefAttr lessThanFunc = 348 getMangledSortHelperFunc(builder, func, {i1Type}, kLessThanFuncNamePrefix, 349 dim, compareOperands, createLessThanFunc); 350 Value cond = builder 351 .create<func::CallOp>(loc, lessThanFunc, TypeRange{i1Type}, 352 compareOperands) 353 .getResult(0); 354 builder.create<scf::ConditionOp>(loc, cond, before->getArguments()); 355 356 Block *after = 357 builder.createBlock(&whileOp.getAfter(), {}, {i.getType()}, {loc}); 358 builder.setInsertionPointToEnd(after); 359 Value cs = constantIndex(builder, loc, step); 360 i = builder.create<arith::AddIOp>(loc, after->getArgument(0), cs); 361 builder.create<scf::YieldOp>(loc, ValueRange{i}); 362 i = whileOp.getResult(0); 363 364 builder.setInsertionPointAfter(whileOp); 365 compareOperands[0] = i; 366 compareOperands[1] = p; 367 FlatSymbolRefAttr compareEqFunc = getMangledSortHelperFunc( 368 builder, func, {i1Type}, kCompareEqFuncNamePrefix, dim, compareOperands, 369 createEqCompareFunc); 370 Value compareEq = 371 builder 372 .create<func::CallOp>(loc, compareEqFunc, TypeRange{i1Type}, 373 compareOperands) 374 .getResult(0); 375 376 return std::make_pair(whileOp.getResult(0), compareEq); 377 } 378 379 /// Creates a function to perform quick sort partition on the values in the 380 /// range of index [lo, hi), assuming lo < hi. 381 // 382 // The generated IR corresponds to this C like algorithm: 383 // int partition(lo, hi, xs) { 384 // p = (lo+hi)/2 // pivot index 385 // i = lo 386 // j = hi-1 387 // while (i < j) do { 388 // while (xs[i] < xs[p]) i ++; 389 // i_eq = (xs[i] == xs[p]); 390 // while (xs[j] > xs[p]) j --; 391 // j_eq = (xs[j] == xs[p]); 392 // if (i < j) { 393 // swap(xs[i], xs[j]) 394 // if (i == p) { 395 // p = j; 396 // } else if (j == p) { 397 // p = i; 398 // } 399 // if (i_eq && j_eq) { 400 // ++i; 401 // --j; 402 // } 403 // } 404 // } 405 // return p 406 // } 407 static void createPartitionFunc(OpBuilder &builder, ModuleOp module, 408 func::FuncOp func, size_t dim) { 409 OpBuilder::InsertionGuard insertionGuard(builder); 410 411 Block *entryBlock = func.addEntryBlock(); 412 builder.setInsertionPointToStart(entryBlock); 413 414 Location loc = func.getLoc(); 415 ValueRange args = entryBlock->getArguments(); 416 Value lo = args[loIdx]; 417 Value hi = args[hiIdx]; 418 Value sum = builder.create<arith::AddIOp>(loc, lo, hi); 419 Value c1 = constantIndex(builder, loc, 1); 420 Value p = builder.create<arith::ShRUIOp>(loc, sum, c1); 421 422 Value i = lo; 423 Value j = builder.create<arith::SubIOp>(loc, hi, c1); 424 SmallVector<Value, 4> operands{i, j, p}; 425 SmallVector<Type, 4> types{i.getType(), j.getType(), p.getType()}; 426 scf::WhileOp whileOp = builder.create<scf::WhileOp>(loc, types, operands); 427 428 // The before-region of the WhileOp. 429 Block *before = 430 builder.createBlock(&whileOp.getBefore(), {}, types, {loc, loc, loc}); 431 builder.setInsertionPointToEnd(before); 432 Value cond = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult, 433 before->getArgument(0), 434 before->getArgument(1)); 435 builder.create<scf::ConditionOp>(loc, cond, before->getArguments()); 436 437 // The after-region of the WhileOp. 438 Block *after = 439 builder.createBlock(&whileOp.getAfter(), {}, types, {loc, loc, loc}); 440 builder.setInsertionPointToEnd(after); 441 i = after->getArgument(0); 442 j = after->getArgument(1); 443 p = after->getArgument(2); 444 445 auto [iresult, iCompareEq] = createScanLoop( 446 builder, module, func, args.slice(xStartIdx, dim), i, p, dim, 1); 447 i = iresult; 448 auto [jresult, jCompareEq] = createScanLoop( 449 builder, module, func, args.slice(xStartIdx, dim), j, p, dim, -1); 450 j = jresult; 451 452 // If i < j: 453 cond = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult, i, j); 454 scf::IfOp ifOp = builder.create<scf::IfOp>(loc, types, cond, /*else=*/true); 455 builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); 456 SmallVector<Value, 6> swapOperands{i, j}; 457 swapOperands.append(args.begin() + xStartIdx, args.end()); 458 createSwap(builder, loc, swapOperands); 459 // If the pivot is moved, update p with the new pivot. 460 Value icond = 461 builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, i, p); 462 scf::IfOp ifOpI = builder.create<scf::IfOp>(loc, TypeRange{p.getType()}, 463 icond, /*else=*/true); 464 builder.setInsertionPointToStart(&ifOpI.getThenRegion().front()); 465 builder.create<scf::YieldOp>(loc, ValueRange{j}); 466 builder.setInsertionPointToStart(&ifOpI.getElseRegion().front()); 467 Value jcond = 468 builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, j, p); 469 scf::IfOp ifOpJ = builder.create<scf::IfOp>(loc, TypeRange{p.getType()}, 470 jcond, /*else=*/true); 471 builder.setInsertionPointToStart(&ifOpJ.getThenRegion().front()); 472 builder.create<scf::YieldOp>(loc, ValueRange{i}); 473 builder.setInsertionPointToStart(&ifOpJ.getElseRegion().front()); 474 builder.create<scf::YieldOp>(loc, ValueRange{p}); 475 builder.setInsertionPointAfter(ifOpJ); 476 builder.create<scf::YieldOp>(loc, ifOpJ.getResults()); 477 builder.setInsertionPointAfter(ifOpI); 478 Value compareEqIJ = 479 builder.create<arith::AndIOp>(loc, iCompareEq, jCompareEq); 480 scf::IfOp ifOp2 = builder.create<scf::IfOp>( 481 loc, TypeRange{i.getType(), j.getType()}, compareEqIJ, /*else=*/true); 482 builder.setInsertionPointToStart(&ifOp2.getThenRegion().front()); 483 Value i2 = builder.create<arith::AddIOp>(loc, i, c1); 484 Value j2 = builder.create<arith::SubIOp>(loc, j, c1); 485 builder.create<scf::YieldOp>(loc, ValueRange{i2, j2}); 486 builder.setInsertionPointToStart(&ifOp2.getElseRegion().front()); 487 builder.create<scf::YieldOp>(loc, ValueRange{i, j}); 488 builder.setInsertionPointAfter(ifOp2); 489 builder.create<scf::YieldOp>( 490 loc, 491 ValueRange{ifOp2.getResult(0), ifOp2.getResult(1), ifOpI.getResult(0)}); 492 493 // False branch for if i < j: 494 builder.setInsertionPointToStart(&ifOp.getElseRegion().front()); 495 builder.create<scf::YieldOp>(loc, ValueRange{i, j, p}); 496 497 // Return for the whileOp. 498 builder.setInsertionPointAfter(ifOp); 499 builder.create<scf::YieldOp>(loc, ifOp.getResults()); 500 501 // Return for the function. 502 builder.setInsertionPointAfter(whileOp); 503 builder.create<func::ReturnOp>(loc, whileOp.getResult(2)); 504 } 505 506 /// Creates a function to perform quick sort on the value in the range of 507 /// index [lo, hi). 508 // 509 // The generate IR corresponds to this C like algorithm: 510 // void quickSort(lo, hi, data) { 511 // if (lo < hi) { 512 // p = partition(low, high, data); 513 // quickSort(lo, p, data); 514 // quickSort(p + 1, hi, data); 515 // } 516 // } 517 static void createSortNonstableFunc(OpBuilder &builder, ModuleOp module, 518 func::FuncOp func, size_t dim) { 519 OpBuilder::InsertionGuard insertionGuard(builder); 520 Block *entryBlock = func.addEntryBlock(); 521 builder.setInsertionPointToStart(entryBlock); 522 523 MLIRContext *context = module.getContext(); 524 Location loc = func.getLoc(); 525 ValueRange args = entryBlock->getArguments(); 526 Value lo = args[loIdx]; 527 Value hi = args[hiIdx]; 528 Value cond = 529 builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult, lo, hi); 530 scf::IfOp ifOp = builder.create<scf::IfOp>(loc, cond, /*else=*/false); 531 532 // The if-stmt true branch. 533 builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); 534 FlatSymbolRefAttr partitionFunc = getMangledSortHelperFunc( 535 builder, func, {IndexType::get(context)}, kPartitionFuncNamePrefix, dim, 536 args, createPartitionFunc); 537 auto p = builder.create<func::CallOp>( 538 loc, partitionFunc, TypeRange{IndexType::get(context)}, ValueRange(args)); 539 540 SmallVector<Value, 6> lowOperands{lo, p.getResult(0)}; 541 lowOperands.append(args.begin() + xStartIdx, args.end()); 542 builder.create<func::CallOp>(loc, func, lowOperands); 543 544 SmallVector<Value, 6> highOperands{ 545 builder.create<arith::AddIOp>(loc, p.getResult(0), 546 constantIndex(builder, loc, 1)), 547 hi}; 548 highOperands.append(args.begin() + xStartIdx, args.end()); 549 builder.create<func::CallOp>(loc, func, highOperands); 550 551 // After the if-stmt. 552 builder.setInsertionPointAfter(ifOp); 553 builder.create<func::ReturnOp>(loc); 554 } 555 556 /// Creates a function to perform insertion sort on the values in the range of 557 /// index [lo, hi). 558 // 559 // The generate IR corresponds to this C like algorithm: 560 // void insertionSort(lo, hi, data) { 561 // for (i = lo+1; i < hi; i++) { 562 // d = data[i]; 563 // p = binarySearch(lo, i-1, data) 564 // for (j = 0; j > i - p; j++) 565 // data[i-j] = data[i-j-1] 566 // data[p] = d 567 // } 568 // } 569 static void createSortStableFunc(OpBuilder &builder, ModuleOp module, 570 func::FuncOp func, size_t dim) { 571 OpBuilder::InsertionGuard insertionGuard(builder); 572 Block *entryBlock = func.addEntryBlock(); 573 builder.setInsertionPointToStart(entryBlock); 574 575 MLIRContext *context = module.getContext(); 576 Location loc = func.getLoc(); 577 ValueRange args = entryBlock->getArguments(); 578 Value c1 = constantIndex(builder, loc, 1); 579 Value lo = args[loIdx]; 580 Value hi = args[hiIdx]; 581 Value lop1 = builder.create<arith::AddIOp>(loc, lo, c1); 582 583 // Start the outer for-stmt with induction variable i. 584 scf::ForOp forOpI = builder.create<scf::ForOp>(loc, lop1, hi, c1); 585 builder.setInsertionPointToStart(forOpI.getBody()); 586 Value i = forOpI.getInductionVar(); 587 588 // Binary search to find the insertion point p. 589 SmallVector<Value, 6> operands{lo, i}; 590 operands.append(args.begin() + xStartIdx, args.begin() + xStartIdx + dim); 591 FlatSymbolRefAttr searchFunc = getMangledSortHelperFunc( 592 builder, func, {IndexType::get(context)}, kBinarySearchFuncNamePrefix, 593 dim, operands, createBinarySearchFunc); 594 Value p = builder 595 .create<func::CallOp>(loc, searchFunc, TypeRange{c1.getType()}, 596 operands) 597 .getResult(0); 598 599 // Move the value at data[i] to a temporary location. 600 ValueRange data = args.drop_front(xStartIdx); 601 SmallVector<Value, 6> d; 602 for (Value v : data) 603 d.push_back(builder.create<memref::LoadOp>(loc, v, i)); 604 605 // Start the inner for-stmt with induction variable j, for moving data[p..i) 606 // to data[p+1..i+1). 607 Value imp = builder.create<arith::SubIOp>(loc, i, p); 608 Value c0 = constantIndex(builder, loc, 0); 609 scf::ForOp forOpJ = builder.create<scf::ForOp>(loc, c0, imp, c1); 610 builder.setInsertionPointToStart(forOpJ.getBody()); 611 Value j = forOpJ.getInductionVar(); 612 Value imj = builder.create<arith::SubIOp>(loc, i, j); 613 Value imjm1 = builder.create<arith::SubIOp>(loc, imj, c1); 614 for (Value v : data) { 615 Value t = builder.create<memref::LoadOp>(loc, v, imjm1); 616 builder.create<memref::StoreOp>(loc, t, v, imj); 617 } 618 619 // Store the value at data[i] to data[p]. 620 builder.setInsertionPointAfter(forOpJ); 621 for (auto it : llvm::zip(d, data)) 622 builder.create<memref::StoreOp>(loc, std::get<0>(it), std::get<1>(it), p); 623 624 builder.setInsertionPointAfter(forOpI); 625 builder.create<func::ReturnOp>(loc); 626 } 627 628 //===---------------------------------------------------------------------===// 629 // The actual sparse buffer rewriting rules. 630 //===---------------------------------------------------------------------===// 631 632 namespace { 633 634 /// Sparse rewriting rule for the push_back operator. 635 struct PushBackRewriter : OpRewritePattern<PushBackOp> { 636 public: 637 using OpRewritePattern<PushBackOp>::OpRewritePattern; 638 PushBackRewriter(MLIRContext *context, bool enableInit) 639 : OpRewritePattern(context), enableBufferInitialization(enableInit) {} 640 LogicalResult matchAndRewrite(PushBackOp op, 641 PatternRewriter &rewriter) const override { 642 // Rewrite push_back(buffer, value, n) to: 643 // new_size = size(buffer) + n 644 // if (new_size > capacity(buffer)) 645 // while new_size > new_capacity 646 // new_capacity = new_capacity*2 647 // new_buffer = realloc(buffer, new_capacity) 648 // buffer = new_buffer 649 // subBuffer = subviewof(buffer) 650 // linalg.fill subBuffer value 651 // 652 // size(buffer) += n 653 // 654 // The capacity check is skipped when the attribute inbounds is presented. 655 Location loc = op->getLoc(); 656 Value c0 = constantIndex(rewriter, loc, 0); 657 Value buffer = op.getInBuffer(); 658 Value capacity = rewriter.create<memref::DimOp>(loc, buffer, c0); 659 Value idx = constantIndex(rewriter, loc, op.getIdx().getZExtValue()); 660 Value bufferSizes = op.getBufferSizes(); 661 Value size = rewriter.create<memref::LoadOp>(loc, bufferSizes, idx); 662 Value value = op.getValue(); 663 664 Value n = op.getN() ? op.getN() : constantIndex(rewriter, loc, 1); 665 Value newSize = rewriter.create<arith::AddIOp>(loc, size, n); 666 auto nValue = dyn_cast_or_null<arith::ConstantIndexOp>(n.getDefiningOp()); 667 bool nIsOne = (nValue && nValue.value() == 1); 668 669 if (!op.getInbounds()) { 670 Value cond = rewriter.create<arith::CmpIOp>( 671 loc, arith::CmpIPredicate::ugt, newSize, capacity); 672 673 Value c2 = constantIndex(rewriter, loc, 2); 674 auto bufferType = 675 MemRefType::get({ShapedType::kDynamicSize}, value.getType()); 676 scf::IfOp ifOp = rewriter.create<scf::IfOp>(loc, bufferType, cond, 677 /*else=*/true); 678 // True branch. 679 rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front()); 680 if (nIsOne) { 681 capacity = rewriter.create<arith::MulIOp>(loc, capacity, c2); 682 } else { 683 // Use a do-while loop to calculate the new capacity as follows: 684 // do { new_capacity *= 2 } while (size > new_capacity) 685 scf::WhileOp whileOp = 686 rewriter.create<scf::WhileOp>(loc, capacity.getType(), capacity); 687 688 // The before-region of the WhileOp. 689 Block *before = rewriter.createBlock(&whileOp.getBefore(), {}, 690 {capacity.getType()}, {loc}); 691 rewriter.setInsertionPointToEnd(before); 692 693 capacity = 694 rewriter.create<arith::MulIOp>(loc, before->getArgument(0), c2); 695 cond = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ugt, 696 newSize, capacity); 697 rewriter.create<scf::ConditionOp>(loc, cond, ValueRange{capacity}); 698 // The after-region of the WhileOp. 699 Block *after = rewriter.createBlock(&whileOp.getAfter(), {}, 700 {capacity.getType()}, {loc}); 701 rewriter.setInsertionPointToEnd(after); 702 rewriter.create<scf::YieldOp>(loc, after->getArguments()); 703 704 rewriter.setInsertionPointAfter(whileOp); 705 capacity = whileOp.getResult(0); 706 } 707 708 Value newBuffer = 709 rewriter.create<memref::ReallocOp>(loc, bufferType, buffer, capacity); 710 if (enableBufferInitialization) { 711 Value fillSize = rewriter.create<arith::SubIOp>(loc, capacity, newSize); 712 Value fillValue = rewriter.create<arith::ConstantOp>( 713 loc, value.getType(), rewriter.getZeroAttr(value.getType())); 714 Value subBuffer = rewriter.create<memref::SubViewOp>( 715 loc, newBuffer, /*offset=*/ValueRange{newSize}, 716 /*size=*/ValueRange{fillSize}, 717 /*step=*/ValueRange{constantIndex(rewriter, loc, 1)}); 718 rewriter.create<linalg::FillOp>(loc, fillValue, subBuffer); 719 } 720 rewriter.create<scf::YieldOp>(loc, newBuffer); 721 722 // False branch. 723 rewriter.setInsertionPointToStart(&ifOp.getElseRegion().front()); 724 rewriter.create<scf::YieldOp>(loc, buffer); 725 726 // Prepare for adding the value to the end of the buffer. 727 rewriter.setInsertionPointAfter(ifOp); 728 buffer = ifOp.getResult(0); 729 } 730 731 // Add the value to the end of the buffer. 732 if (nIsOne) { 733 rewriter.create<memref::StoreOp>(loc, value, buffer, size); 734 } else { 735 Value subBuffer = rewriter.create<memref::SubViewOp>( 736 loc, buffer, /*offset=*/ValueRange{size}, /*size=*/ValueRange{n}, 737 /*step=*/ValueRange{constantIndex(rewriter, loc, 1)}); 738 rewriter.create<linalg::FillOp>(loc, value, subBuffer); 739 } 740 741 // Update the buffer size. 742 rewriter.create<memref::StoreOp>(loc, newSize, bufferSizes, idx); 743 rewriter.replaceOp(op, buffer); 744 return success(); 745 } 746 747 private: 748 bool enableBufferInitialization; 749 }; 750 751 /// Sparse rewriting rule for the sort operator. 752 struct SortRewriter : public OpRewritePattern<SortOp> { 753 public: 754 using OpRewritePattern<SortOp>::OpRewritePattern; 755 756 LogicalResult matchAndRewrite(SortOp op, 757 PatternRewriter &rewriter) const override { 758 Location loc = op.getLoc(); 759 SmallVector<Value, 6> operands{constantIndex(rewriter, loc, 0), op.getN()}; 760 761 // Convert `values` to have dynamic shape and append them to `operands`. 762 auto addValues = [&](ValueRange values) { 763 for (Value v : values) { 764 auto mtp = v.getType().cast<MemRefType>(); 765 if (!mtp.isDynamicDim(0)) { 766 auto newMtp = 767 MemRefType::get({ShapedType::kDynamicSize}, mtp.getElementType()); 768 v = rewriter.create<memref::CastOp>(loc, newMtp, v); 769 } 770 operands.push_back(v); 771 } 772 }; 773 ValueRange xs = op.getXs(); 774 addValues(xs); 775 addValues(op.getYs()); 776 auto insertPoint = op->getParentOfType<func::FuncOp>(); 777 SmallString<32> funcName(op.getStable() ? kSortStableFuncNamePrefix 778 : kSortNonstableFuncNamePrefix); 779 FuncGeneratorType funcGenerator = 780 op.getStable() ? createSortStableFunc : createSortNonstableFunc; 781 FlatSymbolRefAttr func = 782 getMangledSortHelperFunc(rewriter, insertPoint, TypeRange(), funcName, 783 xs.size(), operands, funcGenerator); 784 rewriter.replaceOpWithNewOp<func::CallOp>(op, func, TypeRange(), operands); 785 return success(); 786 } 787 }; 788 789 } // namespace 790 791 //===---------------------------------------------------------------------===// 792 // Methods that add patterns described in this file to a pattern list. 793 //===---------------------------------------------------------------------===// 794 795 void mlir::populateSparseBufferRewriting(RewritePatternSet &patterns, 796 bool enableBufferInitialization) { 797 patterns.add<PushBackRewriter>(patterns.getContext(), 798 enableBufferInitialization); 799 patterns.add<SortRewriter>(patterns.getContext()); 800 } 801