1 //===- SparseTensorDialect.cpp - Sparse tensor dialect implementation -----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include <utility> 10 11 #include "Detail/DimLvlMapParser.h" 12 13 #include "mlir/Dialect/SparseTensor/IR/Enums.h" 14 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" 15 #include "mlir/Dialect/SparseTensor/IR/SparseTensorStorageLayout.h" 16 #include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h" 17 18 #include "mlir/Dialect/Arith/IR/Arith.h" 19 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" 20 #include "mlir/Dialect/Complex/IR/Complex.h" 21 #include "mlir/Dialect/Utils/StaticValueUtils.h" 22 #include "mlir/IR/Builders.h" 23 #include "mlir/IR/DialectImplementation.h" 24 #include "mlir/IR/Matchers.h" 25 #include "mlir/IR/OpImplementation.h" 26 #include "mlir/IR/PatternMatch.h" 27 #include "llvm/ADT/Bitset.h" 28 #include "llvm/ADT/TypeSwitch.h" 29 #include "llvm/Support/FormatVariadic.h" 30 31 #define GET_ATTRDEF_CLASSES 32 #include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.cpp.inc" 33 #include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrEnums.cpp.inc" 34 35 // Forward declarations, following custom print/parsing methods are referenced 36 // by the generated code for SparseTensorTypes.td. 37 static mlir::ParseResult parseLevelRange(mlir::AsmParser &, 38 mlir::sparse_tensor::Level &, 39 mlir::sparse_tensor::Level &); 40 static void printLevelRange(mlir::AsmPrinter &, mlir::sparse_tensor::Level, 41 mlir::sparse_tensor::Level); 42 43 #define GET_TYPEDEF_CLASSES 44 #include "mlir/Dialect/SparseTensor/IR/SparseTensorTypes.cpp.inc" 45 46 using namespace mlir; 47 using namespace mlir::sparse_tensor; 48 49 // Support hashing LevelType such that SparseTensorEncodingAttr can be hashed as 50 // well. 51 namespace mlir::sparse_tensor { 52 llvm::hash_code hash_value(LevelType lt) { 53 return llvm::hash_value(static_cast<uint64_t>(lt)); 54 } 55 } // namespace mlir::sparse_tensor 56 57 //===----------------------------------------------------------------------===// 58 // Local Convenience Methods. 59 //===----------------------------------------------------------------------===// 60 61 static constexpr bool acceptBitWidth(unsigned bitWidth) { 62 switch (bitWidth) { 63 case 0: 64 case 8: 65 case 16: 66 case 32: 67 case 64: 68 return true; 69 default: 70 return false; 71 } 72 } 73 74 static SmallVector<Size> 75 getSparseFieldShape(const SparseTensorEncodingAttr enc, 76 std::optional<ArrayRef<int64_t>> dimShape) { 77 assert(enc); 78 // With only encoding, we can not determine the static shape for leading 79 // batch levels, we therefore return a dynamic shape memref instead. 80 SmallVector<int64_t> memrefShape(enc.getBatchLvlRank(), ShapedType::kDynamic); 81 if (dimShape.has_value()) { 82 // If the actual tensor shape is provided, we can then refine the leading 83 // batch dimension. 84 SmallVector<int64_t> lvlShape = 85 enc.translateShape(*dimShape, CrdTransDirectionKind::dim2lvl); 86 memrefShape.assign(lvlShape.begin(), 87 lvlShape.begin() + enc.getBatchLvlRank()); 88 } 89 // Another dynamic dimension to store the sparse level. 90 memrefShape.push_back(ShapedType::kDynamic); 91 return memrefShape; 92 } 93 94 //===----------------------------------------------------------------------===// 95 // SparseTensorDialect StorageLayout. 96 //===----------------------------------------------------------------------===// 97 98 static constexpr Level kInvalidLevel = -1u; 99 static constexpr Level kInvalidFieldIndex = -1u; 100 static constexpr FieldIndex kDataFieldStartingIdx = 0; 101 102 void StorageLayout::foreachField( 103 llvm::function_ref<bool(FieldIndex, SparseTensorFieldKind, Level, 104 LevelType)> 105 callback) const { 106 const auto lvlTypes = enc.getLvlTypes(); 107 const Level lvlRank = enc.getLvlRank(); 108 SmallVector<COOSegment> cooSegs = enc.getCOOSegments(); 109 FieldIndex fieldIdx = kDataFieldStartingIdx; 110 111 ArrayRef cooSegsRef = cooSegs; 112 // Per-level storage. 113 for (Level l = 0; l < lvlRank; /*l += 1 or l += AoSCooLen*/) { 114 const auto lt = lvlTypes[l]; 115 if (isWithPosLT(lt)) { 116 if (!(callback(fieldIdx++, SparseTensorFieldKind::PosMemRef, l, lt))) 117 return; 118 } 119 if (isWithCrdLT(lt)) { 120 if (!(callback(fieldIdx++, SparseTensorFieldKind::CrdMemRef, l, lt))) 121 return; 122 } 123 if (!cooSegsRef.empty() && cooSegsRef.front().isSegmentStart(l)) { 124 if (!cooSegsRef.front().isSoA) { 125 // AoS COO, all singletons are fused into one memrefs. Skips the entire 126 // COO segement. 127 l = cooSegsRef.front().lvlRange.second; 128 } else { 129 // SoA COO, each singleton level has one memref. 130 l++; 131 } 132 // Expire handled COO segment. 133 cooSegsRef = cooSegsRef.drop_front(); 134 } else { 135 // Non COO levels. 136 l++; 137 } 138 } 139 // The values array. 140 if (!(callback(fieldIdx++, SparseTensorFieldKind::ValMemRef, kInvalidLevel, 141 LevelFormat::Undef))) 142 return; 143 // Put metadata at the end. 144 if (!(callback(fieldIdx++, SparseTensorFieldKind::StorageSpec, kInvalidLevel, 145 LevelFormat::Undef))) 146 return; 147 } 148 149 void sparse_tensor::foreachFieldAndTypeInSparseTensor( 150 SparseTensorType stt, 151 llvm::function_ref<bool(Type, FieldIndex, SparseTensorFieldKind, Level, 152 LevelType)> 153 callback) { 154 assert(stt.hasEncoding()); 155 156 SmallVector<int64_t> memrefShape = 157 getSparseFieldShape(stt.getEncoding(), stt.getDimShape()); 158 159 const Type specType = StorageSpecifierType::get(stt.getEncoding()); 160 // memref<[batch] x ? x pos> positions 161 const Type posMemType = MemRefType::get(memrefShape, stt.getPosType()); 162 // memref<[batch] x ? x crd> coordinates 163 const Type crdMemType = MemRefType::get(memrefShape, stt.getCrdType()); 164 // memref<[batch] x ? x eltType> values 165 const Type valMemType = MemRefType::get(memrefShape, stt.getElementType()); 166 167 StorageLayout(stt).foreachField([specType, posMemType, crdMemType, valMemType, 168 callback](FieldIndex fieldIdx, 169 SparseTensorFieldKind fieldKind, 170 Level lvl, LevelType lt) -> bool { 171 switch (fieldKind) { 172 case SparseTensorFieldKind::StorageSpec: 173 return callback(specType, fieldIdx, fieldKind, lvl, lt); 174 case SparseTensorFieldKind::PosMemRef: 175 return callback(posMemType, fieldIdx, fieldKind, lvl, lt); 176 case SparseTensorFieldKind::CrdMemRef: 177 return callback(crdMemType, fieldIdx, fieldKind, lvl, lt); 178 case SparseTensorFieldKind::ValMemRef: 179 return callback(valMemType, fieldIdx, fieldKind, lvl, lt); 180 }; 181 llvm_unreachable("unrecognized field kind"); 182 }); 183 } 184 185 unsigned StorageLayout::getNumFields() const { 186 unsigned numFields = 0; 187 foreachField([&numFields](FieldIndex, SparseTensorFieldKind, Level, 188 LevelType) -> bool { 189 numFields++; 190 return true; 191 }); 192 return numFields; 193 } 194 195 unsigned StorageLayout::getNumDataFields() const { 196 unsigned numFields = 0; // one value memref 197 foreachField([&numFields](FieldIndex fidx, SparseTensorFieldKind, Level, 198 LevelType) -> bool { 199 if (fidx >= kDataFieldStartingIdx) 200 numFields++; 201 return true; 202 }); 203 numFields -= 1; // the last field is StorageSpecifier 204 assert(numFields == getNumFields() - kDataFieldStartingIdx - 1); 205 return numFields; 206 } 207 208 std::pair<FieldIndex, unsigned> 209 StorageLayout::getFieldIndexAndStride(SparseTensorFieldKind kind, 210 std::optional<Level> lvl) const { 211 FieldIndex fieldIdx = kInvalidFieldIndex; 212 unsigned stride = 1; 213 if (kind == SparseTensorFieldKind::CrdMemRef) { 214 assert(lvl.has_value()); 215 const Level cooStart = enc.getAoSCOOStart(); 216 const Level lvlRank = enc.getLvlRank(); 217 if (lvl.value() >= cooStart && lvl.value() < lvlRank) { 218 lvl = cooStart; 219 stride = lvlRank - cooStart; 220 } 221 } 222 foreachField([lvl, kind, &fieldIdx](FieldIndex fIdx, 223 SparseTensorFieldKind fKind, Level fLvl, 224 LevelType lt) -> bool { 225 if ((lvl && fLvl == lvl.value() && kind == fKind) || 226 (kind == fKind && fKind == SparseTensorFieldKind::ValMemRef)) { 227 fieldIdx = fIdx; 228 // Returns false to break the iteration. 229 return false; 230 } 231 return true; 232 }); 233 assert(fieldIdx != kInvalidFieldIndex); 234 return std::pair<FieldIndex, unsigned>(fieldIdx, stride); 235 } 236 237 //===----------------------------------------------------------------------===// 238 // SparseTensorDialect Attribute Methods. 239 //===----------------------------------------------------------------------===// 240 241 std::optional<uint64_t> SparseTensorDimSliceAttr::getStatic(int64_t v) { 242 return isDynamic(v) ? std::nullopt 243 : std::make_optional(static_cast<uint64_t>(v)); 244 } 245 246 std::optional<uint64_t> SparseTensorDimSliceAttr::getStaticOffset() const { 247 return getStatic(getOffset()); 248 } 249 250 std::optional<uint64_t> SparseTensorDimSliceAttr::getStaticStride() const { 251 return getStatic(getStride()); 252 } 253 254 std::optional<uint64_t> SparseTensorDimSliceAttr::getStaticSize() const { 255 return getStatic(getSize()); 256 } 257 258 bool SparseTensorDimSliceAttr::isCompletelyDynamic() const { 259 return isDynamic(getOffset()) && isDynamic(getStride()) && 260 isDynamic(getSize()); 261 } 262 263 std::string SparseTensorDimSliceAttr::getStaticString(int64_t v) { 264 return isDynamic(v) ? "?" : std::to_string(v); 265 } 266 267 void SparseTensorDimSliceAttr::print(llvm::raw_ostream &os) const { 268 assert(getImpl() && "Uninitialized SparseTensorDimSliceAttr"); 269 os << '('; 270 os << getStaticString(getOffset()); 271 os << ", "; 272 os << getStaticString(getSize()); 273 os << ", "; 274 os << getStaticString(getStride()); 275 os << ')'; 276 } 277 278 void SparseTensorDimSliceAttr::print(AsmPrinter &printer) const { 279 print(printer.getStream()); 280 } 281 282 static ParseResult parseOptionalStaticSlice(int64_t &result, 283 AsmParser &parser) { 284 auto parseResult = parser.parseOptionalInteger(result); 285 if (parseResult.has_value()) { 286 if (parseResult.value().succeeded() && result < 0) { 287 parser.emitError( 288 parser.getCurrentLocation(), 289 "expect positive value or ? for slice offset/size/stride"); 290 return failure(); 291 } 292 return parseResult.value(); 293 } 294 295 // Else, and '?' which represented dynamic slice 296 result = SparseTensorDimSliceAttr::kDynamic; 297 return parser.parseQuestion(); 298 } 299 300 Attribute SparseTensorDimSliceAttr::parse(AsmParser &parser, Type type) { 301 int64_t offset = kDynamic, size = kDynamic, stride = kDynamic; 302 303 if (failed(parser.parseLParen()) || 304 failed(parseOptionalStaticSlice(offset, parser)) || 305 failed(parser.parseComma()) || 306 failed(parseOptionalStaticSlice(size, parser)) || 307 failed(parser.parseComma()) || 308 failed(parseOptionalStaticSlice(stride, parser)) || 309 failed(parser.parseRParen())) 310 return {}; 311 312 return parser.getChecked<SparseTensorDimSliceAttr>(parser.getContext(), 313 offset, size, stride); 314 } 315 316 LogicalResult 317 SparseTensorDimSliceAttr::verify(function_ref<InFlightDiagnostic()> emitError, 318 int64_t offset, int64_t size, int64_t stride) { 319 if (!isDynamic(offset) && offset < 0) 320 return emitError() << "expect non-negative value or ? for slice offset"; 321 if (!isDynamic(size) && size <= 0) 322 return emitError() << "expect positive value or ? for slice size"; 323 if (!isDynamic(stride) && stride <= 0) 324 return emitError() << "expect positive value or ? for slice stride"; 325 return success(); 326 } 327 328 SparseTensorEncodingAttr 329 SparseTensorEncodingAttr::withDimToLvl(AffineMap dimToLvl) const { 330 assert(getImpl() && "Uninitialized SparseTensorEncodingAttr"); 331 return SparseTensorEncodingAttr::get( 332 getContext(), getLvlTypes(), dimToLvl, AffineMap(), getPosWidth(), 333 getCrdWidth(), getExplicitVal(), getImplicitVal()); 334 } 335 336 SparseTensorEncodingAttr 337 SparseTensorEncodingAttr::withDimToLvl(SparseTensorEncodingAttr enc) const { 338 return withDimToLvl(enc ? enc.getDimToLvl() : AffineMap()); 339 } 340 341 SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutDimToLvl() const { 342 return withDimToLvl(AffineMap()); 343 } 344 345 SparseTensorEncodingAttr 346 SparseTensorEncodingAttr::withBitWidths(unsigned posWidth, 347 unsigned crdWidth) const { 348 assert(getImpl() && "Uninitialized SparseTensorEncodingAttr"); 349 return SparseTensorEncodingAttr::get( 350 getContext(), getLvlTypes(), getDimToLvl(), getLvlToDim(), posWidth, 351 crdWidth, getExplicitVal(), getImplicitVal()); 352 } 353 354 SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutBitWidths() const { 355 return withBitWidths(0, 0); 356 } 357 358 SparseTensorEncodingAttr 359 SparseTensorEncodingAttr::withExplicitVal(Attribute explicitVal) const { 360 assert(getImpl() && "Uninitialized SparseTensorEncodingAttr"); 361 return SparseTensorEncodingAttr::get( 362 getContext(), getLvlTypes(), getDimToLvl(), getLvlToDim(), getPosWidth(), 363 getCrdWidth(), explicitVal, getImplicitVal()); 364 } 365 366 SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutExplicitVal() const { 367 return withExplicitVal(Attribute()); 368 } 369 370 SparseTensorEncodingAttr 371 SparseTensorEncodingAttr::withImplicitVal(Attribute implicitVal) const { 372 assert(getImpl() && "Uninitialized SparseTensorEncodingAttr"); 373 return SparseTensorEncodingAttr::get( 374 getContext(), getLvlTypes(), getDimToLvl(), getLvlToDim(), getPosWidth(), 375 getCrdWidth(), getExplicitVal(), implicitVal); 376 } 377 378 SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutImplicitVal() const { 379 return withImplicitVal(Attribute()); 380 } 381 382 SparseTensorEncodingAttr SparseTensorEncodingAttr::withDimSlices( 383 ArrayRef<SparseTensorDimSliceAttr> dimSlices) const { 384 return SparseTensorEncodingAttr::get( 385 getContext(), getLvlTypes(), getDimToLvl(), getLvlToDim(), getPosWidth(), 386 getCrdWidth(), getExplicitVal(), getImplicitVal(), dimSlices); 387 } 388 389 SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutDimSlices() const { 390 return withDimSlices(ArrayRef<SparseTensorDimSliceAttr>{}); 391 } 392 393 uint64_t SparseTensorEncodingAttr::getBatchLvlRank() const { 394 ArrayRef<LevelType> lvlTypes = getLvlTypes(); 395 auto lastBatch = std::find_if(lvlTypes.rbegin(), lvlTypes.rend(), isBatchLT); 396 return std::distance(lastBatch, lvlTypes.rend()); 397 } 398 399 bool SparseTensorEncodingAttr::isAllDense() const { 400 return !getImpl() || llvm::all_of(getLvlTypes(), isDenseLT); 401 } 402 403 bool SparseTensorEncodingAttr::isAllOrdered() const { 404 return !getImpl() || llvm::all_of(getLvlTypes(), isOrderedLT); 405 } 406 407 Type SparseTensorEncodingAttr::getCrdElemType() const { 408 if (!getImpl()) 409 return nullptr; 410 if (getCrdWidth()) 411 return IntegerType::get(getContext(), getCrdWidth()); 412 return IndexType::get(getContext()); 413 } 414 415 Type SparseTensorEncodingAttr::getPosElemType() const { 416 if (!getImpl()) 417 return nullptr; 418 if (getPosWidth()) 419 return IntegerType::get(getContext(), getPosWidth()); 420 return IndexType::get(getContext()); 421 } 422 423 MemRefType SparseTensorEncodingAttr::getCrdMemRefType( 424 std::optional<ArrayRef<int64_t>> dimShape) const { 425 SmallVector<Size> shape = getSparseFieldShape(*this, dimShape); 426 return MemRefType::get(shape, getCrdElemType()); 427 } 428 429 MemRefType SparseTensorEncodingAttr::getPosMemRefType( 430 std::optional<ArrayRef<int64_t>> dimShape) const { 431 SmallVector<Size> shape = getSparseFieldShape(*this, dimShape); 432 return MemRefType::get(shape, getPosElemType()); 433 } 434 435 bool SparseTensorEncodingAttr::isIdentity() const { 436 return !getImpl() || !getDimToLvl() || getDimToLvl().isIdentity(); 437 } 438 439 bool SparseTensorEncodingAttr::isPermutation() const { 440 return !getImpl() || !getDimToLvl() || getDimToLvl().isPermutation(); 441 } 442 443 Dimension SparseTensorEncodingAttr::getDimRank() const { 444 assert(getImpl() && "Uninitialized SparseTensorEncodingAttr"); 445 const auto dimToLvl = getDimToLvl(); 446 return dimToLvl ? dimToLvl.getNumDims() : getLvlRank(); 447 } 448 449 Level SparseTensorEncodingAttr::getLvlRank() const { 450 assert(getImpl() && "Uninitialized SparseTensorEncodingAttr"); 451 return getLvlTypes().size(); 452 } 453 454 LevelType SparseTensorEncodingAttr::getLvlType(Level l) const { 455 if (!getImpl()) 456 return LevelFormat::Batch; 457 assert(l < getLvlRank() && "Level is out of bounds"); 458 return getLvlTypes()[l]; 459 } 460 461 bool SparseTensorEncodingAttr::isSlice() const { 462 assert(getImpl() && "Uninitialized SparseTensorEncodingAttr"); 463 return !getDimSlices().empty(); 464 } 465 466 SparseTensorDimSliceAttr 467 SparseTensorEncodingAttr::getDimSlice(Dimension dim) const { 468 assert(isSlice() && "Is not a slice"); 469 const auto dimSlices = getDimSlices(); 470 assert(dim < dimSlices.size() && "Dimension is out of bounds"); 471 return dimSlices[dim]; 472 } 473 474 std::optional<uint64_t> 475 SparseTensorEncodingAttr::getStaticDimSliceOffset(Dimension dim) const { 476 return getDimSlice(dim).getStaticOffset(); 477 } 478 479 std::optional<uint64_t> 480 SparseTensorEncodingAttr::getStaticDimSliceStride(Dimension dim) const { 481 return getDimSlice(dim).getStaticStride(); 482 } 483 484 std::optional<uint64_t> 485 SparseTensorEncodingAttr::getStaticLvlSliceOffset(Level lvl) const { 486 return getStaticDimSliceOffset(toDim(*this, lvl)); 487 } 488 489 std::optional<uint64_t> 490 SparseTensorEncodingAttr::getStaticLvlSliceStride(Level lvl) const { 491 return getStaticDimSliceStride(toDim(*this, lvl)); 492 } 493 494 SmallVector<int64_t> 495 SparseTensorEncodingAttr::translateShape(ArrayRef<int64_t> srcShape, 496 CrdTransDirectionKind dir) const { 497 if (isIdentity()) 498 return SmallVector<int64_t>(srcShape); 499 500 SmallVector<int64_t> ret; 501 unsigned rank = 502 dir == CrdTransDirectionKind::dim2lvl ? getLvlRank() : getDimRank(); 503 ret.reserve(rank); 504 505 if (isPermutation()) { 506 for (unsigned r = 0; r < rank; r++) { 507 unsigned trans = dir == CrdTransDirectionKind::dim2lvl ? toDim(*this, r) 508 : toLvl(*this, r); 509 ret.push_back(srcShape[trans]); 510 } 511 return ret; 512 } 513 514 // Handle non-permutation maps. 515 AffineMap transMap = 516 dir == CrdTransDirectionKind::dim2lvl ? getDimToLvl() : getLvlToDim(); 517 518 SmallVector<AffineExpr> dimRep; 519 dimRep.reserve(srcShape.size()); 520 for (int64_t sz : srcShape) { 521 if (!ShapedType::isDynamic(sz)) { 522 // Push back the max coordinate for the given dimension/level size. 523 dimRep.push_back(getAffineConstantExpr(sz - 1, getContext())); 524 } else { 525 // A dynamic size, use a AffineDimExpr to symbolize the value. 526 dimRep.push_back(getAffineDimExpr(dimRep.size(), getContext())); 527 } 528 }; 529 530 for (AffineExpr exp : transMap.getResults()) { 531 // Do constant propagation on the affine map. 532 AffineExpr evalExp = 533 simplifyAffineExpr(exp.replaceDims(dimRep), srcShape.size(), 0); 534 // use llvm namespace here to avoid ambiguity 535 if (auto c = llvm::dyn_cast<AffineConstantExpr>(evalExp)) { 536 ret.push_back(c.getValue() + 1); 537 } else { 538 if (auto mod = llvm::dyn_cast<AffineBinaryOpExpr>(evalExp); 539 mod && mod.getKind() == AffineExprKind::Mod) { 540 // We can still infer a static bound for expressions in form 541 // "d % constant" since d % constant \in [0, constant). 542 if (auto bound = llvm::dyn_cast<AffineConstantExpr>(mod.getRHS())) { 543 ret.push_back(bound.getValue()); 544 continue; 545 } 546 } 547 ret.push_back(ShapedType::kDynamic); 548 } 549 } 550 assert(ret.size() == rank); 551 return ret; 552 } 553 554 ValueRange 555 SparseTensorEncodingAttr::translateCrds(OpBuilder &builder, Location loc, 556 ValueRange crds, 557 CrdTransDirectionKind dir) const { 558 if (!getImpl()) 559 return crds; 560 561 SmallVector<Type> retType( 562 dir == CrdTransDirectionKind::lvl2dim ? getDimRank() : getLvlRank(), 563 builder.getIndexType()); 564 auto transOp = builder.create<CrdTranslateOp>(loc, retType, crds, dir, *this); 565 return transOp.getOutCrds(); 566 } 567 568 Attribute SparseTensorEncodingAttr::parse(AsmParser &parser, Type type) { 569 // Open "<{" part. 570 if (failed(parser.parseLess())) 571 return {}; 572 if (failed(parser.parseLBrace())) 573 return {}; 574 575 // Process the data from the parsed dictionary value into struct-like data. 576 SmallVector<LevelType> lvlTypes; 577 SmallVector<SparseTensorDimSliceAttr> dimSlices; 578 AffineMap dimToLvl = {}; 579 AffineMap lvlToDim = {}; 580 unsigned posWidth = 0; 581 unsigned crdWidth = 0; 582 Attribute explicitVal; 583 Attribute implicitVal; 584 StringRef attrName; 585 SmallVector<StringRef, 5> keys = {"map", "posWidth", "crdWidth", 586 "explicitVal", "implicitVal"}; 587 while (succeeded(parser.parseOptionalKeyword(&attrName))) { 588 // Detect admissible keyword. 589 auto *it = find(keys, attrName); 590 if (it == keys.end()) { 591 parser.emitError(parser.getNameLoc(), "unexpected key: ") << attrName; 592 return {}; 593 } 594 unsigned keyWordIndex = it - keys.begin(); 595 // Consume the `=` after keys 596 if (failed(parser.parseEqual())) 597 return {}; 598 // Dispatch on keyword. 599 switch (keyWordIndex) { 600 case 0: { // map 601 ir_detail::DimLvlMapParser cParser(parser); 602 auto res = cParser.parseDimLvlMap(); 603 if (failed(res)) 604 return {}; 605 const auto &dlm = *res; 606 607 const Level lvlRank = dlm.getLvlRank(); 608 for (Level lvl = 0; lvl < lvlRank; lvl++) 609 lvlTypes.push_back(dlm.getLvlType(lvl)); 610 611 const Dimension dimRank = dlm.getDimRank(); 612 for (Dimension dim = 0; dim < dimRank; dim++) 613 dimSlices.push_back(dlm.getDimSlice(dim)); 614 // NOTE: the old syntax requires an all-or-nothing approach to 615 // `dimSlices`; therefore, if any slice actually exists then we need 616 // to convert null-DSA into default/nop DSA. 617 const auto isDefined = [](SparseTensorDimSliceAttr slice) { 618 return static_cast<bool>(slice.getImpl()); 619 }; 620 if (llvm::any_of(dimSlices, isDefined)) { 621 const auto defaultSlice = 622 SparseTensorDimSliceAttr::get(parser.getContext()); 623 for (Dimension dim = 0; dim < dimRank; dim++) 624 if (!isDefined(dimSlices[dim])) 625 dimSlices[dim] = defaultSlice; 626 } else { 627 dimSlices.clear(); 628 } 629 630 dimToLvl = dlm.getDimToLvlMap(parser.getContext()); 631 lvlToDim = dlm.getLvlToDimMap(parser.getContext()); 632 break; 633 } 634 case 1: { // posWidth 635 Attribute attr; 636 if (failed(parser.parseAttribute(attr))) 637 return {}; 638 auto intAttr = llvm::dyn_cast<IntegerAttr>(attr); 639 if (!intAttr) { 640 parser.emitError(parser.getNameLoc(), 641 "expected an integral position bitwidth"); 642 return {}; 643 } 644 posWidth = intAttr.getInt(); 645 break; 646 } 647 case 2: { // crdWidth 648 Attribute attr; 649 if (failed(parser.parseAttribute(attr))) 650 return {}; 651 auto intAttr = llvm::dyn_cast<IntegerAttr>(attr); 652 if (!intAttr) { 653 parser.emitError(parser.getNameLoc(), 654 "expected an integral index bitwidth"); 655 return {}; 656 } 657 crdWidth = intAttr.getInt(); 658 break; 659 } 660 case 3: { // explicitVal 661 Attribute attr; 662 if (failed(parser.parseAttribute(attr))) 663 return {}; 664 if (auto result = llvm::dyn_cast<FloatAttr>(attr)) { 665 explicitVal = result; 666 } else if (auto result = llvm::dyn_cast<IntegerAttr>(attr)) { 667 explicitVal = result; 668 } else if (auto result = llvm::dyn_cast<complex::NumberAttr>(attr)) { 669 explicitVal = result; 670 } else { 671 parser.emitError(parser.getNameLoc(), 672 "expected a numeric value for explicitVal"); 673 return {}; 674 } 675 break; 676 } 677 case 4: { // implicitVal 678 Attribute attr; 679 if (failed(parser.parseAttribute(attr))) 680 return {}; 681 if (auto result = llvm::dyn_cast<FloatAttr>(attr)) { 682 implicitVal = result; 683 } else if (auto result = llvm::dyn_cast<IntegerAttr>(attr)) { 684 implicitVal = result; 685 } else if (auto result = llvm::dyn_cast<complex::NumberAttr>(attr)) { 686 implicitVal = result; 687 } else { 688 parser.emitError(parser.getNameLoc(), 689 "expected a numeric value for implicitVal"); 690 return {}; 691 } 692 break; 693 } 694 } // switch 695 // Only last item can omit the comma. 696 if (parser.parseOptionalComma().failed()) 697 break; 698 } 699 700 // Close "}>" part. 701 if (failed(parser.parseRBrace())) 702 return {}; 703 if (failed(parser.parseGreater())) 704 return {}; 705 706 // Construct struct-like storage for attribute. 707 if (!lvlToDim || lvlToDim.isEmpty()) { 708 lvlToDim = inferLvlToDim(dimToLvl, parser.getContext()); 709 } 710 return parser.getChecked<SparseTensorEncodingAttr>( 711 parser.getContext(), lvlTypes, dimToLvl, lvlToDim, posWidth, crdWidth, 712 explicitVal, implicitVal, dimSlices); 713 } 714 715 void SparseTensorEncodingAttr::print(AsmPrinter &printer) const { 716 auto map = static_cast<AffineMap>(getDimToLvl()); 717 // Empty affine map indicates identity map 718 if (!map) 719 map = AffineMap::getMultiDimIdentityMap(getLvlTypes().size(), getContext()); 720 printer << "<{ map = "; 721 printSymbols(map, printer); 722 printer << '('; 723 printDimensions(map, printer, getDimSlices()); 724 printer << ") -> ("; 725 printLevels(map, printer, getLvlTypes()); 726 printer << ')'; 727 // Print remaining members only for non-default values. 728 if (getPosWidth()) 729 printer << ", posWidth = " << getPosWidth(); 730 if (getCrdWidth()) 731 printer << ", crdWidth = " << getCrdWidth(); 732 if (getExplicitVal()) { 733 printer << ", explicitVal = " << getExplicitVal(); 734 } 735 if (getImplicitVal()) 736 printer << ", implicitVal = " << getImplicitVal(); 737 printer << " }>"; 738 } 739 740 void SparseTensorEncodingAttr::printSymbols(AffineMap &map, 741 AsmPrinter &printer) const { 742 if (map.getNumSymbols() == 0) 743 return; 744 printer << '['; 745 for (unsigned i = 0, n = map.getNumSymbols() - 1; i < n; i++) 746 printer << 's' << i << ", "; 747 if (map.getNumSymbols() >= 1) 748 printer << 's' << map.getNumSymbols() - 1; 749 printer << ']'; 750 } 751 752 void SparseTensorEncodingAttr::printDimensions( 753 AffineMap &map, AsmPrinter &printer, 754 ArrayRef<SparseTensorDimSliceAttr> dimSlices) const { 755 if (!dimSlices.empty()) { 756 for (unsigned i = 0, n = map.getNumDims() - 1; i < n; i++) 757 printer << 'd' << i << " : " << dimSlices[i] << ", "; 758 if (map.getNumDims() >= 1) { 759 printer << 'd' << map.getNumDims() - 1 << " : " 760 << dimSlices[map.getNumDims() - 1]; 761 } 762 } else { 763 for (unsigned i = 0, n = map.getNumDims() - 1; i < n; i++) 764 printer << 'd' << i << ", "; 765 if (map.getNumDims() >= 1) 766 printer << 'd' << map.getNumDims() - 1; 767 } 768 } 769 770 void SparseTensorEncodingAttr::printLevels(AffineMap &map, AsmPrinter &printer, 771 ArrayRef<LevelType> lvlTypes) const { 772 for (unsigned i = 0, n = map.getNumResults() - 1; i < n; i++) { 773 map.getResult(i).print(printer.getStream()); 774 printer << " : " << toMLIRString(lvlTypes[i]) << ", "; 775 } 776 if (map.getNumResults() >= 1) { 777 auto lastIndex = map.getNumResults() - 1; 778 map.getResult(lastIndex).print(printer.getStream()); 779 printer << " : " << toMLIRString(lvlTypes[lastIndex]); 780 } 781 } 782 783 LogicalResult SparseTensorEncodingAttr::verify( 784 function_ref<InFlightDiagnostic()> emitError, ArrayRef<LevelType> lvlTypes, 785 AffineMap dimToLvl, AffineMap lvlToDim, unsigned posWidth, 786 unsigned crdWidth, Attribute explicitVal, Attribute implicitVal, 787 ArrayRef<SparseTensorDimSliceAttr> dimSlices) { 788 if (!acceptBitWidth(posWidth)) 789 return emitError() << "unexpected position bitwidth: " << posWidth; 790 if (!acceptBitWidth(crdWidth)) 791 return emitError() << "unexpected coordinate bitwidth: " << crdWidth; 792 793 // Verify every COO segment. 794 auto *it = std::find_if(lvlTypes.begin(), lvlTypes.end(), isSingletonLT); 795 while (it != lvlTypes.end()) { 796 if (it == lvlTypes.begin() || 797 !(it - 1)->isa<LevelFormat::Compressed, LevelFormat::LooseCompressed>()) 798 return emitError() << "expected compressed or loose_compressed level " 799 "before singleton level"; 800 801 auto *curCOOEnd = std::find_if_not(it, lvlTypes.end(), isSingletonLT); 802 if (!std::all_of(it, curCOOEnd, 803 [](LevelType i) { return isSingletonLT(i); })) 804 return emitError() << "expected all singleton lvlTypes " 805 "following a singleton level"; 806 // We can potentially support mixed SoA/AoS singleton levels. 807 if (!std::all_of(it, curCOOEnd, [it](LevelType i) { 808 return it->isa<LevelPropNonDefault::SoA>() == 809 i.isa<LevelPropNonDefault::SoA>(); 810 })) { 811 return emitError() << "expected all singleton lvlTypes stored in the " 812 "same memory layout (SoA vs AoS)."; 813 } 814 it = std::find_if(curCOOEnd, lvlTypes.end(), isSingletonLT); 815 } 816 817 auto lastBatch = std::find_if(lvlTypes.rbegin(), lvlTypes.rend(), isBatchLT); 818 if (!std::all_of(lastBatch, lvlTypes.rend(), isBatchLT)) 819 return emitError() << "Batch lvlType can only be leading levels."; 820 821 // SoA property can only be applied on singleton level. 822 auto soaLvls = llvm::make_filter_range(lvlTypes, [](LevelType lt) { 823 return lt.isa<LevelPropNonDefault::SoA>(); 824 }); 825 if (llvm::any_of(soaLvls, [](LevelType lt) { 826 return !lt.isa<LevelFormat::Singleton>(); 827 })) { 828 return emitError() << "SoA is only applicable to singleton lvlTypes."; 829 } 830 831 // TODO: audit formats that actually are supported by backend. 832 if (auto it = std::find_if(lvlTypes.begin(), lvlTypes.end(), isNOutOfMLT); 833 it != std::end(lvlTypes)) { 834 if (it != lvlTypes.end() - 1) 835 return emitError() << "expected n_out_of_m to be the last level type"; 836 if (!std::all_of(lvlTypes.begin(), it, 837 [](LevelType i) { return isDenseLT(i); })) 838 return emitError() << "expected all dense lvlTypes " 839 "before a n_out_of_m level"; 840 if (dimToLvl && (dimToLvl.getNumDims() != dimToLvl.getNumResults())) { 841 if (!isBlockSparsity(dimToLvl)) { 842 return emitError() 843 << "expected 1xm block structure for n_out_of_m level"; 844 } 845 auto sizes = getBlockSize(dimToLvl); 846 unsigned coefficient = 0; 847 for (const auto &elem : sizes) { 848 if (elem != 0) { 849 if (elem != coefficient && coefficient != 0) { 850 return emitError() << "expected only one blocked level " 851 "with the same coefficients"; 852 } 853 coefficient = elem; 854 } 855 } 856 if (coefficient != getM(*it)) { 857 return emitError() << "expected coeffiencts of Affine expressions " 858 "to be equal to m of n_out_of_m level"; 859 } 860 } 861 } 862 // Before we can check that the level-rank is consistent/coherent 863 // across all fields, we need to define it. The source-of-truth for 864 // the `getLvlRank` method is the length of the level-types array, 865 // since it must always be provided and have full rank; therefore we 866 // use that same source-of-truth here. 867 const Level lvlRank = lvlTypes.size(); 868 if (lvlRank == 0) 869 return emitError() << "expected a non-empty array for lvlTypes"; 870 // We save `dimRank` here because we'll also need it to verify `dimSlices`. 871 const Dimension dimRank = dimToLvl ? dimToLvl.getNumDims() : lvlRank; 872 if (dimToLvl) { 873 if (dimToLvl.getNumResults() != lvlRank) 874 return emitError() 875 << "level-rank mismatch between dimToLvl and lvlTypes: " 876 << dimToLvl.getNumResults() << " != " << lvlRank; 877 auto inferRes = inferLvlToDim(dimToLvl, dimToLvl.getContext()); 878 // Symbols can't be inferred but are acceptable. 879 if (!inferRes && dimToLvl.getNumSymbols() == 0) 880 return emitError() << "failed to infer lvlToDim from dimToLvl"; 881 if (lvlToDim && (inferRes != lvlToDim)) 882 return emitError() << "expected lvlToDim to be an inverse of dimToLvl"; 883 if (dimRank > lvlRank) 884 return emitError() << "unexpected dimToLvl mapping from " << dimRank 885 << " to " << lvlRank; 886 } 887 if (!dimSlices.empty()) { 888 if (dimSlices.size() != dimRank) 889 return emitError() 890 << "dimension-rank mismatch between dimSlices and dimToLvl: " 891 << dimSlices.size() << " != " << dimRank; 892 // Compiler support for `dimSlices` currently requires that the two 893 // ranks agree. (However, it does allow `dimToLvl` to be a permutation.) 894 if (dimRank != lvlRank) 895 return emitError() 896 << "dimSlices expected dimension-rank to match level-rank: " 897 << dimRank << " != " << lvlRank; 898 } 899 return success(); 900 } 901 902 LogicalResult SparseTensorEncodingAttr::verifyEncoding( 903 ArrayRef<Size> dimShape, Type elementType, 904 function_ref<InFlightDiagnostic()> emitError) const { 905 // Check structural integrity. In particular, this ensures that the 906 // level-rank is coherent across all the fields. 907 if (failed(verify(emitError, getLvlTypes(), getDimToLvl(), getLvlToDim(), 908 getPosWidth(), getCrdWidth(), getExplicitVal(), 909 getImplicitVal(), getDimSlices()))) 910 return failure(); 911 // Check integrity with tensor type specifics. In particular, we 912 // need only check that the dimension-rank of the tensor agrees with 913 // the dimension-rank of the encoding. 914 const Dimension dimRank = dimShape.size(); 915 if (dimRank == 0) 916 return emitError() << "expected non-scalar sparse tensor"; 917 if (getDimRank() != dimRank) 918 return emitError() 919 << "dimension-rank mismatch between encoding and tensor shape: " 920 << getDimRank() << " != " << dimRank; 921 if (auto expVal = getExplicitVal()) { 922 Type attrType = llvm::dyn_cast<TypedAttr>(expVal).getType(); 923 if (attrType != elementType) { 924 return emitError() << "explicit value type mismatch between encoding and " 925 << "tensor element type: " << attrType 926 << " != " << elementType; 927 } 928 } 929 if (auto impVal = getImplicitVal()) { 930 Type attrType = llvm::dyn_cast<TypedAttr>(impVal).getType(); 931 if (attrType != elementType) { 932 return emitError() << "implicit value type mismatch between encoding and " 933 << "tensor element type: " << attrType 934 << " != " << elementType; 935 } 936 // Currently, we only support zero as the implicit value. 937 auto impFVal = llvm::dyn_cast<FloatAttr>(impVal); 938 auto impIntVal = llvm::dyn_cast<IntegerAttr>(impVal); 939 auto impComplexVal = llvm::dyn_cast<complex::NumberAttr>(impVal); 940 if ((impFVal && impFVal.getValue().isNonZero()) || 941 (impIntVal && !impIntVal.getValue().isZero()) || 942 (impComplexVal && (impComplexVal.getImag().isNonZero() || 943 impComplexVal.getReal().isNonZero()))) { 944 return emitError() << "implicit value must be zero"; 945 } 946 } 947 return success(); 948 } 949 950 Level mlir::sparse_tensor::SparseTensorEncodingAttr::getAoSCOOStart() const { 951 SmallVector<COOSegment> coo = getCOOSegments(); 952 assert(coo.size() == 1 || coo.empty()); 953 if (!coo.empty() && coo.front().isAoS()) { 954 return coo.front().lvlRange.first; 955 } 956 return getLvlRank(); 957 } 958 959 SmallVector<COOSegment> 960 mlir::sparse_tensor::SparseTensorEncodingAttr::getCOOSegments() const { 961 SmallVector<COOSegment> ret; 962 if (getLvlRank() <= 1) 963 return ret; 964 965 ArrayRef<LevelType> lts = getLvlTypes(); 966 Level l = 0; 967 while (l < getLvlRank()) { 968 auto lt = lts[l]; 969 if (lt.isa<LevelFormat::Compressed, LevelFormat::LooseCompressed>()) { 970 auto cur = lts.begin() + l; 971 auto end = std::find_if(cur + 1, lts.end(), [](LevelType lt) { 972 return !lt.isa<LevelFormat::Singleton>(); 973 }); 974 unsigned cooLen = std::distance(cur, end); 975 if (cooLen > 1) { 976 // To support mixed SoA/AoS COO, we should break the segment when the 977 // storage scheme changes, for now we faithfully assume that all 978 // consecutive singleton levels have the same storage format as verified 979 // STEA. 980 ret.push_back(COOSegment{std::make_pair(l, l + cooLen), 981 lts[l + 1].isa<LevelPropNonDefault::SoA>()}); 982 } 983 l += cooLen; 984 } else { 985 l++; 986 } 987 } 988 return ret; 989 } 990 991 //===----------------------------------------------------------------------===// 992 // SparseTensorType Methods. 993 //===----------------------------------------------------------------------===// 994 995 bool mlir::sparse_tensor::SparseTensorType::isCOOType(Level startLvl, 996 bool isUnique) const { 997 if (!hasEncoding()) 998 return false; 999 if (!isCompressedLvl(startLvl) && !isLooseCompressedLvl(startLvl)) 1000 return false; 1001 for (Level l = startLvl + 1; l < lvlRank; ++l) 1002 if (!isSingletonLvl(l)) 1003 return false; 1004 // If isUnique is true, then make sure that the last level is unique, 1005 // that is, when lvlRank == 1, the only compressed level is unique, 1006 // and when lvlRank > 1, the last singleton is unique. 1007 return !isUnique || isUniqueLvl(lvlRank - 1); 1008 } 1009 1010 RankedTensorType 1011 mlir::sparse_tensor::SparseTensorType::getCOOType(bool ordered) const { 1012 SmallVector<LevelType> lvlTypes; 1013 lvlTypes.reserve(lvlRank); 1014 // A non-unique compressed level at beginning (unless this is 1015 // also the last level, then it is unique). 1016 lvlTypes.push_back( 1017 *buildLevelType(LevelFormat::Compressed, ordered, lvlRank == 1)); 1018 if (lvlRank > 1) { 1019 // Followed by n-2 non-unique singleton levels. 1020 std::fill_n(std::back_inserter(lvlTypes), lvlRank - 2, 1021 *buildLevelType(LevelFormat::Singleton, ordered, false)); 1022 // Ends by a unique singleton level. 1023 lvlTypes.push_back(*buildLevelType(LevelFormat::Singleton, ordered, true)); 1024 } 1025 auto enc = SparseTensorEncodingAttr::get( 1026 getContext(), lvlTypes, getDimToLvl(), getLvlToDim(), getPosWidth(), 1027 getCrdWidth(), getExplicitVal(), getImplicitVal()); 1028 return RankedTensorType::get(getDimShape(), getElementType(), enc); 1029 } 1030 1031 //===----------------------------------------------------------------------===// 1032 // Convenience Methods. 1033 //===----------------------------------------------------------------------===// 1034 1035 SparseTensorEncodingAttr 1036 mlir::sparse_tensor::getSparseTensorEncoding(Type type) { 1037 if (auto ttp = llvm::dyn_cast<RankedTensorType>(type)) 1038 return llvm::dyn_cast_or_null<SparseTensorEncodingAttr>(ttp.getEncoding()); 1039 if (auto mdtp = llvm::dyn_cast<StorageSpecifierType>(type)) 1040 return mdtp.getEncoding(); 1041 return nullptr; 1042 } 1043 1044 AffineMap mlir::sparse_tensor::inferLvlToDim(AffineMap dimToLvl, 1045 MLIRContext *context) { 1046 auto map = static_cast<AffineMap>(dimToLvl); 1047 AffineMap lvlToDim; 1048 // Return an empty lvlToDim when inference is not successful. 1049 if (!map || map.getNumSymbols() != 0) { 1050 lvlToDim = AffineMap(); 1051 } else if (map.isPermutation()) { 1052 lvlToDim = inversePermutation(map); 1053 } else if (isBlockSparsity(map)) { 1054 lvlToDim = inverseBlockSparsity(map, context); 1055 } 1056 return lvlToDim; 1057 } 1058 1059 AffineMap mlir::sparse_tensor::inverseBlockSparsity(AffineMap dimToLvl, 1060 MLIRContext *context) { 1061 SmallVector<AffineExpr> lvlExprs; 1062 auto numLvls = dimToLvl.getNumResults(); 1063 lvlExprs.reserve(numLvls); 1064 // lvlExprComponents stores information of the floordiv and mod operations 1065 // applied to the same dimension, so as to build the lvlToDim map. 1066 std::map<unsigned, SmallVector<AffineExpr, 3>> lvlExprComponents; 1067 for (unsigned i = 0, n = numLvls; i < n; i++) { 1068 auto result = dimToLvl.getResult(i); 1069 if (auto binOp = dyn_cast<AffineBinaryOpExpr>(result)) { 1070 if (result.getKind() == AffineExprKind::FloorDiv) { 1071 // Position of the dimension in dimToLvl. 1072 auto pos = dyn_cast<AffineDimExpr>(binOp.getLHS()).getPosition(); 1073 assert(lvlExprComponents.find(pos) == lvlExprComponents.end() && 1074 "expected only one floordiv for each dimension"); 1075 SmallVector<AffineExpr, 3> components; 1076 // Level variable for floordiv. 1077 components.push_back(getAffineDimExpr(i, context)); 1078 // Multiplier. 1079 components.push_back(binOp.getRHS()); 1080 // Map key is the position of the dimension. 1081 lvlExprComponents[pos] = components; 1082 } else if (result.getKind() == AffineExprKind::Mod) { 1083 auto pos = dyn_cast<AffineDimExpr>(binOp.getLHS()).getPosition(); 1084 assert(lvlExprComponents.find(pos) != lvlExprComponents.end() && 1085 "expected floordiv before mod"); 1086 // Add level variable for mod to the same vector 1087 // of the corresponding floordiv. 1088 lvlExprComponents[pos].push_back(getAffineDimExpr(i, context)); 1089 } else { 1090 assert(false && "expected floordiv or mod"); 1091 } 1092 } else { 1093 lvlExprs.push_back(getAffineDimExpr(i, context)); 1094 } 1095 } 1096 // Build lvlExprs from lvlExprComponents. 1097 // For example, for il = i floordiv 2 and ii = i mod 2, the components 1098 // would be [il, 2, ii]. It could be used to build the AffineExpr 1099 // i = il * 2 + ii in lvlToDim. 1100 for (auto &components : lvlExprComponents) { 1101 assert(components.second.size() == 3 && 1102 "expected 3 components to build lvlExprs"); 1103 auto mulOp = getAffineBinaryOpExpr( 1104 AffineExprKind::Mul, components.second[0], components.second[1]); 1105 auto addOp = 1106 getAffineBinaryOpExpr(AffineExprKind::Add, mulOp, components.second[2]); 1107 lvlExprs.push_back(addOp); 1108 } 1109 return dimToLvl.get(dimToLvl.getNumResults(), 0, lvlExprs, context); 1110 } 1111 1112 SmallVector<unsigned> mlir::sparse_tensor::getBlockSize(AffineMap dimToLvl) { 1113 assert(isBlockSparsity(dimToLvl) && 1114 "expected dimToLvl to be block sparsity for calling getBlockSize"); 1115 SmallVector<unsigned> blockSize; 1116 for (auto result : dimToLvl.getResults()) { 1117 if (auto binOp = dyn_cast<AffineBinaryOpExpr>(result)) { 1118 if (result.getKind() == AffineExprKind::Mod) { 1119 blockSize.push_back( 1120 dyn_cast<AffineConstantExpr>(binOp.getRHS()).getValue()); 1121 } 1122 } else { 1123 blockSize.push_back(0); 1124 } 1125 } 1126 return blockSize; 1127 } 1128 1129 bool mlir::sparse_tensor::isBlockSparsity(AffineMap dimToLvl) { 1130 if (!dimToLvl) 1131 return false; 1132 std::map<unsigned, int64_t> coeffientMap; 1133 bool hasBlock = false; 1134 for (auto result : dimToLvl.getResults()) { 1135 if (auto binOp = dyn_cast<AffineBinaryOpExpr>(result)) { 1136 // Check for "dim op const". 1137 auto dimOp = dyn_cast<AffineDimExpr>(binOp.getLHS()); 1138 auto conOp = dyn_cast<AffineConstantExpr>(binOp.getRHS()); 1139 if (!dimOp || !conOp || conOp.getValue() <= 0) 1140 return false; 1141 // Inspect "dim / const" or "dim % const". 1142 auto pos = dimOp.getPosition(); 1143 if (binOp.getKind() == AffineExprKind::FloorDiv) { 1144 // Expect only one floordiv for each dimension. 1145 auto [it, inserted] = coeffientMap.try_emplace(pos); 1146 if (!inserted) 1147 return false; 1148 // Record coefficient of the floordiv. 1149 it->second = conOp.getValue(); 1150 } else if (binOp.getKind() == AffineExprKind::Mod) { 1151 // Expect floordiv before mod. 1152 auto it = coeffientMap.find(pos); 1153 if (it == coeffientMap.end()) 1154 return false; 1155 // Expect mod to have the same coefficient as floordiv. 1156 if (conOp.getValue() != it->second) 1157 return false; 1158 hasBlock = true; 1159 } else { 1160 return false; 1161 } 1162 } else if (auto dimOp = dyn_cast<AffineDimExpr>(result)) { 1163 auto pos = dimOp.getPosition(); 1164 // Expect dim to be unset. 1165 if (!coeffientMap.try_emplace(pos, 0).second) 1166 return false; 1167 } else { 1168 return false; 1169 } 1170 } 1171 return hasBlock; 1172 } 1173 1174 bool mlir::sparse_tensor::hasAnyNonIdentityOperandsOrResults(Operation *op) { 1175 auto hasNonIdentityMap = [](Value v) { 1176 auto stt = tryGetSparseTensorType(v); 1177 return stt && !stt->isIdentity(); 1178 }; 1179 1180 return llvm::any_of(op->getOperands(), hasNonIdentityMap) || 1181 llvm::any_of(op->getResults(), hasNonIdentityMap); 1182 } 1183 1184 Dimension mlir::sparse_tensor::toDim(SparseTensorEncodingAttr enc, Level l) { 1185 if (enc) { 1186 assert(enc.isPermutation() && "Non permutation map not supported"); 1187 if (const auto dimToLvl = enc.getDimToLvl()) 1188 return dimToLvl.getDimPosition(l); 1189 } 1190 return l; 1191 } 1192 1193 Level mlir::sparse_tensor::toLvl(SparseTensorEncodingAttr enc, Dimension d) { 1194 if (enc) { 1195 assert(enc.isPermutation() && "Non permutation map not supported"); 1196 if (const auto lvlToDim = enc.getLvlToDim()) 1197 return lvlToDim.getDimPosition(d); 1198 } 1199 return d; 1200 } 1201 1202 /// We normalized sparse tensor encoding attribute by always using 1203 /// ordered/unique LT such that "compressed_nu_no" and "compressed_nu" (as well 1204 /// as other variants) lead to the same storage specifier type, and stripping 1205 /// irrelevant fields that do not alter the sparse tensor memory layout. 1206 static SparseTensorEncodingAttr 1207 getNormalizedEncodingForSpecifier(SparseTensorEncodingAttr enc) { 1208 SmallVector<LevelType> lts; 1209 for (auto lt : enc.getLvlTypes()) 1210 lts.push_back(lt.stripStorageIrrelevantProperties()); 1211 1212 return SparseTensorEncodingAttr::get( 1213 enc.getContext(), lts, 1214 AffineMap(), // dimToLvl (irrelevant to storage specifier) 1215 AffineMap(), // lvlToDim (irrelevant to storage specifier) 1216 // Always use `index` for memSize and lvlSize instead of reusing 1217 // `getPosWidth` and `getCrdWidth`. It allows us to reuse the same SSA 1218 // value for different bitwidth, it also avoids casting between index and 1219 // integer (returned by DimOp) 1220 0, 0, 1221 Attribute(), // explicitVal (irrelevant to storage specifier) 1222 Attribute(), // implicitVal (irrelevant to storage specifier) 1223 enc.getDimSlices()); 1224 } 1225 1226 StorageSpecifierType 1227 StorageSpecifierType::get(MLIRContext *ctx, SparseTensorEncodingAttr encoding) { 1228 return Base::get(ctx, getNormalizedEncodingForSpecifier(encoding)); 1229 } 1230 1231 StorageSpecifierType 1232 StorageSpecifierType::getChecked(function_ref<InFlightDiagnostic()> emitError, 1233 MLIRContext *ctx, 1234 SparseTensorEncodingAttr encoding) { 1235 return Base::getChecked(emitError, ctx, 1236 getNormalizedEncodingForSpecifier(encoding)); 1237 } 1238 1239 //===----------------------------------------------------------------------===// 1240 // SparseTensorDialect Operations. 1241 //===----------------------------------------------------------------------===// 1242 1243 static LogicalResult lvlIsInBounds(Level lvl, Value tensor) { 1244 return success(lvl < getSparseTensorType(tensor).getLvlRank()); 1245 } 1246 1247 static LogicalResult isMatchingWidth(Value mem, unsigned width) { 1248 const Type etp = getMemRefType(mem).getElementType(); 1249 return success(width == 0 ? etp.isIndex() : etp.isInteger(width)); 1250 } 1251 1252 static LogicalResult verifySparsifierGetterSetter( 1253 StorageSpecifierKind mdKind, std::optional<Level> lvl, 1254 TypedValue<StorageSpecifierType> md, Operation *op) { 1255 if (mdKind == StorageSpecifierKind::ValMemSize && lvl) { 1256 return op->emitError( 1257 "redundant level argument for querying value memory size"); 1258 } 1259 1260 const auto enc = md.getType().getEncoding(); 1261 const Level lvlRank = enc.getLvlRank(); 1262 1263 if (mdKind == StorageSpecifierKind::DimOffset || 1264 mdKind == StorageSpecifierKind::DimStride) 1265 if (!enc.isSlice()) 1266 return op->emitError("requested slice data on non-slice tensor"); 1267 1268 if (mdKind != StorageSpecifierKind::ValMemSize) { 1269 if (!lvl) 1270 return op->emitError("missing level argument"); 1271 1272 const Level l = lvl.value(); 1273 if (l >= lvlRank) 1274 return op->emitError("requested level is out of bounds"); 1275 1276 if (mdKind == StorageSpecifierKind::PosMemSize && enc.isSingletonLvl(l)) 1277 return op->emitError( 1278 "requested position memory size on a singleton level"); 1279 } 1280 return success(); 1281 } 1282 1283 static Type getFieldElemType(SparseTensorType stt, SparseTensorFieldKind kind) { 1284 switch (kind) { 1285 case SparseTensorFieldKind::CrdMemRef: 1286 return stt.getCrdType(); 1287 case SparseTensorFieldKind::PosMemRef: 1288 return stt.getPosType(); 1289 case SparseTensorFieldKind::ValMemRef: 1290 return stt.getElementType(); 1291 case SparseTensorFieldKind::StorageSpec: 1292 return nullptr; 1293 } 1294 llvm_unreachable("Unrecognizable FieldKind"); 1295 } 1296 1297 static LogicalResult verifyPackUnPack(Operation *op, bool requiresStaticShape, 1298 SparseTensorType stt, 1299 RankedTensorType valTp, 1300 TypeRange lvlTps) { 1301 if (requiresStaticShape && !stt.hasStaticDimShape()) 1302 return op->emitError("the sparse-tensor must have static shape"); 1303 if (!stt.hasEncoding()) 1304 return op->emitError("the sparse-tensor must have an encoding attribute"); 1305 1306 // Verifies the trailing COO. 1307 Level cooStartLvl = stt.getAoSCOOStart(); 1308 if (cooStartLvl < stt.getLvlRank()) { 1309 // We only supports trailing COO for now, must be the last input. 1310 auto cooTp = llvm::cast<ShapedType>(lvlTps.back()); 1311 // The coordinates should be in shape of <? x rank> 1312 unsigned expCOORank = stt.getLvlRank() - cooStartLvl; 1313 if (cooTp.getRank() != 2 || expCOORank != cooTp.getShape().back()) { 1314 return op->emitError("input/output trailing COO level-ranks don't match"); 1315 } 1316 } 1317 1318 // Verifies that all types match. 1319 StorageLayout layout(stt.getEncoding()); 1320 if (layout.getNumDataFields() != lvlTps.size() + 1) // plus one value memref 1321 return op->emitError("inconsistent number of fields between input/output"); 1322 1323 unsigned idx = 0; 1324 bool misMatch = false; 1325 layout.foreachField([&idx, &misMatch, stt, valTp, 1326 lvlTps](FieldIndex fid, SparseTensorFieldKind fKind, 1327 Level lvl, LevelType lt) -> bool { 1328 if (fKind == SparseTensorFieldKind::StorageSpec) 1329 return true; 1330 1331 Type inputTp = nullptr; 1332 if (fKind == SparseTensorFieldKind::ValMemRef) { 1333 inputTp = valTp; 1334 } else { 1335 assert(fid == idx && stt.getLvlType(lvl) == lt); 1336 inputTp = lvlTps[idx++]; 1337 } 1338 // The input element type and expected element type should match. 1339 Type inpElemTp = llvm::cast<TensorType>(inputTp).getElementType(); 1340 Type expElemTp = getFieldElemType(stt, fKind); 1341 if (inpElemTp != expElemTp) { 1342 misMatch = true; 1343 return false; // to terminate the iteration 1344 } 1345 return true; 1346 }); 1347 1348 if (misMatch) 1349 return op->emitError("input/output element-types don't match"); 1350 return success(); 1351 } 1352 1353 LogicalResult AssembleOp::verify() { 1354 RankedTensorType valuesTp = getValues().getType(); 1355 const auto lvlsTp = getLevels().getTypes(); 1356 const auto resTp = getSparseTensorType(getResult()); 1357 return verifyPackUnPack(*this, true, resTp, valuesTp, lvlsTp); 1358 } 1359 1360 LogicalResult DisassembleOp::verify() { 1361 if (getOutValues().getType() != getRetValues().getType()) 1362 return emitError("output values and return value type mismatch"); 1363 1364 for (auto [ot, rt] : llvm::zip_equal(getOutLevels(), getRetLevels())) 1365 if (ot.getType() != rt.getType()) 1366 return emitError("output levels and return levels type mismatch"); 1367 1368 RankedTensorType valuesTp = getRetValues().getType(); 1369 const auto lvlsTp = getRetLevels().getTypes(); 1370 const auto srcTp = getSparseTensorType(getTensor()); 1371 return verifyPackUnPack(*this, false, srcTp, valuesTp, lvlsTp); 1372 } 1373 1374 LogicalResult ConvertOp::verify() { 1375 RankedTensorType tp1 = getSource().getType(); 1376 RankedTensorType tp2 = getDest().getType(); 1377 if (tp1.getRank() != tp2.getRank()) 1378 return emitError("unexpected conversion mismatch in rank"); 1379 auto dstEnc = 1380 llvm::dyn_cast_or_null<SparseTensorEncodingAttr>(tp2.getEncoding()); 1381 if (dstEnc && dstEnc.isSlice()) 1382 return emitError("cannot convert to a sparse tensor slice"); 1383 1384 auto shape1 = tp1.getShape(); 1385 auto shape2 = tp2.getShape(); 1386 // Accept size matches between the source and the destination type 1387 // (e.g. 10 vs. 10, 10 vs. ?, or ? vs. ?), but reject direct mismatches or 1388 // matches that would need a runtime assert (e.g. 10 vs. 20 or ? vs. 10). 1389 for (Dimension d = 0, dimRank = tp1.getRank(); d < dimRank; d++) 1390 if (shape1[d] != shape2[d] && shape2[d] != ShapedType::kDynamic) 1391 return emitError("unexpected conversion mismatch in dimension ") << d; 1392 return success(); 1393 } 1394 1395 OpFoldResult ConvertOp::fold(FoldAdaptor adaptor) { 1396 if (getType() == getSource().getType()) 1397 return getSource(); 1398 return {}; 1399 } 1400 1401 bool ConvertOp::needsExtraSort() { 1402 SparseTensorType srcStt = getSparseTensorType(getSource()); 1403 SparseTensorType dstStt = getSparseTensorType(getDest()); 1404 1405 // We do not need an extra sort when returning unordered sparse tensors or 1406 // dense tensor since dense tensor support random access. 1407 if (dstStt.isAllDense() || !dstStt.isAllOrdered()) 1408 return false; 1409 1410 if (srcStt.isAllOrdered() && dstStt.isAllOrdered() && 1411 srcStt.hasSameDimToLvl(dstStt)) { 1412 return false; 1413 } 1414 1415 // Source and dest tensors are ordered in different ways. We only do direct 1416 // dense to sparse conversion when the dense input is defined by a sparse 1417 // constant. Note that we can theoretically always directly convert from dense 1418 // inputs by rotating dense loops but it leads to bad cache locality and hurt 1419 // performance. 1420 if (auto constOp = getSource().getDefiningOp<arith::ConstantOp>()) 1421 if (isa<SparseElementsAttr>(constOp.getValue())) 1422 return false; 1423 1424 return true; 1425 } 1426 1427 LogicalResult CrdTranslateOp::verify() { 1428 uint64_t inRank = getEncoder().getLvlRank(); 1429 uint64_t outRank = getEncoder().getDimRank(); 1430 1431 if (getDirection() == CrdTransDirectionKind::dim2lvl) 1432 std::swap(inRank, outRank); 1433 1434 if (inRank != getInCrds().size() || outRank != getOutCrds().size()) 1435 return emitError("Coordinate rank mismatch with encoding"); 1436 1437 return success(); 1438 } 1439 1440 LogicalResult CrdTranslateOp::fold(FoldAdaptor adaptor, 1441 SmallVectorImpl<OpFoldResult> &results) { 1442 if (getEncoder().isIdentity()) { 1443 results.assign(getInCrds().begin(), getInCrds().end()); 1444 return success(); 1445 } 1446 if (getEncoder().isPermutation()) { 1447 AffineMap perm = getDirection() == CrdTransDirectionKind::dim2lvl 1448 ? getEncoder().getDimToLvl() 1449 : getEncoder().getLvlToDim(); 1450 for (AffineExpr exp : perm.getResults()) 1451 results.push_back(getInCrds()[cast<AffineDimExpr>(exp).getPosition()]); 1452 return success(); 1453 } 1454 1455 // Fuse dim2lvl/lvl2dim pairs. 1456 auto def = getInCrds()[0].getDefiningOp<CrdTranslateOp>(); 1457 bool sameDef = def && llvm::all_of(getInCrds(), [def](Value v) { 1458 return v.getDefiningOp() == def; 1459 }); 1460 if (!sameDef) 1461 return failure(); 1462 1463 bool oppositeDir = def.getDirection() != getDirection(); 1464 bool sameOracle = 1465 def.getEncoder().getDimToLvl() == getEncoder().getDimToLvl(); 1466 bool sameCount = def.getNumResults() == getInCrds().size(); 1467 if (!oppositeDir || !sameOracle || !sameCount) 1468 return failure(); 1469 1470 // The definition produces the coordinates in the same order as the input 1471 // coordinates. 1472 bool sameOrder = llvm::all_of(llvm::zip_equal(def.getOutCrds(), getInCrds()), 1473 [](auto valuePair) { 1474 auto [lhs, rhs] = valuePair; 1475 return lhs == rhs; 1476 }); 1477 1478 if (!sameOrder) 1479 return failure(); 1480 // l1 = dim2lvl (lvl2dim l0) 1481 // ==> l0 1482 results.append(def.getInCrds().begin(), def.getInCrds().end()); 1483 return success(); 1484 } 1485 1486 void LvlOp::build(OpBuilder &builder, OperationState &state, Value source, 1487 int64_t index) { 1488 Value val = builder.create<arith::ConstantIndexOp>(state.location, index); 1489 return build(builder, state, source, val); 1490 } 1491 1492 LogicalResult LvlOp::verify() { 1493 if (std::optional<uint64_t> lvl = getConstantLvlIndex()) { 1494 auto stt = getSparseTensorType(getSource()); 1495 if (static_cast<uint64_t>(lvl.value()) >= stt.getLvlRank()) 1496 return emitError( 1497 "Level index exceeds the rank of the input sparse tensor"); 1498 } 1499 return success(); 1500 } 1501 1502 std::optional<uint64_t> LvlOp::getConstantLvlIndex() { 1503 return getConstantIntValue(getIndex()); 1504 } 1505 1506 Speculation::Speculatability LvlOp::getSpeculatability() { 1507 auto constantIndex = getConstantLvlIndex(); 1508 if (!constantIndex) 1509 return Speculation::NotSpeculatable; 1510 1511 assert(constantIndex < 1512 cast<RankedTensorType>(getSource().getType()).getRank()); 1513 return Speculation::Speculatable; 1514 } 1515 1516 OpFoldResult LvlOp::fold(FoldAdaptor adaptor) { 1517 auto lvlIndex = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getIndex()); 1518 if (!lvlIndex) 1519 return {}; 1520 1521 Level lvl = lvlIndex.getAPSInt().getZExtValue(); 1522 auto stt = getSparseTensorType(getSource()); 1523 if (lvl >= stt.getLvlRank()) { 1524 // Follows the same convention used by tensor.dim operation. Out of bound 1525 // indices produce undefined behavior but are still valid IR. Don't choke on 1526 // them. 1527 return {}; 1528 } 1529 1530 // Helper lambda to build an IndexAttr. 1531 auto getIndexAttr = [this](int64_t lvlSz) { 1532 return IntegerAttr::get(IndexType::get(getContext()), APInt(64, lvlSz)); 1533 }; 1534 1535 SmallVector<Size> lvlShape = stt.getLvlShape(); 1536 if (!ShapedType::isDynamic(lvlShape[lvl])) 1537 return getIndexAttr(lvlShape[lvl]); 1538 1539 return {}; 1540 } 1541 1542 void ReinterpretMapOp::build(OpBuilder &odsBuilder, OperationState &odsState, 1543 SparseTensorEncodingAttr dstEnc, Value source) { 1544 auto srcStt = getSparseTensorType(source); 1545 SmallVector<int64_t> srcLvlShape = srcStt.getLvlShape(); 1546 SmallVector<int64_t> dstDimShape = 1547 dstEnc.translateShape(srcLvlShape, CrdTransDirectionKind::lvl2dim); 1548 auto dstTp = 1549 RankedTensorType::get(dstDimShape, srcStt.getElementType(), dstEnc); 1550 return build(odsBuilder, odsState, dstTp, source); 1551 } 1552 1553 LogicalResult ReinterpretMapOp::verify() { 1554 auto srcStt = getSparseTensorType(getSource()); 1555 auto dstStt = getSparseTensorType(getDest()); 1556 ArrayRef<LevelType> srcLvlTps = srcStt.getLvlTypes(); 1557 ArrayRef<LevelType> dstLvlTps = dstStt.getLvlTypes(); 1558 1559 if (srcLvlTps.size() != dstLvlTps.size()) 1560 return emitError("Level rank mismatch between source/dest tensors"); 1561 1562 for (auto [srcLvlTp, dstLvlTp] : llvm::zip(srcLvlTps, dstLvlTps)) 1563 if (srcLvlTp != dstLvlTp) 1564 return emitError("Level type mismatch between source/dest tensors"); 1565 1566 if (srcStt.getPosWidth() != dstStt.getPosWidth() || 1567 srcStt.getCrdWidth() != dstStt.getCrdWidth()) { 1568 return emitError("Crd/Pos width mismatch between source/dest tensors"); 1569 } 1570 1571 if (srcStt.getElementType() != dstStt.getElementType()) 1572 return emitError("Element type mismatch between source/dest tensors"); 1573 1574 SmallVector<Size> srcLvlShape = srcStt.getLvlShape(); 1575 SmallVector<Size> dstLvlShape = dstStt.getLvlShape(); 1576 for (auto [srcLvlSz, dstLvlSz] : llvm::zip(srcLvlShape, dstLvlShape)) { 1577 if (srcLvlSz != dstLvlSz) { 1578 // Should we allow one side to be dynamic size, e.g., <?x?> should be 1579 // compatible to <3x4>? For now, we require all the level sizes to be 1580 // *exactly* matched for simplicity. 1581 return emitError("Level size mismatch between source/dest tensors"); 1582 } 1583 } 1584 1585 return success(); 1586 } 1587 1588 OpFoldResult ReinterpretMapOp::fold(FoldAdaptor adaptor) { 1589 if (getSource().getType() == getDest().getType()) 1590 return getSource(); 1591 1592 if (auto def = getSource().getDefiningOp<ReinterpretMapOp>()) { 1593 // A -> B, B -> A ==> A 1594 if (def.getSource().getType() == getDest().getType()) 1595 return def.getSource(); 1596 } 1597 return {}; 1598 } 1599 1600 template <typename ToBufferOp> 1601 static LogicalResult inferSparseBufferType(ValueRange ops, DictionaryAttr attr, 1602 OpaqueProperties prop, 1603 RegionRange region, 1604 SmallVectorImpl<mlir::Type> &ret) { 1605 typename ToBufferOp::Adaptor adaptor(ops, attr, prop, region); 1606 SparseTensorType stt = getSparseTensorType(adaptor.getTensor()); 1607 Type elemTp = nullptr; 1608 bool withStride = false; 1609 if constexpr (std::is_same_v<ToBufferOp, ToPositionsOp>) { 1610 elemTp = stt.getPosType(); 1611 } else if constexpr (std::is_same_v<ToBufferOp, ToCoordinatesOp> || 1612 std::is_same_v<ToBufferOp, ToCoordinatesBufferOp>) { 1613 elemTp = stt.getCrdType(); 1614 if constexpr (std::is_same_v<ToBufferOp, ToCoordinatesOp>) 1615 withStride = stt.getAoSCOOStart() <= adaptor.getLevel(); 1616 } else if constexpr (std::is_same_v<ToBufferOp, ToValuesOp>) { 1617 elemTp = stt.getElementType(); 1618 } 1619 1620 assert(elemTp && "unhandled operation."); 1621 SmallVector<int64_t> bufShape = stt.getBatchLvlShape(); 1622 bufShape.push_back(ShapedType::kDynamic); 1623 1624 auto layout = withStride ? StridedLayoutAttr::StridedLayoutAttr::get( 1625 stt.getContext(), ShapedType::kDynamic, 1626 {ShapedType::kDynamic}) 1627 : StridedLayoutAttr(); 1628 ret.emplace_back(MemRefType::get(bufShape, elemTp, layout)); 1629 return success(); 1630 } 1631 1632 LogicalResult ToPositionsOp::verify() { 1633 auto stt = getSparseTensorType(getTensor()); 1634 if (failed(lvlIsInBounds(getLevel(), getTensor()))) 1635 return emitError("requested level is out of bounds"); 1636 if (failed(isMatchingWidth(getResult(), stt.getPosWidth()))) 1637 return emitError("unexpected type for positions"); 1638 return success(); 1639 } 1640 1641 LogicalResult 1642 ToPositionsOp::inferReturnTypes(MLIRContext *ctx, std::optional<Location> loc, 1643 ValueRange ops, DictionaryAttr attr, 1644 OpaqueProperties prop, RegionRange region, 1645 SmallVectorImpl<mlir::Type> &ret) { 1646 return inferSparseBufferType<ToPositionsOp>(ops, attr, prop, region, ret); 1647 } 1648 1649 LogicalResult ToCoordinatesOp::verify() { 1650 auto stt = getSparseTensorType(getTensor()); 1651 if (failed(lvlIsInBounds(getLevel(), getTensor()))) 1652 return emitError("requested level is out of bounds"); 1653 if (failed(isMatchingWidth(getResult(), stt.getCrdWidth()))) 1654 return emitError("unexpected type for coordinates"); 1655 return success(); 1656 } 1657 1658 LogicalResult 1659 ToCoordinatesOp::inferReturnTypes(MLIRContext *ctx, std::optional<Location> loc, 1660 ValueRange ops, DictionaryAttr attr, 1661 OpaqueProperties prop, RegionRange region, 1662 SmallVectorImpl<mlir::Type> &ret) { 1663 return inferSparseBufferType<ToCoordinatesOp>(ops, attr, prop, region, ret); 1664 } 1665 1666 LogicalResult ToCoordinatesBufferOp::verify() { 1667 auto stt = getSparseTensorType(getTensor()); 1668 if (stt.getAoSCOOStart() >= stt.getLvlRank()) 1669 return emitError("expected sparse tensor with a COO region"); 1670 return success(); 1671 } 1672 1673 LogicalResult ToCoordinatesBufferOp::inferReturnTypes( 1674 MLIRContext *ctx, std::optional<Location> loc, ValueRange ops, 1675 DictionaryAttr attr, OpaqueProperties prop, RegionRange region, 1676 SmallVectorImpl<mlir::Type> &ret) { 1677 return inferSparseBufferType<ToCoordinatesBufferOp>(ops, attr, prop, region, 1678 ret); 1679 } 1680 1681 LogicalResult ToValuesOp::verify() { 1682 auto stt = getSparseTensorType(getTensor()); 1683 auto mtp = getMemRefType(getResult()); 1684 if (stt.getElementType() != mtp.getElementType()) 1685 return emitError("unexpected mismatch in element types"); 1686 return success(); 1687 } 1688 1689 LogicalResult ToValuesOp::inferReturnTypes(MLIRContext *ctx, 1690 std::optional<Location> loc, 1691 ValueRange ops, DictionaryAttr attr, 1692 OpaqueProperties prop, 1693 RegionRange region, 1694 SmallVectorImpl<mlir::Type> &ret) { 1695 return inferSparseBufferType<ToValuesOp>(ops, attr, prop, region, ret); 1696 } 1697 1698 LogicalResult ToSliceOffsetOp::verify() { 1699 auto rank = getSlice().getType().getRank(); 1700 if (rank <= getDim().getSExtValue() || getDim().getSExtValue() < 0) 1701 return emitError("requested dimension out of bound"); 1702 return success(); 1703 } 1704 1705 LogicalResult ToSliceStrideOp::verify() { 1706 auto rank = getSlice().getType().getRank(); 1707 if (rank <= getDim().getSExtValue() || getDim().getSExtValue() < 0) 1708 return emitError("requested dimension out of bound"); 1709 return success(); 1710 } 1711 1712 LogicalResult GetStorageSpecifierOp::verify() { 1713 return verifySparsifierGetterSetter(getSpecifierKind(), getLevel(), 1714 getSpecifier(), getOperation()); 1715 } 1716 1717 template <typename SpecifierOp> 1718 static SetStorageSpecifierOp getSpecifierSetDef(SpecifierOp op) { 1719 return op.getSpecifier().template getDefiningOp<SetStorageSpecifierOp>(); 1720 } 1721 1722 OpFoldResult GetStorageSpecifierOp::fold(FoldAdaptor adaptor) { 1723 const StorageSpecifierKind kind = getSpecifierKind(); 1724 const auto lvl = getLevel(); 1725 for (auto op = getSpecifierSetDef(*this); op; op = getSpecifierSetDef(op)) 1726 if (kind == op.getSpecifierKind() && lvl == op.getLevel()) 1727 return op.getValue(); 1728 return {}; 1729 } 1730 1731 LogicalResult SetStorageSpecifierOp::verify() { 1732 return verifySparsifierGetterSetter(getSpecifierKind(), getLevel(), 1733 getSpecifier(), getOperation()); 1734 } 1735 1736 template <class T> 1737 static LogicalResult verifyNumBlockArgs(T *op, Region ®ion, 1738 const char *regionName, 1739 TypeRange inputTypes, Type outputType) { 1740 unsigned numArgs = region.getNumArguments(); 1741 unsigned expectedNum = inputTypes.size(); 1742 if (numArgs != expectedNum) 1743 return op->emitError() << regionName << " region must have exactly " 1744 << expectedNum << " arguments"; 1745 1746 for (unsigned i = 0; i < numArgs; i++) { 1747 Type typ = region.getArgument(i).getType(); 1748 if (typ != inputTypes[i]) 1749 return op->emitError() << regionName << " region argument " << (i + 1) 1750 << " type mismatch"; 1751 } 1752 Operation *term = region.front().getTerminator(); 1753 YieldOp yield = dyn_cast<YieldOp>(term); 1754 if (!yield) 1755 return op->emitError() << regionName 1756 << " region must end with sparse_tensor.yield"; 1757 if (!yield.hasSingleResult() || 1758 yield.getSingleResult().getType() != outputType) 1759 return op->emitError() << regionName << " region yield type mismatch"; 1760 1761 return success(); 1762 } 1763 1764 LogicalResult BinaryOp::verify() { 1765 NamedAttrList attrs = (*this)->getAttrs(); 1766 Type leftType = getX().getType(); 1767 Type rightType = getY().getType(); 1768 Type outputType = getOutput().getType(); 1769 Region &overlap = getOverlapRegion(); 1770 Region &left = getLeftRegion(); 1771 Region &right = getRightRegion(); 1772 1773 // Check correct number of block arguments and return type for each 1774 // non-empty region. 1775 if (!overlap.empty()) { 1776 if (failed(verifyNumBlockArgs(this, overlap, "overlap", 1777 TypeRange{leftType, rightType}, outputType))) 1778 return failure(); 1779 } 1780 if (!left.empty()) { 1781 if (failed(verifyNumBlockArgs(this, left, "left", TypeRange{leftType}, 1782 outputType))) 1783 return failure(); 1784 } else if (getLeftIdentity()) { 1785 if (leftType != outputType) 1786 return emitError("left=identity requires first argument to have the same " 1787 "type as the output"); 1788 } 1789 if (!right.empty()) { 1790 if (failed(verifyNumBlockArgs(this, right, "right", TypeRange{rightType}, 1791 outputType))) 1792 return failure(); 1793 } else if (getRightIdentity()) { 1794 if (rightType != outputType) 1795 return emitError("right=identity requires second argument to have the " 1796 "same type as the output"); 1797 } 1798 return success(); 1799 } 1800 1801 LogicalResult UnaryOp::verify() { 1802 Type inputType = getX().getType(); 1803 Type outputType = getOutput().getType(); 1804 1805 // Check correct number of block arguments and return type for each 1806 // non-empty region. 1807 Region &present = getPresentRegion(); 1808 if (!present.empty()) { 1809 if (failed(verifyNumBlockArgs(this, present, "present", 1810 TypeRange{inputType}, outputType))) 1811 return failure(); 1812 } 1813 Region &absent = getAbsentRegion(); 1814 if (!absent.empty()) { 1815 if (failed(verifyNumBlockArgs(this, absent, "absent", TypeRange{}, 1816 outputType))) 1817 return failure(); 1818 // Absent branch can only yield invariant values. 1819 Block *absentBlock = &absent.front(); 1820 Block *parent = getOperation()->getBlock(); 1821 Value absentVal = 1822 cast<YieldOp>(absentBlock->getTerminator()).getSingleResult(); 1823 if (auto arg = dyn_cast<BlockArgument>(absentVal)) { 1824 if (arg.getOwner() == parent) 1825 return emitError("absent region cannot yield linalg argument"); 1826 } else if (Operation *def = absentVal.getDefiningOp()) { 1827 if (!isa<arith::ConstantOp>(def) && 1828 (def->getBlock() == absentBlock || def->getBlock() == parent)) 1829 return emitError("absent region cannot yield locally computed value"); 1830 } 1831 } 1832 return success(); 1833 } 1834 1835 bool ConcatenateOp::needsExtraSort() { 1836 SparseTensorType dstStt = getSparseTensorType(*this); 1837 if (dstStt.isAllDense() || !dstStt.isAllOrdered()) 1838 return false; 1839 1840 bool allSameOrdered = llvm::all_of(getInputs(), [dstStt](Value op) { 1841 return getSparseTensorType(op).hasSameDimToLvl(dstStt); 1842 }); 1843 // TODO: When conDim != 0, as long as conDim corresponding to the first level 1844 // in all input/output buffers, and all input/output buffers have the same 1845 // dimToLvl, the tmp COO buffer is still unnecessary (e.g, concatenate 1846 // CSC matrices along column). 1847 bool directLowerable = 1848 allSameOrdered && getDimension() == 0 && dstStt.isIdentity(); 1849 return !directLowerable; 1850 } 1851 1852 LogicalResult ConcatenateOp::verify() { 1853 const auto dstTp = getSparseTensorType(*this); 1854 const Dimension concatDim = getDimension(); 1855 const Dimension dimRank = dstTp.getDimRank(); 1856 1857 if (getInputs().size() <= 1) 1858 return emitError("Need at least two tensors to concatenate."); 1859 1860 if (concatDim >= dimRank) 1861 return emitError(llvm::formatv( 1862 "Concat-dimension is out of bounds for dimension-rank ({0} >= {1})", 1863 concatDim, dimRank)); 1864 1865 for (const auto &it : llvm::enumerate(getInputs())) { 1866 const auto i = it.index(); 1867 const auto srcTp = getSparseTensorType(it.value()); 1868 if (srcTp.hasDynamicDimShape()) 1869 return emitError(llvm::formatv("Input tensor ${0} has dynamic shape", i)); 1870 const Dimension srcDimRank = srcTp.getDimRank(); 1871 if (srcDimRank != dimRank) 1872 return emitError( 1873 llvm::formatv("Input tensor ${0} has a different rank (rank={1}) " 1874 "from the output tensor (rank={2}).", 1875 i, srcDimRank, dimRank)); 1876 } 1877 1878 for (Dimension d = 0; d < dimRank; d++) { 1879 const Size dstSh = dstTp.getDimShape()[d]; 1880 if (d == concatDim) { 1881 if (!ShapedType::isDynamic(dstSh)) { 1882 // If we reach here, then all inputs have static shapes. So we 1883 // can use `getDimShape()[d]` instead of `*getDynamicDimSize(d)` 1884 // to avoid redundant assertions in the loop. 1885 Size sumSz = 0; 1886 for (const auto src : getInputs()) 1887 sumSz += getSparseTensorType(src).getDimShape()[d]; 1888 // If all dimension are statically known, the sum of all the input 1889 // dimensions should be equal to the output dimension. 1890 if (sumSz != dstSh) 1891 return emitError( 1892 "The concatenation dimension of the output tensor should be the " 1893 "sum of all the concatenation dimensions of the input tensors."); 1894 } 1895 } else { 1896 Size prev = dstSh; 1897 for (const auto src : getInputs()) { 1898 const auto sh = getSparseTensorType(src).getDimShape()[d]; 1899 if (!ShapedType::isDynamic(prev) && sh != prev) 1900 return emitError("All dimensions (expect for the concatenating one) " 1901 "should be equal."); 1902 prev = sh; 1903 } 1904 } 1905 } 1906 1907 return success(); 1908 } 1909 1910 void PushBackOp::build(OpBuilder &builder, OperationState &result, 1911 Value curSize, Value inBuffer, Value value) { 1912 build(builder, result, curSize, inBuffer, value, Value()); 1913 } 1914 1915 LogicalResult PushBackOp::verify() { 1916 if (Value n = getN()) { 1917 std::optional<int64_t> nValue = getConstantIntValue(n); 1918 if (nValue && nValue.value() < 1) 1919 return emitOpError("n must be not less than 1"); 1920 } 1921 return success(); 1922 } 1923 1924 LogicalResult CompressOp::verify() { 1925 const auto stt = getSparseTensorType(getTensor()); 1926 if (stt.getLvlRank() != 1 + static_cast<Level>(getLvlCoords().size())) 1927 return emitOpError("incorrect number of coordinates"); 1928 return success(); 1929 } 1930 1931 void ForeachOp::build( 1932 OpBuilder &builder, OperationState &result, Value tensor, 1933 ValueRange initArgs, AffineMapAttr order, 1934 function_ref<void(OpBuilder &, Location, ValueRange, Value, ValueRange)> 1935 bodyBuilder) { 1936 build(builder, result, initArgs.getTypes(), tensor, initArgs, order); 1937 // Builds foreach body. 1938 if (!bodyBuilder) 1939 return; 1940 const auto stt = getSparseTensorType(tensor); 1941 const Dimension dimRank = stt.getDimRank(); 1942 1943 // Starts with `dimRank`-many coordinates. 1944 SmallVector<Type> blockArgTypes(dimRank, builder.getIndexType()); 1945 // Followed by one value. 1946 blockArgTypes.push_back(stt.getElementType()); 1947 // Followed by the reduction variables. 1948 blockArgTypes.append(initArgs.getTypes().begin(), initArgs.getTypes().end()); 1949 1950 SmallVector<Location> blockArgLocs(blockArgTypes.size(), tensor.getLoc()); 1951 1952 OpBuilder::InsertionGuard guard(builder); 1953 auto ®ion = *result.regions.front(); 1954 Block *bodyBlock = 1955 builder.createBlock(®ion, region.end(), blockArgTypes, blockArgLocs); 1956 bodyBuilder(builder, result.location, 1957 bodyBlock->getArguments().slice(0, dimRank), 1958 bodyBlock->getArguments()[dimRank], 1959 bodyBlock->getArguments().drop_front(dimRank + 1)); 1960 } 1961 1962 LogicalResult ForeachOp::verify() { 1963 const auto t = getSparseTensorType(getTensor()); 1964 const Dimension dimRank = t.getDimRank(); 1965 const auto args = getBody()->getArguments(); 1966 1967 if (getOrder().has_value() && getOrder()->getNumDims() != t.getLvlRank()) 1968 return emitError("Level traverse order does not match tensor's level rank"); 1969 1970 if (dimRank + 1 + getInitArgs().size() != args.size()) 1971 return emitError("Unmatched number of arguments in the block"); 1972 1973 if (getNumResults() != getInitArgs().size()) 1974 return emitError("Mismatch in number of init arguments and results"); 1975 1976 if (getResultTypes() != getInitArgs().getTypes()) 1977 return emitError("Mismatch in types of init arguments and results"); 1978 1979 // Cannot mark this const, because the getters aren't. 1980 auto yield = cast<YieldOp>(getBody()->getTerminator()); 1981 if (yield.getNumOperands() != getNumResults() || 1982 yield.getOperands().getTypes() != getResultTypes()) 1983 return emitError("Mismatch in types of yield values and results"); 1984 1985 const auto iTp = IndexType::get(getContext()); 1986 for (Dimension d = 0; d < dimRank; d++) 1987 if (args[d].getType() != iTp) 1988 return emitError( 1989 llvm::formatv("Expecting Index type for argument at index {0}", d)); 1990 1991 const auto elemTp = t.getElementType(); 1992 const auto valueTp = args[dimRank].getType(); 1993 if (elemTp != valueTp) 1994 return emitError( 1995 llvm::formatv("Unmatched element type between input tensor and " 1996 "block argument, expected:{0}, got: {1}", 1997 elemTp, valueTp)); 1998 return success(); 1999 } 2000 2001 OpFoldResult ReorderCOOOp::fold(FoldAdaptor adaptor) { 2002 if (getSparseTensorEncoding(getInputCoo().getType()) == 2003 getSparseTensorEncoding(getResultCoo().getType())) 2004 return getInputCoo(); 2005 2006 return {}; 2007 } 2008 2009 LogicalResult ReorderCOOOp::verify() { 2010 SparseTensorType srcStt = getSparseTensorType(getInputCoo()); 2011 SparseTensorType dstStt = getSparseTensorType(getResultCoo()); 2012 2013 if (!srcStt.isCOOType() || !dstStt.isCOOType()) 2014 return emitError("Expected COO sparse tensors only"); 2015 2016 if (!srcStt.hasSameDimToLvl(dstStt)) 2017 return emitError("Unmatched dim2lvl map between input and result COO"); 2018 2019 if (srcStt.getPosType() != dstStt.getPosType() || 2020 srcStt.getCrdType() != dstStt.getCrdType() || 2021 srcStt.getElementType() != dstStt.getElementType()) 2022 return emitError("Unmatched storage format between input and result COO"); 2023 2024 return success(); 2025 } 2026 2027 LogicalResult ReduceOp::verify() { 2028 Type inputType = getX().getType(); 2029 Region &formula = getRegion(); 2030 return verifyNumBlockArgs(this, formula, "reduce", 2031 TypeRange{inputType, inputType}, inputType); 2032 } 2033 2034 LogicalResult SelectOp::verify() { 2035 Builder b(getContext()); 2036 Type inputType = getX().getType(); 2037 Type boolType = b.getI1Type(); 2038 Region &formula = getRegion(); 2039 return verifyNumBlockArgs(this, formula, "select", TypeRange{inputType}, 2040 boolType); 2041 } 2042 2043 LogicalResult SortOp::verify() { 2044 AffineMap xPerm = getPermMap(); 2045 uint64_t nx = xPerm.getNumDims(); 2046 if (nx < 1) 2047 return emitError(llvm::formatv("Expected rank(perm_map) > 1, got {0}", nx)); 2048 2049 if (!xPerm.isPermutation()) 2050 return emitError( 2051 llvm::formatv("Expected a permutation map, got {0}", xPerm)); 2052 2053 // We can't check the size of the buffers when n or buffer dimensions aren't 2054 // compile-time constants. 2055 std::optional<int64_t> cn = getConstantIntValue(getN()); 2056 if (!cn) 2057 return success(); 2058 2059 // Verify dimensions. 2060 const auto checkDim = [&](Value v, Size minSize, 2061 const char *message) -> LogicalResult { 2062 const Size sh = getMemRefType(v).getShape()[0]; 2063 if (!ShapedType::isDynamic(sh) && sh < minSize) 2064 return emitError( 2065 llvm::formatv("{0} got {1} < {2}", message, sh, minSize)); 2066 return success(); 2067 }; 2068 uint64_t n = cn.value(); 2069 uint64_t ny = 0; 2070 if (auto nyAttr = getNyAttr()) 2071 ny = nyAttr.getInt(); 2072 if (failed(checkDim(getXy(), n * (nx + ny), 2073 "Expected dimension(xy) >= n * (rank(perm_map) + ny)"))) 2074 return failure(); 2075 for (Value opnd : getYs()) 2076 if (failed(checkDim(opnd, n, "Expected dimension(y) >= n"))) 2077 return failure(); 2078 2079 return success(); 2080 } 2081 2082 //===----------------------------------------------------------------------===// 2083 // Sparse Tensor Iteration Operations. 2084 //===----------------------------------------------------------------------===// 2085 2086 IterSpaceType IteratorType::getIterSpaceType() const { 2087 return IterSpaceType::get(getContext(), getEncoding(), getLoLvl(), 2088 getHiLvl()); 2089 } 2090 2091 IteratorType IterSpaceType::getIteratorType() const { 2092 return IteratorType::get(getContext(), getEncoding(), getLoLvl(), getHiLvl()); 2093 } 2094 2095 /// Parses a level range in the form "$lo `to` $hi" 2096 /// or simply "$lo" if $hi - $lo = 1 2097 static ParseResult parseLevelRange(AsmParser &parser, Level &lvlLo, 2098 Level &lvlHi) { 2099 if (parser.parseInteger(lvlLo)) 2100 return failure(); 2101 2102 if (succeeded(parser.parseOptionalKeyword("to"))) { 2103 if (parser.parseInteger(lvlHi)) 2104 return failure(); 2105 } else { 2106 lvlHi = lvlLo + 1; 2107 } 2108 2109 if (lvlHi <= lvlLo) 2110 return parser.emitError(parser.getNameLoc(), 2111 "expect larger level upper bound than lower bound"); 2112 2113 return success(); 2114 } 2115 2116 /// Parses a level range in the form "$lo `to` $hi" 2117 /// or simply "$lo" if $hi - $lo = 1 2118 static ParseResult parseLevelRange(OpAsmParser &parser, IntegerAttr &lvlLoAttr, 2119 IntegerAttr &lvlHiAttr) { 2120 Level lvlLo, lvlHi; 2121 if (parseLevelRange(parser, lvlLo, lvlHi)) 2122 return failure(); 2123 2124 lvlLoAttr = IntegerAttr::get(parser.getBuilder().getIndexType(), lvlLo); 2125 lvlHiAttr = IntegerAttr::get(parser.getBuilder().getIndexType(), lvlHi); 2126 return success(); 2127 } 2128 2129 /// Prints a level range in the form "$lo `to` $hi" 2130 /// or simply "$lo" if $hi - $lo = 1 2131 static void printLevelRange(AsmPrinter &p, Level lo, Level hi) { 2132 2133 if (lo + 1 == hi) 2134 p << lo; 2135 else 2136 p << lo << " to " << hi; 2137 } 2138 2139 /// Prints a level range in the form "$lo `to` $hi" 2140 /// or simply "$lo" if $hi - $lo = 1 2141 static void printLevelRange(OpAsmPrinter &p, Operation *, IntegerAttr lvlLo, 2142 IntegerAttr lvlHi) { 2143 unsigned lo = lvlLo.getValue().getZExtValue(); 2144 unsigned hi = lvlHi.getValue().getZExtValue(); 2145 printLevelRange(p, lo, hi); 2146 } 2147 2148 /// Parses a list of `optional` defined list in the form of 2149 /// "(%val0, _, %val1, ...)", where `_` is used to annotate that the 2150 /// corresponding value is not defined (e.g., to represent an undefined 2151 /// coordinate in the sparse iteration space). 2152 static ParseResult parseOptionalDefinedList( 2153 OpAsmParser &parser, OperationState &state, I64BitSet &definedSet, 2154 SmallVectorImpl<OpAsmParser::Argument> &definedArgs, 2155 unsigned maxCnt = std::numeric_limits<unsigned>::max(), 2156 OpAsmParser::Delimiter delimiter = OpAsmParser::Delimiter::Paren) { 2157 unsigned cnt = 0; 2158 ParseResult crdList = 2159 parser.parseCommaSeparatedList(delimiter, [&]() -> ParseResult { 2160 if (parser.parseOptionalKeyword("_")) { 2161 if (parser.parseArgument(definedArgs.emplace_back())) 2162 return failure(); 2163 definedSet.set(cnt); 2164 } 2165 cnt += 1; 2166 return success(); 2167 }); 2168 2169 if (cnt > maxCnt) 2170 return parser.emitError(parser.getNameLoc(), 2171 "parsed more value than expected."); 2172 2173 if (failed(crdList)) { 2174 return parser.emitError( 2175 parser.getNameLoc(), 2176 "expecting SSA value or \"_\" for level coordinates"); 2177 } 2178 assert(definedArgs.size() == definedSet.count()); 2179 return success(); 2180 } 2181 2182 static void printOptionalDefinedList(OpAsmPrinter &p, unsigned size, 2183 Block::BlockArgListType blocksArgs, 2184 I64BitSet definedSet) { 2185 if (definedSet.empty()) 2186 return; 2187 2188 for (unsigned i = 0; i < size; i++) { 2189 if (definedSet[i]) { 2190 p << blocksArgs.front(); 2191 blocksArgs = blocksArgs.drop_front(); 2192 } else { 2193 p << "_"; 2194 } 2195 if (i != size - 1) 2196 p << ", "; 2197 } 2198 assert(blocksArgs.empty()); 2199 } 2200 2201 static ParseResult 2202 parseUsedCoordList(OpAsmParser &parser, OperationState &state, 2203 SmallVectorImpl<OpAsmParser::Argument> &coords) { 2204 // Parse "at(%crd0, _, ...)" 2205 I64BitSet crdUsedLvlSet; 2206 if (succeeded(parser.parseOptionalKeyword("at")) && 2207 failed(parseOptionalDefinedList(parser, state, crdUsedLvlSet, coords))) 2208 return failure(); 2209 2210 // Always use IndexType for the coordinate. 2211 for (auto &coord : coords) 2212 coord.type = parser.getBuilder().getIndexType(); 2213 2214 // Set the CrdUsedLvl bitset. 2215 state.addAttribute("crdUsedLvls", 2216 parser.getBuilder().getI64IntegerAttr(crdUsedLvlSet)); 2217 return success(); 2218 } 2219 2220 static ParseResult 2221 parseSparseIterateLoop(OpAsmParser &parser, OperationState &state, 2222 SmallVectorImpl<OpAsmParser::Argument> &iterators, 2223 SmallVectorImpl<OpAsmParser::Argument> &blockArgs) { 2224 SmallVector<OpAsmParser::UnresolvedOperand> spaces; 2225 SmallVector<OpAsmParser::UnresolvedOperand> initArgs; 2226 2227 // Parse "%iters, ... in %spaces, ..." 2228 if (parser.parseArgumentList(iterators) || parser.parseKeyword("in") || 2229 parser.parseOperandList(spaces)) 2230 return failure(); 2231 2232 if (iterators.size() != spaces.size()) 2233 return parser.emitError( 2234 parser.getNameLoc(), 2235 "mismatch in number of sparse iterators and sparse spaces"); 2236 2237 SmallVector<OpAsmParser::Argument> coords; 2238 if (failed(parseUsedCoordList(parser, state, coords))) 2239 return failure(); 2240 size_t numCrds = coords.size(); 2241 2242 // Parse "iter_args(%arg = %init, ...)" 2243 bool hasIterArgs = succeeded(parser.parseOptionalKeyword("iter_args")); 2244 if (hasIterArgs) 2245 if (parser.parseAssignmentList(blockArgs, initArgs)) 2246 return failure(); 2247 2248 blockArgs.append(coords); 2249 2250 SmallVector<Type> iterSpaceTps; 2251 // parse ": sparse_tensor.iter_space -> ret" 2252 if (parser.parseColon() || parser.parseTypeList(iterSpaceTps)) 2253 return failure(); 2254 if (iterSpaceTps.size() != spaces.size()) 2255 return parser.emitError(parser.getNameLoc(), 2256 "mismatch in number of iteration space operands " 2257 "and iteration space types"); 2258 2259 for (auto [it, tp] : llvm::zip_equal(iterators, iterSpaceTps)) { 2260 IterSpaceType spaceTp = llvm::dyn_cast<IterSpaceType>(tp); 2261 if (!spaceTp) 2262 return parser.emitError(parser.getNameLoc(), 2263 "expected sparse_tensor.iter_space type for " 2264 "iteration space operands"); 2265 it.type = spaceTp.getIteratorType(); 2266 } 2267 2268 if (hasIterArgs) 2269 if (parser.parseArrowTypeList(state.types)) 2270 return failure(); 2271 2272 // Resolves input operands. 2273 if (parser.resolveOperands(spaces, iterSpaceTps, parser.getNameLoc(), 2274 state.operands)) 2275 return failure(); 2276 2277 if (hasIterArgs) { 2278 // Strip off leading args that used for coordinates. 2279 MutableArrayRef args = MutableArrayRef(blockArgs).drop_back(numCrds); 2280 if (args.size() != initArgs.size() || args.size() != state.types.size()) { 2281 return parser.emitError( 2282 parser.getNameLoc(), 2283 "mismatch in number of iteration arguments and return values"); 2284 } 2285 2286 for (auto [it, init, tp] : llvm::zip_equal(args, initArgs, state.types)) { 2287 it.type = tp; 2288 if (parser.resolveOperand(init, tp, state.operands)) 2289 return failure(); 2290 } 2291 } 2292 return success(); 2293 } 2294 2295 static ParseResult 2296 parseSparseCoIterateLoop(OpAsmParser &parser, OperationState &state, 2297 SmallVectorImpl<Value> &spacesVals, 2298 SmallVectorImpl<OpAsmParser::Argument> &blockArgs) { 2299 2300 // Parse "(%spaces, ...)" 2301 SmallVector<OpAsmParser::UnresolvedOperand> spaces; 2302 if (parser.parseOperandList(spaces, OpAsmParser::Delimiter::Paren)) 2303 return failure(); 2304 2305 SmallVector<OpAsmParser::Argument> coords; 2306 if (failed(parseUsedCoordList(parser, state, coords))) 2307 return failure(); 2308 size_t numCrds = coords.size(); 2309 2310 // Parse "iter_args(%arg = %init, ...)" 2311 SmallVector<OpAsmParser::UnresolvedOperand> initArgs; 2312 bool hasIterArgs = succeeded(parser.parseOptionalKeyword("iter_args")); 2313 if (hasIterArgs) 2314 if (parser.parseAssignmentList(blockArgs, initArgs)) 2315 return failure(); 2316 blockArgs.append(coords); 2317 2318 SmallVector<Type> iterSpaceTps; 2319 // parse ": (sparse_tensor.iter_space, ...) -> ret" 2320 if (parser.parseColon() || parser.parseLParen() || 2321 parser.parseTypeList(iterSpaceTps) || parser.parseRParen()) 2322 return failure(); 2323 2324 if (iterSpaceTps.size() != spaces.size()) 2325 return parser.emitError(parser.getNameLoc(), 2326 "mismatch in number of iteration space operands " 2327 "and iteration space types"); 2328 2329 if (hasIterArgs) 2330 if (parser.parseArrowTypeList(state.types)) 2331 return failure(); 2332 2333 // Resolves input sparse iteration spaces. 2334 if (parser.resolveOperands(spaces, iterSpaceTps, parser.getNameLoc(), 2335 spacesVals)) 2336 return failure(); 2337 state.operands.append(spacesVals); 2338 2339 if (hasIterArgs) { 2340 // Strip off trailing args that used for coordinates. 2341 MutableArrayRef args = MutableArrayRef(blockArgs).drop_back(numCrds); 2342 if (args.size() != initArgs.size() || args.size() != state.types.size()) { 2343 return parser.emitError( 2344 parser.getNameLoc(), 2345 "mismatch in number of iteration arguments and return values"); 2346 } 2347 2348 for (auto [it, init, tp] : llvm::zip_equal(args, initArgs, state.types)) { 2349 it.type = tp; 2350 if (parser.resolveOperand(init, tp, state.operands)) 2351 return failure(); 2352 } 2353 } 2354 return success(); 2355 } 2356 2357 LogicalResult ExtractIterSpaceOp::inferReturnTypes( 2358 MLIRContext *ctx, std::optional<Location> loc, ValueRange ops, 2359 DictionaryAttr attr, OpaqueProperties prop, RegionRange region, 2360 SmallVectorImpl<mlir::Type> &ret) { 2361 2362 ExtractIterSpaceOp::Adaptor adaptor(ops, attr, prop, region); 2363 SparseTensorType stt = getSparseTensorType(adaptor.getTensor()); 2364 ret.push_back(IterSpaceType::get(ctx, stt.getEncoding(), adaptor.getLoLvl(), 2365 adaptor.getHiLvl())); 2366 return success(); 2367 } 2368 2369 LogicalResult ExtractIterSpaceOp::verify() { 2370 if (getLoLvl() >= getHiLvl()) 2371 return emitOpError("expected smaller level low than level high"); 2372 2373 TypedValue<IteratorType> pIter = getParentIter(); 2374 if ((pIter && getLoLvl() == 0) || (!pIter && getLoLvl() != 0)) { 2375 return emitOpError( 2376 "parent iterator should be specified iff level lower bound equals 0"); 2377 } 2378 2379 if (pIter) { 2380 IterSpaceType spaceTp = getExtractedSpace().getType(); 2381 if (pIter.getType().getEncoding() != spaceTp.getEncoding()) 2382 return emitOpError( 2383 "mismatch in parent iterator encoding and iteration space encoding."); 2384 2385 if (spaceTp.getLoLvl() != pIter.getType().getHiLvl()) 2386 return emitOpError("parent iterator should be used to extract an " 2387 "iteration space from a consecutive level."); 2388 } 2389 2390 return success(); 2391 } 2392 2393 LogicalResult ExtractValOp::verify() { 2394 auto stt = getSparseTensorType(getTensor()); 2395 auto itTp = getIterator().getType(); 2396 2397 if (stt.getEncoding() != itTp.getEncoding()) 2398 return emitOpError("mismatch in tensor encoding and iterator encoding."); 2399 2400 if (stt.getLvlRank() != itTp.getHiLvl()) 2401 return emitOpError("must use last-level iterator to extract values. "); 2402 2403 return success(); 2404 } 2405 2406 struct RemoveUnusedLvlCrds : public OpRewritePattern<IterateOp> { 2407 using OpRewritePattern::OpRewritePattern; 2408 2409 LogicalResult matchAndRewrite(IterateOp iterateOp, 2410 PatternRewriter &rewriter) const override { 2411 I64BitSet newUsedLvls(0); 2412 llvm::BitVector toRemove(iterateOp.getBody()->getNumArguments()); 2413 for (unsigned i = 0, e = iterateOp.getSpaceDim(); i < e; i++) { 2414 if (auto crd = iterateOp.getLvlCrd(i)) { 2415 if (crd->getUsers().empty()) 2416 toRemove.set(crd->getArgNumber()); 2417 else 2418 newUsedLvls.set(i); 2419 } 2420 } 2421 2422 // All coordinates are used. 2423 if (toRemove.none()) 2424 return failure(); 2425 2426 rewriter.startOpModification(iterateOp); 2427 iterateOp.setCrdUsedLvls(newUsedLvls); 2428 iterateOp.getBody()->eraseArguments(toRemove); 2429 rewriter.finalizeOpModification(iterateOp); 2430 return success(); 2431 } 2432 }; 2433 2434 void IterateOp::getCanonicalizationPatterns(mlir::RewritePatternSet &results, 2435 mlir::MLIRContext *context) { 2436 results.add<RemoveUnusedLvlCrds>(context); 2437 } 2438 2439 void IterateOp::build(OpBuilder &builder, OperationState &odsState, 2440 Value iterSpace, ValueRange initArgs) { 2441 unsigned rank = llvm::cast<IterSpaceType>(iterSpace.getType()).getSpaceDim(); 2442 // All ones. 2443 I64BitSet set((1 << rank) - 1); 2444 return build(builder, odsState, iterSpace, initArgs, set); 2445 } 2446 2447 void IterateOp::build(OpBuilder &builder, OperationState &odsState, 2448 Value iterSpace, ValueRange initArgs, 2449 I64BitSet crdUsedLvls) { 2450 OpBuilder::InsertionGuard guard(builder); 2451 2452 odsState.addOperands(iterSpace); 2453 odsState.addOperands(initArgs); 2454 odsState.getOrAddProperties<Properties>().crdUsedLvls = 2455 builder.getIntegerAttr(builder.getIntegerType(64), crdUsedLvls); 2456 Region *bodyRegion = odsState.addRegion(); 2457 odsState.addTypes(initArgs.getTypes()); 2458 Block *bodyBlock = builder.createBlock(bodyRegion); 2459 2460 // Starts with a list of user-provided loop arguments. 2461 for (Value v : initArgs) 2462 bodyBlock->addArgument(v.getType(), v.getLoc()); 2463 2464 // Follows by a list of used coordinates. 2465 for (unsigned i = 0, e = crdUsedLvls.count(); i < e; i++) 2466 bodyBlock->addArgument(builder.getIndexType(), odsState.location); 2467 2468 // Ends with sparse iterator 2469 bodyBlock->addArgument( 2470 llvm::cast<IterSpaceType>(iterSpace.getType()).getIteratorType(), 2471 odsState.location); 2472 } 2473 2474 ParseResult IterateOp::parse(OpAsmParser &parser, OperationState &result) { 2475 OpAsmParser::Argument iterator; 2476 OpAsmParser::UnresolvedOperand iterSpace; 2477 2478 SmallVector<OpAsmParser::Argument> iters, iterArgs; 2479 if (parseSparseIterateLoop(parser, result, iters, iterArgs)) 2480 return failure(); 2481 if (iters.size() != 1) 2482 return parser.emitError(parser.getNameLoc(), 2483 "expected only one iterator/iteration space"); 2484 2485 iterArgs.append(iters); 2486 Region *body = result.addRegion(); 2487 if (parser.parseRegion(*body, iterArgs)) 2488 return failure(); 2489 2490 IterateOp::ensureTerminator(*body, parser.getBuilder(), result.location); 2491 2492 // Parse the optional attribute list. 2493 if (parser.parseOptionalAttrDict(result.attributes)) 2494 return failure(); 2495 2496 return success(); 2497 } 2498 2499 /// Prints the initialization list in the form of 2500 /// <prefix>(%inner = %outer, %inner2 = %outer2, <...>) 2501 /// where 'inner' values are assumed to be region arguments and 'outer' values 2502 /// are regular SSA values. 2503 static void printInitializationList(OpAsmPrinter &p, 2504 Block::BlockArgListType blocksArgs, 2505 ValueRange initializers, 2506 StringRef prefix = "") { 2507 assert(blocksArgs.size() == initializers.size() && 2508 "expected same length of arguments and initializers"); 2509 if (initializers.empty()) 2510 return; 2511 2512 p << prefix << '('; 2513 llvm::interleaveComma(llvm::zip(blocksArgs, initializers), p, [&](auto it) { 2514 p << std::get<0>(it) << " = " << std::get<1>(it); 2515 }); 2516 p << ")"; 2517 } 2518 2519 template <typename SparseLoopOp> 2520 static LogicalResult verifySparseLoopOp(SparseLoopOp op) { 2521 if (op.getInitArgs().size() != op.getNumResults()) { 2522 return op.emitOpError( 2523 "mismatch in number of loop-carried values and defined values"); 2524 } 2525 if (op.getCrdUsedLvls().max() > op.getSpaceDim()) 2526 return op.emitOpError("required out-of-bound coordinates"); 2527 2528 return success(); 2529 } 2530 2531 LogicalResult IterateOp::verify() { return verifySparseLoopOp(*this); } 2532 LogicalResult CoIterateOp::verify() { return verifySparseLoopOp(*this); } 2533 2534 void IterateOp::print(OpAsmPrinter &p) { 2535 p << " " << getIterator() << " in " << getIterSpace(); 2536 if (!getCrdUsedLvls().empty()) { 2537 p << " at("; 2538 printOptionalDefinedList(p, getSpaceDim(), getCrds(), getCrdUsedLvls()); 2539 p << ")"; 2540 } 2541 printInitializationList(p, getRegionIterArgs(), getInitArgs(), " iter_args"); 2542 2543 p << " : " << getIterSpace().getType() << " "; 2544 if (!getInitArgs().empty()) 2545 p.printArrowTypeList(getInitArgs().getTypes()); 2546 2547 p << " "; 2548 p.printRegion(getRegion(), /*printEntryBlockArgs=*/false, 2549 /*printBlockTerminators=*/!getInitArgs().empty()); 2550 } 2551 2552 LogicalResult IterateOp::verifyRegions() { 2553 if (getIterator().getType() != getIterSpace().getType().getIteratorType()) 2554 return emitOpError("mismatch in iterator and iteration space type"); 2555 if (getNumRegionIterArgs() != getNumResults()) 2556 return emitOpError( 2557 "mismatch in number of basic block args and defined values"); 2558 2559 auto initArgs = getInitArgs(); 2560 auto iterArgs = getRegionIterArgs(); 2561 auto yieldVals = getYieldedValues(); 2562 auto opResults = getResults(); 2563 if (!llvm::all_equal({initArgs.size(), iterArgs.size(), yieldVals.size(), 2564 opResults.size()})) { 2565 return emitOpError() << "number mismatch between iter args and results."; 2566 } 2567 2568 for (auto [i, init, iter, yield, ret] : 2569 llvm::enumerate(initArgs, iterArgs, yieldVals, opResults)) { 2570 if (init.getType() != ret.getType()) 2571 return emitOpError() << "types mismatch between " << i 2572 << "th iter operand and defined value"; 2573 if (iter.getType() != ret.getType()) 2574 return emitOpError() << "types mismatch between " << i 2575 << "th iter region arg and defined value"; 2576 if (yield.getType() != ret.getType()) 2577 return emitOpError() << "types mismatch between " << i 2578 << "th yield value and defined value"; 2579 } 2580 2581 return success(); 2582 } 2583 2584 /// OpInterfaces' methods implemented by IterateOp. 2585 SmallVector<Region *> IterateOp::getLoopRegions() { return {&getRegion()}; } 2586 2587 MutableArrayRef<OpOperand> IterateOp::getInitsMutable() { 2588 return getInitArgsMutable(); 2589 } 2590 2591 Block::BlockArgListType IterateOp::getRegionIterArgs() { 2592 return getRegion().getArguments().take_front(getNumRegionIterArgs()); 2593 } 2594 2595 std::optional<MutableArrayRef<OpOperand>> IterateOp::getYieldedValuesMutable() { 2596 return cast<sparse_tensor::YieldOp>( 2597 getRegion().getBlocks().front().getTerminator()) 2598 .getResultsMutable(); 2599 } 2600 2601 std::optional<ResultRange> IterateOp::getLoopResults() { return getResults(); } 2602 2603 OperandRange IterateOp::getEntrySuccessorOperands(RegionBranchPoint point) { 2604 return getInitArgs(); 2605 } 2606 2607 void IterateOp::getSuccessorRegions(RegionBranchPoint point, 2608 SmallVectorImpl<RegionSuccessor> ®ions) { 2609 // Both the operation itself and the region may be branching into the body 2610 // or back into the operation itself. 2611 regions.push_back(RegionSuccessor(&getRegion(), getRegionIterArgs())); 2612 // It is possible for loop not to enter the body. 2613 regions.push_back(RegionSuccessor(getResults())); 2614 } 2615 2616 void CoIterateOp::build(OpBuilder &builder, OperationState &odsState, 2617 ValueRange iterSpaces, ValueRange initArgs, 2618 unsigned numCases) { 2619 unsigned rank = 2620 cast<IterSpaceType>(iterSpaces.front().getType()).getSpaceDim(); 2621 // All ones. 2622 I64BitSet set((1 << rank) - 1); 2623 // Generates all-zero case bits (they only serve as placeholders), which are 2624 // supposed to be overriden later. We need to preallocate all the regions as 2625 // mlir::Region cannot be dynamically added later after the operation is 2626 // created. 2627 SmallVector<int64_t> caseBits(numCases, 0); 2628 ArrayAttr cases = builder.getI64ArrayAttr(caseBits); 2629 return CoIterateOp::build(builder, odsState, initArgs.getTypes(), iterSpaces, 2630 initArgs, set, cases, 2631 /*caseRegionsCount=*/numCases); 2632 } 2633 2634 ParseResult CoIterateOp::parse(OpAsmParser &parser, OperationState &result) { 2635 2636 SmallVector<Value> spaces; 2637 // The block argument list of each regions, it is arranged in the order of 2638 // ([used coordinate list], [loop iterations args], [sparse iterator list]). 2639 SmallVector<OpAsmParser::Argument> blockArgs; 2640 if (parseSparseCoIterateLoop(parser, result, spaces, blockArgs)) 2641 return failure(); 2642 2643 result.addAttribute("operandSegmentSizes", 2644 parser.getBuilder().getDenseI32ArrayAttr( 2645 {static_cast<int32_t>(spaces.size()), 2646 static_cast<int32_t>(result.types.size())})); 2647 2648 SmallVector<Attribute> cases; 2649 while (succeeded(parser.parseOptionalKeyword("case"))) { 2650 // Parse one region per case. 2651 I64BitSet definedItSet; 2652 SmallVector<OpAsmParser::Argument> definedIts; 2653 if (parseOptionalDefinedList(parser, result, definedItSet, definedIts, 2654 spaces.size(), OpAsmParser::Delimiter::None)) 2655 return failure(); 2656 2657 cases.push_back(parser.getBuilder().getI64IntegerAttr(definedItSet)); 2658 2659 for (auto [i, definedIdx] : llvm::enumerate(definedItSet.bits())) { 2660 // Resolve the iterator type based on the iteration space type. 2661 auto spaceTp = llvm::cast<IterSpaceType>(spaces[definedIdx].getType()); 2662 definedIts[i].type = spaceTp.getIteratorType(); 2663 } 2664 definedIts.insert(definedIts.begin(), blockArgs.begin(), blockArgs.end()); 2665 Region *body = result.addRegion(); 2666 if (parser.parseRegion(*body, definedIts)) 2667 return failure(); 2668 2669 CoIterateOp::ensureTerminator(*body, parser.getBuilder(), result.location); 2670 } 2671 2672 result.addAttribute("cases", ArrayAttr::get(parser.getContext(), cases)); 2673 2674 // Parse the optional attribute list. 2675 if (parser.parseOptionalAttrDict(result.attributes)) 2676 return failure(); 2677 2678 return success(); 2679 } 2680 2681 void CoIterateOp::print(OpAsmPrinter &p) { 2682 p << " ("; 2683 llvm::interleaveComma(getIterSpaces(), p, [&](auto s) { p << s; }); 2684 p << ")"; 2685 2686 if (!getCrdUsedLvls().empty()) { 2687 p << " at("; 2688 printOptionalDefinedList(p, getSpaceDim(), getCrds(0), getCrdUsedLvls()); 2689 p << ")"; 2690 } 2691 2692 printInitializationList(p, getRegionIterArgs(0), getInitArgs(), " iter_args"); 2693 2694 p << " : (" << getIterSpaces().getTypes() << ")"; 2695 if (!getInitArgs().empty()) 2696 p.printArrowTypeList(getInitArgs().getTypes()); 2697 2698 for (unsigned idx = 0, e = getRegions().size(); idx < e; idx++) { 2699 p.printNewline(); 2700 p << "case "; 2701 printOptionalDefinedList(p, getIterSpaces().size(), getRegionIterators(idx), 2702 getRegionDefinedSpace(idx)); 2703 p << " "; 2704 p.printRegion(getRegion(idx), /*printEntryBlockArgs=*/false, 2705 /*printBlockTerminators=*/!getInitArgs().empty()); 2706 } 2707 } 2708 2709 ValueRange CoIterateOp::getYieldedValues(unsigned regionIdx) { 2710 return cast<sparse_tensor::YieldOp>( 2711 getRegion(regionIdx).getBlocks().front().getTerminator()) 2712 .getResults(); 2713 } 2714 2715 LogicalResult CoIterateOp::verifyRegions() { 2716 for (unsigned r = 0, e = getNumRegions(); r < e; r++) { 2717 if (getNumRegionIterArgs() != getNumResults()) 2718 return emitOpError( 2719 "mismatch in number of basic block args and defined values"); 2720 2721 auto initArgs = getInitArgs(); 2722 auto iterArgs = getRegionIterArgs(r); 2723 auto yieldVals = getYieldedValues(r); 2724 auto opResults = getResults(); 2725 if (!llvm::all_equal({initArgs.size(), iterArgs.size(), yieldVals.size(), 2726 opResults.size()})) { 2727 return emitOpError() 2728 << "number mismatch between iter args and results on " << r 2729 << "th region"; 2730 } 2731 2732 for (auto [i, init, iter, yield, ret] : 2733 llvm::enumerate(initArgs, iterArgs, yieldVals, opResults)) { 2734 if (init.getType() != ret.getType()) 2735 return emitOpError() 2736 << "types mismatch between " << i 2737 << "th iter operand and defined value on " << r << "th region"; 2738 if (iter.getType() != ret.getType()) 2739 return emitOpError() << "types mismatch between " << i 2740 << "th iter region arg and defined value on " << r 2741 << "th region"; 2742 if (yield.getType() != ret.getType()) 2743 return emitOpError() 2744 << "types mismatch between " << i 2745 << "th yield value and defined value on " << r << "th region"; 2746 } 2747 } 2748 2749 auto cases = getRegionDefinedSpaces(); 2750 llvm::SmallSetVector<uint64_t, 8> set(cases.begin(), cases.end()); 2751 if (set.size() != getNumRegions()) 2752 return emitOpError("contains duplicated cases."); 2753 2754 return success(); 2755 } 2756 2757 SmallVector<Region *> CoIterateOp::getSubCasesOf(unsigned regionIdx) { 2758 SmallVector<Region *> ret; 2759 I64BitSet caseBit = getRegionDefinedSpace(regionIdx); 2760 for (Region &r : getCaseRegions()) 2761 if (getRegionDefinedSpace(r.getRegionNumber()).isSubSetOf(caseBit)) 2762 ret.push_back(&r); 2763 2764 return ret; 2765 } 2766 2767 //===----------------------------------------------------------------------===// 2768 // Sparse Tensor Dialect Setups. 2769 //===----------------------------------------------------------------------===// 2770 2771 /// Materialize a single constant operation from a given attribute value with 2772 /// the desired resultant type. 2773 Operation *SparseTensorDialect::materializeConstant(OpBuilder &builder, 2774 Attribute value, Type type, 2775 Location loc) { 2776 if (auto op = arith::ConstantOp::materialize(builder, value, type, loc)) 2777 return op; 2778 return nullptr; 2779 } 2780 2781 namespace { 2782 struct SparseTensorAsmDialectInterface : public OpAsmDialectInterface { 2783 using OpAsmDialectInterface::OpAsmDialectInterface; 2784 2785 AliasResult getAlias(Attribute attr, raw_ostream &os) const override { 2786 if (isa<SparseTensorEncodingAttr>(attr)) { 2787 os << "sparse"; 2788 return AliasResult::OverridableAlias; 2789 } 2790 return AliasResult::NoAlias; 2791 } 2792 }; 2793 } // namespace 2794 2795 void SparseTensorDialect::initialize() { 2796 addInterface<SparseTensorAsmDialectInterface>(); 2797 addAttributes< 2798 #define GET_ATTRDEF_LIST 2799 #include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.cpp.inc" 2800 >(); 2801 addTypes< 2802 #define GET_TYPEDEF_LIST 2803 #include "mlir/Dialect/SparseTensor/IR/SparseTensorTypes.cpp.inc" 2804 >(); 2805 addOperations< 2806 #define GET_OP_LIST 2807 #include "mlir/Dialect/SparseTensor/IR/SparseTensorOps.cpp.inc" 2808 >(); 2809 declarePromisedInterfaces< 2810 bufferization::BufferizableOpInterface, ConcatenateOp, ConvertOp, LoadOp, 2811 NewOp, NumberOfEntriesOp, AssembleOp, DisassembleOp, 2812 ToCoordinatesBufferOp, ToCoordinatesOp, ToPositionsOp, ToValuesOp>(); 2813 } 2814 2815 #define GET_OP_CLASSES 2816 #include "mlir/Dialect/SparseTensor/IR/SparseTensorOps.cpp.inc" 2817 2818 #include "mlir/Dialect/SparseTensor/IR/SparseTensorOpsDialect.cpp.inc" 2819