1 //===- SparseReinterpretMap.cpp - reinterpret sparse tensor maps ----------===/ 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "Utils/CodegenUtils.h" 10 #include "Utils/IterationGraphSorter.h" 11 12 #include "mlir/Dialect/Affine/IR/AffineOps.h" 13 #include "mlir/Dialect/Bufferization/IR/Bufferization.h" 14 #include "mlir/Dialect/Linalg/IR/Linalg.h" 15 #include "mlir/Dialect/Linalg/Utils/Utils.h" 16 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" 17 #include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h" 18 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h" 19 #include "mlir/Dialect/Tensor/IR/Tensor.h" 20 #include "mlir/IR/AffineExprVisitor.h" 21 #include "mlir/IR/AffineMap.h" 22 23 using namespace mlir; 24 using namespace mlir::sparse_tensor; 25 26 namespace { 27 28 //===----------------------------------------------------------------------===// 29 // File Local Helper classes. 30 //===----------------------------------------------------------------------===// 31 32 // CRTP to help implementing a rewriter that demaps all its inputs. 33 template <typename SubClass, typename SourceOp> 34 struct DemapInsRewriter : public OpRewritePattern<SourceOp> { 35 using OpRewritePattern<SourceOp>::OpRewritePattern; 36 using OpAdaptor = typename SourceOp::Adaptor; 37 38 LogicalResult matchAndRewrite(SourceOp op, 39 PatternRewriter &rewriter) const override { 40 Location loc = op.getLoc(); 41 42 // Demaps non-trivial inputs. 43 bool changed = false; 44 SmallVector<Value> deMappedIns(op->getOperands()); 45 for (Value &in : deMappedIns) { 46 if (auto stt = tryGetSparseTensorType(in); stt && !stt->isIdentity()) { 47 in = rewriter.create<ReinterpretMapOp>(loc, stt->getDemappedType(), in); 48 changed = true; 49 } 50 } 51 52 // CRTP call. 53 OpAdaptor adaptor(deMappedIns, op); 54 LogicalResult status = 55 static_cast<const SubClass *>(this)->rewriteOp(op, adaptor, rewriter); 56 return changed ? success() : status; 57 } 58 }; 59 60 // Flattens an affine expression into a list of AffineDimExprs. 61 struct AffineDimCollector : public AffineExprVisitor<AffineDimCollector> { 62 explicit AffineDimCollector(unsigned dimNum) : dims(dimNum){}; 63 void visitDimExpr(AffineDimExpr expr) { dims.set(expr.getPosition()); } 64 BitVector dims; 65 }; 66 67 // Flattens an affine expression into a list of AffineDimExprs. 68 struct AffineExprAdmissibleVisitor 69 : public AffineExprVisitor<AffineExprAdmissibleVisitor> { 70 explicit AffineExprAdmissibleVisitor(bool isOutput) 71 : admissible(true), isOutput(isOutput){}; 72 73 // We only allow AffineDimExpr on output. 74 void visitAddExpr(AffineBinaryOpExpr expr) { 75 if (isOutput) 76 admissible = false; 77 } 78 void visitMulExpr(AffineBinaryOpExpr expr) { 79 if (isOutput) 80 admissible = false; 81 } 82 83 // We disallow mod, floor div and ceil div on inputs. 84 void visitModExpr(AffineBinaryOpExpr expr) { admissible = false; } 85 void visitFloorDivExpr(AffineBinaryOpExpr expr) { admissible = false; } 86 void visitCeilDivExpr(AffineBinaryOpExpr expr) { admissible = false; } 87 operator bool() { return admissible; } 88 89 private: 90 bool admissible; 91 bool isOutput; 92 }; 93 94 // The first BitVector stores levels where inadmissible exprs are used. 95 // The second BitVector stores the AffineDimExp that are used by the 96 // inadmissible expressions. 97 using InadmissInfo = std::pair<BitVector, BitVector>; 98 99 } // namespace 100 101 //===----------------------------------------------------------------------===// 102 // File Local Helper methods. 103 //===----------------------------------------------------------------------===// 104 105 // Collects the inadmissible affine expression imposed on levels. 106 static InadmissInfo collectInadmissInfo(AffineMap map, bool isOutput) { 107 auto ret = std::make_pair(BitVector(map.getNumResults()), 108 BitVector(map.getNumDims())); 109 AffineDimCollector collector(map.getNumDims()); 110 for (unsigned lvl = 0, e = map.getNumResults(); lvl < e; lvl++) { 111 AffineExprAdmissibleVisitor admissible(isOutput); 112 admissible.walkPostOrder(map.getResult(lvl)); 113 if (!admissible) { 114 // Record the inadmissible level. 115 ret.first.set(lvl); 116 // Record the AffineDimExpr that is used in the inadmissible expr. 117 collector.walkPostOrder(map.getResult(lvl)); 118 } 119 } 120 ret.second = collector.dims; 121 return ret; 122 } 123 124 // Builds the AffineMap to replace the idx in idxMap to lvl such that all tht 125 // inadmissible affine expressions can be eliminated. 126 // For example, we can rewrite 127 // idxMap = (d0, d1) -> (d0 floordiv 2, d1 floordiv 3, d0 mod 2, d1 mod 3) 128 // to 129 // idxMap = (l0, l1, l2, l3) -> (l0, l1, l2, l3) 130 // by composing inverse(idxMap), that is 131 // inverse(idxMap) . idxMap = (l0, l1, l2, l3) -> (l0 * 2 + l2, l1 * 3 + l3) 132 // -> ((l0 * 2 + l2) floordiv 2, 133 // (l1 * 3 + l3) floordiv 3, 134 // (l0 * 2 + l2) mod 2, 135 // (l1 * 3 + l3) mod 3) = (l0, l1, l2, l3) 136 // 137 // This function builds the inverse(idxMap) that replace every dimensions used 138 // in `info` to levels, and updates the iterator type array `itTps` for the new 139 // index variable introduced. 140 // 141 // Note that the returned affine map does not retain the order of the input 142 // affine map. Instead, it always uses the first `info.inAdlvls.count()` for the 143 // replaced levels, and remaining ones for unused dimensions. 144 // For example, to handle 145 // idxMap = (d0, d1) -> (d0, d1 floordiv 4, d2 mod 4) 146 // which is a typical map for block_2to4. The function returns: 147 // inverse(idxMap) = (l0, l1, d0) -> (d0, l0 * 4 + l1) 148 // in which, (l0, l1) together replaces `d1`, yet they appear 149 // before `d0` in the resulting affine map. 150 // The index (loop) order can later be canonicalized by a topo sort. 151 static AffineMap 152 genReplaceDimToLvlMap(const InadmissInfo &info, AffineMap idxMap, 153 SmallVector<utils::IteratorType> &itTps) { 154 MLIRContext *ctx = idxMap.getContext(); 155 auto [inAdLvls, usedDims] = info; 156 // Note that idxMap does not equal to dim2Lvl map, it is computed by 157 // composing idx2Dim(dim2Lvl). They are only equal when idx2Dim is an 158 // ID map. 159 // TODO: we might fail here, in those case we should really return 160 // failure instead of assertion error. 161 auto lvl2Idx = inferLvlToDim(idxMap, ctx); 162 163 assert(lvl2Idx.getNumResults() <= idxMap.getNumDims()); 164 if (lvl2Idx.getNumResults() != idxMap.getNumDims()) { 165 // This could happen when some dimensions are projected. 166 // E.g., idx2Lvl = (*i*, j, k) -> (j, k) 167 // ==> lvl2Idx = (j, k) -> (j, k) 168 // In this case, we append the unused dimesion at the end. 169 // ==> lvl2Idx = (j, k, *i*) -> (*i*, j, k) 170 SmallVector<AffineExpr> results; 171 AffineDimCollector usedInLvl(idxMap.getNumDims()); 172 for (auto e : idxMap.getResults()) 173 usedInLvl.walkPostOrder(e); 174 175 unsigned curUsedDimID = 0; 176 unsigned curUnusedDimID = lvl2Idx.getNumDims(); 177 178 BitVector unused = usedInLvl.dims.flip(); 179 for (unsigned i = 0; i < idxMap.getNumDims(); i++) { 180 if (unused.test(i)) 181 results.push_back(getAffineDimExpr(curUnusedDimID++, ctx)); 182 else 183 results.push_back(lvl2Idx.getResult(curUsedDimID++)); 184 } 185 lvl2Idx = 186 AffineMap::get(lvl2Idx.getNumDims() + unused.count(), 0, results, ctx); 187 } 188 assert(lvl2Idx.getNumResults() == idxMap.getNumDims()); 189 190 // We do not need to replace the DimExpr that is not used in inadmissible 191 // level expressions. We use the first inAdLvl.count() dim to represent the 192 // replaced level, the remainings are reserved for unchanged ones. 193 // Note that results from the inverse map computed previously does not follow 194 // the convention we used, and we need to fix the mismatch below. 195 unsigned curRepID = 0; 196 unsigned curOriID = inAdLvls.count(); 197 SmallVector<AffineExpr> results; 198 SmallVector<AffineExpr> dimRep(idxMap.getNumResults(), AffineExpr()); 199 SmallVector<utils::IteratorType> transItTps; 200 201 for (unsigned l : inAdLvls.set_bits()) { 202 // By our convention, the inadmissible level `l` always appears in the 203 // leading part (accumulated by curRepID) of the affine map's parameter 204 // list. Record the mapping so that we can replace all the uses of `l` to 205 // the correct position after the translation. 206 dimRep[l] = getAffineDimExpr(curRepID++, ctx); 207 // A new index variable is introduced for the inadmissible level, inherit 208 // the iterator type. E.g., if l0 = d0 floordiv 2, the 209 // iterator type of l0 equals to the iterator type of d0. 210 AffineExpr lvlExp = idxMap.getResult(l); 211 AffineDimCollector collector(idxMap.getNumDims()); 212 collector.walkPostOrder(lvlExp); 213 // We assumes a level can only be derived from one dimension. 214 assert(collector.dims.count() == 1); 215 transItTps.push_back(itTps[collector.dims.find_first()]); 216 } 217 218 for (unsigned d = 0, e = idxMap.getNumDims(); d < e; d++) { 219 if (usedDims.test(d)) { 220 // The dimension is used in some of the inadmissible levels, and it need 221 // to be inversed. Get the inversion from the inverse map, and fix the 222 // mismatch captured by the above loop. 223 results.push_back(lvl2Idx.getResult(d).replaceDims(dimRep)); 224 } else { 225 // The dimension is not used in any of the inadmissible levels, and it 226 // does not need to be inversed. Fix the mismatch by mapping it to the 227 // trailing part of the affine map (accumulated by curOriID). 228 results.push_back(getAffineDimExpr(curOriID++, ctx)); 229 transItTps.push_back(itTps[d]); 230 } 231 } 232 unsigned numDim = idxMap.getNumDims() - usedDims.count() + inAdLvls.count(); 233 // Update iterator type. 234 itTps.assign(transItTps.begin(), transItTps.end()); 235 return AffineMap::get(numDim, 0, results, ctx); 236 } 237 238 // Translates the index map in the linalg::GenericOp from idx->dim map to 239 // idx->lvl map. Returns failure if the index map can not be translated to an 240 // admissible form. 241 // Returns the translated index map array and the iterator type array. 242 static std::optional<std::pair<ArrayAttr, ArrayAttr>> 243 translateMap(linalg::GenericOp op, PatternRewriter &rewriter) { 244 // idxMap is a idx2dim map before reinterpretation. 245 MLIRContext *ctx = op.getContext(); 246 SmallVector<AffineMap> idxMapArray = op.getIndexingMapsArray(); 247 SmallVector<utils::IteratorType> itTps = op.getIteratorTypesArray(); 248 for (unsigned i = 0, e = idxMapArray.size(); i < e; i++) { 249 Value tensor = op->getOpOperand(i).get(); 250 auto stt = tryGetSparseTensorType(tensor); 251 if (stt && !stt->isIdentity()) { 252 AffineMap dim2Lvl = stt->getDimToLvl(); 253 // By composing the idx2dim(dim2lvl), we got a idx2lvl Map 254 idxMapArray[i] = dim2Lvl.compose(idxMapArray[i]); 255 } 256 } 257 258 // A naive way to handle common constant expressions that arise during dim2lvl 259 // translation. 260 auto populateCstMapping = [ctx](DenseMap<AffineExpr, AffineExpr> &cstMapping, 261 unsigned pos, int64_t lvlSz) { 262 if (!ShapedType::isDynamic(lvlSz)) { 263 auto c0 = getAffineConstantExpr(0, ctx); 264 auto lvlExp = getAffineDimExpr(pos, ctx); 265 auto szExp = getAffineConstantExpr(lvlSz, ctx); 266 267 // lvl floordiv lvlSz = 0 268 auto divExp = 269 getAffineBinaryOpExpr(AffineExprKind::FloorDiv, lvlExp, szExp); 270 cstMapping.try_emplace(divExp, c0); 271 272 // lvl mod lvlSz = lvl 273 auto modExp = getAffineBinaryOpExpr(AffineExprKind::Mod, lvlExp, szExp); 274 cstMapping.try_emplace(modExp, lvlExp); 275 } 276 }; 277 278 unsigned boundedNum = 0; 279 // A fixed-point algorithm. 280 bool changed = true; 281 while (changed) { 282 changed = false; 283 for (OpOperand &operand : op->getOpOperands()) { 284 auto stt = tryGetSparseTensorType(operand.get()); 285 // Skip on dense operands. 286 if (!stt || !stt->getEncoding()) 287 continue; 288 289 unsigned tid = operand.getOperandNumber(); 290 bool isOutput = &operand == op.getDpsInitOperand(0); 291 AffineMap idxMap = idxMapArray[tid]; 292 InadmissInfo inAdInfo = collectInadmissInfo(idxMap, isOutput); 293 auto [inAdLvls, dimExprs] = inAdInfo; 294 for (unsigned d : dimExprs.set_bits()) { 295 // The first `boundedNum` used in the AffineMap is introduced to 296 // resolve previous inadmissible expressions. We can not replace them 297 // as it might bring back the inadmissible expressions. 298 if (d < boundedNum) 299 return std::nullopt; 300 } 301 302 if (inAdLvls.count() != 0) { 303 // Naive constant progagation, should be sufficient to handle block 304 // sparsity in our cases. 305 SmallVector<int64_t> lvlShape = stt->getLvlShape(); 306 DenseMap<AffineExpr, AffineExpr> cstMapping; 307 unsigned position = 0; 308 for (unsigned lvl : inAdLvls.set_bits()) { 309 int64_t lvlSz = lvlShape[lvl]; 310 populateCstMapping(cstMapping, position, lvlSz); 311 position++; 312 } 313 314 AffineMap lvl2Idx = genReplaceDimToLvlMap(inAdInfo, idxMap, itTps); 315 // Compose the lvl2Idx Map to all AffineIdxMap to eliminate 316 // inadmissible expressions. 317 for (unsigned tid = 0, e = idxMapArray.size(); tid < e; tid++) { 318 AffineMap transMap = idxMapArray[tid].compose(lvl2Idx); 319 idxMapArray[tid] = transMap.replace( 320 cstMapping, /*numResultDims=*/transMap.getNumDims(), 321 /*numResultSyms=*/0); 322 } 323 changed = true; 324 boundedNum += inAdLvls.count(); 325 } 326 } 327 }; 328 329 SmallVector<Attribute> iterAttr = 330 llvm::map_to_vector(itTps, [ctx](auto itTp) -> Attribute { 331 return linalg::IteratorTypeAttr::get(ctx, itTp); 332 }); 333 334 return std::make_pair(rewriter.getAffineMapArrayAttr(idxMapArray), 335 rewriter.getArrayAttr(iterAttr)); 336 } 337 338 // Generates a "de"mapping reinterpretation of the map. 339 static Value genDemap(OpBuilder &builder, SparseTensorEncodingAttr enc, 340 Value val) { 341 return builder.create<ReinterpretMapOp>(val.getLoc(), enc.withoutDimToLvl(), 342 val); 343 } 344 345 // Generates a "re"mapping reinterpretation of the map. 346 static Value genRemap(OpBuilder &builder, SparseTensorEncodingAttr enc, 347 Value val) { 348 return builder.create<ReinterpretMapOp>(val.getLoc(), enc, val); 349 } 350 351 static SmallVector<Value> remapValueRange(OpBuilder &rewriter, TypeRange types, 352 ValueRange outs) { 353 SmallVector<Value> ret(outs); 354 assert(outs.size() == types.size()); 355 for (auto [r, t] : llvm::zip(ret, types)) 356 if (r.getType() != t) 357 r = rewriter.create<ReinterpretMapOp>(r.getLoc(), t, r); 358 return ret; 359 } 360 361 namespace { 362 363 //===----------------------------------------------------------------------===// 364 // Rewriting rules for linalg generic ops. 365 //===----------------------------------------------------------------------===// 366 367 /// Sparse rewriting rule for the generic `linalg` operation. 368 struct GenericOpReinterpretMap 369 : public DemapInsRewriter<GenericOpReinterpretMap, linalg::GenericOp> { 370 public: 371 using DemapInsRewriter::DemapInsRewriter; 372 LogicalResult rewriteOp(linalg::GenericOp linalgOp, OpAdaptor adaptor, 373 PatternRewriter &rewriter) const { 374 // Only rewrite single output operations with pure (sparse) tensor 375 // semantics. 376 if (linalgOp.getNumDpsInits() != 1 || !linalgOp.hasPureTensorSemantics() || 377 !hasAnySparseOperandOrResult(linalgOp) || 378 !hasAnyNonIdentityOperandsOrResults(linalgOp)) 379 return failure(); 380 381 // Try translating the index map. 382 auto transMap = translateMap(linalgOp, rewriter); 383 if (!transMap) 384 return rewriter.notifyMatchFailure( 385 linalgOp, "the sparse kernel can not be sparsified."); 386 387 // On success, replace update the linalg operands and maps in place. 388 Value res = linalgOp.getResult(0); 389 auto stt = tryGetSparseTensorType(res); 390 auto [idxMap, itTp] = *transMap; 391 392 rewriter.startOpModification(linalgOp); 393 linalgOp.setIndexingMapsAttr(idxMap); 394 linalgOp.setIteratorTypesAttr(itTp); 395 // Use demapped arguments. 396 linalgOp.getInputsMutable().assign(adaptor.getInputs()); 397 linalgOp.getDpsInitsMutable().assign(adaptor.getOutputs()); 398 res.setType(adaptor.getOutputs()[0].getType()); 399 rewriter.finalizeOpModification(linalgOp); 400 401 rewriter.setInsertionPointAfter(linalgOp); 402 if (stt && stt->hasEncoding()) { 403 Value t = genRemap(rewriter, stt->getEncoding(), res); 404 rewriter.replaceAllUsesExcept(res, t, t.getDefiningOp()); 405 } 406 return success(); 407 } 408 }; 409 410 struct GenericOpScheduler : public OpRewritePattern<linalg::GenericOp> { 411 using OpRewritePattern::OpRewritePattern; 412 LogicalResult matchAndRewrite(linalg::GenericOp linalgOp, 413 PatternRewriter &rewriter) const override { 414 if (linalgOp.getNumDpsInits() != 1 || !linalgOp.hasPureTensorSemantics() || 415 hasAnyNonIdentityOperandsOrResults(linalgOp) || // need demap first 416 !hasAnySparseOperandOrResult(linalgOp)) { 417 return failure(); 418 } 419 420 const StringRef sorted = "sorted"; 421 if (linalgOp->hasAttr(sorted)) 422 return failure(); 423 424 auto scheduler = IterationGraphSorter::fromGenericOp(linalgOp); 425 bool isAdmissible = false; 426 AffineMap order; 427 // A const list of all masks that we used for iteration graph 428 // computation. Must be ordered from more strict to less strict. 429 // Ideally (though might not be guaranteed), the earlier a constraint mask 430 // can be satisfied, the faster the generated kernel will be. 431 const auto allMasks = {SortMask::kIncludeAll, SortMask::kIncludeDense, 432 SortMask::kIncludeDenseInput, 433 SortMask::kIncludeDenseOutput, 434 SortMask::kSparseOnly}; 435 for (const SortMask mask : allMasks) { 436 order = scheduler.sort(mask); 437 if (order) { 438 if (isAdmissibleOrder(linalgOp, order)) { 439 isAdmissible = true; 440 break; 441 } 442 // else try a set of less strict constraints. 443 } 444 } 445 446 if (!order) { 447 // Cycles detected. 448 if (failed(resolveCycle(scheduler, linalgOp, rewriter))) { 449 return rewriter.notifyMatchFailure( 450 linalgOp, "the sparse kernel can not be scheduled: loop detected."); 451 } 452 return success(); 453 } 454 455 if (!isAdmissible) { 456 return rewriter.notifyMatchFailure( 457 linalgOp, "the sparse kernel can not be scheduled."); 458 } 459 460 // Marks the GenericOp to avoid recursive matching. 461 rewriter.modifyOpInPlace(linalgOp, [&]() { 462 linalgOp->setAttr(sorted, rewriter.getBoolAttr(true)); 463 }); 464 465 // Already sorted. 466 if (order.isIdentity()) 467 return success(); 468 469 assert(order.isPermutation()); 470 // `order` is orignial loop -> sorted loop map 471 ArrayAttr preItTypes = linalgOp.getIteratorTypesAttr(); 472 SmallVector<Attribute> curItTypes; 473 curItTypes.reserve(preItTypes.size()); 474 for (AffineExpr expr : order.getResults()) { 475 unsigned loopID = llvm::cast<AffineDimExpr>(expr).getPosition(); 476 curItTypes.push_back(preItTypes[loopID]); 477 } 478 479 // Inverse `order` to get sorted loop -> original loop map 480 order = inversePermutation(order); 481 SmallVector<AffineMap> idxMaps = linalgOp.getIndexingMapsArray(); 482 for (AffineMap &idxMap : idxMaps) 483 idxMap = idxMap.compose(order); // sorted loop -> lvl map 484 485 rewriter.startOpModification(linalgOp); 486 linalgOp.setIndexingMapsAttr(rewriter.getAffineMapArrayAttr(idxMaps)); 487 linalgOp.setIteratorTypesAttr(rewriter.getArrayAttr(curItTypes)); 488 rewriter.finalizeOpModification(linalgOp); 489 490 return success(); 491 } 492 493 private: 494 /// Whether the loop order is admissible by sparsification. 495 static bool isAdmissibleOrder(linalg::GenericOp linalgOp, AffineMap order) { 496 if (!hasAnySparseResult(linalgOp)) 497 return true; 498 499 OpOperand *lhs = linalgOp.getDpsInitOperand(0); 500 unsigned nest = 0; 501 const auto iteratorTypes = linalgOp.getIteratorTypesArray(); 502 for (const AffineExpr l : order.getResults()) { 503 unsigned loopId = llvm::cast<AffineDimExpr>(l).getPosition(); 504 auto itTp = 505 linalgOp.getIteratorTypes()[loopId].cast<linalg::IteratorTypeAttr>(); 506 if (linalg::isReductionIterator(itTp.getValue())) 507 break; // terminate at first reduction 508 nest++; 509 } 510 // Determine admissible dynamic insertion situations: 511 // (1) fully injective, since there are no reductions, 512 // (2) admissible 1-d expansion in innermost dimension. 513 return static_cast<int64_t>(nest) >= linalgOp.getRank(lhs) - 1; 514 }; 515 516 // Last resort cycle resolution. 517 static LogicalResult resolveCycle(IterationGraphSorter &scheduler, 518 linalg::LinalgOp linalgOp, 519 PatternRewriter &rewriter) { 520 // Compute topological sort while leaving out every sparse input tensor in 521 // succession until an acylic iteration graph results. 522 for (OpOperand *t : linalgOp.getDpsInputOperands()) { 523 Value tval = t->get(); 524 auto srcEnc = getSparseTensorEncoding(tval.getType()); 525 // The constraints introduced by compound index expression are 526 // complicated. Skip them. 527 AffineMap idxMap = linalgOp.getMatchingIndexingMap(t); 528 bool hasCompExpr = llvm::any_of(idxMap.getResults(), [](AffineExpr exp) { 529 return !llvm::isa<AffineDimExpr>(exp); 530 }); 531 if (!srcEnc || hasCompExpr) 532 continue; 533 534 // Try scheduling loop without constraints from `tval`. 535 AffineMap order = scheduler.sort(SortMask::kSparseOnly, tval); 536 if (!order) // still cyclic 537 continue; 538 539 // Found an input tensor that resolves the cycle by inserting a 540 // conversion into a sparse tensor that adheres to the iteration 541 // graph order. 542 auto stt = getSparseTensorType(tval); 543 assert(stt.isIdentity()); 544 order = inversePermutation(order); 545 // sorted loop -> lvl map. 546 idxMap = idxMap.compose(order); 547 548 // Found a permutation such that the results in `idxMap` is sorted. 549 // For example, 550 // (d0, d1, d2, d3) -> (d2, d1, d0) 551 // loops are scheduled in order of d0->d1->d2->d3, to resolve the cycle, 552 // we find a permutation, perm(d2, d1, d0) -> (d0, d1, d2), such that the 553 // transposed tensor's levels are visited in the same order as the loop 554 // scheduling order. 555 SmallVector<std::pair<unsigned, unsigned>> lvlSeq; 556 for (AffineExpr expr : idxMap.getResults()) { 557 unsigned lvl = llvm::cast<AffineDimExpr>(expr).getPosition(); 558 lvlSeq.push_back(std::make_pair(lvl, lvlSeq.size())); 559 } 560 std::sort(lvlSeq.begin(), lvlSeq.end(), [](auto &lhs, auto &rhs) -> bool { 561 return lhs.first < rhs.first; 562 }); 563 SmallVector<unsigned> perm = 564 llvm::to_vector(llvm::make_second_range(lvlSeq)); 565 auto dimToLvl = AffineMap::getPermutationMap(perm, linalgOp.getContext()); 566 // The result of the idxMap must be unsorted. 567 assert(!dimToLvl.isIdentity()); 568 569 // Inserting the transpose 570 rewriter.setInsertionPoint(linalgOp); 571 RankedTensorType dstTp = stt.withDimToLvl(dimToLvl).getRankedTensorType(); 572 Value dst = rewriter.create<ConvertOp>(tval.getLoc(), dstTp, tval); 573 rewriter.modifyOpInPlace(linalgOp, [&]() { 574 linalgOp->setOperand(t->getOperandNumber(), dst); 575 }); 576 return success(); 577 } 578 // Cannot be resolved with a single conversion. 579 // TODO: convert more than one? 580 return failure(); 581 } 582 }; 583 584 //===----------------------------------------------------------------------===// 585 // Reinterpret Map Rewriters for operations other than linalg.generics 586 //===----------------------------------------------------------------------===// 587 588 template <typename AllocOp> 589 struct TensorAllocDemapper : public OpRewritePattern<AllocOp> { 590 using OpRewritePattern<AllocOp>::OpRewritePattern; 591 LogicalResult matchAndRewrite(AllocOp op, 592 PatternRewriter &rewriter) const override { 593 if (!hasAnyNonIdentityOperandsOrResults(op)) 594 return failure(); 595 596 Location loc = op.getLoc(); 597 auto stt = getSparseTensorType(op.getResult()); 598 599 SmallVector<Value> maxDimCrds; 600 maxDimCrds.reserve(stt.getDimRank()); 601 ValueRange dynSz = op.getDynamicSizes(); 602 for (int64_t dimSz : stt.getDimShape()) { 603 if (ShapedType::isDynamic(dimSz)) { 604 Value maxCrd = rewriter.create<arith::SubIOp>( 605 loc, dynSz.front(), constantIndex(rewriter, loc, 1)); 606 maxDimCrds.push_back(maxCrd); 607 dynSz = dynSz.drop_front(); 608 } else { 609 maxDimCrds.push_back(constantIndex(rewriter, loc, dimSz - 1)); 610 } 611 } 612 613 ValueRange maxLvlCrds = stt.translateCrds(rewriter, loc, maxDimCrds, 614 CrdTransDirectionKind::dim2lvl); 615 auto lvlShape = stt.getLvlShape(); 616 SmallVector<Value> dynLvlSzs; 617 for (unsigned i = 0, e = lvlShape.size(); i < e; i++) { 618 if (ShapedType::isDynamic(lvlShape[i])) { 619 Value sz = rewriter.create<arith::AddIOp>( 620 loc, maxLvlCrds[i], constantIndex(rewriter, loc, 1)); 621 dynLvlSzs.push_back(sz); 622 } 623 } 624 625 assert(dynSz.empty()); // should have consumed all. 626 rewriter.startOpModification(op); 627 op->setOperands(dynLvlSzs); 628 op.getResult().setType(stt.getDemappedType()); 629 rewriter.finalizeOpModification(op); 630 rewriter.setInsertionPointAfter(op); 631 632 Value t = genRemap(rewriter, stt.getEncoding(), op.getResult()); 633 rewriter.replaceAllUsesExcept(op.getResult(), t, t.getDefiningOp()); 634 return success(); 635 } 636 }; 637 638 struct TensorInsertDemapper 639 : public DemapInsRewriter<TensorInsertDemapper, tensor::InsertOp> { 640 using DemapInsRewriter::DemapInsRewriter; 641 LogicalResult rewriteOp(tensor::InsertOp op, OpAdaptor adaptor, 642 PatternRewriter &rewriter) const { 643 if (!hasAnySparseResult(op) || !hasAnyNonIdentityOperandsOrResults(op)) 644 return failure(); 645 646 Location loc = op.getLoc(); 647 auto stt = getSparseTensorType(op.getResult()); 648 ValueRange lvlCrd = stt.translateCrds(rewriter, loc, op.getIndices(), 649 CrdTransDirectionKind::dim2lvl); 650 auto insertOp = rewriter.create<tensor::InsertOp>( 651 loc, op.getScalar(), adaptor.getDest(), lvlCrd); 652 653 Value out = genRemap(rewriter, stt.getEncoding(), insertOp.getResult()); 654 rewriter.replaceOp(op, out); 655 return success(); 656 } 657 }; 658 659 struct SparseAssembleDemapper : public OpRewritePattern<AssembleOp> { 660 using OpRewritePattern::OpRewritePattern; 661 LogicalResult matchAndRewrite(AssembleOp op, 662 PatternRewriter &rewriter) const override { 663 if (!hasAnyNonIdentityOperandsOrResults(op)) 664 return failure(); 665 666 assert(hasAnySparseResult(op)); 667 auto stt = getSparseTensorType(op.getResult()); 668 rewriter.modifyOpInPlace( 669 op, [&op, &stt]() { op.getResult().setType(stt.getDemappedType()); }); 670 rewriter.setInsertionPointAfter(op); 671 Value out = genRemap(rewriter, stt.getEncoding(), op.getResult()); 672 rewriter.replaceAllUsesExcept(op, out, out.getDefiningOp()); 673 return success(); 674 } 675 }; 676 677 struct SparseDisassembleDemapper 678 : public DemapInsRewriter<SparseDisassembleDemapper, DisassembleOp> { 679 using DemapInsRewriter::DemapInsRewriter; 680 LogicalResult rewriteOp(DisassembleOp op, OpAdaptor adaptor, 681 PatternRewriter &rewriter) const { 682 if (!hasAnyNonIdentityOperandsOrResults(op)) 683 return failure(); 684 685 assert(hasAnySparseOperandOrResult(op)); 686 rewriter.modifyOpInPlace(op, [&op, &adaptor]() { 687 op.getTensorMutable().assign(adaptor.getTensor()); 688 }); 689 return success(); 690 } 691 }; 692 693 struct ForeachOpDemapper 694 : public DemapInsRewriter<ForeachOpDemapper, ForeachOp> { 695 using DemapInsRewriter::DemapInsRewriter; 696 LogicalResult rewriteOp(ForeachOp op, OpAdaptor adaptor, 697 PatternRewriter &rewriter) const { 698 // Only handle operations with sparse input/output with non-identity dim2lvl 699 // maps. 700 if (!hasAnyNonIdentityOperandsOrResults(op)) 701 return failure(); 702 703 // TODO: demap constant as well. 704 if (auto constOp = op.getTensor().getDefiningOp<arith::ConstantOp>()) 705 if (auto attr = dyn_cast<SparseElementsAttr>(constOp.getValue())) 706 return failure(); 707 708 Location loc = op.getLoc(); 709 // Cache the type information since we update the foreach op in-place. 710 auto srcStt = getSparseTensorType(op.getTensor()); 711 SmallVector<Type> prevRetTps(op.getResultTypes()); 712 713 rewriter.startOpModification(op); 714 op.getTensorMutable().assign(adaptor.getTensor()); 715 op.getInitArgsMutable().assign(adaptor.getInitArgs()); 716 // Update results' types. 717 for (auto r : op.getResults()) 718 if (auto stt = tryGetSparseTensorType(r); stt && !stt->isIdentity()) 719 r.setType(stt->getDemappedType()); 720 721 Level lvlRank = getSparseTensorType(adaptor.getTensor()).getLvlRank(); 722 // Update the foreach body. 723 SmallVector<Type> blockArgTps(lvlRank, rewriter.getIndexType()); 724 blockArgTps.push_back(srcStt.getElementType()); 725 blockArgTps.append(adaptor.getInitArgs().getTypes().begin(), 726 adaptor.getInitArgs().getTypes().end()); 727 Block *body = op.getBody(); 728 // Block Args: [dimCrd, val, initArgs] 729 unsigned preArgNum = body->getNumArguments(); 730 for (Type t : blockArgTps) 731 body->addArgument(t, loc); 732 733 // Block Args: [dimCrd, val, initArgs, lvlCrds, val, DemappedArgs] 734 rewriter.setInsertionPointToStart(body); 735 ValueRange lvlCrds = body->getArguments().slice(preArgNum, lvlRank); 736 737 ValueRange dimCrds = srcStt.translateCrds(rewriter, loc, lvlCrds, 738 CrdTransDirectionKind::lvl2dim); 739 rewriter.replaceAllUsesWith( 740 body->getArguments().take_front(srcStt.getDimRank()), dimCrds); 741 body->eraseArguments(0, srcStt.getDimRank()); 742 // Block Args: [val, initArgs, lvlCrds, val, DemappedArgs] 743 unsigned numInitArgs = op.getInitArgs().size(); 744 rewriter.replaceAllUsesWith(body->getArgument(0), 745 body->getArgument(lvlRank + numInitArgs + 1)); 746 body->eraseArgument(0); 747 // Block Args: [initArgs, lvlCrds, val, DemappedArgs] 748 ValueRange srcArgs = body->getArguments().take_front(numInitArgs); 749 ValueRange dstArgs = body->getArguments().take_back(numInitArgs); 750 // Remap back before replacement. 751 SmallVector<Value> reMappedArgs = 752 remapValueRange(rewriter, srcArgs.getTypes(), dstArgs); 753 rewriter.replaceAllUsesWith(srcArgs, reMappedArgs); 754 body->eraseArguments(0, numInitArgs); 755 // Block Args: [lvlCrds, DemappedArgs] and we are done. 756 757 // Update yield operations. 758 if (numInitArgs != 0) { 759 rewriter.setInsertionPointToEnd(body); 760 auto yield = llvm::cast<YieldOp>(body->getTerminator()); 761 if (auto stt = tryGetSparseTensorType(yield.getResult()); 762 stt && !stt->isIdentity()) { 763 Value y = genDemap(rewriter, stt->getEncoding(), yield.getResult()); 764 rewriter.create<YieldOp>(loc, y); 765 rewriter.eraseOp(yield); 766 } 767 } 768 rewriter.finalizeOpModification(op); 769 770 rewriter.setInsertionPointAfter(op); 771 SmallVector<Value> outs = 772 remapValueRange(rewriter, prevRetTps, op.getResults()); 773 774 // Replace all the uses of the foreach results, expect the use in 775 // reinterpret_map used to remap the output. 776 for (auto [from, to] : llvm::zip(op.getResults(), outs)) 777 rewriter.replaceAllUsesExcept(from, to, to.getDefiningOp()); 778 779 return success(); 780 } 781 }; 782 783 } // namespace 784 785 void mlir::populateSparseReinterpretMap(RewritePatternSet &patterns, 786 ReinterpretMapScope scope) { 787 if (scope == ReinterpretMapScope::kAll || 788 scope == ReinterpretMapScope::kGenericOnly) { 789 patterns.add<GenericOpReinterpretMap, GenericOpScheduler>( 790 patterns.getContext()); 791 } 792 if (scope == ReinterpretMapScope::kAll || 793 scope == ReinterpretMapScope::kExceptGeneric) { 794 patterns.add<TensorAllocDemapper<bufferization::AllocTensorOp>, 795 TensorAllocDemapper<tensor::EmptyOp>, SparseAssembleDemapper, 796 SparseDisassembleDemapper, TensorInsertDemapper, 797 ForeachOpDemapper>(patterns.getContext()); 798 } 799 } 800