xref: /llvm-project/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp (revision af6e1881e0791ac1ee611b62a3d12d9fb03ca142)
1 //===- SparseTensorDialect.cpp - Sparse tensor dialect implementation -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include <utility>
10 
11 #include "Detail/DimLvlMapParser.h"
12 
13 #include "mlir/Dialect/SparseTensor/IR/Enums.h"
14 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
15 #include "mlir/Dialect/SparseTensor/IR/SparseTensorStorageLayout.h"
16 #include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h"
17 
18 #include "mlir/Dialect/Arith/IR/Arith.h"
19 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
20 #include "mlir/Dialect/Complex/IR/Complex.h"
21 #include "mlir/Dialect/Utils/StaticValueUtils.h"
22 #include "mlir/IR/Builders.h"
23 #include "mlir/IR/DialectImplementation.h"
24 #include "mlir/IR/Matchers.h"
25 #include "mlir/IR/OpImplementation.h"
26 #include "mlir/IR/PatternMatch.h"
27 #include "llvm/ADT/Bitset.h"
28 #include "llvm/ADT/TypeSwitch.h"
29 #include "llvm/Support/FormatVariadic.h"
30 
31 #define GET_ATTRDEF_CLASSES
32 #include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.cpp.inc"
33 #include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrEnums.cpp.inc"
34 
35 // Forward declarations, following custom print/parsing methods are referenced
36 // by the generated code for SparseTensorTypes.td.
37 static mlir::ParseResult parseLevelRange(mlir::AsmParser &,
38                                          mlir::sparse_tensor::Level &,
39                                          mlir::sparse_tensor::Level &);
40 static void printLevelRange(mlir::AsmPrinter &, mlir::sparse_tensor::Level,
41                             mlir::sparse_tensor::Level);
42 
43 #define GET_TYPEDEF_CLASSES
44 #include "mlir/Dialect/SparseTensor/IR/SparseTensorTypes.cpp.inc"
45 
46 using namespace mlir;
47 using namespace mlir::sparse_tensor;
48 
49 // Support hashing LevelType such that SparseTensorEncodingAttr can be hashed as
50 // well.
51 namespace mlir::sparse_tensor {
52 llvm::hash_code hash_value(LevelType lt) {
53   return llvm::hash_value(static_cast<uint64_t>(lt));
54 }
55 } // namespace mlir::sparse_tensor
56 
57 //===----------------------------------------------------------------------===//
58 // Local Convenience Methods.
59 //===----------------------------------------------------------------------===//
60 
61 static constexpr bool acceptBitWidth(unsigned bitWidth) {
62   switch (bitWidth) {
63   case 0:
64   case 8:
65   case 16:
66   case 32:
67   case 64:
68     return true;
69   default:
70     return false;
71   }
72 }
73 
74 static SmallVector<Size>
75 getSparseFieldShape(const SparseTensorEncodingAttr enc,
76                     std::optional<ArrayRef<int64_t>> dimShape) {
77   assert(enc);
78   // With only encoding, we can not determine the static shape for leading
79   // batch levels, we therefore return a dynamic shape memref instead.
80   SmallVector<int64_t> memrefShape(enc.getBatchLvlRank(), ShapedType::kDynamic);
81   if (dimShape.has_value()) {
82     // If the actual tensor shape is provided, we can then refine the leading
83     // batch dimension.
84     SmallVector<int64_t> lvlShape =
85         enc.translateShape(*dimShape, CrdTransDirectionKind::dim2lvl);
86     memrefShape.assign(lvlShape.begin(),
87                        lvlShape.begin() + enc.getBatchLvlRank());
88   }
89   // Another dynamic dimension to store the sparse level.
90   memrefShape.push_back(ShapedType::kDynamic);
91   return memrefShape;
92 }
93 
94 //===----------------------------------------------------------------------===//
95 // SparseTensorDialect StorageLayout.
96 //===----------------------------------------------------------------------===//
97 
98 static constexpr Level kInvalidLevel = -1u;
99 static constexpr Level kInvalidFieldIndex = -1u;
100 static constexpr FieldIndex kDataFieldStartingIdx = 0;
101 
102 void StorageLayout::foreachField(
103     llvm::function_ref<bool(FieldIndex, SparseTensorFieldKind, Level,
104                             LevelType)>
105         callback) const {
106   const auto lvlTypes = enc.getLvlTypes();
107   const Level lvlRank = enc.getLvlRank();
108   SmallVector<COOSegment> cooSegs = enc.getCOOSegments();
109   FieldIndex fieldIdx = kDataFieldStartingIdx;
110 
111   ArrayRef cooSegsRef = cooSegs;
112   // Per-level storage.
113   for (Level l = 0; l < lvlRank; /*l += 1 or l += AoSCooLen*/) {
114     const auto lt = lvlTypes[l];
115     if (isWithPosLT(lt)) {
116       if (!(callback(fieldIdx++, SparseTensorFieldKind::PosMemRef, l, lt)))
117         return;
118     }
119     if (isWithCrdLT(lt)) {
120       if (!(callback(fieldIdx++, SparseTensorFieldKind::CrdMemRef, l, lt)))
121         return;
122     }
123     if (!cooSegsRef.empty() && cooSegsRef.front().isSegmentStart(l)) {
124       if (!cooSegsRef.front().isSoA) {
125         // AoS COO, all singletons are fused into one memrefs. Skips the entire
126         // COO segement.
127         l = cooSegsRef.front().lvlRange.second;
128       } else {
129         // SoA COO, each singleton level has one memref.
130         l++;
131       }
132       // Expire handled COO segment.
133       cooSegsRef = cooSegsRef.drop_front();
134     } else {
135       // Non COO levels.
136       l++;
137     }
138   }
139   // The values array.
140   if (!(callback(fieldIdx++, SparseTensorFieldKind::ValMemRef, kInvalidLevel,
141                  LevelFormat::Undef)))
142     return;
143   // Put metadata at the end.
144   if (!(callback(fieldIdx++, SparseTensorFieldKind::StorageSpec, kInvalidLevel,
145                  LevelFormat::Undef)))
146     return;
147 }
148 
149 void sparse_tensor::foreachFieldAndTypeInSparseTensor(
150     SparseTensorType stt,
151     llvm::function_ref<bool(Type, FieldIndex, SparseTensorFieldKind, Level,
152                             LevelType)>
153         callback) {
154   assert(stt.hasEncoding());
155 
156   SmallVector<int64_t> memrefShape =
157       getSparseFieldShape(stt.getEncoding(), stt.getDimShape());
158 
159   const Type specType = StorageSpecifierType::get(stt.getEncoding());
160   // memref<[batch] x ? x pos>  positions
161   const Type posMemType = MemRefType::get(memrefShape, stt.getPosType());
162   // memref<[batch] x ? x crd>  coordinates
163   const Type crdMemType = MemRefType::get(memrefShape, stt.getCrdType());
164   // memref<[batch] x ? x eltType> values
165   const Type valMemType = MemRefType::get(memrefShape, stt.getElementType());
166 
167   StorageLayout(stt).foreachField([specType, posMemType, crdMemType, valMemType,
168                                    callback](FieldIndex fieldIdx,
169                                              SparseTensorFieldKind fieldKind,
170                                              Level lvl, LevelType lt) -> bool {
171     switch (fieldKind) {
172     case SparseTensorFieldKind::StorageSpec:
173       return callback(specType, fieldIdx, fieldKind, lvl, lt);
174     case SparseTensorFieldKind::PosMemRef:
175       return callback(posMemType, fieldIdx, fieldKind, lvl, lt);
176     case SparseTensorFieldKind::CrdMemRef:
177       return callback(crdMemType, fieldIdx, fieldKind, lvl, lt);
178     case SparseTensorFieldKind::ValMemRef:
179       return callback(valMemType, fieldIdx, fieldKind, lvl, lt);
180     };
181     llvm_unreachable("unrecognized field kind");
182   });
183 }
184 
185 unsigned StorageLayout::getNumFields() const {
186   unsigned numFields = 0;
187   foreachField([&numFields](FieldIndex, SparseTensorFieldKind, Level,
188                             LevelType) -> bool {
189     numFields++;
190     return true;
191   });
192   return numFields;
193 }
194 
195 unsigned StorageLayout::getNumDataFields() const {
196   unsigned numFields = 0; // one value memref
197   foreachField([&numFields](FieldIndex fidx, SparseTensorFieldKind, Level,
198                             LevelType) -> bool {
199     if (fidx >= kDataFieldStartingIdx)
200       numFields++;
201     return true;
202   });
203   numFields -= 1; // the last field is StorageSpecifier
204   assert(numFields == getNumFields() - kDataFieldStartingIdx - 1);
205   return numFields;
206 }
207 
208 std::pair<FieldIndex, unsigned>
209 StorageLayout::getFieldIndexAndStride(SparseTensorFieldKind kind,
210                                       std::optional<Level> lvl) const {
211   FieldIndex fieldIdx = kInvalidFieldIndex;
212   unsigned stride = 1;
213   if (kind == SparseTensorFieldKind::CrdMemRef) {
214     assert(lvl.has_value());
215     const Level cooStart = enc.getAoSCOOStart();
216     const Level lvlRank = enc.getLvlRank();
217     if (lvl.value() >= cooStart && lvl.value() < lvlRank) {
218       lvl = cooStart;
219       stride = lvlRank - cooStart;
220     }
221   }
222   foreachField([lvl, kind, &fieldIdx](FieldIndex fIdx,
223                                       SparseTensorFieldKind fKind, Level fLvl,
224                                       LevelType lt) -> bool {
225     if ((lvl && fLvl == lvl.value() && kind == fKind) ||
226         (kind == fKind && fKind == SparseTensorFieldKind::ValMemRef)) {
227       fieldIdx = fIdx;
228       // Returns false to break the iteration.
229       return false;
230     }
231     return true;
232   });
233   assert(fieldIdx != kInvalidFieldIndex);
234   return std::pair<FieldIndex, unsigned>(fieldIdx, stride);
235 }
236 
237 //===----------------------------------------------------------------------===//
238 // SparseTensorDialect Attribute Methods.
239 //===----------------------------------------------------------------------===//
240 
241 std::optional<uint64_t> SparseTensorDimSliceAttr::getStatic(int64_t v) {
242   return isDynamic(v) ? std::nullopt
243                       : std::make_optional(static_cast<uint64_t>(v));
244 }
245 
246 std::optional<uint64_t> SparseTensorDimSliceAttr::getStaticOffset() const {
247   return getStatic(getOffset());
248 }
249 
250 std::optional<uint64_t> SparseTensorDimSliceAttr::getStaticStride() const {
251   return getStatic(getStride());
252 }
253 
254 std::optional<uint64_t> SparseTensorDimSliceAttr::getStaticSize() const {
255   return getStatic(getSize());
256 }
257 
258 bool SparseTensorDimSliceAttr::isCompletelyDynamic() const {
259   return isDynamic(getOffset()) && isDynamic(getStride()) &&
260          isDynamic(getSize());
261 }
262 
263 std::string SparseTensorDimSliceAttr::getStaticString(int64_t v) {
264   return isDynamic(v) ? "?" : std::to_string(v);
265 }
266 
267 void SparseTensorDimSliceAttr::print(llvm::raw_ostream &os) const {
268   assert(getImpl() && "Uninitialized SparseTensorDimSliceAttr");
269   os << '(';
270   os << getStaticString(getOffset());
271   os << ", ";
272   os << getStaticString(getSize());
273   os << ", ";
274   os << getStaticString(getStride());
275   os << ')';
276 }
277 
278 void SparseTensorDimSliceAttr::print(AsmPrinter &printer) const {
279   print(printer.getStream());
280 }
281 
282 static ParseResult parseOptionalStaticSlice(int64_t &result,
283                                             AsmParser &parser) {
284   auto parseResult = parser.parseOptionalInteger(result);
285   if (parseResult.has_value()) {
286     if (parseResult.value().succeeded() && result < 0) {
287       parser.emitError(
288           parser.getCurrentLocation(),
289           "expect positive value or ? for slice offset/size/stride");
290       return failure();
291     }
292     return parseResult.value();
293   }
294 
295   // Else, and '?' which represented dynamic slice
296   result = SparseTensorDimSliceAttr::kDynamic;
297   return parser.parseQuestion();
298 }
299 
300 Attribute SparseTensorDimSliceAttr::parse(AsmParser &parser, Type type) {
301   int64_t offset = kDynamic, size = kDynamic, stride = kDynamic;
302 
303   if (failed(parser.parseLParen()) ||
304       failed(parseOptionalStaticSlice(offset, parser)) ||
305       failed(parser.parseComma()) ||
306       failed(parseOptionalStaticSlice(size, parser)) ||
307       failed(parser.parseComma()) ||
308       failed(parseOptionalStaticSlice(stride, parser)) ||
309       failed(parser.parseRParen()))
310     return {};
311 
312   return parser.getChecked<SparseTensorDimSliceAttr>(parser.getContext(),
313                                                      offset, size, stride);
314 }
315 
316 LogicalResult
317 SparseTensorDimSliceAttr::verify(function_ref<InFlightDiagnostic()> emitError,
318                                  int64_t offset, int64_t size, int64_t stride) {
319   if (!isDynamic(offset) && offset < 0)
320     return emitError() << "expect non-negative value or ? for slice offset";
321   if (!isDynamic(size) && size <= 0)
322     return emitError() << "expect positive value or ? for slice size";
323   if (!isDynamic(stride) && stride <= 0)
324     return emitError() << "expect positive value or ? for slice stride";
325   return success();
326 }
327 
328 SparseTensorEncodingAttr
329 SparseTensorEncodingAttr::withDimToLvl(AffineMap dimToLvl) const {
330   assert(getImpl() && "Uninitialized SparseTensorEncodingAttr");
331   return SparseTensorEncodingAttr::get(
332       getContext(), getLvlTypes(), dimToLvl, AffineMap(), getPosWidth(),
333       getCrdWidth(), getExplicitVal(), getImplicitVal());
334 }
335 
336 SparseTensorEncodingAttr
337 SparseTensorEncodingAttr::withDimToLvl(SparseTensorEncodingAttr enc) const {
338   return withDimToLvl(enc ? enc.getDimToLvl() : AffineMap());
339 }
340 
341 SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutDimToLvl() const {
342   return withDimToLvl(AffineMap());
343 }
344 
345 SparseTensorEncodingAttr
346 SparseTensorEncodingAttr::withBitWidths(unsigned posWidth,
347                                         unsigned crdWidth) const {
348   assert(getImpl() && "Uninitialized SparseTensorEncodingAttr");
349   return SparseTensorEncodingAttr::get(
350       getContext(), getLvlTypes(), getDimToLvl(), getLvlToDim(), posWidth,
351       crdWidth, getExplicitVal(), getImplicitVal());
352 }
353 
354 SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutBitWidths() const {
355   return withBitWidths(0, 0);
356 }
357 
358 SparseTensorEncodingAttr
359 SparseTensorEncodingAttr::withExplicitVal(Attribute explicitVal) const {
360   assert(getImpl() && "Uninitialized SparseTensorEncodingAttr");
361   return SparseTensorEncodingAttr::get(
362       getContext(), getLvlTypes(), getDimToLvl(), getLvlToDim(), getPosWidth(),
363       getCrdWidth(), explicitVal, getImplicitVal());
364 }
365 
366 SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutExplicitVal() const {
367   return withExplicitVal(Attribute());
368 }
369 
370 SparseTensorEncodingAttr
371 SparseTensorEncodingAttr::withImplicitVal(Attribute implicitVal) const {
372   assert(getImpl() && "Uninitialized SparseTensorEncodingAttr");
373   return SparseTensorEncodingAttr::get(
374       getContext(), getLvlTypes(), getDimToLvl(), getLvlToDim(), getPosWidth(),
375       getCrdWidth(), getExplicitVal(), implicitVal);
376 }
377 
378 SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutImplicitVal() const {
379   return withImplicitVal(Attribute());
380 }
381 
382 SparseTensorEncodingAttr SparseTensorEncodingAttr::withDimSlices(
383     ArrayRef<SparseTensorDimSliceAttr> dimSlices) const {
384   return SparseTensorEncodingAttr::get(
385       getContext(), getLvlTypes(), getDimToLvl(), getLvlToDim(), getPosWidth(),
386       getCrdWidth(), getExplicitVal(), getImplicitVal(), dimSlices);
387 }
388 
389 SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutDimSlices() const {
390   return withDimSlices(ArrayRef<SparseTensorDimSliceAttr>{});
391 }
392 
393 uint64_t SparseTensorEncodingAttr::getBatchLvlRank() const {
394   ArrayRef<LevelType> lvlTypes = getLvlTypes();
395   auto lastBatch = std::find_if(lvlTypes.rbegin(), lvlTypes.rend(), isBatchLT);
396   return std::distance(lastBatch, lvlTypes.rend());
397 }
398 
399 bool SparseTensorEncodingAttr::isAllDense() const {
400   return !getImpl() || llvm::all_of(getLvlTypes(), isDenseLT);
401 }
402 
403 bool SparseTensorEncodingAttr::isAllOrdered() const {
404   return !getImpl() || llvm::all_of(getLvlTypes(), isOrderedLT);
405 }
406 
407 Type SparseTensorEncodingAttr::getCrdElemType() const {
408   if (!getImpl())
409     return nullptr;
410   if (getCrdWidth())
411     return IntegerType::get(getContext(), getCrdWidth());
412   return IndexType::get(getContext());
413 }
414 
415 Type SparseTensorEncodingAttr::getPosElemType() const {
416   if (!getImpl())
417     return nullptr;
418   if (getPosWidth())
419     return IntegerType::get(getContext(), getPosWidth());
420   return IndexType::get(getContext());
421 }
422 
423 MemRefType SparseTensorEncodingAttr::getCrdMemRefType(
424     std::optional<ArrayRef<int64_t>> dimShape) const {
425   SmallVector<Size> shape = getSparseFieldShape(*this, dimShape);
426   return MemRefType::get(shape, getCrdElemType());
427 }
428 
429 MemRefType SparseTensorEncodingAttr::getPosMemRefType(
430     std::optional<ArrayRef<int64_t>> dimShape) const {
431   SmallVector<Size> shape = getSparseFieldShape(*this, dimShape);
432   return MemRefType::get(shape, getPosElemType());
433 }
434 
435 bool SparseTensorEncodingAttr::isIdentity() const {
436   return !getImpl() || !getDimToLvl() || getDimToLvl().isIdentity();
437 }
438 
439 bool SparseTensorEncodingAttr::isPermutation() const {
440   return !getImpl() || !getDimToLvl() || getDimToLvl().isPermutation();
441 }
442 
443 Dimension SparseTensorEncodingAttr::getDimRank() const {
444   assert(getImpl() && "Uninitialized SparseTensorEncodingAttr");
445   const auto dimToLvl = getDimToLvl();
446   return dimToLvl ? dimToLvl.getNumDims() : getLvlRank();
447 }
448 
449 Level SparseTensorEncodingAttr::getLvlRank() const {
450   assert(getImpl() && "Uninitialized SparseTensorEncodingAttr");
451   return getLvlTypes().size();
452 }
453 
454 LevelType SparseTensorEncodingAttr::getLvlType(Level l) const {
455   if (!getImpl())
456     return LevelFormat::Batch;
457   assert(l < getLvlRank() && "Level is out of bounds");
458   return getLvlTypes()[l];
459 }
460 
461 bool SparseTensorEncodingAttr::isSlice() const {
462   assert(getImpl() && "Uninitialized SparseTensorEncodingAttr");
463   return !getDimSlices().empty();
464 }
465 
466 SparseTensorDimSliceAttr
467 SparseTensorEncodingAttr::getDimSlice(Dimension dim) const {
468   assert(isSlice() && "Is not a slice");
469   const auto dimSlices = getDimSlices();
470   assert(dim < dimSlices.size() && "Dimension is out of bounds");
471   return dimSlices[dim];
472 }
473 
474 std::optional<uint64_t>
475 SparseTensorEncodingAttr::getStaticDimSliceOffset(Dimension dim) const {
476   return getDimSlice(dim).getStaticOffset();
477 }
478 
479 std::optional<uint64_t>
480 SparseTensorEncodingAttr::getStaticDimSliceStride(Dimension dim) const {
481   return getDimSlice(dim).getStaticStride();
482 }
483 
484 std::optional<uint64_t>
485 SparseTensorEncodingAttr::getStaticLvlSliceOffset(Level lvl) const {
486   return getStaticDimSliceOffset(toDim(*this, lvl));
487 }
488 
489 std::optional<uint64_t>
490 SparseTensorEncodingAttr::getStaticLvlSliceStride(Level lvl) const {
491   return getStaticDimSliceStride(toDim(*this, lvl));
492 }
493 
494 SmallVector<int64_t>
495 SparseTensorEncodingAttr::translateShape(ArrayRef<int64_t> srcShape,
496                                          CrdTransDirectionKind dir) const {
497   if (isIdentity())
498     return SmallVector<int64_t>(srcShape);
499 
500   SmallVector<int64_t> ret;
501   unsigned rank =
502       dir == CrdTransDirectionKind::dim2lvl ? getLvlRank() : getDimRank();
503   ret.reserve(rank);
504 
505   if (isPermutation()) {
506     for (unsigned r = 0; r < rank; r++) {
507       unsigned trans = dir == CrdTransDirectionKind::dim2lvl ? toDim(*this, r)
508                                                              : toLvl(*this, r);
509       ret.push_back(srcShape[trans]);
510     }
511     return ret;
512   }
513 
514   // Handle non-permutation maps.
515   AffineMap transMap =
516       dir == CrdTransDirectionKind::dim2lvl ? getDimToLvl() : getLvlToDim();
517 
518   SmallVector<AffineExpr> dimRep;
519   dimRep.reserve(srcShape.size());
520   for (int64_t sz : srcShape) {
521     if (!ShapedType::isDynamic(sz)) {
522       // Push back the max coordinate for the given dimension/level size.
523       dimRep.push_back(getAffineConstantExpr(sz - 1, getContext()));
524     } else {
525       // A dynamic size, use a AffineDimExpr to symbolize the value.
526       dimRep.push_back(getAffineDimExpr(dimRep.size(), getContext()));
527     }
528   };
529 
530   for (AffineExpr exp : transMap.getResults()) {
531     // Do constant propagation on the affine map.
532     AffineExpr evalExp =
533         simplifyAffineExpr(exp.replaceDims(dimRep), srcShape.size(), 0);
534     // use llvm namespace here to avoid ambiguity
535     if (auto c = llvm::dyn_cast<AffineConstantExpr>(evalExp)) {
536       ret.push_back(c.getValue() + 1);
537     } else {
538       if (auto mod = llvm::dyn_cast<AffineBinaryOpExpr>(evalExp);
539           mod && mod.getKind() == AffineExprKind::Mod) {
540         // We can still infer a static bound for expressions in form
541         // "d % constant" since d % constant \in [0, constant).
542         if (auto bound = llvm::dyn_cast<AffineConstantExpr>(mod.getRHS())) {
543           ret.push_back(bound.getValue());
544           continue;
545         }
546       }
547       ret.push_back(ShapedType::kDynamic);
548     }
549   }
550   assert(ret.size() == rank);
551   return ret;
552 }
553 
554 ValueRange
555 SparseTensorEncodingAttr::translateCrds(OpBuilder &builder, Location loc,
556                                         ValueRange crds,
557                                         CrdTransDirectionKind dir) const {
558   if (!getImpl())
559     return crds;
560 
561   SmallVector<Type> retType(
562       dir == CrdTransDirectionKind::lvl2dim ? getDimRank() : getLvlRank(),
563       builder.getIndexType());
564   auto transOp = builder.create<CrdTranslateOp>(loc, retType, crds, dir, *this);
565   return transOp.getOutCrds();
566 }
567 
568 Attribute SparseTensorEncodingAttr::parse(AsmParser &parser, Type type) {
569   // Open "<{" part.
570   if (failed(parser.parseLess()))
571     return {};
572   if (failed(parser.parseLBrace()))
573     return {};
574 
575   // Process the data from the parsed dictionary value into struct-like data.
576   SmallVector<LevelType> lvlTypes;
577   SmallVector<SparseTensorDimSliceAttr> dimSlices;
578   AffineMap dimToLvl = {};
579   AffineMap lvlToDim = {};
580   unsigned posWidth = 0;
581   unsigned crdWidth = 0;
582   Attribute explicitVal;
583   Attribute implicitVal;
584   StringRef attrName;
585   SmallVector<StringRef, 5> keys = {"map", "posWidth", "crdWidth",
586                                     "explicitVal", "implicitVal"};
587   while (succeeded(parser.parseOptionalKeyword(&attrName))) {
588     // Detect admissible keyword.
589     auto *it = find(keys, attrName);
590     if (it == keys.end()) {
591       parser.emitError(parser.getNameLoc(), "unexpected key: ") << attrName;
592       return {};
593     }
594     unsigned keyWordIndex = it - keys.begin();
595     // Consume the `=` after keys
596     if (failed(parser.parseEqual()))
597       return {};
598     // Dispatch on keyword.
599     switch (keyWordIndex) {
600     case 0: { // map
601       ir_detail::DimLvlMapParser cParser(parser);
602       auto res = cParser.parseDimLvlMap();
603       if (failed(res))
604         return {};
605       const auto &dlm = *res;
606 
607       const Level lvlRank = dlm.getLvlRank();
608       for (Level lvl = 0; lvl < lvlRank; lvl++)
609         lvlTypes.push_back(dlm.getLvlType(lvl));
610 
611       const Dimension dimRank = dlm.getDimRank();
612       for (Dimension dim = 0; dim < dimRank; dim++)
613         dimSlices.push_back(dlm.getDimSlice(dim));
614       // NOTE: the old syntax requires an all-or-nothing approach to
615       // `dimSlices`; therefore, if any slice actually exists then we need
616       // to convert null-DSA into default/nop DSA.
617       const auto isDefined = [](SparseTensorDimSliceAttr slice) {
618         return static_cast<bool>(slice.getImpl());
619       };
620       if (llvm::any_of(dimSlices, isDefined)) {
621         const auto defaultSlice =
622             SparseTensorDimSliceAttr::get(parser.getContext());
623         for (Dimension dim = 0; dim < dimRank; dim++)
624           if (!isDefined(dimSlices[dim]))
625             dimSlices[dim] = defaultSlice;
626       } else {
627         dimSlices.clear();
628       }
629 
630       dimToLvl = dlm.getDimToLvlMap(parser.getContext());
631       lvlToDim = dlm.getLvlToDimMap(parser.getContext());
632       break;
633     }
634     case 1: { // posWidth
635       Attribute attr;
636       if (failed(parser.parseAttribute(attr)))
637         return {};
638       auto intAttr = llvm::dyn_cast<IntegerAttr>(attr);
639       if (!intAttr) {
640         parser.emitError(parser.getNameLoc(),
641                          "expected an integral position bitwidth");
642         return {};
643       }
644       posWidth = intAttr.getInt();
645       break;
646     }
647     case 2: { // crdWidth
648       Attribute attr;
649       if (failed(parser.parseAttribute(attr)))
650         return {};
651       auto intAttr = llvm::dyn_cast<IntegerAttr>(attr);
652       if (!intAttr) {
653         parser.emitError(parser.getNameLoc(),
654                          "expected an integral index bitwidth");
655         return {};
656       }
657       crdWidth = intAttr.getInt();
658       break;
659     }
660     case 3: { // explicitVal
661       Attribute attr;
662       if (failed(parser.parseAttribute(attr)))
663         return {};
664       if (auto result = llvm::dyn_cast<FloatAttr>(attr)) {
665         explicitVal = result;
666       } else if (auto result = llvm::dyn_cast<IntegerAttr>(attr)) {
667         explicitVal = result;
668       } else if (auto result = llvm::dyn_cast<complex::NumberAttr>(attr)) {
669         explicitVal = result;
670       } else {
671         parser.emitError(parser.getNameLoc(),
672                          "expected a numeric value for explicitVal");
673         return {};
674       }
675       break;
676     }
677     case 4: { // implicitVal
678       Attribute attr;
679       if (failed(parser.parseAttribute(attr)))
680         return {};
681       if (auto result = llvm::dyn_cast<FloatAttr>(attr)) {
682         implicitVal = result;
683       } else if (auto result = llvm::dyn_cast<IntegerAttr>(attr)) {
684         implicitVal = result;
685       } else if (auto result = llvm::dyn_cast<complex::NumberAttr>(attr)) {
686         implicitVal = result;
687       } else {
688         parser.emitError(parser.getNameLoc(),
689                          "expected a numeric value for implicitVal");
690         return {};
691       }
692       break;
693     }
694     } // switch
695     // Only last item can omit the comma.
696     if (parser.parseOptionalComma().failed())
697       break;
698   }
699 
700   // Close "}>" part.
701   if (failed(parser.parseRBrace()))
702     return {};
703   if (failed(parser.parseGreater()))
704     return {};
705 
706   // Construct struct-like storage for attribute.
707   if (!lvlToDim || lvlToDim.isEmpty()) {
708     lvlToDim = inferLvlToDim(dimToLvl, parser.getContext());
709   }
710   return parser.getChecked<SparseTensorEncodingAttr>(
711       parser.getContext(), lvlTypes, dimToLvl, lvlToDim, posWidth, crdWidth,
712       explicitVal, implicitVal, dimSlices);
713 }
714 
715 void SparseTensorEncodingAttr::print(AsmPrinter &printer) const {
716   auto map = static_cast<AffineMap>(getDimToLvl());
717   // Empty affine map indicates identity map
718   if (!map)
719     map = AffineMap::getMultiDimIdentityMap(getLvlTypes().size(), getContext());
720   printer << "<{ map = ";
721   printSymbols(map, printer);
722   printer << '(';
723   printDimensions(map, printer, getDimSlices());
724   printer << ") -> (";
725   printLevels(map, printer, getLvlTypes());
726   printer << ')';
727   // Print remaining members only for non-default values.
728   if (getPosWidth())
729     printer << ", posWidth = " << getPosWidth();
730   if (getCrdWidth())
731     printer << ", crdWidth = " << getCrdWidth();
732   if (getExplicitVal()) {
733     printer << ", explicitVal = " << getExplicitVal();
734   }
735   if (getImplicitVal())
736     printer << ", implicitVal = " << getImplicitVal();
737   printer << " }>";
738 }
739 
740 void SparseTensorEncodingAttr::printSymbols(AffineMap &map,
741                                             AsmPrinter &printer) const {
742   if (map.getNumSymbols() == 0)
743     return;
744   printer << '[';
745   for (unsigned i = 0, n = map.getNumSymbols() - 1; i < n; i++)
746     printer << 's' << i << ", ";
747   if (map.getNumSymbols() >= 1)
748     printer << 's' << map.getNumSymbols() - 1;
749   printer << ']';
750 }
751 
752 void SparseTensorEncodingAttr::printDimensions(
753     AffineMap &map, AsmPrinter &printer,
754     ArrayRef<SparseTensorDimSliceAttr> dimSlices) const {
755   if (!dimSlices.empty()) {
756     for (unsigned i = 0, n = map.getNumDims() - 1; i < n; i++)
757       printer << 'd' << i << " : " << dimSlices[i] << ", ";
758     if (map.getNumDims() >= 1) {
759       printer << 'd' << map.getNumDims() - 1 << " : "
760               << dimSlices[map.getNumDims() - 1];
761     }
762   } else {
763     for (unsigned i = 0, n = map.getNumDims() - 1; i < n; i++)
764       printer << 'd' << i << ", ";
765     if (map.getNumDims() >= 1)
766       printer << 'd' << map.getNumDims() - 1;
767   }
768 }
769 
770 void SparseTensorEncodingAttr::printLevels(AffineMap &map, AsmPrinter &printer,
771                                            ArrayRef<LevelType> lvlTypes) const {
772   for (unsigned i = 0, n = map.getNumResults() - 1; i < n; i++) {
773     map.getResult(i).print(printer.getStream());
774     printer << " : " << toMLIRString(lvlTypes[i]) << ", ";
775   }
776   if (map.getNumResults() >= 1) {
777     auto lastIndex = map.getNumResults() - 1;
778     map.getResult(lastIndex).print(printer.getStream());
779     printer << " : " << toMLIRString(lvlTypes[lastIndex]);
780   }
781 }
782 
783 LogicalResult SparseTensorEncodingAttr::verify(
784     function_ref<InFlightDiagnostic()> emitError, ArrayRef<LevelType> lvlTypes,
785     AffineMap dimToLvl, AffineMap lvlToDim, unsigned posWidth,
786     unsigned crdWidth, Attribute explicitVal, Attribute implicitVal,
787     ArrayRef<SparseTensorDimSliceAttr> dimSlices) {
788   if (!acceptBitWidth(posWidth))
789     return emitError() << "unexpected position bitwidth: " << posWidth;
790   if (!acceptBitWidth(crdWidth))
791     return emitError() << "unexpected coordinate bitwidth: " << crdWidth;
792 
793   // Verify every COO segment.
794   auto *it = std::find_if(lvlTypes.begin(), lvlTypes.end(), isSingletonLT);
795   while (it != lvlTypes.end()) {
796     if (it == lvlTypes.begin() ||
797         !(it - 1)->isa<LevelFormat::Compressed, LevelFormat::LooseCompressed>())
798       return emitError() << "expected compressed or loose_compressed level "
799                             "before singleton level";
800 
801     auto *curCOOEnd = std::find_if_not(it, lvlTypes.end(), isSingletonLT);
802     if (!std::all_of(it, curCOOEnd,
803                      [](LevelType i) { return isSingletonLT(i); }))
804       return emitError() << "expected all singleton lvlTypes "
805                             "following a singleton level";
806     // We can potentially support mixed SoA/AoS singleton levels.
807     if (!std::all_of(it, curCOOEnd, [it](LevelType i) {
808           return it->isa<LevelPropNonDefault::SoA>() ==
809                  i.isa<LevelPropNonDefault::SoA>();
810         })) {
811       return emitError() << "expected all singleton lvlTypes stored in the "
812                             "same memory layout (SoA vs AoS).";
813     }
814     it = std::find_if(curCOOEnd, lvlTypes.end(), isSingletonLT);
815   }
816 
817   auto lastBatch = std::find_if(lvlTypes.rbegin(), lvlTypes.rend(), isBatchLT);
818   if (!std::all_of(lastBatch, lvlTypes.rend(), isBatchLT))
819     return emitError() << "Batch lvlType can only be leading levels.";
820 
821   // SoA property can only be applied on singleton level.
822   auto soaLvls = llvm::make_filter_range(lvlTypes, [](LevelType lt) {
823     return lt.isa<LevelPropNonDefault::SoA>();
824   });
825   if (llvm::any_of(soaLvls, [](LevelType lt) {
826         return !lt.isa<LevelFormat::Singleton>();
827       })) {
828     return emitError() << "SoA is only applicable to singleton lvlTypes.";
829   }
830 
831   // TODO: audit formats that actually are supported by backend.
832   if (auto it = std::find_if(lvlTypes.begin(), lvlTypes.end(), isNOutOfMLT);
833       it != std::end(lvlTypes)) {
834     if (it != lvlTypes.end() - 1)
835       return emitError() << "expected n_out_of_m to be the last level type";
836     if (!std::all_of(lvlTypes.begin(), it,
837                      [](LevelType i) { return isDenseLT(i); }))
838       return emitError() << "expected all dense lvlTypes "
839                             "before a n_out_of_m level";
840     if (dimToLvl && (dimToLvl.getNumDims() != dimToLvl.getNumResults())) {
841       if (!isBlockSparsity(dimToLvl)) {
842         return emitError()
843                << "expected 1xm block structure for n_out_of_m level";
844       }
845       auto sizes = getBlockSize(dimToLvl);
846       unsigned coefficient = 0;
847       for (const auto &elem : sizes) {
848         if (elem != 0) {
849           if (elem != coefficient && coefficient != 0) {
850             return emitError() << "expected only one blocked level "
851                                   "with the same coefficients";
852           }
853           coefficient = elem;
854         }
855       }
856       if (coefficient != getM(*it)) {
857         return emitError() << "expected coeffiencts of Affine expressions "
858                               "to be equal to m of n_out_of_m level";
859       }
860     }
861   }
862   // Before we can check that the level-rank is consistent/coherent
863   // across all fields, we need to define it.  The source-of-truth for
864   // the `getLvlRank` method is the length of the level-types array,
865   // since it must always be provided and have full rank; therefore we
866   // use that same source-of-truth here.
867   const Level lvlRank = lvlTypes.size();
868   if (lvlRank == 0)
869     return emitError() << "expected a non-empty array for lvlTypes";
870   // We save `dimRank` here because we'll also need it to verify `dimSlices`.
871   const Dimension dimRank = dimToLvl ? dimToLvl.getNumDims() : lvlRank;
872   if (dimToLvl) {
873     if (dimToLvl.getNumResults() != lvlRank)
874       return emitError()
875              << "level-rank mismatch between dimToLvl and lvlTypes: "
876              << dimToLvl.getNumResults() << " != " << lvlRank;
877     auto inferRes = inferLvlToDim(dimToLvl, dimToLvl.getContext());
878     // Symbols can't be inferred but are acceptable.
879     if (!inferRes && dimToLvl.getNumSymbols() == 0)
880       return emitError() << "failed to infer lvlToDim from dimToLvl";
881     if (lvlToDim && (inferRes != lvlToDim))
882       return emitError() << "expected lvlToDim to be an inverse of dimToLvl";
883     if (dimRank > lvlRank)
884       return emitError() << "unexpected dimToLvl mapping from " << dimRank
885                          << " to " << lvlRank;
886   }
887   if (!dimSlices.empty()) {
888     if (dimSlices.size() != dimRank)
889       return emitError()
890              << "dimension-rank mismatch between dimSlices and dimToLvl: "
891              << dimSlices.size() << " != " << dimRank;
892     // Compiler support for `dimSlices` currently requires that the two
893     // ranks agree.  (However, it does allow `dimToLvl` to be a permutation.)
894     if (dimRank != lvlRank)
895       return emitError()
896              << "dimSlices expected dimension-rank to match level-rank: "
897              << dimRank << " != " << lvlRank;
898   }
899   return success();
900 }
901 
902 LogicalResult SparseTensorEncodingAttr::verifyEncoding(
903     ArrayRef<Size> dimShape, Type elementType,
904     function_ref<InFlightDiagnostic()> emitError) const {
905   // Check structural integrity.  In particular, this ensures that the
906   // level-rank is coherent across all the fields.
907   if (failed(verify(emitError, getLvlTypes(), getDimToLvl(), getLvlToDim(),
908                     getPosWidth(), getCrdWidth(), getExplicitVal(),
909                     getImplicitVal(), getDimSlices())))
910     return failure();
911   // Check integrity with tensor type specifics.  In particular, we
912   // need only check that the dimension-rank of the tensor agrees with
913   // the dimension-rank of the encoding.
914   const Dimension dimRank = dimShape.size();
915   if (dimRank == 0)
916     return emitError() << "expected non-scalar sparse tensor";
917   if (getDimRank() != dimRank)
918     return emitError()
919            << "dimension-rank mismatch between encoding and tensor shape: "
920            << getDimRank() << " != " << dimRank;
921   if (auto expVal = getExplicitVal()) {
922     Type attrType = llvm::dyn_cast<TypedAttr>(expVal).getType();
923     if (attrType != elementType) {
924       return emitError() << "explicit value type mismatch between encoding and "
925                          << "tensor element type: " << attrType
926                          << " != " << elementType;
927     }
928   }
929   if (auto impVal = getImplicitVal()) {
930     Type attrType = llvm::dyn_cast<TypedAttr>(impVal).getType();
931     if (attrType != elementType) {
932       return emitError() << "implicit value type mismatch between encoding and "
933                          << "tensor element type: " << attrType
934                          << " != " << elementType;
935     }
936     // Currently, we only support zero as the implicit value.
937     auto impFVal = llvm::dyn_cast<FloatAttr>(impVal);
938     auto impIntVal = llvm::dyn_cast<IntegerAttr>(impVal);
939     auto impComplexVal = llvm::dyn_cast<complex::NumberAttr>(impVal);
940     if ((impFVal && impFVal.getValue().isNonZero()) ||
941         (impIntVal && !impIntVal.getValue().isZero()) ||
942         (impComplexVal && (impComplexVal.getImag().isNonZero() ||
943                            impComplexVal.getReal().isNonZero()))) {
944       return emitError() << "implicit value must be zero";
945     }
946   }
947   return success();
948 }
949 
950 Level mlir::sparse_tensor::SparseTensorEncodingAttr::getAoSCOOStart() const {
951   SmallVector<COOSegment> coo = getCOOSegments();
952   assert(coo.size() == 1 || coo.empty());
953   if (!coo.empty() && coo.front().isAoS()) {
954     return coo.front().lvlRange.first;
955   }
956   return getLvlRank();
957 }
958 
959 SmallVector<COOSegment>
960 mlir::sparse_tensor::SparseTensorEncodingAttr::getCOOSegments() const {
961   SmallVector<COOSegment> ret;
962   if (getLvlRank() <= 1)
963     return ret;
964 
965   ArrayRef<LevelType> lts = getLvlTypes();
966   Level l = 0;
967   while (l < getLvlRank()) {
968     auto lt = lts[l];
969     if (lt.isa<LevelFormat::Compressed, LevelFormat::LooseCompressed>()) {
970       auto cur = lts.begin() + l;
971       auto end = std::find_if(cur + 1, lts.end(), [](LevelType lt) {
972         return !lt.isa<LevelFormat::Singleton>();
973       });
974       unsigned cooLen = std::distance(cur, end);
975       if (cooLen > 1) {
976         // To support mixed SoA/AoS COO, we should break the segment when the
977         // storage scheme changes, for now we faithfully assume that all
978         // consecutive singleton levels have the same storage format as verified
979         // STEA.
980         ret.push_back(COOSegment{std::make_pair(l, l + cooLen),
981                                  lts[l + 1].isa<LevelPropNonDefault::SoA>()});
982       }
983       l += cooLen;
984     } else {
985       l++;
986     }
987   }
988   return ret;
989 }
990 
991 //===----------------------------------------------------------------------===//
992 // SparseTensorType Methods.
993 //===----------------------------------------------------------------------===//
994 
995 bool mlir::sparse_tensor::SparseTensorType::isCOOType(Level startLvl,
996                                                       bool isUnique) const {
997   if (!hasEncoding())
998     return false;
999   if (!isCompressedLvl(startLvl) && !isLooseCompressedLvl(startLvl))
1000     return false;
1001   for (Level l = startLvl + 1; l < lvlRank; ++l)
1002     if (!isSingletonLvl(l))
1003       return false;
1004   // If isUnique is true, then make sure that the last level is unique,
1005   // that is, when lvlRank == 1, the only compressed level is unique,
1006   // and when lvlRank > 1, the last singleton is unique.
1007   return !isUnique || isUniqueLvl(lvlRank - 1);
1008 }
1009 
1010 RankedTensorType
1011 mlir::sparse_tensor::SparseTensorType::getCOOType(bool ordered) const {
1012   SmallVector<LevelType> lvlTypes;
1013   lvlTypes.reserve(lvlRank);
1014   // A non-unique compressed level at beginning (unless this is
1015   // also the last level, then it is unique).
1016   lvlTypes.push_back(
1017       *buildLevelType(LevelFormat::Compressed, ordered, lvlRank == 1));
1018   if (lvlRank > 1) {
1019     // Followed by n-2 non-unique singleton levels.
1020     std::fill_n(std::back_inserter(lvlTypes), lvlRank - 2,
1021                 *buildLevelType(LevelFormat::Singleton, ordered, false));
1022     // Ends by a unique singleton level.
1023     lvlTypes.push_back(*buildLevelType(LevelFormat::Singleton, ordered, true));
1024   }
1025   auto enc = SparseTensorEncodingAttr::get(
1026       getContext(), lvlTypes, getDimToLvl(), getLvlToDim(), getPosWidth(),
1027       getCrdWidth(), getExplicitVal(), getImplicitVal());
1028   return RankedTensorType::get(getDimShape(), getElementType(), enc);
1029 }
1030 
1031 //===----------------------------------------------------------------------===//
1032 // Convenience Methods.
1033 //===----------------------------------------------------------------------===//
1034 
1035 SparseTensorEncodingAttr
1036 mlir::sparse_tensor::getSparseTensorEncoding(Type type) {
1037   if (auto ttp = llvm::dyn_cast<RankedTensorType>(type))
1038     return llvm::dyn_cast_or_null<SparseTensorEncodingAttr>(ttp.getEncoding());
1039   if (auto mdtp = llvm::dyn_cast<StorageSpecifierType>(type))
1040     return mdtp.getEncoding();
1041   return nullptr;
1042 }
1043 
1044 AffineMap mlir::sparse_tensor::inferLvlToDim(AffineMap dimToLvl,
1045                                              MLIRContext *context) {
1046   auto map = static_cast<AffineMap>(dimToLvl);
1047   AffineMap lvlToDim;
1048   // Return an empty lvlToDim when inference is not successful.
1049   if (!map || map.getNumSymbols() != 0) {
1050     lvlToDim = AffineMap();
1051   } else if (map.isPermutation()) {
1052     lvlToDim = inversePermutation(map);
1053   } else if (isBlockSparsity(map)) {
1054     lvlToDim = inverseBlockSparsity(map, context);
1055   }
1056   return lvlToDim;
1057 }
1058 
1059 AffineMap mlir::sparse_tensor::inverseBlockSparsity(AffineMap dimToLvl,
1060                                                     MLIRContext *context) {
1061   SmallVector<AffineExpr> lvlExprs;
1062   auto numLvls = dimToLvl.getNumResults();
1063   lvlExprs.reserve(numLvls);
1064   // lvlExprComponents stores information of the floordiv and mod operations
1065   // applied to the same dimension, so as to build the lvlToDim map.
1066   std::map<unsigned, SmallVector<AffineExpr, 3>> lvlExprComponents;
1067   for (unsigned i = 0, n = numLvls; i < n; i++) {
1068     auto result = dimToLvl.getResult(i);
1069     if (auto binOp = dyn_cast<AffineBinaryOpExpr>(result)) {
1070       if (result.getKind() == AffineExprKind::FloorDiv) {
1071         // Position of the dimension in dimToLvl.
1072         auto pos = dyn_cast<AffineDimExpr>(binOp.getLHS()).getPosition();
1073         assert(lvlExprComponents.find(pos) == lvlExprComponents.end() &&
1074                "expected only one floordiv for each dimension");
1075         SmallVector<AffineExpr, 3> components;
1076         // Level variable for floordiv.
1077         components.push_back(getAffineDimExpr(i, context));
1078         // Multiplier.
1079         components.push_back(binOp.getRHS());
1080         // Map key is the position of the dimension.
1081         lvlExprComponents[pos] = components;
1082       } else if (result.getKind() == AffineExprKind::Mod) {
1083         auto pos = dyn_cast<AffineDimExpr>(binOp.getLHS()).getPosition();
1084         assert(lvlExprComponents.find(pos) != lvlExprComponents.end() &&
1085                "expected floordiv before mod");
1086         // Add level variable for mod to the same vector
1087         // of the corresponding floordiv.
1088         lvlExprComponents[pos].push_back(getAffineDimExpr(i, context));
1089       } else {
1090         assert(false && "expected floordiv or mod");
1091       }
1092     } else {
1093       lvlExprs.push_back(getAffineDimExpr(i, context));
1094     }
1095   }
1096   // Build lvlExprs from lvlExprComponents.
1097   // For example, for il = i floordiv 2 and ii = i mod 2, the components
1098   // would be [il, 2, ii]. It could be used to build the AffineExpr
1099   // i = il * 2 + ii in lvlToDim.
1100   for (auto &components : lvlExprComponents) {
1101     assert(components.second.size() == 3 &&
1102            "expected 3 components to build lvlExprs");
1103     auto mulOp = getAffineBinaryOpExpr(
1104         AffineExprKind::Mul, components.second[0], components.second[1]);
1105     auto addOp =
1106         getAffineBinaryOpExpr(AffineExprKind::Add, mulOp, components.second[2]);
1107     lvlExprs.push_back(addOp);
1108   }
1109   return dimToLvl.get(dimToLvl.getNumResults(), 0, lvlExprs, context);
1110 }
1111 
1112 SmallVector<unsigned> mlir::sparse_tensor::getBlockSize(AffineMap dimToLvl) {
1113   assert(isBlockSparsity(dimToLvl) &&
1114          "expected dimToLvl to be block sparsity for calling getBlockSize");
1115   SmallVector<unsigned> blockSize;
1116   for (auto result : dimToLvl.getResults()) {
1117     if (auto binOp = dyn_cast<AffineBinaryOpExpr>(result)) {
1118       if (result.getKind() == AffineExprKind::Mod) {
1119         blockSize.push_back(
1120             dyn_cast<AffineConstantExpr>(binOp.getRHS()).getValue());
1121       }
1122     } else {
1123       blockSize.push_back(0);
1124     }
1125   }
1126   return blockSize;
1127 }
1128 
1129 bool mlir::sparse_tensor::isBlockSparsity(AffineMap dimToLvl) {
1130   if (!dimToLvl)
1131     return false;
1132   std::map<unsigned, int64_t> coeffientMap;
1133   bool hasBlock = false;
1134   for (auto result : dimToLvl.getResults()) {
1135     if (auto binOp = dyn_cast<AffineBinaryOpExpr>(result)) {
1136       // Check for "dim op const".
1137       auto dimOp = dyn_cast<AffineDimExpr>(binOp.getLHS());
1138       auto conOp = dyn_cast<AffineConstantExpr>(binOp.getRHS());
1139       if (!dimOp || !conOp || conOp.getValue() <= 0)
1140         return false;
1141       // Inspect "dim / const" or "dim % const".
1142       auto pos = dimOp.getPosition();
1143       if (binOp.getKind() == AffineExprKind::FloorDiv) {
1144         // Expect only one floordiv for each dimension.
1145         auto [it, inserted] = coeffientMap.try_emplace(pos);
1146         if (!inserted)
1147           return false;
1148         // Record coefficient of the floordiv.
1149         it->second = conOp.getValue();
1150       } else if (binOp.getKind() == AffineExprKind::Mod) {
1151         // Expect floordiv before mod.
1152         auto it = coeffientMap.find(pos);
1153         if (it == coeffientMap.end())
1154           return false;
1155         // Expect mod to have the same coefficient as floordiv.
1156         if (conOp.getValue() != it->second)
1157           return false;
1158         hasBlock = true;
1159       } else {
1160         return false;
1161       }
1162     } else if (auto dimOp = dyn_cast<AffineDimExpr>(result)) {
1163       auto pos = dimOp.getPosition();
1164       // Expect dim to be unset.
1165       if (!coeffientMap.try_emplace(pos, 0).second)
1166         return false;
1167     } else {
1168       return false;
1169     }
1170   }
1171   return hasBlock;
1172 }
1173 
1174 bool mlir::sparse_tensor::hasAnyNonIdentityOperandsOrResults(Operation *op) {
1175   auto hasNonIdentityMap = [](Value v) {
1176     auto stt = tryGetSparseTensorType(v);
1177     return stt && !stt->isIdentity();
1178   };
1179 
1180   return llvm::any_of(op->getOperands(), hasNonIdentityMap) ||
1181          llvm::any_of(op->getResults(), hasNonIdentityMap);
1182 }
1183 
1184 Dimension mlir::sparse_tensor::toDim(SparseTensorEncodingAttr enc, Level l) {
1185   if (enc) {
1186     assert(enc.isPermutation() && "Non permutation map not supported");
1187     if (const auto dimToLvl = enc.getDimToLvl())
1188       return dimToLvl.getDimPosition(l);
1189   }
1190   return l;
1191 }
1192 
1193 Level mlir::sparse_tensor::toLvl(SparseTensorEncodingAttr enc, Dimension d) {
1194   if (enc) {
1195     assert(enc.isPermutation() && "Non permutation map not supported");
1196     if (const auto lvlToDim = enc.getLvlToDim())
1197       return lvlToDim.getDimPosition(d);
1198   }
1199   return d;
1200 }
1201 
1202 /// We normalized sparse tensor encoding attribute by always using
1203 /// ordered/unique LT such that "compressed_nu_no" and "compressed_nu" (as well
1204 /// as other variants) lead to the same storage specifier type, and stripping
1205 /// irrelevant fields that do not alter the sparse tensor memory layout.
1206 static SparseTensorEncodingAttr
1207 getNormalizedEncodingForSpecifier(SparseTensorEncodingAttr enc) {
1208   SmallVector<LevelType> lts;
1209   for (auto lt : enc.getLvlTypes())
1210     lts.push_back(lt.stripStorageIrrelevantProperties());
1211 
1212   return SparseTensorEncodingAttr::get(
1213       enc.getContext(), lts,
1214       AffineMap(), // dimToLvl (irrelevant to storage specifier)
1215       AffineMap(), // lvlToDim (irrelevant to storage specifier)
1216       // Always use `index` for memSize and lvlSize instead of reusing
1217       // `getPosWidth` and `getCrdWidth`. It allows us to reuse the same SSA
1218       // value for different bitwidth, it also avoids casting between index and
1219       // integer (returned by DimOp)
1220       0, 0,
1221       Attribute(), // explicitVal (irrelevant to storage specifier)
1222       Attribute(), // implicitVal (irrelevant to storage specifier)
1223       enc.getDimSlices());
1224 }
1225 
1226 StorageSpecifierType
1227 StorageSpecifierType::get(MLIRContext *ctx, SparseTensorEncodingAttr encoding) {
1228   return Base::get(ctx, getNormalizedEncodingForSpecifier(encoding));
1229 }
1230 
1231 StorageSpecifierType
1232 StorageSpecifierType::getChecked(function_ref<InFlightDiagnostic()> emitError,
1233                                  MLIRContext *ctx,
1234                                  SparseTensorEncodingAttr encoding) {
1235   return Base::getChecked(emitError, ctx,
1236                           getNormalizedEncodingForSpecifier(encoding));
1237 }
1238 
1239 //===----------------------------------------------------------------------===//
1240 // SparseTensorDialect Operations.
1241 //===----------------------------------------------------------------------===//
1242 
1243 static LogicalResult lvlIsInBounds(Level lvl, Value tensor) {
1244   return success(lvl < getSparseTensorType(tensor).getLvlRank());
1245 }
1246 
1247 static LogicalResult isMatchingWidth(Value mem, unsigned width) {
1248   const Type etp = getMemRefType(mem).getElementType();
1249   return success(width == 0 ? etp.isIndex() : etp.isInteger(width));
1250 }
1251 
1252 static LogicalResult verifySparsifierGetterSetter(
1253     StorageSpecifierKind mdKind, std::optional<Level> lvl,
1254     TypedValue<StorageSpecifierType> md, Operation *op) {
1255   if (mdKind == StorageSpecifierKind::ValMemSize && lvl) {
1256     return op->emitError(
1257         "redundant level argument for querying value memory size");
1258   }
1259 
1260   const auto enc = md.getType().getEncoding();
1261   const Level lvlRank = enc.getLvlRank();
1262 
1263   if (mdKind == StorageSpecifierKind::DimOffset ||
1264       mdKind == StorageSpecifierKind::DimStride)
1265     if (!enc.isSlice())
1266       return op->emitError("requested slice data on non-slice tensor");
1267 
1268   if (mdKind != StorageSpecifierKind::ValMemSize) {
1269     if (!lvl)
1270       return op->emitError("missing level argument");
1271 
1272     const Level l = lvl.value();
1273     if (l >= lvlRank)
1274       return op->emitError("requested level is out of bounds");
1275 
1276     if (mdKind == StorageSpecifierKind::PosMemSize && enc.isSingletonLvl(l))
1277       return op->emitError(
1278           "requested position memory size on a singleton level");
1279   }
1280   return success();
1281 }
1282 
1283 static Type getFieldElemType(SparseTensorType stt, SparseTensorFieldKind kind) {
1284   switch (kind) {
1285   case SparseTensorFieldKind::CrdMemRef:
1286     return stt.getCrdType();
1287   case SparseTensorFieldKind::PosMemRef:
1288     return stt.getPosType();
1289   case SparseTensorFieldKind::ValMemRef:
1290     return stt.getElementType();
1291   case SparseTensorFieldKind::StorageSpec:
1292     return nullptr;
1293   }
1294   llvm_unreachable("Unrecognizable FieldKind");
1295 }
1296 
1297 static LogicalResult verifyPackUnPack(Operation *op, bool requiresStaticShape,
1298                                       SparseTensorType stt,
1299                                       RankedTensorType valTp,
1300                                       TypeRange lvlTps) {
1301   if (requiresStaticShape && !stt.hasStaticDimShape())
1302     return op->emitError("the sparse-tensor must have static shape");
1303   if (!stt.hasEncoding())
1304     return op->emitError("the sparse-tensor must have an encoding attribute");
1305 
1306   // Verifies the trailing COO.
1307   Level cooStartLvl = stt.getAoSCOOStart();
1308   if (cooStartLvl < stt.getLvlRank()) {
1309     // We only supports trailing COO for now, must be the last input.
1310     auto cooTp = llvm::cast<ShapedType>(lvlTps.back());
1311     // The coordinates should be in shape of <? x rank>
1312     unsigned expCOORank = stt.getLvlRank() - cooStartLvl;
1313     if (cooTp.getRank() != 2 || expCOORank != cooTp.getShape().back()) {
1314       return op->emitError("input/output trailing COO level-ranks don't match");
1315     }
1316   }
1317 
1318   // Verifies that all types match.
1319   StorageLayout layout(stt.getEncoding());
1320   if (layout.getNumDataFields() != lvlTps.size() + 1) // plus one value memref
1321     return op->emitError("inconsistent number of fields between input/output");
1322 
1323   unsigned idx = 0;
1324   bool misMatch = false;
1325   layout.foreachField([&idx, &misMatch, stt, valTp,
1326                        lvlTps](FieldIndex fid, SparseTensorFieldKind fKind,
1327                                Level lvl, LevelType lt) -> bool {
1328     if (fKind == SparseTensorFieldKind::StorageSpec)
1329       return true;
1330 
1331     Type inputTp = nullptr;
1332     if (fKind == SparseTensorFieldKind::ValMemRef) {
1333       inputTp = valTp;
1334     } else {
1335       assert(fid == idx && stt.getLvlType(lvl) == lt);
1336       inputTp = lvlTps[idx++];
1337     }
1338     // The input element type and expected element type should match.
1339     Type inpElemTp = llvm::cast<TensorType>(inputTp).getElementType();
1340     Type expElemTp = getFieldElemType(stt, fKind);
1341     if (inpElemTp != expElemTp) {
1342       misMatch = true;
1343       return false; // to terminate the iteration
1344     }
1345     return true;
1346   });
1347 
1348   if (misMatch)
1349     return op->emitError("input/output element-types don't match");
1350   return success();
1351 }
1352 
1353 LogicalResult AssembleOp::verify() {
1354   RankedTensorType valuesTp = getValues().getType();
1355   const auto lvlsTp = getLevels().getTypes();
1356   const auto resTp = getSparseTensorType(getResult());
1357   return verifyPackUnPack(*this, true, resTp, valuesTp, lvlsTp);
1358 }
1359 
1360 LogicalResult DisassembleOp::verify() {
1361   if (getOutValues().getType() != getRetValues().getType())
1362     return emitError("output values and return value type mismatch");
1363 
1364   for (auto [ot, rt] : llvm::zip_equal(getOutLevels(), getRetLevels()))
1365     if (ot.getType() != rt.getType())
1366       return emitError("output levels and return levels type mismatch");
1367 
1368   RankedTensorType valuesTp = getRetValues().getType();
1369   const auto lvlsTp = getRetLevels().getTypes();
1370   const auto srcTp = getSparseTensorType(getTensor());
1371   return verifyPackUnPack(*this, false, srcTp, valuesTp, lvlsTp);
1372 }
1373 
1374 LogicalResult ConvertOp::verify() {
1375   RankedTensorType tp1 = getSource().getType();
1376   RankedTensorType tp2 = getDest().getType();
1377   if (tp1.getRank() != tp2.getRank())
1378     return emitError("unexpected conversion mismatch in rank");
1379   auto dstEnc =
1380       llvm::dyn_cast_or_null<SparseTensorEncodingAttr>(tp2.getEncoding());
1381   if (dstEnc && dstEnc.isSlice())
1382     return emitError("cannot convert to a sparse tensor slice");
1383 
1384   auto shape1 = tp1.getShape();
1385   auto shape2 = tp2.getShape();
1386   // Accept size matches between the source and the destination type
1387   // (e.g. 10 vs. 10, 10 vs. ?, or ? vs. ?), but reject direct mismatches or
1388   // matches that would need a runtime assert (e.g. 10 vs. 20 or ? vs. 10).
1389   for (Dimension d = 0, dimRank = tp1.getRank(); d < dimRank; d++)
1390     if (shape1[d] != shape2[d] && shape2[d] != ShapedType::kDynamic)
1391       return emitError("unexpected conversion mismatch in dimension ") << d;
1392   return success();
1393 }
1394 
1395 OpFoldResult ConvertOp::fold(FoldAdaptor adaptor) {
1396   if (getType() == getSource().getType())
1397     return getSource();
1398   return {};
1399 }
1400 
1401 bool ConvertOp::needsExtraSort() {
1402   SparseTensorType srcStt = getSparseTensorType(getSource());
1403   SparseTensorType dstStt = getSparseTensorType(getDest());
1404 
1405   // We do not need an extra sort when returning unordered sparse tensors or
1406   // dense tensor since dense tensor support random access.
1407   if (dstStt.isAllDense() || !dstStt.isAllOrdered())
1408     return false;
1409 
1410   if (srcStt.isAllOrdered() && dstStt.isAllOrdered() &&
1411       srcStt.hasSameDimToLvl(dstStt)) {
1412     return false;
1413   }
1414 
1415   // Source and dest tensors are ordered in different ways. We only do direct
1416   // dense to sparse conversion when the dense input is defined by a sparse
1417   // constant. Note that we can theoretically always directly convert from dense
1418   // inputs by rotating dense loops but it leads to bad cache locality and hurt
1419   // performance.
1420   if (auto constOp = getSource().getDefiningOp<arith::ConstantOp>())
1421     if (isa<SparseElementsAttr>(constOp.getValue()))
1422       return false;
1423 
1424   return true;
1425 }
1426 
1427 LogicalResult CrdTranslateOp::verify() {
1428   uint64_t inRank = getEncoder().getLvlRank();
1429   uint64_t outRank = getEncoder().getDimRank();
1430 
1431   if (getDirection() == CrdTransDirectionKind::dim2lvl)
1432     std::swap(inRank, outRank);
1433 
1434   if (inRank != getInCrds().size() || outRank != getOutCrds().size())
1435     return emitError("Coordinate rank mismatch with encoding");
1436 
1437   return success();
1438 }
1439 
1440 LogicalResult CrdTranslateOp::fold(FoldAdaptor adaptor,
1441                                    SmallVectorImpl<OpFoldResult> &results) {
1442   if (getEncoder().isIdentity()) {
1443     results.assign(getInCrds().begin(), getInCrds().end());
1444     return success();
1445   }
1446   if (getEncoder().isPermutation()) {
1447     AffineMap perm = getDirection() == CrdTransDirectionKind::dim2lvl
1448                          ? getEncoder().getDimToLvl()
1449                          : getEncoder().getLvlToDim();
1450     for (AffineExpr exp : perm.getResults())
1451       results.push_back(getInCrds()[cast<AffineDimExpr>(exp).getPosition()]);
1452     return success();
1453   }
1454 
1455   // Fuse dim2lvl/lvl2dim pairs.
1456   auto def = getInCrds()[0].getDefiningOp<CrdTranslateOp>();
1457   bool sameDef = def && llvm::all_of(getInCrds(), [def](Value v) {
1458                    return v.getDefiningOp() == def;
1459                  });
1460   if (!sameDef)
1461     return failure();
1462 
1463   bool oppositeDir = def.getDirection() != getDirection();
1464   bool sameOracle =
1465       def.getEncoder().getDimToLvl() == getEncoder().getDimToLvl();
1466   bool sameCount = def.getNumResults() == getInCrds().size();
1467   if (!oppositeDir || !sameOracle || !sameCount)
1468     return failure();
1469 
1470   // The definition produces the coordinates in the same order as the input
1471   // coordinates.
1472   bool sameOrder = llvm::all_of(llvm::zip_equal(def.getOutCrds(), getInCrds()),
1473                                 [](auto valuePair) {
1474                                   auto [lhs, rhs] = valuePair;
1475                                   return lhs == rhs;
1476                                 });
1477 
1478   if (!sameOrder)
1479     return failure();
1480   // l1 = dim2lvl (lvl2dim l0)
1481   // ==> l0
1482   results.append(def.getInCrds().begin(), def.getInCrds().end());
1483   return success();
1484 }
1485 
1486 void LvlOp::build(OpBuilder &builder, OperationState &state, Value source,
1487                   int64_t index) {
1488   Value val = builder.create<arith::ConstantIndexOp>(state.location, index);
1489   return build(builder, state, source, val);
1490 }
1491 
1492 LogicalResult LvlOp::verify() {
1493   if (std::optional<uint64_t> lvl = getConstantLvlIndex()) {
1494     auto stt = getSparseTensorType(getSource());
1495     if (static_cast<uint64_t>(lvl.value()) >= stt.getLvlRank())
1496       return emitError(
1497           "Level index exceeds the rank of the input sparse tensor");
1498   }
1499   return success();
1500 }
1501 
1502 std::optional<uint64_t> LvlOp::getConstantLvlIndex() {
1503   return getConstantIntValue(getIndex());
1504 }
1505 
1506 Speculation::Speculatability LvlOp::getSpeculatability() {
1507   auto constantIndex = getConstantLvlIndex();
1508   if (!constantIndex)
1509     return Speculation::NotSpeculatable;
1510 
1511   assert(constantIndex <
1512          cast<RankedTensorType>(getSource().getType()).getRank());
1513   return Speculation::Speculatable;
1514 }
1515 
1516 OpFoldResult LvlOp::fold(FoldAdaptor adaptor) {
1517   auto lvlIndex = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getIndex());
1518   if (!lvlIndex)
1519     return {};
1520 
1521   Level lvl = lvlIndex.getAPSInt().getZExtValue();
1522   auto stt = getSparseTensorType(getSource());
1523   if (lvl >= stt.getLvlRank()) {
1524     // Follows the same convention used by tensor.dim operation. Out of bound
1525     // indices produce undefined behavior but are still valid IR. Don't choke on
1526     // them.
1527     return {};
1528   }
1529 
1530   // Helper lambda to build an IndexAttr.
1531   auto getIndexAttr = [this](int64_t lvlSz) {
1532     return IntegerAttr::get(IndexType::get(getContext()), APInt(64, lvlSz));
1533   };
1534 
1535   SmallVector<Size> lvlShape = stt.getLvlShape();
1536   if (!ShapedType::isDynamic(lvlShape[lvl]))
1537     return getIndexAttr(lvlShape[lvl]);
1538 
1539   return {};
1540 }
1541 
1542 void ReinterpretMapOp::build(OpBuilder &odsBuilder, OperationState &odsState,
1543                              SparseTensorEncodingAttr dstEnc, Value source) {
1544   auto srcStt = getSparseTensorType(source);
1545   SmallVector<int64_t> srcLvlShape = srcStt.getLvlShape();
1546   SmallVector<int64_t> dstDimShape =
1547       dstEnc.translateShape(srcLvlShape, CrdTransDirectionKind::lvl2dim);
1548   auto dstTp =
1549       RankedTensorType::get(dstDimShape, srcStt.getElementType(), dstEnc);
1550   return build(odsBuilder, odsState, dstTp, source);
1551 }
1552 
1553 LogicalResult ReinterpretMapOp::verify() {
1554   auto srcStt = getSparseTensorType(getSource());
1555   auto dstStt = getSparseTensorType(getDest());
1556   ArrayRef<LevelType> srcLvlTps = srcStt.getLvlTypes();
1557   ArrayRef<LevelType> dstLvlTps = dstStt.getLvlTypes();
1558 
1559   if (srcLvlTps.size() != dstLvlTps.size())
1560     return emitError("Level rank mismatch between source/dest tensors");
1561 
1562   for (auto [srcLvlTp, dstLvlTp] : llvm::zip(srcLvlTps, dstLvlTps))
1563     if (srcLvlTp != dstLvlTp)
1564       return emitError("Level type mismatch between source/dest tensors");
1565 
1566   if (srcStt.getPosWidth() != dstStt.getPosWidth() ||
1567       srcStt.getCrdWidth() != dstStt.getCrdWidth()) {
1568     return emitError("Crd/Pos width mismatch between source/dest tensors");
1569   }
1570 
1571   if (srcStt.getElementType() != dstStt.getElementType())
1572     return emitError("Element type mismatch between source/dest tensors");
1573 
1574   SmallVector<Size> srcLvlShape = srcStt.getLvlShape();
1575   SmallVector<Size> dstLvlShape = dstStt.getLvlShape();
1576   for (auto [srcLvlSz, dstLvlSz] : llvm::zip(srcLvlShape, dstLvlShape)) {
1577     if (srcLvlSz != dstLvlSz) {
1578       // Should we allow one side to be dynamic size, e.g., <?x?> should be
1579       // compatible to <3x4>? For now, we require all the level sizes to be
1580       // *exactly* matched for simplicity.
1581       return emitError("Level size mismatch between source/dest tensors");
1582     }
1583   }
1584 
1585   return success();
1586 }
1587 
1588 OpFoldResult ReinterpretMapOp::fold(FoldAdaptor adaptor) {
1589   if (getSource().getType() == getDest().getType())
1590     return getSource();
1591 
1592   if (auto def = getSource().getDefiningOp<ReinterpretMapOp>()) {
1593     // A -> B, B -> A ==> A
1594     if (def.getSource().getType() == getDest().getType())
1595       return def.getSource();
1596   }
1597   return {};
1598 }
1599 
1600 template <typename ToBufferOp>
1601 static LogicalResult inferSparseBufferType(ValueRange ops, DictionaryAttr attr,
1602                                            OpaqueProperties prop,
1603                                            RegionRange region,
1604                                            SmallVectorImpl<mlir::Type> &ret) {
1605   typename ToBufferOp::Adaptor adaptor(ops, attr, prop, region);
1606   SparseTensorType stt = getSparseTensorType(adaptor.getTensor());
1607   Type elemTp = nullptr;
1608   bool withStride = false;
1609   if constexpr (std::is_same_v<ToBufferOp, ToPositionsOp>) {
1610     elemTp = stt.getPosType();
1611   } else if constexpr (std::is_same_v<ToBufferOp, ToCoordinatesOp> ||
1612                        std::is_same_v<ToBufferOp, ToCoordinatesBufferOp>) {
1613     elemTp = stt.getCrdType();
1614     if constexpr (std::is_same_v<ToBufferOp, ToCoordinatesOp>)
1615       withStride = stt.getAoSCOOStart() <= adaptor.getLevel();
1616   } else if constexpr (std::is_same_v<ToBufferOp, ToValuesOp>) {
1617     elemTp = stt.getElementType();
1618   }
1619 
1620   assert(elemTp && "unhandled operation.");
1621   SmallVector<int64_t> bufShape = stt.getBatchLvlShape();
1622   bufShape.push_back(ShapedType::kDynamic);
1623 
1624   auto layout = withStride ? StridedLayoutAttr::StridedLayoutAttr::get(
1625                                  stt.getContext(), ShapedType::kDynamic,
1626                                  {ShapedType::kDynamic})
1627                            : StridedLayoutAttr();
1628   ret.emplace_back(MemRefType::get(bufShape, elemTp, layout));
1629   return success();
1630 }
1631 
1632 LogicalResult ToPositionsOp::verify() {
1633   auto stt = getSparseTensorType(getTensor());
1634   if (failed(lvlIsInBounds(getLevel(), getTensor())))
1635     return emitError("requested level is out of bounds");
1636   if (failed(isMatchingWidth(getResult(), stt.getPosWidth())))
1637     return emitError("unexpected type for positions");
1638   return success();
1639 }
1640 
1641 LogicalResult
1642 ToPositionsOp::inferReturnTypes(MLIRContext *ctx, std::optional<Location> loc,
1643                                 ValueRange ops, DictionaryAttr attr,
1644                                 OpaqueProperties prop, RegionRange region,
1645                                 SmallVectorImpl<mlir::Type> &ret) {
1646   return inferSparseBufferType<ToPositionsOp>(ops, attr, prop, region, ret);
1647 }
1648 
1649 LogicalResult ToCoordinatesOp::verify() {
1650   auto stt = getSparseTensorType(getTensor());
1651   if (failed(lvlIsInBounds(getLevel(), getTensor())))
1652     return emitError("requested level is out of bounds");
1653   if (failed(isMatchingWidth(getResult(), stt.getCrdWidth())))
1654     return emitError("unexpected type for coordinates");
1655   return success();
1656 }
1657 
1658 LogicalResult
1659 ToCoordinatesOp::inferReturnTypes(MLIRContext *ctx, std::optional<Location> loc,
1660                                   ValueRange ops, DictionaryAttr attr,
1661                                   OpaqueProperties prop, RegionRange region,
1662                                   SmallVectorImpl<mlir::Type> &ret) {
1663   return inferSparseBufferType<ToCoordinatesOp>(ops, attr, prop, region, ret);
1664 }
1665 
1666 LogicalResult ToCoordinatesBufferOp::verify() {
1667   auto stt = getSparseTensorType(getTensor());
1668   if (stt.getAoSCOOStart() >= stt.getLvlRank())
1669     return emitError("expected sparse tensor with a COO region");
1670   return success();
1671 }
1672 
1673 LogicalResult ToCoordinatesBufferOp::inferReturnTypes(
1674     MLIRContext *ctx, std::optional<Location> loc, ValueRange ops,
1675     DictionaryAttr attr, OpaqueProperties prop, RegionRange region,
1676     SmallVectorImpl<mlir::Type> &ret) {
1677   return inferSparseBufferType<ToCoordinatesBufferOp>(ops, attr, prop, region,
1678                                                       ret);
1679 }
1680 
1681 LogicalResult ToValuesOp::verify() {
1682   auto stt = getSparseTensorType(getTensor());
1683   auto mtp = getMemRefType(getResult());
1684   if (stt.getElementType() != mtp.getElementType())
1685     return emitError("unexpected mismatch in element types");
1686   return success();
1687 }
1688 
1689 LogicalResult ToValuesOp::inferReturnTypes(MLIRContext *ctx,
1690                                            std::optional<Location> loc,
1691                                            ValueRange ops, DictionaryAttr attr,
1692                                            OpaqueProperties prop,
1693                                            RegionRange region,
1694                                            SmallVectorImpl<mlir::Type> &ret) {
1695   return inferSparseBufferType<ToValuesOp>(ops, attr, prop, region, ret);
1696 }
1697 
1698 LogicalResult ToSliceOffsetOp::verify() {
1699   auto rank = getSlice().getType().getRank();
1700   if (rank <= getDim().getSExtValue() || getDim().getSExtValue() < 0)
1701     return emitError("requested dimension out of bound");
1702   return success();
1703 }
1704 
1705 LogicalResult ToSliceStrideOp::verify() {
1706   auto rank = getSlice().getType().getRank();
1707   if (rank <= getDim().getSExtValue() || getDim().getSExtValue() < 0)
1708     return emitError("requested dimension out of bound");
1709   return success();
1710 }
1711 
1712 LogicalResult GetStorageSpecifierOp::verify() {
1713   return verifySparsifierGetterSetter(getSpecifierKind(), getLevel(),
1714                                       getSpecifier(), getOperation());
1715 }
1716 
1717 template <typename SpecifierOp>
1718 static SetStorageSpecifierOp getSpecifierSetDef(SpecifierOp op) {
1719   return op.getSpecifier().template getDefiningOp<SetStorageSpecifierOp>();
1720 }
1721 
1722 OpFoldResult GetStorageSpecifierOp::fold(FoldAdaptor adaptor) {
1723   const StorageSpecifierKind kind = getSpecifierKind();
1724   const auto lvl = getLevel();
1725   for (auto op = getSpecifierSetDef(*this); op; op = getSpecifierSetDef(op))
1726     if (kind == op.getSpecifierKind() && lvl == op.getLevel())
1727       return op.getValue();
1728   return {};
1729 }
1730 
1731 LogicalResult SetStorageSpecifierOp::verify() {
1732   return verifySparsifierGetterSetter(getSpecifierKind(), getLevel(),
1733                                       getSpecifier(), getOperation());
1734 }
1735 
1736 template <class T>
1737 static LogicalResult verifyNumBlockArgs(T *op, Region &region,
1738                                         const char *regionName,
1739                                         TypeRange inputTypes, Type outputType) {
1740   unsigned numArgs = region.getNumArguments();
1741   unsigned expectedNum = inputTypes.size();
1742   if (numArgs != expectedNum)
1743     return op->emitError() << regionName << " region must have exactly "
1744                            << expectedNum << " arguments";
1745 
1746   for (unsigned i = 0; i < numArgs; i++) {
1747     Type typ = region.getArgument(i).getType();
1748     if (typ != inputTypes[i])
1749       return op->emitError() << regionName << " region argument " << (i + 1)
1750                              << " type mismatch";
1751   }
1752   Operation *term = region.front().getTerminator();
1753   YieldOp yield = dyn_cast<YieldOp>(term);
1754   if (!yield)
1755     return op->emitError() << regionName
1756                            << " region must end with sparse_tensor.yield";
1757   if (!yield.hasSingleResult() ||
1758       yield.getSingleResult().getType() != outputType)
1759     return op->emitError() << regionName << " region yield type mismatch";
1760 
1761   return success();
1762 }
1763 
1764 LogicalResult BinaryOp::verify() {
1765   NamedAttrList attrs = (*this)->getAttrs();
1766   Type leftType = getX().getType();
1767   Type rightType = getY().getType();
1768   Type outputType = getOutput().getType();
1769   Region &overlap = getOverlapRegion();
1770   Region &left = getLeftRegion();
1771   Region &right = getRightRegion();
1772 
1773   // Check correct number of block arguments and return type for each
1774   // non-empty region.
1775   if (!overlap.empty()) {
1776     if (failed(verifyNumBlockArgs(this, overlap, "overlap",
1777                                   TypeRange{leftType, rightType}, outputType)))
1778       return failure();
1779   }
1780   if (!left.empty()) {
1781     if (failed(verifyNumBlockArgs(this, left, "left", TypeRange{leftType},
1782                                   outputType)))
1783       return failure();
1784   } else if (getLeftIdentity()) {
1785     if (leftType != outputType)
1786       return emitError("left=identity requires first argument to have the same "
1787                        "type as the output");
1788   }
1789   if (!right.empty()) {
1790     if (failed(verifyNumBlockArgs(this, right, "right", TypeRange{rightType},
1791                                   outputType)))
1792       return failure();
1793   } else if (getRightIdentity()) {
1794     if (rightType != outputType)
1795       return emitError("right=identity requires second argument to have the "
1796                        "same type as the output");
1797   }
1798   return success();
1799 }
1800 
1801 LogicalResult UnaryOp::verify() {
1802   Type inputType = getX().getType();
1803   Type outputType = getOutput().getType();
1804 
1805   // Check correct number of block arguments and return type for each
1806   // non-empty region.
1807   Region &present = getPresentRegion();
1808   if (!present.empty()) {
1809     if (failed(verifyNumBlockArgs(this, present, "present",
1810                                   TypeRange{inputType}, outputType)))
1811       return failure();
1812   }
1813   Region &absent = getAbsentRegion();
1814   if (!absent.empty()) {
1815     if (failed(verifyNumBlockArgs(this, absent, "absent", TypeRange{},
1816                                   outputType)))
1817       return failure();
1818     // Absent branch can only yield invariant values.
1819     Block *absentBlock = &absent.front();
1820     Block *parent = getOperation()->getBlock();
1821     Value absentVal =
1822         cast<YieldOp>(absentBlock->getTerminator()).getSingleResult();
1823     if (auto arg = dyn_cast<BlockArgument>(absentVal)) {
1824       if (arg.getOwner() == parent)
1825         return emitError("absent region cannot yield linalg argument");
1826     } else if (Operation *def = absentVal.getDefiningOp()) {
1827       if (!isa<arith::ConstantOp>(def) &&
1828           (def->getBlock() == absentBlock || def->getBlock() == parent))
1829         return emitError("absent region cannot yield locally computed value");
1830     }
1831   }
1832   return success();
1833 }
1834 
1835 bool ConcatenateOp::needsExtraSort() {
1836   SparseTensorType dstStt = getSparseTensorType(*this);
1837   if (dstStt.isAllDense() || !dstStt.isAllOrdered())
1838     return false;
1839 
1840   bool allSameOrdered = llvm::all_of(getInputs(), [dstStt](Value op) {
1841     return getSparseTensorType(op).hasSameDimToLvl(dstStt);
1842   });
1843   // TODO: When conDim != 0, as long as conDim corresponding to the first level
1844   // in all input/output buffers, and all input/output buffers have the same
1845   // dimToLvl, the tmp COO buffer is still unnecessary (e.g, concatenate
1846   // CSC matrices along column).
1847   bool directLowerable =
1848       allSameOrdered && getDimension() == 0 && dstStt.isIdentity();
1849   return !directLowerable;
1850 }
1851 
1852 LogicalResult ConcatenateOp::verify() {
1853   const auto dstTp = getSparseTensorType(*this);
1854   const Dimension concatDim = getDimension();
1855   const Dimension dimRank = dstTp.getDimRank();
1856 
1857   if (getInputs().size() <= 1)
1858     return emitError("Need at least two tensors to concatenate.");
1859 
1860   if (concatDim >= dimRank)
1861     return emitError(llvm::formatv(
1862         "Concat-dimension is out of bounds for dimension-rank ({0} >= {1})",
1863         concatDim, dimRank));
1864 
1865   for (const auto &it : llvm::enumerate(getInputs())) {
1866     const auto i = it.index();
1867     const auto srcTp = getSparseTensorType(it.value());
1868     if (srcTp.hasDynamicDimShape())
1869       return emitError(llvm::formatv("Input tensor ${0} has dynamic shape", i));
1870     const Dimension srcDimRank = srcTp.getDimRank();
1871     if (srcDimRank != dimRank)
1872       return emitError(
1873           llvm::formatv("Input tensor ${0} has a different rank (rank={1}) "
1874                         "from the output tensor (rank={2}).",
1875                         i, srcDimRank, dimRank));
1876   }
1877 
1878   for (Dimension d = 0; d < dimRank; d++) {
1879     const Size dstSh = dstTp.getDimShape()[d];
1880     if (d == concatDim) {
1881       if (!ShapedType::isDynamic(dstSh)) {
1882         // If we reach here, then all inputs have static shapes.  So we
1883         // can use `getDimShape()[d]` instead of `*getDynamicDimSize(d)`
1884         // to avoid redundant assertions in the loop.
1885         Size sumSz = 0;
1886         for (const auto src : getInputs())
1887           sumSz += getSparseTensorType(src).getDimShape()[d];
1888         // If all dimension are statically known, the sum of all the input
1889         // dimensions should be equal to the output dimension.
1890         if (sumSz != dstSh)
1891           return emitError(
1892               "The concatenation dimension of the output tensor should be the "
1893               "sum of all the concatenation dimensions of the input tensors.");
1894       }
1895     } else {
1896       Size prev = dstSh;
1897       for (const auto src : getInputs()) {
1898         const auto sh = getSparseTensorType(src).getDimShape()[d];
1899         if (!ShapedType::isDynamic(prev) && sh != prev)
1900           return emitError("All dimensions (expect for the concatenating one) "
1901                            "should be equal.");
1902         prev = sh;
1903       }
1904     }
1905   }
1906 
1907   return success();
1908 }
1909 
1910 void PushBackOp::build(OpBuilder &builder, OperationState &result,
1911                        Value curSize, Value inBuffer, Value value) {
1912   build(builder, result, curSize, inBuffer, value, Value());
1913 }
1914 
1915 LogicalResult PushBackOp::verify() {
1916   if (Value n = getN()) {
1917     std::optional<int64_t> nValue = getConstantIntValue(n);
1918     if (nValue && nValue.value() < 1)
1919       return emitOpError("n must be not less than 1");
1920   }
1921   return success();
1922 }
1923 
1924 LogicalResult CompressOp::verify() {
1925   const auto stt = getSparseTensorType(getTensor());
1926   if (stt.getLvlRank() != 1 + static_cast<Level>(getLvlCoords().size()))
1927     return emitOpError("incorrect number of coordinates");
1928   return success();
1929 }
1930 
1931 void ForeachOp::build(
1932     OpBuilder &builder, OperationState &result, Value tensor,
1933     ValueRange initArgs, AffineMapAttr order,
1934     function_ref<void(OpBuilder &, Location, ValueRange, Value, ValueRange)>
1935         bodyBuilder) {
1936   build(builder, result, initArgs.getTypes(), tensor, initArgs, order);
1937   // Builds foreach body.
1938   if (!bodyBuilder)
1939     return;
1940   const auto stt = getSparseTensorType(tensor);
1941   const Dimension dimRank = stt.getDimRank();
1942 
1943   // Starts with `dimRank`-many coordinates.
1944   SmallVector<Type> blockArgTypes(dimRank, builder.getIndexType());
1945   // Followed by one value.
1946   blockArgTypes.push_back(stt.getElementType());
1947   // Followed by the reduction variables.
1948   blockArgTypes.append(initArgs.getTypes().begin(), initArgs.getTypes().end());
1949 
1950   SmallVector<Location> blockArgLocs(blockArgTypes.size(), tensor.getLoc());
1951 
1952   OpBuilder::InsertionGuard guard(builder);
1953   auto &region = *result.regions.front();
1954   Block *bodyBlock =
1955       builder.createBlock(&region, region.end(), blockArgTypes, blockArgLocs);
1956   bodyBuilder(builder, result.location,
1957               bodyBlock->getArguments().slice(0, dimRank),
1958               bodyBlock->getArguments()[dimRank],
1959               bodyBlock->getArguments().drop_front(dimRank + 1));
1960 }
1961 
1962 LogicalResult ForeachOp::verify() {
1963   const auto t = getSparseTensorType(getTensor());
1964   const Dimension dimRank = t.getDimRank();
1965   const auto args = getBody()->getArguments();
1966 
1967   if (getOrder().has_value() && getOrder()->getNumDims() != t.getLvlRank())
1968     return emitError("Level traverse order does not match tensor's level rank");
1969 
1970   if (dimRank + 1 + getInitArgs().size() != args.size())
1971     return emitError("Unmatched number of arguments in the block");
1972 
1973   if (getNumResults() != getInitArgs().size())
1974     return emitError("Mismatch in number of init arguments and results");
1975 
1976   if (getResultTypes() != getInitArgs().getTypes())
1977     return emitError("Mismatch in types of init arguments and results");
1978 
1979   // Cannot mark this const, because the getters aren't.
1980   auto yield = cast<YieldOp>(getBody()->getTerminator());
1981   if (yield.getNumOperands() != getNumResults() ||
1982       yield.getOperands().getTypes() != getResultTypes())
1983     return emitError("Mismatch in types of yield values and results");
1984 
1985   const auto iTp = IndexType::get(getContext());
1986   for (Dimension d = 0; d < dimRank; d++)
1987     if (args[d].getType() != iTp)
1988       return emitError(
1989           llvm::formatv("Expecting Index type for argument at index {0}", d));
1990 
1991   const auto elemTp = t.getElementType();
1992   const auto valueTp = args[dimRank].getType();
1993   if (elemTp != valueTp)
1994     return emitError(
1995         llvm::formatv("Unmatched element type between input tensor and "
1996                       "block argument, expected:{0}, got: {1}",
1997                       elemTp, valueTp));
1998   return success();
1999 }
2000 
2001 OpFoldResult ReorderCOOOp::fold(FoldAdaptor adaptor) {
2002   if (getSparseTensorEncoding(getInputCoo().getType()) ==
2003       getSparseTensorEncoding(getResultCoo().getType()))
2004     return getInputCoo();
2005 
2006   return {};
2007 }
2008 
2009 LogicalResult ReorderCOOOp::verify() {
2010   SparseTensorType srcStt = getSparseTensorType(getInputCoo());
2011   SparseTensorType dstStt = getSparseTensorType(getResultCoo());
2012 
2013   if (!srcStt.isCOOType() || !dstStt.isCOOType())
2014     return emitError("Expected COO sparse tensors only");
2015 
2016   if (!srcStt.hasSameDimToLvl(dstStt))
2017     return emitError("Unmatched dim2lvl map between input and result COO");
2018 
2019   if (srcStt.getPosType() != dstStt.getPosType() ||
2020       srcStt.getCrdType() != dstStt.getCrdType() ||
2021       srcStt.getElementType() != dstStt.getElementType())
2022     return emitError("Unmatched storage format between input and result COO");
2023 
2024   return success();
2025 }
2026 
2027 LogicalResult ReduceOp::verify() {
2028   Type inputType = getX().getType();
2029   Region &formula = getRegion();
2030   return verifyNumBlockArgs(this, formula, "reduce",
2031                             TypeRange{inputType, inputType}, inputType);
2032 }
2033 
2034 LogicalResult SelectOp::verify() {
2035   Builder b(getContext());
2036   Type inputType = getX().getType();
2037   Type boolType = b.getI1Type();
2038   Region &formula = getRegion();
2039   return verifyNumBlockArgs(this, formula, "select", TypeRange{inputType},
2040                             boolType);
2041 }
2042 
2043 LogicalResult SortOp::verify() {
2044   AffineMap xPerm = getPermMap();
2045   uint64_t nx = xPerm.getNumDims();
2046   if (nx < 1)
2047     return emitError(llvm::formatv("Expected rank(perm_map) > 1, got {0}", nx));
2048 
2049   if (!xPerm.isPermutation())
2050     return emitError(
2051         llvm::formatv("Expected a permutation map, got {0}", xPerm));
2052 
2053   // We can't check the size of the buffers when n or buffer dimensions aren't
2054   // compile-time constants.
2055   std::optional<int64_t> cn = getConstantIntValue(getN());
2056   if (!cn)
2057     return success();
2058 
2059   // Verify dimensions.
2060   const auto checkDim = [&](Value v, Size minSize,
2061                             const char *message) -> LogicalResult {
2062     const Size sh = getMemRefType(v).getShape()[0];
2063     if (!ShapedType::isDynamic(sh) && sh < minSize)
2064       return emitError(
2065           llvm::formatv("{0} got {1} < {2}", message, sh, minSize));
2066     return success();
2067   };
2068   uint64_t n = cn.value();
2069   uint64_t ny = 0;
2070   if (auto nyAttr = getNyAttr())
2071     ny = nyAttr.getInt();
2072   if (failed(checkDim(getXy(), n * (nx + ny),
2073                       "Expected dimension(xy) >= n * (rank(perm_map) + ny)")))
2074     return failure();
2075   for (Value opnd : getYs())
2076     if (failed(checkDim(opnd, n, "Expected dimension(y) >= n")))
2077       return failure();
2078 
2079   return success();
2080 }
2081 
2082 //===----------------------------------------------------------------------===//
2083 // Sparse Tensor Iteration Operations.
2084 //===----------------------------------------------------------------------===//
2085 
2086 IterSpaceType IteratorType::getIterSpaceType() const {
2087   return IterSpaceType::get(getContext(), getEncoding(), getLoLvl(),
2088                             getHiLvl());
2089 }
2090 
2091 IteratorType IterSpaceType::getIteratorType() const {
2092   return IteratorType::get(getContext(), getEncoding(), getLoLvl(), getHiLvl());
2093 }
2094 
2095 /// Parses a level range in the form "$lo `to` $hi"
2096 /// or simply "$lo" if $hi - $lo = 1
2097 static ParseResult parseLevelRange(AsmParser &parser, Level &lvlLo,
2098                                    Level &lvlHi) {
2099   if (parser.parseInteger(lvlLo))
2100     return failure();
2101 
2102   if (succeeded(parser.parseOptionalKeyword("to"))) {
2103     if (parser.parseInteger(lvlHi))
2104       return failure();
2105   } else {
2106     lvlHi = lvlLo + 1;
2107   }
2108 
2109   if (lvlHi <= lvlLo)
2110     return parser.emitError(parser.getNameLoc(),
2111                             "expect larger level upper bound than lower bound");
2112 
2113   return success();
2114 }
2115 
2116 /// Parses a level range in the form "$lo `to` $hi"
2117 /// or simply "$lo" if $hi - $lo = 1
2118 static ParseResult parseLevelRange(OpAsmParser &parser, IntegerAttr &lvlLoAttr,
2119                                    IntegerAttr &lvlHiAttr) {
2120   Level lvlLo, lvlHi;
2121   if (parseLevelRange(parser, lvlLo, lvlHi))
2122     return failure();
2123 
2124   lvlLoAttr = IntegerAttr::get(parser.getBuilder().getIndexType(), lvlLo);
2125   lvlHiAttr = IntegerAttr::get(parser.getBuilder().getIndexType(), lvlHi);
2126   return success();
2127 }
2128 
2129 /// Prints a level range in the form "$lo `to` $hi"
2130 /// or simply "$lo" if $hi - $lo = 1
2131 static void printLevelRange(AsmPrinter &p, Level lo, Level hi) {
2132 
2133   if (lo + 1 == hi)
2134     p << lo;
2135   else
2136     p << lo << " to " << hi;
2137 }
2138 
2139 /// Prints a level range in the form "$lo `to` $hi"
2140 /// or simply "$lo" if $hi - $lo = 1
2141 static void printLevelRange(OpAsmPrinter &p, Operation *, IntegerAttr lvlLo,
2142                             IntegerAttr lvlHi) {
2143   unsigned lo = lvlLo.getValue().getZExtValue();
2144   unsigned hi = lvlHi.getValue().getZExtValue();
2145   printLevelRange(p, lo, hi);
2146 }
2147 
2148 /// Parses a list of `optional` defined list in the form of
2149 /// "(%val0, _, %val1, ...)", where `_` is used to annotate that the
2150 /// corresponding value is not defined (e.g., to represent an undefined
2151 /// coordinate in the sparse iteration space).
2152 static ParseResult parseOptionalDefinedList(
2153     OpAsmParser &parser, OperationState &state, I64BitSet &definedSet,
2154     SmallVectorImpl<OpAsmParser::Argument> &definedArgs,
2155     unsigned maxCnt = std::numeric_limits<unsigned>::max(),
2156     OpAsmParser::Delimiter delimiter = OpAsmParser::Delimiter::Paren) {
2157   unsigned cnt = 0;
2158   ParseResult crdList =
2159       parser.parseCommaSeparatedList(delimiter, [&]() -> ParseResult {
2160         if (parser.parseOptionalKeyword("_")) {
2161           if (parser.parseArgument(definedArgs.emplace_back()))
2162             return failure();
2163           definedSet.set(cnt);
2164         }
2165         cnt += 1;
2166         return success();
2167       });
2168 
2169   if (cnt > maxCnt)
2170     return parser.emitError(parser.getNameLoc(),
2171                             "parsed more value than expected.");
2172 
2173   if (failed(crdList)) {
2174     return parser.emitError(
2175         parser.getNameLoc(),
2176         "expecting SSA value or \"_\" for level coordinates");
2177   }
2178   assert(definedArgs.size() == definedSet.count());
2179   return success();
2180 }
2181 
2182 static void printOptionalDefinedList(OpAsmPrinter &p, unsigned size,
2183                                      Block::BlockArgListType blocksArgs,
2184                                      I64BitSet definedSet) {
2185   if (definedSet.empty())
2186     return;
2187 
2188   for (unsigned i = 0; i < size; i++) {
2189     if (definedSet[i]) {
2190       p << blocksArgs.front();
2191       blocksArgs = blocksArgs.drop_front();
2192     } else {
2193       p << "_";
2194     }
2195     if (i != size - 1)
2196       p << ", ";
2197   }
2198   assert(blocksArgs.empty());
2199 }
2200 
2201 static ParseResult
2202 parseUsedCoordList(OpAsmParser &parser, OperationState &state,
2203                    SmallVectorImpl<OpAsmParser::Argument> &coords) {
2204   // Parse "at(%crd0, _, ...)"
2205   I64BitSet crdUsedLvlSet;
2206   if (succeeded(parser.parseOptionalKeyword("at")) &&
2207       failed(parseOptionalDefinedList(parser, state, crdUsedLvlSet, coords)))
2208     return failure();
2209 
2210   // Always use IndexType for the coordinate.
2211   for (auto &coord : coords)
2212     coord.type = parser.getBuilder().getIndexType();
2213 
2214   // Set the CrdUsedLvl bitset.
2215   state.addAttribute("crdUsedLvls",
2216                      parser.getBuilder().getI64IntegerAttr(crdUsedLvlSet));
2217   return success();
2218 }
2219 
2220 static ParseResult
2221 parseSparseIterateLoop(OpAsmParser &parser, OperationState &state,
2222                        SmallVectorImpl<OpAsmParser::Argument> &iterators,
2223                        SmallVectorImpl<OpAsmParser::Argument> &blockArgs) {
2224   SmallVector<OpAsmParser::UnresolvedOperand> spaces;
2225   SmallVector<OpAsmParser::UnresolvedOperand> initArgs;
2226 
2227   // Parse "%iters, ... in %spaces, ..."
2228   if (parser.parseArgumentList(iterators) || parser.parseKeyword("in") ||
2229       parser.parseOperandList(spaces))
2230     return failure();
2231 
2232   if (iterators.size() != spaces.size())
2233     return parser.emitError(
2234         parser.getNameLoc(),
2235         "mismatch in number of sparse iterators and sparse spaces");
2236 
2237   SmallVector<OpAsmParser::Argument> coords;
2238   if (failed(parseUsedCoordList(parser, state, coords)))
2239     return failure();
2240   size_t numCrds = coords.size();
2241 
2242   // Parse "iter_args(%arg = %init, ...)"
2243   bool hasIterArgs = succeeded(parser.parseOptionalKeyword("iter_args"));
2244   if (hasIterArgs)
2245     if (parser.parseAssignmentList(blockArgs, initArgs))
2246       return failure();
2247 
2248   blockArgs.append(coords);
2249 
2250   SmallVector<Type> iterSpaceTps;
2251   // parse ": sparse_tensor.iter_space -> ret"
2252   if (parser.parseColon() || parser.parseTypeList(iterSpaceTps))
2253     return failure();
2254   if (iterSpaceTps.size() != spaces.size())
2255     return parser.emitError(parser.getNameLoc(),
2256                             "mismatch in number of iteration space operands "
2257                             "and iteration space types");
2258 
2259   for (auto [it, tp] : llvm::zip_equal(iterators, iterSpaceTps)) {
2260     IterSpaceType spaceTp = llvm::dyn_cast<IterSpaceType>(tp);
2261     if (!spaceTp)
2262       return parser.emitError(parser.getNameLoc(),
2263                               "expected sparse_tensor.iter_space type for "
2264                               "iteration space operands");
2265     it.type = spaceTp.getIteratorType();
2266   }
2267 
2268   if (hasIterArgs)
2269     if (parser.parseArrowTypeList(state.types))
2270       return failure();
2271 
2272   // Resolves input operands.
2273   if (parser.resolveOperands(spaces, iterSpaceTps, parser.getNameLoc(),
2274                              state.operands))
2275     return failure();
2276 
2277   if (hasIterArgs) {
2278     // Strip off leading args that used for coordinates.
2279     MutableArrayRef args = MutableArrayRef(blockArgs).drop_back(numCrds);
2280     if (args.size() != initArgs.size() || args.size() != state.types.size()) {
2281       return parser.emitError(
2282           parser.getNameLoc(),
2283           "mismatch in number of iteration arguments and return values");
2284     }
2285 
2286     for (auto [it, init, tp] : llvm::zip_equal(args, initArgs, state.types)) {
2287       it.type = tp;
2288       if (parser.resolveOperand(init, tp, state.operands))
2289         return failure();
2290     }
2291   }
2292   return success();
2293 }
2294 
2295 static ParseResult
2296 parseSparseCoIterateLoop(OpAsmParser &parser, OperationState &state,
2297                          SmallVectorImpl<Value> &spacesVals,
2298                          SmallVectorImpl<OpAsmParser::Argument> &blockArgs) {
2299 
2300   // Parse "(%spaces, ...)"
2301   SmallVector<OpAsmParser::UnresolvedOperand> spaces;
2302   if (parser.parseOperandList(spaces, OpAsmParser::Delimiter::Paren))
2303     return failure();
2304 
2305   SmallVector<OpAsmParser::Argument> coords;
2306   if (failed(parseUsedCoordList(parser, state, coords)))
2307     return failure();
2308   size_t numCrds = coords.size();
2309 
2310   // Parse "iter_args(%arg = %init, ...)"
2311   SmallVector<OpAsmParser::UnresolvedOperand> initArgs;
2312   bool hasIterArgs = succeeded(parser.parseOptionalKeyword("iter_args"));
2313   if (hasIterArgs)
2314     if (parser.parseAssignmentList(blockArgs, initArgs))
2315       return failure();
2316   blockArgs.append(coords);
2317 
2318   SmallVector<Type> iterSpaceTps;
2319   // parse ": (sparse_tensor.iter_space, ...) -> ret"
2320   if (parser.parseColon() || parser.parseLParen() ||
2321       parser.parseTypeList(iterSpaceTps) || parser.parseRParen())
2322     return failure();
2323 
2324   if (iterSpaceTps.size() != spaces.size())
2325     return parser.emitError(parser.getNameLoc(),
2326                             "mismatch in number of iteration space operands "
2327                             "and iteration space types");
2328 
2329   if (hasIterArgs)
2330     if (parser.parseArrowTypeList(state.types))
2331       return failure();
2332 
2333   // Resolves input sparse iteration spaces.
2334   if (parser.resolveOperands(spaces, iterSpaceTps, parser.getNameLoc(),
2335                              spacesVals))
2336     return failure();
2337   state.operands.append(spacesVals);
2338 
2339   if (hasIterArgs) {
2340     // Strip off trailing args that used for coordinates.
2341     MutableArrayRef args = MutableArrayRef(blockArgs).drop_back(numCrds);
2342     if (args.size() != initArgs.size() || args.size() != state.types.size()) {
2343       return parser.emitError(
2344           parser.getNameLoc(),
2345           "mismatch in number of iteration arguments and return values");
2346     }
2347 
2348     for (auto [it, init, tp] : llvm::zip_equal(args, initArgs, state.types)) {
2349       it.type = tp;
2350       if (parser.resolveOperand(init, tp, state.operands))
2351         return failure();
2352     }
2353   }
2354   return success();
2355 }
2356 
2357 LogicalResult ExtractIterSpaceOp::inferReturnTypes(
2358     MLIRContext *ctx, std::optional<Location> loc, ValueRange ops,
2359     DictionaryAttr attr, OpaqueProperties prop, RegionRange region,
2360     SmallVectorImpl<mlir::Type> &ret) {
2361 
2362   ExtractIterSpaceOp::Adaptor adaptor(ops, attr, prop, region);
2363   SparseTensorType stt = getSparseTensorType(adaptor.getTensor());
2364   ret.push_back(IterSpaceType::get(ctx, stt.getEncoding(), adaptor.getLoLvl(),
2365                                    adaptor.getHiLvl()));
2366   return success();
2367 }
2368 
2369 LogicalResult ExtractIterSpaceOp::verify() {
2370   if (getLoLvl() >= getHiLvl())
2371     return emitOpError("expected smaller level low than level high");
2372 
2373   TypedValue<IteratorType> pIter = getParentIter();
2374   if ((pIter && getLoLvl() == 0) || (!pIter && getLoLvl() != 0)) {
2375     return emitOpError(
2376         "parent iterator should be specified iff level lower bound equals 0");
2377   }
2378 
2379   if (pIter) {
2380     IterSpaceType spaceTp = getExtractedSpace().getType();
2381     if (pIter.getType().getEncoding() != spaceTp.getEncoding())
2382       return emitOpError(
2383           "mismatch in parent iterator encoding and iteration space encoding.");
2384 
2385     if (spaceTp.getLoLvl() != pIter.getType().getHiLvl())
2386       return emitOpError("parent iterator should be used to extract an "
2387                          "iteration space from a consecutive level.");
2388   }
2389 
2390   return success();
2391 }
2392 
2393 LogicalResult ExtractValOp::verify() {
2394   auto stt = getSparseTensorType(getTensor());
2395   auto itTp = getIterator().getType();
2396 
2397   if (stt.getEncoding() != itTp.getEncoding())
2398     return emitOpError("mismatch in tensor encoding and iterator encoding.");
2399 
2400   if (stt.getLvlRank() != itTp.getHiLvl())
2401     return emitOpError("must use last-level iterator to extract values. ");
2402 
2403   return success();
2404 }
2405 
2406 struct RemoveUnusedLvlCrds : public OpRewritePattern<IterateOp> {
2407   using OpRewritePattern::OpRewritePattern;
2408 
2409   LogicalResult matchAndRewrite(IterateOp iterateOp,
2410                                 PatternRewriter &rewriter) const override {
2411     I64BitSet newUsedLvls(0);
2412     llvm::BitVector toRemove(iterateOp.getBody()->getNumArguments());
2413     for (unsigned i = 0, e = iterateOp.getSpaceDim(); i < e; i++) {
2414       if (auto crd = iterateOp.getLvlCrd(i)) {
2415         if (crd->getUsers().empty())
2416           toRemove.set(crd->getArgNumber());
2417         else
2418           newUsedLvls.set(i);
2419       }
2420     }
2421 
2422     // All coordinates are used.
2423     if (toRemove.none())
2424       return failure();
2425 
2426     rewriter.startOpModification(iterateOp);
2427     iterateOp.setCrdUsedLvls(newUsedLvls);
2428     iterateOp.getBody()->eraseArguments(toRemove);
2429     rewriter.finalizeOpModification(iterateOp);
2430     return success();
2431   }
2432 };
2433 
2434 void IterateOp::getCanonicalizationPatterns(mlir::RewritePatternSet &results,
2435                                             mlir::MLIRContext *context) {
2436   results.add<RemoveUnusedLvlCrds>(context);
2437 }
2438 
2439 void IterateOp::build(OpBuilder &builder, OperationState &odsState,
2440                       Value iterSpace, ValueRange initArgs) {
2441   unsigned rank = llvm::cast<IterSpaceType>(iterSpace.getType()).getSpaceDim();
2442   // All ones.
2443   I64BitSet set((1 << rank) - 1);
2444   return build(builder, odsState, iterSpace, initArgs, set);
2445 }
2446 
2447 void IterateOp::build(OpBuilder &builder, OperationState &odsState,
2448                       Value iterSpace, ValueRange initArgs,
2449                       I64BitSet crdUsedLvls) {
2450   OpBuilder::InsertionGuard guard(builder);
2451 
2452   odsState.addOperands(iterSpace);
2453   odsState.addOperands(initArgs);
2454   odsState.getOrAddProperties<Properties>().crdUsedLvls =
2455       builder.getIntegerAttr(builder.getIntegerType(64), crdUsedLvls);
2456   Region *bodyRegion = odsState.addRegion();
2457   odsState.addTypes(initArgs.getTypes());
2458   Block *bodyBlock = builder.createBlock(bodyRegion);
2459 
2460   // Starts with a list of user-provided loop arguments.
2461   for (Value v : initArgs)
2462     bodyBlock->addArgument(v.getType(), v.getLoc());
2463 
2464   // Follows by a list of used coordinates.
2465   for (unsigned i = 0, e = crdUsedLvls.count(); i < e; i++)
2466     bodyBlock->addArgument(builder.getIndexType(), odsState.location);
2467 
2468   // Ends with sparse iterator
2469   bodyBlock->addArgument(
2470       llvm::cast<IterSpaceType>(iterSpace.getType()).getIteratorType(),
2471       odsState.location);
2472 }
2473 
2474 ParseResult IterateOp::parse(OpAsmParser &parser, OperationState &result) {
2475   OpAsmParser::Argument iterator;
2476   OpAsmParser::UnresolvedOperand iterSpace;
2477 
2478   SmallVector<OpAsmParser::Argument> iters, iterArgs;
2479   if (parseSparseIterateLoop(parser, result, iters, iterArgs))
2480     return failure();
2481   if (iters.size() != 1)
2482     return parser.emitError(parser.getNameLoc(),
2483                             "expected only one iterator/iteration space");
2484 
2485   iterArgs.append(iters);
2486   Region *body = result.addRegion();
2487   if (parser.parseRegion(*body, iterArgs))
2488     return failure();
2489 
2490   IterateOp::ensureTerminator(*body, parser.getBuilder(), result.location);
2491 
2492   // Parse the optional attribute list.
2493   if (parser.parseOptionalAttrDict(result.attributes))
2494     return failure();
2495 
2496   return success();
2497 }
2498 
2499 /// Prints the initialization list in the form of
2500 ///   <prefix>(%inner = %outer, %inner2 = %outer2, <...>)
2501 /// where 'inner' values are assumed to be region arguments and 'outer' values
2502 /// are regular SSA values.
2503 static void printInitializationList(OpAsmPrinter &p,
2504                                     Block::BlockArgListType blocksArgs,
2505                                     ValueRange initializers,
2506                                     StringRef prefix = "") {
2507   assert(blocksArgs.size() == initializers.size() &&
2508          "expected same length of arguments and initializers");
2509   if (initializers.empty())
2510     return;
2511 
2512   p << prefix << '(';
2513   llvm::interleaveComma(llvm::zip(blocksArgs, initializers), p, [&](auto it) {
2514     p << std::get<0>(it) << " = " << std::get<1>(it);
2515   });
2516   p << ")";
2517 }
2518 
2519 template <typename SparseLoopOp>
2520 static LogicalResult verifySparseLoopOp(SparseLoopOp op) {
2521   if (op.getInitArgs().size() != op.getNumResults()) {
2522     return op.emitOpError(
2523         "mismatch in number of loop-carried values and defined values");
2524   }
2525   if (op.getCrdUsedLvls().max() > op.getSpaceDim())
2526     return op.emitOpError("required out-of-bound coordinates");
2527 
2528   return success();
2529 }
2530 
2531 LogicalResult IterateOp::verify() { return verifySparseLoopOp(*this); }
2532 LogicalResult CoIterateOp::verify() { return verifySparseLoopOp(*this); }
2533 
2534 void IterateOp::print(OpAsmPrinter &p) {
2535   p << " " << getIterator() << " in " << getIterSpace();
2536   if (!getCrdUsedLvls().empty()) {
2537     p << " at(";
2538     printOptionalDefinedList(p, getSpaceDim(), getCrds(), getCrdUsedLvls());
2539     p << ")";
2540   }
2541   printInitializationList(p, getRegionIterArgs(), getInitArgs(), " iter_args");
2542 
2543   p << " : " << getIterSpace().getType() << " ";
2544   if (!getInitArgs().empty())
2545     p.printArrowTypeList(getInitArgs().getTypes());
2546 
2547   p << " ";
2548   p.printRegion(getRegion(), /*printEntryBlockArgs=*/false,
2549                 /*printBlockTerminators=*/!getInitArgs().empty());
2550 }
2551 
2552 LogicalResult IterateOp::verifyRegions() {
2553   if (getIterator().getType() != getIterSpace().getType().getIteratorType())
2554     return emitOpError("mismatch in iterator and iteration space type");
2555   if (getNumRegionIterArgs() != getNumResults())
2556     return emitOpError(
2557         "mismatch in number of basic block args and defined values");
2558 
2559   auto initArgs = getInitArgs();
2560   auto iterArgs = getRegionIterArgs();
2561   auto yieldVals = getYieldedValues();
2562   auto opResults = getResults();
2563   if (!llvm::all_equal({initArgs.size(), iterArgs.size(), yieldVals.size(),
2564                         opResults.size()})) {
2565     return emitOpError() << "number mismatch between iter args and results.";
2566   }
2567 
2568   for (auto [i, init, iter, yield, ret] :
2569        llvm::enumerate(initArgs, iterArgs, yieldVals, opResults)) {
2570     if (init.getType() != ret.getType())
2571       return emitOpError() << "types mismatch between " << i
2572                            << "th iter operand and defined value";
2573     if (iter.getType() != ret.getType())
2574       return emitOpError() << "types mismatch between " << i
2575                            << "th iter region arg and defined value";
2576     if (yield.getType() != ret.getType())
2577       return emitOpError() << "types mismatch between " << i
2578                            << "th yield value and defined value";
2579   }
2580 
2581   return success();
2582 }
2583 
2584 /// OpInterfaces' methods implemented by IterateOp.
2585 SmallVector<Region *> IterateOp::getLoopRegions() { return {&getRegion()}; }
2586 
2587 MutableArrayRef<OpOperand> IterateOp::getInitsMutable() {
2588   return getInitArgsMutable();
2589 }
2590 
2591 Block::BlockArgListType IterateOp::getRegionIterArgs() {
2592   return getRegion().getArguments().take_front(getNumRegionIterArgs());
2593 }
2594 
2595 std::optional<MutableArrayRef<OpOperand>> IterateOp::getYieldedValuesMutable() {
2596   return cast<sparse_tensor::YieldOp>(
2597              getRegion().getBlocks().front().getTerminator())
2598       .getResultsMutable();
2599 }
2600 
2601 std::optional<ResultRange> IterateOp::getLoopResults() { return getResults(); }
2602 
2603 OperandRange IterateOp::getEntrySuccessorOperands(RegionBranchPoint point) {
2604   return getInitArgs();
2605 }
2606 
2607 void IterateOp::getSuccessorRegions(RegionBranchPoint point,
2608                                     SmallVectorImpl<RegionSuccessor> &regions) {
2609   // Both the operation itself and the region may be branching into the body
2610   // or back into the operation itself.
2611   regions.push_back(RegionSuccessor(&getRegion(), getRegionIterArgs()));
2612   // It is possible for loop not to enter the body.
2613   regions.push_back(RegionSuccessor(getResults()));
2614 }
2615 
2616 void CoIterateOp::build(OpBuilder &builder, OperationState &odsState,
2617                         ValueRange iterSpaces, ValueRange initArgs,
2618                         unsigned numCases) {
2619   unsigned rank =
2620       cast<IterSpaceType>(iterSpaces.front().getType()).getSpaceDim();
2621   // All ones.
2622   I64BitSet set((1 << rank) - 1);
2623   // Generates all-zero case bits (they only serve as placeholders), which are
2624   // supposed to be overriden later. We need to preallocate all the regions as
2625   // mlir::Region cannot be dynamically added later after the operation is
2626   // created.
2627   SmallVector<int64_t> caseBits(numCases, 0);
2628   ArrayAttr cases = builder.getI64ArrayAttr(caseBits);
2629   return CoIterateOp::build(builder, odsState, initArgs.getTypes(), iterSpaces,
2630                             initArgs, set, cases,
2631                             /*caseRegionsCount=*/numCases);
2632 }
2633 
2634 ParseResult CoIterateOp::parse(OpAsmParser &parser, OperationState &result) {
2635 
2636   SmallVector<Value> spaces;
2637   // The block argument list of each regions, it is arranged in the order of
2638   // ([used coordinate list], [loop iterations args], [sparse iterator list]).
2639   SmallVector<OpAsmParser::Argument> blockArgs;
2640   if (parseSparseCoIterateLoop(parser, result, spaces, blockArgs))
2641     return failure();
2642 
2643   result.addAttribute("operandSegmentSizes",
2644                       parser.getBuilder().getDenseI32ArrayAttr(
2645                           {static_cast<int32_t>(spaces.size()),
2646                            static_cast<int32_t>(result.types.size())}));
2647 
2648   SmallVector<Attribute> cases;
2649   while (succeeded(parser.parseOptionalKeyword("case"))) {
2650     // Parse one region per case.
2651     I64BitSet definedItSet;
2652     SmallVector<OpAsmParser::Argument> definedIts;
2653     if (parseOptionalDefinedList(parser, result, definedItSet, definedIts,
2654                                  spaces.size(), OpAsmParser::Delimiter::None))
2655       return failure();
2656 
2657     cases.push_back(parser.getBuilder().getI64IntegerAttr(definedItSet));
2658 
2659     for (auto [i, definedIdx] : llvm::enumerate(definedItSet.bits())) {
2660       // Resolve the iterator type based on the iteration space type.
2661       auto spaceTp = llvm::cast<IterSpaceType>(spaces[definedIdx].getType());
2662       definedIts[i].type = spaceTp.getIteratorType();
2663     }
2664     definedIts.insert(definedIts.begin(), blockArgs.begin(), blockArgs.end());
2665     Region *body = result.addRegion();
2666     if (parser.parseRegion(*body, definedIts))
2667       return failure();
2668 
2669     CoIterateOp::ensureTerminator(*body, parser.getBuilder(), result.location);
2670   }
2671 
2672   result.addAttribute("cases", ArrayAttr::get(parser.getContext(), cases));
2673 
2674   // Parse the optional attribute list.
2675   if (parser.parseOptionalAttrDict(result.attributes))
2676     return failure();
2677 
2678   return success();
2679 }
2680 
2681 void CoIterateOp::print(OpAsmPrinter &p) {
2682   p << " (";
2683   llvm::interleaveComma(getIterSpaces(), p, [&](auto s) { p << s; });
2684   p << ")";
2685 
2686   if (!getCrdUsedLvls().empty()) {
2687     p << " at(";
2688     printOptionalDefinedList(p, getSpaceDim(), getCrds(0), getCrdUsedLvls());
2689     p << ")";
2690   }
2691 
2692   printInitializationList(p, getRegionIterArgs(0), getInitArgs(), " iter_args");
2693 
2694   p << " : (" << getIterSpaces().getTypes() << ")";
2695   if (!getInitArgs().empty())
2696     p.printArrowTypeList(getInitArgs().getTypes());
2697 
2698   for (unsigned idx = 0, e = getRegions().size(); idx < e; idx++) {
2699     p.printNewline();
2700     p << "case ";
2701     printOptionalDefinedList(p, getIterSpaces().size(), getRegionIterators(idx),
2702                              getRegionDefinedSpace(idx));
2703     p << " ";
2704     p.printRegion(getRegion(idx), /*printEntryBlockArgs=*/false,
2705                   /*printBlockTerminators=*/!getInitArgs().empty());
2706   }
2707 }
2708 
2709 ValueRange CoIterateOp::getYieldedValues(unsigned regionIdx) {
2710   return cast<sparse_tensor::YieldOp>(
2711              getRegion(regionIdx).getBlocks().front().getTerminator())
2712       .getResults();
2713 }
2714 
2715 LogicalResult CoIterateOp::verifyRegions() {
2716   for (unsigned r = 0, e = getNumRegions(); r < e; r++) {
2717     if (getNumRegionIterArgs() != getNumResults())
2718       return emitOpError(
2719           "mismatch in number of basic block args and defined values");
2720 
2721     auto initArgs = getInitArgs();
2722     auto iterArgs = getRegionIterArgs(r);
2723     auto yieldVals = getYieldedValues(r);
2724     auto opResults = getResults();
2725     if (!llvm::all_equal({initArgs.size(), iterArgs.size(), yieldVals.size(),
2726                           opResults.size()})) {
2727       return emitOpError()
2728              << "number mismatch between iter args and results on " << r
2729              << "th region";
2730     }
2731 
2732     for (auto [i, init, iter, yield, ret] :
2733          llvm::enumerate(initArgs, iterArgs, yieldVals, opResults)) {
2734       if (init.getType() != ret.getType())
2735         return emitOpError()
2736                << "types mismatch between " << i
2737                << "th iter operand and defined value on " << r << "th region";
2738       if (iter.getType() != ret.getType())
2739         return emitOpError() << "types mismatch between " << i
2740                              << "th iter region arg and defined value on " << r
2741                              << "th region";
2742       if (yield.getType() != ret.getType())
2743         return emitOpError()
2744                << "types mismatch between " << i
2745                << "th yield value and defined value on " << r << "th region";
2746     }
2747   }
2748 
2749   auto cases = getRegionDefinedSpaces();
2750   llvm::SmallSetVector<uint64_t, 8> set(cases.begin(), cases.end());
2751   if (set.size() != getNumRegions())
2752     return emitOpError("contains duplicated cases.");
2753 
2754   return success();
2755 }
2756 
2757 SmallVector<Region *> CoIterateOp::getSubCasesOf(unsigned regionIdx) {
2758   SmallVector<Region *> ret;
2759   I64BitSet caseBit = getRegionDefinedSpace(regionIdx);
2760   for (Region &r : getCaseRegions())
2761     if (getRegionDefinedSpace(r.getRegionNumber()).isSubSetOf(caseBit))
2762       ret.push_back(&r);
2763 
2764   return ret;
2765 }
2766 
2767 //===----------------------------------------------------------------------===//
2768 // Sparse Tensor Dialect Setups.
2769 //===----------------------------------------------------------------------===//
2770 
2771 /// Materialize a single constant operation from a given attribute value with
2772 /// the desired resultant type.
2773 Operation *SparseTensorDialect::materializeConstant(OpBuilder &builder,
2774                                                     Attribute value, Type type,
2775                                                     Location loc) {
2776   if (auto op = arith::ConstantOp::materialize(builder, value, type, loc))
2777     return op;
2778   return nullptr;
2779 }
2780 
2781 namespace {
2782 struct SparseTensorAsmDialectInterface : public OpAsmDialectInterface {
2783   using OpAsmDialectInterface::OpAsmDialectInterface;
2784 
2785   AliasResult getAlias(Attribute attr, raw_ostream &os) const override {
2786     if (isa<SparseTensorEncodingAttr>(attr)) {
2787       os << "sparse";
2788       return AliasResult::OverridableAlias;
2789     }
2790     return AliasResult::NoAlias;
2791   }
2792 };
2793 } // namespace
2794 
2795 void SparseTensorDialect::initialize() {
2796   addInterface<SparseTensorAsmDialectInterface>();
2797   addAttributes<
2798 #define GET_ATTRDEF_LIST
2799 #include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.cpp.inc"
2800       >();
2801   addTypes<
2802 #define GET_TYPEDEF_LIST
2803 #include "mlir/Dialect/SparseTensor/IR/SparseTensorTypes.cpp.inc"
2804       >();
2805   addOperations<
2806 #define GET_OP_LIST
2807 #include "mlir/Dialect/SparseTensor/IR/SparseTensorOps.cpp.inc"
2808       >();
2809   declarePromisedInterfaces<
2810       bufferization::BufferizableOpInterface, ConcatenateOp, ConvertOp, LoadOp,
2811       NewOp, NumberOfEntriesOp, AssembleOp, DisassembleOp,
2812       ToCoordinatesBufferOp, ToCoordinatesOp, ToPositionsOp, ToValuesOp>();
2813 }
2814 
2815 #define GET_OP_CLASSES
2816 #include "mlir/Dialect/SparseTensor/IR/SparseTensorOps.cpp.inc"
2817 
2818 #include "mlir/Dialect/SparseTensor/IR/SparseTensorOpsDialect.cpp.inc"
2819