xref: /llvm-project/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp (revision fac349a169976f822fb27f03e623fa0d28aec1f3)
1 //===- VectorLinearize.cpp - vector linearization transforms --------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements patterns and pass for linearizing ND vectors into 1D.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/Arith/IR/Arith.h"
14 #include "mlir/Dialect/Vector/IR/VectorOps.h"
15 #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
16 #include "mlir/IR/Attributes.h"
17 #include "mlir/IR/BuiltinAttributes.h"
18 #include "mlir/IR/Operation.h"
19 #include "mlir/IR/PatternMatch.h"
20 #include "mlir/IR/TypeUtilities.h"
21 #include "mlir/Support/LogicalResult.h"
22 #include "mlir/Transforms/DialectConversion.h"
23 #include "llvm/ADT/ArrayRef.h"
24 #include <cstdint>
25 #include <numeric>
26 
27 using namespace mlir;
28 
29 static bool isLessThanTargetBitWidth(Operation *op, unsigned targetBitWidth) {
30   auto resultTypes = op->getResultTypes();
31   for (auto resType : resultTypes) {
32     VectorType vecType = dyn_cast<VectorType>(resType);
33     // Reject index since getElementTypeBitWidth will abort for Index types.
34     if (!vecType || vecType.getElementType().isIndex())
35       return false;
36     // There are no dimension to fold if it is a 0-D vector.
37     if (vecType.getRank() == 0)
38       return false;
39     unsigned trailingVecDimBitWidth =
40         vecType.getShape().back() * vecType.getElementTypeBitWidth();
41     if (trailingVecDimBitWidth >= targetBitWidth)
42       return false;
43   }
44   return true;
45 }
46 
47 namespace {
48 struct LinearizeConstant final : OpConversionPattern<arith::ConstantOp> {
49   using OpConversionPattern::OpConversionPattern;
50   LinearizeConstant(
51       const TypeConverter &typeConverter, MLIRContext *context,
52       unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
53       PatternBenefit benefit = 1)
54       : OpConversionPattern(typeConverter, context, benefit),
55         targetVectorBitWidth(targetVectBitWidth) {}
56   LogicalResult
57   matchAndRewrite(arith::ConstantOp constOp, OpAdaptor adaptor,
58                   ConversionPatternRewriter &rewriter) const override {
59     Location loc = constOp.getLoc();
60     auto resType =
61         getTypeConverter()->convertType<VectorType>(constOp.getType());
62 
63     if (resType.isScalable() && !isa<SplatElementsAttr>(constOp.getValue()))
64       return rewriter.notifyMatchFailure(
65           loc,
66           "Cannot linearize a constant scalable vector that's not a splat");
67 
68     if (!resType)
69       return rewriter.notifyMatchFailure(loc, "can't convert return type");
70     if (!isLessThanTargetBitWidth(constOp, targetVectorBitWidth))
71       return rewriter.notifyMatchFailure(
72           loc, "Can't flatten since targetBitWidth <= OpSize");
73     auto dstElementsAttr = dyn_cast<DenseElementsAttr>(constOp.getValue());
74     if (!dstElementsAttr)
75       return rewriter.notifyMatchFailure(loc, "unsupported attr type");
76 
77     dstElementsAttr = dstElementsAttr.reshape(resType);
78     rewriter.replaceOpWithNewOp<arith::ConstantOp>(constOp, resType,
79                                                    dstElementsAttr);
80     return success();
81   }
82 
83 private:
84   unsigned targetVectorBitWidth;
85 };
86 
87 struct LinearizeVectorizable final
88     : OpTraitConversionPattern<OpTrait::Vectorizable> {
89   using OpTraitConversionPattern::OpTraitConversionPattern;
90 
91 public:
92   LinearizeVectorizable(
93       const TypeConverter &typeConverter, MLIRContext *context,
94       unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
95       PatternBenefit benefit = 1)
96       : OpTraitConversionPattern(typeConverter, context, benefit),
97         targetVectorBitWidth(targetVectBitWidth) {}
98   LogicalResult
99   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
100                   ConversionPatternRewriter &rewriter) const override {
101     if (!isLessThanTargetBitWidth(op, targetVectorBitWidth))
102       return rewriter.notifyMatchFailure(
103           op->getLoc(), "Can't flatten since targetBitWidth <= OpSize");
104     FailureOr<Operation *> newOp =
105         convertOpResultTypes(op, operands, *getTypeConverter(), rewriter);
106     if (failed(newOp))
107       return failure();
108 
109     rewriter.replaceOp(op, (*newOp)->getResults());
110     return success();
111   }
112 
113 private:
114   unsigned targetVectorBitWidth;
115 };
116 
117 /// This pattern converts the ExtractStridedSliceOp into a ShuffleOp that works
118 /// on a linearized vector.
119 /// Following,
120 ///   vector.extract_strided_slice %source
121 ///         { offsets = [..], strides = [..], sizes = [..] }
122 /// is converted to :
123 ///   %source_1d = vector.shape_cast %source
124 ///   %out_1d = vector.shuffle %source_1d, %source_1d [ shuffle_indices_1d ]
125 ///   %out_nd = vector.shape_cast %out_1d
126 /// `shuffle_indices_1d` is computed using the offsets and sizes of the
127 /// extraction.
128 struct LinearizeVectorExtractStridedSlice final
129     : public mlir::OpConversionPattern<mlir::vector::ExtractStridedSliceOp> {
130   using OpConversionPattern::OpConversionPattern;
131   LinearizeVectorExtractStridedSlice(
132       const TypeConverter &typeConverter, MLIRContext *context,
133       unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
134       PatternBenefit benefit = 1)
135       : OpConversionPattern(typeConverter, context, benefit),
136         targetVectorBitWidth(targetVectBitWidth) {}
137 
138   LogicalResult
139   matchAndRewrite(vector::ExtractStridedSliceOp extractOp, OpAdaptor adaptor,
140                   ConversionPatternRewriter &rewriter) const override {
141     Type dstType = getTypeConverter()->convertType(extractOp.getType());
142     assert(!(extractOp.getVector().getType().isScalable() ||
143              cast<VectorType>(dstType).isScalable()) &&
144            "scalable vectors are not supported.");
145     if (!isLessThanTargetBitWidth(extractOp, targetVectorBitWidth))
146       return rewriter.notifyMatchFailure(
147           extractOp, "Can't flatten since targetBitWidth <= OpSize");
148 
149     ArrayAttr offsets = extractOp.getOffsets();
150     ArrayAttr sizes = extractOp.getSizes();
151     ArrayAttr strides = extractOp.getStrides();
152     if (!isConstantIntValue(strides[0], 1))
153       return rewriter.notifyMatchFailure(
154           extractOp, "Strided slice with stride != 1 is not supported.");
155     Value srcVector = adaptor.getVector();
156     // If kD offsets are specified for nD source vector (n > k), the granularity
157     // of the extraction is greater than 1. In this case last (n-k) dimensions
158     // form the extraction granularity.
159     // Example :
160     //  vector.extract_strided_slice %src {
161     //      offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} :
162     //      vector<4x8x8xf32> to vector<2x2x8xf32>
163     // Here, extraction granularity is 8.
164     int64_t extractGranularitySize = 1;
165     int64_t nD = extractOp.getSourceVectorType().getRank();
166     int64_t kD = (int64_t)offsets.size();
167     int64_t k = kD;
168     while (k < nD) {
169       extractGranularitySize *= extractOp.getSourceVectorType().getShape()[k];
170       ++k;
171     }
172     // Get total number of extracted slices.
173     int64_t nExtractedSlices = 1;
174     for (Attribute size : sizes) {
175       nExtractedSlices *= cast<IntegerAttr>(size).getInt();
176     }
177     // Compute the strides of the source vector considering first k dimensions.
178     llvm::SmallVector<int64_t, 4> sourceStrides(kD, extractGranularitySize);
179     for (int i = kD - 2; i >= 0; --i) {
180       sourceStrides[i] = sourceStrides[i + 1] *
181                          extractOp.getSourceVectorType().getShape()[i + 1];
182     }
183     // Final shuffle indices has nExtractedSlices * extractGranularitySize
184     // elements.
185     llvm::SmallVector<int64_t, 4> indices(nExtractedSlices *
186                                           extractGranularitySize);
187     // Compute the strides of the extracted kD vector.
188     llvm::SmallVector<int64_t, 4> extractedStrides(kD, 1);
189     // Compute extractedStrides.
190     for (int i = kD - 2; i >= 0; --i) {
191       extractedStrides[i] =
192           extractedStrides[i + 1] * cast<IntegerAttr>(sizes[i + 1]).getInt();
193     }
194     // Iterate over all extracted slices from 0 to nExtractedSlices - 1
195     // and compute the multi-dimensional index and the corresponding linearized
196     // index within the source vector.
197     for (int64_t i = 0; i < nExtractedSlices; ++i) {
198       int64_t index = i;
199       // Compute the corresponding multi-dimensional index.
200       llvm::SmallVector<int64_t, 4> multiDimIndex(kD, 0);
201       for (int64_t j = 0; j < kD; ++j) {
202         multiDimIndex[j] = (index / extractedStrides[j]);
203         index -= multiDimIndex[j] * extractedStrides[j];
204       }
205       // Compute the corresponding linearized index in the source vector
206       // i.e. shift the multiDimIndex by the offsets.
207       int64_t linearizedIndex = 0;
208       for (int64_t j = 0; j < kD; ++j) {
209         linearizedIndex +=
210             (cast<IntegerAttr>(offsets[j]).getInt() + multiDimIndex[j]) *
211             sourceStrides[j];
212       }
213       // Fill the indices array form linearizedIndex to linearizedIndex +
214       // extractGranularitySize.
215       for (int64_t j = 0; j < extractGranularitySize; ++j) {
216         indices[i * extractGranularitySize + j] = linearizedIndex + j;
217       }
218     }
219     // Perform a shuffle to extract the kD vector.
220     rewriter.replaceOpWithNewOp<vector::ShuffleOp>(
221         extractOp, dstType, srcVector, srcVector,
222         rewriter.getI64ArrayAttr(indices));
223     return success();
224   }
225 
226 private:
227   unsigned targetVectorBitWidth;
228 };
229 
230 /// This pattern converts the ShuffleOp that works on nD (n > 1)
231 /// vectors to a ShuffleOp that works on linearized vectors.
232 /// Following,
233 ///   vector.shuffle %v1, %v2 [ shuffle_indices ]
234 /// is converted to :
235 ///   %v1_1d = vector.shape_cast %v1
236 ///   %v2_1d = vector.shape_cast %v2
237 ///   %out_1d = vector.shuffle %v1_1d, %v2_1d [ shuffle_indices_1d ]
238 ///   %out_nd = vector.shape_cast %out_1d
239 // `shuffle_indices_1d` is computed using the sizes and `shuffle_indices`
240 /// of the original shuffle operation.
241 struct LinearizeVectorShuffle final
242     : public OpConversionPattern<vector::ShuffleOp> {
243   using OpConversionPattern::OpConversionPattern;
244   LinearizeVectorShuffle(
245       const TypeConverter &typeConverter, MLIRContext *context,
246       unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
247       PatternBenefit benefit = 1)
248       : OpConversionPattern(typeConverter, context, benefit),
249         targetVectorBitWidth(targetVectBitWidth) {}
250 
251   LogicalResult
252   matchAndRewrite(vector::ShuffleOp shuffleOp, OpAdaptor adaptor,
253                   ConversionPatternRewriter &rewriter) const override {
254     Type dstType = getTypeConverter()->convertType(shuffleOp.getType());
255     assert(!(shuffleOp.getV1VectorType().isScalable() ||
256              shuffleOp.getV2VectorType().isScalable() ||
257              cast<VectorType>(dstType).isScalable()) &&
258            "scalable vectors are not supported.");
259     if (!isLessThanTargetBitWidth(shuffleOp, targetVectorBitWidth))
260       return rewriter.notifyMatchFailure(
261           shuffleOp, "Can't flatten since targetBitWidth <= OpSize");
262 
263     Value vec1 = adaptor.getV1();
264     Value vec2 = adaptor.getV2();
265     int shuffleSliceLen = 1;
266     int rank = shuffleOp.getV1().getType().getRank();
267 
268     // If rank > 1, we need to do the shuffle in the granularity of slices
269     // instead of scalars. Size of the slice is equal to the rank-1 innermost
270     // dims. Mask of the shuffle op specifies which slice to take from the
271     // outermost dim.
272     if (rank > 1) {
273       llvm::ArrayRef<int64_t> shape = shuffleOp.getV1().getType().getShape();
274       for (unsigned i = 1; i < shape.size(); ++i) {
275         shuffleSliceLen *= shape[i];
276       }
277     }
278 
279     // For each value in the mask, we generate the indices of the source vectors
280     // that needs to be shuffled to the destination vector. If shuffleSliceLen >
281     // 1 we need to shuffle the slices (consecutive shuffleSliceLen number of
282     // elements) instead of scalars.
283     ArrayAttr mask = shuffleOp.getMask();
284     int64_t totalSizeOfShuffledElmnts = mask.size() * shuffleSliceLen;
285     llvm::SmallVector<int64_t, 2> indices(totalSizeOfShuffledElmnts);
286     for (auto [i, value] :
287          llvm::enumerate(mask.getAsValueRange<IntegerAttr>())) {
288 
289       int64_t v = value.getZExtValue();
290       std::iota(indices.begin() + shuffleSliceLen * i,
291                 indices.begin() + shuffleSliceLen * (i + 1),
292                 shuffleSliceLen * v);
293     }
294 
295     rewriter.replaceOpWithNewOp<vector::ShuffleOp>(
296         shuffleOp, dstType, vec1, vec2, rewriter.getI64ArrayAttr(indices));
297     return success();
298   }
299 
300 private:
301   unsigned targetVectorBitWidth;
302 };
303 
304 /// This pattern converts the ExtractOp to a ShuffleOp that works on a
305 /// linearized vector.
306 /// Following,
307 ///   vector.extract %source [ position ]
308 /// is converted to :
309 ///   %source_1d = vector.shape_cast %source
310 ///   %out_1d = vector.shuffle %source_1d, %source_1d [ shuffle_indices_1d ]
311 ///   %out_nd = vector.shape_cast %out_1d
312 /// `shuffle_indices_1d` is computed using the position of the original extract.
313 struct LinearizeVectorExtract final
314     : public OpConversionPattern<vector::ExtractOp> {
315   using OpConversionPattern::OpConversionPattern;
316   LinearizeVectorExtract(
317       const TypeConverter &typeConverter, MLIRContext *context,
318       unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
319       PatternBenefit benefit = 1)
320       : OpConversionPattern(typeConverter, context, benefit),
321         targetVectorBitWidth(targetVectBitWidth) {}
322   LogicalResult
323   matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor,
324                   ConversionPatternRewriter &rewriter) const override {
325     Type dstTy = getTypeConverter()->convertType(extractOp.getType());
326     assert(!(extractOp.getVector().getType().isScalable() ||
327              cast<VectorType>(dstTy).isScalable()) &&
328            "scalable vectors are not supported.");
329     if (!isLessThanTargetBitWidth(extractOp, targetVectorBitWidth))
330       return rewriter.notifyMatchFailure(
331           extractOp, "Can't flatten since targetBitWidth <= OpSize");
332 
333     // Dynamic position is not supported.
334     if (extractOp.hasDynamicPosition())
335       return rewriter.notifyMatchFailure(extractOp,
336                                          "dynamic position is not supported.");
337 
338     llvm::ArrayRef<int64_t> shape = extractOp.getVector().getType().getShape();
339     int64_t size = extractOp.getVector().getType().getNumElements();
340 
341     // Compute linearized offset.
342     int64_t linearizedOffset = 0;
343     llvm::ArrayRef<int64_t> offsets = extractOp.getStaticPosition();
344     for (auto [i, off] : llvm::enumerate(offsets)) {
345       size /= shape[i];
346       linearizedOffset += offsets[i] * size;
347     }
348 
349     llvm::SmallVector<int64_t, 2> indices(size);
350     std::iota(indices.begin(), indices.end(), linearizedOffset);
351     rewriter.replaceOpWithNewOp<vector::ShuffleOp>(
352         extractOp, dstTy, adaptor.getVector(), adaptor.getVector(),
353         rewriter.getI64ArrayAttr(indices));
354 
355     return success();
356   }
357 
358 private:
359   unsigned targetVectorBitWidth;
360 };
361 } // namespace
362 
363 void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality(
364     TypeConverter &typeConverter, RewritePatternSet &patterns,
365     ConversionTarget &target, unsigned targetBitWidth) {
366 
367   typeConverter.addConversion([](VectorType type) -> std::optional<Type> {
368     if (!isLinearizableVector(type))
369       return type;
370 
371     return VectorType::get(type.getNumElements(), type.getElementType(),
372                            type.isScalable());
373   });
374 
375   auto materializeCast = [](OpBuilder &builder, Type type, ValueRange inputs,
376                             Location loc) -> Value {
377     if (inputs.size() != 1 || !isa<VectorType>(inputs.front().getType()) ||
378         !isa<VectorType>(type))
379       return nullptr;
380 
381     return builder.create<vector::ShapeCastOp>(loc, type, inputs.front());
382   };
383   typeConverter.addArgumentMaterialization(materializeCast);
384   typeConverter.addSourceMaterialization(materializeCast);
385   typeConverter.addTargetMaterialization(materializeCast);
386   target.markUnknownOpDynamicallyLegal(
387       [=](Operation *op) -> std::optional<bool> {
388         if ((isa<arith::ConstantOp>(op) ||
389              op->hasTrait<OpTrait::Vectorizable>())) {
390           return (isLessThanTargetBitWidth(op, targetBitWidth)
391                       ? typeConverter.isLegal(op)
392                       : true);
393         }
394         return std::nullopt;
395       });
396 
397   patterns.add<LinearizeConstant, LinearizeVectorizable>(
398       typeConverter, patterns.getContext(), targetBitWidth);
399 }
400 
401 void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns(
402     TypeConverter &typeConverter, RewritePatternSet &patterns,
403     ConversionTarget &target, unsigned int targetBitWidth) {
404   target.addDynamicallyLegalOp<vector::ShuffleOp>(
405       [=](vector::ShuffleOp shuffleOp) -> bool {
406         return isLessThanTargetBitWidth(shuffleOp, targetBitWidth)
407                    ? (typeConverter.isLegal(shuffleOp) &&
408                       cast<mlir::VectorType>(shuffleOp.getResult().getType())
409                               .getRank() == 1)
410                    : true;
411       });
412   patterns.add<LinearizeVectorShuffle, LinearizeVectorExtract,
413                LinearizeVectorExtractStridedSlice>(
414       typeConverter, patterns.getContext(), targetBitWidth);
415 }
416