xref: /llvm-project/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.cpp (revision 204234a69c068032a1adac31f00b51f3b9efa778)
1 //===- CodegenUtils.cpp - Utilities for generating MLIR -------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "CodegenUtils.h"
10 #include "SparseTensorDescriptor.h"
11 
12 #include "mlir/Dialect/Affine/IR/AffineOps.h"
13 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
14 #include "mlir/Dialect/Linalg/IR/Linalg.h"
15 #include "mlir/Dialect/Linalg/Utils/Utils.h"
16 #include "mlir/Dialect/MemRef/IR/MemRef.h"
17 #include "mlir/Dialect/Tensor/IR/Tensor.h"
18 #include "mlir/IR/Matchers.h"
19 #include "mlir/IR/Types.h"
20 #include "mlir/IR/Value.h"
21 #include <optional>
22 
23 using namespace mlir;
24 using namespace mlir::sparse_tensor;
25 
26 //===----------------------------------------------------------------------===//
27 // ExecutionEngine/SparseTensorUtils helper functions.
28 //===----------------------------------------------------------------------===//
29 
30 OverheadType mlir::sparse_tensor::overheadTypeEncoding(unsigned width) {
31   switch (width) {
32   case 64:
33     return OverheadType::kU64;
34   case 32:
35     return OverheadType::kU32;
36   case 16:
37     return OverheadType::kU16;
38   case 8:
39     return OverheadType::kU8;
40   case 0:
41     return OverheadType::kIndex;
42   }
43   llvm_unreachable("Unsupported overhead bitwidth");
44 }
45 
46 OverheadType mlir::sparse_tensor::overheadTypeEncoding(Type tp) {
47   if (tp.isIndex())
48     return OverheadType::kIndex;
49   if (auto intTp = dyn_cast<IntegerType>(tp))
50     return overheadTypeEncoding(intTp.getWidth());
51   llvm_unreachable("Unknown overhead type");
52 }
53 
54 Type mlir::sparse_tensor::getOverheadType(Builder &builder, OverheadType ot) {
55   switch (ot) {
56   case OverheadType::kIndex:
57     return builder.getIndexType();
58   case OverheadType::kU64:
59     return builder.getIntegerType(64);
60   case OverheadType::kU32:
61     return builder.getIntegerType(32);
62   case OverheadType::kU16:
63     return builder.getIntegerType(16);
64   case OverheadType::kU8:
65     return builder.getIntegerType(8);
66   }
67   llvm_unreachable("Unknown OverheadType");
68 }
69 
70 OverheadType
71 mlir::sparse_tensor::posTypeEncoding(SparseTensorEncodingAttr enc) {
72   return overheadTypeEncoding(enc.getPosWidth());
73 }
74 
75 OverheadType
76 mlir::sparse_tensor::crdTypeEncoding(SparseTensorEncodingAttr enc) {
77   return overheadTypeEncoding(enc.getCrdWidth());
78 }
79 
80 // TODO: we ought to add some `static_assert` tests to ensure that the
81 // `STEA::get{Pos,Crd}Type` methods agree with `getOverheadType(builder,
82 // {pos,crd}OverheadTypeEncoding(enc))`
83 
84 // TODO: Adjust the naming convention for the constructors of
85 // `OverheadType` so we can use the `MLIR_SPARSETENSOR_FOREVERY_O` x-macro
86 // here instead of `MLIR_SPARSETENSOR_FOREVERY_FIXED_O`; to further reduce
87 // the possibility of typo bugs or things getting out of sync.
88 StringRef mlir::sparse_tensor::overheadTypeFunctionSuffix(OverheadType ot) {
89   switch (ot) {
90   case OverheadType::kIndex:
91     return "0";
92 #define CASE(ONAME, O)                                                         \
93   case OverheadType::kU##ONAME:                                                \
94     return #ONAME;
95     MLIR_SPARSETENSOR_FOREVERY_FIXED_O(CASE)
96 #undef CASE
97   }
98   llvm_unreachable("Unknown OverheadType");
99 }
100 
101 StringRef mlir::sparse_tensor::overheadTypeFunctionSuffix(Type tp) {
102   return overheadTypeFunctionSuffix(overheadTypeEncoding(tp));
103 }
104 
105 PrimaryType mlir::sparse_tensor::primaryTypeEncoding(Type elemTp) {
106   if (elemTp.isF64())
107     return PrimaryType::kF64;
108   if (elemTp.isF32())
109     return PrimaryType::kF32;
110   if (elemTp.isF16())
111     return PrimaryType::kF16;
112   if (elemTp.isBF16())
113     return PrimaryType::kBF16;
114   if (elemTp.isInteger(64))
115     return PrimaryType::kI64;
116   if (elemTp.isInteger(32))
117     return PrimaryType::kI32;
118   if (elemTp.isInteger(16))
119     return PrimaryType::kI16;
120   if (elemTp.isInteger(8))
121     return PrimaryType::kI8;
122   if (auto complexTp = dyn_cast<ComplexType>(elemTp)) {
123     auto complexEltTp = complexTp.getElementType();
124     if (complexEltTp.isF64())
125       return PrimaryType::kC64;
126     if (complexEltTp.isF32())
127       return PrimaryType::kC32;
128   }
129   llvm_unreachable("Unknown primary type");
130 }
131 
132 StringRef mlir::sparse_tensor::primaryTypeFunctionSuffix(PrimaryType pt) {
133   switch (pt) {
134 #define CASE(VNAME, V)                                                         \
135   case PrimaryType::k##VNAME:                                                  \
136     return #VNAME;
137     MLIR_SPARSETENSOR_FOREVERY_V(CASE)
138 #undef CASE
139   }
140   llvm_unreachable("Unknown PrimaryType");
141 }
142 
143 StringRef mlir::sparse_tensor::primaryTypeFunctionSuffix(Type elemTp) {
144   return primaryTypeFunctionSuffix(primaryTypeEncoding(elemTp));
145 }
146 
147 //===----------------------------------------------------------------------===//
148 // Misc code generators.
149 //===----------------------------------------------------------------------===//
150 
151 Value sparse_tensor::genCast(OpBuilder &builder, Location loc, Value value,
152                              Type dstTp) {
153   const Type srcTp = value.getType();
154   if (srcTp == dstTp)
155     return value;
156 
157   // int <=> index
158   if (isa<IndexType>(srcTp) || isa<IndexType>(dstTp))
159     return builder.create<arith::IndexCastOp>(loc, dstTp, value);
160 
161   const auto srcIntTp = dyn_cast_or_null<IntegerType>(srcTp);
162   const bool isUnsignedCast = srcIntTp ? srcIntTp.isUnsigned() : false;
163   return mlir::convertScalarToDtype(builder, loc, value, dstTp, isUnsignedCast);
164 }
165 
166 Value sparse_tensor::genScalarToTensor(OpBuilder &builder, Location loc,
167                                        Value elem, Type dstTp) {
168   if (auto rtp = dyn_cast<RankedTensorType>(dstTp)) {
169     // Scalars can only be converted to 0-ranked tensors.
170     assert(rtp.getRank() == 0);
171     elem = sparse_tensor::genCast(builder, loc, elem, rtp.getElementType());
172     return builder.create<tensor::FromElementsOp>(loc, rtp, elem);
173   }
174   return sparse_tensor::genCast(builder, loc, elem, dstTp);
175 }
176 
177 Value sparse_tensor::genIndexLoad(OpBuilder &builder, Location loc, Value mem,
178                                   ValueRange s) {
179   Value load = builder.create<memref::LoadOp>(loc, mem, s);
180   if (!isa<IndexType>(load.getType())) {
181     if (load.getType().getIntOrFloatBitWidth() < 64)
182       load = builder.create<arith::ExtUIOp>(loc, builder.getI64Type(), load);
183     load =
184         builder.create<arith::IndexCastOp>(loc, builder.getIndexType(), load);
185   }
186   return load;
187 }
188 
189 mlir::TypedAttr mlir::sparse_tensor::getOneAttr(Builder &builder, Type tp) {
190   if (isa<FloatType>(tp))
191     return builder.getFloatAttr(tp, 1.0);
192   if (isa<IndexType>(tp))
193     return builder.getIndexAttr(1);
194   if (auto intTp = dyn_cast<IntegerType>(tp))
195     return builder.getIntegerAttr(tp, APInt(intTp.getWidth(), 1));
196   if (isa<RankedTensorType, VectorType>(tp)) {
197     auto shapedTp = cast<ShapedType>(tp);
198     if (auto one = getOneAttr(builder, shapedTp.getElementType()))
199       return DenseElementsAttr::get(shapedTp, one);
200   }
201   llvm_unreachable("Unsupported attribute type");
202 }
203 
204 Value mlir::sparse_tensor::genIsNonzero(OpBuilder &builder, mlir::Location loc,
205                                         Value v) {
206   Type tp = v.getType();
207   Value zero = constantZero(builder, loc, tp);
208   if (isa<FloatType>(tp))
209     return builder.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UNE, v,
210                                          zero);
211   if (tp.isIntOrIndex())
212     return builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ne, v,
213                                          zero);
214   if (dyn_cast<ComplexType>(tp))
215     return builder.create<complex::NotEqualOp>(loc, v, zero);
216   llvm_unreachable("Non-numeric type");
217 }
218 
219 void mlir::sparse_tensor::genReshapeDstShape(
220     OpBuilder &builder, Location loc, SmallVectorImpl<Value> &dstShape,
221     ArrayRef<Value> srcShape, ArrayRef<Size> staticDstShape,
222     ArrayRef<ReassociationIndices> reassociation) {
223   // Collapse shape.
224   if (reassociation.size() < srcShape.size()) {
225     unsigned start = 0;
226     for (const auto &map : llvm::enumerate(reassociation)) {
227       auto dstDim = constantIndex(builder, loc, 1);
228       for (unsigned i = start; i < start + map.value().size(); i++) {
229         dstDim = builder.create<arith::MulIOp>(loc, dstDim, srcShape[i]);
230       }
231       dstShape.push_back(dstDim);
232       start = start + map.value().size();
233     }
234     assert(start == srcShape.size());
235     return;
236   }
237 
238   // Expand shape.
239   assert(reassociation.size() == srcShape.size());
240   unsigned start = 0;
241   // Expand the i-th dimension in srcShape.
242   for (unsigned i = 0, size = srcShape.size(); i < size; i++) {
243     const auto &map = reassociation[i];
244     auto srcDim = srcShape[i];
245     // Iterate through dimensions expanded from the i-th dimension.
246     for (unsigned j = start; j < start + map.size(); j++) {
247       // There can be only one dynamic sized dimension among dimensions
248       // expanded from the i-th dimension in srcShape.
249       // For example, if srcDim = 8, then the expanded shape could be <2x?x2>,
250       // but not <2x?x?>.
251       if (staticDstShape[j] == ShapedType::kDynamic) {
252         // The expanded dimension has dynamic size. We compute the dimension
253         // by dividing srcDim by the product of the static dimensions.
254         Size product = 1;
255         for (unsigned k = start; k < start + map.size(); k++) {
256           if (staticDstShape[k] != ShapedType::kDynamic) {
257             product *= staticDstShape[k];
258           }
259         }
260         // Compute the dynamic dimension size.
261         Value productVal = constantIndex(builder, loc, product);
262         Value dynamicSize =
263             builder.create<arith::DivUIOp>(loc, srcDim, productVal);
264         dstShape.push_back(dynamicSize);
265       } else {
266         // The expanded dimension is statically known.
267         dstShape.push_back(constantIndex(builder, loc, staticDstShape[j]));
268       }
269     }
270     start = start + map.size();
271   }
272   assert(start == staticDstShape.size());
273 }
274 
275 void mlir::sparse_tensor::reshapeCvs(
276     OpBuilder &builder, Location loc,
277     ArrayRef<ReassociationIndices> reassociation, // NOLINT
278     ValueRange srcSizes, ValueRange srcCvs,       // NOLINT
279     ValueRange dstSizes, SmallVectorImpl<Value> &dstCvs) {
280   const unsigned srcRank = srcSizes.size();
281   const unsigned dstRank = dstSizes.size();
282   assert(srcRank == srcCvs.size() && "Source rank mismatch");
283   const bool isCollapse = srcRank > dstRank;
284   const ValueRange sizes = isCollapse ? srcSizes : dstSizes;
285   // Iterate over reassociation map.
286   unsigned i = 0;
287   unsigned start = 0;
288   for (const auto &map : llvm::enumerate(reassociation)) {
289     // Prepare strides information in dimension slice.
290     Value linear = constantIndex(builder, loc, 1);
291     for (unsigned j = start, end = start + map.value().size(); j < end; j++) {
292       linear = builder.create<arith::MulIOp>(loc, linear, sizes[j]);
293     }
294     // Start expansion.
295     Value val;
296     if (!isCollapse)
297       val = srcCvs[i];
298     // Iterate over dimension slice.
299     for (unsigned j = start, end = start + map.value().size(); j < end; j++) {
300       linear = builder.create<arith::DivUIOp>(loc, linear, sizes[j]);
301       if (isCollapse) {
302         const Value mul = builder.create<arith::MulIOp>(loc, srcCvs[j], linear);
303         val = val ? builder.create<arith::AddIOp>(loc, val, mul) : mul;
304       } else {
305         const Value old = val;
306         val = builder.create<arith::DivUIOp>(loc, val, linear);
307         assert(dstCvs.size() == j);
308         dstCvs.push_back(val);
309         val = builder.create<arith::RemUIOp>(loc, old, linear);
310       }
311     }
312     // Finalize collapse.
313     if (isCollapse) {
314       assert(dstCvs.size() == i);
315       dstCvs.push_back(val);
316     }
317     start += map.value().size();
318     i++;
319   }
320   assert(dstCvs.size() == dstRank);
321 }
322 
323 FlatSymbolRefAttr mlir::sparse_tensor::getFunc(ModuleOp module, StringRef name,
324                                                TypeRange resultType,
325                                                ValueRange operands,
326                                                EmitCInterface emitCInterface) {
327   MLIRContext *context = module.getContext();
328   auto result = SymbolRefAttr::get(context, name);
329   auto func = module.lookupSymbol<func::FuncOp>(result.getAttr());
330   if (!func) {
331     OpBuilder moduleBuilder(module.getBodyRegion());
332     func = moduleBuilder.create<func::FuncOp>(
333         module.getLoc(), name,
334         FunctionType::get(context, operands.getTypes(), resultType));
335     func.setPrivate();
336     if (static_cast<bool>(emitCInterface))
337       func->setAttr(LLVM::LLVMDialect::getEmitCWrapperAttrName(),
338                     UnitAttr::get(context));
339   }
340   return result;
341 }
342 
343 func::CallOp mlir::sparse_tensor::createFuncCall(
344     OpBuilder &builder, Location loc, StringRef name, TypeRange resultType,
345     ValueRange operands, EmitCInterface emitCInterface) {
346   auto module = builder.getBlock()->getParentOp()->getParentOfType<ModuleOp>();
347   FlatSymbolRefAttr fn =
348       getFunc(module, name, resultType, operands, emitCInterface);
349   return builder.create<func::CallOp>(loc, resultType, fn, operands);
350 }
351 
352 Type mlir::sparse_tensor::getOpaquePointerType(MLIRContext *ctx) {
353   return LLVM::LLVMPointerType::get(ctx);
354 }
355 
356 Type mlir::sparse_tensor::getOpaquePointerType(Builder &builder) {
357   return getOpaquePointerType(builder.getContext());
358 }
359 
360 Value mlir::sparse_tensor::genAlloca(OpBuilder &builder, Location loc,
361                                      unsigned sz, Type tp, bool staticShape) {
362   if (staticShape) {
363     auto memTp = MemRefType::get({sz}, tp);
364     return builder.create<memref::AllocaOp>(loc, memTp);
365   }
366   return genAlloca(builder, loc, constantIndex(builder, loc, sz), tp);
367 }
368 
369 Value mlir::sparse_tensor::genAlloca(OpBuilder &builder, Location loc, Value sz,
370                                      Type tp) {
371   auto memTp = MemRefType::get({ShapedType::kDynamic}, tp);
372   return builder.create<memref::AllocaOp>(loc, memTp, ValueRange{sz});
373 }
374 
375 Value mlir::sparse_tensor::genAllocaScalar(OpBuilder &builder, Location loc,
376                                            Type tp) {
377   return builder.create<memref::AllocaOp>(loc, MemRefType::get({}, tp));
378 }
379 
380 Value mlir::sparse_tensor::allocaBuffer(OpBuilder &builder, Location loc,
381                                         ValueRange values) {
382   const unsigned sz = values.size();
383   assert(sz >= 1);
384   Value buffer = genAlloca(builder, loc, sz, values[0].getType());
385   for (unsigned i = 0; i < sz; i++) {
386     Value idx = constantIndex(builder, loc, i);
387     builder.create<memref::StoreOp>(loc, values[i], buffer, idx);
388   }
389   return buffer;
390 }
391 
392 Value mlir::sparse_tensor::allocDenseTensor(OpBuilder &builder, Location loc,
393                                             RankedTensorType tensorTp,
394                                             ValueRange sizes) {
395   Type elemTp = tensorTp.getElementType();
396   auto shape = tensorTp.getShape();
397   auto memTp = MemRefType::get(shape, elemTp);
398   SmallVector<Value> dynamicSizes;
399   for (unsigned i = 0, rank = tensorTp.getRank(); i < rank; i++) {
400     if (shape[i] == ShapedType::kDynamic)
401       dynamicSizes.push_back(sizes[i]);
402   }
403   Value mem = builder.create<memref::AllocOp>(loc, memTp, dynamicSizes);
404   Value zero = constantZero(builder, loc, elemTp);
405   builder.create<linalg::FillOp>(loc, ValueRange{zero}, ValueRange{mem});
406   return mem;
407 }
408 
409 void mlir::sparse_tensor::deallocDenseTensor(OpBuilder &builder, Location loc,
410                                              Value buffer) {
411   builder.create<memref::DeallocOp>(loc, buffer);
412 }
413 
414 void mlir::sparse_tensor::sizesFromSrc(OpBuilder &builder,
415                                        SmallVectorImpl<Value> &sizes,
416                                        Location loc, Value src) {
417   const Dimension dimRank = getSparseTensorType(src).getDimRank();
418   for (Dimension d = 0; d < dimRank; d++)
419     sizes.push_back(linalg::createOrFoldDimOp(builder, loc, src, d));
420 }
421 
422 Operation *mlir::sparse_tensor::getTop(Operation *op) {
423   for (; isa<scf::ForOp>(op->getParentOp()) ||
424          isa<scf::WhileOp>(op->getParentOp()) ||
425          isa<scf::ParallelOp>(op->getParentOp()) ||
426          isa<scf::IfOp>(op->getParentOp());
427        op = op->getParentOp())
428     ;
429   return op;
430 }
431 
432 void sparse_tensor::foreachInSparseConstant(
433     OpBuilder &builder, Location loc, SparseElementsAttr attr, AffineMap order,
434     function_ref<void(ArrayRef<Value>, Value)> callback) {
435   if (!order)
436     order = builder.getMultiDimIdentityMap(attr.getType().getRank());
437 
438   auto stt = SparseTensorType(getRankedTensorType(attr));
439   const Dimension dimRank = stt.getDimRank();
440   const auto coordinates = attr.getIndices().getValues<IntegerAttr>();
441   const auto values = attr.getValues().getValues<Attribute>();
442 
443   // This is like the `Element<V>` class in the runtime library, but for
444   // MLIR attributes.  In the future we may want to move this out into
445   // a proper class definition to help improve code legibility (e.g.,
446   // `first` -> `coords`, `second` -> `value`) as well as being able
447   // to factor out analogues of `ElementLT<V>` for the sort below, etc.
448   using ElementAttr = std::pair<SmallVector<IntegerAttr>, Attribute>;
449 
450   // Construct the COO from the SparseElementsAttr.
451   SmallVector<ElementAttr> elems;
452   for (size_t i = 0, nse = values.size(); i < nse; i++) {
453     elems.emplace_back();
454     elems.back().second = values[i];
455     auto &coords = elems.back().first;
456     coords.reserve(dimRank);
457     for (Dimension d = 0; d < dimRank; d++)
458       coords.push_back(coordinates[i * dimRank + d]);
459   }
460 
461   // Sorts the sparse element attribute based on coordinates.
462   std::sort(elems.begin(), elems.end(),
463             [order](const ElementAttr &lhs, const ElementAttr &rhs) {
464               if (std::addressof(lhs) == std::addressof(rhs))
465                 return false;
466 
467               auto lhsCoords = llvm::map_to_vector(
468                   lhs.first, [](IntegerAttr i) { return i.getInt(); });
469               auto rhsCoords = llvm::map_to_vector(
470                   rhs.first, [](IntegerAttr i) { return i.getInt(); });
471 
472               SmallVector<int64_t, 4> lhsLvlCrds = order.compose(lhsCoords);
473               SmallVector<int64_t, 4> rhsLvlCrds = order.compose(rhsCoords);
474               // Sort the element based on the lvl coordinates.
475               for (Level l = 0; l < order.getNumResults(); l++) {
476                 if (lhsLvlCrds[l] == rhsLvlCrds[l])
477                   continue;
478                 return lhsLvlCrds[l] < rhsLvlCrds[l];
479               }
480               llvm_unreachable("no equal coordinate in sparse element attr");
481             });
482 
483   SmallVector<Value> cvs;
484   cvs.reserve(dimRank);
485   for (size_t i = 0, nse = values.size(); i < nse; i++) {
486     // Remap coordinates.
487     cvs.clear();
488     for (Dimension d = 0; d < dimRank; d++) {
489       auto crd = elems[i].first[d].getInt();
490       cvs.push_back(builder.create<arith::ConstantIndexOp>(loc, crd));
491     }
492     // Remap value.
493     Value val;
494     if (isa<ComplexType>(attr.getElementType())) {
495       auto valAttr = cast<ArrayAttr>(elems[i].second);
496       val = builder.create<complex::ConstantOp>(loc, attr.getElementType(),
497                                                 valAttr);
498     } else {
499       auto valAttr = cast<TypedAttr>(elems[i].second);
500       val = builder.create<arith::ConstantOp>(loc, valAttr);
501     }
502     assert(val);
503     callback(cvs, val);
504   }
505 }
506 
507 SmallVector<Value> sparse_tensor::loadAll(OpBuilder &builder, Location loc,
508                                           size_t size, Value mem,
509                                           size_t offsetIdx, Value offsetVal) {
510 #ifndef NDEBUG
511   const auto memTp = cast<MemRefType>(mem.getType());
512   assert(memTp.getRank() == 1);
513   const Size memSh = memTp.getDimSize(0);
514   assert(ShapedType::isDynamic(memSh) || memSh >= static_cast<Size>(size));
515   assert(offsetIdx == 0 || offsetIdx < size);
516 #endif // NDEBUG
517   SmallVector<Value> vs;
518   vs.reserve(size);
519   for (unsigned i = 0; i < size; i++) {
520     Value v = builder.create<memref::LoadOp>(loc, mem,
521                                              constantIndex(builder, loc, i));
522     if (i == offsetIdx && offsetVal)
523       v = builder.create<arith::AddIOp>(loc, v, offsetVal);
524     vs.push_back(v);
525   }
526   return vs;
527 }
528 
529 void sparse_tensor::storeAll(OpBuilder &builder, Location loc, Value mem,
530                              ValueRange vs, size_t offsetIdx, Value offsetVal) {
531 #ifndef NDEBUG
532   const size_t vsize = vs.size();
533   const auto memTp = cast<MemRefType>(mem.getType());
534   assert(memTp.getRank() == 1);
535   const Size memSh = memTp.getDimSize(0);
536   assert(ShapedType::isDynamic(memSh) || memSh >= static_cast<Size>(vsize));
537   assert(offsetIdx == 0 || offsetIdx < vsize);
538 #endif // NDEBUG
539   for (const auto &v : llvm::enumerate(vs)) {
540     const Value w =
541         (offsetIdx == v.index() && offsetVal)
542             ? builder.create<arith::AddIOp>(loc, v.value(), offsetVal)
543             : v.value();
544     builder.create<memref::StoreOp>(loc, w, mem,
545                                     constantIndex(builder, loc, v.index()));
546   }
547 }
548 
549 TypedValue<BaseMemRefType>
550 sparse_tensor::genToMemref(OpBuilder &builder, Location loc, Value tensor) {
551   auto tTp = llvm::cast<TensorType>(tensor.getType());
552   auto mTp = MemRefType::get(tTp.getShape(), tTp.getElementType());
553   return builder.create<bufferization::ToMemrefOp>(loc, mTp, tensor)
554       .getResult();
555 }
556 
557 Value sparse_tensor::createOrFoldSliceOffsetOp(OpBuilder &builder, Location loc,
558                                                Value tensor, Dimension dim) {
559   auto enc = getSparseTensorEncoding(tensor.getType());
560   assert(enc && enc.isSlice());
561   std::optional<unsigned> offset = enc.getStaticDimSliceOffset(dim);
562   if (offset.has_value())
563     return constantIndex(builder, loc, *offset);
564   return builder.create<ToSliceOffsetOp>(loc, tensor, APInt(64, dim));
565 }
566 
567 Value sparse_tensor::createOrFoldSliceStrideOp(OpBuilder &builder, Location loc,
568                                                Value tensor, Dimension dim) {
569   auto enc = getSparseTensorEncoding(tensor.getType());
570   assert(enc && enc.isSlice());
571   std::optional<unsigned> stride = enc.getStaticDimSliceStride(dim);
572   if (stride.has_value())
573     return constantIndex(builder, loc, *stride);
574   return builder.create<ToSliceStrideOp>(loc, tensor, APInt(64, dim));
575 }
576 
577 Value sparse_tensor::genReader(OpBuilder &builder, Location loc,
578                                SparseTensorType stt, Value tensor,
579                                /*out*/ SmallVectorImpl<Value> &dimSizesValues,
580                                /*out*/ Value &dimSizesBuffer) {
581   // Construct the dimension **shapes** buffer. The buffer contains the static
582   // size per dimension, or otherwise a zero for a dynamic size.
583   Dimension dimRank = stt.getDimRank();
584   dimSizesValues.clear();
585   dimSizesValues.reserve(dimRank);
586   for (const Size sz : stt.getDimShape()) {
587     const auto s = ShapedType::isDynamic(sz) ? 0 : sz;
588     dimSizesValues.push_back(constantIndex(builder, loc, s));
589   }
590   Value dimShapesBuffer = allocaBuffer(builder, loc, dimSizesValues);
591   // Create the `CheckedSparseTensorReader`. This reader performs a
592   // consistency check on the static sizes, but accepts any size
593   // of each dimension with a dynamic size.
594   Type opaqueTp = getOpaquePointerType(builder);
595   Type eltTp = stt.getElementType();
596   Value valTp = constantPrimaryTypeEncoding(builder, loc, eltTp);
597   Value reader =
598       createFuncCall(builder, loc, "createCheckedSparseTensorReader", opaqueTp,
599                      {tensor, dimShapesBuffer, valTp}, EmitCInterface::On)
600           .getResult(0);
601   // For static shapes, the shape buffer can be used right away. For dynamic
602   // shapes, use the information from the reader to construct a buffer that
603   // supplies the actual size for each dynamic dimension.
604   dimSizesBuffer = dimShapesBuffer;
605   if (stt.hasDynamicDimShape()) {
606     Type indexTp = builder.getIndexType();
607     auto memTp = MemRefType::get({ShapedType::kDynamic}, indexTp);
608     dimSizesBuffer =
609         createFuncCall(builder, loc, "getSparseTensorReaderDimSizes", memTp,
610                        reader, EmitCInterface::On)
611             .getResult(0);
612     // Also convert the dim shapes values into dim sizes values, just in case
613     // subsequent clients need the values (DCE will remove unused).
614     for (Dimension d = 0; d < dimRank; d++) {
615       if (stt.isDynamicDim(d))
616         dimSizesValues[d] = builder.create<memref::LoadOp>(
617             loc, dimSizesBuffer, constantIndex(builder, loc, d));
618     }
619   }
620   return reader;
621 }
622 
623 Value sparse_tensor::genMapBuffers(
624     OpBuilder &builder, Location loc, SparseTensorType stt,
625     ArrayRef<Value> dimSizesValues, Value dimSizesBuffer,
626     /*out*/ SmallVectorImpl<Value> &lvlSizesValues,
627     /*out*/ Value &dim2lvlBuffer,
628     /*out*/ Value &lvl2dimBuffer) {
629   const Dimension dimRank = stt.getDimRank();
630   const Level lvlRank = stt.getLvlRank();
631   lvlSizesValues.clear();
632   lvlSizesValues.reserve(lvlRank);
633   // For an identity mapping, the dim2lvl and lvl2dim mappings are
634   // identical as are dimSizes and lvlSizes, so buffers are reused
635   // as much as possible.
636   if (stt.isIdentity()) {
637     assert(dimRank == lvlRank);
638     SmallVector<Value> iotaValues;
639     iotaValues.reserve(lvlRank);
640     for (Level l = 0; l < lvlRank; l++) {
641       iotaValues.push_back(constantIndex(builder, loc, l));
642       lvlSizesValues.push_back(dimSizesValues[l]);
643     }
644     dim2lvlBuffer = lvl2dimBuffer = allocaBuffer(builder, loc, iotaValues);
645     return dimSizesBuffer; // now lvlSizesBuffer
646   }
647   // Otherwise, some code needs to be generated to set up the buffers.
648   // This code deals with permutations as well as non-permutations that
649   // arise from rank changing blocking.
650   const auto dimToLvl = stt.getDimToLvl();
651   const auto lvlToDim = stt.getLvlToDim();
652   SmallVector<Value> dim2lvlValues(lvlRank); // for each lvl, expr in dim vars
653   SmallVector<Value> lvl2dimValues(dimRank); // for each dim, expr in lvl vars
654   // Generate dim2lvl.
655   assert(lvlRank == dimToLvl.getNumResults());
656   for (Level l = 0; l < lvlRank; l++) {
657     AffineExpr exp = dimToLvl.getResult(l);
658     // We expect:
659     //    (1) l = d
660     //    (2) l = d / c
661     //    (3) l = d % c
662     Dimension d = 0;
663     uint64_t cf = 0, cm = 0;
664     switch (exp.getKind()) {
665     case AffineExprKind::DimId: {
666       d = cast<AffineDimExpr>(exp).getPosition();
667       break;
668     }
669     case AffineExprKind::FloorDiv: {
670       auto floor = cast<AffineBinaryOpExpr>(exp);
671       d = cast<AffineDimExpr>(floor.getLHS()).getPosition();
672       cf = cast<AffineConstantExpr>(floor.getRHS()).getValue();
673       break;
674     }
675     case AffineExprKind::Mod: {
676       auto mod = cast<AffineBinaryOpExpr>(exp);
677       d = cast<AffineDimExpr>(mod.getLHS()).getPosition();
678       cm = cast<AffineConstantExpr>(mod.getRHS()).getValue();
679       break;
680     }
681     default:
682       llvm::report_fatal_error("unsupported dim2lvl in sparse tensor type");
683     }
684     dim2lvlValues[l] = constantIndex(builder, loc, encodeDim(d, cf, cm));
685     // Compute the level sizes.
686     //    (1) l = d        : size(d)
687     //    (2) l = d / c    : size(d) / c
688     //    (3) l = d % c    : c
689     Value lvlSz;
690     if (cm == 0) {
691       lvlSz = dimSizesValues[d];
692       if (cf != 0)
693         lvlSz = builder.create<arith::DivUIOp>(loc, lvlSz,
694                                                constantIndex(builder, loc, cf));
695     } else {
696       lvlSz = constantIndex(builder, loc, cm);
697     }
698     lvlSizesValues.push_back(lvlSz);
699   }
700   // Generate lvl2dim.
701   assert(dimRank == lvlToDim.getNumResults());
702   for (Dimension d = 0; d < dimRank; d++) {
703     AffineExpr exp = lvlToDim.getResult(d);
704     // We expect:
705     //    (1) d = l
706     //    (2) d = l' * c + l
707     Level l = 0, ll = 0;
708     uint64_t c = 0;
709     switch (exp.getKind()) {
710     case AffineExprKind::DimId: {
711       l = cast<AffineDimExpr>(exp).getPosition();
712       break;
713     }
714     case AffineExprKind::Add: {
715       // Always mul on lhs, symbol/constant on rhs.
716       auto add = cast<AffineBinaryOpExpr>(exp);
717       assert(add.getLHS().getKind() == AffineExprKind::Mul);
718       auto mul = cast<AffineBinaryOpExpr>(add.getLHS());
719       ll = cast<AffineDimExpr>(mul.getLHS()).getPosition();
720       c = cast<AffineConstantExpr>(mul.getRHS()).getValue();
721       l = cast<AffineDimExpr>(add.getRHS()).getPosition();
722       break;
723     }
724     default:
725       llvm::report_fatal_error("unsupported lvl2dim in sparse tensor type");
726     }
727     lvl2dimValues[d] = constantIndex(builder, loc, encodeLvl(l, c, ll));
728   }
729   // Return buffers.
730   dim2lvlBuffer = allocaBuffer(builder, loc, dim2lvlValues);
731   lvl2dimBuffer = allocaBuffer(builder, loc, lvl2dimValues);
732   return allocaBuffer(builder, loc, lvlSizesValues); // lvlSizesBuffer
733 }
734