xref: /llvm-project/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp (revision 3ace685105d3b50bca68328bf0c945af22d70f23)
1 //===- VectorLinearize.cpp - vector linearization transforms --------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements patterns and pass for linearizing ND vectors into 1D.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/Arith/IR/Arith.h"
14 #include "mlir/Dialect/Vector/IR/VectorOps.h"
15 #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
16 #include "mlir/IR/Attributes.h"
17 #include "mlir/IR/BuiltinAttributes.h"
18 #include "mlir/IR/Operation.h"
19 #include "mlir/IR/PatternMatch.h"
20 #include "mlir/IR/TypeUtilities.h"
21 #include "mlir/Transforms/DialectConversion.h"
22 #include "llvm/ADT/ArrayRef.h"
23 #include <cstdint>
24 #include <numeric>
25 
26 using namespace mlir;
27 
28 static bool isLessThanTargetBitWidth(Operation *op, unsigned targetBitWidth) {
29   auto resultTypes = op->getResultTypes();
30   for (auto resType : resultTypes) {
31     VectorType vecType = dyn_cast<VectorType>(resType);
32     // Reject index since getElementTypeBitWidth will abort for Index types.
33     if (!vecType || vecType.getElementType().isIndex())
34       return false;
35     // There are no dimension to fold if it is a 0-D vector.
36     if (vecType.getRank() == 0)
37       return false;
38     unsigned trailingVecDimBitWidth =
39         vecType.getShape().back() * vecType.getElementTypeBitWidth();
40     if (trailingVecDimBitWidth >= targetBitWidth)
41       return false;
42   }
43   return true;
44 }
45 
46 static bool isLessThanOrEqualTargetBitWidth(Type t, unsigned targetBitWidth) {
47   VectorType vecType = dyn_cast<VectorType>(t);
48   // Reject index since getElementTypeBitWidth will abort for Index types.
49   if (!vecType || vecType.getElementType().isIndex())
50     return false;
51   // There are no dimension to fold if it is a 0-D vector.
52   if (vecType.getRank() == 0)
53     return false;
54   unsigned trailingVecDimBitWidth =
55       vecType.getShape().back() * vecType.getElementTypeBitWidth();
56   return trailingVecDimBitWidth <= targetBitWidth;
57 }
58 
59 namespace {
60 struct LinearizeConstant final : OpConversionPattern<arith::ConstantOp> {
61   using OpConversionPattern::OpConversionPattern;
62   LinearizeConstant(
63       const TypeConverter &typeConverter, MLIRContext *context,
64       unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
65       PatternBenefit benefit = 1)
66       : OpConversionPattern(typeConverter, context, benefit),
67         targetVectorBitWidth(targetVectBitWidth) {}
68   LogicalResult
69   matchAndRewrite(arith::ConstantOp constOp, OpAdaptor adaptor,
70                   ConversionPatternRewriter &rewriter) const override {
71     Location loc = constOp.getLoc();
72     auto resType =
73         getTypeConverter()->convertType<VectorType>(constOp.getType());
74 
75     if (resType.isScalable() && !isa<SplatElementsAttr>(constOp.getValue()))
76       return rewriter.notifyMatchFailure(
77           loc,
78           "Cannot linearize a constant scalable vector that's not a splat");
79 
80     if (!resType)
81       return rewriter.notifyMatchFailure(loc, "can't convert return type");
82     if (!isLessThanTargetBitWidth(constOp, targetVectorBitWidth))
83       return rewriter.notifyMatchFailure(
84           loc, "Can't flatten since targetBitWidth <= OpSize");
85     auto dstElementsAttr = dyn_cast<DenseElementsAttr>(constOp.getValue());
86     if (!dstElementsAttr)
87       return rewriter.notifyMatchFailure(loc, "unsupported attr type");
88 
89     dstElementsAttr = dstElementsAttr.reshape(resType);
90     rewriter.replaceOpWithNewOp<arith::ConstantOp>(constOp, resType,
91                                                    dstElementsAttr);
92     return success();
93   }
94 
95 private:
96   unsigned targetVectorBitWidth;
97 };
98 
99 struct LinearizeVectorizable final
100     : OpTraitConversionPattern<OpTrait::Vectorizable> {
101   using OpTraitConversionPattern::OpTraitConversionPattern;
102 
103 public:
104   LinearizeVectorizable(
105       const TypeConverter &typeConverter, MLIRContext *context,
106       unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
107       PatternBenefit benefit = 1)
108       : OpTraitConversionPattern(typeConverter, context, benefit),
109         targetVectorBitWidth(targetVectBitWidth) {}
110   LogicalResult
111   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
112                   ConversionPatternRewriter &rewriter) const override {
113     if (!isLessThanTargetBitWidth(op, targetVectorBitWidth))
114       return rewriter.notifyMatchFailure(
115           op->getLoc(), "Can't flatten since targetBitWidth <= OpSize");
116     FailureOr<Operation *> newOp =
117         convertOpResultTypes(op, operands, *getTypeConverter(), rewriter);
118     if (failed(newOp))
119       return failure();
120 
121     rewriter.replaceOp(op, (*newOp)->getResults());
122     return success();
123   }
124 
125 private:
126   unsigned targetVectorBitWidth;
127 };
128 
129 /// This pattern converts the ExtractStridedSliceOp into a ShuffleOp that works
130 /// on a linearized vector.
131 /// Following,
132 ///   vector.extract_strided_slice %source
133 ///         { offsets = [..], strides = [..], sizes = [..] }
134 /// is converted to :
135 ///   %source_1d = vector.shape_cast %source
136 ///   %out_1d = vector.shuffle %source_1d, %source_1d [ shuffle_indices_1d ]
137 ///   %out_nd = vector.shape_cast %out_1d
138 /// `shuffle_indices_1d` is computed using the offsets and sizes of the
139 /// extraction.
140 struct LinearizeVectorExtractStridedSlice final
141     : public mlir::OpConversionPattern<mlir::vector::ExtractStridedSliceOp> {
142   using OpConversionPattern::OpConversionPattern;
143   LinearizeVectorExtractStridedSlice(
144       const TypeConverter &typeConverter, MLIRContext *context,
145       unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
146       PatternBenefit benefit = 1)
147       : OpConversionPattern(typeConverter, context, benefit),
148         targetVectorBitWidth(targetVectBitWidth) {}
149 
150   LogicalResult
151   matchAndRewrite(vector::ExtractStridedSliceOp extractOp, OpAdaptor adaptor,
152                   ConversionPatternRewriter &rewriter) const override {
153     VectorType dstType =
154         getTypeConverter()->convertType<VectorType>(extractOp.getType());
155     assert(dstType && "vector type destination expected.");
156     if (extractOp.getVector().getType().isScalable() || dstType.isScalable())
157       return rewriter.notifyMatchFailure(extractOp,
158                                          "scalable vectors are not supported.");
159     if (!isLessThanTargetBitWidth(extractOp, targetVectorBitWidth))
160       return rewriter.notifyMatchFailure(
161           extractOp, "Can't flatten since targetBitWidth <= OpSize");
162 
163     ArrayAttr offsets = extractOp.getOffsets();
164     ArrayAttr sizes = extractOp.getSizes();
165     ArrayAttr strides = extractOp.getStrides();
166     if (!isConstantIntValue(strides[0], 1))
167       return rewriter.notifyMatchFailure(
168           extractOp, "Strided slice with stride != 1 is not supported.");
169     Value srcVector = adaptor.getVector();
170     // If kD offsets are specified for nD source vector (n > k), the granularity
171     // of the extraction is greater than 1. In this case last (n-k) dimensions
172     // form the extraction granularity.
173     // Example :
174     //  vector.extract_strided_slice %src {
175     //      offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} :
176     //      vector<4x8x8xf32> to vector<2x2x8xf32>
177     // Here, extraction granularity is 8.
178     int64_t extractGranularitySize = 1;
179     int64_t nD = extractOp.getSourceVectorType().getRank();
180     int64_t kD = (int64_t)offsets.size();
181     int64_t k = kD;
182     while (k < nD) {
183       extractGranularitySize *= extractOp.getSourceVectorType().getShape()[k];
184       ++k;
185     }
186     // Get total number of extracted slices.
187     int64_t nExtractedSlices = 1;
188     for (Attribute size : sizes) {
189       nExtractedSlices *= cast<IntegerAttr>(size).getInt();
190     }
191     // Compute the strides of the source vector considering first k dimensions.
192     llvm::SmallVector<int64_t, 4> sourceStrides(kD, extractGranularitySize);
193     for (int i = kD - 2; i >= 0; --i) {
194       sourceStrides[i] = sourceStrides[i + 1] *
195                          extractOp.getSourceVectorType().getShape()[i + 1];
196     }
197     // Final shuffle indices has nExtractedSlices * extractGranularitySize
198     // elements.
199     llvm::SmallVector<int64_t, 4> indices(nExtractedSlices *
200                                           extractGranularitySize);
201     // Compute the strides of the extracted kD vector.
202     llvm::SmallVector<int64_t, 4> extractedStrides(kD, 1);
203     // Compute extractedStrides.
204     for (int i = kD - 2; i >= 0; --i) {
205       extractedStrides[i] =
206           extractedStrides[i + 1] * cast<IntegerAttr>(sizes[i + 1]).getInt();
207     }
208     // Iterate over all extracted slices from 0 to nExtractedSlices - 1
209     // and compute the multi-dimensional index and the corresponding linearized
210     // index within the source vector.
211     for (int64_t i = 0; i < nExtractedSlices; ++i) {
212       int64_t index = i;
213       // Compute the corresponding multi-dimensional index.
214       llvm::SmallVector<int64_t, 4> multiDimIndex(kD, 0);
215       for (int64_t j = 0; j < kD; ++j) {
216         multiDimIndex[j] = (index / extractedStrides[j]);
217         index -= multiDimIndex[j] * extractedStrides[j];
218       }
219       // Compute the corresponding linearized index in the source vector
220       // i.e. shift the multiDimIndex by the offsets.
221       int64_t linearizedIndex = 0;
222       for (int64_t j = 0; j < kD; ++j) {
223         linearizedIndex +=
224             (cast<IntegerAttr>(offsets[j]).getInt() + multiDimIndex[j]) *
225             sourceStrides[j];
226       }
227       // Fill the indices array form linearizedIndex to linearizedIndex +
228       // extractGranularitySize.
229       for (int64_t j = 0; j < extractGranularitySize; ++j) {
230         indices[i * extractGranularitySize + j] = linearizedIndex + j;
231       }
232     }
233     // Perform a shuffle to extract the kD vector.
234     rewriter.replaceOpWithNewOp<vector::ShuffleOp>(
235         extractOp, dstType, srcVector, srcVector, indices);
236     return success();
237   }
238 
239 private:
240   unsigned targetVectorBitWidth;
241 };
242 
243 /// This pattern converts the ShuffleOp that works on nD (n > 1)
244 /// vectors to a ShuffleOp that works on linearized vectors.
245 /// Following,
246 ///   vector.shuffle %v1, %v2 [ shuffle_indices ]
247 /// is converted to :
248 ///   %v1_1d = vector.shape_cast %v1
249 ///   %v2_1d = vector.shape_cast %v2
250 ///   %out_1d = vector.shuffle %v1_1d, %v2_1d [ shuffle_indices_1d ]
251 ///   %out_nd = vector.shape_cast %out_1d
252 // `shuffle_indices_1d` is computed using the sizes and `shuffle_indices`
253 /// of the original shuffle operation.
254 struct LinearizeVectorShuffle final
255     : public OpConversionPattern<vector::ShuffleOp> {
256   using OpConversionPattern::OpConversionPattern;
257   LinearizeVectorShuffle(
258       const TypeConverter &typeConverter, MLIRContext *context,
259       unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
260       PatternBenefit benefit = 1)
261       : OpConversionPattern(typeConverter, context, benefit),
262         targetVectorBitWidth(targetVectBitWidth) {}
263 
264   LogicalResult
265   matchAndRewrite(vector::ShuffleOp shuffleOp, OpAdaptor adaptor,
266                   ConversionPatternRewriter &rewriter) const override {
267     VectorType dstType =
268         getTypeConverter()->convertType<VectorType>(shuffleOp.getType());
269     assert(dstType && "vector type destination expected.");
270     // The assert is used because vector.shuffle does not support scalable
271     // vectors.
272     assert(!(shuffleOp.getV1VectorType().isScalable() ||
273              shuffleOp.getV2VectorType().isScalable() ||
274              dstType.isScalable()) &&
275            "scalable vectors are not supported.");
276     if (!isLessThanTargetBitWidth(shuffleOp, targetVectorBitWidth))
277       return rewriter.notifyMatchFailure(
278           shuffleOp, "Can't flatten since targetBitWidth <= OpSize");
279 
280     Value vec1 = adaptor.getV1();
281     Value vec2 = adaptor.getV2();
282     int shuffleSliceLen = 1;
283     int rank = shuffleOp.getV1().getType().getRank();
284 
285     // If rank > 1, we need to do the shuffle in the granularity of slices
286     // instead of scalars. Size of the slice is equal to the rank-1 innermost
287     // dims. Mask of the shuffle op specifies which slice to take from the
288     // outermost dim.
289     if (rank > 1) {
290       llvm::ArrayRef<int64_t> shape = shuffleOp.getV1().getType().getShape();
291       for (unsigned i = 1; i < shape.size(); ++i) {
292         shuffleSliceLen *= shape[i];
293       }
294     }
295 
296     // For each value in the mask, we generate the indices of the source vectors
297     // that needs to be shuffled to the destination vector. If shuffleSliceLen >
298     // 1 we need to shuffle the slices (consecutive shuffleSliceLen number of
299     // elements) instead of scalars.
300     ArrayRef<int64_t> mask = shuffleOp.getMask();
301     int64_t totalSizeOfShuffledElmnts = mask.size() * shuffleSliceLen;
302     llvm::SmallVector<int64_t, 2> indices(totalSizeOfShuffledElmnts);
303     for (auto [i, value] : llvm::enumerate(mask)) {
304       std::iota(indices.begin() + shuffleSliceLen * i,
305                 indices.begin() + shuffleSliceLen * (i + 1),
306                 shuffleSliceLen * value);
307     }
308 
309     rewriter.replaceOpWithNewOp<vector::ShuffleOp>(shuffleOp, dstType, vec1,
310                                                    vec2, indices);
311     return success();
312   }
313 
314 private:
315   unsigned targetVectorBitWidth;
316 };
317 
318 /// This pattern converts the ExtractOp to a ShuffleOp that works on a
319 /// linearized vector.
320 /// Following,
321 ///   vector.extract %source [ position ]
322 /// is converted to :
323 ///   %source_1d = vector.shape_cast %source
324 ///   %out_1d = vector.shuffle %source_1d, %source_1d [ shuffle_indices_1d ]
325 ///   %out_nd = vector.shape_cast %out_1d
326 /// `shuffle_indices_1d` is computed using the position of the original extract.
327 struct LinearizeVectorExtract final
328     : public OpConversionPattern<vector::ExtractOp> {
329   using OpConversionPattern::OpConversionPattern;
330   LinearizeVectorExtract(
331       const TypeConverter &typeConverter, MLIRContext *context,
332       unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
333       PatternBenefit benefit = 1)
334       : OpConversionPattern(typeConverter, context, benefit),
335         targetVectorBitWidth(targetVectBitWidth) {}
336   LogicalResult
337   matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor,
338                   ConversionPatternRewriter &rewriter) const override {
339     Type dstTy = getTypeConverter()->convertType(extractOp.getType());
340     if (!dstTy)
341       return rewriter.notifyMatchFailure(extractOp,
342                                          "expected n-D vector type.");
343 
344     if (extractOp.getVector().getType().isScalable() ||
345         cast<VectorType>(dstTy).isScalable())
346       return rewriter.notifyMatchFailure(extractOp,
347                                          "scalable vectors are not supported.");
348     if (!isLessThanTargetBitWidth(extractOp, targetVectorBitWidth))
349       return rewriter.notifyMatchFailure(
350           extractOp, "Can't flatten since targetBitWidth <= OpSize");
351 
352     // Dynamic position is not supported.
353     if (extractOp.hasDynamicPosition())
354       return rewriter.notifyMatchFailure(extractOp,
355                                          "dynamic position is not supported.");
356 
357     llvm::ArrayRef<int64_t> shape = extractOp.getVector().getType().getShape();
358     int64_t size = extractOp.getVector().getType().getNumElements();
359 
360     // Compute linearized offset.
361     int64_t linearizedOffset = 0;
362     llvm::ArrayRef<int64_t> offsets = extractOp.getStaticPosition();
363     for (auto [i, off] : llvm::enumerate(offsets)) {
364       size /= shape[i];
365       linearizedOffset += offsets[i] * size;
366     }
367 
368     llvm::SmallVector<int64_t, 2> indices(size);
369     std::iota(indices.begin(), indices.end(), linearizedOffset);
370     rewriter.replaceOpWithNewOp<vector::ShuffleOp>(
371         extractOp, dstTy, adaptor.getVector(), adaptor.getVector(), indices);
372 
373     return success();
374   }
375 
376 private:
377   unsigned targetVectorBitWidth;
378 };
379 
380 /// This pattern converts the InsertOp to a ShuffleOp that works on a
381 /// linearized vector.
382 /// Following,
383 ///   vector.insert %source %destination [ position ]
384 /// is converted to :
385 ///   %source_1d = vector.shape_cast %source
386 ///   %destination_1d = vector.shape_cast %destination
387 ///   %out_1d = vector.shuffle %destination_1d, %source_1d [ shuffle_indices_1d
388 ///   ] %out_nd = vector.shape_cast %out_1d
389 /// `shuffle_indices_1d` is computed using the position of the original insert.
390 struct LinearizeVectorInsert final
391     : public OpConversionPattern<vector::InsertOp> {
392   using OpConversionPattern::OpConversionPattern;
393   LinearizeVectorInsert(
394       const TypeConverter &typeConverter, MLIRContext *context,
395       unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
396       PatternBenefit benefit = 1)
397       : OpConversionPattern(typeConverter, context, benefit),
398         targetVectorBitWidth(targetVectBitWidth) {}
399   LogicalResult
400   matchAndRewrite(vector::InsertOp insertOp, OpAdaptor adaptor,
401                   ConversionPatternRewriter &rewriter) const override {
402     VectorType dstTy = getTypeConverter()->convertType<VectorType>(
403         insertOp.getDestVectorType());
404     assert(dstTy && "vector type destination expected.");
405     if (insertOp.getDestVectorType().isScalable() || dstTy.isScalable())
406       return rewriter.notifyMatchFailure(insertOp,
407                                          "scalable vectors are not supported.");
408 
409     if (!isLessThanOrEqualTargetBitWidth(insertOp.getSourceType(),
410                                          targetVectorBitWidth))
411       return rewriter.notifyMatchFailure(
412           insertOp, "Can't flatten since targetBitWidth < OpSize");
413 
414     // dynamic position is not supported
415     if (insertOp.hasDynamicPosition())
416       return rewriter.notifyMatchFailure(insertOp,
417                                          "dynamic position is not supported.");
418     auto srcTy = insertOp.getSourceType();
419     auto srcAsVec = dyn_cast<VectorType>(srcTy);
420     uint64_t srcSize = 0;
421     if (srcAsVec) {
422       srcSize = srcAsVec.getNumElements();
423     } else {
424       return rewriter.notifyMatchFailure(insertOp,
425                                          "scalars are not supported.");
426     }
427 
428     auto dstShape = insertOp.getDestVectorType().getShape();
429     const auto dstSize = insertOp.getDestVectorType().getNumElements();
430     auto dstSizeForOffsets = dstSize;
431 
432     // compute linearized offset
433     int64_t linearizedOffset = 0;
434     auto offsetsNd = insertOp.getStaticPosition();
435     for (auto [dim, offset] : llvm::enumerate(offsetsNd)) {
436       dstSizeForOffsets /= dstShape[dim];
437       linearizedOffset += offset * dstSizeForOffsets;
438     }
439 
440     llvm::SmallVector<int64_t, 2> indices(dstSize);
441     auto origValsUntil = indices.begin();
442     std::advance(origValsUntil, linearizedOffset);
443     std::iota(indices.begin(), origValsUntil,
444               0); // original values that remain [0, offset)
445     auto newValsUntil = origValsUntil;
446     std::advance(newValsUntil, srcSize);
447     std::iota(origValsUntil, newValsUntil,
448               dstSize); // new values [offset, offset+srcNumElements)
449     std::iota(newValsUntil, indices.end(),
450               linearizedOffset + srcSize); // the rest of original values
451                                            // [offset+srcNumElements, end)
452 
453     rewriter.replaceOpWithNewOp<vector::ShuffleOp>(
454         insertOp, dstTy, adaptor.getDest(), adaptor.getSource(), indices);
455 
456     return success();
457   }
458 
459 private:
460   unsigned targetVectorBitWidth;
461 };
462 } // namespace
463 
464 void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality(
465     TypeConverter &typeConverter, RewritePatternSet &patterns,
466     ConversionTarget &target, unsigned targetBitWidth) {
467 
468   typeConverter.addConversion([](VectorType type) -> std::optional<Type> {
469     if (!isLinearizableVector(type))
470       return type;
471 
472     return VectorType::get(type.getNumElements(), type.getElementType(),
473                            type.isScalable());
474   });
475 
476   auto materializeCast = [](OpBuilder &builder, Type type, ValueRange inputs,
477                             Location loc) -> Value {
478     if (inputs.size() != 1 || !isa<VectorType>(inputs.front().getType()) ||
479         !isa<VectorType>(type))
480       return nullptr;
481 
482     return builder.create<vector::ShapeCastOp>(loc, type, inputs.front());
483   };
484   typeConverter.addSourceMaterialization(materializeCast);
485   typeConverter.addTargetMaterialization(materializeCast);
486   target.markUnknownOpDynamicallyLegal(
487       [=](Operation *op) -> std::optional<bool> {
488         if ((isa<arith::ConstantOp>(op) ||
489              op->hasTrait<OpTrait::Vectorizable>())) {
490           return (isLessThanTargetBitWidth(op, targetBitWidth)
491                       ? typeConverter.isLegal(op)
492                       : true);
493         }
494         return std::nullopt;
495       });
496 
497   patterns.add<LinearizeConstant, LinearizeVectorizable>(
498       typeConverter, patterns.getContext(), targetBitWidth);
499 }
500 
501 void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns(
502     const TypeConverter &typeConverter, RewritePatternSet &patterns,
503     ConversionTarget &target, unsigned int targetBitWidth) {
504   target.addDynamicallyLegalOp<vector::ShuffleOp>(
505       [=](vector::ShuffleOp shuffleOp) -> bool {
506         return isLessThanTargetBitWidth(shuffleOp, targetBitWidth)
507                    ? (typeConverter.isLegal(shuffleOp) &&
508                       cast<mlir::VectorType>(shuffleOp.getResult().getType())
509                               .getRank() == 1)
510                    : true;
511       });
512   patterns.add<LinearizeVectorShuffle, LinearizeVectorExtract,
513                LinearizeVectorInsert, LinearizeVectorExtractStridedSlice>(
514       typeConverter, patterns.getContext(), targetBitWidth);
515 }
516