1 //===- CodegenUtils.cpp - Utilities for generating MLIR -------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "CodegenUtils.h" 10 #include "SparseTensorDescriptor.h" 11 12 #include "mlir/Dialect/Affine/IR/AffineOps.h" 13 #include "mlir/Dialect/Bufferization/IR/Bufferization.h" 14 #include "mlir/Dialect/Linalg/IR/Linalg.h" 15 #include "mlir/Dialect/Linalg/Utils/Utils.h" 16 #include "mlir/Dialect/MemRef/IR/MemRef.h" 17 #include "mlir/Dialect/Tensor/IR/Tensor.h" 18 #include "mlir/IR/Matchers.h" 19 #include "mlir/IR/Types.h" 20 #include "mlir/IR/Value.h" 21 #include <optional> 22 23 using namespace mlir; 24 using namespace mlir::sparse_tensor; 25 26 //===----------------------------------------------------------------------===// 27 // ExecutionEngine/SparseTensorUtils helper functions. 28 //===----------------------------------------------------------------------===// 29 30 OverheadType mlir::sparse_tensor::overheadTypeEncoding(unsigned width) { 31 switch (width) { 32 case 64: 33 return OverheadType::kU64; 34 case 32: 35 return OverheadType::kU32; 36 case 16: 37 return OverheadType::kU16; 38 case 8: 39 return OverheadType::kU8; 40 case 0: 41 return OverheadType::kIndex; 42 } 43 llvm_unreachable("Unsupported overhead bitwidth"); 44 } 45 46 OverheadType mlir::sparse_tensor::overheadTypeEncoding(Type tp) { 47 if (tp.isIndex()) 48 return OverheadType::kIndex; 49 if (auto intTp = dyn_cast<IntegerType>(tp)) 50 return overheadTypeEncoding(intTp.getWidth()); 51 llvm_unreachable("Unknown overhead type"); 52 } 53 54 Type mlir::sparse_tensor::getOverheadType(Builder &builder, OverheadType ot) { 55 switch (ot) { 56 case OverheadType::kIndex: 57 return builder.getIndexType(); 58 case OverheadType::kU64: 59 return builder.getIntegerType(64); 60 case OverheadType::kU32: 61 return builder.getIntegerType(32); 62 case OverheadType::kU16: 63 return builder.getIntegerType(16); 64 case OverheadType::kU8: 65 return builder.getIntegerType(8); 66 } 67 llvm_unreachable("Unknown OverheadType"); 68 } 69 70 OverheadType 71 mlir::sparse_tensor::posTypeEncoding(SparseTensorEncodingAttr enc) { 72 return overheadTypeEncoding(enc.getPosWidth()); 73 } 74 75 OverheadType 76 mlir::sparse_tensor::crdTypeEncoding(SparseTensorEncodingAttr enc) { 77 return overheadTypeEncoding(enc.getCrdWidth()); 78 } 79 80 // TODO: we ought to add some `static_assert` tests to ensure that the 81 // `STEA::get{Pos,Crd}Type` methods agree with `getOverheadType(builder, 82 // {pos,crd}OverheadTypeEncoding(enc))` 83 84 // TODO: Adjust the naming convention for the constructors of 85 // `OverheadType` so we can use the `MLIR_SPARSETENSOR_FOREVERY_O` x-macro 86 // here instead of `MLIR_SPARSETENSOR_FOREVERY_FIXED_O`; to further reduce 87 // the possibility of typo bugs or things getting out of sync. 88 StringRef mlir::sparse_tensor::overheadTypeFunctionSuffix(OverheadType ot) { 89 switch (ot) { 90 case OverheadType::kIndex: 91 return "0"; 92 #define CASE(ONAME, O) \ 93 case OverheadType::kU##ONAME: \ 94 return #ONAME; 95 MLIR_SPARSETENSOR_FOREVERY_FIXED_O(CASE) 96 #undef CASE 97 } 98 llvm_unreachable("Unknown OverheadType"); 99 } 100 101 StringRef mlir::sparse_tensor::overheadTypeFunctionSuffix(Type tp) { 102 return overheadTypeFunctionSuffix(overheadTypeEncoding(tp)); 103 } 104 105 PrimaryType mlir::sparse_tensor::primaryTypeEncoding(Type elemTp) { 106 if (elemTp.isF64()) 107 return PrimaryType::kF64; 108 if (elemTp.isF32()) 109 return PrimaryType::kF32; 110 if (elemTp.isF16()) 111 return PrimaryType::kF16; 112 if (elemTp.isBF16()) 113 return PrimaryType::kBF16; 114 if (elemTp.isInteger(64)) 115 return PrimaryType::kI64; 116 if (elemTp.isInteger(32)) 117 return PrimaryType::kI32; 118 if (elemTp.isInteger(16)) 119 return PrimaryType::kI16; 120 if (elemTp.isInteger(8)) 121 return PrimaryType::kI8; 122 if (auto complexTp = dyn_cast<ComplexType>(elemTp)) { 123 auto complexEltTp = complexTp.getElementType(); 124 if (complexEltTp.isF64()) 125 return PrimaryType::kC64; 126 if (complexEltTp.isF32()) 127 return PrimaryType::kC32; 128 } 129 llvm_unreachable("Unknown primary type"); 130 } 131 132 StringRef mlir::sparse_tensor::primaryTypeFunctionSuffix(PrimaryType pt) { 133 switch (pt) { 134 #define CASE(VNAME, V) \ 135 case PrimaryType::k##VNAME: \ 136 return #VNAME; 137 MLIR_SPARSETENSOR_FOREVERY_V(CASE) 138 #undef CASE 139 } 140 llvm_unreachable("Unknown PrimaryType"); 141 } 142 143 StringRef mlir::sparse_tensor::primaryTypeFunctionSuffix(Type elemTp) { 144 return primaryTypeFunctionSuffix(primaryTypeEncoding(elemTp)); 145 } 146 147 //===----------------------------------------------------------------------===// 148 // Misc code generators. 149 //===----------------------------------------------------------------------===// 150 151 Value sparse_tensor::genCast(OpBuilder &builder, Location loc, Value value, 152 Type dstTp) { 153 const Type srcTp = value.getType(); 154 if (srcTp == dstTp) 155 return value; 156 157 // int <=> index 158 if (isa<IndexType>(srcTp) || isa<IndexType>(dstTp)) 159 return builder.create<arith::IndexCastOp>(loc, dstTp, value); 160 161 const auto srcIntTp = dyn_cast_or_null<IntegerType>(srcTp); 162 const bool isUnsignedCast = srcIntTp ? srcIntTp.isUnsigned() : false; 163 return mlir::convertScalarToDtype(builder, loc, value, dstTp, isUnsignedCast); 164 } 165 166 Value sparse_tensor::genScalarToTensor(OpBuilder &builder, Location loc, 167 Value elem, Type dstTp) { 168 if (auto rtp = dyn_cast<RankedTensorType>(dstTp)) { 169 // Scalars can only be converted to 0-ranked tensors. 170 assert(rtp.getRank() == 0); 171 elem = sparse_tensor::genCast(builder, loc, elem, rtp.getElementType()); 172 return builder.create<tensor::FromElementsOp>(loc, rtp, elem); 173 } 174 return sparse_tensor::genCast(builder, loc, elem, dstTp); 175 } 176 177 Value sparse_tensor::genIndexLoad(OpBuilder &builder, Location loc, Value mem, 178 ValueRange s) { 179 Value load = builder.create<memref::LoadOp>(loc, mem, s); 180 if (!isa<IndexType>(load.getType())) { 181 if (load.getType().getIntOrFloatBitWidth() < 64) 182 load = builder.create<arith::ExtUIOp>(loc, builder.getI64Type(), load); 183 load = 184 builder.create<arith::IndexCastOp>(loc, builder.getIndexType(), load); 185 } 186 return load; 187 } 188 189 mlir::TypedAttr mlir::sparse_tensor::getOneAttr(Builder &builder, Type tp) { 190 if (isa<FloatType>(tp)) 191 return builder.getFloatAttr(tp, 1.0); 192 if (isa<IndexType>(tp)) 193 return builder.getIndexAttr(1); 194 if (auto intTp = dyn_cast<IntegerType>(tp)) 195 return builder.getIntegerAttr(tp, APInt(intTp.getWidth(), 1)); 196 if (isa<RankedTensorType, VectorType>(tp)) { 197 auto shapedTp = cast<ShapedType>(tp); 198 if (auto one = getOneAttr(builder, shapedTp.getElementType())) 199 return DenseElementsAttr::get(shapedTp, one); 200 } 201 llvm_unreachable("Unsupported attribute type"); 202 } 203 204 Value mlir::sparse_tensor::genIsNonzero(OpBuilder &builder, mlir::Location loc, 205 Value v) { 206 Type tp = v.getType(); 207 Value zero = constantZero(builder, loc, tp); 208 if (isa<FloatType>(tp)) 209 return builder.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UNE, v, 210 zero); 211 if (tp.isIntOrIndex()) 212 return builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ne, v, 213 zero); 214 if (dyn_cast<ComplexType>(tp)) 215 return builder.create<complex::NotEqualOp>(loc, v, zero); 216 llvm_unreachable("Non-numeric type"); 217 } 218 219 void mlir::sparse_tensor::genReshapeDstShape( 220 OpBuilder &builder, Location loc, SmallVectorImpl<Value> &dstShape, 221 ArrayRef<Value> srcShape, ArrayRef<Size> staticDstShape, 222 ArrayRef<ReassociationIndices> reassociation) { 223 // Collapse shape. 224 if (reassociation.size() < srcShape.size()) { 225 unsigned start = 0; 226 for (const auto &map : llvm::enumerate(reassociation)) { 227 auto dstDim = constantIndex(builder, loc, 1); 228 for (unsigned i = start; i < start + map.value().size(); i++) { 229 dstDim = builder.create<arith::MulIOp>(loc, dstDim, srcShape[i]); 230 } 231 dstShape.push_back(dstDim); 232 start = start + map.value().size(); 233 } 234 assert(start == srcShape.size()); 235 return; 236 } 237 238 // Expand shape. 239 assert(reassociation.size() == srcShape.size()); 240 unsigned start = 0; 241 // Expand the i-th dimension in srcShape. 242 for (unsigned i = 0, size = srcShape.size(); i < size; i++) { 243 const auto &map = reassociation[i]; 244 auto srcDim = srcShape[i]; 245 // Iterate through dimensions expanded from the i-th dimension. 246 for (unsigned j = start; j < start + map.size(); j++) { 247 // There can be only one dynamic sized dimension among dimensions 248 // expanded from the i-th dimension in srcShape. 249 // For example, if srcDim = 8, then the expanded shape could be <2x?x2>, 250 // but not <2x?x?>. 251 if (staticDstShape[j] == ShapedType::kDynamic) { 252 // The expanded dimension has dynamic size. We compute the dimension 253 // by dividing srcDim by the product of the static dimensions. 254 Size product = 1; 255 for (unsigned k = start; k < start + map.size(); k++) { 256 if (staticDstShape[k] != ShapedType::kDynamic) { 257 product *= staticDstShape[k]; 258 } 259 } 260 // Compute the dynamic dimension size. 261 Value productVal = constantIndex(builder, loc, product); 262 Value dynamicSize = 263 builder.create<arith::DivUIOp>(loc, srcDim, productVal); 264 dstShape.push_back(dynamicSize); 265 } else { 266 // The expanded dimension is statically known. 267 dstShape.push_back(constantIndex(builder, loc, staticDstShape[j])); 268 } 269 } 270 start = start + map.size(); 271 } 272 assert(start == staticDstShape.size()); 273 } 274 275 void mlir::sparse_tensor::reshapeCvs( 276 OpBuilder &builder, Location loc, 277 ArrayRef<ReassociationIndices> reassociation, // NOLINT 278 ValueRange srcSizes, ValueRange srcCvs, // NOLINT 279 ValueRange dstSizes, SmallVectorImpl<Value> &dstCvs) { 280 const unsigned srcRank = srcSizes.size(); 281 const unsigned dstRank = dstSizes.size(); 282 assert(srcRank == srcCvs.size() && "Source rank mismatch"); 283 const bool isCollapse = srcRank > dstRank; 284 const ValueRange sizes = isCollapse ? srcSizes : dstSizes; 285 // Iterate over reassociation map. 286 unsigned i = 0; 287 unsigned start = 0; 288 for (const auto &map : llvm::enumerate(reassociation)) { 289 // Prepare strides information in dimension slice. 290 Value linear = constantIndex(builder, loc, 1); 291 for (unsigned j = start, end = start + map.value().size(); j < end; j++) { 292 linear = builder.create<arith::MulIOp>(loc, linear, sizes[j]); 293 } 294 // Start expansion. 295 Value val; 296 if (!isCollapse) 297 val = srcCvs[i]; 298 // Iterate over dimension slice. 299 for (unsigned j = start, end = start + map.value().size(); j < end; j++) { 300 linear = builder.create<arith::DivUIOp>(loc, linear, sizes[j]); 301 if (isCollapse) { 302 const Value mul = builder.create<arith::MulIOp>(loc, srcCvs[j], linear); 303 val = val ? builder.create<arith::AddIOp>(loc, val, mul) : mul; 304 } else { 305 const Value old = val; 306 val = builder.create<arith::DivUIOp>(loc, val, linear); 307 assert(dstCvs.size() == j); 308 dstCvs.push_back(val); 309 val = builder.create<arith::RemUIOp>(loc, old, linear); 310 } 311 } 312 // Finalize collapse. 313 if (isCollapse) { 314 assert(dstCvs.size() == i); 315 dstCvs.push_back(val); 316 } 317 start += map.value().size(); 318 i++; 319 } 320 assert(dstCvs.size() == dstRank); 321 } 322 323 FlatSymbolRefAttr mlir::sparse_tensor::getFunc(ModuleOp module, StringRef name, 324 TypeRange resultType, 325 ValueRange operands, 326 EmitCInterface emitCInterface) { 327 MLIRContext *context = module.getContext(); 328 auto result = SymbolRefAttr::get(context, name); 329 auto func = module.lookupSymbol<func::FuncOp>(result.getAttr()); 330 if (!func) { 331 OpBuilder moduleBuilder(module.getBodyRegion()); 332 func = moduleBuilder.create<func::FuncOp>( 333 module.getLoc(), name, 334 FunctionType::get(context, operands.getTypes(), resultType)); 335 func.setPrivate(); 336 if (static_cast<bool>(emitCInterface)) 337 func->setAttr(LLVM::LLVMDialect::getEmitCWrapperAttrName(), 338 UnitAttr::get(context)); 339 } 340 return result; 341 } 342 343 func::CallOp mlir::sparse_tensor::createFuncCall( 344 OpBuilder &builder, Location loc, StringRef name, TypeRange resultType, 345 ValueRange operands, EmitCInterface emitCInterface) { 346 auto module = builder.getBlock()->getParentOp()->getParentOfType<ModuleOp>(); 347 FlatSymbolRefAttr fn = 348 getFunc(module, name, resultType, operands, emitCInterface); 349 return builder.create<func::CallOp>(loc, resultType, fn, operands); 350 } 351 352 Type mlir::sparse_tensor::getOpaquePointerType(MLIRContext *ctx) { 353 return LLVM::LLVMPointerType::get(ctx); 354 } 355 356 Type mlir::sparse_tensor::getOpaquePointerType(Builder &builder) { 357 return getOpaquePointerType(builder.getContext()); 358 } 359 360 Value mlir::sparse_tensor::genAlloca(OpBuilder &builder, Location loc, 361 unsigned sz, Type tp, bool staticShape) { 362 if (staticShape) { 363 auto memTp = MemRefType::get({sz}, tp); 364 return builder.create<memref::AllocaOp>(loc, memTp); 365 } 366 return genAlloca(builder, loc, constantIndex(builder, loc, sz), tp); 367 } 368 369 Value mlir::sparse_tensor::genAlloca(OpBuilder &builder, Location loc, Value sz, 370 Type tp) { 371 auto memTp = MemRefType::get({ShapedType::kDynamic}, tp); 372 return builder.create<memref::AllocaOp>(loc, memTp, ValueRange{sz}); 373 } 374 375 Value mlir::sparse_tensor::genAllocaScalar(OpBuilder &builder, Location loc, 376 Type tp) { 377 return builder.create<memref::AllocaOp>(loc, MemRefType::get({}, tp)); 378 } 379 380 Value mlir::sparse_tensor::allocaBuffer(OpBuilder &builder, Location loc, 381 ValueRange values) { 382 const unsigned sz = values.size(); 383 assert(sz >= 1); 384 Value buffer = genAlloca(builder, loc, sz, values[0].getType()); 385 for (unsigned i = 0; i < sz; i++) { 386 Value idx = constantIndex(builder, loc, i); 387 builder.create<memref::StoreOp>(loc, values[i], buffer, idx); 388 } 389 return buffer; 390 } 391 392 Value mlir::sparse_tensor::allocDenseTensor(OpBuilder &builder, Location loc, 393 RankedTensorType tensorTp, 394 ValueRange sizes) { 395 Type elemTp = tensorTp.getElementType(); 396 auto shape = tensorTp.getShape(); 397 auto memTp = MemRefType::get(shape, elemTp); 398 SmallVector<Value> dynamicSizes; 399 for (unsigned i = 0, rank = tensorTp.getRank(); i < rank; i++) { 400 if (shape[i] == ShapedType::kDynamic) 401 dynamicSizes.push_back(sizes[i]); 402 } 403 Value mem = builder.create<memref::AllocOp>(loc, memTp, dynamicSizes); 404 Value zero = constantZero(builder, loc, elemTp); 405 builder.create<linalg::FillOp>(loc, ValueRange{zero}, ValueRange{mem}); 406 return mem; 407 } 408 409 void mlir::sparse_tensor::deallocDenseTensor(OpBuilder &builder, Location loc, 410 Value buffer) { 411 builder.create<memref::DeallocOp>(loc, buffer); 412 } 413 414 void mlir::sparse_tensor::sizesFromSrc(OpBuilder &builder, 415 SmallVectorImpl<Value> &sizes, 416 Location loc, Value src) { 417 const Dimension dimRank = getSparseTensorType(src).getDimRank(); 418 for (Dimension d = 0; d < dimRank; d++) 419 sizes.push_back(linalg::createOrFoldDimOp(builder, loc, src, d)); 420 } 421 422 Operation *mlir::sparse_tensor::getTop(Operation *op) { 423 for (; isa<scf::ForOp>(op->getParentOp()) || 424 isa<scf::WhileOp>(op->getParentOp()) || 425 isa<scf::ParallelOp>(op->getParentOp()) || 426 isa<scf::IfOp>(op->getParentOp()); 427 op = op->getParentOp()) 428 ; 429 return op; 430 } 431 432 void sparse_tensor::foreachInSparseConstant( 433 OpBuilder &builder, Location loc, SparseElementsAttr attr, AffineMap order, 434 function_ref<void(ArrayRef<Value>, Value)> callback) { 435 if (!order) 436 order = builder.getMultiDimIdentityMap(attr.getType().getRank()); 437 438 auto stt = SparseTensorType(getRankedTensorType(attr)); 439 const Dimension dimRank = stt.getDimRank(); 440 const auto coordinates = attr.getIndices().getValues<IntegerAttr>(); 441 const auto values = attr.getValues().getValues<Attribute>(); 442 443 // This is like the `Element<V>` class in the runtime library, but for 444 // MLIR attributes. In the future we may want to move this out into 445 // a proper class definition to help improve code legibility (e.g., 446 // `first` -> `coords`, `second` -> `value`) as well as being able 447 // to factor out analogues of `ElementLT<V>` for the sort below, etc. 448 using ElementAttr = std::pair<SmallVector<IntegerAttr>, Attribute>; 449 450 // Construct the COO from the SparseElementsAttr. 451 SmallVector<ElementAttr> elems; 452 for (size_t i = 0, nse = values.size(); i < nse; i++) { 453 elems.emplace_back(); 454 elems.back().second = values[i]; 455 auto &coords = elems.back().first; 456 coords.reserve(dimRank); 457 for (Dimension d = 0; d < dimRank; d++) 458 coords.push_back(coordinates[i * dimRank + d]); 459 } 460 461 // Sorts the sparse element attribute based on coordinates. 462 std::sort(elems.begin(), elems.end(), 463 [order](const ElementAttr &lhs, const ElementAttr &rhs) { 464 if (std::addressof(lhs) == std::addressof(rhs)) 465 return false; 466 467 auto lhsCoords = llvm::map_to_vector( 468 lhs.first, [](IntegerAttr i) { return i.getInt(); }); 469 auto rhsCoords = llvm::map_to_vector( 470 rhs.first, [](IntegerAttr i) { return i.getInt(); }); 471 472 SmallVector<int64_t, 4> lhsLvlCrds = order.compose(lhsCoords); 473 SmallVector<int64_t, 4> rhsLvlCrds = order.compose(rhsCoords); 474 // Sort the element based on the lvl coordinates. 475 for (Level l = 0; l < order.getNumResults(); l++) { 476 if (lhsLvlCrds[l] == rhsLvlCrds[l]) 477 continue; 478 return lhsLvlCrds[l] < rhsLvlCrds[l]; 479 } 480 llvm_unreachable("no equal coordinate in sparse element attr"); 481 }); 482 483 SmallVector<Value> cvs; 484 cvs.reserve(dimRank); 485 for (size_t i = 0, nse = values.size(); i < nse; i++) { 486 // Remap coordinates. 487 cvs.clear(); 488 for (Dimension d = 0; d < dimRank; d++) { 489 auto crd = elems[i].first[d].getInt(); 490 cvs.push_back(builder.create<arith::ConstantIndexOp>(loc, crd)); 491 } 492 // Remap value. 493 Value val; 494 if (isa<ComplexType>(attr.getElementType())) { 495 auto valAttr = cast<ArrayAttr>(elems[i].second); 496 val = builder.create<complex::ConstantOp>(loc, attr.getElementType(), 497 valAttr); 498 } else { 499 auto valAttr = cast<TypedAttr>(elems[i].second); 500 val = builder.create<arith::ConstantOp>(loc, valAttr); 501 } 502 assert(val); 503 callback(cvs, val); 504 } 505 } 506 507 SmallVector<Value> sparse_tensor::loadAll(OpBuilder &builder, Location loc, 508 size_t size, Value mem, 509 size_t offsetIdx, Value offsetVal) { 510 #ifndef NDEBUG 511 const auto memTp = cast<MemRefType>(mem.getType()); 512 assert(memTp.getRank() == 1); 513 const Size memSh = memTp.getDimSize(0); 514 assert(ShapedType::isDynamic(memSh) || memSh >= static_cast<Size>(size)); 515 assert(offsetIdx == 0 || offsetIdx < size); 516 #endif // NDEBUG 517 SmallVector<Value> vs; 518 vs.reserve(size); 519 for (unsigned i = 0; i < size; i++) { 520 Value v = builder.create<memref::LoadOp>(loc, mem, 521 constantIndex(builder, loc, i)); 522 if (i == offsetIdx && offsetVal) 523 v = builder.create<arith::AddIOp>(loc, v, offsetVal); 524 vs.push_back(v); 525 } 526 return vs; 527 } 528 529 void sparse_tensor::storeAll(OpBuilder &builder, Location loc, Value mem, 530 ValueRange vs, size_t offsetIdx, Value offsetVal) { 531 #ifndef NDEBUG 532 const size_t vsize = vs.size(); 533 const auto memTp = cast<MemRefType>(mem.getType()); 534 assert(memTp.getRank() == 1); 535 const Size memSh = memTp.getDimSize(0); 536 assert(ShapedType::isDynamic(memSh) || memSh >= static_cast<Size>(vsize)); 537 assert(offsetIdx == 0 || offsetIdx < vsize); 538 #endif // NDEBUG 539 for (const auto &v : llvm::enumerate(vs)) { 540 const Value w = 541 (offsetIdx == v.index() && offsetVal) 542 ? builder.create<arith::AddIOp>(loc, v.value(), offsetVal) 543 : v.value(); 544 builder.create<memref::StoreOp>(loc, w, mem, 545 constantIndex(builder, loc, v.index())); 546 } 547 } 548 549 TypedValue<BaseMemRefType> 550 sparse_tensor::genToMemref(OpBuilder &builder, Location loc, Value tensor) { 551 auto tTp = llvm::cast<TensorType>(tensor.getType()); 552 auto mTp = MemRefType::get(tTp.getShape(), tTp.getElementType()); 553 return builder.create<bufferization::ToMemrefOp>(loc, mTp, tensor) 554 .getResult(); 555 } 556 557 Value sparse_tensor::createOrFoldSliceOffsetOp(OpBuilder &builder, Location loc, 558 Value tensor, Dimension dim) { 559 auto enc = getSparseTensorEncoding(tensor.getType()); 560 assert(enc && enc.isSlice()); 561 std::optional<unsigned> offset = enc.getStaticDimSliceOffset(dim); 562 if (offset.has_value()) 563 return constantIndex(builder, loc, *offset); 564 return builder.create<ToSliceOffsetOp>(loc, tensor, APInt(64, dim)); 565 } 566 567 Value sparse_tensor::createOrFoldSliceStrideOp(OpBuilder &builder, Location loc, 568 Value tensor, Dimension dim) { 569 auto enc = getSparseTensorEncoding(tensor.getType()); 570 assert(enc && enc.isSlice()); 571 std::optional<unsigned> stride = enc.getStaticDimSliceStride(dim); 572 if (stride.has_value()) 573 return constantIndex(builder, loc, *stride); 574 return builder.create<ToSliceStrideOp>(loc, tensor, APInt(64, dim)); 575 } 576 577 Value sparse_tensor::genReader(OpBuilder &builder, Location loc, 578 SparseTensorType stt, Value tensor, 579 /*out*/ SmallVectorImpl<Value> &dimSizesValues, 580 /*out*/ Value &dimSizesBuffer) { 581 // Construct the dimension **shapes** buffer. The buffer contains the static 582 // size per dimension, or otherwise a zero for a dynamic size. 583 Dimension dimRank = stt.getDimRank(); 584 dimSizesValues.clear(); 585 dimSizesValues.reserve(dimRank); 586 for (const Size sz : stt.getDimShape()) { 587 const auto s = ShapedType::isDynamic(sz) ? 0 : sz; 588 dimSizesValues.push_back(constantIndex(builder, loc, s)); 589 } 590 Value dimShapesBuffer = allocaBuffer(builder, loc, dimSizesValues); 591 // Create the `CheckedSparseTensorReader`. This reader performs a 592 // consistency check on the static sizes, but accepts any size 593 // of each dimension with a dynamic size. 594 Type opaqueTp = getOpaquePointerType(builder); 595 Type eltTp = stt.getElementType(); 596 Value valTp = constantPrimaryTypeEncoding(builder, loc, eltTp); 597 Value reader = 598 createFuncCall(builder, loc, "createCheckedSparseTensorReader", opaqueTp, 599 {tensor, dimShapesBuffer, valTp}, EmitCInterface::On) 600 .getResult(0); 601 // For static shapes, the shape buffer can be used right away. For dynamic 602 // shapes, use the information from the reader to construct a buffer that 603 // supplies the actual size for each dynamic dimension. 604 dimSizesBuffer = dimShapesBuffer; 605 if (stt.hasDynamicDimShape()) { 606 Type indexTp = builder.getIndexType(); 607 auto memTp = MemRefType::get({ShapedType::kDynamic}, indexTp); 608 dimSizesBuffer = 609 createFuncCall(builder, loc, "getSparseTensorReaderDimSizes", memTp, 610 reader, EmitCInterface::On) 611 .getResult(0); 612 // Also convert the dim shapes values into dim sizes values, just in case 613 // subsequent clients need the values (DCE will remove unused). 614 for (Dimension d = 0; d < dimRank; d++) { 615 if (stt.isDynamicDim(d)) 616 dimSizesValues[d] = builder.create<memref::LoadOp>( 617 loc, dimSizesBuffer, constantIndex(builder, loc, d)); 618 } 619 } 620 return reader; 621 } 622 623 Value sparse_tensor::genMapBuffers( 624 OpBuilder &builder, Location loc, SparseTensorType stt, 625 ArrayRef<Value> dimSizesValues, Value dimSizesBuffer, 626 /*out*/ SmallVectorImpl<Value> &lvlSizesValues, 627 /*out*/ Value &dim2lvlBuffer, 628 /*out*/ Value &lvl2dimBuffer) { 629 const Dimension dimRank = stt.getDimRank(); 630 const Level lvlRank = stt.getLvlRank(); 631 lvlSizesValues.clear(); 632 lvlSizesValues.reserve(lvlRank); 633 // For an identity mapping, the dim2lvl and lvl2dim mappings are 634 // identical as are dimSizes and lvlSizes, so buffers are reused 635 // as much as possible. 636 if (stt.isIdentity()) { 637 assert(dimRank == lvlRank); 638 SmallVector<Value> iotaValues; 639 iotaValues.reserve(lvlRank); 640 for (Level l = 0; l < lvlRank; l++) { 641 iotaValues.push_back(constantIndex(builder, loc, l)); 642 lvlSizesValues.push_back(dimSizesValues[l]); 643 } 644 dim2lvlBuffer = lvl2dimBuffer = allocaBuffer(builder, loc, iotaValues); 645 return dimSizesBuffer; // now lvlSizesBuffer 646 } 647 // Otherwise, some code needs to be generated to set up the buffers. 648 // This code deals with permutations as well as non-permutations that 649 // arise from rank changing blocking. 650 const auto dimToLvl = stt.getDimToLvl(); 651 const auto lvlToDim = stt.getLvlToDim(); 652 SmallVector<Value> dim2lvlValues(lvlRank); // for each lvl, expr in dim vars 653 SmallVector<Value> lvl2dimValues(dimRank); // for each dim, expr in lvl vars 654 // Generate dim2lvl. 655 assert(lvlRank == dimToLvl.getNumResults()); 656 for (Level l = 0; l < lvlRank; l++) { 657 AffineExpr exp = dimToLvl.getResult(l); 658 // We expect: 659 // (1) l = d 660 // (2) l = d / c 661 // (3) l = d % c 662 Dimension d = 0; 663 uint64_t cf = 0, cm = 0; 664 switch (exp.getKind()) { 665 case AffineExprKind::DimId: { 666 d = cast<AffineDimExpr>(exp).getPosition(); 667 break; 668 } 669 case AffineExprKind::FloorDiv: { 670 auto floor = cast<AffineBinaryOpExpr>(exp); 671 d = cast<AffineDimExpr>(floor.getLHS()).getPosition(); 672 cf = cast<AffineConstantExpr>(floor.getRHS()).getValue(); 673 break; 674 } 675 case AffineExprKind::Mod: { 676 auto mod = cast<AffineBinaryOpExpr>(exp); 677 d = cast<AffineDimExpr>(mod.getLHS()).getPosition(); 678 cm = cast<AffineConstantExpr>(mod.getRHS()).getValue(); 679 break; 680 } 681 default: 682 llvm::report_fatal_error("unsupported dim2lvl in sparse tensor type"); 683 } 684 dim2lvlValues[l] = constantIndex(builder, loc, encodeDim(d, cf, cm)); 685 // Compute the level sizes. 686 // (1) l = d : size(d) 687 // (2) l = d / c : size(d) / c 688 // (3) l = d % c : c 689 Value lvlSz; 690 if (cm == 0) { 691 lvlSz = dimSizesValues[d]; 692 if (cf != 0) 693 lvlSz = builder.create<arith::DivUIOp>(loc, lvlSz, 694 constantIndex(builder, loc, cf)); 695 } else { 696 lvlSz = constantIndex(builder, loc, cm); 697 } 698 lvlSizesValues.push_back(lvlSz); 699 } 700 // Generate lvl2dim. 701 assert(dimRank == lvlToDim.getNumResults()); 702 for (Dimension d = 0; d < dimRank; d++) { 703 AffineExpr exp = lvlToDim.getResult(d); 704 // We expect: 705 // (1) d = l 706 // (2) d = l' * c + l 707 Level l = 0, ll = 0; 708 uint64_t c = 0; 709 switch (exp.getKind()) { 710 case AffineExprKind::DimId: { 711 l = cast<AffineDimExpr>(exp).getPosition(); 712 break; 713 } 714 case AffineExprKind::Add: { 715 // Always mul on lhs, symbol/constant on rhs. 716 auto add = cast<AffineBinaryOpExpr>(exp); 717 assert(add.getLHS().getKind() == AffineExprKind::Mul); 718 auto mul = cast<AffineBinaryOpExpr>(add.getLHS()); 719 ll = cast<AffineDimExpr>(mul.getLHS()).getPosition(); 720 c = cast<AffineConstantExpr>(mul.getRHS()).getValue(); 721 l = cast<AffineDimExpr>(add.getRHS()).getPosition(); 722 break; 723 } 724 default: 725 llvm::report_fatal_error("unsupported lvl2dim in sparse tensor type"); 726 } 727 lvl2dimValues[d] = constantIndex(builder, loc, encodeLvl(l, c, ll)); 728 } 729 // Return buffers. 730 dim2lvlBuffer = allocaBuffer(builder, loc, dim2lvlValues); 731 lvl2dimBuffer = allocaBuffer(builder, loc, lvl2dimValues); 732 return allocaBuffer(builder, loc, lvlSizesValues); // lvlSizesBuffer 733 } 734