xref: /llvm-project/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp (revision 62fa12ad2405cbdf24d5daf489e276c67a38e482)
1 //===- SparseTensorDialect.cpp - Sparse tensor dialect implementation -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include <utility>
10 
11 #include "Detail/DimLvlMapParser.h"
12 
13 #include "mlir/Dialect/SparseTensor/IR/Enums.h"
14 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
15 #include "mlir/Dialect/SparseTensor/IR/SparseTensorStorageLayout.h"
16 #include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h"
17 
18 #include "mlir/Dialect/Arith/IR/Arith.h"
19 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
20 #include "mlir/Dialect/Utils/StaticValueUtils.h"
21 #include "mlir/IR/Builders.h"
22 #include "mlir/IR/DialectImplementation.h"
23 #include "mlir/IR/Matchers.h"
24 #include "mlir/IR/OpImplementation.h"
25 #include "mlir/IR/PatternMatch.h"
26 #include "llvm/ADT/TypeSwitch.h"
27 #include "llvm/Support/FormatVariadic.h"
28 
29 #define GET_ATTRDEF_CLASSES
30 #include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.cpp.inc"
31 #include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrEnums.cpp.inc"
32 
33 #define GET_TYPEDEF_CLASSES
34 #include "mlir/Dialect/SparseTensor/IR/SparseTensorTypes.cpp.inc"
35 
36 using namespace mlir;
37 using namespace mlir::sparse_tensor;
38 
39 // Support hashing LevelType such that SparseTensorEncodingAttr can be hashed as
40 // well.
41 namespace mlir::sparse_tensor {
42 llvm::hash_code hash_value(LevelType lt) {
43   return llvm::hash_value(static_cast<uint64_t>(lt));
44 }
45 } // namespace mlir::sparse_tensor
46 
47 //===----------------------------------------------------------------------===//
48 // Local Convenience Methods.
49 //===----------------------------------------------------------------------===//
50 
51 static constexpr bool acceptBitWidth(unsigned bitWidth) {
52   switch (bitWidth) {
53   case 0:
54   case 8:
55   case 16:
56   case 32:
57   case 64:
58     return true;
59   default:
60     return false;
61   }
62 }
63 
64 static SmallVector<Size>
65 getSparseFieldShape(const SparseTensorEncodingAttr enc,
66                     std::optional<ArrayRef<int64_t>> dimShape) {
67   assert(enc);
68   // With only encoding, we can not determine the static shape for leading
69   // batch levels, we therefore return a dynamic shape memref instead.
70   SmallVector<int64_t> memrefShape(enc.getBatchLvlRank(), ShapedType::kDynamic);
71   if (dimShape.has_value()) {
72     // If the actual tensor shape is provided, we can then refine the leading
73     // batch dimension.
74     SmallVector<int64_t> lvlShape =
75         enc.translateShape(*dimShape, CrdTransDirectionKind::dim2lvl);
76     memrefShape.assign(lvlShape.begin(),
77                        lvlShape.begin() + enc.getBatchLvlRank());
78   }
79   // Another dynamic dimension to store the sparse level.
80   memrefShape.push_back(ShapedType::kDynamic);
81   return memrefShape;
82 }
83 
84 //===----------------------------------------------------------------------===//
85 // SparseTensorDialect StorageLayout.
86 //===----------------------------------------------------------------------===//
87 
88 static constexpr Level kInvalidLevel = -1u;
89 static constexpr Level kInvalidFieldIndex = -1u;
90 static constexpr FieldIndex kDataFieldStartingIdx = 0;
91 
92 void StorageLayout::foreachField(
93     llvm::function_ref<bool(FieldIndex, SparseTensorFieldKind, Level,
94                             LevelType)>
95         callback) const {
96   const auto lvlTypes = enc.getLvlTypes();
97   const Level lvlRank = enc.getLvlRank();
98   SmallVector<COOSegment> cooSegs = SparseTensorType(enc).getCOOSegments();
99   FieldIndex fieldIdx = kDataFieldStartingIdx;
100 
101   ArrayRef cooSegsRef = cooSegs;
102   // Per-level storage.
103   for (Level l = 0; l < lvlRank; /*l += 1 or l += AoSCooLen*/) {
104     const auto lt = lvlTypes[l];
105     if (isWithPosLT(lt)) {
106       if (!(callback(fieldIdx++, SparseTensorFieldKind::PosMemRef, l, lt)))
107         return;
108     }
109     if (isWithCrdLT(lt)) {
110       if (!(callback(fieldIdx++, SparseTensorFieldKind::CrdMemRef, l, lt)))
111         return;
112     }
113     if (!cooSegsRef.empty() && cooSegsRef.front().isSegmentStart(l)) {
114       if (!cooSegsRef.front().isSoA) {
115         // AoS COO, all singletons are fused into one memrefs. Skips the entire
116         // COO segement.
117         l = cooSegsRef.front().lvlRange.second;
118       } else {
119         // SoA COO, each singleton level has one memref.
120         l++;
121       }
122       // Expire handled COO segment.
123       cooSegsRef = cooSegsRef.drop_front();
124     } else {
125       // Non COO levels.
126       l++;
127     }
128   }
129   // The values array.
130   if (!(callback(fieldIdx++, SparseTensorFieldKind::ValMemRef, kInvalidLevel,
131                  LevelFormat::Undef)))
132     return;
133   // Put metadata at the end.
134   if (!(callback(fieldIdx++, SparseTensorFieldKind::StorageSpec, kInvalidLevel,
135                  LevelFormat::Undef)))
136     return;
137 }
138 
139 void sparse_tensor::foreachFieldAndTypeInSparseTensor(
140     SparseTensorType stt,
141     llvm::function_ref<bool(Type, FieldIndex, SparseTensorFieldKind, Level,
142                             LevelType)>
143         callback) {
144   assert(stt.hasEncoding());
145 
146   SmallVector<int64_t> memrefShape =
147       getSparseFieldShape(stt.getEncoding(), stt.getDimShape());
148 
149   const Type specType = StorageSpecifierType::get(stt.getEncoding());
150   // memref<[batch] x ? x pos>  positions
151   const Type posMemType = MemRefType::get(memrefShape, stt.getPosType());
152   // memref<[batch] x ? x crd>  coordinates
153   const Type crdMemType = MemRefType::get(memrefShape, stt.getCrdType());
154   // memref<[batch] x ? x eltType> values
155   const Type valMemType = MemRefType::get(memrefShape, stt.getElementType());
156 
157   StorageLayout(stt).foreachField([specType, posMemType, crdMemType, valMemType,
158                                    callback](FieldIndex fieldIdx,
159                                              SparseTensorFieldKind fieldKind,
160                                              Level lvl, LevelType lt) -> bool {
161     switch (fieldKind) {
162     case SparseTensorFieldKind::StorageSpec:
163       return callback(specType, fieldIdx, fieldKind, lvl, lt);
164     case SparseTensorFieldKind::PosMemRef:
165       return callback(posMemType, fieldIdx, fieldKind, lvl, lt);
166     case SparseTensorFieldKind::CrdMemRef:
167       return callback(crdMemType, fieldIdx, fieldKind, lvl, lt);
168     case SparseTensorFieldKind::ValMemRef:
169       return callback(valMemType, fieldIdx, fieldKind, lvl, lt);
170     };
171     llvm_unreachable("unrecognized field kind");
172   });
173 }
174 
175 unsigned StorageLayout::getNumFields() const {
176   unsigned numFields = 0;
177   foreachField([&numFields](FieldIndex, SparseTensorFieldKind, Level,
178                             LevelType) -> bool {
179     numFields++;
180     return true;
181   });
182   return numFields;
183 }
184 
185 unsigned StorageLayout::getNumDataFields() const {
186   unsigned numFields = 0; // one value memref
187   foreachField([&numFields](FieldIndex fidx, SparseTensorFieldKind, Level,
188                             LevelType) -> bool {
189     if (fidx >= kDataFieldStartingIdx)
190       numFields++;
191     return true;
192   });
193   numFields -= 1; // the last field is StorageSpecifier
194   assert(numFields == getNumFields() - kDataFieldStartingIdx - 1);
195   return numFields;
196 }
197 
198 std::pair<FieldIndex, unsigned>
199 StorageLayout::getFieldIndexAndStride(SparseTensorFieldKind kind,
200                                       std::optional<Level> lvl) const {
201   FieldIndex fieldIdx = kInvalidFieldIndex;
202   unsigned stride = 1;
203   if (kind == SparseTensorFieldKind::CrdMemRef) {
204     assert(lvl.has_value());
205     const Level cooStart = SparseTensorType(enc).getAoSCOOStart();
206     const Level lvlRank = enc.getLvlRank();
207     if (lvl.value() >= cooStart && lvl.value() < lvlRank) {
208       lvl = cooStart;
209       stride = lvlRank - cooStart;
210     }
211   }
212   foreachField([lvl, kind, &fieldIdx](FieldIndex fIdx,
213                                       SparseTensorFieldKind fKind, Level fLvl,
214                                       LevelType lt) -> bool {
215     if ((lvl && fLvl == lvl.value() && kind == fKind) ||
216         (kind == fKind && fKind == SparseTensorFieldKind::ValMemRef)) {
217       fieldIdx = fIdx;
218       // Returns false to break the iteration.
219       return false;
220     }
221     return true;
222   });
223   assert(fieldIdx != kInvalidFieldIndex);
224   return std::pair<FieldIndex, unsigned>(fieldIdx, stride);
225 }
226 
227 //===----------------------------------------------------------------------===//
228 // SparseTensorDialect Attribute Methods.
229 //===----------------------------------------------------------------------===//
230 
231 std::optional<uint64_t> SparseTensorDimSliceAttr::getStatic(int64_t v) {
232   return isDynamic(v) ? std::nullopt
233                       : std::make_optional(static_cast<uint64_t>(v));
234 }
235 
236 std::optional<uint64_t> SparseTensorDimSliceAttr::getStaticOffset() const {
237   return getStatic(getOffset());
238 }
239 
240 std::optional<uint64_t> SparseTensorDimSliceAttr::getStaticStride() const {
241   return getStatic(getStride());
242 }
243 
244 std::optional<uint64_t> SparseTensorDimSliceAttr::getStaticSize() const {
245   return getStatic(getSize());
246 }
247 
248 bool SparseTensorDimSliceAttr::isCompletelyDynamic() const {
249   return isDynamic(getOffset()) && isDynamic(getStride()) &&
250          isDynamic(getSize());
251 }
252 
253 std::string SparseTensorDimSliceAttr::getStaticString(int64_t v) {
254   return isDynamic(v) ? "?" : std::to_string(v);
255 }
256 
257 void SparseTensorDimSliceAttr::print(llvm::raw_ostream &os) const {
258   assert(getImpl() && "Uninitialized SparseTensorDimSliceAttr");
259   os << '(';
260   os << getStaticString(getOffset());
261   os << ", ";
262   os << getStaticString(getSize());
263   os << ", ";
264   os << getStaticString(getStride());
265   os << ')';
266 }
267 
268 void SparseTensorDimSliceAttr::print(AsmPrinter &printer) const {
269   print(printer.getStream());
270 }
271 
272 static ParseResult parseOptionalStaticSlice(int64_t &result,
273                                             AsmParser &parser) {
274   auto parseResult = parser.parseOptionalInteger(result);
275   if (parseResult.has_value()) {
276     if (parseResult.value().succeeded() && result < 0) {
277       parser.emitError(
278           parser.getCurrentLocation(),
279           "expect positive value or ? for slice offset/size/stride");
280       return failure();
281     }
282     return parseResult.value();
283   }
284 
285   // Else, and '?' which represented dynamic slice
286   result = SparseTensorDimSliceAttr::kDynamic;
287   return parser.parseQuestion();
288 }
289 
290 Attribute SparseTensorDimSliceAttr::parse(AsmParser &parser, Type type) {
291   int64_t offset = kDynamic, size = kDynamic, stride = kDynamic;
292 
293   if (failed(parser.parseLParen()) ||
294       failed(parseOptionalStaticSlice(offset, parser)) ||
295       failed(parser.parseComma()) ||
296       failed(parseOptionalStaticSlice(size, parser)) ||
297       failed(parser.parseComma()) ||
298       failed(parseOptionalStaticSlice(stride, parser)) ||
299       failed(parser.parseRParen()))
300     return {};
301 
302   return parser.getChecked<SparseTensorDimSliceAttr>(parser.getContext(),
303                                                      offset, size, stride);
304 }
305 
306 LogicalResult
307 SparseTensorDimSliceAttr::verify(function_ref<InFlightDiagnostic()> emitError,
308                                  int64_t offset, int64_t size, int64_t stride) {
309   if (!isDynamic(offset) && offset < 0)
310     return emitError() << "expect non-negative value or ? for slice offset";
311   if (!isDynamic(size) && size <= 0)
312     return emitError() << "expect positive value or ? for slice size";
313   if (!isDynamic(stride) && stride <= 0)
314     return emitError() << "expect positive value or ? for slice stride";
315   return success();
316 }
317 
318 SparseTensorEncodingAttr
319 SparseTensorEncodingAttr::withDimToLvl(AffineMap dimToLvl) const {
320   assert(getImpl() && "Uninitialized SparseTensorEncodingAttr");
321   return SparseTensorEncodingAttr::get(getContext(), getLvlTypes(), dimToLvl,
322                                        AffineMap(), getPosWidth(),
323                                        getCrdWidth());
324 }
325 
326 SparseTensorEncodingAttr
327 SparseTensorEncodingAttr::withDimToLvl(SparseTensorEncodingAttr enc) const {
328   return withDimToLvl(enc ? enc.getDimToLvl() : AffineMap());
329 }
330 
331 SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutDimToLvl() const {
332   return withDimToLvl(AffineMap());
333 }
334 
335 SparseTensorEncodingAttr
336 SparseTensorEncodingAttr::withBitWidths(unsigned posWidth,
337                                         unsigned crdWidth) const {
338   assert(getImpl() && "Uninitialized SparseTensorEncodingAttr");
339   return SparseTensorEncodingAttr::get(getContext(), getLvlTypes(),
340                                        getDimToLvl(), getLvlToDim(), posWidth,
341                                        crdWidth);
342 }
343 
344 SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutBitWidths() const {
345   return withBitWidths(0, 0);
346 }
347 
348 SparseTensorEncodingAttr SparseTensorEncodingAttr::withDimSlices(
349     ArrayRef<SparseTensorDimSliceAttr> dimSlices) const {
350   return SparseTensorEncodingAttr::get(getContext(), getLvlTypes(),
351                                        getDimToLvl(), getLvlToDim(),
352                                        getPosWidth(), getCrdWidth(), dimSlices);
353 }
354 
355 SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutDimSlices() const {
356   return withDimSlices(ArrayRef<SparseTensorDimSliceAttr>{});
357 }
358 
359 uint64_t SparseTensorEncodingAttr::getBatchLvlRank() const {
360   ArrayRef<LevelType> lvlTypes = getLvlTypes();
361   auto lastBatch = std::find_if(lvlTypes.rbegin(), lvlTypes.rend(), isBatchLT);
362   return std::distance(lastBatch, lvlTypes.rend());
363 }
364 
365 bool SparseTensorEncodingAttr::isAllDense() const {
366   return !getImpl() || llvm::all_of(getLvlTypes(), isDenseLT);
367 }
368 
369 bool SparseTensorEncodingAttr::isAllOrdered() const {
370   return !getImpl() || llvm::all_of(getLvlTypes(), isOrderedLT);
371 }
372 
373 Type SparseTensorEncodingAttr::getCrdElemType() const {
374   if (!getImpl())
375     return nullptr;
376   if (getCrdWidth())
377     return IntegerType::get(getContext(), getCrdWidth());
378   return IndexType::get(getContext());
379 }
380 
381 Type SparseTensorEncodingAttr::getPosElemType() const {
382   if (!getImpl())
383     return nullptr;
384   if (getPosWidth())
385     return IntegerType::get(getContext(), getPosWidth());
386   return IndexType::get(getContext());
387 }
388 
389 MemRefType SparseTensorEncodingAttr::getCrdMemRefType(
390     std::optional<ArrayRef<int64_t>> dimShape) const {
391   SmallVector<Size> shape = getSparseFieldShape(*this, dimShape);
392   return MemRefType::get(shape, getCrdElemType());
393 }
394 
395 MemRefType SparseTensorEncodingAttr::getPosMemRefType(
396     std::optional<ArrayRef<int64_t>> dimShape) const {
397   SmallVector<Size> shape = getSparseFieldShape(*this, dimShape);
398   return MemRefType::get(shape, getPosElemType());
399 }
400 
401 bool SparseTensorEncodingAttr::isIdentity() const {
402   return !getImpl() || !getDimToLvl() || getDimToLvl().isIdentity();
403 }
404 
405 bool SparseTensorEncodingAttr::isPermutation() const {
406   return !getImpl() || !getDimToLvl() || getDimToLvl().isPermutation();
407 }
408 
409 Dimension SparseTensorEncodingAttr::getDimRank() const {
410   assert(getImpl() && "Uninitialized SparseTensorEncodingAttr");
411   const auto dimToLvl = getDimToLvl();
412   return dimToLvl ? dimToLvl.getNumDims() : getLvlRank();
413 }
414 
415 Level SparseTensorEncodingAttr::getLvlRank() const {
416   assert(getImpl() && "Uninitialized SparseTensorEncodingAttr");
417   return getLvlTypes().size();
418 }
419 
420 LevelType SparseTensorEncodingAttr::getLvlType(Level l) const {
421   if (!getImpl())
422     return LevelFormat::Batch;
423   assert(l < getLvlRank() && "Level is out of bounds");
424   return getLvlTypes()[l];
425 }
426 
427 bool SparseTensorEncodingAttr::isSlice() const {
428   assert(getImpl() && "Uninitialized SparseTensorEncodingAttr");
429   return !getDimSlices().empty();
430 }
431 
432 SparseTensorDimSliceAttr
433 SparseTensorEncodingAttr::getDimSlice(Dimension dim) const {
434   assert(isSlice() && "Is not a slice");
435   const auto dimSlices = getDimSlices();
436   assert(dim < dimSlices.size() && "Dimension is out of bounds");
437   return dimSlices[dim];
438 }
439 
440 std::optional<uint64_t>
441 SparseTensorEncodingAttr::getStaticDimSliceOffset(Dimension dim) const {
442   return getDimSlice(dim).getStaticOffset();
443 }
444 
445 std::optional<uint64_t>
446 SparseTensorEncodingAttr::getStaticDimSliceStride(Dimension dim) const {
447   return getDimSlice(dim).getStaticStride();
448 }
449 
450 std::optional<uint64_t>
451 SparseTensorEncodingAttr::getStaticLvlSliceOffset(Level lvl) const {
452   return getStaticDimSliceOffset(toDim(*this, lvl));
453 }
454 
455 std::optional<uint64_t>
456 SparseTensorEncodingAttr::getStaticLvlSliceStride(Level lvl) const {
457   return getStaticDimSliceStride(toDim(*this, lvl));
458 }
459 
460 SmallVector<int64_t>
461 SparseTensorEncodingAttr::translateShape(ArrayRef<int64_t> srcShape,
462                                          CrdTransDirectionKind dir) const {
463   if (isIdentity())
464     return SmallVector<int64_t>(srcShape);
465 
466   SmallVector<int64_t> ret;
467   unsigned rank =
468       dir == CrdTransDirectionKind::dim2lvl ? getLvlRank() : getDimRank();
469   ret.reserve(rank);
470 
471   if (isPermutation()) {
472     for (unsigned r = 0; r < rank; r++) {
473       unsigned trans = dir == CrdTransDirectionKind::dim2lvl ? toDim(*this, r)
474                                                              : toLvl(*this, r);
475       ret.push_back(srcShape[trans]);
476     }
477     return ret;
478   }
479 
480   // Handle non-permutation maps.
481   AffineMap transMap =
482       dir == CrdTransDirectionKind::dim2lvl ? getDimToLvl() : getLvlToDim();
483 
484   SmallVector<AffineExpr> dimRep;
485   dimRep.reserve(srcShape.size());
486   for (int64_t sz : srcShape) {
487     if (!ShapedType::isDynamic(sz)) {
488       // Push back the max coordinate for the given dimension/level size.
489       dimRep.push_back(getAffineConstantExpr(sz - 1, getContext()));
490     } else {
491       // A dynamic size, use a AffineDimExpr to symbolize the value.
492       dimRep.push_back(getAffineDimExpr(dimRep.size(), getContext()));
493     }
494   };
495 
496   for (AffineExpr exp : transMap.getResults()) {
497     // Do constant propagation on the affine map.
498     AffineExpr evalExp =
499         simplifyAffineExpr(exp.replaceDims(dimRep), srcShape.size(), 0);
500     // use llvm namespace here to avoid ambiguity
501     if (auto c = llvm::dyn_cast<AffineConstantExpr>(evalExp)) {
502       ret.push_back(c.getValue() + 1);
503     } else {
504       if (auto mod = llvm::dyn_cast<AffineBinaryOpExpr>(evalExp);
505           mod && mod.getKind() == AffineExprKind::Mod) {
506         // We can still infer a static bound for expressions in form
507         // "d % constant" since d % constant \in [0, constant).
508         if (auto bound = llvm::dyn_cast<AffineConstantExpr>(mod.getRHS())) {
509           ret.push_back(bound.getValue());
510           continue;
511         }
512       }
513       ret.push_back(ShapedType::kDynamic);
514     }
515   }
516   assert(ret.size() == rank);
517   return ret;
518 }
519 
520 ValueRange
521 SparseTensorEncodingAttr::translateCrds(OpBuilder &builder, Location loc,
522                                         ValueRange crds,
523                                         CrdTransDirectionKind dir) const {
524   if (!getImpl())
525     return crds;
526 
527   SmallVector<Type> retType(
528       dir == CrdTransDirectionKind::lvl2dim ? getDimRank() : getLvlRank(),
529       builder.getIndexType());
530   auto transOp = builder.create<CrdTranslateOp>(loc, retType, crds, dir, *this);
531   return transOp.getOutCrds();
532 }
533 
534 Attribute SparseTensorEncodingAttr::parse(AsmParser &parser, Type type) {
535   // Open "<{" part.
536   if (failed(parser.parseLess()))
537     return {};
538   if (failed(parser.parseLBrace()))
539     return {};
540 
541   // Process the data from the parsed dictionary value into struct-like data.
542   SmallVector<LevelType> lvlTypes;
543   SmallVector<SparseTensorDimSliceAttr> dimSlices;
544   AffineMap dimToLvl = {};
545   AffineMap lvlToDim = {};
546   unsigned posWidth = 0;
547   unsigned crdWidth = 0;
548   StringRef attrName;
549   SmallVector<StringRef, 3> keys = {"map", "posWidth", "crdWidth"};
550   while (succeeded(parser.parseOptionalKeyword(&attrName))) {
551     // Detect admissible keyword.
552     auto *it = find(keys, attrName);
553     if (it == keys.end()) {
554       parser.emitError(parser.getNameLoc(), "unexpected key: ") << attrName;
555       return {};
556     }
557     unsigned keyWordIndex = it - keys.begin();
558     // Consume the `=` after keys
559     if (failed(parser.parseEqual()))
560       return {};
561     // Dispatch on keyword.
562     switch (keyWordIndex) {
563     case 0: { // map
564       ir_detail::DimLvlMapParser cParser(parser);
565       auto res = cParser.parseDimLvlMap();
566       if (failed(res))
567         return {};
568       const auto &dlm = *res;
569 
570       const Level lvlRank = dlm.getLvlRank();
571       for (Level lvl = 0; lvl < lvlRank; lvl++)
572         lvlTypes.push_back(dlm.getLvlType(lvl));
573 
574       const Dimension dimRank = dlm.getDimRank();
575       for (Dimension dim = 0; dim < dimRank; dim++)
576         dimSlices.push_back(dlm.getDimSlice(dim));
577       // NOTE: the old syntax requires an all-or-nothing approach to
578       // `dimSlices`; therefore, if any slice actually exists then we need
579       // to convert null-DSA into default/nop DSA.
580       const auto isDefined = [](SparseTensorDimSliceAttr slice) {
581         return static_cast<bool>(slice.getImpl());
582       };
583       if (llvm::any_of(dimSlices, isDefined)) {
584         const auto defaultSlice =
585             SparseTensorDimSliceAttr::get(parser.getContext());
586         for (Dimension dim = 0; dim < dimRank; dim++)
587           if (!isDefined(dimSlices[dim]))
588             dimSlices[dim] = defaultSlice;
589       } else {
590         dimSlices.clear();
591       }
592 
593       dimToLvl = dlm.getDimToLvlMap(parser.getContext());
594       lvlToDim = dlm.getLvlToDimMap(parser.getContext());
595       break;
596     }
597     case 1: { // posWidth
598       Attribute attr;
599       if (failed(parser.parseAttribute(attr)))
600         return {};
601       auto intAttr = llvm::dyn_cast<IntegerAttr>(attr);
602       if (!intAttr) {
603         parser.emitError(parser.getNameLoc(),
604                          "expected an integral position bitwidth");
605         return {};
606       }
607       posWidth = intAttr.getInt();
608       break;
609     }
610     case 2: { // crdWidth
611       Attribute attr;
612       if (failed(parser.parseAttribute(attr)))
613         return {};
614       auto intAttr = llvm::dyn_cast<IntegerAttr>(attr);
615       if (!intAttr) {
616         parser.emitError(parser.getNameLoc(),
617                          "expected an integral index bitwidth");
618         return {};
619       }
620       crdWidth = intAttr.getInt();
621       break;
622     }
623     } // switch
624     // Only last item can omit the comma.
625     if (parser.parseOptionalComma().failed())
626       break;
627   }
628 
629   // Close "}>" part.
630   if (failed(parser.parseRBrace()))
631     return {};
632   if (failed(parser.parseGreater()))
633     return {};
634 
635   // Construct struct-like storage for attribute.
636   if (!lvlToDim || lvlToDim.isEmpty()) {
637     lvlToDim = inferLvlToDim(dimToLvl, parser.getContext());
638   }
639   return parser.getChecked<SparseTensorEncodingAttr>(
640       parser.getContext(), lvlTypes, dimToLvl, lvlToDim, posWidth, crdWidth,
641       dimSlices);
642 }
643 
644 void SparseTensorEncodingAttr::print(AsmPrinter &printer) const {
645   auto map = static_cast<AffineMap>(getDimToLvl());
646   // Empty affine map indicates identity map
647   if (!map)
648     map = AffineMap::getMultiDimIdentityMap(getLvlTypes().size(), getContext());
649   printer << "<{ map = ";
650   printSymbols(map, printer);
651   printer << '(';
652   printDimensions(map, printer, getDimSlices());
653   printer << ") -> (";
654   printLevels(map, printer, getLvlTypes());
655   printer << ')';
656   // Print remaining members only for non-default values.
657   if (getPosWidth())
658     printer << ", posWidth = " << getPosWidth();
659   if (getCrdWidth())
660     printer << ", crdWidth = " << getCrdWidth();
661   printer << " }>";
662 }
663 
664 void SparseTensorEncodingAttr::printSymbols(AffineMap &map,
665                                             AsmPrinter &printer) const {
666   if (map.getNumSymbols() == 0)
667     return;
668   printer << '[';
669   for (unsigned i = 0, n = map.getNumSymbols() - 1; i < n; i++)
670     printer << 's' << i << ", ";
671   if (map.getNumSymbols() >= 1)
672     printer << 's' << map.getNumSymbols() - 1;
673   printer << ']';
674 }
675 
676 void SparseTensorEncodingAttr::printDimensions(
677     AffineMap &map, AsmPrinter &printer,
678     ArrayRef<SparseTensorDimSliceAttr> dimSlices) const {
679   if (!dimSlices.empty()) {
680     for (unsigned i = 0, n = map.getNumDims() - 1; i < n; i++)
681       printer << 'd' << i << " : " << dimSlices[i] << ", ";
682     if (map.getNumDims() >= 1) {
683       printer << 'd' << map.getNumDims() - 1 << " : "
684               << dimSlices[map.getNumDims() - 1];
685     }
686   } else {
687     for (unsigned i = 0, n = map.getNumDims() - 1; i < n; i++)
688       printer << 'd' << i << ", ";
689     if (map.getNumDims() >= 1)
690       printer << 'd' << map.getNumDims() - 1;
691   }
692 }
693 
694 void SparseTensorEncodingAttr::printLevels(AffineMap &map, AsmPrinter &printer,
695                                            ArrayRef<LevelType> lvlTypes) const {
696   for (unsigned i = 0, n = map.getNumResults() - 1; i < n; i++) {
697     map.getResult(i).print(printer.getStream());
698     printer << " : " << toMLIRString(lvlTypes[i]) << ", ";
699   }
700   if (map.getNumResults() >= 1) {
701     auto lastIndex = map.getNumResults() - 1;
702     map.getResult(lastIndex).print(printer.getStream());
703     printer << " : " << toMLIRString(lvlTypes[lastIndex]);
704   }
705 }
706 
707 LogicalResult SparseTensorEncodingAttr::verify(
708     function_ref<InFlightDiagnostic()> emitError, ArrayRef<LevelType> lvlTypes,
709     AffineMap dimToLvl, AffineMap lvlToDim, unsigned posWidth,
710     unsigned crdWidth, ArrayRef<SparseTensorDimSliceAttr> dimSlices) {
711   if (!acceptBitWidth(posWidth))
712     return emitError() << "unexpected position bitwidth: " << posWidth;
713   if (!acceptBitWidth(crdWidth))
714     return emitError() << "unexpected coordinate bitwidth: " << crdWidth;
715   if (auto it = std::find_if(lvlTypes.begin(), lvlTypes.end(), isSingletonLT);
716       it != std::end(lvlTypes)) {
717     if (it == lvlTypes.begin() ||
718         (!isCompressedLT(*(it - 1)) && !isLooseCompressedLT(*(it - 1))))
719       return emitError() << "expected compressed or loose_compressed level "
720                             "before singleton level";
721     if (!std::all_of(it, lvlTypes.end(),
722                      [](LevelType i) { return isSingletonLT(i); }))
723       return emitError() << "expected all singleton lvlTypes "
724                             "following a singleton level";
725     // We can potentially support mixed SoA/AoS singleton levels.
726     if (!std::all_of(it, lvlTypes.end(), [it](LevelType i) {
727           return it->isa<LevelPropNonDefault::SoA>() ==
728                  i.isa<LevelPropNonDefault::SoA>();
729         })) {
730       return emitError() << "expected all singleton lvlTypes stored in the "
731                             "same memory layout (SoA vs AoS).";
732     }
733   }
734 
735   auto lastBatch = std::find_if(lvlTypes.rbegin(), lvlTypes.rend(), isBatchLT);
736   if (!std::all_of(lastBatch, lvlTypes.rend(), isBatchLT))
737     return emitError() << "Batch lvlType can only be leading levels.";
738 
739   // SoA property can only be applied on singleton level.
740   auto soaLvls = llvm::make_filter_range(lvlTypes, [](LevelType lt) {
741     return lt.isa<LevelPropNonDefault::SoA>();
742   });
743   if (llvm::any_of(soaLvls, [](LevelType lt) {
744         return !lt.isa<LevelFormat::Singleton>();
745       })) {
746     return emitError() << "SoA is only applicable to singleton lvlTypes.";
747   }
748 
749   // TODO: audit formats that actually are supported by backend.
750   if (auto it = std::find_if(lvlTypes.begin(), lvlTypes.end(), isNOutOfMLT);
751       it != std::end(lvlTypes)) {
752     if (it != lvlTypes.end() - 1)
753       return emitError() << "expected n_out_of_m to be the last level type";
754     if (!std::all_of(lvlTypes.begin(), it,
755                      [](LevelType i) { return isDenseLT(i); }))
756       return emitError() << "expected all dense lvlTypes "
757                             "before a n_out_of_m level";
758     if (dimToLvl && (dimToLvl.getNumDims() != dimToLvl.getNumResults())) {
759       if (!isBlockSparsity(dimToLvl)) {
760         return emitError()
761                << "expected 1xm block structure for n_out_of_m level";
762       }
763       auto sizes = getBlockSize(dimToLvl);
764       unsigned coefficient = 0;
765       for (const auto &elem : sizes) {
766         if (elem != 0) {
767           if (elem != coefficient && coefficient != 0) {
768             return emitError() << "expected only one blocked level "
769                                   "with the same coefficients";
770           }
771           coefficient = elem;
772         }
773       }
774       if (coefficient != getM(*it)) {
775         return emitError() << "expected coeffiencts of Affine expressions "
776                               "to be equal to m of n_out_of_m level";
777       }
778     }
779   }
780   // Before we can check that the level-rank is consistent/coherent
781   // across all fields, we need to define it.  The source-of-truth for
782   // the `getLvlRank` method is the length of the level-types array,
783   // since it must always be provided and have full rank; therefore we
784   // use that same source-of-truth here.
785   const Level lvlRank = lvlTypes.size();
786   if (lvlRank == 0)
787     return emitError() << "expected a non-empty array for lvlTypes";
788   // We save `dimRank` here because we'll also need it to verify `dimSlices`.
789   const Dimension dimRank = dimToLvl ? dimToLvl.getNumDims() : lvlRank;
790   if (dimToLvl) {
791     if (dimToLvl.getNumResults() != lvlRank)
792       return emitError()
793              << "level-rank mismatch between dimToLvl and lvlTypes: "
794              << dimToLvl.getNumResults() << " != " << lvlRank;
795     auto inferRes = inferLvlToDim(dimToLvl, dimToLvl.getContext());
796     // Symbols can't be inferred but are acceptable.
797     if (!inferRes && dimToLvl.getNumSymbols() == 0)
798       return emitError() << "failed to infer lvlToDim from dimToLvl";
799     if (lvlToDim && (inferRes != lvlToDim))
800       return emitError() << "expected lvlToDim to be an inverse of dimToLvl";
801     if (dimRank > lvlRank)
802       return emitError() << "unexpected dimToLvl mapping from " << dimRank
803                          << " to " << lvlRank;
804   }
805   if (!dimSlices.empty()) {
806     if (dimSlices.size() != dimRank)
807       return emitError()
808              << "dimension-rank mismatch between dimSlices and dimToLvl: "
809              << dimSlices.size() << " != " << dimRank;
810     // Compiler support for `dimSlices` currently requires that the two
811     // ranks agree.  (However, it does allow `dimToLvl` to be a permutation.)
812     if (dimRank != lvlRank)
813       return emitError()
814              << "dimSlices expected dimension-rank to match level-rank: "
815              << dimRank << " != " << lvlRank;
816   }
817   return success();
818 }
819 
820 LogicalResult SparseTensorEncodingAttr::verifyEncoding(
821     ArrayRef<Size> dimShape, Type elementType,
822     function_ref<InFlightDiagnostic()> emitError) const {
823   // Check structural integrity.  In particular, this ensures that the
824   // level-rank is coherent across all the fields.
825   if (failed(verify(emitError, getLvlTypes(), getDimToLvl(), getLvlToDim(),
826                     getPosWidth(), getCrdWidth(), getDimSlices())))
827     return failure();
828   // Check integrity with tensor type specifics.  In particular, we
829   // need only check that the dimension-rank of the tensor agrees with
830   // the dimension-rank of the encoding.
831   const Dimension dimRank = dimShape.size();
832   if (dimRank == 0)
833     return emitError() << "expected non-scalar sparse tensor";
834   if (getDimRank() != dimRank)
835     return emitError()
836            << "dimension-rank mismatch between encoding and tensor shape: "
837            << getDimRank() << " != " << dimRank;
838   return success();
839 }
840 
841 //===----------------------------------------------------------------------===//
842 // SparseTensorType Methods.
843 //===----------------------------------------------------------------------===//
844 
845 bool mlir::sparse_tensor::SparseTensorType::isCOOType(Level startLvl,
846                                                       bool isUnique) const {
847   if (!hasEncoding())
848     return false;
849   if (!isCompressedLvl(startLvl) && !isLooseCompressedLvl(startLvl))
850     return false;
851   for (Level l = startLvl + 1; l < lvlRank; ++l)
852     if (!isSingletonLvl(l))
853       return false;
854   // If isUnique is true, then make sure that the last level is unique,
855   // that is, when lvlRank == 1, the only compressed level is unique,
856   // and when lvlRank > 1, the last singleton is unique.
857   return !isUnique || isUniqueLvl(lvlRank - 1);
858 }
859 
860 Level mlir::sparse_tensor::SparseTensorType::getAoSCOOStart() const {
861   SmallVector<COOSegment> coo = getCOOSegments();
862   assert(coo.size() == 1 || coo.empty());
863   if (!coo.empty() && coo.front().isAoS()) {
864     return coo.front().lvlRange.first;
865   }
866   return lvlRank;
867 }
868 
869 SmallVector<COOSegment>
870 mlir::sparse_tensor::SparseTensorType::getCOOSegments() const {
871   SmallVector<COOSegment> ret;
872   if (!hasEncoding() || lvlRank <= 1)
873     return ret;
874 
875   ArrayRef<LevelType> lts = getLvlTypes();
876   Level l = 0;
877   while (l < lvlRank) {
878     auto lt = lts[l];
879     if (lt.isa<LevelFormat::Compressed, LevelFormat::LooseCompressed>()) {
880       auto cur = lts.begin() + l;
881       auto end = std::find_if(cur + 1, lts.end(), [](LevelType lt) {
882         return !lt.isa<LevelFormat::Singleton>();
883       });
884       unsigned cooLen = std::distance(cur, end);
885       if (cooLen > 1) {
886         // To support mixed SoA/AoS COO, we should break the segment when the
887         // storage scheme changes, for now we faithfully assume that all
888         // consecutive singleton levels have the same storage format as verified
889         // STEA.
890         ret.push_back(COOSegment{std::make_pair(l, l + cooLen),
891                                  lts[l + 1].isa<LevelPropNonDefault::SoA>()});
892       }
893       l += cooLen;
894     } else {
895       l++;
896     }
897   }
898   return ret;
899 }
900 
901 RankedTensorType
902 mlir::sparse_tensor::SparseTensorType::getCOOType(bool ordered) const {
903   SmallVector<LevelType> lvlTypes;
904   lvlTypes.reserve(lvlRank);
905   // A non-unique compressed level at beginning (unless this is
906   // also the last level, then it is unique).
907   lvlTypes.push_back(
908       *buildLevelType(LevelFormat::Compressed, ordered, lvlRank == 1));
909   if (lvlRank > 1) {
910     // Followed by n-2 non-unique singleton levels.
911     std::fill_n(std::back_inserter(lvlTypes), lvlRank - 2,
912                 *buildLevelType(LevelFormat::Singleton, ordered, false));
913     // Ends by a unique singleton level.
914     lvlTypes.push_back(*buildLevelType(LevelFormat::Singleton, ordered, true));
915   }
916   auto enc = SparseTensorEncodingAttr::get(getContext(), lvlTypes,
917                                            getDimToLvl(), getLvlToDim(),
918                                            getPosWidth(), getCrdWidth());
919   return RankedTensorType::get(getDimShape(), getElementType(), enc);
920 }
921 
922 //===----------------------------------------------------------------------===//
923 // Convenience Methods.
924 //===----------------------------------------------------------------------===//
925 
926 SparseTensorEncodingAttr
927 mlir::sparse_tensor::getSparseTensorEncoding(Type type) {
928   if (auto ttp = llvm::dyn_cast<RankedTensorType>(type))
929     return llvm::dyn_cast_or_null<SparseTensorEncodingAttr>(ttp.getEncoding());
930   if (auto mdtp = llvm::dyn_cast<StorageSpecifierType>(type))
931     return mdtp.getEncoding();
932   return nullptr;
933 }
934 
935 AffineMap mlir::sparse_tensor::inferLvlToDim(AffineMap dimToLvl,
936                                              MLIRContext *context) {
937   auto map = static_cast<AffineMap>(dimToLvl);
938   AffineMap lvlToDim;
939   // Return an empty lvlToDim when inference is not successful.
940   if (!map || map.getNumSymbols() != 0) {
941     lvlToDim = AffineMap();
942   } else if (map.isPermutation()) {
943     lvlToDim = inversePermutation(map);
944   } else if (isBlockSparsity(map)) {
945     lvlToDim = inverseBlockSparsity(map, context);
946   }
947   return lvlToDim;
948 }
949 
950 AffineMap mlir::sparse_tensor::inverseBlockSparsity(AffineMap dimToLvl,
951                                                     MLIRContext *context) {
952   SmallVector<AffineExpr> lvlExprs;
953   auto numLvls = dimToLvl.getNumResults();
954   lvlExprs.reserve(numLvls);
955   // lvlExprComponents stores information of the floordiv and mod operations
956   // applied to the same dimension, so as to build the lvlToDim map.
957   std::map<unsigned, SmallVector<AffineExpr, 3>> lvlExprComponents;
958   for (unsigned i = 0, n = numLvls; i < n; i++) {
959     auto result = dimToLvl.getResult(i);
960     if (auto binOp = dyn_cast<AffineBinaryOpExpr>(result)) {
961       if (result.getKind() == AffineExprKind::FloorDiv) {
962         // Position of the dimension in dimToLvl.
963         auto pos = dyn_cast<AffineDimExpr>(binOp.getLHS()).getPosition();
964         assert(lvlExprComponents.find(pos) == lvlExprComponents.end() &&
965                "expected only one floordiv for each dimension");
966         SmallVector<AffineExpr, 3> components;
967         // Level variable for floordiv.
968         components.push_back(getAffineDimExpr(i, context));
969         // Multiplier.
970         components.push_back(binOp.getRHS());
971         // Map key is the position of the dimension.
972         lvlExprComponents[pos] = components;
973       } else if (result.getKind() == AffineExprKind::Mod) {
974         auto pos = dyn_cast<AffineDimExpr>(binOp.getLHS()).getPosition();
975         assert(lvlExprComponents.find(pos) != lvlExprComponents.end() &&
976                "expected floordiv before mod");
977         // Add level variable for mod to the same vector
978         // of the corresponding floordiv.
979         lvlExprComponents[pos].push_back(getAffineDimExpr(i, context));
980       } else {
981         assert(false && "expected floordiv or mod");
982       }
983     } else {
984       lvlExprs.push_back(getAffineDimExpr(i, context));
985     }
986   }
987   // Build lvlExprs from lvlExprComponents.
988   // For example, for il = i floordiv 2 and ii = i mod 2, the components
989   // would be [il, 2, ii]. It could be used to build the AffineExpr
990   // i = il * 2 + ii in lvlToDim.
991   for (auto &components : lvlExprComponents) {
992     assert(components.second.size() == 3 &&
993            "expected 3 components to build lvlExprs");
994     auto mulOp = getAffineBinaryOpExpr(
995         AffineExprKind::Mul, components.second[0], components.second[1]);
996     auto addOp =
997         getAffineBinaryOpExpr(AffineExprKind::Add, mulOp, components.second[2]);
998     lvlExprs.push_back(addOp);
999   }
1000   return dimToLvl.get(dimToLvl.getNumResults(), 0, lvlExprs, context);
1001 }
1002 
1003 SmallVector<unsigned> mlir::sparse_tensor::getBlockSize(AffineMap dimToLvl) {
1004   assert(isBlockSparsity(dimToLvl) &&
1005          "expected dimToLvl to be block sparsity for calling getBlockSize");
1006   SmallVector<unsigned> blockSize;
1007   for (auto result : dimToLvl.getResults()) {
1008     if (auto binOp = dyn_cast<AffineBinaryOpExpr>(result)) {
1009       if (result.getKind() == AffineExprKind::Mod) {
1010         blockSize.push_back(
1011             dyn_cast<AffineConstantExpr>(binOp.getRHS()).getValue());
1012       }
1013     } else {
1014       blockSize.push_back(0);
1015     }
1016   }
1017   return blockSize;
1018 }
1019 
1020 bool mlir::sparse_tensor::isBlockSparsity(AffineMap dimToLvl) {
1021   if (!dimToLvl)
1022     return false;
1023   std::map<unsigned, int64_t> coeffientMap;
1024   bool hasBlock = false;
1025   for (auto result : dimToLvl.getResults()) {
1026     if (auto binOp = dyn_cast<AffineBinaryOpExpr>(result)) {
1027       // Check for "dim op const".
1028       auto dimOp = dyn_cast<AffineDimExpr>(binOp.getLHS());
1029       auto conOp = dyn_cast<AffineConstantExpr>(binOp.getRHS());
1030       if (!dimOp || !conOp || conOp.getValue() <= 0)
1031         return false;
1032       // Inspect "dim / const" or "dim % const".
1033       auto pos = dimOp.getPosition();
1034       if (binOp.getKind() == AffineExprKind::FloorDiv) {
1035         // Expect only one floordiv for each dimension.
1036         if (coeffientMap.find(pos) != coeffientMap.end())
1037           return false;
1038         // Record coefficient of the floordiv.
1039         coeffientMap[pos] = conOp.getValue();
1040       } else if (binOp.getKind() == AffineExprKind::Mod) {
1041         // Expect floordiv before mod.
1042         if (coeffientMap.find(pos) == coeffientMap.end())
1043           return false;
1044         // Expect mod to have the same coefficient as floordiv.
1045         if (conOp.getValue() != coeffientMap[pos])
1046           return false;
1047         hasBlock = true;
1048       } else {
1049         return false;
1050       }
1051     } else if (auto dimOp = dyn_cast<AffineDimExpr>(result)) {
1052       auto pos = dimOp.getPosition();
1053       // Expect dim to be unset.
1054       if (coeffientMap.find(pos) != coeffientMap.end())
1055         return false;
1056       coeffientMap[pos] = 0;
1057     } else {
1058       return false;
1059     }
1060   }
1061   return hasBlock;
1062 }
1063 
1064 bool mlir::sparse_tensor::hasAnyNonIdentityOperandsOrResults(Operation *op) {
1065   auto hasNonIdentityMap = [](Value v) {
1066     auto stt = tryGetSparseTensorType(v);
1067     return stt && !stt->isIdentity();
1068   };
1069 
1070   return llvm::any_of(op->getOperands(), hasNonIdentityMap) ||
1071          llvm::any_of(op->getResults(), hasNonIdentityMap);
1072 }
1073 
1074 Dimension mlir::sparse_tensor::toDim(SparseTensorEncodingAttr enc, Level l) {
1075   if (enc) {
1076     assert(enc.isPermutation() && "Non permutation map not supported");
1077     if (const auto dimToLvl = enc.getDimToLvl())
1078       return dimToLvl.getDimPosition(l);
1079   }
1080   return l;
1081 }
1082 
1083 Level mlir::sparse_tensor::toLvl(SparseTensorEncodingAttr enc, Dimension d) {
1084   if (enc) {
1085     assert(enc.isPermutation() && "Non permutation map not supported");
1086     if (const auto lvlToDim = enc.getLvlToDim())
1087       return lvlToDim.getDimPosition(d);
1088   }
1089   return d;
1090 }
1091 
1092 /// We normalized sparse tensor encoding attribute by always using
1093 /// ordered/unique LT such that "compressed_nu_no" and "compressed_nu" (as well
1094 /// as other variants) lead to the same storage specifier type, and stripping
1095 /// irrelevant fields that do not alter the sparse tensor memory layout.
1096 static SparseTensorEncodingAttr
1097 getNormalizedEncodingForSpecifier(SparseTensorEncodingAttr enc) {
1098   SmallVector<LevelType> lts;
1099   for (auto lt : enc.getLvlTypes())
1100     lts.push_back(lt.stripStorageIrrelevantProperties());
1101 
1102   return SparseTensorEncodingAttr::get(
1103       enc.getContext(), lts,
1104       AffineMap(), // dimToLvl (irrelevant to storage specifier)
1105       AffineMap(), // lvlToDim (irrelevant to storage specifier)
1106       // Always use `index` for memSize and lvlSize instead of reusing
1107       // `getPosWidth` and `getCrdWidth`. It allows us to reuse the same SSA
1108       // value for different bitwidth, it also avoids casting between index and
1109       // integer (returned by DimOp)
1110       0, 0, enc.getDimSlices());
1111 }
1112 
1113 StorageSpecifierType
1114 StorageSpecifierType::get(MLIRContext *ctx, SparseTensorEncodingAttr encoding) {
1115   return Base::get(ctx, getNormalizedEncodingForSpecifier(encoding));
1116 }
1117 
1118 //===----------------------------------------------------------------------===//
1119 // SparseTensorDialect Operations.
1120 //===----------------------------------------------------------------------===//
1121 
1122 static LogicalResult lvlIsInBounds(Level lvl, Value tensor) {
1123   return success(lvl < getSparseTensorType(tensor).getLvlRank());
1124 }
1125 
1126 static LogicalResult isMatchingWidth(Value mem, unsigned width) {
1127   const Type etp = getMemRefType(mem).getElementType();
1128   return success(width == 0 ? etp.isIndex() : etp.isInteger(width));
1129 }
1130 
1131 static LogicalResult verifySparsifierGetterSetter(
1132     StorageSpecifierKind mdKind, std::optional<Level> lvl,
1133     TypedValue<StorageSpecifierType> md, Operation *op) {
1134   if (mdKind == StorageSpecifierKind::ValMemSize && lvl) {
1135     return op->emitError(
1136         "redundant level argument for querying value memory size");
1137   }
1138 
1139   const auto enc = md.getType().getEncoding();
1140   const Level lvlRank = enc.getLvlRank();
1141 
1142   if (mdKind == StorageSpecifierKind::DimOffset ||
1143       mdKind == StorageSpecifierKind::DimStride)
1144     if (!enc.isSlice())
1145       return op->emitError("requested slice data on non-slice tensor");
1146 
1147   if (mdKind != StorageSpecifierKind::ValMemSize) {
1148     if (!lvl)
1149       return op->emitError("missing level argument");
1150 
1151     const Level l = lvl.value();
1152     if (l >= lvlRank)
1153       return op->emitError("requested level is out of bounds");
1154 
1155     if (mdKind == StorageSpecifierKind::PosMemSize && enc.isSingletonLvl(l))
1156       return op->emitError(
1157           "requested position memory size on a singleton level");
1158   }
1159   return success();
1160 }
1161 
1162 static Type getFieldElemType(SparseTensorType stt, SparseTensorFieldKind kind) {
1163   switch (kind) {
1164   case SparseTensorFieldKind::CrdMemRef:
1165     return stt.getCrdType();
1166   case SparseTensorFieldKind::PosMemRef:
1167     return stt.getPosType();
1168   case SparseTensorFieldKind::ValMemRef:
1169     return stt.getElementType();
1170   case SparseTensorFieldKind::StorageSpec:
1171     return nullptr;
1172   }
1173   llvm_unreachable("Unrecognizable FieldKind");
1174 }
1175 
1176 static LogicalResult verifyPackUnPack(Operation *op, bool requiresStaticShape,
1177                                       SparseTensorType stt,
1178                                       RankedTensorType valTp,
1179                                       TypeRange lvlTps) {
1180   if (requiresStaticShape && !stt.hasStaticDimShape())
1181     return op->emitError("the sparse-tensor must have static shape");
1182   if (!stt.hasEncoding())
1183     return op->emitError("the sparse-tensor must have an encoding attribute");
1184 
1185   // Verifies the trailing COO.
1186   Level cooStartLvl = stt.getAoSCOOStart();
1187   if (cooStartLvl < stt.getLvlRank()) {
1188     // We only supports trailing COO for now, must be the last input.
1189     auto cooTp = llvm::cast<ShapedType>(lvlTps.back());
1190     // The coordinates should be in shape of <? x rank>
1191     unsigned expCOORank = stt.getLvlRank() - cooStartLvl;
1192     if (cooTp.getRank() != 2 || expCOORank != cooTp.getShape().back()) {
1193       op->emitError("input/output trailing COO level-ranks don't match");
1194     }
1195   }
1196 
1197   // Verifies that all types match.
1198   StorageLayout layout(stt.getEncoding());
1199   if (layout.getNumDataFields() != lvlTps.size() + 1) // plus one value memref
1200     return op->emitError("inconsistent number of fields between input/output");
1201 
1202   unsigned idx = 0;
1203   bool misMatch = false;
1204   layout.foreachField([&idx, &misMatch, stt, valTp,
1205                        lvlTps](FieldIndex fid, SparseTensorFieldKind fKind,
1206                                Level lvl, LevelType lt) -> bool {
1207     if (fKind == SparseTensorFieldKind::StorageSpec)
1208       return true;
1209 
1210     Type inputTp = nullptr;
1211     if (fKind == SparseTensorFieldKind::ValMemRef) {
1212       inputTp = valTp;
1213     } else {
1214       assert(fid == idx && stt.getLvlType(lvl) == lt);
1215       inputTp = lvlTps[idx++];
1216     }
1217     // The input element type and expected element type should match.
1218     Type inpElemTp = llvm::cast<TensorType>(inputTp).getElementType();
1219     Type expElemTp = getFieldElemType(stt, fKind);
1220     if (inpElemTp != expElemTp) {
1221       misMatch = true;
1222       return false; // to terminate the iteration
1223     }
1224     return true;
1225   });
1226 
1227   if (misMatch)
1228     return op->emitError("input/output element-types don't match");
1229   return success();
1230 }
1231 
1232 LogicalResult AssembleOp::verify() {
1233   const auto valuesTp = getRankedTensorType(getValues());
1234   const auto lvlsTp = getLevels().getTypes();
1235   const auto resTp = getSparseTensorType(getResult());
1236   return verifyPackUnPack(*this, true, resTp, valuesTp, lvlsTp);
1237 }
1238 
1239 LogicalResult DisassembleOp::verify() {
1240   if (getOutValues().getType() != getRetValues().getType())
1241     return emitError("output values and return value type mismatch");
1242 
1243   for (auto [ot, rt] : llvm::zip_equal(getOutLevels(), getRetLevels()))
1244     if (ot.getType() != rt.getType())
1245       return emitError("output levels and return levels type mismatch");
1246 
1247   const auto valuesTp = getRankedTensorType(getRetValues());
1248   const auto lvlsTp = getRetLevels().getTypes();
1249   const auto srcTp = getSparseTensorType(getTensor());
1250   return verifyPackUnPack(*this, false, srcTp, valuesTp, lvlsTp);
1251 }
1252 
1253 LogicalResult ConvertOp::verify() {
1254   if (auto tp1 = llvm::dyn_cast<RankedTensorType>(getSource().getType())) {
1255     if (auto tp2 = llvm::dyn_cast<RankedTensorType>(getDest().getType())) {
1256       if (tp1.getRank() != tp2.getRank())
1257         return emitError("unexpected conversion mismatch in rank");
1258       auto dstEnc =
1259           llvm::dyn_cast_or_null<SparseTensorEncodingAttr>(tp2.getEncoding());
1260       if (dstEnc && dstEnc.isSlice())
1261         return emitError("cannot convert to a sparse tensor slice");
1262 
1263       auto shape1 = tp1.getShape();
1264       auto shape2 = tp2.getShape();
1265       // Accept size matches between the source and the destination type
1266       // (e.g. 10 vs. 10, 10 vs. ?, or ? vs. ?), but reject direct mismatches or
1267       // matches that would need a runtime assert (e.g. 10 vs. 20 or ? vs. 10).
1268       for (Dimension d = 0, dimRank = tp1.getRank(); d < dimRank; d++)
1269         if (shape1[d] != shape2[d] && shape2[d] != ShapedType::kDynamic)
1270           return emitError("unexpected conversion mismatch in dimension ") << d;
1271       return success();
1272     }
1273   }
1274   return emitError("unexpected type in convert");
1275 }
1276 
1277 OpFoldResult ConvertOp::fold(FoldAdaptor adaptor) {
1278   if (getType() == getSource().getType())
1279     return getSource();
1280   return {};
1281 }
1282 
1283 bool ConvertOp::needsExtraSort() {
1284   SparseTensorType srcStt = getSparseTensorType(getSource());
1285   SparseTensorType dstStt = getSparseTensorType(getDest());
1286 
1287   // We do not need an extra sort when returning unordered sparse tensors or
1288   // dense tensor since dense tensor support random access.
1289   if (dstStt.isAllDense() || !dstStt.isAllOrdered())
1290     return false;
1291 
1292   if (srcStt.isAllOrdered() && dstStt.isAllOrdered() &&
1293       srcStt.hasSameDimToLvl(dstStt)) {
1294     return false;
1295   }
1296 
1297   // Source and dest tensors are ordered in different ways. We only do direct
1298   // dense to sparse conversion when the dense input is defined by a sparse
1299   // constant. Note that we can theoretically always directly convert from dense
1300   // inputs by rotating dense loops but it leads to bad cache locality and hurt
1301   // performance.
1302   if (auto constOp = getSource().getDefiningOp<arith::ConstantOp>())
1303     if (isa<SparseElementsAttr>(constOp.getValue()))
1304       return false;
1305 
1306   return true;
1307 }
1308 
1309 LogicalResult CrdTranslateOp::verify() {
1310   uint64_t inRank = getEncoder().getLvlRank();
1311   uint64_t outRank = getEncoder().getDimRank();
1312 
1313   if (getDirection() == CrdTransDirectionKind::dim2lvl)
1314     std::swap(inRank, outRank);
1315 
1316   if (inRank != getInCrds().size() || outRank != getOutCrds().size())
1317     return emitError("Coordinate rank mismatch with encoding");
1318 
1319   return success();
1320 }
1321 
1322 LogicalResult CrdTranslateOp::fold(FoldAdaptor adaptor,
1323                                    SmallVectorImpl<OpFoldResult> &results) {
1324   if (getEncoder().isIdentity()) {
1325     results.assign(getInCrds().begin(), getInCrds().end());
1326     return success();
1327   }
1328   if (getEncoder().isPermutation()) {
1329     AffineMap perm = getDirection() == CrdTransDirectionKind::dim2lvl
1330                          ? getEncoder().getDimToLvl()
1331                          : getEncoder().getLvlToDim();
1332     for (AffineExpr exp : perm.getResults())
1333       results.push_back(getInCrds()[cast<AffineDimExpr>(exp).getPosition()]);
1334     return success();
1335   }
1336 
1337   // Fuse dim2lvl/lvl2dim pairs.
1338   auto def = getInCrds()[0].getDefiningOp<CrdTranslateOp>();
1339   bool sameDef = def && llvm::all_of(getInCrds(), [def](Value v) {
1340                    return v.getDefiningOp() == def;
1341                  });
1342   if (!sameDef)
1343     return failure();
1344 
1345   bool oppositeDir = def.getDirection() != getDirection();
1346   bool sameOracle =
1347       def.getEncoder().getDimToLvl() == getEncoder().getDimToLvl();
1348   bool sameCount = def.getNumResults() == getInCrds().size();
1349   if (!oppositeDir || !sameOracle || !sameCount)
1350     return failure();
1351 
1352   // The definition produces the coordinates in the same order as the input
1353   // coordinates.
1354   bool sameOrder = llvm::all_of(llvm::zip_equal(def.getOutCrds(), getInCrds()),
1355                                 [](auto valuePair) {
1356                                   auto [lhs, rhs] = valuePair;
1357                                   return lhs == rhs;
1358                                 });
1359 
1360   if (!sameOrder)
1361     return failure();
1362   // l1 = dim2lvl (lvl2dim l0)
1363   // ==> l0
1364   results.append(def.getInCrds().begin(), def.getInCrds().end());
1365   return success();
1366 }
1367 
1368 void LvlOp::build(OpBuilder &builder, OperationState &state, Value source,
1369                   int64_t index) {
1370   Value val = builder.create<arith::ConstantIndexOp>(state.location, index);
1371   return build(builder, state, source, val);
1372 }
1373 
1374 LogicalResult LvlOp::verify() {
1375   if (std::optional<uint64_t> lvl = getConstantLvlIndex()) {
1376     auto stt = getSparseTensorType(getSource());
1377     if (static_cast<uint64_t>(lvl.value()) >= stt.getLvlRank())
1378       emitError("Level index exceeds the rank of the input sparse tensor");
1379   }
1380   return success();
1381 }
1382 
1383 std::optional<uint64_t> LvlOp::getConstantLvlIndex() {
1384   return getConstantIntValue(getIndex());
1385 }
1386 
1387 Speculation::Speculatability LvlOp::getSpeculatability() {
1388   auto constantIndex = getConstantLvlIndex();
1389   if (!constantIndex)
1390     return Speculation::NotSpeculatable;
1391 
1392   assert(constantIndex <
1393          cast<RankedTensorType>(getSource().getType()).getRank());
1394   return Speculation::Speculatable;
1395 }
1396 
1397 OpFoldResult LvlOp::fold(FoldAdaptor adaptor) {
1398   auto lvlIndex = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getIndex());
1399   if (!lvlIndex)
1400     return {};
1401 
1402   Level lvl = lvlIndex.getAPSInt().getZExtValue();
1403   auto stt = getSparseTensorType(getSource());
1404   if (lvl >= stt.getLvlRank()) {
1405     // Follows the same convention used by tensor.dim operation. Out of bound
1406     // indices produce undefined behavior but are still valid IR. Don't choke on
1407     // them.
1408     return {};
1409   }
1410 
1411   // Helper lambda to build an IndexAttr.
1412   auto getIndexAttr = [this](int64_t lvlSz) {
1413     return IntegerAttr::get(IndexType::get(getContext()), APInt(64, lvlSz));
1414   };
1415 
1416   SmallVector<Size> lvlShape = stt.getLvlShape();
1417   if (!ShapedType::isDynamic(lvlShape[lvl]))
1418     return getIndexAttr(lvlShape[lvl]);
1419 
1420   return {};
1421 }
1422 
1423 void ReinterpretMapOp::build(OpBuilder &odsBuilder, OperationState &odsState,
1424                              SparseTensorEncodingAttr dstEnc, Value source) {
1425   auto srcStt = getSparseTensorType(source);
1426   SmallVector<int64_t> srcLvlShape = srcStt.getLvlShape();
1427   SmallVector<int64_t> dstDimShape =
1428       dstEnc.translateShape(srcLvlShape, CrdTransDirectionKind::lvl2dim);
1429   auto dstTp =
1430       RankedTensorType::get(dstDimShape, srcStt.getElementType(), dstEnc);
1431   return build(odsBuilder, odsState, dstTp, source);
1432 }
1433 
1434 LogicalResult ReinterpretMapOp::verify() {
1435   auto srcStt = getSparseTensorType(getSource());
1436   auto dstStt = getSparseTensorType(getDest());
1437   ArrayRef<LevelType> srcLvlTps = srcStt.getLvlTypes();
1438   ArrayRef<LevelType> dstLvlTps = dstStt.getLvlTypes();
1439 
1440   if (srcLvlTps.size() != dstLvlTps.size())
1441     return emitError("Level rank mismatch between source/dest tensors");
1442 
1443   for (auto [srcLvlTp, dstLvlTp] : llvm::zip(srcLvlTps, dstLvlTps))
1444     if (srcLvlTp != dstLvlTp)
1445       return emitError("Level type mismatch between source/dest tensors");
1446 
1447   if (srcStt.getPosWidth() != dstStt.getPosWidth() ||
1448       srcStt.getCrdWidth() != dstStt.getCrdWidth()) {
1449     return emitError("Crd/Pos width mismatch between source/dest tensors");
1450   }
1451 
1452   if (srcStt.getElementType() != dstStt.getElementType())
1453     return emitError("Element type mismatch between source/dest tensors");
1454 
1455   SmallVector<Size> srcLvlShape = srcStt.getLvlShape();
1456   SmallVector<Size> dstLvlShape = dstStt.getLvlShape();
1457   for (auto [srcLvlSz, dstLvlSz] : llvm::zip(srcLvlShape, dstLvlShape)) {
1458     if (srcLvlSz != dstLvlSz) {
1459       // Should we allow one side to be dynamic size, e.g., <?x?> should be
1460       // compatible to <3x4>? For now, we require all the level sizes to be
1461       // *exactly* matched for simplicity.
1462       return emitError("Level size mismatch between source/dest tensors");
1463     }
1464   }
1465 
1466   return success();
1467 }
1468 
1469 OpFoldResult ReinterpretMapOp::fold(FoldAdaptor adaptor) {
1470   if (getSource().getType() == getDest().getType())
1471     return getSource();
1472 
1473   if (auto def = getSource().getDefiningOp<ReinterpretMapOp>()) {
1474     // A -> B, B -> A ==> A
1475     if (def.getSource().getType() == getDest().getType())
1476       return def.getSource();
1477   }
1478   return {};
1479 }
1480 
1481 template <typename ToBufferOp>
1482 static LogicalResult inferSparseBufferType(ValueRange ops, DictionaryAttr attr,
1483                                            OpaqueProperties prop,
1484                                            RegionRange region,
1485                                            SmallVectorImpl<mlir::Type> &ret) {
1486   typename ToBufferOp::Adaptor adaptor(ops, attr, prop, region);
1487   SparseTensorType stt = getSparseTensorType(adaptor.getTensor());
1488   Type elemTp = nullptr;
1489   bool withStride = false;
1490   if constexpr (std::is_same_v<ToBufferOp, ToPositionsOp>) {
1491     elemTp = stt.getPosType();
1492   } else if constexpr (std::is_same_v<ToBufferOp, ToCoordinatesOp> ||
1493                        std::is_same_v<ToBufferOp, ToCoordinatesBufferOp>) {
1494     elemTp = stt.getCrdType();
1495     if constexpr (std::is_same_v<ToBufferOp, ToCoordinatesOp>)
1496       withStride = stt.getAoSCOOStart() <= adaptor.getLevel();
1497   } else if constexpr (std::is_same_v<ToBufferOp, ToValuesOp>) {
1498     elemTp = stt.getElementType();
1499   }
1500 
1501   assert(elemTp && "unhandled operation.");
1502   SmallVector<int64_t> bufShape = stt.getBatchLvlShape();
1503   bufShape.push_back(ShapedType::kDynamic);
1504 
1505   auto layout = withStride ? StridedLayoutAttr::StridedLayoutAttr::get(
1506                                  stt.getContext(), ShapedType::kDynamic,
1507                                  {ShapedType::kDynamic})
1508                            : StridedLayoutAttr();
1509   ret.emplace_back(MemRefType::get(bufShape, elemTp, layout));
1510   return success();
1511 }
1512 
1513 LogicalResult ToPositionsOp::verify() {
1514   auto stt = getSparseTensorType(getTensor());
1515   if (failed(lvlIsInBounds(getLevel(), getTensor())))
1516     return emitError("requested level is out of bounds");
1517   if (failed(isMatchingWidth(getResult(), stt.getPosWidth())))
1518     return emitError("unexpected type for positions");
1519   return success();
1520 }
1521 
1522 LogicalResult
1523 ToPositionsOp::inferReturnTypes(MLIRContext *ctx, std::optional<Location> loc,
1524                                 ValueRange ops, DictionaryAttr attr,
1525                                 OpaqueProperties prop, RegionRange region,
1526                                 SmallVectorImpl<mlir::Type> &ret) {
1527   return inferSparseBufferType<ToPositionsOp>(ops, attr, prop, region, ret);
1528 }
1529 
1530 LogicalResult ToCoordinatesOp::verify() {
1531   auto stt = getSparseTensorType(getTensor());
1532   if (failed(lvlIsInBounds(getLevel(), getTensor())))
1533     return emitError("requested level is out of bounds");
1534   if (failed(isMatchingWidth(getResult(), stt.getCrdWidth())))
1535     return emitError("unexpected type for coordinates");
1536   return success();
1537 }
1538 
1539 LogicalResult
1540 ToCoordinatesOp::inferReturnTypes(MLIRContext *ctx, std::optional<Location> loc,
1541                                   ValueRange ops, DictionaryAttr attr,
1542                                   OpaqueProperties prop, RegionRange region,
1543                                   SmallVectorImpl<mlir::Type> &ret) {
1544   return inferSparseBufferType<ToCoordinatesOp>(ops, attr, prop, region, ret);
1545 }
1546 
1547 LogicalResult ToCoordinatesBufferOp::verify() {
1548   auto stt = getSparseTensorType(getTensor());
1549   if (stt.getAoSCOOStart() >= stt.getLvlRank())
1550     return emitError("expected sparse tensor with a COO region");
1551   return success();
1552 }
1553 
1554 LogicalResult ToCoordinatesBufferOp::inferReturnTypes(
1555     MLIRContext *ctx, std::optional<Location> loc, ValueRange ops,
1556     DictionaryAttr attr, OpaqueProperties prop, RegionRange region,
1557     SmallVectorImpl<mlir::Type> &ret) {
1558   return inferSparseBufferType<ToCoordinatesBufferOp>(ops, attr, prop, region,
1559                                                       ret);
1560 }
1561 
1562 LogicalResult ToValuesOp::verify() {
1563   auto stt = getSparseTensorType(getTensor());
1564   auto mtp = getMemRefType(getResult());
1565   if (stt.getElementType() != mtp.getElementType())
1566     return emitError("unexpected mismatch in element types");
1567   return success();
1568 }
1569 
1570 LogicalResult ToValuesOp::inferReturnTypes(MLIRContext *ctx,
1571                                            std::optional<Location> loc,
1572                                            ValueRange ops, DictionaryAttr attr,
1573                                            OpaqueProperties prop,
1574                                            RegionRange region,
1575                                            SmallVectorImpl<mlir::Type> &ret) {
1576   return inferSparseBufferType<ToValuesOp>(ops, attr, prop, region, ret);
1577 }
1578 
1579 LogicalResult ToSliceOffsetOp::verify() {
1580   auto rank = getRankedTensorType(getSlice()).getRank();
1581   if (rank <= getDim().getSExtValue() || getDim().getSExtValue() < 0)
1582     return emitError("requested dimension out of bound");
1583   return success();
1584 }
1585 
1586 LogicalResult ToSliceStrideOp::verify() {
1587   auto rank = getRankedTensorType(getSlice()).getRank();
1588   if (rank <= getDim().getSExtValue() || getDim().getSExtValue() < 0)
1589     return emitError("requested dimension out of bound");
1590   return success();
1591 }
1592 
1593 LogicalResult GetStorageSpecifierOp::verify() {
1594   return verifySparsifierGetterSetter(getSpecifierKind(), getLevel(),
1595                                       getSpecifier(), getOperation());
1596 }
1597 
1598 template <typename SpecifierOp>
1599 static SetStorageSpecifierOp getSpecifierSetDef(SpecifierOp op) {
1600   return op.getSpecifier().template getDefiningOp<SetStorageSpecifierOp>();
1601 }
1602 
1603 OpFoldResult GetStorageSpecifierOp::fold(FoldAdaptor adaptor) {
1604   const StorageSpecifierKind kind = getSpecifierKind();
1605   const auto lvl = getLevel();
1606   for (auto op = getSpecifierSetDef(*this); op; op = getSpecifierSetDef(op))
1607     if (kind == op.getSpecifierKind() && lvl == op.getLevel())
1608       return op.getValue();
1609   return {};
1610 }
1611 
1612 LogicalResult SetStorageSpecifierOp::verify() {
1613   return verifySparsifierGetterSetter(getSpecifierKind(), getLevel(),
1614                                       getSpecifier(), getOperation());
1615 }
1616 
1617 template <class T>
1618 static LogicalResult verifyNumBlockArgs(T *op, Region &region,
1619                                         const char *regionName,
1620                                         TypeRange inputTypes, Type outputType) {
1621   unsigned numArgs = region.getNumArguments();
1622   unsigned expectedNum = inputTypes.size();
1623   if (numArgs != expectedNum)
1624     return op->emitError() << regionName << " region must have exactly "
1625                            << expectedNum << " arguments";
1626 
1627   for (unsigned i = 0; i < numArgs; i++) {
1628     Type typ = region.getArgument(i).getType();
1629     if (typ != inputTypes[i])
1630       return op->emitError() << regionName << " region argument " << (i + 1)
1631                              << " type mismatch";
1632   }
1633   Operation *term = region.front().getTerminator();
1634   YieldOp yield = dyn_cast<YieldOp>(term);
1635   if (!yield)
1636     return op->emitError() << regionName
1637                            << " region must end with sparse_tensor.yield";
1638   if (!yield.hasSingleResult() ||
1639       yield.getSingleResult().getType() != outputType)
1640     return op->emitError() << regionName << " region yield type mismatch";
1641 
1642   return success();
1643 }
1644 
1645 LogicalResult BinaryOp::verify() {
1646   NamedAttrList attrs = (*this)->getAttrs();
1647   Type leftType = getX().getType();
1648   Type rightType = getY().getType();
1649   Type outputType = getOutput().getType();
1650   Region &overlap = getOverlapRegion();
1651   Region &left = getLeftRegion();
1652   Region &right = getRightRegion();
1653 
1654   // Check correct number of block arguments and return type for each
1655   // non-empty region.
1656   if (!overlap.empty()) {
1657     if (failed(verifyNumBlockArgs(this, overlap, "overlap",
1658                                   TypeRange{leftType, rightType}, outputType)))
1659       return failure();
1660   }
1661   if (!left.empty()) {
1662     if (failed(verifyNumBlockArgs(this, left, "left", TypeRange{leftType},
1663                                   outputType)))
1664       return failure();
1665   } else if (getLeftIdentity()) {
1666     if (leftType != outputType)
1667       return emitError("left=identity requires first argument to have the same "
1668                        "type as the output");
1669   }
1670   if (!right.empty()) {
1671     if (failed(verifyNumBlockArgs(this, right, "right", TypeRange{rightType},
1672                                   outputType)))
1673       return failure();
1674   } else if (getRightIdentity()) {
1675     if (rightType != outputType)
1676       return emitError("right=identity requires second argument to have the "
1677                        "same type as the output");
1678   }
1679   return success();
1680 }
1681 
1682 LogicalResult UnaryOp::verify() {
1683   Type inputType = getX().getType();
1684   Type outputType = getOutput().getType();
1685 
1686   // Check correct number of block arguments and return type for each
1687   // non-empty region.
1688   Region &present = getPresentRegion();
1689   if (!present.empty()) {
1690     if (failed(verifyNumBlockArgs(this, present, "present",
1691                                   TypeRange{inputType}, outputType)))
1692       return failure();
1693   }
1694   Region &absent = getAbsentRegion();
1695   if (!absent.empty()) {
1696     if (failed(verifyNumBlockArgs(this, absent, "absent", TypeRange{},
1697                                   outputType)))
1698       return failure();
1699     // Absent branch can only yield invariant values.
1700     Block *absentBlock = &absent.front();
1701     Block *parent = getOperation()->getBlock();
1702     Value absentVal =
1703         cast<YieldOp>(absentBlock->getTerminator()).getSingleResult();
1704     if (auto arg = dyn_cast<BlockArgument>(absentVal)) {
1705       if (arg.getOwner() == parent)
1706         return emitError("absent region cannot yield linalg argument");
1707     } else if (Operation *def = absentVal.getDefiningOp()) {
1708       if (!isa<arith::ConstantOp>(def) &&
1709           (def->getBlock() == absentBlock || def->getBlock() == parent))
1710         return emitError("absent region cannot yield locally computed value");
1711     }
1712   }
1713   return success();
1714 }
1715 
1716 bool ConcatenateOp::needsExtraSort() {
1717   SparseTensorType dstStt = getSparseTensorType(*this);
1718   if (dstStt.isAllDense() || !dstStt.isAllOrdered())
1719     return false;
1720 
1721   bool allSameOrdered = llvm::all_of(getInputs(), [dstStt](Value op) {
1722     return getSparseTensorType(op).hasSameDimToLvl(dstStt);
1723   });
1724   // TODO: When conDim != 0, as long as conDim corresponding to the first level
1725   // in all input/output buffers, and all input/output buffers have the same
1726   // dimToLvl, the tmp COO buffer is still unnecessary (e.g, concatenate
1727   // CSC matrices along column).
1728   bool directLowerable =
1729       allSameOrdered && getDimension() == 0 && dstStt.isIdentity();
1730   return !directLowerable;
1731 }
1732 
1733 LogicalResult ConcatenateOp::verify() {
1734   const auto dstTp = getSparseTensorType(*this);
1735   const Dimension concatDim = getDimension();
1736   const Dimension dimRank = dstTp.getDimRank();
1737 
1738   if (getInputs().size() <= 1)
1739     return emitError("Need at least two tensors to concatenate.");
1740 
1741   if (concatDim >= dimRank)
1742     return emitError(llvm::formatv(
1743         "Concat-dimension is out of bounds for dimension-rank ({0} >= {1})",
1744         concatDim, dimRank));
1745 
1746   for (const auto &it : llvm::enumerate(getInputs())) {
1747     const auto i = it.index();
1748     const auto srcTp = getSparseTensorType(it.value());
1749     if (srcTp.hasDynamicDimShape())
1750       return emitError(llvm::formatv("Input tensor ${0} has dynamic shape", i));
1751     const Dimension srcDimRank = srcTp.getDimRank();
1752     if (srcDimRank != dimRank)
1753       return emitError(
1754           llvm::formatv("Input tensor ${0} has a different rank (rank={1}) "
1755                         "from the output tensor (rank={2}).",
1756                         i, srcDimRank, dimRank));
1757   }
1758 
1759   for (Dimension d = 0; d < dimRank; d++) {
1760     const Size dstSh = dstTp.getDimShape()[d];
1761     if (d == concatDim) {
1762       if (!ShapedType::isDynamic(dstSh)) {
1763         // If we reach here, then all inputs have static shapes.  So we
1764         // can use `getDimShape()[d]` instead of `*getDynamicDimSize(d)`
1765         // to avoid redundant assertions in the loop.
1766         Size sumSz = 0;
1767         for (const auto src : getInputs())
1768           sumSz += getSparseTensorType(src).getDimShape()[d];
1769         // If all dimension are statically known, the sum of all the input
1770         // dimensions should be equal to the output dimension.
1771         if (sumSz != dstSh)
1772           return emitError(
1773               "The concatenation dimension of the output tensor should be the "
1774               "sum of all the concatenation dimensions of the input tensors.");
1775       }
1776     } else {
1777       Size prev = dstSh;
1778       for (const auto src : getInputs()) {
1779         const auto sh = getSparseTensorType(src).getDimShape()[d];
1780         if (!ShapedType::isDynamic(prev) && sh != prev)
1781           return emitError("All dimensions (expect for the concatenating one) "
1782                            "should be equal.");
1783         prev = sh;
1784       }
1785     }
1786   }
1787 
1788   return success();
1789 }
1790 
1791 void PushBackOp::build(OpBuilder &builder, OperationState &result,
1792                        Value curSize, Value inBuffer, Value value) {
1793   build(builder, result, curSize, inBuffer, value, Value());
1794 }
1795 
1796 LogicalResult PushBackOp::verify() {
1797   if (Value n = getN()) {
1798     std::optional<int64_t> nValue = getConstantIntValue(n);
1799     if (nValue && nValue.value() < 1)
1800       return emitOpError("n must be not less than 1");
1801   }
1802   return success();
1803 }
1804 
1805 LogicalResult CompressOp::verify() {
1806   const auto stt = getSparseTensorType(getTensor());
1807   if (stt.getLvlRank() != 1 + static_cast<Level>(getLvlCoords().size()))
1808     return emitOpError("incorrect number of coordinates");
1809   return success();
1810 }
1811 
1812 void ForeachOp::build(
1813     OpBuilder &builder, OperationState &result, Value tensor,
1814     ValueRange initArgs, AffineMapAttr order,
1815     function_ref<void(OpBuilder &, Location, ValueRange, Value, ValueRange)>
1816         bodyBuilder) {
1817   build(builder, result, initArgs.getTypes(), tensor, initArgs, order);
1818   // Builds foreach body.
1819   if (!bodyBuilder)
1820     return;
1821   const auto stt = getSparseTensorType(tensor);
1822   const Dimension dimRank = stt.getDimRank();
1823 
1824   // Starts with `dimRank`-many coordinates.
1825   SmallVector<Type> blockArgTypes(dimRank, builder.getIndexType());
1826   // Followed by one value.
1827   blockArgTypes.push_back(stt.getElementType());
1828   // Followed by the reduction variables.
1829   blockArgTypes.append(initArgs.getTypes().begin(), initArgs.getTypes().end());
1830 
1831   SmallVector<Location> blockArgLocs(blockArgTypes.size(), tensor.getLoc());
1832 
1833   OpBuilder::InsertionGuard guard(builder);
1834   auto &region = *result.regions.front();
1835   Block *bodyBlock =
1836       builder.createBlock(&region, region.end(), blockArgTypes, blockArgLocs);
1837   bodyBuilder(builder, result.location,
1838               bodyBlock->getArguments().slice(0, dimRank),
1839               bodyBlock->getArguments()[dimRank],
1840               bodyBlock->getArguments().drop_front(dimRank + 1));
1841 }
1842 
1843 LogicalResult ForeachOp::verify() {
1844   const auto t = getSparseTensorType(getTensor());
1845   const Dimension dimRank = t.getDimRank();
1846   const auto args = getBody()->getArguments();
1847 
1848   if (getOrder().has_value() && getOrder()->getNumDims() != t.getLvlRank())
1849     return emitError("Level traverse order does not match tensor's level rank");
1850 
1851   if (dimRank + 1 + getInitArgs().size() != args.size())
1852     return emitError("Unmatched number of arguments in the block");
1853 
1854   if (getNumResults() != getInitArgs().size())
1855     return emitError("Mismatch in number of init arguments and results");
1856 
1857   if (getResultTypes() != getInitArgs().getTypes())
1858     return emitError("Mismatch in types of init arguments and results");
1859 
1860   // Cannot mark this const, because the getters aren't.
1861   auto yield = cast<YieldOp>(getBody()->getTerminator());
1862   if (yield.getNumOperands() != getNumResults() ||
1863       yield.getOperands().getTypes() != getResultTypes())
1864     return emitError("Mismatch in types of yield values and results");
1865 
1866   const auto iTp = IndexType::get(getContext());
1867   for (Dimension d = 0; d < dimRank; d++)
1868     if (args[d].getType() != iTp)
1869       emitError(
1870           llvm::formatv("Expecting Index type for argument at index {0}", d));
1871 
1872   const auto elemTp = t.getElementType();
1873   const auto valueTp = args[dimRank].getType();
1874   if (elemTp != valueTp)
1875     emitError(llvm::formatv("Unmatched element type between input tensor and "
1876                             "block argument, expected:{0}, got: {1}",
1877                             elemTp, valueTp));
1878   return success();
1879 }
1880 
1881 OpFoldResult ReorderCOOOp::fold(FoldAdaptor adaptor) {
1882   if (getSparseTensorEncoding(getInputCoo().getType()) ==
1883       getSparseTensorEncoding(getResultCoo().getType()))
1884     return getInputCoo();
1885 
1886   return {};
1887 }
1888 
1889 LogicalResult ReorderCOOOp::verify() {
1890   SparseTensorType srcStt = getSparseTensorType(getInputCoo());
1891   SparseTensorType dstStt = getSparseTensorType(getResultCoo());
1892 
1893   if (!srcStt.isCOOType() || !dstStt.isCOOType())
1894     emitError("Expected COO sparse tensors only");
1895 
1896   if (!srcStt.hasSameDimToLvl(dstStt))
1897     emitError("Unmatched dim2lvl map between input and result COO");
1898 
1899   if (srcStt.getPosType() != dstStt.getPosType() ||
1900       srcStt.getCrdType() != dstStt.getCrdType() ||
1901       srcStt.getElementType() != dstStt.getElementType())
1902     emitError("Unmatched storage format between input and result COO");
1903 
1904   return success();
1905 }
1906 
1907 LogicalResult ReduceOp::verify() {
1908   Type inputType = getX().getType();
1909   Region &formula = getRegion();
1910   return verifyNumBlockArgs(this, formula, "reduce",
1911                             TypeRange{inputType, inputType}, inputType);
1912 }
1913 
1914 LogicalResult SelectOp::verify() {
1915   Builder b(getContext());
1916   Type inputType = getX().getType();
1917   Type boolType = b.getI1Type();
1918   Region &formula = getRegion();
1919   return verifyNumBlockArgs(this, formula, "select", TypeRange{inputType},
1920                             boolType);
1921 }
1922 
1923 LogicalResult SortOp::verify() {
1924   AffineMap xPerm = getPermMap();
1925   uint64_t nx = xPerm.getNumDims();
1926   if (nx < 1)
1927     emitError(llvm::formatv("Expected rank(perm_map) > 1, got {0}", nx));
1928 
1929   if (!xPerm.isPermutation())
1930     emitError(llvm::formatv("Expected a permutation map, got {0}", xPerm));
1931 
1932   // We can't check the size of the buffers when n or buffer dimensions aren't
1933   // compile-time constants.
1934   std::optional<int64_t> cn = getConstantIntValue(getN());
1935   if (!cn)
1936     return success();
1937 
1938   // Verify dimensions.
1939   const auto checkDim = [&](Value v, Size minSize, const char *message) {
1940     const Size sh = getMemRefType(v).getShape()[0];
1941     if (!ShapedType::isDynamic(sh) && sh < minSize)
1942       emitError(llvm::formatv("{0} got {1} < {2}", message, sh, minSize));
1943   };
1944   uint64_t n = cn.value();
1945   uint64_t ny = 0;
1946   if (auto nyAttr = getNyAttr())
1947     ny = nyAttr.getInt();
1948   checkDim(getXy(), n * (nx + ny),
1949            "Expected dimension(xy) >= n * (rank(perm_map) + ny)");
1950   for (Value opnd : getYs())
1951     checkDim(opnd, n, "Expected dimension(y) >= n");
1952 
1953   return success();
1954 }
1955 
1956 /// Materialize a single constant operation from a given attribute value with
1957 /// the desired resultant type.
1958 Operation *SparseTensorDialect::materializeConstant(OpBuilder &builder,
1959                                                     Attribute value, Type type,
1960                                                     Location loc) {
1961   if (auto op = arith::ConstantOp::materialize(builder, value, type, loc))
1962     return op;
1963   return nullptr;
1964 }
1965 
1966 namespace {
1967 struct SparseTensorAsmDialectInterface : public OpAsmDialectInterface {
1968   using OpAsmDialectInterface::OpAsmDialectInterface;
1969 
1970   AliasResult getAlias(Attribute attr, raw_ostream &os) const override {
1971     if (attr.isa<SparseTensorEncodingAttr>()) {
1972       os << "sparse";
1973       return AliasResult::OverridableAlias;
1974     }
1975     return AliasResult::NoAlias;
1976   }
1977 };
1978 } // namespace
1979 
1980 void SparseTensorDialect::initialize() {
1981   addInterface<SparseTensorAsmDialectInterface>();
1982   addAttributes<
1983 #define GET_ATTRDEF_LIST
1984 #include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.cpp.inc"
1985       >();
1986   addTypes<
1987 #define GET_TYPEDEF_LIST
1988 #include "mlir/Dialect/SparseTensor/IR/SparseTensorTypes.cpp.inc"
1989       >();
1990   addOperations<
1991 #define GET_OP_LIST
1992 #include "mlir/Dialect/SparseTensor/IR/SparseTensorOps.cpp.inc"
1993       >();
1994   declarePromisedInterfaces<
1995       bufferization::BufferizableOpInterface, ConcatenateOp, ConvertOp, LoadOp,
1996       NewOp, NumberOfEntriesOp, AssembleOp, DisassembleOp,
1997       ToCoordinatesBufferOp, ToCoordinatesOp, ToPositionsOp, ToValuesOp>();
1998 }
1999 
2000 #define GET_OP_CLASSES
2001 #include "mlir/Dialect/SparseTensor/IR/SparseTensorOps.cpp.inc"
2002 
2003 #include "mlir/Dialect/SparseTensor/IR/SparseTensorOpsDialect.cpp.inc"
2004