xref: /llvm-project/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp (revision e71eacc5b19785bc46ce9c3d8541a0c83c65660e)
1 //===- SparseTensorDialect.cpp - Sparse tensor dialect implementation -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include <utility>
10 
11 #include "Detail/DimLvlMapParser.h"
12 
13 #include "mlir/Dialect/SparseTensor/IR/Enums.h"
14 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
15 #include "mlir/Dialect/SparseTensor/IR/SparseTensorStorageLayout.h"
16 #include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h"
17 
18 #include "mlir/Dialect/Arith/IR/Arith.h"
19 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
20 #include "mlir/Dialect/Complex/IR/Complex.h"
21 #include "mlir/Dialect/Utils/StaticValueUtils.h"
22 #include "mlir/IR/Builders.h"
23 #include "mlir/IR/DialectImplementation.h"
24 #include "mlir/IR/Matchers.h"
25 #include "mlir/IR/OpImplementation.h"
26 #include "mlir/IR/PatternMatch.h"
27 #include "llvm/ADT/TypeSwitch.h"
28 #include "llvm/Support/FormatVariadic.h"
29 
30 #define GET_ATTRDEF_CLASSES
31 #include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.cpp.inc"
32 #include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrEnums.cpp.inc"
33 
34 // Forward declarations, following custom print/parsing methods are referenced
35 // by the generated code for SparseTensorTypes.td.
36 static mlir::ParseResult parseLevelRange(mlir::AsmParser &,
37                                          mlir::sparse_tensor::Level &,
38                                          mlir::sparse_tensor::Level &);
39 static void printLevelRange(mlir::AsmPrinter &, mlir::sparse_tensor::Level,
40                             mlir::sparse_tensor::Level);
41 
42 #define GET_TYPEDEF_CLASSES
43 #include "mlir/Dialect/SparseTensor/IR/SparseTensorTypes.cpp.inc"
44 
45 using namespace mlir;
46 using namespace mlir::sparse_tensor;
47 
48 // Support hashing LevelType such that SparseTensorEncodingAttr can be hashed as
49 // well.
50 namespace mlir::sparse_tensor {
51 llvm::hash_code hash_value(LevelType lt) {
52   return llvm::hash_value(static_cast<uint64_t>(lt));
53 }
54 } // namespace mlir::sparse_tensor
55 
56 //===----------------------------------------------------------------------===//
57 // Local Convenience Methods.
58 //===----------------------------------------------------------------------===//
59 
60 static constexpr bool acceptBitWidth(unsigned bitWidth) {
61   switch (bitWidth) {
62   case 0:
63   case 8:
64   case 16:
65   case 32:
66   case 64:
67     return true;
68   default:
69     return false;
70   }
71 }
72 
73 static SmallVector<Size>
74 getSparseFieldShape(const SparseTensorEncodingAttr enc,
75                     std::optional<ArrayRef<int64_t>> dimShape) {
76   assert(enc);
77   // With only encoding, we can not determine the static shape for leading
78   // batch levels, we therefore return a dynamic shape memref instead.
79   SmallVector<int64_t> memrefShape(enc.getBatchLvlRank(), ShapedType::kDynamic);
80   if (dimShape.has_value()) {
81     // If the actual tensor shape is provided, we can then refine the leading
82     // batch dimension.
83     SmallVector<int64_t> lvlShape =
84         enc.translateShape(*dimShape, CrdTransDirectionKind::dim2lvl);
85     memrefShape.assign(lvlShape.begin(),
86                        lvlShape.begin() + enc.getBatchLvlRank());
87   }
88   // Another dynamic dimension to store the sparse level.
89   memrefShape.push_back(ShapedType::kDynamic);
90   return memrefShape;
91 }
92 
93 //===----------------------------------------------------------------------===//
94 // SparseTensorDialect StorageLayout.
95 //===----------------------------------------------------------------------===//
96 
97 static constexpr Level kInvalidLevel = -1u;
98 static constexpr Level kInvalidFieldIndex = -1u;
99 static constexpr FieldIndex kDataFieldStartingIdx = 0;
100 
101 void StorageLayout::foreachField(
102     llvm::function_ref<bool(FieldIndex, SparseTensorFieldKind, Level,
103                             LevelType)>
104         callback) const {
105   const auto lvlTypes = enc.getLvlTypes();
106   const Level lvlRank = enc.getLvlRank();
107   SmallVector<COOSegment> cooSegs = SparseTensorType(enc).getCOOSegments();
108   FieldIndex fieldIdx = kDataFieldStartingIdx;
109 
110   ArrayRef cooSegsRef = cooSegs;
111   // Per-level storage.
112   for (Level l = 0; l < lvlRank; /*l += 1 or l += AoSCooLen*/) {
113     const auto lt = lvlTypes[l];
114     if (isWithPosLT(lt)) {
115       if (!(callback(fieldIdx++, SparseTensorFieldKind::PosMemRef, l, lt)))
116         return;
117     }
118     if (isWithCrdLT(lt)) {
119       if (!(callback(fieldIdx++, SparseTensorFieldKind::CrdMemRef, l, lt)))
120         return;
121     }
122     if (!cooSegsRef.empty() && cooSegsRef.front().isSegmentStart(l)) {
123       if (!cooSegsRef.front().isSoA) {
124         // AoS COO, all singletons are fused into one memrefs. Skips the entire
125         // COO segement.
126         l = cooSegsRef.front().lvlRange.second;
127       } else {
128         // SoA COO, each singleton level has one memref.
129         l++;
130       }
131       // Expire handled COO segment.
132       cooSegsRef = cooSegsRef.drop_front();
133     } else {
134       // Non COO levels.
135       l++;
136     }
137   }
138   // The values array.
139   if (!(callback(fieldIdx++, SparseTensorFieldKind::ValMemRef, kInvalidLevel,
140                  LevelFormat::Undef)))
141     return;
142   // Put metadata at the end.
143   if (!(callback(fieldIdx++, SparseTensorFieldKind::StorageSpec, kInvalidLevel,
144                  LevelFormat::Undef)))
145     return;
146 }
147 
148 void sparse_tensor::foreachFieldAndTypeInSparseTensor(
149     SparseTensorType stt,
150     llvm::function_ref<bool(Type, FieldIndex, SparseTensorFieldKind, Level,
151                             LevelType)>
152         callback) {
153   assert(stt.hasEncoding());
154 
155   SmallVector<int64_t> memrefShape =
156       getSparseFieldShape(stt.getEncoding(), stt.getDimShape());
157 
158   const Type specType = StorageSpecifierType::get(stt.getEncoding());
159   // memref<[batch] x ? x pos>  positions
160   const Type posMemType = MemRefType::get(memrefShape, stt.getPosType());
161   // memref<[batch] x ? x crd>  coordinates
162   const Type crdMemType = MemRefType::get(memrefShape, stt.getCrdType());
163   // memref<[batch] x ? x eltType> values
164   const Type valMemType = MemRefType::get(memrefShape, stt.getElementType());
165 
166   StorageLayout(stt).foreachField([specType, posMemType, crdMemType, valMemType,
167                                    callback](FieldIndex fieldIdx,
168                                              SparseTensorFieldKind fieldKind,
169                                              Level lvl, LevelType lt) -> bool {
170     switch (fieldKind) {
171     case SparseTensorFieldKind::StorageSpec:
172       return callback(specType, fieldIdx, fieldKind, lvl, lt);
173     case SparseTensorFieldKind::PosMemRef:
174       return callback(posMemType, fieldIdx, fieldKind, lvl, lt);
175     case SparseTensorFieldKind::CrdMemRef:
176       return callback(crdMemType, fieldIdx, fieldKind, lvl, lt);
177     case SparseTensorFieldKind::ValMemRef:
178       return callback(valMemType, fieldIdx, fieldKind, lvl, lt);
179     };
180     llvm_unreachable("unrecognized field kind");
181   });
182 }
183 
184 unsigned StorageLayout::getNumFields() const {
185   unsigned numFields = 0;
186   foreachField([&numFields](FieldIndex, SparseTensorFieldKind, Level,
187                             LevelType) -> bool {
188     numFields++;
189     return true;
190   });
191   return numFields;
192 }
193 
194 unsigned StorageLayout::getNumDataFields() const {
195   unsigned numFields = 0; // one value memref
196   foreachField([&numFields](FieldIndex fidx, SparseTensorFieldKind, Level,
197                             LevelType) -> bool {
198     if (fidx >= kDataFieldStartingIdx)
199       numFields++;
200     return true;
201   });
202   numFields -= 1; // the last field is StorageSpecifier
203   assert(numFields == getNumFields() - kDataFieldStartingIdx - 1);
204   return numFields;
205 }
206 
207 std::pair<FieldIndex, unsigned>
208 StorageLayout::getFieldIndexAndStride(SparseTensorFieldKind kind,
209                                       std::optional<Level> lvl) const {
210   FieldIndex fieldIdx = kInvalidFieldIndex;
211   unsigned stride = 1;
212   if (kind == SparseTensorFieldKind::CrdMemRef) {
213     assert(lvl.has_value());
214     const Level cooStart = SparseTensorType(enc).getAoSCOOStart();
215     const Level lvlRank = enc.getLvlRank();
216     if (lvl.value() >= cooStart && lvl.value() < lvlRank) {
217       lvl = cooStart;
218       stride = lvlRank - cooStart;
219     }
220   }
221   foreachField([lvl, kind, &fieldIdx](FieldIndex fIdx,
222                                       SparseTensorFieldKind fKind, Level fLvl,
223                                       LevelType lt) -> bool {
224     if ((lvl && fLvl == lvl.value() && kind == fKind) ||
225         (kind == fKind && fKind == SparseTensorFieldKind::ValMemRef)) {
226       fieldIdx = fIdx;
227       // Returns false to break the iteration.
228       return false;
229     }
230     return true;
231   });
232   assert(fieldIdx != kInvalidFieldIndex);
233   return std::pair<FieldIndex, unsigned>(fieldIdx, stride);
234 }
235 
236 //===----------------------------------------------------------------------===//
237 // SparseTensorDialect Attribute Methods.
238 //===----------------------------------------------------------------------===//
239 
240 std::optional<uint64_t> SparseTensorDimSliceAttr::getStatic(int64_t v) {
241   return isDynamic(v) ? std::nullopt
242                       : std::make_optional(static_cast<uint64_t>(v));
243 }
244 
245 std::optional<uint64_t> SparseTensorDimSliceAttr::getStaticOffset() const {
246   return getStatic(getOffset());
247 }
248 
249 std::optional<uint64_t> SparseTensorDimSliceAttr::getStaticStride() const {
250   return getStatic(getStride());
251 }
252 
253 std::optional<uint64_t> SparseTensorDimSliceAttr::getStaticSize() const {
254   return getStatic(getSize());
255 }
256 
257 bool SparseTensorDimSliceAttr::isCompletelyDynamic() const {
258   return isDynamic(getOffset()) && isDynamic(getStride()) &&
259          isDynamic(getSize());
260 }
261 
262 std::string SparseTensorDimSliceAttr::getStaticString(int64_t v) {
263   return isDynamic(v) ? "?" : std::to_string(v);
264 }
265 
266 void SparseTensorDimSliceAttr::print(llvm::raw_ostream &os) const {
267   assert(getImpl() && "Uninitialized SparseTensorDimSliceAttr");
268   os << '(';
269   os << getStaticString(getOffset());
270   os << ", ";
271   os << getStaticString(getSize());
272   os << ", ";
273   os << getStaticString(getStride());
274   os << ')';
275 }
276 
277 void SparseTensorDimSliceAttr::print(AsmPrinter &printer) const {
278   print(printer.getStream());
279 }
280 
281 static ParseResult parseOptionalStaticSlice(int64_t &result,
282                                             AsmParser &parser) {
283   auto parseResult = parser.parseOptionalInteger(result);
284   if (parseResult.has_value()) {
285     if (parseResult.value().succeeded() && result < 0) {
286       parser.emitError(
287           parser.getCurrentLocation(),
288           "expect positive value or ? for slice offset/size/stride");
289       return failure();
290     }
291     return parseResult.value();
292   }
293 
294   // Else, and '?' which represented dynamic slice
295   result = SparseTensorDimSliceAttr::kDynamic;
296   return parser.parseQuestion();
297 }
298 
299 Attribute SparseTensorDimSliceAttr::parse(AsmParser &parser, Type type) {
300   int64_t offset = kDynamic, size = kDynamic, stride = kDynamic;
301 
302   if (failed(parser.parseLParen()) ||
303       failed(parseOptionalStaticSlice(offset, parser)) ||
304       failed(parser.parseComma()) ||
305       failed(parseOptionalStaticSlice(size, parser)) ||
306       failed(parser.parseComma()) ||
307       failed(parseOptionalStaticSlice(stride, parser)) ||
308       failed(parser.parseRParen()))
309     return {};
310 
311   return parser.getChecked<SparseTensorDimSliceAttr>(parser.getContext(),
312                                                      offset, size, stride);
313 }
314 
315 LogicalResult
316 SparseTensorDimSliceAttr::verify(function_ref<InFlightDiagnostic()> emitError,
317                                  int64_t offset, int64_t size, int64_t stride) {
318   if (!isDynamic(offset) && offset < 0)
319     return emitError() << "expect non-negative value or ? for slice offset";
320   if (!isDynamic(size) && size <= 0)
321     return emitError() << "expect positive value or ? for slice size";
322   if (!isDynamic(stride) && stride <= 0)
323     return emitError() << "expect positive value or ? for slice stride";
324   return success();
325 }
326 
327 SparseTensorEncodingAttr
328 SparseTensorEncodingAttr::withDimToLvl(AffineMap dimToLvl) const {
329   assert(getImpl() && "Uninitialized SparseTensorEncodingAttr");
330   return SparseTensorEncodingAttr::get(
331       getContext(), getLvlTypes(), dimToLvl, AffineMap(), getPosWidth(),
332       getCrdWidth(), getExplicitVal(), getImplicitVal());
333 }
334 
335 SparseTensorEncodingAttr
336 SparseTensorEncodingAttr::withDimToLvl(SparseTensorEncodingAttr enc) const {
337   return withDimToLvl(enc ? enc.getDimToLvl() : AffineMap());
338 }
339 
340 SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutDimToLvl() const {
341   return withDimToLvl(AffineMap());
342 }
343 
344 SparseTensorEncodingAttr
345 SparseTensorEncodingAttr::withBitWidths(unsigned posWidth,
346                                         unsigned crdWidth) const {
347   assert(getImpl() && "Uninitialized SparseTensorEncodingAttr");
348   return SparseTensorEncodingAttr::get(
349       getContext(), getLvlTypes(), getDimToLvl(), getLvlToDim(), posWidth,
350       crdWidth, getExplicitVal(), getImplicitVal());
351 }
352 
353 SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutBitWidths() const {
354   return withBitWidths(0, 0);
355 }
356 
357 SparseTensorEncodingAttr
358 SparseTensorEncodingAttr::withExplicitVal(Attribute explicitVal) const {
359   assert(getImpl() && "Uninitialized SparseTensorEncodingAttr");
360   return SparseTensorEncodingAttr::get(
361       getContext(), getLvlTypes(), getDimToLvl(), getLvlToDim(), getPosWidth(),
362       getCrdWidth(), explicitVal, getImplicitVal());
363 }
364 
365 SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutExplicitVal() const {
366   return withExplicitVal(Attribute());
367 }
368 
369 SparseTensorEncodingAttr
370 SparseTensorEncodingAttr::withImplicitVal(Attribute implicitVal) const {
371   assert(getImpl() && "Uninitialized SparseTensorEncodingAttr");
372   return SparseTensorEncodingAttr::get(
373       getContext(), getLvlTypes(), getDimToLvl(), getLvlToDim(), getPosWidth(),
374       getCrdWidth(), getExplicitVal(), implicitVal);
375 }
376 
377 SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutImplicitVal() const {
378   return withImplicitVal(Attribute());
379 }
380 
381 SparseTensorEncodingAttr SparseTensorEncodingAttr::withDimSlices(
382     ArrayRef<SparseTensorDimSliceAttr> dimSlices) const {
383   return SparseTensorEncodingAttr::get(
384       getContext(), getLvlTypes(), getDimToLvl(), getLvlToDim(), getPosWidth(),
385       getCrdWidth(), getExplicitVal(), getImplicitVal(), dimSlices);
386 }
387 
388 SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutDimSlices() const {
389   return withDimSlices(ArrayRef<SparseTensorDimSliceAttr>{});
390 }
391 
392 uint64_t SparseTensorEncodingAttr::getBatchLvlRank() const {
393   ArrayRef<LevelType> lvlTypes = getLvlTypes();
394   auto lastBatch = std::find_if(lvlTypes.rbegin(), lvlTypes.rend(), isBatchLT);
395   return std::distance(lastBatch, lvlTypes.rend());
396 }
397 
398 bool SparseTensorEncodingAttr::isAllDense() const {
399   return !getImpl() || llvm::all_of(getLvlTypes(), isDenseLT);
400 }
401 
402 bool SparseTensorEncodingAttr::isAllOrdered() const {
403   return !getImpl() || llvm::all_of(getLvlTypes(), isOrderedLT);
404 }
405 
406 Type SparseTensorEncodingAttr::getCrdElemType() const {
407   if (!getImpl())
408     return nullptr;
409   if (getCrdWidth())
410     return IntegerType::get(getContext(), getCrdWidth());
411   return IndexType::get(getContext());
412 }
413 
414 Type SparseTensorEncodingAttr::getPosElemType() const {
415   if (!getImpl())
416     return nullptr;
417   if (getPosWidth())
418     return IntegerType::get(getContext(), getPosWidth());
419   return IndexType::get(getContext());
420 }
421 
422 MemRefType SparseTensorEncodingAttr::getCrdMemRefType(
423     std::optional<ArrayRef<int64_t>> dimShape) const {
424   SmallVector<Size> shape = getSparseFieldShape(*this, dimShape);
425   return MemRefType::get(shape, getCrdElemType());
426 }
427 
428 MemRefType SparseTensorEncodingAttr::getPosMemRefType(
429     std::optional<ArrayRef<int64_t>> dimShape) const {
430   SmallVector<Size> shape = getSparseFieldShape(*this, dimShape);
431   return MemRefType::get(shape, getPosElemType());
432 }
433 
434 bool SparseTensorEncodingAttr::isIdentity() const {
435   return !getImpl() || !getDimToLvl() || getDimToLvl().isIdentity();
436 }
437 
438 bool SparseTensorEncodingAttr::isPermutation() const {
439   return !getImpl() || !getDimToLvl() || getDimToLvl().isPermutation();
440 }
441 
442 Dimension SparseTensorEncodingAttr::getDimRank() const {
443   assert(getImpl() && "Uninitialized SparseTensorEncodingAttr");
444   const auto dimToLvl = getDimToLvl();
445   return dimToLvl ? dimToLvl.getNumDims() : getLvlRank();
446 }
447 
448 Level SparseTensorEncodingAttr::getLvlRank() const {
449   assert(getImpl() && "Uninitialized SparseTensorEncodingAttr");
450   return getLvlTypes().size();
451 }
452 
453 LevelType SparseTensorEncodingAttr::getLvlType(Level l) const {
454   if (!getImpl())
455     return LevelFormat::Batch;
456   assert(l < getLvlRank() && "Level is out of bounds");
457   return getLvlTypes()[l];
458 }
459 
460 bool SparseTensorEncodingAttr::isSlice() const {
461   assert(getImpl() && "Uninitialized SparseTensorEncodingAttr");
462   return !getDimSlices().empty();
463 }
464 
465 SparseTensorDimSliceAttr
466 SparseTensorEncodingAttr::getDimSlice(Dimension dim) const {
467   assert(isSlice() && "Is not a slice");
468   const auto dimSlices = getDimSlices();
469   assert(dim < dimSlices.size() && "Dimension is out of bounds");
470   return dimSlices[dim];
471 }
472 
473 std::optional<uint64_t>
474 SparseTensorEncodingAttr::getStaticDimSliceOffset(Dimension dim) const {
475   return getDimSlice(dim).getStaticOffset();
476 }
477 
478 std::optional<uint64_t>
479 SparseTensorEncodingAttr::getStaticDimSliceStride(Dimension dim) const {
480   return getDimSlice(dim).getStaticStride();
481 }
482 
483 std::optional<uint64_t>
484 SparseTensorEncodingAttr::getStaticLvlSliceOffset(Level lvl) const {
485   return getStaticDimSliceOffset(toDim(*this, lvl));
486 }
487 
488 std::optional<uint64_t>
489 SparseTensorEncodingAttr::getStaticLvlSliceStride(Level lvl) const {
490   return getStaticDimSliceStride(toDim(*this, lvl));
491 }
492 
493 SmallVector<int64_t>
494 SparseTensorEncodingAttr::translateShape(ArrayRef<int64_t> srcShape,
495                                          CrdTransDirectionKind dir) const {
496   if (isIdentity())
497     return SmallVector<int64_t>(srcShape);
498 
499   SmallVector<int64_t> ret;
500   unsigned rank =
501       dir == CrdTransDirectionKind::dim2lvl ? getLvlRank() : getDimRank();
502   ret.reserve(rank);
503 
504   if (isPermutation()) {
505     for (unsigned r = 0; r < rank; r++) {
506       unsigned trans = dir == CrdTransDirectionKind::dim2lvl ? toDim(*this, r)
507                                                              : toLvl(*this, r);
508       ret.push_back(srcShape[trans]);
509     }
510     return ret;
511   }
512 
513   // Handle non-permutation maps.
514   AffineMap transMap =
515       dir == CrdTransDirectionKind::dim2lvl ? getDimToLvl() : getLvlToDim();
516 
517   SmallVector<AffineExpr> dimRep;
518   dimRep.reserve(srcShape.size());
519   for (int64_t sz : srcShape) {
520     if (!ShapedType::isDynamic(sz)) {
521       // Push back the max coordinate for the given dimension/level size.
522       dimRep.push_back(getAffineConstantExpr(sz - 1, getContext()));
523     } else {
524       // A dynamic size, use a AffineDimExpr to symbolize the value.
525       dimRep.push_back(getAffineDimExpr(dimRep.size(), getContext()));
526     }
527   };
528 
529   for (AffineExpr exp : transMap.getResults()) {
530     // Do constant propagation on the affine map.
531     AffineExpr evalExp =
532         simplifyAffineExpr(exp.replaceDims(dimRep), srcShape.size(), 0);
533     // use llvm namespace here to avoid ambiguity
534     if (auto c = llvm::dyn_cast<AffineConstantExpr>(evalExp)) {
535       ret.push_back(c.getValue() + 1);
536     } else {
537       if (auto mod = llvm::dyn_cast<AffineBinaryOpExpr>(evalExp);
538           mod && mod.getKind() == AffineExprKind::Mod) {
539         // We can still infer a static bound for expressions in form
540         // "d % constant" since d % constant \in [0, constant).
541         if (auto bound = llvm::dyn_cast<AffineConstantExpr>(mod.getRHS())) {
542           ret.push_back(bound.getValue());
543           continue;
544         }
545       }
546       ret.push_back(ShapedType::kDynamic);
547     }
548   }
549   assert(ret.size() == rank);
550   return ret;
551 }
552 
553 ValueRange
554 SparseTensorEncodingAttr::translateCrds(OpBuilder &builder, Location loc,
555                                         ValueRange crds,
556                                         CrdTransDirectionKind dir) const {
557   if (!getImpl())
558     return crds;
559 
560   SmallVector<Type> retType(
561       dir == CrdTransDirectionKind::lvl2dim ? getDimRank() : getLvlRank(),
562       builder.getIndexType());
563   auto transOp = builder.create<CrdTranslateOp>(loc, retType, crds, dir, *this);
564   return transOp.getOutCrds();
565 }
566 
567 Attribute SparseTensorEncodingAttr::parse(AsmParser &parser, Type type) {
568   // Open "<{" part.
569   if (failed(parser.parseLess()))
570     return {};
571   if (failed(parser.parseLBrace()))
572     return {};
573 
574   // Process the data from the parsed dictionary value into struct-like data.
575   SmallVector<LevelType> lvlTypes;
576   SmallVector<SparseTensorDimSliceAttr> dimSlices;
577   AffineMap dimToLvl = {};
578   AffineMap lvlToDim = {};
579   unsigned posWidth = 0;
580   unsigned crdWidth = 0;
581   Attribute explicitVal;
582   Attribute implicitVal;
583   StringRef attrName;
584   SmallVector<StringRef, 5> keys = {"map", "posWidth", "crdWidth",
585                                     "explicitVal", "implicitVal"};
586   while (succeeded(parser.parseOptionalKeyword(&attrName))) {
587     // Detect admissible keyword.
588     auto *it = find(keys, attrName);
589     if (it == keys.end()) {
590       parser.emitError(parser.getNameLoc(), "unexpected key: ") << attrName;
591       return {};
592     }
593     unsigned keyWordIndex = it - keys.begin();
594     // Consume the `=` after keys
595     if (failed(parser.parseEqual()))
596       return {};
597     // Dispatch on keyword.
598     switch (keyWordIndex) {
599     case 0: { // map
600       ir_detail::DimLvlMapParser cParser(parser);
601       auto res = cParser.parseDimLvlMap();
602       if (failed(res))
603         return {};
604       const auto &dlm = *res;
605 
606       const Level lvlRank = dlm.getLvlRank();
607       for (Level lvl = 0; lvl < lvlRank; lvl++)
608         lvlTypes.push_back(dlm.getLvlType(lvl));
609 
610       const Dimension dimRank = dlm.getDimRank();
611       for (Dimension dim = 0; dim < dimRank; dim++)
612         dimSlices.push_back(dlm.getDimSlice(dim));
613       // NOTE: the old syntax requires an all-or-nothing approach to
614       // `dimSlices`; therefore, if any slice actually exists then we need
615       // to convert null-DSA into default/nop DSA.
616       const auto isDefined = [](SparseTensorDimSliceAttr slice) {
617         return static_cast<bool>(slice.getImpl());
618       };
619       if (llvm::any_of(dimSlices, isDefined)) {
620         const auto defaultSlice =
621             SparseTensorDimSliceAttr::get(parser.getContext());
622         for (Dimension dim = 0; dim < dimRank; dim++)
623           if (!isDefined(dimSlices[dim]))
624             dimSlices[dim] = defaultSlice;
625       } else {
626         dimSlices.clear();
627       }
628 
629       dimToLvl = dlm.getDimToLvlMap(parser.getContext());
630       lvlToDim = dlm.getLvlToDimMap(parser.getContext());
631       break;
632     }
633     case 1: { // posWidth
634       Attribute attr;
635       if (failed(parser.parseAttribute(attr)))
636         return {};
637       auto intAttr = llvm::dyn_cast<IntegerAttr>(attr);
638       if (!intAttr) {
639         parser.emitError(parser.getNameLoc(),
640                          "expected an integral position bitwidth");
641         return {};
642       }
643       posWidth = intAttr.getInt();
644       break;
645     }
646     case 2: { // crdWidth
647       Attribute attr;
648       if (failed(parser.parseAttribute(attr)))
649         return {};
650       auto intAttr = llvm::dyn_cast<IntegerAttr>(attr);
651       if (!intAttr) {
652         parser.emitError(parser.getNameLoc(),
653                          "expected an integral index bitwidth");
654         return {};
655       }
656       crdWidth = intAttr.getInt();
657       break;
658     }
659     case 3: { // explicitVal
660       Attribute attr;
661       if (failed(parser.parseAttribute(attr)))
662         return {};
663       if (auto result = llvm::dyn_cast<FloatAttr>(attr)) {
664         explicitVal = result;
665       } else if (auto result = llvm::dyn_cast<IntegerAttr>(attr)) {
666         explicitVal = result;
667       } else if (auto result = llvm::dyn_cast<complex::NumberAttr>(attr)) {
668         explicitVal = result;
669       } else {
670         parser.emitError(parser.getNameLoc(),
671                          "expected a numeric value for explicitVal");
672         return {};
673       }
674       break;
675     }
676     case 4: { // implicitVal
677       Attribute attr;
678       if (failed(parser.parseAttribute(attr)))
679         return {};
680       if (auto result = llvm::dyn_cast<FloatAttr>(attr)) {
681         implicitVal = result;
682       } else if (auto result = llvm::dyn_cast<IntegerAttr>(attr)) {
683         implicitVal = result;
684       } else if (auto result = llvm::dyn_cast<complex::NumberAttr>(attr)) {
685         implicitVal = result;
686       } else {
687         parser.emitError(parser.getNameLoc(),
688                          "expected a numeric value for implicitVal");
689         return {};
690       }
691       break;
692     }
693     } // switch
694     // Only last item can omit the comma.
695     if (parser.parseOptionalComma().failed())
696       break;
697   }
698 
699   // Close "}>" part.
700   if (failed(parser.parseRBrace()))
701     return {};
702   if (failed(parser.parseGreater()))
703     return {};
704 
705   // Construct struct-like storage for attribute.
706   if (!lvlToDim || lvlToDim.isEmpty()) {
707     lvlToDim = inferLvlToDim(dimToLvl, parser.getContext());
708   }
709   return parser.getChecked<SparseTensorEncodingAttr>(
710       parser.getContext(), lvlTypes, dimToLvl, lvlToDim, posWidth, crdWidth,
711       explicitVal, implicitVal, dimSlices);
712 }
713 
714 void SparseTensorEncodingAttr::print(AsmPrinter &printer) const {
715   auto map = static_cast<AffineMap>(getDimToLvl());
716   // Empty affine map indicates identity map
717   if (!map)
718     map = AffineMap::getMultiDimIdentityMap(getLvlTypes().size(), getContext());
719   printer << "<{ map = ";
720   printSymbols(map, printer);
721   printer << '(';
722   printDimensions(map, printer, getDimSlices());
723   printer << ") -> (";
724   printLevels(map, printer, getLvlTypes());
725   printer << ')';
726   // Print remaining members only for non-default values.
727   if (getPosWidth())
728     printer << ", posWidth = " << getPosWidth();
729   if (getCrdWidth())
730     printer << ", crdWidth = " << getCrdWidth();
731   if (getExplicitVal()) {
732     printer << ", explicitVal = " << getExplicitVal();
733   }
734   if (getImplicitVal())
735     printer << ", implicitVal = " << getImplicitVal();
736   printer << " }>";
737 }
738 
739 void SparseTensorEncodingAttr::printSymbols(AffineMap &map,
740                                             AsmPrinter &printer) const {
741   if (map.getNumSymbols() == 0)
742     return;
743   printer << '[';
744   for (unsigned i = 0, n = map.getNumSymbols() - 1; i < n; i++)
745     printer << 's' << i << ", ";
746   if (map.getNumSymbols() >= 1)
747     printer << 's' << map.getNumSymbols() - 1;
748   printer << ']';
749 }
750 
751 void SparseTensorEncodingAttr::printDimensions(
752     AffineMap &map, AsmPrinter &printer,
753     ArrayRef<SparseTensorDimSliceAttr> dimSlices) const {
754   if (!dimSlices.empty()) {
755     for (unsigned i = 0, n = map.getNumDims() - 1; i < n; i++)
756       printer << 'd' << i << " : " << dimSlices[i] << ", ";
757     if (map.getNumDims() >= 1) {
758       printer << 'd' << map.getNumDims() - 1 << " : "
759               << dimSlices[map.getNumDims() - 1];
760     }
761   } else {
762     for (unsigned i = 0, n = map.getNumDims() - 1; i < n; i++)
763       printer << 'd' << i << ", ";
764     if (map.getNumDims() >= 1)
765       printer << 'd' << map.getNumDims() - 1;
766   }
767 }
768 
769 void SparseTensorEncodingAttr::printLevels(AffineMap &map, AsmPrinter &printer,
770                                            ArrayRef<LevelType> lvlTypes) const {
771   for (unsigned i = 0, n = map.getNumResults() - 1; i < n; i++) {
772     map.getResult(i).print(printer.getStream());
773     printer << " : " << toMLIRString(lvlTypes[i]) << ", ";
774   }
775   if (map.getNumResults() >= 1) {
776     auto lastIndex = map.getNumResults() - 1;
777     map.getResult(lastIndex).print(printer.getStream());
778     printer << " : " << toMLIRString(lvlTypes[lastIndex]);
779   }
780 }
781 
782 LogicalResult SparseTensorEncodingAttr::verify(
783     function_ref<InFlightDiagnostic()> emitError, ArrayRef<LevelType> lvlTypes,
784     AffineMap dimToLvl, AffineMap lvlToDim, unsigned posWidth,
785     unsigned crdWidth, Attribute explicitVal, Attribute implicitVal,
786     ArrayRef<SparseTensorDimSliceAttr> dimSlices) {
787   if (!acceptBitWidth(posWidth))
788     return emitError() << "unexpected position bitwidth: " << posWidth;
789   if (!acceptBitWidth(crdWidth))
790     return emitError() << "unexpected coordinate bitwidth: " << crdWidth;
791   if (auto it = std::find_if(lvlTypes.begin(), lvlTypes.end(), isSingletonLT);
792       it != std::end(lvlTypes)) {
793     if (it == lvlTypes.begin() ||
794         (!isCompressedLT(*(it - 1)) && !isLooseCompressedLT(*(it - 1))))
795       return emitError() << "expected compressed or loose_compressed level "
796                             "before singleton level";
797     if (!std::all_of(it, lvlTypes.end(),
798                      [](LevelType i) { return isSingletonLT(i); }))
799       return emitError() << "expected all singleton lvlTypes "
800                             "following a singleton level";
801     // We can potentially support mixed SoA/AoS singleton levels.
802     if (!std::all_of(it, lvlTypes.end(), [it](LevelType i) {
803           return it->isa<LevelPropNonDefault::SoA>() ==
804                  i.isa<LevelPropNonDefault::SoA>();
805         })) {
806       return emitError() << "expected all singleton lvlTypes stored in the "
807                             "same memory layout (SoA vs AoS).";
808     }
809   }
810 
811   auto lastBatch = std::find_if(lvlTypes.rbegin(), lvlTypes.rend(), isBatchLT);
812   if (!std::all_of(lastBatch, lvlTypes.rend(), isBatchLT))
813     return emitError() << "Batch lvlType can only be leading levels.";
814 
815   // SoA property can only be applied on singleton level.
816   auto soaLvls = llvm::make_filter_range(lvlTypes, [](LevelType lt) {
817     return lt.isa<LevelPropNonDefault::SoA>();
818   });
819   if (llvm::any_of(soaLvls, [](LevelType lt) {
820         return !lt.isa<LevelFormat::Singleton>();
821       })) {
822     return emitError() << "SoA is only applicable to singleton lvlTypes.";
823   }
824 
825   // TODO: audit formats that actually are supported by backend.
826   if (auto it = std::find_if(lvlTypes.begin(), lvlTypes.end(), isNOutOfMLT);
827       it != std::end(lvlTypes)) {
828     if (it != lvlTypes.end() - 1)
829       return emitError() << "expected n_out_of_m to be the last level type";
830     if (!std::all_of(lvlTypes.begin(), it,
831                      [](LevelType i) { return isDenseLT(i); }))
832       return emitError() << "expected all dense lvlTypes "
833                             "before a n_out_of_m level";
834     if (dimToLvl && (dimToLvl.getNumDims() != dimToLvl.getNumResults())) {
835       if (!isBlockSparsity(dimToLvl)) {
836         return emitError()
837                << "expected 1xm block structure for n_out_of_m level";
838       }
839       auto sizes = getBlockSize(dimToLvl);
840       unsigned coefficient = 0;
841       for (const auto &elem : sizes) {
842         if (elem != 0) {
843           if (elem != coefficient && coefficient != 0) {
844             return emitError() << "expected only one blocked level "
845                                   "with the same coefficients";
846           }
847           coefficient = elem;
848         }
849       }
850       if (coefficient != getM(*it)) {
851         return emitError() << "expected coeffiencts of Affine expressions "
852                               "to be equal to m of n_out_of_m level";
853       }
854     }
855   }
856   // Before we can check that the level-rank is consistent/coherent
857   // across all fields, we need to define it.  The source-of-truth for
858   // the `getLvlRank` method is the length of the level-types array,
859   // since it must always be provided and have full rank; therefore we
860   // use that same source-of-truth here.
861   const Level lvlRank = lvlTypes.size();
862   if (lvlRank == 0)
863     return emitError() << "expected a non-empty array for lvlTypes";
864   // We save `dimRank` here because we'll also need it to verify `dimSlices`.
865   const Dimension dimRank = dimToLvl ? dimToLvl.getNumDims() : lvlRank;
866   if (dimToLvl) {
867     if (dimToLvl.getNumResults() != lvlRank)
868       return emitError()
869              << "level-rank mismatch between dimToLvl and lvlTypes: "
870              << dimToLvl.getNumResults() << " != " << lvlRank;
871     auto inferRes = inferLvlToDim(dimToLvl, dimToLvl.getContext());
872     // Symbols can't be inferred but are acceptable.
873     if (!inferRes && dimToLvl.getNumSymbols() == 0)
874       return emitError() << "failed to infer lvlToDim from dimToLvl";
875     if (lvlToDim && (inferRes != lvlToDim))
876       return emitError() << "expected lvlToDim to be an inverse of dimToLvl";
877     if (dimRank > lvlRank)
878       return emitError() << "unexpected dimToLvl mapping from " << dimRank
879                          << " to " << lvlRank;
880   }
881   if (!dimSlices.empty()) {
882     if (dimSlices.size() != dimRank)
883       return emitError()
884              << "dimension-rank mismatch between dimSlices and dimToLvl: "
885              << dimSlices.size() << " != " << dimRank;
886     // Compiler support for `dimSlices` currently requires that the two
887     // ranks agree.  (However, it does allow `dimToLvl` to be a permutation.)
888     if (dimRank != lvlRank)
889       return emitError()
890              << "dimSlices expected dimension-rank to match level-rank: "
891              << dimRank << " != " << lvlRank;
892   }
893   return success();
894 }
895 
896 LogicalResult SparseTensorEncodingAttr::verifyEncoding(
897     ArrayRef<Size> dimShape, Type elementType,
898     function_ref<InFlightDiagnostic()> emitError) const {
899   // Check structural integrity.  In particular, this ensures that the
900   // level-rank is coherent across all the fields.
901   if (failed(verify(emitError, getLvlTypes(), getDimToLvl(), getLvlToDim(),
902                     getPosWidth(), getCrdWidth(), getExplicitVal(),
903                     getImplicitVal(), getDimSlices())))
904     return failure();
905   // Check integrity with tensor type specifics.  In particular, we
906   // need only check that the dimension-rank of the tensor agrees with
907   // the dimension-rank of the encoding.
908   const Dimension dimRank = dimShape.size();
909   if (dimRank == 0)
910     return emitError() << "expected non-scalar sparse tensor";
911   if (getDimRank() != dimRank)
912     return emitError()
913            << "dimension-rank mismatch between encoding and tensor shape: "
914            << getDimRank() << " != " << dimRank;
915   return success();
916 }
917 
918 //===----------------------------------------------------------------------===//
919 // SparseTensorType Methods.
920 //===----------------------------------------------------------------------===//
921 
922 bool mlir::sparse_tensor::SparseTensorType::isCOOType(Level startLvl,
923                                                       bool isUnique) const {
924   if (!hasEncoding())
925     return false;
926   if (!isCompressedLvl(startLvl) && !isLooseCompressedLvl(startLvl))
927     return false;
928   for (Level l = startLvl + 1; l < lvlRank; ++l)
929     if (!isSingletonLvl(l))
930       return false;
931   // If isUnique is true, then make sure that the last level is unique,
932   // that is, when lvlRank == 1, the only compressed level is unique,
933   // and when lvlRank > 1, the last singleton is unique.
934   return !isUnique || isUniqueLvl(lvlRank - 1);
935 }
936 
937 Level mlir::sparse_tensor::SparseTensorType::getAoSCOOStart() const {
938   SmallVector<COOSegment> coo = getCOOSegments();
939   assert(coo.size() == 1 || coo.empty());
940   if (!coo.empty() && coo.front().isAoS()) {
941     return coo.front().lvlRange.first;
942   }
943   return lvlRank;
944 }
945 
946 SmallVector<COOSegment>
947 mlir::sparse_tensor::SparseTensorType::getCOOSegments() const {
948   SmallVector<COOSegment> ret;
949   if (!hasEncoding() || lvlRank <= 1)
950     return ret;
951 
952   ArrayRef<LevelType> lts = getLvlTypes();
953   Level l = 0;
954   while (l < lvlRank) {
955     auto lt = lts[l];
956     if (lt.isa<LevelFormat::Compressed, LevelFormat::LooseCompressed>()) {
957       auto cur = lts.begin() + l;
958       auto end = std::find_if(cur + 1, lts.end(), [](LevelType lt) {
959         return !lt.isa<LevelFormat::Singleton>();
960       });
961       unsigned cooLen = std::distance(cur, end);
962       if (cooLen > 1) {
963         // To support mixed SoA/AoS COO, we should break the segment when the
964         // storage scheme changes, for now we faithfully assume that all
965         // consecutive singleton levels have the same storage format as verified
966         // STEA.
967         ret.push_back(COOSegment{std::make_pair(l, l + cooLen),
968                                  lts[l + 1].isa<LevelPropNonDefault::SoA>()});
969       }
970       l += cooLen;
971     } else {
972       l++;
973     }
974   }
975   return ret;
976 }
977 
978 RankedTensorType
979 mlir::sparse_tensor::SparseTensorType::getCOOType(bool ordered) const {
980   SmallVector<LevelType> lvlTypes;
981   lvlTypes.reserve(lvlRank);
982   // A non-unique compressed level at beginning (unless this is
983   // also the last level, then it is unique).
984   lvlTypes.push_back(
985       *buildLevelType(LevelFormat::Compressed, ordered, lvlRank == 1));
986   if (lvlRank > 1) {
987     // Followed by n-2 non-unique singleton levels.
988     std::fill_n(std::back_inserter(lvlTypes), lvlRank - 2,
989                 *buildLevelType(LevelFormat::Singleton, ordered, false));
990     // Ends by a unique singleton level.
991     lvlTypes.push_back(*buildLevelType(LevelFormat::Singleton, ordered, true));
992   }
993   auto enc = SparseTensorEncodingAttr::get(
994       getContext(), lvlTypes, getDimToLvl(), getLvlToDim(), getPosWidth(),
995       getCrdWidth(), getExplicitVal(), getImplicitVal());
996   return RankedTensorType::get(getDimShape(), getElementType(), enc);
997 }
998 
999 //===----------------------------------------------------------------------===//
1000 // Convenience Methods.
1001 //===----------------------------------------------------------------------===//
1002 
1003 SparseTensorEncodingAttr
1004 mlir::sparse_tensor::getSparseTensorEncoding(Type type) {
1005   if (auto ttp = llvm::dyn_cast<RankedTensorType>(type))
1006     return llvm::dyn_cast_or_null<SparseTensorEncodingAttr>(ttp.getEncoding());
1007   if (auto mdtp = llvm::dyn_cast<StorageSpecifierType>(type))
1008     return mdtp.getEncoding();
1009   return nullptr;
1010 }
1011 
1012 AffineMap mlir::sparse_tensor::inferLvlToDim(AffineMap dimToLvl,
1013                                              MLIRContext *context) {
1014   auto map = static_cast<AffineMap>(dimToLvl);
1015   AffineMap lvlToDim;
1016   // Return an empty lvlToDim when inference is not successful.
1017   if (!map || map.getNumSymbols() != 0) {
1018     lvlToDim = AffineMap();
1019   } else if (map.isPermutation()) {
1020     lvlToDim = inversePermutation(map);
1021   } else if (isBlockSparsity(map)) {
1022     lvlToDim = inverseBlockSparsity(map, context);
1023   }
1024   return lvlToDim;
1025 }
1026 
1027 AffineMap mlir::sparse_tensor::inverseBlockSparsity(AffineMap dimToLvl,
1028                                                     MLIRContext *context) {
1029   SmallVector<AffineExpr> lvlExprs;
1030   auto numLvls = dimToLvl.getNumResults();
1031   lvlExprs.reserve(numLvls);
1032   // lvlExprComponents stores information of the floordiv and mod operations
1033   // applied to the same dimension, so as to build the lvlToDim map.
1034   std::map<unsigned, SmallVector<AffineExpr, 3>> lvlExprComponents;
1035   for (unsigned i = 0, n = numLvls; i < n; i++) {
1036     auto result = dimToLvl.getResult(i);
1037     if (auto binOp = dyn_cast<AffineBinaryOpExpr>(result)) {
1038       if (result.getKind() == AffineExprKind::FloorDiv) {
1039         // Position of the dimension in dimToLvl.
1040         auto pos = dyn_cast<AffineDimExpr>(binOp.getLHS()).getPosition();
1041         assert(lvlExprComponents.find(pos) == lvlExprComponents.end() &&
1042                "expected only one floordiv for each dimension");
1043         SmallVector<AffineExpr, 3> components;
1044         // Level variable for floordiv.
1045         components.push_back(getAffineDimExpr(i, context));
1046         // Multiplier.
1047         components.push_back(binOp.getRHS());
1048         // Map key is the position of the dimension.
1049         lvlExprComponents[pos] = components;
1050       } else if (result.getKind() == AffineExprKind::Mod) {
1051         auto pos = dyn_cast<AffineDimExpr>(binOp.getLHS()).getPosition();
1052         assert(lvlExprComponents.find(pos) != lvlExprComponents.end() &&
1053                "expected floordiv before mod");
1054         // Add level variable for mod to the same vector
1055         // of the corresponding floordiv.
1056         lvlExprComponents[pos].push_back(getAffineDimExpr(i, context));
1057       } else {
1058         assert(false && "expected floordiv or mod");
1059       }
1060     } else {
1061       lvlExprs.push_back(getAffineDimExpr(i, context));
1062     }
1063   }
1064   // Build lvlExprs from lvlExprComponents.
1065   // For example, for il = i floordiv 2 and ii = i mod 2, the components
1066   // would be [il, 2, ii]. It could be used to build the AffineExpr
1067   // i = il * 2 + ii in lvlToDim.
1068   for (auto &components : lvlExprComponents) {
1069     assert(components.second.size() == 3 &&
1070            "expected 3 components to build lvlExprs");
1071     auto mulOp = getAffineBinaryOpExpr(
1072         AffineExprKind::Mul, components.second[0], components.second[1]);
1073     auto addOp =
1074         getAffineBinaryOpExpr(AffineExprKind::Add, mulOp, components.second[2]);
1075     lvlExprs.push_back(addOp);
1076   }
1077   return dimToLvl.get(dimToLvl.getNumResults(), 0, lvlExprs, context);
1078 }
1079 
1080 SmallVector<unsigned> mlir::sparse_tensor::getBlockSize(AffineMap dimToLvl) {
1081   assert(isBlockSparsity(dimToLvl) &&
1082          "expected dimToLvl to be block sparsity for calling getBlockSize");
1083   SmallVector<unsigned> blockSize;
1084   for (auto result : dimToLvl.getResults()) {
1085     if (auto binOp = dyn_cast<AffineBinaryOpExpr>(result)) {
1086       if (result.getKind() == AffineExprKind::Mod) {
1087         blockSize.push_back(
1088             dyn_cast<AffineConstantExpr>(binOp.getRHS()).getValue());
1089       }
1090     } else {
1091       blockSize.push_back(0);
1092     }
1093   }
1094   return blockSize;
1095 }
1096 
1097 bool mlir::sparse_tensor::isBlockSparsity(AffineMap dimToLvl) {
1098   if (!dimToLvl)
1099     return false;
1100   std::map<unsigned, int64_t> coeffientMap;
1101   bool hasBlock = false;
1102   for (auto result : dimToLvl.getResults()) {
1103     if (auto binOp = dyn_cast<AffineBinaryOpExpr>(result)) {
1104       // Check for "dim op const".
1105       auto dimOp = dyn_cast<AffineDimExpr>(binOp.getLHS());
1106       auto conOp = dyn_cast<AffineConstantExpr>(binOp.getRHS());
1107       if (!dimOp || !conOp || conOp.getValue() <= 0)
1108         return false;
1109       // Inspect "dim / const" or "dim % const".
1110       auto pos = dimOp.getPosition();
1111       if (binOp.getKind() == AffineExprKind::FloorDiv) {
1112         // Expect only one floordiv for each dimension.
1113         if (coeffientMap.find(pos) != coeffientMap.end())
1114           return false;
1115         // Record coefficient of the floordiv.
1116         coeffientMap[pos] = conOp.getValue();
1117       } else if (binOp.getKind() == AffineExprKind::Mod) {
1118         // Expect floordiv before mod.
1119         if (coeffientMap.find(pos) == coeffientMap.end())
1120           return false;
1121         // Expect mod to have the same coefficient as floordiv.
1122         if (conOp.getValue() != coeffientMap[pos])
1123           return false;
1124         hasBlock = true;
1125       } else {
1126         return false;
1127       }
1128     } else if (auto dimOp = dyn_cast<AffineDimExpr>(result)) {
1129       auto pos = dimOp.getPosition();
1130       // Expect dim to be unset.
1131       if (coeffientMap.find(pos) != coeffientMap.end())
1132         return false;
1133       coeffientMap[pos] = 0;
1134     } else {
1135       return false;
1136     }
1137   }
1138   return hasBlock;
1139 }
1140 
1141 bool mlir::sparse_tensor::hasAnyNonIdentityOperandsOrResults(Operation *op) {
1142   auto hasNonIdentityMap = [](Value v) {
1143     auto stt = tryGetSparseTensorType(v);
1144     return stt && !stt->isIdentity();
1145   };
1146 
1147   return llvm::any_of(op->getOperands(), hasNonIdentityMap) ||
1148          llvm::any_of(op->getResults(), hasNonIdentityMap);
1149 }
1150 
1151 Dimension mlir::sparse_tensor::toDim(SparseTensorEncodingAttr enc, Level l) {
1152   if (enc) {
1153     assert(enc.isPermutation() && "Non permutation map not supported");
1154     if (const auto dimToLvl = enc.getDimToLvl())
1155       return dimToLvl.getDimPosition(l);
1156   }
1157   return l;
1158 }
1159 
1160 Level mlir::sparse_tensor::toLvl(SparseTensorEncodingAttr enc, Dimension d) {
1161   if (enc) {
1162     assert(enc.isPermutation() && "Non permutation map not supported");
1163     if (const auto lvlToDim = enc.getLvlToDim())
1164       return lvlToDim.getDimPosition(d);
1165   }
1166   return d;
1167 }
1168 
1169 /// We normalized sparse tensor encoding attribute by always using
1170 /// ordered/unique LT such that "compressed_nu_no" and "compressed_nu" (as well
1171 /// as other variants) lead to the same storage specifier type, and stripping
1172 /// irrelevant fields that do not alter the sparse tensor memory layout.
1173 static SparseTensorEncodingAttr
1174 getNormalizedEncodingForSpecifier(SparseTensorEncodingAttr enc) {
1175   SmallVector<LevelType> lts;
1176   for (auto lt : enc.getLvlTypes())
1177     lts.push_back(lt.stripStorageIrrelevantProperties());
1178 
1179   return SparseTensorEncodingAttr::get(
1180       enc.getContext(), lts,
1181       AffineMap(), // dimToLvl (irrelevant to storage specifier)
1182       AffineMap(), // lvlToDim (irrelevant to storage specifier)
1183       // Always use `index` for memSize and lvlSize instead of reusing
1184       // `getPosWidth` and `getCrdWidth`. It allows us to reuse the same SSA
1185       // value for different bitwidth, it also avoids casting between index and
1186       // integer (returned by DimOp)
1187       0, 0,
1188       Attribute(), // explicitVal (irrelevant to storage specifier)
1189       Attribute(), // implicitVal (irrelevant to storage specifier)
1190       enc.getDimSlices());
1191 }
1192 
1193 StorageSpecifierType
1194 StorageSpecifierType::get(MLIRContext *ctx, SparseTensorEncodingAttr encoding) {
1195   return Base::get(ctx, getNormalizedEncodingForSpecifier(encoding));
1196 }
1197 
1198 //===----------------------------------------------------------------------===//
1199 // SparseTensorDialect Operations.
1200 //===----------------------------------------------------------------------===//
1201 
1202 static LogicalResult lvlIsInBounds(Level lvl, Value tensor) {
1203   return success(lvl < getSparseTensorType(tensor).getLvlRank());
1204 }
1205 
1206 static LogicalResult isMatchingWidth(Value mem, unsigned width) {
1207   const Type etp = getMemRefType(mem).getElementType();
1208   return success(width == 0 ? etp.isIndex() : etp.isInteger(width));
1209 }
1210 
1211 static LogicalResult verifySparsifierGetterSetter(
1212     StorageSpecifierKind mdKind, std::optional<Level> lvl,
1213     TypedValue<StorageSpecifierType> md, Operation *op) {
1214   if (mdKind == StorageSpecifierKind::ValMemSize && lvl) {
1215     return op->emitError(
1216         "redundant level argument for querying value memory size");
1217   }
1218 
1219   const auto enc = md.getType().getEncoding();
1220   const Level lvlRank = enc.getLvlRank();
1221 
1222   if (mdKind == StorageSpecifierKind::DimOffset ||
1223       mdKind == StorageSpecifierKind::DimStride)
1224     if (!enc.isSlice())
1225       return op->emitError("requested slice data on non-slice tensor");
1226 
1227   if (mdKind != StorageSpecifierKind::ValMemSize) {
1228     if (!lvl)
1229       return op->emitError("missing level argument");
1230 
1231     const Level l = lvl.value();
1232     if (l >= lvlRank)
1233       return op->emitError("requested level is out of bounds");
1234 
1235     if (mdKind == StorageSpecifierKind::PosMemSize && enc.isSingletonLvl(l))
1236       return op->emitError(
1237           "requested position memory size on a singleton level");
1238   }
1239   return success();
1240 }
1241 
1242 static Type getFieldElemType(SparseTensorType stt, SparseTensorFieldKind kind) {
1243   switch (kind) {
1244   case SparseTensorFieldKind::CrdMemRef:
1245     return stt.getCrdType();
1246   case SparseTensorFieldKind::PosMemRef:
1247     return stt.getPosType();
1248   case SparseTensorFieldKind::ValMemRef:
1249     return stt.getElementType();
1250   case SparseTensorFieldKind::StorageSpec:
1251     return nullptr;
1252   }
1253   llvm_unreachable("Unrecognizable FieldKind");
1254 }
1255 
1256 static LogicalResult verifyPackUnPack(Operation *op, bool requiresStaticShape,
1257                                       SparseTensorType stt,
1258                                       RankedTensorType valTp,
1259                                       TypeRange lvlTps) {
1260   if (requiresStaticShape && !stt.hasStaticDimShape())
1261     return op->emitError("the sparse-tensor must have static shape");
1262   if (!stt.hasEncoding())
1263     return op->emitError("the sparse-tensor must have an encoding attribute");
1264 
1265   // Verifies the trailing COO.
1266   Level cooStartLvl = stt.getAoSCOOStart();
1267   if (cooStartLvl < stt.getLvlRank()) {
1268     // We only supports trailing COO for now, must be the last input.
1269     auto cooTp = llvm::cast<ShapedType>(lvlTps.back());
1270     // The coordinates should be in shape of <? x rank>
1271     unsigned expCOORank = stt.getLvlRank() - cooStartLvl;
1272     if (cooTp.getRank() != 2 || expCOORank != cooTp.getShape().back()) {
1273       op->emitError("input/output trailing COO level-ranks don't match");
1274     }
1275   }
1276 
1277   // Verifies that all types match.
1278   StorageLayout layout(stt.getEncoding());
1279   if (layout.getNumDataFields() != lvlTps.size() + 1) // plus one value memref
1280     return op->emitError("inconsistent number of fields between input/output");
1281 
1282   unsigned idx = 0;
1283   bool misMatch = false;
1284   layout.foreachField([&idx, &misMatch, stt, valTp,
1285                        lvlTps](FieldIndex fid, SparseTensorFieldKind fKind,
1286                                Level lvl, LevelType lt) -> bool {
1287     if (fKind == SparseTensorFieldKind::StorageSpec)
1288       return true;
1289 
1290     Type inputTp = nullptr;
1291     if (fKind == SparseTensorFieldKind::ValMemRef) {
1292       inputTp = valTp;
1293     } else {
1294       assert(fid == idx && stt.getLvlType(lvl) == lt);
1295       inputTp = lvlTps[idx++];
1296     }
1297     // The input element type and expected element type should match.
1298     Type inpElemTp = llvm::cast<TensorType>(inputTp).getElementType();
1299     Type expElemTp = getFieldElemType(stt, fKind);
1300     if (inpElemTp != expElemTp) {
1301       misMatch = true;
1302       return false; // to terminate the iteration
1303     }
1304     return true;
1305   });
1306 
1307   if (misMatch)
1308     return op->emitError("input/output element-types don't match");
1309   return success();
1310 }
1311 
1312 LogicalResult AssembleOp::verify() {
1313   const auto valuesTp = getRankedTensorType(getValues());
1314   const auto lvlsTp = getLevels().getTypes();
1315   const auto resTp = getSparseTensorType(getResult());
1316   return verifyPackUnPack(*this, true, resTp, valuesTp, lvlsTp);
1317 }
1318 
1319 LogicalResult DisassembleOp::verify() {
1320   if (getOutValues().getType() != getRetValues().getType())
1321     return emitError("output values and return value type mismatch");
1322 
1323   for (auto [ot, rt] : llvm::zip_equal(getOutLevels(), getRetLevels()))
1324     if (ot.getType() != rt.getType())
1325       return emitError("output levels and return levels type mismatch");
1326 
1327   const auto valuesTp = getRankedTensorType(getRetValues());
1328   const auto lvlsTp = getRetLevels().getTypes();
1329   const auto srcTp = getSparseTensorType(getTensor());
1330   return verifyPackUnPack(*this, false, srcTp, valuesTp, lvlsTp);
1331 }
1332 
1333 LogicalResult ConvertOp::verify() {
1334   if (auto tp1 = llvm::dyn_cast<RankedTensorType>(getSource().getType())) {
1335     if (auto tp2 = llvm::dyn_cast<RankedTensorType>(getDest().getType())) {
1336       if (tp1.getRank() != tp2.getRank())
1337         return emitError("unexpected conversion mismatch in rank");
1338       auto dstEnc =
1339           llvm::dyn_cast_or_null<SparseTensorEncodingAttr>(tp2.getEncoding());
1340       if (dstEnc && dstEnc.isSlice())
1341         return emitError("cannot convert to a sparse tensor slice");
1342 
1343       auto shape1 = tp1.getShape();
1344       auto shape2 = tp2.getShape();
1345       // Accept size matches between the source and the destination type
1346       // (e.g. 10 vs. 10, 10 vs. ?, or ? vs. ?), but reject direct mismatches or
1347       // matches that would need a runtime assert (e.g. 10 vs. 20 or ? vs. 10).
1348       for (Dimension d = 0, dimRank = tp1.getRank(); d < dimRank; d++)
1349         if (shape1[d] != shape2[d] && shape2[d] != ShapedType::kDynamic)
1350           return emitError("unexpected conversion mismatch in dimension ") << d;
1351       return success();
1352     }
1353   }
1354   return emitError("unexpected type in convert");
1355 }
1356 
1357 OpFoldResult ConvertOp::fold(FoldAdaptor adaptor) {
1358   if (getType() == getSource().getType())
1359     return getSource();
1360   return {};
1361 }
1362 
1363 bool ConvertOp::needsExtraSort() {
1364   SparseTensorType srcStt = getSparseTensorType(getSource());
1365   SparseTensorType dstStt = getSparseTensorType(getDest());
1366 
1367   // We do not need an extra sort when returning unordered sparse tensors or
1368   // dense tensor since dense tensor support random access.
1369   if (dstStt.isAllDense() || !dstStt.isAllOrdered())
1370     return false;
1371 
1372   if (srcStt.isAllOrdered() && dstStt.isAllOrdered() &&
1373       srcStt.hasSameDimToLvl(dstStt)) {
1374     return false;
1375   }
1376 
1377   // Source and dest tensors are ordered in different ways. We only do direct
1378   // dense to sparse conversion when the dense input is defined by a sparse
1379   // constant. Note that we can theoretically always directly convert from dense
1380   // inputs by rotating dense loops but it leads to bad cache locality and hurt
1381   // performance.
1382   if (auto constOp = getSource().getDefiningOp<arith::ConstantOp>())
1383     if (isa<SparseElementsAttr>(constOp.getValue()))
1384       return false;
1385 
1386   return true;
1387 }
1388 
1389 LogicalResult CrdTranslateOp::verify() {
1390   uint64_t inRank = getEncoder().getLvlRank();
1391   uint64_t outRank = getEncoder().getDimRank();
1392 
1393   if (getDirection() == CrdTransDirectionKind::dim2lvl)
1394     std::swap(inRank, outRank);
1395 
1396   if (inRank != getInCrds().size() || outRank != getOutCrds().size())
1397     return emitError("Coordinate rank mismatch with encoding");
1398 
1399   return success();
1400 }
1401 
1402 LogicalResult CrdTranslateOp::fold(FoldAdaptor adaptor,
1403                                    SmallVectorImpl<OpFoldResult> &results) {
1404   if (getEncoder().isIdentity()) {
1405     results.assign(getInCrds().begin(), getInCrds().end());
1406     return success();
1407   }
1408   if (getEncoder().isPermutation()) {
1409     AffineMap perm = getDirection() == CrdTransDirectionKind::dim2lvl
1410                          ? getEncoder().getDimToLvl()
1411                          : getEncoder().getLvlToDim();
1412     for (AffineExpr exp : perm.getResults())
1413       results.push_back(getInCrds()[cast<AffineDimExpr>(exp).getPosition()]);
1414     return success();
1415   }
1416 
1417   // Fuse dim2lvl/lvl2dim pairs.
1418   auto def = getInCrds()[0].getDefiningOp<CrdTranslateOp>();
1419   bool sameDef = def && llvm::all_of(getInCrds(), [def](Value v) {
1420                    return v.getDefiningOp() == def;
1421                  });
1422   if (!sameDef)
1423     return failure();
1424 
1425   bool oppositeDir = def.getDirection() != getDirection();
1426   bool sameOracle =
1427       def.getEncoder().getDimToLvl() == getEncoder().getDimToLvl();
1428   bool sameCount = def.getNumResults() == getInCrds().size();
1429   if (!oppositeDir || !sameOracle || !sameCount)
1430     return failure();
1431 
1432   // The definition produces the coordinates in the same order as the input
1433   // coordinates.
1434   bool sameOrder = llvm::all_of(llvm::zip_equal(def.getOutCrds(), getInCrds()),
1435                                 [](auto valuePair) {
1436                                   auto [lhs, rhs] = valuePair;
1437                                   return lhs == rhs;
1438                                 });
1439 
1440   if (!sameOrder)
1441     return failure();
1442   // l1 = dim2lvl (lvl2dim l0)
1443   // ==> l0
1444   results.append(def.getInCrds().begin(), def.getInCrds().end());
1445   return success();
1446 }
1447 
1448 void LvlOp::build(OpBuilder &builder, OperationState &state, Value source,
1449                   int64_t index) {
1450   Value val = builder.create<arith::ConstantIndexOp>(state.location, index);
1451   return build(builder, state, source, val);
1452 }
1453 
1454 LogicalResult LvlOp::verify() {
1455   if (std::optional<uint64_t> lvl = getConstantLvlIndex()) {
1456     auto stt = getSparseTensorType(getSource());
1457     if (static_cast<uint64_t>(lvl.value()) >= stt.getLvlRank())
1458       emitError("Level index exceeds the rank of the input sparse tensor");
1459   }
1460   return success();
1461 }
1462 
1463 std::optional<uint64_t> LvlOp::getConstantLvlIndex() {
1464   return getConstantIntValue(getIndex());
1465 }
1466 
1467 Speculation::Speculatability LvlOp::getSpeculatability() {
1468   auto constantIndex = getConstantLvlIndex();
1469   if (!constantIndex)
1470     return Speculation::NotSpeculatable;
1471 
1472   assert(constantIndex <
1473          cast<RankedTensorType>(getSource().getType()).getRank());
1474   return Speculation::Speculatable;
1475 }
1476 
1477 OpFoldResult LvlOp::fold(FoldAdaptor adaptor) {
1478   auto lvlIndex = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getIndex());
1479   if (!lvlIndex)
1480     return {};
1481 
1482   Level lvl = lvlIndex.getAPSInt().getZExtValue();
1483   auto stt = getSparseTensorType(getSource());
1484   if (lvl >= stt.getLvlRank()) {
1485     // Follows the same convention used by tensor.dim operation. Out of bound
1486     // indices produce undefined behavior but are still valid IR. Don't choke on
1487     // them.
1488     return {};
1489   }
1490 
1491   // Helper lambda to build an IndexAttr.
1492   auto getIndexAttr = [this](int64_t lvlSz) {
1493     return IntegerAttr::get(IndexType::get(getContext()), APInt(64, lvlSz));
1494   };
1495 
1496   SmallVector<Size> lvlShape = stt.getLvlShape();
1497   if (!ShapedType::isDynamic(lvlShape[lvl]))
1498     return getIndexAttr(lvlShape[lvl]);
1499 
1500   return {};
1501 }
1502 
1503 void ReinterpretMapOp::build(OpBuilder &odsBuilder, OperationState &odsState,
1504                              SparseTensorEncodingAttr dstEnc, Value source) {
1505   auto srcStt = getSparseTensorType(source);
1506   SmallVector<int64_t> srcLvlShape = srcStt.getLvlShape();
1507   SmallVector<int64_t> dstDimShape =
1508       dstEnc.translateShape(srcLvlShape, CrdTransDirectionKind::lvl2dim);
1509   auto dstTp =
1510       RankedTensorType::get(dstDimShape, srcStt.getElementType(), dstEnc);
1511   return build(odsBuilder, odsState, dstTp, source);
1512 }
1513 
1514 LogicalResult ReinterpretMapOp::verify() {
1515   auto srcStt = getSparseTensorType(getSource());
1516   auto dstStt = getSparseTensorType(getDest());
1517   ArrayRef<LevelType> srcLvlTps = srcStt.getLvlTypes();
1518   ArrayRef<LevelType> dstLvlTps = dstStt.getLvlTypes();
1519 
1520   if (srcLvlTps.size() != dstLvlTps.size())
1521     return emitError("Level rank mismatch between source/dest tensors");
1522 
1523   for (auto [srcLvlTp, dstLvlTp] : llvm::zip(srcLvlTps, dstLvlTps))
1524     if (srcLvlTp != dstLvlTp)
1525       return emitError("Level type mismatch between source/dest tensors");
1526 
1527   if (srcStt.getPosWidth() != dstStt.getPosWidth() ||
1528       srcStt.getCrdWidth() != dstStt.getCrdWidth()) {
1529     return emitError("Crd/Pos width mismatch between source/dest tensors");
1530   }
1531 
1532   if (srcStt.getElementType() != dstStt.getElementType())
1533     return emitError("Element type mismatch between source/dest tensors");
1534 
1535   SmallVector<Size> srcLvlShape = srcStt.getLvlShape();
1536   SmallVector<Size> dstLvlShape = dstStt.getLvlShape();
1537   for (auto [srcLvlSz, dstLvlSz] : llvm::zip(srcLvlShape, dstLvlShape)) {
1538     if (srcLvlSz != dstLvlSz) {
1539       // Should we allow one side to be dynamic size, e.g., <?x?> should be
1540       // compatible to <3x4>? For now, we require all the level sizes to be
1541       // *exactly* matched for simplicity.
1542       return emitError("Level size mismatch between source/dest tensors");
1543     }
1544   }
1545 
1546   return success();
1547 }
1548 
1549 OpFoldResult ReinterpretMapOp::fold(FoldAdaptor adaptor) {
1550   if (getSource().getType() == getDest().getType())
1551     return getSource();
1552 
1553   if (auto def = getSource().getDefiningOp<ReinterpretMapOp>()) {
1554     // A -> B, B -> A ==> A
1555     if (def.getSource().getType() == getDest().getType())
1556       return def.getSource();
1557   }
1558   return {};
1559 }
1560 
1561 template <typename ToBufferOp>
1562 static LogicalResult inferSparseBufferType(ValueRange ops, DictionaryAttr attr,
1563                                            OpaqueProperties prop,
1564                                            RegionRange region,
1565                                            SmallVectorImpl<mlir::Type> &ret) {
1566   typename ToBufferOp::Adaptor adaptor(ops, attr, prop, region);
1567   SparseTensorType stt = getSparseTensorType(adaptor.getTensor());
1568   Type elemTp = nullptr;
1569   bool withStride = false;
1570   if constexpr (std::is_same_v<ToBufferOp, ToPositionsOp>) {
1571     elemTp = stt.getPosType();
1572   } else if constexpr (std::is_same_v<ToBufferOp, ToCoordinatesOp> ||
1573                        std::is_same_v<ToBufferOp, ToCoordinatesBufferOp>) {
1574     elemTp = stt.getCrdType();
1575     if constexpr (std::is_same_v<ToBufferOp, ToCoordinatesOp>)
1576       withStride = stt.getAoSCOOStart() <= adaptor.getLevel();
1577   } else if constexpr (std::is_same_v<ToBufferOp, ToValuesOp>) {
1578     elemTp = stt.getElementType();
1579   }
1580 
1581   assert(elemTp && "unhandled operation.");
1582   SmallVector<int64_t> bufShape = stt.getBatchLvlShape();
1583   bufShape.push_back(ShapedType::kDynamic);
1584 
1585   auto layout = withStride ? StridedLayoutAttr::StridedLayoutAttr::get(
1586                                  stt.getContext(), ShapedType::kDynamic,
1587                                  {ShapedType::kDynamic})
1588                            : StridedLayoutAttr();
1589   ret.emplace_back(MemRefType::get(bufShape, elemTp, layout));
1590   return success();
1591 }
1592 
1593 LogicalResult ToPositionsOp::verify() {
1594   auto stt = getSparseTensorType(getTensor());
1595   if (failed(lvlIsInBounds(getLevel(), getTensor())))
1596     return emitError("requested level is out of bounds");
1597   if (failed(isMatchingWidth(getResult(), stt.getPosWidth())))
1598     return emitError("unexpected type for positions");
1599   return success();
1600 }
1601 
1602 LogicalResult
1603 ToPositionsOp::inferReturnTypes(MLIRContext *ctx, std::optional<Location> loc,
1604                                 ValueRange ops, DictionaryAttr attr,
1605                                 OpaqueProperties prop, RegionRange region,
1606                                 SmallVectorImpl<mlir::Type> &ret) {
1607   return inferSparseBufferType<ToPositionsOp>(ops, attr, prop, region, ret);
1608 }
1609 
1610 LogicalResult ToCoordinatesOp::verify() {
1611   auto stt = getSparseTensorType(getTensor());
1612   if (failed(lvlIsInBounds(getLevel(), getTensor())))
1613     return emitError("requested level is out of bounds");
1614   if (failed(isMatchingWidth(getResult(), stt.getCrdWidth())))
1615     return emitError("unexpected type for coordinates");
1616   return success();
1617 }
1618 
1619 LogicalResult
1620 ToCoordinatesOp::inferReturnTypes(MLIRContext *ctx, std::optional<Location> loc,
1621                                   ValueRange ops, DictionaryAttr attr,
1622                                   OpaqueProperties prop, RegionRange region,
1623                                   SmallVectorImpl<mlir::Type> &ret) {
1624   return inferSparseBufferType<ToCoordinatesOp>(ops, attr, prop, region, ret);
1625 }
1626 
1627 LogicalResult ToCoordinatesBufferOp::verify() {
1628   auto stt = getSparseTensorType(getTensor());
1629   if (stt.getAoSCOOStart() >= stt.getLvlRank())
1630     return emitError("expected sparse tensor with a COO region");
1631   return success();
1632 }
1633 
1634 LogicalResult ToCoordinatesBufferOp::inferReturnTypes(
1635     MLIRContext *ctx, std::optional<Location> loc, ValueRange ops,
1636     DictionaryAttr attr, OpaqueProperties prop, RegionRange region,
1637     SmallVectorImpl<mlir::Type> &ret) {
1638   return inferSparseBufferType<ToCoordinatesBufferOp>(ops, attr, prop, region,
1639                                                       ret);
1640 }
1641 
1642 LogicalResult ToValuesOp::verify() {
1643   auto stt = getSparseTensorType(getTensor());
1644   auto mtp = getMemRefType(getResult());
1645   if (stt.getElementType() != mtp.getElementType())
1646     return emitError("unexpected mismatch in element types");
1647   return success();
1648 }
1649 
1650 LogicalResult ToValuesOp::inferReturnTypes(MLIRContext *ctx,
1651                                            std::optional<Location> loc,
1652                                            ValueRange ops, DictionaryAttr attr,
1653                                            OpaqueProperties prop,
1654                                            RegionRange region,
1655                                            SmallVectorImpl<mlir::Type> &ret) {
1656   return inferSparseBufferType<ToValuesOp>(ops, attr, prop, region, ret);
1657 }
1658 
1659 LogicalResult ToSliceOffsetOp::verify() {
1660   auto rank = getRankedTensorType(getSlice()).getRank();
1661   if (rank <= getDim().getSExtValue() || getDim().getSExtValue() < 0)
1662     return emitError("requested dimension out of bound");
1663   return success();
1664 }
1665 
1666 LogicalResult ToSliceStrideOp::verify() {
1667   auto rank = getRankedTensorType(getSlice()).getRank();
1668   if (rank <= getDim().getSExtValue() || getDim().getSExtValue() < 0)
1669     return emitError("requested dimension out of bound");
1670   return success();
1671 }
1672 
1673 LogicalResult GetStorageSpecifierOp::verify() {
1674   return verifySparsifierGetterSetter(getSpecifierKind(), getLevel(),
1675                                       getSpecifier(), getOperation());
1676 }
1677 
1678 template <typename SpecifierOp>
1679 static SetStorageSpecifierOp getSpecifierSetDef(SpecifierOp op) {
1680   return op.getSpecifier().template getDefiningOp<SetStorageSpecifierOp>();
1681 }
1682 
1683 OpFoldResult GetStorageSpecifierOp::fold(FoldAdaptor adaptor) {
1684   const StorageSpecifierKind kind = getSpecifierKind();
1685   const auto lvl = getLevel();
1686   for (auto op = getSpecifierSetDef(*this); op; op = getSpecifierSetDef(op))
1687     if (kind == op.getSpecifierKind() && lvl == op.getLevel())
1688       return op.getValue();
1689   return {};
1690 }
1691 
1692 LogicalResult SetStorageSpecifierOp::verify() {
1693   return verifySparsifierGetterSetter(getSpecifierKind(), getLevel(),
1694                                       getSpecifier(), getOperation());
1695 }
1696 
1697 template <class T>
1698 static LogicalResult verifyNumBlockArgs(T *op, Region &region,
1699                                         const char *regionName,
1700                                         TypeRange inputTypes, Type outputType) {
1701   unsigned numArgs = region.getNumArguments();
1702   unsigned expectedNum = inputTypes.size();
1703   if (numArgs != expectedNum)
1704     return op->emitError() << regionName << " region must have exactly "
1705                            << expectedNum << " arguments";
1706 
1707   for (unsigned i = 0; i < numArgs; i++) {
1708     Type typ = region.getArgument(i).getType();
1709     if (typ != inputTypes[i])
1710       return op->emitError() << regionName << " region argument " << (i + 1)
1711                              << " type mismatch";
1712   }
1713   Operation *term = region.front().getTerminator();
1714   YieldOp yield = dyn_cast<YieldOp>(term);
1715   if (!yield)
1716     return op->emitError() << regionName
1717                            << " region must end with sparse_tensor.yield";
1718   if (!yield.hasSingleResult() ||
1719       yield.getSingleResult().getType() != outputType)
1720     return op->emitError() << regionName << " region yield type mismatch";
1721 
1722   return success();
1723 }
1724 
1725 LogicalResult BinaryOp::verify() {
1726   NamedAttrList attrs = (*this)->getAttrs();
1727   Type leftType = getX().getType();
1728   Type rightType = getY().getType();
1729   Type outputType = getOutput().getType();
1730   Region &overlap = getOverlapRegion();
1731   Region &left = getLeftRegion();
1732   Region &right = getRightRegion();
1733 
1734   // Check correct number of block arguments and return type for each
1735   // non-empty region.
1736   if (!overlap.empty()) {
1737     if (failed(verifyNumBlockArgs(this, overlap, "overlap",
1738                                   TypeRange{leftType, rightType}, outputType)))
1739       return failure();
1740   }
1741   if (!left.empty()) {
1742     if (failed(verifyNumBlockArgs(this, left, "left", TypeRange{leftType},
1743                                   outputType)))
1744       return failure();
1745   } else if (getLeftIdentity()) {
1746     if (leftType != outputType)
1747       return emitError("left=identity requires first argument to have the same "
1748                        "type as the output");
1749   }
1750   if (!right.empty()) {
1751     if (failed(verifyNumBlockArgs(this, right, "right", TypeRange{rightType},
1752                                   outputType)))
1753       return failure();
1754   } else if (getRightIdentity()) {
1755     if (rightType != outputType)
1756       return emitError("right=identity requires second argument to have the "
1757                        "same type as the output");
1758   }
1759   return success();
1760 }
1761 
1762 LogicalResult UnaryOp::verify() {
1763   Type inputType = getX().getType();
1764   Type outputType = getOutput().getType();
1765 
1766   // Check correct number of block arguments and return type for each
1767   // non-empty region.
1768   Region &present = getPresentRegion();
1769   if (!present.empty()) {
1770     if (failed(verifyNumBlockArgs(this, present, "present",
1771                                   TypeRange{inputType}, outputType)))
1772       return failure();
1773   }
1774   Region &absent = getAbsentRegion();
1775   if (!absent.empty()) {
1776     if (failed(verifyNumBlockArgs(this, absent, "absent", TypeRange{},
1777                                   outputType)))
1778       return failure();
1779     // Absent branch can only yield invariant values.
1780     Block *absentBlock = &absent.front();
1781     Block *parent = getOperation()->getBlock();
1782     Value absentVal =
1783         cast<YieldOp>(absentBlock->getTerminator()).getSingleResult();
1784     if (auto arg = dyn_cast<BlockArgument>(absentVal)) {
1785       if (arg.getOwner() == parent)
1786         return emitError("absent region cannot yield linalg argument");
1787     } else if (Operation *def = absentVal.getDefiningOp()) {
1788       if (!isa<arith::ConstantOp>(def) &&
1789           (def->getBlock() == absentBlock || def->getBlock() == parent))
1790         return emitError("absent region cannot yield locally computed value");
1791     }
1792   }
1793   return success();
1794 }
1795 
1796 bool ConcatenateOp::needsExtraSort() {
1797   SparseTensorType dstStt = getSparseTensorType(*this);
1798   if (dstStt.isAllDense() || !dstStt.isAllOrdered())
1799     return false;
1800 
1801   bool allSameOrdered = llvm::all_of(getInputs(), [dstStt](Value op) {
1802     return getSparseTensorType(op).hasSameDimToLvl(dstStt);
1803   });
1804   // TODO: When conDim != 0, as long as conDim corresponding to the first level
1805   // in all input/output buffers, and all input/output buffers have the same
1806   // dimToLvl, the tmp COO buffer is still unnecessary (e.g, concatenate
1807   // CSC matrices along column).
1808   bool directLowerable =
1809       allSameOrdered && getDimension() == 0 && dstStt.isIdentity();
1810   return !directLowerable;
1811 }
1812 
1813 LogicalResult ConcatenateOp::verify() {
1814   const auto dstTp = getSparseTensorType(*this);
1815   const Dimension concatDim = getDimension();
1816   const Dimension dimRank = dstTp.getDimRank();
1817 
1818   if (getInputs().size() <= 1)
1819     return emitError("Need at least two tensors to concatenate.");
1820 
1821   if (concatDim >= dimRank)
1822     return emitError(llvm::formatv(
1823         "Concat-dimension is out of bounds for dimension-rank ({0} >= {1})",
1824         concatDim, dimRank));
1825 
1826   for (const auto &it : llvm::enumerate(getInputs())) {
1827     const auto i = it.index();
1828     const auto srcTp = getSparseTensorType(it.value());
1829     if (srcTp.hasDynamicDimShape())
1830       return emitError(llvm::formatv("Input tensor ${0} has dynamic shape", i));
1831     const Dimension srcDimRank = srcTp.getDimRank();
1832     if (srcDimRank != dimRank)
1833       return emitError(
1834           llvm::formatv("Input tensor ${0} has a different rank (rank={1}) "
1835                         "from the output tensor (rank={2}).",
1836                         i, srcDimRank, dimRank));
1837   }
1838 
1839   for (Dimension d = 0; d < dimRank; d++) {
1840     const Size dstSh = dstTp.getDimShape()[d];
1841     if (d == concatDim) {
1842       if (!ShapedType::isDynamic(dstSh)) {
1843         // If we reach here, then all inputs have static shapes.  So we
1844         // can use `getDimShape()[d]` instead of `*getDynamicDimSize(d)`
1845         // to avoid redundant assertions in the loop.
1846         Size sumSz = 0;
1847         for (const auto src : getInputs())
1848           sumSz += getSparseTensorType(src).getDimShape()[d];
1849         // If all dimension are statically known, the sum of all the input
1850         // dimensions should be equal to the output dimension.
1851         if (sumSz != dstSh)
1852           return emitError(
1853               "The concatenation dimension of the output tensor should be the "
1854               "sum of all the concatenation dimensions of the input tensors.");
1855       }
1856     } else {
1857       Size prev = dstSh;
1858       for (const auto src : getInputs()) {
1859         const auto sh = getSparseTensorType(src).getDimShape()[d];
1860         if (!ShapedType::isDynamic(prev) && sh != prev)
1861           return emitError("All dimensions (expect for the concatenating one) "
1862                            "should be equal.");
1863         prev = sh;
1864       }
1865     }
1866   }
1867 
1868   return success();
1869 }
1870 
1871 void PushBackOp::build(OpBuilder &builder, OperationState &result,
1872                        Value curSize, Value inBuffer, Value value) {
1873   build(builder, result, curSize, inBuffer, value, Value());
1874 }
1875 
1876 LogicalResult PushBackOp::verify() {
1877   if (Value n = getN()) {
1878     std::optional<int64_t> nValue = getConstantIntValue(n);
1879     if (nValue && nValue.value() < 1)
1880       return emitOpError("n must be not less than 1");
1881   }
1882   return success();
1883 }
1884 
1885 LogicalResult CompressOp::verify() {
1886   const auto stt = getSparseTensorType(getTensor());
1887   if (stt.getLvlRank() != 1 + static_cast<Level>(getLvlCoords().size()))
1888     return emitOpError("incorrect number of coordinates");
1889   return success();
1890 }
1891 
1892 void ForeachOp::build(
1893     OpBuilder &builder, OperationState &result, Value tensor,
1894     ValueRange initArgs, AffineMapAttr order,
1895     function_ref<void(OpBuilder &, Location, ValueRange, Value, ValueRange)>
1896         bodyBuilder) {
1897   build(builder, result, initArgs.getTypes(), tensor, initArgs, order);
1898   // Builds foreach body.
1899   if (!bodyBuilder)
1900     return;
1901   const auto stt = getSparseTensorType(tensor);
1902   const Dimension dimRank = stt.getDimRank();
1903 
1904   // Starts with `dimRank`-many coordinates.
1905   SmallVector<Type> blockArgTypes(dimRank, builder.getIndexType());
1906   // Followed by one value.
1907   blockArgTypes.push_back(stt.getElementType());
1908   // Followed by the reduction variables.
1909   blockArgTypes.append(initArgs.getTypes().begin(), initArgs.getTypes().end());
1910 
1911   SmallVector<Location> blockArgLocs(blockArgTypes.size(), tensor.getLoc());
1912 
1913   OpBuilder::InsertionGuard guard(builder);
1914   auto &region = *result.regions.front();
1915   Block *bodyBlock =
1916       builder.createBlock(&region, region.end(), blockArgTypes, blockArgLocs);
1917   bodyBuilder(builder, result.location,
1918               bodyBlock->getArguments().slice(0, dimRank),
1919               bodyBlock->getArguments()[dimRank],
1920               bodyBlock->getArguments().drop_front(dimRank + 1));
1921 }
1922 
1923 LogicalResult ForeachOp::verify() {
1924   const auto t = getSparseTensorType(getTensor());
1925   const Dimension dimRank = t.getDimRank();
1926   const auto args = getBody()->getArguments();
1927 
1928   if (getOrder().has_value() && getOrder()->getNumDims() != t.getLvlRank())
1929     return emitError("Level traverse order does not match tensor's level rank");
1930 
1931   if (dimRank + 1 + getInitArgs().size() != args.size())
1932     return emitError("Unmatched number of arguments in the block");
1933 
1934   if (getNumResults() != getInitArgs().size())
1935     return emitError("Mismatch in number of init arguments and results");
1936 
1937   if (getResultTypes() != getInitArgs().getTypes())
1938     return emitError("Mismatch in types of init arguments and results");
1939 
1940   // Cannot mark this const, because the getters aren't.
1941   auto yield = cast<YieldOp>(getBody()->getTerminator());
1942   if (yield.getNumOperands() != getNumResults() ||
1943       yield.getOperands().getTypes() != getResultTypes())
1944     return emitError("Mismatch in types of yield values and results");
1945 
1946   const auto iTp = IndexType::get(getContext());
1947   for (Dimension d = 0; d < dimRank; d++)
1948     if (args[d].getType() != iTp)
1949       emitError(
1950           llvm::formatv("Expecting Index type for argument at index {0}", d));
1951 
1952   const auto elemTp = t.getElementType();
1953   const auto valueTp = args[dimRank].getType();
1954   if (elemTp != valueTp)
1955     emitError(llvm::formatv("Unmatched element type between input tensor and "
1956                             "block argument, expected:{0}, got: {1}",
1957                             elemTp, valueTp));
1958   return success();
1959 }
1960 
1961 OpFoldResult ReorderCOOOp::fold(FoldAdaptor adaptor) {
1962   if (getSparseTensorEncoding(getInputCoo().getType()) ==
1963       getSparseTensorEncoding(getResultCoo().getType()))
1964     return getInputCoo();
1965 
1966   return {};
1967 }
1968 
1969 LogicalResult ReorderCOOOp::verify() {
1970   SparseTensorType srcStt = getSparseTensorType(getInputCoo());
1971   SparseTensorType dstStt = getSparseTensorType(getResultCoo());
1972 
1973   if (!srcStt.isCOOType() || !dstStt.isCOOType())
1974     emitError("Expected COO sparse tensors only");
1975 
1976   if (!srcStt.hasSameDimToLvl(dstStt))
1977     emitError("Unmatched dim2lvl map between input and result COO");
1978 
1979   if (srcStt.getPosType() != dstStt.getPosType() ||
1980       srcStt.getCrdType() != dstStt.getCrdType() ||
1981       srcStt.getElementType() != dstStt.getElementType())
1982     emitError("Unmatched storage format between input and result COO");
1983 
1984   return success();
1985 }
1986 
1987 LogicalResult ReduceOp::verify() {
1988   Type inputType = getX().getType();
1989   Region &formula = getRegion();
1990   return verifyNumBlockArgs(this, formula, "reduce",
1991                             TypeRange{inputType, inputType}, inputType);
1992 }
1993 
1994 LogicalResult SelectOp::verify() {
1995   Builder b(getContext());
1996   Type inputType = getX().getType();
1997   Type boolType = b.getI1Type();
1998   Region &formula = getRegion();
1999   return verifyNumBlockArgs(this, formula, "select", TypeRange{inputType},
2000                             boolType);
2001 }
2002 
2003 LogicalResult SortOp::verify() {
2004   AffineMap xPerm = getPermMap();
2005   uint64_t nx = xPerm.getNumDims();
2006   if (nx < 1)
2007     emitError(llvm::formatv("Expected rank(perm_map) > 1, got {0}", nx));
2008 
2009   if (!xPerm.isPermutation())
2010     emitError(llvm::formatv("Expected a permutation map, got {0}", xPerm));
2011 
2012   // We can't check the size of the buffers when n or buffer dimensions aren't
2013   // compile-time constants.
2014   std::optional<int64_t> cn = getConstantIntValue(getN());
2015   if (!cn)
2016     return success();
2017 
2018   // Verify dimensions.
2019   const auto checkDim = [&](Value v, Size minSize, const char *message) {
2020     const Size sh = getMemRefType(v).getShape()[0];
2021     if (!ShapedType::isDynamic(sh) && sh < minSize)
2022       emitError(llvm::formatv("{0} got {1} < {2}", message, sh, minSize));
2023   };
2024   uint64_t n = cn.value();
2025   uint64_t ny = 0;
2026   if (auto nyAttr = getNyAttr())
2027     ny = nyAttr.getInt();
2028   checkDim(getXy(), n * (nx + ny),
2029            "Expected dimension(xy) >= n * (rank(perm_map) + ny)");
2030   for (Value opnd : getYs())
2031     checkDim(opnd, n, "Expected dimension(y) >= n");
2032 
2033   return success();
2034 }
2035 
2036 //===----------------------------------------------------------------------===//
2037 // Sparse Tensor Iteration Operations.
2038 //===----------------------------------------------------------------------===//
2039 
2040 IterSpaceType IteratorType::getIterSpaceType() const {
2041   return IterSpaceType::get(getContext(), getEncoding(), getLoLvl(),
2042                             getHiLvl());
2043 }
2044 
2045 IteratorType IterSpaceType::getIteratorType() const {
2046   return IteratorType::get(getContext(), getEncoding(), getLoLvl(), getHiLvl());
2047 }
2048 
2049 /// Parses a level range in the form "$lo `to` $hi"
2050 /// or simply "$lo" if $hi - $lo = 1
2051 static ParseResult parseLevelRange(AsmParser &parser, Level &lvlLo,
2052                                    Level &lvlHi) {
2053   if (parser.parseInteger(lvlLo))
2054     return failure();
2055 
2056   if (succeeded(parser.parseOptionalKeyword("to"))) {
2057     if (parser.parseInteger(lvlHi))
2058       return failure();
2059   } else {
2060     lvlHi = lvlLo + 1;
2061   }
2062 
2063   if (lvlHi <= lvlLo)
2064     parser.emitError(parser.getNameLoc(),
2065                      "expect larger level upper bound than lower bound");
2066 
2067   return success();
2068 }
2069 
2070 /// Parses a level range in the form "$lo `to` $hi"
2071 /// or simply "$lo" if $hi - $lo = 1
2072 static ParseResult parseLevelRange(OpAsmParser &parser, IntegerAttr &lvlLoAttr,
2073                                    IntegerAttr &lvlHiAttr) {
2074   Level lvlLo, lvlHi;
2075   if (parseLevelRange(parser, lvlLo, lvlHi))
2076     return failure();
2077 
2078   lvlLoAttr = IntegerAttr::get(parser.getBuilder().getIndexType(), lvlLo);
2079   lvlHiAttr = IntegerAttr::get(parser.getBuilder().getIndexType(), lvlHi);
2080   return success();
2081 }
2082 
2083 /// Prints a level range in the form "$lo `to` $hi"
2084 /// or simply "$lo" if $hi - $lo = 1
2085 static void printLevelRange(AsmPrinter &p, Level lo, Level hi) {
2086 
2087   if (lo + 1 == hi)
2088     p << lo;
2089   else
2090     p << lo << " to " << hi;
2091 }
2092 
2093 /// Prints a level range in the form "$lo `to` $hi"
2094 /// or simply "$lo" if $hi - $lo = 1
2095 static void printLevelRange(OpAsmPrinter &p, Operation *, IntegerAttr lvlLo,
2096                             IntegerAttr lvlHi) {
2097   unsigned lo = lvlLo.getValue().getZExtValue();
2098   unsigned hi = lvlHi.getValue().getZExtValue();
2099   printLevelRange(p, lo, hi);
2100 }
2101 
2102 LogicalResult ExtractIterSpaceOp::inferReturnTypes(
2103     MLIRContext *ctx, std::optional<Location> loc, ValueRange ops,
2104     DictionaryAttr attr, OpaqueProperties prop, RegionRange region,
2105     SmallVectorImpl<mlir::Type> &ret) {
2106 
2107   ExtractIterSpaceOp::Adaptor adaptor(ops, attr, prop, region);
2108   SparseTensorType stt = getSparseTensorType(adaptor.getTensor());
2109   ret.push_back(IterSpaceType::get(ctx, stt.getEncoding(), adaptor.getLoLvl(),
2110                                    adaptor.getHiLvl()));
2111   return success();
2112 }
2113 
2114 LogicalResult ExtractIterSpaceOp::verify() {
2115   if (getLoLvl() >= getHiLvl())
2116     return emitOpError("expected smaller level low than level high");
2117 
2118   TypedValue<IteratorType> pIter = getParentIter();
2119   if ((pIter && getLoLvl() == 0) || (!pIter && getLoLvl() != 0)) {
2120     return emitOpError(
2121         "parent iterator should be specified iff level lower bound equals 0");
2122   }
2123 
2124   if (pIter) {
2125     IterSpaceType spaceTp = getResultSpace().getType();
2126     if (pIter.getType().getEncoding() != spaceTp.getEncoding())
2127       return emitOpError(
2128           "mismatch in parent iterator encoding and iteration space encoding.");
2129 
2130     if (spaceTp.getLoLvl() != pIter.getType().getHiLvl())
2131       return emitOpError("parent iterator should be used to extract an "
2132                          "iteration space from a consecutive level.");
2133   }
2134 
2135   return success();
2136 }
2137 
2138 /// Materialize a single constant operation from a given attribute value with
2139 /// the desired resultant type.
2140 Operation *SparseTensorDialect::materializeConstant(OpBuilder &builder,
2141                                                     Attribute value, Type type,
2142                                                     Location loc) {
2143   if (auto op = arith::ConstantOp::materialize(builder, value, type, loc))
2144     return op;
2145   return nullptr;
2146 }
2147 
2148 namespace {
2149 struct SparseTensorAsmDialectInterface : public OpAsmDialectInterface {
2150   using OpAsmDialectInterface::OpAsmDialectInterface;
2151 
2152   AliasResult getAlias(Attribute attr, raw_ostream &os) const override {
2153     if (isa<SparseTensorEncodingAttr>(attr)) {
2154       os << "sparse";
2155       return AliasResult::OverridableAlias;
2156     }
2157     return AliasResult::NoAlias;
2158   }
2159 };
2160 } // namespace
2161 
2162 void SparseTensorDialect::initialize() {
2163   addInterface<SparseTensorAsmDialectInterface>();
2164   addAttributes<
2165 #define GET_ATTRDEF_LIST
2166 #include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.cpp.inc"
2167       >();
2168   addTypes<
2169 #define GET_TYPEDEF_LIST
2170 #include "mlir/Dialect/SparseTensor/IR/SparseTensorTypes.cpp.inc"
2171       >();
2172   addOperations<
2173 #define GET_OP_LIST
2174 #include "mlir/Dialect/SparseTensor/IR/SparseTensorOps.cpp.inc"
2175       >();
2176   declarePromisedInterfaces<
2177       bufferization::BufferizableOpInterface, ConcatenateOp, ConvertOp, LoadOp,
2178       NewOp, NumberOfEntriesOp, AssembleOp, DisassembleOp,
2179       ToCoordinatesBufferOp, ToCoordinatesOp, ToPositionsOp, ToValuesOp>();
2180 }
2181 
2182 #define GET_OP_CLASSES
2183 #include "mlir/Dialect/SparseTensor/IR/SparseTensorOps.cpp.inc"
2184 
2185 #include "mlir/Dialect/SparseTensor/IR/SparseTensorOpsDialect.cpp.inc"
2186