xref: /llvm-project/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.cpp (revision 204234a69c068032a1adac31f00b51f3b9efa778)
1365777ecSAart Bik //===- CodegenUtils.cpp - Utilities for generating MLIR -------------------===//
2365777ecSAart Bik //
3365777ecSAart Bik // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4365777ecSAart Bik // See https://llvm.org/LICENSE.txt for license information.
5365777ecSAart Bik // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6365777ecSAart Bik //
7365777ecSAart Bik //===----------------------------------------------------------------------===//
8365777ecSAart Bik 
9365777ecSAart Bik #include "CodegenUtils.h"
10365777ecSAart Bik #include "SparseTensorDescriptor.h"
11365777ecSAart Bik 
12365777ecSAart Bik #include "mlir/Dialect/Affine/IR/AffineOps.h"
13365777ecSAart Bik #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
14365777ecSAart Bik #include "mlir/Dialect/Linalg/IR/Linalg.h"
15365777ecSAart Bik #include "mlir/Dialect/Linalg/Utils/Utils.h"
16365777ecSAart Bik #include "mlir/Dialect/MemRef/IR/MemRef.h"
17365777ecSAart Bik #include "mlir/Dialect/Tensor/IR/Tensor.h"
18365777ecSAart Bik #include "mlir/IR/Matchers.h"
19365777ecSAart Bik #include "mlir/IR/Types.h"
20365777ecSAart Bik #include "mlir/IR/Value.h"
21365777ecSAart Bik #include <optional>
22365777ecSAart Bik 
23365777ecSAart Bik using namespace mlir;
24365777ecSAart Bik using namespace mlir::sparse_tensor;
25365777ecSAart Bik 
26365777ecSAart Bik //===----------------------------------------------------------------------===//
27365777ecSAart Bik // ExecutionEngine/SparseTensorUtils helper functions.
28365777ecSAart Bik //===----------------------------------------------------------------------===//
29365777ecSAart Bik 
30365777ecSAart Bik OverheadType mlir::sparse_tensor::overheadTypeEncoding(unsigned width) {
31365777ecSAart Bik   switch (width) {
32365777ecSAart Bik   case 64:
33365777ecSAart Bik     return OverheadType::kU64;
34365777ecSAart Bik   case 32:
35365777ecSAart Bik     return OverheadType::kU32;
36365777ecSAart Bik   case 16:
37365777ecSAart Bik     return OverheadType::kU16;
38365777ecSAart Bik   case 8:
39365777ecSAart Bik     return OverheadType::kU8;
40365777ecSAart Bik   case 0:
41365777ecSAart Bik     return OverheadType::kIndex;
42365777ecSAart Bik   }
43365777ecSAart Bik   llvm_unreachable("Unsupported overhead bitwidth");
44365777ecSAart Bik }
45365777ecSAart Bik 
46365777ecSAart Bik OverheadType mlir::sparse_tensor::overheadTypeEncoding(Type tp) {
47365777ecSAart Bik   if (tp.isIndex())
48365777ecSAart Bik     return OverheadType::kIndex;
49365777ecSAart Bik   if (auto intTp = dyn_cast<IntegerType>(tp))
50365777ecSAart Bik     return overheadTypeEncoding(intTp.getWidth());
51365777ecSAart Bik   llvm_unreachable("Unknown overhead type");
52365777ecSAart Bik }
53365777ecSAart Bik 
54365777ecSAart Bik Type mlir::sparse_tensor::getOverheadType(Builder &builder, OverheadType ot) {
55365777ecSAart Bik   switch (ot) {
56365777ecSAart Bik   case OverheadType::kIndex:
57365777ecSAart Bik     return builder.getIndexType();
58365777ecSAart Bik   case OverheadType::kU64:
59365777ecSAart Bik     return builder.getIntegerType(64);
60365777ecSAart Bik   case OverheadType::kU32:
61365777ecSAart Bik     return builder.getIntegerType(32);
62365777ecSAart Bik   case OverheadType::kU16:
63365777ecSAart Bik     return builder.getIntegerType(16);
64365777ecSAart Bik   case OverheadType::kU8:
65365777ecSAart Bik     return builder.getIntegerType(8);
66365777ecSAart Bik   }
67365777ecSAart Bik   llvm_unreachable("Unknown OverheadType");
68365777ecSAart Bik }
69365777ecSAart Bik 
70365777ecSAart Bik OverheadType
71365777ecSAart Bik mlir::sparse_tensor::posTypeEncoding(SparseTensorEncodingAttr enc) {
72365777ecSAart Bik   return overheadTypeEncoding(enc.getPosWidth());
73365777ecSAart Bik }
74365777ecSAart Bik 
75365777ecSAart Bik OverheadType
76365777ecSAart Bik mlir::sparse_tensor::crdTypeEncoding(SparseTensorEncodingAttr enc) {
77365777ecSAart Bik   return overheadTypeEncoding(enc.getCrdWidth());
78365777ecSAart Bik }
79365777ecSAart Bik 
80365777ecSAart Bik // TODO: we ought to add some `static_assert` tests to ensure that the
81365777ecSAart Bik // `STEA::get{Pos,Crd}Type` methods agree with `getOverheadType(builder,
82365777ecSAart Bik // {pos,crd}OverheadTypeEncoding(enc))`
83365777ecSAart Bik 
84365777ecSAart Bik // TODO: Adjust the naming convention for the constructors of
85365777ecSAart Bik // `OverheadType` so we can use the `MLIR_SPARSETENSOR_FOREVERY_O` x-macro
86365777ecSAart Bik // here instead of `MLIR_SPARSETENSOR_FOREVERY_FIXED_O`; to further reduce
87365777ecSAart Bik // the possibility of typo bugs or things getting out of sync.
88365777ecSAart Bik StringRef mlir::sparse_tensor::overheadTypeFunctionSuffix(OverheadType ot) {
89365777ecSAart Bik   switch (ot) {
90365777ecSAart Bik   case OverheadType::kIndex:
91365777ecSAart Bik     return "0";
92365777ecSAart Bik #define CASE(ONAME, O)                                                         \
93365777ecSAart Bik   case OverheadType::kU##ONAME:                                                \
94365777ecSAart Bik     return #ONAME;
95365777ecSAart Bik     MLIR_SPARSETENSOR_FOREVERY_FIXED_O(CASE)
96365777ecSAart Bik #undef CASE
97365777ecSAart Bik   }
98365777ecSAart Bik   llvm_unreachable("Unknown OverheadType");
99365777ecSAart Bik }
100365777ecSAart Bik 
101365777ecSAart Bik StringRef mlir::sparse_tensor::overheadTypeFunctionSuffix(Type tp) {
102365777ecSAart Bik   return overheadTypeFunctionSuffix(overheadTypeEncoding(tp));
103365777ecSAart Bik }
104365777ecSAart Bik 
105365777ecSAart Bik PrimaryType mlir::sparse_tensor::primaryTypeEncoding(Type elemTp) {
106365777ecSAart Bik   if (elemTp.isF64())
107365777ecSAart Bik     return PrimaryType::kF64;
108365777ecSAart Bik   if (elemTp.isF32())
109365777ecSAart Bik     return PrimaryType::kF32;
110365777ecSAart Bik   if (elemTp.isF16())
111365777ecSAart Bik     return PrimaryType::kF16;
112365777ecSAart Bik   if (elemTp.isBF16())
113365777ecSAart Bik     return PrimaryType::kBF16;
114365777ecSAart Bik   if (elemTp.isInteger(64))
115365777ecSAart Bik     return PrimaryType::kI64;
116365777ecSAart Bik   if (elemTp.isInteger(32))
117365777ecSAart Bik     return PrimaryType::kI32;
118365777ecSAart Bik   if (elemTp.isInteger(16))
119365777ecSAart Bik     return PrimaryType::kI16;
120365777ecSAart Bik   if (elemTp.isInteger(8))
121365777ecSAart Bik     return PrimaryType::kI8;
122365777ecSAart Bik   if (auto complexTp = dyn_cast<ComplexType>(elemTp)) {
123365777ecSAart Bik     auto complexEltTp = complexTp.getElementType();
124365777ecSAart Bik     if (complexEltTp.isF64())
125365777ecSAart Bik       return PrimaryType::kC64;
126365777ecSAart Bik     if (complexEltTp.isF32())
127365777ecSAart Bik       return PrimaryType::kC32;
128365777ecSAart Bik   }
129365777ecSAart Bik   llvm_unreachable("Unknown primary type");
130365777ecSAart Bik }
131365777ecSAart Bik 
132365777ecSAart Bik StringRef mlir::sparse_tensor::primaryTypeFunctionSuffix(PrimaryType pt) {
133365777ecSAart Bik   switch (pt) {
134365777ecSAart Bik #define CASE(VNAME, V)                                                         \
135365777ecSAart Bik   case PrimaryType::k##VNAME:                                                  \
136365777ecSAart Bik     return #VNAME;
137365777ecSAart Bik     MLIR_SPARSETENSOR_FOREVERY_V(CASE)
138365777ecSAart Bik #undef CASE
139365777ecSAart Bik   }
140365777ecSAart Bik   llvm_unreachable("Unknown PrimaryType");
141365777ecSAart Bik }
142365777ecSAart Bik 
143365777ecSAart Bik StringRef mlir::sparse_tensor::primaryTypeFunctionSuffix(Type elemTp) {
144365777ecSAart Bik   return primaryTypeFunctionSuffix(primaryTypeEncoding(elemTp));
145365777ecSAart Bik }
146365777ecSAart Bik 
147365777ecSAart Bik //===----------------------------------------------------------------------===//
148365777ecSAart Bik // Misc code generators.
149365777ecSAart Bik //===----------------------------------------------------------------------===//
150365777ecSAart Bik 
151365777ecSAart Bik Value sparse_tensor::genCast(OpBuilder &builder, Location loc, Value value,
152365777ecSAart Bik                              Type dstTp) {
153365777ecSAart Bik   const Type srcTp = value.getType();
154365777ecSAart Bik   if (srcTp == dstTp)
155365777ecSAart Bik     return value;
156365777ecSAart Bik 
157365777ecSAart Bik   // int <=> index
158365777ecSAart Bik   if (isa<IndexType>(srcTp) || isa<IndexType>(dstTp))
159365777ecSAart Bik     return builder.create<arith::IndexCastOp>(loc, dstTp, value);
160365777ecSAart Bik 
161365777ecSAart Bik   const auto srcIntTp = dyn_cast_or_null<IntegerType>(srcTp);
162365777ecSAart Bik   const bool isUnsignedCast = srcIntTp ? srcIntTp.isUnsigned() : false;
163365777ecSAart Bik   return mlir::convertScalarToDtype(builder, loc, value, dstTp, isUnsignedCast);
164365777ecSAart Bik }
165365777ecSAart Bik 
166365777ecSAart Bik Value sparse_tensor::genScalarToTensor(OpBuilder &builder, Location loc,
167365777ecSAart Bik                                        Value elem, Type dstTp) {
168*a5757c5bSChristian Sigg   if (auto rtp = dyn_cast<RankedTensorType>(dstTp)) {
169365777ecSAart Bik     // Scalars can only be converted to 0-ranked tensors.
170365777ecSAart Bik     assert(rtp.getRank() == 0);
171365777ecSAart Bik     elem = sparse_tensor::genCast(builder, loc, elem, rtp.getElementType());
172365777ecSAart Bik     return builder.create<tensor::FromElementsOp>(loc, rtp, elem);
173365777ecSAart Bik   }
174365777ecSAart Bik   return sparse_tensor::genCast(builder, loc, elem, dstTp);
175365777ecSAart Bik }
176365777ecSAart Bik 
177365777ecSAart Bik Value sparse_tensor::genIndexLoad(OpBuilder &builder, Location loc, Value mem,
17852b69aa3SPeiming Liu                                   ValueRange s) {
179365777ecSAart Bik   Value load = builder.create<memref::LoadOp>(loc, mem, s);
180365777ecSAart Bik   if (!isa<IndexType>(load.getType())) {
181365777ecSAart Bik     if (load.getType().getIntOrFloatBitWidth() < 64)
182365777ecSAart Bik       load = builder.create<arith::ExtUIOp>(loc, builder.getI64Type(), load);
183365777ecSAart Bik     load =
184365777ecSAart Bik         builder.create<arith::IndexCastOp>(loc, builder.getIndexType(), load);
185365777ecSAart Bik   }
186365777ecSAart Bik   return load;
187365777ecSAart Bik }
188365777ecSAart Bik 
189365777ecSAart Bik mlir::TypedAttr mlir::sparse_tensor::getOneAttr(Builder &builder, Type tp) {
190365777ecSAart Bik   if (isa<FloatType>(tp))
191365777ecSAart Bik     return builder.getFloatAttr(tp, 1.0);
192365777ecSAart Bik   if (isa<IndexType>(tp))
193365777ecSAart Bik     return builder.getIndexAttr(1);
194365777ecSAart Bik   if (auto intTp = dyn_cast<IntegerType>(tp))
195365777ecSAart Bik     return builder.getIntegerAttr(tp, APInt(intTp.getWidth(), 1));
196365777ecSAart Bik   if (isa<RankedTensorType, VectorType>(tp)) {
197365777ecSAart Bik     auto shapedTp = cast<ShapedType>(tp);
198365777ecSAart Bik     if (auto one = getOneAttr(builder, shapedTp.getElementType()))
199365777ecSAart Bik       return DenseElementsAttr::get(shapedTp, one);
200365777ecSAart Bik   }
201365777ecSAart Bik   llvm_unreachable("Unsupported attribute type");
202365777ecSAart Bik }
203365777ecSAart Bik 
204365777ecSAart Bik Value mlir::sparse_tensor::genIsNonzero(OpBuilder &builder, mlir::Location loc,
205365777ecSAart Bik                                         Value v) {
206365777ecSAart Bik   Type tp = v.getType();
207365777ecSAart Bik   Value zero = constantZero(builder, loc, tp);
208365777ecSAart Bik   if (isa<FloatType>(tp))
209365777ecSAart Bik     return builder.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UNE, v,
210365777ecSAart Bik                                          zero);
211365777ecSAart Bik   if (tp.isIntOrIndex())
212365777ecSAart Bik     return builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ne, v,
213365777ecSAart Bik                                          zero);
214365777ecSAart Bik   if (dyn_cast<ComplexType>(tp))
215365777ecSAart Bik     return builder.create<complex::NotEqualOp>(loc, v, zero);
216365777ecSAart Bik   llvm_unreachable("Non-numeric type");
217365777ecSAart Bik }
218365777ecSAart Bik 
219365777ecSAart Bik void mlir::sparse_tensor::genReshapeDstShape(
220365777ecSAart Bik     OpBuilder &builder, Location loc, SmallVectorImpl<Value> &dstShape,
221365777ecSAart Bik     ArrayRef<Value> srcShape, ArrayRef<Size> staticDstShape,
222365777ecSAart Bik     ArrayRef<ReassociationIndices> reassociation) {
223365777ecSAart Bik   // Collapse shape.
224365777ecSAart Bik   if (reassociation.size() < srcShape.size()) {
225365777ecSAart Bik     unsigned start = 0;
226365777ecSAart Bik     for (const auto &map : llvm::enumerate(reassociation)) {
227365777ecSAart Bik       auto dstDim = constantIndex(builder, loc, 1);
228365777ecSAart Bik       for (unsigned i = start; i < start + map.value().size(); i++) {
229365777ecSAart Bik         dstDim = builder.create<arith::MulIOp>(loc, dstDim, srcShape[i]);
230365777ecSAart Bik       }
231365777ecSAart Bik       dstShape.push_back(dstDim);
232365777ecSAart Bik       start = start + map.value().size();
233365777ecSAart Bik     }
234365777ecSAart Bik     assert(start == srcShape.size());
235365777ecSAart Bik     return;
236365777ecSAart Bik   }
237365777ecSAart Bik 
238365777ecSAart Bik   // Expand shape.
239365777ecSAart Bik   assert(reassociation.size() == srcShape.size());
240365777ecSAart Bik   unsigned start = 0;
241365777ecSAart Bik   // Expand the i-th dimension in srcShape.
242365777ecSAart Bik   for (unsigned i = 0, size = srcShape.size(); i < size; i++) {
243365777ecSAart Bik     const auto &map = reassociation[i];
244365777ecSAart Bik     auto srcDim = srcShape[i];
245365777ecSAart Bik     // Iterate through dimensions expanded from the i-th dimension.
246365777ecSAart Bik     for (unsigned j = start; j < start + map.size(); j++) {
247365777ecSAart Bik       // There can be only one dynamic sized dimension among dimensions
248365777ecSAart Bik       // expanded from the i-th dimension in srcShape.
249365777ecSAart Bik       // For example, if srcDim = 8, then the expanded shape could be <2x?x2>,
250365777ecSAart Bik       // but not <2x?x?>.
251365777ecSAart Bik       if (staticDstShape[j] == ShapedType::kDynamic) {
252365777ecSAart Bik         // The expanded dimension has dynamic size. We compute the dimension
253365777ecSAart Bik         // by dividing srcDim by the product of the static dimensions.
254365777ecSAart Bik         Size product = 1;
255365777ecSAart Bik         for (unsigned k = start; k < start + map.size(); k++) {
256365777ecSAart Bik           if (staticDstShape[k] != ShapedType::kDynamic) {
257365777ecSAart Bik             product *= staticDstShape[k];
258365777ecSAart Bik           }
259365777ecSAart Bik         }
260365777ecSAart Bik         // Compute the dynamic dimension size.
261365777ecSAart Bik         Value productVal = constantIndex(builder, loc, product);
262365777ecSAart Bik         Value dynamicSize =
263365777ecSAart Bik             builder.create<arith::DivUIOp>(loc, srcDim, productVal);
264365777ecSAart Bik         dstShape.push_back(dynamicSize);
265365777ecSAart Bik       } else {
266365777ecSAart Bik         // The expanded dimension is statically known.
267365777ecSAart Bik         dstShape.push_back(constantIndex(builder, loc, staticDstShape[j]));
268365777ecSAart Bik       }
269365777ecSAart Bik     }
270365777ecSAart Bik     start = start + map.size();
271365777ecSAart Bik   }
272365777ecSAart Bik   assert(start == staticDstShape.size());
273365777ecSAart Bik }
274365777ecSAart Bik 
275365777ecSAart Bik void mlir::sparse_tensor::reshapeCvs(
276365777ecSAart Bik     OpBuilder &builder, Location loc,
277365777ecSAart Bik     ArrayRef<ReassociationIndices> reassociation, // NOLINT
278365777ecSAart Bik     ValueRange srcSizes, ValueRange srcCvs,       // NOLINT
279365777ecSAart Bik     ValueRange dstSizes, SmallVectorImpl<Value> &dstCvs) {
280365777ecSAart Bik   const unsigned srcRank = srcSizes.size();
281365777ecSAart Bik   const unsigned dstRank = dstSizes.size();
282365777ecSAart Bik   assert(srcRank == srcCvs.size() && "Source rank mismatch");
283365777ecSAart Bik   const bool isCollapse = srcRank > dstRank;
284365777ecSAart Bik   const ValueRange sizes = isCollapse ? srcSizes : dstSizes;
285365777ecSAart Bik   // Iterate over reassociation map.
286365777ecSAart Bik   unsigned i = 0;
287365777ecSAart Bik   unsigned start = 0;
288365777ecSAart Bik   for (const auto &map : llvm::enumerate(reassociation)) {
289365777ecSAart Bik     // Prepare strides information in dimension slice.
290365777ecSAart Bik     Value linear = constantIndex(builder, loc, 1);
291365777ecSAart Bik     for (unsigned j = start, end = start + map.value().size(); j < end; j++) {
292365777ecSAart Bik       linear = builder.create<arith::MulIOp>(loc, linear, sizes[j]);
293365777ecSAart Bik     }
294365777ecSAart Bik     // Start expansion.
295365777ecSAart Bik     Value val;
296365777ecSAart Bik     if (!isCollapse)
297365777ecSAart Bik       val = srcCvs[i];
298365777ecSAart Bik     // Iterate over dimension slice.
299365777ecSAart Bik     for (unsigned j = start, end = start + map.value().size(); j < end; j++) {
300365777ecSAart Bik       linear = builder.create<arith::DivUIOp>(loc, linear, sizes[j]);
301365777ecSAart Bik       if (isCollapse) {
302365777ecSAart Bik         const Value mul = builder.create<arith::MulIOp>(loc, srcCvs[j], linear);
303365777ecSAart Bik         val = val ? builder.create<arith::AddIOp>(loc, val, mul) : mul;
304365777ecSAart Bik       } else {
305365777ecSAart Bik         const Value old = val;
306365777ecSAart Bik         val = builder.create<arith::DivUIOp>(loc, val, linear);
307365777ecSAart Bik         assert(dstCvs.size() == j);
308365777ecSAart Bik         dstCvs.push_back(val);
309365777ecSAart Bik         val = builder.create<arith::RemUIOp>(loc, old, linear);
310365777ecSAart Bik       }
311365777ecSAart Bik     }
312365777ecSAart Bik     // Finalize collapse.
313365777ecSAart Bik     if (isCollapse) {
314365777ecSAart Bik       assert(dstCvs.size() == i);
315365777ecSAart Bik       dstCvs.push_back(val);
316365777ecSAart Bik     }
317365777ecSAart Bik     start += map.value().size();
318365777ecSAart Bik     i++;
319365777ecSAart Bik   }
320365777ecSAart Bik   assert(dstCvs.size() == dstRank);
321365777ecSAart Bik }
322365777ecSAart Bik 
323365777ecSAart Bik FlatSymbolRefAttr mlir::sparse_tensor::getFunc(ModuleOp module, StringRef name,
324365777ecSAart Bik                                                TypeRange resultType,
325365777ecSAart Bik                                                ValueRange operands,
326365777ecSAart Bik                                                EmitCInterface emitCInterface) {
327365777ecSAart Bik   MLIRContext *context = module.getContext();
328365777ecSAart Bik   auto result = SymbolRefAttr::get(context, name);
329365777ecSAart Bik   auto func = module.lookupSymbol<func::FuncOp>(result.getAttr());
330365777ecSAart Bik   if (!func) {
331365777ecSAart Bik     OpBuilder moduleBuilder(module.getBodyRegion());
332365777ecSAart Bik     func = moduleBuilder.create<func::FuncOp>(
333365777ecSAart Bik         module.getLoc(), name,
334365777ecSAart Bik         FunctionType::get(context, operands.getTypes(), resultType));
335365777ecSAart Bik     func.setPrivate();
336365777ecSAart Bik     if (static_cast<bool>(emitCInterface))
337365777ecSAart Bik       func->setAttr(LLVM::LLVMDialect::getEmitCWrapperAttrName(),
338365777ecSAart Bik                     UnitAttr::get(context));
339365777ecSAart Bik   }
340365777ecSAart Bik   return result;
341365777ecSAart Bik }
342365777ecSAart Bik 
343365777ecSAart Bik func::CallOp mlir::sparse_tensor::createFuncCall(
344365777ecSAart Bik     OpBuilder &builder, Location loc, StringRef name, TypeRange resultType,
345365777ecSAart Bik     ValueRange operands, EmitCInterface emitCInterface) {
346365777ecSAart Bik   auto module = builder.getBlock()->getParentOp()->getParentOfType<ModuleOp>();
347365777ecSAart Bik   FlatSymbolRefAttr fn =
348365777ecSAart Bik       getFunc(module, name, resultType, operands, emitCInterface);
349365777ecSAart Bik   return builder.create<func::CallOp>(loc, resultType, fn, operands);
350365777ecSAart Bik }
351365777ecSAart Bik 
352365777ecSAart Bik Type mlir::sparse_tensor::getOpaquePointerType(MLIRContext *ctx) {
353365777ecSAart Bik   return LLVM::LLVMPointerType::get(ctx);
354365777ecSAart Bik }
355365777ecSAart Bik 
356365777ecSAart Bik Type mlir::sparse_tensor::getOpaquePointerType(Builder &builder) {
357365777ecSAart Bik   return getOpaquePointerType(builder.getContext());
358365777ecSAart Bik }
359365777ecSAart Bik 
360365777ecSAart Bik Value mlir::sparse_tensor::genAlloca(OpBuilder &builder, Location loc,
361365777ecSAart Bik                                      unsigned sz, Type tp, bool staticShape) {
362365777ecSAart Bik   if (staticShape) {
363365777ecSAart Bik     auto memTp = MemRefType::get({sz}, tp);
364365777ecSAart Bik     return builder.create<memref::AllocaOp>(loc, memTp);
365365777ecSAart Bik   }
366365777ecSAart Bik   return genAlloca(builder, loc, constantIndex(builder, loc, sz), tp);
367365777ecSAart Bik }
368365777ecSAart Bik 
369365777ecSAart Bik Value mlir::sparse_tensor::genAlloca(OpBuilder &builder, Location loc, Value sz,
370365777ecSAart Bik                                      Type tp) {
371365777ecSAart Bik   auto memTp = MemRefType::get({ShapedType::kDynamic}, tp);
372365777ecSAart Bik   return builder.create<memref::AllocaOp>(loc, memTp, ValueRange{sz});
373365777ecSAart Bik }
374365777ecSAart Bik 
375365777ecSAart Bik Value mlir::sparse_tensor::genAllocaScalar(OpBuilder &builder, Location loc,
376365777ecSAart Bik                                            Type tp) {
377365777ecSAart Bik   return builder.create<memref::AllocaOp>(loc, MemRefType::get({}, tp));
378365777ecSAart Bik }
379365777ecSAart Bik 
380365777ecSAart Bik Value mlir::sparse_tensor::allocaBuffer(OpBuilder &builder, Location loc,
381365777ecSAart Bik                                         ValueRange values) {
382365777ecSAart Bik   const unsigned sz = values.size();
383365777ecSAart Bik   assert(sz >= 1);
384365777ecSAart Bik   Value buffer = genAlloca(builder, loc, sz, values[0].getType());
385365777ecSAart Bik   for (unsigned i = 0; i < sz; i++) {
386365777ecSAart Bik     Value idx = constantIndex(builder, loc, i);
387365777ecSAart Bik     builder.create<memref::StoreOp>(loc, values[i], buffer, idx);
388365777ecSAart Bik   }
389365777ecSAart Bik   return buffer;
390365777ecSAart Bik }
391365777ecSAart Bik 
392365777ecSAart Bik Value mlir::sparse_tensor::allocDenseTensor(OpBuilder &builder, Location loc,
393365777ecSAart Bik                                             RankedTensorType tensorTp,
394365777ecSAart Bik                                             ValueRange sizes) {
395365777ecSAart Bik   Type elemTp = tensorTp.getElementType();
396365777ecSAart Bik   auto shape = tensorTp.getShape();
397365777ecSAart Bik   auto memTp = MemRefType::get(shape, elemTp);
398365777ecSAart Bik   SmallVector<Value> dynamicSizes;
399365777ecSAart Bik   for (unsigned i = 0, rank = tensorTp.getRank(); i < rank; i++) {
400365777ecSAart Bik     if (shape[i] == ShapedType::kDynamic)
401365777ecSAart Bik       dynamicSizes.push_back(sizes[i]);
402365777ecSAart Bik   }
403365777ecSAart Bik   Value mem = builder.create<memref::AllocOp>(loc, memTp, dynamicSizes);
404365777ecSAart Bik   Value zero = constantZero(builder, loc, elemTp);
405365777ecSAart Bik   builder.create<linalg::FillOp>(loc, ValueRange{zero}, ValueRange{mem});
406365777ecSAart Bik   return mem;
407365777ecSAart Bik }
408365777ecSAart Bik 
409365777ecSAart Bik void mlir::sparse_tensor::deallocDenseTensor(OpBuilder &builder, Location loc,
410365777ecSAart Bik                                              Value buffer) {
411365777ecSAart Bik   builder.create<memref::DeallocOp>(loc, buffer);
412365777ecSAart Bik }
413365777ecSAart Bik 
414365777ecSAart Bik void mlir::sparse_tensor::sizesFromSrc(OpBuilder &builder,
415365777ecSAart Bik                                        SmallVectorImpl<Value> &sizes,
416365777ecSAart Bik                                        Location loc, Value src) {
417365777ecSAart Bik   const Dimension dimRank = getSparseTensorType(src).getDimRank();
418365777ecSAart Bik   for (Dimension d = 0; d < dimRank; d++)
419365777ecSAart Bik     sizes.push_back(linalg::createOrFoldDimOp(builder, loc, src, d));
420365777ecSAart Bik }
421365777ecSAart Bik 
422365777ecSAart Bik Operation *mlir::sparse_tensor::getTop(Operation *op) {
423365777ecSAart Bik   for (; isa<scf::ForOp>(op->getParentOp()) ||
424365777ecSAart Bik          isa<scf::WhileOp>(op->getParentOp()) ||
425365777ecSAart Bik          isa<scf::ParallelOp>(op->getParentOp()) ||
426365777ecSAart Bik          isa<scf::IfOp>(op->getParentOp());
427365777ecSAart Bik        op = op->getParentOp())
428365777ecSAart Bik     ;
429365777ecSAart Bik   return op;
430365777ecSAart Bik }
431365777ecSAart Bik 
432365777ecSAart Bik void sparse_tensor::foreachInSparseConstant(
433365777ecSAart Bik     OpBuilder &builder, Location loc, SparseElementsAttr attr, AffineMap order,
434365777ecSAart Bik     function_ref<void(ArrayRef<Value>, Value)> callback) {
435365777ecSAart Bik   if (!order)
436365777ecSAart Bik     order = builder.getMultiDimIdentityMap(attr.getType().getRank());
437365777ecSAart Bik 
438365777ecSAart Bik   auto stt = SparseTensorType(getRankedTensorType(attr));
439365777ecSAart Bik   const Dimension dimRank = stt.getDimRank();
440365777ecSAart Bik   const auto coordinates = attr.getIndices().getValues<IntegerAttr>();
441365777ecSAart Bik   const auto values = attr.getValues().getValues<Attribute>();
442365777ecSAart Bik 
443365777ecSAart Bik   // This is like the `Element<V>` class in the runtime library, but for
444365777ecSAart Bik   // MLIR attributes.  In the future we may want to move this out into
445365777ecSAart Bik   // a proper class definition to help improve code legibility (e.g.,
446365777ecSAart Bik   // `first` -> `coords`, `second` -> `value`) as well as being able
447365777ecSAart Bik   // to factor out analogues of `ElementLT<V>` for the sort below, etc.
448365777ecSAart Bik   using ElementAttr = std::pair<SmallVector<IntegerAttr>, Attribute>;
449365777ecSAart Bik 
450365777ecSAart Bik   // Construct the COO from the SparseElementsAttr.
451365777ecSAart Bik   SmallVector<ElementAttr> elems;
452365777ecSAart Bik   for (size_t i = 0, nse = values.size(); i < nse; i++) {
453365777ecSAart Bik     elems.emplace_back();
454365777ecSAart Bik     elems.back().second = values[i];
455365777ecSAart Bik     auto &coords = elems.back().first;
456365777ecSAart Bik     coords.reserve(dimRank);
457365777ecSAart Bik     for (Dimension d = 0; d < dimRank; d++)
458365777ecSAart Bik       coords.push_back(coordinates[i * dimRank + d]);
459365777ecSAart Bik   }
460365777ecSAart Bik 
461365777ecSAart Bik   // Sorts the sparse element attribute based on coordinates.
462365777ecSAart Bik   std::sort(elems.begin(), elems.end(),
463365777ecSAart Bik             [order](const ElementAttr &lhs, const ElementAttr &rhs) {
464365777ecSAart Bik               if (std::addressof(lhs) == std::addressof(rhs))
465365777ecSAart Bik                 return false;
466365777ecSAart Bik 
467365777ecSAart Bik               auto lhsCoords = llvm::map_to_vector(
468365777ecSAart Bik                   lhs.first, [](IntegerAttr i) { return i.getInt(); });
469365777ecSAart Bik               auto rhsCoords = llvm::map_to_vector(
470365777ecSAart Bik                   rhs.first, [](IntegerAttr i) { return i.getInt(); });
471365777ecSAart Bik 
472365777ecSAart Bik               SmallVector<int64_t, 4> lhsLvlCrds = order.compose(lhsCoords);
473365777ecSAart Bik               SmallVector<int64_t, 4> rhsLvlCrds = order.compose(rhsCoords);
474365777ecSAart Bik               // Sort the element based on the lvl coordinates.
475365777ecSAart Bik               for (Level l = 0; l < order.getNumResults(); l++) {
476365777ecSAart Bik                 if (lhsLvlCrds[l] == rhsLvlCrds[l])
477365777ecSAart Bik                   continue;
478365777ecSAart Bik                 return lhsLvlCrds[l] < rhsLvlCrds[l];
479365777ecSAart Bik               }
480365777ecSAart Bik               llvm_unreachable("no equal coordinate in sparse element attr");
481365777ecSAart Bik             });
482365777ecSAart Bik 
483365777ecSAart Bik   SmallVector<Value> cvs;
484365777ecSAart Bik   cvs.reserve(dimRank);
485365777ecSAart Bik   for (size_t i = 0, nse = values.size(); i < nse; i++) {
486365777ecSAart Bik     // Remap coordinates.
487365777ecSAart Bik     cvs.clear();
488365777ecSAart Bik     for (Dimension d = 0; d < dimRank; d++) {
489365777ecSAart Bik       auto crd = elems[i].first[d].getInt();
490365777ecSAart Bik       cvs.push_back(builder.create<arith::ConstantIndexOp>(loc, crd));
491365777ecSAart Bik     }
492365777ecSAart Bik     // Remap value.
493365777ecSAart Bik     Value val;
494365777ecSAart Bik     if (isa<ComplexType>(attr.getElementType())) {
495365777ecSAart Bik       auto valAttr = cast<ArrayAttr>(elems[i].second);
496365777ecSAart Bik       val = builder.create<complex::ConstantOp>(loc, attr.getElementType(),
497365777ecSAart Bik                                                 valAttr);
498365777ecSAart Bik     } else {
499365777ecSAart Bik       auto valAttr = cast<TypedAttr>(elems[i].second);
500365777ecSAart Bik       val = builder.create<arith::ConstantOp>(loc, valAttr);
501365777ecSAart Bik     }
502365777ecSAart Bik     assert(val);
503365777ecSAart Bik     callback(cvs, val);
504365777ecSAart Bik   }
505365777ecSAart Bik }
506365777ecSAart Bik 
507365777ecSAart Bik SmallVector<Value> sparse_tensor::loadAll(OpBuilder &builder, Location loc,
508365777ecSAart Bik                                           size_t size, Value mem,
509365777ecSAart Bik                                           size_t offsetIdx, Value offsetVal) {
510365777ecSAart Bik #ifndef NDEBUG
511365777ecSAart Bik   const auto memTp = cast<MemRefType>(mem.getType());
512365777ecSAart Bik   assert(memTp.getRank() == 1);
513365777ecSAart Bik   const Size memSh = memTp.getDimSize(0);
514365777ecSAart Bik   assert(ShapedType::isDynamic(memSh) || memSh >= static_cast<Size>(size));
515365777ecSAart Bik   assert(offsetIdx == 0 || offsetIdx < size);
516365777ecSAart Bik #endif // NDEBUG
517365777ecSAart Bik   SmallVector<Value> vs;
518365777ecSAart Bik   vs.reserve(size);
519365777ecSAart Bik   for (unsigned i = 0; i < size; i++) {
520365777ecSAart Bik     Value v = builder.create<memref::LoadOp>(loc, mem,
521365777ecSAart Bik                                              constantIndex(builder, loc, i));
522365777ecSAart Bik     if (i == offsetIdx && offsetVal)
523365777ecSAart Bik       v = builder.create<arith::AddIOp>(loc, v, offsetVal);
524365777ecSAart Bik     vs.push_back(v);
525365777ecSAart Bik   }
526365777ecSAart Bik   return vs;
527365777ecSAart Bik }
528365777ecSAart Bik 
529365777ecSAart Bik void sparse_tensor::storeAll(OpBuilder &builder, Location loc, Value mem,
530365777ecSAart Bik                              ValueRange vs, size_t offsetIdx, Value offsetVal) {
531365777ecSAart Bik #ifndef NDEBUG
532365777ecSAart Bik   const size_t vsize = vs.size();
533365777ecSAart Bik   const auto memTp = cast<MemRefType>(mem.getType());
534365777ecSAart Bik   assert(memTp.getRank() == 1);
535365777ecSAart Bik   const Size memSh = memTp.getDimSize(0);
536365777ecSAart Bik   assert(ShapedType::isDynamic(memSh) || memSh >= static_cast<Size>(vsize));
537365777ecSAart Bik   assert(offsetIdx == 0 || offsetIdx < vsize);
538365777ecSAart Bik #endif // NDEBUG
539365777ecSAart Bik   for (const auto &v : llvm::enumerate(vs)) {
540365777ecSAart Bik     const Value w =
541365777ecSAart Bik         (offsetIdx == v.index() && offsetVal)
542365777ecSAart Bik             ? builder.create<arith::AddIOp>(loc, v.value(), offsetVal)
543365777ecSAart Bik             : v.value();
544365777ecSAart Bik     builder.create<memref::StoreOp>(loc, w, mem,
545365777ecSAart Bik                                     constantIndex(builder, loc, v.index()));
546365777ecSAart Bik   }
547365777ecSAart Bik }
548365777ecSAart Bik 
549365777ecSAart Bik TypedValue<BaseMemRefType>
550365777ecSAart Bik sparse_tensor::genToMemref(OpBuilder &builder, Location loc, Value tensor) {
551365777ecSAart Bik   auto tTp = llvm::cast<TensorType>(tensor.getType());
552365777ecSAart Bik   auto mTp = MemRefType::get(tTp.getShape(), tTp.getElementType());
553365777ecSAart Bik   return builder.create<bufferization::ToMemrefOp>(loc, mTp, tensor)
554365777ecSAart Bik       .getResult();
555365777ecSAart Bik }
556365777ecSAart Bik 
557365777ecSAart Bik Value sparse_tensor::createOrFoldSliceOffsetOp(OpBuilder &builder, Location loc,
558365777ecSAart Bik                                                Value tensor, Dimension dim) {
559365777ecSAart Bik   auto enc = getSparseTensorEncoding(tensor.getType());
560365777ecSAart Bik   assert(enc && enc.isSlice());
561365777ecSAart Bik   std::optional<unsigned> offset = enc.getStaticDimSliceOffset(dim);
562365777ecSAart Bik   if (offset.has_value())
563365777ecSAart Bik     return constantIndex(builder, loc, *offset);
564365777ecSAart Bik   return builder.create<ToSliceOffsetOp>(loc, tensor, APInt(64, dim));
565365777ecSAart Bik }
566365777ecSAart Bik 
567365777ecSAart Bik Value sparse_tensor::createOrFoldSliceStrideOp(OpBuilder &builder, Location loc,
568365777ecSAart Bik                                                Value tensor, Dimension dim) {
569365777ecSAart Bik   auto enc = getSparseTensorEncoding(tensor.getType());
570365777ecSAart Bik   assert(enc && enc.isSlice());
571365777ecSAart Bik   std::optional<unsigned> stride = enc.getStaticDimSliceStride(dim);
572365777ecSAart Bik   if (stride.has_value())
573365777ecSAart Bik     return constantIndex(builder, loc, *stride);
574365777ecSAart Bik   return builder.create<ToSliceStrideOp>(loc, tensor, APInt(64, dim));
575365777ecSAart Bik }
576365777ecSAart Bik 
577365777ecSAart Bik Value sparse_tensor::genReader(OpBuilder &builder, Location loc,
578365777ecSAart Bik                                SparseTensorType stt, Value tensor,
579365777ecSAart Bik                                /*out*/ SmallVectorImpl<Value> &dimSizesValues,
580365777ecSAart Bik                                /*out*/ Value &dimSizesBuffer) {
581365777ecSAart Bik   // Construct the dimension **shapes** buffer. The buffer contains the static
582365777ecSAart Bik   // size per dimension, or otherwise a zero for a dynamic size.
583365777ecSAart Bik   Dimension dimRank = stt.getDimRank();
584365777ecSAart Bik   dimSizesValues.clear();
585365777ecSAart Bik   dimSizesValues.reserve(dimRank);
586365777ecSAart Bik   for (const Size sz : stt.getDimShape()) {
587365777ecSAart Bik     const auto s = ShapedType::isDynamic(sz) ? 0 : sz;
588365777ecSAart Bik     dimSizesValues.push_back(constantIndex(builder, loc, s));
589365777ecSAart Bik   }
590365777ecSAart Bik   Value dimShapesBuffer = allocaBuffer(builder, loc, dimSizesValues);
591365777ecSAart Bik   // Create the `CheckedSparseTensorReader`. This reader performs a
592365777ecSAart Bik   // consistency check on the static sizes, but accepts any size
593365777ecSAart Bik   // of each dimension with a dynamic size.
594365777ecSAart Bik   Type opaqueTp = getOpaquePointerType(builder);
595365777ecSAart Bik   Type eltTp = stt.getElementType();
596365777ecSAart Bik   Value valTp = constantPrimaryTypeEncoding(builder, loc, eltTp);
597365777ecSAart Bik   Value reader =
598365777ecSAart Bik       createFuncCall(builder, loc, "createCheckedSparseTensorReader", opaqueTp,
599365777ecSAart Bik                      {tensor, dimShapesBuffer, valTp}, EmitCInterface::On)
600365777ecSAart Bik           .getResult(0);
601365777ecSAart Bik   // For static shapes, the shape buffer can be used right away. For dynamic
602365777ecSAart Bik   // shapes, use the information from the reader to construct a buffer that
603365777ecSAart Bik   // supplies the actual size for each dynamic dimension.
604365777ecSAart Bik   dimSizesBuffer = dimShapesBuffer;
605365777ecSAart Bik   if (stt.hasDynamicDimShape()) {
606365777ecSAart Bik     Type indexTp = builder.getIndexType();
607365777ecSAart Bik     auto memTp = MemRefType::get({ShapedType::kDynamic}, indexTp);
608365777ecSAart Bik     dimSizesBuffer =
609365777ecSAart Bik         createFuncCall(builder, loc, "getSparseTensorReaderDimSizes", memTp,
610365777ecSAart Bik                        reader, EmitCInterface::On)
611365777ecSAart Bik             .getResult(0);
612365777ecSAart Bik     // Also convert the dim shapes values into dim sizes values, just in case
613365777ecSAart Bik     // subsequent clients need the values (DCE will remove unused).
614365777ecSAart Bik     for (Dimension d = 0; d < dimRank; d++) {
615365777ecSAart Bik       if (stt.isDynamicDim(d))
616365777ecSAart Bik         dimSizesValues[d] = builder.create<memref::LoadOp>(
617365777ecSAart Bik             loc, dimSizesBuffer, constantIndex(builder, loc, d));
618365777ecSAart Bik     }
619365777ecSAart Bik   }
620365777ecSAart Bik   return reader;
621365777ecSAart Bik }
622365777ecSAart Bik 
623365777ecSAart Bik Value sparse_tensor::genMapBuffers(
624365777ecSAart Bik     OpBuilder &builder, Location loc, SparseTensorType stt,
625365777ecSAart Bik     ArrayRef<Value> dimSizesValues, Value dimSizesBuffer,
626365777ecSAart Bik     /*out*/ SmallVectorImpl<Value> &lvlSizesValues,
627365777ecSAart Bik     /*out*/ Value &dim2lvlBuffer,
628365777ecSAart Bik     /*out*/ Value &lvl2dimBuffer) {
629365777ecSAart Bik   const Dimension dimRank = stt.getDimRank();
630365777ecSAart Bik   const Level lvlRank = stt.getLvlRank();
631365777ecSAart Bik   lvlSizesValues.clear();
632365777ecSAart Bik   lvlSizesValues.reserve(lvlRank);
633365777ecSAart Bik   // For an identity mapping, the dim2lvl and lvl2dim mappings are
634365777ecSAart Bik   // identical as are dimSizes and lvlSizes, so buffers are reused
635365777ecSAart Bik   // as much as possible.
636365777ecSAart Bik   if (stt.isIdentity()) {
637365777ecSAart Bik     assert(dimRank == lvlRank);
638365777ecSAart Bik     SmallVector<Value> iotaValues;
639365777ecSAart Bik     iotaValues.reserve(lvlRank);
640365777ecSAart Bik     for (Level l = 0; l < lvlRank; l++) {
641365777ecSAart Bik       iotaValues.push_back(constantIndex(builder, loc, l));
642365777ecSAart Bik       lvlSizesValues.push_back(dimSizesValues[l]);
643365777ecSAart Bik     }
644365777ecSAart Bik     dim2lvlBuffer = lvl2dimBuffer = allocaBuffer(builder, loc, iotaValues);
645365777ecSAart Bik     return dimSizesBuffer; // now lvlSizesBuffer
646365777ecSAart Bik   }
647365777ecSAart Bik   // Otherwise, some code needs to be generated to set up the buffers.
648365777ecSAart Bik   // This code deals with permutations as well as non-permutations that
649365777ecSAart Bik   // arise from rank changing blocking.
650365777ecSAart Bik   const auto dimToLvl = stt.getDimToLvl();
651365777ecSAart Bik   const auto lvlToDim = stt.getLvlToDim();
652365777ecSAart Bik   SmallVector<Value> dim2lvlValues(lvlRank); // for each lvl, expr in dim vars
653365777ecSAart Bik   SmallVector<Value> lvl2dimValues(dimRank); // for each dim, expr in lvl vars
654365777ecSAart Bik   // Generate dim2lvl.
655365777ecSAart Bik   assert(lvlRank == dimToLvl.getNumResults());
656365777ecSAart Bik   for (Level l = 0; l < lvlRank; l++) {
657365777ecSAart Bik     AffineExpr exp = dimToLvl.getResult(l);
658365777ecSAart Bik     // We expect:
659365777ecSAart Bik     //    (1) l = d
660365777ecSAart Bik     //    (2) l = d / c
661365777ecSAart Bik     //    (3) l = d % c
662365777ecSAart Bik     Dimension d = 0;
663365777ecSAart Bik     uint64_t cf = 0, cm = 0;
664365777ecSAart Bik     switch (exp.getKind()) {
665365777ecSAart Bik     case AffineExprKind::DimId: {
666365777ecSAart Bik       d = cast<AffineDimExpr>(exp).getPosition();
667365777ecSAart Bik       break;
668365777ecSAart Bik     }
669365777ecSAart Bik     case AffineExprKind::FloorDiv: {
670365777ecSAart Bik       auto floor = cast<AffineBinaryOpExpr>(exp);
671365777ecSAart Bik       d = cast<AffineDimExpr>(floor.getLHS()).getPosition();
672365777ecSAart Bik       cf = cast<AffineConstantExpr>(floor.getRHS()).getValue();
673365777ecSAart Bik       break;
674365777ecSAart Bik     }
675365777ecSAart Bik     case AffineExprKind::Mod: {
676365777ecSAart Bik       auto mod = cast<AffineBinaryOpExpr>(exp);
677365777ecSAart Bik       d = cast<AffineDimExpr>(mod.getLHS()).getPosition();
678365777ecSAart Bik       cm = cast<AffineConstantExpr>(mod.getRHS()).getValue();
679365777ecSAart Bik       break;
680365777ecSAart Bik     }
681365777ecSAart Bik     default:
682365777ecSAart Bik       llvm::report_fatal_error("unsupported dim2lvl in sparse tensor type");
683365777ecSAart Bik     }
684365777ecSAart Bik     dim2lvlValues[l] = constantIndex(builder, loc, encodeDim(d, cf, cm));
685365777ecSAart Bik     // Compute the level sizes.
686365777ecSAart Bik     //    (1) l = d        : size(d)
687365777ecSAart Bik     //    (2) l = d / c    : size(d) / c
688365777ecSAart Bik     //    (3) l = d % c    : c
689365777ecSAart Bik     Value lvlSz;
690365777ecSAart Bik     if (cm == 0) {
691365777ecSAart Bik       lvlSz = dimSizesValues[d];
692365777ecSAart Bik       if (cf != 0)
693365777ecSAart Bik         lvlSz = builder.create<arith::DivUIOp>(loc, lvlSz,
694365777ecSAart Bik                                                constantIndex(builder, loc, cf));
695365777ecSAart Bik     } else {
696365777ecSAart Bik       lvlSz = constantIndex(builder, loc, cm);
697365777ecSAart Bik     }
698365777ecSAart Bik     lvlSizesValues.push_back(lvlSz);
699365777ecSAart Bik   }
700365777ecSAart Bik   // Generate lvl2dim.
701365777ecSAart Bik   assert(dimRank == lvlToDim.getNumResults());
702365777ecSAart Bik   for (Dimension d = 0; d < dimRank; d++) {
703365777ecSAart Bik     AffineExpr exp = lvlToDim.getResult(d);
704365777ecSAart Bik     // We expect:
705365777ecSAart Bik     //    (1) d = l
706365777ecSAart Bik     //    (2) d = l' * c + l
707365777ecSAart Bik     Level l = 0, ll = 0;
708365777ecSAart Bik     uint64_t c = 0;
709365777ecSAart Bik     switch (exp.getKind()) {
710365777ecSAart Bik     case AffineExprKind::DimId: {
711365777ecSAart Bik       l = cast<AffineDimExpr>(exp).getPosition();
712365777ecSAart Bik       break;
713365777ecSAart Bik     }
714365777ecSAart Bik     case AffineExprKind::Add: {
715365777ecSAart Bik       // Always mul on lhs, symbol/constant on rhs.
716365777ecSAart Bik       auto add = cast<AffineBinaryOpExpr>(exp);
717365777ecSAart Bik       assert(add.getLHS().getKind() == AffineExprKind::Mul);
718365777ecSAart Bik       auto mul = cast<AffineBinaryOpExpr>(add.getLHS());
719365777ecSAart Bik       ll = cast<AffineDimExpr>(mul.getLHS()).getPosition();
720365777ecSAart Bik       c = cast<AffineConstantExpr>(mul.getRHS()).getValue();
721365777ecSAart Bik       l = cast<AffineDimExpr>(add.getRHS()).getPosition();
722365777ecSAart Bik       break;
723365777ecSAart Bik     }
724365777ecSAart Bik     default:
725365777ecSAart Bik       llvm::report_fatal_error("unsupported lvl2dim in sparse tensor type");
726365777ecSAart Bik     }
727365777ecSAart Bik     lvl2dimValues[d] = constantIndex(builder, loc, encodeLvl(l, c, ll));
728365777ecSAart Bik   }
729365777ecSAart Bik   // Return buffers.
730365777ecSAart Bik   dim2lvlBuffer = allocaBuffer(builder, loc, dim2lvlValues);
731365777ecSAart Bik   lvl2dimBuffer = allocaBuffer(builder, loc, lvl2dimValues);
732365777ecSAart Bik   return allocaBuffer(builder, loc, lvlSizesValues); // lvlSizesBuffer
733365777ecSAart Bik }
734