xref: /llvm-project/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp (revision 0c7f1c152021249a97640d0ec9396e3885b9dbcc)
1 //===- SparseBufferRewriting.cpp - Sparse buffer rewriting rules ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements rewriting rules that are specific to sparse tensor
10 // primitives with memref operands.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "CodegenUtils.h"
15 
16 #include "mlir/Dialect/Arith/IR/Arith.h"
17 #include "mlir/Dialect/Func/IR/FuncOps.h"
18 #include "mlir/Dialect/Linalg/IR/Linalg.h"
19 #include "mlir/Dialect/MemRef/IR/MemRef.h"
20 #include "mlir/Dialect/SCF/IR/SCF.h"
21 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
22 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
23 #include "mlir/Support/LLVM.h"
24 
25 using namespace mlir;
26 using namespace mlir::sparse_tensor;
27 
28 //===---------------------------------------------------------------------===//
29 // Helper methods for the actual rewriting rules.
30 //===---------------------------------------------------------------------===//
31 
32 static constexpr uint64_t loIdx = 0;
33 static constexpr uint64_t hiIdx = 1;
34 static constexpr uint64_t xStartIdx = 2;
35 
36 static constexpr const char kLessThanFuncNamePrefix[] = "_sparse_less_than_";
37 static constexpr const char kCompareEqFuncNamePrefix[] = "_sparse_compare_eq_";
38 static constexpr const char kPartitionFuncNamePrefix[] = "_sparse_partition_";
39 static constexpr const char kBinarySearchFuncNamePrefix[] =
40     "_sparse_binary_search_";
41 static constexpr const char kSortNonstableFuncNamePrefix[] =
42     "_sparse_sort_nonstable_";
43 static constexpr const char kSortStableFuncNamePrefix[] =
44     "_sparse_sort_stable_";
45 
46 using FuncGeneratorType = function_ref<void(OpBuilder &, ModuleOp, func::FuncOp,
47                                             uint64_t, uint64_t, bool)>;
48 
49 /// Constructs a function name with this format to facilitate quick sort:
50 ///   <namePrefix><nx>_<x type>_<y0 type>..._<yn type> for sort
51 ///   <namePrefix><nx>_<x type>_coo_<ny>_<y0 type>..._<yn type> for sort_coo
52 static void getMangledSortHelperFuncName(llvm::raw_svector_ostream &nameOstream,
53                                          StringRef namePrefix, uint64_t nx,
54                                          uint64_t ny, bool isCoo,
55                                          ValueRange operands) {
56   nameOstream << namePrefix << nx << "_"
57               << getMemRefType(operands[xStartIdx]).getElementType();
58 
59   if (isCoo)
60     nameOstream << "_coo_" << ny;
61 
62   uint64_t yBufferOffset = isCoo ? 1 : nx;
63   for (Value v : operands.drop_front(xStartIdx + yBufferOffset))
64     nameOstream << "_" << getMemRefType(v).getElementType();
65 }
66 
67 /// Looks up a function that is appropriate for the given operands being
68 /// sorted, and creates such a function if it doesn't exist yet. The
69 /// parameters `nx` and `ny` tell the number of x and y values provided
70 /// by the buffer in xStartIdx, and `isCoo` indicates whether the instruction
71 /// being processed is a sparse_tensor.sort or sparse_tensor.sort_coo.
72 static FlatSymbolRefAttr
73 getMangledSortHelperFunc(OpBuilder &builder, func::FuncOp insertPoint,
74                          TypeRange resultTypes, StringRef namePrefix,
75                          uint64_t nx, uint64_t ny, bool isCoo,
76                          ValueRange operands, FuncGeneratorType createFunc) {
77   SmallString<32> nameBuffer;
78   llvm::raw_svector_ostream nameOstream(nameBuffer);
79   getMangledSortHelperFuncName(nameOstream, namePrefix, nx, ny, isCoo,
80                                operands);
81 
82   ModuleOp module = insertPoint->getParentOfType<ModuleOp>();
83   MLIRContext *context = module.getContext();
84   auto result = SymbolRefAttr::get(context, nameOstream.str());
85   auto func = module.lookupSymbol<func::FuncOp>(result.getAttr());
86 
87   if (!func) {
88     // Create the function.
89     OpBuilder::InsertionGuard insertionGuard(builder);
90     builder.setInsertionPoint(insertPoint);
91     Location loc = insertPoint.getLoc();
92     func = builder.create<func::FuncOp>(
93         loc, nameOstream.str(),
94         FunctionType::get(context, operands.getTypes(), resultTypes));
95     func.setPrivate();
96     createFunc(builder, module, func, nx, ny, isCoo);
97   }
98 
99   return result;
100 }
101 
102 /// Creates a code block to process each pair of (xs[i], xs[j]) for sorting.
103 /// The code to process the value pairs is generated by `bodyBuilder`.
104 static void forEachIJPairInXs(
105     OpBuilder &builder, Location loc, ValueRange args, uint64_t nx, uint64_t ny,
106     bool isCoo, function_ref<void(uint64_t, Value, Value, Value)> bodyBuilder) {
107   Value iOffset, jOffset;
108   if (isCoo) {
109     Value cstep = constantIndex(builder, loc, nx + ny);
110     iOffset = builder.create<arith::MulIOp>(loc, args[0], cstep);
111     jOffset = builder.create<arith::MulIOp>(loc, args[1], cstep);
112   }
113   for (uint64_t k = 0; k < nx; k++) {
114     scf::IfOp ifOp;
115     Value i, j, buffer;
116     if (isCoo) {
117       Value ck = constantIndex(builder, loc, k);
118       i = builder.create<arith::AddIOp>(loc, ck, iOffset);
119       j = builder.create<arith::AddIOp>(loc, ck, jOffset);
120       buffer = args[xStartIdx];
121     } else {
122       i = args[0];
123       j = args[1];
124       buffer = args[xStartIdx + k];
125     }
126     bodyBuilder(k, i, j, buffer);
127   }
128 }
129 
130 /// Creates a code block to process each pair of (xys[i], xys[j]) for sorting.
131 /// The code to process the value pairs is generated by `bodyBuilder`.
132 static void forEachIJPairInAllBuffers(
133     OpBuilder &builder, Location loc, ValueRange args, uint64_t nx, uint64_t ny,
134     bool isCoo, function_ref<void(uint64_t, Value, Value, Value)> bodyBuilder) {
135 
136   // Create code for the first (nx + ny) buffers. When isCoo==true, these
137   // logical buffers are all from the xy buffer of the sort_coo operator.
138   forEachIJPairInXs(builder, loc, args, nx + ny, 0, isCoo, bodyBuilder);
139 
140   uint64_t numHandledBuffers = isCoo ? 1 : nx + ny;
141 
142   // Create code for the remaining buffers.
143   Value i = args[0];
144   Value j = args[1];
145   for (const auto &arg :
146        llvm::enumerate(args.drop_front(xStartIdx + numHandledBuffers))) {
147     bodyBuilder(arg.index() + nx + ny, i, j, arg.value());
148   }
149 }
150 
151 /// Creates a code block for swapping the values in index i and j for all the
152 /// buffers.
153 //
154 // The generated IR corresponds to this C like algorithm:
155 //     swap(x0[i], x0[j]);
156 //     swap(x1[i], x1[j]);
157 //     ...
158 //     swap(xn[i], xn[j]);
159 //     swap(y0[i], y0[j]);
160 //     ...
161 //     swap(yn[i], yn[j]);
162 static void createSwap(OpBuilder &builder, Location loc, ValueRange args,
163                        uint64_t nx, uint64_t ny, bool isCoo) {
164   auto swapOnePair = [&](uint64_t unused, Value i, Value j, Value buffer) {
165     Value vi = builder.create<memref::LoadOp>(loc, buffer, i);
166     Value vj = builder.create<memref::LoadOp>(loc, buffer, j);
167     builder.create<memref::StoreOp>(loc, vj, buffer, i);
168     builder.create<memref::StoreOp>(loc, vi, buffer, j);
169   };
170 
171   forEachIJPairInAllBuffers(builder, loc, args, nx, ny, isCoo, swapOnePair);
172 }
173 
174 /// Creates a function to compare all the (xs[i], xs[j]) pairs. The method to
175 /// compare each pair is create via `compareBuilder`.
176 static void createCompareFuncImplementation(
177     OpBuilder &builder, ModuleOp unused, func::FuncOp func, uint64_t nx,
178     uint64_t ny, bool isCoo,
179     function_ref<scf::IfOp(OpBuilder &, Location, Value, Value, Value, bool)>
180         compareBuilder) {
181   OpBuilder::InsertionGuard insertionGuard(builder);
182 
183   Block *entryBlock = func.addEntryBlock();
184   builder.setInsertionPointToStart(entryBlock);
185   Location loc = func.getLoc();
186   ValueRange args = entryBlock->getArguments();
187 
188   scf::IfOp topIfOp;
189   auto bodyBuilder = [&](uint64_t k, Value i, Value j, Value buffer) {
190     scf::IfOp ifOp = compareBuilder(builder, loc, i, j, buffer, (k == nx - 1));
191     if (k == 0) {
192       topIfOp = ifOp;
193     } else {
194       OpBuilder::InsertionGuard insertionGuard(builder);
195       builder.setInsertionPointAfter(ifOp);
196       builder.create<scf::YieldOp>(loc, ifOp.getResult(0));
197     }
198   };
199 
200   forEachIJPairInXs(builder, loc, args, nx, ny, isCoo, bodyBuilder);
201 
202   builder.setInsertionPointAfter(topIfOp);
203   builder.create<func::ReturnOp>(loc, topIfOp.getResult(0));
204 }
205 
206 /// Generates an if-statement to compare whether x[i] is equal to x[j].
207 static scf::IfOp createEqCompare(OpBuilder &builder, Location loc, Value i,
208                                  Value j, Value x, bool isLastDim) {
209   Value f = constantI1(builder, loc, false);
210   Value t = constantI1(builder, loc, true);
211   Value vi = builder.create<memref::LoadOp>(loc, x, i);
212   Value vj = builder.create<memref::LoadOp>(loc, x, j);
213 
214   Value cond =
215       builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, vi, vj);
216   scf::IfOp ifOp =
217       builder.create<scf::IfOp>(loc, f.getType(), cond, /*else=*/true);
218 
219   // x[1] != x[j]:
220   builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
221   builder.create<scf::YieldOp>(loc, f);
222 
223   // x[i] == x[j]:
224   builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
225   if (isLastDim == 1) {
226     // Finish checking all dimensions.
227     builder.create<scf::YieldOp>(loc, t);
228   }
229 
230   return ifOp;
231 }
232 
233 /// Creates a function to compare whether xs[i] is equal to xs[j].
234 //
235 // The generate IR corresponds to this C like algorithm:
236 //   if (x0[i] != x0[j])
237 //     return false;
238 //   else
239 //     if (x1[i] != x1[j])
240 //       return false;
241 //     else if (x2[2] != x2[j]))
242 //       and so on ...
243 static void createEqCompareFunc(OpBuilder &builder, ModuleOp unused,
244                                 func::FuncOp func, uint64_t nx, uint64_t ny,
245                                 bool isCoo) {
246   createCompareFuncImplementation(builder, unused, func, nx, ny, isCoo,
247                                   createEqCompare);
248 }
249 
250 /// Generates an if-statement to compare whether x[i] is less than x[j].
251 static scf::IfOp createLessThanCompare(OpBuilder &builder, Location loc,
252                                        Value i, Value j, Value x,
253                                        bool isLastDim) {
254   Value f = constantI1(builder, loc, false);
255   Value t = constantI1(builder, loc, true);
256   Value vi = builder.create<memref::LoadOp>(loc, x, i);
257   Value vj = builder.create<memref::LoadOp>(loc, x, j);
258 
259   Value cond =
260       builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult, vi, vj);
261   scf::IfOp ifOp =
262       builder.create<scf::IfOp>(loc, f.getType(), cond, /*else=*/true);
263   // If (x[i] < x[j]).
264   builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
265   builder.create<scf::YieldOp>(loc, t);
266 
267   builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
268   if (isLastDim == 1) {
269     // Finish checking all dimensions.
270     builder.create<scf::YieldOp>(loc, f);
271   } else {
272     cond =
273         builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult, vj, vi);
274     scf::IfOp ifOp2 =
275         builder.create<scf::IfOp>(loc, f.getType(), cond, /*else=*/true);
276     // Otherwise if (x[j] < x[i]).
277     builder.setInsertionPointToStart(&ifOp2.getThenRegion().front());
278     builder.create<scf::YieldOp>(loc, f);
279 
280     // Otherwise check the remaining dimensions.
281     builder.setInsertionPointAfter(ifOp2);
282     builder.create<scf::YieldOp>(loc, ifOp2.getResult(0));
283     // Set up the insertion point for the nested if-stmt that checks the
284     // remaining dimensions.
285     builder.setInsertionPointToStart(&ifOp2.getElseRegion().front());
286   }
287 
288   return ifOp;
289 }
290 
291 /// Creates a function to compare whether xs[i] is less than xs[j].
292 //
293 // The generate IR corresponds to this C like algorithm:
294 //   if (x0[i] < x0[j])
295 //     return true;
296 //   else if (x0[j] < x0[i])
297 //     return false;
298 //   else
299 //     if (x1[i] < x1[j])
300 //       return true;
301 //     else if (x1[j] < x1[i]))
302 //       and so on ...
303 static void createLessThanFunc(OpBuilder &builder, ModuleOp unused,
304                                func::FuncOp func, uint64_t nx, uint64_t ny,
305                                bool isCoo) {
306   createCompareFuncImplementation(builder, unused, func, nx, ny, isCoo,
307                                   createLessThanCompare);
308 }
309 
310 /// Creates a function to use a binary search to find the insertion point for
311 /// inserting xs[hi] to the sorted values xs[lo..hi).
312 //
313 // The generate IR corresponds to this C like algorithm:
314 //   p = hi
315 //   while (lo < hi)
316 //      mid = (lo + hi) >> 1
317 //      if (xs[p] < xs[mid])
318 //        hi = mid
319 //      else
320 //        lo = mid - 1
321 //   return lo;
322 //
323 static void createBinarySearchFunc(OpBuilder &builder, ModuleOp module,
324                                    func::FuncOp func, uint64_t nx, uint64_t ny,
325                                    bool isCoo) {
326   OpBuilder::InsertionGuard insertionGuard(builder);
327   Block *entryBlock = func.addEntryBlock();
328   builder.setInsertionPointToStart(entryBlock);
329 
330   Location loc = func.getLoc();
331   ValueRange args = entryBlock->getArguments();
332   Value p = args[hiIdx];
333   SmallVector<Type, 2> types(2, p.getType()); // only two
334   scf::WhileOp whileOp = builder.create<scf::WhileOp>(
335       loc, types, SmallVector<Value, 2>{args[loIdx], args[hiIdx]});
336 
337   // The before-region of the WhileOp.
338   Block *before =
339       builder.createBlock(&whileOp.getBefore(), {}, types, {loc, loc});
340   builder.setInsertionPointToEnd(before);
341   Value cond1 = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult,
342                                               before->getArgument(0),
343                                               before->getArgument(1));
344   builder.create<scf::ConditionOp>(loc, cond1, before->getArguments());
345 
346   // The after-region of the WhileOp.
347   Block *after =
348       builder.createBlock(&whileOp.getAfter(), {}, types, {loc, loc});
349   builder.setInsertionPointToEnd(after);
350   Value lo = after->getArgument(0);
351   Value hi = after->getArgument(1);
352   // Compute mid = (lo + hi) >> 1.
353   Value c1 = constantIndex(builder, loc, 1);
354   Value mid = builder.create<arith::ShRUIOp>(
355       loc, builder.create<arith::AddIOp>(loc, lo, hi), c1);
356   Value midp1 = builder.create<arith::AddIOp>(loc, mid, c1);
357 
358   // Compare xs[p] < xs[mid].
359   SmallVector<Value> compareOperands{p, mid};
360   uint64_t numXBuffers = isCoo ? 1 : nx;
361   compareOperands.append(args.begin() + xStartIdx,
362                          args.begin() + xStartIdx + numXBuffers);
363   Type i1Type = IntegerType::get(module.getContext(), 1, IntegerType::Signless);
364   FlatSymbolRefAttr lessThanFunc = getMangledSortHelperFunc(
365       builder, func, {i1Type}, kLessThanFuncNamePrefix, nx, ny, isCoo,
366       compareOperands, createLessThanFunc);
367   Value cond2 = builder
368                     .create<func::CallOp>(loc, lessThanFunc, TypeRange{i1Type},
369                                           compareOperands)
370                     .getResult(0);
371 
372   // Update lo and hi for the WhileOp as follows:
373   //   if (xs[p] < xs[mid]))
374   //     hi = mid;
375   //   else
376   //     lo = mid + 1;
377   Value newLo = builder.create<arith::SelectOp>(loc, cond2, lo, midp1);
378   Value newHi = builder.create<arith::SelectOp>(loc, cond2, mid, hi);
379   builder.create<scf::YieldOp>(loc, ValueRange{newLo, newHi});
380 
381   builder.setInsertionPointAfter(whileOp);
382   builder.create<func::ReturnOp>(loc, whileOp.getResult(0));
383 }
384 
385 /// Creates code to advance i in a loop based on xs[p] as follows:
386 ///   while (xs[i] < xs[p]) i += step (step > 0)
387 /// or
388 ///   while (xs[i] > xs[p]) i += step (step < 0)
389 /// The routine returns i as well as a boolean value to indicate whether
390 /// xs[i] == xs[p].
391 static std::pair<Value, Value>
392 createScanLoop(OpBuilder &builder, ModuleOp module, func::FuncOp func,
393                ValueRange xs, Value i, Value p, uint64_t nx, uint64_t ny,
394                bool isCoo, int step) {
395   Location loc = func.getLoc();
396   scf::WhileOp whileOp =
397       builder.create<scf::WhileOp>(loc, TypeRange{i.getType()}, ValueRange{i});
398 
399   Block *before =
400       builder.createBlock(&whileOp.getBefore(), {}, {i.getType()}, {loc});
401   builder.setInsertionPointToEnd(before);
402   SmallVector<Value> compareOperands;
403   if (step > 0) {
404     compareOperands.push_back(before->getArgument(0));
405     compareOperands.push_back(p);
406   } else {
407     assert(step < 0);
408     compareOperands.push_back(p);
409     compareOperands.push_back(before->getArgument(0));
410   }
411   compareOperands.append(xs.begin(), xs.end());
412   MLIRContext *context = module.getContext();
413   Type i1Type = IntegerType::get(context, 1, IntegerType::Signless);
414   FlatSymbolRefAttr lessThanFunc = getMangledSortHelperFunc(
415       builder, func, {i1Type}, kLessThanFuncNamePrefix, nx, ny, isCoo,
416       compareOperands, createLessThanFunc);
417   Value cond = builder
418                    .create<func::CallOp>(loc, lessThanFunc, TypeRange{i1Type},
419                                          compareOperands)
420                    .getResult(0);
421   builder.create<scf::ConditionOp>(loc, cond, before->getArguments());
422 
423   Block *after =
424       builder.createBlock(&whileOp.getAfter(), {}, {i.getType()}, {loc});
425   builder.setInsertionPointToEnd(after);
426   Value cs = constantIndex(builder, loc, step);
427   i = builder.create<arith::AddIOp>(loc, after->getArgument(0), cs);
428   builder.create<scf::YieldOp>(loc, ValueRange{i});
429   i = whileOp.getResult(0);
430 
431   builder.setInsertionPointAfter(whileOp);
432   compareOperands[0] = i;
433   compareOperands[1] = p;
434   FlatSymbolRefAttr compareEqFunc = getMangledSortHelperFunc(
435       builder, func, {i1Type}, kCompareEqFuncNamePrefix, nx, ny, isCoo,
436       compareOperands, createEqCompareFunc);
437   Value compareEq =
438       builder
439           .create<func::CallOp>(loc, compareEqFunc, TypeRange{i1Type},
440                                 compareOperands)
441           .getResult(0);
442 
443   return std::make_pair(whileOp.getResult(0), compareEq);
444 }
445 
446 /// Creates a code block to swap the values so that data[mi] is the median among
447 /// data[lo], data[hi], and data[mi].
448 //  The generated code corresponds to this C-like algorithm:
449 //  median = mi
450 //  if (data[mi] < data[lo]).                               (if1)
451 //    if (data[hi] < data[lo])                              (if2)
452 //       median = data[hi] < data[mi] ? mi : hi
453 //    else
454 //       median = lo
455 //  else
456 //    if data[hi] < data[mi]                                (if3)
457 //      median = data[hi] < data[lo] ? lo : hi
458 //  if median != mi swap data[median] with data[mi]
459 static void createChoosePivot(OpBuilder &builder, ModuleOp module,
460                               func::FuncOp func, uint64_t nx, uint64_t ny,
461                               bool isCoo, Value lo, Value hi, Value mi,
462                               ValueRange args) {
463   SmallVector<Value> compareOperands{mi, lo};
464   uint64_t numXBuffers = isCoo ? 1 : nx;
465   compareOperands.append(args.begin() + xStartIdx,
466                          args.begin() + xStartIdx + numXBuffers);
467   Type i1Type = IntegerType::get(module.getContext(), 1, IntegerType::Signless);
468   SmallVector<Type, 1> cmpTypes{i1Type};
469   FlatSymbolRefAttr lessThanFunc = getMangledSortHelperFunc(
470       builder, func, cmpTypes, kLessThanFuncNamePrefix, nx, ny, isCoo,
471       compareOperands, createLessThanFunc);
472   Location loc = func.getLoc();
473   // Compare data[mi] < data[lo].
474   Value cond1 =
475       builder.create<func::CallOp>(loc, lessThanFunc, cmpTypes, compareOperands)
476           .getResult(0);
477   SmallVector<Type, 1> ifTypes{lo.getType()};
478   scf::IfOp ifOp1 =
479       builder.create<scf::IfOp>(loc, ifTypes, cond1, /*else=*/true);
480 
481   // Generate an if-stmt to find the median value, assuming we already know that
482   // data[b] < data[a] and we haven't compare data[c] yet.
483   auto createFindMedian = [&](Value a, Value b, Value c) -> scf::IfOp {
484     compareOperands[0] = c;
485     compareOperands[1] = a;
486     // Compare data[c]] < data[a].
487     Value cond2 =
488         builder
489             .create<func::CallOp>(loc, lessThanFunc, cmpTypes, compareOperands)
490             .getResult(0);
491     scf::IfOp ifOp2 =
492         builder.create<scf::IfOp>(loc, ifTypes, cond2, /*else=*/true);
493     builder.setInsertionPointToStart(&ifOp2.getThenRegion().front());
494     compareOperands[0] = c;
495     compareOperands[1] = b;
496     // Compare data[c] < data[b].
497     Value cond3 =
498         builder
499             .create<func::CallOp>(loc, lessThanFunc, cmpTypes, compareOperands)
500             .getResult(0);
501     builder.create<scf::YieldOp>(
502         loc, ValueRange{builder.create<arith::SelectOp>(loc, cond3, b, c)});
503     builder.setInsertionPointToStart(&ifOp2.getElseRegion().front());
504     builder.create<scf::YieldOp>(loc, ValueRange{a});
505     return ifOp2;
506   };
507 
508   builder.setInsertionPointToStart(&ifOp1.getThenRegion().front());
509   scf::IfOp ifOp2 = createFindMedian(lo, mi, hi);
510   builder.setInsertionPointAfter(ifOp2);
511   builder.create<scf::YieldOp>(loc, ValueRange{ifOp2.getResult(0)});
512 
513   builder.setInsertionPointToStart(&ifOp1.getElseRegion().front());
514   scf::IfOp ifOp3 = createFindMedian(mi, lo, hi);
515 
516   builder.setInsertionPointAfter(ifOp3);
517   builder.create<scf::YieldOp>(loc, ValueRange{ifOp3.getResult(0)});
518 
519   builder.setInsertionPointAfter(ifOp1);
520   Value median = ifOp1.getResult(0);
521   Value cond =
522       builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ne, mi, median);
523   scf::IfOp ifOp =
524       builder.create<scf::IfOp>(loc, TypeRange(), cond, /*else=*/false);
525 
526   SmallVector<Value> swapOperands{median, mi};
527   swapOperands.append(args.begin() + xStartIdx, args.end());
528   builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
529   createSwap(builder, loc, swapOperands, nx, ny, isCoo);
530   builder.setInsertionPointAfter(ifOp);
531 }
532 
533 /// Creates a function to perform quick sort partition on the values in the
534 /// range of index [lo, hi), assuming lo < hi.
535 //
536 // The generated IR corresponds to this C like algorithm:
537 // int partition(lo, hi, xs) {
538 //   p = (lo+hi)/2  // pivot index
539 //   i = lo
540 //   j = hi-1
541 //   while (i < j) do {
542 //     while (xs[i] < xs[p]) i ++;
543 //     i_eq = (xs[i] == xs[p]);
544 //     while (xs[j] > xs[p]) j --;
545 //     j_eq = (xs[j] == xs[p]);
546 //     if (i < j) {
547 //       swap(xs[i], xs[j])
548 //       if (i == p) {
549 //         p = j;
550 //       } else if (j == p) {
551 //         p = i;
552 //       }
553 //       if (i_eq && j_eq) {
554 //         ++i;
555 //         --j;
556 //       }
557 //     }
558 //   }
559 //   return p
560 //   }
561 static void createPartitionFunc(OpBuilder &builder, ModuleOp module,
562                                 func::FuncOp func, uint64_t nx, uint64_t ny,
563                                 bool isCoo) {
564   OpBuilder::InsertionGuard insertionGuard(builder);
565 
566   Block *entryBlock = func.addEntryBlock();
567   builder.setInsertionPointToStart(entryBlock);
568 
569   Location loc = func.getLoc();
570   ValueRange args = entryBlock->getArguments();
571   Value lo = args[loIdx];
572   Value hi = args[hiIdx];
573   Value sum = builder.create<arith::AddIOp>(loc, lo, hi);
574   Value c1 = constantIndex(builder, loc, 1);
575   Value p = builder.create<arith::ShRUIOp>(loc, sum, c1);
576 
577   Value i = lo;
578   Value j = builder.create<arith::SubIOp>(loc, hi, c1);
579   createChoosePivot(builder, module, func, nx, ny, isCoo, i, j, p, args);
580   SmallVector<Value, 3> operands{i, j, p}; // Exactly three values.
581   SmallVector<Type, 3> types{i.getType(), j.getType(), p.getType()};
582   scf::WhileOp whileOp = builder.create<scf::WhileOp>(loc, types, operands);
583 
584   // The before-region of the WhileOp.
585   Block *before =
586       builder.createBlock(&whileOp.getBefore(), {}, types, {loc, loc, loc});
587   builder.setInsertionPointToEnd(before);
588   Value cond = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult,
589                                              before->getArgument(0),
590                                              before->getArgument(1));
591   builder.create<scf::ConditionOp>(loc, cond, before->getArguments());
592 
593   // The after-region of the WhileOp.
594   Block *after =
595       builder.createBlock(&whileOp.getAfter(), {}, types, {loc, loc, loc});
596   builder.setInsertionPointToEnd(after);
597   i = after->getArgument(0);
598   j = after->getArgument(1);
599   p = after->getArgument(2);
600 
601   uint64_t numXBuffers = isCoo ? 1 : nx;
602   auto [iresult, iCompareEq] =
603       createScanLoop(builder, module, func, args.slice(xStartIdx, numXBuffers),
604                      i, p, nx, ny, isCoo, 1);
605   i = iresult;
606   auto [jresult, jCompareEq] =
607       createScanLoop(builder, module, func, args.slice(xStartIdx, numXBuffers),
608                      j, p, nx, ny, isCoo, -1);
609   j = jresult;
610 
611   // If i < j:
612   cond = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult, i, j);
613   scf::IfOp ifOp = builder.create<scf::IfOp>(loc, types, cond, /*else=*/true);
614   builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
615   SmallVector<Value> swapOperands{i, j};
616   swapOperands.append(args.begin() + xStartIdx, args.end());
617   createSwap(builder, loc, swapOperands, nx, ny, isCoo);
618   // If the pivot is moved, update p with the new pivot.
619   Value icond =
620       builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, i, p);
621   scf::IfOp ifOpI = builder.create<scf::IfOp>(loc, TypeRange{p.getType()},
622                                               icond, /*else=*/true);
623   builder.setInsertionPointToStart(&ifOpI.getThenRegion().front());
624   builder.create<scf::YieldOp>(loc, ValueRange{j});
625   builder.setInsertionPointToStart(&ifOpI.getElseRegion().front());
626   Value jcond =
627       builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, j, p);
628   scf::IfOp ifOpJ = builder.create<scf::IfOp>(loc, TypeRange{p.getType()},
629                                               jcond, /*else=*/true);
630   builder.setInsertionPointToStart(&ifOpJ.getThenRegion().front());
631   builder.create<scf::YieldOp>(loc, ValueRange{i});
632   builder.setInsertionPointToStart(&ifOpJ.getElseRegion().front());
633   builder.create<scf::YieldOp>(loc, ValueRange{p});
634   builder.setInsertionPointAfter(ifOpJ);
635   builder.create<scf::YieldOp>(loc, ifOpJ.getResults());
636   builder.setInsertionPointAfter(ifOpI);
637   Value compareEqIJ =
638       builder.create<arith::AndIOp>(loc, iCompareEq, jCompareEq);
639   scf::IfOp ifOp2 = builder.create<scf::IfOp>(
640       loc, TypeRange{i.getType(), j.getType()}, compareEqIJ, /*else=*/true);
641   builder.setInsertionPointToStart(&ifOp2.getThenRegion().front());
642   Value i2 = builder.create<arith::AddIOp>(loc, i, c1);
643   Value j2 = builder.create<arith::SubIOp>(loc, j, c1);
644   builder.create<scf::YieldOp>(loc, ValueRange{i2, j2});
645   builder.setInsertionPointToStart(&ifOp2.getElseRegion().front());
646   builder.create<scf::YieldOp>(loc, ValueRange{i, j});
647   builder.setInsertionPointAfter(ifOp2);
648   builder.create<scf::YieldOp>(
649       loc,
650       ValueRange{ifOp2.getResult(0), ifOp2.getResult(1), ifOpI.getResult(0)});
651 
652   // False branch for if i < j:
653   builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
654   builder.create<scf::YieldOp>(loc, ValueRange{i, j, p});
655 
656   // Return for the whileOp.
657   builder.setInsertionPointAfter(ifOp);
658   builder.create<scf::YieldOp>(loc, ifOp.getResults());
659 
660   // Return for the function.
661   builder.setInsertionPointAfter(whileOp);
662   builder.create<func::ReturnOp>(loc, whileOp.getResult(2));
663 }
664 
665 /// Creates a function to perform quick sort on the value in the range of
666 /// index [lo, hi).
667 //
668 // The generate IR corresponds to this C like algorithm:
669 // void quickSort(lo, hi, data) {
670 //   if (lo < hi) {
671 //        p = partition(low, high, data);
672 //        quickSort(lo, p, data);
673 //        quickSort(p + 1, hi, data);
674 //   }
675 // }
676 static void createSortNonstableFunc(OpBuilder &builder, ModuleOp module,
677                                     func::FuncOp func, uint64_t nx, uint64_t ny,
678                                     bool isCoo) {
679   OpBuilder::InsertionGuard insertionGuard(builder);
680   Block *entryBlock = func.addEntryBlock();
681   builder.setInsertionPointToStart(entryBlock);
682 
683   MLIRContext *context = module.getContext();
684   Location loc = func.getLoc();
685   ValueRange args = entryBlock->getArguments();
686   Value lo = args[loIdx];
687   Value hi = args[hiIdx];
688   Value cond =
689       builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult, lo, hi);
690   scf::IfOp ifOp = builder.create<scf::IfOp>(loc, cond, /*else=*/false);
691 
692   // The if-stmt true branch.
693   builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
694   FlatSymbolRefAttr partitionFunc = getMangledSortHelperFunc(
695       builder, func, {IndexType::get(context)}, kPartitionFuncNamePrefix, nx,
696       ny, isCoo, args, createPartitionFunc);
697   auto p = builder.create<func::CallOp>(
698       loc, partitionFunc, TypeRange{IndexType::get(context)}, ValueRange(args));
699 
700   SmallVector<Value> lowOperands{lo, p.getResult(0)};
701   lowOperands.append(args.begin() + xStartIdx, args.end());
702   builder.create<func::CallOp>(loc, func, lowOperands);
703 
704   SmallVector<Value> highOperands{
705       builder.create<arith::AddIOp>(loc, p.getResult(0),
706                                     constantIndex(builder, loc, 1)),
707       hi};
708   highOperands.append(args.begin() + xStartIdx, args.end());
709   builder.create<func::CallOp>(loc, func, highOperands);
710 
711   // After the if-stmt.
712   builder.setInsertionPointAfter(ifOp);
713   builder.create<func::ReturnOp>(loc);
714 }
715 
716 /// Creates a function to perform insertion sort on the values in the range of
717 /// index [lo, hi).
718 //
719 // The generate IR corresponds to this C like algorithm:
720 // void insertionSort(lo, hi, data) {
721 //   for (i = lo+1; i < hi; i++) {
722 //      d = data[i];
723 //      p = binarySearch(lo, i-1, data)
724 //      for (j = 0; j > i - p; j++)
725 //        data[i-j] = data[i-j-1]
726 //      data[p] = d
727 //   }
728 // }
729 static void createSortStableFunc(OpBuilder &builder, ModuleOp module,
730                                  func::FuncOp func, uint64_t nx, uint64_t ny,
731                                  bool isCoo) {
732   OpBuilder::InsertionGuard insertionGuard(builder);
733   Block *entryBlock = func.addEntryBlock();
734   builder.setInsertionPointToStart(entryBlock);
735 
736   MLIRContext *context = module.getContext();
737   Location loc = func.getLoc();
738   ValueRange args = entryBlock->getArguments();
739   Value c1 = constantIndex(builder, loc, 1);
740   Value lo = args[loIdx];
741   Value hi = args[hiIdx];
742   Value lop1 = builder.create<arith::AddIOp>(loc, lo, c1);
743 
744   // Start the outer for-stmt with induction variable i.
745   scf::ForOp forOpI = builder.create<scf::ForOp>(loc, lop1, hi, c1);
746   builder.setInsertionPointToStart(forOpI.getBody());
747   Value i = forOpI.getInductionVar();
748 
749   // Binary search to find the insertion point p.
750   SmallVector<Value> operands{lo, i};
751   operands.append(args.begin() + xStartIdx, args.end());
752   FlatSymbolRefAttr searchFunc = getMangledSortHelperFunc(
753       builder, func, {IndexType::get(context)}, kBinarySearchFuncNamePrefix, nx,
754       ny, isCoo, operands, createBinarySearchFunc);
755   Value p = builder
756                 .create<func::CallOp>(loc, searchFunc, TypeRange{c1.getType()},
757                                       operands)
758                 .getResult(0);
759 
760   // Move the value at data[i] to a temporary location.
761   operands[0] = operands[1] = i;
762   SmallVector<Value> d;
763   forEachIJPairInAllBuffers(
764       builder, loc, operands, nx, ny, isCoo,
765       [&](uint64_t unused, Value i, Value unused2, Value buffer) {
766         d.push_back(builder.create<memref::LoadOp>(loc, buffer, i));
767       });
768 
769   // Start the inner for-stmt with induction variable j, for moving data[p..i)
770   // to data[p+1..i+1).
771   Value imp = builder.create<arith::SubIOp>(loc, i, p);
772   Value c0 = constantIndex(builder, loc, 0);
773   scf::ForOp forOpJ = builder.create<scf::ForOp>(loc, c0, imp, c1);
774   builder.setInsertionPointToStart(forOpJ.getBody());
775   Value j = forOpJ.getInductionVar();
776   Value imj = builder.create<arith::SubIOp>(loc, i, j);
777   operands[1] = imj;
778   operands[0] = builder.create<arith::SubIOp>(loc, imj, c1);
779   forEachIJPairInAllBuffers(
780       builder, loc, operands, nx, ny, isCoo,
781       [&](uint64_t unused, Value imjm1, Value imj, Value buffer) {
782         Value t = builder.create<memref::LoadOp>(loc, buffer, imjm1);
783         builder.create<memref::StoreOp>(loc, t, buffer, imj);
784       });
785 
786   // Store the value at data[i] to data[p].
787   builder.setInsertionPointAfter(forOpJ);
788   operands[0] = operands[1] = p;
789   forEachIJPairInAllBuffers(
790       builder, loc, operands, nx, ny, isCoo,
791       [&](uint64_t k, Value p, Value usused, Value buffer) {
792         builder.create<memref::StoreOp>(loc, d[k], buffer, p);
793       });
794 
795   builder.setInsertionPointAfter(forOpI);
796   builder.create<func::ReturnOp>(loc);
797 }
798 
799 /// Implements the rewriting for operator sort and sort_coo.
800 template <typename OpTy>
801 LogicalResult matchAndRewriteSortOp(OpTy op, ValueRange xys, uint64_t nx,
802                                     uint64_t ny, bool isCoo,
803                                     PatternRewriter &rewriter) {
804   Location loc = op.getLoc();
805   SmallVector<Value> operands{constantIndex(rewriter, loc, 0), op.getN()};
806 
807   // Convert `values` to have dynamic shape and append them to `operands`.
808   for (Value v : xys) {
809     auto mtp = getMemRefType(v);
810     if (!mtp.isDynamicDim(0)) {
811       auto newMtp =
812           MemRefType::get({ShapedType::kDynamic}, mtp.getElementType());
813       v = rewriter.create<memref::CastOp>(loc, newMtp, v);
814     }
815     operands.push_back(v);
816   }
817   bool isStable =
818       (op.getAlgorithm() == SparseTensorSortKind::InsertionSortStable);
819   auto insertPoint = op->template getParentOfType<func::FuncOp>();
820   SmallString<32> funcName(isStable ? kSortStableFuncNamePrefix
821                                     : kSortNonstableFuncNamePrefix);
822   FuncGeneratorType funcGenerator =
823       isStable ? createSortStableFunc : createSortNonstableFunc;
824   FlatSymbolRefAttr func =
825       getMangledSortHelperFunc(rewriter, insertPoint, TypeRange(), funcName, nx,
826                                ny, isCoo, operands, funcGenerator);
827   rewriter.replaceOpWithNewOp<func::CallOp>(op, func, TypeRange(), operands);
828   return success();
829 }
830 
831 //===---------------------------------------------------------------------===//
832 // The actual sparse buffer rewriting rules.
833 //===---------------------------------------------------------------------===//
834 
835 namespace {
836 
837 /// Sparse rewriting rule for the push_back operator.
838 struct PushBackRewriter : OpRewritePattern<PushBackOp> {
839 public:
840   using OpRewritePattern<PushBackOp>::OpRewritePattern;
841   PushBackRewriter(MLIRContext *context, bool enableInit)
842       : OpRewritePattern(context), enableBufferInitialization(enableInit) {}
843   LogicalResult matchAndRewrite(PushBackOp op,
844                                 PatternRewriter &rewriter) const override {
845     // Rewrite push_back(buffer, value, n) to:
846     // new_size = size(buffer) + n
847     // if (new_size > capacity(buffer))
848     //    while new_size > new_capacity
849     //      new_capacity = new_capacity*2
850     //    new_buffer = realloc(buffer, new_capacity)
851     // buffer = new_buffer
852     // subBuffer = subviewof(buffer)
853     // linalg.fill subBuffer value
854     //
855     // size(buffer) += n
856     //
857     // The capacity check is skipped when the attribute inbounds is presented.
858     Location loc = op->getLoc();
859     Value c0 = constantIndex(rewriter, loc, 0);
860     Value buffer = op.getInBuffer();
861     Value capacity = rewriter.create<memref::DimOp>(loc, buffer, c0);
862     Value size = op.getCurSize();
863     Value value = op.getValue();
864 
865     Value n = op.getN() ? op.getN() : constantIndex(rewriter, loc, 1);
866     Value newSize = rewriter.create<arith::AddIOp>(loc, size, n);
867     auto nValue = dyn_cast_or_null<arith::ConstantIndexOp>(n.getDefiningOp());
868     bool nIsOne = (nValue && nValue.value() == 1);
869 
870     if (!op.getInbounds()) {
871       Value cond = rewriter.create<arith::CmpIOp>(
872           loc, arith::CmpIPredicate::ugt, newSize, capacity);
873 
874       Value c2 = constantIndex(rewriter, loc, 2);
875       auto bufferType =
876           MemRefType::get({ShapedType::kDynamic}, value.getType());
877       scf::IfOp ifOp = rewriter.create<scf::IfOp>(loc, bufferType, cond,
878                                                   /*else=*/true);
879       // True branch.
880       rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front());
881       if (nIsOne) {
882         capacity = rewriter.create<arith::MulIOp>(loc, capacity, c2);
883       } else {
884         // Use a do-while loop to calculate the new capacity as follows:
885         //   do { new_capacity *= 2 } while (size > new_capacity)
886         scf::WhileOp whileOp =
887             rewriter.create<scf::WhileOp>(loc, capacity.getType(), capacity);
888 
889         // The before-region of the WhileOp.
890         Block *before = rewriter.createBlock(&whileOp.getBefore(), {},
891                                              {capacity.getType()}, {loc});
892         rewriter.setInsertionPointToEnd(before);
893 
894         capacity =
895             rewriter.create<arith::MulIOp>(loc, before->getArgument(0), c2);
896         cond = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ugt,
897                                               newSize, capacity);
898         rewriter.create<scf::ConditionOp>(loc, cond, ValueRange{capacity});
899         // The after-region of the WhileOp.
900         Block *after = rewriter.createBlock(&whileOp.getAfter(), {},
901                                             {capacity.getType()}, {loc});
902         rewriter.setInsertionPointToEnd(after);
903         rewriter.create<scf::YieldOp>(loc, after->getArguments());
904 
905         rewriter.setInsertionPointAfter(whileOp);
906         capacity = whileOp.getResult(0);
907       }
908 
909       Value newBuffer =
910           rewriter.create<memref::ReallocOp>(loc, bufferType, buffer, capacity);
911       if (enableBufferInitialization) {
912         Value fillSize = rewriter.create<arith::SubIOp>(loc, capacity, newSize);
913         Value fillValue = constantZero(rewriter, loc, value.getType());
914         Value subBuffer = rewriter.create<memref::SubViewOp>(
915             loc, newBuffer, /*offset=*/ValueRange{newSize},
916             /*size=*/ValueRange{fillSize},
917             /*step=*/ValueRange{constantIndex(rewriter, loc, 1)});
918         rewriter.create<linalg::FillOp>(loc, fillValue, subBuffer);
919       }
920       rewriter.create<scf::YieldOp>(loc, newBuffer);
921 
922       // False branch.
923       rewriter.setInsertionPointToStart(&ifOp.getElseRegion().front());
924       rewriter.create<scf::YieldOp>(loc, buffer);
925 
926       // Prepare for adding the value to the end of the buffer.
927       rewriter.setInsertionPointAfter(ifOp);
928       buffer = ifOp.getResult(0);
929     }
930 
931     // Add the value to the end of the buffer.
932     if (nIsOne) {
933       rewriter.create<memref::StoreOp>(loc, value, buffer, size);
934     } else {
935       Value subBuffer = rewriter.create<memref::SubViewOp>(
936           loc, buffer, /*offset=*/ValueRange{size}, /*size=*/ValueRange{n},
937           /*step=*/ValueRange{constantIndex(rewriter, loc, 1)});
938       rewriter.create<linalg::FillOp>(loc, value, subBuffer);
939     }
940 
941     // Update the buffer size.
942     rewriter.replaceOp(op, {buffer, newSize});
943     return success();
944   }
945 
946 private:
947   bool enableBufferInitialization;
948 };
949 
950 /// Sparse rewriting rule for the sort operator.
951 struct SortRewriter : public OpRewritePattern<SortOp> {
952 public:
953   using OpRewritePattern<SortOp>::OpRewritePattern;
954 
955   LogicalResult matchAndRewrite(SortOp op,
956                                 PatternRewriter &rewriter) const override {
957     SmallVector<Value> xys(op.getXs());
958     xys.append(op.getYs().begin(), op.getYs().end());
959     return matchAndRewriteSortOp(op, xys, op.getXs().size(), /*ny=*/0,
960                                  /*isCoo=*/false, rewriter);
961   }
962 };
963 
964 /// Sparse rewriting rule for the sort_coo operator.
965 struct SortCooRewriter : public OpRewritePattern<SortCooOp> {
966 public:
967   using OpRewritePattern<SortCooOp>::OpRewritePattern;
968 
969   LogicalResult matchAndRewrite(SortCooOp op,
970                                 PatternRewriter &rewriter) const override {
971     SmallVector<Value> xys;
972     xys.push_back(op.getXy());
973     xys.append(op.getYs().begin(), op.getYs().end());
974     uint64_t nx = 1;
975     if (auto nxAttr = op.getNxAttr())
976       nx = nxAttr.getInt();
977 
978     uint64_t ny = 0;
979     if (auto nyAttr = op.getNyAttr())
980       ny = nyAttr.getInt();
981 
982     return matchAndRewriteSortOp(op, xys, nx, ny,
983                                  /*isCoo=*/true, rewriter);
984   }
985 };
986 
987 } // namespace
988 
989 //===---------------------------------------------------------------------===//
990 // Methods that add patterns described in this file to a pattern list.
991 //===---------------------------------------------------------------------===//
992 
993 void mlir::populateSparseBufferRewriting(RewritePatternSet &patterns,
994                                          bool enableBufferInitialization) {
995   patterns.add<PushBackRewriter>(patterns.getContext(),
996                                  enableBufferInitialization);
997   patterns.add<SortRewriter, SortCooRewriter>(patterns.getContext());
998 }
999