//===- SparseBufferRewriting.cpp - Sparse buffer rewriting rules ----------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This file implements rewriting rules that are specific to sparse tensor // primitives with memref operands. // //===----------------------------------------------------------------------===// #include "Utils/CodegenUtils.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" #include "mlir/Dialect/SparseTensor/Transforms/Passes.h" #include "mlir/Support/LLVM.h" using namespace mlir; using namespace mlir::sparse_tensor; //===---------------------------------------------------------------------===// // Helper methods for the actual rewriting rules. //===---------------------------------------------------------------------===// static constexpr uint64_t loIdx = 0; static constexpr uint64_t hiIdx = 1; static constexpr uint64_t xStartIdx = 2; static constexpr const char kPartitionFuncNamePrefix[] = "_sparse_partition_"; static constexpr const char kBinarySearchFuncNamePrefix[] = "_sparse_binary_search_"; static constexpr const char kHybridQuickSortFuncNamePrefix[] = "_sparse_hybrid_qsort_"; static constexpr const char kSortStableFuncNamePrefix[] = "_sparse_sort_stable_"; static constexpr const char kShiftDownFuncNamePrefix[] = "_sparse_shift_down_"; static constexpr const char kHeapSortFuncNamePrefix[] = "_sparse_heap_sort_"; static constexpr const char kQuickSortFuncNamePrefix[] = "_sparse_qsort_"; using FuncGeneratorType = function_ref; /// Constructs a function name with this format to facilitate quick sort: /// __..._ for sort /// __coo__..._ for sort_coo static void getMangledSortHelperFuncName(llvm::raw_svector_ostream &nameOstream, StringRef namePrefix, AffineMap xPerm, uint64_t ny, ValueRange operands) { nameOstream << namePrefix; for (auto res : xPerm.getResults()) nameOstream << cast(res).getPosition() << "_"; nameOstream << getMemRefType(operands[xStartIdx]).getElementType(); nameOstream << "_coo_" << ny; constexpr uint64_t yBufferOffset = 1; for (Value v : operands.drop_front(xStartIdx + yBufferOffset)) nameOstream << "_" << getMemRefType(v).getElementType(); } /// Looks up a function that is appropriate for the given operands being /// sorted, and creates such a function if it doesn't exist yet. The /// parameters `xPerm` and `ny` tell the number of x and y values provided /// by the buffer in xStartIdx. // // All sorting function generators take (lo, hi, xs, ys) in `operands` as // parameters for the sorting functions. Other parameters, such as the recursive // call depth, are appended to the end of the parameter list as // "trailing parameters". static FlatSymbolRefAttr getMangledSortHelperFunc( OpBuilder &builder, func::FuncOp insertPoint, TypeRange resultTypes, StringRef namePrefix, AffineMap xPerm, uint64_t ny, ValueRange operands, FuncGeneratorType createFunc, uint32_t nTrailingP = 0) { SmallString<32> nameBuffer; llvm::raw_svector_ostream nameOstream(nameBuffer); getMangledSortHelperFuncName(nameOstream, namePrefix, xPerm, ny, operands.drop_back(nTrailingP)); ModuleOp module = insertPoint->getParentOfType(); MLIRContext *context = module.getContext(); auto result = SymbolRefAttr::get(context, nameOstream.str()); auto func = module.lookupSymbol(result.getAttr()); if (!func) { // Create the function. OpBuilder::InsertionGuard insertionGuard(builder); builder.setInsertionPoint(insertPoint); Location loc = insertPoint.getLoc(); func = builder.create( loc, nameOstream.str(), FunctionType::get(context, operands.getTypes(), resultTypes)); func.setPrivate(); createFunc(builder, module, func, xPerm, ny, nTrailingP); } return result; } /// Creates a code block to process each pair of (xs[i], xs[j]) for sorting. /// The code to process the value pairs is generated by `bodyBuilder`. static void forEachIJPairInXs( OpBuilder &builder, Location loc, ValueRange args, AffineMap xPerm, uint64_t ny, function_ref bodyBuilder) { Value cstep = constantIndex(builder, loc, xPerm.getNumResults() + ny); Value iOffset = builder.create(loc, args[0], cstep); Value jOffset = builder.create(loc, args[1], cstep); for (unsigned k = 0, e = xPerm.getNumResults(); k < e; k++) { unsigned actualK = cast(xPerm.getResult(k)).getPosition(); Value ak = constantIndex(builder, loc, actualK); Value i = builder.create(loc, ak, iOffset); Value j = builder.create(loc, ak, jOffset); Value buffer = args[xStartIdx]; bodyBuilder(k, i, j, buffer); } } /// Creates a code block to process each pair of (xys[i], xys[j]) for sorting. /// The code to process the value pairs is generated by `bodyBuilder`. static void forEachIJPairInAllBuffers( OpBuilder &builder, Location loc, ValueRange args, AffineMap xPerm, uint64_t ny, function_ref bodyBuilder) { // Create code for the first (xPerm + ny) buffers. SmallVector exps(xPerm.getResults()); for (unsigned y = 0; y < ny; y++) { exps.push_back(builder.getAffineDimExpr(y + xPerm.getNumResults())); } AffineMap xyPerm = AffineMap::get(exps.size(), 0, exps, builder.getContext()); assert(xyPerm.isPermutation()); forEachIJPairInXs(builder, loc, args, xyPerm, 0, bodyBuilder); constexpr uint64_t numHandledBuffers = 1; // Create code for the remaining buffers. Value i = args[0]; Value j = args[1]; for (const auto &arg : llvm::enumerate(args.drop_front(xStartIdx + numHandledBuffers))) { bodyBuilder(arg.index() + xPerm.getNumResults() + ny, i, j, arg.value()); } } /// Creates a code block for swapping the values in index i and j for all the /// buffers. // // The generated IR corresponds to this C like algorithm: // swap(x0[i], x0[j]); // swap(x1[i], x1[j]); // ... // swap(xn[i], xn[j]); // swap(y0[i], y0[j]); // ... // swap(yn[i], yn[j]); static void createSwap(OpBuilder &builder, Location loc, ValueRange args, AffineMap xPerm, uint64_t ny) { auto swapOnePair = [&](uint64_t unused, Value i, Value j, Value buffer) { Value vi = builder.create(loc, buffer, i); Value vj = builder.create(loc, buffer, j); builder.create(loc, vj, buffer, i); builder.create(loc, vi, buffer, j); }; forEachIJPairInAllBuffers(builder, loc, args, xPerm, ny, swapOnePair); } /// Creates code to compare all the (xs[i], xs[j]) pairs. The method to compare /// each pair is create via `compareBuilder`. static Value createInlinedCompareImplementation( OpBuilder &builder, Location loc, ValueRange args, AffineMap xPerm, uint64_t ny, function_ref compareBuilder) { Value result; auto bodyBuilder = [&](uint64_t k, Value i, Value j, Value buffer) { bool isFirstDim = (k == 0); bool isLastDim = (k == xPerm.getNumResults() - 1); Value val = compareBuilder(builder, loc, i, j, buffer, isFirstDim, isLastDim); if (isFirstDim) { result = val; } else if (!isLastDim) { OpBuilder::InsertionGuard insertionGuard(builder); auto ifOp = cast(val.getDefiningOp()); builder.setInsertionPointAfter(ifOp); builder.create(loc, ifOp.getResult(0)); } }; forEachIJPairInXs(builder, loc, args, xPerm, ny, bodyBuilder); builder.setInsertionPointAfterValue(result); return result; } /// Generates code to compare whether x[i] is equal to x[j] and returns the /// result of the comparison. static Value createEqCompare(OpBuilder &builder, Location loc, Value i, Value j, Value x, bool isFirstDim, bool isLastDim) { Value vi = builder.create(loc, x, i); Value vj = builder.create(loc, x, j); Value res; if (isLastDim) { res = builder.create(loc, arith::CmpIPredicate::eq, vi, vj); // For 1D, we create a compare without any control flow. Otherwise, we // create YieldOp to return the result in the nested if-stmt. if (!isFirstDim) builder.create(loc, res); } else { Value ne = builder.create(loc, arith::CmpIPredicate::ne, vi, vj); scf::IfOp ifOp = builder.create(loc, builder.getIntegerType(1), ne, /*else=*/true); // If (x[i] != x[j]). builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); Value f = constantI1(builder, loc, false); builder.create(loc, f); // If (x[i] == x[j]). Set up the insertion point for the nested if-stmt that // checks the remaining dimensions. builder.setInsertionPointToStart(&ifOp.getElseRegion().front()); res = ifOp.getResult(0); } return res; } /// Creates code to compare whether xs[i] is equal to xs[j]. // // The generate IR corresponds to this C like algorithm: // if (x0[i] != x0[j]) // return false; // else // if (x1[i] != x1[j]) // return false; // else if (x2[2] != x2[j])) // and so on ... static Value createInlinedEqCompare(OpBuilder &builder, Location loc, ValueRange args, AffineMap xPerm, uint64_t ny, uint32_t nTrailingP = 0) { // Compare functions don't use trailing parameters. (void)nTrailingP; assert(nTrailingP == 0); return createInlinedCompareImplementation(builder, loc, args, xPerm, ny, createEqCompare); } /// Generates code to compare whether x[i] is less than x[j] and returns the /// result of the comparison. static Value createLessThanCompare(OpBuilder &builder, Location loc, Value i, Value j, Value x, bool isFirstDim, bool isLastDim) { Value vi = builder.create(loc, x, i); Value vj = builder.create(loc, x, j); Value res; if (isLastDim) { res = builder.create(loc, arith::CmpIPredicate::ult, vi, vj); // For 1D, we create a compare without any control flow. Otherwise, we // create YieldOp to return the result in the nested if-stmt. if (!isFirstDim) builder.create(loc, res); } else { Value ne = builder.create(loc, arith::CmpIPredicate::ne, vi, vj); scf::IfOp ifOp = builder.create(loc, builder.getIntegerType(1), ne, /*else=*/true); // If (x[i] != x[j]). builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); Value lt = builder.create(loc, arith::CmpIPredicate::ult, vi, vj); builder.create(loc, lt); // If (x[i] == x[j]). Set up the insertion point for the nested if-stmt that // checks the remaining dimensions. builder.setInsertionPointToStart(&ifOp.getElseRegion().front()); res = ifOp.getResult(0); } return res; } /// Creates code to compare whether xs[i] is less than xs[j]. // // The generate IR corresponds to this C like algorithm: // if (x0[i] != x0[j]) // return x0[i] < x0[j]; // else if (x1[j] != x1[i]) // return x1[i] < x1[j]; // else // and so on ... static Value createInlinedLessThan(OpBuilder &builder, Location loc, ValueRange args, AffineMap xPerm, uint64_t ny, uint32_t nTrailingP = 0) { // Compare functions don't use trailing parameters. (void)nTrailingP; assert(nTrailingP == 0); return createInlinedCompareImplementation(builder, loc, args, xPerm, ny, createLessThanCompare); } /// Creates a function to use a binary search to find the insertion point for /// inserting xs[hi] to the sorted values xs[lo..hi). // // The generate IR corresponds to this C like algorithm: // p = hi // while (lo < hi) // mid = (lo + hi) >> 1 // if (xs[p] < xs[mid]) // hi = mid // else // lo = mid - 1 // return lo; // static void createBinarySearchFunc(OpBuilder &builder, ModuleOp module, func::FuncOp func, AffineMap xPerm, uint64_t ny, uint32_t nTrailingP = 0) { // Binary search doesn't use trailing parameters. (void)nTrailingP; assert(nTrailingP == 0); OpBuilder::InsertionGuard insertionGuard(builder); Block *entryBlock = func.addEntryBlock(); builder.setInsertionPointToStart(entryBlock); Location loc = func.getLoc(); ValueRange args = entryBlock->getArguments(); Value p = args[hiIdx]; SmallVector types(2, p.getType()); // Only two types. scf::WhileOp whileOp = builder.create( loc, types, SmallVector{args[loIdx], args[hiIdx]}); // The before-region of the WhileOp. Block *before = builder.createBlock(&whileOp.getBefore(), {}, types, {loc, loc}); builder.setInsertionPointToEnd(before); Value cond1 = builder.create(loc, arith::CmpIPredicate::ult, before->getArgument(0), before->getArgument(1)); builder.create(loc, cond1, before->getArguments()); // The after-region of the WhileOp. Block *after = builder.createBlock(&whileOp.getAfter(), {}, types, {loc, loc}); builder.setInsertionPointToEnd(after); Value lo = after->getArgument(0); Value hi = after->getArgument(1); // Compute mid = (lo + hi) >> 1. Value c1 = constantIndex(builder, loc, 1); Value mid = builder.create( loc, builder.create(loc, lo, hi), c1); Value midp1 = builder.create(loc, mid, c1); // Compare xs[p] < xs[mid]. SmallVector compareOperands{p, mid}; constexpr uint64_t numXBuffers = 1; compareOperands.append(args.begin() + xStartIdx, args.begin() + xStartIdx + numXBuffers); Value cond2 = createInlinedLessThan(builder, loc, compareOperands, xPerm, ny); // Update lo and hi for the WhileOp as follows: // if (xs[p] < xs[mid])) // hi = mid; // else // lo = mid + 1; Value newLo = builder.create(loc, cond2, lo, midp1); Value newHi = builder.create(loc, cond2, mid, hi); builder.create(loc, ValueRange{newLo, newHi}); builder.setInsertionPointAfter(whileOp); builder.create(loc, whileOp.getResult(0)); } /// Creates code to advance i in a loop based on xs[p] as follows: /// while (xs[i] < xs[p]) i += step (step > 0) /// or /// while (xs[i] > xs[p]) i += step (step < 0) /// The routine returns i as well as a boolean value to indicate whether /// xs[i] == xs[p]. static std::pair createScanLoop(OpBuilder &builder, ModuleOp module, func::FuncOp func, ValueRange xs, Value i, Value p, AffineMap xPerm, uint64_t ny, int step) { Location loc = func.getLoc(); scf::WhileOp whileOp = builder.create(loc, TypeRange{i.getType()}, ValueRange{i}); Block *before = builder.createBlock(&whileOp.getBefore(), {}, {i.getType()}, {loc}); builder.setInsertionPointToEnd(before); SmallVector compareOperands; if (step > 0) { compareOperands.push_back(before->getArgument(0)); compareOperands.push_back(p); } else { assert(step < 0); compareOperands.push_back(p); compareOperands.push_back(before->getArgument(0)); } compareOperands.append(xs.begin(), xs.end()); Value cond = createInlinedLessThan(builder, loc, compareOperands, xPerm, ny); builder.create(loc, cond, before->getArguments()); Block *after = builder.createBlock(&whileOp.getAfter(), {}, {i.getType()}, {loc}); builder.setInsertionPointToEnd(after); Value cs = constantIndex(builder, loc, step); i = builder.create(loc, after->getArgument(0), cs); builder.create(loc, ValueRange{i}); i = whileOp.getResult(0); builder.setInsertionPointAfter(whileOp); compareOperands[0] = i; compareOperands[1] = p; Value compareEq = createInlinedEqCompare(builder, loc, compareOperands, xPerm, ny); return std::make_pair(whileOp.getResult(0), compareEq); } /// Creates and returns an IfOp to compare two elements and swap the elements /// if compareFunc(data[b], data[a]) returns true. The new insertion point is /// right after the swap instructions. static scf::IfOp createCompareThenSwap(OpBuilder &builder, Location loc, AffineMap xPerm, uint64_t ny, SmallVectorImpl &swapOperands, SmallVectorImpl &compareOperands, Value a, Value b) { // Compare(data[b], data[a]). compareOperands[0] = b; compareOperands[1] = a; Value cond = createInlinedLessThan(builder, loc, compareOperands, xPerm, ny); scf::IfOp ifOp = builder.create(loc, cond, /*else=*/false); builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); swapOperands[0] = b; swapOperands[1] = a; createSwap(builder, loc, swapOperands, xPerm, ny); return ifOp; } /// Creates code to insert the 3rd element to a list of two sorted elements. static void createInsert3rd(OpBuilder &builder, Location loc, AffineMap xPerm, uint64_t ny, SmallVectorImpl &swapOperands, SmallVectorImpl &compareOperands, Value v0, Value v1, Value v2) { scf::IfOp ifOp = createCompareThenSwap(builder, loc, xPerm, ny, swapOperands, compareOperands, v1, v2); createCompareThenSwap(builder, loc, xPerm, ny, swapOperands, compareOperands, v0, v1); builder.setInsertionPointAfter(ifOp); } /// Creates code to sort 3 elements. static void createSort3(OpBuilder &builder, Location loc, AffineMap xPerm, uint64_t ny, SmallVectorImpl &swapOperands, SmallVectorImpl &compareOperands, Value v0, Value v1, Value v2) { // Sort the first 2 elements. scf::IfOp ifOp1 = createCompareThenSwap(builder, loc, xPerm, ny, swapOperands, compareOperands, v0, v1); builder.setInsertionPointAfter(ifOp1); // Insert the 3th element. createInsert3rd(builder, loc, xPerm, ny, swapOperands, compareOperands, v0, v1, v2); } /// Creates code to sort 5 elements. static void createSort5(OpBuilder &builder, Location loc, AffineMap xPerm, uint64_t ny, SmallVectorImpl &swapOperands, SmallVectorImpl &compareOperands, Value v0, Value v1, Value v2, Value v3, Value v4) { // Sort the first 3 elements. createSort3(builder, loc, xPerm, ny, swapOperands, compareOperands, v0, v1, v2); auto insert4th = [&]() { scf::IfOp ifOp = createCompareThenSwap( builder, loc, xPerm, ny, swapOperands, compareOperands, v2, v3); createInsert3rd(builder, loc, xPerm, ny, swapOperands, compareOperands, v0, v1, v2); builder.setInsertionPointAfter(ifOp); }; // Insert the 4th element. insert4th(); // Insert the 5th element. scf::IfOp ifOp = createCompareThenSwap(builder, loc, xPerm, ny, swapOperands, compareOperands, v3, v4); insert4th(); builder.setInsertionPointAfter(ifOp); } /// Creates a code block to swap the values in indices lo, mi, and hi so that /// data[lo], data[mi] and data[hi] are sorted in non-decreasing values. When /// the number of values in range [lo, hi) is more than a threshold, we also /// include the middle of [lo, mi) and [mi, hi) and sort a total of five values. static void createChoosePivot(OpBuilder &builder, ModuleOp module, func::FuncOp func, AffineMap xPerm, uint64_t ny, Value lo, Value hi, Value mi, ValueRange args) { SmallVector compareOperands{mi, lo}; constexpr uint64_t numXBuffers = 1; compareOperands.append(args.begin() + xStartIdx, args.begin() + xStartIdx + numXBuffers); SmallVector swapOperands{mi, lo}; swapOperands.append(args.begin() + xStartIdx, args.end()); Location loc = func.getLoc(); Value c1 = constantIndex(builder, loc, 1); Value hiP1 = builder.create(loc, hi, c1); Value len = builder.create(loc, hiP1, lo); Value lenThreshold = constantIndex(builder, loc, 1000); Value lenCond = builder.create(loc, arith::CmpIPredicate::ult, len, lenThreshold); scf::IfOp lenIf = builder.create(loc, lenCond, /*else=*/true); // When len < 1000, choose pivot from median of 3 values. builder.setInsertionPointToStart(&lenIf.getThenRegion().front()); createSort3(builder, loc, xPerm, ny, swapOperands, compareOperands, lo, mi, hi); // When len >= 1000, choose pivot from median of 5 values. builder.setInsertionPointToStart(&lenIf.getElseRegion().front()); Value miP1 = builder.create(loc, hi, c1); Value a = builder.create(loc, lo, miP1); // Value a is the middle between [loc, mi]. a = builder.create(loc, a, c1); Value b = builder.create(loc, mi, hiP1); // Value b is the middle between [mi, hi]. b = builder.create(loc, b, c1); createSort5(builder, loc, xPerm, ny, swapOperands, compareOperands, lo, a, mi, b, hi); builder.setInsertionPointAfter(lenIf); } /// Creates a function to perform quick sort partition on the values in the /// range of index [lo, hi), assuming lo < hi. // // The generated IR corresponds to this C like algorithm: // int partition(lo, hi, xs) { // p = (lo+hi)/2 // pivot index // i = lo // j = hi-1 // while (true) do { // while (xs[i] < xs[p]) i ++; // i_eq = (xs[i] == xs[p]); // while (xs[j] > xs[p]) j --; // j_eq = (xs[j] == xs[p]); // // if (i >= j) return j + 1; // // if (i < j) { // swap(xs[i], xs[j]) // if (i == p) { // p = j; // } else if (j == p) { // p = i; // } // if (i_eq && j_eq) { // ++i; // --j; // } // } // } // } static void createPartitionFunc(OpBuilder &builder, ModuleOp module, func::FuncOp func, AffineMap xPerm, uint64_t ny, uint32_t nTrailingP = 0) { // Quick sort partition doesn't use trailing parameters. (void)nTrailingP; assert(nTrailingP == 0); OpBuilder::InsertionGuard insertionGuard(builder); Block *entryBlock = func.addEntryBlock(); builder.setInsertionPointToStart(entryBlock); Location loc = func.getLoc(); ValueRange args = entryBlock->getArguments(); Value lo = args[loIdx]; Value hi = args[hiIdx]; Value sum = builder.create(loc, lo, hi); Value c1 = constantIndex(builder, loc, 1); Value p = builder.create(loc, sum, c1); Value i = lo; Value j = builder.create(loc, hi, c1); createChoosePivot(builder, module, func, xPerm, ny, i, j, p, args); Value trueVal = constantI1(builder, loc, true); // The value for while (true) SmallVector operands{i, j, p, trueVal}; // Exactly four values. SmallVector types{i.getType(), j.getType(), p.getType(), trueVal.getType()}; scf::WhileOp whileOp = builder.create(loc, types, operands); // The before-region of the WhileOp. Block *before = builder.createBlock(&whileOp.getBefore(), {}, types, {loc, loc, loc, loc}); builder.setInsertionPointToEnd(before); builder.create(loc, before->getArgument(3), before->getArguments()); // The after-region of the WhileOp. Block *after = builder.createBlock(&whileOp.getAfter(), {}, types, {loc, loc, loc, loc}); builder.setInsertionPointToEnd(after); i = after->getArgument(0); j = after->getArgument(1); p = after->getArgument(2); constexpr uint64_t numXBuffers = 1; auto [iresult, iCompareEq] = createScanLoop(builder, module, func, args.slice(xStartIdx, numXBuffers), i, p, xPerm, ny, 1); i = iresult; auto [jresult, jCompareEq] = createScanLoop(builder, module, func, args.slice(xStartIdx, numXBuffers), j, p, xPerm, ny, -1); j = jresult; // If i < j: Value cond = builder.create(loc, arith::CmpIPredicate::ult, i, j); scf::IfOp ifOp = builder.create(loc, types, cond, /*else=*/true); builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); SmallVector swapOperands{i, j}; swapOperands.append(args.begin() + xStartIdx, args.end()); createSwap(builder, loc, swapOperands, xPerm, ny); // If the pivot is moved, update p with the new pivot. Value icond = builder.create(loc, arith::CmpIPredicate::eq, i, p); scf::IfOp ifOpI = builder.create(loc, TypeRange{p.getType()}, icond, /*else=*/true); builder.setInsertionPointToStart(&ifOpI.getThenRegion().front()); builder.create(loc, ValueRange{j}); builder.setInsertionPointToStart(&ifOpI.getElseRegion().front()); Value jcond = builder.create(loc, arith::CmpIPredicate::eq, j, p); scf::IfOp ifOpJ = builder.create(loc, TypeRange{p.getType()}, jcond, /*else=*/true); builder.setInsertionPointToStart(&ifOpJ.getThenRegion().front()); builder.create(loc, ValueRange{i}); builder.setInsertionPointToStart(&ifOpJ.getElseRegion().front()); builder.create(loc, ValueRange{p}); builder.setInsertionPointAfter(ifOpJ); builder.create(loc, ifOpJ.getResults()); builder.setInsertionPointAfter(ifOpI); Value compareEqIJ = builder.create(loc, iCompareEq, jCompareEq); scf::IfOp ifOp2 = builder.create( loc, TypeRange{i.getType(), j.getType()}, compareEqIJ, /*else=*/true); builder.setInsertionPointToStart(&ifOp2.getThenRegion().front()); Value i2 = builder.create(loc, i, c1); Value j2 = builder.create(loc, j, c1); builder.create(loc, ValueRange{i2, j2}); builder.setInsertionPointToStart(&ifOp2.getElseRegion().front()); builder.create(loc, ValueRange{i, j}); builder.setInsertionPointAfter(ifOp2); builder.create( loc, ValueRange{ifOp2.getResult(0), ifOp2.getResult(1), ifOpI.getResult(0), /*cont=*/constantI1(builder, loc, true)}); // False branch for if i < j (i.e., i >= j): builder.setInsertionPointToStart(&ifOp.getElseRegion().front()); p = builder.create(loc, j, constantOne(builder, loc, j.getType())); builder.create( loc, ValueRange{i, j, p, /*cont=*/constantI1(builder, loc, false)}); // Return for the whileOp. builder.setInsertionPointAfter(ifOp); builder.create(loc, ifOp.getResults()); // Return for the function. builder.setInsertionPointAfter(whileOp); builder.create(loc, whileOp.getResult(2)); } /// Computes (n-2)/n, assuming n has index type. static Value createSubTwoDividedByTwo(OpBuilder &builder, Location loc, Value n) { Value i2 = constantIndex(builder, loc, 2); Value res = builder.create(loc, n, i2); Value i1 = constantIndex(builder, loc, 1); return builder.create(loc, res, i1); } /// Creates a function to heapify the subtree with root `start` within the full /// binary tree in the range of index [first, first + n). // // The generated IR corresponds to this C like algorithm: // void shiftDown(first, start, n, data) { // if (n >= 2) { // child = start - first // if ((n-2)/2 >= child) { // // Left child exists. // child = child * 2 + 1 // Initialize the bigger child to left child. // childIndex = child + first // if (child+1 < n && data[childIndex] < data[childIndex+1]) // // Right child exits and is bigger. // childIndex++; child++; // // Shift data[start] down to where it belongs in the subtree. // while (data[start] < data[childIndex) { // swap(data[start], data[childIndex]) // start = childIndex // if ((n - 2)/2 >= child) { // // Left child exists. // child = 2*child + 1 // childIndex = child + 1 // if (child + 1) < n && data[childIndex] < data[childIndex+1] // childIndex++; child++; // } // } // } // } // } // static void createShiftDownFunc(OpBuilder &builder, ModuleOp module, func::FuncOp func, AffineMap xPerm, uint64_t ny, uint32_t nTrailingP) { // The value n is passed in as a trailing parameter. assert(nTrailingP == 1); OpBuilder::InsertionGuard insertionGuard(builder); Block *entryBlock = func.addEntryBlock(); builder.setInsertionPointToStart(entryBlock); Location loc = func.getLoc(); Value n = entryBlock->getArguments().back(); ValueRange args = entryBlock->getArguments().drop_back(); Value first = args[loIdx]; Value start = args[hiIdx]; // If (n >= 2). Value c2 = constantIndex(builder, loc, 2); Value condN = builder.create(loc, arith::CmpIPredicate::uge, n, c2); scf::IfOp ifN = builder.create(loc, condN, /*else=*/false); builder.setInsertionPointToStart(&ifN.getThenRegion().front()); Value child = builder.create(loc, start, first); // If ((n-2)/2 >= child). Value t = createSubTwoDividedByTwo(builder, loc, n); Value condNc = builder.create(loc, arith::CmpIPredicate::uge, t, child); scf::IfOp ifNc = builder.create(loc, condNc, /*else=*/false); builder.setInsertionPointToStart(&ifNc.getThenRegion().front()); Value c1 = constantIndex(builder, loc, 1); SmallVector compareOperands{start, start}; constexpr uint64_t numXBuffers = 1; compareOperands.append(args.begin() + xStartIdx, args.begin() + xStartIdx + numXBuffers); // Generate code to inspect the children of 'r' and return the larger child // as follows: // child = r * 2 + 1 // Left child. // childIndex = child + first // if (child+1 < n && data[childIndex] < data[childIndex+1]) // childIndex ++; child ++ // Right child is bigger. auto getLargerChild = [&](Value r) -> std::pair { Value lChild = builder.create(loc, r, c1); lChild = builder.create(loc, lChild, c1); Value lChildIdx = builder.create(loc, lChild, first); Value rChild = builder.create(loc, lChild, c1); Value cond1 = builder.create(loc, arith::CmpIPredicate::ult, rChild, n); SmallVector ifTypes(2, r.getType()); scf::IfOp if1 = builder.create(loc, ifTypes, cond1, /*else=*/true); builder.setInsertionPointToStart(&if1.getThenRegion().front()); Value rChildIdx = builder.create(loc, rChild, first); // Compare data[left] < data[right]. compareOperands[0] = lChildIdx; compareOperands[1] = rChildIdx; Value cond2 = createInlinedLessThan(builder, loc, compareOperands, xPerm, ny); scf::IfOp if2 = builder.create(loc, ifTypes, cond2, /*else=*/true); builder.setInsertionPointToStart(&if2.getThenRegion().front()); builder.create(loc, ValueRange{rChild, rChildIdx}); builder.setInsertionPointToStart(&if2.getElseRegion().front()); builder.create(loc, ValueRange{lChild, lChildIdx}); builder.setInsertionPointAfter(if2); builder.create(loc, if2.getResults()); builder.setInsertionPointToStart(&if1.getElseRegion().front()); builder.create(loc, ValueRange{lChild, lChildIdx}); builder.setInsertionPointAfter(if1); return std::make_pair(if1.getResult(0), if1.getResult(1)); }; Value childIdx; std::tie(child, childIdx) = getLargerChild(child); // While (data[start] < data[childIndex]). SmallVector types(3, child.getType()); scf::WhileOp whileOp = builder.create( loc, types, SmallVector{start, child, childIdx}); // The before-region of the WhileOp. SmallVector locs(3, loc); Block *before = builder.createBlock(&whileOp.getBefore(), {}, types, locs); builder.setInsertionPointToEnd(before); start = before->getArgument(0); childIdx = before->getArgument(2); compareOperands[0] = start; compareOperands[1] = childIdx; Value cond = createInlinedLessThan(builder, loc, compareOperands, xPerm, ny); builder.create(loc, cond, before->getArguments()); // The after-region of the WhileOp. Block *after = builder.createBlock(&whileOp.getAfter(), {}, types, locs); start = after->getArgument(0); child = after->getArgument(1); childIdx = after->getArgument(2); SmallVector swapOperands{start, childIdx}; swapOperands.append(args.begin() + xStartIdx, args.end()); createSwap(builder, loc, swapOperands, xPerm, ny); start = childIdx; Value cond2 = builder.create(loc, arith::CmpIPredicate::uge, t, child); scf::IfOp if2 = builder.create( loc, TypeRange{child.getType(), child.getType()}, cond2, /*else=*/true); builder.setInsertionPointToStart(&if2.getThenRegion().front()); auto [newChild, newChildIdx] = getLargerChild(child); builder.create(loc, ValueRange{newChild, newChildIdx}); builder.setInsertionPointToStart(&if2.getElseRegion().front()); builder.create(loc, ValueRange{child, childIdx}); builder.setInsertionPointAfter(if2); builder.create( loc, ValueRange{start, if2.getResult(0), if2.getResult(1)}); builder.setInsertionPointAfter(ifN); builder.create(loc); } /// Creates a function to perform heap sort on the values in the range of index /// [lo, hi) with the assumption hi - lo >= 2. // // The generate IR corresponds to this C like algorithm: // void heapSort(lo, hi, data) { // n = hi - lo // for i = (n-2)/2 downto 0 // shiftDown(lo, lo+i, n) // // for l = n downto 2 // swap(lo, lo+l-1) // shiftdown(lo, lo, l-1) // } static void createHeapSortFunc(OpBuilder &builder, ModuleOp module, func::FuncOp func, AffineMap xPerm, uint64_t ny, uint32_t nTrailingP) { // Heap sort function doesn't have trailing parameters. (void)nTrailingP; assert(nTrailingP == 0); OpBuilder::InsertionGuard insertionGuard(builder); Block *entryBlock = func.addEntryBlock(); builder.setInsertionPointToStart(entryBlock); Location loc = func.getLoc(); ValueRange args = entryBlock->getArguments(); Value lo = args[loIdx]; Value hi = args[hiIdx]; Value n = builder.create(loc, hi, lo); // For i = (n-2)/2 downto 0. Value c0 = constantIndex(builder, loc, 0); Value c1 = constantIndex(builder, loc, 1); Value s = createSubTwoDividedByTwo(builder, loc, n); Value up = builder.create(loc, s, c1); scf::ForOp forI = builder.create(loc, c0, up, c1); builder.setInsertionPointToStart(forI.getBody()); Value i = builder.create(loc, s, forI.getInductionVar()); Value lopi = builder.create(loc, lo, i); SmallVector shiftDownOperands = {lo, lopi}; shiftDownOperands.append(args.begin() + xStartIdx, args.end()); shiftDownOperands.push_back(n); FlatSymbolRefAttr shiftDownFunc = getMangledSortHelperFunc( builder, func, TypeRange(), kShiftDownFuncNamePrefix, xPerm, ny, shiftDownOperands, createShiftDownFunc, /*nTrailingP=*/1); builder.create(loc, shiftDownFunc, TypeRange(), shiftDownOperands); builder.setInsertionPointAfter(forI); // For l = n downto 2. up = builder.create(loc, n, c1); scf::ForOp forL = builder.create(loc, c0, up, c1); builder.setInsertionPointToStart(forL.getBody()); Value l = builder.create(loc, n, forL.getInductionVar()); Value loplm1 = builder.create(loc, lo, l); loplm1 = builder.create(loc, loplm1, c1); SmallVector swapOperands{lo, loplm1}; swapOperands.append(args.begin() + xStartIdx, args.end()); createSwap(builder, loc, swapOperands, xPerm, ny); shiftDownOperands[1] = lo; shiftDownOperands[shiftDownOperands.size() - 1] = builder.create(loc, l, c1); builder.create(loc, shiftDownFunc, TypeRange(), shiftDownOperands); builder.setInsertionPointAfter(forL); builder.create(loc); } /// A helper for generating code to perform quick sort. It partitions [lo, hi), /// recursively calls quick sort to process the smaller partition and returns /// the bigger partition to be processed by the enclosed while-loop. static std::pair createQuickSort(OpBuilder &builder, ModuleOp module, func::FuncOp func, ValueRange args, AffineMap xPerm, uint64_t ny, uint32_t nTrailingP) { MLIRContext *context = module.getContext(); Location loc = func.getLoc(); Value lo = args[loIdx]; Value hi = args[hiIdx]; SmallVector types(2, lo.getType()); // Only two types. FlatSymbolRefAttr partitionFunc = getMangledSortHelperFunc( builder, func, {IndexType::get(context)}, kPartitionFuncNamePrefix, xPerm, ny, args.drop_back(nTrailingP), createPartitionFunc); Value p = builder .create(loc, partitionFunc, TypeRange{IndexType::get(context)}, args.drop_back(nTrailingP)) .getResult(0); Value lenLow = builder.create(loc, p, lo); Value lenHigh = builder.create(loc, hi, p); // Partition already sorts array with len <= 2 Value c2 = constantIndex(builder, loc, 2); Value len = builder.create(loc, hi, lo); Value lenGtTwo = builder.create(loc, arith::CmpIPredicate::ugt, len, c2); scf::IfOp ifLenGtTwo = builder.create(loc, types, lenGtTwo, /*else=*/true); builder.setInsertionPointToStart(&ifLenGtTwo.getElseRegion().front()); // Returns an empty range to mark the entire region is fully sorted. builder.create(loc, ValueRange{lo, lo}); // Else len > 2, need recursion. builder.setInsertionPointToStart(&ifLenGtTwo.getThenRegion().front()); Value cond = builder.create(loc, arith::CmpIPredicate::ule, lenLow, lenHigh); Value c0 = constantIndex(builder, loc, 0); scf::IfOp ifOp = builder.create(loc, types, cond, /*else=*/true); auto mayRecursion = [&](Value low, Value high, Value len) { Value cond = builder.create(loc, arith::CmpIPredicate::ne, len, c0); scf::IfOp ifOp = builder.create(loc, cond, /*else=*/false); builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); SmallVector operands{low, high}; operands.append(args.begin() + xStartIdx, args.end()); builder.create(loc, func, operands); builder.setInsertionPointAfter(ifOp); }; // Recursively call quickSort to process the smaller partition and return // the bigger partition to be processed by the enclosed while-loop. builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); mayRecursion(lo, p, lenLow); builder.create(loc, ValueRange{p, hi}); builder.setInsertionPointToStart(&ifOp.getElseRegion().front()); mayRecursion(p, hi, lenHigh); builder.create(loc, ValueRange{lo, p}); builder.setInsertionPointAfter(ifOp); builder.create(loc, ifOp.getResults()); builder.setInsertionPointAfter(ifLenGtTwo); return std::make_pair(ifLenGtTwo.getResult(0), ifLenGtTwo.getResult(1)); } /// Creates a function to perform insertion sort on the values in the range of /// index [lo, hi). // // The generate IR corresponds to this C like algorithm: // void insertionSort(lo, hi, data) { // for (i = lo+1; i < hi; i++) { // d = data[i]; // p = binarySearch(lo, i-1, data) // for (j = 0; j > i - p; j++) // data[i-j] = data[i-j-1] // data[p] = d // } // } static void createSortStableFunc(OpBuilder &builder, ModuleOp module, func::FuncOp func, AffineMap xPerm, uint64_t ny, uint32_t nTrailingP) { // Stable sort function doesn't use trailing parameters. (void)nTrailingP; assert(nTrailingP == 0); OpBuilder::InsertionGuard insertionGuard(builder); Block *entryBlock = func.addEntryBlock(); builder.setInsertionPointToStart(entryBlock); MLIRContext *context = module.getContext(); Location loc = func.getLoc(); ValueRange args = entryBlock->getArguments(); Value c1 = constantIndex(builder, loc, 1); Value lo = args[loIdx]; Value hi = args[hiIdx]; Value lop1 = builder.create(loc, lo, c1); // Start the outer for-stmt with induction variable i. scf::ForOp forOpI = builder.create(loc, lop1, hi, c1); builder.setInsertionPointToStart(forOpI.getBody()); Value i = forOpI.getInductionVar(); // Binary search to find the insertion point p. SmallVector operands{lo, i}; operands.append(args.begin() + xStartIdx, args.end()); FlatSymbolRefAttr searchFunc = getMangledSortHelperFunc( builder, func, {IndexType::get(context)}, kBinarySearchFuncNamePrefix, xPerm, ny, operands, createBinarySearchFunc); Value p = builder .create(loc, searchFunc, TypeRange{c1.getType()}, operands) .getResult(0); // Move the value at data[i] to a temporary location. operands[0] = operands[1] = i; SmallVector d; forEachIJPairInAllBuffers( builder, loc, operands, xPerm, ny, [&](uint64_t unused, Value i, Value unused2, Value buffer) { d.push_back(builder.create(loc, buffer, i)); }); // Start the inner for-stmt with induction variable j, for moving data[p..i) // to data[p+1..i+1). Value imp = builder.create(loc, i, p); Value c0 = constantIndex(builder, loc, 0); scf::ForOp forOpJ = builder.create(loc, c0, imp, c1); builder.setInsertionPointToStart(forOpJ.getBody()); Value j = forOpJ.getInductionVar(); Value imj = builder.create(loc, i, j); operands[1] = imj; operands[0] = builder.create(loc, imj, c1); forEachIJPairInAllBuffers( builder, loc, operands, xPerm, ny, [&](uint64_t unused, Value imjm1, Value imj, Value buffer) { Value t = builder.create(loc, buffer, imjm1); builder.create(loc, t, buffer, imj); }); // Store the value at data[i] to data[p]. builder.setInsertionPointAfter(forOpJ); operands[0] = operands[1] = p; forEachIJPairInAllBuffers( builder, loc, operands, xPerm, ny, [&](uint64_t k, Value p, Value usused, Value buffer) { builder.create(loc, d[k], buffer, p); }); builder.setInsertionPointAfter(forOpI); builder.create(loc); } /// Creates a function to perform quick sort or a hybrid quick sort on the /// values in the range of index [lo, hi). // // // When nTrailingP == 0, the generated IR corresponds to this C like algorithm: // void quickSort(lo, hi, data) { // while (lo + 1 < hi) { // p = partition(low, high, data); // if (len(lo, p) < len(p+1, hi)) { // quickSort(lo, p, data); // lo = p+1; // } else { // quickSort(p + 1, hi, data); // hi = p; // } // } // } // // When nTrailingP == 1, the generated IR corresponds to this C like algorithm: // void hybridQuickSort(lo, hi, data, depthLimit) { // while (lo + 1 < hi) { // len = hi - lo; // if (len <= limit) { // insertionSort(lo, hi, data); // } else { // depthLimit --; // if (depthLimit <= 0) { // heapSort(lo, hi, data); // } else { // p = partition(low, high, data); // if (len(lo, p) < len(p+1, hi)) { // quickSort(lo, p, data, depthLimit); // lo = p+1; // } else { // quickSort(p + 1, hi, data, depthLimit); // hi = p; // } // } // } // } // } // static void createQuickSortFunc(OpBuilder &builder, ModuleOp module, func::FuncOp func, AffineMap xPerm, uint64_t ny, uint32_t nTrailingP) { assert(nTrailingP == 1 || nTrailingP == 0); bool isHybrid = (nTrailingP == 1); OpBuilder::InsertionGuard insertionGuard(builder); Block *entryBlock = func.addEntryBlock(); builder.setInsertionPointToStart(entryBlock); Location loc = func.getLoc(); SmallVector args; args.append(entryBlock->getArguments().begin(), entryBlock->getArguments().end()); Value lo = args[loIdx]; Value hi = args[hiIdx]; SmallVector types(2, lo.getType()); // Only two types. scf::WhileOp whileOp = builder.create(loc, types, SmallVector{lo, hi}); // The before-region of the WhileOp. Block *before = builder.createBlock(&whileOp.getBefore(), {}, types, {loc, loc}); builder.setInsertionPointToEnd(before); lo = before->getArgument(0); hi = before->getArgument(1); Value loP1 = builder.create(loc, lo, constantIndex(builder, loc, 1)); Value needSort = builder.create(loc, arith::CmpIPredicate::ult, loP1, hi); builder.create(loc, needSort, before->getArguments()); // The after-region of the WhileOp. Block *after = builder.createBlock(&whileOp.getAfter(), {}, types, {loc, loc}); builder.setInsertionPointToEnd(after); lo = after->getArgument(0); hi = after->getArgument(1); args[0] = lo; args[1] = hi; if (isHybrid) { Value len = builder.create(loc, hi, lo); Value lenLimit = constantIndex(builder, loc, 30); Value lenCond = builder.create( loc, arith::CmpIPredicate::ule, len, lenLimit); scf::IfOp lenIf = builder.create(loc, types, lenCond, /*else=*/true); // When len <= limit. builder.setInsertionPointToStart(&lenIf.getThenRegion().front()); FlatSymbolRefAttr insertionSortFunc = getMangledSortHelperFunc( builder, func, TypeRange(), kSortStableFuncNamePrefix, xPerm, ny, ValueRange(args).drop_back(nTrailingP), createSortStableFunc); builder.create(loc, insertionSortFunc, TypeRange(), ValueRange(args).drop_back(nTrailingP)); builder.create(loc, ValueRange{lo, lo}); // When len > limit. builder.setInsertionPointToStart(&lenIf.getElseRegion().front()); Value depthLimit = args.back(); depthLimit = builder.create(loc, depthLimit, constantI64(builder, loc, 1)); Value depthCond = builder.create(loc, arith::CmpIPredicate::ule, depthLimit, constantI64(builder, loc, 0)); scf::IfOp depthIf = builder.create(loc, types, depthCond, /*else=*/true); // When depth exceeds limit. builder.setInsertionPointToStart(&depthIf.getThenRegion().front()); FlatSymbolRefAttr heapSortFunc = getMangledSortHelperFunc( builder, func, TypeRange(), kHeapSortFuncNamePrefix, xPerm, ny, ValueRange(args).drop_back(nTrailingP), createHeapSortFunc); builder.create(loc, heapSortFunc, TypeRange(), ValueRange(args).drop_back(nTrailingP)); builder.create(loc, ValueRange{lo, lo}); // When depth doesn't exceed limit. builder.setInsertionPointToStart(&depthIf.getElseRegion().front()); args.back() = depthLimit; std::tie(lo, hi) = createQuickSort(builder, module, func, args, xPerm, ny, nTrailingP); builder.create(loc, ValueRange{lo, hi}); builder.setInsertionPointAfter(depthIf); lo = depthIf.getResult(0); hi = depthIf.getResult(1); builder.create(loc, ValueRange{lo, hi}); builder.setInsertionPointAfter(lenIf); lo = lenIf.getResult(0); hi = lenIf.getResult(1); } else { std::tie(lo, hi) = createQuickSort(builder, module, func, args, xPerm, ny, nTrailingP); } // New [lo, hi) for the next while-loop iteration. builder.create(loc, ValueRange{lo, hi}); // After the while-loop. builder.setInsertionPointAfter(whileOp); builder.create(loc); } /// Implements the rewriting for operator sort and sort_coo. template LogicalResult matchAndRewriteSortOp(OpTy op, ValueRange xys, AffineMap xPerm, uint64_t ny, PatternRewriter &rewriter) { Location loc = op.getLoc(); SmallVector operands{constantIndex(rewriter, loc, 0), op.getN()}; // Convert `values` to have dynamic shape and append them to `operands`. for (Value v : xys) { auto mtp = getMemRefType(v); if (!mtp.isDynamicDim(0)) { auto newMtp = MemRefType::get({ShapedType::kDynamic}, mtp.getElementType()); v = rewriter.create(loc, newMtp, v); } operands.push_back(v); } auto insertPoint = op->template getParentOfType(); if (!insertPoint) return failure(); SmallString<32> funcName; FuncGeneratorType funcGenerator; uint32_t nTrailingP = 0; switch (op.getAlgorithm()) { case SparseTensorSortKind::HybridQuickSort: { funcName = kHybridQuickSortFuncNamePrefix; funcGenerator = createQuickSortFunc; nTrailingP = 1; // As a heuristics, set depthLimit = 2 * log2(n). Value lo = operands[loIdx]; Value hi = operands[hiIdx]; Value len = rewriter.create( loc, rewriter.getI64Type(), rewriter.create(loc, hi, lo)); Value depthLimit = rewriter.create( loc, constantI64(rewriter, loc, 64), rewriter.create(loc, len)); operands.push_back(depthLimit); break; } case SparseTensorSortKind::QuickSort: funcName = kQuickSortFuncNamePrefix; funcGenerator = createQuickSortFunc; break; case SparseTensorSortKind::InsertionSortStable: funcName = kSortStableFuncNamePrefix; funcGenerator = createSortStableFunc; break; case SparseTensorSortKind::HeapSort: funcName = kHeapSortFuncNamePrefix; funcGenerator = createHeapSortFunc; break; } FlatSymbolRefAttr func = getMangledSortHelperFunc(rewriter, insertPoint, TypeRange(), funcName, xPerm, ny, operands, funcGenerator, nTrailingP); rewriter.replaceOpWithNewOp(op, func, TypeRange(), operands); return success(); } //===---------------------------------------------------------------------===// // The actual sparse buffer rewriting rules. //===---------------------------------------------------------------------===// namespace { /// Sparse rewriting rule for the push_back operator. struct PushBackRewriter : OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; PushBackRewriter(MLIRContext *context, bool enableInit) : OpRewritePattern(context), enableBufferInitialization(enableInit) {} LogicalResult matchAndRewrite(PushBackOp op, PatternRewriter &rewriter) const override { // Rewrite push_back(buffer, value, n) to: // new_size = size(buffer) + n // if (new_size > capacity(buffer)) // while new_size > new_capacity // new_capacity = new_capacity*2 // new_buffer = realloc(buffer, new_capacity) // buffer = new_buffer // subBuffer = subviewof(buffer) // linalg.fill subBuffer value // // size(buffer) += n // // The capacity check is skipped when the attribute inbounds is presented. Location loc = op->getLoc(); Value c0 = constantIndex(rewriter, loc, 0); Value buffer = op.getInBuffer(); Value capacity = rewriter.create(loc, buffer, c0); Value size = op.getCurSize(); Value value = op.getValue(); Value n = op.getN() ? op.getN() : constantIndex(rewriter, loc, 1); Value newSize = rewriter.create(loc, size, n); auto nValue = dyn_cast_or_null(n.getDefiningOp()); bool nIsOne = (nValue && nValue.value() == 1); if (!op.getInbounds()) { Value cond = rewriter.create( loc, arith::CmpIPredicate::ugt, newSize, capacity); Value c2 = constantIndex(rewriter, loc, 2); auto bufferType = MemRefType::get({ShapedType::kDynamic}, value.getType()); scf::IfOp ifOp = rewriter.create(loc, bufferType, cond, /*else=*/true); // True branch. rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front()); if (nIsOne) { capacity = rewriter.create(loc, capacity, c2); } else { // Use a do-while loop to calculate the new capacity as follows: // do { new_capacity *= 2 } while (size > new_capacity) scf::WhileOp whileOp = rewriter.create(loc, capacity.getType(), capacity); // The before-region of the WhileOp. Block *before = rewriter.createBlock(&whileOp.getBefore(), {}, {capacity.getType()}, {loc}); rewriter.setInsertionPointToEnd(before); capacity = rewriter.create(loc, before->getArgument(0), c2); cond = rewriter.create(loc, arith::CmpIPredicate::ugt, newSize, capacity); rewriter.create(loc, cond, ValueRange{capacity}); // The after-region of the WhileOp. Block *after = rewriter.createBlock(&whileOp.getAfter(), {}, {capacity.getType()}, {loc}); rewriter.setInsertionPointToEnd(after); rewriter.create(loc, after->getArguments()); rewriter.setInsertionPointAfter(whileOp); capacity = whileOp.getResult(0); } Value newBuffer = rewriter.create(loc, bufferType, buffer, capacity); if (enableBufferInitialization) { Value fillSize = rewriter.create(loc, capacity, newSize); Value fillValue = constantZero(rewriter, loc, value.getType()); Value subBuffer = rewriter.create( loc, newBuffer, /*offset=*/ValueRange{newSize}, /*size=*/ValueRange{fillSize}, /*step=*/ValueRange{constantIndex(rewriter, loc, 1)}); rewriter.create(loc, fillValue, subBuffer); } rewriter.create(loc, newBuffer); // False branch. rewriter.setInsertionPointToStart(&ifOp.getElseRegion().front()); rewriter.create(loc, buffer); // Prepare for adding the value to the end of the buffer. rewriter.setInsertionPointAfter(ifOp); buffer = ifOp.getResult(0); } // Add the value to the end of the buffer. if (nIsOne) { rewriter.create(loc, value, buffer, size); } else { Value subBuffer = rewriter.create( loc, buffer, /*offset=*/ValueRange{size}, /*size=*/ValueRange{n}, /*step=*/ValueRange{constantIndex(rewriter, loc, 1)}); rewriter.create(loc, value, subBuffer); } // Update the buffer size. rewriter.replaceOp(op, {buffer, newSize}); return success(); } private: bool enableBufferInitialization; }; /// Sparse rewriting rule for the sort_coo operator. struct SortRewriter : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(SortOp op, PatternRewriter &rewriter) const override { SmallVector xys; xys.push_back(op.getXy()); xys.append(op.getYs().begin(), op.getYs().end()); auto xPerm = op.getPermMap(); uint64_t ny = 0; if (auto nyAttr = op.getNyAttr()) ny = nyAttr.getInt(); return matchAndRewriteSortOp(op, xys, xPerm, ny, rewriter); } }; } // namespace //===---------------------------------------------------------------------===// // Methods that add patterns described in this file to a pattern list. //===---------------------------------------------------------------------===// void mlir::populateSparseBufferRewriting(RewritePatternSet &patterns, bool enableBufferInitialization) { patterns.add(patterns.getContext(), enableBufferInitialization); patterns.add(patterns.getContext()); }