xref: /llvm-project/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp (revision a02010b3e97b5f01d4ff921b353f4a25a29c45cd)
1 //===- SparseTensorDialect.cpp - Sparse tensor dialect implementation -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include <utility>
10 
11 #include "Detail/DimLvlMapParser.h"
12 
13 #include "mlir/Dialect/SparseTensor/IR/Enums.h"
14 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
15 #include "mlir/Dialect/SparseTensor/IR/SparseTensorStorageLayout.h"
16 #include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h"
17 
18 #include "mlir/Dialect/Arith/IR/Arith.h"
19 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
20 #include "mlir/Dialect/Complex/IR/Complex.h"
21 #include "mlir/Dialect/Utils/StaticValueUtils.h"
22 #include "mlir/IR/Builders.h"
23 #include "mlir/IR/DialectImplementation.h"
24 #include "mlir/IR/Matchers.h"
25 #include "mlir/IR/OpImplementation.h"
26 #include "mlir/IR/PatternMatch.h"
27 #include "llvm/ADT/Bitset.h"
28 #include "llvm/ADT/TypeSwitch.h"
29 #include "llvm/Support/FormatVariadic.h"
30 
31 #define GET_ATTRDEF_CLASSES
32 #include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.cpp.inc"
33 #include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrEnums.cpp.inc"
34 
35 // Forward declarations, following custom print/parsing methods are referenced
36 // by the generated code for SparseTensorTypes.td.
37 static mlir::ParseResult parseLevelRange(mlir::AsmParser &,
38                                          mlir::sparse_tensor::Level &,
39                                          mlir::sparse_tensor::Level &);
40 static void printLevelRange(mlir::AsmPrinter &, mlir::sparse_tensor::Level,
41                             mlir::sparse_tensor::Level);
42 
43 #define GET_TYPEDEF_CLASSES
44 #include "mlir/Dialect/SparseTensor/IR/SparseTensorTypes.cpp.inc"
45 
46 using namespace mlir;
47 using namespace mlir::sparse_tensor;
48 
49 // Support hashing LevelType such that SparseTensorEncodingAttr can be hashed as
50 // well.
51 namespace mlir::sparse_tensor {
52 llvm::hash_code hash_value(LevelType lt) {
53   return llvm::hash_value(static_cast<uint64_t>(lt));
54 }
55 } // namespace mlir::sparse_tensor
56 
57 //===----------------------------------------------------------------------===//
58 // Local Convenience Methods.
59 //===----------------------------------------------------------------------===//
60 
61 static constexpr bool acceptBitWidth(unsigned bitWidth) {
62   switch (bitWidth) {
63   case 0:
64   case 8:
65   case 16:
66   case 32:
67   case 64:
68     return true;
69   default:
70     return false;
71   }
72 }
73 
74 static SmallVector<Size>
75 getSparseFieldShape(const SparseTensorEncodingAttr enc,
76                     std::optional<ArrayRef<int64_t>> dimShape) {
77   assert(enc);
78   // With only encoding, we can not determine the static shape for leading
79   // batch levels, we therefore return a dynamic shape memref instead.
80   SmallVector<int64_t> memrefShape(enc.getBatchLvlRank(), ShapedType::kDynamic);
81   if (dimShape.has_value()) {
82     // If the actual tensor shape is provided, we can then refine the leading
83     // batch dimension.
84     SmallVector<int64_t> lvlShape =
85         enc.translateShape(*dimShape, CrdTransDirectionKind::dim2lvl);
86     memrefShape.assign(lvlShape.begin(),
87                        lvlShape.begin() + enc.getBatchLvlRank());
88   }
89   // Another dynamic dimension to store the sparse level.
90   memrefShape.push_back(ShapedType::kDynamic);
91   return memrefShape;
92 }
93 
94 //===----------------------------------------------------------------------===//
95 // SparseTensorDialect StorageLayout.
96 //===----------------------------------------------------------------------===//
97 
98 static constexpr Level kInvalidLevel = -1u;
99 static constexpr Level kInvalidFieldIndex = -1u;
100 static constexpr FieldIndex kDataFieldStartingIdx = 0;
101 
102 void StorageLayout::foreachField(
103     llvm::function_ref<bool(FieldIndex, SparseTensorFieldKind, Level,
104                             LevelType)>
105         callback) const {
106   const auto lvlTypes = enc.getLvlTypes();
107   const Level lvlRank = enc.getLvlRank();
108   SmallVector<COOSegment> cooSegs = enc.getCOOSegments();
109   FieldIndex fieldIdx = kDataFieldStartingIdx;
110 
111   ArrayRef cooSegsRef = cooSegs;
112   // Per-level storage.
113   for (Level l = 0; l < lvlRank; /*l += 1 or l += AoSCooLen*/) {
114     const auto lt = lvlTypes[l];
115     if (isWithPosLT(lt)) {
116       if (!(callback(fieldIdx++, SparseTensorFieldKind::PosMemRef, l, lt)))
117         return;
118     }
119     if (isWithCrdLT(lt)) {
120       if (!(callback(fieldIdx++, SparseTensorFieldKind::CrdMemRef, l, lt)))
121         return;
122     }
123     if (!cooSegsRef.empty() && cooSegsRef.front().isSegmentStart(l)) {
124       if (!cooSegsRef.front().isSoA) {
125         // AoS COO, all singletons are fused into one memrefs. Skips the entire
126         // COO segement.
127         l = cooSegsRef.front().lvlRange.second;
128       } else {
129         // SoA COO, each singleton level has one memref.
130         l++;
131       }
132       // Expire handled COO segment.
133       cooSegsRef = cooSegsRef.drop_front();
134     } else {
135       // Non COO levels.
136       l++;
137     }
138   }
139   // The values array.
140   if (!(callback(fieldIdx++, SparseTensorFieldKind::ValMemRef, kInvalidLevel,
141                  LevelFormat::Undef)))
142     return;
143   // Put metadata at the end.
144   if (!(callback(fieldIdx++, SparseTensorFieldKind::StorageSpec, kInvalidLevel,
145                  LevelFormat::Undef)))
146     return;
147 }
148 
149 void sparse_tensor::foreachFieldAndTypeInSparseTensor(
150     SparseTensorType stt,
151     llvm::function_ref<bool(Type, FieldIndex, SparseTensorFieldKind, Level,
152                             LevelType)>
153         callback) {
154   assert(stt.hasEncoding());
155 
156   SmallVector<int64_t> memrefShape =
157       getSparseFieldShape(stt.getEncoding(), stt.getDimShape());
158 
159   const Type specType = StorageSpecifierType::get(stt.getEncoding());
160   // memref<[batch] x ? x pos>  positions
161   const Type posMemType = MemRefType::get(memrefShape, stt.getPosType());
162   // memref<[batch] x ? x crd>  coordinates
163   const Type crdMemType = MemRefType::get(memrefShape, stt.getCrdType());
164   // memref<[batch] x ? x eltType> values
165   const Type valMemType = MemRefType::get(memrefShape, stt.getElementType());
166 
167   StorageLayout(stt).foreachField([specType, posMemType, crdMemType, valMemType,
168                                    callback](FieldIndex fieldIdx,
169                                              SparseTensorFieldKind fieldKind,
170                                              Level lvl, LevelType lt) -> bool {
171     switch (fieldKind) {
172     case SparseTensorFieldKind::StorageSpec:
173       return callback(specType, fieldIdx, fieldKind, lvl, lt);
174     case SparseTensorFieldKind::PosMemRef:
175       return callback(posMemType, fieldIdx, fieldKind, lvl, lt);
176     case SparseTensorFieldKind::CrdMemRef:
177       return callback(crdMemType, fieldIdx, fieldKind, lvl, lt);
178     case SparseTensorFieldKind::ValMemRef:
179       return callback(valMemType, fieldIdx, fieldKind, lvl, lt);
180     };
181     llvm_unreachable("unrecognized field kind");
182   });
183 }
184 
185 unsigned StorageLayout::getNumFields() const {
186   unsigned numFields = 0;
187   foreachField([&numFields](FieldIndex, SparseTensorFieldKind, Level,
188                             LevelType) -> bool {
189     numFields++;
190     return true;
191   });
192   return numFields;
193 }
194 
195 unsigned StorageLayout::getNumDataFields() const {
196   unsigned numFields = 0; // one value memref
197   foreachField([&numFields](FieldIndex fidx, SparseTensorFieldKind, Level,
198                             LevelType) -> bool {
199     if (fidx >= kDataFieldStartingIdx)
200       numFields++;
201     return true;
202   });
203   numFields -= 1; // the last field is StorageSpecifier
204   assert(numFields == getNumFields() - kDataFieldStartingIdx - 1);
205   return numFields;
206 }
207 
208 std::pair<FieldIndex, unsigned>
209 StorageLayout::getFieldIndexAndStride(SparseTensorFieldKind kind,
210                                       std::optional<Level> lvl) const {
211   FieldIndex fieldIdx = kInvalidFieldIndex;
212   unsigned stride = 1;
213   if (kind == SparseTensorFieldKind::CrdMemRef) {
214     assert(lvl.has_value());
215     const Level cooStart = enc.getAoSCOOStart();
216     const Level lvlRank = enc.getLvlRank();
217     if (lvl.value() >= cooStart && lvl.value() < lvlRank) {
218       lvl = cooStart;
219       stride = lvlRank - cooStart;
220     }
221   }
222   foreachField([lvl, kind, &fieldIdx](FieldIndex fIdx,
223                                       SparseTensorFieldKind fKind, Level fLvl,
224                                       LevelType lt) -> bool {
225     if ((lvl && fLvl == lvl.value() && kind == fKind) ||
226         (kind == fKind && fKind == SparseTensorFieldKind::ValMemRef)) {
227       fieldIdx = fIdx;
228       // Returns false to break the iteration.
229       return false;
230     }
231     return true;
232   });
233   assert(fieldIdx != kInvalidFieldIndex);
234   return std::pair<FieldIndex, unsigned>(fieldIdx, stride);
235 }
236 
237 //===----------------------------------------------------------------------===//
238 // SparseTensorDialect Attribute Methods.
239 //===----------------------------------------------------------------------===//
240 
241 std::optional<uint64_t> SparseTensorDimSliceAttr::getStatic(int64_t v) {
242   return isDynamic(v) ? std::nullopt
243                       : std::make_optional(static_cast<uint64_t>(v));
244 }
245 
246 std::optional<uint64_t> SparseTensorDimSliceAttr::getStaticOffset() const {
247   return getStatic(getOffset());
248 }
249 
250 std::optional<uint64_t> SparseTensorDimSliceAttr::getStaticStride() const {
251   return getStatic(getStride());
252 }
253 
254 std::optional<uint64_t> SparseTensorDimSliceAttr::getStaticSize() const {
255   return getStatic(getSize());
256 }
257 
258 bool SparseTensorDimSliceAttr::isCompletelyDynamic() const {
259   return isDynamic(getOffset()) && isDynamic(getStride()) &&
260          isDynamic(getSize());
261 }
262 
263 std::string SparseTensorDimSliceAttr::getStaticString(int64_t v) {
264   return isDynamic(v) ? "?" : std::to_string(v);
265 }
266 
267 void SparseTensorDimSliceAttr::print(llvm::raw_ostream &os) const {
268   assert(getImpl() && "Uninitialized SparseTensorDimSliceAttr");
269   os << '(';
270   os << getStaticString(getOffset());
271   os << ", ";
272   os << getStaticString(getSize());
273   os << ", ";
274   os << getStaticString(getStride());
275   os << ')';
276 }
277 
278 void SparseTensorDimSliceAttr::print(AsmPrinter &printer) const {
279   print(printer.getStream());
280 }
281 
282 static ParseResult parseOptionalStaticSlice(int64_t &result,
283                                             AsmParser &parser) {
284   auto parseResult = parser.parseOptionalInteger(result);
285   if (parseResult.has_value()) {
286     if (parseResult.value().succeeded() && result < 0) {
287       parser.emitError(
288           parser.getCurrentLocation(),
289           "expect positive value or ? for slice offset/size/stride");
290       return failure();
291     }
292     return parseResult.value();
293   }
294 
295   // Else, and '?' which represented dynamic slice
296   result = SparseTensorDimSliceAttr::kDynamic;
297   return parser.parseQuestion();
298 }
299 
300 Attribute SparseTensorDimSliceAttr::parse(AsmParser &parser, Type type) {
301   int64_t offset = kDynamic, size = kDynamic, stride = kDynamic;
302 
303   if (failed(parser.parseLParen()) ||
304       failed(parseOptionalStaticSlice(offset, parser)) ||
305       failed(parser.parseComma()) ||
306       failed(parseOptionalStaticSlice(size, parser)) ||
307       failed(parser.parseComma()) ||
308       failed(parseOptionalStaticSlice(stride, parser)) ||
309       failed(parser.parseRParen()))
310     return {};
311 
312   return parser.getChecked<SparseTensorDimSliceAttr>(parser.getContext(),
313                                                      offset, size, stride);
314 }
315 
316 LogicalResult
317 SparseTensorDimSliceAttr::verify(function_ref<InFlightDiagnostic()> emitError,
318                                  int64_t offset, int64_t size, int64_t stride) {
319   if (!isDynamic(offset) && offset < 0)
320     return emitError() << "expect non-negative value or ? for slice offset";
321   if (!isDynamic(size) && size <= 0)
322     return emitError() << "expect positive value or ? for slice size";
323   if (!isDynamic(stride) && stride <= 0)
324     return emitError() << "expect positive value or ? for slice stride";
325   return success();
326 }
327 
328 SparseTensorEncodingAttr
329 SparseTensorEncodingAttr::withDimToLvl(AffineMap dimToLvl) const {
330   assert(getImpl() && "Uninitialized SparseTensorEncodingAttr");
331   return SparseTensorEncodingAttr::get(
332       getContext(), getLvlTypes(), dimToLvl, AffineMap(), getPosWidth(),
333       getCrdWidth(), getExplicitVal(), getImplicitVal());
334 }
335 
336 SparseTensorEncodingAttr
337 SparseTensorEncodingAttr::withDimToLvl(SparseTensorEncodingAttr enc) const {
338   return withDimToLvl(enc ? enc.getDimToLvl() : AffineMap());
339 }
340 
341 SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutDimToLvl() const {
342   return withDimToLvl(AffineMap());
343 }
344 
345 SparseTensorEncodingAttr
346 SparseTensorEncodingAttr::withBitWidths(unsigned posWidth,
347                                         unsigned crdWidth) const {
348   assert(getImpl() && "Uninitialized SparseTensorEncodingAttr");
349   return SparseTensorEncodingAttr::get(
350       getContext(), getLvlTypes(), getDimToLvl(), getLvlToDim(), posWidth,
351       crdWidth, getExplicitVal(), getImplicitVal());
352 }
353 
354 SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutBitWidths() const {
355   return withBitWidths(0, 0);
356 }
357 
358 SparseTensorEncodingAttr
359 SparseTensorEncodingAttr::withExplicitVal(Attribute explicitVal) const {
360   assert(getImpl() && "Uninitialized SparseTensorEncodingAttr");
361   return SparseTensorEncodingAttr::get(
362       getContext(), getLvlTypes(), getDimToLvl(), getLvlToDim(), getPosWidth(),
363       getCrdWidth(), explicitVal, getImplicitVal());
364 }
365 
366 SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutExplicitVal() const {
367   return withExplicitVal(Attribute());
368 }
369 
370 SparseTensorEncodingAttr
371 SparseTensorEncodingAttr::withImplicitVal(Attribute implicitVal) const {
372   assert(getImpl() && "Uninitialized SparseTensorEncodingAttr");
373   return SparseTensorEncodingAttr::get(
374       getContext(), getLvlTypes(), getDimToLvl(), getLvlToDim(), getPosWidth(),
375       getCrdWidth(), getExplicitVal(), implicitVal);
376 }
377 
378 SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutImplicitVal() const {
379   return withImplicitVal(Attribute());
380 }
381 
382 SparseTensorEncodingAttr SparseTensorEncodingAttr::withDimSlices(
383     ArrayRef<SparseTensorDimSliceAttr> dimSlices) const {
384   return SparseTensorEncodingAttr::get(
385       getContext(), getLvlTypes(), getDimToLvl(), getLvlToDim(), getPosWidth(),
386       getCrdWidth(), getExplicitVal(), getImplicitVal(), dimSlices);
387 }
388 
389 SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutDimSlices() const {
390   return withDimSlices(ArrayRef<SparseTensorDimSliceAttr>{});
391 }
392 
393 uint64_t SparseTensorEncodingAttr::getBatchLvlRank() const {
394   ArrayRef<LevelType> lvlTypes = getLvlTypes();
395   auto lastBatch = std::find_if(lvlTypes.rbegin(), lvlTypes.rend(), isBatchLT);
396   return std::distance(lastBatch, lvlTypes.rend());
397 }
398 
399 bool SparseTensorEncodingAttr::isAllDense() const {
400   return !getImpl() || llvm::all_of(getLvlTypes(), isDenseLT);
401 }
402 
403 bool SparseTensorEncodingAttr::isAllOrdered() const {
404   return !getImpl() || llvm::all_of(getLvlTypes(), isOrderedLT);
405 }
406 
407 Type SparseTensorEncodingAttr::getCrdElemType() const {
408   if (!getImpl())
409     return nullptr;
410   if (getCrdWidth())
411     return IntegerType::get(getContext(), getCrdWidth());
412   return IndexType::get(getContext());
413 }
414 
415 Type SparseTensorEncodingAttr::getPosElemType() const {
416   if (!getImpl())
417     return nullptr;
418   if (getPosWidth())
419     return IntegerType::get(getContext(), getPosWidth());
420   return IndexType::get(getContext());
421 }
422 
423 MemRefType SparseTensorEncodingAttr::getCrdMemRefType(
424     std::optional<ArrayRef<int64_t>> dimShape) const {
425   SmallVector<Size> shape = getSparseFieldShape(*this, dimShape);
426   return MemRefType::get(shape, getCrdElemType());
427 }
428 
429 MemRefType SparseTensorEncodingAttr::getPosMemRefType(
430     std::optional<ArrayRef<int64_t>> dimShape) const {
431   SmallVector<Size> shape = getSparseFieldShape(*this, dimShape);
432   return MemRefType::get(shape, getPosElemType());
433 }
434 
435 bool SparseTensorEncodingAttr::isIdentity() const {
436   return !getImpl() || !getDimToLvl() || getDimToLvl().isIdentity();
437 }
438 
439 bool SparseTensorEncodingAttr::isPermutation() const {
440   return !getImpl() || !getDimToLvl() || getDimToLvl().isPermutation();
441 }
442 
443 Dimension SparseTensorEncodingAttr::getDimRank() const {
444   assert(getImpl() && "Uninitialized SparseTensorEncodingAttr");
445   const auto dimToLvl = getDimToLvl();
446   return dimToLvl ? dimToLvl.getNumDims() : getLvlRank();
447 }
448 
449 Level SparseTensorEncodingAttr::getLvlRank() const {
450   assert(getImpl() && "Uninitialized SparseTensorEncodingAttr");
451   return getLvlTypes().size();
452 }
453 
454 LevelType SparseTensorEncodingAttr::getLvlType(Level l) const {
455   if (!getImpl())
456     return LevelFormat::Batch;
457   assert(l < getLvlRank() && "Level is out of bounds");
458   return getLvlTypes()[l];
459 }
460 
461 bool SparseTensorEncodingAttr::isSlice() const {
462   assert(getImpl() && "Uninitialized SparseTensorEncodingAttr");
463   return !getDimSlices().empty();
464 }
465 
466 SparseTensorDimSliceAttr
467 SparseTensorEncodingAttr::getDimSlice(Dimension dim) const {
468   assert(isSlice() && "Is not a slice");
469   const auto dimSlices = getDimSlices();
470   assert(dim < dimSlices.size() && "Dimension is out of bounds");
471   return dimSlices[dim];
472 }
473 
474 std::optional<uint64_t>
475 SparseTensorEncodingAttr::getStaticDimSliceOffset(Dimension dim) const {
476   return getDimSlice(dim).getStaticOffset();
477 }
478 
479 std::optional<uint64_t>
480 SparseTensorEncodingAttr::getStaticDimSliceStride(Dimension dim) const {
481   return getDimSlice(dim).getStaticStride();
482 }
483 
484 std::optional<uint64_t>
485 SparseTensorEncodingAttr::getStaticLvlSliceOffset(Level lvl) const {
486   return getStaticDimSliceOffset(toDim(*this, lvl));
487 }
488 
489 std::optional<uint64_t>
490 SparseTensorEncodingAttr::getStaticLvlSliceStride(Level lvl) const {
491   return getStaticDimSliceStride(toDim(*this, lvl));
492 }
493 
494 SmallVector<int64_t>
495 SparseTensorEncodingAttr::translateShape(ArrayRef<int64_t> srcShape,
496                                          CrdTransDirectionKind dir) const {
497   if (isIdentity())
498     return SmallVector<int64_t>(srcShape);
499 
500   SmallVector<int64_t> ret;
501   unsigned rank =
502       dir == CrdTransDirectionKind::dim2lvl ? getLvlRank() : getDimRank();
503   ret.reserve(rank);
504 
505   if (isPermutation()) {
506     for (unsigned r = 0; r < rank; r++) {
507       unsigned trans = dir == CrdTransDirectionKind::dim2lvl ? toDim(*this, r)
508                                                              : toLvl(*this, r);
509       ret.push_back(srcShape[trans]);
510     }
511     return ret;
512   }
513 
514   // Handle non-permutation maps.
515   AffineMap transMap =
516       dir == CrdTransDirectionKind::dim2lvl ? getDimToLvl() : getLvlToDim();
517 
518   SmallVector<AffineExpr> dimRep;
519   dimRep.reserve(srcShape.size());
520   for (int64_t sz : srcShape) {
521     if (!ShapedType::isDynamic(sz)) {
522       // Push back the max coordinate for the given dimension/level size.
523       dimRep.push_back(getAffineConstantExpr(sz - 1, getContext()));
524     } else {
525       // A dynamic size, use a AffineDimExpr to symbolize the value.
526       dimRep.push_back(getAffineDimExpr(dimRep.size(), getContext()));
527     }
528   };
529 
530   for (AffineExpr exp : transMap.getResults()) {
531     // Do constant propagation on the affine map.
532     AffineExpr evalExp =
533         simplifyAffineExpr(exp.replaceDims(dimRep), srcShape.size(), 0);
534     // use llvm namespace here to avoid ambiguity
535     if (auto c = llvm::dyn_cast<AffineConstantExpr>(evalExp)) {
536       ret.push_back(c.getValue() + 1);
537     } else {
538       if (auto mod = llvm::dyn_cast<AffineBinaryOpExpr>(evalExp);
539           mod && mod.getKind() == AffineExprKind::Mod) {
540         // We can still infer a static bound for expressions in form
541         // "d % constant" since d % constant \in [0, constant).
542         if (auto bound = llvm::dyn_cast<AffineConstantExpr>(mod.getRHS())) {
543           ret.push_back(bound.getValue());
544           continue;
545         }
546       }
547       ret.push_back(ShapedType::kDynamic);
548     }
549   }
550   assert(ret.size() == rank);
551   return ret;
552 }
553 
554 ValueRange
555 SparseTensorEncodingAttr::translateCrds(OpBuilder &builder, Location loc,
556                                         ValueRange crds,
557                                         CrdTransDirectionKind dir) const {
558   if (!getImpl())
559     return crds;
560 
561   SmallVector<Type> retType(
562       dir == CrdTransDirectionKind::lvl2dim ? getDimRank() : getLvlRank(),
563       builder.getIndexType());
564   auto transOp = builder.create<CrdTranslateOp>(loc, retType, crds, dir, *this);
565   return transOp.getOutCrds();
566 }
567 
568 Attribute SparseTensorEncodingAttr::parse(AsmParser &parser, Type type) {
569   // Open "<{" part.
570   if (failed(parser.parseLess()))
571     return {};
572   if (failed(parser.parseLBrace()))
573     return {};
574 
575   // Process the data from the parsed dictionary value into struct-like data.
576   SmallVector<LevelType> lvlTypes;
577   SmallVector<SparseTensorDimSliceAttr> dimSlices;
578   AffineMap dimToLvl = {};
579   AffineMap lvlToDim = {};
580   unsigned posWidth = 0;
581   unsigned crdWidth = 0;
582   Attribute explicitVal;
583   Attribute implicitVal;
584   StringRef attrName;
585   SmallVector<StringRef, 5> keys = {"map", "posWidth", "crdWidth",
586                                     "explicitVal", "implicitVal"};
587   while (succeeded(parser.parseOptionalKeyword(&attrName))) {
588     // Detect admissible keyword.
589     auto *it = find(keys, attrName);
590     if (it == keys.end()) {
591       parser.emitError(parser.getNameLoc(), "unexpected key: ") << attrName;
592       return {};
593     }
594     unsigned keyWordIndex = it - keys.begin();
595     // Consume the `=` after keys
596     if (failed(parser.parseEqual()))
597       return {};
598     // Dispatch on keyword.
599     switch (keyWordIndex) {
600     case 0: { // map
601       ir_detail::DimLvlMapParser cParser(parser);
602       auto res = cParser.parseDimLvlMap();
603       if (failed(res))
604         return {};
605       const auto &dlm = *res;
606 
607       const Level lvlRank = dlm.getLvlRank();
608       for (Level lvl = 0; lvl < lvlRank; lvl++)
609         lvlTypes.push_back(dlm.getLvlType(lvl));
610 
611       const Dimension dimRank = dlm.getDimRank();
612       for (Dimension dim = 0; dim < dimRank; dim++)
613         dimSlices.push_back(dlm.getDimSlice(dim));
614       // NOTE: the old syntax requires an all-or-nothing approach to
615       // `dimSlices`; therefore, if any slice actually exists then we need
616       // to convert null-DSA into default/nop DSA.
617       const auto isDefined = [](SparseTensorDimSliceAttr slice) {
618         return static_cast<bool>(slice.getImpl());
619       };
620       if (llvm::any_of(dimSlices, isDefined)) {
621         const auto defaultSlice =
622             SparseTensorDimSliceAttr::get(parser.getContext());
623         for (Dimension dim = 0; dim < dimRank; dim++)
624           if (!isDefined(dimSlices[dim]))
625             dimSlices[dim] = defaultSlice;
626       } else {
627         dimSlices.clear();
628       }
629 
630       dimToLvl = dlm.getDimToLvlMap(parser.getContext());
631       lvlToDim = dlm.getLvlToDimMap(parser.getContext());
632       break;
633     }
634     case 1: { // posWidth
635       Attribute attr;
636       if (failed(parser.parseAttribute(attr)))
637         return {};
638       auto intAttr = llvm::dyn_cast<IntegerAttr>(attr);
639       if (!intAttr) {
640         parser.emitError(parser.getNameLoc(),
641                          "expected an integral position bitwidth");
642         return {};
643       }
644       posWidth = intAttr.getInt();
645       break;
646     }
647     case 2: { // crdWidth
648       Attribute attr;
649       if (failed(parser.parseAttribute(attr)))
650         return {};
651       auto intAttr = llvm::dyn_cast<IntegerAttr>(attr);
652       if (!intAttr) {
653         parser.emitError(parser.getNameLoc(),
654                          "expected an integral index bitwidth");
655         return {};
656       }
657       crdWidth = intAttr.getInt();
658       break;
659     }
660     case 3: { // explicitVal
661       Attribute attr;
662       if (failed(parser.parseAttribute(attr)))
663         return {};
664       if (auto result = llvm::dyn_cast<FloatAttr>(attr)) {
665         explicitVal = result;
666       } else if (auto result = llvm::dyn_cast<IntegerAttr>(attr)) {
667         explicitVal = result;
668       } else if (auto result = llvm::dyn_cast<complex::NumberAttr>(attr)) {
669         explicitVal = result;
670       } else {
671         parser.emitError(parser.getNameLoc(),
672                          "expected a numeric value for explicitVal");
673         return {};
674       }
675       break;
676     }
677     case 4: { // implicitVal
678       Attribute attr;
679       if (failed(parser.parseAttribute(attr)))
680         return {};
681       if (auto result = llvm::dyn_cast<FloatAttr>(attr)) {
682         implicitVal = result;
683       } else if (auto result = llvm::dyn_cast<IntegerAttr>(attr)) {
684         implicitVal = result;
685       } else if (auto result = llvm::dyn_cast<complex::NumberAttr>(attr)) {
686         implicitVal = result;
687       } else {
688         parser.emitError(parser.getNameLoc(),
689                          "expected a numeric value for implicitVal");
690         return {};
691       }
692       break;
693     }
694     } // switch
695     // Only last item can omit the comma.
696     if (parser.parseOptionalComma().failed())
697       break;
698   }
699 
700   // Close "}>" part.
701   if (failed(parser.parseRBrace()))
702     return {};
703   if (failed(parser.parseGreater()))
704     return {};
705 
706   // Construct struct-like storage for attribute.
707   if (!lvlToDim || lvlToDim.isEmpty()) {
708     lvlToDim = inferLvlToDim(dimToLvl, parser.getContext());
709   }
710   return parser.getChecked<SparseTensorEncodingAttr>(
711       parser.getContext(), lvlTypes, dimToLvl, lvlToDim, posWidth, crdWidth,
712       explicitVal, implicitVal, dimSlices);
713 }
714 
715 void SparseTensorEncodingAttr::print(AsmPrinter &printer) const {
716   auto map = static_cast<AffineMap>(getDimToLvl());
717   // Empty affine map indicates identity map
718   if (!map)
719     map = AffineMap::getMultiDimIdentityMap(getLvlTypes().size(), getContext());
720   printer << "<{ map = ";
721   printSymbols(map, printer);
722   printer << '(';
723   printDimensions(map, printer, getDimSlices());
724   printer << ") -> (";
725   printLevels(map, printer, getLvlTypes());
726   printer << ')';
727   // Print remaining members only for non-default values.
728   if (getPosWidth())
729     printer << ", posWidth = " << getPosWidth();
730   if (getCrdWidth())
731     printer << ", crdWidth = " << getCrdWidth();
732   if (getExplicitVal()) {
733     printer << ", explicitVal = " << getExplicitVal();
734   }
735   if (getImplicitVal())
736     printer << ", implicitVal = " << getImplicitVal();
737   printer << " }>";
738 }
739 
740 void SparseTensorEncodingAttr::printSymbols(AffineMap &map,
741                                             AsmPrinter &printer) const {
742   if (map.getNumSymbols() == 0)
743     return;
744   printer << '[';
745   for (unsigned i = 0, n = map.getNumSymbols() - 1; i < n; i++)
746     printer << 's' << i << ", ";
747   if (map.getNumSymbols() >= 1)
748     printer << 's' << map.getNumSymbols() - 1;
749   printer << ']';
750 }
751 
752 void SparseTensorEncodingAttr::printDimensions(
753     AffineMap &map, AsmPrinter &printer,
754     ArrayRef<SparseTensorDimSliceAttr> dimSlices) const {
755   if (!dimSlices.empty()) {
756     for (unsigned i = 0, n = map.getNumDims() - 1; i < n; i++)
757       printer << 'd' << i << " : " << dimSlices[i] << ", ";
758     if (map.getNumDims() >= 1) {
759       printer << 'd' << map.getNumDims() - 1 << " : "
760               << dimSlices[map.getNumDims() - 1];
761     }
762   } else {
763     for (unsigned i = 0, n = map.getNumDims() - 1; i < n; i++)
764       printer << 'd' << i << ", ";
765     if (map.getNumDims() >= 1)
766       printer << 'd' << map.getNumDims() - 1;
767   }
768 }
769 
770 void SparseTensorEncodingAttr::printLevels(AffineMap &map, AsmPrinter &printer,
771                                            ArrayRef<LevelType> lvlTypes) const {
772   for (unsigned i = 0, n = map.getNumResults() - 1; i < n; i++) {
773     map.getResult(i).print(printer.getStream());
774     printer << " : " << toMLIRString(lvlTypes[i]) << ", ";
775   }
776   if (map.getNumResults() >= 1) {
777     auto lastIndex = map.getNumResults() - 1;
778     map.getResult(lastIndex).print(printer.getStream());
779     printer << " : " << toMLIRString(lvlTypes[lastIndex]);
780   }
781 }
782 
783 LogicalResult SparseTensorEncodingAttr::verify(
784     function_ref<InFlightDiagnostic()> emitError, ArrayRef<LevelType> lvlTypes,
785     AffineMap dimToLvl, AffineMap lvlToDim, unsigned posWidth,
786     unsigned crdWidth, Attribute explicitVal, Attribute implicitVal,
787     ArrayRef<SparseTensorDimSliceAttr> dimSlices) {
788   if (!acceptBitWidth(posWidth))
789     return emitError() << "unexpected position bitwidth: " << posWidth;
790   if (!acceptBitWidth(crdWidth))
791     return emitError() << "unexpected coordinate bitwidth: " << crdWidth;
792 
793   // Verify every COO segment.
794   auto *it = std::find_if(lvlTypes.begin(), lvlTypes.end(), isSingletonLT);
795   while (it != lvlTypes.end()) {
796     if (it == lvlTypes.begin() ||
797         !(it - 1)->isa<LevelFormat::Compressed, LevelFormat::LooseCompressed>())
798       return emitError() << "expected compressed or loose_compressed level "
799                             "before singleton level";
800 
801     auto *curCOOEnd = std::find_if_not(it, lvlTypes.end(), isSingletonLT);
802     if (!std::all_of(it, curCOOEnd,
803                      [](LevelType i) { return isSingletonLT(i); }))
804       return emitError() << "expected all singleton lvlTypes "
805                             "following a singleton level";
806     // We can potentially support mixed SoA/AoS singleton levels.
807     if (!std::all_of(it, curCOOEnd, [it](LevelType i) {
808           return it->isa<LevelPropNonDefault::SoA>() ==
809                  i.isa<LevelPropNonDefault::SoA>();
810         })) {
811       return emitError() << "expected all singleton lvlTypes stored in the "
812                             "same memory layout (SoA vs AoS).";
813     }
814     it = std::find_if(curCOOEnd, lvlTypes.end(), isSingletonLT);
815   }
816 
817   auto lastBatch = std::find_if(lvlTypes.rbegin(), lvlTypes.rend(), isBatchLT);
818   if (!std::all_of(lastBatch, lvlTypes.rend(), isBatchLT))
819     return emitError() << "Batch lvlType can only be leading levels.";
820 
821   // SoA property can only be applied on singleton level.
822   auto soaLvls = llvm::make_filter_range(lvlTypes, [](LevelType lt) {
823     return lt.isa<LevelPropNonDefault::SoA>();
824   });
825   if (llvm::any_of(soaLvls, [](LevelType lt) {
826         return !lt.isa<LevelFormat::Singleton>();
827       })) {
828     return emitError() << "SoA is only applicable to singleton lvlTypes.";
829   }
830 
831   // TODO: audit formats that actually are supported by backend.
832   if (auto it = std::find_if(lvlTypes.begin(), lvlTypes.end(), isNOutOfMLT);
833       it != std::end(lvlTypes)) {
834     if (it != lvlTypes.end() - 1)
835       return emitError() << "expected n_out_of_m to be the last level type";
836     if (!std::all_of(lvlTypes.begin(), it,
837                      [](LevelType i) { return isDenseLT(i); }))
838       return emitError() << "expected all dense lvlTypes "
839                             "before a n_out_of_m level";
840     if (dimToLvl && (dimToLvl.getNumDims() != dimToLvl.getNumResults())) {
841       if (!isBlockSparsity(dimToLvl)) {
842         return emitError()
843                << "expected 1xm block structure for n_out_of_m level";
844       }
845       auto sizes = getBlockSize(dimToLvl);
846       unsigned coefficient = 0;
847       for (const auto &elem : sizes) {
848         if (elem != 0) {
849           if (elem != coefficient && coefficient != 0) {
850             return emitError() << "expected only one blocked level "
851                                   "with the same coefficients";
852           }
853           coefficient = elem;
854         }
855       }
856       if (coefficient != getM(*it)) {
857         return emitError() << "expected coeffiencts of Affine expressions "
858                               "to be equal to m of n_out_of_m level";
859       }
860     }
861   }
862   // Before we can check that the level-rank is consistent/coherent
863   // across all fields, we need to define it.  The source-of-truth for
864   // the `getLvlRank` method is the length of the level-types array,
865   // since it must always be provided and have full rank; therefore we
866   // use that same source-of-truth here.
867   const Level lvlRank = lvlTypes.size();
868   if (lvlRank == 0)
869     return emitError() << "expected a non-empty array for lvlTypes";
870   // We save `dimRank` here because we'll also need it to verify `dimSlices`.
871   const Dimension dimRank = dimToLvl ? dimToLvl.getNumDims() : lvlRank;
872   if (dimToLvl) {
873     if (dimToLvl.getNumResults() != lvlRank)
874       return emitError()
875              << "level-rank mismatch between dimToLvl and lvlTypes: "
876              << dimToLvl.getNumResults() << " != " << lvlRank;
877     auto inferRes = inferLvlToDim(dimToLvl, dimToLvl.getContext());
878     // Symbols can't be inferred but are acceptable.
879     if (!inferRes && dimToLvl.getNumSymbols() == 0)
880       return emitError() << "failed to infer lvlToDim from dimToLvl";
881     if (lvlToDim && (inferRes != lvlToDim))
882       return emitError() << "expected lvlToDim to be an inverse of dimToLvl";
883     if (dimRank > lvlRank)
884       return emitError() << "unexpected dimToLvl mapping from " << dimRank
885                          << " to " << lvlRank;
886   }
887   if (!dimSlices.empty()) {
888     if (dimSlices.size() != dimRank)
889       return emitError()
890              << "dimension-rank mismatch between dimSlices and dimToLvl: "
891              << dimSlices.size() << " != " << dimRank;
892     // Compiler support for `dimSlices` currently requires that the two
893     // ranks agree.  (However, it does allow `dimToLvl` to be a permutation.)
894     if (dimRank != lvlRank)
895       return emitError()
896              << "dimSlices expected dimension-rank to match level-rank: "
897              << dimRank << " != " << lvlRank;
898   }
899   return success();
900 }
901 
902 LogicalResult SparseTensorEncodingAttr::verifyEncoding(
903     ArrayRef<Size> dimShape, Type elementType,
904     function_ref<InFlightDiagnostic()> emitError) const {
905   // Check structural integrity.  In particular, this ensures that the
906   // level-rank is coherent across all the fields.
907   if (failed(verify(emitError, getLvlTypes(), getDimToLvl(), getLvlToDim(),
908                     getPosWidth(), getCrdWidth(), getExplicitVal(),
909                     getImplicitVal(), getDimSlices())))
910     return failure();
911   // Check integrity with tensor type specifics.  In particular, we
912   // need only check that the dimension-rank of the tensor agrees with
913   // the dimension-rank of the encoding.
914   const Dimension dimRank = dimShape.size();
915   if (dimRank == 0)
916     return emitError() << "expected non-scalar sparse tensor";
917   if (getDimRank() != dimRank)
918     return emitError()
919            << "dimension-rank mismatch between encoding and tensor shape: "
920            << getDimRank() << " != " << dimRank;
921   if (auto expVal = getExplicitVal()) {
922     Type attrType = llvm::dyn_cast<TypedAttr>(expVal).getType();
923     if (attrType != elementType) {
924       return emitError() << "explicit value type mismatch between encoding and "
925                          << "tensor element type: " << attrType
926                          << " != " << elementType;
927     }
928   }
929   if (auto impVal = getImplicitVal()) {
930     Type attrType = llvm::dyn_cast<TypedAttr>(impVal).getType();
931     if (attrType != elementType) {
932       return emitError() << "implicit value type mismatch between encoding and "
933                          << "tensor element type: " << attrType
934                          << " != " << elementType;
935     }
936     // Currently, we only support zero as the implicit value.
937     auto impFVal = llvm::dyn_cast<FloatAttr>(impVal);
938     auto impIntVal = llvm::dyn_cast<IntegerAttr>(impVal);
939     auto impComplexVal = llvm::dyn_cast<complex::NumberAttr>(impVal);
940     if ((impFVal && impFVal.getValue().isNonZero()) ||
941         (impIntVal && !impIntVal.getValue().isZero()) ||
942         (impComplexVal && (impComplexVal.getImag().isNonZero() ||
943                            impComplexVal.getReal().isNonZero()))) {
944       return emitError() << "implicit value must be zero";
945     }
946   }
947   return success();
948 }
949 
950 Level mlir::sparse_tensor::SparseTensorEncodingAttr::getAoSCOOStart() const {
951   SmallVector<COOSegment> coo = getCOOSegments();
952   assert(coo.size() == 1 || coo.empty());
953   if (!coo.empty() && coo.front().isAoS()) {
954     return coo.front().lvlRange.first;
955   }
956   return getLvlRank();
957 }
958 
959 SmallVector<COOSegment>
960 mlir::sparse_tensor::SparseTensorEncodingAttr::getCOOSegments() const {
961   SmallVector<COOSegment> ret;
962   if (getLvlRank() <= 1)
963     return ret;
964 
965   ArrayRef<LevelType> lts = getLvlTypes();
966   Level l = 0;
967   while (l < getLvlRank()) {
968     auto lt = lts[l];
969     if (lt.isa<LevelFormat::Compressed, LevelFormat::LooseCompressed>()) {
970       auto cur = lts.begin() + l;
971       auto end = std::find_if(cur + 1, lts.end(), [](LevelType lt) {
972         return !lt.isa<LevelFormat::Singleton>();
973       });
974       unsigned cooLen = std::distance(cur, end);
975       if (cooLen > 1) {
976         // To support mixed SoA/AoS COO, we should break the segment when the
977         // storage scheme changes, for now we faithfully assume that all
978         // consecutive singleton levels have the same storage format as verified
979         // STEA.
980         ret.push_back(COOSegment{std::make_pair(l, l + cooLen),
981                                  lts[l + 1].isa<LevelPropNonDefault::SoA>()});
982       }
983       l += cooLen;
984     } else {
985       l++;
986     }
987   }
988   return ret;
989 }
990 
991 //===----------------------------------------------------------------------===//
992 // SparseTensorType Methods.
993 //===----------------------------------------------------------------------===//
994 
995 bool mlir::sparse_tensor::SparseTensorType::isCOOType(Level startLvl,
996                                                       bool isUnique) const {
997   if (!hasEncoding())
998     return false;
999   if (!isCompressedLvl(startLvl) && !isLooseCompressedLvl(startLvl))
1000     return false;
1001   for (Level l = startLvl + 1; l < lvlRank; ++l)
1002     if (!isSingletonLvl(l))
1003       return false;
1004   // If isUnique is true, then make sure that the last level is unique,
1005   // that is, when lvlRank == 1, the only compressed level is unique,
1006   // and when lvlRank > 1, the last singleton is unique.
1007   return !isUnique || isUniqueLvl(lvlRank - 1);
1008 }
1009 
1010 RankedTensorType
1011 mlir::sparse_tensor::SparseTensorType::getCOOType(bool ordered) const {
1012   SmallVector<LevelType> lvlTypes;
1013   lvlTypes.reserve(lvlRank);
1014   // A non-unique compressed level at beginning (unless this is
1015   // also the last level, then it is unique).
1016   lvlTypes.push_back(
1017       *buildLevelType(LevelFormat::Compressed, ordered, lvlRank == 1));
1018   if (lvlRank > 1) {
1019     // Followed by n-2 non-unique singleton levels.
1020     std::fill_n(std::back_inserter(lvlTypes), lvlRank - 2,
1021                 *buildLevelType(LevelFormat::Singleton, ordered, false));
1022     // Ends by a unique singleton level.
1023     lvlTypes.push_back(*buildLevelType(LevelFormat::Singleton, ordered, true));
1024   }
1025   auto enc = SparseTensorEncodingAttr::get(
1026       getContext(), lvlTypes, getDimToLvl(), getLvlToDim(), getPosWidth(),
1027       getCrdWidth(), getExplicitVal(), getImplicitVal());
1028   return RankedTensorType::get(getDimShape(), getElementType(), enc);
1029 }
1030 
1031 //===----------------------------------------------------------------------===//
1032 // Convenience Methods.
1033 //===----------------------------------------------------------------------===//
1034 
1035 SparseTensorEncodingAttr
1036 mlir::sparse_tensor::getSparseTensorEncoding(Type type) {
1037   if (auto ttp = llvm::dyn_cast<RankedTensorType>(type))
1038     return llvm::dyn_cast_or_null<SparseTensorEncodingAttr>(ttp.getEncoding());
1039   if (auto mdtp = llvm::dyn_cast<StorageSpecifierType>(type))
1040     return mdtp.getEncoding();
1041   return nullptr;
1042 }
1043 
1044 AffineMap mlir::sparse_tensor::inferLvlToDim(AffineMap dimToLvl,
1045                                              MLIRContext *context) {
1046   auto map = static_cast<AffineMap>(dimToLvl);
1047   AffineMap lvlToDim;
1048   // Return an empty lvlToDim when inference is not successful.
1049   if (!map || map.getNumSymbols() != 0) {
1050     lvlToDim = AffineMap();
1051   } else if (map.isPermutation()) {
1052     lvlToDim = inversePermutation(map);
1053   } else if (isBlockSparsity(map)) {
1054     lvlToDim = inverseBlockSparsity(map, context);
1055   }
1056   return lvlToDim;
1057 }
1058 
1059 AffineMap mlir::sparse_tensor::inverseBlockSparsity(AffineMap dimToLvl,
1060                                                     MLIRContext *context) {
1061   SmallVector<AffineExpr> lvlExprs;
1062   auto numLvls = dimToLvl.getNumResults();
1063   lvlExprs.reserve(numLvls);
1064   // lvlExprComponents stores information of the floordiv and mod operations
1065   // applied to the same dimension, so as to build the lvlToDim map.
1066   std::map<unsigned, SmallVector<AffineExpr, 3>> lvlExprComponents;
1067   for (unsigned i = 0, n = numLvls; i < n; i++) {
1068     auto result = dimToLvl.getResult(i);
1069     if (auto binOp = dyn_cast<AffineBinaryOpExpr>(result)) {
1070       if (result.getKind() == AffineExprKind::FloorDiv) {
1071         // Position of the dimension in dimToLvl.
1072         auto pos = dyn_cast<AffineDimExpr>(binOp.getLHS()).getPosition();
1073         assert(lvlExprComponents.find(pos) == lvlExprComponents.end() &&
1074                "expected only one floordiv for each dimension");
1075         SmallVector<AffineExpr, 3> components;
1076         // Level variable for floordiv.
1077         components.push_back(getAffineDimExpr(i, context));
1078         // Multiplier.
1079         components.push_back(binOp.getRHS());
1080         // Map key is the position of the dimension.
1081         lvlExprComponents[pos] = components;
1082       } else if (result.getKind() == AffineExprKind::Mod) {
1083         auto pos = dyn_cast<AffineDimExpr>(binOp.getLHS()).getPosition();
1084         assert(lvlExprComponents.find(pos) != lvlExprComponents.end() &&
1085                "expected floordiv before mod");
1086         // Add level variable for mod to the same vector
1087         // of the corresponding floordiv.
1088         lvlExprComponents[pos].push_back(getAffineDimExpr(i, context));
1089       } else {
1090         assert(false && "expected floordiv or mod");
1091       }
1092     } else {
1093       lvlExprs.push_back(getAffineDimExpr(i, context));
1094     }
1095   }
1096   // Build lvlExprs from lvlExprComponents.
1097   // For example, for il = i floordiv 2 and ii = i mod 2, the components
1098   // would be [il, 2, ii]. It could be used to build the AffineExpr
1099   // i = il * 2 + ii in lvlToDim.
1100   for (auto &components : lvlExprComponents) {
1101     assert(components.second.size() == 3 &&
1102            "expected 3 components to build lvlExprs");
1103     auto mulOp = getAffineBinaryOpExpr(
1104         AffineExprKind::Mul, components.second[0], components.second[1]);
1105     auto addOp =
1106         getAffineBinaryOpExpr(AffineExprKind::Add, mulOp, components.second[2]);
1107     lvlExprs.push_back(addOp);
1108   }
1109   return dimToLvl.get(dimToLvl.getNumResults(), 0, lvlExprs, context);
1110 }
1111 
1112 SmallVector<unsigned> mlir::sparse_tensor::getBlockSize(AffineMap dimToLvl) {
1113   assert(isBlockSparsity(dimToLvl) &&
1114          "expected dimToLvl to be block sparsity for calling getBlockSize");
1115   SmallVector<unsigned> blockSize;
1116   for (auto result : dimToLvl.getResults()) {
1117     if (auto binOp = dyn_cast<AffineBinaryOpExpr>(result)) {
1118       if (result.getKind() == AffineExprKind::Mod) {
1119         blockSize.push_back(
1120             dyn_cast<AffineConstantExpr>(binOp.getRHS()).getValue());
1121       }
1122     } else {
1123       blockSize.push_back(0);
1124     }
1125   }
1126   return blockSize;
1127 }
1128 
1129 bool mlir::sparse_tensor::isBlockSparsity(AffineMap dimToLvl) {
1130   if (!dimToLvl)
1131     return false;
1132   std::map<unsigned, int64_t> coeffientMap;
1133   bool hasBlock = false;
1134   for (auto result : dimToLvl.getResults()) {
1135     if (auto binOp = dyn_cast<AffineBinaryOpExpr>(result)) {
1136       // Check for "dim op const".
1137       auto dimOp = dyn_cast<AffineDimExpr>(binOp.getLHS());
1138       auto conOp = dyn_cast<AffineConstantExpr>(binOp.getRHS());
1139       if (!dimOp || !conOp || conOp.getValue() <= 0)
1140         return false;
1141       // Inspect "dim / const" or "dim % const".
1142       auto pos = dimOp.getPosition();
1143       if (binOp.getKind() == AffineExprKind::FloorDiv) {
1144         // Expect only one floordiv for each dimension.
1145         if (coeffientMap.find(pos) != coeffientMap.end())
1146           return false;
1147         // Record coefficient of the floordiv.
1148         coeffientMap[pos] = conOp.getValue();
1149       } else if (binOp.getKind() == AffineExprKind::Mod) {
1150         // Expect floordiv before mod.
1151         if (coeffientMap.find(pos) == coeffientMap.end())
1152           return false;
1153         // Expect mod to have the same coefficient as floordiv.
1154         if (conOp.getValue() != coeffientMap[pos])
1155           return false;
1156         hasBlock = true;
1157       } else {
1158         return false;
1159       }
1160     } else if (auto dimOp = dyn_cast<AffineDimExpr>(result)) {
1161       auto pos = dimOp.getPosition();
1162       // Expect dim to be unset.
1163       if (coeffientMap.find(pos) != coeffientMap.end())
1164         return false;
1165       coeffientMap[pos] = 0;
1166     } else {
1167       return false;
1168     }
1169   }
1170   return hasBlock;
1171 }
1172 
1173 bool mlir::sparse_tensor::hasAnyNonIdentityOperandsOrResults(Operation *op) {
1174   auto hasNonIdentityMap = [](Value v) {
1175     auto stt = tryGetSparseTensorType(v);
1176     return stt && !stt->isIdentity();
1177   };
1178 
1179   return llvm::any_of(op->getOperands(), hasNonIdentityMap) ||
1180          llvm::any_of(op->getResults(), hasNonIdentityMap);
1181 }
1182 
1183 Dimension mlir::sparse_tensor::toDim(SparseTensorEncodingAttr enc, Level l) {
1184   if (enc) {
1185     assert(enc.isPermutation() && "Non permutation map not supported");
1186     if (const auto dimToLvl = enc.getDimToLvl())
1187       return dimToLvl.getDimPosition(l);
1188   }
1189   return l;
1190 }
1191 
1192 Level mlir::sparse_tensor::toLvl(SparseTensorEncodingAttr enc, Dimension d) {
1193   if (enc) {
1194     assert(enc.isPermutation() && "Non permutation map not supported");
1195     if (const auto lvlToDim = enc.getLvlToDim())
1196       return lvlToDim.getDimPosition(d);
1197   }
1198   return d;
1199 }
1200 
1201 /// We normalized sparse tensor encoding attribute by always using
1202 /// ordered/unique LT such that "compressed_nu_no" and "compressed_nu" (as well
1203 /// as other variants) lead to the same storage specifier type, and stripping
1204 /// irrelevant fields that do not alter the sparse tensor memory layout.
1205 static SparseTensorEncodingAttr
1206 getNormalizedEncodingForSpecifier(SparseTensorEncodingAttr enc) {
1207   SmallVector<LevelType> lts;
1208   for (auto lt : enc.getLvlTypes())
1209     lts.push_back(lt.stripStorageIrrelevantProperties());
1210 
1211   return SparseTensorEncodingAttr::get(
1212       enc.getContext(), lts,
1213       AffineMap(), // dimToLvl (irrelevant to storage specifier)
1214       AffineMap(), // lvlToDim (irrelevant to storage specifier)
1215       // Always use `index` for memSize and lvlSize instead of reusing
1216       // `getPosWidth` and `getCrdWidth`. It allows us to reuse the same SSA
1217       // value for different bitwidth, it also avoids casting between index and
1218       // integer (returned by DimOp)
1219       0, 0,
1220       Attribute(), // explicitVal (irrelevant to storage specifier)
1221       Attribute(), // implicitVal (irrelevant to storage specifier)
1222       enc.getDimSlices());
1223 }
1224 
1225 StorageSpecifierType
1226 StorageSpecifierType::get(MLIRContext *ctx, SparseTensorEncodingAttr encoding) {
1227   return Base::get(ctx, getNormalizedEncodingForSpecifier(encoding));
1228 }
1229 
1230 //===----------------------------------------------------------------------===//
1231 // SparseTensorDialect Operations.
1232 //===----------------------------------------------------------------------===//
1233 
1234 static LogicalResult lvlIsInBounds(Level lvl, Value tensor) {
1235   return success(lvl < getSparseTensorType(tensor).getLvlRank());
1236 }
1237 
1238 static LogicalResult isMatchingWidth(Value mem, unsigned width) {
1239   const Type etp = getMemRefType(mem).getElementType();
1240   return success(width == 0 ? etp.isIndex() : etp.isInteger(width));
1241 }
1242 
1243 static LogicalResult verifySparsifierGetterSetter(
1244     StorageSpecifierKind mdKind, std::optional<Level> lvl,
1245     TypedValue<StorageSpecifierType> md, Operation *op) {
1246   if (mdKind == StorageSpecifierKind::ValMemSize && lvl) {
1247     return op->emitError(
1248         "redundant level argument for querying value memory size");
1249   }
1250 
1251   const auto enc = md.getType().getEncoding();
1252   const Level lvlRank = enc.getLvlRank();
1253 
1254   if (mdKind == StorageSpecifierKind::DimOffset ||
1255       mdKind == StorageSpecifierKind::DimStride)
1256     if (!enc.isSlice())
1257       return op->emitError("requested slice data on non-slice tensor");
1258 
1259   if (mdKind != StorageSpecifierKind::ValMemSize) {
1260     if (!lvl)
1261       return op->emitError("missing level argument");
1262 
1263     const Level l = lvl.value();
1264     if (l >= lvlRank)
1265       return op->emitError("requested level is out of bounds");
1266 
1267     if (mdKind == StorageSpecifierKind::PosMemSize && enc.isSingletonLvl(l))
1268       return op->emitError(
1269           "requested position memory size on a singleton level");
1270   }
1271   return success();
1272 }
1273 
1274 static Type getFieldElemType(SparseTensorType stt, SparseTensorFieldKind kind) {
1275   switch (kind) {
1276   case SparseTensorFieldKind::CrdMemRef:
1277     return stt.getCrdType();
1278   case SparseTensorFieldKind::PosMemRef:
1279     return stt.getPosType();
1280   case SparseTensorFieldKind::ValMemRef:
1281     return stt.getElementType();
1282   case SparseTensorFieldKind::StorageSpec:
1283     return nullptr;
1284   }
1285   llvm_unreachable("Unrecognizable FieldKind");
1286 }
1287 
1288 static LogicalResult verifyPackUnPack(Operation *op, bool requiresStaticShape,
1289                                       SparseTensorType stt,
1290                                       RankedTensorType valTp,
1291                                       TypeRange lvlTps) {
1292   if (requiresStaticShape && !stt.hasStaticDimShape())
1293     return op->emitError("the sparse-tensor must have static shape");
1294   if (!stt.hasEncoding())
1295     return op->emitError("the sparse-tensor must have an encoding attribute");
1296 
1297   // Verifies the trailing COO.
1298   Level cooStartLvl = stt.getAoSCOOStart();
1299   if (cooStartLvl < stt.getLvlRank()) {
1300     // We only supports trailing COO for now, must be the last input.
1301     auto cooTp = llvm::cast<ShapedType>(lvlTps.back());
1302     // The coordinates should be in shape of <? x rank>
1303     unsigned expCOORank = stt.getLvlRank() - cooStartLvl;
1304     if (cooTp.getRank() != 2 || expCOORank != cooTp.getShape().back()) {
1305       op->emitError("input/output trailing COO level-ranks don't match");
1306     }
1307   }
1308 
1309   // Verifies that all types match.
1310   StorageLayout layout(stt.getEncoding());
1311   if (layout.getNumDataFields() != lvlTps.size() + 1) // plus one value memref
1312     return op->emitError("inconsistent number of fields between input/output");
1313 
1314   unsigned idx = 0;
1315   bool misMatch = false;
1316   layout.foreachField([&idx, &misMatch, stt, valTp,
1317                        lvlTps](FieldIndex fid, SparseTensorFieldKind fKind,
1318                                Level lvl, LevelType lt) -> bool {
1319     if (fKind == SparseTensorFieldKind::StorageSpec)
1320       return true;
1321 
1322     Type inputTp = nullptr;
1323     if (fKind == SparseTensorFieldKind::ValMemRef) {
1324       inputTp = valTp;
1325     } else {
1326       assert(fid == idx && stt.getLvlType(lvl) == lt);
1327       inputTp = lvlTps[idx++];
1328     }
1329     // The input element type and expected element type should match.
1330     Type inpElemTp = llvm::cast<TensorType>(inputTp).getElementType();
1331     Type expElemTp = getFieldElemType(stt, fKind);
1332     if (inpElemTp != expElemTp) {
1333       misMatch = true;
1334       return false; // to terminate the iteration
1335     }
1336     return true;
1337   });
1338 
1339   if (misMatch)
1340     return op->emitError("input/output element-types don't match");
1341   return success();
1342 }
1343 
1344 LogicalResult AssembleOp::verify() {
1345   const auto valuesTp = getRankedTensorType(getValues());
1346   const auto lvlsTp = getLevels().getTypes();
1347   const auto resTp = getSparseTensorType(getResult());
1348   return verifyPackUnPack(*this, true, resTp, valuesTp, lvlsTp);
1349 }
1350 
1351 LogicalResult DisassembleOp::verify() {
1352   if (getOutValues().getType() != getRetValues().getType())
1353     return emitError("output values and return value type mismatch");
1354 
1355   for (auto [ot, rt] : llvm::zip_equal(getOutLevels(), getRetLevels()))
1356     if (ot.getType() != rt.getType())
1357       return emitError("output levels and return levels type mismatch");
1358 
1359   const auto valuesTp = getRankedTensorType(getRetValues());
1360   const auto lvlsTp = getRetLevels().getTypes();
1361   const auto srcTp = getSparseTensorType(getTensor());
1362   return verifyPackUnPack(*this, false, srcTp, valuesTp, lvlsTp);
1363 }
1364 
1365 LogicalResult ConvertOp::verify() {
1366   if (auto tp1 = llvm::dyn_cast<RankedTensorType>(getSource().getType())) {
1367     if (auto tp2 = llvm::dyn_cast<RankedTensorType>(getDest().getType())) {
1368       if (tp1.getRank() != tp2.getRank())
1369         return emitError("unexpected conversion mismatch in rank");
1370       auto dstEnc =
1371           llvm::dyn_cast_or_null<SparseTensorEncodingAttr>(tp2.getEncoding());
1372       if (dstEnc && dstEnc.isSlice())
1373         return emitError("cannot convert to a sparse tensor slice");
1374 
1375       auto shape1 = tp1.getShape();
1376       auto shape2 = tp2.getShape();
1377       // Accept size matches between the source and the destination type
1378       // (e.g. 10 vs. 10, 10 vs. ?, or ? vs. ?), but reject direct mismatches or
1379       // matches that would need a runtime assert (e.g. 10 vs. 20 or ? vs. 10).
1380       for (Dimension d = 0, dimRank = tp1.getRank(); d < dimRank; d++)
1381         if (shape1[d] != shape2[d] && shape2[d] != ShapedType::kDynamic)
1382           return emitError("unexpected conversion mismatch in dimension ") << d;
1383       return success();
1384     }
1385   }
1386   return emitError("unexpected type in convert");
1387 }
1388 
1389 OpFoldResult ConvertOp::fold(FoldAdaptor adaptor) {
1390   if (getType() == getSource().getType())
1391     return getSource();
1392   return {};
1393 }
1394 
1395 bool ConvertOp::needsExtraSort() {
1396   SparseTensorType srcStt = getSparseTensorType(getSource());
1397   SparseTensorType dstStt = getSparseTensorType(getDest());
1398 
1399   // We do not need an extra sort when returning unordered sparse tensors or
1400   // dense tensor since dense tensor support random access.
1401   if (dstStt.isAllDense() || !dstStt.isAllOrdered())
1402     return false;
1403 
1404   if (srcStt.isAllOrdered() && dstStt.isAllOrdered() &&
1405       srcStt.hasSameDimToLvl(dstStt)) {
1406     return false;
1407   }
1408 
1409   // Source and dest tensors are ordered in different ways. We only do direct
1410   // dense to sparse conversion when the dense input is defined by a sparse
1411   // constant. Note that we can theoretically always directly convert from dense
1412   // inputs by rotating dense loops but it leads to bad cache locality and hurt
1413   // performance.
1414   if (auto constOp = getSource().getDefiningOp<arith::ConstantOp>())
1415     if (isa<SparseElementsAttr>(constOp.getValue()))
1416       return false;
1417 
1418   return true;
1419 }
1420 
1421 LogicalResult CrdTranslateOp::verify() {
1422   uint64_t inRank = getEncoder().getLvlRank();
1423   uint64_t outRank = getEncoder().getDimRank();
1424 
1425   if (getDirection() == CrdTransDirectionKind::dim2lvl)
1426     std::swap(inRank, outRank);
1427 
1428   if (inRank != getInCrds().size() || outRank != getOutCrds().size())
1429     return emitError("Coordinate rank mismatch with encoding");
1430 
1431   return success();
1432 }
1433 
1434 LogicalResult CrdTranslateOp::fold(FoldAdaptor adaptor,
1435                                    SmallVectorImpl<OpFoldResult> &results) {
1436   if (getEncoder().isIdentity()) {
1437     results.assign(getInCrds().begin(), getInCrds().end());
1438     return success();
1439   }
1440   if (getEncoder().isPermutation()) {
1441     AffineMap perm = getDirection() == CrdTransDirectionKind::dim2lvl
1442                          ? getEncoder().getDimToLvl()
1443                          : getEncoder().getLvlToDim();
1444     for (AffineExpr exp : perm.getResults())
1445       results.push_back(getInCrds()[cast<AffineDimExpr>(exp).getPosition()]);
1446     return success();
1447   }
1448 
1449   // Fuse dim2lvl/lvl2dim pairs.
1450   auto def = getInCrds()[0].getDefiningOp<CrdTranslateOp>();
1451   bool sameDef = def && llvm::all_of(getInCrds(), [def](Value v) {
1452                    return v.getDefiningOp() == def;
1453                  });
1454   if (!sameDef)
1455     return failure();
1456 
1457   bool oppositeDir = def.getDirection() != getDirection();
1458   bool sameOracle =
1459       def.getEncoder().getDimToLvl() == getEncoder().getDimToLvl();
1460   bool sameCount = def.getNumResults() == getInCrds().size();
1461   if (!oppositeDir || !sameOracle || !sameCount)
1462     return failure();
1463 
1464   // The definition produces the coordinates in the same order as the input
1465   // coordinates.
1466   bool sameOrder = llvm::all_of(llvm::zip_equal(def.getOutCrds(), getInCrds()),
1467                                 [](auto valuePair) {
1468                                   auto [lhs, rhs] = valuePair;
1469                                   return lhs == rhs;
1470                                 });
1471 
1472   if (!sameOrder)
1473     return failure();
1474   // l1 = dim2lvl (lvl2dim l0)
1475   // ==> l0
1476   results.append(def.getInCrds().begin(), def.getInCrds().end());
1477   return success();
1478 }
1479 
1480 void LvlOp::build(OpBuilder &builder, OperationState &state, Value source,
1481                   int64_t index) {
1482   Value val = builder.create<arith::ConstantIndexOp>(state.location, index);
1483   return build(builder, state, source, val);
1484 }
1485 
1486 LogicalResult LvlOp::verify() {
1487   if (std::optional<uint64_t> lvl = getConstantLvlIndex()) {
1488     auto stt = getSparseTensorType(getSource());
1489     if (static_cast<uint64_t>(lvl.value()) >= stt.getLvlRank())
1490       emitError("Level index exceeds the rank of the input sparse tensor");
1491   }
1492   return success();
1493 }
1494 
1495 std::optional<uint64_t> LvlOp::getConstantLvlIndex() {
1496   return getConstantIntValue(getIndex());
1497 }
1498 
1499 Speculation::Speculatability LvlOp::getSpeculatability() {
1500   auto constantIndex = getConstantLvlIndex();
1501   if (!constantIndex)
1502     return Speculation::NotSpeculatable;
1503 
1504   assert(constantIndex <
1505          cast<RankedTensorType>(getSource().getType()).getRank());
1506   return Speculation::Speculatable;
1507 }
1508 
1509 OpFoldResult LvlOp::fold(FoldAdaptor adaptor) {
1510   auto lvlIndex = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getIndex());
1511   if (!lvlIndex)
1512     return {};
1513 
1514   Level lvl = lvlIndex.getAPSInt().getZExtValue();
1515   auto stt = getSparseTensorType(getSource());
1516   if (lvl >= stt.getLvlRank()) {
1517     // Follows the same convention used by tensor.dim operation. Out of bound
1518     // indices produce undefined behavior but are still valid IR. Don't choke on
1519     // them.
1520     return {};
1521   }
1522 
1523   // Helper lambda to build an IndexAttr.
1524   auto getIndexAttr = [this](int64_t lvlSz) {
1525     return IntegerAttr::get(IndexType::get(getContext()), APInt(64, lvlSz));
1526   };
1527 
1528   SmallVector<Size> lvlShape = stt.getLvlShape();
1529   if (!ShapedType::isDynamic(lvlShape[lvl]))
1530     return getIndexAttr(lvlShape[lvl]);
1531 
1532   return {};
1533 }
1534 
1535 void ReinterpretMapOp::build(OpBuilder &odsBuilder, OperationState &odsState,
1536                              SparseTensorEncodingAttr dstEnc, Value source) {
1537   auto srcStt = getSparseTensorType(source);
1538   SmallVector<int64_t> srcLvlShape = srcStt.getLvlShape();
1539   SmallVector<int64_t> dstDimShape =
1540       dstEnc.translateShape(srcLvlShape, CrdTransDirectionKind::lvl2dim);
1541   auto dstTp =
1542       RankedTensorType::get(dstDimShape, srcStt.getElementType(), dstEnc);
1543   return build(odsBuilder, odsState, dstTp, source);
1544 }
1545 
1546 LogicalResult ReinterpretMapOp::verify() {
1547   auto srcStt = getSparseTensorType(getSource());
1548   auto dstStt = getSparseTensorType(getDest());
1549   ArrayRef<LevelType> srcLvlTps = srcStt.getLvlTypes();
1550   ArrayRef<LevelType> dstLvlTps = dstStt.getLvlTypes();
1551 
1552   if (srcLvlTps.size() != dstLvlTps.size())
1553     return emitError("Level rank mismatch between source/dest tensors");
1554 
1555   for (auto [srcLvlTp, dstLvlTp] : llvm::zip(srcLvlTps, dstLvlTps))
1556     if (srcLvlTp != dstLvlTp)
1557       return emitError("Level type mismatch between source/dest tensors");
1558 
1559   if (srcStt.getPosWidth() != dstStt.getPosWidth() ||
1560       srcStt.getCrdWidth() != dstStt.getCrdWidth()) {
1561     return emitError("Crd/Pos width mismatch between source/dest tensors");
1562   }
1563 
1564   if (srcStt.getElementType() != dstStt.getElementType())
1565     return emitError("Element type mismatch between source/dest tensors");
1566 
1567   SmallVector<Size> srcLvlShape = srcStt.getLvlShape();
1568   SmallVector<Size> dstLvlShape = dstStt.getLvlShape();
1569   for (auto [srcLvlSz, dstLvlSz] : llvm::zip(srcLvlShape, dstLvlShape)) {
1570     if (srcLvlSz != dstLvlSz) {
1571       // Should we allow one side to be dynamic size, e.g., <?x?> should be
1572       // compatible to <3x4>? For now, we require all the level sizes to be
1573       // *exactly* matched for simplicity.
1574       return emitError("Level size mismatch between source/dest tensors");
1575     }
1576   }
1577 
1578   return success();
1579 }
1580 
1581 OpFoldResult ReinterpretMapOp::fold(FoldAdaptor adaptor) {
1582   if (getSource().getType() == getDest().getType())
1583     return getSource();
1584 
1585   if (auto def = getSource().getDefiningOp<ReinterpretMapOp>()) {
1586     // A -> B, B -> A ==> A
1587     if (def.getSource().getType() == getDest().getType())
1588       return def.getSource();
1589   }
1590   return {};
1591 }
1592 
1593 template <typename ToBufferOp>
1594 static LogicalResult inferSparseBufferType(ValueRange ops, DictionaryAttr attr,
1595                                            OpaqueProperties prop,
1596                                            RegionRange region,
1597                                            SmallVectorImpl<mlir::Type> &ret) {
1598   typename ToBufferOp::Adaptor adaptor(ops, attr, prop, region);
1599   SparseTensorType stt = getSparseTensorType(adaptor.getTensor());
1600   Type elemTp = nullptr;
1601   bool withStride = false;
1602   if constexpr (std::is_same_v<ToBufferOp, ToPositionsOp>) {
1603     elemTp = stt.getPosType();
1604   } else if constexpr (std::is_same_v<ToBufferOp, ToCoordinatesOp> ||
1605                        std::is_same_v<ToBufferOp, ToCoordinatesBufferOp>) {
1606     elemTp = stt.getCrdType();
1607     if constexpr (std::is_same_v<ToBufferOp, ToCoordinatesOp>)
1608       withStride = stt.getAoSCOOStart() <= adaptor.getLevel();
1609   } else if constexpr (std::is_same_v<ToBufferOp, ToValuesOp>) {
1610     elemTp = stt.getElementType();
1611   }
1612 
1613   assert(elemTp && "unhandled operation.");
1614   SmallVector<int64_t> bufShape = stt.getBatchLvlShape();
1615   bufShape.push_back(ShapedType::kDynamic);
1616 
1617   auto layout = withStride ? StridedLayoutAttr::StridedLayoutAttr::get(
1618                                  stt.getContext(), ShapedType::kDynamic,
1619                                  {ShapedType::kDynamic})
1620                            : StridedLayoutAttr();
1621   ret.emplace_back(MemRefType::get(bufShape, elemTp, layout));
1622   return success();
1623 }
1624 
1625 LogicalResult ToPositionsOp::verify() {
1626   auto stt = getSparseTensorType(getTensor());
1627   if (failed(lvlIsInBounds(getLevel(), getTensor())))
1628     return emitError("requested level is out of bounds");
1629   if (failed(isMatchingWidth(getResult(), stt.getPosWidth())))
1630     return emitError("unexpected type for positions");
1631   return success();
1632 }
1633 
1634 LogicalResult
1635 ToPositionsOp::inferReturnTypes(MLIRContext *ctx, std::optional<Location> loc,
1636                                 ValueRange ops, DictionaryAttr attr,
1637                                 OpaqueProperties prop, RegionRange region,
1638                                 SmallVectorImpl<mlir::Type> &ret) {
1639   return inferSparseBufferType<ToPositionsOp>(ops, attr, prop, region, ret);
1640 }
1641 
1642 LogicalResult ToCoordinatesOp::verify() {
1643   auto stt = getSparseTensorType(getTensor());
1644   if (failed(lvlIsInBounds(getLevel(), getTensor())))
1645     return emitError("requested level is out of bounds");
1646   if (failed(isMatchingWidth(getResult(), stt.getCrdWidth())))
1647     return emitError("unexpected type for coordinates");
1648   return success();
1649 }
1650 
1651 LogicalResult
1652 ToCoordinatesOp::inferReturnTypes(MLIRContext *ctx, std::optional<Location> loc,
1653                                   ValueRange ops, DictionaryAttr attr,
1654                                   OpaqueProperties prop, RegionRange region,
1655                                   SmallVectorImpl<mlir::Type> &ret) {
1656   return inferSparseBufferType<ToCoordinatesOp>(ops, attr, prop, region, ret);
1657 }
1658 
1659 LogicalResult ToCoordinatesBufferOp::verify() {
1660   auto stt = getSparseTensorType(getTensor());
1661   if (stt.getAoSCOOStart() >= stt.getLvlRank())
1662     return emitError("expected sparse tensor with a COO region");
1663   return success();
1664 }
1665 
1666 LogicalResult ToCoordinatesBufferOp::inferReturnTypes(
1667     MLIRContext *ctx, std::optional<Location> loc, ValueRange ops,
1668     DictionaryAttr attr, OpaqueProperties prop, RegionRange region,
1669     SmallVectorImpl<mlir::Type> &ret) {
1670   return inferSparseBufferType<ToCoordinatesBufferOp>(ops, attr, prop, region,
1671                                                       ret);
1672 }
1673 
1674 LogicalResult ToValuesOp::verify() {
1675   auto stt = getSparseTensorType(getTensor());
1676   auto mtp = getMemRefType(getResult());
1677   if (stt.getElementType() != mtp.getElementType())
1678     return emitError("unexpected mismatch in element types");
1679   return success();
1680 }
1681 
1682 LogicalResult ToValuesOp::inferReturnTypes(MLIRContext *ctx,
1683                                            std::optional<Location> loc,
1684                                            ValueRange ops, DictionaryAttr attr,
1685                                            OpaqueProperties prop,
1686                                            RegionRange region,
1687                                            SmallVectorImpl<mlir::Type> &ret) {
1688   return inferSparseBufferType<ToValuesOp>(ops, attr, prop, region, ret);
1689 }
1690 
1691 LogicalResult ToSliceOffsetOp::verify() {
1692   auto rank = getRankedTensorType(getSlice()).getRank();
1693   if (rank <= getDim().getSExtValue() || getDim().getSExtValue() < 0)
1694     return emitError("requested dimension out of bound");
1695   return success();
1696 }
1697 
1698 LogicalResult ToSliceStrideOp::verify() {
1699   auto rank = getRankedTensorType(getSlice()).getRank();
1700   if (rank <= getDim().getSExtValue() || getDim().getSExtValue() < 0)
1701     return emitError("requested dimension out of bound");
1702   return success();
1703 }
1704 
1705 LogicalResult GetStorageSpecifierOp::verify() {
1706   return verifySparsifierGetterSetter(getSpecifierKind(), getLevel(),
1707                                       getSpecifier(), getOperation());
1708 }
1709 
1710 template <typename SpecifierOp>
1711 static SetStorageSpecifierOp getSpecifierSetDef(SpecifierOp op) {
1712   return op.getSpecifier().template getDefiningOp<SetStorageSpecifierOp>();
1713 }
1714 
1715 OpFoldResult GetStorageSpecifierOp::fold(FoldAdaptor adaptor) {
1716   const StorageSpecifierKind kind = getSpecifierKind();
1717   const auto lvl = getLevel();
1718   for (auto op = getSpecifierSetDef(*this); op; op = getSpecifierSetDef(op))
1719     if (kind == op.getSpecifierKind() && lvl == op.getLevel())
1720       return op.getValue();
1721   return {};
1722 }
1723 
1724 LogicalResult SetStorageSpecifierOp::verify() {
1725   return verifySparsifierGetterSetter(getSpecifierKind(), getLevel(),
1726                                       getSpecifier(), getOperation());
1727 }
1728 
1729 template <class T>
1730 static LogicalResult verifyNumBlockArgs(T *op, Region &region,
1731                                         const char *regionName,
1732                                         TypeRange inputTypes, Type outputType) {
1733   unsigned numArgs = region.getNumArguments();
1734   unsigned expectedNum = inputTypes.size();
1735   if (numArgs != expectedNum)
1736     return op->emitError() << regionName << " region must have exactly "
1737                            << expectedNum << " arguments";
1738 
1739   for (unsigned i = 0; i < numArgs; i++) {
1740     Type typ = region.getArgument(i).getType();
1741     if (typ != inputTypes[i])
1742       return op->emitError() << regionName << " region argument " << (i + 1)
1743                              << " type mismatch";
1744   }
1745   Operation *term = region.front().getTerminator();
1746   YieldOp yield = dyn_cast<YieldOp>(term);
1747   if (!yield)
1748     return op->emitError() << regionName
1749                            << " region must end with sparse_tensor.yield";
1750   if (!yield.hasSingleResult() ||
1751       yield.getSingleResult().getType() != outputType)
1752     return op->emitError() << regionName << " region yield type mismatch";
1753 
1754   return success();
1755 }
1756 
1757 LogicalResult BinaryOp::verify() {
1758   NamedAttrList attrs = (*this)->getAttrs();
1759   Type leftType = getX().getType();
1760   Type rightType = getY().getType();
1761   Type outputType = getOutput().getType();
1762   Region &overlap = getOverlapRegion();
1763   Region &left = getLeftRegion();
1764   Region &right = getRightRegion();
1765 
1766   // Check correct number of block arguments and return type for each
1767   // non-empty region.
1768   if (!overlap.empty()) {
1769     if (failed(verifyNumBlockArgs(this, overlap, "overlap",
1770                                   TypeRange{leftType, rightType}, outputType)))
1771       return failure();
1772   }
1773   if (!left.empty()) {
1774     if (failed(verifyNumBlockArgs(this, left, "left", TypeRange{leftType},
1775                                   outputType)))
1776       return failure();
1777   } else if (getLeftIdentity()) {
1778     if (leftType != outputType)
1779       return emitError("left=identity requires first argument to have the same "
1780                        "type as the output");
1781   }
1782   if (!right.empty()) {
1783     if (failed(verifyNumBlockArgs(this, right, "right", TypeRange{rightType},
1784                                   outputType)))
1785       return failure();
1786   } else if (getRightIdentity()) {
1787     if (rightType != outputType)
1788       return emitError("right=identity requires second argument to have the "
1789                        "same type as the output");
1790   }
1791   return success();
1792 }
1793 
1794 LogicalResult UnaryOp::verify() {
1795   Type inputType = getX().getType();
1796   Type outputType = getOutput().getType();
1797 
1798   // Check correct number of block arguments and return type for each
1799   // non-empty region.
1800   Region &present = getPresentRegion();
1801   if (!present.empty()) {
1802     if (failed(verifyNumBlockArgs(this, present, "present",
1803                                   TypeRange{inputType}, outputType)))
1804       return failure();
1805   }
1806   Region &absent = getAbsentRegion();
1807   if (!absent.empty()) {
1808     if (failed(verifyNumBlockArgs(this, absent, "absent", TypeRange{},
1809                                   outputType)))
1810       return failure();
1811     // Absent branch can only yield invariant values.
1812     Block *absentBlock = &absent.front();
1813     Block *parent = getOperation()->getBlock();
1814     Value absentVal =
1815         cast<YieldOp>(absentBlock->getTerminator()).getSingleResult();
1816     if (auto arg = dyn_cast<BlockArgument>(absentVal)) {
1817       if (arg.getOwner() == parent)
1818         return emitError("absent region cannot yield linalg argument");
1819     } else if (Operation *def = absentVal.getDefiningOp()) {
1820       if (!isa<arith::ConstantOp>(def) &&
1821           (def->getBlock() == absentBlock || def->getBlock() == parent))
1822         return emitError("absent region cannot yield locally computed value");
1823     }
1824   }
1825   return success();
1826 }
1827 
1828 bool ConcatenateOp::needsExtraSort() {
1829   SparseTensorType dstStt = getSparseTensorType(*this);
1830   if (dstStt.isAllDense() || !dstStt.isAllOrdered())
1831     return false;
1832 
1833   bool allSameOrdered = llvm::all_of(getInputs(), [dstStt](Value op) {
1834     return getSparseTensorType(op).hasSameDimToLvl(dstStt);
1835   });
1836   // TODO: When conDim != 0, as long as conDim corresponding to the first level
1837   // in all input/output buffers, and all input/output buffers have the same
1838   // dimToLvl, the tmp COO buffer is still unnecessary (e.g, concatenate
1839   // CSC matrices along column).
1840   bool directLowerable =
1841       allSameOrdered && getDimension() == 0 && dstStt.isIdentity();
1842   return !directLowerable;
1843 }
1844 
1845 LogicalResult ConcatenateOp::verify() {
1846   const auto dstTp = getSparseTensorType(*this);
1847   const Dimension concatDim = getDimension();
1848   const Dimension dimRank = dstTp.getDimRank();
1849 
1850   if (getInputs().size() <= 1)
1851     return emitError("Need at least two tensors to concatenate.");
1852 
1853   if (concatDim >= dimRank)
1854     return emitError(llvm::formatv(
1855         "Concat-dimension is out of bounds for dimension-rank ({0} >= {1})",
1856         concatDim, dimRank));
1857 
1858   for (const auto &it : llvm::enumerate(getInputs())) {
1859     const auto i = it.index();
1860     const auto srcTp = getSparseTensorType(it.value());
1861     if (srcTp.hasDynamicDimShape())
1862       return emitError(llvm::formatv("Input tensor ${0} has dynamic shape", i));
1863     const Dimension srcDimRank = srcTp.getDimRank();
1864     if (srcDimRank != dimRank)
1865       return emitError(
1866           llvm::formatv("Input tensor ${0} has a different rank (rank={1}) "
1867                         "from the output tensor (rank={2}).",
1868                         i, srcDimRank, dimRank));
1869   }
1870 
1871   for (Dimension d = 0; d < dimRank; d++) {
1872     const Size dstSh = dstTp.getDimShape()[d];
1873     if (d == concatDim) {
1874       if (!ShapedType::isDynamic(dstSh)) {
1875         // If we reach here, then all inputs have static shapes.  So we
1876         // can use `getDimShape()[d]` instead of `*getDynamicDimSize(d)`
1877         // to avoid redundant assertions in the loop.
1878         Size sumSz = 0;
1879         for (const auto src : getInputs())
1880           sumSz += getSparseTensorType(src).getDimShape()[d];
1881         // If all dimension are statically known, the sum of all the input
1882         // dimensions should be equal to the output dimension.
1883         if (sumSz != dstSh)
1884           return emitError(
1885               "The concatenation dimension of the output tensor should be the "
1886               "sum of all the concatenation dimensions of the input tensors.");
1887       }
1888     } else {
1889       Size prev = dstSh;
1890       for (const auto src : getInputs()) {
1891         const auto sh = getSparseTensorType(src).getDimShape()[d];
1892         if (!ShapedType::isDynamic(prev) && sh != prev)
1893           return emitError("All dimensions (expect for the concatenating one) "
1894                            "should be equal.");
1895         prev = sh;
1896       }
1897     }
1898   }
1899 
1900   return success();
1901 }
1902 
1903 void PushBackOp::build(OpBuilder &builder, OperationState &result,
1904                        Value curSize, Value inBuffer, Value value) {
1905   build(builder, result, curSize, inBuffer, value, Value());
1906 }
1907 
1908 LogicalResult PushBackOp::verify() {
1909   if (Value n = getN()) {
1910     std::optional<int64_t> nValue = getConstantIntValue(n);
1911     if (nValue && nValue.value() < 1)
1912       return emitOpError("n must be not less than 1");
1913   }
1914   return success();
1915 }
1916 
1917 LogicalResult CompressOp::verify() {
1918   const auto stt = getSparseTensorType(getTensor());
1919   if (stt.getLvlRank() != 1 + static_cast<Level>(getLvlCoords().size()))
1920     return emitOpError("incorrect number of coordinates");
1921   return success();
1922 }
1923 
1924 void ForeachOp::build(
1925     OpBuilder &builder, OperationState &result, Value tensor,
1926     ValueRange initArgs, AffineMapAttr order,
1927     function_ref<void(OpBuilder &, Location, ValueRange, Value, ValueRange)>
1928         bodyBuilder) {
1929   build(builder, result, initArgs.getTypes(), tensor, initArgs, order);
1930   // Builds foreach body.
1931   if (!bodyBuilder)
1932     return;
1933   const auto stt = getSparseTensorType(tensor);
1934   const Dimension dimRank = stt.getDimRank();
1935 
1936   // Starts with `dimRank`-many coordinates.
1937   SmallVector<Type> blockArgTypes(dimRank, builder.getIndexType());
1938   // Followed by one value.
1939   blockArgTypes.push_back(stt.getElementType());
1940   // Followed by the reduction variables.
1941   blockArgTypes.append(initArgs.getTypes().begin(), initArgs.getTypes().end());
1942 
1943   SmallVector<Location> blockArgLocs(blockArgTypes.size(), tensor.getLoc());
1944 
1945   OpBuilder::InsertionGuard guard(builder);
1946   auto &region = *result.regions.front();
1947   Block *bodyBlock =
1948       builder.createBlock(&region, region.end(), blockArgTypes, blockArgLocs);
1949   bodyBuilder(builder, result.location,
1950               bodyBlock->getArguments().slice(0, dimRank),
1951               bodyBlock->getArguments()[dimRank],
1952               bodyBlock->getArguments().drop_front(dimRank + 1));
1953 }
1954 
1955 LogicalResult ForeachOp::verify() {
1956   const auto t = getSparseTensorType(getTensor());
1957   const Dimension dimRank = t.getDimRank();
1958   const auto args = getBody()->getArguments();
1959 
1960   if (getOrder().has_value() && getOrder()->getNumDims() != t.getLvlRank())
1961     return emitError("Level traverse order does not match tensor's level rank");
1962 
1963   if (dimRank + 1 + getInitArgs().size() != args.size())
1964     return emitError("Unmatched number of arguments in the block");
1965 
1966   if (getNumResults() != getInitArgs().size())
1967     return emitError("Mismatch in number of init arguments and results");
1968 
1969   if (getResultTypes() != getInitArgs().getTypes())
1970     return emitError("Mismatch in types of init arguments and results");
1971 
1972   // Cannot mark this const, because the getters aren't.
1973   auto yield = cast<YieldOp>(getBody()->getTerminator());
1974   if (yield.getNumOperands() != getNumResults() ||
1975       yield.getOperands().getTypes() != getResultTypes())
1976     return emitError("Mismatch in types of yield values and results");
1977 
1978   const auto iTp = IndexType::get(getContext());
1979   for (Dimension d = 0; d < dimRank; d++)
1980     if (args[d].getType() != iTp)
1981       emitError(
1982           llvm::formatv("Expecting Index type for argument at index {0}", d));
1983 
1984   const auto elemTp = t.getElementType();
1985   const auto valueTp = args[dimRank].getType();
1986   if (elemTp != valueTp)
1987     emitError(llvm::formatv("Unmatched element type between input tensor and "
1988                             "block argument, expected:{0}, got: {1}",
1989                             elemTp, valueTp));
1990   return success();
1991 }
1992 
1993 OpFoldResult ReorderCOOOp::fold(FoldAdaptor adaptor) {
1994   if (getSparseTensorEncoding(getInputCoo().getType()) ==
1995       getSparseTensorEncoding(getResultCoo().getType()))
1996     return getInputCoo();
1997 
1998   return {};
1999 }
2000 
2001 LogicalResult ReorderCOOOp::verify() {
2002   SparseTensorType srcStt = getSparseTensorType(getInputCoo());
2003   SparseTensorType dstStt = getSparseTensorType(getResultCoo());
2004 
2005   if (!srcStt.isCOOType() || !dstStt.isCOOType())
2006     emitError("Expected COO sparse tensors only");
2007 
2008   if (!srcStt.hasSameDimToLvl(dstStt))
2009     emitError("Unmatched dim2lvl map between input and result COO");
2010 
2011   if (srcStt.getPosType() != dstStt.getPosType() ||
2012       srcStt.getCrdType() != dstStt.getCrdType() ||
2013       srcStt.getElementType() != dstStt.getElementType())
2014     emitError("Unmatched storage format between input and result COO");
2015 
2016   return success();
2017 }
2018 
2019 LogicalResult ReduceOp::verify() {
2020   Type inputType = getX().getType();
2021   Region &formula = getRegion();
2022   return verifyNumBlockArgs(this, formula, "reduce",
2023                             TypeRange{inputType, inputType}, inputType);
2024 }
2025 
2026 LogicalResult SelectOp::verify() {
2027   Builder b(getContext());
2028   Type inputType = getX().getType();
2029   Type boolType = b.getI1Type();
2030   Region &formula = getRegion();
2031   return verifyNumBlockArgs(this, formula, "select", TypeRange{inputType},
2032                             boolType);
2033 }
2034 
2035 LogicalResult SortOp::verify() {
2036   AffineMap xPerm = getPermMap();
2037   uint64_t nx = xPerm.getNumDims();
2038   if (nx < 1)
2039     emitError(llvm::formatv("Expected rank(perm_map) > 1, got {0}", nx));
2040 
2041   if (!xPerm.isPermutation())
2042     emitError(llvm::formatv("Expected a permutation map, got {0}", xPerm));
2043 
2044   // We can't check the size of the buffers when n or buffer dimensions aren't
2045   // compile-time constants.
2046   std::optional<int64_t> cn = getConstantIntValue(getN());
2047   if (!cn)
2048     return success();
2049 
2050   // Verify dimensions.
2051   const auto checkDim = [&](Value v, Size minSize, const char *message) {
2052     const Size sh = getMemRefType(v).getShape()[0];
2053     if (!ShapedType::isDynamic(sh) && sh < minSize)
2054       emitError(llvm::formatv("{0} got {1} < {2}", message, sh, minSize));
2055   };
2056   uint64_t n = cn.value();
2057   uint64_t ny = 0;
2058   if (auto nyAttr = getNyAttr())
2059     ny = nyAttr.getInt();
2060   checkDim(getXy(), n * (nx + ny),
2061            "Expected dimension(xy) >= n * (rank(perm_map) + ny)");
2062   for (Value opnd : getYs())
2063     checkDim(opnd, n, "Expected dimension(y) >= n");
2064 
2065   return success();
2066 }
2067 
2068 //===----------------------------------------------------------------------===//
2069 // Sparse Tensor Iteration Operations.
2070 //===----------------------------------------------------------------------===//
2071 
2072 IterSpaceType IteratorType::getIterSpaceType() const {
2073   return IterSpaceType::get(getContext(), getEncoding(), getLoLvl(),
2074                             getHiLvl());
2075 }
2076 
2077 IteratorType IterSpaceType::getIteratorType() const {
2078   return IteratorType::get(getContext(), getEncoding(), getLoLvl(), getHiLvl());
2079 }
2080 
2081 /// Parses a level range in the form "$lo `to` $hi"
2082 /// or simply "$lo" if $hi - $lo = 1
2083 static ParseResult parseLevelRange(AsmParser &parser, Level &lvlLo,
2084                                    Level &lvlHi) {
2085   if (parser.parseInteger(lvlLo))
2086     return failure();
2087 
2088   if (succeeded(parser.parseOptionalKeyword("to"))) {
2089     if (parser.parseInteger(lvlHi))
2090       return failure();
2091   } else {
2092     lvlHi = lvlLo + 1;
2093   }
2094 
2095   if (lvlHi <= lvlLo)
2096     parser.emitError(parser.getNameLoc(),
2097                      "expect larger level upper bound than lower bound");
2098 
2099   return success();
2100 }
2101 
2102 /// Parses a level range in the form "$lo `to` $hi"
2103 /// or simply "$lo" if $hi - $lo = 1
2104 static ParseResult parseLevelRange(OpAsmParser &parser, IntegerAttr &lvlLoAttr,
2105                                    IntegerAttr &lvlHiAttr) {
2106   Level lvlLo, lvlHi;
2107   if (parseLevelRange(parser, lvlLo, lvlHi))
2108     return failure();
2109 
2110   lvlLoAttr = IntegerAttr::get(parser.getBuilder().getIndexType(), lvlLo);
2111   lvlHiAttr = IntegerAttr::get(parser.getBuilder().getIndexType(), lvlHi);
2112   return success();
2113 }
2114 
2115 /// Prints a level range in the form "$lo `to` $hi"
2116 /// or simply "$lo" if $hi - $lo = 1
2117 static void printLevelRange(AsmPrinter &p, Level lo, Level hi) {
2118 
2119   if (lo + 1 == hi)
2120     p << lo;
2121   else
2122     p << lo << " to " << hi;
2123 }
2124 
2125 /// Prints a level range in the form "$lo `to` $hi"
2126 /// or simply "$lo" if $hi - $lo = 1
2127 static void printLevelRange(OpAsmPrinter &p, Operation *, IntegerAttr lvlLo,
2128                             IntegerAttr lvlHi) {
2129   unsigned lo = lvlLo.getValue().getZExtValue();
2130   unsigned hi = lvlHi.getValue().getZExtValue();
2131   printLevelRange(p, lo, hi);
2132 }
2133 
2134 static ParseResult
2135 parseSparseSpaceLoop(OpAsmParser &parser, OperationState &state,
2136                      SmallVectorImpl<OpAsmParser::Argument> &iterators,
2137                      SmallVectorImpl<OpAsmParser::Argument> &iterArgs) {
2138   SmallVector<OpAsmParser::UnresolvedOperand> spaces;
2139   SmallVector<OpAsmParser::UnresolvedOperand> initArgs;
2140 
2141   // Parse "%iters, ... in %spaces, ..."
2142   if (parser.parseArgumentList(iterators) || parser.parseKeyword("in") ||
2143       parser.parseOperandList(spaces))
2144     return failure();
2145 
2146   if (iterators.size() != spaces.size())
2147     return parser.emitError(
2148         parser.getNameLoc(),
2149         "mismatch in number of sparse iterators and sparse spaces");
2150 
2151   // Parse "at(%crd0, _, ...)"
2152   LevelSet crdUsedLvlSet;
2153   bool hasUsedCrds = succeeded(parser.parseOptionalKeyword("at"));
2154   unsigned lvlCrdCnt = 0;
2155   if (hasUsedCrds) {
2156     ParseResult crdList = parser.parseCommaSeparatedList(
2157         OpAsmParser::Delimiter::Paren, [&]() -> ParseResult {
2158           if (parser.parseOptionalKeyword("_")) {
2159             if (parser.parseArgument(iterArgs.emplace_back()))
2160               return failure();
2161             // Always use IndexType for the coordinate.
2162             crdUsedLvlSet.set(lvlCrdCnt);
2163             iterArgs.back().type = parser.getBuilder().getIndexType();
2164           }
2165           lvlCrdCnt += 1;
2166           return success();
2167         });
2168     if (failed(crdList)) {
2169       return parser.emitError(
2170           parser.getNameLoc(),
2171           "expecting SSA value or \"_\" for level coordinates");
2172     }
2173   }
2174   // Set the CrdUsedLvl bitset.
2175   state.addAttribute("crdUsedLvls",
2176                      parser.getBuilder().getI64IntegerAttr(crdUsedLvlSet));
2177 
2178   // Parse "iter_args(%arg = %init, ...)"
2179   bool hasIterArgs = succeeded(parser.parseOptionalKeyword("iter_args"));
2180   if (hasIterArgs)
2181     if (parser.parseAssignmentList(iterArgs, initArgs))
2182       return failure();
2183 
2184   SmallVector<Type> iterSpaceTps;
2185   // parse ": sparse_tensor.iter_space -> ret"
2186   if (parser.parseColon() || parser.parseTypeList(iterSpaceTps))
2187     return failure();
2188   if (iterSpaceTps.size() != spaces.size())
2189     return parser.emitError(parser.getNameLoc(),
2190                             "mismatch in number of iteration space operands "
2191                             "and iteration space types");
2192 
2193   for (auto [it, tp] : llvm::zip_equal(iterators, iterSpaceTps)) {
2194     IterSpaceType spaceTp = llvm::dyn_cast<IterSpaceType>(tp);
2195     if (!spaceTp)
2196       return parser.emitError(parser.getNameLoc(),
2197                               "expected sparse_tensor.iter_space type for "
2198                               "iteration space operands");
2199     if (hasUsedCrds && spaceTp.getSpaceDim() != lvlCrdCnt)
2200       return parser.emitError(parser.getNameLoc(),
2201                               "mismatch in number of iteration space dimension "
2202                               "and specified coordinates");
2203     it.type = spaceTp.getIteratorType();
2204   }
2205 
2206   if (hasIterArgs)
2207     if (parser.parseArrowTypeList(state.types))
2208       return failure();
2209 
2210   // Resolves input operands.
2211   if (parser.resolveOperands(spaces, iterSpaceTps, parser.getNameLoc(),
2212                              state.operands))
2213     return failure();
2214 
2215   if (hasIterArgs) {
2216     unsigned numCrds = crdUsedLvlSet.count();
2217     // Strip off leading args that used for coordinates.
2218     MutableArrayRef args = MutableArrayRef(iterArgs).drop_front(numCrds);
2219     if (args.size() != initArgs.size() || args.size() != state.types.size()) {
2220       return parser.emitError(
2221           parser.getNameLoc(),
2222           "mismatch in number of iteration arguments and return values");
2223     }
2224 
2225     for (auto [it, init, tp] : llvm::zip_equal(args, initArgs, state.types)) {
2226       it.type = tp;
2227       if (parser.resolveOperand(init, tp, state.operands))
2228         return failure();
2229     }
2230   }
2231   return success();
2232 }
2233 
2234 LogicalResult ExtractIterSpaceOp::inferReturnTypes(
2235     MLIRContext *ctx, std::optional<Location> loc, ValueRange ops,
2236     DictionaryAttr attr, OpaqueProperties prop, RegionRange region,
2237     SmallVectorImpl<mlir::Type> &ret) {
2238 
2239   ExtractIterSpaceOp::Adaptor adaptor(ops, attr, prop, region);
2240   SparseTensorType stt = getSparseTensorType(adaptor.getTensor());
2241   ret.push_back(IterSpaceType::get(ctx, stt.getEncoding(), adaptor.getLoLvl(),
2242                                    adaptor.getHiLvl()));
2243   return success();
2244 }
2245 
2246 LogicalResult ExtractIterSpaceOp::verify() {
2247   if (getLoLvl() >= getHiLvl())
2248     return emitOpError("expected smaller level low than level high");
2249 
2250   TypedValue<IteratorType> pIter = getParentIter();
2251   if ((pIter && getLoLvl() == 0) || (!pIter && getLoLvl() != 0)) {
2252     return emitOpError(
2253         "parent iterator should be specified iff level lower bound equals 0");
2254   }
2255 
2256   if (pIter) {
2257     IterSpaceType spaceTp = getExtractedSpace().getType();
2258     if (pIter.getType().getEncoding() != spaceTp.getEncoding())
2259       return emitOpError(
2260           "mismatch in parent iterator encoding and iteration space encoding.");
2261 
2262     if (spaceTp.getLoLvl() != pIter.getType().getHiLvl())
2263       return emitOpError("parent iterator should be used to extract an "
2264                          "iteration space from a consecutive level.");
2265   }
2266 
2267   return success();
2268 }
2269 
2270 struct RemoveUnusedLvlCrds : public OpRewritePattern<IterateOp> {
2271   using OpRewritePattern::OpRewritePattern;
2272 
2273   LogicalResult matchAndRewrite(IterateOp iterateOp,
2274                                 PatternRewriter &rewriter) const override {
2275     LevelSet newUsedLvls(0);
2276     llvm::BitVector toRemove(iterateOp.getBody()->getNumArguments());
2277     for (unsigned i = 0, e = iterateOp.getSpaceDim(); i < e; i++) {
2278       if (auto crd = iterateOp.getLvlCrd(i)) {
2279         if (crd->getUsers().empty())
2280           toRemove.set(crd->getArgNumber());
2281         else
2282           newUsedLvls.set(i);
2283       }
2284     }
2285 
2286     // All coordinates are used.
2287     if (toRemove.none())
2288       return failure();
2289 
2290     rewriter.startOpModification(iterateOp);
2291     iterateOp.setCrdUsedLvls(newUsedLvls);
2292     iterateOp.getBody()->eraseArguments(toRemove);
2293     rewriter.finalizeOpModification(iterateOp);
2294     return success();
2295   }
2296 };
2297 
2298 void IterateOp::getCanonicalizationPatterns(mlir::RewritePatternSet &results,
2299                                             mlir::MLIRContext *context) {
2300   results.add<RemoveUnusedLvlCrds>(context);
2301 }
2302 
2303 void IterateOp::build(OpBuilder &builder, OperationState &odsState,
2304                       Value iterSpace, ValueRange initArgs) {
2305   unsigned rank = llvm::cast<IterSpaceType>(iterSpace.getType()).getSpaceDim();
2306   // All ones.
2307   LevelSet set((1 << rank) - 1);
2308   return build(builder, odsState, iterSpace, initArgs, set);
2309 }
2310 
2311 void IterateOp::build(OpBuilder &builder, OperationState &odsState,
2312                       Value iterSpace, ValueRange initArgs,
2313                       LevelSet crdUsedLvls) {
2314   OpBuilder::InsertionGuard guard(builder);
2315 
2316   odsState.addOperands(iterSpace);
2317   odsState.addOperands(initArgs);
2318   odsState.getOrAddProperties<Properties>().crdUsedLvls =
2319       builder.getIntegerAttr(builder.getIntegerType(64), crdUsedLvls);
2320   Region *bodyRegion = odsState.addRegion();
2321   odsState.addTypes(initArgs.getTypes());
2322   Block *bodyBlock = builder.createBlock(bodyRegion);
2323 
2324   // First argument, sparse iterator
2325   bodyBlock->addArgument(
2326       llvm::cast<IterSpaceType>(iterSpace.getType()).getIteratorType(),
2327       odsState.location);
2328 
2329   // Followed by a list of used coordinates.
2330   for (unsigned i = 0, e = crdUsedLvls.count(); i < e; i++)
2331     bodyBlock->addArgument(builder.getIndexType(), odsState.location);
2332 
2333   // Followed by a list of user-provided loop arguments.
2334   for (Value v : initArgs)
2335     bodyBlock->addArgument(v.getType(), v.getLoc());
2336 }
2337 
2338 ParseResult IterateOp::parse(OpAsmParser &parser, OperationState &result) {
2339   OpAsmParser::Argument iterator;
2340   OpAsmParser::UnresolvedOperand iterSpace;
2341 
2342   SmallVector<OpAsmParser::Argument> iters, iterArgs;
2343   if (parseSparseSpaceLoop(parser, result, iters, iterArgs))
2344     return failure();
2345   if (iters.size() != 1)
2346     return parser.emitError(parser.getNameLoc(),
2347                             "expected only one iterator/iteration space");
2348 
2349   iters.append(iterArgs);
2350   Region *body = result.addRegion();
2351   if (parser.parseRegion(*body, iters))
2352     return failure();
2353 
2354   IterateOp::ensureTerminator(*body, parser.getBuilder(), result.location);
2355 
2356   // Parse the optional attribute list.
2357   if (parser.parseOptionalAttrDict(result.attributes))
2358     return failure();
2359 
2360   return success();
2361 }
2362 
2363 /// Prints the initialization list in the form of
2364 ///   <prefix>(%inner = %outer, %inner2 = %outer2, <...>)
2365 /// where 'inner' values are assumed to be region arguments and 'outer' values
2366 /// are regular SSA values.
2367 static void printInitializationList(OpAsmPrinter &p,
2368                                     Block::BlockArgListType blocksArgs,
2369                                     ValueRange initializers,
2370                                     StringRef prefix = "") {
2371   assert(blocksArgs.size() == initializers.size() &&
2372          "expected same length of arguments and initializers");
2373   if (initializers.empty())
2374     return;
2375 
2376   p << prefix << '(';
2377   llvm::interleaveComma(llvm::zip(blocksArgs, initializers), p, [&](auto it) {
2378     p << std::get<0>(it) << " = " << std::get<1>(it);
2379   });
2380   p << ")";
2381 }
2382 
2383 static void printUsedCrdsList(OpAsmPrinter &p, unsigned spaceDim,
2384                               Block::BlockArgListType blocksArgs,
2385                               LevelSet crdUsedLvls) {
2386   if (crdUsedLvls.empty())
2387     return;
2388 
2389   p << " at(";
2390   for (unsigned i = 0; i < spaceDim; i++) {
2391     if (crdUsedLvls[i]) {
2392       p << blocksArgs.front();
2393       blocksArgs = blocksArgs.drop_front();
2394     } else {
2395       p << "_";
2396     }
2397     if (i != spaceDim - 1)
2398       p << ", ";
2399   }
2400   assert(blocksArgs.empty());
2401   p << ")";
2402 }
2403 
2404 void IterateOp::print(OpAsmPrinter &p) {
2405   p << " " << getIterator() << " in " << getIterSpace();
2406   printUsedCrdsList(p, getSpaceDim(), getCrds(), getCrdUsedLvls());
2407   printInitializationList(p, getRegionIterArgs(), getInitArgs(), " iter_args");
2408 
2409   p << " : " << getIterSpace().getType() << " ";
2410   if (!getInitArgs().empty())
2411     p << "-> (" << getInitArgs().getTypes() << ") ";
2412 
2413   p.printRegion(getRegion(), /*printEntryBlockArgs=*/false,
2414                 /*printBlockTerminators=*/!getInitArgs().empty());
2415 }
2416 
2417 LogicalResult IterateOp::verify() {
2418   if (getInitArgs().size() != getNumResults()) {
2419     return emitOpError(
2420         "mismatch in number of loop-carried values and defined values");
2421   }
2422   if (getCrdUsedLvls().max() > getSpaceDim())
2423     return emitOpError("required out-of-bound coordinates");
2424 
2425   return success();
2426 }
2427 
2428 LogicalResult IterateOp::verifyRegions() {
2429   if (getIterator().getType() != getIterSpace().getType().getIteratorType())
2430     return emitOpError("mismatch in iterator and iteration space type");
2431   if (getNumRegionIterArgs() != getNumResults())
2432     return emitOpError(
2433         "mismatch in number of basic block args and defined values");
2434 
2435   auto initArgs = getInitArgs();
2436   auto iterArgs = getRegionIterArgs();
2437   auto yieldVals = getYieldedValues();
2438   auto opResults = getResults();
2439   if (!llvm::all_equal({initArgs.size(), iterArgs.size(), yieldVals.size(),
2440                         opResults.size()})) {
2441     return emitOpError() << "number mismatch between iter args and results.";
2442   }
2443 
2444   for (auto [i, init, iter, yield, ret] :
2445        llvm::enumerate(initArgs, iterArgs, yieldVals, opResults)) {
2446     if (init.getType() != ret.getType())
2447       return emitOpError() << "types mismatch between " << i
2448                            << "th iter operand and defined value";
2449     if (iter.getType() != ret.getType())
2450       return emitOpError() << "types mismatch between " << i
2451                            << "th iter region arg and defined value";
2452     if (yield.getType() != ret.getType())
2453       return emitOpError() << "types mismatch between " << i
2454                            << "th yield value and defined value";
2455   }
2456 
2457   return success();
2458 }
2459 
2460 /// OpInterfaces' methods implemented by IterateOp.
2461 SmallVector<Region *> IterateOp::getLoopRegions() { return {&getRegion()}; }
2462 
2463 MutableArrayRef<OpOperand> IterateOp::getInitsMutable() {
2464   return getInitArgsMutable();
2465 }
2466 
2467 Block::BlockArgListType IterateOp::getRegionIterArgs() {
2468   return getRegion().getArguments().take_back(getNumRegionIterArgs());
2469 }
2470 
2471 std::optional<MutableArrayRef<OpOperand>> IterateOp::getYieldedValuesMutable() {
2472   return cast<sparse_tensor::YieldOp>(
2473              getRegion().getBlocks().front().getTerminator())
2474       .getResultsMutable();
2475 }
2476 
2477 std::optional<ResultRange> IterateOp::getLoopResults() { return getResults(); }
2478 
2479 OperandRange IterateOp::getEntrySuccessorOperands(RegionBranchPoint point) {
2480   return getInitArgs();
2481 }
2482 
2483 void IterateOp::getSuccessorRegions(RegionBranchPoint point,
2484                                     SmallVectorImpl<RegionSuccessor> &regions) {
2485   // Both the operation itself and the region may be branching into the body or
2486   // back into the operation itself.
2487   regions.push_back(RegionSuccessor(&getRegion(), getRegionIterArgs()));
2488   // It is possible for loop not to enter the body.
2489   regions.push_back(RegionSuccessor(getResults()));
2490 }
2491 
2492 //===----------------------------------------------------------------------===//
2493 // Sparse Tensor Dialect Setups.
2494 //===----------------------------------------------------------------------===//
2495 
2496 /// Materialize a single constant operation from a given attribute value with
2497 /// the desired resultant type.
2498 Operation *SparseTensorDialect::materializeConstant(OpBuilder &builder,
2499                                                     Attribute value, Type type,
2500                                                     Location loc) {
2501   if (auto op = arith::ConstantOp::materialize(builder, value, type, loc))
2502     return op;
2503   return nullptr;
2504 }
2505 
2506 namespace {
2507 struct SparseTensorAsmDialectInterface : public OpAsmDialectInterface {
2508   using OpAsmDialectInterface::OpAsmDialectInterface;
2509 
2510   AliasResult getAlias(Attribute attr, raw_ostream &os) const override {
2511     if (isa<SparseTensorEncodingAttr>(attr)) {
2512       os << "sparse";
2513       return AliasResult::OverridableAlias;
2514     }
2515     return AliasResult::NoAlias;
2516   }
2517 };
2518 } // namespace
2519 
2520 void SparseTensorDialect::initialize() {
2521   addInterface<SparseTensorAsmDialectInterface>();
2522   addAttributes<
2523 #define GET_ATTRDEF_LIST
2524 #include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.cpp.inc"
2525       >();
2526   addTypes<
2527 #define GET_TYPEDEF_LIST
2528 #include "mlir/Dialect/SparseTensor/IR/SparseTensorTypes.cpp.inc"
2529       >();
2530   addOperations<
2531 #define GET_OP_LIST
2532 #include "mlir/Dialect/SparseTensor/IR/SparseTensorOps.cpp.inc"
2533       >();
2534   declarePromisedInterfaces<
2535       bufferization::BufferizableOpInterface, ConcatenateOp, ConvertOp, LoadOp,
2536       NewOp, NumberOfEntriesOp, AssembleOp, DisassembleOp,
2537       ToCoordinatesBufferOp, ToCoordinatesOp, ToPositionsOp, ToValuesOp>();
2538 }
2539 
2540 #define GET_OP_CLASSES
2541 #include "mlir/Dialect/SparseTensor/IR/SparseTensorOps.cpp.inc"
2542 
2543 #include "mlir/Dialect/SparseTensor/IR/SparseTensorOpsDialect.cpp.inc"
2544