1 //===- SparseTensorDialect.cpp - Sparse tensor dialect implementation -----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include <utility> 10 11 #include "Detail/DimLvlMapParser.h" 12 13 #include "mlir/Dialect/SparseTensor/IR/Enums.h" 14 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" 15 #include "mlir/Dialect/SparseTensor/IR/SparseTensorStorageLayout.h" 16 #include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h" 17 18 #include "mlir/Dialect/Arith/IR/Arith.h" 19 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" 20 #include "mlir/Dialect/Utils/StaticValueUtils.h" 21 #include "mlir/IR/Builders.h" 22 #include "mlir/IR/DialectImplementation.h" 23 #include "mlir/IR/Matchers.h" 24 #include "mlir/IR/OpImplementation.h" 25 #include "mlir/IR/PatternMatch.h" 26 #include "llvm/ADT/TypeSwitch.h" 27 #include "llvm/Support/FormatVariadic.h" 28 29 #define GET_ATTRDEF_CLASSES 30 #include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.cpp.inc" 31 #include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrEnums.cpp.inc" 32 33 #define GET_TYPEDEF_CLASSES 34 #include "mlir/Dialect/SparseTensor/IR/SparseTensorTypes.cpp.inc" 35 36 using namespace mlir; 37 using namespace mlir::sparse_tensor; 38 39 // Support hashing LevelType such that SparseTensorEncodingAttr can be hashed as 40 // well. 41 namespace mlir::sparse_tensor { 42 llvm::hash_code hash_value(LevelType lt) { 43 return llvm::hash_value(static_cast<uint64_t>(lt)); 44 } 45 } // namespace mlir::sparse_tensor 46 47 //===----------------------------------------------------------------------===// 48 // Local Convenience Methods. 49 //===----------------------------------------------------------------------===// 50 51 static constexpr bool acceptBitWidth(unsigned bitWidth) { 52 switch (bitWidth) { 53 case 0: 54 case 8: 55 case 16: 56 case 32: 57 case 64: 58 return true; 59 default: 60 return false; 61 } 62 } 63 64 static SmallVector<Size> 65 getSparseFieldShape(const SparseTensorEncodingAttr enc, 66 std::optional<ArrayRef<int64_t>> dimShape) { 67 assert(enc); 68 // With only encoding, we can not determine the static shape for leading 69 // batch levels, we therefore return a dynamic shape memref instead. 70 SmallVector<int64_t> memrefShape(enc.getBatchLvlRank(), ShapedType::kDynamic); 71 if (dimShape.has_value()) { 72 // If the actual tensor shape is provided, we can then refine the leading 73 // batch dimension. 74 SmallVector<int64_t> lvlShape = 75 enc.translateShape(*dimShape, CrdTransDirectionKind::dim2lvl); 76 memrefShape.assign(lvlShape.begin(), 77 lvlShape.begin() + enc.getBatchLvlRank()); 78 } 79 // Another dynamic dimension to store the sparse level. 80 memrefShape.push_back(ShapedType::kDynamic); 81 return memrefShape; 82 } 83 84 //===----------------------------------------------------------------------===// 85 // SparseTensorDialect StorageLayout. 86 //===----------------------------------------------------------------------===// 87 88 static constexpr Level kInvalidLevel = -1u; 89 static constexpr Level kInvalidFieldIndex = -1u; 90 static constexpr FieldIndex kDataFieldStartingIdx = 0; 91 92 void StorageLayout::foreachField( 93 llvm::function_ref<bool(FieldIndex, SparseTensorFieldKind, Level, 94 LevelType)> 95 callback) const { 96 const auto lvlTypes = enc.getLvlTypes(); 97 const Level lvlRank = enc.getLvlRank(); 98 SmallVector<COOSegment> cooSegs = SparseTensorType(enc).getCOOSegments(); 99 FieldIndex fieldIdx = kDataFieldStartingIdx; 100 101 ArrayRef cooSegsRef = cooSegs; 102 // Per-level storage. 103 for (Level l = 0; l < lvlRank; /*l += 1 or l += AoSCooLen*/) { 104 const auto lt = lvlTypes[l]; 105 if (isWithPosLT(lt)) { 106 if (!(callback(fieldIdx++, SparseTensorFieldKind::PosMemRef, l, lt))) 107 return; 108 } 109 if (isWithCrdLT(lt)) { 110 if (!(callback(fieldIdx++, SparseTensorFieldKind::CrdMemRef, l, lt))) 111 return; 112 } 113 if (!cooSegsRef.empty() && cooSegsRef.front().isSegmentStart(l)) { 114 if (!cooSegsRef.front().isSoA) { 115 // AoS COO, all singletons are fused into one memrefs. Skips the entire 116 // COO segement. 117 l = cooSegsRef.front().lvlRange.second; 118 } else { 119 // SoA COO, each singleton level has one memref. 120 l++; 121 } 122 // Expire handled COO segment. 123 cooSegsRef = cooSegsRef.drop_front(); 124 } else { 125 // Non COO levels. 126 l++; 127 } 128 } 129 // The values array. 130 if (!(callback(fieldIdx++, SparseTensorFieldKind::ValMemRef, kInvalidLevel, 131 LevelFormat::Undef))) 132 return; 133 // Put metadata at the end. 134 if (!(callback(fieldIdx++, SparseTensorFieldKind::StorageSpec, kInvalidLevel, 135 LevelFormat::Undef))) 136 return; 137 } 138 139 void sparse_tensor::foreachFieldAndTypeInSparseTensor( 140 SparseTensorType stt, 141 llvm::function_ref<bool(Type, FieldIndex, SparseTensorFieldKind, Level, 142 LevelType)> 143 callback) { 144 assert(stt.hasEncoding()); 145 146 SmallVector<int64_t> memrefShape = 147 getSparseFieldShape(stt.getEncoding(), stt.getDimShape()); 148 149 const Type specType = StorageSpecifierType::get(stt.getEncoding()); 150 // memref<[batch] x ? x pos> positions 151 const Type posMemType = MemRefType::get(memrefShape, stt.getPosType()); 152 // memref<[batch] x ? x crd> coordinates 153 const Type crdMemType = MemRefType::get(memrefShape, stt.getCrdType()); 154 // memref<[batch] x ? x eltType> values 155 const Type valMemType = MemRefType::get(memrefShape, stt.getElementType()); 156 157 StorageLayout(stt).foreachField([specType, posMemType, crdMemType, valMemType, 158 callback](FieldIndex fieldIdx, 159 SparseTensorFieldKind fieldKind, 160 Level lvl, LevelType lt) -> bool { 161 switch (fieldKind) { 162 case SparseTensorFieldKind::StorageSpec: 163 return callback(specType, fieldIdx, fieldKind, lvl, lt); 164 case SparseTensorFieldKind::PosMemRef: 165 return callback(posMemType, fieldIdx, fieldKind, lvl, lt); 166 case SparseTensorFieldKind::CrdMemRef: 167 return callback(crdMemType, fieldIdx, fieldKind, lvl, lt); 168 case SparseTensorFieldKind::ValMemRef: 169 return callback(valMemType, fieldIdx, fieldKind, lvl, lt); 170 }; 171 llvm_unreachable("unrecognized field kind"); 172 }); 173 } 174 175 unsigned StorageLayout::getNumFields() const { 176 unsigned numFields = 0; 177 foreachField([&numFields](FieldIndex, SparseTensorFieldKind, Level, 178 LevelType) -> bool { 179 numFields++; 180 return true; 181 }); 182 return numFields; 183 } 184 185 unsigned StorageLayout::getNumDataFields() const { 186 unsigned numFields = 0; // one value memref 187 foreachField([&numFields](FieldIndex fidx, SparseTensorFieldKind, Level, 188 LevelType) -> bool { 189 if (fidx >= kDataFieldStartingIdx) 190 numFields++; 191 return true; 192 }); 193 numFields -= 1; // the last field is StorageSpecifier 194 assert(numFields == getNumFields() - kDataFieldStartingIdx - 1); 195 return numFields; 196 } 197 198 std::pair<FieldIndex, unsigned> 199 StorageLayout::getFieldIndexAndStride(SparseTensorFieldKind kind, 200 std::optional<Level> lvl) const { 201 FieldIndex fieldIdx = kInvalidFieldIndex; 202 unsigned stride = 1; 203 if (kind == SparseTensorFieldKind::CrdMemRef) { 204 assert(lvl.has_value()); 205 const Level cooStart = SparseTensorType(enc).getAoSCOOStart(); 206 const Level lvlRank = enc.getLvlRank(); 207 if (lvl.value() >= cooStart && lvl.value() < lvlRank) { 208 lvl = cooStart; 209 stride = lvlRank - cooStart; 210 } 211 } 212 foreachField([lvl, kind, &fieldIdx](FieldIndex fIdx, 213 SparseTensorFieldKind fKind, Level fLvl, 214 LevelType lt) -> bool { 215 if ((lvl && fLvl == lvl.value() && kind == fKind) || 216 (kind == fKind && fKind == SparseTensorFieldKind::ValMemRef)) { 217 fieldIdx = fIdx; 218 // Returns false to break the iteration. 219 return false; 220 } 221 return true; 222 }); 223 assert(fieldIdx != kInvalidFieldIndex); 224 return std::pair<FieldIndex, unsigned>(fieldIdx, stride); 225 } 226 227 //===----------------------------------------------------------------------===// 228 // SparseTensorDialect Attribute Methods. 229 //===----------------------------------------------------------------------===// 230 231 std::optional<uint64_t> SparseTensorDimSliceAttr::getStatic(int64_t v) { 232 return isDynamic(v) ? std::nullopt 233 : std::make_optional(static_cast<uint64_t>(v)); 234 } 235 236 std::optional<uint64_t> SparseTensorDimSliceAttr::getStaticOffset() const { 237 return getStatic(getOffset()); 238 } 239 240 std::optional<uint64_t> SparseTensorDimSliceAttr::getStaticStride() const { 241 return getStatic(getStride()); 242 } 243 244 std::optional<uint64_t> SparseTensorDimSliceAttr::getStaticSize() const { 245 return getStatic(getSize()); 246 } 247 248 bool SparseTensorDimSliceAttr::isCompletelyDynamic() const { 249 return isDynamic(getOffset()) && isDynamic(getStride()) && 250 isDynamic(getSize()); 251 } 252 253 std::string SparseTensorDimSliceAttr::getStaticString(int64_t v) { 254 return isDynamic(v) ? "?" : std::to_string(v); 255 } 256 257 void SparseTensorDimSliceAttr::print(llvm::raw_ostream &os) const { 258 assert(getImpl() && "Uninitialized SparseTensorDimSliceAttr"); 259 os << '('; 260 os << getStaticString(getOffset()); 261 os << ", "; 262 os << getStaticString(getSize()); 263 os << ", "; 264 os << getStaticString(getStride()); 265 os << ')'; 266 } 267 268 void SparseTensorDimSliceAttr::print(AsmPrinter &printer) const { 269 print(printer.getStream()); 270 } 271 272 static ParseResult parseOptionalStaticSlice(int64_t &result, 273 AsmParser &parser) { 274 auto parseResult = parser.parseOptionalInteger(result); 275 if (parseResult.has_value()) { 276 if (parseResult.value().succeeded() && result < 0) { 277 parser.emitError( 278 parser.getCurrentLocation(), 279 "expect positive value or ? for slice offset/size/stride"); 280 return failure(); 281 } 282 return parseResult.value(); 283 } 284 285 // Else, and '?' which represented dynamic slice 286 result = SparseTensorDimSliceAttr::kDynamic; 287 return parser.parseQuestion(); 288 } 289 290 Attribute SparseTensorDimSliceAttr::parse(AsmParser &parser, Type type) { 291 int64_t offset = kDynamic, size = kDynamic, stride = kDynamic; 292 293 if (failed(parser.parseLParen()) || 294 failed(parseOptionalStaticSlice(offset, parser)) || 295 failed(parser.parseComma()) || 296 failed(parseOptionalStaticSlice(size, parser)) || 297 failed(parser.parseComma()) || 298 failed(parseOptionalStaticSlice(stride, parser)) || 299 failed(parser.parseRParen())) 300 return {}; 301 302 return parser.getChecked<SparseTensorDimSliceAttr>(parser.getContext(), 303 offset, size, stride); 304 } 305 306 LogicalResult 307 SparseTensorDimSliceAttr::verify(function_ref<InFlightDiagnostic()> emitError, 308 int64_t offset, int64_t size, int64_t stride) { 309 if (!isDynamic(offset) && offset < 0) 310 return emitError() << "expect non-negative value or ? for slice offset"; 311 if (!isDynamic(size) && size <= 0) 312 return emitError() << "expect positive value or ? for slice size"; 313 if (!isDynamic(stride) && stride <= 0) 314 return emitError() << "expect positive value or ? for slice stride"; 315 return success(); 316 } 317 318 SparseTensorEncodingAttr 319 SparseTensorEncodingAttr::withDimToLvl(AffineMap dimToLvl) const { 320 assert(getImpl() && "Uninitialized SparseTensorEncodingAttr"); 321 return SparseTensorEncodingAttr::get(getContext(), getLvlTypes(), dimToLvl, 322 AffineMap(), getPosWidth(), 323 getCrdWidth()); 324 } 325 326 SparseTensorEncodingAttr 327 SparseTensorEncodingAttr::withDimToLvl(SparseTensorEncodingAttr enc) const { 328 return withDimToLvl(enc ? enc.getDimToLvl() : AffineMap()); 329 } 330 331 SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutDimToLvl() const { 332 return withDimToLvl(AffineMap()); 333 } 334 335 SparseTensorEncodingAttr 336 SparseTensorEncodingAttr::withBitWidths(unsigned posWidth, 337 unsigned crdWidth) const { 338 assert(getImpl() && "Uninitialized SparseTensorEncodingAttr"); 339 return SparseTensorEncodingAttr::get(getContext(), getLvlTypes(), 340 getDimToLvl(), getLvlToDim(), posWidth, 341 crdWidth); 342 } 343 344 SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutBitWidths() const { 345 return withBitWidths(0, 0); 346 } 347 348 SparseTensorEncodingAttr SparseTensorEncodingAttr::withDimSlices( 349 ArrayRef<SparseTensorDimSliceAttr> dimSlices) const { 350 return SparseTensorEncodingAttr::get(getContext(), getLvlTypes(), 351 getDimToLvl(), getLvlToDim(), 352 getPosWidth(), getCrdWidth(), dimSlices); 353 } 354 355 SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutDimSlices() const { 356 return withDimSlices(ArrayRef<SparseTensorDimSliceAttr>{}); 357 } 358 359 uint64_t SparseTensorEncodingAttr::getBatchLvlRank() const { 360 ArrayRef<LevelType> lvlTypes = getLvlTypes(); 361 auto lastBatch = std::find_if(lvlTypes.rbegin(), lvlTypes.rend(), isBatchLT); 362 return std::distance(lastBatch, lvlTypes.rend()); 363 } 364 365 bool SparseTensorEncodingAttr::isAllDense() const { 366 return !getImpl() || llvm::all_of(getLvlTypes(), isDenseLT); 367 } 368 369 bool SparseTensorEncodingAttr::isAllOrdered() const { 370 return !getImpl() || llvm::all_of(getLvlTypes(), isOrderedLT); 371 } 372 373 Type SparseTensorEncodingAttr::getCrdElemType() const { 374 if (!getImpl()) 375 return nullptr; 376 if (getCrdWidth()) 377 return IntegerType::get(getContext(), getCrdWidth()); 378 return IndexType::get(getContext()); 379 } 380 381 Type SparseTensorEncodingAttr::getPosElemType() const { 382 if (!getImpl()) 383 return nullptr; 384 if (getPosWidth()) 385 return IntegerType::get(getContext(), getPosWidth()); 386 return IndexType::get(getContext()); 387 } 388 389 MemRefType SparseTensorEncodingAttr::getCrdMemRefType( 390 std::optional<ArrayRef<int64_t>> dimShape) const { 391 SmallVector<Size> shape = getSparseFieldShape(*this, dimShape); 392 return MemRefType::get(shape, getCrdElemType()); 393 } 394 395 MemRefType SparseTensorEncodingAttr::getPosMemRefType( 396 std::optional<ArrayRef<int64_t>> dimShape) const { 397 SmallVector<Size> shape = getSparseFieldShape(*this, dimShape); 398 return MemRefType::get(shape, getPosElemType()); 399 } 400 401 bool SparseTensorEncodingAttr::isIdentity() const { 402 return !getImpl() || !getDimToLvl() || getDimToLvl().isIdentity(); 403 } 404 405 bool SparseTensorEncodingAttr::isPermutation() const { 406 return !getImpl() || !getDimToLvl() || getDimToLvl().isPermutation(); 407 } 408 409 Dimension SparseTensorEncodingAttr::getDimRank() const { 410 assert(getImpl() && "Uninitialized SparseTensorEncodingAttr"); 411 const auto dimToLvl = getDimToLvl(); 412 return dimToLvl ? dimToLvl.getNumDims() : getLvlRank(); 413 } 414 415 Level SparseTensorEncodingAttr::getLvlRank() const { 416 assert(getImpl() && "Uninitialized SparseTensorEncodingAttr"); 417 return getLvlTypes().size(); 418 } 419 420 LevelType SparseTensorEncodingAttr::getLvlType(Level l) const { 421 if (!getImpl()) 422 return LevelFormat::Batch; 423 assert(l < getLvlRank() && "Level is out of bounds"); 424 return getLvlTypes()[l]; 425 } 426 427 bool SparseTensorEncodingAttr::isSlice() const { 428 assert(getImpl() && "Uninitialized SparseTensorEncodingAttr"); 429 return !getDimSlices().empty(); 430 } 431 432 SparseTensorDimSliceAttr 433 SparseTensorEncodingAttr::getDimSlice(Dimension dim) const { 434 assert(isSlice() && "Is not a slice"); 435 const auto dimSlices = getDimSlices(); 436 assert(dim < dimSlices.size() && "Dimension is out of bounds"); 437 return dimSlices[dim]; 438 } 439 440 std::optional<uint64_t> 441 SparseTensorEncodingAttr::getStaticDimSliceOffset(Dimension dim) const { 442 return getDimSlice(dim).getStaticOffset(); 443 } 444 445 std::optional<uint64_t> 446 SparseTensorEncodingAttr::getStaticDimSliceStride(Dimension dim) const { 447 return getDimSlice(dim).getStaticStride(); 448 } 449 450 std::optional<uint64_t> 451 SparseTensorEncodingAttr::getStaticLvlSliceOffset(Level lvl) const { 452 return getStaticDimSliceOffset(toDim(*this, lvl)); 453 } 454 455 std::optional<uint64_t> 456 SparseTensorEncodingAttr::getStaticLvlSliceStride(Level lvl) const { 457 return getStaticDimSliceStride(toDim(*this, lvl)); 458 } 459 460 SmallVector<int64_t> 461 SparseTensorEncodingAttr::translateShape(ArrayRef<int64_t> srcShape, 462 CrdTransDirectionKind dir) const { 463 if (isIdentity()) 464 return SmallVector<int64_t>(srcShape); 465 466 SmallVector<int64_t> ret; 467 unsigned rank = 468 dir == CrdTransDirectionKind::dim2lvl ? getLvlRank() : getDimRank(); 469 ret.reserve(rank); 470 471 if (isPermutation()) { 472 for (unsigned r = 0; r < rank; r++) { 473 unsigned trans = dir == CrdTransDirectionKind::dim2lvl ? toDim(*this, r) 474 : toLvl(*this, r); 475 ret.push_back(srcShape[trans]); 476 } 477 return ret; 478 } 479 480 // Handle non-permutation maps. 481 AffineMap transMap = 482 dir == CrdTransDirectionKind::dim2lvl ? getDimToLvl() : getLvlToDim(); 483 484 SmallVector<AffineExpr> dimRep; 485 dimRep.reserve(srcShape.size()); 486 for (int64_t sz : srcShape) { 487 if (!ShapedType::isDynamic(sz)) { 488 // Push back the max coordinate for the given dimension/level size. 489 dimRep.push_back(getAffineConstantExpr(sz - 1, getContext())); 490 } else { 491 // A dynamic size, use a AffineDimExpr to symbolize the value. 492 dimRep.push_back(getAffineDimExpr(dimRep.size(), getContext())); 493 } 494 }; 495 496 for (AffineExpr exp : transMap.getResults()) { 497 // Do constant propagation on the affine map. 498 AffineExpr evalExp = 499 simplifyAffineExpr(exp.replaceDims(dimRep), srcShape.size(), 0); 500 // use llvm namespace here to avoid ambiguity 501 if (auto c = llvm::dyn_cast<AffineConstantExpr>(evalExp)) { 502 ret.push_back(c.getValue() + 1); 503 } else { 504 if (auto mod = llvm::dyn_cast<AffineBinaryOpExpr>(evalExp); 505 mod && mod.getKind() == AffineExprKind::Mod) { 506 // We can still infer a static bound for expressions in form 507 // "d % constant" since d % constant \in [0, constant). 508 if (auto bound = llvm::dyn_cast<AffineConstantExpr>(mod.getRHS())) { 509 ret.push_back(bound.getValue()); 510 continue; 511 } 512 } 513 ret.push_back(ShapedType::kDynamic); 514 } 515 } 516 assert(ret.size() == rank); 517 return ret; 518 } 519 520 ValueRange 521 SparseTensorEncodingAttr::translateCrds(OpBuilder &builder, Location loc, 522 ValueRange crds, 523 CrdTransDirectionKind dir) const { 524 if (!getImpl()) 525 return crds; 526 527 SmallVector<Type> retType( 528 dir == CrdTransDirectionKind::lvl2dim ? getDimRank() : getLvlRank(), 529 builder.getIndexType()); 530 auto transOp = builder.create<CrdTranslateOp>(loc, retType, crds, dir, *this); 531 return transOp.getOutCrds(); 532 } 533 534 Attribute SparseTensorEncodingAttr::parse(AsmParser &parser, Type type) { 535 // Open "<{" part. 536 if (failed(parser.parseLess())) 537 return {}; 538 if (failed(parser.parseLBrace())) 539 return {}; 540 541 // Process the data from the parsed dictionary value into struct-like data. 542 SmallVector<LevelType> lvlTypes; 543 SmallVector<SparseTensorDimSliceAttr> dimSlices; 544 AffineMap dimToLvl = {}; 545 AffineMap lvlToDim = {}; 546 unsigned posWidth = 0; 547 unsigned crdWidth = 0; 548 StringRef attrName; 549 SmallVector<StringRef, 3> keys = {"map", "posWidth", "crdWidth"}; 550 while (succeeded(parser.parseOptionalKeyword(&attrName))) { 551 // Detect admissible keyword. 552 auto *it = find(keys, attrName); 553 if (it == keys.end()) { 554 parser.emitError(parser.getNameLoc(), "unexpected key: ") << attrName; 555 return {}; 556 } 557 unsigned keyWordIndex = it - keys.begin(); 558 // Consume the `=` after keys 559 if (failed(parser.parseEqual())) 560 return {}; 561 // Dispatch on keyword. 562 switch (keyWordIndex) { 563 case 0: { // map 564 ir_detail::DimLvlMapParser cParser(parser); 565 auto res = cParser.parseDimLvlMap(); 566 if (failed(res)) 567 return {}; 568 const auto &dlm = *res; 569 570 const Level lvlRank = dlm.getLvlRank(); 571 for (Level lvl = 0; lvl < lvlRank; lvl++) 572 lvlTypes.push_back(dlm.getLvlType(lvl)); 573 574 const Dimension dimRank = dlm.getDimRank(); 575 for (Dimension dim = 0; dim < dimRank; dim++) 576 dimSlices.push_back(dlm.getDimSlice(dim)); 577 // NOTE: the old syntax requires an all-or-nothing approach to 578 // `dimSlices`; therefore, if any slice actually exists then we need 579 // to convert null-DSA into default/nop DSA. 580 const auto isDefined = [](SparseTensorDimSliceAttr slice) { 581 return static_cast<bool>(slice.getImpl()); 582 }; 583 if (llvm::any_of(dimSlices, isDefined)) { 584 const auto defaultSlice = 585 SparseTensorDimSliceAttr::get(parser.getContext()); 586 for (Dimension dim = 0; dim < dimRank; dim++) 587 if (!isDefined(dimSlices[dim])) 588 dimSlices[dim] = defaultSlice; 589 } else { 590 dimSlices.clear(); 591 } 592 593 dimToLvl = dlm.getDimToLvlMap(parser.getContext()); 594 lvlToDim = dlm.getLvlToDimMap(parser.getContext()); 595 break; 596 } 597 case 1: { // posWidth 598 Attribute attr; 599 if (failed(parser.parseAttribute(attr))) 600 return {}; 601 auto intAttr = llvm::dyn_cast<IntegerAttr>(attr); 602 if (!intAttr) { 603 parser.emitError(parser.getNameLoc(), 604 "expected an integral position bitwidth"); 605 return {}; 606 } 607 posWidth = intAttr.getInt(); 608 break; 609 } 610 case 2: { // crdWidth 611 Attribute attr; 612 if (failed(parser.parseAttribute(attr))) 613 return {}; 614 auto intAttr = llvm::dyn_cast<IntegerAttr>(attr); 615 if (!intAttr) { 616 parser.emitError(parser.getNameLoc(), 617 "expected an integral index bitwidth"); 618 return {}; 619 } 620 crdWidth = intAttr.getInt(); 621 break; 622 } 623 } // switch 624 // Only last item can omit the comma. 625 if (parser.parseOptionalComma().failed()) 626 break; 627 } 628 629 // Close "}>" part. 630 if (failed(parser.parseRBrace())) 631 return {}; 632 if (failed(parser.parseGreater())) 633 return {}; 634 635 // Construct struct-like storage for attribute. 636 if (!lvlToDim || lvlToDim.isEmpty()) { 637 lvlToDim = inferLvlToDim(dimToLvl, parser.getContext()); 638 } 639 return parser.getChecked<SparseTensorEncodingAttr>( 640 parser.getContext(), lvlTypes, dimToLvl, lvlToDim, posWidth, crdWidth, 641 dimSlices); 642 } 643 644 void SparseTensorEncodingAttr::print(AsmPrinter &printer) const { 645 auto map = static_cast<AffineMap>(getDimToLvl()); 646 // Empty affine map indicates identity map 647 if (!map) 648 map = AffineMap::getMultiDimIdentityMap(getLvlTypes().size(), getContext()); 649 printer << "<{ map = "; 650 printSymbols(map, printer); 651 printer << '('; 652 printDimensions(map, printer, getDimSlices()); 653 printer << ") -> ("; 654 printLevels(map, printer, getLvlTypes()); 655 printer << ')'; 656 // Print remaining members only for non-default values. 657 if (getPosWidth()) 658 printer << ", posWidth = " << getPosWidth(); 659 if (getCrdWidth()) 660 printer << ", crdWidth = " << getCrdWidth(); 661 printer << " }>"; 662 } 663 664 void SparseTensorEncodingAttr::printSymbols(AffineMap &map, 665 AsmPrinter &printer) const { 666 if (map.getNumSymbols() == 0) 667 return; 668 printer << '['; 669 for (unsigned i = 0, n = map.getNumSymbols() - 1; i < n; i++) 670 printer << 's' << i << ", "; 671 if (map.getNumSymbols() >= 1) 672 printer << 's' << map.getNumSymbols() - 1; 673 printer << ']'; 674 } 675 676 void SparseTensorEncodingAttr::printDimensions( 677 AffineMap &map, AsmPrinter &printer, 678 ArrayRef<SparseTensorDimSliceAttr> dimSlices) const { 679 if (!dimSlices.empty()) { 680 for (unsigned i = 0, n = map.getNumDims() - 1; i < n; i++) 681 printer << 'd' << i << " : " << dimSlices[i] << ", "; 682 if (map.getNumDims() >= 1) { 683 printer << 'd' << map.getNumDims() - 1 << " : " 684 << dimSlices[map.getNumDims() - 1]; 685 } 686 } else { 687 for (unsigned i = 0, n = map.getNumDims() - 1; i < n; i++) 688 printer << 'd' << i << ", "; 689 if (map.getNumDims() >= 1) 690 printer << 'd' << map.getNumDims() - 1; 691 } 692 } 693 694 void SparseTensorEncodingAttr::printLevels(AffineMap &map, AsmPrinter &printer, 695 ArrayRef<LevelType> lvlTypes) const { 696 for (unsigned i = 0, n = map.getNumResults() - 1; i < n; i++) { 697 map.getResult(i).print(printer.getStream()); 698 printer << " : " << toMLIRString(lvlTypes[i]) << ", "; 699 } 700 if (map.getNumResults() >= 1) { 701 auto lastIndex = map.getNumResults() - 1; 702 map.getResult(lastIndex).print(printer.getStream()); 703 printer << " : " << toMLIRString(lvlTypes[lastIndex]); 704 } 705 } 706 707 LogicalResult SparseTensorEncodingAttr::verify( 708 function_ref<InFlightDiagnostic()> emitError, ArrayRef<LevelType> lvlTypes, 709 AffineMap dimToLvl, AffineMap lvlToDim, unsigned posWidth, 710 unsigned crdWidth, ArrayRef<SparseTensorDimSliceAttr> dimSlices) { 711 if (!acceptBitWidth(posWidth)) 712 return emitError() << "unexpected position bitwidth: " << posWidth; 713 if (!acceptBitWidth(crdWidth)) 714 return emitError() << "unexpected coordinate bitwidth: " << crdWidth; 715 if (auto it = std::find_if(lvlTypes.begin(), lvlTypes.end(), isSingletonLT); 716 it != std::end(lvlTypes)) { 717 if (it == lvlTypes.begin() || 718 (!isCompressedLT(*(it - 1)) && !isLooseCompressedLT(*(it - 1)))) 719 return emitError() << "expected compressed or loose_compressed level " 720 "before singleton level"; 721 if (!std::all_of(it, lvlTypes.end(), 722 [](LevelType i) { return isSingletonLT(i); })) 723 return emitError() << "expected all singleton lvlTypes " 724 "following a singleton level"; 725 // We can potentially support mixed SoA/AoS singleton levels. 726 if (!std::all_of(it, lvlTypes.end(), [it](LevelType i) { 727 return it->isa<LevelPropNonDefault::SoA>() == 728 i.isa<LevelPropNonDefault::SoA>(); 729 })) { 730 return emitError() << "expected all singleton lvlTypes stored in the " 731 "same memory layout (SoA vs AoS)."; 732 } 733 } 734 735 auto lastBatch = std::find_if(lvlTypes.rbegin(), lvlTypes.rend(), isBatchLT); 736 if (!std::all_of(lastBatch, lvlTypes.rend(), isBatchLT)) 737 return emitError() << "Batch lvlType can only be leading levels."; 738 739 // SoA property can only be applied on singleton level. 740 auto soaLvls = llvm::make_filter_range(lvlTypes, [](LevelType lt) { 741 return lt.isa<LevelPropNonDefault::SoA>(); 742 }); 743 if (llvm::any_of(soaLvls, [](LevelType lt) { 744 return !lt.isa<LevelFormat::Singleton>(); 745 })) { 746 return emitError() << "SoA is only applicable to singleton lvlTypes."; 747 } 748 749 // TODO: audit formats that actually are supported by backend. 750 if (auto it = std::find_if(lvlTypes.begin(), lvlTypes.end(), isNOutOfMLT); 751 it != std::end(lvlTypes)) { 752 if (it != lvlTypes.end() - 1) 753 return emitError() << "expected n_out_of_m to be the last level type"; 754 if (!std::all_of(lvlTypes.begin(), it, 755 [](LevelType i) { return isDenseLT(i); })) 756 return emitError() << "expected all dense lvlTypes " 757 "before a n_out_of_m level"; 758 if (dimToLvl && (dimToLvl.getNumDims() != dimToLvl.getNumResults())) { 759 if (!isBlockSparsity(dimToLvl)) { 760 return emitError() 761 << "expected 1xm block structure for n_out_of_m level"; 762 } 763 auto sizes = getBlockSize(dimToLvl); 764 unsigned coefficient = 0; 765 for (const auto &elem : sizes) { 766 if (elem != 0) { 767 if (elem != coefficient && coefficient != 0) { 768 return emitError() << "expected only one blocked level " 769 "with the same coefficients"; 770 } 771 coefficient = elem; 772 } 773 } 774 if (coefficient != getM(*it)) { 775 return emitError() << "expected coeffiencts of Affine expressions " 776 "to be equal to m of n_out_of_m level"; 777 } 778 } 779 } 780 // Before we can check that the level-rank is consistent/coherent 781 // across all fields, we need to define it. The source-of-truth for 782 // the `getLvlRank` method is the length of the level-types array, 783 // since it must always be provided and have full rank; therefore we 784 // use that same source-of-truth here. 785 const Level lvlRank = lvlTypes.size(); 786 if (lvlRank == 0) 787 return emitError() << "expected a non-empty array for lvlTypes"; 788 // We save `dimRank` here because we'll also need it to verify `dimSlices`. 789 const Dimension dimRank = dimToLvl ? dimToLvl.getNumDims() : lvlRank; 790 if (dimToLvl) { 791 if (dimToLvl.getNumResults() != lvlRank) 792 return emitError() 793 << "level-rank mismatch between dimToLvl and lvlTypes: " 794 << dimToLvl.getNumResults() << " != " << lvlRank; 795 auto inferRes = inferLvlToDim(dimToLvl, dimToLvl.getContext()); 796 // Symbols can't be inferred but are acceptable. 797 if (!inferRes && dimToLvl.getNumSymbols() == 0) 798 return emitError() << "failed to infer lvlToDim from dimToLvl"; 799 if (lvlToDim && (inferRes != lvlToDim)) 800 return emitError() << "expected lvlToDim to be an inverse of dimToLvl"; 801 if (dimRank > lvlRank) 802 return emitError() << "unexpected dimToLvl mapping from " << dimRank 803 << " to " << lvlRank; 804 } 805 if (!dimSlices.empty()) { 806 if (dimSlices.size() != dimRank) 807 return emitError() 808 << "dimension-rank mismatch between dimSlices and dimToLvl: " 809 << dimSlices.size() << " != " << dimRank; 810 // Compiler support for `dimSlices` currently requires that the two 811 // ranks agree. (However, it does allow `dimToLvl` to be a permutation.) 812 if (dimRank != lvlRank) 813 return emitError() 814 << "dimSlices expected dimension-rank to match level-rank: " 815 << dimRank << " != " << lvlRank; 816 } 817 return success(); 818 } 819 820 LogicalResult SparseTensorEncodingAttr::verifyEncoding( 821 ArrayRef<Size> dimShape, Type elementType, 822 function_ref<InFlightDiagnostic()> emitError) const { 823 // Check structural integrity. In particular, this ensures that the 824 // level-rank is coherent across all the fields. 825 if (failed(verify(emitError, getLvlTypes(), getDimToLvl(), getLvlToDim(), 826 getPosWidth(), getCrdWidth(), getDimSlices()))) 827 return failure(); 828 // Check integrity with tensor type specifics. In particular, we 829 // need only check that the dimension-rank of the tensor agrees with 830 // the dimension-rank of the encoding. 831 const Dimension dimRank = dimShape.size(); 832 if (dimRank == 0) 833 return emitError() << "expected non-scalar sparse tensor"; 834 if (getDimRank() != dimRank) 835 return emitError() 836 << "dimension-rank mismatch between encoding and tensor shape: " 837 << getDimRank() << " != " << dimRank; 838 return success(); 839 } 840 841 //===----------------------------------------------------------------------===// 842 // SparseTensorType Methods. 843 //===----------------------------------------------------------------------===// 844 845 bool mlir::sparse_tensor::SparseTensorType::isCOOType(Level startLvl, 846 bool isUnique) const { 847 if (!hasEncoding()) 848 return false; 849 if (!isCompressedLvl(startLvl) && !isLooseCompressedLvl(startLvl)) 850 return false; 851 for (Level l = startLvl + 1; l < lvlRank; ++l) 852 if (!isSingletonLvl(l)) 853 return false; 854 // If isUnique is true, then make sure that the last level is unique, 855 // that is, when lvlRank == 1, the only compressed level is unique, 856 // and when lvlRank > 1, the last singleton is unique. 857 return !isUnique || isUniqueLvl(lvlRank - 1); 858 } 859 860 Level mlir::sparse_tensor::SparseTensorType::getAoSCOOStart() const { 861 SmallVector<COOSegment> coo = getCOOSegments(); 862 assert(coo.size() == 1 || coo.empty()); 863 if (!coo.empty() && coo.front().isAoS()) { 864 return coo.front().lvlRange.first; 865 } 866 return lvlRank; 867 } 868 869 SmallVector<COOSegment> 870 mlir::sparse_tensor::SparseTensorType::getCOOSegments() const { 871 SmallVector<COOSegment> ret; 872 if (!hasEncoding() || lvlRank <= 1) 873 return ret; 874 875 ArrayRef<LevelType> lts = getLvlTypes(); 876 Level l = 0; 877 while (l < lvlRank) { 878 auto lt = lts[l]; 879 if (lt.isa<LevelFormat::Compressed, LevelFormat::LooseCompressed>()) { 880 auto cur = lts.begin() + l; 881 auto end = std::find_if(cur + 1, lts.end(), [](LevelType lt) { 882 return !lt.isa<LevelFormat::Singleton>(); 883 }); 884 unsigned cooLen = std::distance(cur, end); 885 if (cooLen > 1) { 886 // To support mixed SoA/AoS COO, we should break the segment when the 887 // storage scheme changes, for now we faithfully assume that all 888 // consecutive singleton levels have the same storage format as verified 889 // STEA. 890 ret.push_back(COOSegment{std::make_pair(l, l + cooLen), 891 lts[l + 1].isa<LevelPropNonDefault::SoA>()}); 892 } 893 l += cooLen; 894 } else { 895 l++; 896 } 897 } 898 return ret; 899 } 900 901 RankedTensorType 902 mlir::sparse_tensor::SparseTensorType::getCOOType(bool ordered) const { 903 SmallVector<LevelType> lvlTypes; 904 lvlTypes.reserve(lvlRank); 905 // A non-unique compressed level at beginning (unless this is 906 // also the last level, then it is unique). 907 lvlTypes.push_back( 908 *buildLevelType(LevelFormat::Compressed, ordered, lvlRank == 1)); 909 if (lvlRank > 1) { 910 // Followed by n-2 non-unique singleton levels. 911 std::fill_n(std::back_inserter(lvlTypes), lvlRank - 2, 912 *buildLevelType(LevelFormat::Singleton, ordered, false)); 913 // Ends by a unique singleton level. 914 lvlTypes.push_back(*buildLevelType(LevelFormat::Singleton, ordered, true)); 915 } 916 auto enc = SparseTensorEncodingAttr::get(getContext(), lvlTypes, 917 getDimToLvl(), getLvlToDim(), 918 getPosWidth(), getCrdWidth()); 919 return RankedTensorType::get(getDimShape(), getElementType(), enc); 920 } 921 922 //===----------------------------------------------------------------------===// 923 // Convenience Methods. 924 //===----------------------------------------------------------------------===// 925 926 SparseTensorEncodingAttr 927 mlir::sparse_tensor::getSparseTensorEncoding(Type type) { 928 if (auto ttp = llvm::dyn_cast<RankedTensorType>(type)) 929 return llvm::dyn_cast_or_null<SparseTensorEncodingAttr>(ttp.getEncoding()); 930 if (auto mdtp = llvm::dyn_cast<StorageSpecifierType>(type)) 931 return mdtp.getEncoding(); 932 return nullptr; 933 } 934 935 AffineMap mlir::sparse_tensor::inferLvlToDim(AffineMap dimToLvl, 936 MLIRContext *context) { 937 auto map = static_cast<AffineMap>(dimToLvl); 938 AffineMap lvlToDim; 939 // Return an empty lvlToDim when inference is not successful. 940 if (!map || map.getNumSymbols() != 0) { 941 lvlToDim = AffineMap(); 942 } else if (map.isPermutation()) { 943 lvlToDim = inversePermutation(map); 944 } else if (isBlockSparsity(map)) { 945 lvlToDim = inverseBlockSparsity(map, context); 946 } 947 return lvlToDim; 948 } 949 950 AffineMap mlir::sparse_tensor::inverseBlockSparsity(AffineMap dimToLvl, 951 MLIRContext *context) { 952 SmallVector<AffineExpr> lvlExprs; 953 auto numLvls = dimToLvl.getNumResults(); 954 lvlExprs.reserve(numLvls); 955 // lvlExprComponents stores information of the floordiv and mod operations 956 // applied to the same dimension, so as to build the lvlToDim map. 957 std::map<unsigned, SmallVector<AffineExpr, 3>> lvlExprComponents; 958 for (unsigned i = 0, n = numLvls; i < n; i++) { 959 auto result = dimToLvl.getResult(i); 960 if (auto binOp = dyn_cast<AffineBinaryOpExpr>(result)) { 961 if (result.getKind() == AffineExprKind::FloorDiv) { 962 // Position of the dimension in dimToLvl. 963 auto pos = dyn_cast<AffineDimExpr>(binOp.getLHS()).getPosition(); 964 assert(lvlExprComponents.find(pos) == lvlExprComponents.end() && 965 "expected only one floordiv for each dimension"); 966 SmallVector<AffineExpr, 3> components; 967 // Level variable for floordiv. 968 components.push_back(getAffineDimExpr(i, context)); 969 // Multiplier. 970 components.push_back(binOp.getRHS()); 971 // Map key is the position of the dimension. 972 lvlExprComponents[pos] = components; 973 } else if (result.getKind() == AffineExprKind::Mod) { 974 auto pos = dyn_cast<AffineDimExpr>(binOp.getLHS()).getPosition(); 975 assert(lvlExprComponents.find(pos) != lvlExprComponents.end() && 976 "expected floordiv before mod"); 977 // Add level variable for mod to the same vector 978 // of the corresponding floordiv. 979 lvlExprComponents[pos].push_back(getAffineDimExpr(i, context)); 980 } else { 981 assert(false && "expected floordiv or mod"); 982 } 983 } else { 984 lvlExprs.push_back(getAffineDimExpr(i, context)); 985 } 986 } 987 // Build lvlExprs from lvlExprComponents. 988 // For example, for il = i floordiv 2 and ii = i mod 2, the components 989 // would be [il, 2, ii]. It could be used to build the AffineExpr 990 // i = il * 2 + ii in lvlToDim. 991 for (auto &components : lvlExprComponents) { 992 assert(components.second.size() == 3 && 993 "expected 3 components to build lvlExprs"); 994 auto mulOp = getAffineBinaryOpExpr( 995 AffineExprKind::Mul, components.second[0], components.second[1]); 996 auto addOp = 997 getAffineBinaryOpExpr(AffineExprKind::Add, mulOp, components.second[2]); 998 lvlExprs.push_back(addOp); 999 } 1000 return dimToLvl.get(dimToLvl.getNumResults(), 0, lvlExprs, context); 1001 } 1002 1003 SmallVector<unsigned> mlir::sparse_tensor::getBlockSize(AffineMap dimToLvl) { 1004 assert(isBlockSparsity(dimToLvl) && 1005 "expected dimToLvl to be block sparsity for calling getBlockSize"); 1006 SmallVector<unsigned> blockSize; 1007 for (auto result : dimToLvl.getResults()) { 1008 if (auto binOp = dyn_cast<AffineBinaryOpExpr>(result)) { 1009 if (result.getKind() == AffineExprKind::Mod) { 1010 blockSize.push_back( 1011 dyn_cast<AffineConstantExpr>(binOp.getRHS()).getValue()); 1012 } 1013 } else { 1014 blockSize.push_back(0); 1015 } 1016 } 1017 return blockSize; 1018 } 1019 1020 bool mlir::sparse_tensor::isBlockSparsity(AffineMap dimToLvl) { 1021 if (!dimToLvl) 1022 return false; 1023 std::map<unsigned, int64_t> coeffientMap; 1024 bool hasBlock = false; 1025 for (auto result : dimToLvl.getResults()) { 1026 if (auto binOp = dyn_cast<AffineBinaryOpExpr>(result)) { 1027 // Check for "dim op const". 1028 auto dimOp = dyn_cast<AffineDimExpr>(binOp.getLHS()); 1029 auto conOp = dyn_cast<AffineConstantExpr>(binOp.getRHS()); 1030 if (!dimOp || !conOp || conOp.getValue() <= 0) 1031 return false; 1032 // Inspect "dim / const" or "dim % const". 1033 auto pos = dimOp.getPosition(); 1034 if (binOp.getKind() == AffineExprKind::FloorDiv) { 1035 // Expect only one floordiv for each dimension. 1036 if (coeffientMap.find(pos) != coeffientMap.end()) 1037 return false; 1038 // Record coefficient of the floordiv. 1039 coeffientMap[pos] = conOp.getValue(); 1040 } else if (binOp.getKind() == AffineExprKind::Mod) { 1041 // Expect floordiv before mod. 1042 if (coeffientMap.find(pos) == coeffientMap.end()) 1043 return false; 1044 // Expect mod to have the same coefficient as floordiv. 1045 if (conOp.getValue() != coeffientMap[pos]) 1046 return false; 1047 hasBlock = true; 1048 } else { 1049 return false; 1050 } 1051 } else if (auto dimOp = dyn_cast<AffineDimExpr>(result)) { 1052 auto pos = dimOp.getPosition(); 1053 // Expect dim to be unset. 1054 if (coeffientMap.find(pos) != coeffientMap.end()) 1055 return false; 1056 coeffientMap[pos] = 0; 1057 } else { 1058 return false; 1059 } 1060 } 1061 return hasBlock; 1062 } 1063 1064 bool mlir::sparse_tensor::hasAnyNonIdentityOperandsOrResults(Operation *op) { 1065 auto hasNonIdentityMap = [](Value v) { 1066 auto stt = tryGetSparseTensorType(v); 1067 return stt && !stt->isIdentity(); 1068 }; 1069 1070 return llvm::any_of(op->getOperands(), hasNonIdentityMap) || 1071 llvm::any_of(op->getResults(), hasNonIdentityMap); 1072 } 1073 1074 Dimension mlir::sparse_tensor::toDim(SparseTensorEncodingAttr enc, Level l) { 1075 if (enc) { 1076 assert(enc.isPermutation() && "Non permutation map not supported"); 1077 if (const auto dimToLvl = enc.getDimToLvl()) 1078 return dimToLvl.getDimPosition(l); 1079 } 1080 return l; 1081 } 1082 1083 Level mlir::sparse_tensor::toLvl(SparseTensorEncodingAttr enc, Dimension d) { 1084 if (enc) { 1085 assert(enc.isPermutation() && "Non permutation map not supported"); 1086 if (const auto lvlToDim = enc.getLvlToDim()) 1087 return lvlToDim.getDimPosition(d); 1088 } 1089 return d; 1090 } 1091 1092 /// We normalized sparse tensor encoding attribute by always using 1093 /// ordered/unique LT such that "compressed_nu_no" and "compressed_nu" (as well 1094 /// as other variants) lead to the same storage specifier type, and stripping 1095 /// irrelevant fields that do not alter the sparse tensor memory layout. 1096 static SparseTensorEncodingAttr 1097 getNormalizedEncodingForSpecifier(SparseTensorEncodingAttr enc) { 1098 SmallVector<LevelType> lts; 1099 for (auto lt : enc.getLvlTypes()) 1100 lts.push_back(lt.stripStorageIrrelevantProperties()); 1101 1102 return SparseTensorEncodingAttr::get( 1103 enc.getContext(), lts, 1104 AffineMap(), // dimToLvl (irrelevant to storage specifier) 1105 AffineMap(), // lvlToDim (irrelevant to storage specifier) 1106 // Always use `index` for memSize and lvlSize instead of reusing 1107 // `getPosWidth` and `getCrdWidth`. It allows us to reuse the same SSA 1108 // value for different bitwidth, it also avoids casting between index and 1109 // integer (returned by DimOp) 1110 0, 0, enc.getDimSlices()); 1111 } 1112 1113 StorageSpecifierType 1114 StorageSpecifierType::get(MLIRContext *ctx, SparseTensorEncodingAttr encoding) { 1115 return Base::get(ctx, getNormalizedEncodingForSpecifier(encoding)); 1116 } 1117 1118 //===----------------------------------------------------------------------===// 1119 // SparseTensorDialect Operations. 1120 //===----------------------------------------------------------------------===// 1121 1122 static LogicalResult lvlIsInBounds(Level lvl, Value tensor) { 1123 return success(lvl < getSparseTensorType(tensor).getLvlRank()); 1124 } 1125 1126 static LogicalResult isMatchingWidth(Value mem, unsigned width) { 1127 const Type etp = getMemRefType(mem).getElementType(); 1128 return success(width == 0 ? etp.isIndex() : etp.isInteger(width)); 1129 } 1130 1131 static LogicalResult verifySparsifierGetterSetter( 1132 StorageSpecifierKind mdKind, std::optional<Level> lvl, 1133 TypedValue<StorageSpecifierType> md, Operation *op) { 1134 if (mdKind == StorageSpecifierKind::ValMemSize && lvl) { 1135 return op->emitError( 1136 "redundant level argument for querying value memory size"); 1137 } 1138 1139 const auto enc = md.getType().getEncoding(); 1140 const Level lvlRank = enc.getLvlRank(); 1141 1142 if (mdKind == StorageSpecifierKind::DimOffset || 1143 mdKind == StorageSpecifierKind::DimStride) 1144 if (!enc.isSlice()) 1145 return op->emitError("requested slice data on non-slice tensor"); 1146 1147 if (mdKind != StorageSpecifierKind::ValMemSize) { 1148 if (!lvl) 1149 return op->emitError("missing level argument"); 1150 1151 const Level l = lvl.value(); 1152 if (l >= lvlRank) 1153 return op->emitError("requested level is out of bounds"); 1154 1155 if (mdKind == StorageSpecifierKind::PosMemSize && enc.isSingletonLvl(l)) 1156 return op->emitError( 1157 "requested position memory size on a singleton level"); 1158 } 1159 return success(); 1160 } 1161 1162 static Type getFieldElemType(SparseTensorType stt, SparseTensorFieldKind kind) { 1163 switch (kind) { 1164 case SparseTensorFieldKind::CrdMemRef: 1165 return stt.getCrdType(); 1166 case SparseTensorFieldKind::PosMemRef: 1167 return stt.getPosType(); 1168 case SparseTensorFieldKind::ValMemRef: 1169 return stt.getElementType(); 1170 case SparseTensorFieldKind::StorageSpec: 1171 return nullptr; 1172 } 1173 llvm_unreachable("Unrecognizable FieldKind"); 1174 } 1175 1176 static LogicalResult verifyPackUnPack(Operation *op, bool requiresStaticShape, 1177 SparseTensorType stt, 1178 RankedTensorType valTp, 1179 TypeRange lvlTps) { 1180 if (requiresStaticShape && !stt.hasStaticDimShape()) 1181 return op->emitError("the sparse-tensor must have static shape"); 1182 if (!stt.hasEncoding()) 1183 return op->emitError("the sparse-tensor must have an encoding attribute"); 1184 1185 // Verifies the trailing COO. 1186 Level cooStartLvl = stt.getAoSCOOStart(); 1187 if (cooStartLvl < stt.getLvlRank()) { 1188 // We only supports trailing COO for now, must be the last input. 1189 auto cooTp = llvm::cast<ShapedType>(lvlTps.back()); 1190 // The coordinates should be in shape of <? x rank> 1191 unsigned expCOORank = stt.getLvlRank() - cooStartLvl; 1192 if (cooTp.getRank() != 2 || expCOORank != cooTp.getShape().back()) { 1193 op->emitError("input/output trailing COO level-ranks don't match"); 1194 } 1195 } 1196 1197 // Verifies that all types match. 1198 StorageLayout layout(stt.getEncoding()); 1199 if (layout.getNumDataFields() != lvlTps.size() + 1) // plus one value memref 1200 return op->emitError("inconsistent number of fields between input/output"); 1201 1202 unsigned idx = 0; 1203 bool misMatch = false; 1204 layout.foreachField([&idx, &misMatch, stt, valTp, 1205 lvlTps](FieldIndex fid, SparseTensorFieldKind fKind, 1206 Level lvl, LevelType lt) -> bool { 1207 if (fKind == SparseTensorFieldKind::StorageSpec) 1208 return true; 1209 1210 Type inputTp = nullptr; 1211 if (fKind == SparseTensorFieldKind::ValMemRef) { 1212 inputTp = valTp; 1213 } else { 1214 assert(fid == idx && stt.getLvlType(lvl) == lt); 1215 inputTp = lvlTps[idx++]; 1216 } 1217 // The input element type and expected element type should match. 1218 Type inpElemTp = llvm::cast<TensorType>(inputTp).getElementType(); 1219 Type expElemTp = getFieldElemType(stt, fKind); 1220 if (inpElemTp != expElemTp) { 1221 misMatch = true; 1222 return false; // to terminate the iteration 1223 } 1224 return true; 1225 }); 1226 1227 if (misMatch) 1228 return op->emitError("input/output element-types don't match"); 1229 return success(); 1230 } 1231 1232 LogicalResult AssembleOp::verify() { 1233 const auto valuesTp = getRankedTensorType(getValues()); 1234 const auto lvlsTp = getLevels().getTypes(); 1235 const auto resTp = getSparseTensorType(getResult()); 1236 return verifyPackUnPack(*this, true, resTp, valuesTp, lvlsTp); 1237 } 1238 1239 LogicalResult DisassembleOp::verify() { 1240 if (getOutValues().getType() != getRetValues().getType()) 1241 return emitError("output values and return value type mismatch"); 1242 1243 for (auto [ot, rt] : llvm::zip_equal(getOutLevels(), getRetLevels())) 1244 if (ot.getType() != rt.getType()) 1245 return emitError("output levels and return levels type mismatch"); 1246 1247 const auto valuesTp = getRankedTensorType(getRetValues()); 1248 const auto lvlsTp = getRetLevels().getTypes(); 1249 const auto srcTp = getSparseTensorType(getTensor()); 1250 return verifyPackUnPack(*this, false, srcTp, valuesTp, lvlsTp); 1251 } 1252 1253 LogicalResult ConvertOp::verify() { 1254 if (auto tp1 = llvm::dyn_cast<RankedTensorType>(getSource().getType())) { 1255 if (auto tp2 = llvm::dyn_cast<RankedTensorType>(getDest().getType())) { 1256 if (tp1.getRank() != tp2.getRank()) 1257 return emitError("unexpected conversion mismatch in rank"); 1258 auto dstEnc = 1259 llvm::dyn_cast_or_null<SparseTensorEncodingAttr>(tp2.getEncoding()); 1260 if (dstEnc && dstEnc.isSlice()) 1261 return emitError("cannot convert to a sparse tensor slice"); 1262 1263 auto shape1 = tp1.getShape(); 1264 auto shape2 = tp2.getShape(); 1265 // Accept size matches between the source and the destination type 1266 // (e.g. 10 vs. 10, 10 vs. ?, or ? vs. ?), but reject direct mismatches or 1267 // matches that would need a runtime assert (e.g. 10 vs. 20 or ? vs. 10). 1268 for (Dimension d = 0, dimRank = tp1.getRank(); d < dimRank; d++) 1269 if (shape1[d] != shape2[d] && shape2[d] != ShapedType::kDynamic) 1270 return emitError("unexpected conversion mismatch in dimension ") << d; 1271 return success(); 1272 } 1273 } 1274 return emitError("unexpected type in convert"); 1275 } 1276 1277 OpFoldResult ConvertOp::fold(FoldAdaptor adaptor) { 1278 if (getType() == getSource().getType()) 1279 return getSource(); 1280 return {}; 1281 } 1282 1283 bool ConvertOp::needsExtraSort() { 1284 SparseTensorType srcStt = getSparseTensorType(getSource()); 1285 SparseTensorType dstStt = getSparseTensorType(getDest()); 1286 1287 // We do not need an extra sort when returning unordered sparse tensors or 1288 // dense tensor since dense tensor support random access. 1289 if (dstStt.isAllDense() || !dstStt.isAllOrdered()) 1290 return false; 1291 1292 if (srcStt.isAllOrdered() && dstStt.isAllOrdered() && 1293 srcStt.hasSameDimToLvl(dstStt)) { 1294 return false; 1295 } 1296 1297 // Source and dest tensors are ordered in different ways. We only do direct 1298 // dense to sparse conversion when the dense input is defined by a sparse 1299 // constant. Note that we can theoretically always directly convert from dense 1300 // inputs by rotating dense loops but it leads to bad cache locality and hurt 1301 // performance. 1302 if (auto constOp = getSource().getDefiningOp<arith::ConstantOp>()) 1303 if (isa<SparseElementsAttr>(constOp.getValue())) 1304 return false; 1305 1306 return true; 1307 } 1308 1309 LogicalResult CrdTranslateOp::verify() { 1310 uint64_t inRank = getEncoder().getLvlRank(); 1311 uint64_t outRank = getEncoder().getDimRank(); 1312 1313 if (getDirection() == CrdTransDirectionKind::dim2lvl) 1314 std::swap(inRank, outRank); 1315 1316 if (inRank != getInCrds().size() || outRank != getOutCrds().size()) 1317 return emitError("Coordinate rank mismatch with encoding"); 1318 1319 return success(); 1320 } 1321 1322 LogicalResult CrdTranslateOp::fold(FoldAdaptor adaptor, 1323 SmallVectorImpl<OpFoldResult> &results) { 1324 if (getEncoder().isIdentity()) { 1325 results.assign(getInCrds().begin(), getInCrds().end()); 1326 return success(); 1327 } 1328 if (getEncoder().isPermutation()) { 1329 AffineMap perm = getDirection() == CrdTransDirectionKind::dim2lvl 1330 ? getEncoder().getDimToLvl() 1331 : getEncoder().getLvlToDim(); 1332 for (AffineExpr exp : perm.getResults()) 1333 results.push_back(getInCrds()[cast<AffineDimExpr>(exp).getPosition()]); 1334 return success(); 1335 } 1336 1337 // Fuse dim2lvl/lvl2dim pairs. 1338 auto def = getInCrds()[0].getDefiningOp<CrdTranslateOp>(); 1339 bool sameDef = def && llvm::all_of(getInCrds(), [def](Value v) { 1340 return v.getDefiningOp() == def; 1341 }); 1342 if (!sameDef) 1343 return failure(); 1344 1345 bool oppositeDir = def.getDirection() != getDirection(); 1346 bool sameOracle = 1347 def.getEncoder().getDimToLvl() == getEncoder().getDimToLvl(); 1348 bool sameCount = def.getNumResults() == getInCrds().size(); 1349 if (!oppositeDir || !sameOracle || !sameCount) 1350 return failure(); 1351 1352 // The definition produces the coordinates in the same order as the input 1353 // coordinates. 1354 bool sameOrder = llvm::all_of(llvm::zip_equal(def.getOutCrds(), getInCrds()), 1355 [](auto valuePair) { 1356 auto [lhs, rhs] = valuePair; 1357 return lhs == rhs; 1358 }); 1359 1360 if (!sameOrder) 1361 return failure(); 1362 // l1 = dim2lvl (lvl2dim l0) 1363 // ==> l0 1364 results.append(def.getInCrds().begin(), def.getInCrds().end()); 1365 return success(); 1366 } 1367 1368 void LvlOp::build(OpBuilder &builder, OperationState &state, Value source, 1369 int64_t index) { 1370 Value val = builder.create<arith::ConstantIndexOp>(state.location, index); 1371 return build(builder, state, source, val); 1372 } 1373 1374 LogicalResult LvlOp::verify() { 1375 if (std::optional<uint64_t> lvl = getConstantLvlIndex()) { 1376 auto stt = getSparseTensorType(getSource()); 1377 if (static_cast<uint64_t>(lvl.value()) >= stt.getLvlRank()) 1378 emitError("Level index exceeds the rank of the input sparse tensor"); 1379 } 1380 return success(); 1381 } 1382 1383 std::optional<uint64_t> LvlOp::getConstantLvlIndex() { 1384 return getConstantIntValue(getIndex()); 1385 } 1386 1387 Speculation::Speculatability LvlOp::getSpeculatability() { 1388 auto constantIndex = getConstantLvlIndex(); 1389 if (!constantIndex) 1390 return Speculation::NotSpeculatable; 1391 1392 assert(constantIndex < 1393 cast<RankedTensorType>(getSource().getType()).getRank()); 1394 return Speculation::Speculatable; 1395 } 1396 1397 OpFoldResult LvlOp::fold(FoldAdaptor adaptor) { 1398 auto lvlIndex = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getIndex()); 1399 if (!lvlIndex) 1400 return {}; 1401 1402 Level lvl = lvlIndex.getAPSInt().getZExtValue(); 1403 auto stt = getSparseTensorType(getSource()); 1404 if (lvl >= stt.getLvlRank()) { 1405 // Follows the same convention used by tensor.dim operation. Out of bound 1406 // indices produce undefined behavior but are still valid IR. Don't choke on 1407 // them. 1408 return {}; 1409 } 1410 1411 // Helper lambda to build an IndexAttr. 1412 auto getIndexAttr = [this](int64_t lvlSz) { 1413 return IntegerAttr::get(IndexType::get(getContext()), APInt(64, lvlSz)); 1414 }; 1415 1416 SmallVector<Size> lvlShape = stt.getLvlShape(); 1417 if (!ShapedType::isDynamic(lvlShape[lvl])) 1418 return getIndexAttr(lvlShape[lvl]); 1419 1420 return {}; 1421 } 1422 1423 void ReinterpretMapOp::build(OpBuilder &odsBuilder, OperationState &odsState, 1424 SparseTensorEncodingAttr dstEnc, Value source) { 1425 auto srcStt = getSparseTensorType(source); 1426 SmallVector<int64_t> srcLvlShape = srcStt.getLvlShape(); 1427 SmallVector<int64_t> dstDimShape = 1428 dstEnc.translateShape(srcLvlShape, CrdTransDirectionKind::lvl2dim); 1429 auto dstTp = 1430 RankedTensorType::get(dstDimShape, srcStt.getElementType(), dstEnc); 1431 return build(odsBuilder, odsState, dstTp, source); 1432 } 1433 1434 LogicalResult ReinterpretMapOp::verify() { 1435 auto srcStt = getSparseTensorType(getSource()); 1436 auto dstStt = getSparseTensorType(getDest()); 1437 ArrayRef<LevelType> srcLvlTps = srcStt.getLvlTypes(); 1438 ArrayRef<LevelType> dstLvlTps = dstStt.getLvlTypes(); 1439 1440 if (srcLvlTps.size() != dstLvlTps.size()) 1441 return emitError("Level rank mismatch between source/dest tensors"); 1442 1443 for (auto [srcLvlTp, dstLvlTp] : llvm::zip(srcLvlTps, dstLvlTps)) 1444 if (srcLvlTp != dstLvlTp) 1445 return emitError("Level type mismatch between source/dest tensors"); 1446 1447 if (srcStt.getPosWidth() != dstStt.getPosWidth() || 1448 srcStt.getCrdWidth() != dstStt.getCrdWidth()) { 1449 return emitError("Crd/Pos width mismatch between source/dest tensors"); 1450 } 1451 1452 if (srcStt.getElementType() != dstStt.getElementType()) 1453 return emitError("Element type mismatch between source/dest tensors"); 1454 1455 SmallVector<Size> srcLvlShape = srcStt.getLvlShape(); 1456 SmallVector<Size> dstLvlShape = dstStt.getLvlShape(); 1457 for (auto [srcLvlSz, dstLvlSz] : llvm::zip(srcLvlShape, dstLvlShape)) { 1458 if (srcLvlSz != dstLvlSz) { 1459 // Should we allow one side to be dynamic size, e.g., <?x?> should be 1460 // compatible to <3x4>? For now, we require all the level sizes to be 1461 // *exactly* matched for simplicity. 1462 return emitError("Level size mismatch between source/dest tensors"); 1463 } 1464 } 1465 1466 return success(); 1467 } 1468 1469 OpFoldResult ReinterpretMapOp::fold(FoldAdaptor adaptor) { 1470 if (getSource().getType() == getDest().getType()) 1471 return getSource(); 1472 1473 if (auto def = getSource().getDefiningOp<ReinterpretMapOp>()) { 1474 // A -> B, B -> A ==> A 1475 if (def.getSource().getType() == getDest().getType()) 1476 return def.getSource(); 1477 } 1478 return {}; 1479 } 1480 1481 template <typename ToBufferOp> 1482 static LogicalResult inferSparseBufferType(ValueRange ops, DictionaryAttr attr, 1483 OpaqueProperties prop, 1484 RegionRange region, 1485 SmallVectorImpl<mlir::Type> &ret) { 1486 typename ToBufferOp::Adaptor adaptor(ops, attr, prop, region); 1487 SparseTensorType stt = getSparseTensorType(adaptor.getTensor()); 1488 Type elemTp = nullptr; 1489 bool withStride = false; 1490 if constexpr (std::is_same_v<ToBufferOp, ToPositionsOp>) { 1491 elemTp = stt.getPosType(); 1492 } else if constexpr (std::is_same_v<ToBufferOp, ToCoordinatesOp> || 1493 std::is_same_v<ToBufferOp, ToCoordinatesBufferOp>) { 1494 elemTp = stt.getCrdType(); 1495 if constexpr (std::is_same_v<ToBufferOp, ToCoordinatesOp>) 1496 withStride = stt.getAoSCOOStart() <= adaptor.getLevel(); 1497 } else if constexpr (std::is_same_v<ToBufferOp, ToValuesOp>) { 1498 elemTp = stt.getElementType(); 1499 } 1500 1501 assert(elemTp && "unhandled operation."); 1502 SmallVector<int64_t> bufShape = stt.getBatchLvlShape(); 1503 bufShape.push_back(ShapedType::kDynamic); 1504 1505 auto layout = withStride ? StridedLayoutAttr::StridedLayoutAttr::get( 1506 stt.getContext(), ShapedType::kDynamic, 1507 {ShapedType::kDynamic}) 1508 : StridedLayoutAttr(); 1509 ret.emplace_back(MemRefType::get(bufShape, elemTp, layout)); 1510 return success(); 1511 } 1512 1513 LogicalResult ToPositionsOp::verify() { 1514 auto stt = getSparseTensorType(getTensor()); 1515 if (failed(lvlIsInBounds(getLevel(), getTensor()))) 1516 return emitError("requested level is out of bounds"); 1517 if (failed(isMatchingWidth(getResult(), stt.getPosWidth()))) 1518 return emitError("unexpected type for positions"); 1519 return success(); 1520 } 1521 1522 LogicalResult 1523 ToPositionsOp::inferReturnTypes(MLIRContext *ctx, std::optional<Location> loc, 1524 ValueRange ops, DictionaryAttr attr, 1525 OpaqueProperties prop, RegionRange region, 1526 SmallVectorImpl<mlir::Type> &ret) { 1527 return inferSparseBufferType<ToPositionsOp>(ops, attr, prop, region, ret); 1528 } 1529 1530 LogicalResult ToCoordinatesOp::verify() { 1531 auto stt = getSparseTensorType(getTensor()); 1532 if (failed(lvlIsInBounds(getLevel(), getTensor()))) 1533 return emitError("requested level is out of bounds"); 1534 if (failed(isMatchingWidth(getResult(), stt.getCrdWidth()))) 1535 return emitError("unexpected type for coordinates"); 1536 return success(); 1537 } 1538 1539 LogicalResult 1540 ToCoordinatesOp::inferReturnTypes(MLIRContext *ctx, std::optional<Location> loc, 1541 ValueRange ops, DictionaryAttr attr, 1542 OpaqueProperties prop, RegionRange region, 1543 SmallVectorImpl<mlir::Type> &ret) { 1544 return inferSparseBufferType<ToCoordinatesOp>(ops, attr, prop, region, ret); 1545 } 1546 1547 LogicalResult ToCoordinatesBufferOp::verify() { 1548 auto stt = getSparseTensorType(getTensor()); 1549 if (stt.getAoSCOOStart() >= stt.getLvlRank()) 1550 return emitError("expected sparse tensor with a COO region"); 1551 return success(); 1552 } 1553 1554 LogicalResult ToCoordinatesBufferOp::inferReturnTypes( 1555 MLIRContext *ctx, std::optional<Location> loc, ValueRange ops, 1556 DictionaryAttr attr, OpaqueProperties prop, RegionRange region, 1557 SmallVectorImpl<mlir::Type> &ret) { 1558 return inferSparseBufferType<ToCoordinatesBufferOp>(ops, attr, prop, region, 1559 ret); 1560 } 1561 1562 LogicalResult ToValuesOp::verify() { 1563 auto stt = getSparseTensorType(getTensor()); 1564 auto mtp = getMemRefType(getResult()); 1565 if (stt.getElementType() != mtp.getElementType()) 1566 return emitError("unexpected mismatch in element types"); 1567 return success(); 1568 } 1569 1570 LogicalResult ToValuesOp::inferReturnTypes(MLIRContext *ctx, 1571 std::optional<Location> loc, 1572 ValueRange ops, DictionaryAttr attr, 1573 OpaqueProperties prop, 1574 RegionRange region, 1575 SmallVectorImpl<mlir::Type> &ret) { 1576 return inferSparseBufferType<ToValuesOp>(ops, attr, prop, region, ret); 1577 } 1578 1579 LogicalResult ToSliceOffsetOp::verify() { 1580 auto rank = getRankedTensorType(getSlice()).getRank(); 1581 if (rank <= getDim().getSExtValue() || getDim().getSExtValue() < 0) 1582 return emitError("requested dimension out of bound"); 1583 return success(); 1584 } 1585 1586 LogicalResult ToSliceStrideOp::verify() { 1587 auto rank = getRankedTensorType(getSlice()).getRank(); 1588 if (rank <= getDim().getSExtValue() || getDim().getSExtValue() < 0) 1589 return emitError("requested dimension out of bound"); 1590 return success(); 1591 } 1592 1593 LogicalResult GetStorageSpecifierOp::verify() { 1594 return verifySparsifierGetterSetter(getSpecifierKind(), getLevel(), 1595 getSpecifier(), getOperation()); 1596 } 1597 1598 template <typename SpecifierOp> 1599 static SetStorageSpecifierOp getSpecifierSetDef(SpecifierOp op) { 1600 return op.getSpecifier().template getDefiningOp<SetStorageSpecifierOp>(); 1601 } 1602 1603 OpFoldResult GetStorageSpecifierOp::fold(FoldAdaptor adaptor) { 1604 const StorageSpecifierKind kind = getSpecifierKind(); 1605 const auto lvl = getLevel(); 1606 for (auto op = getSpecifierSetDef(*this); op; op = getSpecifierSetDef(op)) 1607 if (kind == op.getSpecifierKind() && lvl == op.getLevel()) 1608 return op.getValue(); 1609 return {}; 1610 } 1611 1612 LogicalResult SetStorageSpecifierOp::verify() { 1613 return verifySparsifierGetterSetter(getSpecifierKind(), getLevel(), 1614 getSpecifier(), getOperation()); 1615 } 1616 1617 template <class T> 1618 static LogicalResult verifyNumBlockArgs(T *op, Region ®ion, 1619 const char *regionName, 1620 TypeRange inputTypes, Type outputType) { 1621 unsigned numArgs = region.getNumArguments(); 1622 unsigned expectedNum = inputTypes.size(); 1623 if (numArgs != expectedNum) 1624 return op->emitError() << regionName << " region must have exactly " 1625 << expectedNum << " arguments"; 1626 1627 for (unsigned i = 0; i < numArgs; i++) { 1628 Type typ = region.getArgument(i).getType(); 1629 if (typ != inputTypes[i]) 1630 return op->emitError() << regionName << " region argument " << (i + 1) 1631 << " type mismatch"; 1632 } 1633 Operation *term = region.front().getTerminator(); 1634 YieldOp yield = dyn_cast<YieldOp>(term); 1635 if (!yield) 1636 return op->emitError() << regionName 1637 << " region must end with sparse_tensor.yield"; 1638 if (!yield.hasSingleResult() || 1639 yield.getSingleResult().getType() != outputType) 1640 return op->emitError() << regionName << " region yield type mismatch"; 1641 1642 return success(); 1643 } 1644 1645 LogicalResult BinaryOp::verify() { 1646 NamedAttrList attrs = (*this)->getAttrs(); 1647 Type leftType = getX().getType(); 1648 Type rightType = getY().getType(); 1649 Type outputType = getOutput().getType(); 1650 Region &overlap = getOverlapRegion(); 1651 Region &left = getLeftRegion(); 1652 Region &right = getRightRegion(); 1653 1654 // Check correct number of block arguments and return type for each 1655 // non-empty region. 1656 if (!overlap.empty()) { 1657 if (failed(verifyNumBlockArgs(this, overlap, "overlap", 1658 TypeRange{leftType, rightType}, outputType))) 1659 return failure(); 1660 } 1661 if (!left.empty()) { 1662 if (failed(verifyNumBlockArgs(this, left, "left", TypeRange{leftType}, 1663 outputType))) 1664 return failure(); 1665 } else if (getLeftIdentity()) { 1666 if (leftType != outputType) 1667 return emitError("left=identity requires first argument to have the same " 1668 "type as the output"); 1669 } 1670 if (!right.empty()) { 1671 if (failed(verifyNumBlockArgs(this, right, "right", TypeRange{rightType}, 1672 outputType))) 1673 return failure(); 1674 } else if (getRightIdentity()) { 1675 if (rightType != outputType) 1676 return emitError("right=identity requires second argument to have the " 1677 "same type as the output"); 1678 } 1679 return success(); 1680 } 1681 1682 LogicalResult UnaryOp::verify() { 1683 Type inputType = getX().getType(); 1684 Type outputType = getOutput().getType(); 1685 1686 // Check correct number of block arguments and return type for each 1687 // non-empty region. 1688 Region &present = getPresentRegion(); 1689 if (!present.empty()) { 1690 if (failed(verifyNumBlockArgs(this, present, "present", 1691 TypeRange{inputType}, outputType))) 1692 return failure(); 1693 } 1694 Region &absent = getAbsentRegion(); 1695 if (!absent.empty()) { 1696 if (failed(verifyNumBlockArgs(this, absent, "absent", TypeRange{}, 1697 outputType))) 1698 return failure(); 1699 // Absent branch can only yield invariant values. 1700 Block *absentBlock = &absent.front(); 1701 Block *parent = getOperation()->getBlock(); 1702 Value absentVal = 1703 cast<YieldOp>(absentBlock->getTerminator()).getSingleResult(); 1704 if (auto arg = dyn_cast<BlockArgument>(absentVal)) { 1705 if (arg.getOwner() == parent) 1706 return emitError("absent region cannot yield linalg argument"); 1707 } else if (Operation *def = absentVal.getDefiningOp()) { 1708 if (!isa<arith::ConstantOp>(def) && 1709 (def->getBlock() == absentBlock || def->getBlock() == parent)) 1710 return emitError("absent region cannot yield locally computed value"); 1711 } 1712 } 1713 return success(); 1714 } 1715 1716 bool ConcatenateOp::needsExtraSort() { 1717 SparseTensorType dstStt = getSparseTensorType(*this); 1718 if (dstStt.isAllDense() || !dstStt.isAllOrdered()) 1719 return false; 1720 1721 bool allSameOrdered = llvm::all_of(getInputs(), [dstStt](Value op) { 1722 return getSparseTensorType(op).hasSameDimToLvl(dstStt); 1723 }); 1724 // TODO: When conDim != 0, as long as conDim corresponding to the first level 1725 // in all input/output buffers, and all input/output buffers have the same 1726 // dimToLvl, the tmp COO buffer is still unnecessary (e.g, concatenate 1727 // CSC matrices along column). 1728 bool directLowerable = 1729 allSameOrdered && getDimension() == 0 && dstStt.isIdentity(); 1730 return !directLowerable; 1731 } 1732 1733 LogicalResult ConcatenateOp::verify() { 1734 const auto dstTp = getSparseTensorType(*this); 1735 const Dimension concatDim = getDimension(); 1736 const Dimension dimRank = dstTp.getDimRank(); 1737 1738 if (getInputs().size() <= 1) 1739 return emitError("Need at least two tensors to concatenate."); 1740 1741 if (concatDim >= dimRank) 1742 return emitError(llvm::formatv( 1743 "Concat-dimension is out of bounds for dimension-rank ({0} >= {1})", 1744 concatDim, dimRank)); 1745 1746 for (const auto &it : llvm::enumerate(getInputs())) { 1747 const auto i = it.index(); 1748 const auto srcTp = getSparseTensorType(it.value()); 1749 if (srcTp.hasDynamicDimShape()) 1750 return emitError(llvm::formatv("Input tensor ${0} has dynamic shape", i)); 1751 const Dimension srcDimRank = srcTp.getDimRank(); 1752 if (srcDimRank != dimRank) 1753 return emitError( 1754 llvm::formatv("Input tensor ${0} has a different rank (rank={1}) " 1755 "from the output tensor (rank={2}).", 1756 i, srcDimRank, dimRank)); 1757 } 1758 1759 for (Dimension d = 0; d < dimRank; d++) { 1760 const Size dstSh = dstTp.getDimShape()[d]; 1761 if (d == concatDim) { 1762 if (!ShapedType::isDynamic(dstSh)) { 1763 // If we reach here, then all inputs have static shapes. So we 1764 // can use `getDimShape()[d]` instead of `*getDynamicDimSize(d)` 1765 // to avoid redundant assertions in the loop. 1766 Size sumSz = 0; 1767 for (const auto src : getInputs()) 1768 sumSz += getSparseTensorType(src).getDimShape()[d]; 1769 // If all dimension are statically known, the sum of all the input 1770 // dimensions should be equal to the output dimension. 1771 if (sumSz != dstSh) 1772 return emitError( 1773 "The concatenation dimension of the output tensor should be the " 1774 "sum of all the concatenation dimensions of the input tensors."); 1775 } 1776 } else { 1777 Size prev = dstSh; 1778 for (const auto src : getInputs()) { 1779 const auto sh = getSparseTensorType(src).getDimShape()[d]; 1780 if (!ShapedType::isDynamic(prev) && sh != prev) 1781 return emitError("All dimensions (expect for the concatenating one) " 1782 "should be equal."); 1783 prev = sh; 1784 } 1785 } 1786 } 1787 1788 return success(); 1789 } 1790 1791 void PushBackOp::build(OpBuilder &builder, OperationState &result, 1792 Value curSize, Value inBuffer, Value value) { 1793 build(builder, result, curSize, inBuffer, value, Value()); 1794 } 1795 1796 LogicalResult PushBackOp::verify() { 1797 if (Value n = getN()) { 1798 std::optional<int64_t> nValue = getConstantIntValue(n); 1799 if (nValue && nValue.value() < 1) 1800 return emitOpError("n must be not less than 1"); 1801 } 1802 return success(); 1803 } 1804 1805 LogicalResult CompressOp::verify() { 1806 const auto stt = getSparseTensorType(getTensor()); 1807 if (stt.getLvlRank() != 1 + static_cast<Level>(getLvlCoords().size())) 1808 return emitOpError("incorrect number of coordinates"); 1809 return success(); 1810 } 1811 1812 void ForeachOp::build( 1813 OpBuilder &builder, OperationState &result, Value tensor, 1814 ValueRange initArgs, AffineMapAttr order, 1815 function_ref<void(OpBuilder &, Location, ValueRange, Value, ValueRange)> 1816 bodyBuilder) { 1817 build(builder, result, initArgs.getTypes(), tensor, initArgs, order); 1818 // Builds foreach body. 1819 if (!bodyBuilder) 1820 return; 1821 const auto stt = getSparseTensorType(tensor); 1822 const Dimension dimRank = stt.getDimRank(); 1823 1824 // Starts with `dimRank`-many coordinates. 1825 SmallVector<Type> blockArgTypes(dimRank, builder.getIndexType()); 1826 // Followed by one value. 1827 blockArgTypes.push_back(stt.getElementType()); 1828 // Followed by the reduction variables. 1829 blockArgTypes.append(initArgs.getTypes().begin(), initArgs.getTypes().end()); 1830 1831 SmallVector<Location> blockArgLocs(blockArgTypes.size(), tensor.getLoc()); 1832 1833 OpBuilder::InsertionGuard guard(builder); 1834 auto ®ion = *result.regions.front(); 1835 Block *bodyBlock = 1836 builder.createBlock(®ion, region.end(), blockArgTypes, blockArgLocs); 1837 bodyBuilder(builder, result.location, 1838 bodyBlock->getArguments().slice(0, dimRank), 1839 bodyBlock->getArguments()[dimRank], 1840 bodyBlock->getArguments().drop_front(dimRank + 1)); 1841 } 1842 1843 LogicalResult ForeachOp::verify() { 1844 const auto t = getSparseTensorType(getTensor()); 1845 const Dimension dimRank = t.getDimRank(); 1846 const auto args = getBody()->getArguments(); 1847 1848 if (getOrder().has_value() && getOrder()->getNumDims() != t.getLvlRank()) 1849 return emitError("Level traverse order does not match tensor's level rank"); 1850 1851 if (dimRank + 1 + getInitArgs().size() != args.size()) 1852 return emitError("Unmatched number of arguments in the block"); 1853 1854 if (getNumResults() != getInitArgs().size()) 1855 return emitError("Mismatch in number of init arguments and results"); 1856 1857 if (getResultTypes() != getInitArgs().getTypes()) 1858 return emitError("Mismatch in types of init arguments and results"); 1859 1860 // Cannot mark this const, because the getters aren't. 1861 auto yield = cast<YieldOp>(getBody()->getTerminator()); 1862 if (yield.getNumOperands() != getNumResults() || 1863 yield.getOperands().getTypes() != getResultTypes()) 1864 return emitError("Mismatch in types of yield values and results"); 1865 1866 const auto iTp = IndexType::get(getContext()); 1867 for (Dimension d = 0; d < dimRank; d++) 1868 if (args[d].getType() != iTp) 1869 emitError( 1870 llvm::formatv("Expecting Index type for argument at index {0}", d)); 1871 1872 const auto elemTp = t.getElementType(); 1873 const auto valueTp = args[dimRank].getType(); 1874 if (elemTp != valueTp) 1875 emitError(llvm::formatv("Unmatched element type between input tensor and " 1876 "block argument, expected:{0}, got: {1}", 1877 elemTp, valueTp)); 1878 return success(); 1879 } 1880 1881 OpFoldResult ReorderCOOOp::fold(FoldAdaptor adaptor) { 1882 if (getSparseTensorEncoding(getInputCoo().getType()) == 1883 getSparseTensorEncoding(getResultCoo().getType())) 1884 return getInputCoo(); 1885 1886 return {}; 1887 } 1888 1889 LogicalResult ReorderCOOOp::verify() { 1890 SparseTensorType srcStt = getSparseTensorType(getInputCoo()); 1891 SparseTensorType dstStt = getSparseTensorType(getResultCoo()); 1892 1893 if (!srcStt.isCOOType() || !dstStt.isCOOType()) 1894 emitError("Expected COO sparse tensors only"); 1895 1896 if (!srcStt.hasSameDimToLvl(dstStt)) 1897 emitError("Unmatched dim2lvl map between input and result COO"); 1898 1899 if (srcStt.getPosType() != dstStt.getPosType() || 1900 srcStt.getCrdType() != dstStt.getCrdType() || 1901 srcStt.getElementType() != dstStt.getElementType()) 1902 emitError("Unmatched storage format between input and result COO"); 1903 1904 return success(); 1905 } 1906 1907 LogicalResult ReduceOp::verify() { 1908 Type inputType = getX().getType(); 1909 Region &formula = getRegion(); 1910 return verifyNumBlockArgs(this, formula, "reduce", 1911 TypeRange{inputType, inputType}, inputType); 1912 } 1913 1914 LogicalResult SelectOp::verify() { 1915 Builder b(getContext()); 1916 Type inputType = getX().getType(); 1917 Type boolType = b.getI1Type(); 1918 Region &formula = getRegion(); 1919 return verifyNumBlockArgs(this, formula, "select", TypeRange{inputType}, 1920 boolType); 1921 } 1922 1923 LogicalResult SortOp::verify() { 1924 AffineMap xPerm = getPermMap(); 1925 uint64_t nx = xPerm.getNumDims(); 1926 if (nx < 1) 1927 emitError(llvm::formatv("Expected rank(perm_map) > 1, got {0}", nx)); 1928 1929 if (!xPerm.isPermutation()) 1930 emitError(llvm::formatv("Expected a permutation map, got {0}", xPerm)); 1931 1932 // We can't check the size of the buffers when n or buffer dimensions aren't 1933 // compile-time constants. 1934 std::optional<int64_t> cn = getConstantIntValue(getN()); 1935 if (!cn) 1936 return success(); 1937 1938 // Verify dimensions. 1939 const auto checkDim = [&](Value v, Size minSize, const char *message) { 1940 const Size sh = getMemRefType(v).getShape()[0]; 1941 if (!ShapedType::isDynamic(sh) && sh < minSize) 1942 emitError(llvm::formatv("{0} got {1} < {2}", message, sh, minSize)); 1943 }; 1944 uint64_t n = cn.value(); 1945 uint64_t ny = 0; 1946 if (auto nyAttr = getNyAttr()) 1947 ny = nyAttr.getInt(); 1948 checkDim(getXy(), n * (nx + ny), 1949 "Expected dimension(xy) >= n * (rank(perm_map) + ny)"); 1950 for (Value opnd : getYs()) 1951 checkDim(opnd, n, "Expected dimension(y) >= n"); 1952 1953 return success(); 1954 } 1955 1956 /// Materialize a single constant operation from a given attribute value with 1957 /// the desired resultant type. 1958 Operation *SparseTensorDialect::materializeConstant(OpBuilder &builder, 1959 Attribute value, Type type, 1960 Location loc) { 1961 if (auto op = arith::ConstantOp::materialize(builder, value, type, loc)) 1962 return op; 1963 return nullptr; 1964 } 1965 1966 namespace { 1967 struct SparseTensorAsmDialectInterface : public OpAsmDialectInterface { 1968 using OpAsmDialectInterface::OpAsmDialectInterface; 1969 1970 AliasResult getAlias(Attribute attr, raw_ostream &os) const override { 1971 if (attr.isa<SparseTensorEncodingAttr>()) { 1972 os << "sparse"; 1973 return AliasResult::OverridableAlias; 1974 } 1975 return AliasResult::NoAlias; 1976 } 1977 }; 1978 } // namespace 1979 1980 void SparseTensorDialect::initialize() { 1981 addInterface<SparseTensorAsmDialectInterface>(); 1982 addAttributes< 1983 #define GET_ATTRDEF_LIST 1984 #include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.cpp.inc" 1985 >(); 1986 addTypes< 1987 #define GET_TYPEDEF_LIST 1988 #include "mlir/Dialect/SparseTensor/IR/SparseTensorTypes.cpp.inc" 1989 >(); 1990 addOperations< 1991 #define GET_OP_LIST 1992 #include "mlir/Dialect/SparseTensor/IR/SparseTensorOps.cpp.inc" 1993 >(); 1994 declarePromisedInterfaces< 1995 bufferization::BufferizableOpInterface, ConcatenateOp, ConvertOp, LoadOp, 1996 NewOp, NumberOfEntriesOp, AssembleOp, DisassembleOp, 1997 ToCoordinatesBufferOp, ToCoordinatesOp, ToPositionsOp, ToValuesOp>(); 1998 } 1999 2000 #define GET_OP_CLASSES 2001 #include "mlir/Dialect/SparseTensor/IR/SparseTensorOps.cpp.inc" 2002 2003 #include "mlir/Dialect/SparseTensor/IR/SparseTensorOpsDialect.cpp.inc" 2004