xref: /llvm-project/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp (revision 5618d2bea965e127f8d72a2e8fce1b444ce986bf)
1 //===- SparseBufferRewriting.cpp - Sparse buffer rewriting rules ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements rewriting rules that are specific to sparse tensor
10 // primitives with memref operands.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "CodegenUtils.h"
15 
16 #include "mlir/Dialect/Arith/IR/Arith.h"
17 #include "mlir/Dialect/Func/IR/FuncOps.h"
18 #include "mlir/Dialect/Linalg/IR/Linalg.h"
19 #include "mlir/Dialect/MemRef/IR/MemRef.h"
20 #include "mlir/Dialect/SCF/IR/SCF.h"
21 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
22 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
23 #include "mlir/Support/LLVM.h"
24 
25 using namespace mlir;
26 using namespace mlir::sparse_tensor;
27 
28 //===---------------------------------------------------------------------===//
29 // Helper methods for the actual rewriting rules.
30 //===---------------------------------------------------------------------===//
31 
32 static constexpr uint64_t loIdx = 0;
33 static constexpr uint64_t hiIdx = 1;
34 static constexpr uint64_t xStartIdx = 2;
35 
36 static constexpr const char kLessThanFuncNamePrefix[] = "_sparse_less_than_";
37 static constexpr const char kCompareEqFuncNamePrefix[] = "_sparse_compare_eq_";
38 static constexpr const char kPartitionFuncNamePrefix[] = "_sparse_partition_";
39 static constexpr const char kBinarySearchFuncNamePrefix[] =
40     "_sparse_binary_search_";
41 static constexpr const char kSortNonstableFuncNamePrefix[] =
42     "_sparse_sort_nonstable_";
43 static constexpr const char kSortStableFuncNamePrefix[] =
44     "_sparse_sort_stable_";
45 
46 using FuncGeneratorType =
47     function_ref<void(OpBuilder &, ModuleOp, func::FuncOp, size_t)>;
48 
49 /// Constructs a function name with this format to facilitate quick sort:
50 ///   <namePrefix><dim>_<x type>_<y0 type>..._<yn type>
51 static void getMangledSortHelperFuncName(llvm::raw_svector_ostream &nameOstream,
52                                          StringRef namePrefix, size_t dim,
53                                          ValueRange operands) {
54   nameOstream
55       << namePrefix << dim << "_"
56       << operands[xStartIdx].getType().cast<MemRefType>().getElementType();
57 
58   for (Value v : operands.drop_front(xStartIdx + dim))
59     nameOstream << "_" << v.getType().cast<MemRefType>().getElementType();
60 }
61 
62 /// Looks up a function that is appropriate for the given operands being
63 /// sorted, and creates such a function if it doesn't exist yet.
64 static FlatSymbolRefAttr
65 getMangledSortHelperFunc(OpBuilder &builder, func::FuncOp insertPoint,
66                          TypeRange resultTypes, StringRef namePrefix,
67                          size_t dim, ValueRange operands,
68                          FuncGeneratorType createFunc) {
69   SmallString<32> nameBuffer;
70   llvm::raw_svector_ostream nameOstream(nameBuffer);
71   getMangledSortHelperFuncName(nameOstream, namePrefix, dim, operands);
72 
73   ModuleOp module = insertPoint->getParentOfType<ModuleOp>();
74   MLIRContext *context = module.getContext();
75   auto result = SymbolRefAttr::get(context, nameOstream.str());
76   auto func = module.lookupSymbol<func::FuncOp>(result.getAttr());
77 
78   if (!func) {
79     // Create the function.
80     OpBuilder::InsertionGuard insertionGuard(builder);
81     builder.setInsertionPoint(insertPoint);
82     Location loc = insertPoint.getLoc();
83     func = builder.create<func::FuncOp>(
84         loc, nameOstream.str(),
85         FunctionType::get(context, operands.getTypes(), resultTypes));
86     func.setPrivate();
87     createFunc(builder, module, func, dim);
88   }
89 
90   return result;
91 }
92 
93 /// Creates a code block for swapping the values in index i and j for all the
94 /// buffers.
95 //
96 // The generated IR corresponds to this C like algorithm:
97 //     swap(x0[i], x0[j]);
98 //     swap(x1[i], x1[j]);
99 //     ...
100 //     swap(xn[i], xn[j]);
101 //     swap(y0[i], y0[j]);
102 //     ...
103 //     swap(yn[i], yn[j]);
104 static void createSwap(OpBuilder &builder, Location loc, ValueRange args) {
105   Value i = args[0];
106   Value j = args[1];
107   for (auto arg : args.drop_front(xStartIdx)) {
108     Value vi = builder.create<memref::LoadOp>(loc, arg, i);
109     Value vj = builder.create<memref::LoadOp>(loc, arg, j);
110     builder.create<memref::StoreOp>(loc, vj, arg, i);
111     builder.create<memref::StoreOp>(loc, vi, arg, j);
112   }
113 }
114 
115 /// Creates a function to compare all the (xs[i], xs[j]) pairs. The method to
116 /// compare each pair is create via `compareBuilder`.
117 static void createCompareFuncImplementation(
118     OpBuilder &builder, ModuleOp unused, func::FuncOp func, size_t dim,
119     function_ref<scf::IfOp(OpBuilder &, Location, Value, Value, Value, bool)>
120         compareBuilder) {
121   OpBuilder::InsertionGuard insertionGuard(builder);
122 
123   Block *entryBlock = func.addEntryBlock();
124   builder.setInsertionPointToStart(entryBlock);
125   Location loc = func.getLoc();
126   ValueRange args = entryBlock->getArguments();
127 
128   scf::IfOp topIfOp;
129   for (const auto &item : llvm::enumerate(args.slice(xStartIdx, dim))) {
130     scf::IfOp ifOp = compareBuilder(builder, loc, args[0], args[1],
131                                     item.value(), (item.index() == dim - 1));
132     if (item.index() == 0) {
133       topIfOp = ifOp;
134     } else {
135       OpBuilder::InsertionGuard insertionGuard(builder);
136       builder.setInsertionPointAfter(ifOp);
137       builder.create<scf::YieldOp>(loc, ifOp.getResult(0));
138     }
139   }
140 
141   builder.setInsertionPointAfter(topIfOp);
142   builder.create<func::ReturnOp>(loc, topIfOp.getResult(0));
143 }
144 
145 /// Generates an if-statement to compare whether x[i] is equal to x[j].
146 static scf::IfOp createEqCompare(OpBuilder &builder, Location loc, Value i,
147                                  Value j, Value x, bool isLastDim) {
148   Value f = constantI1(builder, loc, false);
149   Value t = constantI1(builder, loc, true);
150   Value vi = builder.create<memref::LoadOp>(loc, x, i);
151   Value vj = builder.create<memref::LoadOp>(loc, x, j);
152 
153   Value cond =
154       builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, vi, vj);
155   scf::IfOp ifOp =
156       builder.create<scf::IfOp>(loc, f.getType(), cond, /*else=*/true);
157 
158   // x[1] != x[j]:
159   builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
160   builder.create<scf::YieldOp>(loc, f);
161 
162   // x[i] == x[j]:
163   builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
164   if (isLastDim == 1) {
165     // Finish checking all dimensions.
166     builder.create<scf::YieldOp>(loc, t);
167   }
168 
169   return ifOp;
170 }
171 
172 /// Creates a function to compare whether xs[i] is equal to xs[j].
173 //
174 // The generate IR corresponds to this C like algorithm:
175 //   if (x0[i] != x0[j])
176 //     return false;
177 //   else
178 //     if (x1[i] != x1[j])
179 //       return false;
180 //     else if (x2[2] != x2[j]))
181 //       and so on ...
182 static void createEqCompareFunc(OpBuilder &builder, ModuleOp unused,
183                                 func::FuncOp func, size_t dim) {
184   createCompareFuncImplementation(builder, unused, func, dim, createEqCompare);
185 }
186 
187 /// Generates an if-statement to compare whether x[i] is less than x[j].
188 static scf::IfOp createLessThanCompare(OpBuilder &builder, Location loc,
189                                        Value i, Value j, Value x,
190                                        bool isLastDim) {
191   Value f = constantI1(builder, loc, false);
192   Value t = constantI1(builder, loc, true);
193   Value vi = builder.create<memref::LoadOp>(loc, x, i);
194   Value vj = builder.create<memref::LoadOp>(loc, x, j);
195 
196   Value cond =
197       builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult, vi, vj);
198   scf::IfOp ifOp =
199       builder.create<scf::IfOp>(loc, f.getType(), cond, /*else=*/true);
200   // If (x[i] < x[j]).
201   builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
202   builder.create<scf::YieldOp>(loc, t);
203 
204   builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
205   if (isLastDim == 1) {
206     // Finish checking all dimensions.
207     builder.create<scf::YieldOp>(loc, f);
208   } else {
209     cond =
210         builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult, vj, vi);
211     scf::IfOp ifOp2 =
212         builder.create<scf::IfOp>(loc, f.getType(), cond, /*else=*/true);
213     // Otherwise if (x[j] < x[i]).
214     builder.setInsertionPointToStart(&ifOp2.getThenRegion().front());
215     builder.create<scf::YieldOp>(loc, f);
216 
217     // Otherwise check the remaining dimensions.
218     builder.setInsertionPointAfter(ifOp2);
219     builder.create<scf::YieldOp>(loc, ifOp2.getResult(0));
220     // Set up the insertion point for the nested if-stmt that checks the
221     // remaining dimensions.
222     builder.setInsertionPointToStart(&ifOp2.getElseRegion().front());
223   }
224 
225   return ifOp;
226 }
227 
228 /// Creates a function to compare whether xs[i] is less than xs[j].
229 //
230 // The generate IR corresponds to this C like algorithm:
231 //   if (x0[i] < x0[j])
232 //     return true;
233 //   else if (x0[j] < x0[i])
234 //     return false;
235 //   else
236 //     if (x1[i] < x1[j])
237 //       return true;
238 //     else if (x1[j] < x1[i]))
239 //       and so on ...
240 static void createLessThanFunc(OpBuilder &builder, ModuleOp unused,
241                                func::FuncOp func, size_t dim) {
242   createCompareFuncImplementation(builder, unused, func, dim,
243                                   createLessThanCompare);
244 }
245 
246 /// Creates a function to use a binary search to find the insertion point for
247 /// inserting xs[hi] to the sorted values xs[lo..hi).
248 //
249 // The generate IR corresponds to this C like algorithm:
250 //   p = hi
251 //   while (lo < hi)
252 //      mid = (lo + hi) >> 1
253 //      if (xs[p] < xs[mid])
254 //        hi = mid
255 //      else
256 //        lo = mid - 1
257 //   return lo;
258 //
259 static void createBinarySearchFunc(OpBuilder &builder, ModuleOp module,
260                                    func::FuncOp func, size_t dim) {
261   OpBuilder::InsertionGuard insertionGuard(builder);
262   Block *entryBlock = func.addEntryBlock();
263   builder.setInsertionPointToStart(entryBlock);
264 
265   Location loc = func.getLoc();
266   ValueRange args = entryBlock->getArguments();
267   Value p = args[hiIdx];
268   SmallVector<Type, 2> types(2, p.getType());
269   scf::WhileOp whileOp = builder.create<scf::WhileOp>(
270       loc, types, SmallVector<Value, 2>{args[loIdx], args[hiIdx]});
271 
272   // The before-region of the WhileOp.
273   Block *before =
274       builder.createBlock(&whileOp.getBefore(), {}, types, {loc, loc});
275   builder.setInsertionPointToEnd(before);
276   Value cond1 = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult,
277                                               before->getArgument(0),
278                                               before->getArgument(1));
279   builder.create<scf::ConditionOp>(loc, cond1, before->getArguments());
280 
281   // The after-region of the WhileOp.
282   Block *after =
283       builder.createBlock(&whileOp.getAfter(), {}, types, {loc, loc});
284   builder.setInsertionPointToEnd(after);
285   Value lo = after->getArgument(0);
286   Value hi = after->getArgument(1);
287   // Compute mid = (lo + hi) >> 1.
288   Value c1 = constantIndex(builder, loc, 1);
289   Value mid = builder.create<arith::ShRUIOp>(
290       loc, builder.create<arith::AddIOp>(loc, lo, hi), c1);
291   Value midp1 = builder.create<arith::AddIOp>(loc, mid, c1);
292 
293   // Compare xs[p] < xs[mid].
294   SmallVector<Value, 6> compareOperands{p, mid};
295   compareOperands.append(args.begin() + xStartIdx,
296                          args.begin() + xStartIdx + dim);
297   Type i1Type = IntegerType::get(module.getContext(), 1, IntegerType::Signless);
298   FlatSymbolRefAttr lessThanFunc =
299       getMangledSortHelperFunc(builder, func, {i1Type}, kLessThanFuncNamePrefix,
300                                dim, compareOperands, createLessThanFunc);
301   Value cond2 = builder
302                     .create<func::CallOp>(loc, lessThanFunc, TypeRange{i1Type},
303                                           compareOperands)
304                     .getResult(0);
305 
306   // Update lo and hi for the WhileOp as follows:
307   //   if (xs[p] < xs[mid]))
308   //     hi = mid;
309   //   else
310   //     lo = mid + 1;
311   Value newLo = builder.create<arith::SelectOp>(loc, cond2, lo, midp1);
312   Value newHi = builder.create<arith::SelectOp>(loc, cond2, mid, hi);
313   builder.create<scf::YieldOp>(loc, ValueRange{newLo, newHi});
314 
315   builder.setInsertionPointAfter(whileOp);
316   builder.create<func::ReturnOp>(loc, whileOp.getResult(0));
317 }
318 
319 /// Creates code to advance i in a loop based on xs[p] as follows:
320 ///   while (xs[i] < xs[p]) i += step (step > 0)
321 /// or
322 ///   while (xs[i] > xs[p]) i += step (step < 0)
323 /// The routine returns i as well as a boolean value to indicate whether
324 /// xs[i] == xs[p].
325 static std::pair<Value, Value>
326 createScanLoop(OpBuilder &builder, ModuleOp module, func::FuncOp func,
327                ValueRange xs, Value i, Value p, size_t dim, int step) {
328   Location loc = func.getLoc();
329   scf::WhileOp whileOp =
330       builder.create<scf::WhileOp>(loc, TypeRange{i.getType()}, ValueRange{i});
331 
332   Block *before =
333       builder.createBlock(&whileOp.getBefore(), {}, {i.getType()}, {loc});
334   builder.setInsertionPointToEnd(before);
335   SmallVector<Value, 6> compareOperands;
336   if (step > 0) {
337     compareOperands.push_back(before->getArgument(0));
338     compareOperands.push_back(p);
339   } else {
340     assert(step < 0);
341     compareOperands.push_back(p);
342     compareOperands.push_back(before->getArgument(0));
343   }
344   compareOperands.append(xs.begin(), xs.end());
345   MLIRContext *context = module.getContext();
346   Type i1Type = IntegerType::get(context, 1, IntegerType::Signless);
347   FlatSymbolRefAttr lessThanFunc =
348       getMangledSortHelperFunc(builder, func, {i1Type}, kLessThanFuncNamePrefix,
349                                dim, compareOperands, createLessThanFunc);
350   Value cond = builder
351                    .create<func::CallOp>(loc, lessThanFunc, TypeRange{i1Type},
352                                          compareOperands)
353                    .getResult(0);
354   builder.create<scf::ConditionOp>(loc, cond, before->getArguments());
355 
356   Block *after =
357       builder.createBlock(&whileOp.getAfter(), {}, {i.getType()}, {loc});
358   builder.setInsertionPointToEnd(after);
359   Value cs = constantIndex(builder, loc, step);
360   i = builder.create<arith::AddIOp>(loc, after->getArgument(0), cs);
361   builder.create<scf::YieldOp>(loc, ValueRange{i});
362   i = whileOp.getResult(0);
363 
364   builder.setInsertionPointAfter(whileOp);
365   compareOperands[0] = i;
366   compareOperands[1] = p;
367   FlatSymbolRefAttr compareEqFunc = getMangledSortHelperFunc(
368       builder, func, {i1Type}, kCompareEqFuncNamePrefix, dim, compareOperands,
369       createEqCompareFunc);
370   Value compareEq =
371       builder
372           .create<func::CallOp>(loc, compareEqFunc, TypeRange{i1Type},
373                                 compareOperands)
374           .getResult(0);
375 
376   return std::make_pair(whileOp.getResult(0), compareEq);
377 }
378 
379 /// Creates a function to perform quick sort partition on the values in the
380 /// range of index [lo, hi), assuming lo < hi.
381 //
382 // The generated IR corresponds to this C like algorithm:
383 // int partition(lo, hi, xs) {
384 //   p = (lo+hi)/2  // pivot index
385 //   i = lo
386 //   j = hi-1
387 //   while (i < j) do {
388 //     while (xs[i] < xs[p]) i ++;
389 //     i_eq = (xs[i] == xs[p]);
390 //     while (xs[j] > xs[p]) j --;
391 //     j_eq = (xs[j] == xs[p]);
392 //     if (i < j) {
393 //       swap(xs[i], xs[j])
394 //       if (i == p) {
395 //         p = j;
396 //       } else if (j == p) {
397 //         p = i;
398 //       }
399 //       if (i_eq && j_eq) {
400 //         ++i;
401 //         --j;
402 //       }
403 //     }
404 //   }
405 //   return p
406 //   }
407 static void createPartitionFunc(OpBuilder &builder, ModuleOp module,
408                                 func::FuncOp func, size_t dim) {
409   OpBuilder::InsertionGuard insertionGuard(builder);
410 
411   Block *entryBlock = func.addEntryBlock();
412   builder.setInsertionPointToStart(entryBlock);
413 
414   Location loc = func.getLoc();
415   ValueRange args = entryBlock->getArguments();
416   Value lo = args[loIdx];
417   Value hi = args[hiIdx];
418   Value sum = builder.create<arith::AddIOp>(loc, lo, hi);
419   Value c1 = constantIndex(builder, loc, 1);
420   Value p = builder.create<arith::ShRUIOp>(loc, sum, c1);
421 
422   Value i = lo;
423   Value j = builder.create<arith::SubIOp>(loc, hi, c1);
424   SmallVector<Value, 4> operands{i, j, p};
425   SmallVector<Type, 4> types{i.getType(), j.getType(), p.getType()};
426   scf::WhileOp whileOp = builder.create<scf::WhileOp>(loc, types, operands);
427 
428   // The before-region of the WhileOp.
429   Block *before =
430       builder.createBlock(&whileOp.getBefore(), {}, types, {loc, loc, loc});
431   builder.setInsertionPointToEnd(before);
432   Value cond = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult,
433                                              before->getArgument(0),
434                                              before->getArgument(1));
435   builder.create<scf::ConditionOp>(loc, cond, before->getArguments());
436 
437   // The after-region of the WhileOp.
438   Block *after =
439       builder.createBlock(&whileOp.getAfter(), {}, types, {loc, loc, loc});
440   builder.setInsertionPointToEnd(after);
441   i = after->getArgument(0);
442   j = after->getArgument(1);
443   p = after->getArgument(2);
444 
445   auto [iresult, iCompareEq] = createScanLoop(
446       builder, module, func, args.slice(xStartIdx, dim), i, p, dim, 1);
447   i = iresult;
448   auto [jresult, jCompareEq] = createScanLoop(
449       builder, module, func, args.slice(xStartIdx, dim), j, p, dim, -1);
450   j = jresult;
451 
452   // If i < j:
453   cond = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult, i, j);
454   scf::IfOp ifOp = builder.create<scf::IfOp>(loc, types, cond, /*else=*/true);
455   builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
456   SmallVector<Value, 6> swapOperands{i, j};
457   swapOperands.append(args.begin() + xStartIdx, args.end());
458   createSwap(builder, loc, swapOperands);
459   // If the pivot is moved, update p with the new pivot.
460   Value icond =
461       builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, i, p);
462   scf::IfOp ifOpI = builder.create<scf::IfOp>(loc, TypeRange{p.getType()},
463                                               icond, /*else=*/true);
464   builder.setInsertionPointToStart(&ifOpI.getThenRegion().front());
465   builder.create<scf::YieldOp>(loc, ValueRange{j});
466   builder.setInsertionPointToStart(&ifOpI.getElseRegion().front());
467   Value jcond =
468       builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, j, p);
469   scf::IfOp ifOpJ = builder.create<scf::IfOp>(loc, TypeRange{p.getType()},
470                                               jcond, /*else=*/true);
471   builder.setInsertionPointToStart(&ifOpJ.getThenRegion().front());
472   builder.create<scf::YieldOp>(loc, ValueRange{i});
473   builder.setInsertionPointToStart(&ifOpJ.getElseRegion().front());
474   builder.create<scf::YieldOp>(loc, ValueRange{p});
475   builder.setInsertionPointAfter(ifOpJ);
476   builder.create<scf::YieldOp>(loc, ifOpJ.getResults());
477   builder.setInsertionPointAfter(ifOpI);
478   Value compareEqIJ =
479       builder.create<arith::AndIOp>(loc, iCompareEq, jCompareEq);
480   scf::IfOp ifOp2 = builder.create<scf::IfOp>(
481       loc, TypeRange{i.getType(), j.getType()}, compareEqIJ, /*else=*/true);
482   builder.setInsertionPointToStart(&ifOp2.getThenRegion().front());
483   Value i2 = builder.create<arith::AddIOp>(loc, i, c1);
484   Value j2 = builder.create<arith::SubIOp>(loc, j, c1);
485   builder.create<scf::YieldOp>(loc, ValueRange{i2, j2});
486   builder.setInsertionPointToStart(&ifOp2.getElseRegion().front());
487   builder.create<scf::YieldOp>(loc, ValueRange{i, j});
488   builder.setInsertionPointAfter(ifOp2);
489   builder.create<scf::YieldOp>(
490       loc,
491       ValueRange{ifOp2.getResult(0), ifOp2.getResult(1), ifOpI.getResult(0)});
492 
493   // False branch for if i < j:
494   builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
495   builder.create<scf::YieldOp>(loc, ValueRange{i, j, p});
496 
497   // Return for the whileOp.
498   builder.setInsertionPointAfter(ifOp);
499   builder.create<scf::YieldOp>(loc, ifOp.getResults());
500 
501   // Return for the function.
502   builder.setInsertionPointAfter(whileOp);
503   builder.create<func::ReturnOp>(loc, whileOp.getResult(2));
504 }
505 
506 /// Creates a function to perform quick sort on the value in the range of
507 /// index [lo, hi).
508 //
509 // The generate IR corresponds to this C like algorithm:
510 // void quickSort(lo, hi, data) {
511 //   if (lo < hi) {
512 //        p = partition(low, high, data);
513 //        quickSort(lo, p, data);
514 //        quickSort(p + 1, hi, data);
515 //   }
516 // }
517 static void createSortNonstableFunc(OpBuilder &builder, ModuleOp module,
518                                     func::FuncOp func, size_t dim) {
519   OpBuilder::InsertionGuard insertionGuard(builder);
520   Block *entryBlock = func.addEntryBlock();
521   builder.setInsertionPointToStart(entryBlock);
522 
523   MLIRContext *context = module.getContext();
524   Location loc = func.getLoc();
525   ValueRange args = entryBlock->getArguments();
526   Value lo = args[loIdx];
527   Value hi = args[hiIdx];
528   Value cond =
529       builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult, lo, hi);
530   scf::IfOp ifOp = builder.create<scf::IfOp>(loc, cond, /*else=*/false);
531 
532   // The if-stmt true branch.
533   builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
534   FlatSymbolRefAttr partitionFunc = getMangledSortHelperFunc(
535       builder, func, {IndexType::get(context)}, kPartitionFuncNamePrefix, dim,
536       args, createPartitionFunc);
537   auto p = builder.create<func::CallOp>(
538       loc, partitionFunc, TypeRange{IndexType::get(context)}, ValueRange(args));
539 
540   SmallVector<Value, 6> lowOperands{lo, p.getResult(0)};
541   lowOperands.append(args.begin() + xStartIdx, args.end());
542   builder.create<func::CallOp>(loc, func, lowOperands);
543 
544   SmallVector<Value, 6> highOperands{
545       builder.create<arith::AddIOp>(loc, p.getResult(0),
546                                     constantIndex(builder, loc, 1)),
547       hi};
548   highOperands.append(args.begin() + xStartIdx, args.end());
549   builder.create<func::CallOp>(loc, func, highOperands);
550 
551   // After the if-stmt.
552   builder.setInsertionPointAfter(ifOp);
553   builder.create<func::ReturnOp>(loc);
554 }
555 
556 /// Creates a function to perform insertion sort on the values in the range of
557 /// index [lo, hi).
558 //
559 // The generate IR corresponds to this C like algorithm:
560 // void insertionSort(lo, hi, data) {
561 //   for (i = lo+1; i < hi; i++) {
562 //      d = data[i];
563 //      p = binarySearch(lo, i-1, data)
564 //      for (j = 0; j > i - p; j++)
565 //        data[i-j] = data[i-j-1]
566 //      data[p] = d
567 //   }
568 // }
569 static void createSortStableFunc(OpBuilder &builder, ModuleOp module,
570                                  func::FuncOp func, size_t dim) {
571   OpBuilder::InsertionGuard insertionGuard(builder);
572   Block *entryBlock = func.addEntryBlock();
573   builder.setInsertionPointToStart(entryBlock);
574 
575   MLIRContext *context = module.getContext();
576   Location loc = func.getLoc();
577   ValueRange args = entryBlock->getArguments();
578   Value c1 = constantIndex(builder, loc, 1);
579   Value lo = args[loIdx];
580   Value hi = args[hiIdx];
581   Value lop1 = builder.create<arith::AddIOp>(loc, lo, c1);
582 
583   // Start the outer for-stmt with induction variable i.
584   scf::ForOp forOpI = builder.create<scf::ForOp>(loc, lop1, hi, c1);
585   builder.setInsertionPointToStart(forOpI.getBody());
586   Value i = forOpI.getInductionVar();
587 
588   // Binary search to find the insertion point p.
589   SmallVector<Value, 6> operands{lo, i};
590   operands.append(args.begin() + xStartIdx, args.begin() + xStartIdx + dim);
591   FlatSymbolRefAttr searchFunc = getMangledSortHelperFunc(
592       builder, func, {IndexType::get(context)}, kBinarySearchFuncNamePrefix,
593       dim, operands, createBinarySearchFunc);
594   Value p = builder
595                 .create<func::CallOp>(loc, searchFunc, TypeRange{c1.getType()},
596                                       operands)
597                 .getResult(0);
598 
599   // Move the value at data[i] to a temporary location.
600   ValueRange data = args.drop_front(xStartIdx);
601   SmallVector<Value, 6> d;
602   for (Value v : data)
603     d.push_back(builder.create<memref::LoadOp>(loc, v, i));
604 
605   // Start the inner for-stmt with induction variable j, for moving data[p..i)
606   // to data[p+1..i+1).
607   Value imp = builder.create<arith::SubIOp>(loc, i, p);
608   Value c0 = constantIndex(builder, loc, 0);
609   scf::ForOp forOpJ = builder.create<scf::ForOp>(loc, c0, imp, c1);
610   builder.setInsertionPointToStart(forOpJ.getBody());
611   Value j = forOpJ.getInductionVar();
612   Value imj = builder.create<arith::SubIOp>(loc, i, j);
613   Value imjm1 = builder.create<arith::SubIOp>(loc, imj, c1);
614   for (Value v : data) {
615     Value t = builder.create<memref::LoadOp>(loc, v, imjm1);
616     builder.create<memref::StoreOp>(loc, t, v, imj);
617   }
618 
619   // Store the value at data[i] to data[p].
620   builder.setInsertionPointAfter(forOpJ);
621   for (auto it : llvm::zip(d, data))
622     builder.create<memref::StoreOp>(loc, std::get<0>(it), std::get<1>(it), p);
623 
624   builder.setInsertionPointAfter(forOpI);
625   builder.create<func::ReturnOp>(loc);
626 }
627 
628 //===---------------------------------------------------------------------===//
629 // The actual sparse buffer rewriting rules.
630 //===---------------------------------------------------------------------===//
631 
632 namespace {
633 
634 /// Sparse rewriting rule for the push_back operator.
635 struct PushBackRewriter : OpRewritePattern<PushBackOp> {
636 public:
637   using OpRewritePattern<PushBackOp>::OpRewritePattern;
638   PushBackRewriter(MLIRContext *context, bool enableInit)
639       : OpRewritePattern(context), enableBufferInitialization(enableInit) {}
640   LogicalResult matchAndRewrite(PushBackOp op,
641                                 PatternRewriter &rewriter) const override {
642     // Rewrite push_back(buffer, value, n) to:
643     // new_size = size(buffer) + n
644     // if (new_size > capacity(buffer))
645     //    while new_size > new_capacity
646     //      new_capacity = new_capacity*2
647     //    new_buffer = realloc(buffer, new_capacity)
648     // buffer = new_buffer
649     // subBuffer = subviewof(buffer)
650     // linalg.fill subBuffer value
651     //
652     // size(buffer) += n
653     //
654     // The capacity check is skipped when the attribute inbounds is presented.
655     Location loc = op->getLoc();
656     Value c0 = constantIndex(rewriter, loc, 0);
657     Value buffer = op.getInBuffer();
658     Value capacity = rewriter.create<memref::DimOp>(loc, buffer, c0);
659     Value idx = constantIndex(rewriter, loc, op.getIdx().getZExtValue());
660     Value bufferSizes = op.getBufferSizes();
661     Value size = rewriter.create<memref::LoadOp>(loc, bufferSizes, idx);
662     Value value = op.getValue();
663 
664     Value n = op.getN() ? op.getN() : constantIndex(rewriter, loc, 1);
665     Value newSize = rewriter.create<arith::AddIOp>(loc, size, n);
666     auto nValue = dyn_cast_or_null<arith::ConstantIndexOp>(n.getDefiningOp());
667     bool nIsOne = (nValue && nValue.value() == 1);
668 
669     if (!op.getInbounds()) {
670       Value cond = rewriter.create<arith::CmpIOp>(
671           loc, arith::CmpIPredicate::ugt, newSize, capacity);
672 
673       Value c2 = constantIndex(rewriter, loc, 2);
674       auto bufferType =
675           MemRefType::get({ShapedType::kDynamicSize}, value.getType());
676       scf::IfOp ifOp = rewriter.create<scf::IfOp>(loc, bufferType, cond,
677                                                   /*else=*/true);
678       // True branch.
679       rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front());
680       if (nIsOne) {
681         capacity = rewriter.create<arith::MulIOp>(loc, capacity, c2);
682       } else {
683         // Use a do-while loop to calculate the new capacity as follows:
684         //   do { new_capacity *= 2 } while (size > new_capacity)
685         scf::WhileOp whileOp =
686             rewriter.create<scf::WhileOp>(loc, capacity.getType(), capacity);
687 
688         // The before-region of the WhileOp.
689         Block *before = rewriter.createBlock(&whileOp.getBefore(), {},
690                                              {capacity.getType()}, {loc});
691         rewriter.setInsertionPointToEnd(before);
692 
693         capacity =
694             rewriter.create<arith::MulIOp>(loc, before->getArgument(0), c2);
695         cond = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ugt,
696                                               newSize, capacity);
697         rewriter.create<scf::ConditionOp>(loc, cond, ValueRange{capacity});
698         // The after-region of the WhileOp.
699         Block *after = rewriter.createBlock(&whileOp.getAfter(), {},
700                                             {capacity.getType()}, {loc});
701         rewriter.setInsertionPointToEnd(after);
702         rewriter.create<scf::YieldOp>(loc, after->getArguments());
703 
704         rewriter.setInsertionPointAfter(whileOp);
705         capacity = whileOp.getResult(0);
706       }
707 
708       Value newBuffer =
709           rewriter.create<memref::ReallocOp>(loc, bufferType, buffer, capacity);
710       if (enableBufferInitialization) {
711         Value fillSize = rewriter.create<arith::SubIOp>(loc, capacity, newSize);
712         Value fillValue = rewriter.create<arith::ConstantOp>(
713             loc, value.getType(), rewriter.getZeroAttr(value.getType()));
714         Value subBuffer = rewriter.create<memref::SubViewOp>(
715             loc, newBuffer, /*offset=*/ValueRange{newSize},
716             /*size=*/ValueRange{fillSize},
717             /*step=*/ValueRange{constantIndex(rewriter, loc, 1)});
718         rewriter.create<linalg::FillOp>(loc, fillValue, subBuffer);
719       }
720       rewriter.create<scf::YieldOp>(loc, newBuffer);
721 
722       // False branch.
723       rewriter.setInsertionPointToStart(&ifOp.getElseRegion().front());
724       rewriter.create<scf::YieldOp>(loc, buffer);
725 
726       // Prepare for adding the value to the end of the buffer.
727       rewriter.setInsertionPointAfter(ifOp);
728       buffer = ifOp.getResult(0);
729     }
730 
731     // Add the value to the end of the buffer.
732     if (nIsOne) {
733       rewriter.create<memref::StoreOp>(loc, value, buffer, size);
734     } else {
735       Value subBuffer = rewriter.create<memref::SubViewOp>(
736           loc, buffer, /*offset=*/ValueRange{size}, /*size=*/ValueRange{n},
737           /*step=*/ValueRange{constantIndex(rewriter, loc, 1)});
738       rewriter.create<linalg::FillOp>(loc, value, subBuffer);
739     }
740 
741     // Update the buffer size.
742     rewriter.create<memref::StoreOp>(loc, newSize, bufferSizes, idx);
743     rewriter.replaceOp(op, buffer);
744     return success();
745   }
746 
747 private:
748   bool enableBufferInitialization;
749 };
750 
751 /// Sparse rewriting rule for the sort operator.
752 struct SortRewriter : public OpRewritePattern<SortOp> {
753 public:
754   using OpRewritePattern<SortOp>::OpRewritePattern;
755 
756   LogicalResult matchAndRewrite(SortOp op,
757                                 PatternRewriter &rewriter) const override {
758     Location loc = op.getLoc();
759     SmallVector<Value, 6> operands{constantIndex(rewriter, loc, 0), op.getN()};
760 
761     // Convert `values` to have dynamic shape and append them to `operands`.
762     auto addValues = [&](ValueRange values) {
763       for (Value v : values) {
764         auto mtp = v.getType().cast<MemRefType>();
765         if (!mtp.isDynamicDim(0)) {
766           auto newMtp =
767               MemRefType::get({ShapedType::kDynamicSize}, mtp.getElementType());
768           v = rewriter.create<memref::CastOp>(loc, newMtp, v);
769         }
770         operands.push_back(v);
771       }
772     };
773     ValueRange xs = op.getXs();
774     addValues(xs);
775     addValues(op.getYs());
776     auto insertPoint = op->getParentOfType<func::FuncOp>();
777     SmallString<32> funcName(op.getStable() ? kSortStableFuncNamePrefix
778                                             : kSortNonstableFuncNamePrefix);
779     FuncGeneratorType funcGenerator =
780         op.getStable() ? createSortStableFunc : createSortNonstableFunc;
781     FlatSymbolRefAttr func =
782         getMangledSortHelperFunc(rewriter, insertPoint, TypeRange(), funcName,
783                                  xs.size(), operands, funcGenerator);
784     rewriter.replaceOpWithNewOp<func::CallOp>(op, func, TypeRange(), operands);
785     return success();
786   }
787 };
788 
789 } // namespace
790 
791 //===---------------------------------------------------------------------===//
792 // Methods that add patterns described in this file to a pattern list.
793 //===---------------------------------------------------------------------===//
794 
795 void mlir::populateSparseBufferRewriting(RewritePatternSet &patterns,
796                                          bool enableBufferInitialization) {
797   patterns.add<PushBackRewriter>(patterns.getContext(),
798                                  enableBufferInitialization);
799   patterns.add<SortRewriter>(patterns.getContext());
800 }
801