1 //===- CodegenUtils.cpp - Utilities for generating MLIR -------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "CodegenUtils.h" 10 #include "SparseTensorDescriptor.h" 11 12 #include "mlir/Dialect/Affine/IR/AffineOps.h" 13 #include "mlir/Dialect/Bufferization/IR/Bufferization.h" 14 #include "mlir/Dialect/Linalg/IR/Linalg.h" 15 #include "mlir/Dialect/Linalg/Utils/Utils.h" 16 #include "mlir/Dialect/MemRef/IR/MemRef.h" 17 #include "mlir/Dialect/Tensor/IR/Tensor.h" 18 #include "mlir/IR/Matchers.h" 19 #include "mlir/IR/Types.h" 20 #include "mlir/IR/Value.h" 21 #include <optional> 22 23 using namespace mlir; 24 using namespace mlir::sparse_tensor; 25 26 //===----------------------------------------------------------------------===// 27 // ExecutionEngine/SparseTensorUtils helper functions. 28 //===----------------------------------------------------------------------===// 29 30 OverheadType mlir::sparse_tensor::overheadTypeEncoding(unsigned width) { 31 switch (width) { 32 case 64: 33 return OverheadType::kU64; 34 case 32: 35 return OverheadType::kU32; 36 case 16: 37 return OverheadType::kU16; 38 case 8: 39 return OverheadType::kU8; 40 case 0: 41 return OverheadType::kIndex; 42 } 43 llvm_unreachable("Unsupported overhead bitwidth"); 44 } 45 46 OverheadType mlir::sparse_tensor::overheadTypeEncoding(Type tp) { 47 if (tp.isIndex()) 48 return OverheadType::kIndex; 49 if (auto intTp = dyn_cast<IntegerType>(tp)) 50 return overheadTypeEncoding(intTp.getWidth()); 51 llvm_unreachable("Unknown overhead type"); 52 } 53 54 Type mlir::sparse_tensor::getOverheadType(Builder &builder, OverheadType ot) { 55 switch (ot) { 56 case OverheadType::kIndex: 57 return builder.getIndexType(); 58 case OverheadType::kU64: 59 return builder.getIntegerType(64); 60 case OverheadType::kU32: 61 return builder.getIntegerType(32); 62 case OverheadType::kU16: 63 return builder.getIntegerType(16); 64 case OverheadType::kU8: 65 return builder.getIntegerType(8); 66 } 67 llvm_unreachable("Unknown OverheadType"); 68 } 69 70 OverheadType 71 mlir::sparse_tensor::posTypeEncoding(SparseTensorEncodingAttr enc) { 72 return overheadTypeEncoding(enc.getPosWidth()); 73 } 74 75 OverheadType 76 mlir::sparse_tensor::crdTypeEncoding(SparseTensorEncodingAttr enc) { 77 return overheadTypeEncoding(enc.getCrdWidth()); 78 } 79 80 // TODO: we ought to add some `static_assert` tests to ensure that the 81 // `STEA::get{Pos,Crd}Type` methods agree with `getOverheadType(builder, 82 // {pos,crd}OverheadTypeEncoding(enc))` 83 84 // TODO: Adjust the naming convention for the constructors of 85 // `OverheadType` so we can use the `MLIR_SPARSETENSOR_FOREVERY_O` x-macro 86 // here instead of `MLIR_SPARSETENSOR_FOREVERY_FIXED_O`; to further reduce 87 // the possibility of typo bugs or things getting out of sync. 88 StringRef mlir::sparse_tensor::overheadTypeFunctionSuffix(OverheadType ot) { 89 switch (ot) { 90 case OverheadType::kIndex: 91 return "0"; 92 #define CASE(ONAME, O) \ 93 case OverheadType::kU##ONAME: \ 94 return #ONAME; 95 MLIR_SPARSETENSOR_FOREVERY_FIXED_O(CASE) 96 #undef CASE 97 } 98 llvm_unreachable("Unknown OverheadType"); 99 } 100 101 StringRef mlir::sparse_tensor::overheadTypeFunctionSuffix(Type tp) { 102 return overheadTypeFunctionSuffix(overheadTypeEncoding(tp)); 103 } 104 105 PrimaryType mlir::sparse_tensor::primaryTypeEncoding(Type elemTp) { 106 if (elemTp.isF64()) 107 return PrimaryType::kF64; 108 if (elemTp.isF32()) 109 return PrimaryType::kF32; 110 if (elemTp.isF16()) 111 return PrimaryType::kF16; 112 if (elemTp.isBF16()) 113 return PrimaryType::kBF16; 114 if (elemTp.isInteger(64)) 115 return PrimaryType::kI64; 116 if (elemTp.isInteger(32)) 117 return PrimaryType::kI32; 118 if (elemTp.isInteger(16)) 119 return PrimaryType::kI16; 120 if (elemTp.isInteger(8)) 121 return PrimaryType::kI8; 122 if (auto complexTp = dyn_cast<ComplexType>(elemTp)) { 123 auto complexEltTp = complexTp.getElementType(); 124 if (complexEltTp.isF64()) 125 return PrimaryType::kC64; 126 if (complexEltTp.isF32()) 127 return PrimaryType::kC32; 128 } 129 llvm_unreachable("Unknown primary type"); 130 } 131 132 StringRef mlir::sparse_tensor::primaryTypeFunctionSuffix(PrimaryType pt) { 133 switch (pt) { 134 #define CASE(VNAME, V) \ 135 case PrimaryType::k##VNAME: \ 136 return #VNAME; 137 MLIR_SPARSETENSOR_FOREVERY_V(CASE) 138 #undef CASE 139 } 140 llvm_unreachable("Unknown PrimaryType"); 141 } 142 143 StringRef mlir::sparse_tensor::primaryTypeFunctionSuffix(Type elemTp) { 144 return primaryTypeFunctionSuffix(primaryTypeEncoding(elemTp)); 145 } 146 147 //===----------------------------------------------------------------------===// 148 // Misc code generators. 149 //===----------------------------------------------------------------------===// 150 151 Value sparse_tensor::genCast(OpBuilder &builder, Location loc, Value value, 152 Type dstTp) { 153 const Type srcTp = value.getType(); 154 if (srcTp == dstTp) 155 return value; 156 157 // int <=> index 158 if (isa<IndexType>(srcTp) || isa<IndexType>(dstTp)) 159 return builder.create<arith::IndexCastOp>(loc, dstTp, value); 160 161 const auto srcIntTp = dyn_cast_or_null<IntegerType>(srcTp); 162 const bool isUnsignedCast = srcIntTp ? srcIntTp.isUnsigned() : false; 163 return mlir::convertScalarToDtype(builder, loc, value, dstTp, isUnsignedCast); 164 } 165 166 Value sparse_tensor::genScalarToTensor(OpBuilder &builder, Location loc, 167 Value elem, Type dstTp) { 168 if (auto rtp = dstTp.dyn_cast<RankedTensorType>()) { 169 // Scalars can only be converted to 0-ranked tensors. 170 assert(rtp.getRank() == 0); 171 elem = sparse_tensor::genCast(builder, loc, elem, rtp.getElementType()); 172 return builder.create<tensor::FromElementsOp>(loc, rtp, elem); 173 } 174 return sparse_tensor::genCast(builder, loc, elem, dstTp); 175 } 176 177 Value sparse_tensor::genIndexLoad(OpBuilder &builder, Location loc, Value mem, 178 ValueRange s) { 179 Value load = builder.create<memref::LoadOp>(loc, mem, s); 180 if (!isa<IndexType>(load.getType())) { 181 if (load.getType().getIntOrFloatBitWidth() < 64) 182 load = builder.create<arith::ExtUIOp>(loc, builder.getI64Type(), load); 183 load = 184 builder.create<arith::IndexCastOp>(loc, builder.getIndexType(), load); 185 } 186 return load; 187 } 188 189 mlir::TypedAttr mlir::sparse_tensor::getOneAttr(Builder &builder, Type tp) { 190 if (isa<FloatType>(tp)) 191 return builder.getFloatAttr(tp, 1.0); 192 if (isa<IndexType>(tp)) 193 return builder.getIndexAttr(1); 194 if (auto intTp = dyn_cast<IntegerType>(tp)) 195 return builder.getIntegerAttr(tp, APInt(intTp.getWidth(), 1)); 196 if (isa<RankedTensorType, VectorType>(tp)) { 197 auto shapedTp = cast<ShapedType>(tp); 198 if (auto one = getOneAttr(builder, shapedTp.getElementType())) 199 return DenseElementsAttr::get(shapedTp, one); 200 } 201 llvm_unreachable("Unsupported attribute type"); 202 } 203 204 Value mlir::sparse_tensor::genIsNonzero(OpBuilder &builder, mlir::Location loc, 205 Value v) { 206 Type tp = v.getType(); 207 Value zero = constantZero(builder, loc, tp); 208 if (isa<FloatType>(tp)) 209 return builder.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UNE, v, 210 zero); 211 if (tp.isIntOrIndex()) 212 return builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ne, v, 213 zero); 214 if (dyn_cast<ComplexType>(tp)) 215 return builder.create<complex::NotEqualOp>(loc, v, zero); 216 llvm_unreachable("Non-numeric type"); 217 } 218 219 void mlir::sparse_tensor::genReshapeDstShape( 220 OpBuilder &builder, Location loc, SmallVectorImpl<Value> &dstShape, 221 ArrayRef<Value> srcShape, ArrayRef<Size> staticDstShape, 222 ArrayRef<ReassociationIndices> reassociation) { 223 // Collapse shape. 224 if (reassociation.size() < srcShape.size()) { 225 unsigned start = 0; 226 for (const auto &map : llvm::enumerate(reassociation)) { 227 auto dstDim = constantIndex(builder, loc, 1); 228 for (unsigned i = start; i < start + map.value().size(); i++) { 229 dstDim = builder.create<arith::MulIOp>(loc, dstDim, srcShape[i]); 230 } 231 dstShape.push_back(dstDim); 232 start = start + map.value().size(); 233 } 234 assert(start == srcShape.size()); 235 return; 236 } 237 238 // Expand shape. 239 assert(reassociation.size() == srcShape.size()); 240 unsigned start = 0; 241 // Expand the i-th dimension in srcShape. 242 for (unsigned i = 0, size = srcShape.size(); i < size; i++) { 243 const auto &map = reassociation[i]; 244 auto srcDim = srcShape[i]; 245 // Iterate through dimensions expanded from the i-th dimension. 246 for (unsigned j = start; j < start + map.size(); j++) { 247 // There can be only one dynamic sized dimension among dimensions 248 // expanded from the i-th dimension in srcShape. 249 // For example, if srcDim = 8, then the expanded shape could be <2x?x2>, 250 // but not <2x?x?>. 251 if (staticDstShape[j] == ShapedType::kDynamic) { 252 // The expanded dimension has dynamic size. We compute the dimension 253 // by dividing srcDim by the product of the static dimensions. 254 Size product = 1; 255 for (unsigned k = start; k < start + map.size(); k++) { 256 if (staticDstShape[k] != ShapedType::kDynamic) { 257 product *= staticDstShape[k]; 258 } 259 } 260 // Compute the dynamic dimension size. 261 Value productVal = constantIndex(builder, loc, product); 262 Value dynamicSize = 263 builder.create<arith::DivUIOp>(loc, srcDim, productVal); 264 dstShape.push_back(dynamicSize); 265 } else { 266 // The expanded dimension is statically known. 267 dstShape.push_back(constantIndex(builder, loc, staticDstShape[j])); 268 } 269 } 270 start = start + map.size(); 271 } 272 assert(start == staticDstShape.size()); 273 } 274 275 void mlir::sparse_tensor::reshapeCvs( 276 OpBuilder &builder, Location loc, 277 ArrayRef<ReassociationIndices> reassociation, // NOLINT 278 ValueRange srcSizes, ValueRange srcCvs, // NOLINT 279 ValueRange dstSizes, SmallVectorImpl<Value> &dstCvs) { 280 const unsigned srcRank = srcSizes.size(); 281 const unsigned dstRank = dstSizes.size(); 282 assert(srcRank == srcCvs.size() && "Source rank mismatch"); 283 const bool isCollapse = srcRank > dstRank; 284 const ValueRange sizes = isCollapse ? srcSizes : dstSizes; 285 // Iterate over reassociation map. 286 unsigned i = 0; 287 unsigned start = 0; 288 for (const auto &map : llvm::enumerate(reassociation)) { 289 // Prepare strides information in dimension slice. 290 Value linear = constantIndex(builder, loc, 1); 291 for (unsigned j = start, end = start + map.value().size(); j < end; j++) { 292 linear = builder.create<arith::MulIOp>(loc, linear, sizes[j]); 293 } 294 // Start expansion. 295 Value val; 296 if (!isCollapse) 297 val = srcCvs[i]; 298 // Iterate over dimension slice. 299 for (unsigned j = start, end = start + map.value().size(); j < end; j++) { 300 linear = builder.create<arith::DivUIOp>(loc, linear, sizes[j]); 301 if (isCollapse) { 302 const Value mul = builder.create<arith::MulIOp>(loc, srcCvs[j], linear); 303 val = val ? builder.create<arith::AddIOp>(loc, val, mul) : mul; 304 } else { 305 const Value old = val; 306 val = builder.create<arith::DivUIOp>(loc, val, linear); 307 assert(dstCvs.size() == j); 308 dstCvs.push_back(val); 309 val = builder.create<arith::RemUIOp>(loc, old, linear); 310 } 311 } 312 // Finalize collapse. 313 if (isCollapse) { 314 assert(dstCvs.size() == i); 315 dstCvs.push_back(val); 316 } 317 start += map.value().size(); 318 i++; 319 } 320 assert(dstCvs.size() == dstRank); 321 } 322 323 FlatSymbolRefAttr mlir::sparse_tensor::getFunc(ModuleOp module, StringRef name, 324 TypeRange resultType, 325 ValueRange operands, 326 EmitCInterface emitCInterface) { 327 MLIRContext *context = module.getContext(); 328 auto result = SymbolRefAttr::get(context, name); 329 auto func = module.lookupSymbol<func::FuncOp>(result.getAttr()); 330 if (!func) { 331 OpBuilder moduleBuilder(module.getBodyRegion()); 332 func = moduleBuilder.create<func::FuncOp>( 333 module.getLoc(), name, 334 FunctionType::get(context, operands.getTypes(), resultType)); 335 func.setPrivate(); 336 if (static_cast<bool>(emitCInterface)) 337 func->setAttr(LLVM::LLVMDialect::getEmitCWrapperAttrName(), 338 UnitAttr::get(context)); 339 } 340 return result; 341 } 342 343 func::CallOp mlir::sparse_tensor::createFuncCall( 344 OpBuilder &builder, Location loc, StringRef name, TypeRange resultType, 345 ValueRange operands, EmitCInterface emitCInterface) { 346 auto module = builder.getBlock()->getParentOp()->getParentOfType<ModuleOp>(); 347 FlatSymbolRefAttr fn = 348 getFunc(module, name, resultType, operands, emitCInterface); 349 return builder.create<func::CallOp>(loc, resultType, fn, operands); 350 } 351 352 Type mlir::sparse_tensor::getOpaquePointerType(MLIRContext *ctx) { 353 return LLVM::LLVMPointerType::get(ctx); 354 } 355 356 Type mlir::sparse_tensor::getOpaquePointerType(Builder &builder) { 357 return getOpaquePointerType(builder.getContext()); 358 } 359 360 Value mlir::sparse_tensor::genAlloca(OpBuilder &builder, Location loc, 361 unsigned sz, Type tp, bool staticShape) { 362 if (staticShape) { 363 auto memTp = MemRefType::get({sz}, tp); 364 return builder.create<memref::AllocaOp>(loc, memTp); 365 } 366 return genAlloca(builder, loc, constantIndex(builder, loc, sz), tp); 367 } 368 369 Value mlir::sparse_tensor::genAlloca(OpBuilder &builder, Location loc, Value sz, 370 Type tp) { 371 auto memTp = MemRefType::get({ShapedType::kDynamic}, tp); 372 return builder.create<memref::AllocaOp>(loc, memTp, ValueRange{sz}); 373 } 374 375 Value mlir::sparse_tensor::genAllocaScalar(OpBuilder &builder, Location loc, 376 Type tp) { 377 return builder.create<memref::AllocaOp>(loc, MemRefType::get({}, tp)); 378 } 379 380 Value mlir::sparse_tensor::allocaBuffer(OpBuilder &builder, Location loc, 381 ValueRange values) { 382 const unsigned sz = values.size(); 383 assert(sz >= 1); 384 Value buffer = genAlloca(builder, loc, sz, values[0].getType()); 385 for (unsigned i = 0; i < sz; i++) { 386 Value idx = constantIndex(builder, loc, i); 387 builder.create<memref::StoreOp>(loc, values[i], buffer, idx); 388 } 389 return buffer; 390 } 391 392 Value mlir::sparse_tensor::allocDenseTensor(OpBuilder &builder, Location loc, 393 RankedTensorType tensorTp, 394 ValueRange sizes) { 395 Type elemTp = tensorTp.getElementType(); 396 auto shape = tensorTp.getShape(); 397 auto memTp = MemRefType::get(shape, elemTp); 398 SmallVector<Value> dynamicSizes; 399 for (unsigned i = 0, rank = tensorTp.getRank(); i < rank; i++) { 400 if (shape[i] == ShapedType::kDynamic) 401 dynamicSizes.push_back(sizes[i]); 402 } 403 Value mem = builder.create<memref::AllocOp>(loc, memTp, dynamicSizes); 404 Value zero = constantZero(builder, loc, elemTp); 405 builder.create<linalg::FillOp>(loc, ValueRange{zero}, ValueRange{mem}); 406 return mem; 407 } 408 409 void mlir::sparse_tensor::deallocDenseTensor(OpBuilder &builder, Location loc, 410 Value buffer) { 411 builder.create<memref::DeallocOp>(loc, buffer); 412 } 413 414 void mlir::sparse_tensor::sizesFromSrc(OpBuilder &builder, 415 SmallVectorImpl<Value> &sizes, 416 Location loc, Value src) { 417 const Dimension dimRank = getSparseTensorType(src).getDimRank(); 418 for (Dimension d = 0; d < dimRank; d++) 419 sizes.push_back(linalg::createOrFoldDimOp(builder, loc, src, d)); 420 } 421 422 Operation *mlir::sparse_tensor::getTop(Operation *op) { 423 for (; isa<scf::ForOp>(op->getParentOp()) || 424 isa<scf::WhileOp>(op->getParentOp()) || 425 isa<scf::ParallelOp>(op->getParentOp()) || 426 isa<scf::IfOp>(op->getParentOp()); 427 op = op->getParentOp()) 428 ; 429 return op; 430 } 431 432 void sparse_tensor::foreachInSparseConstant( 433 OpBuilder &builder, Location loc, SparseElementsAttr attr, AffineMap order, 434 function_ref<void(ArrayRef<Value>, Value)> callback) { 435 if (!order) 436 order = builder.getMultiDimIdentityMap(attr.getType().getRank()); 437 438 auto stt = SparseTensorType(getRankedTensorType(attr)); 439 const Dimension dimRank = stt.getDimRank(); 440 const auto coordinates = attr.getIndices().getValues<IntegerAttr>(); 441 const auto values = attr.getValues().getValues<Attribute>(); 442 443 // This is like the `Element<V>` class in the runtime library, but for 444 // MLIR attributes. In the future we may want to move this out into 445 // a proper class definition to help improve code legibility (e.g., 446 // `first` -> `coords`, `second` -> `value`) as well as being able 447 // to factor out analogues of `ElementLT<V>` for the sort below, etc. 448 using ElementAttr = std::pair<SmallVector<IntegerAttr>, Attribute>; 449 450 // Construct the COO from the SparseElementsAttr. 451 SmallVector<ElementAttr> elems; 452 for (size_t i = 0, nse = values.size(); i < nse; i++) { 453 elems.emplace_back(); 454 elems.back().second = values[i]; 455 auto &coords = elems.back().first; 456 coords.reserve(dimRank); 457 for (Dimension d = 0; d < dimRank; d++) 458 coords.push_back(coordinates[i * dimRank + d]); 459 } 460 461 // Sorts the sparse element attribute based on coordinates. 462 std::sort(elems.begin(), elems.end(), 463 [order](const ElementAttr &lhs, const ElementAttr &rhs) { 464 if (std::addressof(lhs) == std::addressof(rhs)) 465 return false; 466 467 auto lhsCoords = llvm::map_to_vector( 468 lhs.first, [](IntegerAttr i) { return i.getInt(); }); 469 auto rhsCoords = llvm::map_to_vector( 470 rhs.first, [](IntegerAttr i) { return i.getInt(); }); 471 472 SmallVector<int64_t, 4> lhsLvlCrds = order.compose(lhsCoords); 473 SmallVector<int64_t, 4> rhsLvlCrds = order.compose(rhsCoords); 474 // Sort the element based on the lvl coordinates. 475 for (Level l = 0; l < order.getNumResults(); l++) { 476 if (lhsLvlCrds[l] == rhsLvlCrds[l]) 477 continue; 478 return lhsLvlCrds[l] < rhsLvlCrds[l]; 479 } 480 llvm_unreachable("no equal coordinate in sparse element attr"); 481 }); 482 483 SmallVector<Value> cvs; 484 cvs.reserve(dimRank); 485 for (size_t i = 0, nse = values.size(); i < nse; i++) { 486 // Remap coordinates. 487 cvs.clear(); 488 for (Dimension d = 0; d < dimRank; d++) { 489 auto crd = elems[i].first[d].getInt(); 490 cvs.push_back(builder.create<arith::ConstantIndexOp>(loc, crd)); 491 } 492 // Remap value. 493 Value val; 494 if (isa<ComplexType>(attr.getElementType())) { 495 auto valAttr = cast<ArrayAttr>(elems[i].second); 496 val = builder.create<complex::ConstantOp>(loc, attr.getElementType(), 497 valAttr); 498 } else { 499 auto valAttr = cast<TypedAttr>(elems[i].second); 500 val = builder.create<arith::ConstantOp>(loc, valAttr); 501 } 502 assert(val); 503 callback(cvs, val); 504 } 505 } 506 507 SmallVector<Value> sparse_tensor::loadAll(OpBuilder &builder, Location loc, 508 size_t size, Value mem, 509 size_t offsetIdx, Value offsetVal) { 510 #ifndef NDEBUG 511 const auto memTp = cast<MemRefType>(mem.getType()); 512 assert(memTp.getRank() == 1); 513 const Size memSh = memTp.getDimSize(0); 514 assert(ShapedType::isDynamic(memSh) || memSh >= static_cast<Size>(size)); 515 assert(offsetIdx == 0 || offsetIdx < size); 516 #endif // NDEBUG 517 SmallVector<Value> vs; 518 vs.reserve(size); 519 for (unsigned i = 0; i < size; i++) { 520 Value v = builder.create<memref::LoadOp>(loc, mem, 521 constantIndex(builder, loc, i)); 522 if (i == offsetIdx && offsetVal) 523 v = builder.create<arith::AddIOp>(loc, v, offsetVal); 524 vs.push_back(v); 525 } 526 return vs; 527 } 528 529 void sparse_tensor::storeAll(OpBuilder &builder, Location loc, Value mem, 530 ValueRange vs, size_t offsetIdx, Value offsetVal) { 531 #ifndef NDEBUG 532 const size_t vsize = vs.size(); 533 const auto memTp = cast<MemRefType>(mem.getType()); 534 assert(memTp.getRank() == 1); 535 const Size memSh = memTp.getDimSize(0); 536 assert(ShapedType::isDynamic(memSh) || memSh >= static_cast<Size>(vsize)); 537 assert(offsetIdx == 0 || offsetIdx < vsize); 538 #endif // NDEBUG 539 for (const auto &v : llvm::enumerate(vs)) { 540 const Value w = 541 (offsetIdx == v.index() && offsetVal) 542 ? builder.create<arith::AddIOp>(loc, v.value(), offsetVal) 543 : v.value(); 544 builder.create<memref::StoreOp>(loc, w, mem, 545 constantIndex(builder, loc, v.index())); 546 } 547 } 548 549 TypedValue<BaseMemRefType> 550 sparse_tensor::genToMemref(OpBuilder &builder, Location loc, Value tensor) { 551 auto tTp = llvm::cast<TensorType>(tensor.getType()); 552 auto mTp = MemRefType::get(tTp.getShape(), tTp.getElementType()); 553 return builder.create<bufferization::ToMemrefOp>(loc, mTp, tensor) 554 .getResult(); 555 } 556 557 Value sparse_tensor::genValMemSize(OpBuilder &builder, Location loc, 558 Value tensor) { 559 return getDescriptorFromTensorTuple(tensor).getValMemSize(builder, loc); 560 } 561 562 Value sparse_tensor::createOrFoldSliceOffsetOp(OpBuilder &builder, Location loc, 563 Value tensor, Dimension dim) { 564 auto enc = getSparseTensorEncoding(tensor.getType()); 565 assert(enc && enc.isSlice()); 566 std::optional<unsigned> offset = enc.getStaticDimSliceOffset(dim); 567 if (offset.has_value()) 568 return constantIndex(builder, loc, *offset); 569 return builder.create<ToSliceOffsetOp>(loc, tensor, APInt(64, dim)); 570 } 571 572 Value sparse_tensor::createOrFoldSliceStrideOp(OpBuilder &builder, Location loc, 573 Value tensor, Dimension dim) { 574 auto enc = getSparseTensorEncoding(tensor.getType()); 575 assert(enc && enc.isSlice()); 576 std::optional<unsigned> stride = enc.getStaticDimSliceStride(dim); 577 if (stride.has_value()) 578 return constantIndex(builder, loc, *stride); 579 return builder.create<ToSliceStrideOp>(loc, tensor, APInt(64, dim)); 580 } 581 582 Value sparse_tensor::genReader(OpBuilder &builder, Location loc, 583 SparseTensorType stt, Value tensor, 584 /*out*/ SmallVectorImpl<Value> &dimSizesValues, 585 /*out*/ Value &dimSizesBuffer) { 586 // Construct the dimension **shapes** buffer. The buffer contains the static 587 // size per dimension, or otherwise a zero for a dynamic size. 588 Dimension dimRank = stt.getDimRank(); 589 dimSizesValues.clear(); 590 dimSizesValues.reserve(dimRank); 591 for (const Size sz : stt.getDimShape()) { 592 const auto s = ShapedType::isDynamic(sz) ? 0 : sz; 593 dimSizesValues.push_back(constantIndex(builder, loc, s)); 594 } 595 Value dimShapesBuffer = allocaBuffer(builder, loc, dimSizesValues); 596 // Create the `CheckedSparseTensorReader`. This reader performs a 597 // consistency check on the static sizes, but accepts any size 598 // of each dimension with a dynamic size. 599 Type opaqueTp = getOpaquePointerType(builder); 600 Type eltTp = stt.getElementType(); 601 Value valTp = constantPrimaryTypeEncoding(builder, loc, eltTp); 602 Value reader = 603 createFuncCall(builder, loc, "createCheckedSparseTensorReader", opaqueTp, 604 {tensor, dimShapesBuffer, valTp}, EmitCInterface::On) 605 .getResult(0); 606 // For static shapes, the shape buffer can be used right away. For dynamic 607 // shapes, use the information from the reader to construct a buffer that 608 // supplies the actual size for each dynamic dimension. 609 dimSizesBuffer = dimShapesBuffer; 610 if (stt.hasDynamicDimShape()) { 611 Type indexTp = builder.getIndexType(); 612 auto memTp = MemRefType::get({ShapedType::kDynamic}, indexTp); 613 dimSizesBuffer = 614 createFuncCall(builder, loc, "getSparseTensorReaderDimSizes", memTp, 615 reader, EmitCInterface::On) 616 .getResult(0); 617 // Also convert the dim shapes values into dim sizes values, just in case 618 // subsequent clients need the values (DCE will remove unused). 619 for (Dimension d = 0; d < dimRank; d++) { 620 if (stt.isDynamicDim(d)) 621 dimSizesValues[d] = builder.create<memref::LoadOp>( 622 loc, dimSizesBuffer, constantIndex(builder, loc, d)); 623 } 624 } 625 return reader; 626 } 627 628 Value sparse_tensor::genMapBuffers( 629 OpBuilder &builder, Location loc, SparseTensorType stt, 630 ArrayRef<Value> dimSizesValues, Value dimSizesBuffer, 631 /*out*/ SmallVectorImpl<Value> &lvlSizesValues, 632 /*out*/ Value &dim2lvlBuffer, 633 /*out*/ Value &lvl2dimBuffer) { 634 const Dimension dimRank = stt.getDimRank(); 635 const Level lvlRank = stt.getLvlRank(); 636 lvlSizesValues.clear(); 637 lvlSizesValues.reserve(lvlRank); 638 // For an identity mapping, the dim2lvl and lvl2dim mappings are 639 // identical as are dimSizes and lvlSizes, so buffers are reused 640 // as much as possible. 641 if (stt.isIdentity()) { 642 assert(dimRank == lvlRank); 643 SmallVector<Value> iotaValues; 644 iotaValues.reserve(lvlRank); 645 for (Level l = 0; l < lvlRank; l++) { 646 iotaValues.push_back(constantIndex(builder, loc, l)); 647 lvlSizesValues.push_back(dimSizesValues[l]); 648 } 649 dim2lvlBuffer = lvl2dimBuffer = allocaBuffer(builder, loc, iotaValues); 650 return dimSizesBuffer; // now lvlSizesBuffer 651 } 652 // Otherwise, some code needs to be generated to set up the buffers. 653 // This code deals with permutations as well as non-permutations that 654 // arise from rank changing blocking. 655 const auto dimToLvl = stt.getDimToLvl(); 656 const auto lvlToDim = stt.getLvlToDim(); 657 SmallVector<Value> dim2lvlValues(lvlRank); // for each lvl, expr in dim vars 658 SmallVector<Value> lvl2dimValues(dimRank); // for each dim, expr in lvl vars 659 // Generate dim2lvl. 660 assert(lvlRank == dimToLvl.getNumResults()); 661 for (Level l = 0; l < lvlRank; l++) { 662 AffineExpr exp = dimToLvl.getResult(l); 663 // We expect: 664 // (1) l = d 665 // (2) l = d / c 666 // (3) l = d % c 667 Dimension d = 0; 668 uint64_t cf = 0, cm = 0; 669 switch (exp.getKind()) { 670 case AffineExprKind::DimId: { 671 d = cast<AffineDimExpr>(exp).getPosition(); 672 break; 673 } 674 case AffineExprKind::FloorDiv: { 675 auto floor = cast<AffineBinaryOpExpr>(exp); 676 d = cast<AffineDimExpr>(floor.getLHS()).getPosition(); 677 cf = cast<AffineConstantExpr>(floor.getRHS()).getValue(); 678 break; 679 } 680 case AffineExprKind::Mod: { 681 auto mod = cast<AffineBinaryOpExpr>(exp); 682 d = cast<AffineDimExpr>(mod.getLHS()).getPosition(); 683 cm = cast<AffineConstantExpr>(mod.getRHS()).getValue(); 684 break; 685 } 686 default: 687 llvm::report_fatal_error("unsupported dim2lvl in sparse tensor type"); 688 } 689 dim2lvlValues[l] = constantIndex(builder, loc, encodeDim(d, cf, cm)); 690 // Compute the level sizes. 691 // (1) l = d : size(d) 692 // (2) l = d / c : size(d) / c 693 // (3) l = d % c : c 694 Value lvlSz; 695 if (cm == 0) { 696 lvlSz = dimSizesValues[d]; 697 if (cf != 0) 698 lvlSz = builder.create<arith::DivUIOp>(loc, lvlSz, 699 constantIndex(builder, loc, cf)); 700 } else { 701 lvlSz = constantIndex(builder, loc, cm); 702 } 703 lvlSizesValues.push_back(lvlSz); 704 } 705 // Generate lvl2dim. 706 assert(dimRank == lvlToDim.getNumResults()); 707 for (Dimension d = 0; d < dimRank; d++) { 708 AffineExpr exp = lvlToDim.getResult(d); 709 // We expect: 710 // (1) d = l 711 // (2) d = l' * c + l 712 Level l = 0, ll = 0; 713 uint64_t c = 0; 714 switch (exp.getKind()) { 715 case AffineExprKind::DimId: { 716 l = cast<AffineDimExpr>(exp).getPosition(); 717 break; 718 } 719 case AffineExprKind::Add: { 720 // Always mul on lhs, symbol/constant on rhs. 721 auto add = cast<AffineBinaryOpExpr>(exp); 722 assert(add.getLHS().getKind() == AffineExprKind::Mul); 723 auto mul = cast<AffineBinaryOpExpr>(add.getLHS()); 724 ll = cast<AffineDimExpr>(mul.getLHS()).getPosition(); 725 c = cast<AffineConstantExpr>(mul.getRHS()).getValue(); 726 l = cast<AffineDimExpr>(add.getRHS()).getPosition(); 727 break; 728 } 729 default: 730 llvm::report_fatal_error("unsupported lvl2dim in sparse tensor type"); 731 } 732 lvl2dimValues[d] = constantIndex(builder, loc, encodeLvl(l, c, ll)); 733 } 734 // Return buffers. 735 dim2lvlBuffer = allocaBuffer(builder, loc, dim2lvlValues); 736 lvl2dimBuffer = allocaBuffer(builder, loc, lvl2dimValues); 737 return allocaBuffer(builder, loc, lvlSizesValues); // lvlSizesBuffer 738 } 739