xref: /llvm-project/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp (revision db791b278a414fb6df1acc1799adcf11d8fb9169)
1 //===- VectorLinearize.cpp - vector linearization transforms --------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements patterns and pass for linearizing ND vectors into 1D.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/Arith/IR/Arith.h"
14 #include "mlir/Dialect/Vector/IR/VectorOps.h"
15 #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
16 #include "mlir/IR/Attributes.h"
17 #include "mlir/IR/BuiltinAttributes.h"
18 #include "mlir/IR/Operation.h"
19 #include "mlir/IR/PatternMatch.h"
20 #include "mlir/IR/TypeUtilities.h"
21 #include "mlir/Transforms/DialectConversion.h"
22 #include "llvm/ADT/ArrayRef.h"
23 #include <cstdint>
24 #include <numeric>
25 
26 using namespace mlir;
27 
28 static bool isLessThanTargetBitWidth(Operation *op, unsigned targetBitWidth) {
29   auto resultTypes = op->getResultTypes();
30   for (auto resType : resultTypes) {
31     VectorType vecType = dyn_cast<VectorType>(resType);
32     // Reject index since getElementTypeBitWidth will abort for Index types.
33     if (!vecType || vecType.getElementType().isIndex())
34       return false;
35     // There are no dimension to fold if it is a 0-D vector.
36     if (vecType.getRank() == 0)
37       return false;
38     unsigned trailingVecDimBitWidth =
39         vecType.getShape().back() * vecType.getElementTypeBitWidth();
40     if (trailingVecDimBitWidth >= targetBitWidth)
41       return false;
42   }
43   return true;
44 }
45 
46 static bool isLessThanOrEqualTargetBitWidth(Type t, unsigned targetBitWidth) {
47   VectorType vecType = dyn_cast<VectorType>(t);
48   // Reject index since getElementTypeBitWidth will abort for Index types.
49   if (!vecType || vecType.getElementType().isIndex())
50     return false;
51   // There are no dimension to fold if it is a 0-D vector.
52   if (vecType.getRank() == 0)
53     return false;
54   unsigned trailingVecDimBitWidth =
55       vecType.getShape().back() * vecType.getElementTypeBitWidth();
56   return trailingVecDimBitWidth <= targetBitWidth;
57 }
58 
59 namespace {
60 struct LinearizeConstant final : OpConversionPattern<arith::ConstantOp> {
61   using OpConversionPattern::OpConversionPattern;
62   LinearizeConstant(
63       const TypeConverter &typeConverter, MLIRContext *context,
64       unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
65       PatternBenefit benefit = 1)
66       : OpConversionPattern(typeConverter, context, benefit),
67         targetVectorBitWidth(targetVectBitWidth) {}
68   LogicalResult
69   matchAndRewrite(arith::ConstantOp constOp, OpAdaptor adaptor,
70                   ConversionPatternRewriter &rewriter) const override {
71     Location loc = constOp.getLoc();
72     auto resType =
73         getTypeConverter()->convertType<VectorType>(constOp.getType());
74 
75     if (resType.isScalable() && !isa<SplatElementsAttr>(constOp.getValue()))
76       return rewriter.notifyMatchFailure(
77           loc,
78           "Cannot linearize a constant scalable vector that's not a splat");
79 
80     if (!resType)
81       return rewriter.notifyMatchFailure(loc, "can't convert return type");
82     if (!isLessThanTargetBitWidth(constOp, targetVectorBitWidth))
83       return rewriter.notifyMatchFailure(
84           loc, "Can't flatten since targetBitWidth <= OpSize");
85     auto dstElementsAttr = dyn_cast<DenseElementsAttr>(constOp.getValue());
86     if (!dstElementsAttr)
87       return rewriter.notifyMatchFailure(loc, "unsupported attr type");
88 
89     dstElementsAttr = dstElementsAttr.reshape(resType);
90     rewriter.replaceOpWithNewOp<arith::ConstantOp>(constOp, resType,
91                                                    dstElementsAttr);
92     return success();
93   }
94 
95 private:
96   unsigned targetVectorBitWidth;
97 };
98 
99 struct LinearizeVectorizable final
100     : OpTraitConversionPattern<OpTrait::Vectorizable> {
101   using OpTraitConversionPattern::OpTraitConversionPattern;
102 
103 public:
104   LinearizeVectorizable(
105       const TypeConverter &typeConverter, MLIRContext *context,
106       unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
107       PatternBenefit benefit = 1)
108       : OpTraitConversionPattern(typeConverter, context, benefit),
109         targetVectorBitWidth(targetVectBitWidth) {}
110   LogicalResult
111   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
112                   ConversionPatternRewriter &rewriter) const override {
113     if (!isLessThanTargetBitWidth(op, targetVectorBitWidth))
114       return rewriter.notifyMatchFailure(
115           op->getLoc(), "Can't flatten since targetBitWidth <= OpSize");
116     FailureOr<Operation *> newOp =
117         convertOpResultTypes(op, operands, *getTypeConverter(), rewriter);
118     if (failed(newOp))
119       return failure();
120 
121     rewriter.replaceOp(op, (*newOp)->getResults());
122     return success();
123   }
124 
125 private:
126   unsigned targetVectorBitWidth;
127 };
128 
129 /// This pattern converts the ExtractStridedSliceOp into a ShuffleOp that works
130 /// on a linearized vector.
131 /// Following,
132 ///   vector.extract_strided_slice %source
133 ///         { offsets = [..], strides = [..], sizes = [..] }
134 /// is converted to :
135 ///   %source_1d = vector.shape_cast %source
136 ///   %out_1d = vector.shuffle %source_1d, %source_1d [ shuffle_indices_1d ]
137 ///   %out_nd = vector.shape_cast %out_1d
138 /// `shuffle_indices_1d` is computed using the offsets and sizes of the
139 /// extraction.
140 struct LinearizeVectorExtractStridedSlice final
141     : public mlir::OpConversionPattern<mlir::vector::ExtractStridedSliceOp> {
142   using OpConversionPattern::OpConversionPattern;
143   LinearizeVectorExtractStridedSlice(
144       const TypeConverter &typeConverter, MLIRContext *context,
145       unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
146       PatternBenefit benefit = 1)
147       : OpConversionPattern(typeConverter, context, benefit),
148         targetVectorBitWidth(targetVectBitWidth) {}
149 
150   LogicalResult
151   matchAndRewrite(vector::ExtractStridedSliceOp extractOp, OpAdaptor adaptor,
152                   ConversionPatternRewriter &rewriter) const override {
153     VectorType dstType =
154         getTypeConverter()->convertType<VectorType>(extractOp.getType());
155     assert(dstType && "vector type destination expected.");
156     if (extractOp.getVector().getType().isScalable() || dstType.isScalable())
157       return rewriter.notifyMatchFailure(extractOp,
158                                          "scalable vectors are not supported.");
159     if (!isLessThanTargetBitWidth(extractOp, targetVectorBitWidth))
160       return rewriter.notifyMatchFailure(
161           extractOp, "Can't flatten since targetBitWidth <= OpSize");
162 
163     ArrayAttr offsets = extractOp.getOffsets();
164     ArrayAttr sizes = extractOp.getSizes();
165     ArrayAttr strides = extractOp.getStrides();
166     if (!isConstantIntValue(strides[0], 1))
167       return rewriter.notifyMatchFailure(
168           extractOp, "Strided slice with stride != 1 is not supported.");
169     Value srcVector = adaptor.getVector();
170     // If kD offsets are specified for nD source vector (n > k), the granularity
171     // of the extraction is greater than 1. In this case last (n-k) dimensions
172     // form the extraction granularity.
173     // Example :
174     //  vector.extract_strided_slice %src {
175     //      offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} :
176     //      vector<4x8x8xf32> to vector<2x2x8xf32>
177     // Here, extraction granularity is 8.
178     int64_t extractGranularitySize = 1;
179     int64_t nD = extractOp.getSourceVectorType().getRank();
180     int64_t kD = (int64_t)offsets.size();
181     int64_t k = kD;
182     while (k < nD) {
183       extractGranularitySize *= extractOp.getSourceVectorType().getShape()[k];
184       ++k;
185     }
186     // Get total number of extracted slices.
187     int64_t nExtractedSlices = 1;
188     for (Attribute size : sizes) {
189       nExtractedSlices *= cast<IntegerAttr>(size).getInt();
190     }
191     // Compute the strides of the source vector considering first k dimensions.
192     llvm::SmallVector<int64_t, 4> sourceStrides(kD, extractGranularitySize);
193     for (int i = kD - 2; i >= 0; --i) {
194       sourceStrides[i] = sourceStrides[i + 1] *
195                          extractOp.getSourceVectorType().getShape()[i + 1];
196     }
197     // Final shuffle indices has nExtractedSlices * extractGranularitySize
198     // elements.
199     llvm::SmallVector<int64_t, 4> indices(nExtractedSlices *
200                                           extractGranularitySize);
201     // Compute the strides of the extracted kD vector.
202     llvm::SmallVector<int64_t, 4> extractedStrides(kD, 1);
203     // Compute extractedStrides.
204     for (int i = kD - 2; i >= 0; --i) {
205       extractedStrides[i] =
206           extractedStrides[i + 1] * cast<IntegerAttr>(sizes[i + 1]).getInt();
207     }
208     // Iterate over all extracted slices from 0 to nExtractedSlices - 1
209     // and compute the multi-dimensional index and the corresponding linearized
210     // index within the source vector.
211     for (int64_t i = 0; i < nExtractedSlices; ++i) {
212       int64_t index = i;
213       // Compute the corresponding multi-dimensional index.
214       llvm::SmallVector<int64_t, 4> multiDimIndex(kD, 0);
215       for (int64_t j = 0; j < kD; ++j) {
216         multiDimIndex[j] = (index / extractedStrides[j]);
217         index -= multiDimIndex[j] * extractedStrides[j];
218       }
219       // Compute the corresponding linearized index in the source vector
220       // i.e. shift the multiDimIndex by the offsets.
221       int64_t linearizedIndex = 0;
222       for (int64_t j = 0; j < kD; ++j) {
223         linearizedIndex +=
224             (cast<IntegerAttr>(offsets[j]).getInt() + multiDimIndex[j]) *
225             sourceStrides[j];
226       }
227       // Fill the indices array form linearizedIndex to linearizedIndex +
228       // extractGranularitySize.
229       for (int64_t j = 0; j < extractGranularitySize; ++j) {
230         indices[i * extractGranularitySize + j] = linearizedIndex + j;
231       }
232     }
233     // Perform a shuffle to extract the kD vector.
234     rewriter.replaceOpWithNewOp<vector::ShuffleOp>(
235         extractOp, dstType, srcVector, srcVector,
236         rewriter.getI64ArrayAttr(indices));
237     return success();
238   }
239 
240 private:
241   unsigned targetVectorBitWidth;
242 };
243 
244 /// This pattern converts the ShuffleOp that works on nD (n > 1)
245 /// vectors to a ShuffleOp that works on linearized vectors.
246 /// Following,
247 ///   vector.shuffle %v1, %v2 [ shuffle_indices ]
248 /// is converted to :
249 ///   %v1_1d = vector.shape_cast %v1
250 ///   %v2_1d = vector.shape_cast %v2
251 ///   %out_1d = vector.shuffle %v1_1d, %v2_1d [ shuffle_indices_1d ]
252 ///   %out_nd = vector.shape_cast %out_1d
253 // `shuffle_indices_1d` is computed using the sizes and `shuffle_indices`
254 /// of the original shuffle operation.
255 struct LinearizeVectorShuffle final
256     : public OpConversionPattern<vector::ShuffleOp> {
257   using OpConversionPattern::OpConversionPattern;
258   LinearizeVectorShuffle(
259       const TypeConverter &typeConverter, MLIRContext *context,
260       unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
261       PatternBenefit benefit = 1)
262       : OpConversionPattern(typeConverter, context, benefit),
263         targetVectorBitWidth(targetVectBitWidth) {}
264 
265   LogicalResult
266   matchAndRewrite(vector::ShuffleOp shuffleOp, OpAdaptor adaptor,
267                   ConversionPatternRewriter &rewriter) const override {
268     VectorType dstType =
269         getTypeConverter()->convertType<VectorType>(shuffleOp.getType());
270     assert(dstType && "vector type destination expected.");
271     // The assert is used because vector.shuffle does not support scalable
272     // vectors.
273     assert(!(shuffleOp.getV1VectorType().isScalable() ||
274              shuffleOp.getV2VectorType().isScalable() ||
275              dstType.isScalable()) &&
276            "scalable vectors are not supported.");
277     if (!isLessThanTargetBitWidth(shuffleOp, targetVectorBitWidth))
278       return rewriter.notifyMatchFailure(
279           shuffleOp, "Can't flatten since targetBitWidth <= OpSize");
280 
281     Value vec1 = adaptor.getV1();
282     Value vec2 = adaptor.getV2();
283     int shuffleSliceLen = 1;
284     int rank = shuffleOp.getV1().getType().getRank();
285 
286     // If rank > 1, we need to do the shuffle in the granularity of slices
287     // instead of scalars. Size of the slice is equal to the rank-1 innermost
288     // dims. Mask of the shuffle op specifies which slice to take from the
289     // outermost dim.
290     if (rank > 1) {
291       llvm::ArrayRef<int64_t> shape = shuffleOp.getV1().getType().getShape();
292       for (unsigned i = 1; i < shape.size(); ++i) {
293         shuffleSliceLen *= shape[i];
294       }
295     }
296 
297     // For each value in the mask, we generate the indices of the source vectors
298     // that needs to be shuffled to the destination vector. If shuffleSliceLen >
299     // 1 we need to shuffle the slices (consecutive shuffleSliceLen number of
300     // elements) instead of scalars.
301     ArrayAttr mask = shuffleOp.getMask();
302     int64_t totalSizeOfShuffledElmnts = mask.size() * shuffleSliceLen;
303     llvm::SmallVector<int64_t, 2> indices(totalSizeOfShuffledElmnts);
304     for (auto [i, value] :
305          llvm::enumerate(mask.getAsValueRange<IntegerAttr>())) {
306 
307       int64_t v = value.getZExtValue();
308       std::iota(indices.begin() + shuffleSliceLen * i,
309                 indices.begin() + shuffleSliceLen * (i + 1),
310                 shuffleSliceLen * v);
311     }
312 
313     rewriter.replaceOpWithNewOp<vector::ShuffleOp>(
314         shuffleOp, dstType, vec1, vec2, rewriter.getI64ArrayAttr(indices));
315     return success();
316   }
317 
318 private:
319   unsigned targetVectorBitWidth;
320 };
321 
322 /// This pattern converts the ExtractOp to a ShuffleOp that works on a
323 /// linearized vector.
324 /// Following,
325 ///   vector.extract %source [ position ]
326 /// is converted to :
327 ///   %source_1d = vector.shape_cast %source
328 ///   %out_1d = vector.shuffle %source_1d, %source_1d [ shuffle_indices_1d ]
329 ///   %out_nd = vector.shape_cast %out_1d
330 /// `shuffle_indices_1d` is computed using the position of the original extract.
331 struct LinearizeVectorExtract final
332     : public OpConversionPattern<vector::ExtractOp> {
333   using OpConversionPattern::OpConversionPattern;
334   LinearizeVectorExtract(
335       const TypeConverter &typeConverter, MLIRContext *context,
336       unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
337       PatternBenefit benefit = 1)
338       : OpConversionPattern(typeConverter, context, benefit),
339         targetVectorBitWidth(targetVectBitWidth) {}
340   LogicalResult
341   matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor,
342                   ConversionPatternRewriter &rewriter) const override {
343     Type dstTy = getTypeConverter()->convertType(extractOp.getType());
344     if (extractOp.getVector().getType().isScalable() ||
345         cast<VectorType>(dstTy).isScalable())
346       return rewriter.notifyMatchFailure(extractOp,
347                                          "scalable vectors are not supported.");
348     if (!isLessThanTargetBitWidth(extractOp, targetVectorBitWidth))
349       return rewriter.notifyMatchFailure(
350           extractOp, "Can't flatten since targetBitWidth <= OpSize");
351 
352     // Dynamic position is not supported.
353     if (extractOp.hasDynamicPosition())
354       return rewriter.notifyMatchFailure(extractOp,
355                                          "dynamic position is not supported.");
356 
357     llvm::ArrayRef<int64_t> shape = extractOp.getVector().getType().getShape();
358     int64_t size = extractOp.getVector().getType().getNumElements();
359 
360     // Compute linearized offset.
361     int64_t linearizedOffset = 0;
362     llvm::ArrayRef<int64_t> offsets = extractOp.getStaticPosition();
363     for (auto [i, off] : llvm::enumerate(offsets)) {
364       size /= shape[i];
365       linearizedOffset += offsets[i] * size;
366     }
367 
368     llvm::SmallVector<int64_t, 2> indices(size);
369     std::iota(indices.begin(), indices.end(), linearizedOffset);
370     rewriter.replaceOpWithNewOp<vector::ShuffleOp>(
371         extractOp, dstTy, adaptor.getVector(), adaptor.getVector(),
372         rewriter.getI64ArrayAttr(indices));
373 
374     return success();
375   }
376 
377 private:
378   unsigned targetVectorBitWidth;
379 };
380 
381 /// This pattern converts the InsertOp to a ShuffleOp that works on a
382 /// linearized vector.
383 /// Following,
384 ///   vector.insert %source %destination [ position ]
385 /// is converted to :
386 ///   %source_1d = vector.shape_cast %source
387 ///   %destination_1d = vector.shape_cast %destination
388 ///   %out_1d = vector.shuffle %destination_1d, %source_1d [ shuffle_indices_1d
389 ///   ] %out_nd = vector.shape_cast %out_1d
390 /// `shuffle_indices_1d` is computed using the position of the original insert.
391 struct LinearizeVectorInsert final
392     : public OpConversionPattern<vector::InsertOp> {
393   using OpConversionPattern::OpConversionPattern;
394   LinearizeVectorInsert(
395       const TypeConverter &typeConverter, MLIRContext *context,
396       unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
397       PatternBenefit benefit = 1)
398       : OpConversionPattern(typeConverter, context, benefit),
399         targetVectorBitWidth(targetVectBitWidth) {}
400   LogicalResult
401   matchAndRewrite(vector::InsertOp insertOp, OpAdaptor adaptor,
402                   ConversionPatternRewriter &rewriter) const override {
403     VectorType dstTy = getTypeConverter()->convertType<VectorType>(
404         insertOp.getDestVectorType());
405     assert(dstTy && "vector type destination expected.");
406     if (insertOp.getDestVectorType().isScalable() || dstTy.isScalable())
407       return rewriter.notifyMatchFailure(insertOp,
408                                          "scalable vectors are not supported.");
409 
410     if (!isLessThanOrEqualTargetBitWidth(insertOp.getSourceType(),
411                                          targetVectorBitWidth))
412       return rewriter.notifyMatchFailure(
413           insertOp, "Can't flatten since targetBitWidth < OpSize");
414 
415     // dynamic position is not supported
416     if (insertOp.hasDynamicPosition())
417       return rewriter.notifyMatchFailure(insertOp,
418                                          "dynamic position is not supported.");
419     auto srcTy = insertOp.getSourceType();
420     auto srcAsVec = dyn_cast<VectorType>(srcTy);
421     uint64_t srcSize = 0;
422     if (srcAsVec) {
423       srcSize = srcAsVec.getNumElements();
424     } else {
425       return rewriter.notifyMatchFailure(insertOp,
426                                          "scalars are not supported.");
427     }
428 
429     auto dstShape = insertOp.getDestVectorType().getShape();
430     const auto dstSize = insertOp.getDestVectorType().getNumElements();
431     auto dstSizeForOffsets = dstSize;
432 
433     // compute linearized offset
434     int64_t linearizedOffset = 0;
435     auto offsetsNd = insertOp.getStaticPosition();
436     for (auto [dim, offset] : llvm::enumerate(offsetsNd)) {
437       dstSizeForOffsets /= dstShape[dim];
438       linearizedOffset += offset * dstSizeForOffsets;
439     }
440 
441     llvm::SmallVector<int64_t, 2> indices(dstSize);
442     auto origValsUntil = indices.begin();
443     std::advance(origValsUntil, linearizedOffset);
444     std::iota(indices.begin(), origValsUntil,
445               0); // original values that remain [0, offset)
446     auto newValsUntil = origValsUntil;
447     std::advance(newValsUntil, srcSize);
448     std::iota(origValsUntil, newValsUntil,
449               dstSize); // new values [offset, offset+srcNumElements)
450     std::iota(newValsUntil, indices.end(),
451               linearizedOffset + srcSize); // the rest of original values
452                                            // [offset+srcNumElements, end)
453 
454     rewriter.replaceOpWithNewOp<vector::ShuffleOp>(
455         insertOp, dstTy, adaptor.getDest(), adaptor.getSource(),
456         rewriter.getI64ArrayAttr(indices));
457 
458     return success();
459   }
460 
461 private:
462   unsigned targetVectorBitWidth;
463 };
464 } // namespace
465 
466 void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality(
467     TypeConverter &typeConverter, RewritePatternSet &patterns,
468     ConversionTarget &target, unsigned targetBitWidth) {
469 
470   typeConverter.addConversion([](VectorType type) -> std::optional<Type> {
471     if (!isLinearizableVector(type))
472       return type;
473 
474     return VectorType::get(type.getNumElements(), type.getElementType(),
475                            type.isScalable());
476   });
477 
478   auto materializeCast = [](OpBuilder &builder, Type type, ValueRange inputs,
479                             Location loc) -> Value {
480     if (inputs.size() != 1 || !isa<VectorType>(inputs.front().getType()) ||
481         !isa<VectorType>(type))
482       return nullptr;
483 
484     return builder.create<vector::ShapeCastOp>(loc, type, inputs.front());
485   };
486   typeConverter.addArgumentMaterialization(materializeCast);
487   typeConverter.addSourceMaterialization(materializeCast);
488   typeConverter.addTargetMaterialization(materializeCast);
489   target.markUnknownOpDynamicallyLegal(
490       [=](Operation *op) -> std::optional<bool> {
491         if ((isa<arith::ConstantOp>(op) ||
492              op->hasTrait<OpTrait::Vectorizable>())) {
493           return (isLessThanTargetBitWidth(op, targetBitWidth)
494                       ? typeConverter.isLegal(op)
495                       : true);
496         }
497         return std::nullopt;
498       });
499 
500   patterns.add<LinearizeConstant, LinearizeVectorizable>(
501       typeConverter, patterns.getContext(), targetBitWidth);
502 }
503 
504 void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns(
505     TypeConverter &typeConverter, RewritePatternSet &patterns,
506     ConversionTarget &target, unsigned int targetBitWidth) {
507   target.addDynamicallyLegalOp<vector::ShuffleOp>(
508       [=](vector::ShuffleOp shuffleOp) -> bool {
509         return isLessThanTargetBitWidth(shuffleOp, targetBitWidth)
510                    ? (typeConverter.isLegal(shuffleOp) &&
511                       cast<mlir::VectorType>(shuffleOp.getResult().getType())
512                               .getRank() == 1)
513                    : true;
514       });
515   patterns.add<LinearizeVectorShuffle, LinearizeVectorExtract,
516                LinearizeVectorInsert, LinearizeVectorExtractStridedSlice>(
517       typeConverter, patterns.getContext(), targetBitWidth);
518 }
519