xref: /llvm-project/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp (revision 1609f1c2a5ecc0e0e28f433ec9205122478ddaa3)
1 //===- SparseBufferRewriting.cpp - Sparse buffer rewriting rules ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements rewriting rules that are specific to sparse tensor
10 // primitives with memref operands.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "CodegenUtils.h"
15 
16 #include "mlir/Dialect/Arith/IR/Arith.h"
17 #include "mlir/Dialect/Func/IR/FuncOps.h"
18 #include "mlir/Dialect/Linalg/IR/Linalg.h"
19 #include "mlir/Dialect/Math/IR/Math.h"
20 #include "mlir/Dialect/MemRef/IR/MemRef.h"
21 #include "mlir/Dialect/SCF/IR/SCF.h"
22 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
23 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
24 #include "mlir/Support/LLVM.h"
25 
26 using namespace mlir;
27 using namespace mlir::sparse_tensor;
28 
29 //===---------------------------------------------------------------------===//
30 // Helper methods for the actual rewriting rules.
31 //===---------------------------------------------------------------------===//
32 
33 static constexpr uint64_t loIdx = 0;
34 static constexpr uint64_t hiIdx = 1;
35 static constexpr uint64_t xStartIdx = 2;
36 
37 static constexpr const char kPartitionFuncNamePrefix[] = "_sparse_partition_";
38 static constexpr const char kBinarySearchFuncNamePrefix[] =
39     "_sparse_binary_search_";
40 static constexpr const char kHybridQuickSortFuncNamePrefix[] =
41     "_sparse_hybrid_qsort_";
42 static constexpr const char kSortStableFuncNamePrefix[] =
43     "_sparse_sort_stable_";
44 static constexpr const char kShiftDownFuncNamePrefix[] = "_sparse_shift_down_";
45 static constexpr const char kHeapSortFuncNamePrefix[] = "_sparse_heap_sort_";
46 static constexpr const char kQuickSortFuncNamePrefix[] = "_sparse_qsort_";
47 
48 using FuncGeneratorType = function_ref<void(OpBuilder &, ModuleOp, func::FuncOp,
49                                             AffineMap, uint64_t, uint32_t)>;
50 
51 /// Constructs a function name with this format to facilitate quick sort:
52 ///   <namePrefix><xPerm>_<x type>_<y0 type>..._<yn type> for sort
53 ///   <namePrefix><xPerm>_<x type>_coo_<ny>_<y0 type>..._<yn type> for sort_coo
54 static void getMangledSortHelperFuncName(llvm::raw_svector_ostream &nameOstream,
55                                          StringRef namePrefix, AffineMap xPerm,
56                                          uint64_t ny, ValueRange operands) {
57   nameOstream << namePrefix;
58   for (auto res : xPerm.getResults())
59     nameOstream << cast<AffineDimExpr>(res).getPosition() << "_";
60 
61   nameOstream << getMemRefType(operands[xStartIdx]).getElementType();
62   nameOstream << "_coo_" << ny;
63 
64   constexpr uint64_t yBufferOffset = 1;
65   for (Value v : operands.drop_front(xStartIdx + yBufferOffset))
66     nameOstream << "_" << getMemRefType(v).getElementType();
67 }
68 
69 /// Looks up a function that is appropriate for the given operands being
70 /// sorted, and creates such a function if it doesn't exist yet. The
71 /// parameters `xPerm` and `ny` tell the number of x and y values provided
72 /// by the buffer in xStartIdx.
73 //
74 // All sorting function generators take (lo, hi, xs, ys) in `operands` as
75 // parameters for the sorting functions. Other parameters, such as the recursive
76 // call depth, are appended to the end of the parameter list as
77 // "trailing parameters".
78 static FlatSymbolRefAttr getMangledSortHelperFunc(
79     OpBuilder &builder, func::FuncOp insertPoint, TypeRange resultTypes,
80     StringRef namePrefix, AffineMap xPerm, uint64_t ny, ValueRange operands,
81     FuncGeneratorType createFunc, uint32_t nTrailingP = 0) {
82   SmallString<32> nameBuffer;
83   llvm::raw_svector_ostream nameOstream(nameBuffer);
84   getMangledSortHelperFuncName(nameOstream, namePrefix, xPerm, ny,
85                                operands.drop_back(nTrailingP));
86 
87   ModuleOp module = insertPoint->getParentOfType<ModuleOp>();
88   MLIRContext *context = module.getContext();
89   auto result = SymbolRefAttr::get(context, nameOstream.str());
90   auto func = module.lookupSymbol<func::FuncOp>(result.getAttr());
91 
92   if (!func) {
93     // Create the function.
94     OpBuilder::InsertionGuard insertionGuard(builder);
95     builder.setInsertionPoint(insertPoint);
96     Location loc = insertPoint.getLoc();
97     func = builder.create<func::FuncOp>(
98         loc, nameOstream.str(),
99         FunctionType::get(context, operands.getTypes(), resultTypes));
100     func.setPrivate();
101     createFunc(builder, module, func, xPerm, ny, nTrailingP);
102   }
103 
104   return result;
105 }
106 
107 /// Creates a code block to process each pair of (xs[i], xs[j]) for sorting.
108 /// The code to process the value pairs is generated by `bodyBuilder`.
109 static void forEachIJPairInXs(
110     OpBuilder &builder, Location loc, ValueRange args, AffineMap xPerm,
111     uint64_t ny,
112     function_ref<void(uint64_t, Value, Value, Value)> bodyBuilder) {
113   Value cstep = constantIndex(builder, loc, xPerm.getNumResults() + ny);
114   Value iOffset = builder.create<arith::MulIOp>(loc, args[0], cstep);
115   Value jOffset = builder.create<arith::MulIOp>(loc, args[1], cstep);
116   for (unsigned k = 0, e = xPerm.getNumResults(); k < e; k++) {
117     unsigned actualK = cast<AffineDimExpr>(xPerm.getResult(k)).getPosition();
118     Value ak = constantIndex(builder, loc, actualK);
119     Value i = builder.create<arith::AddIOp>(loc, ak, iOffset);
120     Value j = builder.create<arith::AddIOp>(loc, ak, jOffset);
121     Value buffer = args[xStartIdx];
122 
123     bodyBuilder(k, i, j, buffer);
124   }
125 }
126 
127 /// Creates a code block to process each pair of (xys[i], xys[j]) for sorting.
128 /// The code to process the value pairs is generated by `bodyBuilder`.
129 static void forEachIJPairInAllBuffers(
130     OpBuilder &builder, Location loc, ValueRange args, AffineMap xPerm,
131     uint64_t ny,
132     function_ref<void(uint64_t, Value, Value, Value)> bodyBuilder) {
133 
134   // Create code for the first (xPerm + ny) buffers.
135   SmallVector<AffineExpr> exps(xPerm.getResults().begin(),
136                                xPerm.getResults().end());
137   for (unsigned y = 0; y < ny; y++) {
138     exps.push_back(builder.getAffineDimExpr(y + xPerm.getNumResults()));
139   }
140   AffineMap xyPerm = AffineMap::get(exps.size(), 0, exps, builder.getContext());
141   assert(xyPerm.isPermutation());
142 
143   forEachIJPairInXs(builder, loc, args, xyPerm, 0, bodyBuilder);
144 
145   constexpr uint64_t numHandledBuffers = 1;
146   // Create code for the remaining buffers.
147   Value i = args[0];
148   Value j = args[1];
149   for (const auto &arg :
150        llvm::enumerate(args.drop_front(xStartIdx + numHandledBuffers))) {
151     bodyBuilder(arg.index() + xPerm.getNumResults() + ny, i, j, arg.value());
152   }
153 }
154 
155 /// Creates a code block for swapping the values in index i and j for all the
156 /// buffers.
157 //
158 // The generated IR corresponds to this C like algorithm:
159 //     swap(x0[i], x0[j]);
160 //     swap(x1[i], x1[j]);
161 //     ...
162 //     swap(xn[i], xn[j]);
163 //     swap(y0[i], y0[j]);
164 //     ...
165 //     swap(yn[i], yn[j]);
166 static void createSwap(OpBuilder &builder, Location loc, ValueRange args,
167                        AffineMap xPerm, uint64_t ny) {
168   auto swapOnePair = [&](uint64_t unused, Value i, Value j, Value buffer) {
169     Value vi = builder.create<memref::LoadOp>(loc, buffer, i);
170     Value vj = builder.create<memref::LoadOp>(loc, buffer, j);
171     builder.create<memref::StoreOp>(loc, vj, buffer, i);
172     builder.create<memref::StoreOp>(loc, vi, buffer, j);
173   };
174 
175   forEachIJPairInAllBuffers(builder, loc, args, xPerm, ny, swapOnePair);
176 }
177 
178 /// Creates code to compare all the (xs[i], xs[j]) pairs. The method to compare
179 /// each pair is create via `compareBuilder`.
180 static Value createInlinedCompareImplementation(
181     OpBuilder &builder, Location loc, ValueRange args, AffineMap xPerm,
182     uint64_t ny,
183     function_ref<Value(OpBuilder &, Location, Value, Value, Value, bool, bool)>
184         compareBuilder) {
185   Value result;
186   auto bodyBuilder = [&](uint64_t k, Value i, Value j, Value buffer) {
187     bool isFirstDim = (k == 0);
188     bool isLastDim = (k == xPerm.getNumResults() - 1);
189     Value val =
190         compareBuilder(builder, loc, i, j, buffer, isFirstDim, isLastDim);
191     if (isFirstDim) {
192       result = val;
193     } else if (!isLastDim) {
194       OpBuilder::InsertionGuard insertionGuard(builder);
195       auto ifOp = cast<scf::IfOp>(val.getDefiningOp());
196       builder.setInsertionPointAfter(ifOp);
197       builder.create<scf::YieldOp>(loc, ifOp.getResult(0));
198     }
199   };
200 
201   forEachIJPairInXs(builder, loc, args, xPerm, ny, bodyBuilder);
202 
203   builder.setInsertionPointAfterValue(result);
204   return result;
205 }
206 
207 /// Generates code to compare whether x[i] is equal to x[j] and returns the
208 /// result of the comparison.
209 static Value createEqCompare(OpBuilder &builder, Location loc, Value i, Value j,
210                              Value x, bool isFirstDim, bool isLastDim) {
211   Value vi = builder.create<memref::LoadOp>(loc, x, i);
212   Value vj = builder.create<memref::LoadOp>(loc, x, j);
213 
214   Value res;
215   if (isLastDim) {
216     res = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, vi, vj);
217     // For 1D, we create a compare without any control flow. Otherwise, we
218     // create YieldOp to return the result in the nested if-stmt.
219     if (!isFirstDim)
220       builder.create<scf::YieldOp>(loc, res);
221   } else {
222     Value ne =
223         builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ne, vi, vj);
224     scf::IfOp ifOp = builder.create<scf::IfOp>(loc, builder.getIntegerType(1),
225                                                ne, /*else=*/true);
226     // If (x[i] != x[j]).
227     builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
228     Value f = constantI1(builder, loc, false);
229     builder.create<scf::YieldOp>(loc, f);
230 
231     // If (x[i] == x[j]). Set up the insertion point for the nested if-stmt that
232     // checks the remaining dimensions.
233     builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
234     res = ifOp.getResult(0);
235   }
236 
237   return res;
238 }
239 
240 /// Creates code to compare whether xs[i] is equal to xs[j].
241 //
242 // The generate IR corresponds to this C like algorithm:
243 //   if (x0[i] != x0[j])
244 //     return false;
245 //   else
246 //     if (x1[i] != x1[j])
247 //       return false;
248 //     else if (x2[2] != x2[j]))
249 //       and so on ...
250 static Value createInlinedEqCompare(OpBuilder &builder, Location loc,
251                                     ValueRange args, AffineMap xPerm,
252                                     uint64_t ny, uint32_t nTrailingP = 0) {
253   // Compare functions don't use trailing parameters.
254   (void)nTrailingP;
255   assert(nTrailingP == 0);
256   return createInlinedCompareImplementation(builder, loc, args, xPerm, ny,
257                                             createEqCompare);
258 }
259 
260 /// Generates code to compare whether x[i] is less than x[j] and returns the
261 /// result of the comparison.
262 static Value createLessThanCompare(OpBuilder &builder, Location loc, Value i,
263                                    Value j, Value x, bool isFirstDim,
264                                    bool isLastDim) {
265   Value vi = builder.create<memref::LoadOp>(loc, x, i);
266   Value vj = builder.create<memref::LoadOp>(loc, x, j);
267 
268   Value res;
269   if (isLastDim) {
270     res = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult, vi, vj);
271     // For 1D, we create a compare without any control flow. Otherwise, we
272     // create YieldOp to return the result in the nested if-stmt.
273     if (!isFirstDim)
274       builder.create<scf::YieldOp>(loc, res);
275   } else {
276     Value ne =
277         builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ne, vi, vj);
278     scf::IfOp ifOp = builder.create<scf::IfOp>(loc, builder.getIntegerType(1),
279                                                ne, /*else=*/true);
280     // If (x[i] != x[j]).
281     builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
282     Value lt =
283         builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult, vi, vj);
284     builder.create<scf::YieldOp>(loc, lt);
285 
286     // If (x[i] == x[j]). Set up the insertion point for the nested if-stmt that
287     // checks the remaining dimensions.
288     builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
289     res = ifOp.getResult(0);
290   }
291 
292   return res;
293 }
294 
295 /// Creates code to compare whether xs[i] is less than xs[j].
296 //
297 // The generate IR corresponds to this C like algorithm:
298 //   if (x0[i] != x0[j])
299 //     return x0[i] < x0[j];
300 //   else if (x1[j] != x1[i])
301 //     return x1[i] < x1[j];
302 //   else
303 //       and so on ...
304 static Value createInlinedLessThan(OpBuilder &builder, Location loc,
305                                    ValueRange args, AffineMap xPerm,
306                                    uint64_t ny, uint32_t nTrailingP = 0) {
307   // Compare functions don't use trailing parameters.
308   (void)nTrailingP;
309   assert(nTrailingP == 0);
310   return createInlinedCompareImplementation(builder, loc, args, xPerm, ny,
311                                             createLessThanCompare);
312 }
313 
314 /// Creates a function to use a binary search to find the insertion point for
315 /// inserting xs[hi] to the sorted values xs[lo..hi).
316 //
317 // The generate IR corresponds to this C like algorithm:
318 //   p = hi
319 //   while (lo < hi)
320 //      mid = (lo + hi) >> 1
321 //      if (xs[p] < xs[mid])
322 //        hi = mid
323 //      else
324 //        lo = mid - 1
325 //   return lo;
326 //
327 static void createBinarySearchFunc(OpBuilder &builder, ModuleOp module,
328                                    func::FuncOp func, AffineMap xPerm,
329                                    uint64_t ny, uint32_t nTrailingP = 0) {
330   // Binary search doesn't use trailing parameters.
331   (void)nTrailingP;
332   assert(nTrailingP == 0);
333   OpBuilder::InsertionGuard insertionGuard(builder);
334   Block *entryBlock = func.addEntryBlock();
335   builder.setInsertionPointToStart(entryBlock);
336 
337   Location loc = func.getLoc();
338   ValueRange args = entryBlock->getArguments();
339   Value p = args[hiIdx];
340   SmallVector<Type, 2> types(2, p.getType()); // Only two types.
341   scf::WhileOp whileOp = builder.create<scf::WhileOp>(
342       loc, types, SmallVector<Value, 2>{args[loIdx], args[hiIdx]});
343 
344   // The before-region of the WhileOp.
345   Block *before =
346       builder.createBlock(&whileOp.getBefore(), {}, types, {loc, loc});
347   builder.setInsertionPointToEnd(before);
348   Value cond1 = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult,
349                                               before->getArgument(0),
350                                               before->getArgument(1));
351   builder.create<scf::ConditionOp>(loc, cond1, before->getArguments());
352 
353   // The after-region of the WhileOp.
354   Block *after =
355       builder.createBlock(&whileOp.getAfter(), {}, types, {loc, loc});
356   builder.setInsertionPointToEnd(after);
357   Value lo = after->getArgument(0);
358   Value hi = after->getArgument(1);
359   // Compute mid = (lo + hi) >> 1.
360   Value c1 = constantIndex(builder, loc, 1);
361   Value mid = builder.create<arith::ShRUIOp>(
362       loc, builder.create<arith::AddIOp>(loc, lo, hi), c1);
363   Value midp1 = builder.create<arith::AddIOp>(loc, mid, c1);
364 
365   // Compare xs[p] < xs[mid].
366   SmallVector<Value> compareOperands{p, mid};
367   constexpr uint64_t numXBuffers = 1;
368   compareOperands.append(args.begin() + xStartIdx,
369                          args.begin() + xStartIdx + numXBuffers);
370   Value cond2 = createInlinedLessThan(builder, loc, compareOperands, xPerm, ny);
371   // Update lo and hi for the WhileOp as follows:
372   //   if (xs[p] < xs[mid]))
373   //     hi = mid;
374   //   else
375   //     lo = mid + 1;
376   Value newLo = builder.create<arith::SelectOp>(loc, cond2, lo, midp1);
377   Value newHi = builder.create<arith::SelectOp>(loc, cond2, mid, hi);
378   builder.create<scf::YieldOp>(loc, ValueRange{newLo, newHi});
379 
380   builder.setInsertionPointAfter(whileOp);
381   builder.create<func::ReturnOp>(loc, whileOp.getResult(0));
382 }
383 
384 /// Creates code to advance i in a loop based on xs[p] as follows:
385 ///   while (xs[i] < xs[p]) i += step (step > 0)
386 /// or
387 ///   while (xs[i] > xs[p]) i += step (step < 0)
388 /// The routine returns i as well as a boolean value to indicate whether
389 /// xs[i] == xs[p].
390 static std::pair<Value, Value> createScanLoop(OpBuilder &builder,
391                                               ModuleOp module,
392                                               func::FuncOp func, ValueRange xs,
393                                               Value i, Value p, AffineMap xPerm,
394                                               uint64_t ny, int step) {
395   Location loc = func.getLoc();
396   scf::WhileOp whileOp =
397       builder.create<scf::WhileOp>(loc, TypeRange{i.getType()}, ValueRange{i});
398 
399   Block *before =
400       builder.createBlock(&whileOp.getBefore(), {}, {i.getType()}, {loc});
401   builder.setInsertionPointToEnd(before);
402   SmallVector<Value> compareOperands;
403   if (step > 0) {
404     compareOperands.push_back(before->getArgument(0));
405     compareOperands.push_back(p);
406   } else {
407     assert(step < 0);
408     compareOperands.push_back(p);
409     compareOperands.push_back(before->getArgument(0));
410   }
411   compareOperands.append(xs.begin(), xs.end());
412   Value cond = createInlinedLessThan(builder, loc, compareOperands, xPerm, ny);
413   builder.create<scf::ConditionOp>(loc, cond, before->getArguments());
414 
415   Block *after =
416       builder.createBlock(&whileOp.getAfter(), {}, {i.getType()}, {loc});
417   builder.setInsertionPointToEnd(after);
418   Value cs = constantIndex(builder, loc, step);
419   i = builder.create<arith::AddIOp>(loc, after->getArgument(0), cs);
420   builder.create<scf::YieldOp>(loc, ValueRange{i});
421   i = whileOp.getResult(0);
422 
423   builder.setInsertionPointAfter(whileOp);
424   compareOperands[0] = i;
425   compareOperands[1] = p;
426   Value compareEq =
427       createInlinedEqCompare(builder, loc, compareOperands, xPerm, ny);
428 
429   return std::make_pair(whileOp.getResult(0), compareEq);
430 }
431 
432 /// Creates and returns an IfOp to compare two elements and swap the elements
433 /// if compareFunc(data[b], data[a]) returns true. The new insertion point is
434 /// right after the swap instructions.
435 static scf::IfOp createCompareThenSwap(OpBuilder &builder, Location loc,
436                                        AffineMap xPerm, uint64_t ny,
437                                        SmallVectorImpl<Value> &swapOperands,
438                                        SmallVectorImpl<Value> &compareOperands,
439                                        Value a, Value b) {
440   // Compare(data[b], data[a]).
441   compareOperands[0] = b;
442   compareOperands[1] = a;
443   Value cond = createInlinedLessThan(builder, loc, compareOperands, xPerm, ny);
444   scf::IfOp ifOp = builder.create<scf::IfOp>(loc, cond, /*else=*/false);
445   builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
446   swapOperands[0] = b;
447   swapOperands[1] = a;
448   createSwap(builder, loc, swapOperands, xPerm, ny);
449   return ifOp;
450 }
451 
452 /// Creates code to insert the 3rd element to a list of two sorted elements.
453 static void createInsert3rd(OpBuilder &builder, Location loc, AffineMap xPerm,
454                             uint64_t ny, SmallVectorImpl<Value> &swapOperands,
455                             SmallVectorImpl<Value> &compareOperands, Value v0,
456                             Value v1, Value v2) {
457   scf::IfOp ifOp = createCompareThenSwap(builder, loc, xPerm, ny, swapOperands,
458                                          compareOperands, v1, v2);
459   createCompareThenSwap(builder, loc, xPerm, ny, swapOperands, compareOperands,
460                         v0, v1);
461   builder.setInsertionPointAfter(ifOp);
462 }
463 
464 /// Creates code to sort 3 elements.
465 static void createSort3(OpBuilder &builder, Location loc, AffineMap xPerm,
466                         uint64_t ny, SmallVectorImpl<Value> &swapOperands,
467                         SmallVectorImpl<Value> &compareOperands, Value v0,
468                         Value v1, Value v2) {
469   // Sort the first 2 elements.
470   scf::IfOp ifOp1 = createCompareThenSwap(builder, loc, xPerm, ny, swapOperands,
471                                           compareOperands, v0, v1);
472   builder.setInsertionPointAfter(ifOp1);
473 
474   // Insert the 3th element.
475   createInsert3rd(builder, loc, xPerm, ny, swapOperands, compareOperands, v0,
476                   v1, v2);
477 }
478 
479 /// Creates code to sort 5 elements.
480 static void createSort5(OpBuilder &builder, Location loc, AffineMap xPerm,
481                         uint64_t ny, SmallVectorImpl<Value> &swapOperands,
482                         SmallVectorImpl<Value> &compareOperands, Value v0,
483                         Value v1, Value v2, Value v3, Value v4) {
484   // Sort the first 3 elements.
485   createSort3(builder, loc, xPerm, ny, swapOperands, compareOperands, v0, v1,
486               v2);
487 
488   auto insert4th = [&]() {
489     scf::IfOp ifOp = createCompareThenSwap(
490         builder, loc, xPerm, ny, swapOperands, compareOperands, v2, v3);
491     createInsert3rd(builder, loc, xPerm, ny, swapOperands, compareOperands, v0,
492                     v1, v2);
493     builder.setInsertionPointAfter(ifOp);
494   };
495 
496   // Insert the 4th element.
497   insert4th();
498 
499   // Insert the 5th element.
500   scf::IfOp ifOp = createCompareThenSwap(builder, loc, xPerm, ny, swapOperands,
501                                          compareOperands, v3, v4);
502   insert4th();
503   builder.setInsertionPointAfter(ifOp);
504 }
505 
506 /// Creates a code block to swap the values in indices lo, mi, and hi so that
507 /// data[lo], data[mi] and data[hi] are sorted in non-decreasing values. When
508 /// the number of values in range [lo, hi) is more than a threshold, we also
509 /// include the middle of [lo, mi) and [mi, hi) and sort a total of five values.
510 static void createChoosePivot(OpBuilder &builder, ModuleOp module,
511                               func::FuncOp func, AffineMap xPerm, uint64_t ny,
512                               Value lo, Value hi, Value mi, ValueRange args) {
513   SmallVector<Value> compareOperands{mi, lo};
514   constexpr uint64_t numXBuffers = 1;
515   compareOperands.append(args.begin() + xStartIdx,
516                          args.begin() + xStartIdx + numXBuffers);
517   SmallVector<Value> swapOperands{mi, lo};
518   swapOperands.append(args.begin() + xStartIdx, args.end());
519   Location loc = func.getLoc();
520   Value c1 = constantIndex(builder, loc, 1);
521   Value hiP1 = builder.create<arith::AddIOp>(loc, hi, c1);
522   Value len = builder.create<arith::SubIOp>(loc, hiP1, lo);
523   Value lenThreshold = constantIndex(builder, loc, 1000);
524   Value lenCond = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult,
525                                                 len, lenThreshold);
526   scf::IfOp lenIf = builder.create<scf::IfOp>(loc, lenCond, /*else=*/true);
527 
528   // When len < 1000, choose pivot from median of 3 values.
529   builder.setInsertionPointToStart(&lenIf.getThenRegion().front());
530   createSort3(builder, loc, xPerm, ny, swapOperands, compareOperands, lo, mi,
531               hi);
532 
533   // When len >= 1000, choose pivot from median of 5 values.
534   builder.setInsertionPointToStart(&lenIf.getElseRegion().front());
535   Value miP1 = builder.create<arith::AddIOp>(loc, hi, c1);
536   Value a = builder.create<arith::AddIOp>(loc, lo, miP1);
537   // Value a is the middle between [loc, mi].
538   a = builder.create<arith::ShRUIOp>(loc, a, c1);
539   Value b = builder.create<arith::AddIOp>(loc, mi, hiP1);
540   // Value b is the middle between [mi, hi].
541   b = builder.create<arith::ShRUIOp>(loc, b, c1);
542   createSort5(builder, loc, xPerm, ny, swapOperands, compareOperands, lo, a, mi,
543               b, hi);
544 
545   builder.setInsertionPointAfter(lenIf);
546 }
547 
548 /// Creates a function to perform quick sort partition on the values in the
549 /// range of index [lo, hi), assuming lo < hi.
550 //
551 // The generated IR corresponds to this C like algorithm:
552 // int partition(lo, hi, xs) {
553 //   p = (lo+hi)/2  // pivot index
554 //   i = lo
555 //   j = hi-1
556 //   while (true) do {
557 //     while (xs[i] < xs[p]) i ++;
558 //     i_eq = (xs[i] == xs[p]);
559 //     while (xs[j] > xs[p]) j --;
560 //     j_eq = (xs[j] == xs[p]);
561 //
562 //     if (i >= j) return j + 1;
563 //
564 //     if (i < j) {
565 //       swap(xs[i], xs[j])
566 //       if (i == p) {
567 //         p = j;
568 //       } else if (j == p) {
569 //         p = i;
570 //       }
571 //       if (i_eq && j_eq) {
572 //         ++i;
573 //         --j;
574 //       }
575 //     }
576 //   }
577 // }
578 static void createPartitionFunc(OpBuilder &builder, ModuleOp module,
579                                 func::FuncOp func, AffineMap xPerm, uint64_t ny,
580                                 uint32_t nTrailingP = 0) {
581   // Quick sort partition doesn't use trailing parameters.
582   (void)nTrailingP;
583   assert(nTrailingP == 0);
584   OpBuilder::InsertionGuard insertionGuard(builder);
585 
586   Block *entryBlock = func.addEntryBlock();
587   builder.setInsertionPointToStart(entryBlock);
588 
589   Location loc = func.getLoc();
590   ValueRange args = entryBlock->getArguments();
591   Value lo = args[loIdx];
592   Value hi = args[hiIdx];
593   Value sum = builder.create<arith::AddIOp>(loc, lo, hi);
594   Value c1 = constantIndex(builder, loc, 1);
595   Value p = builder.create<arith::ShRUIOp>(loc, sum, c1);
596 
597   Value i = lo;
598   Value j = builder.create<arith::SubIOp>(loc, hi, c1);
599   createChoosePivot(builder, module, func, xPerm, ny, i, j, p, args);
600   Value trueVal = constantI1(builder, loc, true); // The value for while (true)
601   SmallVector<Value, 4> operands{i, j, p, trueVal}; // Exactly four values.
602   SmallVector<Type, 4> types{i.getType(), j.getType(), p.getType(),
603                              trueVal.getType()};
604   scf::WhileOp whileOp = builder.create<scf::WhileOp>(loc, types, operands);
605 
606   // The before-region of the WhileOp.
607   Block *before = builder.createBlock(&whileOp.getBefore(), {}, types,
608                                       {loc, loc, loc, loc});
609   builder.setInsertionPointToEnd(before);
610   builder.create<scf::ConditionOp>(loc, before->getArgument(3),
611                                    before->getArguments());
612 
613   // The after-region of the WhileOp.
614   Block *after =
615       builder.createBlock(&whileOp.getAfter(), {}, types, {loc, loc, loc, loc});
616   builder.setInsertionPointToEnd(after);
617   i = after->getArgument(0);
618   j = after->getArgument(1);
619   p = after->getArgument(2);
620 
621   constexpr uint64_t numXBuffers = 1;
622   auto [iresult, iCompareEq] =
623       createScanLoop(builder, module, func, args.slice(xStartIdx, numXBuffers),
624                      i, p, xPerm, ny, 1);
625   i = iresult;
626   auto [jresult, jCompareEq] =
627       createScanLoop(builder, module, func, args.slice(xStartIdx, numXBuffers),
628                      j, p, xPerm, ny, -1);
629   j = jresult;
630 
631   // If i < j:
632   Value cond =
633       builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult, i, j);
634   scf::IfOp ifOp = builder.create<scf::IfOp>(loc, types, cond, /*else=*/true);
635   builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
636   SmallVector<Value> swapOperands{i, j};
637   swapOperands.append(args.begin() + xStartIdx, args.end());
638   createSwap(builder, loc, swapOperands, xPerm, ny);
639   // If the pivot is moved, update p with the new pivot.
640   Value icond =
641       builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, i, p);
642   scf::IfOp ifOpI = builder.create<scf::IfOp>(loc, TypeRange{p.getType()},
643                                               icond, /*else=*/true);
644   builder.setInsertionPointToStart(&ifOpI.getThenRegion().front());
645   builder.create<scf::YieldOp>(loc, ValueRange{j});
646   builder.setInsertionPointToStart(&ifOpI.getElseRegion().front());
647   Value jcond =
648       builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, j, p);
649   scf::IfOp ifOpJ = builder.create<scf::IfOp>(loc, TypeRange{p.getType()},
650                                               jcond, /*else=*/true);
651   builder.setInsertionPointToStart(&ifOpJ.getThenRegion().front());
652   builder.create<scf::YieldOp>(loc, ValueRange{i});
653   builder.setInsertionPointToStart(&ifOpJ.getElseRegion().front());
654   builder.create<scf::YieldOp>(loc, ValueRange{p});
655   builder.setInsertionPointAfter(ifOpJ);
656   builder.create<scf::YieldOp>(loc, ifOpJ.getResults());
657   builder.setInsertionPointAfter(ifOpI);
658   Value compareEqIJ =
659       builder.create<arith::AndIOp>(loc, iCompareEq, jCompareEq);
660   scf::IfOp ifOp2 = builder.create<scf::IfOp>(
661       loc, TypeRange{i.getType(), j.getType()}, compareEqIJ, /*else=*/true);
662   builder.setInsertionPointToStart(&ifOp2.getThenRegion().front());
663   Value i2 = builder.create<arith::AddIOp>(loc, i, c1);
664   Value j2 = builder.create<arith::SubIOp>(loc, j, c1);
665   builder.create<scf::YieldOp>(loc, ValueRange{i2, j2});
666   builder.setInsertionPointToStart(&ifOp2.getElseRegion().front());
667   builder.create<scf::YieldOp>(loc, ValueRange{i, j});
668   builder.setInsertionPointAfter(ifOp2);
669   builder.create<scf::YieldOp>(
670       loc,
671       ValueRange{ifOp2.getResult(0), ifOp2.getResult(1), ifOpI.getResult(0),
672                  /*cont=*/constantI1(builder, loc, true)});
673 
674   // False branch for if i < j (i.e., i >= j):
675   builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
676   p = builder.create<arith::AddIOp>(loc, j,
677                                     constantOne(builder, loc, j.getType()));
678   builder.create<scf::YieldOp>(
679       loc, ValueRange{i, j, p, /*cont=*/constantI1(builder, loc, false)});
680 
681   // Return for the whileOp.
682   builder.setInsertionPointAfter(ifOp);
683   builder.create<scf::YieldOp>(loc, ifOp.getResults());
684 
685   // Return for the function.
686   builder.setInsertionPointAfter(whileOp);
687   builder.create<func::ReturnOp>(loc, whileOp.getResult(2));
688 }
689 
690 /// Computes (n-2)/n, assuming n has index type.
691 static Value createSubTwoDividedByTwo(OpBuilder &builder, Location loc,
692                                       Value n) {
693   Value i2 = constantIndex(builder, loc, 2);
694   Value res = builder.create<arith::SubIOp>(loc, n, i2);
695   Value i1 = constantIndex(builder, loc, 1);
696   return builder.create<arith::ShRUIOp>(loc, res, i1);
697 }
698 
699 /// Creates a function to heapify the subtree with root `start` within the full
700 /// binary tree in the range of index [first, first + n).
701 //
702 // The generated IR corresponds to this C like algorithm:
703 // void shiftDown(first, start, n, data) {
704 //   if (n >= 2) {
705 //     child = start - first
706 //     if ((n-2)/2 >= child) {
707 //       // Left child exists.
708 //       child = child * 2 + 1 // Initialize the bigger child to left child.
709 //       childIndex = child + first
710 //       if (child+1 < n && data[childIndex] < data[childIndex+1])
711 //         // Right child exits and is bigger.
712 //         childIndex++; child++;
713 //       // Shift data[start] down to where it belongs in the subtree.
714 //       while (data[start] < data[childIndex) {
715 //         swap(data[start], data[childIndex])
716 //         start = childIndex
717 //         if ((n - 2)/2 >= child) {
718 //           // Left child exists.
719 //           child = 2*child + 1
720 //           childIndex = child + 1
721 //           if (child + 1) < n && data[childIndex] < data[childIndex+1]
722 //             childIndex++; child++;
723 //         }
724 //       }
725 //     }
726 //   }
727 // }
728 //
729 static void createShiftDownFunc(OpBuilder &builder, ModuleOp module,
730                                 func::FuncOp func, AffineMap xPerm, uint64_t ny,
731                                 uint32_t nTrailingP) {
732   // The value n is passed in as a trailing parameter.
733   assert(nTrailingP == 1);
734   OpBuilder::InsertionGuard insertionGuard(builder);
735   Block *entryBlock = func.addEntryBlock();
736   builder.setInsertionPointToStart(entryBlock);
737 
738   Location loc = func.getLoc();
739   Value n = entryBlock->getArguments().back();
740   ValueRange args = entryBlock->getArguments().drop_back();
741   Value first = args[loIdx];
742   Value start = args[hiIdx];
743 
744   // If (n >= 2).
745   Value c2 = constantIndex(builder, loc, 2);
746   Value condN =
747       builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::uge, n, c2);
748   scf::IfOp ifN = builder.create<scf::IfOp>(loc, condN, /*else=*/false);
749   builder.setInsertionPointToStart(&ifN.getThenRegion().front());
750   Value child = builder.create<arith::SubIOp>(loc, start, first);
751 
752   // If ((n-2)/2 >= child).
753   Value t = createSubTwoDividedByTwo(builder, loc, n);
754   Value condNc =
755       builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::uge, t, child);
756   scf::IfOp ifNc = builder.create<scf::IfOp>(loc, condNc, /*else=*/false);
757 
758   builder.setInsertionPointToStart(&ifNc.getThenRegion().front());
759   Value c1 = constantIndex(builder, loc, 1);
760   SmallVector<Value> compareOperands{start, start};
761   constexpr uint64_t numXBuffers = 1;
762   compareOperands.append(args.begin() + xStartIdx,
763                          args.begin() + xStartIdx + numXBuffers);
764 
765   // Generate code to inspect the children of 'r' and return the larger child
766   // as follows:
767   //   child = r * 2 + 1 // Left child.
768   //   childIndex = child + first
769   //   if (child+1 < n && data[childIndex] < data[childIndex+1])
770   //     childIndex ++; child ++ // Right child is bigger.
771   auto getLargerChild = [&](Value r) -> std::pair<Value, Value> {
772     Value lChild = builder.create<arith::ShLIOp>(loc, r, c1);
773     lChild = builder.create<arith::AddIOp>(loc, lChild, c1);
774     Value lChildIdx = builder.create<arith::AddIOp>(loc, lChild, first);
775     Value rChild = builder.create<arith::AddIOp>(loc, lChild, c1);
776     Value cond1 = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult,
777                                                 rChild, n);
778     SmallVector<Type, 2> ifTypes(2, r.getType());
779     scf::IfOp if1 =
780         builder.create<scf::IfOp>(loc, ifTypes, cond1, /*else=*/true);
781     builder.setInsertionPointToStart(&if1.getThenRegion().front());
782     Value rChildIdx = builder.create<arith::AddIOp>(loc, rChild, first);
783     // Compare data[left] < data[right].
784     compareOperands[0] = lChildIdx;
785     compareOperands[1] = rChildIdx;
786     Value cond2 =
787         createInlinedLessThan(builder, loc, compareOperands, xPerm, ny);
788     scf::IfOp if2 =
789         builder.create<scf::IfOp>(loc, ifTypes, cond2, /*else=*/true);
790     builder.setInsertionPointToStart(&if2.getThenRegion().front());
791     builder.create<scf::YieldOp>(loc, ValueRange{rChild, rChildIdx});
792     builder.setInsertionPointToStart(&if2.getElseRegion().front());
793     builder.create<scf::YieldOp>(loc, ValueRange{lChild, lChildIdx});
794     builder.setInsertionPointAfter(if2);
795     builder.create<scf::YieldOp>(loc, if2.getResults());
796     builder.setInsertionPointToStart(&if1.getElseRegion().front());
797     builder.create<scf::YieldOp>(loc, ValueRange{lChild, lChildIdx});
798     builder.setInsertionPointAfter(if1);
799     return std::make_pair(if1.getResult(0), if1.getResult(1));
800   };
801 
802   Value childIdx;
803   std::tie(child, childIdx) = getLargerChild(child);
804 
805   // While (data[start] < data[childIndex]).
806   SmallVector<Type, 3> types(3, child.getType());
807   scf::WhileOp whileOp = builder.create<scf::WhileOp>(
808       loc, types, SmallVector<Value, 2>{start, child, childIdx});
809 
810   // The before-region of the WhileOp.
811   SmallVector<Location, 3> locs(3, loc);
812   Block *before = builder.createBlock(&whileOp.getBefore(), {}, types, locs);
813   builder.setInsertionPointToEnd(before);
814   start = before->getArgument(0);
815   childIdx = before->getArgument(2);
816   compareOperands[0] = start;
817   compareOperands[1] = childIdx;
818   Value cond = createInlinedLessThan(builder, loc, compareOperands, xPerm, ny);
819   builder.create<scf::ConditionOp>(loc, cond, before->getArguments());
820 
821   // The after-region of the WhileOp.
822   Block *after = builder.createBlock(&whileOp.getAfter(), {}, types, locs);
823   start = after->getArgument(0);
824   child = after->getArgument(1);
825   childIdx = after->getArgument(2);
826   SmallVector<Value> swapOperands{start, childIdx};
827   swapOperands.append(args.begin() + xStartIdx, args.end());
828   createSwap(builder, loc, swapOperands, xPerm, ny);
829   start = childIdx;
830   Value cond2 =
831       builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::uge, t, child);
832   scf::IfOp if2 = builder.create<scf::IfOp>(
833       loc, TypeRange{child.getType(), child.getType()}, cond2, /*else=*/true);
834   builder.setInsertionPointToStart(&if2.getThenRegion().front());
835   auto [newChild, newChildIdx] = getLargerChild(child);
836   builder.create<scf::YieldOp>(loc, ValueRange{newChild, newChildIdx});
837   builder.setInsertionPointToStart(&if2.getElseRegion().front());
838   builder.create<scf::YieldOp>(loc, ValueRange{child, childIdx});
839   builder.setInsertionPointAfter(if2);
840   builder.create<scf::YieldOp>(
841       loc, ValueRange{start, if2.getResult(0), if2.getResult(1)});
842 
843   builder.setInsertionPointAfter(ifN);
844   builder.create<func::ReturnOp>(loc);
845 }
846 
847 /// Creates a function to perform heap sort on the values in the range of index
848 /// [lo, hi) with the assumption hi - lo >= 2.
849 //
850 // The generate IR corresponds to this C like algorithm:
851 // void heapSort(lo, hi, data) {
852 //   n = hi - lo
853 //   for i = (n-2)/2 downto 0
854 //     shiftDown(lo, lo+i, n)
855 //
856 //   for l = n downto 2
857 //      swap(lo, lo+l-1)
858 //      shiftdown(lo, lo, l-1)
859 // }
860 static void createHeapSortFunc(OpBuilder &builder, ModuleOp module,
861                                func::FuncOp func, AffineMap xPerm, uint64_t ny,
862                                uint32_t nTrailingP) {
863   // Heap sort function doesn't have trailing parameters.
864   (void)nTrailingP;
865   assert(nTrailingP == 0);
866   OpBuilder::InsertionGuard insertionGuard(builder);
867   Block *entryBlock = func.addEntryBlock();
868   builder.setInsertionPointToStart(entryBlock);
869 
870   Location loc = func.getLoc();
871   ValueRange args = entryBlock->getArguments();
872   Value lo = args[loIdx];
873   Value hi = args[hiIdx];
874   Value n = builder.create<arith::SubIOp>(loc, hi, lo);
875 
876   // For i = (n-2)/2 downto 0.
877   Value c0 = constantIndex(builder, loc, 0);
878   Value c1 = constantIndex(builder, loc, 1);
879   Value s = createSubTwoDividedByTwo(builder, loc, n);
880   Value up = builder.create<arith::AddIOp>(loc, s, c1);
881   scf::ForOp forI = builder.create<scf::ForOp>(loc, c0, up, c1);
882   builder.setInsertionPointToStart(forI.getBody());
883   Value i = builder.create<arith::SubIOp>(loc, s, forI.getInductionVar());
884   Value lopi = builder.create<arith::AddIOp>(loc, lo, i);
885   SmallVector<Value> shiftDownOperands = {lo, lopi};
886   shiftDownOperands.append(args.begin() + xStartIdx, args.end());
887   shiftDownOperands.push_back(n);
888   FlatSymbolRefAttr shiftDownFunc = getMangledSortHelperFunc(
889       builder, func, TypeRange(), kShiftDownFuncNamePrefix, xPerm, ny,
890       shiftDownOperands, createShiftDownFunc, /*nTrailingP=*/1);
891   builder.create<func::CallOp>(loc, shiftDownFunc, TypeRange(),
892                                shiftDownOperands);
893 
894   builder.setInsertionPointAfter(forI);
895   // For l = n downto 2.
896   up = builder.create<arith::SubIOp>(loc, n, c1);
897   scf::ForOp forL = builder.create<scf::ForOp>(loc, c0, up, c1);
898   builder.setInsertionPointToStart(forL.getBody());
899   Value l = builder.create<arith::SubIOp>(loc, n, forL.getInductionVar());
900   Value loplm1 = builder.create<arith::AddIOp>(loc, lo, l);
901   loplm1 = builder.create<arith::SubIOp>(loc, loplm1, c1);
902   SmallVector<Value> swapOperands{lo, loplm1};
903   swapOperands.append(args.begin() + xStartIdx, args.end());
904   createSwap(builder, loc, swapOperands, xPerm, ny);
905   shiftDownOperands[1] = lo;
906   shiftDownOperands[shiftDownOperands.size() - 1] =
907       builder.create<arith::SubIOp>(loc, l, c1);
908   builder.create<func::CallOp>(loc, shiftDownFunc, TypeRange(),
909                                shiftDownOperands);
910 
911   builder.setInsertionPointAfter(forL);
912   builder.create<func::ReturnOp>(loc);
913 }
914 
915 /// A helper for generating code to perform quick sort. It partitions [lo, hi),
916 /// recursively calls quick sort to process the smaller partition and returns
917 /// the bigger partition to be processed by the enclosed while-loop.
918 static std::pair<Value, Value>
919 createQuickSort(OpBuilder &builder, ModuleOp module, func::FuncOp func,
920                 ValueRange args, AffineMap xPerm, uint64_t ny,
921                 uint32_t nTrailingP) {
922   MLIRContext *context = module.getContext();
923   Location loc = func.getLoc();
924   Value lo = args[loIdx];
925   Value hi = args[hiIdx];
926   SmallVector<Type, 2> types(2, lo.getType()); // Only two types.
927 
928   FlatSymbolRefAttr partitionFunc = getMangledSortHelperFunc(
929       builder, func, {IndexType::get(context)}, kPartitionFuncNamePrefix, xPerm,
930       ny, args.drop_back(nTrailingP), createPartitionFunc);
931   Value p = builder
932                 .create<func::CallOp>(loc, partitionFunc,
933                                       TypeRange{IndexType::get(context)},
934                                       args.drop_back(nTrailingP))
935                 .getResult(0);
936 
937   Value lenLow = builder.create<arith::SubIOp>(loc, p, lo);
938   Value lenHigh = builder.create<arith::SubIOp>(loc, hi, p);
939   // Partition already sorts array with len <= 2
940   Value c2 = constantIndex(builder, loc, 2);
941   Value len = builder.create<arith::SubIOp>(loc, hi, lo);
942   Value lenGtTwo =
943       builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ugt, len, c2);
944   scf::IfOp ifLenGtTwo =
945       builder.create<scf::IfOp>(loc, types, lenGtTwo, /*else=*/true);
946   builder.setInsertionPointToStart(&ifLenGtTwo.getElseRegion().front());
947   // Returns an empty range to mark the entire region is fully sorted.
948   builder.create<scf::YieldOp>(loc, ValueRange{lo, lo});
949 
950   // Else len > 2, need recursion.
951   builder.setInsertionPointToStart(&ifLenGtTwo.getThenRegion().front());
952   Value cond = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ule,
953                                              lenLow, lenHigh);
954 
955   scf::IfOp ifOp = builder.create<scf::IfOp>(loc, types, cond, /*else=*/true);
956 
957   Value c0 = constantIndex(builder, loc, 0);
958   auto mayRecursion = [&](Value low, Value high, Value len) {
959     Value cond =
960         builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ne, len, c0);
961     scf::IfOp ifOp = builder.create<scf::IfOp>(loc, cond, /*else=*/false);
962     builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
963     SmallVector<Value> operands{low, high};
964     operands.append(args.begin() + xStartIdx, args.end());
965     builder.create<func::CallOp>(loc, func, operands);
966     builder.setInsertionPointAfter(ifOp);
967   };
968 
969   // Recursively call quickSort to process the smaller partition and return
970   // the bigger partition to be processed by the enclosed while-loop.
971   builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
972   mayRecursion(lo, p, lenLow);
973   builder.create<scf::YieldOp>(loc, ValueRange{p, hi});
974 
975   builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
976   mayRecursion(p, hi, lenHigh);
977   builder.create<scf::YieldOp>(loc, ValueRange{lo, p});
978 
979   builder.setInsertionPointAfter(ifOp);
980   builder.create<scf::YieldOp>(loc, ifOp.getResults());
981 
982   builder.setInsertionPointAfter(ifLenGtTwo);
983   return std::make_pair(ifLenGtTwo.getResult(0), ifLenGtTwo.getResult(1));
984 }
985 
986 /// Creates a function to perform insertion sort on the values in the range of
987 /// index [lo, hi).
988 //
989 // The generate IR corresponds to this C like algorithm:
990 // void insertionSort(lo, hi, data) {
991 //   for (i = lo+1; i < hi; i++) {
992 //      d = data[i];
993 //      p = binarySearch(lo, i-1, data)
994 //      for (j = 0; j > i - p; j++)
995 //        data[i-j] = data[i-j-1]
996 //      data[p] = d
997 //   }
998 // }
999 static void createSortStableFunc(OpBuilder &builder, ModuleOp module,
1000                                  func::FuncOp func, AffineMap xPerm,
1001                                  uint64_t ny, uint32_t nTrailingP) {
1002   // Stable sort function doesn't use trailing parameters.
1003   (void)nTrailingP;
1004   assert(nTrailingP == 0);
1005   OpBuilder::InsertionGuard insertionGuard(builder);
1006   Block *entryBlock = func.addEntryBlock();
1007   builder.setInsertionPointToStart(entryBlock);
1008 
1009   MLIRContext *context = module.getContext();
1010   Location loc = func.getLoc();
1011   ValueRange args = entryBlock->getArguments();
1012   Value c1 = constantIndex(builder, loc, 1);
1013   Value lo = args[loIdx];
1014   Value hi = args[hiIdx];
1015   Value lop1 = builder.create<arith::AddIOp>(loc, lo, c1);
1016 
1017   // Start the outer for-stmt with induction variable i.
1018   scf::ForOp forOpI = builder.create<scf::ForOp>(loc, lop1, hi, c1);
1019   builder.setInsertionPointToStart(forOpI.getBody());
1020   Value i = forOpI.getInductionVar();
1021 
1022   // Binary search to find the insertion point p.
1023   SmallVector<Value> operands{lo, i};
1024   operands.append(args.begin() + xStartIdx, args.end());
1025   FlatSymbolRefAttr searchFunc = getMangledSortHelperFunc(
1026       builder, func, {IndexType::get(context)}, kBinarySearchFuncNamePrefix,
1027       xPerm, ny, operands, createBinarySearchFunc);
1028   Value p = builder
1029                 .create<func::CallOp>(loc, searchFunc, TypeRange{c1.getType()},
1030                                       operands)
1031                 .getResult(0);
1032 
1033   // Move the value at data[i] to a temporary location.
1034   operands[0] = operands[1] = i;
1035   SmallVector<Value> d;
1036   forEachIJPairInAllBuffers(
1037       builder, loc, operands, xPerm, ny,
1038       [&](uint64_t unused, Value i, Value unused2, Value buffer) {
1039         d.push_back(builder.create<memref::LoadOp>(loc, buffer, i));
1040       });
1041 
1042   // Start the inner for-stmt with induction variable j, for moving data[p..i)
1043   // to data[p+1..i+1).
1044   Value imp = builder.create<arith::SubIOp>(loc, i, p);
1045   Value c0 = constantIndex(builder, loc, 0);
1046   scf::ForOp forOpJ = builder.create<scf::ForOp>(loc, c0, imp, c1);
1047   builder.setInsertionPointToStart(forOpJ.getBody());
1048   Value j = forOpJ.getInductionVar();
1049   Value imj = builder.create<arith::SubIOp>(loc, i, j);
1050   operands[1] = imj;
1051   operands[0] = builder.create<arith::SubIOp>(loc, imj, c1);
1052   forEachIJPairInAllBuffers(
1053       builder, loc, operands, xPerm, ny,
1054       [&](uint64_t unused, Value imjm1, Value imj, Value buffer) {
1055         Value t = builder.create<memref::LoadOp>(loc, buffer, imjm1);
1056         builder.create<memref::StoreOp>(loc, t, buffer, imj);
1057       });
1058 
1059   // Store the value at data[i] to data[p].
1060   builder.setInsertionPointAfter(forOpJ);
1061   operands[0] = operands[1] = p;
1062   forEachIJPairInAllBuffers(
1063       builder, loc, operands, xPerm, ny,
1064       [&](uint64_t k, Value p, Value usused, Value buffer) {
1065         builder.create<memref::StoreOp>(loc, d[k], buffer, p);
1066       });
1067 
1068   builder.setInsertionPointAfter(forOpI);
1069   builder.create<func::ReturnOp>(loc);
1070 }
1071 
1072 /// Creates a function to perform quick sort or a hybrid quick sort on the
1073 /// values in the range of index [lo, hi).
1074 //
1075 //
1076 // When nTrailingP == 0, the generated IR corresponds to this C like algorithm:
1077 // void quickSort(lo, hi, data) {
1078 //   while (lo + 1 < hi) {
1079 //        p = partition(low, high, data);
1080 //        if (len(lo, p) < len(p+1, hi)) {
1081 //          quickSort(lo, p, data);
1082 //          lo = p+1;
1083 //        } else {
1084 //          quickSort(p + 1, hi, data);
1085 //          hi = p;
1086 //        }
1087 //   }
1088 // }
1089 //
1090 // When nTrailingP == 1, the generated IR corresponds to this C like algorithm:
1091 // void hybridQuickSort(lo, hi, data, depthLimit) {
1092 //   while (lo + 1 < hi) {
1093 //     len = hi - lo;
1094 //     if (len <= limit) {
1095 //       insertionSort(lo, hi, data);
1096 //     } else {
1097 //       depthLimit --;
1098 //       if (depthLimit <= 0) {
1099 //         heapSort(lo, hi, data);
1100 //       } else {
1101 //          p = partition(low, high, data);
1102 //          if (len(lo, p) < len(p+1, hi)) {
1103 //            quickSort(lo, p, data, depthLimit);
1104 //            lo = p+1;
1105 //          } else {
1106 //            quickSort(p + 1, hi, data, depthLimit);
1107 //            hi = p;
1108 //          }
1109 //       }
1110 //     }
1111 //   }
1112 // }
1113 //
1114 static void createQuickSortFunc(OpBuilder &builder, ModuleOp module,
1115                                 func::FuncOp func, AffineMap xPerm, uint64_t ny,
1116                                 uint32_t nTrailingP) {
1117   assert(nTrailingP == 1 || nTrailingP == 0);
1118   bool isHybrid = (nTrailingP == 1);
1119   OpBuilder::InsertionGuard insertionGuard(builder);
1120   Block *entryBlock = func.addEntryBlock();
1121   builder.setInsertionPointToStart(entryBlock);
1122 
1123   Location loc = func.getLoc();
1124   SmallVector<Value> args;
1125   args.append(entryBlock->getArguments().begin(),
1126               entryBlock->getArguments().end());
1127   Value lo = args[loIdx];
1128   Value hi = args[hiIdx];
1129   SmallVector<Type, 2> types(2, lo.getType()); // Only two types.
1130   scf::WhileOp whileOp =
1131       builder.create<scf::WhileOp>(loc, types, SmallVector<Value, 2>{lo, hi});
1132 
1133   // The before-region of the WhileOp.
1134   Block *before =
1135       builder.createBlock(&whileOp.getBefore(), {}, types, {loc, loc});
1136   builder.setInsertionPointToEnd(before);
1137   lo = before->getArgument(0);
1138   hi = before->getArgument(1);
1139   Value loP1 =
1140       builder.create<arith::AddIOp>(loc, lo, constantIndex(builder, loc, 1));
1141   Value needSort =
1142       builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult, loP1, hi);
1143   builder.create<scf::ConditionOp>(loc, needSort, before->getArguments());
1144 
1145   // The after-region of the WhileOp.
1146   Block *after =
1147       builder.createBlock(&whileOp.getAfter(), {}, types, {loc, loc});
1148   builder.setInsertionPointToEnd(after);
1149   lo = after->getArgument(0);
1150   hi = after->getArgument(1);
1151   args[0] = lo;
1152   args[1] = hi;
1153 
1154   if (isHybrid) {
1155     Value len = builder.create<arith::SubIOp>(loc, hi, lo);
1156     Value lenLimit = constantIndex(builder, loc, 30);
1157     Value lenCond = builder.create<arith::CmpIOp>(
1158         loc, arith::CmpIPredicate::ule, len, lenLimit);
1159     scf::IfOp lenIf =
1160         builder.create<scf::IfOp>(loc, types, lenCond, /*else=*/true);
1161 
1162     // When len <= limit.
1163     builder.setInsertionPointToStart(&lenIf.getThenRegion().front());
1164     FlatSymbolRefAttr insertionSortFunc = getMangledSortHelperFunc(
1165         builder, func, TypeRange(), kSortStableFuncNamePrefix, xPerm, ny,
1166         ValueRange(args).drop_back(nTrailingP), createSortStableFunc);
1167     builder.create<func::CallOp>(loc, insertionSortFunc, TypeRange(),
1168                                  ValueRange(args).drop_back(nTrailingP));
1169     builder.create<scf::YieldOp>(loc, ValueRange{lo, lo});
1170 
1171     // When len > limit.
1172     builder.setInsertionPointToStart(&lenIf.getElseRegion().front());
1173     Value depthLimit = args.back();
1174     depthLimit = builder.create<arith::SubIOp>(loc, depthLimit,
1175                                                constantI64(builder, loc, 1));
1176     Value depthCond =
1177         builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ule,
1178                                       depthLimit, constantI64(builder, loc, 0));
1179     scf::IfOp depthIf =
1180         builder.create<scf::IfOp>(loc, types, depthCond, /*else=*/true);
1181 
1182     // When depth exceeds limit.
1183     builder.setInsertionPointToStart(&depthIf.getThenRegion().front());
1184     FlatSymbolRefAttr heapSortFunc = getMangledSortHelperFunc(
1185         builder, func, TypeRange(), kHeapSortFuncNamePrefix, xPerm, ny,
1186         ValueRange(args).drop_back(nTrailingP), createHeapSortFunc);
1187     builder.create<func::CallOp>(loc, heapSortFunc, TypeRange(),
1188                                  ValueRange(args).drop_back(nTrailingP));
1189     builder.create<scf::YieldOp>(loc, ValueRange{lo, lo});
1190 
1191     // When depth doesn't exceed limit.
1192     builder.setInsertionPointToStart(&depthIf.getElseRegion().front());
1193     args.back() = depthLimit;
1194     std::tie(lo, hi) =
1195         createQuickSort(builder, module, func, args, xPerm, ny, nTrailingP);
1196     builder.create<scf::YieldOp>(loc, ValueRange{lo, hi});
1197 
1198     builder.setInsertionPointAfter(depthIf);
1199     lo = depthIf.getResult(0);
1200     hi = depthIf.getResult(1);
1201     builder.create<scf::YieldOp>(loc, ValueRange{lo, hi});
1202 
1203     builder.setInsertionPointAfter(lenIf);
1204     lo = lenIf.getResult(0);
1205     hi = lenIf.getResult(1);
1206   } else {
1207     std::tie(lo, hi) =
1208         createQuickSort(builder, module, func, args, xPerm, ny, nTrailingP);
1209   }
1210 
1211   // New [lo, hi) for the next while-loop iteration.
1212   builder.create<scf::YieldOp>(loc, ValueRange{lo, hi});
1213 
1214   // After the while-loop.
1215   builder.setInsertionPointAfter(whileOp);
1216   builder.create<func::ReturnOp>(loc);
1217 }
1218 
1219 /// Implements the rewriting for operator sort and sort_coo.
1220 template <typename OpTy>
1221 LogicalResult matchAndRewriteSortOp(OpTy op, ValueRange xys, AffineMap xPerm,
1222                                     uint64_t ny, PatternRewriter &rewriter) {
1223   Location loc = op.getLoc();
1224   SmallVector<Value> operands{constantIndex(rewriter, loc, 0), op.getN()};
1225 
1226   // Convert `values` to have dynamic shape and append them to `operands`.
1227   for (Value v : xys) {
1228     auto mtp = getMemRefType(v);
1229     if (!mtp.isDynamicDim(0)) {
1230       auto newMtp =
1231           MemRefType::get({ShapedType::kDynamic}, mtp.getElementType());
1232       v = rewriter.create<memref::CastOp>(loc, newMtp, v);
1233     }
1234     operands.push_back(v);
1235   }
1236 
1237   auto insertPoint = op->template getParentOfType<func::FuncOp>();
1238   if (!insertPoint)
1239     return failure();
1240 
1241   SmallString<32> funcName;
1242   FuncGeneratorType funcGenerator;
1243   uint32_t nTrailingP = 0;
1244   switch (op.getAlgorithm()) {
1245   case SparseTensorSortKind::HybridQuickSort: {
1246     funcName = kHybridQuickSortFuncNamePrefix;
1247     funcGenerator = createQuickSortFunc;
1248     nTrailingP = 1;
1249     // As a heuristics, set depthLimit = 2 * log2(n).
1250     Value lo = operands[loIdx];
1251     Value hi = operands[hiIdx];
1252     Value len = rewriter.create<arith::IndexCastOp>(
1253         loc, rewriter.getI64Type(),
1254         rewriter.create<arith::SubIOp>(loc, hi, lo));
1255     Value depthLimit = rewriter.create<arith::SubIOp>(
1256         loc, constantI64(rewriter, loc, 64),
1257         rewriter.create<math::CountLeadingZerosOp>(loc, len));
1258     operands.push_back(depthLimit);
1259     break;
1260   }
1261   case SparseTensorSortKind::QuickSort:
1262     funcName = kQuickSortFuncNamePrefix;
1263     funcGenerator = createQuickSortFunc;
1264     break;
1265   case SparseTensorSortKind::InsertionSortStable:
1266     funcName = kSortStableFuncNamePrefix;
1267     funcGenerator = createSortStableFunc;
1268     break;
1269   case SparseTensorSortKind::HeapSort:
1270     funcName = kHeapSortFuncNamePrefix;
1271     funcGenerator = createHeapSortFunc;
1272     break;
1273   }
1274 
1275   FlatSymbolRefAttr func =
1276       getMangledSortHelperFunc(rewriter, insertPoint, TypeRange(), funcName,
1277                                xPerm, ny, operands, funcGenerator, nTrailingP);
1278   rewriter.replaceOpWithNewOp<func::CallOp>(op, func, TypeRange(), operands);
1279   return success();
1280 }
1281 
1282 //===---------------------------------------------------------------------===//
1283 // The actual sparse buffer rewriting rules.
1284 //===---------------------------------------------------------------------===//
1285 
1286 namespace {
1287 /// Sparse rewriting rule for the push_back operator.
1288 struct PushBackRewriter : OpRewritePattern<PushBackOp> {
1289 public:
1290   using OpRewritePattern<PushBackOp>::OpRewritePattern;
1291   PushBackRewriter(MLIRContext *context, bool enableInit)
1292       : OpRewritePattern(context), enableBufferInitialization(enableInit) {}
1293   LogicalResult matchAndRewrite(PushBackOp op,
1294                                 PatternRewriter &rewriter) const override {
1295     // Rewrite push_back(buffer, value, n) to:
1296     // new_size = size(buffer) + n
1297     // if (new_size > capacity(buffer))
1298     //    while new_size > new_capacity
1299     //      new_capacity = new_capacity*2
1300     //    new_buffer = realloc(buffer, new_capacity)
1301     // buffer = new_buffer
1302     // subBuffer = subviewof(buffer)
1303     // linalg.fill subBuffer value
1304     //
1305     // size(buffer) += n
1306     //
1307     // The capacity check is skipped when the attribute inbounds is presented.
1308     Location loc = op->getLoc();
1309     Value c0 = constantIndex(rewriter, loc, 0);
1310     Value buffer = op.getInBuffer();
1311     Value capacity = rewriter.create<memref::DimOp>(loc, buffer, c0);
1312     Value size = op.getCurSize();
1313     Value value = op.getValue();
1314 
1315     Value n = op.getN() ? op.getN() : constantIndex(rewriter, loc, 1);
1316     Value newSize = rewriter.create<arith::AddIOp>(loc, size, n);
1317     auto nValue = dyn_cast_or_null<arith::ConstantIndexOp>(n.getDefiningOp());
1318     bool nIsOne = (nValue && nValue.value() == 1);
1319 
1320     if (!op.getInbounds()) {
1321       Value cond = rewriter.create<arith::CmpIOp>(
1322           loc, arith::CmpIPredicate::ugt, newSize, capacity);
1323 
1324       Value c2 = constantIndex(rewriter, loc, 2);
1325       auto bufferType =
1326           MemRefType::get({ShapedType::kDynamic}, value.getType());
1327       scf::IfOp ifOp = rewriter.create<scf::IfOp>(loc, bufferType, cond,
1328                                                   /*else=*/true);
1329       // True branch.
1330       rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front());
1331       if (nIsOne) {
1332         capacity = rewriter.create<arith::MulIOp>(loc, capacity, c2);
1333       } else {
1334         // Use a do-while loop to calculate the new capacity as follows:
1335         //   do { new_capacity *= 2 } while (size > new_capacity)
1336         scf::WhileOp whileOp =
1337             rewriter.create<scf::WhileOp>(loc, capacity.getType(), capacity);
1338 
1339         // The before-region of the WhileOp.
1340         Block *before = rewriter.createBlock(&whileOp.getBefore(), {},
1341                                              {capacity.getType()}, {loc});
1342         rewriter.setInsertionPointToEnd(before);
1343 
1344         capacity =
1345             rewriter.create<arith::MulIOp>(loc, before->getArgument(0), c2);
1346         cond = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ugt,
1347                                               newSize, capacity);
1348         rewriter.create<scf::ConditionOp>(loc, cond, ValueRange{capacity});
1349         // The after-region of the WhileOp.
1350         Block *after = rewriter.createBlock(&whileOp.getAfter(), {},
1351                                             {capacity.getType()}, {loc});
1352         rewriter.setInsertionPointToEnd(after);
1353         rewriter.create<scf::YieldOp>(loc, after->getArguments());
1354 
1355         rewriter.setInsertionPointAfter(whileOp);
1356         capacity = whileOp.getResult(0);
1357       }
1358 
1359       Value newBuffer =
1360           rewriter.create<memref::ReallocOp>(loc, bufferType, buffer, capacity);
1361       if (enableBufferInitialization) {
1362         Value fillSize = rewriter.create<arith::SubIOp>(loc, capacity, newSize);
1363         Value fillValue = constantZero(rewriter, loc, value.getType());
1364         Value subBuffer = rewriter.create<memref::SubViewOp>(
1365             loc, newBuffer, /*offset=*/ValueRange{newSize},
1366             /*size=*/ValueRange{fillSize},
1367             /*step=*/ValueRange{constantIndex(rewriter, loc, 1)});
1368         rewriter.create<linalg::FillOp>(loc, fillValue, subBuffer);
1369       }
1370       rewriter.create<scf::YieldOp>(loc, newBuffer);
1371 
1372       // False branch.
1373       rewriter.setInsertionPointToStart(&ifOp.getElseRegion().front());
1374       rewriter.create<scf::YieldOp>(loc, buffer);
1375 
1376       // Prepare for adding the value to the end of the buffer.
1377       rewriter.setInsertionPointAfter(ifOp);
1378       buffer = ifOp.getResult(0);
1379     }
1380 
1381     // Add the value to the end of the buffer.
1382     if (nIsOne) {
1383       rewriter.create<memref::StoreOp>(loc, value, buffer, size);
1384     } else {
1385       Value subBuffer = rewriter.create<memref::SubViewOp>(
1386           loc, buffer, /*offset=*/ValueRange{size}, /*size=*/ValueRange{n},
1387           /*step=*/ValueRange{constantIndex(rewriter, loc, 1)});
1388       rewriter.create<linalg::FillOp>(loc, value, subBuffer);
1389     }
1390 
1391     // Update the buffer size.
1392     rewriter.replaceOp(op, {buffer, newSize});
1393     return success();
1394   }
1395 
1396 private:
1397   bool enableBufferInitialization;
1398 };
1399 
1400 /// Sparse rewriting rule for the sort_coo operator.
1401 struct SortRewriter : public OpRewritePattern<SortOp> {
1402 public:
1403   using OpRewritePattern<SortOp>::OpRewritePattern;
1404 
1405   LogicalResult matchAndRewrite(SortOp op,
1406                                 PatternRewriter &rewriter) const override {
1407     SmallVector<Value> xys;
1408     xys.push_back(op.getXy());
1409     xys.append(op.getYs().begin(), op.getYs().end());
1410 
1411     auto xPerm = op.getPermMap();
1412     uint64_t ny = 0;
1413     if (auto nyAttr = op.getNyAttr())
1414       ny = nyAttr.getInt();
1415 
1416     return matchAndRewriteSortOp(op, xys, xPerm, ny, rewriter);
1417   }
1418 };
1419 
1420 } // namespace
1421 
1422 //===---------------------------------------------------------------------===//
1423 // Methods that add patterns described in this file to a pattern list.
1424 //===---------------------------------------------------------------------===//
1425 
1426 void mlir::populateSparseBufferRewriting(RewritePatternSet &patterns,
1427                                          bool enableBufferInitialization) {
1428   patterns.add<PushBackRewriter>(patterns.getContext(),
1429                                  enableBufferInitialization);
1430   patterns.add<SortRewriter>(patterns.getContext());
1431 }
1432