xref: /llvm-project/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp (revision 16b75cd2bb439633d29c99a7663f2586e4068ecf)
1 //===- VectorDropLeadUnitDim.cpp - Conversion within the Vector dialect ---===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Utils/StructuredOpsUtils.h"
10 #include "mlir/Dialect/Vector/IR/VectorOps.h"
11 #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
12 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
13 #include "mlir/Dialect/Vector/Utils/VectorUtils.h"
14 #include "mlir/IR/Builders.h"
15 #include "mlir/IR/TypeUtilities.h"
16 
17 #define DEBUG_TYPE "vector-drop-unit-dim"
18 
19 using namespace mlir;
20 using namespace mlir::vector;
21 
22 // Trims leading one dimensions from `oldType` and returns the result type.
23 // Returns `vector<1xT>` if `oldType` only has one element.
24 static VectorType trimLeadingOneDims(VectorType oldType) {
25   ArrayRef<int64_t> oldShape = oldType.getShape();
26   ArrayRef<int64_t> newShape =
27       oldShape.drop_while([](int64_t dim) { return dim == 1; });
28   // Make sure we have at least 1 dimension per vector type requirements.
29   if (newShape.empty())
30     newShape = oldShape.take_back();
31   return VectorType::get(newShape, oldType.getElementType());
32 }
33 
34 /// Return a smallVector of size `rank` containing all zeros.
35 static SmallVector<int64_t> splatZero(int64_t rank) {
36   return SmallVector<int64_t>(rank, 0);
37 }
38 namespace {
39 
40 // Casts away leading one dimensions in vector.extract_strided_slice's vector
41 // input by inserting vector.broadcast.
42 struct CastAwayExtractStridedSliceLeadingOneDim
43     : public OpRewritePattern<vector::ExtractStridedSliceOp> {
44   using OpRewritePattern::OpRewritePattern;
45 
46   LogicalResult matchAndRewrite(vector::ExtractStridedSliceOp extractOp,
47                                 PatternRewriter &rewriter) const override {
48     // vector.extract_strided_slice requires the input and output vector to have
49     // the same rank. Here we drop leading one dimensions from the input vector
50     // type to make sure we don't cause mismatch.
51     VectorType oldSrcType = extractOp.getSourceVectorType();
52     VectorType newSrcType = trimLeadingOneDims(oldSrcType);
53 
54     if (newSrcType.getRank() == oldSrcType.getRank())
55       return failure();
56 
57     int64_t dropCount = oldSrcType.getRank() - newSrcType.getRank();
58 
59     VectorType oldDstType = extractOp.getType();
60     VectorType newDstType =
61         VectorType::get(oldDstType.getShape().drop_front(dropCount),
62                         oldDstType.getElementType());
63 
64     Location loc = extractOp.getLoc();
65 
66     Value newSrcVector = rewriter.create<vector::ExtractOp>(
67         loc, extractOp.getVector(), splatZero(dropCount));
68 
69     // The offsets/sizes/strides attribute can have a less number of elements
70     // than the input vector's rank: it is meant for the leading dimensions.
71     auto newOffsets = rewriter.getArrayAttr(
72         extractOp.getOffsets().getValue().drop_front(dropCount));
73     auto newSizes = rewriter.getArrayAttr(
74         extractOp.getSizes().getValue().drop_front(dropCount));
75     auto newStrides = rewriter.getArrayAttr(
76         extractOp.getStrides().getValue().drop_front(dropCount));
77 
78     auto newExtractOp = rewriter.create<vector::ExtractStridedSliceOp>(
79         loc, newDstType, newSrcVector, newOffsets, newSizes, newStrides);
80 
81     rewriter.replaceOpWithNewOp<vector::BroadcastOp>(extractOp, oldDstType,
82                                                      newExtractOp);
83 
84     return success();
85   }
86 };
87 
88 // Casts away leading one dimensions in vector.insert_strided_slice's vector
89 // inputs by inserting vector.broadcast.
90 struct CastAwayInsertStridedSliceLeadingOneDim
91     : public OpRewritePattern<vector::InsertStridedSliceOp> {
92   using OpRewritePattern::OpRewritePattern;
93 
94   LogicalResult matchAndRewrite(vector::InsertStridedSliceOp insertOp,
95                                 PatternRewriter &rewriter) const override {
96     VectorType oldSrcType = insertOp.getSourceVectorType();
97     VectorType newSrcType = trimLeadingOneDims(oldSrcType);
98     VectorType oldDstType = insertOp.getDestVectorType();
99     VectorType newDstType = trimLeadingOneDims(oldDstType);
100 
101     int64_t srcDropCount = oldSrcType.getRank() - newSrcType.getRank();
102     int64_t dstDropCount = oldDstType.getRank() - newDstType.getRank();
103     if (srcDropCount == 0 && dstDropCount == 0)
104       return failure();
105 
106     // Trim leading one dimensions from both operands.
107     Location loc = insertOp.getLoc();
108 
109     Value newSrcVector = rewriter.create<vector::ExtractOp>(
110         loc, insertOp.getSource(), splatZero(srcDropCount));
111     Value newDstVector = rewriter.create<vector::ExtractOp>(
112         loc, insertOp.getDest(), splatZero(dstDropCount));
113 
114     auto newOffsets = rewriter.getArrayAttr(
115         insertOp.getOffsets().getValue().take_back(newDstType.getRank()));
116     auto newStrides = rewriter.getArrayAttr(
117         insertOp.getStrides().getValue().take_back(newSrcType.getRank()));
118 
119     auto newInsertOp = rewriter.create<vector::InsertStridedSliceOp>(
120         loc, newDstType, newSrcVector, newDstVector, newOffsets, newStrides);
121 
122     rewriter.replaceOpWithNewOp<vector::BroadcastOp>(insertOp, oldDstType,
123                                                      newInsertOp);
124 
125     return success();
126   }
127 };
128 
129 // Casts away leading one dimensions in vector.insert's vector inputs by
130 // inserting vector.broadcast.
131 struct CastAwayInsertLeadingOneDim : public OpRewritePattern<vector::InsertOp> {
132   using OpRewritePattern::OpRewritePattern;
133 
134   LogicalResult matchAndRewrite(vector::InsertOp insertOp,
135                                 PatternRewriter &rewriter) const override {
136     Type oldSrcType = insertOp.getSourceType();
137     Type newSrcType = oldSrcType;
138     int64_t oldSrcRank = 0, newSrcRank = 0;
139     if (auto type = dyn_cast<VectorType>(oldSrcType)) {
140       newSrcType = trimLeadingOneDims(type);
141       oldSrcRank = type.getRank();
142       newSrcRank = cast<VectorType>(newSrcType).getRank();
143     }
144 
145     VectorType oldDstType = insertOp.getDestVectorType();
146     VectorType newDstType = trimLeadingOneDims(oldDstType);
147 
148     int64_t srcDropCount = oldSrcRank - newSrcRank;
149     int64_t dstDropCount = oldDstType.getRank() - newDstType.getRank();
150     if (srcDropCount == 0 && dstDropCount == 0)
151       return failure();
152 
153     // Trim leading one dimensions from both operands.
154     Location loc = insertOp.getLoc();
155 
156     Value newSrcVector = insertOp.getSource();
157     if (oldSrcRank != 0) {
158       newSrcVector = rewriter.create<vector::ExtractOp>(
159           loc, insertOp.getSource(), splatZero(srcDropCount));
160     }
161     Value newDstVector = rewriter.create<vector::ExtractOp>(
162         loc, insertOp.getDest(), splatZero(dstDropCount));
163 
164     // New position rank needs to be computed in two steps: (1) if destination
165     // type has leading unit dims, we also trim the position array accordingly,
166     // then (2) if source type also has leading unit dims, we need to append
167     // zeroes to the position array accordingly.
168     unsigned oldPosRank = insertOp.getPosition().size();
169     unsigned newPosRank = std::max<int64_t>(0, oldPosRank - dstDropCount);
170     SmallVector<int64_t> newPositions =
171         llvm::to_vector(insertOp.getPosition().take_back(newPosRank));
172     newPositions.resize(newDstType.getRank() - newSrcRank, 0);
173 
174     auto newInsertOp = rewriter.create<vector::InsertOp>(
175         loc, newDstType, newSrcVector, newDstVector, newPositions);
176 
177     rewriter.replaceOpWithNewOp<vector::BroadcastOp>(insertOp, oldDstType,
178                                                      newInsertOp);
179 
180     return success();
181   }
182 };
183 
184 // Turns vector.transfer_read on vector with leading 1 dimensions into
185 // vector.shape_cast followed by vector.transfer_read on vector without leading
186 // 1 dimensions.
187 struct CastAwayTransferReadLeadingOneDim
188     : public OpRewritePattern<vector::TransferReadOp> {
189   using OpRewritePattern::OpRewritePattern;
190 
191   LogicalResult matchAndRewrite(vector::TransferReadOp read,
192                                 PatternRewriter &rewriter) const override {
193     // TODO: support 0-d corner case.
194     if (read.getTransferRank() == 0)
195       return failure();
196 
197     if (read.getMask())
198       return failure();
199 
200     auto shapedType = cast<ShapedType>(read.getSource().getType());
201     if (shapedType.getElementType() != read.getVectorType().getElementType())
202       return failure();
203 
204     VectorType oldType = read.getVectorType();
205     VectorType newType = trimLeadingOneDims(oldType);
206 
207     if (newType == oldType)
208       return failure();
209 
210     AffineMap oldMap = read.getPermutationMap();
211     ArrayRef<AffineExpr> newResults =
212         oldMap.getResults().take_back(newType.getRank());
213     AffineMap newMap =
214         AffineMap::get(oldMap.getNumDims(), oldMap.getNumSymbols(), newResults,
215                        rewriter.getContext());
216 
217     ArrayAttr inBoundsAttr;
218     if (read.getInBounds())
219       inBoundsAttr = rewriter.getArrayAttr(
220           read.getInBoundsAttr().getValue().take_back(newType.getRank()));
221 
222     auto newRead = rewriter.create<vector::TransferReadOp>(
223         read.getLoc(), newType, read.getSource(), read.getIndices(),
224         AffineMapAttr::get(newMap), read.getPadding(), /*mask=*/Value(),
225         inBoundsAttr);
226     rewriter.replaceOpWithNewOp<vector::BroadcastOp>(read, oldType, newRead);
227 
228     return success();
229   }
230 };
231 
232 // Turns vector.transfer_write on vector with leading 1 dimensions into
233 // vector.shape_cast followed by vector.transfer_write on vector without leading
234 // 1 dimensions.
235 struct CastAwayTransferWriteLeadingOneDim
236     : public OpRewritePattern<vector::TransferWriteOp> {
237   using OpRewritePattern::OpRewritePattern;
238 
239   LogicalResult matchAndRewrite(vector::TransferWriteOp write,
240                                 PatternRewriter &rewriter) const override {
241     // TODO: support 0-d corner case.
242     if (write.getTransferRank() == 0)
243       return failure();
244 
245     if (write.getMask())
246       return failure();
247 
248     auto shapedType = dyn_cast<ShapedType>(write.getSource().getType());
249     if (shapedType.getElementType() != write.getVectorType().getElementType())
250       return failure();
251 
252     VectorType oldType = write.getVectorType();
253     VectorType newType = trimLeadingOneDims(oldType);
254     if (newType == oldType)
255       return failure();
256     int64_t dropDim = oldType.getRank() - newType.getRank();
257 
258     AffineMap oldMap = write.getPermutationMap();
259     ArrayRef<AffineExpr> newResults =
260         oldMap.getResults().take_back(newType.getRank());
261     AffineMap newMap =
262         AffineMap::get(oldMap.getNumDims(), oldMap.getNumSymbols(), newResults,
263                        rewriter.getContext());
264 
265     ArrayAttr inBoundsAttr;
266     if (write.getInBounds())
267       inBoundsAttr = rewriter.getArrayAttr(
268           write.getInBoundsAttr().getValue().take_back(newType.getRank()));
269 
270     auto newVector = rewriter.create<vector::ExtractOp>(
271         write.getLoc(), write.getVector(), splatZero(dropDim));
272     rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
273         write, newVector, write.getSource(), write.getIndices(),
274         AffineMapAttr::get(newMap), inBoundsAttr);
275 
276     return success();
277   }
278 };
279 
280 } // namespace
281 
282 LogicalResult
283 mlir::vector::castAwayContractionLeadingOneDim(vector::ContractionOp contractOp,
284                                                RewriterBase &rewriter) {
285   VectorType oldAccType = dyn_cast<VectorType>(contractOp.getAccType());
286   if (oldAccType == nullptr)
287     return failure();
288   if (oldAccType.getRank() < 2)
289     return failure();
290   if (oldAccType.getShape()[0] != 1)
291     return failure();
292   // currently we support only dropping one dim but the pattern can be applied
293   // greedily to drop more.
294   int64_t dropDim = 1;
295 
296   auto oldIndexingMaps = contractOp.getIndexingMapsArray();
297   SmallVector<AffineMap> newIndexingMaps;
298 
299   auto oldIteratorTypes = contractOp.getIteratorTypes();
300   SmallVector<Attribute> newIteratorTypes;
301 
302   int64_t dimToDrop = oldIndexingMaps[2].getDimPosition(0);
303 
304   if (!isParallelIterator(oldIteratorTypes[dimToDrop]))
305     // only parallel type iterators can be dropped.
306     return failure();
307 
308   for (const auto &it : llvm::enumerate(oldIteratorTypes)) {
309     int64_t currDim = it.index();
310     if (currDim == dimToDrop)
311       continue;
312     newIteratorTypes.push_back(it.value());
313   }
314 
315   SmallVector<Value> operands = {contractOp.getLhs(), contractOp.getRhs(),
316                                  contractOp.getAcc()};
317   SmallVector<Value> newOperands;
318 
319   for (const auto &it : llvm::enumerate(oldIndexingMaps)) {
320     // Check if the dim to be dropped exists as a leading dim in the operand
321     // if it does then we use vector.extract to drop it.
322     bool validExtract = false;
323     SmallVector<AffineExpr> results;
324     auto map = it.value();
325     int64_t orginalZeroDim = it.value().getDimPosition(0);
326     if (orginalZeroDim != dimToDrop) {
327       // There are two reasons to be in this path, 1. We need to
328       // tranpose the operand to make the dim to be dropped
329       // leading. 2. The dim to be dropped does not exist and in
330       // that case we dont want to add a unit tranpose but we must
331       // check all the indices to make sure this is the case.
332       bool tranposeNeeded = false;
333       SmallVector<int64_t> perm;
334       SmallVector<AffineExpr> transposeResults;
335 
336       for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) {
337         int64_t currDim = map.getDimPosition(i);
338         if (currDim == dimToDrop) {
339           tranposeNeeded = true;
340           perm.insert(perm.begin(), i);
341           auto targetExpr = rewriter.getAffineDimExpr(currDim);
342           transposeResults.insert(transposeResults.begin(), targetExpr);
343         } else {
344           perm.push_back(i);
345           auto targetExpr = rewriter.getAffineDimExpr(currDim);
346           transposeResults.push_back(targetExpr);
347         }
348       }
349       // Do the tranpose now if needed so that we can drop the
350       // correct dim using extract later.
351       if (tranposeNeeded) {
352         map = AffineMap::get(map.getNumDims(), 0, transposeResults,
353                              contractOp.getContext());
354         operands[it.index()] = rewriter.create<vector::TransposeOp>(
355             contractOp.getLoc(), operands[it.index()], perm);
356       }
357     }
358     // We have taken care to have the dim to be dropped be
359     // the leading dim. If its still not leading that means it
360     // does not exist in this operand and hence we do not need
361     // an extract.
362     if (map.getDimPosition(0) == dimToDrop)
363       validExtract = true;
364 
365     for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) {
366       int64_t currDim = map.getDimPosition(i);
367       if (currDim == dimToDrop)
368         // This is the dim we are dropping.
369         continue;
370       auto targetExpr = rewriter.getAffineDimExpr(
371           currDim < dimToDrop ? currDim : currDim - 1);
372       results.push_back(targetExpr);
373     }
374     newIndexingMaps.push_back(AffineMap::get(map.getNumDims() - 1, 0, results,
375                                              contractOp.getContext()));
376     // Extract if its a valid extraction, otherwise use the operand
377     // without extraction.
378     newOperands.push_back(
379         validExtract ? rewriter.create<vector::ExtractOp>(contractOp.getLoc(),
380                                                           operands[it.index()],
381                                                           splatZero(dropDim))
382                      : operands[it.index()]);
383   }
384   auto newContractOp = rewriter.create<vector::ContractionOp>(
385       contractOp.getLoc(), newOperands[0], newOperands[1], newOperands[2],
386       rewriter.getAffineMapArrayAttr(newIndexingMaps),
387       rewriter.getArrayAttr(newIteratorTypes), contractOp.getKind());
388   rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
389       contractOp, contractOp->getResultTypes()[0], newContractOp);
390   return success();
391 }
392 
393 namespace {
394 
395 /// Turns vector.contract on vector with leading 1 dimensions into
396 /// vector.extract followed by vector.contract on vector without leading
397 /// 1 dimensions. Also performs tranpose of lhs and rhs operands if required
398 /// prior to extract.
399 struct CastAwayContractionLeadingOneDim
400     : public OpRewritePattern<vector::ContractionOp> {
401   using OpRewritePattern::OpRewritePattern;
402 
403   LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
404                                 PatternRewriter &rewriter) const override {
405     return castAwayContractionLeadingOneDim(contractOp, rewriter);
406   }
407 };
408 
409 class CastAwayElementwiseLeadingOneDim : public RewritePattern {
410 public:
411   CastAwayElementwiseLeadingOneDim(MLIRContext *context,
412                                    PatternBenefit benefit = 1)
413       : RewritePattern(MatchAnyOpTypeTag(), benefit, context) {}
414 
415   LogicalResult matchAndRewrite(Operation *op,
416                                 PatternRewriter &rewriter) const override {
417     if (!OpTrait::hasElementwiseMappableTraits(op) || op->getNumResults() != 1)
418       return failure();
419     auto vecType = dyn_cast<VectorType>(op->getResultTypes()[0]);
420     if (!vecType)
421       return failure();
422     VectorType newVecType = trimLeadingOneDims(vecType);
423     if (newVecType == vecType)
424       return failure();
425     int64_t dropDim = vecType.getRank() - newVecType.getRank();
426     SmallVector<Value, 4> newOperands;
427     for (Value operand : op->getOperands()) {
428       if (auto opVecType = dyn_cast<VectorType>(operand.getType())) {
429         newOperands.push_back(rewriter.create<vector::ExtractOp>(
430             op->getLoc(), operand, splatZero(dropDim)));
431       } else {
432         newOperands.push_back(operand);
433       }
434     }
435     Operation *newOp =
436         rewriter.create(op->getLoc(), op->getName().getIdentifier(),
437                         newOperands, newVecType, op->getAttrs());
438     rewriter.replaceOpWithNewOp<vector::BroadcastOp>(op, vecType,
439                                                      newOp->getResult(0));
440     return success();
441   }
442 };
443 
444 } // namespace
445 
446 void mlir::vector::populateCastAwayVectorLeadingOneDimPatterns(
447     RewritePatternSet &patterns, PatternBenefit benefit) {
448   patterns
449       .add<CastAwayExtractStridedSliceLeadingOneDim,
450            CastAwayInsertStridedSliceLeadingOneDim, CastAwayInsertLeadingOneDim,
451            CastAwayTransferReadLeadingOneDim,
452            CastAwayTransferWriteLeadingOneDim, CastAwayElementwiseLeadingOneDim,
453            CastAwayContractionLeadingOneDim>(patterns.getContext(), benefit);
454   populateShapeCastFoldingPatterns(patterns, benefit);
455 }
456