1062e515bSbixia1 //===- SparseBufferRewriting.cpp - Sparse buffer rewriting rules ----------===// 2062e515bSbixia1 // 3062e515bSbixia1 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4062e515bSbixia1 // See https://llvm.org/LICENSE.txt for license information. 5062e515bSbixia1 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6062e515bSbixia1 // 7062e515bSbixia1 //===----------------------------------------------------------------------===// 8062e515bSbixia1 // 9062e515bSbixia1 // This file implements rewriting rules that are specific to sparse tensor 10062e515bSbixia1 // primitives with memref operands. 11062e515bSbixia1 // 12062e515bSbixia1 //===----------------------------------------------------------------------===// 13062e515bSbixia1 14365777ecSAart Bik #include "Utils/CodegenUtils.h" 15062e515bSbixia1 16bfbf3bcbSAart Bik #include "mlir/Dialect/Arith/IR/Arith.h" 17062e515bSbixia1 #include "mlir/Dialect/Func/IR/FuncOps.h" 18d45be887Sbixia1 #include "mlir/Dialect/Linalg/IR/Linalg.h" 19a1507668Sbixia1 #include "mlir/Dialect/Math/IR/Math.h" 20062e515bSbixia1 #include "mlir/Dialect/MemRef/IR/MemRef.h" 21062e515bSbixia1 #include "mlir/Dialect/SCF/IR/SCF.h" 22062e515bSbixia1 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" 23062e515bSbixia1 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h" 24062e515bSbixia1 #include "mlir/Support/LLVM.h" 25062e515bSbixia1 26062e515bSbixia1 using namespace mlir; 27062e515bSbixia1 using namespace mlir::sparse_tensor; 28062e515bSbixia1 29062e515bSbixia1 //===---------------------------------------------------------------------===// 30062e515bSbixia1 // Helper methods for the actual rewriting rules. 31062e515bSbixia1 //===---------------------------------------------------------------------===// 32062e515bSbixia1 339409bbb2Sbixia1 static constexpr uint64_t loIdx = 0; 349409bbb2Sbixia1 static constexpr uint64_t hiIdx = 1; 359409bbb2Sbixia1 static constexpr uint64_t xStartIdx = 2; 369409bbb2Sbixia1 379409bbb2Sbixia1 static constexpr const char kPartitionFuncNamePrefix[] = "_sparse_partition_"; 389409bbb2Sbixia1 static constexpr const char kBinarySearchFuncNamePrefix[] = 399409bbb2Sbixia1 "_sparse_binary_search_"; 40a1507668Sbixia1 static constexpr const char kHybridQuickSortFuncNamePrefix[] = 41a1507668Sbixia1 "_sparse_hybrid_qsort_"; 429409bbb2Sbixia1 static constexpr const char kSortStableFuncNamePrefix[] = 439409bbb2Sbixia1 "_sparse_sort_stable_"; 443b1c86cdSbixia1 static constexpr const char kShiftDownFuncNamePrefix[] = "_sparse_shift_down_"; 453b1c86cdSbixia1 static constexpr const char kHeapSortFuncNamePrefix[] = "_sparse_heap_sort_"; 46a1507668Sbixia1 static constexpr const char kQuickSortFuncNamePrefix[] = "_sparse_qsort_"; 47062e515bSbixia1 48bfa3bc43SPeiming Liu using FuncGeneratorType = function_ref<void(OpBuilder &, ModuleOp, func::FuncOp, 49bfa3bc43SPeiming Liu AffineMap, uint64_t, uint32_t)>; 50062e515bSbixia1 51062e515bSbixia1 /// Constructs a function name with this format to facilitate quick sort: 52bfa3bc43SPeiming Liu /// <namePrefix><xPerm>_<x type>_<y0 type>..._<yn type> for sort 53bfa3bc43SPeiming Liu /// <namePrefix><xPerm>_<x type>_coo_<ny>_<y0 type>..._<yn type> for sort_coo 54062e515bSbixia1 static void getMangledSortHelperFuncName(llvm::raw_svector_ostream &nameOstream, 55bfa3bc43SPeiming Liu StringRef namePrefix, AffineMap xPerm, 56bfa3bc43SPeiming Liu uint64_t ny, ValueRange operands) { 57bfa3bc43SPeiming Liu nameOstream << namePrefix; 58bfa3bc43SPeiming Liu for (auto res : xPerm.getResults()) 591609f1c2Slong.chen nameOstream << cast<AffineDimExpr>(res).getPosition() << "_"; 60062e515bSbixia1 61bfa3bc43SPeiming Liu nameOstream << getMemRefType(operands[xStartIdx]).getElementType(); 624f729d5aSbixia1 nameOstream << "_coo_" << ny; 634f729d5aSbixia1 64bfa3bc43SPeiming Liu constexpr uint64_t yBufferOffset = 1; 654f729d5aSbixia1 for (Value v : operands.drop_front(xStartIdx + yBufferOffset)) 669916ab03Swren romano nameOstream << "_" << getMemRefType(v).getElementType(); 67062e515bSbixia1 } 68062e515bSbixia1 69062e515bSbixia1 /// Looks up a function that is appropriate for the given operands being 704f729d5aSbixia1 /// sorted, and creates such a function if it doesn't exist yet. The 71bfa3bc43SPeiming Liu /// parameters `xPerm` and `ny` tell the number of x and y values provided 72bfa3bc43SPeiming Liu /// by the buffer in xStartIdx. 738550aebdSbixia1 // 748550aebdSbixia1 // All sorting function generators take (lo, hi, xs, ys) in `operands` as 758550aebdSbixia1 // parameters for the sorting functions. Other parameters, such as the recursive 768550aebdSbixia1 // call depth, are appended to the end of the parameter list as 778550aebdSbixia1 // "trailing parameters". 78bfa3bc43SPeiming Liu static FlatSymbolRefAttr getMangledSortHelperFunc( 79bfa3bc43SPeiming Liu OpBuilder &builder, func::FuncOp insertPoint, TypeRange resultTypes, 80bfa3bc43SPeiming Liu StringRef namePrefix, AffineMap xPerm, uint64_t ny, ValueRange operands, 81bfa3bc43SPeiming Liu FuncGeneratorType createFunc, uint32_t nTrailingP = 0) { 82062e515bSbixia1 SmallString<32> nameBuffer; 83062e515bSbixia1 llvm::raw_svector_ostream nameOstream(nameBuffer); 84bfa3bc43SPeiming Liu getMangledSortHelperFuncName(nameOstream, namePrefix, xPerm, ny, 858550aebdSbixia1 operands.drop_back(nTrailingP)); 86062e515bSbixia1 87062e515bSbixia1 ModuleOp module = insertPoint->getParentOfType<ModuleOp>(); 88062e515bSbixia1 MLIRContext *context = module.getContext(); 89062e515bSbixia1 auto result = SymbolRefAttr::get(context, nameOstream.str()); 90062e515bSbixia1 auto func = module.lookupSymbol<func::FuncOp>(result.getAttr()); 91062e515bSbixia1 92062e515bSbixia1 if (!func) { 93062e515bSbixia1 // Create the function. 94062e515bSbixia1 OpBuilder::InsertionGuard insertionGuard(builder); 95062e515bSbixia1 builder.setInsertionPoint(insertPoint); 96062e515bSbixia1 Location loc = insertPoint.getLoc(); 97062e515bSbixia1 func = builder.create<func::FuncOp>( 98062e515bSbixia1 loc, nameOstream.str(), 99062e515bSbixia1 FunctionType::get(context, operands.getTypes(), resultTypes)); 100062e515bSbixia1 func.setPrivate(); 101bfa3bc43SPeiming Liu createFunc(builder, module, func, xPerm, ny, nTrailingP); 102062e515bSbixia1 } 103062e515bSbixia1 104062e515bSbixia1 return result; 105062e515bSbixia1 } 106062e515bSbixia1 1074f729d5aSbixia1 /// Creates a code block to process each pair of (xs[i], xs[j]) for sorting. 1084f729d5aSbixia1 /// The code to process the value pairs is generated by `bodyBuilder`. 1094f729d5aSbixia1 static void forEachIJPairInXs( 110bfa3bc43SPeiming Liu OpBuilder &builder, Location loc, ValueRange args, AffineMap xPerm, 111bfa3bc43SPeiming Liu uint64_t ny, 112bfa3bc43SPeiming Liu function_ref<void(uint64_t, Value, Value, Value)> bodyBuilder) { 113bfa3bc43SPeiming Liu Value cstep = constantIndex(builder, loc, xPerm.getNumResults() + ny); 114bfa3bc43SPeiming Liu Value iOffset = builder.create<arith::MulIOp>(loc, args[0], cstep); 115bfa3bc43SPeiming Liu Value jOffset = builder.create<arith::MulIOp>(loc, args[1], cstep); 116bfa3bc43SPeiming Liu for (unsigned k = 0, e = xPerm.getNumResults(); k < e; k++) { 1171609f1c2Slong.chen unsigned actualK = cast<AffineDimExpr>(xPerm.getResult(k)).getPosition(); 118bfa3bc43SPeiming Liu Value ak = constantIndex(builder, loc, actualK); 119bfa3bc43SPeiming Liu Value i = builder.create<arith::AddIOp>(loc, ak, iOffset); 120bfa3bc43SPeiming Liu Value j = builder.create<arith::AddIOp>(loc, ak, jOffset); 121bfa3bc43SPeiming Liu Value buffer = args[xStartIdx]; 122bfa3bc43SPeiming Liu 1234f729d5aSbixia1 bodyBuilder(k, i, j, buffer); 1244f729d5aSbixia1 } 1254f729d5aSbixia1 } 1264f729d5aSbixia1 1274f729d5aSbixia1 /// Creates a code block to process each pair of (xys[i], xys[j]) for sorting. 1284f729d5aSbixia1 /// The code to process the value pairs is generated by `bodyBuilder`. 1294f729d5aSbixia1 static void forEachIJPairInAllBuffers( 130bfa3bc43SPeiming Liu OpBuilder &builder, Location loc, ValueRange args, AffineMap xPerm, 131bfa3bc43SPeiming Liu uint64_t ny, 132bfa3bc43SPeiming Liu function_ref<void(uint64_t, Value, Value, Value)> bodyBuilder) { 1334f729d5aSbixia1 134bfa3bc43SPeiming Liu // Create code for the first (xPerm + ny) buffers. 135*5262865aSKazu Hirata SmallVector<AffineExpr> exps(xPerm.getResults()); 136bfa3bc43SPeiming Liu for (unsigned y = 0; y < ny; y++) { 137bfa3bc43SPeiming Liu exps.push_back(builder.getAffineDimExpr(y + xPerm.getNumResults())); 138bfa3bc43SPeiming Liu } 139bfa3bc43SPeiming Liu AffineMap xyPerm = AffineMap::get(exps.size(), 0, exps, builder.getContext()); 140bfa3bc43SPeiming Liu assert(xyPerm.isPermutation()); 1414f729d5aSbixia1 142bfa3bc43SPeiming Liu forEachIJPairInXs(builder, loc, args, xyPerm, 0, bodyBuilder); 1434f729d5aSbixia1 144bfa3bc43SPeiming Liu constexpr uint64_t numHandledBuffers = 1; 1454f729d5aSbixia1 // Create code for the remaining buffers. 1464f729d5aSbixia1 Value i = args[0]; 1474f729d5aSbixia1 Value j = args[1]; 1484f729d5aSbixia1 for (const auto &arg : 1494f729d5aSbixia1 llvm::enumerate(args.drop_front(xStartIdx + numHandledBuffers))) { 150bfa3bc43SPeiming Liu bodyBuilder(arg.index() + xPerm.getNumResults() + ny, i, j, arg.value()); 1514f729d5aSbixia1 } 1524f729d5aSbixia1 } 1534f729d5aSbixia1 1549b800bf7Sbixia1 /// Creates a code block for swapping the values in index i and j for all the 155062e515bSbixia1 /// buffers. 156062e515bSbixia1 // 1579b800bf7Sbixia1 // The generated IR corresponds to this C like algorithm: 158062e515bSbixia1 // swap(x0[i], x0[j]); 159062e515bSbixia1 // swap(x1[i], x1[j]); 160062e515bSbixia1 // ... 161062e515bSbixia1 // swap(xn[i], xn[j]); 162062e515bSbixia1 // swap(y0[i], y0[j]); 163062e515bSbixia1 // ... 164062e515bSbixia1 // swap(yn[i], yn[j]); 1654f729d5aSbixia1 static void createSwap(OpBuilder &builder, Location loc, ValueRange args, 166bfa3bc43SPeiming Liu AffineMap xPerm, uint64_t ny) { 1674f729d5aSbixia1 auto swapOnePair = [&](uint64_t unused, Value i, Value j, Value buffer) { 1684f729d5aSbixia1 Value vi = builder.create<memref::LoadOp>(loc, buffer, i); 1694f729d5aSbixia1 Value vj = builder.create<memref::LoadOp>(loc, buffer, j); 1704f729d5aSbixia1 builder.create<memref::StoreOp>(loc, vj, buffer, i); 1714f729d5aSbixia1 builder.create<memref::StoreOp>(loc, vi, buffer, j); 1724f729d5aSbixia1 }; 1734f729d5aSbixia1 174bfa3bc43SPeiming Liu forEachIJPairInAllBuffers(builder, loc, args, xPerm, ny, swapOnePair); 175062e515bSbixia1 } 176062e515bSbixia1 1772ef41627Sbixia1 /// Creates code to compare all the (xs[i], xs[j]) pairs. The method to compare 1782ef41627Sbixia1 /// each pair is create via `compareBuilder`. 1792ef41627Sbixia1 static Value createInlinedCompareImplementation( 180bfa3bc43SPeiming Liu OpBuilder &builder, Location loc, ValueRange args, AffineMap xPerm, 181bfa3bc43SPeiming Liu uint64_t ny, 1822ef41627Sbixia1 function_ref<Value(OpBuilder &, Location, Value, Value, Value, bool, bool)> 1839b800bf7Sbixia1 compareBuilder) { 1842ef41627Sbixia1 Value result; 1854f729d5aSbixia1 auto bodyBuilder = [&](uint64_t k, Value i, Value j, Value buffer) { 1862ef41627Sbixia1 bool isFirstDim = (k == 0); 187bfa3bc43SPeiming Liu bool isLastDim = (k == xPerm.getNumResults() - 1); 1882ef41627Sbixia1 Value val = 1892ef41627Sbixia1 compareBuilder(builder, loc, i, j, buffer, isFirstDim, isLastDim); 1902ef41627Sbixia1 if (isFirstDim) { 1912ef41627Sbixia1 result = val; 1922ef41627Sbixia1 } else if (!isLastDim) { 1939b800bf7Sbixia1 OpBuilder::InsertionGuard insertionGuard(builder); 1942ef41627Sbixia1 auto ifOp = cast<scf::IfOp>(val.getDefiningOp()); 1959b800bf7Sbixia1 builder.setInsertionPointAfter(ifOp); 1969b800bf7Sbixia1 builder.create<scf::YieldOp>(loc, ifOp.getResult(0)); 1979b800bf7Sbixia1 } 1984f729d5aSbixia1 }; 1994f729d5aSbixia1 200bfa3bc43SPeiming Liu forEachIJPairInXs(builder, loc, args, xPerm, ny, bodyBuilder); 2019b800bf7Sbixia1 2022ef41627Sbixia1 builder.setInsertionPointAfterValue(result); 2032ef41627Sbixia1 return result; 2049b800bf7Sbixia1 } 2059b800bf7Sbixia1 2062ef41627Sbixia1 /// Generates code to compare whether x[i] is equal to x[j] and returns the 2072ef41627Sbixia1 /// result of the comparison. 2082ef41627Sbixia1 static Value createEqCompare(OpBuilder &builder, Location loc, Value i, Value j, 2092ef41627Sbixia1 Value x, bool isFirstDim, bool isLastDim) { 2109b800bf7Sbixia1 Value vi = builder.create<memref::LoadOp>(loc, x, i); 2119b800bf7Sbixia1 Value vj = builder.create<memref::LoadOp>(loc, x, j); 2129b800bf7Sbixia1 2132ef41627Sbixia1 Value res; 2142ef41627Sbixia1 if (isLastDim) { 2152ef41627Sbixia1 res = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, vi, vj); 2162ef41627Sbixia1 // For 1D, we create a compare without any control flow. Otherwise, we 2172ef41627Sbixia1 // create YieldOp to return the result in the nested if-stmt. 2182ef41627Sbixia1 if (!isFirstDim) 2192ef41627Sbixia1 builder.create<scf::YieldOp>(loc, res); 2202ef41627Sbixia1 } else { 2212ef41627Sbixia1 Value ne = 2222ef41627Sbixia1 builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ne, vi, vj); 2232ef41627Sbixia1 scf::IfOp ifOp = builder.create<scf::IfOp>(loc, builder.getIntegerType(1), 2242ef41627Sbixia1 ne, /*else=*/true); 2252ef41627Sbixia1 // If (x[i] != x[j]). 2262ef41627Sbixia1 builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); 2272ef41627Sbixia1 Value f = constantI1(builder, loc, false); 2289b800bf7Sbixia1 builder.create<scf::YieldOp>(loc, f); 2299b800bf7Sbixia1 2302ef41627Sbixia1 // If (x[i] == x[j]). Set up the insertion point for the nested if-stmt that 2312ef41627Sbixia1 // checks the remaining dimensions. 2322ef41627Sbixia1 builder.setInsertionPointToStart(&ifOp.getElseRegion().front()); 2332ef41627Sbixia1 res = ifOp.getResult(0); 2349b800bf7Sbixia1 } 2359b800bf7Sbixia1 2362ef41627Sbixia1 return res; 2379b800bf7Sbixia1 } 2389b800bf7Sbixia1 2392ef41627Sbixia1 /// Creates code to compare whether xs[i] is equal to xs[j]. 2409b800bf7Sbixia1 // 2419b800bf7Sbixia1 // The generate IR corresponds to this C like algorithm: 2429b800bf7Sbixia1 // if (x0[i] != x0[j]) 2439b800bf7Sbixia1 // return false; 2449b800bf7Sbixia1 // else 2459b800bf7Sbixia1 // if (x1[i] != x1[j]) 2469b800bf7Sbixia1 // return false; 2479b800bf7Sbixia1 // else if (x2[2] != x2[j])) 2489b800bf7Sbixia1 // and so on ... 2492ef41627Sbixia1 static Value createInlinedEqCompare(OpBuilder &builder, Location loc, 250bfa3bc43SPeiming Liu ValueRange args, AffineMap xPerm, 251bfa3bc43SPeiming Liu uint64_t ny, uint32_t nTrailingP = 0) { 2528550aebdSbixia1 // Compare functions don't use trailing parameters. 2538550aebdSbixia1 (void)nTrailingP; 2548550aebdSbixia1 assert(nTrailingP == 0); 255bfa3bc43SPeiming Liu return createInlinedCompareImplementation(builder, loc, args, xPerm, ny, 2564f729d5aSbixia1 createEqCompare); 2579b800bf7Sbixia1 } 2589b800bf7Sbixia1 2592ef41627Sbixia1 /// Generates code to compare whether x[i] is less than x[j] and returns the 2602ef41627Sbixia1 /// result of the comparison. 2612ef41627Sbixia1 static Value createLessThanCompare(OpBuilder &builder, Location loc, Value i, 2622ef41627Sbixia1 Value j, Value x, bool isFirstDim, 263062e515bSbixia1 bool isLastDim) { 264062e515bSbixia1 Value vi = builder.create<memref::LoadOp>(loc, x, i); 265062e515bSbixia1 Value vj = builder.create<memref::LoadOp>(loc, x, j); 266062e515bSbixia1 2672ef41627Sbixia1 Value res; 2682ef41627Sbixia1 if (isLastDim) { 2692ef41627Sbixia1 res = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult, vi, vj); 2702ef41627Sbixia1 // For 1D, we create a compare without any control flow. Otherwise, we 2712ef41627Sbixia1 // create YieldOp to return the result in the nested if-stmt. 2722ef41627Sbixia1 if (!isFirstDim) 2732ef41627Sbixia1 builder.create<scf::YieldOp>(loc, res); 274062e515bSbixia1 } else { 2752ef41627Sbixia1 Value ne = 2762ef41627Sbixia1 builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ne, vi, vj); 2772ef41627Sbixia1 scf::IfOp ifOp = builder.create<scf::IfOp>(loc, builder.getIntegerType(1), 2782ef41627Sbixia1 ne, /*else=*/true); 2792ef41627Sbixia1 // If (x[i] != x[j]). 2802ef41627Sbixia1 builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); 2812ef41627Sbixia1 Value lt = 2822ef41627Sbixia1 builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult, vi, vj); 2832ef41627Sbixia1 builder.create<scf::YieldOp>(loc, lt); 284062e515bSbixia1 2852ef41627Sbixia1 // If (x[i] == x[j]). Set up the insertion point for the nested if-stmt that 2862ef41627Sbixia1 // checks the remaining dimensions. 2872ef41627Sbixia1 builder.setInsertionPointToStart(&ifOp.getElseRegion().front()); 2882ef41627Sbixia1 res = ifOp.getResult(0); 289062e515bSbixia1 } 290062e515bSbixia1 2912ef41627Sbixia1 return res; 292062e515bSbixia1 } 293062e515bSbixia1 2942ef41627Sbixia1 /// Creates code to compare whether xs[i] is less than xs[j]. 295062e515bSbixia1 // 296062e515bSbixia1 // The generate IR corresponds to this C like algorithm: 2972ef41627Sbixia1 // if (x0[i] != x0[j]) 2982ef41627Sbixia1 // return x0[i] < x0[j]; 2992ef41627Sbixia1 // else if (x1[j] != x1[i]) 3002ef41627Sbixia1 // return x1[i] < x1[j]; 301062e515bSbixia1 // else 302062e515bSbixia1 // and so on ... 3032ef41627Sbixia1 static Value createInlinedLessThan(OpBuilder &builder, Location loc, 304bfa3bc43SPeiming Liu ValueRange args, AffineMap xPerm, 305bfa3bc43SPeiming Liu uint64_t ny, uint32_t nTrailingP = 0) { 3068550aebdSbixia1 // Compare functions don't use trailing parameters. 3078550aebdSbixia1 (void)nTrailingP; 3088550aebdSbixia1 assert(nTrailingP == 0); 309bfa3bc43SPeiming Liu return createInlinedCompareImplementation(builder, loc, args, xPerm, ny, 3109b800bf7Sbixia1 createLessThanCompare); 311062e515bSbixia1 } 312062e515bSbixia1 3139409bbb2Sbixia1 /// Creates a function to use a binary search to find the insertion point for 3149409bbb2Sbixia1 /// inserting xs[hi] to the sorted values xs[lo..hi). 3159409bbb2Sbixia1 // 3169409bbb2Sbixia1 // The generate IR corresponds to this C like algorithm: 3179409bbb2Sbixia1 // p = hi 3189409bbb2Sbixia1 // while (lo < hi) 3199409bbb2Sbixia1 // mid = (lo + hi) >> 1 3209409bbb2Sbixia1 // if (xs[p] < xs[mid]) 3219409bbb2Sbixia1 // hi = mid 3229409bbb2Sbixia1 // else 3239409bbb2Sbixia1 // lo = mid - 1 3249409bbb2Sbixia1 // return lo; 3259409bbb2Sbixia1 // 3269409bbb2Sbixia1 static void createBinarySearchFunc(OpBuilder &builder, ModuleOp module, 327bfa3bc43SPeiming Liu func::FuncOp func, AffineMap xPerm, 328bfa3bc43SPeiming Liu uint64_t ny, uint32_t nTrailingP = 0) { 3298550aebdSbixia1 // Binary search doesn't use trailing parameters. 3308550aebdSbixia1 (void)nTrailingP; 3318550aebdSbixia1 assert(nTrailingP == 0); 3329409bbb2Sbixia1 OpBuilder::InsertionGuard insertionGuard(builder); 3339409bbb2Sbixia1 Block *entryBlock = func.addEntryBlock(); 3349409bbb2Sbixia1 builder.setInsertionPointToStart(entryBlock); 3359409bbb2Sbixia1 3369409bbb2Sbixia1 Location loc = func.getLoc(); 3379409bbb2Sbixia1 ValueRange args = entryBlock->getArguments(); 3389409bbb2Sbixia1 Value p = args[hiIdx]; 3398550aebdSbixia1 SmallVector<Type, 2> types(2, p.getType()); // Only two types. 3409409bbb2Sbixia1 scf::WhileOp whileOp = builder.create<scf::WhileOp>( 3419409bbb2Sbixia1 loc, types, SmallVector<Value, 2>{args[loIdx], args[hiIdx]}); 3429409bbb2Sbixia1 3439409bbb2Sbixia1 // The before-region of the WhileOp. 3449409bbb2Sbixia1 Block *before = 3459409bbb2Sbixia1 builder.createBlock(&whileOp.getBefore(), {}, types, {loc, loc}); 3469409bbb2Sbixia1 builder.setInsertionPointToEnd(before); 3479409bbb2Sbixia1 Value cond1 = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult, 3489409bbb2Sbixia1 before->getArgument(0), 3499409bbb2Sbixia1 before->getArgument(1)); 3509409bbb2Sbixia1 builder.create<scf::ConditionOp>(loc, cond1, before->getArguments()); 3519409bbb2Sbixia1 3529409bbb2Sbixia1 // The after-region of the WhileOp. 3539409bbb2Sbixia1 Block *after = 3549409bbb2Sbixia1 builder.createBlock(&whileOp.getAfter(), {}, types, {loc, loc}); 3559409bbb2Sbixia1 builder.setInsertionPointToEnd(after); 3569409bbb2Sbixia1 Value lo = after->getArgument(0); 3579409bbb2Sbixia1 Value hi = after->getArgument(1); 3589409bbb2Sbixia1 // Compute mid = (lo + hi) >> 1. 3599409bbb2Sbixia1 Value c1 = constantIndex(builder, loc, 1); 3609409bbb2Sbixia1 Value mid = builder.create<arith::ShRUIOp>( 3619409bbb2Sbixia1 loc, builder.create<arith::AddIOp>(loc, lo, hi), c1); 3629409bbb2Sbixia1 Value midp1 = builder.create<arith::AddIOp>(loc, mid, c1); 3639409bbb2Sbixia1 3649409bbb2Sbixia1 // Compare xs[p] < xs[mid]. 3650e1708ffSAart Bik SmallVector<Value> compareOperands{p, mid}; 366bfa3bc43SPeiming Liu constexpr uint64_t numXBuffers = 1; 3679409bbb2Sbixia1 compareOperands.append(args.begin() + xStartIdx, 3684f729d5aSbixia1 args.begin() + xStartIdx + numXBuffers); 369bfa3bc43SPeiming Liu Value cond2 = createInlinedLessThan(builder, loc, compareOperands, xPerm, ny); 3709409bbb2Sbixia1 // Update lo and hi for the WhileOp as follows: 3719409bbb2Sbixia1 // if (xs[p] < xs[mid])) 3729409bbb2Sbixia1 // hi = mid; 3739409bbb2Sbixia1 // else 3749409bbb2Sbixia1 // lo = mid + 1; 3759409bbb2Sbixia1 Value newLo = builder.create<arith::SelectOp>(loc, cond2, lo, midp1); 3769409bbb2Sbixia1 Value newHi = builder.create<arith::SelectOp>(loc, cond2, mid, hi); 3779409bbb2Sbixia1 builder.create<scf::YieldOp>(loc, ValueRange{newLo, newHi}); 3789409bbb2Sbixia1 3799409bbb2Sbixia1 builder.setInsertionPointAfter(whileOp); 3809409bbb2Sbixia1 builder.create<func::ReturnOp>(loc, whileOp.getResult(0)); 3819409bbb2Sbixia1 } 3829409bbb2Sbixia1 3839b800bf7Sbixia1 /// Creates code to advance i in a loop based on xs[p] as follows: 3849b800bf7Sbixia1 /// while (xs[i] < xs[p]) i += step (step > 0) 3859b800bf7Sbixia1 /// or 3869b800bf7Sbixia1 /// while (xs[i] > xs[p]) i += step (step < 0) 3879b800bf7Sbixia1 /// The routine returns i as well as a boolean value to indicate whether 3889b800bf7Sbixia1 /// xs[i] == xs[p]. 389bfa3bc43SPeiming Liu static std::pair<Value, Value> createScanLoop(OpBuilder &builder, 390bfa3bc43SPeiming Liu ModuleOp module, 391bfa3bc43SPeiming Liu func::FuncOp func, ValueRange xs, 392bfa3bc43SPeiming Liu Value i, Value p, AffineMap xPerm, 393bfa3bc43SPeiming Liu uint64_t ny, int step) { 394062e515bSbixia1 Location loc = func.getLoc(); 3959b800bf7Sbixia1 scf::WhileOp whileOp = 3969b800bf7Sbixia1 builder.create<scf::WhileOp>(loc, TypeRange{i.getType()}, ValueRange{i}); 397062e515bSbixia1 3989b800bf7Sbixia1 Block *before = 3999b800bf7Sbixia1 builder.createBlock(&whileOp.getBefore(), {}, {i.getType()}, {loc}); 4009b800bf7Sbixia1 builder.setInsertionPointToEnd(before); 4010e1708ffSAart Bik SmallVector<Value> compareOperands; 4029b800bf7Sbixia1 if (step > 0) { 4039b800bf7Sbixia1 compareOperands.push_back(before->getArgument(0)); 4049b800bf7Sbixia1 compareOperands.push_back(p); 4059b800bf7Sbixia1 } else { 4069b800bf7Sbixia1 assert(step < 0); 4079b800bf7Sbixia1 compareOperands.push_back(p); 4089b800bf7Sbixia1 compareOperands.push_back(before->getArgument(0)); 4099b800bf7Sbixia1 } 410062e515bSbixia1 compareOperands.append(xs.begin(), xs.end()); 411bfa3bc43SPeiming Liu Value cond = createInlinedLessThan(builder, loc, compareOperands, xPerm, ny); 4129b800bf7Sbixia1 builder.create<scf::ConditionOp>(loc, cond, before->getArguments()); 413062e515bSbixia1 4149b800bf7Sbixia1 Block *after = 4159b800bf7Sbixia1 builder.createBlock(&whileOp.getAfter(), {}, {i.getType()}, {loc}); 4169b800bf7Sbixia1 builder.setInsertionPointToEnd(after); 4179b800bf7Sbixia1 Value cs = constantIndex(builder, loc, step); 4189b800bf7Sbixia1 i = builder.create<arith::AddIOp>(loc, after->getArgument(0), cs); 4199b800bf7Sbixia1 builder.create<scf::YieldOp>(loc, ValueRange{i}); 4209b800bf7Sbixia1 i = whileOp.getResult(0); 4219b800bf7Sbixia1 4229b800bf7Sbixia1 builder.setInsertionPointAfter(whileOp); 4239b800bf7Sbixia1 compareOperands[0] = i; 4249b800bf7Sbixia1 compareOperands[1] = p; 4259b800bf7Sbixia1 Value compareEq = 426bfa3bc43SPeiming Liu createInlinedEqCompare(builder, loc, compareOperands, xPerm, ny); 4279b800bf7Sbixia1 4289b800bf7Sbixia1 return std::make_pair(whileOp.getResult(0), compareEq); 4299b800bf7Sbixia1 } 4309b800bf7Sbixia1 431abb05014Sbixia1 /// Creates and returns an IfOp to compare two elements and swap the elements 432abb05014Sbixia1 /// if compareFunc(data[b], data[a]) returns true. The new insertion point is 433abb05014Sbixia1 /// right after the swap instructions. 434abb05014Sbixia1 static scf::IfOp createCompareThenSwap(OpBuilder &builder, Location loc, 435bfa3bc43SPeiming Liu AffineMap xPerm, uint64_t ny, 436abb05014Sbixia1 SmallVectorImpl<Value> &swapOperands, 437abb05014Sbixia1 SmallVectorImpl<Value> &compareOperands, 438abb05014Sbixia1 Value a, Value b) { 439abb05014Sbixia1 // Compare(data[b], data[a]). 440abb05014Sbixia1 compareOperands[0] = b; 441abb05014Sbixia1 compareOperands[1] = a; 442bfa3bc43SPeiming Liu Value cond = createInlinedLessThan(builder, loc, compareOperands, xPerm, ny); 443abb05014Sbixia1 scf::IfOp ifOp = builder.create<scf::IfOp>(loc, cond, /*else=*/false); 444abb05014Sbixia1 builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); 445abb05014Sbixia1 swapOperands[0] = b; 446abb05014Sbixia1 swapOperands[1] = a; 447bfa3bc43SPeiming Liu createSwap(builder, loc, swapOperands, xPerm, ny); 448abb05014Sbixia1 return ifOp; 449abb05014Sbixia1 } 450abb05014Sbixia1 451abb05014Sbixia1 /// Creates code to insert the 3rd element to a list of two sorted elements. 452bfa3bc43SPeiming Liu static void createInsert3rd(OpBuilder &builder, Location loc, AffineMap xPerm, 453bfa3bc43SPeiming Liu uint64_t ny, SmallVectorImpl<Value> &swapOperands, 454abb05014Sbixia1 SmallVectorImpl<Value> &compareOperands, Value v0, 455abb05014Sbixia1 Value v1, Value v2) { 456bfa3bc43SPeiming Liu scf::IfOp ifOp = createCompareThenSwap(builder, loc, xPerm, ny, swapOperands, 457bfa3bc43SPeiming Liu compareOperands, v1, v2); 458bfa3bc43SPeiming Liu createCompareThenSwap(builder, loc, xPerm, ny, swapOperands, compareOperands, 459bfa3bc43SPeiming Liu v0, v1); 460abb05014Sbixia1 builder.setInsertionPointAfter(ifOp); 461abb05014Sbixia1 } 462abb05014Sbixia1 463abb05014Sbixia1 /// Creates code to sort 3 elements. 464bfa3bc43SPeiming Liu static void createSort3(OpBuilder &builder, Location loc, AffineMap xPerm, 465bfa3bc43SPeiming Liu uint64_t ny, SmallVectorImpl<Value> &swapOperands, 466abb05014Sbixia1 SmallVectorImpl<Value> &compareOperands, Value v0, 467abb05014Sbixia1 Value v1, Value v2) { 468abb05014Sbixia1 // Sort the first 2 elements. 469bfa3bc43SPeiming Liu scf::IfOp ifOp1 = createCompareThenSwap(builder, loc, xPerm, ny, swapOperands, 470bfa3bc43SPeiming Liu compareOperands, v0, v1); 471abb05014Sbixia1 builder.setInsertionPointAfter(ifOp1); 472abb05014Sbixia1 473abb05014Sbixia1 // Insert the 3th element. 474bfa3bc43SPeiming Liu createInsert3rd(builder, loc, xPerm, ny, swapOperands, compareOperands, v0, 475bfa3bc43SPeiming Liu v1, v2); 476abb05014Sbixia1 } 477abb05014Sbixia1 478abb05014Sbixia1 /// Creates code to sort 5 elements. 479bfa3bc43SPeiming Liu static void createSort5(OpBuilder &builder, Location loc, AffineMap xPerm, 480bfa3bc43SPeiming Liu uint64_t ny, SmallVectorImpl<Value> &swapOperands, 481abb05014Sbixia1 SmallVectorImpl<Value> &compareOperands, Value v0, 482abb05014Sbixia1 Value v1, Value v2, Value v3, Value v4) { 483abb05014Sbixia1 // Sort the first 3 elements. 484bfa3bc43SPeiming Liu createSort3(builder, loc, xPerm, ny, swapOperands, compareOperands, v0, v1, 485bfa3bc43SPeiming Liu v2); 486abb05014Sbixia1 487abb05014Sbixia1 auto insert4th = [&]() { 488abb05014Sbixia1 scf::IfOp ifOp = createCompareThenSwap( 489bfa3bc43SPeiming Liu builder, loc, xPerm, ny, swapOperands, compareOperands, v2, v3); 490bfa3bc43SPeiming Liu createInsert3rd(builder, loc, xPerm, ny, swapOperands, compareOperands, v0, 491bfa3bc43SPeiming Liu v1, v2); 492abb05014Sbixia1 builder.setInsertionPointAfter(ifOp); 493abb05014Sbixia1 }; 494abb05014Sbixia1 495abb05014Sbixia1 // Insert the 4th element. 496abb05014Sbixia1 insert4th(); 497abb05014Sbixia1 498abb05014Sbixia1 // Insert the 5th element. 499bfa3bc43SPeiming Liu scf::IfOp ifOp = createCompareThenSwap(builder, loc, xPerm, ny, swapOperands, 500bfa3bc43SPeiming Liu compareOperands, v3, v4); 501abb05014Sbixia1 insert4th(); 502abb05014Sbixia1 builder.setInsertionPointAfter(ifOp); 503abb05014Sbixia1 } 504abb05014Sbixia1 505abb05014Sbixia1 /// Creates a code block to swap the values in indices lo, mi, and hi so that 506abb05014Sbixia1 /// data[lo], data[mi] and data[hi] are sorted in non-decreasing values. When 507abb05014Sbixia1 /// the number of values in range [lo, hi) is more than a threshold, we also 508abb05014Sbixia1 /// include the middle of [lo, mi) and [mi, hi) and sort a total of five values. 5097cec4d16Sbixia1 static void createChoosePivot(OpBuilder &builder, ModuleOp module, 510bfa3bc43SPeiming Liu func::FuncOp func, AffineMap xPerm, uint64_t ny, 511bfa3bc43SPeiming Liu Value lo, Value hi, Value mi, ValueRange args) { 5127cec4d16Sbixia1 SmallVector<Value> compareOperands{mi, lo}; 513bfa3bc43SPeiming Liu constexpr uint64_t numXBuffers = 1; 5147cec4d16Sbixia1 compareOperands.append(args.begin() + xStartIdx, 5157cec4d16Sbixia1 args.begin() + xStartIdx + numXBuffers); 516abb05014Sbixia1 SmallVector<Value> swapOperands{mi, lo}; 5177cec4d16Sbixia1 swapOperands.append(args.begin() + xStartIdx, args.end()); 518abb05014Sbixia1 Location loc = func.getLoc(); 519abb05014Sbixia1 Value c1 = constantIndex(builder, loc, 1); 520abb05014Sbixia1 Value hiP1 = builder.create<arith::AddIOp>(loc, hi, c1); 521abb05014Sbixia1 Value len = builder.create<arith::SubIOp>(loc, hiP1, lo); 522abb05014Sbixia1 Value lenThreshold = constantIndex(builder, loc, 1000); 523abb05014Sbixia1 Value lenCond = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult, 524abb05014Sbixia1 len, lenThreshold); 525abb05014Sbixia1 scf::IfOp lenIf = builder.create<scf::IfOp>(loc, lenCond, /*else=*/true); 526abb05014Sbixia1 527abb05014Sbixia1 // When len < 1000, choose pivot from median of 3 values. 528abb05014Sbixia1 builder.setInsertionPointToStart(&lenIf.getThenRegion().front()); 529bfa3bc43SPeiming Liu createSort3(builder, loc, xPerm, ny, swapOperands, compareOperands, lo, mi, 530bfa3bc43SPeiming Liu hi); 531abb05014Sbixia1 532abb05014Sbixia1 // When len >= 1000, choose pivot from median of 5 values. 533abb05014Sbixia1 builder.setInsertionPointToStart(&lenIf.getElseRegion().front()); 534abb05014Sbixia1 Value miP1 = builder.create<arith::AddIOp>(loc, hi, c1); 535abb05014Sbixia1 Value a = builder.create<arith::AddIOp>(loc, lo, miP1); 536abb05014Sbixia1 // Value a is the middle between [loc, mi]. 537abb05014Sbixia1 a = builder.create<arith::ShRUIOp>(loc, a, c1); 538abb05014Sbixia1 Value b = builder.create<arith::AddIOp>(loc, mi, hiP1); 539abb05014Sbixia1 // Value b is the middle between [mi, hi]. 540abb05014Sbixia1 b = builder.create<arith::ShRUIOp>(loc, b, c1); 541bfa3bc43SPeiming Liu createSort5(builder, loc, xPerm, ny, swapOperands, compareOperands, lo, a, mi, 542bfa3bc43SPeiming Liu b, hi); 543abb05014Sbixia1 544abb05014Sbixia1 builder.setInsertionPointAfter(lenIf); 5457cec4d16Sbixia1 } 5467cec4d16Sbixia1 5479b800bf7Sbixia1 /// Creates a function to perform quick sort partition on the values in the 5489b800bf7Sbixia1 /// range of index [lo, hi), assuming lo < hi. 5499b800bf7Sbixia1 // 5509b800bf7Sbixia1 // The generated IR corresponds to this C like algorithm: 5519b800bf7Sbixia1 // int partition(lo, hi, xs) { 5529b800bf7Sbixia1 // p = (lo+hi)/2 // pivot index 5539b800bf7Sbixia1 // i = lo 5549b800bf7Sbixia1 // j = hi-1 5554176ce61SPeiming Liu // while (true) do { 5569b800bf7Sbixia1 // while (xs[i] < xs[p]) i ++; 5579b800bf7Sbixia1 // i_eq = (xs[i] == xs[p]); 5589b800bf7Sbixia1 // while (xs[j] > xs[p]) j --; 5599b800bf7Sbixia1 // j_eq = (xs[j] == xs[p]); 5604176ce61SPeiming Liu // 5614176ce61SPeiming Liu // if (i >= j) return j + 1; 5624176ce61SPeiming Liu // 5639b800bf7Sbixia1 // if (i < j) { 5649b800bf7Sbixia1 // swap(xs[i], xs[j]) 5659b800bf7Sbixia1 // if (i == p) { 5669b800bf7Sbixia1 // p = j; 5679b800bf7Sbixia1 // } else if (j == p) { 5689b800bf7Sbixia1 // p = i; 5699b800bf7Sbixia1 // } 5709b800bf7Sbixia1 // if (i_eq && j_eq) { 5719b800bf7Sbixia1 // ++i; 5729b800bf7Sbixia1 // --j; 5739b800bf7Sbixia1 // } 5749b800bf7Sbixia1 // } 5759b800bf7Sbixia1 // } 5769b800bf7Sbixia1 // } 5779b800bf7Sbixia1 static void createPartitionFunc(OpBuilder &builder, ModuleOp module, 578bfa3bc43SPeiming Liu func::FuncOp func, AffineMap xPerm, uint64_t ny, 579bfa3bc43SPeiming Liu uint32_t nTrailingP = 0) { 5808550aebdSbixia1 // Quick sort partition doesn't use trailing parameters. 5818550aebdSbixia1 (void)nTrailingP; 5828550aebdSbixia1 assert(nTrailingP == 0); 5839b800bf7Sbixia1 OpBuilder::InsertionGuard insertionGuard(builder); 5849b800bf7Sbixia1 5859b800bf7Sbixia1 Block *entryBlock = func.addEntryBlock(); 5869b800bf7Sbixia1 builder.setInsertionPointToStart(entryBlock); 5879b800bf7Sbixia1 5889b800bf7Sbixia1 Location loc = func.getLoc(); 5899b800bf7Sbixia1 ValueRange args = entryBlock->getArguments(); 5909b800bf7Sbixia1 Value lo = args[loIdx]; 5919b800bf7Sbixia1 Value hi = args[hiIdx]; 5929b800bf7Sbixia1 Value sum = builder.create<arith::AddIOp>(loc, lo, hi); 5939b800bf7Sbixia1 Value c1 = constantIndex(builder, loc, 1); 5949b800bf7Sbixia1 Value p = builder.create<arith::ShRUIOp>(loc, sum, c1); 5959b800bf7Sbixia1 5969b800bf7Sbixia1 Value i = lo; 5979b800bf7Sbixia1 Value j = builder.create<arith::SubIOp>(loc, hi, c1); 598bfa3bc43SPeiming Liu createChoosePivot(builder, module, func, xPerm, ny, i, j, p, args); 5994176ce61SPeiming Liu Value trueVal = constantI1(builder, loc, true); // The value for while (true) 6004176ce61SPeiming Liu SmallVector<Value, 4> operands{i, j, p, trueVal}; // Exactly four values. 6014176ce61SPeiming Liu SmallVector<Type, 4> types{i.getType(), j.getType(), p.getType(), 6024176ce61SPeiming Liu trueVal.getType()}; 6039b800bf7Sbixia1 scf::WhileOp whileOp = builder.create<scf::WhileOp>(loc, types, operands); 6049b800bf7Sbixia1 6059b800bf7Sbixia1 // The before-region of the WhileOp. 6064176ce61SPeiming Liu Block *before = builder.createBlock(&whileOp.getBefore(), {}, types, 6074176ce61SPeiming Liu {loc, loc, loc, loc}); 6089b800bf7Sbixia1 builder.setInsertionPointToEnd(before); 6094176ce61SPeiming Liu builder.create<scf::ConditionOp>(loc, before->getArgument(3), 6104176ce61SPeiming Liu before->getArguments()); 6119b800bf7Sbixia1 6129b800bf7Sbixia1 // The after-region of the WhileOp. 6139b800bf7Sbixia1 Block *after = 6144176ce61SPeiming Liu builder.createBlock(&whileOp.getAfter(), {}, types, {loc, loc, loc, loc}); 6159b800bf7Sbixia1 builder.setInsertionPointToEnd(after); 6169b800bf7Sbixia1 i = after->getArgument(0); 6179b800bf7Sbixia1 j = after->getArgument(1); 6189b800bf7Sbixia1 p = after->getArgument(2); 6199b800bf7Sbixia1 620bfa3bc43SPeiming Liu constexpr uint64_t numXBuffers = 1; 6214f729d5aSbixia1 auto [iresult, iCompareEq] = 6224f729d5aSbixia1 createScanLoop(builder, module, func, args.slice(xStartIdx, numXBuffers), 623bfa3bc43SPeiming Liu i, p, xPerm, ny, 1); 6249b800bf7Sbixia1 i = iresult; 6254f729d5aSbixia1 auto [jresult, jCompareEq] = 6264f729d5aSbixia1 createScanLoop(builder, module, func, args.slice(xStartIdx, numXBuffers), 627bfa3bc43SPeiming Liu j, p, xPerm, ny, -1); 6289b800bf7Sbixia1 j = jresult; 6299b800bf7Sbixia1 6309b800bf7Sbixia1 // If i < j: 6314176ce61SPeiming Liu Value cond = 6324176ce61SPeiming Liu builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult, i, j); 6339b800bf7Sbixia1 scf::IfOp ifOp = builder.create<scf::IfOp>(loc, types, cond, /*else=*/true); 634062e515bSbixia1 builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); 6350e1708ffSAart Bik SmallVector<Value> swapOperands{i, j}; 636062e515bSbixia1 swapOperands.append(args.begin() + xStartIdx, args.end()); 637bfa3bc43SPeiming Liu createSwap(builder, loc, swapOperands, xPerm, ny); 6389b800bf7Sbixia1 // If the pivot is moved, update p with the new pivot. 6399b800bf7Sbixia1 Value icond = 6409b800bf7Sbixia1 builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, i, p); 6419b800bf7Sbixia1 scf::IfOp ifOpI = builder.create<scf::IfOp>(loc, TypeRange{p.getType()}, 6429b800bf7Sbixia1 icond, /*else=*/true); 6439b800bf7Sbixia1 builder.setInsertionPointToStart(&ifOpI.getThenRegion().front()); 6449b800bf7Sbixia1 builder.create<scf::YieldOp>(loc, ValueRange{j}); 6459b800bf7Sbixia1 builder.setInsertionPointToStart(&ifOpI.getElseRegion().front()); 6469b800bf7Sbixia1 Value jcond = 6479b800bf7Sbixia1 builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, j, p); 6489b800bf7Sbixia1 scf::IfOp ifOpJ = builder.create<scf::IfOp>(loc, TypeRange{p.getType()}, 6499b800bf7Sbixia1 jcond, /*else=*/true); 6509b800bf7Sbixia1 builder.setInsertionPointToStart(&ifOpJ.getThenRegion().front()); 6519b800bf7Sbixia1 builder.create<scf::YieldOp>(loc, ValueRange{i}); 6529b800bf7Sbixia1 builder.setInsertionPointToStart(&ifOpJ.getElseRegion().front()); 6539b800bf7Sbixia1 builder.create<scf::YieldOp>(loc, ValueRange{p}); 6549b800bf7Sbixia1 builder.setInsertionPointAfter(ifOpJ); 6559b800bf7Sbixia1 builder.create<scf::YieldOp>(loc, ifOpJ.getResults()); 6569b800bf7Sbixia1 builder.setInsertionPointAfter(ifOpI); 6579b800bf7Sbixia1 Value compareEqIJ = 6589b800bf7Sbixia1 builder.create<arith::AndIOp>(loc, iCompareEq, jCompareEq); 6599b800bf7Sbixia1 scf::IfOp ifOp2 = builder.create<scf::IfOp>( 6609b800bf7Sbixia1 loc, TypeRange{i.getType(), j.getType()}, compareEqIJ, /*else=*/true); 6619b800bf7Sbixia1 builder.setInsertionPointToStart(&ifOp2.getThenRegion().front()); 6629b800bf7Sbixia1 Value i2 = builder.create<arith::AddIOp>(loc, i, c1); 6639b800bf7Sbixia1 Value j2 = builder.create<arith::SubIOp>(loc, j, c1); 6649b800bf7Sbixia1 builder.create<scf::YieldOp>(loc, ValueRange{i2, j2}); 6659b800bf7Sbixia1 builder.setInsertionPointToStart(&ifOp2.getElseRegion().front()); 6669b800bf7Sbixia1 builder.create<scf::YieldOp>(loc, ValueRange{i, j}); 6679b800bf7Sbixia1 builder.setInsertionPointAfter(ifOp2); 6689b800bf7Sbixia1 builder.create<scf::YieldOp>( 6699b800bf7Sbixia1 loc, 6704176ce61SPeiming Liu ValueRange{ifOp2.getResult(0), ifOp2.getResult(1), ifOpI.getResult(0), 6714176ce61SPeiming Liu /*cont=*/constantI1(builder, loc, true)}); 672062e515bSbixia1 6734176ce61SPeiming Liu // False branch for if i < j (i.e., i >= j): 674062e515bSbixia1 builder.setInsertionPointToStart(&ifOp.getElseRegion().front()); 6754176ce61SPeiming Liu p = builder.create<arith::AddIOp>(loc, j, 6764176ce61SPeiming Liu constantOne(builder, loc, j.getType())); 6774176ce61SPeiming Liu builder.create<scf::YieldOp>( 6784176ce61SPeiming Liu loc, ValueRange{i, j, p, /*cont=*/constantI1(builder, loc, false)}); 679062e515bSbixia1 6809b800bf7Sbixia1 // Return for the whileOp. 681062e515bSbixia1 builder.setInsertionPointAfter(ifOp); 6829b800bf7Sbixia1 builder.create<scf::YieldOp>(loc, ifOp.getResults()); 683062e515bSbixia1 6849b800bf7Sbixia1 // Return for the function. 6859b800bf7Sbixia1 builder.setInsertionPointAfter(whileOp); 6869b800bf7Sbixia1 builder.create<func::ReturnOp>(loc, whileOp.getResult(2)); 687062e515bSbixia1 } 688062e515bSbixia1 6893b1c86cdSbixia1 /// Computes (n-2)/n, assuming n has index type. 6903b1c86cdSbixia1 static Value createSubTwoDividedByTwo(OpBuilder &builder, Location loc, 6913b1c86cdSbixia1 Value n) { 6923b1c86cdSbixia1 Value i2 = constantIndex(builder, loc, 2); 6933b1c86cdSbixia1 Value res = builder.create<arith::SubIOp>(loc, n, i2); 6943b1c86cdSbixia1 Value i1 = constantIndex(builder, loc, 1); 6953b1c86cdSbixia1 return builder.create<arith::ShRUIOp>(loc, res, i1); 6963b1c86cdSbixia1 } 6973b1c86cdSbixia1 6983b1c86cdSbixia1 /// Creates a function to heapify the subtree with root `start` within the full 6993b1c86cdSbixia1 /// binary tree in the range of index [first, first + n). 7003b1c86cdSbixia1 // 7013b1c86cdSbixia1 // The generated IR corresponds to this C like algorithm: 7023b1c86cdSbixia1 // void shiftDown(first, start, n, data) { 7033b1c86cdSbixia1 // if (n >= 2) { 7043b1c86cdSbixia1 // child = start - first 7053b1c86cdSbixia1 // if ((n-2)/2 >= child) { 7063b1c86cdSbixia1 // // Left child exists. 7073b1c86cdSbixia1 // child = child * 2 + 1 // Initialize the bigger child to left child. 7083b1c86cdSbixia1 // childIndex = child + first 7093b1c86cdSbixia1 // if (child+1 < n && data[childIndex] < data[childIndex+1]) 7103b1c86cdSbixia1 // // Right child exits and is bigger. 7113b1c86cdSbixia1 // childIndex++; child++; 7123b1c86cdSbixia1 // // Shift data[start] down to where it belongs in the subtree. 7133b1c86cdSbixia1 // while (data[start] < data[childIndex) { 7143b1c86cdSbixia1 // swap(data[start], data[childIndex]) 7153b1c86cdSbixia1 // start = childIndex 7163b1c86cdSbixia1 // if ((n - 2)/2 >= child) { 7173b1c86cdSbixia1 // // Left child exists. 7183b1c86cdSbixia1 // child = 2*child + 1 7193b1c86cdSbixia1 // childIndex = child + 1 7203b1c86cdSbixia1 // if (child + 1) < n && data[childIndex] < data[childIndex+1] 7213b1c86cdSbixia1 // childIndex++; child++; 7223b1c86cdSbixia1 // } 7233b1c86cdSbixia1 // } 7243b1c86cdSbixia1 // } 7253b1c86cdSbixia1 // } 7263b1c86cdSbixia1 // } 7273b1c86cdSbixia1 // 7283b1c86cdSbixia1 static void createShiftDownFunc(OpBuilder &builder, ModuleOp module, 729bfa3bc43SPeiming Liu func::FuncOp func, AffineMap xPerm, uint64_t ny, 730bfa3bc43SPeiming Liu uint32_t nTrailingP) { 7313b1c86cdSbixia1 // The value n is passed in as a trailing parameter. 7323b1c86cdSbixia1 assert(nTrailingP == 1); 7333b1c86cdSbixia1 OpBuilder::InsertionGuard insertionGuard(builder); 7343b1c86cdSbixia1 Block *entryBlock = func.addEntryBlock(); 7353b1c86cdSbixia1 builder.setInsertionPointToStart(entryBlock); 7363b1c86cdSbixia1 7373b1c86cdSbixia1 Location loc = func.getLoc(); 7383b1c86cdSbixia1 Value n = entryBlock->getArguments().back(); 7393b1c86cdSbixia1 ValueRange args = entryBlock->getArguments().drop_back(); 7403b1c86cdSbixia1 Value first = args[loIdx]; 7413b1c86cdSbixia1 Value start = args[hiIdx]; 7423b1c86cdSbixia1 7433b1c86cdSbixia1 // If (n >= 2). 7443b1c86cdSbixia1 Value c2 = constantIndex(builder, loc, 2); 7453b1c86cdSbixia1 Value condN = 7463b1c86cdSbixia1 builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::uge, n, c2); 7473b1c86cdSbixia1 scf::IfOp ifN = builder.create<scf::IfOp>(loc, condN, /*else=*/false); 7483b1c86cdSbixia1 builder.setInsertionPointToStart(&ifN.getThenRegion().front()); 7493b1c86cdSbixia1 Value child = builder.create<arith::SubIOp>(loc, start, first); 7503b1c86cdSbixia1 7513b1c86cdSbixia1 // If ((n-2)/2 >= child). 7523b1c86cdSbixia1 Value t = createSubTwoDividedByTwo(builder, loc, n); 7533b1c86cdSbixia1 Value condNc = 7543b1c86cdSbixia1 builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::uge, t, child); 7553b1c86cdSbixia1 scf::IfOp ifNc = builder.create<scf::IfOp>(loc, condNc, /*else=*/false); 7563b1c86cdSbixia1 7573b1c86cdSbixia1 builder.setInsertionPointToStart(&ifNc.getThenRegion().front()); 7583b1c86cdSbixia1 Value c1 = constantIndex(builder, loc, 1); 7593b1c86cdSbixia1 SmallVector<Value> compareOperands{start, start}; 760bfa3bc43SPeiming Liu constexpr uint64_t numXBuffers = 1; 7613b1c86cdSbixia1 compareOperands.append(args.begin() + xStartIdx, 7623b1c86cdSbixia1 args.begin() + xStartIdx + numXBuffers); 7633b1c86cdSbixia1 7643b1c86cdSbixia1 // Generate code to inspect the children of 'r' and return the larger child 7653b1c86cdSbixia1 // as follows: 7663b1c86cdSbixia1 // child = r * 2 + 1 // Left child. 7673b1c86cdSbixia1 // childIndex = child + first 7683b1c86cdSbixia1 // if (child+1 < n && data[childIndex] < data[childIndex+1]) 7693b1c86cdSbixia1 // childIndex ++; child ++ // Right child is bigger. 7703b1c86cdSbixia1 auto getLargerChild = [&](Value r) -> std::pair<Value, Value> { 7713b1c86cdSbixia1 Value lChild = builder.create<arith::ShLIOp>(loc, r, c1); 7723b1c86cdSbixia1 lChild = builder.create<arith::AddIOp>(loc, lChild, c1); 7733b1c86cdSbixia1 Value lChildIdx = builder.create<arith::AddIOp>(loc, lChild, first); 7743b1c86cdSbixia1 Value rChild = builder.create<arith::AddIOp>(loc, lChild, c1); 7753b1c86cdSbixia1 Value cond1 = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult, 7763b1c86cdSbixia1 rChild, n); 7773b1c86cdSbixia1 SmallVector<Type, 2> ifTypes(2, r.getType()); 7783b1c86cdSbixia1 scf::IfOp if1 = 7793b1c86cdSbixia1 builder.create<scf::IfOp>(loc, ifTypes, cond1, /*else=*/true); 7803b1c86cdSbixia1 builder.setInsertionPointToStart(&if1.getThenRegion().front()); 7813b1c86cdSbixia1 Value rChildIdx = builder.create<arith::AddIOp>(loc, rChild, first); 7823b1c86cdSbixia1 // Compare data[left] < data[right]. 7833b1c86cdSbixia1 compareOperands[0] = lChildIdx; 7843b1c86cdSbixia1 compareOperands[1] = rChildIdx; 7852ef41627Sbixia1 Value cond2 = 786bfa3bc43SPeiming Liu createInlinedLessThan(builder, loc, compareOperands, xPerm, ny); 7873b1c86cdSbixia1 scf::IfOp if2 = 7883b1c86cdSbixia1 builder.create<scf::IfOp>(loc, ifTypes, cond2, /*else=*/true); 7893b1c86cdSbixia1 builder.setInsertionPointToStart(&if2.getThenRegion().front()); 7903b1c86cdSbixia1 builder.create<scf::YieldOp>(loc, ValueRange{rChild, rChildIdx}); 7913b1c86cdSbixia1 builder.setInsertionPointToStart(&if2.getElseRegion().front()); 7923b1c86cdSbixia1 builder.create<scf::YieldOp>(loc, ValueRange{lChild, lChildIdx}); 7933b1c86cdSbixia1 builder.setInsertionPointAfter(if2); 7943b1c86cdSbixia1 builder.create<scf::YieldOp>(loc, if2.getResults()); 7953b1c86cdSbixia1 builder.setInsertionPointToStart(&if1.getElseRegion().front()); 7963b1c86cdSbixia1 builder.create<scf::YieldOp>(loc, ValueRange{lChild, lChildIdx}); 7973b1c86cdSbixia1 builder.setInsertionPointAfter(if1); 7983b1c86cdSbixia1 return std::make_pair(if1.getResult(0), if1.getResult(1)); 7993b1c86cdSbixia1 }; 8003b1c86cdSbixia1 8013b1c86cdSbixia1 Value childIdx; 8023b1c86cdSbixia1 std::tie(child, childIdx) = getLargerChild(child); 8033b1c86cdSbixia1 8043b1c86cdSbixia1 // While (data[start] < data[childIndex]). 8053b1c86cdSbixia1 SmallVector<Type, 3> types(3, child.getType()); 8063b1c86cdSbixia1 scf::WhileOp whileOp = builder.create<scf::WhileOp>( 8073b1c86cdSbixia1 loc, types, SmallVector<Value, 2>{start, child, childIdx}); 8083b1c86cdSbixia1 8093b1c86cdSbixia1 // The before-region of the WhileOp. 8103b1c86cdSbixia1 SmallVector<Location, 3> locs(3, loc); 8113b1c86cdSbixia1 Block *before = builder.createBlock(&whileOp.getBefore(), {}, types, locs); 8123b1c86cdSbixia1 builder.setInsertionPointToEnd(before); 8133b1c86cdSbixia1 start = before->getArgument(0); 8143b1c86cdSbixia1 childIdx = before->getArgument(2); 8153b1c86cdSbixia1 compareOperands[0] = start; 8163b1c86cdSbixia1 compareOperands[1] = childIdx; 817bfa3bc43SPeiming Liu Value cond = createInlinedLessThan(builder, loc, compareOperands, xPerm, ny); 8183b1c86cdSbixia1 builder.create<scf::ConditionOp>(loc, cond, before->getArguments()); 8193b1c86cdSbixia1 8203b1c86cdSbixia1 // The after-region of the WhileOp. 8213b1c86cdSbixia1 Block *after = builder.createBlock(&whileOp.getAfter(), {}, types, locs); 8223b1c86cdSbixia1 start = after->getArgument(0); 8233b1c86cdSbixia1 child = after->getArgument(1); 8243b1c86cdSbixia1 childIdx = after->getArgument(2); 8253b1c86cdSbixia1 SmallVector<Value> swapOperands{start, childIdx}; 8263b1c86cdSbixia1 swapOperands.append(args.begin() + xStartIdx, args.end()); 827bfa3bc43SPeiming Liu createSwap(builder, loc, swapOperands, xPerm, ny); 8283b1c86cdSbixia1 start = childIdx; 8293b1c86cdSbixia1 Value cond2 = 8303b1c86cdSbixia1 builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::uge, t, child); 8313b1c86cdSbixia1 scf::IfOp if2 = builder.create<scf::IfOp>( 8323b1c86cdSbixia1 loc, TypeRange{child.getType(), child.getType()}, cond2, /*else=*/true); 8333b1c86cdSbixia1 builder.setInsertionPointToStart(&if2.getThenRegion().front()); 8343b1c86cdSbixia1 auto [newChild, newChildIdx] = getLargerChild(child); 8353b1c86cdSbixia1 builder.create<scf::YieldOp>(loc, ValueRange{newChild, newChildIdx}); 8363b1c86cdSbixia1 builder.setInsertionPointToStart(&if2.getElseRegion().front()); 8373b1c86cdSbixia1 builder.create<scf::YieldOp>(loc, ValueRange{child, childIdx}); 8383b1c86cdSbixia1 builder.setInsertionPointAfter(if2); 8393b1c86cdSbixia1 builder.create<scf::YieldOp>( 8403b1c86cdSbixia1 loc, ValueRange{start, if2.getResult(0), if2.getResult(1)}); 8413b1c86cdSbixia1 8423b1c86cdSbixia1 builder.setInsertionPointAfter(ifN); 8433b1c86cdSbixia1 builder.create<func::ReturnOp>(loc); 8443b1c86cdSbixia1 } 8453b1c86cdSbixia1 8463b1c86cdSbixia1 /// Creates a function to perform heap sort on the values in the range of index 8473b1c86cdSbixia1 /// [lo, hi) with the assumption hi - lo >= 2. 8483b1c86cdSbixia1 // 8493b1c86cdSbixia1 // The generate IR corresponds to this C like algorithm: 8503b1c86cdSbixia1 // void heapSort(lo, hi, data) { 8513b1c86cdSbixia1 // n = hi - lo 8523b1c86cdSbixia1 // for i = (n-2)/2 downto 0 8533b1c86cdSbixia1 // shiftDown(lo, lo+i, n) 8543b1c86cdSbixia1 // 8553b1c86cdSbixia1 // for l = n downto 2 8563b1c86cdSbixia1 // swap(lo, lo+l-1) 8573b1c86cdSbixia1 // shiftdown(lo, lo, l-1) 8583b1c86cdSbixia1 // } 8593b1c86cdSbixia1 static void createHeapSortFunc(OpBuilder &builder, ModuleOp module, 860bfa3bc43SPeiming Liu func::FuncOp func, AffineMap xPerm, uint64_t ny, 861bfa3bc43SPeiming Liu uint32_t nTrailingP) { 8623b1c86cdSbixia1 // Heap sort function doesn't have trailing parameters. 8633b1c86cdSbixia1 (void)nTrailingP; 8643b1c86cdSbixia1 assert(nTrailingP == 0); 8653b1c86cdSbixia1 OpBuilder::InsertionGuard insertionGuard(builder); 8663b1c86cdSbixia1 Block *entryBlock = func.addEntryBlock(); 8673b1c86cdSbixia1 builder.setInsertionPointToStart(entryBlock); 8683b1c86cdSbixia1 8693b1c86cdSbixia1 Location loc = func.getLoc(); 8703b1c86cdSbixia1 ValueRange args = entryBlock->getArguments(); 8713b1c86cdSbixia1 Value lo = args[loIdx]; 8723b1c86cdSbixia1 Value hi = args[hiIdx]; 8733b1c86cdSbixia1 Value n = builder.create<arith::SubIOp>(loc, hi, lo); 8743b1c86cdSbixia1 8753b1c86cdSbixia1 // For i = (n-2)/2 downto 0. 8763b1c86cdSbixia1 Value c0 = constantIndex(builder, loc, 0); 8773b1c86cdSbixia1 Value c1 = constantIndex(builder, loc, 1); 8783b1c86cdSbixia1 Value s = createSubTwoDividedByTwo(builder, loc, n); 8793b1c86cdSbixia1 Value up = builder.create<arith::AddIOp>(loc, s, c1); 8803b1c86cdSbixia1 scf::ForOp forI = builder.create<scf::ForOp>(loc, c0, up, c1); 8813b1c86cdSbixia1 builder.setInsertionPointToStart(forI.getBody()); 8823b1c86cdSbixia1 Value i = builder.create<arith::SubIOp>(loc, s, forI.getInductionVar()); 8833b1c86cdSbixia1 Value lopi = builder.create<arith::AddIOp>(loc, lo, i); 8843b1c86cdSbixia1 SmallVector<Value> shiftDownOperands = {lo, lopi}; 8853b1c86cdSbixia1 shiftDownOperands.append(args.begin() + xStartIdx, args.end()); 8863b1c86cdSbixia1 shiftDownOperands.push_back(n); 8873b1c86cdSbixia1 FlatSymbolRefAttr shiftDownFunc = getMangledSortHelperFunc( 888bfa3bc43SPeiming Liu builder, func, TypeRange(), kShiftDownFuncNamePrefix, xPerm, ny, 8893b1c86cdSbixia1 shiftDownOperands, createShiftDownFunc, /*nTrailingP=*/1); 8903b1c86cdSbixia1 builder.create<func::CallOp>(loc, shiftDownFunc, TypeRange(), 8913b1c86cdSbixia1 shiftDownOperands); 8923b1c86cdSbixia1 8933b1c86cdSbixia1 builder.setInsertionPointAfter(forI); 8943b1c86cdSbixia1 // For l = n downto 2. 8953b1c86cdSbixia1 up = builder.create<arith::SubIOp>(loc, n, c1); 8963b1c86cdSbixia1 scf::ForOp forL = builder.create<scf::ForOp>(loc, c0, up, c1); 8973b1c86cdSbixia1 builder.setInsertionPointToStart(forL.getBody()); 8983b1c86cdSbixia1 Value l = builder.create<arith::SubIOp>(loc, n, forL.getInductionVar()); 8993b1c86cdSbixia1 Value loplm1 = builder.create<arith::AddIOp>(loc, lo, l); 9003b1c86cdSbixia1 loplm1 = builder.create<arith::SubIOp>(loc, loplm1, c1); 9013b1c86cdSbixia1 SmallVector<Value> swapOperands{lo, loplm1}; 9023b1c86cdSbixia1 swapOperands.append(args.begin() + xStartIdx, args.end()); 903bfa3bc43SPeiming Liu createSwap(builder, loc, swapOperands, xPerm, ny); 9043b1c86cdSbixia1 shiftDownOperands[1] = lo; 9053b1c86cdSbixia1 shiftDownOperands[shiftDownOperands.size() - 1] = 9063b1c86cdSbixia1 builder.create<arith::SubIOp>(loc, l, c1); 9073b1c86cdSbixia1 builder.create<func::CallOp>(loc, shiftDownFunc, TypeRange(), 9083b1c86cdSbixia1 shiftDownOperands); 9093b1c86cdSbixia1 9103b1c86cdSbixia1 builder.setInsertionPointAfter(forL); 9113b1c86cdSbixia1 builder.create<func::ReturnOp>(loc); 9123b1c86cdSbixia1 } 9133b1c86cdSbixia1 914f6424d11Sbixia1 /// A helper for generating code to perform quick sort. It partitions [lo, hi), 915f6424d11Sbixia1 /// recursively calls quick sort to process the smaller partition and returns 916f6424d11Sbixia1 /// the bigger partition to be processed by the enclosed while-loop. 917f6424d11Sbixia1 static std::pair<Value, Value> 918f6424d11Sbixia1 createQuickSort(OpBuilder &builder, ModuleOp module, func::FuncOp func, 919bfa3bc43SPeiming Liu ValueRange args, AffineMap xPerm, uint64_t ny, 920f6424d11Sbixia1 uint32_t nTrailingP) { 921062e515bSbixia1 MLIRContext *context = module.getContext(); 922062e515bSbixia1 Location loc = func.getLoc(); 923062e515bSbixia1 Value lo = args[loIdx]; 924062e515bSbixia1 Value hi = args[hiIdx]; 9254176ce61SPeiming Liu SmallVector<Type, 2> types(2, lo.getType()); // Only two types. 9264176ce61SPeiming Liu 927062e515bSbixia1 FlatSymbolRefAttr partitionFunc = getMangledSortHelperFunc( 928bfa3bc43SPeiming Liu builder, func, {IndexType::get(context)}, kPartitionFuncNamePrefix, xPerm, 929bfa3bc43SPeiming Liu ny, args.drop_back(nTrailingP), createPartitionFunc); 930f6424d11Sbixia1 Value p = builder 931f6424d11Sbixia1 .create<func::CallOp>(loc, partitionFunc, 932a1507668Sbixia1 TypeRange{IndexType::get(context)}, 933f6424d11Sbixia1 args.drop_back(nTrailingP)) 934f6424d11Sbixia1 .getResult(0); 9354176ce61SPeiming Liu 936f6424d11Sbixia1 Value lenLow = builder.create<arith::SubIOp>(loc, p, lo); 937f6424d11Sbixia1 Value lenHigh = builder.create<arith::SubIOp>(loc, hi, p); 9384176ce61SPeiming Liu // Partition already sorts array with len <= 2 9394176ce61SPeiming Liu Value c2 = constantIndex(builder, loc, 2); 9404176ce61SPeiming Liu Value len = builder.create<arith::SubIOp>(loc, hi, lo); 9414176ce61SPeiming Liu Value lenGtTwo = 9424176ce61SPeiming Liu builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ugt, len, c2); 9434176ce61SPeiming Liu scf::IfOp ifLenGtTwo = 9444176ce61SPeiming Liu builder.create<scf::IfOp>(loc, types, lenGtTwo, /*else=*/true); 9454176ce61SPeiming Liu builder.setInsertionPointToStart(&ifLenGtTwo.getElseRegion().front()); 9464176ce61SPeiming Liu // Returns an empty range to mark the entire region is fully sorted. 9474176ce61SPeiming Liu builder.create<scf::YieldOp>(loc, ValueRange{lo, lo}); 9484176ce61SPeiming Liu 9494176ce61SPeiming Liu // Else len > 2, need recursion. 9504176ce61SPeiming Liu builder.setInsertionPointToStart(&ifLenGtTwo.getThenRegion().front()); 951f6424d11Sbixia1 Value cond = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ule, 952f6424d11Sbixia1 lenLow, lenHigh); 953062e515bSbixia1 954851f85ffSMatthias Springer Value c0 = constantIndex(builder, loc, 0); 955f6424d11Sbixia1 scf::IfOp ifOp = builder.create<scf::IfOp>(loc, types, cond, /*else=*/true); 956062e515bSbixia1 957f6424d11Sbixia1 auto mayRecursion = [&](Value low, Value high, Value len) { 958f6424d11Sbixia1 Value cond = 959f6424d11Sbixia1 builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ne, len, c0); 960f6424d11Sbixia1 scf::IfOp ifOp = builder.create<scf::IfOp>(loc, cond, /*else=*/false); 961f6424d11Sbixia1 builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); 962f6424d11Sbixia1 SmallVector<Value> operands{low, high}; 963f6424d11Sbixia1 operands.append(args.begin() + xStartIdx, args.end()); 964f6424d11Sbixia1 builder.create<func::CallOp>(loc, func, operands); 965f6424d11Sbixia1 builder.setInsertionPointAfter(ifOp); 966f6424d11Sbixia1 }; 967f6424d11Sbixia1 968f6424d11Sbixia1 // Recursively call quickSort to process the smaller partition and return 969f6424d11Sbixia1 // the bigger partition to be processed by the enclosed while-loop. 970f6424d11Sbixia1 builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); 971f6424d11Sbixia1 mayRecursion(lo, p, lenLow); 9724176ce61SPeiming Liu builder.create<scf::YieldOp>(loc, ValueRange{p, hi}); 973f6424d11Sbixia1 974f6424d11Sbixia1 builder.setInsertionPointToStart(&ifOp.getElseRegion().front()); 9754176ce61SPeiming Liu mayRecursion(p, hi, lenHigh); 976f6424d11Sbixia1 builder.create<scf::YieldOp>(loc, ValueRange{lo, p}); 977f6424d11Sbixia1 978f6424d11Sbixia1 builder.setInsertionPointAfter(ifOp); 9794176ce61SPeiming Liu builder.create<scf::YieldOp>(loc, ifOp.getResults()); 9804176ce61SPeiming Liu 9814176ce61SPeiming Liu builder.setInsertionPointAfter(ifLenGtTwo); 9824176ce61SPeiming Liu return std::make_pair(ifLenGtTwo.getResult(0), ifLenGtTwo.getResult(1)); 983062e515bSbixia1 } 984062e515bSbixia1 9859409bbb2Sbixia1 /// Creates a function to perform insertion sort on the values in the range of 9869409bbb2Sbixia1 /// index [lo, hi). 9879409bbb2Sbixia1 // 9889409bbb2Sbixia1 // The generate IR corresponds to this C like algorithm: 9899409bbb2Sbixia1 // void insertionSort(lo, hi, data) { 9909409bbb2Sbixia1 // for (i = lo+1; i < hi; i++) { 9919409bbb2Sbixia1 // d = data[i]; 9929409bbb2Sbixia1 // p = binarySearch(lo, i-1, data) 9939409bbb2Sbixia1 // for (j = 0; j > i - p; j++) 9949409bbb2Sbixia1 // data[i-j] = data[i-j-1] 9959409bbb2Sbixia1 // data[p] = d 9969409bbb2Sbixia1 // } 9979409bbb2Sbixia1 // } 9989409bbb2Sbixia1 static void createSortStableFunc(OpBuilder &builder, ModuleOp module, 999bfa3bc43SPeiming Liu func::FuncOp func, AffineMap xPerm, 1000bfa3bc43SPeiming Liu uint64_t ny, uint32_t nTrailingP) { 10018550aebdSbixia1 // Stable sort function doesn't use trailing parameters. 10028550aebdSbixia1 (void)nTrailingP; 10038550aebdSbixia1 assert(nTrailingP == 0); 10049409bbb2Sbixia1 OpBuilder::InsertionGuard insertionGuard(builder); 10059409bbb2Sbixia1 Block *entryBlock = func.addEntryBlock(); 10069409bbb2Sbixia1 builder.setInsertionPointToStart(entryBlock); 10079409bbb2Sbixia1 10089409bbb2Sbixia1 MLIRContext *context = module.getContext(); 10099409bbb2Sbixia1 Location loc = func.getLoc(); 10109409bbb2Sbixia1 ValueRange args = entryBlock->getArguments(); 10119409bbb2Sbixia1 Value c1 = constantIndex(builder, loc, 1); 10129409bbb2Sbixia1 Value lo = args[loIdx]; 10139409bbb2Sbixia1 Value hi = args[hiIdx]; 10149409bbb2Sbixia1 Value lop1 = builder.create<arith::AddIOp>(loc, lo, c1); 10159409bbb2Sbixia1 10169409bbb2Sbixia1 // Start the outer for-stmt with induction variable i. 10179409bbb2Sbixia1 scf::ForOp forOpI = builder.create<scf::ForOp>(loc, lop1, hi, c1); 10189409bbb2Sbixia1 builder.setInsertionPointToStart(forOpI.getBody()); 10199409bbb2Sbixia1 Value i = forOpI.getInductionVar(); 10209409bbb2Sbixia1 10219409bbb2Sbixia1 // Binary search to find the insertion point p. 10220e1708ffSAart Bik SmallVector<Value> operands{lo, i}; 10234f729d5aSbixia1 operands.append(args.begin() + xStartIdx, args.end()); 10249409bbb2Sbixia1 FlatSymbolRefAttr searchFunc = getMangledSortHelperFunc( 1025bfa3bc43SPeiming Liu builder, func, {IndexType::get(context)}, kBinarySearchFuncNamePrefix, 1026bfa3bc43SPeiming Liu xPerm, ny, operands, createBinarySearchFunc); 10279409bbb2Sbixia1 Value p = builder 10289409bbb2Sbixia1 .create<func::CallOp>(loc, searchFunc, TypeRange{c1.getType()}, 10299409bbb2Sbixia1 operands) 10309409bbb2Sbixia1 .getResult(0); 10319409bbb2Sbixia1 10329409bbb2Sbixia1 // Move the value at data[i] to a temporary location. 10334f729d5aSbixia1 operands[0] = operands[1] = i; 10340e1708ffSAart Bik SmallVector<Value> d; 10354f729d5aSbixia1 forEachIJPairInAllBuffers( 1036bfa3bc43SPeiming Liu builder, loc, operands, xPerm, ny, 10374f729d5aSbixia1 [&](uint64_t unused, Value i, Value unused2, Value buffer) { 10384f729d5aSbixia1 d.push_back(builder.create<memref::LoadOp>(loc, buffer, i)); 10394f729d5aSbixia1 }); 10409409bbb2Sbixia1 10419409bbb2Sbixia1 // Start the inner for-stmt with induction variable j, for moving data[p..i) 10429409bbb2Sbixia1 // to data[p+1..i+1). 10439409bbb2Sbixia1 Value imp = builder.create<arith::SubIOp>(loc, i, p); 10449409bbb2Sbixia1 Value c0 = constantIndex(builder, loc, 0); 10459409bbb2Sbixia1 scf::ForOp forOpJ = builder.create<scf::ForOp>(loc, c0, imp, c1); 10469409bbb2Sbixia1 builder.setInsertionPointToStart(forOpJ.getBody()); 10479409bbb2Sbixia1 Value j = forOpJ.getInductionVar(); 10489409bbb2Sbixia1 Value imj = builder.create<arith::SubIOp>(loc, i, j); 10494f729d5aSbixia1 operands[1] = imj; 10504f729d5aSbixia1 operands[0] = builder.create<arith::SubIOp>(loc, imj, c1); 10514f729d5aSbixia1 forEachIJPairInAllBuffers( 1052bfa3bc43SPeiming Liu builder, loc, operands, xPerm, ny, 10534f729d5aSbixia1 [&](uint64_t unused, Value imjm1, Value imj, Value buffer) { 10544f729d5aSbixia1 Value t = builder.create<memref::LoadOp>(loc, buffer, imjm1); 10554f729d5aSbixia1 builder.create<memref::StoreOp>(loc, t, buffer, imj); 10564f729d5aSbixia1 }); 10579409bbb2Sbixia1 10589409bbb2Sbixia1 // Store the value at data[i] to data[p]. 10599409bbb2Sbixia1 builder.setInsertionPointAfter(forOpJ); 10604f729d5aSbixia1 operands[0] = operands[1] = p; 10614f729d5aSbixia1 forEachIJPairInAllBuffers( 1062bfa3bc43SPeiming Liu builder, loc, operands, xPerm, ny, 10634f729d5aSbixia1 [&](uint64_t k, Value p, Value usused, Value buffer) { 10644f729d5aSbixia1 builder.create<memref::StoreOp>(loc, d[k], buffer, p); 10654f729d5aSbixia1 }); 10669409bbb2Sbixia1 10679409bbb2Sbixia1 builder.setInsertionPointAfter(forOpI); 10689409bbb2Sbixia1 builder.create<func::ReturnOp>(loc); 10699409bbb2Sbixia1 } 10709409bbb2Sbixia1 1071a1507668Sbixia1 /// Creates a function to perform quick sort or a hybrid quick sort on the 1072a1507668Sbixia1 /// values in the range of index [lo, hi). 1073a1507668Sbixia1 // 1074a1507668Sbixia1 // 1075a1507668Sbixia1 // When nTrailingP == 0, the generated IR corresponds to this C like algorithm: 1076a1507668Sbixia1 // void quickSort(lo, hi, data) { 1077f6424d11Sbixia1 // while (lo + 1 < hi) { 1078a1507668Sbixia1 // p = partition(low, high, data); 1079f6424d11Sbixia1 // if (len(lo, p) < len(p+1, hi)) { 1080a1507668Sbixia1 // quickSort(lo, p, data); 1081f6424d11Sbixia1 // lo = p+1; 1082f6424d11Sbixia1 // } else { 1083a1507668Sbixia1 // quickSort(p + 1, hi, data); 1084f6424d11Sbixia1 // hi = p; 1085f6424d11Sbixia1 // } 1086a1507668Sbixia1 // } 1087a1507668Sbixia1 // } 1088a1507668Sbixia1 // 1089a1507668Sbixia1 // When nTrailingP == 1, the generated IR corresponds to this C like algorithm: 1090a1507668Sbixia1 // void hybridQuickSort(lo, hi, data, depthLimit) { 1091f6424d11Sbixia1 // while (lo + 1 < hi) { 1092a1507668Sbixia1 // len = hi - lo; 1093a1507668Sbixia1 // if (len <= limit) { 1094a1507668Sbixia1 // insertionSort(lo, hi, data); 1095a1507668Sbixia1 // } else { 1096a1507668Sbixia1 // depthLimit --; 1097a1507668Sbixia1 // if (depthLimit <= 0) { 1098a1507668Sbixia1 // heapSort(lo, hi, data); 1099a1507668Sbixia1 // } else { 1100a1507668Sbixia1 // p = partition(low, high, data); 1101f6424d11Sbixia1 // if (len(lo, p) < len(p+1, hi)) { 1102f6424d11Sbixia1 // quickSort(lo, p, data, depthLimit); 1103f6424d11Sbixia1 // lo = p+1; 1104f6424d11Sbixia1 // } else { 1105f6424d11Sbixia1 // quickSort(p + 1, hi, data, depthLimit); 1106f6424d11Sbixia1 // hi = p; 1107a1507668Sbixia1 // } 1108f6424d11Sbixia1 // } 1109a1507668Sbixia1 // } 1110a1507668Sbixia1 // } 1111a1507668Sbixia1 // } 1112a1507668Sbixia1 // 1113a1507668Sbixia1 static void createQuickSortFunc(OpBuilder &builder, ModuleOp module, 1114bfa3bc43SPeiming Liu func::FuncOp func, AffineMap xPerm, uint64_t ny, 1115bfa3bc43SPeiming Liu uint32_t nTrailingP) { 1116a1507668Sbixia1 assert(nTrailingP == 1 || nTrailingP == 0); 1117a1507668Sbixia1 bool isHybrid = (nTrailingP == 1); 1118a1507668Sbixia1 OpBuilder::InsertionGuard insertionGuard(builder); 1119a1507668Sbixia1 Block *entryBlock = func.addEntryBlock(); 1120a1507668Sbixia1 builder.setInsertionPointToStart(entryBlock); 1121a1507668Sbixia1 1122a1507668Sbixia1 Location loc = func.getLoc(); 1123f6424d11Sbixia1 SmallVector<Value> args; 1124f6424d11Sbixia1 args.append(entryBlock->getArguments().begin(), 1125f6424d11Sbixia1 entryBlock->getArguments().end()); 1126a1507668Sbixia1 Value lo = args[loIdx]; 1127a1507668Sbixia1 Value hi = args[hiIdx]; 1128f6424d11Sbixia1 SmallVector<Type, 2> types(2, lo.getType()); // Only two types. 1129f6424d11Sbixia1 scf::WhileOp whileOp = 1130f6424d11Sbixia1 builder.create<scf::WhileOp>(loc, types, SmallVector<Value, 2>{lo, hi}); 1131a1507668Sbixia1 1132f6424d11Sbixia1 // The before-region of the WhileOp. 1133f6424d11Sbixia1 Block *before = 1134f6424d11Sbixia1 builder.createBlock(&whileOp.getBefore(), {}, types, {loc, loc}); 1135f6424d11Sbixia1 builder.setInsertionPointToEnd(before); 1136f6424d11Sbixia1 lo = before->getArgument(0); 1137f6424d11Sbixia1 hi = before->getArgument(1); 1138f6424d11Sbixia1 Value loP1 = 1139f6424d11Sbixia1 builder.create<arith::AddIOp>(loc, lo, constantIndex(builder, loc, 1)); 1140f6424d11Sbixia1 Value needSort = 1141f6424d11Sbixia1 builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult, loP1, hi); 1142f6424d11Sbixia1 builder.create<scf::ConditionOp>(loc, needSort, before->getArguments()); 1143f6424d11Sbixia1 1144f6424d11Sbixia1 // The after-region of the WhileOp. 1145f6424d11Sbixia1 Block *after = 1146f6424d11Sbixia1 builder.createBlock(&whileOp.getAfter(), {}, types, {loc, loc}); 1147f6424d11Sbixia1 builder.setInsertionPointToEnd(after); 1148f6424d11Sbixia1 lo = after->getArgument(0); 1149f6424d11Sbixia1 hi = after->getArgument(1); 1150f6424d11Sbixia1 args[0] = lo; 1151f6424d11Sbixia1 args[1] = hi; 1152a1507668Sbixia1 1153a1507668Sbixia1 if (isHybrid) { 1154a1507668Sbixia1 Value len = builder.create<arith::SubIOp>(loc, hi, lo); 1155a1507668Sbixia1 Value lenLimit = constantIndex(builder, loc, 30); 1156a1507668Sbixia1 Value lenCond = builder.create<arith::CmpIOp>( 1157a1507668Sbixia1 loc, arith::CmpIPredicate::ule, len, lenLimit); 1158f6424d11Sbixia1 scf::IfOp lenIf = 1159f6424d11Sbixia1 builder.create<scf::IfOp>(loc, types, lenCond, /*else=*/true); 1160a1507668Sbixia1 1161a1507668Sbixia1 // When len <= limit. 1162a1507668Sbixia1 builder.setInsertionPointToStart(&lenIf.getThenRegion().front()); 1163a1507668Sbixia1 FlatSymbolRefAttr insertionSortFunc = getMangledSortHelperFunc( 1164bfa3bc43SPeiming Liu builder, func, TypeRange(), kSortStableFuncNamePrefix, xPerm, ny, 1165f6424d11Sbixia1 ValueRange(args).drop_back(nTrailingP), createSortStableFunc); 1166a1507668Sbixia1 builder.create<func::CallOp>(loc, insertionSortFunc, TypeRange(), 1167f6424d11Sbixia1 ValueRange(args).drop_back(nTrailingP)); 1168f6424d11Sbixia1 builder.create<scf::YieldOp>(loc, ValueRange{lo, lo}); 1169a1507668Sbixia1 1170a1507668Sbixia1 // When len > limit. 1171a1507668Sbixia1 builder.setInsertionPointToStart(&lenIf.getElseRegion().front()); 1172f6424d11Sbixia1 Value depthLimit = args.back(); 1173f6424d11Sbixia1 depthLimit = builder.create<arith::SubIOp>(loc, depthLimit, 1174f6424d11Sbixia1 constantI64(builder, loc, 1)); 1175a1507668Sbixia1 Value depthCond = 1176a1507668Sbixia1 builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ule, 1177a1507668Sbixia1 depthLimit, constantI64(builder, loc, 0)); 1178f6424d11Sbixia1 scf::IfOp depthIf = 1179f6424d11Sbixia1 builder.create<scf::IfOp>(loc, types, depthCond, /*else=*/true); 1180a1507668Sbixia1 1181a1507668Sbixia1 // When depth exceeds limit. 1182a1507668Sbixia1 builder.setInsertionPointToStart(&depthIf.getThenRegion().front()); 1183a1507668Sbixia1 FlatSymbolRefAttr heapSortFunc = getMangledSortHelperFunc( 1184bfa3bc43SPeiming Liu builder, func, TypeRange(), kHeapSortFuncNamePrefix, xPerm, ny, 1185f6424d11Sbixia1 ValueRange(args).drop_back(nTrailingP), createHeapSortFunc); 1186a1507668Sbixia1 builder.create<func::CallOp>(loc, heapSortFunc, TypeRange(), 1187f6424d11Sbixia1 ValueRange(args).drop_back(nTrailingP)); 1188f6424d11Sbixia1 builder.create<scf::YieldOp>(loc, ValueRange{lo, lo}); 1189a1507668Sbixia1 1190a1507668Sbixia1 // When depth doesn't exceed limit. 1191a1507668Sbixia1 builder.setInsertionPointToStart(&depthIf.getElseRegion().front()); 1192f6424d11Sbixia1 args.back() = depthLimit; 1193f6424d11Sbixia1 std::tie(lo, hi) = 1194bfa3bc43SPeiming Liu createQuickSort(builder, module, func, args, xPerm, ny, nTrailingP); 1195f6424d11Sbixia1 builder.create<scf::YieldOp>(loc, ValueRange{lo, hi}); 1196a1507668Sbixia1 1197a1507668Sbixia1 builder.setInsertionPointAfter(depthIf); 1198f6424d11Sbixia1 lo = depthIf.getResult(0); 1199f6424d11Sbixia1 hi = depthIf.getResult(1); 1200f6424d11Sbixia1 builder.create<scf::YieldOp>(loc, ValueRange{lo, hi}); 1201f6424d11Sbixia1 1202f6424d11Sbixia1 builder.setInsertionPointAfter(lenIf); 1203f6424d11Sbixia1 lo = lenIf.getResult(0); 1204f6424d11Sbixia1 hi = lenIf.getResult(1); 1205f6424d11Sbixia1 } else { 1206f6424d11Sbixia1 std::tie(lo, hi) = 1207bfa3bc43SPeiming Liu createQuickSort(builder, module, func, args, xPerm, ny, nTrailingP); 1208a1507668Sbixia1 } 1209a1507668Sbixia1 1210f6424d11Sbixia1 // New [lo, hi) for the next while-loop iteration. 1211f6424d11Sbixia1 builder.create<scf::YieldOp>(loc, ValueRange{lo, hi}); 1212f6424d11Sbixia1 1213f6424d11Sbixia1 // After the while-loop. 1214f6424d11Sbixia1 builder.setInsertionPointAfter(whileOp); 1215a1507668Sbixia1 builder.create<func::ReturnOp>(loc); 1216a1507668Sbixia1 } 1217a1507668Sbixia1 12184f729d5aSbixia1 /// Implements the rewriting for operator sort and sort_coo. 12194f729d5aSbixia1 template <typename OpTy> 1220bfa3bc43SPeiming Liu LogicalResult matchAndRewriteSortOp(OpTy op, ValueRange xys, AffineMap xPerm, 1221bfa3bc43SPeiming Liu uint64_t ny, PatternRewriter &rewriter) { 12224f729d5aSbixia1 Location loc = op.getLoc(); 12230e1708ffSAart Bik SmallVector<Value> operands{constantIndex(rewriter, loc, 0), op.getN()}; 12244f729d5aSbixia1 12254f729d5aSbixia1 // Convert `values` to have dynamic shape and append them to `operands`. 12264f729d5aSbixia1 for (Value v : xys) { 12279916ab03Swren romano auto mtp = getMemRefType(v); 12284f729d5aSbixia1 if (!mtp.isDynamicDim(0)) { 12294f729d5aSbixia1 auto newMtp = 1230399638f9SAliia Khasanova MemRefType::get({ShapedType::kDynamic}, mtp.getElementType()); 12314f729d5aSbixia1 v = rewriter.create<memref::CastOp>(loc, newMtp, v); 12324f729d5aSbixia1 } 12334f729d5aSbixia1 operands.push_back(v); 12344f729d5aSbixia1 } 12353b1c86cdSbixia1 12364f729d5aSbixia1 auto insertPoint = op->template getParentOfType<func::FuncOp>(); 12379a29d875SKohei Yamaguchi if (!insertPoint) 12389a29d875SKohei Yamaguchi return failure(); 12399a29d875SKohei Yamaguchi 12403b1c86cdSbixia1 SmallString<32> funcName; 12413b1c86cdSbixia1 FuncGeneratorType funcGenerator; 12428550aebdSbixia1 uint32_t nTrailingP = 0; 12433b1c86cdSbixia1 switch (op.getAlgorithm()) { 1244a1507668Sbixia1 case SparseTensorSortKind::HybridQuickSort: { 1245a1507668Sbixia1 funcName = kHybridQuickSortFuncNamePrefix; 1246a1507668Sbixia1 funcGenerator = createQuickSortFunc; 1247a1507668Sbixia1 nTrailingP = 1; 1248a1507668Sbixia1 // As a heuristics, set depthLimit = 2 * log2(n). 1249a1507668Sbixia1 Value lo = operands[loIdx]; 1250a1507668Sbixia1 Value hi = operands[hiIdx]; 1251a1507668Sbixia1 Value len = rewriter.create<arith::IndexCastOp>( 1252a1507668Sbixia1 loc, rewriter.getI64Type(), 1253a1507668Sbixia1 rewriter.create<arith::SubIOp>(loc, hi, lo)); 1254a1507668Sbixia1 Value depthLimit = rewriter.create<arith::SubIOp>( 1255a1507668Sbixia1 loc, constantI64(rewriter, loc, 64), 1256a1507668Sbixia1 rewriter.create<math::CountLeadingZerosOp>(loc, len)); 1257f6424d11Sbixia1 operands.push_back(depthLimit); 1258a1507668Sbixia1 break; 1259a1507668Sbixia1 } 12603b1c86cdSbixia1 case SparseTensorSortKind::QuickSort: 1261a1507668Sbixia1 funcName = kQuickSortFuncNamePrefix; 1262a1507668Sbixia1 funcGenerator = createQuickSortFunc; 12633b1c86cdSbixia1 break; 12643b1c86cdSbixia1 case SparseTensorSortKind::InsertionSortStable: 12653b1c86cdSbixia1 funcName = kSortStableFuncNamePrefix; 12663b1c86cdSbixia1 funcGenerator = createSortStableFunc; 12673b1c86cdSbixia1 break; 12683b1c86cdSbixia1 case SparseTensorSortKind::HeapSort: 12693b1c86cdSbixia1 funcName = kHeapSortFuncNamePrefix; 12703b1c86cdSbixia1 funcGenerator = createHeapSortFunc; 12713b1c86cdSbixia1 break; 12723b1c86cdSbixia1 } 12733b1c86cdSbixia1 12744f729d5aSbixia1 FlatSymbolRefAttr func = 1275bfa3bc43SPeiming Liu getMangledSortHelperFunc(rewriter, insertPoint, TypeRange(), funcName, 1276bfa3bc43SPeiming Liu xPerm, ny, operands, funcGenerator, nTrailingP); 12774f729d5aSbixia1 rewriter.replaceOpWithNewOp<func::CallOp>(op, func, TypeRange(), operands); 12784f729d5aSbixia1 return success(); 12794f729d5aSbixia1 } 12804f729d5aSbixia1 1281062e515bSbixia1 //===---------------------------------------------------------------------===// 1282062e515bSbixia1 // The actual sparse buffer rewriting rules. 1283062e515bSbixia1 //===---------------------------------------------------------------------===// 1284062e515bSbixia1 1285062e515bSbixia1 namespace { 1286654bbbdeSbixia1 /// Sparse rewriting rule for the push_back operator. 1287654bbbdeSbixia1 struct PushBackRewriter : OpRewritePattern<PushBackOp> { 1288654bbbdeSbixia1 public: 1289654bbbdeSbixia1 using OpRewritePattern<PushBackOp>::OpRewritePattern; 12905618d2beSbixia1 PushBackRewriter(MLIRContext *context, bool enableInit) 12915618d2beSbixia1 : OpRewritePattern(context), enableBufferInitialization(enableInit) {} 1292654bbbdeSbixia1 LogicalResult matchAndRewrite(PushBackOp op, 1293654bbbdeSbixia1 PatternRewriter &rewriter) const override { 1294d45be887Sbixia1 // Rewrite push_back(buffer, value, n) to: 1295d45be887Sbixia1 // new_size = size(buffer) + n 1296d45be887Sbixia1 // if (new_size > capacity(buffer)) 1297d45be887Sbixia1 // while new_size > new_capacity 1298d45be887Sbixia1 // new_capacity = new_capacity*2 1299654bbbdeSbixia1 // new_buffer = realloc(buffer, new_capacity) 1300654bbbdeSbixia1 // buffer = new_buffer 1301d45be887Sbixia1 // subBuffer = subviewof(buffer) 1302d45be887Sbixia1 // linalg.fill subBuffer value 1303d45be887Sbixia1 // 1304d45be887Sbixia1 // size(buffer) += n 13051c835b5aSbixia1 // 13061c835b5aSbixia1 // The capacity check is skipped when the attribute inbounds is presented. 1307654bbbdeSbixia1 Location loc = op->getLoc(); 1308654bbbdeSbixia1 Value c0 = constantIndex(rewriter, loc, 0); 1309654bbbdeSbixia1 Value buffer = op.getInBuffer(); 1310654bbbdeSbixia1 Value capacity = rewriter.create<memref::DimOp>(loc, buffer, c0); 1311988733c6SPeiming Liu Value size = op.getCurSize(); 1312654bbbdeSbixia1 Value value = op.getValue(); 13131c835b5aSbixia1 1314d45be887Sbixia1 Value n = op.getN() ? op.getN() : constantIndex(rewriter, loc, 1); 1315d45be887Sbixia1 Value newSize = rewriter.create<arith::AddIOp>(loc, size, n); 1316d45be887Sbixia1 auto nValue = dyn_cast_or_null<arith::ConstantIndexOp>(n.getDefiningOp()); 1317d45be887Sbixia1 bool nIsOne = (nValue && nValue.value() == 1); 1318d45be887Sbixia1 13191c835b5aSbixia1 if (!op.getInbounds()) { 13201c835b5aSbixia1 Value cond = rewriter.create<arith::CmpIOp>( 1321d45be887Sbixia1 loc, arith::CmpIPredicate::ugt, newSize, capacity); 13221c835b5aSbixia1 1323d45be887Sbixia1 Value c2 = constantIndex(rewriter, loc, 2); 1324654bbbdeSbixia1 auto bufferType = 1325399638f9SAliia Khasanova MemRefType::get({ShapedType::kDynamic}, value.getType()); 1326654bbbdeSbixia1 scf::IfOp ifOp = rewriter.create<scf::IfOp>(loc, bufferType, cond, 1327654bbbdeSbixia1 /*else=*/true); 1328654bbbdeSbixia1 // True branch. 1329654bbbdeSbixia1 rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front()); 1330d45be887Sbixia1 if (nIsOne) { 1331654bbbdeSbixia1 capacity = rewriter.create<arith::MulIOp>(loc, capacity, c2); 1332d45be887Sbixia1 } else { 1333d45be887Sbixia1 // Use a do-while loop to calculate the new capacity as follows: 1334d45be887Sbixia1 // do { new_capacity *= 2 } while (size > new_capacity) 1335d45be887Sbixia1 scf::WhileOp whileOp = 1336d45be887Sbixia1 rewriter.create<scf::WhileOp>(loc, capacity.getType(), capacity); 1337d45be887Sbixia1 1338d45be887Sbixia1 // The before-region of the WhileOp. 1339d45be887Sbixia1 Block *before = rewriter.createBlock(&whileOp.getBefore(), {}, 1340d45be887Sbixia1 {capacity.getType()}, {loc}); 1341d45be887Sbixia1 rewriter.setInsertionPointToEnd(before); 1342d45be887Sbixia1 1343d45be887Sbixia1 capacity = 1344d45be887Sbixia1 rewriter.create<arith::MulIOp>(loc, before->getArgument(0), c2); 1345d45be887Sbixia1 cond = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ugt, 1346d45be887Sbixia1 newSize, capacity); 1347d45be887Sbixia1 rewriter.create<scf::ConditionOp>(loc, cond, ValueRange{capacity}); 1348d45be887Sbixia1 // The after-region of the WhileOp. 1349d45be887Sbixia1 Block *after = rewriter.createBlock(&whileOp.getAfter(), {}, 1350d45be887Sbixia1 {capacity.getType()}, {loc}); 1351d45be887Sbixia1 rewriter.setInsertionPointToEnd(after); 1352d45be887Sbixia1 rewriter.create<scf::YieldOp>(loc, after->getArguments()); 1353d45be887Sbixia1 1354d45be887Sbixia1 rewriter.setInsertionPointAfter(whileOp); 1355d45be887Sbixia1 capacity = whileOp.getResult(0); 1356d45be887Sbixia1 } 1357d45be887Sbixia1 1358654bbbdeSbixia1 Value newBuffer = 1359654bbbdeSbixia1 rewriter.create<memref::ReallocOp>(loc, bufferType, buffer, capacity); 13605618d2beSbixia1 if (enableBufferInitialization) { 13615618d2beSbixia1 Value fillSize = rewriter.create<arith::SubIOp>(loc, capacity, newSize); 1362ea4be70cSbixia1 Value fillValue = constantZero(rewriter, loc, value.getType()); 13635618d2beSbixia1 Value subBuffer = rewriter.create<memref::SubViewOp>( 13645618d2beSbixia1 loc, newBuffer, /*offset=*/ValueRange{newSize}, 13655618d2beSbixia1 /*size=*/ValueRange{fillSize}, 13665618d2beSbixia1 /*step=*/ValueRange{constantIndex(rewriter, loc, 1)}); 13675618d2beSbixia1 rewriter.create<linalg::FillOp>(loc, fillValue, subBuffer); 13685618d2beSbixia1 } 1369654bbbdeSbixia1 rewriter.create<scf::YieldOp>(loc, newBuffer); 1370654bbbdeSbixia1 1371654bbbdeSbixia1 // False branch. 1372654bbbdeSbixia1 rewriter.setInsertionPointToStart(&ifOp.getElseRegion().front()); 1373654bbbdeSbixia1 rewriter.create<scf::YieldOp>(loc, buffer); 1374654bbbdeSbixia1 13751c835b5aSbixia1 // Prepare for adding the value to the end of the buffer. 1376654bbbdeSbixia1 rewriter.setInsertionPointAfter(ifOp); 1377654bbbdeSbixia1 buffer = ifOp.getResult(0); 13781c835b5aSbixia1 } 13791c835b5aSbixia1 13801c835b5aSbixia1 // Add the value to the end of the buffer. 1381d45be887Sbixia1 if (nIsOne) { 1382654bbbdeSbixia1 rewriter.create<memref::StoreOp>(loc, value, buffer, size); 1383d45be887Sbixia1 } else { 1384d45be887Sbixia1 Value subBuffer = rewriter.create<memref::SubViewOp>( 1385d45be887Sbixia1 loc, buffer, /*offset=*/ValueRange{size}, /*size=*/ValueRange{n}, 1386d45be887Sbixia1 /*step=*/ValueRange{constantIndex(rewriter, loc, 1)}); 1387d45be887Sbixia1 rewriter.create<linalg::FillOp>(loc, value, subBuffer); 1388d45be887Sbixia1 } 1389654bbbdeSbixia1 1390d45be887Sbixia1 // Update the buffer size. 1391988733c6SPeiming Liu rewriter.replaceOp(op, {buffer, newSize}); 1392654bbbdeSbixia1 return success(); 1393654bbbdeSbixia1 } 13945618d2beSbixia1 13955618d2beSbixia1 private: 13965618d2beSbixia1 bool enableBufferInitialization; 1397654bbbdeSbixia1 }; 1398654bbbdeSbixia1 13994f729d5aSbixia1 /// Sparse rewriting rule for the sort_coo operator. 14000083f833SPeiming Liu struct SortRewriter : public OpRewritePattern<SortOp> { 14014f729d5aSbixia1 public: 14020083f833SPeiming Liu using OpRewritePattern<SortOp>::OpRewritePattern; 14034f729d5aSbixia1 14040083f833SPeiming Liu LogicalResult matchAndRewrite(SortOp op, 14054f729d5aSbixia1 PatternRewriter &rewriter) const override { 14060e1708ffSAart Bik SmallVector<Value> xys; 14074f729d5aSbixia1 xys.push_back(op.getXy()); 14084f729d5aSbixia1 xys.append(op.getYs().begin(), op.getYs().end()); 14094f729d5aSbixia1 1410bfa3bc43SPeiming Liu auto xPerm = op.getPermMap(); 14114f729d5aSbixia1 uint64_t ny = 0; 14124f729d5aSbixia1 if (auto nyAttr = op.getNyAttr()) 14134f729d5aSbixia1 ny = nyAttr.getInt(); 14144f729d5aSbixia1 1415bfa3bc43SPeiming Liu return matchAndRewriteSortOp(op, xys, xPerm, ny, rewriter); 1416062e515bSbixia1 } 1417062e515bSbixia1 }; 1418062e515bSbixia1 1419062e515bSbixia1 } // namespace 1420062e515bSbixia1 1421062e515bSbixia1 //===---------------------------------------------------------------------===// 1422062e515bSbixia1 // Methods that add patterns described in this file to a pattern list. 1423062e515bSbixia1 //===---------------------------------------------------------------------===// 1424062e515bSbixia1 14255618d2beSbixia1 void mlir::populateSparseBufferRewriting(RewritePatternSet &patterns, 14265618d2beSbixia1 bool enableBufferInitialization) { 14275618d2beSbixia1 patterns.add<PushBackRewriter>(patterns.getContext(), 14285618d2beSbixia1 enableBufferInitialization); 14290083f833SPeiming Liu patterns.add<SortRewriter>(patterns.getContext()); 1430062e515bSbixia1 } 1431