xref: /llvm-project/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp (revision c0d78c4232057768b04d3330e581d81544391e68)
1 //===- SparseReinterpretMap.cpp - reinterpret sparse tensor maps ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "CodegenUtils.h"
10 
11 #include "mlir/Dialect/Affine/IR/AffineOps.h"
12 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
13 #include "mlir/Dialect/Linalg/IR/Linalg.h"
14 #include "mlir/Dialect/Linalg/Utils/Utils.h"
15 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
16 #include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h"
17 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
18 #include "mlir/Dialect/Tensor/IR/Tensor.h"
19 #include "mlir/IR/AffineMap.h"
20 
21 using namespace mlir;
22 using namespace mlir::sparse_tensor;
23 
24 //===----------------------------------------------------------------------===//
25 // File Local Helper methods.
26 //===----------------------------------------------------------------------===//
27 
28 // Translates a "simple" map according to an identity lvl-map.
29 static AffineMap translateMap(OpBuilder &builder, SparseTensorType stt,
30                               AffineMap map) {
31   unsigned lvlRank = stt.getLvlRank();
32   AffineMap lvl2dim = stt.getLvlToDim();
33   assert(lvl2dim.getNumInputs() == lvlRank);
34   SmallVector<AffineExpr> exps;
35   for (unsigned i = 0, n = map.getNumResults(); i < n; i++) {
36     unsigned pos = map.getResult(i).cast<AffineDimExpr>().getPosition();
37     exps.push_back(lvl2dim.getResult(pos));
38   }
39   return AffineMap::get(lvlRank, 0, exps, builder.getContext());
40 }
41 
42 // Generates a "de"mapping reinterpretation of the map.
43 static Value genDemap(OpBuilder &builder, SparseTensorEncodingAttr enc,
44                       Value val) {
45   return builder.create<ReinterpretMapOp>(val.getLoc(), enc.withoutDimToLvl(),
46                                           val);
47 }
48 
49 // Generates a "re"mapping reinterpretation of the map.
50 static Value genRemap(OpBuilder &builder, SparseTensorEncodingAttr enc,
51                       Value val) {
52   return builder.create<ReinterpretMapOp>(val.getLoc(), enc, val);
53 }
54 
55 static SmallVector<Value> remapValueRange(OpBuilder &rewriter, TypeRange types,
56                                           ValueRange outs) {
57   SmallVector<Value> ret(outs);
58   assert(outs.size() == types.size());
59   for (auto [r, t] : llvm::zip(ret, types))
60     if (r.getType() != t)
61       r = rewriter.create<ReinterpretMapOp>(r.getLoc(), t, r);
62   return ret;
63 }
64 
65 /// Whether the operation has any sparse tensor with non-identity dim2lvl maps.
66 static bool hasNonIdentityOperandsOrResults(Operation *op) {
67   auto hasNonIdentityMap = [](Value v) {
68     auto stt = tryGetSparseTensorType(v);
69     return stt && !stt->isIdentity();
70   };
71 
72   return llvm::any_of(op->getOperands(), hasNonIdentityMap) ||
73          llvm::any_of(op->getResults(), hasNonIdentityMap);
74 }
75 
76 // Generates a clone of the given linalg generic operation, but with
77 // remapped arguments, index maps, and iteration types.
78 //
79 // TODO: As decribed below, this is proof-of-concept code which makes a lot
80 //       of simplifying assumptions for now.
81 //
82 static linalg::GenericOp genGenericLinalg(PatternRewriter &rewriter,
83                                           linalg::GenericOp linalgOp,
84                                           SparseTensorType stt, Value out) {
85   unsigned dimRank = stt.getDimRank();
86   unsigned lvlRank = stt.getLvlRank();
87   SmallVector<Value> inputOps = linalgOp.getInputs();
88   SmallVector<Value> outputOps = {out};
89   SmallVector<AffineMap> indexMaps;
90   SmallVector<utils::IteratorType> iterTypes;
91   // Translate the index maps, except output map, which is lvl-identity.
92   auto maps = linalgOp.getIndexingMapsArray();
93   for (unsigned i = 0, n = maps.size() - 1; i < n; i++)
94     indexMaps.push_back(translateMap(rewriter, stt, maps[i]));
95   indexMaps.push_back(
96       AffineMap::getMultiDimIdentityMap(lvlRank, rewriter.getContext()));
97   // Add additional "parallel" iteration types at the top.
98   for (unsigned i = 0, diff = lvlRank = dimRank; i < diff; i++)
99     iterTypes.push_back(utils::IteratorType::parallel);
100   for (auto &i : linalgOp.getIteratorTypesArray())
101     iterTypes.push_back(i);
102   // Generate the new linalg generic operation and clone body.
103   auto newOp = rewriter.create<linalg::GenericOp>(
104       linalgOp.getLoc(), out.getType(), inputOps, outputOps, indexMaps,
105       iterTypes);
106   rewriter.cloneRegionBefore(linalgOp.getRegion(), newOp.getRegion(),
107                              newOp.getRegion().begin());
108   return newOp;
109 }
110 
111 namespace {
112 
113 //===----------------------------------------------------------------------===//
114 // Rewriting rules for linalg generic ops.
115 //===----------------------------------------------------------------------===//
116 
117 /// Sparse rewriting rule for the generic `linalg` operation.
118 struct GenericOpReinterpretMap : public OpRewritePattern<linalg::GenericOp> {
119 public:
120   GenericOpReinterpretMap(MLIRContext *context)
121       : OpRewritePattern<linalg::GenericOp>(context) {}
122 
123   LogicalResult matchAndRewrite(linalg::GenericOp linalgOp,
124                                 PatternRewriter &rewriter) const override {
125     // Only rewrite single output operations with pure tensor semantics.
126     if (linalgOp.getNumDpsInits() != 1 || !linalgOp.hasTensorSemantics())
127       return failure();
128     // Scan all operands, inspect sparse tensors.
129     //
130     // TODO: generalize this proof-of-concept algorithm, since the current
131     //       implementation accepts only simple indexing maps, and one
132     //       non-permutation sparse tensor, which must have an identity
133     //       indexing map and be the output.
134     //
135     OpOperand *tx = nullptr;
136     for (OpOperand &t : linalgOp->getOpOperands()) {
137       // Ensure every index map is "simple".
138       const auto map = linalgOp.getMatchingIndexingMap(&t);
139       for (unsigned i = 0, n = map.getNumResults(); i < n; i++)
140         if (map.getResult(i).getKind() != AffineExprKind::DimId)
141           return failure();
142       // Inspect sparse operands.
143       auto stt = tryGetSparseTensorType(t.get());
144       if (stt && stt->hasEncoding()) {
145         if (stt->isPermutation())
146           continue;
147         assert(stt->getDimRank() < stt->getLvlRank()); // only allowed non-perm
148         if (tx)
149           return failure(); // more than one non-perm
150         if (!map.isIdentity())
151           return failure(); // no ID indexing map on the non-perm
152         tx = &t;
153       }
154     }
155     // Found a non-permutation, rewrite when this is the output.
156     if (tx && tx == linalgOp.getDpsInitOperand(0)) {
157       auto stt = getSparseTensorType(tx->get());
158       auto demap = genDemap(rewriter, stt.getEncoding(), tx->get());
159       auto newOp = genGenericLinalg(rewriter, linalgOp, stt, demap);
160       auto remap = genRemap(rewriter, stt.getEncoding(), newOp.getResult(0));
161       rewriter.replaceOp(linalgOp, remap);
162       return success();
163     }
164     return failure();
165   }
166 };
167 
168 //===----------------------------------------------------------------------===//
169 // Reinterpret Map Rewriters for operations other than linalg.generics
170 //===----------------------------------------------------------------------===//
171 
172 // CRTP to help implementing a rewriter that demaps all its inputs.
173 template <typename SubClass, typename SourceOp>
174 struct DemapInsRewriter : public OpRewritePattern<SourceOp> {
175   using OpRewritePattern<SourceOp>::OpRewritePattern;
176   using OpAdaptor = typename SourceOp::Adaptor;
177 
178   LogicalResult matchAndRewrite(SourceOp op,
179                                 PatternRewriter &rewriter) const override {
180     Location loc = op.getLoc();
181     // Demaps non-trivial inputs.
182     SmallVector<Value> deMappedIns(op->getOperands());
183     for (Value &in : deMappedIns)
184       if (auto stt = tryGetSparseTensorType(in); stt && !stt->isIdentity())
185         in = rewriter.create<ReinterpretMapOp>(loc, stt->getDemappedType(), in);
186 
187     // CRTP call.
188     OpAdaptor adaptor(deMappedIns);
189     return static_cast<const SubClass *>(this)->rewriteOp(op, adaptor,
190                                                           rewriter);
191   }
192 };
193 
194 struct TensorAllocDemapper
195     : public OpRewritePattern<bufferization::AllocTensorOp> {
196   using OpRewritePattern::OpRewritePattern;
197   LogicalResult matchAndRewrite(bufferization::AllocTensorOp op,
198                                 PatternRewriter &rewriter) const override {
199     if (!hasNonIdentityOperandsOrResults(op))
200       return failure();
201 
202     Location loc = op.getLoc();
203     auto stt = getSparseTensorType(op.getResult());
204 
205     SmallVector<Value> maxDimCrds;
206     maxDimCrds.reserve(stt.getDimRank());
207     ValueRange dynSz = op.getDynamicSizes();
208     for (int64_t dimSz : stt.getDimShape()) {
209       if (ShapedType::isDynamic(dimSz)) {
210         Value maxCrd = rewriter.create<arith::SubIOp>(
211             loc, dynSz.front(), constantIndex(rewriter, loc, 1));
212         maxDimCrds.push_back(maxCrd);
213         dynSz = dynSz.drop_front();
214       } else {
215         maxDimCrds.push_back(constantIndex(rewriter, loc, dimSz - 1));
216       }
217     }
218 
219     ValueRange maxLvlCrds = stt.translateCrds(rewriter, loc, maxDimCrds,
220                                               CrdTransDirectionKind::dim2lvl);
221     auto lvlShape = stt.getLvlShape();
222     SmallVector<Value> dynLvlSzs;
223     for (unsigned i = 0, e = lvlShape.size(); i < e; i++) {
224       if (ShapedType::isDynamic(lvlShape[i])) {
225         Value sz = rewriter.create<arith::AddIOp>(
226             loc, maxLvlCrds[i], constantIndex(rewriter, loc, 1));
227         dynLvlSzs.push_back(sz);
228       }
229     }
230 
231     assert(dynSz.empty()); // should have consumed all.
232     rewriter.startRootUpdate(op);
233     op->setOperands(dynLvlSzs);
234     op.getResult().setType(stt.getDemappedType());
235     rewriter.finalizeRootUpdate(op);
236     rewriter.setInsertionPointAfter(op);
237 
238     Value t = genRemap(rewriter, stt.getEncoding(), op.getResult());
239     rewriter.replaceAllUsesExcept(op.getResult(), t, t.getDefiningOp());
240     return success();
241   }
242 };
243 
244 struct TensorInsertDemapper
245     : public DemapInsRewriter<TensorInsertDemapper, tensor::InsertOp> {
246   using DemapInsRewriter::DemapInsRewriter;
247   LogicalResult rewriteOp(tensor::InsertOp op, OpAdaptor adaptor,
248                           PatternRewriter &rewriter) const {
249     if (!hasAnySparseResult(op))
250       return failure();
251 
252     Location loc = op.getLoc();
253     auto stt = getSparseTensorType(op.getResult());
254     ValueRange lvlCrd = stt.translateCrds(rewriter, loc, op.getIndices(),
255                                           CrdTransDirectionKind::dim2lvl);
256     auto insertOp = rewriter.create<sparse_tensor::InsertOp>(
257         loc, op.getScalar(), adaptor.getDest(), lvlCrd);
258 
259     Value out = genRemap(rewriter, stt.getEncoding(), insertOp.getResult());
260     rewriter.replaceOp(op, out);
261     return success();
262   }
263 };
264 
265 struct ForeachOpDemapper
266     : public DemapInsRewriter<ForeachOpDemapper, ForeachOp> {
267   using DemapInsRewriter::DemapInsRewriter;
268   LogicalResult rewriteOp(ForeachOp op, OpAdaptor adaptor,
269                           PatternRewriter &rewriter) const {
270     // Only handle operations with sparse input/output with non-identity dim2lvl
271     // maps.
272     if (!hasNonIdentityOperandsOrResults(op))
273       return failure();
274 
275     // TODO: demap constant as well.
276     if (auto constOp = op.getTensor().getDefiningOp<arith::ConstantOp>())
277       if (auto attr = dyn_cast<SparseElementsAttr>(constOp.getValue()))
278         return failure();
279 
280     Location loc = op.getLoc();
281     // Cache the type information since we update the foreach op in-place.
282     auto srcStt = getSparseTensorType(op.getTensor());
283     SmallVector<Type> prevRetTps(op.getResultTypes());
284 
285     rewriter.startRootUpdate(op);
286     op.getTensorMutable().assign(adaptor.getTensor());
287     op.getInitArgsMutable().assign(adaptor.getInitArgs());
288     // Update results' types.
289     for (auto r : op.getResults())
290       if (auto stt = tryGetSparseTensorType(r); stt && !stt->isIdentity())
291         r.setType(stt->getDemappedType());
292 
293     Level lvlRank = getSparseTensorType(adaptor.getTensor()).getLvlRank();
294     // Update the foreach body.
295     SmallVector<Type> blockArgTps(lvlRank, rewriter.getIndexType());
296     blockArgTps.push_back(srcStt.getElementType());
297     blockArgTps.append(adaptor.getInitArgs().getTypes().begin(),
298                        adaptor.getInitArgs().getTypes().end());
299     Block *body = op.getBody();
300     // Block Args: [dimCrd, val, initArgs]
301     unsigned preArgNum = body->getNumArguments();
302     for (Type t : blockArgTps)
303       body->addArgument(t, loc);
304 
305     // Block Args: [dimCrd, val, initArgs, lvlCrds, val, DemappedArgs]
306     rewriter.setInsertionPointToStart(body);
307     ValueRange lvlCrds = body->getArguments().slice(preArgNum, lvlRank);
308 
309     ValueRange dimCrds = srcStt.translateCrds(rewriter, loc, lvlCrds,
310                                               CrdTransDirectionKind::lvl2dim);
311     rewriter.replaceAllUsesWith(
312         body->getArguments().take_front(srcStt.getDimRank()), dimCrds);
313     body->eraseArguments(0, srcStt.getDimRank());
314     // Block Args: [val, initArgs, lvlCrds, val, DemappedArgs]
315     unsigned numInitArgs = op.getInitArgs().size();
316     rewriter.replaceAllUsesWith(body->getArgument(0),
317                                 body->getArgument(lvlRank + numInitArgs + 1));
318     body->eraseArgument(0);
319     // Block Args: [initArgs, lvlCrds, val, DemappedArgs]
320     ValueRange srcArgs = body->getArguments().take_front(numInitArgs);
321     ValueRange dstArgs = body->getArguments().take_back(numInitArgs);
322     // Remap back before replacement.
323     SmallVector<Value> reMappedArgs =
324         remapValueRange(rewriter, srcArgs.getTypes(), dstArgs);
325     rewriter.replaceAllUsesWith(srcArgs, reMappedArgs);
326     body->eraseArguments(0, numInitArgs);
327     // Block Args: [lvlCrds, DemappedArgs] and we are done.
328 
329     // Update yield operations.
330     if (numInitArgs != 0) {
331       rewriter.setInsertionPointToEnd(body);
332       auto yield = llvm::cast<YieldOp>(body->getTerminator());
333       if (auto stt = tryGetSparseTensorType(yield.getResult());
334           stt && !stt->isIdentity()) {
335         Value y = genDemap(rewriter, stt->getEncoding(), yield.getResult());
336         rewriter.create<YieldOp>(loc, y);
337         rewriter.eraseOp(yield);
338       }
339     }
340     rewriter.finalizeRootUpdate(op);
341 
342     rewriter.setInsertionPointAfter(op);
343     SmallVector<Value> outs =
344         remapValueRange(rewriter, prevRetTps, op.getResults());
345 
346     // Replace all the uses of the foreach results, expect the use in
347     // reinterpret_map used to remap the output.
348     for (auto [from, to] : llvm::zip(op.getResults(), outs))
349       rewriter.replaceAllUsesExcept(from, to, to.getDefiningOp());
350 
351     return success();
352   }
353 };
354 
355 } // namespace
356 
357 void mlir::populateSparseReinterpretMap(RewritePatternSet &patterns,
358                                         ReinterpretMapScope scope) {
359   if (scope == ReinterpretMapScope::kAll ||
360       scope == ReinterpretMapScope::kGenericOnly) {
361     patterns.add<GenericOpReinterpretMap>(patterns.getContext());
362   }
363   if (scope == ReinterpretMapScope::kAll ||
364       scope == ReinterpretMapScope::kExceptGeneric) {
365     patterns.add<TensorAllocDemapper, TensorInsertDemapper, ForeachOpDemapper>(
366         patterns.getContext());
367   }
368 }
369