xref: /llvm-project/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp (revision 4176ce61f156c8daa1e6bf8cc86d5e61e9413149)
1 //===- SparseBufferRewriting.cpp - Sparse buffer rewriting rules ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements rewriting rules that are specific to sparse tensor
10 // primitives with memref operands.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "CodegenUtils.h"
15 
16 #include "mlir/Dialect/Arith/IR/Arith.h"
17 #include "mlir/Dialect/Func/IR/FuncOps.h"
18 #include "mlir/Dialect/Linalg/IR/Linalg.h"
19 #include "mlir/Dialect/Math/IR/Math.h"
20 #include "mlir/Dialect/MemRef/IR/MemRef.h"
21 #include "mlir/Dialect/SCF/IR/SCF.h"
22 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
23 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
24 #include "mlir/Support/LLVM.h"
25 
26 using namespace mlir;
27 using namespace mlir::sparse_tensor;
28 
29 //===---------------------------------------------------------------------===//
30 // Helper methods for the actual rewriting rules.
31 //===---------------------------------------------------------------------===//
32 
33 static constexpr uint64_t loIdx = 0;
34 static constexpr uint64_t hiIdx = 1;
35 static constexpr uint64_t xStartIdx = 2;
36 
37 static constexpr const char kPartitionFuncNamePrefix[] = "_sparse_partition_";
38 static constexpr const char kBinarySearchFuncNamePrefix[] =
39     "_sparse_binary_search_";
40 static constexpr const char kHybridQuickSortFuncNamePrefix[] =
41     "_sparse_hybrid_qsort_";
42 static constexpr const char kSortStableFuncNamePrefix[] =
43     "_sparse_sort_stable_";
44 static constexpr const char kShiftDownFuncNamePrefix[] = "_sparse_shift_down_";
45 static constexpr const char kHeapSortFuncNamePrefix[] = "_sparse_heap_sort_";
46 static constexpr const char kQuickSortFuncNamePrefix[] = "_sparse_qsort_";
47 
48 using FuncGeneratorType = function_ref<void(
49     OpBuilder &, ModuleOp, func::FuncOp, uint64_t, uint64_t, bool, uint32_t)>;
50 
51 /// Constructs a function name with this format to facilitate quick sort:
52 ///   <namePrefix><nx>_<x type>_<y0 type>..._<yn type> for sort
53 ///   <namePrefix><nx>_<x type>_coo_<ny>_<y0 type>..._<yn type> for sort_coo
54 static void getMangledSortHelperFuncName(llvm::raw_svector_ostream &nameOstream,
55                                          StringRef namePrefix, uint64_t nx,
56                                          uint64_t ny, bool isCoo,
57                                          ValueRange operands) {
58   nameOstream << namePrefix << nx << "_"
59               << getMemRefType(operands[xStartIdx]).getElementType();
60 
61   if (isCoo)
62     nameOstream << "_coo_" << ny;
63 
64   uint64_t yBufferOffset = isCoo ? 1 : nx;
65   for (Value v : operands.drop_front(xStartIdx + yBufferOffset))
66     nameOstream << "_" << getMemRefType(v).getElementType();
67 }
68 
69 /// Looks up a function that is appropriate for the given operands being
70 /// sorted, and creates such a function if it doesn't exist yet. The
71 /// parameters `nx` and `ny` tell the number of x and y values provided
72 /// by the buffer in xStartIdx, and `isCoo` indicates whether the instruction
73 /// being processed is a sparse_tensor.sort or sparse_tensor.sort_coo.
74 //
75 // All sorting function generators take (lo, hi, xs, ys) in `operands` as
76 // parameters for the sorting functions. Other parameters, such as the recursive
77 // call depth, are appended to the end of the parameter list as
78 // "trailing parameters".
79 static FlatSymbolRefAttr
80 getMangledSortHelperFunc(OpBuilder &builder, func::FuncOp insertPoint,
81                          TypeRange resultTypes, StringRef namePrefix,
82                          uint64_t nx, uint64_t ny, bool isCoo,
83                          ValueRange operands, FuncGeneratorType createFunc,
84                          uint32_t nTrailingP = 0) {
85   SmallString<32> nameBuffer;
86   llvm::raw_svector_ostream nameOstream(nameBuffer);
87   getMangledSortHelperFuncName(nameOstream, namePrefix, nx, ny, isCoo,
88                                operands.drop_back(nTrailingP));
89 
90   ModuleOp module = insertPoint->getParentOfType<ModuleOp>();
91   MLIRContext *context = module.getContext();
92   auto result = SymbolRefAttr::get(context, nameOstream.str());
93   auto func = module.lookupSymbol<func::FuncOp>(result.getAttr());
94 
95   if (!func) {
96     // Create the function.
97     OpBuilder::InsertionGuard insertionGuard(builder);
98     builder.setInsertionPoint(insertPoint);
99     Location loc = insertPoint.getLoc();
100     func = builder.create<func::FuncOp>(
101         loc, nameOstream.str(),
102         FunctionType::get(context, operands.getTypes(), resultTypes));
103     func.setPrivate();
104     createFunc(builder, module, func, nx, ny, isCoo, nTrailingP);
105   }
106 
107   return result;
108 }
109 
110 /// Creates a code block to process each pair of (xs[i], xs[j]) for sorting.
111 /// The code to process the value pairs is generated by `bodyBuilder`.
112 static void forEachIJPairInXs(
113     OpBuilder &builder, Location loc, ValueRange args, uint64_t nx, uint64_t ny,
114     bool isCoo, function_ref<void(uint64_t, Value, Value, Value)> bodyBuilder) {
115   Value iOffset, jOffset;
116   if (isCoo) {
117     Value cstep = constantIndex(builder, loc, nx + ny);
118     iOffset = builder.create<arith::MulIOp>(loc, args[0], cstep);
119     jOffset = builder.create<arith::MulIOp>(loc, args[1], cstep);
120   }
121   for (uint64_t k = 0; k < nx; k++) {
122     scf::IfOp ifOp;
123     Value i, j, buffer;
124     if (isCoo) {
125       Value ck = constantIndex(builder, loc, k);
126       i = builder.create<arith::AddIOp>(loc, ck, iOffset);
127       j = builder.create<arith::AddIOp>(loc, ck, jOffset);
128       buffer = args[xStartIdx];
129     } else {
130       i = args[0];
131       j = args[1];
132       buffer = args[xStartIdx + k];
133     }
134     bodyBuilder(k, i, j, buffer);
135   }
136 }
137 
138 /// Creates a code block to process each pair of (xys[i], xys[j]) for sorting.
139 /// The code to process the value pairs is generated by `bodyBuilder`.
140 static void forEachIJPairInAllBuffers(
141     OpBuilder &builder, Location loc, ValueRange args, uint64_t nx, uint64_t ny,
142     bool isCoo, function_ref<void(uint64_t, Value, Value, Value)> bodyBuilder) {
143 
144   // Create code for the first (nx + ny) buffers. When isCoo==true, these
145   // logical buffers are all from the xy buffer of the sort_coo operator.
146   forEachIJPairInXs(builder, loc, args, nx + ny, 0, isCoo, bodyBuilder);
147 
148   uint64_t numHandledBuffers = isCoo ? 1 : nx + ny;
149 
150   // Create code for the remaining buffers.
151   Value i = args[0];
152   Value j = args[1];
153   for (const auto &arg :
154        llvm::enumerate(args.drop_front(xStartIdx + numHandledBuffers))) {
155     bodyBuilder(arg.index() + nx + ny, i, j, arg.value());
156   }
157 }
158 
159 /// Creates a code block for swapping the values in index i and j for all the
160 /// buffers.
161 //
162 // The generated IR corresponds to this C like algorithm:
163 //     swap(x0[i], x0[j]);
164 //     swap(x1[i], x1[j]);
165 //     ...
166 //     swap(xn[i], xn[j]);
167 //     swap(y0[i], y0[j]);
168 //     ...
169 //     swap(yn[i], yn[j]);
170 static void createSwap(OpBuilder &builder, Location loc, ValueRange args,
171                        uint64_t nx, uint64_t ny, bool isCoo) {
172   auto swapOnePair = [&](uint64_t unused, Value i, Value j, Value buffer) {
173     Value vi = builder.create<memref::LoadOp>(loc, buffer, i);
174     Value vj = builder.create<memref::LoadOp>(loc, buffer, j);
175     builder.create<memref::StoreOp>(loc, vj, buffer, i);
176     builder.create<memref::StoreOp>(loc, vi, buffer, j);
177   };
178 
179   forEachIJPairInAllBuffers(builder, loc, args, nx, ny, isCoo, swapOnePair);
180 }
181 
182 /// Creates code to compare all the (xs[i], xs[j]) pairs. The method to compare
183 /// each pair is create via `compareBuilder`.
184 static Value createInlinedCompareImplementation(
185     OpBuilder &builder, Location loc, ValueRange args, uint64_t nx, uint64_t ny,
186     bool isCoo,
187     function_ref<Value(OpBuilder &, Location, Value, Value, Value, bool, bool)>
188         compareBuilder) {
189   Value result;
190   auto bodyBuilder = [&](uint64_t k, Value i, Value j, Value buffer) {
191     bool isFirstDim = (k == 0);
192     bool isLastDim = (k == nx - 1);
193     Value val =
194         compareBuilder(builder, loc, i, j, buffer, isFirstDim, isLastDim);
195     if (isFirstDim) {
196       result = val;
197     } else if (!isLastDim) {
198       OpBuilder::InsertionGuard insertionGuard(builder);
199       auto ifOp = cast<scf::IfOp>(val.getDefiningOp());
200       builder.setInsertionPointAfter(ifOp);
201       builder.create<scf::YieldOp>(loc, ifOp.getResult(0));
202     }
203   };
204 
205   forEachIJPairInXs(builder, loc, args, nx, ny, isCoo, bodyBuilder);
206 
207   builder.setInsertionPointAfterValue(result);
208   return result;
209 }
210 
211 /// Generates code to compare whether x[i] is equal to x[j] and returns the
212 /// result of the comparison.
213 static Value createEqCompare(OpBuilder &builder, Location loc, Value i, Value j,
214                              Value x, bool isFirstDim, bool isLastDim) {
215   Value vi = builder.create<memref::LoadOp>(loc, x, i);
216   Value vj = builder.create<memref::LoadOp>(loc, x, j);
217 
218   Value res;
219   if (isLastDim) {
220     res = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, vi, vj);
221     // For 1D, we create a compare without any control flow. Otherwise, we
222     // create YieldOp to return the result in the nested if-stmt.
223     if (!isFirstDim)
224       builder.create<scf::YieldOp>(loc, res);
225   } else {
226     Value ne =
227         builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ne, vi, vj);
228     scf::IfOp ifOp = builder.create<scf::IfOp>(loc, builder.getIntegerType(1),
229                                                ne, /*else=*/true);
230     // If (x[i] != x[j]).
231     builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
232     Value f = constantI1(builder, loc, false);
233     builder.create<scf::YieldOp>(loc, f);
234 
235     // If (x[i] == x[j]). Set up the insertion point for the nested if-stmt that
236     // checks the remaining dimensions.
237     builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
238     res = ifOp.getResult(0);
239   }
240 
241   return res;
242 }
243 
244 /// Creates code to compare whether xs[i] is equal to xs[j].
245 //
246 // The generate IR corresponds to this C like algorithm:
247 //   if (x0[i] != x0[j])
248 //     return false;
249 //   else
250 //     if (x1[i] != x1[j])
251 //       return false;
252 //     else if (x2[2] != x2[j]))
253 //       and so on ...
254 static Value createInlinedEqCompare(OpBuilder &builder, Location loc,
255                                     ValueRange args, uint64_t nx, uint64_t ny,
256                                     bool isCoo, uint32_t nTrailingP = 0) {
257   // Compare functions don't use trailing parameters.
258   (void)nTrailingP;
259   assert(nTrailingP == 0);
260   return createInlinedCompareImplementation(builder, loc, args, nx, ny, isCoo,
261                                             createEqCompare);
262 }
263 
264 /// Generates code to compare whether x[i] is less than x[j] and returns the
265 /// result of the comparison.
266 static Value createLessThanCompare(OpBuilder &builder, Location loc, Value i,
267                                    Value j, Value x, bool isFirstDim,
268                                    bool isLastDim) {
269   Value vi = builder.create<memref::LoadOp>(loc, x, i);
270   Value vj = builder.create<memref::LoadOp>(loc, x, j);
271 
272   Value res;
273   if (isLastDim) {
274     res = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult, vi, vj);
275     // For 1D, we create a compare without any control flow. Otherwise, we
276     // create YieldOp to return the result in the nested if-stmt.
277     if (!isFirstDim)
278       builder.create<scf::YieldOp>(loc, res);
279   } else {
280     Value ne =
281         builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ne, vi, vj);
282     scf::IfOp ifOp = builder.create<scf::IfOp>(loc, builder.getIntegerType(1),
283                                                ne, /*else=*/true);
284     // If (x[i] != x[j]).
285     builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
286     Value lt =
287         builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult, vi, vj);
288     builder.create<scf::YieldOp>(loc, lt);
289 
290     // If (x[i] == x[j]). Set up the insertion point for the nested if-stmt that
291     // checks the remaining dimensions.
292     builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
293     res = ifOp.getResult(0);
294   }
295 
296   return res;
297 }
298 
299 /// Creates code to compare whether xs[i] is less than xs[j].
300 //
301 // The generate IR corresponds to this C like algorithm:
302 //   if (x0[i] != x0[j])
303 //     return x0[i] < x0[j];
304 //   else if (x1[j] != x1[i])
305 //     return x1[i] < x1[j];
306 //   else
307 //       and so on ...
308 static Value createInlinedLessThan(OpBuilder &builder, Location loc,
309                                    ValueRange args, uint64_t nx, uint64_t ny,
310                                    bool isCoo, uint32_t nTrailingP = 0) {
311   // Compare functions don't use trailing parameters.
312   (void)nTrailingP;
313   assert(nTrailingP == 0);
314   return createInlinedCompareImplementation(builder, loc, args, nx, ny, isCoo,
315                                             createLessThanCompare);
316 }
317 
318 /// Creates a function to use a binary search to find the insertion point for
319 /// inserting xs[hi] to the sorted values xs[lo..hi).
320 //
321 // The generate IR corresponds to this C like algorithm:
322 //   p = hi
323 //   while (lo < hi)
324 //      mid = (lo + hi) >> 1
325 //      if (xs[p] < xs[mid])
326 //        hi = mid
327 //      else
328 //        lo = mid - 1
329 //   return lo;
330 //
331 static void createBinarySearchFunc(OpBuilder &builder, ModuleOp module,
332                                    func::FuncOp func, uint64_t nx, uint64_t ny,
333                                    bool isCoo, uint32_t nTrailingP = 0) {
334   // Binary search doesn't use trailing parameters.
335   (void)nTrailingP;
336   assert(nTrailingP == 0);
337   OpBuilder::InsertionGuard insertionGuard(builder);
338   Block *entryBlock = func.addEntryBlock();
339   builder.setInsertionPointToStart(entryBlock);
340 
341   Location loc = func.getLoc();
342   ValueRange args = entryBlock->getArguments();
343   Value p = args[hiIdx];
344   SmallVector<Type, 2> types(2, p.getType()); // Only two types.
345   scf::WhileOp whileOp = builder.create<scf::WhileOp>(
346       loc, types, SmallVector<Value, 2>{args[loIdx], args[hiIdx]});
347 
348   // The before-region of the WhileOp.
349   Block *before =
350       builder.createBlock(&whileOp.getBefore(), {}, types, {loc, loc});
351   builder.setInsertionPointToEnd(before);
352   Value cond1 = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult,
353                                               before->getArgument(0),
354                                               before->getArgument(1));
355   builder.create<scf::ConditionOp>(loc, cond1, before->getArguments());
356 
357   // The after-region of the WhileOp.
358   Block *after =
359       builder.createBlock(&whileOp.getAfter(), {}, types, {loc, loc});
360   builder.setInsertionPointToEnd(after);
361   Value lo = after->getArgument(0);
362   Value hi = after->getArgument(1);
363   // Compute mid = (lo + hi) >> 1.
364   Value c1 = constantIndex(builder, loc, 1);
365   Value mid = builder.create<arith::ShRUIOp>(
366       loc, builder.create<arith::AddIOp>(loc, lo, hi), c1);
367   Value midp1 = builder.create<arith::AddIOp>(loc, mid, c1);
368 
369   // Compare xs[p] < xs[mid].
370   SmallVector<Value> compareOperands{p, mid};
371   uint64_t numXBuffers = isCoo ? 1 : nx;
372   compareOperands.append(args.begin() + xStartIdx,
373                          args.begin() + xStartIdx + numXBuffers);
374   Value cond2 =
375       createInlinedLessThan(builder, loc, compareOperands, nx, ny, isCoo);
376   // Update lo and hi for the WhileOp as follows:
377   //   if (xs[p] < xs[mid]))
378   //     hi = mid;
379   //   else
380   //     lo = mid + 1;
381   Value newLo = builder.create<arith::SelectOp>(loc, cond2, lo, midp1);
382   Value newHi = builder.create<arith::SelectOp>(loc, cond2, mid, hi);
383   builder.create<scf::YieldOp>(loc, ValueRange{newLo, newHi});
384 
385   builder.setInsertionPointAfter(whileOp);
386   builder.create<func::ReturnOp>(loc, whileOp.getResult(0));
387 }
388 
389 /// Creates code to advance i in a loop based on xs[p] as follows:
390 ///   while (xs[i] < xs[p]) i += step (step > 0)
391 /// or
392 ///   while (xs[i] > xs[p]) i += step (step < 0)
393 /// The routine returns i as well as a boolean value to indicate whether
394 /// xs[i] == xs[p].
395 static std::pair<Value, Value>
396 createScanLoop(OpBuilder &builder, ModuleOp module, func::FuncOp func,
397                ValueRange xs, Value i, Value p, uint64_t nx, uint64_t ny,
398                bool isCoo, int step) {
399   Location loc = func.getLoc();
400   scf::WhileOp whileOp =
401       builder.create<scf::WhileOp>(loc, TypeRange{i.getType()}, ValueRange{i});
402 
403   Block *before =
404       builder.createBlock(&whileOp.getBefore(), {}, {i.getType()}, {loc});
405   builder.setInsertionPointToEnd(before);
406   SmallVector<Value> compareOperands;
407   if (step > 0) {
408     compareOperands.push_back(before->getArgument(0));
409     compareOperands.push_back(p);
410   } else {
411     assert(step < 0);
412     compareOperands.push_back(p);
413     compareOperands.push_back(before->getArgument(0));
414   }
415   compareOperands.append(xs.begin(), xs.end());
416   Value cond =
417       createInlinedLessThan(builder, loc, compareOperands, nx, ny, isCoo);
418   builder.create<scf::ConditionOp>(loc, cond, before->getArguments());
419 
420   Block *after =
421       builder.createBlock(&whileOp.getAfter(), {}, {i.getType()}, {loc});
422   builder.setInsertionPointToEnd(after);
423   Value cs = constantIndex(builder, loc, step);
424   i = builder.create<arith::AddIOp>(loc, after->getArgument(0), cs);
425   builder.create<scf::YieldOp>(loc, ValueRange{i});
426   i = whileOp.getResult(0);
427 
428   builder.setInsertionPointAfter(whileOp);
429   compareOperands[0] = i;
430   compareOperands[1] = p;
431   Value compareEq =
432       createInlinedEqCompare(builder, loc, compareOperands, nx, ny, isCoo);
433 
434   return std::make_pair(whileOp.getResult(0), compareEq);
435 }
436 
437 /// Creates and returns an IfOp to compare two elements and swap the elements
438 /// if compareFunc(data[b], data[a]) returns true. The new insertion point is
439 /// right after the swap instructions.
440 static scf::IfOp createCompareThenSwap(OpBuilder &builder, Location loc,
441                                        uint64_t nx, uint64_t ny, bool isCoo,
442                                        SmallVectorImpl<Value> &swapOperands,
443                                        SmallVectorImpl<Value> &compareOperands,
444                                        Value a, Value b) {
445   // Compare(data[b], data[a]).
446   compareOperands[0] = b;
447   compareOperands[1] = a;
448   Value cond =
449       createInlinedLessThan(builder, loc, compareOperands, nx, ny, isCoo);
450   scf::IfOp ifOp = builder.create<scf::IfOp>(loc, cond, /*else=*/false);
451   builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
452   swapOperands[0] = b;
453   swapOperands[1] = a;
454   createSwap(builder, loc, swapOperands, nx, ny, isCoo);
455   return ifOp;
456 }
457 
458 /// Creates code to insert the 3rd element to a list of two sorted elements.
459 static void createInsert3rd(OpBuilder &builder, Location loc, uint64_t nx,
460                             uint64_t ny, bool isCoo,
461                             SmallVectorImpl<Value> &swapOperands,
462                             SmallVectorImpl<Value> &compareOperands, Value v0,
463                             Value v1, Value v2) {
464   scf::IfOp ifOp = createCompareThenSwap(builder, loc, nx, ny, isCoo,
465                                          swapOperands, compareOperands, v1, v2);
466   createCompareThenSwap(builder, loc, nx, ny, isCoo, swapOperands,
467                         compareOperands, v0, v1);
468   builder.setInsertionPointAfter(ifOp);
469 }
470 
471 /// Creates code to sort 3 elements.
472 static void createSort3(OpBuilder &builder, Location loc, uint64_t nx,
473                         uint64_t ny, bool isCoo,
474                         SmallVectorImpl<Value> &swapOperands,
475                         SmallVectorImpl<Value> &compareOperands, Value v0,
476                         Value v1, Value v2) {
477   // Sort the first 2 elements.
478   scf::IfOp ifOp1 = createCompareThenSwap(
479       builder, loc, nx, ny, isCoo, swapOperands, compareOperands, v0, v1);
480   builder.setInsertionPointAfter(ifOp1);
481 
482   // Insert the 3th element.
483   createInsert3rd(builder, loc, nx, ny, isCoo, swapOperands, compareOperands,
484                   v0, v1, v2);
485 }
486 
487 /// Creates code to sort 5 elements.
488 static void createSort5(OpBuilder &builder, Location loc, uint64_t nx,
489                         uint64_t ny, bool isCoo,
490                         SmallVectorImpl<Value> &swapOperands,
491                         SmallVectorImpl<Value> &compareOperands, Value v0,
492                         Value v1, Value v2, Value v3, Value v4) {
493   // Sort the first 3 elements.
494   createSort3(builder, loc, nx, ny, isCoo, swapOperands, compareOperands, v0,
495               v1, v2);
496 
497   auto insert4th = [&]() {
498     scf::IfOp ifOp = createCompareThenSwap(
499         builder, loc, nx, ny, isCoo, swapOperands, compareOperands, v2, v3);
500     createInsert3rd(builder, loc, nx, ny, isCoo, swapOperands, compareOperands,
501                     v0, v1, v2);
502     builder.setInsertionPointAfter(ifOp);
503   };
504 
505   // Insert the 4th element.
506   insert4th();
507 
508   // Insert the 5th element.
509   scf::IfOp ifOp = createCompareThenSwap(builder, loc, nx, ny, isCoo,
510                                          swapOperands, compareOperands, v3, v4);
511   insert4th();
512   builder.setInsertionPointAfter(ifOp);
513 }
514 
515 /// Creates a code block to swap the values in indices lo, mi, and hi so that
516 /// data[lo], data[mi] and data[hi] are sorted in non-decreasing values. When
517 /// the number of values in range [lo, hi) is more than a threshold, we also
518 /// include the middle of [lo, mi) and [mi, hi) and sort a total of five values.
519 static void createChoosePivot(OpBuilder &builder, ModuleOp module,
520                               func::FuncOp func, uint64_t nx, uint64_t ny,
521                               bool isCoo, Value lo, Value hi, Value mi,
522                               ValueRange args) {
523   SmallVector<Value> compareOperands{mi, lo};
524   uint64_t numXBuffers = isCoo ? 1 : nx;
525   compareOperands.append(args.begin() + xStartIdx,
526                          args.begin() + xStartIdx + numXBuffers);
527   SmallVector<Value> swapOperands{mi, lo};
528   swapOperands.append(args.begin() + xStartIdx, args.end());
529   Location loc = func.getLoc();
530   Value c1 = constantIndex(builder, loc, 1);
531   Value hiP1 = builder.create<arith::AddIOp>(loc, hi, c1);
532   Value len = builder.create<arith::SubIOp>(loc, hiP1, lo);
533   Value lenThreshold = constantIndex(builder, loc, 1000);
534   Value lenCond = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult,
535                                                 len, lenThreshold);
536   scf::IfOp lenIf = builder.create<scf::IfOp>(loc, lenCond, /*else=*/true);
537 
538   // When len < 1000, choose pivot from median of 3 values.
539   builder.setInsertionPointToStart(&lenIf.getThenRegion().front());
540   createSort3(builder, loc, nx, ny, isCoo, swapOperands, compareOperands, lo,
541               mi, hi);
542 
543   // When len >= 1000, choose pivot from median of 5 values.
544   builder.setInsertionPointToStart(&lenIf.getElseRegion().front());
545   Value miP1 = builder.create<arith::AddIOp>(loc, hi, c1);
546   Value a = builder.create<arith::AddIOp>(loc, lo, miP1);
547   // Value a is the middle between [loc, mi].
548   a = builder.create<arith::ShRUIOp>(loc, a, c1);
549   Value b = builder.create<arith::AddIOp>(loc, mi, hiP1);
550   // Value b is the middle between [mi, hi].
551   b = builder.create<arith::ShRUIOp>(loc, b, c1);
552   createSort5(builder, loc, nx, ny, isCoo, swapOperands, compareOperands, lo, a,
553               mi, b, hi);
554 
555   builder.setInsertionPointAfter(lenIf);
556 }
557 
558 /// Creates a function to perform quick sort partition on the values in the
559 /// range of index [lo, hi), assuming lo < hi.
560 //
561 // The generated IR corresponds to this C like algorithm:
562 // int partition(lo, hi, xs) {
563 //   p = (lo+hi)/2  // pivot index
564 //   i = lo
565 //   j = hi-1
566 //   while (true) do {
567 //     while (xs[i] < xs[p]) i ++;
568 //     i_eq = (xs[i] == xs[p]);
569 //     while (xs[j] > xs[p]) j --;
570 //     j_eq = (xs[j] == xs[p]);
571 //
572 //     if (i >= j) return j + 1;
573 //
574 //     if (i < j) {
575 //       swap(xs[i], xs[j])
576 //       if (i == p) {
577 //         p = j;
578 //       } else if (j == p) {
579 //         p = i;
580 //       }
581 //       if (i_eq && j_eq) {
582 //         ++i;
583 //         --j;
584 //       }
585 //     }
586 //   }
587 // }
588 static void createPartitionFunc(OpBuilder &builder, ModuleOp module,
589                                 func::FuncOp func, uint64_t nx, uint64_t ny,
590                                 bool isCoo, uint32_t nTrailingP = 0) {
591   // Quick sort partition doesn't use trailing parameters.
592   (void)nTrailingP;
593   assert(nTrailingP == 0);
594   OpBuilder::InsertionGuard insertionGuard(builder);
595 
596   Block *entryBlock = func.addEntryBlock();
597   builder.setInsertionPointToStart(entryBlock);
598 
599   Location loc = func.getLoc();
600   ValueRange args = entryBlock->getArguments();
601   Value lo = args[loIdx];
602   Value hi = args[hiIdx];
603   Value sum = builder.create<arith::AddIOp>(loc, lo, hi);
604   Value c1 = constantIndex(builder, loc, 1);
605   Value p = builder.create<arith::ShRUIOp>(loc, sum, c1);
606 
607   Value i = lo;
608   Value j = builder.create<arith::SubIOp>(loc, hi, c1);
609   createChoosePivot(builder, module, func, nx, ny, isCoo, i, j, p, args);
610   Value trueVal = constantI1(builder, loc, true); // The value for while (true)
611   SmallVector<Value, 4> operands{i, j, p, trueVal}; // Exactly four values.
612   SmallVector<Type, 4> types{i.getType(), j.getType(), p.getType(),
613                              trueVal.getType()};
614   scf::WhileOp whileOp = builder.create<scf::WhileOp>(loc, types, operands);
615 
616   // The before-region of the WhileOp.
617   Block *before = builder.createBlock(&whileOp.getBefore(), {}, types,
618                                       {loc, loc, loc, loc});
619   builder.setInsertionPointToEnd(before);
620   builder.create<scf::ConditionOp>(loc, before->getArgument(3),
621                                    before->getArguments());
622 
623   // The after-region of the WhileOp.
624   Block *after =
625       builder.createBlock(&whileOp.getAfter(), {}, types, {loc, loc, loc, loc});
626   builder.setInsertionPointToEnd(after);
627   i = after->getArgument(0);
628   j = after->getArgument(1);
629   p = after->getArgument(2);
630 
631   uint64_t numXBuffers = isCoo ? 1 : nx;
632   auto [iresult, iCompareEq] =
633       createScanLoop(builder, module, func, args.slice(xStartIdx, numXBuffers),
634                      i, p, nx, ny, isCoo, 1);
635   i = iresult;
636   auto [jresult, jCompareEq] =
637       createScanLoop(builder, module, func, args.slice(xStartIdx, numXBuffers),
638                      j, p, nx, ny, isCoo, -1);
639   j = jresult;
640 
641   // If i < j:
642   Value cond =
643       builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult, i, j);
644   scf::IfOp ifOp = builder.create<scf::IfOp>(loc, types, cond, /*else=*/true);
645   builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
646   SmallVector<Value> swapOperands{i, j};
647   swapOperands.append(args.begin() + xStartIdx, args.end());
648   createSwap(builder, loc, swapOperands, nx, ny, isCoo);
649   // If the pivot is moved, update p with the new pivot.
650   Value icond =
651       builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, i, p);
652   scf::IfOp ifOpI = builder.create<scf::IfOp>(loc, TypeRange{p.getType()},
653                                               icond, /*else=*/true);
654   builder.setInsertionPointToStart(&ifOpI.getThenRegion().front());
655   builder.create<scf::YieldOp>(loc, ValueRange{j});
656   builder.setInsertionPointToStart(&ifOpI.getElseRegion().front());
657   Value jcond =
658       builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, j, p);
659   scf::IfOp ifOpJ = builder.create<scf::IfOp>(loc, TypeRange{p.getType()},
660                                               jcond, /*else=*/true);
661   builder.setInsertionPointToStart(&ifOpJ.getThenRegion().front());
662   builder.create<scf::YieldOp>(loc, ValueRange{i});
663   builder.setInsertionPointToStart(&ifOpJ.getElseRegion().front());
664   builder.create<scf::YieldOp>(loc, ValueRange{p});
665   builder.setInsertionPointAfter(ifOpJ);
666   builder.create<scf::YieldOp>(loc, ifOpJ.getResults());
667   builder.setInsertionPointAfter(ifOpI);
668   Value compareEqIJ =
669       builder.create<arith::AndIOp>(loc, iCompareEq, jCompareEq);
670   scf::IfOp ifOp2 = builder.create<scf::IfOp>(
671       loc, TypeRange{i.getType(), j.getType()}, compareEqIJ, /*else=*/true);
672   builder.setInsertionPointToStart(&ifOp2.getThenRegion().front());
673   Value i2 = builder.create<arith::AddIOp>(loc, i, c1);
674   Value j2 = builder.create<arith::SubIOp>(loc, j, c1);
675   builder.create<scf::YieldOp>(loc, ValueRange{i2, j2});
676   builder.setInsertionPointToStart(&ifOp2.getElseRegion().front());
677   builder.create<scf::YieldOp>(loc, ValueRange{i, j});
678   builder.setInsertionPointAfter(ifOp2);
679   builder.create<scf::YieldOp>(
680       loc,
681       ValueRange{ifOp2.getResult(0), ifOp2.getResult(1), ifOpI.getResult(0),
682                  /*cont=*/constantI1(builder, loc, true)});
683 
684   // False branch for if i < j (i.e., i >= j):
685   builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
686   p = builder.create<arith::AddIOp>(loc, j,
687                                     constantOne(builder, loc, j.getType()));
688   builder.create<scf::YieldOp>(
689       loc, ValueRange{i, j, p, /*cont=*/constantI1(builder, loc, false)});
690 
691   // Return for the whileOp.
692   builder.setInsertionPointAfter(ifOp);
693   builder.create<scf::YieldOp>(loc, ifOp.getResults());
694 
695   // Return for the function.
696   builder.setInsertionPointAfter(whileOp);
697   builder.create<func::ReturnOp>(loc, whileOp.getResult(2));
698 }
699 
700 /// Computes (n-2)/n, assuming n has index type.
701 static Value createSubTwoDividedByTwo(OpBuilder &builder, Location loc,
702                                       Value n) {
703   Value i2 = constantIndex(builder, loc, 2);
704   Value res = builder.create<arith::SubIOp>(loc, n, i2);
705   Value i1 = constantIndex(builder, loc, 1);
706   return builder.create<arith::ShRUIOp>(loc, res, i1);
707 }
708 
709 /// Creates a function to heapify the subtree with root `start` within the full
710 /// binary tree in the range of index [first, first + n).
711 //
712 // The generated IR corresponds to this C like algorithm:
713 // void shiftDown(first, start, n, data) {
714 //   if (n >= 2) {
715 //     child = start - first
716 //     if ((n-2)/2 >= child) {
717 //       // Left child exists.
718 //       child = child * 2 + 1 // Initialize the bigger child to left child.
719 //       childIndex = child + first
720 //       if (child+1 < n && data[childIndex] < data[childIndex+1])
721 //         // Right child exits and is bigger.
722 //         childIndex++; child++;
723 //       // Shift data[start] down to where it belongs in the subtree.
724 //       while (data[start] < data[childIndex) {
725 //         swap(data[start], data[childIndex])
726 //         start = childIndex
727 //         if ((n - 2)/2 >= child) {
728 //           // Left child exists.
729 //           child = 2*child + 1
730 //           childIndex = child + 1
731 //           if (child + 1) < n && data[childIndex] < data[childIndex+1]
732 //             childIndex++; child++;
733 //         }
734 //       }
735 //     }
736 //   }
737 // }
738 //
739 static void createShiftDownFunc(OpBuilder &builder, ModuleOp module,
740                                 func::FuncOp func, uint64_t nx, uint64_t ny,
741                                 bool isCoo, uint32_t nTrailingP) {
742   // The value n is passed in as a trailing parameter.
743   assert(nTrailingP == 1);
744   OpBuilder::InsertionGuard insertionGuard(builder);
745   Block *entryBlock = func.addEntryBlock();
746   builder.setInsertionPointToStart(entryBlock);
747 
748   Location loc = func.getLoc();
749   Value n = entryBlock->getArguments().back();
750   ValueRange args = entryBlock->getArguments().drop_back();
751   Value first = args[loIdx];
752   Value start = args[hiIdx];
753 
754   // If (n >= 2).
755   Value c2 = constantIndex(builder, loc, 2);
756   Value condN =
757       builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::uge, n, c2);
758   scf::IfOp ifN = builder.create<scf::IfOp>(loc, condN, /*else=*/false);
759   builder.setInsertionPointToStart(&ifN.getThenRegion().front());
760   Value child = builder.create<arith::SubIOp>(loc, start, first);
761 
762   // If ((n-2)/2 >= child).
763   Value t = createSubTwoDividedByTwo(builder, loc, n);
764   Value condNc =
765       builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::uge, t, child);
766   scf::IfOp ifNc = builder.create<scf::IfOp>(loc, condNc, /*else=*/false);
767 
768   builder.setInsertionPointToStart(&ifNc.getThenRegion().front());
769   Value c1 = constantIndex(builder, loc, 1);
770   SmallVector<Value> compareOperands{start, start};
771   uint64_t numXBuffers = isCoo ? 1 : nx;
772   compareOperands.append(args.begin() + xStartIdx,
773                          args.begin() + xStartIdx + numXBuffers);
774 
775   // Generate code to inspect the children of 'r' and return the larger child
776   // as follows:
777   //   child = r * 2 + 1 // Left child.
778   //   childIndex = child + first
779   //   if (child+1 < n && data[childIndex] < data[childIndex+1])
780   //     childIndex ++; child ++ // Right child is bigger.
781   auto getLargerChild = [&](Value r) -> std::pair<Value, Value> {
782     Value lChild = builder.create<arith::ShLIOp>(loc, r, c1);
783     lChild = builder.create<arith::AddIOp>(loc, lChild, c1);
784     Value lChildIdx = builder.create<arith::AddIOp>(loc, lChild, first);
785     Value rChild = builder.create<arith::AddIOp>(loc, lChild, c1);
786     Value cond1 = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult,
787                                                 rChild, n);
788     SmallVector<Type, 2> ifTypes(2, r.getType());
789     scf::IfOp if1 =
790         builder.create<scf::IfOp>(loc, ifTypes, cond1, /*else=*/true);
791     builder.setInsertionPointToStart(&if1.getThenRegion().front());
792     Value rChildIdx = builder.create<arith::AddIOp>(loc, rChild, first);
793     // Compare data[left] < data[right].
794     compareOperands[0] = lChildIdx;
795     compareOperands[1] = rChildIdx;
796     Value cond2 =
797         createInlinedLessThan(builder, loc, compareOperands, nx, ny, isCoo);
798     scf::IfOp if2 =
799         builder.create<scf::IfOp>(loc, ifTypes, cond2, /*else=*/true);
800     builder.setInsertionPointToStart(&if2.getThenRegion().front());
801     builder.create<scf::YieldOp>(loc, ValueRange{rChild, rChildIdx});
802     builder.setInsertionPointToStart(&if2.getElseRegion().front());
803     builder.create<scf::YieldOp>(loc, ValueRange{lChild, lChildIdx});
804     builder.setInsertionPointAfter(if2);
805     builder.create<scf::YieldOp>(loc, if2.getResults());
806     builder.setInsertionPointToStart(&if1.getElseRegion().front());
807     builder.create<scf::YieldOp>(loc, ValueRange{lChild, lChildIdx});
808     builder.setInsertionPointAfter(if1);
809     return std::make_pair(if1.getResult(0), if1.getResult(1));
810   };
811 
812   Value childIdx;
813   std::tie(child, childIdx) = getLargerChild(child);
814 
815   // While (data[start] < data[childIndex]).
816   SmallVector<Type, 3> types(3, child.getType());
817   scf::WhileOp whileOp = builder.create<scf::WhileOp>(
818       loc, types, SmallVector<Value, 2>{start, child, childIdx});
819 
820   // The before-region of the WhileOp.
821   SmallVector<Location, 3> locs(3, loc);
822   Block *before = builder.createBlock(&whileOp.getBefore(), {}, types, locs);
823   builder.setInsertionPointToEnd(before);
824   start = before->getArgument(0);
825   childIdx = before->getArgument(2);
826   compareOperands[0] = start;
827   compareOperands[1] = childIdx;
828   Value cond =
829       createInlinedLessThan(builder, loc, compareOperands, nx, ny, isCoo);
830   builder.create<scf::ConditionOp>(loc, cond, before->getArguments());
831 
832   // The after-region of the WhileOp.
833   Block *after = builder.createBlock(&whileOp.getAfter(), {}, types, locs);
834   start = after->getArgument(0);
835   child = after->getArgument(1);
836   childIdx = after->getArgument(2);
837   SmallVector<Value> swapOperands{start, childIdx};
838   swapOperands.append(args.begin() + xStartIdx, args.end());
839   createSwap(builder, loc, swapOperands, nx, ny, isCoo);
840   start = childIdx;
841   Value cond2 =
842       builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::uge, t, child);
843   scf::IfOp if2 = builder.create<scf::IfOp>(
844       loc, TypeRange{child.getType(), child.getType()}, cond2, /*else=*/true);
845   builder.setInsertionPointToStart(&if2.getThenRegion().front());
846   auto [newChild, newChildIdx] = getLargerChild(child);
847   builder.create<scf::YieldOp>(loc, ValueRange{newChild, newChildIdx});
848   builder.setInsertionPointToStart(&if2.getElseRegion().front());
849   builder.create<scf::YieldOp>(loc, ValueRange{child, childIdx});
850   builder.setInsertionPointAfter(if2);
851   builder.create<scf::YieldOp>(
852       loc, ValueRange{start, if2.getResult(0), if2.getResult(1)});
853 
854   builder.setInsertionPointAfter(ifN);
855   builder.create<func::ReturnOp>(loc);
856 }
857 
858 /// Creates a function to perform heap sort on the values in the range of index
859 /// [lo, hi) with the assumption hi - lo >= 2.
860 //
861 // The generate IR corresponds to this C like algorithm:
862 // void heapSort(lo, hi, data) {
863 //   n = hi - lo
864 //   for i = (n-2)/2 downto 0
865 //     shiftDown(lo, lo+i, n)
866 //
867 //   for l = n downto 2
868 //      swap(lo, lo+l-1)
869 //      shiftdown(lo, lo, l-1)
870 // }
871 static void createHeapSortFunc(OpBuilder &builder, ModuleOp module,
872                                func::FuncOp func, uint64_t nx, uint64_t ny,
873                                bool isCoo, uint32_t nTrailingP) {
874   // Heap sort function doesn't have trailing parameters.
875   (void)nTrailingP;
876   assert(nTrailingP == 0);
877   OpBuilder::InsertionGuard insertionGuard(builder);
878   Block *entryBlock = func.addEntryBlock();
879   builder.setInsertionPointToStart(entryBlock);
880 
881   Location loc = func.getLoc();
882   ValueRange args = entryBlock->getArguments();
883   Value lo = args[loIdx];
884   Value hi = args[hiIdx];
885   Value n = builder.create<arith::SubIOp>(loc, hi, lo);
886 
887   // For i = (n-2)/2 downto 0.
888   Value c0 = constantIndex(builder, loc, 0);
889   Value c1 = constantIndex(builder, loc, 1);
890   Value s = createSubTwoDividedByTwo(builder, loc, n);
891   Value up = builder.create<arith::AddIOp>(loc, s, c1);
892   scf::ForOp forI = builder.create<scf::ForOp>(loc, c0, up, c1);
893   builder.setInsertionPointToStart(forI.getBody());
894   Value i = builder.create<arith::SubIOp>(loc, s, forI.getInductionVar());
895   Value lopi = builder.create<arith::AddIOp>(loc, lo, i);
896   SmallVector<Value> shiftDownOperands = {lo, lopi};
897   shiftDownOperands.append(args.begin() + xStartIdx, args.end());
898   shiftDownOperands.push_back(n);
899   FlatSymbolRefAttr shiftDownFunc = getMangledSortHelperFunc(
900       builder, func, TypeRange(), kShiftDownFuncNamePrefix, nx, ny, isCoo,
901       shiftDownOperands, createShiftDownFunc, /*nTrailingP=*/1);
902   builder.create<func::CallOp>(loc, shiftDownFunc, TypeRange(),
903                                shiftDownOperands);
904 
905   builder.setInsertionPointAfter(forI);
906   // For l = n downto 2.
907   up = builder.create<arith::SubIOp>(loc, n, c1);
908   scf::ForOp forL = builder.create<scf::ForOp>(loc, c0, up, c1);
909   builder.setInsertionPointToStart(forL.getBody());
910   Value l = builder.create<arith::SubIOp>(loc, n, forL.getInductionVar());
911   Value loplm1 = builder.create<arith::AddIOp>(loc, lo, l);
912   loplm1 = builder.create<arith::SubIOp>(loc, loplm1, c1);
913   SmallVector<Value> swapOperands{lo, loplm1};
914   swapOperands.append(args.begin() + xStartIdx, args.end());
915   createSwap(builder, loc, swapOperands, nx, ny, isCoo);
916   shiftDownOperands[1] = lo;
917   shiftDownOperands[shiftDownOperands.size() - 1] =
918       builder.create<arith::SubIOp>(loc, l, c1);
919   builder.create<func::CallOp>(loc, shiftDownFunc, TypeRange(),
920                                shiftDownOperands);
921 
922   builder.setInsertionPointAfter(forL);
923   builder.create<func::ReturnOp>(loc);
924 }
925 
926 /// A helper for generating code to perform quick sort. It partitions [lo, hi),
927 /// recursively calls quick sort to process the smaller partition and returns
928 /// the bigger partition to be processed by the enclosed while-loop.
929 static std::pair<Value, Value>
930 createQuickSort(OpBuilder &builder, ModuleOp module, func::FuncOp func,
931                 ValueRange args, uint64_t nx, uint64_t ny, bool isCoo,
932                 uint32_t nTrailingP) {
933   MLIRContext *context = module.getContext();
934   Location loc = func.getLoc();
935   Value lo = args[loIdx];
936   Value hi = args[hiIdx];
937   SmallVector<Type, 2> types(2, lo.getType()); // Only two types.
938 
939   FlatSymbolRefAttr partitionFunc = getMangledSortHelperFunc(
940       builder, func, {IndexType::get(context)}, kPartitionFuncNamePrefix, nx,
941       ny, isCoo, args.drop_back(nTrailingP), createPartitionFunc);
942   Value p = builder
943                 .create<func::CallOp>(loc, partitionFunc,
944                                       TypeRange{IndexType::get(context)},
945                                       args.drop_back(nTrailingP))
946                 .getResult(0);
947 
948   Value lenLow = builder.create<arith::SubIOp>(loc, p, lo);
949   Value lenHigh = builder.create<arith::SubIOp>(loc, hi, p);
950   // Partition already sorts array with len <= 2
951   Value c2 = constantIndex(builder, loc, 2);
952   Value len = builder.create<arith::SubIOp>(loc, hi, lo);
953   Value lenGtTwo =
954       builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ugt, len, c2);
955   scf::IfOp ifLenGtTwo =
956       builder.create<scf::IfOp>(loc, types, lenGtTwo, /*else=*/true);
957   builder.setInsertionPointToStart(&ifLenGtTwo.getElseRegion().front());
958   // Returns an empty range to mark the entire region is fully sorted.
959   builder.create<scf::YieldOp>(loc, ValueRange{lo, lo});
960 
961   // Else len > 2, need recursion.
962   builder.setInsertionPointToStart(&ifLenGtTwo.getThenRegion().front());
963   Value cond = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ule,
964                                              lenLow, lenHigh);
965 
966   scf::IfOp ifOp = builder.create<scf::IfOp>(loc, types, cond, /*else=*/true);
967 
968   Value c0 = constantIndex(builder, loc, 0);
969   auto mayRecursion = [&](Value low, Value high, Value len) {
970     Value cond =
971         builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ne, len, c0);
972     scf::IfOp ifOp = builder.create<scf::IfOp>(loc, cond, /*else=*/false);
973     builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
974     SmallVector<Value> operands{low, high};
975     operands.append(args.begin() + xStartIdx, args.end());
976     builder.create<func::CallOp>(loc, func, operands);
977     builder.setInsertionPointAfter(ifOp);
978   };
979 
980   // Recursively call quickSort to process the smaller partition and return
981   // the bigger partition to be processed by the enclosed while-loop.
982   builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
983   mayRecursion(lo, p, lenLow);
984   builder.create<scf::YieldOp>(loc, ValueRange{p, hi});
985 
986   builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
987   mayRecursion(p, hi, lenHigh);
988   builder.create<scf::YieldOp>(loc, ValueRange{lo, p});
989 
990   builder.setInsertionPointAfter(ifOp);
991   builder.create<scf::YieldOp>(loc, ifOp.getResults());
992 
993   builder.setInsertionPointAfter(ifLenGtTwo);
994   return std::make_pair(ifLenGtTwo.getResult(0), ifLenGtTwo.getResult(1));
995 }
996 
997 /// Creates a function to perform insertion sort on the values in the range of
998 /// index [lo, hi).
999 //
1000 // The generate IR corresponds to this C like algorithm:
1001 // void insertionSort(lo, hi, data) {
1002 //   for (i = lo+1; i < hi; i++) {
1003 //      d = data[i];
1004 //      p = binarySearch(lo, i-1, data)
1005 //      for (j = 0; j > i - p; j++)
1006 //        data[i-j] = data[i-j-1]
1007 //      data[p] = d
1008 //   }
1009 // }
1010 static void createSortStableFunc(OpBuilder &builder, ModuleOp module,
1011                                  func::FuncOp func, uint64_t nx, uint64_t ny,
1012                                  bool isCoo, uint32_t nTrailingP) {
1013   // Stable sort function doesn't use trailing parameters.
1014   (void)nTrailingP;
1015   assert(nTrailingP == 0);
1016   OpBuilder::InsertionGuard insertionGuard(builder);
1017   Block *entryBlock = func.addEntryBlock();
1018   builder.setInsertionPointToStart(entryBlock);
1019 
1020   MLIRContext *context = module.getContext();
1021   Location loc = func.getLoc();
1022   ValueRange args = entryBlock->getArguments();
1023   Value c1 = constantIndex(builder, loc, 1);
1024   Value lo = args[loIdx];
1025   Value hi = args[hiIdx];
1026   Value lop1 = builder.create<arith::AddIOp>(loc, lo, c1);
1027 
1028   // Start the outer for-stmt with induction variable i.
1029   scf::ForOp forOpI = builder.create<scf::ForOp>(loc, lop1, hi, c1);
1030   builder.setInsertionPointToStart(forOpI.getBody());
1031   Value i = forOpI.getInductionVar();
1032 
1033   // Binary search to find the insertion point p.
1034   SmallVector<Value> operands{lo, i};
1035   operands.append(args.begin() + xStartIdx, args.end());
1036   FlatSymbolRefAttr searchFunc = getMangledSortHelperFunc(
1037       builder, func, {IndexType::get(context)}, kBinarySearchFuncNamePrefix, nx,
1038       ny, isCoo, operands, createBinarySearchFunc);
1039   Value p = builder
1040                 .create<func::CallOp>(loc, searchFunc, TypeRange{c1.getType()},
1041                                       operands)
1042                 .getResult(0);
1043 
1044   // Move the value at data[i] to a temporary location.
1045   operands[0] = operands[1] = i;
1046   SmallVector<Value> d;
1047   forEachIJPairInAllBuffers(
1048       builder, loc, operands, nx, ny, isCoo,
1049       [&](uint64_t unused, Value i, Value unused2, Value buffer) {
1050         d.push_back(builder.create<memref::LoadOp>(loc, buffer, i));
1051       });
1052 
1053   // Start the inner for-stmt with induction variable j, for moving data[p..i)
1054   // to data[p+1..i+1).
1055   Value imp = builder.create<arith::SubIOp>(loc, i, p);
1056   Value c0 = constantIndex(builder, loc, 0);
1057   scf::ForOp forOpJ = builder.create<scf::ForOp>(loc, c0, imp, c1);
1058   builder.setInsertionPointToStart(forOpJ.getBody());
1059   Value j = forOpJ.getInductionVar();
1060   Value imj = builder.create<arith::SubIOp>(loc, i, j);
1061   operands[1] = imj;
1062   operands[0] = builder.create<arith::SubIOp>(loc, imj, c1);
1063   forEachIJPairInAllBuffers(
1064       builder, loc, operands, nx, ny, isCoo,
1065       [&](uint64_t unused, Value imjm1, Value imj, Value buffer) {
1066         Value t = builder.create<memref::LoadOp>(loc, buffer, imjm1);
1067         builder.create<memref::StoreOp>(loc, t, buffer, imj);
1068       });
1069 
1070   // Store the value at data[i] to data[p].
1071   builder.setInsertionPointAfter(forOpJ);
1072   operands[0] = operands[1] = p;
1073   forEachIJPairInAllBuffers(
1074       builder, loc, operands, nx, ny, isCoo,
1075       [&](uint64_t k, Value p, Value usused, Value buffer) {
1076         builder.create<memref::StoreOp>(loc, d[k], buffer, p);
1077       });
1078 
1079   builder.setInsertionPointAfter(forOpI);
1080   builder.create<func::ReturnOp>(loc);
1081 }
1082 
1083 /// Creates a function to perform quick sort or a hybrid quick sort on the
1084 /// values in the range of index [lo, hi).
1085 //
1086 //
1087 // When nTrailingP == 0, the generated IR corresponds to this C like algorithm:
1088 // void quickSort(lo, hi, data) {
1089 //   while (lo + 1 < hi) {
1090 //        p = partition(low, high, data);
1091 //        if (len(lo, p) < len(p+1, hi)) {
1092 //          quickSort(lo, p, data);
1093 //          lo = p+1;
1094 //        } else {
1095 //          quickSort(p + 1, hi, data);
1096 //          hi = p;
1097 //        }
1098 //   }
1099 // }
1100 //
1101 // When nTrailingP == 1, the generated IR corresponds to this C like algorithm:
1102 // void hybridQuickSort(lo, hi, data, depthLimit) {
1103 //   while (lo + 1 < hi) {
1104 //     len = hi - lo;
1105 //     if (len <= limit) {
1106 //       insertionSort(lo, hi, data);
1107 //     } else {
1108 //       depthLimit --;
1109 //       if (depthLimit <= 0) {
1110 //         heapSort(lo, hi, data);
1111 //       } else {
1112 //          p = partition(low, high, data);
1113 //          if (len(lo, p) < len(p+1, hi)) {
1114 //            quickSort(lo, p, data, depthLimit);
1115 //            lo = p+1;
1116 //          } else {
1117 //            quickSort(p + 1, hi, data, depthLimit);
1118 //            hi = p;
1119 //          }
1120 //       }
1121 //     }
1122 //   }
1123 // }
1124 //
1125 static void createQuickSortFunc(OpBuilder &builder, ModuleOp module,
1126                                 func::FuncOp func, uint64_t nx, uint64_t ny,
1127                                 bool isCoo, uint32_t nTrailingP) {
1128   assert(nTrailingP == 1 || nTrailingP == 0);
1129   bool isHybrid = (nTrailingP == 1);
1130   OpBuilder::InsertionGuard insertionGuard(builder);
1131   Block *entryBlock = func.addEntryBlock();
1132   builder.setInsertionPointToStart(entryBlock);
1133 
1134   Location loc = func.getLoc();
1135   SmallVector<Value> args;
1136   args.append(entryBlock->getArguments().begin(),
1137               entryBlock->getArguments().end());
1138   Value lo = args[loIdx];
1139   Value hi = args[hiIdx];
1140   SmallVector<Type, 2> types(2, lo.getType()); // Only two types.
1141   scf::WhileOp whileOp =
1142       builder.create<scf::WhileOp>(loc, types, SmallVector<Value, 2>{lo, hi});
1143 
1144   // The before-region of the WhileOp.
1145   Block *before =
1146       builder.createBlock(&whileOp.getBefore(), {}, types, {loc, loc});
1147   builder.setInsertionPointToEnd(before);
1148   lo = before->getArgument(0);
1149   hi = before->getArgument(1);
1150   Value loP1 =
1151       builder.create<arith::AddIOp>(loc, lo, constantIndex(builder, loc, 1));
1152   Value needSort =
1153       builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult, loP1, hi);
1154   builder.create<scf::ConditionOp>(loc, needSort, before->getArguments());
1155 
1156   // The after-region of the WhileOp.
1157   Block *after =
1158       builder.createBlock(&whileOp.getAfter(), {}, types, {loc, loc});
1159   builder.setInsertionPointToEnd(after);
1160   lo = after->getArgument(0);
1161   hi = after->getArgument(1);
1162   args[0] = lo;
1163   args[1] = hi;
1164 
1165   if (isHybrid) {
1166     Value len = builder.create<arith::SubIOp>(loc, hi, lo);
1167     Value lenLimit = constantIndex(builder, loc, 30);
1168     Value lenCond = builder.create<arith::CmpIOp>(
1169         loc, arith::CmpIPredicate::ule, len, lenLimit);
1170     scf::IfOp lenIf =
1171         builder.create<scf::IfOp>(loc, types, lenCond, /*else=*/true);
1172 
1173     // When len <= limit.
1174     builder.setInsertionPointToStart(&lenIf.getThenRegion().front());
1175     FlatSymbolRefAttr insertionSortFunc = getMangledSortHelperFunc(
1176         builder, func, TypeRange(), kSortStableFuncNamePrefix, nx, ny, isCoo,
1177         ValueRange(args).drop_back(nTrailingP), createSortStableFunc);
1178     builder.create<func::CallOp>(loc, insertionSortFunc, TypeRange(),
1179                                  ValueRange(args).drop_back(nTrailingP));
1180     builder.create<scf::YieldOp>(loc, ValueRange{lo, lo});
1181 
1182     // When len > limit.
1183     builder.setInsertionPointToStart(&lenIf.getElseRegion().front());
1184     Value depthLimit = args.back();
1185     depthLimit = builder.create<arith::SubIOp>(loc, depthLimit,
1186                                                constantI64(builder, loc, 1));
1187     Value depthCond =
1188         builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ule,
1189                                       depthLimit, constantI64(builder, loc, 0));
1190     scf::IfOp depthIf =
1191         builder.create<scf::IfOp>(loc, types, depthCond, /*else=*/true);
1192 
1193     // When depth exceeds limit.
1194     builder.setInsertionPointToStart(&depthIf.getThenRegion().front());
1195     FlatSymbolRefAttr heapSortFunc = getMangledSortHelperFunc(
1196         builder, func, TypeRange(), kHeapSortFuncNamePrefix, nx, ny, isCoo,
1197         ValueRange(args).drop_back(nTrailingP), createHeapSortFunc);
1198     builder.create<func::CallOp>(loc, heapSortFunc, TypeRange(),
1199                                  ValueRange(args).drop_back(nTrailingP));
1200     builder.create<scf::YieldOp>(loc, ValueRange{lo, lo});
1201 
1202     // When depth doesn't exceed limit.
1203     builder.setInsertionPointToStart(&depthIf.getElseRegion().front());
1204     args.back() = depthLimit;
1205     std::tie(lo, hi) =
1206         createQuickSort(builder, module, func, args, nx, ny, isCoo, nTrailingP);
1207     builder.create<scf::YieldOp>(loc, ValueRange{lo, hi});
1208 
1209     builder.setInsertionPointAfter(depthIf);
1210     lo = depthIf.getResult(0);
1211     hi = depthIf.getResult(1);
1212     builder.create<scf::YieldOp>(loc, ValueRange{lo, hi});
1213 
1214     builder.setInsertionPointAfter(lenIf);
1215     lo = lenIf.getResult(0);
1216     hi = lenIf.getResult(1);
1217   } else {
1218     std::tie(lo, hi) =
1219         createQuickSort(builder, module, func, args, nx, ny, isCoo, nTrailingP);
1220   }
1221 
1222   // New [lo, hi) for the next while-loop iteration.
1223   builder.create<scf::YieldOp>(loc, ValueRange{lo, hi});
1224 
1225   // After the while-loop.
1226   builder.setInsertionPointAfter(whileOp);
1227   builder.create<func::ReturnOp>(loc);
1228 }
1229 
1230 /// Implements the rewriting for operator sort and sort_coo.
1231 template <typename OpTy>
1232 LogicalResult matchAndRewriteSortOp(OpTy op, ValueRange xys, uint64_t nx,
1233                                     uint64_t ny, bool isCoo,
1234                                     PatternRewriter &rewriter) {
1235   Location loc = op.getLoc();
1236   SmallVector<Value> operands{constantIndex(rewriter, loc, 0), op.getN()};
1237 
1238   // Convert `values` to have dynamic shape and append them to `operands`.
1239   for (Value v : xys) {
1240     auto mtp = getMemRefType(v);
1241     if (!mtp.isDynamicDim(0)) {
1242       auto newMtp =
1243           MemRefType::get({ShapedType::kDynamic}, mtp.getElementType());
1244       v = rewriter.create<memref::CastOp>(loc, newMtp, v);
1245     }
1246     operands.push_back(v);
1247   }
1248 
1249   auto insertPoint = op->template getParentOfType<func::FuncOp>();
1250   if (!insertPoint)
1251     return failure();
1252 
1253   SmallString<32> funcName;
1254   FuncGeneratorType funcGenerator;
1255   uint32_t nTrailingP = 0;
1256   switch (op.getAlgorithm()) {
1257   case SparseTensorSortKind::HybridQuickSort: {
1258     funcName = kHybridQuickSortFuncNamePrefix;
1259     funcGenerator = createQuickSortFunc;
1260     nTrailingP = 1;
1261     // As a heuristics, set depthLimit = 2 * log2(n).
1262     Value lo = operands[loIdx];
1263     Value hi = operands[hiIdx];
1264     Value len = rewriter.create<arith::IndexCastOp>(
1265         loc, rewriter.getI64Type(),
1266         rewriter.create<arith::SubIOp>(loc, hi, lo));
1267     Value depthLimit = rewriter.create<arith::SubIOp>(
1268         loc, constantI64(rewriter, loc, 64),
1269         rewriter.create<math::CountLeadingZerosOp>(loc, len));
1270     operands.push_back(depthLimit);
1271     break;
1272   }
1273   case SparseTensorSortKind::QuickSort:
1274     funcName = kQuickSortFuncNamePrefix;
1275     funcGenerator = createQuickSortFunc;
1276     break;
1277   case SparseTensorSortKind::InsertionSortStable:
1278     funcName = kSortStableFuncNamePrefix;
1279     funcGenerator = createSortStableFunc;
1280     break;
1281   case SparseTensorSortKind::HeapSort:
1282     funcName = kHeapSortFuncNamePrefix;
1283     funcGenerator = createHeapSortFunc;
1284     break;
1285   }
1286 
1287   FlatSymbolRefAttr func =
1288       getMangledSortHelperFunc(rewriter, insertPoint, TypeRange(), funcName, nx,
1289                                ny, isCoo, operands, funcGenerator, nTrailingP);
1290   rewriter.replaceOpWithNewOp<func::CallOp>(op, func, TypeRange(), operands);
1291   return success();
1292 }
1293 
1294 //===---------------------------------------------------------------------===//
1295 // The actual sparse buffer rewriting rules.
1296 //===---------------------------------------------------------------------===//
1297 
1298 namespace {
1299 
1300 /// Sparse rewriting rule for the push_back operator.
1301 struct PushBackRewriter : OpRewritePattern<PushBackOp> {
1302 public:
1303   using OpRewritePattern<PushBackOp>::OpRewritePattern;
1304   PushBackRewriter(MLIRContext *context, bool enableInit)
1305       : OpRewritePattern(context), enableBufferInitialization(enableInit) {}
1306   LogicalResult matchAndRewrite(PushBackOp op,
1307                                 PatternRewriter &rewriter) const override {
1308     // Rewrite push_back(buffer, value, n) to:
1309     // new_size = size(buffer) + n
1310     // if (new_size > capacity(buffer))
1311     //    while new_size > new_capacity
1312     //      new_capacity = new_capacity*2
1313     //    new_buffer = realloc(buffer, new_capacity)
1314     // buffer = new_buffer
1315     // subBuffer = subviewof(buffer)
1316     // linalg.fill subBuffer value
1317     //
1318     // size(buffer) += n
1319     //
1320     // The capacity check is skipped when the attribute inbounds is presented.
1321     Location loc = op->getLoc();
1322     Value c0 = constantIndex(rewriter, loc, 0);
1323     Value buffer = op.getInBuffer();
1324     Value capacity = rewriter.create<memref::DimOp>(loc, buffer, c0);
1325     Value size = op.getCurSize();
1326     Value value = op.getValue();
1327 
1328     Value n = op.getN() ? op.getN() : constantIndex(rewriter, loc, 1);
1329     Value newSize = rewriter.create<arith::AddIOp>(loc, size, n);
1330     auto nValue = dyn_cast_or_null<arith::ConstantIndexOp>(n.getDefiningOp());
1331     bool nIsOne = (nValue && nValue.value() == 1);
1332 
1333     if (!op.getInbounds()) {
1334       Value cond = rewriter.create<arith::CmpIOp>(
1335           loc, arith::CmpIPredicate::ugt, newSize, capacity);
1336 
1337       Value c2 = constantIndex(rewriter, loc, 2);
1338       auto bufferType =
1339           MemRefType::get({ShapedType::kDynamic}, value.getType());
1340       scf::IfOp ifOp = rewriter.create<scf::IfOp>(loc, bufferType, cond,
1341                                                   /*else=*/true);
1342       // True branch.
1343       rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front());
1344       if (nIsOne) {
1345         capacity = rewriter.create<arith::MulIOp>(loc, capacity, c2);
1346       } else {
1347         // Use a do-while loop to calculate the new capacity as follows:
1348         //   do { new_capacity *= 2 } while (size > new_capacity)
1349         scf::WhileOp whileOp =
1350             rewriter.create<scf::WhileOp>(loc, capacity.getType(), capacity);
1351 
1352         // The before-region of the WhileOp.
1353         Block *before = rewriter.createBlock(&whileOp.getBefore(), {},
1354                                              {capacity.getType()}, {loc});
1355         rewriter.setInsertionPointToEnd(before);
1356 
1357         capacity =
1358             rewriter.create<arith::MulIOp>(loc, before->getArgument(0), c2);
1359         cond = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ugt,
1360                                               newSize, capacity);
1361         rewriter.create<scf::ConditionOp>(loc, cond, ValueRange{capacity});
1362         // The after-region of the WhileOp.
1363         Block *after = rewriter.createBlock(&whileOp.getAfter(), {},
1364                                             {capacity.getType()}, {loc});
1365         rewriter.setInsertionPointToEnd(after);
1366         rewriter.create<scf::YieldOp>(loc, after->getArguments());
1367 
1368         rewriter.setInsertionPointAfter(whileOp);
1369         capacity = whileOp.getResult(0);
1370       }
1371 
1372       Value newBuffer =
1373           rewriter.create<memref::ReallocOp>(loc, bufferType, buffer, capacity);
1374       if (enableBufferInitialization) {
1375         Value fillSize = rewriter.create<arith::SubIOp>(loc, capacity, newSize);
1376         Value fillValue = constantZero(rewriter, loc, value.getType());
1377         Value subBuffer = rewriter.create<memref::SubViewOp>(
1378             loc, newBuffer, /*offset=*/ValueRange{newSize},
1379             /*size=*/ValueRange{fillSize},
1380             /*step=*/ValueRange{constantIndex(rewriter, loc, 1)});
1381         rewriter.create<linalg::FillOp>(loc, fillValue, subBuffer);
1382       }
1383       rewriter.create<scf::YieldOp>(loc, newBuffer);
1384 
1385       // False branch.
1386       rewriter.setInsertionPointToStart(&ifOp.getElseRegion().front());
1387       rewriter.create<scf::YieldOp>(loc, buffer);
1388 
1389       // Prepare for adding the value to the end of the buffer.
1390       rewriter.setInsertionPointAfter(ifOp);
1391       buffer = ifOp.getResult(0);
1392     }
1393 
1394     // Add the value to the end of the buffer.
1395     if (nIsOne) {
1396       rewriter.create<memref::StoreOp>(loc, value, buffer, size);
1397     } else {
1398       Value subBuffer = rewriter.create<memref::SubViewOp>(
1399           loc, buffer, /*offset=*/ValueRange{size}, /*size=*/ValueRange{n},
1400           /*step=*/ValueRange{constantIndex(rewriter, loc, 1)});
1401       rewriter.create<linalg::FillOp>(loc, value, subBuffer);
1402     }
1403 
1404     // Update the buffer size.
1405     rewriter.replaceOp(op, {buffer, newSize});
1406     return success();
1407   }
1408 
1409 private:
1410   bool enableBufferInitialization;
1411 };
1412 
1413 /// Sparse rewriting rule for the sort operator.
1414 struct SortRewriter : public OpRewritePattern<SortOp> {
1415 public:
1416   using OpRewritePattern<SortOp>::OpRewritePattern;
1417 
1418   LogicalResult matchAndRewrite(SortOp op,
1419                                 PatternRewriter &rewriter) const override {
1420     SmallVector<Value> xys(op.getXs());
1421     xys.append(op.getYs().begin(), op.getYs().end());
1422     return matchAndRewriteSortOp(op, xys, op.getXs().size(), /*ny=*/0,
1423                                  /*isCoo=*/false, rewriter);
1424   }
1425 };
1426 
1427 /// Sparse rewriting rule for the sort_coo operator.
1428 struct SortCooRewriter : public OpRewritePattern<SortCooOp> {
1429 public:
1430   using OpRewritePattern<SortCooOp>::OpRewritePattern;
1431 
1432   LogicalResult matchAndRewrite(SortCooOp op,
1433                                 PatternRewriter &rewriter) const override {
1434     SmallVector<Value> xys;
1435     xys.push_back(op.getXy());
1436     xys.append(op.getYs().begin(), op.getYs().end());
1437     uint64_t nx = 1;
1438     if (auto nxAttr = op.getNxAttr())
1439       nx = nxAttr.getInt();
1440 
1441     uint64_t ny = 0;
1442     if (auto nyAttr = op.getNyAttr())
1443       ny = nyAttr.getInt();
1444 
1445     return matchAndRewriteSortOp(op, xys, nx, ny,
1446                                  /*isCoo=*/true, rewriter);
1447   }
1448 };
1449 
1450 } // namespace
1451 
1452 //===---------------------------------------------------------------------===//
1453 // Methods that add patterns described in this file to a pattern list.
1454 //===---------------------------------------------------------------------===//
1455 
1456 void mlir::populateSparseBufferRewriting(RewritePatternSet &patterns,
1457                                          bool enableBufferInitialization) {
1458   patterns.add<PushBackRewriter>(patterns.getContext(),
1459                                  enableBufferInitialization);
1460   patterns.add<SortRewriter, SortCooRewriter>(patterns.getContext());
1461 }
1462