xref: /llvm-project/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp (revision c44202574ff9a8c0632aba30c2765b134557435f)
1 //===- SparseTensorDialect.cpp - Sparse tensor dialect implementation -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include <utility>
10 
11 #include "Detail/DimLvlMapParser.h"
12 
13 #include "mlir/Dialect/SparseTensor/IR/Enums.h"
14 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
15 #include "mlir/Dialect/SparseTensor/IR/SparseTensorStorageLayout.h"
16 #include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h"
17 
18 #include "mlir/Dialect/Arith/IR/Arith.h"
19 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
20 #include "mlir/Dialect/Complex/IR/Complex.h"
21 #include "mlir/Dialect/Utils/StaticValueUtils.h"
22 #include "mlir/IR/Builders.h"
23 #include "mlir/IR/DialectImplementation.h"
24 #include "mlir/IR/Matchers.h"
25 #include "mlir/IR/OpImplementation.h"
26 #include "mlir/IR/PatternMatch.h"
27 #include "llvm/ADT/Bitset.h"
28 #include "llvm/ADT/TypeSwitch.h"
29 #include "llvm/Support/FormatVariadic.h"
30 
31 #define GET_ATTRDEF_CLASSES
32 #include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.cpp.inc"
33 #include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrEnums.cpp.inc"
34 
35 // Forward declarations, following custom print/parsing methods are referenced
36 // by the generated code for SparseTensorTypes.td.
37 static mlir::ParseResult parseLevelRange(mlir::AsmParser &,
38                                          mlir::sparse_tensor::Level &,
39                                          mlir::sparse_tensor::Level &);
40 static void printLevelRange(mlir::AsmPrinter &, mlir::sparse_tensor::Level,
41                             mlir::sparse_tensor::Level);
42 
43 #define GET_TYPEDEF_CLASSES
44 #include "mlir/Dialect/SparseTensor/IR/SparseTensorTypes.cpp.inc"
45 
46 using namespace mlir;
47 using namespace mlir::sparse_tensor;
48 
49 // Support hashing LevelType such that SparseTensorEncodingAttr can be hashed as
50 // well.
51 namespace mlir::sparse_tensor {
52 llvm::hash_code hash_value(LevelType lt) {
53   return llvm::hash_value(static_cast<uint64_t>(lt));
54 }
55 } // namespace mlir::sparse_tensor
56 
57 //===----------------------------------------------------------------------===//
58 // Local Convenience Methods.
59 //===----------------------------------------------------------------------===//
60 
61 static constexpr bool acceptBitWidth(unsigned bitWidth) {
62   switch (bitWidth) {
63   case 0:
64   case 8:
65   case 16:
66   case 32:
67   case 64:
68     return true;
69   default:
70     return false;
71   }
72 }
73 
74 static SmallVector<Size>
75 getSparseFieldShape(const SparseTensorEncodingAttr enc,
76                     std::optional<ArrayRef<int64_t>> dimShape) {
77   assert(enc);
78   // With only encoding, we can not determine the static shape for leading
79   // batch levels, we therefore return a dynamic shape memref instead.
80   SmallVector<int64_t> memrefShape(enc.getBatchLvlRank(), ShapedType::kDynamic);
81   if (dimShape.has_value()) {
82     // If the actual tensor shape is provided, we can then refine the leading
83     // batch dimension.
84     SmallVector<int64_t> lvlShape =
85         enc.translateShape(*dimShape, CrdTransDirectionKind::dim2lvl);
86     memrefShape.assign(lvlShape.begin(),
87                        lvlShape.begin() + enc.getBatchLvlRank());
88   }
89   // Another dynamic dimension to store the sparse level.
90   memrefShape.push_back(ShapedType::kDynamic);
91   return memrefShape;
92 }
93 
94 //===----------------------------------------------------------------------===//
95 // SparseTensorDialect StorageLayout.
96 //===----------------------------------------------------------------------===//
97 
98 static constexpr Level kInvalidLevel = -1u;
99 static constexpr Level kInvalidFieldIndex = -1u;
100 static constexpr FieldIndex kDataFieldStartingIdx = 0;
101 
102 void StorageLayout::foreachField(
103     llvm::function_ref<bool(FieldIndex, SparseTensorFieldKind, Level,
104                             LevelType)>
105         callback) const {
106   const auto lvlTypes = enc.getLvlTypes();
107   const Level lvlRank = enc.getLvlRank();
108   SmallVector<COOSegment> cooSegs = enc.getCOOSegments();
109   FieldIndex fieldIdx = kDataFieldStartingIdx;
110 
111   ArrayRef cooSegsRef = cooSegs;
112   // Per-level storage.
113   for (Level l = 0; l < lvlRank; /*l += 1 or l += AoSCooLen*/) {
114     const auto lt = lvlTypes[l];
115     if (isWithPosLT(lt)) {
116       if (!(callback(fieldIdx++, SparseTensorFieldKind::PosMemRef, l, lt)))
117         return;
118     }
119     if (isWithCrdLT(lt)) {
120       if (!(callback(fieldIdx++, SparseTensorFieldKind::CrdMemRef, l, lt)))
121         return;
122     }
123     if (!cooSegsRef.empty() && cooSegsRef.front().isSegmentStart(l)) {
124       if (!cooSegsRef.front().isSoA) {
125         // AoS COO, all singletons are fused into one memrefs. Skips the entire
126         // COO segement.
127         l = cooSegsRef.front().lvlRange.second;
128       } else {
129         // SoA COO, each singleton level has one memref.
130         l++;
131       }
132       // Expire handled COO segment.
133       cooSegsRef = cooSegsRef.drop_front();
134     } else {
135       // Non COO levels.
136       l++;
137     }
138   }
139   // The values array.
140   if (!(callback(fieldIdx++, SparseTensorFieldKind::ValMemRef, kInvalidLevel,
141                  LevelFormat::Undef)))
142     return;
143   // Put metadata at the end.
144   if (!(callback(fieldIdx++, SparseTensorFieldKind::StorageSpec, kInvalidLevel,
145                  LevelFormat::Undef)))
146     return;
147 }
148 
149 void sparse_tensor::foreachFieldAndTypeInSparseTensor(
150     SparseTensorType stt,
151     llvm::function_ref<bool(Type, FieldIndex, SparseTensorFieldKind, Level,
152                             LevelType)>
153         callback) {
154   assert(stt.hasEncoding());
155 
156   SmallVector<int64_t> memrefShape =
157       getSparseFieldShape(stt.getEncoding(), stt.getDimShape());
158 
159   const Type specType = StorageSpecifierType::get(stt.getEncoding());
160   // memref<[batch] x ? x pos>  positions
161   const Type posMemType = MemRefType::get(memrefShape, stt.getPosType());
162   // memref<[batch] x ? x crd>  coordinates
163   const Type crdMemType = MemRefType::get(memrefShape, stt.getCrdType());
164   // memref<[batch] x ? x eltType> values
165   const Type valMemType = MemRefType::get(memrefShape, stt.getElementType());
166 
167   StorageLayout(stt).foreachField([specType, posMemType, crdMemType, valMemType,
168                                    callback](FieldIndex fieldIdx,
169                                              SparseTensorFieldKind fieldKind,
170                                              Level lvl, LevelType lt) -> bool {
171     switch (fieldKind) {
172     case SparseTensorFieldKind::StorageSpec:
173       return callback(specType, fieldIdx, fieldKind, lvl, lt);
174     case SparseTensorFieldKind::PosMemRef:
175       return callback(posMemType, fieldIdx, fieldKind, lvl, lt);
176     case SparseTensorFieldKind::CrdMemRef:
177       return callback(crdMemType, fieldIdx, fieldKind, lvl, lt);
178     case SparseTensorFieldKind::ValMemRef:
179       return callback(valMemType, fieldIdx, fieldKind, lvl, lt);
180     };
181     llvm_unreachable("unrecognized field kind");
182   });
183 }
184 
185 unsigned StorageLayout::getNumFields() const {
186   unsigned numFields = 0;
187   foreachField([&numFields](FieldIndex, SparseTensorFieldKind, Level,
188                             LevelType) -> bool {
189     numFields++;
190     return true;
191   });
192   return numFields;
193 }
194 
195 unsigned StorageLayout::getNumDataFields() const {
196   unsigned numFields = 0; // one value memref
197   foreachField([&numFields](FieldIndex fidx, SparseTensorFieldKind, Level,
198                             LevelType) -> bool {
199     if (fidx >= kDataFieldStartingIdx)
200       numFields++;
201     return true;
202   });
203   numFields -= 1; // the last field is StorageSpecifier
204   assert(numFields == getNumFields() - kDataFieldStartingIdx - 1);
205   return numFields;
206 }
207 
208 std::pair<FieldIndex, unsigned>
209 StorageLayout::getFieldIndexAndStride(SparseTensorFieldKind kind,
210                                       std::optional<Level> lvl) const {
211   FieldIndex fieldIdx = kInvalidFieldIndex;
212   unsigned stride = 1;
213   if (kind == SparseTensorFieldKind::CrdMemRef) {
214     assert(lvl.has_value());
215     const Level cooStart = enc.getAoSCOOStart();
216     const Level lvlRank = enc.getLvlRank();
217     if (lvl.value() >= cooStart && lvl.value() < lvlRank) {
218       lvl = cooStart;
219       stride = lvlRank - cooStart;
220     }
221   }
222   foreachField([lvl, kind, &fieldIdx](FieldIndex fIdx,
223                                       SparseTensorFieldKind fKind, Level fLvl,
224                                       LevelType lt) -> bool {
225     if ((lvl && fLvl == lvl.value() && kind == fKind) ||
226         (kind == fKind && fKind == SparseTensorFieldKind::ValMemRef)) {
227       fieldIdx = fIdx;
228       // Returns false to break the iteration.
229       return false;
230     }
231     return true;
232   });
233   assert(fieldIdx != kInvalidFieldIndex);
234   return std::pair<FieldIndex, unsigned>(fieldIdx, stride);
235 }
236 
237 //===----------------------------------------------------------------------===//
238 // SparseTensorDialect Attribute Methods.
239 //===----------------------------------------------------------------------===//
240 
241 std::optional<uint64_t> SparseTensorDimSliceAttr::getStatic(int64_t v) {
242   return isDynamic(v) ? std::nullopt
243                       : std::make_optional(static_cast<uint64_t>(v));
244 }
245 
246 std::optional<uint64_t> SparseTensorDimSliceAttr::getStaticOffset() const {
247   return getStatic(getOffset());
248 }
249 
250 std::optional<uint64_t> SparseTensorDimSliceAttr::getStaticStride() const {
251   return getStatic(getStride());
252 }
253 
254 std::optional<uint64_t> SparseTensorDimSliceAttr::getStaticSize() const {
255   return getStatic(getSize());
256 }
257 
258 bool SparseTensorDimSliceAttr::isCompletelyDynamic() const {
259   return isDynamic(getOffset()) && isDynamic(getStride()) &&
260          isDynamic(getSize());
261 }
262 
263 std::string SparseTensorDimSliceAttr::getStaticString(int64_t v) {
264   return isDynamic(v) ? "?" : std::to_string(v);
265 }
266 
267 void SparseTensorDimSliceAttr::print(llvm::raw_ostream &os) const {
268   assert(getImpl() && "Uninitialized SparseTensorDimSliceAttr");
269   os << '(';
270   os << getStaticString(getOffset());
271   os << ", ";
272   os << getStaticString(getSize());
273   os << ", ";
274   os << getStaticString(getStride());
275   os << ')';
276 }
277 
278 void SparseTensorDimSliceAttr::print(AsmPrinter &printer) const {
279   print(printer.getStream());
280 }
281 
282 static ParseResult parseOptionalStaticSlice(int64_t &result,
283                                             AsmParser &parser) {
284   auto parseResult = parser.parseOptionalInteger(result);
285   if (parseResult.has_value()) {
286     if (parseResult.value().succeeded() && result < 0) {
287       parser.emitError(
288           parser.getCurrentLocation(),
289           "expect positive value or ? for slice offset/size/stride");
290       return failure();
291     }
292     return parseResult.value();
293   }
294 
295   // Else, and '?' which represented dynamic slice
296   result = SparseTensorDimSliceAttr::kDynamic;
297   return parser.parseQuestion();
298 }
299 
300 Attribute SparseTensorDimSliceAttr::parse(AsmParser &parser, Type type) {
301   int64_t offset = kDynamic, size = kDynamic, stride = kDynamic;
302 
303   if (failed(parser.parseLParen()) ||
304       failed(parseOptionalStaticSlice(offset, parser)) ||
305       failed(parser.parseComma()) ||
306       failed(parseOptionalStaticSlice(size, parser)) ||
307       failed(parser.parseComma()) ||
308       failed(parseOptionalStaticSlice(stride, parser)) ||
309       failed(parser.parseRParen()))
310     return {};
311 
312   return parser.getChecked<SparseTensorDimSliceAttr>(parser.getContext(),
313                                                      offset, size, stride);
314 }
315 
316 LogicalResult
317 SparseTensorDimSliceAttr::verify(function_ref<InFlightDiagnostic()> emitError,
318                                  int64_t offset, int64_t size, int64_t stride) {
319   if (!isDynamic(offset) && offset < 0)
320     return emitError() << "expect non-negative value or ? for slice offset";
321   if (!isDynamic(size) && size <= 0)
322     return emitError() << "expect positive value or ? for slice size";
323   if (!isDynamic(stride) && stride <= 0)
324     return emitError() << "expect positive value or ? for slice stride";
325   return success();
326 }
327 
328 SparseTensorEncodingAttr
329 SparseTensorEncodingAttr::withDimToLvl(AffineMap dimToLvl) const {
330   assert(getImpl() && "Uninitialized SparseTensorEncodingAttr");
331   return SparseTensorEncodingAttr::get(
332       getContext(), getLvlTypes(), dimToLvl, AffineMap(), getPosWidth(),
333       getCrdWidth(), getExplicitVal(), getImplicitVal());
334 }
335 
336 SparseTensorEncodingAttr
337 SparseTensorEncodingAttr::withDimToLvl(SparseTensorEncodingAttr enc) const {
338   return withDimToLvl(enc ? enc.getDimToLvl() : AffineMap());
339 }
340 
341 SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutDimToLvl() const {
342   return withDimToLvl(AffineMap());
343 }
344 
345 SparseTensorEncodingAttr
346 SparseTensorEncodingAttr::withBitWidths(unsigned posWidth,
347                                         unsigned crdWidth) const {
348   assert(getImpl() && "Uninitialized SparseTensorEncodingAttr");
349   return SparseTensorEncodingAttr::get(
350       getContext(), getLvlTypes(), getDimToLvl(), getLvlToDim(), posWidth,
351       crdWidth, getExplicitVal(), getImplicitVal());
352 }
353 
354 SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutBitWidths() const {
355   return withBitWidths(0, 0);
356 }
357 
358 SparseTensorEncodingAttr
359 SparseTensorEncodingAttr::withExplicitVal(Attribute explicitVal) const {
360   assert(getImpl() && "Uninitialized SparseTensorEncodingAttr");
361   return SparseTensorEncodingAttr::get(
362       getContext(), getLvlTypes(), getDimToLvl(), getLvlToDim(), getPosWidth(),
363       getCrdWidth(), explicitVal, getImplicitVal());
364 }
365 
366 SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutExplicitVal() const {
367   return withExplicitVal(Attribute());
368 }
369 
370 SparseTensorEncodingAttr
371 SparseTensorEncodingAttr::withImplicitVal(Attribute implicitVal) const {
372   assert(getImpl() && "Uninitialized SparseTensorEncodingAttr");
373   return SparseTensorEncodingAttr::get(
374       getContext(), getLvlTypes(), getDimToLvl(), getLvlToDim(), getPosWidth(),
375       getCrdWidth(), getExplicitVal(), implicitVal);
376 }
377 
378 SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutImplicitVal() const {
379   return withImplicitVal(Attribute());
380 }
381 
382 SparseTensorEncodingAttr SparseTensorEncodingAttr::withDimSlices(
383     ArrayRef<SparseTensorDimSliceAttr> dimSlices) const {
384   return SparseTensorEncodingAttr::get(
385       getContext(), getLvlTypes(), getDimToLvl(), getLvlToDim(), getPosWidth(),
386       getCrdWidth(), getExplicitVal(), getImplicitVal(), dimSlices);
387 }
388 
389 SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutDimSlices() const {
390   return withDimSlices(ArrayRef<SparseTensorDimSliceAttr>{});
391 }
392 
393 uint64_t SparseTensorEncodingAttr::getBatchLvlRank() const {
394   ArrayRef<LevelType> lvlTypes = getLvlTypes();
395   auto lastBatch = std::find_if(lvlTypes.rbegin(), lvlTypes.rend(), isBatchLT);
396   return std::distance(lastBatch, lvlTypes.rend());
397 }
398 
399 bool SparseTensorEncodingAttr::isAllDense() const {
400   return !getImpl() || llvm::all_of(getLvlTypes(), isDenseLT);
401 }
402 
403 bool SparseTensorEncodingAttr::isAllOrdered() const {
404   return !getImpl() || llvm::all_of(getLvlTypes(), isOrderedLT);
405 }
406 
407 Type SparseTensorEncodingAttr::getCrdElemType() const {
408   if (!getImpl())
409     return nullptr;
410   if (getCrdWidth())
411     return IntegerType::get(getContext(), getCrdWidth());
412   return IndexType::get(getContext());
413 }
414 
415 Type SparseTensorEncodingAttr::getPosElemType() const {
416   if (!getImpl())
417     return nullptr;
418   if (getPosWidth())
419     return IntegerType::get(getContext(), getPosWidth());
420   return IndexType::get(getContext());
421 }
422 
423 MemRefType SparseTensorEncodingAttr::getCrdMemRefType(
424     std::optional<ArrayRef<int64_t>> dimShape) const {
425   SmallVector<Size> shape = getSparseFieldShape(*this, dimShape);
426   return MemRefType::get(shape, getCrdElemType());
427 }
428 
429 MemRefType SparseTensorEncodingAttr::getPosMemRefType(
430     std::optional<ArrayRef<int64_t>> dimShape) const {
431   SmallVector<Size> shape = getSparseFieldShape(*this, dimShape);
432   return MemRefType::get(shape, getPosElemType());
433 }
434 
435 bool SparseTensorEncodingAttr::isIdentity() const {
436   return !getImpl() || !getDimToLvl() || getDimToLvl().isIdentity();
437 }
438 
439 bool SparseTensorEncodingAttr::isPermutation() const {
440   return !getImpl() || !getDimToLvl() || getDimToLvl().isPermutation();
441 }
442 
443 Dimension SparseTensorEncodingAttr::getDimRank() const {
444   assert(getImpl() && "Uninitialized SparseTensorEncodingAttr");
445   const auto dimToLvl = getDimToLvl();
446   return dimToLvl ? dimToLvl.getNumDims() : getLvlRank();
447 }
448 
449 Level SparseTensorEncodingAttr::getLvlRank() const {
450   assert(getImpl() && "Uninitialized SparseTensorEncodingAttr");
451   return getLvlTypes().size();
452 }
453 
454 LevelType SparseTensorEncodingAttr::getLvlType(Level l) const {
455   if (!getImpl())
456     return LevelFormat::Batch;
457   assert(l < getLvlRank() && "Level is out of bounds");
458   return getLvlTypes()[l];
459 }
460 
461 bool SparseTensorEncodingAttr::isSlice() const {
462   assert(getImpl() && "Uninitialized SparseTensorEncodingAttr");
463   return !getDimSlices().empty();
464 }
465 
466 SparseTensorDimSliceAttr
467 SparseTensorEncodingAttr::getDimSlice(Dimension dim) const {
468   assert(isSlice() && "Is not a slice");
469   const auto dimSlices = getDimSlices();
470   assert(dim < dimSlices.size() && "Dimension is out of bounds");
471   return dimSlices[dim];
472 }
473 
474 std::optional<uint64_t>
475 SparseTensorEncodingAttr::getStaticDimSliceOffset(Dimension dim) const {
476   return getDimSlice(dim).getStaticOffset();
477 }
478 
479 std::optional<uint64_t>
480 SparseTensorEncodingAttr::getStaticDimSliceStride(Dimension dim) const {
481   return getDimSlice(dim).getStaticStride();
482 }
483 
484 std::optional<uint64_t>
485 SparseTensorEncodingAttr::getStaticLvlSliceOffset(Level lvl) const {
486   return getStaticDimSliceOffset(toDim(*this, lvl));
487 }
488 
489 std::optional<uint64_t>
490 SparseTensorEncodingAttr::getStaticLvlSliceStride(Level lvl) const {
491   return getStaticDimSliceStride(toDim(*this, lvl));
492 }
493 
494 SmallVector<int64_t>
495 SparseTensorEncodingAttr::translateShape(ArrayRef<int64_t> srcShape,
496                                          CrdTransDirectionKind dir) const {
497   if (isIdentity())
498     return SmallVector<int64_t>(srcShape);
499 
500   SmallVector<int64_t> ret;
501   unsigned rank =
502       dir == CrdTransDirectionKind::dim2lvl ? getLvlRank() : getDimRank();
503   ret.reserve(rank);
504 
505   if (isPermutation()) {
506     for (unsigned r = 0; r < rank; r++) {
507       unsigned trans = dir == CrdTransDirectionKind::dim2lvl ? toDim(*this, r)
508                                                              : toLvl(*this, r);
509       ret.push_back(srcShape[trans]);
510     }
511     return ret;
512   }
513 
514   // Handle non-permutation maps.
515   AffineMap transMap =
516       dir == CrdTransDirectionKind::dim2lvl ? getDimToLvl() : getLvlToDim();
517 
518   SmallVector<AffineExpr> dimRep;
519   dimRep.reserve(srcShape.size());
520   for (int64_t sz : srcShape) {
521     if (!ShapedType::isDynamic(sz)) {
522       // Push back the max coordinate for the given dimension/level size.
523       dimRep.push_back(getAffineConstantExpr(sz - 1, getContext()));
524     } else {
525       // A dynamic size, use a AffineDimExpr to symbolize the value.
526       dimRep.push_back(getAffineDimExpr(dimRep.size(), getContext()));
527     }
528   };
529 
530   for (AffineExpr exp : transMap.getResults()) {
531     // Do constant propagation on the affine map.
532     AffineExpr evalExp =
533         simplifyAffineExpr(exp.replaceDims(dimRep), srcShape.size(), 0);
534     // use llvm namespace here to avoid ambiguity
535     if (auto c = llvm::dyn_cast<AffineConstantExpr>(evalExp)) {
536       ret.push_back(c.getValue() + 1);
537     } else {
538       if (auto mod = llvm::dyn_cast<AffineBinaryOpExpr>(evalExp);
539           mod && mod.getKind() == AffineExprKind::Mod) {
540         // We can still infer a static bound for expressions in form
541         // "d % constant" since d % constant \in [0, constant).
542         if (auto bound = llvm::dyn_cast<AffineConstantExpr>(mod.getRHS())) {
543           ret.push_back(bound.getValue());
544           continue;
545         }
546       }
547       ret.push_back(ShapedType::kDynamic);
548     }
549   }
550   assert(ret.size() == rank);
551   return ret;
552 }
553 
554 ValueRange
555 SparseTensorEncodingAttr::translateCrds(OpBuilder &builder, Location loc,
556                                         ValueRange crds,
557                                         CrdTransDirectionKind dir) const {
558   if (!getImpl())
559     return crds;
560 
561   SmallVector<Type> retType(
562       dir == CrdTransDirectionKind::lvl2dim ? getDimRank() : getLvlRank(),
563       builder.getIndexType());
564   auto transOp = builder.create<CrdTranslateOp>(loc, retType, crds, dir, *this);
565   return transOp.getOutCrds();
566 }
567 
568 Attribute SparseTensorEncodingAttr::parse(AsmParser &parser, Type type) {
569   // Open "<{" part.
570   if (failed(parser.parseLess()))
571     return {};
572   if (failed(parser.parseLBrace()))
573     return {};
574 
575   // Process the data from the parsed dictionary value into struct-like data.
576   SmallVector<LevelType> lvlTypes;
577   SmallVector<SparseTensorDimSliceAttr> dimSlices;
578   AffineMap dimToLvl = {};
579   AffineMap lvlToDim = {};
580   unsigned posWidth = 0;
581   unsigned crdWidth = 0;
582   Attribute explicitVal;
583   Attribute implicitVal;
584   StringRef attrName;
585   SmallVector<StringRef, 5> keys = {"map", "posWidth", "crdWidth",
586                                     "explicitVal", "implicitVal"};
587   while (succeeded(parser.parseOptionalKeyword(&attrName))) {
588     // Detect admissible keyword.
589     auto *it = find(keys, attrName);
590     if (it == keys.end()) {
591       parser.emitError(parser.getNameLoc(), "unexpected key: ") << attrName;
592       return {};
593     }
594     unsigned keyWordIndex = it - keys.begin();
595     // Consume the `=` after keys
596     if (failed(parser.parseEqual()))
597       return {};
598     // Dispatch on keyword.
599     switch (keyWordIndex) {
600     case 0: { // map
601       ir_detail::DimLvlMapParser cParser(parser);
602       auto res = cParser.parseDimLvlMap();
603       if (failed(res))
604         return {};
605       const auto &dlm = *res;
606 
607       const Level lvlRank = dlm.getLvlRank();
608       for (Level lvl = 0; lvl < lvlRank; lvl++)
609         lvlTypes.push_back(dlm.getLvlType(lvl));
610 
611       const Dimension dimRank = dlm.getDimRank();
612       for (Dimension dim = 0; dim < dimRank; dim++)
613         dimSlices.push_back(dlm.getDimSlice(dim));
614       // NOTE: the old syntax requires an all-or-nothing approach to
615       // `dimSlices`; therefore, if any slice actually exists then we need
616       // to convert null-DSA into default/nop DSA.
617       const auto isDefined = [](SparseTensorDimSliceAttr slice) {
618         return static_cast<bool>(slice.getImpl());
619       };
620       if (llvm::any_of(dimSlices, isDefined)) {
621         const auto defaultSlice =
622             SparseTensorDimSliceAttr::get(parser.getContext());
623         for (Dimension dim = 0; dim < dimRank; dim++)
624           if (!isDefined(dimSlices[dim]))
625             dimSlices[dim] = defaultSlice;
626       } else {
627         dimSlices.clear();
628       }
629 
630       dimToLvl = dlm.getDimToLvlMap(parser.getContext());
631       lvlToDim = dlm.getLvlToDimMap(parser.getContext());
632       break;
633     }
634     case 1: { // posWidth
635       Attribute attr;
636       if (failed(parser.parseAttribute(attr)))
637         return {};
638       auto intAttr = llvm::dyn_cast<IntegerAttr>(attr);
639       if (!intAttr) {
640         parser.emitError(parser.getNameLoc(),
641                          "expected an integral position bitwidth");
642         return {};
643       }
644       posWidth = intAttr.getInt();
645       break;
646     }
647     case 2: { // crdWidth
648       Attribute attr;
649       if (failed(parser.parseAttribute(attr)))
650         return {};
651       auto intAttr = llvm::dyn_cast<IntegerAttr>(attr);
652       if (!intAttr) {
653         parser.emitError(parser.getNameLoc(),
654                          "expected an integral index bitwidth");
655         return {};
656       }
657       crdWidth = intAttr.getInt();
658       break;
659     }
660     case 3: { // explicitVal
661       Attribute attr;
662       if (failed(parser.parseAttribute(attr)))
663         return {};
664       if (auto result = llvm::dyn_cast<FloatAttr>(attr)) {
665         explicitVal = result;
666       } else if (auto result = llvm::dyn_cast<IntegerAttr>(attr)) {
667         explicitVal = result;
668       } else if (auto result = llvm::dyn_cast<complex::NumberAttr>(attr)) {
669         explicitVal = result;
670       } else {
671         parser.emitError(parser.getNameLoc(),
672                          "expected a numeric value for explicitVal");
673         return {};
674       }
675       break;
676     }
677     case 4: { // implicitVal
678       Attribute attr;
679       if (failed(parser.parseAttribute(attr)))
680         return {};
681       if (auto result = llvm::dyn_cast<FloatAttr>(attr)) {
682         implicitVal = result;
683       } else if (auto result = llvm::dyn_cast<IntegerAttr>(attr)) {
684         implicitVal = result;
685       } else if (auto result = llvm::dyn_cast<complex::NumberAttr>(attr)) {
686         implicitVal = result;
687       } else {
688         parser.emitError(parser.getNameLoc(),
689                          "expected a numeric value for implicitVal");
690         return {};
691       }
692       break;
693     }
694     } // switch
695     // Only last item can omit the comma.
696     if (parser.parseOptionalComma().failed())
697       break;
698   }
699 
700   // Close "}>" part.
701   if (failed(parser.parseRBrace()))
702     return {};
703   if (failed(parser.parseGreater()))
704     return {};
705 
706   // Construct struct-like storage for attribute.
707   if (!lvlToDim || lvlToDim.isEmpty()) {
708     lvlToDim = inferLvlToDim(dimToLvl, parser.getContext());
709   }
710   return parser.getChecked<SparseTensorEncodingAttr>(
711       parser.getContext(), lvlTypes, dimToLvl, lvlToDim, posWidth, crdWidth,
712       explicitVal, implicitVal, dimSlices);
713 }
714 
715 void SparseTensorEncodingAttr::print(AsmPrinter &printer) const {
716   auto map = static_cast<AffineMap>(getDimToLvl());
717   // Empty affine map indicates identity map
718   if (!map)
719     map = AffineMap::getMultiDimIdentityMap(getLvlTypes().size(), getContext());
720   printer << "<{ map = ";
721   printSymbols(map, printer);
722   printer << '(';
723   printDimensions(map, printer, getDimSlices());
724   printer << ") -> (";
725   printLevels(map, printer, getLvlTypes());
726   printer << ')';
727   // Print remaining members only for non-default values.
728   if (getPosWidth())
729     printer << ", posWidth = " << getPosWidth();
730   if (getCrdWidth())
731     printer << ", crdWidth = " << getCrdWidth();
732   if (getExplicitVal()) {
733     printer << ", explicitVal = " << getExplicitVal();
734   }
735   if (getImplicitVal())
736     printer << ", implicitVal = " << getImplicitVal();
737   printer << " }>";
738 }
739 
740 void SparseTensorEncodingAttr::printSymbols(AffineMap &map,
741                                             AsmPrinter &printer) const {
742   if (map.getNumSymbols() == 0)
743     return;
744   printer << '[';
745   for (unsigned i = 0, n = map.getNumSymbols() - 1; i < n; i++)
746     printer << 's' << i << ", ";
747   if (map.getNumSymbols() >= 1)
748     printer << 's' << map.getNumSymbols() - 1;
749   printer << ']';
750 }
751 
752 void SparseTensorEncodingAttr::printDimensions(
753     AffineMap &map, AsmPrinter &printer,
754     ArrayRef<SparseTensorDimSliceAttr> dimSlices) const {
755   if (!dimSlices.empty()) {
756     for (unsigned i = 0, n = map.getNumDims() - 1; i < n; i++)
757       printer << 'd' << i << " : " << dimSlices[i] << ", ";
758     if (map.getNumDims() >= 1) {
759       printer << 'd' << map.getNumDims() - 1 << " : "
760               << dimSlices[map.getNumDims() - 1];
761     }
762   } else {
763     for (unsigned i = 0, n = map.getNumDims() - 1; i < n; i++)
764       printer << 'd' << i << ", ";
765     if (map.getNumDims() >= 1)
766       printer << 'd' << map.getNumDims() - 1;
767   }
768 }
769 
770 void SparseTensorEncodingAttr::printLevels(AffineMap &map, AsmPrinter &printer,
771                                            ArrayRef<LevelType> lvlTypes) const {
772   for (unsigned i = 0, n = map.getNumResults() - 1; i < n; i++) {
773     map.getResult(i).print(printer.getStream());
774     printer << " : " << toMLIRString(lvlTypes[i]) << ", ";
775   }
776   if (map.getNumResults() >= 1) {
777     auto lastIndex = map.getNumResults() - 1;
778     map.getResult(lastIndex).print(printer.getStream());
779     printer << " : " << toMLIRString(lvlTypes[lastIndex]);
780   }
781 }
782 
783 LogicalResult SparseTensorEncodingAttr::verify(
784     function_ref<InFlightDiagnostic()> emitError, ArrayRef<LevelType> lvlTypes,
785     AffineMap dimToLvl, AffineMap lvlToDim, unsigned posWidth,
786     unsigned crdWidth, Attribute explicitVal, Attribute implicitVal,
787     ArrayRef<SparseTensorDimSliceAttr> dimSlices) {
788   if (!acceptBitWidth(posWidth))
789     return emitError() << "unexpected position bitwidth: " << posWidth;
790   if (!acceptBitWidth(crdWidth))
791     return emitError() << "unexpected coordinate bitwidth: " << crdWidth;
792 
793   // Verify every COO segment.
794   auto *it = std::find_if(lvlTypes.begin(), lvlTypes.end(), isSingletonLT);
795   while (it != lvlTypes.end()) {
796     if (it == lvlTypes.begin() ||
797         !(it - 1)->isa<LevelFormat::Compressed, LevelFormat::LooseCompressed>())
798       return emitError() << "expected compressed or loose_compressed level "
799                             "before singleton level";
800 
801     auto *curCOOEnd = std::find_if_not(it, lvlTypes.end(), isSingletonLT);
802     if (!std::all_of(it, curCOOEnd,
803                      [](LevelType i) { return isSingletonLT(i); }))
804       return emitError() << "expected all singleton lvlTypes "
805                             "following a singleton level";
806     // We can potentially support mixed SoA/AoS singleton levels.
807     if (!std::all_of(it, curCOOEnd, [it](LevelType i) {
808           return it->isa<LevelPropNonDefault::SoA>() ==
809                  i.isa<LevelPropNonDefault::SoA>();
810         })) {
811       return emitError() << "expected all singleton lvlTypes stored in the "
812                             "same memory layout (SoA vs AoS).";
813     }
814     it = std::find_if(curCOOEnd, lvlTypes.end(), isSingletonLT);
815   }
816 
817   auto lastBatch = std::find_if(lvlTypes.rbegin(), lvlTypes.rend(), isBatchLT);
818   if (!std::all_of(lastBatch, lvlTypes.rend(), isBatchLT))
819     return emitError() << "Batch lvlType can only be leading levels.";
820 
821   // SoA property can only be applied on singleton level.
822   auto soaLvls = llvm::make_filter_range(lvlTypes, [](LevelType lt) {
823     return lt.isa<LevelPropNonDefault::SoA>();
824   });
825   if (llvm::any_of(soaLvls, [](LevelType lt) {
826         return !lt.isa<LevelFormat::Singleton>();
827       })) {
828     return emitError() << "SoA is only applicable to singleton lvlTypes.";
829   }
830 
831   // TODO: audit formats that actually are supported by backend.
832   if (auto it = std::find_if(lvlTypes.begin(), lvlTypes.end(), isNOutOfMLT);
833       it != std::end(lvlTypes)) {
834     if (it != lvlTypes.end() - 1)
835       return emitError() << "expected n_out_of_m to be the last level type";
836     if (!std::all_of(lvlTypes.begin(), it,
837                      [](LevelType i) { return isDenseLT(i); }))
838       return emitError() << "expected all dense lvlTypes "
839                             "before a n_out_of_m level";
840     if (dimToLvl && (dimToLvl.getNumDims() != dimToLvl.getNumResults())) {
841       if (!isBlockSparsity(dimToLvl)) {
842         return emitError()
843                << "expected 1xm block structure for n_out_of_m level";
844       }
845       auto sizes = getBlockSize(dimToLvl);
846       unsigned coefficient = 0;
847       for (const auto &elem : sizes) {
848         if (elem != 0) {
849           if (elem != coefficient && coefficient != 0) {
850             return emitError() << "expected only one blocked level "
851                                   "with the same coefficients";
852           }
853           coefficient = elem;
854         }
855       }
856       if (coefficient != getM(*it)) {
857         return emitError() << "expected coeffiencts of Affine expressions "
858                               "to be equal to m of n_out_of_m level";
859       }
860     }
861   }
862   // Before we can check that the level-rank is consistent/coherent
863   // across all fields, we need to define it.  The source-of-truth for
864   // the `getLvlRank` method is the length of the level-types array,
865   // since it must always be provided and have full rank; therefore we
866   // use that same source-of-truth here.
867   const Level lvlRank = lvlTypes.size();
868   if (lvlRank == 0)
869     return emitError() << "expected a non-empty array for lvlTypes";
870   // We save `dimRank` here because we'll also need it to verify `dimSlices`.
871   const Dimension dimRank = dimToLvl ? dimToLvl.getNumDims() : lvlRank;
872   if (dimToLvl) {
873     if (dimToLvl.getNumResults() != lvlRank)
874       return emitError()
875              << "level-rank mismatch between dimToLvl and lvlTypes: "
876              << dimToLvl.getNumResults() << " != " << lvlRank;
877     auto inferRes = inferLvlToDim(dimToLvl, dimToLvl.getContext());
878     // Symbols can't be inferred but are acceptable.
879     if (!inferRes && dimToLvl.getNumSymbols() == 0)
880       return emitError() << "failed to infer lvlToDim from dimToLvl";
881     if (lvlToDim && (inferRes != lvlToDim))
882       return emitError() << "expected lvlToDim to be an inverse of dimToLvl";
883     if (dimRank > lvlRank)
884       return emitError() << "unexpected dimToLvl mapping from " << dimRank
885                          << " to " << lvlRank;
886   }
887   if (!dimSlices.empty()) {
888     if (dimSlices.size() != dimRank)
889       return emitError()
890              << "dimension-rank mismatch between dimSlices and dimToLvl: "
891              << dimSlices.size() << " != " << dimRank;
892     // Compiler support for `dimSlices` currently requires that the two
893     // ranks agree.  (However, it does allow `dimToLvl` to be a permutation.)
894     if (dimRank != lvlRank)
895       return emitError()
896              << "dimSlices expected dimension-rank to match level-rank: "
897              << dimRank << " != " << lvlRank;
898   }
899   return success();
900 }
901 
902 LogicalResult SparseTensorEncodingAttr::verifyEncoding(
903     ArrayRef<Size> dimShape, Type elementType,
904     function_ref<InFlightDiagnostic()> emitError) const {
905   // Check structural integrity.  In particular, this ensures that the
906   // level-rank is coherent across all the fields.
907   if (failed(verify(emitError, getLvlTypes(), getDimToLvl(), getLvlToDim(),
908                     getPosWidth(), getCrdWidth(), getExplicitVal(),
909                     getImplicitVal(), getDimSlices())))
910     return failure();
911   // Check integrity with tensor type specifics.  In particular, we
912   // need only check that the dimension-rank of the tensor agrees with
913   // the dimension-rank of the encoding.
914   const Dimension dimRank = dimShape.size();
915   if (dimRank == 0)
916     return emitError() << "expected non-scalar sparse tensor";
917   if (getDimRank() != dimRank)
918     return emitError()
919            << "dimension-rank mismatch between encoding and tensor shape: "
920            << getDimRank() << " != " << dimRank;
921   if (auto expVal = getExplicitVal()) {
922     Type attrType = llvm::dyn_cast<TypedAttr>(expVal).getType();
923     if (attrType != elementType) {
924       return emitError() << "explicit value type mismatch between encoding and "
925                          << "tensor element type: " << attrType
926                          << " != " << elementType;
927     }
928   }
929   if (auto impVal = getImplicitVal()) {
930     Type attrType = llvm::dyn_cast<TypedAttr>(impVal).getType();
931     if (attrType != elementType) {
932       return emitError() << "implicit value type mismatch between encoding and "
933                          << "tensor element type: " << attrType
934                          << " != " << elementType;
935     }
936     // Currently, we only support zero as the implicit value.
937     auto impFVal = llvm::dyn_cast<FloatAttr>(impVal);
938     auto impIntVal = llvm::dyn_cast<IntegerAttr>(impVal);
939     auto impComplexVal = llvm::dyn_cast<complex::NumberAttr>(impVal);
940     if ((impFVal && impFVal.getValue().isNonZero()) ||
941         (impIntVal && !impIntVal.getValue().isZero()) ||
942         (impComplexVal && (impComplexVal.getImag().isNonZero() ||
943                            impComplexVal.getReal().isNonZero()))) {
944       return emitError() << "implicit value must be zero";
945     }
946   }
947   return success();
948 }
949 
950 Level mlir::sparse_tensor::SparseTensorEncodingAttr::getAoSCOOStart() const {
951   SmallVector<COOSegment> coo = getCOOSegments();
952   assert(coo.size() == 1 || coo.empty());
953   if (!coo.empty() && coo.front().isAoS()) {
954     return coo.front().lvlRange.first;
955   }
956   return getLvlRank();
957 }
958 
959 SmallVector<COOSegment>
960 mlir::sparse_tensor::SparseTensorEncodingAttr::getCOOSegments() const {
961   SmallVector<COOSegment> ret;
962   if (getLvlRank() <= 1)
963     return ret;
964 
965   ArrayRef<LevelType> lts = getLvlTypes();
966   Level l = 0;
967   while (l < getLvlRank()) {
968     auto lt = lts[l];
969     if (lt.isa<LevelFormat::Compressed, LevelFormat::LooseCompressed>()) {
970       auto cur = lts.begin() + l;
971       auto end = std::find_if(cur + 1, lts.end(), [](LevelType lt) {
972         return !lt.isa<LevelFormat::Singleton>();
973       });
974       unsigned cooLen = std::distance(cur, end);
975       if (cooLen > 1) {
976         // To support mixed SoA/AoS COO, we should break the segment when the
977         // storage scheme changes, for now we faithfully assume that all
978         // consecutive singleton levels have the same storage format as verified
979         // STEA.
980         ret.push_back(COOSegment{std::make_pair(l, l + cooLen),
981                                  lts[l + 1].isa<LevelPropNonDefault::SoA>()});
982       }
983       l += cooLen;
984     } else {
985       l++;
986     }
987   }
988   return ret;
989 }
990 
991 //===----------------------------------------------------------------------===//
992 // SparseTensorType Methods.
993 //===----------------------------------------------------------------------===//
994 
995 bool mlir::sparse_tensor::SparseTensorType::isCOOType(Level startLvl,
996                                                       bool isUnique) const {
997   if (!hasEncoding())
998     return false;
999   if (!isCompressedLvl(startLvl) && !isLooseCompressedLvl(startLvl))
1000     return false;
1001   for (Level l = startLvl + 1; l < lvlRank; ++l)
1002     if (!isSingletonLvl(l))
1003       return false;
1004   // If isUnique is true, then make sure that the last level is unique,
1005   // that is, when lvlRank == 1, the only compressed level is unique,
1006   // and when lvlRank > 1, the last singleton is unique.
1007   return !isUnique || isUniqueLvl(lvlRank - 1);
1008 }
1009 
1010 RankedTensorType
1011 mlir::sparse_tensor::SparseTensorType::getCOOType(bool ordered) const {
1012   SmallVector<LevelType> lvlTypes;
1013   lvlTypes.reserve(lvlRank);
1014   // A non-unique compressed level at beginning (unless this is
1015   // also the last level, then it is unique).
1016   lvlTypes.push_back(
1017       *buildLevelType(LevelFormat::Compressed, ordered, lvlRank == 1));
1018   if (lvlRank > 1) {
1019     // Followed by n-2 non-unique singleton levels.
1020     std::fill_n(std::back_inserter(lvlTypes), lvlRank - 2,
1021                 *buildLevelType(LevelFormat::Singleton, ordered, false));
1022     // Ends by a unique singleton level.
1023     lvlTypes.push_back(*buildLevelType(LevelFormat::Singleton, ordered, true));
1024   }
1025   auto enc = SparseTensorEncodingAttr::get(
1026       getContext(), lvlTypes, getDimToLvl(), getLvlToDim(), getPosWidth(),
1027       getCrdWidth(), getExplicitVal(), getImplicitVal());
1028   return RankedTensorType::get(getDimShape(), getElementType(), enc);
1029 }
1030 
1031 //===----------------------------------------------------------------------===//
1032 // Convenience Methods.
1033 //===----------------------------------------------------------------------===//
1034 
1035 SparseTensorEncodingAttr
1036 mlir::sparse_tensor::getSparseTensorEncoding(Type type) {
1037   if (auto ttp = llvm::dyn_cast<RankedTensorType>(type))
1038     return llvm::dyn_cast_or_null<SparseTensorEncodingAttr>(ttp.getEncoding());
1039   if (auto mdtp = llvm::dyn_cast<StorageSpecifierType>(type))
1040     return mdtp.getEncoding();
1041   return nullptr;
1042 }
1043 
1044 AffineMap mlir::sparse_tensor::inferLvlToDim(AffineMap dimToLvl,
1045                                              MLIRContext *context) {
1046   auto map = static_cast<AffineMap>(dimToLvl);
1047   AffineMap lvlToDim;
1048   // Return an empty lvlToDim when inference is not successful.
1049   if (!map || map.getNumSymbols() != 0) {
1050     lvlToDim = AffineMap();
1051   } else if (map.isPermutation()) {
1052     lvlToDim = inversePermutation(map);
1053   } else if (isBlockSparsity(map)) {
1054     lvlToDim = inverseBlockSparsity(map, context);
1055   }
1056   return lvlToDim;
1057 }
1058 
1059 AffineMap mlir::sparse_tensor::inverseBlockSparsity(AffineMap dimToLvl,
1060                                                     MLIRContext *context) {
1061   SmallVector<AffineExpr> lvlExprs;
1062   auto numLvls = dimToLvl.getNumResults();
1063   lvlExprs.reserve(numLvls);
1064   // lvlExprComponents stores information of the floordiv and mod operations
1065   // applied to the same dimension, so as to build the lvlToDim map.
1066   std::map<unsigned, SmallVector<AffineExpr, 3>> lvlExprComponents;
1067   for (unsigned i = 0, n = numLvls; i < n; i++) {
1068     auto result = dimToLvl.getResult(i);
1069     if (auto binOp = dyn_cast<AffineBinaryOpExpr>(result)) {
1070       if (result.getKind() == AffineExprKind::FloorDiv) {
1071         // Position of the dimension in dimToLvl.
1072         auto pos = dyn_cast<AffineDimExpr>(binOp.getLHS()).getPosition();
1073         assert(lvlExprComponents.find(pos) == lvlExprComponents.end() &&
1074                "expected only one floordiv for each dimension");
1075         SmallVector<AffineExpr, 3> components;
1076         // Level variable for floordiv.
1077         components.push_back(getAffineDimExpr(i, context));
1078         // Multiplier.
1079         components.push_back(binOp.getRHS());
1080         // Map key is the position of the dimension.
1081         lvlExprComponents[pos] = components;
1082       } else if (result.getKind() == AffineExprKind::Mod) {
1083         auto pos = dyn_cast<AffineDimExpr>(binOp.getLHS()).getPosition();
1084         assert(lvlExprComponents.find(pos) != lvlExprComponents.end() &&
1085                "expected floordiv before mod");
1086         // Add level variable for mod to the same vector
1087         // of the corresponding floordiv.
1088         lvlExprComponents[pos].push_back(getAffineDimExpr(i, context));
1089       } else {
1090         assert(false && "expected floordiv or mod");
1091       }
1092     } else {
1093       lvlExprs.push_back(getAffineDimExpr(i, context));
1094     }
1095   }
1096   // Build lvlExprs from lvlExprComponents.
1097   // For example, for il = i floordiv 2 and ii = i mod 2, the components
1098   // would be [il, 2, ii]. It could be used to build the AffineExpr
1099   // i = il * 2 + ii in lvlToDim.
1100   for (auto &components : lvlExprComponents) {
1101     assert(components.second.size() == 3 &&
1102            "expected 3 components to build lvlExprs");
1103     auto mulOp = getAffineBinaryOpExpr(
1104         AffineExprKind::Mul, components.second[0], components.second[1]);
1105     auto addOp =
1106         getAffineBinaryOpExpr(AffineExprKind::Add, mulOp, components.second[2]);
1107     lvlExprs.push_back(addOp);
1108   }
1109   return dimToLvl.get(dimToLvl.getNumResults(), 0, lvlExprs, context);
1110 }
1111 
1112 SmallVector<unsigned> mlir::sparse_tensor::getBlockSize(AffineMap dimToLvl) {
1113   assert(isBlockSparsity(dimToLvl) &&
1114          "expected dimToLvl to be block sparsity for calling getBlockSize");
1115   SmallVector<unsigned> blockSize;
1116   for (auto result : dimToLvl.getResults()) {
1117     if (auto binOp = dyn_cast<AffineBinaryOpExpr>(result)) {
1118       if (result.getKind() == AffineExprKind::Mod) {
1119         blockSize.push_back(
1120             dyn_cast<AffineConstantExpr>(binOp.getRHS()).getValue());
1121       }
1122     } else {
1123       blockSize.push_back(0);
1124     }
1125   }
1126   return blockSize;
1127 }
1128 
1129 bool mlir::sparse_tensor::isBlockSparsity(AffineMap dimToLvl) {
1130   if (!dimToLvl)
1131     return false;
1132   std::map<unsigned, int64_t> coeffientMap;
1133   bool hasBlock = false;
1134   for (auto result : dimToLvl.getResults()) {
1135     if (auto binOp = dyn_cast<AffineBinaryOpExpr>(result)) {
1136       // Check for "dim op const".
1137       auto dimOp = dyn_cast<AffineDimExpr>(binOp.getLHS());
1138       auto conOp = dyn_cast<AffineConstantExpr>(binOp.getRHS());
1139       if (!dimOp || !conOp || conOp.getValue() <= 0)
1140         return false;
1141       // Inspect "dim / const" or "dim % const".
1142       auto pos = dimOp.getPosition();
1143       if (binOp.getKind() == AffineExprKind::FloorDiv) {
1144         // Expect only one floordiv for each dimension.
1145         if (coeffientMap.find(pos) != coeffientMap.end())
1146           return false;
1147         // Record coefficient of the floordiv.
1148         coeffientMap[pos] = conOp.getValue();
1149       } else if (binOp.getKind() == AffineExprKind::Mod) {
1150         // Expect floordiv before mod.
1151         if (coeffientMap.find(pos) == coeffientMap.end())
1152           return false;
1153         // Expect mod to have the same coefficient as floordiv.
1154         if (conOp.getValue() != coeffientMap[pos])
1155           return false;
1156         hasBlock = true;
1157       } else {
1158         return false;
1159       }
1160     } else if (auto dimOp = dyn_cast<AffineDimExpr>(result)) {
1161       auto pos = dimOp.getPosition();
1162       // Expect dim to be unset.
1163       if (coeffientMap.find(pos) != coeffientMap.end())
1164         return false;
1165       coeffientMap[pos] = 0;
1166     } else {
1167       return false;
1168     }
1169   }
1170   return hasBlock;
1171 }
1172 
1173 bool mlir::sparse_tensor::hasAnyNonIdentityOperandsOrResults(Operation *op) {
1174   auto hasNonIdentityMap = [](Value v) {
1175     auto stt = tryGetSparseTensorType(v);
1176     return stt && !stt->isIdentity();
1177   };
1178 
1179   return llvm::any_of(op->getOperands(), hasNonIdentityMap) ||
1180          llvm::any_of(op->getResults(), hasNonIdentityMap);
1181 }
1182 
1183 Dimension mlir::sparse_tensor::toDim(SparseTensorEncodingAttr enc, Level l) {
1184   if (enc) {
1185     assert(enc.isPermutation() && "Non permutation map not supported");
1186     if (const auto dimToLvl = enc.getDimToLvl())
1187       return dimToLvl.getDimPosition(l);
1188   }
1189   return l;
1190 }
1191 
1192 Level mlir::sparse_tensor::toLvl(SparseTensorEncodingAttr enc, Dimension d) {
1193   if (enc) {
1194     assert(enc.isPermutation() && "Non permutation map not supported");
1195     if (const auto lvlToDim = enc.getLvlToDim())
1196       return lvlToDim.getDimPosition(d);
1197   }
1198   return d;
1199 }
1200 
1201 /// We normalized sparse tensor encoding attribute by always using
1202 /// ordered/unique LT such that "compressed_nu_no" and "compressed_nu" (as well
1203 /// as other variants) lead to the same storage specifier type, and stripping
1204 /// irrelevant fields that do not alter the sparse tensor memory layout.
1205 static SparseTensorEncodingAttr
1206 getNormalizedEncodingForSpecifier(SparseTensorEncodingAttr enc) {
1207   SmallVector<LevelType> lts;
1208   for (auto lt : enc.getLvlTypes())
1209     lts.push_back(lt.stripStorageIrrelevantProperties());
1210 
1211   return SparseTensorEncodingAttr::get(
1212       enc.getContext(), lts,
1213       AffineMap(), // dimToLvl (irrelevant to storage specifier)
1214       AffineMap(), // lvlToDim (irrelevant to storage specifier)
1215       // Always use `index` for memSize and lvlSize instead of reusing
1216       // `getPosWidth` and `getCrdWidth`. It allows us to reuse the same SSA
1217       // value for different bitwidth, it also avoids casting between index and
1218       // integer (returned by DimOp)
1219       0, 0,
1220       Attribute(), // explicitVal (irrelevant to storage specifier)
1221       Attribute(), // implicitVal (irrelevant to storage specifier)
1222       enc.getDimSlices());
1223 }
1224 
1225 StorageSpecifierType
1226 StorageSpecifierType::get(MLIRContext *ctx, SparseTensorEncodingAttr encoding) {
1227   return Base::get(ctx, getNormalizedEncodingForSpecifier(encoding));
1228 }
1229 
1230 StorageSpecifierType
1231 StorageSpecifierType::getChecked(function_ref<InFlightDiagnostic()> emitError,
1232                                  MLIRContext *ctx,
1233                                  SparseTensorEncodingAttr encoding) {
1234   return Base::getChecked(emitError, ctx,
1235                           getNormalizedEncodingForSpecifier(encoding));
1236 }
1237 
1238 //===----------------------------------------------------------------------===//
1239 // SparseTensorDialect Operations.
1240 //===----------------------------------------------------------------------===//
1241 
1242 static LogicalResult lvlIsInBounds(Level lvl, Value tensor) {
1243   return success(lvl < getSparseTensorType(tensor).getLvlRank());
1244 }
1245 
1246 static LogicalResult isMatchingWidth(Value mem, unsigned width) {
1247   const Type etp = getMemRefType(mem).getElementType();
1248   return success(width == 0 ? etp.isIndex() : etp.isInteger(width));
1249 }
1250 
1251 static LogicalResult verifySparsifierGetterSetter(
1252     StorageSpecifierKind mdKind, std::optional<Level> lvl,
1253     TypedValue<StorageSpecifierType> md, Operation *op) {
1254   if (mdKind == StorageSpecifierKind::ValMemSize && lvl) {
1255     return op->emitError(
1256         "redundant level argument for querying value memory size");
1257   }
1258 
1259   const auto enc = md.getType().getEncoding();
1260   const Level lvlRank = enc.getLvlRank();
1261 
1262   if (mdKind == StorageSpecifierKind::DimOffset ||
1263       mdKind == StorageSpecifierKind::DimStride)
1264     if (!enc.isSlice())
1265       return op->emitError("requested slice data on non-slice tensor");
1266 
1267   if (mdKind != StorageSpecifierKind::ValMemSize) {
1268     if (!lvl)
1269       return op->emitError("missing level argument");
1270 
1271     const Level l = lvl.value();
1272     if (l >= lvlRank)
1273       return op->emitError("requested level is out of bounds");
1274 
1275     if (mdKind == StorageSpecifierKind::PosMemSize && enc.isSingletonLvl(l))
1276       return op->emitError(
1277           "requested position memory size on a singleton level");
1278   }
1279   return success();
1280 }
1281 
1282 static Type getFieldElemType(SparseTensorType stt, SparseTensorFieldKind kind) {
1283   switch (kind) {
1284   case SparseTensorFieldKind::CrdMemRef:
1285     return stt.getCrdType();
1286   case SparseTensorFieldKind::PosMemRef:
1287     return stt.getPosType();
1288   case SparseTensorFieldKind::ValMemRef:
1289     return stt.getElementType();
1290   case SparseTensorFieldKind::StorageSpec:
1291     return nullptr;
1292   }
1293   llvm_unreachable("Unrecognizable FieldKind");
1294 }
1295 
1296 static LogicalResult verifyPackUnPack(Operation *op, bool requiresStaticShape,
1297                                       SparseTensorType stt,
1298                                       RankedTensorType valTp,
1299                                       TypeRange lvlTps) {
1300   if (requiresStaticShape && !stt.hasStaticDimShape())
1301     return op->emitError("the sparse-tensor must have static shape");
1302   if (!stt.hasEncoding())
1303     return op->emitError("the sparse-tensor must have an encoding attribute");
1304 
1305   // Verifies the trailing COO.
1306   Level cooStartLvl = stt.getAoSCOOStart();
1307   if (cooStartLvl < stt.getLvlRank()) {
1308     // We only supports trailing COO for now, must be the last input.
1309     auto cooTp = llvm::cast<ShapedType>(lvlTps.back());
1310     // The coordinates should be in shape of <? x rank>
1311     unsigned expCOORank = stt.getLvlRank() - cooStartLvl;
1312     if (cooTp.getRank() != 2 || expCOORank != cooTp.getShape().back()) {
1313       op->emitError("input/output trailing COO level-ranks don't match");
1314     }
1315   }
1316 
1317   // Verifies that all types match.
1318   StorageLayout layout(stt.getEncoding());
1319   if (layout.getNumDataFields() != lvlTps.size() + 1) // plus one value memref
1320     return op->emitError("inconsistent number of fields between input/output");
1321 
1322   unsigned idx = 0;
1323   bool misMatch = false;
1324   layout.foreachField([&idx, &misMatch, stt, valTp,
1325                        lvlTps](FieldIndex fid, SparseTensorFieldKind fKind,
1326                                Level lvl, LevelType lt) -> bool {
1327     if (fKind == SparseTensorFieldKind::StorageSpec)
1328       return true;
1329 
1330     Type inputTp = nullptr;
1331     if (fKind == SparseTensorFieldKind::ValMemRef) {
1332       inputTp = valTp;
1333     } else {
1334       assert(fid == idx && stt.getLvlType(lvl) == lt);
1335       inputTp = lvlTps[idx++];
1336     }
1337     // The input element type and expected element type should match.
1338     Type inpElemTp = llvm::cast<TensorType>(inputTp).getElementType();
1339     Type expElemTp = getFieldElemType(stt, fKind);
1340     if (inpElemTp != expElemTp) {
1341       misMatch = true;
1342       return false; // to terminate the iteration
1343     }
1344     return true;
1345   });
1346 
1347   if (misMatch)
1348     return op->emitError("input/output element-types don't match");
1349   return success();
1350 }
1351 
1352 LogicalResult AssembleOp::verify() {
1353   const auto valuesTp = getRankedTensorType(getValues());
1354   const auto lvlsTp = getLevels().getTypes();
1355   const auto resTp = getSparseTensorType(getResult());
1356   return verifyPackUnPack(*this, true, resTp, valuesTp, lvlsTp);
1357 }
1358 
1359 LogicalResult DisassembleOp::verify() {
1360   if (getOutValues().getType() != getRetValues().getType())
1361     return emitError("output values and return value type mismatch");
1362 
1363   for (auto [ot, rt] : llvm::zip_equal(getOutLevels(), getRetLevels()))
1364     if (ot.getType() != rt.getType())
1365       return emitError("output levels and return levels type mismatch");
1366 
1367   const auto valuesTp = getRankedTensorType(getRetValues());
1368   const auto lvlsTp = getRetLevels().getTypes();
1369   const auto srcTp = getSparseTensorType(getTensor());
1370   return verifyPackUnPack(*this, false, srcTp, valuesTp, lvlsTp);
1371 }
1372 
1373 LogicalResult ConvertOp::verify() {
1374   if (auto tp1 = llvm::dyn_cast<RankedTensorType>(getSource().getType())) {
1375     if (auto tp2 = llvm::dyn_cast<RankedTensorType>(getDest().getType())) {
1376       if (tp1.getRank() != tp2.getRank())
1377         return emitError("unexpected conversion mismatch in rank");
1378       auto dstEnc =
1379           llvm::dyn_cast_or_null<SparseTensorEncodingAttr>(tp2.getEncoding());
1380       if (dstEnc && dstEnc.isSlice())
1381         return emitError("cannot convert to a sparse tensor slice");
1382 
1383       auto shape1 = tp1.getShape();
1384       auto shape2 = tp2.getShape();
1385       // Accept size matches between the source and the destination type
1386       // (e.g. 10 vs. 10, 10 vs. ?, or ? vs. ?), but reject direct mismatches or
1387       // matches that would need a runtime assert (e.g. 10 vs. 20 or ? vs. 10).
1388       for (Dimension d = 0, dimRank = tp1.getRank(); d < dimRank; d++)
1389         if (shape1[d] != shape2[d] && shape2[d] != ShapedType::kDynamic)
1390           return emitError("unexpected conversion mismatch in dimension ") << d;
1391       return success();
1392     }
1393   }
1394   return emitError("unexpected type in convert");
1395 }
1396 
1397 OpFoldResult ConvertOp::fold(FoldAdaptor adaptor) {
1398   if (getType() == getSource().getType())
1399     return getSource();
1400   return {};
1401 }
1402 
1403 bool ConvertOp::needsExtraSort() {
1404   SparseTensorType srcStt = getSparseTensorType(getSource());
1405   SparseTensorType dstStt = getSparseTensorType(getDest());
1406 
1407   // We do not need an extra sort when returning unordered sparse tensors or
1408   // dense tensor since dense tensor support random access.
1409   if (dstStt.isAllDense() || !dstStt.isAllOrdered())
1410     return false;
1411 
1412   if (srcStt.isAllOrdered() && dstStt.isAllOrdered() &&
1413       srcStt.hasSameDimToLvl(dstStt)) {
1414     return false;
1415   }
1416 
1417   // Source and dest tensors are ordered in different ways. We only do direct
1418   // dense to sparse conversion when the dense input is defined by a sparse
1419   // constant. Note that we can theoretically always directly convert from dense
1420   // inputs by rotating dense loops but it leads to bad cache locality and hurt
1421   // performance.
1422   if (auto constOp = getSource().getDefiningOp<arith::ConstantOp>())
1423     if (isa<SparseElementsAttr>(constOp.getValue()))
1424       return false;
1425 
1426   return true;
1427 }
1428 
1429 LogicalResult CrdTranslateOp::verify() {
1430   uint64_t inRank = getEncoder().getLvlRank();
1431   uint64_t outRank = getEncoder().getDimRank();
1432 
1433   if (getDirection() == CrdTransDirectionKind::dim2lvl)
1434     std::swap(inRank, outRank);
1435 
1436   if (inRank != getInCrds().size() || outRank != getOutCrds().size())
1437     return emitError("Coordinate rank mismatch with encoding");
1438 
1439   return success();
1440 }
1441 
1442 LogicalResult CrdTranslateOp::fold(FoldAdaptor adaptor,
1443                                    SmallVectorImpl<OpFoldResult> &results) {
1444   if (getEncoder().isIdentity()) {
1445     results.assign(getInCrds().begin(), getInCrds().end());
1446     return success();
1447   }
1448   if (getEncoder().isPermutation()) {
1449     AffineMap perm = getDirection() == CrdTransDirectionKind::dim2lvl
1450                          ? getEncoder().getDimToLvl()
1451                          : getEncoder().getLvlToDim();
1452     for (AffineExpr exp : perm.getResults())
1453       results.push_back(getInCrds()[cast<AffineDimExpr>(exp).getPosition()]);
1454     return success();
1455   }
1456 
1457   // Fuse dim2lvl/lvl2dim pairs.
1458   auto def = getInCrds()[0].getDefiningOp<CrdTranslateOp>();
1459   bool sameDef = def && llvm::all_of(getInCrds(), [def](Value v) {
1460                    return v.getDefiningOp() == def;
1461                  });
1462   if (!sameDef)
1463     return failure();
1464 
1465   bool oppositeDir = def.getDirection() != getDirection();
1466   bool sameOracle =
1467       def.getEncoder().getDimToLvl() == getEncoder().getDimToLvl();
1468   bool sameCount = def.getNumResults() == getInCrds().size();
1469   if (!oppositeDir || !sameOracle || !sameCount)
1470     return failure();
1471 
1472   // The definition produces the coordinates in the same order as the input
1473   // coordinates.
1474   bool sameOrder = llvm::all_of(llvm::zip_equal(def.getOutCrds(), getInCrds()),
1475                                 [](auto valuePair) {
1476                                   auto [lhs, rhs] = valuePair;
1477                                   return lhs == rhs;
1478                                 });
1479 
1480   if (!sameOrder)
1481     return failure();
1482   // l1 = dim2lvl (lvl2dim l0)
1483   // ==> l0
1484   results.append(def.getInCrds().begin(), def.getInCrds().end());
1485   return success();
1486 }
1487 
1488 void LvlOp::build(OpBuilder &builder, OperationState &state, Value source,
1489                   int64_t index) {
1490   Value val = builder.create<arith::ConstantIndexOp>(state.location, index);
1491   return build(builder, state, source, val);
1492 }
1493 
1494 LogicalResult LvlOp::verify() {
1495   if (std::optional<uint64_t> lvl = getConstantLvlIndex()) {
1496     auto stt = getSparseTensorType(getSource());
1497     if (static_cast<uint64_t>(lvl.value()) >= stt.getLvlRank())
1498       emitError("Level index exceeds the rank of the input sparse tensor");
1499   }
1500   return success();
1501 }
1502 
1503 std::optional<uint64_t> LvlOp::getConstantLvlIndex() {
1504   return getConstantIntValue(getIndex());
1505 }
1506 
1507 Speculation::Speculatability LvlOp::getSpeculatability() {
1508   auto constantIndex = getConstantLvlIndex();
1509   if (!constantIndex)
1510     return Speculation::NotSpeculatable;
1511 
1512   assert(constantIndex <
1513          cast<RankedTensorType>(getSource().getType()).getRank());
1514   return Speculation::Speculatable;
1515 }
1516 
1517 OpFoldResult LvlOp::fold(FoldAdaptor adaptor) {
1518   auto lvlIndex = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getIndex());
1519   if (!lvlIndex)
1520     return {};
1521 
1522   Level lvl = lvlIndex.getAPSInt().getZExtValue();
1523   auto stt = getSparseTensorType(getSource());
1524   if (lvl >= stt.getLvlRank()) {
1525     // Follows the same convention used by tensor.dim operation. Out of bound
1526     // indices produce undefined behavior but are still valid IR. Don't choke on
1527     // them.
1528     return {};
1529   }
1530 
1531   // Helper lambda to build an IndexAttr.
1532   auto getIndexAttr = [this](int64_t lvlSz) {
1533     return IntegerAttr::get(IndexType::get(getContext()), APInt(64, lvlSz));
1534   };
1535 
1536   SmallVector<Size> lvlShape = stt.getLvlShape();
1537   if (!ShapedType::isDynamic(lvlShape[lvl]))
1538     return getIndexAttr(lvlShape[lvl]);
1539 
1540   return {};
1541 }
1542 
1543 void ReinterpretMapOp::build(OpBuilder &odsBuilder, OperationState &odsState,
1544                              SparseTensorEncodingAttr dstEnc, Value source) {
1545   auto srcStt = getSparseTensorType(source);
1546   SmallVector<int64_t> srcLvlShape = srcStt.getLvlShape();
1547   SmallVector<int64_t> dstDimShape =
1548       dstEnc.translateShape(srcLvlShape, CrdTransDirectionKind::lvl2dim);
1549   auto dstTp =
1550       RankedTensorType::get(dstDimShape, srcStt.getElementType(), dstEnc);
1551   return build(odsBuilder, odsState, dstTp, source);
1552 }
1553 
1554 LogicalResult ReinterpretMapOp::verify() {
1555   auto srcStt = getSparseTensorType(getSource());
1556   auto dstStt = getSparseTensorType(getDest());
1557   ArrayRef<LevelType> srcLvlTps = srcStt.getLvlTypes();
1558   ArrayRef<LevelType> dstLvlTps = dstStt.getLvlTypes();
1559 
1560   if (srcLvlTps.size() != dstLvlTps.size())
1561     return emitError("Level rank mismatch between source/dest tensors");
1562 
1563   for (auto [srcLvlTp, dstLvlTp] : llvm::zip(srcLvlTps, dstLvlTps))
1564     if (srcLvlTp != dstLvlTp)
1565       return emitError("Level type mismatch between source/dest tensors");
1566 
1567   if (srcStt.getPosWidth() != dstStt.getPosWidth() ||
1568       srcStt.getCrdWidth() != dstStt.getCrdWidth()) {
1569     return emitError("Crd/Pos width mismatch between source/dest tensors");
1570   }
1571 
1572   if (srcStt.getElementType() != dstStt.getElementType())
1573     return emitError("Element type mismatch between source/dest tensors");
1574 
1575   SmallVector<Size> srcLvlShape = srcStt.getLvlShape();
1576   SmallVector<Size> dstLvlShape = dstStt.getLvlShape();
1577   for (auto [srcLvlSz, dstLvlSz] : llvm::zip(srcLvlShape, dstLvlShape)) {
1578     if (srcLvlSz != dstLvlSz) {
1579       // Should we allow one side to be dynamic size, e.g., <?x?> should be
1580       // compatible to <3x4>? For now, we require all the level sizes to be
1581       // *exactly* matched for simplicity.
1582       return emitError("Level size mismatch between source/dest tensors");
1583     }
1584   }
1585 
1586   return success();
1587 }
1588 
1589 OpFoldResult ReinterpretMapOp::fold(FoldAdaptor adaptor) {
1590   if (getSource().getType() == getDest().getType())
1591     return getSource();
1592 
1593   if (auto def = getSource().getDefiningOp<ReinterpretMapOp>()) {
1594     // A -> B, B -> A ==> A
1595     if (def.getSource().getType() == getDest().getType())
1596       return def.getSource();
1597   }
1598   return {};
1599 }
1600 
1601 template <typename ToBufferOp>
1602 static LogicalResult inferSparseBufferType(ValueRange ops, DictionaryAttr attr,
1603                                            OpaqueProperties prop,
1604                                            RegionRange region,
1605                                            SmallVectorImpl<mlir::Type> &ret) {
1606   typename ToBufferOp::Adaptor adaptor(ops, attr, prop, region);
1607   SparseTensorType stt = getSparseTensorType(adaptor.getTensor());
1608   Type elemTp = nullptr;
1609   bool withStride = false;
1610   if constexpr (std::is_same_v<ToBufferOp, ToPositionsOp>) {
1611     elemTp = stt.getPosType();
1612   } else if constexpr (std::is_same_v<ToBufferOp, ToCoordinatesOp> ||
1613                        std::is_same_v<ToBufferOp, ToCoordinatesBufferOp>) {
1614     elemTp = stt.getCrdType();
1615     if constexpr (std::is_same_v<ToBufferOp, ToCoordinatesOp>)
1616       withStride = stt.getAoSCOOStart() <= adaptor.getLevel();
1617   } else if constexpr (std::is_same_v<ToBufferOp, ToValuesOp>) {
1618     elemTp = stt.getElementType();
1619   }
1620 
1621   assert(elemTp && "unhandled operation.");
1622   SmallVector<int64_t> bufShape = stt.getBatchLvlShape();
1623   bufShape.push_back(ShapedType::kDynamic);
1624 
1625   auto layout = withStride ? StridedLayoutAttr::StridedLayoutAttr::get(
1626                                  stt.getContext(), ShapedType::kDynamic,
1627                                  {ShapedType::kDynamic})
1628                            : StridedLayoutAttr();
1629   ret.emplace_back(MemRefType::get(bufShape, elemTp, layout));
1630   return success();
1631 }
1632 
1633 LogicalResult ToPositionsOp::verify() {
1634   auto stt = getSparseTensorType(getTensor());
1635   if (failed(lvlIsInBounds(getLevel(), getTensor())))
1636     return emitError("requested level is out of bounds");
1637   if (failed(isMatchingWidth(getResult(), stt.getPosWidth())))
1638     return emitError("unexpected type for positions");
1639   return success();
1640 }
1641 
1642 LogicalResult
1643 ToPositionsOp::inferReturnTypes(MLIRContext *ctx, std::optional<Location> loc,
1644                                 ValueRange ops, DictionaryAttr attr,
1645                                 OpaqueProperties prop, RegionRange region,
1646                                 SmallVectorImpl<mlir::Type> &ret) {
1647   return inferSparseBufferType<ToPositionsOp>(ops, attr, prop, region, ret);
1648 }
1649 
1650 LogicalResult ToCoordinatesOp::verify() {
1651   auto stt = getSparseTensorType(getTensor());
1652   if (failed(lvlIsInBounds(getLevel(), getTensor())))
1653     return emitError("requested level is out of bounds");
1654   if (failed(isMatchingWidth(getResult(), stt.getCrdWidth())))
1655     return emitError("unexpected type for coordinates");
1656   return success();
1657 }
1658 
1659 LogicalResult
1660 ToCoordinatesOp::inferReturnTypes(MLIRContext *ctx, std::optional<Location> loc,
1661                                   ValueRange ops, DictionaryAttr attr,
1662                                   OpaqueProperties prop, RegionRange region,
1663                                   SmallVectorImpl<mlir::Type> &ret) {
1664   return inferSparseBufferType<ToCoordinatesOp>(ops, attr, prop, region, ret);
1665 }
1666 
1667 LogicalResult ToCoordinatesBufferOp::verify() {
1668   auto stt = getSparseTensorType(getTensor());
1669   if (stt.getAoSCOOStart() >= stt.getLvlRank())
1670     return emitError("expected sparse tensor with a COO region");
1671   return success();
1672 }
1673 
1674 LogicalResult ToCoordinatesBufferOp::inferReturnTypes(
1675     MLIRContext *ctx, std::optional<Location> loc, ValueRange ops,
1676     DictionaryAttr attr, OpaqueProperties prop, RegionRange region,
1677     SmallVectorImpl<mlir::Type> &ret) {
1678   return inferSparseBufferType<ToCoordinatesBufferOp>(ops, attr, prop, region,
1679                                                       ret);
1680 }
1681 
1682 LogicalResult ToValuesOp::verify() {
1683   auto stt = getSparseTensorType(getTensor());
1684   auto mtp = getMemRefType(getResult());
1685   if (stt.getElementType() != mtp.getElementType())
1686     return emitError("unexpected mismatch in element types");
1687   return success();
1688 }
1689 
1690 LogicalResult ToValuesOp::inferReturnTypes(MLIRContext *ctx,
1691                                            std::optional<Location> loc,
1692                                            ValueRange ops, DictionaryAttr attr,
1693                                            OpaqueProperties prop,
1694                                            RegionRange region,
1695                                            SmallVectorImpl<mlir::Type> &ret) {
1696   return inferSparseBufferType<ToValuesOp>(ops, attr, prop, region, ret);
1697 }
1698 
1699 LogicalResult ToSliceOffsetOp::verify() {
1700   auto rank = getRankedTensorType(getSlice()).getRank();
1701   if (rank <= getDim().getSExtValue() || getDim().getSExtValue() < 0)
1702     return emitError("requested dimension out of bound");
1703   return success();
1704 }
1705 
1706 LogicalResult ToSliceStrideOp::verify() {
1707   auto rank = getRankedTensorType(getSlice()).getRank();
1708   if (rank <= getDim().getSExtValue() || getDim().getSExtValue() < 0)
1709     return emitError("requested dimension out of bound");
1710   return success();
1711 }
1712 
1713 LogicalResult GetStorageSpecifierOp::verify() {
1714   return verifySparsifierGetterSetter(getSpecifierKind(), getLevel(),
1715                                       getSpecifier(), getOperation());
1716 }
1717 
1718 template <typename SpecifierOp>
1719 static SetStorageSpecifierOp getSpecifierSetDef(SpecifierOp op) {
1720   return op.getSpecifier().template getDefiningOp<SetStorageSpecifierOp>();
1721 }
1722 
1723 OpFoldResult GetStorageSpecifierOp::fold(FoldAdaptor adaptor) {
1724   const StorageSpecifierKind kind = getSpecifierKind();
1725   const auto lvl = getLevel();
1726   for (auto op = getSpecifierSetDef(*this); op; op = getSpecifierSetDef(op))
1727     if (kind == op.getSpecifierKind() && lvl == op.getLevel())
1728       return op.getValue();
1729   return {};
1730 }
1731 
1732 LogicalResult SetStorageSpecifierOp::verify() {
1733   return verifySparsifierGetterSetter(getSpecifierKind(), getLevel(),
1734                                       getSpecifier(), getOperation());
1735 }
1736 
1737 template <class T>
1738 static LogicalResult verifyNumBlockArgs(T *op, Region &region,
1739                                         const char *regionName,
1740                                         TypeRange inputTypes, Type outputType) {
1741   unsigned numArgs = region.getNumArguments();
1742   unsigned expectedNum = inputTypes.size();
1743   if (numArgs != expectedNum)
1744     return op->emitError() << regionName << " region must have exactly "
1745                            << expectedNum << " arguments";
1746 
1747   for (unsigned i = 0; i < numArgs; i++) {
1748     Type typ = region.getArgument(i).getType();
1749     if (typ != inputTypes[i])
1750       return op->emitError() << regionName << " region argument " << (i + 1)
1751                              << " type mismatch";
1752   }
1753   Operation *term = region.front().getTerminator();
1754   YieldOp yield = dyn_cast<YieldOp>(term);
1755   if (!yield)
1756     return op->emitError() << regionName
1757                            << " region must end with sparse_tensor.yield";
1758   if (!yield.hasSingleResult() ||
1759       yield.getSingleResult().getType() != outputType)
1760     return op->emitError() << regionName << " region yield type mismatch";
1761 
1762   return success();
1763 }
1764 
1765 LogicalResult BinaryOp::verify() {
1766   NamedAttrList attrs = (*this)->getAttrs();
1767   Type leftType = getX().getType();
1768   Type rightType = getY().getType();
1769   Type outputType = getOutput().getType();
1770   Region &overlap = getOverlapRegion();
1771   Region &left = getLeftRegion();
1772   Region &right = getRightRegion();
1773 
1774   // Check correct number of block arguments and return type for each
1775   // non-empty region.
1776   if (!overlap.empty()) {
1777     if (failed(verifyNumBlockArgs(this, overlap, "overlap",
1778                                   TypeRange{leftType, rightType}, outputType)))
1779       return failure();
1780   }
1781   if (!left.empty()) {
1782     if (failed(verifyNumBlockArgs(this, left, "left", TypeRange{leftType},
1783                                   outputType)))
1784       return failure();
1785   } else if (getLeftIdentity()) {
1786     if (leftType != outputType)
1787       return emitError("left=identity requires first argument to have the same "
1788                        "type as the output");
1789   }
1790   if (!right.empty()) {
1791     if (failed(verifyNumBlockArgs(this, right, "right", TypeRange{rightType},
1792                                   outputType)))
1793       return failure();
1794   } else if (getRightIdentity()) {
1795     if (rightType != outputType)
1796       return emitError("right=identity requires second argument to have the "
1797                        "same type as the output");
1798   }
1799   return success();
1800 }
1801 
1802 LogicalResult UnaryOp::verify() {
1803   Type inputType = getX().getType();
1804   Type outputType = getOutput().getType();
1805 
1806   // Check correct number of block arguments and return type for each
1807   // non-empty region.
1808   Region &present = getPresentRegion();
1809   if (!present.empty()) {
1810     if (failed(verifyNumBlockArgs(this, present, "present",
1811                                   TypeRange{inputType}, outputType)))
1812       return failure();
1813   }
1814   Region &absent = getAbsentRegion();
1815   if (!absent.empty()) {
1816     if (failed(verifyNumBlockArgs(this, absent, "absent", TypeRange{},
1817                                   outputType)))
1818       return failure();
1819     // Absent branch can only yield invariant values.
1820     Block *absentBlock = &absent.front();
1821     Block *parent = getOperation()->getBlock();
1822     Value absentVal =
1823         cast<YieldOp>(absentBlock->getTerminator()).getSingleResult();
1824     if (auto arg = dyn_cast<BlockArgument>(absentVal)) {
1825       if (arg.getOwner() == parent)
1826         return emitError("absent region cannot yield linalg argument");
1827     } else if (Operation *def = absentVal.getDefiningOp()) {
1828       if (!isa<arith::ConstantOp>(def) &&
1829           (def->getBlock() == absentBlock || def->getBlock() == parent))
1830         return emitError("absent region cannot yield locally computed value");
1831     }
1832   }
1833   return success();
1834 }
1835 
1836 bool ConcatenateOp::needsExtraSort() {
1837   SparseTensorType dstStt = getSparseTensorType(*this);
1838   if (dstStt.isAllDense() || !dstStt.isAllOrdered())
1839     return false;
1840 
1841   bool allSameOrdered = llvm::all_of(getInputs(), [dstStt](Value op) {
1842     return getSparseTensorType(op).hasSameDimToLvl(dstStt);
1843   });
1844   // TODO: When conDim != 0, as long as conDim corresponding to the first level
1845   // in all input/output buffers, and all input/output buffers have the same
1846   // dimToLvl, the tmp COO buffer is still unnecessary (e.g, concatenate
1847   // CSC matrices along column).
1848   bool directLowerable =
1849       allSameOrdered && getDimension() == 0 && dstStt.isIdentity();
1850   return !directLowerable;
1851 }
1852 
1853 LogicalResult ConcatenateOp::verify() {
1854   const auto dstTp = getSparseTensorType(*this);
1855   const Dimension concatDim = getDimension();
1856   const Dimension dimRank = dstTp.getDimRank();
1857 
1858   if (getInputs().size() <= 1)
1859     return emitError("Need at least two tensors to concatenate.");
1860 
1861   if (concatDim >= dimRank)
1862     return emitError(llvm::formatv(
1863         "Concat-dimension is out of bounds for dimension-rank ({0} >= {1})",
1864         concatDim, dimRank));
1865 
1866   for (const auto &it : llvm::enumerate(getInputs())) {
1867     const auto i = it.index();
1868     const auto srcTp = getSparseTensorType(it.value());
1869     if (srcTp.hasDynamicDimShape())
1870       return emitError(llvm::formatv("Input tensor ${0} has dynamic shape", i));
1871     const Dimension srcDimRank = srcTp.getDimRank();
1872     if (srcDimRank != dimRank)
1873       return emitError(
1874           llvm::formatv("Input tensor ${0} has a different rank (rank={1}) "
1875                         "from the output tensor (rank={2}).",
1876                         i, srcDimRank, dimRank));
1877   }
1878 
1879   for (Dimension d = 0; d < dimRank; d++) {
1880     const Size dstSh = dstTp.getDimShape()[d];
1881     if (d == concatDim) {
1882       if (!ShapedType::isDynamic(dstSh)) {
1883         // If we reach here, then all inputs have static shapes.  So we
1884         // can use `getDimShape()[d]` instead of `*getDynamicDimSize(d)`
1885         // to avoid redundant assertions in the loop.
1886         Size sumSz = 0;
1887         for (const auto src : getInputs())
1888           sumSz += getSparseTensorType(src).getDimShape()[d];
1889         // If all dimension are statically known, the sum of all the input
1890         // dimensions should be equal to the output dimension.
1891         if (sumSz != dstSh)
1892           return emitError(
1893               "The concatenation dimension of the output tensor should be the "
1894               "sum of all the concatenation dimensions of the input tensors.");
1895       }
1896     } else {
1897       Size prev = dstSh;
1898       for (const auto src : getInputs()) {
1899         const auto sh = getSparseTensorType(src).getDimShape()[d];
1900         if (!ShapedType::isDynamic(prev) && sh != prev)
1901           return emitError("All dimensions (expect for the concatenating one) "
1902                            "should be equal.");
1903         prev = sh;
1904       }
1905     }
1906   }
1907 
1908   return success();
1909 }
1910 
1911 void PushBackOp::build(OpBuilder &builder, OperationState &result,
1912                        Value curSize, Value inBuffer, Value value) {
1913   build(builder, result, curSize, inBuffer, value, Value());
1914 }
1915 
1916 LogicalResult PushBackOp::verify() {
1917   if (Value n = getN()) {
1918     std::optional<int64_t> nValue = getConstantIntValue(n);
1919     if (nValue && nValue.value() < 1)
1920       return emitOpError("n must be not less than 1");
1921   }
1922   return success();
1923 }
1924 
1925 LogicalResult CompressOp::verify() {
1926   const auto stt = getSparseTensorType(getTensor());
1927   if (stt.getLvlRank() != 1 + static_cast<Level>(getLvlCoords().size()))
1928     return emitOpError("incorrect number of coordinates");
1929   return success();
1930 }
1931 
1932 void ForeachOp::build(
1933     OpBuilder &builder, OperationState &result, Value tensor,
1934     ValueRange initArgs, AffineMapAttr order,
1935     function_ref<void(OpBuilder &, Location, ValueRange, Value, ValueRange)>
1936         bodyBuilder) {
1937   build(builder, result, initArgs.getTypes(), tensor, initArgs, order);
1938   // Builds foreach body.
1939   if (!bodyBuilder)
1940     return;
1941   const auto stt = getSparseTensorType(tensor);
1942   const Dimension dimRank = stt.getDimRank();
1943 
1944   // Starts with `dimRank`-many coordinates.
1945   SmallVector<Type> blockArgTypes(dimRank, builder.getIndexType());
1946   // Followed by one value.
1947   blockArgTypes.push_back(stt.getElementType());
1948   // Followed by the reduction variables.
1949   blockArgTypes.append(initArgs.getTypes().begin(), initArgs.getTypes().end());
1950 
1951   SmallVector<Location> blockArgLocs(blockArgTypes.size(), tensor.getLoc());
1952 
1953   OpBuilder::InsertionGuard guard(builder);
1954   auto &region = *result.regions.front();
1955   Block *bodyBlock =
1956       builder.createBlock(&region, region.end(), blockArgTypes, blockArgLocs);
1957   bodyBuilder(builder, result.location,
1958               bodyBlock->getArguments().slice(0, dimRank),
1959               bodyBlock->getArguments()[dimRank],
1960               bodyBlock->getArguments().drop_front(dimRank + 1));
1961 }
1962 
1963 LogicalResult ForeachOp::verify() {
1964   const auto t = getSparseTensorType(getTensor());
1965   const Dimension dimRank = t.getDimRank();
1966   const auto args = getBody()->getArguments();
1967 
1968   if (getOrder().has_value() && getOrder()->getNumDims() != t.getLvlRank())
1969     return emitError("Level traverse order does not match tensor's level rank");
1970 
1971   if (dimRank + 1 + getInitArgs().size() != args.size())
1972     return emitError("Unmatched number of arguments in the block");
1973 
1974   if (getNumResults() != getInitArgs().size())
1975     return emitError("Mismatch in number of init arguments and results");
1976 
1977   if (getResultTypes() != getInitArgs().getTypes())
1978     return emitError("Mismatch in types of init arguments and results");
1979 
1980   // Cannot mark this const, because the getters aren't.
1981   auto yield = cast<YieldOp>(getBody()->getTerminator());
1982   if (yield.getNumOperands() != getNumResults() ||
1983       yield.getOperands().getTypes() != getResultTypes())
1984     return emitError("Mismatch in types of yield values and results");
1985 
1986   const auto iTp = IndexType::get(getContext());
1987   for (Dimension d = 0; d < dimRank; d++)
1988     if (args[d].getType() != iTp)
1989       emitError(
1990           llvm::formatv("Expecting Index type for argument at index {0}", d));
1991 
1992   const auto elemTp = t.getElementType();
1993   const auto valueTp = args[dimRank].getType();
1994   if (elemTp != valueTp)
1995     emitError(llvm::formatv("Unmatched element type between input tensor and "
1996                             "block argument, expected:{0}, got: {1}",
1997                             elemTp, valueTp));
1998   return success();
1999 }
2000 
2001 OpFoldResult ReorderCOOOp::fold(FoldAdaptor adaptor) {
2002   if (getSparseTensorEncoding(getInputCoo().getType()) ==
2003       getSparseTensorEncoding(getResultCoo().getType()))
2004     return getInputCoo();
2005 
2006   return {};
2007 }
2008 
2009 LogicalResult ReorderCOOOp::verify() {
2010   SparseTensorType srcStt = getSparseTensorType(getInputCoo());
2011   SparseTensorType dstStt = getSparseTensorType(getResultCoo());
2012 
2013   if (!srcStt.isCOOType() || !dstStt.isCOOType())
2014     emitError("Expected COO sparse tensors only");
2015 
2016   if (!srcStt.hasSameDimToLvl(dstStt))
2017     emitError("Unmatched dim2lvl map between input and result COO");
2018 
2019   if (srcStt.getPosType() != dstStt.getPosType() ||
2020       srcStt.getCrdType() != dstStt.getCrdType() ||
2021       srcStt.getElementType() != dstStt.getElementType())
2022     emitError("Unmatched storage format between input and result COO");
2023 
2024   return success();
2025 }
2026 
2027 LogicalResult ReduceOp::verify() {
2028   Type inputType = getX().getType();
2029   Region &formula = getRegion();
2030   return verifyNumBlockArgs(this, formula, "reduce",
2031                             TypeRange{inputType, inputType}, inputType);
2032 }
2033 
2034 LogicalResult SelectOp::verify() {
2035   Builder b(getContext());
2036   Type inputType = getX().getType();
2037   Type boolType = b.getI1Type();
2038   Region &formula = getRegion();
2039   return verifyNumBlockArgs(this, formula, "select", TypeRange{inputType},
2040                             boolType);
2041 }
2042 
2043 LogicalResult SortOp::verify() {
2044   AffineMap xPerm = getPermMap();
2045   uint64_t nx = xPerm.getNumDims();
2046   if (nx < 1)
2047     emitError(llvm::formatv("Expected rank(perm_map) > 1, got {0}", nx));
2048 
2049   if (!xPerm.isPermutation())
2050     emitError(llvm::formatv("Expected a permutation map, got {0}", xPerm));
2051 
2052   // We can't check the size of the buffers when n or buffer dimensions aren't
2053   // compile-time constants.
2054   std::optional<int64_t> cn = getConstantIntValue(getN());
2055   if (!cn)
2056     return success();
2057 
2058   // Verify dimensions.
2059   const auto checkDim = [&](Value v, Size minSize, const char *message) {
2060     const Size sh = getMemRefType(v).getShape()[0];
2061     if (!ShapedType::isDynamic(sh) && sh < minSize)
2062       emitError(llvm::formatv("{0} got {1} < {2}", message, sh, minSize));
2063   };
2064   uint64_t n = cn.value();
2065   uint64_t ny = 0;
2066   if (auto nyAttr = getNyAttr())
2067     ny = nyAttr.getInt();
2068   checkDim(getXy(), n * (nx + ny),
2069            "Expected dimension(xy) >= n * (rank(perm_map) + ny)");
2070   for (Value opnd : getYs())
2071     checkDim(opnd, n, "Expected dimension(y) >= n");
2072 
2073   return success();
2074 }
2075 
2076 //===----------------------------------------------------------------------===//
2077 // Sparse Tensor Iteration Operations.
2078 //===----------------------------------------------------------------------===//
2079 
2080 IterSpaceType IteratorType::getIterSpaceType() const {
2081   return IterSpaceType::get(getContext(), getEncoding(), getLoLvl(),
2082                             getHiLvl());
2083 }
2084 
2085 IteratorType IterSpaceType::getIteratorType() const {
2086   return IteratorType::get(getContext(), getEncoding(), getLoLvl(), getHiLvl());
2087 }
2088 
2089 /// Parses a level range in the form "$lo `to` $hi"
2090 /// or simply "$lo" if $hi - $lo = 1
2091 static ParseResult parseLevelRange(AsmParser &parser, Level &lvlLo,
2092                                    Level &lvlHi) {
2093   if (parser.parseInteger(lvlLo))
2094     return failure();
2095 
2096   if (succeeded(parser.parseOptionalKeyword("to"))) {
2097     if (parser.parseInteger(lvlHi))
2098       return failure();
2099   } else {
2100     lvlHi = lvlLo + 1;
2101   }
2102 
2103   if (lvlHi <= lvlLo)
2104     parser.emitError(parser.getNameLoc(),
2105                      "expect larger level upper bound than lower bound");
2106 
2107   return success();
2108 }
2109 
2110 /// Parses a level range in the form "$lo `to` $hi"
2111 /// or simply "$lo" if $hi - $lo = 1
2112 static ParseResult parseLevelRange(OpAsmParser &parser, IntegerAttr &lvlLoAttr,
2113                                    IntegerAttr &lvlHiAttr) {
2114   Level lvlLo, lvlHi;
2115   if (parseLevelRange(parser, lvlLo, lvlHi))
2116     return failure();
2117 
2118   lvlLoAttr = IntegerAttr::get(parser.getBuilder().getIndexType(), lvlLo);
2119   lvlHiAttr = IntegerAttr::get(parser.getBuilder().getIndexType(), lvlHi);
2120   return success();
2121 }
2122 
2123 /// Prints a level range in the form "$lo `to` $hi"
2124 /// or simply "$lo" if $hi - $lo = 1
2125 static void printLevelRange(AsmPrinter &p, Level lo, Level hi) {
2126 
2127   if (lo + 1 == hi)
2128     p << lo;
2129   else
2130     p << lo << " to " << hi;
2131 }
2132 
2133 /// Prints a level range in the form "$lo `to` $hi"
2134 /// or simply "$lo" if $hi - $lo = 1
2135 static void printLevelRange(OpAsmPrinter &p, Operation *, IntegerAttr lvlLo,
2136                             IntegerAttr lvlHi) {
2137   unsigned lo = lvlLo.getValue().getZExtValue();
2138   unsigned hi = lvlHi.getValue().getZExtValue();
2139   printLevelRange(p, lo, hi);
2140 }
2141 
2142 /// Parses a list of `optional` defined list in the form of
2143 /// "(%val0, _, %val1, ...)", where `_` is used to annotate that the
2144 /// corresponding value is not defined (e.g., to represent an undefined
2145 /// coordinate in the sparse iteration space).
2146 static ParseResult parseOptionalDefinedList(
2147     OpAsmParser &parser, OperationState &state, I64BitSet &definedSet,
2148     SmallVectorImpl<OpAsmParser::Argument> &definedArgs,
2149     unsigned maxCnt = std::numeric_limits<unsigned>::max(),
2150     OpAsmParser::Delimiter delimiter = OpAsmParser::Delimiter::Paren) {
2151   unsigned cnt = 0;
2152   ParseResult crdList =
2153       parser.parseCommaSeparatedList(delimiter, [&]() -> ParseResult {
2154         if (parser.parseOptionalKeyword("_")) {
2155           if (parser.parseArgument(definedArgs.emplace_back()))
2156             return failure();
2157           definedSet.set(cnt);
2158         }
2159         cnt += 1;
2160         return success();
2161       });
2162 
2163   if (cnt > maxCnt)
2164     return parser.emitError(parser.getNameLoc(),
2165                             "parsed more value than expected.");
2166 
2167   if (failed(crdList)) {
2168     return parser.emitError(
2169         parser.getNameLoc(),
2170         "expecting SSA value or \"_\" for level coordinates");
2171   }
2172   assert(definedArgs.size() == definedSet.count());
2173   return success();
2174 }
2175 
2176 static void printOptionalDefinedList(OpAsmPrinter &p, unsigned size,
2177                                      Block::BlockArgListType blocksArgs,
2178                                      I64BitSet definedSet) {
2179   if (definedSet.empty())
2180     return;
2181 
2182   for (unsigned i = 0; i < size; i++) {
2183     if (definedSet[i]) {
2184       p << blocksArgs.front();
2185       blocksArgs = blocksArgs.drop_front();
2186     } else {
2187       p << "_";
2188     }
2189     if (i != size - 1)
2190       p << ", ";
2191   }
2192   assert(blocksArgs.empty());
2193 }
2194 
2195 static ParseResult
2196 parseUsedCoordList(OpAsmParser &parser, OperationState &state,
2197                    SmallVectorImpl<OpAsmParser::Argument> &coords) {
2198   // Parse "at(%crd0, _, ...)"
2199   I64BitSet crdUsedLvlSet;
2200   if (succeeded(parser.parseOptionalKeyword("at")) &&
2201       failed(parseOptionalDefinedList(parser, state, crdUsedLvlSet, coords)))
2202     return failure();
2203 
2204   // Always use IndexType for the coordinate.
2205   for (auto &coord : coords)
2206     coord.type = parser.getBuilder().getIndexType();
2207 
2208   // Set the CrdUsedLvl bitset.
2209   state.addAttribute("crdUsedLvls",
2210                      parser.getBuilder().getI64IntegerAttr(crdUsedLvlSet));
2211   return success();
2212 }
2213 
2214 static ParseResult
2215 parseSparseIterateLoop(OpAsmParser &parser, OperationState &state,
2216                        SmallVectorImpl<OpAsmParser::Argument> &iterators,
2217                        SmallVectorImpl<OpAsmParser::Argument> &blockArgs) {
2218   SmallVector<OpAsmParser::UnresolvedOperand> spaces;
2219   SmallVector<OpAsmParser::UnresolvedOperand> initArgs;
2220 
2221   // Parse "%iters, ... in %spaces, ..."
2222   if (parser.parseArgumentList(iterators) || parser.parseKeyword("in") ||
2223       parser.parseOperandList(spaces))
2224     return failure();
2225 
2226   if (iterators.size() != spaces.size())
2227     return parser.emitError(
2228         parser.getNameLoc(),
2229         "mismatch in number of sparse iterators and sparse spaces");
2230 
2231   if (failed(parseUsedCoordList(parser, state, blockArgs)))
2232     return failure();
2233   size_t numCrds = blockArgs.size();
2234 
2235   // Parse "iter_args(%arg = %init, ...)"
2236   bool hasIterArgs = succeeded(parser.parseOptionalKeyword("iter_args"));
2237   if (hasIterArgs)
2238     if (parser.parseAssignmentList(blockArgs, initArgs))
2239       return failure();
2240 
2241   SmallVector<Type> iterSpaceTps;
2242   // parse ": sparse_tensor.iter_space -> ret"
2243   if (parser.parseColon() || parser.parseTypeList(iterSpaceTps))
2244     return failure();
2245   if (iterSpaceTps.size() != spaces.size())
2246     return parser.emitError(parser.getNameLoc(),
2247                             "mismatch in number of iteration space operands "
2248                             "and iteration space types");
2249 
2250   for (auto [it, tp] : llvm::zip_equal(iterators, iterSpaceTps)) {
2251     IterSpaceType spaceTp = llvm::dyn_cast<IterSpaceType>(tp);
2252     if (!spaceTp)
2253       return parser.emitError(parser.getNameLoc(),
2254                               "expected sparse_tensor.iter_space type for "
2255                               "iteration space operands");
2256     it.type = spaceTp.getIteratorType();
2257   }
2258 
2259   if (hasIterArgs)
2260     if (parser.parseArrowTypeList(state.types))
2261       return failure();
2262 
2263   // Resolves input operands.
2264   if (parser.resolveOperands(spaces, iterSpaceTps, parser.getNameLoc(),
2265                              state.operands))
2266     return failure();
2267 
2268   if (hasIterArgs) {
2269     // Strip off leading args that used for coordinates.
2270     MutableArrayRef args = MutableArrayRef(blockArgs).drop_front(numCrds);
2271     if (args.size() != initArgs.size() || args.size() != state.types.size()) {
2272       return parser.emitError(
2273           parser.getNameLoc(),
2274           "mismatch in number of iteration arguments and return values");
2275     }
2276 
2277     for (auto [it, init, tp] : llvm::zip_equal(args, initArgs, state.types)) {
2278       it.type = tp;
2279       if (parser.resolveOperand(init, tp, state.operands))
2280         return failure();
2281     }
2282   }
2283   return success();
2284 }
2285 
2286 static ParseResult
2287 parseSparseCoIterateLoop(OpAsmParser &parser, OperationState &state,
2288                          SmallVectorImpl<Value> &spacesVals,
2289                          SmallVectorImpl<OpAsmParser::Argument> &blockArgs) {
2290 
2291   // Parse "(%spaces, ...)"
2292   SmallVector<OpAsmParser::UnresolvedOperand> spaces;
2293   if (parser.parseOperandList(spaces, OpAsmParser::Delimiter::Paren))
2294     return failure();
2295 
2296   SmallVector<OpAsmParser::Argument> coords;
2297   if (failed(parseUsedCoordList(parser, state, coords)))
2298     return failure();
2299   size_t numCrds = coords.size();
2300 
2301   // Parse "iter_args(%arg = %init, ...)"
2302   SmallVector<OpAsmParser::UnresolvedOperand> initArgs;
2303   bool hasIterArgs = succeeded(parser.parseOptionalKeyword("iter_args"));
2304   if (hasIterArgs)
2305     if (parser.parseAssignmentList(blockArgs, initArgs))
2306       return failure();
2307   blockArgs.append(coords);
2308 
2309   SmallVector<Type> iterSpaceTps;
2310   // parse ": (sparse_tensor.iter_space, ...) -> ret"
2311   if (parser.parseColon() || parser.parseLParen() ||
2312       parser.parseTypeList(iterSpaceTps) || parser.parseRParen())
2313     return failure();
2314 
2315   if (iterSpaceTps.size() != spaces.size())
2316     return parser.emitError(parser.getNameLoc(),
2317                             "mismatch in number of iteration space operands "
2318                             "and iteration space types");
2319 
2320   if (hasIterArgs)
2321     if (parser.parseArrowTypeList(state.types))
2322       return failure();
2323 
2324   // Resolves input sparse iteration spaces.
2325   if (parser.resolveOperands(spaces, iterSpaceTps, parser.getNameLoc(),
2326                              spacesVals))
2327     return failure();
2328   state.operands.append(spacesVals);
2329 
2330   if (hasIterArgs) {
2331     // Strip off trailing args that used for coordinates.
2332     MutableArrayRef args = MutableArrayRef(blockArgs).drop_back(numCrds);
2333     if (args.size() != initArgs.size() || args.size() != state.types.size()) {
2334       return parser.emitError(
2335           parser.getNameLoc(),
2336           "mismatch in number of iteration arguments and return values");
2337     }
2338 
2339     for (auto [it, init, tp] : llvm::zip_equal(args, initArgs, state.types)) {
2340       it.type = tp;
2341       if (parser.resolveOperand(init, tp, state.operands))
2342         return failure();
2343     }
2344   }
2345   return success();
2346 }
2347 
2348 LogicalResult ExtractIterSpaceOp::inferReturnTypes(
2349     MLIRContext *ctx, std::optional<Location> loc, ValueRange ops,
2350     DictionaryAttr attr, OpaqueProperties prop, RegionRange region,
2351     SmallVectorImpl<mlir::Type> &ret) {
2352 
2353   ExtractIterSpaceOp::Adaptor adaptor(ops, attr, prop, region);
2354   SparseTensorType stt = getSparseTensorType(adaptor.getTensor());
2355   ret.push_back(IterSpaceType::get(ctx, stt.getEncoding(), adaptor.getLoLvl(),
2356                                    adaptor.getHiLvl()));
2357   return success();
2358 }
2359 
2360 LogicalResult ExtractIterSpaceOp::verify() {
2361   if (getLoLvl() >= getHiLvl())
2362     return emitOpError("expected smaller level low than level high");
2363 
2364   TypedValue<IteratorType> pIter = getParentIter();
2365   if ((pIter && getLoLvl() == 0) || (!pIter && getLoLvl() != 0)) {
2366     return emitOpError(
2367         "parent iterator should be specified iff level lower bound equals 0");
2368   }
2369 
2370   if (pIter) {
2371     IterSpaceType spaceTp = getExtractedSpace().getType();
2372     if (pIter.getType().getEncoding() != spaceTp.getEncoding())
2373       return emitOpError(
2374           "mismatch in parent iterator encoding and iteration space encoding.");
2375 
2376     if (spaceTp.getLoLvl() != pIter.getType().getHiLvl())
2377       return emitOpError("parent iterator should be used to extract an "
2378                          "iteration space from a consecutive level.");
2379   }
2380 
2381   return success();
2382 }
2383 
2384 LogicalResult ExtractValOp::verify() {
2385   auto stt = getSparseTensorType(getTensor());
2386   auto itTp = getIterator().getType();
2387 
2388   if (stt.getEncoding() != itTp.getEncoding())
2389     return emitOpError("mismatch in tensor encoding and iterator encoding.");
2390 
2391   if (stt.getLvlRank() != itTp.getHiLvl())
2392     return emitOpError("must use last-level iterator to extract values. ");
2393 
2394   return success();
2395 }
2396 
2397 struct RemoveUnusedLvlCrds : public OpRewritePattern<IterateOp> {
2398   using OpRewritePattern::OpRewritePattern;
2399 
2400   LogicalResult matchAndRewrite(IterateOp iterateOp,
2401                                 PatternRewriter &rewriter) const override {
2402     I64BitSet newUsedLvls(0);
2403     llvm::BitVector toRemove(iterateOp.getBody()->getNumArguments());
2404     for (unsigned i = 0, e = iterateOp.getSpaceDim(); i < e; i++) {
2405       if (auto crd = iterateOp.getLvlCrd(i)) {
2406         if (crd->getUsers().empty())
2407           toRemove.set(crd->getArgNumber());
2408         else
2409           newUsedLvls.set(i);
2410       }
2411     }
2412 
2413     // All coordinates are used.
2414     if (toRemove.none())
2415       return failure();
2416 
2417     rewriter.startOpModification(iterateOp);
2418     iterateOp.setCrdUsedLvls(newUsedLvls);
2419     iterateOp.getBody()->eraseArguments(toRemove);
2420     rewriter.finalizeOpModification(iterateOp);
2421     return success();
2422   }
2423 };
2424 
2425 void IterateOp::getCanonicalizationPatterns(mlir::RewritePatternSet &results,
2426                                             mlir::MLIRContext *context) {
2427   results.add<RemoveUnusedLvlCrds>(context);
2428 }
2429 
2430 void IterateOp::build(OpBuilder &builder, OperationState &odsState,
2431                       Value iterSpace, ValueRange initArgs) {
2432   unsigned rank = llvm::cast<IterSpaceType>(iterSpace.getType()).getSpaceDim();
2433   // All ones.
2434   I64BitSet set((1 << rank) - 1);
2435   return build(builder, odsState, iterSpace, initArgs, set);
2436 }
2437 
2438 void IterateOp::build(OpBuilder &builder, OperationState &odsState,
2439                       Value iterSpace, ValueRange initArgs,
2440                       I64BitSet crdUsedLvls) {
2441   OpBuilder::InsertionGuard guard(builder);
2442 
2443   odsState.addOperands(iterSpace);
2444   odsState.addOperands(initArgs);
2445   odsState.getOrAddProperties<Properties>().crdUsedLvls =
2446       builder.getIntegerAttr(builder.getIntegerType(64), crdUsedLvls);
2447   Region *bodyRegion = odsState.addRegion();
2448   odsState.addTypes(initArgs.getTypes());
2449   Block *bodyBlock = builder.createBlock(bodyRegion);
2450 
2451   // First argument, sparse iterator
2452   bodyBlock->addArgument(
2453       llvm::cast<IterSpaceType>(iterSpace.getType()).getIteratorType(),
2454       odsState.location);
2455 
2456   // Followed by a list of used coordinates.
2457   for (unsigned i = 0, e = crdUsedLvls.count(); i < e; i++)
2458     bodyBlock->addArgument(builder.getIndexType(), odsState.location);
2459 
2460   // Followed by a list of user-provided loop arguments.
2461   for (Value v : initArgs)
2462     bodyBlock->addArgument(v.getType(), v.getLoc());
2463 }
2464 
2465 ParseResult IterateOp::parse(OpAsmParser &parser, OperationState &result) {
2466   OpAsmParser::Argument iterator;
2467   OpAsmParser::UnresolvedOperand iterSpace;
2468 
2469   SmallVector<OpAsmParser::Argument> iters, iterArgs;
2470   if (parseSparseIterateLoop(parser, result, iters, iterArgs))
2471     return failure();
2472   if (iters.size() != 1)
2473     return parser.emitError(parser.getNameLoc(),
2474                             "expected only one iterator/iteration space");
2475 
2476   iters.append(iterArgs);
2477   Region *body = result.addRegion();
2478   if (parser.parseRegion(*body, iters))
2479     return failure();
2480 
2481   IterateOp::ensureTerminator(*body, parser.getBuilder(), result.location);
2482 
2483   // Parse the optional attribute list.
2484   if (parser.parseOptionalAttrDict(result.attributes))
2485     return failure();
2486 
2487   return success();
2488 }
2489 
2490 /// Prints the initialization list in the form of
2491 ///   <prefix>(%inner = %outer, %inner2 = %outer2, <...>)
2492 /// where 'inner' values are assumed to be region arguments and 'outer' values
2493 /// are regular SSA values.
2494 static void printInitializationList(OpAsmPrinter &p,
2495                                     Block::BlockArgListType blocksArgs,
2496                                     ValueRange initializers,
2497                                     StringRef prefix = "") {
2498   assert(blocksArgs.size() == initializers.size() &&
2499          "expected same length of arguments and initializers");
2500   if (initializers.empty())
2501     return;
2502 
2503   p << prefix << '(';
2504   llvm::interleaveComma(llvm::zip(blocksArgs, initializers), p, [&](auto it) {
2505     p << std::get<0>(it) << " = " << std::get<1>(it);
2506   });
2507   p << ")";
2508 }
2509 
2510 template <typename SparseLoopOp>
2511 static LogicalResult verifySparseLoopOp(SparseLoopOp op) {
2512   if (op.getInitArgs().size() != op.getNumResults()) {
2513     return op.emitOpError(
2514         "mismatch in number of loop-carried values and defined values");
2515   }
2516   if (op.getCrdUsedLvls().max() > op.getSpaceDim())
2517     return op.emitOpError("required out-of-bound coordinates");
2518 
2519   return success();
2520 }
2521 
2522 LogicalResult IterateOp::verify() { return verifySparseLoopOp(*this); }
2523 LogicalResult CoIterateOp::verify() { return verifySparseLoopOp(*this); }
2524 
2525 void IterateOp::print(OpAsmPrinter &p) {
2526   p << " " << getIterator() << " in " << getIterSpace();
2527   if (!getCrdUsedLvls().empty()) {
2528     p << " at(";
2529     printOptionalDefinedList(p, getSpaceDim(), getCrds(), getCrdUsedLvls());
2530     p << ")";
2531   }
2532   printInitializationList(p, getRegionIterArgs(), getInitArgs(), " iter_args");
2533 
2534   p << " : " << getIterSpace().getType() << " ";
2535   if (!getInitArgs().empty())
2536     p.printArrowTypeList(getInitArgs().getTypes());
2537 
2538   p << " ";
2539   p.printRegion(getRegion(), /*printEntryBlockArgs=*/false,
2540                 /*printBlockTerminators=*/!getInitArgs().empty());
2541 }
2542 
2543 LogicalResult IterateOp::verifyRegions() {
2544   if (getIterator().getType() != getIterSpace().getType().getIteratorType())
2545     return emitOpError("mismatch in iterator and iteration space type");
2546   if (getNumRegionIterArgs() != getNumResults())
2547     return emitOpError(
2548         "mismatch in number of basic block args and defined values");
2549 
2550   auto initArgs = getInitArgs();
2551   auto iterArgs = getRegionIterArgs();
2552   auto yieldVals = getYieldedValues();
2553   auto opResults = getResults();
2554   if (!llvm::all_equal({initArgs.size(), iterArgs.size(), yieldVals.size(),
2555                         opResults.size()})) {
2556     return emitOpError() << "number mismatch between iter args and results.";
2557   }
2558 
2559   for (auto [i, init, iter, yield, ret] :
2560        llvm::enumerate(initArgs, iterArgs, yieldVals, opResults)) {
2561     if (init.getType() != ret.getType())
2562       return emitOpError() << "types mismatch between " << i
2563                            << "th iter operand and defined value";
2564     if (iter.getType() != ret.getType())
2565       return emitOpError() << "types mismatch between " << i
2566                            << "th iter region arg and defined value";
2567     if (yield.getType() != ret.getType())
2568       return emitOpError() << "types mismatch between " << i
2569                            << "th yield value and defined value";
2570   }
2571 
2572   return success();
2573 }
2574 
2575 /// OpInterfaces' methods implemented by IterateOp.
2576 SmallVector<Region *> IterateOp::getLoopRegions() { return {&getRegion()}; }
2577 
2578 MutableArrayRef<OpOperand> IterateOp::getInitsMutable() {
2579   return getInitArgsMutable();
2580 }
2581 
2582 Block::BlockArgListType IterateOp::getRegionIterArgs() {
2583   return getRegion().getArguments().take_back(getNumRegionIterArgs());
2584 }
2585 
2586 std::optional<MutableArrayRef<OpOperand>> IterateOp::getYieldedValuesMutable() {
2587   return cast<sparse_tensor::YieldOp>(
2588              getRegion().getBlocks().front().getTerminator())
2589       .getResultsMutable();
2590 }
2591 
2592 std::optional<ResultRange> IterateOp::getLoopResults() { return getResults(); }
2593 
2594 OperandRange IterateOp::getEntrySuccessorOperands(RegionBranchPoint point) {
2595   return getInitArgs();
2596 }
2597 
2598 void IterateOp::getSuccessorRegions(RegionBranchPoint point,
2599                                     SmallVectorImpl<RegionSuccessor> &regions) {
2600   // Both the operation itself and the region may be branching into the body
2601   // or back into the operation itself.
2602   regions.push_back(RegionSuccessor(&getRegion(), getRegionIterArgs()));
2603   // It is possible for loop not to enter the body.
2604   regions.push_back(RegionSuccessor(getResults()));
2605 }
2606 
2607 void CoIterateOp::build(OpBuilder &builder, OperationState &odsState,
2608                         ValueRange iterSpaces, ValueRange initArgs,
2609                         unsigned numCases) {
2610   unsigned rank =
2611       cast<IterSpaceType>(iterSpaces.front().getType()).getSpaceDim();
2612   // All ones.
2613   I64BitSet set((1 << rank) - 1);
2614   // Generates all-zero case bits (they only serve as placeholders), which are
2615   // supposed to be overriden later. We need to preallocate all the regions as
2616   // mlir::Region cannot be dynamically added later after the operation is
2617   // created.
2618   SmallVector<int64_t> caseBits(numCases, 0);
2619   ArrayAttr cases = builder.getI64ArrayAttr(caseBits);
2620   return CoIterateOp::build(builder, odsState, initArgs.getTypes(), iterSpaces,
2621                             initArgs, set, cases,
2622                             /*caseRegionsCount=*/numCases);
2623 }
2624 
2625 ParseResult CoIterateOp::parse(OpAsmParser &parser, OperationState &result) {
2626 
2627   SmallVector<Value> spaces;
2628   // The block argument list of each regions, it is arranged in the order of
2629   // ([used coordinate list], [loop iterations args], [sparse iterator list]).
2630   SmallVector<OpAsmParser::Argument> blockArgs;
2631   if (parseSparseCoIterateLoop(parser, result, spaces, blockArgs))
2632     return failure();
2633 
2634   result.addAttribute("operandSegmentSizes",
2635                       parser.getBuilder().getDenseI32ArrayAttr(
2636                           {static_cast<int32_t>(spaces.size()),
2637                            static_cast<int32_t>(result.types.size())}));
2638 
2639   SmallVector<Attribute> cases;
2640   while (succeeded(parser.parseOptionalKeyword("case"))) {
2641     // Parse one region per case.
2642     I64BitSet definedItSet;
2643     SmallVector<OpAsmParser::Argument> definedIts;
2644     if (parseOptionalDefinedList(parser, result, definedItSet, definedIts,
2645                                  spaces.size(), OpAsmParser::Delimiter::None))
2646       return failure();
2647 
2648     cases.push_back(parser.getBuilder().getI64IntegerAttr(definedItSet));
2649 
2650     for (auto [i, definedIdx] : llvm::enumerate(definedItSet.bits())) {
2651       // Resolve the iterator type based on the iteration space type.
2652       auto spaceTp = llvm::cast<IterSpaceType>(spaces[definedIdx].getType());
2653       definedIts[i].type = spaceTp.getIteratorType();
2654     }
2655     definedIts.insert(definedIts.begin(), blockArgs.begin(), blockArgs.end());
2656     Region *body = result.addRegion();
2657     if (parser.parseRegion(*body, definedIts))
2658       return failure();
2659 
2660     CoIterateOp::ensureTerminator(*body, parser.getBuilder(), result.location);
2661   }
2662 
2663   result.addAttribute("cases", ArrayAttr::get(parser.getContext(), cases));
2664 
2665   // Parse the optional attribute list.
2666   if (parser.parseOptionalAttrDict(result.attributes))
2667     return failure();
2668 
2669   return success();
2670 }
2671 
2672 void CoIterateOp::print(OpAsmPrinter &p) {
2673   p << " (";
2674   llvm::interleaveComma(getIterSpaces(), p, [&](auto s) { p << s; });
2675   p << ")";
2676 
2677   if (!getCrdUsedLvls().empty()) {
2678     p << " at(";
2679     printOptionalDefinedList(p, getSpaceDim(), getCrds(0), getCrdUsedLvls());
2680     p << ")";
2681   }
2682 
2683   printInitializationList(p, getRegionIterArgs(0), getInitArgs(), " iter_args");
2684 
2685   p << " : (" << getIterSpaces().getTypes() << ")";
2686   if (!getInitArgs().empty())
2687     p.printArrowTypeList(getInitArgs().getTypes());
2688 
2689   for (unsigned idx = 0, e = getRegions().size(); idx < e; idx++) {
2690     p.printNewline();
2691     p << "case ";
2692     printOptionalDefinedList(p, getIterSpaces().size(), getRegionIterators(idx),
2693                              getRegionDefinedSpace(idx));
2694     p << " ";
2695     p.printRegion(getRegion(idx), /*printEntryBlockArgs=*/false,
2696                   /*printBlockTerminators=*/!getInitArgs().empty());
2697   }
2698 }
2699 
2700 ValueRange CoIterateOp::getYieldedValues(unsigned regionIdx) {
2701   return cast<sparse_tensor::YieldOp>(
2702              getRegion(regionIdx).getBlocks().front().getTerminator())
2703       .getResults();
2704 }
2705 
2706 LogicalResult CoIterateOp::verifyRegions() {
2707   for (unsigned r = 0, e = getNumRegions(); r < e; r++) {
2708     if (getNumRegionIterArgs() != getNumResults())
2709       return emitOpError(
2710           "mismatch in number of basic block args and defined values");
2711 
2712     auto initArgs = getInitArgs();
2713     auto iterArgs = getRegionIterArgs(r);
2714     auto yieldVals = getYieldedValues(r);
2715     auto opResults = getResults();
2716     if (!llvm::all_equal({initArgs.size(), iterArgs.size(), yieldVals.size(),
2717                           opResults.size()})) {
2718       return emitOpError()
2719              << "number mismatch between iter args and results on " << r
2720              << "th region";
2721     }
2722 
2723     for (auto [i, init, iter, yield, ret] :
2724          llvm::enumerate(initArgs, iterArgs, yieldVals, opResults)) {
2725       if (init.getType() != ret.getType())
2726         return emitOpError()
2727                << "types mismatch between " << i
2728                << "th iter operand and defined value on " << r << "th region";
2729       if (iter.getType() != ret.getType())
2730         return emitOpError() << "types mismatch between " << i
2731                              << "th iter region arg and defined value on " << r
2732                              << "th region";
2733       if (yield.getType() != ret.getType())
2734         return emitOpError()
2735                << "types mismatch between " << i
2736                << "th yield value and defined value on " << r << "th region";
2737     }
2738   }
2739 
2740   auto cases = getRegionDefinedSpaces();
2741   llvm::SmallSetVector<uint64_t, 8> set(cases.begin(), cases.end());
2742   if (set.size() != getNumRegions())
2743     return emitOpError("contains duplicated cases.");
2744 
2745   return success();
2746 }
2747 
2748 //===----------------------------------------------------------------------===//
2749 // Sparse Tensor Dialect Setups.
2750 //===----------------------------------------------------------------------===//
2751 
2752 /// Materialize a single constant operation from a given attribute value with
2753 /// the desired resultant type.
2754 Operation *SparseTensorDialect::materializeConstant(OpBuilder &builder,
2755                                                     Attribute value, Type type,
2756                                                     Location loc) {
2757   if (auto op = arith::ConstantOp::materialize(builder, value, type, loc))
2758     return op;
2759   return nullptr;
2760 }
2761 
2762 namespace {
2763 struct SparseTensorAsmDialectInterface : public OpAsmDialectInterface {
2764   using OpAsmDialectInterface::OpAsmDialectInterface;
2765 
2766   AliasResult getAlias(Attribute attr, raw_ostream &os) const override {
2767     if (isa<SparseTensorEncodingAttr>(attr)) {
2768       os << "sparse";
2769       return AliasResult::OverridableAlias;
2770     }
2771     return AliasResult::NoAlias;
2772   }
2773 };
2774 } // namespace
2775 
2776 void SparseTensorDialect::initialize() {
2777   addInterface<SparseTensorAsmDialectInterface>();
2778   addAttributes<
2779 #define GET_ATTRDEF_LIST
2780 #include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.cpp.inc"
2781       >();
2782   addTypes<
2783 #define GET_TYPEDEF_LIST
2784 #include "mlir/Dialect/SparseTensor/IR/SparseTensorTypes.cpp.inc"
2785       >();
2786   addOperations<
2787 #define GET_OP_LIST
2788 #include "mlir/Dialect/SparseTensor/IR/SparseTensorOps.cpp.inc"
2789       >();
2790   declarePromisedInterfaces<
2791       bufferization::BufferizableOpInterface, ConcatenateOp, ConvertOp, LoadOp,
2792       NewOp, NumberOfEntriesOp, AssembleOp, DisassembleOp,
2793       ToCoordinatesBufferOp, ToCoordinatesOp, ToPositionsOp, ToValuesOp>();
2794 }
2795 
2796 #define GET_OP_CLASSES
2797 #include "mlir/Dialect/SparseTensor/IR/SparseTensorOps.cpp.inc"
2798 
2799 #include "mlir/Dialect/SparseTensor/IR/SparseTensorOpsDialect.cpp.inc"
2800