xref: /llvm-project/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp (revision 7b70baa9ef2d859a90a2873242b446cf21b1755e)
1 //===- VectorDropLeadUnitDim.cpp - Conversion within the Vector dialect ---===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Utils/StructuredOpsUtils.h"
10 #include "mlir/Dialect/Vector/IR/VectorOps.h"
11 #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
12 #include "mlir/Dialect/Vector/Utils/VectorUtils.h"
13 #include "mlir/IR/Builders.h"
14 #include "mlir/IR/ImplicitLocOpBuilder.h"
15 #include "mlir/IR/TypeUtilities.h"
16 
17 #define DEBUG_TYPE "vector-drop-unit-dim"
18 
19 using namespace mlir;
20 using namespace mlir::vector;
21 
22 // Trims leading one dimensions from `oldType` and returns the result type.
23 // Returns `vector<1xT>` if `oldType` only has one element.
24 static VectorType trimLeadingOneDims(VectorType oldType) {
25   ArrayRef<int64_t> oldShape = oldType.getShape();
26   ArrayRef<int64_t> newShape =
27       oldShape.drop_while([](int64_t dim) { return dim == 1; });
28   // Make sure we have at least 1 dimension per vector type requirements.
29   if (newShape.empty())
30     newShape = oldShape.take_back();
31   return VectorType::get(newShape, oldType.getElementType());
32 }
33 
34 /// Return a smallVector of size `rank` containing all zeros.
35 static SmallVector<int64_t> splatZero(int64_t rank) {
36   return SmallVector<int64_t>(rank, 0);
37 }
38 namespace {
39 
40 // Casts away leading one dimensions in vector.extract_strided_slice's vector
41 // input by inserting vector.broadcast.
42 struct CastAwayExtractStridedSliceLeadingOneDim
43     : public OpRewritePattern<vector::ExtractStridedSliceOp> {
44   using OpRewritePattern::OpRewritePattern;
45 
46   LogicalResult matchAndRewrite(vector::ExtractStridedSliceOp extractOp,
47                                 PatternRewriter &rewriter) const override {
48     // vector.extract_strided_slice requires the input and output vector to have
49     // the same rank. Here we drop leading one dimensions from the input vector
50     // type to make sure we don't cause mismatch.
51     VectorType oldSrcType = extractOp.getSourceVectorType();
52     VectorType newSrcType = trimLeadingOneDims(oldSrcType);
53 
54     if (newSrcType.getRank() == oldSrcType.getRank())
55       return failure();
56 
57     int64_t dropCount = oldSrcType.getRank() - newSrcType.getRank();
58 
59     VectorType oldDstType = extractOp.getType();
60     VectorType newDstType =
61         VectorType::get(oldDstType.getShape().drop_front(dropCount),
62                         oldDstType.getElementType());
63 
64     Location loc = extractOp.getLoc();
65 
66     Value newSrcVector = rewriter.create<vector::ExtractOp>(
67         loc, extractOp.getVector(), splatZero(dropCount));
68 
69     // The offsets/sizes/strides attribute can have a less number of elements
70     // than the input vector's rank: it is meant for the leading dimensions.
71     auto newOffsets = rewriter.getArrayAttr(
72         extractOp.getOffsets().getValue().drop_front(dropCount));
73     auto newSizes = rewriter.getArrayAttr(
74         extractOp.getSizes().getValue().drop_front(dropCount));
75     auto newStrides = rewriter.getArrayAttr(
76         extractOp.getStrides().getValue().drop_front(dropCount));
77 
78     auto newExtractOp = rewriter.create<vector::ExtractStridedSliceOp>(
79         loc, newDstType, newSrcVector, newOffsets, newSizes, newStrides);
80 
81     rewriter.replaceOpWithNewOp<vector::BroadcastOp>(extractOp, oldDstType,
82                                                      newExtractOp);
83 
84     return success();
85   }
86 };
87 
88 // Casts away leading one dimensions in vector.insert_strided_slice's vector
89 // inputs by inserting vector.broadcast.
90 struct CastAwayInsertStridedSliceLeadingOneDim
91     : public OpRewritePattern<vector::InsertStridedSliceOp> {
92   using OpRewritePattern::OpRewritePattern;
93 
94   LogicalResult matchAndRewrite(vector::InsertStridedSliceOp insertOp,
95                                 PatternRewriter &rewriter) const override {
96     VectorType oldSrcType = insertOp.getSourceVectorType();
97     VectorType newSrcType = trimLeadingOneDims(oldSrcType);
98     VectorType oldDstType = insertOp.getDestVectorType();
99     VectorType newDstType = trimLeadingOneDims(oldDstType);
100 
101     int64_t srcDropCount = oldSrcType.getRank() - newSrcType.getRank();
102     int64_t dstDropCount = oldDstType.getRank() - newDstType.getRank();
103     if (srcDropCount == 0 && dstDropCount == 0)
104       return failure();
105 
106     // Trim leading one dimensions from both operands.
107     Location loc = insertOp.getLoc();
108 
109     Value newSrcVector = rewriter.create<vector::ExtractOp>(
110         loc, insertOp.getSource(), splatZero(srcDropCount));
111     Value newDstVector = rewriter.create<vector::ExtractOp>(
112         loc, insertOp.getDest(), splatZero(dstDropCount));
113 
114     auto newOffsets = rewriter.getArrayAttr(
115         insertOp.getOffsets().getValue().take_back(newDstType.getRank()));
116     auto newStrides = rewriter.getArrayAttr(
117         insertOp.getStrides().getValue().take_back(newSrcType.getRank()));
118 
119     auto newInsertOp = rewriter.create<vector::InsertStridedSliceOp>(
120         loc, newDstType, newSrcVector, newDstVector, newOffsets, newStrides);
121 
122     rewriter.replaceOpWithNewOp<vector::BroadcastOp>(insertOp, oldDstType,
123                                                      newInsertOp);
124 
125     return success();
126   }
127 };
128 
129 // Casts away leading one dimensions in vector.insert's vector inputs by
130 // inserting vector.broadcast.
131 struct CastAwayInsertLeadingOneDim : public OpRewritePattern<vector::InsertOp> {
132   using OpRewritePattern::OpRewritePattern;
133 
134   LogicalResult matchAndRewrite(vector::InsertOp insertOp,
135                                 PatternRewriter &rewriter) const override {
136     Type oldSrcType = insertOp.getSourceType();
137     Type newSrcType = oldSrcType;
138     int64_t oldSrcRank = 0, newSrcRank = 0;
139     if (auto type = oldSrcType.dyn_cast<VectorType>()) {
140       newSrcType = trimLeadingOneDims(type);
141       oldSrcRank = type.getRank();
142       newSrcRank = newSrcType.cast<VectorType>().getRank();
143     }
144 
145     VectorType oldDstType = insertOp.getDestVectorType();
146     VectorType newDstType = trimLeadingOneDims(oldDstType);
147 
148     int64_t srcDropCount = oldSrcRank - newSrcRank;
149     int64_t dstDropCount = oldDstType.getRank() - newDstType.getRank();
150     if (srcDropCount == 0 && dstDropCount == 0)
151       return failure();
152 
153     // Trim leading one dimensions from both operands.
154     Location loc = insertOp.getLoc();
155 
156     Value newSrcVector = insertOp.getSource();
157     if (oldSrcRank != 0) {
158       newSrcVector = rewriter.create<vector::ExtractOp>(
159           loc, insertOp.getSource(), splatZero(srcDropCount));
160     }
161     Value newDstVector = rewriter.create<vector::ExtractOp>(
162         loc, insertOp.getDest(), splatZero(dstDropCount));
163 
164     unsigned oldPosRank = insertOp.getPosition().getValue().size();
165     unsigned newPosRank = newDstType.getRank() - newSrcRank;
166     SmallVector<Attribute> newPositions = llvm::to_vector(
167         insertOp.getPosition().getValue().take_back(newPosRank));
168     if (newPosRank > oldPosRank) {
169       auto zeroAttr = rewriter.getZeroAttr(rewriter.getI64Type());
170       newPositions.resize(newPosRank, zeroAttr);
171     }
172 
173     auto newInsertOp = rewriter.create<vector::InsertOp>(
174         loc, newDstType, newSrcVector, newDstVector,
175         rewriter.getArrayAttr(newPositions));
176 
177     rewriter.replaceOpWithNewOp<vector::BroadcastOp>(insertOp, oldDstType,
178                                                      newInsertOp);
179 
180     return success();
181   }
182 };
183 
184 // Turns vector.transfer_read on vector with leading 1 dimensions into
185 // vector.shape_cast followed by vector.transfer_read on vector without leading
186 // 1 dimensions.
187 struct CastAwayTransferReadLeadingOneDim
188     : public OpRewritePattern<vector::TransferReadOp> {
189   using OpRewritePattern::OpRewritePattern;
190 
191   LogicalResult matchAndRewrite(vector::TransferReadOp read,
192                                 PatternRewriter &rewriter) const override {
193     // TODO: support 0-d corner case.
194     if (read.getTransferRank() == 0)
195       return failure();
196 
197     if (read.getMask())
198       return failure();
199 
200     auto shapedType = read.getSource().getType().cast<ShapedType>();
201     if (shapedType.getElementType() != read.getVectorType().getElementType())
202       return failure();
203 
204     VectorType oldType = read.getVectorType();
205     VectorType newType = trimLeadingOneDims(oldType);
206 
207     if (newType == oldType)
208       return failure();
209 
210     AffineMap oldMap = read.getPermutationMap();
211     ArrayRef<AffineExpr> newResults =
212         oldMap.getResults().take_back(newType.getRank());
213     AffineMap newMap =
214         AffineMap::get(oldMap.getNumDims(), oldMap.getNumSymbols(), newResults,
215                        rewriter.getContext());
216 
217     ArrayAttr inBoundsAttr;
218     if (read.getInBounds())
219       inBoundsAttr = rewriter.getArrayAttr(
220           read.getInBoundsAttr().getValue().take_back(newType.getRank()));
221 
222     auto newRead = rewriter.create<vector::TransferReadOp>(
223         read.getLoc(), newType, read.getSource(), read.getIndices(),
224         AffineMapAttr::get(newMap), read.getPadding(), /*mask=*/Value(),
225         inBoundsAttr);
226     rewriter.replaceOpWithNewOp<vector::BroadcastOp>(read, oldType, newRead);
227 
228     return success();
229   }
230 };
231 
232 // Turns vector.transfer_write on vector with leading 1 dimensions into
233 // vector.shape_cast followed by vector.transfer_write on vector without leading
234 // 1 dimensions.
235 struct CastAwayTransferWriteLeadingOneDim
236     : public OpRewritePattern<vector::TransferWriteOp> {
237   using OpRewritePattern::OpRewritePattern;
238 
239   LogicalResult matchAndRewrite(vector::TransferWriteOp write,
240                                 PatternRewriter &rewriter) const override {
241     // TODO: support 0-d corner case.
242     if (write.getTransferRank() == 0)
243       return failure();
244 
245     if (write.getMask())
246       return failure();
247 
248     auto shapedType = write.getSource().getType().dyn_cast<ShapedType>();
249     if (shapedType.getElementType() != write.getVectorType().getElementType())
250       return failure();
251 
252     VectorType oldType = write.getVectorType();
253     VectorType newType = trimLeadingOneDims(oldType);
254     if (newType == oldType)
255       return failure();
256     int64_t dropDim = oldType.getRank() - newType.getRank();
257 
258     AffineMap oldMap = write.getPermutationMap();
259     ArrayRef<AffineExpr> newResults =
260         oldMap.getResults().take_back(newType.getRank());
261     AffineMap newMap =
262         AffineMap::get(oldMap.getNumDims(), oldMap.getNumSymbols(), newResults,
263                        rewriter.getContext());
264 
265     ArrayAttr inBoundsAttr;
266     if (write.getInBounds())
267       inBoundsAttr = rewriter.getArrayAttr(
268           write.getInBoundsAttr().getValue().take_back(newType.getRank()));
269 
270     auto newVector = rewriter.create<vector::ExtractOp>(
271         write.getLoc(), write.getVector(), splatZero(dropDim));
272     rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
273         write, newVector, write.getSource(), write.getIndices(),
274         AffineMapAttr::get(newMap), inBoundsAttr);
275 
276     return success();
277   }
278 };
279 
280 /// Turns vector.contract on vector with leading 1 dimensions into
281 /// vector.extract followed by vector.contract on vector without leading
282 /// 1 dimensions. Also performs tranpose of lhs and rhs operands if required
283 /// prior to extract.
284 struct CastAwayContractionLeadingOneDim
285     : public OpRewritePattern<vector::ContractionOp> {
286   using OpRewritePattern::OpRewritePattern;
287 
288   LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
289                                 PatternRewriter &rewriter) const override {
290     VectorType oldAccType = contractOp.getAccType().dyn_cast<VectorType>();
291     if (oldAccType == nullptr)
292       return failure();
293     if (oldAccType.getRank() < 2)
294       return failure();
295     if (oldAccType.getShape()[0] != 1)
296       return failure();
297     // currently we support only dropping one dim but the pattern can be applied
298     // greedily to drop more.
299     int64_t dropDim = 1;
300 
301     auto oldIndexingMaps = contractOp.getIndexingMapsArray();
302     SmallVector<AffineMap> newIndexingMaps;
303 
304     auto oldIteratorTypes = contractOp.getIteratorTypes();
305     SmallVector<Attribute> newIteratorTypes;
306 
307     int64_t dimToDrop = oldIndexingMaps[2].getDimPosition(0);
308 
309     if (!isParallelIterator(oldIteratorTypes[dimToDrop]))
310       // only parallel type iterators can be dropped.
311       return failure();
312 
313     for (const auto &it : llvm::enumerate(oldIteratorTypes)) {
314       int64_t currDim = it.index();
315       if (currDim == dimToDrop)
316         continue;
317       newIteratorTypes.push_back(it.value());
318     }
319 
320     SmallVector<Value> operands = {contractOp.getLhs(), contractOp.getRhs(),
321                                    contractOp.getAcc()};
322     SmallVector<Value> newOperands;
323 
324     for (const auto &it : llvm::enumerate(oldIndexingMaps)) {
325       // Check if the dim to be dropped exists as a leading dim in the operand
326       // if it does then we use vector.extract to drop it.
327       bool validExtract = false;
328       SmallVector<AffineExpr> results;
329       auto map = it.value();
330       int64_t orginalZeroDim = it.value().getDimPosition(0);
331       if (orginalZeroDim != dimToDrop) {
332         // There are two reasons to be in this path, 1. We need to
333         // tranpose the operand to make the dim to be dropped
334         // leading. 2. The dim to be dropped does not exist and in
335         // that case we dont want to add a unit tranpose but we must
336         // check all the indices to make sure this is the case.
337         bool tranposeNeeded = false;
338         SmallVector<int64_t> perm;
339         SmallVector<AffineExpr> transposeResults;
340 
341         for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) {
342           int64_t currDim = map.getDimPosition(i);
343           if (currDim == dimToDrop) {
344             tranposeNeeded = true;
345             perm.insert(perm.begin(), i);
346             auto targetExpr = rewriter.getAffineDimExpr(currDim);
347             transposeResults.insert(transposeResults.begin(), targetExpr);
348           } else {
349             perm.push_back(i);
350             auto targetExpr = rewriter.getAffineDimExpr(currDim);
351             transposeResults.push_back(targetExpr);
352           }
353         }
354         // Do the tranpose now if needed so that we can drop the
355         // correct dim using extract later.
356         if (tranposeNeeded) {
357           map = AffineMap::get(map.getNumDims(), 0, transposeResults,
358                                contractOp.getContext());
359           operands[it.index()] = rewriter.create<vector::TransposeOp>(
360               contractOp.getLoc(), operands[it.index()], perm);
361         }
362       }
363       // We have taken care to have the dim to be dropped be
364       // the leading dim. If its still not leading that means it
365       // does not exist in this operand and hence we do not need
366       // an extract.
367       if (map.getDimPosition(0) == dimToDrop)
368         validExtract = true;
369 
370       for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) {
371         int64_t currDim = map.getDimPosition(i);
372         if (currDim == dimToDrop)
373           // This is the dim we are dropping.
374           continue;
375         auto targetExpr = rewriter.getAffineDimExpr(
376             currDim < dimToDrop ? currDim : currDim - 1);
377         results.push_back(targetExpr);
378       }
379       newIndexingMaps.push_back(AffineMap::get(map.getNumDims() - 1, 0, results,
380                                                contractOp.getContext()));
381       // Extract if its a valid extraction, otherwise use the operand
382       // without extraction.
383       newOperands.push_back(validExtract
384                                 ? rewriter.create<vector::ExtractOp>(
385                                       contractOp.getLoc(), operands[it.index()],
386                                       splatZero(dropDim))
387                                 : operands[it.index()]);
388     }
389     auto newContractOp = rewriter.create<vector::ContractionOp>(
390         contractOp.getLoc(), newOperands[0], newOperands[1], newOperands[2],
391         rewriter.getAffineMapArrayAttr(newIndexingMaps),
392         rewriter.getArrayAttr(newIteratorTypes), contractOp.getKind());
393     rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
394         contractOp, contractOp->getResultTypes()[0], newContractOp);
395     return success();
396   }
397 };
398 
399 class CastAwayElementwiseLeadingOneDim : public RewritePattern {
400 public:
401   CastAwayElementwiseLeadingOneDim(MLIRContext *context,
402                                    PatternBenefit benefit = 1)
403       : RewritePattern(MatchAnyOpTypeTag(), benefit, context) {}
404 
405   LogicalResult matchAndRewrite(Operation *op,
406                                 PatternRewriter &rewriter) const override {
407     if (!OpTrait::hasElementwiseMappableTraits(op) || op->getNumResults() != 1)
408       return failure();
409     auto vecType = op->getResultTypes()[0].dyn_cast<VectorType>();
410     if (!vecType)
411       return failure();
412     VectorType newVecType = trimLeadingOneDims(vecType);
413     if (newVecType == vecType)
414       return failure();
415     int64_t dropDim = vecType.getRank() - newVecType.getRank();
416     SmallVector<Value, 4> newOperands;
417     for (Value operand : op->getOperands()) {
418       if (auto opVecType = operand.getType().dyn_cast<VectorType>()) {
419         newOperands.push_back(rewriter.create<vector::ExtractOp>(
420             op->getLoc(), operand, splatZero(dropDim)));
421       } else {
422         newOperands.push_back(operand);
423       }
424     }
425     Operation *newOp =
426         rewriter.create(op->getLoc(), op->getName().getIdentifier(),
427                         newOperands, newVecType, op->getAttrs());
428     rewriter.replaceOpWithNewOp<vector::BroadcastOp>(op, vecType,
429                                                      newOp->getResult(0));
430     return success();
431   }
432 };
433 
434 } // namespace
435 
436 void mlir::vector::populateCastAwayVectorLeadingOneDimPatterns(
437     RewritePatternSet &patterns, PatternBenefit benefit) {
438   patterns
439       .add<CastAwayExtractStridedSliceLeadingOneDim,
440            CastAwayInsertStridedSliceLeadingOneDim, CastAwayInsertLeadingOneDim,
441            CastAwayTransferReadLeadingOneDim,
442            CastAwayTransferWriteLeadingOneDim, CastAwayElementwiseLeadingOneDim,
443            CastAwayContractionLeadingOneDim>(patterns.getContext(), benefit);
444   populateShapeCastFoldingPatterns(patterns, benefit);
445 }
446