xref: /llvm-project/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp (revision 94e27c265a9aeb3659175ecee81a68d1763e0180)
1 //===- SparseTensorDialect.cpp - Sparse tensor dialect implementation -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include <utility>
10 
11 #include "Detail/DimLvlMapParser.h"
12 
13 #include "mlir/Dialect/SparseTensor/IR/Enums.h"
14 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
15 #include "mlir/Dialect/SparseTensor/IR/SparseTensorStorageLayout.h"
16 #include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h"
17 
18 #include "mlir/Dialect/Arith/IR/Arith.h"
19 #include "mlir/Dialect/Utils/StaticValueUtils.h"
20 #include "mlir/IR/Builders.h"
21 #include "mlir/IR/DialectImplementation.h"
22 #include "mlir/IR/Matchers.h"
23 #include "mlir/IR/OpImplementation.h"
24 #include "mlir/IR/PatternMatch.h"
25 #include "llvm/ADT/TypeSwitch.h"
26 #include "llvm/Support/FormatVariadic.h"
27 
28 #define GET_ATTRDEF_CLASSES
29 #include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.cpp.inc"
30 #include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrEnums.cpp.inc"
31 
32 #define GET_TYPEDEF_CLASSES
33 #include "mlir/Dialect/SparseTensor/IR/SparseTensorTypes.cpp.inc"
34 
35 using namespace mlir;
36 using namespace mlir::sparse_tensor;
37 
38 // Support hashing LevelType such that SparseTensorEncodingAttr can be hashed as
39 // well.
40 namespace mlir::sparse_tensor {
41 llvm::hash_code hash_value(LevelType lt) {
42   return llvm::hash_value(static_cast<uint64_t>(lt));
43 }
44 } // namespace mlir::sparse_tensor
45 
46 //===----------------------------------------------------------------------===//
47 // Local Convenience Methods.
48 //===----------------------------------------------------------------------===//
49 
50 static constexpr bool acceptBitWidth(unsigned bitWidth) {
51   switch (bitWidth) {
52   case 0:
53   case 8:
54   case 16:
55   case 32:
56   case 64:
57     return true;
58   default:
59     return false;
60   }
61 }
62 
63 //===----------------------------------------------------------------------===//
64 // SparseTensorDialect StorageLayout.
65 //===----------------------------------------------------------------------===//
66 
67 static constexpr Level kInvalidLevel = -1u;
68 static constexpr Level kInvalidFieldIndex = -1u;
69 static constexpr FieldIndex kDataFieldStartingIdx = 0;
70 
71 void StorageLayout::foreachField(
72     llvm::function_ref<bool(FieldIndex, SparseTensorFieldKind, Level,
73                             LevelType)>
74         callback) const {
75   const auto lvlTypes = enc.getLvlTypes();
76   const Level lvlRank = enc.getLvlRank();
77   SmallVector<COOSegment> cooSegs = SparseTensorType(enc).getCOOSegments();
78   FieldIndex fieldIdx = kDataFieldStartingIdx;
79 
80   ArrayRef cooSegsRef = cooSegs;
81   // Per-level storage.
82   for (Level l = 0; l < lvlRank; /*l += 1 or l += AoSCooLen*/) {
83     const auto lt = lvlTypes[l];
84     if (isWithPosLT(lt)) {
85       if (!(callback(fieldIdx++, SparseTensorFieldKind::PosMemRef, l, lt)))
86         return;
87     }
88     if (isWithCrdLT(lt)) {
89       if (!(callback(fieldIdx++, SparseTensorFieldKind::CrdMemRef, l, lt)))
90         return;
91     }
92     if (!cooSegsRef.empty() && cooSegsRef.front().isSegmentStart(l)) {
93       if (!cooSegsRef.front().isSoA) {
94         // AoS COO, all singletons are fused into one memrefs. Skips the entire
95         // COO segement.
96         l = cooSegsRef.front().lvlRange.second;
97       } else {
98         // SoA COO, each singleton level has one memref.
99         l++;
100       }
101       // Expire handled COO segment.
102       cooSegsRef = cooSegsRef.drop_front();
103     } else {
104       // Non COO levels.
105       l++;
106     }
107   }
108   // The values array.
109   if (!(callback(fieldIdx++, SparseTensorFieldKind::ValMemRef, kInvalidLevel,
110                  LevelFormat::Undef)))
111     return;
112   // Put metadata at the end.
113   if (!(callback(fieldIdx++, SparseTensorFieldKind::StorageSpec, kInvalidLevel,
114                  LevelFormat::Undef)))
115     return;
116 }
117 
118 void sparse_tensor::foreachFieldAndTypeInSparseTensor(
119     SparseTensorType stt,
120     llvm::function_ref<bool(Type, FieldIndex, SparseTensorFieldKind, Level,
121                             LevelType)>
122         callback) {
123   assert(stt.hasEncoding());
124   // Construct the basic types.
125   const Type crdType = stt.getCrdType();
126   const Type posType = stt.getPosType();
127   const Type eltType = stt.getElementType();
128 
129   SmallVector<int64_t> memrefShape = stt.getBatchLvlShape();
130   memrefShape.push_back(ShapedType::kDynamic);
131 
132   const Type specType = StorageSpecifierType::get(stt.getEncoding());
133   // memref<[batch] x ? x pos>  positions
134   const Type posMemType = MemRefType::get(memrefShape, posType);
135   // memref<[batch] x ? x crd>  coordinates
136   const Type crdMemType = MemRefType::get(memrefShape, crdType);
137   // memref<[batch] x ? x eltType> values
138   const Type valMemType = MemRefType::get(memrefShape, eltType);
139 
140   StorageLayout(stt).foreachField([specType, posMemType, crdMemType, valMemType,
141                                    callback](FieldIndex fieldIdx,
142                                              SparseTensorFieldKind fieldKind,
143                                              Level lvl, LevelType lt) -> bool {
144     switch (fieldKind) {
145     case SparseTensorFieldKind::StorageSpec:
146       return callback(specType, fieldIdx, fieldKind, lvl, lt);
147     case SparseTensorFieldKind::PosMemRef:
148       return callback(posMemType, fieldIdx, fieldKind, lvl, lt);
149     case SparseTensorFieldKind::CrdMemRef:
150       return callback(crdMemType, fieldIdx, fieldKind, lvl, lt);
151     case SparseTensorFieldKind::ValMemRef:
152       return callback(valMemType, fieldIdx, fieldKind, lvl, lt);
153     };
154     llvm_unreachable("unrecognized field kind");
155   });
156 }
157 
158 unsigned StorageLayout::getNumFields() const {
159   unsigned numFields = 0;
160   foreachField([&numFields](FieldIndex, SparseTensorFieldKind, Level,
161                             LevelType) -> bool {
162     numFields++;
163     return true;
164   });
165   return numFields;
166 }
167 
168 unsigned StorageLayout::getNumDataFields() const {
169   unsigned numFields = 0; // one value memref
170   foreachField([&numFields](FieldIndex fidx, SparseTensorFieldKind, Level,
171                             LevelType) -> bool {
172     if (fidx >= kDataFieldStartingIdx)
173       numFields++;
174     return true;
175   });
176   numFields -= 1; // the last field is StorageSpecifier
177   assert(numFields == getNumFields() - kDataFieldStartingIdx - 1);
178   return numFields;
179 }
180 
181 std::pair<FieldIndex, unsigned>
182 StorageLayout::getFieldIndexAndStride(SparseTensorFieldKind kind,
183                                       std::optional<Level> lvl) const {
184   FieldIndex fieldIdx = kInvalidFieldIndex;
185   unsigned stride = 1;
186   if (kind == SparseTensorFieldKind::CrdMemRef) {
187     assert(lvl.has_value());
188     const Level cooStart = SparseTensorType(enc).getAoSCOOStart();
189     const Level lvlRank = enc.getLvlRank();
190     if (lvl.value() >= cooStart && lvl.value() < lvlRank) {
191       lvl = cooStart;
192       stride = lvlRank - cooStart;
193     }
194   }
195   foreachField([lvl, kind, &fieldIdx](FieldIndex fIdx,
196                                       SparseTensorFieldKind fKind, Level fLvl,
197                                       LevelType lt) -> bool {
198     if ((lvl && fLvl == lvl.value() && kind == fKind) ||
199         (kind == fKind && fKind == SparseTensorFieldKind::ValMemRef)) {
200       fieldIdx = fIdx;
201       // Returns false to break the iteration.
202       return false;
203     }
204     return true;
205   });
206   assert(fieldIdx != kInvalidFieldIndex);
207   return std::pair<FieldIndex, unsigned>(fieldIdx, stride);
208 }
209 
210 //===----------------------------------------------------------------------===//
211 // SparseTensorDialect Attribute Methods.
212 //===----------------------------------------------------------------------===//
213 
214 std::optional<uint64_t> SparseTensorDimSliceAttr::getStatic(int64_t v) {
215   return isDynamic(v) ? std::nullopt
216                       : std::make_optional(static_cast<uint64_t>(v));
217 }
218 
219 std::optional<uint64_t> SparseTensorDimSliceAttr::getStaticOffset() const {
220   return getStatic(getOffset());
221 }
222 
223 std::optional<uint64_t> SparseTensorDimSliceAttr::getStaticStride() const {
224   return getStatic(getStride());
225 }
226 
227 std::optional<uint64_t> SparseTensorDimSliceAttr::getStaticSize() const {
228   return getStatic(getSize());
229 }
230 
231 bool SparseTensorDimSliceAttr::isCompletelyDynamic() const {
232   return isDynamic(getOffset()) && isDynamic(getStride()) &&
233          isDynamic(getSize());
234 }
235 
236 std::string SparseTensorDimSliceAttr::getStaticString(int64_t v) {
237   return isDynamic(v) ? "?" : std::to_string(v);
238 }
239 
240 void SparseTensorDimSliceAttr::print(llvm::raw_ostream &os) const {
241   assert(getImpl() && "Uninitialized SparseTensorDimSliceAttr");
242   os << '(';
243   os << getStaticString(getOffset());
244   os << ", ";
245   os << getStaticString(getSize());
246   os << ", ";
247   os << getStaticString(getStride());
248   os << ')';
249 }
250 
251 void SparseTensorDimSliceAttr::print(AsmPrinter &printer) const {
252   print(printer.getStream());
253 }
254 
255 static ParseResult parseOptionalStaticSlice(int64_t &result,
256                                             AsmParser &parser) {
257   auto parseResult = parser.parseOptionalInteger(result);
258   if (parseResult.has_value()) {
259     if (parseResult.value().succeeded() && result < 0) {
260       parser.emitError(
261           parser.getCurrentLocation(),
262           "expect positive value or ? for slice offset/size/stride");
263       return failure();
264     }
265     return parseResult.value();
266   }
267 
268   // Else, and '?' which represented dynamic slice
269   result = SparseTensorDimSliceAttr::kDynamic;
270   return parser.parseQuestion();
271 }
272 
273 Attribute SparseTensorDimSliceAttr::parse(AsmParser &parser, Type type) {
274   int64_t offset = kDynamic, size = kDynamic, stride = kDynamic;
275 
276   if (failed(parser.parseLParen()) ||
277       failed(parseOptionalStaticSlice(offset, parser)) ||
278       failed(parser.parseComma()) ||
279       failed(parseOptionalStaticSlice(size, parser)) ||
280       failed(parser.parseComma()) ||
281       failed(parseOptionalStaticSlice(stride, parser)) ||
282       failed(parser.parseRParen()))
283     return {};
284 
285   return parser.getChecked<SparseTensorDimSliceAttr>(parser.getContext(),
286                                                      offset, size, stride);
287 }
288 
289 LogicalResult
290 SparseTensorDimSliceAttr::verify(function_ref<InFlightDiagnostic()> emitError,
291                                  int64_t offset, int64_t size, int64_t stride) {
292   if (!isDynamic(offset) && offset < 0)
293     return emitError() << "expect non-negative value or ? for slice offset";
294   if (!isDynamic(size) && size <= 0)
295     return emitError() << "expect positive value or ? for slice size";
296   if (!isDynamic(stride) && stride <= 0)
297     return emitError() << "expect positive value or ? for slice stride";
298   return success();
299 }
300 
301 SparseTensorEncodingAttr
302 SparseTensorEncodingAttr::withDimToLvl(AffineMap dimToLvl) const {
303   assert(getImpl() && "Uninitialized SparseTensorEncodingAttr");
304   return SparseTensorEncodingAttr::get(getContext(), getLvlTypes(), dimToLvl,
305                                        AffineMap(), getPosWidth(),
306                                        getCrdWidth());
307 }
308 
309 SparseTensorEncodingAttr
310 SparseTensorEncodingAttr::withDimToLvl(SparseTensorEncodingAttr enc) const {
311   return withDimToLvl(enc ? enc.getDimToLvl() : AffineMap());
312 }
313 
314 SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutDimToLvl() const {
315   return withDimToLvl(AffineMap());
316 }
317 
318 SparseTensorEncodingAttr
319 SparseTensorEncodingAttr::withBitWidths(unsigned posWidth,
320                                         unsigned crdWidth) const {
321   assert(getImpl() && "Uninitialized SparseTensorEncodingAttr");
322   return SparseTensorEncodingAttr::get(getContext(), getLvlTypes(),
323                                        getDimToLvl(), getLvlToDim(), posWidth,
324                                        crdWidth);
325 }
326 
327 SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutBitWidths() const {
328   return withBitWidths(0, 0);
329 }
330 
331 SparseTensorEncodingAttr SparseTensorEncodingAttr::withDimSlices(
332     ArrayRef<SparseTensorDimSliceAttr> dimSlices) const {
333   return SparseTensorEncodingAttr::get(getContext(), getLvlTypes(),
334                                        getDimToLvl(), getLvlToDim(),
335                                        getPosWidth(), getCrdWidth(), dimSlices);
336 }
337 
338 SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutDimSlices() const {
339   return withDimSlices(ArrayRef<SparseTensorDimSliceAttr>{});
340 }
341 
342 uint64_t SparseTensorEncodingAttr::getBatchLvlRank() const {
343   ArrayRef<LevelType> lvlTypes = getLvlTypes();
344   auto lastBatch = std::find_if(lvlTypes.rbegin(), lvlTypes.rend(), isBatchLT);
345   return std::distance(lastBatch, lvlTypes.rend());
346 }
347 
348 bool SparseTensorEncodingAttr::isAllDense() const {
349   return !getImpl() || llvm::all_of(getLvlTypes(), isDenseLT);
350 }
351 
352 bool SparseTensorEncodingAttr::isAllOrdered() const {
353   return !getImpl() || llvm::all_of(getLvlTypes(), isOrderedLT);
354 }
355 
356 bool SparseTensorEncodingAttr::isIdentity() const {
357   return !getImpl() || !getDimToLvl() || getDimToLvl().isIdentity();
358 }
359 
360 bool SparseTensorEncodingAttr::isPermutation() const {
361   return !getImpl() || !getDimToLvl() || getDimToLvl().isPermutation();
362 }
363 
364 Dimension SparseTensorEncodingAttr::getDimRank() const {
365   assert(getImpl() && "Uninitialized SparseTensorEncodingAttr");
366   const auto dimToLvl = getDimToLvl();
367   return dimToLvl ? dimToLvl.getNumDims() : getLvlRank();
368 }
369 
370 Level SparseTensorEncodingAttr::getLvlRank() const {
371   assert(getImpl() && "Uninitialized SparseTensorEncodingAttr");
372   return getLvlTypes().size();
373 }
374 
375 LevelType SparseTensorEncodingAttr::getLvlType(Level l) const {
376   if (!getImpl())
377     return LevelFormat::Batch;
378   assert(l < getLvlRank() && "Level is out of bounds");
379   return getLvlTypes()[l];
380 }
381 
382 bool SparseTensorEncodingAttr::isSlice() const {
383   assert(getImpl() && "Uninitialized SparseTensorEncodingAttr");
384   return !getDimSlices().empty();
385 }
386 
387 SparseTensorDimSliceAttr
388 SparseTensorEncodingAttr::getDimSlice(Dimension dim) const {
389   assert(isSlice() && "Is not a slice");
390   const auto dimSlices = getDimSlices();
391   assert(dim < dimSlices.size() && "Dimension is out of bounds");
392   return dimSlices[dim];
393 }
394 
395 std::optional<uint64_t>
396 SparseTensorEncodingAttr::getStaticDimSliceOffset(Dimension dim) const {
397   return getDimSlice(dim).getStaticOffset();
398 }
399 
400 std::optional<uint64_t>
401 SparseTensorEncodingAttr::getStaticDimSliceStride(Dimension dim) const {
402   return getDimSlice(dim).getStaticStride();
403 }
404 
405 std::optional<uint64_t>
406 SparseTensorEncodingAttr::getStaticLvlSliceOffset(Level lvl) const {
407   return getStaticDimSliceOffset(toDim(*this, lvl));
408 }
409 
410 std::optional<uint64_t>
411 SparseTensorEncodingAttr::getStaticLvlSliceStride(Level lvl) const {
412   return getStaticDimSliceStride(toDim(*this, lvl));
413 }
414 
415 SmallVector<int64_t>
416 SparseTensorEncodingAttr::translateShape(ArrayRef<int64_t> srcShape,
417                                          CrdTransDirectionKind dir) const {
418   if (isIdentity())
419     return SmallVector<int64_t>(srcShape);
420 
421   SmallVector<int64_t> ret;
422   unsigned rank =
423       dir == CrdTransDirectionKind::dim2lvl ? getLvlRank() : getDimRank();
424   ret.reserve(rank);
425 
426   if (isPermutation()) {
427     for (unsigned r = 0; r < rank; r++) {
428       unsigned trans = dir == CrdTransDirectionKind::dim2lvl ? toDim(*this, r)
429                                                              : toLvl(*this, r);
430       ret.push_back(srcShape[trans]);
431     }
432     return ret;
433   }
434 
435   // Handle non-permutation maps.
436   AffineMap transMap =
437       dir == CrdTransDirectionKind::dim2lvl ? getDimToLvl() : getLvlToDim();
438 
439   SmallVector<AffineExpr> dimRep;
440   dimRep.reserve(srcShape.size());
441   for (int64_t sz : srcShape) {
442     if (!ShapedType::isDynamic(sz)) {
443       // Push back the max coordinate for the given dimension/level size.
444       dimRep.push_back(getAffineConstantExpr(sz - 1, getContext()));
445     } else {
446       // A dynamic size, use a AffineDimExpr to symbolize the value.
447       dimRep.push_back(getAffineDimExpr(dimRep.size(), getContext()));
448     }
449   };
450 
451   for (AffineExpr exp : transMap.getResults()) {
452     // Do constant propagation on the affine map.
453     AffineExpr evalExp =
454         simplifyAffineExpr(exp.replaceDims(dimRep), srcShape.size(), 0);
455     // use llvm namespace here to avoid ambiguity
456     if (auto c = llvm::dyn_cast<AffineConstantExpr>(evalExp)) {
457       ret.push_back(c.getValue() + 1);
458     } else {
459       if (auto mod = llvm::dyn_cast<AffineBinaryOpExpr>(evalExp);
460           mod && mod.getKind() == AffineExprKind::Mod) {
461         // We can still infer a static bound for expressions in form
462         // "d % constant" since d % constant \in [0, constant).
463         if (auto bound = llvm::dyn_cast<AffineConstantExpr>(mod.getRHS())) {
464           ret.push_back(bound.getValue());
465           continue;
466         }
467       }
468       ret.push_back(ShapedType::kDynamic);
469     }
470   }
471   assert(ret.size() == rank);
472   return ret;
473 }
474 
475 ValueRange
476 SparseTensorEncodingAttr::translateCrds(OpBuilder &builder, Location loc,
477                                         ValueRange crds,
478                                         CrdTransDirectionKind dir) const {
479   if (!getImpl())
480     return crds;
481 
482   SmallVector<Type> retType(
483       dir == CrdTransDirectionKind::lvl2dim ? getDimRank() : getLvlRank(),
484       builder.getIndexType());
485   auto transOp = builder.create<CrdTranslateOp>(loc, retType, crds, dir, *this);
486   return transOp.getOutCrds();
487 }
488 
489 Attribute SparseTensorEncodingAttr::parse(AsmParser &parser, Type type) {
490   // Open "<{" part.
491   if (failed(parser.parseLess()))
492     return {};
493   if (failed(parser.parseLBrace()))
494     return {};
495 
496   // Process the data from the parsed dictionary value into struct-like data.
497   SmallVector<LevelType> lvlTypes;
498   SmallVector<SparseTensorDimSliceAttr> dimSlices;
499   AffineMap dimToLvl = {};
500   AffineMap lvlToDim = {};
501   unsigned posWidth = 0;
502   unsigned crdWidth = 0;
503   StringRef attrName;
504   SmallVector<StringRef, 3> keys = {"map", "posWidth", "crdWidth"};
505   while (succeeded(parser.parseOptionalKeyword(&attrName))) {
506     // Detect admissible keyword.
507     auto *it = find(keys, attrName);
508     if (it == keys.end()) {
509       parser.emitError(parser.getNameLoc(), "unexpected key: ") << attrName;
510       return {};
511     }
512     unsigned keyWordIndex = it - keys.begin();
513     // Consume the `=` after keys
514     if (failed(parser.parseEqual()))
515       return {};
516     // Dispatch on keyword.
517     switch (keyWordIndex) {
518     case 0: { // map
519       ir_detail::DimLvlMapParser cParser(parser);
520       auto res = cParser.parseDimLvlMap();
521       if (failed(res))
522         return {};
523       const auto &dlm = *res;
524 
525       const Level lvlRank = dlm.getLvlRank();
526       for (Level lvl = 0; lvl < lvlRank; lvl++)
527         lvlTypes.push_back(dlm.getLvlType(lvl));
528 
529       const Dimension dimRank = dlm.getDimRank();
530       for (Dimension dim = 0; dim < dimRank; dim++)
531         dimSlices.push_back(dlm.getDimSlice(dim));
532       // NOTE: the old syntax requires an all-or-nothing approach to
533       // `dimSlices`; therefore, if any slice actually exists then we need
534       // to convert null-DSA into default/nop DSA.
535       const auto isDefined = [](SparseTensorDimSliceAttr slice) {
536         return static_cast<bool>(slice.getImpl());
537       };
538       if (llvm::any_of(dimSlices, isDefined)) {
539         const auto defaultSlice =
540             SparseTensorDimSliceAttr::get(parser.getContext());
541         for (Dimension dim = 0; dim < dimRank; dim++)
542           if (!isDefined(dimSlices[dim]))
543             dimSlices[dim] = defaultSlice;
544       } else {
545         dimSlices.clear();
546       }
547 
548       dimToLvl = dlm.getDimToLvlMap(parser.getContext());
549       lvlToDim = dlm.getLvlToDimMap(parser.getContext());
550       break;
551     }
552     case 1: { // posWidth
553       Attribute attr;
554       if (failed(parser.parseAttribute(attr)))
555         return {};
556       auto intAttr = llvm::dyn_cast<IntegerAttr>(attr);
557       if (!intAttr) {
558         parser.emitError(parser.getNameLoc(),
559                          "expected an integral position bitwidth");
560         return {};
561       }
562       posWidth = intAttr.getInt();
563       break;
564     }
565     case 2: { // crdWidth
566       Attribute attr;
567       if (failed(parser.parseAttribute(attr)))
568         return {};
569       auto intAttr = llvm::dyn_cast<IntegerAttr>(attr);
570       if (!intAttr) {
571         parser.emitError(parser.getNameLoc(),
572                          "expected an integral index bitwidth");
573         return {};
574       }
575       crdWidth = intAttr.getInt();
576       break;
577     }
578     } // switch
579     // Only last item can omit the comma.
580     if (parser.parseOptionalComma().failed())
581       break;
582   }
583 
584   // Close "}>" part.
585   if (failed(parser.parseRBrace()))
586     return {};
587   if (failed(parser.parseGreater()))
588     return {};
589 
590   // Construct struct-like storage for attribute.
591   if (!lvlToDim || lvlToDim.isEmpty()) {
592     lvlToDim = inferLvlToDim(dimToLvl, parser.getContext());
593   }
594   return parser.getChecked<SparseTensorEncodingAttr>(
595       parser.getContext(), lvlTypes, dimToLvl, lvlToDim, posWidth, crdWidth,
596       dimSlices);
597 }
598 
599 void SparseTensorEncodingAttr::print(AsmPrinter &printer) const {
600   auto map = static_cast<AffineMap>(getDimToLvl());
601   // Empty affine map indicates identity map
602   if (!map)
603     map = AffineMap::getMultiDimIdentityMap(getLvlTypes().size(), getContext());
604   printer << "<{ map = ";
605   printSymbols(map, printer);
606   printer << '(';
607   printDimensions(map, printer, getDimSlices());
608   printer << ") -> (";
609   printLevels(map, printer, getLvlTypes());
610   printer << ')';
611   // Print remaining members only for non-default values.
612   if (getPosWidth())
613     printer << ", posWidth = " << getPosWidth();
614   if (getCrdWidth())
615     printer << ", crdWidth = " << getCrdWidth();
616   printer << " }>";
617 }
618 
619 void SparseTensorEncodingAttr::printSymbols(AffineMap &map,
620                                             AsmPrinter &printer) const {
621   if (map.getNumSymbols() == 0)
622     return;
623   printer << '[';
624   for (unsigned i = 0, n = map.getNumSymbols() - 1; i < n; i++)
625     printer << 's' << i << ", ";
626   if (map.getNumSymbols() >= 1)
627     printer << 's' << map.getNumSymbols() - 1;
628   printer << ']';
629 }
630 
631 void SparseTensorEncodingAttr::printDimensions(
632     AffineMap &map, AsmPrinter &printer,
633     ArrayRef<SparseTensorDimSliceAttr> dimSlices) const {
634   if (!dimSlices.empty()) {
635     for (unsigned i = 0, n = map.getNumDims() - 1; i < n; i++)
636       printer << 'd' << i << " : " << dimSlices[i] << ", ";
637     if (map.getNumDims() >= 1) {
638       printer << 'd' << map.getNumDims() - 1 << " : "
639               << dimSlices[map.getNumDims() - 1];
640     }
641   } else {
642     for (unsigned i = 0, n = map.getNumDims() - 1; i < n; i++)
643       printer << 'd' << i << ", ";
644     if (map.getNumDims() >= 1)
645       printer << 'd' << map.getNumDims() - 1;
646   }
647 }
648 
649 void SparseTensorEncodingAttr::printLevels(AffineMap &map, AsmPrinter &printer,
650                                            ArrayRef<LevelType> lvlTypes) const {
651   for (unsigned i = 0, n = map.getNumResults() - 1; i < n; i++) {
652     map.getResult(i).print(printer.getStream());
653     printer << " : " << toMLIRString(lvlTypes[i]) << ", ";
654   }
655   if (map.getNumResults() >= 1) {
656     auto lastIndex = map.getNumResults() - 1;
657     map.getResult(lastIndex).print(printer.getStream());
658     printer << " : " << toMLIRString(lvlTypes[lastIndex]);
659   }
660 }
661 
662 LogicalResult SparseTensorEncodingAttr::verify(
663     function_ref<InFlightDiagnostic()> emitError, ArrayRef<LevelType> lvlTypes,
664     AffineMap dimToLvl, AffineMap lvlToDim, unsigned posWidth,
665     unsigned crdWidth, ArrayRef<SparseTensorDimSliceAttr> dimSlices) {
666   if (!acceptBitWidth(posWidth))
667     return emitError() << "unexpected position bitwidth: " << posWidth;
668   if (!acceptBitWidth(crdWidth))
669     return emitError() << "unexpected coordinate bitwidth: " << crdWidth;
670   if (auto it = std::find_if(lvlTypes.begin(), lvlTypes.end(), isSingletonLT);
671       it != std::end(lvlTypes)) {
672     if (it == lvlTypes.begin() ||
673         (!isCompressedLT(*(it - 1)) && !isLooseCompressedLT(*(it - 1))))
674       return emitError() << "expected compressed or loose_compressed level "
675                             "before singleton level";
676     if (!std::all_of(it, lvlTypes.end(),
677                      [](LevelType i) { return isSingletonLT(i); }))
678       return emitError() << "expected all singleton lvlTypes "
679                             "following a singleton level";
680     // We can potentially support mixed SoA/AoS singleton levels.
681     if (!std::all_of(it, lvlTypes.end(), [it](LevelType i) {
682           return it->isa<LevelPropNonDefault::SoA>() ==
683                  i.isa<LevelPropNonDefault::SoA>();
684         })) {
685       return emitError() << "expected all singleton lvlTypes stored in the "
686                             "same memory layout (SoA vs AoS).";
687     }
688   }
689 
690   auto lastBatch = std::find_if(lvlTypes.rbegin(), lvlTypes.rend(), isBatchLT);
691   if (!std::all_of(lastBatch, lvlTypes.rend(), isBatchLT))
692     return emitError() << "Batch lvlType can only be leading levels.";
693 
694   // SoA property can only be applied on singleton level.
695   auto soaLvls = llvm::make_filter_range(lvlTypes, [](LevelType lt) {
696     return lt.isa<LevelPropNonDefault::SoA>();
697   });
698   if (llvm::any_of(soaLvls, [](LevelType lt) {
699         return !lt.isa<LevelFormat::Singleton>();
700       })) {
701     return emitError() << "SoA is only applicable to singleton lvlTypes.";
702   }
703 
704   // TODO: audit formats that actually are supported by backend.
705   if (auto it = std::find_if(lvlTypes.begin(), lvlTypes.end(), isNOutOfMLT);
706       it != std::end(lvlTypes)) {
707     if (it != lvlTypes.end() - 1)
708       return emitError() << "expected n_out_of_m to be the last level type";
709     if (!std::all_of(lvlTypes.begin(), it,
710                      [](LevelType i) { return isDenseLT(i); }))
711       return emitError() << "expected all dense lvlTypes "
712                             "before a n_out_of_m level";
713     if (dimToLvl && (dimToLvl.getNumDims() != dimToLvl.getNumResults())) {
714       if (!isBlockSparsity(dimToLvl)) {
715         return emitError()
716                << "expected 1xm block structure for n_out_of_m level";
717       }
718       auto sizes = getBlockSize(dimToLvl);
719       unsigned coefficient = 0;
720       for (const auto &elem : sizes) {
721         if (elem != 0) {
722           if (elem != coefficient && coefficient != 0) {
723             return emitError() << "expected only one blocked level "
724                                   "with the same coefficients";
725           }
726           coefficient = elem;
727         }
728       }
729       if (coefficient != getM(*it)) {
730         return emitError() << "expected coeffiencts of Affine expressions "
731                               "to be equal to m of n_out_of_m level";
732       }
733     }
734   }
735   // Before we can check that the level-rank is consistent/coherent
736   // across all fields, we need to define it.  The source-of-truth for
737   // the `getLvlRank` method is the length of the level-types array,
738   // since it must always be provided and have full rank; therefore we
739   // use that same source-of-truth here.
740   const Level lvlRank = lvlTypes.size();
741   if (lvlRank == 0)
742     return emitError() << "expected a non-empty array for lvlTypes";
743   // We save `dimRank` here because we'll also need it to verify `dimSlices`.
744   const Dimension dimRank = dimToLvl ? dimToLvl.getNumDims() : lvlRank;
745   if (dimToLvl) {
746     if (dimToLvl.getNumResults() != lvlRank)
747       return emitError()
748              << "level-rank mismatch between dimToLvl and lvlTypes: "
749              << dimToLvl.getNumResults() << " != " << lvlRank;
750     auto inferRes = inferLvlToDim(dimToLvl, dimToLvl.getContext());
751     // Symbols can't be inferred but are acceptable.
752     if (!inferRes && dimToLvl.getNumSymbols() == 0)
753       return emitError() << "failed to infer lvlToDim from dimToLvl";
754     if (lvlToDim && (inferRes != lvlToDim))
755       return emitError() << "expected lvlToDim to be an inverse of dimToLvl";
756     if (dimRank > lvlRank)
757       return emitError() << "unexpected dimToLvl mapping from " << dimRank
758                          << " to " << lvlRank;
759   }
760   if (!dimSlices.empty()) {
761     if (dimSlices.size() != dimRank)
762       return emitError()
763              << "dimension-rank mismatch between dimSlices and dimToLvl: "
764              << dimSlices.size() << " != " << dimRank;
765     // Compiler support for `dimSlices` currently requires that the two
766     // ranks agree.  (However, it does allow `dimToLvl` to be a permutation.)
767     if (dimRank != lvlRank)
768       return emitError()
769              << "dimSlices expected dimension-rank to match level-rank: "
770              << dimRank << " != " << lvlRank;
771   }
772   return success();
773 }
774 
775 LogicalResult SparseTensorEncodingAttr::verifyEncoding(
776     ArrayRef<Size> dimShape, Type elementType,
777     function_ref<InFlightDiagnostic()> emitError) const {
778   // Check structural integrity.  In particular, this ensures that the
779   // level-rank is coherent across all the fields.
780   if (failed(verify(emitError, getLvlTypes(), getDimToLvl(), getLvlToDim(),
781                     getPosWidth(), getCrdWidth(), getDimSlices())))
782     return failure();
783   // Check integrity with tensor type specifics.  In particular, we
784   // need only check that the dimension-rank of the tensor agrees with
785   // the dimension-rank of the encoding.
786   const Dimension dimRank = dimShape.size();
787   if (dimRank == 0)
788     return emitError() << "expected non-scalar sparse tensor";
789   if (getDimRank() != dimRank)
790     return emitError()
791            << "dimension-rank mismatch between encoding and tensor shape: "
792            << getDimRank() << " != " << dimRank;
793   return success();
794 }
795 
796 //===----------------------------------------------------------------------===//
797 // SparseTensorType Methods.
798 //===----------------------------------------------------------------------===//
799 
800 bool mlir::sparse_tensor::SparseTensorType::isCOOType(Level startLvl,
801                                                       bool isUnique) const {
802   if (!hasEncoding())
803     return false;
804   if (!isCompressedLvl(startLvl) && !isLooseCompressedLvl(startLvl))
805     return false;
806   for (Level l = startLvl + 1; l < lvlRank; ++l)
807     if (!isSingletonLvl(l))
808       return false;
809   // If isUnique is true, then make sure that the last level is unique,
810   // that is, when lvlRank == 1, the only compressed level is unique,
811   // and when lvlRank > 1, the last singleton is unique.
812   return !isUnique || isUniqueLvl(lvlRank - 1);
813 }
814 
815 Level mlir::sparse_tensor::SparseTensorType::getAoSCOOStart() const {
816   SmallVector<COOSegment> coo = getCOOSegments();
817   assert(coo.size() == 1 || coo.empty());
818   if (!coo.empty() && coo.front().isAoS()) {
819     return coo.front().lvlRange.first;
820   }
821   return lvlRank;
822 }
823 
824 SmallVector<COOSegment>
825 mlir::sparse_tensor::SparseTensorType::getCOOSegments() const {
826   SmallVector<COOSegment> ret;
827   if (!hasEncoding() || lvlRank <= 1)
828     return ret;
829 
830   ArrayRef<LevelType> lts = getLvlTypes();
831   Level l = 0;
832   while (l < lvlRank) {
833     auto lt = lts[l];
834     if (lt.isa<LevelFormat::Compressed, LevelFormat::LooseCompressed>()) {
835       auto cur = lts.begin() + l;
836       auto end = std::find_if(cur + 1, lts.end(), [](LevelType lt) {
837         return !lt.isa<LevelFormat::Singleton>();
838       });
839       unsigned cooLen = std::distance(cur, end);
840       if (cooLen > 1) {
841         // To support mixed SoA/AoS COO, we should break the segment when the
842         // storage scheme changes, for now we faithfully assume that all
843         // consecutive singleton levels have the same storage format as verified
844         // STEA.
845         ret.push_back(COOSegment{std::make_pair(l, l + cooLen),
846                                  lts[l + 1].isa<LevelPropNonDefault::SoA>()});
847       }
848       l += cooLen;
849     } else {
850       l++;
851     }
852   }
853   return ret;
854 }
855 
856 RankedTensorType
857 mlir::sparse_tensor::SparseTensorType::getCOOType(bool ordered) const {
858   SmallVector<LevelType> lvlTypes;
859   lvlTypes.reserve(lvlRank);
860   // A non-unique compressed level at beginning (unless this is
861   // also the last level, then it is unique).
862   lvlTypes.push_back(
863       *buildLevelType(LevelFormat::Compressed, ordered, lvlRank == 1));
864   if (lvlRank > 1) {
865     // Followed by n-2 non-unique singleton levels.
866     std::fill_n(std::back_inserter(lvlTypes), lvlRank - 2,
867                 *buildLevelType(LevelFormat::Singleton, ordered, false));
868     // Ends by a unique singleton level.
869     lvlTypes.push_back(*buildLevelType(LevelFormat::Singleton, ordered, true));
870   }
871   auto enc = SparseTensorEncodingAttr::get(getContext(), lvlTypes,
872                                            getDimToLvl(), getLvlToDim(),
873                                            getPosWidth(), getCrdWidth());
874   return RankedTensorType::get(getDimShape(), getElementType(), enc);
875 }
876 
877 //===----------------------------------------------------------------------===//
878 // Convenience Methods.
879 //===----------------------------------------------------------------------===//
880 
881 SparseTensorEncodingAttr
882 mlir::sparse_tensor::getSparseTensorEncoding(Type type) {
883   if (auto ttp = llvm::dyn_cast<RankedTensorType>(type))
884     return llvm::dyn_cast_or_null<SparseTensorEncodingAttr>(ttp.getEncoding());
885   if (auto mdtp = llvm::dyn_cast<StorageSpecifierType>(type))
886     return mdtp.getEncoding();
887   return nullptr;
888 }
889 
890 AffineMap mlir::sparse_tensor::inferLvlToDim(AffineMap dimToLvl,
891                                              MLIRContext *context) {
892   auto map = static_cast<AffineMap>(dimToLvl);
893   AffineMap lvlToDim;
894   // Return an empty lvlToDim when inference is not successful.
895   if (!map || map.getNumSymbols() != 0) {
896     lvlToDim = AffineMap();
897   } else if (map.isPermutation()) {
898     lvlToDim = inversePermutation(map);
899   } else if (isBlockSparsity(map)) {
900     lvlToDim = inverseBlockSparsity(map, context);
901   }
902   return lvlToDim;
903 }
904 
905 AffineMap mlir::sparse_tensor::inverseBlockSparsity(AffineMap dimToLvl,
906                                                     MLIRContext *context) {
907   SmallVector<AffineExpr> lvlExprs;
908   auto numLvls = dimToLvl.getNumResults();
909   lvlExprs.reserve(numLvls);
910   // lvlExprComponents stores information of the floordiv and mod operations
911   // applied to the same dimension, so as to build the lvlToDim map.
912   std::map<unsigned, SmallVector<AffineExpr, 3>> lvlExprComponents;
913   for (unsigned i = 0, n = numLvls; i < n; i++) {
914     auto result = dimToLvl.getResult(i);
915     if (auto binOp = dyn_cast<AffineBinaryOpExpr>(result)) {
916       if (result.getKind() == AffineExprKind::FloorDiv) {
917         // Position of the dimension in dimToLvl.
918         auto pos = dyn_cast<AffineDimExpr>(binOp.getLHS()).getPosition();
919         assert(lvlExprComponents.find(pos) == lvlExprComponents.end() &&
920                "expected only one floordiv for each dimension");
921         SmallVector<AffineExpr, 3> components;
922         // Level variable for floordiv.
923         components.push_back(getAffineDimExpr(i, context));
924         // Multiplier.
925         components.push_back(binOp.getRHS());
926         // Map key is the position of the dimension.
927         lvlExprComponents[pos] = components;
928       } else if (result.getKind() == AffineExprKind::Mod) {
929         auto pos = dyn_cast<AffineDimExpr>(binOp.getLHS()).getPosition();
930         assert(lvlExprComponents.find(pos) != lvlExprComponents.end() &&
931                "expected floordiv before mod");
932         // Add level variable for mod to the same vector
933         // of the corresponding floordiv.
934         lvlExprComponents[pos].push_back(getAffineDimExpr(i, context));
935       } else {
936         assert(false && "expected floordiv or mod");
937       }
938     } else {
939       lvlExprs.push_back(getAffineDimExpr(i, context));
940     }
941   }
942   // Build lvlExprs from lvlExprComponents.
943   // For example, for il = i floordiv 2 and ii = i mod 2, the components
944   // would be [il, 2, ii]. It could be used to build the AffineExpr
945   // i = il * 2 + ii in lvlToDim.
946   for (auto &components : lvlExprComponents) {
947     assert(components.second.size() == 3 &&
948            "expected 3 components to build lvlExprs");
949     auto mulOp = getAffineBinaryOpExpr(
950         AffineExprKind::Mul, components.second[0], components.second[1]);
951     auto addOp =
952         getAffineBinaryOpExpr(AffineExprKind::Add, mulOp, components.second[2]);
953     lvlExprs.push_back(addOp);
954   }
955   return dimToLvl.get(dimToLvl.getNumResults(), 0, lvlExprs, context);
956 }
957 
958 SmallVector<unsigned> mlir::sparse_tensor::getBlockSize(AffineMap dimToLvl) {
959   assert(isBlockSparsity(dimToLvl) &&
960          "expected dimToLvl to be block sparsity for calling getBlockSize");
961   SmallVector<unsigned> blockSize;
962   for (auto result : dimToLvl.getResults()) {
963     if (auto binOp = dyn_cast<AffineBinaryOpExpr>(result)) {
964       if (result.getKind() == AffineExprKind::Mod) {
965         blockSize.push_back(
966             dyn_cast<AffineConstantExpr>(binOp.getRHS()).getValue());
967       }
968     } else {
969       blockSize.push_back(0);
970     }
971   }
972   return blockSize;
973 }
974 
975 bool mlir::sparse_tensor::isBlockSparsity(AffineMap dimToLvl) {
976   if (!dimToLvl)
977     return false;
978   std::map<unsigned, int64_t> coeffientMap;
979   bool hasBlock = false;
980   for (auto result : dimToLvl.getResults()) {
981     if (auto binOp = dyn_cast<AffineBinaryOpExpr>(result)) {
982       // Check for "dim op const".
983       auto dimOp = dyn_cast<AffineDimExpr>(binOp.getLHS());
984       auto conOp = dyn_cast<AffineConstantExpr>(binOp.getRHS());
985       if (!dimOp || !conOp || conOp.getValue() <= 0)
986         return false;
987       // Inspect "dim / const" or "dim % const".
988       auto pos = dimOp.getPosition();
989       if (binOp.getKind() == AffineExprKind::FloorDiv) {
990         // Expect only one floordiv for each dimension.
991         if (coeffientMap.find(pos) != coeffientMap.end())
992           return false;
993         // Record coefficient of the floordiv.
994         coeffientMap[pos] = conOp.getValue();
995       } else if (binOp.getKind() == AffineExprKind::Mod) {
996         // Expect floordiv before mod.
997         if (coeffientMap.find(pos) == coeffientMap.end())
998           return false;
999         // Expect mod to have the same coefficient as floordiv.
1000         if (conOp.getValue() != coeffientMap[pos])
1001           return false;
1002         hasBlock = true;
1003       } else {
1004         return false;
1005       }
1006     } else if (auto dimOp = dyn_cast<AffineDimExpr>(result)) {
1007       auto pos = dimOp.getPosition();
1008       // Expect dim to be unset.
1009       if (coeffientMap.find(pos) != coeffientMap.end())
1010         return false;
1011       coeffientMap[pos] = 0;
1012     } else {
1013       return false;
1014     }
1015   }
1016   return hasBlock;
1017 }
1018 
1019 bool mlir::sparse_tensor::hasAnyNonIdentityOperandsOrResults(Operation *op) {
1020   auto hasNonIdentityMap = [](Value v) {
1021     auto stt = tryGetSparseTensorType(v);
1022     return stt && !stt->isIdentity();
1023   };
1024 
1025   return llvm::any_of(op->getOperands(), hasNonIdentityMap) ||
1026          llvm::any_of(op->getResults(), hasNonIdentityMap);
1027 }
1028 
1029 Dimension mlir::sparse_tensor::toDim(SparseTensorEncodingAttr enc, Level l) {
1030   if (enc) {
1031     assert(enc.isPermutation() && "Non permutation map not supported");
1032     if (const auto dimToLvl = enc.getDimToLvl())
1033       return dimToLvl.getDimPosition(l);
1034   }
1035   return l;
1036 }
1037 
1038 Level mlir::sparse_tensor::toLvl(SparseTensorEncodingAttr enc, Dimension d) {
1039   if (enc) {
1040     assert(enc.isPermutation() && "Non permutation map not supported");
1041     if (const auto lvlToDim = enc.getLvlToDim())
1042       return lvlToDim.getDimPosition(d);
1043   }
1044   return d;
1045 }
1046 
1047 /// We normalized sparse tensor encoding attribute by always using
1048 /// ordered/unique LT such that "compressed_nu_no" and "compressed_nu" (as well
1049 /// as other variants) lead to the same storage specifier type, and stripping
1050 /// irrelevant fields that do not alter the sparse tensor memory layout.
1051 static SparseTensorEncodingAttr
1052 getNormalizedEncodingForSpecifier(SparseTensorEncodingAttr enc) {
1053   SmallVector<LevelType> lts;
1054   for (auto lt : enc.getLvlTypes())
1055     lts.push_back(lt.stripStorageIrrelevantProperties());
1056 
1057   return SparseTensorEncodingAttr::get(
1058       enc.getContext(), lts,
1059       AffineMap(), // dimToLvl (irrelevant to storage specifier)
1060       AffineMap(), // lvlToDim (irrelevant to storage specifier)
1061       // Always use `index` for memSize and lvlSize instead of reusing
1062       // `getPosWidth` and `getCrdWidth`. It allows us to reuse the same SSA
1063       // value for different bitwidth, it also avoids casting between index and
1064       // integer (returned by DimOp)
1065       0, 0, enc.getDimSlices());
1066 }
1067 
1068 StorageSpecifierType
1069 StorageSpecifierType::get(MLIRContext *ctx, SparseTensorEncodingAttr encoding) {
1070   return Base::get(ctx, getNormalizedEncodingForSpecifier(encoding));
1071 }
1072 
1073 //===----------------------------------------------------------------------===//
1074 // SparseTensorDialect Operations.
1075 //===----------------------------------------------------------------------===//
1076 
1077 static LogicalResult lvlIsInBounds(Level lvl, Value tensor) {
1078   return success(lvl < getSparseTensorType(tensor).getLvlRank());
1079 }
1080 
1081 static LogicalResult isMatchingWidth(Value mem, unsigned width) {
1082   const Type etp = getMemRefType(mem).getElementType();
1083   return success(width == 0 ? etp.isIndex() : etp.isInteger(width));
1084 }
1085 
1086 static LogicalResult verifySparsifierGetterSetter(
1087     StorageSpecifierKind mdKind, std::optional<Level> lvl,
1088     TypedValue<StorageSpecifierType> md, Operation *op) {
1089   if (mdKind == StorageSpecifierKind::ValMemSize && lvl) {
1090     return op->emitError(
1091         "redundant level argument for querying value memory size");
1092   }
1093 
1094   const auto enc = md.getType().getEncoding();
1095   const Level lvlRank = enc.getLvlRank();
1096 
1097   if (mdKind == StorageSpecifierKind::DimOffset ||
1098       mdKind == StorageSpecifierKind::DimStride)
1099     if (!enc.isSlice())
1100       return op->emitError("requested slice data on non-slice tensor");
1101 
1102   if (mdKind != StorageSpecifierKind::ValMemSize) {
1103     if (!lvl)
1104       return op->emitError("missing level argument");
1105 
1106     const Level l = lvl.value();
1107     if (l >= lvlRank)
1108       return op->emitError("requested level is out of bounds");
1109 
1110     if (mdKind == StorageSpecifierKind::PosMemSize && enc.isSingletonLvl(l))
1111       return op->emitError(
1112           "requested position memory size on a singleton level");
1113   }
1114   return success();
1115 }
1116 
1117 static Type getFieldElemType(SparseTensorType stt, SparseTensorFieldKind kind) {
1118   switch (kind) {
1119   case SparseTensorFieldKind::CrdMemRef:
1120     return stt.getCrdType();
1121   case SparseTensorFieldKind::PosMemRef:
1122     return stt.getPosType();
1123   case SparseTensorFieldKind::ValMemRef:
1124     return stt.getElementType();
1125   case SparseTensorFieldKind::StorageSpec:
1126     return nullptr;
1127   }
1128   llvm_unreachable("Unrecognizable FieldKind");
1129 }
1130 
1131 static LogicalResult verifyPackUnPack(Operation *op, bool requiresStaticShape,
1132                                       SparseTensorType stt,
1133                                       RankedTensorType valTp,
1134                                       TypeRange lvlTps) {
1135   if (requiresStaticShape && !stt.hasStaticDimShape())
1136     return op->emitError("the sparse-tensor must have static shape");
1137   if (!stt.hasEncoding())
1138     return op->emitError("the sparse-tensor must have an encoding attribute");
1139 
1140   // Verifies the trailing COO.
1141   Level cooStartLvl = stt.getAoSCOOStart();
1142   if (cooStartLvl < stt.getLvlRank()) {
1143     // We only supports trailing COO for now, must be the last input.
1144     auto cooTp = llvm::cast<ShapedType>(lvlTps.back());
1145     // The coordinates should be in shape of <? x rank>
1146     unsigned expCOORank = stt.getLvlRank() - cooStartLvl;
1147     if (cooTp.getRank() != 2 || expCOORank != cooTp.getShape().back()) {
1148       op->emitError("input/output trailing COO level-ranks don't match");
1149     }
1150   }
1151 
1152   // Verifies that all types match.
1153   StorageLayout layout(stt.getEncoding());
1154   if (layout.getNumDataFields() != lvlTps.size() + 1) // plus one value memref
1155     return op->emitError("inconsistent number of fields between input/output");
1156 
1157   unsigned idx = 0;
1158   bool misMatch = false;
1159   layout.foreachField([&idx, &misMatch, stt, valTp,
1160                        lvlTps](FieldIndex fid, SparseTensorFieldKind fKind,
1161                                Level lvl, LevelType lt) -> bool {
1162     if (fKind == SparseTensorFieldKind::StorageSpec)
1163       return true;
1164 
1165     Type inputTp = nullptr;
1166     if (fKind == SparseTensorFieldKind::ValMemRef) {
1167       inputTp = valTp;
1168     } else {
1169       assert(fid == idx && stt.getLvlType(lvl) == lt);
1170       inputTp = lvlTps[idx++];
1171     }
1172     // The input element type and expected element type should match.
1173     Type inpElemTp = llvm::cast<TensorType>(inputTp).getElementType();
1174     Type expElemTp = getFieldElemType(stt, fKind);
1175     if (inpElemTp != expElemTp) {
1176       misMatch = true;
1177       return false; // to terminate the iteration
1178     }
1179     return true;
1180   });
1181 
1182   if (misMatch)
1183     return op->emitError("input/output element-types don't match");
1184   return success();
1185 }
1186 
1187 LogicalResult AssembleOp::verify() {
1188   const auto valuesTp = getRankedTensorType(getValues());
1189   const auto lvlsTp = getLevels().getTypes();
1190   const auto resTp = getSparseTensorType(getResult());
1191   return verifyPackUnPack(*this, true, resTp, valuesTp, lvlsTp);
1192 }
1193 
1194 LogicalResult DisassembleOp::verify() {
1195   if (getOutValues().getType() != getRetValues().getType())
1196     return emitError("output values and return value type mismatch");
1197 
1198   for (auto [ot, rt] : llvm::zip_equal(getOutLevels(), getRetLevels()))
1199     if (ot.getType() != rt.getType())
1200       return emitError("output levels and return levels type mismatch");
1201 
1202   const auto valuesTp = getRankedTensorType(getRetValues());
1203   const auto lvlsTp = getRetLevels().getTypes();
1204   const auto srcTp = getSparseTensorType(getTensor());
1205   return verifyPackUnPack(*this, false, srcTp, valuesTp, lvlsTp);
1206 }
1207 
1208 LogicalResult ConvertOp::verify() {
1209   if (auto tp1 = llvm::dyn_cast<RankedTensorType>(getSource().getType())) {
1210     if (auto tp2 = llvm::dyn_cast<RankedTensorType>(getDest().getType())) {
1211       if (tp1.getRank() != tp2.getRank())
1212         return emitError("unexpected conversion mismatch in rank");
1213       auto dstEnc =
1214           llvm::dyn_cast_or_null<SparseTensorEncodingAttr>(tp2.getEncoding());
1215       if (dstEnc && dstEnc.isSlice())
1216         return emitError("cannot convert to a sparse tensor slice");
1217 
1218       auto shape1 = tp1.getShape();
1219       auto shape2 = tp2.getShape();
1220       // Accept size matches between the source and the destination type
1221       // (e.g. 10 vs. 10, 10 vs. ?, or ? vs. ?), but reject direct mismatches or
1222       // matches that would need a runtime assert (e.g. 10 vs. 20 or ? vs. 10).
1223       for (Dimension d = 0, dimRank = tp1.getRank(); d < dimRank; d++)
1224         if (shape1[d] != shape2[d] && shape2[d] != ShapedType::kDynamic)
1225           return emitError("unexpected conversion mismatch in dimension ") << d;
1226       return success();
1227     }
1228   }
1229   return emitError("unexpected type in convert");
1230 }
1231 
1232 OpFoldResult ConvertOp::fold(FoldAdaptor adaptor) {
1233   if (getType() == getSource().getType())
1234     return getSource();
1235   return {};
1236 }
1237 
1238 bool ConvertOp::needsExtraSort() {
1239   SparseTensorType srcStt = getSparseTensorType(getSource());
1240   SparseTensorType dstStt = getSparseTensorType(getDest());
1241 
1242   // We do not need an extra sort when returning unordered sparse tensors or
1243   // dense tensor since dense tensor support random access.
1244   if (dstStt.isAllDense() || !dstStt.isAllOrdered())
1245     return false;
1246 
1247   if (srcStt.isAllOrdered() && dstStt.isAllOrdered() &&
1248       srcStt.hasSameDimToLvl(dstStt)) {
1249     return false;
1250   }
1251 
1252   // Source and dest tensors are ordered in different ways. We only do direct
1253   // dense to sparse conversion when the dense input is defined by a sparse
1254   // constant. Note that we can theoretically always directly convert from dense
1255   // inputs by rotating dense loops but it leads to bad cache locality and hurt
1256   // performance.
1257   if (auto constOp = getSource().getDefiningOp<arith::ConstantOp>())
1258     if (isa<SparseElementsAttr>(constOp.getValue()))
1259       return false;
1260 
1261   return true;
1262 }
1263 
1264 LogicalResult CrdTranslateOp::verify() {
1265   uint64_t inRank = getEncoder().getLvlRank();
1266   uint64_t outRank = getEncoder().getDimRank();
1267 
1268   if (getDirection() == CrdTransDirectionKind::dim2lvl)
1269     std::swap(inRank, outRank);
1270 
1271   if (inRank != getInCrds().size() || outRank != getOutCrds().size())
1272     return emitError("Coordinate rank mismatch with encoding");
1273 
1274   return success();
1275 }
1276 
1277 LogicalResult CrdTranslateOp::fold(FoldAdaptor adaptor,
1278                                    SmallVectorImpl<OpFoldResult> &results) {
1279   if (getEncoder().isIdentity()) {
1280     results.assign(getInCrds().begin(), getInCrds().end());
1281     return success();
1282   }
1283   if (getEncoder().isPermutation()) {
1284     AffineMap perm = getDirection() == CrdTransDirectionKind::dim2lvl
1285                          ? getEncoder().getDimToLvl()
1286                          : getEncoder().getLvlToDim();
1287     for (AffineExpr exp : perm.getResults())
1288       results.push_back(getInCrds()[cast<AffineDimExpr>(exp).getPosition()]);
1289     return success();
1290   }
1291 
1292   // Fuse dim2lvl/lvl2dim pairs.
1293   auto def = getInCrds()[0].getDefiningOp<CrdTranslateOp>();
1294   bool sameDef = def && llvm::all_of(getInCrds(), [def](Value v) {
1295                    return v.getDefiningOp() == def;
1296                  });
1297   if (!sameDef)
1298     return failure();
1299 
1300   bool oppositeDir = def.getDirection() != getDirection();
1301   bool sameOracle =
1302       def.getEncoder().getDimToLvl() == getEncoder().getDimToLvl();
1303   bool sameCount = def.getNumResults() == getInCrds().size();
1304   if (!oppositeDir || !sameOracle || !sameCount)
1305     return failure();
1306 
1307   // The definition produces the coordinates in the same order as the input
1308   // coordinates.
1309   bool sameOrder = llvm::all_of(llvm::zip_equal(def.getOutCrds(), getInCrds()),
1310                                 [](auto valuePair) {
1311                                   auto [lhs, rhs] = valuePair;
1312                                   return lhs == rhs;
1313                                 });
1314 
1315   if (!sameOrder)
1316     return failure();
1317   // l1 = dim2lvl (lvl2dim l0)
1318   // ==> l0
1319   results.append(def.getInCrds().begin(), def.getInCrds().end());
1320   return success();
1321 }
1322 
1323 void LvlOp::build(OpBuilder &builder, OperationState &state, Value source,
1324                   int64_t index) {
1325   Value val = builder.create<arith::ConstantIndexOp>(state.location, index);
1326   return build(builder, state, source, val);
1327 }
1328 
1329 LogicalResult LvlOp::verify() {
1330   if (std::optional<uint64_t> lvl = getConstantLvlIndex()) {
1331     auto stt = getSparseTensorType(getSource());
1332     if (static_cast<uint64_t>(lvl.value()) >= stt.getLvlRank())
1333       emitError("Level index exceeds the rank of the input sparse tensor");
1334   }
1335   return success();
1336 }
1337 
1338 std::optional<uint64_t> LvlOp::getConstantLvlIndex() {
1339   return getConstantIntValue(getIndex());
1340 }
1341 
1342 Speculation::Speculatability LvlOp::getSpeculatability() {
1343   auto constantIndex = getConstantLvlIndex();
1344   if (!constantIndex)
1345     return Speculation::NotSpeculatable;
1346 
1347   assert(constantIndex <
1348          cast<RankedTensorType>(getSource().getType()).getRank());
1349   return Speculation::Speculatable;
1350 }
1351 
1352 OpFoldResult LvlOp::fold(FoldAdaptor adaptor) {
1353   auto lvlIndex = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getIndex());
1354   if (!lvlIndex)
1355     return {};
1356 
1357   Level lvl = lvlIndex.getAPSInt().getZExtValue();
1358   auto stt = getSparseTensorType(getSource());
1359   if (lvl >= stt.getLvlRank()) {
1360     // Follows the same convention used by tensor.dim operation. Out of bound
1361     // indices produce undefined behavior but are still valid IR. Don't choke on
1362     // them.
1363     return {};
1364   }
1365 
1366   // Helper lambda to build an IndexAttr.
1367   auto getIndexAttr = [this](int64_t lvlSz) {
1368     return IntegerAttr::get(IndexType::get(getContext()), APInt(64, lvlSz));
1369   };
1370 
1371   SmallVector<Size> lvlShape = stt.getLvlShape();
1372   if (!ShapedType::isDynamic(lvlShape[lvl]))
1373     return getIndexAttr(lvlShape[lvl]);
1374 
1375   return {};
1376 }
1377 
1378 void ReinterpretMapOp::build(OpBuilder &odsBuilder, OperationState &odsState,
1379                              SparseTensorEncodingAttr dstEnc, Value source) {
1380   auto srcStt = getSparseTensorType(source);
1381   SmallVector<int64_t> srcLvlShape = srcStt.getLvlShape();
1382   SmallVector<int64_t> dstDimShape =
1383       dstEnc.translateShape(srcLvlShape, CrdTransDirectionKind::lvl2dim);
1384   auto dstTp =
1385       RankedTensorType::get(dstDimShape, srcStt.getElementType(), dstEnc);
1386   return build(odsBuilder, odsState, dstTp, source);
1387 }
1388 
1389 LogicalResult ReinterpretMapOp::verify() {
1390   auto srcStt = getSparseTensorType(getSource());
1391   auto dstStt = getSparseTensorType(getDest());
1392   ArrayRef<LevelType> srcLvlTps = srcStt.getLvlTypes();
1393   ArrayRef<LevelType> dstLvlTps = dstStt.getLvlTypes();
1394 
1395   if (srcLvlTps.size() != dstLvlTps.size())
1396     return emitError("Level rank mismatch between source/dest tensors");
1397 
1398   for (auto [srcLvlTp, dstLvlTp] : llvm::zip(srcLvlTps, dstLvlTps))
1399     if (srcLvlTp != dstLvlTp)
1400       return emitError("Level type mismatch between source/dest tensors");
1401 
1402   if (srcStt.getPosWidth() != dstStt.getPosWidth() ||
1403       srcStt.getCrdWidth() != dstStt.getCrdWidth()) {
1404     return emitError("Crd/Pos width mismatch between source/dest tensors");
1405   }
1406 
1407   if (srcStt.getElementType() != dstStt.getElementType())
1408     return emitError("Element type mismatch between source/dest tensors");
1409 
1410   SmallVector<Size> srcLvlShape = srcStt.getLvlShape();
1411   SmallVector<Size> dstLvlShape = dstStt.getLvlShape();
1412   for (auto [srcLvlSz, dstLvlSz] : llvm::zip(srcLvlShape, dstLvlShape)) {
1413     if (srcLvlSz != dstLvlSz) {
1414       // Should we allow one side to be dynamic size, e.g., <?x?> should be
1415       // compatible to <3x4>? For now, we require all the level sizes to be
1416       // *exactly* matched for simplicity.
1417       return emitError("Level size mismatch between source/dest tensors");
1418     }
1419   }
1420 
1421   return success();
1422 }
1423 
1424 OpFoldResult ReinterpretMapOp::fold(FoldAdaptor adaptor) {
1425   if (getSource().getType() == getDest().getType())
1426     return getSource();
1427 
1428   if (auto def = getSource().getDefiningOp<ReinterpretMapOp>()) {
1429     // A -> B, B -> A ==> A
1430     if (def.getSource().getType() == getDest().getType())
1431       return def.getSource();
1432   }
1433   return {};
1434 }
1435 
1436 template <typename ToBufferOp>
1437 static LogicalResult inferSparseBufferType(ValueRange ops, DictionaryAttr attr,
1438                                            OpaqueProperties prop,
1439                                            RegionRange region,
1440                                            SmallVectorImpl<mlir::Type> &ret) {
1441   typename ToBufferOp::Adaptor adaptor(ops, attr, prop, region);
1442   SparseTensorType stt = getSparseTensorType(adaptor.getTensor());
1443   Type elemTp = nullptr;
1444   bool withStride = false;
1445   if constexpr (std::is_same_v<ToBufferOp, ToPositionsOp>) {
1446     elemTp = stt.getPosType();
1447   } else if constexpr (std::is_same_v<ToBufferOp, ToCoordinatesOp> ||
1448                        std::is_same_v<ToBufferOp, ToCoordinatesBufferOp>) {
1449     elemTp = stt.getCrdType();
1450     if constexpr (std::is_same_v<ToBufferOp, ToCoordinatesOp>)
1451       withStride = stt.getAoSCOOStart() <= adaptor.getLevel();
1452   } else if constexpr (std::is_same_v<ToBufferOp, ToValuesOp>) {
1453     elemTp = stt.getElementType();
1454   }
1455 
1456   assert(elemTp && "unhandled operation.");
1457   SmallVector<int64_t> bufShape = stt.getBatchLvlShape();
1458   bufShape.push_back(ShapedType::kDynamic);
1459 
1460   auto layout = withStride ? StridedLayoutAttr::StridedLayoutAttr::get(
1461                                  stt.getContext(), ShapedType::kDynamic,
1462                                  {ShapedType::kDynamic})
1463                            : StridedLayoutAttr();
1464   ret.emplace_back(MemRefType::get(bufShape, elemTp, layout));
1465   return success();
1466 }
1467 
1468 LogicalResult ToPositionsOp::verify() {
1469   auto stt = getSparseTensorType(getTensor());
1470   if (failed(lvlIsInBounds(getLevel(), getTensor())))
1471     return emitError("requested level is out of bounds");
1472   if (failed(isMatchingWidth(getResult(), stt.getPosWidth())))
1473     return emitError("unexpected type for positions");
1474   return success();
1475 }
1476 
1477 LogicalResult
1478 ToPositionsOp::inferReturnTypes(MLIRContext *ctx, std::optional<Location> loc,
1479                                 ValueRange ops, DictionaryAttr attr,
1480                                 OpaqueProperties prop, RegionRange region,
1481                                 SmallVectorImpl<mlir::Type> &ret) {
1482   return inferSparseBufferType<ToPositionsOp>(ops, attr, prop, region, ret);
1483 }
1484 
1485 LogicalResult ToCoordinatesOp::verify() {
1486   auto stt = getSparseTensorType(getTensor());
1487   if (failed(lvlIsInBounds(getLevel(), getTensor())))
1488     return emitError("requested level is out of bounds");
1489   if (failed(isMatchingWidth(getResult(), stt.getCrdWidth())))
1490     return emitError("unexpected type for coordinates");
1491   return success();
1492 }
1493 
1494 LogicalResult
1495 ToCoordinatesOp::inferReturnTypes(MLIRContext *ctx, std::optional<Location> loc,
1496                                   ValueRange ops, DictionaryAttr attr,
1497                                   OpaqueProperties prop, RegionRange region,
1498                                   SmallVectorImpl<mlir::Type> &ret) {
1499   return inferSparseBufferType<ToCoordinatesOp>(ops, attr, prop, region, ret);
1500 }
1501 
1502 LogicalResult ToCoordinatesBufferOp::verify() {
1503   auto stt = getSparseTensorType(getTensor());
1504   if (stt.getAoSCOOStart() >= stt.getLvlRank())
1505     return emitError("expected sparse tensor with a COO region");
1506   return success();
1507 }
1508 
1509 LogicalResult ToCoordinatesBufferOp::inferReturnTypes(
1510     MLIRContext *ctx, std::optional<Location> loc, ValueRange ops,
1511     DictionaryAttr attr, OpaqueProperties prop, RegionRange region,
1512     SmallVectorImpl<mlir::Type> &ret) {
1513   return inferSparseBufferType<ToCoordinatesBufferOp>(ops, attr, prop, region,
1514                                                       ret);
1515 }
1516 
1517 LogicalResult ToValuesOp::verify() {
1518   auto stt = getSparseTensorType(getTensor());
1519   auto mtp = getMemRefType(getResult());
1520   if (stt.getElementType() != mtp.getElementType())
1521     return emitError("unexpected mismatch in element types");
1522   return success();
1523 }
1524 
1525 LogicalResult ToValuesOp::inferReturnTypes(MLIRContext *ctx,
1526                                            std::optional<Location> loc,
1527                                            ValueRange ops, DictionaryAttr attr,
1528                                            OpaqueProperties prop,
1529                                            RegionRange region,
1530                                            SmallVectorImpl<mlir::Type> &ret) {
1531   return inferSparseBufferType<ToValuesOp>(ops, attr, prop, region, ret);
1532 }
1533 
1534 LogicalResult ToSliceOffsetOp::verify() {
1535   auto rank = getRankedTensorType(getSlice()).getRank();
1536   if (rank <= getDim().getSExtValue() || getDim().getSExtValue() < 0)
1537     return emitError("requested dimension out of bound");
1538   return success();
1539 }
1540 
1541 LogicalResult ToSliceStrideOp::verify() {
1542   auto rank = getRankedTensorType(getSlice()).getRank();
1543   if (rank <= getDim().getSExtValue() || getDim().getSExtValue() < 0)
1544     return emitError("requested dimension out of bound");
1545   return success();
1546 }
1547 
1548 LogicalResult GetStorageSpecifierOp::verify() {
1549   return verifySparsifierGetterSetter(getSpecifierKind(), getLevel(),
1550                                       getSpecifier(), getOperation());
1551 }
1552 
1553 template <typename SpecifierOp>
1554 static SetStorageSpecifierOp getSpecifierSetDef(SpecifierOp op) {
1555   return op.getSpecifier().template getDefiningOp<SetStorageSpecifierOp>();
1556 }
1557 
1558 OpFoldResult GetStorageSpecifierOp::fold(FoldAdaptor adaptor) {
1559   const StorageSpecifierKind kind = getSpecifierKind();
1560   const auto lvl = getLevel();
1561   for (auto op = getSpecifierSetDef(*this); op; op = getSpecifierSetDef(op))
1562     if (kind == op.getSpecifierKind() && lvl == op.getLevel())
1563       return op.getValue();
1564   return {};
1565 }
1566 
1567 LogicalResult SetStorageSpecifierOp::verify() {
1568   return verifySparsifierGetterSetter(getSpecifierKind(), getLevel(),
1569                                       getSpecifier(), getOperation());
1570 }
1571 
1572 template <class T>
1573 static LogicalResult verifyNumBlockArgs(T *op, Region &region,
1574                                         const char *regionName,
1575                                         TypeRange inputTypes, Type outputType) {
1576   unsigned numArgs = region.getNumArguments();
1577   unsigned expectedNum = inputTypes.size();
1578   if (numArgs != expectedNum)
1579     return op->emitError() << regionName << " region must have exactly "
1580                            << expectedNum << " arguments";
1581 
1582   for (unsigned i = 0; i < numArgs; i++) {
1583     Type typ = region.getArgument(i).getType();
1584     if (typ != inputTypes[i])
1585       return op->emitError() << regionName << " region argument " << (i + 1)
1586                              << " type mismatch";
1587   }
1588   Operation *term = region.front().getTerminator();
1589   YieldOp yield = dyn_cast<YieldOp>(term);
1590   if (!yield)
1591     return op->emitError() << regionName
1592                            << " region must end with sparse_tensor.yield";
1593   if (!yield.getResult() || yield.getResult().getType() != outputType)
1594     return op->emitError() << regionName << " region yield type mismatch";
1595 
1596   return success();
1597 }
1598 
1599 LogicalResult BinaryOp::verify() {
1600   NamedAttrList attrs = (*this)->getAttrs();
1601   Type leftType = getX().getType();
1602   Type rightType = getY().getType();
1603   Type outputType = getOutput().getType();
1604   Region &overlap = getOverlapRegion();
1605   Region &left = getLeftRegion();
1606   Region &right = getRightRegion();
1607 
1608   // Check correct number of block arguments and return type for each
1609   // non-empty region.
1610   if (!overlap.empty()) {
1611     if (failed(verifyNumBlockArgs(this, overlap, "overlap",
1612                                   TypeRange{leftType, rightType}, outputType)))
1613       return failure();
1614   }
1615   if (!left.empty()) {
1616     if (failed(verifyNumBlockArgs(this, left, "left", TypeRange{leftType},
1617                                   outputType)))
1618       return failure();
1619   } else if (getLeftIdentity()) {
1620     if (leftType != outputType)
1621       return emitError("left=identity requires first argument to have the same "
1622                        "type as the output");
1623   }
1624   if (!right.empty()) {
1625     if (failed(verifyNumBlockArgs(this, right, "right", TypeRange{rightType},
1626                                   outputType)))
1627       return failure();
1628   } else if (getRightIdentity()) {
1629     if (rightType != outputType)
1630       return emitError("right=identity requires second argument to have the "
1631                        "same type as the output");
1632   }
1633   return success();
1634 }
1635 
1636 LogicalResult UnaryOp::verify() {
1637   Type inputType = getX().getType();
1638   Type outputType = getOutput().getType();
1639 
1640   // Check correct number of block arguments and return type for each
1641   // non-empty region.
1642   Region &present = getPresentRegion();
1643   if (!present.empty()) {
1644     if (failed(verifyNumBlockArgs(this, present, "present",
1645                                   TypeRange{inputType}, outputType)))
1646       return failure();
1647   }
1648   Region &absent = getAbsentRegion();
1649   if (!absent.empty()) {
1650     if (failed(verifyNumBlockArgs(this, absent, "absent", TypeRange{},
1651                                   outputType)))
1652       return failure();
1653     // Absent branch can only yield invariant values.
1654     Block *absentBlock = &absent.front();
1655     Block *parent = getOperation()->getBlock();
1656     Value absentVal = cast<YieldOp>(absentBlock->getTerminator()).getResult();
1657     if (auto arg = dyn_cast<BlockArgument>(absentVal)) {
1658       if (arg.getOwner() == parent)
1659         return emitError("absent region cannot yield linalg argument");
1660     } else if (Operation *def = absentVal.getDefiningOp()) {
1661       if (!isa<arith::ConstantOp>(def) &&
1662           (def->getBlock() == absentBlock || def->getBlock() == parent))
1663         return emitError("absent region cannot yield locally computed value");
1664     }
1665   }
1666   return success();
1667 }
1668 
1669 bool ConcatenateOp::needsExtraSort() {
1670   SparseTensorType dstStt = getSparseTensorType(*this);
1671   if (dstStt.isAllDense() || !dstStt.isAllOrdered())
1672     return false;
1673 
1674   bool allSameOrdered = llvm::all_of(getInputs(), [dstStt](Value op) {
1675     return getSparseTensorType(op).hasSameDimToLvl(dstStt);
1676   });
1677   // TODO: When conDim != 0, as long as conDim corresponding to the first level
1678   // in all input/output buffers, and all input/output buffers have the same
1679   // dimToLvl, the tmp COO buffer is still unnecessary (e.g, concatenate
1680   // CSC matrices along column).
1681   bool directLowerable =
1682       allSameOrdered && getDimension() == 0 && dstStt.isIdentity();
1683   return !directLowerable;
1684 }
1685 
1686 LogicalResult ConcatenateOp::verify() {
1687   const auto dstTp = getSparseTensorType(*this);
1688   const Dimension concatDim = getDimension();
1689   const Dimension dimRank = dstTp.getDimRank();
1690 
1691   if (getInputs().size() <= 1)
1692     return emitError("Need at least two tensors to concatenate.");
1693 
1694   if (concatDim >= dimRank)
1695     return emitError(llvm::formatv(
1696         "Concat-dimension is out of bounds for dimension-rank ({0} >= {1})",
1697         concatDim, dimRank));
1698 
1699   for (const auto &it : llvm::enumerate(getInputs())) {
1700     const auto i = it.index();
1701     const auto srcTp = getSparseTensorType(it.value());
1702     if (srcTp.hasDynamicDimShape())
1703       return emitError(llvm::formatv("Input tensor ${0} has dynamic shape", i));
1704     const Dimension srcDimRank = srcTp.getDimRank();
1705     if (srcDimRank != dimRank)
1706       return emitError(
1707           llvm::formatv("Input tensor ${0} has a different rank (rank={1}) "
1708                         "from the output tensor (rank={2}).",
1709                         i, srcDimRank, dimRank));
1710   }
1711 
1712   for (Dimension d = 0; d < dimRank; d++) {
1713     const Size dstSh = dstTp.getDimShape()[d];
1714     if (d == concatDim) {
1715       if (!ShapedType::isDynamic(dstSh)) {
1716         // If we reach here, then all inputs have static shapes.  So we
1717         // can use `getDimShape()[d]` instead of `*getDynamicDimSize(d)`
1718         // to avoid redundant assertions in the loop.
1719         Size sumSz = 0;
1720         for (const auto src : getInputs())
1721           sumSz += getSparseTensorType(src).getDimShape()[d];
1722         // If all dimension are statically known, the sum of all the input
1723         // dimensions should be equal to the output dimension.
1724         if (sumSz != dstSh)
1725           return emitError(
1726               "The concatenation dimension of the output tensor should be the "
1727               "sum of all the concatenation dimensions of the input tensors.");
1728       }
1729     } else {
1730       Size prev = dstSh;
1731       for (const auto src : getInputs()) {
1732         const auto sh = getSparseTensorType(src).getDimShape()[d];
1733         if (!ShapedType::isDynamic(prev) && sh != prev)
1734           return emitError("All dimensions (expect for the concatenating one) "
1735                            "should be equal.");
1736         prev = sh;
1737       }
1738     }
1739   }
1740 
1741   return success();
1742 }
1743 
1744 void PushBackOp::build(OpBuilder &builder, OperationState &result,
1745                        Value curSize, Value inBuffer, Value value) {
1746   build(builder, result, curSize, inBuffer, value, Value());
1747 }
1748 
1749 LogicalResult PushBackOp::verify() {
1750   if (Value n = getN()) {
1751     std::optional<int64_t> nValue = getConstantIntValue(n);
1752     if (nValue && nValue.value() < 1)
1753       return emitOpError("n must be not less than 1");
1754   }
1755   return success();
1756 }
1757 
1758 LogicalResult CompressOp::verify() {
1759   const auto stt = getSparseTensorType(getTensor());
1760   if (stt.getLvlRank() != 1 + static_cast<Level>(getLvlCoords().size()))
1761     return emitOpError("incorrect number of coordinates");
1762   return success();
1763 }
1764 
1765 void ForeachOp::build(
1766     OpBuilder &builder, OperationState &result, Value tensor,
1767     ValueRange initArgs, AffineMapAttr order,
1768     function_ref<void(OpBuilder &, Location, ValueRange, Value, ValueRange)>
1769         bodyBuilder) {
1770   build(builder, result, initArgs.getTypes(), tensor, initArgs, order);
1771   // Builds foreach body.
1772   if (!bodyBuilder)
1773     return;
1774   const auto stt = getSparseTensorType(tensor);
1775   const Dimension dimRank = stt.getDimRank();
1776 
1777   // Starts with `dimRank`-many coordinates.
1778   SmallVector<Type> blockArgTypes(dimRank, builder.getIndexType());
1779   // Followed by one value.
1780   blockArgTypes.push_back(stt.getElementType());
1781   // Followed by the reduction variables.
1782   blockArgTypes.append(initArgs.getTypes().begin(), initArgs.getTypes().end());
1783 
1784   SmallVector<Location> blockArgLocs(blockArgTypes.size(), tensor.getLoc());
1785 
1786   OpBuilder::InsertionGuard guard(builder);
1787   auto &region = *result.regions.front();
1788   Block *bodyBlock =
1789       builder.createBlock(&region, region.end(), blockArgTypes, blockArgLocs);
1790   bodyBuilder(builder, result.location,
1791               bodyBlock->getArguments().slice(0, dimRank),
1792               bodyBlock->getArguments()[dimRank],
1793               bodyBlock->getArguments().drop_front(dimRank + 1));
1794 }
1795 
1796 LogicalResult ForeachOp::verify() {
1797   const auto t = getSparseTensorType(getTensor());
1798   const Dimension dimRank = t.getDimRank();
1799   const auto args = getBody()->getArguments();
1800 
1801   if (getOrder().has_value() && getOrder()->getNumDims() != t.getLvlRank())
1802     return emitError("Level traverse order does not match tensor's level rank");
1803 
1804   if (dimRank + 1 + getInitArgs().size() != args.size())
1805     return emitError("Unmatched number of arguments in the block");
1806 
1807   if (getNumResults() != getInitArgs().size())
1808     return emitError("Mismatch in number of init arguments and results");
1809 
1810   if (getResultTypes() != getInitArgs().getTypes())
1811     return emitError("Mismatch in types of init arguments and results");
1812 
1813   // Cannot mark this const, because the getters aren't.
1814   auto yield = cast<YieldOp>(getBody()->getTerminator());
1815   if (yield.getNumOperands() != getNumResults() ||
1816       yield.getOperands().getTypes() != getResultTypes())
1817     return emitError("Mismatch in types of yield values and results");
1818 
1819   const auto iTp = IndexType::get(getContext());
1820   for (Dimension d = 0; d < dimRank; d++)
1821     if (args[d].getType() != iTp)
1822       emitError(
1823           llvm::formatv("Expecting Index type for argument at index {0}", d));
1824 
1825   const auto elemTp = t.getElementType();
1826   const auto valueTp = args[dimRank].getType();
1827   if (elemTp != valueTp)
1828     emitError(llvm::formatv("Unmatched element type between input tensor and "
1829                             "block argument, expected:{0}, got: {1}",
1830                             elemTp, valueTp));
1831   return success();
1832 }
1833 
1834 OpFoldResult ReorderCOOOp::fold(FoldAdaptor adaptor) {
1835   if (getSparseTensorEncoding(getInputCoo().getType()) ==
1836       getSparseTensorEncoding(getResultCoo().getType()))
1837     return getInputCoo();
1838 
1839   return {};
1840 }
1841 
1842 LogicalResult ReorderCOOOp::verify() {
1843   SparseTensorType srcStt = getSparseTensorType(getInputCoo());
1844   SparseTensorType dstStt = getSparseTensorType(getResultCoo());
1845 
1846   if (!srcStt.isCOOType() || !dstStt.isCOOType())
1847     emitError("Expected COO sparse tensors only");
1848 
1849   if (!srcStt.hasSameDimToLvl(dstStt))
1850     emitError("Unmatched dim2lvl map between input and result COO");
1851 
1852   if (srcStt.getPosType() != dstStt.getPosType() ||
1853       srcStt.getCrdType() != dstStt.getCrdType() ||
1854       srcStt.getElementType() != dstStt.getElementType())
1855     emitError("Unmatched storage format between input and result COO");
1856 
1857   return success();
1858 }
1859 
1860 LogicalResult ReduceOp::verify() {
1861   Type inputType = getX().getType();
1862   Region &formula = getRegion();
1863   return verifyNumBlockArgs(this, formula, "reduce",
1864                             TypeRange{inputType, inputType}, inputType);
1865 }
1866 
1867 LogicalResult SelectOp::verify() {
1868   Builder b(getContext());
1869   Type inputType = getX().getType();
1870   Type boolType = b.getI1Type();
1871   Region &formula = getRegion();
1872   return verifyNumBlockArgs(this, formula, "select", TypeRange{inputType},
1873                             boolType);
1874 }
1875 
1876 LogicalResult SortOp::verify() {
1877   AffineMap xPerm = getPermMap();
1878   uint64_t nx = xPerm.getNumDims();
1879   if (nx < 1)
1880     emitError(llvm::formatv("Expected rank(perm_map) > 1, got {0}", nx));
1881 
1882   if (!xPerm.isPermutation())
1883     emitError(llvm::formatv("Expected a permutation map, got {0}", xPerm));
1884 
1885   // We can't check the size of the buffers when n or buffer dimensions aren't
1886   // compile-time constants.
1887   std::optional<int64_t> cn = getConstantIntValue(getN());
1888   if (!cn)
1889     return success();
1890 
1891   // Verify dimensions.
1892   const auto checkDim = [&](Value v, Size minSize, const char *message) {
1893     const Size sh = getMemRefType(v).getShape()[0];
1894     if (!ShapedType::isDynamic(sh) && sh < minSize)
1895       emitError(llvm::formatv("{0} got {1} < {2}", message, sh, minSize));
1896   };
1897   uint64_t n = cn.value();
1898   uint64_t ny = 0;
1899   if (auto nyAttr = getNyAttr())
1900     ny = nyAttr.getInt();
1901   checkDim(getXy(), n * (nx + ny),
1902            "Expected dimension(xy) >= n * (rank(perm_map) + ny)");
1903   for (Value opnd : getYs())
1904     checkDim(opnd, n, "Expected dimension(y) >= n");
1905 
1906   return success();
1907 }
1908 
1909 LogicalResult YieldOp::verify() {
1910   // Check for compatible parent.
1911   auto *parentOp = (*this)->getParentOp();
1912   if (isa<BinaryOp>(parentOp) || isa<UnaryOp>(parentOp) ||
1913       isa<ReduceOp>(parentOp) || isa<SelectOp>(parentOp) ||
1914       isa<ForeachOp>(parentOp))
1915     return success();
1916 
1917   return emitOpError("expected parent op to be sparse_tensor unary, binary, "
1918                      "reduce, select or foreach");
1919 }
1920 
1921 /// Materialize a single constant operation from a given attribute value with
1922 /// the desired resultant type.
1923 Operation *SparseTensorDialect::materializeConstant(OpBuilder &builder,
1924                                                     Attribute value, Type type,
1925                                                     Location loc) {
1926   if (auto op = arith::ConstantOp::materialize(builder, value, type, loc))
1927     return op;
1928   return nullptr;
1929 }
1930 
1931 namespace {
1932 struct SparseTensorAsmDialectInterface : public OpAsmDialectInterface {
1933   using OpAsmDialectInterface::OpAsmDialectInterface;
1934 
1935   AliasResult getAlias(Attribute attr, raw_ostream &os) const override {
1936     if (attr.isa<SparseTensorEncodingAttr>()) {
1937       os << "sparse";
1938       return AliasResult::OverridableAlias;
1939     }
1940     return AliasResult::NoAlias;
1941   }
1942 };
1943 } // namespace
1944 
1945 void SparseTensorDialect::initialize() {
1946   addInterface<SparseTensorAsmDialectInterface>();
1947   addAttributes<
1948 #define GET_ATTRDEF_LIST
1949 #include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.cpp.inc"
1950       >();
1951   addTypes<
1952 #define GET_TYPEDEF_LIST
1953 #include "mlir/Dialect/SparseTensor/IR/SparseTensorTypes.cpp.inc"
1954       >();
1955   addOperations<
1956 #define GET_OP_LIST
1957 #include "mlir/Dialect/SparseTensor/IR/SparseTensorOps.cpp.inc"
1958       >();
1959 }
1960 
1961 #define GET_OP_CLASSES
1962 #include "mlir/Dialect/SparseTensor/IR/SparseTensorOps.cpp.inc"
1963 
1964 #include "mlir/Dialect/SparseTensor/IR/SparseTensorOpsDialect.cpp.inc"
1965