xref: /llvm-project/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp (revision e54236dfb5982bc8358bad62a27e6048f06a0272)
1 //===- VectorDropLeadUnitDim.cpp - Conversion within the Vector dialect ---===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Utils/StructuredOpsUtils.h"
10 #include "mlir/Dialect/Vector/IR/VectorOps.h"
11 #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
12 #include "mlir/Dialect/Vector/Utils/VectorUtils.h"
13 #include "mlir/IR/Builders.h"
14 #include "mlir/IR/ImplicitLocOpBuilder.h"
15 #include "mlir/IR/TypeUtilities.h"
16 
17 #define DEBUG_TYPE "vector-drop-unit-dim"
18 
19 using namespace mlir;
20 using namespace mlir::vector;
21 
22 // Trims leading one dimensions from `oldType` and returns the result type.
23 // Returns `vector<1xT>` if `oldType` only has one element.
24 static VectorType trimLeadingOneDims(VectorType oldType) {
25   ArrayRef<int64_t> oldShape = oldType.getShape();
26   ArrayRef<int64_t> newShape =
27       oldShape.drop_while([](int64_t dim) { return dim == 1; });
28   // Make sure we have at least 1 dimension per vector type requirements.
29   if (newShape.empty())
30     newShape = oldShape.take_back();
31   return VectorType::get(newShape, oldType.getElementType());
32 }
33 
34 /// Return a smallVector of size `rank` containing all zeros.
35 static SmallVector<int64_t> splatZero(int64_t rank) {
36   return SmallVector<int64_t>(rank, 0);
37 }
38 namespace {
39 
40 // Casts away leading one dimensions in vector.extract_strided_slice's vector
41 // input by inserting vector.broadcast.
42 struct CastAwayExtractStridedSliceLeadingOneDim
43     : public OpRewritePattern<vector::ExtractStridedSliceOp> {
44   using OpRewritePattern::OpRewritePattern;
45 
46   LogicalResult matchAndRewrite(vector::ExtractStridedSliceOp extractOp,
47                                 PatternRewriter &rewriter) const override {
48     // vector.extract_strided_slice requires the input and output vector to have
49     // the same rank. Here we drop leading one dimensions from the input vector
50     // type to make sure we don't cause mismatch.
51     VectorType oldSrcType = extractOp.getVectorType();
52     VectorType newSrcType = trimLeadingOneDims(oldSrcType);
53 
54     if (newSrcType.getRank() == oldSrcType.getRank())
55       return failure();
56 
57     int64_t dropCount = oldSrcType.getRank() - newSrcType.getRank();
58 
59     VectorType oldDstType = extractOp.getType();
60     VectorType newDstType =
61         VectorType::get(oldDstType.getShape().drop_front(dropCount),
62                         oldDstType.getElementType());
63 
64     Location loc = extractOp.getLoc();
65 
66     Value newSrcVector = rewriter.create<vector::ExtractOp>(
67         loc, extractOp.getVector(), splatZero(dropCount));
68 
69     // The offsets/sizes/strides attribute can have a less number of elements
70     // than the input vector's rank: it is meant for the leading dimensions.
71     auto newOffsets = rewriter.getArrayAttr(
72         extractOp.getOffsets().getValue().drop_front(dropCount));
73     auto newSizes = rewriter.getArrayAttr(
74         extractOp.getSizes().getValue().drop_front(dropCount));
75     auto newStrides = rewriter.getArrayAttr(
76         extractOp.getStrides().getValue().drop_front(dropCount));
77 
78     auto newExtractOp = rewriter.create<vector::ExtractStridedSliceOp>(
79         loc, newDstType, newSrcVector, newOffsets, newSizes, newStrides);
80 
81     rewriter.replaceOpWithNewOp<vector::BroadcastOp>(extractOp, oldDstType,
82                                                      newExtractOp);
83 
84     return success();
85   }
86 };
87 
88 // Casts away leading one dimensions in vector.insert_strided_slice's vector
89 // inputs by inserting vector.broadcast.
90 struct CastAwayInsertStridedSliceLeadingOneDim
91     : public OpRewritePattern<vector::InsertStridedSliceOp> {
92   using OpRewritePattern::OpRewritePattern;
93 
94   LogicalResult matchAndRewrite(vector::InsertStridedSliceOp insertOp,
95                                 PatternRewriter &rewriter) const override {
96     VectorType oldSrcType = insertOp.getSourceVectorType();
97     VectorType newSrcType = trimLeadingOneDims(oldSrcType);
98     VectorType oldDstType = insertOp.getDestVectorType();
99     VectorType newDstType = trimLeadingOneDims(oldDstType);
100 
101     int64_t srcDropCount = oldSrcType.getRank() - newSrcType.getRank();
102     int64_t dstDropCount = oldDstType.getRank() - newDstType.getRank();
103     if (srcDropCount == 0 && dstDropCount == 0)
104       return failure();
105 
106     // Trim leading one dimensions from both operands.
107     Location loc = insertOp.getLoc();
108 
109     Value newSrcVector = rewriter.create<vector::ExtractOp>(
110         loc, insertOp.getSource(), splatZero(srcDropCount));
111     Value newDstVector = rewriter.create<vector::ExtractOp>(
112         loc, insertOp.getDest(), splatZero(dstDropCount));
113 
114     auto newOffsets = rewriter.getArrayAttr(
115         insertOp.getOffsets().getValue().take_back(newDstType.getRank()));
116     auto newStrides = rewriter.getArrayAttr(
117         insertOp.getStrides().getValue().take_back(newSrcType.getRank()));
118 
119     auto newInsertOp = rewriter.create<vector::InsertStridedSliceOp>(
120         loc, newDstType, newSrcVector, newDstVector, newOffsets, newStrides);
121 
122     rewriter.replaceOpWithNewOp<vector::BroadcastOp>(insertOp, oldDstType,
123                                                      newInsertOp);
124 
125     return success();
126   }
127 };
128 
129 // Casts away leading one dimensions in vector.insert's vector inputs by
130 // inserting vector.broadcast.
131 struct CastAwayInsertLeadingOneDim : public OpRewritePattern<vector::InsertOp> {
132   using OpRewritePattern::OpRewritePattern;
133 
134   LogicalResult matchAndRewrite(vector::InsertOp insertOp,
135                                 PatternRewriter &rewriter) const override {
136     Type oldSrcType = insertOp.getSourceType();
137     Type newSrcType = oldSrcType;
138     int64_t oldSrcRank = 0, newSrcRank = 0;
139     if (auto type = oldSrcType.dyn_cast<VectorType>()) {
140       newSrcType = trimLeadingOneDims(type);
141       oldSrcRank = type.getRank();
142       newSrcRank = newSrcType.cast<VectorType>().getRank();
143     }
144 
145     VectorType oldDstType = insertOp.getDestVectorType();
146     VectorType newDstType = trimLeadingOneDims(oldDstType);
147 
148     int64_t srcDropCount = oldSrcRank - newSrcRank;
149     int64_t dstDropCount = oldDstType.getRank() - newDstType.getRank();
150     if (srcDropCount == 0 && dstDropCount == 0)
151       return failure();
152 
153     // Trim leading one dimensions from both operands.
154     Location loc = insertOp.getLoc();
155 
156     Value newSrcVector = insertOp.getSource();
157     if (oldSrcRank != 0) {
158       newSrcVector = rewriter.create<vector::ExtractOp>(
159           loc, insertOp.getSource(), splatZero(srcDropCount));
160     }
161     Value newDstVector = rewriter.create<vector::ExtractOp>(
162         loc, insertOp.getDest(), splatZero(dstDropCount));
163 
164     unsigned oldPosRank = insertOp.getPosition().getValue().size();
165     unsigned newPosRank = newDstType.getRank() - newSrcRank;
166     SmallVector<Attribute> newPositions = llvm::to_vector(
167         insertOp.getPosition().getValue().take_back(newPosRank));
168     if (newPosRank > oldPosRank) {
169       auto zeroAttr = rewriter.getZeroAttr(rewriter.getI64Type());
170       newPositions.resize(newPosRank, zeroAttr);
171     }
172 
173     auto newInsertOp = rewriter.create<vector::InsertOp>(
174         loc, newDstType, newSrcVector, newDstVector,
175         rewriter.getArrayAttr(newPositions));
176 
177     rewriter.replaceOpWithNewOp<vector::BroadcastOp>(insertOp, oldDstType,
178                                                      newInsertOp);
179 
180     return success();
181   }
182 };
183 
184 // Turns vector.transfer_read on vector with leading 1 dimensions into
185 // vector.shape_cast followed by vector.transfer_read on vector without leading
186 // 1 dimensions.
187 struct CastAwayTransferReadLeadingOneDim
188     : public OpRewritePattern<vector::TransferReadOp> {
189   using OpRewritePattern::OpRewritePattern;
190 
191   LogicalResult matchAndRewrite(vector::TransferReadOp read,
192                                 PatternRewriter &rewriter) const override {
193     // TODO: support 0-d corner case.
194     if (read.getTransferRank() == 0)
195       return failure();
196 
197     if (read.getMask())
198       return failure();
199 
200     auto shapedType = read.getSource().getType().cast<ShapedType>();
201     if (shapedType.getElementType() != read.getVectorType().getElementType())
202       return failure();
203 
204     VectorType oldType = read.getVectorType();
205     VectorType newType = trimLeadingOneDims(oldType);
206 
207     if (newType == oldType)
208       return failure();
209 
210     AffineMap oldMap = read.getPermutationMap();
211     ArrayRef<AffineExpr> newResults =
212         oldMap.getResults().take_back(newType.getRank());
213     AffineMap newMap =
214         AffineMap::get(oldMap.getNumDims(), oldMap.getNumSymbols(), newResults,
215                        rewriter.getContext());
216 
217     ArrayAttr inBoundsAttr;
218     if (read.getInBounds())
219       inBoundsAttr = rewriter.getArrayAttr(
220           read.getInBoundsAttr().getValue().take_back(newType.getRank()));
221 
222     auto newRead = rewriter.create<vector::TransferReadOp>(
223         read.getLoc(), newType, read.getSource(), read.getIndices(),
224         AffineMapAttr::get(newMap), read.getPadding(), /*mask=*/Value(),
225         inBoundsAttr);
226     rewriter.replaceOpWithNewOp<vector::BroadcastOp>(read, oldType, newRead);
227 
228     return success();
229   }
230 };
231 
232 // Turns vector.transfer_write on vector with leading 1 dimensions into
233 // vector.shape_cast followed by vector.transfer_write on vector without leading
234 // 1 dimensions.
235 struct CastAwayTransferWriteLeadingOneDim
236     : public OpRewritePattern<vector::TransferWriteOp> {
237   using OpRewritePattern::OpRewritePattern;
238 
239   LogicalResult matchAndRewrite(vector::TransferWriteOp write,
240                                 PatternRewriter &rewriter) const override {
241     // TODO: support 0-d corner case.
242     if (write.getTransferRank() == 0)
243       return failure();
244 
245     if (write.getMask())
246       return failure();
247 
248     auto shapedType = write.getSource().getType().dyn_cast<ShapedType>();
249     if (shapedType.getElementType() != write.getVectorType().getElementType())
250       return failure();
251 
252     VectorType oldType = write.getVectorType();
253     VectorType newType = trimLeadingOneDims(oldType);
254     if (newType == oldType)
255       return failure();
256     int64_t dropDim = oldType.getRank() - newType.getRank();
257 
258     AffineMap oldMap = write.getPermutationMap();
259     ArrayRef<AffineExpr> newResults =
260         oldMap.getResults().take_back(newType.getRank());
261     AffineMap newMap =
262         AffineMap::get(oldMap.getNumDims(), oldMap.getNumSymbols(), newResults,
263                        rewriter.getContext());
264 
265     ArrayAttr inBoundsAttr;
266     if (write.getInBounds())
267       inBoundsAttr = rewriter.getArrayAttr(
268           write.getInBoundsAttr().getValue().take_back(newType.getRank()));
269 
270     auto newVector = rewriter.create<vector::ExtractOp>(
271         write.getLoc(), write.getVector(), splatZero(dropDim));
272     rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
273         write, newVector, write.getSource(), write.getIndices(),
274         AffineMapAttr::get(newMap), inBoundsAttr);
275 
276     return success();
277   }
278 };
279 
280 /// Turns vector.contract on vector with leading 1 dimensions into
281 /// vector.extract followed by vector.contract on vector without leading
282 /// 1 dimensions. Also performs tranpose of lhs and rhs operands if required
283 /// prior to extract.
284 struct CastAwayContractionLeadingOneDim
285     : public OpRewritePattern<vector::ContractionOp> {
286   using OpRewritePattern::OpRewritePattern;
287 
288   LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
289                                 PatternRewriter &rewriter) const override {
290     VectorType oldAccType = contractOp.getAccType().dyn_cast<VectorType>();
291     if (oldAccType == nullptr)
292       return failure();
293     if (oldAccType.getRank() < 2)
294       return failure();
295     // TODO: implement masks.
296     if (llvm::size(contractOp.getMasks()) != 0)
297       return failure();
298     if (oldAccType.getShape()[0] != 1)
299       return failure();
300     // currently we support only dropping one dim but the pattern can be applied
301     // greedily to drop more.
302     int64_t dropDim = 1;
303 
304     auto oldIndexingMaps = contractOp.getIndexingMaps();
305     SmallVector<AffineMap> newIndexingMaps;
306 
307     auto oldIteratorTypes = contractOp.getIteratorTypes();
308     SmallVector<Attribute> newIteratorTypes;
309 
310     int64_t dimToDrop = oldIndexingMaps[2].getDimPosition(0);
311 
312     if (!isParallelIterator(oldIteratorTypes[dimToDrop]))
313       // only parallel type iterators can be dropped.
314       return failure();
315 
316     for (const auto &it : llvm::enumerate(oldIteratorTypes)) {
317       int64_t currDim = it.index();
318       if (currDim == dimToDrop)
319         continue;
320       newIteratorTypes.push_back(it.value());
321     }
322 
323     SmallVector<Value> operands = {contractOp.getLhs(), contractOp.getRhs(),
324                                    contractOp.getAcc()};
325     SmallVector<Value> newOperands;
326 
327     for (const auto &it : llvm::enumerate(oldIndexingMaps)) {
328       // Check if the dim to be dropped exists as a leading dim in the operand
329       // if it does then we use vector.extract to drop it.
330       bool validExtract = false;
331       SmallVector<AffineExpr> results;
332       auto map = it.value();
333       int64_t orginalZeroDim = it.value().getDimPosition(0);
334       if (orginalZeroDim != dimToDrop) {
335         // There are two reasons to be in this path, 1. We need to
336         // tranpose the operand to make the dim to be dropped
337         // leading. 2. The dim to be dropped does not exist and in
338         // that case we dont want to add a unit tranpose but we must
339         // check all the indices to make sure this is the case.
340         bool tranposeNeeded = false;
341         SmallVector<int64_t> perm;
342         SmallVector<AffineExpr> transposeResults;
343 
344         for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) {
345           int64_t currDim = map.getDimPosition(i);
346           if (currDim == dimToDrop) {
347             tranposeNeeded = true;
348             perm.insert(perm.begin(), i);
349             auto targetExpr = rewriter.getAffineDimExpr(currDim);
350             transposeResults.insert(transposeResults.begin(), targetExpr);
351           } else {
352             perm.push_back(i);
353             auto targetExpr = rewriter.getAffineDimExpr(currDim);
354             transposeResults.push_back(targetExpr);
355           }
356         }
357         // Do the tranpose now if needed so that we can drop the
358         // correct dim using extract later.
359         if (tranposeNeeded) {
360           map = AffineMap::get(map.getNumDims(), 0, transposeResults,
361                                contractOp.getContext());
362           operands[it.index()] = rewriter.create<vector::TransposeOp>(
363               contractOp.getLoc(), operands[it.index()], perm);
364         }
365       }
366       // We have taken care to have the dim to be dropped be
367       // the leading dim. If its still not leading that means it
368       // does not exist in this operand and hence we do not need
369       // an extract.
370       if (map.getDimPosition(0) == dimToDrop)
371         validExtract = true;
372 
373       for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) {
374         int64_t currDim = map.getDimPosition(i);
375         if (currDim == dimToDrop)
376           // This is the dim we are dropping.
377           continue;
378         auto targetExpr = rewriter.getAffineDimExpr(
379             currDim < dimToDrop ? currDim : currDim - 1);
380         results.push_back(targetExpr);
381       }
382       newIndexingMaps.push_back(AffineMap::get(map.getNumDims() - 1, 0, results,
383                                                contractOp.getContext()));
384       // Extract if its a valid extraction, otherwise use the operand
385       // without extraction.
386       newOperands.push_back(validExtract
387                                 ? rewriter.create<vector::ExtractOp>(
388                                       contractOp.getLoc(), operands[it.index()],
389                                       splatZero(dropDim))
390                                 : operands[it.index()]);
391     }
392     auto newContractOp = rewriter.create<vector::ContractionOp>(
393         contractOp.getLoc(), newOperands[0], newOperands[1], newOperands[2],
394         rewriter.getAffineMapArrayAttr(newIndexingMaps),
395         rewriter.getArrayAttr(newIteratorTypes), contractOp.getKind());
396     rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
397         contractOp, contractOp->getResultTypes()[0], newContractOp);
398     return success();
399   }
400 };
401 
402 class CastAwayElementwiseLeadingOneDim : public RewritePattern {
403 public:
404   CastAwayElementwiseLeadingOneDim(MLIRContext *context)
405       : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context) {}
406 
407   LogicalResult matchAndRewrite(Operation *op,
408                                 PatternRewriter &rewriter) const override {
409     if (!OpTrait::hasElementwiseMappableTraits(op) || op->getNumResults() != 1)
410       return failure();
411     auto vecType = op->getResultTypes()[0].dyn_cast<VectorType>();
412     if (!vecType)
413       return failure();
414     VectorType newVecType = trimLeadingOneDims(vecType);
415     if (newVecType == vecType)
416       return failure();
417     int64_t dropDim = vecType.getRank() - newVecType.getRank();
418     SmallVector<Value, 4> newOperands;
419     for (Value operand : op->getOperands()) {
420       if (auto opVecType = operand.getType().dyn_cast<VectorType>()) {
421         newOperands.push_back(rewriter.create<vector::ExtractOp>(
422             op->getLoc(), operand, splatZero(dropDim)));
423       } else {
424         newOperands.push_back(operand);
425       }
426     }
427     Operation *newOp =
428         rewriter.create(op->getLoc(), op->getName().getIdentifier(),
429                         newOperands, newVecType, op->getAttrs());
430     rewriter.replaceOpWithNewOp<vector::BroadcastOp>(op, vecType,
431                                                      newOp->getResult(0));
432     return success();
433   }
434 };
435 
436 } // namespace
437 
438 void mlir::vector::populateCastAwayVectorLeadingOneDimPatterns(
439     RewritePatternSet &patterns) {
440   patterns
441       .add<CastAwayExtractStridedSliceLeadingOneDim,
442            CastAwayInsertStridedSliceLeadingOneDim, CastAwayInsertLeadingOneDim,
443            CastAwayTransferReadLeadingOneDim,
444            CastAwayTransferWriteLeadingOneDim, CastAwayElementwiseLeadingOneDim,
445            CastAwayContractionLeadingOneDim>(patterns.getContext());
446   populateShapeCastFoldingPatterns(patterns);
447 }
448