//===- SparseBufferRewriting.cpp - Sparse buffer rewriting rules ----------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This file implements rewriting rules that are specific to sparse tensor // primitives with memref operands. // //===----------------------------------------------------------------------===// #include "CodegenUtils.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" #include "mlir/Dialect/SparseTensor/Transforms/Passes.h" #include "mlir/Support/LLVM.h" using namespace mlir; using namespace mlir::sparse_tensor; //===---------------------------------------------------------------------===// // Helper methods for the actual rewriting rules. //===---------------------------------------------------------------------===// static constexpr uint64_t loIdx = 0; static constexpr uint64_t hiIdx = 1; static constexpr uint64_t xStartIdx = 2; static constexpr const char kLessThanFuncNamePrefix[] = "_sparse_less_than_"; static constexpr const char kCompareEqFuncNamePrefix[] = "_sparse_compare_eq_"; static constexpr const char kPartitionFuncNamePrefix[] = "_sparse_partition_"; static constexpr const char kBinarySearchFuncNamePrefix[] = "_sparse_binary_search_"; static constexpr const char kSortNonstableFuncNamePrefix[] = "_sparse_sort_nonstable_"; static constexpr const char kSortStableFuncNamePrefix[] = "_sparse_sort_stable_"; using FuncGeneratorType = function_ref; /// Constructs a function name with this format to facilitate quick sort: /// __..._ for sort /// __coo__..._ for sort_coo static void getMangledSortHelperFuncName(llvm::raw_svector_ostream &nameOstream, StringRef namePrefix, uint64_t nx, uint64_t ny, bool isCoo, ValueRange operands) { nameOstream << namePrefix << nx << "_" << getMemRefType(operands[xStartIdx]).getElementType(); if (isCoo) nameOstream << "_coo_" << ny; uint64_t yBufferOffset = isCoo ? 1 : nx; for (Value v : operands.drop_front(xStartIdx + yBufferOffset)) nameOstream << "_" << getMemRefType(v).getElementType(); } /// Looks up a function that is appropriate for the given operands being /// sorted, and creates such a function if it doesn't exist yet. The /// parameters `nx` and `ny` tell the number of x and y values provided /// by the buffer in xStartIdx, and `isCoo` indicates whether the instruction /// being processed is a sparse_tensor.sort or sparse_tensor.sort_coo. static FlatSymbolRefAttr getMangledSortHelperFunc(OpBuilder &builder, func::FuncOp insertPoint, TypeRange resultTypes, StringRef namePrefix, uint64_t nx, uint64_t ny, bool isCoo, ValueRange operands, FuncGeneratorType createFunc) { SmallString<32> nameBuffer; llvm::raw_svector_ostream nameOstream(nameBuffer); getMangledSortHelperFuncName(nameOstream, namePrefix, nx, ny, isCoo, operands); ModuleOp module = insertPoint->getParentOfType(); MLIRContext *context = module.getContext(); auto result = SymbolRefAttr::get(context, nameOstream.str()); auto func = module.lookupSymbol(result.getAttr()); if (!func) { // Create the function. OpBuilder::InsertionGuard insertionGuard(builder); builder.setInsertionPoint(insertPoint); Location loc = insertPoint.getLoc(); func = builder.create( loc, nameOstream.str(), FunctionType::get(context, operands.getTypes(), resultTypes)); func.setPrivate(); createFunc(builder, module, func, nx, ny, isCoo); } return result; } /// Creates a code block to process each pair of (xs[i], xs[j]) for sorting. /// The code to process the value pairs is generated by `bodyBuilder`. static void forEachIJPairInXs( OpBuilder &builder, Location loc, ValueRange args, uint64_t nx, uint64_t ny, bool isCoo, function_ref bodyBuilder) { Value iOffset, jOffset; if (isCoo) { Value cstep = constantIndex(builder, loc, nx + ny); iOffset = builder.create(loc, args[0], cstep); jOffset = builder.create(loc, args[1], cstep); } for (uint64_t k = 0; k < nx; k++) { scf::IfOp ifOp; Value i, j, buffer; if (isCoo) { Value ck = constantIndex(builder, loc, k); i = builder.create(loc, ck, iOffset); j = builder.create(loc, ck, jOffset); buffer = args[xStartIdx]; } else { i = args[0]; j = args[1]; buffer = args[xStartIdx + k]; } bodyBuilder(k, i, j, buffer); } } /// Creates a code block to process each pair of (xys[i], xys[j]) for sorting. /// The code to process the value pairs is generated by `bodyBuilder`. static void forEachIJPairInAllBuffers( OpBuilder &builder, Location loc, ValueRange args, uint64_t nx, uint64_t ny, bool isCoo, function_ref bodyBuilder) { // Create code for the first (nx + ny) buffers. When isCoo==true, these // logical buffers are all from the xy buffer of the sort_coo operator. forEachIJPairInXs(builder, loc, args, nx + ny, 0, isCoo, bodyBuilder); uint64_t numHandledBuffers = isCoo ? 1 : nx + ny; // Create code for the remaining buffers. Value i = args[0]; Value j = args[1]; for (const auto &arg : llvm::enumerate(args.drop_front(xStartIdx + numHandledBuffers))) { bodyBuilder(arg.index() + nx + ny, i, j, arg.value()); } } /// Creates a code block for swapping the values in index i and j for all the /// buffers. // // The generated IR corresponds to this C like algorithm: // swap(x0[i], x0[j]); // swap(x1[i], x1[j]); // ... // swap(xn[i], xn[j]); // swap(y0[i], y0[j]); // ... // swap(yn[i], yn[j]); static void createSwap(OpBuilder &builder, Location loc, ValueRange args, uint64_t nx, uint64_t ny, bool isCoo) { auto swapOnePair = [&](uint64_t unused, Value i, Value j, Value buffer) { Value vi = builder.create(loc, buffer, i); Value vj = builder.create(loc, buffer, j); builder.create(loc, vj, buffer, i); builder.create(loc, vi, buffer, j); }; forEachIJPairInAllBuffers(builder, loc, args, nx, ny, isCoo, swapOnePair); } /// Creates a function to compare all the (xs[i], xs[j]) pairs. The method to /// compare each pair is create via `compareBuilder`. static void createCompareFuncImplementation( OpBuilder &builder, ModuleOp unused, func::FuncOp func, uint64_t nx, uint64_t ny, bool isCoo, function_ref compareBuilder) { OpBuilder::InsertionGuard insertionGuard(builder); Block *entryBlock = func.addEntryBlock(); builder.setInsertionPointToStart(entryBlock); Location loc = func.getLoc(); ValueRange args = entryBlock->getArguments(); scf::IfOp topIfOp; auto bodyBuilder = [&](uint64_t k, Value i, Value j, Value buffer) { scf::IfOp ifOp = compareBuilder(builder, loc, i, j, buffer, (k == nx - 1)); if (k == 0) { topIfOp = ifOp; } else { OpBuilder::InsertionGuard insertionGuard(builder); builder.setInsertionPointAfter(ifOp); builder.create(loc, ifOp.getResult(0)); } }; forEachIJPairInXs(builder, loc, args, nx, ny, isCoo, bodyBuilder); builder.setInsertionPointAfter(topIfOp); builder.create(loc, topIfOp.getResult(0)); } /// Generates an if-statement to compare whether x[i] is equal to x[j]. static scf::IfOp createEqCompare(OpBuilder &builder, Location loc, Value i, Value j, Value x, bool isLastDim) { Value f = constantI1(builder, loc, false); Value t = constantI1(builder, loc, true); Value vi = builder.create(loc, x, i); Value vj = builder.create(loc, x, j); Value cond = builder.create(loc, arith::CmpIPredicate::eq, vi, vj); scf::IfOp ifOp = builder.create(loc, f.getType(), cond, /*else=*/true); // x[1] != x[j]: builder.setInsertionPointToStart(&ifOp.getElseRegion().front()); builder.create(loc, f); // x[i] == x[j]: builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); if (isLastDim == 1) { // Finish checking all dimensions. builder.create(loc, t); } return ifOp; } /// Creates a function to compare whether xs[i] is equal to xs[j]. // // The generate IR corresponds to this C like algorithm: // if (x0[i] != x0[j]) // return false; // else // if (x1[i] != x1[j]) // return false; // else if (x2[2] != x2[j])) // and so on ... static void createEqCompareFunc(OpBuilder &builder, ModuleOp unused, func::FuncOp func, uint64_t nx, uint64_t ny, bool isCoo) { createCompareFuncImplementation(builder, unused, func, nx, ny, isCoo, createEqCompare); } /// Generates an if-statement to compare whether x[i] is less than x[j]. static scf::IfOp createLessThanCompare(OpBuilder &builder, Location loc, Value i, Value j, Value x, bool isLastDim) { Value f = constantI1(builder, loc, false); Value t = constantI1(builder, loc, true); Value vi = builder.create(loc, x, i); Value vj = builder.create(loc, x, j); Value cond = builder.create(loc, arith::CmpIPredicate::ult, vi, vj); scf::IfOp ifOp = builder.create(loc, f.getType(), cond, /*else=*/true); // If (x[i] < x[j]). builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); builder.create(loc, t); builder.setInsertionPointToStart(&ifOp.getElseRegion().front()); if (isLastDim == 1) { // Finish checking all dimensions. builder.create(loc, f); } else { cond = builder.create(loc, arith::CmpIPredicate::ult, vj, vi); scf::IfOp ifOp2 = builder.create(loc, f.getType(), cond, /*else=*/true); // Otherwise if (x[j] < x[i]). builder.setInsertionPointToStart(&ifOp2.getThenRegion().front()); builder.create(loc, f); // Otherwise check the remaining dimensions. builder.setInsertionPointAfter(ifOp2); builder.create(loc, ifOp2.getResult(0)); // Set up the insertion point for the nested if-stmt that checks the // remaining dimensions. builder.setInsertionPointToStart(&ifOp2.getElseRegion().front()); } return ifOp; } /// Creates a function to compare whether xs[i] is less than xs[j]. // // The generate IR corresponds to this C like algorithm: // if (x0[i] < x0[j]) // return true; // else if (x0[j] < x0[i]) // return false; // else // if (x1[i] < x1[j]) // return true; // else if (x1[j] < x1[i])) // and so on ... static void createLessThanFunc(OpBuilder &builder, ModuleOp unused, func::FuncOp func, uint64_t nx, uint64_t ny, bool isCoo) { createCompareFuncImplementation(builder, unused, func, nx, ny, isCoo, createLessThanCompare); } /// Creates a function to use a binary search to find the insertion point for /// inserting xs[hi] to the sorted values xs[lo..hi). // // The generate IR corresponds to this C like algorithm: // p = hi // while (lo < hi) // mid = (lo + hi) >> 1 // if (xs[p] < xs[mid]) // hi = mid // else // lo = mid - 1 // return lo; // static void createBinarySearchFunc(OpBuilder &builder, ModuleOp module, func::FuncOp func, uint64_t nx, uint64_t ny, bool isCoo) { OpBuilder::InsertionGuard insertionGuard(builder); Block *entryBlock = func.addEntryBlock(); builder.setInsertionPointToStart(entryBlock); Location loc = func.getLoc(); ValueRange args = entryBlock->getArguments(); Value p = args[hiIdx]; SmallVector types(2, p.getType()); // only two scf::WhileOp whileOp = builder.create( loc, types, SmallVector{args[loIdx], args[hiIdx]}); // The before-region of the WhileOp. Block *before = builder.createBlock(&whileOp.getBefore(), {}, types, {loc, loc}); builder.setInsertionPointToEnd(before); Value cond1 = builder.create(loc, arith::CmpIPredicate::ult, before->getArgument(0), before->getArgument(1)); builder.create(loc, cond1, before->getArguments()); // The after-region of the WhileOp. Block *after = builder.createBlock(&whileOp.getAfter(), {}, types, {loc, loc}); builder.setInsertionPointToEnd(after); Value lo = after->getArgument(0); Value hi = after->getArgument(1); // Compute mid = (lo + hi) >> 1. Value c1 = constantIndex(builder, loc, 1); Value mid = builder.create( loc, builder.create(loc, lo, hi), c1); Value midp1 = builder.create(loc, mid, c1); // Compare xs[p] < xs[mid]. SmallVector compareOperands{p, mid}; uint64_t numXBuffers = isCoo ? 1 : nx; compareOperands.append(args.begin() + xStartIdx, args.begin() + xStartIdx + numXBuffers); Type i1Type = IntegerType::get(module.getContext(), 1, IntegerType::Signless); FlatSymbolRefAttr lessThanFunc = getMangledSortHelperFunc( builder, func, {i1Type}, kLessThanFuncNamePrefix, nx, ny, isCoo, compareOperands, createLessThanFunc); Value cond2 = builder .create(loc, lessThanFunc, TypeRange{i1Type}, compareOperands) .getResult(0); // Update lo and hi for the WhileOp as follows: // if (xs[p] < xs[mid])) // hi = mid; // else // lo = mid + 1; Value newLo = builder.create(loc, cond2, lo, midp1); Value newHi = builder.create(loc, cond2, mid, hi); builder.create(loc, ValueRange{newLo, newHi}); builder.setInsertionPointAfter(whileOp); builder.create(loc, whileOp.getResult(0)); } /// Creates code to advance i in a loop based on xs[p] as follows: /// while (xs[i] < xs[p]) i += step (step > 0) /// or /// while (xs[i] > xs[p]) i += step (step < 0) /// The routine returns i as well as a boolean value to indicate whether /// xs[i] == xs[p]. static std::pair createScanLoop(OpBuilder &builder, ModuleOp module, func::FuncOp func, ValueRange xs, Value i, Value p, uint64_t nx, uint64_t ny, bool isCoo, int step) { Location loc = func.getLoc(); scf::WhileOp whileOp = builder.create(loc, TypeRange{i.getType()}, ValueRange{i}); Block *before = builder.createBlock(&whileOp.getBefore(), {}, {i.getType()}, {loc}); builder.setInsertionPointToEnd(before); SmallVector compareOperands; if (step > 0) { compareOperands.push_back(before->getArgument(0)); compareOperands.push_back(p); } else { assert(step < 0); compareOperands.push_back(p); compareOperands.push_back(before->getArgument(0)); } compareOperands.append(xs.begin(), xs.end()); MLIRContext *context = module.getContext(); Type i1Type = IntegerType::get(context, 1, IntegerType::Signless); FlatSymbolRefAttr lessThanFunc = getMangledSortHelperFunc( builder, func, {i1Type}, kLessThanFuncNamePrefix, nx, ny, isCoo, compareOperands, createLessThanFunc); Value cond = builder .create(loc, lessThanFunc, TypeRange{i1Type}, compareOperands) .getResult(0); builder.create(loc, cond, before->getArguments()); Block *after = builder.createBlock(&whileOp.getAfter(), {}, {i.getType()}, {loc}); builder.setInsertionPointToEnd(after); Value cs = constantIndex(builder, loc, step); i = builder.create(loc, after->getArgument(0), cs); builder.create(loc, ValueRange{i}); i = whileOp.getResult(0); builder.setInsertionPointAfter(whileOp); compareOperands[0] = i; compareOperands[1] = p; FlatSymbolRefAttr compareEqFunc = getMangledSortHelperFunc( builder, func, {i1Type}, kCompareEqFuncNamePrefix, nx, ny, isCoo, compareOperands, createEqCompareFunc); Value compareEq = builder .create(loc, compareEqFunc, TypeRange{i1Type}, compareOperands) .getResult(0); return std::make_pair(whileOp.getResult(0), compareEq); } /// Creates a code block to swap the values so that data[mi] is the median among /// data[lo], data[hi], and data[mi]. // The generated code corresponds to this C-like algorithm: // median = mi // if (data[mi] < data[lo]). (if1) // if (data[hi] < data[lo]) (if2) // median = data[hi] < data[mi] ? mi : hi // else // median = lo // else // if data[hi] < data[mi] (if3) // median = data[hi] < data[lo] ? lo : hi // if median != mi swap data[median] with data[mi] static void createChoosePivot(OpBuilder &builder, ModuleOp module, func::FuncOp func, uint64_t nx, uint64_t ny, bool isCoo, Value lo, Value hi, Value mi, ValueRange args) { SmallVector compareOperands{mi, lo}; uint64_t numXBuffers = isCoo ? 1 : nx; compareOperands.append(args.begin() + xStartIdx, args.begin() + xStartIdx + numXBuffers); Type i1Type = IntegerType::get(module.getContext(), 1, IntegerType::Signless); SmallVector cmpTypes{i1Type}; FlatSymbolRefAttr lessThanFunc = getMangledSortHelperFunc( builder, func, cmpTypes, kLessThanFuncNamePrefix, nx, ny, isCoo, compareOperands, createLessThanFunc); Location loc = func.getLoc(); // Compare data[mi] < data[lo]. Value cond1 = builder.create(loc, lessThanFunc, cmpTypes, compareOperands) .getResult(0); SmallVector ifTypes{lo.getType()}; scf::IfOp ifOp1 = builder.create(loc, ifTypes, cond1, /*else=*/true); // Generate an if-stmt to find the median value, assuming we already know that // data[b] < data[a] and we haven't compare data[c] yet. auto createFindMedian = [&](Value a, Value b, Value c) -> scf::IfOp { compareOperands[0] = c; compareOperands[1] = a; // Compare data[c]] < data[a]. Value cond2 = builder .create(loc, lessThanFunc, cmpTypes, compareOperands) .getResult(0); scf::IfOp ifOp2 = builder.create(loc, ifTypes, cond2, /*else=*/true); builder.setInsertionPointToStart(&ifOp2.getThenRegion().front()); compareOperands[0] = c; compareOperands[1] = b; // Compare data[c] < data[b]. Value cond3 = builder .create(loc, lessThanFunc, cmpTypes, compareOperands) .getResult(0); builder.create( loc, ValueRange{builder.create(loc, cond3, b, c)}); builder.setInsertionPointToStart(&ifOp2.getElseRegion().front()); builder.create(loc, ValueRange{a}); return ifOp2; }; builder.setInsertionPointToStart(&ifOp1.getThenRegion().front()); scf::IfOp ifOp2 = createFindMedian(lo, mi, hi); builder.setInsertionPointAfter(ifOp2); builder.create(loc, ValueRange{ifOp2.getResult(0)}); builder.setInsertionPointToStart(&ifOp1.getElseRegion().front()); scf::IfOp ifOp3 = createFindMedian(mi, lo, hi); builder.setInsertionPointAfter(ifOp3); builder.create(loc, ValueRange{ifOp3.getResult(0)}); builder.setInsertionPointAfter(ifOp1); Value median = ifOp1.getResult(0); Value cond = builder.create(loc, arith::CmpIPredicate::ne, mi, median); scf::IfOp ifOp = builder.create(loc, TypeRange(), cond, /*else=*/false); SmallVector swapOperands{median, mi}; swapOperands.append(args.begin() + xStartIdx, args.end()); builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); createSwap(builder, loc, swapOperands, nx, ny, isCoo); builder.setInsertionPointAfter(ifOp); } /// Creates a function to perform quick sort partition on the values in the /// range of index [lo, hi), assuming lo < hi. // // The generated IR corresponds to this C like algorithm: // int partition(lo, hi, xs) { // p = (lo+hi)/2 // pivot index // i = lo // j = hi-1 // while (i < j) do { // while (xs[i] < xs[p]) i ++; // i_eq = (xs[i] == xs[p]); // while (xs[j] > xs[p]) j --; // j_eq = (xs[j] == xs[p]); // if (i < j) { // swap(xs[i], xs[j]) // if (i == p) { // p = j; // } else if (j == p) { // p = i; // } // if (i_eq && j_eq) { // ++i; // --j; // } // } // } // return p // } static void createPartitionFunc(OpBuilder &builder, ModuleOp module, func::FuncOp func, uint64_t nx, uint64_t ny, bool isCoo) { OpBuilder::InsertionGuard insertionGuard(builder); Block *entryBlock = func.addEntryBlock(); builder.setInsertionPointToStart(entryBlock); Location loc = func.getLoc(); ValueRange args = entryBlock->getArguments(); Value lo = args[loIdx]; Value hi = args[hiIdx]; Value sum = builder.create(loc, lo, hi); Value c1 = constantIndex(builder, loc, 1); Value p = builder.create(loc, sum, c1); Value i = lo; Value j = builder.create(loc, hi, c1); createChoosePivot(builder, module, func, nx, ny, isCoo, i, j, p, args); SmallVector operands{i, j, p}; // Exactly three values. SmallVector types{i.getType(), j.getType(), p.getType()}; scf::WhileOp whileOp = builder.create(loc, types, operands); // The before-region of the WhileOp. Block *before = builder.createBlock(&whileOp.getBefore(), {}, types, {loc, loc, loc}); builder.setInsertionPointToEnd(before); Value cond = builder.create(loc, arith::CmpIPredicate::ult, before->getArgument(0), before->getArgument(1)); builder.create(loc, cond, before->getArguments()); // The after-region of the WhileOp. Block *after = builder.createBlock(&whileOp.getAfter(), {}, types, {loc, loc, loc}); builder.setInsertionPointToEnd(after); i = after->getArgument(0); j = after->getArgument(1); p = after->getArgument(2); uint64_t numXBuffers = isCoo ? 1 : nx; auto [iresult, iCompareEq] = createScanLoop(builder, module, func, args.slice(xStartIdx, numXBuffers), i, p, nx, ny, isCoo, 1); i = iresult; auto [jresult, jCompareEq] = createScanLoop(builder, module, func, args.slice(xStartIdx, numXBuffers), j, p, nx, ny, isCoo, -1); j = jresult; // If i < j: cond = builder.create(loc, arith::CmpIPredicate::ult, i, j); scf::IfOp ifOp = builder.create(loc, types, cond, /*else=*/true); builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); SmallVector swapOperands{i, j}; swapOperands.append(args.begin() + xStartIdx, args.end()); createSwap(builder, loc, swapOperands, nx, ny, isCoo); // If the pivot is moved, update p with the new pivot. Value icond = builder.create(loc, arith::CmpIPredicate::eq, i, p); scf::IfOp ifOpI = builder.create(loc, TypeRange{p.getType()}, icond, /*else=*/true); builder.setInsertionPointToStart(&ifOpI.getThenRegion().front()); builder.create(loc, ValueRange{j}); builder.setInsertionPointToStart(&ifOpI.getElseRegion().front()); Value jcond = builder.create(loc, arith::CmpIPredicate::eq, j, p); scf::IfOp ifOpJ = builder.create(loc, TypeRange{p.getType()}, jcond, /*else=*/true); builder.setInsertionPointToStart(&ifOpJ.getThenRegion().front()); builder.create(loc, ValueRange{i}); builder.setInsertionPointToStart(&ifOpJ.getElseRegion().front()); builder.create(loc, ValueRange{p}); builder.setInsertionPointAfter(ifOpJ); builder.create(loc, ifOpJ.getResults()); builder.setInsertionPointAfter(ifOpI); Value compareEqIJ = builder.create(loc, iCompareEq, jCompareEq); scf::IfOp ifOp2 = builder.create( loc, TypeRange{i.getType(), j.getType()}, compareEqIJ, /*else=*/true); builder.setInsertionPointToStart(&ifOp2.getThenRegion().front()); Value i2 = builder.create(loc, i, c1); Value j2 = builder.create(loc, j, c1); builder.create(loc, ValueRange{i2, j2}); builder.setInsertionPointToStart(&ifOp2.getElseRegion().front()); builder.create(loc, ValueRange{i, j}); builder.setInsertionPointAfter(ifOp2); builder.create( loc, ValueRange{ifOp2.getResult(0), ifOp2.getResult(1), ifOpI.getResult(0)}); // False branch for if i < j: builder.setInsertionPointToStart(&ifOp.getElseRegion().front()); builder.create(loc, ValueRange{i, j, p}); // Return for the whileOp. builder.setInsertionPointAfter(ifOp); builder.create(loc, ifOp.getResults()); // Return for the function. builder.setInsertionPointAfter(whileOp); builder.create(loc, whileOp.getResult(2)); } /// Creates a function to perform quick sort on the value in the range of /// index [lo, hi). // // The generate IR corresponds to this C like algorithm: // void quickSort(lo, hi, data) { // if (lo < hi) { // p = partition(low, high, data); // quickSort(lo, p, data); // quickSort(p + 1, hi, data); // } // } static void createSortNonstableFunc(OpBuilder &builder, ModuleOp module, func::FuncOp func, uint64_t nx, uint64_t ny, bool isCoo) { OpBuilder::InsertionGuard insertionGuard(builder); Block *entryBlock = func.addEntryBlock(); builder.setInsertionPointToStart(entryBlock); MLIRContext *context = module.getContext(); Location loc = func.getLoc(); ValueRange args = entryBlock->getArguments(); Value lo = args[loIdx]; Value hi = args[hiIdx]; Value cond = builder.create(loc, arith::CmpIPredicate::ult, lo, hi); scf::IfOp ifOp = builder.create(loc, cond, /*else=*/false); // The if-stmt true branch. builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); FlatSymbolRefAttr partitionFunc = getMangledSortHelperFunc( builder, func, {IndexType::get(context)}, kPartitionFuncNamePrefix, nx, ny, isCoo, args, createPartitionFunc); auto p = builder.create( loc, partitionFunc, TypeRange{IndexType::get(context)}, ValueRange(args)); SmallVector lowOperands{lo, p.getResult(0)}; lowOperands.append(args.begin() + xStartIdx, args.end()); builder.create(loc, func, lowOperands); SmallVector highOperands{ builder.create(loc, p.getResult(0), constantIndex(builder, loc, 1)), hi}; highOperands.append(args.begin() + xStartIdx, args.end()); builder.create(loc, func, highOperands); // After the if-stmt. builder.setInsertionPointAfter(ifOp); builder.create(loc); } /// Creates a function to perform insertion sort on the values in the range of /// index [lo, hi). // // The generate IR corresponds to this C like algorithm: // void insertionSort(lo, hi, data) { // for (i = lo+1; i < hi; i++) { // d = data[i]; // p = binarySearch(lo, i-1, data) // for (j = 0; j > i - p; j++) // data[i-j] = data[i-j-1] // data[p] = d // } // } static void createSortStableFunc(OpBuilder &builder, ModuleOp module, func::FuncOp func, uint64_t nx, uint64_t ny, bool isCoo) { OpBuilder::InsertionGuard insertionGuard(builder); Block *entryBlock = func.addEntryBlock(); builder.setInsertionPointToStart(entryBlock); MLIRContext *context = module.getContext(); Location loc = func.getLoc(); ValueRange args = entryBlock->getArguments(); Value c1 = constantIndex(builder, loc, 1); Value lo = args[loIdx]; Value hi = args[hiIdx]; Value lop1 = builder.create(loc, lo, c1); // Start the outer for-stmt with induction variable i. scf::ForOp forOpI = builder.create(loc, lop1, hi, c1); builder.setInsertionPointToStart(forOpI.getBody()); Value i = forOpI.getInductionVar(); // Binary search to find the insertion point p. SmallVector operands{lo, i}; operands.append(args.begin() + xStartIdx, args.end()); FlatSymbolRefAttr searchFunc = getMangledSortHelperFunc( builder, func, {IndexType::get(context)}, kBinarySearchFuncNamePrefix, nx, ny, isCoo, operands, createBinarySearchFunc); Value p = builder .create(loc, searchFunc, TypeRange{c1.getType()}, operands) .getResult(0); // Move the value at data[i] to a temporary location. operands[0] = operands[1] = i; SmallVector d; forEachIJPairInAllBuffers( builder, loc, operands, nx, ny, isCoo, [&](uint64_t unused, Value i, Value unused2, Value buffer) { d.push_back(builder.create(loc, buffer, i)); }); // Start the inner for-stmt with induction variable j, for moving data[p..i) // to data[p+1..i+1). Value imp = builder.create(loc, i, p); Value c0 = constantIndex(builder, loc, 0); scf::ForOp forOpJ = builder.create(loc, c0, imp, c1); builder.setInsertionPointToStart(forOpJ.getBody()); Value j = forOpJ.getInductionVar(); Value imj = builder.create(loc, i, j); operands[1] = imj; operands[0] = builder.create(loc, imj, c1); forEachIJPairInAllBuffers( builder, loc, operands, nx, ny, isCoo, [&](uint64_t unused, Value imjm1, Value imj, Value buffer) { Value t = builder.create(loc, buffer, imjm1); builder.create(loc, t, buffer, imj); }); // Store the value at data[i] to data[p]. builder.setInsertionPointAfter(forOpJ); operands[0] = operands[1] = p; forEachIJPairInAllBuffers( builder, loc, operands, nx, ny, isCoo, [&](uint64_t k, Value p, Value usused, Value buffer) { builder.create(loc, d[k], buffer, p); }); builder.setInsertionPointAfter(forOpI); builder.create(loc); } /// Implements the rewriting for operator sort and sort_coo. template LogicalResult matchAndRewriteSortOp(OpTy op, ValueRange xys, uint64_t nx, uint64_t ny, bool isCoo, PatternRewriter &rewriter) { Location loc = op.getLoc(); SmallVector operands{constantIndex(rewriter, loc, 0), op.getN()}; // Convert `values` to have dynamic shape and append them to `operands`. for (Value v : xys) { auto mtp = getMemRefType(v); if (!mtp.isDynamicDim(0)) { auto newMtp = MemRefType::get({ShapedType::kDynamic}, mtp.getElementType()); v = rewriter.create(loc, newMtp, v); } operands.push_back(v); } bool isStable = (op.getAlgorithm() == SparseTensorSortKind::InsertionSortStable); auto insertPoint = op->template getParentOfType(); SmallString<32> funcName(isStable ? kSortStableFuncNamePrefix : kSortNonstableFuncNamePrefix); FuncGeneratorType funcGenerator = isStable ? createSortStableFunc : createSortNonstableFunc; FlatSymbolRefAttr func = getMangledSortHelperFunc(rewriter, insertPoint, TypeRange(), funcName, nx, ny, isCoo, operands, funcGenerator); rewriter.replaceOpWithNewOp(op, func, TypeRange(), operands); return success(); } //===---------------------------------------------------------------------===// // The actual sparse buffer rewriting rules. //===---------------------------------------------------------------------===// namespace { /// Sparse rewriting rule for the push_back operator. struct PushBackRewriter : OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; PushBackRewriter(MLIRContext *context, bool enableInit) : OpRewritePattern(context), enableBufferInitialization(enableInit) {} LogicalResult matchAndRewrite(PushBackOp op, PatternRewriter &rewriter) const override { // Rewrite push_back(buffer, value, n) to: // new_size = size(buffer) + n // if (new_size > capacity(buffer)) // while new_size > new_capacity // new_capacity = new_capacity*2 // new_buffer = realloc(buffer, new_capacity) // buffer = new_buffer // subBuffer = subviewof(buffer) // linalg.fill subBuffer value // // size(buffer) += n // // The capacity check is skipped when the attribute inbounds is presented. Location loc = op->getLoc(); Value c0 = constantIndex(rewriter, loc, 0); Value buffer = op.getInBuffer(); Value capacity = rewriter.create(loc, buffer, c0); Value size = op.getCurSize(); Value value = op.getValue(); Value n = op.getN() ? op.getN() : constantIndex(rewriter, loc, 1); Value newSize = rewriter.create(loc, size, n); auto nValue = dyn_cast_or_null(n.getDefiningOp()); bool nIsOne = (nValue && nValue.value() == 1); if (!op.getInbounds()) { Value cond = rewriter.create( loc, arith::CmpIPredicate::ugt, newSize, capacity); Value c2 = constantIndex(rewriter, loc, 2); auto bufferType = MemRefType::get({ShapedType::kDynamic}, value.getType()); scf::IfOp ifOp = rewriter.create(loc, bufferType, cond, /*else=*/true); // True branch. rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front()); if (nIsOne) { capacity = rewriter.create(loc, capacity, c2); } else { // Use a do-while loop to calculate the new capacity as follows: // do { new_capacity *= 2 } while (size > new_capacity) scf::WhileOp whileOp = rewriter.create(loc, capacity.getType(), capacity); // The before-region of the WhileOp. Block *before = rewriter.createBlock(&whileOp.getBefore(), {}, {capacity.getType()}, {loc}); rewriter.setInsertionPointToEnd(before); capacity = rewriter.create(loc, before->getArgument(0), c2); cond = rewriter.create(loc, arith::CmpIPredicate::ugt, newSize, capacity); rewriter.create(loc, cond, ValueRange{capacity}); // The after-region of the WhileOp. Block *after = rewriter.createBlock(&whileOp.getAfter(), {}, {capacity.getType()}, {loc}); rewriter.setInsertionPointToEnd(after); rewriter.create(loc, after->getArguments()); rewriter.setInsertionPointAfter(whileOp); capacity = whileOp.getResult(0); } Value newBuffer = rewriter.create(loc, bufferType, buffer, capacity); if (enableBufferInitialization) { Value fillSize = rewriter.create(loc, capacity, newSize); Value fillValue = constantZero(rewriter, loc, value.getType()); Value subBuffer = rewriter.create( loc, newBuffer, /*offset=*/ValueRange{newSize}, /*size=*/ValueRange{fillSize}, /*step=*/ValueRange{constantIndex(rewriter, loc, 1)}); rewriter.create(loc, fillValue, subBuffer); } rewriter.create(loc, newBuffer); // False branch. rewriter.setInsertionPointToStart(&ifOp.getElseRegion().front()); rewriter.create(loc, buffer); // Prepare for adding the value to the end of the buffer. rewriter.setInsertionPointAfter(ifOp); buffer = ifOp.getResult(0); } // Add the value to the end of the buffer. if (nIsOne) { rewriter.create(loc, value, buffer, size); } else { Value subBuffer = rewriter.create( loc, buffer, /*offset=*/ValueRange{size}, /*size=*/ValueRange{n}, /*step=*/ValueRange{constantIndex(rewriter, loc, 1)}); rewriter.create(loc, value, subBuffer); } // Update the buffer size. rewriter.replaceOp(op, {buffer, newSize}); return success(); } private: bool enableBufferInitialization; }; /// Sparse rewriting rule for the sort operator. struct SortRewriter : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(SortOp op, PatternRewriter &rewriter) const override { SmallVector xys(op.getXs()); xys.append(op.getYs().begin(), op.getYs().end()); return matchAndRewriteSortOp(op, xys, op.getXs().size(), /*ny=*/0, /*isCoo=*/false, rewriter); } }; /// Sparse rewriting rule for the sort_coo operator. struct SortCooRewriter : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(SortCooOp op, PatternRewriter &rewriter) const override { SmallVector xys; xys.push_back(op.getXy()); xys.append(op.getYs().begin(), op.getYs().end()); uint64_t nx = 1; if (auto nxAttr = op.getNxAttr()) nx = nxAttr.getInt(); uint64_t ny = 0; if (auto nyAttr = op.getNyAttr()) ny = nyAttr.getInt(); return matchAndRewriteSortOp(op, xys, nx, ny, /*isCoo=*/true, rewriter); } }; } // namespace //===---------------------------------------------------------------------===// // Methods that add patterns described in this file to a pattern list. //===---------------------------------------------------------------------===// void mlir::populateSparseBufferRewriting(RewritePatternSet &patterns, bool enableBufferInitialization) { patterns.add(patterns.getContext(), enableBufferInitialization); patterns.add(patterns.getContext()); }