xref: /llvm-project/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp (revision 942b403ff1a412778c9fb97bd53b44e35b544b0e)
1 //===- VectorDropLeadUnitDim.cpp - Conversion within the Vector dialect ---===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Utils/StructuredOpsUtils.h"
10 #include "mlir/Dialect/Vector/IR/VectorOps.h"
11 #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
12 #include "mlir/Dialect/Vector/Utils/VectorUtils.h"
13 #include "mlir/IR/Builders.h"
14 #include "mlir/IR/ImplicitLocOpBuilder.h"
15 #include "mlir/IR/TypeUtilities.h"
16 
17 #define DEBUG_TYPE "vector-drop-unit-dim"
18 
19 using namespace mlir;
20 using namespace mlir::vector;
21 
22 // Trims leading one dimensions from `oldType` and returns the result type.
23 // Returns `vector<1xT>` if `oldType` only has one element.
24 static VectorType trimLeadingOneDims(VectorType oldType) {
25   ArrayRef<int64_t> oldShape = oldType.getShape();
26   ArrayRef<int64_t> newShape =
27       oldShape.drop_while([](int64_t dim) { return dim == 1; });
28   // Make sure we have at least 1 dimension per vector type requirements.
29   if (newShape.empty())
30     newShape = oldShape.take_back();
31   return VectorType::get(newShape, oldType.getElementType());
32 }
33 
34 /// Return a smallVector of size `rank` containing all zeros.
35 static SmallVector<int64_t> splatZero(int64_t rank) {
36   return SmallVector<int64_t>(rank, 0);
37 }
38 namespace {
39 
40 // Casts away leading one dimensions in vector.extract_strided_slice's vector
41 // input by inserting vector.broadcast.
42 struct CastAwayExtractStridedSliceLeadingOneDim
43     : public OpRewritePattern<vector::ExtractStridedSliceOp> {
44   using OpRewritePattern::OpRewritePattern;
45 
46   LogicalResult matchAndRewrite(vector::ExtractStridedSliceOp extractOp,
47                                 PatternRewriter &rewriter) const override {
48     // vector.extract_strided_slice requires the input and output vector to have
49     // the same rank. Here we drop leading one dimensions from the input vector
50     // type to make sure we don't cause mismatch.
51     VectorType oldSrcType = extractOp.getSourceVectorType();
52     VectorType newSrcType = trimLeadingOneDims(oldSrcType);
53 
54     if (newSrcType.getRank() == oldSrcType.getRank())
55       return failure();
56 
57     int64_t dropCount = oldSrcType.getRank() - newSrcType.getRank();
58 
59     VectorType oldDstType = extractOp.getType();
60     VectorType newDstType =
61         VectorType::get(oldDstType.getShape().drop_front(dropCount),
62                         oldDstType.getElementType());
63 
64     Location loc = extractOp.getLoc();
65 
66     Value newSrcVector = rewriter.create<vector::ExtractOp>(
67         loc, extractOp.getVector(), splatZero(dropCount));
68 
69     // The offsets/sizes/strides attribute can have a less number of elements
70     // than the input vector's rank: it is meant for the leading dimensions.
71     auto newOffsets = rewriter.getArrayAttr(
72         extractOp.getOffsets().getValue().drop_front(dropCount));
73     auto newSizes = rewriter.getArrayAttr(
74         extractOp.getSizes().getValue().drop_front(dropCount));
75     auto newStrides = rewriter.getArrayAttr(
76         extractOp.getStrides().getValue().drop_front(dropCount));
77 
78     auto newExtractOp = rewriter.create<vector::ExtractStridedSliceOp>(
79         loc, newDstType, newSrcVector, newOffsets, newSizes, newStrides);
80 
81     rewriter.replaceOpWithNewOp<vector::BroadcastOp>(extractOp, oldDstType,
82                                                      newExtractOp);
83 
84     return success();
85   }
86 };
87 
88 // Casts away leading one dimensions in vector.insert_strided_slice's vector
89 // inputs by inserting vector.broadcast.
90 struct CastAwayInsertStridedSliceLeadingOneDim
91     : public OpRewritePattern<vector::InsertStridedSliceOp> {
92   using OpRewritePattern::OpRewritePattern;
93 
94   LogicalResult matchAndRewrite(vector::InsertStridedSliceOp insertOp,
95                                 PatternRewriter &rewriter) const override {
96     VectorType oldSrcType = insertOp.getSourceVectorType();
97     VectorType newSrcType = trimLeadingOneDims(oldSrcType);
98     VectorType oldDstType = insertOp.getDestVectorType();
99     VectorType newDstType = trimLeadingOneDims(oldDstType);
100 
101     int64_t srcDropCount = oldSrcType.getRank() - newSrcType.getRank();
102     int64_t dstDropCount = oldDstType.getRank() - newDstType.getRank();
103     if (srcDropCount == 0 && dstDropCount == 0)
104       return failure();
105 
106     // Trim leading one dimensions from both operands.
107     Location loc = insertOp.getLoc();
108 
109     Value newSrcVector = rewriter.create<vector::ExtractOp>(
110         loc, insertOp.getSource(), splatZero(srcDropCount));
111     Value newDstVector = rewriter.create<vector::ExtractOp>(
112         loc, insertOp.getDest(), splatZero(dstDropCount));
113 
114     auto newOffsets = rewriter.getArrayAttr(
115         insertOp.getOffsets().getValue().take_back(newDstType.getRank()));
116     auto newStrides = rewriter.getArrayAttr(
117         insertOp.getStrides().getValue().take_back(newSrcType.getRank()));
118 
119     auto newInsertOp = rewriter.create<vector::InsertStridedSliceOp>(
120         loc, newDstType, newSrcVector, newDstVector, newOffsets, newStrides);
121 
122     rewriter.replaceOpWithNewOp<vector::BroadcastOp>(insertOp, oldDstType,
123                                                      newInsertOp);
124 
125     return success();
126   }
127 };
128 
129 // Casts away leading one dimensions in vector.insert's vector inputs by
130 // inserting vector.broadcast.
131 struct CastAwayInsertLeadingOneDim : public OpRewritePattern<vector::InsertOp> {
132   using OpRewritePattern::OpRewritePattern;
133 
134   LogicalResult matchAndRewrite(vector::InsertOp insertOp,
135                                 PatternRewriter &rewriter) const override {
136     Type oldSrcType = insertOp.getSourceType();
137     Type newSrcType = oldSrcType;
138     int64_t oldSrcRank = 0, newSrcRank = 0;
139     if (auto type = oldSrcType.dyn_cast<VectorType>()) {
140       newSrcType = trimLeadingOneDims(type);
141       oldSrcRank = type.getRank();
142       newSrcRank = newSrcType.cast<VectorType>().getRank();
143     }
144 
145     VectorType oldDstType = insertOp.getDestVectorType();
146     VectorType newDstType = trimLeadingOneDims(oldDstType);
147 
148     int64_t srcDropCount = oldSrcRank - newSrcRank;
149     int64_t dstDropCount = oldDstType.getRank() - newDstType.getRank();
150     if (srcDropCount == 0 && dstDropCount == 0)
151       return failure();
152 
153     // Trim leading one dimensions from both operands.
154     Location loc = insertOp.getLoc();
155 
156     Value newSrcVector = insertOp.getSource();
157     if (oldSrcRank != 0) {
158       newSrcVector = rewriter.create<vector::ExtractOp>(
159           loc, insertOp.getSource(), splatZero(srcDropCount));
160     }
161     Value newDstVector = rewriter.create<vector::ExtractOp>(
162         loc, insertOp.getDest(), splatZero(dstDropCount));
163 
164     // New position rank needs to be computed in two steps: (1) if destination
165     // type has leading unit dims, we also trim the position array accordingly,
166     // then (2) if source type also has leading unit dims, we need to append
167     // zeroes to the position array accordingly.
168     unsigned oldPosRank = insertOp.getPosition().getValue().size();
169     unsigned newPosRank = std::max<int64_t>(0, oldPosRank - dstDropCount);
170     SmallVector<Attribute> newPositions = llvm::to_vector(
171         insertOp.getPosition().getValue().take_back(newPosRank));
172     if (srcDropCount >= dstDropCount) {
173       auto zeroAttr = rewriter.getZeroAttr(rewriter.getI64Type());
174       newPositions.resize(newPosRank + srcDropCount, zeroAttr);
175     }
176 
177     auto newInsertOp = rewriter.create<vector::InsertOp>(
178         loc, newDstType, newSrcVector, newDstVector,
179         rewriter.getArrayAttr(newPositions));
180 
181     rewriter.replaceOpWithNewOp<vector::BroadcastOp>(insertOp, oldDstType,
182                                                      newInsertOp);
183 
184     return success();
185   }
186 };
187 
188 // Turns vector.transfer_read on vector with leading 1 dimensions into
189 // vector.shape_cast followed by vector.transfer_read on vector without leading
190 // 1 dimensions.
191 struct CastAwayTransferReadLeadingOneDim
192     : public OpRewritePattern<vector::TransferReadOp> {
193   using OpRewritePattern::OpRewritePattern;
194 
195   LogicalResult matchAndRewrite(vector::TransferReadOp read,
196                                 PatternRewriter &rewriter) const override {
197     // TODO: support 0-d corner case.
198     if (read.getTransferRank() == 0)
199       return failure();
200 
201     if (read.getMask())
202       return failure();
203 
204     auto shapedType = read.getSource().getType().cast<ShapedType>();
205     if (shapedType.getElementType() != read.getVectorType().getElementType())
206       return failure();
207 
208     VectorType oldType = read.getVectorType();
209     VectorType newType = trimLeadingOneDims(oldType);
210 
211     if (newType == oldType)
212       return failure();
213 
214     AffineMap oldMap = read.getPermutationMap();
215     ArrayRef<AffineExpr> newResults =
216         oldMap.getResults().take_back(newType.getRank());
217     AffineMap newMap =
218         AffineMap::get(oldMap.getNumDims(), oldMap.getNumSymbols(), newResults,
219                        rewriter.getContext());
220 
221     ArrayAttr inBoundsAttr;
222     if (read.getInBounds())
223       inBoundsAttr = rewriter.getArrayAttr(
224           read.getInBoundsAttr().getValue().take_back(newType.getRank()));
225 
226     auto newRead = rewriter.create<vector::TransferReadOp>(
227         read.getLoc(), newType, read.getSource(), read.getIndices(),
228         AffineMapAttr::get(newMap), read.getPadding(), /*mask=*/Value(),
229         inBoundsAttr);
230     rewriter.replaceOpWithNewOp<vector::BroadcastOp>(read, oldType, newRead);
231 
232     return success();
233   }
234 };
235 
236 // Turns vector.transfer_write on vector with leading 1 dimensions into
237 // vector.shape_cast followed by vector.transfer_write on vector without leading
238 // 1 dimensions.
239 struct CastAwayTransferWriteLeadingOneDim
240     : public OpRewritePattern<vector::TransferWriteOp> {
241   using OpRewritePattern::OpRewritePattern;
242 
243   LogicalResult matchAndRewrite(vector::TransferWriteOp write,
244                                 PatternRewriter &rewriter) const override {
245     // TODO: support 0-d corner case.
246     if (write.getTransferRank() == 0)
247       return failure();
248 
249     if (write.getMask())
250       return failure();
251 
252     auto shapedType = write.getSource().getType().dyn_cast<ShapedType>();
253     if (shapedType.getElementType() != write.getVectorType().getElementType())
254       return failure();
255 
256     VectorType oldType = write.getVectorType();
257     VectorType newType = trimLeadingOneDims(oldType);
258     if (newType == oldType)
259       return failure();
260     int64_t dropDim = oldType.getRank() - newType.getRank();
261 
262     AffineMap oldMap = write.getPermutationMap();
263     ArrayRef<AffineExpr> newResults =
264         oldMap.getResults().take_back(newType.getRank());
265     AffineMap newMap =
266         AffineMap::get(oldMap.getNumDims(), oldMap.getNumSymbols(), newResults,
267                        rewriter.getContext());
268 
269     ArrayAttr inBoundsAttr;
270     if (write.getInBounds())
271       inBoundsAttr = rewriter.getArrayAttr(
272           write.getInBoundsAttr().getValue().take_back(newType.getRank()));
273 
274     auto newVector = rewriter.create<vector::ExtractOp>(
275         write.getLoc(), write.getVector(), splatZero(dropDim));
276     rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
277         write, newVector, write.getSource(), write.getIndices(),
278         AffineMapAttr::get(newMap), inBoundsAttr);
279 
280     return success();
281   }
282 };
283 
284 /// Turns vector.contract on vector with leading 1 dimensions into
285 /// vector.extract followed by vector.contract on vector without leading
286 /// 1 dimensions. Also performs tranpose of lhs and rhs operands if required
287 /// prior to extract.
288 struct CastAwayContractionLeadingOneDim
289     : public OpRewritePattern<vector::ContractionOp> {
290   using OpRewritePattern::OpRewritePattern;
291 
292   LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
293                                 PatternRewriter &rewriter) const override {
294     VectorType oldAccType = contractOp.getAccType().dyn_cast<VectorType>();
295     if (oldAccType == nullptr)
296       return failure();
297     if (oldAccType.getRank() < 2)
298       return failure();
299     if (oldAccType.getShape()[0] != 1)
300       return failure();
301     // currently we support only dropping one dim but the pattern can be applied
302     // greedily to drop more.
303     int64_t dropDim = 1;
304 
305     auto oldIndexingMaps = contractOp.getIndexingMapsArray();
306     SmallVector<AffineMap> newIndexingMaps;
307 
308     auto oldIteratorTypes = contractOp.getIteratorTypes();
309     SmallVector<Attribute> newIteratorTypes;
310 
311     int64_t dimToDrop = oldIndexingMaps[2].getDimPosition(0);
312 
313     if (!isParallelIterator(oldIteratorTypes[dimToDrop]))
314       // only parallel type iterators can be dropped.
315       return failure();
316 
317     for (const auto &it : llvm::enumerate(oldIteratorTypes)) {
318       int64_t currDim = it.index();
319       if (currDim == dimToDrop)
320         continue;
321       newIteratorTypes.push_back(it.value());
322     }
323 
324     SmallVector<Value> operands = {contractOp.getLhs(), contractOp.getRhs(),
325                                    contractOp.getAcc()};
326     SmallVector<Value> newOperands;
327 
328     for (const auto &it : llvm::enumerate(oldIndexingMaps)) {
329       // Check if the dim to be dropped exists as a leading dim in the operand
330       // if it does then we use vector.extract to drop it.
331       bool validExtract = false;
332       SmallVector<AffineExpr> results;
333       auto map = it.value();
334       int64_t orginalZeroDim = it.value().getDimPosition(0);
335       if (orginalZeroDim != dimToDrop) {
336         // There are two reasons to be in this path, 1. We need to
337         // tranpose the operand to make the dim to be dropped
338         // leading. 2. The dim to be dropped does not exist and in
339         // that case we dont want to add a unit tranpose but we must
340         // check all the indices to make sure this is the case.
341         bool tranposeNeeded = false;
342         SmallVector<int64_t> perm;
343         SmallVector<AffineExpr> transposeResults;
344 
345         for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) {
346           int64_t currDim = map.getDimPosition(i);
347           if (currDim == dimToDrop) {
348             tranposeNeeded = true;
349             perm.insert(perm.begin(), i);
350             auto targetExpr = rewriter.getAffineDimExpr(currDim);
351             transposeResults.insert(transposeResults.begin(), targetExpr);
352           } else {
353             perm.push_back(i);
354             auto targetExpr = rewriter.getAffineDimExpr(currDim);
355             transposeResults.push_back(targetExpr);
356           }
357         }
358         // Do the tranpose now if needed so that we can drop the
359         // correct dim using extract later.
360         if (tranposeNeeded) {
361           map = AffineMap::get(map.getNumDims(), 0, transposeResults,
362                                contractOp.getContext());
363           operands[it.index()] = rewriter.create<vector::TransposeOp>(
364               contractOp.getLoc(), operands[it.index()], perm);
365         }
366       }
367       // We have taken care to have the dim to be dropped be
368       // the leading dim. If its still not leading that means it
369       // does not exist in this operand and hence we do not need
370       // an extract.
371       if (map.getDimPosition(0) == dimToDrop)
372         validExtract = true;
373 
374       for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) {
375         int64_t currDim = map.getDimPosition(i);
376         if (currDim == dimToDrop)
377           // This is the dim we are dropping.
378           continue;
379         auto targetExpr = rewriter.getAffineDimExpr(
380             currDim < dimToDrop ? currDim : currDim - 1);
381         results.push_back(targetExpr);
382       }
383       newIndexingMaps.push_back(AffineMap::get(map.getNumDims() - 1, 0, results,
384                                                contractOp.getContext()));
385       // Extract if its a valid extraction, otherwise use the operand
386       // without extraction.
387       newOperands.push_back(validExtract
388                                 ? rewriter.create<vector::ExtractOp>(
389                                       contractOp.getLoc(), operands[it.index()],
390                                       splatZero(dropDim))
391                                 : operands[it.index()]);
392     }
393     auto newContractOp = rewriter.create<vector::ContractionOp>(
394         contractOp.getLoc(), newOperands[0], newOperands[1], newOperands[2],
395         rewriter.getAffineMapArrayAttr(newIndexingMaps),
396         rewriter.getArrayAttr(newIteratorTypes), contractOp.getKind());
397     rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
398         contractOp, contractOp->getResultTypes()[0], newContractOp);
399     return success();
400   }
401 };
402 
403 class CastAwayElementwiseLeadingOneDim : public RewritePattern {
404 public:
405   CastAwayElementwiseLeadingOneDim(MLIRContext *context,
406                                    PatternBenefit benefit = 1)
407       : RewritePattern(MatchAnyOpTypeTag(), benefit, context) {}
408 
409   LogicalResult matchAndRewrite(Operation *op,
410                                 PatternRewriter &rewriter) const override {
411     if (!OpTrait::hasElementwiseMappableTraits(op) || op->getNumResults() != 1)
412       return failure();
413     auto vecType = op->getResultTypes()[0].dyn_cast<VectorType>();
414     if (!vecType)
415       return failure();
416     VectorType newVecType = trimLeadingOneDims(vecType);
417     if (newVecType == vecType)
418       return failure();
419     int64_t dropDim = vecType.getRank() - newVecType.getRank();
420     SmallVector<Value, 4> newOperands;
421     for (Value operand : op->getOperands()) {
422       if (auto opVecType = operand.getType().dyn_cast<VectorType>()) {
423         newOperands.push_back(rewriter.create<vector::ExtractOp>(
424             op->getLoc(), operand, splatZero(dropDim)));
425       } else {
426         newOperands.push_back(operand);
427       }
428     }
429     Operation *newOp =
430         rewriter.create(op->getLoc(), op->getName().getIdentifier(),
431                         newOperands, newVecType, op->getAttrs());
432     rewriter.replaceOpWithNewOp<vector::BroadcastOp>(op, vecType,
433                                                      newOp->getResult(0));
434     return success();
435   }
436 };
437 
438 } // namespace
439 
440 void mlir::vector::populateCastAwayVectorLeadingOneDimPatterns(
441     RewritePatternSet &patterns, PatternBenefit benefit) {
442   patterns
443       .add<CastAwayExtractStridedSliceLeadingOneDim,
444            CastAwayInsertStridedSliceLeadingOneDim, CastAwayInsertLeadingOneDim,
445            CastAwayTransferReadLeadingOneDim,
446            CastAwayTransferWriteLeadingOneDim, CastAwayElementwiseLeadingOneDim,
447            CastAwayContractionLeadingOneDim>(patterns.getContext(), benefit);
448   populateShapeCastFoldingPatterns(patterns, benefit);
449 }
450