xref: /llvm-project/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp (revision ea4be70cea8509520db8638bb17bcd7b5d8d60ac)
1 //===- SparseBufferRewriting.cpp - Sparse buffer rewriting rules ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements rewriting rules that are specific to sparse tensor
10 // primitives with memref operands.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "CodegenUtils.h"
15 
16 #include "mlir/Dialect/Arith/IR/Arith.h"
17 #include "mlir/Dialect/Func/IR/FuncOps.h"
18 #include "mlir/Dialect/Linalg/IR/Linalg.h"
19 #include "mlir/Dialect/MemRef/IR/MemRef.h"
20 #include "mlir/Dialect/SCF/IR/SCF.h"
21 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
22 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
23 #include "mlir/Support/LLVM.h"
24 
25 using namespace mlir;
26 using namespace mlir::sparse_tensor;
27 
28 //===---------------------------------------------------------------------===//
29 // Helper methods for the actual rewriting rules.
30 //===---------------------------------------------------------------------===//
31 
32 static constexpr uint64_t loIdx = 0;
33 static constexpr uint64_t hiIdx = 1;
34 static constexpr uint64_t xStartIdx = 2;
35 
36 static constexpr const char kLessThanFuncNamePrefix[] = "_sparse_less_than_";
37 static constexpr const char kCompareEqFuncNamePrefix[] = "_sparse_compare_eq_";
38 static constexpr const char kPartitionFuncNamePrefix[] = "_sparse_partition_";
39 static constexpr const char kBinarySearchFuncNamePrefix[] =
40     "_sparse_binary_search_";
41 static constexpr const char kSortNonstableFuncNamePrefix[] =
42     "_sparse_sort_nonstable_";
43 static constexpr const char kSortStableFuncNamePrefix[] =
44     "_sparse_sort_stable_";
45 
46 using FuncGeneratorType = function_ref<void(OpBuilder &, ModuleOp, func::FuncOp,
47                                             uint64_t, uint64_t, bool)>;
48 
49 /// Constructs a function name with this format to facilitate quick sort:
50 ///   <namePrefix><nx>_<x type>_<y0 type>..._<yn type> for sort
51 ///   <namePrefix><nx>_<x type>_coo_<ny>_<y0 type>..._<yn type> for sort_coo
52 static void getMangledSortHelperFuncName(llvm::raw_svector_ostream &nameOstream,
53                                          StringRef namePrefix, uint64_t nx,
54                                          uint64_t ny, bool isCoo,
55                                          ValueRange operands) {
56   nameOstream
57       << namePrefix << nx << "_"
58       << operands[xStartIdx].getType().cast<MemRefType>().getElementType();
59 
60   if (isCoo)
61     nameOstream << "_coo_" << ny;
62 
63   uint64_t yBufferOffset = isCoo ? 1 : nx;
64   for (Value v : operands.drop_front(xStartIdx + yBufferOffset))
65     nameOstream << "_" << v.getType().cast<MemRefType>().getElementType();
66 }
67 
68 /// Looks up a function that is appropriate for the given operands being
69 /// sorted, and creates such a function if it doesn't exist yet. The
70 /// parameters `nx` and `ny` tell the number of x and y values provided
71 /// by the buffer in xStartIdx, and `isCoo` indicates whether the instruction
72 /// being processed is a sparse_tensor.sort or sparse_tensor.sort_coo.
73 static FlatSymbolRefAttr
74 getMangledSortHelperFunc(OpBuilder &builder, func::FuncOp insertPoint,
75                          TypeRange resultTypes, StringRef namePrefix,
76                          uint64_t nx, uint64_t ny, bool isCoo,
77                          ValueRange operands, FuncGeneratorType createFunc) {
78   SmallString<32> nameBuffer;
79   llvm::raw_svector_ostream nameOstream(nameBuffer);
80   getMangledSortHelperFuncName(nameOstream, namePrefix, nx, ny, isCoo,
81                                operands);
82 
83   ModuleOp module = insertPoint->getParentOfType<ModuleOp>();
84   MLIRContext *context = module.getContext();
85   auto result = SymbolRefAttr::get(context, nameOstream.str());
86   auto func = module.lookupSymbol<func::FuncOp>(result.getAttr());
87 
88   if (!func) {
89     // Create the function.
90     OpBuilder::InsertionGuard insertionGuard(builder);
91     builder.setInsertionPoint(insertPoint);
92     Location loc = insertPoint.getLoc();
93     func = builder.create<func::FuncOp>(
94         loc, nameOstream.str(),
95         FunctionType::get(context, operands.getTypes(), resultTypes));
96     func.setPrivate();
97     createFunc(builder, module, func, nx, ny, isCoo);
98   }
99 
100   return result;
101 }
102 
103 /// Creates a code block to process each pair of (xs[i], xs[j]) for sorting.
104 /// The code to process the value pairs is generated by `bodyBuilder`.
105 static void forEachIJPairInXs(
106     OpBuilder &builder, Location loc, ValueRange args, uint64_t nx, uint64_t ny,
107     bool isCoo, function_ref<void(uint64_t, Value, Value, Value)> bodyBuilder) {
108   Value iOffset, jOffset;
109   if (isCoo) {
110     Value cstep = constantIndex(builder, loc, nx + ny);
111     iOffset = builder.create<arith::MulIOp>(loc, args[0], cstep);
112     jOffset = builder.create<arith::MulIOp>(loc, args[1], cstep);
113   }
114   for (uint64_t k = 0; k < nx; k++) {
115     scf::IfOp ifOp;
116     Value i, j, buffer;
117     if (isCoo) {
118       Value ck = constantIndex(builder, loc, k);
119       i = builder.create<arith::AddIOp>(loc, ck, iOffset);
120       j = builder.create<arith::AddIOp>(loc, ck, jOffset);
121       buffer = args[xStartIdx];
122     } else {
123       i = args[0];
124       j = args[1];
125       buffer = args[xStartIdx + k];
126     }
127     bodyBuilder(k, i, j, buffer);
128   }
129 }
130 
131 /// Creates a code block to process each pair of (xys[i], xys[j]) for sorting.
132 /// The code to process the value pairs is generated by `bodyBuilder`.
133 static void forEachIJPairInAllBuffers(
134     OpBuilder &builder, Location loc, ValueRange args, uint64_t nx, uint64_t ny,
135     bool isCoo, function_ref<void(uint64_t, Value, Value, Value)> bodyBuilder) {
136 
137   // Create code for the first (nx + ny) buffers. When isCoo==true, these
138   // logical buffers are all from the xy buffer of the sort_coo operator.
139   forEachIJPairInXs(builder, loc, args, nx + ny, 0, isCoo, bodyBuilder);
140 
141   uint64_t numHandledBuffers = isCoo ? 1 : nx + ny;
142 
143   // Create code for the remaining buffers.
144   Value i = args[0];
145   Value j = args[1];
146   for (const auto &arg :
147        llvm::enumerate(args.drop_front(xStartIdx + numHandledBuffers))) {
148     bodyBuilder(arg.index() + nx + ny, i, j, arg.value());
149   }
150 }
151 
152 /// Creates a code block for swapping the values in index i and j for all the
153 /// buffers.
154 //
155 // The generated IR corresponds to this C like algorithm:
156 //     swap(x0[i], x0[j]);
157 //     swap(x1[i], x1[j]);
158 //     ...
159 //     swap(xn[i], xn[j]);
160 //     swap(y0[i], y0[j]);
161 //     ...
162 //     swap(yn[i], yn[j]);
163 static void createSwap(OpBuilder &builder, Location loc, ValueRange args,
164                        uint64_t nx, uint64_t ny, bool isCoo) {
165   auto swapOnePair = [&](uint64_t unused, Value i, Value j, Value buffer) {
166     Value vi = builder.create<memref::LoadOp>(loc, buffer, i);
167     Value vj = builder.create<memref::LoadOp>(loc, buffer, j);
168     builder.create<memref::StoreOp>(loc, vj, buffer, i);
169     builder.create<memref::StoreOp>(loc, vi, buffer, j);
170   };
171 
172   forEachIJPairInAllBuffers(builder, loc, args, nx, ny, isCoo, swapOnePair);
173 }
174 
175 /// Creates a function to compare all the (xs[i], xs[j]) pairs. The method to
176 /// compare each pair is create via `compareBuilder`.
177 static void createCompareFuncImplementation(
178     OpBuilder &builder, ModuleOp unused, func::FuncOp func, uint64_t nx,
179     uint64_t ny, bool isCoo,
180     function_ref<scf::IfOp(OpBuilder &, Location, Value, Value, Value, bool)>
181         compareBuilder) {
182   OpBuilder::InsertionGuard insertionGuard(builder);
183 
184   Block *entryBlock = func.addEntryBlock();
185   builder.setInsertionPointToStart(entryBlock);
186   Location loc = func.getLoc();
187   ValueRange args = entryBlock->getArguments();
188 
189   scf::IfOp topIfOp;
190   auto bodyBuilder = [&](uint64_t k, Value i, Value j, Value buffer) {
191     scf::IfOp ifOp = compareBuilder(builder, loc, i, j, buffer, (k == nx - 1));
192     if (k == 0) {
193       topIfOp = ifOp;
194     } else {
195       OpBuilder::InsertionGuard insertionGuard(builder);
196       builder.setInsertionPointAfter(ifOp);
197       builder.create<scf::YieldOp>(loc, ifOp.getResult(0));
198     }
199   };
200 
201   forEachIJPairInXs(builder, loc, args, nx, ny, isCoo, bodyBuilder);
202 
203   builder.setInsertionPointAfter(topIfOp);
204   builder.create<func::ReturnOp>(loc, topIfOp.getResult(0));
205 }
206 
207 /// Generates an if-statement to compare whether x[i] is equal to x[j].
208 static scf::IfOp createEqCompare(OpBuilder &builder, Location loc, Value i,
209                                  Value j, Value x, bool isLastDim) {
210   Value f = constantI1(builder, loc, false);
211   Value t = constantI1(builder, loc, true);
212   Value vi = builder.create<memref::LoadOp>(loc, x, i);
213   Value vj = builder.create<memref::LoadOp>(loc, x, j);
214 
215   Value cond =
216       builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, vi, vj);
217   scf::IfOp ifOp =
218       builder.create<scf::IfOp>(loc, f.getType(), cond, /*else=*/true);
219 
220   // x[1] != x[j]:
221   builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
222   builder.create<scf::YieldOp>(loc, f);
223 
224   // x[i] == x[j]:
225   builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
226   if (isLastDim == 1) {
227     // Finish checking all dimensions.
228     builder.create<scf::YieldOp>(loc, t);
229   }
230 
231   return ifOp;
232 }
233 
234 /// Creates a function to compare whether xs[i] is equal to xs[j].
235 //
236 // The generate IR corresponds to this C like algorithm:
237 //   if (x0[i] != x0[j])
238 //     return false;
239 //   else
240 //     if (x1[i] != x1[j])
241 //       return false;
242 //     else if (x2[2] != x2[j]))
243 //       and so on ...
244 static void createEqCompareFunc(OpBuilder &builder, ModuleOp unused,
245                                 func::FuncOp func, uint64_t nx, uint64_t ny,
246                                 bool isCoo) {
247   createCompareFuncImplementation(builder, unused, func, nx, ny, isCoo,
248                                   createEqCompare);
249 }
250 
251 /// Generates an if-statement to compare whether x[i] is less than x[j].
252 static scf::IfOp createLessThanCompare(OpBuilder &builder, Location loc,
253                                        Value i, Value j, Value x,
254                                        bool isLastDim) {
255   Value f = constantI1(builder, loc, false);
256   Value t = constantI1(builder, loc, true);
257   Value vi = builder.create<memref::LoadOp>(loc, x, i);
258   Value vj = builder.create<memref::LoadOp>(loc, x, j);
259 
260   Value cond =
261       builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult, vi, vj);
262   scf::IfOp ifOp =
263       builder.create<scf::IfOp>(loc, f.getType(), cond, /*else=*/true);
264   // If (x[i] < x[j]).
265   builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
266   builder.create<scf::YieldOp>(loc, t);
267 
268   builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
269   if (isLastDim == 1) {
270     // Finish checking all dimensions.
271     builder.create<scf::YieldOp>(loc, f);
272   } else {
273     cond =
274         builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult, vj, vi);
275     scf::IfOp ifOp2 =
276         builder.create<scf::IfOp>(loc, f.getType(), cond, /*else=*/true);
277     // Otherwise if (x[j] < x[i]).
278     builder.setInsertionPointToStart(&ifOp2.getThenRegion().front());
279     builder.create<scf::YieldOp>(loc, f);
280 
281     // Otherwise check the remaining dimensions.
282     builder.setInsertionPointAfter(ifOp2);
283     builder.create<scf::YieldOp>(loc, ifOp2.getResult(0));
284     // Set up the insertion point for the nested if-stmt that checks the
285     // remaining dimensions.
286     builder.setInsertionPointToStart(&ifOp2.getElseRegion().front());
287   }
288 
289   return ifOp;
290 }
291 
292 /// Creates a function to compare whether xs[i] is less than xs[j].
293 //
294 // The generate IR corresponds to this C like algorithm:
295 //   if (x0[i] < x0[j])
296 //     return true;
297 //   else if (x0[j] < x0[i])
298 //     return false;
299 //   else
300 //     if (x1[i] < x1[j])
301 //       return true;
302 //     else if (x1[j] < x1[i]))
303 //       and so on ...
304 static void createLessThanFunc(OpBuilder &builder, ModuleOp unused,
305                                func::FuncOp func, uint64_t nx, uint64_t ny,
306                                bool isCoo) {
307   createCompareFuncImplementation(builder, unused, func, nx, ny, isCoo,
308                                   createLessThanCompare);
309 }
310 
311 /// Creates a function to use a binary search to find the insertion point for
312 /// inserting xs[hi] to the sorted values xs[lo..hi).
313 //
314 // The generate IR corresponds to this C like algorithm:
315 //   p = hi
316 //   while (lo < hi)
317 //      mid = (lo + hi) >> 1
318 //      if (xs[p] < xs[mid])
319 //        hi = mid
320 //      else
321 //        lo = mid - 1
322 //   return lo;
323 //
324 static void createBinarySearchFunc(OpBuilder &builder, ModuleOp module,
325                                    func::FuncOp func, uint64_t nx, uint64_t ny,
326                                    bool isCoo) {
327   OpBuilder::InsertionGuard insertionGuard(builder);
328   Block *entryBlock = func.addEntryBlock();
329   builder.setInsertionPointToStart(entryBlock);
330 
331   Location loc = func.getLoc();
332   ValueRange args = entryBlock->getArguments();
333   Value p = args[hiIdx];
334   SmallVector<Type, 2> types(2, p.getType());  // only two
335   scf::WhileOp whileOp = builder.create<scf::WhileOp>(
336       loc, types, SmallVector<Value, 2>{args[loIdx], args[hiIdx]});
337 
338   // The before-region of the WhileOp.
339   Block *before =
340       builder.createBlock(&whileOp.getBefore(), {}, types, {loc, loc});
341   builder.setInsertionPointToEnd(before);
342   Value cond1 = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult,
343                                               before->getArgument(0),
344                                               before->getArgument(1));
345   builder.create<scf::ConditionOp>(loc, cond1, before->getArguments());
346 
347   // The after-region of the WhileOp.
348   Block *after =
349       builder.createBlock(&whileOp.getAfter(), {}, types, {loc, loc});
350   builder.setInsertionPointToEnd(after);
351   Value lo = after->getArgument(0);
352   Value hi = after->getArgument(1);
353   // Compute mid = (lo + hi) >> 1.
354   Value c1 = constantIndex(builder, loc, 1);
355   Value mid = builder.create<arith::ShRUIOp>(
356       loc, builder.create<arith::AddIOp>(loc, lo, hi), c1);
357   Value midp1 = builder.create<arith::AddIOp>(loc, mid, c1);
358 
359   // Compare xs[p] < xs[mid].
360   SmallVector<Value> compareOperands{p, mid};
361   uint64_t numXBuffers = isCoo ? 1 : nx;
362   compareOperands.append(args.begin() + xStartIdx,
363                          args.begin() + xStartIdx + numXBuffers);
364   Type i1Type = IntegerType::get(module.getContext(), 1, IntegerType::Signless);
365   FlatSymbolRefAttr lessThanFunc = getMangledSortHelperFunc(
366       builder, func, {i1Type}, kLessThanFuncNamePrefix, nx, ny, isCoo,
367       compareOperands, createLessThanFunc);
368   Value cond2 = builder
369                     .create<func::CallOp>(loc, lessThanFunc, TypeRange{i1Type},
370                                           compareOperands)
371                     .getResult(0);
372 
373   // Update lo and hi for the WhileOp as follows:
374   //   if (xs[p] < xs[mid]))
375   //     hi = mid;
376   //   else
377   //     lo = mid + 1;
378   Value newLo = builder.create<arith::SelectOp>(loc, cond2, lo, midp1);
379   Value newHi = builder.create<arith::SelectOp>(loc, cond2, mid, hi);
380   builder.create<scf::YieldOp>(loc, ValueRange{newLo, newHi});
381 
382   builder.setInsertionPointAfter(whileOp);
383   builder.create<func::ReturnOp>(loc, whileOp.getResult(0));
384 }
385 
386 /// Creates code to advance i in a loop based on xs[p] as follows:
387 ///   while (xs[i] < xs[p]) i += step (step > 0)
388 /// or
389 ///   while (xs[i] > xs[p]) i += step (step < 0)
390 /// The routine returns i as well as a boolean value to indicate whether
391 /// xs[i] == xs[p].
392 static std::pair<Value, Value>
393 createScanLoop(OpBuilder &builder, ModuleOp module, func::FuncOp func,
394                ValueRange xs, Value i, Value p, uint64_t nx, uint64_t ny,
395                bool isCoo, int step) {
396   Location loc = func.getLoc();
397   scf::WhileOp whileOp =
398       builder.create<scf::WhileOp>(loc, TypeRange{i.getType()}, ValueRange{i});
399 
400   Block *before =
401       builder.createBlock(&whileOp.getBefore(), {}, {i.getType()}, {loc});
402   builder.setInsertionPointToEnd(before);
403   SmallVector<Value> compareOperands;
404   if (step > 0) {
405     compareOperands.push_back(before->getArgument(0));
406     compareOperands.push_back(p);
407   } else {
408     assert(step < 0);
409     compareOperands.push_back(p);
410     compareOperands.push_back(before->getArgument(0));
411   }
412   compareOperands.append(xs.begin(), xs.end());
413   MLIRContext *context = module.getContext();
414   Type i1Type = IntegerType::get(context, 1, IntegerType::Signless);
415   FlatSymbolRefAttr lessThanFunc = getMangledSortHelperFunc(
416       builder, func, {i1Type}, kLessThanFuncNamePrefix, nx, ny, isCoo,
417       compareOperands, createLessThanFunc);
418   Value cond = builder
419                    .create<func::CallOp>(loc, lessThanFunc, TypeRange{i1Type},
420                                          compareOperands)
421                    .getResult(0);
422   builder.create<scf::ConditionOp>(loc, cond, before->getArguments());
423 
424   Block *after =
425       builder.createBlock(&whileOp.getAfter(), {}, {i.getType()}, {loc});
426   builder.setInsertionPointToEnd(after);
427   Value cs = constantIndex(builder, loc, step);
428   i = builder.create<arith::AddIOp>(loc, after->getArgument(0), cs);
429   builder.create<scf::YieldOp>(loc, ValueRange{i});
430   i = whileOp.getResult(0);
431 
432   builder.setInsertionPointAfter(whileOp);
433   compareOperands[0] = i;
434   compareOperands[1] = p;
435   FlatSymbolRefAttr compareEqFunc = getMangledSortHelperFunc(
436       builder, func, {i1Type}, kCompareEqFuncNamePrefix, nx, ny, isCoo,
437       compareOperands, createEqCompareFunc);
438   Value compareEq =
439       builder
440           .create<func::CallOp>(loc, compareEqFunc, TypeRange{i1Type},
441                                 compareOperands)
442           .getResult(0);
443 
444   return std::make_pair(whileOp.getResult(0), compareEq);
445 }
446 
447 /// Creates a function to perform quick sort partition on the values in the
448 /// range of index [lo, hi), assuming lo < hi.
449 //
450 // The generated IR corresponds to this C like algorithm:
451 // int partition(lo, hi, xs) {
452 //   p = (lo+hi)/2  // pivot index
453 //   i = lo
454 //   j = hi-1
455 //   while (i < j) do {
456 //     while (xs[i] < xs[p]) i ++;
457 //     i_eq = (xs[i] == xs[p]);
458 //     while (xs[j] > xs[p]) j --;
459 //     j_eq = (xs[j] == xs[p]);
460 //     if (i < j) {
461 //       swap(xs[i], xs[j])
462 //       if (i == p) {
463 //         p = j;
464 //       } else if (j == p) {
465 //         p = i;
466 //       }
467 //       if (i_eq && j_eq) {
468 //         ++i;
469 //         --j;
470 //       }
471 //     }
472 //   }
473 //   return p
474 //   }
475 static void createPartitionFunc(OpBuilder &builder, ModuleOp module,
476                                 func::FuncOp func, uint64_t nx, uint64_t ny,
477                                 bool isCoo) {
478   OpBuilder::InsertionGuard insertionGuard(builder);
479 
480   Block *entryBlock = func.addEntryBlock();
481   builder.setInsertionPointToStart(entryBlock);
482 
483   Location loc = func.getLoc();
484   ValueRange args = entryBlock->getArguments();
485   Value lo = args[loIdx];
486   Value hi = args[hiIdx];
487   Value sum = builder.create<arith::AddIOp>(loc, lo, hi);
488   Value c1 = constantIndex(builder, loc, 1);
489   Value p = builder.create<arith::ShRUIOp>(loc, sum, c1);
490 
491   Value i = lo;
492   Value j = builder.create<arith::SubIOp>(loc, hi, c1);
493   SmallVector<Value, 3> operands{i, j, p};  // exactly three
494   SmallVector<Type, 3> types{i.getType(), j.getType(), p.getType()};
495   scf::WhileOp whileOp = builder.create<scf::WhileOp>(loc, types, operands);
496 
497   // The before-region of the WhileOp.
498   Block *before =
499       builder.createBlock(&whileOp.getBefore(), {}, types, {loc, loc, loc});
500   builder.setInsertionPointToEnd(before);
501   Value cond = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult,
502                                              before->getArgument(0),
503                                              before->getArgument(1));
504   builder.create<scf::ConditionOp>(loc, cond, before->getArguments());
505 
506   // The after-region of the WhileOp.
507   Block *after =
508       builder.createBlock(&whileOp.getAfter(), {}, types, {loc, loc, loc});
509   builder.setInsertionPointToEnd(after);
510   i = after->getArgument(0);
511   j = after->getArgument(1);
512   p = after->getArgument(2);
513 
514   uint64_t numXBuffers = isCoo ? 1 : nx;
515   auto [iresult, iCompareEq] =
516       createScanLoop(builder, module, func, args.slice(xStartIdx, numXBuffers),
517                      i, p, nx, ny, isCoo, 1);
518   i = iresult;
519   auto [jresult, jCompareEq] =
520       createScanLoop(builder, module, func, args.slice(xStartIdx, numXBuffers),
521                      j, p, nx, ny, isCoo, -1);
522   j = jresult;
523 
524   // If i < j:
525   cond = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult, i, j);
526   scf::IfOp ifOp = builder.create<scf::IfOp>(loc, types, cond, /*else=*/true);
527   builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
528   SmallVector<Value> swapOperands{i, j};
529   swapOperands.append(args.begin() + xStartIdx, args.end());
530   createSwap(builder, loc, swapOperands, nx, ny, isCoo);
531   // If the pivot is moved, update p with the new pivot.
532   Value icond =
533       builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, i, p);
534   scf::IfOp ifOpI = builder.create<scf::IfOp>(loc, TypeRange{p.getType()},
535                                               icond, /*else=*/true);
536   builder.setInsertionPointToStart(&ifOpI.getThenRegion().front());
537   builder.create<scf::YieldOp>(loc, ValueRange{j});
538   builder.setInsertionPointToStart(&ifOpI.getElseRegion().front());
539   Value jcond =
540       builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, j, p);
541   scf::IfOp ifOpJ = builder.create<scf::IfOp>(loc, TypeRange{p.getType()},
542                                               jcond, /*else=*/true);
543   builder.setInsertionPointToStart(&ifOpJ.getThenRegion().front());
544   builder.create<scf::YieldOp>(loc, ValueRange{i});
545   builder.setInsertionPointToStart(&ifOpJ.getElseRegion().front());
546   builder.create<scf::YieldOp>(loc, ValueRange{p});
547   builder.setInsertionPointAfter(ifOpJ);
548   builder.create<scf::YieldOp>(loc, ifOpJ.getResults());
549   builder.setInsertionPointAfter(ifOpI);
550   Value compareEqIJ =
551       builder.create<arith::AndIOp>(loc, iCompareEq, jCompareEq);
552   scf::IfOp ifOp2 = builder.create<scf::IfOp>(
553       loc, TypeRange{i.getType(), j.getType()}, compareEqIJ, /*else=*/true);
554   builder.setInsertionPointToStart(&ifOp2.getThenRegion().front());
555   Value i2 = builder.create<arith::AddIOp>(loc, i, c1);
556   Value j2 = builder.create<arith::SubIOp>(loc, j, c1);
557   builder.create<scf::YieldOp>(loc, ValueRange{i2, j2});
558   builder.setInsertionPointToStart(&ifOp2.getElseRegion().front());
559   builder.create<scf::YieldOp>(loc, ValueRange{i, j});
560   builder.setInsertionPointAfter(ifOp2);
561   builder.create<scf::YieldOp>(
562       loc,
563       ValueRange{ifOp2.getResult(0), ifOp2.getResult(1), ifOpI.getResult(0)});
564 
565   // False branch for if i < j:
566   builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
567   builder.create<scf::YieldOp>(loc, ValueRange{i, j, p});
568 
569   // Return for the whileOp.
570   builder.setInsertionPointAfter(ifOp);
571   builder.create<scf::YieldOp>(loc, ifOp.getResults());
572 
573   // Return for the function.
574   builder.setInsertionPointAfter(whileOp);
575   builder.create<func::ReturnOp>(loc, whileOp.getResult(2));
576 }
577 
578 /// Creates a function to perform quick sort on the value in the range of
579 /// index [lo, hi).
580 //
581 // The generate IR corresponds to this C like algorithm:
582 // void quickSort(lo, hi, data) {
583 //   if (lo < hi) {
584 //        p = partition(low, high, data);
585 //        quickSort(lo, p, data);
586 //        quickSort(p + 1, hi, data);
587 //   }
588 // }
589 static void createSortNonstableFunc(OpBuilder &builder, ModuleOp module,
590                                     func::FuncOp func, uint64_t nx, uint64_t ny,
591                                     bool isCoo) {
592   OpBuilder::InsertionGuard insertionGuard(builder);
593   Block *entryBlock = func.addEntryBlock();
594   builder.setInsertionPointToStart(entryBlock);
595 
596   MLIRContext *context = module.getContext();
597   Location loc = func.getLoc();
598   ValueRange args = entryBlock->getArguments();
599   Value lo = args[loIdx];
600   Value hi = args[hiIdx];
601   Value cond =
602       builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult, lo, hi);
603   scf::IfOp ifOp = builder.create<scf::IfOp>(loc, cond, /*else=*/false);
604 
605   // The if-stmt true branch.
606   builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
607   FlatSymbolRefAttr partitionFunc = getMangledSortHelperFunc(
608       builder, func, {IndexType::get(context)}, kPartitionFuncNamePrefix, nx,
609       ny, isCoo, args, createPartitionFunc);
610   auto p = builder.create<func::CallOp>(
611       loc, partitionFunc, TypeRange{IndexType::get(context)}, ValueRange(args));
612 
613   SmallVector<Value> lowOperands{lo, p.getResult(0)};
614   lowOperands.append(args.begin() + xStartIdx, args.end());
615   builder.create<func::CallOp>(loc, func, lowOperands);
616 
617   SmallVector<Value> highOperands{
618       builder.create<arith::AddIOp>(loc, p.getResult(0),
619                                     constantIndex(builder, loc, 1)),
620       hi};
621   highOperands.append(args.begin() + xStartIdx, args.end());
622   builder.create<func::CallOp>(loc, func, highOperands);
623 
624   // After the if-stmt.
625   builder.setInsertionPointAfter(ifOp);
626   builder.create<func::ReturnOp>(loc);
627 }
628 
629 /// Creates a function to perform insertion sort on the values in the range of
630 /// index [lo, hi).
631 //
632 // The generate IR corresponds to this C like algorithm:
633 // void insertionSort(lo, hi, data) {
634 //   for (i = lo+1; i < hi; i++) {
635 //      d = data[i];
636 //      p = binarySearch(lo, i-1, data)
637 //      for (j = 0; j > i - p; j++)
638 //        data[i-j] = data[i-j-1]
639 //      data[p] = d
640 //   }
641 // }
642 static void createSortStableFunc(OpBuilder &builder, ModuleOp module,
643                                  func::FuncOp func, uint64_t nx, uint64_t ny,
644                                  bool isCoo) {
645   OpBuilder::InsertionGuard insertionGuard(builder);
646   Block *entryBlock = func.addEntryBlock();
647   builder.setInsertionPointToStart(entryBlock);
648 
649   MLIRContext *context = module.getContext();
650   Location loc = func.getLoc();
651   ValueRange args = entryBlock->getArguments();
652   Value c1 = constantIndex(builder, loc, 1);
653   Value lo = args[loIdx];
654   Value hi = args[hiIdx];
655   Value lop1 = builder.create<arith::AddIOp>(loc, lo, c1);
656 
657   // Start the outer for-stmt with induction variable i.
658   scf::ForOp forOpI = builder.create<scf::ForOp>(loc, lop1, hi, c1);
659   builder.setInsertionPointToStart(forOpI.getBody());
660   Value i = forOpI.getInductionVar();
661 
662   // Binary search to find the insertion point p.
663   SmallVector<Value> operands{lo, i};
664   operands.append(args.begin() + xStartIdx, args.end());
665   FlatSymbolRefAttr searchFunc = getMangledSortHelperFunc(
666       builder, func, {IndexType::get(context)}, kBinarySearchFuncNamePrefix, nx,
667       ny, isCoo, operands, createBinarySearchFunc);
668   Value p = builder
669                 .create<func::CallOp>(loc, searchFunc, TypeRange{c1.getType()},
670                                       operands)
671                 .getResult(0);
672 
673   // Move the value at data[i] to a temporary location.
674   operands[0] = operands[1] = i;
675   SmallVector<Value> d;
676   forEachIJPairInAllBuffers(
677       builder, loc, operands, nx, ny, isCoo,
678       [&](uint64_t unused, Value i, Value unused2, Value buffer) {
679         d.push_back(builder.create<memref::LoadOp>(loc, buffer, i));
680       });
681 
682   // Start the inner for-stmt with induction variable j, for moving data[p..i)
683   // to data[p+1..i+1).
684   Value imp = builder.create<arith::SubIOp>(loc, i, p);
685   Value c0 = constantIndex(builder, loc, 0);
686   scf::ForOp forOpJ = builder.create<scf::ForOp>(loc, c0, imp, c1);
687   builder.setInsertionPointToStart(forOpJ.getBody());
688   Value j = forOpJ.getInductionVar();
689   Value imj = builder.create<arith::SubIOp>(loc, i, j);
690   operands[1] = imj;
691   operands[0] = builder.create<arith::SubIOp>(loc, imj, c1);
692   forEachIJPairInAllBuffers(
693       builder, loc, operands, nx, ny, isCoo,
694       [&](uint64_t unused, Value imjm1, Value imj, Value buffer) {
695         Value t = builder.create<memref::LoadOp>(loc, buffer, imjm1);
696         builder.create<memref::StoreOp>(loc, t, buffer, imj);
697       });
698 
699   // Store the value at data[i] to data[p].
700   builder.setInsertionPointAfter(forOpJ);
701   operands[0] = operands[1] = p;
702   forEachIJPairInAllBuffers(
703       builder, loc, operands, nx, ny, isCoo,
704       [&](uint64_t k, Value p, Value usused, Value buffer) {
705         builder.create<memref::StoreOp>(loc, d[k], buffer, p);
706       });
707 
708   builder.setInsertionPointAfter(forOpI);
709   builder.create<func::ReturnOp>(loc);
710 }
711 
712 /// Implements the rewriting for operator sort and sort_coo.
713 template <typename OpTy>
714 LogicalResult matchAndRewriteSortOp(OpTy op, ValueRange xys, uint64_t nx,
715                                     uint64_t ny, bool isCoo,
716                                     PatternRewriter &rewriter) {
717   Location loc = op.getLoc();
718   SmallVector<Value> operands{constantIndex(rewriter, loc, 0), op.getN()};
719 
720   // Convert `values` to have dynamic shape and append them to `operands`.
721   for (Value v : xys) {
722     auto mtp = v.getType().cast<MemRefType>();
723     if (!mtp.isDynamicDim(0)) {
724       auto newMtp =
725           MemRefType::get({ShapedType::kDynamic}, mtp.getElementType());
726       v = rewriter.create<memref::CastOp>(loc, newMtp, v);
727     }
728     operands.push_back(v);
729   }
730   auto insertPoint = op->template getParentOfType<func::FuncOp>();
731   SmallString<32> funcName(op.getStable() ? kSortStableFuncNamePrefix
732                                           : kSortNonstableFuncNamePrefix);
733   FuncGeneratorType funcGenerator =
734       op.getStable() ? createSortStableFunc : createSortNonstableFunc;
735   FlatSymbolRefAttr func =
736       getMangledSortHelperFunc(rewriter, insertPoint, TypeRange(), funcName, nx,
737                                ny, isCoo, operands, funcGenerator);
738   rewriter.replaceOpWithNewOp<func::CallOp>(op, func, TypeRange(), operands);
739   return success();
740 }
741 
742 //===---------------------------------------------------------------------===//
743 // The actual sparse buffer rewriting rules.
744 //===---------------------------------------------------------------------===//
745 
746 namespace {
747 
748 /// Sparse rewriting rule for the push_back operator.
749 struct PushBackRewriter : OpRewritePattern<PushBackOp> {
750 public:
751   using OpRewritePattern<PushBackOp>::OpRewritePattern;
752   PushBackRewriter(MLIRContext *context, bool enableInit)
753       : OpRewritePattern(context), enableBufferInitialization(enableInit) {}
754   LogicalResult matchAndRewrite(PushBackOp op,
755                                 PatternRewriter &rewriter) const override {
756     // Rewrite push_back(buffer, value, n) to:
757     // new_size = size(buffer) + n
758     // if (new_size > capacity(buffer))
759     //    while new_size > new_capacity
760     //      new_capacity = new_capacity*2
761     //    new_buffer = realloc(buffer, new_capacity)
762     // buffer = new_buffer
763     // subBuffer = subviewof(buffer)
764     // linalg.fill subBuffer value
765     //
766     // size(buffer) += n
767     //
768     // The capacity check is skipped when the attribute inbounds is presented.
769     Location loc = op->getLoc();
770     Value c0 = constantIndex(rewriter, loc, 0);
771     Value buffer = op.getInBuffer();
772     Value capacity = rewriter.create<memref::DimOp>(loc, buffer, c0);
773     Value idx = constantIndex(rewriter, loc, op.getIdx().getZExtValue());
774     Value bufferSizes = op.getBufferSizes();
775     Value size = rewriter.create<memref::LoadOp>(loc, bufferSizes, idx);
776     Value value = op.getValue();
777 
778     Value n = op.getN() ? op.getN() : constantIndex(rewriter, loc, 1);
779     Value newSize = rewriter.create<arith::AddIOp>(loc, size, n);
780     auto nValue = dyn_cast_or_null<arith::ConstantIndexOp>(n.getDefiningOp());
781     bool nIsOne = (nValue && nValue.value() == 1);
782 
783     if (!op.getInbounds()) {
784       Value cond = rewriter.create<arith::CmpIOp>(
785           loc, arith::CmpIPredicate::ugt, newSize, capacity);
786 
787       Value c2 = constantIndex(rewriter, loc, 2);
788       auto bufferType =
789           MemRefType::get({ShapedType::kDynamic}, value.getType());
790       scf::IfOp ifOp = rewriter.create<scf::IfOp>(loc, bufferType, cond,
791                                                   /*else=*/true);
792       // True branch.
793       rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front());
794       if (nIsOne) {
795         capacity = rewriter.create<arith::MulIOp>(loc, capacity, c2);
796       } else {
797         // Use a do-while loop to calculate the new capacity as follows:
798         //   do { new_capacity *= 2 } while (size > new_capacity)
799         scf::WhileOp whileOp =
800             rewriter.create<scf::WhileOp>(loc, capacity.getType(), capacity);
801 
802         // The before-region of the WhileOp.
803         Block *before = rewriter.createBlock(&whileOp.getBefore(), {},
804                                              {capacity.getType()}, {loc});
805         rewriter.setInsertionPointToEnd(before);
806 
807         capacity =
808             rewriter.create<arith::MulIOp>(loc, before->getArgument(0), c2);
809         cond = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ugt,
810                                               newSize, capacity);
811         rewriter.create<scf::ConditionOp>(loc, cond, ValueRange{capacity});
812         // The after-region of the WhileOp.
813         Block *after = rewriter.createBlock(&whileOp.getAfter(), {},
814                                             {capacity.getType()}, {loc});
815         rewriter.setInsertionPointToEnd(after);
816         rewriter.create<scf::YieldOp>(loc, after->getArguments());
817 
818         rewriter.setInsertionPointAfter(whileOp);
819         capacity = whileOp.getResult(0);
820       }
821 
822       Value newBuffer =
823           rewriter.create<memref::ReallocOp>(loc, bufferType, buffer, capacity);
824       if (enableBufferInitialization) {
825         Value fillSize = rewriter.create<arith::SubIOp>(loc, capacity, newSize);
826         Value fillValue = constantZero(rewriter, loc, value.getType());
827         Value subBuffer = rewriter.create<memref::SubViewOp>(
828             loc, newBuffer, /*offset=*/ValueRange{newSize},
829             /*size=*/ValueRange{fillSize},
830             /*step=*/ValueRange{constantIndex(rewriter, loc, 1)});
831         rewriter.create<linalg::FillOp>(loc, fillValue, subBuffer);
832       }
833       rewriter.create<scf::YieldOp>(loc, newBuffer);
834 
835       // False branch.
836       rewriter.setInsertionPointToStart(&ifOp.getElseRegion().front());
837       rewriter.create<scf::YieldOp>(loc, buffer);
838 
839       // Prepare for adding the value to the end of the buffer.
840       rewriter.setInsertionPointAfter(ifOp);
841       buffer = ifOp.getResult(0);
842     }
843 
844     // Add the value to the end of the buffer.
845     if (nIsOne) {
846       rewriter.create<memref::StoreOp>(loc, value, buffer, size);
847     } else {
848       Value subBuffer = rewriter.create<memref::SubViewOp>(
849           loc, buffer, /*offset=*/ValueRange{size}, /*size=*/ValueRange{n},
850           /*step=*/ValueRange{constantIndex(rewriter, loc, 1)});
851       rewriter.create<linalg::FillOp>(loc, value, subBuffer);
852     }
853 
854     // Update the buffer size.
855     rewriter.create<memref::StoreOp>(loc, newSize, bufferSizes, idx);
856     rewriter.replaceOp(op, buffer);
857     return success();
858   }
859 
860 private:
861   bool enableBufferInitialization;
862 };
863 
864 /// Sparse rewriting rule for the sort operator.
865 struct SortRewriter : public OpRewritePattern<SortOp> {
866 public:
867   using OpRewritePattern<SortOp>::OpRewritePattern;
868 
869   LogicalResult matchAndRewrite(SortOp op,
870                                 PatternRewriter &rewriter) const override {
871     SmallVector<Value> xys(op.getXs());
872     xys.append(op.getYs().begin(), op.getYs().end());
873     return matchAndRewriteSortOp(op, xys, op.getXs().size(), /*ny=*/0,
874                                  /*isCoo=*/false, rewriter);
875   }
876 };
877 
878 /// Sparse rewriting rule for the sort_coo operator.
879 struct SortCooRewriter : public OpRewritePattern<SortCooOp> {
880 public:
881   using OpRewritePattern<SortCooOp>::OpRewritePattern;
882 
883   LogicalResult matchAndRewrite(SortCooOp op,
884                                 PatternRewriter &rewriter) const override {
885     SmallVector<Value> xys;
886     xys.push_back(op.getXy());
887     xys.append(op.getYs().begin(), op.getYs().end());
888     uint64_t nx = 1;
889     if (auto nxAttr = op.getNxAttr())
890       nx = nxAttr.getInt();
891 
892     uint64_t ny = 0;
893     if (auto nyAttr = op.getNyAttr())
894       ny = nyAttr.getInt();
895 
896     return matchAndRewriteSortOp(op, xys, nx, ny,
897                                  /*isCoo=*/true, rewriter);
898   }
899 };
900 
901 } // namespace
902 
903 //===---------------------------------------------------------------------===//
904 // Methods that add patterns described in this file to a pattern list.
905 //===---------------------------------------------------------------------===//
906 
907 void mlir::populateSparseBufferRewriting(RewritePatternSet &patterns,
908                                          bool enableBufferInitialization) {
909   patterns.add<PushBackRewriter>(patterns.getContext(),
910                                  enableBufferInitialization);
911   patterns.add<SortRewriter, SortCooRewriter>(patterns.getContext());
912 }
913