xref: /llvm-project/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp (revision 5262865aac683b72f3e66de7a122e0c455ab6b9b)
1 //===- SparseBufferRewriting.cpp - Sparse buffer rewriting rules ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements rewriting rules that are specific to sparse tensor
10 // primitives with memref operands.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "Utils/CodegenUtils.h"
15 
16 #include "mlir/Dialect/Arith/IR/Arith.h"
17 #include "mlir/Dialect/Func/IR/FuncOps.h"
18 #include "mlir/Dialect/Linalg/IR/Linalg.h"
19 #include "mlir/Dialect/Math/IR/Math.h"
20 #include "mlir/Dialect/MemRef/IR/MemRef.h"
21 #include "mlir/Dialect/SCF/IR/SCF.h"
22 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
23 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
24 #include "mlir/Support/LLVM.h"
25 
26 using namespace mlir;
27 using namespace mlir::sparse_tensor;
28 
29 //===---------------------------------------------------------------------===//
30 // Helper methods for the actual rewriting rules.
31 //===---------------------------------------------------------------------===//
32 
33 static constexpr uint64_t loIdx = 0;
34 static constexpr uint64_t hiIdx = 1;
35 static constexpr uint64_t xStartIdx = 2;
36 
37 static constexpr const char kPartitionFuncNamePrefix[] = "_sparse_partition_";
38 static constexpr const char kBinarySearchFuncNamePrefix[] =
39     "_sparse_binary_search_";
40 static constexpr const char kHybridQuickSortFuncNamePrefix[] =
41     "_sparse_hybrid_qsort_";
42 static constexpr const char kSortStableFuncNamePrefix[] =
43     "_sparse_sort_stable_";
44 static constexpr const char kShiftDownFuncNamePrefix[] = "_sparse_shift_down_";
45 static constexpr const char kHeapSortFuncNamePrefix[] = "_sparse_heap_sort_";
46 static constexpr const char kQuickSortFuncNamePrefix[] = "_sparse_qsort_";
47 
48 using FuncGeneratorType = function_ref<void(OpBuilder &, ModuleOp, func::FuncOp,
49                                             AffineMap, uint64_t, uint32_t)>;
50 
51 /// Constructs a function name with this format to facilitate quick sort:
52 ///   <namePrefix><xPerm>_<x type>_<y0 type>..._<yn type> for sort
53 ///   <namePrefix><xPerm>_<x type>_coo_<ny>_<y0 type>..._<yn type> for sort_coo
54 static void getMangledSortHelperFuncName(llvm::raw_svector_ostream &nameOstream,
55                                          StringRef namePrefix, AffineMap xPerm,
56                                          uint64_t ny, ValueRange operands) {
57   nameOstream << namePrefix;
58   for (auto res : xPerm.getResults())
59     nameOstream << cast<AffineDimExpr>(res).getPosition() << "_";
60 
61   nameOstream << getMemRefType(operands[xStartIdx]).getElementType();
62   nameOstream << "_coo_" << ny;
63 
64   constexpr uint64_t yBufferOffset = 1;
65   for (Value v : operands.drop_front(xStartIdx + yBufferOffset))
66     nameOstream << "_" << getMemRefType(v).getElementType();
67 }
68 
69 /// Looks up a function that is appropriate for the given operands being
70 /// sorted, and creates such a function if it doesn't exist yet. The
71 /// parameters `xPerm` and `ny` tell the number of x and y values provided
72 /// by the buffer in xStartIdx.
73 //
74 // All sorting function generators take (lo, hi, xs, ys) in `operands` as
75 // parameters for the sorting functions. Other parameters, such as the recursive
76 // call depth, are appended to the end of the parameter list as
77 // "trailing parameters".
78 static FlatSymbolRefAttr getMangledSortHelperFunc(
79     OpBuilder &builder, func::FuncOp insertPoint, TypeRange resultTypes,
80     StringRef namePrefix, AffineMap xPerm, uint64_t ny, ValueRange operands,
81     FuncGeneratorType createFunc, uint32_t nTrailingP = 0) {
82   SmallString<32> nameBuffer;
83   llvm::raw_svector_ostream nameOstream(nameBuffer);
84   getMangledSortHelperFuncName(nameOstream, namePrefix, xPerm, ny,
85                                operands.drop_back(nTrailingP));
86 
87   ModuleOp module = insertPoint->getParentOfType<ModuleOp>();
88   MLIRContext *context = module.getContext();
89   auto result = SymbolRefAttr::get(context, nameOstream.str());
90   auto func = module.lookupSymbol<func::FuncOp>(result.getAttr());
91 
92   if (!func) {
93     // Create the function.
94     OpBuilder::InsertionGuard insertionGuard(builder);
95     builder.setInsertionPoint(insertPoint);
96     Location loc = insertPoint.getLoc();
97     func = builder.create<func::FuncOp>(
98         loc, nameOstream.str(),
99         FunctionType::get(context, operands.getTypes(), resultTypes));
100     func.setPrivate();
101     createFunc(builder, module, func, xPerm, ny, nTrailingP);
102   }
103 
104   return result;
105 }
106 
107 /// Creates a code block to process each pair of (xs[i], xs[j]) for sorting.
108 /// The code to process the value pairs is generated by `bodyBuilder`.
109 static void forEachIJPairInXs(
110     OpBuilder &builder, Location loc, ValueRange args, AffineMap xPerm,
111     uint64_t ny,
112     function_ref<void(uint64_t, Value, Value, Value)> bodyBuilder) {
113   Value cstep = constantIndex(builder, loc, xPerm.getNumResults() + ny);
114   Value iOffset = builder.create<arith::MulIOp>(loc, args[0], cstep);
115   Value jOffset = builder.create<arith::MulIOp>(loc, args[1], cstep);
116   for (unsigned k = 0, e = xPerm.getNumResults(); k < e; k++) {
117     unsigned actualK = cast<AffineDimExpr>(xPerm.getResult(k)).getPosition();
118     Value ak = constantIndex(builder, loc, actualK);
119     Value i = builder.create<arith::AddIOp>(loc, ak, iOffset);
120     Value j = builder.create<arith::AddIOp>(loc, ak, jOffset);
121     Value buffer = args[xStartIdx];
122 
123     bodyBuilder(k, i, j, buffer);
124   }
125 }
126 
127 /// Creates a code block to process each pair of (xys[i], xys[j]) for sorting.
128 /// The code to process the value pairs is generated by `bodyBuilder`.
129 static void forEachIJPairInAllBuffers(
130     OpBuilder &builder, Location loc, ValueRange args, AffineMap xPerm,
131     uint64_t ny,
132     function_ref<void(uint64_t, Value, Value, Value)> bodyBuilder) {
133 
134   // Create code for the first (xPerm + ny) buffers.
135   SmallVector<AffineExpr> exps(xPerm.getResults());
136   for (unsigned y = 0; y < ny; y++) {
137     exps.push_back(builder.getAffineDimExpr(y + xPerm.getNumResults()));
138   }
139   AffineMap xyPerm = AffineMap::get(exps.size(), 0, exps, builder.getContext());
140   assert(xyPerm.isPermutation());
141 
142   forEachIJPairInXs(builder, loc, args, xyPerm, 0, bodyBuilder);
143 
144   constexpr uint64_t numHandledBuffers = 1;
145   // Create code for the remaining buffers.
146   Value i = args[0];
147   Value j = args[1];
148   for (const auto &arg :
149        llvm::enumerate(args.drop_front(xStartIdx + numHandledBuffers))) {
150     bodyBuilder(arg.index() + xPerm.getNumResults() + ny, i, j, arg.value());
151   }
152 }
153 
154 /// Creates a code block for swapping the values in index i and j for all the
155 /// buffers.
156 //
157 // The generated IR corresponds to this C like algorithm:
158 //     swap(x0[i], x0[j]);
159 //     swap(x1[i], x1[j]);
160 //     ...
161 //     swap(xn[i], xn[j]);
162 //     swap(y0[i], y0[j]);
163 //     ...
164 //     swap(yn[i], yn[j]);
165 static void createSwap(OpBuilder &builder, Location loc, ValueRange args,
166                        AffineMap xPerm, uint64_t ny) {
167   auto swapOnePair = [&](uint64_t unused, Value i, Value j, Value buffer) {
168     Value vi = builder.create<memref::LoadOp>(loc, buffer, i);
169     Value vj = builder.create<memref::LoadOp>(loc, buffer, j);
170     builder.create<memref::StoreOp>(loc, vj, buffer, i);
171     builder.create<memref::StoreOp>(loc, vi, buffer, j);
172   };
173 
174   forEachIJPairInAllBuffers(builder, loc, args, xPerm, ny, swapOnePair);
175 }
176 
177 /// Creates code to compare all the (xs[i], xs[j]) pairs. The method to compare
178 /// each pair is create via `compareBuilder`.
179 static Value createInlinedCompareImplementation(
180     OpBuilder &builder, Location loc, ValueRange args, AffineMap xPerm,
181     uint64_t ny,
182     function_ref<Value(OpBuilder &, Location, Value, Value, Value, bool, bool)>
183         compareBuilder) {
184   Value result;
185   auto bodyBuilder = [&](uint64_t k, Value i, Value j, Value buffer) {
186     bool isFirstDim = (k == 0);
187     bool isLastDim = (k == xPerm.getNumResults() - 1);
188     Value val =
189         compareBuilder(builder, loc, i, j, buffer, isFirstDim, isLastDim);
190     if (isFirstDim) {
191       result = val;
192     } else if (!isLastDim) {
193       OpBuilder::InsertionGuard insertionGuard(builder);
194       auto ifOp = cast<scf::IfOp>(val.getDefiningOp());
195       builder.setInsertionPointAfter(ifOp);
196       builder.create<scf::YieldOp>(loc, ifOp.getResult(0));
197     }
198   };
199 
200   forEachIJPairInXs(builder, loc, args, xPerm, ny, bodyBuilder);
201 
202   builder.setInsertionPointAfterValue(result);
203   return result;
204 }
205 
206 /// Generates code to compare whether x[i] is equal to x[j] and returns the
207 /// result of the comparison.
208 static Value createEqCompare(OpBuilder &builder, Location loc, Value i, Value j,
209                              Value x, bool isFirstDim, bool isLastDim) {
210   Value vi = builder.create<memref::LoadOp>(loc, x, i);
211   Value vj = builder.create<memref::LoadOp>(loc, x, j);
212 
213   Value res;
214   if (isLastDim) {
215     res = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, vi, vj);
216     // For 1D, we create a compare without any control flow. Otherwise, we
217     // create YieldOp to return the result in the nested if-stmt.
218     if (!isFirstDim)
219       builder.create<scf::YieldOp>(loc, res);
220   } else {
221     Value ne =
222         builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ne, vi, vj);
223     scf::IfOp ifOp = builder.create<scf::IfOp>(loc, builder.getIntegerType(1),
224                                                ne, /*else=*/true);
225     // If (x[i] != x[j]).
226     builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
227     Value f = constantI1(builder, loc, false);
228     builder.create<scf::YieldOp>(loc, f);
229 
230     // If (x[i] == x[j]). Set up the insertion point for the nested if-stmt that
231     // checks the remaining dimensions.
232     builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
233     res = ifOp.getResult(0);
234   }
235 
236   return res;
237 }
238 
239 /// Creates code to compare whether xs[i] is equal to xs[j].
240 //
241 // The generate IR corresponds to this C like algorithm:
242 //   if (x0[i] != x0[j])
243 //     return false;
244 //   else
245 //     if (x1[i] != x1[j])
246 //       return false;
247 //     else if (x2[2] != x2[j]))
248 //       and so on ...
249 static Value createInlinedEqCompare(OpBuilder &builder, Location loc,
250                                     ValueRange args, AffineMap xPerm,
251                                     uint64_t ny, uint32_t nTrailingP = 0) {
252   // Compare functions don't use trailing parameters.
253   (void)nTrailingP;
254   assert(nTrailingP == 0);
255   return createInlinedCompareImplementation(builder, loc, args, xPerm, ny,
256                                             createEqCompare);
257 }
258 
259 /// Generates code to compare whether x[i] is less than x[j] and returns the
260 /// result of the comparison.
261 static Value createLessThanCompare(OpBuilder &builder, Location loc, Value i,
262                                    Value j, Value x, bool isFirstDim,
263                                    bool isLastDim) {
264   Value vi = builder.create<memref::LoadOp>(loc, x, i);
265   Value vj = builder.create<memref::LoadOp>(loc, x, j);
266 
267   Value res;
268   if (isLastDim) {
269     res = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult, vi, vj);
270     // For 1D, we create a compare without any control flow. Otherwise, we
271     // create YieldOp to return the result in the nested if-stmt.
272     if (!isFirstDim)
273       builder.create<scf::YieldOp>(loc, res);
274   } else {
275     Value ne =
276         builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ne, vi, vj);
277     scf::IfOp ifOp = builder.create<scf::IfOp>(loc, builder.getIntegerType(1),
278                                                ne, /*else=*/true);
279     // If (x[i] != x[j]).
280     builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
281     Value lt =
282         builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult, vi, vj);
283     builder.create<scf::YieldOp>(loc, lt);
284 
285     // If (x[i] == x[j]). Set up the insertion point for the nested if-stmt that
286     // checks the remaining dimensions.
287     builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
288     res = ifOp.getResult(0);
289   }
290 
291   return res;
292 }
293 
294 /// Creates code to compare whether xs[i] is less than xs[j].
295 //
296 // The generate IR corresponds to this C like algorithm:
297 //   if (x0[i] != x0[j])
298 //     return x0[i] < x0[j];
299 //   else if (x1[j] != x1[i])
300 //     return x1[i] < x1[j];
301 //   else
302 //       and so on ...
303 static Value createInlinedLessThan(OpBuilder &builder, Location loc,
304                                    ValueRange args, AffineMap xPerm,
305                                    uint64_t ny, uint32_t nTrailingP = 0) {
306   // Compare functions don't use trailing parameters.
307   (void)nTrailingP;
308   assert(nTrailingP == 0);
309   return createInlinedCompareImplementation(builder, loc, args, xPerm, ny,
310                                             createLessThanCompare);
311 }
312 
313 /// Creates a function to use a binary search to find the insertion point for
314 /// inserting xs[hi] to the sorted values xs[lo..hi).
315 //
316 // The generate IR corresponds to this C like algorithm:
317 //   p = hi
318 //   while (lo < hi)
319 //      mid = (lo + hi) >> 1
320 //      if (xs[p] < xs[mid])
321 //        hi = mid
322 //      else
323 //        lo = mid - 1
324 //   return lo;
325 //
326 static void createBinarySearchFunc(OpBuilder &builder, ModuleOp module,
327                                    func::FuncOp func, AffineMap xPerm,
328                                    uint64_t ny, uint32_t nTrailingP = 0) {
329   // Binary search doesn't use trailing parameters.
330   (void)nTrailingP;
331   assert(nTrailingP == 0);
332   OpBuilder::InsertionGuard insertionGuard(builder);
333   Block *entryBlock = func.addEntryBlock();
334   builder.setInsertionPointToStart(entryBlock);
335 
336   Location loc = func.getLoc();
337   ValueRange args = entryBlock->getArguments();
338   Value p = args[hiIdx];
339   SmallVector<Type, 2> types(2, p.getType()); // Only two types.
340   scf::WhileOp whileOp = builder.create<scf::WhileOp>(
341       loc, types, SmallVector<Value, 2>{args[loIdx], args[hiIdx]});
342 
343   // The before-region of the WhileOp.
344   Block *before =
345       builder.createBlock(&whileOp.getBefore(), {}, types, {loc, loc});
346   builder.setInsertionPointToEnd(before);
347   Value cond1 = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult,
348                                               before->getArgument(0),
349                                               before->getArgument(1));
350   builder.create<scf::ConditionOp>(loc, cond1, before->getArguments());
351 
352   // The after-region of the WhileOp.
353   Block *after =
354       builder.createBlock(&whileOp.getAfter(), {}, types, {loc, loc});
355   builder.setInsertionPointToEnd(after);
356   Value lo = after->getArgument(0);
357   Value hi = after->getArgument(1);
358   // Compute mid = (lo + hi) >> 1.
359   Value c1 = constantIndex(builder, loc, 1);
360   Value mid = builder.create<arith::ShRUIOp>(
361       loc, builder.create<arith::AddIOp>(loc, lo, hi), c1);
362   Value midp1 = builder.create<arith::AddIOp>(loc, mid, c1);
363 
364   // Compare xs[p] < xs[mid].
365   SmallVector<Value> compareOperands{p, mid};
366   constexpr uint64_t numXBuffers = 1;
367   compareOperands.append(args.begin() + xStartIdx,
368                          args.begin() + xStartIdx + numXBuffers);
369   Value cond2 = createInlinedLessThan(builder, loc, compareOperands, xPerm, ny);
370   // Update lo and hi for the WhileOp as follows:
371   //   if (xs[p] < xs[mid]))
372   //     hi = mid;
373   //   else
374   //     lo = mid + 1;
375   Value newLo = builder.create<arith::SelectOp>(loc, cond2, lo, midp1);
376   Value newHi = builder.create<arith::SelectOp>(loc, cond2, mid, hi);
377   builder.create<scf::YieldOp>(loc, ValueRange{newLo, newHi});
378 
379   builder.setInsertionPointAfter(whileOp);
380   builder.create<func::ReturnOp>(loc, whileOp.getResult(0));
381 }
382 
383 /// Creates code to advance i in a loop based on xs[p] as follows:
384 ///   while (xs[i] < xs[p]) i += step (step > 0)
385 /// or
386 ///   while (xs[i] > xs[p]) i += step (step < 0)
387 /// The routine returns i as well as a boolean value to indicate whether
388 /// xs[i] == xs[p].
389 static std::pair<Value, Value> createScanLoop(OpBuilder &builder,
390                                               ModuleOp module,
391                                               func::FuncOp func, ValueRange xs,
392                                               Value i, Value p, AffineMap xPerm,
393                                               uint64_t ny, int step) {
394   Location loc = func.getLoc();
395   scf::WhileOp whileOp =
396       builder.create<scf::WhileOp>(loc, TypeRange{i.getType()}, ValueRange{i});
397 
398   Block *before =
399       builder.createBlock(&whileOp.getBefore(), {}, {i.getType()}, {loc});
400   builder.setInsertionPointToEnd(before);
401   SmallVector<Value> compareOperands;
402   if (step > 0) {
403     compareOperands.push_back(before->getArgument(0));
404     compareOperands.push_back(p);
405   } else {
406     assert(step < 0);
407     compareOperands.push_back(p);
408     compareOperands.push_back(before->getArgument(0));
409   }
410   compareOperands.append(xs.begin(), xs.end());
411   Value cond = createInlinedLessThan(builder, loc, compareOperands, xPerm, ny);
412   builder.create<scf::ConditionOp>(loc, cond, before->getArguments());
413 
414   Block *after =
415       builder.createBlock(&whileOp.getAfter(), {}, {i.getType()}, {loc});
416   builder.setInsertionPointToEnd(after);
417   Value cs = constantIndex(builder, loc, step);
418   i = builder.create<arith::AddIOp>(loc, after->getArgument(0), cs);
419   builder.create<scf::YieldOp>(loc, ValueRange{i});
420   i = whileOp.getResult(0);
421 
422   builder.setInsertionPointAfter(whileOp);
423   compareOperands[0] = i;
424   compareOperands[1] = p;
425   Value compareEq =
426       createInlinedEqCompare(builder, loc, compareOperands, xPerm, ny);
427 
428   return std::make_pair(whileOp.getResult(0), compareEq);
429 }
430 
431 /// Creates and returns an IfOp to compare two elements and swap the elements
432 /// if compareFunc(data[b], data[a]) returns true. The new insertion point is
433 /// right after the swap instructions.
434 static scf::IfOp createCompareThenSwap(OpBuilder &builder, Location loc,
435                                        AffineMap xPerm, uint64_t ny,
436                                        SmallVectorImpl<Value> &swapOperands,
437                                        SmallVectorImpl<Value> &compareOperands,
438                                        Value a, Value b) {
439   // Compare(data[b], data[a]).
440   compareOperands[0] = b;
441   compareOperands[1] = a;
442   Value cond = createInlinedLessThan(builder, loc, compareOperands, xPerm, ny);
443   scf::IfOp ifOp = builder.create<scf::IfOp>(loc, cond, /*else=*/false);
444   builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
445   swapOperands[0] = b;
446   swapOperands[1] = a;
447   createSwap(builder, loc, swapOperands, xPerm, ny);
448   return ifOp;
449 }
450 
451 /// Creates code to insert the 3rd element to a list of two sorted elements.
452 static void createInsert3rd(OpBuilder &builder, Location loc, AffineMap xPerm,
453                             uint64_t ny, SmallVectorImpl<Value> &swapOperands,
454                             SmallVectorImpl<Value> &compareOperands, Value v0,
455                             Value v1, Value v2) {
456   scf::IfOp ifOp = createCompareThenSwap(builder, loc, xPerm, ny, swapOperands,
457                                          compareOperands, v1, v2);
458   createCompareThenSwap(builder, loc, xPerm, ny, swapOperands, compareOperands,
459                         v0, v1);
460   builder.setInsertionPointAfter(ifOp);
461 }
462 
463 /// Creates code to sort 3 elements.
464 static void createSort3(OpBuilder &builder, Location loc, AffineMap xPerm,
465                         uint64_t ny, SmallVectorImpl<Value> &swapOperands,
466                         SmallVectorImpl<Value> &compareOperands, Value v0,
467                         Value v1, Value v2) {
468   // Sort the first 2 elements.
469   scf::IfOp ifOp1 = createCompareThenSwap(builder, loc, xPerm, ny, swapOperands,
470                                           compareOperands, v0, v1);
471   builder.setInsertionPointAfter(ifOp1);
472 
473   // Insert the 3th element.
474   createInsert3rd(builder, loc, xPerm, ny, swapOperands, compareOperands, v0,
475                   v1, v2);
476 }
477 
478 /// Creates code to sort 5 elements.
479 static void createSort5(OpBuilder &builder, Location loc, AffineMap xPerm,
480                         uint64_t ny, SmallVectorImpl<Value> &swapOperands,
481                         SmallVectorImpl<Value> &compareOperands, Value v0,
482                         Value v1, Value v2, Value v3, Value v4) {
483   // Sort the first 3 elements.
484   createSort3(builder, loc, xPerm, ny, swapOperands, compareOperands, v0, v1,
485               v2);
486 
487   auto insert4th = [&]() {
488     scf::IfOp ifOp = createCompareThenSwap(
489         builder, loc, xPerm, ny, swapOperands, compareOperands, v2, v3);
490     createInsert3rd(builder, loc, xPerm, ny, swapOperands, compareOperands, v0,
491                     v1, v2);
492     builder.setInsertionPointAfter(ifOp);
493   };
494 
495   // Insert the 4th element.
496   insert4th();
497 
498   // Insert the 5th element.
499   scf::IfOp ifOp = createCompareThenSwap(builder, loc, xPerm, ny, swapOperands,
500                                          compareOperands, v3, v4);
501   insert4th();
502   builder.setInsertionPointAfter(ifOp);
503 }
504 
505 /// Creates a code block to swap the values in indices lo, mi, and hi so that
506 /// data[lo], data[mi] and data[hi] are sorted in non-decreasing values. When
507 /// the number of values in range [lo, hi) is more than a threshold, we also
508 /// include the middle of [lo, mi) and [mi, hi) and sort a total of five values.
509 static void createChoosePivot(OpBuilder &builder, ModuleOp module,
510                               func::FuncOp func, AffineMap xPerm, uint64_t ny,
511                               Value lo, Value hi, Value mi, ValueRange args) {
512   SmallVector<Value> compareOperands{mi, lo};
513   constexpr uint64_t numXBuffers = 1;
514   compareOperands.append(args.begin() + xStartIdx,
515                          args.begin() + xStartIdx + numXBuffers);
516   SmallVector<Value> swapOperands{mi, lo};
517   swapOperands.append(args.begin() + xStartIdx, args.end());
518   Location loc = func.getLoc();
519   Value c1 = constantIndex(builder, loc, 1);
520   Value hiP1 = builder.create<arith::AddIOp>(loc, hi, c1);
521   Value len = builder.create<arith::SubIOp>(loc, hiP1, lo);
522   Value lenThreshold = constantIndex(builder, loc, 1000);
523   Value lenCond = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult,
524                                                 len, lenThreshold);
525   scf::IfOp lenIf = builder.create<scf::IfOp>(loc, lenCond, /*else=*/true);
526 
527   // When len < 1000, choose pivot from median of 3 values.
528   builder.setInsertionPointToStart(&lenIf.getThenRegion().front());
529   createSort3(builder, loc, xPerm, ny, swapOperands, compareOperands, lo, mi,
530               hi);
531 
532   // When len >= 1000, choose pivot from median of 5 values.
533   builder.setInsertionPointToStart(&lenIf.getElseRegion().front());
534   Value miP1 = builder.create<arith::AddIOp>(loc, hi, c1);
535   Value a = builder.create<arith::AddIOp>(loc, lo, miP1);
536   // Value a is the middle between [loc, mi].
537   a = builder.create<arith::ShRUIOp>(loc, a, c1);
538   Value b = builder.create<arith::AddIOp>(loc, mi, hiP1);
539   // Value b is the middle between [mi, hi].
540   b = builder.create<arith::ShRUIOp>(loc, b, c1);
541   createSort5(builder, loc, xPerm, ny, swapOperands, compareOperands, lo, a, mi,
542               b, hi);
543 
544   builder.setInsertionPointAfter(lenIf);
545 }
546 
547 /// Creates a function to perform quick sort partition on the values in the
548 /// range of index [lo, hi), assuming lo < hi.
549 //
550 // The generated IR corresponds to this C like algorithm:
551 // int partition(lo, hi, xs) {
552 //   p = (lo+hi)/2  // pivot index
553 //   i = lo
554 //   j = hi-1
555 //   while (true) do {
556 //     while (xs[i] < xs[p]) i ++;
557 //     i_eq = (xs[i] == xs[p]);
558 //     while (xs[j] > xs[p]) j --;
559 //     j_eq = (xs[j] == xs[p]);
560 //
561 //     if (i >= j) return j + 1;
562 //
563 //     if (i < j) {
564 //       swap(xs[i], xs[j])
565 //       if (i == p) {
566 //         p = j;
567 //       } else if (j == p) {
568 //         p = i;
569 //       }
570 //       if (i_eq && j_eq) {
571 //         ++i;
572 //         --j;
573 //       }
574 //     }
575 //   }
576 // }
577 static void createPartitionFunc(OpBuilder &builder, ModuleOp module,
578                                 func::FuncOp func, AffineMap xPerm, uint64_t ny,
579                                 uint32_t nTrailingP = 0) {
580   // Quick sort partition doesn't use trailing parameters.
581   (void)nTrailingP;
582   assert(nTrailingP == 0);
583   OpBuilder::InsertionGuard insertionGuard(builder);
584 
585   Block *entryBlock = func.addEntryBlock();
586   builder.setInsertionPointToStart(entryBlock);
587 
588   Location loc = func.getLoc();
589   ValueRange args = entryBlock->getArguments();
590   Value lo = args[loIdx];
591   Value hi = args[hiIdx];
592   Value sum = builder.create<arith::AddIOp>(loc, lo, hi);
593   Value c1 = constantIndex(builder, loc, 1);
594   Value p = builder.create<arith::ShRUIOp>(loc, sum, c1);
595 
596   Value i = lo;
597   Value j = builder.create<arith::SubIOp>(loc, hi, c1);
598   createChoosePivot(builder, module, func, xPerm, ny, i, j, p, args);
599   Value trueVal = constantI1(builder, loc, true); // The value for while (true)
600   SmallVector<Value, 4> operands{i, j, p, trueVal}; // Exactly four values.
601   SmallVector<Type, 4> types{i.getType(), j.getType(), p.getType(),
602                              trueVal.getType()};
603   scf::WhileOp whileOp = builder.create<scf::WhileOp>(loc, types, operands);
604 
605   // The before-region of the WhileOp.
606   Block *before = builder.createBlock(&whileOp.getBefore(), {}, types,
607                                       {loc, loc, loc, loc});
608   builder.setInsertionPointToEnd(before);
609   builder.create<scf::ConditionOp>(loc, before->getArgument(3),
610                                    before->getArguments());
611 
612   // The after-region of the WhileOp.
613   Block *after =
614       builder.createBlock(&whileOp.getAfter(), {}, types, {loc, loc, loc, loc});
615   builder.setInsertionPointToEnd(after);
616   i = after->getArgument(0);
617   j = after->getArgument(1);
618   p = after->getArgument(2);
619 
620   constexpr uint64_t numXBuffers = 1;
621   auto [iresult, iCompareEq] =
622       createScanLoop(builder, module, func, args.slice(xStartIdx, numXBuffers),
623                      i, p, xPerm, ny, 1);
624   i = iresult;
625   auto [jresult, jCompareEq] =
626       createScanLoop(builder, module, func, args.slice(xStartIdx, numXBuffers),
627                      j, p, xPerm, ny, -1);
628   j = jresult;
629 
630   // If i < j:
631   Value cond =
632       builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult, i, j);
633   scf::IfOp ifOp = builder.create<scf::IfOp>(loc, types, cond, /*else=*/true);
634   builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
635   SmallVector<Value> swapOperands{i, j};
636   swapOperands.append(args.begin() + xStartIdx, args.end());
637   createSwap(builder, loc, swapOperands, xPerm, ny);
638   // If the pivot is moved, update p with the new pivot.
639   Value icond =
640       builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, i, p);
641   scf::IfOp ifOpI = builder.create<scf::IfOp>(loc, TypeRange{p.getType()},
642                                               icond, /*else=*/true);
643   builder.setInsertionPointToStart(&ifOpI.getThenRegion().front());
644   builder.create<scf::YieldOp>(loc, ValueRange{j});
645   builder.setInsertionPointToStart(&ifOpI.getElseRegion().front());
646   Value jcond =
647       builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, j, p);
648   scf::IfOp ifOpJ = builder.create<scf::IfOp>(loc, TypeRange{p.getType()},
649                                               jcond, /*else=*/true);
650   builder.setInsertionPointToStart(&ifOpJ.getThenRegion().front());
651   builder.create<scf::YieldOp>(loc, ValueRange{i});
652   builder.setInsertionPointToStart(&ifOpJ.getElseRegion().front());
653   builder.create<scf::YieldOp>(loc, ValueRange{p});
654   builder.setInsertionPointAfter(ifOpJ);
655   builder.create<scf::YieldOp>(loc, ifOpJ.getResults());
656   builder.setInsertionPointAfter(ifOpI);
657   Value compareEqIJ =
658       builder.create<arith::AndIOp>(loc, iCompareEq, jCompareEq);
659   scf::IfOp ifOp2 = builder.create<scf::IfOp>(
660       loc, TypeRange{i.getType(), j.getType()}, compareEqIJ, /*else=*/true);
661   builder.setInsertionPointToStart(&ifOp2.getThenRegion().front());
662   Value i2 = builder.create<arith::AddIOp>(loc, i, c1);
663   Value j2 = builder.create<arith::SubIOp>(loc, j, c1);
664   builder.create<scf::YieldOp>(loc, ValueRange{i2, j2});
665   builder.setInsertionPointToStart(&ifOp2.getElseRegion().front());
666   builder.create<scf::YieldOp>(loc, ValueRange{i, j});
667   builder.setInsertionPointAfter(ifOp2);
668   builder.create<scf::YieldOp>(
669       loc,
670       ValueRange{ifOp2.getResult(0), ifOp2.getResult(1), ifOpI.getResult(0),
671                  /*cont=*/constantI1(builder, loc, true)});
672 
673   // False branch for if i < j (i.e., i >= j):
674   builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
675   p = builder.create<arith::AddIOp>(loc, j,
676                                     constantOne(builder, loc, j.getType()));
677   builder.create<scf::YieldOp>(
678       loc, ValueRange{i, j, p, /*cont=*/constantI1(builder, loc, false)});
679 
680   // Return for the whileOp.
681   builder.setInsertionPointAfter(ifOp);
682   builder.create<scf::YieldOp>(loc, ifOp.getResults());
683 
684   // Return for the function.
685   builder.setInsertionPointAfter(whileOp);
686   builder.create<func::ReturnOp>(loc, whileOp.getResult(2));
687 }
688 
689 /// Computes (n-2)/n, assuming n has index type.
690 static Value createSubTwoDividedByTwo(OpBuilder &builder, Location loc,
691                                       Value n) {
692   Value i2 = constantIndex(builder, loc, 2);
693   Value res = builder.create<arith::SubIOp>(loc, n, i2);
694   Value i1 = constantIndex(builder, loc, 1);
695   return builder.create<arith::ShRUIOp>(loc, res, i1);
696 }
697 
698 /// Creates a function to heapify the subtree with root `start` within the full
699 /// binary tree in the range of index [first, first + n).
700 //
701 // The generated IR corresponds to this C like algorithm:
702 // void shiftDown(first, start, n, data) {
703 //   if (n >= 2) {
704 //     child = start - first
705 //     if ((n-2)/2 >= child) {
706 //       // Left child exists.
707 //       child = child * 2 + 1 // Initialize the bigger child to left child.
708 //       childIndex = child + first
709 //       if (child+1 < n && data[childIndex] < data[childIndex+1])
710 //         // Right child exits and is bigger.
711 //         childIndex++; child++;
712 //       // Shift data[start] down to where it belongs in the subtree.
713 //       while (data[start] < data[childIndex) {
714 //         swap(data[start], data[childIndex])
715 //         start = childIndex
716 //         if ((n - 2)/2 >= child) {
717 //           // Left child exists.
718 //           child = 2*child + 1
719 //           childIndex = child + 1
720 //           if (child + 1) < n && data[childIndex] < data[childIndex+1]
721 //             childIndex++; child++;
722 //         }
723 //       }
724 //     }
725 //   }
726 // }
727 //
728 static void createShiftDownFunc(OpBuilder &builder, ModuleOp module,
729                                 func::FuncOp func, AffineMap xPerm, uint64_t ny,
730                                 uint32_t nTrailingP) {
731   // The value n is passed in as a trailing parameter.
732   assert(nTrailingP == 1);
733   OpBuilder::InsertionGuard insertionGuard(builder);
734   Block *entryBlock = func.addEntryBlock();
735   builder.setInsertionPointToStart(entryBlock);
736 
737   Location loc = func.getLoc();
738   Value n = entryBlock->getArguments().back();
739   ValueRange args = entryBlock->getArguments().drop_back();
740   Value first = args[loIdx];
741   Value start = args[hiIdx];
742 
743   // If (n >= 2).
744   Value c2 = constantIndex(builder, loc, 2);
745   Value condN =
746       builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::uge, n, c2);
747   scf::IfOp ifN = builder.create<scf::IfOp>(loc, condN, /*else=*/false);
748   builder.setInsertionPointToStart(&ifN.getThenRegion().front());
749   Value child = builder.create<arith::SubIOp>(loc, start, first);
750 
751   // If ((n-2)/2 >= child).
752   Value t = createSubTwoDividedByTwo(builder, loc, n);
753   Value condNc =
754       builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::uge, t, child);
755   scf::IfOp ifNc = builder.create<scf::IfOp>(loc, condNc, /*else=*/false);
756 
757   builder.setInsertionPointToStart(&ifNc.getThenRegion().front());
758   Value c1 = constantIndex(builder, loc, 1);
759   SmallVector<Value> compareOperands{start, start};
760   constexpr uint64_t numXBuffers = 1;
761   compareOperands.append(args.begin() + xStartIdx,
762                          args.begin() + xStartIdx + numXBuffers);
763 
764   // Generate code to inspect the children of 'r' and return the larger child
765   // as follows:
766   //   child = r * 2 + 1 // Left child.
767   //   childIndex = child + first
768   //   if (child+1 < n && data[childIndex] < data[childIndex+1])
769   //     childIndex ++; child ++ // Right child is bigger.
770   auto getLargerChild = [&](Value r) -> std::pair<Value, Value> {
771     Value lChild = builder.create<arith::ShLIOp>(loc, r, c1);
772     lChild = builder.create<arith::AddIOp>(loc, lChild, c1);
773     Value lChildIdx = builder.create<arith::AddIOp>(loc, lChild, first);
774     Value rChild = builder.create<arith::AddIOp>(loc, lChild, c1);
775     Value cond1 = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult,
776                                                 rChild, n);
777     SmallVector<Type, 2> ifTypes(2, r.getType());
778     scf::IfOp if1 =
779         builder.create<scf::IfOp>(loc, ifTypes, cond1, /*else=*/true);
780     builder.setInsertionPointToStart(&if1.getThenRegion().front());
781     Value rChildIdx = builder.create<arith::AddIOp>(loc, rChild, first);
782     // Compare data[left] < data[right].
783     compareOperands[0] = lChildIdx;
784     compareOperands[1] = rChildIdx;
785     Value cond2 =
786         createInlinedLessThan(builder, loc, compareOperands, xPerm, ny);
787     scf::IfOp if2 =
788         builder.create<scf::IfOp>(loc, ifTypes, cond2, /*else=*/true);
789     builder.setInsertionPointToStart(&if2.getThenRegion().front());
790     builder.create<scf::YieldOp>(loc, ValueRange{rChild, rChildIdx});
791     builder.setInsertionPointToStart(&if2.getElseRegion().front());
792     builder.create<scf::YieldOp>(loc, ValueRange{lChild, lChildIdx});
793     builder.setInsertionPointAfter(if2);
794     builder.create<scf::YieldOp>(loc, if2.getResults());
795     builder.setInsertionPointToStart(&if1.getElseRegion().front());
796     builder.create<scf::YieldOp>(loc, ValueRange{lChild, lChildIdx});
797     builder.setInsertionPointAfter(if1);
798     return std::make_pair(if1.getResult(0), if1.getResult(1));
799   };
800 
801   Value childIdx;
802   std::tie(child, childIdx) = getLargerChild(child);
803 
804   // While (data[start] < data[childIndex]).
805   SmallVector<Type, 3> types(3, child.getType());
806   scf::WhileOp whileOp = builder.create<scf::WhileOp>(
807       loc, types, SmallVector<Value, 2>{start, child, childIdx});
808 
809   // The before-region of the WhileOp.
810   SmallVector<Location, 3> locs(3, loc);
811   Block *before = builder.createBlock(&whileOp.getBefore(), {}, types, locs);
812   builder.setInsertionPointToEnd(before);
813   start = before->getArgument(0);
814   childIdx = before->getArgument(2);
815   compareOperands[0] = start;
816   compareOperands[1] = childIdx;
817   Value cond = createInlinedLessThan(builder, loc, compareOperands, xPerm, ny);
818   builder.create<scf::ConditionOp>(loc, cond, before->getArguments());
819 
820   // The after-region of the WhileOp.
821   Block *after = builder.createBlock(&whileOp.getAfter(), {}, types, locs);
822   start = after->getArgument(0);
823   child = after->getArgument(1);
824   childIdx = after->getArgument(2);
825   SmallVector<Value> swapOperands{start, childIdx};
826   swapOperands.append(args.begin() + xStartIdx, args.end());
827   createSwap(builder, loc, swapOperands, xPerm, ny);
828   start = childIdx;
829   Value cond2 =
830       builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::uge, t, child);
831   scf::IfOp if2 = builder.create<scf::IfOp>(
832       loc, TypeRange{child.getType(), child.getType()}, cond2, /*else=*/true);
833   builder.setInsertionPointToStart(&if2.getThenRegion().front());
834   auto [newChild, newChildIdx] = getLargerChild(child);
835   builder.create<scf::YieldOp>(loc, ValueRange{newChild, newChildIdx});
836   builder.setInsertionPointToStart(&if2.getElseRegion().front());
837   builder.create<scf::YieldOp>(loc, ValueRange{child, childIdx});
838   builder.setInsertionPointAfter(if2);
839   builder.create<scf::YieldOp>(
840       loc, ValueRange{start, if2.getResult(0), if2.getResult(1)});
841 
842   builder.setInsertionPointAfter(ifN);
843   builder.create<func::ReturnOp>(loc);
844 }
845 
846 /// Creates a function to perform heap sort on the values in the range of index
847 /// [lo, hi) with the assumption hi - lo >= 2.
848 //
849 // The generate IR corresponds to this C like algorithm:
850 // void heapSort(lo, hi, data) {
851 //   n = hi - lo
852 //   for i = (n-2)/2 downto 0
853 //     shiftDown(lo, lo+i, n)
854 //
855 //   for l = n downto 2
856 //      swap(lo, lo+l-1)
857 //      shiftdown(lo, lo, l-1)
858 // }
859 static void createHeapSortFunc(OpBuilder &builder, ModuleOp module,
860                                func::FuncOp func, AffineMap xPerm, uint64_t ny,
861                                uint32_t nTrailingP) {
862   // Heap sort function doesn't have trailing parameters.
863   (void)nTrailingP;
864   assert(nTrailingP == 0);
865   OpBuilder::InsertionGuard insertionGuard(builder);
866   Block *entryBlock = func.addEntryBlock();
867   builder.setInsertionPointToStart(entryBlock);
868 
869   Location loc = func.getLoc();
870   ValueRange args = entryBlock->getArguments();
871   Value lo = args[loIdx];
872   Value hi = args[hiIdx];
873   Value n = builder.create<arith::SubIOp>(loc, hi, lo);
874 
875   // For i = (n-2)/2 downto 0.
876   Value c0 = constantIndex(builder, loc, 0);
877   Value c1 = constantIndex(builder, loc, 1);
878   Value s = createSubTwoDividedByTwo(builder, loc, n);
879   Value up = builder.create<arith::AddIOp>(loc, s, c1);
880   scf::ForOp forI = builder.create<scf::ForOp>(loc, c0, up, c1);
881   builder.setInsertionPointToStart(forI.getBody());
882   Value i = builder.create<arith::SubIOp>(loc, s, forI.getInductionVar());
883   Value lopi = builder.create<arith::AddIOp>(loc, lo, i);
884   SmallVector<Value> shiftDownOperands = {lo, lopi};
885   shiftDownOperands.append(args.begin() + xStartIdx, args.end());
886   shiftDownOperands.push_back(n);
887   FlatSymbolRefAttr shiftDownFunc = getMangledSortHelperFunc(
888       builder, func, TypeRange(), kShiftDownFuncNamePrefix, xPerm, ny,
889       shiftDownOperands, createShiftDownFunc, /*nTrailingP=*/1);
890   builder.create<func::CallOp>(loc, shiftDownFunc, TypeRange(),
891                                shiftDownOperands);
892 
893   builder.setInsertionPointAfter(forI);
894   // For l = n downto 2.
895   up = builder.create<arith::SubIOp>(loc, n, c1);
896   scf::ForOp forL = builder.create<scf::ForOp>(loc, c0, up, c1);
897   builder.setInsertionPointToStart(forL.getBody());
898   Value l = builder.create<arith::SubIOp>(loc, n, forL.getInductionVar());
899   Value loplm1 = builder.create<arith::AddIOp>(loc, lo, l);
900   loplm1 = builder.create<arith::SubIOp>(loc, loplm1, c1);
901   SmallVector<Value> swapOperands{lo, loplm1};
902   swapOperands.append(args.begin() + xStartIdx, args.end());
903   createSwap(builder, loc, swapOperands, xPerm, ny);
904   shiftDownOperands[1] = lo;
905   shiftDownOperands[shiftDownOperands.size() - 1] =
906       builder.create<arith::SubIOp>(loc, l, c1);
907   builder.create<func::CallOp>(loc, shiftDownFunc, TypeRange(),
908                                shiftDownOperands);
909 
910   builder.setInsertionPointAfter(forL);
911   builder.create<func::ReturnOp>(loc);
912 }
913 
914 /// A helper for generating code to perform quick sort. It partitions [lo, hi),
915 /// recursively calls quick sort to process the smaller partition and returns
916 /// the bigger partition to be processed by the enclosed while-loop.
917 static std::pair<Value, Value>
918 createQuickSort(OpBuilder &builder, ModuleOp module, func::FuncOp func,
919                 ValueRange args, AffineMap xPerm, uint64_t ny,
920                 uint32_t nTrailingP) {
921   MLIRContext *context = module.getContext();
922   Location loc = func.getLoc();
923   Value lo = args[loIdx];
924   Value hi = args[hiIdx];
925   SmallVector<Type, 2> types(2, lo.getType()); // Only two types.
926 
927   FlatSymbolRefAttr partitionFunc = getMangledSortHelperFunc(
928       builder, func, {IndexType::get(context)}, kPartitionFuncNamePrefix, xPerm,
929       ny, args.drop_back(nTrailingP), createPartitionFunc);
930   Value p = builder
931                 .create<func::CallOp>(loc, partitionFunc,
932                                       TypeRange{IndexType::get(context)},
933                                       args.drop_back(nTrailingP))
934                 .getResult(0);
935 
936   Value lenLow = builder.create<arith::SubIOp>(loc, p, lo);
937   Value lenHigh = builder.create<arith::SubIOp>(loc, hi, p);
938   // Partition already sorts array with len <= 2
939   Value c2 = constantIndex(builder, loc, 2);
940   Value len = builder.create<arith::SubIOp>(loc, hi, lo);
941   Value lenGtTwo =
942       builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ugt, len, c2);
943   scf::IfOp ifLenGtTwo =
944       builder.create<scf::IfOp>(loc, types, lenGtTwo, /*else=*/true);
945   builder.setInsertionPointToStart(&ifLenGtTwo.getElseRegion().front());
946   // Returns an empty range to mark the entire region is fully sorted.
947   builder.create<scf::YieldOp>(loc, ValueRange{lo, lo});
948 
949   // Else len > 2, need recursion.
950   builder.setInsertionPointToStart(&ifLenGtTwo.getThenRegion().front());
951   Value cond = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ule,
952                                              lenLow, lenHigh);
953 
954   Value c0 = constantIndex(builder, loc, 0);
955   scf::IfOp ifOp = builder.create<scf::IfOp>(loc, types, cond, /*else=*/true);
956 
957   auto mayRecursion = [&](Value low, Value high, Value len) {
958     Value cond =
959         builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ne, len, c0);
960     scf::IfOp ifOp = builder.create<scf::IfOp>(loc, cond, /*else=*/false);
961     builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
962     SmallVector<Value> operands{low, high};
963     operands.append(args.begin() + xStartIdx, args.end());
964     builder.create<func::CallOp>(loc, func, operands);
965     builder.setInsertionPointAfter(ifOp);
966   };
967 
968   // Recursively call quickSort to process the smaller partition and return
969   // the bigger partition to be processed by the enclosed while-loop.
970   builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
971   mayRecursion(lo, p, lenLow);
972   builder.create<scf::YieldOp>(loc, ValueRange{p, hi});
973 
974   builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
975   mayRecursion(p, hi, lenHigh);
976   builder.create<scf::YieldOp>(loc, ValueRange{lo, p});
977 
978   builder.setInsertionPointAfter(ifOp);
979   builder.create<scf::YieldOp>(loc, ifOp.getResults());
980 
981   builder.setInsertionPointAfter(ifLenGtTwo);
982   return std::make_pair(ifLenGtTwo.getResult(0), ifLenGtTwo.getResult(1));
983 }
984 
985 /// Creates a function to perform insertion sort on the values in the range of
986 /// index [lo, hi).
987 //
988 // The generate IR corresponds to this C like algorithm:
989 // void insertionSort(lo, hi, data) {
990 //   for (i = lo+1; i < hi; i++) {
991 //      d = data[i];
992 //      p = binarySearch(lo, i-1, data)
993 //      for (j = 0; j > i - p; j++)
994 //        data[i-j] = data[i-j-1]
995 //      data[p] = d
996 //   }
997 // }
998 static void createSortStableFunc(OpBuilder &builder, ModuleOp module,
999                                  func::FuncOp func, AffineMap xPerm,
1000                                  uint64_t ny, uint32_t nTrailingP) {
1001   // Stable sort function doesn't use trailing parameters.
1002   (void)nTrailingP;
1003   assert(nTrailingP == 0);
1004   OpBuilder::InsertionGuard insertionGuard(builder);
1005   Block *entryBlock = func.addEntryBlock();
1006   builder.setInsertionPointToStart(entryBlock);
1007 
1008   MLIRContext *context = module.getContext();
1009   Location loc = func.getLoc();
1010   ValueRange args = entryBlock->getArguments();
1011   Value c1 = constantIndex(builder, loc, 1);
1012   Value lo = args[loIdx];
1013   Value hi = args[hiIdx];
1014   Value lop1 = builder.create<arith::AddIOp>(loc, lo, c1);
1015 
1016   // Start the outer for-stmt with induction variable i.
1017   scf::ForOp forOpI = builder.create<scf::ForOp>(loc, lop1, hi, c1);
1018   builder.setInsertionPointToStart(forOpI.getBody());
1019   Value i = forOpI.getInductionVar();
1020 
1021   // Binary search to find the insertion point p.
1022   SmallVector<Value> operands{lo, i};
1023   operands.append(args.begin() + xStartIdx, args.end());
1024   FlatSymbolRefAttr searchFunc = getMangledSortHelperFunc(
1025       builder, func, {IndexType::get(context)}, kBinarySearchFuncNamePrefix,
1026       xPerm, ny, operands, createBinarySearchFunc);
1027   Value p = builder
1028                 .create<func::CallOp>(loc, searchFunc, TypeRange{c1.getType()},
1029                                       operands)
1030                 .getResult(0);
1031 
1032   // Move the value at data[i] to a temporary location.
1033   operands[0] = operands[1] = i;
1034   SmallVector<Value> d;
1035   forEachIJPairInAllBuffers(
1036       builder, loc, operands, xPerm, ny,
1037       [&](uint64_t unused, Value i, Value unused2, Value buffer) {
1038         d.push_back(builder.create<memref::LoadOp>(loc, buffer, i));
1039       });
1040 
1041   // Start the inner for-stmt with induction variable j, for moving data[p..i)
1042   // to data[p+1..i+1).
1043   Value imp = builder.create<arith::SubIOp>(loc, i, p);
1044   Value c0 = constantIndex(builder, loc, 0);
1045   scf::ForOp forOpJ = builder.create<scf::ForOp>(loc, c0, imp, c1);
1046   builder.setInsertionPointToStart(forOpJ.getBody());
1047   Value j = forOpJ.getInductionVar();
1048   Value imj = builder.create<arith::SubIOp>(loc, i, j);
1049   operands[1] = imj;
1050   operands[0] = builder.create<arith::SubIOp>(loc, imj, c1);
1051   forEachIJPairInAllBuffers(
1052       builder, loc, operands, xPerm, ny,
1053       [&](uint64_t unused, Value imjm1, Value imj, Value buffer) {
1054         Value t = builder.create<memref::LoadOp>(loc, buffer, imjm1);
1055         builder.create<memref::StoreOp>(loc, t, buffer, imj);
1056       });
1057 
1058   // Store the value at data[i] to data[p].
1059   builder.setInsertionPointAfter(forOpJ);
1060   operands[0] = operands[1] = p;
1061   forEachIJPairInAllBuffers(
1062       builder, loc, operands, xPerm, ny,
1063       [&](uint64_t k, Value p, Value usused, Value buffer) {
1064         builder.create<memref::StoreOp>(loc, d[k], buffer, p);
1065       });
1066 
1067   builder.setInsertionPointAfter(forOpI);
1068   builder.create<func::ReturnOp>(loc);
1069 }
1070 
1071 /// Creates a function to perform quick sort or a hybrid quick sort on the
1072 /// values in the range of index [lo, hi).
1073 //
1074 //
1075 // When nTrailingP == 0, the generated IR corresponds to this C like algorithm:
1076 // void quickSort(lo, hi, data) {
1077 //   while (lo + 1 < hi) {
1078 //        p = partition(low, high, data);
1079 //        if (len(lo, p) < len(p+1, hi)) {
1080 //          quickSort(lo, p, data);
1081 //          lo = p+1;
1082 //        } else {
1083 //          quickSort(p + 1, hi, data);
1084 //          hi = p;
1085 //        }
1086 //   }
1087 // }
1088 //
1089 // When nTrailingP == 1, the generated IR corresponds to this C like algorithm:
1090 // void hybridQuickSort(lo, hi, data, depthLimit) {
1091 //   while (lo + 1 < hi) {
1092 //     len = hi - lo;
1093 //     if (len <= limit) {
1094 //       insertionSort(lo, hi, data);
1095 //     } else {
1096 //       depthLimit --;
1097 //       if (depthLimit <= 0) {
1098 //         heapSort(lo, hi, data);
1099 //       } else {
1100 //          p = partition(low, high, data);
1101 //          if (len(lo, p) < len(p+1, hi)) {
1102 //            quickSort(lo, p, data, depthLimit);
1103 //            lo = p+1;
1104 //          } else {
1105 //            quickSort(p + 1, hi, data, depthLimit);
1106 //            hi = p;
1107 //          }
1108 //       }
1109 //     }
1110 //   }
1111 // }
1112 //
1113 static void createQuickSortFunc(OpBuilder &builder, ModuleOp module,
1114                                 func::FuncOp func, AffineMap xPerm, uint64_t ny,
1115                                 uint32_t nTrailingP) {
1116   assert(nTrailingP == 1 || nTrailingP == 0);
1117   bool isHybrid = (nTrailingP == 1);
1118   OpBuilder::InsertionGuard insertionGuard(builder);
1119   Block *entryBlock = func.addEntryBlock();
1120   builder.setInsertionPointToStart(entryBlock);
1121 
1122   Location loc = func.getLoc();
1123   SmallVector<Value> args;
1124   args.append(entryBlock->getArguments().begin(),
1125               entryBlock->getArguments().end());
1126   Value lo = args[loIdx];
1127   Value hi = args[hiIdx];
1128   SmallVector<Type, 2> types(2, lo.getType()); // Only two types.
1129   scf::WhileOp whileOp =
1130       builder.create<scf::WhileOp>(loc, types, SmallVector<Value, 2>{lo, hi});
1131 
1132   // The before-region of the WhileOp.
1133   Block *before =
1134       builder.createBlock(&whileOp.getBefore(), {}, types, {loc, loc});
1135   builder.setInsertionPointToEnd(before);
1136   lo = before->getArgument(0);
1137   hi = before->getArgument(1);
1138   Value loP1 =
1139       builder.create<arith::AddIOp>(loc, lo, constantIndex(builder, loc, 1));
1140   Value needSort =
1141       builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult, loP1, hi);
1142   builder.create<scf::ConditionOp>(loc, needSort, before->getArguments());
1143 
1144   // The after-region of the WhileOp.
1145   Block *after =
1146       builder.createBlock(&whileOp.getAfter(), {}, types, {loc, loc});
1147   builder.setInsertionPointToEnd(after);
1148   lo = after->getArgument(0);
1149   hi = after->getArgument(1);
1150   args[0] = lo;
1151   args[1] = hi;
1152 
1153   if (isHybrid) {
1154     Value len = builder.create<arith::SubIOp>(loc, hi, lo);
1155     Value lenLimit = constantIndex(builder, loc, 30);
1156     Value lenCond = builder.create<arith::CmpIOp>(
1157         loc, arith::CmpIPredicate::ule, len, lenLimit);
1158     scf::IfOp lenIf =
1159         builder.create<scf::IfOp>(loc, types, lenCond, /*else=*/true);
1160 
1161     // When len <= limit.
1162     builder.setInsertionPointToStart(&lenIf.getThenRegion().front());
1163     FlatSymbolRefAttr insertionSortFunc = getMangledSortHelperFunc(
1164         builder, func, TypeRange(), kSortStableFuncNamePrefix, xPerm, ny,
1165         ValueRange(args).drop_back(nTrailingP), createSortStableFunc);
1166     builder.create<func::CallOp>(loc, insertionSortFunc, TypeRange(),
1167                                  ValueRange(args).drop_back(nTrailingP));
1168     builder.create<scf::YieldOp>(loc, ValueRange{lo, lo});
1169 
1170     // When len > limit.
1171     builder.setInsertionPointToStart(&lenIf.getElseRegion().front());
1172     Value depthLimit = args.back();
1173     depthLimit = builder.create<arith::SubIOp>(loc, depthLimit,
1174                                                constantI64(builder, loc, 1));
1175     Value depthCond =
1176         builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ule,
1177                                       depthLimit, constantI64(builder, loc, 0));
1178     scf::IfOp depthIf =
1179         builder.create<scf::IfOp>(loc, types, depthCond, /*else=*/true);
1180 
1181     // When depth exceeds limit.
1182     builder.setInsertionPointToStart(&depthIf.getThenRegion().front());
1183     FlatSymbolRefAttr heapSortFunc = getMangledSortHelperFunc(
1184         builder, func, TypeRange(), kHeapSortFuncNamePrefix, xPerm, ny,
1185         ValueRange(args).drop_back(nTrailingP), createHeapSortFunc);
1186     builder.create<func::CallOp>(loc, heapSortFunc, TypeRange(),
1187                                  ValueRange(args).drop_back(nTrailingP));
1188     builder.create<scf::YieldOp>(loc, ValueRange{lo, lo});
1189 
1190     // When depth doesn't exceed limit.
1191     builder.setInsertionPointToStart(&depthIf.getElseRegion().front());
1192     args.back() = depthLimit;
1193     std::tie(lo, hi) =
1194         createQuickSort(builder, module, func, args, xPerm, ny, nTrailingP);
1195     builder.create<scf::YieldOp>(loc, ValueRange{lo, hi});
1196 
1197     builder.setInsertionPointAfter(depthIf);
1198     lo = depthIf.getResult(0);
1199     hi = depthIf.getResult(1);
1200     builder.create<scf::YieldOp>(loc, ValueRange{lo, hi});
1201 
1202     builder.setInsertionPointAfter(lenIf);
1203     lo = lenIf.getResult(0);
1204     hi = lenIf.getResult(1);
1205   } else {
1206     std::tie(lo, hi) =
1207         createQuickSort(builder, module, func, args, xPerm, ny, nTrailingP);
1208   }
1209 
1210   // New [lo, hi) for the next while-loop iteration.
1211   builder.create<scf::YieldOp>(loc, ValueRange{lo, hi});
1212 
1213   // After the while-loop.
1214   builder.setInsertionPointAfter(whileOp);
1215   builder.create<func::ReturnOp>(loc);
1216 }
1217 
1218 /// Implements the rewriting for operator sort and sort_coo.
1219 template <typename OpTy>
1220 LogicalResult matchAndRewriteSortOp(OpTy op, ValueRange xys, AffineMap xPerm,
1221                                     uint64_t ny, PatternRewriter &rewriter) {
1222   Location loc = op.getLoc();
1223   SmallVector<Value> operands{constantIndex(rewriter, loc, 0), op.getN()};
1224 
1225   // Convert `values` to have dynamic shape and append them to `operands`.
1226   for (Value v : xys) {
1227     auto mtp = getMemRefType(v);
1228     if (!mtp.isDynamicDim(0)) {
1229       auto newMtp =
1230           MemRefType::get({ShapedType::kDynamic}, mtp.getElementType());
1231       v = rewriter.create<memref::CastOp>(loc, newMtp, v);
1232     }
1233     operands.push_back(v);
1234   }
1235 
1236   auto insertPoint = op->template getParentOfType<func::FuncOp>();
1237   if (!insertPoint)
1238     return failure();
1239 
1240   SmallString<32> funcName;
1241   FuncGeneratorType funcGenerator;
1242   uint32_t nTrailingP = 0;
1243   switch (op.getAlgorithm()) {
1244   case SparseTensorSortKind::HybridQuickSort: {
1245     funcName = kHybridQuickSortFuncNamePrefix;
1246     funcGenerator = createQuickSortFunc;
1247     nTrailingP = 1;
1248     // As a heuristics, set depthLimit = 2 * log2(n).
1249     Value lo = operands[loIdx];
1250     Value hi = operands[hiIdx];
1251     Value len = rewriter.create<arith::IndexCastOp>(
1252         loc, rewriter.getI64Type(),
1253         rewriter.create<arith::SubIOp>(loc, hi, lo));
1254     Value depthLimit = rewriter.create<arith::SubIOp>(
1255         loc, constantI64(rewriter, loc, 64),
1256         rewriter.create<math::CountLeadingZerosOp>(loc, len));
1257     operands.push_back(depthLimit);
1258     break;
1259   }
1260   case SparseTensorSortKind::QuickSort:
1261     funcName = kQuickSortFuncNamePrefix;
1262     funcGenerator = createQuickSortFunc;
1263     break;
1264   case SparseTensorSortKind::InsertionSortStable:
1265     funcName = kSortStableFuncNamePrefix;
1266     funcGenerator = createSortStableFunc;
1267     break;
1268   case SparseTensorSortKind::HeapSort:
1269     funcName = kHeapSortFuncNamePrefix;
1270     funcGenerator = createHeapSortFunc;
1271     break;
1272   }
1273 
1274   FlatSymbolRefAttr func =
1275       getMangledSortHelperFunc(rewriter, insertPoint, TypeRange(), funcName,
1276                                xPerm, ny, operands, funcGenerator, nTrailingP);
1277   rewriter.replaceOpWithNewOp<func::CallOp>(op, func, TypeRange(), operands);
1278   return success();
1279 }
1280 
1281 //===---------------------------------------------------------------------===//
1282 // The actual sparse buffer rewriting rules.
1283 //===---------------------------------------------------------------------===//
1284 
1285 namespace {
1286 /// Sparse rewriting rule for the push_back operator.
1287 struct PushBackRewriter : OpRewritePattern<PushBackOp> {
1288 public:
1289   using OpRewritePattern<PushBackOp>::OpRewritePattern;
1290   PushBackRewriter(MLIRContext *context, bool enableInit)
1291       : OpRewritePattern(context), enableBufferInitialization(enableInit) {}
1292   LogicalResult matchAndRewrite(PushBackOp op,
1293                                 PatternRewriter &rewriter) const override {
1294     // Rewrite push_back(buffer, value, n) to:
1295     // new_size = size(buffer) + n
1296     // if (new_size > capacity(buffer))
1297     //    while new_size > new_capacity
1298     //      new_capacity = new_capacity*2
1299     //    new_buffer = realloc(buffer, new_capacity)
1300     // buffer = new_buffer
1301     // subBuffer = subviewof(buffer)
1302     // linalg.fill subBuffer value
1303     //
1304     // size(buffer) += n
1305     //
1306     // The capacity check is skipped when the attribute inbounds is presented.
1307     Location loc = op->getLoc();
1308     Value c0 = constantIndex(rewriter, loc, 0);
1309     Value buffer = op.getInBuffer();
1310     Value capacity = rewriter.create<memref::DimOp>(loc, buffer, c0);
1311     Value size = op.getCurSize();
1312     Value value = op.getValue();
1313 
1314     Value n = op.getN() ? op.getN() : constantIndex(rewriter, loc, 1);
1315     Value newSize = rewriter.create<arith::AddIOp>(loc, size, n);
1316     auto nValue = dyn_cast_or_null<arith::ConstantIndexOp>(n.getDefiningOp());
1317     bool nIsOne = (nValue && nValue.value() == 1);
1318 
1319     if (!op.getInbounds()) {
1320       Value cond = rewriter.create<arith::CmpIOp>(
1321           loc, arith::CmpIPredicate::ugt, newSize, capacity);
1322 
1323       Value c2 = constantIndex(rewriter, loc, 2);
1324       auto bufferType =
1325           MemRefType::get({ShapedType::kDynamic}, value.getType());
1326       scf::IfOp ifOp = rewriter.create<scf::IfOp>(loc, bufferType, cond,
1327                                                   /*else=*/true);
1328       // True branch.
1329       rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front());
1330       if (nIsOne) {
1331         capacity = rewriter.create<arith::MulIOp>(loc, capacity, c2);
1332       } else {
1333         // Use a do-while loop to calculate the new capacity as follows:
1334         //   do { new_capacity *= 2 } while (size > new_capacity)
1335         scf::WhileOp whileOp =
1336             rewriter.create<scf::WhileOp>(loc, capacity.getType(), capacity);
1337 
1338         // The before-region of the WhileOp.
1339         Block *before = rewriter.createBlock(&whileOp.getBefore(), {},
1340                                              {capacity.getType()}, {loc});
1341         rewriter.setInsertionPointToEnd(before);
1342 
1343         capacity =
1344             rewriter.create<arith::MulIOp>(loc, before->getArgument(0), c2);
1345         cond = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ugt,
1346                                               newSize, capacity);
1347         rewriter.create<scf::ConditionOp>(loc, cond, ValueRange{capacity});
1348         // The after-region of the WhileOp.
1349         Block *after = rewriter.createBlock(&whileOp.getAfter(), {},
1350                                             {capacity.getType()}, {loc});
1351         rewriter.setInsertionPointToEnd(after);
1352         rewriter.create<scf::YieldOp>(loc, after->getArguments());
1353 
1354         rewriter.setInsertionPointAfter(whileOp);
1355         capacity = whileOp.getResult(0);
1356       }
1357 
1358       Value newBuffer =
1359           rewriter.create<memref::ReallocOp>(loc, bufferType, buffer, capacity);
1360       if (enableBufferInitialization) {
1361         Value fillSize = rewriter.create<arith::SubIOp>(loc, capacity, newSize);
1362         Value fillValue = constantZero(rewriter, loc, value.getType());
1363         Value subBuffer = rewriter.create<memref::SubViewOp>(
1364             loc, newBuffer, /*offset=*/ValueRange{newSize},
1365             /*size=*/ValueRange{fillSize},
1366             /*step=*/ValueRange{constantIndex(rewriter, loc, 1)});
1367         rewriter.create<linalg::FillOp>(loc, fillValue, subBuffer);
1368       }
1369       rewriter.create<scf::YieldOp>(loc, newBuffer);
1370 
1371       // False branch.
1372       rewriter.setInsertionPointToStart(&ifOp.getElseRegion().front());
1373       rewriter.create<scf::YieldOp>(loc, buffer);
1374 
1375       // Prepare for adding the value to the end of the buffer.
1376       rewriter.setInsertionPointAfter(ifOp);
1377       buffer = ifOp.getResult(0);
1378     }
1379 
1380     // Add the value to the end of the buffer.
1381     if (nIsOne) {
1382       rewriter.create<memref::StoreOp>(loc, value, buffer, size);
1383     } else {
1384       Value subBuffer = rewriter.create<memref::SubViewOp>(
1385           loc, buffer, /*offset=*/ValueRange{size}, /*size=*/ValueRange{n},
1386           /*step=*/ValueRange{constantIndex(rewriter, loc, 1)});
1387       rewriter.create<linalg::FillOp>(loc, value, subBuffer);
1388     }
1389 
1390     // Update the buffer size.
1391     rewriter.replaceOp(op, {buffer, newSize});
1392     return success();
1393   }
1394 
1395 private:
1396   bool enableBufferInitialization;
1397 };
1398 
1399 /// Sparse rewriting rule for the sort_coo operator.
1400 struct SortRewriter : public OpRewritePattern<SortOp> {
1401 public:
1402   using OpRewritePattern<SortOp>::OpRewritePattern;
1403 
1404   LogicalResult matchAndRewrite(SortOp op,
1405                                 PatternRewriter &rewriter) const override {
1406     SmallVector<Value> xys;
1407     xys.push_back(op.getXy());
1408     xys.append(op.getYs().begin(), op.getYs().end());
1409 
1410     auto xPerm = op.getPermMap();
1411     uint64_t ny = 0;
1412     if (auto nyAttr = op.getNyAttr())
1413       ny = nyAttr.getInt();
1414 
1415     return matchAndRewriteSortOp(op, xys, xPerm, ny, rewriter);
1416   }
1417 };
1418 
1419 } // namespace
1420 
1421 //===---------------------------------------------------------------------===//
1422 // Methods that add patterns described in this file to a pattern list.
1423 //===---------------------------------------------------------------------===//
1424 
1425 void mlir::populateSparseBufferRewriting(RewritePatternSet &patterns,
1426                                          bool enableBufferInitialization) {
1427   patterns.add<PushBackRewriter>(patterns.getContext(),
1428                                  enableBufferInitialization);
1429   patterns.add<SortRewriter>(patterns.getContext());
1430 }
1431