1 //===- SparseReinterpretMap.cpp - reinterpret sparse tensor maps ----------===/ 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "CodegenUtils.h" 10 #include "IterationGraphSorter.h" 11 12 #include "mlir/Dialect/Affine/IR/AffineOps.h" 13 #include "mlir/Dialect/Bufferization/IR/Bufferization.h" 14 #include "mlir/Dialect/Linalg/IR/Linalg.h" 15 #include "mlir/Dialect/Linalg/Utils/Utils.h" 16 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" 17 #include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h" 18 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h" 19 #include "mlir/Dialect/Tensor/IR/Tensor.h" 20 #include "mlir/IR/AffineExprVisitor.h" 21 #include "mlir/IR/AffineMap.h" 22 23 using namespace mlir; 24 using namespace mlir::sparse_tensor; 25 26 namespace { 27 28 //===----------------------------------------------------------------------===// 29 // File Local Helper classes. 30 //===----------------------------------------------------------------------===// 31 32 // CRTP to help implementing a rewriter that demaps all its inputs. 33 template <typename SubClass, typename SourceOp> 34 struct DemapInsRewriter : public OpRewritePattern<SourceOp> { 35 using OpRewritePattern<SourceOp>::OpRewritePattern; 36 using OpAdaptor = typename SourceOp::Adaptor; 37 38 LogicalResult matchAndRewrite(SourceOp op, 39 PatternRewriter &rewriter) const override { 40 Location loc = op.getLoc(); 41 // Demaps non-trivial inputs. 42 SmallVector<Value> deMappedIns(op->getOperands()); 43 for (Value &in : deMappedIns) 44 if (auto stt = tryGetSparseTensorType(in); stt && !stt->isIdentity()) 45 in = rewriter.create<ReinterpretMapOp>(loc, stt->getDemappedType(), in); 46 47 // CRTP call. 48 OpAdaptor adaptor(deMappedIns, op); 49 return static_cast<const SubClass *>(this)->rewriteOp(op, adaptor, 50 rewriter); 51 } 52 }; 53 54 // Flattens an affine expression into a list of AffineDimExprs. 55 struct AffineDimCollector : public AffineExprVisitor<AffineDimCollector> { 56 explicit AffineDimCollector(unsigned dimNum) : dims(dimNum){}; 57 void visitDimExpr(AffineDimExpr expr) { dims.set(expr.getPosition()); } 58 BitVector dims; 59 }; 60 61 // Flattens an affine expression into a list of AffineDimExprs. 62 struct AffineExprAdmissibleVisitor 63 : public AffineExprVisitor<AffineExprAdmissibleVisitor> { 64 explicit AffineExprAdmissibleVisitor(bool isOutput) 65 : admissible(true), isOutput(isOutput){}; 66 67 // We only allow AffineDimExpr on output. 68 void visitAddExpr(AffineBinaryOpExpr expr) { 69 if (isOutput) 70 admissible = false; 71 } 72 void visitMulExpr(AffineBinaryOpExpr expr) { 73 if (isOutput) 74 admissible = false; 75 } 76 77 // We disallow mod, floor div and ceil div on inputs. 78 void visitModExpr(AffineBinaryOpExpr expr) { admissible = false; } 79 void visitFloorDivExpr(AffineBinaryOpExpr expr) { admissible = false; } 80 void visitCeilDivExpr(AffineBinaryOpExpr expr) { admissible = false; } 81 operator bool() { return admissible; } 82 83 private: 84 bool admissible; 85 bool isOutput; 86 }; 87 88 // The first BitVector stores levels where inadmissible exprs are used. 89 // The second BitVector stores the AffineDimExp that are used by the 90 // inadmissible expressions. 91 using InadmissInfo = std::pair<BitVector, BitVector>; 92 93 } // namespace 94 95 //===----------------------------------------------------------------------===// 96 // File Local Helper methods. 97 //===----------------------------------------------------------------------===// 98 99 // Collects the inadmissible affine expression imposed on levels. 100 static InadmissInfo collectInadmissInfo(AffineMap map, bool isOutput) { 101 auto ret = std::make_pair(BitVector(map.getNumResults()), 102 BitVector(map.getNumDims())); 103 AffineDimCollector collector(map.getNumDims()); 104 for (unsigned lvl = 0, e = map.getNumResults(); lvl < e; lvl++) { 105 AffineExprAdmissibleVisitor admissible(isOutput); 106 admissible.walkPostOrder(map.getResult(lvl)); 107 if (!admissible) { 108 // Record the inadmissible level. 109 ret.first.set(lvl); 110 // Record the AffineDimExpr that is used in the inadmissible expr. 111 collector.walkPostOrder(map.getResult(lvl)); 112 } 113 } 114 ret.second = collector.dims; 115 return ret; 116 } 117 118 // Builds the AffineMap to replace the idx in idxMap to lvl such that all tht 119 // inadmissible affine expressions can be eliminated. 120 // For example, we can rewrite 121 // idxMap = (d0, d1) -> (d0 floordiv 2, d1 floordiv 3, d0 mod 2, d1 mod 3) 122 // to 123 // idxMap = (l0, l1, l2, l3) -> (l0, l1, l2, l3) 124 // by composing inverse(idxMap), that is 125 // inverse(idxMap) . idxMap = (l0, l1, l2, l3) -> (l0 * 2 + l2, l1 * 3 + l3) 126 // -> ((l0 * 2 + l2) floordiv 2, 127 // (l1 * 3 + l3) floordiv 3, 128 // (l0 * 2 + l2) mod 2, 129 // (l1 * 3 + l3) mod 3) = (l0, l1, l2, l3) 130 // 131 // This function builds the inverse(idxMap) that replace every dimensions used 132 // in `info` to levels, and updates the iterator type array `itTps` for the new 133 // index variable introduced. 134 // 135 // Note that the returned affine map does not retain the order of the input 136 // affine map. Instead, it always uses the first `info.inAdlvls.count()` for the 137 // replaced levels, and remaining ones for unused dimensions. 138 // For example, to handle 139 // idxMap = (d0, d1) -> (d0, d1 floordiv 4, d2 mod 4) 140 // which is a typical map for block_2to4. The function returns: 141 // inverse(idxMap) = (l0, l1, d0) -> (d0, l0 * 4 + l1) 142 // in which, (l0, l1) together replaces `d1`, yet they appear 143 // before `d0` in the resulting affine map. 144 // The index (loop) order can later be canonicalized by a topo sort. 145 static AffineMap 146 genReplaceDimToLvlMap(const InadmissInfo &info, AffineMap idxMap, 147 SmallVector<utils::IteratorType> &itTps) { 148 MLIRContext *ctx = idxMap.getContext(); 149 auto [inAdLvls, usedDims] = info; 150 // Note that idxMap does not equal to dim2Lvl map, it is computed by 151 // composing idx2Dim(dim2Lvl). They are only equal when idx2Dim is an 152 // ID map. 153 // TODO: we might fail here, in those case we should really return 154 // failure instead of assertion error. 155 auto lvl2Idx = inferLvlToDim(idxMap, ctx); 156 157 assert(lvl2Idx.getNumResults() <= idxMap.getNumDims()); 158 if (lvl2Idx.getNumResults() != idxMap.getNumDims()) { 159 // This could happen when some dimensions are projected. 160 // E.g., idx2Lvl = (*i*, j, k) -> (j, k) 161 // ==> lvl2Idx = (j, k) -> (j, k) 162 // In this case, we append the unused dimesion at the end. 163 // ==> lvl2Idx = (j, k, *i*) -> (*i*, j, k) 164 SmallVector<AffineExpr> results; 165 AffineDimCollector usedInLvl(idxMap.getNumDims()); 166 for (auto e : idxMap.getResults()) 167 usedInLvl.walkPostOrder(e); 168 169 unsigned curUsedDimID = 0; 170 unsigned curUnusedDimID = lvl2Idx.getNumDims(); 171 172 BitVector unused = usedInLvl.dims.flip(); 173 for (unsigned i = 0; i < idxMap.getNumDims(); i++) { 174 if (unused.test(i)) 175 results.push_back(getAffineDimExpr(curUnusedDimID++, ctx)); 176 else 177 results.push_back(lvl2Idx.getResult(curUsedDimID++)); 178 } 179 lvl2Idx = 180 AffineMap::get(lvl2Idx.getNumDims() + unused.count(), 0, results, ctx); 181 } 182 assert(lvl2Idx.getNumResults() == idxMap.getNumDims()); 183 184 // We do not need to replace the DimExpr that is not used in inadmissible 185 // level expressions. We use the first inAdLvl.count() dim to represent the 186 // replaced level, the remainings are reserved for unchanged ones. 187 // Note that results from the inverse map computed previously does not follow 188 // the convention we used, and we need to fix the mismatch below. 189 unsigned curRepID = 0; 190 unsigned curOriID = inAdLvls.count(); 191 SmallVector<AffineExpr> results; 192 SmallVector<AffineExpr> dimRep(idxMap.getNumResults(), AffineExpr()); 193 SmallVector<utils::IteratorType> transItTps; 194 195 for (unsigned l : inAdLvls.set_bits()) { 196 // By our convention, the inadmissible level `l` always appears in the 197 // leading part (accumulated by curRepID) of the affine map's parameter 198 // list. Record the mapping so that we can replace all the uses of `l` to 199 // the correct position after the translation. 200 dimRep[l] = getAffineDimExpr(curRepID++, ctx); 201 // A new index variable is introduced for the inadmissible level, inherit 202 // the iterator type. E.g., if l0 = d0 floordiv 2, the 203 // iterator type of l0 equals to the iterator type of d0. 204 AffineExpr lvlExp = idxMap.getResult(l); 205 AffineDimCollector collector(idxMap.getNumDims()); 206 collector.walkPostOrder(lvlExp); 207 // We assumes a level can only be derived from one dimension. 208 assert(collector.dims.count() == 1); 209 transItTps.push_back(itTps[collector.dims.find_first()]); 210 } 211 212 for (unsigned d = 0, e = idxMap.getNumDims(); d < e; d++) { 213 if (usedDims.test(d)) { 214 // The dimension is used in some of the inadmissible levels, and it need 215 // to be inversed. Get the inversion from the inverse map, and fix the 216 // mismatch captured by the above loop. 217 results.push_back(lvl2Idx.getResult(d).replaceDims(dimRep)); 218 } else { 219 // The dimension is not used in any of the inadmissible levels, and it 220 // does not need to be inversed. Fix the mismatch by mapping it to the 221 // trailing part of the affine map (accumulated by curOriID). 222 results.push_back(getAffineDimExpr(curOriID++, ctx)); 223 transItTps.push_back(itTps[d]); 224 } 225 } 226 unsigned numDim = idxMap.getNumDims() - usedDims.count() + inAdLvls.count(); 227 // Update iterator type. 228 itTps.assign(transItTps.begin(), transItTps.end()); 229 return AffineMap::get(numDim, 0, results, ctx); 230 } 231 232 // Translates the index map in the linalg::GenericOp from idx->dim map to 233 // idx->lvl map. Returns failure if the index map can not be translated to an 234 // admissible form. 235 // Returns the translated index map array and the iterator type array. 236 static std::optional<std::pair<ArrayAttr, ArrayAttr>> 237 translateMap(linalg::GenericOp op, PatternRewriter &rewriter) { 238 // idxMap is a idx2dim map before reinterpretation. 239 MLIRContext *ctx = op.getContext(); 240 SmallVector<AffineMap> idxMapArray = op.getIndexingMapsArray(); 241 SmallVector<utils::IteratorType> itTps = op.getIteratorTypesArray(); 242 for (unsigned i = 0, e = idxMapArray.size(); i < e; i++) { 243 Value tensor = op->getOpOperand(i).get(); 244 auto stt = tryGetSparseTensorType(tensor); 245 if (stt && !stt->isIdentity()) { 246 AffineMap dim2Lvl = stt->getDimToLvl(); 247 // By composing the idx2dim(dim2lvl), we got a idx2lvl Map 248 idxMapArray[i] = dim2Lvl.compose(idxMapArray[i]); 249 } 250 } 251 252 // A naive way to handle common constant expressions that arise during dim2lvl 253 // translation. 254 auto populateCstMapping = [ctx](DenseMap<AffineExpr, AffineExpr> &cstMapping, 255 unsigned pos, int64_t lvlSz) { 256 if (!ShapedType::isDynamic(lvlSz)) { 257 auto c0 = getAffineConstantExpr(0, ctx); 258 auto lvlExp = getAffineDimExpr(pos, ctx); 259 auto szExp = getAffineConstantExpr(lvlSz, ctx); 260 261 // lvl floordiv lvlSz = 0 262 auto divExp = 263 getAffineBinaryOpExpr(AffineExprKind::FloorDiv, lvlExp, szExp); 264 cstMapping.try_emplace(divExp, c0); 265 266 // lvl mod lvlSz = lvl 267 auto modExp = getAffineBinaryOpExpr(AffineExprKind::Mod, lvlExp, szExp); 268 cstMapping.try_emplace(modExp, lvlExp); 269 } 270 }; 271 272 unsigned boundedNum = 0; 273 // A fixed-point algorithm. 274 bool changed = true; 275 while (changed) { 276 changed = false; 277 for (OpOperand &operand : op->getOpOperands()) { 278 auto stt = tryGetSparseTensorType(operand.get()); 279 // Skip on dense operands. 280 if (!stt || !stt->getEncoding()) 281 continue; 282 283 unsigned tid = operand.getOperandNumber(); 284 bool isOutput = &operand == op.getDpsInitOperand(0); 285 AffineMap idxMap = idxMapArray[tid]; 286 InadmissInfo inAdInfo = collectInadmissInfo(idxMap, isOutput); 287 auto [inAdLvls, dimExprs] = inAdInfo; 288 for (unsigned d : dimExprs.set_bits()) { 289 // The first `boundedNum` used in the AffineMap is introduced to 290 // resolve previous inadmissible expressions. We can not replace them 291 // as it might bring back the inadmissible expressions. 292 if (d < boundedNum) 293 return std::nullopt; 294 } 295 296 if (inAdLvls.count() != 0) { 297 // Naive constant progagation, should be sufficient to handle block 298 // sparsity in our cases. 299 SmallVector<int64_t> lvlShape = stt->getLvlShape(); 300 DenseMap<AffineExpr, AffineExpr> cstMapping; 301 unsigned position = 0; 302 for (unsigned lvl : inAdLvls.set_bits()) { 303 int64_t lvlSz = lvlShape[lvl]; 304 populateCstMapping(cstMapping, position, lvlSz); 305 position++; 306 } 307 308 AffineMap lvl2Idx = genReplaceDimToLvlMap(inAdInfo, idxMap, itTps); 309 // Compose the lvl2Idx Map to all AffineIdxMap to eliminate 310 // inadmissible expressions. 311 for (unsigned tid = 0, e = idxMapArray.size(); tid < e; tid++) { 312 AffineMap transMap = idxMapArray[tid].compose(lvl2Idx); 313 idxMapArray[tid] = transMap.replace( 314 cstMapping, /*numResultDims=*/transMap.getNumDims(), 315 /*numResultSyms=*/0); 316 } 317 changed = true; 318 boundedNum += inAdLvls.count(); 319 } 320 } 321 }; 322 323 SmallVector<Attribute> iterAttr = 324 llvm::map_to_vector(itTps, [ctx](auto itTp) -> Attribute { 325 return linalg::IteratorTypeAttr::get(ctx, itTp); 326 }); 327 328 return std::make_pair(rewriter.getAffineMapArrayAttr(idxMapArray), 329 rewriter.getArrayAttr(iterAttr)); 330 } 331 332 // Generates a "de"mapping reinterpretation of the map. 333 static Value genDemap(OpBuilder &builder, SparseTensorEncodingAttr enc, 334 Value val) { 335 return builder.create<ReinterpretMapOp>(val.getLoc(), enc.withoutDimToLvl(), 336 val); 337 } 338 339 // Generates a "re"mapping reinterpretation of the map. 340 static Value genRemap(OpBuilder &builder, SparseTensorEncodingAttr enc, 341 Value val) { 342 return builder.create<ReinterpretMapOp>(val.getLoc(), enc, val); 343 } 344 345 static SmallVector<Value> remapValueRange(OpBuilder &rewriter, TypeRange types, 346 ValueRange outs) { 347 SmallVector<Value> ret(outs); 348 assert(outs.size() == types.size()); 349 for (auto [r, t] : llvm::zip(ret, types)) 350 if (r.getType() != t) 351 r = rewriter.create<ReinterpretMapOp>(r.getLoc(), t, r); 352 return ret; 353 } 354 355 namespace { 356 357 //===----------------------------------------------------------------------===// 358 // Rewriting rules for linalg generic ops. 359 //===----------------------------------------------------------------------===// 360 361 /// Sparse rewriting rule for the generic `linalg` operation. 362 struct GenericOpReinterpretMap 363 : public DemapInsRewriter<GenericOpReinterpretMap, linalg::GenericOp> { 364 public: 365 using DemapInsRewriter::DemapInsRewriter; 366 LogicalResult rewriteOp(linalg::GenericOp linalgOp, OpAdaptor adaptor, 367 PatternRewriter &rewriter) const { 368 // Only rewrite single output operations with pure (sparse) tensor 369 // semantics. 370 if (linalgOp.getNumDpsInits() != 1 || !linalgOp.hasTensorSemantics() || 371 !hasAnySparseOperandOrResult(linalgOp) || 372 !hasAnyNonIdentityOperandsOrResults(linalgOp)) 373 return failure(); 374 375 // Try translating the index map. 376 auto transMap = translateMap(linalgOp, rewriter); 377 if (!transMap) 378 return rewriter.notifyMatchFailure( 379 linalgOp, "the sparse kernel can not be sparsified."); 380 381 // On success, replace update the linalg operands and maps in place. 382 Value res = linalgOp.getResult(0); 383 auto stt = tryGetSparseTensorType(res); 384 auto [idxMap, itTp] = *transMap; 385 386 rewriter.startRootUpdate(linalgOp); 387 linalgOp.setIndexingMapsAttr(idxMap); 388 linalgOp.setIteratorTypesAttr(itTp); 389 // Use demapped arguments. 390 linalgOp.getInputsMutable().assign(adaptor.getInputs()); 391 linalgOp.getDpsInitsMutable().assign(adaptor.getOutputs()); 392 res.setType(adaptor.getOutputs()[0].getType()); 393 rewriter.finalizeRootUpdate(linalgOp); 394 395 rewriter.setInsertionPointAfter(linalgOp); 396 if (stt && stt->hasEncoding()) { 397 Value t = genRemap(rewriter, stt->getEncoding(), res); 398 rewriter.replaceAllUsesExcept(res, t, t.getDefiningOp()); 399 } 400 return success(); 401 } 402 }; 403 404 struct GenericOpScheduler : public OpRewritePattern<linalg::GenericOp> { 405 using OpRewritePattern::OpRewritePattern; 406 LogicalResult matchAndRewrite(linalg::GenericOp linalgOp, 407 PatternRewriter &rewriter) const override { 408 if (linalgOp.getNumDpsInits() != 1 || !linalgOp.hasTensorSemantics() || 409 hasAnyNonIdentityOperandsOrResults(linalgOp) || // need demap first 410 !hasAnySparseOperandOrResult(linalgOp)) { 411 return failure(); 412 } 413 414 const StringRef sorted = "sorted"; 415 if (linalgOp->hasAttr(sorted)) 416 return failure(); 417 418 auto scheduler = IterationGraphSorter::fromGenericOp(linalgOp); 419 bool isAdmissible = false; 420 AffineMap order; 421 // A const list of all masks that we used for iteration graph 422 // computation. Must be ordered from more strict to less strict. 423 // Ideally (though might not be guaranteed), the earlier a constraint mask 424 // can be satisfied, the faster the generated kernel will be. 425 const auto allMasks = { 426 SortMask::kIncludeAll, SortMask::kIncludeDense, 427 SortMask::kIncludeDenseInput, SortMask::kIncludeDenseOutput, 428 SortMask::kIncludeUndef, SortMask::kSparseOnly}; 429 for (const SortMask mask : allMasks) { 430 order = scheduler.sort(mask); 431 if (order) { 432 if (isAdmissibleOrder(linalgOp, order)) { 433 isAdmissible = true; 434 break; 435 } 436 // else try a set of less strict constraints. 437 } 438 } 439 440 if (!order) { 441 // Cycles detected. 442 if (failed(resolveCycle(scheduler, linalgOp, rewriter))) { 443 return rewriter.notifyMatchFailure( 444 linalgOp, "the sparse kernel can not be scheduled: loop detected."); 445 } 446 return success(); 447 } 448 449 if (!isAdmissible) { 450 return rewriter.notifyMatchFailure( 451 linalgOp, "the sparse kernel can not be scheduled."); 452 } 453 454 // Marks the GenericOp to avoid recursive matching. 455 linalgOp->setAttr(sorted, rewriter.getBoolAttr(true)); 456 457 // Already sorted. 458 if (order.isIdentity()) 459 return failure(); 460 461 assert(order.isPermutation()); 462 // `order` is orignial loop -> sorted loop map 463 ArrayAttr preItTypes = linalgOp.getIteratorTypesAttr(); 464 SmallVector<Attribute> curItTypes; 465 curItTypes.reserve(preItTypes.size()); 466 for (AffineExpr expr : order.getResults()) { 467 unsigned loopID = llvm::cast<AffineDimExpr>(expr).getPosition(); 468 curItTypes.push_back(preItTypes[loopID]); 469 } 470 471 // Inverse `order` to get sorted loop -> original loop map 472 order = inversePermutation(order); 473 SmallVector<AffineMap> idxMaps = linalgOp.getIndexingMapsArray(); 474 for (AffineMap &idxMap : idxMaps) 475 idxMap = idxMap.compose(order); // sorted loop -> lvl map 476 477 rewriter.startRootUpdate(linalgOp); 478 linalgOp.setIndexingMapsAttr(rewriter.getAffineMapArrayAttr(idxMaps)); 479 linalgOp.setIteratorTypesAttr(rewriter.getArrayAttr(curItTypes)); 480 rewriter.finalizeRootUpdate(linalgOp); 481 482 return success(); 483 } 484 485 private: 486 /// Whether the loop order is admissible by sparsification. 487 static bool isAdmissibleOrder(linalg::GenericOp linalgOp, AffineMap order) { 488 if (!hasAnySparseResult(linalgOp)) 489 return true; 490 491 OpOperand *lhs = linalgOp.getDpsInitOperand(0); 492 unsigned nest = 0; 493 const auto iteratorTypes = linalgOp.getIteratorTypesArray(); 494 for (const AffineExpr l : order.getResults()) { 495 unsigned loopId = llvm::cast<AffineDimExpr>(l).getPosition(); 496 auto itTp = 497 linalgOp.getIteratorTypes()[loopId].cast<linalg::IteratorTypeAttr>(); 498 if (linalg::isReductionIterator(itTp.getValue())) 499 break; // terminate at first reduction 500 nest++; 501 } 502 // Determine admissible dynamic insertion situations: 503 // (1) fully injective, since there are no reductions, 504 // (2) admissible 1-d expansion in innermost dimension. 505 return static_cast<int64_t>(nest) >= linalgOp.getRank(lhs) - 1; 506 }; 507 508 // Last resort cycle resolution. 509 static LogicalResult resolveCycle(IterationGraphSorter &scheduler, 510 linalg::LinalgOp linalgOp, 511 PatternRewriter &rewriter) { 512 // Compute topological sort while leaving out every sparse input tensor in 513 // succession until an acylic iteration graph results. 514 for (OpOperand *t : linalgOp.getDpsInputOperands()) { 515 Value tval = t->get(); 516 auto srcEnc = getSparseTensorEncoding(tval.getType()); 517 // The constraints introduced by compound index expression are 518 // complicated. Skip them. 519 AffineMap idxMap = linalgOp.getMatchingIndexingMap(t); 520 bool hasCompExpr = llvm::any_of(idxMap.getResults(), [](AffineExpr exp) { 521 return !llvm::isa<AffineDimExpr>(exp); 522 }); 523 if (!srcEnc || hasCompExpr) 524 continue; 525 526 // Try scheduling loop without constraints from `tval`. 527 AffineMap order = scheduler.sort(SortMask::kSparseOnly, tval); 528 if (!order) // still cyclic 529 continue; 530 531 // Found an input tensor that resolves the cycle by inserting a 532 // conversion into a sparse tensor that adheres to the iteration 533 // graph order. 534 auto stt = getSparseTensorType(tval); 535 assert(stt.isIdentity()); 536 order = inversePermutation(order); 537 // sorted loop -> lvl map. 538 idxMap = idxMap.compose(order); 539 540 // Found a permutation such that the results in `idxMap` is sorted. 541 // For example, 542 // (d0, d1, d2, d3) -> (d2, d1, d0) 543 // loops are scheduled in order of d0->d1->d2->d3, to resolve the cycle, 544 // we find a permutation, perm(d2, d1, d0) -> (d0, d1, d2), such that the 545 // transposed tensor's levels are visited in the same order as the loop 546 // scheduling order. 547 SmallVector<std::pair<unsigned, unsigned>> lvlSeq; 548 for (AffineExpr expr : idxMap.getResults()) { 549 unsigned lvl = llvm::cast<AffineDimExpr>(expr).getPosition(); 550 lvlSeq.push_back(std::make_pair(lvl, lvlSeq.size())); 551 } 552 std::sort(lvlSeq.begin(), lvlSeq.end(), [](auto &lhs, auto &rhs) -> bool { 553 return lhs.first < rhs.first; 554 }); 555 SmallVector<unsigned> perm = 556 llvm::to_vector(llvm::make_second_range(lvlSeq)); 557 auto dimToLvl = AffineMap::getPermutationMap(perm, linalgOp.getContext()); 558 // The result of the idxMap must be unsorted. 559 assert(!dimToLvl.isIdentity()); 560 561 // Inserting the transpose 562 rewriter.setInsertionPoint(linalgOp); 563 RankedTensorType dstTp = stt.withDimToLvl(dimToLvl).getRankedTensorType(); 564 Value dst = rewriter.create<ConvertOp>(tval.getLoc(), dstTp, tval); 565 rewriter.updateRootInPlace(linalgOp, [&]() { 566 linalgOp->setOperand(t->getOperandNumber(), dst); 567 }); 568 return success(); 569 } 570 // Cannot be resolved with a single conversion. 571 // TODO: convert more than one? 572 return failure(); 573 } 574 }; 575 576 //===----------------------------------------------------------------------===// 577 // Reinterpret Map Rewriters for operations other than linalg.generics 578 //===----------------------------------------------------------------------===// 579 580 template <typename AllocOp> 581 struct TensorAllocDemapper : public OpRewritePattern<AllocOp> { 582 using OpRewritePattern<AllocOp>::OpRewritePattern; 583 LogicalResult matchAndRewrite(AllocOp op, 584 PatternRewriter &rewriter) const override { 585 if (!hasAnyNonIdentityOperandsOrResults(op)) 586 return failure(); 587 588 Location loc = op.getLoc(); 589 auto stt = getSparseTensorType(op.getResult()); 590 591 SmallVector<Value> maxDimCrds; 592 maxDimCrds.reserve(stt.getDimRank()); 593 ValueRange dynSz = op.getDynamicSizes(); 594 for (int64_t dimSz : stt.getDimShape()) { 595 if (ShapedType::isDynamic(dimSz)) { 596 Value maxCrd = rewriter.create<arith::SubIOp>( 597 loc, dynSz.front(), constantIndex(rewriter, loc, 1)); 598 maxDimCrds.push_back(maxCrd); 599 dynSz = dynSz.drop_front(); 600 } else { 601 maxDimCrds.push_back(constantIndex(rewriter, loc, dimSz - 1)); 602 } 603 } 604 605 ValueRange maxLvlCrds = stt.translateCrds(rewriter, loc, maxDimCrds, 606 CrdTransDirectionKind::dim2lvl); 607 auto lvlShape = stt.getLvlShape(); 608 SmallVector<Value> dynLvlSzs; 609 for (unsigned i = 0, e = lvlShape.size(); i < e; i++) { 610 if (ShapedType::isDynamic(lvlShape[i])) { 611 Value sz = rewriter.create<arith::AddIOp>( 612 loc, maxLvlCrds[i], constantIndex(rewriter, loc, 1)); 613 dynLvlSzs.push_back(sz); 614 } 615 } 616 617 assert(dynSz.empty()); // should have consumed all. 618 rewriter.startRootUpdate(op); 619 op->setOperands(dynLvlSzs); 620 op.getResult().setType(stt.getDemappedType()); 621 rewriter.finalizeRootUpdate(op); 622 rewriter.setInsertionPointAfter(op); 623 624 Value t = genRemap(rewriter, stt.getEncoding(), op.getResult()); 625 rewriter.replaceAllUsesExcept(op.getResult(), t, t.getDefiningOp()); 626 return success(); 627 } 628 }; 629 630 struct TensorInsertDemapper 631 : public DemapInsRewriter<TensorInsertDemapper, tensor::InsertOp> { 632 using DemapInsRewriter::DemapInsRewriter; 633 LogicalResult rewriteOp(tensor::InsertOp op, OpAdaptor adaptor, 634 PatternRewriter &rewriter) const { 635 if (!hasAnySparseResult(op)) 636 return failure(); 637 638 Location loc = op.getLoc(); 639 auto stt = getSparseTensorType(op.getResult()); 640 ValueRange lvlCrd = stt.translateCrds(rewriter, loc, op.getIndices(), 641 CrdTransDirectionKind::dim2lvl); 642 auto insertOp = rewriter.create<sparse_tensor::InsertOp>( 643 loc, op.getScalar(), adaptor.getDest(), lvlCrd); 644 645 Value out = genRemap(rewriter, stt.getEncoding(), insertOp.getResult()); 646 rewriter.replaceOp(op, out); 647 return success(); 648 } 649 }; 650 651 struct ForeachOpDemapper 652 : public DemapInsRewriter<ForeachOpDemapper, ForeachOp> { 653 using DemapInsRewriter::DemapInsRewriter; 654 LogicalResult rewriteOp(ForeachOp op, OpAdaptor adaptor, 655 PatternRewriter &rewriter) const { 656 // Only handle operations with sparse input/output with non-identity dim2lvl 657 // maps. 658 if (!hasAnyNonIdentityOperandsOrResults(op)) 659 return failure(); 660 661 // TODO: demap constant as well. 662 if (auto constOp = op.getTensor().getDefiningOp<arith::ConstantOp>()) 663 if (auto attr = dyn_cast<SparseElementsAttr>(constOp.getValue())) 664 return failure(); 665 666 Location loc = op.getLoc(); 667 // Cache the type information since we update the foreach op in-place. 668 auto srcStt = getSparseTensorType(op.getTensor()); 669 SmallVector<Type> prevRetTps(op.getResultTypes()); 670 671 rewriter.startRootUpdate(op); 672 op.getTensorMutable().assign(adaptor.getTensor()); 673 op.getInitArgsMutable().assign(adaptor.getInitArgs()); 674 // Update results' types. 675 for (auto r : op.getResults()) 676 if (auto stt = tryGetSparseTensorType(r); stt && !stt->isIdentity()) 677 r.setType(stt->getDemappedType()); 678 679 Level lvlRank = getSparseTensorType(adaptor.getTensor()).getLvlRank(); 680 // Update the foreach body. 681 SmallVector<Type> blockArgTps(lvlRank, rewriter.getIndexType()); 682 blockArgTps.push_back(srcStt.getElementType()); 683 blockArgTps.append(adaptor.getInitArgs().getTypes().begin(), 684 adaptor.getInitArgs().getTypes().end()); 685 Block *body = op.getBody(); 686 // Block Args: [dimCrd, val, initArgs] 687 unsigned preArgNum = body->getNumArguments(); 688 for (Type t : blockArgTps) 689 body->addArgument(t, loc); 690 691 // Block Args: [dimCrd, val, initArgs, lvlCrds, val, DemappedArgs] 692 rewriter.setInsertionPointToStart(body); 693 ValueRange lvlCrds = body->getArguments().slice(preArgNum, lvlRank); 694 695 ValueRange dimCrds = srcStt.translateCrds(rewriter, loc, lvlCrds, 696 CrdTransDirectionKind::lvl2dim); 697 rewriter.replaceAllUsesWith( 698 body->getArguments().take_front(srcStt.getDimRank()), dimCrds); 699 body->eraseArguments(0, srcStt.getDimRank()); 700 // Block Args: [val, initArgs, lvlCrds, val, DemappedArgs] 701 unsigned numInitArgs = op.getInitArgs().size(); 702 rewriter.replaceAllUsesWith(body->getArgument(0), 703 body->getArgument(lvlRank + numInitArgs + 1)); 704 body->eraseArgument(0); 705 // Block Args: [initArgs, lvlCrds, val, DemappedArgs] 706 ValueRange srcArgs = body->getArguments().take_front(numInitArgs); 707 ValueRange dstArgs = body->getArguments().take_back(numInitArgs); 708 // Remap back before replacement. 709 SmallVector<Value> reMappedArgs = 710 remapValueRange(rewriter, srcArgs.getTypes(), dstArgs); 711 rewriter.replaceAllUsesWith(srcArgs, reMappedArgs); 712 body->eraseArguments(0, numInitArgs); 713 // Block Args: [lvlCrds, DemappedArgs] and we are done. 714 715 // Update yield operations. 716 if (numInitArgs != 0) { 717 rewriter.setInsertionPointToEnd(body); 718 auto yield = llvm::cast<YieldOp>(body->getTerminator()); 719 if (auto stt = tryGetSparseTensorType(yield.getResult()); 720 stt && !stt->isIdentity()) { 721 Value y = genDemap(rewriter, stt->getEncoding(), yield.getResult()); 722 rewriter.create<YieldOp>(loc, y); 723 rewriter.eraseOp(yield); 724 } 725 } 726 rewriter.finalizeRootUpdate(op); 727 728 rewriter.setInsertionPointAfter(op); 729 SmallVector<Value> outs = 730 remapValueRange(rewriter, prevRetTps, op.getResults()); 731 732 // Replace all the uses of the foreach results, expect the use in 733 // reinterpret_map used to remap the output. 734 for (auto [from, to] : llvm::zip(op.getResults(), outs)) 735 rewriter.replaceAllUsesExcept(from, to, to.getDefiningOp()); 736 737 return success(); 738 } 739 }; 740 741 } // namespace 742 743 void mlir::populateSparseReinterpretMap(RewritePatternSet &patterns, 744 ReinterpretMapScope scope) { 745 if (scope == ReinterpretMapScope::kAll || 746 scope == ReinterpretMapScope::kGenericOnly) { 747 patterns.add<GenericOpReinterpretMap, GenericOpScheduler>( 748 patterns.getContext()); 749 } 750 if (scope == ReinterpretMapScope::kAll || 751 scope == ReinterpretMapScope::kExceptGeneric) { 752 patterns.add<TensorAllocDemapper<bufferization::AllocTensorOp>, 753 TensorAllocDemapper<tensor::EmptyOp>, TensorInsertDemapper, 754 ForeachOpDemapper>(patterns.getContext()); 755 } 756 } 757