1 //===- SparseReinterpretMap.cpp - reinterpret sparse tensor maps ----------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "CodegenUtils.h" 10 11 #include "mlir/Dialect/Affine/IR/AffineOps.h" 12 #include "mlir/Dialect/Bufferization/IR/Bufferization.h" 13 #include "mlir/Dialect/Linalg/IR/Linalg.h" 14 #include "mlir/Dialect/Linalg/Utils/Utils.h" 15 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" 16 #include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h" 17 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h" 18 #include "mlir/Dialect/Tensor/IR/Tensor.h" 19 #include "mlir/IR/AffineMap.h" 20 21 using namespace mlir; 22 using namespace mlir::sparse_tensor; 23 24 //===----------------------------------------------------------------------===// 25 // File Local Helper methods. 26 //===----------------------------------------------------------------------===// 27 28 // Translates a "simple" map according to an identity lvl-map. 29 static AffineMap translateMap(OpBuilder &builder, SparseTensorType stt, 30 AffineMap map) { 31 unsigned lvlRank = stt.getLvlRank(); 32 AffineMap lvl2dim = stt.getLvlToDim(); 33 assert(lvl2dim.getNumInputs() == lvlRank); 34 SmallVector<AffineExpr> exps; 35 for (unsigned i = 0, n = map.getNumResults(); i < n; i++) { 36 unsigned pos = map.getResult(i).cast<AffineDimExpr>().getPosition(); 37 exps.push_back(lvl2dim.getResult(pos)); 38 } 39 return AffineMap::get(lvlRank, 0, exps, builder.getContext()); 40 } 41 42 // Generates a "de"mapping reinterpretation of the map. 43 static Value genDemap(OpBuilder &builder, SparseTensorEncodingAttr enc, 44 Value val) { 45 return builder.create<ReinterpretMapOp>(val.getLoc(), enc.withoutDimToLvl(), 46 val); 47 } 48 49 // Generates a "re"mapping reinterpretation of the map. 50 static Value genRemap(OpBuilder &builder, SparseTensorEncodingAttr enc, 51 Value val) { 52 return builder.create<ReinterpretMapOp>(val.getLoc(), enc, val); 53 } 54 55 static SmallVector<Value> remapValueRange(OpBuilder &rewriter, TypeRange types, 56 ValueRange outs) { 57 SmallVector<Value> ret(outs); 58 assert(outs.size() == types.size()); 59 for (auto [r, t] : llvm::zip(ret, types)) 60 if (r.getType() != t) 61 r = rewriter.create<ReinterpretMapOp>(r.getLoc(), t, r); 62 return ret; 63 } 64 65 /// Whether the operation has any sparse tensor with non-identity dim2lvl maps. 66 static bool hasNonIdentityOperandsOrResults(Operation *op) { 67 auto hasNonIdentityMap = [](Value v) { 68 auto stt = tryGetSparseTensorType(v); 69 return stt && !stt->isIdentity(); 70 }; 71 72 return llvm::any_of(op->getOperands(), hasNonIdentityMap) || 73 llvm::any_of(op->getResults(), hasNonIdentityMap); 74 } 75 76 // Generates a clone of the given linalg generic operation, but with 77 // remapped arguments, index maps, and iteration types. 78 // 79 // TODO: As decribed below, this is proof-of-concept code which makes a lot 80 // of simplifying assumptions for now. 81 // 82 static linalg::GenericOp genGenericLinalg(PatternRewriter &rewriter, 83 linalg::GenericOp linalgOp, 84 SparseTensorType stt, Value out) { 85 unsigned dimRank = stt.getDimRank(); 86 unsigned lvlRank = stt.getLvlRank(); 87 SmallVector<Value> inputOps = linalgOp.getInputs(); 88 SmallVector<Value> outputOps = {out}; 89 SmallVector<AffineMap> indexMaps; 90 SmallVector<utils::IteratorType> iterTypes; 91 // Translate the index maps, except output map, which is lvl-identity. 92 auto maps = linalgOp.getIndexingMapsArray(); 93 for (unsigned i = 0, n = maps.size() - 1; i < n; i++) 94 indexMaps.push_back(translateMap(rewriter, stt, maps[i])); 95 indexMaps.push_back( 96 AffineMap::getMultiDimIdentityMap(lvlRank, rewriter.getContext())); 97 // Add additional "parallel" iteration types at the top. 98 for (unsigned i = 0, diff = lvlRank = dimRank; i < diff; i++) 99 iterTypes.push_back(utils::IteratorType::parallel); 100 for (auto &i : linalgOp.getIteratorTypesArray()) 101 iterTypes.push_back(i); 102 // Generate the new linalg generic operation and clone body. 103 auto newOp = rewriter.create<linalg::GenericOp>( 104 linalgOp.getLoc(), out.getType(), inputOps, outputOps, indexMaps, 105 iterTypes); 106 rewriter.cloneRegionBefore(linalgOp.getRegion(), newOp.getRegion(), 107 newOp.getRegion().begin()); 108 return newOp; 109 } 110 111 namespace { 112 113 //===----------------------------------------------------------------------===// 114 // Rewriting rules for linalg generic ops. 115 //===----------------------------------------------------------------------===// 116 117 /// Sparse rewriting rule for the generic `linalg` operation. 118 struct GenericOpReinterpretMap : public OpRewritePattern<linalg::GenericOp> { 119 public: 120 GenericOpReinterpretMap(MLIRContext *context) 121 : OpRewritePattern<linalg::GenericOp>(context) {} 122 123 LogicalResult matchAndRewrite(linalg::GenericOp linalgOp, 124 PatternRewriter &rewriter) const override { 125 // Only rewrite single output operations with pure tensor semantics. 126 if (linalgOp.getNumDpsInits() != 1 || !linalgOp.hasTensorSemantics()) 127 return failure(); 128 // Scan all operands, inspect sparse tensors. 129 // 130 // TODO: generalize this proof-of-concept algorithm, since the current 131 // implementation accepts only simple indexing maps, and one 132 // non-permutation sparse tensor, which must have an identity 133 // indexing map and be the output. 134 // 135 OpOperand *tx = nullptr; 136 for (OpOperand &t : linalgOp->getOpOperands()) { 137 // Ensure every index map is "simple". 138 const auto map = linalgOp.getMatchingIndexingMap(&t); 139 for (unsigned i = 0, n = map.getNumResults(); i < n; i++) 140 if (map.getResult(i).getKind() != AffineExprKind::DimId) 141 return failure(); 142 // Inspect sparse operands. 143 auto stt = tryGetSparseTensorType(t.get()); 144 if (stt && stt->hasEncoding()) { 145 if (stt->isPermutation()) 146 continue; 147 assert(stt->getDimRank() < stt->getLvlRank()); // only allowed non-perm 148 if (tx) 149 return failure(); // more than one non-perm 150 if (!map.isIdentity()) 151 return failure(); // no ID indexing map on the non-perm 152 tx = &t; 153 } 154 } 155 // Found a non-permutation, rewrite when this is the output. 156 if (tx && tx == linalgOp.getDpsInitOperand(0)) { 157 auto stt = getSparseTensorType(tx->get()); 158 auto demap = genDemap(rewriter, stt.getEncoding(), tx->get()); 159 auto newOp = genGenericLinalg(rewriter, linalgOp, stt, demap); 160 auto remap = genRemap(rewriter, stt.getEncoding(), newOp.getResult(0)); 161 rewriter.replaceOp(linalgOp, remap); 162 return success(); 163 } 164 return failure(); 165 } 166 }; 167 168 //===----------------------------------------------------------------------===// 169 // Reinterpret Map Rewriters for operations other than linalg.generics 170 //===----------------------------------------------------------------------===// 171 172 // CRTP to help implementing a rewriter that demaps all its inputs. 173 template <typename SubClass, typename SourceOp> 174 struct DemapInsRewriter : public OpRewritePattern<SourceOp> { 175 using OpRewritePattern<SourceOp>::OpRewritePattern; 176 using OpAdaptor = typename SourceOp::Adaptor; 177 178 LogicalResult matchAndRewrite(SourceOp op, 179 PatternRewriter &rewriter) const override { 180 Location loc = op.getLoc(); 181 // Demaps non-trivial inputs. 182 SmallVector<Value> deMappedIns(op->getOperands()); 183 for (Value &in : deMappedIns) 184 if (auto stt = tryGetSparseTensorType(in); stt && !stt->isIdentity()) 185 in = rewriter.create<ReinterpretMapOp>(loc, stt->getDemappedType(), in); 186 187 // CRTP call. 188 OpAdaptor adaptor(deMappedIns); 189 return static_cast<const SubClass *>(this)->rewriteOp(op, adaptor, 190 rewriter); 191 } 192 }; 193 194 struct TensorAllocDemapper 195 : public OpRewritePattern<bufferization::AllocTensorOp> { 196 using OpRewritePattern::OpRewritePattern; 197 LogicalResult matchAndRewrite(bufferization::AllocTensorOp op, 198 PatternRewriter &rewriter) const override { 199 if (!hasNonIdentityOperandsOrResults(op)) 200 return failure(); 201 202 Location loc = op.getLoc(); 203 auto stt = getSparseTensorType(op.getResult()); 204 205 SmallVector<Value> maxDimCrds; 206 maxDimCrds.reserve(stt.getDimRank()); 207 ValueRange dynSz = op.getDynamicSizes(); 208 for (int64_t dimSz : stt.getDimShape()) { 209 if (ShapedType::isDynamic(dimSz)) { 210 Value maxCrd = rewriter.create<arith::SubIOp>( 211 loc, dynSz.front(), constantIndex(rewriter, loc, 1)); 212 maxDimCrds.push_back(maxCrd); 213 dynSz = dynSz.drop_front(); 214 } else { 215 maxDimCrds.push_back(constantIndex(rewriter, loc, dimSz - 1)); 216 } 217 } 218 219 ValueRange maxLvlCrds = stt.translateCrds(rewriter, loc, maxDimCrds, 220 CrdTransDirectionKind::dim2lvl); 221 auto lvlShape = stt.getLvlShape(); 222 SmallVector<Value> dynLvlSzs; 223 for (unsigned i = 0, e = lvlShape.size(); i < e; i++) { 224 if (ShapedType::isDynamic(lvlShape[i])) { 225 Value sz = rewriter.create<arith::AddIOp>( 226 loc, maxLvlCrds[i], constantIndex(rewriter, loc, 1)); 227 dynLvlSzs.push_back(sz); 228 } 229 } 230 231 assert(dynSz.empty()); // should have consumed all. 232 rewriter.startRootUpdate(op); 233 op->setOperands(dynLvlSzs); 234 op.getResult().setType(stt.getDemappedType()); 235 rewriter.finalizeRootUpdate(op); 236 rewriter.setInsertionPointAfter(op); 237 238 Value t = genRemap(rewriter, stt.getEncoding(), op.getResult()); 239 rewriter.replaceAllUsesExcept(op.getResult(), t, t.getDefiningOp()); 240 return success(); 241 } 242 }; 243 244 struct TensorInsertDemapper 245 : public DemapInsRewriter<TensorInsertDemapper, tensor::InsertOp> { 246 using DemapInsRewriter::DemapInsRewriter; 247 LogicalResult rewriteOp(tensor::InsertOp op, OpAdaptor adaptor, 248 PatternRewriter &rewriter) const { 249 if (!hasAnySparseResult(op)) 250 return failure(); 251 252 Location loc = op.getLoc(); 253 auto stt = getSparseTensorType(op.getResult()); 254 ValueRange lvlCrd = stt.translateCrds(rewriter, loc, op.getIndices(), 255 CrdTransDirectionKind::dim2lvl); 256 auto insertOp = rewriter.create<sparse_tensor::InsertOp>( 257 loc, op.getScalar(), adaptor.getDest(), lvlCrd); 258 259 Value out = genRemap(rewriter, stt.getEncoding(), insertOp.getResult()); 260 rewriter.replaceOp(op, out); 261 return success(); 262 } 263 }; 264 265 struct ForeachOpDemapper 266 : public DemapInsRewriter<ForeachOpDemapper, ForeachOp> { 267 using DemapInsRewriter::DemapInsRewriter; 268 LogicalResult rewriteOp(ForeachOp op, OpAdaptor adaptor, 269 PatternRewriter &rewriter) const { 270 // Only handle operations with sparse input/output with non-identity dim2lvl 271 // maps. 272 if (!hasNonIdentityOperandsOrResults(op)) 273 return failure(); 274 275 // TODO: demap constant as well. 276 if (auto constOp = op.getTensor().getDefiningOp<arith::ConstantOp>()) 277 if (auto attr = dyn_cast<SparseElementsAttr>(constOp.getValue())) 278 return failure(); 279 280 Location loc = op.getLoc(); 281 // Cache the type information since we update the foreach op in-place. 282 auto srcStt = getSparseTensorType(op.getTensor()); 283 SmallVector<Type> prevRetTps(op.getResultTypes()); 284 285 rewriter.startRootUpdate(op); 286 op.getTensorMutable().assign(adaptor.getTensor()); 287 op.getInitArgsMutable().assign(adaptor.getInitArgs()); 288 // Update results' types. 289 for (auto r : op.getResults()) 290 if (auto stt = tryGetSparseTensorType(r); stt && !stt->isIdentity()) 291 r.setType(stt->getDemappedType()); 292 293 Level lvlRank = getSparseTensorType(adaptor.getTensor()).getLvlRank(); 294 // Update the foreach body. 295 SmallVector<Type> blockArgTps(lvlRank, rewriter.getIndexType()); 296 blockArgTps.push_back(srcStt.getElementType()); 297 blockArgTps.append(adaptor.getInitArgs().getTypes().begin(), 298 adaptor.getInitArgs().getTypes().end()); 299 Block *body = op.getBody(); 300 // Block Args: [dimCrd, val, initArgs] 301 unsigned preArgNum = body->getNumArguments(); 302 for (Type t : blockArgTps) 303 body->addArgument(t, loc); 304 305 // Block Args: [dimCrd, val, initArgs, lvlCrds, val, DemappedArgs] 306 rewriter.setInsertionPointToStart(body); 307 ValueRange lvlCrds = body->getArguments().slice(preArgNum, lvlRank); 308 309 ValueRange dimCrds = srcStt.translateCrds(rewriter, loc, lvlCrds, 310 CrdTransDirectionKind::lvl2dim); 311 rewriter.replaceAllUsesWith( 312 body->getArguments().take_front(srcStt.getDimRank()), dimCrds); 313 body->eraseArguments(0, srcStt.getDimRank()); 314 // Block Args: [val, initArgs, lvlCrds, val, DemappedArgs] 315 unsigned numInitArgs = op.getInitArgs().size(); 316 rewriter.replaceAllUsesWith(body->getArgument(0), 317 body->getArgument(lvlRank + numInitArgs + 1)); 318 body->eraseArgument(0); 319 // Block Args: [initArgs, lvlCrds, val, DemappedArgs] 320 ValueRange srcArgs = body->getArguments().take_front(numInitArgs); 321 ValueRange dstArgs = body->getArguments().take_back(numInitArgs); 322 // Remap back before replacement. 323 SmallVector<Value> reMappedArgs = 324 remapValueRange(rewriter, srcArgs.getTypes(), dstArgs); 325 rewriter.replaceAllUsesWith(srcArgs, reMappedArgs); 326 body->eraseArguments(0, numInitArgs); 327 // Block Args: [lvlCrds, DemappedArgs] and we are done. 328 329 // Update yield operations. 330 if (numInitArgs != 0) { 331 rewriter.setInsertionPointToEnd(body); 332 auto yield = llvm::cast<YieldOp>(body->getTerminator()); 333 if (auto stt = tryGetSparseTensorType(yield.getResult()); 334 stt && !stt->isIdentity()) { 335 Value y = genDemap(rewriter, stt->getEncoding(), yield.getResult()); 336 rewriter.create<YieldOp>(loc, y); 337 rewriter.eraseOp(yield); 338 } 339 } 340 rewriter.finalizeRootUpdate(op); 341 342 rewriter.setInsertionPointAfter(op); 343 SmallVector<Value> outs = 344 remapValueRange(rewriter, prevRetTps, op.getResults()); 345 346 // Replace all the uses of the foreach results, expect the use in 347 // reinterpret_map used to remap the output. 348 for (auto [from, to] : llvm::zip(op.getResults(), outs)) 349 rewriter.replaceAllUsesExcept(from, to, to.getDefiningOp()); 350 351 return success(); 352 } 353 }; 354 355 } // namespace 356 357 void mlir::populateSparseReinterpretMap(RewritePatternSet &patterns, 358 ReinterpretMapScope scope) { 359 if (scope == ReinterpretMapScope::kAll || 360 scope == ReinterpretMapScope::kGenericOnly) { 361 patterns.add<GenericOpReinterpretMap>(patterns.getContext()); 362 } 363 if (scope == ReinterpretMapScope::kAll || 364 scope == ReinterpretMapScope::kExceptGeneric) { 365 patterns.add<TensorAllocDemapper, TensorInsertDemapper, ForeachOpDemapper>( 366 patterns.getContext()); 367 } 368 } 369