1 //===- SparseTensorDialect.cpp - Sparse tensor dialect implementation -----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include <utility> 10 11 #include "Detail/DimLvlMapParser.h" 12 13 #include "mlir/Dialect/SparseTensor/IR/Enums.h" 14 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" 15 #include "mlir/Dialect/SparseTensor/IR/SparseTensorStorageLayout.h" 16 #include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h" 17 18 #include "mlir/Dialect/Arith/IR/Arith.h" 19 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" 20 #include "mlir/Dialect/Complex/IR/Complex.h" 21 #include "mlir/Dialect/Utils/StaticValueUtils.h" 22 #include "mlir/IR/Builders.h" 23 #include "mlir/IR/DialectImplementation.h" 24 #include "mlir/IR/Matchers.h" 25 #include "mlir/IR/OpImplementation.h" 26 #include "mlir/IR/PatternMatch.h" 27 #include "llvm/ADT/Bitset.h" 28 #include "llvm/ADT/TypeSwitch.h" 29 #include "llvm/Support/FormatVariadic.h" 30 31 #define GET_ATTRDEF_CLASSES 32 #include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.cpp.inc" 33 #include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrEnums.cpp.inc" 34 35 // Forward declarations, following custom print/parsing methods are referenced 36 // by the generated code for SparseTensorTypes.td. 37 static mlir::ParseResult parseLevelRange(mlir::AsmParser &, 38 mlir::sparse_tensor::Level &, 39 mlir::sparse_tensor::Level &); 40 static void printLevelRange(mlir::AsmPrinter &, mlir::sparse_tensor::Level, 41 mlir::sparse_tensor::Level); 42 43 #define GET_TYPEDEF_CLASSES 44 #include "mlir/Dialect/SparseTensor/IR/SparseTensorTypes.cpp.inc" 45 46 using namespace mlir; 47 using namespace mlir::sparse_tensor; 48 49 // Support hashing LevelType such that SparseTensorEncodingAttr can be hashed as 50 // well. 51 namespace mlir::sparse_tensor { 52 llvm::hash_code hash_value(LevelType lt) { 53 return llvm::hash_value(static_cast<uint64_t>(lt)); 54 } 55 } // namespace mlir::sparse_tensor 56 57 //===----------------------------------------------------------------------===// 58 // Local Convenience Methods. 59 //===----------------------------------------------------------------------===// 60 61 static constexpr bool acceptBitWidth(unsigned bitWidth) { 62 switch (bitWidth) { 63 case 0: 64 case 8: 65 case 16: 66 case 32: 67 case 64: 68 return true; 69 default: 70 return false; 71 } 72 } 73 74 static SmallVector<Size> 75 getSparseFieldShape(const SparseTensorEncodingAttr enc, 76 std::optional<ArrayRef<int64_t>> dimShape) { 77 assert(enc); 78 // With only encoding, we can not determine the static shape for leading 79 // batch levels, we therefore return a dynamic shape memref instead. 80 SmallVector<int64_t> memrefShape(enc.getBatchLvlRank(), ShapedType::kDynamic); 81 if (dimShape.has_value()) { 82 // If the actual tensor shape is provided, we can then refine the leading 83 // batch dimension. 84 SmallVector<int64_t> lvlShape = 85 enc.translateShape(*dimShape, CrdTransDirectionKind::dim2lvl); 86 memrefShape.assign(lvlShape.begin(), 87 lvlShape.begin() + enc.getBatchLvlRank()); 88 } 89 // Another dynamic dimension to store the sparse level. 90 memrefShape.push_back(ShapedType::kDynamic); 91 return memrefShape; 92 } 93 94 //===----------------------------------------------------------------------===// 95 // SparseTensorDialect StorageLayout. 96 //===----------------------------------------------------------------------===// 97 98 static constexpr Level kInvalidLevel = -1u; 99 static constexpr Level kInvalidFieldIndex = -1u; 100 static constexpr FieldIndex kDataFieldStartingIdx = 0; 101 102 void StorageLayout::foreachField( 103 llvm::function_ref<bool(FieldIndex, SparseTensorFieldKind, Level, 104 LevelType)> 105 callback) const { 106 const auto lvlTypes = enc.getLvlTypes(); 107 const Level lvlRank = enc.getLvlRank(); 108 SmallVector<COOSegment> cooSegs = enc.getCOOSegments(); 109 FieldIndex fieldIdx = kDataFieldStartingIdx; 110 111 ArrayRef cooSegsRef = cooSegs; 112 // Per-level storage. 113 for (Level l = 0; l < lvlRank; /*l += 1 or l += AoSCooLen*/) { 114 const auto lt = lvlTypes[l]; 115 if (isWithPosLT(lt)) { 116 if (!(callback(fieldIdx++, SparseTensorFieldKind::PosMemRef, l, lt))) 117 return; 118 } 119 if (isWithCrdLT(lt)) { 120 if (!(callback(fieldIdx++, SparseTensorFieldKind::CrdMemRef, l, lt))) 121 return; 122 } 123 if (!cooSegsRef.empty() && cooSegsRef.front().isSegmentStart(l)) { 124 if (!cooSegsRef.front().isSoA) { 125 // AoS COO, all singletons are fused into one memrefs. Skips the entire 126 // COO segement. 127 l = cooSegsRef.front().lvlRange.second; 128 } else { 129 // SoA COO, each singleton level has one memref. 130 l++; 131 } 132 // Expire handled COO segment. 133 cooSegsRef = cooSegsRef.drop_front(); 134 } else { 135 // Non COO levels. 136 l++; 137 } 138 } 139 // The values array. 140 if (!(callback(fieldIdx++, SparseTensorFieldKind::ValMemRef, kInvalidLevel, 141 LevelFormat::Undef))) 142 return; 143 // Put metadata at the end. 144 if (!(callback(fieldIdx++, SparseTensorFieldKind::StorageSpec, kInvalidLevel, 145 LevelFormat::Undef))) 146 return; 147 } 148 149 void sparse_tensor::foreachFieldAndTypeInSparseTensor( 150 SparseTensorType stt, 151 llvm::function_ref<bool(Type, FieldIndex, SparseTensorFieldKind, Level, 152 LevelType)> 153 callback) { 154 assert(stt.hasEncoding()); 155 156 SmallVector<int64_t> memrefShape = 157 getSparseFieldShape(stt.getEncoding(), stt.getDimShape()); 158 159 const Type specType = StorageSpecifierType::get(stt.getEncoding()); 160 // memref<[batch] x ? x pos> positions 161 const Type posMemType = MemRefType::get(memrefShape, stt.getPosType()); 162 // memref<[batch] x ? x crd> coordinates 163 const Type crdMemType = MemRefType::get(memrefShape, stt.getCrdType()); 164 // memref<[batch] x ? x eltType> values 165 const Type valMemType = MemRefType::get(memrefShape, stt.getElementType()); 166 167 StorageLayout(stt).foreachField([specType, posMemType, crdMemType, valMemType, 168 callback](FieldIndex fieldIdx, 169 SparseTensorFieldKind fieldKind, 170 Level lvl, LevelType lt) -> bool { 171 switch (fieldKind) { 172 case SparseTensorFieldKind::StorageSpec: 173 return callback(specType, fieldIdx, fieldKind, lvl, lt); 174 case SparseTensorFieldKind::PosMemRef: 175 return callback(posMemType, fieldIdx, fieldKind, lvl, lt); 176 case SparseTensorFieldKind::CrdMemRef: 177 return callback(crdMemType, fieldIdx, fieldKind, lvl, lt); 178 case SparseTensorFieldKind::ValMemRef: 179 return callback(valMemType, fieldIdx, fieldKind, lvl, lt); 180 }; 181 llvm_unreachable("unrecognized field kind"); 182 }); 183 } 184 185 unsigned StorageLayout::getNumFields() const { 186 unsigned numFields = 0; 187 foreachField([&numFields](FieldIndex, SparseTensorFieldKind, Level, 188 LevelType) -> bool { 189 numFields++; 190 return true; 191 }); 192 return numFields; 193 } 194 195 unsigned StorageLayout::getNumDataFields() const { 196 unsigned numFields = 0; // one value memref 197 foreachField([&numFields](FieldIndex fidx, SparseTensorFieldKind, Level, 198 LevelType) -> bool { 199 if (fidx >= kDataFieldStartingIdx) 200 numFields++; 201 return true; 202 }); 203 numFields -= 1; // the last field is StorageSpecifier 204 assert(numFields == getNumFields() - kDataFieldStartingIdx - 1); 205 return numFields; 206 } 207 208 std::pair<FieldIndex, unsigned> 209 StorageLayout::getFieldIndexAndStride(SparseTensorFieldKind kind, 210 std::optional<Level> lvl) const { 211 FieldIndex fieldIdx = kInvalidFieldIndex; 212 unsigned stride = 1; 213 if (kind == SparseTensorFieldKind::CrdMemRef) { 214 assert(lvl.has_value()); 215 const Level cooStart = enc.getAoSCOOStart(); 216 const Level lvlRank = enc.getLvlRank(); 217 if (lvl.value() >= cooStart && lvl.value() < lvlRank) { 218 lvl = cooStart; 219 stride = lvlRank - cooStart; 220 } 221 } 222 foreachField([lvl, kind, &fieldIdx](FieldIndex fIdx, 223 SparseTensorFieldKind fKind, Level fLvl, 224 LevelType lt) -> bool { 225 if ((lvl && fLvl == lvl.value() && kind == fKind) || 226 (kind == fKind && fKind == SparseTensorFieldKind::ValMemRef)) { 227 fieldIdx = fIdx; 228 // Returns false to break the iteration. 229 return false; 230 } 231 return true; 232 }); 233 assert(fieldIdx != kInvalidFieldIndex); 234 return std::pair<FieldIndex, unsigned>(fieldIdx, stride); 235 } 236 237 //===----------------------------------------------------------------------===// 238 // SparseTensorDialect Attribute Methods. 239 //===----------------------------------------------------------------------===// 240 241 std::optional<uint64_t> SparseTensorDimSliceAttr::getStatic(int64_t v) { 242 return isDynamic(v) ? std::nullopt 243 : std::make_optional(static_cast<uint64_t>(v)); 244 } 245 246 std::optional<uint64_t> SparseTensorDimSliceAttr::getStaticOffset() const { 247 return getStatic(getOffset()); 248 } 249 250 std::optional<uint64_t> SparseTensorDimSliceAttr::getStaticStride() const { 251 return getStatic(getStride()); 252 } 253 254 std::optional<uint64_t> SparseTensorDimSliceAttr::getStaticSize() const { 255 return getStatic(getSize()); 256 } 257 258 bool SparseTensorDimSliceAttr::isCompletelyDynamic() const { 259 return isDynamic(getOffset()) && isDynamic(getStride()) && 260 isDynamic(getSize()); 261 } 262 263 std::string SparseTensorDimSliceAttr::getStaticString(int64_t v) { 264 return isDynamic(v) ? "?" : std::to_string(v); 265 } 266 267 void SparseTensorDimSliceAttr::print(llvm::raw_ostream &os) const { 268 assert(getImpl() && "Uninitialized SparseTensorDimSliceAttr"); 269 os << '('; 270 os << getStaticString(getOffset()); 271 os << ", "; 272 os << getStaticString(getSize()); 273 os << ", "; 274 os << getStaticString(getStride()); 275 os << ')'; 276 } 277 278 void SparseTensorDimSliceAttr::print(AsmPrinter &printer) const { 279 print(printer.getStream()); 280 } 281 282 static ParseResult parseOptionalStaticSlice(int64_t &result, 283 AsmParser &parser) { 284 auto parseResult = parser.parseOptionalInteger(result); 285 if (parseResult.has_value()) { 286 if (parseResult.value().succeeded() && result < 0) { 287 parser.emitError( 288 parser.getCurrentLocation(), 289 "expect positive value or ? for slice offset/size/stride"); 290 return failure(); 291 } 292 return parseResult.value(); 293 } 294 295 // Else, and '?' which represented dynamic slice 296 result = SparseTensorDimSliceAttr::kDynamic; 297 return parser.parseQuestion(); 298 } 299 300 Attribute SparseTensorDimSliceAttr::parse(AsmParser &parser, Type type) { 301 int64_t offset = kDynamic, size = kDynamic, stride = kDynamic; 302 303 if (failed(parser.parseLParen()) || 304 failed(parseOptionalStaticSlice(offset, parser)) || 305 failed(parser.parseComma()) || 306 failed(parseOptionalStaticSlice(size, parser)) || 307 failed(parser.parseComma()) || 308 failed(parseOptionalStaticSlice(stride, parser)) || 309 failed(parser.parseRParen())) 310 return {}; 311 312 return parser.getChecked<SparseTensorDimSliceAttr>(parser.getContext(), 313 offset, size, stride); 314 } 315 316 LogicalResult 317 SparseTensorDimSliceAttr::verify(function_ref<InFlightDiagnostic()> emitError, 318 int64_t offset, int64_t size, int64_t stride) { 319 if (!isDynamic(offset) && offset < 0) 320 return emitError() << "expect non-negative value or ? for slice offset"; 321 if (!isDynamic(size) && size <= 0) 322 return emitError() << "expect positive value or ? for slice size"; 323 if (!isDynamic(stride) && stride <= 0) 324 return emitError() << "expect positive value or ? for slice stride"; 325 return success(); 326 } 327 328 SparseTensorEncodingAttr 329 SparseTensorEncodingAttr::withDimToLvl(AffineMap dimToLvl) const { 330 assert(getImpl() && "Uninitialized SparseTensorEncodingAttr"); 331 return SparseTensorEncodingAttr::get( 332 getContext(), getLvlTypes(), dimToLvl, AffineMap(), getPosWidth(), 333 getCrdWidth(), getExplicitVal(), getImplicitVal()); 334 } 335 336 SparseTensorEncodingAttr 337 SparseTensorEncodingAttr::withDimToLvl(SparseTensorEncodingAttr enc) const { 338 return withDimToLvl(enc ? enc.getDimToLvl() : AffineMap()); 339 } 340 341 SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutDimToLvl() const { 342 return withDimToLvl(AffineMap()); 343 } 344 345 SparseTensorEncodingAttr 346 SparseTensorEncodingAttr::withBitWidths(unsigned posWidth, 347 unsigned crdWidth) const { 348 assert(getImpl() && "Uninitialized SparseTensorEncodingAttr"); 349 return SparseTensorEncodingAttr::get( 350 getContext(), getLvlTypes(), getDimToLvl(), getLvlToDim(), posWidth, 351 crdWidth, getExplicitVal(), getImplicitVal()); 352 } 353 354 SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutBitWidths() const { 355 return withBitWidths(0, 0); 356 } 357 358 SparseTensorEncodingAttr 359 SparseTensorEncodingAttr::withExplicitVal(Attribute explicitVal) const { 360 assert(getImpl() && "Uninitialized SparseTensorEncodingAttr"); 361 return SparseTensorEncodingAttr::get( 362 getContext(), getLvlTypes(), getDimToLvl(), getLvlToDim(), getPosWidth(), 363 getCrdWidth(), explicitVal, getImplicitVal()); 364 } 365 366 SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutExplicitVal() const { 367 return withExplicitVal(Attribute()); 368 } 369 370 SparseTensorEncodingAttr 371 SparseTensorEncodingAttr::withImplicitVal(Attribute implicitVal) const { 372 assert(getImpl() && "Uninitialized SparseTensorEncodingAttr"); 373 return SparseTensorEncodingAttr::get( 374 getContext(), getLvlTypes(), getDimToLvl(), getLvlToDim(), getPosWidth(), 375 getCrdWidth(), getExplicitVal(), implicitVal); 376 } 377 378 SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutImplicitVal() const { 379 return withImplicitVal(Attribute()); 380 } 381 382 SparseTensorEncodingAttr SparseTensorEncodingAttr::withDimSlices( 383 ArrayRef<SparseTensorDimSliceAttr> dimSlices) const { 384 return SparseTensorEncodingAttr::get( 385 getContext(), getLvlTypes(), getDimToLvl(), getLvlToDim(), getPosWidth(), 386 getCrdWidth(), getExplicitVal(), getImplicitVal(), dimSlices); 387 } 388 389 SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutDimSlices() const { 390 return withDimSlices(ArrayRef<SparseTensorDimSliceAttr>{}); 391 } 392 393 uint64_t SparseTensorEncodingAttr::getBatchLvlRank() const { 394 ArrayRef<LevelType> lvlTypes = getLvlTypes(); 395 auto lastBatch = std::find_if(lvlTypes.rbegin(), lvlTypes.rend(), isBatchLT); 396 return std::distance(lastBatch, lvlTypes.rend()); 397 } 398 399 bool SparseTensorEncodingAttr::isAllDense() const { 400 return !getImpl() || llvm::all_of(getLvlTypes(), isDenseLT); 401 } 402 403 bool SparseTensorEncodingAttr::isAllOrdered() const { 404 return !getImpl() || llvm::all_of(getLvlTypes(), isOrderedLT); 405 } 406 407 Type SparseTensorEncodingAttr::getCrdElemType() const { 408 if (!getImpl()) 409 return nullptr; 410 if (getCrdWidth()) 411 return IntegerType::get(getContext(), getCrdWidth()); 412 return IndexType::get(getContext()); 413 } 414 415 Type SparseTensorEncodingAttr::getPosElemType() const { 416 if (!getImpl()) 417 return nullptr; 418 if (getPosWidth()) 419 return IntegerType::get(getContext(), getPosWidth()); 420 return IndexType::get(getContext()); 421 } 422 423 MemRefType SparseTensorEncodingAttr::getCrdMemRefType( 424 std::optional<ArrayRef<int64_t>> dimShape) const { 425 SmallVector<Size> shape = getSparseFieldShape(*this, dimShape); 426 return MemRefType::get(shape, getCrdElemType()); 427 } 428 429 MemRefType SparseTensorEncodingAttr::getPosMemRefType( 430 std::optional<ArrayRef<int64_t>> dimShape) const { 431 SmallVector<Size> shape = getSparseFieldShape(*this, dimShape); 432 return MemRefType::get(shape, getPosElemType()); 433 } 434 435 bool SparseTensorEncodingAttr::isIdentity() const { 436 return !getImpl() || !getDimToLvl() || getDimToLvl().isIdentity(); 437 } 438 439 bool SparseTensorEncodingAttr::isPermutation() const { 440 return !getImpl() || !getDimToLvl() || getDimToLvl().isPermutation(); 441 } 442 443 Dimension SparseTensorEncodingAttr::getDimRank() const { 444 assert(getImpl() && "Uninitialized SparseTensorEncodingAttr"); 445 const auto dimToLvl = getDimToLvl(); 446 return dimToLvl ? dimToLvl.getNumDims() : getLvlRank(); 447 } 448 449 Level SparseTensorEncodingAttr::getLvlRank() const { 450 assert(getImpl() && "Uninitialized SparseTensorEncodingAttr"); 451 return getLvlTypes().size(); 452 } 453 454 LevelType SparseTensorEncodingAttr::getLvlType(Level l) const { 455 if (!getImpl()) 456 return LevelFormat::Batch; 457 assert(l < getLvlRank() && "Level is out of bounds"); 458 return getLvlTypes()[l]; 459 } 460 461 bool SparseTensorEncodingAttr::isSlice() const { 462 assert(getImpl() && "Uninitialized SparseTensorEncodingAttr"); 463 return !getDimSlices().empty(); 464 } 465 466 SparseTensorDimSliceAttr 467 SparseTensorEncodingAttr::getDimSlice(Dimension dim) const { 468 assert(isSlice() && "Is not a slice"); 469 const auto dimSlices = getDimSlices(); 470 assert(dim < dimSlices.size() && "Dimension is out of bounds"); 471 return dimSlices[dim]; 472 } 473 474 std::optional<uint64_t> 475 SparseTensorEncodingAttr::getStaticDimSliceOffset(Dimension dim) const { 476 return getDimSlice(dim).getStaticOffset(); 477 } 478 479 std::optional<uint64_t> 480 SparseTensorEncodingAttr::getStaticDimSliceStride(Dimension dim) const { 481 return getDimSlice(dim).getStaticStride(); 482 } 483 484 std::optional<uint64_t> 485 SparseTensorEncodingAttr::getStaticLvlSliceOffset(Level lvl) const { 486 return getStaticDimSliceOffset(toDim(*this, lvl)); 487 } 488 489 std::optional<uint64_t> 490 SparseTensorEncodingAttr::getStaticLvlSliceStride(Level lvl) const { 491 return getStaticDimSliceStride(toDim(*this, lvl)); 492 } 493 494 SmallVector<int64_t> 495 SparseTensorEncodingAttr::translateShape(ArrayRef<int64_t> srcShape, 496 CrdTransDirectionKind dir) const { 497 if (isIdentity()) 498 return SmallVector<int64_t>(srcShape); 499 500 SmallVector<int64_t> ret; 501 unsigned rank = 502 dir == CrdTransDirectionKind::dim2lvl ? getLvlRank() : getDimRank(); 503 ret.reserve(rank); 504 505 if (isPermutation()) { 506 for (unsigned r = 0; r < rank; r++) { 507 unsigned trans = dir == CrdTransDirectionKind::dim2lvl ? toDim(*this, r) 508 : toLvl(*this, r); 509 ret.push_back(srcShape[trans]); 510 } 511 return ret; 512 } 513 514 // Handle non-permutation maps. 515 AffineMap transMap = 516 dir == CrdTransDirectionKind::dim2lvl ? getDimToLvl() : getLvlToDim(); 517 518 SmallVector<AffineExpr> dimRep; 519 dimRep.reserve(srcShape.size()); 520 for (int64_t sz : srcShape) { 521 if (!ShapedType::isDynamic(sz)) { 522 // Push back the max coordinate for the given dimension/level size. 523 dimRep.push_back(getAffineConstantExpr(sz - 1, getContext())); 524 } else { 525 // A dynamic size, use a AffineDimExpr to symbolize the value. 526 dimRep.push_back(getAffineDimExpr(dimRep.size(), getContext())); 527 } 528 }; 529 530 for (AffineExpr exp : transMap.getResults()) { 531 // Do constant propagation on the affine map. 532 AffineExpr evalExp = 533 simplifyAffineExpr(exp.replaceDims(dimRep), srcShape.size(), 0); 534 // use llvm namespace here to avoid ambiguity 535 if (auto c = llvm::dyn_cast<AffineConstantExpr>(evalExp)) { 536 ret.push_back(c.getValue() + 1); 537 } else { 538 if (auto mod = llvm::dyn_cast<AffineBinaryOpExpr>(evalExp); 539 mod && mod.getKind() == AffineExprKind::Mod) { 540 // We can still infer a static bound for expressions in form 541 // "d % constant" since d % constant \in [0, constant). 542 if (auto bound = llvm::dyn_cast<AffineConstantExpr>(mod.getRHS())) { 543 ret.push_back(bound.getValue()); 544 continue; 545 } 546 } 547 ret.push_back(ShapedType::kDynamic); 548 } 549 } 550 assert(ret.size() == rank); 551 return ret; 552 } 553 554 ValueRange 555 SparseTensorEncodingAttr::translateCrds(OpBuilder &builder, Location loc, 556 ValueRange crds, 557 CrdTransDirectionKind dir) const { 558 if (!getImpl()) 559 return crds; 560 561 SmallVector<Type> retType( 562 dir == CrdTransDirectionKind::lvl2dim ? getDimRank() : getLvlRank(), 563 builder.getIndexType()); 564 auto transOp = builder.create<CrdTranslateOp>(loc, retType, crds, dir, *this); 565 return transOp.getOutCrds(); 566 } 567 568 Attribute SparseTensorEncodingAttr::parse(AsmParser &parser, Type type) { 569 // Open "<{" part. 570 if (failed(parser.parseLess())) 571 return {}; 572 if (failed(parser.parseLBrace())) 573 return {}; 574 575 // Process the data from the parsed dictionary value into struct-like data. 576 SmallVector<LevelType> lvlTypes; 577 SmallVector<SparseTensorDimSliceAttr> dimSlices; 578 AffineMap dimToLvl = {}; 579 AffineMap lvlToDim = {}; 580 unsigned posWidth = 0; 581 unsigned crdWidth = 0; 582 Attribute explicitVal; 583 Attribute implicitVal; 584 StringRef attrName; 585 SmallVector<StringRef, 5> keys = {"map", "posWidth", "crdWidth", 586 "explicitVal", "implicitVal"}; 587 while (succeeded(parser.parseOptionalKeyword(&attrName))) { 588 // Detect admissible keyword. 589 auto *it = find(keys, attrName); 590 if (it == keys.end()) { 591 parser.emitError(parser.getNameLoc(), "unexpected key: ") << attrName; 592 return {}; 593 } 594 unsigned keyWordIndex = it - keys.begin(); 595 // Consume the `=` after keys 596 if (failed(parser.parseEqual())) 597 return {}; 598 // Dispatch on keyword. 599 switch (keyWordIndex) { 600 case 0: { // map 601 ir_detail::DimLvlMapParser cParser(parser); 602 auto res = cParser.parseDimLvlMap(); 603 if (failed(res)) 604 return {}; 605 const auto &dlm = *res; 606 607 const Level lvlRank = dlm.getLvlRank(); 608 for (Level lvl = 0; lvl < lvlRank; lvl++) 609 lvlTypes.push_back(dlm.getLvlType(lvl)); 610 611 const Dimension dimRank = dlm.getDimRank(); 612 for (Dimension dim = 0; dim < dimRank; dim++) 613 dimSlices.push_back(dlm.getDimSlice(dim)); 614 // NOTE: the old syntax requires an all-or-nothing approach to 615 // `dimSlices`; therefore, if any slice actually exists then we need 616 // to convert null-DSA into default/nop DSA. 617 const auto isDefined = [](SparseTensorDimSliceAttr slice) { 618 return static_cast<bool>(slice.getImpl()); 619 }; 620 if (llvm::any_of(dimSlices, isDefined)) { 621 const auto defaultSlice = 622 SparseTensorDimSliceAttr::get(parser.getContext()); 623 for (Dimension dim = 0; dim < dimRank; dim++) 624 if (!isDefined(dimSlices[dim])) 625 dimSlices[dim] = defaultSlice; 626 } else { 627 dimSlices.clear(); 628 } 629 630 dimToLvl = dlm.getDimToLvlMap(parser.getContext()); 631 lvlToDim = dlm.getLvlToDimMap(parser.getContext()); 632 break; 633 } 634 case 1: { // posWidth 635 Attribute attr; 636 if (failed(parser.parseAttribute(attr))) 637 return {}; 638 auto intAttr = llvm::dyn_cast<IntegerAttr>(attr); 639 if (!intAttr) { 640 parser.emitError(parser.getNameLoc(), 641 "expected an integral position bitwidth"); 642 return {}; 643 } 644 posWidth = intAttr.getInt(); 645 break; 646 } 647 case 2: { // crdWidth 648 Attribute attr; 649 if (failed(parser.parseAttribute(attr))) 650 return {}; 651 auto intAttr = llvm::dyn_cast<IntegerAttr>(attr); 652 if (!intAttr) { 653 parser.emitError(parser.getNameLoc(), 654 "expected an integral index bitwidth"); 655 return {}; 656 } 657 crdWidth = intAttr.getInt(); 658 break; 659 } 660 case 3: { // explicitVal 661 Attribute attr; 662 if (failed(parser.parseAttribute(attr))) 663 return {}; 664 if (auto result = llvm::dyn_cast<FloatAttr>(attr)) { 665 explicitVal = result; 666 } else if (auto result = llvm::dyn_cast<IntegerAttr>(attr)) { 667 explicitVal = result; 668 } else if (auto result = llvm::dyn_cast<complex::NumberAttr>(attr)) { 669 explicitVal = result; 670 } else { 671 parser.emitError(parser.getNameLoc(), 672 "expected a numeric value for explicitVal"); 673 return {}; 674 } 675 break; 676 } 677 case 4: { // implicitVal 678 Attribute attr; 679 if (failed(parser.parseAttribute(attr))) 680 return {}; 681 if (auto result = llvm::dyn_cast<FloatAttr>(attr)) { 682 implicitVal = result; 683 } else if (auto result = llvm::dyn_cast<IntegerAttr>(attr)) { 684 implicitVal = result; 685 } else if (auto result = llvm::dyn_cast<complex::NumberAttr>(attr)) { 686 implicitVal = result; 687 } else { 688 parser.emitError(parser.getNameLoc(), 689 "expected a numeric value for implicitVal"); 690 return {}; 691 } 692 break; 693 } 694 } // switch 695 // Only last item can omit the comma. 696 if (parser.parseOptionalComma().failed()) 697 break; 698 } 699 700 // Close "}>" part. 701 if (failed(parser.parseRBrace())) 702 return {}; 703 if (failed(parser.parseGreater())) 704 return {}; 705 706 // Construct struct-like storage for attribute. 707 if (!lvlToDim || lvlToDim.isEmpty()) { 708 lvlToDim = inferLvlToDim(dimToLvl, parser.getContext()); 709 } 710 return parser.getChecked<SparseTensorEncodingAttr>( 711 parser.getContext(), lvlTypes, dimToLvl, lvlToDim, posWidth, crdWidth, 712 explicitVal, implicitVal, dimSlices); 713 } 714 715 void SparseTensorEncodingAttr::print(AsmPrinter &printer) const { 716 auto map = static_cast<AffineMap>(getDimToLvl()); 717 // Empty affine map indicates identity map 718 if (!map) 719 map = AffineMap::getMultiDimIdentityMap(getLvlTypes().size(), getContext()); 720 printer << "<{ map = "; 721 printSymbols(map, printer); 722 printer << '('; 723 printDimensions(map, printer, getDimSlices()); 724 printer << ") -> ("; 725 printLevels(map, printer, getLvlTypes()); 726 printer << ')'; 727 // Print remaining members only for non-default values. 728 if (getPosWidth()) 729 printer << ", posWidth = " << getPosWidth(); 730 if (getCrdWidth()) 731 printer << ", crdWidth = " << getCrdWidth(); 732 if (getExplicitVal()) { 733 printer << ", explicitVal = " << getExplicitVal(); 734 } 735 if (getImplicitVal()) 736 printer << ", implicitVal = " << getImplicitVal(); 737 printer << " }>"; 738 } 739 740 void SparseTensorEncodingAttr::printSymbols(AffineMap &map, 741 AsmPrinter &printer) const { 742 if (map.getNumSymbols() == 0) 743 return; 744 printer << '['; 745 for (unsigned i = 0, n = map.getNumSymbols() - 1; i < n; i++) 746 printer << 's' << i << ", "; 747 if (map.getNumSymbols() >= 1) 748 printer << 's' << map.getNumSymbols() - 1; 749 printer << ']'; 750 } 751 752 void SparseTensorEncodingAttr::printDimensions( 753 AffineMap &map, AsmPrinter &printer, 754 ArrayRef<SparseTensorDimSliceAttr> dimSlices) const { 755 if (!dimSlices.empty()) { 756 for (unsigned i = 0, n = map.getNumDims() - 1; i < n; i++) 757 printer << 'd' << i << " : " << dimSlices[i] << ", "; 758 if (map.getNumDims() >= 1) { 759 printer << 'd' << map.getNumDims() - 1 << " : " 760 << dimSlices[map.getNumDims() - 1]; 761 } 762 } else { 763 for (unsigned i = 0, n = map.getNumDims() - 1; i < n; i++) 764 printer << 'd' << i << ", "; 765 if (map.getNumDims() >= 1) 766 printer << 'd' << map.getNumDims() - 1; 767 } 768 } 769 770 void SparseTensorEncodingAttr::printLevels(AffineMap &map, AsmPrinter &printer, 771 ArrayRef<LevelType> lvlTypes) const { 772 for (unsigned i = 0, n = map.getNumResults() - 1; i < n; i++) { 773 map.getResult(i).print(printer.getStream()); 774 printer << " : " << toMLIRString(lvlTypes[i]) << ", "; 775 } 776 if (map.getNumResults() >= 1) { 777 auto lastIndex = map.getNumResults() - 1; 778 map.getResult(lastIndex).print(printer.getStream()); 779 printer << " : " << toMLIRString(lvlTypes[lastIndex]); 780 } 781 } 782 783 LogicalResult SparseTensorEncodingAttr::verify( 784 function_ref<InFlightDiagnostic()> emitError, ArrayRef<LevelType> lvlTypes, 785 AffineMap dimToLvl, AffineMap lvlToDim, unsigned posWidth, 786 unsigned crdWidth, Attribute explicitVal, Attribute implicitVal, 787 ArrayRef<SparseTensorDimSliceAttr> dimSlices) { 788 if (!acceptBitWidth(posWidth)) 789 return emitError() << "unexpected position bitwidth: " << posWidth; 790 if (!acceptBitWidth(crdWidth)) 791 return emitError() << "unexpected coordinate bitwidth: " << crdWidth; 792 793 // Verify every COO segment. 794 auto *it = std::find_if(lvlTypes.begin(), lvlTypes.end(), isSingletonLT); 795 while (it != lvlTypes.end()) { 796 if (it == lvlTypes.begin() || 797 !(it - 1)->isa<LevelFormat::Compressed, LevelFormat::LooseCompressed>()) 798 return emitError() << "expected compressed or loose_compressed level " 799 "before singleton level"; 800 801 auto *curCOOEnd = std::find_if_not(it, lvlTypes.end(), isSingletonLT); 802 if (!std::all_of(it, curCOOEnd, 803 [](LevelType i) { return isSingletonLT(i); })) 804 return emitError() << "expected all singleton lvlTypes " 805 "following a singleton level"; 806 // We can potentially support mixed SoA/AoS singleton levels. 807 if (!std::all_of(it, curCOOEnd, [it](LevelType i) { 808 return it->isa<LevelPropNonDefault::SoA>() == 809 i.isa<LevelPropNonDefault::SoA>(); 810 })) { 811 return emitError() << "expected all singleton lvlTypes stored in the " 812 "same memory layout (SoA vs AoS)."; 813 } 814 it = std::find_if(curCOOEnd, lvlTypes.end(), isSingletonLT); 815 } 816 817 auto lastBatch = std::find_if(lvlTypes.rbegin(), lvlTypes.rend(), isBatchLT); 818 if (!std::all_of(lastBatch, lvlTypes.rend(), isBatchLT)) 819 return emitError() << "Batch lvlType can only be leading levels."; 820 821 // SoA property can only be applied on singleton level. 822 auto soaLvls = llvm::make_filter_range(lvlTypes, [](LevelType lt) { 823 return lt.isa<LevelPropNonDefault::SoA>(); 824 }); 825 if (llvm::any_of(soaLvls, [](LevelType lt) { 826 return !lt.isa<LevelFormat::Singleton>(); 827 })) { 828 return emitError() << "SoA is only applicable to singleton lvlTypes."; 829 } 830 831 // TODO: audit formats that actually are supported by backend. 832 if (auto it = std::find_if(lvlTypes.begin(), lvlTypes.end(), isNOutOfMLT); 833 it != std::end(lvlTypes)) { 834 if (it != lvlTypes.end() - 1) 835 return emitError() << "expected n_out_of_m to be the last level type"; 836 if (!std::all_of(lvlTypes.begin(), it, 837 [](LevelType i) { return isDenseLT(i); })) 838 return emitError() << "expected all dense lvlTypes " 839 "before a n_out_of_m level"; 840 if (dimToLvl && (dimToLvl.getNumDims() != dimToLvl.getNumResults())) { 841 if (!isBlockSparsity(dimToLvl)) { 842 return emitError() 843 << "expected 1xm block structure for n_out_of_m level"; 844 } 845 auto sizes = getBlockSize(dimToLvl); 846 unsigned coefficient = 0; 847 for (const auto &elem : sizes) { 848 if (elem != 0) { 849 if (elem != coefficient && coefficient != 0) { 850 return emitError() << "expected only one blocked level " 851 "with the same coefficients"; 852 } 853 coefficient = elem; 854 } 855 } 856 if (coefficient != getM(*it)) { 857 return emitError() << "expected coeffiencts of Affine expressions " 858 "to be equal to m of n_out_of_m level"; 859 } 860 } 861 } 862 // Before we can check that the level-rank is consistent/coherent 863 // across all fields, we need to define it. The source-of-truth for 864 // the `getLvlRank` method is the length of the level-types array, 865 // since it must always be provided and have full rank; therefore we 866 // use that same source-of-truth here. 867 const Level lvlRank = lvlTypes.size(); 868 if (lvlRank == 0) 869 return emitError() << "expected a non-empty array for lvlTypes"; 870 // We save `dimRank` here because we'll also need it to verify `dimSlices`. 871 const Dimension dimRank = dimToLvl ? dimToLvl.getNumDims() : lvlRank; 872 if (dimToLvl) { 873 if (dimToLvl.getNumResults() != lvlRank) 874 return emitError() 875 << "level-rank mismatch between dimToLvl and lvlTypes: " 876 << dimToLvl.getNumResults() << " != " << lvlRank; 877 auto inferRes = inferLvlToDim(dimToLvl, dimToLvl.getContext()); 878 // Symbols can't be inferred but are acceptable. 879 if (!inferRes && dimToLvl.getNumSymbols() == 0) 880 return emitError() << "failed to infer lvlToDim from dimToLvl"; 881 if (lvlToDim && (inferRes != lvlToDim)) 882 return emitError() << "expected lvlToDim to be an inverse of dimToLvl"; 883 if (dimRank > lvlRank) 884 return emitError() << "unexpected dimToLvl mapping from " << dimRank 885 << " to " << lvlRank; 886 } 887 if (!dimSlices.empty()) { 888 if (dimSlices.size() != dimRank) 889 return emitError() 890 << "dimension-rank mismatch between dimSlices and dimToLvl: " 891 << dimSlices.size() << " != " << dimRank; 892 // Compiler support for `dimSlices` currently requires that the two 893 // ranks agree. (However, it does allow `dimToLvl` to be a permutation.) 894 if (dimRank != lvlRank) 895 return emitError() 896 << "dimSlices expected dimension-rank to match level-rank: " 897 << dimRank << " != " << lvlRank; 898 } 899 return success(); 900 } 901 902 LogicalResult SparseTensorEncodingAttr::verifyEncoding( 903 ArrayRef<Size> dimShape, Type elementType, 904 function_ref<InFlightDiagnostic()> emitError) const { 905 // Check structural integrity. In particular, this ensures that the 906 // level-rank is coherent across all the fields. 907 if (failed(verify(emitError, getLvlTypes(), getDimToLvl(), getLvlToDim(), 908 getPosWidth(), getCrdWidth(), getExplicitVal(), 909 getImplicitVal(), getDimSlices()))) 910 return failure(); 911 // Check integrity with tensor type specifics. In particular, we 912 // need only check that the dimension-rank of the tensor agrees with 913 // the dimension-rank of the encoding. 914 const Dimension dimRank = dimShape.size(); 915 if (dimRank == 0) 916 return emitError() << "expected non-scalar sparse tensor"; 917 if (getDimRank() != dimRank) 918 return emitError() 919 << "dimension-rank mismatch between encoding and tensor shape: " 920 << getDimRank() << " != " << dimRank; 921 if (auto expVal = getExplicitVal()) { 922 Type attrType = llvm::dyn_cast<TypedAttr>(expVal).getType(); 923 if (attrType != elementType) { 924 return emitError() << "explicit value type mismatch between encoding and " 925 << "tensor element type: " << attrType 926 << " != " << elementType; 927 } 928 } 929 if (auto impVal = getImplicitVal()) { 930 Type attrType = llvm::dyn_cast<TypedAttr>(impVal).getType(); 931 if (attrType != elementType) { 932 return emitError() << "implicit value type mismatch between encoding and " 933 << "tensor element type: " << attrType 934 << " != " << elementType; 935 } 936 // Currently, we only support zero as the implicit value. 937 auto impFVal = llvm::dyn_cast<FloatAttr>(impVal); 938 auto impIntVal = llvm::dyn_cast<IntegerAttr>(impVal); 939 auto impComplexVal = llvm::dyn_cast<complex::NumberAttr>(impVal); 940 if ((impFVal && impFVal.getValue().isNonZero()) || 941 (impIntVal && !impIntVal.getValue().isZero()) || 942 (impComplexVal && (impComplexVal.getImag().isNonZero() || 943 impComplexVal.getReal().isNonZero()))) { 944 return emitError() << "implicit value must be zero"; 945 } 946 } 947 return success(); 948 } 949 950 Level mlir::sparse_tensor::SparseTensorEncodingAttr::getAoSCOOStart() const { 951 SmallVector<COOSegment> coo = getCOOSegments(); 952 assert(coo.size() == 1 || coo.empty()); 953 if (!coo.empty() && coo.front().isAoS()) { 954 return coo.front().lvlRange.first; 955 } 956 return getLvlRank(); 957 } 958 959 SmallVector<COOSegment> 960 mlir::sparse_tensor::SparseTensorEncodingAttr::getCOOSegments() const { 961 SmallVector<COOSegment> ret; 962 if (getLvlRank() <= 1) 963 return ret; 964 965 ArrayRef<LevelType> lts = getLvlTypes(); 966 Level l = 0; 967 while (l < getLvlRank()) { 968 auto lt = lts[l]; 969 if (lt.isa<LevelFormat::Compressed, LevelFormat::LooseCompressed>()) { 970 auto cur = lts.begin() + l; 971 auto end = std::find_if(cur + 1, lts.end(), [](LevelType lt) { 972 return !lt.isa<LevelFormat::Singleton>(); 973 }); 974 unsigned cooLen = std::distance(cur, end); 975 if (cooLen > 1) { 976 // To support mixed SoA/AoS COO, we should break the segment when the 977 // storage scheme changes, for now we faithfully assume that all 978 // consecutive singleton levels have the same storage format as verified 979 // STEA. 980 ret.push_back(COOSegment{std::make_pair(l, l + cooLen), 981 lts[l + 1].isa<LevelPropNonDefault::SoA>()}); 982 } 983 l += cooLen; 984 } else { 985 l++; 986 } 987 } 988 return ret; 989 } 990 991 //===----------------------------------------------------------------------===// 992 // SparseTensorType Methods. 993 //===----------------------------------------------------------------------===// 994 995 bool mlir::sparse_tensor::SparseTensorType::isCOOType(Level startLvl, 996 bool isUnique) const { 997 if (!hasEncoding()) 998 return false; 999 if (!isCompressedLvl(startLvl) && !isLooseCompressedLvl(startLvl)) 1000 return false; 1001 for (Level l = startLvl + 1; l < lvlRank; ++l) 1002 if (!isSingletonLvl(l)) 1003 return false; 1004 // If isUnique is true, then make sure that the last level is unique, 1005 // that is, when lvlRank == 1, the only compressed level is unique, 1006 // and when lvlRank > 1, the last singleton is unique. 1007 return !isUnique || isUniqueLvl(lvlRank - 1); 1008 } 1009 1010 RankedTensorType 1011 mlir::sparse_tensor::SparseTensorType::getCOOType(bool ordered) const { 1012 SmallVector<LevelType> lvlTypes; 1013 lvlTypes.reserve(lvlRank); 1014 // A non-unique compressed level at beginning (unless this is 1015 // also the last level, then it is unique). 1016 lvlTypes.push_back( 1017 *buildLevelType(LevelFormat::Compressed, ordered, lvlRank == 1)); 1018 if (lvlRank > 1) { 1019 // Followed by n-2 non-unique singleton levels. 1020 std::fill_n(std::back_inserter(lvlTypes), lvlRank - 2, 1021 *buildLevelType(LevelFormat::Singleton, ordered, false)); 1022 // Ends by a unique singleton level. 1023 lvlTypes.push_back(*buildLevelType(LevelFormat::Singleton, ordered, true)); 1024 } 1025 auto enc = SparseTensorEncodingAttr::get( 1026 getContext(), lvlTypes, getDimToLvl(), getLvlToDim(), getPosWidth(), 1027 getCrdWidth(), getExplicitVal(), getImplicitVal()); 1028 return RankedTensorType::get(getDimShape(), getElementType(), enc); 1029 } 1030 1031 //===----------------------------------------------------------------------===// 1032 // Convenience Methods. 1033 //===----------------------------------------------------------------------===// 1034 1035 SparseTensorEncodingAttr 1036 mlir::sparse_tensor::getSparseTensorEncoding(Type type) { 1037 if (auto ttp = llvm::dyn_cast<RankedTensorType>(type)) 1038 return llvm::dyn_cast_or_null<SparseTensorEncodingAttr>(ttp.getEncoding()); 1039 if (auto mdtp = llvm::dyn_cast<StorageSpecifierType>(type)) 1040 return mdtp.getEncoding(); 1041 return nullptr; 1042 } 1043 1044 AffineMap mlir::sparse_tensor::inferLvlToDim(AffineMap dimToLvl, 1045 MLIRContext *context) { 1046 auto map = static_cast<AffineMap>(dimToLvl); 1047 AffineMap lvlToDim; 1048 // Return an empty lvlToDim when inference is not successful. 1049 if (!map || map.getNumSymbols() != 0) { 1050 lvlToDim = AffineMap(); 1051 } else if (map.isPermutation()) { 1052 lvlToDim = inversePermutation(map); 1053 } else if (isBlockSparsity(map)) { 1054 lvlToDim = inverseBlockSparsity(map, context); 1055 } 1056 return lvlToDim; 1057 } 1058 1059 AffineMap mlir::sparse_tensor::inverseBlockSparsity(AffineMap dimToLvl, 1060 MLIRContext *context) { 1061 SmallVector<AffineExpr> lvlExprs; 1062 auto numLvls = dimToLvl.getNumResults(); 1063 lvlExprs.reserve(numLvls); 1064 // lvlExprComponents stores information of the floordiv and mod operations 1065 // applied to the same dimension, so as to build the lvlToDim map. 1066 std::map<unsigned, SmallVector<AffineExpr, 3>> lvlExprComponents; 1067 for (unsigned i = 0, n = numLvls; i < n; i++) { 1068 auto result = dimToLvl.getResult(i); 1069 if (auto binOp = dyn_cast<AffineBinaryOpExpr>(result)) { 1070 if (result.getKind() == AffineExprKind::FloorDiv) { 1071 // Position of the dimension in dimToLvl. 1072 auto pos = dyn_cast<AffineDimExpr>(binOp.getLHS()).getPosition(); 1073 assert(lvlExprComponents.find(pos) == lvlExprComponents.end() && 1074 "expected only one floordiv for each dimension"); 1075 SmallVector<AffineExpr, 3> components; 1076 // Level variable for floordiv. 1077 components.push_back(getAffineDimExpr(i, context)); 1078 // Multiplier. 1079 components.push_back(binOp.getRHS()); 1080 // Map key is the position of the dimension. 1081 lvlExprComponents[pos] = components; 1082 } else if (result.getKind() == AffineExprKind::Mod) { 1083 auto pos = dyn_cast<AffineDimExpr>(binOp.getLHS()).getPosition(); 1084 assert(lvlExprComponents.find(pos) != lvlExprComponents.end() && 1085 "expected floordiv before mod"); 1086 // Add level variable for mod to the same vector 1087 // of the corresponding floordiv. 1088 lvlExprComponents[pos].push_back(getAffineDimExpr(i, context)); 1089 } else { 1090 assert(false && "expected floordiv or mod"); 1091 } 1092 } else { 1093 lvlExprs.push_back(getAffineDimExpr(i, context)); 1094 } 1095 } 1096 // Build lvlExprs from lvlExprComponents. 1097 // For example, for il = i floordiv 2 and ii = i mod 2, the components 1098 // would be [il, 2, ii]. It could be used to build the AffineExpr 1099 // i = il * 2 + ii in lvlToDim. 1100 for (auto &components : lvlExprComponents) { 1101 assert(components.second.size() == 3 && 1102 "expected 3 components to build lvlExprs"); 1103 auto mulOp = getAffineBinaryOpExpr( 1104 AffineExprKind::Mul, components.second[0], components.second[1]); 1105 auto addOp = 1106 getAffineBinaryOpExpr(AffineExprKind::Add, mulOp, components.second[2]); 1107 lvlExprs.push_back(addOp); 1108 } 1109 return dimToLvl.get(dimToLvl.getNumResults(), 0, lvlExprs, context); 1110 } 1111 1112 SmallVector<unsigned> mlir::sparse_tensor::getBlockSize(AffineMap dimToLvl) { 1113 assert(isBlockSparsity(dimToLvl) && 1114 "expected dimToLvl to be block sparsity for calling getBlockSize"); 1115 SmallVector<unsigned> blockSize; 1116 for (auto result : dimToLvl.getResults()) { 1117 if (auto binOp = dyn_cast<AffineBinaryOpExpr>(result)) { 1118 if (result.getKind() == AffineExprKind::Mod) { 1119 blockSize.push_back( 1120 dyn_cast<AffineConstantExpr>(binOp.getRHS()).getValue()); 1121 } 1122 } else { 1123 blockSize.push_back(0); 1124 } 1125 } 1126 return blockSize; 1127 } 1128 1129 bool mlir::sparse_tensor::isBlockSparsity(AffineMap dimToLvl) { 1130 if (!dimToLvl) 1131 return false; 1132 std::map<unsigned, int64_t> coeffientMap; 1133 bool hasBlock = false; 1134 for (auto result : dimToLvl.getResults()) { 1135 if (auto binOp = dyn_cast<AffineBinaryOpExpr>(result)) { 1136 // Check for "dim op const". 1137 auto dimOp = dyn_cast<AffineDimExpr>(binOp.getLHS()); 1138 auto conOp = dyn_cast<AffineConstantExpr>(binOp.getRHS()); 1139 if (!dimOp || !conOp || conOp.getValue() <= 0) 1140 return false; 1141 // Inspect "dim / const" or "dim % const". 1142 auto pos = dimOp.getPosition(); 1143 if (binOp.getKind() == AffineExprKind::FloorDiv) { 1144 // Expect only one floordiv for each dimension. 1145 if (coeffientMap.find(pos) != coeffientMap.end()) 1146 return false; 1147 // Record coefficient of the floordiv. 1148 coeffientMap[pos] = conOp.getValue(); 1149 } else if (binOp.getKind() == AffineExprKind::Mod) { 1150 // Expect floordiv before mod. 1151 if (coeffientMap.find(pos) == coeffientMap.end()) 1152 return false; 1153 // Expect mod to have the same coefficient as floordiv. 1154 if (conOp.getValue() != coeffientMap[pos]) 1155 return false; 1156 hasBlock = true; 1157 } else { 1158 return false; 1159 } 1160 } else if (auto dimOp = dyn_cast<AffineDimExpr>(result)) { 1161 auto pos = dimOp.getPosition(); 1162 // Expect dim to be unset. 1163 if (coeffientMap.find(pos) != coeffientMap.end()) 1164 return false; 1165 coeffientMap[pos] = 0; 1166 } else { 1167 return false; 1168 } 1169 } 1170 return hasBlock; 1171 } 1172 1173 bool mlir::sparse_tensor::hasAnyNonIdentityOperandsOrResults(Operation *op) { 1174 auto hasNonIdentityMap = [](Value v) { 1175 auto stt = tryGetSparseTensorType(v); 1176 return stt && !stt->isIdentity(); 1177 }; 1178 1179 return llvm::any_of(op->getOperands(), hasNonIdentityMap) || 1180 llvm::any_of(op->getResults(), hasNonIdentityMap); 1181 } 1182 1183 Dimension mlir::sparse_tensor::toDim(SparseTensorEncodingAttr enc, Level l) { 1184 if (enc) { 1185 assert(enc.isPermutation() && "Non permutation map not supported"); 1186 if (const auto dimToLvl = enc.getDimToLvl()) 1187 return dimToLvl.getDimPosition(l); 1188 } 1189 return l; 1190 } 1191 1192 Level mlir::sparse_tensor::toLvl(SparseTensorEncodingAttr enc, Dimension d) { 1193 if (enc) { 1194 assert(enc.isPermutation() && "Non permutation map not supported"); 1195 if (const auto lvlToDim = enc.getLvlToDim()) 1196 return lvlToDim.getDimPosition(d); 1197 } 1198 return d; 1199 } 1200 1201 /// We normalized sparse tensor encoding attribute by always using 1202 /// ordered/unique LT such that "compressed_nu_no" and "compressed_nu" (as well 1203 /// as other variants) lead to the same storage specifier type, and stripping 1204 /// irrelevant fields that do not alter the sparse tensor memory layout. 1205 static SparseTensorEncodingAttr 1206 getNormalizedEncodingForSpecifier(SparseTensorEncodingAttr enc) { 1207 SmallVector<LevelType> lts; 1208 for (auto lt : enc.getLvlTypes()) 1209 lts.push_back(lt.stripStorageIrrelevantProperties()); 1210 1211 return SparseTensorEncodingAttr::get( 1212 enc.getContext(), lts, 1213 AffineMap(), // dimToLvl (irrelevant to storage specifier) 1214 AffineMap(), // lvlToDim (irrelevant to storage specifier) 1215 // Always use `index` for memSize and lvlSize instead of reusing 1216 // `getPosWidth` and `getCrdWidth`. It allows us to reuse the same SSA 1217 // value for different bitwidth, it also avoids casting between index and 1218 // integer (returned by DimOp) 1219 0, 0, 1220 Attribute(), // explicitVal (irrelevant to storage specifier) 1221 Attribute(), // implicitVal (irrelevant to storage specifier) 1222 enc.getDimSlices()); 1223 } 1224 1225 StorageSpecifierType 1226 StorageSpecifierType::get(MLIRContext *ctx, SparseTensorEncodingAttr encoding) { 1227 return Base::get(ctx, getNormalizedEncodingForSpecifier(encoding)); 1228 } 1229 1230 //===----------------------------------------------------------------------===// 1231 // SparseTensorDialect Operations. 1232 //===----------------------------------------------------------------------===// 1233 1234 static LogicalResult lvlIsInBounds(Level lvl, Value tensor) { 1235 return success(lvl < getSparseTensorType(tensor).getLvlRank()); 1236 } 1237 1238 static LogicalResult isMatchingWidth(Value mem, unsigned width) { 1239 const Type etp = getMemRefType(mem).getElementType(); 1240 return success(width == 0 ? etp.isIndex() : etp.isInteger(width)); 1241 } 1242 1243 static LogicalResult verifySparsifierGetterSetter( 1244 StorageSpecifierKind mdKind, std::optional<Level> lvl, 1245 TypedValue<StorageSpecifierType> md, Operation *op) { 1246 if (mdKind == StorageSpecifierKind::ValMemSize && lvl) { 1247 return op->emitError( 1248 "redundant level argument for querying value memory size"); 1249 } 1250 1251 const auto enc = md.getType().getEncoding(); 1252 const Level lvlRank = enc.getLvlRank(); 1253 1254 if (mdKind == StorageSpecifierKind::DimOffset || 1255 mdKind == StorageSpecifierKind::DimStride) 1256 if (!enc.isSlice()) 1257 return op->emitError("requested slice data on non-slice tensor"); 1258 1259 if (mdKind != StorageSpecifierKind::ValMemSize) { 1260 if (!lvl) 1261 return op->emitError("missing level argument"); 1262 1263 const Level l = lvl.value(); 1264 if (l >= lvlRank) 1265 return op->emitError("requested level is out of bounds"); 1266 1267 if (mdKind == StorageSpecifierKind::PosMemSize && enc.isSingletonLvl(l)) 1268 return op->emitError( 1269 "requested position memory size on a singleton level"); 1270 } 1271 return success(); 1272 } 1273 1274 static Type getFieldElemType(SparseTensorType stt, SparseTensorFieldKind kind) { 1275 switch (kind) { 1276 case SparseTensorFieldKind::CrdMemRef: 1277 return stt.getCrdType(); 1278 case SparseTensorFieldKind::PosMemRef: 1279 return stt.getPosType(); 1280 case SparseTensorFieldKind::ValMemRef: 1281 return stt.getElementType(); 1282 case SparseTensorFieldKind::StorageSpec: 1283 return nullptr; 1284 } 1285 llvm_unreachable("Unrecognizable FieldKind"); 1286 } 1287 1288 static LogicalResult verifyPackUnPack(Operation *op, bool requiresStaticShape, 1289 SparseTensorType stt, 1290 RankedTensorType valTp, 1291 TypeRange lvlTps) { 1292 if (requiresStaticShape && !stt.hasStaticDimShape()) 1293 return op->emitError("the sparse-tensor must have static shape"); 1294 if (!stt.hasEncoding()) 1295 return op->emitError("the sparse-tensor must have an encoding attribute"); 1296 1297 // Verifies the trailing COO. 1298 Level cooStartLvl = stt.getAoSCOOStart(); 1299 if (cooStartLvl < stt.getLvlRank()) { 1300 // We only supports trailing COO for now, must be the last input. 1301 auto cooTp = llvm::cast<ShapedType>(lvlTps.back()); 1302 // The coordinates should be in shape of <? x rank> 1303 unsigned expCOORank = stt.getLvlRank() - cooStartLvl; 1304 if (cooTp.getRank() != 2 || expCOORank != cooTp.getShape().back()) { 1305 op->emitError("input/output trailing COO level-ranks don't match"); 1306 } 1307 } 1308 1309 // Verifies that all types match. 1310 StorageLayout layout(stt.getEncoding()); 1311 if (layout.getNumDataFields() != lvlTps.size() + 1) // plus one value memref 1312 return op->emitError("inconsistent number of fields between input/output"); 1313 1314 unsigned idx = 0; 1315 bool misMatch = false; 1316 layout.foreachField([&idx, &misMatch, stt, valTp, 1317 lvlTps](FieldIndex fid, SparseTensorFieldKind fKind, 1318 Level lvl, LevelType lt) -> bool { 1319 if (fKind == SparseTensorFieldKind::StorageSpec) 1320 return true; 1321 1322 Type inputTp = nullptr; 1323 if (fKind == SparseTensorFieldKind::ValMemRef) { 1324 inputTp = valTp; 1325 } else { 1326 assert(fid == idx && stt.getLvlType(lvl) == lt); 1327 inputTp = lvlTps[idx++]; 1328 } 1329 // The input element type and expected element type should match. 1330 Type inpElemTp = llvm::cast<TensorType>(inputTp).getElementType(); 1331 Type expElemTp = getFieldElemType(stt, fKind); 1332 if (inpElemTp != expElemTp) { 1333 misMatch = true; 1334 return false; // to terminate the iteration 1335 } 1336 return true; 1337 }); 1338 1339 if (misMatch) 1340 return op->emitError("input/output element-types don't match"); 1341 return success(); 1342 } 1343 1344 LogicalResult AssembleOp::verify() { 1345 const auto valuesTp = getRankedTensorType(getValues()); 1346 const auto lvlsTp = getLevels().getTypes(); 1347 const auto resTp = getSparseTensorType(getResult()); 1348 return verifyPackUnPack(*this, true, resTp, valuesTp, lvlsTp); 1349 } 1350 1351 LogicalResult DisassembleOp::verify() { 1352 if (getOutValues().getType() != getRetValues().getType()) 1353 return emitError("output values and return value type mismatch"); 1354 1355 for (auto [ot, rt] : llvm::zip_equal(getOutLevels(), getRetLevels())) 1356 if (ot.getType() != rt.getType()) 1357 return emitError("output levels and return levels type mismatch"); 1358 1359 const auto valuesTp = getRankedTensorType(getRetValues()); 1360 const auto lvlsTp = getRetLevels().getTypes(); 1361 const auto srcTp = getSparseTensorType(getTensor()); 1362 return verifyPackUnPack(*this, false, srcTp, valuesTp, lvlsTp); 1363 } 1364 1365 LogicalResult ConvertOp::verify() { 1366 if (auto tp1 = llvm::dyn_cast<RankedTensorType>(getSource().getType())) { 1367 if (auto tp2 = llvm::dyn_cast<RankedTensorType>(getDest().getType())) { 1368 if (tp1.getRank() != tp2.getRank()) 1369 return emitError("unexpected conversion mismatch in rank"); 1370 auto dstEnc = 1371 llvm::dyn_cast_or_null<SparseTensorEncodingAttr>(tp2.getEncoding()); 1372 if (dstEnc && dstEnc.isSlice()) 1373 return emitError("cannot convert to a sparse tensor slice"); 1374 1375 auto shape1 = tp1.getShape(); 1376 auto shape2 = tp2.getShape(); 1377 // Accept size matches between the source and the destination type 1378 // (e.g. 10 vs. 10, 10 vs. ?, or ? vs. ?), but reject direct mismatches or 1379 // matches that would need a runtime assert (e.g. 10 vs. 20 or ? vs. 10). 1380 for (Dimension d = 0, dimRank = tp1.getRank(); d < dimRank; d++) 1381 if (shape1[d] != shape2[d] && shape2[d] != ShapedType::kDynamic) 1382 return emitError("unexpected conversion mismatch in dimension ") << d; 1383 return success(); 1384 } 1385 } 1386 return emitError("unexpected type in convert"); 1387 } 1388 1389 OpFoldResult ConvertOp::fold(FoldAdaptor adaptor) { 1390 if (getType() == getSource().getType()) 1391 return getSource(); 1392 return {}; 1393 } 1394 1395 bool ConvertOp::needsExtraSort() { 1396 SparseTensorType srcStt = getSparseTensorType(getSource()); 1397 SparseTensorType dstStt = getSparseTensorType(getDest()); 1398 1399 // We do not need an extra sort when returning unordered sparse tensors or 1400 // dense tensor since dense tensor support random access. 1401 if (dstStt.isAllDense() || !dstStt.isAllOrdered()) 1402 return false; 1403 1404 if (srcStt.isAllOrdered() && dstStt.isAllOrdered() && 1405 srcStt.hasSameDimToLvl(dstStt)) { 1406 return false; 1407 } 1408 1409 // Source and dest tensors are ordered in different ways. We only do direct 1410 // dense to sparse conversion when the dense input is defined by a sparse 1411 // constant. Note that we can theoretically always directly convert from dense 1412 // inputs by rotating dense loops but it leads to bad cache locality and hurt 1413 // performance. 1414 if (auto constOp = getSource().getDefiningOp<arith::ConstantOp>()) 1415 if (isa<SparseElementsAttr>(constOp.getValue())) 1416 return false; 1417 1418 return true; 1419 } 1420 1421 LogicalResult CrdTranslateOp::verify() { 1422 uint64_t inRank = getEncoder().getLvlRank(); 1423 uint64_t outRank = getEncoder().getDimRank(); 1424 1425 if (getDirection() == CrdTransDirectionKind::dim2lvl) 1426 std::swap(inRank, outRank); 1427 1428 if (inRank != getInCrds().size() || outRank != getOutCrds().size()) 1429 return emitError("Coordinate rank mismatch with encoding"); 1430 1431 return success(); 1432 } 1433 1434 LogicalResult CrdTranslateOp::fold(FoldAdaptor adaptor, 1435 SmallVectorImpl<OpFoldResult> &results) { 1436 if (getEncoder().isIdentity()) { 1437 results.assign(getInCrds().begin(), getInCrds().end()); 1438 return success(); 1439 } 1440 if (getEncoder().isPermutation()) { 1441 AffineMap perm = getDirection() == CrdTransDirectionKind::dim2lvl 1442 ? getEncoder().getDimToLvl() 1443 : getEncoder().getLvlToDim(); 1444 for (AffineExpr exp : perm.getResults()) 1445 results.push_back(getInCrds()[cast<AffineDimExpr>(exp).getPosition()]); 1446 return success(); 1447 } 1448 1449 // Fuse dim2lvl/lvl2dim pairs. 1450 auto def = getInCrds()[0].getDefiningOp<CrdTranslateOp>(); 1451 bool sameDef = def && llvm::all_of(getInCrds(), [def](Value v) { 1452 return v.getDefiningOp() == def; 1453 }); 1454 if (!sameDef) 1455 return failure(); 1456 1457 bool oppositeDir = def.getDirection() != getDirection(); 1458 bool sameOracle = 1459 def.getEncoder().getDimToLvl() == getEncoder().getDimToLvl(); 1460 bool sameCount = def.getNumResults() == getInCrds().size(); 1461 if (!oppositeDir || !sameOracle || !sameCount) 1462 return failure(); 1463 1464 // The definition produces the coordinates in the same order as the input 1465 // coordinates. 1466 bool sameOrder = llvm::all_of(llvm::zip_equal(def.getOutCrds(), getInCrds()), 1467 [](auto valuePair) { 1468 auto [lhs, rhs] = valuePair; 1469 return lhs == rhs; 1470 }); 1471 1472 if (!sameOrder) 1473 return failure(); 1474 // l1 = dim2lvl (lvl2dim l0) 1475 // ==> l0 1476 results.append(def.getInCrds().begin(), def.getInCrds().end()); 1477 return success(); 1478 } 1479 1480 void LvlOp::build(OpBuilder &builder, OperationState &state, Value source, 1481 int64_t index) { 1482 Value val = builder.create<arith::ConstantIndexOp>(state.location, index); 1483 return build(builder, state, source, val); 1484 } 1485 1486 LogicalResult LvlOp::verify() { 1487 if (std::optional<uint64_t> lvl = getConstantLvlIndex()) { 1488 auto stt = getSparseTensorType(getSource()); 1489 if (static_cast<uint64_t>(lvl.value()) >= stt.getLvlRank()) 1490 emitError("Level index exceeds the rank of the input sparse tensor"); 1491 } 1492 return success(); 1493 } 1494 1495 std::optional<uint64_t> LvlOp::getConstantLvlIndex() { 1496 return getConstantIntValue(getIndex()); 1497 } 1498 1499 Speculation::Speculatability LvlOp::getSpeculatability() { 1500 auto constantIndex = getConstantLvlIndex(); 1501 if (!constantIndex) 1502 return Speculation::NotSpeculatable; 1503 1504 assert(constantIndex < 1505 cast<RankedTensorType>(getSource().getType()).getRank()); 1506 return Speculation::Speculatable; 1507 } 1508 1509 OpFoldResult LvlOp::fold(FoldAdaptor adaptor) { 1510 auto lvlIndex = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getIndex()); 1511 if (!lvlIndex) 1512 return {}; 1513 1514 Level lvl = lvlIndex.getAPSInt().getZExtValue(); 1515 auto stt = getSparseTensorType(getSource()); 1516 if (lvl >= stt.getLvlRank()) { 1517 // Follows the same convention used by tensor.dim operation. Out of bound 1518 // indices produce undefined behavior but are still valid IR. Don't choke on 1519 // them. 1520 return {}; 1521 } 1522 1523 // Helper lambda to build an IndexAttr. 1524 auto getIndexAttr = [this](int64_t lvlSz) { 1525 return IntegerAttr::get(IndexType::get(getContext()), APInt(64, lvlSz)); 1526 }; 1527 1528 SmallVector<Size> lvlShape = stt.getLvlShape(); 1529 if (!ShapedType::isDynamic(lvlShape[lvl])) 1530 return getIndexAttr(lvlShape[lvl]); 1531 1532 return {}; 1533 } 1534 1535 void ReinterpretMapOp::build(OpBuilder &odsBuilder, OperationState &odsState, 1536 SparseTensorEncodingAttr dstEnc, Value source) { 1537 auto srcStt = getSparseTensorType(source); 1538 SmallVector<int64_t> srcLvlShape = srcStt.getLvlShape(); 1539 SmallVector<int64_t> dstDimShape = 1540 dstEnc.translateShape(srcLvlShape, CrdTransDirectionKind::lvl2dim); 1541 auto dstTp = 1542 RankedTensorType::get(dstDimShape, srcStt.getElementType(), dstEnc); 1543 return build(odsBuilder, odsState, dstTp, source); 1544 } 1545 1546 LogicalResult ReinterpretMapOp::verify() { 1547 auto srcStt = getSparseTensorType(getSource()); 1548 auto dstStt = getSparseTensorType(getDest()); 1549 ArrayRef<LevelType> srcLvlTps = srcStt.getLvlTypes(); 1550 ArrayRef<LevelType> dstLvlTps = dstStt.getLvlTypes(); 1551 1552 if (srcLvlTps.size() != dstLvlTps.size()) 1553 return emitError("Level rank mismatch between source/dest tensors"); 1554 1555 for (auto [srcLvlTp, dstLvlTp] : llvm::zip(srcLvlTps, dstLvlTps)) 1556 if (srcLvlTp != dstLvlTp) 1557 return emitError("Level type mismatch between source/dest tensors"); 1558 1559 if (srcStt.getPosWidth() != dstStt.getPosWidth() || 1560 srcStt.getCrdWidth() != dstStt.getCrdWidth()) { 1561 return emitError("Crd/Pos width mismatch between source/dest tensors"); 1562 } 1563 1564 if (srcStt.getElementType() != dstStt.getElementType()) 1565 return emitError("Element type mismatch between source/dest tensors"); 1566 1567 SmallVector<Size> srcLvlShape = srcStt.getLvlShape(); 1568 SmallVector<Size> dstLvlShape = dstStt.getLvlShape(); 1569 for (auto [srcLvlSz, dstLvlSz] : llvm::zip(srcLvlShape, dstLvlShape)) { 1570 if (srcLvlSz != dstLvlSz) { 1571 // Should we allow one side to be dynamic size, e.g., <?x?> should be 1572 // compatible to <3x4>? For now, we require all the level sizes to be 1573 // *exactly* matched for simplicity. 1574 return emitError("Level size mismatch between source/dest tensors"); 1575 } 1576 } 1577 1578 return success(); 1579 } 1580 1581 OpFoldResult ReinterpretMapOp::fold(FoldAdaptor adaptor) { 1582 if (getSource().getType() == getDest().getType()) 1583 return getSource(); 1584 1585 if (auto def = getSource().getDefiningOp<ReinterpretMapOp>()) { 1586 // A -> B, B -> A ==> A 1587 if (def.getSource().getType() == getDest().getType()) 1588 return def.getSource(); 1589 } 1590 return {}; 1591 } 1592 1593 template <typename ToBufferOp> 1594 static LogicalResult inferSparseBufferType(ValueRange ops, DictionaryAttr attr, 1595 OpaqueProperties prop, 1596 RegionRange region, 1597 SmallVectorImpl<mlir::Type> &ret) { 1598 typename ToBufferOp::Adaptor adaptor(ops, attr, prop, region); 1599 SparseTensorType stt = getSparseTensorType(adaptor.getTensor()); 1600 Type elemTp = nullptr; 1601 bool withStride = false; 1602 if constexpr (std::is_same_v<ToBufferOp, ToPositionsOp>) { 1603 elemTp = stt.getPosType(); 1604 } else if constexpr (std::is_same_v<ToBufferOp, ToCoordinatesOp> || 1605 std::is_same_v<ToBufferOp, ToCoordinatesBufferOp>) { 1606 elemTp = stt.getCrdType(); 1607 if constexpr (std::is_same_v<ToBufferOp, ToCoordinatesOp>) 1608 withStride = stt.getAoSCOOStart() <= adaptor.getLevel(); 1609 } else if constexpr (std::is_same_v<ToBufferOp, ToValuesOp>) { 1610 elemTp = stt.getElementType(); 1611 } 1612 1613 assert(elemTp && "unhandled operation."); 1614 SmallVector<int64_t> bufShape = stt.getBatchLvlShape(); 1615 bufShape.push_back(ShapedType::kDynamic); 1616 1617 auto layout = withStride ? StridedLayoutAttr::StridedLayoutAttr::get( 1618 stt.getContext(), ShapedType::kDynamic, 1619 {ShapedType::kDynamic}) 1620 : StridedLayoutAttr(); 1621 ret.emplace_back(MemRefType::get(bufShape, elemTp, layout)); 1622 return success(); 1623 } 1624 1625 LogicalResult ToPositionsOp::verify() { 1626 auto stt = getSparseTensorType(getTensor()); 1627 if (failed(lvlIsInBounds(getLevel(), getTensor()))) 1628 return emitError("requested level is out of bounds"); 1629 if (failed(isMatchingWidth(getResult(), stt.getPosWidth()))) 1630 return emitError("unexpected type for positions"); 1631 return success(); 1632 } 1633 1634 LogicalResult 1635 ToPositionsOp::inferReturnTypes(MLIRContext *ctx, std::optional<Location> loc, 1636 ValueRange ops, DictionaryAttr attr, 1637 OpaqueProperties prop, RegionRange region, 1638 SmallVectorImpl<mlir::Type> &ret) { 1639 return inferSparseBufferType<ToPositionsOp>(ops, attr, prop, region, ret); 1640 } 1641 1642 LogicalResult ToCoordinatesOp::verify() { 1643 auto stt = getSparseTensorType(getTensor()); 1644 if (failed(lvlIsInBounds(getLevel(), getTensor()))) 1645 return emitError("requested level is out of bounds"); 1646 if (failed(isMatchingWidth(getResult(), stt.getCrdWidth()))) 1647 return emitError("unexpected type for coordinates"); 1648 return success(); 1649 } 1650 1651 LogicalResult 1652 ToCoordinatesOp::inferReturnTypes(MLIRContext *ctx, std::optional<Location> loc, 1653 ValueRange ops, DictionaryAttr attr, 1654 OpaqueProperties prop, RegionRange region, 1655 SmallVectorImpl<mlir::Type> &ret) { 1656 return inferSparseBufferType<ToCoordinatesOp>(ops, attr, prop, region, ret); 1657 } 1658 1659 LogicalResult ToCoordinatesBufferOp::verify() { 1660 auto stt = getSparseTensorType(getTensor()); 1661 if (stt.getAoSCOOStart() >= stt.getLvlRank()) 1662 return emitError("expected sparse tensor with a COO region"); 1663 return success(); 1664 } 1665 1666 LogicalResult ToCoordinatesBufferOp::inferReturnTypes( 1667 MLIRContext *ctx, std::optional<Location> loc, ValueRange ops, 1668 DictionaryAttr attr, OpaqueProperties prop, RegionRange region, 1669 SmallVectorImpl<mlir::Type> &ret) { 1670 return inferSparseBufferType<ToCoordinatesBufferOp>(ops, attr, prop, region, 1671 ret); 1672 } 1673 1674 LogicalResult ToValuesOp::verify() { 1675 auto stt = getSparseTensorType(getTensor()); 1676 auto mtp = getMemRefType(getResult()); 1677 if (stt.getElementType() != mtp.getElementType()) 1678 return emitError("unexpected mismatch in element types"); 1679 return success(); 1680 } 1681 1682 LogicalResult ToValuesOp::inferReturnTypes(MLIRContext *ctx, 1683 std::optional<Location> loc, 1684 ValueRange ops, DictionaryAttr attr, 1685 OpaqueProperties prop, 1686 RegionRange region, 1687 SmallVectorImpl<mlir::Type> &ret) { 1688 return inferSparseBufferType<ToValuesOp>(ops, attr, prop, region, ret); 1689 } 1690 1691 LogicalResult ToSliceOffsetOp::verify() { 1692 auto rank = getRankedTensorType(getSlice()).getRank(); 1693 if (rank <= getDim().getSExtValue() || getDim().getSExtValue() < 0) 1694 return emitError("requested dimension out of bound"); 1695 return success(); 1696 } 1697 1698 LogicalResult ToSliceStrideOp::verify() { 1699 auto rank = getRankedTensorType(getSlice()).getRank(); 1700 if (rank <= getDim().getSExtValue() || getDim().getSExtValue() < 0) 1701 return emitError("requested dimension out of bound"); 1702 return success(); 1703 } 1704 1705 LogicalResult GetStorageSpecifierOp::verify() { 1706 return verifySparsifierGetterSetter(getSpecifierKind(), getLevel(), 1707 getSpecifier(), getOperation()); 1708 } 1709 1710 template <typename SpecifierOp> 1711 static SetStorageSpecifierOp getSpecifierSetDef(SpecifierOp op) { 1712 return op.getSpecifier().template getDefiningOp<SetStorageSpecifierOp>(); 1713 } 1714 1715 OpFoldResult GetStorageSpecifierOp::fold(FoldAdaptor adaptor) { 1716 const StorageSpecifierKind kind = getSpecifierKind(); 1717 const auto lvl = getLevel(); 1718 for (auto op = getSpecifierSetDef(*this); op; op = getSpecifierSetDef(op)) 1719 if (kind == op.getSpecifierKind() && lvl == op.getLevel()) 1720 return op.getValue(); 1721 return {}; 1722 } 1723 1724 LogicalResult SetStorageSpecifierOp::verify() { 1725 return verifySparsifierGetterSetter(getSpecifierKind(), getLevel(), 1726 getSpecifier(), getOperation()); 1727 } 1728 1729 template <class T> 1730 static LogicalResult verifyNumBlockArgs(T *op, Region ®ion, 1731 const char *regionName, 1732 TypeRange inputTypes, Type outputType) { 1733 unsigned numArgs = region.getNumArguments(); 1734 unsigned expectedNum = inputTypes.size(); 1735 if (numArgs != expectedNum) 1736 return op->emitError() << regionName << " region must have exactly " 1737 << expectedNum << " arguments"; 1738 1739 for (unsigned i = 0; i < numArgs; i++) { 1740 Type typ = region.getArgument(i).getType(); 1741 if (typ != inputTypes[i]) 1742 return op->emitError() << regionName << " region argument " << (i + 1) 1743 << " type mismatch"; 1744 } 1745 Operation *term = region.front().getTerminator(); 1746 YieldOp yield = dyn_cast<YieldOp>(term); 1747 if (!yield) 1748 return op->emitError() << regionName 1749 << " region must end with sparse_tensor.yield"; 1750 if (!yield.hasSingleResult() || 1751 yield.getSingleResult().getType() != outputType) 1752 return op->emitError() << regionName << " region yield type mismatch"; 1753 1754 return success(); 1755 } 1756 1757 LogicalResult BinaryOp::verify() { 1758 NamedAttrList attrs = (*this)->getAttrs(); 1759 Type leftType = getX().getType(); 1760 Type rightType = getY().getType(); 1761 Type outputType = getOutput().getType(); 1762 Region &overlap = getOverlapRegion(); 1763 Region &left = getLeftRegion(); 1764 Region &right = getRightRegion(); 1765 1766 // Check correct number of block arguments and return type for each 1767 // non-empty region. 1768 if (!overlap.empty()) { 1769 if (failed(verifyNumBlockArgs(this, overlap, "overlap", 1770 TypeRange{leftType, rightType}, outputType))) 1771 return failure(); 1772 } 1773 if (!left.empty()) { 1774 if (failed(verifyNumBlockArgs(this, left, "left", TypeRange{leftType}, 1775 outputType))) 1776 return failure(); 1777 } else if (getLeftIdentity()) { 1778 if (leftType != outputType) 1779 return emitError("left=identity requires first argument to have the same " 1780 "type as the output"); 1781 } 1782 if (!right.empty()) { 1783 if (failed(verifyNumBlockArgs(this, right, "right", TypeRange{rightType}, 1784 outputType))) 1785 return failure(); 1786 } else if (getRightIdentity()) { 1787 if (rightType != outputType) 1788 return emitError("right=identity requires second argument to have the " 1789 "same type as the output"); 1790 } 1791 return success(); 1792 } 1793 1794 LogicalResult UnaryOp::verify() { 1795 Type inputType = getX().getType(); 1796 Type outputType = getOutput().getType(); 1797 1798 // Check correct number of block arguments and return type for each 1799 // non-empty region. 1800 Region &present = getPresentRegion(); 1801 if (!present.empty()) { 1802 if (failed(verifyNumBlockArgs(this, present, "present", 1803 TypeRange{inputType}, outputType))) 1804 return failure(); 1805 } 1806 Region &absent = getAbsentRegion(); 1807 if (!absent.empty()) { 1808 if (failed(verifyNumBlockArgs(this, absent, "absent", TypeRange{}, 1809 outputType))) 1810 return failure(); 1811 // Absent branch can only yield invariant values. 1812 Block *absentBlock = &absent.front(); 1813 Block *parent = getOperation()->getBlock(); 1814 Value absentVal = 1815 cast<YieldOp>(absentBlock->getTerminator()).getSingleResult(); 1816 if (auto arg = dyn_cast<BlockArgument>(absentVal)) { 1817 if (arg.getOwner() == parent) 1818 return emitError("absent region cannot yield linalg argument"); 1819 } else if (Operation *def = absentVal.getDefiningOp()) { 1820 if (!isa<arith::ConstantOp>(def) && 1821 (def->getBlock() == absentBlock || def->getBlock() == parent)) 1822 return emitError("absent region cannot yield locally computed value"); 1823 } 1824 } 1825 return success(); 1826 } 1827 1828 bool ConcatenateOp::needsExtraSort() { 1829 SparseTensorType dstStt = getSparseTensorType(*this); 1830 if (dstStt.isAllDense() || !dstStt.isAllOrdered()) 1831 return false; 1832 1833 bool allSameOrdered = llvm::all_of(getInputs(), [dstStt](Value op) { 1834 return getSparseTensorType(op).hasSameDimToLvl(dstStt); 1835 }); 1836 // TODO: When conDim != 0, as long as conDim corresponding to the first level 1837 // in all input/output buffers, and all input/output buffers have the same 1838 // dimToLvl, the tmp COO buffer is still unnecessary (e.g, concatenate 1839 // CSC matrices along column). 1840 bool directLowerable = 1841 allSameOrdered && getDimension() == 0 && dstStt.isIdentity(); 1842 return !directLowerable; 1843 } 1844 1845 LogicalResult ConcatenateOp::verify() { 1846 const auto dstTp = getSparseTensorType(*this); 1847 const Dimension concatDim = getDimension(); 1848 const Dimension dimRank = dstTp.getDimRank(); 1849 1850 if (getInputs().size() <= 1) 1851 return emitError("Need at least two tensors to concatenate."); 1852 1853 if (concatDim >= dimRank) 1854 return emitError(llvm::formatv( 1855 "Concat-dimension is out of bounds for dimension-rank ({0} >= {1})", 1856 concatDim, dimRank)); 1857 1858 for (const auto &it : llvm::enumerate(getInputs())) { 1859 const auto i = it.index(); 1860 const auto srcTp = getSparseTensorType(it.value()); 1861 if (srcTp.hasDynamicDimShape()) 1862 return emitError(llvm::formatv("Input tensor ${0} has dynamic shape", i)); 1863 const Dimension srcDimRank = srcTp.getDimRank(); 1864 if (srcDimRank != dimRank) 1865 return emitError( 1866 llvm::formatv("Input tensor ${0} has a different rank (rank={1}) " 1867 "from the output tensor (rank={2}).", 1868 i, srcDimRank, dimRank)); 1869 } 1870 1871 for (Dimension d = 0; d < dimRank; d++) { 1872 const Size dstSh = dstTp.getDimShape()[d]; 1873 if (d == concatDim) { 1874 if (!ShapedType::isDynamic(dstSh)) { 1875 // If we reach here, then all inputs have static shapes. So we 1876 // can use `getDimShape()[d]` instead of `*getDynamicDimSize(d)` 1877 // to avoid redundant assertions in the loop. 1878 Size sumSz = 0; 1879 for (const auto src : getInputs()) 1880 sumSz += getSparseTensorType(src).getDimShape()[d]; 1881 // If all dimension are statically known, the sum of all the input 1882 // dimensions should be equal to the output dimension. 1883 if (sumSz != dstSh) 1884 return emitError( 1885 "The concatenation dimension of the output tensor should be the " 1886 "sum of all the concatenation dimensions of the input tensors."); 1887 } 1888 } else { 1889 Size prev = dstSh; 1890 for (const auto src : getInputs()) { 1891 const auto sh = getSparseTensorType(src).getDimShape()[d]; 1892 if (!ShapedType::isDynamic(prev) && sh != prev) 1893 return emitError("All dimensions (expect for the concatenating one) " 1894 "should be equal."); 1895 prev = sh; 1896 } 1897 } 1898 } 1899 1900 return success(); 1901 } 1902 1903 void PushBackOp::build(OpBuilder &builder, OperationState &result, 1904 Value curSize, Value inBuffer, Value value) { 1905 build(builder, result, curSize, inBuffer, value, Value()); 1906 } 1907 1908 LogicalResult PushBackOp::verify() { 1909 if (Value n = getN()) { 1910 std::optional<int64_t> nValue = getConstantIntValue(n); 1911 if (nValue && nValue.value() < 1) 1912 return emitOpError("n must be not less than 1"); 1913 } 1914 return success(); 1915 } 1916 1917 LogicalResult CompressOp::verify() { 1918 const auto stt = getSparseTensorType(getTensor()); 1919 if (stt.getLvlRank() != 1 + static_cast<Level>(getLvlCoords().size())) 1920 return emitOpError("incorrect number of coordinates"); 1921 return success(); 1922 } 1923 1924 void ForeachOp::build( 1925 OpBuilder &builder, OperationState &result, Value tensor, 1926 ValueRange initArgs, AffineMapAttr order, 1927 function_ref<void(OpBuilder &, Location, ValueRange, Value, ValueRange)> 1928 bodyBuilder) { 1929 build(builder, result, initArgs.getTypes(), tensor, initArgs, order); 1930 // Builds foreach body. 1931 if (!bodyBuilder) 1932 return; 1933 const auto stt = getSparseTensorType(tensor); 1934 const Dimension dimRank = stt.getDimRank(); 1935 1936 // Starts with `dimRank`-many coordinates. 1937 SmallVector<Type> blockArgTypes(dimRank, builder.getIndexType()); 1938 // Followed by one value. 1939 blockArgTypes.push_back(stt.getElementType()); 1940 // Followed by the reduction variables. 1941 blockArgTypes.append(initArgs.getTypes().begin(), initArgs.getTypes().end()); 1942 1943 SmallVector<Location> blockArgLocs(blockArgTypes.size(), tensor.getLoc()); 1944 1945 OpBuilder::InsertionGuard guard(builder); 1946 auto ®ion = *result.regions.front(); 1947 Block *bodyBlock = 1948 builder.createBlock(®ion, region.end(), blockArgTypes, blockArgLocs); 1949 bodyBuilder(builder, result.location, 1950 bodyBlock->getArguments().slice(0, dimRank), 1951 bodyBlock->getArguments()[dimRank], 1952 bodyBlock->getArguments().drop_front(dimRank + 1)); 1953 } 1954 1955 LogicalResult ForeachOp::verify() { 1956 const auto t = getSparseTensorType(getTensor()); 1957 const Dimension dimRank = t.getDimRank(); 1958 const auto args = getBody()->getArguments(); 1959 1960 if (getOrder().has_value() && getOrder()->getNumDims() != t.getLvlRank()) 1961 return emitError("Level traverse order does not match tensor's level rank"); 1962 1963 if (dimRank + 1 + getInitArgs().size() != args.size()) 1964 return emitError("Unmatched number of arguments in the block"); 1965 1966 if (getNumResults() != getInitArgs().size()) 1967 return emitError("Mismatch in number of init arguments and results"); 1968 1969 if (getResultTypes() != getInitArgs().getTypes()) 1970 return emitError("Mismatch in types of init arguments and results"); 1971 1972 // Cannot mark this const, because the getters aren't. 1973 auto yield = cast<YieldOp>(getBody()->getTerminator()); 1974 if (yield.getNumOperands() != getNumResults() || 1975 yield.getOperands().getTypes() != getResultTypes()) 1976 return emitError("Mismatch in types of yield values and results"); 1977 1978 const auto iTp = IndexType::get(getContext()); 1979 for (Dimension d = 0; d < dimRank; d++) 1980 if (args[d].getType() != iTp) 1981 emitError( 1982 llvm::formatv("Expecting Index type for argument at index {0}", d)); 1983 1984 const auto elemTp = t.getElementType(); 1985 const auto valueTp = args[dimRank].getType(); 1986 if (elemTp != valueTp) 1987 emitError(llvm::formatv("Unmatched element type between input tensor and " 1988 "block argument, expected:{0}, got: {1}", 1989 elemTp, valueTp)); 1990 return success(); 1991 } 1992 1993 OpFoldResult ReorderCOOOp::fold(FoldAdaptor adaptor) { 1994 if (getSparseTensorEncoding(getInputCoo().getType()) == 1995 getSparseTensorEncoding(getResultCoo().getType())) 1996 return getInputCoo(); 1997 1998 return {}; 1999 } 2000 2001 LogicalResult ReorderCOOOp::verify() { 2002 SparseTensorType srcStt = getSparseTensorType(getInputCoo()); 2003 SparseTensorType dstStt = getSparseTensorType(getResultCoo()); 2004 2005 if (!srcStt.isCOOType() || !dstStt.isCOOType()) 2006 emitError("Expected COO sparse tensors only"); 2007 2008 if (!srcStt.hasSameDimToLvl(dstStt)) 2009 emitError("Unmatched dim2lvl map between input and result COO"); 2010 2011 if (srcStt.getPosType() != dstStt.getPosType() || 2012 srcStt.getCrdType() != dstStt.getCrdType() || 2013 srcStt.getElementType() != dstStt.getElementType()) 2014 emitError("Unmatched storage format between input and result COO"); 2015 2016 return success(); 2017 } 2018 2019 LogicalResult ReduceOp::verify() { 2020 Type inputType = getX().getType(); 2021 Region &formula = getRegion(); 2022 return verifyNumBlockArgs(this, formula, "reduce", 2023 TypeRange{inputType, inputType}, inputType); 2024 } 2025 2026 LogicalResult SelectOp::verify() { 2027 Builder b(getContext()); 2028 Type inputType = getX().getType(); 2029 Type boolType = b.getI1Type(); 2030 Region &formula = getRegion(); 2031 return verifyNumBlockArgs(this, formula, "select", TypeRange{inputType}, 2032 boolType); 2033 } 2034 2035 LogicalResult SortOp::verify() { 2036 AffineMap xPerm = getPermMap(); 2037 uint64_t nx = xPerm.getNumDims(); 2038 if (nx < 1) 2039 emitError(llvm::formatv("Expected rank(perm_map) > 1, got {0}", nx)); 2040 2041 if (!xPerm.isPermutation()) 2042 emitError(llvm::formatv("Expected a permutation map, got {0}", xPerm)); 2043 2044 // We can't check the size of the buffers when n or buffer dimensions aren't 2045 // compile-time constants. 2046 std::optional<int64_t> cn = getConstantIntValue(getN()); 2047 if (!cn) 2048 return success(); 2049 2050 // Verify dimensions. 2051 const auto checkDim = [&](Value v, Size minSize, const char *message) { 2052 const Size sh = getMemRefType(v).getShape()[0]; 2053 if (!ShapedType::isDynamic(sh) && sh < minSize) 2054 emitError(llvm::formatv("{0} got {1} < {2}", message, sh, minSize)); 2055 }; 2056 uint64_t n = cn.value(); 2057 uint64_t ny = 0; 2058 if (auto nyAttr = getNyAttr()) 2059 ny = nyAttr.getInt(); 2060 checkDim(getXy(), n * (nx + ny), 2061 "Expected dimension(xy) >= n * (rank(perm_map) + ny)"); 2062 for (Value opnd : getYs()) 2063 checkDim(opnd, n, "Expected dimension(y) >= n"); 2064 2065 return success(); 2066 } 2067 2068 //===----------------------------------------------------------------------===// 2069 // Sparse Tensor Iteration Operations. 2070 //===----------------------------------------------------------------------===// 2071 2072 IterSpaceType IteratorType::getIterSpaceType() const { 2073 return IterSpaceType::get(getContext(), getEncoding(), getLoLvl(), 2074 getHiLvl()); 2075 } 2076 2077 IteratorType IterSpaceType::getIteratorType() const { 2078 return IteratorType::get(getContext(), getEncoding(), getLoLvl(), getHiLvl()); 2079 } 2080 2081 /// Parses a level range in the form "$lo `to` $hi" 2082 /// or simply "$lo" if $hi - $lo = 1 2083 static ParseResult parseLevelRange(AsmParser &parser, Level &lvlLo, 2084 Level &lvlHi) { 2085 if (parser.parseInteger(lvlLo)) 2086 return failure(); 2087 2088 if (succeeded(parser.parseOptionalKeyword("to"))) { 2089 if (parser.parseInteger(lvlHi)) 2090 return failure(); 2091 } else { 2092 lvlHi = lvlLo + 1; 2093 } 2094 2095 if (lvlHi <= lvlLo) 2096 parser.emitError(parser.getNameLoc(), 2097 "expect larger level upper bound than lower bound"); 2098 2099 return success(); 2100 } 2101 2102 /// Parses a level range in the form "$lo `to` $hi" 2103 /// or simply "$lo" if $hi - $lo = 1 2104 static ParseResult parseLevelRange(OpAsmParser &parser, IntegerAttr &lvlLoAttr, 2105 IntegerAttr &lvlHiAttr) { 2106 Level lvlLo, lvlHi; 2107 if (parseLevelRange(parser, lvlLo, lvlHi)) 2108 return failure(); 2109 2110 lvlLoAttr = IntegerAttr::get(parser.getBuilder().getIndexType(), lvlLo); 2111 lvlHiAttr = IntegerAttr::get(parser.getBuilder().getIndexType(), lvlHi); 2112 return success(); 2113 } 2114 2115 /// Prints a level range in the form "$lo `to` $hi" 2116 /// or simply "$lo" if $hi - $lo = 1 2117 static void printLevelRange(AsmPrinter &p, Level lo, Level hi) { 2118 2119 if (lo + 1 == hi) 2120 p << lo; 2121 else 2122 p << lo << " to " << hi; 2123 } 2124 2125 /// Prints a level range in the form "$lo `to` $hi" 2126 /// or simply "$lo" if $hi - $lo = 1 2127 static void printLevelRange(OpAsmPrinter &p, Operation *, IntegerAttr lvlLo, 2128 IntegerAttr lvlHi) { 2129 unsigned lo = lvlLo.getValue().getZExtValue(); 2130 unsigned hi = lvlHi.getValue().getZExtValue(); 2131 printLevelRange(p, lo, hi); 2132 } 2133 2134 static ParseResult 2135 parseSparseSpaceLoop(OpAsmParser &parser, OperationState &state, 2136 SmallVectorImpl<OpAsmParser::Argument> &iterators, 2137 SmallVectorImpl<OpAsmParser::Argument> &iterArgs) { 2138 SmallVector<OpAsmParser::UnresolvedOperand> spaces; 2139 SmallVector<OpAsmParser::UnresolvedOperand> initArgs; 2140 2141 // Parse "%iters, ... in %spaces, ..." 2142 if (parser.parseArgumentList(iterators) || parser.parseKeyword("in") || 2143 parser.parseOperandList(spaces)) 2144 return failure(); 2145 2146 if (iterators.size() != spaces.size()) 2147 return parser.emitError( 2148 parser.getNameLoc(), 2149 "mismatch in number of sparse iterators and sparse spaces"); 2150 2151 // Parse "at(%crd0, _, ...)" 2152 LevelSet crdUsedLvlSet; 2153 bool hasUsedCrds = succeeded(parser.parseOptionalKeyword("at")); 2154 unsigned lvlCrdCnt = 0; 2155 if (hasUsedCrds) { 2156 ParseResult crdList = parser.parseCommaSeparatedList( 2157 OpAsmParser::Delimiter::Paren, [&]() -> ParseResult { 2158 if (parser.parseOptionalKeyword("_")) { 2159 if (parser.parseArgument(iterArgs.emplace_back())) 2160 return failure(); 2161 // Always use IndexType for the coordinate. 2162 crdUsedLvlSet.set(lvlCrdCnt); 2163 iterArgs.back().type = parser.getBuilder().getIndexType(); 2164 } 2165 lvlCrdCnt += 1; 2166 return success(); 2167 }); 2168 if (failed(crdList)) { 2169 return parser.emitError( 2170 parser.getNameLoc(), 2171 "expecting SSA value or \"_\" for level coordinates"); 2172 } 2173 } 2174 // Set the CrdUsedLvl bitset. 2175 state.addAttribute("crdUsedLvls", 2176 parser.getBuilder().getI64IntegerAttr(crdUsedLvlSet)); 2177 2178 // Parse "iter_args(%arg = %init, ...)" 2179 bool hasIterArgs = succeeded(parser.parseOptionalKeyword("iter_args")); 2180 if (hasIterArgs) 2181 if (parser.parseAssignmentList(iterArgs, initArgs)) 2182 return failure(); 2183 2184 SmallVector<Type> iterSpaceTps; 2185 // parse ": sparse_tensor.iter_space -> ret" 2186 if (parser.parseColon() || parser.parseTypeList(iterSpaceTps)) 2187 return failure(); 2188 if (iterSpaceTps.size() != spaces.size()) 2189 return parser.emitError(parser.getNameLoc(), 2190 "mismatch in number of iteration space operands " 2191 "and iteration space types"); 2192 2193 for (auto [it, tp] : llvm::zip_equal(iterators, iterSpaceTps)) { 2194 IterSpaceType spaceTp = llvm::dyn_cast<IterSpaceType>(tp); 2195 if (!spaceTp) 2196 return parser.emitError(parser.getNameLoc(), 2197 "expected sparse_tensor.iter_space type for " 2198 "iteration space operands"); 2199 if (hasUsedCrds && spaceTp.getSpaceDim() != lvlCrdCnt) 2200 return parser.emitError(parser.getNameLoc(), 2201 "mismatch in number of iteration space dimension " 2202 "and specified coordinates"); 2203 it.type = spaceTp.getIteratorType(); 2204 } 2205 2206 if (hasIterArgs) 2207 if (parser.parseArrowTypeList(state.types)) 2208 return failure(); 2209 2210 // Resolves input operands. 2211 if (parser.resolveOperands(spaces, iterSpaceTps, parser.getNameLoc(), 2212 state.operands)) 2213 return failure(); 2214 2215 if (hasIterArgs) { 2216 unsigned numCrds = crdUsedLvlSet.count(); 2217 // Strip off leading args that used for coordinates. 2218 MutableArrayRef args = MutableArrayRef(iterArgs).drop_front(numCrds); 2219 if (args.size() != initArgs.size() || args.size() != state.types.size()) { 2220 return parser.emitError( 2221 parser.getNameLoc(), 2222 "mismatch in number of iteration arguments and return values"); 2223 } 2224 2225 for (auto [it, init, tp] : llvm::zip_equal(args, initArgs, state.types)) { 2226 it.type = tp; 2227 if (parser.resolveOperand(init, tp, state.operands)) 2228 return failure(); 2229 } 2230 } 2231 return success(); 2232 } 2233 2234 LogicalResult ExtractIterSpaceOp::inferReturnTypes( 2235 MLIRContext *ctx, std::optional<Location> loc, ValueRange ops, 2236 DictionaryAttr attr, OpaqueProperties prop, RegionRange region, 2237 SmallVectorImpl<mlir::Type> &ret) { 2238 2239 ExtractIterSpaceOp::Adaptor adaptor(ops, attr, prop, region); 2240 SparseTensorType stt = getSparseTensorType(adaptor.getTensor()); 2241 ret.push_back(IterSpaceType::get(ctx, stt.getEncoding(), adaptor.getLoLvl(), 2242 adaptor.getHiLvl())); 2243 return success(); 2244 } 2245 2246 LogicalResult ExtractIterSpaceOp::verify() { 2247 if (getLoLvl() >= getHiLvl()) 2248 return emitOpError("expected smaller level low than level high"); 2249 2250 TypedValue<IteratorType> pIter = getParentIter(); 2251 if ((pIter && getLoLvl() == 0) || (!pIter && getLoLvl() != 0)) { 2252 return emitOpError( 2253 "parent iterator should be specified iff level lower bound equals 0"); 2254 } 2255 2256 if (pIter) { 2257 IterSpaceType spaceTp = getExtractedSpace().getType(); 2258 if (pIter.getType().getEncoding() != spaceTp.getEncoding()) 2259 return emitOpError( 2260 "mismatch in parent iterator encoding and iteration space encoding."); 2261 2262 if (spaceTp.getLoLvl() != pIter.getType().getHiLvl()) 2263 return emitOpError("parent iterator should be used to extract an " 2264 "iteration space from a consecutive level."); 2265 } 2266 2267 return success(); 2268 } 2269 2270 struct RemoveUnusedLvlCrds : public OpRewritePattern<IterateOp> { 2271 using OpRewritePattern::OpRewritePattern; 2272 2273 LogicalResult matchAndRewrite(IterateOp iterateOp, 2274 PatternRewriter &rewriter) const override { 2275 LevelSet newUsedLvls(0); 2276 llvm::BitVector toRemove(iterateOp.getBody()->getNumArguments()); 2277 for (unsigned i = 0, e = iterateOp.getSpaceDim(); i < e; i++) { 2278 if (auto crd = iterateOp.getLvlCrd(i)) { 2279 if (crd->getUsers().empty()) 2280 toRemove.set(crd->getArgNumber()); 2281 else 2282 newUsedLvls.set(i); 2283 } 2284 } 2285 2286 // All coordinates are used. 2287 if (toRemove.none()) 2288 return failure(); 2289 2290 rewriter.startOpModification(iterateOp); 2291 iterateOp.setCrdUsedLvls(newUsedLvls); 2292 iterateOp.getBody()->eraseArguments(toRemove); 2293 rewriter.finalizeOpModification(iterateOp); 2294 return success(); 2295 } 2296 }; 2297 2298 void IterateOp::getCanonicalizationPatterns(mlir::RewritePatternSet &results, 2299 mlir::MLIRContext *context) { 2300 results.add<RemoveUnusedLvlCrds>(context); 2301 } 2302 2303 void IterateOp::build(OpBuilder &builder, OperationState &odsState, 2304 Value iterSpace, ValueRange initArgs) { 2305 unsigned rank = llvm::cast<IterSpaceType>(iterSpace.getType()).getSpaceDim(); 2306 // All ones. 2307 LevelSet set((1 << rank) - 1); 2308 return build(builder, odsState, iterSpace, initArgs, set); 2309 } 2310 2311 void IterateOp::build(OpBuilder &builder, OperationState &odsState, 2312 Value iterSpace, ValueRange initArgs, 2313 LevelSet crdUsedLvls) { 2314 OpBuilder::InsertionGuard guard(builder); 2315 2316 odsState.addOperands(iterSpace); 2317 odsState.addOperands(initArgs); 2318 odsState.getOrAddProperties<Properties>().crdUsedLvls = 2319 builder.getIntegerAttr(builder.getIntegerType(64), crdUsedLvls); 2320 Region *bodyRegion = odsState.addRegion(); 2321 odsState.addTypes(initArgs.getTypes()); 2322 Block *bodyBlock = builder.createBlock(bodyRegion); 2323 2324 // First argument, sparse iterator 2325 bodyBlock->addArgument( 2326 llvm::cast<IterSpaceType>(iterSpace.getType()).getIteratorType(), 2327 odsState.location); 2328 2329 // Followed by a list of used coordinates. 2330 for (unsigned i = 0, e = crdUsedLvls.count(); i < e; i++) 2331 bodyBlock->addArgument(builder.getIndexType(), odsState.location); 2332 2333 // Followed by a list of user-provided loop arguments. 2334 for (Value v : initArgs) 2335 bodyBlock->addArgument(v.getType(), v.getLoc()); 2336 } 2337 2338 ParseResult IterateOp::parse(OpAsmParser &parser, OperationState &result) { 2339 OpAsmParser::Argument iterator; 2340 OpAsmParser::UnresolvedOperand iterSpace; 2341 2342 SmallVector<OpAsmParser::Argument> iters, iterArgs; 2343 if (parseSparseSpaceLoop(parser, result, iters, iterArgs)) 2344 return failure(); 2345 if (iters.size() != 1) 2346 return parser.emitError(parser.getNameLoc(), 2347 "expected only one iterator/iteration space"); 2348 2349 iters.append(iterArgs); 2350 Region *body = result.addRegion(); 2351 if (parser.parseRegion(*body, iters)) 2352 return failure(); 2353 2354 IterateOp::ensureTerminator(*body, parser.getBuilder(), result.location); 2355 2356 // Parse the optional attribute list. 2357 if (parser.parseOptionalAttrDict(result.attributes)) 2358 return failure(); 2359 2360 return success(); 2361 } 2362 2363 /// Prints the initialization list in the form of 2364 /// <prefix>(%inner = %outer, %inner2 = %outer2, <...>) 2365 /// where 'inner' values are assumed to be region arguments and 'outer' values 2366 /// are regular SSA values. 2367 static void printInitializationList(OpAsmPrinter &p, 2368 Block::BlockArgListType blocksArgs, 2369 ValueRange initializers, 2370 StringRef prefix = "") { 2371 assert(blocksArgs.size() == initializers.size() && 2372 "expected same length of arguments and initializers"); 2373 if (initializers.empty()) 2374 return; 2375 2376 p << prefix << '('; 2377 llvm::interleaveComma(llvm::zip(blocksArgs, initializers), p, [&](auto it) { 2378 p << std::get<0>(it) << " = " << std::get<1>(it); 2379 }); 2380 p << ")"; 2381 } 2382 2383 static void printUsedCrdsList(OpAsmPrinter &p, unsigned spaceDim, 2384 Block::BlockArgListType blocksArgs, 2385 LevelSet crdUsedLvls) { 2386 if (crdUsedLvls.empty()) 2387 return; 2388 2389 p << " at("; 2390 for (unsigned i = 0; i < spaceDim; i++) { 2391 if (crdUsedLvls[i]) { 2392 p << blocksArgs.front(); 2393 blocksArgs = blocksArgs.drop_front(); 2394 } else { 2395 p << "_"; 2396 } 2397 if (i != spaceDim - 1) 2398 p << ", "; 2399 } 2400 assert(blocksArgs.empty()); 2401 p << ")"; 2402 } 2403 2404 void IterateOp::print(OpAsmPrinter &p) { 2405 p << " " << getIterator() << " in " << getIterSpace(); 2406 printUsedCrdsList(p, getSpaceDim(), getCrds(), getCrdUsedLvls()); 2407 printInitializationList(p, getRegionIterArgs(), getInitArgs(), " iter_args"); 2408 2409 p << " : " << getIterSpace().getType() << " "; 2410 if (!getInitArgs().empty()) 2411 p << "-> (" << getInitArgs().getTypes() << ") "; 2412 2413 p.printRegion(getRegion(), /*printEntryBlockArgs=*/false, 2414 /*printBlockTerminators=*/!getInitArgs().empty()); 2415 } 2416 2417 LogicalResult IterateOp::verify() { 2418 if (getInitArgs().size() != getNumResults()) { 2419 return emitOpError( 2420 "mismatch in number of loop-carried values and defined values"); 2421 } 2422 if (getCrdUsedLvls().max() > getSpaceDim()) 2423 return emitOpError("required out-of-bound coordinates"); 2424 2425 return success(); 2426 } 2427 2428 LogicalResult IterateOp::verifyRegions() { 2429 if (getIterator().getType() != getIterSpace().getType().getIteratorType()) 2430 return emitOpError("mismatch in iterator and iteration space type"); 2431 if (getNumRegionIterArgs() != getNumResults()) 2432 return emitOpError( 2433 "mismatch in number of basic block args and defined values"); 2434 2435 auto initArgs = getInitArgs(); 2436 auto iterArgs = getRegionIterArgs(); 2437 auto yieldVals = getYieldedValues(); 2438 auto opResults = getResults(); 2439 if (!llvm::all_equal({initArgs.size(), iterArgs.size(), yieldVals.size(), 2440 opResults.size()})) { 2441 return emitOpError() << "number mismatch between iter args and results."; 2442 } 2443 2444 for (auto [i, init, iter, yield, ret] : 2445 llvm::enumerate(initArgs, iterArgs, yieldVals, opResults)) { 2446 if (init.getType() != ret.getType()) 2447 return emitOpError() << "types mismatch between " << i 2448 << "th iter operand and defined value"; 2449 if (iter.getType() != ret.getType()) 2450 return emitOpError() << "types mismatch between " << i 2451 << "th iter region arg and defined value"; 2452 if (yield.getType() != ret.getType()) 2453 return emitOpError() << "types mismatch between " << i 2454 << "th yield value and defined value"; 2455 } 2456 2457 return success(); 2458 } 2459 2460 /// OpInterfaces' methods implemented by IterateOp. 2461 SmallVector<Region *> IterateOp::getLoopRegions() { return {&getRegion()}; } 2462 2463 MutableArrayRef<OpOperand> IterateOp::getInitsMutable() { 2464 return getInitArgsMutable(); 2465 } 2466 2467 Block::BlockArgListType IterateOp::getRegionIterArgs() { 2468 return getRegion().getArguments().take_back(getNumRegionIterArgs()); 2469 } 2470 2471 std::optional<MutableArrayRef<OpOperand>> IterateOp::getYieldedValuesMutable() { 2472 return cast<sparse_tensor::YieldOp>( 2473 getRegion().getBlocks().front().getTerminator()) 2474 .getResultsMutable(); 2475 } 2476 2477 std::optional<ResultRange> IterateOp::getLoopResults() { return getResults(); } 2478 2479 OperandRange IterateOp::getEntrySuccessorOperands(RegionBranchPoint point) { 2480 return getInitArgs(); 2481 } 2482 2483 void IterateOp::getSuccessorRegions(RegionBranchPoint point, 2484 SmallVectorImpl<RegionSuccessor> ®ions) { 2485 // Both the operation itself and the region may be branching into the body or 2486 // back into the operation itself. 2487 regions.push_back(RegionSuccessor(&getRegion(), getRegionIterArgs())); 2488 // It is possible for loop not to enter the body. 2489 regions.push_back(RegionSuccessor(getResults())); 2490 } 2491 2492 //===----------------------------------------------------------------------===// 2493 // Sparse Tensor Dialect Setups. 2494 //===----------------------------------------------------------------------===// 2495 2496 /// Materialize a single constant operation from a given attribute value with 2497 /// the desired resultant type. 2498 Operation *SparseTensorDialect::materializeConstant(OpBuilder &builder, 2499 Attribute value, Type type, 2500 Location loc) { 2501 if (auto op = arith::ConstantOp::materialize(builder, value, type, loc)) 2502 return op; 2503 return nullptr; 2504 } 2505 2506 namespace { 2507 struct SparseTensorAsmDialectInterface : public OpAsmDialectInterface { 2508 using OpAsmDialectInterface::OpAsmDialectInterface; 2509 2510 AliasResult getAlias(Attribute attr, raw_ostream &os) const override { 2511 if (isa<SparseTensorEncodingAttr>(attr)) { 2512 os << "sparse"; 2513 return AliasResult::OverridableAlias; 2514 } 2515 return AliasResult::NoAlias; 2516 } 2517 }; 2518 } // namespace 2519 2520 void SparseTensorDialect::initialize() { 2521 addInterface<SparseTensorAsmDialectInterface>(); 2522 addAttributes< 2523 #define GET_ATTRDEF_LIST 2524 #include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.cpp.inc" 2525 >(); 2526 addTypes< 2527 #define GET_TYPEDEF_LIST 2528 #include "mlir/Dialect/SparseTensor/IR/SparseTensorTypes.cpp.inc" 2529 >(); 2530 addOperations< 2531 #define GET_OP_LIST 2532 #include "mlir/Dialect/SparseTensor/IR/SparseTensorOps.cpp.inc" 2533 >(); 2534 declarePromisedInterfaces< 2535 bufferization::BufferizableOpInterface, ConcatenateOp, ConvertOp, LoadOp, 2536 NewOp, NumberOfEntriesOp, AssembleOp, DisassembleOp, 2537 ToCoordinatesBufferOp, ToCoordinatesOp, ToPositionsOp, ToValuesOp>(); 2538 } 2539 2540 #define GET_OP_CLASSES 2541 #include "mlir/Dialect/SparseTensor/IR/SparseTensorOps.cpp.inc" 2542 2543 #include "mlir/Dialect/SparseTensor/IR/SparseTensorOpsDialect.cpp.inc" 2544