1319072f4SAart Bik //===- SparseTensorDialect.cpp - Sparse tensor dialect implementation -----===// 2319072f4SAart Bik // 3319072f4SAart Bik // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4319072f4SAart Bik // See https://llvm.org/LICENSE.txt for license information. 5319072f4SAart Bik // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6319072f4SAart Bik // 7319072f4SAart Bik //===----------------------------------------------------------------------===// 8319072f4SAart Bik 9239737edSMehdi Amini #include <utility> 10239737edSMehdi Amini 116b88c852SAart Bik #include "Detail/DimLvlMapParser.h" 126b88c852SAart Bik 13ee3ee131SAart Bik #include "mlir/Dialect/SparseTensor/IR/Enums.h" 14319072f4SAart Bik #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" 15afe78db7SPeiming Liu #include "mlir/Dialect/SparseTensor/IR/SparseTensorStorageLayout.h" 16f708a549Swren romano #include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h" 175517208dSAart Bik 18abc362a1SJakub Kuderski #include "mlir/Dialect/Arith/IR/Arith.h" 19513cdb82SJustin Fargnoli #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" 20e71eacc5SYinying Li #include "mlir/Dialect/Complex/IR/Complex.h" 21cb7bda2aSMatthias Springer #include "mlir/Dialect/Utils/StaticValueUtils.h" 22319072f4SAart Bik #include "mlir/IR/Builders.h" 230a292199SAart Bik #include "mlir/IR/DialectImplementation.h" 2465e7cd13SRiver Riddle #include "mlir/IR/Matchers.h" 25319072f4SAart Bik #include "mlir/IR/OpImplementation.h" 261f07853fSPeiming Liu #include "mlir/IR/PatternMatch.h" 27a43d79afSPeiming Liu #include "llvm/ADT/Bitset.h" 280a292199SAart Bik #include "llvm/ADT/TypeSwitch.h" 29de907138SPeiming Liu #include "llvm/Support/FormatVariadic.h" 30319072f4SAart Bik 3171cc0f1cSPeiming Liu #define GET_ATTRDEF_CLASSES 3271cc0f1cSPeiming Liu #include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.cpp.inc" 3371cc0f1cSPeiming Liu #include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrEnums.cpp.inc" 3471cc0f1cSPeiming Liu 35481bd5d4SPeiming Liu // Forward declarations, following custom print/parsing methods are referenced 36481bd5d4SPeiming Liu // by the generated code for SparseTensorTypes.td. 37481bd5d4SPeiming Liu static mlir::ParseResult parseLevelRange(mlir::AsmParser &, 38481bd5d4SPeiming Liu mlir::sparse_tensor::Level &, 39481bd5d4SPeiming Liu mlir::sparse_tensor::Level &); 40481bd5d4SPeiming Liu static void printLevelRange(mlir::AsmPrinter &, mlir::sparse_tensor::Level, 41481bd5d4SPeiming Liu mlir::sparse_tensor::Level); 42481bd5d4SPeiming Liu 4371cc0f1cSPeiming Liu #define GET_TYPEDEF_CLASSES 4471cc0f1cSPeiming Liu #include "mlir/Dialect/SparseTensor/IR/SparseTensorTypes.cpp.inc" 4571cc0f1cSPeiming Liu 46319072f4SAart Bik using namespace mlir; 47319072f4SAart Bik using namespace mlir::sparse_tensor; 48319072f4SAart Bik 49aaf91645SPeiming Liu // Support hashing LevelType such that SparseTensorEncodingAttr can be hashed as 50aaf91645SPeiming Liu // well. 51aaf91645SPeiming Liu namespace mlir::sparse_tensor { 52aaf91645SPeiming Liu llvm::hash_code hash_value(LevelType lt) { 53aaf91645SPeiming Liu return llvm::hash_value(static_cast<uint64_t>(lt)); 54aaf91645SPeiming Liu } 55aaf91645SPeiming Liu } // namespace mlir::sparse_tensor 56aaf91645SPeiming Liu 570a292199SAart Bik //===----------------------------------------------------------------------===// 5845288085SAart Bik // Local Convenience Methods. 597a1077baSwren romano //===----------------------------------------------------------------------===// 607a1077baSwren romano 617a1077baSwren romano static constexpr bool acceptBitWidth(unsigned bitWidth) { 627a1077baSwren romano switch (bitWidth) { 637a1077baSwren romano case 0: 647a1077baSwren romano case 8: 657a1077baSwren romano case 16: 667a1077baSwren romano case 32: 677a1077baSwren romano case 64: 687a1077baSwren romano return true; 697a1077baSwren romano default: 707a1077baSwren romano return false; 717a1077baSwren romano } 727a1077baSwren romano } 737a1077baSwren romano 7462fa12adSPeiming Liu static SmallVector<Size> 7562fa12adSPeiming Liu getSparseFieldShape(const SparseTensorEncodingAttr enc, 7662fa12adSPeiming Liu std::optional<ArrayRef<int64_t>> dimShape) { 7762fa12adSPeiming Liu assert(enc); 7862fa12adSPeiming Liu // With only encoding, we can not determine the static shape for leading 7962fa12adSPeiming Liu // batch levels, we therefore return a dynamic shape memref instead. 8062fa12adSPeiming Liu SmallVector<int64_t> memrefShape(enc.getBatchLvlRank(), ShapedType::kDynamic); 8162fa12adSPeiming Liu if (dimShape.has_value()) { 8262fa12adSPeiming Liu // If the actual tensor shape is provided, we can then refine the leading 8362fa12adSPeiming Liu // batch dimension. 8462fa12adSPeiming Liu SmallVector<int64_t> lvlShape = 8562fa12adSPeiming Liu enc.translateShape(*dimShape, CrdTransDirectionKind::dim2lvl); 8662fa12adSPeiming Liu memrefShape.assign(lvlShape.begin(), 8762fa12adSPeiming Liu lvlShape.begin() + enc.getBatchLvlRank()); 8862fa12adSPeiming Liu } 8962fa12adSPeiming Liu // Another dynamic dimension to store the sparse level. 9062fa12adSPeiming Liu memrefShape.push_back(ShapedType::kDynamic); 9162fa12adSPeiming Liu return memrefShape; 9262fa12adSPeiming Liu } 9362fa12adSPeiming Liu 947a1077baSwren romano //===----------------------------------------------------------------------===// 957e83a1afSAart Bik // SparseTensorDialect StorageLayout. 96afe78db7SPeiming Liu //===----------------------------------------------------------------------===// 97afe78db7SPeiming Liu 98afe78db7SPeiming Liu static constexpr Level kInvalidLevel = -1u; 99afe78db7SPeiming Liu static constexpr Level kInvalidFieldIndex = -1u; 100afe78db7SPeiming Liu static constexpr FieldIndex kDataFieldStartingIdx = 0; 101afe78db7SPeiming Liu 102afe78db7SPeiming Liu void StorageLayout::foreachField( 103afe78db7SPeiming Liu llvm::function_ref<bool(FieldIndex, SparseTensorFieldKind, Level, 1041944c4f7SAart Bik LevelType)> 105afe78db7SPeiming Liu callback) const { 106afe78db7SPeiming Liu const auto lvlTypes = enc.getLvlTypes(); 107afe78db7SPeiming Liu const Level lvlRank = enc.getLvlRank(); 10883f3b1cbSYinying Li SmallVector<COOSegment> cooSegs = enc.getCOOSegments(); 109afe78db7SPeiming Liu FieldIndex fieldIdx = kDataFieldStartingIdx; 110f740366fSPeiming Liu 111f740366fSPeiming Liu ArrayRef cooSegsRef = cooSegs; 112afe78db7SPeiming Liu // Per-level storage. 113f740366fSPeiming Liu for (Level l = 0; l < lvlRank; /*l += 1 or l += AoSCooLen*/) { 1141dd387e1SAart Bik const auto lt = lvlTypes[l]; 1151dd387e1SAart Bik if (isWithPosLT(lt)) { 1161dd387e1SAart Bik if (!(callback(fieldIdx++, SparseTensorFieldKind::PosMemRef, l, lt))) 1172e2011daSAart Bik return; 118de560888SPeiming Liu } 1191dd387e1SAart Bik if (isWithCrdLT(lt)) { 1201dd387e1SAart Bik if (!(callback(fieldIdx++, SparseTensorFieldKind::CrdMemRef, l, lt))) 1212e2011daSAart Bik return; 122afe78db7SPeiming Liu } 123f740366fSPeiming Liu if (!cooSegsRef.empty() && cooSegsRef.front().isSegmentStart(l)) { 124f740366fSPeiming Liu if (!cooSegsRef.front().isSoA) { 125f740366fSPeiming Liu // AoS COO, all singletons are fused into one memrefs. Skips the entire 126f740366fSPeiming Liu // COO segement. 127f740366fSPeiming Liu l = cooSegsRef.front().lvlRange.second; 128f740366fSPeiming Liu } else { 129f740366fSPeiming Liu // SoA COO, each singleton level has one memref. 130f740366fSPeiming Liu l++; 131f740366fSPeiming Liu } 132f740366fSPeiming Liu // Expire handled COO segment. 133f740366fSPeiming Liu cooSegsRef = cooSegsRef.drop_front(); 134f740366fSPeiming Liu } else { 135f740366fSPeiming Liu // Non COO levels. 136f740366fSPeiming Liu l++; 137f740366fSPeiming Liu } 138afe78db7SPeiming Liu } 139afe78db7SPeiming Liu // The values array. 1402e2011daSAart Bik if (!(callback(fieldIdx++, SparseTensorFieldKind::ValMemRef, kInvalidLevel, 141aaf91645SPeiming Liu LevelFormat::Undef))) 1422e2011daSAart Bik return; 143afe78db7SPeiming Liu // Put metadata at the end. 1442e2011daSAart Bik if (!(callback(fieldIdx++, SparseTensorFieldKind::StorageSpec, kInvalidLevel, 145aaf91645SPeiming Liu LevelFormat::Undef))) 1462e2011daSAart Bik return; 147afe78db7SPeiming Liu } 148afe78db7SPeiming Liu 149afe78db7SPeiming Liu void sparse_tensor::foreachFieldAndTypeInSparseTensor( 150afe78db7SPeiming Liu SparseTensorType stt, 151afe78db7SPeiming Liu llvm::function_ref<bool(Type, FieldIndex, SparseTensorFieldKind, Level, 1521944c4f7SAart Bik LevelType)> 153afe78db7SPeiming Liu callback) { 154afe78db7SPeiming Liu assert(stt.hasEncoding()); 155afe78db7SPeiming Liu 15662fa12adSPeiming Liu SmallVector<int64_t> memrefShape = 15762fa12adSPeiming Liu getSparseFieldShape(stt.getEncoding(), stt.getDimShape()); 1580d1f9576SPeiming Liu 159afe78db7SPeiming Liu const Type specType = StorageSpecifierType::get(stt.getEncoding()); 1600d1f9576SPeiming Liu // memref<[batch] x ? x pos> positions 16162fa12adSPeiming Liu const Type posMemType = MemRefType::get(memrefShape, stt.getPosType()); 1620d1f9576SPeiming Liu // memref<[batch] x ? x crd> coordinates 16362fa12adSPeiming Liu const Type crdMemType = MemRefType::get(memrefShape, stt.getCrdType()); 1640d1f9576SPeiming Liu // memref<[batch] x ? x eltType> values 16562fa12adSPeiming Liu const Type valMemType = MemRefType::get(memrefShape, stt.getElementType()); 166afe78db7SPeiming Liu 1671944c4f7SAart Bik StorageLayout(stt).foreachField([specType, posMemType, crdMemType, valMemType, 1681944c4f7SAart Bik callback](FieldIndex fieldIdx, 1691944c4f7SAart Bik SparseTensorFieldKind fieldKind, 1701944c4f7SAart Bik Level lvl, LevelType lt) -> bool { 171afe78db7SPeiming Liu switch (fieldKind) { 172afe78db7SPeiming Liu case SparseTensorFieldKind::StorageSpec: 1731dd387e1SAart Bik return callback(specType, fieldIdx, fieldKind, lvl, lt); 174afe78db7SPeiming Liu case SparseTensorFieldKind::PosMemRef: 1751dd387e1SAart Bik return callback(posMemType, fieldIdx, fieldKind, lvl, lt); 176afe78db7SPeiming Liu case SparseTensorFieldKind::CrdMemRef: 1771dd387e1SAart Bik return callback(crdMemType, fieldIdx, fieldKind, lvl, lt); 178afe78db7SPeiming Liu case SparseTensorFieldKind::ValMemRef: 1791dd387e1SAart Bik return callback(valMemType, fieldIdx, fieldKind, lvl, lt); 180afe78db7SPeiming Liu }; 181afe78db7SPeiming Liu llvm_unreachable("unrecognized field kind"); 182afe78db7SPeiming Liu }); 183afe78db7SPeiming Liu } 184afe78db7SPeiming Liu 185afe78db7SPeiming Liu unsigned StorageLayout::getNumFields() const { 186afe78db7SPeiming Liu unsigned numFields = 0; 187afe78db7SPeiming Liu foreachField([&numFields](FieldIndex, SparseTensorFieldKind, Level, 1881944c4f7SAart Bik LevelType) -> bool { 189afe78db7SPeiming Liu numFields++; 190afe78db7SPeiming Liu return true; 191afe78db7SPeiming Liu }); 192afe78db7SPeiming Liu return numFields; 193afe78db7SPeiming Liu } 194afe78db7SPeiming Liu 195afe78db7SPeiming Liu unsigned StorageLayout::getNumDataFields() const { 196afe78db7SPeiming Liu unsigned numFields = 0; // one value memref 197afe78db7SPeiming Liu foreachField([&numFields](FieldIndex fidx, SparseTensorFieldKind, Level, 1981944c4f7SAart Bik LevelType) -> bool { 199afe78db7SPeiming Liu if (fidx >= kDataFieldStartingIdx) 200afe78db7SPeiming Liu numFields++; 201afe78db7SPeiming Liu return true; 202afe78db7SPeiming Liu }); 203afe78db7SPeiming Liu numFields -= 1; // the last field is StorageSpecifier 204afe78db7SPeiming Liu assert(numFields == getNumFields() - kDataFieldStartingIdx - 1); 205afe78db7SPeiming Liu return numFields; 206afe78db7SPeiming Liu } 207afe78db7SPeiming Liu 208afe78db7SPeiming Liu std::pair<FieldIndex, unsigned> 209afe78db7SPeiming Liu StorageLayout::getFieldIndexAndStride(SparseTensorFieldKind kind, 210afe78db7SPeiming Liu std::optional<Level> lvl) const { 211afe78db7SPeiming Liu FieldIndex fieldIdx = kInvalidFieldIndex; 212afe78db7SPeiming Liu unsigned stride = 1; 213afe78db7SPeiming Liu if (kind == SparseTensorFieldKind::CrdMemRef) { 214afe78db7SPeiming Liu assert(lvl.has_value()); 21583f3b1cbSYinying Li const Level cooStart = enc.getAoSCOOStart(); 216afe78db7SPeiming Liu const Level lvlRank = enc.getLvlRank(); 217afe78db7SPeiming Liu if (lvl.value() >= cooStart && lvl.value() < lvlRank) { 218afe78db7SPeiming Liu lvl = cooStart; 219afe78db7SPeiming Liu stride = lvlRank - cooStart; 220afe78db7SPeiming Liu } 221afe78db7SPeiming Liu } 222afe78db7SPeiming Liu foreachField([lvl, kind, &fieldIdx](FieldIndex fIdx, 223afe78db7SPeiming Liu SparseTensorFieldKind fKind, Level fLvl, 2241944c4f7SAart Bik LevelType lt) -> bool { 225afe78db7SPeiming Liu if ((lvl && fLvl == lvl.value() && kind == fKind) || 226afe78db7SPeiming Liu (kind == fKind && fKind == SparseTensorFieldKind::ValMemRef)) { 227afe78db7SPeiming Liu fieldIdx = fIdx; 228afe78db7SPeiming Liu // Returns false to break the iteration. 229afe78db7SPeiming Liu return false; 230afe78db7SPeiming Liu } 231afe78db7SPeiming Liu return true; 232afe78db7SPeiming Liu }); 233afe78db7SPeiming Liu assert(fieldIdx != kInvalidFieldIndex); 234afe78db7SPeiming Liu return std::pair<FieldIndex, unsigned>(fieldIdx, stride); 235afe78db7SPeiming Liu } 236afe78db7SPeiming Liu 237afe78db7SPeiming Liu //===----------------------------------------------------------------------===// 2387e83a1afSAart Bik // SparseTensorDialect Attribute Methods. 2390a292199SAart Bik //===----------------------------------------------------------------------===// 2400a292199SAart Bik 2417a1077baSwren romano std::optional<uint64_t> SparseTensorDimSliceAttr::getStatic(int64_t v) { 2427a1077baSwren romano return isDynamic(v) ? std::nullopt 2437a1077baSwren romano : std::make_optional(static_cast<uint64_t>(v)); 2440a292199SAart Bik } 2457a1077baSwren romano 2467a1077baSwren romano std::optional<uint64_t> SparseTensorDimSliceAttr::getStaticOffset() const { 2477a1077baSwren romano return getStatic(getOffset()); 2487a1077baSwren romano } 2497a1077baSwren romano 2507a1077baSwren romano std::optional<uint64_t> SparseTensorDimSliceAttr::getStaticStride() const { 2517a1077baSwren romano return getStatic(getStride()); 2527a1077baSwren romano } 2537a1077baSwren romano 2547a1077baSwren romano std::optional<uint64_t> SparseTensorDimSliceAttr::getStaticSize() const { 2557a1077baSwren romano return getStatic(getSize()); 2567a1077baSwren romano } 2577a1077baSwren romano 2587a1077baSwren romano bool SparseTensorDimSliceAttr::isCompletelyDynamic() const { 2597a1077baSwren romano return isDynamic(getOffset()) && isDynamic(getStride()) && 2607a1077baSwren romano isDynamic(getSize()); 2617a1077baSwren romano } 2627a1077baSwren romano 2637a1077baSwren romano std::string SparseTensorDimSliceAttr::getStaticString(int64_t v) { 2647a1077baSwren romano return isDynamic(v) ? "?" : std::to_string(v); 2650a292199SAart Bik } 2660a292199SAart Bik 26711a4f5bdSAart Bik void SparseTensorDimSliceAttr::print(llvm::raw_ostream &os) const { 268cad46467Swren romano assert(getImpl() && "Uninitialized SparseTensorDimSliceAttr"); 26911a4f5bdSAart Bik os << '('; 27011a4f5bdSAart Bik os << getStaticString(getOffset()); 27111a4f5bdSAart Bik os << ", "; 27211a4f5bdSAart Bik os << getStaticString(getSize()); 27311a4f5bdSAart Bik os << ", "; 27411a4f5bdSAart Bik os << getStaticString(getStride()); 27511a4f5bdSAart Bik os << ')'; 27611a4f5bdSAart Bik } 27711a4f5bdSAart Bik 278885a1f43SPeiming Liu void SparseTensorDimSliceAttr::print(AsmPrinter &printer) const { 27911a4f5bdSAart Bik print(printer.getStream()); 280885a1f43SPeiming Liu } 281885a1f43SPeiming Liu 282885a1f43SPeiming Liu static ParseResult parseOptionalStaticSlice(int64_t &result, 283885a1f43SPeiming Liu AsmParser &parser) { 284885a1f43SPeiming Liu auto parseResult = parser.parseOptionalInteger(result); 285885a1f43SPeiming Liu if (parseResult.has_value()) { 286885a1f43SPeiming Liu if (parseResult.value().succeeded() && result < 0) { 287885a1f43SPeiming Liu parser.emitError( 288885a1f43SPeiming Liu parser.getCurrentLocation(), 289885a1f43SPeiming Liu "expect positive value or ? for slice offset/size/stride"); 290885a1f43SPeiming Liu return failure(); 291885a1f43SPeiming Liu } 292885a1f43SPeiming Liu return parseResult.value(); 293885a1f43SPeiming Liu } 294885a1f43SPeiming Liu 295885a1f43SPeiming Liu // Else, and '?' which represented dynamic slice 296885a1f43SPeiming Liu result = SparseTensorDimSliceAttr::kDynamic; 297885a1f43SPeiming Liu return parser.parseQuestion(); 298885a1f43SPeiming Liu } 299885a1f43SPeiming Liu 300885a1f43SPeiming Liu Attribute SparseTensorDimSliceAttr::parse(AsmParser &parser, Type type) { 3017a1077baSwren romano int64_t offset = kDynamic, size = kDynamic, stride = kDynamic; 302885a1f43SPeiming Liu 303885a1f43SPeiming Liu if (failed(parser.parseLParen()) || 304885a1f43SPeiming Liu failed(parseOptionalStaticSlice(offset, parser)) || 305885a1f43SPeiming Liu failed(parser.parseComma()) || 306885a1f43SPeiming Liu failed(parseOptionalStaticSlice(size, parser)) || 307885a1f43SPeiming Liu failed(parser.parseComma()) || 308885a1f43SPeiming Liu failed(parseOptionalStaticSlice(stride, parser)) || 309885a1f43SPeiming Liu failed(parser.parseRParen())) 310885a1f43SPeiming Liu return {}; 311885a1f43SPeiming Liu 312885a1f43SPeiming Liu return parser.getChecked<SparseTensorDimSliceAttr>(parser.getContext(), 313885a1f43SPeiming Liu offset, size, stride); 314885a1f43SPeiming Liu } 315885a1f43SPeiming Liu 316885a1f43SPeiming Liu LogicalResult 317885a1f43SPeiming Liu SparseTensorDimSliceAttr::verify(function_ref<InFlightDiagnostic()> emitError, 318885a1f43SPeiming Liu int64_t offset, int64_t size, int64_t stride) { 3197a1077baSwren romano if (!isDynamic(offset) && offset < 0) 3207a1077baSwren romano return emitError() << "expect non-negative value or ? for slice offset"; 3217a1077baSwren romano if (!isDynamic(size) && size <= 0) 3227a1077baSwren romano return emitError() << "expect positive value or ? for slice size"; 3237a1077baSwren romano if (!isDynamic(stride) && stride <= 0) 3247a1077baSwren romano return emitError() << "expect positive value or ? for slice stride"; 325885a1f43SPeiming Liu return success(); 326885a1f43SPeiming Liu } 327885a1f43SPeiming Liu 32876647fceSwren romano SparseTensorEncodingAttr 32976647fceSwren romano SparseTensorEncodingAttr::withDimToLvl(AffineMap dimToLvl) const { 33076647fceSwren romano assert(getImpl() && "Uninitialized SparseTensorEncodingAttr"); 331a10d67f9SYinying Li return SparseTensorEncodingAttr::get( 332a10d67f9SYinying Li getContext(), getLvlTypes(), dimToLvl, AffineMap(), getPosWidth(), 333a10d67f9SYinying Li getCrdWidth(), getExplicitVal(), getImplicitVal()); 33476647fceSwren romano } 33576647fceSwren romano 33676647fceSwren romano SparseTensorEncodingAttr 33776647fceSwren romano SparseTensorEncodingAttr::withDimToLvl(SparseTensorEncodingAttr enc) const { 33876647fceSwren romano return withDimToLvl(enc ? enc.getDimToLvl() : AffineMap()); 33976647fceSwren romano } 34076647fceSwren romano 34176647fceSwren romano SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutDimToLvl() const { 34276647fceSwren romano return withDimToLvl(AffineMap()); 34376647fceSwren romano } 34476647fceSwren romano 34576647fceSwren romano SparseTensorEncodingAttr 34676647fceSwren romano SparseTensorEncodingAttr::withBitWidths(unsigned posWidth, 34776647fceSwren romano unsigned crdWidth) const { 34876647fceSwren romano assert(getImpl() && "Uninitialized SparseTensorEncodingAttr"); 349a10d67f9SYinying Li return SparseTensorEncodingAttr::get( 350a10d67f9SYinying Li getContext(), getLvlTypes(), getDimToLvl(), getLvlToDim(), posWidth, 351a10d67f9SYinying Li crdWidth, getExplicitVal(), getImplicitVal()); 35296fef4dcSwren romano } 35396fef4dcSwren romano 35485dbb3fcSPeiming Liu SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutBitWidths() const { 35576647fceSwren romano return withBitWidths(0, 0); 35685dbb3fcSPeiming Liu } 35785dbb3fcSPeiming Liu 358a10d67f9SYinying Li SparseTensorEncodingAttr 359a10d67f9SYinying Li SparseTensorEncodingAttr::withExplicitVal(Attribute explicitVal) const { 360a10d67f9SYinying Li assert(getImpl() && "Uninitialized SparseTensorEncodingAttr"); 361a10d67f9SYinying Li return SparseTensorEncodingAttr::get( 362a10d67f9SYinying Li getContext(), getLvlTypes(), getDimToLvl(), getLvlToDim(), getPosWidth(), 363a10d67f9SYinying Li getCrdWidth(), explicitVal, getImplicitVal()); 364a10d67f9SYinying Li } 365a10d67f9SYinying Li 366a10d67f9SYinying Li SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutExplicitVal() const { 367a10d67f9SYinying Li return withExplicitVal(Attribute()); 368a10d67f9SYinying Li } 369a10d67f9SYinying Li 370a10d67f9SYinying Li SparseTensorEncodingAttr 371a10d67f9SYinying Li SparseTensorEncodingAttr::withImplicitVal(Attribute implicitVal) const { 372a10d67f9SYinying Li assert(getImpl() && "Uninitialized SparseTensorEncodingAttr"); 373a10d67f9SYinying Li return SparseTensorEncodingAttr::get( 374a10d67f9SYinying Li getContext(), getLvlTypes(), getDimToLvl(), getLvlToDim(), getPosWidth(), 375a10d67f9SYinying Li getCrdWidth(), getExplicitVal(), implicitVal); 376a10d67f9SYinying Li } 377a10d67f9SYinying Li 378a10d67f9SYinying Li SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutImplicitVal() const { 379a10d67f9SYinying Li return withImplicitVal(Attribute()); 380a10d67f9SYinying Li } 381a10d67f9SYinying Li 382af2bec7cSwren romano SparseTensorEncodingAttr SparseTensorEncodingAttr::withDimSlices( 383af2bec7cSwren romano ArrayRef<SparseTensorDimSliceAttr> dimSlices) const { 384a10d67f9SYinying Li return SparseTensorEncodingAttr::get( 385a10d67f9SYinying Li getContext(), getLvlTypes(), getDimToLvl(), getLvlToDim(), getPosWidth(), 386a10d67f9SYinying Li getCrdWidth(), getExplicitVal(), getImplicitVal(), dimSlices); 387af2bec7cSwren romano } 388af2bec7cSwren romano 389af2bec7cSwren romano SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutDimSlices() const { 390af2bec7cSwren romano return withDimSlices(ArrayRef<SparseTensorDimSliceAttr>{}); 391af2bec7cSwren romano } 392af2bec7cSwren romano 3930d1f9576SPeiming Liu uint64_t SparseTensorEncodingAttr::getBatchLvlRank() const { 3940d1f9576SPeiming Liu ArrayRef<LevelType> lvlTypes = getLvlTypes(); 3950d1f9576SPeiming Liu auto lastBatch = std::find_if(lvlTypes.rbegin(), lvlTypes.rend(), isBatchLT); 3960d1f9576SPeiming Liu return std::distance(lastBatch, lvlTypes.rend()); 3970d1f9576SPeiming Liu } 3980d1f9576SPeiming Liu 399a3672addSPeiming Liu bool SparseTensorEncodingAttr::isAllDense() const { 4001dd387e1SAart Bik return !getImpl() || llvm::all_of(getLvlTypes(), isDenseLT); 401f708a549Swren romano } 402f708a549Swren romano 403f708a549Swren romano bool SparseTensorEncodingAttr::isAllOrdered() const { 4041dd387e1SAart Bik return !getImpl() || llvm::all_of(getLvlTypes(), isOrderedLT); 405a3672addSPeiming Liu } 406a3672addSPeiming Liu 40762fa12adSPeiming Liu Type SparseTensorEncodingAttr::getCrdElemType() const { 40862fa12adSPeiming Liu if (!getImpl()) 40962fa12adSPeiming Liu return nullptr; 41062fa12adSPeiming Liu if (getCrdWidth()) 41162fa12adSPeiming Liu return IntegerType::get(getContext(), getCrdWidth()); 41262fa12adSPeiming Liu return IndexType::get(getContext()); 41362fa12adSPeiming Liu } 41462fa12adSPeiming Liu 41562fa12adSPeiming Liu Type SparseTensorEncodingAttr::getPosElemType() const { 41662fa12adSPeiming Liu if (!getImpl()) 41762fa12adSPeiming Liu return nullptr; 41862fa12adSPeiming Liu if (getPosWidth()) 41962fa12adSPeiming Liu return IntegerType::get(getContext(), getPosWidth()); 42062fa12adSPeiming Liu return IndexType::get(getContext()); 42162fa12adSPeiming Liu } 42262fa12adSPeiming Liu 42362fa12adSPeiming Liu MemRefType SparseTensorEncodingAttr::getCrdMemRefType( 42462fa12adSPeiming Liu std::optional<ArrayRef<int64_t>> dimShape) const { 42562fa12adSPeiming Liu SmallVector<Size> shape = getSparseFieldShape(*this, dimShape); 42662fa12adSPeiming Liu return MemRefType::get(shape, getCrdElemType()); 42762fa12adSPeiming Liu } 42862fa12adSPeiming Liu 42962fa12adSPeiming Liu MemRefType SparseTensorEncodingAttr::getPosMemRefType( 43062fa12adSPeiming Liu std::optional<ArrayRef<int64_t>> dimShape) const { 43162fa12adSPeiming Liu SmallVector<Size> shape = getSparseFieldShape(*this, dimShape); 43262fa12adSPeiming Liu return MemRefType::get(shape, getPosElemType()); 43362fa12adSPeiming Liu } 43462fa12adSPeiming Liu 43576647fceSwren romano bool SparseTensorEncodingAttr::isIdentity() const { 43676647fceSwren romano return !getImpl() || !getDimToLvl() || getDimToLvl().isIdentity(); 43776647fceSwren romano } 43876647fceSwren romano 43976647fceSwren romano bool SparseTensorEncodingAttr::isPermutation() const { 44076647fceSwren romano return !getImpl() || !getDimToLvl() || getDimToLvl().isPermutation(); 44176647fceSwren romano } 44276647fceSwren romano 44376647fceSwren romano Dimension SparseTensorEncodingAttr::getDimRank() const { 44476647fceSwren romano assert(getImpl() && "Uninitialized SparseTensorEncodingAttr"); 44576647fceSwren romano const auto dimToLvl = getDimToLvl(); 44676647fceSwren romano return dimToLvl ? dimToLvl.getNumDims() : getLvlRank(); 447f708a549Swren romano } 448f708a549Swren romano 449f708a549Swren romano Level SparseTensorEncodingAttr::getLvlRank() const { 450f708a549Swren romano assert(getImpl() && "Uninitialized SparseTensorEncodingAttr"); 451a0615d02Swren romano return getLvlTypes().size(); 452f708a549Swren romano } 453f708a549Swren romano 4541944c4f7SAart Bik LevelType SparseTensorEncodingAttr::getLvlType(Level l) const { 455f708a549Swren romano if (!getImpl()) 45652b69aa3SPeiming Liu return LevelFormat::Batch; 457f708a549Swren romano assert(l < getLvlRank() && "Level is out of bounds"); 458a0615d02Swren romano return getLvlTypes()[l]; 459a3672addSPeiming Liu } 460a3672addSPeiming Liu 461867e1964Swren romano bool SparseTensorEncodingAttr::isSlice() const { 462867e1964Swren romano assert(getImpl() && "Uninitialized SparseTensorEncodingAttr"); 463867e1964Swren romano return !getDimSlices().empty(); 464867e1964Swren romano } 465867e1964Swren romano 466867e1964Swren romano SparseTensorDimSliceAttr 467867e1964Swren romano SparseTensorEncodingAttr::getDimSlice(Dimension dim) const { 468867e1964Swren romano assert(isSlice() && "Is not a slice"); 469867e1964Swren romano const auto dimSlices = getDimSlices(); 470867e1964Swren romano assert(dim < dimSlices.size() && "Dimension is out of bounds"); 471867e1964Swren romano return dimSlices[dim]; 472867e1964Swren romano } 473867e1964Swren romano 474885a1f43SPeiming Liu std::optional<uint64_t> 475f708a549Swren romano SparseTensorEncodingAttr::getStaticDimSliceOffset(Dimension dim) const { 476867e1964Swren romano return getDimSlice(dim).getStaticOffset(); 477885a1f43SPeiming Liu } 478885a1f43SPeiming Liu 479885a1f43SPeiming Liu std::optional<uint64_t> 480f708a549Swren romano SparseTensorEncodingAttr::getStaticDimSliceStride(Dimension dim) const { 481867e1964Swren romano return getDimSlice(dim).getStaticStride(); 482885a1f43SPeiming Liu } 483885a1f43SPeiming Liu 484885a1f43SPeiming Liu std::optional<uint64_t> 485f708a549Swren romano SparseTensorEncodingAttr::getStaticLvlSliceOffset(Level lvl) const { 4864e2f1521SPeiming Liu return getStaticDimSliceOffset(toDim(*this, lvl)); 487885a1f43SPeiming Liu } 488885a1f43SPeiming Liu 489885a1f43SPeiming Liu std::optional<uint64_t> 490f708a549Swren romano SparseTensorEncodingAttr::getStaticLvlSliceStride(Level lvl) const { 4914e2f1521SPeiming Liu return getStaticDimSliceStride(toDim(*this, lvl)); 492885a1f43SPeiming Liu } 493885a1f43SPeiming Liu 494d808d922SPeiming Liu SmallVector<int64_t> 495e10dc60aSAart Bik SparseTensorEncodingAttr::translateShape(ArrayRef<int64_t> srcShape, 496d808d922SPeiming Liu CrdTransDirectionKind dir) const { 497d808d922SPeiming Liu if (isIdentity()) 498d808d922SPeiming Liu return SmallVector<int64_t>(srcShape); 499d808d922SPeiming Liu 500d808d922SPeiming Liu SmallVector<int64_t> ret; 501d808d922SPeiming Liu unsigned rank = 502d808d922SPeiming Liu dir == CrdTransDirectionKind::dim2lvl ? getLvlRank() : getDimRank(); 503d808d922SPeiming Liu ret.reserve(rank); 504d808d922SPeiming Liu 505d808d922SPeiming Liu if (isPermutation()) { 506d808d922SPeiming Liu for (unsigned r = 0; r < rank; r++) { 5074e2f1521SPeiming Liu unsigned trans = dir == CrdTransDirectionKind::dim2lvl ? toDim(*this, r) 5084e2f1521SPeiming Liu : toLvl(*this, r); 509d808d922SPeiming Liu ret.push_back(srcShape[trans]); 510d808d922SPeiming Liu } 511d808d922SPeiming Liu return ret; 512d808d922SPeiming Liu } 513d808d922SPeiming Liu 514d808d922SPeiming Liu // Handle non-permutation maps. 515d808d922SPeiming Liu AffineMap transMap = 516d808d922SPeiming Liu dir == CrdTransDirectionKind::dim2lvl ? getDimToLvl() : getLvlToDim(); 517d808d922SPeiming Liu 518d808d922SPeiming Liu SmallVector<AffineExpr> dimRep; 519d808d922SPeiming Liu dimRep.reserve(srcShape.size()); 520d808d922SPeiming Liu for (int64_t sz : srcShape) { 521d808d922SPeiming Liu if (!ShapedType::isDynamic(sz)) { 522d808d922SPeiming Liu // Push back the max coordinate for the given dimension/level size. 523d808d922SPeiming Liu dimRep.push_back(getAffineConstantExpr(sz - 1, getContext())); 524d808d922SPeiming Liu } else { 525d808d922SPeiming Liu // A dynamic size, use a AffineDimExpr to symbolize the value. 526d808d922SPeiming Liu dimRep.push_back(getAffineDimExpr(dimRep.size(), getContext())); 527d808d922SPeiming Liu } 528d808d922SPeiming Liu }; 529d808d922SPeiming Liu 530d808d922SPeiming Liu for (AffineExpr exp : transMap.getResults()) { 531d808d922SPeiming Liu // Do constant propagation on the affine map. 532d808d922SPeiming Liu AffineExpr evalExp = 533d808d922SPeiming Liu simplifyAffineExpr(exp.replaceDims(dimRep), srcShape.size(), 0); 5341609f1c2Slong.chen // use llvm namespace here to avoid ambiguity 5351609f1c2Slong.chen if (auto c = llvm::dyn_cast<AffineConstantExpr>(evalExp)) { 536d808d922SPeiming Liu ret.push_back(c.getValue() + 1); 537ef222988SPeiming Liu } else { 5381609f1c2Slong.chen if (auto mod = llvm::dyn_cast<AffineBinaryOpExpr>(evalExp); 539ef222988SPeiming Liu mod && mod.getKind() == AffineExprKind::Mod) { 540ef222988SPeiming Liu // We can still infer a static bound for expressions in form 541ef222988SPeiming Liu // "d % constant" since d % constant \in [0, constant). 5421609f1c2Slong.chen if (auto bound = llvm::dyn_cast<AffineConstantExpr>(mod.getRHS())) { 543ef222988SPeiming Liu ret.push_back(bound.getValue()); 544ef222988SPeiming Liu continue; 545ef222988SPeiming Liu } 546ef222988SPeiming Liu } 547d808d922SPeiming Liu ret.push_back(ShapedType::kDynamic); 548d808d922SPeiming Liu } 549ef222988SPeiming Liu } 550ef222988SPeiming Liu assert(ret.size() == rank); 551d808d922SPeiming Liu return ret; 552d808d922SPeiming Liu } 553d808d922SPeiming Liu 5546456e0bbSPeiming Liu ValueRange 5556456e0bbSPeiming Liu SparseTensorEncodingAttr::translateCrds(OpBuilder &builder, Location loc, 5566456e0bbSPeiming Liu ValueRange crds, 5576456e0bbSPeiming Liu CrdTransDirectionKind dir) const { 5586456e0bbSPeiming Liu if (!getImpl()) 5596456e0bbSPeiming Liu return crds; 5606456e0bbSPeiming Liu 5616456e0bbSPeiming Liu SmallVector<Type> retType( 5626456e0bbSPeiming Liu dir == CrdTransDirectionKind::lvl2dim ? getDimRank() : getLvlRank(), 5636456e0bbSPeiming Liu builder.getIndexType()); 5646456e0bbSPeiming Liu auto transOp = builder.create<CrdTranslateOp>(loc, retType, crds, dir, *this); 5656456e0bbSPeiming Liu return transOp.getOutCrds(); 5666456e0bbSPeiming Liu } 5676456e0bbSPeiming Liu 568f97e72aaSMehdi Amini Attribute SparseTensorEncodingAttr::parse(AsmParser &parser, Type type) { 5692e2011daSAart Bik // Open "<{" part. 5702e2011daSAart Bik if (failed(parser.parseLess())) 5712e2011daSAart Bik return {}; 5722e2011daSAart Bik if (failed(parser.parseLBrace())) 5732e2011daSAart Bik return {}; 574885a1f43SPeiming Liu 5750a292199SAart Bik // Process the data from the parsed dictionary value into struct-like data. 5761944c4f7SAart Bik SmallVector<LevelType> lvlTypes; 577540d5e0cSwren romano SmallVector<SparseTensorDimSliceAttr> dimSlices; 57876647fceSwren romano AffineMap dimToLvl = {}; 5797b9fb1c2SYinying Li AffineMap lvlToDim = {}; 58084cd51bbSwren romano unsigned posWidth = 0; 58184cd51bbSwren romano unsigned crdWidth = 0; 582a10d67f9SYinying Li Attribute explicitVal; 583a10d67f9SYinying Li Attribute implicitVal; 584885a1f43SPeiming Liu StringRef attrName; 585a10d67f9SYinying Li SmallVector<StringRef, 5> keys = {"map", "posWidth", "crdWidth", 586a10d67f9SYinying Li "explicitVal", "implicitVal"}; 587885a1f43SPeiming Liu while (succeeded(parser.parseOptionalKeyword(&attrName))) { 588bb44a6b7SAart Bik // Detect admissible keyword. 589bb44a6b7SAart Bik auto *it = find(keys, attrName); 590bb44a6b7SAart Bik if (it == keys.end()) { 591885a1f43SPeiming Liu parser.emitError(parser.getNameLoc(), "unexpected key: ") << attrName; 592885a1f43SPeiming Liu return {}; 593885a1f43SPeiming Liu } 594bb44a6b7SAart Bik unsigned keyWordIndex = it - keys.begin(); 595885a1f43SPeiming Liu // Consume the `=` after keys 5962e2011daSAart Bik if (failed(parser.parseEqual())) 5972e2011daSAart Bik return {}; 598bb44a6b7SAart Bik // Dispatch on keyword. 599bb44a6b7SAart Bik switch (keyWordIndex) { 600fb5047f5SYinying Li case 0: { // map 6016b88c852SAart Bik ir_detail::DimLvlMapParser cParser(parser); 6026b88c852SAart Bik auto res = cParser.parseDimLvlMap(); 6032e2011daSAart Bik if (failed(res)) 6042e2011daSAart Bik return {}; 605cad46467Swren romano const auto &dlm = *res; 606cad46467Swren romano 607cad46467Swren romano const Level lvlRank = dlm.getLvlRank(); 608cad46467Swren romano for (Level lvl = 0; lvl < lvlRank; lvl++) 609cad46467Swren romano lvlTypes.push_back(dlm.getLvlType(lvl)); 610cad46467Swren romano 611cad46467Swren romano const Dimension dimRank = dlm.getDimRank(); 612cad46467Swren romano for (Dimension dim = 0; dim < dimRank; dim++) 613cad46467Swren romano dimSlices.push_back(dlm.getDimSlice(dim)); 614cad46467Swren romano // NOTE: the old syntax requires an all-or-nothing approach to 615cad46467Swren romano // `dimSlices`; therefore, if any slice actually exists then we need 616cad46467Swren romano // to convert null-DSA into default/nop DSA. 617cad46467Swren romano const auto isDefined = [](SparseTensorDimSliceAttr slice) { 618cad46467Swren romano return static_cast<bool>(slice.getImpl()); 619cad46467Swren romano }; 620cad46467Swren romano if (llvm::any_of(dimSlices, isDefined)) { 621cad46467Swren romano const auto defaultSlice = 622cad46467Swren romano SparseTensorDimSliceAttr::get(parser.getContext()); 623cad46467Swren romano for (Dimension dim = 0; dim < dimRank; dim++) 624cad46467Swren romano if (!isDefined(dimSlices[dim])) 625cad46467Swren romano dimSlices[dim] = defaultSlice; 626cad46467Swren romano } else { 627cad46467Swren romano dimSlices.clear(); 628cad46467Swren romano } 629cad46467Swren romano 630cad46467Swren romano dimToLvl = dlm.getDimToLvlMap(parser.getContext()); 6317b9fb1c2SYinying Li lvlToDim = dlm.getLvlToDimMap(parser.getContext()); 632bb44a6b7SAart Bik break; 6330a292199SAart Bik } 634fb5047f5SYinying Li case 1: { // posWidth 635fb5047f5SYinying Li Attribute attr; 6362e2011daSAart Bik if (failed(parser.parseAttribute(attr))) 6372e2011daSAart Bik return {}; 638fb5047f5SYinying Li auto intAttr = llvm::dyn_cast<IntegerAttr>(attr); 6392e2011daSAart Bik if (!intAttr) { 6402e2011daSAart Bik parser.emitError(parser.getNameLoc(), 6412e2011daSAart Bik "expected an integral position bitwidth"); 6422e2011daSAart Bik return {}; 6432e2011daSAart Bik } 644fb5047f5SYinying Li posWidth = intAttr.getInt(); 645fb5047f5SYinying Li break; 646fb5047f5SYinying Li } 647fb5047f5SYinying Li case 2: { // crdWidth 648fb5047f5SYinying Li Attribute attr; 6492e2011daSAart Bik if (failed(parser.parseAttribute(attr))) 6502e2011daSAart Bik return {}; 651fb5047f5SYinying Li auto intAttr = llvm::dyn_cast<IntegerAttr>(attr); 6522e2011daSAart Bik if (!intAttr) { 6532e2011daSAart Bik parser.emitError(parser.getNameLoc(), 6542e2011daSAart Bik "expected an integral index bitwidth"); 6552e2011daSAart Bik return {}; 6562e2011daSAart Bik } 657fb5047f5SYinying Li crdWidth = intAttr.getInt(); 658fb5047f5SYinying Li break; 659fb5047f5SYinying Li } 660a10d67f9SYinying Li case 3: { // explicitVal 661a10d67f9SYinying Li Attribute attr; 662a10d67f9SYinying Li if (failed(parser.parseAttribute(attr))) 663a10d67f9SYinying Li return {}; 664a10d67f9SYinying Li if (auto result = llvm::dyn_cast<FloatAttr>(attr)) { 665a10d67f9SYinying Li explicitVal = result; 666a10d67f9SYinying Li } else if (auto result = llvm::dyn_cast<IntegerAttr>(attr)) { 667a10d67f9SYinying Li explicitVal = result; 668e71eacc5SYinying Li } else if (auto result = llvm::dyn_cast<complex::NumberAttr>(attr)) { 669e71eacc5SYinying Li explicitVal = result; 670a10d67f9SYinying Li } else { 671a10d67f9SYinying Li parser.emitError(parser.getNameLoc(), 672a10d67f9SYinying Li "expected a numeric value for explicitVal"); 673a10d67f9SYinying Li return {}; 674a10d67f9SYinying Li } 675a10d67f9SYinying Li break; 676a10d67f9SYinying Li } 677a10d67f9SYinying Li case 4: { // implicitVal 678a10d67f9SYinying Li Attribute attr; 679a10d67f9SYinying Li if (failed(parser.parseAttribute(attr))) 680a10d67f9SYinying Li return {}; 681a10d67f9SYinying Li if (auto result = llvm::dyn_cast<FloatAttr>(attr)) { 682a10d67f9SYinying Li implicitVal = result; 683a10d67f9SYinying Li } else if (auto result = llvm::dyn_cast<IntegerAttr>(attr)) { 684a10d67f9SYinying Li implicitVal = result; 685e71eacc5SYinying Li } else if (auto result = llvm::dyn_cast<complex::NumberAttr>(attr)) { 686e71eacc5SYinying Li implicitVal = result; 687a10d67f9SYinying Li } else { 688a10d67f9SYinying Li parser.emitError(parser.getNameLoc(), 689a10d67f9SYinying Li "expected a numeric value for implicitVal"); 690a10d67f9SYinying Li return {}; 691a10d67f9SYinying Li } 692a10d67f9SYinying Li break; 693a10d67f9SYinying Li } 694bb44a6b7SAart Bik } // switch 695bb44a6b7SAart Bik // Only last item can omit the comma. 696885a1f43SPeiming Liu if (parser.parseOptionalComma().failed()) 697885a1f43SPeiming Liu break; 6980a292199SAart Bik } 699885a1f43SPeiming Liu 7002e2011daSAart Bik // Close "}>" part. 7012e2011daSAart Bik if (failed(parser.parseRBrace())) 7022e2011daSAart Bik return {}; 7032e2011daSAart Bik if (failed(parser.parseGreater())) 7042e2011daSAart Bik return {}; 705885a1f43SPeiming Liu 7060a292199SAart Bik // Construct struct-like storage for attribute. 7077b9fb1c2SYinying Li if (!lvlToDim || lvlToDim.isEmpty()) { 7087b9fb1c2SYinying Li lvlToDim = inferLvlToDim(dimToLvl, parser.getContext()); 7097b9fb1c2SYinying Li } 710c48e9087SAart Bik return parser.getChecked<SparseTensorEncodingAttr>( 711836411b9SAart Bik parser.getContext(), lvlTypes, dimToLvl, lvlToDim, posWidth, crdWidth, 712a10d67f9SYinying Li explicitVal, implicitVal, dimSlices); 7130a292199SAart Bik } 7140a292199SAart Bik 715f97e72aaSMehdi Amini void SparseTensorEncodingAttr::print(AsmPrinter &printer) const { 7166280e231SYinying Li auto map = static_cast<AffineMap>(getDimToLvl()); 7176280e231SYinying Li // Empty affine map indicates identity map 7186280e231SYinying Li if (!map) 7196280e231SYinying Li map = AffineMap::getMultiDimIdentityMap(getLvlTypes().size(), getContext()); 7206280e231SYinying Li printer << "<{ map = "; 7216280e231SYinying Li printSymbols(map, printer); 7226280e231SYinying Li printer << '('; 7236280e231SYinying Li printDimensions(map, printer, getDimSlices()); 7246280e231SYinying Li printer << ") -> ("; 7256280e231SYinying Li printLevels(map, printer, getLvlTypes()); 7266280e231SYinying Li printer << ')'; 727e3d64ccfSAart Bik // Print remaining members only for non-default values. 72884cd51bbSwren romano if (getPosWidth()) 72984cd51bbSwren romano printer << ", posWidth = " << getPosWidth(); 73084cd51bbSwren romano if (getCrdWidth()) 73184cd51bbSwren romano printer << ", crdWidth = " << getCrdWidth(); 732a10d67f9SYinying Li if (getExplicitVal()) { 733a10d67f9SYinying Li printer << ", explicitVal = " << getExplicitVal(); 734a10d67f9SYinying Li } 735a10d67f9SYinying Li if (getImplicitVal()) 736a10d67f9SYinying Li printer << ", implicitVal = " << getImplicitVal(); 7376280e231SYinying Li printer << " }>"; 738885a1f43SPeiming Liu } 739885a1f43SPeiming Liu 7406280e231SYinying Li void SparseTensorEncodingAttr::printSymbols(AffineMap &map, 7416280e231SYinying Li AsmPrinter &printer) const { 7426280e231SYinying Li if (map.getNumSymbols() == 0) 7436280e231SYinying Li return; 7446280e231SYinying Li printer << '['; 7456280e231SYinying Li for (unsigned i = 0, n = map.getNumSymbols() - 1; i < n; i++) 7466280e231SYinying Li printer << 's' << i << ", "; 7476280e231SYinying Li if (map.getNumSymbols() >= 1) 7486280e231SYinying Li printer << 's' << map.getNumSymbols() - 1; 7496280e231SYinying Li printer << ']'; 7506280e231SYinying Li } 7516280e231SYinying Li 7526280e231SYinying Li void SparseTensorEncodingAttr::printDimensions( 7536280e231SYinying Li AffineMap &map, AsmPrinter &printer, 7546280e231SYinying Li ArrayRef<SparseTensorDimSliceAttr> dimSlices) const { 7556280e231SYinying Li if (!dimSlices.empty()) { 7566280e231SYinying Li for (unsigned i = 0, n = map.getNumDims() - 1; i < n; i++) 7576280e231SYinying Li printer << 'd' << i << " : " << dimSlices[i] << ", "; 7586280e231SYinying Li if (map.getNumDims() >= 1) { 7596280e231SYinying Li printer << 'd' << map.getNumDims() - 1 << " : " 7606280e231SYinying Li << dimSlices[map.getNumDims() - 1]; 7616280e231SYinying Li } 7626280e231SYinying Li } else { 7636280e231SYinying Li for (unsigned i = 0, n = map.getNumDims() - 1; i < n; i++) 7646280e231SYinying Li printer << 'd' << i << ", "; 7656280e231SYinying Li if (map.getNumDims() >= 1) 7666280e231SYinying Li printer << 'd' << map.getNumDims() - 1; 7676280e231SYinying Li } 7686280e231SYinying Li } 7696280e231SYinying Li 7701944c4f7SAart Bik void SparseTensorEncodingAttr::printLevels(AffineMap &map, AsmPrinter &printer, 7711944c4f7SAart Bik ArrayRef<LevelType> lvlTypes) const { 7726280e231SYinying Li for (unsigned i = 0, n = map.getNumResults() - 1; i < n; i++) { 7736280e231SYinying Li map.getResult(i).print(printer.getStream()); 774ced1fac8SYinying Li printer << " : " << toMLIRString(lvlTypes[i]) << ", "; 7756280e231SYinying Li } 7766280e231SYinying Li if (map.getNumResults() >= 1) { 7776280e231SYinying Li auto lastIndex = map.getNumResults() - 1; 7786280e231SYinying Li map.getResult(lastIndex).print(printer.getStream()); 779ced1fac8SYinying Li printer << " : " << toMLIRString(lvlTypes[lastIndex]); 7806280e231SYinying Li } 7810a292199SAart Bik } 7820a292199SAart Bik 7831944c4f7SAart Bik LogicalResult SparseTensorEncodingAttr::verify( 7841944c4f7SAart Bik function_ref<InFlightDiagnostic()> emitError, ArrayRef<LevelType> lvlTypes, 7851944c4f7SAart Bik AffineMap dimToLvl, AffineMap lvlToDim, unsigned posWidth, 786a10d67f9SYinying Li unsigned crdWidth, Attribute explicitVal, Attribute implicitVal, 787a10d67f9SYinying Li ArrayRef<SparseTensorDimSliceAttr> dimSlices) { 78884cd51bbSwren romano if (!acceptBitWidth(posWidth)) 78984cd51bbSwren romano return emitError() << "unexpected position bitwidth: " << posWidth; 79084cd51bbSwren romano if (!acceptBitWidth(crdWidth)) 79184cd51bbSwren romano return emitError() << "unexpected coordinate bitwidth: " << crdWidth; 79213af97a7SPeiming Liu 79313af97a7SPeiming Liu // Verify every COO segment. 79413af97a7SPeiming Liu auto *it = std::find_if(lvlTypes.begin(), lvlTypes.end(), isSingletonLT); 79513af97a7SPeiming Liu while (it != lvlTypes.end()) { 7967b9fb1c2SYinying Li if (it == lvlTypes.begin() || 79713af97a7SPeiming Liu !(it - 1)->isa<LevelFormat::Compressed, LevelFormat::LooseCompressed>()) 7987b9fb1c2SYinying Li return emitError() << "expected compressed or loose_compressed level " 7997b9fb1c2SYinying Li "before singleton level"; 80013af97a7SPeiming Liu 80113af97a7SPeiming Liu auto *curCOOEnd = std::find_if_not(it, lvlTypes.end(), isSingletonLT); 80213af97a7SPeiming Liu if (!std::all_of(it, curCOOEnd, 8031944c4f7SAart Bik [](LevelType i) { return isSingletonLT(i); })) 8047b9fb1c2SYinying Li return emitError() << "expected all singleton lvlTypes " 8057b9fb1c2SYinying Li "following a singleton level"; 806088c7ce4SPeiming Liu // We can potentially support mixed SoA/AoS singleton levels. 80713af97a7SPeiming Liu if (!std::all_of(it, curCOOEnd, [it](LevelType i) { 808088c7ce4SPeiming Liu return it->isa<LevelPropNonDefault::SoA>() == 809088c7ce4SPeiming Liu i.isa<LevelPropNonDefault::SoA>(); 810088c7ce4SPeiming Liu })) { 811088c7ce4SPeiming Liu return emitError() << "expected all singleton lvlTypes stored in the " 812088c7ce4SPeiming Liu "same memory layout (SoA vs AoS)."; 8137b9fb1c2SYinying Li } 81413af97a7SPeiming Liu it = std::find_if(curCOOEnd, lvlTypes.end(), isSingletonLT); 815088c7ce4SPeiming Liu } 816088c7ce4SPeiming Liu 81756d58295SPeiming Liu auto lastBatch = std::find_if(lvlTypes.rbegin(), lvlTypes.rend(), isBatchLT); 81856d58295SPeiming Liu if (!std::all_of(lastBatch, lvlTypes.rend(), isBatchLT)) 81956d58295SPeiming Liu return emitError() << "Batch lvlType can only be leading levels."; 82056d58295SPeiming Liu 821088c7ce4SPeiming Liu // SoA property can only be applied on singleton level. 822088c7ce4SPeiming Liu auto soaLvls = llvm::make_filter_range(lvlTypes, [](LevelType lt) { 823088c7ce4SPeiming Liu return lt.isa<LevelPropNonDefault::SoA>(); 824088c7ce4SPeiming Liu }); 825088c7ce4SPeiming Liu if (llvm::any_of(soaLvls, [](LevelType lt) { 826088c7ce4SPeiming Liu return !lt.isa<LevelFormat::Singleton>(); 827088c7ce4SPeiming Liu })) { 828088c7ce4SPeiming Liu return emitError() << "SoA is only applicable to singleton lvlTypes."; 829088c7ce4SPeiming Liu } 830088c7ce4SPeiming Liu 8312a6b521bSYinying Li // TODO: audit formats that actually are supported by backend. 8322a6b521bSYinying Li if (auto it = std::find_if(lvlTypes.begin(), lvlTypes.end(), isNOutOfMLT); 8332a6b521bSYinying Li it != std::end(lvlTypes)) { 8342a6b521bSYinying Li if (it != lvlTypes.end() - 1) 8352a6b521bSYinying Li return emitError() << "expected n_out_of_m to be the last level type"; 8362a6b521bSYinying Li if (!std::all_of(lvlTypes.begin(), it, 8372a6b521bSYinying Li [](LevelType i) { return isDenseLT(i); })) 8382a6b521bSYinying Li return emitError() << "expected all dense lvlTypes " 8392a6b521bSYinying Li "before a n_out_of_m level"; 8402a6b521bSYinying Li if (dimToLvl && (dimToLvl.getNumDims() != dimToLvl.getNumResults())) { 8412a6b521bSYinying Li if (!isBlockSparsity(dimToLvl)) { 8422a6b521bSYinying Li return emitError() 8432a6b521bSYinying Li << "expected 1xm block structure for n_out_of_m level"; 8442a6b521bSYinying Li } 8452a6b521bSYinying Li auto sizes = getBlockSize(dimToLvl); 8462a6b521bSYinying Li unsigned coefficient = 0; 8472a6b521bSYinying Li for (const auto &elem : sizes) { 8482a6b521bSYinying Li if (elem != 0) { 8492a6b521bSYinying Li if (elem != coefficient && coefficient != 0) { 8502a6b521bSYinying Li return emitError() << "expected only one blocked level " 8512a6b521bSYinying Li "with the same coefficients"; 8522a6b521bSYinying Li } 8532a6b521bSYinying Li coefficient = elem; 8542a6b521bSYinying Li } 8552a6b521bSYinying Li } 8562a6b521bSYinying Li if (coefficient != getM(*it)) { 8572a6b521bSYinying Li return emitError() << "expected coeffiencts of Affine expressions " 8582a6b521bSYinying Li "to be equal to m of n_out_of_m level"; 8592a6b521bSYinying Li } 8602a6b521bSYinying Li } 8612a6b521bSYinying Li } 862f708a549Swren romano // Before we can check that the level-rank is consistent/coherent 863f708a549Swren romano // across all fields, we need to define it. The source-of-truth for 864f708a549Swren romano // the `getLvlRank` method is the length of the level-types array, 865f708a549Swren romano // since it must always be provided and have full rank; therefore we 866f708a549Swren romano // use that same source-of-truth here. 867a0615d02Swren romano const Level lvlRank = lvlTypes.size(); 868f708a549Swren romano if (lvlRank == 0) 869d9a2f89bSwren romano return emitError() << "expected a non-empty array for lvlTypes"; 87076647fceSwren romano // We save `dimRank` here because we'll also need it to verify `dimSlices`. 87176647fceSwren romano const Dimension dimRank = dimToLvl ? dimToLvl.getNumDims() : lvlRank; 87276647fceSwren romano if (dimToLvl) { 87376647fceSwren romano if (dimToLvl.getNumResults() != lvlRank) 8740a292199SAart Bik return emitError() 87576647fceSwren romano << "level-rank mismatch between dimToLvl and lvlTypes: " 87676647fceSwren romano << dimToLvl.getNumResults() << " != " << lvlRank; 8777b9fb1c2SYinying Li auto inferRes = inferLvlToDim(dimToLvl, dimToLvl.getContext()); 8787b9fb1c2SYinying Li // Symbols can't be inferred but are acceptable. 8797b9fb1c2SYinying Li if (!inferRes && dimToLvl.getNumSymbols() == 0) 8807b9fb1c2SYinying Li return emitError() << "failed to infer lvlToDim from dimToLvl"; 8817b9fb1c2SYinying Li if (lvlToDim && (inferRes != lvlToDim)) 8827b9fb1c2SYinying Li return emitError() << "expected lvlToDim to be an inverse of dimToLvl"; 88376647fceSwren romano if (dimRank > lvlRank) 88476647fceSwren romano return emitError() << "unexpected dimToLvl mapping from " << dimRank 88576647fceSwren romano << " to " << lvlRank; 8860a292199SAart Bik } 88776647fceSwren romano if (!dimSlices.empty()) { 88876647fceSwren romano if (dimSlices.size() != dimRank) 889d9a2f89bSwren romano return emitError() 89076647fceSwren romano << "dimension-rank mismatch between dimSlices and dimToLvl: " 89176647fceSwren romano << dimSlices.size() << " != " << dimRank; 89276647fceSwren romano // Compiler support for `dimSlices` currently requires that the two 89376647fceSwren romano // ranks agree. (However, it does allow `dimToLvl` to be a permutation.) 89476647fceSwren romano if (dimRank != lvlRank) 89576647fceSwren romano return emitError() 89676647fceSwren romano << "dimSlices expected dimension-rank to match level-rank: " 89776647fceSwren romano << dimRank << " != " << lvlRank; 898885a1f43SPeiming Liu } 8990a292199SAart Bik return success(); 9000a292199SAart Bik } 9010a292199SAart Bik 9020a292199SAart Bik LogicalResult SparseTensorEncodingAttr::verifyEncoding( 90322212ca7SAart Bik ArrayRef<Size> dimShape, Type elementType, 9040a292199SAart Bik function_ref<InFlightDiagnostic()> emitError) const { 905f708a549Swren romano // Check structural integrity. In particular, this ensures that the 906f708a549Swren romano // level-rank is coherent across all the fields. 9072e2011daSAart Bik if (failed(verify(emitError, getLvlTypes(), getDimToLvl(), getLvlToDim(), 908a10d67f9SYinying Li getPosWidth(), getCrdWidth(), getExplicitVal(), 909a10d67f9SYinying Li getImplicitVal(), getDimSlices()))) 9102e2011daSAart Bik return failure(); 911f708a549Swren romano // Check integrity with tensor type specifics. In particular, we 912f708a549Swren romano // need only check that the dimension-rank of the tensor agrees with 913f708a549Swren romano // the dimension-rank of the encoding. 914f708a549Swren romano const Dimension dimRank = dimShape.size(); 915f708a549Swren romano if (dimRank == 0) 9164aa9b398SAart Bik return emitError() << "expected non-scalar sparse tensor"; 91776647fceSwren romano if (getDimRank() != dimRank) 91876647fceSwren romano return emitError() 91976647fceSwren romano << "dimension-rank mismatch between encoding and tensor shape: " 92076647fceSwren romano << getDimRank() << " != " << dimRank; 92183f3b1cbSYinying Li if (auto expVal = getExplicitVal()) { 92283f3b1cbSYinying Li Type attrType = llvm::dyn_cast<TypedAttr>(expVal).getType(); 92383f3b1cbSYinying Li if (attrType != elementType) { 92483f3b1cbSYinying Li return emitError() << "explicit value type mismatch between encoding and " 92583f3b1cbSYinying Li << "tensor element type: " << attrType 92683f3b1cbSYinying Li << " != " << elementType; 92783f3b1cbSYinying Li } 92883f3b1cbSYinying Li } 92983f3b1cbSYinying Li if (auto impVal = getImplicitVal()) { 93083f3b1cbSYinying Li Type attrType = llvm::dyn_cast<TypedAttr>(impVal).getType(); 93183f3b1cbSYinying Li if (attrType != elementType) { 93283f3b1cbSYinying Li return emitError() << "implicit value type mismatch between encoding and " 93383f3b1cbSYinying Li << "tensor element type: " << attrType 93483f3b1cbSYinying Li << " != " << elementType; 93583f3b1cbSYinying Li } 93683f3b1cbSYinying Li // Currently, we only support zero as the implicit value. 93783f3b1cbSYinying Li auto impFVal = llvm::dyn_cast<FloatAttr>(impVal); 93883f3b1cbSYinying Li auto impIntVal = llvm::dyn_cast<IntegerAttr>(impVal); 93983f3b1cbSYinying Li auto impComplexVal = llvm::dyn_cast<complex::NumberAttr>(impVal); 94083f3b1cbSYinying Li if ((impFVal && impFVal.getValue().isNonZero()) || 94183f3b1cbSYinying Li (impIntVal && !impIntVal.getValue().isZero()) || 94283f3b1cbSYinying Li (impComplexVal && (impComplexVal.getImag().isNonZero() || 94383f3b1cbSYinying Li impComplexVal.getReal().isNonZero()))) { 94483f3b1cbSYinying Li return emitError() << "implicit value must be zero"; 94583f3b1cbSYinying Li } 94683f3b1cbSYinying Li } 9470a292199SAart Bik return success(); 9480a292199SAart Bik } 9490a292199SAart Bik 95083f3b1cbSYinying Li Level mlir::sparse_tensor::SparseTensorEncodingAttr::getAoSCOOStart() const { 951f740366fSPeiming Liu SmallVector<COOSegment> coo = getCOOSegments(); 9525248a987SPeiming Liu assert(coo.size() == 1 || coo.empty()); 9535248a987SPeiming Liu if (!coo.empty() && coo.front().isAoS()) { 954f740366fSPeiming Liu return coo.front().lvlRange.first; 955f740366fSPeiming Liu } 95683f3b1cbSYinying Li return getLvlRank(); 9575b729503SAart Bik } 9585b729503SAart Bik 959f740366fSPeiming Liu SmallVector<COOSegment> 96083f3b1cbSYinying Li mlir::sparse_tensor::SparseTensorEncodingAttr::getCOOSegments() const { 961f740366fSPeiming Liu SmallVector<COOSegment> ret; 96283f3b1cbSYinying Li if (getLvlRank() <= 1) 963f740366fSPeiming Liu return ret; 964f740366fSPeiming Liu 965f740366fSPeiming Liu ArrayRef<LevelType> lts = getLvlTypes(); 966f740366fSPeiming Liu Level l = 0; 96783f3b1cbSYinying Li while (l < getLvlRank()) { 968f740366fSPeiming Liu auto lt = lts[l]; 969f740366fSPeiming Liu if (lt.isa<LevelFormat::Compressed, LevelFormat::LooseCompressed>()) { 970f740366fSPeiming Liu auto cur = lts.begin() + l; 971f740366fSPeiming Liu auto end = std::find_if(cur + 1, lts.end(), [](LevelType lt) { 972f740366fSPeiming Liu return !lt.isa<LevelFormat::Singleton>(); 973f740366fSPeiming Liu }); 974f740366fSPeiming Liu unsigned cooLen = std::distance(cur, end); 975f740366fSPeiming Liu if (cooLen > 1) { 976f740366fSPeiming Liu // To support mixed SoA/AoS COO, we should break the segment when the 977f740366fSPeiming Liu // storage scheme changes, for now we faithfully assume that all 978f740366fSPeiming Liu // consecutive singleton levels have the same storage format as verified 979f740366fSPeiming Liu // STEA. 980f740366fSPeiming Liu ret.push_back(COOSegment{std::make_pair(l, l + cooLen), 981f740366fSPeiming Liu lts[l + 1].isa<LevelPropNonDefault::SoA>()}); 982f740366fSPeiming Liu } 983f740366fSPeiming Liu l += cooLen; 984f740366fSPeiming Liu } else { 985f740366fSPeiming Liu l++; 986f740366fSPeiming Liu } 987f740366fSPeiming Liu } 988f740366fSPeiming Liu return ret; 989f740366fSPeiming Liu } 990f740366fSPeiming Liu 99183f3b1cbSYinying Li //===----------------------------------------------------------------------===// 99283f3b1cbSYinying Li // SparseTensorType Methods. 99383f3b1cbSYinying Li //===----------------------------------------------------------------------===// 99483f3b1cbSYinying Li 99583f3b1cbSYinying Li bool mlir::sparse_tensor::SparseTensorType::isCOOType(Level startLvl, 99683f3b1cbSYinying Li bool isUnique) const { 99783f3b1cbSYinying Li if (!hasEncoding()) 99883f3b1cbSYinying Li return false; 99983f3b1cbSYinying Li if (!isCompressedLvl(startLvl) && !isLooseCompressedLvl(startLvl)) 100083f3b1cbSYinying Li return false; 100183f3b1cbSYinying Li for (Level l = startLvl + 1; l < lvlRank; ++l) 100283f3b1cbSYinying Li if (!isSingletonLvl(l)) 100383f3b1cbSYinying Li return false; 100483f3b1cbSYinying Li // If isUnique is true, then make sure that the last level is unique, 100583f3b1cbSYinying Li // that is, when lvlRank == 1, the only compressed level is unique, 100683f3b1cbSYinying Li // and when lvlRank > 1, the last singleton is unique. 100783f3b1cbSYinying Li return !isUnique || isUniqueLvl(lvlRank - 1); 100883f3b1cbSYinying Li } 100983f3b1cbSYinying Li 101045288085SAart Bik RankedTensorType 101145288085SAart Bik mlir::sparse_tensor::SparseTensorType::getCOOType(bool ordered) const { 101245288085SAart Bik SmallVector<LevelType> lvlTypes; 101345288085SAart Bik lvlTypes.reserve(lvlRank); 101445288085SAart Bik // A non-unique compressed level at beginning (unless this is 101545288085SAart Bik // also the last level, then it is unique). 101645288085SAart Bik lvlTypes.push_back( 101745288085SAart Bik *buildLevelType(LevelFormat::Compressed, ordered, lvlRank == 1)); 101845288085SAart Bik if (lvlRank > 1) { 101945288085SAart Bik // Followed by n-2 non-unique singleton levels. 102045288085SAart Bik std::fill_n(std::back_inserter(lvlTypes), lvlRank - 2, 102145288085SAart Bik *buildLevelType(LevelFormat::Singleton, ordered, false)); 102245288085SAart Bik // Ends by a unique singleton level. 102345288085SAart Bik lvlTypes.push_back(*buildLevelType(LevelFormat::Singleton, ordered, true)); 102445288085SAart Bik } 1025a10d67f9SYinying Li auto enc = SparseTensorEncodingAttr::get( 1026a10d67f9SYinying Li getContext(), lvlTypes, getDimToLvl(), getLvlToDim(), getPosWidth(), 1027a10d67f9SYinying Li getCrdWidth(), getExplicitVal(), getImplicitVal()); 102845288085SAart Bik return RankedTensorType::get(getDimShape(), getElementType(), enc); 102945288085SAart Bik } 103045288085SAart Bik 103145288085SAart Bik //===----------------------------------------------------------------------===// 103245288085SAart Bik // Convenience Methods. 10334d068619SAart Bik //===----------------------------------------------------------------------===// 10344d068619SAart Bik 103596a23911SAart Bik SparseTensorEncodingAttr 103696a23911SAart Bik mlir::sparse_tensor::getSparseTensorEncoding(Type type) { 1037c1fa60b4STres Popp if (auto ttp = llvm::dyn_cast<RankedTensorType>(type)) 1038c1fa60b4STres Popp return llvm::dyn_cast_or_null<SparseTensorEncodingAttr>(ttp.getEncoding()); 1039c1fa60b4STres Popp if (auto mdtp = llvm::dyn_cast<StorageSpecifierType>(type)) 104071cc0f1cSPeiming Liu return mdtp.getEncoding(); 104196a23911SAart Bik return nullptr; 104296a23911SAart Bik } 104396a23911SAart Bik 1044d4088e7dSYinying Li AffineMap mlir::sparse_tensor::inferLvlToDim(AffineMap dimToLvl, 1045d4088e7dSYinying Li MLIRContext *context) { 1046d4088e7dSYinying Li auto map = static_cast<AffineMap>(dimToLvl); 1047d4088e7dSYinying Li AffineMap lvlToDim; 1048d4088e7dSYinying Li // Return an empty lvlToDim when inference is not successful. 1049d4088e7dSYinying Li if (!map || map.getNumSymbols() != 0) { 1050d4088e7dSYinying Li lvlToDim = AffineMap(); 1051d4088e7dSYinying Li } else if (map.isPermutation()) { 1052d4088e7dSYinying Li lvlToDim = inversePermutation(map); 10537b9fb1c2SYinying Li } else if (isBlockSparsity(map)) { 1054d4088e7dSYinying Li lvlToDim = inverseBlockSparsity(map, context); 1055d4088e7dSYinying Li } 1056d4088e7dSYinying Li return lvlToDim; 1057d4088e7dSYinying Li } 1058d4088e7dSYinying Li 1059d4088e7dSYinying Li AffineMap mlir::sparse_tensor::inverseBlockSparsity(AffineMap dimToLvl, 1060d4088e7dSYinying Li MLIRContext *context) { 1061d4088e7dSYinying Li SmallVector<AffineExpr> lvlExprs; 1062d4088e7dSYinying Li auto numLvls = dimToLvl.getNumResults(); 1063d4088e7dSYinying Li lvlExprs.reserve(numLvls); 1064d4088e7dSYinying Li // lvlExprComponents stores information of the floordiv and mod operations 1065d4088e7dSYinying Li // applied to the same dimension, so as to build the lvlToDim map. 1066d4088e7dSYinying Li std::map<unsigned, SmallVector<AffineExpr, 3>> lvlExprComponents; 1067d4088e7dSYinying Li for (unsigned i = 0, n = numLvls; i < n; i++) { 1068d4088e7dSYinying Li auto result = dimToLvl.getResult(i); 10691609f1c2Slong.chen if (auto binOp = dyn_cast<AffineBinaryOpExpr>(result)) { 1070d4088e7dSYinying Li if (result.getKind() == AffineExprKind::FloorDiv) { 1071d4088e7dSYinying Li // Position of the dimension in dimToLvl. 10721609f1c2Slong.chen auto pos = dyn_cast<AffineDimExpr>(binOp.getLHS()).getPosition(); 1073d4088e7dSYinying Li assert(lvlExprComponents.find(pos) == lvlExprComponents.end() && 1074d4088e7dSYinying Li "expected only one floordiv for each dimension"); 1075d4088e7dSYinying Li SmallVector<AffineExpr, 3> components; 1076d4088e7dSYinying Li // Level variable for floordiv. 1077d4088e7dSYinying Li components.push_back(getAffineDimExpr(i, context)); 1078d4088e7dSYinying Li // Multiplier. 1079d4088e7dSYinying Li components.push_back(binOp.getRHS()); 1080d4088e7dSYinying Li // Map key is the position of the dimension. 1081d4088e7dSYinying Li lvlExprComponents[pos] = components; 1082d4088e7dSYinying Li } else if (result.getKind() == AffineExprKind::Mod) { 10831609f1c2Slong.chen auto pos = dyn_cast<AffineDimExpr>(binOp.getLHS()).getPosition(); 1084d4088e7dSYinying Li assert(lvlExprComponents.find(pos) != lvlExprComponents.end() && 1085d4088e7dSYinying Li "expected floordiv before mod"); 1086d4088e7dSYinying Li // Add level variable for mod to the same vector 1087d4088e7dSYinying Li // of the corresponding floordiv. 1088d4088e7dSYinying Li lvlExprComponents[pos].push_back(getAffineDimExpr(i, context)); 1089d4088e7dSYinying Li } else { 1090d4088e7dSYinying Li assert(false && "expected floordiv or mod"); 1091d4088e7dSYinying Li } 1092d4088e7dSYinying Li } else { 1093d4088e7dSYinying Li lvlExprs.push_back(getAffineDimExpr(i, context)); 1094d4088e7dSYinying Li } 1095d4088e7dSYinying Li } 1096d4088e7dSYinying Li // Build lvlExprs from lvlExprComponents. 1097d4088e7dSYinying Li // For example, for il = i floordiv 2 and ii = i mod 2, the components 1098d4088e7dSYinying Li // would be [il, 2, ii]. It could be used to build the AffineExpr 1099d4088e7dSYinying Li // i = il * 2 + ii in lvlToDim. 1100d4088e7dSYinying Li for (auto &components : lvlExprComponents) { 1101d4088e7dSYinying Li assert(components.second.size() == 3 && 1102d4088e7dSYinying Li "expected 3 components to build lvlExprs"); 1103d4088e7dSYinying Li auto mulOp = getAffineBinaryOpExpr( 1104d4088e7dSYinying Li AffineExprKind::Mul, components.second[0], components.second[1]); 1105d4088e7dSYinying Li auto addOp = 1106d4088e7dSYinying Li getAffineBinaryOpExpr(AffineExprKind::Add, mulOp, components.second[2]); 1107d4088e7dSYinying Li lvlExprs.push_back(addOp); 1108d4088e7dSYinying Li } 1109d4088e7dSYinying Li return dimToLvl.get(dimToLvl.getNumResults(), 0, lvlExprs, context); 1110d4088e7dSYinying Li } 1111d4088e7dSYinying Li 11127b9fb1c2SYinying Li SmallVector<unsigned> mlir::sparse_tensor::getBlockSize(AffineMap dimToLvl) { 11137b9fb1c2SYinying Li assert(isBlockSparsity(dimToLvl) && 11147b9fb1c2SYinying Li "expected dimToLvl to be block sparsity for calling getBlockSize"); 11157b9fb1c2SYinying Li SmallVector<unsigned> blockSize; 11167b9fb1c2SYinying Li for (auto result : dimToLvl.getResults()) { 11171609f1c2Slong.chen if (auto binOp = dyn_cast<AffineBinaryOpExpr>(result)) { 11187b9fb1c2SYinying Li if (result.getKind() == AffineExprKind::Mod) { 11197b9fb1c2SYinying Li blockSize.push_back( 11201609f1c2Slong.chen dyn_cast<AffineConstantExpr>(binOp.getRHS()).getValue()); 11217b9fb1c2SYinying Li } 11227b9fb1c2SYinying Li } else { 11237b9fb1c2SYinying Li blockSize.push_back(0); 11247b9fb1c2SYinying Li } 11257b9fb1c2SYinying Li } 11267b9fb1c2SYinying Li return blockSize; 11277b9fb1c2SYinying Li } 11287b9fb1c2SYinying Li 11297b9fb1c2SYinying Li bool mlir::sparse_tensor::isBlockSparsity(AffineMap dimToLvl) { 11307b9fb1c2SYinying Li if (!dimToLvl) 11317b9fb1c2SYinying Li return false; 11327b9fb1c2SYinying Li std::map<unsigned, int64_t> coeffientMap; 113331b72b07SYinying Li bool hasBlock = false; 11347b9fb1c2SYinying Li for (auto result : dimToLvl.getResults()) { 11351609f1c2Slong.chen if (auto binOp = dyn_cast<AffineBinaryOpExpr>(result)) { 11363d3e46ccSAart Bik // Check for "dim op const". 11373d3e46ccSAart Bik auto dimOp = dyn_cast<AffineDimExpr>(binOp.getLHS()); 11383d3e46ccSAart Bik auto conOp = dyn_cast<AffineConstantExpr>(binOp.getRHS()); 113931b72b07SYinying Li if (!dimOp || !conOp || conOp.getValue() <= 0) 11403d3e46ccSAart Bik return false; 11413d3e46ccSAart Bik // Inspect "dim / const" or "dim % const". 11423d3e46ccSAart Bik auto pos = dimOp.getPosition(); 11433d3e46ccSAart Bik if (binOp.getKind() == AffineExprKind::FloorDiv) { 11447b9fb1c2SYinying Li // Expect only one floordiv for each dimension. 1145*af6e1881SKazu Hirata auto [it, inserted] = coeffientMap.try_emplace(pos); 1146*af6e1881SKazu Hirata if (!inserted) 11477b9fb1c2SYinying Li return false; 11483d3e46ccSAart Bik // Record coefficient of the floordiv. 1149*af6e1881SKazu Hirata it->second = conOp.getValue(); 11503d3e46ccSAart Bik } else if (binOp.getKind() == AffineExprKind::Mod) { 11517b9fb1c2SYinying Li // Expect floordiv before mod. 1152*af6e1881SKazu Hirata auto it = coeffientMap.find(pos); 1153*af6e1881SKazu Hirata if (it == coeffientMap.end()) 11547b9fb1c2SYinying Li return false; 11557b9fb1c2SYinying Li // Expect mod to have the same coefficient as floordiv. 1156*af6e1881SKazu Hirata if (conOp.getValue() != it->second) 11577b9fb1c2SYinying Li return false; 115831b72b07SYinying Li hasBlock = true; 115931b72b07SYinying Li } else { 116031b72b07SYinying Li return false; 116131b72b07SYinying Li } 116231b72b07SYinying Li } else if (auto dimOp = dyn_cast<AffineDimExpr>(result)) { 116331b72b07SYinying Li auto pos = dimOp.getPosition(); 116431b72b07SYinying Li // Expect dim to be unset. 11652077fb80SKazu Hirata if (!coeffientMap.try_emplace(pos, 0).second) 116631b72b07SYinying Li return false; 11677b9fb1c2SYinying Li } else { 11687b9fb1c2SYinying Li return false; 11697b9fb1c2SYinying Li } 11707b9fb1c2SYinying Li } 117131b72b07SYinying Li return hasBlock; 11727b9fb1c2SYinying Li } 11737b9fb1c2SYinying Li 117406a65ce5SPeiming Liu bool mlir::sparse_tensor::hasAnyNonIdentityOperandsOrResults(Operation *op) { 117506a65ce5SPeiming Liu auto hasNonIdentityMap = [](Value v) { 117606a65ce5SPeiming Liu auto stt = tryGetSparseTensorType(v); 117706a65ce5SPeiming Liu return stt && !stt->isIdentity(); 117806a65ce5SPeiming Liu }; 117906a65ce5SPeiming Liu 118006a65ce5SPeiming Liu return llvm::any_of(op->getOperands(), hasNonIdentityMap) || 118106a65ce5SPeiming Liu llvm::any_of(op->getResults(), hasNonIdentityMap); 118206a65ce5SPeiming Liu } 118306a65ce5SPeiming Liu 11844e2f1521SPeiming Liu Dimension mlir::sparse_tensor::toDim(SparseTensorEncodingAttr enc, Level l) { 1185f231821eSAart Bik if (enc) { 11864e2f1521SPeiming Liu assert(enc.isPermutation() && "Non permutation map not supported"); 11874e2f1521SPeiming Liu if (const auto dimToLvl = enc.getDimToLvl()) 118876647fceSwren romano return dimToLvl.getDimPosition(l); 1189f231821eSAart Bik } 1190f708a549Swren romano return l; 1191f231821eSAart Bik } 1192f231821eSAart Bik 11934e2f1521SPeiming Liu Level mlir::sparse_tensor::toLvl(SparseTensorEncodingAttr enc, Dimension d) { 1194f231821eSAart Bik if (enc) { 11954e2f1521SPeiming Liu assert(enc.isPermutation() && "Non permutation map not supported"); 11964e2f1521SPeiming Liu if (const auto lvlToDim = enc.getLvlToDim()) 11974e2f1521SPeiming Liu return lvlToDim.getDimPosition(d); 1198f231821eSAart Bik } 1199f231821eSAart Bik return d; 1200f231821eSAart Bik } 1201f231821eSAart Bik 1202083ddffeSPeiming Liu /// We normalized sparse tensor encoding attribute by always using 12031dd387e1SAart Bik /// ordered/unique LT such that "compressed_nu_no" and "compressed_nu" (as well 1204083ddffeSPeiming Liu /// as other variants) lead to the same storage specifier type, and stripping 12056db397a8SPeiming Liu /// irrelevant fields that do not alter the sparse tensor memory layout. 1206083ddffeSPeiming Liu static SparseTensorEncodingAttr 1207083ddffeSPeiming Liu getNormalizedEncodingForSpecifier(SparseTensorEncodingAttr enc) { 12081944c4f7SAart Bik SmallVector<LevelType> lts; 12091dd387e1SAart Bik for (auto lt : enc.getLvlTypes()) 12105248a987SPeiming Liu lts.push_back(lt.stripStorageIrrelevantProperties()); 1211083ddffeSPeiming Liu 1212083ddffeSPeiming Liu return SparseTensorEncodingAttr::get( 12131dd387e1SAart Bik enc.getContext(), lts, 121476647fceSwren romano AffineMap(), // dimToLvl (irrelevant to storage specifier) 1215836411b9SAart Bik AffineMap(), // lvlToDim (irrelevant to storage specifier) 121684cd51bbSwren romano // Always use `index` for memSize and lvlSize instead of reusing 12176db397a8SPeiming Liu // `getPosWidth` and `getCrdWidth`. It allows us to reuse the same SSA 12186db397a8SPeiming Liu // value for different bitwidth, it also avoids casting between index and 12196db397a8SPeiming Liu // integer (returned by DimOp) 1220a10d67f9SYinying Li 0, 0, 1221a10d67f9SYinying Li Attribute(), // explicitVal (irrelevant to storage specifier) 1222a10d67f9SYinying Li Attribute(), // implicitVal (irrelevant to storage specifier) 1223a10d67f9SYinying Li enc.getDimSlices()); 1224083ddffeSPeiming Liu } 1225083ddffeSPeiming Liu 1226083ddffeSPeiming Liu StorageSpecifierType 1227083ddffeSPeiming Liu StorageSpecifierType::get(MLIRContext *ctx, SparseTensorEncodingAttr encoding) { 1228083ddffeSPeiming Liu return Base::get(ctx, getNormalizedEncodingForSpecifier(encoding)); 1229083ddffeSPeiming Liu } 1230083ddffeSPeiming Liu 12317359a6b7SMatthias Springer StorageSpecifierType 12327359a6b7SMatthias Springer StorageSpecifierType::getChecked(function_ref<InFlightDiagnostic()> emitError, 12337359a6b7SMatthias Springer MLIRContext *ctx, 12347359a6b7SMatthias Springer SparseTensorEncodingAttr encoding) { 12357359a6b7SMatthias Springer return Base::getChecked(emitError, ctx, 12367359a6b7SMatthias Springer getNormalizedEncodingForSpecifier(encoding)); 12377359a6b7SMatthias Springer } 12387359a6b7SMatthias Springer 123971cc0f1cSPeiming Liu //===----------------------------------------------------------------------===// 124071cc0f1cSPeiming Liu // SparseTensorDialect Operations. 124196a23911SAart Bik //===----------------------------------------------------------------------===// 124296a23911SAart Bik 124384cd51bbSwren romano static LogicalResult lvlIsInBounds(Level lvl, Value tensor) { 124484cd51bbSwren romano return success(lvl < getSparseTensorType(tensor).getLvlRank()); 124596a23911SAart Bik } 124696a23911SAart Bik 124784cd51bbSwren romano static LogicalResult isMatchingWidth(Value mem, unsigned width) { 124884cd51bbSwren romano const Type etp = getMemRefType(mem).getElementType(); 1249743fbcb7Swren romano return success(width == 0 ? etp.isIndex() : etp.isInteger(width)); 125096a23911SAart Bik } 125196a23911SAart Bik 125222426110SRamkumar Ramachandra static LogicalResult verifySparsifierGetterSetter( 125384cd51bbSwren romano StorageSpecifierKind mdKind, std::optional<Level> lvl, 125422426110SRamkumar Ramachandra TypedValue<StorageSpecifierType> md, Operation *op) { 1255f708a549Swren romano if (mdKind == StorageSpecifierKind::ValMemSize && lvl) { 125671cc0f1cSPeiming Liu return op->emitError( 1257f708a549Swren romano "redundant level argument for querying value memory size"); 125871cc0f1cSPeiming Liu } 125971cc0f1cSPeiming Liu 1260f708a549Swren romano const auto enc = md.getType().getEncoding(); 1261f708a549Swren romano const Level lvlRank = enc.getLvlRank(); 126271cc0f1cSPeiming Liu 12636db397a8SPeiming Liu if (mdKind == StorageSpecifierKind::DimOffset || 12646db397a8SPeiming Liu mdKind == StorageSpecifierKind::DimStride) 12656db397a8SPeiming Liu if (!enc.isSlice()) 12666db397a8SPeiming Liu return op->emitError("requested slice data on non-slice tensor"); 12678237cac6SPeiming Liu 126871cc0f1cSPeiming Liu if (mdKind != StorageSpecifierKind::ValMemSize) { 1269f708a549Swren romano if (!lvl) 1270f708a549Swren romano return op->emitError("missing level argument"); 127171cc0f1cSPeiming Liu 127284cd51bbSwren romano const Level l = lvl.value(); 1273f708a549Swren romano if (l >= lvlRank) 127484cd51bbSwren romano return op->emitError("requested level is out of bounds"); 127571cc0f1cSPeiming Liu 127684cd51bbSwren romano if (mdKind == StorageSpecifierKind::PosMemSize && enc.isSingletonLvl(l)) 127771cc0f1cSPeiming Liu return op->emitError( 127884cd51bbSwren romano "requested position memory size on a singleton level"); 127971cc0f1cSPeiming Liu } 128071cc0f1cSPeiming Liu return success(); 128171cc0f1cSPeiming Liu } 128271cc0f1cSPeiming Liu 1283de560888SPeiming Liu static Type getFieldElemType(SparseTensorType stt, SparseTensorFieldKind kind) { 1284de560888SPeiming Liu switch (kind) { 1285de560888SPeiming Liu case SparseTensorFieldKind::CrdMemRef: 1286de560888SPeiming Liu return stt.getCrdType(); 1287de560888SPeiming Liu case SparseTensorFieldKind::PosMemRef: 1288de560888SPeiming Liu return stt.getPosType(); 1289de560888SPeiming Liu case SparseTensorFieldKind::ValMemRef: 1290de560888SPeiming Liu return stt.getElementType(); 1291de560888SPeiming Liu case SparseTensorFieldKind::StorageSpec: 1292de560888SPeiming Liu return nullptr; 1293de560888SPeiming Liu } 1294de560888SPeiming Liu llvm_unreachable("Unrecognizable FieldKind"); 12957864d736SPeiming Liu } 12967864d736SPeiming Liu 1297de560888SPeiming Liu static LogicalResult verifyPackUnPack(Operation *op, bool requiresStaticShape, 1298de560888SPeiming Liu SparseTensorType stt, 1299de560888SPeiming Liu RankedTensorType valTp, 1300de560888SPeiming Liu TypeRange lvlTps) { 1301de560888SPeiming Liu if (requiresStaticShape && !stt.hasStaticDimShape()) 1302de560888SPeiming Liu return op->emitError("the sparse-tensor must have static shape"); 1303de560888SPeiming Liu if (!stt.hasEncoding()) 1304de560888SPeiming Liu return op->emitError("the sparse-tensor must have an encoding attribute"); 1305de560888SPeiming Liu 1306de560888SPeiming Liu // Verifies the trailing COO. 13075248a987SPeiming Liu Level cooStartLvl = stt.getAoSCOOStart(); 1308de560888SPeiming Liu if (cooStartLvl < stt.getLvlRank()) { 1309de560888SPeiming Liu // We only supports trailing COO for now, must be the last input. 131068f58812STres Popp auto cooTp = llvm::cast<ShapedType>(lvlTps.back()); 1311de560888SPeiming Liu // The coordinates should be in shape of <? x rank> 1312de560888SPeiming Liu unsigned expCOORank = stt.getLvlRank() - cooStartLvl; 1313de560888SPeiming Liu if (cooTp.getRank() != 2 || expCOORank != cooTp.getShape().back()) { 131477f8297cSMatthias Springer return op->emitError("input/output trailing COO level-ranks don't match"); 1315de560888SPeiming Liu } 1316de560888SPeiming Liu } 1317de560888SPeiming Liu 1318de560888SPeiming Liu // Verifies that all types match. 1319de560888SPeiming Liu StorageLayout layout(stt.getEncoding()); 1320de560888SPeiming Liu if (layout.getNumDataFields() != lvlTps.size() + 1) // plus one value memref 1321de560888SPeiming Liu return op->emitError("inconsistent number of fields between input/output"); 1322de560888SPeiming Liu 1323de560888SPeiming Liu unsigned idx = 0; 1324de560888SPeiming Liu bool misMatch = false; 1325de560888SPeiming Liu layout.foreachField([&idx, &misMatch, stt, valTp, 1326de560888SPeiming Liu lvlTps](FieldIndex fid, SparseTensorFieldKind fKind, 13271944c4f7SAart Bik Level lvl, LevelType lt) -> bool { 1328de560888SPeiming Liu if (fKind == SparseTensorFieldKind::StorageSpec) 1329de560888SPeiming Liu return true; 1330de560888SPeiming Liu 1331de560888SPeiming Liu Type inputTp = nullptr; 1332de560888SPeiming Liu if (fKind == SparseTensorFieldKind::ValMemRef) { 1333de560888SPeiming Liu inputTp = valTp; 1334de560888SPeiming Liu } else { 13351dd387e1SAart Bik assert(fid == idx && stt.getLvlType(lvl) == lt); 1336de560888SPeiming Liu inputTp = lvlTps[idx++]; 1337de560888SPeiming Liu } 1338de560888SPeiming Liu // The input element type and expected element type should match. 133968f58812STres Popp Type inpElemTp = llvm::cast<TensorType>(inputTp).getElementType(); 1340de560888SPeiming Liu Type expElemTp = getFieldElemType(stt, fKind); 1341de560888SPeiming Liu if (inpElemTp != expElemTp) { 1342de560888SPeiming Liu misMatch = true; 1343de560888SPeiming Liu return false; // to terminate the iteration 1344de560888SPeiming Liu } 1345de560888SPeiming Liu return true; 1346de560888SPeiming Liu }); 1347de560888SPeiming Liu 1348de560888SPeiming Liu if (misMatch) 1349de560888SPeiming Liu return op->emitError("input/output element-types don't match"); 1350de560888SPeiming Liu return success(); 1351de560888SPeiming Liu } 1352de560888SPeiming Liu 13536ca47eb4SPeiming Liu LogicalResult AssembleOp::verify() { 135477f8297cSMatthias Springer RankedTensorType valuesTp = getValues().getType(); 1355de560888SPeiming Liu const auto lvlsTp = getLevels().getTypes(); 1356de560888SPeiming Liu const auto resTp = getSparseTensorType(getResult()); 1357de560888SPeiming Liu return verifyPackUnPack(*this, true, resTp, valuesTp, lvlsTp); 13586dbca86dSPeiming Liu } 13596dbca86dSPeiming Liu 13606ca47eb4SPeiming Liu LogicalResult DisassembleOp::verify() { 1361b2e6b735SPeiming Liu if (getOutValues().getType() != getRetValues().getType()) 1362b2e6b735SPeiming Liu return emitError("output values and return value type mismatch"); 1363d4db5289SPeiming Liu 1364b2e6b735SPeiming Liu for (auto [ot, rt] : llvm::zip_equal(getOutLevels(), getRetLevels())) 1365b2e6b735SPeiming Liu if (ot.getType() != rt.getType()) 1366b2e6b735SPeiming Liu return emitError("output levels and return levels type mismatch"); 1367b2e6b735SPeiming Liu 136877f8297cSMatthias Springer RankedTensorType valuesTp = getRetValues().getType(); 1369b2e6b735SPeiming Liu const auto lvlsTp = getRetLevels().getTypes(); 1370b2e6b735SPeiming Liu const auto srcTp = getSparseTensorType(getTensor()); 1371b2e6b735SPeiming Liu return verifyPackUnPack(*this, false, srcTp, valuesTp, lvlsTp); 13726dbca86dSPeiming Liu } 13736dbca86dSPeiming Liu 1374b98dc035SRiver Riddle LogicalResult ConvertOp::verify() { 137577f8297cSMatthias Springer RankedTensorType tp1 = getSource().getType(); 137677f8297cSMatthias Springer RankedTensorType tp2 = getDest().getType(); 13771e6ef0cfSAart Bik if (tp1.getRank() != tp2.getRank()) 1378b98dc035SRiver Riddle return emitError("unexpected conversion mismatch in rank"); 137933267f40SPeiming Liu auto dstEnc = 1380c1fa60b4STres Popp llvm::dyn_cast_or_null<SparseTensorEncodingAttr>(tp2.getEncoding()); 138133267f40SPeiming Liu if (dstEnc && dstEnc.isSlice()) 138233267f40SPeiming Liu return emitError("cannot convert to a sparse tensor slice"); 138333267f40SPeiming Liu 1384697ea09dSAart Bik auto shape1 = tp1.getShape(); 1385697ea09dSAart Bik auto shape2 = tp2.getShape(); 13869d1db3d4SAart Bik // Accept size matches between the source and the destination type 13879d1db3d4SAart Bik // (e.g. 10 vs. 10, 10 vs. ?, or ? vs. ?), but reject direct mismatches or 13889d1db3d4SAart Bik // matches that would need a runtime assert (e.g. 10 vs. 20 or ? vs. 10). 1389f708a549Swren romano for (Dimension d = 0, dimRank = tp1.getRank(); d < dimRank; d++) 1390399638f9SAliia Khasanova if (shape1[d] != shape2[d] && shape2[d] != ShapedType::kDynamic) 1391b98dc035SRiver Riddle return emitError("unexpected conversion mismatch in dimension ") << d; 1392697ea09dSAart Bik return success(); 1393697ea09dSAart Bik } 1394697ea09dSAart Bik 13957df76121SMarkus Böck OpFoldResult ConvertOp::fold(FoldAdaptor adaptor) { 1396f248d0b2SPeiming Liu if (getType() == getSource().getType()) 13970128f801Sbixia1 return getSource(); 13980128f801Sbixia1 return {}; 13990128f801Sbixia1 } 14000128f801Sbixia1 1401761c9dd9SPeiming Liu bool ConvertOp::needsExtraSort() { 1402dda3dc5eSPeiming Liu SparseTensorType srcStt = getSparseTensorType(getSource()); 1403dda3dc5eSPeiming Liu SparseTensorType dstStt = getSparseTensorType(getDest()); 1404dda3dc5eSPeiming Liu 1405761c9dd9SPeiming Liu // We do not need an extra sort when returning unordered sparse tensors or 1406761c9dd9SPeiming Liu // dense tensor since dense tensor support random access. 1407dda3dc5eSPeiming Liu if (dstStt.isAllDense() || !dstStt.isAllOrdered()) 1408761c9dd9SPeiming Liu return false; 1409dda3dc5eSPeiming Liu 1410dda3dc5eSPeiming Liu if (srcStt.isAllOrdered() && dstStt.isAllOrdered() && 1411dda3dc5eSPeiming Liu srcStt.hasSameDimToLvl(dstStt)) { 1412761c9dd9SPeiming Liu return false; 1413dda3dc5eSPeiming Liu } 1414dda3dc5eSPeiming Liu 1415dda3dc5eSPeiming Liu // Source and dest tensors are ordered in different ways. We only do direct 1416dda3dc5eSPeiming Liu // dense to sparse conversion when the dense input is defined by a sparse 1417dda3dc5eSPeiming Liu // constant. Note that we can theoretically always directly convert from dense 1418dda3dc5eSPeiming Liu // inputs by rotating dense loops but it leads to bad cache locality and hurt 1419dda3dc5eSPeiming Liu // performance. 1420dda3dc5eSPeiming Liu if (auto constOp = getSource().getDefiningOp<arith::ConstantOp>()) 1421dda3dc5eSPeiming Liu if (isa<SparseElementsAttr>(constOp.getValue())) 1422dda3dc5eSPeiming Liu return false; 1423761c9dd9SPeiming Liu 1424761c9dd9SPeiming Liu return true; 1425dda3dc5eSPeiming Liu } 1426dda3dc5eSPeiming Liu 1427ff21a90eSPeiming Liu LogicalResult CrdTranslateOp::verify() { 1428ff21a90eSPeiming Liu uint64_t inRank = getEncoder().getLvlRank(); 1429ff21a90eSPeiming Liu uint64_t outRank = getEncoder().getDimRank(); 1430ff21a90eSPeiming Liu 1431ff21a90eSPeiming Liu if (getDirection() == CrdTransDirectionKind::dim2lvl) 1432ff21a90eSPeiming Liu std::swap(inRank, outRank); 1433ff21a90eSPeiming Liu 1434ff21a90eSPeiming Liu if (inRank != getInCrds().size() || outRank != getOutCrds().size()) 1435ff21a90eSPeiming Liu return emitError("Coordinate rank mismatch with encoding"); 1436ff21a90eSPeiming Liu 1437ff21a90eSPeiming Liu return success(); 1438ff21a90eSPeiming Liu } 1439ff21a90eSPeiming Liu 1440ff21a90eSPeiming Liu LogicalResult CrdTranslateOp::fold(FoldAdaptor adaptor, 1441ff21a90eSPeiming Liu SmallVectorImpl<OpFoldResult> &results) { 14426456e0bbSPeiming Liu if (getEncoder().isIdentity()) { 14436456e0bbSPeiming Liu results.assign(getInCrds().begin(), getInCrds().end()); 14446456e0bbSPeiming Liu return success(); 14456456e0bbSPeiming Liu } 1446ff21a90eSPeiming Liu if (getEncoder().isPermutation()) { 1447ff21a90eSPeiming Liu AffineMap perm = getDirection() == CrdTransDirectionKind::dim2lvl 1448ff21a90eSPeiming Liu ? getEncoder().getDimToLvl() 1449ff21a90eSPeiming Liu : getEncoder().getLvlToDim(); 1450ff21a90eSPeiming Liu for (AffineExpr exp : perm.getResults()) 14511609f1c2Slong.chen results.push_back(getInCrds()[cast<AffineDimExpr>(exp).getPosition()]); 1452ff21a90eSPeiming Liu return success(); 1453ff21a90eSPeiming Liu } 1454ff21a90eSPeiming Liu 1455ff21a90eSPeiming Liu // Fuse dim2lvl/lvl2dim pairs. 1456ff21a90eSPeiming Liu auto def = getInCrds()[0].getDefiningOp<CrdTranslateOp>(); 1457ff21a90eSPeiming Liu bool sameDef = def && llvm::all_of(getInCrds(), [def](Value v) { 1458ff21a90eSPeiming Liu return v.getDefiningOp() == def; 1459ff21a90eSPeiming Liu }); 1460ff21a90eSPeiming Liu if (!sameDef) 1461ff21a90eSPeiming Liu return failure(); 1462ff21a90eSPeiming Liu 1463ff21a90eSPeiming Liu bool oppositeDir = def.getDirection() != getDirection(); 1464ff21a90eSPeiming Liu bool sameOracle = 1465ff21a90eSPeiming Liu def.getEncoder().getDimToLvl() == getEncoder().getDimToLvl(); 1466ff21a90eSPeiming Liu bool sameCount = def.getNumResults() == getInCrds().size(); 1467ff21a90eSPeiming Liu if (!oppositeDir || !sameOracle || !sameCount) 1468ff21a90eSPeiming Liu return failure(); 1469ff21a90eSPeiming Liu 1470ff21a90eSPeiming Liu // The definition produces the coordinates in the same order as the input 1471ff21a90eSPeiming Liu // coordinates. 1472ff21a90eSPeiming Liu bool sameOrder = llvm::all_of(llvm::zip_equal(def.getOutCrds(), getInCrds()), 1473ff21a90eSPeiming Liu [](auto valuePair) { 1474ff21a90eSPeiming Liu auto [lhs, rhs] = valuePair; 1475ff21a90eSPeiming Liu return lhs == rhs; 1476ff21a90eSPeiming Liu }); 1477ff21a90eSPeiming Liu 1478ff21a90eSPeiming Liu if (!sameOrder) 1479ff21a90eSPeiming Liu return failure(); 1480ff21a90eSPeiming Liu // l1 = dim2lvl (lvl2dim l0) 1481ff21a90eSPeiming Liu // ==> l0 1482ff21a90eSPeiming Liu results.append(def.getInCrds().begin(), def.getInCrds().end()); 1483ff21a90eSPeiming Liu return success(); 1484ff21a90eSPeiming Liu } 1485ff21a90eSPeiming Liu 1486c780352dSPeiming Liu void LvlOp::build(OpBuilder &builder, OperationState &state, Value source, 1487c780352dSPeiming Liu int64_t index) { 1488c780352dSPeiming Liu Value val = builder.create<arith::ConstantIndexOp>(state.location, index); 1489c780352dSPeiming Liu return build(builder, state, source, val); 1490c780352dSPeiming Liu } 1491c780352dSPeiming Liu 1492f0f5fdf7SPeiming Liu LogicalResult LvlOp::verify() { 1493f0f5fdf7SPeiming Liu if (std::optional<uint64_t> lvl = getConstantLvlIndex()) { 1494f0f5fdf7SPeiming Liu auto stt = getSparseTensorType(getSource()); 1495f0f5fdf7SPeiming Liu if (static_cast<uint64_t>(lvl.value()) >= stt.getLvlRank()) 149677f8297cSMatthias Springer return emitError( 149777f8297cSMatthias Springer "Level index exceeds the rank of the input sparse tensor"); 1498f0f5fdf7SPeiming Liu } 1499f0f5fdf7SPeiming Liu return success(); 1500f0f5fdf7SPeiming Liu } 1501f0f5fdf7SPeiming Liu 1502f0f5fdf7SPeiming Liu std::optional<uint64_t> LvlOp::getConstantLvlIndex() { 1503f0f5fdf7SPeiming Liu return getConstantIntValue(getIndex()); 1504f0f5fdf7SPeiming Liu } 1505f0f5fdf7SPeiming Liu 1506f0f5fdf7SPeiming Liu Speculation::Speculatability LvlOp::getSpeculatability() { 1507f0f5fdf7SPeiming Liu auto constantIndex = getConstantLvlIndex(); 1508f0f5fdf7SPeiming Liu if (!constantIndex) 1509f0f5fdf7SPeiming Liu return Speculation::NotSpeculatable; 1510f0f5fdf7SPeiming Liu 1511f0f5fdf7SPeiming Liu assert(constantIndex < 1512f0f5fdf7SPeiming Liu cast<RankedTensorType>(getSource().getType()).getRank()); 1513f0f5fdf7SPeiming Liu return Speculation::Speculatable; 1514f0f5fdf7SPeiming Liu } 1515f0f5fdf7SPeiming Liu 1516f0f5fdf7SPeiming Liu OpFoldResult LvlOp::fold(FoldAdaptor adaptor) { 1517f0f5fdf7SPeiming Liu auto lvlIndex = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getIndex()); 1518f0f5fdf7SPeiming Liu if (!lvlIndex) 1519f0f5fdf7SPeiming Liu return {}; 1520f0f5fdf7SPeiming Liu 1521f0f5fdf7SPeiming Liu Level lvl = lvlIndex.getAPSInt().getZExtValue(); 1522f0f5fdf7SPeiming Liu auto stt = getSparseTensorType(getSource()); 1523f0f5fdf7SPeiming Liu if (lvl >= stt.getLvlRank()) { 1524f0f5fdf7SPeiming Liu // Follows the same convention used by tensor.dim operation. Out of bound 1525f0f5fdf7SPeiming Liu // indices produce undefined behavior but are still valid IR. Don't choke on 1526f0f5fdf7SPeiming Liu // them. 1527f0f5fdf7SPeiming Liu return {}; 1528f0f5fdf7SPeiming Liu } 1529f0f5fdf7SPeiming Liu 1530f0f5fdf7SPeiming Liu // Helper lambda to build an IndexAttr. 1531f0f5fdf7SPeiming Liu auto getIndexAttr = [this](int64_t lvlSz) { 1532f0f5fdf7SPeiming Liu return IntegerAttr::get(IndexType::get(getContext()), APInt(64, lvlSz)); 1533f0f5fdf7SPeiming Liu }; 1534f0f5fdf7SPeiming Liu 1535deedf554SPeiming Liu SmallVector<Size> lvlShape = stt.getLvlShape(); 1536deedf554SPeiming Liu if (!ShapedType::isDynamic(lvlShape[lvl])) 1537deedf554SPeiming Liu return getIndexAttr(lvlShape[lvl]); 1538f0f5fdf7SPeiming Liu 1539f0f5fdf7SPeiming Liu return {}; 1540f0f5fdf7SPeiming Liu } 1541f0f5fdf7SPeiming Liu 1542d808d922SPeiming Liu void ReinterpretMapOp::build(OpBuilder &odsBuilder, OperationState &odsState, 1543d808d922SPeiming Liu SparseTensorEncodingAttr dstEnc, Value source) { 1544d808d922SPeiming Liu auto srcStt = getSparseTensorType(source); 1545d808d922SPeiming Liu SmallVector<int64_t> srcLvlShape = srcStt.getLvlShape(); 1546d808d922SPeiming Liu SmallVector<int64_t> dstDimShape = 1547e10dc60aSAart Bik dstEnc.translateShape(srcLvlShape, CrdTransDirectionKind::lvl2dim); 1548d808d922SPeiming Liu auto dstTp = 1549d808d922SPeiming Liu RankedTensorType::get(dstDimShape, srcStt.getElementType(), dstEnc); 1550d808d922SPeiming Liu return build(odsBuilder, odsState, dstTp, source); 1551d808d922SPeiming Liu } 1552d808d922SPeiming Liu 1553d808d922SPeiming Liu LogicalResult ReinterpretMapOp::verify() { 1554d808d922SPeiming Liu auto srcStt = getSparseTensorType(getSource()); 1555d808d922SPeiming Liu auto dstStt = getSparseTensorType(getDest()); 15561944c4f7SAart Bik ArrayRef<LevelType> srcLvlTps = srcStt.getLvlTypes(); 15571944c4f7SAart Bik ArrayRef<LevelType> dstLvlTps = dstStt.getLvlTypes(); 1558d808d922SPeiming Liu 1559d808d922SPeiming Liu if (srcLvlTps.size() != dstLvlTps.size()) 1560d808d922SPeiming Liu return emitError("Level rank mismatch between source/dest tensors"); 1561d808d922SPeiming Liu 1562d808d922SPeiming Liu for (auto [srcLvlTp, dstLvlTp] : llvm::zip(srcLvlTps, dstLvlTps)) 1563d808d922SPeiming Liu if (srcLvlTp != dstLvlTp) 1564d808d922SPeiming Liu return emitError("Level type mismatch between source/dest tensors"); 1565d808d922SPeiming Liu 1566d808d922SPeiming Liu if (srcStt.getPosWidth() != dstStt.getPosWidth() || 1567d808d922SPeiming Liu srcStt.getCrdWidth() != dstStt.getCrdWidth()) { 1568d808d922SPeiming Liu return emitError("Crd/Pos width mismatch between source/dest tensors"); 1569d808d922SPeiming Liu } 1570d808d922SPeiming Liu 1571d808d922SPeiming Liu if (srcStt.getElementType() != dstStt.getElementType()) 1572d808d922SPeiming Liu return emitError("Element type mismatch between source/dest tensors"); 1573d808d922SPeiming Liu 157422212ca7SAart Bik SmallVector<Size> srcLvlShape = srcStt.getLvlShape(); 157522212ca7SAart Bik SmallVector<Size> dstLvlShape = dstStt.getLvlShape(); 1576d808d922SPeiming Liu for (auto [srcLvlSz, dstLvlSz] : llvm::zip(srcLvlShape, dstLvlShape)) { 1577d808d922SPeiming Liu if (srcLvlSz != dstLvlSz) { 1578d808d922SPeiming Liu // Should we allow one side to be dynamic size, e.g., <?x?> should be 1579d808d922SPeiming Liu // compatible to <3x4>? For now, we require all the level sizes to be 1580d808d922SPeiming Liu // *exactly* matched for simplicity. 1581d808d922SPeiming Liu return emitError("Level size mismatch between source/dest tensors"); 1582d808d922SPeiming Liu } 1583d808d922SPeiming Liu } 1584d808d922SPeiming Liu 1585d808d922SPeiming Liu return success(); 1586d808d922SPeiming Liu } 1587d808d922SPeiming Liu 1588d808d922SPeiming Liu OpFoldResult ReinterpretMapOp::fold(FoldAdaptor adaptor) { 1589d808d922SPeiming Liu if (getSource().getType() == getDest().getType()) 1590d808d922SPeiming Liu return getSource(); 1591d808d922SPeiming Liu 1592d808d922SPeiming Liu if (auto def = getSource().getDefiningOp<ReinterpretMapOp>()) { 1593d808d922SPeiming Liu // A -> B, B -> A ==> A 1594d808d922SPeiming Liu if (def.getSource().getType() == getDest().getType()) 1595d808d922SPeiming Liu return def.getSource(); 1596d808d922SPeiming Liu } 1597d808d922SPeiming Liu return {}; 1598d808d922SPeiming Liu } 1599d808d922SPeiming Liu 16006bc7c9dfSPeiming Liu template <typename ToBufferOp> 16016bc7c9dfSPeiming Liu static LogicalResult inferSparseBufferType(ValueRange ops, DictionaryAttr attr, 16026bc7c9dfSPeiming Liu OpaqueProperties prop, 16036bc7c9dfSPeiming Liu RegionRange region, 16046bc7c9dfSPeiming Liu SmallVectorImpl<mlir::Type> &ret) { 16056bc7c9dfSPeiming Liu typename ToBufferOp::Adaptor adaptor(ops, attr, prop, region); 16066bc7c9dfSPeiming Liu SparseTensorType stt = getSparseTensorType(adaptor.getTensor()); 16076bc7c9dfSPeiming Liu Type elemTp = nullptr; 16086bc7c9dfSPeiming Liu bool withStride = false; 16096bc7c9dfSPeiming Liu if constexpr (std::is_same_v<ToBufferOp, ToPositionsOp>) { 16106bc7c9dfSPeiming Liu elemTp = stt.getPosType(); 16116bc7c9dfSPeiming Liu } else if constexpr (std::is_same_v<ToBufferOp, ToCoordinatesOp> || 16126bc7c9dfSPeiming Liu std::is_same_v<ToBufferOp, ToCoordinatesBufferOp>) { 16136bc7c9dfSPeiming Liu elemTp = stt.getCrdType(); 16146bc7c9dfSPeiming Liu if constexpr (std::is_same_v<ToBufferOp, ToCoordinatesOp>) 16156bc7c9dfSPeiming Liu withStride = stt.getAoSCOOStart() <= adaptor.getLevel(); 16166bc7c9dfSPeiming Liu } else if constexpr (std::is_same_v<ToBufferOp, ToValuesOp>) { 16176bc7c9dfSPeiming Liu elemTp = stt.getElementType(); 16186bc7c9dfSPeiming Liu } 16196bc7c9dfSPeiming Liu 16206bc7c9dfSPeiming Liu assert(elemTp && "unhandled operation."); 16216bc7c9dfSPeiming Liu SmallVector<int64_t> bufShape = stt.getBatchLvlShape(); 16226bc7c9dfSPeiming Liu bufShape.push_back(ShapedType::kDynamic); 16236bc7c9dfSPeiming Liu 16246bc7c9dfSPeiming Liu auto layout = withStride ? StridedLayoutAttr::StridedLayoutAttr::get( 16256bc7c9dfSPeiming Liu stt.getContext(), ShapedType::kDynamic, 16266bc7c9dfSPeiming Liu {ShapedType::kDynamic}) 16276bc7c9dfSPeiming Liu : StridedLayoutAttr(); 16286bc7c9dfSPeiming Liu ret.emplace_back(MemRefType::get(bufShape, elemTp, layout)); 16296bc7c9dfSPeiming Liu return success(); 16306bc7c9dfSPeiming Liu } 16316bc7c9dfSPeiming Liu 163284cd51bbSwren romano LogicalResult ToPositionsOp::verify() { 16335b729503SAart Bik auto stt = getSparseTensorType(getTensor()); 163484cd51bbSwren romano if (failed(lvlIsInBounds(getLevel(), getTensor()))) 163584cd51bbSwren romano return emitError("requested level is out of bounds"); 16365b729503SAart Bik if (failed(isMatchingWidth(getResult(), stt.getPosWidth()))) 163784cd51bbSwren romano return emitError("unexpected type for positions"); 163896a23911SAart Bik return success(); 163996a23911SAart Bik } 164096a23911SAart Bik 16416bc7c9dfSPeiming Liu LogicalResult 16426bc7c9dfSPeiming Liu ToPositionsOp::inferReturnTypes(MLIRContext *ctx, std::optional<Location> loc, 16436bc7c9dfSPeiming Liu ValueRange ops, DictionaryAttr attr, 16446bc7c9dfSPeiming Liu OpaqueProperties prop, RegionRange region, 16456bc7c9dfSPeiming Liu SmallVectorImpl<mlir::Type> &ret) { 16466bc7c9dfSPeiming Liu return inferSparseBufferType<ToPositionsOp>(ops, attr, prop, region, ret); 16476bc7c9dfSPeiming Liu } 16486bc7c9dfSPeiming Liu 164984cd51bbSwren romano LogicalResult ToCoordinatesOp::verify() { 16505b729503SAart Bik auto stt = getSparseTensorType(getTensor()); 165184cd51bbSwren romano if (failed(lvlIsInBounds(getLevel(), getTensor()))) 165284cd51bbSwren romano return emitError("requested level is out of bounds"); 16535b729503SAart Bik if (failed(isMatchingWidth(getResult(), stt.getCrdWidth()))) 165484cd51bbSwren romano return emitError("unexpected type for coordinates"); 165596a23911SAart Bik return success(); 165696a23911SAart Bik } 165796a23911SAart Bik 16586bc7c9dfSPeiming Liu LogicalResult 16596bc7c9dfSPeiming Liu ToCoordinatesOp::inferReturnTypes(MLIRContext *ctx, std::optional<Location> loc, 16606bc7c9dfSPeiming Liu ValueRange ops, DictionaryAttr attr, 16616bc7c9dfSPeiming Liu OpaqueProperties prop, RegionRange region, 16626bc7c9dfSPeiming Liu SmallVectorImpl<mlir::Type> &ret) { 16636bc7c9dfSPeiming Liu return inferSparseBufferType<ToCoordinatesOp>(ops, attr, prop, region, ret); 16646bc7c9dfSPeiming Liu } 16656bc7c9dfSPeiming Liu 166684cd51bbSwren romano LogicalResult ToCoordinatesBufferOp::verify() { 16675b729503SAart Bik auto stt = getSparseTensorType(getTensor()); 16685248a987SPeiming Liu if (stt.getAoSCOOStart() >= stt.getLvlRank()) 16699bde3d0cSbixia1 return emitError("expected sparse tensor with a COO region"); 16709bde3d0cSbixia1 return success(); 16719bde3d0cSbixia1 } 16729bde3d0cSbixia1 16736bc7c9dfSPeiming Liu LogicalResult ToCoordinatesBufferOp::inferReturnTypes( 16746bc7c9dfSPeiming Liu MLIRContext *ctx, std::optional<Location> loc, ValueRange ops, 16756bc7c9dfSPeiming Liu DictionaryAttr attr, OpaqueProperties prop, RegionRange region, 16766bc7c9dfSPeiming Liu SmallVectorImpl<mlir::Type> &ret) { 16776bc7c9dfSPeiming Liu return inferSparseBufferType<ToCoordinatesBufferOp>(ops, attr, prop, region, 16786bc7c9dfSPeiming Liu ret); 16796bc7c9dfSPeiming Liu } 16806bc7c9dfSPeiming Liu 1681b98dc035SRiver Riddle LogicalResult ToValuesOp::verify() { 16825b729503SAart Bik auto stt = getSparseTensorType(getTensor()); 16839916ab03Swren romano auto mtp = getMemRefType(getResult()); 16845b729503SAart Bik if (stt.getElementType() != mtp.getElementType()) 1685b98dc035SRiver Riddle return emitError("unexpected mismatch in element types"); 168696a23911SAart Bik return success(); 168796a23911SAart Bik } 168896a23911SAart Bik 16896bc7c9dfSPeiming Liu LogicalResult ToValuesOp::inferReturnTypes(MLIRContext *ctx, 16906bc7c9dfSPeiming Liu std::optional<Location> loc, 16916bc7c9dfSPeiming Liu ValueRange ops, DictionaryAttr attr, 16926bc7c9dfSPeiming Liu OpaqueProperties prop, 16936bc7c9dfSPeiming Liu RegionRange region, 16946bc7c9dfSPeiming Liu SmallVectorImpl<mlir::Type> &ret) { 16956bc7c9dfSPeiming Liu return inferSparseBufferType<ToValuesOp>(ops, attr, prop, region, ret); 16966bc7c9dfSPeiming Liu } 16976bc7c9dfSPeiming Liu 1698c738b430SPeiming Liu LogicalResult ToSliceOffsetOp::verify() { 169977f8297cSMatthias Springer auto rank = getSlice().getType().getRank(); 1700c738b430SPeiming Liu if (rank <= getDim().getSExtValue() || getDim().getSExtValue() < 0) 1701c738b430SPeiming Liu return emitError("requested dimension out of bound"); 1702c738b430SPeiming Liu return success(); 1703c738b430SPeiming Liu } 1704c738b430SPeiming Liu 1705c738b430SPeiming Liu LogicalResult ToSliceStrideOp::verify() { 170677f8297cSMatthias Springer auto rank = getSlice().getType().getRank(); 1707c738b430SPeiming Liu if (rank <= getDim().getSExtValue() || getDim().getSExtValue() < 0) 1708c738b430SPeiming Liu return emitError("requested dimension out of bound"); 1709c738b430SPeiming Liu return success(); 1710c738b430SPeiming Liu } 1711c738b430SPeiming Liu 171271cc0f1cSPeiming Liu LogicalResult GetStorageSpecifierOp::verify() { 17132e2011daSAart Bik return verifySparsifierGetterSetter(getSpecifierKind(), getLevel(), 17142e2011daSAart Bik getSpecifier(), getOperation()); 171571cc0f1cSPeiming Liu } 171671cc0f1cSPeiming Liu 1717509974afSPeiming Liu template <typename SpecifierOp> 1718509974afSPeiming Liu static SetStorageSpecifierOp getSpecifierSetDef(SpecifierOp op) { 1719509974afSPeiming Liu return op.getSpecifier().template getDefiningOp<SetStorageSpecifierOp>(); 1720509974afSPeiming Liu } 1721509974afSPeiming Liu 17227df76121SMarkus Böck OpFoldResult GetStorageSpecifierOp::fold(FoldAdaptor adaptor) { 172384cd51bbSwren romano const StorageSpecifierKind kind = getSpecifierKind(); 172484cd51bbSwren romano const auto lvl = getLevel(); 1725509974afSPeiming Liu for (auto op = getSpecifierSetDef(*this); op; op = getSpecifierSetDef(op)) 172684cd51bbSwren romano if (kind == op.getSpecifierKind() && lvl == op.getLevel()) 1727509974afSPeiming Liu return op.getValue(); 1728509974afSPeiming Liu return {}; 1729509974afSPeiming Liu } 1730509974afSPeiming Liu 173171cc0f1cSPeiming Liu LogicalResult SetStorageSpecifierOp::verify() { 17322e2011daSAart Bik return verifySparsifierGetterSetter(getSpecifierKind(), getLevel(), 17332e2011daSAart Bik getSpecifier(), getOperation()); 173471cc0f1cSPeiming Liu } 173571cc0f1cSPeiming Liu 1736414ed019SJim Kitchen template <class T> 1737414ed019SJim Kitchen static LogicalResult verifyNumBlockArgs(T *op, Region ®ion, 1738414ed019SJim Kitchen const char *regionName, 1739414ed019SJim Kitchen TypeRange inputTypes, Type outputType) { 1740414ed019SJim Kitchen unsigned numArgs = region.getNumArguments(); 1741414ed019SJim Kitchen unsigned expectedNum = inputTypes.size(); 1742414ed019SJim Kitchen if (numArgs != expectedNum) 1743414ed019SJim Kitchen return op->emitError() << regionName << " region must have exactly " 1744414ed019SJim Kitchen << expectedNum << " arguments"; 1745414ed019SJim Kitchen 1746414ed019SJim Kitchen for (unsigned i = 0; i < numArgs; i++) { 1747414ed019SJim Kitchen Type typ = region.getArgument(i).getType(); 1748414ed019SJim Kitchen if (typ != inputTypes[i]) 1749414ed019SJim Kitchen return op->emitError() << regionName << " region argument " << (i + 1) 1750414ed019SJim Kitchen << " type mismatch"; 1751414ed019SJim Kitchen } 1752414ed019SJim Kitchen Operation *term = region.front().getTerminator(); 1753414ed019SJim Kitchen YieldOp yield = dyn_cast<YieldOp>(term); 1754414ed019SJim Kitchen if (!yield) 1755414ed019SJim Kitchen return op->emitError() << regionName 1756414ed019SJim Kitchen << " region must end with sparse_tensor.yield"; 1757a54930e6SPeiming Liu if (!yield.hasSingleResult() || 1758a54930e6SPeiming Liu yield.getSingleResult().getType() != outputType) 1759414ed019SJim Kitchen return op->emitError() << regionName << " region yield type mismatch"; 1760414ed019SJim Kitchen 1761414ed019SJim Kitchen return success(); 1762414ed019SJim Kitchen } 1763414ed019SJim Kitchen 1764414ed019SJim Kitchen LogicalResult BinaryOp::verify() { 1765414ed019SJim Kitchen NamedAttrList attrs = (*this)->getAttrs(); 176604235d07SJacques Pienaar Type leftType = getX().getType(); 176704235d07SJacques Pienaar Type rightType = getY().getType(); 176804235d07SJacques Pienaar Type outputType = getOutput().getType(); 176904235d07SJacques Pienaar Region &overlap = getOverlapRegion(); 177004235d07SJacques Pienaar Region &left = getLeftRegion(); 177104235d07SJacques Pienaar Region &right = getRightRegion(); 1772414ed019SJim Kitchen 1773414ed019SJim Kitchen // Check correct number of block arguments and return type for each 1774414ed019SJim Kitchen // non-empty region. 1775414ed019SJim Kitchen if (!overlap.empty()) { 17762e2011daSAart Bik if (failed(verifyNumBlockArgs(this, overlap, "overlap", 17772e2011daSAart Bik TypeRange{leftType, rightType}, outputType))) 17782e2011daSAart Bik return failure(); 1779414ed019SJim Kitchen } 1780414ed019SJim Kitchen if (!left.empty()) { 17812e2011daSAart Bik if (failed(verifyNumBlockArgs(this, left, "left", TypeRange{leftType}, 17822e2011daSAart Bik outputType))) 17832e2011daSAart Bik return failure(); 178404235d07SJacques Pienaar } else if (getLeftIdentity()) { 1785414ed019SJim Kitchen if (leftType != outputType) 1786414ed019SJim Kitchen return emitError("left=identity requires first argument to have the same " 1787414ed019SJim Kitchen "type as the output"); 1788414ed019SJim Kitchen } 1789414ed019SJim Kitchen if (!right.empty()) { 17902e2011daSAart Bik if (failed(verifyNumBlockArgs(this, right, "right", TypeRange{rightType}, 17912e2011daSAart Bik outputType))) 17922e2011daSAart Bik return failure(); 179304235d07SJacques Pienaar } else if (getRightIdentity()) { 1794414ed019SJim Kitchen if (rightType != outputType) 1795414ed019SJim Kitchen return emitError("right=identity requires second argument to have the " 1796414ed019SJim Kitchen "same type as the output"); 1797414ed019SJim Kitchen } 1798414ed019SJim Kitchen return success(); 1799414ed019SJim Kitchen } 1800414ed019SJim Kitchen 1801414ed019SJim Kitchen LogicalResult UnaryOp::verify() { 180204235d07SJacques Pienaar Type inputType = getX().getType(); 180304235d07SJacques Pienaar Type outputType = getOutput().getType(); 1804414ed019SJim Kitchen 1805414ed019SJim Kitchen // Check correct number of block arguments and return type for each 1806414ed019SJim Kitchen // non-empty region. 180704235d07SJacques Pienaar Region &present = getPresentRegion(); 1808414ed019SJim Kitchen if (!present.empty()) { 18092e2011daSAart Bik if (failed(verifyNumBlockArgs(this, present, "present", 18102e2011daSAart Bik TypeRange{inputType}, outputType))) 18112e2011daSAart Bik return failure(); 1812414ed019SJim Kitchen } 181304235d07SJacques Pienaar Region &absent = getAbsentRegion(); 1814414ed019SJim Kitchen if (!absent.empty()) { 18152e2011daSAart Bik if (failed(verifyNumBlockArgs(this, absent, "absent", TypeRange{}, 18162e2011daSAart Bik outputType))) 18172e2011daSAart Bik return failure(); 18187e83a1afSAart Bik // Absent branch can only yield invariant values. 18197e83a1afSAart Bik Block *absentBlock = &absent.front(); 18207e83a1afSAart Bik Block *parent = getOperation()->getBlock(); 1821a54930e6SPeiming Liu Value absentVal = 1822a54930e6SPeiming Liu cast<YieldOp>(absentBlock->getTerminator()).getSingleResult(); 18237e83a1afSAart Bik if (auto arg = dyn_cast<BlockArgument>(absentVal)) { 18247e83a1afSAart Bik if (arg.getOwner() == parent) 18257e83a1afSAart Bik return emitError("absent region cannot yield linalg argument"); 18267e83a1afSAart Bik } else if (Operation *def = absentVal.getDefiningOp()) { 18277e83a1afSAart Bik if (!isa<arith::ConstantOp>(def) && 18287e83a1afSAart Bik (def->getBlock() == absentBlock || def->getBlock() == parent)) 18297e83a1afSAart Bik return emitError("absent region cannot yield locally computed value"); 18307e83a1afSAart Bik } 1831414ed019SJim Kitchen } 1832414ed019SJim Kitchen return success(); 1833414ed019SJim Kitchen } 1834414ed019SJim Kitchen 1835761c9dd9SPeiming Liu bool ConcatenateOp::needsExtraSort() { 1836761c9dd9SPeiming Liu SparseTensorType dstStt = getSparseTensorType(*this); 1837761c9dd9SPeiming Liu if (dstStt.isAllDense() || !dstStt.isAllOrdered()) 1838761c9dd9SPeiming Liu return false; 1839761c9dd9SPeiming Liu 1840761c9dd9SPeiming Liu bool allSameOrdered = llvm::all_of(getInputs(), [dstStt](Value op) { 1841761c9dd9SPeiming Liu return getSparseTensorType(op).hasSameDimToLvl(dstStt); 1842761c9dd9SPeiming Liu }); 1843761c9dd9SPeiming Liu // TODO: When conDim != 0, as long as conDim corresponding to the first level 1844761c9dd9SPeiming Liu // in all input/output buffers, and all input/output buffers have the same 1845761c9dd9SPeiming Liu // dimToLvl, the tmp COO buffer is still unnecessary (e.g, concatenate 1846761c9dd9SPeiming Liu // CSC matrices along column). 1847761c9dd9SPeiming Liu bool directLowerable = 1848761c9dd9SPeiming Liu allSameOrdered && getDimension() == 0 && dstStt.isIdentity(); 1849761c9dd9SPeiming Liu return !directLowerable; 1850761c9dd9SPeiming Liu } 1851761c9dd9SPeiming Liu 1852de907138SPeiming Liu LogicalResult ConcatenateOp::verify() { 1853f708a549Swren romano const auto dstTp = getSparseTensorType(*this); 185484cd51bbSwren romano const Dimension concatDim = getDimension(); 1855f708a549Swren romano const Dimension dimRank = dstTp.getDimRank(); 1856de907138SPeiming Liu 1857de907138SPeiming Liu if (getInputs().size() <= 1) 1858de907138SPeiming Liu return emitError("Need at least two tensors to concatenate."); 1859de907138SPeiming Liu 1860f708a549Swren romano if (concatDim >= dimRank) 1861de907138SPeiming Liu return emitError(llvm::formatv( 1862f708a549Swren romano "Concat-dimension is out of bounds for dimension-rank ({0} >= {1})", 1863f708a549Swren romano concatDim, dimRank)); 1864de907138SPeiming Liu 1865f708a549Swren romano for (const auto &it : llvm::enumerate(getInputs())) { 1866f708a549Swren romano const auto i = it.index(); 1867f708a549Swren romano const auto srcTp = getSparseTensorType(it.value()); 1868f708a549Swren romano if (srcTp.hasDynamicDimShape()) 1869f708a549Swren romano return emitError(llvm::formatv("Input tensor ${0} has dynamic shape", i)); 1870f708a549Swren romano const Dimension srcDimRank = srcTp.getDimRank(); 1871f708a549Swren romano if (srcDimRank != dimRank) 1872de907138SPeiming Liu return emitError( 1873f708a549Swren romano llvm::formatv("Input tensor ${0} has a different rank (rank={1}) " 1874de907138SPeiming Liu "from the output tensor (rank={2}).", 1875f708a549Swren romano i, srcDimRank, dimRank)); 1876de907138SPeiming Liu } 1877de907138SPeiming Liu 1878f708a549Swren romano for (Dimension d = 0; d < dimRank; d++) { 187922212ca7SAart Bik const Size dstSh = dstTp.getDimShape()[d]; 1880f708a549Swren romano if (d == concatDim) { 1881f708a549Swren romano if (!ShapedType::isDynamic(dstSh)) { 1882f708a549Swren romano // If we reach here, then all inputs have static shapes. So we 1883f708a549Swren romano // can use `getDimShape()[d]` instead of `*getDynamicDimSize(d)` 1884f708a549Swren romano // to avoid redundant assertions in the loop. 188522212ca7SAart Bik Size sumSz = 0; 1886f708a549Swren romano for (const auto src : getInputs()) 1887f708a549Swren romano sumSz += getSparseTensorType(src).getDimShape()[d]; 1888de907138SPeiming Liu // If all dimension are statically known, the sum of all the input 1889de907138SPeiming Liu // dimensions should be equal to the output dimension. 1890f708a549Swren romano if (sumSz != dstSh) 1891de907138SPeiming Liu return emitError( 1892de907138SPeiming Liu "The concatenation dimension of the output tensor should be the " 1893de907138SPeiming Liu "sum of all the concatenation dimensions of the input tensors."); 1894de907138SPeiming Liu } 1895de907138SPeiming Liu } else { 189622212ca7SAart Bik Size prev = dstSh; 1897f708a549Swren romano for (const auto src : getInputs()) { 1898f708a549Swren romano const auto sh = getSparseTensorType(src).getDimShape()[d]; 1899f708a549Swren romano if (!ShapedType::isDynamic(prev) && sh != prev) 1900de907138SPeiming Liu return emitError("All dimensions (expect for the concatenating one) " 1901de907138SPeiming Liu "should be equal."); 1902f708a549Swren romano prev = sh; 1903de907138SPeiming Liu } 1904de907138SPeiming Liu } 1905de907138SPeiming Liu } 1906de907138SPeiming Liu 1907de907138SPeiming Liu return success(); 1908de907138SPeiming Liu } 1909de907138SPeiming Liu 191014504caeSbixia1 void PushBackOp::build(OpBuilder &builder, OperationState &result, 1911988733c6SPeiming Liu Value curSize, Value inBuffer, Value value) { 1912988733c6SPeiming Liu build(builder, result, curSize, inBuffer, value, Value()); 191314504caeSbixia1 } 191414504caeSbixia1 191514504caeSbixia1 LogicalResult PushBackOp::verify() { 1916743fbcb7Swren romano if (Value n = getN()) { 1917cb7bda2aSMatthias Springer std::optional<int64_t> nValue = getConstantIntValue(n); 191814504caeSbixia1 if (nValue && nValue.value() < 1) 191914504caeSbixia1 return emitOpError("n must be not less than 1"); 192014504caeSbixia1 } 192114504caeSbixia1 return success(); 192214504caeSbixia1 } 192314504caeSbixia1 1924a3610359SAart Bik LogicalResult CompressOp::verify() { 192584cd51bbSwren romano const auto stt = getSparseTensorType(getTensor()); 192684cd51bbSwren romano if (stt.getLvlRank() != 1 + static_cast<Level>(getLvlCoords().size())) 192784cd51bbSwren romano return emitOpError("incorrect number of coordinates"); 1928a3610359SAart Bik return success(); 1929a3610359SAart Bik } 1930a3610359SAart Bik 193100ad0655SPeiming Liu void ForeachOp::build( 193200ad0655SPeiming Liu OpBuilder &builder, OperationState &result, Value tensor, 19339e8d9316SPeiming Liu ValueRange initArgs, AffineMapAttr order, 19344fa00ce1SPeiming Liu function_ref<void(OpBuilder &, Location, ValueRange, Value, ValueRange)> 19354fa00ce1SPeiming Liu bodyBuilder) { 19369e8d9316SPeiming Liu build(builder, result, initArgs.getTypes(), tensor, initArgs, order); 19374fa00ce1SPeiming Liu // Builds foreach body. 193800ad0655SPeiming Liu if (!bodyBuilder) 193900ad0655SPeiming Liu return; 1940f708a549Swren romano const auto stt = getSparseTensorType(tensor); 1941f708a549Swren romano const Dimension dimRank = stt.getDimRank(); 194200ad0655SPeiming Liu 194384cd51bbSwren romano // Starts with `dimRank`-many coordinates. 1944f708a549Swren romano SmallVector<Type> blockArgTypes(dimRank, builder.getIndexType()); 194500ad0655SPeiming Liu // Followed by one value. 1946f708a549Swren romano blockArgTypes.push_back(stt.getElementType()); 1947f708a549Swren romano // Followed by the reduction variables. 19484fa00ce1SPeiming Liu blockArgTypes.append(initArgs.getTypes().begin(), initArgs.getTypes().end()); 194900ad0655SPeiming Liu 1950f708a549Swren romano SmallVector<Location> blockArgLocs(blockArgTypes.size(), tensor.getLoc()); 195100ad0655SPeiming Liu 195200ad0655SPeiming Liu OpBuilder::InsertionGuard guard(builder); 195300ad0655SPeiming Liu auto ®ion = *result.regions.front(); 195400ad0655SPeiming Liu Block *bodyBlock = 195500ad0655SPeiming Liu builder.createBlock(®ion, region.end(), blockArgTypes, blockArgLocs); 19564fa00ce1SPeiming Liu bodyBuilder(builder, result.location, 1957f708a549Swren romano bodyBlock->getArguments().slice(0, dimRank), 1958f708a549Swren romano bodyBlock->getArguments()[dimRank], 1959f708a549Swren romano bodyBlock->getArguments().drop_front(dimRank + 1)); 196000ad0655SPeiming Liu } 196100ad0655SPeiming Liu 1962e08865a1SPeiming Liu LogicalResult ForeachOp::verify() { 1963f708a549Swren romano const auto t = getSparseTensorType(getTensor()); 1964f708a549Swren romano const Dimension dimRank = t.getDimRank(); 1965f708a549Swren romano const auto args = getBody()->getArguments(); 1966e08865a1SPeiming Liu 196753ffafb2SPeiming Liu if (getOrder().has_value() && getOrder()->getNumDims() != t.getLvlRank()) 196853ffafb2SPeiming Liu return emitError("Level traverse order does not match tensor's level rank"); 19699e8d9316SPeiming Liu 197053ffafb2SPeiming Liu if (dimRank + 1 + getInitArgs().size() != args.size()) 1971e08865a1SPeiming Liu return emitError("Unmatched number of arguments in the block"); 1972e08865a1SPeiming Liu 19734fa00ce1SPeiming Liu if (getNumResults() != getInitArgs().size()) 19744fa00ce1SPeiming Liu return emitError("Mismatch in number of init arguments and results"); 19754fa00ce1SPeiming Liu 19764fa00ce1SPeiming Liu if (getResultTypes() != getInitArgs().getTypes()) 19774fa00ce1SPeiming Liu return emitError("Mismatch in types of init arguments and results"); 19784fa00ce1SPeiming Liu 1979f708a549Swren romano // Cannot mark this const, because the getters aren't. 19804fa00ce1SPeiming Liu auto yield = cast<YieldOp>(getBody()->getTerminator()); 19814fa00ce1SPeiming Liu if (yield.getNumOperands() != getNumResults() || 19824fa00ce1SPeiming Liu yield.getOperands().getTypes() != getResultTypes()) 19834fa00ce1SPeiming Liu return emitError("Mismatch in types of yield values and results"); 19844fa00ce1SPeiming Liu 1985f708a549Swren romano const auto iTp = IndexType::get(getContext()); 1986f708a549Swren romano for (Dimension d = 0; d < dimRank; d++) 1987f708a549Swren romano if (args[d].getType() != iTp) 198877f8297cSMatthias Springer return emitError( 1989f708a549Swren romano llvm::formatv("Expecting Index type for argument at index {0}", d)); 1990e08865a1SPeiming Liu 1991f708a549Swren romano const auto elemTp = t.getElementType(); 1992f708a549Swren romano const auto valueTp = args[dimRank].getType(); 1993e08865a1SPeiming Liu if (elemTp != valueTp) 199477f8297cSMatthias Springer return emitError( 199577f8297cSMatthias Springer llvm::formatv("Unmatched element type between input tensor and " 1996e08865a1SPeiming Liu "block argument, expected:{0}, got: {1}", 1997e08865a1SPeiming Liu elemTp, valueTp)); 1998e08865a1SPeiming Liu return success(); 1999e08865a1SPeiming Liu } 2000e08865a1SPeiming Liu 20010aacc213SPeiming Liu OpFoldResult ReorderCOOOp::fold(FoldAdaptor adaptor) { 20020aacc213SPeiming Liu if (getSparseTensorEncoding(getInputCoo().getType()) == 20030aacc213SPeiming Liu getSparseTensorEncoding(getResultCoo().getType())) 20040aacc213SPeiming Liu return getInputCoo(); 20050aacc213SPeiming Liu 20060aacc213SPeiming Liu return {}; 20070aacc213SPeiming Liu } 20080aacc213SPeiming Liu 20090aacc213SPeiming Liu LogicalResult ReorderCOOOp::verify() { 20100aacc213SPeiming Liu SparseTensorType srcStt = getSparseTensorType(getInputCoo()); 20110aacc213SPeiming Liu SparseTensorType dstStt = getSparseTensorType(getResultCoo()); 20120aacc213SPeiming Liu 20135b729503SAart Bik if (!srcStt.isCOOType() || !dstStt.isCOOType()) 201477f8297cSMatthias Springer return emitError("Expected COO sparse tensors only"); 201598f8b1afSAart Bik 20160aacc213SPeiming Liu if (!srcStt.hasSameDimToLvl(dstStt)) 201777f8297cSMatthias Springer return emitError("Unmatched dim2lvl map between input and result COO"); 20180aacc213SPeiming Liu 20190aacc213SPeiming Liu if (srcStt.getPosType() != dstStt.getPosType() || 20200aacc213SPeiming Liu srcStt.getCrdType() != dstStt.getCrdType() || 202198f8b1afSAart Bik srcStt.getElementType() != dstStt.getElementType()) 202277f8297cSMatthias Springer return emitError("Unmatched storage format between input and result COO"); 202398f8b1afSAart Bik 20240aacc213SPeiming Liu return success(); 20250aacc213SPeiming Liu } 20260aacc213SPeiming Liu 20272b8a4d9cSJim Kitchen LogicalResult ReduceOp::verify() { 2028a1ec0d8bSJacques Pienaar Type inputType = getX().getType(); 2029a1ec0d8bSJacques Pienaar Region &formula = getRegion(); 20302e2011daSAart Bik return verifyNumBlockArgs(this, formula, "reduce", 20312e2011daSAart Bik TypeRange{inputType, inputType}, inputType); 20322b8a4d9cSJim Kitchen } 20332b8a4d9cSJim Kitchen 203407150fecSJim Kitchen LogicalResult SelectOp::verify() { 203507150fecSJim Kitchen Builder b(getContext()); 203607150fecSJim Kitchen Type inputType = getX().getType(); 203707150fecSJim Kitchen Type boolType = b.getI1Type(); 203807150fecSJim Kitchen Region &formula = getRegion(); 20392e2011daSAart Bik return verifyNumBlockArgs(this, formula, "select", TypeRange{inputType}, 20402e2011daSAart Bik boolType); 20412b8a4d9cSJim Kitchen } 20422b8a4d9cSJim Kitchen 20430083f833SPeiming Liu LogicalResult SortOp::verify() { 2044bfa3bc43SPeiming Liu AffineMap xPerm = getPermMap(); 2045bfa3bc43SPeiming Liu uint64_t nx = xPerm.getNumDims(); 2046bfa3bc43SPeiming Liu if (nx < 1) 204777f8297cSMatthias Springer return emitError(llvm::formatv("Expected rank(perm_map) > 1, got {0}", nx)); 2048bfa3bc43SPeiming Liu 2049bfa3bc43SPeiming Liu if (!xPerm.isPermutation()) 205077f8297cSMatthias Springer return emitError( 205177f8297cSMatthias Springer llvm::formatv("Expected a permutation map, got {0}", xPerm)); 2052bfa3bc43SPeiming Liu 2053cf24d49dSbixia1 // We can't check the size of the buffers when n or buffer dimensions aren't 2054cf24d49dSbixia1 // compile-time constants. 2055d2d29288SAart Bik std::optional<int64_t> cn = getConstantIntValue(getN()); 2056cf24d49dSbixia1 if (!cn) 2057cf24d49dSbixia1 return success(); 2058cf24d49dSbixia1 2059d2d29288SAart Bik // Verify dimensions. 206077f8297cSMatthias Springer const auto checkDim = [&](Value v, Size minSize, 206177f8297cSMatthias Springer const char *message) -> LogicalResult { 206222212ca7SAart Bik const Size sh = getMemRefType(v).getShape()[0]; 2063f708a549Swren romano if (!ShapedType::isDynamic(sh) && sh < minSize) 206477f8297cSMatthias Springer return emitError( 206577f8297cSMatthias Springer llvm::formatv("{0} got {1} < {2}", message, sh, minSize)); 206677f8297cSMatthias Springer return success(); 2067cf24d49dSbixia1 }; 2068d2d29288SAart Bik uint64_t n = cn.value(); 2069d2d29288SAart Bik uint64_t ny = 0; 2070d2d29288SAart Bik if (auto nyAttr = getNyAttr()) 2071d2d29288SAart Bik ny = nyAttr.getInt(); 207277f8297cSMatthias Springer if (failed(checkDim(getXy(), n * (nx + ny), 207377f8297cSMatthias Springer "Expected dimension(xy) >= n * (rank(perm_map) + ny)"))) 207477f8297cSMatthias Springer return failure(); 2075d2d29288SAart Bik for (Value opnd : getYs()) 207677f8297cSMatthias Springer if (failed(checkDim(opnd, n, "Expected dimension(y) >= n"))) 207777f8297cSMatthias Springer return failure(); 2078cf24d49dSbixia1 2079cf24d49dSbixia1 return success(); 2080cf24d49dSbixia1 } 2081cf24d49dSbixia1 2082481bd5d4SPeiming Liu //===----------------------------------------------------------------------===// 2083481bd5d4SPeiming Liu // Sparse Tensor Iteration Operations. 2084481bd5d4SPeiming Liu //===----------------------------------------------------------------------===// 2085481bd5d4SPeiming Liu 2086481bd5d4SPeiming Liu IterSpaceType IteratorType::getIterSpaceType() const { 2087481bd5d4SPeiming Liu return IterSpaceType::get(getContext(), getEncoding(), getLoLvl(), 2088481bd5d4SPeiming Liu getHiLvl()); 2089481bd5d4SPeiming Liu } 2090481bd5d4SPeiming Liu 2091481bd5d4SPeiming Liu IteratorType IterSpaceType::getIteratorType() const { 2092481bd5d4SPeiming Liu return IteratorType::get(getContext(), getEncoding(), getLoLvl(), getHiLvl()); 2093481bd5d4SPeiming Liu } 2094481bd5d4SPeiming Liu 2095481bd5d4SPeiming Liu /// Parses a level range in the form "$lo `to` $hi" 2096481bd5d4SPeiming Liu /// or simply "$lo" if $hi - $lo = 1 2097481bd5d4SPeiming Liu static ParseResult parseLevelRange(AsmParser &parser, Level &lvlLo, 2098481bd5d4SPeiming Liu Level &lvlHi) { 2099481bd5d4SPeiming Liu if (parser.parseInteger(lvlLo)) 2100481bd5d4SPeiming Liu return failure(); 2101481bd5d4SPeiming Liu 2102481bd5d4SPeiming Liu if (succeeded(parser.parseOptionalKeyword("to"))) { 2103481bd5d4SPeiming Liu if (parser.parseInteger(lvlHi)) 2104481bd5d4SPeiming Liu return failure(); 2105481bd5d4SPeiming Liu } else { 2106481bd5d4SPeiming Liu lvlHi = lvlLo + 1; 2107481bd5d4SPeiming Liu } 2108481bd5d4SPeiming Liu 2109481bd5d4SPeiming Liu if (lvlHi <= lvlLo) 211077f8297cSMatthias Springer return parser.emitError(parser.getNameLoc(), 2111481bd5d4SPeiming Liu "expect larger level upper bound than lower bound"); 2112481bd5d4SPeiming Liu 2113481bd5d4SPeiming Liu return success(); 2114481bd5d4SPeiming Liu } 2115481bd5d4SPeiming Liu 2116481bd5d4SPeiming Liu /// Parses a level range in the form "$lo `to` $hi" 2117481bd5d4SPeiming Liu /// or simply "$lo" if $hi - $lo = 1 2118481bd5d4SPeiming Liu static ParseResult parseLevelRange(OpAsmParser &parser, IntegerAttr &lvlLoAttr, 2119481bd5d4SPeiming Liu IntegerAttr &lvlHiAttr) { 2120481bd5d4SPeiming Liu Level lvlLo, lvlHi; 2121481bd5d4SPeiming Liu if (parseLevelRange(parser, lvlLo, lvlHi)) 2122481bd5d4SPeiming Liu return failure(); 2123481bd5d4SPeiming Liu 2124481bd5d4SPeiming Liu lvlLoAttr = IntegerAttr::get(parser.getBuilder().getIndexType(), lvlLo); 2125481bd5d4SPeiming Liu lvlHiAttr = IntegerAttr::get(parser.getBuilder().getIndexType(), lvlHi); 2126481bd5d4SPeiming Liu return success(); 2127481bd5d4SPeiming Liu } 2128481bd5d4SPeiming Liu 2129481bd5d4SPeiming Liu /// Prints a level range in the form "$lo `to` $hi" 2130481bd5d4SPeiming Liu /// or simply "$lo" if $hi - $lo = 1 2131481bd5d4SPeiming Liu static void printLevelRange(AsmPrinter &p, Level lo, Level hi) { 2132481bd5d4SPeiming Liu 2133481bd5d4SPeiming Liu if (lo + 1 == hi) 2134481bd5d4SPeiming Liu p << lo; 2135481bd5d4SPeiming Liu else 2136481bd5d4SPeiming Liu p << lo << " to " << hi; 2137481bd5d4SPeiming Liu } 2138481bd5d4SPeiming Liu 2139481bd5d4SPeiming Liu /// Prints a level range in the form "$lo `to` $hi" 2140481bd5d4SPeiming Liu /// or simply "$lo" if $hi - $lo = 1 2141481bd5d4SPeiming Liu static void printLevelRange(OpAsmPrinter &p, Operation *, IntegerAttr lvlLo, 2142481bd5d4SPeiming Liu IntegerAttr lvlHi) { 2143481bd5d4SPeiming Liu unsigned lo = lvlLo.getValue().getZExtValue(); 2144481bd5d4SPeiming Liu unsigned hi = lvlHi.getValue().getZExtValue(); 2145481bd5d4SPeiming Liu printLevelRange(p, lo, hi); 2146481bd5d4SPeiming Liu } 2147481bd5d4SPeiming Liu 2148785a24f1SPeiming Liu /// Parses a list of `optional` defined list in the form of 2149785a24f1SPeiming Liu /// "(%val0, _, %val1, ...)", where `_` is used to annotate that the 2150785a24f1SPeiming Liu /// corresponding value is not defined (e.g., to represent an undefined 2151785a24f1SPeiming Liu /// coordinate in the sparse iteration space). 2152785a24f1SPeiming Liu static ParseResult parseOptionalDefinedList( 2153785a24f1SPeiming Liu OpAsmParser &parser, OperationState &state, I64BitSet &definedSet, 2154785a24f1SPeiming Liu SmallVectorImpl<OpAsmParser::Argument> &definedArgs, 2155785a24f1SPeiming Liu unsigned maxCnt = std::numeric_limits<unsigned>::max(), 2156785a24f1SPeiming Liu OpAsmParser::Delimiter delimiter = OpAsmParser::Delimiter::Paren) { 2157785a24f1SPeiming Liu unsigned cnt = 0; 2158785a24f1SPeiming Liu ParseResult crdList = 2159785a24f1SPeiming Liu parser.parseCommaSeparatedList(delimiter, [&]() -> ParseResult { 2160785a24f1SPeiming Liu if (parser.parseOptionalKeyword("_")) { 2161785a24f1SPeiming Liu if (parser.parseArgument(definedArgs.emplace_back())) 2162785a24f1SPeiming Liu return failure(); 2163785a24f1SPeiming Liu definedSet.set(cnt); 2164785a24f1SPeiming Liu } 2165785a24f1SPeiming Liu cnt += 1; 2166785a24f1SPeiming Liu return success(); 2167785a24f1SPeiming Liu }); 2168785a24f1SPeiming Liu 2169785a24f1SPeiming Liu if (cnt > maxCnt) 2170785a24f1SPeiming Liu return parser.emitError(parser.getNameLoc(), 2171785a24f1SPeiming Liu "parsed more value than expected."); 2172785a24f1SPeiming Liu 2173785a24f1SPeiming Liu if (failed(crdList)) { 2174785a24f1SPeiming Liu return parser.emitError( 2175785a24f1SPeiming Liu parser.getNameLoc(), 2176785a24f1SPeiming Liu "expecting SSA value or \"_\" for level coordinates"); 2177785a24f1SPeiming Liu } 2178785a24f1SPeiming Liu assert(definedArgs.size() == definedSet.count()); 2179785a24f1SPeiming Liu return success(); 2180785a24f1SPeiming Liu } 2181785a24f1SPeiming Liu 2182785a24f1SPeiming Liu static void printOptionalDefinedList(OpAsmPrinter &p, unsigned size, 2183785a24f1SPeiming Liu Block::BlockArgListType blocksArgs, 2184785a24f1SPeiming Liu I64BitSet definedSet) { 2185785a24f1SPeiming Liu if (definedSet.empty()) 2186785a24f1SPeiming Liu return; 2187785a24f1SPeiming Liu 2188785a24f1SPeiming Liu for (unsigned i = 0; i < size; i++) { 2189785a24f1SPeiming Liu if (definedSet[i]) { 2190785a24f1SPeiming Liu p << blocksArgs.front(); 2191785a24f1SPeiming Liu blocksArgs = blocksArgs.drop_front(); 2192785a24f1SPeiming Liu } else { 2193785a24f1SPeiming Liu p << "_"; 2194785a24f1SPeiming Liu } 2195785a24f1SPeiming Liu if (i != size - 1) 2196785a24f1SPeiming Liu p << ", "; 2197785a24f1SPeiming Liu } 2198785a24f1SPeiming Liu assert(blocksArgs.empty()); 2199785a24f1SPeiming Liu } 2200785a24f1SPeiming Liu 2201e276cf08SPeiming Liu static ParseResult 2202785a24f1SPeiming Liu parseUsedCoordList(OpAsmParser &parser, OperationState &state, 2203785a24f1SPeiming Liu SmallVectorImpl<OpAsmParser::Argument> &coords) { 2204785a24f1SPeiming Liu // Parse "at(%crd0, _, ...)" 2205785a24f1SPeiming Liu I64BitSet crdUsedLvlSet; 2206785a24f1SPeiming Liu if (succeeded(parser.parseOptionalKeyword("at")) && 2207785a24f1SPeiming Liu failed(parseOptionalDefinedList(parser, state, crdUsedLvlSet, coords))) 2208785a24f1SPeiming Liu return failure(); 2209785a24f1SPeiming Liu 2210785a24f1SPeiming Liu // Always use IndexType for the coordinate. 2211785a24f1SPeiming Liu for (auto &coord : coords) 2212785a24f1SPeiming Liu coord.type = parser.getBuilder().getIndexType(); 2213785a24f1SPeiming Liu 2214785a24f1SPeiming Liu // Set the CrdUsedLvl bitset. 2215785a24f1SPeiming Liu state.addAttribute("crdUsedLvls", 2216785a24f1SPeiming Liu parser.getBuilder().getI64IntegerAttr(crdUsedLvlSet)); 2217785a24f1SPeiming Liu return success(); 2218785a24f1SPeiming Liu } 2219785a24f1SPeiming Liu 2220785a24f1SPeiming Liu static ParseResult 2221785a24f1SPeiming Liu parseSparseIterateLoop(OpAsmParser &parser, OperationState &state, 2222e276cf08SPeiming Liu SmallVectorImpl<OpAsmParser::Argument> &iterators, 2223785a24f1SPeiming Liu SmallVectorImpl<OpAsmParser::Argument> &blockArgs) { 2224e276cf08SPeiming Liu SmallVector<OpAsmParser::UnresolvedOperand> spaces; 2225e276cf08SPeiming Liu SmallVector<OpAsmParser::UnresolvedOperand> initArgs; 2226e276cf08SPeiming Liu 2227e276cf08SPeiming Liu // Parse "%iters, ... in %spaces, ..." 2228e276cf08SPeiming Liu if (parser.parseArgumentList(iterators) || parser.parseKeyword("in") || 2229e276cf08SPeiming Liu parser.parseOperandList(spaces)) 2230e276cf08SPeiming Liu return failure(); 2231e276cf08SPeiming Liu 2232e276cf08SPeiming Liu if (iterators.size() != spaces.size()) 2233e276cf08SPeiming Liu return parser.emitError( 2234e276cf08SPeiming Liu parser.getNameLoc(), 2235e276cf08SPeiming Liu "mismatch in number of sparse iterators and sparse spaces"); 2236e276cf08SPeiming Liu 2237b48ef8d8SPeiming Liu SmallVector<OpAsmParser::Argument> coords; 2238b48ef8d8SPeiming Liu if (failed(parseUsedCoordList(parser, state, coords))) 2239e276cf08SPeiming Liu return failure(); 2240b48ef8d8SPeiming Liu size_t numCrds = coords.size(); 2241e276cf08SPeiming Liu 2242e276cf08SPeiming Liu // Parse "iter_args(%arg = %init, ...)" 2243e276cf08SPeiming Liu bool hasIterArgs = succeeded(parser.parseOptionalKeyword("iter_args")); 2244e276cf08SPeiming Liu if (hasIterArgs) 2245785a24f1SPeiming Liu if (parser.parseAssignmentList(blockArgs, initArgs)) 2246e276cf08SPeiming Liu return failure(); 2247e276cf08SPeiming Liu 2248b48ef8d8SPeiming Liu blockArgs.append(coords); 2249b48ef8d8SPeiming Liu 2250e276cf08SPeiming Liu SmallVector<Type> iterSpaceTps; 2251e276cf08SPeiming Liu // parse ": sparse_tensor.iter_space -> ret" 2252e276cf08SPeiming Liu if (parser.parseColon() || parser.parseTypeList(iterSpaceTps)) 2253e276cf08SPeiming Liu return failure(); 2254e276cf08SPeiming Liu if (iterSpaceTps.size() != spaces.size()) 2255e276cf08SPeiming Liu return parser.emitError(parser.getNameLoc(), 2256e276cf08SPeiming Liu "mismatch in number of iteration space operands " 2257e276cf08SPeiming Liu "and iteration space types"); 2258e276cf08SPeiming Liu 2259e276cf08SPeiming Liu for (auto [it, tp] : llvm::zip_equal(iterators, iterSpaceTps)) { 2260e276cf08SPeiming Liu IterSpaceType spaceTp = llvm::dyn_cast<IterSpaceType>(tp); 2261e276cf08SPeiming Liu if (!spaceTp) 2262e276cf08SPeiming Liu return parser.emitError(parser.getNameLoc(), 2263e276cf08SPeiming Liu "expected sparse_tensor.iter_space type for " 2264e276cf08SPeiming Liu "iteration space operands"); 2265e276cf08SPeiming Liu it.type = spaceTp.getIteratorType(); 2266e276cf08SPeiming Liu } 2267e276cf08SPeiming Liu 2268e276cf08SPeiming Liu if (hasIterArgs) 2269e276cf08SPeiming Liu if (parser.parseArrowTypeList(state.types)) 2270e276cf08SPeiming Liu return failure(); 2271e276cf08SPeiming Liu 2272e276cf08SPeiming Liu // Resolves input operands. 2273e276cf08SPeiming Liu if (parser.resolveOperands(spaces, iterSpaceTps, parser.getNameLoc(), 2274e276cf08SPeiming Liu state.operands)) 2275e276cf08SPeiming Liu return failure(); 2276e276cf08SPeiming Liu 2277e276cf08SPeiming Liu if (hasIterArgs) { 2278e276cf08SPeiming Liu // Strip off leading args that used for coordinates. 2279b48ef8d8SPeiming Liu MutableArrayRef args = MutableArrayRef(blockArgs).drop_back(numCrds); 2280785a24f1SPeiming Liu if (args.size() != initArgs.size() || args.size() != state.types.size()) { 2281785a24f1SPeiming Liu return parser.emitError( 2282785a24f1SPeiming Liu parser.getNameLoc(), 2283785a24f1SPeiming Liu "mismatch in number of iteration arguments and return values"); 2284785a24f1SPeiming Liu } 2285785a24f1SPeiming Liu 2286785a24f1SPeiming Liu for (auto [it, init, tp] : llvm::zip_equal(args, initArgs, state.types)) { 2287785a24f1SPeiming Liu it.type = tp; 2288785a24f1SPeiming Liu if (parser.resolveOperand(init, tp, state.operands)) 2289785a24f1SPeiming Liu return failure(); 2290785a24f1SPeiming Liu } 2291785a24f1SPeiming Liu } 2292785a24f1SPeiming Liu return success(); 2293785a24f1SPeiming Liu } 2294785a24f1SPeiming Liu 2295785a24f1SPeiming Liu static ParseResult 2296785a24f1SPeiming Liu parseSparseCoIterateLoop(OpAsmParser &parser, OperationState &state, 2297785a24f1SPeiming Liu SmallVectorImpl<Value> &spacesVals, 2298785a24f1SPeiming Liu SmallVectorImpl<OpAsmParser::Argument> &blockArgs) { 2299785a24f1SPeiming Liu 2300785a24f1SPeiming Liu // Parse "(%spaces, ...)" 2301785a24f1SPeiming Liu SmallVector<OpAsmParser::UnresolvedOperand> spaces; 2302785a24f1SPeiming Liu if (parser.parseOperandList(spaces, OpAsmParser::Delimiter::Paren)) 2303785a24f1SPeiming Liu return failure(); 2304785a24f1SPeiming Liu 2305c4420257SPeiming Liu SmallVector<OpAsmParser::Argument> coords; 2306c4420257SPeiming Liu if (failed(parseUsedCoordList(parser, state, coords))) 2307785a24f1SPeiming Liu return failure(); 2308c4420257SPeiming Liu size_t numCrds = coords.size(); 2309785a24f1SPeiming Liu 2310785a24f1SPeiming Liu // Parse "iter_args(%arg = %init, ...)" 2311785a24f1SPeiming Liu SmallVector<OpAsmParser::UnresolvedOperand> initArgs; 2312785a24f1SPeiming Liu bool hasIterArgs = succeeded(parser.parseOptionalKeyword("iter_args")); 2313785a24f1SPeiming Liu if (hasIterArgs) 2314785a24f1SPeiming Liu if (parser.parseAssignmentList(blockArgs, initArgs)) 2315785a24f1SPeiming Liu return failure(); 2316c4420257SPeiming Liu blockArgs.append(coords); 2317785a24f1SPeiming Liu 2318785a24f1SPeiming Liu SmallVector<Type> iterSpaceTps; 2319785a24f1SPeiming Liu // parse ": (sparse_tensor.iter_space, ...) -> ret" 2320785a24f1SPeiming Liu if (parser.parseColon() || parser.parseLParen() || 2321785a24f1SPeiming Liu parser.parseTypeList(iterSpaceTps) || parser.parseRParen()) 2322785a24f1SPeiming Liu return failure(); 2323785a24f1SPeiming Liu 2324785a24f1SPeiming Liu if (iterSpaceTps.size() != spaces.size()) 2325785a24f1SPeiming Liu return parser.emitError(parser.getNameLoc(), 2326785a24f1SPeiming Liu "mismatch in number of iteration space operands " 2327785a24f1SPeiming Liu "and iteration space types"); 2328785a24f1SPeiming Liu 2329785a24f1SPeiming Liu if (hasIterArgs) 2330785a24f1SPeiming Liu if (parser.parseArrowTypeList(state.types)) 2331785a24f1SPeiming Liu return failure(); 2332785a24f1SPeiming Liu 2333785a24f1SPeiming Liu // Resolves input sparse iteration spaces. 2334785a24f1SPeiming Liu if (parser.resolveOperands(spaces, iterSpaceTps, parser.getNameLoc(), 2335785a24f1SPeiming Liu spacesVals)) 2336785a24f1SPeiming Liu return failure(); 2337785a24f1SPeiming Liu state.operands.append(spacesVals); 2338785a24f1SPeiming Liu 2339785a24f1SPeiming Liu if (hasIterArgs) { 2340c4420257SPeiming Liu // Strip off trailing args that used for coordinates. 2341c4420257SPeiming Liu MutableArrayRef args = MutableArrayRef(blockArgs).drop_back(numCrds); 2342e276cf08SPeiming Liu if (args.size() != initArgs.size() || args.size() != state.types.size()) { 2343e276cf08SPeiming Liu return parser.emitError( 2344e276cf08SPeiming Liu parser.getNameLoc(), 2345e276cf08SPeiming Liu "mismatch in number of iteration arguments and return values"); 2346e276cf08SPeiming Liu } 2347e276cf08SPeiming Liu 2348e276cf08SPeiming Liu for (auto [it, init, tp] : llvm::zip_equal(args, initArgs, state.types)) { 2349e276cf08SPeiming Liu it.type = tp; 2350e276cf08SPeiming Liu if (parser.resolveOperand(init, tp, state.operands)) 2351e276cf08SPeiming Liu return failure(); 2352e276cf08SPeiming Liu } 2353e276cf08SPeiming Liu } 2354e276cf08SPeiming Liu return success(); 2355e276cf08SPeiming Liu } 2356e276cf08SPeiming Liu 2357481bd5d4SPeiming Liu LogicalResult ExtractIterSpaceOp::inferReturnTypes( 2358481bd5d4SPeiming Liu MLIRContext *ctx, std::optional<Location> loc, ValueRange ops, 2359481bd5d4SPeiming Liu DictionaryAttr attr, OpaqueProperties prop, RegionRange region, 2360481bd5d4SPeiming Liu SmallVectorImpl<mlir::Type> &ret) { 2361481bd5d4SPeiming Liu 2362481bd5d4SPeiming Liu ExtractIterSpaceOp::Adaptor adaptor(ops, attr, prop, region); 2363481bd5d4SPeiming Liu SparseTensorType stt = getSparseTensorType(adaptor.getTensor()); 2364481bd5d4SPeiming Liu ret.push_back(IterSpaceType::get(ctx, stt.getEncoding(), adaptor.getLoLvl(), 2365481bd5d4SPeiming Liu adaptor.getHiLvl())); 2366481bd5d4SPeiming Liu return success(); 2367481bd5d4SPeiming Liu } 2368481bd5d4SPeiming Liu 2369481bd5d4SPeiming Liu LogicalResult ExtractIterSpaceOp::verify() { 2370481bd5d4SPeiming Liu if (getLoLvl() >= getHiLvl()) 2371481bd5d4SPeiming Liu return emitOpError("expected smaller level low than level high"); 2372481bd5d4SPeiming Liu 2373481bd5d4SPeiming Liu TypedValue<IteratorType> pIter = getParentIter(); 2374481bd5d4SPeiming Liu if ((pIter && getLoLvl() == 0) || (!pIter && getLoLvl() != 0)) { 2375481bd5d4SPeiming Liu return emitOpError( 2376481bd5d4SPeiming Liu "parent iterator should be specified iff level lower bound equals 0"); 2377481bd5d4SPeiming Liu } 2378481bd5d4SPeiming Liu 2379481bd5d4SPeiming Liu if (pIter) { 2380e276cf08SPeiming Liu IterSpaceType spaceTp = getExtractedSpace().getType(); 2381481bd5d4SPeiming Liu if (pIter.getType().getEncoding() != spaceTp.getEncoding()) 2382481bd5d4SPeiming Liu return emitOpError( 2383481bd5d4SPeiming Liu "mismatch in parent iterator encoding and iteration space encoding."); 2384481bd5d4SPeiming Liu 2385481bd5d4SPeiming Liu if (spaceTp.getLoLvl() != pIter.getType().getHiLvl()) 2386481bd5d4SPeiming Liu return emitOpError("parent iterator should be used to extract an " 2387481bd5d4SPeiming Liu "iteration space from a consecutive level."); 2388481bd5d4SPeiming Liu } 2389481bd5d4SPeiming Liu 2390481bd5d4SPeiming Liu return success(); 2391481bd5d4SPeiming Liu } 2392481bd5d4SPeiming Liu 239312189f80SPeiming Liu LogicalResult ExtractValOp::verify() { 239412189f80SPeiming Liu auto stt = getSparseTensorType(getTensor()); 239512189f80SPeiming Liu auto itTp = getIterator().getType(); 239612189f80SPeiming Liu 239712189f80SPeiming Liu if (stt.getEncoding() != itTp.getEncoding()) 239812189f80SPeiming Liu return emitOpError("mismatch in tensor encoding and iterator encoding."); 239912189f80SPeiming Liu 240012189f80SPeiming Liu if (stt.getLvlRank() != itTp.getHiLvl()) 240112189f80SPeiming Liu return emitOpError("must use last-level iterator to extract values. "); 240212189f80SPeiming Liu 240312189f80SPeiming Liu return success(); 240412189f80SPeiming Liu } 240512189f80SPeiming Liu 2406a43d79afSPeiming Liu struct RemoveUnusedLvlCrds : public OpRewritePattern<IterateOp> { 2407a43d79afSPeiming Liu using OpRewritePattern::OpRewritePattern; 2408a43d79afSPeiming Liu 2409a43d79afSPeiming Liu LogicalResult matchAndRewrite(IterateOp iterateOp, 2410a43d79afSPeiming Liu PatternRewriter &rewriter) const override { 2411785a24f1SPeiming Liu I64BitSet newUsedLvls(0); 2412a43d79afSPeiming Liu llvm::BitVector toRemove(iterateOp.getBody()->getNumArguments()); 2413a43d79afSPeiming Liu for (unsigned i = 0, e = iterateOp.getSpaceDim(); i < e; i++) { 2414a43d79afSPeiming Liu if (auto crd = iterateOp.getLvlCrd(i)) { 2415a43d79afSPeiming Liu if (crd->getUsers().empty()) 2416a43d79afSPeiming Liu toRemove.set(crd->getArgNumber()); 2417a43d79afSPeiming Liu else 2418a43d79afSPeiming Liu newUsedLvls.set(i); 2419a43d79afSPeiming Liu } 2420a43d79afSPeiming Liu } 2421a43d79afSPeiming Liu 2422a43d79afSPeiming Liu // All coordinates are used. 2423a43d79afSPeiming Liu if (toRemove.none()) 2424a43d79afSPeiming Liu return failure(); 2425a43d79afSPeiming Liu 2426a43d79afSPeiming Liu rewriter.startOpModification(iterateOp); 2427a43d79afSPeiming Liu iterateOp.setCrdUsedLvls(newUsedLvls); 2428a43d79afSPeiming Liu iterateOp.getBody()->eraseArguments(toRemove); 2429a43d79afSPeiming Liu rewriter.finalizeOpModification(iterateOp); 2430a43d79afSPeiming Liu return success(); 2431a43d79afSPeiming Liu } 2432a43d79afSPeiming Liu }; 2433a43d79afSPeiming Liu 2434a43d79afSPeiming Liu void IterateOp::getCanonicalizationPatterns(mlir::RewritePatternSet &results, 2435a43d79afSPeiming Liu mlir::MLIRContext *context) { 2436a43d79afSPeiming Liu results.add<RemoveUnusedLvlCrds>(context); 2437a43d79afSPeiming Liu } 2438a43d79afSPeiming Liu 2439a02010b3SPeiming Liu void IterateOp::build(OpBuilder &builder, OperationState &odsState, 2440a02010b3SPeiming Liu Value iterSpace, ValueRange initArgs) { 2441a02010b3SPeiming Liu unsigned rank = llvm::cast<IterSpaceType>(iterSpace.getType()).getSpaceDim(); 2442a02010b3SPeiming Liu // All ones. 2443785a24f1SPeiming Liu I64BitSet set((1 << rank) - 1); 2444a02010b3SPeiming Liu return build(builder, odsState, iterSpace, initArgs, set); 2445a02010b3SPeiming Liu } 2446a02010b3SPeiming Liu 2447a02010b3SPeiming Liu void IterateOp::build(OpBuilder &builder, OperationState &odsState, 2448a02010b3SPeiming Liu Value iterSpace, ValueRange initArgs, 2449785a24f1SPeiming Liu I64BitSet crdUsedLvls) { 2450a02010b3SPeiming Liu OpBuilder::InsertionGuard guard(builder); 2451a02010b3SPeiming Liu 2452a02010b3SPeiming Liu odsState.addOperands(iterSpace); 2453a02010b3SPeiming Liu odsState.addOperands(initArgs); 2454a02010b3SPeiming Liu odsState.getOrAddProperties<Properties>().crdUsedLvls = 2455a02010b3SPeiming Liu builder.getIntegerAttr(builder.getIntegerType(64), crdUsedLvls); 2456a02010b3SPeiming Liu Region *bodyRegion = odsState.addRegion(); 2457a02010b3SPeiming Liu odsState.addTypes(initArgs.getTypes()); 2458a02010b3SPeiming Liu Block *bodyBlock = builder.createBlock(bodyRegion); 2459a02010b3SPeiming Liu 2460b48ef8d8SPeiming Liu // Starts with a list of user-provided loop arguments. 2461b48ef8d8SPeiming Liu for (Value v : initArgs) 2462b48ef8d8SPeiming Liu bodyBlock->addArgument(v.getType(), v.getLoc()); 2463a02010b3SPeiming Liu 2464b48ef8d8SPeiming Liu // Follows by a list of used coordinates. 2465a02010b3SPeiming Liu for (unsigned i = 0, e = crdUsedLvls.count(); i < e; i++) 2466a02010b3SPeiming Liu bodyBlock->addArgument(builder.getIndexType(), odsState.location); 2467a02010b3SPeiming Liu 2468b48ef8d8SPeiming Liu // Ends with sparse iterator 2469b48ef8d8SPeiming Liu bodyBlock->addArgument( 2470b48ef8d8SPeiming Liu llvm::cast<IterSpaceType>(iterSpace.getType()).getIteratorType(), 2471b48ef8d8SPeiming Liu odsState.location); 2472a02010b3SPeiming Liu } 2473a02010b3SPeiming Liu 2474e276cf08SPeiming Liu ParseResult IterateOp::parse(OpAsmParser &parser, OperationState &result) { 2475e276cf08SPeiming Liu OpAsmParser::Argument iterator; 2476e276cf08SPeiming Liu OpAsmParser::UnresolvedOperand iterSpace; 2477e276cf08SPeiming Liu 2478e276cf08SPeiming Liu SmallVector<OpAsmParser::Argument> iters, iterArgs; 2479785a24f1SPeiming Liu if (parseSparseIterateLoop(parser, result, iters, iterArgs)) 2480e276cf08SPeiming Liu return failure(); 2481e276cf08SPeiming Liu if (iters.size() != 1) 2482e276cf08SPeiming Liu return parser.emitError(parser.getNameLoc(), 2483e276cf08SPeiming Liu "expected only one iterator/iteration space"); 2484e276cf08SPeiming Liu 2485b48ef8d8SPeiming Liu iterArgs.append(iters); 2486e276cf08SPeiming Liu Region *body = result.addRegion(); 2487b48ef8d8SPeiming Liu if (parser.parseRegion(*body, iterArgs)) 2488e276cf08SPeiming Liu return failure(); 2489e276cf08SPeiming Liu 2490e276cf08SPeiming Liu IterateOp::ensureTerminator(*body, parser.getBuilder(), result.location); 2491e276cf08SPeiming Liu 2492e276cf08SPeiming Liu // Parse the optional attribute list. 2493e276cf08SPeiming Liu if (parser.parseOptionalAttrDict(result.attributes)) 2494e276cf08SPeiming Liu return failure(); 2495e276cf08SPeiming Liu 2496e276cf08SPeiming Liu return success(); 2497e276cf08SPeiming Liu } 2498e276cf08SPeiming Liu 2499e276cf08SPeiming Liu /// Prints the initialization list in the form of 2500e276cf08SPeiming Liu /// <prefix>(%inner = %outer, %inner2 = %outer2, <...>) 2501e276cf08SPeiming Liu /// where 'inner' values are assumed to be region arguments and 'outer' values 2502e276cf08SPeiming Liu /// are regular SSA values. 2503e276cf08SPeiming Liu static void printInitializationList(OpAsmPrinter &p, 2504e276cf08SPeiming Liu Block::BlockArgListType blocksArgs, 2505e276cf08SPeiming Liu ValueRange initializers, 2506e276cf08SPeiming Liu StringRef prefix = "") { 2507e276cf08SPeiming Liu assert(blocksArgs.size() == initializers.size() && 2508e276cf08SPeiming Liu "expected same length of arguments and initializers"); 2509e276cf08SPeiming Liu if (initializers.empty()) 2510e276cf08SPeiming Liu return; 2511e276cf08SPeiming Liu 2512e276cf08SPeiming Liu p << prefix << '('; 2513e276cf08SPeiming Liu llvm::interleaveComma(llvm::zip(blocksArgs, initializers), p, [&](auto it) { 2514e276cf08SPeiming Liu p << std::get<0>(it) << " = " << std::get<1>(it); 2515e276cf08SPeiming Liu }); 2516e276cf08SPeiming Liu p << ")"; 2517e276cf08SPeiming Liu } 2518e276cf08SPeiming Liu 2519785a24f1SPeiming Liu template <typename SparseLoopOp> 2520785a24f1SPeiming Liu static LogicalResult verifySparseLoopOp(SparseLoopOp op) { 2521785a24f1SPeiming Liu if (op.getInitArgs().size() != op.getNumResults()) { 2522785a24f1SPeiming Liu return op.emitOpError( 2523785a24f1SPeiming Liu "mismatch in number of loop-carried values and defined values"); 2524785a24f1SPeiming Liu } 2525785a24f1SPeiming Liu if (op.getCrdUsedLvls().max() > op.getSpaceDim()) 2526785a24f1SPeiming Liu return op.emitOpError("required out-of-bound coordinates"); 2527e276cf08SPeiming Liu 2528785a24f1SPeiming Liu return success(); 2529e276cf08SPeiming Liu } 2530785a24f1SPeiming Liu 2531785a24f1SPeiming Liu LogicalResult IterateOp::verify() { return verifySparseLoopOp(*this); } 2532785a24f1SPeiming Liu LogicalResult CoIterateOp::verify() { return verifySparseLoopOp(*this); } 2533e276cf08SPeiming Liu 2534e276cf08SPeiming Liu void IterateOp::print(OpAsmPrinter &p) { 2535e276cf08SPeiming Liu p << " " << getIterator() << " in " << getIterSpace(); 2536785a24f1SPeiming Liu if (!getCrdUsedLvls().empty()) { 2537785a24f1SPeiming Liu p << " at("; 2538785a24f1SPeiming Liu printOptionalDefinedList(p, getSpaceDim(), getCrds(), getCrdUsedLvls()); 2539785a24f1SPeiming Liu p << ")"; 2540785a24f1SPeiming Liu } 2541e276cf08SPeiming Liu printInitializationList(p, getRegionIterArgs(), getInitArgs(), " iter_args"); 2542e276cf08SPeiming Liu 2543e276cf08SPeiming Liu p << " : " << getIterSpace().getType() << " "; 2544e276cf08SPeiming Liu if (!getInitArgs().empty()) 2545785a24f1SPeiming Liu p.printArrowTypeList(getInitArgs().getTypes()); 2546e276cf08SPeiming Liu 2547785a24f1SPeiming Liu p << " "; 2548e276cf08SPeiming Liu p.printRegion(getRegion(), /*printEntryBlockArgs=*/false, 2549e276cf08SPeiming Liu /*printBlockTerminators=*/!getInitArgs().empty()); 2550e276cf08SPeiming Liu } 2551e276cf08SPeiming Liu 2552e276cf08SPeiming Liu LogicalResult IterateOp::verifyRegions() { 2553e276cf08SPeiming Liu if (getIterator().getType() != getIterSpace().getType().getIteratorType()) 2554e276cf08SPeiming Liu return emitOpError("mismatch in iterator and iteration space type"); 2555e276cf08SPeiming Liu if (getNumRegionIterArgs() != getNumResults()) 2556e276cf08SPeiming Liu return emitOpError( 2557e276cf08SPeiming Liu "mismatch in number of basic block args and defined values"); 2558e276cf08SPeiming Liu 2559e276cf08SPeiming Liu auto initArgs = getInitArgs(); 2560e276cf08SPeiming Liu auto iterArgs = getRegionIterArgs(); 2561e276cf08SPeiming Liu auto yieldVals = getYieldedValues(); 2562e276cf08SPeiming Liu auto opResults = getResults(); 2563e276cf08SPeiming Liu if (!llvm::all_equal({initArgs.size(), iterArgs.size(), yieldVals.size(), 2564e276cf08SPeiming Liu opResults.size()})) { 2565e276cf08SPeiming Liu return emitOpError() << "number mismatch between iter args and results."; 2566e276cf08SPeiming Liu } 2567e276cf08SPeiming Liu 2568e276cf08SPeiming Liu for (auto [i, init, iter, yield, ret] : 2569e276cf08SPeiming Liu llvm::enumerate(initArgs, iterArgs, yieldVals, opResults)) { 2570e276cf08SPeiming Liu if (init.getType() != ret.getType()) 2571e276cf08SPeiming Liu return emitOpError() << "types mismatch between " << i 2572e276cf08SPeiming Liu << "th iter operand and defined value"; 2573e276cf08SPeiming Liu if (iter.getType() != ret.getType()) 2574e276cf08SPeiming Liu return emitOpError() << "types mismatch between " << i 2575e276cf08SPeiming Liu << "th iter region arg and defined value"; 2576e276cf08SPeiming Liu if (yield.getType() != ret.getType()) 2577e276cf08SPeiming Liu return emitOpError() << "types mismatch between " << i 2578e276cf08SPeiming Liu << "th yield value and defined value"; 2579e276cf08SPeiming Liu } 2580e276cf08SPeiming Liu 2581e276cf08SPeiming Liu return success(); 2582e276cf08SPeiming Liu } 2583e276cf08SPeiming Liu 2584e276cf08SPeiming Liu /// OpInterfaces' methods implemented by IterateOp. 2585e276cf08SPeiming Liu SmallVector<Region *> IterateOp::getLoopRegions() { return {&getRegion()}; } 2586e276cf08SPeiming Liu 2587e276cf08SPeiming Liu MutableArrayRef<OpOperand> IterateOp::getInitsMutable() { 2588e276cf08SPeiming Liu return getInitArgsMutable(); 2589e276cf08SPeiming Liu } 2590e276cf08SPeiming Liu 2591e276cf08SPeiming Liu Block::BlockArgListType IterateOp::getRegionIterArgs() { 2592b48ef8d8SPeiming Liu return getRegion().getArguments().take_front(getNumRegionIterArgs()); 2593e276cf08SPeiming Liu } 2594e276cf08SPeiming Liu 2595e276cf08SPeiming Liu std::optional<MutableArrayRef<OpOperand>> IterateOp::getYieldedValuesMutable() { 2596e276cf08SPeiming Liu return cast<sparse_tensor::YieldOp>( 2597e276cf08SPeiming Liu getRegion().getBlocks().front().getTerminator()) 2598e276cf08SPeiming Liu .getResultsMutable(); 2599e276cf08SPeiming Liu } 2600e276cf08SPeiming Liu 2601e276cf08SPeiming Liu std::optional<ResultRange> IterateOp::getLoopResults() { return getResults(); } 2602e276cf08SPeiming Liu 2603e276cf08SPeiming Liu OperandRange IterateOp::getEntrySuccessorOperands(RegionBranchPoint point) { 2604e276cf08SPeiming Liu return getInitArgs(); 2605e276cf08SPeiming Liu } 2606e276cf08SPeiming Liu 2607e276cf08SPeiming Liu void IterateOp::getSuccessorRegions(RegionBranchPoint point, 2608e276cf08SPeiming Liu SmallVectorImpl<RegionSuccessor> ®ions) { 2609785a24f1SPeiming Liu // Both the operation itself and the region may be branching into the body 2610785a24f1SPeiming Liu // or back into the operation itself. 2611e276cf08SPeiming Liu regions.push_back(RegionSuccessor(&getRegion(), getRegionIterArgs())); 2612e276cf08SPeiming Liu // It is possible for loop not to enter the body. 2613e276cf08SPeiming Liu regions.push_back(RegionSuccessor(getResults())); 2614e276cf08SPeiming Liu } 2615e276cf08SPeiming Liu 2616c4420257SPeiming Liu void CoIterateOp::build(OpBuilder &builder, OperationState &odsState, 2617c4420257SPeiming Liu ValueRange iterSpaces, ValueRange initArgs, 2618c4420257SPeiming Liu unsigned numCases) { 2619c4420257SPeiming Liu unsigned rank = 2620c4420257SPeiming Liu cast<IterSpaceType>(iterSpaces.front().getType()).getSpaceDim(); 2621c4420257SPeiming Liu // All ones. 2622c4420257SPeiming Liu I64BitSet set((1 << rank) - 1); 2623c4420257SPeiming Liu // Generates all-zero case bits (they only serve as placeholders), which are 2624c4420257SPeiming Liu // supposed to be overriden later. We need to preallocate all the regions as 2625c4420257SPeiming Liu // mlir::Region cannot be dynamically added later after the operation is 2626c4420257SPeiming Liu // created. 2627c4420257SPeiming Liu SmallVector<int64_t> caseBits(numCases, 0); 2628c4420257SPeiming Liu ArrayAttr cases = builder.getI64ArrayAttr(caseBits); 2629c4420257SPeiming Liu return CoIterateOp::build(builder, odsState, initArgs.getTypes(), iterSpaces, 2630c4420257SPeiming Liu initArgs, set, cases, 2631c4420257SPeiming Liu /*caseRegionsCount=*/numCases); 2632c4420257SPeiming Liu } 2633c4420257SPeiming Liu 2634785a24f1SPeiming Liu ParseResult CoIterateOp::parse(OpAsmParser &parser, OperationState &result) { 2635785a24f1SPeiming Liu 2636785a24f1SPeiming Liu SmallVector<Value> spaces; 2637785a24f1SPeiming Liu // The block argument list of each regions, it is arranged in the order of 2638785a24f1SPeiming Liu // ([used coordinate list], [loop iterations args], [sparse iterator list]). 2639785a24f1SPeiming Liu SmallVector<OpAsmParser::Argument> blockArgs; 2640785a24f1SPeiming Liu if (parseSparseCoIterateLoop(parser, result, spaces, blockArgs)) 2641785a24f1SPeiming Liu return failure(); 2642785a24f1SPeiming Liu 2643785a24f1SPeiming Liu result.addAttribute("operandSegmentSizes", 2644785a24f1SPeiming Liu parser.getBuilder().getDenseI32ArrayAttr( 2645785a24f1SPeiming Liu {static_cast<int32_t>(spaces.size()), 2646785a24f1SPeiming Liu static_cast<int32_t>(result.types.size())})); 2647785a24f1SPeiming Liu 2648785a24f1SPeiming Liu SmallVector<Attribute> cases; 2649785a24f1SPeiming Liu while (succeeded(parser.parseOptionalKeyword("case"))) { 2650785a24f1SPeiming Liu // Parse one region per case. 2651785a24f1SPeiming Liu I64BitSet definedItSet; 2652785a24f1SPeiming Liu SmallVector<OpAsmParser::Argument> definedIts; 2653785a24f1SPeiming Liu if (parseOptionalDefinedList(parser, result, definedItSet, definedIts, 2654785a24f1SPeiming Liu spaces.size(), OpAsmParser::Delimiter::None)) 2655785a24f1SPeiming Liu return failure(); 2656785a24f1SPeiming Liu 2657785a24f1SPeiming Liu cases.push_back(parser.getBuilder().getI64IntegerAttr(definedItSet)); 2658785a24f1SPeiming Liu 2659785a24f1SPeiming Liu for (auto [i, definedIdx] : llvm::enumerate(definedItSet.bits())) { 2660785a24f1SPeiming Liu // Resolve the iterator type based on the iteration space type. 2661785a24f1SPeiming Liu auto spaceTp = llvm::cast<IterSpaceType>(spaces[definedIdx].getType()); 2662785a24f1SPeiming Liu definedIts[i].type = spaceTp.getIteratorType(); 2663785a24f1SPeiming Liu } 2664785a24f1SPeiming Liu definedIts.insert(definedIts.begin(), blockArgs.begin(), blockArgs.end()); 2665785a24f1SPeiming Liu Region *body = result.addRegion(); 2666785a24f1SPeiming Liu if (parser.parseRegion(*body, definedIts)) 2667785a24f1SPeiming Liu return failure(); 2668785a24f1SPeiming Liu 2669785a24f1SPeiming Liu CoIterateOp::ensureTerminator(*body, parser.getBuilder(), result.location); 2670785a24f1SPeiming Liu } 2671785a24f1SPeiming Liu 2672785a24f1SPeiming Liu result.addAttribute("cases", ArrayAttr::get(parser.getContext(), cases)); 2673785a24f1SPeiming Liu 2674785a24f1SPeiming Liu // Parse the optional attribute list. 2675785a24f1SPeiming Liu if (parser.parseOptionalAttrDict(result.attributes)) 2676785a24f1SPeiming Liu return failure(); 2677785a24f1SPeiming Liu 2678785a24f1SPeiming Liu return success(); 2679785a24f1SPeiming Liu } 2680785a24f1SPeiming Liu 2681785a24f1SPeiming Liu void CoIterateOp::print(OpAsmPrinter &p) { 2682785a24f1SPeiming Liu p << " ("; 2683785a24f1SPeiming Liu llvm::interleaveComma(getIterSpaces(), p, [&](auto s) { p << s; }); 2684785a24f1SPeiming Liu p << ")"; 2685785a24f1SPeiming Liu 2686785a24f1SPeiming Liu if (!getCrdUsedLvls().empty()) { 2687785a24f1SPeiming Liu p << " at("; 2688785a24f1SPeiming Liu printOptionalDefinedList(p, getSpaceDim(), getCrds(0), getCrdUsedLvls()); 2689785a24f1SPeiming Liu p << ")"; 2690785a24f1SPeiming Liu } 2691785a24f1SPeiming Liu 2692785a24f1SPeiming Liu printInitializationList(p, getRegionIterArgs(0), getInitArgs(), " iter_args"); 2693785a24f1SPeiming Liu 2694785a24f1SPeiming Liu p << " : (" << getIterSpaces().getTypes() << ")"; 2695785a24f1SPeiming Liu if (!getInitArgs().empty()) 2696785a24f1SPeiming Liu p.printArrowTypeList(getInitArgs().getTypes()); 2697785a24f1SPeiming Liu 2698785a24f1SPeiming Liu for (unsigned idx = 0, e = getRegions().size(); idx < e; idx++) { 2699785a24f1SPeiming Liu p.printNewline(); 2700785a24f1SPeiming Liu p << "case "; 2701785a24f1SPeiming Liu printOptionalDefinedList(p, getIterSpaces().size(), getRegionIterators(idx), 2702785a24f1SPeiming Liu getRegionDefinedSpace(idx)); 2703785a24f1SPeiming Liu p << " "; 2704785a24f1SPeiming Liu p.printRegion(getRegion(idx), /*printEntryBlockArgs=*/false, 2705785a24f1SPeiming Liu /*printBlockTerminators=*/!getInitArgs().empty()); 2706785a24f1SPeiming Liu } 2707785a24f1SPeiming Liu } 2708785a24f1SPeiming Liu 2709785a24f1SPeiming Liu ValueRange CoIterateOp::getYieldedValues(unsigned regionIdx) { 2710785a24f1SPeiming Liu return cast<sparse_tensor::YieldOp>( 2711785a24f1SPeiming Liu getRegion(regionIdx).getBlocks().front().getTerminator()) 2712785a24f1SPeiming Liu .getResults(); 2713785a24f1SPeiming Liu } 2714785a24f1SPeiming Liu 2715785a24f1SPeiming Liu LogicalResult CoIterateOp::verifyRegions() { 2716785a24f1SPeiming Liu for (unsigned r = 0, e = getNumRegions(); r < e; r++) { 2717c4420257SPeiming Liu if (getNumRegionIterArgs() != getNumResults()) 2718785a24f1SPeiming Liu return emitOpError( 2719785a24f1SPeiming Liu "mismatch in number of basic block args and defined values"); 2720785a24f1SPeiming Liu 2721785a24f1SPeiming Liu auto initArgs = getInitArgs(); 2722785a24f1SPeiming Liu auto iterArgs = getRegionIterArgs(r); 2723785a24f1SPeiming Liu auto yieldVals = getYieldedValues(r); 2724785a24f1SPeiming Liu auto opResults = getResults(); 2725785a24f1SPeiming Liu if (!llvm::all_equal({initArgs.size(), iterArgs.size(), yieldVals.size(), 2726785a24f1SPeiming Liu opResults.size()})) { 2727785a24f1SPeiming Liu return emitOpError() 2728785a24f1SPeiming Liu << "number mismatch between iter args and results on " << r 2729785a24f1SPeiming Liu << "th region"; 2730785a24f1SPeiming Liu } 2731785a24f1SPeiming Liu 2732785a24f1SPeiming Liu for (auto [i, init, iter, yield, ret] : 2733785a24f1SPeiming Liu llvm::enumerate(initArgs, iterArgs, yieldVals, opResults)) { 2734785a24f1SPeiming Liu if (init.getType() != ret.getType()) 2735785a24f1SPeiming Liu return emitOpError() 2736785a24f1SPeiming Liu << "types mismatch between " << i 2737785a24f1SPeiming Liu << "th iter operand and defined value on " << r << "th region"; 2738785a24f1SPeiming Liu if (iter.getType() != ret.getType()) 2739785a24f1SPeiming Liu return emitOpError() << "types mismatch between " << i 2740785a24f1SPeiming Liu << "th iter region arg and defined value on " << r 2741785a24f1SPeiming Liu << "th region"; 2742785a24f1SPeiming Liu if (yield.getType() != ret.getType()) 2743785a24f1SPeiming Liu return emitOpError() 2744785a24f1SPeiming Liu << "types mismatch between " << i 2745785a24f1SPeiming Liu << "th yield value and defined value on " << r << "th region"; 2746785a24f1SPeiming Liu } 2747785a24f1SPeiming Liu } 2748785a24f1SPeiming Liu 2749785a24f1SPeiming Liu auto cases = getRegionDefinedSpaces(); 2750785a24f1SPeiming Liu llvm::SmallSetVector<uint64_t, 8> set(cases.begin(), cases.end()); 2751785a24f1SPeiming Liu if (set.size() != getNumRegions()) 2752785a24f1SPeiming Liu return emitOpError("contains duplicated cases."); 2753785a24f1SPeiming Liu 2754785a24f1SPeiming Liu return success(); 2755785a24f1SPeiming Liu } 2756785a24f1SPeiming Liu 2757f607102aSPeiming Liu SmallVector<Region *> CoIterateOp::getSubCasesOf(unsigned regionIdx) { 2758f607102aSPeiming Liu SmallVector<Region *> ret; 2759f607102aSPeiming Liu I64BitSet caseBit = getRegionDefinedSpace(regionIdx); 2760f607102aSPeiming Liu for (Region &r : getCaseRegions()) 2761f607102aSPeiming Liu if (getRegionDefinedSpace(r.getRegionNumber()).isSubSetOf(caseBit)) 2762f607102aSPeiming Liu ret.push_back(&r); 2763f607102aSPeiming Liu 2764f607102aSPeiming Liu return ret; 2765f607102aSPeiming Liu } 2766f607102aSPeiming Liu 2767e276cf08SPeiming Liu //===----------------------------------------------------------------------===// 2768e276cf08SPeiming Liu // Sparse Tensor Dialect Setups. 2769e276cf08SPeiming Liu //===----------------------------------------------------------------------===// 2770e276cf08SPeiming Liu 2771f0f5fdf7SPeiming Liu /// Materialize a single constant operation from a given attribute value with 2772f0f5fdf7SPeiming Liu /// the desired resultant type. 2773f0f5fdf7SPeiming Liu Operation *SparseTensorDialect::materializeConstant(OpBuilder &builder, 2774f0f5fdf7SPeiming Liu Attribute value, Type type, 2775f0f5fdf7SPeiming Liu Location loc) { 2776f0f5fdf7SPeiming Liu if (auto op = arith::ConstantOp::materialize(builder, value, type, loc)) 2777f0f5fdf7SPeiming Liu return op; 2778f0f5fdf7SPeiming Liu return nullptr; 2779f0f5fdf7SPeiming Liu } 2780f0f5fdf7SPeiming Liu 2781c5a67e16SYinying Li namespace { 2782c5a67e16SYinying Li struct SparseTensorAsmDialectInterface : public OpAsmDialectInterface { 2783c5a67e16SYinying Li using OpAsmDialectInterface::OpAsmDialectInterface; 2784c5a67e16SYinying Li 2785c5a67e16SYinying Li AliasResult getAlias(Attribute attr, raw_ostream &os) const override { 2786a5757c5bSChristian Sigg if (isa<SparseTensorEncodingAttr>(attr)) { 2787c5a67e16SYinying Li os << "sparse"; 2788c5a67e16SYinying Li return AliasResult::OverridableAlias; 2789c5a67e16SYinying Li } 2790c5a67e16SYinying Li return AliasResult::NoAlias; 2791c5a67e16SYinying Li } 2792c5a67e16SYinying Li }; 2793c5a67e16SYinying Li } // namespace 2794c5a67e16SYinying Li 2795319072f4SAart Bik void SparseTensorDialect::initialize() { 2796c5a67e16SYinying Li addInterface<SparseTensorAsmDialectInterface>(); 27970a292199SAart Bik addAttributes< 27980a292199SAart Bik #define GET_ATTRDEF_LIST 27990a292199SAart Bik #include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.cpp.inc" 28000a292199SAart Bik >(); 280171cc0f1cSPeiming Liu addTypes< 280271cc0f1cSPeiming Liu #define GET_TYPEDEF_LIST 280371cc0f1cSPeiming Liu #include "mlir/Dialect/SparseTensor/IR/SparseTensorTypes.cpp.inc" 280471cc0f1cSPeiming Liu >(); 2805319072f4SAart Bik addOperations< 2806319072f4SAart Bik #define GET_OP_LIST 2807319072f4SAart Bik #include "mlir/Dialect/SparseTensor/IR/SparseTensorOps.cpp.inc" 2808319072f4SAart Bik >(); 2809513cdb82SJustin Fargnoli declarePromisedInterfaces< 2810513cdb82SJustin Fargnoli bufferization::BufferizableOpInterface, ConcatenateOp, ConvertOp, LoadOp, 2811513cdb82SJustin Fargnoli NewOp, NumberOfEntriesOp, AssembleOp, DisassembleOp, 2812513cdb82SJustin Fargnoli ToCoordinatesBufferOp, ToCoordinatesOp, ToPositionsOp, ToValuesOp>(); 2813319072f4SAart Bik } 2814319072f4SAart Bik 2815319072f4SAart Bik #define GET_OP_CLASSES 2816319072f4SAart Bik #include "mlir/Dialect/SparseTensor/IR/SparseTensorOps.cpp.inc" 28175517208dSAart Bik 28185517208dSAart Bik #include "mlir/Dialect/SparseTensor/IR/SparseTensorOpsDialect.cpp.inc" 2819