1 //===- SparseReinterpretMap.cpp - reinterpret sparse tensor maps ----------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "CodegenUtils.h" 10 11 #include "mlir/Dialect/Affine/IR/AffineOps.h" 12 #include "mlir/Dialect/Bufferization/IR/Bufferization.h" 13 #include "mlir/Dialect/Linalg/IR/Linalg.h" 14 #include "mlir/Dialect/Linalg/Utils/Utils.h" 15 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" 16 #include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h" 17 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h" 18 #include "mlir/Dialect/Tensor/IR/Tensor.h" 19 #include "mlir/IR/AffineExprVisitor.h" 20 #include "mlir/IR/AffineMap.h" 21 22 using namespace mlir; 23 using namespace mlir::sparse_tensor; 24 25 namespace { 26 27 //===----------------------------------------------------------------------===// 28 // File Local Helper classes. 29 //===----------------------------------------------------------------------===// 30 31 // CRTP to help implementing a rewriter that demaps all its inputs. 32 template <typename SubClass, typename SourceOp> 33 struct DemapInsRewriter : public OpRewritePattern<SourceOp> { 34 using OpRewritePattern<SourceOp>::OpRewritePattern; 35 using OpAdaptor = typename SourceOp::Adaptor; 36 37 LogicalResult matchAndRewrite(SourceOp op, 38 PatternRewriter &rewriter) const override { 39 Location loc = op.getLoc(); 40 // Demaps non-trivial inputs. 41 SmallVector<Value> deMappedIns(op->getOperands()); 42 for (Value &in : deMappedIns) 43 if (auto stt = tryGetSparseTensorType(in); stt && !stt->isIdentity()) 44 in = rewriter.create<ReinterpretMapOp>(loc, stt->getDemappedType(), in); 45 46 // CRTP call. 47 OpAdaptor adaptor(deMappedIns, op); 48 return static_cast<const SubClass *>(this)->rewriteOp(op, adaptor, 49 rewriter); 50 } 51 }; 52 53 // Flattens an affine expression into a list of AffineDimExprs. 54 struct AffineDimCollector : public AffineExprVisitor<AffineDimCollector> { 55 explicit AffineDimCollector(unsigned dimNum) : dims(dimNum){}; 56 void visitDimExpr(AffineDimExpr expr) { dims.set(expr.getPosition()); } 57 BitVector dims; 58 }; 59 60 // Flattens an affine expression into a list of AffineDimExprs. 61 struct AffineExprAdmissibleVisitor 62 : public AffineExprVisitor<AffineExprAdmissibleVisitor> { 63 explicit AffineExprAdmissibleVisitor(bool isOutput) 64 : admissible(true), isOutput(isOutput){}; 65 66 // We only allow AffineDimExpr on output. 67 void visitAddExpr(AffineBinaryOpExpr expr) { 68 if (isOutput) 69 admissible = false; 70 } 71 void visitMulExpr(AffineBinaryOpExpr expr) { 72 if (isOutput) 73 admissible = false; 74 } 75 76 // We disallow mod, floor div and ceil div on inputs. 77 void visitModExpr(AffineBinaryOpExpr expr) { admissible = false; } 78 void visitFloorDivExpr(AffineBinaryOpExpr expr) { admissible = false; } 79 void visitCeilDivExpr(AffineBinaryOpExpr expr) { admissible = false; } 80 operator bool() { return admissible; } 81 82 private: 83 bool admissible; 84 bool isOutput; 85 }; 86 87 // The first BitVector stores levels where inadmissible exprs are used. 88 // The second BitVector stores the AffineDimExp that are used by the 89 // inadmissible expressions. 90 using InadmissInfo = std::pair<BitVector, BitVector>; 91 92 } // namespace 93 94 //===----------------------------------------------------------------------===// 95 // File Local Helper methods. 96 //===----------------------------------------------------------------------===// 97 98 // Collects the inadmissible affine expression imposed on levels. 99 static InadmissInfo collectInadmissInfo(AffineMap map, bool isOutput) { 100 auto ret = std::make_pair(BitVector(map.getNumResults()), 101 BitVector(map.getNumDims())); 102 AffineDimCollector collector(map.getNumDims()); 103 for (unsigned lvl = 0, e = map.getNumResults(); lvl < e; lvl++) { 104 AffineExprAdmissibleVisitor admissible(isOutput); 105 admissible.walkPostOrder(map.getResult(lvl)); 106 if (!admissible) { 107 // Record the inadmissible level. 108 ret.first.set(lvl); 109 // Record the AffineDimExpr that is used in the inadmissible expr. 110 collector.walkPostOrder(map.getResult(lvl)); 111 } 112 } 113 ret.second = collector.dims; 114 return ret; 115 } 116 117 // Builds the AffineMap to replace the idx in idxMap to lvl such that all tht 118 // inadmissible affine expressions can be eliminated. 119 // For example, we can rewrite 120 // idxMap = (d0, d1) -> (d0 floordiv 2, d1 floordiv 3, d0 mod 2, d1 mod 3) 121 // to 122 // idxMap = (l0, l1, l2, l3) -> (l0, l1, l2, l3) 123 // by composing inverse(idxMap), that is 124 // inverse(idxMap) . idxMap = (l0, l1, l2, l3) -> (l0 * 2 + l2, l1 * 3 + l3) 125 // -> ((l0 * 2 + l2) floordiv 2, 126 // (l1 * 3 + l3) floordiv 3, 127 // (l0 * 2 + l2) mod 2, 128 // (l1 * 3 + l3) mod 3) = (l0, l1, l2, l3) 129 // 130 // This function builds the inverse(idxMap) that replace every dimensions used 131 // in `info` to levels, and updates the iterator type array `itTps` for the new 132 // index variable introduced. 133 // 134 // Note that the returned affine map does not retain the order of the input 135 // affine map. Instead, it always uses the first `info.inAdlvls.count()` for the 136 // replaced levels, and remaining ones for unused dimensions. 137 // For example, to handle 138 // idxMap = (d0, d1) -> (d0, d1 floordiv 4, d2 mod 4) 139 // which is a typical map for block_2to4. The function returns: 140 // inverse(idxMap) = (l0, l1, d0) -> (d0, l0 * 4 + l1) 141 // in which, (l0, l1) together replaces `d1`, yet they appear 142 // before `d0` in the resulting affine map. 143 // The index (loop) order can later be canonicalized by a topo sort. 144 static AffineMap 145 genReplaceDimToLvlMap(const InadmissInfo &info, AffineMap idxMap, 146 SmallVector<utils::IteratorType> &itTps) { 147 MLIRContext *ctx = idxMap.getContext(); 148 auto [inAdLvls, usedDims] = info; 149 // Note that idxMap does not equal to dim2Lvl map, it is computed by 150 // composing idx2Dim(dim2Lvl). They are only equal when idx2Dim is an 151 // ID map. 152 // TODO: we might fail here, in those case we should really return 153 // failure instead of assertion error. 154 auto lvl2Idx = inferLvlToDim(idxMap, ctx); 155 156 assert(lvl2Idx.getNumResults() <= idxMap.getNumDims()); 157 if (lvl2Idx.getNumResults() != idxMap.getNumDims()) { 158 // This could happen when some dimensions are projected. 159 // E.g., idx2Lvl = (*i*, j, k) -> (j, k) 160 // ==> lvl2Idx = (j, k) -> (j, k) 161 // In this case, we append the unused dimesion at the end. 162 // ==> lvl2Idx = (j, k, *i*) -> (*i*, j, k) 163 SmallVector<AffineExpr> results; 164 AffineDimCollector usedInLvl(idxMap.getNumDims()); 165 for (auto e : idxMap.getResults()) 166 usedInLvl.walkPostOrder(e); 167 168 unsigned curUsedDimID = 0; 169 unsigned curUnusedDimID = lvl2Idx.getNumDims(); 170 171 BitVector unused = usedInLvl.dims.flip(); 172 for (unsigned i = 0; i < idxMap.getNumDims(); i++) { 173 if (unused.test(i)) 174 results.push_back(getAffineDimExpr(curUnusedDimID++, ctx)); 175 else 176 results.push_back(lvl2Idx.getResult(curUsedDimID++)); 177 } 178 lvl2Idx = 179 AffineMap::get(lvl2Idx.getNumDims() + unused.count(), 0, results, ctx); 180 } 181 assert(lvl2Idx.getNumResults() == idxMap.getNumDims()); 182 183 // We do not need to replace the DimExpr that is not used in inadmissible 184 // level expressions. We use the first inAdLvl.count() dim to represent the 185 // replaced level, the remainings are reserved for unchanged ones. 186 // Note that results from the inverse map computed previously does not follow 187 // the convention we used, and we need to fix the mismatch below. 188 unsigned curRepID = 0; 189 unsigned curOriID = inAdLvls.count(); 190 SmallVector<AffineExpr> results; 191 SmallVector<AffineExpr> dimRep(idxMap.getNumResults(), AffineExpr()); 192 SmallVector<utils::IteratorType> transItTps; 193 194 for (unsigned l : inAdLvls.set_bits()) { 195 // By our convention, the inadmissible level `l` always appears in the 196 // leading part (accumulated by curRepID) of the affine map's parameter 197 // list. Record the mapping so that we can replace all the uses of `l` to 198 // the correct position after the translation. 199 dimRep[l] = getAffineDimExpr(curRepID++, ctx); 200 // A new index variable is introduced for the inadmissible level, inherit 201 // the iterator type. E.g., if l0 = d0 floordiv 2, the 202 // iterator type of l0 equals to the iterator type of d0. 203 AffineExpr lvlExp = idxMap.getResult(l); 204 AffineDimCollector collector(idxMap.getNumDims()); 205 collector.walkPostOrder(lvlExp); 206 // We assumes a level can only be derived from one dimension. 207 assert(collector.dims.count() == 1); 208 transItTps.push_back(itTps[collector.dims.find_first()]); 209 } 210 211 for (unsigned d = 0, e = idxMap.getNumDims(); d < e; d++) { 212 if (usedDims.test(d)) { 213 // The dimension is used in some of the inadmissible levels, and it need 214 // to be inversed. Get the inversion from the inverse map, and fix the 215 // mismatch captured by the above loop. 216 results.push_back(lvl2Idx.getResult(d).replaceDims(dimRep)); 217 } else { 218 // The dimension is not used in any of the inadmissible levels, and it 219 // does not need to be inversed. Fix the mismatch by mapping it to the 220 // trailing part of the affine map (accumulated by curOriID). 221 results.push_back(getAffineDimExpr(curOriID++, ctx)); 222 transItTps.push_back(itTps[d]); 223 } 224 } 225 unsigned numDim = idxMap.getNumDims() - usedDims.count() + inAdLvls.count(); 226 // Update iterator type. 227 itTps.assign(transItTps.begin(), transItTps.end()); 228 return AffineMap::get(numDim, 0, results, ctx); 229 } 230 231 // Translates the index map in the linalg::GenericOp from idx->dim map to 232 // idx->lvl map. Returns failure if the index map can not be translated to an 233 // admissible form. 234 // Returns the translated index map array and the iterator type array. 235 static std::optional<std::pair<ArrayAttr, ArrayAttr>> 236 translateMap(linalg::GenericOp op, PatternRewriter &rewriter) { 237 // idxMap is a idx2dim map before reinterpretation. 238 MLIRContext *ctx = op.getContext(); 239 SmallVector<AffineMap> idxMapArray = op.getIndexingMapsArray(); 240 SmallVector<utils::IteratorType> itTps = op.getIteratorTypesArray(); 241 for (unsigned i = 0, e = idxMapArray.size(); i < e; i++) { 242 Value tensor = op->getOpOperand(i).get(); 243 auto stt = tryGetSparseTensorType(tensor); 244 if (stt && !stt->isIdentity()) { 245 AffineMap dim2Lvl = stt->getDimToLvl(); 246 // By composing the idx2dim(dim2lvl), we got a idx2lvl Map 247 idxMapArray[i] = dim2Lvl.compose(idxMapArray[i]); 248 } 249 } 250 251 // A naive way to handle common constant expressions that arise during dim2lvl 252 // translation. 253 auto populateCstMapping = [ctx](DenseMap<AffineExpr, AffineExpr> &cstMapping, 254 unsigned pos, int64_t lvlSz) { 255 if (!ShapedType::isDynamic(lvlSz)) { 256 auto c0 = getAffineConstantExpr(0, ctx); 257 auto lvlExp = getAffineDimExpr(pos, ctx); 258 auto szExp = getAffineConstantExpr(lvlSz, ctx); 259 260 // lvl floordiv lvlSz = 0 261 auto divExp = 262 getAffineBinaryOpExpr(AffineExprKind::FloorDiv, lvlExp, szExp); 263 cstMapping.try_emplace(divExp, c0); 264 265 // lvl mod lvlSz = lvl 266 auto modExp = getAffineBinaryOpExpr(AffineExprKind::Mod, lvlExp, szExp); 267 cstMapping.try_emplace(modExp, lvlExp); 268 } 269 }; 270 271 unsigned boundedNum = 0; 272 // A fixed-point algorithm. 273 bool changed = true; 274 while (changed) { 275 changed = false; 276 for (OpOperand &operand : op->getOpOperands()) { 277 auto stt = tryGetSparseTensorType(operand.get()); 278 // Skip on dense operands. 279 if (!stt || !stt->getEncoding()) 280 continue; 281 282 unsigned tid = operand.getOperandNumber(); 283 bool isOutput = &operand == op.getDpsInitOperand(0); 284 AffineMap idxMap = idxMapArray[tid]; 285 InadmissInfo inAdInfo = collectInadmissInfo(idxMap, isOutput); 286 auto [inAdLvls, dimExprs] = inAdInfo; 287 for (unsigned d : dimExprs.set_bits()) { 288 // The first `boundedNum` used in the AffineMap is introduced to 289 // resolve previous inadmissible expressions. We can not replace them 290 // as it might bring back the inadmissible expressions. 291 if (d < boundedNum) 292 return std::nullopt; 293 } 294 295 if (inAdLvls.count() != 0) { 296 // Naive constant progagation, should be sufficient to handle block 297 // sparsity in our cases. 298 SmallVector<int64_t> lvlShape = stt->getLvlShape(); 299 DenseMap<AffineExpr, AffineExpr> cstMapping; 300 unsigned position = 0; 301 for (unsigned lvl : inAdLvls.set_bits()) { 302 int64_t lvlSz = lvlShape[lvl]; 303 populateCstMapping(cstMapping, position, lvlSz); 304 position++; 305 } 306 307 AffineMap lvl2Idx = genReplaceDimToLvlMap(inAdInfo, idxMap, itTps); 308 // Compose the lvl2Idx Map to all AffineIdxMap to eliminate 309 // inadmissible expressions. 310 for (unsigned tid = 0, e = idxMapArray.size(); tid < e; tid++) { 311 AffineMap transMap = idxMapArray[tid].compose(lvl2Idx); 312 idxMapArray[tid] = transMap.replace( 313 cstMapping, /*numResultDims=*/transMap.getNumDims(), 314 /*numResultSyms=*/0); 315 } 316 changed = true; 317 boundedNum += inAdLvls.count(); 318 } 319 } 320 }; 321 322 SmallVector<Attribute> iterAttr = 323 llvm::map_to_vector(itTps, [ctx](auto itTp) -> Attribute { 324 return linalg::IteratorTypeAttr::get(ctx, itTp); 325 }); 326 327 return std::make_pair(rewriter.getAffineMapArrayAttr(idxMapArray), 328 rewriter.getArrayAttr(iterAttr)); 329 } 330 331 // Generates a "de"mapping reinterpretation of the map. 332 static Value genDemap(OpBuilder &builder, SparseTensorEncodingAttr enc, 333 Value val) { 334 return builder.create<ReinterpretMapOp>(val.getLoc(), enc.withoutDimToLvl(), 335 val); 336 } 337 338 // Generates a "re"mapping reinterpretation of the map. 339 static Value genRemap(OpBuilder &builder, SparseTensorEncodingAttr enc, 340 Value val) { 341 return builder.create<ReinterpretMapOp>(val.getLoc(), enc, val); 342 } 343 344 static SmallVector<Value> remapValueRange(OpBuilder &rewriter, TypeRange types, 345 ValueRange outs) { 346 SmallVector<Value> ret(outs); 347 assert(outs.size() == types.size()); 348 for (auto [r, t] : llvm::zip(ret, types)) 349 if (r.getType() != t) 350 r = rewriter.create<ReinterpretMapOp>(r.getLoc(), t, r); 351 return ret; 352 } 353 354 /// Whether the operation has any sparse tensor with non-identity dim2lvl maps. 355 static bool hasNonIdentityOperandsOrResults(Operation *op) { 356 auto hasNonIdentityMap = [](Value v) { 357 auto stt = tryGetSparseTensorType(v); 358 return stt && !stt->isIdentity(); 359 }; 360 361 return llvm::any_of(op->getOperands(), hasNonIdentityMap) || 362 llvm::any_of(op->getResults(), hasNonIdentityMap); 363 } 364 365 namespace { 366 367 //===----------------------------------------------------------------------===// 368 // Rewriting rules for linalg generic ops. 369 //===----------------------------------------------------------------------===// 370 371 /// Sparse rewriting rule for the generic `linalg` operation. 372 struct GenericOpReinterpretMap 373 : public DemapInsRewriter<GenericOpReinterpretMap, linalg::GenericOp> { 374 public: 375 using DemapInsRewriter::DemapInsRewriter; 376 LogicalResult rewriteOp(linalg::GenericOp linalgOp, OpAdaptor adaptor, 377 PatternRewriter &rewriter) const { 378 // Only rewrite single output operations with pure (sparse) tensor 379 // semantics. 380 if (linalgOp.getNumDpsInits() != 1 || !linalgOp.hasTensorSemantics() || 381 !hasAnySparseOperandOrResult(linalgOp) || 382 !hasNonIdentityOperandsOrResults(linalgOp)) 383 return failure(); 384 385 // Try translating the index map. 386 auto transMap = translateMap(linalgOp, rewriter); 387 if (!transMap) 388 return rewriter.notifyMatchFailure( 389 linalgOp, "the sparse kernel can not be sparsified."); 390 391 // On success, replace update the linalg operands and maps in place. 392 Value res = linalgOp.getResult(0); 393 auto stt = tryGetSparseTensorType(res); 394 auto [idxMap, itTp] = *transMap; 395 396 rewriter.startRootUpdate(linalgOp); 397 linalgOp.setIndexingMapsAttr(idxMap); 398 linalgOp.setIteratorTypesAttr(itTp); 399 // Use demapped arguments. 400 linalgOp.getInputsMutable().assign(adaptor.getInputs()); 401 linalgOp.getDpsInitsMutable().assign(adaptor.getOutputs()); 402 res.setType(adaptor.getOutputs()[0].getType()); 403 rewriter.finalizeRootUpdate(linalgOp); 404 405 rewriter.setInsertionPointAfter(linalgOp); 406 if (stt && stt->hasEncoding()) { 407 Value t = genRemap(rewriter, stt->getEncoding(), res); 408 rewriter.replaceAllUsesExcept(res, t, t.getDefiningOp()); 409 } 410 return success(); 411 } 412 }; 413 414 //===----------------------------------------------------------------------===// 415 // Reinterpret Map Rewriters for operations other than linalg.generics 416 //===----------------------------------------------------------------------===// 417 418 template <typename AllocOp> 419 struct TensorAllocDemapper : public OpRewritePattern<AllocOp> { 420 using OpRewritePattern<AllocOp>::OpRewritePattern; 421 LogicalResult matchAndRewrite(AllocOp op, 422 PatternRewriter &rewriter) const override { 423 if (!hasNonIdentityOperandsOrResults(op)) 424 return failure(); 425 426 Location loc = op.getLoc(); 427 auto stt = getSparseTensorType(op.getResult()); 428 429 SmallVector<Value> maxDimCrds; 430 maxDimCrds.reserve(stt.getDimRank()); 431 ValueRange dynSz = op.getDynamicSizes(); 432 for (int64_t dimSz : stt.getDimShape()) { 433 if (ShapedType::isDynamic(dimSz)) { 434 Value maxCrd = rewriter.create<arith::SubIOp>( 435 loc, dynSz.front(), constantIndex(rewriter, loc, 1)); 436 maxDimCrds.push_back(maxCrd); 437 dynSz = dynSz.drop_front(); 438 } else { 439 maxDimCrds.push_back(constantIndex(rewriter, loc, dimSz - 1)); 440 } 441 } 442 443 ValueRange maxLvlCrds = stt.translateCrds(rewriter, loc, maxDimCrds, 444 CrdTransDirectionKind::dim2lvl); 445 auto lvlShape = stt.getLvlShape(); 446 SmallVector<Value> dynLvlSzs; 447 for (unsigned i = 0, e = lvlShape.size(); i < e; i++) { 448 if (ShapedType::isDynamic(lvlShape[i])) { 449 Value sz = rewriter.create<arith::AddIOp>( 450 loc, maxLvlCrds[i], constantIndex(rewriter, loc, 1)); 451 dynLvlSzs.push_back(sz); 452 } 453 } 454 455 assert(dynSz.empty()); // should have consumed all. 456 rewriter.startRootUpdate(op); 457 op->setOperands(dynLvlSzs); 458 op.getResult().setType(stt.getDemappedType()); 459 rewriter.finalizeRootUpdate(op); 460 rewriter.setInsertionPointAfter(op); 461 462 Value t = genRemap(rewriter, stt.getEncoding(), op.getResult()); 463 rewriter.replaceAllUsesExcept(op.getResult(), t, t.getDefiningOp()); 464 return success(); 465 } 466 }; 467 468 struct TensorInsertDemapper 469 : public DemapInsRewriter<TensorInsertDemapper, tensor::InsertOp> { 470 using DemapInsRewriter::DemapInsRewriter; 471 LogicalResult rewriteOp(tensor::InsertOp op, OpAdaptor adaptor, 472 PatternRewriter &rewriter) const { 473 if (!hasAnySparseResult(op)) 474 return failure(); 475 476 Location loc = op.getLoc(); 477 auto stt = getSparseTensorType(op.getResult()); 478 ValueRange lvlCrd = stt.translateCrds(rewriter, loc, op.getIndices(), 479 CrdTransDirectionKind::dim2lvl); 480 auto insertOp = rewriter.create<sparse_tensor::InsertOp>( 481 loc, op.getScalar(), adaptor.getDest(), lvlCrd); 482 483 Value out = genRemap(rewriter, stt.getEncoding(), insertOp.getResult()); 484 rewriter.replaceOp(op, out); 485 return success(); 486 } 487 }; 488 489 struct ForeachOpDemapper 490 : public DemapInsRewriter<ForeachOpDemapper, ForeachOp> { 491 using DemapInsRewriter::DemapInsRewriter; 492 LogicalResult rewriteOp(ForeachOp op, OpAdaptor adaptor, 493 PatternRewriter &rewriter) const { 494 // Only handle operations with sparse input/output with non-identity dim2lvl 495 // maps. 496 if (!hasNonIdentityOperandsOrResults(op)) 497 return failure(); 498 499 // TODO: demap constant as well. 500 if (auto constOp = op.getTensor().getDefiningOp<arith::ConstantOp>()) 501 if (auto attr = dyn_cast<SparseElementsAttr>(constOp.getValue())) 502 return failure(); 503 504 Location loc = op.getLoc(); 505 // Cache the type information since we update the foreach op in-place. 506 auto srcStt = getSparseTensorType(op.getTensor()); 507 SmallVector<Type> prevRetTps(op.getResultTypes()); 508 509 rewriter.startRootUpdate(op); 510 op.getTensorMutable().assign(adaptor.getTensor()); 511 op.getInitArgsMutable().assign(adaptor.getInitArgs()); 512 // Update results' types. 513 for (auto r : op.getResults()) 514 if (auto stt = tryGetSparseTensorType(r); stt && !stt->isIdentity()) 515 r.setType(stt->getDemappedType()); 516 517 Level lvlRank = getSparseTensorType(adaptor.getTensor()).getLvlRank(); 518 // Update the foreach body. 519 SmallVector<Type> blockArgTps(lvlRank, rewriter.getIndexType()); 520 blockArgTps.push_back(srcStt.getElementType()); 521 blockArgTps.append(adaptor.getInitArgs().getTypes().begin(), 522 adaptor.getInitArgs().getTypes().end()); 523 Block *body = op.getBody(); 524 // Block Args: [dimCrd, val, initArgs] 525 unsigned preArgNum = body->getNumArguments(); 526 for (Type t : blockArgTps) 527 body->addArgument(t, loc); 528 529 // Block Args: [dimCrd, val, initArgs, lvlCrds, val, DemappedArgs] 530 rewriter.setInsertionPointToStart(body); 531 ValueRange lvlCrds = body->getArguments().slice(preArgNum, lvlRank); 532 533 ValueRange dimCrds = srcStt.translateCrds(rewriter, loc, lvlCrds, 534 CrdTransDirectionKind::lvl2dim); 535 rewriter.replaceAllUsesWith( 536 body->getArguments().take_front(srcStt.getDimRank()), dimCrds); 537 body->eraseArguments(0, srcStt.getDimRank()); 538 // Block Args: [val, initArgs, lvlCrds, val, DemappedArgs] 539 unsigned numInitArgs = op.getInitArgs().size(); 540 rewriter.replaceAllUsesWith(body->getArgument(0), 541 body->getArgument(lvlRank + numInitArgs + 1)); 542 body->eraseArgument(0); 543 // Block Args: [initArgs, lvlCrds, val, DemappedArgs] 544 ValueRange srcArgs = body->getArguments().take_front(numInitArgs); 545 ValueRange dstArgs = body->getArguments().take_back(numInitArgs); 546 // Remap back before replacement. 547 SmallVector<Value> reMappedArgs = 548 remapValueRange(rewriter, srcArgs.getTypes(), dstArgs); 549 rewriter.replaceAllUsesWith(srcArgs, reMappedArgs); 550 body->eraseArguments(0, numInitArgs); 551 // Block Args: [lvlCrds, DemappedArgs] and we are done. 552 553 // Update yield operations. 554 if (numInitArgs != 0) { 555 rewriter.setInsertionPointToEnd(body); 556 auto yield = llvm::cast<YieldOp>(body->getTerminator()); 557 if (auto stt = tryGetSparseTensorType(yield.getResult()); 558 stt && !stt->isIdentity()) { 559 Value y = genDemap(rewriter, stt->getEncoding(), yield.getResult()); 560 rewriter.create<YieldOp>(loc, y); 561 rewriter.eraseOp(yield); 562 } 563 } 564 rewriter.finalizeRootUpdate(op); 565 566 rewriter.setInsertionPointAfter(op); 567 SmallVector<Value> outs = 568 remapValueRange(rewriter, prevRetTps, op.getResults()); 569 570 // Replace all the uses of the foreach results, expect the use in 571 // reinterpret_map used to remap the output. 572 for (auto [from, to] : llvm::zip(op.getResults(), outs)) 573 rewriter.replaceAllUsesExcept(from, to, to.getDefiningOp()); 574 575 return success(); 576 } 577 }; 578 579 } // namespace 580 581 void mlir::populateSparseReinterpretMap(RewritePatternSet &patterns, 582 ReinterpretMapScope scope) { 583 if (scope == ReinterpretMapScope::kAll || 584 scope == ReinterpretMapScope::kGenericOnly) { 585 patterns.add<GenericOpReinterpretMap>(patterns.getContext()); 586 } 587 if (scope == ReinterpretMapScope::kAll || 588 scope == ReinterpretMapScope::kExceptGeneric) { 589 patterns.add<TensorAllocDemapper<bufferization::AllocTensorOp>, 590 TensorAllocDemapper<tensor::EmptyOp>, TensorInsertDemapper, 591 ForeachOpDemapper>(patterns.getContext()); 592 } 593 } 594