xref: /llvm-project/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp (revision 8550aebd57d08d34c11ba438d07e1e0942b97f31)
1 //===- SparseBufferRewriting.cpp - Sparse buffer rewriting rules ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements rewriting rules that are specific to sparse tensor
10 // primitives with memref operands.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "CodegenUtils.h"
15 
16 #include "mlir/Dialect/Arith/IR/Arith.h"
17 #include "mlir/Dialect/Func/IR/FuncOps.h"
18 #include "mlir/Dialect/Linalg/IR/Linalg.h"
19 #include "mlir/Dialect/MemRef/IR/MemRef.h"
20 #include "mlir/Dialect/SCF/IR/SCF.h"
21 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
22 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
23 #include "mlir/Support/LLVM.h"
24 
25 using namespace mlir;
26 using namespace mlir::sparse_tensor;
27 
28 //===---------------------------------------------------------------------===//
29 // Helper methods for the actual rewriting rules.
30 //===---------------------------------------------------------------------===//
31 
32 static constexpr uint64_t loIdx = 0;
33 static constexpr uint64_t hiIdx = 1;
34 static constexpr uint64_t xStartIdx = 2;
35 
36 static constexpr const char kLessThanFuncNamePrefix[] = "_sparse_less_than_";
37 static constexpr const char kCompareEqFuncNamePrefix[] = "_sparse_compare_eq_";
38 static constexpr const char kPartitionFuncNamePrefix[] = "_sparse_partition_";
39 static constexpr const char kBinarySearchFuncNamePrefix[] =
40     "_sparse_binary_search_";
41 static constexpr const char kSortNonstableFuncNamePrefix[] =
42     "_sparse_sort_nonstable_";
43 static constexpr const char kSortStableFuncNamePrefix[] =
44     "_sparse_sort_stable_";
45 
46 using FuncGeneratorType = function_ref<void(
47     OpBuilder &, ModuleOp, func::FuncOp, uint64_t, uint64_t, bool, uint32_t)>;
48 
49 /// Constructs a function name with this format to facilitate quick sort:
50 ///   <namePrefix><nx>_<x type>_<y0 type>..._<yn type> for sort
51 ///   <namePrefix><nx>_<x type>_coo_<ny>_<y0 type>..._<yn type> for sort_coo
52 static void getMangledSortHelperFuncName(llvm::raw_svector_ostream &nameOstream,
53                                          StringRef namePrefix, uint64_t nx,
54                                          uint64_t ny, bool isCoo,
55                                          ValueRange operands) {
56   nameOstream << namePrefix << nx << "_"
57               << getMemRefType(operands[xStartIdx]).getElementType();
58 
59   if (isCoo)
60     nameOstream << "_coo_" << ny;
61 
62   uint64_t yBufferOffset = isCoo ? 1 : nx;
63   for (Value v : operands.drop_front(xStartIdx + yBufferOffset))
64     nameOstream << "_" << getMemRefType(v).getElementType();
65 }
66 
67 /// Looks up a function that is appropriate for the given operands being
68 /// sorted, and creates such a function if it doesn't exist yet. The
69 /// parameters `nx` and `ny` tell the number of x and y values provided
70 /// by the buffer in xStartIdx, and `isCoo` indicates whether the instruction
71 /// being processed is a sparse_tensor.sort or sparse_tensor.sort_coo.
72 //
73 // All sorting function generators take (lo, hi, xs, ys) in `operands` as
74 // parameters for the sorting functions. Other parameters, such as the recursive
75 // call depth, are appended to the end of the parameter list as
76 // "trailing parameters".
77 static FlatSymbolRefAttr
78 getMangledSortHelperFunc(OpBuilder &builder, func::FuncOp insertPoint,
79                          TypeRange resultTypes, StringRef namePrefix,
80                          uint64_t nx, uint64_t ny, bool isCoo,
81                          ValueRange operands, FuncGeneratorType createFunc,
82                          uint32_t nTrailingP = 0) {
83   SmallString<32> nameBuffer;
84   llvm::raw_svector_ostream nameOstream(nameBuffer);
85   getMangledSortHelperFuncName(nameOstream, namePrefix, nx, ny, isCoo,
86                                operands.drop_back(nTrailingP));
87 
88   ModuleOp module = insertPoint->getParentOfType<ModuleOp>();
89   MLIRContext *context = module.getContext();
90   auto result = SymbolRefAttr::get(context, nameOstream.str());
91   auto func = module.lookupSymbol<func::FuncOp>(result.getAttr());
92 
93   if (!func) {
94     // Create the function.
95     OpBuilder::InsertionGuard insertionGuard(builder);
96     builder.setInsertionPoint(insertPoint);
97     Location loc = insertPoint.getLoc();
98     func = builder.create<func::FuncOp>(
99         loc, nameOstream.str(),
100         FunctionType::get(context, operands.getTypes(), resultTypes));
101     func.setPrivate();
102     createFunc(builder, module, func, nx, ny, isCoo, nTrailingP);
103   }
104 
105   return result;
106 }
107 
108 /// Creates a code block to process each pair of (xs[i], xs[j]) for sorting.
109 /// The code to process the value pairs is generated by `bodyBuilder`.
110 static void forEachIJPairInXs(
111     OpBuilder &builder, Location loc, ValueRange args, uint64_t nx, uint64_t ny,
112     bool isCoo, function_ref<void(uint64_t, Value, Value, Value)> bodyBuilder) {
113   Value iOffset, jOffset;
114   if (isCoo) {
115     Value cstep = constantIndex(builder, loc, nx + ny);
116     iOffset = builder.create<arith::MulIOp>(loc, args[0], cstep);
117     jOffset = builder.create<arith::MulIOp>(loc, args[1], cstep);
118   }
119   for (uint64_t k = 0; k < nx; k++) {
120     scf::IfOp ifOp;
121     Value i, j, buffer;
122     if (isCoo) {
123       Value ck = constantIndex(builder, loc, k);
124       i = builder.create<arith::AddIOp>(loc, ck, iOffset);
125       j = builder.create<arith::AddIOp>(loc, ck, jOffset);
126       buffer = args[xStartIdx];
127     } else {
128       i = args[0];
129       j = args[1];
130       buffer = args[xStartIdx + k];
131     }
132     bodyBuilder(k, i, j, buffer);
133   }
134 }
135 
136 /// Creates a code block to process each pair of (xys[i], xys[j]) for sorting.
137 /// The code to process the value pairs is generated by `bodyBuilder`.
138 static void forEachIJPairInAllBuffers(
139     OpBuilder &builder, Location loc, ValueRange args, uint64_t nx, uint64_t ny,
140     bool isCoo, function_ref<void(uint64_t, Value, Value, Value)> bodyBuilder) {
141 
142   // Create code for the first (nx + ny) buffers. When isCoo==true, these
143   // logical buffers are all from the xy buffer of the sort_coo operator.
144   forEachIJPairInXs(builder, loc, args, nx + ny, 0, isCoo, bodyBuilder);
145 
146   uint64_t numHandledBuffers = isCoo ? 1 : nx + ny;
147 
148   // Create code for the remaining buffers.
149   Value i = args[0];
150   Value j = args[1];
151   for (const auto &arg :
152        llvm::enumerate(args.drop_front(xStartIdx + numHandledBuffers))) {
153     bodyBuilder(arg.index() + nx + ny, i, j, arg.value());
154   }
155 }
156 
157 /// Creates a code block for swapping the values in index i and j for all the
158 /// buffers.
159 //
160 // The generated IR corresponds to this C like algorithm:
161 //     swap(x0[i], x0[j]);
162 //     swap(x1[i], x1[j]);
163 //     ...
164 //     swap(xn[i], xn[j]);
165 //     swap(y0[i], y0[j]);
166 //     ...
167 //     swap(yn[i], yn[j]);
168 static void createSwap(OpBuilder &builder, Location loc, ValueRange args,
169                        uint64_t nx, uint64_t ny, bool isCoo) {
170   auto swapOnePair = [&](uint64_t unused, Value i, Value j, Value buffer) {
171     Value vi = builder.create<memref::LoadOp>(loc, buffer, i);
172     Value vj = builder.create<memref::LoadOp>(loc, buffer, j);
173     builder.create<memref::StoreOp>(loc, vj, buffer, i);
174     builder.create<memref::StoreOp>(loc, vi, buffer, j);
175   };
176 
177   forEachIJPairInAllBuffers(builder, loc, args, nx, ny, isCoo, swapOnePair);
178 }
179 
180 /// Creates a function to compare all the (xs[i], xs[j]) pairs. The method to
181 /// compare each pair is create via `compareBuilder`.
182 static void createCompareFuncImplementation(
183     OpBuilder &builder, ModuleOp unused, func::FuncOp func, uint64_t nx,
184     uint64_t ny, bool isCoo,
185     function_ref<scf::IfOp(OpBuilder &, Location, Value, Value, Value, bool)>
186         compareBuilder) {
187   OpBuilder::InsertionGuard insertionGuard(builder);
188 
189   Block *entryBlock = func.addEntryBlock();
190   builder.setInsertionPointToStart(entryBlock);
191   Location loc = func.getLoc();
192   ValueRange args = entryBlock->getArguments();
193 
194   scf::IfOp topIfOp;
195   auto bodyBuilder = [&](uint64_t k, Value i, Value j, Value buffer) {
196     scf::IfOp ifOp = compareBuilder(builder, loc, i, j, buffer, (k == nx - 1));
197     if (k == 0) {
198       topIfOp = ifOp;
199     } else {
200       OpBuilder::InsertionGuard insertionGuard(builder);
201       builder.setInsertionPointAfter(ifOp);
202       builder.create<scf::YieldOp>(loc, ifOp.getResult(0));
203     }
204   };
205 
206   forEachIJPairInXs(builder, loc, args, nx, ny, isCoo, bodyBuilder);
207 
208   builder.setInsertionPointAfter(topIfOp);
209   builder.create<func::ReturnOp>(loc, topIfOp.getResult(0));
210 }
211 
212 /// Generates an if-statement to compare whether x[i] is equal to x[j].
213 static scf::IfOp createEqCompare(OpBuilder &builder, Location loc, Value i,
214                                  Value j, Value x, bool isLastDim) {
215   Value f = constantI1(builder, loc, false);
216   Value t = constantI1(builder, loc, true);
217   Value vi = builder.create<memref::LoadOp>(loc, x, i);
218   Value vj = builder.create<memref::LoadOp>(loc, x, j);
219 
220   Value cond =
221       builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, vi, vj);
222   scf::IfOp ifOp =
223       builder.create<scf::IfOp>(loc, f.getType(), cond, /*else=*/true);
224 
225   // x[1] != x[j]:
226   builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
227   builder.create<scf::YieldOp>(loc, f);
228 
229   // x[i] == x[j]:
230   builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
231   if (isLastDim == 1) {
232     // Finish checking all dimensions.
233     builder.create<scf::YieldOp>(loc, t);
234   }
235 
236   return ifOp;
237 }
238 
239 /// Creates a function to compare whether xs[i] is equal to xs[j].
240 //
241 // The generate IR corresponds to this C like algorithm:
242 //   if (x0[i] != x0[j])
243 //     return false;
244 //   else
245 //     if (x1[i] != x1[j])
246 //       return false;
247 //     else if (x2[2] != x2[j]))
248 //       and so on ...
249 static void createEqCompareFunc(OpBuilder &builder, ModuleOp unused,
250                                 func::FuncOp func, uint64_t nx, uint64_t ny,
251                                 bool isCoo, uint32_t nTrailingP = 0) {
252   // Compare functions don't use trailing parameters.
253   (void)nTrailingP;
254   assert(nTrailingP == 0);
255   createCompareFuncImplementation(builder, unused, func, nx, ny, isCoo,
256                                   createEqCompare);
257 }
258 
259 /// Generates an if-statement to compare whether x[i] is less than x[j].
260 static scf::IfOp createLessThanCompare(OpBuilder &builder, Location loc,
261                                        Value i, Value j, Value x,
262                                        bool isLastDim) {
263   Value f = constantI1(builder, loc, false);
264   Value t = constantI1(builder, loc, true);
265   Value vi = builder.create<memref::LoadOp>(loc, x, i);
266   Value vj = builder.create<memref::LoadOp>(loc, x, j);
267 
268   Value cond =
269       builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult, vi, vj);
270   scf::IfOp ifOp =
271       builder.create<scf::IfOp>(loc, f.getType(), cond, /*else=*/true);
272   // If (x[i] < x[j]).
273   builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
274   builder.create<scf::YieldOp>(loc, t);
275 
276   builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
277   if (isLastDim == 1) {
278     // Finish checking all dimensions.
279     builder.create<scf::YieldOp>(loc, f);
280   } else {
281     cond =
282         builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult, vj, vi);
283     scf::IfOp ifOp2 =
284         builder.create<scf::IfOp>(loc, f.getType(), cond, /*else=*/true);
285     // Otherwise if (x[j] < x[i]).
286     builder.setInsertionPointToStart(&ifOp2.getThenRegion().front());
287     builder.create<scf::YieldOp>(loc, f);
288 
289     // Otherwise check the remaining dimensions.
290     builder.setInsertionPointAfter(ifOp2);
291     builder.create<scf::YieldOp>(loc, ifOp2.getResult(0));
292     // Set up the insertion point for the nested if-stmt that checks the
293     // remaining dimensions.
294     builder.setInsertionPointToStart(&ifOp2.getElseRegion().front());
295   }
296 
297   return ifOp;
298 }
299 
300 /// Creates a function to compare whether xs[i] is less than xs[j].
301 //
302 // The generate IR corresponds to this C like algorithm:
303 //   if (x0[i] < x0[j])
304 //     return true;
305 //   else if (x0[j] < x0[i])
306 //     return false;
307 //   else
308 //     if (x1[i] < x1[j])
309 //       return true;
310 //     else if (x1[j] < x1[i]))
311 //       and so on ...
312 static void createLessThanFunc(OpBuilder &builder, ModuleOp unused,
313                                func::FuncOp func, uint64_t nx, uint64_t ny,
314                                bool isCoo, uint32_t nTrailingP = 0) {
315   // Compare functions don't use trailing parameters.
316   (void)nTrailingP;
317   assert(nTrailingP == 0);
318   createCompareFuncImplementation(builder, unused, func, nx, ny, isCoo,
319                                   createLessThanCompare);
320 }
321 
322 /// Creates a function to use a binary search to find the insertion point for
323 /// inserting xs[hi] to the sorted values xs[lo..hi).
324 //
325 // The generate IR corresponds to this C like algorithm:
326 //   p = hi
327 //   while (lo < hi)
328 //      mid = (lo + hi) >> 1
329 //      if (xs[p] < xs[mid])
330 //        hi = mid
331 //      else
332 //        lo = mid - 1
333 //   return lo;
334 //
335 static void createBinarySearchFunc(OpBuilder &builder, ModuleOp module,
336                                    func::FuncOp func, uint64_t nx, uint64_t ny,
337                                    bool isCoo, uint32_t nTrailingP = 0) {
338   // Binary search doesn't use trailing parameters.
339   (void)nTrailingP;
340   assert(nTrailingP == 0);
341   OpBuilder::InsertionGuard insertionGuard(builder);
342   Block *entryBlock = func.addEntryBlock();
343   builder.setInsertionPointToStart(entryBlock);
344 
345   Location loc = func.getLoc();
346   ValueRange args = entryBlock->getArguments();
347   Value p = args[hiIdx];
348   SmallVector<Type, 2> types(2, p.getType()); // Only two types.
349   scf::WhileOp whileOp = builder.create<scf::WhileOp>(
350       loc, types, SmallVector<Value, 2>{args[loIdx], args[hiIdx]});
351 
352   // The before-region of the WhileOp.
353   Block *before =
354       builder.createBlock(&whileOp.getBefore(), {}, types, {loc, loc});
355   builder.setInsertionPointToEnd(before);
356   Value cond1 = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult,
357                                               before->getArgument(0),
358                                               before->getArgument(1));
359   builder.create<scf::ConditionOp>(loc, cond1, before->getArguments());
360 
361   // The after-region of the WhileOp.
362   Block *after =
363       builder.createBlock(&whileOp.getAfter(), {}, types, {loc, loc});
364   builder.setInsertionPointToEnd(after);
365   Value lo = after->getArgument(0);
366   Value hi = after->getArgument(1);
367   // Compute mid = (lo + hi) >> 1.
368   Value c1 = constantIndex(builder, loc, 1);
369   Value mid = builder.create<arith::ShRUIOp>(
370       loc, builder.create<arith::AddIOp>(loc, lo, hi), c1);
371   Value midp1 = builder.create<arith::AddIOp>(loc, mid, c1);
372 
373   // Compare xs[p] < xs[mid].
374   SmallVector<Value> compareOperands{p, mid};
375   uint64_t numXBuffers = isCoo ? 1 : nx;
376   compareOperands.append(args.begin() + xStartIdx,
377                          args.begin() + xStartIdx + numXBuffers);
378   Type i1Type = IntegerType::get(module.getContext(), 1, IntegerType::Signless);
379   FlatSymbolRefAttr lessThanFunc = getMangledSortHelperFunc(
380       builder, func, {i1Type}, kLessThanFuncNamePrefix, nx, ny, isCoo,
381       compareOperands, createLessThanFunc, nTrailingP);
382   Value cond2 = builder
383                     .create<func::CallOp>(loc, lessThanFunc, TypeRange{i1Type},
384                                           compareOperands)
385                     .getResult(0);
386 
387   // Update lo and hi for the WhileOp as follows:
388   //   if (xs[p] < xs[mid]))
389   //     hi = mid;
390   //   else
391   //     lo = mid + 1;
392   Value newLo = builder.create<arith::SelectOp>(loc, cond2, lo, midp1);
393   Value newHi = builder.create<arith::SelectOp>(loc, cond2, mid, hi);
394   builder.create<scf::YieldOp>(loc, ValueRange{newLo, newHi});
395 
396   builder.setInsertionPointAfter(whileOp);
397   builder.create<func::ReturnOp>(loc, whileOp.getResult(0));
398 }
399 
400 /// Creates code to advance i in a loop based on xs[p] as follows:
401 ///   while (xs[i] < xs[p]) i += step (step > 0)
402 /// or
403 ///   while (xs[i] > xs[p]) i += step (step < 0)
404 /// The routine returns i as well as a boolean value to indicate whether
405 /// xs[i] == xs[p].
406 static std::pair<Value, Value>
407 createScanLoop(OpBuilder &builder, ModuleOp module, func::FuncOp func,
408                ValueRange xs, Value i, Value p, uint64_t nx, uint64_t ny,
409                bool isCoo, int step) {
410   Location loc = func.getLoc();
411   scf::WhileOp whileOp =
412       builder.create<scf::WhileOp>(loc, TypeRange{i.getType()}, ValueRange{i});
413 
414   Block *before =
415       builder.createBlock(&whileOp.getBefore(), {}, {i.getType()}, {loc});
416   builder.setInsertionPointToEnd(before);
417   SmallVector<Value> compareOperands;
418   if (step > 0) {
419     compareOperands.push_back(before->getArgument(0));
420     compareOperands.push_back(p);
421   } else {
422     assert(step < 0);
423     compareOperands.push_back(p);
424     compareOperands.push_back(before->getArgument(0));
425   }
426   compareOperands.append(xs.begin(), xs.end());
427   MLIRContext *context = module.getContext();
428   Type i1Type = IntegerType::get(context, 1, IntegerType::Signless);
429   FlatSymbolRefAttr lessThanFunc = getMangledSortHelperFunc(
430       builder, func, {i1Type}, kLessThanFuncNamePrefix, nx, ny, isCoo,
431       compareOperands, createLessThanFunc);
432   Value cond = builder
433                    .create<func::CallOp>(loc, lessThanFunc, TypeRange{i1Type},
434                                          compareOperands)
435                    .getResult(0);
436   builder.create<scf::ConditionOp>(loc, cond, before->getArguments());
437 
438   Block *after =
439       builder.createBlock(&whileOp.getAfter(), {}, {i.getType()}, {loc});
440   builder.setInsertionPointToEnd(after);
441   Value cs = constantIndex(builder, loc, step);
442   i = builder.create<arith::AddIOp>(loc, after->getArgument(0), cs);
443   builder.create<scf::YieldOp>(loc, ValueRange{i});
444   i = whileOp.getResult(0);
445 
446   builder.setInsertionPointAfter(whileOp);
447   compareOperands[0] = i;
448   compareOperands[1] = p;
449   FlatSymbolRefAttr compareEqFunc = getMangledSortHelperFunc(
450       builder, func, {i1Type}, kCompareEqFuncNamePrefix, nx, ny, isCoo,
451       compareOperands, createEqCompareFunc);
452   Value compareEq =
453       builder
454           .create<func::CallOp>(loc, compareEqFunc, TypeRange{i1Type},
455                                 compareOperands)
456           .getResult(0);
457 
458   return std::make_pair(whileOp.getResult(0), compareEq);
459 }
460 
461 /// Creates a code block to swap the values so that data[mi] is the median among
462 /// data[lo], data[hi], and data[mi].
463 //  The generated code corresponds to this C-like algorithm:
464 //  median = mi
465 //  if (data[mi] < data[lo]).                               (if1)
466 //    if (data[hi] < data[lo])                              (if2)
467 //       median = data[hi] < data[mi] ? mi : hi
468 //    else
469 //       median = lo
470 //  else
471 //    if data[hi] < data[mi]                                (if3)
472 //      median = data[hi] < data[lo] ? lo : hi
473 //  if median != mi swap data[median] with data[mi]
474 static void createChoosePivot(OpBuilder &builder, ModuleOp module,
475                               func::FuncOp func, uint64_t nx, uint64_t ny,
476                               bool isCoo, Value lo, Value hi, Value mi,
477                               ValueRange args) {
478   SmallVector<Value> compareOperands{mi, lo};
479   uint64_t numXBuffers = isCoo ? 1 : nx;
480   compareOperands.append(args.begin() + xStartIdx,
481                          args.begin() + xStartIdx + numXBuffers);
482   Type i1Type = IntegerType::get(module.getContext(), 1, IntegerType::Signless);
483   SmallVector<Type, 1> cmpTypes{i1Type};
484   FlatSymbolRefAttr lessThanFunc = getMangledSortHelperFunc(
485       builder, func, cmpTypes, kLessThanFuncNamePrefix, nx, ny, isCoo,
486       compareOperands, createLessThanFunc);
487   Location loc = func.getLoc();
488   // Compare data[mi] < data[lo].
489   Value cond1 =
490       builder.create<func::CallOp>(loc, lessThanFunc, cmpTypes, compareOperands)
491           .getResult(0);
492   SmallVector<Type, 1> ifTypes{lo.getType()};
493   scf::IfOp ifOp1 =
494       builder.create<scf::IfOp>(loc, ifTypes, cond1, /*else=*/true);
495 
496   // Generate an if-stmt to find the median value, assuming we already know that
497   // data[b] < data[a] and we haven't compare data[c] yet.
498   auto createFindMedian = [&](Value a, Value b, Value c) -> scf::IfOp {
499     compareOperands[0] = c;
500     compareOperands[1] = a;
501     // Compare data[c]] < data[a].
502     Value cond2 =
503         builder
504             .create<func::CallOp>(loc, lessThanFunc, cmpTypes, compareOperands)
505             .getResult(0);
506     scf::IfOp ifOp2 =
507         builder.create<scf::IfOp>(loc, ifTypes, cond2, /*else=*/true);
508     builder.setInsertionPointToStart(&ifOp2.getThenRegion().front());
509     compareOperands[0] = c;
510     compareOperands[1] = b;
511     // Compare data[c] < data[b].
512     Value cond3 =
513         builder
514             .create<func::CallOp>(loc, lessThanFunc, cmpTypes, compareOperands)
515             .getResult(0);
516     builder.create<scf::YieldOp>(
517         loc, ValueRange{builder.create<arith::SelectOp>(loc, cond3, b, c)});
518     builder.setInsertionPointToStart(&ifOp2.getElseRegion().front());
519     builder.create<scf::YieldOp>(loc, ValueRange{a});
520     return ifOp2;
521   };
522 
523   builder.setInsertionPointToStart(&ifOp1.getThenRegion().front());
524   scf::IfOp ifOp2 = createFindMedian(lo, mi, hi);
525   builder.setInsertionPointAfter(ifOp2);
526   builder.create<scf::YieldOp>(loc, ValueRange{ifOp2.getResult(0)});
527 
528   builder.setInsertionPointToStart(&ifOp1.getElseRegion().front());
529   scf::IfOp ifOp3 = createFindMedian(mi, lo, hi);
530 
531   builder.setInsertionPointAfter(ifOp3);
532   builder.create<scf::YieldOp>(loc, ValueRange{ifOp3.getResult(0)});
533 
534   builder.setInsertionPointAfter(ifOp1);
535   Value median = ifOp1.getResult(0);
536   Value cond =
537       builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ne, mi, median);
538   scf::IfOp ifOp =
539       builder.create<scf::IfOp>(loc, TypeRange(), cond, /*else=*/false);
540 
541   SmallVector<Value> swapOperands{median, mi};
542   swapOperands.append(args.begin() + xStartIdx, args.end());
543   builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
544   createSwap(builder, loc, swapOperands, nx, ny, isCoo);
545   builder.setInsertionPointAfter(ifOp);
546 }
547 
548 /// Creates a function to perform quick sort partition on the values in the
549 /// range of index [lo, hi), assuming lo < hi.
550 //
551 // The generated IR corresponds to this C like algorithm:
552 // int partition(lo, hi, xs) {
553 //   p = (lo+hi)/2  // pivot index
554 //   i = lo
555 //   j = hi-1
556 //   while (i < j) do {
557 //     while (xs[i] < xs[p]) i ++;
558 //     i_eq = (xs[i] == xs[p]);
559 //     while (xs[j] > xs[p]) j --;
560 //     j_eq = (xs[j] == xs[p]);
561 //     if (i < j) {
562 //       swap(xs[i], xs[j])
563 //       if (i == p) {
564 //         p = j;
565 //       } else if (j == p) {
566 //         p = i;
567 //       }
568 //       if (i_eq && j_eq) {
569 //         ++i;
570 //         --j;
571 //       }
572 //     }
573 //   }
574 //   return p
575 //   }
576 static void createPartitionFunc(OpBuilder &builder, ModuleOp module,
577                                 func::FuncOp func, uint64_t nx, uint64_t ny,
578                                 bool isCoo, uint32_t nTrailingP = 0) {
579   // Quick sort partition doesn't use trailing parameters.
580   (void)nTrailingP;
581   assert(nTrailingP == 0);
582   OpBuilder::InsertionGuard insertionGuard(builder);
583 
584   Block *entryBlock = func.addEntryBlock();
585   builder.setInsertionPointToStart(entryBlock);
586 
587   Location loc = func.getLoc();
588   ValueRange args = entryBlock->getArguments();
589   Value lo = args[loIdx];
590   Value hi = args[hiIdx];
591   Value sum = builder.create<arith::AddIOp>(loc, lo, hi);
592   Value c1 = constantIndex(builder, loc, 1);
593   Value p = builder.create<arith::ShRUIOp>(loc, sum, c1);
594 
595   Value i = lo;
596   Value j = builder.create<arith::SubIOp>(loc, hi, c1);
597   createChoosePivot(builder, module, func, nx, ny, isCoo, i, j, p, args);
598   SmallVector<Value, 3> operands{i, j, p}; // Exactly three values.
599   SmallVector<Type, 3> types{i.getType(), j.getType(), p.getType()};
600   scf::WhileOp whileOp = builder.create<scf::WhileOp>(loc, types, operands);
601 
602   // The before-region of the WhileOp.
603   Block *before =
604       builder.createBlock(&whileOp.getBefore(), {}, types, {loc, loc, loc});
605   builder.setInsertionPointToEnd(before);
606   Value cond = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult,
607                                              before->getArgument(0),
608                                              before->getArgument(1));
609   builder.create<scf::ConditionOp>(loc, cond, before->getArguments());
610 
611   // The after-region of the WhileOp.
612   Block *after =
613       builder.createBlock(&whileOp.getAfter(), {}, types, {loc, loc, loc});
614   builder.setInsertionPointToEnd(after);
615   i = after->getArgument(0);
616   j = after->getArgument(1);
617   p = after->getArgument(2);
618 
619   uint64_t numXBuffers = isCoo ? 1 : nx;
620   auto [iresult, iCompareEq] =
621       createScanLoop(builder, module, func, args.slice(xStartIdx, numXBuffers),
622                      i, p, nx, ny, isCoo, 1);
623   i = iresult;
624   auto [jresult, jCompareEq] =
625       createScanLoop(builder, module, func, args.slice(xStartIdx, numXBuffers),
626                      j, p, nx, ny, isCoo, -1);
627   j = jresult;
628 
629   // If i < j:
630   cond = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult, i, j);
631   scf::IfOp ifOp = builder.create<scf::IfOp>(loc, types, cond, /*else=*/true);
632   builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
633   SmallVector<Value> swapOperands{i, j};
634   swapOperands.append(args.begin() + xStartIdx, args.end());
635   createSwap(builder, loc, swapOperands, nx, ny, isCoo);
636   // If the pivot is moved, update p with the new pivot.
637   Value icond =
638       builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, i, p);
639   scf::IfOp ifOpI = builder.create<scf::IfOp>(loc, TypeRange{p.getType()},
640                                               icond, /*else=*/true);
641   builder.setInsertionPointToStart(&ifOpI.getThenRegion().front());
642   builder.create<scf::YieldOp>(loc, ValueRange{j});
643   builder.setInsertionPointToStart(&ifOpI.getElseRegion().front());
644   Value jcond =
645       builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, j, p);
646   scf::IfOp ifOpJ = builder.create<scf::IfOp>(loc, TypeRange{p.getType()},
647                                               jcond, /*else=*/true);
648   builder.setInsertionPointToStart(&ifOpJ.getThenRegion().front());
649   builder.create<scf::YieldOp>(loc, ValueRange{i});
650   builder.setInsertionPointToStart(&ifOpJ.getElseRegion().front());
651   builder.create<scf::YieldOp>(loc, ValueRange{p});
652   builder.setInsertionPointAfter(ifOpJ);
653   builder.create<scf::YieldOp>(loc, ifOpJ.getResults());
654   builder.setInsertionPointAfter(ifOpI);
655   Value compareEqIJ =
656       builder.create<arith::AndIOp>(loc, iCompareEq, jCompareEq);
657   scf::IfOp ifOp2 = builder.create<scf::IfOp>(
658       loc, TypeRange{i.getType(), j.getType()}, compareEqIJ, /*else=*/true);
659   builder.setInsertionPointToStart(&ifOp2.getThenRegion().front());
660   Value i2 = builder.create<arith::AddIOp>(loc, i, c1);
661   Value j2 = builder.create<arith::SubIOp>(loc, j, c1);
662   builder.create<scf::YieldOp>(loc, ValueRange{i2, j2});
663   builder.setInsertionPointToStart(&ifOp2.getElseRegion().front());
664   builder.create<scf::YieldOp>(loc, ValueRange{i, j});
665   builder.setInsertionPointAfter(ifOp2);
666   builder.create<scf::YieldOp>(
667       loc,
668       ValueRange{ifOp2.getResult(0), ifOp2.getResult(1), ifOpI.getResult(0)});
669 
670   // False branch for if i < j:
671   builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
672   builder.create<scf::YieldOp>(loc, ValueRange{i, j, p});
673 
674   // Return for the whileOp.
675   builder.setInsertionPointAfter(ifOp);
676   builder.create<scf::YieldOp>(loc, ifOp.getResults());
677 
678   // Return for the function.
679   builder.setInsertionPointAfter(whileOp);
680   builder.create<func::ReturnOp>(loc, whileOp.getResult(2));
681 }
682 
683 /// Creates a function to perform quick sort on the value in the range of
684 /// index [lo, hi).
685 //
686 // The generate IR corresponds to this C like algorithm:
687 // void quickSort(lo, hi, data) {
688 //   if (lo < hi) {
689 //        p = partition(low, high, data);
690 //        quickSort(lo, p, data);
691 //        quickSort(p + 1, hi, data);
692 //   }
693 // }
694 static void createSortNonstableFunc(OpBuilder &builder, ModuleOp module,
695                                     func::FuncOp func, uint64_t nx, uint64_t ny,
696                                     bool isCoo, uint32_t nTrailingP) {
697   (void)nTrailingP;
698   OpBuilder::InsertionGuard insertionGuard(builder);
699   Block *entryBlock = func.addEntryBlock();
700   builder.setInsertionPointToStart(entryBlock);
701 
702   MLIRContext *context = module.getContext();
703   Location loc = func.getLoc();
704   ValueRange args = entryBlock->getArguments();
705   Value lo = args[loIdx];
706   Value hi = args[hiIdx];
707   Value cond =
708       builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult, lo, hi);
709   scf::IfOp ifOp = builder.create<scf::IfOp>(loc, cond, /*else=*/false);
710 
711   // The if-stmt true branch.
712   builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
713   FlatSymbolRefAttr partitionFunc = getMangledSortHelperFunc(
714       builder, func, {IndexType::get(context)}, kPartitionFuncNamePrefix, nx,
715       ny, isCoo, args, createPartitionFunc);
716   auto p = builder.create<func::CallOp>(
717       loc, partitionFunc, TypeRange{IndexType::get(context)}, ValueRange(args));
718 
719   SmallVector<Value> lowOperands{lo, p.getResult(0)};
720   lowOperands.append(args.begin() + xStartIdx, args.end());
721   builder.create<func::CallOp>(loc, func, lowOperands);
722 
723   SmallVector<Value> highOperands{
724       builder.create<arith::AddIOp>(loc, p.getResult(0),
725                                     constantIndex(builder, loc, 1)),
726       hi};
727   highOperands.append(args.begin() + xStartIdx, args.end());
728   builder.create<func::CallOp>(loc, func, highOperands);
729 
730   // After the if-stmt.
731   builder.setInsertionPointAfter(ifOp);
732   builder.create<func::ReturnOp>(loc);
733 }
734 
735 /// Creates a function to perform insertion sort on the values in the range of
736 /// index [lo, hi).
737 //
738 // The generate IR corresponds to this C like algorithm:
739 // void insertionSort(lo, hi, data) {
740 //   for (i = lo+1; i < hi; i++) {
741 //      d = data[i];
742 //      p = binarySearch(lo, i-1, data)
743 //      for (j = 0; j > i - p; j++)
744 //        data[i-j] = data[i-j-1]
745 //      data[p] = d
746 //   }
747 // }
748 static void createSortStableFunc(OpBuilder &builder, ModuleOp module,
749                                  func::FuncOp func, uint64_t nx, uint64_t ny,
750                                  bool isCoo, uint32_t nTrailingP) {
751   // Stable sort function doesn't use trailing parameters.
752   (void)nTrailingP;
753   assert(nTrailingP == 0);
754   OpBuilder::InsertionGuard insertionGuard(builder);
755   Block *entryBlock = func.addEntryBlock();
756   builder.setInsertionPointToStart(entryBlock);
757 
758   MLIRContext *context = module.getContext();
759   Location loc = func.getLoc();
760   ValueRange args = entryBlock->getArguments();
761   Value c1 = constantIndex(builder, loc, 1);
762   Value lo = args[loIdx];
763   Value hi = args[hiIdx];
764   Value lop1 = builder.create<arith::AddIOp>(loc, lo, c1);
765 
766   // Start the outer for-stmt with induction variable i.
767   scf::ForOp forOpI = builder.create<scf::ForOp>(loc, lop1, hi, c1);
768   builder.setInsertionPointToStart(forOpI.getBody());
769   Value i = forOpI.getInductionVar();
770 
771   // Binary search to find the insertion point p.
772   SmallVector<Value> operands{lo, i};
773   operands.append(args.begin() + xStartIdx, args.end());
774   FlatSymbolRefAttr searchFunc = getMangledSortHelperFunc(
775       builder, func, {IndexType::get(context)}, kBinarySearchFuncNamePrefix, nx,
776       ny, isCoo, operands, createBinarySearchFunc);
777   Value p = builder
778                 .create<func::CallOp>(loc, searchFunc, TypeRange{c1.getType()},
779                                       operands)
780                 .getResult(0);
781 
782   // Move the value at data[i] to a temporary location.
783   operands[0] = operands[1] = i;
784   SmallVector<Value> d;
785   forEachIJPairInAllBuffers(
786       builder, loc, operands, nx, ny, isCoo,
787       [&](uint64_t unused, Value i, Value unused2, Value buffer) {
788         d.push_back(builder.create<memref::LoadOp>(loc, buffer, i));
789       });
790 
791   // Start the inner for-stmt with induction variable j, for moving data[p..i)
792   // to data[p+1..i+1).
793   Value imp = builder.create<arith::SubIOp>(loc, i, p);
794   Value c0 = constantIndex(builder, loc, 0);
795   scf::ForOp forOpJ = builder.create<scf::ForOp>(loc, c0, imp, c1);
796   builder.setInsertionPointToStart(forOpJ.getBody());
797   Value j = forOpJ.getInductionVar();
798   Value imj = builder.create<arith::SubIOp>(loc, i, j);
799   operands[1] = imj;
800   operands[0] = builder.create<arith::SubIOp>(loc, imj, c1);
801   forEachIJPairInAllBuffers(
802       builder, loc, operands, nx, ny, isCoo,
803       [&](uint64_t unused, Value imjm1, Value imj, Value buffer) {
804         Value t = builder.create<memref::LoadOp>(loc, buffer, imjm1);
805         builder.create<memref::StoreOp>(loc, t, buffer, imj);
806       });
807 
808   // Store the value at data[i] to data[p].
809   builder.setInsertionPointAfter(forOpJ);
810   operands[0] = operands[1] = p;
811   forEachIJPairInAllBuffers(
812       builder, loc, operands, nx, ny, isCoo,
813       [&](uint64_t k, Value p, Value usused, Value buffer) {
814         builder.create<memref::StoreOp>(loc, d[k], buffer, p);
815       });
816 
817   builder.setInsertionPointAfter(forOpI);
818   builder.create<func::ReturnOp>(loc);
819 }
820 
821 /// Implements the rewriting for operator sort and sort_coo.
822 template <typename OpTy>
823 LogicalResult matchAndRewriteSortOp(OpTy op, ValueRange xys, uint64_t nx,
824                                     uint64_t ny, bool isCoo,
825                                     PatternRewriter &rewriter) {
826   Location loc = op.getLoc();
827   SmallVector<Value> operands{constantIndex(rewriter, loc, 0), op.getN()};
828 
829   // Convert `values` to have dynamic shape and append them to `operands`.
830   for (Value v : xys) {
831     auto mtp = getMemRefType(v);
832     if (!mtp.isDynamicDim(0)) {
833       auto newMtp =
834           MemRefType::get({ShapedType::kDynamic}, mtp.getElementType());
835       v = rewriter.create<memref::CastOp>(loc, newMtp, v);
836     }
837     operands.push_back(v);
838   }
839   bool isStable =
840       (op.getAlgorithm() == SparseTensorSortKind::InsertionSortStable);
841   auto insertPoint = op->template getParentOfType<func::FuncOp>();
842   SmallString<32> funcName(isStable ? kSortStableFuncNamePrefix
843                                     : kSortNonstableFuncNamePrefix);
844   FuncGeneratorType funcGenerator =
845       isStable ? createSortStableFunc : createSortNonstableFunc;
846   uint32_t nTrailingP = 0;
847   FlatSymbolRefAttr func =
848       getMangledSortHelperFunc(rewriter, insertPoint, TypeRange(), funcName, nx,
849                                ny, isCoo, operands, funcGenerator, nTrailingP);
850   rewriter.replaceOpWithNewOp<func::CallOp>(op, func, TypeRange(), operands);
851   return success();
852 }
853 
854 //===---------------------------------------------------------------------===//
855 // The actual sparse buffer rewriting rules.
856 //===---------------------------------------------------------------------===//
857 
858 namespace {
859 
860 /// Sparse rewriting rule for the push_back operator.
861 struct PushBackRewriter : OpRewritePattern<PushBackOp> {
862 public:
863   using OpRewritePattern<PushBackOp>::OpRewritePattern;
864   PushBackRewriter(MLIRContext *context, bool enableInit)
865       : OpRewritePattern(context), enableBufferInitialization(enableInit) {}
866   LogicalResult matchAndRewrite(PushBackOp op,
867                                 PatternRewriter &rewriter) const override {
868     // Rewrite push_back(buffer, value, n) to:
869     // new_size = size(buffer) + n
870     // if (new_size > capacity(buffer))
871     //    while new_size > new_capacity
872     //      new_capacity = new_capacity*2
873     //    new_buffer = realloc(buffer, new_capacity)
874     // buffer = new_buffer
875     // subBuffer = subviewof(buffer)
876     // linalg.fill subBuffer value
877     //
878     // size(buffer) += n
879     //
880     // The capacity check is skipped when the attribute inbounds is presented.
881     Location loc = op->getLoc();
882     Value c0 = constantIndex(rewriter, loc, 0);
883     Value buffer = op.getInBuffer();
884     Value capacity = rewriter.create<memref::DimOp>(loc, buffer, c0);
885     Value size = op.getCurSize();
886     Value value = op.getValue();
887 
888     Value n = op.getN() ? op.getN() : constantIndex(rewriter, loc, 1);
889     Value newSize = rewriter.create<arith::AddIOp>(loc, size, n);
890     auto nValue = dyn_cast_or_null<arith::ConstantIndexOp>(n.getDefiningOp());
891     bool nIsOne = (nValue && nValue.value() == 1);
892 
893     if (!op.getInbounds()) {
894       Value cond = rewriter.create<arith::CmpIOp>(
895           loc, arith::CmpIPredicate::ugt, newSize, capacity);
896 
897       Value c2 = constantIndex(rewriter, loc, 2);
898       auto bufferType =
899           MemRefType::get({ShapedType::kDynamic}, value.getType());
900       scf::IfOp ifOp = rewriter.create<scf::IfOp>(loc, bufferType, cond,
901                                                   /*else=*/true);
902       // True branch.
903       rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front());
904       if (nIsOne) {
905         capacity = rewriter.create<arith::MulIOp>(loc, capacity, c2);
906       } else {
907         // Use a do-while loop to calculate the new capacity as follows:
908         //   do { new_capacity *= 2 } while (size > new_capacity)
909         scf::WhileOp whileOp =
910             rewriter.create<scf::WhileOp>(loc, capacity.getType(), capacity);
911 
912         // The before-region of the WhileOp.
913         Block *before = rewriter.createBlock(&whileOp.getBefore(), {},
914                                              {capacity.getType()}, {loc});
915         rewriter.setInsertionPointToEnd(before);
916 
917         capacity =
918             rewriter.create<arith::MulIOp>(loc, before->getArgument(0), c2);
919         cond = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ugt,
920                                               newSize, capacity);
921         rewriter.create<scf::ConditionOp>(loc, cond, ValueRange{capacity});
922         // The after-region of the WhileOp.
923         Block *after = rewriter.createBlock(&whileOp.getAfter(), {},
924                                             {capacity.getType()}, {loc});
925         rewriter.setInsertionPointToEnd(after);
926         rewriter.create<scf::YieldOp>(loc, after->getArguments());
927 
928         rewriter.setInsertionPointAfter(whileOp);
929         capacity = whileOp.getResult(0);
930       }
931 
932       Value newBuffer =
933           rewriter.create<memref::ReallocOp>(loc, bufferType, buffer, capacity);
934       if (enableBufferInitialization) {
935         Value fillSize = rewriter.create<arith::SubIOp>(loc, capacity, newSize);
936         Value fillValue = constantZero(rewriter, loc, value.getType());
937         Value subBuffer = rewriter.create<memref::SubViewOp>(
938             loc, newBuffer, /*offset=*/ValueRange{newSize},
939             /*size=*/ValueRange{fillSize},
940             /*step=*/ValueRange{constantIndex(rewriter, loc, 1)});
941         rewriter.create<linalg::FillOp>(loc, fillValue, subBuffer);
942       }
943       rewriter.create<scf::YieldOp>(loc, newBuffer);
944 
945       // False branch.
946       rewriter.setInsertionPointToStart(&ifOp.getElseRegion().front());
947       rewriter.create<scf::YieldOp>(loc, buffer);
948 
949       // Prepare for adding the value to the end of the buffer.
950       rewriter.setInsertionPointAfter(ifOp);
951       buffer = ifOp.getResult(0);
952     }
953 
954     // Add the value to the end of the buffer.
955     if (nIsOne) {
956       rewriter.create<memref::StoreOp>(loc, value, buffer, size);
957     } else {
958       Value subBuffer = rewriter.create<memref::SubViewOp>(
959           loc, buffer, /*offset=*/ValueRange{size}, /*size=*/ValueRange{n},
960           /*step=*/ValueRange{constantIndex(rewriter, loc, 1)});
961       rewriter.create<linalg::FillOp>(loc, value, subBuffer);
962     }
963 
964     // Update the buffer size.
965     rewriter.replaceOp(op, {buffer, newSize});
966     return success();
967   }
968 
969 private:
970   bool enableBufferInitialization;
971 };
972 
973 /// Sparse rewriting rule for the sort operator.
974 struct SortRewriter : public OpRewritePattern<SortOp> {
975 public:
976   using OpRewritePattern<SortOp>::OpRewritePattern;
977 
978   LogicalResult matchAndRewrite(SortOp op,
979                                 PatternRewriter &rewriter) const override {
980     SmallVector<Value> xys(op.getXs());
981     xys.append(op.getYs().begin(), op.getYs().end());
982     return matchAndRewriteSortOp(op, xys, op.getXs().size(), /*ny=*/0,
983                                  /*isCoo=*/false, rewriter);
984   }
985 };
986 
987 /// Sparse rewriting rule for the sort_coo operator.
988 struct SortCooRewriter : public OpRewritePattern<SortCooOp> {
989 public:
990   using OpRewritePattern<SortCooOp>::OpRewritePattern;
991 
992   LogicalResult matchAndRewrite(SortCooOp op,
993                                 PatternRewriter &rewriter) const override {
994     SmallVector<Value> xys;
995     xys.push_back(op.getXy());
996     xys.append(op.getYs().begin(), op.getYs().end());
997     uint64_t nx = 1;
998     if (auto nxAttr = op.getNxAttr())
999       nx = nxAttr.getInt();
1000 
1001     uint64_t ny = 0;
1002     if (auto nyAttr = op.getNyAttr())
1003       ny = nyAttr.getInt();
1004 
1005     return matchAndRewriteSortOp(op, xys, nx, ny,
1006                                  /*isCoo=*/true, rewriter);
1007   }
1008 };
1009 
1010 } // namespace
1011 
1012 //===---------------------------------------------------------------------===//
1013 // Methods that add patterns described in this file to a pattern list.
1014 //===---------------------------------------------------------------------===//
1015 
1016 void mlir::populateSparseBufferRewriting(RewritePatternSet &patterns,
1017                                          bool enableBufferInitialization) {
1018   patterns.add<PushBackRewriter>(patterns.getContext(),
1019                                  enableBufferInitialization);
1020   patterns.add<SortRewriter, SortCooRewriter>(patterns.getContext());
1021 }
1022