xref: /llvm-project/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp (revision b4444dca47c41436aa781bfd38aac6eca856ef23)
1 //===- VectorLinearize.cpp - vector linearization transforms --------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements patterns and pass for linearizing ND vectors into 1D.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/Arith/IR/Arith.h"
14 #include "mlir/Dialect/Vector/IR/VectorOps.h"
15 #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
16 #include "mlir/IR/Attributes.h"
17 #include "mlir/IR/BuiltinAttributes.h"
18 #include "mlir/IR/Operation.h"
19 #include "mlir/IR/PatternMatch.h"
20 #include "mlir/IR/TypeUtilities.h"
21 #include "mlir/Transforms/DialectConversion.h"
22 #include "llvm/ADT/ArrayRef.h"
23 #include <cstdint>
24 #include <numeric>
25 
26 using namespace mlir;
27 
28 static bool isLessThanTargetBitWidth(Operation *op, unsigned targetBitWidth) {
29   auto resultTypes = op->getResultTypes();
30   for (auto resType : resultTypes) {
31     VectorType vecType = dyn_cast<VectorType>(resType);
32     // Reject index since getElementTypeBitWidth will abort for Index types.
33     if (!vecType || vecType.getElementType().isIndex())
34       return false;
35     // There are no dimension to fold if it is a 0-D vector.
36     if (vecType.getRank() == 0)
37       return false;
38     unsigned trailingVecDimBitWidth =
39         vecType.getShape().back() * vecType.getElementTypeBitWidth();
40     if (trailingVecDimBitWidth >= targetBitWidth)
41       return false;
42   }
43   return true;
44 }
45 
46 static bool isLessThanOrEqualTargetBitWidth(Type t, unsigned targetBitWidth) {
47   VectorType vecType = dyn_cast<VectorType>(t);
48   // Reject index since getElementTypeBitWidth will abort for Index types.
49   if (!vecType || vecType.getElementType().isIndex())
50     return false;
51   // There are no dimension to fold if it is a 0-D vector.
52   if (vecType.getRank() == 0)
53     return false;
54   unsigned trailingVecDimBitWidth =
55       vecType.getShape().back() * vecType.getElementTypeBitWidth();
56   return trailingVecDimBitWidth <= targetBitWidth;
57 }
58 
59 namespace {
60 struct LinearizeConstant final : OpConversionPattern<arith::ConstantOp> {
61   using OpConversionPattern::OpConversionPattern;
62   LinearizeConstant(
63       const TypeConverter &typeConverter, MLIRContext *context,
64       unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
65       PatternBenefit benefit = 1)
66       : OpConversionPattern(typeConverter, context, benefit),
67         targetVectorBitWidth(targetVectBitWidth) {}
68   LogicalResult
69   matchAndRewrite(arith::ConstantOp constOp, OpAdaptor adaptor,
70                   ConversionPatternRewriter &rewriter) const override {
71     Location loc = constOp.getLoc();
72     auto resType =
73         getTypeConverter()->convertType<VectorType>(constOp.getType());
74 
75     if (resType.isScalable() && !isa<SplatElementsAttr>(constOp.getValue()))
76       return rewriter.notifyMatchFailure(
77           loc,
78           "Cannot linearize a constant scalable vector that's not a splat");
79 
80     if (!resType)
81       return rewriter.notifyMatchFailure(loc, "can't convert return type");
82     if (!isLessThanTargetBitWidth(constOp, targetVectorBitWidth))
83       return rewriter.notifyMatchFailure(
84           loc, "Can't flatten since targetBitWidth <= OpSize");
85     auto dstElementsAttr = dyn_cast<DenseElementsAttr>(constOp.getValue());
86     if (!dstElementsAttr)
87       return rewriter.notifyMatchFailure(loc, "unsupported attr type");
88 
89     dstElementsAttr = dstElementsAttr.reshape(resType);
90     rewriter.replaceOpWithNewOp<arith::ConstantOp>(constOp, resType,
91                                                    dstElementsAttr);
92     return success();
93   }
94 
95 private:
96   unsigned targetVectorBitWidth;
97 };
98 
99 struct LinearizeVectorizable final
100     : OpTraitConversionPattern<OpTrait::Vectorizable> {
101   using OpTraitConversionPattern::OpTraitConversionPattern;
102 
103 public:
104   LinearizeVectorizable(
105       const TypeConverter &typeConverter, MLIRContext *context,
106       unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
107       PatternBenefit benefit = 1)
108       : OpTraitConversionPattern(typeConverter, context, benefit),
109         targetVectorBitWidth(targetVectBitWidth) {}
110   LogicalResult
111   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
112                   ConversionPatternRewriter &rewriter) const override {
113     if (!isLessThanTargetBitWidth(op, targetVectorBitWidth))
114       return rewriter.notifyMatchFailure(
115           op->getLoc(), "Can't flatten since targetBitWidth <= OpSize");
116     FailureOr<Operation *> newOp =
117         convertOpResultTypes(op, operands, *getTypeConverter(), rewriter);
118     if (failed(newOp))
119       return failure();
120 
121     rewriter.replaceOp(op, (*newOp)->getResults());
122     return success();
123   }
124 
125 private:
126   unsigned targetVectorBitWidth;
127 };
128 
129 /// This pattern converts the ExtractStridedSliceOp into a ShuffleOp that works
130 /// on a linearized vector.
131 /// Following,
132 ///   vector.extract_strided_slice %source
133 ///         { offsets = [..], strides = [..], sizes = [..] }
134 /// is converted to :
135 ///   %source_1d = vector.shape_cast %source
136 ///   %out_1d = vector.shuffle %source_1d, %source_1d [ shuffle_indices_1d ]
137 ///   %out_nd = vector.shape_cast %out_1d
138 /// `shuffle_indices_1d` is computed using the offsets and sizes of the
139 /// extraction.
140 struct LinearizeVectorExtractStridedSlice final
141     : public mlir::OpConversionPattern<mlir::vector::ExtractStridedSliceOp> {
142   using OpConversionPattern::OpConversionPattern;
143   LinearizeVectorExtractStridedSlice(
144       const TypeConverter &typeConverter, MLIRContext *context,
145       unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
146       PatternBenefit benefit = 1)
147       : OpConversionPattern(typeConverter, context, benefit),
148         targetVectorBitWidth(targetVectBitWidth) {}
149 
150   LogicalResult
151   matchAndRewrite(vector::ExtractStridedSliceOp extractOp, OpAdaptor adaptor,
152                   ConversionPatternRewriter &rewriter) const override {
153     VectorType dstType =
154         getTypeConverter()->convertType<VectorType>(extractOp.getType());
155     assert(dstType && "vector type destination expected.");
156     if (extractOp.getVector().getType().isScalable() || dstType.isScalable())
157       return rewriter.notifyMatchFailure(extractOp,
158                                          "scalable vectors are not supported.");
159     if (!isLessThanTargetBitWidth(extractOp, targetVectorBitWidth))
160       return rewriter.notifyMatchFailure(
161           extractOp, "Can't flatten since targetBitWidth <= OpSize");
162 
163     ArrayAttr offsets = extractOp.getOffsets();
164     ArrayAttr sizes = extractOp.getSizes();
165     ArrayAttr strides = extractOp.getStrides();
166     if (!isConstantIntValue(strides[0], 1))
167       return rewriter.notifyMatchFailure(
168           extractOp, "Strided slice with stride != 1 is not supported.");
169     Value srcVector = adaptor.getVector();
170     // If kD offsets are specified for nD source vector (n > k), the granularity
171     // of the extraction is greater than 1. In this case last (n-k) dimensions
172     // form the extraction granularity.
173     // Example :
174     //  vector.extract_strided_slice %src {
175     //      offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} :
176     //      vector<4x8x8xf32> to vector<2x2x8xf32>
177     // Here, extraction granularity is 8.
178     int64_t extractGranularitySize = 1;
179     int64_t nD = extractOp.getSourceVectorType().getRank();
180     int64_t kD = (int64_t)offsets.size();
181     int64_t k = kD;
182     while (k < nD) {
183       extractGranularitySize *= extractOp.getSourceVectorType().getShape()[k];
184       ++k;
185     }
186     // Get total number of extracted slices.
187     int64_t nExtractedSlices = 1;
188     for (Attribute size : sizes) {
189       nExtractedSlices *= cast<IntegerAttr>(size).getInt();
190     }
191     // Compute the strides of the source vector considering first k dimensions.
192     llvm::SmallVector<int64_t, 4> sourceStrides(kD, extractGranularitySize);
193     for (int i = kD - 2; i >= 0; --i) {
194       sourceStrides[i] = sourceStrides[i + 1] *
195                          extractOp.getSourceVectorType().getShape()[i + 1];
196     }
197     // Final shuffle indices has nExtractedSlices * extractGranularitySize
198     // elements.
199     llvm::SmallVector<int64_t, 4> indices(nExtractedSlices *
200                                           extractGranularitySize);
201     // Compute the strides of the extracted kD vector.
202     llvm::SmallVector<int64_t, 4> extractedStrides(kD, 1);
203     // Compute extractedStrides.
204     for (int i = kD - 2; i >= 0; --i) {
205       extractedStrides[i] =
206           extractedStrides[i + 1] * cast<IntegerAttr>(sizes[i + 1]).getInt();
207     }
208     // Iterate over all extracted slices from 0 to nExtractedSlices - 1
209     // and compute the multi-dimensional index and the corresponding linearized
210     // index within the source vector.
211     for (int64_t i = 0; i < nExtractedSlices; ++i) {
212       int64_t index = i;
213       // Compute the corresponding multi-dimensional index.
214       llvm::SmallVector<int64_t, 4> multiDimIndex(kD, 0);
215       for (int64_t j = 0; j < kD; ++j) {
216         multiDimIndex[j] = (index / extractedStrides[j]);
217         index -= multiDimIndex[j] * extractedStrides[j];
218       }
219       // Compute the corresponding linearized index in the source vector
220       // i.e. shift the multiDimIndex by the offsets.
221       int64_t linearizedIndex = 0;
222       for (int64_t j = 0; j < kD; ++j) {
223         linearizedIndex +=
224             (cast<IntegerAttr>(offsets[j]).getInt() + multiDimIndex[j]) *
225             sourceStrides[j];
226       }
227       // Fill the indices array form linearizedIndex to linearizedIndex +
228       // extractGranularitySize.
229       for (int64_t j = 0; j < extractGranularitySize; ++j) {
230         indices[i * extractGranularitySize + j] = linearizedIndex + j;
231       }
232     }
233     // Perform a shuffle to extract the kD vector.
234     rewriter.replaceOpWithNewOp<vector::ShuffleOp>(
235         extractOp, dstType, srcVector, srcVector, indices);
236     return success();
237   }
238 
239 private:
240   unsigned targetVectorBitWidth;
241 };
242 
243 /// This pattern converts the ShuffleOp that works on nD (n > 1)
244 /// vectors to a ShuffleOp that works on linearized vectors.
245 /// Following,
246 ///   vector.shuffle %v1, %v2 [ shuffle_indices ]
247 /// is converted to :
248 ///   %v1_1d = vector.shape_cast %v1
249 ///   %v2_1d = vector.shape_cast %v2
250 ///   %out_1d = vector.shuffle %v1_1d, %v2_1d [ shuffle_indices_1d ]
251 ///   %out_nd = vector.shape_cast %out_1d
252 // `shuffle_indices_1d` is computed using the sizes and `shuffle_indices`
253 /// of the original shuffle operation.
254 struct LinearizeVectorShuffle final
255     : public OpConversionPattern<vector::ShuffleOp> {
256   using OpConversionPattern::OpConversionPattern;
257   LinearizeVectorShuffle(
258       const TypeConverter &typeConverter, MLIRContext *context,
259       unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
260       PatternBenefit benefit = 1)
261       : OpConversionPattern(typeConverter, context, benefit),
262         targetVectorBitWidth(targetVectBitWidth) {}
263 
264   LogicalResult
265   matchAndRewrite(vector::ShuffleOp shuffleOp, OpAdaptor adaptor,
266                   ConversionPatternRewriter &rewriter) const override {
267     VectorType dstType =
268         getTypeConverter()->convertType<VectorType>(shuffleOp.getType());
269     assert(dstType && "vector type destination expected.");
270     // The assert is used because vector.shuffle does not support scalable
271     // vectors.
272     assert(!(shuffleOp.getV1VectorType().isScalable() ||
273              shuffleOp.getV2VectorType().isScalable() ||
274              dstType.isScalable()) &&
275            "scalable vectors are not supported.");
276     if (!isLessThanTargetBitWidth(shuffleOp, targetVectorBitWidth))
277       return rewriter.notifyMatchFailure(
278           shuffleOp, "Can't flatten since targetBitWidth <= OpSize");
279 
280     Value vec1 = adaptor.getV1();
281     Value vec2 = adaptor.getV2();
282     int shuffleSliceLen = 1;
283     int rank = shuffleOp.getV1().getType().getRank();
284 
285     // If rank > 1, we need to do the shuffle in the granularity of slices
286     // instead of scalars. Size of the slice is equal to the rank-1 innermost
287     // dims. Mask of the shuffle op specifies which slice to take from the
288     // outermost dim.
289     if (rank > 1) {
290       llvm::ArrayRef<int64_t> shape = shuffleOp.getV1().getType().getShape();
291       for (unsigned i = 1; i < shape.size(); ++i) {
292         shuffleSliceLen *= shape[i];
293       }
294     }
295 
296     // For each value in the mask, we generate the indices of the source vectors
297     // that needs to be shuffled to the destination vector. If shuffleSliceLen >
298     // 1 we need to shuffle the slices (consecutive shuffleSliceLen number of
299     // elements) instead of scalars.
300     ArrayRef<int64_t> mask = shuffleOp.getMask();
301     int64_t totalSizeOfShuffledElmnts = mask.size() * shuffleSliceLen;
302     llvm::SmallVector<int64_t, 2> indices(totalSizeOfShuffledElmnts);
303     for (auto [i, value] : llvm::enumerate(mask)) {
304       std::iota(indices.begin() + shuffleSliceLen * i,
305                 indices.begin() + shuffleSliceLen * (i + 1),
306                 shuffleSliceLen * value);
307     }
308 
309     rewriter.replaceOpWithNewOp<vector::ShuffleOp>(shuffleOp, dstType, vec1,
310                                                    vec2, indices);
311     return success();
312   }
313 
314 private:
315   unsigned targetVectorBitWidth;
316 };
317 
318 /// This pattern converts the ExtractOp to a ShuffleOp that works on a
319 /// linearized vector.
320 /// Following,
321 ///   vector.extract %source [ position ]
322 /// is converted to :
323 ///   %source_1d = vector.shape_cast %source
324 ///   %out_1d = vector.shuffle %source_1d, %source_1d [ shuffle_indices_1d ]
325 ///   %out_nd = vector.shape_cast %out_1d
326 /// `shuffle_indices_1d` is computed using the position of the original extract.
327 struct LinearizeVectorExtract final
328     : public OpConversionPattern<vector::ExtractOp> {
329   using OpConversionPattern::OpConversionPattern;
330   LinearizeVectorExtract(
331       const TypeConverter &typeConverter, MLIRContext *context,
332       unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
333       PatternBenefit benefit = 1)
334       : OpConversionPattern(typeConverter, context, benefit),
335         targetVectorBitWidth(targetVectBitWidth) {}
336   LogicalResult
337   matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor,
338                   ConversionPatternRewriter &rewriter) const override {
339     Type dstTy = getTypeConverter()->convertType(extractOp.getType());
340     if (extractOp.getVector().getType().isScalable() ||
341         cast<VectorType>(dstTy).isScalable())
342       return rewriter.notifyMatchFailure(extractOp,
343                                          "scalable vectors are not supported.");
344     if (!isLessThanTargetBitWidth(extractOp, targetVectorBitWidth))
345       return rewriter.notifyMatchFailure(
346           extractOp, "Can't flatten since targetBitWidth <= OpSize");
347 
348     // Dynamic position is not supported.
349     if (extractOp.hasDynamicPosition())
350       return rewriter.notifyMatchFailure(extractOp,
351                                          "dynamic position is not supported.");
352 
353     llvm::ArrayRef<int64_t> shape = extractOp.getVector().getType().getShape();
354     int64_t size = extractOp.getVector().getType().getNumElements();
355 
356     // Compute linearized offset.
357     int64_t linearizedOffset = 0;
358     llvm::ArrayRef<int64_t> offsets = extractOp.getStaticPosition();
359     for (auto [i, off] : llvm::enumerate(offsets)) {
360       size /= shape[i];
361       linearizedOffset += offsets[i] * size;
362     }
363 
364     llvm::SmallVector<int64_t, 2> indices(size);
365     std::iota(indices.begin(), indices.end(), linearizedOffset);
366     rewriter.replaceOpWithNewOp<vector::ShuffleOp>(
367         extractOp, dstTy, adaptor.getVector(), adaptor.getVector(), indices);
368 
369     return success();
370   }
371 
372 private:
373   unsigned targetVectorBitWidth;
374 };
375 
376 /// This pattern converts the InsertOp to a ShuffleOp that works on a
377 /// linearized vector.
378 /// Following,
379 ///   vector.insert %source %destination [ position ]
380 /// is converted to :
381 ///   %source_1d = vector.shape_cast %source
382 ///   %destination_1d = vector.shape_cast %destination
383 ///   %out_1d = vector.shuffle %destination_1d, %source_1d [ shuffle_indices_1d
384 ///   ] %out_nd = vector.shape_cast %out_1d
385 /// `shuffle_indices_1d` is computed using the position of the original insert.
386 struct LinearizeVectorInsert final
387     : public OpConversionPattern<vector::InsertOp> {
388   using OpConversionPattern::OpConversionPattern;
389   LinearizeVectorInsert(
390       const TypeConverter &typeConverter, MLIRContext *context,
391       unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
392       PatternBenefit benefit = 1)
393       : OpConversionPattern(typeConverter, context, benefit),
394         targetVectorBitWidth(targetVectBitWidth) {}
395   LogicalResult
396   matchAndRewrite(vector::InsertOp insertOp, OpAdaptor adaptor,
397                   ConversionPatternRewriter &rewriter) const override {
398     VectorType dstTy = getTypeConverter()->convertType<VectorType>(
399         insertOp.getDestVectorType());
400     assert(dstTy && "vector type destination expected.");
401     if (insertOp.getDestVectorType().isScalable() || dstTy.isScalable())
402       return rewriter.notifyMatchFailure(insertOp,
403                                          "scalable vectors are not supported.");
404 
405     if (!isLessThanOrEqualTargetBitWidth(insertOp.getSourceType(),
406                                          targetVectorBitWidth))
407       return rewriter.notifyMatchFailure(
408           insertOp, "Can't flatten since targetBitWidth < OpSize");
409 
410     // dynamic position is not supported
411     if (insertOp.hasDynamicPosition())
412       return rewriter.notifyMatchFailure(insertOp,
413                                          "dynamic position is not supported.");
414     auto srcTy = insertOp.getSourceType();
415     auto srcAsVec = dyn_cast<VectorType>(srcTy);
416     uint64_t srcSize = 0;
417     if (srcAsVec) {
418       srcSize = srcAsVec.getNumElements();
419     } else {
420       return rewriter.notifyMatchFailure(insertOp,
421                                          "scalars are not supported.");
422     }
423 
424     auto dstShape = insertOp.getDestVectorType().getShape();
425     const auto dstSize = insertOp.getDestVectorType().getNumElements();
426     auto dstSizeForOffsets = dstSize;
427 
428     // compute linearized offset
429     int64_t linearizedOffset = 0;
430     auto offsetsNd = insertOp.getStaticPosition();
431     for (auto [dim, offset] : llvm::enumerate(offsetsNd)) {
432       dstSizeForOffsets /= dstShape[dim];
433       linearizedOffset += offset * dstSizeForOffsets;
434     }
435 
436     llvm::SmallVector<int64_t, 2> indices(dstSize);
437     auto origValsUntil = indices.begin();
438     std::advance(origValsUntil, linearizedOffset);
439     std::iota(indices.begin(), origValsUntil,
440               0); // original values that remain [0, offset)
441     auto newValsUntil = origValsUntil;
442     std::advance(newValsUntil, srcSize);
443     std::iota(origValsUntil, newValsUntil,
444               dstSize); // new values [offset, offset+srcNumElements)
445     std::iota(newValsUntil, indices.end(),
446               linearizedOffset + srcSize); // the rest of original values
447                                            // [offset+srcNumElements, end)
448 
449     rewriter.replaceOpWithNewOp<vector::ShuffleOp>(
450         insertOp, dstTy, adaptor.getDest(), adaptor.getSource(), indices);
451 
452     return success();
453   }
454 
455 private:
456   unsigned targetVectorBitWidth;
457 };
458 } // namespace
459 
460 void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality(
461     TypeConverter &typeConverter, RewritePatternSet &patterns,
462     ConversionTarget &target, unsigned targetBitWidth) {
463 
464   typeConverter.addConversion([](VectorType type) -> std::optional<Type> {
465     if (!isLinearizableVector(type))
466       return type;
467 
468     return VectorType::get(type.getNumElements(), type.getElementType(),
469                            type.isScalable());
470   });
471 
472   auto materializeCast = [](OpBuilder &builder, Type type, ValueRange inputs,
473                             Location loc) -> Value {
474     if (inputs.size() != 1 || !isa<VectorType>(inputs.front().getType()) ||
475         !isa<VectorType>(type))
476       return nullptr;
477 
478     return builder.create<vector::ShapeCastOp>(loc, type, inputs.front());
479   };
480   typeConverter.addArgumentMaterialization(materializeCast);
481   typeConverter.addSourceMaterialization(materializeCast);
482   typeConverter.addTargetMaterialization(materializeCast);
483   target.markUnknownOpDynamicallyLegal(
484       [=](Operation *op) -> std::optional<bool> {
485         if ((isa<arith::ConstantOp>(op) ||
486              op->hasTrait<OpTrait::Vectorizable>())) {
487           return (isLessThanTargetBitWidth(op, targetBitWidth)
488                       ? typeConverter.isLegal(op)
489                       : true);
490         }
491         return std::nullopt;
492       });
493 
494   patterns.add<LinearizeConstant, LinearizeVectorizable>(
495       typeConverter, patterns.getContext(), targetBitWidth);
496 }
497 
498 void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns(
499     TypeConverter &typeConverter, RewritePatternSet &patterns,
500     ConversionTarget &target, unsigned int targetBitWidth) {
501   target.addDynamicallyLegalOp<vector::ShuffleOp>(
502       [=](vector::ShuffleOp shuffleOp) -> bool {
503         return isLessThanTargetBitWidth(shuffleOp, targetBitWidth)
504                    ? (typeConverter.isLegal(shuffleOp) &&
505                       cast<mlir::VectorType>(shuffleOp.getResult().getType())
506                               .getRank() == 1)
507                    : true;
508       });
509   patterns.add<LinearizeVectorShuffle, LinearizeVectorExtract,
510                LinearizeVectorInsert, LinearizeVectorExtractStridedSlice>(
511       typeConverter, patterns.getContext(), targetBitWidth);
512 }
513