xref: /llvm-project/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp (revision 5262865aac683b72f3e66de7a122e0c455ab6b9b)
1062e515bSbixia1 //===- SparseBufferRewriting.cpp - Sparse buffer rewriting rules ----------===//
2062e515bSbixia1 //
3062e515bSbixia1 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4062e515bSbixia1 // See https://llvm.org/LICENSE.txt for license information.
5062e515bSbixia1 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6062e515bSbixia1 //
7062e515bSbixia1 //===----------------------------------------------------------------------===//
8062e515bSbixia1 //
9062e515bSbixia1 // This file implements rewriting rules that are specific to sparse tensor
10062e515bSbixia1 // primitives with memref operands.
11062e515bSbixia1 //
12062e515bSbixia1 //===----------------------------------------------------------------------===//
13062e515bSbixia1 
14365777ecSAart Bik #include "Utils/CodegenUtils.h"
15062e515bSbixia1 
16bfbf3bcbSAart Bik #include "mlir/Dialect/Arith/IR/Arith.h"
17062e515bSbixia1 #include "mlir/Dialect/Func/IR/FuncOps.h"
18d45be887Sbixia1 #include "mlir/Dialect/Linalg/IR/Linalg.h"
19a1507668Sbixia1 #include "mlir/Dialect/Math/IR/Math.h"
20062e515bSbixia1 #include "mlir/Dialect/MemRef/IR/MemRef.h"
21062e515bSbixia1 #include "mlir/Dialect/SCF/IR/SCF.h"
22062e515bSbixia1 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
23062e515bSbixia1 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
24062e515bSbixia1 #include "mlir/Support/LLVM.h"
25062e515bSbixia1 
26062e515bSbixia1 using namespace mlir;
27062e515bSbixia1 using namespace mlir::sparse_tensor;
28062e515bSbixia1 
29062e515bSbixia1 //===---------------------------------------------------------------------===//
30062e515bSbixia1 // Helper methods for the actual rewriting rules.
31062e515bSbixia1 //===---------------------------------------------------------------------===//
32062e515bSbixia1 
339409bbb2Sbixia1 static constexpr uint64_t loIdx = 0;
349409bbb2Sbixia1 static constexpr uint64_t hiIdx = 1;
359409bbb2Sbixia1 static constexpr uint64_t xStartIdx = 2;
369409bbb2Sbixia1 
379409bbb2Sbixia1 static constexpr const char kPartitionFuncNamePrefix[] = "_sparse_partition_";
389409bbb2Sbixia1 static constexpr const char kBinarySearchFuncNamePrefix[] =
399409bbb2Sbixia1     "_sparse_binary_search_";
40a1507668Sbixia1 static constexpr const char kHybridQuickSortFuncNamePrefix[] =
41a1507668Sbixia1     "_sparse_hybrid_qsort_";
429409bbb2Sbixia1 static constexpr const char kSortStableFuncNamePrefix[] =
439409bbb2Sbixia1     "_sparse_sort_stable_";
443b1c86cdSbixia1 static constexpr const char kShiftDownFuncNamePrefix[] = "_sparse_shift_down_";
453b1c86cdSbixia1 static constexpr const char kHeapSortFuncNamePrefix[] = "_sparse_heap_sort_";
46a1507668Sbixia1 static constexpr const char kQuickSortFuncNamePrefix[] = "_sparse_qsort_";
47062e515bSbixia1 
48bfa3bc43SPeiming Liu using FuncGeneratorType = function_ref<void(OpBuilder &, ModuleOp, func::FuncOp,
49bfa3bc43SPeiming Liu                                             AffineMap, uint64_t, uint32_t)>;
50062e515bSbixia1 
51062e515bSbixia1 /// Constructs a function name with this format to facilitate quick sort:
52bfa3bc43SPeiming Liu ///   <namePrefix><xPerm>_<x type>_<y0 type>..._<yn type> for sort
53bfa3bc43SPeiming Liu ///   <namePrefix><xPerm>_<x type>_coo_<ny>_<y0 type>..._<yn type> for sort_coo
54062e515bSbixia1 static void getMangledSortHelperFuncName(llvm::raw_svector_ostream &nameOstream,
55bfa3bc43SPeiming Liu                                          StringRef namePrefix, AffineMap xPerm,
56bfa3bc43SPeiming Liu                                          uint64_t ny, ValueRange operands) {
57bfa3bc43SPeiming Liu   nameOstream << namePrefix;
58bfa3bc43SPeiming Liu   for (auto res : xPerm.getResults())
591609f1c2Slong.chen     nameOstream << cast<AffineDimExpr>(res).getPosition() << "_";
60062e515bSbixia1 
61bfa3bc43SPeiming Liu   nameOstream << getMemRefType(operands[xStartIdx]).getElementType();
624f729d5aSbixia1   nameOstream << "_coo_" << ny;
634f729d5aSbixia1 
64bfa3bc43SPeiming Liu   constexpr uint64_t yBufferOffset = 1;
654f729d5aSbixia1   for (Value v : operands.drop_front(xStartIdx + yBufferOffset))
669916ab03Swren romano     nameOstream << "_" << getMemRefType(v).getElementType();
67062e515bSbixia1 }
68062e515bSbixia1 
69062e515bSbixia1 /// Looks up a function that is appropriate for the given operands being
704f729d5aSbixia1 /// sorted, and creates such a function if it doesn't exist yet. The
71bfa3bc43SPeiming Liu /// parameters `xPerm` and `ny` tell the number of x and y values provided
72bfa3bc43SPeiming Liu /// by the buffer in xStartIdx.
738550aebdSbixia1 //
748550aebdSbixia1 // All sorting function generators take (lo, hi, xs, ys) in `operands` as
758550aebdSbixia1 // parameters for the sorting functions. Other parameters, such as the recursive
768550aebdSbixia1 // call depth, are appended to the end of the parameter list as
778550aebdSbixia1 // "trailing parameters".
78bfa3bc43SPeiming Liu static FlatSymbolRefAttr getMangledSortHelperFunc(
79bfa3bc43SPeiming Liu     OpBuilder &builder, func::FuncOp insertPoint, TypeRange resultTypes,
80bfa3bc43SPeiming Liu     StringRef namePrefix, AffineMap xPerm, uint64_t ny, ValueRange operands,
81bfa3bc43SPeiming Liu     FuncGeneratorType createFunc, uint32_t nTrailingP = 0) {
82062e515bSbixia1   SmallString<32> nameBuffer;
83062e515bSbixia1   llvm::raw_svector_ostream nameOstream(nameBuffer);
84bfa3bc43SPeiming Liu   getMangledSortHelperFuncName(nameOstream, namePrefix, xPerm, ny,
858550aebdSbixia1                                operands.drop_back(nTrailingP));
86062e515bSbixia1 
87062e515bSbixia1   ModuleOp module = insertPoint->getParentOfType<ModuleOp>();
88062e515bSbixia1   MLIRContext *context = module.getContext();
89062e515bSbixia1   auto result = SymbolRefAttr::get(context, nameOstream.str());
90062e515bSbixia1   auto func = module.lookupSymbol<func::FuncOp>(result.getAttr());
91062e515bSbixia1 
92062e515bSbixia1   if (!func) {
93062e515bSbixia1     // Create the function.
94062e515bSbixia1     OpBuilder::InsertionGuard insertionGuard(builder);
95062e515bSbixia1     builder.setInsertionPoint(insertPoint);
96062e515bSbixia1     Location loc = insertPoint.getLoc();
97062e515bSbixia1     func = builder.create<func::FuncOp>(
98062e515bSbixia1         loc, nameOstream.str(),
99062e515bSbixia1         FunctionType::get(context, operands.getTypes(), resultTypes));
100062e515bSbixia1     func.setPrivate();
101bfa3bc43SPeiming Liu     createFunc(builder, module, func, xPerm, ny, nTrailingP);
102062e515bSbixia1   }
103062e515bSbixia1 
104062e515bSbixia1   return result;
105062e515bSbixia1 }
106062e515bSbixia1 
1074f729d5aSbixia1 /// Creates a code block to process each pair of (xs[i], xs[j]) for sorting.
1084f729d5aSbixia1 /// The code to process the value pairs is generated by `bodyBuilder`.
1094f729d5aSbixia1 static void forEachIJPairInXs(
110bfa3bc43SPeiming Liu     OpBuilder &builder, Location loc, ValueRange args, AffineMap xPerm,
111bfa3bc43SPeiming Liu     uint64_t ny,
112bfa3bc43SPeiming Liu     function_ref<void(uint64_t, Value, Value, Value)> bodyBuilder) {
113bfa3bc43SPeiming Liu   Value cstep = constantIndex(builder, loc, xPerm.getNumResults() + ny);
114bfa3bc43SPeiming Liu   Value iOffset = builder.create<arith::MulIOp>(loc, args[0], cstep);
115bfa3bc43SPeiming Liu   Value jOffset = builder.create<arith::MulIOp>(loc, args[1], cstep);
116bfa3bc43SPeiming Liu   for (unsigned k = 0, e = xPerm.getNumResults(); k < e; k++) {
1171609f1c2Slong.chen     unsigned actualK = cast<AffineDimExpr>(xPerm.getResult(k)).getPosition();
118bfa3bc43SPeiming Liu     Value ak = constantIndex(builder, loc, actualK);
119bfa3bc43SPeiming Liu     Value i = builder.create<arith::AddIOp>(loc, ak, iOffset);
120bfa3bc43SPeiming Liu     Value j = builder.create<arith::AddIOp>(loc, ak, jOffset);
121bfa3bc43SPeiming Liu     Value buffer = args[xStartIdx];
122bfa3bc43SPeiming Liu 
1234f729d5aSbixia1     bodyBuilder(k, i, j, buffer);
1244f729d5aSbixia1   }
1254f729d5aSbixia1 }
1264f729d5aSbixia1 
1274f729d5aSbixia1 /// Creates a code block to process each pair of (xys[i], xys[j]) for sorting.
1284f729d5aSbixia1 /// The code to process the value pairs is generated by `bodyBuilder`.
1294f729d5aSbixia1 static void forEachIJPairInAllBuffers(
130bfa3bc43SPeiming Liu     OpBuilder &builder, Location loc, ValueRange args, AffineMap xPerm,
131bfa3bc43SPeiming Liu     uint64_t ny,
132bfa3bc43SPeiming Liu     function_ref<void(uint64_t, Value, Value, Value)> bodyBuilder) {
1334f729d5aSbixia1 
134bfa3bc43SPeiming Liu   // Create code for the first (xPerm + ny) buffers.
135*5262865aSKazu Hirata   SmallVector<AffineExpr> exps(xPerm.getResults());
136bfa3bc43SPeiming Liu   for (unsigned y = 0; y < ny; y++) {
137bfa3bc43SPeiming Liu     exps.push_back(builder.getAffineDimExpr(y + xPerm.getNumResults()));
138bfa3bc43SPeiming Liu   }
139bfa3bc43SPeiming Liu   AffineMap xyPerm = AffineMap::get(exps.size(), 0, exps, builder.getContext());
140bfa3bc43SPeiming Liu   assert(xyPerm.isPermutation());
1414f729d5aSbixia1 
142bfa3bc43SPeiming Liu   forEachIJPairInXs(builder, loc, args, xyPerm, 0, bodyBuilder);
1434f729d5aSbixia1 
144bfa3bc43SPeiming Liu   constexpr uint64_t numHandledBuffers = 1;
1454f729d5aSbixia1   // Create code for the remaining buffers.
1464f729d5aSbixia1   Value i = args[0];
1474f729d5aSbixia1   Value j = args[1];
1484f729d5aSbixia1   for (const auto &arg :
1494f729d5aSbixia1        llvm::enumerate(args.drop_front(xStartIdx + numHandledBuffers))) {
150bfa3bc43SPeiming Liu     bodyBuilder(arg.index() + xPerm.getNumResults() + ny, i, j, arg.value());
1514f729d5aSbixia1   }
1524f729d5aSbixia1 }
1534f729d5aSbixia1 
1549b800bf7Sbixia1 /// Creates a code block for swapping the values in index i and j for all the
155062e515bSbixia1 /// buffers.
156062e515bSbixia1 //
1579b800bf7Sbixia1 // The generated IR corresponds to this C like algorithm:
158062e515bSbixia1 //     swap(x0[i], x0[j]);
159062e515bSbixia1 //     swap(x1[i], x1[j]);
160062e515bSbixia1 //     ...
161062e515bSbixia1 //     swap(xn[i], xn[j]);
162062e515bSbixia1 //     swap(y0[i], y0[j]);
163062e515bSbixia1 //     ...
164062e515bSbixia1 //     swap(yn[i], yn[j]);
1654f729d5aSbixia1 static void createSwap(OpBuilder &builder, Location loc, ValueRange args,
166bfa3bc43SPeiming Liu                        AffineMap xPerm, uint64_t ny) {
1674f729d5aSbixia1   auto swapOnePair = [&](uint64_t unused, Value i, Value j, Value buffer) {
1684f729d5aSbixia1     Value vi = builder.create<memref::LoadOp>(loc, buffer, i);
1694f729d5aSbixia1     Value vj = builder.create<memref::LoadOp>(loc, buffer, j);
1704f729d5aSbixia1     builder.create<memref::StoreOp>(loc, vj, buffer, i);
1714f729d5aSbixia1     builder.create<memref::StoreOp>(loc, vi, buffer, j);
1724f729d5aSbixia1   };
1734f729d5aSbixia1 
174bfa3bc43SPeiming Liu   forEachIJPairInAllBuffers(builder, loc, args, xPerm, ny, swapOnePair);
175062e515bSbixia1 }
176062e515bSbixia1 
1772ef41627Sbixia1 /// Creates code to compare all the (xs[i], xs[j]) pairs. The method to compare
1782ef41627Sbixia1 /// each pair is create via `compareBuilder`.
1792ef41627Sbixia1 static Value createInlinedCompareImplementation(
180bfa3bc43SPeiming Liu     OpBuilder &builder, Location loc, ValueRange args, AffineMap xPerm,
181bfa3bc43SPeiming Liu     uint64_t ny,
1822ef41627Sbixia1     function_ref<Value(OpBuilder &, Location, Value, Value, Value, bool, bool)>
1839b800bf7Sbixia1         compareBuilder) {
1842ef41627Sbixia1   Value result;
1854f729d5aSbixia1   auto bodyBuilder = [&](uint64_t k, Value i, Value j, Value buffer) {
1862ef41627Sbixia1     bool isFirstDim = (k == 0);
187bfa3bc43SPeiming Liu     bool isLastDim = (k == xPerm.getNumResults() - 1);
1882ef41627Sbixia1     Value val =
1892ef41627Sbixia1         compareBuilder(builder, loc, i, j, buffer, isFirstDim, isLastDim);
1902ef41627Sbixia1     if (isFirstDim) {
1912ef41627Sbixia1       result = val;
1922ef41627Sbixia1     } else if (!isLastDim) {
1939b800bf7Sbixia1       OpBuilder::InsertionGuard insertionGuard(builder);
1942ef41627Sbixia1       auto ifOp = cast<scf::IfOp>(val.getDefiningOp());
1959b800bf7Sbixia1       builder.setInsertionPointAfter(ifOp);
1969b800bf7Sbixia1       builder.create<scf::YieldOp>(loc, ifOp.getResult(0));
1979b800bf7Sbixia1     }
1984f729d5aSbixia1   };
1994f729d5aSbixia1 
200bfa3bc43SPeiming Liu   forEachIJPairInXs(builder, loc, args, xPerm, ny, bodyBuilder);
2019b800bf7Sbixia1 
2022ef41627Sbixia1   builder.setInsertionPointAfterValue(result);
2032ef41627Sbixia1   return result;
2049b800bf7Sbixia1 }
2059b800bf7Sbixia1 
2062ef41627Sbixia1 /// Generates code to compare whether x[i] is equal to x[j] and returns the
2072ef41627Sbixia1 /// result of the comparison.
2082ef41627Sbixia1 static Value createEqCompare(OpBuilder &builder, Location loc, Value i, Value j,
2092ef41627Sbixia1                              Value x, bool isFirstDim, bool isLastDim) {
2109b800bf7Sbixia1   Value vi = builder.create<memref::LoadOp>(loc, x, i);
2119b800bf7Sbixia1   Value vj = builder.create<memref::LoadOp>(loc, x, j);
2129b800bf7Sbixia1 
2132ef41627Sbixia1   Value res;
2142ef41627Sbixia1   if (isLastDim) {
2152ef41627Sbixia1     res = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, vi, vj);
2162ef41627Sbixia1     // For 1D, we create a compare without any control flow. Otherwise, we
2172ef41627Sbixia1     // create YieldOp to return the result in the nested if-stmt.
2182ef41627Sbixia1     if (!isFirstDim)
2192ef41627Sbixia1       builder.create<scf::YieldOp>(loc, res);
2202ef41627Sbixia1   } else {
2212ef41627Sbixia1     Value ne =
2222ef41627Sbixia1         builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ne, vi, vj);
2232ef41627Sbixia1     scf::IfOp ifOp = builder.create<scf::IfOp>(loc, builder.getIntegerType(1),
2242ef41627Sbixia1                                                ne, /*else=*/true);
2252ef41627Sbixia1     // If (x[i] != x[j]).
2262ef41627Sbixia1     builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
2272ef41627Sbixia1     Value f = constantI1(builder, loc, false);
2289b800bf7Sbixia1     builder.create<scf::YieldOp>(loc, f);
2299b800bf7Sbixia1 
2302ef41627Sbixia1     // If (x[i] == x[j]). Set up the insertion point for the nested if-stmt that
2312ef41627Sbixia1     // checks the remaining dimensions.
2322ef41627Sbixia1     builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
2332ef41627Sbixia1     res = ifOp.getResult(0);
2349b800bf7Sbixia1   }
2359b800bf7Sbixia1 
2362ef41627Sbixia1   return res;
2379b800bf7Sbixia1 }
2389b800bf7Sbixia1 
2392ef41627Sbixia1 /// Creates code to compare whether xs[i] is equal to xs[j].
2409b800bf7Sbixia1 //
2419b800bf7Sbixia1 // The generate IR corresponds to this C like algorithm:
2429b800bf7Sbixia1 //   if (x0[i] != x0[j])
2439b800bf7Sbixia1 //     return false;
2449b800bf7Sbixia1 //   else
2459b800bf7Sbixia1 //     if (x1[i] != x1[j])
2469b800bf7Sbixia1 //       return false;
2479b800bf7Sbixia1 //     else if (x2[2] != x2[j]))
2489b800bf7Sbixia1 //       and so on ...
2492ef41627Sbixia1 static Value createInlinedEqCompare(OpBuilder &builder, Location loc,
250bfa3bc43SPeiming Liu                                     ValueRange args, AffineMap xPerm,
251bfa3bc43SPeiming Liu                                     uint64_t ny, uint32_t nTrailingP = 0) {
2528550aebdSbixia1   // Compare functions don't use trailing parameters.
2538550aebdSbixia1   (void)nTrailingP;
2548550aebdSbixia1   assert(nTrailingP == 0);
255bfa3bc43SPeiming Liu   return createInlinedCompareImplementation(builder, loc, args, xPerm, ny,
2564f729d5aSbixia1                                             createEqCompare);
2579b800bf7Sbixia1 }
2589b800bf7Sbixia1 
2592ef41627Sbixia1 /// Generates code to compare whether x[i] is less than x[j] and returns the
2602ef41627Sbixia1 /// result of the comparison.
2612ef41627Sbixia1 static Value createLessThanCompare(OpBuilder &builder, Location loc, Value i,
2622ef41627Sbixia1                                    Value j, Value x, bool isFirstDim,
263062e515bSbixia1                                    bool isLastDim) {
264062e515bSbixia1   Value vi = builder.create<memref::LoadOp>(loc, x, i);
265062e515bSbixia1   Value vj = builder.create<memref::LoadOp>(loc, x, j);
266062e515bSbixia1 
2672ef41627Sbixia1   Value res;
2682ef41627Sbixia1   if (isLastDim) {
2692ef41627Sbixia1     res = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult, vi, vj);
2702ef41627Sbixia1     // For 1D, we create a compare without any control flow. Otherwise, we
2712ef41627Sbixia1     // create YieldOp to return the result in the nested if-stmt.
2722ef41627Sbixia1     if (!isFirstDim)
2732ef41627Sbixia1       builder.create<scf::YieldOp>(loc, res);
274062e515bSbixia1   } else {
2752ef41627Sbixia1     Value ne =
2762ef41627Sbixia1         builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ne, vi, vj);
2772ef41627Sbixia1     scf::IfOp ifOp = builder.create<scf::IfOp>(loc, builder.getIntegerType(1),
2782ef41627Sbixia1                                                ne, /*else=*/true);
2792ef41627Sbixia1     // If (x[i] != x[j]).
2802ef41627Sbixia1     builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
2812ef41627Sbixia1     Value lt =
2822ef41627Sbixia1         builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult, vi, vj);
2832ef41627Sbixia1     builder.create<scf::YieldOp>(loc, lt);
284062e515bSbixia1 
2852ef41627Sbixia1     // If (x[i] == x[j]). Set up the insertion point for the nested if-stmt that
2862ef41627Sbixia1     // checks the remaining dimensions.
2872ef41627Sbixia1     builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
2882ef41627Sbixia1     res = ifOp.getResult(0);
289062e515bSbixia1   }
290062e515bSbixia1 
2912ef41627Sbixia1   return res;
292062e515bSbixia1 }
293062e515bSbixia1 
2942ef41627Sbixia1 /// Creates code to compare whether xs[i] is less than xs[j].
295062e515bSbixia1 //
296062e515bSbixia1 // The generate IR corresponds to this C like algorithm:
2972ef41627Sbixia1 //   if (x0[i] != x0[j])
2982ef41627Sbixia1 //     return x0[i] < x0[j];
2992ef41627Sbixia1 //   else if (x1[j] != x1[i])
3002ef41627Sbixia1 //     return x1[i] < x1[j];
301062e515bSbixia1 //   else
302062e515bSbixia1 //       and so on ...
3032ef41627Sbixia1 static Value createInlinedLessThan(OpBuilder &builder, Location loc,
304bfa3bc43SPeiming Liu                                    ValueRange args, AffineMap xPerm,
305bfa3bc43SPeiming Liu                                    uint64_t ny, uint32_t nTrailingP = 0) {
3068550aebdSbixia1   // Compare functions don't use trailing parameters.
3078550aebdSbixia1   (void)nTrailingP;
3088550aebdSbixia1   assert(nTrailingP == 0);
309bfa3bc43SPeiming Liu   return createInlinedCompareImplementation(builder, loc, args, xPerm, ny,
3109b800bf7Sbixia1                                             createLessThanCompare);
311062e515bSbixia1 }
312062e515bSbixia1 
3139409bbb2Sbixia1 /// Creates a function to use a binary search to find the insertion point for
3149409bbb2Sbixia1 /// inserting xs[hi] to the sorted values xs[lo..hi).
3159409bbb2Sbixia1 //
3169409bbb2Sbixia1 // The generate IR corresponds to this C like algorithm:
3179409bbb2Sbixia1 //   p = hi
3189409bbb2Sbixia1 //   while (lo < hi)
3199409bbb2Sbixia1 //      mid = (lo + hi) >> 1
3209409bbb2Sbixia1 //      if (xs[p] < xs[mid])
3219409bbb2Sbixia1 //        hi = mid
3229409bbb2Sbixia1 //      else
3239409bbb2Sbixia1 //        lo = mid - 1
3249409bbb2Sbixia1 //   return lo;
3259409bbb2Sbixia1 //
3269409bbb2Sbixia1 static void createBinarySearchFunc(OpBuilder &builder, ModuleOp module,
327bfa3bc43SPeiming Liu                                    func::FuncOp func, AffineMap xPerm,
328bfa3bc43SPeiming Liu                                    uint64_t ny, uint32_t nTrailingP = 0) {
3298550aebdSbixia1   // Binary search doesn't use trailing parameters.
3308550aebdSbixia1   (void)nTrailingP;
3318550aebdSbixia1   assert(nTrailingP == 0);
3329409bbb2Sbixia1   OpBuilder::InsertionGuard insertionGuard(builder);
3339409bbb2Sbixia1   Block *entryBlock = func.addEntryBlock();
3349409bbb2Sbixia1   builder.setInsertionPointToStart(entryBlock);
3359409bbb2Sbixia1 
3369409bbb2Sbixia1   Location loc = func.getLoc();
3379409bbb2Sbixia1   ValueRange args = entryBlock->getArguments();
3389409bbb2Sbixia1   Value p = args[hiIdx];
3398550aebdSbixia1   SmallVector<Type, 2> types(2, p.getType()); // Only two types.
3409409bbb2Sbixia1   scf::WhileOp whileOp = builder.create<scf::WhileOp>(
3419409bbb2Sbixia1       loc, types, SmallVector<Value, 2>{args[loIdx], args[hiIdx]});
3429409bbb2Sbixia1 
3439409bbb2Sbixia1   // The before-region of the WhileOp.
3449409bbb2Sbixia1   Block *before =
3459409bbb2Sbixia1       builder.createBlock(&whileOp.getBefore(), {}, types, {loc, loc});
3469409bbb2Sbixia1   builder.setInsertionPointToEnd(before);
3479409bbb2Sbixia1   Value cond1 = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult,
3489409bbb2Sbixia1                                               before->getArgument(0),
3499409bbb2Sbixia1                                               before->getArgument(1));
3509409bbb2Sbixia1   builder.create<scf::ConditionOp>(loc, cond1, before->getArguments());
3519409bbb2Sbixia1 
3529409bbb2Sbixia1   // The after-region of the WhileOp.
3539409bbb2Sbixia1   Block *after =
3549409bbb2Sbixia1       builder.createBlock(&whileOp.getAfter(), {}, types, {loc, loc});
3559409bbb2Sbixia1   builder.setInsertionPointToEnd(after);
3569409bbb2Sbixia1   Value lo = after->getArgument(0);
3579409bbb2Sbixia1   Value hi = after->getArgument(1);
3589409bbb2Sbixia1   // Compute mid = (lo + hi) >> 1.
3599409bbb2Sbixia1   Value c1 = constantIndex(builder, loc, 1);
3609409bbb2Sbixia1   Value mid = builder.create<arith::ShRUIOp>(
3619409bbb2Sbixia1       loc, builder.create<arith::AddIOp>(loc, lo, hi), c1);
3629409bbb2Sbixia1   Value midp1 = builder.create<arith::AddIOp>(loc, mid, c1);
3639409bbb2Sbixia1 
3649409bbb2Sbixia1   // Compare xs[p] < xs[mid].
3650e1708ffSAart Bik   SmallVector<Value> compareOperands{p, mid};
366bfa3bc43SPeiming Liu   constexpr uint64_t numXBuffers = 1;
3679409bbb2Sbixia1   compareOperands.append(args.begin() + xStartIdx,
3684f729d5aSbixia1                          args.begin() + xStartIdx + numXBuffers);
369bfa3bc43SPeiming Liu   Value cond2 = createInlinedLessThan(builder, loc, compareOperands, xPerm, ny);
3709409bbb2Sbixia1   // Update lo and hi for the WhileOp as follows:
3719409bbb2Sbixia1   //   if (xs[p] < xs[mid]))
3729409bbb2Sbixia1   //     hi = mid;
3739409bbb2Sbixia1   //   else
3749409bbb2Sbixia1   //     lo = mid + 1;
3759409bbb2Sbixia1   Value newLo = builder.create<arith::SelectOp>(loc, cond2, lo, midp1);
3769409bbb2Sbixia1   Value newHi = builder.create<arith::SelectOp>(loc, cond2, mid, hi);
3779409bbb2Sbixia1   builder.create<scf::YieldOp>(loc, ValueRange{newLo, newHi});
3789409bbb2Sbixia1 
3799409bbb2Sbixia1   builder.setInsertionPointAfter(whileOp);
3809409bbb2Sbixia1   builder.create<func::ReturnOp>(loc, whileOp.getResult(0));
3819409bbb2Sbixia1 }
3829409bbb2Sbixia1 
3839b800bf7Sbixia1 /// Creates code to advance i in a loop based on xs[p] as follows:
3849b800bf7Sbixia1 ///   while (xs[i] < xs[p]) i += step (step > 0)
3859b800bf7Sbixia1 /// or
3869b800bf7Sbixia1 ///   while (xs[i] > xs[p]) i += step (step < 0)
3879b800bf7Sbixia1 /// The routine returns i as well as a boolean value to indicate whether
3889b800bf7Sbixia1 /// xs[i] == xs[p].
389bfa3bc43SPeiming Liu static std::pair<Value, Value> createScanLoop(OpBuilder &builder,
390bfa3bc43SPeiming Liu                                               ModuleOp module,
391bfa3bc43SPeiming Liu                                               func::FuncOp func, ValueRange xs,
392bfa3bc43SPeiming Liu                                               Value i, Value p, AffineMap xPerm,
393bfa3bc43SPeiming Liu                                               uint64_t ny, int step) {
394062e515bSbixia1   Location loc = func.getLoc();
3959b800bf7Sbixia1   scf::WhileOp whileOp =
3969b800bf7Sbixia1       builder.create<scf::WhileOp>(loc, TypeRange{i.getType()}, ValueRange{i});
397062e515bSbixia1 
3989b800bf7Sbixia1   Block *before =
3999b800bf7Sbixia1       builder.createBlock(&whileOp.getBefore(), {}, {i.getType()}, {loc});
4009b800bf7Sbixia1   builder.setInsertionPointToEnd(before);
4010e1708ffSAart Bik   SmallVector<Value> compareOperands;
4029b800bf7Sbixia1   if (step > 0) {
4039b800bf7Sbixia1     compareOperands.push_back(before->getArgument(0));
4049b800bf7Sbixia1     compareOperands.push_back(p);
4059b800bf7Sbixia1   } else {
4069b800bf7Sbixia1     assert(step < 0);
4079b800bf7Sbixia1     compareOperands.push_back(p);
4089b800bf7Sbixia1     compareOperands.push_back(before->getArgument(0));
4099b800bf7Sbixia1   }
410062e515bSbixia1   compareOperands.append(xs.begin(), xs.end());
411bfa3bc43SPeiming Liu   Value cond = createInlinedLessThan(builder, loc, compareOperands, xPerm, ny);
4129b800bf7Sbixia1   builder.create<scf::ConditionOp>(loc, cond, before->getArguments());
413062e515bSbixia1 
4149b800bf7Sbixia1   Block *after =
4159b800bf7Sbixia1       builder.createBlock(&whileOp.getAfter(), {}, {i.getType()}, {loc});
4169b800bf7Sbixia1   builder.setInsertionPointToEnd(after);
4179b800bf7Sbixia1   Value cs = constantIndex(builder, loc, step);
4189b800bf7Sbixia1   i = builder.create<arith::AddIOp>(loc, after->getArgument(0), cs);
4199b800bf7Sbixia1   builder.create<scf::YieldOp>(loc, ValueRange{i});
4209b800bf7Sbixia1   i = whileOp.getResult(0);
4219b800bf7Sbixia1 
4229b800bf7Sbixia1   builder.setInsertionPointAfter(whileOp);
4239b800bf7Sbixia1   compareOperands[0] = i;
4249b800bf7Sbixia1   compareOperands[1] = p;
4259b800bf7Sbixia1   Value compareEq =
426bfa3bc43SPeiming Liu       createInlinedEqCompare(builder, loc, compareOperands, xPerm, ny);
4279b800bf7Sbixia1 
4289b800bf7Sbixia1   return std::make_pair(whileOp.getResult(0), compareEq);
4299b800bf7Sbixia1 }
4309b800bf7Sbixia1 
431abb05014Sbixia1 /// Creates and returns an IfOp to compare two elements and swap the elements
432abb05014Sbixia1 /// if compareFunc(data[b], data[a]) returns true. The new insertion point is
433abb05014Sbixia1 /// right after the swap instructions.
434abb05014Sbixia1 static scf::IfOp createCompareThenSwap(OpBuilder &builder, Location loc,
435bfa3bc43SPeiming Liu                                        AffineMap xPerm, uint64_t ny,
436abb05014Sbixia1                                        SmallVectorImpl<Value> &swapOperands,
437abb05014Sbixia1                                        SmallVectorImpl<Value> &compareOperands,
438abb05014Sbixia1                                        Value a, Value b) {
439abb05014Sbixia1   // Compare(data[b], data[a]).
440abb05014Sbixia1   compareOperands[0] = b;
441abb05014Sbixia1   compareOperands[1] = a;
442bfa3bc43SPeiming Liu   Value cond = createInlinedLessThan(builder, loc, compareOperands, xPerm, ny);
443abb05014Sbixia1   scf::IfOp ifOp = builder.create<scf::IfOp>(loc, cond, /*else=*/false);
444abb05014Sbixia1   builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
445abb05014Sbixia1   swapOperands[0] = b;
446abb05014Sbixia1   swapOperands[1] = a;
447bfa3bc43SPeiming Liu   createSwap(builder, loc, swapOperands, xPerm, ny);
448abb05014Sbixia1   return ifOp;
449abb05014Sbixia1 }
450abb05014Sbixia1 
451abb05014Sbixia1 /// Creates code to insert the 3rd element to a list of two sorted elements.
452bfa3bc43SPeiming Liu static void createInsert3rd(OpBuilder &builder, Location loc, AffineMap xPerm,
453bfa3bc43SPeiming Liu                             uint64_t ny, SmallVectorImpl<Value> &swapOperands,
454abb05014Sbixia1                             SmallVectorImpl<Value> &compareOperands, Value v0,
455abb05014Sbixia1                             Value v1, Value v2) {
456bfa3bc43SPeiming Liu   scf::IfOp ifOp = createCompareThenSwap(builder, loc, xPerm, ny, swapOperands,
457bfa3bc43SPeiming Liu                                          compareOperands, v1, v2);
458bfa3bc43SPeiming Liu   createCompareThenSwap(builder, loc, xPerm, ny, swapOperands, compareOperands,
459bfa3bc43SPeiming Liu                         v0, v1);
460abb05014Sbixia1   builder.setInsertionPointAfter(ifOp);
461abb05014Sbixia1 }
462abb05014Sbixia1 
463abb05014Sbixia1 /// Creates code to sort 3 elements.
464bfa3bc43SPeiming Liu static void createSort3(OpBuilder &builder, Location loc, AffineMap xPerm,
465bfa3bc43SPeiming Liu                         uint64_t ny, SmallVectorImpl<Value> &swapOperands,
466abb05014Sbixia1                         SmallVectorImpl<Value> &compareOperands, Value v0,
467abb05014Sbixia1                         Value v1, Value v2) {
468abb05014Sbixia1   // Sort the first 2 elements.
469bfa3bc43SPeiming Liu   scf::IfOp ifOp1 = createCompareThenSwap(builder, loc, xPerm, ny, swapOperands,
470bfa3bc43SPeiming Liu                                           compareOperands, v0, v1);
471abb05014Sbixia1   builder.setInsertionPointAfter(ifOp1);
472abb05014Sbixia1 
473abb05014Sbixia1   // Insert the 3th element.
474bfa3bc43SPeiming Liu   createInsert3rd(builder, loc, xPerm, ny, swapOperands, compareOperands, v0,
475bfa3bc43SPeiming Liu                   v1, v2);
476abb05014Sbixia1 }
477abb05014Sbixia1 
478abb05014Sbixia1 /// Creates code to sort 5 elements.
479bfa3bc43SPeiming Liu static void createSort5(OpBuilder &builder, Location loc, AffineMap xPerm,
480bfa3bc43SPeiming Liu                         uint64_t ny, SmallVectorImpl<Value> &swapOperands,
481abb05014Sbixia1                         SmallVectorImpl<Value> &compareOperands, Value v0,
482abb05014Sbixia1                         Value v1, Value v2, Value v3, Value v4) {
483abb05014Sbixia1   // Sort the first 3 elements.
484bfa3bc43SPeiming Liu   createSort3(builder, loc, xPerm, ny, swapOperands, compareOperands, v0, v1,
485bfa3bc43SPeiming Liu               v2);
486abb05014Sbixia1 
487abb05014Sbixia1   auto insert4th = [&]() {
488abb05014Sbixia1     scf::IfOp ifOp = createCompareThenSwap(
489bfa3bc43SPeiming Liu         builder, loc, xPerm, ny, swapOperands, compareOperands, v2, v3);
490bfa3bc43SPeiming Liu     createInsert3rd(builder, loc, xPerm, ny, swapOperands, compareOperands, v0,
491bfa3bc43SPeiming Liu                     v1, v2);
492abb05014Sbixia1     builder.setInsertionPointAfter(ifOp);
493abb05014Sbixia1   };
494abb05014Sbixia1 
495abb05014Sbixia1   // Insert the 4th element.
496abb05014Sbixia1   insert4th();
497abb05014Sbixia1 
498abb05014Sbixia1   // Insert the 5th element.
499bfa3bc43SPeiming Liu   scf::IfOp ifOp = createCompareThenSwap(builder, loc, xPerm, ny, swapOperands,
500bfa3bc43SPeiming Liu                                          compareOperands, v3, v4);
501abb05014Sbixia1   insert4th();
502abb05014Sbixia1   builder.setInsertionPointAfter(ifOp);
503abb05014Sbixia1 }
504abb05014Sbixia1 
505abb05014Sbixia1 /// Creates a code block to swap the values in indices lo, mi, and hi so that
506abb05014Sbixia1 /// data[lo], data[mi] and data[hi] are sorted in non-decreasing values. When
507abb05014Sbixia1 /// the number of values in range [lo, hi) is more than a threshold, we also
508abb05014Sbixia1 /// include the middle of [lo, mi) and [mi, hi) and sort a total of five values.
5097cec4d16Sbixia1 static void createChoosePivot(OpBuilder &builder, ModuleOp module,
510bfa3bc43SPeiming Liu                               func::FuncOp func, AffineMap xPerm, uint64_t ny,
511bfa3bc43SPeiming Liu                               Value lo, Value hi, Value mi, ValueRange args) {
5127cec4d16Sbixia1   SmallVector<Value> compareOperands{mi, lo};
513bfa3bc43SPeiming Liu   constexpr uint64_t numXBuffers = 1;
5147cec4d16Sbixia1   compareOperands.append(args.begin() + xStartIdx,
5157cec4d16Sbixia1                          args.begin() + xStartIdx + numXBuffers);
516abb05014Sbixia1   SmallVector<Value> swapOperands{mi, lo};
5177cec4d16Sbixia1   swapOperands.append(args.begin() + xStartIdx, args.end());
518abb05014Sbixia1   Location loc = func.getLoc();
519abb05014Sbixia1   Value c1 = constantIndex(builder, loc, 1);
520abb05014Sbixia1   Value hiP1 = builder.create<arith::AddIOp>(loc, hi, c1);
521abb05014Sbixia1   Value len = builder.create<arith::SubIOp>(loc, hiP1, lo);
522abb05014Sbixia1   Value lenThreshold = constantIndex(builder, loc, 1000);
523abb05014Sbixia1   Value lenCond = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult,
524abb05014Sbixia1                                                 len, lenThreshold);
525abb05014Sbixia1   scf::IfOp lenIf = builder.create<scf::IfOp>(loc, lenCond, /*else=*/true);
526abb05014Sbixia1 
527abb05014Sbixia1   // When len < 1000, choose pivot from median of 3 values.
528abb05014Sbixia1   builder.setInsertionPointToStart(&lenIf.getThenRegion().front());
529bfa3bc43SPeiming Liu   createSort3(builder, loc, xPerm, ny, swapOperands, compareOperands, lo, mi,
530bfa3bc43SPeiming Liu               hi);
531abb05014Sbixia1 
532abb05014Sbixia1   // When len >= 1000, choose pivot from median of 5 values.
533abb05014Sbixia1   builder.setInsertionPointToStart(&lenIf.getElseRegion().front());
534abb05014Sbixia1   Value miP1 = builder.create<arith::AddIOp>(loc, hi, c1);
535abb05014Sbixia1   Value a = builder.create<arith::AddIOp>(loc, lo, miP1);
536abb05014Sbixia1   // Value a is the middle between [loc, mi].
537abb05014Sbixia1   a = builder.create<arith::ShRUIOp>(loc, a, c1);
538abb05014Sbixia1   Value b = builder.create<arith::AddIOp>(loc, mi, hiP1);
539abb05014Sbixia1   // Value b is the middle between [mi, hi].
540abb05014Sbixia1   b = builder.create<arith::ShRUIOp>(loc, b, c1);
541bfa3bc43SPeiming Liu   createSort5(builder, loc, xPerm, ny, swapOperands, compareOperands, lo, a, mi,
542bfa3bc43SPeiming Liu               b, hi);
543abb05014Sbixia1 
544abb05014Sbixia1   builder.setInsertionPointAfter(lenIf);
5457cec4d16Sbixia1 }
5467cec4d16Sbixia1 
5479b800bf7Sbixia1 /// Creates a function to perform quick sort partition on the values in the
5489b800bf7Sbixia1 /// range of index [lo, hi), assuming lo < hi.
5499b800bf7Sbixia1 //
5509b800bf7Sbixia1 // The generated IR corresponds to this C like algorithm:
5519b800bf7Sbixia1 // int partition(lo, hi, xs) {
5529b800bf7Sbixia1 //   p = (lo+hi)/2  // pivot index
5539b800bf7Sbixia1 //   i = lo
5549b800bf7Sbixia1 //   j = hi-1
5554176ce61SPeiming Liu //   while (true) do {
5569b800bf7Sbixia1 //     while (xs[i] < xs[p]) i ++;
5579b800bf7Sbixia1 //     i_eq = (xs[i] == xs[p]);
5589b800bf7Sbixia1 //     while (xs[j] > xs[p]) j --;
5599b800bf7Sbixia1 //     j_eq = (xs[j] == xs[p]);
5604176ce61SPeiming Liu //
5614176ce61SPeiming Liu //     if (i >= j) return j + 1;
5624176ce61SPeiming Liu //
5639b800bf7Sbixia1 //     if (i < j) {
5649b800bf7Sbixia1 //       swap(xs[i], xs[j])
5659b800bf7Sbixia1 //       if (i == p) {
5669b800bf7Sbixia1 //         p = j;
5679b800bf7Sbixia1 //       } else if (j == p) {
5689b800bf7Sbixia1 //         p = i;
5699b800bf7Sbixia1 //       }
5709b800bf7Sbixia1 //       if (i_eq && j_eq) {
5719b800bf7Sbixia1 //         ++i;
5729b800bf7Sbixia1 //         --j;
5739b800bf7Sbixia1 //       }
5749b800bf7Sbixia1 //     }
5759b800bf7Sbixia1 //   }
5769b800bf7Sbixia1 // }
5779b800bf7Sbixia1 static void createPartitionFunc(OpBuilder &builder, ModuleOp module,
578bfa3bc43SPeiming Liu                                 func::FuncOp func, AffineMap xPerm, uint64_t ny,
579bfa3bc43SPeiming Liu                                 uint32_t nTrailingP = 0) {
5808550aebdSbixia1   // Quick sort partition doesn't use trailing parameters.
5818550aebdSbixia1   (void)nTrailingP;
5828550aebdSbixia1   assert(nTrailingP == 0);
5839b800bf7Sbixia1   OpBuilder::InsertionGuard insertionGuard(builder);
5849b800bf7Sbixia1 
5859b800bf7Sbixia1   Block *entryBlock = func.addEntryBlock();
5869b800bf7Sbixia1   builder.setInsertionPointToStart(entryBlock);
5879b800bf7Sbixia1 
5889b800bf7Sbixia1   Location loc = func.getLoc();
5899b800bf7Sbixia1   ValueRange args = entryBlock->getArguments();
5909b800bf7Sbixia1   Value lo = args[loIdx];
5919b800bf7Sbixia1   Value hi = args[hiIdx];
5929b800bf7Sbixia1   Value sum = builder.create<arith::AddIOp>(loc, lo, hi);
5939b800bf7Sbixia1   Value c1 = constantIndex(builder, loc, 1);
5949b800bf7Sbixia1   Value p = builder.create<arith::ShRUIOp>(loc, sum, c1);
5959b800bf7Sbixia1 
5969b800bf7Sbixia1   Value i = lo;
5979b800bf7Sbixia1   Value j = builder.create<arith::SubIOp>(loc, hi, c1);
598bfa3bc43SPeiming Liu   createChoosePivot(builder, module, func, xPerm, ny, i, j, p, args);
5994176ce61SPeiming Liu   Value trueVal = constantI1(builder, loc, true); // The value for while (true)
6004176ce61SPeiming Liu   SmallVector<Value, 4> operands{i, j, p, trueVal}; // Exactly four values.
6014176ce61SPeiming Liu   SmallVector<Type, 4> types{i.getType(), j.getType(), p.getType(),
6024176ce61SPeiming Liu                              trueVal.getType()};
6039b800bf7Sbixia1   scf::WhileOp whileOp = builder.create<scf::WhileOp>(loc, types, operands);
6049b800bf7Sbixia1 
6059b800bf7Sbixia1   // The before-region of the WhileOp.
6064176ce61SPeiming Liu   Block *before = builder.createBlock(&whileOp.getBefore(), {}, types,
6074176ce61SPeiming Liu                                       {loc, loc, loc, loc});
6089b800bf7Sbixia1   builder.setInsertionPointToEnd(before);
6094176ce61SPeiming Liu   builder.create<scf::ConditionOp>(loc, before->getArgument(3),
6104176ce61SPeiming Liu                                    before->getArguments());
6119b800bf7Sbixia1 
6129b800bf7Sbixia1   // The after-region of the WhileOp.
6139b800bf7Sbixia1   Block *after =
6144176ce61SPeiming Liu       builder.createBlock(&whileOp.getAfter(), {}, types, {loc, loc, loc, loc});
6159b800bf7Sbixia1   builder.setInsertionPointToEnd(after);
6169b800bf7Sbixia1   i = after->getArgument(0);
6179b800bf7Sbixia1   j = after->getArgument(1);
6189b800bf7Sbixia1   p = after->getArgument(2);
6199b800bf7Sbixia1 
620bfa3bc43SPeiming Liu   constexpr uint64_t numXBuffers = 1;
6214f729d5aSbixia1   auto [iresult, iCompareEq] =
6224f729d5aSbixia1       createScanLoop(builder, module, func, args.slice(xStartIdx, numXBuffers),
623bfa3bc43SPeiming Liu                      i, p, xPerm, ny, 1);
6249b800bf7Sbixia1   i = iresult;
6254f729d5aSbixia1   auto [jresult, jCompareEq] =
6264f729d5aSbixia1       createScanLoop(builder, module, func, args.slice(xStartIdx, numXBuffers),
627bfa3bc43SPeiming Liu                      j, p, xPerm, ny, -1);
6289b800bf7Sbixia1   j = jresult;
6299b800bf7Sbixia1 
6309b800bf7Sbixia1   // If i < j:
6314176ce61SPeiming Liu   Value cond =
6324176ce61SPeiming Liu       builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult, i, j);
6339b800bf7Sbixia1   scf::IfOp ifOp = builder.create<scf::IfOp>(loc, types, cond, /*else=*/true);
634062e515bSbixia1   builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
6350e1708ffSAart Bik   SmallVector<Value> swapOperands{i, j};
636062e515bSbixia1   swapOperands.append(args.begin() + xStartIdx, args.end());
637bfa3bc43SPeiming Liu   createSwap(builder, loc, swapOperands, xPerm, ny);
6389b800bf7Sbixia1   // If the pivot is moved, update p with the new pivot.
6399b800bf7Sbixia1   Value icond =
6409b800bf7Sbixia1       builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, i, p);
6419b800bf7Sbixia1   scf::IfOp ifOpI = builder.create<scf::IfOp>(loc, TypeRange{p.getType()},
6429b800bf7Sbixia1                                               icond, /*else=*/true);
6439b800bf7Sbixia1   builder.setInsertionPointToStart(&ifOpI.getThenRegion().front());
6449b800bf7Sbixia1   builder.create<scf::YieldOp>(loc, ValueRange{j});
6459b800bf7Sbixia1   builder.setInsertionPointToStart(&ifOpI.getElseRegion().front());
6469b800bf7Sbixia1   Value jcond =
6479b800bf7Sbixia1       builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, j, p);
6489b800bf7Sbixia1   scf::IfOp ifOpJ = builder.create<scf::IfOp>(loc, TypeRange{p.getType()},
6499b800bf7Sbixia1                                               jcond, /*else=*/true);
6509b800bf7Sbixia1   builder.setInsertionPointToStart(&ifOpJ.getThenRegion().front());
6519b800bf7Sbixia1   builder.create<scf::YieldOp>(loc, ValueRange{i});
6529b800bf7Sbixia1   builder.setInsertionPointToStart(&ifOpJ.getElseRegion().front());
6539b800bf7Sbixia1   builder.create<scf::YieldOp>(loc, ValueRange{p});
6549b800bf7Sbixia1   builder.setInsertionPointAfter(ifOpJ);
6559b800bf7Sbixia1   builder.create<scf::YieldOp>(loc, ifOpJ.getResults());
6569b800bf7Sbixia1   builder.setInsertionPointAfter(ifOpI);
6579b800bf7Sbixia1   Value compareEqIJ =
6589b800bf7Sbixia1       builder.create<arith::AndIOp>(loc, iCompareEq, jCompareEq);
6599b800bf7Sbixia1   scf::IfOp ifOp2 = builder.create<scf::IfOp>(
6609b800bf7Sbixia1       loc, TypeRange{i.getType(), j.getType()}, compareEqIJ, /*else=*/true);
6619b800bf7Sbixia1   builder.setInsertionPointToStart(&ifOp2.getThenRegion().front());
6629b800bf7Sbixia1   Value i2 = builder.create<arith::AddIOp>(loc, i, c1);
6639b800bf7Sbixia1   Value j2 = builder.create<arith::SubIOp>(loc, j, c1);
6649b800bf7Sbixia1   builder.create<scf::YieldOp>(loc, ValueRange{i2, j2});
6659b800bf7Sbixia1   builder.setInsertionPointToStart(&ifOp2.getElseRegion().front());
6669b800bf7Sbixia1   builder.create<scf::YieldOp>(loc, ValueRange{i, j});
6679b800bf7Sbixia1   builder.setInsertionPointAfter(ifOp2);
6689b800bf7Sbixia1   builder.create<scf::YieldOp>(
6699b800bf7Sbixia1       loc,
6704176ce61SPeiming Liu       ValueRange{ifOp2.getResult(0), ifOp2.getResult(1), ifOpI.getResult(0),
6714176ce61SPeiming Liu                  /*cont=*/constantI1(builder, loc, true)});
672062e515bSbixia1 
6734176ce61SPeiming Liu   // False branch for if i < j (i.e., i >= j):
674062e515bSbixia1   builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
6754176ce61SPeiming Liu   p = builder.create<arith::AddIOp>(loc, j,
6764176ce61SPeiming Liu                                     constantOne(builder, loc, j.getType()));
6774176ce61SPeiming Liu   builder.create<scf::YieldOp>(
6784176ce61SPeiming Liu       loc, ValueRange{i, j, p, /*cont=*/constantI1(builder, loc, false)});
679062e515bSbixia1 
6809b800bf7Sbixia1   // Return for the whileOp.
681062e515bSbixia1   builder.setInsertionPointAfter(ifOp);
6829b800bf7Sbixia1   builder.create<scf::YieldOp>(loc, ifOp.getResults());
683062e515bSbixia1 
6849b800bf7Sbixia1   // Return for the function.
6859b800bf7Sbixia1   builder.setInsertionPointAfter(whileOp);
6869b800bf7Sbixia1   builder.create<func::ReturnOp>(loc, whileOp.getResult(2));
687062e515bSbixia1 }
688062e515bSbixia1 
6893b1c86cdSbixia1 /// Computes (n-2)/n, assuming n has index type.
6903b1c86cdSbixia1 static Value createSubTwoDividedByTwo(OpBuilder &builder, Location loc,
6913b1c86cdSbixia1                                       Value n) {
6923b1c86cdSbixia1   Value i2 = constantIndex(builder, loc, 2);
6933b1c86cdSbixia1   Value res = builder.create<arith::SubIOp>(loc, n, i2);
6943b1c86cdSbixia1   Value i1 = constantIndex(builder, loc, 1);
6953b1c86cdSbixia1   return builder.create<arith::ShRUIOp>(loc, res, i1);
6963b1c86cdSbixia1 }
6973b1c86cdSbixia1 
6983b1c86cdSbixia1 /// Creates a function to heapify the subtree with root `start` within the full
6993b1c86cdSbixia1 /// binary tree in the range of index [first, first + n).
7003b1c86cdSbixia1 //
7013b1c86cdSbixia1 // The generated IR corresponds to this C like algorithm:
7023b1c86cdSbixia1 // void shiftDown(first, start, n, data) {
7033b1c86cdSbixia1 //   if (n >= 2) {
7043b1c86cdSbixia1 //     child = start - first
7053b1c86cdSbixia1 //     if ((n-2)/2 >= child) {
7063b1c86cdSbixia1 //       // Left child exists.
7073b1c86cdSbixia1 //       child = child * 2 + 1 // Initialize the bigger child to left child.
7083b1c86cdSbixia1 //       childIndex = child + first
7093b1c86cdSbixia1 //       if (child+1 < n && data[childIndex] < data[childIndex+1])
7103b1c86cdSbixia1 //         // Right child exits and is bigger.
7113b1c86cdSbixia1 //         childIndex++; child++;
7123b1c86cdSbixia1 //       // Shift data[start] down to where it belongs in the subtree.
7133b1c86cdSbixia1 //       while (data[start] < data[childIndex) {
7143b1c86cdSbixia1 //         swap(data[start], data[childIndex])
7153b1c86cdSbixia1 //         start = childIndex
7163b1c86cdSbixia1 //         if ((n - 2)/2 >= child) {
7173b1c86cdSbixia1 //           // Left child exists.
7183b1c86cdSbixia1 //           child = 2*child + 1
7193b1c86cdSbixia1 //           childIndex = child + 1
7203b1c86cdSbixia1 //           if (child + 1) < n && data[childIndex] < data[childIndex+1]
7213b1c86cdSbixia1 //             childIndex++; child++;
7223b1c86cdSbixia1 //         }
7233b1c86cdSbixia1 //       }
7243b1c86cdSbixia1 //     }
7253b1c86cdSbixia1 //   }
7263b1c86cdSbixia1 // }
7273b1c86cdSbixia1 //
7283b1c86cdSbixia1 static void createShiftDownFunc(OpBuilder &builder, ModuleOp module,
729bfa3bc43SPeiming Liu                                 func::FuncOp func, AffineMap xPerm, uint64_t ny,
730bfa3bc43SPeiming Liu                                 uint32_t nTrailingP) {
7313b1c86cdSbixia1   // The value n is passed in as a trailing parameter.
7323b1c86cdSbixia1   assert(nTrailingP == 1);
7333b1c86cdSbixia1   OpBuilder::InsertionGuard insertionGuard(builder);
7343b1c86cdSbixia1   Block *entryBlock = func.addEntryBlock();
7353b1c86cdSbixia1   builder.setInsertionPointToStart(entryBlock);
7363b1c86cdSbixia1 
7373b1c86cdSbixia1   Location loc = func.getLoc();
7383b1c86cdSbixia1   Value n = entryBlock->getArguments().back();
7393b1c86cdSbixia1   ValueRange args = entryBlock->getArguments().drop_back();
7403b1c86cdSbixia1   Value first = args[loIdx];
7413b1c86cdSbixia1   Value start = args[hiIdx];
7423b1c86cdSbixia1 
7433b1c86cdSbixia1   // If (n >= 2).
7443b1c86cdSbixia1   Value c2 = constantIndex(builder, loc, 2);
7453b1c86cdSbixia1   Value condN =
7463b1c86cdSbixia1       builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::uge, n, c2);
7473b1c86cdSbixia1   scf::IfOp ifN = builder.create<scf::IfOp>(loc, condN, /*else=*/false);
7483b1c86cdSbixia1   builder.setInsertionPointToStart(&ifN.getThenRegion().front());
7493b1c86cdSbixia1   Value child = builder.create<arith::SubIOp>(loc, start, first);
7503b1c86cdSbixia1 
7513b1c86cdSbixia1   // If ((n-2)/2 >= child).
7523b1c86cdSbixia1   Value t = createSubTwoDividedByTwo(builder, loc, n);
7533b1c86cdSbixia1   Value condNc =
7543b1c86cdSbixia1       builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::uge, t, child);
7553b1c86cdSbixia1   scf::IfOp ifNc = builder.create<scf::IfOp>(loc, condNc, /*else=*/false);
7563b1c86cdSbixia1 
7573b1c86cdSbixia1   builder.setInsertionPointToStart(&ifNc.getThenRegion().front());
7583b1c86cdSbixia1   Value c1 = constantIndex(builder, loc, 1);
7593b1c86cdSbixia1   SmallVector<Value> compareOperands{start, start};
760bfa3bc43SPeiming Liu   constexpr uint64_t numXBuffers = 1;
7613b1c86cdSbixia1   compareOperands.append(args.begin() + xStartIdx,
7623b1c86cdSbixia1                          args.begin() + xStartIdx + numXBuffers);
7633b1c86cdSbixia1 
7643b1c86cdSbixia1   // Generate code to inspect the children of 'r' and return the larger child
7653b1c86cdSbixia1   // as follows:
7663b1c86cdSbixia1   //   child = r * 2 + 1 // Left child.
7673b1c86cdSbixia1   //   childIndex = child + first
7683b1c86cdSbixia1   //   if (child+1 < n && data[childIndex] < data[childIndex+1])
7693b1c86cdSbixia1   //     childIndex ++; child ++ // Right child is bigger.
7703b1c86cdSbixia1   auto getLargerChild = [&](Value r) -> std::pair<Value, Value> {
7713b1c86cdSbixia1     Value lChild = builder.create<arith::ShLIOp>(loc, r, c1);
7723b1c86cdSbixia1     lChild = builder.create<arith::AddIOp>(loc, lChild, c1);
7733b1c86cdSbixia1     Value lChildIdx = builder.create<arith::AddIOp>(loc, lChild, first);
7743b1c86cdSbixia1     Value rChild = builder.create<arith::AddIOp>(loc, lChild, c1);
7753b1c86cdSbixia1     Value cond1 = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult,
7763b1c86cdSbixia1                                                 rChild, n);
7773b1c86cdSbixia1     SmallVector<Type, 2> ifTypes(2, r.getType());
7783b1c86cdSbixia1     scf::IfOp if1 =
7793b1c86cdSbixia1         builder.create<scf::IfOp>(loc, ifTypes, cond1, /*else=*/true);
7803b1c86cdSbixia1     builder.setInsertionPointToStart(&if1.getThenRegion().front());
7813b1c86cdSbixia1     Value rChildIdx = builder.create<arith::AddIOp>(loc, rChild, first);
7823b1c86cdSbixia1     // Compare data[left] < data[right].
7833b1c86cdSbixia1     compareOperands[0] = lChildIdx;
7843b1c86cdSbixia1     compareOperands[1] = rChildIdx;
7852ef41627Sbixia1     Value cond2 =
786bfa3bc43SPeiming Liu         createInlinedLessThan(builder, loc, compareOperands, xPerm, ny);
7873b1c86cdSbixia1     scf::IfOp if2 =
7883b1c86cdSbixia1         builder.create<scf::IfOp>(loc, ifTypes, cond2, /*else=*/true);
7893b1c86cdSbixia1     builder.setInsertionPointToStart(&if2.getThenRegion().front());
7903b1c86cdSbixia1     builder.create<scf::YieldOp>(loc, ValueRange{rChild, rChildIdx});
7913b1c86cdSbixia1     builder.setInsertionPointToStart(&if2.getElseRegion().front());
7923b1c86cdSbixia1     builder.create<scf::YieldOp>(loc, ValueRange{lChild, lChildIdx});
7933b1c86cdSbixia1     builder.setInsertionPointAfter(if2);
7943b1c86cdSbixia1     builder.create<scf::YieldOp>(loc, if2.getResults());
7953b1c86cdSbixia1     builder.setInsertionPointToStart(&if1.getElseRegion().front());
7963b1c86cdSbixia1     builder.create<scf::YieldOp>(loc, ValueRange{lChild, lChildIdx});
7973b1c86cdSbixia1     builder.setInsertionPointAfter(if1);
7983b1c86cdSbixia1     return std::make_pair(if1.getResult(0), if1.getResult(1));
7993b1c86cdSbixia1   };
8003b1c86cdSbixia1 
8013b1c86cdSbixia1   Value childIdx;
8023b1c86cdSbixia1   std::tie(child, childIdx) = getLargerChild(child);
8033b1c86cdSbixia1 
8043b1c86cdSbixia1   // While (data[start] < data[childIndex]).
8053b1c86cdSbixia1   SmallVector<Type, 3> types(3, child.getType());
8063b1c86cdSbixia1   scf::WhileOp whileOp = builder.create<scf::WhileOp>(
8073b1c86cdSbixia1       loc, types, SmallVector<Value, 2>{start, child, childIdx});
8083b1c86cdSbixia1 
8093b1c86cdSbixia1   // The before-region of the WhileOp.
8103b1c86cdSbixia1   SmallVector<Location, 3> locs(3, loc);
8113b1c86cdSbixia1   Block *before = builder.createBlock(&whileOp.getBefore(), {}, types, locs);
8123b1c86cdSbixia1   builder.setInsertionPointToEnd(before);
8133b1c86cdSbixia1   start = before->getArgument(0);
8143b1c86cdSbixia1   childIdx = before->getArgument(2);
8153b1c86cdSbixia1   compareOperands[0] = start;
8163b1c86cdSbixia1   compareOperands[1] = childIdx;
817bfa3bc43SPeiming Liu   Value cond = createInlinedLessThan(builder, loc, compareOperands, xPerm, ny);
8183b1c86cdSbixia1   builder.create<scf::ConditionOp>(loc, cond, before->getArguments());
8193b1c86cdSbixia1 
8203b1c86cdSbixia1   // The after-region of the WhileOp.
8213b1c86cdSbixia1   Block *after = builder.createBlock(&whileOp.getAfter(), {}, types, locs);
8223b1c86cdSbixia1   start = after->getArgument(0);
8233b1c86cdSbixia1   child = after->getArgument(1);
8243b1c86cdSbixia1   childIdx = after->getArgument(2);
8253b1c86cdSbixia1   SmallVector<Value> swapOperands{start, childIdx};
8263b1c86cdSbixia1   swapOperands.append(args.begin() + xStartIdx, args.end());
827bfa3bc43SPeiming Liu   createSwap(builder, loc, swapOperands, xPerm, ny);
8283b1c86cdSbixia1   start = childIdx;
8293b1c86cdSbixia1   Value cond2 =
8303b1c86cdSbixia1       builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::uge, t, child);
8313b1c86cdSbixia1   scf::IfOp if2 = builder.create<scf::IfOp>(
8323b1c86cdSbixia1       loc, TypeRange{child.getType(), child.getType()}, cond2, /*else=*/true);
8333b1c86cdSbixia1   builder.setInsertionPointToStart(&if2.getThenRegion().front());
8343b1c86cdSbixia1   auto [newChild, newChildIdx] = getLargerChild(child);
8353b1c86cdSbixia1   builder.create<scf::YieldOp>(loc, ValueRange{newChild, newChildIdx});
8363b1c86cdSbixia1   builder.setInsertionPointToStart(&if2.getElseRegion().front());
8373b1c86cdSbixia1   builder.create<scf::YieldOp>(loc, ValueRange{child, childIdx});
8383b1c86cdSbixia1   builder.setInsertionPointAfter(if2);
8393b1c86cdSbixia1   builder.create<scf::YieldOp>(
8403b1c86cdSbixia1       loc, ValueRange{start, if2.getResult(0), if2.getResult(1)});
8413b1c86cdSbixia1 
8423b1c86cdSbixia1   builder.setInsertionPointAfter(ifN);
8433b1c86cdSbixia1   builder.create<func::ReturnOp>(loc);
8443b1c86cdSbixia1 }
8453b1c86cdSbixia1 
8463b1c86cdSbixia1 /// Creates a function to perform heap sort on the values in the range of index
8473b1c86cdSbixia1 /// [lo, hi) with the assumption hi - lo >= 2.
8483b1c86cdSbixia1 //
8493b1c86cdSbixia1 // The generate IR corresponds to this C like algorithm:
8503b1c86cdSbixia1 // void heapSort(lo, hi, data) {
8513b1c86cdSbixia1 //   n = hi - lo
8523b1c86cdSbixia1 //   for i = (n-2)/2 downto 0
8533b1c86cdSbixia1 //     shiftDown(lo, lo+i, n)
8543b1c86cdSbixia1 //
8553b1c86cdSbixia1 //   for l = n downto 2
8563b1c86cdSbixia1 //      swap(lo, lo+l-1)
8573b1c86cdSbixia1 //      shiftdown(lo, lo, l-1)
8583b1c86cdSbixia1 // }
8593b1c86cdSbixia1 static void createHeapSortFunc(OpBuilder &builder, ModuleOp module,
860bfa3bc43SPeiming Liu                                func::FuncOp func, AffineMap xPerm, uint64_t ny,
861bfa3bc43SPeiming Liu                                uint32_t nTrailingP) {
8623b1c86cdSbixia1   // Heap sort function doesn't have trailing parameters.
8633b1c86cdSbixia1   (void)nTrailingP;
8643b1c86cdSbixia1   assert(nTrailingP == 0);
8653b1c86cdSbixia1   OpBuilder::InsertionGuard insertionGuard(builder);
8663b1c86cdSbixia1   Block *entryBlock = func.addEntryBlock();
8673b1c86cdSbixia1   builder.setInsertionPointToStart(entryBlock);
8683b1c86cdSbixia1 
8693b1c86cdSbixia1   Location loc = func.getLoc();
8703b1c86cdSbixia1   ValueRange args = entryBlock->getArguments();
8713b1c86cdSbixia1   Value lo = args[loIdx];
8723b1c86cdSbixia1   Value hi = args[hiIdx];
8733b1c86cdSbixia1   Value n = builder.create<arith::SubIOp>(loc, hi, lo);
8743b1c86cdSbixia1 
8753b1c86cdSbixia1   // For i = (n-2)/2 downto 0.
8763b1c86cdSbixia1   Value c0 = constantIndex(builder, loc, 0);
8773b1c86cdSbixia1   Value c1 = constantIndex(builder, loc, 1);
8783b1c86cdSbixia1   Value s = createSubTwoDividedByTwo(builder, loc, n);
8793b1c86cdSbixia1   Value up = builder.create<arith::AddIOp>(loc, s, c1);
8803b1c86cdSbixia1   scf::ForOp forI = builder.create<scf::ForOp>(loc, c0, up, c1);
8813b1c86cdSbixia1   builder.setInsertionPointToStart(forI.getBody());
8823b1c86cdSbixia1   Value i = builder.create<arith::SubIOp>(loc, s, forI.getInductionVar());
8833b1c86cdSbixia1   Value lopi = builder.create<arith::AddIOp>(loc, lo, i);
8843b1c86cdSbixia1   SmallVector<Value> shiftDownOperands = {lo, lopi};
8853b1c86cdSbixia1   shiftDownOperands.append(args.begin() + xStartIdx, args.end());
8863b1c86cdSbixia1   shiftDownOperands.push_back(n);
8873b1c86cdSbixia1   FlatSymbolRefAttr shiftDownFunc = getMangledSortHelperFunc(
888bfa3bc43SPeiming Liu       builder, func, TypeRange(), kShiftDownFuncNamePrefix, xPerm, ny,
8893b1c86cdSbixia1       shiftDownOperands, createShiftDownFunc, /*nTrailingP=*/1);
8903b1c86cdSbixia1   builder.create<func::CallOp>(loc, shiftDownFunc, TypeRange(),
8913b1c86cdSbixia1                                shiftDownOperands);
8923b1c86cdSbixia1 
8933b1c86cdSbixia1   builder.setInsertionPointAfter(forI);
8943b1c86cdSbixia1   // For l = n downto 2.
8953b1c86cdSbixia1   up = builder.create<arith::SubIOp>(loc, n, c1);
8963b1c86cdSbixia1   scf::ForOp forL = builder.create<scf::ForOp>(loc, c0, up, c1);
8973b1c86cdSbixia1   builder.setInsertionPointToStart(forL.getBody());
8983b1c86cdSbixia1   Value l = builder.create<arith::SubIOp>(loc, n, forL.getInductionVar());
8993b1c86cdSbixia1   Value loplm1 = builder.create<arith::AddIOp>(loc, lo, l);
9003b1c86cdSbixia1   loplm1 = builder.create<arith::SubIOp>(loc, loplm1, c1);
9013b1c86cdSbixia1   SmallVector<Value> swapOperands{lo, loplm1};
9023b1c86cdSbixia1   swapOperands.append(args.begin() + xStartIdx, args.end());
903bfa3bc43SPeiming Liu   createSwap(builder, loc, swapOperands, xPerm, ny);
9043b1c86cdSbixia1   shiftDownOperands[1] = lo;
9053b1c86cdSbixia1   shiftDownOperands[shiftDownOperands.size() - 1] =
9063b1c86cdSbixia1       builder.create<arith::SubIOp>(loc, l, c1);
9073b1c86cdSbixia1   builder.create<func::CallOp>(loc, shiftDownFunc, TypeRange(),
9083b1c86cdSbixia1                                shiftDownOperands);
9093b1c86cdSbixia1 
9103b1c86cdSbixia1   builder.setInsertionPointAfter(forL);
9113b1c86cdSbixia1   builder.create<func::ReturnOp>(loc);
9123b1c86cdSbixia1 }
9133b1c86cdSbixia1 
914f6424d11Sbixia1 /// A helper for generating code to perform quick sort. It partitions [lo, hi),
915f6424d11Sbixia1 /// recursively calls quick sort to process the smaller partition and returns
916f6424d11Sbixia1 /// the bigger partition to be processed by the enclosed while-loop.
917f6424d11Sbixia1 static std::pair<Value, Value>
918f6424d11Sbixia1 createQuickSort(OpBuilder &builder, ModuleOp module, func::FuncOp func,
919bfa3bc43SPeiming Liu                 ValueRange args, AffineMap xPerm, uint64_t ny,
920f6424d11Sbixia1                 uint32_t nTrailingP) {
921062e515bSbixia1   MLIRContext *context = module.getContext();
922062e515bSbixia1   Location loc = func.getLoc();
923062e515bSbixia1   Value lo = args[loIdx];
924062e515bSbixia1   Value hi = args[hiIdx];
9254176ce61SPeiming Liu   SmallVector<Type, 2> types(2, lo.getType()); // Only two types.
9264176ce61SPeiming Liu 
927062e515bSbixia1   FlatSymbolRefAttr partitionFunc = getMangledSortHelperFunc(
928bfa3bc43SPeiming Liu       builder, func, {IndexType::get(context)}, kPartitionFuncNamePrefix, xPerm,
929bfa3bc43SPeiming Liu       ny, args.drop_back(nTrailingP), createPartitionFunc);
930f6424d11Sbixia1   Value p = builder
931f6424d11Sbixia1                 .create<func::CallOp>(loc, partitionFunc,
932a1507668Sbixia1                                       TypeRange{IndexType::get(context)},
933f6424d11Sbixia1                                       args.drop_back(nTrailingP))
934f6424d11Sbixia1                 .getResult(0);
9354176ce61SPeiming Liu 
936f6424d11Sbixia1   Value lenLow = builder.create<arith::SubIOp>(loc, p, lo);
937f6424d11Sbixia1   Value lenHigh = builder.create<arith::SubIOp>(loc, hi, p);
9384176ce61SPeiming Liu   // Partition already sorts array with len <= 2
9394176ce61SPeiming Liu   Value c2 = constantIndex(builder, loc, 2);
9404176ce61SPeiming Liu   Value len = builder.create<arith::SubIOp>(loc, hi, lo);
9414176ce61SPeiming Liu   Value lenGtTwo =
9424176ce61SPeiming Liu       builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ugt, len, c2);
9434176ce61SPeiming Liu   scf::IfOp ifLenGtTwo =
9444176ce61SPeiming Liu       builder.create<scf::IfOp>(loc, types, lenGtTwo, /*else=*/true);
9454176ce61SPeiming Liu   builder.setInsertionPointToStart(&ifLenGtTwo.getElseRegion().front());
9464176ce61SPeiming Liu   // Returns an empty range to mark the entire region is fully sorted.
9474176ce61SPeiming Liu   builder.create<scf::YieldOp>(loc, ValueRange{lo, lo});
9484176ce61SPeiming Liu 
9494176ce61SPeiming Liu   // Else len > 2, need recursion.
9504176ce61SPeiming Liu   builder.setInsertionPointToStart(&ifLenGtTwo.getThenRegion().front());
951f6424d11Sbixia1   Value cond = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ule,
952f6424d11Sbixia1                                              lenLow, lenHigh);
953062e515bSbixia1 
954851f85ffSMatthias Springer   Value c0 = constantIndex(builder, loc, 0);
955f6424d11Sbixia1   scf::IfOp ifOp = builder.create<scf::IfOp>(loc, types, cond, /*else=*/true);
956062e515bSbixia1 
957f6424d11Sbixia1   auto mayRecursion = [&](Value low, Value high, Value len) {
958f6424d11Sbixia1     Value cond =
959f6424d11Sbixia1         builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ne, len, c0);
960f6424d11Sbixia1     scf::IfOp ifOp = builder.create<scf::IfOp>(loc, cond, /*else=*/false);
961f6424d11Sbixia1     builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
962f6424d11Sbixia1     SmallVector<Value> operands{low, high};
963f6424d11Sbixia1     operands.append(args.begin() + xStartIdx, args.end());
964f6424d11Sbixia1     builder.create<func::CallOp>(loc, func, operands);
965f6424d11Sbixia1     builder.setInsertionPointAfter(ifOp);
966f6424d11Sbixia1   };
967f6424d11Sbixia1 
968f6424d11Sbixia1   // Recursively call quickSort to process the smaller partition and return
969f6424d11Sbixia1   // the bigger partition to be processed by the enclosed while-loop.
970f6424d11Sbixia1   builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
971f6424d11Sbixia1   mayRecursion(lo, p, lenLow);
9724176ce61SPeiming Liu   builder.create<scf::YieldOp>(loc, ValueRange{p, hi});
973f6424d11Sbixia1 
974f6424d11Sbixia1   builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
9754176ce61SPeiming Liu   mayRecursion(p, hi, lenHigh);
976f6424d11Sbixia1   builder.create<scf::YieldOp>(loc, ValueRange{lo, p});
977f6424d11Sbixia1 
978f6424d11Sbixia1   builder.setInsertionPointAfter(ifOp);
9794176ce61SPeiming Liu   builder.create<scf::YieldOp>(loc, ifOp.getResults());
9804176ce61SPeiming Liu 
9814176ce61SPeiming Liu   builder.setInsertionPointAfter(ifLenGtTwo);
9824176ce61SPeiming Liu   return std::make_pair(ifLenGtTwo.getResult(0), ifLenGtTwo.getResult(1));
983062e515bSbixia1 }
984062e515bSbixia1 
9859409bbb2Sbixia1 /// Creates a function to perform insertion sort on the values in the range of
9869409bbb2Sbixia1 /// index [lo, hi).
9879409bbb2Sbixia1 //
9889409bbb2Sbixia1 // The generate IR corresponds to this C like algorithm:
9899409bbb2Sbixia1 // void insertionSort(lo, hi, data) {
9909409bbb2Sbixia1 //   for (i = lo+1; i < hi; i++) {
9919409bbb2Sbixia1 //      d = data[i];
9929409bbb2Sbixia1 //      p = binarySearch(lo, i-1, data)
9939409bbb2Sbixia1 //      for (j = 0; j > i - p; j++)
9949409bbb2Sbixia1 //        data[i-j] = data[i-j-1]
9959409bbb2Sbixia1 //      data[p] = d
9969409bbb2Sbixia1 //   }
9979409bbb2Sbixia1 // }
9989409bbb2Sbixia1 static void createSortStableFunc(OpBuilder &builder, ModuleOp module,
999bfa3bc43SPeiming Liu                                  func::FuncOp func, AffineMap xPerm,
1000bfa3bc43SPeiming Liu                                  uint64_t ny, uint32_t nTrailingP) {
10018550aebdSbixia1   // Stable sort function doesn't use trailing parameters.
10028550aebdSbixia1   (void)nTrailingP;
10038550aebdSbixia1   assert(nTrailingP == 0);
10049409bbb2Sbixia1   OpBuilder::InsertionGuard insertionGuard(builder);
10059409bbb2Sbixia1   Block *entryBlock = func.addEntryBlock();
10069409bbb2Sbixia1   builder.setInsertionPointToStart(entryBlock);
10079409bbb2Sbixia1 
10089409bbb2Sbixia1   MLIRContext *context = module.getContext();
10099409bbb2Sbixia1   Location loc = func.getLoc();
10109409bbb2Sbixia1   ValueRange args = entryBlock->getArguments();
10119409bbb2Sbixia1   Value c1 = constantIndex(builder, loc, 1);
10129409bbb2Sbixia1   Value lo = args[loIdx];
10139409bbb2Sbixia1   Value hi = args[hiIdx];
10149409bbb2Sbixia1   Value lop1 = builder.create<arith::AddIOp>(loc, lo, c1);
10159409bbb2Sbixia1 
10169409bbb2Sbixia1   // Start the outer for-stmt with induction variable i.
10179409bbb2Sbixia1   scf::ForOp forOpI = builder.create<scf::ForOp>(loc, lop1, hi, c1);
10189409bbb2Sbixia1   builder.setInsertionPointToStart(forOpI.getBody());
10199409bbb2Sbixia1   Value i = forOpI.getInductionVar();
10209409bbb2Sbixia1 
10219409bbb2Sbixia1   // Binary search to find the insertion point p.
10220e1708ffSAart Bik   SmallVector<Value> operands{lo, i};
10234f729d5aSbixia1   operands.append(args.begin() + xStartIdx, args.end());
10249409bbb2Sbixia1   FlatSymbolRefAttr searchFunc = getMangledSortHelperFunc(
1025bfa3bc43SPeiming Liu       builder, func, {IndexType::get(context)}, kBinarySearchFuncNamePrefix,
1026bfa3bc43SPeiming Liu       xPerm, ny, operands, createBinarySearchFunc);
10279409bbb2Sbixia1   Value p = builder
10289409bbb2Sbixia1                 .create<func::CallOp>(loc, searchFunc, TypeRange{c1.getType()},
10299409bbb2Sbixia1                                       operands)
10309409bbb2Sbixia1                 .getResult(0);
10319409bbb2Sbixia1 
10329409bbb2Sbixia1   // Move the value at data[i] to a temporary location.
10334f729d5aSbixia1   operands[0] = operands[1] = i;
10340e1708ffSAart Bik   SmallVector<Value> d;
10354f729d5aSbixia1   forEachIJPairInAllBuffers(
1036bfa3bc43SPeiming Liu       builder, loc, operands, xPerm, ny,
10374f729d5aSbixia1       [&](uint64_t unused, Value i, Value unused2, Value buffer) {
10384f729d5aSbixia1         d.push_back(builder.create<memref::LoadOp>(loc, buffer, i));
10394f729d5aSbixia1       });
10409409bbb2Sbixia1 
10419409bbb2Sbixia1   // Start the inner for-stmt with induction variable j, for moving data[p..i)
10429409bbb2Sbixia1   // to data[p+1..i+1).
10439409bbb2Sbixia1   Value imp = builder.create<arith::SubIOp>(loc, i, p);
10449409bbb2Sbixia1   Value c0 = constantIndex(builder, loc, 0);
10459409bbb2Sbixia1   scf::ForOp forOpJ = builder.create<scf::ForOp>(loc, c0, imp, c1);
10469409bbb2Sbixia1   builder.setInsertionPointToStart(forOpJ.getBody());
10479409bbb2Sbixia1   Value j = forOpJ.getInductionVar();
10489409bbb2Sbixia1   Value imj = builder.create<arith::SubIOp>(loc, i, j);
10494f729d5aSbixia1   operands[1] = imj;
10504f729d5aSbixia1   operands[0] = builder.create<arith::SubIOp>(loc, imj, c1);
10514f729d5aSbixia1   forEachIJPairInAllBuffers(
1052bfa3bc43SPeiming Liu       builder, loc, operands, xPerm, ny,
10534f729d5aSbixia1       [&](uint64_t unused, Value imjm1, Value imj, Value buffer) {
10544f729d5aSbixia1         Value t = builder.create<memref::LoadOp>(loc, buffer, imjm1);
10554f729d5aSbixia1         builder.create<memref::StoreOp>(loc, t, buffer, imj);
10564f729d5aSbixia1       });
10579409bbb2Sbixia1 
10589409bbb2Sbixia1   // Store the value at data[i] to data[p].
10599409bbb2Sbixia1   builder.setInsertionPointAfter(forOpJ);
10604f729d5aSbixia1   operands[0] = operands[1] = p;
10614f729d5aSbixia1   forEachIJPairInAllBuffers(
1062bfa3bc43SPeiming Liu       builder, loc, operands, xPerm, ny,
10634f729d5aSbixia1       [&](uint64_t k, Value p, Value usused, Value buffer) {
10644f729d5aSbixia1         builder.create<memref::StoreOp>(loc, d[k], buffer, p);
10654f729d5aSbixia1       });
10669409bbb2Sbixia1 
10679409bbb2Sbixia1   builder.setInsertionPointAfter(forOpI);
10689409bbb2Sbixia1   builder.create<func::ReturnOp>(loc);
10699409bbb2Sbixia1 }
10709409bbb2Sbixia1 
1071a1507668Sbixia1 /// Creates a function to perform quick sort or a hybrid quick sort on the
1072a1507668Sbixia1 /// values in the range of index [lo, hi).
1073a1507668Sbixia1 //
1074a1507668Sbixia1 //
1075a1507668Sbixia1 // When nTrailingP == 0, the generated IR corresponds to this C like algorithm:
1076a1507668Sbixia1 // void quickSort(lo, hi, data) {
1077f6424d11Sbixia1 //   while (lo + 1 < hi) {
1078a1507668Sbixia1 //        p = partition(low, high, data);
1079f6424d11Sbixia1 //        if (len(lo, p) < len(p+1, hi)) {
1080a1507668Sbixia1 //          quickSort(lo, p, data);
1081f6424d11Sbixia1 //          lo = p+1;
1082f6424d11Sbixia1 //        } else {
1083a1507668Sbixia1 //          quickSort(p + 1, hi, data);
1084f6424d11Sbixia1 //          hi = p;
1085f6424d11Sbixia1 //        }
1086a1507668Sbixia1 //   }
1087a1507668Sbixia1 // }
1088a1507668Sbixia1 //
1089a1507668Sbixia1 // When nTrailingP == 1, the generated IR corresponds to this C like algorithm:
1090a1507668Sbixia1 // void hybridQuickSort(lo, hi, data, depthLimit) {
1091f6424d11Sbixia1 //   while (lo + 1 < hi) {
1092a1507668Sbixia1 //     len = hi - lo;
1093a1507668Sbixia1 //     if (len <= limit) {
1094a1507668Sbixia1 //       insertionSort(lo, hi, data);
1095a1507668Sbixia1 //     } else {
1096a1507668Sbixia1 //       depthLimit --;
1097a1507668Sbixia1 //       if (depthLimit <= 0) {
1098a1507668Sbixia1 //         heapSort(lo, hi, data);
1099a1507668Sbixia1 //       } else {
1100a1507668Sbixia1 //          p = partition(low, high, data);
1101f6424d11Sbixia1 //          if (len(lo, p) < len(p+1, hi)) {
1102f6424d11Sbixia1 //            quickSort(lo, p, data, depthLimit);
1103f6424d11Sbixia1 //            lo = p+1;
1104f6424d11Sbixia1 //          } else {
1105f6424d11Sbixia1 //            quickSort(p + 1, hi, data, depthLimit);
1106f6424d11Sbixia1 //            hi = p;
1107a1507668Sbixia1 //          }
1108f6424d11Sbixia1 //       }
1109a1507668Sbixia1 //     }
1110a1507668Sbixia1 //   }
1111a1507668Sbixia1 // }
1112a1507668Sbixia1 //
1113a1507668Sbixia1 static void createQuickSortFunc(OpBuilder &builder, ModuleOp module,
1114bfa3bc43SPeiming Liu                                 func::FuncOp func, AffineMap xPerm, uint64_t ny,
1115bfa3bc43SPeiming Liu                                 uint32_t nTrailingP) {
1116a1507668Sbixia1   assert(nTrailingP == 1 || nTrailingP == 0);
1117a1507668Sbixia1   bool isHybrid = (nTrailingP == 1);
1118a1507668Sbixia1   OpBuilder::InsertionGuard insertionGuard(builder);
1119a1507668Sbixia1   Block *entryBlock = func.addEntryBlock();
1120a1507668Sbixia1   builder.setInsertionPointToStart(entryBlock);
1121a1507668Sbixia1 
1122a1507668Sbixia1   Location loc = func.getLoc();
1123f6424d11Sbixia1   SmallVector<Value> args;
1124f6424d11Sbixia1   args.append(entryBlock->getArguments().begin(),
1125f6424d11Sbixia1               entryBlock->getArguments().end());
1126a1507668Sbixia1   Value lo = args[loIdx];
1127a1507668Sbixia1   Value hi = args[hiIdx];
1128f6424d11Sbixia1   SmallVector<Type, 2> types(2, lo.getType()); // Only two types.
1129f6424d11Sbixia1   scf::WhileOp whileOp =
1130f6424d11Sbixia1       builder.create<scf::WhileOp>(loc, types, SmallVector<Value, 2>{lo, hi});
1131a1507668Sbixia1 
1132f6424d11Sbixia1   // The before-region of the WhileOp.
1133f6424d11Sbixia1   Block *before =
1134f6424d11Sbixia1       builder.createBlock(&whileOp.getBefore(), {}, types, {loc, loc});
1135f6424d11Sbixia1   builder.setInsertionPointToEnd(before);
1136f6424d11Sbixia1   lo = before->getArgument(0);
1137f6424d11Sbixia1   hi = before->getArgument(1);
1138f6424d11Sbixia1   Value loP1 =
1139f6424d11Sbixia1       builder.create<arith::AddIOp>(loc, lo, constantIndex(builder, loc, 1));
1140f6424d11Sbixia1   Value needSort =
1141f6424d11Sbixia1       builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult, loP1, hi);
1142f6424d11Sbixia1   builder.create<scf::ConditionOp>(loc, needSort, before->getArguments());
1143f6424d11Sbixia1 
1144f6424d11Sbixia1   // The after-region of the WhileOp.
1145f6424d11Sbixia1   Block *after =
1146f6424d11Sbixia1       builder.createBlock(&whileOp.getAfter(), {}, types, {loc, loc});
1147f6424d11Sbixia1   builder.setInsertionPointToEnd(after);
1148f6424d11Sbixia1   lo = after->getArgument(0);
1149f6424d11Sbixia1   hi = after->getArgument(1);
1150f6424d11Sbixia1   args[0] = lo;
1151f6424d11Sbixia1   args[1] = hi;
1152a1507668Sbixia1 
1153a1507668Sbixia1   if (isHybrid) {
1154a1507668Sbixia1     Value len = builder.create<arith::SubIOp>(loc, hi, lo);
1155a1507668Sbixia1     Value lenLimit = constantIndex(builder, loc, 30);
1156a1507668Sbixia1     Value lenCond = builder.create<arith::CmpIOp>(
1157a1507668Sbixia1         loc, arith::CmpIPredicate::ule, len, lenLimit);
1158f6424d11Sbixia1     scf::IfOp lenIf =
1159f6424d11Sbixia1         builder.create<scf::IfOp>(loc, types, lenCond, /*else=*/true);
1160a1507668Sbixia1 
1161a1507668Sbixia1     // When len <= limit.
1162a1507668Sbixia1     builder.setInsertionPointToStart(&lenIf.getThenRegion().front());
1163a1507668Sbixia1     FlatSymbolRefAttr insertionSortFunc = getMangledSortHelperFunc(
1164bfa3bc43SPeiming Liu         builder, func, TypeRange(), kSortStableFuncNamePrefix, xPerm, ny,
1165f6424d11Sbixia1         ValueRange(args).drop_back(nTrailingP), createSortStableFunc);
1166a1507668Sbixia1     builder.create<func::CallOp>(loc, insertionSortFunc, TypeRange(),
1167f6424d11Sbixia1                                  ValueRange(args).drop_back(nTrailingP));
1168f6424d11Sbixia1     builder.create<scf::YieldOp>(loc, ValueRange{lo, lo});
1169a1507668Sbixia1 
1170a1507668Sbixia1     // When len > limit.
1171a1507668Sbixia1     builder.setInsertionPointToStart(&lenIf.getElseRegion().front());
1172f6424d11Sbixia1     Value depthLimit = args.back();
1173f6424d11Sbixia1     depthLimit = builder.create<arith::SubIOp>(loc, depthLimit,
1174f6424d11Sbixia1                                                constantI64(builder, loc, 1));
1175a1507668Sbixia1     Value depthCond =
1176a1507668Sbixia1         builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ule,
1177a1507668Sbixia1                                       depthLimit, constantI64(builder, loc, 0));
1178f6424d11Sbixia1     scf::IfOp depthIf =
1179f6424d11Sbixia1         builder.create<scf::IfOp>(loc, types, depthCond, /*else=*/true);
1180a1507668Sbixia1 
1181a1507668Sbixia1     // When depth exceeds limit.
1182a1507668Sbixia1     builder.setInsertionPointToStart(&depthIf.getThenRegion().front());
1183a1507668Sbixia1     FlatSymbolRefAttr heapSortFunc = getMangledSortHelperFunc(
1184bfa3bc43SPeiming Liu         builder, func, TypeRange(), kHeapSortFuncNamePrefix, xPerm, ny,
1185f6424d11Sbixia1         ValueRange(args).drop_back(nTrailingP), createHeapSortFunc);
1186a1507668Sbixia1     builder.create<func::CallOp>(loc, heapSortFunc, TypeRange(),
1187f6424d11Sbixia1                                  ValueRange(args).drop_back(nTrailingP));
1188f6424d11Sbixia1     builder.create<scf::YieldOp>(loc, ValueRange{lo, lo});
1189a1507668Sbixia1 
1190a1507668Sbixia1     // When depth doesn't exceed limit.
1191a1507668Sbixia1     builder.setInsertionPointToStart(&depthIf.getElseRegion().front());
1192f6424d11Sbixia1     args.back() = depthLimit;
1193f6424d11Sbixia1     std::tie(lo, hi) =
1194bfa3bc43SPeiming Liu         createQuickSort(builder, module, func, args, xPerm, ny, nTrailingP);
1195f6424d11Sbixia1     builder.create<scf::YieldOp>(loc, ValueRange{lo, hi});
1196a1507668Sbixia1 
1197a1507668Sbixia1     builder.setInsertionPointAfter(depthIf);
1198f6424d11Sbixia1     lo = depthIf.getResult(0);
1199f6424d11Sbixia1     hi = depthIf.getResult(1);
1200f6424d11Sbixia1     builder.create<scf::YieldOp>(loc, ValueRange{lo, hi});
1201f6424d11Sbixia1 
1202f6424d11Sbixia1     builder.setInsertionPointAfter(lenIf);
1203f6424d11Sbixia1     lo = lenIf.getResult(0);
1204f6424d11Sbixia1     hi = lenIf.getResult(1);
1205f6424d11Sbixia1   } else {
1206f6424d11Sbixia1     std::tie(lo, hi) =
1207bfa3bc43SPeiming Liu         createQuickSort(builder, module, func, args, xPerm, ny, nTrailingP);
1208a1507668Sbixia1   }
1209a1507668Sbixia1 
1210f6424d11Sbixia1   // New [lo, hi) for the next while-loop iteration.
1211f6424d11Sbixia1   builder.create<scf::YieldOp>(loc, ValueRange{lo, hi});
1212f6424d11Sbixia1 
1213f6424d11Sbixia1   // After the while-loop.
1214f6424d11Sbixia1   builder.setInsertionPointAfter(whileOp);
1215a1507668Sbixia1   builder.create<func::ReturnOp>(loc);
1216a1507668Sbixia1 }
1217a1507668Sbixia1 
12184f729d5aSbixia1 /// Implements the rewriting for operator sort and sort_coo.
12194f729d5aSbixia1 template <typename OpTy>
1220bfa3bc43SPeiming Liu LogicalResult matchAndRewriteSortOp(OpTy op, ValueRange xys, AffineMap xPerm,
1221bfa3bc43SPeiming Liu                                     uint64_t ny, PatternRewriter &rewriter) {
12224f729d5aSbixia1   Location loc = op.getLoc();
12230e1708ffSAart Bik   SmallVector<Value> operands{constantIndex(rewriter, loc, 0), op.getN()};
12244f729d5aSbixia1 
12254f729d5aSbixia1   // Convert `values` to have dynamic shape and append them to `operands`.
12264f729d5aSbixia1   for (Value v : xys) {
12279916ab03Swren romano     auto mtp = getMemRefType(v);
12284f729d5aSbixia1     if (!mtp.isDynamicDim(0)) {
12294f729d5aSbixia1       auto newMtp =
1230399638f9SAliia Khasanova           MemRefType::get({ShapedType::kDynamic}, mtp.getElementType());
12314f729d5aSbixia1       v = rewriter.create<memref::CastOp>(loc, newMtp, v);
12324f729d5aSbixia1     }
12334f729d5aSbixia1     operands.push_back(v);
12344f729d5aSbixia1   }
12353b1c86cdSbixia1 
12364f729d5aSbixia1   auto insertPoint = op->template getParentOfType<func::FuncOp>();
12379a29d875SKohei Yamaguchi   if (!insertPoint)
12389a29d875SKohei Yamaguchi     return failure();
12399a29d875SKohei Yamaguchi 
12403b1c86cdSbixia1   SmallString<32> funcName;
12413b1c86cdSbixia1   FuncGeneratorType funcGenerator;
12428550aebdSbixia1   uint32_t nTrailingP = 0;
12433b1c86cdSbixia1   switch (op.getAlgorithm()) {
1244a1507668Sbixia1   case SparseTensorSortKind::HybridQuickSort: {
1245a1507668Sbixia1     funcName = kHybridQuickSortFuncNamePrefix;
1246a1507668Sbixia1     funcGenerator = createQuickSortFunc;
1247a1507668Sbixia1     nTrailingP = 1;
1248a1507668Sbixia1     // As a heuristics, set depthLimit = 2 * log2(n).
1249a1507668Sbixia1     Value lo = operands[loIdx];
1250a1507668Sbixia1     Value hi = operands[hiIdx];
1251a1507668Sbixia1     Value len = rewriter.create<arith::IndexCastOp>(
1252a1507668Sbixia1         loc, rewriter.getI64Type(),
1253a1507668Sbixia1         rewriter.create<arith::SubIOp>(loc, hi, lo));
1254a1507668Sbixia1     Value depthLimit = rewriter.create<arith::SubIOp>(
1255a1507668Sbixia1         loc, constantI64(rewriter, loc, 64),
1256a1507668Sbixia1         rewriter.create<math::CountLeadingZerosOp>(loc, len));
1257f6424d11Sbixia1     operands.push_back(depthLimit);
1258a1507668Sbixia1     break;
1259a1507668Sbixia1   }
12603b1c86cdSbixia1   case SparseTensorSortKind::QuickSort:
1261a1507668Sbixia1     funcName = kQuickSortFuncNamePrefix;
1262a1507668Sbixia1     funcGenerator = createQuickSortFunc;
12633b1c86cdSbixia1     break;
12643b1c86cdSbixia1   case SparseTensorSortKind::InsertionSortStable:
12653b1c86cdSbixia1     funcName = kSortStableFuncNamePrefix;
12663b1c86cdSbixia1     funcGenerator = createSortStableFunc;
12673b1c86cdSbixia1     break;
12683b1c86cdSbixia1   case SparseTensorSortKind::HeapSort:
12693b1c86cdSbixia1     funcName = kHeapSortFuncNamePrefix;
12703b1c86cdSbixia1     funcGenerator = createHeapSortFunc;
12713b1c86cdSbixia1     break;
12723b1c86cdSbixia1   }
12733b1c86cdSbixia1 
12744f729d5aSbixia1   FlatSymbolRefAttr func =
1275bfa3bc43SPeiming Liu       getMangledSortHelperFunc(rewriter, insertPoint, TypeRange(), funcName,
1276bfa3bc43SPeiming Liu                                xPerm, ny, operands, funcGenerator, nTrailingP);
12774f729d5aSbixia1   rewriter.replaceOpWithNewOp<func::CallOp>(op, func, TypeRange(), operands);
12784f729d5aSbixia1   return success();
12794f729d5aSbixia1 }
12804f729d5aSbixia1 
1281062e515bSbixia1 //===---------------------------------------------------------------------===//
1282062e515bSbixia1 // The actual sparse buffer rewriting rules.
1283062e515bSbixia1 //===---------------------------------------------------------------------===//
1284062e515bSbixia1 
1285062e515bSbixia1 namespace {
1286654bbbdeSbixia1 /// Sparse rewriting rule for the push_back operator.
1287654bbbdeSbixia1 struct PushBackRewriter : OpRewritePattern<PushBackOp> {
1288654bbbdeSbixia1 public:
1289654bbbdeSbixia1   using OpRewritePattern<PushBackOp>::OpRewritePattern;
12905618d2beSbixia1   PushBackRewriter(MLIRContext *context, bool enableInit)
12915618d2beSbixia1       : OpRewritePattern(context), enableBufferInitialization(enableInit) {}
1292654bbbdeSbixia1   LogicalResult matchAndRewrite(PushBackOp op,
1293654bbbdeSbixia1                                 PatternRewriter &rewriter) const override {
1294d45be887Sbixia1     // Rewrite push_back(buffer, value, n) to:
1295d45be887Sbixia1     // new_size = size(buffer) + n
1296d45be887Sbixia1     // if (new_size > capacity(buffer))
1297d45be887Sbixia1     //    while new_size > new_capacity
1298d45be887Sbixia1     //      new_capacity = new_capacity*2
1299654bbbdeSbixia1     //    new_buffer = realloc(buffer, new_capacity)
1300654bbbdeSbixia1     // buffer = new_buffer
1301d45be887Sbixia1     // subBuffer = subviewof(buffer)
1302d45be887Sbixia1     // linalg.fill subBuffer value
1303d45be887Sbixia1     //
1304d45be887Sbixia1     // size(buffer) += n
13051c835b5aSbixia1     //
13061c835b5aSbixia1     // The capacity check is skipped when the attribute inbounds is presented.
1307654bbbdeSbixia1     Location loc = op->getLoc();
1308654bbbdeSbixia1     Value c0 = constantIndex(rewriter, loc, 0);
1309654bbbdeSbixia1     Value buffer = op.getInBuffer();
1310654bbbdeSbixia1     Value capacity = rewriter.create<memref::DimOp>(loc, buffer, c0);
1311988733c6SPeiming Liu     Value size = op.getCurSize();
1312654bbbdeSbixia1     Value value = op.getValue();
13131c835b5aSbixia1 
1314d45be887Sbixia1     Value n = op.getN() ? op.getN() : constantIndex(rewriter, loc, 1);
1315d45be887Sbixia1     Value newSize = rewriter.create<arith::AddIOp>(loc, size, n);
1316d45be887Sbixia1     auto nValue = dyn_cast_or_null<arith::ConstantIndexOp>(n.getDefiningOp());
1317d45be887Sbixia1     bool nIsOne = (nValue && nValue.value() == 1);
1318d45be887Sbixia1 
13191c835b5aSbixia1     if (!op.getInbounds()) {
13201c835b5aSbixia1       Value cond = rewriter.create<arith::CmpIOp>(
1321d45be887Sbixia1           loc, arith::CmpIPredicate::ugt, newSize, capacity);
13221c835b5aSbixia1 
1323d45be887Sbixia1       Value c2 = constantIndex(rewriter, loc, 2);
1324654bbbdeSbixia1       auto bufferType =
1325399638f9SAliia Khasanova           MemRefType::get({ShapedType::kDynamic}, value.getType());
1326654bbbdeSbixia1       scf::IfOp ifOp = rewriter.create<scf::IfOp>(loc, bufferType, cond,
1327654bbbdeSbixia1                                                   /*else=*/true);
1328654bbbdeSbixia1       // True branch.
1329654bbbdeSbixia1       rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front());
1330d45be887Sbixia1       if (nIsOne) {
1331654bbbdeSbixia1         capacity = rewriter.create<arith::MulIOp>(loc, capacity, c2);
1332d45be887Sbixia1       } else {
1333d45be887Sbixia1         // Use a do-while loop to calculate the new capacity as follows:
1334d45be887Sbixia1         //   do { new_capacity *= 2 } while (size > new_capacity)
1335d45be887Sbixia1         scf::WhileOp whileOp =
1336d45be887Sbixia1             rewriter.create<scf::WhileOp>(loc, capacity.getType(), capacity);
1337d45be887Sbixia1 
1338d45be887Sbixia1         // The before-region of the WhileOp.
1339d45be887Sbixia1         Block *before = rewriter.createBlock(&whileOp.getBefore(), {},
1340d45be887Sbixia1                                              {capacity.getType()}, {loc});
1341d45be887Sbixia1         rewriter.setInsertionPointToEnd(before);
1342d45be887Sbixia1 
1343d45be887Sbixia1         capacity =
1344d45be887Sbixia1             rewriter.create<arith::MulIOp>(loc, before->getArgument(0), c2);
1345d45be887Sbixia1         cond = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ugt,
1346d45be887Sbixia1                                               newSize, capacity);
1347d45be887Sbixia1         rewriter.create<scf::ConditionOp>(loc, cond, ValueRange{capacity});
1348d45be887Sbixia1         // The after-region of the WhileOp.
1349d45be887Sbixia1         Block *after = rewriter.createBlock(&whileOp.getAfter(), {},
1350d45be887Sbixia1                                             {capacity.getType()}, {loc});
1351d45be887Sbixia1         rewriter.setInsertionPointToEnd(after);
1352d45be887Sbixia1         rewriter.create<scf::YieldOp>(loc, after->getArguments());
1353d45be887Sbixia1 
1354d45be887Sbixia1         rewriter.setInsertionPointAfter(whileOp);
1355d45be887Sbixia1         capacity = whileOp.getResult(0);
1356d45be887Sbixia1       }
1357d45be887Sbixia1 
1358654bbbdeSbixia1       Value newBuffer =
1359654bbbdeSbixia1           rewriter.create<memref::ReallocOp>(loc, bufferType, buffer, capacity);
13605618d2beSbixia1       if (enableBufferInitialization) {
13615618d2beSbixia1         Value fillSize = rewriter.create<arith::SubIOp>(loc, capacity, newSize);
1362ea4be70cSbixia1         Value fillValue = constantZero(rewriter, loc, value.getType());
13635618d2beSbixia1         Value subBuffer = rewriter.create<memref::SubViewOp>(
13645618d2beSbixia1             loc, newBuffer, /*offset=*/ValueRange{newSize},
13655618d2beSbixia1             /*size=*/ValueRange{fillSize},
13665618d2beSbixia1             /*step=*/ValueRange{constantIndex(rewriter, loc, 1)});
13675618d2beSbixia1         rewriter.create<linalg::FillOp>(loc, fillValue, subBuffer);
13685618d2beSbixia1       }
1369654bbbdeSbixia1       rewriter.create<scf::YieldOp>(loc, newBuffer);
1370654bbbdeSbixia1 
1371654bbbdeSbixia1       // False branch.
1372654bbbdeSbixia1       rewriter.setInsertionPointToStart(&ifOp.getElseRegion().front());
1373654bbbdeSbixia1       rewriter.create<scf::YieldOp>(loc, buffer);
1374654bbbdeSbixia1 
13751c835b5aSbixia1       // Prepare for adding the value to the end of the buffer.
1376654bbbdeSbixia1       rewriter.setInsertionPointAfter(ifOp);
1377654bbbdeSbixia1       buffer = ifOp.getResult(0);
13781c835b5aSbixia1     }
13791c835b5aSbixia1 
13801c835b5aSbixia1     // Add the value to the end of the buffer.
1381d45be887Sbixia1     if (nIsOne) {
1382654bbbdeSbixia1       rewriter.create<memref::StoreOp>(loc, value, buffer, size);
1383d45be887Sbixia1     } else {
1384d45be887Sbixia1       Value subBuffer = rewriter.create<memref::SubViewOp>(
1385d45be887Sbixia1           loc, buffer, /*offset=*/ValueRange{size}, /*size=*/ValueRange{n},
1386d45be887Sbixia1           /*step=*/ValueRange{constantIndex(rewriter, loc, 1)});
1387d45be887Sbixia1       rewriter.create<linalg::FillOp>(loc, value, subBuffer);
1388d45be887Sbixia1     }
1389654bbbdeSbixia1 
1390d45be887Sbixia1     // Update the buffer size.
1391988733c6SPeiming Liu     rewriter.replaceOp(op, {buffer, newSize});
1392654bbbdeSbixia1     return success();
1393654bbbdeSbixia1   }
13945618d2beSbixia1 
13955618d2beSbixia1 private:
13965618d2beSbixia1   bool enableBufferInitialization;
1397654bbbdeSbixia1 };
1398654bbbdeSbixia1 
13994f729d5aSbixia1 /// Sparse rewriting rule for the sort_coo operator.
14000083f833SPeiming Liu struct SortRewriter : public OpRewritePattern<SortOp> {
14014f729d5aSbixia1 public:
14020083f833SPeiming Liu   using OpRewritePattern<SortOp>::OpRewritePattern;
14034f729d5aSbixia1 
14040083f833SPeiming Liu   LogicalResult matchAndRewrite(SortOp op,
14054f729d5aSbixia1                                 PatternRewriter &rewriter) const override {
14060e1708ffSAart Bik     SmallVector<Value> xys;
14074f729d5aSbixia1     xys.push_back(op.getXy());
14084f729d5aSbixia1     xys.append(op.getYs().begin(), op.getYs().end());
14094f729d5aSbixia1 
1410bfa3bc43SPeiming Liu     auto xPerm = op.getPermMap();
14114f729d5aSbixia1     uint64_t ny = 0;
14124f729d5aSbixia1     if (auto nyAttr = op.getNyAttr())
14134f729d5aSbixia1       ny = nyAttr.getInt();
14144f729d5aSbixia1 
1415bfa3bc43SPeiming Liu     return matchAndRewriteSortOp(op, xys, xPerm, ny, rewriter);
1416062e515bSbixia1   }
1417062e515bSbixia1 };
1418062e515bSbixia1 
1419062e515bSbixia1 } // namespace
1420062e515bSbixia1 
1421062e515bSbixia1 //===---------------------------------------------------------------------===//
1422062e515bSbixia1 // Methods that add patterns described in this file to a pattern list.
1423062e515bSbixia1 //===---------------------------------------------------------------------===//
1424062e515bSbixia1 
14255618d2beSbixia1 void mlir::populateSparseBufferRewriting(RewritePatternSet &patterns,
14265618d2beSbixia1                                          bool enableBufferInitialization) {
14275618d2beSbixia1   patterns.add<PushBackRewriter>(patterns.getContext(),
14285618d2beSbixia1                                  enableBufferInitialization);
14290083f833SPeiming Liu   patterns.add<SortRewriter>(patterns.getContext());
1430062e515bSbixia1 }
1431