xref: /llvm-project/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp (revision af6e1881e0791ac1ee611b62a3d12d9fb03ca142)
1319072f4SAart Bik //===- SparseTensorDialect.cpp - Sparse tensor dialect implementation -----===//
2319072f4SAart Bik //
3319072f4SAart Bik // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4319072f4SAart Bik // See https://llvm.org/LICENSE.txt for license information.
5319072f4SAart Bik // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6319072f4SAart Bik //
7319072f4SAart Bik //===----------------------------------------------------------------------===//
8319072f4SAart Bik 
9239737edSMehdi Amini #include <utility>
10239737edSMehdi Amini 
116b88c852SAart Bik #include "Detail/DimLvlMapParser.h"
126b88c852SAart Bik 
13ee3ee131SAart Bik #include "mlir/Dialect/SparseTensor/IR/Enums.h"
14319072f4SAart Bik #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
15afe78db7SPeiming Liu #include "mlir/Dialect/SparseTensor/IR/SparseTensorStorageLayout.h"
16f708a549Swren romano #include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h"
175517208dSAart Bik 
18abc362a1SJakub Kuderski #include "mlir/Dialect/Arith/IR/Arith.h"
19513cdb82SJustin Fargnoli #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
20e71eacc5SYinying Li #include "mlir/Dialect/Complex/IR/Complex.h"
21cb7bda2aSMatthias Springer #include "mlir/Dialect/Utils/StaticValueUtils.h"
22319072f4SAart Bik #include "mlir/IR/Builders.h"
230a292199SAart Bik #include "mlir/IR/DialectImplementation.h"
2465e7cd13SRiver Riddle #include "mlir/IR/Matchers.h"
25319072f4SAart Bik #include "mlir/IR/OpImplementation.h"
261f07853fSPeiming Liu #include "mlir/IR/PatternMatch.h"
27a43d79afSPeiming Liu #include "llvm/ADT/Bitset.h"
280a292199SAart Bik #include "llvm/ADT/TypeSwitch.h"
29de907138SPeiming Liu #include "llvm/Support/FormatVariadic.h"
30319072f4SAart Bik 
3171cc0f1cSPeiming Liu #define GET_ATTRDEF_CLASSES
3271cc0f1cSPeiming Liu #include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.cpp.inc"
3371cc0f1cSPeiming Liu #include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrEnums.cpp.inc"
3471cc0f1cSPeiming Liu 
35481bd5d4SPeiming Liu // Forward declarations, following custom print/parsing methods are referenced
36481bd5d4SPeiming Liu // by the generated code for SparseTensorTypes.td.
37481bd5d4SPeiming Liu static mlir::ParseResult parseLevelRange(mlir::AsmParser &,
38481bd5d4SPeiming Liu                                          mlir::sparse_tensor::Level &,
39481bd5d4SPeiming Liu                                          mlir::sparse_tensor::Level &);
40481bd5d4SPeiming Liu static void printLevelRange(mlir::AsmPrinter &, mlir::sparse_tensor::Level,
41481bd5d4SPeiming Liu                             mlir::sparse_tensor::Level);
42481bd5d4SPeiming Liu 
4371cc0f1cSPeiming Liu #define GET_TYPEDEF_CLASSES
4471cc0f1cSPeiming Liu #include "mlir/Dialect/SparseTensor/IR/SparseTensorTypes.cpp.inc"
4571cc0f1cSPeiming Liu 
46319072f4SAart Bik using namespace mlir;
47319072f4SAart Bik using namespace mlir::sparse_tensor;
48319072f4SAart Bik 
49aaf91645SPeiming Liu // Support hashing LevelType such that SparseTensorEncodingAttr can be hashed as
50aaf91645SPeiming Liu // well.
51aaf91645SPeiming Liu namespace mlir::sparse_tensor {
52aaf91645SPeiming Liu llvm::hash_code hash_value(LevelType lt) {
53aaf91645SPeiming Liu   return llvm::hash_value(static_cast<uint64_t>(lt));
54aaf91645SPeiming Liu }
55aaf91645SPeiming Liu } // namespace mlir::sparse_tensor
56aaf91645SPeiming Liu 
570a292199SAart Bik //===----------------------------------------------------------------------===//
5845288085SAart Bik // Local Convenience Methods.
597a1077baSwren romano //===----------------------------------------------------------------------===//
607a1077baSwren romano 
617a1077baSwren romano static constexpr bool acceptBitWidth(unsigned bitWidth) {
627a1077baSwren romano   switch (bitWidth) {
637a1077baSwren romano   case 0:
647a1077baSwren romano   case 8:
657a1077baSwren romano   case 16:
667a1077baSwren romano   case 32:
677a1077baSwren romano   case 64:
687a1077baSwren romano     return true;
697a1077baSwren romano   default:
707a1077baSwren romano     return false;
717a1077baSwren romano   }
727a1077baSwren romano }
737a1077baSwren romano 
7462fa12adSPeiming Liu static SmallVector<Size>
7562fa12adSPeiming Liu getSparseFieldShape(const SparseTensorEncodingAttr enc,
7662fa12adSPeiming Liu                     std::optional<ArrayRef<int64_t>> dimShape) {
7762fa12adSPeiming Liu   assert(enc);
7862fa12adSPeiming Liu   // With only encoding, we can not determine the static shape for leading
7962fa12adSPeiming Liu   // batch levels, we therefore return a dynamic shape memref instead.
8062fa12adSPeiming Liu   SmallVector<int64_t> memrefShape(enc.getBatchLvlRank(), ShapedType::kDynamic);
8162fa12adSPeiming Liu   if (dimShape.has_value()) {
8262fa12adSPeiming Liu     // If the actual tensor shape is provided, we can then refine the leading
8362fa12adSPeiming Liu     // batch dimension.
8462fa12adSPeiming Liu     SmallVector<int64_t> lvlShape =
8562fa12adSPeiming Liu         enc.translateShape(*dimShape, CrdTransDirectionKind::dim2lvl);
8662fa12adSPeiming Liu     memrefShape.assign(lvlShape.begin(),
8762fa12adSPeiming Liu                        lvlShape.begin() + enc.getBatchLvlRank());
8862fa12adSPeiming Liu   }
8962fa12adSPeiming Liu   // Another dynamic dimension to store the sparse level.
9062fa12adSPeiming Liu   memrefShape.push_back(ShapedType::kDynamic);
9162fa12adSPeiming Liu   return memrefShape;
9262fa12adSPeiming Liu }
9362fa12adSPeiming Liu 
947a1077baSwren romano //===----------------------------------------------------------------------===//
957e83a1afSAart Bik // SparseTensorDialect StorageLayout.
96afe78db7SPeiming Liu //===----------------------------------------------------------------------===//
97afe78db7SPeiming Liu 
98afe78db7SPeiming Liu static constexpr Level kInvalidLevel = -1u;
99afe78db7SPeiming Liu static constexpr Level kInvalidFieldIndex = -1u;
100afe78db7SPeiming Liu static constexpr FieldIndex kDataFieldStartingIdx = 0;
101afe78db7SPeiming Liu 
102afe78db7SPeiming Liu void StorageLayout::foreachField(
103afe78db7SPeiming Liu     llvm::function_ref<bool(FieldIndex, SparseTensorFieldKind, Level,
1041944c4f7SAart Bik                             LevelType)>
105afe78db7SPeiming Liu         callback) const {
106afe78db7SPeiming Liu   const auto lvlTypes = enc.getLvlTypes();
107afe78db7SPeiming Liu   const Level lvlRank = enc.getLvlRank();
10883f3b1cbSYinying Li   SmallVector<COOSegment> cooSegs = enc.getCOOSegments();
109afe78db7SPeiming Liu   FieldIndex fieldIdx = kDataFieldStartingIdx;
110f740366fSPeiming Liu 
111f740366fSPeiming Liu   ArrayRef cooSegsRef = cooSegs;
112afe78db7SPeiming Liu   // Per-level storage.
113f740366fSPeiming Liu   for (Level l = 0; l < lvlRank; /*l += 1 or l += AoSCooLen*/) {
1141dd387e1SAart Bik     const auto lt = lvlTypes[l];
1151dd387e1SAart Bik     if (isWithPosLT(lt)) {
1161dd387e1SAart Bik       if (!(callback(fieldIdx++, SparseTensorFieldKind::PosMemRef, l, lt)))
1172e2011daSAart Bik         return;
118de560888SPeiming Liu     }
1191dd387e1SAart Bik     if (isWithCrdLT(lt)) {
1201dd387e1SAart Bik       if (!(callback(fieldIdx++, SparseTensorFieldKind::CrdMemRef, l, lt)))
1212e2011daSAart Bik         return;
122afe78db7SPeiming Liu     }
123f740366fSPeiming Liu     if (!cooSegsRef.empty() && cooSegsRef.front().isSegmentStart(l)) {
124f740366fSPeiming Liu       if (!cooSegsRef.front().isSoA) {
125f740366fSPeiming Liu         // AoS COO, all singletons are fused into one memrefs. Skips the entire
126f740366fSPeiming Liu         // COO segement.
127f740366fSPeiming Liu         l = cooSegsRef.front().lvlRange.second;
128f740366fSPeiming Liu       } else {
129f740366fSPeiming Liu         // SoA COO, each singleton level has one memref.
130f740366fSPeiming Liu         l++;
131f740366fSPeiming Liu       }
132f740366fSPeiming Liu       // Expire handled COO segment.
133f740366fSPeiming Liu       cooSegsRef = cooSegsRef.drop_front();
134f740366fSPeiming Liu     } else {
135f740366fSPeiming Liu       // Non COO levels.
136f740366fSPeiming Liu       l++;
137f740366fSPeiming Liu     }
138afe78db7SPeiming Liu   }
139afe78db7SPeiming Liu   // The values array.
1402e2011daSAart Bik   if (!(callback(fieldIdx++, SparseTensorFieldKind::ValMemRef, kInvalidLevel,
141aaf91645SPeiming Liu                  LevelFormat::Undef)))
1422e2011daSAart Bik     return;
143afe78db7SPeiming Liu   // Put metadata at the end.
1442e2011daSAart Bik   if (!(callback(fieldIdx++, SparseTensorFieldKind::StorageSpec, kInvalidLevel,
145aaf91645SPeiming Liu                  LevelFormat::Undef)))
1462e2011daSAart Bik     return;
147afe78db7SPeiming Liu }
148afe78db7SPeiming Liu 
149afe78db7SPeiming Liu void sparse_tensor::foreachFieldAndTypeInSparseTensor(
150afe78db7SPeiming Liu     SparseTensorType stt,
151afe78db7SPeiming Liu     llvm::function_ref<bool(Type, FieldIndex, SparseTensorFieldKind, Level,
1521944c4f7SAart Bik                             LevelType)>
153afe78db7SPeiming Liu         callback) {
154afe78db7SPeiming Liu   assert(stt.hasEncoding());
155afe78db7SPeiming Liu 
15662fa12adSPeiming Liu   SmallVector<int64_t> memrefShape =
15762fa12adSPeiming Liu       getSparseFieldShape(stt.getEncoding(), stt.getDimShape());
1580d1f9576SPeiming Liu 
159afe78db7SPeiming Liu   const Type specType = StorageSpecifierType::get(stt.getEncoding());
1600d1f9576SPeiming Liu   // memref<[batch] x ? x pos>  positions
16162fa12adSPeiming Liu   const Type posMemType = MemRefType::get(memrefShape, stt.getPosType());
1620d1f9576SPeiming Liu   // memref<[batch] x ? x crd>  coordinates
16362fa12adSPeiming Liu   const Type crdMemType = MemRefType::get(memrefShape, stt.getCrdType());
1640d1f9576SPeiming Liu   // memref<[batch] x ? x eltType> values
16562fa12adSPeiming Liu   const Type valMemType = MemRefType::get(memrefShape, stt.getElementType());
166afe78db7SPeiming Liu 
1671944c4f7SAart Bik   StorageLayout(stt).foreachField([specType, posMemType, crdMemType, valMemType,
1681944c4f7SAart Bik                                    callback](FieldIndex fieldIdx,
1691944c4f7SAart Bik                                              SparseTensorFieldKind fieldKind,
1701944c4f7SAart Bik                                              Level lvl, LevelType lt) -> bool {
171afe78db7SPeiming Liu     switch (fieldKind) {
172afe78db7SPeiming Liu     case SparseTensorFieldKind::StorageSpec:
1731dd387e1SAart Bik       return callback(specType, fieldIdx, fieldKind, lvl, lt);
174afe78db7SPeiming Liu     case SparseTensorFieldKind::PosMemRef:
1751dd387e1SAart Bik       return callback(posMemType, fieldIdx, fieldKind, lvl, lt);
176afe78db7SPeiming Liu     case SparseTensorFieldKind::CrdMemRef:
1771dd387e1SAart Bik       return callback(crdMemType, fieldIdx, fieldKind, lvl, lt);
178afe78db7SPeiming Liu     case SparseTensorFieldKind::ValMemRef:
1791dd387e1SAart Bik       return callback(valMemType, fieldIdx, fieldKind, lvl, lt);
180afe78db7SPeiming Liu     };
181afe78db7SPeiming Liu     llvm_unreachable("unrecognized field kind");
182afe78db7SPeiming Liu   });
183afe78db7SPeiming Liu }
184afe78db7SPeiming Liu 
185afe78db7SPeiming Liu unsigned StorageLayout::getNumFields() const {
186afe78db7SPeiming Liu   unsigned numFields = 0;
187afe78db7SPeiming Liu   foreachField([&numFields](FieldIndex, SparseTensorFieldKind, Level,
1881944c4f7SAart Bik                             LevelType) -> bool {
189afe78db7SPeiming Liu     numFields++;
190afe78db7SPeiming Liu     return true;
191afe78db7SPeiming Liu   });
192afe78db7SPeiming Liu   return numFields;
193afe78db7SPeiming Liu }
194afe78db7SPeiming Liu 
195afe78db7SPeiming Liu unsigned StorageLayout::getNumDataFields() const {
196afe78db7SPeiming Liu   unsigned numFields = 0; // one value memref
197afe78db7SPeiming Liu   foreachField([&numFields](FieldIndex fidx, SparseTensorFieldKind, Level,
1981944c4f7SAart Bik                             LevelType) -> bool {
199afe78db7SPeiming Liu     if (fidx >= kDataFieldStartingIdx)
200afe78db7SPeiming Liu       numFields++;
201afe78db7SPeiming Liu     return true;
202afe78db7SPeiming Liu   });
203afe78db7SPeiming Liu   numFields -= 1; // the last field is StorageSpecifier
204afe78db7SPeiming Liu   assert(numFields == getNumFields() - kDataFieldStartingIdx - 1);
205afe78db7SPeiming Liu   return numFields;
206afe78db7SPeiming Liu }
207afe78db7SPeiming Liu 
208afe78db7SPeiming Liu std::pair<FieldIndex, unsigned>
209afe78db7SPeiming Liu StorageLayout::getFieldIndexAndStride(SparseTensorFieldKind kind,
210afe78db7SPeiming Liu                                       std::optional<Level> lvl) const {
211afe78db7SPeiming Liu   FieldIndex fieldIdx = kInvalidFieldIndex;
212afe78db7SPeiming Liu   unsigned stride = 1;
213afe78db7SPeiming Liu   if (kind == SparseTensorFieldKind::CrdMemRef) {
214afe78db7SPeiming Liu     assert(lvl.has_value());
21583f3b1cbSYinying Li     const Level cooStart = enc.getAoSCOOStart();
216afe78db7SPeiming Liu     const Level lvlRank = enc.getLvlRank();
217afe78db7SPeiming Liu     if (lvl.value() >= cooStart && lvl.value() < lvlRank) {
218afe78db7SPeiming Liu       lvl = cooStart;
219afe78db7SPeiming Liu       stride = lvlRank - cooStart;
220afe78db7SPeiming Liu     }
221afe78db7SPeiming Liu   }
222afe78db7SPeiming Liu   foreachField([lvl, kind, &fieldIdx](FieldIndex fIdx,
223afe78db7SPeiming Liu                                       SparseTensorFieldKind fKind, Level fLvl,
2241944c4f7SAart Bik                                       LevelType lt) -> bool {
225afe78db7SPeiming Liu     if ((lvl && fLvl == lvl.value() && kind == fKind) ||
226afe78db7SPeiming Liu         (kind == fKind && fKind == SparseTensorFieldKind::ValMemRef)) {
227afe78db7SPeiming Liu       fieldIdx = fIdx;
228afe78db7SPeiming Liu       // Returns false to break the iteration.
229afe78db7SPeiming Liu       return false;
230afe78db7SPeiming Liu     }
231afe78db7SPeiming Liu     return true;
232afe78db7SPeiming Liu   });
233afe78db7SPeiming Liu   assert(fieldIdx != kInvalidFieldIndex);
234afe78db7SPeiming Liu   return std::pair<FieldIndex, unsigned>(fieldIdx, stride);
235afe78db7SPeiming Liu }
236afe78db7SPeiming Liu 
237afe78db7SPeiming Liu //===----------------------------------------------------------------------===//
2387e83a1afSAart Bik // SparseTensorDialect Attribute Methods.
2390a292199SAart Bik //===----------------------------------------------------------------------===//
2400a292199SAart Bik 
2417a1077baSwren romano std::optional<uint64_t> SparseTensorDimSliceAttr::getStatic(int64_t v) {
2427a1077baSwren romano   return isDynamic(v) ? std::nullopt
2437a1077baSwren romano                       : std::make_optional(static_cast<uint64_t>(v));
2440a292199SAart Bik }
2457a1077baSwren romano 
2467a1077baSwren romano std::optional<uint64_t> SparseTensorDimSliceAttr::getStaticOffset() const {
2477a1077baSwren romano   return getStatic(getOffset());
2487a1077baSwren romano }
2497a1077baSwren romano 
2507a1077baSwren romano std::optional<uint64_t> SparseTensorDimSliceAttr::getStaticStride() const {
2517a1077baSwren romano   return getStatic(getStride());
2527a1077baSwren romano }
2537a1077baSwren romano 
2547a1077baSwren romano std::optional<uint64_t> SparseTensorDimSliceAttr::getStaticSize() const {
2557a1077baSwren romano   return getStatic(getSize());
2567a1077baSwren romano }
2577a1077baSwren romano 
2587a1077baSwren romano bool SparseTensorDimSliceAttr::isCompletelyDynamic() const {
2597a1077baSwren romano   return isDynamic(getOffset()) && isDynamic(getStride()) &&
2607a1077baSwren romano          isDynamic(getSize());
2617a1077baSwren romano }
2627a1077baSwren romano 
2637a1077baSwren romano std::string SparseTensorDimSliceAttr::getStaticString(int64_t v) {
2647a1077baSwren romano   return isDynamic(v) ? "?" : std::to_string(v);
2650a292199SAart Bik }
2660a292199SAart Bik 
26711a4f5bdSAart Bik void SparseTensorDimSliceAttr::print(llvm::raw_ostream &os) const {
268cad46467Swren romano   assert(getImpl() && "Uninitialized SparseTensorDimSliceAttr");
26911a4f5bdSAart Bik   os << '(';
27011a4f5bdSAart Bik   os << getStaticString(getOffset());
27111a4f5bdSAart Bik   os << ", ";
27211a4f5bdSAart Bik   os << getStaticString(getSize());
27311a4f5bdSAart Bik   os << ", ";
27411a4f5bdSAart Bik   os << getStaticString(getStride());
27511a4f5bdSAart Bik   os << ')';
27611a4f5bdSAart Bik }
27711a4f5bdSAart Bik 
278885a1f43SPeiming Liu void SparseTensorDimSliceAttr::print(AsmPrinter &printer) const {
27911a4f5bdSAart Bik   print(printer.getStream());
280885a1f43SPeiming Liu }
281885a1f43SPeiming Liu 
282885a1f43SPeiming Liu static ParseResult parseOptionalStaticSlice(int64_t &result,
283885a1f43SPeiming Liu                                             AsmParser &parser) {
284885a1f43SPeiming Liu   auto parseResult = parser.parseOptionalInteger(result);
285885a1f43SPeiming Liu   if (parseResult.has_value()) {
286885a1f43SPeiming Liu     if (parseResult.value().succeeded() && result < 0) {
287885a1f43SPeiming Liu       parser.emitError(
288885a1f43SPeiming Liu           parser.getCurrentLocation(),
289885a1f43SPeiming Liu           "expect positive value or ? for slice offset/size/stride");
290885a1f43SPeiming Liu       return failure();
291885a1f43SPeiming Liu     }
292885a1f43SPeiming Liu     return parseResult.value();
293885a1f43SPeiming Liu   }
294885a1f43SPeiming Liu 
295885a1f43SPeiming Liu   // Else, and '?' which represented dynamic slice
296885a1f43SPeiming Liu   result = SparseTensorDimSliceAttr::kDynamic;
297885a1f43SPeiming Liu   return parser.parseQuestion();
298885a1f43SPeiming Liu }
299885a1f43SPeiming Liu 
300885a1f43SPeiming Liu Attribute SparseTensorDimSliceAttr::parse(AsmParser &parser, Type type) {
3017a1077baSwren romano   int64_t offset = kDynamic, size = kDynamic, stride = kDynamic;
302885a1f43SPeiming Liu 
303885a1f43SPeiming Liu   if (failed(parser.parseLParen()) ||
304885a1f43SPeiming Liu       failed(parseOptionalStaticSlice(offset, parser)) ||
305885a1f43SPeiming Liu       failed(parser.parseComma()) ||
306885a1f43SPeiming Liu       failed(parseOptionalStaticSlice(size, parser)) ||
307885a1f43SPeiming Liu       failed(parser.parseComma()) ||
308885a1f43SPeiming Liu       failed(parseOptionalStaticSlice(stride, parser)) ||
309885a1f43SPeiming Liu       failed(parser.parseRParen()))
310885a1f43SPeiming Liu     return {};
311885a1f43SPeiming Liu 
312885a1f43SPeiming Liu   return parser.getChecked<SparseTensorDimSliceAttr>(parser.getContext(),
313885a1f43SPeiming Liu                                                      offset, size, stride);
314885a1f43SPeiming Liu }
315885a1f43SPeiming Liu 
316885a1f43SPeiming Liu LogicalResult
317885a1f43SPeiming Liu SparseTensorDimSliceAttr::verify(function_ref<InFlightDiagnostic()> emitError,
318885a1f43SPeiming Liu                                  int64_t offset, int64_t size, int64_t stride) {
3197a1077baSwren romano   if (!isDynamic(offset) && offset < 0)
3207a1077baSwren romano     return emitError() << "expect non-negative value or ? for slice offset";
3217a1077baSwren romano   if (!isDynamic(size) && size <= 0)
3227a1077baSwren romano     return emitError() << "expect positive value or ? for slice size";
3237a1077baSwren romano   if (!isDynamic(stride) && stride <= 0)
3247a1077baSwren romano     return emitError() << "expect positive value or ? for slice stride";
325885a1f43SPeiming Liu   return success();
326885a1f43SPeiming Liu }
327885a1f43SPeiming Liu 
32876647fceSwren romano SparseTensorEncodingAttr
32976647fceSwren romano SparseTensorEncodingAttr::withDimToLvl(AffineMap dimToLvl) const {
33076647fceSwren romano   assert(getImpl() && "Uninitialized SparseTensorEncodingAttr");
331a10d67f9SYinying Li   return SparseTensorEncodingAttr::get(
332a10d67f9SYinying Li       getContext(), getLvlTypes(), dimToLvl, AffineMap(), getPosWidth(),
333a10d67f9SYinying Li       getCrdWidth(), getExplicitVal(), getImplicitVal());
33476647fceSwren romano }
33576647fceSwren romano 
33676647fceSwren romano SparseTensorEncodingAttr
33776647fceSwren romano SparseTensorEncodingAttr::withDimToLvl(SparseTensorEncodingAttr enc) const {
33876647fceSwren romano   return withDimToLvl(enc ? enc.getDimToLvl() : AffineMap());
33976647fceSwren romano }
34076647fceSwren romano 
34176647fceSwren romano SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutDimToLvl() const {
34276647fceSwren romano   return withDimToLvl(AffineMap());
34376647fceSwren romano }
34476647fceSwren romano 
34576647fceSwren romano SparseTensorEncodingAttr
34676647fceSwren romano SparseTensorEncodingAttr::withBitWidths(unsigned posWidth,
34776647fceSwren romano                                         unsigned crdWidth) const {
34876647fceSwren romano   assert(getImpl() && "Uninitialized SparseTensorEncodingAttr");
349a10d67f9SYinying Li   return SparseTensorEncodingAttr::get(
350a10d67f9SYinying Li       getContext(), getLvlTypes(), getDimToLvl(), getLvlToDim(), posWidth,
351a10d67f9SYinying Li       crdWidth, getExplicitVal(), getImplicitVal());
35296fef4dcSwren romano }
35396fef4dcSwren romano 
35485dbb3fcSPeiming Liu SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutBitWidths() const {
35576647fceSwren romano   return withBitWidths(0, 0);
35685dbb3fcSPeiming Liu }
35785dbb3fcSPeiming Liu 
358a10d67f9SYinying Li SparseTensorEncodingAttr
359a10d67f9SYinying Li SparseTensorEncodingAttr::withExplicitVal(Attribute explicitVal) const {
360a10d67f9SYinying Li   assert(getImpl() && "Uninitialized SparseTensorEncodingAttr");
361a10d67f9SYinying Li   return SparseTensorEncodingAttr::get(
362a10d67f9SYinying Li       getContext(), getLvlTypes(), getDimToLvl(), getLvlToDim(), getPosWidth(),
363a10d67f9SYinying Li       getCrdWidth(), explicitVal, getImplicitVal());
364a10d67f9SYinying Li }
365a10d67f9SYinying Li 
366a10d67f9SYinying Li SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutExplicitVal() const {
367a10d67f9SYinying Li   return withExplicitVal(Attribute());
368a10d67f9SYinying Li }
369a10d67f9SYinying Li 
370a10d67f9SYinying Li SparseTensorEncodingAttr
371a10d67f9SYinying Li SparseTensorEncodingAttr::withImplicitVal(Attribute implicitVal) const {
372a10d67f9SYinying Li   assert(getImpl() && "Uninitialized SparseTensorEncodingAttr");
373a10d67f9SYinying Li   return SparseTensorEncodingAttr::get(
374a10d67f9SYinying Li       getContext(), getLvlTypes(), getDimToLvl(), getLvlToDim(), getPosWidth(),
375a10d67f9SYinying Li       getCrdWidth(), getExplicitVal(), implicitVal);
376a10d67f9SYinying Li }
377a10d67f9SYinying Li 
378a10d67f9SYinying Li SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutImplicitVal() const {
379a10d67f9SYinying Li   return withImplicitVal(Attribute());
380a10d67f9SYinying Li }
381a10d67f9SYinying Li 
382af2bec7cSwren romano SparseTensorEncodingAttr SparseTensorEncodingAttr::withDimSlices(
383af2bec7cSwren romano     ArrayRef<SparseTensorDimSliceAttr> dimSlices) const {
384a10d67f9SYinying Li   return SparseTensorEncodingAttr::get(
385a10d67f9SYinying Li       getContext(), getLvlTypes(), getDimToLvl(), getLvlToDim(), getPosWidth(),
386a10d67f9SYinying Li       getCrdWidth(), getExplicitVal(), getImplicitVal(), dimSlices);
387af2bec7cSwren romano }
388af2bec7cSwren romano 
389af2bec7cSwren romano SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutDimSlices() const {
390af2bec7cSwren romano   return withDimSlices(ArrayRef<SparseTensorDimSliceAttr>{});
391af2bec7cSwren romano }
392af2bec7cSwren romano 
3930d1f9576SPeiming Liu uint64_t SparseTensorEncodingAttr::getBatchLvlRank() const {
3940d1f9576SPeiming Liu   ArrayRef<LevelType> lvlTypes = getLvlTypes();
3950d1f9576SPeiming Liu   auto lastBatch = std::find_if(lvlTypes.rbegin(), lvlTypes.rend(), isBatchLT);
3960d1f9576SPeiming Liu   return std::distance(lastBatch, lvlTypes.rend());
3970d1f9576SPeiming Liu }
3980d1f9576SPeiming Liu 
399a3672addSPeiming Liu bool SparseTensorEncodingAttr::isAllDense() const {
4001dd387e1SAart Bik   return !getImpl() || llvm::all_of(getLvlTypes(), isDenseLT);
401f708a549Swren romano }
402f708a549Swren romano 
403f708a549Swren romano bool SparseTensorEncodingAttr::isAllOrdered() const {
4041dd387e1SAart Bik   return !getImpl() || llvm::all_of(getLvlTypes(), isOrderedLT);
405a3672addSPeiming Liu }
406a3672addSPeiming Liu 
40762fa12adSPeiming Liu Type SparseTensorEncodingAttr::getCrdElemType() const {
40862fa12adSPeiming Liu   if (!getImpl())
40962fa12adSPeiming Liu     return nullptr;
41062fa12adSPeiming Liu   if (getCrdWidth())
41162fa12adSPeiming Liu     return IntegerType::get(getContext(), getCrdWidth());
41262fa12adSPeiming Liu   return IndexType::get(getContext());
41362fa12adSPeiming Liu }
41462fa12adSPeiming Liu 
41562fa12adSPeiming Liu Type SparseTensorEncodingAttr::getPosElemType() const {
41662fa12adSPeiming Liu   if (!getImpl())
41762fa12adSPeiming Liu     return nullptr;
41862fa12adSPeiming Liu   if (getPosWidth())
41962fa12adSPeiming Liu     return IntegerType::get(getContext(), getPosWidth());
42062fa12adSPeiming Liu   return IndexType::get(getContext());
42162fa12adSPeiming Liu }
42262fa12adSPeiming Liu 
42362fa12adSPeiming Liu MemRefType SparseTensorEncodingAttr::getCrdMemRefType(
42462fa12adSPeiming Liu     std::optional<ArrayRef<int64_t>> dimShape) const {
42562fa12adSPeiming Liu   SmallVector<Size> shape = getSparseFieldShape(*this, dimShape);
42662fa12adSPeiming Liu   return MemRefType::get(shape, getCrdElemType());
42762fa12adSPeiming Liu }
42862fa12adSPeiming Liu 
42962fa12adSPeiming Liu MemRefType SparseTensorEncodingAttr::getPosMemRefType(
43062fa12adSPeiming Liu     std::optional<ArrayRef<int64_t>> dimShape) const {
43162fa12adSPeiming Liu   SmallVector<Size> shape = getSparseFieldShape(*this, dimShape);
43262fa12adSPeiming Liu   return MemRefType::get(shape, getPosElemType());
43362fa12adSPeiming Liu }
43462fa12adSPeiming Liu 
43576647fceSwren romano bool SparseTensorEncodingAttr::isIdentity() const {
43676647fceSwren romano   return !getImpl() || !getDimToLvl() || getDimToLvl().isIdentity();
43776647fceSwren romano }
43876647fceSwren romano 
43976647fceSwren romano bool SparseTensorEncodingAttr::isPermutation() const {
44076647fceSwren romano   return !getImpl() || !getDimToLvl() || getDimToLvl().isPermutation();
44176647fceSwren romano }
44276647fceSwren romano 
44376647fceSwren romano Dimension SparseTensorEncodingAttr::getDimRank() const {
44476647fceSwren romano   assert(getImpl() && "Uninitialized SparseTensorEncodingAttr");
44576647fceSwren romano   const auto dimToLvl = getDimToLvl();
44676647fceSwren romano   return dimToLvl ? dimToLvl.getNumDims() : getLvlRank();
447f708a549Swren romano }
448f708a549Swren romano 
449f708a549Swren romano Level SparseTensorEncodingAttr::getLvlRank() const {
450f708a549Swren romano   assert(getImpl() && "Uninitialized SparseTensorEncodingAttr");
451a0615d02Swren romano   return getLvlTypes().size();
452f708a549Swren romano }
453f708a549Swren romano 
4541944c4f7SAart Bik LevelType SparseTensorEncodingAttr::getLvlType(Level l) const {
455f708a549Swren romano   if (!getImpl())
45652b69aa3SPeiming Liu     return LevelFormat::Batch;
457f708a549Swren romano   assert(l < getLvlRank() && "Level is out of bounds");
458a0615d02Swren romano   return getLvlTypes()[l];
459a3672addSPeiming Liu }
460a3672addSPeiming Liu 
461867e1964Swren romano bool SparseTensorEncodingAttr::isSlice() const {
462867e1964Swren romano   assert(getImpl() && "Uninitialized SparseTensorEncodingAttr");
463867e1964Swren romano   return !getDimSlices().empty();
464867e1964Swren romano }
465867e1964Swren romano 
466867e1964Swren romano SparseTensorDimSliceAttr
467867e1964Swren romano SparseTensorEncodingAttr::getDimSlice(Dimension dim) const {
468867e1964Swren romano   assert(isSlice() && "Is not a slice");
469867e1964Swren romano   const auto dimSlices = getDimSlices();
470867e1964Swren romano   assert(dim < dimSlices.size() && "Dimension is out of bounds");
471867e1964Swren romano   return dimSlices[dim];
472867e1964Swren romano }
473867e1964Swren romano 
474885a1f43SPeiming Liu std::optional<uint64_t>
475f708a549Swren romano SparseTensorEncodingAttr::getStaticDimSliceOffset(Dimension dim) const {
476867e1964Swren romano   return getDimSlice(dim).getStaticOffset();
477885a1f43SPeiming Liu }
478885a1f43SPeiming Liu 
479885a1f43SPeiming Liu std::optional<uint64_t>
480f708a549Swren romano SparseTensorEncodingAttr::getStaticDimSliceStride(Dimension dim) const {
481867e1964Swren romano   return getDimSlice(dim).getStaticStride();
482885a1f43SPeiming Liu }
483885a1f43SPeiming Liu 
484885a1f43SPeiming Liu std::optional<uint64_t>
485f708a549Swren romano SparseTensorEncodingAttr::getStaticLvlSliceOffset(Level lvl) const {
4864e2f1521SPeiming Liu   return getStaticDimSliceOffset(toDim(*this, lvl));
487885a1f43SPeiming Liu }
488885a1f43SPeiming Liu 
489885a1f43SPeiming Liu std::optional<uint64_t>
490f708a549Swren romano SparseTensorEncodingAttr::getStaticLvlSliceStride(Level lvl) const {
4914e2f1521SPeiming Liu   return getStaticDimSliceStride(toDim(*this, lvl));
492885a1f43SPeiming Liu }
493885a1f43SPeiming Liu 
494d808d922SPeiming Liu SmallVector<int64_t>
495e10dc60aSAart Bik SparseTensorEncodingAttr::translateShape(ArrayRef<int64_t> srcShape,
496d808d922SPeiming Liu                                          CrdTransDirectionKind dir) const {
497d808d922SPeiming Liu   if (isIdentity())
498d808d922SPeiming Liu     return SmallVector<int64_t>(srcShape);
499d808d922SPeiming Liu 
500d808d922SPeiming Liu   SmallVector<int64_t> ret;
501d808d922SPeiming Liu   unsigned rank =
502d808d922SPeiming Liu       dir == CrdTransDirectionKind::dim2lvl ? getLvlRank() : getDimRank();
503d808d922SPeiming Liu   ret.reserve(rank);
504d808d922SPeiming Liu 
505d808d922SPeiming Liu   if (isPermutation()) {
506d808d922SPeiming Liu     for (unsigned r = 0; r < rank; r++) {
5074e2f1521SPeiming Liu       unsigned trans = dir == CrdTransDirectionKind::dim2lvl ? toDim(*this, r)
5084e2f1521SPeiming Liu                                                              : toLvl(*this, r);
509d808d922SPeiming Liu       ret.push_back(srcShape[trans]);
510d808d922SPeiming Liu     }
511d808d922SPeiming Liu     return ret;
512d808d922SPeiming Liu   }
513d808d922SPeiming Liu 
514d808d922SPeiming Liu   // Handle non-permutation maps.
515d808d922SPeiming Liu   AffineMap transMap =
516d808d922SPeiming Liu       dir == CrdTransDirectionKind::dim2lvl ? getDimToLvl() : getLvlToDim();
517d808d922SPeiming Liu 
518d808d922SPeiming Liu   SmallVector<AffineExpr> dimRep;
519d808d922SPeiming Liu   dimRep.reserve(srcShape.size());
520d808d922SPeiming Liu   for (int64_t sz : srcShape) {
521d808d922SPeiming Liu     if (!ShapedType::isDynamic(sz)) {
522d808d922SPeiming Liu       // Push back the max coordinate for the given dimension/level size.
523d808d922SPeiming Liu       dimRep.push_back(getAffineConstantExpr(sz - 1, getContext()));
524d808d922SPeiming Liu     } else {
525d808d922SPeiming Liu       // A dynamic size, use a AffineDimExpr to symbolize the value.
526d808d922SPeiming Liu       dimRep.push_back(getAffineDimExpr(dimRep.size(), getContext()));
527d808d922SPeiming Liu     }
528d808d922SPeiming Liu   };
529d808d922SPeiming Liu 
530d808d922SPeiming Liu   for (AffineExpr exp : transMap.getResults()) {
531d808d922SPeiming Liu     // Do constant propagation on the affine map.
532d808d922SPeiming Liu     AffineExpr evalExp =
533d808d922SPeiming Liu         simplifyAffineExpr(exp.replaceDims(dimRep), srcShape.size(), 0);
5341609f1c2Slong.chen     // use llvm namespace here to avoid ambiguity
5351609f1c2Slong.chen     if (auto c = llvm::dyn_cast<AffineConstantExpr>(evalExp)) {
536d808d922SPeiming Liu       ret.push_back(c.getValue() + 1);
537ef222988SPeiming Liu     } else {
5381609f1c2Slong.chen       if (auto mod = llvm::dyn_cast<AffineBinaryOpExpr>(evalExp);
539ef222988SPeiming Liu           mod && mod.getKind() == AffineExprKind::Mod) {
540ef222988SPeiming Liu         // We can still infer a static bound for expressions in form
541ef222988SPeiming Liu         // "d % constant" since d % constant \in [0, constant).
5421609f1c2Slong.chen         if (auto bound = llvm::dyn_cast<AffineConstantExpr>(mod.getRHS())) {
543ef222988SPeiming Liu           ret.push_back(bound.getValue());
544ef222988SPeiming Liu           continue;
545ef222988SPeiming Liu         }
546ef222988SPeiming Liu       }
547d808d922SPeiming Liu       ret.push_back(ShapedType::kDynamic);
548d808d922SPeiming Liu     }
549ef222988SPeiming Liu   }
550ef222988SPeiming Liu   assert(ret.size() == rank);
551d808d922SPeiming Liu   return ret;
552d808d922SPeiming Liu }
553d808d922SPeiming Liu 
5546456e0bbSPeiming Liu ValueRange
5556456e0bbSPeiming Liu SparseTensorEncodingAttr::translateCrds(OpBuilder &builder, Location loc,
5566456e0bbSPeiming Liu                                         ValueRange crds,
5576456e0bbSPeiming Liu                                         CrdTransDirectionKind dir) const {
5586456e0bbSPeiming Liu   if (!getImpl())
5596456e0bbSPeiming Liu     return crds;
5606456e0bbSPeiming Liu 
5616456e0bbSPeiming Liu   SmallVector<Type> retType(
5626456e0bbSPeiming Liu       dir == CrdTransDirectionKind::lvl2dim ? getDimRank() : getLvlRank(),
5636456e0bbSPeiming Liu       builder.getIndexType());
5646456e0bbSPeiming Liu   auto transOp = builder.create<CrdTranslateOp>(loc, retType, crds, dir, *this);
5656456e0bbSPeiming Liu   return transOp.getOutCrds();
5666456e0bbSPeiming Liu }
5676456e0bbSPeiming Liu 
568f97e72aaSMehdi Amini Attribute SparseTensorEncodingAttr::parse(AsmParser &parser, Type type) {
5692e2011daSAart Bik   // Open "<{" part.
5702e2011daSAart Bik   if (failed(parser.parseLess()))
5712e2011daSAart Bik     return {};
5722e2011daSAart Bik   if (failed(parser.parseLBrace()))
5732e2011daSAart Bik     return {};
574885a1f43SPeiming Liu 
5750a292199SAart Bik   // Process the data from the parsed dictionary value into struct-like data.
5761944c4f7SAart Bik   SmallVector<LevelType> lvlTypes;
577540d5e0cSwren romano   SmallVector<SparseTensorDimSliceAttr> dimSlices;
57876647fceSwren romano   AffineMap dimToLvl = {};
5797b9fb1c2SYinying Li   AffineMap lvlToDim = {};
58084cd51bbSwren romano   unsigned posWidth = 0;
58184cd51bbSwren romano   unsigned crdWidth = 0;
582a10d67f9SYinying Li   Attribute explicitVal;
583a10d67f9SYinying Li   Attribute implicitVal;
584885a1f43SPeiming Liu   StringRef attrName;
585a10d67f9SYinying Li   SmallVector<StringRef, 5> keys = {"map", "posWidth", "crdWidth",
586a10d67f9SYinying Li                                     "explicitVal", "implicitVal"};
587885a1f43SPeiming Liu   while (succeeded(parser.parseOptionalKeyword(&attrName))) {
588bb44a6b7SAart Bik     // Detect admissible keyword.
589bb44a6b7SAart Bik     auto *it = find(keys, attrName);
590bb44a6b7SAart Bik     if (it == keys.end()) {
591885a1f43SPeiming Liu       parser.emitError(parser.getNameLoc(), "unexpected key: ") << attrName;
592885a1f43SPeiming Liu       return {};
593885a1f43SPeiming Liu     }
594bb44a6b7SAart Bik     unsigned keyWordIndex = it - keys.begin();
595885a1f43SPeiming Liu     // Consume the `=` after keys
5962e2011daSAart Bik     if (failed(parser.parseEqual()))
5972e2011daSAart Bik       return {};
598bb44a6b7SAart Bik     // Dispatch on keyword.
599bb44a6b7SAart Bik     switch (keyWordIndex) {
600fb5047f5SYinying Li     case 0: { // map
6016b88c852SAart Bik       ir_detail::DimLvlMapParser cParser(parser);
6026b88c852SAart Bik       auto res = cParser.parseDimLvlMap();
6032e2011daSAart Bik       if (failed(res))
6042e2011daSAart Bik         return {};
605cad46467Swren romano       const auto &dlm = *res;
606cad46467Swren romano 
607cad46467Swren romano       const Level lvlRank = dlm.getLvlRank();
608cad46467Swren romano       for (Level lvl = 0; lvl < lvlRank; lvl++)
609cad46467Swren romano         lvlTypes.push_back(dlm.getLvlType(lvl));
610cad46467Swren romano 
611cad46467Swren romano       const Dimension dimRank = dlm.getDimRank();
612cad46467Swren romano       for (Dimension dim = 0; dim < dimRank; dim++)
613cad46467Swren romano         dimSlices.push_back(dlm.getDimSlice(dim));
614cad46467Swren romano       // NOTE: the old syntax requires an all-or-nothing approach to
615cad46467Swren romano       // `dimSlices`; therefore, if any slice actually exists then we need
616cad46467Swren romano       // to convert null-DSA into default/nop DSA.
617cad46467Swren romano       const auto isDefined = [](SparseTensorDimSliceAttr slice) {
618cad46467Swren romano         return static_cast<bool>(slice.getImpl());
619cad46467Swren romano       };
620cad46467Swren romano       if (llvm::any_of(dimSlices, isDefined)) {
621cad46467Swren romano         const auto defaultSlice =
622cad46467Swren romano             SparseTensorDimSliceAttr::get(parser.getContext());
623cad46467Swren romano         for (Dimension dim = 0; dim < dimRank; dim++)
624cad46467Swren romano           if (!isDefined(dimSlices[dim]))
625cad46467Swren romano             dimSlices[dim] = defaultSlice;
626cad46467Swren romano       } else {
627cad46467Swren romano         dimSlices.clear();
628cad46467Swren romano       }
629cad46467Swren romano 
630cad46467Swren romano       dimToLvl = dlm.getDimToLvlMap(parser.getContext());
6317b9fb1c2SYinying Li       lvlToDim = dlm.getLvlToDimMap(parser.getContext());
632bb44a6b7SAart Bik       break;
6330a292199SAart Bik     }
634fb5047f5SYinying Li     case 1: { // posWidth
635fb5047f5SYinying Li       Attribute attr;
6362e2011daSAart Bik       if (failed(parser.parseAttribute(attr)))
6372e2011daSAart Bik         return {};
638fb5047f5SYinying Li       auto intAttr = llvm::dyn_cast<IntegerAttr>(attr);
6392e2011daSAart Bik       if (!intAttr) {
6402e2011daSAart Bik         parser.emitError(parser.getNameLoc(),
6412e2011daSAart Bik                          "expected an integral position bitwidth");
6422e2011daSAart Bik         return {};
6432e2011daSAart Bik       }
644fb5047f5SYinying Li       posWidth = intAttr.getInt();
645fb5047f5SYinying Li       break;
646fb5047f5SYinying Li     }
647fb5047f5SYinying Li     case 2: { // crdWidth
648fb5047f5SYinying Li       Attribute attr;
6492e2011daSAart Bik       if (failed(parser.parseAttribute(attr)))
6502e2011daSAart Bik         return {};
651fb5047f5SYinying Li       auto intAttr = llvm::dyn_cast<IntegerAttr>(attr);
6522e2011daSAart Bik       if (!intAttr) {
6532e2011daSAart Bik         parser.emitError(parser.getNameLoc(),
6542e2011daSAart Bik                          "expected an integral index bitwidth");
6552e2011daSAart Bik         return {};
6562e2011daSAart Bik       }
657fb5047f5SYinying Li       crdWidth = intAttr.getInt();
658fb5047f5SYinying Li       break;
659fb5047f5SYinying Li     }
660a10d67f9SYinying Li     case 3: { // explicitVal
661a10d67f9SYinying Li       Attribute attr;
662a10d67f9SYinying Li       if (failed(parser.parseAttribute(attr)))
663a10d67f9SYinying Li         return {};
664a10d67f9SYinying Li       if (auto result = llvm::dyn_cast<FloatAttr>(attr)) {
665a10d67f9SYinying Li         explicitVal = result;
666a10d67f9SYinying Li       } else if (auto result = llvm::dyn_cast<IntegerAttr>(attr)) {
667a10d67f9SYinying Li         explicitVal = result;
668e71eacc5SYinying Li       } else if (auto result = llvm::dyn_cast<complex::NumberAttr>(attr)) {
669e71eacc5SYinying Li         explicitVal = result;
670a10d67f9SYinying Li       } else {
671a10d67f9SYinying Li         parser.emitError(parser.getNameLoc(),
672a10d67f9SYinying Li                          "expected a numeric value for explicitVal");
673a10d67f9SYinying Li         return {};
674a10d67f9SYinying Li       }
675a10d67f9SYinying Li       break;
676a10d67f9SYinying Li     }
677a10d67f9SYinying Li     case 4: { // implicitVal
678a10d67f9SYinying Li       Attribute attr;
679a10d67f9SYinying Li       if (failed(parser.parseAttribute(attr)))
680a10d67f9SYinying Li         return {};
681a10d67f9SYinying Li       if (auto result = llvm::dyn_cast<FloatAttr>(attr)) {
682a10d67f9SYinying Li         implicitVal = result;
683a10d67f9SYinying Li       } else if (auto result = llvm::dyn_cast<IntegerAttr>(attr)) {
684a10d67f9SYinying Li         implicitVal = result;
685e71eacc5SYinying Li       } else if (auto result = llvm::dyn_cast<complex::NumberAttr>(attr)) {
686e71eacc5SYinying Li         implicitVal = result;
687a10d67f9SYinying Li       } else {
688a10d67f9SYinying Li         parser.emitError(parser.getNameLoc(),
689a10d67f9SYinying Li                          "expected a numeric value for implicitVal");
690a10d67f9SYinying Li         return {};
691a10d67f9SYinying Li       }
692a10d67f9SYinying Li       break;
693a10d67f9SYinying Li     }
694bb44a6b7SAart Bik     } // switch
695bb44a6b7SAart Bik     // Only last item can omit the comma.
696885a1f43SPeiming Liu     if (parser.parseOptionalComma().failed())
697885a1f43SPeiming Liu       break;
6980a292199SAart Bik   }
699885a1f43SPeiming Liu 
7002e2011daSAart Bik   // Close "}>" part.
7012e2011daSAart Bik   if (failed(parser.parseRBrace()))
7022e2011daSAart Bik     return {};
7032e2011daSAart Bik   if (failed(parser.parseGreater()))
7042e2011daSAart Bik     return {};
705885a1f43SPeiming Liu 
7060a292199SAart Bik   // Construct struct-like storage for attribute.
7077b9fb1c2SYinying Li   if (!lvlToDim || lvlToDim.isEmpty()) {
7087b9fb1c2SYinying Li     lvlToDim = inferLvlToDim(dimToLvl, parser.getContext());
7097b9fb1c2SYinying Li   }
710c48e9087SAart Bik   return parser.getChecked<SparseTensorEncodingAttr>(
711836411b9SAart Bik       parser.getContext(), lvlTypes, dimToLvl, lvlToDim, posWidth, crdWidth,
712a10d67f9SYinying Li       explicitVal, implicitVal, dimSlices);
7130a292199SAart Bik }
7140a292199SAart Bik 
715f97e72aaSMehdi Amini void SparseTensorEncodingAttr::print(AsmPrinter &printer) const {
7166280e231SYinying Li   auto map = static_cast<AffineMap>(getDimToLvl());
7176280e231SYinying Li   // Empty affine map indicates identity map
7186280e231SYinying Li   if (!map)
7196280e231SYinying Li     map = AffineMap::getMultiDimIdentityMap(getLvlTypes().size(), getContext());
7206280e231SYinying Li   printer << "<{ map = ";
7216280e231SYinying Li   printSymbols(map, printer);
7226280e231SYinying Li   printer << '(';
7236280e231SYinying Li   printDimensions(map, printer, getDimSlices());
7246280e231SYinying Li   printer << ") -> (";
7256280e231SYinying Li   printLevels(map, printer, getLvlTypes());
7266280e231SYinying Li   printer << ')';
727e3d64ccfSAart Bik   // Print remaining members only for non-default values.
72884cd51bbSwren romano   if (getPosWidth())
72984cd51bbSwren romano     printer << ", posWidth = " << getPosWidth();
73084cd51bbSwren romano   if (getCrdWidth())
73184cd51bbSwren romano     printer << ", crdWidth = " << getCrdWidth();
732a10d67f9SYinying Li   if (getExplicitVal()) {
733a10d67f9SYinying Li     printer << ", explicitVal = " << getExplicitVal();
734a10d67f9SYinying Li   }
735a10d67f9SYinying Li   if (getImplicitVal())
736a10d67f9SYinying Li     printer << ", implicitVal = " << getImplicitVal();
7376280e231SYinying Li   printer << " }>";
738885a1f43SPeiming Liu }
739885a1f43SPeiming Liu 
7406280e231SYinying Li void SparseTensorEncodingAttr::printSymbols(AffineMap &map,
7416280e231SYinying Li                                             AsmPrinter &printer) const {
7426280e231SYinying Li   if (map.getNumSymbols() == 0)
7436280e231SYinying Li     return;
7446280e231SYinying Li   printer << '[';
7456280e231SYinying Li   for (unsigned i = 0, n = map.getNumSymbols() - 1; i < n; i++)
7466280e231SYinying Li     printer << 's' << i << ", ";
7476280e231SYinying Li   if (map.getNumSymbols() >= 1)
7486280e231SYinying Li     printer << 's' << map.getNumSymbols() - 1;
7496280e231SYinying Li   printer << ']';
7506280e231SYinying Li }
7516280e231SYinying Li 
7526280e231SYinying Li void SparseTensorEncodingAttr::printDimensions(
7536280e231SYinying Li     AffineMap &map, AsmPrinter &printer,
7546280e231SYinying Li     ArrayRef<SparseTensorDimSliceAttr> dimSlices) const {
7556280e231SYinying Li   if (!dimSlices.empty()) {
7566280e231SYinying Li     for (unsigned i = 0, n = map.getNumDims() - 1; i < n; i++)
7576280e231SYinying Li       printer << 'd' << i << " : " << dimSlices[i] << ", ";
7586280e231SYinying Li     if (map.getNumDims() >= 1) {
7596280e231SYinying Li       printer << 'd' << map.getNumDims() - 1 << " : "
7606280e231SYinying Li               << dimSlices[map.getNumDims() - 1];
7616280e231SYinying Li     }
7626280e231SYinying Li   } else {
7636280e231SYinying Li     for (unsigned i = 0, n = map.getNumDims() - 1; i < n; i++)
7646280e231SYinying Li       printer << 'd' << i << ", ";
7656280e231SYinying Li     if (map.getNumDims() >= 1)
7666280e231SYinying Li       printer << 'd' << map.getNumDims() - 1;
7676280e231SYinying Li   }
7686280e231SYinying Li }
7696280e231SYinying Li 
7701944c4f7SAart Bik void SparseTensorEncodingAttr::printLevels(AffineMap &map, AsmPrinter &printer,
7711944c4f7SAart Bik                                            ArrayRef<LevelType> lvlTypes) const {
7726280e231SYinying Li   for (unsigned i = 0, n = map.getNumResults() - 1; i < n; i++) {
7736280e231SYinying Li     map.getResult(i).print(printer.getStream());
774ced1fac8SYinying Li     printer << " : " << toMLIRString(lvlTypes[i]) << ", ";
7756280e231SYinying Li   }
7766280e231SYinying Li   if (map.getNumResults() >= 1) {
7776280e231SYinying Li     auto lastIndex = map.getNumResults() - 1;
7786280e231SYinying Li     map.getResult(lastIndex).print(printer.getStream());
779ced1fac8SYinying Li     printer << " : " << toMLIRString(lvlTypes[lastIndex]);
7806280e231SYinying Li   }
7810a292199SAart Bik }
7820a292199SAart Bik 
7831944c4f7SAart Bik LogicalResult SparseTensorEncodingAttr::verify(
7841944c4f7SAart Bik     function_ref<InFlightDiagnostic()> emitError, ArrayRef<LevelType> lvlTypes,
7851944c4f7SAart Bik     AffineMap dimToLvl, AffineMap lvlToDim, unsigned posWidth,
786a10d67f9SYinying Li     unsigned crdWidth, Attribute explicitVal, Attribute implicitVal,
787a10d67f9SYinying Li     ArrayRef<SparseTensorDimSliceAttr> dimSlices) {
78884cd51bbSwren romano   if (!acceptBitWidth(posWidth))
78984cd51bbSwren romano     return emitError() << "unexpected position bitwidth: " << posWidth;
79084cd51bbSwren romano   if (!acceptBitWidth(crdWidth))
79184cd51bbSwren romano     return emitError() << "unexpected coordinate bitwidth: " << crdWidth;
79213af97a7SPeiming Liu 
79313af97a7SPeiming Liu   // Verify every COO segment.
79413af97a7SPeiming Liu   auto *it = std::find_if(lvlTypes.begin(), lvlTypes.end(), isSingletonLT);
79513af97a7SPeiming Liu   while (it != lvlTypes.end()) {
7967b9fb1c2SYinying Li     if (it == lvlTypes.begin() ||
79713af97a7SPeiming Liu         !(it - 1)->isa<LevelFormat::Compressed, LevelFormat::LooseCompressed>())
7987b9fb1c2SYinying Li       return emitError() << "expected compressed or loose_compressed level "
7997b9fb1c2SYinying Li                             "before singleton level";
80013af97a7SPeiming Liu 
80113af97a7SPeiming Liu     auto *curCOOEnd = std::find_if_not(it, lvlTypes.end(), isSingletonLT);
80213af97a7SPeiming Liu     if (!std::all_of(it, curCOOEnd,
8031944c4f7SAart Bik                      [](LevelType i) { return isSingletonLT(i); }))
8047b9fb1c2SYinying Li       return emitError() << "expected all singleton lvlTypes "
8057b9fb1c2SYinying Li                             "following a singleton level";
806088c7ce4SPeiming Liu     // We can potentially support mixed SoA/AoS singleton levels.
80713af97a7SPeiming Liu     if (!std::all_of(it, curCOOEnd, [it](LevelType i) {
808088c7ce4SPeiming Liu           return it->isa<LevelPropNonDefault::SoA>() ==
809088c7ce4SPeiming Liu                  i.isa<LevelPropNonDefault::SoA>();
810088c7ce4SPeiming Liu         })) {
811088c7ce4SPeiming Liu       return emitError() << "expected all singleton lvlTypes stored in the "
812088c7ce4SPeiming Liu                             "same memory layout (SoA vs AoS).";
8137b9fb1c2SYinying Li     }
81413af97a7SPeiming Liu     it = std::find_if(curCOOEnd, lvlTypes.end(), isSingletonLT);
815088c7ce4SPeiming Liu   }
816088c7ce4SPeiming Liu 
81756d58295SPeiming Liu   auto lastBatch = std::find_if(lvlTypes.rbegin(), lvlTypes.rend(), isBatchLT);
81856d58295SPeiming Liu   if (!std::all_of(lastBatch, lvlTypes.rend(), isBatchLT))
81956d58295SPeiming Liu     return emitError() << "Batch lvlType can only be leading levels.";
82056d58295SPeiming Liu 
821088c7ce4SPeiming Liu   // SoA property can only be applied on singleton level.
822088c7ce4SPeiming Liu   auto soaLvls = llvm::make_filter_range(lvlTypes, [](LevelType lt) {
823088c7ce4SPeiming Liu     return lt.isa<LevelPropNonDefault::SoA>();
824088c7ce4SPeiming Liu   });
825088c7ce4SPeiming Liu   if (llvm::any_of(soaLvls, [](LevelType lt) {
826088c7ce4SPeiming Liu         return !lt.isa<LevelFormat::Singleton>();
827088c7ce4SPeiming Liu       })) {
828088c7ce4SPeiming Liu     return emitError() << "SoA is only applicable to singleton lvlTypes.";
829088c7ce4SPeiming Liu   }
830088c7ce4SPeiming Liu 
8312a6b521bSYinying Li   // TODO: audit formats that actually are supported by backend.
8322a6b521bSYinying Li   if (auto it = std::find_if(lvlTypes.begin(), lvlTypes.end(), isNOutOfMLT);
8332a6b521bSYinying Li       it != std::end(lvlTypes)) {
8342a6b521bSYinying Li     if (it != lvlTypes.end() - 1)
8352a6b521bSYinying Li       return emitError() << "expected n_out_of_m to be the last level type";
8362a6b521bSYinying Li     if (!std::all_of(lvlTypes.begin(), it,
8372a6b521bSYinying Li                      [](LevelType i) { return isDenseLT(i); }))
8382a6b521bSYinying Li       return emitError() << "expected all dense lvlTypes "
8392a6b521bSYinying Li                             "before a n_out_of_m level";
8402a6b521bSYinying Li     if (dimToLvl && (dimToLvl.getNumDims() != dimToLvl.getNumResults())) {
8412a6b521bSYinying Li       if (!isBlockSparsity(dimToLvl)) {
8422a6b521bSYinying Li         return emitError()
8432a6b521bSYinying Li                << "expected 1xm block structure for n_out_of_m level";
8442a6b521bSYinying Li       }
8452a6b521bSYinying Li       auto sizes = getBlockSize(dimToLvl);
8462a6b521bSYinying Li       unsigned coefficient = 0;
8472a6b521bSYinying Li       for (const auto &elem : sizes) {
8482a6b521bSYinying Li         if (elem != 0) {
8492a6b521bSYinying Li           if (elem != coefficient && coefficient != 0) {
8502a6b521bSYinying Li             return emitError() << "expected only one blocked level "
8512a6b521bSYinying Li                                   "with the same coefficients";
8522a6b521bSYinying Li           }
8532a6b521bSYinying Li           coefficient = elem;
8542a6b521bSYinying Li         }
8552a6b521bSYinying Li       }
8562a6b521bSYinying Li       if (coefficient != getM(*it)) {
8572a6b521bSYinying Li         return emitError() << "expected coeffiencts of Affine expressions "
8582a6b521bSYinying Li                               "to be equal to m of n_out_of_m level";
8592a6b521bSYinying Li       }
8602a6b521bSYinying Li     }
8612a6b521bSYinying Li   }
862f708a549Swren romano   // Before we can check that the level-rank is consistent/coherent
863f708a549Swren romano   // across all fields, we need to define it.  The source-of-truth for
864f708a549Swren romano   // the `getLvlRank` method is the length of the level-types array,
865f708a549Swren romano   // since it must always be provided and have full rank; therefore we
866f708a549Swren romano   // use that same source-of-truth here.
867a0615d02Swren romano   const Level lvlRank = lvlTypes.size();
868f708a549Swren romano   if (lvlRank == 0)
869d9a2f89bSwren romano     return emitError() << "expected a non-empty array for lvlTypes";
87076647fceSwren romano   // We save `dimRank` here because we'll also need it to verify `dimSlices`.
87176647fceSwren romano   const Dimension dimRank = dimToLvl ? dimToLvl.getNumDims() : lvlRank;
87276647fceSwren romano   if (dimToLvl) {
87376647fceSwren romano     if (dimToLvl.getNumResults() != lvlRank)
8740a292199SAart Bik       return emitError()
87576647fceSwren romano              << "level-rank mismatch between dimToLvl and lvlTypes: "
87676647fceSwren romano              << dimToLvl.getNumResults() << " != " << lvlRank;
8777b9fb1c2SYinying Li     auto inferRes = inferLvlToDim(dimToLvl, dimToLvl.getContext());
8787b9fb1c2SYinying Li     // Symbols can't be inferred but are acceptable.
8797b9fb1c2SYinying Li     if (!inferRes && dimToLvl.getNumSymbols() == 0)
8807b9fb1c2SYinying Li       return emitError() << "failed to infer lvlToDim from dimToLvl";
8817b9fb1c2SYinying Li     if (lvlToDim && (inferRes != lvlToDim))
8827b9fb1c2SYinying Li       return emitError() << "expected lvlToDim to be an inverse of dimToLvl";
88376647fceSwren romano     if (dimRank > lvlRank)
88476647fceSwren romano       return emitError() << "unexpected dimToLvl mapping from " << dimRank
88576647fceSwren romano                          << " to " << lvlRank;
8860a292199SAart Bik   }
88776647fceSwren romano   if (!dimSlices.empty()) {
88876647fceSwren romano     if (dimSlices.size() != dimRank)
889d9a2f89bSwren romano       return emitError()
89076647fceSwren romano              << "dimension-rank mismatch between dimSlices and dimToLvl: "
89176647fceSwren romano              << dimSlices.size() << " != " << dimRank;
89276647fceSwren romano     // Compiler support for `dimSlices` currently requires that the two
89376647fceSwren romano     // ranks agree.  (However, it does allow `dimToLvl` to be a permutation.)
89476647fceSwren romano     if (dimRank != lvlRank)
89576647fceSwren romano       return emitError()
89676647fceSwren romano              << "dimSlices expected dimension-rank to match level-rank: "
89776647fceSwren romano              << dimRank << " != " << lvlRank;
898885a1f43SPeiming Liu   }
8990a292199SAart Bik   return success();
9000a292199SAart Bik }
9010a292199SAart Bik 
9020a292199SAart Bik LogicalResult SparseTensorEncodingAttr::verifyEncoding(
90322212ca7SAart Bik     ArrayRef<Size> dimShape, Type elementType,
9040a292199SAart Bik     function_ref<InFlightDiagnostic()> emitError) const {
905f708a549Swren romano   // Check structural integrity.  In particular, this ensures that the
906f708a549Swren romano   // level-rank is coherent across all the fields.
9072e2011daSAart Bik   if (failed(verify(emitError, getLvlTypes(), getDimToLvl(), getLvlToDim(),
908a10d67f9SYinying Li                     getPosWidth(), getCrdWidth(), getExplicitVal(),
909a10d67f9SYinying Li                     getImplicitVal(), getDimSlices())))
9102e2011daSAart Bik     return failure();
911f708a549Swren romano   // Check integrity with tensor type specifics.  In particular, we
912f708a549Swren romano   // need only check that the dimension-rank of the tensor agrees with
913f708a549Swren romano   // the dimension-rank of the encoding.
914f708a549Swren romano   const Dimension dimRank = dimShape.size();
915f708a549Swren romano   if (dimRank == 0)
9164aa9b398SAart Bik     return emitError() << "expected non-scalar sparse tensor";
91776647fceSwren romano   if (getDimRank() != dimRank)
91876647fceSwren romano     return emitError()
91976647fceSwren romano            << "dimension-rank mismatch between encoding and tensor shape: "
92076647fceSwren romano            << getDimRank() << " != " << dimRank;
92183f3b1cbSYinying Li   if (auto expVal = getExplicitVal()) {
92283f3b1cbSYinying Li     Type attrType = llvm::dyn_cast<TypedAttr>(expVal).getType();
92383f3b1cbSYinying Li     if (attrType != elementType) {
92483f3b1cbSYinying Li       return emitError() << "explicit value type mismatch between encoding and "
92583f3b1cbSYinying Li                          << "tensor element type: " << attrType
92683f3b1cbSYinying Li                          << " != " << elementType;
92783f3b1cbSYinying Li     }
92883f3b1cbSYinying Li   }
92983f3b1cbSYinying Li   if (auto impVal = getImplicitVal()) {
93083f3b1cbSYinying Li     Type attrType = llvm::dyn_cast<TypedAttr>(impVal).getType();
93183f3b1cbSYinying Li     if (attrType != elementType) {
93283f3b1cbSYinying Li       return emitError() << "implicit value type mismatch between encoding and "
93383f3b1cbSYinying Li                          << "tensor element type: " << attrType
93483f3b1cbSYinying Li                          << " != " << elementType;
93583f3b1cbSYinying Li     }
93683f3b1cbSYinying Li     // Currently, we only support zero as the implicit value.
93783f3b1cbSYinying Li     auto impFVal = llvm::dyn_cast<FloatAttr>(impVal);
93883f3b1cbSYinying Li     auto impIntVal = llvm::dyn_cast<IntegerAttr>(impVal);
93983f3b1cbSYinying Li     auto impComplexVal = llvm::dyn_cast<complex::NumberAttr>(impVal);
94083f3b1cbSYinying Li     if ((impFVal && impFVal.getValue().isNonZero()) ||
94183f3b1cbSYinying Li         (impIntVal && !impIntVal.getValue().isZero()) ||
94283f3b1cbSYinying Li         (impComplexVal && (impComplexVal.getImag().isNonZero() ||
94383f3b1cbSYinying Li                            impComplexVal.getReal().isNonZero()))) {
94483f3b1cbSYinying Li       return emitError() << "implicit value must be zero";
94583f3b1cbSYinying Li     }
94683f3b1cbSYinying Li   }
9470a292199SAart Bik   return success();
9480a292199SAart Bik }
9490a292199SAart Bik 
95083f3b1cbSYinying Li Level mlir::sparse_tensor::SparseTensorEncodingAttr::getAoSCOOStart() const {
951f740366fSPeiming Liu   SmallVector<COOSegment> coo = getCOOSegments();
9525248a987SPeiming Liu   assert(coo.size() == 1 || coo.empty());
9535248a987SPeiming Liu   if (!coo.empty() && coo.front().isAoS()) {
954f740366fSPeiming Liu     return coo.front().lvlRange.first;
955f740366fSPeiming Liu   }
95683f3b1cbSYinying Li   return getLvlRank();
9575b729503SAart Bik }
9585b729503SAart Bik 
959f740366fSPeiming Liu SmallVector<COOSegment>
96083f3b1cbSYinying Li mlir::sparse_tensor::SparseTensorEncodingAttr::getCOOSegments() const {
961f740366fSPeiming Liu   SmallVector<COOSegment> ret;
96283f3b1cbSYinying Li   if (getLvlRank() <= 1)
963f740366fSPeiming Liu     return ret;
964f740366fSPeiming Liu 
965f740366fSPeiming Liu   ArrayRef<LevelType> lts = getLvlTypes();
966f740366fSPeiming Liu   Level l = 0;
96783f3b1cbSYinying Li   while (l < getLvlRank()) {
968f740366fSPeiming Liu     auto lt = lts[l];
969f740366fSPeiming Liu     if (lt.isa<LevelFormat::Compressed, LevelFormat::LooseCompressed>()) {
970f740366fSPeiming Liu       auto cur = lts.begin() + l;
971f740366fSPeiming Liu       auto end = std::find_if(cur + 1, lts.end(), [](LevelType lt) {
972f740366fSPeiming Liu         return !lt.isa<LevelFormat::Singleton>();
973f740366fSPeiming Liu       });
974f740366fSPeiming Liu       unsigned cooLen = std::distance(cur, end);
975f740366fSPeiming Liu       if (cooLen > 1) {
976f740366fSPeiming Liu         // To support mixed SoA/AoS COO, we should break the segment when the
977f740366fSPeiming Liu         // storage scheme changes, for now we faithfully assume that all
978f740366fSPeiming Liu         // consecutive singleton levels have the same storage format as verified
979f740366fSPeiming Liu         // STEA.
980f740366fSPeiming Liu         ret.push_back(COOSegment{std::make_pair(l, l + cooLen),
981f740366fSPeiming Liu                                  lts[l + 1].isa<LevelPropNonDefault::SoA>()});
982f740366fSPeiming Liu       }
983f740366fSPeiming Liu       l += cooLen;
984f740366fSPeiming Liu     } else {
985f740366fSPeiming Liu       l++;
986f740366fSPeiming Liu     }
987f740366fSPeiming Liu   }
988f740366fSPeiming Liu   return ret;
989f740366fSPeiming Liu }
990f740366fSPeiming Liu 
99183f3b1cbSYinying Li //===----------------------------------------------------------------------===//
99283f3b1cbSYinying Li // SparseTensorType Methods.
99383f3b1cbSYinying Li //===----------------------------------------------------------------------===//
99483f3b1cbSYinying Li 
99583f3b1cbSYinying Li bool mlir::sparse_tensor::SparseTensorType::isCOOType(Level startLvl,
99683f3b1cbSYinying Li                                                       bool isUnique) const {
99783f3b1cbSYinying Li   if (!hasEncoding())
99883f3b1cbSYinying Li     return false;
99983f3b1cbSYinying Li   if (!isCompressedLvl(startLvl) && !isLooseCompressedLvl(startLvl))
100083f3b1cbSYinying Li     return false;
100183f3b1cbSYinying Li   for (Level l = startLvl + 1; l < lvlRank; ++l)
100283f3b1cbSYinying Li     if (!isSingletonLvl(l))
100383f3b1cbSYinying Li       return false;
100483f3b1cbSYinying Li   // If isUnique is true, then make sure that the last level is unique,
100583f3b1cbSYinying Li   // that is, when lvlRank == 1, the only compressed level is unique,
100683f3b1cbSYinying Li   // and when lvlRank > 1, the last singleton is unique.
100783f3b1cbSYinying Li   return !isUnique || isUniqueLvl(lvlRank - 1);
100883f3b1cbSYinying Li }
100983f3b1cbSYinying Li 
101045288085SAart Bik RankedTensorType
101145288085SAart Bik mlir::sparse_tensor::SparseTensorType::getCOOType(bool ordered) const {
101245288085SAart Bik   SmallVector<LevelType> lvlTypes;
101345288085SAart Bik   lvlTypes.reserve(lvlRank);
101445288085SAart Bik   // A non-unique compressed level at beginning (unless this is
101545288085SAart Bik   // also the last level, then it is unique).
101645288085SAart Bik   lvlTypes.push_back(
101745288085SAart Bik       *buildLevelType(LevelFormat::Compressed, ordered, lvlRank == 1));
101845288085SAart Bik   if (lvlRank > 1) {
101945288085SAart Bik     // Followed by n-2 non-unique singleton levels.
102045288085SAart Bik     std::fill_n(std::back_inserter(lvlTypes), lvlRank - 2,
102145288085SAart Bik                 *buildLevelType(LevelFormat::Singleton, ordered, false));
102245288085SAart Bik     // Ends by a unique singleton level.
102345288085SAart Bik     lvlTypes.push_back(*buildLevelType(LevelFormat::Singleton, ordered, true));
102445288085SAart Bik   }
1025a10d67f9SYinying Li   auto enc = SparseTensorEncodingAttr::get(
1026a10d67f9SYinying Li       getContext(), lvlTypes, getDimToLvl(), getLvlToDim(), getPosWidth(),
1027a10d67f9SYinying Li       getCrdWidth(), getExplicitVal(), getImplicitVal());
102845288085SAart Bik   return RankedTensorType::get(getDimShape(), getElementType(), enc);
102945288085SAart Bik }
103045288085SAart Bik 
103145288085SAart Bik //===----------------------------------------------------------------------===//
103245288085SAart Bik // Convenience Methods.
10334d068619SAart Bik //===----------------------------------------------------------------------===//
10344d068619SAart Bik 
103596a23911SAart Bik SparseTensorEncodingAttr
103696a23911SAart Bik mlir::sparse_tensor::getSparseTensorEncoding(Type type) {
1037c1fa60b4STres Popp   if (auto ttp = llvm::dyn_cast<RankedTensorType>(type))
1038c1fa60b4STres Popp     return llvm::dyn_cast_or_null<SparseTensorEncodingAttr>(ttp.getEncoding());
1039c1fa60b4STres Popp   if (auto mdtp = llvm::dyn_cast<StorageSpecifierType>(type))
104071cc0f1cSPeiming Liu     return mdtp.getEncoding();
104196a23911SAart Bik   return nullptr;
104296a23911SAart Bik }
104396a23911SAart Bik 
1044d4088e7dSYinying Li AffineMap mlir::sparse_tensor::inferLvlToDim(AffineMap dimToLvl,
1045d4088e7dSYinying Li                                              MLIRContext *context) {
1046d4088e7dSYinying Li   auto map = static_cast<AffineMap>(dimToLvl);
1047d4088e7dSYinying Li   AffineMap lvlToDim;
1048d4088e7dSYinying Li   // Return an empty lvlToDim when inference is not successful.
1049d4088e7dSYinying Li   if (!map || map.getNumSymbols() != 0) {
1050d4088e7dSYinying Li     lvlToDim = AffineMap();
1051d4088e7dSYinying Li   } else if (map.isPermutation()) {
1052d4088e7dSYinying Li     lvlToDim = inversePermutation(map);
10537b9fb1c2SYinying Li   } else if (isBlockSparsity(map)) {
1054d4088e7dSYinying Li     lvlToDim = inverseBlockSparsity(map, context);
1055d4088e7dSYinying Li   }
1056d4088e7dSYinying Li   return lvlToDim;
1057d4088e7dSYinying Li }
1058d4088e7dSYinying Li 
1059d4088e7dSYinying Li AffineMap mlir::sparse_tensor::inverseBlockSparsity(AffineMap dimToLvl,
1060d4088e7dSYinying Li                                                     MLIRContext *context) {
1061d4088e7dSYinying Li   SmallVector<AffineExpr> lvlExprs;
1062d4088e7dSYinying Li   auto numLvls = dimToLvl.getNumResults();
1063d4088e7dSYinying Li   lvlExprs.reserve(numLvls);
1064d4088e7dSYinying Li   // lvlExprComponents stores information of the floordiv and mod operations
1065d4088e7dSYinying Li   // applied to the same dimension, so as to build the lvlToDim map.
1066d4088e7dSYinying Li   std::map<unsigned, SmallVector<AffineExpr, 3>> lvlExprComponents;
1067d4088e7dSYinying Li   for (unsigned i = 0, n = numLvls; i < n; i++) {
1068d4088e7dSYinying Li     auto result = dimToLvl.getResult(i);
10691609f1c2Slong.chen     if (auto binOp = dyn_cast<AffineBinaryOpExpr>(result)) {
1070d4088e7dSYinying Li       if (result.getKind() == AffineExprKind::FloorDiv) {
1071d4088e7dSYinying Li         // Position of the dimension in dimToLvl.
10721609f1c2Slong.chen         auto pos = dyn_cast<AffineDimExpr>(binOp.getLHS()).getPosition();
1073d4088e7dSYinying Li         assert(lvlExprComponents.find(pos) == lvlExprComponents.end() &&
1074d4088e7dSYinying Li                "expected only one floordiv for each dimension");
1075d4088e7dSYinying Li         SmallVector<AffineExpr, 3> components;
1076d4088e7dSYinying Li         // Level variable for floordiv.
1077d4088e7dSYinying Li         components.push_back(getAffineDimExpr(i, context));
1078d4088e7dSYinying Li         // Multiplier.
1079d4088e7dSYinying Li         components.push_back(binOp.getRHS());
1080d4088e7dSYinying Li         // Map key is the position of the dimension.
1081d4088e7dSYinying Li         lvlExprComponents[pos] = components;
1082d4088e7dSYinying Li       } else if (result.getKind() == AffineExprKind::Mod) {
10831609f1c2Slong.chen         auto pos = dyn_cast<AffineDimExpr>(binOp.getLHS()).getPosition();
1084d4088e7dSYinying Li         assert(lvlExprComponents.find(pos) != lvlExprComponents.end() &&
1085d4088e7dSYinying Li                "expected floordiv before mod");
1086d4088e7dSYinying Li         // Add level variable for mod to the same vector
1087d4088e7dSYinying Li         // of the corresponding floordiv.
1088d4088e7dSYinying Li         lvlExprComponents[pos].push_back(getAffineDimExpr(i, context));
1089d4088e7dSYinying Li       } else {
1090d4088e7dSYinying Li         assert(false && "expected floordiv or mod");
1091d4088e7dSYinying Li       }
1092d4088e7dSYinying Li     } else {
1093d4088e7dSYinying Li       lvlExprs.push_back(getAffineDimExpr(i, context));
1094d4088e7dSYinying Li     }
1095d4088e7dSYinying Li   }
1096d4088e7dSYinying Li   // Build lvlExprs from lvlExprComponents.
1097d4088e7dSYinying Li   // For example, for il = i floordiv 2 and ii = i mod 2, the components
1098d4088e7dSYinying Li   // would be [il, 2, ii]. It could be used to build the AffineExpr
1099d4088e7dSYinying Li   // i = il * 2 + ii in lvlToDim.
1100d4088e7dSYinying Li   for (auto &components : lvlExprComponents) {
1101d4088e7dSYinying Li     assert(components.second.size() == 3 &&
1102d4088e7dSYinying Li            "expected 3 components to build lvlExprs");
1103d4088e7dSYinying Li     auto mulOp = getAffineBinaryOpExpr(
1104d4088e7dSYinying Li         AffineExprKind::Mul, components.second[0], components.second[1]);
1105d4088e7dSYinying Li     auto addOp =
1106d4088e7dSYinying Li         getAffineBinaryOpExpr(AffineExprKind::Add, mulOp, components.second[2]);
1107d4088e7dSYinying Li     lvlExprs.push_back(addOp);
1108d4088e7dSYinying Li   }
1109d4088e7dSYinying Li   return dimToLvl.get(dimToLvl.getNumResults(), 0, lvlExprs, context);
1110d4088e7dSYinying Li }
1111d4088e7dSYinying Li 
11127b9fb1c2SYinying Li SmallVector<unsigned> mlir::sparse_tensor::getBlockSize(AffineMap dimToLvl) {
11137b9fb1c2SYinying Li   assert(isBlockSparsity(dimToLvl) &&
11147b9fb1c2SYinying Li          "expected dimToLvl to be block sparsity for calling getBlockSize");
11157b9fb1c2SYinying Li   SmallVector<unsigned> blockSize;
11167b9fb1c2SYinying Li   for (auto result : dimToLvl.getResults()) {
11171609f1c2Slong.chen     if (auto binOp = dyn_cast<AffineBinaryOpExpr>(result)) {
11187b9fb1c2SYinying Li       if (result.getKind() == AffineExprKind::Mod) {
11197b9fb1c2SYinying Li         blockSize.push_back(
11201609f1c2Slong.chen             dyn_cast<AffineConstantExpr>(binOp.getRHS()).getValue());
11217b9fb1c2SYinying Li       }
11227b9fb1c2SYinying Li     } else {
11237b9fb1c2SYinying Li       blockSize.push_back(0);
11247b9fb1c2SYinying Li     }
11257b9fb1c2SYinying Li   }
11267b9fb1c2SYinying Li   return blockSize;
11277b9fb1c2SYinying Li }
11287b9fb1c2SYinying Li 
11297b9fb1c2SYinying Li bool mlir::sparse_tensor::isBlockSparsity(AffineMap dimToLvl) {
11307b9fb1c2SYinying Li   if (!dimToLvl)
11317b9fb1c2SYinying Li     return false;
11327b9fb1c2SYinying Li   std::map<unsigned, int64_t> coeffientMap;
113331b72b07SYinying Li   bool hasBlock = false;
11347b9fb1c2SYinying Li   for (auto result : dimToLvl.getResults()) {
11351609f1c2Slong.chen     if (auto binOp = dyn_cast<AffineBinaryOpExpr>(result)) {
11363d3e46ccSAart Bik       // Check for "dim op const".
11373d3e46ccSAart Bik       auto dimOp = dyn_cast<AffineDimExpr>(binOp.getLHS());
11383d3e46ccSAart Bik       auto conOp = dyn_cast<AffineConstantExpr>(binOp.getRHS());
113931b72b07SYinying Li       if (!dimOp || !conOp || conOp.getValue() <= 0)
11403d3e46ccSAart Bik         return false;
11413d3e46ccSAart Bik       // Inspect "dim / const" or "dim % const".
11423d3e46ccSAart Bik       auto pos = dimOp.getPosition();
11433d3e46ccSAart Bik       if (binOp.getKind() == AffineExprKind::FloorDiv) {
11447b9fb1c2SYinying Li         // Expect only one floordiv for each dimension.
1145*af6e1881SKazu Hirata         auto [it, inserted] = coeffientMap.try_emplace(pos);
1146*af6e1881SKazu Hirata         if (!inserted)
11477b9fb1c2SYinying Li           return false;
11483d3e46ccSAart Bik         // Record coefficient of the floordiv.
1149*af6e1881SKazu Hirata         it->second = conOp.getValue();
11503d3e46ccSAart Bik       } else if (binOp.getKind() == AffineExprKind::Mod) {
11517b9fb1c2SYinying Li         // Expect floordiv before mod.
1152*af6e1881SKazu Hirata         auto it = coeffientMap.find(pos);
1153*af6e1881SKazu Hirata         if (it == coeffientMap.end())
11547b9fb1c2SYinying Li           return false;
11557b9fb1c2SYinying Li         // Expect mod to have the same coefficient as floordiv.
1156*af6e1881SKazu Hirata         if (conOp.getValue() != it->second)
11577b9fb1c2SYinying Li           return false;
115831b72b07SYinying Li         hasBlock = true;
115931b72b07SYinying Li       } else {
116031b72b07SYinying Li         return false;
116131b72b07SYinying Li       }
116231b72b07SYinying Li     } else if (auto dimOp = dyn_cast<AffineDimExpr>(result)) {
116331b72b07SYinying Li       auto pos = dimOp.getPosition();
116431b72b07SYinying Li       // Expect dim to be unset.
11652077fb80SKazu Hirata       if (!coeffientMap.try_emplace(pos, 0).second)
116631b72b07SYinying Li         return false;
11677b9fb1c2SYinying Li     } else {
11687b9fb1c2SYinying Li       return false;
11697b9fb1c2SYinying Li     }
11707b9fb1c2SYinying Li   }
117131b72b07SYinying Li   return hasBlock;
11727b9fb1c2SYinying Li }
11737b9fb1c2SYinying Li 
117406a65ce5SPeiming Liu bool mlir::sparse_tensor::hasAnyNonIdentityOperandsOrResults(Operation *op) {
117506a65ce5SPeiming Liu   auto hasNonIdentityMap = [](Value v) {
117606a65ce5SPeiming Liu     auto stt = tryGetSparseTensorType(v);
117706a65ce5SPeiming Liu     return stt && !stt->isIdentity();
117806a65ce5SPeiming Liu   };
117906a65ce5SPeiming Liu 
118006a65ce5SPeiming Liu   return llvm::any_of(op->getOperands(), hasNonIdentityMap) ||
118106a65ce5SPeiming Liu          llvm::any_of(op->getResults(), hasNonIdentityMap);
118206a65ce5SPeiming Liu }
118306a65ce5SPeiming Liu 
11844e2f1521SPeiming Liu Dimension mlir::sparse_tensor::toDim(SparseTensorEncodingAttr enc, Level l) {
1185f231821eSAart Bik   if (enc) {
11864e2f1521SPeiming Liu     assert(enc.isPermutation() && "Non permutation map not supported");
11874e2f1521SPeiming Liu     if (const auto dimToLvl = enc.getDimToLvl())
118876647fceSwren romano       return dimToLvl.getDimPosition(l);
1189f231821eSAart Bik   }
1190f708a549Swren romano   return l;
1191f231821eSAart Bik }
1192f231821eSAart Bik 
11934e2f1521SPeiming Liu Level mlir::sparse_tensor::toLvl(SparseTensorEncodingAttr enc, Dimension d) {
1194f231821eSAart Bik   if (enc) {
11954e2f1521SPeiming Liu     assert(enc.isPermutation() && "Non permutation map not supported");
11964e2f1521SPeiming Liu     if (const auto lvlToDim = enc.getLvlToDim())
11974e2f1521SPeiming Liu       return lvlToDim.getDimPosition(d);
1198f231821eSAart Bik   }
1199f231821eSAart Bik   return d;
1200f231821eSAart Bik }
1201f231821eSAart Bik 
1202083ddffeSPeiming Liu /// We normalized sparse tensor encoding attribute by always using
12031dd387e1SAart Bik /// ordered/unique LT such that "compressed_nu_no" and "compressed_nu" (as well
1204083ddffeSPeiming Liu /// as other variants) lead to the same storage specifier type, and stripping
12056db397a8SPeiming Liu /// irrelevant fields that do not alter the sparse tensor memory layout.
1206083ddffeSPeiming Liu static SparseTensorEncodingAttr
1207083ddffeSPeiming Liu getNormalizedEncodingForSpecifier(SparseTensorEncodingAttr enc) {
12081944c4f7SAart Bik   SmallVector<LevelType> lts;
12091dd387e1SAart Bik   for (auto lt : enc.getLvlTypes())
12105248a987SPeiming Liu     lts.push_back(lt.stripStorageIrrelevantProperties());
1211083ddffeSPeiming Liu 
1212083ddffeSPeiming Liu   return SparseTensorEncodingAttr::get(
12131dd387e1SAart Bik       enc.getContext(), lts,
121476647fceSwren romano       AffineMap(), // dimToLvl (irrelevant to storage specifier)
1215836411b9SAart Bik       AffineMap(), // lvlToDim (irrelevant to storage specifier)
121684cd51bbSwren romano       // Always use `index` for memSize and lvlSize instead of reusing
12176db397a8SPeiming Liu       // `getPosWidth` and `getCrdWidth`. It allows us to reuse the same SSA
12186db397a8SPeiming Liu       // value for different bitwidth, it also avoids casting between index and
12196db397a8SPeiming Liu       // integer (returned by DimOp)
1220a10d67f9SYinying Li       0, 0,
1221a10d67f9SYinying Li       Attribute(), // explicitVal (irrelevant to storage specifier)
1222a10d67f9SYinying Li       Attribute(), // implicitVal (irrelevant to storage specifier)
1223a10d67f9SYinying Li       enc.getDimSlices());
1224083ddffeSPeiming Liu }
1225083ddffeSPeiming Liu 
1226083ddffeSPeiming Liu StorageSpecifierType
1227083ddffeSPeiming Liu StorageSpecifierType::get(MLIRContext *ctx, SparseTensorEncodingAttr encoding) {
1228083ddffeSPeiming Liu   return Base::get(ctx, getNormalizedEncodingForSpecifier(encoding));
1229083ddffeSPeiming Liu }
1230083ddffeSPeiming Liu 
12317359a6b7SMatthias Springer StorageSpecifierType
12327359a6b7SMatthias Springer StorageSpecifierType::getChecked(function_ref<InFlightDiagnostic()> emitError,
12337359a6b7SMatthias Springer                                  MLIRContext *ctx,
12347359a6b7SMatthias Springer                                  SparseTensorEncodingAttr encoding) {
12357359a6b7SMatthias Springer   return Base::getChecked(emitError, ctx,
12367359a6b7SMatthias Springer                           getNormalizedEncodingForSpecifier(encoding));
12377359a6b7SMatthias Springer }
12387359a6b7SMatthias Springer 
123971cc0f1cSPeiming Liu //===----------------------------------------------------------------------===//
124071cc0f1cSPeiming Liu // SparseTensorDialect Operations.
124196a23911SAart Bik //===----------------------------------------------------------------------===//
124296a23911SAart Bik 
124384cd51bbSwren romano static LogicalResult lvlIsInBounds(Level lvl, Value tensor) {
124484cd51bbSwren romano   return success(lvl < getSparseTensorType(tensor).getLvlRank());
124596a23911SAart Bik }
124696a23911SAart Bik 
124784cd51bbSwren romano static LogicalResult isMatchingWidth(Value mem, unsigned width) {
124884cd51bbSwren romano   const Type etp = getMemRefType(mem).getElementType();
1249743fbcb7Swren romano   return success(width == 0 ? etp.isIndex() : etp.isInteger(width));
125096a23911SAart Bik }
125196a23911SAart Bik 
125222426110SRamkumar Ramachandra static LogicalResult verifySparsifierGetterSetter(
125384cd51bbSwren romano     StorageSpecifierKind mdKind, std::optional<Level> lvl,
125422426110SRamkumar Ramachandra     TypedValue<StorageSpecifierType> md, Operation *op) {
1255f708a549Swren romano   if (mdKind == StorageSpecifierKind::ValMemSize && lvl) {
125671cc0f1cSPeiming Liu     return op->emitError(
1257f708a549Swren romano         "redundant level argument for querying value memory size");
125871cc0f1cSPeiming Liu   }
125971cc0f1cSPeiming Liu 
1260f708a549Swren romano   const auto enc = md.getType().getEncoding();
1261f708a549Swren romano   const Level lvlRank = enc.getLvlRank();
126271cc0f1cSPeiming Liu 
12636db397a8SPeiming Liu   if (mdKind == StorageSpecifierKind::DimOffset ||
12646db397a8SPeiming Liu       mdKind == StorageSpecifierKind::DimStride)
12656db397a8SPeiming Liu     if (!enc.isSlice())
12666db397a8SPeiming Liu       return op->emitError("requested slice data on non-slice tensor");
12678237cac6SPeiming Liu 
126871cc0f1cSPeiming Liu   if (mdKind != StorageSpecifierKind::ValMemSize) {
1269f708a549Swren romano     if (!lvl)
1270f708a549Swren romano       return op->emitError("missing level argument");
127171cc0f1cSPeiming Liu 
127284cd51bbSwren romano     const Level l = lvl.value();
1273f708a549Swren romano     if (l >= lvlRank)
127484cd51bbSwren romano       return op->emitError("requested level is out of bounds");
127571cc0f1cSPeiming Liu 
127684cd51bbSwren romano     if (mdKind == StorageSpecifierKind::PosMemSize && enc.isSingletonLvl(l))
127771cc0f1cSPeiming Liu       return op->emitError(
127884cd51bbSwren romano           "requested position memory size on a singleton level");
127971cc0f1cSPeiming Liu   }
128071cc0f1cSPeiming Liu   return success();
128171cc0f1cSPeiming Liu }
128271cc0f1cSPeiming Liu 
1283de560888SPeiming Liu static Type getFieldElemType(SparseTensorType stt, SparseTensorFieldKind kind) {
1284de560888SPeiming Liu   switch (kind) {
1285de560888SPeiming Liu   case SparseTensorFieldKind::CrdMemRef:
1286de560888SPeiming Liu     return stt.getCrdType();
1287de560888SPeiming Liu   case SparseTensorFieldKind::PosMemRef:
1288de560888SPeiming Liu     return stt.getPosType();
1289de560888SPeiming Liu   case SparseTensorFieldKind::ValMemRef:
1290de560888SPeiming Liu     return stt.getElementType();
1291de560888SPeiming Liu   case SparseTensorFieldKind::StorageSpec:
1292de560888SPeiming Liu     return nullptr;
1293de560888SPeiming Liu   }
1294de560888SPeiming Liu   llvm_unreachable("Unrecognizable FieldKind");
12957864d736SPeiming Liu }
12967864d736SPeiming Liu 
1297de560888SPeiming Liu static LogicalResult verifyPackUnPack(Operation *op, bool requiresStaticShape,
1298de560888SPeiming Liu                                       SparseTensorType stt,
1299de560888SPeiming Liu                                       RankedTensorType valTp,
1300de560888SPeiming Liu                                       TypeRange lvlTps) {
1301de560888SPeiming Liu   if (requiresStaticShape && !stt.hasStaticDimShape())
1302de560888SPeiming Liu     return op->emitError("the sparse-tensor must have static shape");
1303de560888SPeiming Liu   if (!stt.hasEncoding())
1304de560888SPeiming Liu     return op->emitError("the sparse-tensor must have an encoding attribute");
1305de560888SPeiming Liu 
1306de560888SPeiming Liu   // Verifies the trailing COO.
13075248a987SPeiming Liu   Level cooStartLvl = stt.getAoSCOOStart();
1308de560888SPeiming Liu   if (cooStartLvl < stt.getLvlRank()) {
1309de560888SPeiming Liu     // We only supports trailing COO for now, must be the last input.
131068f58812STres Popp     auto cooTp = llvm::cast<ShapedType>(lvlTps.back());
1311de560888SPeiming Liu     // The coordinates should be in shape of <? x rank>
1312de560888SPeiming Liu     unsigned expCOORank = stt.getLvlRank() - cooStartLvl;
1313de560888SPeiming Liu     if (cooTp.getRank() != 2 || expCOORank != cooTp.getShape().back()) {
131477f8297cSMatthias Springer       return op->emitError("input/output trailing COO level-ranks don't match");
1315de560888SPeiming Liu     }
1316de560888SPeiming Liu   }
1317de560888SPeiming Liu 
1318de560888SPeiming Liu   // Verifies that all types match.
1319de560888SPeiming Liu   StorageLayout layout(stt.getEncoding());
1320de560888SPeiming Liu   if (layout.getNumDataFields() != lvlTps.size() + 1) // plus one value memref
1321de560888SPeiming Liu     return op->emitError("inconsistent number of fields between input/output");
1322de560888SPeiming Liu 
1323de560888SPeiming Liu   unsigned idx = 0;
1324de560888SPeiming Liu   bool misMatch = false;
1325de560888SPeiming Liu   layout.foreachField([&idx, &misMatch, stt, valTp,
1326de560888SPeiming Liu                        lvlTps](FieldIndex fid, SparseTensorFieldKind fKind,
13271944c4f7SAart Bik                                Level lvl, LevelType lt) -> bool {
1328de560888SPeiming Liu     if (fKind == SparseTensorFieldKind::StorageSpec)
1329de560888SPeiming Liu       return true;
1330de560888SPeiming Liu 
1331de560888SPeiming Liu     Type inputTp = nullptr;
1332de560888SPeiming Liu     if (fKind == SparseTensorFieldKind::ValMemRef) {
1333de560888SPeiming Liu       inputTp = valTp;
1334de560888SPeiming Liu     } else {
13351dd387e1SAart Bik       assert(fid == idx && stt.getLvlType(lvl) == lt);
1336de560888SPeiming Liu       inputTp = lvlTps[idx++];
1337de560888SPeiming Liu     }
1338de560888SPeiming Liu     // The input element type and expected element type should match.
133968f58812STres Popp     Type inpElemTp = llvm::cast<TensorType>(inputTp).getElementType();
1340de560888SPeiming Liu     Type expElemTp = getFieldElemType(stt, fKind);
1341de560888SPeiming Liu     if (inpElemTp != expElemTp) {
1342de560888SPeiming Liu       misMatch = true;
1343de560888SPeiming Liu       return false; // to terminate the iteration
1344de560888SPeiming Liu     }
1345de560888SPeiming Liu     return true;
1346de560888SPeiming Liu   });
1347de560888SPeiming Liu 
1348de560888SPeiming Liu   if (misMatch)
1349de560888SPeiming Liu     return op->emitError("input/output element-types don't match");
1350de560888SPeiming Liu   return success();
1351de560888SPeiming Liu }
1352de560888SPeiming Liu 
13536ca47eb4SPeiming Liu LogicalResult AssembleOp::verify() {
135477f8297cSMatthias Springer   RankedTensorType valuesTp = getValues().getType();
1355de560888SPeiming Liu   const auto lvlsTp = getLevels().getTypes();
1356de560888SPeiming Liu   const auto resTp = getSparseTensorType(getResult());
1357de560888SPeiming Liu   return verifyPackUnPack(*this, true, resTp, valuesTp, lvlsTp);
13586dbca86dSPeiming Liu }
13596dbca86dSPeiming Liu 
13606ca47eb4SPeiming Liu LogicalResult DisassembleOp::verify() {
1361b2e6b735SPeiming Liu   if (getOutValues().getType() != getRetValues().getType())
1362b2e6b735SPeiming Liu     return emitError("output values and return value type mismatch");
1363d4db5289SPeiming Liu 
1364b2e6b735SPeiming Liu   for (auto [ot, rt] : llvm::zip_equal(getOutLevels(), getRetLevels()))
1365b2e6b735SPeiming Liu     if (ot.getType() != rt.getType())
1366b2e6b735SPeiming Liu       return emitError("output levels and return levels type mismatch");
1367b2e6b735SPeiming Liu 
136877f8297cSMatthias Springer   RankedTensorType valuesTp = getRetValues().getType();
1369b2e6b735SPeiming Liu   const auto lvlsTp = getRetLevels().getTypes();
1370b2e6b735SPeiming Liu   const auto srcTp = getSparseTensorType(getTensor());
1371b2e6b735SPeiming Liu   return verifyPackUnPack(*this, false, srcTp, valuesTp, lvlsTp);
13726dbca86dSPeiming Liu }
13736dbca86dSPeiming Liu 
1374b98dc035SRiver Riddle LogicalResult ConvertOp::verify() {
137577f8297cSMatthias Springer   RankedTensorType tp1 = getSource().getType();
137677f8297cSMatthias Springer   RankedTensorType tp2 = getDest().getType();
13771e6ef0cfSAart Bik   if (tp1.getRank() != tp2.getRank())
1378b98dc035SRiver Riddle     return emitError("unexpected conversion mismatch in rank");
137933267f40SPeiming Liu   auto dstEnc =
1380c1fa60b4STres Popp       llvm::dyn_cast_or_null<SparseTensorEncodingAttr>(tp2.getEncoding());
138133267f40SPeiming Liu   if (dstEnc && dstEnc.isSlice())
138233267f40SPeiming Liu     return emitError("cannot convert to a sparse tensor slice");
138333267f40SPeiming Liu 
1384697ea09dSAart Bik   auto shape1 = tp1.getShape();
1385697ea09dSAart Bik   auto shape2 = tp2.getShape();
13869d1db3d4SAart Bik   // Accept size matches between the source and the destination type
13879d1db3d4SAart Bik   // (e.g. 10 vs. 10, 10 vs. ?, or ? vs. ?), but reject direct mismatches or
13889d1db3d4SAart Bik   // matches that would need a runtime assert (e.g. 10 vs. 20 or ? vs. 10).
1389f708a549Swren romano   for (Dimension d = 0, dimRank = tp1.getRank(); d < dimRank; d++)
1390399638f9SAliia Khasanova     if (shape1[d] != shape2[d] && shape2[d] != ShapedType::kDynamic)
1391b98dc035SRiver Riddle       return emitError("unexpected conversion mismatch in dimension ") << d;
1392697ea09dSAart Bik   return success();
1393697ea09dSAart Bik }
1394697ea09dSAart Bik 
13957df76121SMarkus Böck OpFoldResult ConvertOp::fold(FoldAdaptor adaptor) {
1396f248d0b2SPeiming Liu   if (getType() == getSource().getType())
13970128f801Sbixia1     return getSource();
13980128f801Sbixia1   return {};
13990128f801Sbixia1 }
14000128f801Sbixia1 
1401761c9dd9SPeiming Liu bool ConvertOp::needsExtraSort() {
1402dda3dc5eSPeiming Liu   SparseTensorType srcStt = getSparseTensorType(getSource());
1403dda3dc5eSPeiming Liu   SparseTensorType dstStt = getSparseTensorType(getDest());
1404dda3dc5eSPeiming Liu 
1405761c9dd9SPeiming Liu   // We do not need an extra sort when returning unordered sparse tensors or
1406761c9dd9SPeiming Liu   // dense tensor since dense tensor support random access.
1407dda3dc5eSPeiming Liu   if (dstStt.isAllDense() || !dstStt.isAllOrdered())
1408761c9dd9SPeiming Liu     return false;
1409dda3dc5eSPeiming Liu 
1410dda3dc5eSPeiming Liu   if (srcStt.isAllOrdered() && dstStt.isAllOrdered() &&
1411dda3dc5eSPeiming Liu       srcStt.hasSameDimToLvl(dstStt)) {
1412761c9dd9SPeiming Liu     return false;
1413dda3dc5eSPeiming Liu   }
1414dda3dc5eSPeiming Liu 
1415dda3dc5eSPeiming Liu   // Source and dest tensors are ordered in different ways. We only do direct
1416dda3dc5eSPeiming Liu   // dense to sparse conversion when the dense input is defined by a sparse
1417dda3dc5eSPeiming Liu   // constant. Note that we can theoretically always directly convert from dense
1418dda3dc5eSPeiming Liu   // inputs by rotating dense loops but it leads to bad cache locality and hurt
1419dda3dc5eSPeiming Liu   // performance.
1420dda3dc5eSPeiming Liu   if (auto constOp = getSource().getDefiningOp<arith::ConstantOp>())
1421dda3dc5eSPeiming Liu     if (isa<SparseElementsAttr>(constOp.getValue()))
1422dda3dc5eSPeiming Liu       return false;
1423761c9dd9SPeiming Liu 
1424761c9dd9SPeiming Liu   return true;
1425dda3dc5eSPeiming Liu }
1426dda3dc5eSPeiming Liu 
1427ff21a90eSPeiming Liu LogicalResult CrdTranslateOp::verify() {
1428ff21a90eSPeiming Liu   uint64_t inRank = getEncoder().getLvlRank();
1429ff21a90eSPeiming Liu   uint64_t outRank = getEncoder().getDimRank();
1430ff21a90eSPeiming Liu 
1431ff21a90eSPeiming Liu   if (getDirection() == CrdTransDirectionKind::dim2lvl)
1432ff21a90eSPeiming Liu     std::swap(inRank, outRank);
1433ff21a90eSPeiming Liu 
1434ff21a90eSPeiming Liu   if (inRank != getInCrds().size() || outRank != getOutCrds().size())
1435ff21a90eSPeiming Liu     return emitError("Coordinate rank mismatch with encoding");
1436ff21a90eSPeiming Liu 
1437ff21a90eSPeiming Liu   return success();
1438ff21a90eSPeiming Liu }
1439ff21a90eSPeiming Liu 
1440ff21a90eSPeiming Liu LogicalResult CrdTranslateOp::fold(FoldAdaptor adaptor,
1441ff21a90eSPeiming Liu                                    SmallVectorImpl<OpFoldResult> &results) {
14426456e0bbSPeiming Liu   if (getEncoder().isIdentity()) {
14436456e0bbSPeiming Liu     results.assign(getInCrds().begin(), getInCrds().end());
14446456e0bbSPeiming Liu     return success();
14456456e0bbSPeiming Liu   }
1446ff21a90eSPeiming Liu   if (getEncoder().isPermutation()) {
1447ff21a90eSPeiming Liu     AffineMap perm = getDirection() == CrdTransDirectionKind::dim2lvl
1448ff21a90eSPeiming Liu                          ? getEncoder().getDimToLvl()
1449ff21a90eSPeiming Liu                          : getEncoder().getLvlToDim();
1450ff21a90eSPeiming Liu     for (AffineExpr exp : perm.getResults())
14511609f1c2Slong.chen       results.push_back(getInCrds()[cast<AffineDimExpr>(exp).getPosition()]);
1452ff21a90eSPeiming Liu     return success();
1453ff21a90eSPeiming Liu   }
1454ff21a90eSPeiming Liu 
1455ff21a90eSPeiming Liu   // Fuse dim2lvl/lvl2dim pairs.
1456ff21a90eSPeiming Liu   auto def = getInCrds()[0].getDefiningOp<CrdTranslateOp>();
1457ff21a90eSPeiming Liu   bool sameDef = def && llvm::all_of(getInCrds(), [def](Value v) {
1458ff21a90eSPeiming Liu                    return v.getDefiningOp() == def;
1459ff21a90eSPeiming Liu                  });
1460ff21a90eSPeiming Liu   if (!sameDef)
1461ff21a90eSPeiming Liu     return failure();
1462ff21a90eSPeiming Liu 
1463ff21a90eSPeiming Liu   bool oppositeDir = def.getDirection() != getDirection();
1464ff21a90eSPeiming Liu   bool sameOracle =
1465ff21a90eSPeiming Liu       def.getEncoder().getDimToLvl() == getEncoder().getDimToLvl();
1466ff21a90eSPeiming Liu   bool sameCount = def.getNumResults() == getInCrds().size();
1467ff21a90eSPeiming Liu   if (!oppositeDir || !sameOracle || !sameCount)
1468ff21a90eSPeiming Liu     return failure();
1469ff21a90eSPeiming Liu 
1470ff21a90eSPeiming Liu   // The definition produces the coordinates in the same order as the input
1471ff21a90eSPeiming Liu   // coordinates.
1472ff21a90eSPeiming Liu   bool sameOrder = llvm::all_of(llvm::zip_equal(def.getOutCrds(), getInCrds()),
1473ff21a90eSPeiming Liu                                 [](auto valuePair) {
1474ff21a90eSPeiming Liu                                   auto [lhs, rhs] = valuePair;
1475ff21a90eSPeiming Liu                                   return lhs == rhs;
1476ff21a90eSPeiming Liu                                 });
1477ff21a90eSPeiming Liu 
1478ff21a90eSPeiming Liu   if (!sameOrder)
1479ff21a90eSPeiming Liu     return failure();
1480ff21a90eSPeiming Liu   // l1 = dim2lvl (lvl2dim l0)
1481ff21a90eSPeiming Liu   // ==> l0
1482ff21a90eSPeiming Liu   results.append(def.getInCrds().begin(), def.getInCrds().end());
1483ff21a90eSPeiming Liu   return success();
1484ff21a90eSPeiming Liu }
1485ff21a90eSPeiming Liu 
1486c780352dSPeiming Liu void LvlOp::build(OpBuilder &builder, OperationState &state, Value source,
1487c780352dSPeiming Liu                   int64_t index) {
1488c780352dSPeiming Liu   Value val = builder.create<arith::ConstantIndexOp>(state.location, index);
1489c780352dSPeiming Liu   return build(builder, state, source, val);
1490c780352dSPeiming Liu }
1491c780352dSPeiming Liu 
1492f0f5fdf7SPeiming Liu LogicalResult LvlOp::verify() {
1493f0f5fdf7SPeiming Liu   if (std::optional<uint64_t> lvl = getConstantLvlIndex()) {
1494f0f5fdf7SPeiming Liu     auto stt = getSparseTensorType(getSource());
1495f0f5fdf7SPeiming Liu     if (static_cast<uint64_t>(lvl.value()) >= stt.getLvlRank())
149677f8297cSMatthias Springer       return emitError(
149777f8297cSMatthias Springer           "Level index exceeds the rank of the input sparse tensor");
1498f0f5fdf7SPeiming Liu   }
1499f0f5fdf7SPeiming Liu   return success();
1500f0f5fdf7SPeiming Liu }
1501f0f5fdf7SPeiming Liu 
1502f0f5fdf7SPeiming Liu std::optional<uint64_t> LvlOp::getConstantLvlIndex() {
1503f0f5fdf7SPeiming Liu   return getConstantIntValue(getIndex());
1504f0f5fdf7SPeiming Liu }
1505f0f5fdf7SPeiming Liu 
1506f0f5fdf7SPeiming Liu Speculation::Speculatability LvlOp::getSpeculatability() {
1507f0f5fdf7SPeiming Liu   auto constantIndex = getConstantLvlIndex();
1508f0f5fdf7SPeiming Liu   if (!constantIndex)
1509f0f5fdf7SPeiming Liu     return Speculation::NotSpeculatable;
1510f0f5fdf7SPeiming Liu 
1511f0f5fdf7SPeiming Liu   assert(constantIndex <
1512f0f5fdf7SPeiming Liu          cast<RankedTensorType>(getSource().getType()).getRank());
1513f0f5fdf7SPeiming Liu   return Speculation::Speculatable;
1514f0f5fdf7SPeiming Liu }
1515f0f5fdf7SPeiming Liu 
1516f0f5fdf7SPeiming Liu OpFoldResult LvlOp::fold(FoldAdaptor adaptor) {
1517f0f5fdf7SPeiming Liu   auto lvlIndex = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getIndex());
1518f0f5fdf7SPeiming Liu   if (!lvlIndex)
1519f0f5fdf7SPeiming Liu     return {};
1520f0f5fdf7SPeiming Liu 
1521f0f5fdf7SPeiming Liu   Level lvl = lvlIndex.getAPSInt().getZExtValue();
1522f0f5fdf7SPeiming Liu   auto stt = getSparseTensorType(getSource());
1523f0f5fdf7SPeiming Liu   if (lvl >= stt.getLvlRank()) {
1524f0f5fdf7SPeiming Liu     // Follows the same convention used by tensor.dim operation. Out of bound
1525f0f5fdf7SPeiming Liu     // indices produce undefined behavior but are still valid IR. Don't choke on
1526f0f5fdf7SPeiming Liu     // them.
1527f0f5fdf7SPeiming Liu     return {};
1528f0f5fdf7SPeiming Liu   }
1529f0f5fdf7SPeiming Liu 
1530f0f5fdf7SPeiming Liu   // Helper lambda to build an IndexAttr.
1531f0f5fdf7SPeiming Liu   auto getIndexAttr = [this](int64_t lvlSz) {
1532f0f5fdf7SPeiming Liu     return IntegerAttr::get(IndexType::get(getContext()), APInt(64, lvlSz));
1533f0f5fdf7SPeiming Liu   };
1534f0f5fdf7SPeiming Liu 
1535deedf554SPeiming Liu   SmallVector<Size> lvlShape = stt.getLvlShape();
1536deedf554SPeiming Liu   if (!ShapedType::isDynamic(lvlShape[lvl]))
1537deedf554SPeiming Liu     return getIndexAttr(lvlShape[lvl]);
1538f0f5fdf7SPeiming Liu 
1539f0f5fdf7SPeiming Liu   return {};
1540f0f5fdf7SPeiming Liu }
1541f0f5fdf7SPeiming Liu 
1542d808d922SPeiming Liu void ReinterpretMapOp::build(OpBuilder &odsBuilder, OperationState &odsState,
1543d808d922SPeiming Liu                              SparseTensorEncodingAttr dstEnc, Value source) {
1544d808d922SPeiming Liu   auto srcStt = getSparseTensorType(source);
1545d808d922SPeiming Liu   SmallVector<int64_t> srcLvlShape = srcStt.getLvlShape();
1546d808d922SPeiming Liu   SmallVector<int64_t> dstDimShape =
1547e10dc60aSAart Bik       dstEnc.translateShape(srcLvlShape, CrdTransDirectionKind::lvl2dim);
1548d808d922SPeiming Liu   auto dstTp =
1549d808d922SPeiming Liu       RankedTensorType::get(dstDimShape, srcStt.getElementType(), dstEnc);
1550d808d922SPeiming Liu   return build(odsBuilder, odsState, dstTp, source);
1551d808d922SPeiming Liu }
1552d808d922SPeiming Liu 
1553d808d922SPeiming Liu LogicalResult ReinterpretMapOp::verify() {
1554d808d922SPeiming Liu   auto srcStt = getSparseTensorType(getSource());
1555d808d922SPeiming Liu   auto dstStt = getSparseTensorType(getDest());
15561944c4f7SAart Bik   ArrayRef<LevelType> srcLvlTps = srcStt.getLvlTypes();
15571944c4f7SAart Bik   ArrayRef<LevelType> dstLvlTps = dstStt.getLvlTypes();
1558d808d922SPeiming Liu 
1559d808d922SPeiming Liu   if (srcLvlTps.size() != dstLvlTps.size())
1560d808d922SPeiming Liu     return emitError("Level rank mismatch between source/dest tensors");
1561d808d922SPeiming Liu 
1562d808d922SPeiming Liu   for (auto [srcLvlTp, dstLvlTp] : llvm::zip(srcLvlTps, dstLvlTps))
1563d808d922SPeiming Liu     if (srcLvlTp != dstLvlTp)
1564d808d922SPeiming Liu       return emitError("Level type mismatch between source/dest tensors");
1565d808d922SPeiming Liu 
1566d808d922SPeiming Liu   if (srcStt.getPosWidth() != dstStt.getPosWidth() ||
1567d808d922SPeiming Liu       srcStt.getCrdWidth() != dstStt.getCrdWidth()) {
1568d808d922SPeiming Liu     return emitError("Crd/Pos width mismatch between source/dest tensors");
1569d808d922SPeiming Liu   }
1570d808d922SPeiming Liu 
1571d808d922SPeiming Liu   if (srcStt.getElementType() != dstStt.getElementType())
1572d808d922SPeiming Liu     return emitError("Element type mismatch between source/dest tensors");
1573d808d922SPeiming Liu 
157422212ca7SAart Bik   SmallVector<Size> srcLvlShape = srcStt.getLvlShape();
157522212ca7SAart Bik   SmallVector<Size> dstLvlShape = dstStt.getLvlShape();
1576d808d922SPeiming Liu   for (auto [srcLvlSz, dstLvlSz] : llvm::zip(srcLvlShape, dstLvlShape)) {
1577d808d922SPeiming Liu     if (srcLvlSz != dstLvlSz) {
1578d808d922SPeiming Liu       // Should we allow one side to be dynamic size, e.g., <?x?> should be
1579d808d922SPeiming Liu       // compatible to <3x4>? For now, we require all the level sizes to be
1580d808d922SPeiming Liu       // *exactly* matched for simplicity.
1581d808d922SPeiming Liu       return emitError("Level size mismatch between source/dest tensors");
1582d808d922SPeiming Liu     }
1583d808d922SPeiming Liu   }
1584d808d922SPeiming Liu 
1585d808d922SPeiming Liu   return success();
1586d808d922SPeiming Liu }
1587d808d922SPeiming Liu 
1588d808d922SPeiming Liu OpFoldResult ReinterpretMapOp::fold(FoldAdaptor adaptor) {
1589d808d922SPeiming Liu   if (getSource().getType() == getDest().getType())
1590d808d922SPeiming Liu     return getSource();
1591d808d922SPeiming Liu 
1592d808d922SPeiming Liu   if (auto def = getSource().getDefiningOp<ReinterpretMapOp>()) {
1593d808d922SPeiming Liu     // A -> B, B -> A ==> A
1594d808d922SPeiming Liu     if (def.getSource().getType() == getDest().getType())
1595d808d922SPeiming Liu       return def.getSource();
1596d808d922SPeiming Liu   }
1597d808d922SPeiming Liu   return {};
1598d808d922SPeiming Liu }
1599d808d922SPeiming Liu 
16006bc7c9dfSPeiming Liu template <typename ToBufferOp>
16016bc7c9dfSPeiming Liu static LogicalResult inferSparseBufferType(ValueRange ops, DictionaryAttr attr,
16026bc7c9dfSPeiming Liu                                            OpaqueProperties prop,
16036bc7c9dfSPeiming Liu                                            RegionRange region,
16046bc7c9dfSPeiming Liu                                            SmallVectorImpl<mlir::Type> &ret) {
16056bc7c9dfSPeiming Liu   typename ToBufferOp::Adaptor adaptor(ops, attr, prop, region);
16066bc7c9dfSPeiming Liu   SparseTensorType stt = getSparseTensorType(adaptor.getTensor());
16076bc7c9dfSPeiming Liu   Type elemTp = nullptr;
16086bc7c9dfSPeiming Liu   bool withStride = false;
16096bc7c9dfSPeiming Liu   if constexpr (std::is_same_v<ToBufferOp, ToPositionsOp>) {
16106bc7c9dfSPeiming Liu     elemTp = stt.getPosType();
16116bc7c9dfSPeiming Liu   } else if constexpr (std::is_same_v<ToBufferOp, ToCoordinatesOp> ||
16126bc7c9dfSPeiming Liu                        std::is_same_v<ToBufferOp, ToCoordinatesBufferOp>) {
16136bc7c9dfSPeiming Liu     elemTp = stt.getCrdType();
16146bc7c9dfSPeiming Liu     if constexpr (std::is_same_v<ToBufferOp, ToCoordinatesOp>)
16156bc7c9dfSPeiming Liu       withStride = stt.getAoSCOOStart() <= adaptor.getLevel();
16166bc7c9dfSPeiming Liu   } else if constexpr (std::is_same_v<ToBufferOp, ToValuesOp>) {
16176bc7c9dfSPeiming Liu     elemTp = stt.getElementType();
16186bc7c9dfSPeiming Liu   }
16196bc7c9dfSPeiming Liu 
16206bc7c9dfSPeiming Liu   assert(elemTp && "unhandled operation.");
16216bc7c9dfSPeiming Liu   SmallVector<int64_t> bufShape = stt.getBatchLvlShape();
16226bc7c9dfSPeiming Liu   bufShape.push_back(ShapedType::kDynamic);
16236bc7c9dfSPeiming Liu 
16246bc7c9dfSPeiming Liu   auto layout = withStride ? StridedLayoutAttr::StridedLayoutAttr::get(
16256bc7c9dfSPeiming Liu                                  stt.getContext(), ShapedType::kDynamic,
16266bc7c9dfSPeiming Liu                                  {ShapedType::kDynamic})
16276bc7c9dfSPeiming Liu                            : StridedLayoutAttr();
16286bc7c9dfSPeiming Liu   ret.emplace_back(MemRefType::get(bufShape, elemTp, layout));
16296bc7c9dfSPeiming Liu   return success();
16306bc7c9dfSPeiming Liu }
16316bc7c9dfSPeiming Liu 
163284cd51bbSwren romano LogicalResult ToPositionsOp::verify() {
16335b729503SAart Bik   auto stt = getSparseTensorType(getTensor());
163484cd51bbSwren romano   if (failed(lvlIsInBounds(getLevel(), getTensor())))
163584cd51bbSwren romano     return emitError("requested level is out of bounds");
16365b729503SAart Bik   if (failed(isMatchingWidth(getResult(), stt.getPosWidth())))
163784cd51bbSwren romano     return emitError("unexpected type for positions");
163896a23911SAart Bik   return success();
163996a23911SAart Bik }
164096a23911SAart Bik 
16416bc7c9dfSPeiming Liu LogicalResult
16426bc7c9dfSPeiming Liu ToPositionsOp::inferReturnTypes(MLIRContext *ctx, std::optional<Location> loc,
16436bc7c9dfSPeiming Liu                                 ValueRange ops, DictionaryAttr attr,
16446bc7c9dfSPeiming Liu                                 OpaqueProperties prop, RegionRange region,
16456bc7c9dfSPeiming Liu                                 SmallVectorImpl<mlir::Type> &ret) {
16466bc7c9dfSPeiming Liu   return inferSparseBufferType<ToPositionsOp>(ops, attr, prop, region, ret);
16476bc7c9dfSPeiming Liu }
16486bc7c9dfSPeiming Liu 
164984cd51bbSwren romano LogicalResult ToCoordinatesOp::verify() {
16505b729503SAart Bik   auto stt = getSparseTensorType(getTensor());
165184cd51bbSwren romano   if (failed(lvlIsInBounds(getLevel(), getTensor())))
165284cd51bbSwren romano     return emitError("requested level is out of bounds");
16535b729503SAart Bik   if (failed(isMatchingWidth(getResult(), stt.getCrdWidth())))
165484cd51bbSwren romano     return emitError("unexpected type for coordinates");
165596a23911SAart Bik   return success();
165696a23911SAart Bik }
165796a23911SAart Bik 
16586bc7c9dfSPeiming Liu LogicalResult
16596bc7c9dfSPeiming Liu ToCoordinatesOp::inferReturnTypes(MLIRContext *ctx, std::optional<Location> loc,
16606bc7c9dfSPeiming Liu                                   ValueRange ops, DictionaryAttr attr,
16616bc7c9dfSPeiming Liu                                   OpaqueProperties prop, RegionRange region,
16626bc7c9dfSPeiming Liu                                   SmallVectorImpl<mlir::Type> &ret) {
16636bc7c9dfSPeiming Liu   return inferSparseBufferType<ToCoordinatesOp>(ops, attr, prop, region, ret);
16646bc7c9dfSPeiming Liu }
16656bc7c9dfSPeiming Liu 
166684cd51bbSwren romano LogicalResult ToCoordinatesBufferOp::verify() {
16675b729503SAart Bik   auto stt = getSparseTensorType(getTensor());
16685248a987SPeiming Liu   if (stt.getAoSCOOStart() >= stt.getLvlRank())
16699bde3d0cSbixia1     return emitError("expected sparse tensor with a COO region");
16709bde3d0cSbixia1   return success();
16719bde3d0cSbixia1 }
16729bde3d0cSbixia1 
16736bc7c9dfSPeiming Liu LogicalResult ToCoordinatesBufferOp::inferReturnTypes(
16746bc7c9dfSPeiming Liu     MLIRContext *ctx, std::optional<Location> loc, ValueRange ops,
16756bc7c9dfSPeiming Liu     DictionaryAttr attr, OpaqueProperties prop, RegionRange region,
16766bc7c9dfSPeiming Liu     SmallVectorImpl<mlir::Type> &ret) {
16776bc7c9dfSPeiming Liu   return inferSparseBufferType<ToCoordinatesBufferOp>(ops, attr, prop, region,
16786bc7c9dfSPeiming Liu                                                       ret);
16796bc7c9dfSPeiming Liu }
16806bc7c9dfSPeiming Liu 
1681b98dc035SRiver Riddle LogicalResult ToValuesOp::verify() {
16825b729503SAart Bik   auto stt = getSparseTensorType(getTensor());
16839916ab03Swren romano   auto mtp = getMemRefType(getResult());
16845b729503SAart Bik   if (stt.getElementType() != mtp.getElementType())
1685b98dc035SRiver Riddle     return emitError("unexpected mismatch in element types");
168696a23911SAart Bik   return success();
168796a23911SAart Bik }
168896a23911SAart Bik 
16896bc7c9dfSPeiming Liu LogicalResult ToValuesOp::inferReturnTypes(MLIRContext *ctx,
16906bc7c9dfSPeiming Liu                                            std::optional<Location> loc,
16916bc7c9dfSPeiming Liu                                            ValueRange ops, DictionaryAttr attr,
16926bc7c9dfSPeiming Liu                                            OpaqueProperties prop,
16936bc7c9dfSPeiming Liu                                            RegionRange region,
16946bc7c9dfSPeiming Liu                                            SmallVectorImpl<mlir::Type> &ret) {
16956bc7c9dfSPeiming Liu   return inferSparseBufferType<ToValuesOp>(ops, attr, prop, region, ret);
16966bc7c9dfSPeiming Liu }
16976bc7c9dfSPeiming Liu 
1698c738b430SPeiming Liu LogicalResult ToSliceOffsetOp::verify() {
169977f8297cSMatthias Springer   auto rank = getSlice().getType().getRank();
1700c738b430SPeiming Liu   if (rank <= getDim().getSExtValue() || getDim().getSExtValue() < 0)
1701c738b430SPeiming Liu     return emitError("requested dimension out of bound");
1702c738b430SPeiming Liu   return success();
1703c738b430SPeiming Liu }
1704c738b430SPeiming Liu 
1705c738b430SPeiming Liu LogicalResult ToSliceStrideOp::verify() {
170677f8297cSMatthias Springer   auto rank = getSlice().getType().getRank();
1707c738b430SPeiming Liu   if (rank <= getDim().getSExtValue() || getDim().getSExtValue() < 0)
1708c738b430SPeiming Liu     return emitError("requested dimension out of bound");
1709c738b430SPeiming Liu   return success();
1710c738b430SPeiming Liu }
1711c738b430SPeiming Liu 
171271cc0f1cSPeiming Liu LogicalResult GetStorageSpecifierOp::verify() {
17132e2011daSAart Bik   return verifySparsifierGetterSetter(getSpecifierKind(), getLevel(),
17142e2011daSAart Bik                                       getSpecifier(), getOperation());
171571cc0f1cSPeiming Liu }
171671cc0f1cSPeiming Liu 
1717509974afSPeiming Liu template <typename SpecifierOp>
1718509974afSPeiming Liu static SetStorageSpecifierOp getSpecifierSetDef(SpecifierOp op) {
1719509974afSPeiming Liu   return op.getSpecifier().template getDefiningOp<SetStorageSpecifierOp>();
1720509974afSPeiming Liu }
1721509974afSPeiming Liu 
17227df76121SMarkus Böck OpFoldResult GetStorageSpecifierOp::fold(FoldAdaptor adaptor) {
172384cd51bbSwren romano   const StorageSpecifierKind kind = getSpecifierKind();
172484cd51bbSwren romano   const auto lvl = getLevel();
1725509974afSPeiming Liu   for (auto op = getSpecifierSetDef(*this); op; op = getSpecifierSetDef(op))
172684cd51bbSwren romano     if (kind == op.getSpecifierKind() && lvl == op.getLevel())
1727509974afSPeiming Liu       return op.getValue();
1728509974afSPeiming Liu   return {};
1729509974afSPeiming Liu }
1730509974afSPeiming Liu 
173171cc0f1cSPeiming Liu LogicalResult SetStorageSpecifierOp::verify() {
17322e2011daSAart Bik   return verifySparsifierGetterSetter(getSpecifierKind(), getLevel(),
17332e2011daSAart Bik                                       getSpecifier(), getOperation());
173471cc0f1cSPeiming Liu }
173571cc0f1cSPeiming Liu 
1736414ed019SJim Kitchen template <class T>
1737414ed019SJim Kitchen static LogicalResult verifyNumBlockArgs(T *op, Region &region,
1738414ed019SJim Kitchen                                         const char *regionName,
1739414ed019SJim Kitchen                                         TypeRange inputTypes, Type outputType) {
1740414ed019SJim Kitchen   unsigned numArgs = region.getNumArguments();
1741414ed019SJim Kitchen   unsigned expectedNum = inputTypes.size();
1742414ed019SJim Kitchen   if (numArgs != expectedNum)
1743414ed019SJim Kitchen     return op->emitError() << regionName << " region must have exactly "
1744414ed019SJim Kitchen                            << expectedNum << " arguments";
1745414ed019SJim Kitchen 
1746414ed019SJim Kitchen   for (unsigned i = 0; i < numArgs; i++) {
1747414ed019SJim Kitchen     Type typ = region.getArgument(i).getType();
1748414ed019SJim Kitchen     if (typ != inputTypes[i])
1749414ed019SJim Kitchen       return op->emitError() << regionName << " region argument " << (i + 1)
1750414ed019SJim Kitchen                              << " type mismatch";
1751414ed019SJim Kitchen   }
1752414ed019SJim Kitchen   Operation *term = region.front().getTerminator();
1753414ed019SJim Kitchen   YieldOp yield = dyn_cast<YieldOp>(term);
1754414ed019SJim Kitchen   if (!yield)
1755414ed019SJim Kitchen     return op->emitError() << regionName
1756414ed019SJim Kitchen                            << " region must end with sparse_tensor.yield";
1757a54930e6SPeiming Liu   if (!yield.hasSingleResult() ||
1758a54930e6SPeiming Liu       yield.getSingleResult().getType() != outputType)
1759414ed019SJim Kitchen     return op->emitError() << regionName << " region yield type mismatch";
1760414ed019SJim Kitchen 
1761414ed019SJim Kitchen   return success();
1762414ed019SJim Kitchen }
1763414ed019SJim Kitchen 
1764414ed019SJim Kitchen LogicalResult BinaryOp::verify() {
1765414ed019SJim Kitchen   NamedAttrList attrs = (*this)->getAttrs();
176604235d07SJacques Pienaar   Type leftType = getX().getType();
176704235d07SJacques Pienaar   Type rightType = getY().getType();
176804235d07SJacques Pienaar   Type outputType = getOutput().getType();
176904235d07SJacques Pienaar   Region &overlap = getOverlapRegion();
177004235d07SJacques Pienaar   Region &left = getLeftRegion();
177104235d07SJacques Pienaar   Region &right = getRightRegion();
1772414ed019SJim Kitchen 
1773414ed019SJim Kitchen   // Check correct number of block arguments and return type for each
1774414ed019SJim Kitchen   // non-empty region.
1775414ed019SJim Kitchen   if (!overlap.empty()) {
17762e2011daSAart Bik     if (failed(verifyNumBlockArgs(this, overlap, "overlap",
17772e2011daSAart Bik                                   TypeRange{leftType, rightType}, outputType)))
17782e2011daSAart Bik       return failure();
1779414ed019SJim Kitchen   }
1780414ed019SJim Kitchen   if (!left.empty()) {
17812e2011daSAart Bik     if (failed(verifyNumBlockArgs(this, left, "left", TypeRange{leftType},
17822e2011daSAart Bik                                   outputType)))
17832e2011daSAart Bik       return failure();
178404235d07SJacques Pienaar   } else if (getLeftIdentity()) {
1785414ed019SJim Kitchen     if (leftType != outputType)
1786414ed019SJim Kitchen       return emitError("left=identity requires first argument to have the same "
1787414ed019SJim Kitchen                        "type as the output");
1788414ed019SJim Kitchen   }
1789414ed019SJim Kitchen   if (!right.empty()) {
17902e2011daSAart Bik     if (failed(verifyNumBlockArgs(this, right, "right", TypeRange{rightType},
17912e2011daSAart Bik                                   outputType)))
17922e2011daSAart Bik       return failure();
179304235d07SJacques Pienaar   } else if (getRightIdentity()) {
1794414ed019SJim Kitchen     if (rightType != outputType)
1795414ed019SJim Kitchen       return emitError("right=identity requires second argument to have the "
1796414ed019SJim Kitchen                        "same type as the output");
1797414ed019SJim Kitchen   }
1798414ed019SJim Kitchen   return success();
1799414ed019SJim Kitchen }
1800414ed019SJim Kitchen 
1801414ed019SJim Kitchen LogicalResult UnaryOp::verify() {
180204235d07SJacques Pienaar   Type inputType = getX().getType();
180304235d07SJacques Pienaar   Type outputType = getOutput().getType();
1804414ed019SJim Kitchen 
1805414ed019SJim Kitchen   // Check correct number of block arguments and return type for each
1806414ed019SJim Kitchen   // non-empty region.
180704235d07SJacques Pienaar   Region &present = getPresentRegion();
1808414ed019SJim Kitchen   if (!present.empty()) {
18092e2011daSAart Bik     if (failed(verifyNumBlockArgs(this, present, "present",
18102e2011daSAart Bik                                   TypeRange{inputType}, outputType)))
18112e2011daSAart Bik       return failure();
1812414ed019SJim Kitchen   }
181304235d07SJacques Pienaar   Region &absent = getAbsentRegion();
1814414ed019SJim Kitchen   if (!absent.empty()) {
18152e2011daSAart Bik     if (failed(verifyNumBlockArgs(this, absent, "absent", TypeRange{},
18162e2011daSAart Bik                                   outputType)))
18172e2011daSAart Bik       return failure();
18187e83a1afSAart Bik     // Absent branch can only yield invariant values.
18197e83a1afSAart Bik     Block *absentBlock = &absent.front();
18207e83a1afSAart Bik     Block *parent = getOperation()->getBlock();
1821a54930e6SPeiming Liu     Value absentVal =
1822a54930e6SPeiming Liu         cast<YieldOp>(absentBlock->getTerminator()).getSingleResult();
18237e83a1afSAart Bik     if (auto arg = dyn_cast<BlockArgument>(absentVal)) {
18247e83a1afSAart Bik       if (arg.getOwner() == parent)
18257e83a1afSAart Bik         return emitError("absent region cannot yield linalg argument");
18267e83a1afSAart Bik     } else if (Operation *def = absentVal.getDefiningOp()) {
18277e83a1afSAart Bik       if (!isa<arith::ConstantOp>(def) &&
18287e83a1afSAart Bik           (def->getBlock() == absentBlock || def->getBlock() == parent))
18297e83a1afSAart Bik         return emitError("absent region cannot yield locally computed value");
18307e83a1afSAart Bik     }
1831414ed019SJim Kitchen   }
1832414ed019SJim Kitchen   return success();
1833414ed019SJim Kitchen }
1834414ed019SJim Kitchen 
1835761c9dd9SPeiming Liu bool ConcatenateOp::needsExtraSort() {
1836761c9dd9SPeiming Liu   SparseTensorType dstStt = getSparseTensorType(*this);
1837761c9dd9SPeiming Liu   if (dstStt.isAllDense() || !dstStt.isAllOrdered())
1838761c9dd9SPeiming Liu     return false;
1839761c9dd9SPeiming Liu 
1840761c9dd9SPeiming Liu   bool allSameOrdered = llvm::all_of(getInputs(), [dstStt](Value op) {
1841761c9dd9SPeiming Liu     return getSparseTensorType(op).hasSameDimToLvl(dstStt);
1842761c9dd9SPeiming Liu   });
1843761c9dd9SPeiming Liu   // TODO: When conDim != 0, as long as conDim corresponding to the first level
1844761c9dd9SPeiming Liu   // in all input/output buffers, and all input/output buffers have the same
1845761c9dd9SPeiming Liu   // dimToLvl, the tmp COO buffer is still unnecessary (e.g, concatenate
1846761c9dd9SPeiming Liu   // CSC matrices along column).
1847761c9dd9SPeiming Liu   bool directLowerable =
1848761c9dd9SPeiming Liu       allSameOrdered && getDimension() == 0 && dstStt.isIdentity();
1849761c9dd9SPeiming Liu   return !directLowerable;
1850761c9dd9SPeiming Liu }
1851761c9dd9SPeiming Liu 
1852de907138SPeiming Liu LogicalResult ConcatenateOp::verify() {
1853f708a549Swren romano   const auto dstTp = getSparseTensorType(*this);
185484cd51bbSwren romano   const Dimension concatDim = getDimension();
1855f708a549Swren romano   const Dimension dimRank = dstTp.getDimRank();
1856de907138SPeiming Liu 
1857de907138SPeiming Liu   if (getInputs().size() <= 1)
1858de907138SPeiming Liu     return emitError("Need at least two tensors to concatenate.");
1859de907138SPeiming Liu 
1860f708a549Swren romano   if (concatDim >= dimRank)
1861de907138SPeiming Liu     return emitError(llvm::formatv(
1862f708a549Swren romano         "Concat-dimension is out of bounds for dimension-rank ({0} >= {1})",
1863f708a549Swren romano         concatDim, dimRank));
1864de907138SPeiming Liu 
1865f708a549Swren romano   for (const auto &it : llvm::enumerate(getInputs())) {
1866f708a549Swren romano     const auto i = it.index();
1867f708a549Swren romano     const auto srcTp = getSparseTensorType(it.value());
1868f708a549Swren romano     if (srcTp.hasDynamicDimShape())
1869f708a549Swren romano       return emitError(llvm::formatv("Input tensor ${0} has dynamic shape", i));
1870f708a549Swren romano     const Dimension srcDimRank = srcTp.getDimRank();
1871f708a549Swren romano     if (srcDimRank != dimRank)
1872de907138SPeiming Liu       return emitError(
1873f708a549Swren romano           llvm::formatv("Input tensor ${0} has a different rank (rank={1}) "
1874de907138SPeiming Liu                         "from the output tensor (rank={2}).",
1875f708a549Swren romano                         i, srcDimRank, dimRank));
1876de907138SPeiming Liu   }
1877de907138SPeiming Liu 
1878f708a549Swren romano   for (Dimension d = 0; d < dimRank; d++) {
187922212ca7SAart Bik     const Size dstSh = dstTp.getDimShape()[d];
1880f708a549Swren romano     if (d == concatDim) {
1881f708a549Swren romano       if (!ShapedType::isDynamic(dstSh)) {
1882f708a549Swren romano         // If we reach here, then all inputs have static shapes.  So we
1883f708a549Swren romano         // can use `getDimShape()[d]` instead of `*getDynamicDimSize(d)`
1884f708a549Swren romano         // to avoid redundant assertions in the loop.
188522212ca7SAart Bik         Size sumSz = 0;
1886f708a549Swren romano         for (const auto src : getInputs())
1887f708a549Swren romano           sumSz += getSparseTensorType(src).getDimShape()[d];
1888de907138SPeiming Liu         // If all dimension are statically known, the sum of all the input
1889de907138SPeiming Liu         // dimensions should be equal to the output dimension.
1890f708a549Swren romano         if (sumSz != dstSh)
1891de907138SPeiming Liu           return emitError(
1892de907138SPeiming Liu               "The concatenation dimension of the output tensor should be the "
1893de907138SPeiming Liu               "sum of all the concatenation dimensions of the input tensors.");
1894de907138SPeiming Liu       }
1895de907138SPeiming Liu     } else {
189622212ca7SAart Bik       Size prev = dstSh;
1897f708a549Swren romano       for (const auto src : getInputs()) {
1898f708a549Swren romano         const auto sh = getSparseTensorType(src).getDimShape()[d];
1899f708a549Swren romano         if (!ShapedType::isDynamic(prev) && sh != prev)
1900de907138SPeiming Liu           return emitError("All dimensions (expect for the concatenating one) "
1901de907138SPeiming Liu                            "should be equal.");
1902f708a549Swren romano         prev = sh;
1903de907138SPeiming Liu       }
1904de907138SPeiming Liu     }
1905de907138SPeiming Liu   }
1906de907138SPeiming Liu 
1907de907138SPeiming Liu   return success();
1908de907138SPeiming Liu }
1909de907138SPeiming Liu 
191014504caeSbixia1 void PushBackOp::build(OpBuilder &builder, OperationState &result,
1911988733c6SPeiming Liu                        Value curSize, Value inBuffer, Value value) {
1912988733c6SPeiming Liu   build(builder, result, curSize, inBuffer, value, Value());
191314504caeSbixia1 }
191414504caeSbixia1 
191514504caeSbixia1 LogicalResult PushBackOp::verify() {
1916743fbcb7Swren romano   if (Value n = getN()) {
1917cb7bda2aSMatthias Springer     std::optional<int64_t> nValue = getConstantIntValue(n);
191814504caeSbixia1     if (nValue && nValue.value() < 1)
191914504caeSbixia1       return emitOpError("n must be not less than 1");
192014504caeSbixia1   }
192114504caeSbixia1   return success();
192214504caeSbixia1 }
192314504caeSbixia1 
1924a3610359SAart Bik LogicalResult CompressOp::verify() {
192584cd51bbSwren romano   const auto stt = getSparseTensorType(getTensor());
192684cd51bbSwren romano   if (stt.getLvlRank() != 1 + static_cast<Level>(getLvlCoords().size()))
192784cd51bbSwren romano     return emitOpError("incorrect number of coordinates");
1928a3610359SAart Bik   return success();
1929a3610359SAart Bik }
1930a3610359SAart Bik 
193100ad0655SPeiming Liu void ForeachOp::build(
193200ad0655SPeiming Liu     OpBuilder &builder, OperationState &result, Value tensor,
19339e8d9316SPeiming Liu     ValueRange initArgs, AffineMapAttr order,
19344fa00ce1SPeiming Liu     function_ref<void(OpBuilder &, Location, ValueRange, Value, ValueRange)>
19354fa00ce1SPeiming Liu         bodyBuilder) {
19369e8d9316SPeiming Liu   build(builder, result, initArgs.getTypes(), tensor, initArgs, order);
19374fa00ce1SPeiming Liu   // Builds foreach body.
193800ad0655SPeiming Liu   if (!bodyBuilder)
193900ad0655SPeiming Liu     return;
1940f708a549Swren romano   const auto stt = getSparseTensorType(tensor);
1941f708a549Swren romano   const Dimension dimRank = stt.getDimRank();
194200ad0655SPeiming Liu 
194384cd51bbSwren romano   // Starts with `dimRank`-many coordinates.
1944f708a549Swren romano   SmallVector<Type> blockArgTypes(dimRank, builder.getIndexType());
194500ad0655SPeiming Liu   // Followed by one value.
1946f708a549Swren romano   blockArgTypes.push_back(stt.getElementType());
1947f708a549Swren romano   // Followed by the reduction variables.
19484fa00ce1SPeiming Liu   blockArgTypes.append(initArgs.getTypes().begin(), initArgs.getTypes().end());
194900ad0655SPeiming Liu 
1950f708a549Swren romano   SmallVector<Location> blockArgLocs(blockArgTypes.size(), tensor.getLoc());
195100ad0655SPeiming Liu 
195200ad0655SPeiming Liu   OpBuilder::InsertionGuard guard(builder);
195300ad0655SPeiming Liu   auto &region = *result.regions.front();
195400ad0655SPeiming Liu   Block *bodyBlock =
195500ad0655SPeiming Liu       builder.createBlock(&region, region.end(), blockArgTypes, blockArgLocs);
19564fa00ce1SPeiming Liu   bodyBuilder(builder, result.location,
1957f708a549Swren romano               bodyBlock->getArguments().slice(0, dimRank),
1958f708a549Swren romano               bodyBlock->getArguments()[dimRank],
1959f708a549Swren romano               bodyBlock->getArguments().drop_front(dimRank + 1));
196000ad0655SPeiming Liu }
196100ad0655SPeiming Liu 
1962e08865a1SPeiming Liu LogicalResult ForeachOp::verify() {
1963f708a549Swren romano   const auto t = getSparseTensorType(getTensor());
1964f708a549Swren romano   const Dimension dimRank = t.getDimRank();
1965f708a549Swren romano   const auto args = getBody()->getArguments();
1966e08865a1SPeiming Liu 
196753ffafb2SPeiming Liu   if (getOrder().has_value() && getOrder()->getNumDims() != t.getLvlRank())
196853ffafb2SPeiming Liu     return emitError("Level traverse order does not match tensor's level rank");
19699e8d9316SPeiming Liu 
197053ffafb2SPeiming Liu   if (dimRank + 1 + getInitArgs().size() != args.size())
1971e08865a1SPeiming Liu     return emitError("Unmatched number of arguments in the block");
1972e08865a1SPeiming Liu 
19734fa00ce1SPeiming Liu   if (getNumResults() != getInitArgs().size())
19744fa00ce1SPeiming Liu     return emitError("Mismatch in number of init arguments and results");
19754fa00ce1SPeiming Liu 
19764fa00ce1SPeiming Liu   if (getResultTypes() != getInitArgs().getTypes())
19774fa00ce1SPeiming Liu     return emitError("Mismatch in types of init arguments and results");
19784fa00ce1SPeiming Liu 
1979f708a549Swren romano   // Cannot mark this const, because the getters aren't.
19804fa00ce1SPeiming Liu   auto yield = cast<YieldOp>(getBody()->getTerminator());
19814fa00ce1SPeiming Liu   if (yield.getNumOperands() != getNumResults() ||
19824fa00ce1SPeiming Liu       yield.getOperands().getTypes() != getResultTypes())
19834fa00ce1SPeiming Liu     return emitError("Mismatch in types of yield values and results");
19844fa00ce1SPeiming Liu 
1985f708a549Swren romano   const auto iTp = IndexType::get(getContext());
1986f708a549Swren romano   for (Dimension d = 0; d < dimRank; d++)
1987f708a549Swren romano     if (args[d].getType() != iTp)
198877f8297cSMatthias Springer       return emitError(
1989f708a549Swren romano           llvm::formatv("Expecting Index type for argument at index {0}", d));
1990e08865a1SPeiming Liu 
1991f708a549Swren romano   const auto elemTp = t.getElementType();
1992f708a549Swren romano   const auto valueTp = args[dimRank].getType();
1993e08865a1SPeiming Liu   if (elemTp != valueTp)
199477f8297cSMatthias Springer     return emitError(
199577f8297cSMatthias Springer         llvm::formatv("Unmatched element type between input tensor and "
1996e08865a1SPeiming Liu                       "block argument, expected:{0}, got: {1}",
1997e08865a1SPeiming Liu                       elemTp, valueTp));
1998e08865a1SPeiming Liu   return success();
1999e08865a1SPeiming Liu }
2000e08865a1SPeiming Liu 
20010aacc213SPeiming Liu OpFoldResult ReorderCOOOp::fold(FoldAdaptor adaptor) {
20020aacc213SPeiming Liu   if (getSparseTensorEncoding(getInputCoo().getType()) ==
20030aacc213SPeiming Liu       getSparseTensorEncoding(getResultCoo().getType()))
20040aacc213SPeiming Liu     return getInputCoo();
20050aacc213SPeiming Liu 
20060aacc213SPeiming Liu   return {};
20070aacc213SPeiming Liu }
20080aacc213SPeiming Liu 
20090aacc213SPeiming Liu LogicalResult ReorderCOOOp::verify() {
20100aacc213SPeiming Liu   SparseTensorType srcStt = getSparseTensorType(getInputCoo());
20110aacc213SPeiming Liu   SparseTensorType dstStt = getSparseTensorType(getResultCoo());
20120aacc213SPeiming Liu 
20135b729503SAart Bik   if (!srcStt.isCOOType() || !dstStt.isCOOType())
201477f8297cSMatthias Springer     return emitError("Expected COO sparse tensors only");
201598f8b1afSAart Bik 
20160aacc213SPeiming Liu   if (!srcStt.hasSameDimToLvl(dstStt))
201777f8297cSMatthias Springer     return emitError("Unmatched dim2lvl map between input and result COO");
20180aacc213SPeiming Liu 
20190aacc213SPeiming Liu   if (srcStt.getPosType() != dstStt.getPosType() ||
20200aacc213SPeiming Liu       srcStt.getCrdType() != dstStt.getCrdType() ||
202198f8b1afSAart Bik       srcStt.getElementType() != dstStt.getElementType())
202277f8297cSMatthias Springer     return emitError("Unmatched storage format between input and result COO");
202398f8b1afSAart Bik 
20240aacc213SPeiming Liu   return success();
20250aacc213SPeiming Liu }
20260aacc213SPeiming Liu 
20272b8a4d9cSJim Kitchen LogicalResult ReduceOp::verify() {
2028a1ec0d8bSJacques Pienaar   Type inputType = getX().getType();
2029a1ec0d8bSJacques Pienaar   Region &formula = getRegion();
20302e2011daSAart Bik   return verifyNumBlockArgs(this, formula, "reduce",
20312e2011daSAart Bik                             TypeRange{inputType, inputType}, inputType);
20322b8a4d9cSJim Kitchen }
20332b8a4d9cSJim Kitchen 
203407150fecSJim Kitchen LogicalResult SelectOp::verify() {
203507150fecSJim Kitchen   Builder b(getContext());
203607150fecSJim Kitchen   Type inputType = getX().getType();
203707150fecSJim Kitchen   Type boolType = b.getI1Type();
203807150fecSJim Kitchen   Region &formula = getRegion();
20392e2011daSAart Bik   return verifyNumBlockArgs(this, formula, "select", TypeRange{inputType},
20402e2011daSAart Bik                             boolType);
20412b8a4d9cSJim Kitchen }
20422b8a4d9cSJim Kitchen 
20430083f833SPeiming Liu LogicalResult SortOp::verify() {
2044bfa3bc43SPeiming Liu   AffineMap xPerm = getPermMap();
2045bfa3bc43SPeiming Liu   uint64_t nx = xPerm.getNumDims();
2046bfa3bc43SPeiming Liu   if (nx < 1)
204777f8297cSMatthias Springer     return emitError(llvm::formatv("Expected rank(perm_map) > 1, got {0}", nx));
2048bfa3bc43SPeiming Liu 
2049bfa3bc43SPeiming Liu   if (!xPerm.isPermutation())
205077f8297cSMatthias Springer     return emitError(
205177f8297cSMatthias Springer         llvm::formatv("Expected a permutation map, got {0}", xPerm));
2052bfa3bc43SPeiming Liu 
2053cf24d49dSbixia1   // We can't check the size of the buffers when n or buffer dimensions aren't
2054cf24d49dSbixia1   // compile-time constants.
2055d2d29288SAart Bik   std::optional<int64_t> cn = getConstantIntValue(getN());
2056cf24d49dSbixia1   if (!cn)
2057cf24d49dSbixia1     return success();
2058cf24d49dSbixia1 
2059d2d29288SAart Bik   // Verify dimensions.
206077f8297cSMatthias Springer   const auto checkDim = [&](Value v, Size minSize,
206177f8297cSMatthias Springer                             const char *message) -> LogicalResult {
206222212ca7SAart Bik     const Size sh = getMemRefType(v).getShape()[0];
2063f708a549Swren romano     if (!ShapedType::isDynamic(sh) && sh < minSize)
206477f8297cSMatthias Springer       return emitError(
206577f8297cSMatthias Springer           llvm::formatv("{0} got {1} < {2}", message, sh, minSize));
206677f8297cSMatthias Springer     return success();
2067cf24d49dSbixia1   };
2068d2d29288SAart Bik   uint64_t n = cn.value();
2069d2d29288SAart Bik   uint64_t ny = 0;
2070d2d29288SAart Bik   if (auto nyAttr = getNyAttr())
2071d2d29288SAart Bik     ny = nyAttr.getInt();
207277f8297cSMatthias Springer   if (failed(checkDim(getXy(), n * (nx + ny),
207377f8297cSMatthias Springer                       "Expected dimension(xy) >= n * (rank(perm_map) + ny)")))
207477f8297cSMatthias Springer     return failure();
2075d2d29288SAart Bik   for (Value opnd : getYs())
207677f8297cSMatthias Springer     if (failed(checkDim(opnd, n, "Expected dimension(y) >= n")))
207777f8297cSMatthias Springer       return failure();
2078cf24d49dSbixia1 
2079cf24d49dSbixia1   return success();
2080cf24d49dSbixia1 }
2081cf24d49dSbixia1 
2082481bd5d4SPeiming Liu //===----------------------------------------------------------------------===//
2083481bd5d4SPeiming Liu // Sparse Tensor Iteration Operations.
2084481bd5d4SPeiming Liu //===----------------------------------------------------------------------===//
2085481bd5d4SPeiming Liu 
2086481bd5d4SPeiming Liu IterSpaceType IteratorType::getIterSpaceType() const {
2087481bd5d4SPeiming Liu   return IterSpaceType::get(getContext(), getEncoding(), getLoLvl(),
2088481bd5d4SPeiming Liu                             getHiLvl());
2089481bd5d4SPeiming Liu }
2090481bd5d4SPeiming Liu 
2091481bd5d4SPeiming Liu IteratorType IterSpaceType::getIteratorType() const {
2092481bd5d4SPeiming Liu   return IteratorType::get(getContext(), getEncoding(), getLoLvl(), getHiLvl());
2093481bd5d4SPeiming Liu }
2094481bd5d4SPeiming Liu 
2095481bd5d4SPeiming Liu /// Parses a level range in the form "$lo `to` $hi"
2096481bd5d4SPeiming Liu /// or simply "$lo" if $hi - $lo = 1
2097481bd5d4SPeiming Liu static ParseResult parseLevelRange(AsmParser &parser, Level &lvlLo,
2098481bd5d4SPeiming Liu                                    Level &lvlHi) {
2099481bd5d4SPeiming Liu   if (parser.parseInteger(lvlLo))
2100481bd5d4SPeiming Liu     return failure();
2101481bd5d4SPeiming Liu 
2102481bd5d4SPeiming Liu   if (succeeded(parser.parseOptionalKeyword("to"))) {
2103481bd5d4SPeiming Liu     if (parser.parseInteger(lvlHi))
2104481bd5d4SPeiming Liu       return failure();
2105481bd5d4SPeiming Liu   } else {
2106481bd5d4SPeiming Liu     lvlHi = lvlLo + 1;
2107481bd5d4SPeiming Liu   }
2108481bd5d4SPeiming Liu 
2109481bd5d4SPeiming Liu   if (lvlHi <= lvlLo)
211077f8297cSMatthias Springer     return parser.emitError(parser.getNameLoc(),
2111481bd5d4SPeiming Liu                             "expect larger level upper bound than lower bound");
2112481bd5d4SPeiming Liu 
2113481bd5d4SPeiming Liu   return success();
2114481bd5d4SPeiming Liu }
2115481bd5d4SPeiming Liu 
2116481bd5d4SPeiming Liu /// Parses a level range in the form "$lo `to` $hi"
2117481bd5d4SPeiming Liu /// or simply "$lo" if $hi - $lo = 1
2118481bd5d4SPeiming Liu static ParseResult parseLevelRange(OpAsmParser &parser, IntegerAttr &lvlLoAttr,
2119481bd5d4SPeiming Liu                                    IntegerAttr &lvlHiAttr) {
2120481bd5d4SPeiming Liu   Level lvlLo, lvlHi;
2121481bd5d4SPeiming Liu   if (parseLevelRange(parser, lvlLo, lvlHi))
2122481bd5d4SPeiming Liu     return failure();
2123481bd5d4SPeiming Liu 
2124481bd5d4SPeiming Liu   lvlLoAttr = IntegerAttr::get(parser.getBuilder().getIndexType(), lvlLo);
2125481bd5d4SPeiming Liu   lvlHiAttr = IntegerAttr::get(parser.getBuilder().getIndexType(), lvlHi);
2126481bd5d4SPeiming Liu   return success();
2127481bd5d4SPeiming Liu }
2128481bd5d4SPeiming Liu 
2129481bd5d4SPeiming Liu /// Prints a level range in the form "$lo `to` $hi"
2130481bd5d4SPeiming Liu /// or simply "$lo" if $hi - $lo = 1
2131481bd5d4SPeiming Liu static void printLevelRange(AsmPrinter &p, Level lo, Level hi) {
2132481bd5d4SPeiming Liu 
2133481bd5d4SPeiming Liu   if (lo + 1 == hi)
2134481bd5d4SPeiming Liu     p << lo;
2135481bd5d4SPeiming Liu   else
2136481bd5d4SPeiming Liu     p << lo << " to " << hi;
2137481bd5d4SPeiming Liu }
2138481bd5d4SPeiming Liu 
2139481bd5d4SPeiming Liu /// Prints a level range in the form "$lo `to` $hi"
2140481bd5d4SPeiming Liu /// or simply "$lo" if $hi - $lo = 1
2141481bd5d4SPeiming Liu static void printLevelRange(OpAsmPrinter &p, Operation *, IntegerAttr lvlLo,
2142481bd5d4SPeiming Liu                             IntegerAttr lvlHi) {
2143481bd5d4SPeiming Liu   unsigned lo = lvlLo.getValue().getZExtValue();
2144481bd5d4SPeiming Liu   unsigned hi = lvlHi.getValue().getZExtValue();
2145481bd5d4SPeiming Liu   printLevelRange(p, lo, hi);
2146481bd5d4SPeiming Liu }
2147481bd5d4SPeiming Liu 
2148785a24f1SPeiming Liu /// Parses a list of `optional` defined list in the form of
2149785a24f1SPeiming Liu /// "(%val0, _, %val1, ...)", where `_` is used to annotate that the
2150785a24f1SPeiming Liu /// corresponding value is not defined (e.g., to represent an undefined
2151785a24f1SPeiming Liu /// coordinate in the sparse iteration space).
2152785a24f1SPeiming Liu static ParseResult parseOptionalDefinedList(
2153785a24f1SPeiming Liu     OpAsmParser &parser, OperationState &state, I64BitSet &definedSet,
2154785a24f1SPeiming Liu     SmallVectorImpl<OpAsmParser::Argument> &definedArgs,
2155785a24f1SPeiming Liu     unsigned maxCnt = std::numeric_limits<unsigned>::max(),
2156785a24f1SPeiming Liu     OpAsmParser::Delimiter delimiter = OpAsmParser::Delimiter::Paren) {
2157785a24f1SPeiming Liu   unsigned cnt = 0;
2158785a24f1SPeiming Liu   ParseResult crdList =
2159785a24f1SPeiming Liu       parser.parseCommaSeparatedList(delimiter, [&]() -> ParseResult {
2160785a24f1SPeiming Liu         if (parser.parseOptionalKeyword("_")) {
2161785a24f1SPeiming Liu           if (parser.parseArgument(definedArgs.emplace_back()))
2162785a24f1SPeiming Liu             return failure();
2163785a24f1SPeiming Liu           definedSet.set(cnt);
2164785a24f1SPeiming Liu         }
2165785a24f1SPeiming Liu         cnt += 1;
2166785a24f1SPeiming Liu         return success();
2167785a24f1SPeiming Liu       });
2168785a24f1SPeiming Liu 
2169785a24f1SPeiming Liu   if (cnt > maxCnt)
2170785a24f1SPeiming Liu     return parser.emitError(parser.getNameLoc(),
2171785a24f1SPeiming Liu                             "parsed more value than expected.");
2172785a24f1SPeiming Liu 
2173785a24f1SPeiming Liu   if (failed(crdList)) {
2174785a24f1SPeiming Liu     return parser.emitError(
2175785a24f1SPeiming Liu         parser.getNameLoc(),
2176785a24f1SPeiming Liu         "expecting SSA value or \"_\" for level coordinates");
2177785a24f1SPeiming Liu   }
2178785a24f1SPeiming Liu   assert(definedArgs.size() == definedSet.count());
2179785a24f1SPeiming Liu   return success();
2180785a24f1SPeiming Liu }
2181785a24f1SPeiming Liu 
2182785a24f1SPeiming Liu static void printOptionalDefinedList(OpAsmPrinter &p, unsigned size,
2183785a24f1SPeiming Liu                                      Block::BlockArgListType blocksArgs,
2184785a24f1SPeiming Liu                                      I64BitSet definedSet) {
2185785a24f1SPeiming Liu   if (definedSet.empty())
2186785a24f1SPeiming Liu     return;
2187785a24f1SPeiming Liu 
2188785a24f1SPeiming Liu   for (unsigned i = 0; i < size; i++) {
2189785a24f1SPeiming Liu     if (definedSet[i]) {
2190785a24f1SPeiming Liu       p << blocksArgs.front();
2191785a24f1SPeiming Liu       blocksArgs = blocksArgs.drop_front();
2192785a24f1SPeiming Liu     } else {
2193785a24f1SPeiming Liu       p << "_";
2194785a24f1SPeiming Liu     }
2195785a24f1SPeiming Liu     if (i != size - 1)
2196785a24f1SPeiming Liu       p << ", ";
2197785a24f1SPeiming Liu   }
2198785a24f1SPeiming Liu   assert(blocksArgs.empty());
2199785a24f1SPeiming Liu }
2200785a24f1SPeiming Liu 
2201e276cf08SPeiming Liu static ParseResult
2202785a24f1SPeiming Liu parseUsedCoordList(OpAsmParser &parser, OperationState &state,
2203785a24f1SPeiming Liu                    SmallVectorImpl<OpAsmParser::Argument> &coords) {
2204785a24f1SPeiming Liu   // Parse "at(%crd0, _, ...)"
2205785a24f1SPeiming Liu   I64BitSet crdUsedLvlSet;
2206785a24f1SPeiming Liu   if (succeeded(parser.parseOptionalKeyword("at")) &&
2207785a24f1SPeiming Liu       failed(parseOptionalDefinedList(parser, state, crdUsedLvlSet, coords)))
2208785a24f1SPeiming Liu     return failure();
2209785a24f1SPeiming Liu 
2210785a24f1SPeiming Liu   // Always use IndexType for the coordinate.
2211785a24f1SPeiming Liu   for (auto &coord : coords)
2212785a24f1SPeiming Liu     coord.type = parser.getBuilder().getIndexType();
2213785a24f1SPeiming Liu 
2214785a24f1SPeiming Liu   // Set the CrdUsedLvl bitset.
2215785a24f1SPeiming Liu   state.addAttribute("crdUsedLvls",
2216785a24f1SPeiming Liu                      parser.getBuilder().getI64IntegerAttr(crdUsedLvlSet));
2217785a24f1SPeiming Liu   return success();
2218785a24f1SPeiming Liu }
2219785a24f1SPeiming Liu 
2220785a24f1SPeiming Liu static ParseResult
2221785a24f1SPeiming Liu parseSparseIterateLoop(OpAsmParser &parser, OperationState &state,
2222e276cf08SPeiming Liu                        SmallVectorImpl<OpAsmParser::Argument> &iterators,
2223785a24f1SPeiming Liu                        SmallVectorImpl<OpAsmParser::Argument> &blockArgs) {
2224e276cf08SPeiming Liu   SmallVector<OpAsmParser::UnresolvedOperand> spaces;
2225e276cf08SPeiming Liu   SmallVector<OpAsmParser::UnresolvedOperand> initArgs;
2226e276cf08SPeiming Liu 
2227e276cf08SPeiming Liu   // Parse "%iters, ... in %spaces, ..."
2228e276cf08SPeiming Liu   if (parser.parseArgumentList(iterators) || parser.parseKeyword("in") ||
2229e276cf08SPeiming Liu       parser.parseOperandList(spaces))
2230e276cf08SPeiming Liu     return failure();
2231e276cf08SPeiming Liu 
2232e276cf08SPeiming Liu   if (iterators.size() != spaces.size())
2233e276cf08SPeiming Liu     return parser.emitError(
2234e276cf08SPeiming Liu         parser.getNameLoc(),
2235e276cf08SPeiming Liu         "mismatch in number of sparse iterators and sparse spaces");
2236e276cf08SPeiming Liu 
2237b48ef8d8SPeiming Liu   SmallVector<OpAsmParser::Argument> coords;
2238b48ef8d8SPeiming Liu   if (failed(parseUsedCoordList(parser, state, coords)))
2239e276cf08SPeiming Liu     return failure();
2240b48ef8d8SPeiming Liu   size_t numCrds = coords.size();
2241e276cf08SPeiming Liu 
2242e276cf08SPeiming Liu   // Parse "iter_args(%arg = %init, ...)"
2243e276cf08SPeiming Liu   bool hasIterArgs = succeeded(parser.parseOptionalKeyword("iter_args"));
2244e276cf08SPeiming Liu   if (hasIterArgs)
2245785a24f1SPeiming Liu     if (parser.parseAssignmentList(blockArgs, initArgs))
2246e276cf08SPeiming Liu       return failure();
2247e276cf08SPeiming Liu 
2248b48ef8d8SPeiming Liu   blockArgs.append(coords);
2249b48ef8d8SPeiming Liu 
2250e276cf08SPeiming Liu   SmallVector<Type> iterSpaceTps;
2251e276cf08SPeiming Liu   // parse ": sparse_tensor.iter_space -> ret"
2252e276cf08SPeiming Liu   if (parser.parseColon() || parser.parseTypeList(iterSpaceTps))
2253e276cf08SPeiming Liu     return failure();
2254e276cf08SPeiming Liu   if (iterSpaceTps.size() != spaces.size())
2255e276cf08SPeiming Liu     return parser.emitError(parser.getNameLoc(),
2256e276cf08SPeiming Liu                             "mismatch in number of iteration space operands "
2257e276cf08SPeiming Liu                             "and iteration space types");
2258e276cf08SPeiming Liu 
2259e276cf08SPeiming Liu   for (auto [it, tp] : llvm::zip_equal(iterators, iterSpaceTps)) {
2260e276cf08SPeiming Liu     IterSpaceType spaceTp = llvm::dyn_cast<IterSpaceType>(tp);
2261e276cf08SPeiming Liu     if (!spaceTp)
2262e276cf08SPeiming Liu       return parser.emitError(parser.getNameLoc(),
2263e276cf08SPeiming Liu                               "expected sparse_tensor.iter_space type for "
2264e276cf08SPeiming Liu                               "iteration space operands");
2265e276cf08SPeiming Liu     it.type = spaceTp.getIteratorType();
2266e276cf08SPeiming Liu   }
2267e276cf08SPeiming Liu 
2268e276cf08SPeiming Liu   if (hasIterArgs)
2269e276cf08SPeiming Liu     if (parser.parseArrowTypeList(state.types))
2270e276cf08SPeiming Liu       return failure();
2271e276cf08SPeiming Liu 
2272e276cf08SPeiming Liu   // Resolves input operands.
2273e276cf08SPeiming Liu   if (parser.resolveOperands(spaces, iterSpaceTps, parser.getNameLoc(),
2274e276cf08SPeiming Liu                              state.operands))
2275e276cf08SPeiming Liu     return failure();
2276e276cf08SPeiming Liu 
2277e276cf08SPeiming Liu   if (hasIterArgs) {
2278e276cf08SPeiming Liu     // Strip off leading args that used for coordinates.
2279b48ef8d8SPeiming Liu     MutableArrayRef args = MutableArrayRef(blockArgs).drop_back(numCrds);
2280785a24f1SPeiming Liu     if (args.size() != initArgs.size() || args.size() != state.types.size()) {
2281785a24f1SPeiming Liu       return parser.emitError(
2282785a24f1SPeiming Liu           parser.getNameLoc(),
2283785a24f1SPeiming Liu           "mismatch in number of iteration arguments and return values");
2284785a24f1SPeiming Liu     }
2285785a24f1SPeiming Liu 
2286785a24f1SPeiming Liu     for (auto [it, init, tp] : llvm::zip_equal(args, initArgs, state.types)) {
2287785a24f1SPeiming Liu       it.type = tp;
2288785a24f1SPeiming Liu       if (parser.resolveOperand(init, tp, state.operands))
2289785a24f1SPeiming Liu         return failure();
2290785a24f1SPeiming Liu     }
2291785a24f1SPeiming Liu   }
2292785a24f1SPeiming Liu   return success();
2293785a24f1SPeiming Liu }
2294785a24f1SPeiming Liu 
2295785a24f1SPeiming Liu static ParseResult
2296785a24f1SPeiming Liu parseSparseCoIterateLoop(OpAsmParser &parser, OperationState &state,
2297785a24f1SPeiming Liu                          SmallVectorImpl<Value> &spacesVals,
2298785a24f1SPeiming Liu                          SmallVectorImpl<OpAsmParser::Argument> &blockArgs) {
2299785a24f1SPeiming Liu 
2300785a24f1SPeiming Liu   // Parse "(%spaces, ...)"
2301785a24f1SPeiming Liu   SmallVector<OpAsmParser::UnresolvedOperand> spaces;
2302785a24f1SPeiming Liu   if (parser.parseOperandList(spaces, OpAsmParser::Delimiter::Paren))
2303785a24f1SPeiming Liu     return failure();
2304785a24f1SPeiming Liu 
2305c4420257SPeiming Liu   SmallVector<OpAsmParser::Argument> coords;
2306c4420257SPeiming Liu   if (failed(parseUsedCoordList(parser, state, coords)))
2307785a24f1SPeiming Liu     return failure();
2308c4420257SPeiming Liu   size_t numCrds = coords.size();
2309785a24f1SPeiming Liu 
2310785a24f1SPeiming Liu   // Parse "iter_args(%arg = %init, ...)"
2311785a24f1SPeiming Liu   SmallVector<OpAsmParser::UnresolvedOperand> initArgs;
2312785a24f1SPeiming Liu   bool hasIterArgs = succeeded(parser.parseOptionalKeyword("iter_args"));
2313785a24f1SPeiming Liu   if (hasIterArgs)
2314785a24f1SPeiming Liu     if (parser.parseAssignmentList(blockArgs, initArgs))
2315785a24f1SPeiming Liu       return failure();
2316c4420257SPeiming Liu   blockArgs.append(coords);
2317785a24f1SPeiming Liu 
2318785a24f1SPeiming Liu   SmallVector<Type> iterSpaceTps;
2319785a24f1SPeiming Liu   // parse ": (sparse_tensor.iter_space, ...) -> ret"
2320785a24f1SPeiming Liu   if (parser.parseColon() || parser.parseLParen() ||
2321785a24f1SPeiming Liu       parser.parseTypeList(iterSpaceTps) || parser.parseRParen())
2322785a24f1SPeiming Liu     return failure();
2323785a24f1SPeiming Liu 
2324785a24f1SPeiming Liu   if (iterSpaceTps.size() != spaces.size())
2325785a24f1SPeiming Liu     return parser.emitError(parser.getNameLoc(),
2326785a24f1SPeiming Liu                             "mismatch in number of iteration space operands "
2327785a24f1SPeiming Liu                             "and iteration space types");
2328785a24f1SPeiming Liu 
2329785a24f1SPeiming Liu   if (hasIterArgs)
2330785a24f1SPeiming Liu     if (parser.parseArrowTypeList(state.types))
2331785a24f1SPeiming Liu       return failure();
2332785a24f1SPeiming Liu 
2333785a24f1SPeiming Liu   // Resolves input sparse iteration spaces.
2334785a24f1SPeiming Liu   if (parser.resolveOperands(spaces, iterSpaceTps, parser.getNameLoc(),
2335785a24f1SPeiming Liu                              spacesVals))
2336785a24f1SPeiming Liu     return failure();
2337785a24f1SPeiming Liu   state.operands.append(spacesVals);
2338785a24f1SPeiming Liu 
2339785a24f1SPeiming Liu   if (hasIterArgs) {
2340c4420257SPeiming Liu     // Strip off trailing args that used for coordinates.
2341c4420257SPeiming Liu     MutableArrayRef args = MutableArrayRef(blockArgs).drop_back(numCrds);
2342e276cf08SPeiming Liu     if (args.size() != initArgs.size() || args.size() != state.types.size()) {
2343e276cf08SPeiming Liu       return parser.emitError(
2344e276cf08SPeiming Liu           parser.getNameLoc(),
2345e276cf08SPeiming Liu           "mismatch in number of iteration arguments and return values");
2346e276cf08SPeiming Liu     }
2347e276cf08SPeiming Liu 
2348e276cf08SPeiming Liu     for (auto [it, init, tp] : llvm::zip_equal(args, initArgs, state.types)) {
2349e276cf08SPeiming Liu       it.type = tp;
2350e276cf08SPeiming Liu       if (parser.resolveOperand(init, tp, state.operands))
2351e276cf08SPeiming Liu         return failure();
2352e276cf08SPeiming Liu     }
2353e276cf08SPeiming Liu   }
2354e276cf08SPeiming Liu   return success();
2355e276cf08SPeiming Liu }
2356e276cf08SPeiming Liu 
2357481bd5d4SPeiming Liu LogicalResult ExtractIterSpaceOp::inferReturnTypes(
2358481bd5d4SPeiming Liu     MLIRContext *ctx, std::optional<Location> loc, ValueRange ops,
2359481bd5d4SPeiming Liu     DictionaryAttr attr, OpaqueProperties prop, RegionRange region,
2360481bd5d4SPeiming Liu     SmallVectorImpl<mlir::Type> &ret) {
2361481bd5d4SPeiming Liu 
2362481bd5d4SPeiming Liu   ExtractIterSpaceOp::Adaptor adaptor(ops, attr, prop, region);
2363481bd5d4SPeiming Liu   SparseTensorType stt = getSparseTensorType(adaptor.getTensor());
2364481bd5d4SPeiming Liu   ret.push_back(IterSpaceType::get(ctx, stt.getEncoding(), adaptor.getLoLvl(),
2365481bd5d4SPeiming Liu                                    adaptor.getHiLvl()));
2366481bd5d4SPeiming Liu   return success();
2367481bd5d4SPeiming Liu }
2368481bd5d4SPeiming Liu 
2369481bd5d4SPeiming Liu LogicalResult ExtractIterSpaceOp::verify() {
2370481bd5d4SPeiming Liu   if (getLoLvl() >= getHiLvl())
2371481bd5d4SPeiming Liu     return emitOpError("expected smaller level low than level high");
2372481bd5d4SPeiming Liu 
2373481bd5d4SPeiming Liu   TypedValue<IteratorType> pIter = getParentIter();
2374481bd5d4SPeiming Liu   if ((pIter && getLoLvl() == 0) || (!pIter && getLoLvl() != 0)) {
2375481bd5d4SPeiming Liu     return emitOpError(
2376481bd5d4SPeiming Liu         "parent iterator should be specified iff level lower bound equals 0");
2377481bd5d4SPeiming Liu   }
2378481bd5d4SPeiming Liu 
2379481bd5d4SPeiming Liu   if (pIter) {
2380e276cf08SPeiming Liu     IterSpaceType spaceTp = getExtractedSpace().getType();
2381481bd5d4SPeiming Liu     if (pIter.getType().getEncoding() != spaceTp.getEncoding())
2382481bd5d4SPeiming Liu       return emitOpError(
2383481bd5d4SPeiming Liu           "mismatch in parent iterator encoding and iteration space encoding.");
2384481bd5d4SPeiming Liu 
2385481bd5d4SPeiming Liu     if (spaceTp.getLoLvl() != pIter.getType().getHiLvl())
2386481bd5d4SPeiming Liu       return emitOpError("parent iterator should be used to extract an "
2387481bd5d4SPeiming Liu                          "iteration space from a consecutive level.");
2388481bd5d4SPeiming Liu   }
2389481bd5d4SPeiming Liu 
2390481bd5d4SPeiming Liu   return success();
2391481bd5d4SPeiming Liu }
2392481bd5d4SPeiming Liu 
239312189f80SPeiming Liu LogicalResult ExtractValOp::verify() {
239412189f80SPeiming Liu   auto stt = getSparseTensorType(getTensor());
239512189f80SPeiming Liu   auto itTp = getIterator().getType();
239612189f80SPeiming Liu 
239712189f80SPeiming Liu   if (stt.getEncoding() != itTp.getEncoding())
239812189f80SPeiming Liu     return emitOpError("mismatch in tensor encoding and iterator encoding.");
239912189f80SPeiming Liu 
240012189f80SPeiming Liu   if (stt.getLvlRank() != itTp.getHiLvl())
240112189f80SPeiming Liu     return emitOpError("must use last-level iterator to extract values. ");
240212189f80SPeiming Liu 
240312189f80SPeiming Liu   return success();
240412189f80SPeiming Liu }
240512189f80SPeiming Liu 
2406a43d79afSPeiming Liu struct RemoveUnusedLvlCrds : public OpRewritePattern<IterateOp> {
2407a43d79afSPeiming Liu   using OpRewritePattern::OpRewritePattern;
2408a43d79afSPeiming Liu 
2409a43d79afSPeiming Liu   LogicalResult matchAndRewrite(IterateOp iterateOp,
2410a43d79afSPeiming Liu                                 PatternRewriter &rewriter) const override {
2411785a24f1SPeiming Liu     I64BitSet newUsedLvls(0);
2412a43d79afSPeiming Liu     llvm::BitVector toRemove(iterateOp.getBody()->getNumArguments());
2413a43d79afSPeiming Liu     for (unsigned i = 0, e = iterateOp.getSpaceDim(); i < e; i++) {
2414a43d79afSPeiming Liu       if (auto crd = iterateOp.getLvlCrd(i)) {
2415a43d79afSPeiming Liu         if (crd->getUsers().empty())
2416a43d79afSPeiming Liu           toRemove.set(crd->getArgNumber());
2417a43d79afSPeiming Liu         else
2418a43d79afSPeiming Liu           newUsedLvls.set(i);
2419a43d79afSPeiming Liu       }
2420a43d79afSPeiming Liu     }
2421a43d79afSPeiming Liu 
2422a43d79afSPeiming Liu     // All coordinates are used.
2423a43d79afSPeiming Liu     if (toRemove.none())
2424a43d79afSPeiming Liu       return failure();
2425a43d79afSPeiming Liu 
2426a43d79afSPeiming Liu     rewriter.startOpModification(iterateOp);
2427a43d79afSPeiming Liu     iterateOp.setCrdUsedLvls(newUsedLvls);
2428a43d79afSPeiming Liu     iterateOp.getBody()->eraseArguments(toRemove);
2429a43d79afSPeiming Liu     rewriter.finalizeOpModification(iterateOp);
2430a43d79afSPeiming Liu     return success();
2431a43d79afSPeiming Liu   }
2432a43d79afSPeiming Liu };
2433a43d79afSPeiming Liu 
2434a43d79afSPeiming Liu void IterateOp::getCanonicalizationPatterns(mlir::RewritePatternSet &results,
2435a43d79afSPeiming Liu                                             mlir::MLIRContext *context) {
2436a43d79afSPeiming Liu   results.add<RemoveUnusedLvlCrds>(context);
2437a43d79afSPeiming Liu }
2438a43d79afSPeiming Liu 
2439a02010b3SPeiming Liu void IterateOp::build(OpBuilder &builder, OperationState &odsState,
2440a02010b3SPeiming Liu                       Value iterSpace, ValueRange initArgs) {
2441a02010b3SPeiming Liu   unsigned rank = llvm::cast<IterSpaceType>(iterSpace.getType()).getSpaceDim();
2442a02010b3SPeiming Liu   // All ones.
2443785a24f1SPeiming Liu   I64BitSet set((1 << rank) - 1);
2444a02010b3SPeiming Liu   return build(builder, odsState, iterSpace, initArgs, set);
2445a02010b3SPeiming Liu }
2446a02010b3SPeiming Liu 
2447a02010b3SPeiming Liu void IterateOp::build(OpBuilder &builder, OperationState &odsState,
2448a02010b3SPeiming Liu                       Value iterSpace, ValueRange initArgs,
2449785a24f1SPeiming Liu                       I64BitSet crdUsedLvls) {
2450a02010b3SPeiming Liu   OpBuilder::InsertionGuard guard(builder);
2451a02010b3SPeiming Liu 
2452a02010b3SPeiming Liu   odsState.addOperands(iterSpace);
2453a02010b3SPeiming Liu   odsState.addOperands(initArgs);
2454a02010b3SPeiming Liu   odsState.getOrAddProperties<Properties>().crdUsedLvls =
2455a02010b3SPeiming Liu       builder.getIntegerAttr(builder.getIntegerType(64), crdUsedLvls);
2456a02010b3SPeiming Liu   Region *bodyRegion = odsState.addRegion();
2457a02010b3SPeiming Liu   odsState.addTypes(initArgs.getTypes());
2458a02010b3SPeiming Liu   Block *bodyBlock = builder.createBlock(bodyRegion);
2459a02010b3SPeiming Liu 
2460b48ef8d8SPeiming Liu   // Starts with a list of user-provided loop arguments.
2461b48ef8d8SPeiming Liu   for (Value v : initArgs)
2462b48ef8d8SPeiming Liu     bodyBlock->addArgument(v.getType(), v.getLoc());
2463a02010b3SPeiming Liu 
2464b48ef8d8SPeiming Liu   // Follows by a list of used coordinates.
2465a02010b3SPeiming Liu   for (unsigned i = 0, e = crdUsedLvls.count(); i < e; i++)
2466a02010b3SPeiming Liu     bodyBlock->addArgument(builder.getIndexType(), odsState.location);
2467a02010b3SPeiming Liu 
2468b48ef8d8SPeiming Liu   // Ends with sparse iterator
2469b48ef8d8SPeiming Liu   bodyBlock->addArgument(
2470b48ef8d8SPeiming Liu       llvm::cast<IterSpaceType>(iterSpace.getType()).getIteratorType(),
2471b48ef8d8SPeiming Liu       odsState.location);
2472a02010b3SPeiming Liu }
2473a02010b3SPeiming Liu 
2474e276cf08SPeiming Liu ParseResult IterateOp::parse(OpAsmParser &parser, OperationState &result) {
2475e276cf08SPeiming Liu   OpAsmParser::Argument iterator;
2476e276cf08SPeiming Liu   OpAsmParser::UnresolvedOperand iterSpace;
2477e276cf08SPeiming Liu 
2478e276cf08SPeiming Liu   SmallVector<OpAsmParser::Argument> iters, iterArgs;
2479785a24f1SPeiming Liu   if (parseSparseIterateLoop(parser, result, iters, iterArgs))
2480e276cf08SPeiming Liu     return failure();
2481e276cf08SPeiming Liu   if (iters.size() != 1)
2482e276cf08SPeiming Liu     return parser.emitError(parser.getNameLoc(),
2483e276cf08SPeiming Liu                             "expected only one iterator/iteration space");
2484e276cf08SPeiming Liu 
2485b48ef8d8SPeiming Liu   iterArgs.append(iters);
2486e276cf08SPeiming Liu   Region *body = result.addRegion();
2487b48ef8d8SPeiming Liu   if (parser.parseRegion(*body, iterArgs))
2488e276cf08SPeiming Liu     return failure();
2489e276cf08SPeiming Liu 
2490e276cf08SPeiming Liu   IterateOp::ensureTerminator(*body, parser.getBuilder(), result.location);
2491e276cf08SPeiming Liu 
2492e276cf08SPeiming Liu   // Parse the optional attribute list.
2493e276cf08SPeiming Liu   if (parser.parseOptionalAttrDict(result.attributes))
2494e276cf08SPeiming Liu     return failure();
2495e276cf08SPeiming Liu 
2496e276cf08SPeiming Liu   return success();
2497e276cf08SPeiming Liu }
2498e276cf08SPeiming Liu 
2499e276cf08SPeiming Liu /// Prints the initialization list in the form of
2500e276cf08SPeiming Liu ///   <prefix>(%inner = %outer, %inner2 = %outer2, <...>)
2501e276cf08SPeiming Liu /// where 'inner' values are assumed to be region arguments and 'outer' values
2502e276cf08SPeiming Liu /// are regular SSA values.
2503e276cf08SPeiming Liu static void printInitializationList(OpAsmPrinter &p,
2504e276cf08SPeiming Liu                                     Block::BlockArgListType blocksArgs,
2505e276cf08SPeiming Liu                                     ValueRange initializers,
2506e276cf08SPeiming Liu                                     StringRef prefix = "") {
2507e276cf08SPeiming Liu   assert(blocksArgs.size() == initializers.size() &&
2508e276cf08SPeiming Liu          "expected same length of arguments and initializers");
2509e276cf08SPeiming Liu   if (initializers.empty())
2510e276cf08SPeiming Liu     return;
2511e276cf08SPeiming Liu 
2512e276cf08SPeiming Liu   p << prefix << '(';
2513e276cf08SPeiming Liu   llvm::interleaveComma(llvm::zip(blocksArgs, initializers), p, [&](auto it) {
2514e276cf08SPeiming Liu     p << std::get<0>(it) << " = " << std::get<1>(it);
2515e276cf08SPeiming Liu   });
2516e276cf08SPeiming Liu   p << ")";
2517e276cf08SPeiming Liu }
2518e276cf08SPeiming Liu 
2519785a24f1SPeiming Liu template <typename SparseLoopOp>
2520785a24f1SPeiming Liu static LogicalResult verifySparseLoopOp(SparseLoopOp op) {
2521785a24f1SPeiming Liu   if (op.getInitArgs().size() != op.getNumResults()) {
2522785a24f1SPeiming Liu     return op.emitOpError(
2523785a24f1SPeiming Liu         "mismatch in number of loop-carried values and defined values");
2524785a24f1SPeiming Liu   }
2525785a24f1SPeiming Liu   if (op.getCrdUsedLvls().max() > op.getSpaceDim())
2526785a24f1SPeiming Liu     return op.emitOpError("required out-of-bound coordinates");
2527e276cf08SPeiming Liu 
2528785a24f1SPeiming Liu   return success();
2529e276cf08SPeiming Liu }
2530785a24f1SPeiming Liu 
2531785a24f1SPeiming Liu LogicalResult IterateOp::verify() { return verifySparseLoopOp(*this); }
2532785a24f1SPeiming Liu LogicalResult CoIterateOp::verify() { return verifySparseLoopOp(*this); }
2533e276cf08SPeiming Liu 
2534e276cf08SPeiming Liu void IterateOp::print(OpAsmPrinter &p) {
2535e276cf08SPeiming Liu   p << " " << getIterator() << " in " << getIterSpace();
2536785a24f1SPeiming Liu   if (!getCrdUsedLvls().empty()) {
2537785a24f1SPeiming Liu     p << " at(";
2538785a24f1SPeiming Liu     printOptionalDefinedList(p, getSpaceDim(), getCrds(), getCrdUsedLvls());
2539785a24f1SPeiming Liu     p << ")";
2540785a24f1SPeiming Liu   }
2541e276cf08SPeiming Liu   printInitializationList(p, getRegionIterArgs(), getInitArgs(), " iter_args");
2542e276cf08SPeiming Liu 
2543e276cf08SPeiming Liu   p << " : " << getIterSpace().getType() << " ";
2544e276cf08SPeiming Liu   if (!getInitArgs().empty())
2545785a24f1SPeiming Liu     p.printArrowTypeList(getInitArgs().getTypes());
2546e276cf08SPeiming Liu 
2547785a24f1SPeiming Liu   p << " ";
2548e276cf08SPeiming Liu   p.printRegion(getRegion(), /*printEntryBlockArgs=*/false,
2549e276cf08SPeiming Liu                 /*printBlockTerminators=*/!getInitArgs().empty());
2550e276cf08SPeiming Liu }
2551e276cf08SPeiming Liu 
2552e276cf08SPeiming Liu LogicalResult IterateOp::verifyRegions() {
2553e276cf08SPeiming Liu   if (getIterator().getType() != getIterSpace().getType().getIteratorType())
2554e276cf08SPeiming Liu     return emitOpError("mismatch in iterator and iteration space type");
2555e276cf08SPeiming Liu   if (getNumRegionIterArgs() != getNumResults())
2556e276cf08SPeiming Liu     return emitOpError(
2557e276cf08SPeiming Liu         "mismatch in number of basic block args and defined values");
2558e276cf08SPeiming Liu 
2559e276cf08SPeiming Liu   auto initArgs = getInitArgs();
2560e276cf08SPeiming Liu   auto iterArgs = getRegionIterArgs();
2561e276cf08SPeiming Liu   auto yieldVals = getYieldedValues();
2562e276cf08SPeiming Liu   auto opResults = getResults();
2563e276cf08SPeiming Liu   if (!llvm::all_equal({initArgs.size(), iterArgs.size(), yieldVals.size(),
2564e276cf08SPeiming Liu                         opResults.size()})) {
2565e276cf08SPeiming Liu     return emitOpError() << "number mismatch between iter args and results.";
2566e276cf08SPeiming Liu   }
2567e276cf08SPeiming Liu 
2568e276cf08SPeiming Liu   for (auto [i, init, iter, yield, ret] :
2569e276cf08SPeiming Liu        llvm::enumerate(initArgs, iterArgs, yieldVals, opResults)) {
2570e276cf08SPeiming Liu     if (init.getType() != ret.getType())
2571e276cf08SPeiming Liu       return emitOpError() << "types mismatch between " << i
2572e276cf08SPeiming Liu                            << "th iter operand and defined value";
2573e276cf08SPeiming Liu     if (iter.getType() != ret.getType())
2574e276cf08SPeiming Liu       return emitOpError() << "types mismatch between " << i
2575e276cf08SPeiming Liu                            << "th iter region arg and defined value";
2576e276cf08SPeiming Liu     if (yield.getType() != ret.getType())
2577e276cf08SPeiming Liu       return emitOpError() << "types mismatch between " << i
2578e276cf08SPeiming Liu                            << "th yield value and defined value";
2579e276cf08SPeiming Liu   }
2580e276cf08SPeiming Liu 
2581e276cf08SPeiming Liu   return success();
2582e276cf08SPeiming Liu }
2583e276cf08SPeiming Liu 
2584e276cf08SPeiming Liu /// OpInterfaces' methods implemented by IterateOp.
2585e276cf08SPeiming Liu SmallVector<Region *> IterateOp::getLoopRegions() { return {&getRegion()}; }
2586e276cf08SPeiming Liu 
2587e276cf08SPeiming Liu MutableArrayRef<OpOperand> IterateOp::getInitsMutable() {
2588e276cf08SPeiming Liu   return getInitArgsMutable();
2589e276cf08SPeiming Liu }
2590e276cf08SPeiming Liu 
2591e276cf08SPeiming Liu Block::BlockArgListType IterateOp::getRegionIterArgs() {
2592b48ef8d8SPeiming Liu   return getRegion().getArguments().take_front(getNumRegionIterArgs());
2593e276cf08SPeiming Liu }
2594e276cf08SPeiming Liu 
2595e276cf08SPeiming Liu std::optional<MutableArrayRef<OpOperand>> IterateOp::getYieldedValuesMutable() {
2596e276cf08SPeiming Liu   return cast<sparse_tensor::YieldOp>(
2597e276cf08SPeiming Liu              getRegion().getBlocks().front().getTerminator())
2598e276cf08SPeiming Liu       .getResultsMutable();
2599e276cf08SPeiming Liu }
2600e276cf08SPeiming Liu 
2601e276cf08SPeiming Liu std::optional<ResultRange> IterateOp::getLoopResults() { return getResults(); }
2602e276cf08SPeiming Liu 
2603e276cf08SPeiming Liu OperandRange IterateOp::getEntrySuccessorOperands(RegionBranchPoint point) {
2604e276cf08SPeiming Liu   return getInitArgs();
2605e276cf08SPeiming Liu }
2606e276cf08SPeiming Liu 
2607e276cf08SPeiming Liu void IterateOp::getSuccessorRegions(RegionBranchPoint point,
2608e276cf08SPeiming Liu                                     SmallVectorImpl<RegionSuccessor> &regions) {
2609785a24f1SPeiming Liu   // Both the operation itself and the region may be branching into the body
2610785a24f1SPeiming Liu   // or back into the operation itself.
2611e276cf08SPeiming Liu   regions.push_back(RegionSuccessor(&getRegion(), getRegionIterArgs()));
2612e276cf08SPeiming Liu   // It is possible for loop not to enter the body.
2613e276cf08SPeiming Liu   regions.push_back(RegionSuccessor(getResults()));
2614e276cf08SPeiming Liu }
2615e276cf08SPeiming Liu 
2616c4420257SPeiming Liu void CoIterateOp::build(OpBuilder &builder, OperationState &odsState,
2617c4420257SPeiming Liu                         ValueRange iterSpaces, ValueRange initArgs,
2618c4420257SPeiming Liu                         unsigned numCases) {
2619c4420257SPeiming Liu   unsigned rank =
2620c4420257SPeiming Liu       cast<IterSpaceType>(iterSpaces.front().getType()).getSpaceDim();
2621c4420257SPeiming Liu   // All ones.
2622c4420257SPeiming Liu   I64BitSet set((1 << rank) - 1);
2623c4420257SPeiming Liu   // Generates all-zero case bits (they only serve as placeholders), which are
2624c4420257SPeiming Liu   // supposed to be overriden later. We need to preallocate all the regions as
2625c4420257SPeiming Liu   // mlir::Region cannot be dynamically added later after the operation is
2626c4420257SPeiming Liu   // created.
2627c4420257SPeiming Liu   SmallVector<int64_t> caseBits(numCases, 0);
2628c4420257SPeiming Liu   ArrayAttr cases = builder.getI64ArrayAttr(caseBits);
2629c4420257SPeiming Liu   return CoIterateOp::build(builder, odsState, initArgs.getTypes(), iterSpaces,
2630c4420257SPeiming Liu                             initArgs, set, cases,
2631c4420257SPeiming Liu                             /*caseRegionsCount=*/numCases);
2632c4420257SPeiming Liu }
2633c4420257SPeiming Liu 
2634785a24f1SPeiming Liu ParseResult CoIterateOp::parse(OpAsmParser &parser, OperationState &result) {
2635785a24f1SPeiming Liu 
2636785a24f1SPeiming Liu   SmallVector<Value> spaces;
2637785a24f1SPeiming Liu   // The block argument list of each regions, it is arranged in the order of
2638785a24f1SPeiming Liu   // ([used coordinate list], [loop iterations args], [sparse iterator list]).
2639785a24f1SPeiming Liu   SmallVector<OpAsmParser::Argument> blockArgs;
2640785a24f1SPeiming Liu   if (parseSparseCoIterateLoop(parser, result, spaces, blockArgs))
2641785a24f1SPeiming Liu     return failure();
2642785a24f1SPeiming Liu 
2643785a24f1SPeiming Liu   result.addAttribute("operandSegmentSizes",
2644785a24f1SPeiming Liu                       parser.getBuilder().getDenseI32ArrayAttr(
2645785a24f1SPeiming Liu                           {static_cast<int32_t>(spaces.size()),
2646785a24f1SPeiming Liu                            static_cast<int32_t>(result.types.size())}));
2647785a24f1SPeiming Liu 
2648785a24f1SPeiming Liu   SmallVector<Attribute> cases;
2649785a24f1SPeiming Liu   while (succeeded(parser.parseOptionalKeyword("case"))) {
2650785a24f1SPeiming Liu     // Parse one region per case.
2651785a24f1SPeiming Liu     I64BitSet definedItSet;
2652785a24f1SPeiming Liu     SmallVector<OpAsmParser::Argument> definedIts;
2653785a24f1SPeiming Liu     if (parseOptionalDefinedList(parser, result, definedItSet, definedIts,
2654785a24f1SPeiming Liu                                  spaces.size(), OpAsmParser::Delimiter::None))
2655785a24f1SPeiming Liu       return failure();
2656785a24f1SPeiming Liu 
2657785a24f1SPeiming Liu     cases.push_back(parser.getBuilder().getI64IntegerAttr(definedItSet));
2658785a24f1SPeiming Liu 
2659785a24f1SPeiming Liu     for (auto [i, definedIdx] : llvm::enumerate(definedItSet.bits())) {
2660785a24f1SPeiming Liu       // Resolve the iterator type based on the iteration space type.
2661785a24f1SPeiming Liu       auto spaceTp = llvm::cast<IterSpaceType>(spaces[definedIdx].getType());
2662785a24f1SPeiming Liu       definedIts[i].type = spaceTp.getIteratorType();
2663785a24f1SPeiming Liu     }
2664785a24f1SPeiming Liu     definedIts.insert(definedIts.begin(), blockArgs.begin(), blockArgs.end());
2665785a24f1SPeiming Liu     Region *body = result.addRegion();
2666785a24f1SPeiming Liu     if (parser.parseRegion(*body, definedIts))
2667785a24f1SPeiming Liu       return failure();
2668785a24f1SPeiming Liu 
2669785a24f1SPeiming Liu     CoIterateOp::ensureTerminator(*body, parser.getBuilder(), result.location);
2670785a24f1SPeiming Liu   }
2671785a24f1SPeiming Liu 
2672785a24f1SPeiming Liu   result.addAttribute("cases", ArrayAttr::get(parser.getContext(), cases));
2673785a24f1SPeiming Liu 
2674785a24f1SPeiming Liu   // Parse the optional attribute list.
2675785a24f1SPeiming Liu   if (parser.parseOptionalAttrDict(result.attributes))
2676785a24f1SPeiming Liu     return failure();
2677785a24f1SPeiming Liu 
2678785a24f1SPeiming Liu   return success();
2679785a24f1SPeiming Liu }
2680785a24f1SPeiming Liu 
2681785a24f1SPeiming Liu void CoIterateOp::print(OpAsmPrinter &p) {
2682785a24f1SPeiming Liu   p << " (";
2683785a24f1SPeiming Liu   llvm::interleaveComma(getIterSpaces(), p, [&](auto s) { p << s; });
2684785a24f1SPeiming Liu   p << ")";
2685785a24f1SPeiming Liu 
2686785a24f1SPeiming Liu   if (!getCrdUsedLvls().empty()) {
2687785a24f1SPeiming Liu     p << " at(";
2688785a24f1SPeiming Liu     printOptionalDefinedList(p, getSpaceDim(), getCrds(0), getCrdUsedLvls());
2689785a24f1SPeiming Liu     p << ")";
2690785a24f1SPeiming Liu   }
2691785a24f1SPeiming Liu 
2692785a24f1SPeiming Liu   printInitializationList(p, getRegionIterArgs(0), getInitArgs(), " iter_args");
2693785a24f1SPeiming Liu 
2694785a24f1SPeiming Liu   p << " : (" << getIterSpaces().getTypes() << ")";
2695785a24f1SPeiming Liu   if (!getInitArgs().empty())
2696785a24f1SPeiming Liu     p.printArrowTypeList(getInitArgs().getTypes());
2697785a24f1SPeiming Liu 
2698785a24f1SPeiming Liu   for (unsigned idx = 0, e = getRegions().size(); idx < e; idx++) {
2699785a24f1SPeiming Liu     p.printNewline();
2700785a24f1SPeiming Liu     p << "case ";
2701785a24f1SPeiming Liu     printOptionalDefinedList(p, getIterSpaces().size(), getRegionIterators(idx),
2702785a24f1SPeiming Liu                              getRegionDefinedSpace(idx));
2703785a24f1SPeiming Liu     p << " ";
2704785a24f1SPeiming Liu     p.printRegion(getRegion(idx), /*printEntryBlockArgs=*/false,
2705785a24f1SPeiming Liu                   /*printBlockTerminators=*/!getInitArgs().empty());
2706785a24f1SPeiming Liu   }
2707785a24f1SPeiming Liu }
2708785a24f1SPeiming Liu 
2709785a24f1SPeiming Liu ValueRange CoIterateOp::getYieldedValues(unsigned regionIdx) {
2710785a24f1SPeiming Liu   return cast<sparse_tensor::YieldOp>(
2711785a24f1SPeiming Liu              getRegion(regionIdx).getBlocks().front().getTerminator())
2712785a24f1SPeiming Liu       .getResults();
2713785a24f1SPeiming Liu }
2714785a24f1SPeiming Liu 
2715785a24f1SPeiming Liu LogicalResult CoIterateOp::verifyRegions() {
2716785a24f1SPeiming Liu   for (unsigned r = 0, e = getNumRegions(); r < e; r++) {
2717c4420257SPeiming Liu     if (getNumRegionIterArgs() != getNumResults())
2718785a24f1SPeiming Liu       return emitOpError(
2719785a24f1SPeiming Liu           "mismatch in number of basic block args and defined values");
2720785a24f1SPeiming Liu 
2721785a24f1SPeiming Liu     auto initArgs = getInitArgs();
2722785a24f1SPeiming Liu     auto iterArgs = getRegionIterArgs(r);
2723785a24f1SPeiming Liu     auto yieldVals = getYieldedValues(r);
2724785a24f1SPeiming Liu     auto opResults = getResults();
2725785a24f1SPeiming Liu     if (!llvm::all_equal({initArgs.size(), iterArgs.size(), yieldVals.size(),
2726785a24f1SPeiming Liu                           opResults.size()})) {
2727785a24f1SPeiming Liu       return emitOpError()
2728785a24f1SPeiming Liu              << "number mismatch between iter args and results on " << r
2729785a24f1SPeiming Liu              << "th region";
2730785a24f1SPeiming Liu     }
2731785a24f1SPeiming Liu 
2732785a24f1SPeiming Liu     for (auto [i, init, iter, yield, ret] :
2733785a24f1SPeiming Liu          llvm::enumerate(initArgs, iterArgs, yieldVals, opResults)) {
2734785a24f1SPeiming Liu       if (init.getType() != ret.getType())
2735785a24f1SPeiming Liu         return emitOpError()
2736785a24f1SPeiming Liu                << "types mismatch between " << i
2737785a24f1SPeiming Liu                << "th iter operand and defined value on " << r << "th region";
2738785a24f1SPeiming Liu       if (iter.getType() != ret.getType())
2739785a24f1SPeiming Liu         return emitOpError() << "types mismatch between " << i
2740785a24f1SPeiming Liu                              << "th iter region arg and defined value on " << r
2741785a24f1SPeiming Liu                              << "th region";
2742785a24f1SPeiming Liu       if (yield.getType() != ret.getType())
2743785a24f1SPeiming Liu         return emitOpError()
2744785a24f1SPeiming Liu                << "types mismatch between " << i
2745785a24f1SPeiming Liu                << "th yield value and defined value on " << r << "th region";
2746785a24f1SPeiming Liu     }
2747785a24f1SPeiming Liu   }
2748785a24f1SPeiming Liu 
2749785a24f1SPeiming Liu   auto cases = getRegionDefinedSpaces();
2750785a24f1SPeiming Liu   llvm::SmallSetVector<uint64_t, 8> set(cases.begin(), cases.end());
2751785a24f1SPeiming Liu   if (set.size() != getNumRegions())
2752785a24f1SPeiming Liu     return emitOpError("contains duplicated cases.");
2753785a24f1SPeiming Liu 
2754785a24f1SPeiming Liu   return success();
2755785a24f1SPeiming Liu }
2756785a24f1SPeiming Liu 
2757f607102aSPeiming Liu SmallVector<Region *> CoIterateOp::getSubCasesOf(unsigned regionIdx) {
2758f607102aSPeiming Liu   SmallVector<Region *> ret;
2759f607102aSPeiming Liu   I64BitSet caseBit = getRegionDefinedSpace(regionIdx);
2760f607102aSPeiming Liu   for (Region &r : getCaseRegions())
2761f607102aSPeiming Liu     if (getRegionDefinedSpace(r.getRegionNumber()).isSubSetOf(caseBit))
2762f607102aSPeiming Liu       ret.push_back(&r);
2763f607102aSPeiming Liu 
2764f607102aSPeiming Liu   return ret;
2765f607102aSPeiming Liu }
2766f607102aSPeiming Liu 
2767e276cf08SPeiming Liu //===----------------------------------------------------------------------===//
2768e276cf08SPeiming Liu // Sparse Tensor Dialect Setups.
2769e276cf08SPeiming Liu //===----------------------------------------------------------------------===//
2770e276cf08SPeiming Liu 
2771f0f5fdf7SPeiming Liu /// Materialize a single constant operation from a given attribute value with
2772f0f5fdf7SPeiming Liu /// the desired resultant type.
2773f0f5fdf7SPeiming Liu Operation *SparseTensorDialect::materializeConstant(OpBuilder &builder,
2774f0f5fdf7SPeiming Liu                                                     Attribute value, Type type,
2775f0f5fdf7SPeiming Liu                                                     Location loc) {
2776f0f5fdf7SPeiming Liu   if (auto op = arith::ConstantOp::materialize(builder, value, type, loc))
2777f0f5fdf7SPeiming Liu     return op;
2778f0f5fdf7SPeiming Liu   return nullptr;
2779f0f5fdf7SPeiming Liu }
2780f0f5fdf7SPeiming Liu 
2781c5a67e16SYinying Li namespace {
2782c5a67e16SYinying Li struct SparseTensorAsmDialectInterface : public OpAsmDialectInterface {
2783c5a67e16SYinying Li   using OpAsmDialectInterface::OpAsmDialectInterface;
2784c5a67e16SYinying Li 
2785c5a67e16SYinying Li   AliasResult getAlias(Attribute attr, raw_ostream &os) const override {
2786a5757c5bSChristian Sigg     if (isa<SparseTensorEncodingAttr>(attr)) {
2787c5a67e16SYinying Li       os << "sparse";
2788c5a67e16SYinying Li       return AliasResult::OverridableAlias;
2789c5a67e16SYinying Li     }
2790c5a67e16SYinying Li     return AliasResult::NoAlias;
2791c5a67e16SYinying Li   }
2792c5a67e16SYinying Li };
2793c5a67e16SYinying Li } // namespace
2794c5a67e16SYinying Li 
2795319072f4SAart Bik void SparseTensorDialect::initialize() {
2796c5a67e16SYinying Li   addInterface<SparseTensorAsmDialectInterface>();
27970a292199SAart Bik   addAttributes<
27980a292199SAart Bik #define GET_ATTRDEF_LIST
27990a292199SAart Bik #include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.cpp.inc"
28000a292199SAart Bik       >();
280171cc0f1cSPeiming Liu   addTypes<
280271cc0f1cSPeiming Liu #define GET_TYPEDEF_LIST
280371cc0f1cSPeiming Liu #include "mlir/Dialect/SparseTensor/IR/SparseTensorTypes.cpp.inc"
280471cc0f1cSPeiming Liu       >();
2805319072f4SAart Bik   addOperations<
2806319072f4SAart Bik #define GET_OP_LIST
2807319072f4SAart Bik #include "mlir/Dialect/SparseTensor/IR/SparseTensorOps.cpp.inc"
2808319072f4SAart Bik       >();
2809513cdb82SJustin Fargnoli   declarePromisedInterfaces<
2810513cdb82SJustin Fargnoli       bufferization::BufferizableOpInterface, ConcatenateOp, ConvertOp, LoadOp,
2811513cdb82SJustin Fargnoli       NewOp, NumberOfEntriesOp, AssembleOp, DisassembleOp,
2812513cdb82SJustin Fargnoli       ToCoordinatesBufferOp, ToCoordinatesOp, ToPositionsOp, ToValuesOp>();
2813319072f4SAart Bik }
2814319072f4SAart Bik 
2815319072f4SAart Bik #define GET_OP_CLASSES
2816319072f4SAart Bik #include "mlir/Dialect/SparseTensor/IR/SparseTensorOps.cpp.inc"
28175517208dSAart Bik 
28185517208dSAart Bik #include "mlir/Dialect/SparseTensor/IR/SparseTensorOpsDialect.cpp.inc"
2819