1 //===- SparseTensorDialect.cpp - Sparse tensor dialect implementation -----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include <utility> 10 11 #include "Detail/DimLvlMapParser.h" 12 13 #include "mlir/Dialect/SparseTensor/IR/Enums.h" 14 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" 15 #include "mlir/Dialect/SparseTensor/IR/SparseTensorStorageLayout.h" 16 #include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h" 17 18 #include "mlir/Dialect/Arith/IR/Arith.h" 19 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" 20 #include "mlir/Dialect/Complex/IR/Complex.h" 21 #include "mlir/Dialect/Utils/StaticValueUtils.h" 22 #include "mlir/IR/Builders.h" 23 #include "mlir/IR/DialectImplementation.h" 24 #include "mlir/IR/Matchers.h" 25 #include "mlir/IR/OpImplementation.h" 26 #include "mlir/IR/PatternMatch.h" 27 #include "llvm/ADT/TypeSwitch.h" 28 #include "llvm/Support/FormatVariadic.h" 29 30 #define GET_ATTRDEF_CLASSES 31 #include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.cpp.inc" 32 #include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrEnums.cpp.inc" 33 34 // Forward declarations, following custom print/parsing methods are referenced 35 // by the generated code for SparseTensorTypes.td. 36 static mlir::ParseResult parseLevelRange(mlir::AsmParser &, 37 mlir::sparse_tensor::Level &, 38 mlir::sparse_tensor::Level &); 39 static void printLevelRange(mlir::AsmPrinter &, mlir::sparse_tensor::Level, 40 mlir::sparse_tensor::Level); 41 42 #define GET_TYPEDEF_CLASSES 43 #include "mlir/Dialect/SparseTensor/IR/SparseTensorTypes.cpp.inc" 44 45 using namespace mlir; 46 using namespace mlir::sparse_tensor; 47 48 // Support hashing LevelType such that SparseTensorEncodingAttr can be hashed as 49 // well. 50 namespace mlir::sparse_tensor { 51 llvm::hash_code hash_value(LevelType lt) { 52 return llvm::hash_value(static_cast<uint64_t>(lt)); 53 } 54 } // namespace mlir::sparse_tensor 55 56 //===----------------------------------------------------------------------===// 57 // Local Convenience Methods. 58 //===----------------------------------------------------------------------===// 59 60 static constexpr bool acceptBitWidth(unsigned bitWidth) { 61 switch (bitWidth) { 62 case 0: 63 case 8: 64 case 16: 65 case 32: 66 case 64: 67 return true; 68 default: 69 return false; 70 } 71 } 72 73 static SmallVector<Size> 74 getSparseFieldShape(const SparseTensorEncodingAttr enc, 75 std::optional<ArrayRef<int64_t>> dimShape) { 76 assert(enc); 77 // With only encoding, we can not determine the static shape for leading 78 // batch levels, we therefore return a dynamic shape memref instead. 79 SmallVector<int64_t> memrefShape(enc.getBatchLvlRank(), ShapedType::kDynamic); 80 if (dimShape.has_value()) { 81 // If the actual tensor shape is provided, we can then refine the leading 82 // batch dimension. 83 SmallVector<int64_t> lvlShape = 84 enc.translateShape(*dimShape, CrdTransDirectionKind::dim2lvl); 85 memrefShape.assign(lvlShape.begin(), 86 lvlShape.begin() + enc.getBatchLvlRank()); 87 } 88 // Another dynamic dimension to store the sparse level. 89 memrefShape.push_back(ShapedType::kDynamic); 90 return memrefShape; 91 } 92 93 //===----------------------------------------------------------------------===// 94 // SparseTensorDialect StorageLayout. 95 //===----------------------------------------------------------------------===// 96 97 static constexpr Level kInvalidLevel = -1u; 98 static constexpr Level kInvalidFieldIndex = -1u; 99 static constexpr FieldIndex kDataFieldStartingIdx = 0; 100 101 void StorageLayout::foreachField( 102 llvm::function_ref<bool(FieldIndex, SparseTensorFieldKind, Level, 103 LevelType)> 104 callback) const { 105 const auto lvlTypes = enc.getLvlTypes(); 106 const Level lvlRank = enc.getLvlRank(); 107 SmallVector<COOSegment> cooSegs = SparseTensorType(enc).getCOOSegments(); 108 FieldIndex fieldIdx = kDataFieldStartingIdx; 109 110 ArrayRef cooSegsRef = cooSegs; 111 // Per-level storage. 112 for (Level l = 0; l < lvlRank; /*l += 1 or l += AoSCooLen*/) { 113 const auto lt = lvlTypes[l]; 114 if (isWithPosLT(lt)) { 115 if (!(callback(fieldIdx++, SparseTensorFieldKind::PosMemRef, l, lt))) 116 return; 117 } 118 if (isWithCrdLT(lt)) { 119 if (!(callback(fieldIdx++, SparseTensorFieldKind::CrdMemRef, l, lt))) 120 return; 121 } 122 if (!cooSegsRef.empty() && cooSegsRef.front().isSegmentStart(l)) { 123 if (!cooSegsRef.front().isSoA) { 124 // AoS COO, all singletons are fused into one memrefs. Skips the entire 125 // COO segement. 126 l = cooSegsRef.front().lvlRange.second; 127 } else { 128 // SoA COO, each singleton level has one memref. 129 l++; 130 } 131 // Expire handled COO segment. 132 cooSegsRef = cooSegsRef.drop_front(); 133 } else { 134 // Non COO levels. 135 l++; 136 } 137 } 138 // The values array. 139 if (!(callback(fieldIdx++, SparseTensorFieldKind::ValMemRef, kInvalidLevel, 140 LevelFormat::Undef))) 141 return; 142 // Put metadata at the end. 143 if (!(callback(fieldIdx++, SparseTensorFieldKind::StorageSpec, kInvalidLevel, 144 LevelFormat::Undef))) 145 return; 146 } 147 148 void sparse_tensor::foreachFieldAndTypeInSparseTensor( 149 SparseTensorType stt, 150 llvm::function_ref<bool(Type, FieldIndex, SparseTensorFieldKind, Level, 151 LevelType)> 152 callback) { 153 assert(stt.hasEncoding()); 154 155 SmallVector<int64_t> memrefShape = 156 getSparseFieldShape(stt.getEncoding(), stt.getDimShape()); 157 158 const Type specType = StorageSpecifierType::get(stt.getEncoding()); 159 // memref<[batch] x ? x pos> positions 160 const Type posMemType = MemRefType::get(memrefShape, stt.getPosType()); 161 // memref<[batch] x ? x crd> coordinates 162 const Type crdMemType = MemRefType::get(memrefShape, stt.getCrdType()); 163 // memref<[batch] x ? x eltType> values 164 const Type valMemType = MemRefType::get(memrefShape, stt.getElementType()); 165 166 StorageLayout(stt).foreachField([specType, posMemType, crdMemType, valMemType, 167 callback](FieldIndex fieldIdx, 168 SparseTensorFieldKind fieldKind, 169 Level lvl, LevelType lt) -> bool { 170 switch (fieldKind) { 171 case SparseTensorFieldKind::StorageSpec: 172 return callback(specType, fieldIdx, fieldKind, lvl, lt); 173 case SparseTensorFieldKind::PosMemRef: 174 return callback(posMemType, fieldIdx, fieldKind, lvl, lt); 175 case SparseTensorFieldKind::CrdMemRef: 176 return callback(crdMemType, fieldIdx, fieldKind, lvl, lt); 177 case SparseTensorFieldKind::ValMemRef: 178 return callback(valMemType, fieldIdx, fieldKind, lvl, lt); 179 }; 180 llvm_unreachable("unrecognized field kind"); 181 }); 182 } 183 184 unsigned StorageLayout::getNumFields() const { 185 unsigned numFields = 0; 186 foreachField([&numFields](FieldIndex, SparseTensorFieldKind, Level, 187 LevelType) -> bool { 188 numFields++; 189 return true; 190 }); 191 return numFields; 192 } 193 194 unsigned StorageLayout::getNumDataFields() const { 195 unsigned numFields = 0; // one value memref 196 foreachField([&numFields](FieldIndex fidx, SparseTensorFieldKind, Level, 197 LevelType) -> bool { 198 if (fidx >= kDataFieldStartingIdx) 199 numFields++; 200 return true; 201 }); 202 numFields -= 1; // the last field is StorageSpecifier 203 assert(numFields == getNumFields() - kDataFieldStartingIdx - 1); 204 return numFields; 205 } 206 207 std::pair<FieldIndex, unsigned> 208 StorageLayout::getFieldIndexAndStride(SparseTensorFieldKind kind, 209 std::optional<Level> lvl) const { 210 FieldIndex fieldIdx = kInvalidFieldIndex; 211 unsigned stride = 1; 212 if (kind == SparseTensorFieldKind::CrdMemRef) { 213 assert(lvl.has_value()); 214 const Level cooStart = SparseTensorType(enc).getAoSCOOStart(); 215 const Level lvlRank = enc.getLvlRank(); 216 if (lvl.value() >= cooStart && lvl.value() < lvlRank) { 217 lvl = cooStart; 218 stride = lvlRank - cooStart; 219 } 220 } 221 foreachField([lvl, kind, &fieldIdx](FieldIndex fIdx, 222 SparseTensorFieldKind fKind, Level fLvl, 223 LevelType lt) -> bool { 224 if ((lvl && fLvl == lvl.value() && kind == fKind) || 225 (kind == fKind && fKind == SparseTensorFieldKind::ValMemRef)) { 226 fieldIdx = fIdx; 227 // Returns false to break the iteration. 228 return false; 229 } 230 return true; 231 }); 232 assert(fieldIdx != kInvalidFieldIndex); 233 return std::pair<FieldIndex, unsigned>(fieldIdx, stride); 234 } 235 236 //===----------------------------------------------------------------------===// 237 // SparseTensorDialect Attribute Methods. 238 //===----------------------------------------------------------------------===// 239 240 std::optional<uint64_t> SparseTensorDimSliceAttr::getStatic(int64_t v) { 241 return isDynamic(v) ? std::nullopt 242 : std::make_optional(static_cast<uint64_t>(v)); 243 } 244 245 std::optional<uint64_t> SparseTensorDimSliceAttr::getStaticOffset() const { 246 return getStatic(getOffset()); 247 } 248 249 std::optional<uint64_t> SparseTensorDimSliceAttr::getStaticStride() const { 250 return getStatic(getStride()); 251 } 252 253 std::optional<uint64_t> SparseTensorDimSliceAttr::getStaticSize() const { 254 return getStatic(getSize()); 255 } 256 257 bool SparseTensorDimSliceAttr::isCompletelyDynamic() const { 258 return isDynamic(getOffset()) && isDynamic(getStride()) && 259 isDynamic(getSize()); 260 } 261 262 std::string SparseTensorDimSliceAttr::getStaticString(int64_t v) { 263 return isDynamic(v) ? "?" : std::to_string(v); 264 } 265 266 void SparseTensorDimSliceAttr::print(llvm::raw_ostream &os) const { 267 assert(getImpl() && "Uninitialized SparseTensorDimSliceAttr"); 268 os << '('; 269 os << getStaticString(getOffset()); 270 os << ", "; 271 os << getStaticString(getSize()); 272 os << ", "; 273 os << getStaticString(getStride()); 274 os << ')'; 275 } 276 277 void SparseTensorDimSliceAttr::print(AsmPrinter &printer) const { 278 print(printer.getStream()); 279 } 280 281 static ParseResult parseOptionalStaticSlice(int64_t &result, 282 AsmParser &parser) { 283 auto parseResult = parser.parseOptionalInteger(result); 284 if (parseResult.has_value()) { 285 if (parseResult.value().succeeded() && result < 0) { 286 parser.emitError( 287 parser.getCurrentLocation(), 288 "expect positive value or ? for slice offset/size/stride"); 289 return failure(); 290 } 291 return parseResult.value(); 292 } 293 294 // Else, and '?' which represented dynamic slice 295 result = SparseTensorDimSliceAttr::kDynamic; 296 return parser.parseQuestion(); 297 } 298 299 Attribute SparseTensorDimSliceAttr::parse(AsmParser &parser, Type type) { 300 int64_t offset = kDynamic, size = kDynamic, stride = kDynamic; 301 302 if (failed(parser.parseLParen()) || 303 failed(parseOptionalStaticSlice(offset, parser)) || 304 failed(parser.parseComma()) || 305 failed(parseOptionalStaticSlice(size, parser)) || 306 failed(parser.parseComma()) || 307 failed(parseOptionalStaticSlice(stride, parser)) || 308 failed(parser.parseRParen())) 309 return {}; 310 311 return parser.getChecked<SparseTensorDimSliceAttr>(parser.getContext(), 312 offset, size, stride); 313 } 314 315 LogicalResult 316 SparseTensorDimSliceAttr::verify(function_ref<InFlightDiagnostic()> emitError, 317 int64_t offset, int64_t size, int64_t stride) { 318 if (!isDynamic(offset) && offset < 0) 319 return emitError() << "expect non-negative value or ? for slice offset"; 320 if (!isDynamic(size) && size <= 0) 321 return emitError() << "expect positive value or ? for slice size"; 322 if (!isDynamic(stride) && stride <= 0) 323 return emitError() << "expect positive value or ? for slice stride"; 324 return success(); 325 } 326 327 SparseTensorEncodingAttr 328 SparseTensorEncodingAttr::withDimToLvl(AffineMap dimToLvl) const { 329 assert(getImpl() && "Uninitialized SparseTensorEncodingAttr"); 330 return SparseTensorEncodingAttr::get( 331 getContext(), getLvlTypes(), dimToLvl, AffineMap(), getPosWidth(), 332 getCrdWidth(), getExplicitVal(), getImplicitVal()); 333 } 334 335 SparseTensorEncodingAttr 336 SparseTensorEncodingAttr::withDimToLvl(SparseTensorEncodingAttr enc) const { 337 return withDimToLvl(enc ? enc.getDimToLvl() : AffineMap()); 338 } 339 340 SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutDimToLvl() const { 341 return withDimToLvl(AffineMap()); 342 } 343 344 SparseTensorEncodingAttr 345 SparseTensorEncodingAttr::withBitWidths(unsigned posWidth, 346 unsigned crdWidth) const { 347 assert(getImpl() && "Uninitialized SparseTensorEncodingAttr"); 348 return SparseTensorEncodingAttr::get( 349 getContext(), getLvlTypes(), getDimToLvl(), getLvlToDim(), posWidth, 350 crdWidth, getExplicitVal(), getImplicitVal()); 351 } 352 353 SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutBitWidths() const { 354 return withBitWidths(0, 0); 355 } 356 357 SparseTensorEncodingAttr 358 SparseTensorEncodingAttr::withExplicitVal(Attribute explicitVal) const { 359 assert(getImpl() && "Uninitialized SparseTensorEncodingAttr"); 360 return SparseTensorEncodingAttr::get( 361 getContext(), getLvlTypes(), getDimToLvl(), getLvlToDim(), getPosWidth(), 362 getCrdWidth(), explicitVal, getImplicitVal()); 363 } 364 365 SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutExplicitVal() const { 366 return withExplicitVal(Attribute()); 367 } 368 369 SparseTensorEncodingAttr 370 SparseTensorEncodingAttr::withImplicitVal(Attribute implicitVal) const { 371 assert(getImpl() && "Uninitialized SparseTensorEncodingAttr"); 372 return SparseTensorEncodingAttr::get( 373 getContext(), getLvlTypes(), getDimToLvl(), getLvlToDim(), getPosWidth(), 374 getCrdWidth(), getExplicitVal(), implicitVal); 375 } 376 377 SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutImplicitVal() const { 378 return withImplicitVal(Attribute()); 379 } 380 381 SparseTensorEncodingAttr SparseTensorEncodingAttr::withDimSlices( 382 ArrayRef<SparseTensorDimSliceAttr> dimSlices) const { 383 return SparseTensorEncodingAttr::get( 384 getContext(), getLvlTypes(), getDimToLvl(), getLvlToDim(), getPosWidth(), 385 getCrdWidth(), getExplicitVal(), getImplicitVal(), dimSlices); 386 } 387 388 SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutDimSlices() const { 389 return withDimSlices(ArrayRef<SparseTensorDimSliceAttr>{}); 390 } 391 392 uint64_t SparseTensorEncodingAttr::getBatchLvlRank() const { 393 ArrayRef<LevelType> lvlTypes = getLvlTypes(); 394 auto lastBatch = std::find_if(lvlTypes.rbegin(), lvlTypes.rend(), isBatchLT); 395 return std::distance(lastBatch, lvlTypes.rend()); 396 } 397 398 bool SparseTensorEncodingAttr::isAllDense() const { 399 return !getImpl() || llvm::all_of(getLvlTypes(), isDenseLT); 400 } 401 402 bool SparseTensorEncodingAttr::isAllOrdered() const { 403 return !getImpl() || llvm::all_of(getLvlTypes(), isOrderedLT); 404 } 405 406 Type SparseTensorEncodingAttr::getCrdElemType() const { 407 if (!getImpl()) 408 return nullptr; 409 if (getCrdWidth()) 410 return IntegerType::get(getContext(), getCrdWidth()); 411 return IndexType::get(getContext()); 412 } 413 414 Type SparseTensorEncodingAttr::getPosElemType() const { 415 if (!getImpl()) 416 return nullptr; 417 if (getPosWidth()) 418 return IntegerType::get(getContext(), getPosWidth()); 419 return IndexType::get(getContext()); 420 } 421 422 MemRefType SparseTensorEncodingAttr::getCrdMemRefType( 423 std::optional<ArrayRef<int64_t>> dimShape) const { 424 SmallVector<Size> shape = getSparseFieldShape(*this, dimShape); 425 return MemRefType::get(shape, getCrdElemType()); 426 } 427 428 MemRefType SparseTensorEncodingAttr::getPosMemRefType( 429 std::optional<ArrayRef<int64_t>> dimShape) const { 430 SmallVector<Size> shape = getSparseFieldShape(*this, dimShape); 431 return MemRefType::get(shape, getPosElemType()); 432 } 433 434 bool SparseTensorEncodingAttr::isIdentity() const { 435 return !getImpl() || !getDimToLvl() || getDimToLvl().isIdentity(); 436 } 437 438 bool SparseTensorEncodingAttr::isPermutation() const { 439 return !getImpl() || !getDimToLvl() || getDimToLvl().isPermutation(); 440 } 441 442 Dimension SparseTensorEncodingAttr::getDimRank() const { 443 assert(getImpl() && "Uninitialized SparseTensorEncodingAttr"); 444 const auto dimToLvl = getDimToLvl(); 445 return dimToLvl ? dimToLvl.getNumDims() : getLvlRank(); 446 } 447 448 Level SparseTensorEncodingAttr::getLvlRank() const { 449 assert(getImpl() && "Uninitialized SparseTensorEncodingAttr"); 450 return getLvlTypes().size(); 451 } 452 453 LevelType SparseTensorEncodingAttr::getLvlType(Level l) const { 454 if (!getImpl()) 455 return LevelFormat::Batch; 456 assert(l < getLvlRank() && "Level is out of bounds"); 457 return getLvlTypes()[l]; 458 } 459 460 bool SparseTensorEncodingAttr::isSlice() const { 461 assert(getImpl() && "Uninitialized SparseTensorEncodingAttr"); 462 return !getDimSlices().empty(); 463 } 464 465 SparseTensorDimSliceAttr 466 SparseTensorEncodingAttr::getDimSlice(Dimension dim) const { 467 assert(isSlice() && "Is not a slice"); 468 const auto dimSlices = getDimSlices(); 469 assert(dim < dimSlices.size() && "Dimension is out of bounds"); 470 return dimSlices[dim]; 471 } 472 473 std::optional<uint64_t> 474 SparseTensorEncodingAttr::getStaticDimSliceOffset(Dimension dim) const { 475 return getDimSlice(dim).getStaticOffset(); 476 } 477 478 std::optional<uint64_t> 479 SparseTensorEncodingAttr::getStaticDimSliceStride(Dimension dim) const { 480 return getDimSlice(dim).getStaticStride(); 481 } 482 483 std::optional<uint64_t> 484 SparseTensorEncodingAttr::getStaticLvlSliceOffset(Level lvl) const { 485 return getStaticDimSliceOffset(toDim(*this, lvl)); 486 } 487 488 std::optional<uint64_t> 489 SparseTensorEncodingAttr::getStaticLvlSliceStride(Level lvl) const { 490 return getStaticDimSliceStride(toDim(*this, lvl)); 491 } 492 493 SmallVector<int64_t> 494 SparseTensorEncodingAttr::translateShape(ArrayRef<int64_t> srcShape, 495 CrdTransDirectionKind dir) const { 496 if (isIdentity()) 497 return SmallVector<int64_t>(srcShape); 498 499 SmallVector<int64_t> ret; 500 unsigned rank = 501 dir == CrdTransDirectionKind::dim2lvl ? getLvlRank() : getDimRank(); 502 ret.reserve(rank); 503 504 if (isPermutation()) { 505 for (unsigned r = 0; r < rank; r++) { 506 unsigned trans = dir == CrdTransDirectionKind::dim2lvl ? toDim(*this, r) 507 : toLvl(*this, r); 508 ret.push_back(srcShape[trans]); 509 } 510 return ret; 511 } 512 513 // Handle non-permutation maps. 514 AffineMap transMap = 515 dir == CrdTransDirectionKind::dim2lvl ? getDimToLvl() : getLvlToDim(); 516 517 SmallVector<AffineExpr> dimRep; 518 dimRep.reserve(srcShape.size()); 519 for (int64_t sz : srcShape) { 520 if (!ShapedType::isDynamic(sz)) { 521 // Push back the max coordinate for the given dimension/level size. 522 dimRep.push_back(getAffineConstantExpr(sz - 1, getContext())); 523 } else { 524 // A dynamic size, use a AffineDimExpr to symbolize the value. 525 dimRep.push_back(getAffineDimExpr(dimRep.size(), getContext())); 526 } 527 }; 528 529 for (AffineExpr exp : transMap.getResults()) { 530 // Do constant propagation on the affine map. 531 AffineExpr evalExp = 532 simplifyAffineExpr(exp.replaceDims(dimRep), srcShape.size(), 0); 533 // use llvm namespace here to avoid ambiguity 534 if (auto c = llvm::dyn_cast<AffineConstantExpr>(evalExp)) { 535 ret.push_back(c.getValue() + 1); 536 } else { 537 if (auto mod = llvm::dyn_cast<AffineBinaryOpExpr>(evalExp); 538 mod && mod.getKind() == AffineExprKind::Mod) { 539 // We can still infer a static bound for expressions in form 540 // "d % constant" since d % constant \in [0, constant). 541 if (auto bound = llvm::dyn_cast<AffineConstantExpr>(mod.getRHS())) { 542 ret.push_back(bound.getValue()); 543 continue; 544 } 545 } 546 ret.push_back(ShapedType::kDynamic); 547 } 548 } 549 assert(ret.size() == rank); 550 return ret; 551 } 552 553 ValueRange 554 SparseTensorEncodingAttr::translateCrds(OpBuilder &builder, Location loc, 555 ValueRange crds, 556 CrdTransDirectionKind dir) const { 557 if (!getImpl()) 558 return crds; 559 560 SmallVector<Type> retType( 561 dir == CrdTransDirectionKind::lvl2dim ? getDimRank() : getLvlRank(), 562 builder.getIndexType()); 563 auto transOp = builder.create<CrdTranslateOp>(loc, retType, crds, dir, *this); 564 return transOp.getOutCrds(); 565 } 566 567 Attribute SparseTensorEncodingAttr::parse(AsmParser &parser, Type type) { 568 // Open "<{" part. 569 if (failed(parser.parseLess())) 570 return {}; 571 if (failed(parser.parseLBrace())) 572 return {}; 573 574 // Process the data from the parsed dictionary value into struct-like data. 575 SmallVector<LevelType> lvlTypes; 576 SmallVector<SparseTensorDimSliceAttr> dimSlices; 577 AffineMap dimToLvl = {}; 578 AffineMap lvlToDim = {}; 579 unsigned posWidth = 0; 580 unsigned crdWidth = 0; 581 Attribute explicitVal; 582 Attribute implicitVal; 583 StringRef attrName; 584 SmallVector<StringRef, 5> keys = {"map", "posWidth", "crdWidth", 585 "explicitVal", "implicitVal"}; 586 while (succeeded(parser.parseOptionalKeyword(&attrName))) { 587 // Detect admissible keyword. 588 auto *it = find(keys, attrName); 589 if (it == keys.end()) { 590 parser.emitError(parser.getNameLoc(), "unexpected key: ") << attrName; 591 return {}; 592 } 593 unsigned keyWordIndex = it - keys.begin(); 594 // Consume the `=` after keys 595 if (failed(parser.parseEqual())) 596 return {}; 597 // Dispatch on keyword. 598 switch (keyWordIndex) { 599 case 0: { // map 600 ir_detail::DimLvlMapParser cParser(parser); 601 auto res = cParser.parseDimLvlMap(); 602 if (failed(res)) 603 return {}; 604 const auto &dlm = *res; 605 606 const Level lvlRank = dlm.getLvlRank(); 607 for (Level lvl = 0; lvl < lvlRank; lvl++) 608 lvlTypes.push_back(dlm.getLvlType(lvl)); 609 610 const Dimension dimRank = dlm.getDimRank(); 611 for (Dimension dim = 0; dim < dimRank; dim++) 612 dimSlices.push_back(dlm.getDimSlice(dim)); 613 // NOTE: the old syntax requires an all-or-nothing approach to 614 // `dimSlices`; therefore, if any slice actually exists then we need 615 // to convert null-DSA into default/nop DSA. 616 const auto isDefined = [](SparseTensorDimSliceAttr slice) { 617 return static_cast<bool>(slice.getImpl()); 618 }; 619 if (llvm::any_of(dimSlices, isDefined)) { 620 const auto defaultSlice = 621 SparseTensorDimSliceAttr::get(parser.getContext()); 622 for (Dimension dim = 0; dim < dimRank; dim++) 623 if (!isDefined(dimSlices[dim])) 624 dimSlices[dim] = defaultSlice; 625 } else { 626 dimSlices.clear(); 627 } 628 629 dimToLvl = dlm.getDimToLvlMap(parser.getContext()); 630 lvlToDim = dlm.getLvlToDimMap(parser.getContext()); 631 break; 632 } 633 case 1: { // posWidth 634 Attribute attr; 635 if (failed(parser.parseAttribute(attr))) 636 return {}; 637 auto intAttr = llvm::dyn_cast<IntegerAttr>(attr); 638 if (!intAttr) { 639 parser.emitError(parser.getNameLoc(), 640 "expected an integral position bitwidth"); 641 return {}; 642 } 643 posWidth = intAttr.getInt(); 644 break; 645 } 646 case 2: { // crdWidth 647 Attribute attr; 648 if (failed(parser.parseAttribute(attr))) 649 return {}; 650 auto intAttr = llvm::dyn_cast<IntegerAttr>(attr); 651 if (!intAttr) { 652 parser.emitError(parser.getNameLoc(), 653 "expected an integral index bitwidth"); 654 return {}; 655 } 656 crdWidth = intAttr.getInt(); 657 break; 658 } 659 case 3: { // explicitVal 660 Attribute attr; 661 if (failed(parser.parseAttribute(attr))) 662 return {}; 663 if (auto result = llvm::dyn_cast<FloatAttr>(attr)) { 664 explicitVal = result; 665 } else if (auto result = llvm::dyn_cast<IntegerAttr>(attr)) { 666 explicitVal = result; 667 } else if (auto result = llvm::dyn_cast<complex::NumberAttr>(attr)) { 668 explicitVal = result; 669 } else { 670 parser.emitError(parser.getNameLoc(), 671 "expected a numeric value for explicitVal"); 672 return {}; 673 } 674 break; 675 } 676 case 4: { // implicitVal 677 Attribute attr; 678 if (failed(parser.parseAttribute(attr))) 679 return {}; 680 if (auto result = llvm::dyn_cast<FloatAttr>(attr)) { 681 implicitVal = result; 682 } else if (auto result = llvm::dyn_cast<IntegerAttr>(attr)) { 683 implicitVal = result; 684 } else if (auto result = llvm::dyn_cast<complex::NumberAttr>(attr)) { 685 implicitVal = result; 686 } else { 687 parser.emitError(parser.getNameLoc(), 688 "expected a numeric value for implicitVal"); 689 return {}; 690 } 691 break; 692 } 693 } // switch 694 // Only last item can omit the comma. 695 if (parser.parseOptionalComma().failed()) 696 break; 697 } 698 699 // Close "}>" part. 700 if (failed(parser.parseRBrace())) 701 return {}; 702 if (failed(parser.parseGreater())) 703 return {}; 704 705 // Construct struct-like storage for attribute. 706 if (!lvlToDim || lvlToDim.isEmpty()) { 707 lvlToDim = inferLvlToDim(dimToLvl, parser.getContext()); 708 } 709 return parser.getChecked<SparseTensorEncodingAttr>( 710 parser.getContext(), lvlTypes, dimToLvl, lvlToDim, posWidth, crdWidth, 711 explicitVal, implicitVal, dimSlices); 712 } 713 714 void SparseTensorEncodingAttr::print(AsmPrinter &printer) const { 715 auto map = static_cast<AffineMap>(getDimToLvl()); 716 // Empty affine map indicates identity map 717 if (!map) 718 map = AffineMap::getMultiDimIdentityMap(getLvlTypes().size(), getContext()); 719 printer << "<{ map = "; 720 printSymbols(map, printer); 721 printer << '('; 722 printDimensions(map, printer, getDimSlices()); 723 printer << ") -> ("; 724 printLevels(map, printer, getLvlTypes()); 725 printer << ')'; 726 // Print remaining members only for non-default values. 727 if (getPosWidth()) 728 printer << ", posWidth = " << getPosWidth(); 729 if (getCrdWidth()) 730 printer << ", crdWidth = " << getCrdWidth(); 731 if (getExplicitVal()) { 732 printer << ", explicitVal = " << getExplicitVal(); 733 } 734 if (getImplicitVal()) 735 printer << ", implicitVal = " << getImplicitVal(); 736 printer << " }>"; 737 } 738 739 void SparseTensorEncodingAttr::printSymbols(AffineMap &map, 740 AsmPrinter &printer) const { 741 if (map.getNumSymbols() == 0) 742 return; 743 printer << '['; 744 for (unsigned i = 0, n = map.getNumSymbols() - 1; i < n; i++) 745 printer << 's' << i << ", "; 746 if (map.getNumSymbols() >= 1) 747 printer << 's' << map.getNumSymbols() - 1; 748 printer << ']'; 749 } 750 751 void SparseTensorEncodingAttr::printDimensions( 752 AffineMap &map, AsmPrinter &printer, 753 ArrayRef<SparseTensorDimSliceAttr> dimSlices) const { 754 if (!dimSlices.empty()) { 755 for (unsigned i = 0, n = map.getNumDims() - 1; i < n; i++) 756 printer << 'd' << i << " : " << dimSlices[i] << ", "; 757 if (map.getNumDims() >= 1) { 758 printer << 'd' << map.getNumDims() - 1 << " : " 759 << dimSlices[map.getNumDims() - 1]; 760 } 761 } else { 762 for (unsigned i = 0, n = map.getNumDims() - 1; i < n; i++) 763 printer << 'd' << i << ", "; 764 if (map.getNumDims() >= 1) 765 printer << 'd' << map.getNumDims() - 1; 766 } 767 } 768 769 void SparseTensorEncodingAttr::printLevels(AffineMap &map, AsmPrinter &printer, 770 ArrayRef<LevelType> lvlTypes) const { 771 for (unsigned i = 0, n = map.getNumResults() - 1; i < n; i++) { 772 map.getResult(i).print(printer.getStream()); 773 printer << " : " << toMLIRString(lvlTypes[i]) << ", "; 774 } 775 if (map.getNumResults() >= 1) { 776 auto lastIndex = map.getNumResults() - 1; 777 map.getResult(lastIndex).print(printer.getStream()); 778 printer << " : " << toMLIRString(lvlTypes[lastIndex]); 779 } 780 } 781 782 LogicalResult SparseTensorEncodingAttr::verify( 783 function_ref<InFlightDiagnostic()> emitError, ArrayRef<LevelType> lvlTypes, 784 AffineMap dimToLvl, AffineMap lvlToDim, unsigned posWidth, 785 unsigned crdWidth, Attribute explicitVal, Attribute implicitVal, 786 ArrayRef<SparseTensorDimSliceAttr> dimSlices) { 787 if (!acceptBitWidth(posWidth)) 788 return emitError() << "unexpected position bitwidth: " << posWidth; 789 if (!acceptBitWidth(crdWidth)) 790 return emitError() << "unexpected coordinate bitwidth: " << crdWidth; 791 if (auto it = std::find_if(lvlTypes.begin(), lvlTypes.end(), isSingletonLT); 792 it != std::end(lvlTypes)) { 793 if (it == lvlTypes.begin() || 794 (!isCompressedLT(*(it - 1)) && !isLooseCompressedLT(*(it - 1)))) 795 return emitError() << "expected compressed or loose_compressed level " 796 "before singleton level"; 797 if (!std::all_of(it, lvlTypes.end(), 798 [](LevelType i) { return isSingletonLT(i); })) 799 return emitError() << "expected all singleton lvlTypes " 800 "following a singleton level"; 801 // We can potentially support mixed SoA/AoS singleton levels. 802 if (!std::all_of(it, lvlTypes.end(), [it](LevelType i) { 803 return it->isa<LevelPropNonDefault::SoA>() == 804 i.isa<LevelPropNonDefault::SoA>(); 805 })) { 806 return emitError() << "expected all singleton lvlTypes stored in the " 807 "same memory layout (SoA vs AoS)."; 808 } 809 } 810 811 auto lastBatch = std::find_if(lvlTypes.rbegin(), lvlTypes.rend(), isBatchLT); 812 if (!std::all_of(lastBatch, lvlTypes.rend(), isBatchLT)) 813 return emitError() << "Batch lvlType can only be leading levels."; 814 815 // SoA property can only be applied on singleton level. 816 auto soaLvls = llvm::make_filter_range(lvlTypes, [](LevelType lt) { 817 return lt.isa<LevelPropNonDefault::SoA>(); 818 }); 819 if (llvm::any_of(soaLvls, [](LevelType lt) { 820 return !lt.isa<LevelFormat::Singleton>(); 821 })) { 822 return emitError() << "SoA is only applicable to singleton lvlTypes."; 823 } 824 825 // TODO: audit formats that actually are supported by backend. 826 if (auto it = std::find_if(lvlTypes.begin(), lvlTypes.end(), isNOutOfMLT); 827 it != std::end(lvlTypes)) { 828 if (it != lvlTypes.end() - 1) 829 return emitError() << "expected n_out_of_m to be the last level type"; 830 if (!std::all_of(lvlTypes.begin(), it, 831 [](LevelType i) { return isDenseLT(i); })) 832 return emitError() << "expected all dense lvlTypes " 833 "before a n_out_of_m level"; 834 if (dimToLvl && (dimToLvl.getNumDims() != dimToLvl.getNumResults())) { 835 if (!isBlockSparsity(dimToLvl)) { 836 return emitError() 837 << "expected 1xm block structure for n_out_of_m level"; 838 } 839 auto sizes = getBlockSize(dimToLvl); 840 unsigned coefficient = 0; 841 for (const auto &elem : sizes) { 842 if (elem != 0) { 843 if (elem != coefficient && coefficient != 0) { 844 return emitError() << "expected only one blocked level " 845 "with the same coefficients"; 846 } 847 coefficient = elem; 848 } 849 } 850 if (coefficient != getM(*it)) { 851 return emitError() << "expected coeffiencts of Affine expressions " 852 "to be equal to m of n_out_of_m level"; 853 } 854 } 855 } 856 // Before we can check that the level-rank is consistent/coherent 857 // across all fields, we need to define it. The source-of-truth for 858 // the `getLvlRank` method is the length of the level-types array, 859 // since it must always be provided and have full rank; therefore we 860 // use that same source-of-truth here. 861 const Level lvlRank = lvlTypes.size(); 862 if (lvlRank == 0) 863 return emitError() << "expected a non-empty array for lvlTypes"; 864 // We save `dimRank` here because we'll also need it to verify `dimSlices`. 865 const Dimension dimRank = dimToLvl ? dimToLvl.getNumDims() : lvlRank; 866 if (dimToLvl) { 867 if (dimToLvl.getNumResults() != lvlRank) 868 return emitError() 869 << "level-rank mismatch between dimToLvl and lvlTypes: " 870 << dimToLvl.getNumResults() << " != " << lvlRank; 871 auto inferRes = inferLvlToDim(dimToLvl, dimToLvl.getContext()); 872 // Symbols can't be inferred but are acceptable. 873 if (!inferRes && dimToLvl.getNumSymbols() == 0) 874 return emitError() << "failed to infer lvlToDim from dimToLvl"; 875 if (lvlToDim && (inferRes != lvlToDim)) 876 return emitError() << "expected lvlToDim to be an inverse of dimToLvl"; 877 if (dimRank > lvlRank) 878 return emitError() << "unexpected dimToLvl mapping from " << dimRank 879 << " to " << lvlRank; 880 } 881 if (!dimSlices.empty()) { 882 if (dimSlices.size() != dimRank) 883 return emitError() 884 << "dimension-rank mismatch between dimSlices and dimToLvl: " 885 << dimSlices.size() << " != " << dimRank; 886 // Compiler support for `dimSlices` currently requires that the two 887 // ranks agree. (However, it does allow `dimToLvl` to be a permutation.) 888 if (dimRank != lvlRank) 889 return emitError() 890 << "dimSlices expected dimension-rank to match level-rank: " 891 << dimRank << " != " << lvlRank; 892 } 893 return success(); 894 } 895 896 LogicalResult SparseTensorEncodingAttr::verifyEncoding( 897 ArrayRef<Size> dimShape, Type elementType, 898 function_ref<InFlightDiagnostic()> emitError) const { 899 // Check structural integrity. In particular, this ensures that the 900 // level-rank is coherent across all the fields. 901 if (failed(verify(emitError, getLvlTypes(), getDimToLvl(), getLvlToDim(), 902 getPosWidth(), getCrdWidth(), getExplicitVal(), 903 getImplicitVal(), getDimSlices()))) 904 return failure(); 905 // Check integrity with tensor type specifics. In particular, we 906 // need only check that the dimension-rank of the tensor agrees with 907 // the dimension-rank of the encoding. 908 const Dimension dimRank = dimShape.size(); 909 if (dimRank == 0) 910 return emitError() << "expected non-scalar sparse tensor"; 911 if (getDimRank() != dimRank) 912 return emitError() 913 << "dimension-rank mismatch between encoding and tensor shape: " 914 << getDimRank() << " != " << dimRank; 915 return success(); 916 } 917 918 //===----------------------------------------------------------------------===// 919 // SparseTensorType Methods. 920 //===----------------------------------------------------------------------===// 921 922 bool mlir::sparse_tensor::SparseTensorType::isCOOType(Level startLvl, 923 bool isUnique) const { 924 if (!hasEncoding()) 925 return false; 926 if (!isCompressedLvl(startLvl) && !isLooseCompressedLvl(startLvl)) 927 return false; 928 for (Level l = startLvl + 1; l < lvlRank; ++l) 929 if (!isSingletonLvl(l)) 930 return false; 931 // If isUnique is true, then make sure that the last level is unique, 932 // that is, when lvlRank == 1, the only compressed level is unique, 933 // and when lvlRank > 1, the last singleton is unique. 934 return !isUnique || isUniqueLvl(lvlRank - 1); 935 } 936 937 Level mlir::sparse_tensor::SparseTensorType::getAoSCOOStart() const { 938 SmallVector<COOSegment> coo = getCOOSegments(); 939 assert(coo.size() == 1 || coo.empty()); 940 if (!coo.empty() && coo.front().isAoS()) { 941 return coo.front().lvlRange.first; 942 } 943 return lvlRank; 944 } 945 946 SmallVector<COOSegment> 947 mlir::sparse_tensor::SparseTensorType::getCOOSegments() const { 948 SmallVector<COOSegment> ret; 949 if (!hasEncoding() || lvlRank <= 1) 950 return ret; 951 952 ArrayRef<LevelType> lts = getLvlTypes(); 953 Level l = 0; 954 while (l < lvlRank) { 955 auto lt = lts[l]; 956 if (lt.isa<LevelFormat::Compressed, LevelFormat::LooseCompressed>()) { 957 auto cur = lts.begin() + l; 958 auto end = std::find_if(cur + 1, lts.end(), [](LevelType lt) { 959 return !lt.isa<LevelFormat::Singleton>(); 960 }); 961 unsigned cooLen = std::distance(cur, end); 962 if (cooLen > 1) { 963 // To support mixed SoA/AoS COO, we should break the segment when the 964 // storage scheme changes, for now we faithfully assume that all 965 // consecutive singleton levels have the same storage format as verified 966 // STEA. 967 ret.push_back(COOSegment{std::make_pair(l, l + cooLen), 968 lts[l + 1].isa<LevelPropNonDefault::SoA>()}); 969 } 970 l += cooLen; 971 } else { 972 l++; 973 } 974 } 975 return ret; 976 } 977 978 RankedTensorType 979 mlir::sparse_tensor::SparseTensorType::getCOOType(bool ordered) const { 980 SmallVector<LevelType> lvlTypes; 981 lvlTypes.reserve(lvlRank); 982 // A non-unique compressed level at beginning (unless this is 983 // also the last level, then it is unique). 984 lvlTypes.push_back( 985 *buildLevelType(LevelFormat::Compressed, ordered, lvlRank == 1)); 986 if (lvlRank > 1) { 987 // Followed by n-2 non-unique singleton levels. 988 std::fill_n(std::back_inserter(lvlTypes), lvlRank - 2, 989 *buildLevelType(LevelFormat::Singleton, ordered, false)); 990 // Ends by a unique singleton level. 991 lvlTypes.push_back(*buildLevelType(LevelFormat::Singleton, ordered, true)); 992 } 993 auto enc = SparseTensorEncodingAttr::get( 994 getContext(), lvlTypes, getDimToLvl(), getLvlToDim(), getPosWidth(), 995 getCrdWidth(), getExplicitVal(), getImplicitVal()); 996 return RankedTensorType::get(getDimShape(), getElementType(), enc); 997 } 998 999 //===----------------------------------------------------------------------===// 1000 // Convenience Methods. 1001 //===----------------------------------------------------------------------===// 1002 1003 SparseTensorEncodingAttr 1004 mlir::sparse_tensor::getSparseTensorEncoding(Type type) { 1005 if (auto ttp = llvm::dyn_cast<RankedTensorType>(type)) 1006 return llvm::dyn_cast_or_null<SparseTensorEncodingAttr>(ttp.getEncoding()); 1007 if (auto mdtp = llvm::dyn_cast<StorageSpecifierType>(type)) 1008 return mdtp.getEncoding(); 1009 return nullptr; 1010 } 1011 1012 AffineMap mlir::sparse_tensor::inferLvlToDim(AffineMap dimToLvl, 1013 MLIRContext *context) { 1014 auto map = static_cast<AffineMap>(dimToLvl); 1015 AffineMap lvlToDim; 1016 // Return an empty lvlToDim when inference is not successful. 1017 if (!map || map.getNumSymbols() != 0) { 1018 lvlToDim = AffineMap(); 1019 } else if (map.isPermutation()) { 1020 lvlToDim = inversePermutation(map); 1021 } else if (isBlockSparsity(map)) { 1022 lvlToDim = inverseBlockSparsity(map, context); 1023 } 1024 return lvlToDim; 1025 } 1026 1027 AffineMap mlir::sparse_tensor::inverseBlockSparsity(AffineMap dimToLvl, 1028 MLIRContext *context) { 1029 SmallVector<AffineExpr> lvlExprs; 1030 auto numLvls = dimToLvl.getNumResults(); 1031 lvlExprs.reserve(numLvls); 1032 // lvlExprComponents stores information of the floordiv and mod operations 1033 // applied to the same dimension, so as to build the lvlToDim map. 1034 std::map<unsigned, SmallVector<AffineExpr, 3>> lvlExprComponents; 1035 for (unsigned i = 0, n = numLvls; i < n; i++) { 1036 auto result = dimToLvl.getResult(i); 1037 if (auto binOp = dyn_cast<AffineBinaryOpExpr>(result)) { 1038 if (result.getKind() == AffineExprKind::FloorDiv) { 1039 // Position of the dimension in dimToLvl. 1040 auto pos = dyn_cast<AffineDimExpr>(binOp.getLHS()).getPosition(); 1041 assert(lvlExprComponents.find(pos) == lvlExprComponents.end() && 1042 "expected only one floordiv for each dimension"); 1043 SmallVector<AffineExpr, 3> components; 1044 // Level variable for floordiv. 1045 components.push_back(getAffineDimExpr(i, context)); 1046 // Multiplier. 1047 components.push_back(binOp.getRHS()); 1048 // Map key is the position of the dimension. 1049 lvlExprComponents[pos] = components; 1050 } else if (result.getKind() == AffineExprKind::Mod) { 1051 auto pos = dyn_cast<AffineDimExpr>(binOp.getLHS()).getPosition(); 1052 assert(lvlExprComponents.find(pos) != lvlExprComponents.end() && 1053 "expected floordiv before mod"); 1054 // Add level variable for mod to the same vector 1055 // of the corresponding floordiv. 1056 lvlExprComponents[pos].push_back(getAffineDimExpr(i, context)); 1057 } else { 1058 assert(false && "expected floordiv or mod"); 1059 } 1060 } else { 1061 lvlExprs.push_back(getAffineDimExpr(i, context)); 1062 } 1063 } 1064 // Build lvlExprs from lvlExprComponents. 1065 // For example, for il = i floordiv 2 and ii = i mod 2, the components 1066 // would be [il, 2, ii]. It could be used to build the AffineExpr 1067 // i = il * 2 + ii in lvlToDim. 1068 for (auto &components : lvlExprComponents) { 1069 assert(components.second.size() == 3 && 1070 "expected 3 components to build lvlExprs"); 1071 auto mulOp = getAffineBinaryOpExpr( 1072 AffineExprKind::Mul, components.second[0], components.second[1]); 1073 auto addOp = 1074 getAffineBinaryOpExpr(AffineExprKind::Add, mulOp, components.second[2]); 1075 lvlExprs.push_back(addOp); 1076 } 1077 return dimToLvl.get(dimToLvl.getNumResults(), 0, lvlExprs, context); 1078 } 1079 1080 SmallVector<unsigned> mlir::sparse_tensor::getBlockSize(AffineMap dimToLvl) { 1081 assert(isBlockSparsity(dimToLvl) && 1082 "expected dimToLvl to be block sparsity for calling getBlockSize"); 1083 SmallVector<unsigned> blockSize; 1084 for (auto result : dimToLvl.getResults()) { 1085 if (auto binOp = dyn_cast<AffineBinaryOpExpr>(result)) { 1086 if (result.getKind() == AffineExprKind::Mod) { 1087 blockSize.push_back( 1088 dyn_cast<AffineConstantExpr>(binOp.getRHS()).getValue()); 1089 } 1090 } else { 1091 blockSize.push_back(0); 1092 } 1093 } 1094 return blockSize; 1095 } 1096 1097 bool mlir::sparse_tensor::isBlockSparsity(AffineMap dimToLvl) { 1098 if (!dimToLvl) 1099 return false; 1100 std::map<unsigned, int64_t> coeffientMap; 1101 bool hasBlock = false; 1102 for (auto result : dimToLvl.getResults()) { 1103 if (auto binOp = dyn_cast<AffineBinaryOpExpr>(result)) { 1104 // Check for "dim op const". 1105 auto dimOp = dyn_cast<AffineDimExpr>(binOp.getLHS()); 1106 auto conOp = dyn_cast<AffineConstantExpr>(binOp.getRHS()); 1107 if (!dimOp || !conOp || conOp.getValue() <= 0) 1108 return false; 1109 // Inspect "dim / const" or "dim % const". 1110 auto pos = dimOp.getPosition(); 1111 if (binOp.getKind() == AffineExprKind::FloorDiv) { 1112 // Expect only one floordiv for each dimension. 1113 if (coeffientMap.find(pos) != coeffientMap.end()) 1114 return false; 1115 // Record coefficient of the floordiv. 1116 coeffientMap[pos] = conOp.getValue(); 1117 } else if (binOp.getKind() == AffineExprKind::Mod) { 1118 // Expect floordiv before mod. 1119 if (coeffientMap.find(pos) == coeffientMap.end()) 1120 return false; 1121 // Expect mod to have the same coefficient as floordiv. 1122 if (conOp.getValue() != coeffientMap[pos]) 1123 return false; 1124 hasBlock = true; 1125 } else { 1126 return false; 1127 } 1128 } else if (auto dimOp = dyn_cast<AffineDimExpr>(result)) { 1129 auto pos = dimOp.getPosition(); 1130 // Expect dim to be unset. 1131 if (coeffientMap.find(pos) != coeffientMap.end()) 1132 return false; 1133 coeffientMap[pos] = 0; 1134 } else { 1135 return false; 1136 } 1137 } 1138 return hasBlock; 1139 } 1140 1141 bool mlir::sparse_tensor::hasAnyNonIdentityOperandsOrResults(Operation *op) { 1142 auto hasNonIdentityMap = [](Value v) { 1143 auto stt = tryGetSparseTensorType(v); 1144 return stt && !stt->isIdentity(); 1145 }; 1146 1147 return llvm::any_of(op->getOperands(), hasNonIdentityMap) || 1148 llvm::any_of(op->getResults(), hasNonIdentityMap); 1149 } 1150 1151 Dimension mlir::sparse_tensor::toDim(SparseTensorEncodingAttr enc, Level l) { 1152 if (enc) { 1153 assert(enc.isPermutation() && "Non permutation map not supported"); 1154 if (const auto dimToLvl = enc.getDimToLvl()) 1155 return dimToLvl.getDimPosition(l); 1156 } 1157 return l; 1158 } 1159 1160 Level mlir::sparse_tensor::toLvl(SparseTensorEncodingAttr enc, Dimension d) { 1161 if (enc) { 1162 assert(enc.isPermutation() && "Non permutation map not supported"); 1163 if (const auto lvlToDim = enc.getLvlToDim()) 1164 return lvlToDim.getDimPosition(d); 1165 } 1166 return d; 1167 } 1168 1169 /// We normalized sparse tensor encoding attribute by always using 1170 /// ordered/unique LT such that "compressed_nu_no" and "compressed_nu" (as well 1171 /// as other variants) lead to the same storage specifier type, and stripping 1172 /// irrelevant fields that do not alter the sparse tensor memory layout. 1173 static SparseTensorEncodingAttr 1174 getNormalizedEncodingForSpecifier(SparseTensorEncodingAttr enc) { 1175 SmallVector<LevelType> lts; 1176 for (auto lt : enc.getLvlTypes()) 1177 lts.push_back(lt.stripStorageIrrelevantProperties()); 1178 1179 return SparseTensorEncodingAttr::get( 1180 enc.getContext(), lts, 1181 AffineMap(), // dimToLvl (irrelevant to storage specifier) 1182 AffineMap(), // lvlToDim (irrelevant to storage specifier) 1183 // Always use `index` for memSize and lvlSize instead of reusing 1184 // `getPosWidth` and `getCrdWidth`. It allows us to reuse the same SSA 1185 // value for different bitwidth, it also avoids casting between index and 1186 // integer (returned by DimOp) 1187 0, 0, 1188 Attribute(), // explicitVal (irrelevant to storage specifier) 1189 Attribute(), // implicitVal (irrelevant to storage specifier) 1190 enc.getDimSlices()); 1191 } 1192 1193 StorageSpecifierType 1194 StorageSpecifierType::get(MLIRContext *ctx, SparseTensorEncodingAttr encoding) { 1195 return Base::get(ctx, getNormalizedEncodingForSpecifier(encoding)); 1196 } 1197 1198 //===----------------------------------------------------------------------===// 1199 // SparseTensorDialect Operations. 1200 //===----------------------------------------------------------------------===// 1201 1202 static LogicalResult lvlIsInBounds(Level lvl, Value tensor) { 1203 return success(lvl < getSparseTensorType(tensor).getLvlRank()); 1204 } 1205 1206 static LogicalResult isMatchingWidth(Value mem, unsigned width) { 1207 const Type etp = getMemRefType(mem).getElementType(); 1208 return success(width == 0 ? etp.isIndex() : etp.isInteger(width)); 1209 } 1210 1211 static LogicalResult verifySparsifierGetterSetter( 1212 StorageSpecifierKind mdKind, std::optional<Level> lvl, 1213 TypedValue<StorageSpecifierType> md, Operation *op) { 1214 if (mdKind == StorageSpecifierKind::ValMemSize && lvl) { 1215 return op->emitError( 1216 "redundant level argument for querying value memory size"); 1217 } 1218 1219 const auto enc = md.getType().getEncoding(); 1220 const Level lvlRank = enc.getLvlRank(); 1221 1222 if (mdKind == StorageSpecifierKind::DimOffset || 1223 mdKind == StorageSpecifierKind::DimStride) 1224 if (!enc.isSlice()) 1225 return op->emitError("requested slice data on non-slice tensor"); 1226 1227 if (mdKind != StorageSpecifierKind::ValMemSize) { 1228 if (!lvl) 1229 return op->emitError("missing level argument"); 1230 1231 const Level l = lvl.value(); 1232 if (l >= lvlRank) 1233 return op->emitError("requested level is out of bounds"); 1234 1235 if (mdKind == StorageSpecifierKind::PosMemSize && enc.isSingletonLvl(l)) 1236 return op->emitError( 1237 "requested position memory size on a singleton level"); 1238 } 1239 return success(); 1240 } 1241 1242 static Type getFieldElemType(SparseTensorType stt, SparseTensorFieldKind kind) { 1243 switch (kind) { 1244 case SparseTensorFieldKind::CrdMemRef: 1245 return stt.getCrdType(); 1246 case SparseTensorFieldKind::PosMemRef: 1247 return stt.getPosType(); 1248 case SparseTensorFieldKind::ValMemRef: 1249 return stt.getElementType(); 1250 case SparseTensorFieldKind::StorageSpec: 1251 return nullptr; 1252 } 1253 llvm_unreachable("Unrecognizable FieldKind"); 1254 } 1255 1256 static LogicalResult verifyPackUnPack(Operation *op, bool requiresStaticShape, 1257 SparseTensorType stt, 1258 RankedTensorType valTp, 1259 TypeRange lvlTps) { 1260 if (requiresStaticShape && !stt.hasStaticDimShape()) 1261 return op->emitError("the sparse-tensor must have static shape"); 1262 if (!stt.hasEncoding()) 1263 return op->emitError("the sparse-tensor must have an encoding attribute"); 1264 1265 // Verifies the trailing COO. 1266 Level cooStartLvl = stt.getAoSCOOStart(); 1267 if (cooStartLvl < stt.getLvlRank()) { 1268 // We only supports trailing COO for now, must be the last input. 1269 auto cooTp = llvm::cast<ShapedType>(lvlTps.back()); 1270 // The coordinates should be in shape of <? x rank> 1271 unsigned expCOORank = stt.getLvlRank() - cooStartLvl; 1272 if (cooTp.getRank() != 2 || expCOORank != cooTp.getShape().back()) { 1273 op->emitError("input/output trailing COO level-ranks don't match"); 1274 } 1275 } 1276 1277 // Verifies that all types match. 1278 StorageLayout layout(stt.getEncoding()); 1279 if (layout.getNumDataFields() != lvlTps.size() + 1) // plus one value memref 1280 return op->emitError("inconsistent number of fields between input/output"); 1281 1282 unsigned idx = 0; 1283 bool misMatch = false; 1284 layout.foreachField([&idx, &misMatch, stt, valTp, 1285 lvlTps](FieldIndex fid, SparseTensorFieldKind fKind, 1286 Level lvl, LevelType lt) -> bool { 1287 if (fKind == SparseTensorFieldKind::StorageSpec) 1288 return true; 1289 1290 Type inputTp = nullptr; 1291 if (fKind == SparseTensorFieldKind::ValMemRef) { 1292 inputTp = valTp; 1293 } else { 1294 assert(fid == idx && stt.getLvlType(lvl) == lt); 1295 inputTp = lvlTps[idx++]; 1296 } 1297 // The input element type and expected element type should match. 1298 Type inpElemTp = llvm::cast<TensorType>(inputTp).getElementType(); 1299 Type expElemTp = getFieldElemType(stt, fKind); 1300 if (inpElemTp != expElemTp) { 1301 misMatch = true; 1302 return false; // to terminate the iteration 1303 } 1304 return true; 1305 }); 1306 1307 if (misMatch) 1308 return op->emitError("input/output element-types don't match"); 1309 return success(); 1310 } 1311 1312 LogicalResult AssembleOp::verify() { 1313 const auto valuesTp = getRankedTensorType(getValues()); 1314 const auto lvlsTp = getLevels().getTypes(); 1315 const auto resTp = getSparseTensorType(getResult()); 1316 return verifyPackUnPack(*this, true, resTp, valuesTp, lvlsTp); 1317 } 1318 1319 LogicalResult DisassembleOp::verify() { 1320 if (getOutValues().getType() != getRetValues().getType()) 1321 return emitError("output values and return value type mismatch"); 1322 1323 for (auto [ot, rt] : llvm::zip_equal(getOutLevels(), getRetLevels())) 1324 if (ot.getType() != rt.getType()) 1325 return emitError("output levels and return levels type mismatch"); 1326 1327 const auto valuesTp = getRankedTensorType(getRetValues()); 1328 const auto lvlsTp = getRetLevels().getTypes(); 1329 const auto srcTp = getSparseTensorType(getTensor()); 1330 return verifyPackUnPack(*this, false, srcTp, valuesTp, lvlsTp); 1331 } 1332 1333 LogicalResult ConvertOp::verify() { 1334 if (auto tp1 = llvm::dyn_cast<RankedTensorType>(getSource().getType())) { 1335 if (auto tp2 = llvm::dyn_cast<RankedTensorType>(getDest().getType())) { 1336 if (tp1.getRank() != tp2.getRank()) 1337 return emitError("unexpected conversion mismatch in rank"); 1338 auto dstEnc = 1339 llvm::dyn_cast_or_null<SparseTensorEncodingAttr>(tp2.getEncoding()); 1340 if (dstEnc && dstEnc.isSlice()) 1341 return emitError("cannot convert to a sparse tensor slice"); 1342 1343 auto shape1 = tp1.getShape(); 1344 auto shape2 = tp2.getShape(); 1345 // Accept size matches between the source and the destination type 1346 // (e.g. 10 vs. 10, 10 vs. ?, or ? vs. ?), but reject direct mismatches or 1347 // matches that would need a runtime assert (e.g. 10 vs. 20 or ? vs. 10). 1348 for (Dimension d = 0, dimRank = tp1.getRank(); d < dimRank; d++) 1349 if (shape1[d] != shape2[d] && shape2[d] != ShapedType::kDynamic) 1350 return emitError("unexpected conversion mismatch in dimension ") << d; 1351 return success(); 1352 } 1353 } 1354 return emitError("unexpected type in convert"); 1355 } 1356 1357 OpFoldResult ConvertOp::fold(FoldAdaptor adaptor) { 1358 if (getType() == getSource().getType()) 1359 return getSource(); 1360 return {}; 1361 } 1362 1363 bool ConvertOp::needsExtraSort() { 1364 SparseTensorType srcStt = getSparseTensorType(getSource()); 1365 SparseTensorType dstStt = getSparseTensorType(getDest()); 1366 1367 // We do not need an extra sort when returning unordered sparse tensors or 1368 // dense tensor since dense tensor support random access. 1369 if (dstStt.isAllDense() || !dstStt.isAllOrdered()) 1370 return false; 1371 1372 if (srcStt.isAllOrdered() && dstStt.isAllOrdered() && 1373 srcStt.hasSameDimToLvl(dstStt)) { 1374 return false; 1375 } 1376 1377 // Source and dest tensors are ordered in different ways. We only do direct 1378 // dense to sparse conversion when the dense input is defined by a sparse 1379 // constant. Note that we can theoretically always directly convert from dense 1380 // inputs by rotating dense loops but it leads to bad cache locality and hurt 1381 // performance. 1382 if (auto constOp = getSource().getDefiningOp<arith::ConstantOp>()) 1383 if (isa<SparseElementsAttr>(constOp.getValue())) 1384 return false; 1385 1386 return true; 1387 } 1388 1389 LogicalResult CrdTranslateOp::verify() { 1390 uint64_t inRank = getEncoder().getLvlRank(); 1391 uint64_t outRank = getEncoder().getDimRank(); 1392 1393 if (getDirection() == CrdTransDirectionKind::dim2lvl) 1394 std::swap(inRank, outRank); 1395 1396 if (inRank != getInCrds().size() || outRank != getOutCrds().size()) 1397 return emitError("Coordinate rank mismatch with encoding"); 1398 1399 return success(); 1400 } 1401 1402 LogicalResult CrdTranslateOp::fold(FoldAdaptor adaptor, 1403 SmallVectorImpl<OpFoldResult> &results) { 1404 if (getEncoder().isIdentity()) { 1405 results.assign(getInCrds().begin(), getInCrds().end()); 1406 return success(); 1407 } 1408 if (getEncoder().isPermutation()) { 1409 AffineMap perm = getDirection() == CrdTransDirectionKind::dim2lvl 1410 ? getEncoder().getDimToLvl() 1411 : getEncoder().getLvlToDim(); 1412 for (AffineExpr exp : perm.getResults()) 1413 results.push_back(getInCrds()[cast<AffineDimExpr>(exp).getPosition()]); 1414 return success(); 1415 } 1416 1417 // Fuse dim2lvl/lvl2dim pairs. 1418 auto def = getInCrds()[0].getDefiningOp<CrdTranslateOp>(); 1419 bool sameDef = def && llvm::all_of(getInCrds(), [def](Value v) { 1420 return v.getDefiningOp() == def; 1421 }); 1422 if (!sameDef) 1423 return failure(); 1424 1425 bool oppositeDir = def.getDirection() != getDirection(); 1426 bool sameOracle = 1427 def.getEncoder().getDimToLvl() == getEncoder().getDimToLvl(); 1428 bool sameCount = def.getNumResults() == getInCrds().size(); 1429 if (!oppositeDir || !sameOracle || !sameCount) 1430 return failure(); 1431 1432 // The definition produces the coordinates in the same order as the input 1433 // coordinates. 1434 bool sameOrder = llvm::all_of(llvm::zip_equal(def.getOutCrds(), getInCrds()), 1435 [](auto valuePair) { 1436 auto [lhs, rhs] = valuePair; 1437 return lhs == rhs; 1438 }); 1439 1440 if (!sameOrder) 1441 return failure(); 1442 // l1 = dim2lvl (lvl2dim l0) 1443 // ==> l0 1444 results.append(def.getInCrds().begin(), def.getInCrds().end()); 1445 return success(); 1446 } 1447 1448 void LvlOp::build(OpBuilder &builder, OperationState &state, Value source, 1449 int64_t index) { 1450 Value val = builder.create<arith::ConstantIndexOp>(state.location, index); 1451 return build(builder, state, source, val); 1452 } 1453 1454 LogicalResult LvlOp::verify() { 1455 if (std::optional<uint64_t> lvl = getConstantLvlIndex()) { 1456 auto stt = getSparseTensorType(getSource()); 1457 if (static_cast<uint64_t>(lvl.value()) >= stt.getLvlRank()) 1458 emitError("Level index exceeds the rank of the input sparse tensor"); 1459 } 1460 return success(); 1461 } 1462 1463 std::optional<uint64_t> LvlOp::getConstantLvlIndex() { 1464 return getConstantIntValue(getIndex()); 1465 } 1466 1467 Speculation::Speculatability LvlOp::getSpeculatability() { 1468 auto constantIndex = getConstantLvlIndex(); 1469 if (!constantIndex) 1470 return Speculation::NotSpeculatable; 1471 1472 assert(constantIndex < 1473 cast<RankedTensorType>(getSource().getType()).getRank()); 1474 return Speculation::Speculatable; 1475 } 1476 1477 OpFoldResult LvlOp::fold(FoldAdaptor adaptor) { 1478 auto lvlIndex = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getIndex()); 1479 if (!lvlIndex) 1480 return {}; 1481 1482 Level lvl = lvlIndex.getAPSInt().getZExtValue(); 1483 auto stt = getSparseTensorType(getSource()); 1484 if (lvl >= stt.getLvlRank()) { 1485 // Follows the same convention used by tensor.dim operation. Out of bound 1486 // indices produce undefined behavior but are still valid IR. Don't choke on 1487 // them. 1488 return {}; 1489 } 1490 1491 // Helper lambda to build an IndexAttr. 1492 auto getIndexAttr = [this](int64_t lvlSz) { 1493 return IntegerAttr::get(IndexType::get(getContext()), APInt(64, lvlSz)); 1494 }; 1495 1496 SmallVector<Size> lvlShape = stt.getLvlShape(); 1497 if (!ShapedType::isDynamic(lvlShape[lvl])) 1498 return getIndexAttr(lvlShape[lvl]); 1499 1500 return {}; 1501 } 1502 1503 void ReinterpretMapOp::build(OpBuilder &odsBuilder, OperationState &odsState, 1504 SparseTensorEncodingAttr dstEnc, Value source) { 1505 auto srcStt = getSparseTensorType(source); 1506 SmallVector<int64_t> srcLvlShape = srcStt.getLvlShape(); 1507 SmallVector<int64_t> dstDimShape = 1508 dstEnc.translateShape(srcLvlShape, CrdTransDirectionKind::lvl2dim); 1509 auto dstTp = 1510 RankedTensorType::get(dstDimShape, srcStt.getElementType(), dstEnc); 1511 return build(odsBuilder, odsState, dstTp, source); 1512 } 1513 1514 LogicalResult ReinterpretMapOp::verify() { 1515 auto srcStt = getSparseTensorType(getSource()); 1516 auto dstStt = getSparseTensorType(getDest()); 1517 ArrayRef<LevelType> srcLvlTps = srcStt.getLvlTypes(); 1518 ArrayRef<LevelType> dstLvlTps = dstStt.getLvlTypes(); 1519 1520 if (srcLvlTps.size() != dstLvlTps.size()) 1521 return emitError("Level rank mismatch between source/dest tensors"); 1522 1523 for (auto [srcLvlTp, dstLvlTp] : llvm::zip(srcLvlTps, dstLvlTps)) 1524 if (srcLvlTp != dstLvlTp) 1525 return emitError("Level type mismatch between source/dest tensors"); 1526 1527 if (srcStt.getPosWidth() != dstStt.getPosWidth() || 1528 srcStt.getCrdWidth() != dstStt.getCrdWidth()) { 1529 return emitError("Crd/Pos width mismatch between source/dest tensors"); 1530 } 1531 1532 if (srcStt.getElementType() != dstStt.getElementType()) 1533 return emitError("Element type mismatch between source/dest tensors"); 1534 1535 SmallVector<Size> srcLvlShape = srcStt.getLvlShape(); 1536 SmallVector<Size> dstLvlShape = dstStt.getLvlShape(); 1537 for (auto [srcLvlSz, dstLvlSz] : llvm::zip(srcLvlShape, dstLvlShape)) { 1538 if (srcLvlSz != dstLvlSz) { 1539 // Should we allow one side to be dynamic size, e.g., <?x?> should be 1540 // compatible to <3x4>? For now, we require all the level sizes to be 1541 // *exactly* matched for simplicity. 1542 return emitError("Level size mismatch between source/dest tensors"); 1543 } 1544 } 1545 1546 return success(); 1547 } 1548 1549 OpFoldResult ReinterpretMapOp::fold(FoldAdaptor adaptor) { 1550 if (getSource().getType() == getDest().getType()) 1551 return getSource(); 1552 1553 if (auto def = getSource().getDefiningOp<ReinterpretMapOp>()) { 1554 // A -> B, B -> A ==> A 1555 if (def.getSource().getType() == getDest().getType()) 1556 return def.getSource(); 1557 } 1558 return {}; 1559 } 1560 1561 template <typename ToBufferOp> 1562 static LogicalResult inferSparseBufferType(ValueRange ops, DictionaryAttr attr, 1563 OpaqueProperties prop, 1564 RegionRange region, 1565 SmallVectorImpl<mlir::Type> &ret) { 1566 typename ToBufferOp::Adaptor adaptor(ops, attr, prop, region); 1567 SparseTensorType stt = getSparseTensorType(adaptor.getTensor()); 1568 Type elemTp = nullptr; 1569 bool withStride = false; 1570 if constexpr (std::is_same_v<ToBufferOp, ToPositionsOp>) { 1571 elemTp = stt.getPosType(); 1572 } else if constexpr (std::is_same_v<ToBufferOp, ToCoordinatesOp> || 1573 std::is_same_v<ToBufferOp, ToCoordinatesBufferOp>) { 1574 elemTp = stt.getCrdType(); 1575 if constexpr (std::is_same_v<ToBufferOp, ToCoordinatesOp>) 1576 withStride = stt.getAoSCOOStart() <= adaptor.getLevel(); 1577 } else if constexpr (std::is_same_v<ToBufferOp, ToValuesOp>) { 1578 elemTp = stt.getElementType(); 1579 } 1580 1581 assert(elemTp && "unhandled operation."); 1582 SmallVector<int64_t> bufShape = stt.getBatchLvlShape(); 1583 bufShape.push_back(ShapedType::kDynamic); 1584 1585 auto layout = withStride ? StridedLayoutAttr::StridedLayoutAttr::get( 1586 stt.getContext(), ShapedType::kDynamic, 1587 {ShapedType::kDynamic}) 1588 : StridedLayoutAttr(); 1589 ret.emplace_back(MemRefType::get(bufShape, elemTp, layout)); 1590 return success(); 1591 } 1592 1593 LogicalResult ToPositionsOp::verify() { 1594 auto stt = getSparseTensorType(getTensor()); 1595 if (failed(lvlIsInBounds(getLevel(), getTensor()))) 1596 return emitError("requested level is out of bounds"); 1597 if (failed(isMatchingWidth(getResult(), stt.getPosWidth()))) 1598 return emitError("unexpected type for positions"); 1599 return success(); 1600 } 1601 1602 LogicalResult 1603 ToPositionsOp::inferReturnTypes(MLIRContext *ctx, std::optional<Location> loc, 1604 ValueRange ops, DictionaryAttr attr, 1605 OpaqueProperties prop, RegionRange region, 1606 SmallVectorImpl<mlir::Type> &ret) { 1607 return inferSparseBufferType<ToPositionsOp>(ops, attr, prop, region, ret); 1608 } 1609 1610 LogicalResult ToCoordinatesOp::verify() { 1611 auto stt = getSparseTensorType(getTensor()); 1612 if (failed(lvlIsInBounds(getLevel(), getTensor()))) 1613 return emitError("requested level is out of bounds"); 1614 if (failed(isMatchingWidth(getResult(), stt.getCrdWidth()))) 1615 return emitError("unexpected type for coordinates"); 1616 return success(); 1617 } 1618 1619 LogicalResult 1620 ToCoordinatesOp::inferReturnTypes(MLIRContext *ctx, std::optional<Location> loc, 1621 ValueRange ops, DictionaryAttr attr, 1622 OpaqueProperties prop, RegionRange region, 1623 SmallVectorImpl<mlir::Type> &ret) { 1624 return inferSparseBufferType<ToCoordinatesOp>(ops, attr, prop, region, ret); 1625 } 1626 1627 LogicalResult ToCoordinatesBufferOp::verify() { 1628 auto stt = getSparseTensorType(getTensor()); 1629 if (stt.getAoSCOOStart() >= stt.getLvlRank()) 1630 return emitError("expected sparse tensor with a COO region"); 1631 return success(); 1632 } 1633 1634 LogicalResult ToCoordinatesBufferOp::inferReturnTypes( 1635 MLIRContext *ctx, std::optional<Location> loc, ValueRange ops, 1636 DictionaryAttr attr, OpaqueProperties prop, RegionRange region, 1637 SmallVectorImpl<mlir::Type> &ret) { 1638 return inferSparseBufferType<ToCoordinatesBufferOp>(ops, attr, prop, region, 1639 ret); 1640 } 1641 1642 LogicalResult ToValuesOp::verify() { 1643 auto stt = getSparseTensorType(getTensor()); 1644 auto mtp = getMemRefType(getResult()); 1645 if (stt.getElementType() != mtp.getElementType()) 1646 return emitError("unexpected mismatch in element types"); 1647 return success(); 1648 } 1649 1650 LogicalResult ToValuesOp::inferReturnTypes(MLIRContext *ctx, 1651 std::optional<Location> loc, 1652 ValueRange ops, DictionaryAttr attr, 1653 OpaqueProperties prop, 1654 RegionRange region, 1655 SmallVectorImpl<mlir::Type> &ret) { 1656 return inferSparseBufferType<ToValuesOp>(ops, attr, prop, region, ret); 1657 } 1658 1659 LogicalResult ToSliceOffsetOp::verify() { 1660 auto rank = getRankedTensorType(getSlice()).getRank(); 1661 if (rank <= getDim().getSExtValue() || getDim().getSExtValue() < 0) 1662 return emitError("requested dimension out of bound"); 1663 return success(); 1664 } 1665 1666 LogicalResult ToSliceStrideOp::verify() { 1667 auto rank = getRankedTensorType(getSlice()).getRank(); 1668 if (rank <= getDim().getSExtValue() || getDim().getSExtValue() < 0) 1669 return emitError("requested dimension out of bound"); 1670 return success(); 1671 } 1672 1673 LogicalResult GetStorageSpecifierOp::verify() { 1674 return verifySparsifierGetterSetter(getSpecifierKind(), getLevel(), 1675 getSpecifier(), getOperation()); 1676 } 1677 1678 template <typename SpecifierOp> 1679 static SetStorageSpecifierOp getSpecifierSetDef(SpecifierOp op) { 1680 return op.getSpecifier().template getDefiningOp<SetStorageSpecifierOp>(); 1681 } 1682 1683 OpFoldResult GetStorageSpecifierOp::fold(FoldAdaptor adaptor) { 1684 const StorageSpecifierKind kind = getSpecifierKind(); 1685 const auto lvl = getLevel(); 1686 for (auto op = getSpecifierSetDef(*this); op; op = getSpecifierSetDef(op)) 1687 if (kind == op.getSpecifierKind() && lvl == op.getLevel()) 1688 return op.getValue(); 1689 return {}; 1690 } 1691 1692 LogicalResult SetStorageSpecifierOp::verify() { 1693 return verifySparsifierGetterSetter(getSpecifierKind(), getLevel(), 1694 getSpecifier(), getOperation()); 1695 } 1696 1697 template <class T> 1698 static LogicalResult verifyNumBlockArgs(T *op, Region ®ion, 1699 const char *regionName, 1700 TypeRange inputTypes, Type outputType) { 1701 unsigned numArgs = region.getNumArguments(); 1702 unsigned expectedNum = inputTypes.size(); 1703 if (numArgs != expectedNum) 1704 return op->emitError() << regionName << " region must have exactly " 1705 << expectedNum << " arguments"; 1706 1707 for (unsigned i = 0; i < numArgs; i++) { 1708 Type typ = region.getArgument(i).getType(); 1709 if (typ != inputTypes[i]) 1710 return op->emitError() << regionName << " region argument " << (i + 1) 1711 << " type mismatch"; 1712 } 1713 Operation *term = region.front().getTerminator(); 1714 YieldOp yield = dyn_cast<YieldOp>(term); 1715 if (!yield) 1716 return op->emitError() << regionName 1717 << " region must end with sparse_tensor.yield"; 1718 if (!yield.hasSingleResult() || 1719 yield.getSingleResult().getType() != outputType) 1720 return op->emitError() << regionName << " region yield type mismatch"; 1721 1722 return success(); 1723 } 1724 1725 LogicalResult BinaryOp::verify() { 1726 NamedAttrList attrs = (*this)->getAttrs(); 1727 Type leftType = getX().getType(); 1728 Type rightType = getY().getType(); 1729 Type outputType = getOutput().getType(); 1730 Region &overlap = getOverlapRegion(); 1731 Region &left = getLeftRegion(); 1732 Region &right = getRightRegion(); 1733 1734 // Check correct number of block arguments and return type for each 1735 // non-empty region. 1736 if (!overlap.empty()) { 1737 if (failed(verifyNumBlockArgs(this, overlap, "overlap", 1738 TypeRange{leftType, rightType}, outputType))) 1739 return failure(); 1740 } 1741 if (!left.empty()) { 1742 if (failed(verifyNumBlockArgs(this, left, "left", TypeRange{leftType}, 1743 outputType))) 1744 return failure(); 1745 } else if (getLeftIdentity()) { 1746 if (leftType != outputType) 1747 return emitError("left=identity requires first argument to have the same " 1748 "type as the output"); 1749 } 1750 if (!right.empty()) { 1751 if (failed(verifyNumBlockArgs(this, right, "right", TypeRange{rightType}, 1752 outputType))) 1753 return failure(); 1754 } else if (getRightIdentity()) { 1755 if (rightType != outputType) 1756 return emitError("right=identity requires second argument to have the " 1757 "same type as the output"); 1758 } 1759 return success(); 1760 } 1761 1762 LogicalResult UnaryOp::verify() { 1763 Type inputType = getX().getType(); 1764 Type outputType = getOutput().getType(); 1765 1766 // Check correct number of block arguments and return type for each 1767 // non-empty region. 1768 Region &present = getPresentRegion(); 1769 if (!present.empty()) { 1770 if (failed(verifyNumBlockArgs(this, present, "present", 1771 TypeRange{inputType}, outputType))) 1772 return failure(); 1773 } 1774 Region &absent = getAbsentRegion(); 1775 if (!absent.empty()) { 1776 if (failed(verifyNumBlockArgs(this, absent, "absent", TypeRange{}, 1777 outputType))) 1778 return failure(); 1779 // Absent branch can only yield invariant values. 1780 Block *absentBlock = &absent.front(); 1781 Block *parent = getOperation()->getBlock(); 1782 Value absentVal = 1783 cast<YieldOp>(absentBlock->getTerminator()).getSingleResult(); 1784 if (auto arg = dyn_cast<BlockArgument>(absentVal)) { 1785 if (arg.getOwner() == parent) 1786 return emitError("absent region cannot yield linalg argument"); 1787 } else if (Operation *def = absentVal.getDefiningOp()) { 1788 if (!isa<arith::ConstantOp>(def) && 1789 (def->getBlock() == absentBlock || def->getBlock() == parent)) 1790 return emitError("absent region cannot yield locally computed value"); 1791 } 1792 } 1793 return success(); 1794 } 1795 1796 bool ConcatenateOp::needsExtraSort() { 1797 SparseTensorType dstStt = getSparseTensorType(*this); 1798 if (dstStt.isAllDense() || !dstStt.isAllOrdered()) 1799 return false; 1800 1801 bool allSameOrdered = llvm::all_of(getInputs(), [dstStt](Value op) { 1802 return getSparseTensorType(op).hasSameDimToLvl(dstStt); 1803 }); 1804 // TODO: When conDim != 0, as long as conDim corresponding to the first level 1805 // in all input/output buffers, and all input/output buffers have the same 1806 // dimToLvl, the tmp COO buffer is still unnecessary (e.g, concatenate 1807 // CSC matrices along column). 1808 bool directLowerable = 1809 allSameOrdered && getDimension() == 0 && dstStt.isIdentity(); 1810 return !directLowerable; 1811 } 1812 1813 LogicalResult ConcatenateOp::verify() { 1814 const auto dstTp = getSparseTensorType(*this); 1815 const Dimension concatDim = getDimension(); 1816 const Dimension dimRank = dstTp.getDimRank(); 1817 1818 if (getInputs().size() <= 1) 1819 return emitError("Need at least two tensors to concatenate."); 1820 1821 if (concatDim >= dimRank) 1822 return emitError(llvm::formatv( 1823 "Concat-dimension is out of bounds for dimension-rank ({0} >= {1})", 1824 concatDim, dimRank)); 1825 1826 for (const auto &it : llvm::enumerate(getInputs())) { 1827 const auto i = it.index(); 1828 const auto srcTp = getSparseTensorType(it.value()); 1829 if (srcTp.hasDynamicDimShape()) 1830 return emitError(llvm::formatv("Input tensor ${0} has dynamic shape", i)); 1831 const Dimension srcDimRank = srcTp.getDimRank(); 1832 if (srcDimRank != dimRank) 1833 return emitError( 1834 llvm::formatv("Input tensor ${0} has a different rank (rank={1}) " 1835 "from the output tensor (rank={2}).", 1836 i, srcDimRank, dimRank)); 1837 } 1838 1839 for (Dimension d = 0; d < dimRank; d++) { 1840 const Size dstSh = dstTp.getDimShape()[d]; 1841 if (d == concatDim) { 1842 if (!ShapedType::isDynamic(dstSh)) { 1843 // If we reach here, then all inputs have static shapes. So we 1844 // can use `getDimShape()[d]` instead of `*getDynamicDimSize(d)` 1845 // to avoid redundant assertions in the loop. 1846 Size sumSz = 0; 1847 for (const auto src : getInputs()) 1848 sumSz += getSparseTensorType(src).getDimShape()[d]; 1849 // If all dimension are statically known, the sum of all the input 1850 // dimensions should be equal to the output dimension. 1851 if (sumSz != dstSh) 1852 return emitError( 1853 "The concatenation dimension of the output tensor should be the " 1854 "sum of all the concatenation dimensions of the input tensors."); 1855 } 1856 } else { 1857 Size prev = dstSh; 1858 for (const auto src : getInputs()) { 1859 const auto sh = getSparseTensorType(src).getDimShape()[d]; 1860 if (!ShapedType::isDynamic(prev) && sh != prev) 1861 return emitError("All dimensions (expect for the concatenating one) " 1862 "should be equal."); 1863 prev = sh; 1864 } 1865 } 1866 } 1867 1868 return success(); 1869 } 1870 1871 void PushBackOp::build(OpBuilder &builder, OperationState &result, 1872 Value curSize, Value inBuffer, Value value) { 1873 build(builder, result, curSize, inBuffer, value, Value()); 1874 } 1875 1876 LogicalResult PushBackOp::verify() { 1877 if (Value n = getN()) { 1878 std::optional<int64_t> nValue = getConstantIntValue(n); 1879 if (nValue && nValue.value() < 1) 1880 return emitOpError("n must be not less than 1"); 1881 } 1882 return success(); 1883 } 1884 1885 LogicalResult CompressOp::verify() { 1886 const auto stt = getSparseTensorType(getTensor()); 1887 if (stt.getLvlRank() != 1 + static_cast<Level>(getLvlCoords().size())) 1888 return emitOpError("incorrect number of coordinates"); 1889 return success(); 1890 } 1891 1892 void ForeachOp::build( 1893 OpBuilder &builder, OperationState &result, Value tensor, 1894 ValueRange initArgs, AffineMapAttr order, 1895 function_ref<void(OpBuilder &, Location, ValueRange, Value, ValueRange)> 1896 bodyBuilder) { 1897 build(builder, result, initArgs.getTypes(), tensor, initArgs, order); 1898 // Builds foreach body. 1899 if (!bodyBuilder) 1900 return; 1901 const auto stt = getSparseTensorType(tensor); 1902 const Dimension dimRank = stt.getDimRank(); 1903 1904 // Starts with `dimRank`-many coordinates. 1905 SmallVector<Type> blockArgTypes(dimRank, builder.getIndexType()); 1906 // Followed by one value. 1907 blockArgTypes.push_back(stt.getElementType()); 1908 // Followed by the reduction variables. 1909 blockArgTypes.append(initArgs.getTypes().begin(), initArgs.getTypes().end()); 1910 1911 SmallVector<Location> blockArgLocs(blockArgTypes.size(), tensor.getLoc()); 1912 1913 OpBuilder::InsertionGuard guard(builder); 1914 auto ®ion = *result.regions.front(); 1915 Block *bodyBlock = 1916 builder.createBlock(®ion, region.end(), blockArgTypes, blockArgLocs); 1917 bodyBuilder(builder, result.location, 1918 bodyBlock->getArguments().slice(0, dimRank), 1919 bodyBlock->getArguments()[dimRank], 1920 bodyBlock->getArguments().drop_front(dimRank + 1)); 1921 } 1922 1923 LogicalResult ForeachOp::verify() { 1924 const auto t = getSparseTensorType(getTensor()); 1925 const Dimension dimRank = t.getDimRank(); 1926 const auto args = getBody()->getArguments(); 1927 1928 if (getOrder().has_value() && getOrder()->getNumDims() != t.getLvlRank()) 1929 return emitError("Level traverse order does not match tensor's level rank"); 1930 1931 if (dimRank + 1 + getInitArgs().size() != args.size()) 1932 return emitError("Unmatched number of arguments in the block"); 1933 1934 if (getNumResults() != getInitArgs().size()) 1935 return emitError("Mismatch in number of init arguments and results"); 1936 1937 if (getResultTypes() != getInitArgs().getTypes()) 1938 return emitError("Mismatch in types of init arguments and results"); 1939 1940 // Cannot mark this const, because the getters aren't. 1941 auto yield = cast<YieldOp>(getBody()->getTerminator()); 1942 if (yield.getNumOperands() != getNumResults() || 1943 yield.getOperands().getTypes() != getResultTypes()) 1944 return emitError("Mismatch in types of yield values and results"); 1945 1946 const auto iTp = IndexType::get(getContext()); 1947 for (Dimension d = 0; d < dimRank; d++) 1948 if (args[d].getType() != iTp) 1949 emitError( 1950 llvm::formatv("Expecting Index type for argument at index {0}", d)); 1951 1952 const auto elemTp = t.getElementType(); 1953 const auto valueTp = args[dimRank].getType(); 1954 if (elemTp != valueTp) 1955 emitError(llvm::formatv("Unmatched element type between input tensor and " 1956 "block argument, expected:{0}, got: {1}", 1957 elemTp, valueTp)); 1958 return success(); 1959 } 1960 1961 OpFoldResult ReorderCOOOp::fold(FoldAdaptor adaptor) { 1962 if (getSparseTensorEncoding(getInputCoo().getType()) == 1963 getSparseTensorEncoding(getResultCoo().getType())) 1964 return getInputCoo(); 1965 1966 return {}; 1967 } 1968 1969 LogicalResult ReorderCOOOp::verify() { 1970 SparseTensorType srcStt = getSparseTensorType(getInputCoo()); 1971 SparseTensorType dstStt = getSparseTensorType(getResultCoo()); 1972 1973 if (!srcStt.isCOOType() || !dstStt.isCOOType()) 1974 emitError("Expected COO sparse tensors only"); 1975 1976 if (!srcStt.hasSameDimToLvl(dstStt)) 1977 emitError("Unmatched dim2lvl map between input and result COO"); 1978 1979 if (srcStt.getPosType() != dstStt.getPosType() || 1980 srcStt.getCrdType() != dstStt.getCrdType() || 1981 srcStt.getElementType() != dstStt.getElementType()) 1982 emitError("Unmatched storage format between input and result COO"); 1983 1984 return success(); 1985 } 1986 1987 LogicalResult ReduceOp::verify() { 1988 Type inputType = getX().getType(); 1989 Region &formula = getRegion(); 1990 return verifyNumBlockArgs(this, formula, "reduce", 1991 TypeRange{inputType, inputType}, inputType); 1992 } 1993 1994 LogicalResult SelectOp::verify() { 1995 Builder b(getContext()); 1996 Type inputType = getX().getType(); 1997 Type boolType = b.getI1Type(); 1998 Region &formula = getRegion(); 1999 return verifyNumBlockArgs(this, formula, "select", TypeRange{inputType}, 2000 boolType); 2001 } 2002 2003 LogicalResult SortOp::verify() { 2004 AffineMap xPerm = getPermMap(); 2005 uint64_t nx = xPerm.getNumDims(); 2006 if (nx < 1) 2007 emitError(llvm::formatv("Expected rank(perm_map) > 1, got {0}", nx)); 2008 2009 if (!xPerm.isPermutation()) 2010 emitError(llvm::formatv("Expected a permutation map, got {0}", xPerm)); 2011 2012 // We can't check the size of the buffers when n or buffer dimensions aren't 2013 // compile-time constants. 2014 std::optional<int64_t> cn = getConstantIntValue(getN()); 2015 if (!cn) 2016 return success(); 2017 2018 // Verify dimensions. 2019 const auto checkDim = [&](Value v, Size minSize, const char *message) { 2020 const Size sh = getMemRefType(v).getShape()[0]; 2021 if (!ShapedType::isDynamic(sh) && sh < minSize) 2022 emitError(llvm::formatv("{0} got {1} < {2}", message, sh, minSize)); 2023 }; 2024 uint64_t n = cn.value(); 2025 uint64_t ny = 0; 2026 if (auto nyAttr = getNyAttr()) 2027 ny = nyAttr.getInt(); 2028 checkDim(getXy(), n * (nx + ny), 2029 "Expected dimension(xy) >= n * (rank(perm_map) + ny)"); 2030 for (Value opnd : getYs()) 2031 checkDim(opnd, n, "Expected dimension(y) >= n"); 2032 2033 return success(); 2034 } 2035 2036 //===----------------------------------------------------------------------===// 2037 // Sparse Tensor Iteration Operations. 2038 //===----------------------------------------------------------------------===// 2039 2040 IterSpaceType IteratorType::getIterSpaceType() const { 2041 return IterSpaceType::get(getContext(), getEncoding(), getLoLvl(), 2042 getHiLvl()); 2043 } 2044 2045 IteratorType IterSpaceType::getIteratorType() const { 2046 return IteratorType::get(getContext(), getEncoding(), getLoLvl(), getHiLvl()); 2047 } 2048 2049 /// Parses a level range in the form "$lo `to` $hi" 2050 /// or simply "$lo" if $hi - $lo = 1 2051 static ParseResult parseLevelRange(AsmParser &parser, Level &lvlLo, 2052 Level &lvlHi) { 2053 if (parser.parseInteger(lvlLo)) 2054 return failure(); 2055 2056 if (succeeded(parser.parseOptionalKeyword("to"))) { 2057 if (parser.parseInteger(lvlHi)) 2058 return failure(); 2059 } else { 2060 lvlHi = lvlLo + 1; 2061 } 2062 2063 if (lvlHi <= lvlLo) 2064 parser.emitError(parser.getNameLoc(), 2065 "expect larger level upper bound than lower bound"); 2066 2067 return success(); 2068 } 2069 2070 /// Parses a level range in the form "$lo `to` $hi" 2071 /// or simply "$lo" if $hi - $lo = 1 2072 static ParseResult parseLevelRange(OpAsmParser &parser, IntegerAttr &lvlLoAttr, 2073 IntegerAttr &lvlHiAttr) { 2074 Level lvlLo, lvlHi; 2075 if (parseLevelRange(parser, lvlLo, lvlHi)) 2076 return failure(); 2077 2078 lvlLoAttr = IntegerAttr::get(parser.getBuilder().getIndexType(), lvlLo); 2079 lvlHiAttr = IntegerAttr::get(parser.getBuilder().getIndexType(), lvlHi); 2080 return success(); 2081 } 2082 2083 /// Prints a level range in the form "$lo `to` $hi" 2084 /// or simply "$lo" if $hi - $lo = 1 2085 static void printLevelRange(AsmPrinter &p, Level lo, Level hi) { 2086 2087 if (lo + 1 == hi) 2088 p << lo; 2089 else 2090 p << lo << " to " << hi; 2091 } 2092 2093 /// Prints a level range in the form "$lo `to` $hi" 2094 /// or simply "$lo" if $hi - $lo = 1 2095 static void printLevelRange(OpAsmPrinter &p, Operation *, IntegerAttr lvlLo, 2096 IntegerAttr lvlHi) { 2097 unsigned lo = lvlLo.getValue().getZExtValue(); 2098 unsigned hi = lvlHi.getValue().getZExtValue(); 2099 printLevelRange(p, lo, hi); 2100 } 2101 2102 LogicalResult ExtractIterSpaceOp::inferReturnTypes( 2103 MLIRContext *ctx, std::optional<Location> loc, ValueRange ops, 2104 DictionaryAttr attr, OpaqueProperties prop, RegionRange region, 2105 SmallVectorImpl<mlir::Type> &ret) { 2106 2107 ExtractIterSpaceOp::Adaptor adaptor(ops, attr, prop, region); 2108 SparseTensorType stt = getSparseTensorType(adaptor.getTensor()); 2109 ret.push_back(IterSpaceType::get(ctx, stt.getEncoding(), adaptor.getLoLvl(), 2110 adaptor.getHiLvl())); 2111 return success(); 2112 } 2113 2114 LogicalResult ExtractIterSpaceOp::verify() { 2115 if (getLoLvl() >= getHiLvl()) 2116 return emitOpError("expected smaller level low than level high"); 2117 2118 TypedValue<IteratorType> pIter = getParentIter(); 2119 if ((pIter && getLoLvl() == 0) || (!pIter && getLoLvl() != 0)) { 2120 return emitOpError( 2121 "parent iterator should be specified iff level lower bound equals 0"); 2122 } 2123 2124 if (pIter) { 2125 IterSpaceType spaceTp = getResultSpace().getType(); 2126 if (pIter.getType().getEncoding() != spaceTp.getEncoding()) 2127 return emitOpError( 2128 "mismatch in parent iterator encoding and iteration space encoding."); 2129 2130 if (spaceTp.getLoLvl() != pIter.getType().getHiLvl()) 2131 return emitOpError("parent iterator should be used to extract an " 2132 "iteration space from a consecutive level."); 2133 } 2134 2135 return success(); 2136 } 2137 2138 /// Materialize a single constant operation from a given attribute value with 2139 /// the desired resultant type. 2140 Operation *SparseTensorDialect::materializeConstant(OpBuilder &builder, 2141 Attribute value, Type type, 2142 Location loc) { 2143 if (auto op = arith::ConstantOp::materialize(builder, value, type, loc)) 2144 return op; 2145 return nullptr; 2146 } 2147 2148 namespace { 2149 struct SparseTensorAsmDialectInterface : public OpAsmDialectInterface { 2150 using OpAsmDialectInterface::OpAsmDialectInterface; 2151 2152 AliasResult getAlias(Attribute attr, raw_ostream &os) const override { 2153 if (isa<SparseTensorEncodingAttr>(attr)) { 2154 os << "sparse"; 2155 return AliasResult::OverridableAlias; 2156 } 2157 return AliasResult::NoAlias; 2158 } 2159 }; 2160 } // namespace 2161 2162 void SparseTensorDialect::initialize() { 2163 addInterface<SparseTensorAsmDialectInterface>(); 2164 addAttributes< 2165 #define GET_ATTRDEF_LIST 2166 #include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.cpp.inc" 2167 >(); 2168 addTypes< 2169 #define GET_TYPEDEF_LIST 2170 #include "mlir/Dialect/SparseTensor/IR/SparseTensorTypes.cpp.inc" 2171 >(); 2172 addOperations< 2173 #define GET_OP_LIST 2174 #include "mlir/Dialect/SparseTensor/IR/SparseTensorOps.cpp.inc" 2175 >(); 2176 declarePromisedInterfaces< 2177 bufferization::BufferizableOpInterface, ConcatenateOp, ConvertOp, LoadOp, 2178 NewOp, NumberOfEntriesOp, AssembleOp, DisassembleOp, 2179 ToCoordinatesBufferOp, ToCoordinatesOp, ToPositionsOp, ToValuesOp>(); 2180 } 2181 2182 #define GET_OP_CLASSES 2183 #include "mlir/Dialect/SparseTensor/IR/SparseTensorOps.cpp.inc" 2184 2185 #include "mlir/Dialect/SparseTensor/IR/SparseTensorOpsDialect.cpp.inc" 2186