xref: /llvm-project/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp (revision f6424d11cb3f6d1ece0e3a4633abfd8427d463ff)
1 //===- SparseBufferRewriting.cpp - Sparse buffer rewriting rules ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements rewriting rules that are specific to sparse tensor
10 // primitives with memref operands.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "CodegenUtils.h"
15 
16 #include "mlir/Dialect/Arith/IR/Arith.h"
17 #include "mlir/Dialect/Func/IR/FuncOps.h"
18 #include "mlir/Dialect/Linalg/IR/Linalg.h"
19 #include "mlir/Dialect/Math/IR/Math.h"
20 #include "mlir/Dialect/MemRef/IR/MemRef.h"
21 #include "mlir/Dialect/SCF/IR/SCF.h"
22 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
23 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
24 #include "mlir/Support/LLVM.h"
25 
26 using namespace mlir;
27 using namespace mlir::sparse_tensor;
28 
29 //===---------------------------------------------------------------------===//
30 // Helper methods for the actual rewriting rules.
31 //===---------------------------------------------------------------------===//
32 
33 static constexpr uint64_t loIdx = 0;
34 static constexpr uint64_t hiIdx = 1;
35 static constexpr uint64_t xStartIdx = 2;
36 
37 static constexpr const char kLessThanFuncNamePrefix[] = "_sparse_less_than_";
38 static constexpr const char kCompareEqFuncNamePrefix[] = "_sparse_compare_eq_";
39 static constexpr const char kPartitionFuncNamePrefix[] = "_sparse_partition_";
40 static constexpr const char kBinarySearchFuncNamePrefix[] =
41     "_sparse_binary_search_";
42 static constexpr const char kHybridQuickSortFuncNamePrefix[] =
43     "_sparse_hybrid_qsort_";
44 static constexpr const char kSortStableFuncNamePrefix[] =
45     "_sparse_sort_stable_";
46 static constexpr const char kShiftDownFuncNamePrefix[] = "_sparse_shift_down_";
47 static constexpr const char kHeapSortFuncNamePrefix[] = "_sparse_heap_sort_";
48 static constexpr const char kQuickSortFuncNamePrefix[] = "_sparse_qsort_";
49 
50 using FuncGeneratorType = function_ref<void(
51     OpBuilder &, ModuleOp, func::FuncOp, uint64_t, uint64_t, bool, uint32_t)>;
52 
53 /// Constructs a function name with this format to facilitate quick sort:
54 ///   <namePrefix><nx>_<x type>_<y0 type>..._<yn type> for sort
55 ///   <namePrefix><nx>_<x type>_coo_<ny>_<y0 type>..._<yn type> for sort_coo
56 static void getMangledSortHelperFuncName(llvm::raw_svector_ostream &nameOstream,
57                                          StringRef namePrefix, uint64_t nx,
58                                          uint64_t ny, bool isCoo,
59                                          ValueRange operands) {
60   nameOstream << namePrefix << nx << "_"
61               << getMemRefType(operands[xStartIdx]).getElementType();
62 
63   if (isCoo)
64     nameOstream << "_coo_" << ny;
65 
66   uint64_t yBufferOffset = isCoo ? 1 : nx;
67   for (Value v : operands.drop_front(xStartIdx + yBufferOffset))
68     nameOstream << "_" << getMemRefType(v).getElementType();
69 }
70 
71 /// Looks up a function that is appropriate for the given operands being
72 /// sorted, and creates such a function if it doesn't exist yet. The
73 /// parameters `nx` and `ny` tell the number of x and y values provided
74 /// by the buffer in xStartIdx, and `isCoo` indicates whether the instruction
75 /// being processed is a sparse_tensor.sort or sparse_tensor.sort_coo.
76 //
77 // All sorting function generators take (lo, hi, xs, ys) in `operands` as
78 // parameters for the sorting functions. Other parameters, such as the recursive
79 // call depth, are appended to the end of the parameter list as
80 // "trailing parameters".
81 static FlatSymbolRefAttr
82 getMangledSortHelperFunc(OpBuilder &builder, func::FuncOp insertPoint,
83                          TypeRange resultTypes, StringRef namePrefix,
84                          uint64_t nx, uint64_t ny, bool isCoo,
85                          ValueRange operands, FuncGeneratorType createFunc,
86                          uint32_t nTrailingP = 0) {
87   SmallString<32> nameBuffer;
88   llvm::raw_svector_ostream nameOstream(nameBuffer);
89   getMangledSortHelperFuncName(nameOstream, namePrefix, nx, ny, isCoo,
90                                operands.drop_back(nTrailingP));
91 
92   ModuleOp module = insertPoint->getParentOfType<ModuleOp>();
93   MLIRContext *context = module.getContext();
94   auto result = SymbolRefAttr::get(context, nameOstream.str());
95   auto func = module.lookupSymbol<func::FuncOp>(result.getAttr());
96 
97   if (!func) {
98     // Create the function.
99     OpBuilder::InsertionGuard insertionGuard(builder);
100     builder.setInsertionPoint(insertPoint);
101     Location loc = insertPoint.getLoc();
102     func = builder.create<func::FuncOp>(
103         loc, nameOstream.str(),
104         FunctionType::get(context, operands.getTypes(), resultTypes));
105     func.setPrivate();
106     createFunc(builder, module, func, nx, ny, isCoo, nTrailingP);
107   }
108 
109   return result;
110 }
111 
112 /// Creates a code block to process each pair of (xs[i], xs[j]) for sorting.
113 /// The code to process the value pairs is generated by `bodyBuilder`.
114 static void forEachIJPairInXs(
115     OpBuilder &builder, Location loc, ValueRange args, uint64_t nx, uint64_t ny,
116     bool isCoo, function_ref<void(uint64_t, Value, Value, Value)> bodyBuilder) {
117   Value iOffset, jOffset;
118   if (isCoo) {
119     Value cstep = constantIndex(builder, loc, nx + ny);
120     iOffset = builder.create<arith::MulIOp>(loc, args[0], cstep);
121     jOffset = builder.create<arith::MulIOp>(loc, args[1], cstep);
122   }
123   for (uint64_t k = 0; k < nx; k++) {
124     scf::IfOp ifOp;
125     Value i, j, buffer;
126     if (isCoo) {
127       Value ck = constantIndex(builder, loc, k);
128       i = builder.create<arith::AddIOp>(loc, ck, iOffset);
129       j = builder.create<arith::AddIOp>(loc, ck, jOffset);
130       buffer = args[xStartIdx];
131     } else {
132       i = args[0];
133       j = args[1];
134       buffer = args[xStartIdx + k];
135     }
136     bodyBuilder(k, i, j, buffer);
137   }
138 }
139 
140 /// Creates a code block to process each pair of (xys[i], xys[j]) for sorting.
141 /// The code to process the value pairs is generated by `bodyBuilder`.
142 static void forEachIJPairInAllBuffers(
143     OpBuilder &builder, Location loc, ValueRange args, uint64_t nx, uint64_t ny,
144     bool isCoo, function_ref<void(uint64_t, Value, Value, Value)> bodyBuilder) {
145 
146   // Create code for the first (nx + ny) buffers. When isCoo==true, these
147   // logical buffers are all from the xy buffer of the sort_coo operator.
148   forEachIJPairInXs(builder, loc, args, nx + ny, 0, isCoo, bodyBuilder);
149 
150   uint64_t numHandledBuffers = isCoo ? 1 : nx + ny;
151 
152   // Create code for the remaining buffers.
153   Value i = args[0];
154   Value j = args[1];
155   for (const auto &arg :
156        llvm::enumerate(args.drop_front(xStartIdx + numHandledBuffers))) {
157     bodyBuilder(arg.index() + nx + ny, i, j, arg.value());
158   }
159 }
160 
161 /// Creates a code block for swapping the values in index i and j for all the
162 /// buffers.
163 //
164 // The generated IR corresponds to this C like algorithm:
165 //     swap(x0[i], x0[j]);
166 //     swap(x1[i], x1[j]);
167 //     ...
168 //     swap(xn[i], xn[j]);
169 //     swap(y0[i], y0[j]);
170 //     ...
171 //     swap(yn[i], yn[j]);
172 static void createSwap(OpBuilder &builder, Location loc, ValueRange args,
173                        uint64_t nx, uint64_t ny, bool isCoo) {
174   auto swapOnePair = [&](uint64_t unused, Value i, Value j, Value buffer) {
175     Value vi = builder.create<memref::LoadOp>(loc, buffer, i);
176     Value vj = builder.create<memref::LoadOp>(loc, buffer, j);
177     builder.create<memref::StoreOp>(loc, vj, buffer, i);
178     builder.create<memref::StoreOp>(loc, vi, buffer, j);
179   };
180 
181   forEachIJPairInAllBuffers(builder, loc, args, nx, ny, isCoo, swapOnePair);
182 }
183 
184 /// Creates a function to compare all the (xs[i], xs[j]) pairs. The method to
185 /// compare each pair is create via `compareBuilder`.
186 static void createCompareFuncImplementation(
187     OpBuilder &builder, ModuleOp unused, func::FuncOp func, uint64_t nx,
188     uint64_t ny, bool isCoo,
189     function_ref<scf::IfOp(OpBuilder &, Location, Value, Value, Value, bool)>
190         compareBuilder) {
191   OpBuilder::InsertionGuard insertionGuard(builder);
192 
193   Block *entryBlock = func.addEntryBlock();
194   builder.setInsertionPointToStart(entryBlock);
195   Location loc = func.getLoc();
196   ValueRange args = entryBlock->getArguments();
197 
198   scf::IfOp topIfOp;
199   auto bodyBuilder = [&](uint64_t k, Value i, Value j, Value buffer) {
200     scf::IfOp ifOp = compareBuilder(builder, loc, i, j, buffer, (k == nx - 1));
201     if (k == 0) {
202       topIfOp = ifOp;
203     } else {
204       OpBuilder::InsertionGuard insertionGuard(builder);
205       builder.setInsertionPointAfter(ifOp);
206       builder.create<scf::YieldOp>(loc, ifOp.getResult(0));
207     }
208   };
209 
210   forEachIJPairInXs(builder, loc, args, nx, ny, isCoo, bodyBuilder);
211 
212   builder.setInsertionPointAfter(topIfOp);
213   builder.create<func::ReturnOp>(loc, topIfOp.getResult(0));
214 }
215 
216 /// Generates an if-statement to compare whether x[i] is equal to x[j].
217 static scf::IfOp createEqCompare(OpBuilder &builder, Location loc, Value i,
218                                  Value j, Value x, bool isLastDim) {
219   Value f = constantI1(builder, loc, false);
220   Value t = constantI1(builder, loc, true);
221   Value vi = builder.create<memref::LoadOp>(loc, x, i);
222   Value vj = builder.create<memref::LoadOp>(loc, x, j);
223 
224   Value cond =
225       builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, vi, vj);
226   scf::IfOp ifOp =
227       builder.create<scf::IfOp>(loc, f.getType(), cond, /*else=*/true);
228 
229   // x[1] != x[j]:
230   builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
231   builder.create<scf::YieldOp>(loc, f);
232 
233   // x[i] == x[j]:
234   builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
235   if (isLastDim == 1) {
236     // Finish checking all dimensions.
237     builder.create<scf::YieldOp>(loc, t);
238   }
239 
240   return ifOp;
241 }
242 
243 /// Creates a function to compare whether xs[i] is equal to xs[j].
244 //
245 // The generate IR corresponds to this C like algorithm:
246 //   if (x0[i] != x0[j])
247 //     return false;
248 //   else
249 //     if (x1[i] != x1[j])
250 //       return false;
251 //     else if (x2[2] != x2[j]))
252 //       and so on ...
253 static void createEqCompareFunc(OpBuilder &builder, ModuleOp unused,
254                                 func::FuncOp func, uint64_t nx, uint64_t ny,
255                                 bool isCoo, uint32_t nTrailingP = 0) {
256   // Compare functions don't use trailing parameters.
257   (void)nTrailingP;
258   assert(nTrailingP == 0);
259   createCompareFuncImplementation(builder, unused, func, nx, ny, isCoo,
260                                   createEqCompare);
261 }
262 
263 /// Generates an if-statement to compare whether x[i] is less than x[j].
264 static scf::IfOp createLessThanCompare(OpBuilder &builder, Location loc,
265                                        Value i, Value j, Value x,
266                                        bool isLastDim) {
267   Value f = constantI1(builder, loc, false);
268   Value t = constantI1(builder, loc, true);
269   Value vi = builder.create<memref::LoadOp>(loc, x, i);
270   Value vj = builder.create<memref::LoadOp>(loc, x, j);
271 
272   Value cond =
273       builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult, vi, vj);
274   scf::IfOp ifOp =
275       builder.create<scf::IfOp>(loc, f.getType(), cond, /*else=*/true);
276   // If (x[i] < x[j]).
277   builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
278   builder.create<scf::YieldOp>(loc, t);
279 
280   builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
281   if (isLastDim == 1) {
282     // Finish checking all dimensions.
283     builder.create<scf::YieldOp>(loc, f);
284   } else {
285     cond =
286         builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult, vj, vi);
287     scf::IfOp ifOp2 =
288         builder.create<scf::IfOp>(loc, f.getType(), cond, /*else=*/true);
289     // Otherwise if (x[j] < x[i]).
290     builder.setInsertionPointToStart(&ifOp2.getThenRegion().front());
291     builder.create<scf::YieldOp>(loc, f);
292 
293     // Otherwise check the remaining dimensions.
294     builder.setInsertionPointAfter(ifOp2);
295     builder.create<scf::YieldOp>(loc, ifOp2.getResult(0));
296     // Set up the insertion point for the nested if-stmt that checks the
297     // remaining dimensions.
298     builder.setInsertionPointToStart(&ifOp2.getElseRegion().front());
299   }
300 
301   return ifOp;
302 }
303 
304 /// Creates a function to compare whether xs[i] is less than xs[j].
305 //
306 // The generate IR corresponds to this C like algorithm:
307 //   if (x0[i] < x0[j])
308 //     return true;
309 //   else if (x0[j] < x0[i])
310 //     return false;
311 //   else
312 //     if (x1[i] < x1[j])
313 //       return true;
314 //     else if (x1[j] < x1[i]))
315 //       and so on ...
316 static void createLessThanFunc(OpBuilder &builder, ModuleOp unused,
317                                func::FuncOp func, uint64_t nx, uint64_t ny,
318                                bool isCoo, uint32_t nTrailingP = 0) {
319   // Compare functions don't use trailing parameters.
320   (void)nTrailingP;
321   assert(nTrailingP == 0);
322   createCompareFuncImplementation(builder, unused, func, nx, ny, isCoo,
323                                   createLessThanCompare);
324 }
325 
326 /// Creates a function to use a binary search to find the insertion point for
327 /// inserting xs[hi] to the sorted values xs[lo..hi).
328 //
329 // The generate IR corresponds to this C like algorithm:
330 //   p = hi
331 //   while (lo < hi)
332 //      mid = (lo + hi) >> 1
333 //      if (xs[p] < xs[mid])
334 //        hi = mid
335 //      else
336 //        lo = mid - 1
337 //   return lo;
338 //
339 static void createBinarySearchFunc(OpBuilder &builder, ModuleOp module,
340                                    func::FuncOp func, uint64_t nx, uint64_t ny,
341                                    bool isCoo, uint32_t nTrailingP = 0) {
342   // Binary search doesn't use trailing parameters.
343   (void)nTrailingP;
344   assert(nTrailingP == 0);
345   OpBuilder::InsertionGuard insertionGuard(builder);
346   Block *entryBlock = func.addEntryBlock();
347   builder.setInsertionPointToStart(entryBlock);
348 
349   Location loc = func.getLoc();
350   ValueRange args = entryBlock->getArguments();
351   Value p = args[hiIdx];
352   SmallVector<Type, 2> types(2, p.getType()); // Only two types.
353   scf::WhileOp whileOp = builder.create<scf::WhileOp>(
354       loc, types, SmallVector<Value, 2>{args[loIdx], args[hiIdx]});
355 
356   // The before-region of the WhileOp.
357   Block *before =
358       builder.createBlock(&whileOp.getBefore(), {}, types, {loc, loc});
359   builder.setInsertionPointToEnd(before);
360   Value cond1 = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult,
361                                               before->getArgument(0),
362                                               before->getArgument(1));
363   builder.create<scf::ConditionOp>(loc, cond1, before->getArguments());
364 
365   // The after-region of the WhileOp.
366   Block *after =
367       builder.createBlock(&whileOp.getAfter(), {}, types, {loc, loc});
368   builder.setInsertionPointToEnd(after);
369   Value lo = after->getArgument(0);
370   Value hi = after->getArgument(1);
371   // Compute mid = (lo + hi) >> 1.
372   Value c1 = constantIndex(builder, loc, 1);
373   Value mid = builder.create<arith::ShRUIOp>(
374       loc, builder.create<arith::AddIOp>(loc, lo, hi), c1);
375   Value midp1 = builder.create<arith::AddIOp>(loc, mid, c1);
376 
377   // Compare xs[p] < xs[mid].
378   SmallVector<Value> compareOperands{p, mid};
379   uint64_t numXBuffers = isCoo ? 1 : nx;
380   compareOperands.append(args.begin() + xStartIdx,
381                          args.begin() + xStartIdx + numXBuffers);
382   Type i1Type = IntegerType::get(module.getContext(), 1, IntegerType::Signless);
383   FlatSymbolRefAttr lessThanFunc = getMangledSortHelperFunc(
384       builder, func, {i1Type}, kLessThanFuncNamePrefix, nx, ny, isCoo,
385       compareOperands, createLessThanFunc, nTrailingP);
386   Value cond2 = builder
387                     .create<func::CallOp>(loc, lessThanFunc, TypeRange{i1Type},
388                                           compareOperands)
389                     .getResult(0);
390 
391   // Update lo and hi for the WhileOp as follows:
392   //   if (xs[p] < xs[mid]))
393   //     hi = mid;
394   //   else
395   //     lo = mid + 1;
396   Value newLo = builder.create<arith::SelectOp>(loc, cond2, lo, midp1);
397   Value newHi = builder.create<arith::SelectOp>(loc, cond2, mid, hi);
398   builder.create<scf::YieldOp>(loc, ValueRange{newLo, newHi});
399 
400   builder.setInsertionPointAfter(whileOp);
401   builder.create<func::ReturnOp>(loc, whileOp.getResult(0));
402 }
403 
404 /// Creates code to advance i in a loop based on xs[p] as follows:
405 ///   while (xs[i] < xs[p]) i += step (step > 0)
406 /// or
407 ///   while (xs[i] > xs[p]) i += step (step < 0)
408 /// The routine returns i as well as a boolean value to indicate whether
409 /// xs[i] == xs[p].
410 static std::pair<Value, Value>
411 createScanLoop(OpBuilder &builder, ModuleOp module, func::FuncOp func,
412                ValueRange xs, Value i, Value p, uint64_t nx, uint64_t ny,
413                bool isCoo, int step) {
414   Location loc = func.getLoc();
415   scf::WhileOp whileOp =
416       builder.create<scf::WhileOp>(loc, TypeRange{i.getType()}, ValueRange{i});
417 
418   Block *before =
419       builder.createBlock(&whileOp.getBefore(), {}, {i.getType()}, {loc});
420   builder.setInsertionPointToEnd(before);
421   SmallVector<Value> compareOperands;
422   if (step > 0) {
423     compareOperands.push_back(before->getArgument(0));
424     compareOperands.push_back(p);
425   } else {
426     assert(step < 0);
427     compareOperands.push_back(p);
428     compareOperands.push_back(before->getArgument(0));
429   }
430   compareOperands.append(xs.begin(), xs.end());
431   MLIRContext *context = module.getContext();
432   Type i1Type = IntegerType::get(context, 1, IntegerType::Signless);
433   FlatSymbolRefAttr lessThanFunc = getMangledSortHelperFunc(
434       builder, func, {i1Type}, kLessThanFuncNamePrefix, nx, ny, isCoo,
435       compareOperands, createLessThanFunc);
436   Value cond = builder
437                    .create<func::CallOp>(loc, lessThanFunc, TypeRange{i1Type},
438                                          compareOperands)
439                    .getResult(0);
440   builder.create<scf::ConditionOp>(loc, cond, before->getArguments());
441 
442   Block *after =
443       builder.createBlock(&whileOp.getAfter(), {}, {i.getType()}, {loc});
444   builder.setInsertionPointToEnd(after);
445   Value cs = constantIndex(builder, loc, step);
446   i = builder.create<arith::AddIOp>(loc, after->getArgument(0), cs);
447   builder.create<scf::YieldOp>(loc, ValueRange{i});
448   i = whileOp.getResult(0);
449 
450   builder.setInsertionPointAfter(whileOp);
451   compareOperands[0] = i;
452   compareOperands[1] = p;
453   FlatSymbolRefAttr compareEqFunc = getMangledSortHelperFunc(
454       builder, func, {i1Type}, kCompareEqFuncNamePrefix, nx, ny, isCoo,
455       compareOperands, createEqCompareFunc);
456   Value compareEq =
457       builder
458           .create<func::CallOp>(loc, compareEqFunc, TypeRange{i1Type},
459                                 compareOperands)
460           .getResult(0);
461 
462   return std::make_pair(whileOp.getResult(0), compareEq);
463 }
464 
465 /// Creates a code block to swap the values so that data[mi] is the median among
466 /// data[lo], data[hi], and data[mi].
467 //  The generated code corresponds to this C-like algorithm:
468 //  median = mi
469 //  if (data[mi] < data[lo]).                               (if1)
470 //    if (data[hi] < data[lo])                              (if2)
471 //       median = data[hi] < data[mi] ? mi : hi
472 //    else
473 //       median = lo
474 //  else
475 //    if data[hi] < data[mi]                                (if3)
476 //      median = data[hi] < data[lo] ? lo : hi
477 //  if median != mi swap data[median] with data[mi]
478 static void createChoosePivot(OpBuilder &builder, ModuleOp module,
479                               func::FuncOp func, uint64_t nx, uint64_t ny,
480                               bool isCoo, Value lo, Value hi, Value mi,
481                               ValueRange args) {
482   SmallVector<Value> compareOperands{mi, lo};
483   uint64_t numXBuffers = isCoo ? 1 : nx;
484   compareOperands.append(args.begin() + xStartIdx,
485                          args.begin() + xStartIdx + numXBuffers);
486   Type i1Type = IntegerType::get(module.getContext(), 1, IntegerType::Signless);
487   SmallVector<Type, 1> cmpTypes{i1Type};
488   FlatSymbolRefAttr lessThanFunc = getMangledSortHelperFunc(
489       builder, func, cmpTypes, kLessThanFuncNamePrefix, nx, ny, isCoo,
490       compareOperands, createLessThanFunc);
491   Location loc = func.getLoc();
492   // Compare data[mi] < data[lo].
493   Value cond1 =
494       builder.create<func::CallOp>(loc, lessThanFunc, cmpTypes, compareOperands)
495           .getResult(0);
496   SmallVector<Type, 1> ifTypes{lo.getType()};
497   scf::IfOp ifOp1 =
498       builder.create<scf::IfOp>(loc, ifTypes, cond1, /*else=*/true);
499 
500   // Generate an if-stmt to find the median value, assuming we already know that
501   // data[b] < data[a] and we haven't compare data[c] yet.
502   auto createFindMedian = [&](Value a, Value b, Value c) -> scf::IfOp {
503     compareOperands[0] = c;
504     compareOperands[1] = a;
505     // Compare data[c]] < data[a].
506     Value cond2 =
507         builder
508             .create<func::CallOp>(loc, lessThanFunc, cmpTypes, compareOperands)
509             .getResult(0);
510     scf::IfOp ifOp2 =
511         builder.create<scf::IfOp>(loc, ifTypes, cond2, /*else=*/true);
512     builder.setInsertionPointToStart(&ifOp2.getThenRegion().front());
513     compareOperands[0] = c;
514     compareOperands[1] = b;
515     // Compare data[c] < data[b].
516     Value cond3 =
517         builder
518             .create<func::CallOp>(loc, lessThanFunc, cmpTypes, compareOperands)
519             .getResult(0);
520     builder.create<scf::YieldOp>(
521         loc, ValueRange{builder.create<arith::SelectOp>(loc, cond3, b, c)});
522     builder.setInsertionPointToStart(&ifOp2.getElseRegion().front());
523     builder.create<scf::YieldOp>(loc, ValueRange{a});
524     return ifOp2;
525   };
526 
527   builder.setInsertionPointToStart(&ifOp1.getThenRegion().front());
528   scf::IfOp ifOp2 = createFindMedian(lo, mi, hi);
529   builder.setInsertionPointAfter(ifOp2);
530   builder.create<scf::YieldOp>(loc, ValueRange{ifOp2.getResult(0)});
531 
532   builder.setInsertionPointToStart(&ifOp1.getElseRegion().front());
533   scf::IfOp ifOp3 = createFindMedian(mi, lo, hi);
534 
535   builder.setInsertionPointAfter(ifOp3);
536   builder.create<scf::YieldOp>(loc, ValueRange{ifOp3.getResult(0)});
537 
538   builder.setInsertionPointAfter(ifOp1);
539   Value median = ifOp1.getResult(0);
540   Value cond =
541       builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ne, mi, median);
542   scf::IfOp ifOp =
543       builder.create<scf::IfOp>(loc, TypeRange(), cond, /*else=*/false);
544 
545   SmallVector<Value> swapOperands{median, mi};
546   swapOperands.append(args.begin() + xStartIdx, args.end());
547   builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
548   createSwap(builder, loc, swapOperands, nx, ny, isCoo);
549   builder.setInsertionPointAfter(ifOp);
550 }
551 
552 /// Creates a function to perform quick sort partition on the values in the
553 /// range of index [lo, hi), assuming lo < hi.
554 //
555 // The generated IR corresponds to this C like algorithm:
556 // int partition(lo, hi, xs) {
557 //   p = (lo+hi)/2  // pivot index
558 //   i = lo
559 //   j = hi-1
560 //   while (i < j) do {
561 //     while (xs[i] < xs[p]) i ++;
562 //     i_eq = (xs[i] == xs[p]);
563 //     while (xs[j] > xs[p]) j --;
564 //     j_eq = (xs[j] == xs[p]);
565 //     if (i < j) {
566 //       swap(xs[i], xs[j])
567 //       if (i == p) {
568 //         p = j;
569 //       } else if (j == p) {
570 //         p = i;
571 //       }
572 //       if (i_eq && j_eq) {
573 //         ++i;
574 //         --j;
575 //       }
576 //     }
577 //   }
578 //   return p
579 //   }
580 static void createPartitionFunc(OpBuilder &builder, ModuleOp module,
581                                 func::FuncOp func, uint64_t nx, uint64_t ny,
582                                 bool isCoo, uint32_t nTrailingP = 0) {
583   // Quick sort partition doesn't use trailing parameters.
584   (void)nTrailingP;
585   assert(nTrailingP == 0);
586   OpBuilder::InsertionGuard insertionGuard(builder);
587 
588   Block *entryBlock = func.addEntryBlock();
589   builder.setInsertionPointToStart(entryBlock);
590 
591   Location loc = func.getLoc();
592   ValueRange args = entryBlock->getArguments();
593   Value lo = args[loIdx];
594   Value hi = args[hiIdx];
595   Value sum = builder.create<arith::AddIOp>(loc, lo, hi);
596   Value c1 = constantIndex(builder, loc, 1);
597   Value p = builder.create<arith::ShRUIOp>(loc, sum, c1);
598 
599   Value i = lo;
600   Value j = builder.create<arith::SubIOp>(loc, hi, c1);
601   createChoosePivot(builder, module, func, nx, ny, isCoo, i, j, p, args);
602   SmallVector<Value, 3> operands{i, j, p}; // Exactly three values.
603   SmallVector<Type, 3> types{i.getType(), j.getType(), p.getType()};
604   scf::WhileOp whileOp = builder.create<scf::WhileOp>(loc, types, operands);
605 
606   // The before-region of the WhileOp.
607   Block *before =
608       builder.createBlock(&whileOp.getBefore(), {}, types, {loc, loc, loc});
609   builder.setInsertionPointToEnd(before);
610   Value cond = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult,
611                                              before->getArgument(0),
612                                              before->getArgument(1));
613   builder.create<scf::ConditionOp>(loc, cond, before->getArguments());
614 
615   // The after-region of the WhileOp.
616   Block *after =
617       builder.createBlock(&whileOp.getAfter(), {}, types, {loc, loc, loc});
618   builder.setInsertionPointToEnd(after);
619   i = after->getArgument(0);
620   j = after->getArgument(1);
621   p = after->getArgument(2);
622 
623   uint64_t numXBuffers = isCoo ? 1 : nx;
624   auto [iresult, iCompareEq] =
625       createScanLoop(builder, module, func, args.slice(xStartIdx, numXBuffers),
626                      i, p, nx, ny, isCoo, 1);
627   i = iresult;
628   auto [jresult, jCompareEq] =
629       createScanLoop(builder, module, func, args.slice(xStartIdx, numXBuffers),
630                      j, p, nx, ny, isCoo, -1);
631   j = jresult;
632 
633   // If i < j:
634   cond = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult, i, j);
635   scf::IfOp ifOp = builder.create<scf::IfOp>(loc, types, cond, /*else=*/true);
636   builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
637   SmallVector<Value> swapOperands{i, j};
638   swapOperands.append(args.begin() + xStartIdx, args.end());
639   createSwap(builder, loc, swapOperands, nx, ny, isCoo);
640   // If the pivot is moved, update p with the new pivot.
641   Value icond =
642       builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, i, p);
643   scf::IfOp ifOpI = builder.create<scf::IfOp>(loc, TypeRange{p.getType()},
644                                               icond, /*else=*/true);
645   builder.setInsertionPointToStart(&ifOpI.getThenRegion().front());
646   builder.create<scf::YieldOp>(loc, ValueRange{j});
647   builder.setInsertionPointToStart(&ifOpI.getElseRegion().front());
648   Value jcond =
649       builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, j, p);
650   scf::IfOp ifOpJ = builder.create<scf::IfOp>(loc, TypeRange{p.getType()},
651                                               jcond, /*else=*/true);
652   builder.setInsertionPointToStart(&ifOpJ.getThenRegion().front());
653   builder.create<scf::YieldOp>(loc, ValueRange{i});
654   builder.setInsertionPointToStart(&ifOpJ.getElseRegion().front());
655   builder.create<scf::YieldOp>(loc, ValueRange{p});
656   builder.setInsertionPointAfter(ifOpJ);
657   builder.create<scf::YieldOp>(loc, ifOpJ.getResults());
658   builder.setInsertionPointAfter(ifOpI);
659   Value compareEqIJ =
660       builder.create<arith::AndIOp>(loc, iCompareEq, jCompareEq);
661   scf::IfOp ifOp2 = builder.create<scf::IfOp>(
662       loc, TypeRange{i.getType(), j.getType()}, compareEqIJ, /*else=*/true);
663   builder.setInsertionPointToStart(&ifOp2.getThenRegion().front());
664   Value i2 = builder.create<arith::AddIOp>(loc, i, c1);
665   Value j2 = builder.create<arith::SubIOp>(loc, j, c1);
666   builder.create<scf::YieldOp>(loc, ValueRange{i2, j2});
667   builder.setInsertionPointToStart(&ifOp2.getElseRegion().front());
668   builder.create<scf::YieldOp>(loc, ValueRange{i, j});
669   builder.setInsertionPointAfter(ifOp2);
670   builder.create<scf::YieldOp>(
671       loc,
672       ValueRange{ifOp2.getResult(0), ifOp2.getResult(1), ifOpI.getResult(0)});
673 
674   // False branch for if i < j:
675   builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
676   builder.create<scf::YieldOp>(loc, ValueRange{i, j, p});
677 
678   // Return for the whileOp.
679   builder.setInsertionPointAfter(ifOp);
680   builder.create<scf::YieldOp>(loc, ifOp.getResults());
681 
682   // Return for the function.
683   builder.setInsertionPointAfter(whileOp);
684   builder.create<func::ReturnOp>(loc, whileOp.getResult(2));
685 }
686 
687 /// Computes (n-2)/n, assuming n has index type.
688 static Value createSubTwoDividedByTwo(OpBuilder &builder, Location loc,
689                                       Value n) {
690   Value i2 = constantIndex(builder, loc, 2);
691   Value res = builder.create<arith::SubIOp>(loc, n, i2);
692   Value i1 = constantIndex(builder, loc, 1);
693   return builder.create<arith::ShRUIOp>(loc, res, i1);
694 }
695 
696 /// Creates a function to heapify the subtree with root `start` within the full
697 /// binary tree in the range of index [first, first + n).
698 //
699 // The generated IR corresponds to this C like algorithm:
700 // void shiftDown(first, start, n, data) {
701 //   if (n >= 2) {
702 //     child = start - first
703 //     if ((n-2)/2 >= child) {
704 //       // Left child exists.
705 //       child = child * 2 + 1 // Initialize the bigger child to left child.
706 //       childIndex = child + first
707 //       if (child+1 < n && data[childIndex] < data[childIndex+1])
708 //         // Right child exits and is bigger.
709 //         childIndex++; child++;
710 //       // Shift data[start] down to where it belongs in the subtree.
711 //       while (data[start] < data[childIndex) {
712 //         swap(data[start], data[childIndex])
713 //         start = childIndex
714 //         if ((n - 2)/2 >= child) {
715 //           // Left child exists.
716 //           child = 2*child + 1
717 //           childIndex = child + 1
718 //           if (child + 1) < n && data[childIndex] < data[childIndex+1]
719 //             childIndex++; child++;
720 //         }
721 //       }
722 //     }
723 //   }
724 // }
725 //
726 static void createShiftDownFunc(OpBuilder &builder, ModuleOp module,
727                                 func::FuncOp func, uint64_t nx, uint64_t ny,
728                                 bool isCoo, uint32_t nTrailingP) {
729   // The value n is passed in as a trailing parameter.
730   assert(nTrailingP == 1);
731   OpBuilder::InsertionGuard insertionGuard(builder);
732   Block *entryBlock = func.addEntryBlock();
733   builder.setInsertionPointToStart(entryBlock);
734 
735   Location loc = func.getLoc();
736   Value n = entryBlock->getArguments().back();
737   ValueRange args = entryBlock->getArguments().drop_back();
738   Value first = args[loIdx];
739   Value start = args[hiIdx];
740 
741   // If (n >= 2).
742   Value c2 = constantIndex(builder, loc, 2);
743   Value condN =
744       builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::uge, n, c2);
745   scf::IfOp ifN = builder.create<scf::IfOp>(loc, condN, /*else=*/false);
746   builder.setInsertionPointToStart(&ifN.getThenRegion().front());
747   Value child = builder.create<arith::SubIOp>(loc, start, first);
748 
749   // If ((n-2)/2 >= child).
750   Value t = createSubTwoDividedByTwo(builder, loc, n);
751   Value condNc =
752       builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::uge, t, child);
753   scf::IfOp ifNc = builder.create<scf::IfOp>(loc, condNc, /*else=*/false);
754 
755   builder.setInsertionPointToStart(&ifNc.getThenRegion().front());
756   Value c1 = constantIndex(builder, loc, 1);
757   SmallVector<Value> compareOperands{start, start};
758   uint64_t numXBuffers = isCoo ? 1 : nx;
759   compareOperands.append(args.begin() + xStartIdx,
760                          args.begin() + xStartIdx + numXBuffers);
761   Type i1Type = IntegerType::get(module.getContext(), 1, IntegerType::Signless);
762   FlatSymbolRefAttr lessThanFunc = getMangledSortHelperFunc(
763       builder, func, {i1Type}, kLessThanFuncNamePrefix, nx, ny, isCoo,
764       compareOperands, createLessThanFunc);
765 
766   // Generate code to inspect the children of 'r' and return the larger child
767   // as follows:
768   //   child = r * 2 + 1 // Left child.
769   //   childIndex = child + first
770   //   if (child+1 < n && data[childIndex] < data[childIndex+1])
771   //     childIndex ++; child ++ // Right child is bigger.
772   auto getLargerChild = [&](Value r) -> std::pair<Value, Value> {
773     Value lChild = builder.create<arith::ShLIOp>(loc, r, c1);
774     lChild = builder.create<arith::AddIOp>(loc, lChild, c1);
775     Value lChildIdx = builder.create<arith::AddIOp>(loc, lChild, first);
776     Value rChild = builder.create<arith::AddIOp>(loc, lChild, c1);
777     Value cond1 = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult,
778                                                 rChild, n);
779     SmallVector<Type, 2> ifTypes(2, r.getType());
780     scf::IfOp if1 =
781         builder.create<scf::IfOp>(loc, ifTypes, cond1, /*else=*/true);
782     builder.setInsertionPointToStart(&if1.getThenRegion().front());
783     Value rChildIdx = builder.create<arith::AddIOp>(loc, rChild, first);
784     // Compare data[left] < data[right].
785     compareOperands[0] = lChildIdx;
786     compareOperands[1] = rChildIdx;
787     Value cond2 = builder
788                       .create<func::CallOp>(loc, lessThanFunc,
789                                             TypeRange{i1Type}, compareOperands)
790                       .getResult(0);
791     scf::IfOp if2 =
792         builder.create<scf::IfOp>(loc, ifTypes, cond2, /*else=*/true);
793     builder.setInsertionPointToStart(&if2.getThenRegion().front());
794     builder.create<scf::YieldOp>(loc, ValueRange{rChild, rChildIdx});
795     builder.setInsertionPointToStart(&if2.getElseRegion().front());
796     builder.create<scf::YieldOp>(loc, ValueRange{lChild, lChildIdx});
797     builder.setInsertionPointAfter(if2);
798     builder.create<scf::YieldOp>(loc, if2.getResults());
799     builder.setInsertionPointToStart(&if1.getElseRegion().front());
800     builder.create<scf::YieldOp>(loc, ValueRange{lChild, lChildIdx});
801     builder.setInsertionPointAfter(if1);
802     return std::make_pair(if1.getResult(0), if1.getResult(1));
803   };
804 
805   Value childIdx;
806   std::tie(child, childIdx) = getLargerChild(child);
807 
808   // While (data[start] < data[childIndex]).
809   SmallVector<Type, 3> types(3, child.getType());
810   scf::WhileOp whileOp = builder.create<scf::WhileOp>(
811       loc, types, SmallVector<Value, 2>{start, child, childIdx});
812 
813   // The before-region of the WhileOp.
814   SmallVector<Location, 3> locs(3, loc);
815   Block *before = builder.createBlock(&whileOp.getBefore(), {}, types, locs);
816   builder.setInsertionPointToEnd(before);
817   start = before->getArgument(0);
818   childIdx = before->getArgument(2);
819   compareOperands[0] = start;
820   compareOperands[1] = childIdx;
821   Value cond = builder
822                    .create<func::CallOp>(loc, lessThanFunc, TypeRange{i1Type},
823                                          compareOperands)
824                    .getResult(0);
825   builder.create<scf::ConditionOp>(loc, cond, before->getArguments());
826 
827   // The after-region of the WhileOp.
828   Block *after = builder.createBlock(&whileOp.getAfter(), {}, types, locs);
829   start = after->getArgument(0);
830   child = after->getArgument(1);
831   childIdx = after->getArgument(2);
832   SmallVector<Value> swapOperands{start, childIdx};
833   swapOperands.append(args.begin() + xStartIdx, args.end());
834   createSwap(builder, loc, swapOperands, nx, ny, isCoo);
835   start = childIdx;
836   Value cond2 =
837       builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::uge, t, child);
838   scf::IfOp if2 = builder.create<scf::IfOp>(
839       loc, TypeRange{child.getType(), child.getType()}, cond2, /*else=*/true);
840   builder.setInsertionPointToStart(&if2.getThenRegion().front());
841   auto [newChild, newChildIdx] = getLargerChild(child);
842   builder.create<scf::YieldOp>(loc, ValueRange{newChild, newChildIdx});
843   builder.setInsertionPointToStart(&if2.getElseRegion().front());
844   builder.create<scf::YieldOp>(loc, ValueRange{child, childIdx});
845   builder.setInsertionPointAfter(if2);
846   builder.create<scf::YieldOp>(
847       loc, ValueRange{start, if2.getResult(0), if2.getResult(1)});
848 
849   builder.setInsertionPointAfter(ifN);
850   builder.create<func::ReturnOp>(loc);
851 }
852 
853 /// Creates a function to perform heap sort on the values in the range of index
854 /// [lo, hi) with the assumption hi - lo >= 2.
855 //
856 // The generate IR corresponds to this C like algorithm:
857 // void heapSort(lo, hi, data) {
858 //   n = hi - lo
859 //   for i = (n-2)/2 downto 0
860 //     shiftDown(lo, lo+i, n)
861 //
862 //   for l = n downto 2
863 //      swap(lo, lo+l-1)
864 //      shiftdown(lo, lo, l-1)
865 // }
866 static void createHeapSortFunc(OpBuilder &builder, ModuleOp module,
867                                func::FuncOp func, uint64_t nx, uint64_t ny,
868                                bool isCoo, uint32_t nTrailingP) {
869   // Heap sort function doesn't have trailing parameters.
870   (void)nTrailingP;
871   assert(nTrailingP == 0);
872   OpBuilder::InsertionGuard insertionGuard(builder);
873   Block *entryBlock = func.addEntryBlock();
874   builder.setInsertionPointToStart(entryBlock);
875 
876   Location loc = func.getLoc();
877   ValueRange args = entryBlock->getArguments();
878   Value lo = args[loIdx];
879   Value hi = args[hiIdx];
880   Value n = builder.create<arith::SubIOp>(loc, hi, lo);
881 
882   // For i = (n-2)/2 downto 0.
883   Value c0 = constantIndex(builder, loc, 0);
884   Value c1 = constantIndex(builder, loc, 1);
885   Value s = createSubTwoDividedByTwo(builder, loc, n);
886   Value up = builder.create<arith::AddIOp>(loc, s, c1);
887   scf::ForOp forI = builder.create<scf::ForOp>(loc, c0, up, c1);
888   builder.setInsertionPointToStart(forI.getBody());
889   Value i = builder.create<arith::SubIOp>(loc, s, forI.getInductionVar());
890   Value lopi = builder.create<arith::AddIOp>(loc, lo, i);
891   SmallVector<Value> shiftDownOperands = {lo, lopi};
892   shiftDownOperands.append(args.begin() + xStartIdx, args.end());
893   shiftDownOperands.push_back(n);
894   FlatSymbolRefAttr shiftDownFunc = getMangledSortHelperFunc(
895       builder, func, TypeRange(), kShiftDownFuncNamePrefix, nx, ny, isCoo,
896       shiftDownOperands, createShiftDownFunc, /*nTrailingP=*/1);
897   builder.create<func::CallOp>(loc, shiftDownFunc, TypeRange(),
898                                shiftDownOperands);
899 
900   builder.setInsertionPointAfter(forI);
901   // For l = n downto 2.
902   up = builder.create<arith::SubIOp>(loc, n, c1);
903   scf::ForOp forL = builder.create<scf::ForOp>(loc, c0, up, c1);
904   builder.setInsertionPointToStart(forL.getBody());
905   Value l = builder.create<arith::SubIOp>(loc, n, forL.getInductionVar());
906   Value loplm1 = builder.create<arith::AddIOp>(loc, lo, l);
907   loplm1 = builder.create<arith::SubIOp>(loc, loplm1, c1);
908   SmallVector<Value> swapOperands{lo, loplm1};
909   swapOperands.append(args.begin() + xStartIdx, args.end());
910   createSwap(builder, loc, swapOperands, nx, ny, isCoo);
911   shiftDownOperands[1] = lo;
912   shiftDownOperands[shiftDownOperands.size() - 1] =
913       builder.create<arith::SubIOp>(loc, l, c1);
914   builder.create<func::CallOp>(loc, shiftDownFunc, TypeRange(),
915                                shiftDownOperands);
916 
917   builder.setInsertionPointAfter(forL);
918   builder.create<func::ReturnOp>(loc);
919 }
920 
921 /// A helper for generating code to perform quick sort. It partitions [lo, hi),
922 /// recursively calls quick sort to process the smaller partition and returns
923 /// the bigger partition to be processed by the enclosed while-loop.
924 static std::pair<Value, Value>
925 createQuickSort(OpBuilder &builder, ModuleOp module, func::FuncOp func,
926                 ValueRange args, uint64_t nx, uint64_t ny, bool isCoo,
927                 uint32_t nTrailingP) {
928   MLIRContext *context = module.getContext();
929   Location loc = func.getLoc();
930   Value lo = args[loIdx];
931   Value hi = args[hiIdx];
932   FlatSymbolRefAttr partitionFunc = getMangledSortHelperFunc(
933       builder, func, {IndexType::get(context)}, kPartitionFuncNamePrefix, nx,
934       ny, isCoo, args.drop_back(nTrailingP), createPartitionFunc);
935   Value p = builder
936                 .create<func::CallOp>(loc, partitionFunc,
937                                       TypeRange{IndexType::get(context)},
938                                       args.drop_back(nTrailingP))
939                 .getResult(0);
940   Value pP1 =
941       builder.create<arith::AddIOp>(loc, p, constantIndex(builder, loc, 1));
942   Value lenLow = builder.create<arith::SubIOp>(loc, p, lo);
943   Value lenHigh = builder.create<arith::SubIOp>(loc, hi, p);
944   Value cond = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ule,
945                                              lenLow, lenHigh);
946 
947   SmallVector<Type, 2> types(2, lo.getType()); // Only two types.
948   scf::IfOp ifOp = builder.create<scf::IfOp>(loc, types, cond, /*else=*/true);
949 
950   Value c0 = constantIndex(builder, loc, 0);
951   auto mayRecursion = [&](Value low, Value high, Value len) {
952     Value cond =
953         builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ne, len, c0);
954     scf::IfOp ifOp = builder.create<scf::IfOp>(loc, cond, /*else=*/false);
955     builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
956     SmallVector<Value> operands{low, high};
957     operands.append(args.begin() + xStartIdx, args.end());
958     builder.create<func::CallOp>(loc, func, operands);
959     builder.setInsertionPointAfter(ifOp);
960   };
961 
962   // Recursively call quickSort to process the smaller partition and return
963   // the bigger partition to be processed by the enclosed while-loop.
964   builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
965   mayRecursion(lo, p, lenLow);
966   builder.create<scf::YieldOp>(loc, ValueRange{pP1, hi});
967 
968   builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
969   mayRecursion(pP1, hi, lenHigh);
970   builder.create<scf::YieldOp>(loc, ValueRange{lo, p});
971 
972   builder.setInsertionPointAfter(ifOp);
973   return std::make_pair(ifOp.getResult(0), ifOp.getResult(1));
974 }
975 
976 /// Creates a function to perform insertion sort on the values in the range of
977 /// index [lo, hi).
978 //
979 // The generate IR corresponds to this C like algorithm:
980 // void insertionSort(lo, hi, data) {
981 //   for (i = lo+1; i < hi; i++) {
982 //      d = data[i];
983 //      p = binarySearch(lo, i-1, data)
984 //      for (j = 0; j > i - p; j++)
985 //        data[i-j] = data[i-j-1]
986 //      data[p] = d
987 //   }
988 // }
989 static void createSortStableFunc(OpBuilder &builder, ModuleOp module,
990                                  func::FuncOp func, uint64_t nx, uint64_t ny,
991                                  bool isCoo, uint32_t nTrailingP) {
992   // Stable sort function doesn't use trailing parameters.
993   (void)nTrailingP;
994   assert(nTrailingP == 0);
995   OpBuilder::InsertionGuard insertionGuard(builder);
996   Block *entryBlock = func.addEntryBlock();
997   builder.setInsertionPointToStart(entryBlock);
998 
999   MLIRContext *context = module.getContext();
1000   Location loc = func.getLoc();
1001   ValueRange args = entryBlock->getArguments();
1002   Value c1 = constantIndex(builder, loc, 1);
1003   Value lo = args[loIdx];
1004   Value hi = args[hiIdx];
1005   Value lop1 = builder.create<arith::AddIOp>(loc, lo, c1);
1006 
1007   // Start the outer for-stmt with induction variable i.
1008   scf::ForOp forOpI = builder.create<scf::ForOp>(loc, lop1, hi, c1);
1009   builder.setInsertionPointToStart(forOpI.getBody());
1010   Value i = forOpI.getInductionVar();
1011 
1012   // Binary search to find the insertion point p.
1013   SmallVector<Value> operands{lo, i};
1014   operands.append(args.begin() + xStartIdx, args.end());
1015   FlatSymbolRefAttr searchFunc = getMangledSortHelperFunc(
1016       builder, func, {IndexType::get(context)}, kBinarySearchFuncNamePrefix, nx,
1017       ny, isCoo, operands, createBinarySearchFunc);
1018   Value p = builder
1019                 .create<func::CallOp>(loc, searchFunc, TypeRange{c1.getType()},
1020                                       operands)
1021                 .getResult(0);
1022 
1023   // Move the value at data[i] to a temporary location.
1024   operands[0] = operands[1] = i;
1025   SmallVector<Value> d;
1026   forEachIJPairInAllBuffers(
1027       builder, loc, operands, nx, ny, isCoo,
1028       [&](uint64_t unused, Value i, Value unused2, Value buffer) {
1029         d.push_back(builder.create<memref::LoadOp>(loc, buffer, i));
1030       });
1031 
1032   // Start the inner for-stmt with induction variable j, for moving data[p..i)
1033   // to data[p+1..i+1).
1034   Value imp = builder.create<arith::SubIOp>(loc, i, p);
1035   Value c0 = constantIndex(builder, loc, 0);
1036   scf::ForOp forOpJ = builder.create<scf::ForOp>(loc, c0, imp, c1);
1037   builder.setInsertionPointToStart(forOpJ.getBody());
1038   Value j = forOpJ.getInductionVar();
1039   Value imj = builder.create<arith::SubIOp>(loc, i, j);
1040   operands[1] = imj;
1041   operands[0] = builder.create<arith::SubIOp>(loc, imj, c1);
1042   forEachIJPairInAllBuffers(
1043       builder, loc, operands, nx, ny, isCoo,
1044       [&](uint64_t unused, Value imjm1, Value imj, Value buffer) {
1045         Value t = builder.create<memref::LoadOp>(loc, buffer, imjm1);
1046         builder.create<memref::StoreOp>(loc, t, buffer, imj);
1047       });
1048 
1049   // Store the value at data[i] to data[p].
1050   builder.setInsertionPointAfter(forOpJ);
1051   operands[0] = operands[1] = p;
1052   forEachIJPairInAllBuffers(
1053       builder, loc, operands, nx, ny, isCoo,
1054       [&](uint64_t k, Value p, Value usused, Value buffer) {
1055         builder.create<memref::StoreOp>(loc, d[k], buffer, p);
1056       });
1057 
1058   builder.setInsertionPointAfter(forOpI);
1059   builder.create<func::ReturnOp>(loc);
1060 }
1061 
1062 /// Creates a function to perform quick sort or a hybrid quick sort on the
1063 /// values in the range of index [lo, hi).
1064 //
1065 //
1066 // When nTrailingP == 0, the generated IR corresponds to this C like algorithm:
1067 // void quickSort(lo, hi, data) {
1068 //   while (lo + 1 < hi) {
1069 //        p = partition(low, high, data);
1070 //        if (len(lo, p) < len(p+1, hi)) {
1071 //          quickSort(lo, p, data);
1072 //          lo = p+1;
1073 //        } else {
1074 //          quickSort(p + 1, hi, data);
1075 //          hi = p;
1076 //        }
1077 //   }
1078 // }
1079 //
1080 // When nTrailingP == 1, the generated IR corresponds to this C like algorithm:
1081 // void hybridQuickSort(lo, hi, data, depthLimit) {
1082 //   while (lo + 1 < hi) {
1083 //     len = hi - lo;
1084 //     if (len <= limit) {
1085 //       insertionSort(lo, hi, data);
1086 //     } else {
1087 //       depthLimit --;
1088 //       if (depthLimit <= 0) {
1089 //         heapSort(lo, hi, data);
1090 //       } else {
1091 //          p = partition(low, high, data);
1092 //          if (len(lo, p) < len(p+1, hi)) {
1093 //            quickSort(lo, p, data, depthLimit);
1094 //            lo = p+1;
1095 //          } else {
1096 //            quickSort(p + 1, hi, data, depthLimit);
1097 //            hi = p;
1098 //          }
1099 //       }
1100 //     }
1101 //   }
1102 // }
1103 //
1104 static void createQuickSortFunc(OpBuilder &builder, ModuleOp module,
1105                                 func::FuncOp func, uint64_t nx, uint64_t ny,
1106                                 bool isCoo, uint32_t nTrailingP) {
1107   assert(nTrailingP == 1 || nTrailingP == 0);
1108   bool isHybrid = (nTrailingP == 1);
1109   OpBuilder::InsertionGuard insertionGuard(builder);
1110   Block *entryBlock = func.addEntryBlock();
1111   builder.setInsertionPointToStart(entryBlock);
1112 
1113   Location loc = func.getLoc();
1114   SmallVector<Value> args;
1115   args.append(entryBlock->getArguments().begin(),
1116               entryBlock->getArguments().end());
1117   Value lo = args[loIdx];
1118   Value hi = args[hiIdx];
1119   SmallVector<Type, 2> types(2, lo.getType()); // Only two types.
1120   scf::WhileOp whileOp =
1121       builder.create<scf::WhileOp>(loc, types, SmallVector<Value, 2>{lo, hi});
1122 
1123   // The before-region of the WhileOp.
1124   Block *before =
1125       builder.createBlock(&whileOp.getBefore(), {}, types, {loc, loc});
1126   builder.setInsertionPointToEnd(before);
1127   lo = before->getArgument(0);
1128   hi = before->getArgument(1);
1129   Value loP1 =
1130       builder.create<arith::AddIOp>(loc, lo, constantIndex(builder, loc, 1));
1131   Value needSort =
1132       builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult, loP1, hi);
1133   builder.create<scf::ConditionOp>(loc, needSort, before->getArguments());
1134 
1135   // The after-region of the WhileOp.
1136   Block *after =
1137       builder.createBlock(&whileOp.getAfter(), {}, types, {loc, loc});
1138   builder.setInsertionPointToEnd(after);
1139   lo = after->getArgument(0);
1140   hi = after->getArgument(1);
1141   args[0] = lo;
1142   args[1] = hi;
1143 
1144   if (isHybrid) {
1145     Value len = builder.create<arith::SubIOp>(loc, hi, lo);
1146     Value lenLimit = constantIndex(builder, loc, 30);
1147     Value lenCond = builder.create<arith::CmpIOp>(
1148         loc, arith::CmpIPredicate::ule, len, lenLimit);
1149     scf::IfOp lenIf =
1150         builder.create<scf::IfOp>(loc, types, lenCond, /*else=*/true);
1151 
1152     // When len <= limit.
1153     builder.setInsertionPointToStart(&lenIf.getThenRegion().front());
1154     FlatSymbolRefAttr insertionSortFunc = getMangledSortHelperFunc(
1155         builder, func, TypeRange(), kSortStableFuncNamePrefix, nx, ny, isCoo,
1156         ValueRange(args).drop_back(nTrailingP), createSortStableFunc);
1157     builder.create<func::CallOp>(loc, insertionSortFunc, TypeRange(),
1158                                  ValueRange(args).drop_back(nTrailingP));
1159     builder.create<scf::YieldOp>(loc, ValueRange{lo, lo});
1160 
1161     // When len > limit.
1162     builder.setInsertionPointToStart(&lenIf.getElseRegion().front());
1163     Value depthLimit = args.back();
1164     depthLimit = builder.create<arith::SubIOp>(loc, depthLimit,
1165                                                constantI64(builder, loc, 1));
1166     Value depthCond =
1167         builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ule,
1168                                       depthLimit, constantI64(builder, loc, 0));
1169     scf::IfOp depthIf =
1170         builder.create<scf::IfOp>(loc, types, depthCond, /*else=*/true);
1171 
1172     // When depth exceeds limit.
1173     builder.setInsertionPointToStart(&depthIf.getThenRegion().front());
1174     FlatSymbolRefAttr heapSortFunc = getMangledSortHelperFunc(
1175         builder, func, TypeRange(), kHeapSortFuncNamePrefix, nx, ny, isCoo,
1176         ValueRange(args).drop_back(nTrailingP), createHeapSortFunc);
1177     builder.create<func::CallOp>(loc, heapSortFunc, TypeRange(),
1178                                  ValueRange(args).drop_back(nTrailingP));
1179     builder.create<scf::YieldOp>(loc, ValueRange{lo, lo});
1180 
1181     // When depth doesn't exceed limit.
1182     builder.setInsertionPointToStart(&depthIf.getElseRegion().front());
1183     args.back() = depthLimit;
1184     std::tie(lo, hi) =
1185         createQuickSort(builder, module, func, args, nx, ny, isCoo, nTrailingP);
1186     builder.create<scf::YieldOp>(loc, ValueRange{lo, hi});
1187 
1188     builder.setInsertionPointAfter(depthIf);
1189     lo = depthIf.getResult(0);
1190     hi = depthIf.getResult(1);
1191     builder.create<scf::YieldOp>(loc, ValueRange{lo, hi});
1192 
1193     builder.setInsertionPointAfter(lenIf);
1194     lo = lenIf.getResult(0);
1195     hi = lenIf.getResult(1);
1196   } else {
1197     std::tie(lo, hi) =
1198         createQuickSort(builder, module, func, args, nx, ny, isCoo, nTrailingP);
1199   }
1200 
1201   // New [lo, hi) for the next while-loop iteration.
1202   builder.create<scf::YieldOp>(loc, ValueRange{lo, hi});
1203 
1204   // After the while-loop.
1205   builder.setInsertionPointAfter(whileOp);
1206   builder.create<func::ReturnOp>(loc);
1207 }
1208 
1209 /// Implements the rewriting for operator sort and sort_coo.
1210 template <typename OpTy>
1211 LogicalResult matchAndRewriteSortOp(OpTy op, ValueRange xys, uint64_t nx,
1212                                     uint64_t ny, bool isCoo,
1213                                     PatternRewriter &rewriter) {
1214   Location loc = op.getLoc();
1215   SmallVector<Value> operands{constantIndex(rewriter, loc, 0), op.getN()};
1216 
1217   // Convert `values` to have dynamic shape and append them to `operands`.
1218   for (Value v : xys) {
1219     auto mtp = getMemRefType(v);
1220     if (!mtp.isDynamicDim(0)) {
1221       auto newMtp =
1222           MemRefType::get({ShapedType::kDynamic}, mtp.getElementType());
1223       v = rewriter.create<memref::CastOp>(loc, newMtp, v);
1224     }
1225     operands.push_back(v);
1226   }
1227 
1228   auto insertPoint = op->template getParentOfType<func::FuncOp>();
1229   if (!insertPoint)
1230     return failure();
1231 
1232   SmallString<32> funcName;
1233   FuncGeneratorType funcGenerator;
1234   uint32_t nTrailingP = 0;
1235   switch (op.getAlgorithm()) {
1236   case SparseTensorSortKind::HybridQuickSort: {
1237     funcName = kHybridQuickSortFuncNamePrefix;
1238     funcGenerator = createQuickSortFunc;
1239     nTrailingP = 1;
1240     // As a heuristics, set depthLimit = 2 * log2(n).
1241     Value lo = operands[loIdx];
1242     Value hi = operands[hiIdx];
1243     Value len = rewriter.create<arith::IndexCastOp>(
1244         loc, rewriter.getI64Type(),
1245         rewriter.create<arith::SubIOp>(loc, hi, lo));
1246     Value depthLimit = rewriter.create<arith::SubIOp>(
1247         loc, constantI64(rewriter, loc, 64),
1248         rewriter.create<math::CountLeadingZerosOp>(loc, len));
1249     operands.push_back(depthLimit);
1250     break;
1251   }
1252   case SparseTensorSortKind::QuickSort:
1253     funcName = kQuickSortFuncNamePrefix;
1254     funcGenerator = createQuickSortFunc;
1255     break;
1256   case SparseTensorSortKind::InsertionSortStable:
1257     funcName = kSortStableFuncNamePrefix;
1258     funcGenerator = createSortStableFunc;
1259     break;
1260   case SparseTensorSortKind::HeapSort:
1261     funcName = kHeapSortFuncNamePrefix;
1262     funcGenerator = createHeapSortFunc;
1263     break;
1264   }
1265 
1266   FlatSymbolRefAttr func =
1267       getMangledSortHelperFunc(rewriter, insertPoint, TypeRange(), funcName, nx,
1268                                ny, isCoo, operands, funcGenerator, nTrailingP);
1269   rewriter.replaceOpWithNewOp<func::CallOp>(op, func, TypeRange(), operands);
1270   return success();
1271 }
1272 
1273 //===---------------------------------------------------------------------===//
1274 // The actual sparse buffer rewriting rules.
1275 //===---------------------------------------------------------------------===//
1276 
1277 namespace {
1278 
1279 /// Sparse rewriting rule for the push_back operator.
1280 struct PushBackRewriter : OpRewritePattern<PushBackOp> {
1281 public:
1282   using OpRewritePattern<PushBackOp>::OpRewritePattern;
1283   PushBackRewriter(MLIRContext *context, bool enableInit)
1284       : OpRewritePattern(context), enableBufferInitialization(enableInit) {}
1285   LogicalResult matchAndRewrite(PushBackOp op,
1286                                 PatternRewriter &rewriter) const override {
1287     // Rewrite push_back(buffer, value, n) to:
1288     // new_size = size(buffer) + n
1289     // if (new_size > capacity(buffer))
1290     //    while new_size > new_capacity
1291     //      new_capacity = new_capacity*2
1292     //    new_buffer = realloc(buffer, new_capacity)
1293     // buffer = new_buffer
1294     // subBuffer = subviewof(buffer)
1295     // linalg.fill subBuffer value
1296     //
1297     // size(buffer) += n
1298     //
1299     // The capacity check is skipped when the attribute inbounds is presented.
1300     Location loc = op->getLoc();
1301     Value c0 = constantIndex(rewriter, loc, 0);
1302     Value buffer = op.getInBuffer();
1303     Value capacity = rewriter.create<memref::DimOp>(loc, buffer, c0);
1304     Value size = op.getCurSize();
1305     Value value = op.getValue();
1306 
1307     Value n = op.getN() ? op.getN() : constantIndex(rewriter, loc, 1);
1308     Value newSize = rewriter.create<arith::AddIOp>(loc, size, n);
1309     auto nValue = dyn_cast_or_null<arith::ConstantIndexOp>(n.getDefiningOp());
1310     bool nIsOne = (nValue && nValue.value() == 1);
1311 
1312     if (!op.getInbounds()) {
1313       Value cond = rewriter.create<arith::CmpIOp>(
1314           loc, arith::CmpIPredicate::ugt, newSize, capacity);
1315 
1316       Value c2 = constantIndex(rewriter, loc, 2);
1317       auto bufferType =
1318           MemRefType::get({ShapedType::kDynamic}, value.getType());
1319       scf::IfOp ifOp = rewriter.create<scf::IfOp>(loc, bufferType, cond,
1320                                                   /*else=*/true);
1321       // True branch.
1322       rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front());
1323       if (nIsOne) {
1324         capacity = rewriter.create<arith::MulIOp>(loc, capacity, c2);
1325       } else {
1326         // Use a do-while loop to calculate the new capacity as follows:
1327         //   do { new_capacity *= 2 } while (size > new_capacity)
1328         scf::WhileOp whileOp =
1329             rewriter.create<scf::WhileOp>(loc, capacity.getType(), capacity);
1330 
1331         // The before-region of the WhileOp.
1332         Block *before = rewriter.createBlock(&whileOp.getBefore(), {},
1333                                              {capacity.getType()}, {loc});
1334         rewriter.setInsertionPointToEnd(before);
1335 
1336         capacity =
1337             rewriter.create<arith::MulIOp>(loc, before->getArgument(0), c2);
1338         cond = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ugt,
1339                                               newSize, capacity);
1340         rewriter.create<scf::ConditionOp>(loc, cond, ValueRange{capacity});
1341         // The after-region of the WhileOp.
1342         Block *after = rewriter.createBlock(&whileOp.getAfter(), {},
1343                                             {capacity.getType()}, {loc});
1344         rewriter.setInsertionPointToEnd(after);
1345         rewriter.create<scf::YieldOp>(loc, after->getArguments());
1346 
1347         rewriter.setInsertionPointAfter(whileOp);
1348         capacity = whileOp.getResult(0);
1349       }
1350 
1351       Value newBuffer =
1352           rewriter.create<memref::ReallocOp>(loc, bufferType, buffer, capacity);
1353       if (enableBufferInitialization) {
1354         Value fillSize = rewriter.create<arith::SubIOp>(loc, capacity, newSize);
1355         Value fillValue = constantZero(rewriter, loc, value.getType());
1356         Value subBuffer = rewriter.create<memref::SubViewOp>(
1357             loc, newBuffer, /*offset=*/ValueRange{newSize},
1358             /*size=*/ValueRange{fillSize},
1359             /*step=*/ValueRange{constantIndex(rewriter, loc, 1)});
1360         rewriter.create<linalg::FillOp>(loc, fillValue, subBuffer);
1361       }
1362       rewriter.create<scf::YieldOp>(loc, newBuffer);
1363 
1364       // False branch.
1365       rewriter.setInsertionPointToStart(&ifOp.getElseRegion().front());
1366       rewriter.create<scf::YieldOp>(loc, buffer);
1367 
1368       // Prepare for adding the value to the end of the buffer.
1369       rewriter.setInsertionPointAfter(ifOp);
1370       buffer = ifOp.getResult(0);
1371     }
1372 
1373     // Add the value to the end of the buffer.
1374     if (nIsOne) {
1375       rewriter.create<memref::StoreOp>(loc, value, buffer, size);
1376     } else {
1377       Value subBuffer = rewriter.create<memref::SubViewOp>(
1378           loc, buffer, /*offset=*/ValueRange{size}, /*size=*/ValueRange{n},
1379           /*step=*/ValueRange{constantIndex(rewriter, loc, 1)});
1380       rewriter.create<linalg::FillOp>(loc, value, subBuffer);
1381     }
1382 
1383     // Update the buffer size.
1384     rewriter.replaceOp(op, {buffer, newSize});
1385     return success();
1386   }
1387 
1388 private:
1389   bool enableBufferInitialization;
1390 };
1391 
1392 /// Sparse rewriting rule for the sort operator.
1393 struct SortRewriter : public OpRewritePattern<SortOp> {
1394 public:
1395   using OpRewritePattern<SortOp>::OpRewritePattern;
1396 
1397   LogicalResult matchAndRewrite(SortOp op,
1398                                 PatternRewriter &rewriter) const override {
1399     SmallVector<Value> xys(op.getXs());
1400     xys.append(op.getYs().begin(), op.getYs().end());
1401     return matchAndRewriteSortOp(op, xys, op.getXs().size(), /*ny=*/0,
1402                                  /*isCoo=*/false, rewriter);
1403   }
1404 };
1405 
1406 /// Sparse rewriting rule for the sort_coo operator.
1407 struct SortCooRewriter : public OpRewritePattern<SortCooOp> {
1408 public:
1409   using OpRewritePattern<SortCooOp>::OpRewritePattern;
1410 
1411   LogicalResult matchAndRewrite(SortCooOp op,
1412                                 PatternRewriter &rewriter) const override {
1413     SmallVector<Value> xys;
1414     xys.push_back(op.getXy());
1415     xys.append(op.getYs().begin(), op.getYs().end());
1416     uint64_t nx = 1;
1417     if (auto nxAttr = op.getNxAttr())
1418       nx = nxAttr.getInt();
1419 
1420     uint64_t ny = 0;
1421     if (auto nyAttr = op.getNyAttr())
1422       ny = nyAttr.getInt();
1423 
1424     return matchAndRewriteSortOp(op, xys, nx, ny,
1425                                  /*isCoo=*/true, rewriter);
1426   }
1427 };
1428 
1429 } // namespace
1430 
1431 //===---------------------------------------------------------------------===//
1432 // Methods that add patterns described in this file to a pattern list.
1433 //===---------------------------------------------------------------------===//
1434 
1435 void mlir::populateSparseBufferRewriting(RewritePatternSet &patterns,
1436                                          bool enableBufferInitialization) {
1437   patterns.add<PushBackRewriter>(patterns.getContext(),
1438                                  enableBufferInitialization);
1439   patterns.add<SortRewriter, SortCooRewriter>(patterns.getContext());
1440 }
1441