1 //===- SparseTensorDialect.cpp - Sparse tensor dialect implementation -----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include <utility> 10 11 #include "Detail/DimLvlMapParser.h" 12 13 #include "mlir/Dialect/SparseTensor/IR/Enums.h" 14 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" 15 #include "mlir/Dialect/SparseTensor/IR/SparseTensorStorageLayout.h" 16 #include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h" 17 18 #include "mlir/Dialect/Arith/IR/Arith.h" 19 #include "mlir/Dialect/Utils/StaticValueUtils.h" 20 #include "mlir/IR/Builders.h" 21 #include "mlir/IR/DialectImplementation.h" 22 #include "mlir/IR/Matchers.h" 23 #include "mlir/IR/OpImplementation.h" 24 #include "mlir/IR/PatternMatch.h" 25 #include "llvm/ADT/TypeSwitch.h" 26 #include "llvm/Support/FormatVariadic.h" 27 28 #define GET_ATTRDEF_CLASSES 29 #include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.cpp.inc" 30 #include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrEnums.cpp.inc" 31 32 #define GET_TYPEDEF_CLASSES 33 #include "mlir/Dialect/SparseTensor/IR/SparseTensorTypes.cpp.inc" 34 35 using namespace mlir; 36 using namespace mlir::sparse_tensor; 37 38 // Support hashing LevelType such that SparseTensorEncodingAttr can be hashed as 39 // well. 40 namespace mlir::sparse_tensor { 41 llvm::hash_code hash_value(LevelType lt) { 42 return llvm::hash_value(static_cast<uint64_t>(lt)); 43 } 44 } // namespace mlir::sparse_tensor 45 46 //===----------------------------------------------------------------------===// 47 // Local Convenience Methods. 48 //===----------------------------------------------------------------------===// 49 50 static constexpr bool acceptBitWidth(unsigned bitWidth) { 51 switch (bitWidth) { 52 case 0: 53 case 8: 54 case 16: 55 case 32: 56 case 64: 57 return true; 58 default: 59 return false; 60 } 61 } 62 63 //===----------------------------------------------------------------------===// 64 // SparseTensorDialect StorageLayout. 65 //===----------------------------------------------------------------------===// 66 67 static constexpr Level kInvalidLevel = -1u; 68 static constexpr Level kInvalidFieldIndex = -1u; 69 static constexpr FieldIndex kDataFieldStartingIdx = 0; 70 71 void StorageLayout::foreachField( 72 llvm::function_ref<bool(FieldIndex, SparseTensorFieldKind, Level, 73 LevelType)> 74 callback) const { 75 const auto lvlTypes = enc.getLvlTypes(); 76 const Level lvlRank = enc.getLvlRank(); 77 SmallVector<COOSegment> cooSegs = SparseTensorType(enc).getCOOSegments(); 78 FieldIndex fieldIdx = kDataFieldStartingIdx; 79 80 ArrayRef cooSegsRef = cooSegs; 81 // Per-level storage. 82 for (Level l = 0; l < lvlRank; /*l += 1 or l += AoSCooLen*/) { 83 const auto lt = lvlTypes[l]; 84 if (isWithPosLT(lt)) { 85 if (!(callback(fieldIdx++, SparseTensorFieldKind::PosMemRef, l, lt))) 86 return; 87 } 88 if (isWithCrdLT(lt)) { 89 if (!(callback(fieldIdx++, SparseTensorFieldKind::CrdMemRef, l, lt))) 90 return; 91 } 92 if (!cooSegsRef.empty() && cooSegsRef.front().isSegmentStart(l)) { 93 if (!cooSegsRef.front().isSoA) { 94 // AoS COO, all singletons are fused into one memrefs. Skips the entire 95 // COO segement. 96 l = cooSegsRef.front().lvlRange.second; 97 } else { 98 // SoA COO, each singleton level has one memref. 99 l++; 100 } 101 // Expire handled COO segment. 102 cooSegsRef = cooSegsRef.drop_front(); 103 } else { 104 // Non COO levels. 105 l++; 106 } 107 } 108 // The values array. 109 if (!(callback(fieldIdx++, SparseTensorFieldKind::ValMemRef, kInvalidLevel, 110 LevelFormat::Undef))) 111 return; 112 // Put metadata at the end. 113 if (!(callback(fieldIdx++, SparseTensorFieldKind::StorageSpec, kInvalidLevel, 114 LevelFormat::Undef))) 115 return; 116 } 117 118 void sparse_tensor::foreachFieldAndTypeInSparseTensor( 119 SparseTensorType stt, 120 llvm::function_ref<bool(Type, FieldIndex, SparseTensorFieldKind, Level, 121 LevelType)> 122 callback) { 123 assert(stt.hasEncoding()); 124 // Construct the basic types. 125 const Type crdType = stt.getCrdType(); 126 const Type posType = stt.getPosType(); 127 const Type eltType = stt.getElementType(); 128 129 SmallVector<int64_t> memrefShape = stt.getBatchLvlShape(); 130 memrefShape.push_back(ShapedType::kDynamic); 131 132 const Type specType = StorageSpecifierType::get(stt.getEncoding()); 133 // memref<[batch] x ? x pos> positions 134 const Type posMemType = MemRefType::get(memrefShape, posType); 135 // memref<[batch] x ? x crd> coordinates 136 const Type crdMemType = MemRefType::get(memrefShape, crdType); 137 // memref<[batch] x ? x eltType> values 138 const Type valMemType = MemRefType::get(memrefShape, eltType); 139 140 StorageLayout(stt).foreachField([specType, posMemType, crdMemType, valMemType, 141 callback](FieldIndex fieldIdx, 142 SparseTensorFieldKind fieldKind, 143 Level lvl, LevelType lt) -> bool { 144 switch (fieldKind) { 145 case SparseTensorFieldKind::StorageSpec: 146 return callback(specType, fieldIdx, fieldKind, lvl, lt); 147 case SparseTensorFieldKind::PosMemRef: 148 return callback(posMemType, fieldIdx, fieldKind, lvl, lt); 149 case SparseTensorFieldKind::CrdMemRef: 150 return callback(crdMemType, fieldIdx, fieldKind, lvl, lt); 151 case SparseTensorFieldKind::ValMemRef: 152 return callback(valMemType, fieldIdx, fieldKind, lvl, lt); 153 }; 154 llvm_unreachable("unrecognized field kind"); 155 }); 156 } 157 158 unsigned StorageLayout::getNumFields() const { 159 unsigned numFields = 0; 160 foreachField([&numFields](FieldIndex, SparseTensorFieldKind, Level, 161 LevelType) -> bool { 162 numFields++; 163 return true; 164 }); 165 return numFields; 166 } 167 168 unsigned StorageLayout::getNumDataFields() const { 169 unsigned numFields = 0; // one value memref 170 foreachField([&numFields](FieldIndex fidx, SparseTensorFieldKind, Level, 171 LevelType) -> bool { 172 if (fidx >= kDataFieldStartingIdx) 173 numFields++; 174 return true; 175 }); 176 numFields -= 1; // the last field is StorageSpecifier 177 assert(numFields == getNumFields() - kDataFieldStartingIdx - 1); 178 return numFields; 179 } 180 181 std::pair<FieldIndex, unsigned> 182 StorageLayout::getFieldIndexAndStride(SparseTensorFieldKind kind, 183 std::optional<Level> lvl) const { 184 FieldIndex fieldIdx = kInvalidFieldIndex; 185 unsigned stride = 1; 186 if (kind == SparseTensorFieldKind::CrdMemRef) { 187 assert(lvl.has_value()); 188 const Level cooStart = SparseTensorType(enc).getAoSCOOStart(); 189 const Level lvlRank = enc.getLvlRank(); 190 if (lvl.value() >= cooStart && lvl.value() < lvlRank) { 191 lvl = cooStart; 192 stride = lvlRank - cooStart; 193 } 194 } 195 foreachField([lvl, kind, &fieldIdx](FieldIndex fIdx, 196 SparseTensorFieldKind fKind, Level fLvl, 197 LevelType lt) -> bool { 198 if ((lvl && fLvl == lvl.value() && kind == fKind) || 199 (kind == fKind && fKind == SparseTensorFieldKind::ValMemRef)) { 200 fieldIdx = fIdx; 201 // Returns false to break the iteration. 202 return false; 203 } 204 return true; 205 }); 206 assert(fieldIdx != kInvalidFieldIndex); 207 return std::pair<FieldIndex, unsigned>(fieldIdx, stride); 208 } 209 210 //===----------------------------------------------------------------------===// 211 // SparseTensorDialect Attribute Methods. 212 //===----------------------------------------------------------------------===// 213 214 std::optional<uint64_t> SparseTensorDimSliceAttr::getStatic(int64_t v) { 215 return isDynamic(v) ? std::nullopt 216 : std::make_optional(static_cast<uint64_t>(v)); 217 } 218 219 std::optional<uint64_t> SparseTensorDimSliceAttr::getStaticOffset() const { 220 return getStatic(getOffset()); 221 } 222 223 std::optional<uint64_t> SparseTensorDimSliceAttr::getStaticStride() const { 224 return getStatic(getStride()); 225 } 226 227 std::optional<uint64_t> SparseTensorDimSliceAttr::getStaticSize() const { 228 return getStatic(getSize()); 229 } 230 231 bool SparseTensorDimSliceAttr::isCompletelyDynamic() const { 232 return isDynamic(getOffset()) && isDynamic(getStride()) && 233 isDynamic(getSize()); 234 } 235 236 std::string SparseTensorDimSliceAttr::getStaticString(int64_t v) { 237 return isDynamic(v) ? "?" : std::to_string(v); 238 } 239 240 void SparseTensorDimSliceAttr::print(llvm::raw_ostream &os) const { 241 assert(getImpl() && "Uninitialized SparseTensorDimSliceAttr"); 242 os << '('; 243 os << getStaticString(getOffset()); 244 os << ", "; 245 os << getStaticString(getSize()); 246 os << ", "; 247 os << getStaticString(getStride()); 248 os << ')'; 249 } 250 251 void SparseTensorDimSliceAttr::print(AsmPrinter &printer) const { 252 print(printer.getStream()); 253 } 254 255 static ParseResult parseOptionalStaticSlice(int64_t &result, 256 AsmParser &parser) { 257 auto parseResult = parser.parseOptionalInteger(result); 258 if (parseResult.has_value()) { 259 if (parseResult.value().succeeded() && result < 0) { 260 parser.emitError( 261 parser.getCurrentLocation(), 262 "expect positive value or ? for slice offset/size/stride"); 263 return failure(); 264 } 265 return parseResult.value(); 266 } 267 268 // Else, and '?' which represented dynamic slice 269 result = SparseTensorDimSliceAttr::kDynamic; 270 return parser.parseQuestion(); 271 } 272 273 Attribute SparseTensorDimSliceAttr::parse(AsmParser &parser, Type type) { 274 int64_t offset = kDynamic, size = kDynamic, stride = kDynamic; 275 276 if (failed(parser.parseLParen()) || 277 failed(parseOptionalStaticSlice(offset, parser)) || 278 failed(parser.parseComma()) || 279 failed(parseOptionalStaticSlice(size, parser)) || 280 failed(parser.parseComma()) || 281 failed(parseOptionalStaticSlice(stride, parser)) || 282 failed(parser.parseRParen())) 283 return {}; 284 285 return parser.getChecked<SparseTensorDimSliceAttr>(parser.getContext(), 286 offset, size, stride); 287 } 288 289 LogicalResult 290 SparseTensorDimSliceAttr::verify(function_ref<InFlightDiagnostic()> emitError, 291 int64_t offset, int64_t size, int64_t stride) { 292 if (!isDynamic(offset) && offset < 0) 293 return emitError() << "expect non-negative value or ? for slice offset"; 294 if (!isDynamic(size) && size <= 0) 295 return emitError() << "expect positive value or ? for slice size"; 296 if (!isDynamic(stride) && stride <= 0) 297 return emitError() << "expect positive value or ? for slice stride"; 298 return success(); 299 } 300 301 SparseTensorEncodingAttr 302 SparseTensorEncodingAttr::withDimToLvl(AffineMap dimToLvl) const { 303 assert(getImpl() && "Uninitialized SparseTensorEncodingAttr"); 304 return SparseTensorEncodingAttr::get(getContext(), getLvlTypes(), dimToLvl, 305 AffineMap(), getPosWidth(), 306 getCrdWidth()); 307 } 308 309 SparseTensorEncodingAttr 310 SparseTensorEncodingAttr::withDimToLvl(SparseTensorEncodingAttr enc) const { 311 return withDimToLvl(enc ? enc.getDimToLvl() : AffineMap()); 312 } 313 314 SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutDimToLvl() const { 315 return withDimToLvl(AffineMap()); 316 } 317 318 SparseTensorEncodingAttr 319 SparseTensorEncodingAttr::withBitWidths(unsigned posWidth, 320 unsigned crdWidth) const { 321 assert(getImpl() && "Uninitialized SparseTensorEncodingAttr"); 322 return SparseTensorEncodingAttr::get(getContext(), getLvlTypes(), 323 getDimToLvl(), getLvlToDim(), posWidth, 324 crdWidth); 325 } 326 327 SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutBitWidths() const { 328 return withBitWidths(0, 0); 329 } 330 331 SparseTensorEncodingAttr SparseTensorEncodingAttr::withDimSlices( 332 ArrayRef<SparseTensorDimSliceAttr> dimSlices) const { 333 return SparseTensorEncodingAttr::get(getContext(), getLvlTypes(), 334 getDimToLvl(), getLvlToDim(), 335 getPosWidth(), getCrdWidth(), dimSlices); 336 } 337 338 SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutDimSlices() const { 339 return withDimSlices(ArrayRef<SparseTensorDimSliceAttr>{}); 340 } 341 342 uint64_t SparseTensorEncodingAttr::getBatchLvlRank() const { 343 ArrayRef<LevelType> lvlTypes = getLvlTypes(); 344 auto lastBatch = std::find_if(lvlTypes.rbegin(), lvlTypes.rend(), isBatchLT); 345 return std::distance(lastBatch, lvlTypes.rend()); 346 } 347 348 bool SparseTensorEncodingAttr::isAllDense() const { 349 return !getImpl() || llvm::all_of(getLvlTypes(), isDenseLT); 350 } 351 352 bool SparseTensorEncodingAttr::isAllOrdered() const { 353 return !getImpl() || llvm::all_of(getLvlTypes(), isOrderedLT); 354 } 355 356 bool SparseTensorEncodingAttr::isIdentity() const { 357 return !getImpl() || !getDimToLvl() || getDimToLvl().isIdentity(); 358 } 359 360 bool SparseTensorEncodingAttr::isPermutation() const { 361 return !getImpl() || !getDimToLvl() || getDimToLvl().isPermutation(); 362 } 363 364 Dimension SparseTensorEncodingAttr::getDimRank() const { 365 assert(getImpl() && "Uninitialized SparseTensorEncodingAttr"); 366 const auto dimToLvl = getDimToLvl(); 367 return dimToLvl ? dimToLvl.getNumDims() : getLvlRank(); 368 } 369 370 Level SparseTensorEncodingAttr::getLvlRank() const { 371 assert(getImpl() && "Uninitialized SparseTensorEncodingAttr"); 372 return getLvlTypes().size(); 373 } 374 375 LevelType SparseTensorEncodingAttr::getLvlType(Level l) const { 376 if (!getImpl()) 377 return LevelFormat::Batch; 378 assert(l < getLvlRank() && "Level is out of bounds"); 379 return getLvlTypes()[l]; 380 } 381 382 bool SparseTensorEncodingAttr::isSlice() const { 383 assert(getImpl() && "Uninitialized SparseTensorEncodingAttr"); 384 return !getDimSlices().empty(); 385 } 386 387 SparseTensorDimSliceAttr 388 SparseTensorEncodingAttr::getDimSlice(Dimension dim) const { 389 assert(isSlice() && "Is not a slice"); 390 const auto dimSlices = getDimSlices(); 391 assert(dim < dimSlices.size() && "Dimension is out of bounds"); 392 return dimSlices[dim]; 393 } 394 395 std::optional<uint64_t> 396 SparseTensorEncodingAttr::getStaticDimSliceOffset(Dimension dim) const { 397 return getDimSlice(dim).getStaticOffset(); 398 } 399 400 std::optional<uint64_t> 401 SparseTensorEncodingAttr::getStaticDimSliceStride(Dimension dim) const { 402 return getDimSlice(dim).getStaticStride(); 403 } 404 405 std::optional<uint64_t> 406 SparseTensorEncodingAttr::getStaticLvlSliceOffset(Level lvl) const { 407 return getStaticDimSliceOffset(toDim(*this, lvl)); 408 } 409 410 std::optional<uint64_t> 411 SparseTensorEncodingAttr::getStaticLvlSliceStride(Level lvl) const { 412 return getStaticDimSliceStride(toDim(*this, lvl)); 413 } 414 415 SmallVector<int64_t> 416 SparseTensorEncodingAttr::translateShape(ArrayRef<int64_t> srcShape, 417 CrdTransDirectionKind dir) const { 418 if (isIdentity()) 419 return SmallVector<int64_t>(srcShape); 420 421 SmallVector<int64_t> ret; 422 unsigned rank = 423 dir == CrdTransDirectionKind::dim2lvl ? getLvlRank() : getDimRank(); 424 ret.reserve(rank); 425 426 if (isPermutation()) { 427 for (unsigned r = 0; r < rank; r++) { 428 unsigned trans = dir == CrdTransDirectionKind::dim2lvl ? toDim(*this, r) 429 : toLvl(*this, r); 430 ret.push_back(srcShape[trans]); 431 } 432 return ret; 433 } 434 435 // Handle non-permutation maps. 436 AffineMap transMap = 437 dir == CrdTransDirectionKind::dim2lvl ? getDimToLvl() : getLvlToDim(); 438 439 SmallVector<AffineExpr> dimRep; 440 dimRep.reserve(srcShape.size()); 441 for (int64_t sz : srcShape) { 442 if (!ShapedType::isDynamic(sz)) { 443 // Push back the max coordinate for the given dimension/level size. 444 dimRep.push_back(getAffineConstantExpr(sz - 1, getContext())); 445 } else { 446 // A dynamic size, use a AffineDimExpr to symbolize the value. 447 dimRep.push_back(getAffineDimExpr(dimRep.size(), getContext())); 448 } 449 }; 450 451 for (AffineExpr exp : transMap.getResults()) { 452 // Do constant propagation on the affine map. 453 AffineExpr evalExp = 454 simplifyAffineExpr(exp.replaceDims(dimRep), srcShape.size(), 0); 455 // use llvm namespace here to avoid ambiguity 456 if (auto c = llvm::dyn_cast<AffineConstantExpr>(evalExp)) { 457 ret.push_back(c.getValue() + 1); 458 } else { 459 if (auto mod = llvm::dyn_cast<AffineBinaryOpExpr>(evalExp); 460 mod && mod.getKind() == AffineExprKind::Mod) { 461 // We can still infer a static bound for expressions in form 462 // "d % constant" since d % constant \in [0, constant). 463 if (auto bound = llvm::dyn_cast<AffineConstantExpr>(mod.getRHS())) { 464 ret.push_back(bound.getValue()); 465 continue; 466 } 467 } 468 ret.push_back(ShapedType::kDynamic); 469 } 470 } 471 assert(ret.size() == rank); 472 return ret; 473 } 474 475 ValueRange 476 SparseTensorEncodingAttr::translateCrds(OpBuilder &builder, Location loc, 477 ValueRange crds, 478 CrdTransDirectionKind dir) const { 479 if (!getImpl()) 480 return crds; 481 482 SmallVector<Type> retType( 483 dir == CrdTransDirectionKind::lvl2dim ? getDimRank() : getLvlRank(), 484 builder.getIndexType()); 485 auto transOp = builder.create<CrdTranslateOp>(loc, retType, crds, dir, *this); 486 return transOp.getOutCrds(); 487 } 488 489 Attribute SparseTensorEncodingAttr::parse(AsmParser &parser, Type type) { 490 // Open "<{" part. 491 if (failed(parser.parseLess())) 492 return {}; 493 if (failed(parser.parseLBrace())) 494 return {}; 495 496 // Process the data from the parsed dictionary value into struct-like data. 497 SmallVector<LevelType> lvlTypes; 498 SmallVector<SparseTensorDimSliceAttr> dimSlices; 499 AffineMap dimToLvl = {}; 500 AffineMap lvlToDim = {}; 501 unsigned posWidth = 0; 502 unsigned crdWidth = 0; 503 StringRef attrName; 504 SmallVector<StringRef, 3> keys = {"map", "posWidth", "crdWidth"}; 505 while (succeeded(parser.parseOptionalKeyword(&attrName))) { 506 // Detect admissible keyword. 507 auto *it = find(keys, attrName); 508 if (it == keys.end()) { 509 parser.emitError(parser.getNameLoc(), "unexpected key: ") << attrName; 510 return {}; 511 } 512 unsigned keyWordIndex = it - keys.begin(); 513 // Consume the `=` after keys 514 if (failed(parser.parseEqual())) 515 return {}; 516 // Dispatch on keyword. 517 switch (keyWordIndex) { 518 case 0: { // map 519 ir_detail::DimLvlMapParser cParser(parser); 520 auto res = cParser.parseDimLvlMap(); 521 if (failed(res)) 522 return {}; 523 const auto &dlm = *res; 524 525 const Level lvlRank = dlm.getLvlRank(); 526 for (Level lvl = 0; lvl < lvlRank; lvl++) 527 lvlTypes.push_back(dlm.getLvlType(lvl)); 528 529 const Dimension dimRank = dlm.getDimRank(); 530 for (Dimension dim = 0; dim < dimRank; dim++) 531 dimSlices.push_back(dlm.getDimSlice(dim)); 532 // NOTE: the old syntax requires an all-or-nothing approach to 533 // `dimSlices`; therefore, if any slice actually exists then we need 534 // to convert null-DSA into default/nop DSA. 535 const auto isDefined = [](SparseTensorDimSliceAttr slice) { 536 return static_cast<bool>(slice.getImpl()); 537 }; 538 if (llvm::any_of(dimSlices, isDefined)) { 539 const auto defaultSlice = 540 SparseTensorDimSliceAttr::get(parser.getContext()); 541 for (Dimension dim = 0; dim < dimRank; dim++) 542 if (!isDefined(dimSlices[dim])) 543 dimSlices[dim] = defaultSlice; 544 } else { 545 dimSlices.clear(); 546 } 547 548 dimToLvl = dlm.getDimToLvlMap(parser.getContext()); 549 lvlToDim = dlm.getLvlToDimMap(parser.getContext()); 550 break; 551 } 552 case 1: { // posWidth 553 Attribute attr; 554 if (failed(parser.parseAttribute(attr))) 555 return {}; 556 auto intAttr = llvm::dyn_cast<IntegerAttr>(attr); 557 if (!intAttr) { 558 parser.emitError(parser.getNameLoc(), 559 "expected an integral position bitwidth"); 560 return {}; 561 } 562 posWidth = intAttr.getInt(); 563 break; 564 } 565 case 2: { // crdWidth 566 Attribute attr; 567 if (failed(parser.parseAttribute(attr))) 568 return {}; 569 auto intAttr = llvm::dyn_cast<IntegerAttr>(attr); 570 if (!intAttr) { 571 parser.emitError(parser.getNameLoc(), 572 "expected an integral index bitwidth"); 573 return {}; 574 } 575 crdWidth = intAttr.getInt(); 576 break; 577 } 578 } // switch 579 // Only last item can omit the comma. 580 if (parser.parseOptionalComma().failed()) 581 break; 582 } 583 584 // Close "}>" part. 585 if (failed(parser.parseRBrace())) 586 return {}; 587 if (failed(parser.parseGreater())) 588 return {}; 589 590 // Construct struct-like storage for attribute. 591 if (!lvlToDim || lvlToDim.isEmpty()) { 592 lvlToDim = inferLvlToDim(dimToLvl, parser.getContext()); 593 } 594 return parser.getChecked<SparseTensorEncodingAttr>( 595 parser.getContext(), lvlTypes, dimToLvl, lvlToDim, posWidth, crdWidth, 596 dimSlices); 597 } 598 599 void SparseTensorEncodingAttr::print(AsmPrinter &printer) const { 600 auto map = static_cast<AffineMap>(getDimToLvl()); 601 // Empty affine map indicates identity map 602 if (!map) 603 map = AffineMap::getMultiDimIdentityMap(getLvlTypes().size(), getContext()); 604 printer << "<{ map = "; 605 printSymbols(map, printer); 606 printer << '('; 607 printDimensions(map, printer, getDimSlices()); 608 printer << ") -> ("; 609 printLevels(map, printer, getLvlTypes()); 610 printer << ')'; 611 // Print remaining members only for non-default values. 612 if (getPosWidth()) 613 printer << ", posWidth = " << getPosWidth(); 614 if (getCrdWidth()) 615 printer << ", crdWidth = " << getCrdWidth(); 616 printer << " }>"; 617 } 618 619 void SparseTensorEncodingAttr::printSymbols(AffineMap &map, 620 AsmPrinter &printer) const { 621 if (map.getNumSymbols() == 0) 622 return; 623 printer << '['; 624 for (unsigned i = 0, n = map.getNumSymbols() - 1; i < n; i++) 625 printer << 's' << i << ", "; 626 if (map.getNumSymbols() >= 1) 627 printer << 's' << map.getNumSymbols() - 1; 628 printer << ']'; 629 } 630 631 void SparseTensorEncodingAttr::printDimensions( 632 AffineMap &map, AsmPrinter &printer, 633 ArrayRef<SparseTensorDimSliceAttr> dimSlices) const { 634 if (!dimSlices.empty()) { 635 for (unsigned i = 0, n = map.getNumDims() - 1; i < n; i++) 636 printer << 'd' << i << " : " << dimSlices[i] << ", "; 637 if (map.getNumDims() >= 1) { 638 printer << 'd' << map.getNumDims() - 1 << " : " 639 << dimSlices[map.getNumDims() - 1]; 640 } 641 } else { 642 for (unsigned i = 0, n = map.getNumDims() - 1; i < n; i++) 643 printer << 'd' << i << ", "; 644 if (map.getNumDims() >= 1) 645 printer << 'd' << map.getNumDims() - 1; 646 } 647 } 648 649 void SparseTensorEncodingAttr::printLevels(AffineMap &map, AsmPrinter &printer, 650 ArrayRef<LevelType> lvlTypes) const { 651 for (unsigned i = 0, n = map.getNumResults() - 1; i < n; i++) { 652 map.getResult(i).print(printer.getStream()); 653 printer << " : " << toMLIRString(lvlTypes[i]) << ", "; 654 } 655 if (map.getNumResults() >= 1) { 656 auto lastIndex = map.getNumResults() - 1; 657 map.getResult(lastIndex).print(printer.getStream()); 658 printer << " : " << toMLIRString(lvlTypes[lastIndex]); 659 } 660 } 661 662 LogicalResult SparseTensorEncodingAttr::verify( 663 function_ref<InFlightDiagnostic()> emitError, ArrayRef<LevelType> lvlTypes, 664 AffineMap dimToLvl, AffineMap lvlToDim, unsigned posWidth, 665 unsigned crdWidth, ArrayRef<SparseTensorDimSliceAttr> dimSlices) { 666 if (!acceptBitWidth(posWidth)) 667 return emitError() << "unexpected position bitwidth: " << posWidth; 668 if (!acceptBitWidth(crdWidth)) 669 return emitError() << "unexpected coordinate bitwidth: " << crdWidth; 670 if (auto it = std::find_if(lvlTypes.begin(), lvlTypes.end(), isSingletonLT); 671 it != std::end(lvlTypes)) { 672 if (it == lvlTypes.begin() || 673 (!isCompressedLT(*(it - 1)) && !isLooseCompressedLT(*(it - 1)))) 674 return emitError() << "expected compressed or loose_compressed level " 675 "before singleton level"; 676 if (!std::all_of(it, lvlTypes.end(), 677 [](LevelType i) { return isSingletonLT(i); })) 678 return emitError() << "expected all singleton lvlTypes " 679 "following a singleton level"; 680 // We can potentially support mixed SoA/AoS singleton levels. 681 if (!std::all_of(it, lvlTypes.end(), [it](LevelType i) { 682 return it->isa<LevelPropNonDefault::SoA>() == 683 i.isa<LevelPropNonDefault::SoA>(); 684 })) { 685 return emitError() << "expected all singleton lvlTypes stored in the " 686 "same memory layout (SoA vs AoS)."; 687 } 688 } 689 690 auto lastBatch = std::find_if(lvlTypes.rbegin(), lvlTypes.rend(), isBatchLT); 691 if (!std::all_of(lastBatch, lvlTypes.rend(), isBatchLT)) 692 return emitError() << "Batch lvlType can only be leading levels."; 693 694 // SoA property can only be applied on singleton level. 695 auto soaLvls = llvm::make_filter_range(lvlTypes, [](LevelType lt) { 696 return lt.isa<LevelPropNonDefault::SoA>(); 697 }); 698 if (llvm::any_of(soaLvls, [](LevelType lt) { 699 return !lt.isa<LevelFormat::Singleton>(); 700 })) { 701 return emitError() << "SoA is only applicable to singleton lvlTypes."; 702 } 703 704 // TODO: audit formats that actually are supported by backend. 705 if (auto it = std::find_if(lvlTypes.begin(), lvlTypes.end(), isNOutOfMLT); 706 it != std::end(lvlTypes)) { 707 if (it != lvlTypes.end() - 1) 708 return emitError() << "expected n_out_of_m to be the last level type"; 709 if (!std::all_of(lvlTypes.begin(), it, 710 [](LevelType i) { return isDenseLT(i); })) 711 return emitError() << "expected all dense lvlTypes " 712 "before a n_out_of_m level"; 713 if (dimToLvl && (dimToLvl.getNumDims() != dimToLvl.getNumResults())) { 714 if (!isBlockSparsity(dimToLvl)) { 715 return emitError() 716 << "expected 1xm block structure for n_out_of_m level"; 717 } 718 auto sizes = getBlockSize(dimToLvl); 719 unsigned coefficient = 0; 720 for (const auto &elem : sizes) { 721 if (elem != 0) { 722 if (elem != coefficient && coefficient != 0) { 723 return emitError() << "expected only one blocked level " 724 "with the same coefficients"; 725 } 726 coefficient = elem; 727 } 728 } 729 if (coefficient != getM(*it)) { 730 return emitError() << "expected coeffiencts of Affine expressions " 731 "to be equal to m of n_out_of_m level"; 732 } 733 } 734 } 735 // Before we can check that the level-rank is consistent/coherent 736 // across all fields, we need to define it. The source-of-truth for 737 // the `getLvlRank` method is the length of the level-types array, 738 // since it must always be provided and have full rank; therefore we 739 // use that same source-of-truth here. 740 const Level lvlRank = lvlTypes.size(); 741 if (lvlRank == 0) 742 return emitError() << "expected a non-empty array for lvlTypes"; 743 // We save `dimRank` here because we'll also need it to verify `dimSlices`. 744 const Dimension dimRank = dimToLvl ? dimToLvl.getNumDims() : lvlRank; 745 if (dimToLvl) { 746 if (dimToLvl.getNumResults() != lvlRank) 747 return emitError() 748 << "level-rank mismatch between dimToLvl and lvlTypes: " 749 << dimToLvl.getNumResults() << " != " << lvlRank; 750 auto inferRes = inferLvlToDim(dimToLvl, dimToLvl.getContext()); 751 // Symbols can't be inferred but are acceptable. 752 if (!inferRes && dimToLvl.getNumSymbols() == 0) 753 return emitError() << "failed to infer lvlToDim from dimToLvl"; 754 if (lvlToDim && (inferRes != lvlToDim)) 755 return emitError() << "expected lvlToDim to be an inverse of dimToLvl"; 756 if (dimRank > lvlRank) 757 return emitError() << "unexpected dimToLvl mapping from " << dimRank 758 << " to " << lvlRank; 759 } 760 if (!dimSlices.empty()) { 761 if (dimSlices.size() != dimRank) 762 return emitError() 763 << "dimension-rank mismatch between dimSlices and dimToLvl: " 764 << dimSlices.size() << " != " << dimRank; 765 // Compiler support for `dimSlices` currently requires that the two 766 // ranks agree. (However, it does allow `dimToLvl` to be a permutation.) 767 if (dimRank != lvlRank) 768 return emitError() 769 << "dimSlices expected dimension-rank to match level-rank: " 770 << dimRank << " != " << lvlRank; 771 } 772 return success(); 773 } 774 775 LogicalResult SparseTensorEncodingAttr::verifyEncoding( 776 ArrayRef<Size> dimShape, Type elementType, 777 function_ref<InFlightDiagnostic()> emitError) const { 778 // Check structural integrity. In particular, this ensures that the 779 // level-rank is coherent across all the fields. 780 if (failed(verify(emitError, getLvlTypes(), getDimToLvl(), getLvlToDim(), 781 getPosWidth(), getCrdWidth(), getDimSlices()))) 782 return failure(); 783 // Check integrity with tensor type specifics. In particular, we 784 // need only check that the dimension-rank of the tensor agrees with 785 // the dimension-rank of the encoding. 786 const Dimension dimRank = dimShape.size(); 787 if (dimRank == 0) 788 return emitError() << "expected non-scalar sparse tensor"; 789 if (getDimRank() != dimRank) 790 return emitError() 791 << "dimension-rank mismatch between encoding and tensor shape: " 792 << getDimRank() << " != " << dimRank; 793 return success(); 794 } 795 796 //===----------------------------------------------------------------------===// 797 // SparseTensorType Methods. 798 //===----------------------------------------------------------------------===// 799 800 bool mlir::sparse_tensor::SparseTensorType::isCOOType(Level startLvl, 801 bool isUnique) const { 802 if (!hasEncoding()) 803 return false; 804 if (!isCompressedLvl(startLvl) && !isLooseCompressedLvl(startLvl)) 805 return false; 806 for (Level l = startLvl + 1; l < lvlRank; ++l) 807 if (!isSingletonLvl(l)) 808 return false; 809 // If isUnique is true, then make sure that the last level is unique, 810 // that is, when lvlRank == 1, the only compressed level is unique, 811 // and when lvlRank > 1, the last singleton is unique. 812 return !isUnique || isUniqueLvl(lvlRank - 1); 813 } 814 815 Level mlir::sparse_tensor::SparseTensorType::getAoSCOOStart() const { 816 SmallVector<COOSegment> coo = getCOOSegments(); 817 assert(coo.size() == 1 || coo.empty()); 818 if (!coo.empty() && coo.front().isAoS()) { 819 return coo.front().lvlRange.first; 820 } 821 return lvlRank; 822 } 823 824 SmallVector<COOSegment> 825 mlir::sparse_tensor::SparseTensorType::getCOOSegments() const { 826 SmallVector<COOSegment> ret; 827 if (!hasEncoding() || lvlRank <= 1) 828 return ret; 829 830 ArrayRef<LevelType> lts = getLvlTypes(); 831 Level l = 0; 832 while (l < lvlRank) { 833 auto lt = lts[l]; 834 if (lt.isa<LevelFormat::Compressed, LevelFormat::LooseCompressed>()) { 835 auto cur = lts.begin() + l; 836 auto end = std::find_if(cur + 1, lts.end(), [](LevelType lt) { 837 return !lt.isa<LevelFormat::Singleton>(); 838 }); 839 unsigned cooLen = std::distance(cur, end); 840 if (cooLen > 1) { 841 // To support mixed SoA/AoS COO, we should break the segment when the 842 // storage scheme changes, for now we faithfully assume that all 843 // consecutive singleton levels have the same storage format as verified 844 // STEA. 845 ret.push_back(COOSegment{std::make_pair(l, l + cooLen), 846 lts[l + 1].isa<LevelPropNonDefault::SoA>()}); 847 } 848 l += cooLen; 849 } else { 850 l++; 851 } 852 } 853 return ret; 854 } 855 856 RankedTensorType 857 mlir::sparse_tensor::SparseTensorType::getCOOType(bool ordered) const { 858 SmallVector<LevelType> lvlTypes; 859 lvlTypes.reserve(lvlRank); 860 // A non-unique compressed level at beginning (unless this is 861 // also the last level, then it is unique). 862 lvlTypes.push_back( 863 *buildLevelType(LevelFormat::Compressed, ordered, lvlRank == 1)); 864 if (lvlRank > 1) { 865 // Followed by n-2 non-unique singleton levels. 866 std::fill_n(std::back_inserter(lvlTypes), lvlRank - 2, 867 *buildLevelType(LevelFormat::Singleton, ordered, false)); 868 // Ends by a unique singleton level. 869 lvlTypes.push_back(*buildLevelType(LevelFormat::Singleton, ordered, true)); 870 } 871 auto enc = SparseTensorEncodingAttr::get(getContext(), lvlTypes, 872 getDimToLvl(), getLvlToDim(), 873 getPosWidth(), getCrdWidth()); 874 return RankedTensorType::get(getDimShape(), getElementType(), enc); 875 } 876 877 //===----------------------------------------------------------------------===// 878 // Convenience Methods. 879 //===----------------------------------------------------------------------===// 880 881 SparseTensorEncodingAttr 882 mlir::sparse_tensor::getSparseTensorEncoding(Type type) { 883 if (auto ttp = llvm::dyn_cast<RankedTensorType>(type)) 884 return llvm::dyn_cast_or_null<SparseTensorEncodingAttr>(ttp.getEncoding()); 885 if (auto mdtp = llvm::dyn_cast<StorageSpecifierType>(type)) 886 return mdtp.getEncoding(); 887 return nullptr; 888 } 889 890 AffineMap mlir::sparse_tensor::inferLvlToDim(AffineMap dimToLvl, 891 MLIRContext *context) { 892 auto map = static_cast<AffineMap>(dimToLvl); 893 AffineMap lvlToDim; 894 // Return an empty lvlToDim when inference is not successful. 895 if (!map || map.getNumSymbols() != 0) { 896 lvlToDim = AffineMap(); 897 } else if (map.isPermutation()) { 898 lvlToDim = inversePermutation(map); 899 } else if (isBlockSparsity(map)) { 900 lvlToDim = inverseBlockSparsity(map, context); 901 } 902 return lvlToDim; 903 } 904 905 AffineMap mlir::sparse_tensor::inverseBlockSparsity(AffineMap dimToLvl, 906 MLIRContext *context) { 907 SmallVector<AffineExpr> lvlExprs; 908 auto numLvls = dimToLvl.getNumResults(); 909 lvlExprs.reserve(numLvls); 910 // lvlExprComponents stores information of the floordiv and mod operations 911 // applied to the same dimension, so as to build the lvlToDim map. 912 std::map<unsigned, SmallVector<AffineExpr, 3>> lvlExprComponents; 913 for (unsigned i = 0, n = numLvls; i < n; i++) { 914 auto result = dimToLvl.getResult(i); 915 if (auto binOp = dyn_cast<AffineBinaryOpExpr>(result)) { 916 if (result.getKind() == AffineExprKind::FloorDiv) { 917 // Position of the dimension in dimToLvl. 918 auto pos = dyn_cast<AffineDimExpr>(binOp.getLHS()).getPosition(); 919 assert(lvlExprComponents.find(pos) == lvlExprComponents.end() && 920 "expected only one floordiv for each dimension"); 921 SmallVector<AffineExpr, 3> components; 922 // Level variable for floordiv. 923 components.push_back(getAffineDimExpr(i, context)); 924 // Multiplier. 925 components.push_back(binOp.getRHS()); 926 // Map key is the position of the dimension. 927 lvlExprComponents[pos] = components; 928 } else if (result.getKind() == AffineExprKind::Mod) { 929 auto pos = dyn_cast<AffineDimExpr>(binOp.getLHS()).getPosition(); 930 assert(lvlExprComponents.find(pos) != lvlExprComponents.end() && 931 "expected floordiv before mod"); 932 // Add level variable for mod to the same vector 933 // of the corresponding floordiv. 934 lvlExprComponents[pos].push_back(getAffineDimExpr(i, context)); 935 } else { 936 assert(false && "expected floordiv or mod"); 937 } 938 } else { 939 lvlExprs.push_back(getAffineDimExpr(i, context)); 940 } 941 } 942 // Build lvlExprs from lvlExprComponents. 943 // For example, for il = i floordiv 2 and ii = i mod 2, the components 944 // would be [il, 2, ii]. It could be used to build the AffineExpr 945 // i = il * 2 + ii in lvlToDim. 946 for (auto &components : lvlExprComponents) { 947 assert(components.second.size() == 3 && 948 "expected 3 components to build lvlExprs"); 949 auto mulOp = getAffineBinaryOpExpr( 950 AffineExprKind::Mul, components.second[0], components.second[1]); 951 auto addOp = 952 getAffineBinaryOpExpr(AffineExprKind::Add, mulOp, components.second[2]); 953 lvlExprs.push_back(addOp); 954 } 955 return dimToLvl.get(dimToLvl.getNumResults(), 0, lvlExprs, context); 956 } 957 958 SmallVector<unsigned> mlir::sparse_tensor::getBlockSize(AffineMap dimToLvl) { 959 assert(isBlockSparsity(dimToLvl) && 960 "expected dimToLvl to be block sparsity for calling getBlockSize"); 961 SmallVector<unsigned> blockSize; 962 for (auto result : dimToLvl.getResults()) { 963 if (auto binOp = dyn_cast<AffineBinaryOpExpr>(result)) { 964 if (result.getKind() == AffineExprKind::Mod) { 965 blockSize.push_back( 966 dyn_cast<AffineConstantExpr>(binOp.getRHS()).getValue()); 967 } 968 } else { 969 blockSize.push_back(0); 970 } 971 } 972 return blockSize; 973 } 974 975 bool mlir::sparse_tensor::isBlockSparsity(AffineMap dimToLvl) { 976 if (!dimToLvl) 977 return false; 978 std::map<unsigned, int64_t> coeffientMap; 979 bool hasBlock = false; 980 for (auto result : dimToLvl.getResults()) { 981 if (auto binOp = dyn_cast<AffineBinaryOpExpr>(result)) { 982 // Check for "dim op const". 983 auto dimOp = dyn_cast<AffineDimExpr>(binOp.getLHS()); 984 auto conOp = dyn_cast<AffineConstantExpr>(binOp.getRHS()); 985 if (!dimOp || !conOp || conOp.getValue() <= 0) 986 return false; 987 // Inspect "dim / const" or "dim % const". 988 auto pos = dimOp.getPosition(); 989 if (binOp.getKind() == AffineExprKind::FloorDiv) { 990 // Expect only one floordiv for each dimension. 991 if (coeffientMap.find(pos) != coeffientMap.end()) 992 return false; 993 // Record coefficient of the floordiv. 994 coeffientMap[pos] = conOp.getValue(); 995 } else if (binOp.getKind() == AffineExprKind::Mod) { 996 // Expect floordiv before mod. 997 if (coeffientMap.find(pos) == coeffientMap.end()) 998 return false; 999 // Expect mod to have the same coefficient as floordiv. 1000 if (conOp.getValue() != coeffientMap[pos]) 1001 return false; 1002 hasBlock = true; 1003 } else { 1004 return false; 1005 } 1006 } else if (auto dimOp = dyn_cast<AffineDimExpr>(result)) { 1007 auto pos = dimOp.getPosition(); 1008 // Expect dim to be unset. 1009 if (coeffientMap.find(pos) != coeffientMap.end()) 1010 return false; 1011 coeffientMap[pos] = 0; 1012 } else { 1013 return false; 1014 } 1015 } 1016 return hasBlock; 1017 } 1018 1019 bool mlir::sparse_tensor::hasAnyNonIdentityOperandsOrResults(Operation *op) { 1020 auto hasNonIdentityMap = [](Value v) { 1021 auto stt = tryGetSparseTensorType(v); 1022 return stt && !stt->isIdentity(); 1023 }; 1024 1025 return llvm::any_of(op->getOperands(), hasNonIdentityMap) || 1026 llvm::any_of(op->getResults(), hasNonIdentityMap); 1027 } 1028 1029 Dimension mlir::sparse_tensor::toDim(SparseTensorEncodingAttr enc, Level l) { 1030 if (enc) { 1031 assert(enc.isPermutation() && "Non permutation map not supported"); 1032 if (const auto dimToLvl = enc.getDimToLvl()) 1033 return dimToLvl.getDimPosition(l); 1034 } 1035 return l; 1036 } 1037 1038 Level mlir::sparse_tensor::toLvl(SparseTensorEncodingAttr enc, Dimension d) { 1039 if (enc) { 1040 assert(enc.isPermutation() && "Non permutation map not supported"); 1041 if (const auto lvlToDim = enc.getLvlToDim()) 1042 return lvlToDim.getDimPosition(d); 1043 } 1044 return d; 1045 } 1046 1047 /// We normalized sparse tensor encoding attribute by always using 1048 /// ordered/unique LT such that "compressed_nu_no" and "compressed_nu" (as well 1049 /// as other variants) lead to the same storage specifier type, and stripping 1050 /// irrelevant fields that do not alter the sparse tensor memory layout. 1051 static SparseTensorEncodingAttr 1052 getNormalizedEncodingForSpecifier(SparseTensorEncodingAttr enc) { 1053 SmallVector<LevelType> lts; 1054 for (auto lt : enc.getLvlTypes()) 1055 lts.push_back(lt.stripStorageIrrelevantProperties()); 1056 1057 return SparseTensorEncodingAttr::get( 1058 enc.getContext(), lts, 1059 AffineMap(), // dimToLvl (irrelevant to storage specifier) 1060 AffineMap(), // lvlToDim (irrelevant to storage specifier) 1061 // Always use `index` for memSize and lvlSize instead of reusing 1062 // `getPosWidth` and `getCrdWidth`. It allows us to reuse the same SSA 1063 // value for different bitwidth, it also avoids casting between index and 1064 // integer (returned by DimOp) 1065 0, 0, enc.getDimSlices()); 1066 } 1067 1068 StorageSpecifierType 1069 StorageSpecifierType::get(MLIRContext *ctx, SparseTensorEncodingAttr encoding) { 1070 return Base::get(ctx, getNormalizedEncodingForSpecifier(encoding)); 1071 } 1072 1073 //===----------------------------------------------------------------------===// 1074 // SparseTensorDialect Operations. 1075 //===----------------------------------------------------------------------===// 1076 1077 static LogicalResult lvlIsInBounds(Level lvl, Value tensor) { 1078 return success(lvl < getSparseTensorType(tensor).getLvlRank()); 1079 } 1080 1081 static LogicalResult isMatchingWidth(Value mem, unsigned width) { 1082 const Type etp = getMemRefType(mem).getElementType(); 1083 return success(width == 0 ? etp.isIndex() : etp.isInteger(width)); 1084 } 1085 1086 static LogicalResult verifySparsifierGetterSetter( 1087 StorageSpecifierKind mdKind, std::optional<Level> lvl, 1088 TypedValue<StorageSpecifierType> md, Operation *op) { 1089 if (mdKind == StorageSpecifierKind::ValMemSize && lvl) { 1090 return op->emitError( 1091 "redundant level argument for querying value memory size"); 1092 } 1093 1094 const auto enc = md.getType().getEncoding(); 1095 const Level lvlRank = enc.getLvlRank(); 1096 1097 if (mdKind == StorageSpecifierKind::DimOffset || 1098 mdKind == StorageSpecifierKind::DimStride) 1099 if (!enc.isSlice()) 1100 return op->emitError("requested slice data on non-slice tensor"); 1101 1102 if (mdKind != StorageSpecifierKind::ValMemSize) { 1103 if (!lvl) 1104 return op->emitError("missing level argument"); 1105 1106 const Level l = lvl.value(); 1107 if (l >= lvlRank) 1108 return op->emitError("requested level is out of bounds"); 1109 1110 if (mdKind == StorageSpecifierKind::PosMemSize && enc.isSingletonLvl(l)) 1111 return op->emitError( 1112 "requested position memory size on a singleton level"); 1113 } 1114 return success(); 1115 } 1116 1117 static Type getFieldElemType(SparseTensorType stt, SparseTensorFieldKind kind) { 1118 switch (kind) { 1119 case SparseTensorFieldKind::CrdMemRef: 1120 return stt.getCrdType(); 1121 case SparseTensorFieldKind::PosMemRef: 1122 return stt.getPosType(); 1123 case SparseTensorFieldKind::ValMemRef: 1124 return stt.getElementType(); 1125 case SparseTensorFieldKind::StorageSpec: 1126 return nullptr; 1127 } 1128 llvm_unreachable("Unrecognizable FieldKind"); 1129 } 1130 1131 static LogicalResult verifyPackUnPack(Operation *op, bool requiresStaticShape, 1132 SparseTensorType stt, 1133 RankedTensorType valTp, 1134 TypeRange lvlTps) { 1135 if (requiresStaticShape && !stt.hasStaticDimShape()) 1136 return op->emitError("the sparse-tensor must have static shape"); 1137 if (!stt.hasEncoding()) 1138 return op->emitError("the sparse-tensor must have an encoding attribute"); 1139 1140 // Verifies the trailing COO. 1141 Level cooStartLvl = stt.getAoSCOOStart(); 1142 if (cooStartLvl < stt.getLvlRank()) { 1143 // We only supports trailing COO for now, must be the last input. 1144 auto cooTp = llvm::cast<ShapedType>(lvlTps.back()); 1145 // The coordinates should be in shape of <? x rank> 1146 unsigned expCOORank = stt.getLvlRank() - cooStartLvl; 1147 if (cooTp.getRank() != 2 || expCOORank != cooTp.getShape().back()) { 1148 op->emitError("input/output trailing COO level-ranks don't match"); 1149 } 1150 } 1151 1152 // Verifies that all types match. 1153 StorageLayout layout(stt.getEncoding()); 1154 if (layout.getNumDataFields() != lvlTps.size() + 1) // plus one value memref 1155 return op->emitError("inconsistent number of fields between input/output"); 1156 1157 unsigned idx = 0; 1158 bool misMatch = false; 1159 layout.foreachField([&idx, &misMatch, stt, valTp, 1160 lvlTps](FieldIndex fid, SparseTensorFieldKind fKind, 1161 Level lvl, LevelType lt) -> bool { 1162 if (fKind == SparseTensorFieldKind::StorageSpec) 1163 return true; 1164 1165 Type inputTp = nullptr; 1166 if (fKind == SparseTensorFieldKind::ValMemRef) { 1167 inputTp = valTp; 1168 } else { 1169 assert(fid == idx && stt.getLvlType(lvl) == lt); 1170 inputTp = lvlTps[idx++]; 1171 } 1172 // The input element type and expected element type should match. 1173 Type inpElemTp = llvm::cast<TensorType>(inputTp).getElementType(); 1174 Type expElemTp = getFieldElemType(stt, fKind); 1175 if (inpElemTp != expElemTp) { 1176 misMatch = true; 1177 return false; // to terminate the iteration 1178 } 1179 return true; 1180 }); 1181 1182 if (misMatch) 1183 return op->emitError("input/output element-types don't match"); 1184 return success(); 1185 } 1186 1187 LogicalResult AssembleOp::verify() { 1188 const auto valuesTp = getRankedTensorType(getValues()); 1189 const auto lvlsTp = getLevels().getTypes(); 1190 const auto resTp = getSparseTensorType(getResult()); 1191 return verifyPackUnPack(*this, true, resTp, valuesTp, lvlsTp); 1192 } 1193 1194 LogicalResult DisassembleOp::verify() { 1195 if (getOutValues().getType() != getRetValues().getType()) 1196 return emitError("output values and return value type mismatch"); 1197 1198 for (auto [ot, rt] : llvm::zip_equal(getOutLevels(), getRetLevels())) 1199 if (ot.getType() != rt.getType()) 1200 return emitError("output levels and return levels type mismatch"); 1201 1202 const auto valuesTp = getRankedTensorType(getRetValues()); 1203 const auto lvlsTp = getRetLevels().getTypes(); 1204 const auto srcTp = getSparseTensorType(getTensor()); 1205 return verifyPackUnPack(*this, false, srcTp, valuesTp, lvlsTp); 1206 } 1207 1208 LogicalResult ConvertOp::verify() { 1209 if (auto tp1 = llvm::dyn_cast<RankedTensorType>(getSource().getType())) { 1210 if (auto tp2 = llvm::dyn_cast<RankedTensorType>(getDest().getType())) { 1211 if (tp1.getRank() != tp2.getRank()) 1212 return emitError("unexpected conversion mismatch in rank"); 1213 auto dstEnc = 1214 llvm::dyn_cast_or_null<SparseTensorEncodingAttr>(tp2.getEncoding()); 1215 if (dstEnc && dstEnc.isSlice()) 1216 return emitError("cannot convert to a sparse tensor slice"); 1217 1218 auto shape1 = tp1.getShape(); 1219 auto shape2 = tp2.getShape(); 1220 // Accept size matches between the source and the destination type 1221 // (e.g. 10 vs. 10, 10 vs. ?, or ? vs. ?), but reject direct mismatches or 1222 // matches that would need a runtime assert (e.g. 10 vs. 20 or ? vs. 10). 1223 for (Dimension d = 0, dimRank = tp1.getRank(); d < dimRank; d++) 1224 if (shape1[d] != shape2[d] && shape2[d] != ShapedType::kDynamic) 1225 return emitError("unexpected conversion mismatch in dimension ") << d; 1226 return success(); 1227 } 1228 } 1229 return emitError("unexpected type in convert"); 1230 } 1231 1232 OpFoldResult ConvertOp::fold(FoldAdaptor adaptor) { 1233 if (getType() == getSource().getType()) 1234 return getSource(); 1235 return {}; 1236 } 1237 1238 bool ConvertOp::needsExtraSort() { 1239 SparseTensorType srcStt = getSparseTensorType(getSource()); 1240 SparseTensorType dstStt = getSparseTensorType(getDest()); 1241 1242 // We do not need an extra sort when returning unordered sparse tensors or 1243 // dense tensor since dense tensor support random access. 1244 if (dstStt.isAllDense() || !dstStt.isAllOrdered()) 1245 return false; 1246 1247 if (srcStt.isAllOrdered() && dstStt.isAllOrdered() && 1248 srcStt.hasSameDimToLvl(dstStt)) { 1249 return false; 1250 } 1251 1252 // Source and dest tensors are ordered in different ways. We only do direct 1253 // dense to sparse conversion when the dense input is defined by a sparse 1254 // constant. Note that we can theoretically always directly convert from dense 1255 // inputs by rotating dense loops but it leads to bad cache locality and hurt 1256 // performance. 1257 if (auto constOp = getSource().getDefiningOp<arith::ConstantOp>()) 1258 if (isa<SparseElementsAttr>(constOp.getValue())) 1259 return false; 1260 1261 return true; 1262 } 1263 1264 LogicalResult CrdTranslateOp::verify() { 1265 uint64_t inRank = getEncoder().getLvlRank(); 1266 uint64_t outRank = getEncoder().getDimRank(); 1267 1268 if (getDirection() == CrdTransDirectionKind::dim2lvl) 1269 std::swap(inRank, outRank); 1270 1271 if (inRank != getInCrds().size() || outRank != getOutCrds().size()) 1272 return emitError("Coordinate rank mismatch with encoding"); 1273 1274 return success(); 1275 } 1276 1277 LogicalResult CrdTranslateOp::fold(FoldAdaptor adaptor, 1278 SmallVectorImpl<OpFoldResult> &results) { 1279 if (getEncoder().isIdentity()) { 1280 results.assign(getInCrds().begin(), getInCrds().end()); 1281 return success(); 1282 } 1283 if (getEncoder().isPermutation()) { 1284 AffineMap perm = getDirection() == CrdTransDirectionKind::dim2lvl 1285 ? getEncoder().getDimToLvl() 1286 : getEncoder().getLvlToDim(); 1287 for (AffineExpr exp : perm.getResults()) 1288 results.push_back(getInCrds()[cast<AffineDimExpr>(exp).getPosition()]); 1289 return success(); 1290 } 1291 1292 // Fuse dim2lvl/lvl2dim pairs. 1293 auto def = getInCrds()[0].getDefiningOp<CrdTranslateOp>(); 1294 bool sameDef = def && llvm::all_of(getInCrds(), [def](Value v) { 1295 return v.getDefiningOp() == def; 1296 }); 1297 if (!sameDef) 1298 return failure(); 1299 1300 bool oppositeDir = def.getDirection() != getDirection(); 1301 bool sameOracle = 1302 def.getEncoder().getDimToLvl() == getEncoder().getDimToLvl(); 1303 bool sameCount = def.getNumResults() == getInCrds().size(); 1304 if (!oppositeDir || !sameOracle || !sameCount) 1305 return failure(); 1306 1307 // The definition produces the coordinates in the same order as the input 1308 // coordinates. 1309 bool sameOrder = llvm::all_of(llvm::zip_equal(def.getOutCrds(), getInCrds()), 1310 [](auto valuePair) { 1311 auto [lhs, rhs] = valuePair; 1312 return lhs == rhs; 1313 }); 1314 1315 if (!sameOrder) 1316 return failure(); 1317 // l1 = dim2lvl (lvl2dim l0) 1318 // ==> l0 1319 results.append(def.getInCrds().begin(), def.getInCrds().end()); 1320 return success(); 1321 } 1322 1323 void LvlOp::build(OpBuilder &builder, OperationState &state, Value source, 1324 int64_t index) { 1325 Value val = builder.create<arith::ConstantIndexOp>(state.location, index); 1326 return build(builder, state, source, val); 1327 } 1328 1329 LogicalResult LvlOp::verify() { 1330 if (std::optional<uint64_t> lvl = getConstantLvlIndex()) { 1331 auto stt = getSparseTensorType(getSource()); 1332 if (static_cast<uint64_t>(lvl.value()) >= stt.getLvlRank()) 1333 emitError("Level index exceeds the rank of the input sparse tensor"); 1334 } 1335 return success(); 1336 } 1337 1338 std::optional<uint64_t> LvlOp::getConstantLvlIndex() { 1339 return getConstantIntValue(getIndex()); 1340 } 1341 1342 Speculation::Speculatability LvlOp::getSpeculatability() { 1343 auto constantIndex = getConstantLvlIndex(); 1344 if (!constantIndex) 1345 return Speculation::NotSpeculatable; 1346 1347 assert(constantIndex < 1348 cast<RankedTensorType>(getSource().getType()).getRank()); 1349 return Speculation::Speculatable; 1350 } 1351 1352 OpFoldResult LvlOp::fold(FoldAdaptor adaptor) { 1353 auto lvlIndex = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getIndex()); 1354 if (!lvlIndex) 1355 return {}; 1356 1357 Level lvl = lvlIndex.getAPSInt().getZExtValue(); 1358 auto stt = getSparseTensorType(getSource()); 1359 if (lvl >= stt.getLvlRank()) { 1360 // Follows the same convention used by tensor.dim operation. Out of bound 1361 // indices produce undefined behavior but are still valid IR. Don't choke on 1362 // them. 1363 return {}; 1364 } 1365 1366 // Helper lambda to build an IndexAttr. 1367 auto getIndexAttr = [this](int64_t lvlSz) { 1368 return IntegerAttr::get(IndexType::get(getContext()), APInt(64, lvlSz)); 1369 }; 1370 1371 SmallVector<Size> lvlShape = stt.getLvlShape(); 1372 if (!ShapedType::isDynamic(lvlShape[lvl])) 1373 return getIndexAttr(lvlShape[lvl]); 1374 1375 return {}; 1376 } 1377 1378 void ReinterpretMapOp::build(OpBuilder &odsBuilder, OperationState &odsState, 1379 SparseTensorEncodingAttr dstEnc, Value source) { 1380 auto srcStt = getSparseTensorType(source); 1381 SmallVector<int64_t> srcLvlShape = srcStt.getLvlShape(); 1382 SmallVector<int64_t> dstDimShape = 1383 dstEnc.translateShape(srcLvlShape, CrdTransDirectionKind::lvl2dim); 1384 auto dstTp = 1385 RankedTensorType::get(dstDimShape, srcStt.getElementType(), dstEnc); 1386 return build(odsBuilder, odsState, dstTp, source); 1387 } 1388 1389 LogicalResult ReinterpretMapOp::verify() { 1390 auto srcStt = getSparseTensorType(getSource()); 1391 auto dstStt = getSparseTensorType(getDest()); 1392 ArrayRef<LevelType> srcLvlTps = srcStt.getLvlTypes(); 1393 ArrayRef<LevelType> dstLvlTps = dstStt.getLvlTypes(); 1394 1395 if (srcLvlTps.size() != dstLvlTps.size()) 1396 return emitError("Level rank mismatch between source/dest tensors"); 1397 1398 for (auto [srcLvlTp, dstLvlTp] : llvm::zip(srcLvlTps, dstLvlTps)) 1399 if (srcLvlTp != dstLvlTp) 1400 return emitError("Level type mismatch between source/dest tensors"); 1401 1402 if (srcStt.getPosWidth() != dstStt.getPosWidth() || 1403 srcStt.getCrdWidth() != dstStt.getCrdWidth()) { 1404 return emitError("Crd/Pos width mismatch between source/dest tensors"); 1405 } 1406 1407 if (srcStt.getElementType() != dstStt.getElementType()) 1408 return emitError("Element type mismatch between source/dest tensors"); 1409 1410 SmallVector<Size> srcLvlShape = srcStt.getLvlShape(); 1411 SmallVector<Size> dstLvlShape = dstStt.getLvlShape(); 1412 for (auto [srcLvlSz, dstLvlSz] : llvm::zip(srcLvlShape, dstLvlShape)) { 1413 if (srcLvlSz != dstLvlSz) { 1414 // Should we allow one side to be dynamic size, e.g., <?x?> should be 1415 // compatible to <3x4>? For now, we require all the level sizes to be 1416 // *exactly* matched for simplicity. 1417 return emitError("Level size mismatch between source/dest tensors"); 1418 } 1419 } 1420 1421 return success(); 1422 } 1423 1424 OpFoldResult ReinterpretMapOp::fold(FoldAdaptor adaptor) { 1425 if (getSource().getType() == getDest().getType()) 1426 return getSource(); 1427 1428 if (auto def = getSource().getDefiningOp<ReinterpretMapOp>()) { 1429 // A -> B, B -> A ==> A 1430 if (def.getSource().getType() == getDest().getType()) 1431 return def.getSource(); 1432 } 1433 return {}; 1434 } 1435 1436 template <typename ToBufferOp> 1437 static LogicalResult inferSparseBufferType(ValueRange ops, DictionaryAttr attr, 1438 OpaqueProperties prop, 1439 RegionRange region, 1440 SmallVectorImpl<mlir::Type> &ret) { 1441 typename ToBufferOp::Adaptor adaptor(ops, attr, prop, region); 1442 SparseTensorType stt = getSparseTensorType(adaptor.getTensor()); 1443 Type elemTp = nullptr; 1444 bool withStride = false; 1445 if constexpr (std::is_same_v<ToBufferOp, ToPositionsOp>) { 1446 elemTp = stt.getPosType(); 1447 } else if constexpr (std::is_same_v<ToBufferOp, ToCoordinatesOp> || 1448 std::is_same_v<ToBufferOp, ToCoordinatesBufferOp>) { 1449 elemTp = stt.getCrdType(); 1450 if constexpr (std::is_same_v<ToBufferOp, ToCoordinatesOp>) 1451 withStride = stt.getAoSCOOStart() <= adaptor.getLevel(); 1452 } else if constexpr (std::is_same_v<ToBufferOp, ToValuesOp>) { 1453 elemTp = stt.getElementType(); 1454 } 1455 1456 assert(elemTp && "unhandled operation."); 1457 SmallVector<int64_t> bufShape = stt.getBatchLvlShape(); 1458 bufShape.push_back(ShapedType::kDynamic); 1459 1460 auto layout = withStride ? StridedLayoutAttr::StridedLayoutAttr::get( 1461 stt.getContext(), ShapedType::kDynamic, 1462 {ShapedType::kDynamic}) 1463 : StridedLayoutAttr(); 1464 ret.emplace_back(MemRefType::get(bufShape, elemTp, layout)); 1465 return success(); 1466 } 1467 1468 LogicalResult ToPositionsOp::verify() { 1469 auto stt = getSparseTensorType(getTensor()); 1470 if (failed(lvlIsInBounds(getLevel(), getTensor()))) 1471 return emitError("requested level is out of bounds"); 1472 if (failed(isMatchingWidth(getResult(), stt.getPosWidth()))) 1473 return emitError("unexpected type for positions"); 1474 return success(); 1475 } 1476 1477 LogicalResult 1478 ToPositionsOp::inferReturnTypes(MLIRContext *ctx, std::optional<Location> loc, 1479 ValueRange ops, DictionaryAttr attr, 1480 OpaqueProperties prop, RegionRange region, 1481 SmallVectorImpl<mlir::Type> &ret) { 1482 return inferSparseBufferType<ToPositionsOp>(ops, attr, prop, region, ret); 1483 } 1484 1485 LogicalResult ToCoordinatesOp::verify() { 1486 auto stt = getSparseTensorType(getTensor()); 1487 if (failed(lvlIsInBounds(getLevel(), getTensor()))) 1488 return emitError("requested level is out of bounds"); 1489 if (failed(isMatchingWidth(getResult(), stt.getCrdWidth()))) 1490 return emitError("unexpected type for coordinates"); 1491 return success(); 1492 } 1493 1494 LogicalResult 1495 ToCoordinatesOp::inferReturnTypes(MLIRContext *ctx, std::optional<Location> loc, 1496 ValueRange ops, DictionaryAttr attr, 1497 OpaqueProperties prop, RegionRange region, 1498 SmallVectorImpl<mlir::Type> &ret) { 1499 return inferSparseBufferType<ToCoordinatesOp>(ops, attr, prop, region, ret); 1500 } 1501 1502 LogicalResult ToCoordinatesBufferOp::verify() { 1503 auto stt = getSparseTensorType(getTensor()); 1504 if (stt.getAoSCOOStart() >= stt.getLvlRank()) 1505 return emitError("expected sparse tensor with a COO region"); 1506 return success(); 1507 } 1508 1509 LogicalResult ToCoordinatesBufferOp::inferReturnTypes( 1510 MLIRContext *ctx, std::optional<Location> loc, ValueRange ops, 1511 DictionaryAttr attr, OpaqueProperties prop, RegionRange region, 1512 SmallVectorImpl<mlir::Type> &ret) { 1513 return inferSparseBufferType<ToCoordinatesBufferOp>(ops, attr, prop, region, 1514 ret); 1515 } 1516 1517 LogicalResult ToValuesOp::verify() { 1518 auto stt = getSparseTensorType(getTensor()); 1519 auto mtp = getMemRefType(getResult()); 1520 if (stt.getElementType() != mtp.getElementType()) 1521 return emitError("unexpected mismatch in element types"); 1522 return success(); 1523 } 1524 1525 LogicalResult ToValuesOp::inferReturnTypes(MLIRContext *ctx, 1526 std::optional<Location> loc, 1527 ValueRange ops, DictionaryAttr attr, 1528 OpaqueProperties prop, 1529 RegionRange region, 1530 SmallVectorImpl<mlir::Type> &ret) { 1531 return inferSparseBufferType<ToValuesOp>(ops, attr, prop, region, ret); 1532 } 1533 1534 LogicalResult ToSliceOffsetOp::verify() { 1535 auto rank = getRankedTensorType(getSlice()).getRank(); 1536 if (rank <= getDim().getSExtValue() || getDim().getSExtValue() < 0) 1537 return emitError("requested dimension out of bound"); 1538 return success(); 1539 } 1540 1541 LogicalResult ToSliceStrideOp::verify() { 1542 auto rank = getRankedTensorType(getSlice()).getRank(); 1543 if (rank <= getDim().getSExtValue() || getDim().getSExtValue() < 0) 1544 return emitError("requested dimension out of bound"); 1545 return success(); 1546 } 1547 1548 LogicalResult GetStorageSpecifierOp::verify() { 1549 return verifySparsifierGetterSetter(getSpecifierKind(), getLevel(), 1550 getSpecifier(), getOperation()); 1551 } 1552 1553 template <typename SpecifierOp> 1554 static SetStorageSpecifierOp getSpecifierSetDef(SpecifierOp op) { 1555 return op.getSpecifier().template getDefiningOp<SetStorageSpecifierOp>(); 1556 } 1557 1558 OpFoldResult GetStorageSpecifierOp::fold(FoldAdaptor adaptor) { 1559 const StorageSpecifierKind kind = getSpecifierKind(); 1560 const auto lvl = getLevel(); 1561 for (auto op = getSpecifierSetDef(*this); op; op = getSpecifierSetDef(op)) 1562 if (kind == op.getSpecifierKind() && lvl == op.getLevel()) 1563 return op.getValue(); 1564 return {}; 1565 } 1566 1567 LogicalResult SetStorageSpecifierOp::verify() { 1568 return verifySparsifierGetterSetter(getSpecifierKind(), getLevel(), 1569 getSpecifier(), getOperation()); 1570 } 1571 1572 template <class T> 1573 static LogicalResult verifyNumBlockArgs(T *op, Region ®ion, 1574 const char *regionName, 1575 TypeRange inputTypes, Type outputType) { 1576 unsigned numArgs = region.getNumArguments(); 1577 unsigned expectedNum = inputTypes.size(); 1578 if (numArgs != expectedNum) 1579 return op->emitError() << regionName << " region must have exactly " 1580 << expectedNum << " arguments"; 1581 1582 for (unsigned i = 0; i < numArgs; i++) { 1583 Type typ = region.getArgument(i).getType(); 1584 if (typ != inputTypes[i]) 1585 return op->emitError() << regionName << " region argument " << (i + 1) 1586 << " type mismatch"; 1587 } 1588 Operation *term = region.front().getTerminator(); 1589 YieldOp yield = dyn_cast<YieldOp>(term); 1590 if (!yield) 1591 return op->emitError() << regionName 1592 << " region must end with sparse_tensor.yield"; 1593 if (!yield.getResult() || yield.getResult().getType() != outputType) 1594 return op->emitError() << regionName << " region yield type mismatch"; 1595 1596 return success(); 1597 } 1598 1599 LogicalResult BinaryOp::verify() { 1600 NamedAttrList attrs = (*this)->getAttrs(); 1601 Type leftType = getX().getType(); 1602 Type rightType = getY().getType(); 1603 Type outputType = getOutput().getType(); 1604 Region &overlap = getOverlapRegion(); 1605 Region &left = getLeftRegion(); 1606 Region &right = getRightRegion(); 1607 1608 // Check correct number of block arguments and return type for each 1609 // non-empty region. 1610 if (!overlap.empty()) { 1611 if (failed(verifyNumBlockArgs(this, overlap, "overlap", 1612 TypeRange{leftType, rightType}, outputType))) 1613 return failure(); 1614 } 1615 if (!left.empty()) { 1616 if (failed(verifyNumBlockArgs(this, left, "left", TypeRange{leftType}, 1617 outputType))) 1618 return failure(); 1619 } else if (getLeftIdentity()) { 1620 if (leftType != outputType) 1621 return emitError("left=identity requires first argument to have the same " 1622 "type as the output"); 1623 } 1624 if (!right.empty()) { 1625 if (failed(verifyNumBlockArgs(this, right, "right", TypeRange{rightType}, 1626 outputType))) 1627 return failure(); 1628 } else if (getRightIdentity()) { 1629 if (rightType != outputType) 1630 return emitError("right=identity requires second argument to have the " 1631 "same type as the output"); 1632 } 1633 return success(); 1634 } 1635 1636 LogicalResult UnaryOp::verify() { 1637 Type inputType = getX().getType(); 1638 Type outputType = getOutput().getType(); 1639 1640 // Check correct number of block arguments and return type for each 1641 // non-empty region. 1642 Region &present = getPresentRegion(); 1643 if (!present.empty()) { 1644 if (failed(verifyNumBlockArgs(this, present, "present", 1645 TypeRange{inputType}, outputType))) 1646 return failure(); 1647 } 1648 Region &absent = getAbsentRegion(); 1649 if (!absent.empty()) { 1650 if (failed(verifyNumBlockArgs(this, absent, "absent", TypeRange{}, 1651 outputType))) 1652 return failure(); 1653 // Absent branch can only yield invariant values. 1654 Block *absentBlock = &absent.front(); 1655 Block *parent = getOperation()->getBlock(); 1656 Value absentVal = cast<YieldOp>(absentBlock->getTerminator()).getResult(); 1657 if (auto arg = dyn_cast<BlockArgument>(absentVal)) { 1658 if (arg.getOwner() == parent) 1659 return emitError("absent region cannot yield linalg argument"); 1660 } else if (Operation *def = absentVal.getDefiningOp()) { 1661 if (!isa<arith::ConstantOp>(def) && 1662 (def->getBlock() == absentBlock || def->getBlock() == parent)) 1663 return emitError("absent region cannot yield locally computed value"); 1664 } 1665 } 1666 return success(); 1667 } 1668 1669 bool ConcatenateOp::needsExtraSort() { 1670 SparseTensorType dstStt = getSparseTensorType(*this); 1671 if (dstStt.isAllDense() || !dstStt.isAllOrdered()) 1672 return false; 1673 1674 bool allSameOrdered = llvm::all_of(getInputs(), [dstStt](Value op) { 1675 return getSparseTensorType(op).hasSameDimToLvl(dstStt); 1676 }); 1677 // TODO: When conDim != 0, as long as conDim corresponding to the first level 1678 // in all input/output buffers, and all input/output buffers have the same 1679 // dimToLvl, the tmp COO buffer is still unnecessary (e.g, concatenate 1680 // CSC matrices along column). 1681 bool directLowerable = 1682 allSameOrdered && getDimension() == 0 && dstStt.isIdentity(); 1683 return !directLowerable; 1684 } 1685 1686 LogicalResult ConcatenateOp::verify() { 1687 const auto dstTp = getSparseTensorType(*this); 1688 const Dimension concatDim = getDimension(); 1689 const Dimension dimRank = dstTp.getDimRank(); 1690 1691 if (getInputs().size() <= 1) 1692 return emitError("Need at least two tensors to concatenate."); 1693 1694 if (concatDim >= dimRank) 1695 return emitError(llvm::formatv( 1696 "Concat-dimension is out of bounds for dimension-rank ({0} >= {1})", 1697 concatDim, dimRank)); 1698 1699 for (const auto &it : llvm::enumerate(getInputs())) { 1700 const auto i = it.index(); 1701 const auto srcTp = getSparseTensorType(it.value()); 1702 if (srcTp.hasDynamicDimShape()) 1703 return emitError(llvm::formatv("Input tensor ${0} has dynamic shape", i)); 1704 const Dimension srcDimRank = srcTp.getDimRank(); 1705 if (srcDimRank != dimRank) 1706 return emitError( 1707 llvm::formatv("Input tensor ${0} has a different rank (rank={1}) " 1708 "from the output tensor (rank={2}).", 1709 i, srcDimRank, dimRank)); 1710 } 1711 1712 for (Dimension d = 0; d < dimRank; d++) { 1713 const Size dstSh = dstTp.getDimShape()[d]; 1714 if (d == concatDim) { 1715 if (!ShapedType::isDynamic(dstSh)) { 1716 // If we reach here, then all inputs have static shapes. So we 1717 // can use `getDimShape()[d]` instead of `*getDynamicDimSize(d)` 1718 // to avoid redundant assertions in the loop. 1719 Size sumSz = 0; 1720 for (const auto src : getInputs()) 1721 sumSz += getSparseTensorType(src).getDimShape()[d]; 1722 // If all dimension are statically known, the sum of all the input 1723 // dimensions should be equal to the output dimension. 1724 if (sumSz != dstSh) 1725 return emitError( 1726 "The concatenation dimension of the output tensor should be the " 1727 "sum of all the concatenation dimensions of the input tensors."); 1728 } 1729 } else { 1730 Size prev = dstSh; 1731 for (const auto src : getInputs()) { 1732 const auto sh = getSparseTensorType(src).getDimShape()[d]; 1733 if (!ShapedType::isDynamic(prev) && sh != prev) 1734 return emitError("All dimensions (expect for the concatenating one) " 1735 "should be equal."); 1736 prev = sh; 1737 } 1738 } 1739 } 1740 1741 return success(); 1742 } 1743 1744 void PushBackOp::build(OpBuilder &builder, OperationState &result, 1745 Value curSize, Value inBuffer, Value value) { 1746 build(builder, result, curSize, inBuffer, value, Value()); 1747 } 1748 1749 LogicalResult PushBackOp::verify() { 1750 if (Value n = getN()) { 1751 std::optional<int64_t> nValue = getConstantIntValue(n); 1752 if (nValue && nValue.value() < 1) 1753 return emitOpError("n must be not less than 1"); 1754 } 1755 return success(); 1756 } 1757 1758 LogicalResult CompressOp::verify() { 1759 const auto stt = getSparseTensorType(getTensor()); 1760 if (stt.getLvlRank() != 1 + static_cast<Level>(getLvlCoords().size())) 1761 return emitOpError("incorrect number of coordinates"); 1762 return success(); 1763 } 1764 1765 void ForeachOp::build( 1766 OpBuilder &builder, OperationState &result, Value tensor, 1767 ValueRange initArgs, AffineMapAttr order, 1768 function_ref<void(OpBuilder &, Location, ValueRange, Value, ValueRange)> 1769 bodyBuilder) { 1770 build(builder, result, initArgs.getTypes(), tensor, initArgs, order); 1771 // Builds foreach body. 1772 if (!bodyBuilder) 1773 return; 1774 const auto stt = getSparseTensorType(tensor); 1775 const Dimension dimRank = stt.getDimRank(); 1776 1777 // Starts with `dimRank`-many coordinates. 1778 SmallVector<Type> blockArgTypes(dimRank, builder.getIndexType()); 1779 // Followed by one value. 1780 blockArgTypes.push_back(stt.getElementType()); 1781 // Followed by the reduction variables. 1782 blockArgTypes.append(initArgs.getTypes().begin(), initArgs.getTypes().end()); 1783 1784 SmallVector<Location> blockArgLocs(blockArgTypes.size(), tensor.getLoc()); 1785 1786 OpBuilder::InsertionGuard guard(builder); 1787 auto ®ion = *result.regions.front(); 1788 Block *bodyBlock = 1789 builder.createBlock(®ion, region.end(), blockArgTypes, blockArgLocs); 1790 bodyBuilder(builder, result.location, 1791 bodyBlock->getArguments().slice(0, dimRank), 1792 bodyBlock->getArguments()[dimRank], 1793 bodyBlock->getArguments().drop_front(dimRank + 1)); 1794 } 1795 1796 LogicalResult ForeachOp::verify() { 1797 const auto t = getSparseTensorType(getTensor()); 1798 const Dimension dimRank = t.getDimRank(); 1799 const auto args = getBody()->getArguments(); 1800 1801 if (getOrder().has_value() && getOrder()->getNumDims() != t.getLvlRank()) 1802 return emitError("Level traverse order does not match tensor's level rank"); 1803 1804 if (dimRank + 1 + getInitArgs().size() != args.size()) 1805 return emitError("Unmatched number of arguments in the block"); 1806 1807 if (getNumResults() != getInitArgs().size()) 1808 return emitError("Mismatch in number of init arguments and results"); 1809 1810 if (getResultTypes() != getInitArgs().getTypes()) 1811 return emitError("Mismatch in types of init arguments and results"); 1812 1813 // Cannot mark this const, because the getters aren't. 1814 auto yield = cast<YieldOp>(getBody()->getTerminator()); 1815 if (yield.getNumOperands() != getNumResults() || 1816 yield.getOperands().getTypes() != getResultTypes()) 1817 return emitError("Mismatch in types of yield values and results"); 1818 1819 const auto iTp = IndexType::get(getContext()); 1820 for (Dimension d = 0; d < dimRank; d++) 1821 if (args[d].getType() != iTp) 1822 emitError( 1823 llvm::formatv("Expecting Index type for argument at index {0}", d)); 1824 1825 const auto elemTp = t.getElementType(); 1826 const auto valueTp = args[dimRank].getType(); 1827 if (elemTp != valueTp) 1828 emitError(llvm::formatv("Unmatched element type between input tensor and " 1829 "block argument, expected:{0}, got: {1}", 1830 elemTp, valueTp)); 1831 return success(); 1832 } 1833 1834 OpFoldResult ReorderCOOOp::fold(FoldAdaptor adaptor) { 1835 if (getSparseTensorEncoding(getInputCoo().getType()) == 1836 getSparseTensorEncoding(getResultCoo().getType())) 1837 return getInputCoo(); 1838 1839 return {}; 1840 } 1841 1842 LogicalResult ReorderCOOOp::verify() { 1843 SparseTensorType srcStt = getSparseTensorType(getInputCoo()); 1844 SparseTensorType dstStt = getSparseTensorType(getResultCoo()); 1845 1846 if (!srcStt.isCOOType() || !dstStt.isCOOType()) 1847 emitError("Expected COO sparse tensors only"); 1848 1849 if (!srcStt.hasSameDimToLvl(dstStt)) 1850 emitError("Unmatched dim2lvl map between input and result COO"); 1851 1852 if (srcStt.getPosType() != dstStt.getPosType() || 1853 srcStt.getCrdType() != dstStt.getCrdType() || 1854 srcStt.getElementType() != dstStt.getElementType()) 1855 emitError("Unmatched storage format between input and result COO"); 1856 1857 return success(); 1858 } 1859 1860 LogicalResult ReduceOp::verify() { 1861 Type inputType = getX().getType(); 1862 Region &formula = getRegion(); 1863 return verifyNumBlockArgs(this, formula, "reduce", 1864 TypeRange{inputType, inputType}, inputType); 1865 } 1866 1867 LogicalResult SelectOp::verify() { 1868 Builder b(getContext()); 1869 Type inputType = getX().getType(); 1870 Type boolType = b.getI1Type(); 1871 Region &formula = getRegion(); 1872 return verifyNumBlockArgs(this, formula, "select", TypeRange{inputType}, 1873 boolType); 1874 } 1875 1876 LogicalResult SortOp::verify() { 1877 AffineMap xPerm = getPermMap(); 1878 uint64_t nx = xPerm.getNumDims(); 1879 if (nx < 1) 1880 emitError(llvm::formatv("Expected rank(perm_map) > 1, got {0}", nx)); 1881 1882 if (!xPerm.isPermutation()) 1883 emitError(llvm::formatv("Expected a permutation map, got {0}", xPerm)); 1884 1885 // We can't check the size of the buffers when n or buffer dimensions aren't 1886 // compile-time constants. 1887 std::optional<int64_t> cn = getConstantIntValue(getN()); 1888 if (!cn) 1889 return success(); 1890 1891 // Verify dimensions. 1892 const auto checkDim = [&](Value v, Size minSize, const char *message) { 1893 const Size sh = getMemRefType(v).getShape()[0]; 1894 if (!ShapedType::isDynamic(sh) && sh < minSize) 1895 emitError(llvm::formatv("{0} got {1} < {2}", message, sh, minSize)); 1896 }; 1897 uint64_t n = cn.value(); 1898 uint64_t ny = 0; 1899 if (auto nyAttr = getNyAttr()) 1900 ny = nyAttr.getInt(); 1901 checkDim(getXy(), n * (nx + ny), 1902 "Expected dimension(xy) >= n * (rank(perm_map) + ny)"); 1903 for (Value opnd : getYs()) 1904 checkDim(opnd, n, "Expected dimension(y) >= n"); 1905 1906 return success(); 1907 } 1908 1909 LogicalResult YieldOp::verify() { 1910 // Check for compatible parent. 1911 auto *parentOp = (*this)->getParentOp(); 1912 if (isa<BinaryOp>(parentOp) || isa<UnaryOp>(parentOp) || 1913 isa<ReduceOp>(parentOp) || isa<SelectOp>(parentOp) || 1914 isa<ForeachOp>(parentOp)) 1915 return success(); 1916 1917 return emitOpError("expected parent op to be sparse_tensor unary, binary, " 1918 "reduce, select or foreach"); 1919 } 1920 1921 /// Materialize a single constant operation from a given attribute value with 1922 /// the desired resultant type. 1923 Operation *SparseTensorDialect::materializeConstant(OpBuilder &builder, 1924 Attribute value, Type type, 1925 Location loc) { 1926 if (auto op = arith::ConstantOp::materialize(builder, value, type, loc)) 1927 return op; 1928 return nullptr; 1929 } 1930 1931 namespace { 1932 struct SparseTensorAsmDialectInterface : public OpAsmDialectInterface { 1933 using OpAsmDialectInterface::OpAsmDialectInterface; 1934 1935 AliasResult getAlias(Attribute attr, raw_ostream &os) const override { 1936 if (attr.isa<SparseTensorEncodingAttr>()) { 1937 os << "sparse"; 1938 return AliasResult::OverridableAlias; 1939 } 1940 return AliasResult::NoAlias; 1941 } 1942 }; 1943 } // namespace 1944 1945 void SparseTensorDialect::initialize() { 1946 addInterface<SparseTensorAsmDialectInterface>(); 1947 addAttributes< 1948 #define GET_ATTRDEF_LIST 1949 #include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.cpp.inc" 1950 >(); 1951 addTypes< 1952 #define GET_TYPEDEF_LIST 1953 #include "mlir/Dialect/SparseTensor/IR/SparseTensorTypes.cpp.inc" 1954 >(); 1955 addOperations< 1956 #define GET_OP_LIST 1957 #include "mlir/Dialect/SparseTensor/IR/SparseTensorOps.cpp.inc" 1958 >(); 1959 } 1960 1961 #define GET_OP_CLASSES 1962 #include "mlir/Dialect/SparseTensor/IR/SparseTensorOps.cpp.inc" 1963 1964 #include "mlir/Dialect/SparseTensor/IR/SparseTensorOpsDialect.cpp.inc" 1965