xref: /llvm-project/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp (revision 8debcf03c535e14ee47b14fddfcaeae3f32d1317)
1 //===- SparseTensorDialect.cpp - Sparse tensor dialect implementation -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include <utility>
10 
11 #include "Detail/DimLvlMapParser.h"
12 
13 #include "mlir/Dialect/SparseTensor/IR/Enums.h"
14 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
15 #include "mlir/Dialect/SparseTensor/IR/SparseTensorStorageLayout.h"
16 #include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h"
17 
18 #include "mlir/Dialect/Arith/IR/Arith.h"
19 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
20 #include "mlir/Dialect/Utils/StaticValueUtils.h"
21 #include "mlir/IR/Builders.h"
22 #include "mlir/IR/DialectImplementation.h"
23 #include "mlir/IR/Matchers.h"
24 #include "mlir/IR/OpImplementation.h"
25 #include "mlir/IR/PatternMatch.h"
26 #include "llvm/ADT/TypeSwitch.h"
27 #include "llvm/Support/FormatVariadic.h"
28 
29 #define GET_ATTRDEF_CLASSES
30 #include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.cpp.inc"
31 #include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrEnums.cpp.inc"
32 
33 // Forward declarations, following custom print/parsing methods are referenced
34 // by the generated code for SparseTensorTypes.td.
35 static mlir::ParseResult parseLevelRange(mlir::AsmParser &,
36                                          mlir::sparse_tensor::Level &,
37                                          mlir::sparse_tensor::Level &);
38 static void printLevelRange(mlir::AsmPrinter &, mlir::sparse_tensor::Level,
39                             mlir::sparse_tensor::Level);
40 
41 #define GET_TYPEDEF_CLASSES
42 #include "mlir/Dialect/SparseTensor/IR/SparseTensorTypes.cpp.inc"
43 
44 using namespace mlir;
45 using namespace mlir::sparse_tensor;
46 
47 // Support hashing LevelType such that SparseTensorEncodingAttr can be hashed as
48 // well.
49 namespace mlir::sparse_tensor {
50 llvm::hash_code hash_value(LevelType lt) {
51   return llvm::hash_value(static_cast<uint64_t>(lt));
52 }
53 } // namespace mlir::sparse_tensor
54 
55 //===----------------------------------------------------------------------===//
56 // Local Convenience Methods.
57 //===----------------------------------------------------------------------===//
58 
59 static constexpr bool acceptBitWidth(unsigned bitWidth) {
60   switch (bitWidth) {
61   case 0:
62   case 8:
63   case 16:
64   case 32:
65   case 64:
66     return true;
67   default:
68     return false;
69   }
70 }
71 
72 static SmallVector<Size>
73 getSparseFieldShape(const SparseTensorEncodingAttr enc,
74                     std::optional<ArrayRef<int64_t>> dimShape) {
75   assert(enc);
76   // With only encoding, we can not determine the static shape for leading
77   // batch levels, we therefore return a dynamic shape memref instead.
78   SmallVector<int64_t> memrefShape(enc.getBatchLvlRank(), ShapedType::kDynamic);
79   if (dimShape.has_value()) {
80     // If the actual tensor shape is provided, we can then refine the leading
81     // batch dimension.
82     SmallVector<int64_t> lvlShape =
83         enc.translateShape(*dimShape, CrdTransDirectionKind::dim2lvl);
84     memrefShape.assign(lvlShape.begin(),
85                        lvlShape.begin() + enc.getBatchLvlRank());
86   }
87   // Another dynamic dimension to store the sparse level.
88   memrefShape.push_back(ShapedType::kDynamic);
89   return memrefShape;
90 }
91 
92 //===----------------------------------------------------------------------===//
93 // SparseTensorDialect StorageLayout.
94 //===----------------------------------------------------------------------===//
95 
96 static constexpr Level kInvalidLevel = -1u;
97 static constexpr Level kInvalidFieldIndex = -1u;
98 static constexpr FieldIndex kDataFieldStartingIdx = 0;
99 
100 void StorageLayout::foreachField(
101     llvm::function_ref<bool(FieldIndex, SparseTensorFieldKind, Level,
102                             LevelType)>
103         callback) const {
104   const auto lvlTypes = enc.getLvlTypes();
105   const Level lvlRank = enc.getLvlRank();
106   SmallVector<COOSegment> cooSegs = SparseTensorType(enc).getCOOSegments();
107   FieldIndex fieldIdx = kDataFieldStartingIdx;
108 
109   ArrayRef cooSegsRef = cooSegs;
110   // Per-level storage.
111   for (Level l = 0; l < lvlRank; /*l += 1 or l += AoSCooLen*/) {
112     const auto lt = lvlTypes[l];
113     if (isWithPosLT(lt)) {
114       if (!(callback(fieldIdx++, SparseTensorFieldKind::PosMemRef, l, lt)))
115         return;
116     }
117     if (isWithCrdLT(lt)) {
118       if (!(callback(fieldIdx++, SparseTensorFieldKind::CrdMemRef, l, lt)))
119         return;
120     }
121     if (!cooSegsRef.empty() && cooSegsRef.front().isSegmentStart(l)) {
122       if (!cooSegsRef.front().isSoA) {
123         // AoS COO, all singletons are fused into one memrefs. Skips the entire
124         // COO segement.
125         l = cooSegsRef.front().lvlRange.second;
126       } else {
127         // SoA COO, each singleton level has one memref.
128         l++;
129       }
130       // Expire handled COO segment.
131       cooSegsRef = cooSegsRef.drop_front();
132     } else {
133       // Non COO levels.
134       l++;
135     }
136   }
137   // The values array.
138   if (!(callback(fieldIdx++, SparseTensorFieldKind::ValMemRef, kInvalidLevel,
139                  LevelFormat::Undef)))
140     return;
141   // Put metadata at the end.
142   if (!(callback(fieldIdx++, SparseTensorFieldKind::StorageSpec, kInvalidLevel,
143                  LevelFormat::Undef)))
144     return;
145 }
146 
147 void sparse_tensor::foreachFieldAndTypeInSparseTensor(
148     SparseTensorType stt,
149     llvm::function_ref<bool(Type, FieldIndex, SparseTensorFieldKind, Level,
150                             LevelType)>
151         callback) {
152   assert(stt.hasEncoding());
153 
154   SmallVector<int64_t> memrefShape =
155       getSparseFieldShape(stt.getEncoding(), stt.getDimShape());
156 
157   const Type specType = StorageSpecifierType::get(stt.getEncoding());
158   // memref<[batch] x ? x pos>  positions
159   const Type posMemType = MemRefType::get(memrefShape, stt.getPosType());
160   // memref<[batch] x ? x crd>  coordinates
161   const Type crdMemType = MemRefType::get(memrefShape, stt.getCrdType());
162   // memref<[batch] x ? x eltType> values
163   const Type valMemType = MemRefType::get(memrefShape, stt.getElementType());
164 
165   StorageLayout(stt).foreachField([specType, posMemType, crdMemType, valMemType,
166                                    callback](FieldIndex fieldIdx,
167                                              SparseTensorFieldKind fieldKind,
168                                              Level lvl, LevelType lt) -> bool {
169     switch (fieldKind) {
170     case SparseTensorFieldKind::StorageSpec:
171       return callback(specType, fieldIdx, fieldKind, lvl, lt);
172     case SparseTensorFieldKind::PosMemRef:
173       return callback(posMemType, fieldIdx, fieldKind, lvl, lt);
174     case SparseTensorFieldKind::CrdMemRef:
175       return callback(crdMemType, fieldIdx, fieldKind, lvl, lt);
176     case SparseTensorFieldKind::ValMemRef:
177       return callback(valMemType, fieldIdx, fieldKind, lvl, lt);
178     };
179     llvm_unreachable("unrecognized field kind");
180   });
181 }
182 
183 unsigned StorageLayout::getNumFields() const {
184   unsigned numFields = 0;
185   foreachField([&numFields](FieldIndex, SparseTensorFieldKind, Level,
186                             LevelType) -> bool {
187     numFields++;
188     return true;
189   });
190   return numFields;
191 }
192 
193 unsigned StorageLayout::getNumDataFields() const {
194   unsigned numFields = 0; // one value memref
195   foreachField([&numFields](FieldIndex fidx, SparseTensorFieldKind, Level,
196                             LevelType) -> bool {
197     if (fidx >= kDataFieldStartingIdx)
198       numFields++;
199     return true;
200   });
201   numFields -= 1; // the last field is StorageSpecifier
202   assert(numFields == getNumFields() - kDataFieldStartingIdx - 1);
203   return numFields;
204 }
205 
206 std::pair<FieldIndex, unsigned>
207 StorageLayout::getFieldIndexAndStride(SparseTensorFieldKind kind,
208                                       std::optional<Level> lvl) const {
209   FieldIndex fieldIdx = kInvalidFieldIndex;
210   unsigned stride = 1;
211   if (kind == SparseTensorFieldKind::CrdMemRef) {
212     assert(lvl.has_value());
213     const Level cooStart = SparseTensorType(enc).getAoSCOOStart();
214     const Level lvlRank = enc.getLvlRank();
215     if (lvl.value() >= cooStart && lvl.value() < lvlRank) {
216       lvl = cooStart;
217       stride = lvlRank - cooStart;
218     }
219   }
220   foreachField([lvl, kind, &fieldIdx](FieldIndex fIdx,
221                                       SparseTensorFieldKind fKind, Level fLvl,
222                                       LevelType lt) -> bool {
223     if ((lvl && fLvl == lvl.value() && kind == fKind) ||
224         (kind == fKind && fKind == SparseTensorFieldKind::ValMemRef)) {
225       fieldIdx = fIdx;
226       // Returns false to break the iteration.
227       return false;
228     }
229     return true;
230   });
231   assert(fieldIdx != kInvalidFieldIndex);
232   return std::pair<FieldIndex, unsigned>(fieldIdx, stride);
233 }
234 
235 //===----------------------------------------------------------------------===//
236 // SparseTensorDialect Attribute Methods.
237 //===----------------------------------------------------------------------===//
238 
239 std::optional<uint64_t> SparseTensorDimSliceAttr::getStatic(int64_t v) {
240   return isDynamic(v) ? std::nullopt
241                       : std::make_optional(static_cast<uint64_t>(v));
242 }
243 
244 std::optional<uint64_t> SparseTensorDimSliceAttr::getStaticOffset() const {
245   return getStatic(getOffset());
246 }
247 
248 std::optional<uint64_t> SparseTensorDimSliceAttr::getStaticStride() const {
249   return getStatic(getStride());
250 }
251 
252 std::optional<uint64_t> SparseTensorDimSliceAttr::getStaticSize() const {
253   return getStatic(getSize());
254 }
255 
256 bool SparseTensorDimSliceAttr::isCompletelyDynamic() const {
257   return isDynamic(getOffset()) && isDynamic(getStride()) &&
258          isDynamic(getSize());
259 }
260 
261 std::string SparseTensorDimSliceAttr::getStaticString(int64_t v) {
262   return isDynamic(v) ? "?" : std::to_string(v);
263 }
264 
265 void SparseTensorDimSliceAttr::print(llvm::raw_ostream &os) const {
266   assert(getImpl() && "Uninitialized SparseTensorDimSliceAttr");
267   os << '(';
268   os << getStaticString(getOffset());
269   os << ", ";
270   os << getStaticString(getSize());
271   os << ", ";
272   os << getStaticString(getStride());
273   os << ')';
274 }
275 
276 void SparseTensorDimSliceAttr::print(AsmPrinter &printer) const {
277   print(printer.getStream());
278 }
279 
280 static ParseResult parseOptionalStaticSlice(int64_t &result,
281                                             AsmParser &parser) {
282   auto parseResult = parser.parseOptionalInteger(result);
283   if (parseResult.has_value()) {
284     if (parseResult.value().succeeded() && result < 0) {
285       parser.emitError(
286           parser.getCurrentLocation(),
287           "expect positive value or ? for slice offset/size/stride");
288       return failure();
289     }
290     return parseResult.value();
291   }
292 
293   // Else, and '?' which represented dynamic slice
294   result = SparseTensorDimSliceAttr::kDynamic;
295   return parser.parseQuestion();
296 }
297 
298 Attribute SparseTensorDimSliceAttr::parse(AsmParser &parser, Type type) {
299   int64_t offset = kDynamic, size = kDynamic, stride = kDynamic;
300 
301   if (failed(parser.parseLParen()) ||
302       failed(parseOptionalStaticSlice(offset, parser)) ||
303       failed(parser.parseComma()) ||
304       failed(parseOptionalStaticSlice(size, parser)) ||
305       failed(parser.parseComma()) ||
306       failed(parseOptionalStaticSlice(stride, parser)) ||
307       failed(parser.parseRParen()))
308     return {};
309 
310   return parser.getChecked<SparseTensorDimSliceAttr>(parser.getContext(),
311                                                      offset, size, stride);
312 }
313 
314 LogicalResult
315 SparseTensorDimSliceAttr::verify(function_ref<InFlightDiagnostic()> emitError,
316                                  int64_t offset, int64_t size, int64_t stride) {
317   if (!isDynamic(offset) && offset < 0)
318     return emitError() << "expect non-negative value or ? for slice offset";
319   if (!isDynamic(size) && size <= 0)
320     return emitError() << "expect positive value or ? for slice size";
321   if (!isDynamic(stride) && stride <= 0)
322     return emitError() << "expect positive value or ? for slice stride";
323   return success();
324 }
325 
326 SparseTensorEncodingAttr
327 SparseTensorEncodingAttr::withDimToLvl(AffineMap dimToLvl) const {
328   assert(getImpl() && "Uninitialized SparseTensorEncodingAttr");
329   return SparseTensorEncodingAttr::get(getContext(), getLvlTypes(), dimToLvl,
330                                        AffineMap(), getPosWidth(),
331                                        getCrdWidth());
332 }
333 
334 SparseTensorEncodingAttr
335 SparseTensorEncodingAttr::withDimToLvl(SparseTensorEncodingAttr enc) const {
336   return withDimToLvl(enc ? enc.getDimToLvl() : AffineMap());
337 }
338 
339 SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutDimToLvl() const {
340   return withDimToLvl(AffineMap());
341 }
342 
343 SparseTensorEncodingAttr
344 SparseTensorEncodingAttr::withBitWidths(unsigned posWidth,
345                                         unsigned crdWidth) const {
346   assert(getImpl() && "Uninitialized SparseTensorEncodingAttr");
347   return SparseTensorEncodingAttr::get(getContext(), getLvlTypes(),
348                                        getDimToLvl(), getLvlToDim(), posWidth,
349                                        crdWidth);
350 }
351 
352 SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutBitWidths() const {
353   return withBitWidths(0, 0);
354 }
355 
356 SparseTensorEncodingAttr SparseTensorEncodingAttr::withDimSlices(
357     ArrayRef<SparseTensorDimSliceAttr> dimSlices) const {
358   return SparseTensorEncodingAttr::get(getContext(), getLvlTypes(),
359                                        getDimToLvl(), getLvlToDim(),
360                                        getPosWidth(), getCrdWidth(), dimSlices);
361 }
362 
363 SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutDimSlices() const {
364   return withDimSlices(ArrayRef<SparseTensorDimSliceAttr>{});
365 }
366 
367 uint64_t SparseTensorEncodingAttr::getBatchLvlRank() const {
368   ArrayRef<LevelType> lvlTypes = getLvlTypes();
369   auto lastBatch = std::find_if(lvlTypes.rbegin(), lvlTypes.rend(), isBatchLT);
370   return std::distance(lastBatch, lvlTypes.rend());
371 }
372 
373 bool SparseTensorEncodingAttr::isAllDense() const {
374   return !getImpl() || llvm::all_of(getLvlTypes(), isDenseLT);
375 }
376 
377 bool SparseTensorEncodingAttr::isAllOrdered() const {
378   return !getImpl() || llvm::all_of(getLvlTypes(), isOrderedLT);
379 }
380 
381 Type SparseTensorEncodingAttr::getCrdElemType() const {
382   if (!getImpl())
383     return nullptr;
384   if (getCrdWidth())
385     return IntegerType::get(getContext(), getCrdWidth());
386   return IndexType::get(getContext());
387 }
388 
389 Type SparseTensorEncodingAttr::getPosElemType() const {
390   if (!getImpl())
391     return nullptr;
392   if (getPosWidth())
393     return IntegerType::get(getContext(), getPosWidth());
394   return IndexType::get(getContext());
395 }
396 
397 MemRefType SparseTensorEncodingAttr::getCrdMemRefType(
398     std::optional<ArrayRef<int64_t>> dimShape) const {
399   SmallVector<Size> shape = getSparseFieldShape(*this, dimShape);
400   return MemRefType::get(shape, getCrdElemType());
401 }
402 
403 MemRefType SparseTensorEncodingAttr::getPosMemRefType(
404     std::optional<ArrayRef<int64_t>> dimShape) const {
405   SmallVector<Size> shape = getSparseFieldShape(*this, dimShape);
406   return MemRefType::get(shape, getPosElemType());
407 }
408 
409 bool SparseTensorEncodingAttr::isIdentity() const {
410   return !getImpl() || !getDimToLvl() || getDimToLvl().isIdentity();
411 }
412 
413 bool SparseTensorEncodingAttr::isPermutation() const {
414   return !getImpl() || !getDimToLvl() || getDimToLvl().isPermutation();
415 }
416 
417 Dimension SparseTensorEncodingAttr::getDimRank() const {
418   assert(getImpl() && "Uninitialized SparseTensorEncodingAttr");
419   const auto dimToLvl = getDimToLvl();
420   return dimToLvl ? dimToLvl.getNumDims() : getLvlRank();
421 }
422 
423 Level SparseTensorEncodingAttr::getLvlRank() const {
424   assert(getImpl() && "Uninitialized SparseTensorEncodingAttr");
425   return getLvlTypes().size();
426 }
427 
428 LevelType SparseTensorEncodingAttr::getLvlType(Level l) const {
429   if (!getImpl())
430     return LevelFormat::Batch;
431   assert(l < getLvlRank() && "Level is out of bounds");
432   return getLvlTypes()[l];
433 }
434 
435 bool SparseTensorEncodingAttr::isSlice() const {
436   assert(getImpl() && "Uninitialized SparseTensorEncodingAttr");
437   return !getDimSlices().empty();
438 }
439 
440 SparseTensorDimSliceAttr
441 SparseTensorEncodingAttr::getDimSlice(Dimension dim) const {
442   assert(isSlice() && "Is not a slice");
443   const auto dimSlices = getDimSlices();
444   assert(dim < dimSlices.size() && "Dimension is out of bounds");
445   return dimSlices[dim];
446 }
447 
448 std::optional<uint64_t>
449 SparseTensorEncodingAttr::getStaticDimSliceOffset(Dimension dim) const {
450   return getDimSlice(dim).getStaticOffset();
451 }
452 
453 std::optional<uint64_t>
454 SparseTensorEncodingAttr::getStaticDimSliceStride(Dimension dim) const {
455   return getDimSlice(dim).getStaticStride();
456 }
457 
458 std::optional<uint64_t>
459 SparseTensorEncodingAttr::getStaticLvlSliceOffset(Level lvl) const {
460   return getStaticDimSliceOffset(toDim(*this, lvl));
461 }
462 
463 std::optional<uint64_t>
464 SparseTensorEncodingAttr::getStaticLvlSliceStride(Level lvl) const {
465   return getStaticDimSliceStride(toDim(*this, lvl));
466 }
467 
468 SmallVector<int64_t>
469 SparseTensorEncodingAttr::translateShape(ArrayRef<int64_t> srcShape,
470                                          CrdTransDirectionKind dir) const {
471   if (isIdentity())
472     return SmallVector<int64_t>(srcShape);
473 
474   SmallVector<int64_t> ret;
475   unsigned rank =
476       dir == CrdTransDirectionKind::dim2lvl ? getLvlRank() : getDimRank();
477   ret.reserve(rank);
478 
479   if (isPermutation()) {
480     for (unsigned r = 0; r < rank; r++) {
481       unsigned trans = dir == CrdTransDirectionKind::dim2lvl ? toDim(*this, r)
482                                                              : toLvl(*this, r);
483       ret.push_back(srcShape[trans]);
484     }
485     return ret;
486   }
487 
488   // Handle non-permutation maps.
489   AffineMap transMap =
490       dir == CrdTransDirectionKind::dim2lvl ? getDimToLvl() : getLvlToDim();
491 
492   SmallVector<AffineExpr> dimRep;
493   dimRep.reserve(srcShape.size());
494   for (int64_t sz : srcShape) {
495     if (!ShapedType::isDynamic(sz)) {
496       // Push back the max coordinate for the given dimension/level size.
497       dimRep.push_back(getAffineConstantExpr(sz - 1, getContext()));
498     } else {
499       // A dynamic size, use a AffineDimExpr to symbolize the value.
500       dimRep.push_back(getAffineDimExpr(dimRep.size(), getContext()));
501     }
502   };
503 
504   for (AffineExpr exp : transMap.getResults()) {
505     // Do constant propagation on the affine map.
506     AffineExpr evalExp =
507         simplifyAffineExpr(exp.replaceDims(dimRep), srcShape.size(), 0);
508     // use llvm namespace here to avoid ambiguity
509     if (auto c = llvm::dyn_cast<AffineConstantExpr>(evalExp)) {
510       ret.push_back(c.getValue() + 1);
511     } else {
512       if (auto mod = llvm::dyn_cast<AffineBinaryOpExpr>(evalExp);
513           mod && mod.getKind() == AffineExprKind::Mod) {
514         // We can still infer a static bound for expressions in form
515         // "d % constant" since d % constant \in [0, constant).
516         if (auto bound = llvm::dyn_cast<AffineConstantExpr>(mod.getRHS())) {
517           ret.push_back(bound.getValue());
518           continue;
519         }
520       }
521       ret.push_back(ShapedType::kDynamic);
522     }
523   }
524   assert(ret.size() == rank);
525   return ret;
526 }
527 
528 ValueRange
529 SparseTensorEncodingAttr::translateCrds(OpBuilder &builder, Location loc,
530                                         ValueRange crds,
531                                         CrdTransDirectionKind dir) const {
532   if (!getImpl())
533     return crds;
534 
535   SmallVector<Type> retType(
536       dir == CrdTransDirectionKind::lvl2dim ? getDimRank() : getLvlRank(),
537       builder.getIndexType());
538   auto transOp = builder.create<CrdTranslateOp>(loc, retType, crds, dir, *this);
539   return transOp.getOutCrds();
540 }
541 
542 Attribute SparseTensorEncodingAttr::parse(AsmParser &parser, Type type) {
543   // Open "<{" part.
544   if (failed(parser.parseLess()))
545     return {};
546   if (failed(parser.parseLBrace()))
547     return {};
548 
549   // Process the data from the parsed dictionary value into struct-like data.
550   SmallVector<LevelType> lvlTypes;
551   SmallVector<SparseTensorDimSliceAttr> dimSlices;
552   AffineMap dimToLvl = {};
553   AffineMap lvlToDim = {};
554   unsigned posWidth = 0;
555   unsigned crdWidth = 0;
556   StringRef attrName;
557   SmallVector<StringRef, 3> keys = {"map", "posWidth", "crdWidth"};
558   while (succeeded(parser.parseOptionalKeyword(&attrName))) {
559     // Detect admissible keyword.
560     auto *it = find(keys, attrName);
561     if (it == keys.end()) {
562       parser.emitError(parser.getNameLoc(), "unexpected key: ") << attrName;
563       return {};
564     }
565     unsigned keyWordIndex = it - keys.begin();
566     // Consume the `=` after keys
567     if (failed(parser.parseEqual()))
568       return {};
569     // Dispatch on keyword.
570     switch (keyWordIndex) {
571     case 0: { // map
572       ir_detail::DimLvlMapParser cParser(parser);
573       auto res = cParser.parseDimLvlMap();
574       if (failed(res))
575         return {};
576       const auto &dlm = *res;
577 
578       const Level lvlRank = dlm.getLvlRank();
579       for (Level lvl = 0; lvl < lvlRank; lvl++)
580         lvlTypes.push_back(dlm.getLvlType(lvl));
581 
582       const Dimension dimRank = dlm.getDimRank();
583       for (Dimension dim = 0; dim < dimRank; dim++)
584         dimSlices.push_back(dlm.getDimSlice(dim));
585       // NOTE: the old syntax requires an all-or-nothing approach to
586       // `dimSlices`; therefore, if any slice actually exists then we need
587       // to convert null-DSA into default/nop DSA.
588       const auto isDefined = [](SparseTensorDimSliceAttr slice) {
589         return static_cast<bool>(slice.getImpl());
590       };
591       if (llvm::any_of(dimSlices, isDefined)) {
592         const auto defaultSlice =
593             SparseTensorDimSliceAttr::get(parser.getContext());
594         for (Dimension dim = 0; dim < dimRank; dim++)
595           if (!isDefined(dimSlices[dim]))
596             dimSlices[dim] = defaultSlice;
597       } else {
598         dimSlices.clear();
599       }
600 
601       dimToLvl = dlm.getDimToLvlMap(parser.getContext());
602       lvlToDim = dlm.getLvlToDimMap(parser.getContext());
603       break;
604     }
605     case 1: { // posWidth
606       Attribute attr;
607       if (failed(parser.parseAttribute(attr)))
608         return {};
609       auto intAttr = llvm::dyn_cast<IntegerAttr>(attr);
610       if (!intAttr) {
611         parser.emitError(parser.getNameLoc(),
612                          "expected an integral position bitwidth");
613         return {};
614       }
615       posWidth = intAttr.getInt();
616       break;
617     }
618     case 2: { // crdWidth
619       Attribute attr;
620       if (failed(parser.parseAttribute(attr)))
621         return {};
622       auto intAttr = llvm::dyn_cast<IntegerAttr>(attr);
623       if (!intAttr) {
624         parser.emitError(parser.getNameLoc(),
625                          "expected an integral index bitwidth");
626         return {};
627       }
628       crdWidth = intAttr.getInt();
629       break;
630     }
631     } // switch
632     // Only last item can omit the comma.
633     if (parser.parseOptionalComma().failed())
634       break;
635   }
636 
637   // Close "}>" part.
638   if (failed(parser.parseRBrace()))
639     return {};
640   if (failed(parser.parseGreater()))
641     return {};
642 
643   // Construct struct-like storage for attribute.
644   if (!lvlToDim || lvlToDim.isEmpty()) {
645     lvlToDim = inferLvlToDim(dimToLvl, parser.getContext());
646   }
647   return parser.getChecked<SparseTensorEncodingAttr>(
648       parser.getContext(), lvlTypes, dimToLvl, lvlToDim, posWidth, crdWidth,
649       dimSlices);
650 }
651 
652 void SparseTensorEncodingAttr::print(AsmPrinter &printer) const {
653   auto map = static_cast<AffineMap>(getDimToLvl());
654   // Empty affine map indicates identity map
655   if (!map)
656     map = AffineMap::getMultiDimIdentityMap(getLvlTypes().size(), getContext());
657   printer << "<{ map = ";
658   printSymbols(map, printer);
659   printer << '(';
660   printDimensions(map, printer, getDimSlices());
661   printer << ") -> (";
662   printLevels(map, printer, getLvlTypes());
663   printer << ')';
664   // Print remaining members only for non-default values.
665   if (getPosWidth())
666     printer << ", posWidth = " << getPosWidth();
667   if (getCrdWidth())
668     printer << ", crdWidth = " << getCrdWidth();
669   printer << " }>";
670 }
671 
672 void SparseTensorEncodingAttr::printSymbols(AffineMap &map,
673                                             AsmPrinter &printer) const {
674   if (map.getNumSymbols() == 0)
675     return;
676   printer << '[';
677   for (unsigned i = 0, n = map.getNumSymbols() - 1; i < n; i++)
678     printer << 's' << i << ", ";
679   if (map.getNumSymbols() >= 1)
680     printer << 's' << map.getNumSymbols() - 1;
681   printer << ']';
682 }
683 
684 void SparseTensorEncodingAttr::printDimensions(
685     AffineMap &map, AsmPrinter &printer,
686     ArrayRef<SparseTensorDimSliceAttr> dimSlices) const {
687   if (!dimSlices.empty()) {
688     for (unsigned i = 0, n = map.getNumDims() - 1; i < n; i++)
689       printer << 'd' << i << " : " << dimSlices[i] << ", ";
690     if (map.getNumDims() >= 1) {
691       printer << 'd' << map.getNumDims() - 1 << " : "
692               << dimSlices[map.getNumDims() - 1];
693     }
694   } else {
695     for (unsigned i = 0, n = map.getNumDims() - 1; i < n; i++)
696       printer << 'd' << i << ", ";
697     if (map.getNumDims() >= 1)
698       printer << 'd' << map.getNumDims() - 1;
699   }
700 }
701 
702 void SparseTensorEncodingAttr::printLevels(AffineMap &map, AsmPrinter &printer,
703                                            ArrayRef<LevelType> lvlTypes) const {
704   for (unsigned i = 0, n = map.getNumResults() - 1; i < n; i++) {
705     map.getResult(i).print(printer.getStream());
706     printer << " : " << toMLIRString(lvlTypes[i]) << ", ";
707   }
708   if (map.getNumResults() >= 1) {
709     auto lastIndex = map.getNumResults() - 1;
710     map.getResult(lastIndex).print(printer.getStream());
711     printer << " : " << toMLIRString(lvlTypes[lastIndex]);
712   }
713 }
714 
715 LogicalResult SparseTensorEncodingAttr::verify(
716     function_ref<InFlightDiagnostic()> emitError, ArrayRef<LevelType> lvlTypes,
717     AffineMap dimToLvl, AffineMap lvlToDim, unsigned posWidth,
718     unsigned crdWidth, ArrayRef<SparseTensorDimSliceAttr> dimSlices) {
719   if (!acceptBitWidth(posWidth))
720     return emitError() << "unexpected position bitwidth: " << posWidth;
721   if (!acceptBitWidth(crdWidth))
722     return emitError() << "unexpected coordinate bitwidth: " << crdWidth;
723   if (auto it = std::find_if(lvlTypes.begin(), lvlTypes.end(), isSingletonLT);
724       it != std::end(lvlTypes)) {
725     if (it == lvlTypes.begin() ||
726         (!isCompressedLT(*(it - 1)) && !isLooseCompressedLT(*(it - 1))))
727       return emitError() << "expected compressed or loose_compressed level "
728                             "before singleton level";
729     if (!std::all_of(it, lvlTypes.end(),
730                      [](LevelType i) { return isSingletonLT(i); }))
731       return emitError() << "expected all singleton lvlTypes "
732                             "following a singleton level";
733     // We can potentially support mixed SoA/AoS singleton levels.
734     if (!std::all_of(it, lvlTypes.end(), [it](LevelType i) {
735           return it->isa<LevelPropNonDefault::SoA>() ==
736                  i.isa<LevelPropNonDefault::SoA>();
737         })) {
738       return emitError() << "expected all singleton lvlTypes stored in the "
739                             "same memory layout (SoA vs AoS).";
740     }
741   }
742 
743   auto lastBatch = std::find_if(lvlTypes.rbegin(), lvlTypes.rend(), isBatchLT);
744   if (!std::all_of(lastBatch, lvlTypes.rend(), isBatchLT))
745     return emitError() << "Batch lvlType can only be leading levels.";
746 
747   // SoA property can only be applied on singleton level.
748   auto soaLvls = llvm::make_filter_range(lvlTypes, [](LevelType lt) {
749     return lt.isa<LevelPropNonDefault::SoA>();
750   });
751   if (llvm::any_of(soaLvls, [](LevelType lt) {
752         return !lt.isa<LevelFormat::Singleton>();
753       })) {
754     return emitError() << "SoA is only applicable to singleton lvlTypes.";
755   }
756 
757   // TODO: audit formats that actually are supported by backend.
758   if (auto it = std::find_if(lvlTypes.begin(), lvlTypes.end(), isNOutOfMLT);
759       it != std::end(lvlTypes)) {
760     if (it != lvlTypes.end() - 1)
761       return emitError() << "expected n_out_of_m to be the last level type";
762     if (!std::all_of(lvlTypes.begin(), it,
763                      [](LevelType i) { return isDenseLT(i); }))
764       return emitError() << "expected all dense lvlTypes "
765                             "before a n_out_of_m level";
766     if (dimToLvl && (dimToLvl.getNumDims() != dimToLvl.getNumResults())) {
767       if (!isBlockSparsity(dimToLvl)) {
768         return emitError()
769                << "expected 1xm block structure for n_out_of_m level";
770       }
771       auto sizes = getBlockSize(dimToLvl);
772       unsigned coefficient = 0;
773       for (const auto &elem : sizes) {
774         if (elem != 0) {
775           if (elem != coefficient && coefficient != 0) {
776             return emitError() << "expected only one blocked level "
777                                   "with the same coefficients";
778           }
779           coefficient = elem;
780         }
781       }
782       if (coefficient != getM(*it)) {
783         return emitError() << "expected coeffiencts of Affine expressions "
784                               "to be equal to m of n_out_of_m level";
785       }
786     }
787   }
788   // Before we can check that the level-rank is consistent/coherent
789   // across all fields, we need to define it.  The source-of-truth for
790   // the `getLvlRank` method is the length of the level-types array,
791   // since it must always be provided and have full rank; therefore we
792   // use that same source-of-truth here.
793   const Level lvlRank = lvlTypes.size();
794   if (lvlRank == 0)
795     return emitError() << "expected a non-empty array for lvlTypes";
796   // We save `dimRank` here because we'll also need it to verify `dimSlices`.
797   const Dimension dimRank = dimToLvl ? dimToLvl.getNumDims() : lvlRank;
798   if (dimToLvl) {
799     if (dimToLvl.getNumResults() != lvlRank)
800       return emitError()
801              << "level-rank mismatch between dimToLvl and lvlTypes: "
802              << dimToLvl.getNumResults() << " != " << lvlRank;
803     auto inferRes = inferLvlToDim(dimToLvl, dimToLvl.getContext());
804     // Symbols can't be inferred but are acceptable.
805     if (!inferRes && dimToLvl.getNumSymbols() == 0)
806       return emitError() << "failed to infer lvlToDim from dimToLvl";
807     if (lvlToDim && (inferRes != lvlToDim))
808       return emitError() << "expected lvlToDim to be an inverse of dimToLvl";
809     if (dimRank > lvlRank)
810       return emitError() << "unexpected dimToLvl mapping from " << dimRank
811                          << " to " << lvlRank;
812   }
813   if (!dimSlices.empty()) {
814     if (dimSlices.size() != dimRank)
815       return emitError()
816              << "dimension-rank mismatch between dimSlices and dimToLvl: "
817              << dimSlices.size() << " != " << dimRank;
818     // Compiler support for `dimSlices` currently requires that the two
819     // ranks agree.  (However, it does allow `dimToLvl` to be a permutation.)
820     if (dimRank != lvlRank)
821       return emitError()
822              << "dimSlices expected dimension-rank to match level-rank: "
823              << dimRank << " != " << lvlRank;
824   }
825   return success();
826 }
827 
828 LogicalResult SparseTensorEncodingAttr::verifyEncoding(
829     ArrayRef<Size> dimShape, Type elementType,
830     function_ref<InFlightDiagnostic()> emitError) const {
831   // Check structural integrity.  In particular, this ensures that the
832   // level-rank is coherent across all the fields.
833   if (failed(verify(emitError, getLvlTypes(), getDimToLvl(), getLvlToDim(),
834                     getPosWidth(), getCrdWidth(), getDimSlices())))
835     return failure();
836   // Check integrity with tensor type specifics.  In particular, we
837   // need only check that the dimension-rank of the tensor agrees with
838   // the dimension-rank of the encoding.
839   const Dimension dimRank = dimShape.size();
840   if (dimRank == 0)
841     return emitError() << "expected non-scalar sparse tensor";
842   if (getDimRank() != dimRank)
843     return emitError()
844            << "dimension-rank mismatch between encoding and tensor shape: "
845            << getDimRank() << " != " << dimRank;
846   return success();
847 }
848 
849 //===----------------------------------------------------------------------===//
850 // SparseTensorType Methods.
851 //===----------------------------------------------------------------------===//
852 
853 bool mlir::sparse_tensor::SparseTensorType::isCOOType(Level startLvl,
854                                                       bool isUnique) const {
855   if (!hasEncoding())
856     return false;
857   if (!isCompressedLvl(startLvl) && !isLooseCompressedLvl(startLvl))
858     return false;
859   for (Level l = startLvl + 1; l < lvlRank; ++l)
860     if (!isSingletonLvl(l))
861       return false;
862   // If isUnique is true, then make sure that the last level is unique,
863   // that is, when lvlRank == 1, the only compressed level is unique,
864   // and when lvlRank > 1, the last singleton is unique.
865   return !isUnique || isUniqueLvl(lvlRank - 1);
866 }
867 
868 Level mlir::sparse_tensor::SparseTensorType::getAoSCOOStart() const {
869   SmallVector<COOSegment> coo = getCOOSegments();
870   assert(coo.size() == 1 || coo.empty());
871   if (!coo.empty() && coo.front().isAoS()) {
872     return coo.front().lvlRange.first;
873   }
874   return lvlRank;
875 }
876 
877 SmallVector<COOSegment>
878 mlir::sparse_tensor::SparseTensorType::getCOOSegments() const {
879   SmallVector<COOSegment> ret;
880   if (!hasEncoding() || lvlRank <= 1)
881     return ret;
882 
883   ArrayRef<LevelType> lts = getLvlTypes();
884   Level l = 0;
885   while (l < lvlRank) {
886     auto lt = lts[l];
887     if (lt.isa<LevelFormat::Compressed, LevelFormat::LooseCompressed>()) {
888       auto cur = lts.begin() + l;
889       auto end = std::find_if(cur + 1, lts.end(), [](LevelType lt) {
890         return !lt.isa<LevelFormat::Singleton>();
891       });
892       unsigned cooLen = std::distance(cur, end);
893       if (cooLen > 1) {
894         // To support mixed SoA/AoS COO, we should break the segment when the
895         // storage scheme changes, for now we faithfully assume that all
896         // consecutive singleton levels have the same storage format as verified
897         // STEA.
898         ret.push_back(COOSegment{std::make_pair(l, l + cooLen),
899                                  lts[l + 1].isa<LevelPropNonDefault::SoA>()});
900       }
901       l += cooLen;
902     } else {
903       l++;
904     }
905   }
906   return ret;
907 }
908 
909 RankedTensorType
910 mlir::sparse_tensor::SparseTensorType::getCOOType(bool ordered) const {
911   SmallVector<LevelType> lvlTypes;
912   lvlTypes.reserve(lvlRank);
913   // A non-unique compressed level at beginning (unless this is
914   // also the last level, then it is unique).
915   lvlTypes.push_back(
916       *buildLevelType(LevelFormat::Compressed, ordered, lvlRank == 1));
917   if (lvlRank > 1) {
918     // Followed by n-2 non-unique singleton levels.
919     std::fill_n(std::back_inserter(lvlTypes), lvlRank - 2,
920                 *buildLevelType(LevelFormat::Singleton, ordered, false));
921     // Ends by a unique singleton level.
922     lvlTypes.push_back(*buildLevelType(LevelFormat::Singleton, ordered, true));
923   }
924   auto enc = SparseTensorEncodingAttr::get(getContext(), lvlTypes,
925                                            getDimToLvl(), getLvlToDim(),
926                                            getPosWidth(), getCrdWidth());
927   return RankedTensorType::get(getDimShape(), getElementType(), enc);
928 }
929 
930 //===----------------------------------------------------------------------===//
931 // Convenience Methods.
932 //===----------------------------------------------------------------------===//
933 
934 SparseTensorEncodingAttr
935 mlir::sparse_tensor::getSparseTensorEncoding(Type type) {
936   if (auto ttp = llvm::dyn_cast<RankedTensorType>(type))
937     return llvm::dyn_cast_or_null<SparseTensorEncodingAttr>(ttp.getEncoding());
938   if (auto mdtp = llvm::dyn_cast<StorageSpecifierType>(type))
939     return mdtp.getEncoding();
940   return nullptr;
941 }
942 
943 AffineMap mlir::sparse_tensor::inferLvlToDim(AffineMap dimToLvl,
944                                              MLIRContext *context) {
945   auto map = static_cast<AffineMap>(dimToLvl);
946   AffineMap lvlToDim;
947   // Return an empty lvlToDim when inference is not successful.
948   if (!map || map.getNumSymbols() != 0) {
949     lvlToDim = AffineMap();
950   } else if (map.isPermutation()) {
951     lvlToDim = inversePermutation(map);
952   } else if (isBlockSparsity(map)) {
953     lvlToDim = inverseBlockSparsity(map, context);
954   }
955   return lvlToDim;
956 }
957 
958 AffineMap mlir::sparse_tensor::inverseBlockSparsity(AffineMap dimToLvl,
959                                                     MLIRContext *context) {
960   SmallVector<AffineExpr> lvlExprs;
961   auto numLvls = dimToLvl.getNumResults();
962   lvlExprs.reserve(numLvls);
963   // lvlExprComponents stores information of the floordiv and mod operations
964   // applied to the same dimension, so as to build the lvlToDim map.
965   std::map<unsigned, SmallVector<AffineExpr, 3>> lvlExprComponents;
966   for (unsigned i = 0, n = numLvls; i < n; i++) {
967     auto result = dimToLvl.getResult(i);
968     if (auto binOp = dyn_cast<AffineBinaryOpExpr>(result)) {
969       if (result.getKind() == AffineExprKind::FloorDiv) {
970         // Position of the dimension in dimToLvl.
971         auto pos = dyn_cast<AffineDimExpr>(binOp.getLHS()).getPosition();
972         assert(lvlExprComponents.find(pos) == lvlExprComponents.end() &&
973                "expected only one floordiv for each dimension");
974         SmallVector<AffineExpr, 3> components;
975         // Level variable for floordiv.
976         components.push_back(getAffineDimExpr(i, context));
977         // Multiplier.
978         components.push_back(binOp.getRHS());
979         // Map key is the position of the dimension.
980         lvlExprComponents[pos] = components;
981       } else if (result.getKind() == AffineExprKind::Mod) {
982         auto pos = dyn_cast<AffineDimExpr>(binOp.getLHS()).getPosition();
983         assert(lvlExprComponents.find(pos) != lvlExprComponents.end() &&
984                "expected floordiv before mod");
985         // Add level variable for mod to the same vector
986         // of the corresponding floordiv.
987         lvlExprComponents[pos].push_back(getAffineDimExpr(i, context));
988       } else {
989         assert(false && "expected floordiv or mod");
990       }
991     } else {
992       lvlExprs.push_back(getAffineDimExpr(i, context));
993     }
994   }
995   // Build lvlExprs from lvlExprComponents.
996   // For example, for il = i floordiv 2 and ii = i mod 2, the components
997   // would be [il, 2, ii]. It could be used to build the AffineExpr
998   // i = il * 2 + ii in lvlToDim.
999   for (auto &components : lvlExprComponents) {
1000     assert(components.second.size() == 3 &&
1001            "expected 3 components to build lvlExprs");
1002     auto mulOp = getAffineBinaryOpExpr(
1003         AffineExprKind::Mul, components.second[0], components.second[1]);
1004     auto addOp =
1005         getAffineBinaryOpExpr(AffineExprKind::Add, mulOp, components.second[2]);
1006     lvlExprs.push_back(addOp);
1007   }
1008   return dimToLvl.get(dimToLvl.getNumResults(), 0, lvlExprs, context);
1009 }
1010 
1011 SmallVector<unsigned> mlir::sparse_tensor::getBlockSize(AffineMap dimToLvl) {
1012   assert(isBlockSparsity(dimToLvl) &&
1013          "expected dimToLvl to be block sparsity for calling getBlockSize");
1014   SmallVector<unsigned> blockSize;
1015   for (auto result : dimToLvl.getResults()) {
1016     if (auto binOp = dyn_cast<AffineBinaryOpExpr>(result)) {
1017       if (result.getKind() == AffineExprKind::Mod) {
1018         blockSize.push_back(
1019             dyn_cast<AffineConstantExpr>(binOp.getRHS()).getValue());
1020       }
1021     } else {
1022       blockSize.push_back(0);
1023     }
1024   }
1025   return blockSize;
1026 }
1027 
1028 bool mlir::sparse_tensor::isBlockSparsity(AffineMap dimToLvl) {
1029   if (!dimToLvl)
1030     return false;
1031   std::map<unsigned, int64_t> coeffientMap;
1032   bool hasBlock = false;
1033   for (auto result : dimToLvl.getResults()) {
1034     if (auto binOp = dyn_cast<AffineBinaryOpExpr>(result)) {
1035       // Check for "dim op const".
1036       auto dimOp = dyn_cast<AffineDimExpr>(binOp.getLHS());
1037       auto conOp = dyn_cast<AffineConstantExpr>(binOp.getRHS());
1038       if (!dimOp || !conOp || conOp.getValue() <= 0)
1039         return false;
1040       // Inspect "dim / const" or "dim % const".
1041       auto pos = dimOp.getPosition();
1042       if (binOp.getKind() == AffineExprKind::FloorDiv) {
1043         // Expect only one floordiv for each dimension.
1044         if (coeffientMap.find(pos) != coeffientMap.end())
1045           return false;
1046         // Record coefficient of the floordiv.
1047         coeffientMap[pos] = conOp.getValue();
1048       } else if (binOp.getKind() == AffineExprKind::Mod) {
1049         // Expect floordiv before mod.
1050         if (coeffientMap.find(pos) == coeffientMap.end())
1051           return false;
1052         // Expect mod to have the same coefficient as floordiv.
1053         if (conOp.getValue() != coeffientMap[pos])
1054           return false;
1055         hasBlock = true;
1056       } else {
1057         return false;
1058       }
1059     } else if (auto dimOp = dyn_cast<AffineDimExpr>(result)) {
1060       auto pos = dimOp.getPosition();
1061       // Expect dim to be unset.
1062       if (coeffientMap.find(pos) != coeffientMap.end())
1063         return false;
1064       coeffientMap[pos] = 0;
1065     } else {
1066       return false;
1067     }
1068   }
1069   return hasBlock;
1070 }
1071 
1072 bool mlir::sparse_tensor::hasAnyNonIdentityOperandsOrResults(Operation *op) {
1073   auto hasNonIdentityMap = [](Value v) {
1074     auto stt = tryGetSparseTensorType(v);
1075     return stt && !stt->isIdentity();
1076   };
1077 
1078   return llvm::any_of(op->getOperands(), hasNonIdentityMap) ||
1079          llvm::any_of(op->getResults(), hasNonIdentityMap);
1080 }
1081 
1082 Dimension mlir::sparse_tensor::toDim(SparseTensorEncodingAttr enc, Level l) {
1083   if (enc) {
1084     assert(enc.isPermutation() && "Non permutation map not supported");
1085     if (const auto dimToLvl = enc.getDimToLvl())
1086       return dimToLvl.getDimPosition(l);
1087   }
1088   return l;
1089 }
1090 
1091 Level mlir::sparse_tensor::toLvl(SparseTensorEncodingAttr enc, Dimension d) {
1092   if (enc) {
1093     assert(enc.isPermutation() && "Non permutation map not supported");
1094     if (const auto lvlToDim = enc.getLvlToDim())
1095       return lvlToDim.getDimPosition(d);
1096   }
1097   return d;
1098 }
1099 
1100 /// We normalized sparse tensor encoding attribute by always using
1101 /// ordered/unique LT such that "compressed_nu_no" and "compressed_nu" (as well
1102 /// as other variants) lead to the same storage specifier type, and stripping
1103 /// irrelevant fields that do not alter the sparse tensor memory layout.
1104 static SparseTensorEncodingAttr
1105 getNormalizedEncodingForSpecifier(SparseTensorEncodingAttr enc) {
1106   SmallVector<LevelType> lts;
1107   for (auto lt : enc.getLvlTypes())
1108     lts.push_back(lt.stripStorageIrrelevantProperties());
1109 
1110   return SparseTensorEncodingAttr::get(
1111       enc.getContext(), lts,
1112       AffineMap(), // dimToLvl (irrelevant to storage specifier)
1113       AffineMap(), // lvlToDim (irrelevant to storage specifier)
1114       // Always use `index` for memSize and lvlSize instead of reusing
1115       // `getPosWidth` and `getCrdWidth`. It allows us to reuse the same SSA
1116       // value for different bitwidth, it also avoids casting between index and
1117       // integer (returned by DimOp)
1118       0, 0, enc.getDimSlices());
1119 }
1120 
1121 StorageSpecifierType
1122 StorageSpecifierType::get(MLIRContext *ctx, SparseTensorEncodingAttr encoding) {
1123   return Base::get(ctx, getNormalizedEncodingForSpecifier(encoding));
1124 }
1125 
1126 //===----------------------------------------------------------------------===//
1127 // SparseTensorDialect Operations.
1128 //===----------------------------------------------------------------------===//
1129 
1130 static LogicalResult lvlIsInBounds(Level lvl, Value tensor) {
1131   return success(lvl < getSparseTensorType(tensor).getLvlRank());
1132 }
1133 
1134 static LogicalResult isMatchingWidth(Value mem, unsigned width) {
1135   const Type etp = getMemRefType(mem).getElementType();
1136   return success(width == 0 ? etp.isIndex() : etp.isInteger(width));
1137 }
1138 
1139 static LogicalResult verifySparsifierGetterSetter(
1140     StorageSpecifierKind mdKind, std::optional<Level> lvl,
1141     TypedValue<StorageSpecifierType> md, Operation *op) {
1142   if (mdKind == StorageSpecifierKind::ValMemSize && lvl) {
1143     return op->emitError(
1144         "redundant level argument for querying value memory size");
1145   }
1146 
1147   const auto enc = md.getType().getEncoding();
1148   const Level lvlRank = enc.getLvlRank();
1149 
1150   if (mdKind == StorageSpecifierKind::DimOffset ||
1151       mdKind == StorageSpecifierKind::DimStride)
1152     if (!enc.isSlice())
1153       return op->emitError("requested slice data on non-slice tensor");
1154 
1155   if (mdKind != StorageSpecifierKind::ValMemSize) {
1156     if (!lvl)
1157       return op->emitError("missing level argument");
1158 
1159     const Level l = lvl.value();
1160     if (l >= lvlRank)
1161       return op->emitError("requested level is out of bounds");
1162 
1163     if (mdKind == StorageSpecifierKind::PosMemSize && enc.isSingletonLvl(l))
1164       return op->emitError(
1165           "requested position memory size on a singleton level");
1166   }
1167   return success();
1168 }
1169 
1170 static Type getFieldElemType(SparseTensorType stt, SparseTensorFieldKind kind) {
1171   switch (kind) {
1172   case SparseTensorFieldKind::CrdMemRef:
1173     return stt.getCrdType();
1174   case SparseTensorFieldKind::PosMemRef:
1175     return stt.getPosType();
1176   case SparseTensorFieldKind::ValMemRef:
1177     return stt.getElementType();
1178   case SparseTensorFieldKind::StorageSpec:
1179     return nullptr;
1180   }
1181   llvm_unreachable("Unrecognizable FieldKind");
1182 }
1183 
1184 static LogicalResult verifyPackUnPack(Operation *op, bool requiresStaticShape,
1185                                       SparseTensorType stt,
1186                                       RankedTensorType valTp,
1187                                       TypeRange lvlTps) {
1188   if (requiresStaticShape && !stt.hasStaticDimShape())
1189     return op->emitError("the sparse-tensor must have static shape");
1190   if (!stt.hasEncoding())
1191     return op->emitError("the sparse-tensor must have an encoding attribute");
1192 
1193   // Verifies the trailing COO.
1194   Level cooStartLvl = stt.getAoSCOOStart();
1195   if (cooStartLvl < stt.getLvlRank()) {
1196     // We only supports trailing COO for now, must be the last input.
1197     auto cooTp = llvm::cast<ShapedType>(lvlTps.back());
1198     // The coordinates should be in shape of <? x rank>
1199     unsigned expCOORank = stt.getLvlRank() - cooStartLvl;
1200     if (cooTp.getRank() != 2 || expCOORank != cooTp.getShape().back()) {
1201       op->emitError("input/output trailing COO level-ranks don't match");
1202     }
1203   }
1204 
1205   // Verifies that all types match.
1206   StorageLayout layout(stt.getEncoding());
1207   if (layout.getNumDataFields() != lvlTps.size() + 1) // plus one value memref
1208     return op->emitError("inconsistent number of fields between input/output");
1209 
1210   unsigned idx = 0;
1211   bool misMatch = false;
1212   layout.foreachField([&idx, &misMatch, stt, valTp,
1213                        lvlTps](FieldIndex fid, SparseTensorFieldKind fKind,
1214                                Level lvl, LevelType lt) -> bool {
1215     if (fKind == SparseTensorFieldKind::StorageSpec)
1216       return true;
1217 
1218     Type inputTp = nullptr;
1219     if (fKind == SparseTensorFieldKind::ValMemRef) {
1220       inputTp = valTp;
1221     } else {
1222       assert(fid == idx && stt.getLvlType(lvl) == lt);
1223       inputTp = lvlTps[idx++];
1224     }
1225     // The input element type and expected element type should match.
1226     Type inpElemTp = llvm::cast<TensorType>(inputTp).getElementType();
1227     Type expElemTp = getFieldElemType(stt, fKind);
1228     if (inpElemTp != expElemTp) {
1229       misMatch = true;
1230       return false; // to terminate the iteration
1231     }
1232     return true;
1233   });
1234 
1235   if (misMatch)
1236     return op->emitError("input/output element-types don't match");
1237   return success();
1238 }
1239 
1240 LogicalResult AssembleOp::verify() {
1241   const auto valuesTp = getRankedTensorType(getValues());
1242   const auto lvlsTp = getLevels().getTypes();
1243   const auto resTp = getSparseTensorType(getResult());
1244   return verifyPackUnPack(*this, true, resTp, valuesTp, lvlsTp);
1245 }
1246 
1247 LogicalResult DisassembleOp::verify() {
1248   if (getOutValues().getType() != getRetValues().getType())
1249     return emitError("output values and return value type mismatch");
1250 
1251   for (auto [ot, rt] : llvm::zip_equal(getOutLevels(), getRetLevels()))
1252     if (ot.getType() != rt.getType())
1253       return emitError("output levels and return levels type mismatch");
1254 
1255   const auto valuesTp = getRankedTensorType(getRetValues());
1256   const auto lvlsTp = getRetLevels().getTypes();
1257   const auto srcTp = getSparseTensorType(getTensor());
1258   return verifyPackUnPack(*this, false, srcTp, valuesTp, lvlsTp);
1259 }
1260 
1261 LogicalResult ConvertOp::verify() {
1262   if (auto tp1 = llvm::dyn_cast<RankedTensorType>(getSource().getType())) {
1263     if (auto tp2 = llvm::dyn_cast<RankedTensorType>(getDest().getType())) {
1264       if (tp1.getRank() != tp2.getRank())
1265         return emitError("unexpected conversion mismatch in rank");
1266       auto dstEnc =
1267           llvm::dyn_cast_or_null<SparseTensorEncodingAttr>(tp2.getEncoding());
1268       if (dstEnc && dstEnc.isSlice())
1269         return emitError("cannot convert to a sparse tensor slice");
1270 
1271       auto shape1 = tp1.getShape();
1272       auto shape2 = tp2.getShape();
1273       // Accept size matches between the source and the destination type
1274       // (e.g. 10 vs. 10, 10 vs. ?, or ? vs. ?), but reject direct mismatches or
1275       // matches that would need a runtime assert (e.g. 10 vs. 20 or ? vs. 10).
1276       for (Dimension d = 0, dimRank = tp1.getRank(); d < dimRank; d++)
1277         if (shape1[d] != shape2[d] && shape2[d] != ShapedType::kDynamic)
1278           return emitError("unexpected conversion mismatch in dimension ") << d;
1279       return success();
1280     }
1281   }
1282   return emitError("unexpected type in convert");
1283 }
1284 
1285 OpFoldResult ConvertOp::fold(FoldAdaptor adaptor) {
1286   if (getType() == getSource().getType())
1287     return getSource();
1288   return {};
1289 }
1290 
1291 bool ConvertOp::needsExtraSort() {
1292   SparseTensorType srcStt = getSparseTensorType(getSource());
1293   SparseTensorType dstStt = getSparseTensorType(getDest());
1294 
1295   // We do not need an extra sort when returning unordered sparse tensors or
1296   // dense tensor since dense tensor support random access.
1297   if (dstStt.isAllDense() || !dstStt.isAllOrdered())
1298     return false;
1299 
1300   if (srcStt.isAllOrdered() && dstStt.isAllOrdered() &&
1301       srcStt.hasSameDimToLvl(dstStt)) {
1302     return false;
1303   }
1304 
1305   // Source and dest tensors are ordered in different ways. We only do direct
1306   // dense to sparse conversion when the dense input is defined by a sparse
1307   // constant. Note that we can theoretically always directly convert from dense
1308   // inputs by rotating dense loops but it leads to bad cache locality and hurt
1309   // performance.
1310   if (auto constOp = getSource().getDefiningOp<arith::ConstantOp>())
1311     if (isa<SparseElementsAttr>(constOp.getValue()))
1312       return false;
1313 
1314   return true;
1315 }
1316 
1317 LogicalResult CrdTranslateOp::verify() {
1318   uint64_t inRank = getEncoder().getLvlRank();
1319   uint64_t outRank = getEncoder().getDimRank();
1320 
1321   if (getDirection() == CrdTransDirectionKind::dim2lvl)
1322     std::swap(inRank, outRank);
1323 
1324   if (inRank != getInCrds().size() || outRank != getOutCrds().size())
1325     return emitError("Coordinate rank mismatch with encoding");
1326 
1327   return success();
1328 }
1329 
1330 LogicalResult CrdTranslateOp::fold(FoldAdaptor adaptor,
1331                                    SmallVectorImpl<OpFoldResult> &results) {
1332   if (getEncoder().isIdentity()) {
1333     results.assign(getInCrds().begin(), getInCrds().end());
1334     return success();
1335   }
1336   if (getEncoder().isPermutation()) {
1337     AffineMap perm = getDirection() == CrdTransDirectionKind::dim2lvl
1338                          ? getEncoder().getDimToLvl()
1339                          : getEncoder().getLvlToDim();
1340     for (AffineExpr exp : perm.getResults())
1341       results.push_back(getInCrds()[cast<AffineDimExpr>(exp).getPosition()]);
1342     return success();
1343   }
1344 
1345   // Fuse dim2lvl/lvl2dim pairs.
1346   auto def = getInCrds()[0].getDefiningOp<CrdTranslateOp>();
1347   bool sameDef = def && llvm::all_of(getInCrds(), [def](Value v) {
1348                    return v.getDefiningOp() == def;
1349                  });
1350   if (!sameDef)
1351     return failure();
1352 
1353   bool oppositeDir = def.getDirection() != getDirection();
1354   bool sameOracle =
1355       def.getEncoder().getDimToLvl() == getEncoder().getDimToLvl();
1356   bool sameCount = def.getNumResults() == getInCrds().size();
1357   if (!oppositeDir || !sameOracle || !sameCount)
1358     return failure();
1359 
1360   // The definition produces the coordinates in the same order as the input
1361   // coordinates.
1362   bool sameOrder = llvm::all_of(llvm::zip_equal(def.getOutCrds(), getInCrds()),
1363                                 [](auto valuePair) {
1364                                   auto [lhs, rhs] = valuePair;
1365                                   return lhs == rhs;
1366                                 });
1367 
1368   if (!sameOrder)
1369     return failure();
1370   // l1 = dim2lvl (lvl2dim l0)
1371   // ==> l0
1372   results.append(def.getInCrds().begin(), def.getInCrds().end());
1373   return success();
1374 }
1375 
1376 void LvlOp::build(OpBuilder &builder, OperationState &state, Value source,
1377                   int64_t index) {
1378   Value val = builder.create<arith::ConstantIndexOp>(state.location, index);
1379   return build(builder, state, source, val);
1380 }
1381 
1382 LogicalResult LvlOp::verify() {
1383   if (std::optional<uint64_t> lvl = getConstantLvlIndex()) {
1384     auto stt = getSparseTensorType(getSource());
1385     if (static_cast<uint64_t>(lvl.value()) >= stt.getLvlRank())
1386       emitError("Level index exceeds the rank of the input sparse tensor");
1387   }
1388   return success();
1389 }
1390 
1391 std::optional<uint64_t> LvlOp::getConstantLvlIndex() {
1392   return getConstantIntValue(getIndex());
1393 }
1394 
1395 Speculation::Speculatability LvlOp::getSpeculatability() {
1396   auto constantIndex = getConstantLvlIndex();
1397   if (!constantIndex)
1398     return Speculation::NotSpeculatable;
1399 
1400   assert(constantIndex <
1401          cast<RankedTensorType>(getSource().getType()).getRank());
1402   return Speculation::Speculatable;
1403 }
1404 
1405 OpFoldResult LvlOp::fold(FoldAdaptor adaptor) {
1406   auto lvlIndex = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getIndex());
1407   if (!lvlIndex)
1408     return {};
1409 
1410   Level lvl = lvlIndex.getAPSInt().getZExtValue();
1411   auto stt = getSparseTensorType(getSource());
1412   if (lvl >= stt.getLvlRank()) {
1413     // Follows the same convention used by tensor.dim operation. Out of bound
1414     // indices produce undefined behavior but are still valid IR. Don't choke on
1415     // them.
1416     return {};
1417   }
1418 
1419   // Helper lambda to build an IndexAttr.
1420   auto getIndexAttr = [this](int64_t lvlSz) {
1421     return IntegerAttr::get(IndexType::get(getContext()), APInt(64, lvlSz));
1422   };
1423 
1424   SmallVector<Size> lvlShape = stt.getLvlShape();
1425   if (!ShapedType::isDynamic(lvlShape[lvl]))
1426     return getIndexAttr(lvlShape[lvl]);
1427 
1428   return {};
1429 }
1430 
1431 void ReinterpretMapOp::build(OpBuilder &odsBuilder, OperationState &odsState,
1432                              SparseTensorEncodingAttr dstEnc, Value source) {
1433   auto srcStt = getSparseTensorType(source);
1434   SmallVector<int64_t> srcLvlShape = srcStt.getLvlShape();
1435   SmallVector<int64_t> dstDimShape =
1436       dstEnc.translateShape(srcLvlShape, CrdTransDirectionKind::lvl2dim);
1437   auto dstTp =
1438       RankedTensorType::get(dstDimShape, srcStt.getElementType(), dstEnc);
1439   return build(odsBuilder, odsState, dstTp, source);
1440 }
1441 
1442 LogicalResult ReinterpretMapOp::verify() {
1443   auto srcStt = getSparseTensorType(getSource());
1444   auto dstStt = getSparseTensorType(getDest());
1445   ArrayRef<LevelType> srcLvlTps = srcStt.getLvlTypes();
1446   ArrayRef<LevelType> dstLvlTps = dstStt.getLvlTypes();
1447 
1448   if (srcLvlTps.size() != dstLvlTps.size())
1449     return emitError("Level rank mismatch between source/dest tensors");
1450 
1451   for (auto [srcLvlTp, dstLvlTp] : llvm::zip(srcLvlTps, dstLvlTps))
1452     if (srcLvlTp != dstLvlTp)
1453       return emitError("Level type mismatch between source/dest tensors");
1454 
1455   if (srcStt.getPosWidth() != dstStt.getPosWidth() ||
1456       srcStt.getCrdWidth() != dstStt.getCrdWidth()) {
1457     return emitError("Crd/Pos width mismatch between source/dest tensors");
1458   }
1459 
1460   if (srcStt.getElementType() != dstStt.getElementType())
1461     return emitError("Element type mismatch between source/dest tensors");
1462 
1463   SmallVector<Size> srcLvlShape = srcStt.getLvlShape();
1464   SmallVector<Size> dstLvlShape = dstStt.getLvlShape();
1465   for (auto [srcLvlSz, dstLvlSz] : llvm::zip(srcLvlShape, dstLvlShape)) {
1466     if (srcLvlSz != dstLvlSz) {
1467       // Should we allow one side to be dynamic size, e.g., <?x?> should be
1468       // compatible to <3x4>? For now, we require all the level sizes to be
1469       // *exactly* matched for simplicity.
1470       return emitError("Level size mismatch between source/dest tensors");
1471     }
1472   }
1473 
1474   return success();
1475 }
1476 
1477 OpFoldResult ReinterpretMapOp::fold(FoldAdaptor adaptor) {
1478   if (getSource().getType() == getDest().getType())
1479     return getSource();
1480 
1481   if (auto def = getSource().getDefiningOp<ReinterpretMapOp>()) {
1482     // A -> B, B -> A ==> A
1483     if (def.getSource().getType() == getDest().getType())
1484       return def.getSource();
1485   }
1486   return {};
1487 }
1488 
1489 template <typename ToBufferOp>
1490 static LogicalResult inferSparseBufferType(ValueRange ops, DictionaryAttr attr,
1491                                            OpaqueProperties prop,
1492                                            RegionRange region,
1493                                            SmallVectorImpl<mlir::Type> &ret) {
1494   typename ToBufferOp::Adaptor adaptor(ops, attr, prop, region);
1495   SparseTensorType stt = getSparseTensorType(adaptor.getTensor());
1496   Type elemTp = nullptr;
1497   bool withStride = false;
1498   if constexpr (std::is_same_v<ToBufferOp, ToPositionsOp>) {
1499     elemTp = stt.getPosType();
1500   } else if constexpr (std::is_same_v<ToBufferOp, ToCoordinatesOp> ||
1501                        std::is_same_v<ToBufferOp, ToCoordinatesBufferOp>) {
1502     elemTp = stt.getCrdType();
1503     if constexpr (std::is_same_v<ToBufferOp, ToCoordinatesOp>)
1504       withStride = stt.getAoSCOOStart() <= adaptor.getLevel();
1505   } else if constexpr (std::is_same_v<ToBufferOp, ToValuesOp>) {
1506     elemTp = stt.getElementType();
1507   }
1508 
1509   assert(elemTp && "unhandled operation.");
1510   SmallVector<int64_t> bufShape = stt.getBatchLvlShape();
1511   bufShape.push_back(ShapedType::kDynamic);
1512 
1513   auto layout = withStride ? StridedLayoutAttr::StridedLayoutAttr::get(
1514                                  stt.getContext(), ShapedType::kDynamic,
1515                                  {ShapedType::kDynamic})
1516                            : StridedLayoutAttr();
1517   ret.emplace_back(MemRefType::get(bufShape, elemTp, layout));
1518   return success();
1519 }
1520 
1521 LogicalResult ToPositionsOp::verify() {
1522   auto stt = getSparseTensorType(getTensor());
1523   if (failed(lvlIsInBounds(getLevel(), getTensor())))
1524     return emitError("requested level is out of bounds");
1525   if (failed(isMatchingWidth(getResult(), stt.getPosWidth())))
1526     return emitError("unexpected type for positions");
1527   return success();
1528 }
1529 
1530 LogicalResult
1531 ToPositionsOp::inferReturnTypes(MLIRContext *ctx, std::optional<Location> loc,
1532                                 ValueRange ops, DictionaryAttr attr,
1533                                 OpaqueProperties prop, RegionRange region,
1534                                 SmallVectorImpl<mlir::Type> &ret) {
1535   return inferSparseBufferType<ToPositionsOp>(ops, attr, prop, region, ret);
1536 }
1537 
1538 LogicalResult ToCoordinatesOp::verify() {
1539   auto stt = getSparseTensorType(getTensor());
1540   if (failed(lvlIsInBounds(getLevel(), getTensor())))
1541     return emitError("requested level is out of bounds");
1542   if (failed(isMatchingWidth(getResult(), stt.getCrdWidth())))
1543     return emitError("unexpected type for coordinates");
1544   return success();
1545 }
1546 
1547 LogicalResult
1548 ToCoordinatesOp::inferReturnTypes(MLIRContext *ctx, std::optional<Location> loc,
1549                                   ValueRange ops, DictionaryAttr attr,
1550                                   OpaqueProperties prop, RegionRange region,
1551                                   SmallVectorImpl<mlir::Type> &ret) {
1552   return inferSparseBufferType<ToCoordinatesOp>(ops, attr, prop, region, ret);
1553 }
1554 
1555 LogicalResult ToCoordinatesBufferOp::verify() {
1556   auto stt = getSparseTensorType(getTensor());
1557   if (stt.getAoSCOOStart() >= stt.getLvlRank())
1558     return emitError("expected sparse tensor with a COO region");
1559   return success();
1560 }
1561 
1562 LogicalResult ToCoordinatesBufferOp::inferReturnTypes(
1563     MLIRContext *ctx, std::optional<Location> loc, ValueRange ops,
1564     DictionaryAttr attr, OpaqueProperties prop, RegionRange region,
1565     SmallVectorImpl<mlir::Type> &ret) {
1566   return inferSparseBufferType<ToCoordinatesBufferOp>(ops, attr, prop, region,
1567                                                       ret);
1568 }
1569 
1570 LogicalResult ToValuesOp::verify() {
1571   auto stt = getSparseTensorType(getTensor());
1572   auto mtp = getMemRefType(getResult());
1573   if (stt.getElementType() != mtp.getElementType())
1574     return emitError("unexpected mismatch in element types");
1575   return success();
1576 }
1577 
1578 LogicalResult ToValuesOp::inferReturnTypes(MLIRContext *ctx,
1579                                            std::optional<Location> loc,
1580                                            ValueRange ops, DictionaryAttr attr,
1581                                            OpaqueProperties prop,
1582                                            RegionRange region,
1583                                            SmallVectorImpl<mlir::Type> &ret) {
1584   return inferSparseBufferType<ToValuesOp>(ops, attr, prop, region, ret);
1585 }
1586 
1587 LogicalResult ToSliceOffsetOp::verify() {
1588   auto rank = getRankedTensorType(getSlice()).getRank();
1589   if (rank <= getDim().getSExtValue() || getDim().getSExtValue() < 0)
1590     return emitError("requested dimension out of bound");
1591   return success();
1592 }
1593 
1594 LogicalResult ToSliceStrideOp::verify() {
1595   auto rank = getRankedTensorType(getSlice()).getRank();
1596   if (rank <= getDim().getSExtValue() || getDim().getSExtValue() < 0)
1597     return emitError("requested dimension out of bound");
1598   return success();
1599 }
1600 
1601 LogicalResult GetStorageSpecifierOp::verify() {
1602   return verifySparsifierGetterSetter(getSpecifierKind(), getLevel(),
1603                                       getSpecifier(), getOperation());
1604 }
1605 
1606 template <typename SpecifierOp>
1607 static SetStorageSpecifierOp getSpecifierSetDef(SpecifierOp op) {
1608   return op.getSpecifier().template getDefiningOp<SetStorageSpecifierOp>();
1609 }
1610 
1611 OpFoldResult GetStorageSpecifierOp::fold(FoldAdaptor adaptor) {
1612   const StorageSpecifierKind kind = getSpecifierKind();
1613   const auto lvl = getLevel();
1614   for (auto op = getSpecifierSetDef(*this); op; op = getSpecifierSetDef(op))
1615     if (kind == op.getSpecifierKind() && lvl == op.getLevel())
1616       return op.getValue();
1617   return {};
1618 }
1619 
1620 LogicalResult SetStorageSpecifierOp::verify() {
1621   return verifySparsifierGetterSetter(getSpecifierKind(), getLevel(),
1622                                       getSpecifier(), getOperation());
1623 }
1624 
1625 template <class T>
1626 static LogicalResult verifyNumBlockArgs(T *op, Region &region,
1627                                         const char *regionName,
1628                                         TypeRange inputTypes, Type outputType) {
1629   unsigned numArgs = region.getNumArguments();
1630   unsigned expectedNum = inputTypes.size();
1631   if (numArgs != expectedNum)
1632     return op->emitError() << regionName << " region must have exactly "
1633                            << expectedNum << " arguments";
1634 
1635   for (unsigned i = 0; i < numArgs; i++) {
1636     Type typ = region.getArgument(i).getType();
1637     if (typ != inputTypes[i])
1638       return op->emitError() << regionName << " region argument " << (i + 1)
1639                              << " type mismatch";
1640   }
1641   Operation *term = region.front().getTerminator();
1642   YieldOp yield = dyn_cast<YieldOp>(term);
1643   if (!yield)
1644     return op->emitError() << regionName
1645                            << " region must end with sparse_tensor.yield";
1646   if (!yield.hasSingleResult() ||
1647       yield.getSingleResult().getType() != outputType)
1648     return op->emitError() << regionName << " region yield type mismatch";
1649 
1650   return success();
1651 }
1652 
1653 LogicalResult BinaryOp::verify() {
1654   NamedAttrList attrs = (*this)->getAttrs();
1655   Type leftType = getX().getType();
1656   Type rightType = getY().getType();
1657   Type outputType = getOutput().getType();
1658   Region &overlap = getOverlapRegion();
1659   Region &left = getLeftRegion();
1660   Region &right = getRightRegion();
1661 
1662   // Check correct number of block arguments and return type for each
1663   // non-empty region.
1664   if (!overlap.empty()) {
1665     if (failed(verifyNumBlockArgs(this, overlap, "overlap",
1666                                   TypeRange{leftType, rightType}, outputType)))
1667       return failure();
1668   }
1669   if (!left.empty()) {
1670     if (failed(verifyNumBlockArgs(this, left, "left", TypeRange{leftType},
1671                                   outputType)))
1672       return failure();
1673   } else if (getLeftIdentity()) {
1674     if (leftType != outputType)
1675       return emitError("left=identity requires first argument to have the same "
1676                        "type as the output");
1677   }
1678   if (!right.empty()) {
1679     if (failed(verifyNumBlockArgs(this, right, "right", TypeRange{rightType},
1680                                   outputType)))
1681       return failure();
1682   } else if (getRightIdentity()) {
1683     if (rightType != outputType)
1684       return emitError("right=identity requires second argument to have the "
1685                        "same type as the output");
1686   }
1687   return success();
1688 }
1689 
1690 LogicalResult UnaryOp::verify() {
1691   Type inputType = getX().getType();
1692   Type outputType = getOutput().getType();
1693 
1694   // Check correct number of block arguments and return type for each
1695   // non-empty region.
1696   Region &present = getPresentRegion();
1697   if (!present.empty()) {
1698     if (failed(verifyNumBlockArgs(this, present, "present",
1699                                   TypeRange{inputType}, outputType)))
1700       return failure();
1701   }
1702   Region &absent = getAbsentRegion();
1703   if (!absent.empty()) {
1704     if (failed(verifyNumBlockArgs(this, absent, "absent", TypeRange{},
1705                                   outputType)))
1706       return failure();
1707     // Absent branch can only yield invariant values.
1708     Block *absentBlock = &absent.front();
1709     Block *parent = getOperation()->getBlock();
1710     Value absentVal =
1711         cast<YieldOp>(absentBlock->getTerminator()).getSingleResult();
1712     if (auto arg = dyn_cast<BlockArgument>(absentVal)) {
1713       if (arg.getOwner() == parent)
1714         return emitError("absent region cannot yield linalg argument");
1715     } else if (Operation *def = absentVal.getDefiningOp()) {
1716       if (!isa<arith::ConstantOp>(def) &&
1717           (def->getBlock() == absentBlock || def->getBlock() == parent))
1718         return emitError("absent region cannot yield locally computed value");
1719     }
1720   }
1721   return success();
1722 }
1723 
1724 bool ConcatenateOp::needsExtraSort() {
1725   SparseTensorType dstStt = getSparseTensorType(*this);
1726   if (dstStt.isAllDense() || !dstStt.isAllOrdered())
1727     return false;
1728 
1729   bool allSameOrdered = llvm::all_of(getInputs(), [dstStt](Value op) {
1730     return getSparseTensorType(op).hasSameDimToLvl(dstStt);
1731   });
1732   // TODO: When conDim != 0, as long as conDim corresponding to the first level
1733   // in all input/output buffers, and all input/output buffers have the same
1734   // dimToLvl, the tmp COO buffer is still unnecessary (e.g, concatenate
1735   // CSC matrices along column).
1736   bool directLowerable =
1737       allSameOrdered && getDimension() == 0 && dstStt.isIdentity();
1738   return !directLowerable;
1739 }
1740 
1741 LogicalResult ConcatenateOp::verify() {
1742   const auto dstTp = getSparseTensorType(*this);
1743   const Dimension concatDim = getDimension();
1744   const Dimension dimRank = dstTp.getDimRank();
1745 
1746   if (getInputs().size() <= 1)
1747     return emitError("Need at least two tensors to concatenate.");
1748 
1749   if (concatDim >= dimRank)
1750     return emitError(llvm::formatv(
1751         "Concat-dimension is out of bounds for dimension-rank ({0} >= {1})",
1752         concatDim, dimRank));
1753 
1754   for (const auto &it : llvm::enumerate(getInputs())) {
1755     const auto i = it.index();
1756     const auto srcTp = getSparseTensorType(it.value());
1757     if (srcTp.hasDynamicDimShape())
1758       return emitError(llvm::formatv("Input tensor ${0} has dynamic shape", i));
1759     const Dimension srcDimRank = srcTp.getDimRank();
1760     if (srcDimRank != dimRank)
1761       return emitError(
1762           llvm::formatv("Input tensor ${0} has a different rank (rank={1}) "
1763                         "from the output tensor (rank={2}).",
1764                         i, srcDimRank, dimRank));
1765   }
1766 
1767   for (Dimension d = 0; d < dimRank; d++) {
1768     const Size dstSh = dstTp.getDimShape()[d];
1769     if (d == concatDim) {
1770       if (!ShapedType::isDynamic(dstSh)) {
1771         // If we reach here, then all inputs have static shapes.  So we
1772         // can use `getDimShape()[d]` instead of `*getDynamicDimSize(d)`
1773         // to avoid redundant assertions in the loop.
1774         Size sumSz = 0;
1775         for (const auto src : getInputs())
1776           sumSz += getSparseTensorType(src).getDimShape()[d];
1777         // If all dimension are statically known, the sum of all the input
1778         // dimensions should be equal to the output dimension.
1779         if (sumSz != dstSh)
1780           return emitError(
1781               "The concatenation dimension of the output tensor should be the "
1782               "sum of all the concatenation dimensions of the input tensors.");
1783       }
1784     } else {
1785       Size prev = dstSh;
1786       for (const auto src : getInputs()) {
1787         const auto sh = getSparseTensorType(src).getDimShape()[d];
1788         if (!ShapedType::isDynamic(prev) && sh != prev)
1789           return emitError("All dimensions (expect for the concatenating one) "
1790                            "should be equal.");
1791         prev = sh;
1792       }
1793     }
1794   }
1795 
1796   return success();
1797 }
1798 
1799 void PushBackOp::build(OpBuilder &builder, OperationState &result,
1800                        Value curSize, Value inBuffer, Value value) {
1801   build(builder, result, curSize, inBuffer, value, Value());
1802 }
1803 
1804 LogicalResult PushBackOp::verify() {
1805   if (Value n = getN()) {
1806     std::optional<int64_t> nValue = getConstantIntValue(n);
1807     if (nValue && nValue.value() < 1)
1808       return emitOpError("n must be not less than 1");
1809   }
1810   return success();
1811 }
1812 
1813 LogicalResult CompressOp::verify() {
1814   const auto stt = getSparseTensorType(getTensor());
1815   if (stt.getLvlRank() != 1 + static_cast<Level>(getLvlCoords().size()))
1816     return emitOpError("incorrect number of coordinates");
1817   return success();
1818 }
1819 
1820 void ForeachOp::build(
1821     OpBuilder &builder, OperationState &result, Value tensor,
1822     ValueRange initArgs, AffineMapAttr order,
1823     function_ref<void(OpBuilder &, Location, ValueRange, Value, ValueRange)>
1824         bodyBuilder) {
1825   build(builder, result, initArgs.getTypes(), tensor, initArgs, order);
1826   // Builds foreach body.
1827   if (!bodyBuilder)
1828     return;
1829   const auto stt = getSparseTensorType(tensor);
1830   const Dimension dimRank = stt.getDimRank();
1831 
1832   // Starts with `dimRank`-many coordinates.
1833   SmallVector<Type> blockArgTypes(dimRank, builder.getIndexType());
1834   // Followed by one value.
1835   blockArgTypes.push_back(stt.getElementType());
1836   // Followed by the reduction variables.
1837   blockArgTypes.append(initArgs.getTypes().begin(), initArgs.getTypes().end());
1838 
1839   SmallVector<Location> blockArgLocs(blockArgTypes.size(), tensor.getLoc());
1840 
1841   OpBuilder::InsertionGuard guard(builder);
1842   auto &region = *result.regions.front();
1843   Block *bodyBlock =
1844       builder.createBlock(&region, region.end(), blockArgTypes, blockArgLocs);
1845   bodyBuilder(builder, result.location,
1846               bodyBlock->getArguments().slice(0, dimRank),
1847               bodyBlock->getArguments()[dimRank],
1848               bodyBlock->getArguments().drop_front(dimRank + 1));
1849 }
1850 
1851 LogicalResult ForeachOp::verify() {
1852   const auto t = getSparseTensorType(getTensor());
1853   const Dimension dimRank = t.getDimRank();
1854   const auto args = getBody()->getArguments();
1855 
1856   if (getOrder().has_value() && getOrder()->getNumDims() != t.getLvlRank())
1857     return emitError("Level traverse order does not match tensor's level rank");
1858 
1859   if (dimRank + 1 + getInitArgs().size() != args.size())
1860     return emitError("Unmatched number of arguments in the block");
1861 
1862   if (getNumResults() != getInitArgs().size())
1863     return emitError("Mismatch in number of init arguments and results");
1864 
1865   if (getResultTypes() != getInitArgs().getTypes())
1866     return emitError("Mismatch in types of init arguments and results");
1867 
1868   // Cannot mark this const, because the getters aren't.
1869   auto yield = cast<YieldOp>(getBody()->getTerminator());
1870   if (yield.getNumOperands() != getNumResults() ||
1871       yield.getOperands().getTypes() != getResultTypes())
1872     return emitError("Mismatch in types of yield values and results");
1873 
1874   const auto iTp = IndexType::get(getContext());
1875   for (Dimension d = 0; d < dimRank; d++)
1876     if (args[d].getType() != iTp)
1877       emitError(
1878           llvm::formatv("Expecting Index type for argument at index {0}", d));
1879 
1880   const auto elemTp = t.getElementType();
1881   const auto valueTp = args[dimRank].getType();
1882   if (elemTp != valueTp)
1883     emitError(llvm::formatv("Unmatched element type between input tensor and "
1884                             "block argument, expected:{0}, got: {1}",
1885                             elemTp, valueTp));
1886   return success();
1887 }
1888 
1889 OpFoldResult ReorderCOOOp::fold(FoldAdaptor adaptor) {
1890   if (getSparseTensorEncoding(getInputCoo().getType()) ==
1891       getSparseTensorEncoding(getResultCoo().getType()))
1892     return getInputCoo();
1893 
1894   return {};
1895 }
1896 
1897 LogicalResult ReorderCOOOp::verify() {
1898   SparseTensorType srcStt = getSparseTensorType(getInputCoo());
1899   SparseTensorType dstStt = getSparseTensorType(getResultCoo());
1900 
1901   if (!srcStt.isCOOType() || !dstStt.isCOOType())
1902     emitError("Expected COO sparse tensors only");
1903 
1904   if (!srcStt.hasSameDimToLvl(dstStt))
1905     emitError("Unmatched dim2lvl map between input and result COO");
1906 
1907   if (srcStt.getPosType() != dstStt.getPosType() ||
1908       srcStt.getCrdType() != dstStt.getCrdType() ||
1909       srcStt.getElementType() != dstStt.getElementType())
1910     emitError("Unmatched storage format between input and result COO");
1911 
1912   return success();
1913 }
1914 
1915 LogicalResult ReduceOp::verify() {
1916   Type inputType = getX().getType();
1917   Region &formula = getRegion();
1918   return verifyNumBlockArgs(this, formula, "reduce",
1919                             TypeRange{inputType, inputType}, inputType);
1920 }
1921 
1922 LogicalResult SelectOp::verify() {
1923   Builder b(getContext());
1924   Type inputType = getX().getType();
1925   Type boolType = b.getI1Type();
1926   Region &formula = getRegion();
1927   return verifyNumBlockArgs(this, formula, "select", TypeRange{inputType},
1928                             boolType);
1929 }
1930 
1931 LogicalResult SortOp::verify() {
1932   AffineMap xPerm = getPermMap();
1933   uint64_t nx = xPerm.getNumDims();
1934   if (nx < 1)
1935     emitError(llvm::formatv("Expected rank(perm_map) > 1, got {0}", nx));
1936 
1937   if (!xPerm.isPermutation())
1938     emitError(llvm::formatv("Expected a permutation map, got {0}", xPerm));
1939 
1940   // We can't check the size of the buffers when n or buffer dimensions aren't
1941   // compile-time constants.
1942   std::optional<int64_t> cn = getConstantIntValue(getN());
1943   if (!cn)
1944     return success();
1945 
1946   // Verify dimensions.
1947   const auto checkDim = [&](Value v, Size minSize, const char *message) {
1948     const Size sh = getMemRefType(v).getShape()[0];
1949     if (!ShapedType::isDynamic(sh) && sh < minSize)
1950       emitError(llvm::formatv("{0} got {1} < {2}", message, sh, minSize));
1951   };
1952   uint64_t n = cn.value();
1953   uint64_t ny = 0;
1954   if (auto nyAttr = getNyAttr())
1955     ny = nyAttr.getInt();
1956   checkDim(getXy(), n * (nx + ny),
1957            "Expected dimension(xy) >= n * (rank(perm_map) + ny)");
1958   for (Value opnd : getYs())
1959     checkDim(opnd, n, "Expected dimension(y) >= n");
1960 
1961   return success();
1962 }
1963 
1964 //===----------------------------------------------------------------------===//
1965 // Sparse Tensor Iteration Operations.
1966 //===----------------------------------------------------------------------===//
1967 
1968 IterSpaceType IteratorType::getIterSpaceType() const {
1969   return IterSpaceType::get(getContext(), getEncoding(), getLoLvl(),
1970                             getHiLvl());
1971 }
1972 
1973 IteratorType IterSpaceType::getIteratorType() const {
1974   return IteratorType::get(getContext(), getEncoding(), getLoLvl(), getHiLvl());
1975 }
1976 
1977 /// Parses a level range in the form "$lo `to` $hi"
1978 /// or simply "$lo" if $hi - $lo = 1
1979 static ParseResult parseLevelRange(AsmParser &parser, Level &lvlLo,
1980                                    Level &lvlHi) {
1981   if (parser.parseInteger(lvlLo))
1982     return failure();
1983 
1984   if (succeeded(parser.parseOptionalKeyword("to"))) {
1985     if (parser.parseInteger(lvlHi))
1986       return failure();
1987   } else {
1988     lvlHi = lvlLo + 1;
1989   }
1990 
1991   if (lvlHi <= lvlLo)
1992     parser.emitError(parser.getNameLoc(),
1993                      "expect larger level upper bound than lower bound");
1994 
1995   return success();
1996 }
1997 
1998 /// Parses a level range in the form "$lo `to` $hi"
1999 /// or simply "$lo" if $hi - $lo = 1
2000 static ParseResult parseLevelRange(OpAsmParser &parser, IntegerAttr &lvlLoAttr,
2001                                    IntegerAttr &lvlHiAttr) {
2002   Level lvlLo, lvlHi;
2003   if (parseLevelRange(parser, lvlLo, lvlHi))
2004     return failure();
2005 
2006   lvlLoAttr = IntegerAttr::get(parser.getBuilder().getIndexType(), lvlLo);
2007   lvlHiAttr = IntegerAttr::get(parser.getBuilder().getIndexType(), lvlHi);
2008   return success();
2009 }
2010 
2011 /// Prints a level range in the form "$lo `to` $hi"
2012 /// or simply "$lo" if $hi - $lo = 1
2013 static void printLevelRange(AsmPrinter &p, Level lo, Level hi) {
2014 
2015   if (lo + 1 == hi)
2016     p << lo;
2017   else
2018     p << lo << " to " << hi;
2019 }
2020 
2021 /// Prints a level range in the form "$lo `to` $hi"
2022 /// or simply "$lo" if $hi - $lo = 1
2023 static void printLevelRange(OpAsmPrinter &p, Operation *, IntegerAttr lvlLo,
2024                             IntegerAttr lvlHi) {
2025   unsigned lo = lvlLo.getValue().getZExtValue();
2026   unsigned hi = lvlHi.getValue().getZExtValue();
2027   printLevelRange(p, lo, hi);
2028 }
2029 
2030 static ParseResult
2031 parseSparseSpaceLoop(OpAsmParser &parser, OperationState &state,
2032                      SmallVectorImpl<OpAsmParser::Argument> &iterators,
2033                      SmallVectorImpl<OpAsmParser::Argument> &iterArgs) {
2034   SmallVector<OpAsmParser::UnresolvedOperand> spaces;
2035   SmallVector<OpAsmParser::UnresolvedOperand> initArgs;
2036 
2037   // Parses "%iters, ... in %spaces, ..."
2038   if (parser.parseArgumentList(iterators) || parser.parseKeyword("in") ||
2039       parser.parseOperandList(spaces))
2040     return failure();
2041 
2042   if (iterators.size() != spaces.size())
2043     return parser.emitError(
2044         parser.getNameLoc(),
2045         "mismatch in number of sparse iterators and sparse spaces");
2046 
2047   // Parse "at(%crd0, _, ...)"
2048   LevelSet crdUsedLvlSet;
2049   bool hasUsedCrds = succeeded(parser.parseOptionalKeyword("at"));
2050   unsigned lvlCrdCnt = 0;
2051   if (hasUsedCrds) {
2052     ParseResult crdList = parser.parseCommaSeparatedList(
2053         OpAsmParser::Delimiter::Paren, [&]() -> ParseResult {
2054           if (parser.parseOptionalKeyword("_")) {
2055             if (parser.parseArgument(iterArgs.emplace_back()))
2056               return failure();
2057             // Always use IndexType for the coordinate.
2058             crdUsedLvlSet.set(lvlCrdCnt);
2059             iterArgs.back().type = parser.getBuilder().getIndexType();
2060           }
2061           lvlCrdCnt += 1;
2062           return success();
2063         });
2064     if (failed(crdList)) {
2065       return parser.emitError(
2066           parser.getNameLoc(),
2067           "expecting SSA value or \"_\" for level coordinates");
2068     }
2069   }
2070   // Set the CrdUsedLvl bitset.
2071   state.addAttribute("crdUsedLvls",
2072                      parser.getBuilder().getI64IntegerAttr(crdUsedLvlSet));
2073 
2074   // Parse "iter_args(%arg = %init, ...)"
2075   bool hasIterArgs = succeeded(parser.parseOptionalKeyword("iter_args"));
2076   if (hasIterArgs)
2077     if (parser.parseAssignmentList(iterArgs, initArgs))
2078       return failure();
2079 
2080   SmallVector<Type> iterSpaceTps;
2081   // parse ": sparse_tensor.iter_space -> ret"
2082   if (parser.parseColon() || parser.parseTypeList(iterSpaceTps))
2083     return failure();
2084   if (iterSpaceTps.size() != spaces.size())
2085     return parser.emitError(parser.getNameLoc(),
2086                             "mismatch in number of iteration space operands "
2087                             "and iteration space types");
2088 
2089   for (auto [it, tp] : llvm::zip_equal(iterators, iterSpaceTps)) {
2090     IterSpaceType spaceTp = llvm::dyn_cast<IterSpaceType>(tp);
2091     if (!spaceTp)
2092       return parser.emitError(parser.getNameLoc(),
2093                               "expected sparse_tensor.iter_space type for "
2094                               "iteration space operands");
2095     if (hasUsedCrds && spaceTp.getSpaceDim() != lvlCrdCnt)
2096       return parser.emitError(parser.getNameLoc(),
2097                               "mismatch in number of iteration space dimension "
2098                               "and specified coordinates");
2099     it.type = spaceTp.getIteratorType();
2100   }
2101 
2102   if (hasIterArgs)
2103     if (parser.parseArrowTypeList(state.types))
2104       return failure();
2105 
2106   // Resolves input operands.
2107   if (parser.resolveOperands(spaces, iterSpaceTps, parser.getNameLoc(),
2108                              state.operands))
2109     return failure();
2110 
2111   if (hasIterArgs) {
2112     unsigned numCrds = crdUsedLvlSet.count();
2113     // Strip off leading args that used for coordinates.
2114     MutableArrayRef args = MutableArrayRef(iterArgs).drop_front(numCrds);
2115     if (args.size() != initArgs.size() || args.size() != state.types.size()) {
2116       return parser.emitError(
2117           parser.getNameLoc(),
2118           "mismatch in number of iteration arguments and return values");
2119     }
2120 
2121     for (auto [it, init, tp] : llvm::zip_equal(args, initArgs, state.types)) {
2122       it.type = tp;
2123       if (parser.resolveOperand(init, tp, state.operands))
2124         return failure();
2125     }
2126   }
2127   return success();
2128 }
2129 
2130 LogicalResult ExtractIterSpaceOp::inferReturnTypes(
2131     MLIRContext *ctx, std::optional<Location> loc, ValueRange ops,
2132     DictionaryAttr attr, OpaqueProperties prop, RegionRange region,
2133     SmallVectorImpl<mlir::Type> &ret) {
2134 
2135   ExtractIterSpaceOp::Adaptor adaptor(ops, attr, prop, region);
2136   SparseTensorType stt = getSparseTensorType(adaptor.getTensor());
2137   ret.push_back(IterSpaceType::get(ctx, stt.getEncoding(), adaptor.getLoLvl(),
2138                                    adaptor.getHiLvl()));
2139   return success();
2140 }
2141 
2142 LogicalResult ExtractIterSpaceOp::verify() {
2143   if (getLoLvl() >= getHiLvl())
2144     return emitOpError("expected smaller level low than level high");
2145 
2146   TypedValue<IteratorType> pIter = getParentIter();
2147   if ((pIter && getLoLvl() == 0) || (!pIter && getLoLvl() != 0)) {
2148     return emitOpError(
2149         "parent iterator should be specified iff level lower bound equals 0");
2150   }
2151 
2152   if (pIter) {
2153     IterSpaceType spaceTp = getResultSpace().getType();
2154     if (pIter.getType().getEncoding() != spaceTp.getEncoding())
2155       return emitOpError(
2156           "mismatch in parent iterator encoding and iteration space encoding.");
2157 
2158     if (spaceTp.getLoLvl() != pIter.getType().getHiLvl())
2159       return emitOpError("parent iterator should be used to extract an "
2160                          "iteration space from a consecutive level.");
2161   }
2162 
2163   return success();
2164 }
2165 
2166 ParseResult IterateOp::parse(OpAsmParser &parser, OperationState &result) {
2167   OpAsmParser::Argument iterator;
2168   OpAsmParser::UnresolvedOperand iterSpace;
2169 
2170   SmallVector<OpAsmParser::Argument> iters, iterArgs;
2171   if (parseSparseSpaceLoop(parser, result, iters, iterArgs))
2172     return failure();
2173   if (iters.size() != 1)
2174     return parser.emitError(parser.getNameLoc(),
2175                             "expected only one iterator/iteration space");
2176 
2177   iters.append(iterArgs);
2178   Region *body = result.addRegion();
2179   if (parser.parseRegion(*body, iters))
2180     return failure();
2181 
2182   IterateOp::ensureTerminator(*body, parser.getBuilder(), result.location);
2183 
2184   // Parse the optional attribute list.
2185   if (parser.parseOptionalAttrDict(result.attributes))
2186     return failure();
2187 
2188   return success();
2189 }
2190 
2191 /// Prints the initialization list in the form of
2192 ///   <prefix>(%inner = %outer, %inner2 = %outer2, <...>)
2193 /// where 'inner' values are assumed to be region arguments and 'outer' values
2194 /// are regular SSA values.
2195 static void printInitializationList(OpAsmPrinter &p,
2196                                     Block::BlockArgListType blocksArgs,
2197                                     ValueRange initializers,
2198                                     StringRef prefix = "") {
2199   assert(blocksArgs.size() == initializers.size() &&
2200          "expected same length of arguments and initializers");
2201   if (initializers.empty())
2202     return;
2203 
2204   p << prefix << '(';
2205   llvm::interleaveComma(llvm::zip(blocksArgs, initializers), p, [&](auto it) {
2206     p << std::get<0>(it) << " = " << std::get<1>(it);
2207   });
2208   p << ")";
2209 }
2210 
2211 static void printUsedCrdsList(OpAsmPrinter &p, unsigned spaceDim,
2212                               Block::BlockArgListType blocksArgs,
2213                               LevelSet crdUsedLvls) {
2214   if (crdUsedLvls.empty())
2215     return;
2216 
2217   p << " at(";
2218   for (unsigned i = 0; i < spaceDim; i++) {
2219     if (crdUsedLvls[i]) {
2220       p << blocksArgs.front();
2221       blocksArgs = blocksArgs.drop_front();
2222     } else {
2223       p << "_";
2224     }
2225     if (i != spaceDim - 1)
2226       p << ", ";
2227   }
2228   assert(blocksArgs.empty());
2229   p << ")";
2230 }
2231 
2232 void IterateOp::print(OpAsmPrinter &p) {
2233   p << " " << getIterator() << " in " << getIterSpace();
2234   printUsedCrdsList(p, getSpaceDim(), getCrds(), getCrdUsedLvls());
2235   printInitializationList(p, getRegionIterArgs(), getInitArgs(), " iter_args");
2236 
2237   p << " : " << getIterSpace().getType() << " ";
2238   if (!getInitArgs().empty())
2239     p << "-> (" << getInitArgs().getTypes() << ") ";
2240 
2241   p.printRegion(getRegion(), /*printEntryBlockArgs=*/false,
2242                 /*printBlockTerminators=*/!getInitArgs().empty());
2243 }
2244 
2245 LogicalResult IterateOp::verify() {
2246   if (getInitArgs().size() != getNumResults()) {
2247     return emitOpError(
2248         "mismatch in number of loop-carried values and defined values");
2249   }
2250   return success();
2251 }
2252 
2253 LogicalResult IterateOp::verifyRegions() {
2254   if (getIterator().getType() != getIterSpace().getType().getIteratorType())
2255     return emitOpError("mismatch in iterator and iteration space type");
2256   if (getNumRegionIterArgs() != getNumResults())
2257     return emitOpError(
2258         "mismatch in number of basic block args and defined values");
2259 
2260   auto initArgs = getInitArgs();
2261   auto iterArgs = getRegionIterArgs();
2262   auto yieldVals = getYieldedValues();
2263   auto opResults = getResults();
2264   if (!llvm::all_equal({initArgs.size(), iterArgs.size(), yieldVals.size(),
2265                         opResults.size()})) {
2266     return emitOpError() << "number mismatch between iter args and results.";
2267   }
2268 
2269   for (auto [i, init, iter, yield, ret] :
2270        llvm::enumerate(initArgs, iterArgs, yieldVals, opResults)) {
2271     if (init.getType() != ret.getType())
2272       return emitOpError() << "types mismatch between " << i
2273                            << "th iter operand and defined value";
2274     if (iter.getType() != ret.getType())
2275       return emitOpError() << "types mismatch between " << i
2276                            << "th iter region arg and defined value";
2277     if (yield.getType() != ret.getType())
2278       return emitOpError() << "types mismatch between " << i
2279                            << "th yield value and defined value";
2280   }
2281 
2282   return success();
2283 }
2284 
2285 /// IterateOp implemented OpInterfaces' methods.
2286 SmallVector<Region *> IterateOp::getLoopRegions() { return {&getRegion()}; }
2287 
2288 MutableArrayRef<OpOperand> IterateOp::getInitsMutable() {
2289   return getInitArgsMutable();
2290 }
2291 
2292 Block::BlockArgListType IterateOp::getRegionIterArgs() {
2293   return getRegion().getArguments().take_back(getNumRegionIterArgs());
2294 }
2295 
2296 std::optional<MutableArrayRef<OpOperand>> IterateOp::getYieldedValuesMutable() {
2297   return cast<sparse_tensor::YieldOp>(
2298              getRegion().getBlocks().front().getTerminator())
2299       .getResultsMutable();
2300 }
2301 
2302 std::optional<ResultRange> IterateOp::getLoopResults() { return getResults(); }
2303 
2304 OperandRange IterateOp::getEntrySuccessorOperands(RegionBranchPoint point) {
2305   return getInitArgs();
2306 }
2307 
2308 void IterateOp::getSuccessorRegions(RegionBranchPoint point,
2309                                     SmallVectorImpl<RegionSuccessor> &regions) {
2310   // Both the operation itself and the region may be branching into the body or
2311   // back into the operation itself.
2312   regions.push_back(RegionSuccessor(&getRegion(), getRegionIterArgs()));
2313   // It is possible for loop not to enter the body.
2314   regions.push_back(RegionSuccessor(getResults()));
2315 }
2316 
2317 //===----------------------------------------------------------------------===//
2318 // Sparse Tensor Dialect Setups.
2319 //===----------------------------------------------------------------------===//
2320 
2321 /// Materialize a single constant operation from a given attribute value with
2322 /// the desired resultant type.
2323 Operation *SparseTensorDialect::materializeConstant(OpBuilder &builder,
2324                                                     Attribute value, Type type,
2325                                                     Location loc) {
2326   if (auto op = arith::ConstantOp::materialize(builder, value, type, loc))
2327     return op;
2328   return nullptr;
2329 }
2330 
2331 namespace {
2332 struct SparseTensorAsmDialectInterface : public OpAsmDialectInterface {
2333   using OpAsmDialectInterface::OpAsmDialectInterface;
2334 
2335   AliasResult getAlias(Attribute attr, raw_ostream &os) const override {
2336     if (attr.isa<SparseTensorEncodingAttr>()) {
2337       os << "sparse";
2338       return AliasResult::OverridableAlias;
2339     }
2340     return AliasResult::NoAlias;
2341   }
2342 };
2343 } // namespace
2344 
2345 void SparseTensorDialect::initialize() {
2346   addInterface<SparseTensorAsmDialectInterface>();
2347   addAttributes<
2348 #define GET_ATTRDEF_LIST
2349 #include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.cpp.inc"
2350       >();
2351   addTypes<
2352 #define GET_TYPEDEF_LIST
2353 #include "mlir/Dialect/SparseTensor/IR/SparseTensorTypes.cpp.inc"
2354       >();
2355   addOperations<
2356 #define GET_OP_LIST
2357 #include "mlir/Dialect/SparseTensor/IR/SparseTensorOps.cpp.inc"
2358       >();
2359   declarePromisedInterfaces<
2360       bufferization::BufferizableOpInterface, ConcatenateOp, ConvertOp, LoadOp,
2361       NewOp, NumberOfEntriesOp, AssembleOp, DisassembleOp,
2362       ToCoordinatesBufferOp, ToCoordinatesOp, ToPositionsOp, ToValuesOp>();
2363 }
2364 
2365 #define GET_OP_CLASSES
2366 #include "mlir/Dialect/SparseTensor/IR/SparseTensorOps.cpp.inc"
2367 
2368 #include "mlir/Dialect/SparseTensor/IR/SparseTensorOpsDialect.cpp.inc"
2369