1 //===- SparseTensorDialect.cpp - Sparse tensor dialect implementation -----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include <utility> 10 11 #include "Detail/DimLvlMapParser.h" 12 13 #include "mlir/Dialect/SparseTensor/IR/Enums.h" 14 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" 15 #include "mlir/Dialect/SparseTensor/IR/SparseTensorStorageLayout.h" 16 #include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h" 17 18 #include "mlir/Dialect/Arith/IR/Arith.h" 19 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" 20 #include "mlir/Dialect/Complex/IR/Complex.h" 21 #include "mlir/Dialect/Utils/StaticValueUtils.h" 22 #include "mlir/IR/Builders.h" 23 #include "mlir/IR/DialectImplementation.h" 24 #include "mlir/IR/Matchers.h" 25 #include "mlir/IR/OpImplementation.h" 26 #include "mlir/IR/PatternMatch.h" 27 #include "llvm/ADT/Bitset.h" 28 #include "llvm/ADT/TypeSwitch.h" 29 #include "llvm/Support/FormatVariadic.h" 30 31 #define GET_ATTRDEF_CLASSES 32 #include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.cpp.inc" 33 #include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrEnums.cpp.inc" 34 35 // Forward declarations, following custom print/parsing methods are referenced 36 // by the generated code for SparseTensorTypes.td. 37 static mlir::ParseResult parseLevelRange(mlir::AsmParser &, 38 mlir::sparse_tensor::Level &, 39 mlir::sparse_tensor::Level &); 40 static void printLevelRange(mlir::AsmPrinter &, mlir::sparse_tensor::Level, 41 mlir::sparse_tensor::Level); 42 43 #define GET_TYPEDEF_CLASSES 44 #include "mlir/Dialect/SparseTensor/IR/SparseTensorTypes.cpp.inc" 45 46 using namespace mlir; 47 using namespace mlir::sparse_tensor; 48 49 // Support hashing LevelType such that SparseTensorEncodingAttr can be hashed as 50 // well. 51 namespace mlir::sparse_tensor { 52 llvm::hash_code hash_value(LevelType lt) { 53 return llvm::hash_value(static_cast<uint64_t>(lt)); 54 } 55 } // namespace mlir::sparse_tensor 56 57 //===----------------------------------------------------------------------===// 58 // Local Convenience Methods. 59 //===----------------------------------------------------------------------===// 60 61 static constexpr bool acceptBitWidth(unsigned bitWidth) { 62 switch (bitWidth) { 63 case 0: 64 case 8: 65 case 16: 66 case 32: 67 case 64: 68 return true; 69 default: 70 return false; 71 } 72 } 73 74 static SmallVector<Size> 75 getSparseFieldShape(const SparseTensorEncodingAttr enc, 76 std::optional<ArrayRef<int64_t>> dimShape) { 77 assert(enc); 78 // With only encoding, we can not determine the static shape for leading 79 // batch levels, we therefore return a dynamic shape memref instead. 80 SmallVector<int64_t> memrefShape(enc.getBatchLvlRank(), ShapedType::kDynamic); 81 if (dimShape.has_value()) { 82 // If the actual tensor shape is provided, we can then refine the leading 83 // batch dimension. 84 SmallVector<int64_t> lvlShape = 85 enc.translateShape(*dimShape, CrdTransDirectionKind::dim2lvl); 86 memrefShape.assign(lvlShape.begin(), 87 lvlShape.begin() + enc.getBatchLvlRank()); 88 } 89 // Another dynamic dimension to store the sparse level. 90 memrefShape.push_back(ShapedType::kDynamic); 91 return memrefShape; 92 } 93 94 //===----------------------------------------------------------------------===// 95 // SparseTensorDialect StorageLayout. 96 //===----------------------------------------------------------------------===// 97 98 static constexpr Level kInvalidLevel = -1u; 99 static constexpr Level kInvalidFieldIndex = -1u; 100 static constexpr FieldIndex kDataFieldStartingIdx = 0; 101 102 void StorageLayout::foreachField( 103 llvm::function_ref<bool(FieldIndex, SparseTensorFieldKind, Level, 104 LevelType)> 105 callback) const { 106 const auto lvlTypes = enc.getLvlTypes(); 107 const Level lvlRank = enc.getLvlRank(); 108 SmallVector<COOSegment> cooSegs = enc.getCOOSegments(); 109 FieldIndex fieldIdx = kDataFieldStartingIdx; 110 111 ArrayRef cooSegsRef = cooSegs; 112 // Per-level storage. 113 for (Level l = 0; l < lvlRank; /*l += 1 or l += AoSCooLen*/) { 114 const auto lt = lvlTypes[l]; 115 if (isWithPosLT(lt)) { 116 if (!(callback(fieldIdx++, SparseTensorFieldKind::PosMemRef, l, lt))) 117 return; 118 } 119 if (isWithCrdLT(lt)) { 120 if (!(callback(fieldIdx++, SparseTensorFieldKind::CrdMemRef, l, lt))) 121 return; 122 } 123 if (!cooSegsRef.empty() && cooSegsRef.front().isSegmentStart(l)) { 124 if (!cooSegsRef.front().isSoA) { 125 // AoS COO, all singletons are fused into one memrefs. Skips the entire 126 // COO segement. 127 l = cooSegsRef.front().lvlRange.second; 128 } else { 129 // SoA COO, each singleton level has one memref. 130 l++; 131 } 132 // Expire handled COO segment. 133 cooSegsRef = cooSegsRef.drop_front(); 134 } else { 135 // Non COO levels. 136 l++; 137 } 138 } 139 // The values array. 140 if (!(callback(fieldIdx++, SparseTensorFieldKind::ValMemRef, kInvalidLevel, 141 LevelFormat::Undef))) 142 return; 143 // Put metadata at the end. 144 if (!(callback(fieldIdx++, SparseTensorFieldKind::StorageSpec, kInvalidLevel, 145 LevelFormat::Undef))) 146 return; 147 } 148 149 void sparse_tensor::foreachFieldAndTypeInSparseTensor( 150 SparseTensorType stt, 151 llvm::function_ref<bool(Type, FieldIndex, SparseTensorFieldKind, Level, 152 LevelType)> 153 callback) { 154 assert(stt.hasEncoding()); 155 156 SmallVector<int64_t> memrefShape = 157 getSparseFieldShape(stt.getEncoding(), stt.getDimShape()); 158 159 const Type specType = StorageSpecifierType::get(stt.getEncoding()); 160 // memref<[batch] x ? x pos> positions 161 const Type posMemType = MemRefType::get(memrefShape, stt.getPosType()); 162 // memref<[batch] x ? x crd> coordinates 163 const Type crdMemType = MemRefType::get(memrefShape, stt.getCrdType()); 164 // memref<[batch] x ? x eltType> values 165 const Type valMemType = MemRefType::get(memrefShape, stt.getElementType()); 166 167 StorageLayout(stt).foreachField([specType, posMemType, crdMemType, valMemType, 168 callback](FieldIndex fieldIdx, 169 SparseTensorFieldKind fieldKind, 170 Level lvl, LevelType lt) -> bool { 171 switch (fieldKind) { 172 case SparseTensorFieldKind::StorageSpec: 173 return callback(specType, fieldIdx, fieldKind, lvl, lt); 174 case SparseTensorFieldKind::PosMemRef: 175 return callback(posMemType, fieldIdx, fieldKind, lvl, lt); 176 case SparseTensorFieldKind::CrdMemRef: 177 return callback(crdMemType, fieldIdx, fieldKind, lvl, lt); 178 case SparseTensorFieldKind::ValMemRef: 179 return callback(valMemType, fieldIdx, fieldKind, lvl, lt); 180 }; 181 llvm_unreachable("unrecognized field kind"); 182 }); 183 } 184 185 unsigned StorageLayout::getNumFields() const { 186 unsigned numFields = 0; 187 foreachField([&numFields](FieldIndex, SparseTensorFieldKind, Level, 188 LevelType) -> bool { 189 numFields++; 190 return true; 191 }); 192 return numFields; 193 } 194 195 unsigned StorageLayout::getNumDataFields() const { 196 unsigned numFields = 0; // one value memref 197 foreachField([&numFields](FieldIndex fidx, SparseTensorFieldKind, Level, 198 LevelType) -> bool { 199 if (fidx >= kDataFieldStartingIdx) 200 numFields++; 201 return true; 202 }); 203 numFields -= 1; // the last field is StorageSpecifier 204 assert(numFields == getNumFields() - kDataFieldStartingIdx - 1); 205 return numFields; 206 } 207 208 std::pair<FieldIndex, unsigned> 209 StorageLayout::getFieldIndexAndStride(SparseTensorFieldKind kind, 210 std::optional<Level> lvl) const { 211 FieldIndex fieldIdx = kInvalidFieldIndex; 212 unsigned stride = 1; 213 if (kind == SparseTensorFieldKind::CrdMemRef) { 214 assert(lvl.has_value()); 215 const Level cooStart = enc.getAoSCOOStart(); 216 const Level lvlRank = enc.getLvlRank(); 217 if (lvl.value() >= cooStart && lvl.value() < lvlRank) { 218 lvl = cooStart; 219 stride = lvlRank - cooStart; 220 } 221 } 222 foreachField([lvl, kind, &fieldIdx](FieldIndex fIdx, 223 SparseTensorFieldKind fKind, Level fLvl, 224 LevelType lt) -> bool { 225 if ((lvl && fLvl == lvl.value() && kind == fKind) || 226 (kind == fKind && fKind == SparseTensorFieldKind::ValMemRef)) { 227 fieldIdx = fIdx; 228 // Returns false to break the iteration. 229 return false; 230 } 231 return true; 232 }); 233 assert(fieldIdx != kInvalidFieldIndex); 234 return std::pair<FieldIndex, unsigned>(fieldIdx, stride); 235 } 236 237 //===----------------------------------------------------------------------===// 238 // SparseTensorDialect Attribute Methods. 239 //===----------------------------------------------------------------------===// 240 241 std::optional<uint64_t> SparseTensorDimSliceAttr::getStatic(int64_t v) { 242 return isDynamic(v) ? std::nullopt 243 : std::make_optional(static_cast<uint64_t>(v)); 244 } 245 246 std::optional<uint64_t> SparseTensorDimSliceAttr::getStaticOffset() const { 247 return getStatic(getOffset()); 248 } 249 250 std::optional<uint64_t> SparseTensorDimSliceAttr::getStaticStride() const { 251 return getStatic(getStride()); 252 } 253 254 std::optional<uint64_t> SparseTensorDimSliceAttr::getStaticSize() const { 255 return getStatic(getSize()); 256 } 257 258 bool SparseTensorDimSliceAttr::isCompletelyDynamic() const { 259 return isDynamic(getOffset()) && isDynamic(getStride()) && 260 isDynamic(getSize()); 261 } 262 263 std::string SparseTensorDimSliceAttr::getStaticString(int64_t v) { 264 return isDynamic(v) ? "?" : std::to_string(v); 265 } 266 267 void SparseTensorDimSliceAttr::print(llvm::raw_ostream &os) const { 268 assert(getImpl() && "Uninitialized SparseTensorDimSliceAttr"); 269 os << '('; 270 os << getStaticString(getOffset()); 271 os << ", "; 272 os << getStaticString(getSize()); 273 os << ", "; 274 os << getStaticString(getStride()); 275 os << ')'; 276 } 277 278 void SparseTensorDimSliceAttr::print(AsmPrinter &printer) const { 279 print(printer.getStream()); 280 } 281 282 static ParseResult parseOptionalStaticSlice(int64_t &result, 283 AsmParser &parser) { 284 auto parseResult = parser.parseOptionalInteger(result); 285 if (parseResult.has_value()) { 286 if (parseResult.value().succeeded() && result < 0) { 287 parser.emitError( 288 parser.getCurrentLocation(), 289 "expect positive value or ? for slice offset/size/stride"); 290 return failure(); 291 } 292 return parseResult.value(); 293 } 294 295 // Else, and '?' which represented dynamic slice 296 result = SparseTensorDimSliceAttr::kDynamic; 297 return parser.parseQuestion(); 298 } 299 300 Attribute SparseTensorDimSliceAttr::parse(AsmParser &parser, Type type) { 301 int64_t offset = kDynamic, size = kDynamic, stride = kDynamic; 302 303 if (failed(parser.parseLParen()) || 304 failed(parseOptionalStaticSlice(offset, parser)) || 305 failed(parser.parseComma()) || 306 failed(parseOptionalStaticSlice(size, parser)) || 307 failed(parser.parseComma()) || 308 failed(parseOptionalStaticSlice(stride, parser)) || 309 failed(parser.parseRParen())) 310 return {}; 311 312 return parser.getChecked<SparseTensorDimSliceAttr>(parser.getContext(), 313 offset, size, stride); 314 } 315 316 LogicalResult 317 SparseTensorDimSliceAttr::verify(function_ref<InFlightDiagnostic()> emitError, 318 int64_t offset, int64_t size, int64_t stride) { 319 if (!isDynamic(offset) && offset < 0) 320 return emitError() << "expect non-negative value or ? for slice offset"; 321 if (!isDynamic(size) && size <= 0) 322 return emitError() << "expect positive value or ? for slice size"; 323 if (!isDynamic(stride) && stride <= 0) 324 return emitError() << "expect positive value or ? for slice stride"; 325 return success(); 326 } 327 328 SparseTensorEncodingAttr 329 SparseTensorEncodingAttr::withDimToLvl(AffineMap dimToLvl) const { 330 assert(getImpl() && "Uninitialized SparseTensorEncodingAttr"); 331 return SparseTensorEncodingAttr::get( 332 getContext(), getLvlTypes(), dimToLvl, AffineMap(), getPosWidth(), 333 getCrdWidth(), getExplicitVal(), getImplicitVal()); 334 } 335 336 SparseTensorEncodingAttr 337 SparseTensorEncodingAttr::withDimToLvl(SparseTensorEncodingAttr enc) const { 338 return withDimToLvl(enc ? enc.getDimToLvl() : AffineMap()); 339 } 340 341 SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutDimToLvl() const { 342 return withDimToLvl(AffineMap()); 343 } 344 345 SparseTensorEncodingAttr 346 SparseTensorEncodingAttr::withBitWidths(unsigned posWidth, 347 unsigned crdWidth) const { 348 assert(getImpl() && "Uninitialized SparseTensorEncodingAttr"); 349 return SparseTensorEncodingAttr::get( 350 getContext(), getLvlTypes(), getDimToLvl(), getLvlToDim(), posWidth, 351 crdWidth, getExplicitVal(), getImplicitVal()); 352 } 353 354 SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutBitWidths() const { 355 return withBitWidths(0, 0); 356 } 357 358 SparseTensorEncodingAttr 359 SparseTensorEncodingAttr::withExplicitVal(Attribute explicitVal) const { 360 assert(getImpl() && "Uninitialized SparseTensorEncodingAttr"); 361 return SparseTensorEncodingAttr::get( 362 getContext(), getLvlTypes(), getDimToLvl(), getLvlToDim(), getPosWidth(), 363 getCrdWidth(), explicitVal, getImplicitVal()); 364 } 365 366 SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutExplicitVal() const { 367 return withExplicitVal(Attribute()); 368 } 369 370 SparseTensorEncodingAttr 371 SparseTensorEncodingAttr::withImplicitVal(Attribute implicitVal) const { 372 assert(getImpl() && "Uninitialized SparseTensorEncodingAttr"); 373 return SparseTensorEncodingAttr::get( 374 getContext(), getLvlTypes(), getDimToLvl(), getLvlToDim(), getPosWidth(), 375 getCrdWidth(), getExplicitVal(), implicitVal); 376 } 377 378 SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutImplicitVal() const { 379 return withImplicitVal(Attribute()); 380 } 381 382 SparseTensorEncodingAttr SparseTensorEncodingAttr::withDimSlices( 383 ArrayRef<SparseTensorDimSliceAttr> dimSlices) const { 384 return SparseTensorEncodingAttr::get( 385 getContext(), getLvlTypes(), getDimToLvl(), getLvlToDim(), getPosWidth(), 386 getCrdWidth(), getExplicitVal(), getImplicitVal(), dimSlices); 387 } 388 389 SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutDimSlices() const { 390 return withDimSlices(ArrayRef<SparseTensorDimSliceAttr>{}); 391 } 392 393 uint64_t SparseTensorEncodingAttr::getBatchLvlRank() const { 394 ArrayRef<LevelType> lvlTypes = getLvlTypes(); 395 auto lastBatch = std::find_if(lvlTypes.rbegin(), lvlTypes.rend(), isBatchLT); 396 return std::distance(lastBatch, lvlTypes.rend()); 397 } 398 399 bool SparseTensorEncodingAttr::isAllDense() const { 400 return !getImpl() || llvm::all_of(getLvlTypes(), isDenseLT); 401 } 402 403 bool SparseTensorEncodingAttr::isAllOrdered() const { 404 return !getImpl() || llvm::all_of(getLvlTypes(), isOrderedLT); 405 } 406 407 Type SparseTensorEncodingAttr::getCrdElemType() const { 408 if (!getImpl()) 409 return nullptr; 410 if (getCrdWidth()) 411 return IntegerType::get(getContext(), getCrdWidth()); 412 return IndexType::get(getContext()); 413 } 414 415 Type SparseTensorEncodingAttr::getPosElemType() const { 416 if (!getImpl()) 417 return nullptr; 418 if (getPosWidth()) 419 return IntegerType::get(getContext(), getPosWidth()); 420 return IndexType::get(getContext()); 421 } 422 423 MemRefType SparseTensorEncodingAttr::getCrdMemRefType( 424 std::optional<ArrayRef<int64_t>> dimShape) const { 425 SmallVector<Size> shape = getSparseFieldShape(*this, dimShape); 426 return MemRefType::get(shape, getCrdElemType()); 427 } 428 429 MemRefType SparseTensorEncodingAttr::getPosMemRefType( 430 std::optional<ArrayRef<int64_t>> dimShape) const { 431 SmallVector<Size> shape = getSparseFieldShape(*this, dimShape); 432 return MemRefType::get(shape, getPosElemType()); 433 } 434 435 bool SparseTensorEncodingAttr::isIdentity() const { 436 return !getImpl() || !getDimToLvl() || getDimToLvl().isIdentity(); 437 } 438 439 bool SparseTensorEncodingAttr::isPermutation() const { 440 return !getImpl() || !getDimToLvl() || getDimToLvl().isPermutation(); 441 } 442 443 Dimension SparseTensorEncodingAttr::getDimRank() const { 444 assert(getImpl() && "Uninitialized SparseTensorEncodingAttr"); 445 const auto dimToLvl = getDimToLvl(); 446 return dimToLvl ? dimToLvl.getNumDims() : getLvlRank(); 447 } 448 449 Level SparseTensorEncodingAttr::getLvlRank() const { 450 assert(getImpl() && "Uninitialized SparseTensorEncodingAttr"); 451 return getLvlTypes().size(); 452 } 453 454 LevelType SparseTensorEncodingAttr::getLvlType(Level l) const { 455 if (!getImpl()) 456 return LevelFormat::Batch; 457 assert(l < getLvlRank() && "Level is out of bounds"); 458 return getLvlTypes()[l]; 459 } 460 461 bool SparseTensorEncodingAttr::isSlice() const { 462 assert(getImpl() && "Uninitialized SparseTensorEncodingAttr"); 463 return !getDimSlices().empty(); 464 } 465 466 SparseTensorDimSliceAttr 467 SparseTensorEncodingAttr::getDimSlice(Dimension dim) const { 468 assert(isSlice() && "Is not a slice"); 469 const auto dimSlices = getDimSlices(); 470 assert(dim < dimSlices.size() && "Dimension is out of bounds"); 471 return dimSlices[dim]; 472 } 473 474 std::optional<uint64_t> 475 SparseTensorEncodingAttr::getStaticDimSliceOffset(Dimension dim) const { 476 return getDimSlice(dim).getStaticOffset(); 477 } 478 479 std::optional<uint64_t> 480 SparseTensorEncodingAttr::getStaticDimSliceStride(Dimension dim) const { 481 return getDimSlice(dim).getStaticStride(); 482 } 483 484 std::optional<uint64_t> 485 SparseTensorEncodingAttr::getStaticLvlSliceOffset(Level lvl) const { 486 return getStaticDimSliceOffset(toDim(*this, lvl)); 487 } 488 489 std::optional<uint64_t> 490 SparseTensorEncodingAttr::getStaticLvlSliceStride(Level lvl) const { 491 return getStaticDimSliceStride(toDim(*this, lvl)); 492 } 493 494 SmallVector<int64_t> 495 SparseTensorEncodingAttr::translateShape(ArrayRef<int64_t> srcShape, 496 CrdTransDirectionKind dir) const { 497 if (isIdentity()) 498 return SmallVector<int64_t>(srcShape); 499 500 SmallVector<int64_t> ret; 501 unsigned rank = 502 dir == CrdTransDirectionKind::dim2lvl ? getLvlRank() : getDimRank(); 503 ret.reserve(rank); 504 505 if (isPermutation()) { 506 for (unsigned r = 0; r < rank; r++) { 507 unsigned trans = dir == CrdTransDirectionKind::dim2lvl ? toDim(*this, r) 508 : toLvl(*this, r); 509 ret.push_back(srcShape[trans]); 510 } 511 return ret; 512 } 513 514 // Handle non-permutation maps. 515 AffineMap transMap = 516 dir == CrdTransDirectionKind::dim2lvl ? getDimToLvl() : getLvlToDim(); 517 518 SmallVector<AffineExpr> dimRep; 519 dimRep.reserve(srcShape.size()); 520 for (int64_t sz : srcShape) { 521 if (!ShapedType::isDynamic(sz)) { 522 // Push back the max coordinate for the given dimension/level size. 523 dimRep.push_back(getAffineConstantExpr(sz - 1, getContext())); 524 } else { 525 // A dynamic size, use a AffineDimExpr to symbolize the value. 526 dimRep.push_back(getAffineDimExpr(dimRep.size(), getContext())); 527 } 528 }; 529 530 for (AffineExpr exp : transMap.getResults()) { 531 // Do constant propagation on the affine map. 532 AffineExpr evalExp = 533 simplifyAffineExpr(exp.replaceDims(dimRep), srcShape.size(), 0); 534 // use llvm namespace here to avoid ambiguity 535 if (auto c = llvm::dyn_cast<AffineConstantExpr>(evalExp)) { 536 ret.push_back(c.getValue() + 1); 537 } else { 538 if (auto mod = llvm::dyn_cast<AffineBinaryOpExpr>(evalExp); 539 mod && mod.getKind() == AffineExprKind::Mod) { 540 // We can still infer a static bound for expressions in form 541 // "d % constant" since d % constant \in [0, constant). 542 if (auto bound = llvm::dyn_cast<AffineConstantExpr>(mod.getRHS())) { 543 ret.push_back(bound.getValue()); 544 continue; 545 } 546 } 547 ret.push_back(ShapedType::kDynamic); 548 } 549 } 550 assert(ret.size() == rank); 551 return ret; 552 } 553 554 ValueRange 555 SparseTensorEncodingAttr::translateCrds(OpBuilder &builder, Location loc, 556 ValueRange crds, 557 CrdTransDirectionKind dir) const { 558 if (!getImpl()) 559 return crds; 560 561 SmallVector<Type> retType( 562 dir == CrdTransDirectionKind::lvl2dim ? getDimRank() : getLvlRank(), 563 builder.getIndexType()); 564 auto transOp = builder.create<CrdTranslateOp>(loc, retType, crds, dir, *this); 565 return transOp.getOutCrds(); 566 } 567 568 Attribute SparseTensorEncodingAttr::parse(AsmParser &parser, Type type) { 569 // Open "<{" part. 570 if (failed(parser.parseLess())) 571 return {}; 572 if (failed(parser.parseLBrace())) 573 return {}; 574 575 // Process the data from the parsed dictionary value into struct-like data. 576 SmallVector<LevelType> lvlTypes; 577 SmallVector<SparseTensorDimSliceAttr> dimSlices; 578 AffineMap dimToLvl = {}; 579 AffineMap lvlToDim = {}; 580 unsigned posWidth = 0; 581 unsigned crdWidth = 0; 582 Attribute explicitVal; 583 Attribute implicitVal; 584 StringRef attrName; 585 SmallVector<StringRef, 5> keys = {"map", "posWidth", "crdWidth", 586 "explicitVal", "implicitVal"}; 587 while (succeeded(parser.parseOptionalKeyword(&attrName))) { 588 // Detect admissible keyword. 589 auto *it = find(keys, attrName); 590 if (it == keys.end()) { 591 parser.emitError(parser.getNameLoc(), "unexpected key: ") << attrName; 592 return {}; 593 } 594 unsigned keyWordIndex = it - keys.begin(); 595 // Consume the `=` after keys 596 if (failed(parser.parseEqual())) 597 return {}; 598 // Dispatch on keyword. 599 switch (keyWordIndex) { 600 case 0: { // map 601 ir_detail::DimLvlMapParser cParser(parser); 602 auto res = cParser.parseDimLvlMap(); 603 if (failed(res)) 604 return {}; 605 const auto &dlm = *res; 606 607 const Level lvlRank = dlm.getLvlRank(); 608 for (Level lvl = 0; lvl < lvlRank; lvl++) 609 lvlTypes.push_back(dlm.getLvlType(lvl)); 610 611 const Dimension dimRank = dlm.getDimRank(); 612 for (Dimension dim = 0; dim < dimRank; dim++) 613 dimSlices.push_back(dlm.getDimSlice(dim)); 614 // NOTE: the old syntax requires an all-or-nothing approach to 615 // `dimSlices`; therefore, if any slice actually exists then we need 616 // to convert null-DSA into default/nop DSA. 617 const auto isDefined = [](SparseTensorDimSliceAttr slice) { 618 return static_cast<bool>(slice.getImpl()); 619 }; 620 if (llvm::any_of(dimSlices, isDefined)) { 621 const auto defaultSlice = 622 SparseTensorDimSliceAttr::get(parser.getContext()); 623 for (Dimension dim = 0; dim < dimRank; dim++) 624 if (!isDefined(dimSlices[dim])) 625 dimSlices[dim] = defaultSlice; 626 } else { 627 dimSlices.clear(); 628 } 629 630 dimToLvl = dlm.getDimToLvlMap(parser.getContext()); 631 lvlToDim = dlm.getLvlToDimMap(parser.getContext()); 632 break; 633 } 634 case 1: { // posWidth 635 Attribute attr; 636 if (failed(parser.parseAttribute(attr))) 637 return {}; 638 auto intAttr = llvm::dyn_cast<IntegerAttr>(attr); 639 if (!intAttr) { 640 parser.emitError(parser.getNameLoc(), 641 "expected an integral position bitwidth"); 642 return {}; 643 } 644 posWidth = intAttr.getInt(); 645 break; 646 } 647 case 2: { // crdWidth 648 Attribute attr; 649 if (failed(parser.parseAttribute(attr))) 650 return {}; 651 auto intAttr = llvm::dyn_cast<IntegerAttr>(attr); 652 if (!intAttr) { 653 parser.emitError(parser.getNameLoc(), 654 "expected an integral index bitwidth"); 655 return {}; 656 } 657 crdWidth = intAttr.getInt(); 658 break; 659 } 660 case 3: { // explicitVal 661 Attribute attr; 662 if (failed(parser.parseAttribute(attr))) 663 return {}; 664 if (auto result = llvm::dyn_cast<FloatAttr>(attr)) { 665 explicitVal = result; 666 } else if (auto result = llvm::dyn_cast<IntegerAttr>(attr)) { 667 explicitVal = result; 668 } else if (auto result = llvm::dyn_cast<complex::NumberAttr>(attr)) { 669 explicitVal = result; 670 } else { 671 parser.emitError(parser.getNameLoc(), 672 "expected a numeric value for explicitVal"); 673 return {}; 674 } 675 break; 676 } 677 case 4: { // implicitVal 678 Attribute attr; 679 if (failed(parser.parseAttribute(attr))) 680 return {}; 681 if (auto result = llvm::dyn_cast<FloatAttr>(attr)) { 682 implicitVal = result; 683 } else if (auto result = llvm::dyn_cast<IntegerAttr>(attr)) { 684 implicitVal = result; 685 } else if (auto result = llvm::dyn_cast<complex::NumberAttr>(attr)) { 686 implicitVal = result; 687 } else { 688 parser.emitError(parser.getNameLoc(), 689 "expected a numeric value for implicitVal"); 690 return {}; 691 } 692 break; 693 } 694 } // switch 695 // Only last item can omit the comma. 696 if (parser.parseOptionalComma().failed()) 697 break; 698 } 699 700 // Close "}>" part. 701 if (failed(parser.parseRBrace())) 702 return {}; 703 if (failed(parser.parseGreater())) 704 return {}; 705 706 // Construct struct-like storage for attribute. 707 if (!lvlToDim || lvlToDim.isEmpty()) { 708 lvlToDim = inferLvlToDim(dimToLvl, parser.getContext()); 709 } 710 return parser.getChecked<SparseTensorEncodingAttr>( 711 parser.getContext(), lvlTypes, dimToLvl, lvlToDim, posWidth, crdWidth, 712 explicitVal, implicitVal, dimSlices); 713 } 714 715 void SparseTensorEncodingAttr::print(AsmPrinter &printer) const { 716 auto map = static_cast<AffineMap>(getDimToLvl()); 717 // Empty affine map indicates identity map 718 if (!map) 719 map = AffineMap::getMultiDimIdentityMap(getLvlTypes().size(), getContext()); 720 printer << "<{ map = "; 721 printSymbols(map, printer); 722 printer << '('; 723 printDimensions(map, printer, getDimSlices()); 724 printer << ") -> ("; 725 printLevels(map, printer, getLvlTypes()); 726 printer << ')'; 727 // Print remaining members only for non-default values. 728 if (getPosWidth()) 729 printer << ", posWidth = " << getPosWidth(); 730 if (getCrdWidth()) 731 printer << ", crdWidth = " << getCrdWidth(); 732 if (getExplicitVal()) { 733 printer << ", explicitVal = " << getExplicitVal(); 734 } 735 if (getImplicitVal()) 736 printer << ", implicitVal = " << getImplicitVal(); 737 printer << " }>"; 738 } 739 740 void SparseTensorEncodingAttr::printSymbols(AffineMap &map, 741 AsmPrinter &printer) const { 742 if (map.getNumSymbols() == 0) 743 return; 744 printer << '['; 745 for (unsigned i = 0, n = map.getNumSymbols() - 1; i < n; i++) 746 printer << 's' << i << ", "; 747 if (map.getNumSymbols() >= 1) 748 printer << 's' << map.getNumSymbols() - 1; 749 printer << ']'; 750 } 751 752 void SparseTensorEncodingAttr::printDimensions( 753 AffineMap &map, AsmPrinter &printer, 754 ArrayRef<SparseTensorDimSliceAttr> dimSlices) const { 755 if (!dimSlices.empty()) { 756 for (unsigned i = 0, n = map.getNumDims() - 1; i < n; i++) 757 printer << 'd' << i << " : " << dimSlices[i] << ", "; 758 if (map.getNumDims() >= 1) { 759 printer << 'd' << map.getNumDims() - 1 << " : " 760 << dimSlices[map.getNumDims() - 1]; 761 } 762 } else { 763 for (unsigned i = 0, n = map.getNumDims() - 1; i < n; i++) 764 printer << 'd' << i << ", "; 765 if (map.getNumDims() >= 1) 766 printer << 'd' << map.getNumDims() - 1; 767 } 768 } 769 770 void SparseTensorEncodingAttr::printLevels(AffineMap &map, AsmPrinter &printer, 771 ArrayRef<LevelType> lvlTypes) const { 772 for (unsigned i = 0, n = map.getNumResults() - 1; i < n; i++) { 773 map.getResult(i).print(printer.getStream()); 774 printer << " : " << toMLIRString(lvlTypes[i]) << ", "; 775 } 776 if (map.getNumResults() >= 1) { 777 auto lastIndex = map.getNumResults() - 1; 778 map.getResult(lastIndex).print(printer.getStream()); 779 printer << " : " << toMLIRString(lvlTypes[lastIndex]); 780 } 781 } 782 783 LogicalResult SparseTensorEncodingAttr::verify( 784 function_ref<InFlightDiagnostic()> emitError, ArrayRef<LevelType> lvlTypes, 785 AffineMap dimToLvl, AffineMap lvlToDim, unsigned posWidth, 786 unsigned crdWidth, Attribute explicitVal, Attribute implicitVal, 787 ArrayRef<SparseTensorDimSliceAttr> dimSlices) { 788 if (!acceptBitWidth(posWidth)) 789 return emitError() << "unexpected position bitwidth: " << posWidth; 790 if (!acceptBitWidth(crdWidth)) 791 return emitError() << "unexpected coordinate bitwidth: " << crdWidth; 792 793 // Verify every COO segment. 794 auto *it = std::find_if(lvlTypes.begin(), lvlTypes.end(), isSingletonLT); 795 while (it != lvlTypes.end()) { 796 if (it == lvlTypes.begin() || 797 !(it - 1)->isa<LevelFormat::Compressed, LevelFormat::LooseCompressed>()) 798 return emitError() << "expected compressed or loose_compressed level " 799 "before singleton level"; 800 801 auto *curCOOEnd = std::find_if_not(it, lvlTypes.end(), isSingletonLT); 802 if (!std::all_of(it, curCOOEnd, 803 [](LevelType i) { return isSingletonLT(i); })) 804 return emitError() << "expected all singleton lvlTypes " 805 "following a singleton level"; 806 // We can potentially support mixed SoA/AoS singleton levels. 807 if (!std::all_of(it, curCOOEnd, [it](LevelType i) { 808 return it->isa<LevelPropNonDefault::SoA>() == 809 i.isa<LevelPropNonDefault::SoA>(); 810 })) { 811 return emitError() << "expected all singleton lvlTypes stored in the " 812 "same memory layout (SoA vs AoS)."; 813 } 814 it = std::find_if(curCOOEnd, lvlTypes.end(), isSingletonLT); 815 } 816 817 auto lastBatch = std::find_if(lvlTypes.rbegin(), lvlTypes.rend(), isBatchLT); 818 if (!std::all_of(lastBatch, lvlTypes.rend(), isBatchLT)) 819 return emitError() << "Batch lvlType can only be leading levels."; 820 821 // SoA property can only be applied on singleton level. 822 auto soaLvls = llvm::make_filter_range(lvlTypes, [](LevelType lt) { 823 return lt.isa<LevelPropNonDefault::SoA>(); 824 }); 825 if (llvm::any_of(soaLvls, [](LevelType lt) { 826 return !lt.isa<LevelFormat::Singleton>(); 827 })) { 828 return emitError() << "SoA is only applicable to singleton lvlTypes."; 829 } 830 831 // TODO: audit formats that actually are supported by backend. 832 if (auto it = std::find_if(lvlTypes.begin(), lvlTypes.end(), isNOutOfMLT); 833 it != std::end(lvlTypes)) { 834 if (it != lvlTypes.end() - 1) 835 return emitError() << "expected n_out_of_m to be the last level type"; 836 if (!std::all_of(lvlTypes.begin(), it, 837 [](LevelType i) { return isDenseLT(i); })) 838 return emitError() << "expected all dense lvlTypes " 839 "before a n_out_of_m level"; 840 if (dimToLvl && (dimToLvl.getNumDims() != dimToLvl.getNumResults())) { 841 if (!isBlockSparsity(dimToLvl)) { 842 return emitError() 843 << "expected 1xm block structure for n_out_of_m level"; 844 } 845 auto sizes = getBlockSize(dimToLvl); 846 unsigned coefficient = 0; 847 for (const auto &elem : sizes) { 848 if (elem != 0) { 849 if (elem != coefficient && coefficient != 0) { 850 return emitError() << "expected only one blocked level " 851 "with the same coefficients"; 852 } 853 coefficient = elem; 854 } 855 } 856 if (coefficient != getM(*it)) { 857 return emitError() << "expected coeffiencts of Affine expressions " 858 "to be equal to m of n_out_of_m level"; 859 } 860 } 861 } 862 // Before we can check that the level-rank is consistent/coherent 863 // across all fields, we need to define it. The source-of-truth for 864 // the `getLvlRank` method is the length of the level-types array, 865 // since it must always be provided and have full rank; therefore we 866 // use that same source-of-truth here. 867 const Level lvlRank = lvlTypes.size(); 868 if (lvlRank == 0) 869 return emitError() << "expected a non-empty array for lvlTypes"; 870 // We save `dimRank` here because we'll also need it to verify `dimSlices`. 871 const Dimension dimRank = dimToLvl ? dimToLvl.getNumDims() : lvlRank; 872 if (dimToLvl) { 873 if (dimToLvl.getNumResults() != lvlRank) 874 return emitError() 875 << "level-rank mismatch between dimToLvl and lvlTypes: " 876 << dimToLvl.getNumResults() << " != " << lvlRank; 877 auto inferRes = inferLvlToDim(dimToLvl, dimToLvl.getContext()); 878 // Symbols can't be inferred but are acceptable. 879 if (!inferRes && dimToLvl.getNumSymbols() == 0) 880 return emitError() << "failed to infer lvlToDim from dimToLvl"; 881 if (lvlToDim && (inferRes != lvlToDim)) 882 return emitError() << "expected lvlToDim to be an inverse of dimToLvl"; 883 if (dimRank > lvlRank) 884 return emitError() << "unexpected dimToLvl mapping from " << dimRank 885 << " to " << lvlRank; 886 } 887 if (!dimSlices.empty()) { 888 if (dimSlices.size() != dimRank) 889 return emitError() 890 << "dimension-rank mismatch between dimSlices and dimToLvl: " 891 << dimSlices.size() << " != " << dimRank; 892 // Compiler support for `dimSlices` currently requires that the two 893 // ranks agree. (However, it does allow `dimToLvl` to be a permutation.) 894 if (dimRank != lvlRank) 895 return emitError() 896 << "dimSlices expected dimension-rank to match level-rank: " 897 << dimRank << " != " << lvlRank; 898 } 899 return success(); 900 } 901 902 LogicalResult SparseTensorEncodingAttr::verifyEncoding( 903 ArrayRef<Size> dimShape, Type elementType, 904 function_ref<InFlightDiagnostic()> emitError) const { 905 // Check structural integrity. In particular, this ensures that the 906 // level-rank is coherent across all the fields. 907 if (failed(verify(emitError, getLvlTypes(), getDimToLvl(), getLvlToDim(), 908 getPosWidth(), getCrdWidth(), getExplicitVal(), 909 getImplicitVal(), getDimSlices()))) 910 return failure(); 911 // Check integrity with tensor type specifics. In particular, we 912 // need only check that the dimension-rank of the tensor agrees with 913 // the dimension-rank of the encoding. 914 const Dimension dimRank = dimShape.size(); 915 if (dimRank == 0) 916 return emitError() << "expected non-scalar sparse tensor"; 917 if (getDimRank() != dimRank) 918 return emitError() 919 << "dimension-rank mismatch between encoding and tensor shape: " 920 << getDimRank() << " != " << dimRank; 921 if (auto expVal = getExplicitVal()) { 922 Type attrType = llvm::dyn_cast<TypedAttr>(expVal).getType(); 923 if (attrType != elementType) { 924 return emitError() << "explicit value type mismatch between encoding and " 925 << "tensor element type: " << attrType 926 << " != " << elementType; 927 } 928 } 929 if (auto impVal = getImplicitVal()) { 930 Type attrType = llvm::dyn_cast<TypedAttr>(impVal).getType(); 931 if (attrType != elementType) { 932 return emitError() << "implicit value type mismatch between encoding and " 933 << "tensor element type: " << attrType 934 << " != " << elementType; 935 } 936 // Currently, we only support zero as the implicit value. 937 auto impFVal = llvm::dyn_cast<FloatAttr>(impVal); 938 auto impIntVal = llvm::dyn_cast<IntegerAttr>(impVal); 939 auto impComplexVal = llvm::dyn_cast<complex::NumberAttr>(impVal); 940 if ((impFVal && impFVal.getValue().isNonZero()) || 941 (impIntVal && !impIntVal.getValue().isZero()) || 942 (impComplexVal && (impComplexVal.getImag().isNonZero() || 943 impComplexVal.getReal().isNonZero()))) { 944 return emitError() << "implicit value must be zero"; 945 } 946 } 947 return success(); 948 } 949 950 Level mlir::sparse_tensor::SparseTensorEncodingAttr::getAoSCOOStart() const { 951 SmallVector<COOSegment> coo = getCOOSegments(); 952 assert(coo.size() == 1 || coo.empty()); 953 if (!coo.empty() && coo.front().isAoS()) { 954 return coo.front().lvlRange.first; 955 } 956 return getLvlRank(); 957 } 958 959 SmallVector<COOSegment> 960 mlir::sparse_tensor::SparseTensorEncodingAttr::getCOOSegments() const { 961 SmallVector<COOSegment> ret; 962 if (getLvlRank() <= 1) 963 return ret; 964 965 ArrayRef<LevelType> lts = getLvlTypes(); 966 Level l = 0; 967 while (l < getLvlRank()) { 968 auto lt = lts[l]; 969 if (lt.isa<LevelFormat::Compressed, LevelFormat::LooseCompressed>()) { 970 auto cur = lts.begin() + l; 971 auto end = std::find_if(cur + 1, lts.end(), [](LevelType lt) { 972 return !lt.isa<LevelFormat::Singleton>(); 973 }); 974 unsigned cooLen = std::distance(cur, end); 975 if (cooLen > 1) { 976 // To support mixed SoA/AoS COO, we should break the segment when the 977 // storage scheme changes, for now we faithfully assume that all 978 // consecutive singleton levels have the same storage format as verified 979 // STEA. 980 ret.push_back(COOSegment{std::make_pair(l, l + cooLen), 981 lts[l + 1].isa<LevelPropNonDefault::SoA>()}); 982 } 983 l += cooLen; 984 } else { 985 l++; 986 } 987 } 988 return ret; 989 } 990 991 //===----------------------------------------------------------------------===// 992 // SparseTensorType Methods. 993 //===----------------------------------------------------------------------===// 994 995 bool mlir::sparse_tensor::SparseTensorType::isCOOType(Level startLvl, 996 bool isUnique) const { 997 if (!hasEncoding()) 998 return false; 999 if (!isCompressedLvl(startLvl) && !isLooseCompressedLvl(startLvl)) 1000 return false; 1001 for (Level l = startLvl + 1; l < lvlRank; ++l) 1002 if (!isSingletonLvl(l)) 1003 return false; 1004 // If isUnique is true, then make sure that the last level is unique, 1005 // that is, when lvlRank == 1, the only compressed level is unique, 1006 // and when lvlRank > 1, the last singleton is unique. 1007 return !isUnique || isUniqueLvl(lvlRank - 1); 1008 } 1009 1010 RankedTensorType 1011 mlir::sparse_tensor::SparseTensorType::getCOOType(bool ordered) const { 1012 SmallVector<LevelType> lvlTypes; 1013 lvlTypes.reserve(lvlRank); 1014 // A non-unique compressed level at beginning (unless this is 1015 // also the last level, then it is unique). 1016 lvlTypes.push_back( 1017 *buildLevelType(LevelFormat::Compressed, ordered, lvlRank == 1)); 1018 if (lvlRank > 1) { 1019 // Followed by n-2 non-unique singleton levels. 1020 std::fill_n(std::back_inserter(lvlTypes), lvlRank - 2, 1021 *buildLevelType(LevelFormat::Singleton, ordered, false)); 1022 // Ends by a unique singleton level. 1023 lvlTypes.push_back(*buildLevelType(LevelFormat::Singleton, ordered, true)); 1024 } 1025 auto enc = SparseTensorEncodingAttr::get( 1026 getContext(), lvlTypes, getDimToLvl(), getLvlToDim(), getPosWidth(), 1027 getCrdWidth(), getExplicitVal(), getImplicitVal()); 1028 return RankedTensorType::get(getDimShape(), getElementType(), enc); 1029 } 1030 1031 //===----------------------------------------------------------------------===// 1032 // Convenience Methods. 1033 //===----------------------------------------------------------------------===// 1034 1035 SparseTensorEncodingAttr 1036 mlir::sparse_tensor::getSparseTensorEncoding(Type type) { 1037 if (auto ttp = llvm::dyn_cast<RankedTensorType>(type)) 1038 return llvm::dyn_cast_or_null<SparseTensorEncodingAttr>(ttp.getEncoding()); 1039 if (auto mdtp = llvm::dyn_cast<StorageSpecifierType>(type)) 1040 return mdtp.getEncoding(); 1041 return nullptr; 1042 } 1043 1044 AffineMap mlir::sparse_tensor::inferLvlToDim(AffineMap dimToLvl, 1045 MLIRContext *context) { 1046 auto map = static_cast<AffineMap>(dimToLvl); 1047 AffineMap lvlToDim; 1048 // Return an empty lvlToDim when inference is not successful. 1049 if (!map || map.getNumSymbols() != 0) { 1050 lvlToDim = AffineMap(); 1051 } else if (map.isPermutation()) { 1052 lvlToDim = inversePermutation(map); 1053 } else if (isBlockSparsity(map)) { 1054 lvlToDim = inverseBlockSparsity(map, context); 1055 } 1056 return lvlToDim; 1057 } 1058 1059 AffineMap mlir::sparse_tensor::inverseBlockSparsity(AffineMap dimToLvl, 1060 MLIRContext *context) { 1061 SmallVector<AffineExpr> lvlExprs; 1062 auto numLvls = dimToLvl.getNumResults(); 1063 lvlExprs.reserve(numLvls); 1064 // lvlExprComponents stores information of the floordiv and mod operations 1065 // applied to the same dimension, so as to build the lvlToDim map. 1066 std::map<unsigned, SmallVector<AffineExpr, 3>> lvlExprComponents; 1067 for (unsigned i = 0, n = numLvls; i < n; i++) { 1068 auto result = dimToLvl.getResult(i); 1069 if (auto binOp = dyn_cast<AffineBinaryOpExpr>(result)) { 1070 if (result.getKind() == AffineExprKind::FloorDiv) { 1071 // Position of the dimension in dimToLvl. 1072 auto pos = dyn_cast<AffineDimExpr>(binOp.getLHS()).getPosition(); 1073 assert(lvlExprComponents.find(pos) == lvlExprComponents.end() && 1074 "expected only one floordiv for each dimension"); 1075 SmallVector<AffineExpr, 3> components; 1076 // Level variable for floordiv. 1077 components.push_back(getAffineDimExpr(i, context)); 1078 // Multiplier. 1079 components.push_back(binOp.getRHS()); 1080 // Map key is the position of the dimension. 1081 lvlExprComponents[pos] = components; 1082 } else if (result.getKind() == AffineExprKind::Mod) { 1083 auto pos = dyn_cast<AffineDimExpr>(binOp.getLHS()).getPosition(); 1084 assert(lvlExprComponents.find(pos) != lvlExprComponents.end() && 1085 "expected floordiv before mod"); 1086 // Add level variable for mod to the same vector 1087 // of the corresponding floordiv. 1088 lvlExprComponents[pos].push_back(getAffineDimExpr(i, context)); 1089 } else { 1090 assert(false && "expected floordiv or mod"); 1091 } 1092 } else { 1093 lvlExprs.push_back(getAffineDimExpr(i, context)); 1094 } 1095 } 1096 // Build lvlExprs from lvlExprComponents. 1097 // For example, for il = i floordiv 2 and ii = i mod 2, the components 1098 // would be [il, 2, ii]. It could be used to build the AffineExpr 1099 // i = il * 2 + ii in lvlToDim. 1100 for (auto &components : lvlExprComponents) { 1101 assert(components.second.size() == 3 && 1102 "expected 3 components to build lvlExprs"); 1103 auto mulOp = getAffineBinaryOpExpr( 1104 AffineExprKind::Mul, components.second[0], components.second[1]); 1105 auto addOp = 1106 getAffineBinaryOpExpr(AffineExprKind::Add, mulOp, components.second[2]); 1107 lvlExprs.push_back(addOp); 1108 } 1109 return dimToLvl.get(dimToLvl.getNumResults(), 0, lvlExprs, context); 1110 } 1111 1112 SmallVector<unsigned> mlir::sparse_tensor::getBlockSize(AffineMap dimToLvl) { 1113 assert(isBlockSparsity(dimToLvl) && 1114 "expected dimToLvl to be block sparsity for calling getBlockSize"); 1115 SmallVector<unsigned> blockSize; 1116 for (auto result : dimToLvl.getResults()) { 1117 if (auto binOp = dyn_cast<AffineBinaryOpExpr>(result)) { 1118 if (result.getKind() == AffineExprKind::Mod) { 1119 blockSize.push_back( 1120 dyn_cast<AffineConstantExpr>(binOp.getRHS()).getValue()); 1121 } 1122 } else { 1123 blockSize.push_back(0); 1124 } 1125 } 1126 return blockSize; 1127 } 1128 1129 bool mlir::sparse_tensor::isBlockSparsity(AffineMap dimToLvl) { 1130 if (!dimToLvl) 1131 return false; 1132 std::map<unsigned, int64_t> coeffientMap; 1133 bool hasBlock = false; 1134 for (auto result : dimToLvl.getResults()) { 1135 if (auto binOp = dyn_cast<AffineBinaryOpExpr>(result)) { 1136 // Check for "dim op const". 1137 auto dimOp = dyn_cast<AffineDimExpr>(binOp.getLHS()); 1138 auto conOp = dyn_cast<AffineConstantExpr>(binOp.getRHS()); 1139 if (!dimOp || !conOp || conOp.getValue() <= 0) 1140 return false; 1141 // Inspect "dim / const" or "dim % const". 1142 auto pos = dimOp.getPosition(); 1143 if (binOp.getKind() == AffineExprKind::FloorDiv) { 1144 // Expect only one floordiv for each dimension. 1145 if (coeffientMap.find(pos) != coeffientMap.end()) 1146 return false; 1147 // Record coefficient of the floordiv. 1148 coeffientMap[pos] = conOp.getValue(); 1149 } else if (binOp.getKind() == AffineExprKind::Mod) { 1150 // Expect floordiv before mod. 1151 if (coeffientMap.find(pos) == coeffientMap.end()) 1152 return false; 1153 // Expect mod to have the same coefficient as floordiv. 1154 if (conOp.getValue() != coeffientMap[pos]) 1155 return false; 1156 hasBlock = true; 1157 } else { 1158 return false; 1159 } 1160 } else if (auto dimOp = dyn_cast<AffineDimExpr>(result)) { 1161 auto pos = dimOp.getPosition(); 1162 // Expect dim to be unset. 1163 if (coeffientMap.find(pos) != coeffientMap.end()) 1164 return false; 1165 coeffientMap[pos] = 0; 1166 } else { 1167 return false; 1168 } 1169 } 1170 return hasBlock; 1171 } 1172 1173 bool mlir::sparse_tensor::hasAnyNonIdentityOperandsOrResults(Operation *op) { 1174 auto hasNonIdentityMap = [](Value v) { 1175 auto stt = tryGetSparseTensorType(v); 1176 return stt && !stt->isIdentity(); 1177 }; 1178 1179 return llvm::any_of(op->getOperands(), hasNonIdentityMap) || 1180 llvm::any_of(op->getResults(), hasNonIdentityMap); 1181 } 1182 1183 Dimension mlir::sparse_tensor::toDim(SparseTensorEncodingAttr enc, Level l) { 1184 if (enc) { 1185 assert(enc.isPermutation() && "Non permutation map not supported"); 1186 if (const auto dimToLvl = enc.getDimToLvl()) 1187 return dimToLvl.getDimPosition(l); 1188 } 1189 return l; 1190 } 1191 1192 Level mlir::sparse_tensor::toLvl(SparseTensorEncodingAttr enc, Dimension d) { 1193 if (enc) { 1194 assert(enc.isPermutation() && "Non permutation map not supported"); 1195 if (const auto lvlToDim = enc.getLvlToDim()) 1196 return lvlToDim.getDimPosition(d); 1197 } 1198 return d; 1199 } 1200 1201 /// We normalized sparse tensor encoding attribute by always using 1202 /// ordered/unique LT such that "compressed_nu_no" and "compressed_nu" (as well 1203 /// as other variants) lead to the same storage specifier type, and stripping 1204 /// irrelevant fields that do not alter the sparse tensor memory layout. 1205 static SparseTensorEncodingAttr 1206 getNormalizedEncodingForSpecifier(SparseTensorEncodingAttr enc) { 1207 SmallVector<LevelType> lts; 1208 for (auto lt : enc.getLvlTypes()) 1209 lts.push_back(lt.stripStorageIrrelevantProperties()); 1210 1211 return SparseTensorEncodingAttr::get( 1212 enc.getContext(), lts, 1213 AffineMap(), // dimToLvl (irrelevant to storage specifier) 1214 AffineMap(), // lvlToDim (irrelevant to storage specifier) 1215 // Always use `index` for memSize and lvlSize instead of reusing 1216 // `getPosWidth` and `getCrdWidth`. It allows us to reuse the same SSA 1217 // value for different bitwidth, it also avoids casting between index and 1218 // integer (returned by DimOp) 1219 0, 0, 1220 Attribute(), // explicitVal (irrelevant to storage specifier) 1221 Attribute(), // implicitVal (irrelevant to storage specifier) 1222 enc.getDimSlices()); 1223 } 1224 1225 StorageSpecifierType 1226 StorageSpecifierType::get(MLIRContext *ctx, SparseTensorEncodingAttr encoding) { 1227 return Base::get(ctx, getNormalizedEncodingForSpecifier(encoding)); 1228 } 1229 1230 StorageSpecifierType 1231 StorageSpecifierType::getChecked(function_ref<InFlightDiagnostic()> emitError, 1232 MLIRContext *ctx, 1233 SparseTensorEncodingAttr encoding) { 1234 return Base::getChecked(emitError, ctx, 1235 getNormalizedEncodingForSpecifier(encoding)); 1236 } 1237 1238 //===----------------------------------------------------------------------===// 1239 // SparseTensorDialect Operations. 1240 //===----------------------------------------------------------------------===// 1241 1242 static LogicalResult lvlIsInBounds(Level lvl, Value tensor) { 1243 return success(lvl < getSparseTensorType(tensor).getLvlRank()); 1244 } 1245 1246 static LogicalResult isMatchingWidth(Value mem, unsigned width) { 1247 const Type etp = getMemRefType(mem).getElementType(); 1248 return success(width == 0 ? etp.isIndex() : etp.isInteger(width)); 1249 } 1250 1251 static LogicalResult verifySparsifierGetterSetter( 1252 StorageSpecifierKind mdKind, std::optional<Level> lvl, 1253 TypedValue<StorageSpecifierType> md, Operation *op) { 1254 if (mdKind == StorageSpecifierKind::ValMemSize && lvl) { 1255 return op->emitError( 1256 "redundant level argument for querying value memory size"); 1257 } 1258 1259 const auto enc = md.getType().getEncoding(); 1260 const Level lvlRank = enc.getLvlRank(); 1261 1262 if (mdKind == StorageSpecifierKind::DimOffset || 1263 mdKind == StorageSpecifierKind::DimStride) 1264 if (!enc.isSlice()) 1265 return op->emitError("requested slice data on non-slice tensor"); 1266 1267 if (mdKind != StorageSpecifierKind::ValMemSize) { 1268 if (!lvl) 1269 return op->emitError("missing level argument"); 1270 1271 const Level l = lvl.value(); 1272 if (l >= lvlRank) 1273 return op->emitError("requested level is out of bounds"); 1274 1275 if (mdKind == StorageSpecifierKind::PosMemSize && enc.isSingletonLvl(l)) 1276 return op->emitError( 1277 "requested position memory size on a singleton level"); 1278 } 1279 return success(); 1280 } 1281 1282 static Type getFieldElemType(SparseTensorType stt, SparseTensorFieldKind kind) { 1283 switch (kind) { 1284 case SparseTensorFieldKind::CrdMemRef: 1285 return stt.getCrdType(); 1286 case SparseTensorFieldKind::PosMemRef: 1287 return stt.getPosType(); 1288 case SparseTensorFieldKind::ValMemRef: 1289 return stt.getElementType(); 1290 case SparseTensorFieldKind::StorageSpec: 1291 return nullptr; 1292 } 1293 llvm_unreachable("Unrecognizable FieldKind"); 1294 } 1295 1296 static LogicalResult verifyPackUnPack(Operation *op, bool requiresStaticShape, 1297 SparseTensorType stt, 1298 RankedTensorType valTp, 1299 TypeRange lvlTps) { 1300 if (requiresStaticShape && !stt.hasStaticDimShape()) 1301 return op->emitError("the sparse-tensor must have static shape"); 1302 if (!stt.hasEncoding()) 1303 return op->emitError("the sparse-tensor must have an encoding attribute"); 1304 1305 // Verifies the trailing COO. 1306 Level cooStartLvl = stt.getAoSCOOStart(); 1307 if (cooStartLvl < stt.getLvlRank()) { 1308 // We only supports trailing COO for now, must be the last input. 1309 auto cooTp = llvm::cast<ShapedType>(lvlTps.back()); 1310 // The coordinates should be in shape of <? x rank> 1311 unsigned expCOORank = stt.getLvlRank() - cooStartLvl; 1312 if (cooTp.getRank() != 2 || expCOORank != cooTp.getShape().back()) { 1313 op->emitError("input/output trailing COO level-ranks don't match"); 1314 } 1315 } 1316 1317 // Verifies that all types match. 1318 StorageLayout layout(stt.getEncoding()); 1319 if (layout.getNumDataFields() != lvlTps.size() + 1) // plus one value memref 1320 return op->emitError("inconsistent number of fields between input/output"); 1321 1322 unsigned idx = 0; 1323 bool misMatch = false; 1324 layout.foreachField([&idx, &misMatch, stt, valTp, 1325 lvlTps](FieldIndex fid, SparseTensorFieldKind fKind, 1326 Level lvl, LevelType lt) -> bool { 1327 if (fKind == SparseTensorFieldKind::StorageSpec) 1328 return true; 1329 1330 Type inputTp = nullptr; 1331 if (fKind == SparseTensorFieldKind::ValMemRef) { 1332 inputTp = valTp; 1333 } else { 1334 assert(fid == idx && stt.getLvlType(lvl) == lt); 1335 inputTp = lvlTps[idx++]; 1336 } 1337 // The input element type and expected element type should match. 1338 Type inpElemTp = llvm::cast<TensorType>(inputTp).getElementType(); 1339 Type expElemTp = getFieldElemType(stt, fKind); 1340 if (inpElemTp != expElemTp) { 1341 misMatch = true; 1342 return false; // to terminate the iteration 1343 } 1344 return true; 1345 }); 1346 1347 if (misMatch) 1348 return op->emitError("input/output element-types don't match"); 1349 return success(); 1350 } 1351 1352 LogicalResult AssembleOp::verify() { 1353 const auto valuesTp = getRankedTensorType(getValues()); 1354 const auto lvlsTp = getLevels().getTypes(); 1355 const auto resTp = getSparseTensorType(getResult()); 1356 return verifyPackUnPack(*this, true, resTp, valuesTp, lvlsTp); 1357 } 1358 1359 LogicalResult DisassembleOp::verify() { 1360 if (getOutValues().getType() != getRetValues().getType()) 1361 return emitError("output values and return value type mismatch"); 1362 1363 for (auto [ot, rt] : llvm::zip_equal(getOutLevels(), getRetLevels())) 1364 if (ot.getType() != rt.getType()) 1365 return emitError("output levels and return levels type mismatch"); 1366 1367 const auto valuesTp = getRankedTensorType(getRetValues()); 1368 const auto lvlsTp = getRetLevels().getTypes(); 1369 const auto srcTp = getSparseTensorType(getTensor()); 1370 return verifyPackUnPack(*this, false, srcTp, valuesTp, lvlsTp); 1371 } 1372 1373 LogicalResult ConvertOp::verify() { 1374 if (auto tp1 = llvm::dyn_cast<RankedTensorType>(getSource().getType())) { 1375 if (auto tp2 = llvm::dyn_cast<RankedTensorType>(getDest().getType())) { 1376 if (tp1.getRank() != tp2.getRank()) 1377 return emitError("unexpected conversion mismatch in rank"); 1378 auto dstEnc = 1379 llvm::dyn_cast_or_null<SparseTensorEncodingAttr>(tp2.getEncoding()); 1380 if (dstEnc && dstEnc.isSlice()) 1381 return emitError("cannot convert to a sparse tensor slice"); 1382 1383 auto shape1 = tp1.getShape(); 1384 auto shape2 = tp2.getShape(); 1385 // Accept size matches between the source and the destination type 1386 // (e.g. 10 vs. 10, 10 vs. ?, or ? vs. ?), but reject direct mismatches or 1387 // matches that would need a runtime assert (e.g. 10 vs. 20 or ? vs. 10). 1388 for (Dimension d = 0, dimRank = tp1.getRank(); d < dimRank; d++) 1389 if (shape1[d] != shape2[d] && shape2[d] != ShapedType::kDynamic) 1390 return emitError("unexpected conversion mismatch in dimension ") << d; 1391 return success(); 1392 } 1393 } 1394 return emitError("unexpected type in convert"); 1395 } 1396 1397 OpFoldResult ConvertOp::fold(FoldAdaptor adaptor) { 1398 if (getType() == getSource().getType()) 1399 return getSource(); 1400 return {}; 1401 } 1402 1403 bool ConvertOp::needsExtraSort() { 1404 SparseTensorType srcStt = getSparseTensorType(getSource()); 1405 SparseTensorType dstStt = getSparseTensorType(getDest()); 1406 1407 // We do not need an extra sort when returning unordered sparse tensors or 1408 // dense tensor since dense tensor support random access. 1409 if (dstStt.isAllDense() || !dstStt.isAllOrdered()) 1410 return false; 1411 1412 if (srcStt.isAllOrdered() && dstStt.isAllOrdered() && 1413 srcStt.hasSameDimToLvl(dstStt)) { 1414 return false; 1415 } 1416 1417 // Source and dest tensors are ordered in different ways. We only do direct 1418 // dense to sparse conversion when the dense input is defined by a sparse 1419 // constant. Note that we can theoretically always directly convert from dense 1420 // inputs by rotating dense loops but it leads to bad cache locality and hurt 1421 // performance. 1422 if (auto constOp = getSource().getDefiningOp<arith::ConstantOp>()) 1423 if (isa<SparseElementsAttr>(constOp.getValue())) 1424 return false; 1425 1426 return true; 1427 } 1428 1429 LogicalResult CrdTranslateOp::verify() { 1430 uint64_t inRank = getEncoder().getLvlRank(); 1431 uint64_t outRank = getEncoder().getDimRank(); 1432 1433 if (getDirection() == CrdTransDirectionKind::dim2lvl) 1434 std::swap(inRank, outRank); 1435 1436 if (inRank != getInCrds().size() || outRank != getOutCrds().size()) 1437 return emitError("Coordinate rank mismatch with encoding"); 1438 1439 return success(); 1440 } 1441 1442 LogicalResult CrdTranslateOp::fold(FoldAdaptor adaptor, 1443 SmallVectorImpl<OpFoldResult> &results) { 1444 if (getEncoder().isIdentity()) { 1445 results.assign(getInCrds().begin(), getInCrds().end()); 1446 return success(); 1447 } 1448 if (getEncoder().isPermutation()) { 1449 AffineMap perm = getDirection() == CrdTransDirectionKind::dim2lvl 1450 ? getEncoder().getDimToLvl() 1451 : getEncoder().getLvlToDim(); 1452 for (AffineExpr exp : perm.getResults()) 1453 results.push_back(getInCrds()[cast<AffineDimExpr>(exp).getPosition()]); 1454 return success(); 1455 } 1456 1457 // Fuse dim2lvl/lvl2dim pairs. 1458 auto def = getInCrds()[0].getDefiningOp<CrdTranslateOp>(); 1459 bool sameDef = def && llvm::all_of(getInCrds(), [def](Value v) { 1460 return v.getDefiningOp() == def; 1461 }); 1462 if (!sameDef) 1463 return failure(); 1464 1465 bool oppositeDir = def.getDirection() != getDirection(); 1466 bool sameOracle = 1467 def.getEncoder().getDimToLvl() == getEncoder().getDimToLvl(); 1468 bool sameCount = def.getNumResults() == getInCrds().size(); 1469 if (!oppositeDir || !sameOracle || !sameCount) 1470 return failure(); 1471 1472 // The definition produces the coordinates in the same order as the input 1473 // coordinates. 1474 bool sameOrder = llvm::all_of(llvm::zip_equal(def.getOutCrds(), getInCrds()), 1475 [](auto valuePair) { 1476 auto [lhs, rhs] = valuePair; 1477 return lhs == rhs; 1478 }); 1479 1480 if (!sameOrder) 1481 return failure(); 1482 // l1 = dim2lvl (lvl2dim l0) 1483 // ==> l0 1484 results.append(def.getInCrds().begin(), def.getInCrds().end()); 1485 return success(); 1486 } 1487 1488 void LvlOp::build(OpBuilder &builder, OperationState &state, Value source, 1489 int64_t index) { 1490 Value val = builder.create<arith::ConstantIndexOp>(state.location, index); 1491 return build(builder, state, source, val); 1492 } 1493 1494 LogicalResult LvlOp::verify() { 1495 if (std::optional<uint64_t> lvl = getConstantLvlIndex()) { 1496 auto stt = getSparseTensorType(getSource()); 1497 if (static_cast<uint64_t>(lvl.value()) >= stt.getLvlRank()) 1498 emitError("Level index exceeds the rank of the input sparse tensor"); 1499 } 1500 return success(); 1501 } 1502 1503 std::optional<uint64_t> LvlOp::getConstantLvlIndex() { 1504 return getConstantIntValue(getIndex()); 1505 } 1506 1507 Speculation::Speculatability LvlOp::getSpeculatability() { 1508 auto constantIndex = getConstantLvlIndex(); 1509 if (!constantIndex) 1510 return Speculation::NotSpeculatable; 1511 1512 assert(constantIndex < 1513 cast<RankedTensorType>(getSource().getType()).getRank()); 1514 return Speculation::Speculatable; 1515 } 1516 1517 OpFoldResult LvlOp::fold(FoldAdaptor adaptor) { 1518 auto lvlIndex = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getIndex()); 1519 if (!lvlIndex) 1520 return {}; 1521 1522 Level lvl = lvlIndex.getAPSInt().getZExtValue(); 1523 auto stt = getSparseTensorType(getSource()); 1524 if (lvl >= stt.getLvlRank()) { 1525 // Follows the same convention used by tensor.dim operation. Out of bound 1526 // indices produce undefined behavior but are still valid IR. Don't choke on 1527 // them. 1528 return {}; 1529 } 1530 1531 // Helper lambda to build an IndexAttr. 1532 auto getIndexAttr = [this](int64_t lvlSz) { 1533 return IntegerAttr::get(IndexType::get(getContext()), APInt(64, lvlSz)); 1534 }; 1535 1536 SmallVector<Size> lvlShape = stt.getLvlShape(); 1537 if (!ShapedType::isDynamic(lvlShape[lvl])) 1538 return getIndexAttr(lvlShape[lvl]); 1539 1540 return {}; 1541 } 1542 1543 void ReinterpretMapOp::build(OpBuilder &odsBuilder, OperationState &odsState, 1544 SparseTensorEncodingAttr dstEnc, Value source) { 1545 auto srcStt = getSparseTensorType(source); 1546 SmallVector<int64_t> srcLvlShape = srcStt.getLvlShape(); 1547 SmallVector<int64_t> dstDimShape = 1548 dstEnc.translateShape(srcLvlShape, CrdTransDirectionKind::lvl2dim); 1549 auto dstTp = 1550 RankedTensorType::get(dstDimShape, srcStt.getElementType(), dstEnc); 1551 return build(odsBuilder, odsState, dstTp, source); 1552 } 1553 1554 LogicalResult ReinterpretMapOp::verify() { 1555 auto srcStt = getSparseTensorType(getSource()); 1556 auto dstStt = getSparseTensorType(getDest()); 1557 ArrayRef<LevelType> srcLvlTps = srcStt.getLvlTypes(); 1558 ArrayRef<LevelType> dstLvlTps = dstStt.getLvlTypes(); 1559 1560 if (srcLvlTps.size() != dstLvlTps.size()) 1561 return emitError("Level rank mismatch between source/dest tensors"); 1562 1563 for (auto [srcLvlTp, dstLvlTp] : llvm::zip(srcLvlTps, dstLvlTps)) 1564 if (srcLvlTp != dstLvlTp) 1565 return emitError("Level type mismatch between source/dest tensors"); 1566 1567 if (srcStt.getPosWidth() != dstStt.getPosWidth() || 1568 srcStt.getCrdWidth() != dstStt.getCrdWidth()) { 1569 return emitError("Crd/Pos width mismatch between source/dest tensors"); 1570 } 1571 1572 if (srcStt.getElementType() != dstStt.getElementType()) 1573 return emitError("Element type mismatch between source/dest tensors"); 1574 1575 SmallVector<Size> srcLvlShape = srcStt.getLvlShape(); 1576 SmallVector<Size> dstLvlShape = dstStt.getLvlShape(); 1577 for (auto [srcLvlSz, dstLvlSz] : llvm::zip(srcLvlShape, dstLvlShape)) { 1578 if (srcLvlSz != dstLvlSz) { 1579 // Should we allow one side to be dynamic size, e.g., <?x?> should be 1580 // compatible to <3x4>? For now, we require all the level sizes to be 1581 // *exactly* matched for simplicity. 1582 return emitError("Level size mismatch between source/dest tensors"); 1583 } 1584 } 1585 1586 return success(); 1587 } 1588 1589 OpFoldResult ReinterpretMapOp::fold(FoldAdaptor adaptor) { 1590 if (getSource().getType() == getDest().getType()) 1591 return getSource(); 1592 1593 if (auto def = getSource().getDefiningOp<ReinterpretMapOp>()) { 1594 // A -> B, B -> A ==> A 1595 if (def.getSource().getType() == getDest().getType()) 1596 return def.getSource(); 1597 } 1598 return {}; 1599 } 1600 1601 template <typename ToBufferOp> 1602 static LogicalResult inferSparseBufferType(ValueRange ops, DictionaryAttr attr, 1603 OpaqueProperties prop, 1604 RegionRange region, 1605 SmallVectorImpl<mlir::Type> &ret) { 1606 typename ToBufferOp::Adaptor adaptor(ops, attr, prop, region); 1607 SparseTensorType stt = getSparseTensorType(adaptor.getTensor()); 1608 Type elemTp = nullptr; 1609 bool withStride = false; 1610 if constexpr (std::is_same_v<ToBufferOp, ToPositionsOp>) { 1611 elemTp = stt.getPosType(); 1612 } else if constexpr (std::is_same_v<ToBufferOp, ToCoordinatesOp> || 1613 std::is_same_v<ToBufferOp, ToCoordinatesBufferOp>) { 1614 elemTp = stt.getCrdType(); 1615 if constexpr (std::is_same_v<ToBufferOp, ToCoordinatesOp>) 1616 withStride = stt.getAoSCOOStart() <= adaptor.getLevel(); 1617 } else if constexpr (std::is_same_v<ToBufferOp, ToValuesOp>) { 1618 elemTp = stt.getElementType(); 1619 } 1620 1621 assert(elemTp && "unhandled operation."); 1622 SmallVector<int64_t> bufShape = stt.getBatchLvlShape(); 1623 bufShape.push_back(ShapedType::kDynamic); 1624 1625 auto layout = withStride ? StridedLayoutAttr::StridedLayoutAttr::get( 1626 stt.getContext(), ShapedType::kDynamic, 1627 {ShapedType::kDynamic}) 1628 : StridedLayoutAttr(); 1629 ret.emplace_back(MemRefType::get(bufShape, elemTp, layout)); 1630 return success(); 1631 } 1632 1633 LogicalResult ToPositionsOp::verify() { 1634 auto stt = getSparseTensorType(getTensor()); 1635 if (failed(lvlIsInBounds(getLevel(), getTensor()))) 1636 return emitError("requested level is out of bounds"); 1637 if (failed(isMatchingWidth(getResult(), stt.getPosWidth()))) 1638 return emitError("unexpected type for positions"); 1639 return success(); 1640 } 1641 1642 LogicalResult 1643 ToPositionsOp::inferReturnTypes(MLIRContext *ctx, std::optional<Location> loc, 1644 ValueRange ops, DictionaryAttr attr, 1645 OpaqueProperties prop, RegionRange region, 1646 SmallVectorImpl<mlir::Type> &ret) { 1647 return inferSparseBufferType<ToPositionsOp>(ops, attr, prop, region, ret); 1648 } 1649 1650 LogicalResult ToCoordinatesOp::verify() { 1651 auto stt = getSparseTensorType(getTensor()); 1652 if (failed(lvlIsInBounds(getLevel(), getTensor()))) 1653 return emitError("requested level is out of bounds"); 1654 if (failed(isMatchingWidth(getResult(), stt.getCrdWidth()))) 1655 return emitError("unexpected type for coordinates"); 1656 return success(); 1657 } 1658 1659 LogicalResult 1660 ToCoordinatesOp::inferReturnTypes(MLIRContext *ctx, std::optional<Location> loc, 1661 ValueRange ops, DictionaryAttr attr, 1662 OpaqueProperties prop, RegionRange region, 1663 SmallVectorImpl<mlir::Type> &ret) { 1664 return inferSparseBufferType<ToCoordinatesOp>(ops, attr, prop, region, ret); 1665 } 1666 1667 LogicalResult ToCoordinatesBufferOp::verify() { 1668 auto stt = getSparseTensorType(getTensor()); 1669 if (stt.getAoSCOOStart() >= stt.getLvlRank()) 1670 return emitError("expected sparse tensor with a COO region"); 1671 return success(); 1672 } 1673 1674 LogicalResult ToCoordinatesBufferOp::inferReturnTypes( 1675 MLIRContext *ctx, std::optional<Location> loc, ValueRange ops, 1676 DictionaryAttr attr, OpaqueProperties prop, RegionRange region, 1677 SmallVectorImpl<mlir::Type> &ret) { 1678 return inferSparseBufferType<ToCoordinatesBufferOp>(ops, attr, prop, region, 1679 ret); 1680 } 1681 1682 LogicalResult ToValuesOp::verify() { 1683 auto stt = getSparseTensorType(getTensor()); 1684 auto mtp = getMemRefType(getResult()); 1685 if (stt.getElementType() != mtp.getElementType()) 1686 return emitError("unexpected mismatch in element types"); 1687 return success(); 1688 } 1689 1690 LogicalResult ToValuesOp::inferReturnTypes(MLIRContext *ctx, 1691 std::optional<Location> loc, 1692 ValueRange ops, DictionaryAttr attr, 1693 OpaqueProperties prop, 1694 RegionRange region, 1695 SmallVectorImpl<mlir::Type> &ret) { 1696 return inferSparseBufferType<ToValuesOp>(ops, attr, prop, region, ret); 1697 } 1698 1699 LogicalResult ToSliceOffsetOp::verify() { 1700 auto rank = getRankedTensorType(getSlice()).getRank(); 1701 if (rank <= getDim().getSExtValue() || getDim().getSExtValue() < 0) 1702 return emitError("requested dimension out of bound"); 1703 return success(); 1704 } 1705 1706 LogicalResult ToSliceStrideOp::verify() { 1707 auto rank = getRankedTensorType(getSlice()).getRank(); 1708 if (rank <= getDim().getSExtValue() || getDim().getSExtValue() < 0) 1709 return emitError("requested dimension out of bound"); 1710 return success(); 1711 } 1712 1713 LogicalResult GetStorageSpecifierOp::verify() { 1714 return verifySparsifierGetterSetter(getSpecifierKind(), getLevel(), 1715 getSpecifier(), getOperation()); 1716 } 1717 1718 template <typename SpecifierOp> 1719 static SetStorageSpecifierOp getSpecifierSetDef(SpecifierOp op) { 1720 return op.getSpecifier().template getDefiningOp<SetStorageSpecifierOp>(); 1721 } 1722 1723 OpFoldResult GetStorageSpecifierOp::fold(FoldAdaptor adaptor) { 1724 const StorageSpecifierKind kind = getSpecifierKind(); 1725 const auto lvl = getLevel(); 1726 for (auto op = getSpecifierSetDef(*this); op; op = getSpecifierSetDef(op)) 1727 if (kind == op.getSpecifierKind() && lvl == op.getLevel()) 1728 return op.getValue(); 1729 return {}; 1730 } 1731 1732 LogicalResult SetStorageSpecifierOp::verify() { 1733 return verifySparsifierGetterSetter(getSpecifierKind(), getLevel(), 1734 getSpecifier(), getOperation()); 1735 } 1736 1737 template <class T> 1738 static LogicalResult verifyNumBlockArgs(T *op, Region ®ion, 1739 const char *regionName, 1740 TypeRange inputTypes, Type outputType) { 1741 unsigned numArgs = region.getNumArguments(); 1742 unsigned expectedNum = inputTypes.size(); 1743 if (numArgs != expectedNum) 1744 return op->emitError() << regionName << " region must have exactly " 1745 << expectedNum << " arguments"; 1746 1747 for (unsigned i = 0; i < numArgs; i++) { 1748 Type typ = region.getArgument(i).getType(); 1749 if (typ != inputTypes[i]) 1750 return op->emitError() << regionName << " region argument " << (i + 1) 1751 << " type mismatch"; 1752 } 1753 Operation *term = region.front().getTerminator(); 1754 YieldOp yield = dyn_cast<YieldOp>(term); 1755 if (!yield) 1756 return op->emitError() << regionName 1757 << " region must end with sparse_tensor.yield"; 1758 if (!yield.hasSingleResult() || 1759 yield.getSingleResult().getType() != outputType) 1760 return op->emitError() << regionName << " region yield type mismatch"; 1761 1762 return success(); 1763 } 1764 1765 LogicalResult BinaryOp::verify() { 1766 NamedAttrList attrs = (*this)->getAttrs(); 1767 Type leftType = getX().getType(); 1768 Type rightType = getY().getType(); 1769 Type outputType = getOutput().getType(); 1770 Region &overlap = getOverlapRegion(); 1771 Region &left = getLeftRegion(); 1772 Region &right = getRightRegion(); 1773 1774 // Check correct number of block arguments and return type for each 1775 // non-empty region. 1776 if (!overlap.empty()) { 1777 if (failed(verifyNumBlockArgs(this, overlap, "overlap", 1778 TypeRange{leftType, rightType}, outputType))) 1779 return failure(); 1780 } 1781 if (!left.empty()) { 1782 if (failed(verifyNumBlockArgs(this, left, "left", TypeRange{leftType}, 1783 outputType))) 1784 return failure(); 1785 } else if (getLeftIdentity()) { 1786 if (leftType != outputType) 1787 return emitError("left=identity requires first argument to have the same " 1788 "type as the output"); 1789 } 1790 if (!right.empty()) { 1791 if (failed(verifyNumBlockArgs(this, right, "right", TypeRange{rightType}, 1792 outputType))) 1793 return failure(); 1794 } else if (getRightIdentity()) { 1795 if (rightType != outputType) 1796 return emitError("right=identity requires second argument to have the " 1797 "same type as the output"); 1798 } 1799 return success(); 1800 } 1801 1802 LogicalResult UnaryOp::verify() { 1803 Type inputType = getX().getType(); 1804 Type outputType = getOutput().getType(); 1805 1806 // Check correct number of block arguments and return type for each 1807 // non-empty region. 1808 Region &present = getPresentRegion(); 1809 if (!present.empty()) { 1810 if (failed(verifyNumBlockArgs(this, present, "present", 1811 TypeRange{inputType}, outputType))) 1812 return failure(); 1813 } 1814 Region &absent = getAbsentRegion(); 1815 if (!absent.empty()) { 1816 if (failed(verifyNumBlockArgs(this, absent, "absent", TypeRange{}, 1817 outputType))) 1818 return failure(); 1819 // Absent branch can only yield invariant values. 1820 Block *absentBlock = &absent.front(); 1821 Block *parent = getOperation()->getBlock(); 1822 Value absentVal = 1823 cast<YieldOp>(absentBlock->getTerminator()).getSingleResult(); 1824 if (auto arg = dyn_cast<BlockArgument>(absentVal)) { 1825 if (arg.getOwner() == parent) 1826 return emitError("absent region cannot yield linalg argument"); 1827 } else if (Operation *def = absentVal.getDefiningOp()) { 1828 if (!isa<arith::ConstantOp>(def) && 1829 (def->getBlock() == absentBlock || def->getBlock() == parent)) 1830 return emitError("absent region cannot yield locally computed value"); 1831 } 1832 } 1833 return success(); 1834 } 1835 1836 bool ConcatenateOp::needsExtraSort() { 1837 SparseTensorType dstStt = getSparseTensorType(*this); 1838 if (dstStt.isAllDense() || !dstStt.isAllOrdered()) 1839 return false; 1840 1841 bool allSameOrdered = llvm::all_of(getInputs(), [dstStt](Value op) { 1842 return getSparseTensorType(op).hasSameDimToLvl(dstStt); 1843 }); 1844 // TODO: When conDim != 0, as long as conDim corresponding to the first level 1845 // in all input/output buffers, and all input/output buffers have the same 1846 // dimToLvl, the tmp COO buffer is still unnecessary (e.g, concatenate 1847 // CSC matrices along column). 1848 bool directLowerable = 1849 allSameOrdered && getDimension() == 0 && dstStt.isIdentity(); 1850 return !directLowerable; 1851 } 1852 1853 LogicalResult ConcatenateOp::verify() { 1854 const auto dstTp = getSparseTensorType(*this); 1855 const Dimension concatDim = getDimension(); 1856 const Dimension dimRank = dstTp.getDimRank(); 1857 1858 if (getInputs().size() <= 1) 1859 return emitError("Need at least two tensors to concatenate."); 1860 1861 if (concatDim >= dimRank) 1862 return emitError(llvm::formatv( 1863 "Concat-dimension is out of bounds for dimension-rank ({0} >= {1})", 1864 concatDim, dimRank)); 1865 1866 for (const auto &it : llvm::enumerate(getInputs())) { 1867 const auto i = it.index(); 1868 const auto srcTp = getSparseTensorType(it.value()); 1869 if (srcTp.hasDynamicDimShape()) 1870 return emitError(llvm::formatv("Input tensor ${0} has dynamic shape", i)); 1871 const Dimension srcDimRank = srcTp.getDimRank(); 1872 if (srcDimRank != dimRank) 1873 return emitError( 1874 llvm::formatv("Input tensor ${0} has a different rank (rank={1}) " 1875 "from the output tensor (rank={2}).", 1876 i, srcDimRank, dimRank)); 1877 } 1878 1879 for (Dimension d = 0; d < dimRank; d++) { 1880 const Size dstSh = dstTp.getDimShape()[d]; 1881 if (d == concatDim) { 1882 if (!ShapedType::isDynamic(dstSh)) { 1883 // If we reach here, then all inputs have static shapes. So we 1884 // can use `getDimShape()[d]` instead of `*getDynamicDimSize(d)` 1885 // to avoid redundant assertions in the loop. 1886 Size sumSz = 0; 1887 for (const auto src : getInputs()) 1888 sumSz += getSparseTensorType(src).getDimShape()[d]; 1889 // If all dimension are statically known, the sum of all the input 1890 // dimensions should be equal to the output dimension. 1891 if (sumSz != dstSh) 1892 return emitError( 1893 "The concatenation dimension of the output tensor should be the " 1894 "sum of all the concatenation dimensions of the input tensors."); 1895 } 1896 } else { 1897 Size prev = dstSh; 1898 for (const auto src : getInputs()) { 1899 const auto sh = getSparseTensorType(src).getDimShape()[d]; 1900 if (!ShapedType::isDynamic(prev) && sh != prev) 1901 return emitError("All dimensions (expect for the concatenating one) " 1902 "should be equal."); 1903 prev = sh; 1904 } 1905 } 1906 } 1907 1908 return success(); 1909 } 1910 1911 void PushBackOp::build(OpBuilder &builder, OperationState &result, 1912 Value curSize, Value inBuffer, Value value) { 1913 build(builder, result, curSize, inBuffer, value, Value()); 1914 } 1915 1916 LogicalResult PushBackOp::verify() { 1917 if (Value n = getN()) { 1918 std::optional<int64_t> nValue = getConstantIntValue(n); 1919 if (nValue && nValue.value() < 1) 1920 return emitOpError("n must be not less than 1"); 1921 } 1922 return success(); 1923 } 1924 1925 LogicalResult CompressOp::verify() { 1926 const auto stt = getSparseTensorType(getTensor()); 1927 if (stt.getLvlRank() != 1 + static_cast<Level>(getLvlCoords().size())) 1928 return emitOpError("incorrect number of coordinates"); 1929 return success(); 1930 } 1931 1932 void ForeachOp::build( 1933 OpBuilder &builder, OperationState &result, Value tensor, 1934 ValueRange initArgs, AffineMapAttr order, 1935 function_ref<void(OpBuilder &, Location, ValueRange, Value, ValueRange)> 1936 bodyBuilder) { 1937 build(builder, result, initArgs.getTypes(), tensor, initArgs, order); 1938 // Builds foreach body. 1939 if (!bodyBuilder) 1940 return; 1941 const auto stt = getSparseTensorType(tensor); 1942 const Dimension dimRank = stt.getDimRank(); 1943 1944 // Starts with `dimRank`-many coordinates. 1945 SmallVector<Type> blockArgTypes(dimRank, builder.getIndexType()); 1946 // Followed by one value. 1947 blockArgTypes.push_back(stt.getElementType()); 1948 // Followed by the reduction variables. 1949 blockArgTypes.append(initArgs.getTypes().begin(), initArgs.getTypes().end()); 1950 1951 SmallVector<Location> blockArgLocs(blockArgTypes.size(), tensor.getLoc()); 1952 1953 OpBuilder::InsertionGuard guard(builder); 1954 auto ®ion = *result.regions.front(); 1955 Block *bodyBlock = 1956 builder.createBlock(®ion, region.end(), blockArgTypes, blockArgLocs); 1957 bodyBuilder(builder, result.location, 1958 bodyBlock->getArguments().slice(0, dimRank), 1959 bodyBlock->getArguments()[dimRank], 1960 bodyBlock->getArguments().drop_front(dimRank + 1)); 1961 } 1962 1963 LogicalResult ForeachOp::verify() { 1964 const auto t = getSparseTensorType(getTensor()); 1965 const Dimension dimRank = t.getDimRank(); 1966 const auto args = getBody()->getArguments(); 1967 1968 if (getOrder().has_value() && getOrder()->getNumDims() != t.getLvlRank()) 1969 return emitError("Level traverse order does not match tensor's level rank"); 1970 1971 if (dimRank + 1 + getInitArgs().size() != args.size()) 1972 return emitError("Unmatched number of arguments in the block"); 1973 1974 if (getNumResults() != getInitArgs().size()) 1975 return emitError("Mismatch in number of init arguments and results"); 1976 1977 if (getResultTypes() != getInitArgs().getTypes()) 1978 return emitError("Mismatch in types of init arguments and results"); 1979 1980 // Cannot mark this const, because the getters aren't. 1981 auto yield = cast<YieldOp>(getBody()->getTerminator()); 1982 if (yield.getNumOperands() != getNumResults() || 1983 yield.getOperands().getTypes() != getResultTypes()) 1984 return emitError("Mismatch in types of yield values and results"); 1985 1986 const auto iTp = IndexType::get(getContext()); 1987 for (Dimension d = 0; d < dimRank; d++) 1988 if (args[d].getType() != iTp) 1989 emitError( 1990 llvm::formatv("Expecting Index type for argument at index {0}", d)); 1991 1992 const auto elemTp = t.getElementType(); 1993 const auto valueTp = args[dimRank].getType(); 1994 if (elemTp != valueTp) 1995 emitError(llvm::formatv("Unmatched element type between input tensor and " 1996 "block argument, expected:{0}, got: {1}", 1997 elemTp, valueTp)); 1998 return success(); 1999 } 2000 2001 OpFoldResult ReorderCOOOp::fold(FoldAdaptor adaptor) { 2002 if (getSparseTensorEncoding(getInputCoo().getType()) == 2003 getSparseTensorEncoding(getResultCoo().getType())) 2004 return getInputCoo(); 2005 2006 return {}; 2007 } 2008 2009 LogicalResult ReorderCOOOp::verify() { 2010 SparseTensorType srcStt = getSparseTensorType(getInputCoo()); 2011 SparseTensorType dstStt = getSparseTensorType(getResultCoo()); 2012 2013 if (!srcStt.isCOOType() || !dstStt.isCOOType()) 2014 emitError("Expected COO sparse tensors only"); 2015 2016 if (!srcStt.hasSameDimToLvl(dstStt)) 2017 emitError("Unmatched dim2lvl map between input and result COO"); 2018 2019 if (srcStt.getPosType() != dstStt.getPosType() || 2020 srcStt.getCrdType() != dstStt.getCrdType() || 2021 srcStt.getElementType() != dstStt.getElementType()) 2022 emitError("Unmatched storage format between input and result COO"); 2023 2024 return success(); 2025 } 2026 2027 LogicalResult ReduceOp::verify() { 2028 Type inputType = getX().getType(); 2029 Region &formula = getRegion(); 2030 return verifyNumBlockArgs(this, formula, "reduce", 2031 TypeRange{inputType, inputType}, inputType); 2032 } 2033 2034 LogicalResult SelectOp::verify() { 2035 Builder b(getContext()); 2036 Type inputType = getX().getType(); 2037 Type boolType = b.getI1Type(); 2038 Region &formula = getRegion(); 2039 return verifyNumBlockArgs(this, formula, "select", TypeRange{inputType}, 2040 boolType); 2041 } 2042 2043 LogicalResult SortOp::verify() { 2044 AffineMap xPerm = getPermMap(); 2045 uint64_t nx = xPerm.getNumDims(); 2046 if (nx < 1) 2047 emitError(llvm::formatv("Expected rank(perm_map) > 1, got {0}", nx)); 2048 2049 if (!xPerm.isPermutation()) 2050 emitError(llvm::formatv("Expected a permutation map, got {0}", xPerm)); 2051 2052 // We can't check the size of the buffers when n or buffer dimensions aren't 2053 // compile-time constants. 2054 std::optional<int64_t> cn = getConstantIntValue(getN()); 2055 if (!cn) 2056 return success(); 2057 2058 // Verify dimensions. 2059 const auto checkDim = [&](Value v, Size minSize, const char *message) { 2060 const Size sh = getMemRefType(v).getShape()[0]; 2061 if (!ShapedType::isDynamic(sh) && sh < minSize) 2062 emitError(llvm::formatv("{0} got {1} < {2}", message, sh, minSize)); 2063 }; 2064 uint64_t n = cn.value(); 2065 uint64_t ny = 0; 2066 if (auto nyAttr = getNyAttr()) 2067 ny = nyAttr.getInt(); 2068 checkDim(getXy(), n * (nx + ny), 2069 "Expected dimension(xy) >= n * (rank(perm_map) + ny)"); 2070 for (Value opnd : getYs()) 2071 checkDim(opnd, n, "Expected dimension(y) >= n"); 2072 2073 return success(); 2074 } 2075 2076 //===----------------------------------------------------------------------===// 2077 // Sparse Tensor Iteration Operations. 2078 //===----------------------------------------------------------------------===// 2079 2080 IterSpaceType IteratorType::getIterSpaceType() const { 2081 return IterSpaceType::get(getContext(), getEncoding(), getLoLvl(), 2082 getHiLvl()); 2083 } 2084 2085 IteratorType IterSpaceType::getIteratorType() const { 2086 return IteratorType::get(getContext(), getEncoding(), getLoLvl(), getHiLvl()); 2087 } 2088 2089 /// Parses a level range in the form "$lo `to` $hi" 2090 /// or simply "$lo" if $hi - $lo = 1 2091 static ParseResult parseLevelRange(AsmParser &parser, Level &lvlLo, 2092 Level &lvlHi) { 2093 if (parser.parseInteger(lvlLo)) 2094 return failure(); 2095 2096 if (succeeded(parser.parseOptionalKeyword("to"))) { 2097 if (parser.parseInteger(lvlHi)) 2098 return failure(); 2099 } else { 2100 lvlHi = lvlLo + 1; 2101 } 2102 2103 if (lvlHi <= lvlLo) 2104 parser.emitError(parser.getNameLoc(), 2105 "expect larger level upper bound than lower bound"); 2106 2107 return success(); 2108 } 2109 2110 /// Parses a level range in the form "$lo `to` $hi" 2111 /// or simply "$lo" if $hi - $lo = 1 2112 static ParseResult parseLevelRange(OpAsmParser &parser, IntegerAttr &lvlLoAttr, 2113 IntegerAttr &lvlHiAttr) { 2114 Level lvlLo, lvlHi; 2115 if (parseLevelRange(parser, lvlLo, lvlHi)) 2116 return failure(); 2117 2118 lvlLoAttr = IntegerAttr::get(parser.getBuilder().getIndexType(), lvlLo); 2119 lvlHiAttr = IntegerAttr::get(parser.getBuilder().getIndexType(), lvlHi); 2120 return success(); 2121 } 2122 2123 /// Prints a level range in the form "$lo `to` $hi" 2124 /// or simply "$lo" if $hi - $lo = 1 2125 static void printLevelRange(AsmPrinter &p, Level lo, Level hi) { 2126 2127 if (lo + 1 == hi) 2128 p << lo; 2129 else 2130 p << lo << " to " << hi; 2131 } 2132 2133 /// Prints a level range in the form "$lo `to` $hi" 2134 /// or simply "$lo" if $hi - $lo = 1 2135 static void printLevelRange(OpAsmPrinter &p, Operation *, IntegerAttr lvlLo, 2136 IntegerAttr lvlHi) { 2137 unsigned lo = lvlLo.getValue().getZExtValue(); 2138 unsigned hi = lvlHi.getValue().getZExtValue(); 2139 printLevelRange(p, lo, hi); 2140 } 2141 2142 /// Parses a list of `optional` defined list in the form of 2143 /// "(%val0, _, %val1, ...)", where `_` is used to annotate that the 2144 /// corresponding value is not defined (e.g., to represent an undefined 2145 /// coordinate in the sparse iteration space). 2146 static ParseResult parseOptionalDefinedList( 2147 OpAsmParser &parser, OperationState &state, I64BitSet &definedSet, 2148 SmallVectorImpl<OpAsmParser::Argument> &definedArgs, 2149 unsigned maxCnt = std::numeric_limits<unsigned>::max(), 2150 OpAsmParser::Delimiter delimiter = OpAsmParser::Delimiter::Paren) { 2151 unsigned cnt = 0; 2152 ParseResult crdList = 2153 parser.parseCommaSeparatedList(delimiter, [&]() -> ParseResult { 2154 if (parser.parseOptionalKeyword("_")) { 2155 if (parser.parseArgument(definedArgs.emplace_back())) 2156 return failure(); 2157 definedSet.set(cnt); 2158 } 2159 cnt += 1; 2160 return success(); 2161 }); 2162 2163 if (cnt > maxCnt) 2164 return parser.emitError(parser.getNameLoc(), 2165 "parsed more value than expected."); 2166 2167 if (failed(crdList)) { 2168 return parser.emitError( 2169 parser.getNameLoc(), 2170 "expecting SSA value or \"_\" for level coordinates"); 2171 } 2172 assert(definedArgs.size() == definedSet.count()); 2173 return success(); 2174 } 2175 2176 static void printOptionalDefinedList(OpAsmPrinter &p, unsigned size, 2177 Block::BlockArgListType blocksArgs, 2178 I64BitSet definedSet) { 2179 if (definedSet.empty()) 2180 return; 2181 2182 for (unsigned i = 0; i < size; i++) { 2183 if (definedSet[i]) { 2184 p << blocksArgs.front(); 2185 blocksArgs = blocksArgs.drop_front(); 2186 } else { 2187 p << "_"; 2188 } 2189 if (i != size - 1) 2190 p << ", "; 2191 } 2192 assert(blocksArgs.empty()); 2193 } 2194 2195 static ParseResult 2196 parseUsedCoordList(OpAsmParser &parser, OperationState &state, 2197 SmallVectorImpl<OpAsmParser::Argument> &coords) { 2198 // Parse "at(%crd0, _, ...)" 2199 I64BitSet crdUsedLvlSet; 2200 if (succeeded(parser.parseOptionalKeyword("at")) && 2201 failed(parseOptionalDefinedList(parser, state, crdUsedLvlSet, coords))) 2202 return failure(); 2203 2204 // Always use IndexType for the coordinate. 2205 for (auto &coord : coords) 2206 coord.type = parser.getBuilder().getIndexType(); 2207 2208 // Set the CrdUsedLvl bitset. 2209 state.addAttribute("crdUsedLvls", 2210 parser.getBuilder().getI64IntegerAttr(crdUsedLvlSet)); 2211 return success(); 2212 } 2213 2214 static ParseResult 2215 parseSparseIterateLoop(OpAsmParser &parser, OperationState &state, 2216 SmallVectorImpl<OpAsmParser::Argument> &iterators, 2217 SmallVectorImpl<OpAsmParser::Argument> &blockArgs) { 2218 SmallVector<OpAsmParser::UnresolvedOperand> spaces; 2219 SmallVector<OpAsmParser::UnresolvedOperand> initArgs; 2220 2221 // Parse "%iters, ... in %spaces, ..." 2222 if (parser.parseArgumentList(iterators) || parser.parseKeyword("in") || 2223 parser.parseOperandList(spaces)) 2224 return failure(); 2225 2226 if (iterators.size() != spaces.size()) 2227 return parser.emitError( 2228 parser.getNameLoc(), 2229 "mismatch in number of sparse iterators and sparse spaces"); 2230 2231 if (failed(parseUsedCoordList(parser, state, blockArgs))) 2232 return failure(); 2233 size_t numCrds = blockArgs.size(); 2234 2235 // Parse "iter_args(%arg = %init, ...)" 2236 bool hasIterArgs = succeeded(parser.parseOptionalKeyword("iter_args")); 2237 if (hasIterArgs) 2238 if (parser.parseAssignmentList(blockArgs, initArgs)) 2239 return failure(); 2240 2241 SmallVector<Type> iterSpaceTps; 2242 // parse ": sparse_tensor.iter_space -> ret" 2243 if (parser.parseColon() || parser.parseTypeList(iterSpaceTps)) 2244 return failure(); 2245 if (iterSpaceTps.size() != spaces.size()) 2246 return parser.emitError(parser.getNameLoc(), 2247 "mismatch in number of iteration space operands " 2248 "and iteration space types"); 2249 2250 for (auto [it, tp] : llvm::zip_equal(iterators, iterSpaceTps)) { 2251 IterSpaceType spaceTp = llvm::dyn_cast<IterSpaceType>(tp); 2252 if (!spaceTp) 2253 return parser.emitError(parser.getNameLoc(), 2254 "expected sparse_tensor.iter_space type for " 2255 "iteration space operands"); 2256 it.type = spaceTp.getIteratorType(); 2257 } 2258 2259 if (hasIterArgs) 2260 if (parser.parseArrowTypeList(state.types)) 2261 return failure(); 2262 2263 // Resolves input operands. 2264 if (parser.resolveOperands(spaces, iterSpaceTps, parser.getNameLoc(), 2265 state.operands)) 2266 return failure(); 2267 2268 if (hasIterArgs) { 2269 // Strip off leading args that used for coordinates. 2270 MutableArrayRef args = MutableArrayRef(blockArgs).drop_front(numCrds); 2271 if (args.size() != initArgs.size() || args.size() != state.types.size()) { 2272 return parser.emitError( 2273 parser.getNameLoc(), 2274 "mismatch in number of iteration arguments and return values"); 2275 } 2276 2277 for (auto [it, init, tp] : llvm::zip_equal(args, initArgs, state.types)) { 2278 it.type = tp; 2279 if (parser.resolveOperand(init, tp, state.operands)) 2280 return failure(); 2281 } 2282 } 2283 return success(); 2284 } 2285 2286 static ParseResult 2287 parseSparseCoIterateLoop(OpAsmParser &parser, OperationState &state, 2288 SmallVectorImpl<Value> &spacesVals, 2289 SmallVectorImpl<OpAsmParser::Argument> &blockArgs) { 2290 2291 // Parse "(%spaces, ...)" 2292 SmallVector<OpAsmParser::UnresolvedOperand> spaces; 2293 if (parser.parseOperandList(spaces, OpAsmParser::Delimiter::Paren)) 2294 return failure(); 2295 2296 SmallVector<OpAsmParser::Argument> coords; 2297 if (failed(parseUsedCoordList(parser, state, coords))) 2298 return failure(); 2299 size_t numCrds = coords.size(); 2300 2301 // Parse "iter_args(%arg = %init, ...)" 2302 SmallVector<OpAsmParser::UnresolvedOperand> initArgs; 2303 bool hasIterArgs = succeeded(parser.parseOptionalKeyword("iter_args")); 2304 if (hasIterArgs) 2305 if (parser.parseAssignmentList(blockArgs, initArgs)) 2306 return failure(); 2307 blockArgs.append(coords); 2308 2309 SmallVector<Type> iterSpaceTps; 2310 // parse ": (sparse_tensor.iter_space, ...) -> ret" 2311 if (parser.parseColon() || parser.parseLParen() || 2312 parser.parseTypeList(iterSpaceTps) || parser.parseRParen()) 2313 return failure(); 2314 2315 if (iterSpaceTps.size() != spaces.size()) 2316 return parser.emitError(parser.getNameLoc(), 2317 "mismatch in number of iteration space operands " 2318 "and iteration space types"); 2319 2320 if (hasIterArgs) 2321 if (parser.parseArrowTypeList(state.types)) 2322 return failure(); 2323 2324 // Resolves input sparse iteration spaces. 2325 if (parser.resolveOperands(spaces, iterSpaceTps, parser.getNameLoc(), 2326 spacesVals)) 2327 return failure(); 2328 state.operands.append(spacesVals); 2329 2330 if (hasIterArgs) { 2331 // Strip off trailing args that used for coordinates. 2332 MutableArrayRef args = MutableArrayRef(blockArgs).drop_back(numCrds); 2333 if (args.size() != initArgs.size() || args.size() != state.types.size()) { 2334 return parser.emitError( 2335 parser.getNameLoc(), 2336 "mismatch in number of iteration arguments and return values"); 2337 } 2338 2339 for (auto [it, init, tp] : llvm::zip_equal(args, initArgs, state.types)) { 2340 it.type = tp; 2341 if (parser.resolveOperand(init, tp, state.operands)) 2342 return failure(); 2343 } 2344 } 2345 return success(); 2346 } 2347 2348 LogicalResult ExtractIterSpaceOp::inferReturnTypes( 2349 MLIRContext *ctx, std::optional<Location> loc, ValueRange ops, 2350 DictionaryAttr attr, OpaqueProperties prop, RegionRange region, 2351 SmallVectorImpl<mlir::Type> &ret) { 2352 2353 ExtractIterSpaceOp::Adaptor adaptor(ops, attr, prop, region); 2354 SparseTensorType stt = getSparseTensorType(adaptor.getTensor()); 2355 ret.push_back(IterSpaceType::get(ctx, stt.getEncoding(), adaptor.getLoLvl(), 2356 adaptor.getHiLvl())); 2357 return success(); 2358 } 2359 2360 LogicalResult ExtractIterSpaceOp::verify() { 2361 if (getLoLvl() >= getHiLvl()) 2362 return emitOpError("expected smaller level low than level high"); 2363 2364 TypedValue<IteratorType> pIter = getParentIter(); 2365 if ((pIter && getLoLvl() == 0) || (!pIter && getLoLvl() != 0)) { 2366 return emitOpError( 2367 "parent iterator should be specified iff level lower bound equals 0"); 2368 } 2369 2370 if (pIter) { 2371 IterSpaceType spaceTp = getExtractedSpace().getType(); 2372 if (pIter.getType().getEncoding() != spaceTp.getEncoding()) 2373 return emitOpError( 2374 "mismatch in parent iterator encoding and iteration space encoding."); 2375 2376 if (spaceTp.getLoLvl() != pIter.getType().getHiLvl()) 2377 return emitOpError("parent iterator should be used to extract an " 2378 "iteration space from a consecutive level."); 2379 } 2380 2381 return success(); 2382 } 2383 2384 LogicalResult ExtractValOp::verify() { 2385 auto stt = getSparseTensorType(getTensor()); 2386 auto itTp = getIterator().getType(); 2387 2388 if (stt.getEncoding() != itTp.getEncoding()) 2389 return emitOpError("mismatch in tensor encoding and iterator encoding."); 2390 2391 if (stt.getLvlRank() != itTp.getHiLvl()) 2392 return emitOpError("must use last-level iterator to extract values. "); 2393 2394 return success(); 2395 } 2396 2397 struct RemoveUnusedLvlCrds : public OpRewritePattern<IterateOp> { 2398 using OpRewritePattern::OpRewritePattern; 2399 2400 LogicalResult matchAndRewrite(IterateOp iterateOp, 2401 PatternRewriter &rewriter) const override { 2402 I64BitSet newUsedLvls(0); 2403 llvm::BitVector toRemove(iterateOp.getBody()->getNumArguments()); 2404 for (unsigned i = 0, e = iterateOp.getSpaceDim(); i < e; i++) { 2405 if (auto crd = iterateOp.getLvlCrd(i)) { 2406 if (crd->getUsers().empty()) 2407 toRemove.set(crd->getArgNumber()); 2408 else 2409 newUsedLvls.set(i); 2410 } 2411 } 2412 2413 // All coordinates are used. 2414 if (toRemove.none()) 2415 return failure(); 2416 2417 rewriter.startOpModification(iterateOp); 2418 iterateOp.setCrdUsedLvls(newUsedLvls); 2419 iterateOp.getBody()->eraseArguments(toRemove); 2420 rewriter.finalizeOpModification(iterateOp); 2421 return success(); 2422 } 2423 }; 2424 2425 void IterateOp::getCanonicalizationPatterns(mlir::RewritePatternSet &results, 2426 mlir::MLIRContext *context) { 2427 results.add<RemoveUnusedLvlCrds>(context); 2428 } 2429 2430 void IterateOp::build(OpBuilder &builder, OperationState &odsState, 2431 Value iterSpace, ValueRange initArgs) { 2432 unsigned rank = llvm::cast<IterSpaceType>(iterSpace.getType()).getSpaceDim(); 2433 // All ones. 2434 I64BitSet set((1 << rank) - 1); 2435 return build(builder, odsState, iterSpace, initArgs, set); 2436 } 2437 2438 void IterateOp::build(OpBuilder &builder, OperationState &odsState, 2439 Value iterSpace, ValueRange initArgs, 2440 I64BitSet crdUsedLvls) { 2441 OpBuilder::InsertionGuard guard(builder); 2442 2443 odsState.addOperands(iterSpace); 2444 odsState.addOperands(initArgs); 2445 odsState.getOrAddProperties<Properties>().crdUsedLvls = 2446 builder.getIntegerAttr(builder.getIntegerType(64), crdUsedLvls); 2447 Region *bodyRegion = odsState.addRegion(); 2448 odsState.addTypes(initArgs.getTypes()); 2449 Block *bodyBlock = builder.createBlock(bodyRegion); 2450 2451 // First argument, sparse iterator 2452 bodyBlock->addArgument( 2453 llvm::cast<IterSpaceType>(iterSpace.getType()).getIteratorType(), 2454 odsState.location); 2455 2456 // Followed by a list of used coordinates. 2457 for (unsigned i = 0, e = crdUsedLvls.count(); i < e; i++) 2458 bodyBlock->addArgument(builder.getIndexType(), odsState.location); 2459 2460 // Followed by a list of user-provided loop arguments. 2461 for (Value v : initArgs) 2462 bodyBlock->addArgument(v.getType(), v.getLoc()); 2463 } 2464 2465 ParseResult IterateOp::parse(OpAsmParser &parser, OperationState &result) { 2466 OpAsmParser::Argument iterator; 2467 OpAsmParser::UnresolvedOperand iterSpace; 2468 2469 SmallVector<OpAsmParser::Argument> iters, iterArgs; 2470 if (parseSparseIterateLoop(parser, result, iters, iterArgs)) 2471 return failure(); 2472 if (iters.size() != 1) 2473 return parser.emitError(parser.getNameLoc(), 2474 "expected only one iterator/iteration space"); 2475 2476 iters.append(iterArgs); 2477 Region *body = result.addRegion(); 2478 if (parser.parseRegion(*body, iters)) 2479 return failure(); 2480 2481 IterateOp::ensureTerminator(*body, parser.getBuilder(), result.location); 2482 2483 // Parse the optional attribute list. 2484 if (parser.parseOptionalAttrDict(result.attributes)) 2485 return failure(); 2486 2487 return success(); 2488 } 2489 2490 /// Prints the initialization list in the form of 2491 /// <prefix>(%inner = %outer, %inner2 = %outer2, <...>) 2492 /// where 'inner' values are assumed to be region arguments and 'outer' values 2493 /// are regular SSA values. 2494 static void printInitializationList(OpAsmPrinter &p, 2495 Block::BlockArgListType blocksArgs, 2496 ValueRange initializers, 2497 StringRef prefix = "") { 2498 assert(blocksArgs.size() == initializers.size() && 2499 "expected same length of arguments and initializers"); 2500 if (initializers.empty()) 2501 return; 2502 2503 p << prefix << '('; 2504 llvm::interleaveComma(llvm::zip(blocksArgs, initializers), p, [&](auto it) { 2505 p << std::get<0>(it) << " = " << std::get<1>(it); 2506 }); 2507 p << ")"; 2508 } 2509 2510 template <typename SparseLoopOp> 2511 static LogicalResult verifySparseLoopOp(SparseLoopOp op) { 2512 if (op.getInitArgs().size() != op.getNumResults()) { 2513 return op.emitOpError( 2514 "mismatch in number of loop-carried values and defined values"); 2515 } 2516 if (op.getCrdUsedLvls().max() > op.getSpaceDim()) 2517 return op.emitOpError("required out-of-bound coordinates"); 2518 2519 return success(); 2520 } 2521 2522 LogicalResult IterateOp::verify() { return verifySparseLoopOp(*this); } 2523 LogicalResult CoIterateOp::verify() { return verifySparseLoopOp(*this); } 2524 2525 void IterateOp::print(OpAsmPrinter &p) { 2526 p << " " << getIterator() << " in " << getIterSpace(); 2527 if (!getCrdUsedLvls().empty()) { 2528 p << " at("; 2529 printOptionalDefinedList(p, getSpaceDim(), getCrds(), getCrdUsedLvls()); 2530 p << ")"; 2531 } 2532 printInitializationList(p, getRegionIterArgs(), getInitArgs(), " iter_args"); 2533 2534 p << " : " << getIterSpace().getType() << " "; 2535 if (!getInitArgs().empty()) 2536 p.printArrowTypeList(getInitArgs().getTypes()); 2537 2538 p << " "; 2539 p.printRegion(getRegion(), /*printEntryBlockArgs=*/false, 2540 /*printBlockTerminators=*/!getInitArgs().empty()); 2541 } 2542 2543 LogicalResult IterateOp::verifyRegions() { 2544 if (getIterator().getType() != getIterSpace().getType().getIteratorType()) 2545 return emitOpError("mismatch in iterator and iteration space type"); 2546 if (getNumRegionIterArgs() != getNumResults()) 2547 return emitOpError( 2548 "mismatch in number of basic block args and defined values"); 2549 2550 auto initArgs = getInitArgs(); 2551 auto iterArgs = getRegionIterArgs(); 2552 auto yieldVals = getYieldedValues(); 2553 auto opResults = getResults(); 2554 if (!llvm::all_equal({initArgs.size(), iterArgs.size(), yieldVals.size(), 2555 opResults.size()})) { 2556 return emitOpError() << "number mismatch between iter args and results."; 2557 } 2558 2559 for (auto [i, init, iter, yield, ret] : 2560 llvm::enumerate(initArgs, iterArgs, yieldVals, opResults)) { 2561 if (init.getType() != ret.getType()) 2562 return emitOpError() << "types mismatch between " << i 2563 << "th iter operand and defined value"; 2564 if (iter.getType() != ret.getType()) 2565 return emitOpError() << "types mismatch between " << i 2566 << "th iter region arg and defined value"; 2567 if (yield.getType() != ret.getType()) 2568 return emitOpError() << "types mismatch between " << i 2569 << "th yield value and defined value"; 2570 } 2571 2572 return success(); 2573 } 2574 2575 /// OpInterfaces' methods implemented by IterateOp. 2576 SmallVector<Region *> IterateOp::getLoopRegions() { return {&getRegion()}; } 2577 2578 MutableArrayRef<OpOperand> IterateOp::getInitsMutable() { 2579 return getInitArgsMutable(); 2580 } 2581 2582 Block::BlockArgListType IterateOp::getRegionIterArgs() { 2583 return getRegion().getArguments().take_back(getNumRegionIterArgs()); 2584 } 2585 2586 std::optional<MutableArrayRef<OpOperand>> IterateOp::getYieldedValuesMutable() { 2587 return cast<sparse_tensor::YieldOp>( 2588 getRegion().getBlocks().front().getTerminator()) 2589 .getResultsMutable(); 2590 } 2591 2592 std::optional<ResultRange> IterateOp::getLoopResults() { return getResults(); } 2593 2594 OperandRange IterateOp::getEntrySuccessorOperands(RegionBranchPoint point) { 2595 return getInitArgs(); 2596 } 2597 2598 void IterateOp::getSuccessorRegions(RegionBranchPoint point, 2599 SmallVectorImpl<RegionSuccessor> ®ions) { 2600 // Both the operation itself and the region may be branching into the body 2601 // or back into the operation itself. 2602 regions.push_back(RegionSuccessor(&getRegion(), getRegionIterArgs())); 2603 // It is possible for loop not to enter the body. 2604 regions.push_back(RegionSuccessor(getResults())); 2605 } 2606 2607 void CoIterateOp::build(OpBuilder &builder, OperationState &odsState, 2608 ValueRange iterSpaces, ValueRange initArgs, 2609 unsigned numCases) { 2610 unsigned rank = 2611 cast<IterSpaceType>(iterSpaces.front().getType()).getSpaceDim(); 2612 // All ones. 2613 I64BitSet set((1 << rank) - 1); 2614 // Generates all-zero case bits (they only serve as placeholders), which are 2615 // supposed to be overriden later. We need to preallocate all the regions as 2616 // mlir::Region cannot be dynamically added later after the operation is 2617 // created. 2618 SmallVector<int64_t> caseBits(numCases, 0); 2619 ArrayAttr cases = builder.getI64ArrayAttr(caseBits); 2620 return CoIterateOp::build(builder, odsState, initArgs.getTypes(), iterSpaces, 2621 initArgs, set, cases, 2622 /*caseRegionsCount=*/numCases); 2623 } 2624 2625 ParseResult CoIterateOp::parse(OpAsmParser &parser, OperationState &result) { 2626 2627 SmallVector<Value> spaces; 2628 // The block argument list of each regions, it is arranged in the order of 2629 // ([used coordinate list], [loop iterations args], [sparse iterator list]). 2630 SmallVector<OpAsmParser::Argument> blockArgs; 2631 if (parseSparseCoIterateLoop(parser, result, spaces, blockArgs)) 2632 return failure(); 2633 2634 result.addAttribute("operandSegmentSizes", 2635 parser.getBuilder().getDenseI32ArrayAttr( 2636 {static_cast<int32_t>(spaces.size()), 2637 static_cast<int32_t>(result.types.size())})); 2638 2639 SmallVector<Attribute> cases; 2640 while (succeeded(parser.parseOptionalKeyword("case"))) { 2641 // Parse one region per case. 2642 I64BitSet definedItSet; 2643 SmallVector<OpAsmParser::Argument> definedIts; 2644 if (parseOptionalDefinedList(parser, result, definedItSet, definedIts, 2645 spaces.size(), OpAsmParser::Delimiter::None)) 2646 return failure(); 2647 2648 cases.push_back(parser.getBuilder().getI64IntegerAttr(definedItSet)); 2649 2650 for (auto [i, definedIdx] : llvm::enumerate(definedItSet.bits())) { 2651 // Resolve the iterator type based on the iteration space type. 2652 auto spaceTp = llvm::cast<IterSpaceType>(spaces[definedIdx].getType()); 2653 definedIts[i].type = spaceTp.getIteratorType(); 2654 } 2655 definedIts.insert(definedIts.begin(), blockArgs.begin(), blockArgs.end()); 2656 Region *body = result.addRegion(); 2657 if (parser.parseRegion(*body, definedIts)) 2658 return failure(); 2659 2660 CoIterateOp::ensureTerminator(*body, parser.getBuilder(), result.location); 2661 } 2662 2663 result.addAttribute("cases", ArrayAttr::get(parser.getContext(), cases)); 2664 2665 // Parse the optional attribute list. 2666 if (parser.parseOptionalAttrDict(result.attributes)) 2667 return failure(); 2668 2669 return success(); 2670 } 2671 2672 void CoIterateOp::print(OpAsmPrinter &p) { 2673 p << " ("; 2674 llvm::interleaveComma(getIterSpaces(), p, [&](auto s) { p << s; }); 2675 p << ")"; 2676 2677 if (!getCrdUsedLvls().empty()) { 2678 p << " at("; 2679 printOptionalDefinedList(p, getSpaceDim(), getCrds(0), getCrdUsedLvls()); 2680 p << ")"; 2681 } 2682 2683 printInitializationList(p, getRegionIterArgs(0), getInitArgs(), " iter_args"); 2684 2685 p << " : (" << getIterSpaces().getTypes() << ")"; 2686 if (!getInitArgs().empty()) 2687 p.printArrowTypeList(getInitArgs().getTypes()); 2688 2689 for (unsigned idx = 0, e = getRegions().size(); idx < e; idx++) { 2690 p.printNewline(); 2691 p << "case "; 2692 printOptionalDefinedList(p, getIterSpaces().size(), getRegionIterators(idx), 2693 getRegionDefinedSpace(idx)); 2694 p << " "; 2695 p.printRegion(getRegion(idx), /*printEntryBlockArgs=*/false, 2696 /*printBlockTerminators=*/!getInitArgs().empty()); 2697 } 2698 } 2699 2700 ValueRange CoIterateOp::getYieldedValues(unsigned regionIdx) { 2701 return cast<sparse_tensor::YieldOp>( 2702 getRegion(regionIdx).getBlocks().front().getTerminator()) 2703 .getResults(); 2704 } 2705 2706 LogicalResult CoIterateOp::verifyRegions() { 2707 for (unsigned r = 0, e = getNumRegions(); r < e; r++) { 2708 if (getNumRegionIterArgs() != getNumResults()) 2709 return emitOpError( 2710 "mismatch in number of basic block args and defined values"); 2711 2712 auto initArgs = getInitArgs(); 2713 auto iterArgs = getRegionIterArgs(r); 2714 auto yieldVals = getYieldedValues(r); 2715 auto opResults = getResults(); 2716 if (!llvm::all_equal({initArgs.size(), iterArgs.size(), yieldVals.size(), 2717 opResults.size()})) { 2718 return emitOpError() 2719 << "number mismatch between iter args and results on " << r 2720 << "th region"; 2721 } 2722 2723 for (auto [i, init, iter, yield, ret] : 2724 llvm::enumerate(initArgs, iterArgs, yieldVals, opResults)) { 2725 if (init.getType() != ret.getType()) 2726 return emitOpError() 2727 << "types mismatch between " << i 2728 << "th iter operand and defined value on " << r << "th region"; 2729 if (iter.getType() != ret.getType()) 2730 return emitOpError() << "types mismatch between " << i 2731 << "th iter region arg and defined value on " << r 2732 << "th region"; 2733 if (yield.getType() != ret.getType()) 2734 return emitOpError() 2735 << "types mismatch between " << i 2736 << "th yield value and defined value on " << r << "th region"; 2737 } 2738 } 2739 2740 auto cases = getRegionDefinedSpaces(); 2741 llvm::SmallSetVector<uint64_t, 8> set(cases.begin(), cases.end()); 2742 if (set.size() != getNumRegions()) 2743 return emitOpError("contains duplicated cases."); 2744 2745 return success(); 2746 } 2747 2748 //===----------------------------------------------------------------------===// 2749 // Sparse Tensor Dialect Setups. 2750 //===----------------------------------------------------------------------===// 2751 2752 /// Materialize a single constant operation from a given attribute value with 2753 /// the desired resultant type. 2754 Operation *SparseTensorDialect::materializeConstant(OpBuilder &builder, 2755 Attribute value, Type type, 2756 Location loc) { 2757 if (auto op = arith::ConstantOp::materialize(builder, value, type, loc)) 2758 return op; 2759 return nullptr; 2760 } 2761 2762 namespace { 2763 struct SparseTensorAsmDialectInterface : public OpAsmDialectInterface { 2764 using OpAsmDialectInterface::OpAsmDialectInterface; 2765 2766 AliasResult getAlias(Attribute attr, raw_ostream &os) const override { 2767 if (isa<SparseTensorEncodingAttr>(attr)) { 2768 os << "sparse"; 2769 return AliasResult::OverridableAlias; 2770 } 2771 return AliasResult::NoAlias; 2772 } 2773 }; 2774 } // namespace 2775 2776 void SparseTensorDialect::initialize() { 2777 addInterface<SparseTensorAsmDialectInterface>(); 2778 addAttributes< 2779 #define GET_ATTRDEF_LIST 2780 #include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.cpp.inc" 2781 >(); 2782 addTypes< 2783 #define GET_TYPEDEF_LIST 2784 #include "mlir/Dialect/SparseTensor/IR/SparseTensorTypes.cpp.inc" 2785 >(); 2786 addOperations< 2787 #define GET_OP_LIST 2788 #include "mlir/Dialect/SparseTensor/IR/SparseTensorOps.cpp.inc" 2789 >(); 2790 declarePromisedInterfaces< 2791 bufferization::BufferizableOpInterface, ConcatenateOp, ConvertOp, LoadOp, 2792 NewOp, NumberOfEntriesOp, AssembleOp, DisassembleOp, 2793 ToCoordinatesBufferOp, ToCoordinatesOp, ToPositionsOp, ToValuesOp>(); 2794 } 2795 2796 #define GET_OP_CLASSES 2797 #include "mlir/Dialect/SparseTensor/IR/SparseTensorOps.cpp.inc" 2798 2799 #include "mlir/Dialect/SparseTensor/IR/SparseTensorOpsDialect.cpp.inc" 2800