//===- SparseTensorDialect.cpp - Sparse tensor dialect implementation -----===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include #include "Detail/DimLvlMapParser.h" #include "mlir/Dialect/SparseTensor/IR/Enums.h" #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" #include "mlir/Dialect/SparseTensor/IR/SparseTensorStorageLayout.h" #include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" #include "mlir/Dialect/Complex/IR/Complex.h" #include "mlir/Dialect/Utils/StaticValueUtils.h" #include "mlir/IR/Builders.h" #include "mlir/IR/DialectImplementation.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/OpImplementation.h" #include "mlir/IR/PatternMatch.h" #include "llvm/ADT/Bitset.h" #include "llvm/ADT/TypeSwitch.h" #include "llvm/Support/FormatVariadic.h" #define GET_ATTRDEF_CLASSES #include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.cpp.inc" #include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrEnums.cpp.inc" // Forward declarations, following custom print/parsing methods are referenced // by the generated code for SparseTensorTypes.td. static mlir::ParseResult parseLevelRange(mlir::AsmParser &, mlir::sparse_tensor::Level &, mlir::sparse_tensor::Level &); static void printLevelRange(mlir::AsmPrinter &, mlir::sparse_tensor::Level, mlir::sparse_tensor::Level); #define GET_TYPEDEF_CLASSES #include "mlir/Dialect/SparseTensor/IR/SparseTensorTypes.cpp.inc" using namespace mlir; using namespace mlir::sparse_tensor; // Support hashing LevelType such that SparseTensorEncodingAttr can be hashed as // well. namespace mlir::sparse_tensor { llvm::hash_code hash_value(LevelType lt) { return llvm::hash_value(static_cast(lt)); } } // namespace mlir::sparse_tensor //===----------------------------------------------------------------------===// // Local Convenience Methods. //===----------------------------------------------------------------------===// static constexpr bool acceptBitWidth(unsigned bitWidth) { switch (bitWidth) { case 0: case 8: case 16: case 32: case 64: return true; default: return false; } } static SmallVector getSparseFieldShape(const SparseTensorEncodingAttr enc, std::optional> dimShape) { assert(enc); // With only encoding, we can not determine the static shape for leading // batch levels, we therefore return a dynamic shape memref instead. SmallVector memrefShape(enc.getBatchLvlRank(), ShapedType::kDynamic); if (dimShape.has_value()) { // If the actual tensor shape is provided, we can then refine the leading // batch dimension. SmallVector lvlShape = enc.translateShape(*dimShape, CrdTransDirectionKind::dim2lvl); memrefShape.assign(lvlShape.begin(), lvlShape.begin() + enc.getBatchLvlRank()); } // Another dynamic dimension to store the sparse level. memrefShape.push_back(ShapedType::kDynamic); return memrefShape; } //===----------------------------------------------------------------------===// // SparseTensorDialect StorageLayout. //===----------------------------------------------------------------------===// static constexpr Level kInvalidLevel = -1u; static constexpr Level kInvalidFieldIndex = -1u; static constexpr FieldIndex kDataFieldStartingIdx = 0; void StorageLayout::foreachField( llvm::function_ref callback) const { const auto lvlTypes = enc.getLvlTypes(); const Level lvlRank = enc.getLvlRank(); SmallVector cooSegs = enc.getCOOSegments(); FieldIndex fieldIdx = kDataFieldStartingIdx; ArrayRef cooSegsRef = cooSegs; // Per-level storage. for (Level l = 0; l < lvlRank; /*l += 1 or l += AoSCooLen*/) { const auto lt = lvlTypes[l]; if (isWithPosLT(lt)) { if (!(callback(fieldIdx++, SparseTensorFieldKind::PosMemRef, l, lt))) return; } if (isWithCrdLT(lt)) { if (!(callback(fieldIdx++, SparseTensorFieldKind::CrdMemRef, l, lt))) return; } if (!cooSegsRef.empty() && cooSegsRef.front().isSegmentStart(l)) { if (!cooSegsRef.front().isSoA) { // AoS COO, all singletons are fused into one memrefs. Skips the entire // COO segement. l = cooSegsRef.front().lvlRange.second; } else { // SoA COO, each singleton level has one memref. l++; } // Expire handled COO segment. cooSegsRef = cooSegsRef.drop_front(); } else { // Non COO levels. l++; } } // The values array. if (!(callback(fieldIdx++, SparseTensorFieldKind::ValMemRef, kInvalidLevel, LevelFormat::Undef))) return; // Put metadata at the end. if (!(callback(fieldIdx++, SparseTensorFieldKind::StorageSpec, kInvalidLevel, LevelFormat::Undef))) return; } void sparse_tensor::foreachFieldAndTypeInSparseTensor( SparseTensorType stt, llvm::function_ref callback) { assert(stt.hasEncoding()); SmallVector memrefShape = getSparseFieldShape(stt.getEncoding(), stt.getDimShape()); const Type specType = StorageSpecifierType::get(stt.getEncoding()); // memref<[batch] x ? x pos> positions const Type posMemType = MemRefType::get(memrefShape, stt.getPosType()); // memref<[batch] x ? x crd> coordinates const Type crdMemType = MemRefType::get(memrefShape, stt.getCrdType()); // memref<[batch] x ? x eltType> values const Type valMemType = MemRefType::get(memrefShape, stt.getElementType()); StorageLayout(stt).foreachField([specType, posMemType, crdMemType, valMemType, callback](FieldIndex fieldIdx, SparseTensorFieldKind fieldKind, Level lvl, LevelType lt) -> bool { switch (fieldKind) { case SparseTensorFieldKind::StorageSpec: return callback(specType, fieldIdx, fieldKind, lvl, lt); case SparseTensorFieldKind::PosMemRef: return callback(posMemType, fieldIdx, fieldKind, lvl, lt); case SparseTensorFieldKind::CrdMemRef: return callback(crdMemType, fieldIdx, fieldKind, lvl, lt); case SparseTensorFieldKind::ValMemRef: return callback(valMemType, fieldIdx, fieldKind, lvl, lt); }; llvm_unreachable("unrecognized field kind"); }); } unsigned StorageLayout::getNumFields() const { unsigned numFields = 0; foreachField([&numFields](FieldIndex, SparseTensorFieldKind, Level, LevelType) -> bool { numFields++; return true; }); return numFields; } unsigned StorageLayout::getNumDataFields() const { unsigned numFields = 0; // one value memref foreachField([&numFields](FieldIndex fidx, SparseTensorFieldKind, Level, LevelType) -> bool { if (fidx >= kDataFieldStartingIdx) numFields++; return true; }); numFields -= 1; // the last field is StorageSpecifier assert(numFields == getNumFields() - kDataFieldStartingIdx - 1); return numFields; } std::pair StorageLayout::getFieldIndexAndStride(SparseTensorFieldKind kind, std::optional lvl) const { FieldIndex fieldIdx = kInvalidFieldIndex; unsigned stride = 1; if (kind == SparseTensorFieldKind::CrdMemRef) { assert(lvl.has_value()); const Level cooStart = enc.getAoSCOOStart(); const Level lvlRank = enc.getLvlRank(); if (lvl.value() >= cooStart && lvl.value() < lvlRank) { lvl = cooStart; stride = lvlRank - cooStart; } } foreachField([lvl, kind, &fieldIdx](FieldIndex fIdx, SparseTensorFieldKind fKind, Level fLvl, LevelType lt) -> bool { if ((lvl && fLvl == lvl.value() && kind == fKind) || (kind == fKind && fKind == SparseTensorFieldKind::ValMemRef)) { fieldIdx = fIdx; // Returns false to break the iteration. return false; } return true; }); assert(fieldIdx != kInvalidFieldIndex); return std::pair(fieldIdx, stride); } //===----------------------------------------------------------------------===// // SparseTensorDialect Attribute Methods. //===----------------------------------------------------------------------===// std::optional SparseTensorDimSliceAttr::getStatic(int64_t v) { return isDynamic(v) ? std::nullopt : std::make_optional(static_cast(v)); } std::optional SparseTensorDimSliceAttr::getStaticOffset() const { return getStatic(getOffset()); } std::optional SparseTensorDimSliceAttr::getStaticStride() const { return getStatic(getStride()); } std::optional SparseTensorDimSliceAttr::getStaticSize() const { return getStatic(getSize()); } bool SparseTensorDimSliceAttr::isCompletelyDynamic() const { return isDynamic(getOffset()) && isDynamic(getStride()) && isDynamic(getSize()); } std::string SparseTensorDimSliceAttr::getStaticString(int64_t v) { return isDynamic(v) ? "?" : std::to_string(v); } void SparseTensorDimSliceAttr::print(llvm::raw_ostream &os) const { assert(getImpl() && "Uninitialized SparseTensorDimSliceAttr"); os << '('; os << getStaticString(getOffset()); os << ", "; os << getStaticString(getSize()); os << ", "; os << getStaticString(getStride()); os << ')'; } void SparseTensorDimSliceAttr::print(AsmPrinter &printer) const { print(printer.getStream()); } static ParseResult parseOptionalStaticSlice(int64_t &result, AsmParser &parser) { auto parseResult = parser.parseOptionalInteger(result); if (parseResult.has_value()) { if (parseResult.value().succeeded() && result < 0) { parser.emitError( parser.getCurrentLocation(), "expect positive value or ? for slice offset/size/stride"); return failure(); } return parseResult.value(); } // Else, and '?' which represented dynamic slice result = SparseTensorDimSliceAttr::kDynamic; return parser.parseQuestion(); } Attribute SparseTensorDimSliceAttr::parse(AsmParser &parser, Type type) { int64_t offset = kDynamic, size = kDynamic, stride = kDynamic; if (failed(parser.parseLParen()) || failed(parseOptionalStaticSlice(offset, parser)) || failed(parser.parseComma()) || failed(parseOptionalStaticSlice(size, parser)) || failed(parser.parseComma()) || failed(parseOptionalStaticSlice(stride, parser)) || failed(parser.parseRParen())) return {}; return parser.getChecked(parser.getContext(), offset, size, stride); } LogicalResult SparseTensorDimSliceAttr::verify(function_ref emitError, int64_t offset, int64_t size, int64_t stride) { if (!isDynamic(offset) && offset < 0) return emitError() << "expect non-negative value or ? for slice offset"; if (!isDynamic(size) && size <= 0) return emitError() << "expect positive value or ? for slice size"; if (!isDynamic(stride) && stride <= 0) return emitError() << "expect positive value or ? for slice stride"; return success(); } SparseTensorEncodingAttr SparseTensorEncodingAttr::withDimToLvl(AffineMap dimToLvl) const { assert(getImpl() && "Uninitialized SparseTensorEncodingAttr"); return SparseTensorEncodingAttr::get( getContext(), getLvlTypes(), dimToLvl, AffineMap(), getPosWidth(), getCrdWidth(), getExplicitVal(), getImplicitVal()); } SparseTensorEncodingAttr SparseTensorEncodingAttr::withDimToLvl(SparseTensorEncodingAttr enc) const { return withDimToLvl(enc ? enc.getDimToLvl() : AffineMap()); } SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutDimToLvl() const { return withDimToLvl(AffineMap()); } SparseTensorEncodingAttr SparseTensorEncodingAttr::withBitWidths(unsigned posWidth, unsigned crdWidth) const { assert(getImpl() && "Uninitialized SparseTensorEncodingAttr"); return SparseTensorEncodingAttr::get( getContext(), getLvlTypes(), getDimToLvl(), getLvlToDim(), posWidth, crdWidth, getExplicitVal(), getImplicitVal()); } SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutBitWidths() const { return withBitWidths(0, 0); } SparseTensorEncodingAttr SparseTensorEncodingAttr::withExplicitVal(Attribute explicitVal) const { assert(getImpl() && "Uninitialized SparseTensorEncodingAttr"); return SparseTensorEncodingAttr::get( getContext(), getLvlTypes(), getDimToLvl(), getLvlToDim(), getPosWidth(), getCrdWidth(), explicitVal, getImplicitVal()); } SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutExplicitVal() const { return withExplicitVal(Attribute()); } SparseTensorEncodingAttr SparseTensorEncodingAttr::withImplicitVal(Attribute implicitVal) const { assert(getImpl() && "Uninitialized SparseTensorEncodingAttr"); return SparseTensorEncodingAttr::get( getContext(), getLvlTypes(), getDimToLvl(), getLvlToDim(), getPosWidth(), getCrdWidth(), getExplicitVal(), implicitVal); } SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutImplicitVal() const { return withImplicitVal(Attribute()); } SparseTensorEncodingAttr SparseTensorEncodingAttr::withDimSlices( ArrayRef dimSlices) const { return SparseTensorEncodingAttr::get( getContext(), getLvlTypes(), getDimToLvl(), getLvlToDim(), getPosWidth(), getCrdWidth(), getExplicitVal(), getImplicitVal(), dimSlices); } SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutDimSlices() const { return withDimSlices(ArrayRef{}); } uint64_t SparseTensorEncodingAttr::getBatchLvlRank() const { ArrayRef lvlTypes = getLvlTypes(); auto lastBatch = std::find_if(lvlTypes.rbegin(), lvlTypes.rend(), isBatchLT); return std::distance(lastBatch, lvlTypes.rend()); } bool SparseTensorEncodingAttr::isAllDense() const { return !getImpl() || llvm::all_of(getLvlTypes(), isDenseLT); } bool SparseTensorEncodingAttr::isAllOrdered() const { return !getImpl() || llvm::all_of(getLvlTypes(), isOrderedLT); } Type SparseTensorEncodingAttr::getCrdElemType() const { if (!getImpl()) return nullptr; if (getCrdWidth()) return IntegerType::get(getContext(), getCrdWidth()); return IndexType::get(getContext()); } Type SparseTensorEncodingAttr::getPosElemType() const { if (!getImpl()) return nullptr; if (getPosWidth()) return IntegerType::get(getContext(), getPosWidth()); return IndexType::get(getContext()); } MemRefType SparseTensorEncodingAttr::getCrdMemRefType( std::optional> dimShape) const { SmallVector shape = getSparseFieldShape(*this, dimShape); return MemRefType::get(shape, getCrdElemType()); } MemRefType SparseTensorEncodingAttr::getPosMemRefType( std::optional> dimShape) const { SmallVector shape = getSparseFieldShape(*this, dimShape); return MemRefType::get(shape, getPosElemType()); } bool SparseTensorEncodingAttr::isIdentity() const { return !getImpl() || !getDimToLvl() || getDimToLvl().isIdentity(); } bool SparseTensorEncodingAttr::isPermutation() const { return !getImpl() || !getDimToLvl() || getDimToLvl().isPermutation(); } Dimension SparseTensorEncodingAttr::getDimRank() const { assert(getImpl() && "Uninitialized SparseTensorEncodingAttr"); const auto dimToLvl = getDimToLvl(); return dimToLvl ? dimToLvl.getNumDims() : getLvlRank(); } Level SparseTensorEncodingAttr::getLvlRank() const { assert(getImpl() && "Uninitialized SparseTensorEncodingAttr"); return getLvlTypes().size(); } LevelType SparseTensorEncodingAttr::getLvlType(Level l) const { if (!getImpl()) return LevelFormat::Batch; assert(l < getLvlRank() && "Level is out of bounds"); return getLvlTypes()[l]; } bool SparseTensorEncodingAttr::isSlice() const { assert(getImpl() && "Uninitialized SparseTensorEncodingAttr"); return !getDimSlices().empty(); } SparseTensorDimSliceAttr SparseTensorEncodingAttr::getDimSlice(Dimension dim) const { assert(isSlice() && "Is not a slice"); const auto dimSlices = getDimSlices(); assert(dim < dimSlices.size() && "Dimension is out of bounds"); return dimSlices[dim]; } std::optional SparseTensorEncodingAttr::getStaticDimSliceOffset(Dimension dim) const { return getDimSlice(dim).getStaticOffset(); } std::optional SparseTensorEncodingAttr::getStaticDimSliceStride(Dimension dim) const { return getDimSlice(dim).getStaticStride(); } std::optional SparseTensorEncodingAttr::getStaticLvlSliceOffset(Level lvl) const { return getStaticDimSliceOffset(toDim(*this, lvl)); } std::optional SparseTensorEncodingAttr::getStaticLvlSliceStride(Level lvl) const { return getStaticDimSliceStride(toDim(*this, lvl)); } SmallVector SparseTensorEncodingAttr::translateShape(ArrayRef srcShape, CrdTransDirectionKind dir) const { if (isIdentity()) return SmallVector(srcShape); SmallVector ret; unsigned rank = dir == CrdTransDirectionKind::dim2lvl ? getLvlRank() : getDimRank(); ret.reserve(rank); if (isPermutation()) { for (unsigned r = 0; r < rank; r++) { unsigned trans = dir == CrdTransDirectionKind::dim2lvl ? toDim(*this, r) : toLvl(*this, r); ret.push_back(srcShape[trans]); } return ret; } // Handle non-permutation maps. AffineMap transMap = dir == CrdTransDirectionKind::dim2lvl ? getDimToLvl() : getLvlToDim(); SmallVector dimRep; dimRep.reserve(srcShape.size()); for (int64_t sz : srcShape) { if (!ShapedType::isDynamic(sz)) { // Push back the max coordinate for the given dimension/level size. dimRep.push_back(getAffineConstantExpr(sz - 1, getContext())); } else { // A dynamic size, use a AffineDimExpr to symbolize the value. dimRep.push_back(getAffineDimExpr(dimRep.size(), getContext())); } }; for (AffineExpr exp : transMap.getResults()) { // Do constant propagation on the affine map. AffineExpr evalExp = simplifyAffineExpr(exp.replaceDims(dimRep), srcShape.size(), 0); // use llvm namespace here to avoid ambiguity if (auto c = llvm::dyn_cast(evalExp)) { ret.push_back(c.getValue() + 1); } else { if (auto mod = llvm::dyn_cast(evalExp); mod && mod.getKind() == AffineExprKind::Mod) { // We can still infer a static bound for expressions in form // "d % constant" since d % constant \in [0, constant). if (auto bound = llvm::dyn_cast(mod.getRHS())) { ret.push_back(bound.getValue()); continue; } } ret.push_back(ShapedType::kDynamic); } } assert(ret.size() == rank); return ret; } ValueRange SparseTensorEncodingAttr::translateCrds(OpBuilder &builder, Location loc, ValueRange crds, CrdTransDirectionKind dir) const { if (!getImpl()) return crds; SmallVector retType( dir == CrdTransDirectionKind::lvl2dim ? getDimRank() : getLvlRank(), builder.getIndexType()); auto transOp = builder.create(loc, retType, crds, dir, *this); return transOp.getOutCrds(); } Attribute SparseTensorEncodingAttr::parse(AsmParser &parser, Type type) { // Open "<{" part. if (failed(parser.parseLess())) return {}; if (failed(parser.parseLBrace())) return {}; // Process the data from the parsed dictionary value into struct-like data. SmallVector lvlTypes; SmallVector dimSlices; AffineMap dimToLvl = {}; AffineMap lvlToDim = {}; unsigned posWidth = 0; unsigned crdWidth = 0; Attribute explicitVal; Attribute implicitVal; StringRef attrName; SmallVector keys = {"map", "posWidth", "crdWidth", "explicitVal", "implicitVal"}; while (succeeded(parser.parseOptionalKeyword(&attrName))) { // Detect admissible keyword. auto *it = find(keys, attrName); if (it == keys.end()) { parser.emitError(parser.getNameLoc(), "unexpected key: ") << attrName; return {}; } unsigned keyWordIndex = it - keys.begin(); // Consume the `=` after keys if (failed(parser.parseEqual())) return {}; // Dispatch on keyword. switch (keyWordIndex) { case 0: { // map ir_detail::DimLvlMapParser cParser(parser); auto res = cParser.parseDimLvlMap(); if (failed(res)) return {}; const auto &dlm = *res; const Level lvlRank = dlm.getLvlRank(); for (Level lvl = 0; lvl < lvlRank; lvl++) lvlTypes.push_back(dlm.getLvlType(lvl)); const Dimension dimRank = dlm.getDimRank(); for (Dimension dim = 0; dim < dimRank; dim++) dimSlices.push_back(dlm.getDimSlice(dim)); // NOTE: the old syntax requires an all-or-nothing approach to // `dimSlices`; therefore, if any slice actually exists then we need // to convert null-DSA into default/nop DSA. const auto isDefined = [](SparseTensorDimSliceAttr slice) { return static_cast(slice.getImpl()); }; if (llvm::any_of(dimSlices, isDefined)) { const auto defaultSlice = SparseTensorDimSliceAttr::get(parser.getContext()); for (Dimension dim = 0; dim < dimRank; dim++) if (!isDefined(dimSlices[dim])) dimSlices[dim] = defaultSlice; } else { dimSlices.clear(); } dimToLvl = dlm.getDimToLvlMap(parser.getContext()); lvlToDim = dlm.getLvlToDimMap(parser.getContext()); break; } case 1: { // posWidth Attribute attr; if (failed(parser.parseAttribute(attr))) return {}; auto intAttr = llvm::dyn_cast(attr); if (!intAttr) { parser.emitError(parser.getNameLoc(), "expected an integral position bitwidth"); return {}; } posWidth = intAttr.getInt(); break; } case 2: { // crdWidth Attribute attr; if (failed(parser.parseAttribute(attr))) return {}; auto intAttr = llvm::dyn_cast(attr); if (!intAttr) { parser.emitError(parser.getNameLoc(), "expected an integral index bitwidth"); return {}; } crdWidth = intAttr.getInt(); break; } case 3: { // explicitVal Attribute attr; if (failed(parser.parseAttribute(attr))) return {}; if (auto result = llvm::dyn_cast(attr)) { explicitVal = result; } else if (auto result = llvm::dyn_cast(attr)) { explicitVal = result; } else if (auto result = llvm::dyn_cast(attr)) { explicitVal = result; } else { parser.emitError(parser.getNameLoc(), "expected a numeric value for explicitVal"); return {}; } break; } case 4: { // implicitVal Attribute attr; if (failed(parser.parseAttribute(attr))) return {}; if (auto result = llvm::dyn_cast(attr)) { implicitVal = result; } else if (auto result = llvm::dyn_cast(attr)) { implicitVal = result; } else if (auto result = llvm::dyn_cast(attr)) { implicitVal = result; } else { parser.emitError(parser.getNameLoc(), "expected a numeric value for implicitVal"); return {}; } break; } } // switch // Only last item can omit the comma. if (parser.parseOptionalComma().failed()) break; } // Close "}>" part. if (failed(parser.parseRBrace())) return {}; if (failed(parser.parseGreater())) return {}; // Construct struct-like storage for attribute. if (!lvlToDim || lvlToDim.isEmpty()) { lvlToDim = inferLvlToDim(dimToLvl, parser.getContext()); } return parser.getChecked( parser.getContext(), lvlTypes, dimToLvl, lvlToDim, posWidth, crdWidth, explicitVal, implicitVal, dimSlices); } void SparseTensorEncodingAttr::print(AsmPrinter &printer) const { auto map = static_cast(getDimToLvl()); // Empty affine map indicates identity map if (!map) map = AffineMap::getMultiDimIdentityMap(getLvlTypes().size(), getContext()); printer << "<{ map = "; printSymbols(map, printer); printer << '('; printDimensions(map, printer, getDimSlices()); printer << ") -> ("; printLevels(map, printer, getLvlTypes()); printer << ')'; // Print remaining members only for non-default values. if (getPosWidth()) printer << ", posWidth = " << getPosWidth(); if (getCrdWidth()) printer << ", crdWidth = " << getCrdWidth(); if (getExplicitVal()) { printer << ", explicitVal = " << getExplicitVal(); } if (getImplicitVal()) printer << ", implicitVal = " << getImplicitVal(); printer << " }>"; } void SparseTensorEncodingAttr::printSymbols(AffineMap &map, AsmPrinter &printer) const { if (map.getNumSymbols() == 0) return; printer << '['; for (unsigned i = 0, n = map.getNumSymbols() - 1; i < n; i++) printer << 's' << i << ", "; if (map.getNumSymbols() >= 1) printer << 's' << map.getNumSymbols() - 1; printer << ']'; } void SparseTensorEncodingAttr::printDimensions( AffineMap &map, AsmPrinter &printer, ArrayRef dimSlices) const { if (!dimSlices.empty()) { for (unsigned i = 0, n = map.getNumDims() - 1; i < n; i++) printer << 'd' << i << " : " << dimSlices[i] << ", "; if (map.getNumDims() >= 1) { printer << 'd' << map.getNumDims() - 1 << " : " << dimSlices[map.getNumDims() - 1]; } } else { for (unsigned i = 0, n = map.getNumDims() - 1; i < n; i++) printer << 'd' << i << ", "; if (map.getNumDims() >= 1) printer << 'd' << map.getNumDims() - 1; } } void SparseTensorEncodingAttr::printLevels(AffineMap &map, AsmPrinter &printer, ArrayRef lvlTypes) const { for (unsigned i = 0, n = map.getNumResults() - 1; i < n; i++) { map.getResult(i).print(printer.getStream()); printer << " : " << toMLIRString(lvlTypes[i]) << ", "; } if (map.getNumResults() >= 1) { auto lastIndex = map.getNumResults() - 1; map.getResult(lastIndex).print(printer.getStream()); printer << " : " << toMLIRString(lvlTypes[lastIndex]); } } LogicalResult SparseTensorEncodingAttr::verify( function_ref emitError, ArrayRef lvlTypes, AffineMap dimToLvl, AffineMap lvlToDim, unsigned posWidth, unsigned crdWidth, Attribute explicitVal, Attribute implicitVal, ArrayRef dimSlices) { if (!acceptBitWidth(posWidth)) return emitError() << "unexpected position bitwidth: " << posWidth; if (!acceptBitWidth(crdWidth)) return emitError() << "unexpected coordinate bitwidth: " << crdWidth; // Verify every COO segment. auto *it = std::find_if(lvlTypes.begin(), lvlTypes.end(), isSingletonLT); while (it != lvlTypes.end()) { if (it == lvlTypes.begin() || !(it - 1)->isa()) return emitError() << "expected compressed or loose_compressed level " "before singleton level"; auto *curCOOEnd = std::find_if_not(it, lvlTypes.end(), isSingletonLT); if (!std::all_of(it, curCOOEnd, [](LevelType i) { return isSingletonLT(i); })) return emitError() << "expected all singleton lvlTypes " "following a singleton level"; // We can potentially support mixed SoA/AoS singleton levels. if (!std::all_of(it, curCOOEnd, [it](LevelType i) { return it->isa() == i.isa(); })) { return emitError() << "expected all singleton lvlTypes stored in the " "same memory layout (SoA vs AoS)."; } it = std::find_if(curCOOEnd, lvlTypes.end(), isSingletonLT); } auto lastBatch = std::find_if(lvlTypes.rbegin(), lvlTypes.rend(), isBatchLT); if (!std::all_of(lastBatch, lvlTypes.rend(), isBatchLT)) return emitError() << "Batch lvlType can only be leading levels."; // SoA property can only be applied on singleton level. auto soaLvls = llvm::make_filter_range(lvlTypes, [](LevelType lt) { return lt.isa(); }); if (llvm::any_of(soaLvls, [](LevelType lt) { return !lt.isa(); })) { return emitError() << "SoA is only applicable to singleton lvlTypes."; } // TODO: audit formats that actually are supported by backend. if (auto it = std::find_if(lvlTypes.begin(), lvlTypes.end(), isNOutOfMLT); it != std::end(lvlTypes)) { if (it != lvlTypes.end() - 1) return emitError() << "expected n_out_of_m to be the last level type"; if (!std::all_of(lvlTypes.begin(), it, [](LevelType i) { return isDenseLT(i); })) return emitError() << "expected all dense lvlTypes " "before a n_out_of_m level"; if (dimToLvl && (dimToLvl.getNumDims() != dimToLvl.getNumResults())) { if (!isBlockSparsity(dimToLvl)) { return emitError() << "expected 1xm block structure for n_out_of_m level"; } auto sizes = getBlockSize(dimToLvl); unsigned coefficient = 0; for (const auto &elem : sizes) { if (elem != 0) { if (elem != coefficient && coefficient != 0) { return emitError() << "expected only one blocked level " "with the same coefficients"; } coefficient = elem; } } if (coefficient != getM(*it)) { return emitError() << "expected coeffiencts of Affine expressions " "to be equal to m of n_out_of_m level"; } } } // Before we can check that the level-rank is consistent/coherent // across all fields, we need to define it. The source-of-truth for // the `getLvlRank` method is the length of the level-types array, // since it must always be provided and have full rank; therefore we // use that same source-of-truth here. const Level lvlRank = lvlTypes.size(); if (lvlRank == 0) return emitError() << "expected a non-empty array for lvlTypes"; // We save `dimRank` here because we'll also need it to verify `dimSlices`. const Dimension dimRank = dimToLvl ? dimToLvl.getNumDims() : lvlRank; if (dimToLvl) { if (dimToLvl.getNumResults() != lvlRank) return emitError() << "level-rank mismatch between dimToLvl and lvlTypes: " << dimToLvl.getNumResults() << " != " << lvlRank; auto inferRes = inferLvlToDim(dimToLvl, dimToLvl.getContext()); // Symbols can't be inferred but are acceptable. if (!inferRes && dimToLvl.getNumSymbols() == 0) return emitError() << "failed to infer lvlToDim from dimToLvl"; if (lvlToDim && (inferRes != lvlToDim)) return emitError() << "expected lvlToDim to be an inverse of dimToLvl"; if (dimRank > lvlRank) return emitError() << "unexpected dimToLvl mapping from " << dimRank << " to " << lvlRank; } if (!dimSlices.empty()) { if (dimSlices.size() != dimRank) return emitError() << "dimension-rank mismatch between dimSlices and dimToLvl: " << dimSlices.size() << " != " << dimRank; // Compiler support for `dimSlices` currently requires that the two // ranks agree. (However, it does allow `dimToLvl` to be a permutation.) if (dimRank != lvlRank) return emitError() << "dimSlices expected dimension-rank to match level-rank: " << dimRank << " != " << lvlRank; } return success(); } LogicalResult SparseTensorEncodingAttr::verifyEncoding( ArrayRef dimShape, Type elementType, function_ref emitError) const { // Check structural integrity. In particular, this ensures that the // level-rank is coherent across all the fields. if (failed(verify(emitError, getLvlTypes(), getDimToLvl(), getLvlToDim(), getPosWidth(), getCrdWidth(), getExplicitVal(), getImplicitVal(), getDimSlices()))) return failure(); // Check integrity with tensor type specifics. In particular, we // need only check that the dimension-rank of the tensor agrees with // the dimension-rank of the encoding. const Dimension dimRank = dimShape.size(); if (dimRank == 0) return emitError() << "expected non-scalar sparse tensor"; if (getDimRank() != dimRank) return emitError() << "dimension-rank mismatch between encoding and tensor shape: " << getDimRank() << " != " << dimRank; if (auto expVal = getExplicitVal()) { Type attrType = llvm::dyn_cast(expVal).getType(); if (attrType != elementType) { return emitError() << "explicit value type mismatch between encoding and " << "tensor element type: " << attrType << " != " << elementType; } } if (auto impVal = getImplicitVal()) { Type attrType = llvm::dyn_cast(impVal).getType(); if (attrType != elementType) { return emitError() << "implicit value type mismatch between encoding and " << "tensor element type: " << attrType << " != " << elementType; } // Currently, we only support zero as the implicit value. auto impFVal = llvm::dyn_cast(impVal); auto impIntVal = llvm::dyn_cast(impVal); auto impComplexVal = llvm::dyn_cast(impVal); if ((impFVal && impFVal.getValue().isNonZero()) || (impIntVal && !impIntVal.getValue().isZero()) || (impComplexVal && (impComplexVal.getImag().isNonZero() || impComplexVal.getReal().isNonZero()))) { return emitError() << "implicit value must be zero"; } } return success(); } Level mlir::sparse_tensor::SparseTensorEncodingAttr::getAoSCOOStart() const { SmallVector coo = getCOOSegments(); assert(coo.size() == 1 || coo.empty()); if (!coo.empty() && coo.front().isAoS()) { return coo.front().lvlRange.first; } return getLvlRank(); } SmallVector mlir::sparse_tensor::SparseTensorEncodingAttr::getCOOSegments() const { SmallVector ret; if (getLvlRank() <= 1) return ret; ArrayRef lts = getLvlTypes(); Level l = 0; while (l < getLvlRank()) { auto lt = lts[l]; if (lt.isa()) { auto cur = lts.begin() + l; auto end = std::find_if(cur + 1, lts.end(), [](LevelType lt) { return !lt.isa(); }); unsigned cooLen = std::distance(cur, end); if (cooLen > 1) { // To support mixed SoA/AoS COO, we should break the segment when the // storage scheme changes, for now we faithfully assume that all // consecutive singleton levels have the same storage format as verified // STEA. ret.push_back(COOSegment{std::make_pair(l, l + cooLen), lts[l + 1].isa()}); } l += cooLen; } else { l++; } } return ret; } //===----------------------------------------------------------------------===// // SparseTensorType Methods. //===----------------------------------------------------------------------===// bool mlir::sparse_tensor::SparseTensorType::isCOOType(Level startLvl, bool isUnique) const { if (!hasEncoding()) return false; if (!isCompressedLvl(startLvl) && !isLooseCompressedLvl(startLvl)) return false; for (Level l = startLvl + 1; l < lvlRank; ++l) if (!isSingletonLvl(l)) return false; // If isUnique is true, then make sure that the last level is unique, // that is, when lvlRank == 1, the only compressed level is unique, // and when lvlRank > 1, the last singleton is unique. return !isUnique || isUniqueLvl(lvlRank - 1); } RankedTensorType mlir::sparse_tensor::SparseTensorType::getCOOType(bool ordered) const { SmallVector lvlTypes; lvlTypes.reserve(lvlRank); // A non-unique compressed level at beginning (unless this is // also the last level, then it is unique). lvlTypes.push_back( *buildLevelType(LevelFormat::Compressed, ordered, lvlRank == 1)); if (lvlRank > 1) { // Followed by n-2 non-unique singleton levels. std::fill_n(std::back_inserter(lvlTypes), lvlRank - 2, *buildLevelType(LevelFormat::Singleton, ordered, false)); // Ends by a unique singleton level. lvlTypes.push_back(*buildLevelType(LevelFormat::Singleton, ordered, true)); } auto enc = SparseTensorEncodingAttr::get( getContext(), lvlTypes, getDimToLvl(), getLvlToDim(), getPosWidth(), getCrdWidth(), getExplicitVal(), getImplicitVal()); return RankedTensorType::get(getDimShape(), getElementType(), enc); } //===----------------------------------------------------------------------===// // Convenience Methods. //===----------------------------------------------------------------------===// SparseTensorEncodingAttr mlir::sparse_tensor::getSparseTensorEncoding(Type type) { if (auto ttp = llvm::dyn_cast(type)) return llvm::dyn_cast_or_null(ttp.getEncoding()); if (auto mdtp = llvm::dyn_cast(type)) return mdtp.getEncoding(); return nullptr; } AffineMap mlir::sparse_tensor::inferLvlToDim(AffineMap dimToLvl, MLIRContext *context) { auto map = static_cast(dimToLvl); AffineMap lvlToDim; // Return an empty lvlToDim when inference is not successful. if (!map || map.getNumSymbols() != 0) { lvlToDim = AffineMap(); } else if (map.isPermutation()) { lvlToDim = inversePermutation(map); } else if (isBlockSparsity(map)) { lvlToDim = inverseBlockSparsity(map, context); } return lvlToDim; } AffineMap mlir::sparse_tensor::inverseBlockSparsity(AffineMap dimToLvl, MLIRContext *context) { SmallVector lvlExprs; auto numLvls = dimToLvl.getNumResults(); lvlExprs.reserve(numLvls); // lvlExprComponents stores information of the floordiv and mod operations // applied to the same dimension, so as to build the lvlToDim map. std::map> lvlExprComponents; for (unsigned i = 0, n = numLvls; i < n; i++) { auto result = dimToLvl.getResult(i); if (auto binOp = dyn_cast(result)) { if (result.getKind() == AffineExprKind::FloorDiv) { // Position of the dimension in dimToLvl. auto pos = dyn_cast(binOp.getLHS()).getPosition(); assert(lvlExprComponents.find(pos) == lvlExprComponents.end() && "expected only one floordiv for each dimension"); SmallVector components; // Level variable for floordiv. components.push_back(getAffineDimExpr(i, context)); // Multiplier. components.push_back(binOp.getRHS()); // Map key is the position of the dimension. lvlExprComponents[pos] = components; } else if (result.getKind() == AffineExprKind::Mod) { auto pos = dyn_cast(binOp.getLHS()).getPosition(); assert(lvlExprComponents.find(pos) != lvlExprComponents.end() && "expected floordiv before mod"); // Add level variable for mod to the same vector // of the corresponding floordiv. lvlExprComponents[pos].push_back(getAffineDimExpr(i, context)); } else { assert(false && "expected floordiv or mod"); } } else { lvlExprs.push_back(getAffineDimExpr(i, context)); } } // Build lvlExprs from lvlExprComponents. // For example, for il = i floordiv 2 and ii = i mod 2, the components // would be [il, 2, ii]. It could be used to build the AffineExpr // i = il * 2 + ii in lvlToDim. for (auto &components : lvlExprComponents) { assert(components.second.size() == 3 && "expected 3 components to build lvlExprs"); auto mulOp = getAffineBinaryOpExpr( AffineExprKind::Mul, components.second[0], components.second[1]); auto addOp = getAffineBinaryOpExpr(AffineExprKind::Add, mulOp, components.second[2]); lvlExprs.push_back(addOp); } return dimToLvl.get(dimToLvl.getNumResults(), 0, lvlExprs, context); } SmallVector mlir::sparse_tensor::getBlockSize(AffineMap dimToLvl) { assert(isBlockSparsity(dimToLvl) && "expected dimToLvl to be block sparsity for calling getBlockSize"); SmallVector blockSize; for (auto result : dimToLvl.getResults()) { if (auto binOp = dyn_cast(result)) { if (result.getKind() == AffineExprKind::Mod) { blockSize.push_back( dyn_cast(binOp.getRHS()).getValue()); } } else { blockSize.push_back(0); } } return blockSize; } bool mlir::sparse_tensor::isBlockSparsity(AffineMap dimToLvl) { if (!dimToLvl) return false; std::map coeffientMap; bool hasBlock = false; for (auto result : dimToLvl.getResults()) { if (auto binOp = dyn_cast(result)) { // Check for "dim op const". auto dimOp = dyn_cast(binOp.getLHS()); auto conOp = dyn_cast(binOp.getRHS()); if (!dimOp || !conOp || conOp.getValue() <= 0) return false; // Inspect "dim / const" or "dim % const". auto pos = dimOp.getPosition(); if (binOp.getKind() == AffineExprKind::FloorDiv) { // Expect only one floordiv for each dimension. auto [it, inserted] = coeffientMap.try_emplace(pos); if (!inserted) return false; // Record coefficient of the floordiv. it->second = conOp.getValue(); } else if (binOp.getKind() == AffineExprKind::Mod) { // Expect floordiv before mod. auto it = coeffientMap.find(pos); if (it == coeffientMap.end()) return false; // Expect mod to have the same coefficient as floordiv. if (conOp.getValue() != it->second) return false; hasBlock = true; } else { return false; } } else if (auto dimOp = dyn_cast(result)) { auto pos = dimOp.getPosition(); // Expect dim to be unset. if (!coeffientMap.try_emplace(pos, 0).second) return false; } else { return false; } } return hasBlock; } bool mlir::sparse_tensor::hasAnyNonIdentityOperandsOrResults(Operation *op) { auto hasNonIdentityMap = [](Value v) { auto stt = tryGetSparseTensorType(v); return stt && !stt->isIdentity(); }; return llvm::any_of(op->getOperands(), hasNonIdentityMap) || llvm::any_of(op->getResults(), hasNonIdentityMap); } Dimension mlir::sparse_tensor::toDim(SparseTensorEncodingAttr enc, Level l) { if (enc) { assert(enc.isPermutation() && "Non permutation map not supported"); if (const auto dimToLvl = enc.getDimToLvl()) return dimToLvl.getDimPosition(l); } return l; } Level mlir::sparse_tensor::toLvl(SparseTensorEncodingAttr enc, Dimension d) { if (enc) { assert(enc.isPermutation() && "Non permutation map not supported"); if (const auto lvlToDim = enc.getLvlToDim()) return lvlToDim.getDimPosition(d); } return d; } /// We normalized sparse tensor encoding attribute by always using /// ordered/unique LT such that "compressed_nu_no" and "compressed_nu" (as well /// as other variants) lead to the same storage specifier type, and stripping /// irrelevant fields that do not alter the sparse tensor memory layout. static SparseTensorEncodingAttr getNormalizedEncodingForSpecifier(SparseTensorEncodingAttr enc) { SmallVector lts; for (auto lt : enc.getLvlTypes()) lts.push_back(lt.stripStorageIrrelevantProperties()); return SparseTensorEncodingAttr::get( enc.getContext(), lts, AffineMap(), // dimToLvl (irrelevant to storage specifier) AffineMap(), // lvlToDim (irrelevant to storage specifier) // Always use `index` for memSize and lvlSize instead of reusing // `getPosWidth` and `getCrdWidth`. It allows us to reuse the same SSA // value for different bitwidth, it also avoids casting between index and // integer (returned by DimOp) 0, 0, Attribute(), // explicitVal (irrelevant to storage specifier) Attribute(), // implicitVal (irrelevant to storage specifier) enc.getDimSlices()); } StorageSpecifierType StorageSpecifierType::get(MLIRContext *ctx, SparseTensorEncodingAttr encoding) { return Base::get(ctx, getNormalizedEncodingForSpecifier(encoding)); } StorageSpecifierType StorageSpecifierType::getChecked(function_ref emitError, MLIRContext *ctx, SparseTensorEncodingAttr encoding) { return Base::getChecked(emitError, ctx, getNormalizedEncodingForSpecifier(encoding)); } //===----------------------------------------------------------------------===// // SparseTensorDialect Operations. //===----------------------------------------------------------------------===// static LogicalResult lvlIsInBounds(Level lvl, Value tensor) { return success(lvl < getSparseTensorType(tensor).getLvlRank()); } static LogicalResult isMatchingWidth(Value mem, unsigned width) { const Type etp = getMemRefType(mem).getElementType(); return success(width == 0 ? etp.isIndex() : etp.isInteger(width)); } static LogicalResult verifySparsifierGetterSetter( StorageSpecifierKind mdKind, std::optional lvl, TypedValue md, Operation *op) { if (mdKind == StorageSpecifierKind::ValMemSize && lvl) { return op->emitError( "redundant level argument for querying value memory size"); } const auto enc = md.getType().getEncoding(); const Level lvlRank = enc.getLvlRank(); if (mdKind == StorageSpecifierKind::DimOffset || mdKind == StorageSpecifierKind::DimStride) if (!enc.isSlice()) return op->emitError("requested slice data on non-slice tensor"); if (mdKind != StorageSpecifierKind::ValMemSize) { if (!lvl) return op->emitError("missing level argument"); const Level l = lvl.value(); if (l >= lvlRank) return op->emitError("requested level is out of bounds"); if (mdKind == StorageSpecifierKind::PosMemSize && enc.isSingletonLvl(l)) return op->emitError( "requested position memory size on a singleton level"); } return success(); } static Type getFieldElemType(SparseTensorType stt, SparseTensorFieldKind kind) { switch (kind) { case SparseTensorFieldKind::CrdMemRef: return stt.getCrdType(); case SparseTensorFieldKind::PosMemRef: return stt.getPosType(); case SparseTensorFieldKind::ValMemRef: return stt.getElementType(); case SparseTensorFieldKind::StorageSpec: return nullptr; } llvm_unreachable("Unrecognizable FieldKind"); } static LogicalResult verifyPackUnPack(Operation *op, bool requiresStaticShape, SparseTensorType stt, RankedTensorType valTp, TypeRange lvlTps) { if (requiresStaticShape && !stt.hasStaticDimShape()) return op->emitError("the sparse-tensor must have static shape"); if (!stt.hasEncoding()) return op->emitError("the sparse-tensor must have an encoding attribute"); // Verifies the trailing COO. Level cooStartLvl = stt.getAoSCOOStart(); if (cooStartLvl < stt.getLvlRank()) { // We only supports trailing COO for now, must be the last input. auto cooTp = llvm::cast(lvlTps.back()); // The coordinates should be in shape of unsigned expCOORank = stt.getLvlRank() - cooStartLvl; if (cooTp.getRank() != 2 || expCOORank != cooTp.getShape().back()) { return op->emitError("input/output trailing COO level-ranks don't match"); } } // Verifies that all types match. StorageLayout layout(stt.getEncoding()); if (layout.getNumDataFields() != lvlTps.size() + 1) // plus one value memref return op->emitError("inconsistent number of fields between input/output"); unsigned idx = 0; bool misMatch = false; layout.foreachField([&idx, &misMatch, stt, valTp, lvlTps](FieldIndex fid, SparseTensorFieldKind fKind, Level lvl, LevelType lt) -> bool { if (fKind == SparseTensorFieldKind::StorageSpec) return true; Type inputTp = nullptr; if (fKind == SparseTensorFieldKind::ValMemRef) { inputTp = valTp; } else { assert(fid == idx && stt.getLvlType(lvl) == lt); inputTp = lvlTps[idx++]; } // The input element type and expected element type should match. Type inpElemTp = llvm::cast(inputTp).getElementType(); Type expElemTp = getFieldElemType(stt, fKind); if (inpElemTp != expElemTp) { misMatch = true; return false; // to terminate the iteration } return true; }); if (misMatch) return op->emitError("input/output element-types don't match"); return success(); } LogicalResult AssembleOp::verify() { RankedTensorType valuesTp = getValues().getType(); const auto lvlsTp = getLevels().getTypes(); const auto resTp = getSparseTensorType(getResult()); return verifyPackUnPack(*this, true, resTp, valuesTp, lvlsTp); } LogicalResult DisassembleOp::verify() { if (getOutValues().getType() != getRetValues().getType()) return emitError("output values and return value type mismatch"); for (auto [ot, rt] : llvm::zip_equal(getOutLevels(), getRetLevels())) if (ot.getType() != rt.getType()) return emitError("output levels and return levels type mismatch"); RankedTensorType valuesTp = getRetValues().getType(); const auto lvlsTp = getRetLevels().getTypes(); const auto srcTp = getSparseTensorType(getTensor()); return verifyPackUnPack(*this, false, srcTp, valuesTp, lvlsTp); } LogicalResult ConvertOp::verify() { RankedTensorType tp1 = getSource().getType(); RankedTensorType tp2 = getDest().getType(); if (tp1.getRank() != tp2.getRank()) return emitError("unexpected conversion mismatch in rank"); auto dstEnc = llvm::dyn_cast_or_null(tp2.getEncoding()); if (dstEnc && dstEnc.isSlice()) return emitError("cannot convert to a sparse tensor slice"); auto shape1 = tp1.getShape(); auto shape2 = tp2.getShape(); // Accept size matches between the source and the destination type // (e.g. 10 vs. 10, 10 vs. ?, or ? vs. ?), but reject direct mismatches or // matches that would need a runtime assert (e.g. 10 vs. 20 or ? vs. 10). for (Dimension d = 0, dimRank = tp1.getRank(); d < dimRank; d++) if (shape1[d] != shape2[d] && shape2[d] != ShapedType::kDynamic) return emitError("unexpected conversion mismatch in dimension ") << d; return success(); } OpFoldResult ConvertOp::fold(FoldAdaptor adaptor) { if (getType() == getSource().getType()) return getSource(); return {}; } bool ConvertOp::needsExtraSort() { SparseTensorType srcStt = getSparseTensorType(getSource()); SparseTensorType dstStt = getSparseTensorType(getDest()); // We do not need an extra sort when returning unordered sparse tensors or // dense tensor since dense tensor support random access. if (dstStt.isAllDense() || !dstStt.isAllOrdered()) return false; if (srcStt.isAllOrdered() && dstStt.isAllOrdered() && srcStt.hasSameDimToLvl(dstStt)) { return false; } // Source and dest tensors are ordered in different ways. We only do direct // dense to sparse conversion when the dense input is defined by a sparse // constant. Note that we can theoretically always directly convert from dense // inputs by rotating dense loops but it leads to bad cache locality and hurt // performance. if (auto constOp = getSource().getDefiningOp()) if (isa(constOp.getValue())) return false; return true; } LogicalResult CrdTranslateOp::verify() { uint64_t inRank = getEncoder().getLvlRank(); uint64_t outRank = getEncoder().getDimRank(); if (getDirection() == CrdTransDirectionKind::dim2lvl) std::swap(inRank, outRank); if (inRank != getInCrds().size() || outRank != getOutCrds().size()) return emitError("Coordinate rank mismatch with encoding"); return success(); } LogicalResult CrdTranslateOp::fold(FoldAdaptor adaptor, SmallVectorImpl &results) { if (getEncoder().isIdentity()) { results.assign(getInCrds().begin(), getInCrds().end()); return success(); } if (getEncoder().isPermutation()) { AffineMap perm = getDirection() == CrdTransDirectionKind::dim2lvl ? getEncoder().getDimToLvl() : getEncoder().getLvlToDim(); for (AffineExpr exp : perm.getResults()) results.push_back(getInCrds()[cast(exp).getPosition()]); return success(); } // Fuse dim2lvl/lvl2dim pairs. auto def = getInCrds()[0].getDefiningOp(); bool sameDef = def && llvm::all_of(getInCrds(), [def](Value v) { return v.getDefiningOp() == def; }); if (!sameDef) return failure(); bool oppositeDir = def.getDirection() != getDirection(); bool sameOracle = def.getEncoder().getDimToLvl() == getEncoder().getDimToLvl(); bool sameCount = def.getNumResults() == getInCrds().size(); if (!oppositeDir || !sameOracle || !sameCount) return failure(); // The definition produces the coordinates in the same order as the input // coordinates. bool sameOrder = llvm::all_of(llvm::zip_equal(def.getOutCrds(), getInCrds()), [](auto valuePair) { auto [lhs, rhs] = valuePair; return lhs == rhs; }); if (!sameOrder) return failure(); // l1 = dim2lvl (lvl2dim l0) // ==> l0 results.append(def.getInCrds().begin(), def.getInCrds().end()); return success(); } void LvlOp::build(OpBuilder &builder, OperationState &state, Value source, int64_t index) { Value val = builder.create(state.location, index); return build(builder, state, source, val); } LogicalResult LvlOp::verify() { if (std::optional lvl = getConstantLvlIndex()) { auto stt = getSparseTensorType(getSource()); if (static_cast(lvl.value()) >= stt.getLvlRank()) return emitError( "Level index exceeds the rank of the input sparse tensor"); } return success(); } std::optional LvlOp::getConstantLvlIndex() { return getConstantIntValue(getIndex()); } Speculation::Speculatability LvlOp::getSpeculatability() { auto constantIndex = getConstantLvlIndex(); if (!constantIndex) return Speculation::NotSpeculatable; assert(constantIndex < cast(getSource().getType()).getRank()); return Speculation::Speculatable; } OpFoldResult LvlOp::fold(FoldAdaptor adaptor) { auto lvlIndex = llvm::dyn_cast_if_present(adaptor.getIndex()); if (!lvlIndex) return {}; Level lvl = lvlIndex.getAPSInt().getZExtValue(); auto stt = getSparseTensorType(getSource()); if (lvl >= stt.getLvlRank()) { // Follows the same convention used by tensor.dim operation. Out of bound // indices produce undefined behavior but are still valid IR. Don't choke on // them. return {}; } // Helper lambda to build an IndexAttr. auto getIndexAttr = [this](int64_t lvlSz) { return IntegerAttr::get(IndexType::get(getContext()), APInt(64, lvlSz)); }; SmallVector lvlShape = stt.getLvlShape(); if (!ShapedType::isDynamic(lvlShape[lvl])) return getIndexAttr(lvlShape[lvl]); return {}; } void ReinterpretMapOp::build(OpBuilder &odsBuilder, OperationState &odsState, SparseTensorEncodingAttr dstEnc, Value source) { auto srcStt = getSparseTensorType(source); SmallVector srcLvlShape = srcStt.getLvlShape(); SmallVector dstDimShape = dstEnc.translateShape(srcLvlShape, CrdTransDirectionKind::lvl2dim); auto dstTp = RankedTensorType::get(dstDimShape, srcStt.getElementType(), dstEnc); return build(odsBuilder, odsState, dstTp, source); } LogicalResult ReinterpretMapOp::verify() { auto srcStt = getSparseTensorType(getSource()); auto dstStt = getSparseTensorType(getDest()); ArrayRef srcLvlTps = srcStt.getLvlTypes(); ArrayRef dstLvlTps = dstStt.getLvlTypes(); if (srcLvlTps.size() != dstLvlTps.size()) return emitError("Level rank mismatch between source/dest tensors"); for (auto [srcLvlTp, dstLvlTp] : llvm::zip(srcLvlTps, dstLvlTps)) if (srcLvlTp != dstLvlTp) return emitError("Level type mismatch between source/dest tensors"); if (srcStt.getPosWidth() != dstStt.getPosWidth() || srcStt.getCrdWidth() != dstStt.getCrdWidth()) { return emitError("Crd/Pos width mismatch between source/dest tensors"); } if (srcStt.getElementType() != dstStt.getElementType()) return emitError("Element type mismatch between source/dest tensors"); SmallVector srcLvlShape = srcStt.getLvlShape(); SmallVector dstLvlShape = dstStt.getLvlShape(); for (auto [srcLvlSz, dstLvlSz] : llvm::zip(srcLvlShape, dstLvlShape)) { if (srcLvlSz != dstLvlSz) { // Should we allow one side to be dynamic size, e.g., should be // compatible to <3x4>? For now, we require all the level sizes to be // *exactly* matched for simplicity. return emitError("Level size mismatch between source/dest tensors"); } } return success(); } OpFoldResult ReinterpretMapOp::fold(FoldAdaptor adaptor) { if (getSource().getType() == getDest().getType()) return getSource(); if (auto def = getSource().getDefiningOp()) { // A -> B, B -> A ==> A if (def.getSource().getType() == getDest().getType()) return def.getSource(); } return {}; } template static LogicalResult inferSparseBufferType(ValueRange ops, DictionaryAttr attr, OpaqueProperties prop, RegionRange region, SmallVectorImpl &ret) { typename ToBufferOp::Adaptor adaptor(ops, attr, prop, region); SparseTensorType stt = getSparseTensorType(adaptor.getTensor()); Type elemTp = nullptr; bool withStride = false; if constexpr (std::is_same_v) { elemTp = stt.getPosType(); } else if constexpr (std::is_same_v || std::is_same_v) { elemTp = stt.getCrdType(); if constexpr (std::is_same_v) withStride = stt.getAoSCOOStart() <= adaptor.getLevel(); } else if constexpr (std::is_same_v) { elemTp = stt.getElementType(); } assert(elemTp && "unhandled operation."); SmallVector bufShape = stt.getBatchLvlShape(); bufShape.push_back(ShapedType::kDynamic); auto layout = withStride ? StridedLayoutAttr::StridedLayoutAttr::get( stt.getContext(), ShapedType::kDynamic, {ShapedType::kDynamic}) : StridedLayoutAttr(); ret.emplace_back(MemRefType::get(bufShape, elemTp, layout)); return success(); } LogicalResult ToPositionsOp::verify() { auto stt = getSparseTensorType(getTensor()); if (failed(lvlIsInBounds(getLevel(), getTensor()))) return emitError("requested level is out of bounds"); if (failed(isMatchingWidth(getResult(), stt.getPosWidth()))) return emitError("unexpected type for positions"); return success(); } LogicalResult ToPositionsOp::inferReturnTypes(MLIRContext *ctx, std::optional loc, ValueRange ops, DictionaryAttr attr, OpaqueProperties prop, RegionRange region, SmallVectorImpl &ret) { return inferSparseBufferType(ops, attr, prop, region, ret); } LogicalResult ToCoordinatesOp::verify() { auto stt = getSparseTensorType(getTensor()); if (failed(lvlIsInBounds(getLevel(), getTensor()))) return emitError("requested level is out of bounds"); if (failed(isMatchingWidth(getResult(), stt.getCrdWidth()))) return emitError("unexpected type for coordinates"); return success(); } LogicalResult ToCoordinatesOp::inferReturnTypes(MLIRContext *ctx, std::optional loc, ValueRange ops, DictionaryAttr attr, OpaqueProperties prop, RegionRange region, SmallVectorImpl &ret) { return inferSparseBufferType(ops, attr, prop, region, ret); } LogicalResult ToCoordinatesBufferOp::verify() { auto stt = getSparseTensorType(getTensor()); if (stt.getAoSCOOStart() >= stt.getLvlRank()) return emitError("expected sparse tensor with a COO region"); return success(); } LogicalResult ToCoordinatesBufferOp::inferReturnTypes( MLIRContext *ctx, std::optional loc, ValueRange ops, DictionaryAttr attr, OpaqueProperties prop, RegionRange region, SmallVectorImpl &ret) { return inferSparseBufferType(ops, attr, prop, region, ret); } LogicalResult ToValuesOp::verify() { auto stt = getSparseTensorType(getTensor()); auto mtp = getMemRefType(getResult()); if (stt.getElementType() != mtp.getElementType()) return emitError("unexpected mismatch in element types"); return success(); } LogicalResult ToValuesOp::inferReturnTypes(MLIRContext *ctx, std::optional loc, ValueRange ops, DictionaryAttr attr, OpaqueProperties prop, RegionRange region, SmallVectorImpl &ret) { return inferSparseBufferType(ops, attr, prop, region, ret); } LogicalResult ToSliceOffsetOp::verify() { auto rank = getSlice().getType().getRank(); if (rank <= getDim().getSExtValue() || getDim().getSExtValue() < 0) return emitError("requested dimension out of bound"); return success(); } LogicalResult ToSliceStrideOp::verify() { auto rank = getSlice().getType().getRank(); if (rank <= getDim().getSExtValue() || getDim().getSExtValue() < 0) return emitError("requested dimension out of bound"); return success(); } LogicalResult GetStorageSpecifierOp::verify() { return verifySparsifierGetterSetter(getSpecifierKind(), getLevel(), getSpecifier(), getOperation()); } template static SetStorageSpecifierOp getSpecifierSetDef(SpecifierOp op) { return op.getSpecifier().template getDefiningOp(); } OpFoldResult GetStorageSpecifierOp::fold(FoldAdaptor adaptor) { const StorageSpecifierKind kind = getSpecifierKind(); const auto lvl = getLevel(); for (auto op = getSpecifierSetDef(*this); op; op = getSpecifierSetDef(op)) if (kind == op.getSpecifierKind() && lvl == op.getLevel()) return op.getValue(); return {}; } LogicalResult SetStorageSpecifierOp::verify() { return verifySparsifierGetterSetter(getSpecifierKind(), getLevel(), getSpecifier(), getOperation()); } template static LogicalResult verifyNumBlockArgs(T *op, Region ®ion, const char *regionName, TypeRange inputTypes, Type outputType) { unsigned numArgs = region.getNumArguments(); unsigned expectedNum = inputTypes.size(); if (numArgs != expectedNum) return op->emitError() << regionName << " region must have exactly " << expectedNum << " arguments"; for (unsigned i = 0; i < numArgs; i++) { Type typ = region.getArgument(i).getType(); if (typ != inputTypes[i]) return op->emitError() << regionName << " region argument " << (i + 1) << " type mismatch"; } Operation *term = region.front().getTerminator(); YieldOp yield = dyn_cast(term); if (!yield) return op->emitError() << regionName << " region must end with sparse_tensor.yield"; if (!yield.hasSingleResult() || yield.getSingleResult().getType() != outputType) return op->emitError() << regionName << " region yield type mismatch"; return success(); } LogicalResult BinaryOp::verify() { NamedAttrList attrs = (*this)->getAttrs(); Type leftType = getX().getType(); Type rightType = getY().getType(); Type outputType = getOutput().getType(); Region &overlap = getOverlapRegion(); Region &left = getLeftRegion(); Region &right = getRightRegion(); // Check correct number of block arguments and return type for each // non-empty region. if (!overlap.empty()) { if (failed(verifyNumBlockArgs(this, overlap, "overlap", TypeRange{leftType, rightType}, outputType))) return failure(); } if (!left.empty()) { if (failed(verifyNumBlockArgs(this, left, "left", TypeRange{leftType}, outputType))) return failure(); } else if (getLeftIdentity()) { if (leftType != outputType) return emitError("left=identity requires first argument to have the same " "type as the output"); } if (!right.empty()) { if (failed(verifyNumBlockArgs(this, right, "right", TypeRange{rightType}, outputType))) return failure(); } else if (getRightIdentity()) { if (rightType != outputType) return emitError("right=identity requires second argument to have the " "same type as the output"); } return success(); } LogicalResult UnaryOp::verify() { Type inputType = getX().getType(); Type outputType = getOutput().getType(); // Check correct number of block arguments and return type for each // non-empty region. Region &present = getPresentRegion(); if (!present.empty()) { if (failed(verifyNumBlockArgs(this, present, "present", TypeRange{inputType}, outputType))) return failure(); } Region &absent = getAbsentRegion(); if (!absent.empty()) { if (failed(verifyNumBlockArgs(this, absent, "absent", TypeRange{}, outputType))) return failure(); // Absent branch can only yield invariant values. Block *absentBlock = &absent.front(); Block *parent = getOperation()->getBlock(); Value absentVal = cast(absentBlock->getTerminator()).getSingleResult(); if (auto arg = dyn_cast(absentVal)) { if (arg.getOwner() == parent) return emitError("absent region cannot yield linalg argument"); } else if (Operation *def = absentVal.getDefiningOp()) { if (!isa(def) && (def->getBlock() == absentBlock || def->getBlock() == parent)) return emitError("absent region cannot yield locally computed value"); } } return success(); } bool ConcatenateOp::needsExtraSort() { SparseTensorType dstStt = getSparseTensorType(*this); if (dstStt.isAllDense() || !dstStt.isAllOrdered()) return false; bool allSameOrdered = llvm::all_of(getInputs(), [dstStt](Value op) { return getSparseTensorType(op).hasSameDimToLvl(dstStt); }); // TODO: When conDim != 0, as long as conDim corresponding to the first level // in all input/output buffers, and all input/output buffers have the same // dimToLvl, the tmp COO buffer is still unnecessary (e.g, concatenate // CSC matrices along column). bool directLowerable = allSameOrdered && getDimension() == 0 && dstStt.isIdentity(); return !directLowerable; } LogicalResult ConcatenateOp::verify() { const auto dstTp = getSparseTensorType(*this); const Dimension concatDim = getDimension(); const Dimension dimRank = dstTp.getDimRank(); if (getInputs().size() <= 1) return emitError("Need at least two tensors to concatenate."); if (concatDim >= dimRank) return emitError(llvm::formatv( "Concat-dimension is out of bounds for dimension-rank ({0} >= {1})", concatDim, dimRank)); for (const auto &it : llvm::enumerate(getInputs())) { const auto i = it.index(); const auto srcTp = getSparseTensorType(it.value()); if (srcTp.hasDynamicDimShape()) return emitError(llvm::formatv("Input tensor ${0} has dynamic shape", i)); const Dimension srcDimRank = srcTp.getDimRank(); if (srcDimRank != dimRank) return emitError( llvm::formatv("Input tensor ${0} has a different rank (rank={1}) " "from the output tensor (rank={2}).", i, srcDimRank, dimRank)); } for (Dimension d = 0; d < dimRank; d++) { const Size dstSh = dstTp.getDimShape()[d]; if (d == concatDim) { if (!ShapedType::isDynamic(dstSh)) { // If we reach here, then all inputs have static shapes. So we // can use `getDimShape()[d]` instead of `*getDynamicDimSize(d)` // to avoid redundant assertions in the loop. Size sumSz = 0; for (const auto src : getInputs()) sumSz += getSparseTensorType(src).getDimShape()[d]; // If all dimension are statically known, the sum of all the input // dimensions should be equal to the output dimension. if (sumSz != dstSh) return emitError( "The concatenation dimension of the output tensor should be the " "sum of all the concatenation dimensions of the input tensors."); } } else { Size prev = dstSh; for (const auto src : getInputs()) { const auto sh = getSparseTensorType(src).getDimShape()[d]; if (!ShapedType::isDynamic(prev) && sh != prev) return emitError("All dimensions (expect for the concatenating one) " "should be equal."); prev = sh; } } } return success(); } void PushBackOp::build(OpBuilder &builder, OperationState &result, Value curSize, Value inBuffer, Value value) { build(builder, result, curSize, inBuffer, value, Value()); } LogicalResult PushBackOp::verify() { if (Value n = getN()) { std::optional nValue = getConstantIntValue(n); if (nValue && nValue.value() < 1) return emitOpError("n must be not less than 1"); } return success(); } LogicalResult CompressOp::verify() { const auto stt = getSparseTensorType(getTensor()); if (stt.getLvlRank() != 1 + static_cast(getLvlCoords().size())) return emitOpError("incorrect number of coordinates"); return success(); } void ForeachOp::build( OpBuilder &builder, OperationState &result, Value tensor, ValueRange initArgs, AffineMapAttr order, function_ref bodyBuilder) { build(builder, result, initArgs.getTypes(), tensor, initArgs, order); // Builds foreach body. if (!bodyBuilder) return; const auto stt = getSparseTensorType(tensor); const Dimension dimRank = stt.getDimRank(); // Starts with `dimRank`-many coordinates. SmallVector blockArgTypes(dimRank, builder.getIndexType()); // Followed by one value. blockArgTypes.push_back(stt.getElementType()); // Followed by the reduction variables. blockArgTypes.append(initArgs.getTypes().begin(), initArgs.getTypes().end()); SmallVector blockArgLocs(blockArgTypes.size(), tensor.getLoc()); OpBuilder::InsertionGuard guard(builder); auto ®ion = *result.regions.front(); Block *bodyBlock = builder.createBlock(®ion, region.end(), blockArgTypes, blockArgLocs); bodyBuilder(builder, result.location, bodyBlock->getArguments().slice(0, dimRank), bodyBlock->getArguments()[dimRank], bodyBlock->getArguments().drop_front(dimRank + 1)); } LogicalResult ForeachOp::verify() { const auto t = getSparseTensorType(getTensor()); const Dimension dimRank = t.getDimRank(); const auto args = getBody()->getArguments(); if (getOrder().has_value() && getOrder()->getNumDims() != t.getLvlRank()) return emitError("Level traverse order does not match tensor's level rank"); if (dimRank + 1 + getInitArgs().size() != args.size()) return emitError("Unmatched number of arguments in the block"); if (getNumResults() != getInitArgs().size()) return emitError("Mismatch in number of init arguments and results"); if (getResultTypes() != getInitArgs().getTypes()) return emitError("Mismatch in types of init arguments and results"); // Cannot mark this const, because the getters aren't. auto yield = cast(getBody()->getTerminator()); if (yield.getNumOperands() != getNumResults() || yield.getOperands().getTypes() != getResultTypes()) return emitError("Mismatch in types of yield values and results"); const auto iTp = IndexType::get(getContext()); for (Dimension d = 0; d < dimRank; d++) if (args[d].getType() != iTp) return emitError( llvm::formatv("Expecting Index type for argument at index {0}", d)); const auto elemTp = t.getElementType(); const auto valueTp = args[dimRank].getType(); if (elemTp != valueTp) return emitError( llvm::formatv("Unmatched element type between input tensor and " "block argument, expected:{0}, got: {1}", elemTp, valueTp)); return success(); } OpFoldResult ReorderCOOOp::fold(FoldAdaptor adaptor) { if (getSparseTensorEncoding(getInputCoo().getType()) == getSparseTensorEncoding(getResultCoo().getType())) return getInputCoo(); return {}; } LogicalResult ReorderCOOOp::verify() { SparseTensorType srcStt = getSparseTensorType(getInputCoo()); SparseTensorType dstStt = getSparseTensorType(getResultCoo()); if (!srcStt.isCOOType() || !dstStt.isCOOType()) return emitError("Expected COO sparse tensors only"); if (!srcStt.hasSameDimToLvl(dstStt)) return emitError("Unmatched dim2lvl map between input and result COO"); if (srcStt.getPosType() != dstStt.getPosType() || srcStt.getCrdType() != dstStt.getCrdType() || srcStt.getElementType() != dstStt.getElementType()) return emitError("Unmatched storage format between input and result COO"); return success(); } LogicalResult ReduceOp::verify() { Type inputType = getX().getType(); Region &formula = getRegion(); return verifyNumBlockArgs(this, formula, "reduce", TypeRange{inputType, inputType}, inputType); } LogicalResult SelectOp::verify() { Builder b(getContext()); Type inputType = getX().getType(); Type boolType = b.getI1Type(); Region &formula = getRegion(); return verifyNumBlockArgs(this, formula, "select", TypeRange{inputType}, boolType); } LogicalResult SortOp::verify() { AffineMap xPerm = getPermMap(); uint64_t nx = xPerm.getNumDims(); if (nx < 1) return emitError(llvm::formatv("Expected rank(perm_map) > 1, got {0}", nx)); if (!xPerm.isPermutation()) return emitError( llvm::formatv("Expected a permutation map, got {0}", xPerm)); // We can't check the size of the buffers when n or buffer dimensions aren't // compile-time constants. std::optional cn = getConstantIntValue(getN()); if (!cn) return success(); // Verify dimensions. const auto checkDim = [&](Value v, Size minSize, const char *message) -> LogicalResult { const Size sh = getMemRefType(v).getShape()[0]; if (!ShapedType::isDynamic(sh) && sh < minSize) return emitError( llvm::formatv("{0} got {1} < {2}", message, sh, minSize)); return success(); }; uint64_t n = cn.value(); uint64_t ny = 0; if (auto nyAttr = getNyAttr()) ny = nyAttr.getInt(); if (failed(checkDim(getXy(), n * (nx + ny), "Expected dimension(xy) >= n * (rank(perm_map) + ny)"))) return failure(); for (Value opnd : getYs()) if (failed(checkDim(opnd, n, "Expected dimension(y) >= n"))) return failure(); return success(); } //===----------------------------------------------------------------------===// // Sparse Tensor Iteration Operations. //===----------------------------------------------------------------------===// IterSpaceType IteratorType::getIterSpaceType() const { return IterSpaceType::get(getContext(), getEncoding(), getLoLvl(), getHiLvl()); } IteratorType IterSpaceType::getIteratorType() const { return IteratorType::get(getContext(), getEncoding(), getLoLvl(), getHiLvl()); } /// Parses a level range in the form "$lo `to` $hi" /// or simply "$lo" if $hi - $lo = 1 static ParseResult parseLevelRange(AsmParser &parser, Level &lvlLo, Level &lvlHi) { if (parser.parseInteger(lvlLo)) return failure(); if (succeeded(parser.parseOptionalKeyword("to"))) { if (parser.parseInteger(lvlHi)) return failure(); } else { lvlHi = lvlLo + 1; } if (lvlHi <= lvlLo) return parser.emitError(parser.getNameLoc(), "expect larger level upper bound than lower bound"); return success(); } /// Parses a level range in the form "$lo `to` $hi" /// or simply "$lo" if $hi - $lo = 1 static ParseResult parseLevelRange(OpAsmParser &parser, IntegerAttr &lvlLoAttr, IntegerAttr &lvlHiAttr) { Level lvlLo, lvlHi; if (parseLevelRange(parser, lvlLo, lvlHi)) return failure(); lvlLoAttr = IntegerAttr::get(parser.getBuilder().getIndexType(), lvlLo); lvlHiAttr = IntegerAttr::get(parser.getBuilder().getIndexType(), lvlHi); return success(); } /// Prints a level range in the form "$lo `to` $hi" /// or simply "$lo" if $hi - $lo = 1 static void printLevelRange(AsmPrinter &p, Level lo, Level hi) { if (lo + 1 == hi) p << lo; else p << lo << " to " << hi; } /// Prints a level range in the form "$lo `to` $hi" /// or simply "$lo" if $hi - $lo = 1 static void printLevelRange(OpAsmPrinter &p, Operation *, IntegerAttr lvlLo, IntegerAttr lvlHi) { unsigned lo = lvlLo.getValue().getZExtValue(); unsigned hi = lvlHi.getValue().getZExtValue(); printLevelRange(p, lo, hi); } /// Parses a list of `optional` defined list in the form of /// "(%val0, _, %val1, ...)", where `_` is used to annotate that the /// corresponding value is not defined (e.g., to represent an undefined /// coordinate in the sparse iteration space). static ParseResult parseOptionalDefinedList( OpAsmParser &parser, OperationState &state, I64BitSet &definedSet, SmallVectorImpl &definedArgs, unsigned maxCnt = std::numeric_limits::max(), OpAsmParser::Delimiter delimiter = OpAsmParser::Delimiter::Paren) { unsigned cnt = 0; ParseResult crdList = parser.parseCommaSeparatedList(delimiter, [&]() -> ParseResult { if (parser.parseOptionalKeyword("_")) { if (parser.parseArgument(definedArgs.emplace_back())) return failure(); definedSet.set(cnt); } cnt += 1; return success(); }); if (cnt > maxCnt) return parser.emitError(parser.getNameLoc(), "parsed more value than expected."); if (failed(crdList)) { return parser.emitError( parser.getNameLoc(), "expecting SSA value or \"_\" for level coordinates"); } assert(definedArgs.size() == definedSet.count()); return success(); } static void printOptionalDefinedList(OpAsmPrinter &p, unsigned size, Block::BlockArgListType blocksArgs, I64BitSet definedSet) { if (definedSet.empty()) return; for (unsigned i = 0; i < size; i++) { if (definedSet[i]) { p << blocksArgs.front(); blocksArgs = blocksArgs.drop_front(); } else { p << "_"; } if (i != size - 1) p << ", "; } assert(blocksArgs.empty()); } static ParseResult parseUsedCoordList(OpAsmParser &parser, OperationState &state, SmallVectorImpl &coords) { // Parse "at(%crd0, _, ...)" I64BitSet crdUsedLvlSet; if (succeeded(parser.parseOptionalKeyword("at")) && failed(parseOptionalDefinedList(parser, state, crdUsedLvlSet, coords))) return failure(); // Always use IndexType for the coordinate. for (auto &coord : coords) coord.type = parser.getBuilder().getIndexType(); // Set the CrdUsedLvl bitset. state.addAttribute("crdUsedLvls", parser.getBuilder().getI64IntegerAttr(crdUsedLvlSet)); return success(); } static ParseResult parseSparseIterateLoop(OpAsmParser &parser, OperationState &state, SmallVectorImpl &iterators, SmallVectorImpl &blockArgs) { SmallVector spaces; SmallVector initArgs; // Parse "%iters, ... in %spaces, ..." if (parser.parseArgumentList(iterators) || parser.parseKeyword("in") || parser.parseOperandList(spaces)) return failure(); if (iterators.size() != spaces.size()) return parser.emitError( parser.getNameLoc(), "mismatch in number of sparse iterators and sparse spaces"); SmallVector coords; if (failed(parseUsedCoordList(parser, state, coords))) return failure(); size_t numCrds = coords.size(); // Parse "iter_args(%arg = %init, ...)" bool hasIterArgs = succeeded(parser.parseOptionalKeyword("iter_args")); if (hasIterArgs) if (parser.parseAssignmentList(blockArgs, initArgs)) return failure(); blockArgs.append(coords); SmallVector iterSpaceTps; // parse ": sparse_tensor.iter_space -> ret" if (parser.parseColon() || parser.parseTypeList(iterSpaceTps)) return failure(); if (iterSpaceTps.size() != spaces.size()) return parser.emitError(parser.getNameLoc(), "mismatch in number of iteration space operands " "and iteration space types"); for (auto [it, tp] : llvm::zip_equal(iterators, iterSpaceTps)) { IterSpaceType spaceTp = llvm::dyn_cast(tp); if (!spaceTp) return parser.emitError(parser.getNameLoc(), "expected sparse_tensor.iter_space type for " "iteration space operands"); it.type = spaceTp.getIteratorType(); } if (hasIterArgs) if (parser.parseArrowTypeList(state.types)) return failure(); // Resolves input operands. if (parser.resolveOperands(spaces, iterSpaceTps, parser.getNameLoc(), state.operands)) return failure(); if (hasIterArgs) { // Strip off leading args that used for coordinates. MutableArrayRef args = MutableArrayRef(blockArgs).drop_back(numCrds); if (args.size() != initArgs.size() || args.size() != state.types.size()) { return parser.emitError( parser.getNameLoc(), "mismatch in number of iteration arguments and return values"); } for (auto [it, init, tp] : llvm::zip_equal(args, initArgs, state.types)) { it.type = tp; if (parser.resolveOperand(init, tp, state.operands)) return failure(); } } return success(); } static ParseResult parseSparseCoIterateLoop(OpAsmParser &parser, OperationState &state, SmallVectorImpl &spacesVals, SmallVectorImpl &blockArgs) { // Parse "(%spaces, ...)" SmallVector spaces; if (parser.parseOperandList(spaces, OpAsmParser::Delimiter::Paren)) return failure(); SmallVector coords; if (failed(parseUsedCoordList(parser, state, coords))) return failure(); size_t numCrds = coords.size(); // Parse "iter_args(%arg = %init, ...)" SmallVector initArgs; bool hasIterArgs = succeeded(parser.parseOptionalKeyword("iter_args")); if (hasIterArgs) if (parser.parseAssignmentList(blockArgs, initArgs)) return failure(); blockArgs.append(coords); SmallVector iterSpaceTps; // parse ": (sparse_tensor.iter_space, ...) -> ret" if (parser.parseColon() || parser.parseLParen() || parser.parseTypeList(iterSpaceTps) || parser.parseRParen()) return failure(); if (iterSpaceTps.size() != spaces.size()) return parser.emitError(parser.getNameLoc(), "mismatch in number of iteration space operands " "and iteration space types"); if (hasIterArgs) if (parser.parseArrowTypeList(state.types)) return failure(); // Resolves input sparse iteration spaces. if (parser.resolveOperands(spaces, iterSpaceTps, parser.getNameLoc(), spacesVals)) return failure(); state.operands.append(spacesVals); if (hasIterArgs) { // Strip off trailing args that used for coordinates. MutableArrayRef args = MutableArrayRef(blockArgs).drop_back(numCrds); if (args.size() != initArgs.size() || args.size() != state.types.size()) { return parser.emitError( parser.getNameLoc(), "mismatch in number of iteration arguments and return values"); } for (auto [it, init, tp] : llvm::zip_equal(args, initArgs, state.types)) { it.type = tp; if (parser.resolveOperand(init, tp, state.operands)) return failure(); } } return success(); } LogicalResult ExtractIterSpaceOp::inferReturnTypes( MLIRContext *ctx, std::optional loc, ValueRange ops, DictionaryAttr attr, OpaqueProperties prop, RegionRange region, SmallVectorImpl &ret) { ExtractIterSpaceOp::Adaptor adaptor(ops, attr, prop, region); SparseTensorType stt = getSparseTensorType(adaptor.getTensor()); ret.push_back(IterSpaceType::get(ctx, stt.getEncoding(), adaptor.getLoLvl(), adaptor.getHiLvl())); return success(); } LogicalResult ExtractIterSpaceOp::verify() { if (getLoLvl() >= getHiLvl()) return emitOpError("expected smaller level low than level high"); TypedValue pIter = getParentIter(); if ((pIter && getLoLvl() == 0) || (!pIter && getLoLvl() != 0)) { return emitOpError( "parent iterator should be specified iff level lower bound equals 0"); } if (pIter) { IterSpaceType spaceTp = getExtractedSpace().getType(); if (pIter.getType().getEncoding() != spaceTp.getEncoding()) return emitOpError( "mismatch in parent iterator encoding and iteration space encoding."); if (spaceTp.getLoLvl() != pIter.getType().getHiLvl()) return emitOpError("parent iterator should be used to extract an " "iteration space from a consecutive level."); } return success(); } LogicalResult ExtractValOp::verify() { auto stt = getSparseTensorType(getTensor()); auto itTp = getIterator().getType(); if (stt.getEncoding() != itTp.getEncoding()) return emitOpError("mismatch in tensor encoding and iterator encoding."); if (stt.getLvlRank() != itTp.getHiLvl()) return emitOpError("must use last-level iterator to extract values. "); return success(); } struct RemoveUnusedLvlCrds : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(IterateOp iterateOp, PatternRewriter &rewriter) const override { I64BitSet newUsedLvls(0); llvm::BitVector toRemove(iterateOp.getBody()->getNumArguments()); for (unsigned i = 0, e = iterateOp.getSpaceDim(); i < e; i++) { if (auto crd = iterateOp.getLvlCrd(i)) { if (crd->getUsers().empty()) toRemove.set(crd->getArgNumber()); else newUsedLvls.set(i); } } // All coordinates are used. if (toRemove.none()) return failure(); rewriter.startOpModification(iterateOp); iterateOp.setCrdUsedLvls(newUsedLvls); iterateOp.getBody()->eraseArguments(toRemove); rewriter.finalizeOpModification(iterateOp); return success(); } }; void IterateOp::getCanonicalizationPatterns(mlir::RewritePatternSet &results, mlir::MLIRContext *context) { results.add(context); } void IterateOp::build(OpBuilder &builder, OperationState &odsState, Value iterSpace, ValueRange initArgs) { unsigned rank = llvm::cast(iterSpace.getType()).getSpaceDim(); // All ones. I64BitSet set((1 << rank) - 1); return build(builder, odsState, iterSpace, initArgs, set); } void IterateOp::build(OpBuilder &builder, OperationState &odsState, Value iterSpace, ValueRange initArgs, I64BitSet crdUsedLvls) { OpBuilder::InsertionGuard guard(builder); odsState.addOperands(iterSpace); odsState.addOperands(initArgs); odsState.getOrAddProperties().crdUsedLvls = builder.getIntegerAttr(builder.getIntegerType(64), crdUsedLvls); Region *bodyRegion = odsState.addRegion(); odsState.addTypes(initArgs.getTypes()); Block *bodyBlock = builder.createBlock(bodyRegion); // Starts with a list of user-provided loop arguments. for (Value v : initArgs) bodyBlock->addArgument(v.getType(), v.getLoc()); // Follows by a list of used coordinates. for (unsigned i = 0, e = crdUsedLvls.count(); i < e; i++) bodyBlock->addArgument(builder.getIndexType(), odsState.location); // Ends with sparse iterator bodyBlock->addArgument( llvm::cast(iterSpace.getType()).getIteratorType(), odsState.location); } ParseResult IterateOp::parse(OpAsmParser &parser, OperationState &result) { OpAsmParser::Argument iterator; OpAsmParser::UnresolvedOperand iterSpace; SmallVector iters, iterArgs; if (parseSparseIterateLoop(parser, result, iters, iterArgs)) return failure(); if (iters.size() != 1) return parser.emitError(parser.getNameLoc(), "expected only one iterator/iteration space"); iterArgs.append(iters); Region *body = result.addRegion(); if (parser.parseRegion(*body, iterArgs)) return failure(); IterateOp::ensureTerminator(*body, parser.getBuilder(), result.location); // Parse the optional attribute list. if (parser.parseOptionalAttrDict(result.attributes)) return failure(); return success(); } /// Prints the initialization list in the form of /// (%inner = %outer, %inner2 = %outer2, <...>) /// where 'inner' values are assumed to be region arguments and 'outer' values /// are regular SSA values. static void printInitializationList(OpAsmPrinter &p, Block::BlockArgListType blocksArgs, ValueRange initializers, StringRef prefix = "") { assert(blocksArgs.size() == initializers.size() && "expected same length of arguments and initializers"); if (initializers.empty()) return; p << prefix << '('; llvm::interleaveComma(llvm::zip(blocksArgs, initializers), p, [&](auto it) { p << std::get<0>(it) << " = " << std::get<1>(it); }); p << ")"; } template static LogicalResult verifySparseLoopOp(SparseLoopOp op) { if (op.getInitArgs().size() != op.getNumResults()) { return op.emitOpError( "mismatch in number of loop-carried values and defined values"); } if (op.getCrdUsedLvls().max() > op.getSpaceDim()) return op.emitOpError("required out-of-bound coordinates"); return success(); } LogicalResult IterateOp::verify() { return verifySparseLoopOp(*this); } LogicalResult CoIterateOp::verify() { return verifySparseLoopOp(*this); } void IterateOp::print(OpAsmPrinter &p) { p << " " << getIterator() << " in " << getIterSpace(); if (!getCrdUsedLvls().empty()) { p << " at("; printOptionalDefinedList(p, getSpaceDim(), getCrds(), getCrdUsedLvls()); p << ")"; } printInitializationList(p, getRegionIterArgs(), getInitArgs(), " iter_args"); p << " : " << getIterSpace().getType() << " "; if (!getInitArgs().empty()) p.printArrowTypeList(getInitArgs().getTypes()); p << " "; p.printRegion(getRegion(), /*printEntryBlockArgs=*/false, /*printBlockTerminators=*/!getInitArgs().empty()); } LogicalResult IterateOp::verifyRegions() { if (getIterator().getType() != getIterSpace().getType().getIteratorType()) return emitOpError("mismatch in iterator and iteration space type"); if (getNumRegionIterArgs() != getNumResults()) return emitOpError( "mismatch in number of basic block args and defined values"); auto initArgs = getInitArgs(); auto iterArgs = getRegionIterArgs(); auto yieldVals = getYieldedValues(); auto opResults = getResults(); if (!llvm::all_equal({initArgs.size(), iterArgs.size(), yieldVals.size(), opResults.size()})) { return emitOpError() << "number mismatch between iter args and results."; } for (auto [i, init, iter, yield, ret] : llvm::enumerate(initArgs, iterArgs, yieldVals, opResults)) { if (init.getType() != ret.getType()) return emitOpError() << "types mismatch between " << i << "th iter operand and defined value"; if (iter.getType() != ret.getType()) return emitOpError() << "types mismatch between " << i << "th iter region arg and defined value"; if (yield.getType() != ret.getType()) return emitOpError() << "types mismatch between " << i << "th yield value and defined value"; } return success(); } /// OpInterfaces' methods implemented by IterateOp. SmallVector IterateOp::getLoopRegions() { return {&getRegion()}; } MutableArrayRef IterateOp::getInitsMutable() { return getInitArgsMutable(); } Block::BlockArgListType IterateOp::getRegionIterArgs() { return getRegion().getArguments().take_front(getNumRegionIterArgs()); } std::optional> IterateOp::getYieldedValuesMutable() { return cast( getRegion().getBlocks().front().getTerminator()) .getResultsMutable(); } std::optional IterateOp::getLoopResults() { return getResults(); } OperandRange IterateOp::getEntrySuccessorOperands(RegionBranchPoint point) { return getInitArgs(); } void IterateOp::getSuccessorRegions(RegionBranchPoint point, SmallVectorImpl ®ions) { // Both the operation itself and the region may be branching into the body // or back into the operation itself. regions.push_back(RegionSuccessor(&getRegion(), getRegionIterArgs())); // It is possible for loop not to enter the body. regions.push_back(RegionSuccessor(getResults())); } void CoIterateOp::build(OpBuilder &builder, OperationState &odsState, ValueRange iterSpaces, ValueRange initArgs, unsigned numCases) { unsigned rank = cast(iterSpaces.front().getType()).getSpaceDim(); // All ones. I64BitSet set((1 << rank) - 1); // Generates all-zero case bits (they only serve as placeholders), which are // supposed to be overriden later. We need to preallocate all the regions as // mlir::Region cannot be dynamically added later after the operation is // created. SmallVector caseBits(numCases, 0); ArrayAttr cases = builder.getI64ArrayAttr(caseBits); return CoIterateOp::build(builder, odsState, initArgs.getTypes(), iterSpaces, initArgs, set, cases, /*caseRegionsCount=*/numCases); } ParseResult CoIterateOp::parse(OpAsmParser &parser, OperationState &result) { SmallVector spaces; // The block argument list of each regions, it is arranged in the order of // ([used coordinate list], [loop iterations args], [sparse iterator list]). SmallVector blockArgs; if (parseSparseCoIterateLoop(parser, result, spaces, blockArgs)) return failure(); result.addAttribute("operandSegmentSizes", parser.getBuilder().getDenseI32ArrayAttr( {static_cast(spaces.size()), static_cast(result.types.size())})); SmallVector cases; while (succeeded(parser.parseOptionalKeyword("case"))) { // Parse one region per case. I64BitSet definedItSet; SmallVector definedIts; if (parseOptionalDefinedList(parser, result, definedItSet, definedIts, spaces.size(), OpAsmParser::Delimiter::None)) return failure(); cases.push_back(parser.getBuilder().getI64IntegerAttr(definedItSet)); for (auto [i, definedIdx] : llvm::enumerate(definedItSet.bits())) { // Resolve the iterator type based on the iteration space type. auto spaceTp = llvm::cast(spaces[definedIdx].getType()); definedIts[i].type = spaceTp.getIteratorType(); } definedIts.insert(definedIts.begin(), blockArgs.begin(), blockArgs.end()); Region *body = result.addRegion(); if (parser.parseRegion(*body, definedIts)) return failure(); CoIterateOp::ensureTerminator(*body, parser.getBuilder(), result.location); } result.addAttribute("cases", ArrayAttr::get(parser.getContext(), cases)); // Parse the optional attribute list. if (parser.parseOptionalAttrDict(result.attributes)) return failure(); return success(); } void CoIterateOp::print(OpAsmPrinter &p) { p << " ("; llvm::interleaveComma(getIterSpaces(), p, [&](auto s) { p << s; }); p << ")"; if (!getCrdUsedLvls().empty()) { p << " at("; printOptionalDefinedList(p, getSpaceDim(), getCrds(0), getCrdUsedLvls()); p << ")"; } printInitializationList(p, getRegionIterArgs(0), getInitArgs(), " iter_args"); p << " : (" << getIterSpaces().getTypes() << ")"; if (!getInitArgs().empty()) p.printArrowTypeList(getInitArgs().getTypes()); for (unsigned idx = 0, e = getRegions().size(); idx < e; idx++) { p.printNewline(); p << "case "; printOptionalDefinedList(p, getIterSpaces().size(), getRegionIterators(idx), getRegionDefinedSpace(idx)); p << " "; p.printRegion(getRegion(idx), /*printEntryBlockArgs=*/false, /*printBlockTerminators=*/!getInitArgs().empty()); } } ValueRange CoIterateOp::getYieldedValues(unsigned regionIdx) { return cast( getRegion(regionIdx).getBlocks().front().getTerminator()) .getResults(); } LogicalResult CoIterateOp::verifyRegions() { for (unsigned r = 0, e = getNumRegions(); r < e; r++) { if (getNumRegionIterArgs() != getNumResults()) return emitOpError( "mismatch in number of basic block args and defined values"); auto initArgs = getInitArgs(); auto iterArgs = getRegionIterArgs(r); auto yieldVals = getYieldedValues(r); auto opResults = getResults(); if (!llvm::all_equal({initArgs.size(), iterArgs.size(), yieldVals.size(), opResults.size()})) { return emitOpError() << "number mismatch between iter args and results on " << r << "th region"; } for (auto [i, init, iter, yield, ret] : llvm::enumerate(initArgs, iterArgs, yieldVals, opResults)) { if (init.getType() != ret.getType()) return emitOpError() << "types mismatch between " << i << "th iter operand and defined value on " << r << "th region"; if (iter.getType() != ret.getType()) return emitOpError() << "types mismatch between " << i << "th iter region arg and defined value on " << r << "th region"; if (yield.getType() != ret.getType()) return emitOpError() << "types mismatch between " << i << "th yield value and defined value on " << r << "th region"; } } auto cases = getRegionDefinedSpaces(); llvm::SmallSetVector set(cases.begin(), cases.end()); if (set.size() != getNumRegions()) return emitOpError("contains duplicated cases."); return success(); } SmallVector CoIterateOp::getSubCasesOf(unsigned regionIdx) { SmallVector ret; I64BitSet caseBit = getRegionDefinedSpace(regionIdx); for (Region &r : getCaseRegions()) if (getRegionDefinedSpace(r.getRegionNumber()).isSubSetOf(caseBit)) ret.push_back(&r); return ret; } //===----------------------------------------------------------------------===// // Sparse Tensor Dialect Setups. //===----------------------------------------------------------------------===// /// Materialize a single constant operation from a given attribute value with /// the desired resultant type. Operation *SparseTensorDialect::materializeConstant(OpBuilder &builder, Attribute value, Type type, Location loc) { if (auto op = arith::ConstantOp::materialize(builder, value, type, loc)) return op; return nullptr; } namespace { struct SparseTensorAsmDialectInterface : public OpAsmDialectInterface { using OpAsmDialectInterface::OpAsmDialectInterface; AliasResult getAlias(Attribute attr, raw_ostream &os) const override { if (isa(attr)) { os << "sparse"; return AliasResult::OverridableAlias; } return AliasResult::NoAlias; } }; } // namespace void SparseTensorDialect::initialize() { addInterface(); addAttributes< #define GET_ATTRDEF_LIST #include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.cpp.inc" >(); addTypes< #define GET_TYPEDEF_LIST #include "mlir/Dialect/SparseTensor/IR/SparseTensorTypes.cpp.inc" >(); addOperations< #define GET_OP_LIST #include "mlir/Dialect/SparseTensor/IR/SparseTensorOps.cpp.inc" >(); declarePromisedInterfaces< bufferization::BufferizableOpInterface, ConcatenateOp, ConvertOp, LoadOp, NewOp, NumberOfEntriesOp, AssembleOp, DisassembleOp, ToCoordinatesBufferOp, ToCoordinatesOp, ToPositionsOp, ToValuesOp>(); } #define GET_OP_CLASSES #include "mlir/Dialect/SparseTensor/IR/SparseTensorOps.cpp.inc" #include "mlir/Dialect/SparseTensor/IR/SparseTensorOpsDialect.cpp.inc"