1 //===- SparseTensorDialect.cpp - Sparse tensor dialect implementation -----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include <utility> 10 11 #include "Detail/DimLvlMapParser.h" 12 13 #include "mlir/Dialect/SparseTensor/IR/Enums.h" 14 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" 15 #include "mlir/Dialect/SparseTensor/IR/SparseTensorStorageLayout.h" 16 #include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h" 17 18 #include "mlir/Dialect/Arith/IR/Arith.h" 19 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" 20 #include "mlir/Dialect/Utils/StaticValueUtils.h" 21 #include "mlir/IR/Builders.h" 22 #include "mlir/IR/DialectImplementation.h" 23 #include "mlir/IR/Matchers.h" 24 #include "mlir/IR/OpImplementation.h" 25 #include "mlir/IR/PatternMatch.h" 26 #include "llvm/ADT/TypeSwitch.h" 27 #include "llvm/Support/FormatVariadic.h" 28 29 #define GET_ATTRDEF_CLASSES 30 #include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.cpp.inc" 31 #include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrEnums.cpp.inc" 32 33 // Forward declarations, following custom print/parsing methods are referenced 34 // by the generated code for SparseTensorTypes.td. 35 static mlir::ParseResult parseLevelRange(mlir::AsmParser &, 36 mlir::sparse_tensor::Level &, 37 mlir::sparse_tensor::Level &); 38 static void printLevelRange(mlir::AsmPrinter &, mlir::sparse_tensor::Level, 39 mlir::sparse_tensor::Level); 40 41 #define GET_TYPEDEF_CLASSES 42 #include "mlir/Dialect/SparseTensor/IR/SparseTensorTypes.cpp.inc" 43 44 using namespace mlir; 45 using namespace mlir::sparse_tensor; 46 47 // Support hashing LevelType such that SparseTensorEncodingAttr can be hashed as 48 // well. 49 namespace mlir::sparse_tensor { 50 llvm::hash_code hash_value(LevelType lt) { 51 return llvm::hash_value(static_cast<uint64_t>(lt)); 52 } 53 } // namespace mlir::sparse_tensor 54 55 //===----------------------------------------------------------------------===// 56 // Local Convenience Methods. 57 //===----------------------------------------------------------------------===// 58 59 static constexpr bool acceptBitWidth(unsigned bitWidth) { 60 switch (bitWidth) { 61 case 0: 62 case 8: 63 case 16: 64 case 32: 65 case 64: 66 return true; 67 default: 68 return false; 69 } 70 } 71 72 static SmallVector<Size> 73 getSparseFieldShape(const SparseTensorEncodingAttr enc, 74 std::optional<ArrayRef<int64_t>> dimShape) { 75 assert(enc); 76 // With only encoding, we can not determine the static shape for leading 77 // batch levels, we therefore return a dynamic shape memref instead. 78 SmallVector<int64_t> memrefShape(enc.getBatchLvlRank(), ShapedType::kDynamic); 79 if (dimShape.has_value()) { 80 // If the actual tensor shape is provided, we can then refine the leading 81 // batch dimension. 82 SmallVector<int64_t> lvlShape = 83 enc.translateShape(*dimShape, CrdTransDirectionKind::dim2lvl); 84 memrefShape.assign(lvlShape.begin(), 85 lvlShape.begin() + enc.getBatchLvlRank()); 86 } 87 // Another dynamic dimension to store the sparse level. 88 memrefShape.push_back(ShapedType::kDynamic); 89 return memrefShape; 90 } 91 92 //===----------------------------------------------------------------------===// 93 // SparseTensorDialect StorageLayout. 94 //===----------------------------------------------------------------------===// 95 96 static constexpr Level kInvalidLevel = -1u; 97 static constexpr Level kInvalidFieldIndex = -1u; 98 static constexpr FieldIndex kDataFieldStartingIdx = 0; 99 100 void StorageLayout::foreachField( 101 llvm::function_ref<bool(FieldIndex, SparseTensorFieldKind, Level, 102 LevelType)> 103 callback) const { 104 const auto lvlTypes = enc.getLvlTypes(); 105 const Level lvlRank = enc.getLvlRank(); 106 SmallVector<COOSegment> cooSegs = SparseTensorType(enc).getCOOSegments(); 107 FieldIndex fieldIdx = kDataFieldStartingIdx; 108 109 ArrayRef cooSegsRef = cooSegs; 110 // Per-level storage. 111 for (Level l = 0; l < lvlRank; /*l += 1 or l += AoSCooLen*/) { 112 const auto lt = lvlTypes[l]; 113 if (isWithPosLT(lt)) { 114 if (!(callback(fieldIdx++, SparseTensorFieldKind::PosMemRef, l, lt))) 115 return; 116 } 117 if (isWithCrdLT(lt)) { 118 if (!(callback(fieldIdx++, SparseTensorFieldKind::CrdMemRef, l, lt))) 119 return; 120 } 121 if (!cooSegsRef.empty() && cooSegsRef.front().isSegmentStart(l)) { 122 if (!cooSegsRef.front().isSoA) { 123 // AoS COO, all singletons are fused into one memrefs. Skips the entire 124 // COO segement. 125 l = cooSegsRef.front().lvlRange.second; 126 } else { 127 // SoA COO, each singleton level has one memref. 128 l++; 129 } 130 // Expire handled COO segment. 131 cooSegsRef = cooSegsRef.drop_front(); 132 } else { 133 // Non COO levels. 134 l++; 135 } 136 } 137 // The values array. 138 if (!(callback(fieldIdx++, SparseTensorFieldKind::ValMemRef, kInvalidLevel, 139 LevelFormat::Undef))) 140 return; 141 // Put metadata at the end. 142 if (!(callback(fieldIdx++, SparseTensorFieldKind::StorageSpec, kInvalidLevel, 143 LevelFormat::Undef))) 144 return; 145 } 146 147 void sparse_tensor::foreachFieldAndTypeInSparseTensor( 148 SparseTensorType stt, 149 llvm::function_ref<bool(Type, FieldIndex, SparseTensorFieldKind, Level, 150 LevelType)> 151 callback) { 152 assert(stt.hasEncoding()); 153 154 SmallVector<int64_t> memrefShape = 155 getSparseFieldShape(stt.getEncoding(), stt.getDimShape()); 156 157 const Type specType = StorageSpecifierType::get(stt.getEncoding()); 158 // memref<[batch] x ? x pos> positions 159 const Type posMemType = MemRefType::get(memrefShape, stt.getPosType()); 160 // memref<[batch] x ? x crd> coordinates 161 const Type crdMemType = MemRefType::get(memrefShape, stt.getCrdType()); 162 // memref<[batch] x ? x eltType> values 163 const Type valMemType = MemRefType::get(memrefShape, stt.getElementType()); 164 165 StorageLayout(stt).foreachField([specType, posMemType, crdMemType, valMemType, 166 callback](FieldIndex fieldIdx, 167 SparseTensorFieldKind fieldKind, 168 Level lvl, LevelType lt) -> bool { 169 switch (fieldKind) { 170 case SparseTensorFieldKind::StorageSpec: 171 return callback(specType, fieldIdx, fieldKind, lvl, lt); 172 case SparseTensorFieldKind::PosMemRef: 173 return callback(posMemType, fieldIdx, fieldKind, lvl, lt); 174 case SparseTensorFieldKind::CrdMemRef: 175 return callback(crdMemType, fieldIdx, fieldKind, lvl, lt); 176 case SparseTensorFieldKind::ValMemRef: 177 return callback(valMemType, fieldIdx, fieldKind, lvl, lt); 178 }; 179 llvm_unreachable("unrecognized field kind"); 180 }); 181 } 182 183 unsigned StorageLayout::getNumFields() const { 184 unsigned numFields = 0; 185 foreachField([&numFields](FieldIndex, SparseTensorFieldKind, Level, 186 LevelType) -> bool { 187 numFields++; 188 return true; 189 }); 190 return numFields; 191 } 192 193 unsigned StorageLayout::getNumDataFields() const { 194 unsigned numFields = 0; // one value memref 195 foreachField([&numFields](FieldIndex fidx, SparseTensorFieldKind, Level, 196 LevelType) -> bool { 197 if (fidx >= kDataFieldStartingIdx) 198 numFields++; 199 return true; 200 }); 201 numFields -= 1; // the last field is StorageSpecifier 202 assert(numFields == getNumFields() - kDataFieldStartingIdx - 1); 203 return numFields; 204 } 205 206 std::pair<FieldIndex, unsigned> 207 StorageLayout::getFieldIndexAndStride(SparseTensorFieldKind kind, 208 std::optional<Level> lvl) const { 209 FieldIndex fieldIdx = kInvalidFieldIndex; 210 unsigned stride = 1; 211 if (kind == SparseTensorFieldKind::CrdMemRef) { 212 assert(lvl.has_value()); 213 const Level cooStart = SparseTensorType(enc).getAoSCOOStart(); 214 const Level lvlRank = enc.getLvlRank(); 215 if (lvl.value() >= cooStart && lvl.value() < lvlRank) { 216 lvl = cooStart; 217 stride = lvlRank - cooStart; 218 } 219 } 220 foreachField([lvl, kind, &fieldIdx](FieldIndex fIdx, 221 SparseTensorFieldKind fKind, Level fLvl, 222 LevelType lt) -> bool { 223 if ((lvl && fLvl == lvl.value() && kind == fKind) || 224 (kind == fKind && fKind == SparseTensorFieldKind::ValMemRef)) { 225 fieldIdx = fIdx; 226 // Returns false to break the iteration. 227 return false; 228 } 229 return true; 230 }); 231 assert(fieldIdx != kInvalidFieldIndex); 232 return std::pair<FieldIndex, unsigned>(fieldIdx, stride); 233 } 234 235 //===----------------------------------------------------------------------===// 236 // SparseTensorDialect Attribute Methods. 237 //===----------------------------------------------------------------------===// 238 239 std::optional<uint64_t> SparseTensorDimSliceAttr::getStatic(int64_t v) { 240 return isDynamic(v) ? std::nullopt 241 : std::make_optional(static_cast<uint64_t>(v)); 242 } 243 244 std::optional<uint64_t> SparseTensorDimSliceAttr::getStaticOffset() const { 245 return getStatic(getOffset()); 246 } 247 248 std::optional<uint64_t> SparseTensorDimSliceAttr::getStaticStride() const { 249 return getStatic(getStride()); 250 } 251 252 std::optional<uint64_t> SparseTensorDimSliceAttr::getStaticSize() const { 253 return getStatic(getSize()); 254 } 255 256 bool SparseTensorDimSliceAttr::isCompletelyDynamic() const { 257 return isDynamic(getOffset()) && isDynamic(getStride()) && 258 isDynamic(getSize()); 259 } 260 261 std::string SparseTensorDimSliceAttr::getStaticString(int64_t v) { 262 return isDynamic(v) ? "?" : std::to_string(v); 263 } 264 265 void SparseTensorDimSliceAttr::print(llvm::raw_ostream &os) const { 266 assert(getImpl() && "Uninitialized SparseTensorDimSliceAttr"); 267 os << '('; 268 os << getStaticString(getOffset()); 269 os << ", "; 270 os << getStaticString(getSize()); 271 os << ", "; 272 os << getStaticString(getStride()); 273 os << ')'; 274 } 275 276 void SparseTensorDimSliceAttr::print(AsmPrinter &printer) const { 277 print(printer.getStream()); 278 } 279 280 static ParseResult parseOptionalStaticSlice(int64_t &result, 281 AsmParser &parser) { 282 auto parseResult = parser.parseOptionalInteger(result); 283 if (parseResult.has_value()) { 284 if (parseResult.value().succeeded() && result < 0) { 285 parser.emitError( 286 parser.getCurrentLocation(), 287 "expect positive value or ? for slice offset/size/stride"); 288 return failure(); 289 } 290 return parseResult.value(); 291 } 292 293 // Else, and '?' which represented dynamic slice 294 result = SparseTensorDimSliceAttr::kDynamic; 295 return parser.parseQuestion(); 296 } 297 298 Attribute SparseTensorDimSliceAttr::parse(AsmParser &parser, Type type) { 299 int64_t offset = kDynamic, size = kDynamic, stride = kDynamic; 300 301 if (failed(parser.parseLParen()) || 302 failed(parseOptionalStaticSlice(offset, parser)) || 303 failed(parser.parseComma()) || 304 failed(parseOptionalStaticSlice(size, parser)) || 305 failed(parser.parseComma()) || 306 failed(parseOptionalStaticSlice(stride, parser)) || 307 failed(parser.parseRParen())) 308 return {}; 309 310 return parser.getChecked<SparseTensorDimSliceAttr>(parser.getContext(), 311 offset, size, stride); 312 } 313 314 LogicalResult 315 SparseTensorDimSliceAttr::verify(function_ref<InFlightDiagnostic()> emitError, 316 int64_t offset, int64_t size, int64_t stride) { 317 if (!isDynamic(offset) && offset < 0) 318 return emitError() << "expect non-negative value or ? for slice offset"; 319 if (!isDynamic(size) && size <= 0) 320 return emitError() << "expect positive value or ? for slice size"; 321 if (!isDynamic(stride) && stride <= 0) 322 return emitError() << "expect positive value or ? for slice stride"; 323 return success(); 324 } 325 326 SparseTensorEncodingAttr 327 SparseTensorEncodingAttr::withDimToLvl(AffineMap dimToLvl) const { 328 assert(getImpl() && "Uninitialized SparseTensorEncodingAttr"); 329 return SparseTensorEncodingAttr::get(getContext(), getLvlTypes(), dimToLvl, 330 AffineMap(), getPosWidth(), 331 getCrdWidth()); 332 } 333 334 SparseTensorEncodingAttr 335 SparseTensorEncodingAttr::withDimToLvl(SparseTensorEncodingAttr enc) const { 336 return withDimToLvl(enc ? enc.getDimToLvl() : AffineMap()); 337 } 338 339 SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutDimToLvl() const { 340 return withDimToLvl(AffineMap()); 341 } 342 343 SparseTensorEncodingAttr 344 SparseTensorEncodingAttr::withBitWidths(unsigned posWidth, 345 unsigned crdWidth) const { 346 assert(getImpl() && "Uninitialized SparseTensorEncodingAttr"); 347 return SparseTensorEncodingAttr::get(getContext(), getLvlTypes(), 348 getDimToLvl(), getLvlToDim(), posWidth, 349 crdWidth); 350 } 351 352 SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutBitWidths() const { 353 return withBitWidths(0, 0); 354 } 355 356 SparseTensorEncodingAttr SparseTensorEncodingAttr::withDimSlices( 357 ArrayRef<SparseTensorDimSliceAttr> dimSlices) const { 358 return SparseTensorEncodingAttr::get(getContext(), getLvlTypes(), 359 getDimToLvl(), getLvlToDim(), 360 getPosWidth(), getCrdWidth(), dimSlices); 361 } 362 363 SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutDimSlices() const { 364 return withDimSlices(ArrayRef<SparseTensorDimSliceAttr>{}); 365 } 366 367 uint64_t SparseTensorEncodingAttr::getBatchLvlRank() const { 368 ArrayRef<LevelType> lvlTypes = getLvlTypes(); 369 auto lastBatch = std::find_if(lvlTypes.rbegin(), lvlTypes.rend(), isBatchLT); 370 return std::distance(lastBatch, lvlTypes.rend()); 371 } 372 373 bool SparseTensorEncodingAttr::isAllDense() const { 374 return !getImpl() || llvm::all_of(getLvlTypes(), isDenseLT); 375 } 376 377 bool SparseTensorEncodingAttr::isAllOrdered() const { 378 return !getImpl() || llvm::all_of(getLvlTypes(), isOrderedLT); 379 } 380 381 Type SparseTensorEncodingAttr::getCrdElemType() const { 382 if (!getImpl()) 383 return nullptr; 384 if (getCrdWidth()) 385 return IntegerType::get(getContext(), getCrdWidth()); 386 return IndexType::get(getContext()); 387 } 388 389 Type SparseTensorEncodingAttr::getPosElemType() const { 390 if (!getImpl()) 391 return nullptr; 392 if (getPosWidth()) 393 return IntegerType::get(getContext(), getPosWidth()); 394 return IndexType::get(getContext()); 395 } 396 397 MemRefType SparseTensorEncodingAttr::getCrdMemRefType( 398 std::optional<ArrayRef<int64_t>> dimShape) const { 399 SmallVector<Size> shape = getSparseFieldShape(*this, dimShape); 400 return MemRefType::get(shape, getCrdElemType()); 401 } 402 403 MemRefType SparseTensorEncodingAttr::getPosMemRefType( 404 std::optional<ArrayRef<int64_t>> dimShape) const { 405 SmallVector<Size> shape = getSparseFieldShape(*this, dimShape); 406 return MemRefType::get(shape, getPosElemType()); 407 } 408 409 bool SparseTensorEncodingAttr::isIdentity() const { 410 return !getImpl() || !getDimToLvl() || getDimToLvl().isIdentity(); 411 } 412 413 bool SparseTensorEncodingAttr::isPermutation() const { 414 return !getImpl() || !getDimToLvl() || getDimToLvl().isPermutation(); 415 } 416 417 Dimension SparseTensorEncodingAttr::getDimRank() const { 418 assert(getImpl() && "Uninitialized SparseTensorEncodingAttr"); 419 const auto dimToLvl = getDimToLvl(); 420 return dimToLvl ? dimToLvl.getNumDims() : getLvlRank(); 421 } 422 423 Level SparseTensorEncodingAttr::getLvlRank() const { 424 assert(getImpl() && "Uninitialized SparseTensorEncodingAttr"); 425 return getLvlTypes().size(); 426 } 427 428 LevelType SparseTensorEncodingAttr::getLvlType(Level l) const { 429 if (!getImpl()) 430 return LevelFormat::Batch; 431 assert(l < getLvlRank() && "Level is out of bounds"); 432 return getLvlTypes()[l]; 433 } 434 435 bool SparseTensorEncodingAttr::isSlice() const { 436 assert(getImpl() && "Uninitialized SparseTensorEncodingAttr"); 437 return !getDimSlices().empty(); 438 } 439 440 SparseTensorDimSliceAttr 441 SparseTensorEncodingAttr::getDimSlice(Dimension dim) const { 442 assert(isSlice() && "Is not a slice"); 443 const auto dimSlices = getDimSlices(); 444 assert(dim < dimSlices.size() && "Dimension is out of bounds"); 445 return dimSlices[dim]; 446 } 447 448 std::optional<uint64_t> 449 SparseTensorEncodingAttr::getStaticDimSliceOffset(Dimension dim) const { 450 return getDimSlice(dim).getStaticOffset(); 451 } 452 453 std::optional<uint64_t> 454 SparseTensorEncodingAttr::getStaticDimSliceStride(Dimension dim) const { 455 return getDimSlice(dim).getStaticStride(); 456 } 457 458 std::optional<uint64_t> 459 SparseTensorEncodingAttr::getStaticLvlSliceOffset(Level lvl) const { 460 return getStaticDimSliceOffset(toDim(*this, lvl)); 461 } 462 463 std::optional<uint64_t> 464 SparseTensorEncodingAttr::getStaticLvlSliceStride(Level lvl) const { 465 return getStaticDimSliceStride(toDim(*this, lvl)); 466 } 467 468 SmallVector<int64_t> 469 SparseTensorEncodingAttr::translateShape(ArrayRef<int64_t> srcShape, 470 CrdTransDirectionKind dir) const { 471 if (isIdentity()) 472 return SmallVector<int64_t>(srcShape); 473 474 SmallVector<int64_t> ret; 475 unsigned rank = 476 dir == CrdTransDirectionKind::dim2lvl ? getLvlRank() : getDimRank(); 477 ret.reserve(rank); 478 479 if (isPermutation()) { 480 for (unsigned r = 0; r < rank; r++) { 481 unsigned trans = dir == CrdTransDirectionKind::dim2lvl ? toDim(*this, r) 482 : toLvl(*this, r); 483 ret.push_back(srcShape[trans]); 484 } 485 return ret; 486 } 487 488 // Handle non-permutation maps. 489 AffineMap transMap = 490 dir == CrdTransDirectionKind::dim2lvl ? getDimToLvl() : getLvlToDim(); 491 492 SmallVector<AffineExpr> dimRep; 493 dimRep.reserve(srcShape.size()); 494 for (int64_t sz : srcShape) { 495 if (!ShapedType::isDynamic(sz)) { 496 // Push back the max coordinate for the given dimension/level size. 497 dimRep.push_back(getAffineConstantExpr(sz - 1, getContext())); 498 } else { 499 // A dynamic size, use a AffineDimExpr to symbolize the value. 500 dimRep.push_back(getAffineDimExpr(dimRep.size(), getContext())); 501 } 502 }; 503 504 for (AffineExpr exp : transMap.getResults()) { 505 // Do constant propagation on the affine map. 506 AffineExpr evalExp = 507 simplifyAffineExpr(exp.replaceDims(dimRep), srcShape.size(), 0); 508 // use llvm namespace here to avoid ambiguity 509 if (auto c = llvm::dyn_cast<AffineConstantExpr>(evalExp)) { 510 ret.push_back(c.getValue() + 1); 511 } else { 512 if (auto mod = llvm::dyn_cast<AffineBinaryOpExpr>(evalExp); 513 mod && mod.getKind() == AffineExprKind::Mod) { 514 // We can still infer a static bound for expressions in form 515 // "d % constant" since d % constant \in [0, constant). 516 if (auto bound = llvm::dyn_cast<AffineConstantExpr>(mod.getRHS())) { 517 ret.push_back(bound.getValue()); 518 continue; 519 } 520 } 521 ret.push_back(ShapedType::kDynamic); 522 } 523 } 524 assert(ret.size() == rank); 525 return ret; 526 } 527 528 ValueRange 529 SparseTensorEncodingAttr::translateCrds(OpBuilder &builder, Location loc, 530 ValueRange crds, 531 CrdTransDirectionKind dir) const { 532 if (!getImpl()) 533 return crds; 534 535 SmallVector<Type> retType( 536 dir == CrdTransDirectionKind::lvl2dim ? getDimRank() : getLvlRank(), 537 builder.getIndexType()); 538 auto transOp = builder.create<CrdTranslateOp>(loc, retType, crds, dir, *this); 539 return transOp.getOutCrds(); 540 } 541 542 Attribute SparseTensorEncodingAttr::parse(AsmParser &parser, Type type) { 543 // Open "<{" part. 544 if (failed(parser.parseLess())) 545 return {}; 546 if (failed(parser.parseLBrace())) 547 return {}; 548 549 // Process the data from the parsed dictionary value into struct-like data. 550 SmallVector<LevelType> lvlTypes; 551 SmallVector<SparseTensorDimSliceAttr> dimSlices; 552 AffineMap dimToLvl = {}; 553 AffineMap lvlToDim = {}; 554 unsigned posWidth = 0; 555 unsigned crdWidth = 0; 556 StringRef attrName; 557 SmallVector<StringRef, 3> keys = {"map", "posWidth", "crdWidth"}; 558 while (succeeded(parser.parseOptionalKeyword(&attrName))) { 559 // Detect admissible keyword. 560 auto *it = find(keys, attrName); 561 if (it == keys.end()) { 562 parser.emitError(parser.getNameLoc(), "unexpected key: ") << attrName; 563 return {}; 564 } 565 unsigned keyWordIndex = it - keys.begin(); 566 // Consume the `=` after keys 567 if (failed(parser.parseEqual())) 568 return {}; 569 // Dispatch on keyword. 570 switch (keyWordIndex) { 571 case 0: { // map 572 ir_detail::DimLvlMapParser cParser(parser); 573 auto res = cParser.parseDimLvlMap(); 574 if (failed(res)) 575 return {}; 576 const auto &dlm = *res; 577 578 const Level lvlRank = dlm.getLvlRank(); 579 for (Level lvl = 0; lvl < lvlRank; lvl++) 580 lvlTypes.push_back(dlm.getLvlType(lvl)); 581 582 const Dimension dimRank = dlm.getDimRank(); 583 for (Dimension dim = 0; dim < dimRank; dim++) 584 dimSlices.push_back(dlm.getDimSlice(dim)); 585 // NOTE: the old syntax requires an all-or-nothing approach to 586 // `dimSlices`; therefore, if any slice actually exists then we need 587 // to convert null-DSA into default/nop DSA. 588 const auto isDefined = [](SparseTensorDimSliceAttr slice) { 589 return static_cast<bool>(slice.getImpl()); 590 }; 591 if (llvm::any_of(dimSlices, isDefined)) { 592 const auto defaultSlice = 593 SparseTensorDimSliceAttr::get(parser.getContext()); 594 for (Dimension dim = 0; dim < dimRank; dim++) 595 if (!isDefined(dimSlices[dim])) 596 dimSlices[dim] = defaultSlice; 597 } else { 598 dimSlices.clear(); 599 } 600 601 dimToLvl = dlm.getDimToLvlMap(parser.getContext()); 602 lvlToDim = dlm.getLvlToDimMap(parser.getContext()); 603 break; 604 } 605 case 1: { // posWidth 606 Attribute attr; 607 if (failed(parser.parseAttribute(attr))) 608 return {}; 609 auto intAttr = llvm::dyn_cast<IntegerAttr>(attr); 610 if (!intAttr) { 611 parser.emitError(parser.getNameLoc(), 612 "expected an integral position bitwidth"); 613 return {}; 614 } 615 posWidth = intAttr.getInt(); 616 break; 617 } 618 case 2: { // crdWidth 619 Attribute attr; 620 if (failed(parser.parseAttribute(attr))) 621 return {}; 622 auto intAttr = llvm::dyn_cast<IntegerAttr>(attr); 623 if (!intAttr) { 624 parser.emitError(parser.getNameLoc(), 625 "expected an integral index bitwidth"); 626 return {}; 627 } 628 crdWidth = intAttr.getInt(); 629 break; 630 } 631 } // switch 632 // Only last item can omit the comma. 633 if (parser.parseOptionalComma().failed()) 634 break; 635 } 636 637 // Close "}>" part. 638 if (failed(parser.parseRBrace())) 639 return {}; 640 if (failed(parser.parseGreater())) 641 return {}; 642 643 // Construct struct-like storage for attribute. 644 if (!lvlToDim || lvlToDim.isEmpty()) { 645 lvlToDim = inferLvlToDim(dimToLvl, parser.getContext()); 646 } 647 return parser.getChecked<SparseTensorEncodingAttr>( 648 parser.getContext(), lvlTypes, dimToLvl, lvlToDim, posWidth, crdWidth, 649 dimSlices); 650 } 651 652 void SparseTensorEncodingAttr::print(AsmPrinter &printer) const { 653 auto map = static_cast<AffineMap>(getDimToLvl()); 654 // Empty affine map indicates identity map 655 if (!map) 656 map = AffineMap::getMultiDimIdentityMap(getLvlTypes().size(), getContext()); 657 printer << "<{ map = "; 658 printSymbols(map, printer); 659 printer << '('; 660 printDimensions(map, printer, getDimSlices()); 661 printer << ") -> ("; 662 printLevels(map, printer, getLvlTypes()); 663 printer << ')'; 664 // Print remaining members only for non-default values. 665 if (getPosWidth()) 666 printer << ", posWidth = " << getPosWidth(); 667 if (getCrdWidth()) 668 printer << ", crdWidth = " << getCrdWidth(); 669 printer << " }>"; 670 } 671 672 void SparseTensorEncodingAttr::printSymbols(AffineMap &map, 673 AsmPrinter &printer) const { 674 if (map.getNumSymbols() == 0) 675 return; 676 printer << '['; 677 for (unsigned i = 0, n = map.getNumSymbols() - 1; i < n; i++) 678 printer << 's' << i << ", "; 679 if (map.getNumSymbols() >= 1) 680 printer << 's' << map.getNumSymbols() - 1; 681 printer << ']'; 682 } 683 684 void SparseTensorEncodingAttr::printDimensions( 685 AffineMap &map, AsmPrinter &printer, 686 ArrayRef<SparseTensorDimSliceAttr> dimSlices) const { 687 if (!dimSlices.empty()) { 688 for (unsigned i = 0, n = map.getNumDims() - 1; i < n; i++) 689 printer << 'd' << i << " : " << dimSlices[i] << ", "; 690 if (map.getNumDims() >= 1) { 691 printer << 'd' << map.getNumDims() - 1 << " : " 692 << dimSlices[map.getNumDims() - 1]; 693 } 694 } else { 695 for (unsigned i = 0, n = map.getNumDims() - 1; i < n; i++) 696 printer << 'd' << i << ", "; 697 if (map.getNumDims() >= 1) 698 printer << 'd' << map.getNumDims() - 1; 699 } 700 } 701 702 void SparseTensorEncodingAttr::printLevels(AffineMap &map, AsmPrinter &printer, 703 ArrayRef<LevelType> lvlTypes) const { 704 for (unsigned i = 0, n = map.getNumResults() - 1; i < n; i++) { 705 map.getResult(i).print(printer.getStream()); 706 printer << " : " << toMLIRString(lvlTypes[i]) << ", "; 707 } 708 if (map.getNumResults() >= 1) { 709 auto lastIndex = map.getNumResults() - 1; 710 map.getResult(lastIndex).print(printer.getStream()); 711 printer << " : " << toMLIRString(lvlTypes[lastIndex]); 712 } 713 } 714 715 LogicalResult SparseTensorEncodingAttr::verify( 716 function_ref<InFlightDiagnostic()> emitError, ArrayRef<LevelType> lvlTypes, 717 AffineMap dimToLvl, AffineMap lvlToDim, unsigned posWidth, 718 unsigned crdWidth, ArrayRef<SparseTensorDimSliceAttr> dimSlices) { 719 if (!acceptBitWidth(posWidth)) 720 return emitError() << "unexpected position bitwidth: " << posWidth; 721 if (!acceptBitWidth(crdWidth)) 722 return emitError() << "unexpected coordinate bitwidth: " << crdWidth; 723 if (auto it = std::find_if(lvlTypes.begin(), lvlTypes.end(), isSingletonLT); 724 it != std::end(lvlTypes)) { 725 if (it == lvlTypes.begin() || 726 (!isCompressedLT(*(it - 1)) && !isLooseCompressedLT(*(it - 1)))) 727 return emitError() << "expected compressed or loose_compressed level " 728 "before singleton level"; 729 if (!std::all_of(it, lvlTypes.end(), 730 [](LevelType i) { return isSingletonLT(i); })) 731 return emitError() << "expected all singleton lvlTypes " 732 "following a singleton level"; 733 // We can potentially support mixed SoA/AoS singleton levels. 734 if (!std::all_of(it, lvlTypes.end(), [it](LevelType i) { 735 return it->isa<LevelPropNonDefault::SoA>() == 736 i.isa<LevelPropNonDefault::SoA>(); 737 })) { 738 return emitError() << "expected all singleton lvlTypes stored in the " 739 "same memory layout (SoA vs AoS)."; 740 } 741 } 742 743 auto lastBatch = std::find_if(lvlTypes.rbegin(), lvlTypes.rend(), isBatchLT); 744 if (!std::all_of(lastBatch, lvlTypes.rend(), isBatchLT)) 745 return emitError() << "Batch lvlType can only be leading levels."; 746 747 // SoA property can only be applied on singleton level. 748 auto soaLvls = llvm::make_filter_range(lvlTypes, [](LevelType lt) { 749 return lt.isa<LevelPropNonDefault::SoA>(); 750 }); 751 if (llvm::any_of(soaLvls, [](LevelType lt) { 752 return !lt.isa<LevelFormat::Singleton>(); 753 })) { 754 return emitError() << "SoA is only applicable to singleton lvlTypes."; 755 } 756 757 // TODO: audit formats that actually are supported by backend. 758 if (auto it = std::find_if(lvlTypes.begin(), lvlTypes.end(), isNOutOfMLT); 759 it != std::end(lvlTypes)) { 760 if (it != lvlTypes.end() - 1) 761 return emitError() << "expected n_out_of_m to be the last level type"; 762 if (!std::all_of(lvlTypes.begin(), it, 763 [](LevelType i) { return isDenseLT(i); })) 764 return emitError() << "expected all dense lvlTypes " 765 "before a n_out_of_m level"; 766 if (dimToLvl && (dimToLvl.getNumDims() != dimToLvl.getNumResults())) { 767 if (!isBlockSparsity(dimToLvl)) { 768 return emitError() 769 << "expected 1xm block structure for n_out_of_m level"; 770 } 771 auto sizes = getBlockSize(dimToLvl); 772 unsigned coefficient = 0; 773 for (const auto &elem : sizes) { 774 if (elem != 0) { 775 if (elem != coefficient && coefficient != 0) { 776 return emitError() << "expected only one blocked level " 777 "with the same coefficients"; 778 } 779 coefficient = elem; 780 } 781 } 782 if (coefficient != getM(*it)) { 783 return emitError() << "expected coeffiencts of Affine expressions " 784 "to be equal to m of n_out_of_m level"; 785 } 786 } 787 } 788 // Before we can check that the level-rank is consistent/coherent 789 // across all fields, we need to define it. The source-of-truth for 790 // the `getLvlRank` method is the length of the level-types array, 791 // since it must always be provided and have full rank; therefore we 792 // use that same source-of-truth here. 793 const Level lvlRank = lvlTypes.size(); 794 if (lvlRank == 0) 795 return emitError() << "expected a non-empty array for lvlTypes"; 796 // We save `dimRank` here because we'll also need it to verify `dimSlices`. 797 const Dimension dimRank = dimToLvl ? dimToLvl.getNumDims() : lvlRank; 798 if (dimToLvl) { 799 if (dimToLvl.getNumResults() != lvlRank) 800 return emitError() 801 << "level-rank mismatch between dimToLvl and lvlTypes: " 802 << dimToLvl.getNumResults() << " != " << lvlRank; 803 auto inferRes = inferLvlToDim(dimToLvl, dimToLvl.getContext()); 804 // Symbols can't be inferred but are acceptable. 805 if (!inferRes && dimToLvl.getNumSymbols() == 0) 806 return emitError() << "failed to infer lvlToDim from dimToLvl"; 807 if (lvlToDim && (inferRes != lvlToDim)) 808 return emitError() << "expected lvlToDim to be an inverse of dimToLvl"; 809 if (dimRank > lvlRank) 810 return emitError() << "unexpected dimToLvl mapping from " << dimRank 811 << " to " << lvlRank; 812 } 813 if (!dimSlices.empty()) { 814 if (dimSlices.size() != dimRank) 815 return emitError() 816 << "dimension-rank mismatch between dimSlices and dimToLvl: " 817 << dimSlices.size() << " != " << dimRank; 818 // Compiler support for `dimSlices` currently requires that the two 819 // ranks agree. (However, it does allow `dimToLvl` to be a permutation.) 820 if (dimRank != lvlRank) 821 return emitError() 822 << "dimSlices expected dimension-rank to match level-rank: " 823 << dimRank << " != " << lvlRank; 824 } 825 return success(); 826 } 827 828 LogicalResult SparseTensorEncodingAttr::verifyEncoding( 829 ArrayRef<Size> dimShape, Type elementType, 830 function_ref<InFlightDiagnostic()> emitError) const { 831 // Check structural integrity. In particular, this ensures that the 832 // level-rank is coherent across all the fields. 833 if (failed(verify(emitError, getLvlTypes(), getDimToLvl(), getLvlToDim(), 834 getPosWidth(), getCrdWidth(), getDimSlices()))) 835 return failure(); 836 // Check integrity with tensor type specifics. In particular, we 837 // need only check that the dimension-rank of the tensor agrees with 838 // the dimension-rank of the encoding. 839 const Dimension dimRank = dimShape.size(); 840 if (dimRank == 0) 841 return emitError() << "expected non-scalar sparse tensor"; 842 if (getDimRank() != dimRank) 843 return emitError() 844 << "dimension-rank mismatch between encoding and tensor shape: " 845 << getDimRank() << " != " << dimRank; 846 return success(); 847 } 848 849 //===----------------------------------------------------------------------===// 850 // SparseTensorType Methods. 851 //===----------------------------------------------------------------------===// 852 853 bool mlir::sparse_tensor::SparseTensorType::isCOOType(Level startLvl, 854 bool isUnique) const { 855 if (!hasEncoding()) 856 return false; 857 if (!isCompressedLvl(startLvl) && !isLooseCompressedLvl(startLvl)) 858 return false; 859 for (Level l = startLvl + 1; l < lvlRank; ++l) 860 if (!isSingletonLvl(l)) 861 return false; 862 // If isUnique is true, then make sure that the last level is unique, 863 // that is, when lvlRank == 1, the only compressed level is unique, 864 // and when lvlRank > 1, the last singleton is unique. 865 return !isUnique || isUniqueLvl(lvlRank - 1); 866 } 867 868 Level mlir::sparse_tensor::SparseTensorType::getAoSCOOStart() const { 869 SmallVector<COOSegment> coo = getCOOSegments(); 870 assert(coo.size() == 1 || coo.empty()); 871 if (!coo.empty() && coo.front().isAoS()) { 872 return coo.front().lvlRange.first; 873 } 874 return lvlRank; 875 } 876 877 SmallVector<COOSegment> 878 mlir::sparse_tensor::SparseTensorType::getCOOSegments() const { 879 SmallVector<COOSegment> ret; 880 if (!hasEncoding() || lvlRank <= 1) 881 return ret; 882 883 ArrayRef<LevelType> lts = getLvlTypes(); 884 Level l = 0; 885 while (l < lvlRank) { 886 auto lt = lts[l]; 887 if (lt.isa<LevelFormat::Compressed, LevelFormat::LooseCompressed>()) { 888 auto cur = lts.begin() + l; 889 auto end = std::find_if(cur + 1, lts.end(), [](LevelType lt) { 890 return !lt.isa<LevelFormat::Singleton>(); 891 }); 892 unsigned cooLen = std::distance(cur, end); 893 if (cooLen > 1) { 894 // To support mixed SoA/AoS COO, we should break the segment when the 895 // storage scheme changes, for now we faithfully assume that all 896 // consecutive singleton levels have the same storage format as verified 897 // STEA. 898 ret.push_back(COOSegment{std::make_pair(l, l + cooLen), 899 lts[l + 1].isa<LevelPropNonDefault::SoA>()}); 900 } 901 l += cooLen; 902 } else { 903 l++; 904 } 905 } 906 return ret; 907 } 908 909 RankedTensorType 910 mlir::sparse_tensor::SparseTensorType::getCOOType(bool ordered) const { 911 SmallVector<LevelType> lvlTypes; 912 lvlTypes.reserve(lvlRank); 913 // A non-unique compressed level at beginning (unless this is 914 // also the last level, then it is unique). 915 lvlTypes.push_back( 916 *buildLevelType(LevelFormat::Compressed, ordered, lvlRank == 1)); 917 if (lvlRank > 1) { 918 // Followed by n-2 non-unique singleton levels. 919 std::fill_n(std::back_inserter(lvlTypes), lvlRank - 2, 920 *buildLevelType(LevelFormat::Singleton, ordered, false)); 921 // Ends by a unique singleton level. 922 lvlTypes.push_back(*buildLevelType(LevelFormat::Singleton, ordered, true)); 923 } 924 auto enc = SparseTensorEncodingAttr::get(getContext(), lvlTypes, 925 getDimToLvl(), getLvlToDim(), 926 getPosWidth(), getCrdWidth()); 927 return RankedTensorType::get(getDimShape(), getElementType(), enc); 928 } 929 930 //===----------------------------------------------------------------------===// 931 // Convenience Methods. 932 //===----------------------------------------------------------------------===// 933 934 SparseTensorEncodingAttr 935 mlir::sparse_tensor::getSparseTensorEncoding(Type type) { 936 if (auto ttp = llvm::dyn_cast<RankedTensorType>(type)) 937 return llvm::dyn_cast_or_null<SparseTensorEncodingAttr>(ttp.getEncoding()); 938 if (auto mdtp = llvm::dyn_cast<StorageSpecifierType>(type)) 939 return mdtp.getEncoding(); 940 return nullptr; 941 } 942 943 AffineMap mlir::sparse_tensor::inferLvlToDim(AffineMap dimToLvl, 944 MLIRContext *context) { 945 auto map = static_cast<AffineMap>(dimToLvl); 946 AffineMap lvlToDim; 947 // Return an empty lvlToDim when inference is not successful. 948 if (!map || map.getNumSymbols() != 0) { 949 lvlToDim = AffineMap(); 950 } else if (map.isPermutation()) { 951 lvlToDim = inversePermutation(map); 952 } else if (isBlockSparsity(map)) { 953 lvlToDim = inverseBlockSparsity(map, context); 954 } 955 return lvlToDim; 956 } 957 958 AffineMap mlir::sparse_tensor::inverseBlockSparsity(AffineMap dimToLvl, 959 MLIRContext *context) { 960 SmallVector<AffineExpr> lvlExprs; 961 auto numLvls = dimToLvl.getNumResults(); 962 lvlExprs.reserve(numLvls); 963 // lvlExprComponents stores information of the floordiv and mod operations 964 // applied to the same dimension, so as to build the lvlToDim map. 965 std::map<unsigned, SmallVector<AffineExpr, 3>> lvlExprComponents; 966 for (unsigned i = 0, n = numLvls; i < n; i++) { 967 auto result = dimToLvl.getResult(i); 968 if (auto binOp = dyn_cast<AffineBinaryOpExpr>(result)) { 969 if (result.getKind() == AffineExprKind::FloorDiv) { 970 // Position of the dimension in dimToLvl. 971 auto pos = dyn_cast<AffineDimExpr>(binOp.getLHS()).getPosition(); 972 assert(lvlExprComponents.find(pos) == lvlExprComponents.end() && 973 "expected only one floordiv for each dimension"); 974 SmallVector<AffineExpr, 3> components; 975 // Level variable for floordiv. 976 components.push_back(getAffineDimExpr(i, context)); 977 // Multiplier. 978 components.push_back(binOp.getRHS()); 979 // Map key is the position of the dimension. 980 lvlExprComponents[pos] = components; 981 } else if (result.getKind() == AffineExprKind::Mod) { 982 auto pos = dyn_cast<AffineDimExpr>(binOp.getLHS()).getPosition(); 983 assert(lvlExprComponents.find(pos) != lvlExprComponents.end() && 984 "expected floordiv before mod"); 985 // Add level variable for mod to the same vector 986 // of the corresponding floordiv. 987 lvlExprComponents[pos].push_back(getAffineDimExpr(i, context)); 988 } else { 989 assert(false && "expected floordiv or mod"); 990 } 991 } else { 992 lvlExprs.push_back(getAffineDimExpr(i, context)); 993 } 994 } 995 // Build lvlExprs from lvlExprComponents. 996 // For example, for il = i floordiv 2 and ii = i mod 2, the components 997 // would be [il, 2, ii]. It could be used to build the AffineExpr 998 // i = il * 2 + ii in lvlToDim. 999 for (auto &components : lvlExprComponents) { 1000 assert(components.second.size() == 3 && 1001 "expected 3 components to build lvlExprs"); 1002 auto mulOp = getAffineBinaryOpExpr( 1003 AffineExprKind::Mul, components.second[0], components.second[1]); 1004 auto addOp = 1005 getAffineBinaryOpExpr(AffineExprKind::Add, mulOp, components.second[2]); 1006 lvlExprs.push_back(addOp); 1007 } 1008 return dimToLvl.get(dimToLvl.getNumResults(), 0, lvlExprs, context); 1009 } 1010 1011 SmallVector<unsigned> mlir::sparse_tensor::getBlockSize(AffineMap dimToLvl) { 1012 assert(isBlockSparsity(dimToLvl) && 1013 "expected dimToLvl to be block sparsity for calling getBlockSize"); 1014 SmallVector<unsigned> blockSize; 1015 for (auto result : dimToLvl.getResults()) { 1016 if (auto binOp = dyn_cast<AffineBinaryOpExpr>(result)) { 1017 if (result.getKind() == AffineExprKind::Mod) { 1018 blockSize.push_back( 1019 dyn_cast<AffineConstantExpr>(binOp.getRHS()).getValue()); 1020 } 1021 } else { 1022 blockSize.push_back(0); 1023 } 1024 } 1025 return blockSize; 1026 } 1027 1028 bool mlir::sparse_tensor::isBlockSparsity(AffineMap dimToLvl) { 1029 if (!dimToLvl) 1030 return false; 1031 std::map<unsigned, int64_t> coeffientMap; 1032 bool hasBlock = false; 1033 for (auto result : dimToLvl.getResults()) { 1034 if (auto binOp = dyn_cast<AffineBinaryOpExpr>(result)) { 1035 // Check for "dim op const". 1036 auto dimOp = dyn_cast<AffineDimExpr>(binOp.getLHS()); 1037 auto conOp = dyn_cast<AffineConstantExpr>(binOp.getRHS()); 1038 if (!dimOp || !conOp || conOp.getValue() <= 0) 1039 return false; 1040 // Inspect "dim / const" or "dim % const". 1041 auto pos = dimOp.getPosition(); 1042 if (binOp.getKind() == AffineExprKind::FloorDiv) { 1043 // Expect only one floordiv for each dimension. 1044 if (coeffientMap.find(pos) != coeffientMap.end()) 1045 return false; 1046 // Record coefficient of the floordiv. 1047 coeffientMap[pos] = conOp.getValue(); 1048 } else if (binOp.getKind() == AffineExprKind::Mod) { 1049 // Expect floordiv before mod. 1050 if (coeffientMap.find(pos) == coeffientMap.end()) 1051 return false; 1052 // Expect mod to have the same coefficient as floordiv. 1053 if (conOp.getValue() != coeffientMap[pos]) 1054 return false; 1055 hasBlock = true; 1056 } else { 1057 return false; 1058 } 1059 } else if (auto dimOp = dyn_cast<AffineDimExpr>(result)) { 1060 auto pos = dimOp.getPosition(); 1061 // Expect dim to be unset. 1062 if (coeffientMap.find(pos) != coeffientMap.end()) 1063 return false; 1064 coeffientMap[pos] = 0; 1065 } else { 1066 return false; 1067 } 1068 } 1069 return hasBlock; 1070 } 1071 1072 bool mlir::sparse_tensor::hasAnyNonIdentityOperandsOrResults(Operation *op) { 1073 auto hasNonIdentityMap = [](Value v) { 1074 auto stt = tryGetSparseTensorType(v); 1075 return stt && !stt->isIdentity(); 1076 }; 1077 1078 return llvm::any_of(op->getOperands(), hasNonIdentityMap) || 1079 llvm::any_of(op->getResults(), hasNonIdentityMap); 1080 } 1081 1082 Dimension mlir::sparse_tensor::toDim(SparseTensorEncodingAttr enc, Level l) { 1083 if (enc) { 1084 assert(enc.isPermutation() && "Non permutation map not supported"); 1085 if (const auto dimToLvl = enc.getDimToLvl()) 1086 return dimToLvl.getDimPosition(l); 1087 } 1088 return l; 1089 } 1090 1091 Level mlir::sparse_tensor::toLvl(SparseTensorEncodingAttr enc, Dimension d) { 1092 if (enc) { 1093 assert(enc.isPermutation() && "Non permutation map not supported"); 1094 if (const auto lvlToDim = enc.getLvlToDim()) 1095 return lvlToDim.getDimPosition(d); 1096 } 1097 return d; 1098 } 1099 1100 /// We normalized sparse tensor encoding attribute by always using 1101 /// ordered/unique LT such that "compressed_nu_no" and "compressed_nu" (as well 1102 /// as other variants) lead to the same storage specifier type, and stripping 1103 /// irrelevant fields that do not alter the sparse tensor memory layout. 1104 static SparseTensorEncodingAttr 1105 getNormalizedEncodingForSpecifier(SparseTensorEncodingAttr enc) { 1106 SmallVector<LevelType> lts; 1107 for (auto lt : enc.getLvlTypes()) 1108 lts.push_back(lt.stripStorageIrrelevantProperties()); 1109 1110 return SparseTensorEncodingAttr::get( 1111 enc.getContext(), lts, 1112 AffineMap(), // dimToLvl (irrelevant to storage specifier) 1113 AffineMap(), // lvlToDim (irrelevant to storage specifier) 1114 // Always use `index` for memSize and lvlSize instead of reusing 1115 // `getPosWidth` and `getCrdWidth`. It allows us to reuse the same SSA 1116 // value for different bitwidth, it also avoids casting between index and 1117 // integer (returned by DimOp) 1118 0, 0, enc.getDimSlices()); 1119 } 1120 1121 StorageSpecifierType 1122 StorageSpecifierType::get(MLIRContext *ctx, SparseTensorEncodingAttr encoding) { 1123 return Base::get(ctx, getNormalizedEncodingForSpecifier(encoding)); 1124 } 1125 1126 //===----------------------------------------------------------------------===// 1127 // SparseTensorDialect Operations. 1128 //===----------------------------------------------------------------------===// 1129 1130 static LogicalResult lvlIsInBounds(Level lvl, Value tensor) { 1131 return success(lvl < getSparseTensorType(tensor).getLvlRank()); 1132 } 1133 1134 static LogicalResult isMatchingWidth(Value mem, unsigned width) { 1135 const Type etp = getMemRefType(mem).getElementType(); 1136 return success(width == 0 ? etp.isIndex() : etp.isInteger(width)); 1137 } 1138 1139 static LogicalResult verifySparsifierGetterSetter( 1140 StorageSpecifierKind mdKind, std::optional<Level> lvl, 1141 TypedValue<StorageSpecifierType> md, Operation *op) { 1142 if (mdKind == StorageSpecifierKind::ValMemSize && lvl) { 1143 return op->emitError( 1144 "redundant level argument for querying value memory size"); 1145 } 1146 1147 const auto enc = md.getType().getEncoding(); 1148 const Level lvlRank = enc.getLvlRank(); 1149 1150 if (mdKind == StorageSpecifierKind::DimOffset || 1151 mdKind == StorageSpecifierKind::DimStride) 1152 if (!enc.isSlice()) 1153 return op->emitError("requested slice data on non-slice tensor"); 1154 1155 if (mdKind != StorageSpecifierKind::ValMemSize) { 1156 if (!lvl) 1157 return op->emitError("missing level argument"); 1158 1159 const Level l = lvl.value(); 1160 if (l >= lvlRank) 1161 return op->emitError("requested level is out of bounds"); 1162 1163 if (mdKind == StorageSpecifierKind::PosMemSize && enc.isSingletonLvl(l)) 1164 return op->emitError( 1165 "requested position memory size on a singleton level"); 1166 } 1167 return success(); 1168 } 1169 1170 static Type getFieldElemType(SparseTensorType stt, SparseTensorFieldKind kind) { 1171 switch (kind) { 1172 case SparseTensorFieldKind::CrdMemRef: 1173 return stt.getCrdType(); 1174 case SparseTensorFieldKind::PosMemRef: 1175 return stt.getPosType(); 1176 case SparseTensorFieldKind::ValMemRef: 1177 return stt.getElementType(); 1178 case SparseTensorFieldKind::StorageSpec: 1179 return nullptr; 1180 } 1181 llvm_unreachable("Unrecognizable FieldKind"); 1182 } 1183 1184 static LogicalResult verifyPackUnPack(Operation *op, bool requiresStaticShape, 1185 SparseTensorType stt, 1186 RankedTensorType valTp, 1187 TypeRange lvlTps) { 1188 if (requiresStaticShape && !stt.hasStaticDimShape()) 1189 return op->emitError("the sparse-tensor must have static shape"); 1190 if (!stt.hasEncoding()) 1191 return op->emitError("the sparse-tensor must have an encoding attribute"); 1192 1193 // Verifies the trailing COO. 1194 Level cooStartLvl = stt.getAoSCOOStart(); 1195 if (cooStartLvl < stt.getLvlRank()) { 1196 // We only supports trailing COO for now, must be the last input. 1197 auto cooTp = llvm::cast<ShapedType>(lvlTps.back()); 1198 // The coordinates should be in shape of <? x rank> 1199 unsigned expCOORank = stt.getLvlRank() - cooStartLvl; 1200 if (cooTp.getRank() != 2 || expCOORank != cooTp.getShape().back()) { 1201 op->emitError("input/output trailing COO level-ranks don't match"); 1202 } 1203 } 1204 1205 // Verifies that all types match. 1206 StorageLayout layout(stt.getEncoding()); 1207 if (layout.getNumDataFields() != lvlTps.size() + 1) // plus one value memref 1208 return op->emitError("inconsistent number of fields between input/output"); 1209 1210 unsigned idx = 0; 1211 bool misMatch = false; 1212 layout.foreachField([&idx, &misMatch, stt, valTp, 1213 lvlTps](FieldIndex fid, SparseTensorFieldKind fKind, 1214 Level lvl, LevelType lt) -> bool { 1215 if (fKind == SparseTensorFieldKind::StorageSpec) 1216 return true; 1217 1218 Type inputTp = nullptr; 1219 if (fKind == SparseTensorFieldKind::ValMemRef) { 1220 inputTp = valTp; 1221 } else { 1222 assert(fid == idx && stt.getLvlType(lvl) == lt); 1223 inputTp = lvlTps[idx++]; 1224 } 1225 // The input element type and expected element type should match. 1226 Type inpElemTp = llvm::cast<TensorType>(inputTp).getElementType(); 1227 Type expElemTp = getFieldElemType(stt, fKind); 1228 if (inpElemTp != expElemTp) { 1229 misMatch = true; 1230 return false; // to terminate the iteration 1231 } 1232 return true; 1233 }); 1234 1235 if (misMatch) 1236 return op->emitError("input/output element-types don't match"); 1237 return success(); 1238 } 1239 1240 LogicalResult AssembleOp::verify() { 1241 const auto valuesTp = getRankedTensorType(getValues()); 1242 const auto lvlsTp = getLevels().getTypes(); 1243 const auto resTp = getSparseTensorType(getResult()); 1244 return verifyPackUnPack(*this, true, resTp, valuesTp, lvlsTp); 1245 } 1246 1247 LogicalResult DisassembleOp::verify() { 1248 if (getOutValues().getType() != getRetValues().getType()) 1249 return emitError("output values and return value type mismatch"); 1250 1251 for (auto [ot, rt] : llvm::zip_equal(getOutLevels(), getRetLevels())) 1252 if (ot.getType() != rt.getType()) 1253 return emitError("output levels and return levels type mismatch"); 1254 1255 const auto valuesTp = getRankedTensorType(getRetValues()); 1256 const auto lvlsTp = getRetLevels().getTypes(); 1257 const auto srcTp = getSparseTensorType(getTensor()); 1258 return verifyPackUnPack(*this, false, srcTp, valuesTp, lvlsTp); 1259 } 1260 1261 LogicalResult ConvertOp::verify() { 1262 if (auto tp1 = llvm::dyn_cast<RankedTensorType>(getSource().getType())) { 1263 if (auto tp2 = llvm::dyn_cast<RankedTensorType>(getDest().getType())) { 1264 if (tp1.getRank() != tp2.getRank()) 1265 return emitError("unexpected conversion mismatch in rank"); 1266 auto dstEnc = 1267 llvm::dyn_cast_or_null<SparseTensorEncodingAttr>(tp2.getEncoding()); 1268 if (dstEnc && dstEnc.isSlice()) 1269 return emitError("cannot convert to a sparse tensor slice"); 1270 1271 auto shape1 = tp1.getShape(); 1272 auto shape2 = tp2.getShape(); 1273 // Accept size matches between the source and the destination type 1274 // (e.g. 10 vs. 10, 10 vs. ?, or ? vs. ?), but reject direct mismatches or 1275 // matches that would need a runtime assert (e.g. 10 vs. 20 or ? vs. 10). 1276 for (Dimension d = 0, dimRank = tp1.getRank(); d < dimRank; d++) 1277 if (shape1[d] != shape2[d] && shape2[d] != ShapedType::kDynamic) 1278 return emitError("unexpected conversion mismatch in dimension ") << d; 1279 return success(); 1280 } 1281 } 1282 return emitError("unexpected type in convert"); 1283 } 1284 1285 OpFoldResult ConvertOp::fold(FoldAdaptor adaptor) { 1286 if (getType() == getSource().getType()) 1287 return getSource(); 1288 return {}; 1289 } 1290 1291 bool ConvertOp::needsExtraSort() { 1292 SparseTensorType srcStt = getSparseTensorType(getSource()); 1293 SparseTensorType dstStt = getSparseTensorType(getDest()); 1294 1295 // We do not need an extra sort when returning unordered sparse tensors or 1296 // dense tensor since dense tensor support random access. 1297 if (dstStt.isAllDense() || !dstStt.isAllOrdered()) 1298 return false; 1299 1300 if (srcStt.isAllOrdered() && dstStt.isAllOrdered() && 1301 srcStt.hasSameDimToLvl(dstStt)) { 1302 return false; 1303 } 1304 1305 // Source and dest tensors are ordered in different ways. We only do direct 1306 // dense to sparse conversion when the dense input is defined by a sparse 1307 // constant. Note that we can theoretically always directly convert from dense 1308 // inputs by rotating dense loops but it leads to bad cache locality and hurt 1309 // performance. 1310 if (auto constOp = getSource().getDefiningOp<arith::ConstantOp>()) 1311 if (isa<SparseElementsAttr>(constOp.getValue())) 1312 return false; 1313 1314 return true; 1315 } 1316 1317 LogicalResult CrdTranslateOp::verify() { 1318 uint64_t inRank = getEncoder().getLvlRank(); 1319 uint64_t outRank = getEncoder().getDimRank(); 1320 1321 if (getDirection() == CrdTransDirectionKind::dim2lvl) 1322 std::swap(inRank, outRank); 1323 1324 if (inRank != getInCrds().size() || outRank != getOutCrds().size()) 1325 return emitError("Coordinate rank mismatch with encoding"); 1326 1327 return success(); 1328 } 1329 1330 LogicalResult CrdTranslateOp::fold(FoldAdaptor adaptor, 1331 SmallVectorImpl<OpFoldResult> &results) { 1332 if (getEncoder().isIdentity()) { 1333 results.assign(getInCrds().begin(), getInCrds().end()); 1334 return success(); 1335 } 1336 if (getEncoder().isPermutation()) { 1337 AffineMap perm = getDirection() == CrdTransDirectionKind::dim2lvl 1338 ? getEncoder().getDimToLvl() 1339 : getEncoder().getLvlToDim(); 1340 for (AffineExpr exp : perm.getResults()) 1341 results.push_back(getInCrds()[cast<AffineDimExpr>(exp).getPosition()]); 1342 return success(); 1343 } 1344 1345 // Fuse dim2lvl/lvl2dim pairs. 1346 auto def = getInCrds()[0].getDefiningOp<CrdTranslateOp>(); 1347 bool sameDef = def && llvm::all_of(getInCrds(), [def](Value v) { 1348 return v.getDefiningOp() == def; 1349 }); 1350 if (!sameDef) 1351 return failure(); 1352 1353 bool oppositeDir = def.getDirection() != getDirection(); 1354 bool sameOracle = 1355 def.getEncoder().getDimToLvl() == getEncoder().getDimToLvl(); 1356 bool sameCount = def.getNumResults() == getInCrds().size(); 1357 if (!oppositeDir || !sameOracle || !sameCount) 1358 return failure(); 1359 1360 // The definition produces the coordinates in the same order as the input 1361 // coordinates. 1362 bool sameOrder = llvm::all_of(llvm::zip_equal(def.getOutCrds(), getInCrds()), 1363 [](auto valuePair) { 1364 auto [lhs, rhs] = valuePair; 1365 return lhs == rhs; 1366 }); 1367 1368 if (!sameOrder) 1369 return failure(); 1370 // l1 = dim2lvl (lvl2dim l0) 1371 // ==> l0 1372 results.append(def.getInCrds().begin(), def.getInCrds().end()); 1373 return success(); 1374 } 1375 1376 void LvlOp::build(OpBuilder &builder, OperationState &state, Value source, 1377 int64_t index) { 1378 Value val = builder.create<arith::ConstantIndexOp>(state.location, index); 1379 return build(builder, state, source, val); 1380 } 1381 1382 LogicalResult LvlOp::verify() { 1383 if (std::optional<uint64_t> lvl = getConstantLvlIndex()) { 1384 auto stt = getSparseTensorType(getSource()); 1385 if (static_cast<uint64_t>(lvl.value()) >= stt.getLvlRank()) 1386 emitError("Level index exceeds the rank of the input sparse tensor"); 1387 } 1388 return success(); 1389 } 1390 1391 std::optional<uint64_t> LvlOp::getConstantLvlIndex() { 1392 return getConstantIntValue(getIndex()); 1393 } 1394 1395 Speculation::Speculatability LvlOp::getSpeculatability() { 1396 auto constantIndex = getConstantLvlIndex(); 1397 if (!constantIndex) 1398 return Speculation::NotSpeculatable; 1399 1400 assert(constantIndex < 1401 cast<RankedTensorType>(getSource().getType()).getRank()); 1402 return Speculation::Speculatable; 1403 } 1404 1405 OpFoldResult LvlOp::fold(FoldAdaptor adaptor) { 1406 auto lvlIndex = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getIndex()); 1407 if (!lvlIndex) 1408 return {}; 1409 1410 Level lvl = lvlIndex.getAPSInt().getZExtValue(); 1411 auto stt = getSparseTensorType(getSource()); 1412 if (lvl >= stt.getLvlRank()) { 1413 // Follows the same convention used by tensor.dim operation. Out of bound 1414 // indices produce undefined behavior but are still valid IR. Don't choke on 1415 // them. 1416 return {}; 1417 } 1418 1419 // Helper lambda to build an IndexAttr. 1420 auto getIndexAttr = [this](int64_t lvlSz) { 1421 return IntegerAttr::get(IndexType::get(getContext()), APInt(64, lvlSz)); 1422 }; 1423 1424 SmallVector<Size> lvlShape = stt.getLvlShape(); 1425 if (!ShapedType::isDynamic(lvlShape[lvl])) 1426 return getIndexAttr(lvlShape[lvl]); 1427 1428 return {}; 1429 } 1430 1431 void ReinterpretMapOp::build(OpBuilder &odsBuilder, OperationState &odsState, 1432 SparseTensorEncodingAttr dstEnc, Value source) { 1433 auto srcStt = getSparseTensorType(source); 1434 SmallVector<int64_t> srcLvlShape = srcStt.getLvlShape(); 1435 SmallVector<int64_t> dstDimShape = 1436 dstEnc.translateShape(srcLvlShape, CrdTransDirectionKind::lvl2dim); 1437 auto dstTp = 1438 RankedTensorType::get(dstDimShape, srcStt.getElementType(), dstEnc); 1439 return build(odsBuilder, odsState, dstTp, source); 1440 } 1441 1442 LogicalResult ReinterpretMapOp::verify() { 1443 auto srcStt = getSparseTensorType(getSource()); 1444 auto dstStt = getSparseTensorType(getDest()); 1445 ArrayRef<LevelType> srcLvlTps = srcStt.getLvlTypes(); 1446 ArrayRef<LevelType> dstLvlTps = dstStt.getLvlTypes(); 1447 1448 if (srcLvlTps.size() != dstLvlTps.size()) 1449 return emitError("Level rank mismatch between source/dest tensors"); 1450 1451 for (auto [srcLvlTp, dstLvlTp] : llvm::zip(srcLvlTps, dstLvlTps)) 1452 if (srcLvlTp != dstLvlTp) 1453 return emitError("Level type mismatch between source/dest tensors"); 1454 1455 if (srcStt.getPosWidth() != dstStt.getPosWidth() || 1456 srcStt.getCrdWidth() != dstStt.getCrdWidth()) { 1457 return emitError("Crd/Pos width mismatch between source/dest tensors"); 1458 } 1459 1460 if (srcStt.getElementType() != dstStt.getElementType()) 1461 return emitError("Element type mismatch between source/dest tensors"); 1462 1463 SmallVector<Size> srcLvlShape = srcStt.getLvlShape(); 1464 SmallVector<Size> dstLvlShape = dstStt.getLvlShape(); 1465 for (auto [srcLvlSz, dstLvlSz] : llvm::zip(srcLvlShape, dstLvlShape)) { 1466 if (srcLvlSz != dstLvlSz) { 1467 // Should we allow one side to be dynamic size, e.g., <?x?> should be 1468 // compatible to <3x4>? For now, we require all the level sizes to be 1469 // *exactly* matched for simplicity. 1470 return emitError("Level size mismatch between source/dest tensors"); 1471 } 1472 } 1473 1474 return success(); 1475 } 1476 1477 OpFoldResult ReinterpretMapOp::fold(FoldAdaptor adaptor) { 1478 if (getSource().getType() == getDest().getType()) 1479 return getSource(); 1480 1481 if (auto def = getSource().getDefiningOp<ReinterpretMapOp>()) { 1482 // A -> B, B -> A ==> A 1483 if (def.getSource().getType() == getDest().getType()) 1484 return def.getSource(); 1485 } 1486 return {}; 1487 } 1488 1489 template <typename ToBufferOp> 1490 static LogicalResult inferSparseBufferType(ValueRange ops, DictionaryAttr attr, 1491 OpaqueProperties prop, 1492 RegionRange region, 1493 SmallVectorImpl<mlir::Type> &ret) { 1494 typename ToBufferOp::Adaptor adaptor(ops, attr, prop, region); 1495 SparseTensorType stt = getSparseTensorType(adaptor.getTensor()); 1496 Type elemTp = nullptr; 1497 bool withStride = false; 1498 if constexpr (std::is_same_v<ToBufferOp, ToPositionsOp>) { 1499 elemTp = stt.getPosType(); 1500 } else if constexpr (std::is_same_v<ToBufferOp, ToCoordinatesOp> || 1501 std::is_same_v<ToBufferOp, ToCoordinatesBufferOp>) { 1502 elemTp = stt.getCrdType(); 1503 if constexpr (std::is_same_v<ToBufferOp, ToCoordinatesOp>) 1504 withStride = stt.getAoSCOOStart() <= adaptor.getLevel(); 1505 } else if constexpr (std::is_same_v<ToBufferOp, ToValuesOp>) { 1506 elemTp = stt.getElementType(); 1507 } 1508 1509 assert(elemTp && "unhandled operation."); 1510 SmallVector<int64_t> bufShape = stt.getBatchLvlShape(); 1511 bufShape.push_back(ShapedType::kDynamic); 1512 1513 auto layout = withStride ? StridedLayoutAttr::StridedLayoutAttr::get( 1514 stt.getContext(), ShapedType::kDynamic, 1515 {ShapedType::kDynamic}) 1516 : StridedLayoutAttr(); 1517 ret.emplace_back(MemRefType::get(bufShape, elemTp, layout)); 1518 return success(); 1519 } 1520 1521 LogicalResult ToPositionsOp::verify() { 1522 auto stt = getSparseTensorType(getTensor()); 1523 if (failed(lvlIsInBounds(getLevel(), getTensor()))) 1524 return emitError("requested level is out of bounds"); 1525 if (failed(isMatchingWidth(getResult(), stt.getPosWidth()))) 1526 return emitError("unexpected type for positions"); 1527 return success(); 1528 } 1529 1530 LogicalResult 1531 ToPositionsOp::inferReturnTypes(MLIRContext *ctx, std::optional<Location> loc, 1532 ValueRange ops, DictionaryAttr attr, 1533 OpaqueProperties prop, RegionRange region, 1534 SmallVectorImpl<mlir::Type> &ret) { 1535 return inferSparseBufferType<ToPositionsOp>(ops, attr, prop, region, ret); 1536 } 1537 1538 LogicalResult ToCoordinatesOp::verify() { 1539 auto stt = getSparseTensorType(getTensor()); 1540 if (failed(lvlIsInBounds(getLevel(), getTensor()))) 1541 return emitError("requested level is out of bounds"); 1542 if (failed(isMatchingWidth(getResult(), stt.getCrdWidth()))) 1543 return emitError("unexpected type for coordinates"); 1544 return success(); 1545 } 1546 1547 LogicalResult 1548 ToCoordinatesOp::inferReturnTypes(MLIRContext *ctx, std::optional<Location> loc, 1549 ValueRange ops, DictionaryAttr attr, 1550 OpaqueProperties prop, RegionRange region, 1551 SmallVectorImpl<mlir::Type> &ret) { 1552 return inferSparseBufferType<ToCoordinatesOp>(ops, attr, prop, region, ret); 1553 } 1554 1555 LogicalResult ToCoordinatesBufferOp::verify() { 1556 auto stt = getSparseTensorType(getTensor()); 1557 if (stt.getAoSCOOStart() >= stt.getLvlRank()) 1558 return emitError("expected sparse tensor with a COO region"); 1559 return success(); 1560 } 1561 1562 LogicalResult ToCoordinatesBufferOp::inferReturnTypes( 1563 MLIRContext *ctx, std::optional<Location> loc, ValueRange ops, 1564 DictionaryAttr attr, OpaqueProperties prop, RegionRange region, 1565 SmallVectorImpl<mlir::Type> &ret) { 1566 return inferSparseBufferType<ToCoordinatesBufferOp>(ops, attr, prop, region, 1567 ret); 1568 } 1569 1570 LogicalResult ToValuesOp::verify() { 1571 auto stt = getSparseTensorType(getTensor()); 1572 auto mtp = getMemRefType(getResult()); 1573 if (stt.getElementType() != mtp.getElementType()) 1574 return emitError("unexpected mismatch in element types"); 1575 return success(); 1576 } 1577 1578 LogicalResult ToValuesOp::inferReturnTypes(MLIRContext *ctx, 1579 std::optional<Location> loc, 1580 ValueRange ops, DictionaryAttr attr, 1581 OpaqueProperties prop, 1582 RegionRange region, 1583 SmallVectorImpl<mlir::Type> &ret) { 1584 return inferSparseBufferType<ToValuesOp>(ops, attr, prop, region, ret); 1585 } 1586 1587 LogicalResult ToSliceOffsetOp::verify() { 1588 auto rank = getRankedTensorType(getSlice()).getRank(); 1589 if (rank <= getDim().getSExtValue() || getDim().getSExtValue() < 0) 1590 return emitError("requested dimension out of bound"); 1591 return success(); 1592 } 1593 1594 LogicalResult ToSliceStrideOp::verify() { 1595 auto rank = getRankedTensorType(getSlice()).getRank(); 1596 if (rank <= getDim().getSExtValue() || getDim().getSExtValue() < 0) 1597 return emitError("requested dimension out of bound"); 1598 return success(); 1599 } 1600 1601 LogicalResult GetStorageSpecifierOp::verify() { 1602 return verifySparsifierGetterSetter(getSpecifierKind(), getLevel(), 1603 getSpecifier(), getOperation()); 1604 } 1605 1606 template <typename SpecifierOp> 1607 static SetStorageSpecifierOp getSpecifierSetDef(SpecifierOp op) { 1608 return op.getSpecifier().template getDefiningOp<SetStorageSpecifierOp>(); 1609 } 1610 1611 OpFoldResult GetStorageSpecifierOp::fold(FoldAdaptor adaptor) { 1612 const StorageSpecifierKind kind = getSpecifierKind(); 1613 const auto lvl = getLevel(); 1614 for (auto op = getSpecifierSetDef(*this); op; op = getSpecifierSetDef(op)) 1615 if (kind == op.getSpecifierKind() && lvl == op.getLevel()) 1616 return op.getValue(); 1617 return {}; 1618 } 1619 1620 LogicalResult SetStorageSpecifierOp::verify() { 1621 return verifySparsifierGetterSetter(getSpecifierKind(), getLevel(), 1622 getSpecifier(), getOperation()); 1623 } 1624 1625 template <class T> 1626 static LogicalResult verifyNumBlockArgs(T *op, Region ®ion, 1627 const char *regionName, 1628 TypeRange inputTypes, Type outputType) { 1629 unsigned numArgs = region.getNumArguments(); 1630 unsigned expectedNum = inputTypes.size(); 1631 if (numArgs != expectedNum) 1632 return op->emitError() << regionName << " region must have exactly " 1633 << expectedNum << " arguments"; 1634 1635 for (unsigned i = 0; i < numArgs; i++) { 1636 Type typ = region.getArgument(i).getType(); 1637 if (typ != inputTypes[i]) 1638 return op->emitError() << regionName << " region argument " << (i + 1) 1639 << " type mismatch"; 1640 } 1641 Operation *term = region.front().getTerminator(); 1642 YieldOp yield = dyn_cast<YieldOp>(term); 1643 if (!yield) 1644 return op->emitError() << regionName 1645 << " region must end with sparse_tensor.yield"; 1646 if (!yield.hasSingleResult() || 1647 yield.getSingleResult().getType() != outputType) 1648 return op->emitError() << regionName << " region yield type mismatch"; 1649 1650 return success(); 1651 } 1652 1653 LogicalResult BinaryOp::verify() { 1654 NamedAttrList attrs = (*this)->getAttrs(); 1655 Type leftType = getX().getType(); 1656 Type rightType = getY().getType(); 1657 Type outputType = getOutput().getType(); 1658 Region &overlap = getOverlapRegion(); 1659 Region &left = getLeftRegion(); 1660 Region &right = getRightRegion(); 1661 1662 // Check correct number of block arguments and return type for each 1663 // non-empty region. 1664 if (!overlap.empty()) { 1665 if (failed(verifyNumBlockArgs(this, overlap, "overlap", 1666 TypeRange{leftType, rightType}, outputType))) 1667 return failure(); 1668 } 1669 if (!left.empty()) { 1670 if (failed(verifyNumBlockArgs(this, left, "left", TypeRange{leftType}, 1671 outputType))) 1672 return failure(); 1673 } else if (getLeftIdentity()) { 1674 if (leftType != outputType) 1675 return emitError("left=identity requires first argument to have the same " 1676 "type as the output"); 1677 } 1678 if (!right.empty()) { 1679 if (failed(verifyNumBlockArgs(this, right, "right", TypeRange{rightType}, 1680 outputType))) 1681 return failure(); 1682 } else if (getRightIdentity()) { 1683 if (rightType != outputType) 1684 return emitError("right=identity requires second argument to have the " 1685 "same type as the output"); 1686 } 1687 return success(); 1688 } 1689 1690 LogicalResult UnaryOp::verify() { 1691 Type inputType = getX().getType(); 1692 Type outputType = getOutput().getType(); 1693 1694 // Check correct number of block arguments and return type for each 1695 // non-empty region. 1696 Region &present = getPresentRegion(); 1697 if (!present.empty()) { 1698 if (failed(verifyNumBlockArgs(this, present, "present", 1699 TypeRange{inputType}, outputType))) 1700 return failure(); 1701 } 1702 Region &absent = getAbsentRegion(); 1703 if (!absent.empty()) { 1704 if (failed(verifyNumBlockArgs(this, absent, "absent", TypeRange{}, 1705 outputType))) 1706 return failure(); 1707 // Absent branch can only yield invariant values. 1708 Block *absentBlock = &absent.front(); 1709 Block *parent = getOperation()->getBlock(); 1710 Value absentVal = 1711 cast<YieldOp>(absentBlock->getTerminator()).getSingleResult(); 1712 if (auto arg = dyn_cast<BlockArgument>(absentVal)) { 1713 if (arg.getOwner() == parent) 1714 return emitError("absent region cannot yield linalg argument"); 1715 } else if (Operation *def = absentVal.getDefiningOp()) { 1716 if (!isa<arith::ConstantOp>(def) && 1717 (def->getBlock() == absentBlock || def->getBlock() == parent)) 1718 return emitError("absent region cannot yield locally computed value"); 1719 } 1720 } 1721 return success(); 1722 } 1723 1724 bool ConcatenateOp::needsExtraSort() { 1725 SparseTensorType dstStt = getSparseTensorType(*this); 1726 if (dstStt.isAllDense() || !dstStt.isAllOrdered()) 1727 return false; 1728 1729 bool allSameOrdered = llvm::all_of(getInputs(), [dstStt](Value op) { 1730 return getSparseTensorType(op).hasSameDimToLvl(dstStt); 1731 }); 1732 // TODO: When conDim != 0, as long as conDim corresponding to the first level 1733 // in all input/output buffers, and all input/output buffers have the same 1734 // dimToLvl, the tmp COO buffer is still unnecessary (e.g, concatenate 1735 // CSC matrices along column). 1736 bool directLowerable = 1737 allSameOrdered && getDimension() == 0 && dstStt.isIdentity(); 1738 return !directLowerable; 1739 } 1740 1741 LogicalResult ConcatenateOp::verify() { 1742 const auto dstTp = getSparseTensorType(*this); 1743 const Dimension concatDim = getDimension(); 1744 const Dimension dimRank = dstTp.getDimRank(); 1745 1746 if (getInputs().size() <= 1) 1747 return emitError("Need at least two tensors to concatenate."); 1748 1749 if (concatDim >= dimRank) 1750 return emitError(llvm::formatv( 1751 "Concat-dimension is out of bounds for dimension-rank ({0} >= {1})", 1752 concatDim, dimRank)); 1753 1754 for (const auto &it : llvm::enumerate(getInputs())) { 1755 const auto i = it.index(); 1756 const auto srcTp = getSparseTensorType(it.value()); 1757 if (srcTp.hasDynamicDimShape()) 1758 return emitError(llvm::formatv("Input tensor ${0} has dynamic shape", i)); 1759 const Dimension srcDimRank = srcTp.getDimRank(); 1760 if (srcDimRank != dimRank) 1761 return emitError( 1762 llvm::formatv("Input tensor ${0} has a different rank (rank={1}) " 1763 "from the output tensor (rank={2}).", 1764 i, srcDimRank, dimRank)); 1765 } 1766 1767 for (Dimension d = 0; d < dimRank; d++) { 1768 const Size dstSh = dstTp.getDimShape()[d]; 1769 if (d == concatDim) { 1770 if (!ShapedType::isDynamic(dstSh)) { 1771 // If we reach here, then all inputs have static shapes. So we 1772 // can use `getDimShape()[d]` instead of `*getDynamicDimSize(d)` 1773 // to avoid redundant assertions in the loop. 1774 Size sumSz = 0; 1775 for (const auto src : getInputs()) 1776 sumSz += getSparseTensorType(src).getDimShape()[d]; 1777 // If all dimension are statically known, the sum of all the input 1778 // dimensions should be equal to the output dimension. 1779 if (sumSz != dstSh) 1780 return emitError( 1781 "The concatenation dimension of the output tensor should be the " 1782 "sum of all the concatenation dimensions of the input tensors."); 1783 } 1784 } else { 1785 Size prev = dstSh; 1786 for (const auto src : getInputs()) { 1787 const auto sh = getSparseTensorType(src).getDimShape()[d]; 1788 if (!ShapedType::isDynamic(prev) && sh != prev) 1789 return emitError("All dimensions (expect for the concatenating one) " 1790 "should be equal."); 1791 prev = sh; 1792 } 1793 } 1794 } 1795 1796 return success(); 1797 } 1798 1799 void PushBackOp::build(OpBuilder &builder, OperationState &result, 1800 Value curSize, Value inBuffer, Value value) { 1801 build(builder, result, curSize, inBuffer, value, Value()); 1802 } 1803 1804 LogicalResult PushBackOp::verify() { 1805 if (Value n = getN()) { 1806 std::optional<int64_t> nValue = getConstantIntValue(n); 1807 if (nValue && nValue.value() < 1) 1808 return emitOpError("n must be not less than 1"); 1809 } 1810 return success(); 1811 } 1812 1813 LogicalResult CompressOp::verify() { 1814 const auto stt = getSparseTensorType(getTensor()); 1815 if (stt.getLvlRank() != 1 + static_cast<Level>(getLvlCoords().size())) 1816 return emitOpError("incorrect number of coordinates"); 1817 return success(); 1818 } 1819 1820 void ForeachOp::build( 1821 OpBuilder &builder, OperationState &result, Value tensor, 1822 ValueRange initArgs, AffineMapAttr order, 1823 function_ref<void(OpBuilder &, Location, ValueRange, Value, ValueRange)> 1824 bodyBuilder) { 1825 build(builder, result, initArgs.getTypes(), tensor, initArgs, order); 1826 // Builds foreach body. 1827 if (!bodyBuilder) 1828 return; 1829 const auto stt = getSparseTensorType(tensor); 1830 const Dimension dimRank = stt.getDimRank(); 1831 1832 // Starts with `dimRank`-many coordinates. 1833 SmallVector<Type> blockArgTypes(dimRank, builder.getIndexType()); 1834 // Followed by one value. 1835 blockArgTypes.push_back(stt.getElementType()); 1836 // Followed by the reduction variables. 1837 blockArgTypes.append(initArgs.getTypes().begin(), initArgs.getTypes().end()); 1838 1839 SmallVector<Location> blockArgLocs(blockArgTypes.size(), tensor.getLoc()); 1840 1841 OpBuilder::InsertionGuard guard(builder); 1842 auto ®ion = *result.regions.front(); 1843 Block *bodyBlock = 1844 builder.createBlock(®ion, region.end(), blockArgTypes, blockArgLocs); 1845 bodyBuilder(builder, result.location, 1846 bodyBlock->getArguments().slice(0, dimRank), 1847 bodyBlock->getArguments()[dimRank], 1848 bodyBlock->getArguments().drop_front(dimRank + 1)); 1849 } 1850 1851 LogicalResult ForeachOp::verify() { 1852 const auto t = getSparseTensorType(getTensor()); 1853 const Dimension dimRank = t.getDimRank(); 1854 const auto args = getBody()->getArguments(); 1855 1856 if (getOrder().has_value() && getOrder()->getNumDims() != t.getLvlRank()) 1857 return emitError("Level traverse order does not match tensor's level rank"); 1858 1859 if (dimRank + 1 + getInitArgs().size() != args.size()) 1860 return emitError("Unmatched number of arguments in the block"); 1861 1862 if (getNumResults() != getInitArgs().size()) 1863 return emitError("Mismatch in number of init arguments and results"); 1864 1865 if (getResultTypes() != getInitArgs().getTypes()) 1866 return emitError("Mismatch in types of init arguments and results"); 1867 1868 // Cannot mark this const, because the getters aren't. 1869 auto yield = cast<YieldOp>(getBody()->getTerminator()); 1870 if (yield.getNumOperands() != getNumResults() || 1871 yield.getOperands().getTypes() != getResultTypes()) 1872 return emitError("Mismatch in types of yield values and results"); 1873 1874 const auto iTp = IndexType::get(getContext()); 1875 for (Dimension d = 0; d < dimRank; d++) 1876 if (args[d].getType() != iTp) 1877 emitError( 1878 llvm::formatv("Expecting Index type for argument at index {0}", d)); 1879 1880 const auto elemTp = t.getElementType(); 1881 const auto valueTp = args[dimRank].getType(); 1882 if (elemTp != valueTp) 1883 emitError(llvm::formatv("Unmatched element type between input tensor and " 1884 "block argument, expected:{0}, got: {1}", 1885 elemTp, valueTp)); 1886 return success(); 1887 } 1888 1889 OpFoldResult ReorderCOOOp::fold(FoldAdaptor adaptor) { 1890 if (getSparseTensorEncoding(getInputCoo().getType()) == 1891 getSparseTensorEncoding(getResultCoo().getType())) 1892 return getInputCoo(); 1893 1894 return {}; 1895 } 1896 1897 LogicalResult ReorderCOOOp::verify() { 1898 SparseTensorType srcStt = getSparseTensorType(getInputCoo()); 1899 SparseTensorType dstStt = getSparseTensorType(getResultCoo()); 1900 1901 if (!srcStt.isCOOType() || !dstStt.isCOOType()) 1902 emitError("Expected COO sparse tensors only"); 1903 1904 if (!srcStt.hasSameDimToLvl(dstStt)) 1905 emitError("Unmatched dim2lvl map between input and result COO"); 1906 1907 if (srcStt.getPosType() != dstStt.getPosType() || 1908 srcStt.getCrdType() != dstStt.getCrdType() || 1909 srcStt.getElementType() != dstStt.getElementType()) 1910 emitError("Unmatched storage format between input and result COO"); 1911 1912 return success(); 1913 } 1914 1915 LogicalResult ReduceOp::verify() { 1916 Type inputType = getX().getType(); 1917 Region &formula = getRegion(); 1918 return verifyNumBlockArgs(this, formula, "reduce", 1919 TypeRange{inputType, inputType}, inputType); 1920 } 1921 1922 LogicalResult SelectOp::verify() { 1923 Builder b(getContext()); 1924 Type inputType = getX().getType(); 1925 Type boolType = b.getI1Type(); 1926 Region &formula = getRegion(); 1927 return verifyNumBlockArgs(this, formula, "select", TypeRange{inputType}, 1928 boolType); 1929 } 1930 1931 LogicalResult SortOp::verify() { 1932 AffineMap xPerm = getPermMap(); 1933 uint64_t nx = xPerm.getNumDims(); 1934 if (nx < 1) 1935 emitError(llvm::formatv("Expected rank(perm_map) > 1, got {0}", nx)); 1936 1937 if (!xPerm.isPermutation()) 1938 emitError(llvm::formatv("Expected a permutation map, got {0}", xPerm)); 1939 1940 // We can't check the size of the buffers when n or buffer dimensions aren't 1941 // compile-time constants. 1942 std::optional<int64_t> cn = getConstantIntValue(getN()); 1943 if (!cn) 1944 return success(); 1945 1946 // Verify dimensions. 1947 const auto checkDim = [&](Value v, Size minSize, const char *message) { 1948 const Size sh = getMemRefType(v).getShape()[0]; 1949 if (!ShapedType::isDynamic(sh) && sh < minSize) 1950 emitError(llvm::formatv("{0} got {1} < {2}", message, sh, minSize)); 1951 }; 1952 uint64_t n = cn.value(); 1953 uint64_t ny = 0; 1954 if (auto nyAttr = getNyAttr()) 1955 ny = nyAttr.getInt(); 1956 checkDim(getXy(), n * (nx + ny), 1957 "Expected dimension(xy) >= n * (rank(perm_map) + ny)"); 1958 for (Value opnd : getYs()) 1959 checkDim(opnd, n, "Expected dimension(y) >= n"); 1960 1961 return success(); 1962 } 1963 1964 //===----------------------------------------------------------------------===// 1965 // Sparse Tensor Iteration Operations. 1966 //===----------------------------------------------------------------------===// 1967 1968 IterSpaceType IteratorType::getIterSpaceType() const { 1969 return IterSpaceType::get(getContext(), getEncoding(), getLoLvl(), 1970 getHiLvl()); 1971 } 1972 1973 IteratorType IterSpaceType::getIteratorType() const { 1974 return IteratorType::get(getContext(), getEncoding(), getLoLvl(), getHiLvl()); 1975 } 1976 1977 /// Parses a level range in the form "$lo `to` $hi" 1978 /// or simply "$lo" if $hi - $lo = 1 1979 static ParseResult parseLevelRange(AsmParser &parser, Level &lvlLo, 1980 Level &lvlHi) { 1981 if (parser.parseInteger(lvlLo)) 1982 return failure(); 1983 1984 if (succeeded(parser.parseOptionalKeyword("to"))) { 1985 if (parser.parseInteger(lvlHi)) 1986 return failure(); 1987 } else { 1988 lvlHi = lvlLo + 1; 1989 } 1990 1991 if (lvlHi <= lvlLo) 1992 parser.emitError(parser.getNameLoc(), 1993 "expect larger level upper bound than lower bound"); 1994 1995 return success(); 1996 } 1997 1998 /// Parses a level range in the form "$lo `to` $hi" 1999 /// or simply "$lo" if $hi - $lo = 1 2000 static ParseResult parseLevelRange(OpAsmParser &parser, IntegerAttr &lvlLoAttr, 2001 IntegerAttr &lvlHiAttr) { 2002 Level lvlLo, lvlHi; 2003 if (parseLevelRange(parser, lvlLo, lvlHi)) 2004 return failure(); 2005 2006 lvlLoAttr = IntegerAttr::get(parser.getBuilder().getIndexType(), lvlLo); 2007 lvlHiAttr = IntegerAttr::get(parser.getBuilder().getIndexType(), lvlHi); 2008 return success(); 2009 } 2010 2011 /// Prints a level range in the form "$lo `to` $hi" 2012 /// or simply "$lo" if $hi - $lo = 1 2013 static void printLevelRange(AsmPrinter &p, Level lo, Level hi) { 2014 2015 if (lo + 1 == hi) 2016 p << lo; 2017 else 2018 p << lo << " to " << hi; 2019 } 2020 2021 /// Prints a level range in the form "$lo `to` $hi" 2022 /// or simply "$lo" if $hi - $lo = 1 2023 static void printLevelRange(OpAsmPrinter &p, Operation *, IntegerAttr lvlLo, 2024 IntegerAttr lvlHi) { 2025 unsigned lo = lvlLo.getValue().getZExtValue(); 2026 unsigned hi = lvlHi.getValue().getZExtValue(); 2027 printLevelRange(p, lo, hi); 2028 } 2029 2030 static ParseResult 2031 parseSparseSpaceLoop(OpAsmParser &parser, OperationState &state, 2032 SmallVectorImpl<OpAsmParser::Argument> &iterators, 2033 SmallVectorImpl<OpAsmParser::Argument> &iterArgs) { 2034 SmallVector<OpAsmParser::UnresolvedOperand> spaces; 2035 SmallVector<OpAsmParser::UnresolvedOperand> initArgs; 2036 2037 // Parses "%iters, ... in %spaces, ..." 2038 if (parser.parseArgumentList(iterators) || parser.parseKeyword("in") || 2039 parser.parseOperandList(spaces)) 2040 return failure(); 2041 2042 if (iterators.size() != spaces.size()) 2043 return parser.emitError( 2044 parser.getNameLoc(), 2045 "mismatch in number of sparse iterators and sparse spaces"); 2046 2047 // Parse "at(%crd0, _, ...)" 2048 LevelSet crdUsedLvlSet; 2049 bool hasUsedCrds = succeeded(parser.parseOptionalKeyword("at")); 2050 unsigned lvlCrdCnt = 0; 2051 if (hasUsedCrds) { 2052 ParseResult crdList = parser.parseCommaSeparatedList( 2053 OpAsmParser::Delimiter::Paren, [&]() -> ParseResult { 2054 if (parser.parseOptionalKeyword("_")) { 2055 if (parser.parseArgument(iterArgs.emplace_back())) 2056 return failure(); 2057 // Always use IndexType for the coordinate. 2058 crdUsedLvlSet.set(lvlCrdCnt); 2059 iterArgs.back().type = parser.getBuilder().getIndexType(); 2060 } 2061 lvlCrdCnt += 1; 2062 return success(); 2063 }); 2064 if (failed(crdList)) { 2065 return parser.emitError( 2066 parser.getNameLoc(), 2067 "expecting SSA value or \"_\" for level coordinates"); 2068 } 2069 } 2070 // Set the CrdUsedLvl bitset. 2071 state.addAttribute("crdUsedLvls", 2072 parser.getBuilder().getI64IntegerAttr(crdUsedLvlSet)); 2073 2074 // Parse "iter_args(%arg = %init, ...)" 2075 bool hasIterArgs = succeeded(parser.parseOptionalKeyword("iter_args")); 2076 if (hasIterArgs) 2077 if (parser.parseAssignmentList(iterArgs, initArgs)) 2078 return failure(); 2079 2080 SmallVector<Type> iterSpaceTps; 2081 // parse ": sparse_tensor.iter_space -> ret" 2082 if (parser.parseColon() || parser.parseTypeList(iterSpaceTps)) 2083 return failure(); 2084 if (iterSpaceTps.size() != spaces.size()) 2085 return parser.emitError(parser.getNameLoc(), 2086 "mismatch in number of iteration space operands " 2087 "and iteration space types"); 2088 2089 for (auto [it, tp] : llvm::zip_equal(iterators, iterSpaceTps)) { 2090 IterSpaceType spaceTp = llvm::dyn_cast<IterSpaceType>(tp); 2091 if (!spaceTp) 2092 return parser.emitError(parser.getNameLoc(), 2093 "expected sparse_tensor.iter_space type for " 2094 "iteration space operands"); 2095 if (hasUsedCrds && spaceTp.getSpaceDim() != lvlCrdCnt) 2096 return parser.emitError(parser.getNameLoc(), 2097 "mismatch in number of iteration space dimension " 2098 "and specified coordinates"); 2099 it.type = spaceTp.getIteratorType(); 2100 } 2101 2102 if (hasIterArgs) 2103 if (parser.parseArrowTypeList(state.types)) 2104 return failure(); 2105 2106 // Resolves input operands. 2107 if (parser.resolveOperands(spaces, iterSpaceTps, parser.getNameLoc(), 2108 state.operands)) 2109 return failure(); 2110 2111 if (hasIterArgs) { 2112 unsigned numCrds = crdUsedLvlSet.count(); 2113 // Strip off leading args that used for coordinates. 2114 MutableArrayRef args = MutableArrayRef(iterArgs).drop_front(numCrds); 2115 if (args.size() != initArgs.size() || args.size() != state.types.size()) { 2116 return parser.emitError( 2117 parser.getNameLoc(), 2118 "mismatch in number of iteration arguments and return values"); 2119 } 2120 2121 for (auto [it, init, tp] : llvm::zip_equal(args, initArgs, state.types)) { 2122 it.type = tp; 2123 if (parser.resolveOperand(init, tp, state.operands)) 2124 return failure(); 2125 } 2126 } 2127 return success(); 2128 } 2129 2130 LogicalResult ExtractIterSpaceOp::inferReturnTypes( 2131 MLIRContext *ctx, std::optional<Location> loc, ValueRange ops, 2132 DictionaryAttr attr, OpaqueProperties prop, RegionRange region, 2133 SmallVectorImpl<mlir::Type> &ret) { 2134 2135 ExtractIterSpaceOp::Adaptor adaptor(ops, attr, prop, region); 2136 SparseTensorType stt = getSparseTensorType(adaptor.getTensor()); 2137 ret.push_back(IterSpaceType::get(ctx, stt.getEncoding(), adaptor.getLoLvl(), 2138 adaptor.getHiLvl())); 2139 return success(); 2140 } 2141 2142 LogicalResult ExtractIterSpaceOp::verify() { 2143 if (getLoLvl() >= getHiLvl()) 2144 return emitOpError("expected smaller level low than level high"); 2145 2146 TypedValue<IteratorType> pIter = getParentIter(); 2147 if ((pIter && getLoLvl() == 0) || (!pIter && getLoLvl() != 0)) { 2148 return emitOpError( 2149 "parent iterator should be specified iff level lower bound equals 0"); 2150 } 2151 2152 if (pIter) { 2153 IterSpaceType spaceTp = getResultSpace().getType(); 2154 if (pIter.getType().getEncoding() != spaceTp.getEncoding()) 2155 return emitOpError( 2156 "mismatch in parent iterator encoding and iteration space encoding."); 2157 2158 if (spaceTp.getLoLvl() != pIter.getType().getHiLvl()) 2159 return emitOpError("parent iterator should be used to extract an " 2160 "iteration space from a consecutive level."); 2161 } 2162 2163 return success(); 2164 } 2165 2166 ParseResult IterateOp::parse(OpAsmParser &parser, OperationState &result) { 2167 OpAsmParser::Argument iterator; 2168 OpAsmParser::UnresolvedOperand iterSpace; 2169 2170 SmallVector<OpAsmParser::Argument> iters, iterArgs; 2171 if (parseSparseSpaceLoop(parser, result, iters, iterArgs)) 2172 return failure(); 2173 if (iters.size() != 1) 2174 return parser.emitError(parser.getNameLoc(), 2175 "expected only one iterator/iteration space"); 2176 2177 iters.append(iterArgs); 2178 Region *body = result.addRegion(); 2179 if (parser.parseRegion(*body, iters)) 2180 return failure(); 2181 2182 IterateOp::ensureTerminator(*body, parser.getBuilder(), result.location); 2183 2184 // Parse the optional attribute list. 2185 if (parser.parseOptionalAttrDict(result.attributes)) 2186 return failure(); 2187 2188 return success(); 2189 } 2190 2191 /// Prints the initialization list in the form of 2192 /// <prefix>(%inner = %outer, %inner2 = %outer2, <...>) 2193 /// where 'inner' values are assumed to be region arguments and 'outer' values 2194 /// are regular SSA values. 2195 static void printInitializationList(OpAsmPrinter &p, 2196 Block::BlockArgListType blocksArgs, 2197 ValueRange initializers, 2198 StringRef prefix = "") { 2199 assert(blocksArgs.size() == initializers.size() && 2200 "expected same length of arguments and initializers"); 2201 if (initializers.empty()) 2202 return; 2203 2204 p << prefix << '('; 2205 llvm::interleaveComma(llvm::zip(blocksArgs, initializers), p, [&](auto it) { 2206 p << std::get<0>(it) << " = " << std::get<1>(it); 2207 }); 2208 p << ")"; 2209 } 2210 2211 static void printUsedCrdsList(OpAsmPrinter &p, unsigned spaceDim, 2212 Block::BlockArgListType blocksArgs, 2213 LevelSet crdUsedLvls) { 2214 if (crdUsedLvls.empty()) 2215 return; 2216 2217 p << " at("; 2218 for (unsigned i = 0; i < spaceDim; i++) { 2219 if (crdUsedLvls[i]) { 2220 p << blocksArgs.front(); 2221 blocksArgs = blocksArgs.drop_front(); 2222 } else { 2223 p << "_"; 2224 } 2225 if (i != spaceDim - 1) 2226 p << ", "; 2227 } 2228 assert(blocksArgs.empty()); 2229 p << ")"; 2230 } 2231 2232 void IterateOp::print(OpAsmPrinter &p) { 2233 p << " " << getIterator() << " in " << getIterSpace(); 2234 printUsedCrdsList(p, getSpaceDim(), getCrds(), getCrdUsedLvls()); 2235 printInitializationList(p, getRegionIterArgs(), getInitArgs(), " iter_args"); 2236 2237 p << " : " << getIterSpace().getType() << " "; 2238 if (!getInitArgs().empty()) 2239 p << "-> (" << getInitArgs().getTypes() << ") "; 2240 2241 p.printRegion(getRegion(), /*printEntryBlockArgs=*/false, 2242 /*printBlockTerminators=*/!getInitArgs().empty()); 2243 } 2244 2245 LogicalResult IterateOp::verify() { 2246 if (getInitArgs().size() != getNumResults()) { 2247 return emitOpError( 2248 "mismatch in number of loop-carried values and defined values"); 2249 } 2250 return success(); 2251 } 2252 2253 LogicalResult IterateOp::verifyRegions() { 2254 if (getIterator().getType() != getIterSpace().getType().getIteratorType()) 2255 return emitOpError("mismatch in iterator and iteration space type"); 2256 if (getNumRegionIterArgs() != getNumResults()) 2257 return emitOpError( 2258 "mismatch in number of basic block args and defined values"); 2259 2260 auto initArgs = getInitArgs(); 2261 auto iterArgs = getRegionIterArgs(); 2262 auto yieldVals = getYieldedValues(); 2263 auto opResults = getResults(); 2264 if (!llvm::all_equal({initArgs.size(), iterArgs.size(), yieldVals.size(), 2265 opResults.size()})) { 2266 return emitOpError() << "number mismatch between iter args and results."; 2267 } 2268 2269 for (auto [i, init, iter, yield, ret] : 2270 llvm::enumerate(initArgs, iterArgs, yieldVals, opResults)) { 2271 if (init.getType() != ret.getType()) 2272 return emitOpError() << "types mismatch between " << i 2273 << "th iter operand and defined value"; 2274 if (iter.getType() != ret.getType()) 2275 return emitOpError() << "types mismatch between " << i 2276 << "th iter region arg and defined value"; 2277 if (yield.getType() != ret.getType()) 2278 return emitOpError() << "types mismatch between " << i 2279 << "th yield value and defined value"; 2280 } 2281 2282 return success(); 2283 } 2284 2285 /// IterateOp implemented OpInterfaces' methods. 2286 SmallVector<Region *> IterateOp::getLoopRegions() { return {&getRegion()}; } 2287 2288 MutableArrayRef<OpOperand> IterateOp::getInitsMutable() { 2289 return getInitArgsMutable(); 2290 } 2291 2292 Block::BlockArgListType IterateOp::getRegionIterArgs() { 2293 return getRegion().getArguments().take_back(getNumRegionIterArgs()); 2294 } 2295 2296 std::optional<MutableArrayRef<OpOperand>> IterateOp::getYieldedValuesMutable() { 2297 return cast<sparse_tensor::YieldOp>( 2298 getRegion().getBlocks().front().getTerminator()) 2299 .getResultsMutable(); 2300 } 2301 2302 std::optional<ResultRange> IterateOp::getLoopResults() { return getResults(); } 2303 2304 OperandRange IterateOp::getEntrySuccessorOperands(RegionBranchPoint point) { 2305 return getInitArgs(); 2306 } 2307 2308 void IterateOp::getSuccessorRegions(RegionBranchPoint point, 2309 SmallVectorImpl<RegionSuccessor> ®ions) { 2310 // Both the operation itself and the region may be branching into the body or 2311 // back into the operation itself. 2312 regions.push_back(RegionSuccessor(&getRegion(), getRegionIterArgs())); 2313 // It is possible for loop not to enter the body. 2314 regions.push_back(RegionSuccessor(getResults())); 2315 } 2316 2317 //===----------------------------------------------------------------------===// 2318 // Sparse Tensor Dialect Setups. 2319 //===----------------------------------------------------------------------===// 2320 2321 /// Materialize a single constant operation from a given attribute value with 2322 /// the desired resultant type. 2323 Operation *SparseTensorDialect::materializeConstant(OpBuilder &builder, 2324 Attribute value, Type type, 2325 Location loc) { 2326 if (auto op = arith::ConstantOp::materialize(builder, value, type, loc)) 2327 return op; 2328 return nullptr; 2329 } 2330 2331 namespace { 2332 struct SparseTensorAsmDialectInterface : public OpAsmDialectInterface { 2333 using OpAsmDialectInterface::OpAsmDialectInterface; 2334 2335 AliasResult getAlias(Attribute attr, raw_ostream &os) const override { 2336 if (attr.isa<SparseTensorEncodingAttr>()) { 2337 os << "sparse"; 2338 return AliasResult::OverridableAlias; 2339 } 2340 return AliasResult::NoAlias; 2341 } 2342 }; 2343 } // namespace 2344 2345 void SparseTensorDialect::initialize() { 2346 addInterface<SparseTensorAsmDialectInterface>(); 2347 addAttributes< 2348 #define GET_ATTRDEF_LIST 2349 #include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.cpp.inc" 2350 >(); 2351 addTypes< 2352 #define GET_TYPEDEF_LIST 2353 #include "mlir/Dialect/SparseTensor/IR/SparseTensorTypes.cpp.inc" 2354 >(); 2355 addOperations< 2356 #define GET_OP_LIST 2357 #include "mlir/Dialect/SparseTensor/IR/SparseTensorOps.cpp.inc" 2358 >(); 2359 declarePromisedInterfaces< 2360 bufferization::BufferizableOpInterface, ConcatenateOp, ConvertOp, LoadOp, 2361 NewOp, NumberOfEntriesOp, AssembleOp, DisassembleOp, 2362 ToCoordinatesBufferOp, ToCoordinatesOp, ToPositionsOp, ToValuesOp>(); 2363 } 2364 2365 #define GET_OP_CLASSES 2366 #include "mlir/Dialect/SparseTensor/IR/SparseTensorOps.cpp.inc" 2367 2368 #include "mlir/Dialect/SparseTensor/IR/SparseTensorOpsDialect.cpp.inc" 2369