xref: /llvm-project/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp (revision 01fbc5658cdfa152519e2d0842ccf7d91aaeaeaf)
1 //===- VectorLinearize.cpp - vector linearization transforms --------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements patterns and pass for linearizing ND vectors into 1D.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/Arith/IR/Arith.h"
14 #include "mlir/Dialect/Vector/IR/VectorOps.h"
15 #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
16 #include "mlir/IR/Attributes.h"
17 #include "mlir/IR/BuiltinAttributes.h"
18 #include "mlir/IR/Operation.h"
19 #include "mlir/IR/PatternMatch.h"
20 #include "mlir/IR/TypeUtilities.h"
21 #include "mlir/Support/LogicalResult.h"
22 #include "mlir/Transforms/DialectConversion.h"
23 #include "llvm/ADT/ArrayRef.h"
24 #include <cstdint>
25 #include <numeric>
26 
27 using namespace mlir;
28 
29 static bool isLessThanTargetBitWidth(Operation *op, unsigned targetBitWidth) {
30   auto resultTypes = op->getResultTypes();
31   for (auto resType : resultTypes) {
32     VectorType vecType = dyn_cast<VectorType>(resType);
33     // Reject index since getElementTypeBitWidth will abort for Index types.
34     if (!vecType || vecType.getElementType().isIndex())
35       return false;
36     // There are no dimension to fold if it is a 0-D vector.
37     if (vecType.getRank() == 0)
38       return false;
39     unsigned trailingVecDimBitWidth =
40         vecType.getShape().back() * vecType.getElementTypeBitWidth();
41     if (trailingVecDimBitWidth >= targetBitWidth)
42       return false;
43   }
44   return true;
45 }
46 
47 static bool isLessThanOrEqualTargetBitWidth(Type t, unsigned targetBitWidth) {
48   VectorType vecType = dyn_cast<VectorType>(t);
49   // Reject index since getElementTypeBitWidth will abort for Index types.
50   if (!vecType || vecType.getElementType().isIndex())
51     return false;
52   // There are no dimension to fold if it is a 0-D vector.
53   if (vecType.getRank() == 0)
54     return false;
55   unsigned trailingVecDimBitWidth =
56       vecType.getShape().back() * vecType.getElementTypeBitWidth();
57   return trailingVecDimBitWidth <= targetBitWidth;
58 }
59 
60 namespace {
61 struct LinearizeConstant final : OpConversionPattern<arith::ConstantOp> {
62   using OpConversionPattern::OpConversionPattern;
63   LinearizeConstant(
64       const TypeConverter &typeConverter, MLIRContext *context,
65       unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
66       PatternBenefit benefit = 1)
67       : OpConversionPattern(typeConverter, context, benefit),
68         targetVectorBitWidth(targetVectBitWidth) {}
69   LogicalResult
70   matchAndRewrite(arith::ConstantOp constOp, OpAdaptor adaptor,
71                   ConversionPatternRewriter &rewriter) const override {
72     Location loc = constOp.getLoc();
73     auto resType =
74         getTypeConverter()->convertType<VectorType>(constOp.getType());
75 
76     if (resType.isScalable() && !isa<SplatElementsAttr>(constOp.getValue()))
77       return rewriter.notifyMatchFailure(
78           loc,
79           "Cannot linearize a constant scalable vector that's not a splat");
80 
81     if (!resType)
82       return rewriter.notifyMatchFailure(loc, "can't convert return type");
83     if (!isLessThanTargetBitWidth(constOp, targetVectorBitWidth))
84       return rewriter.notifyMatchFailure(
85           loc, "Can't flatten since targetBitWidth <= OpSize");
86     auto dstElementsAttr = dyn_cast<DenseElementsAttr>(constOp.getValue());
87     if (!dstElementsAttr)
88       return rewriter.notifyMatchFailure(loc, "unsupported attr type");
89 
90     dstElementsAttr = dstElementsAttr.reshape(resType);
91     rewriter.replaceOpWithNewOp<arith::ConstantOp>(constOp, resType,
92                                                    dstElementsAttr);
93     return success();
94   }
95 
96 private:
97   unsigned targetVectorBitWidth;
98 };
99 
100 struct LinearizeVectorizable final
101     : OpTraitConversionPattern<OpTrait::Vectorizable> {
102   using OpTraitConversionPattern::OpTraitConversionPattern;
103 
104 public:
105   LinearizeVectorizable(
106       const TypeConverter &typeConverter, MLIRContext *context,
107       unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
108       PatternBenefit benefit = 1)
109       : OpTraitConversionPattern(typeConverter, context, benefit),
110         targetVectorBitWidth(targetVectBitWidth) {}
111   LogicalResult
112   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
113                   ConversionPatternRewriter &rewriter) const override {
114     if (!isLessThanTargetBitWidth(op, targetVectorBitWidth))
115       return rewriter.notifyMatchFailure(
116           op->getLoc(), "Can't flatten since targetBitWidth <= OpSize");
117     FailureOr<Operation *> newOp =
118         convertOpResultTypes(op, operands, *getTypeConverter(), rewriter);
119     if (failed(newOp))
120       return failure();
121 
122     rewriter.replaceOp(op, (*newOp)->getResults());
123     return success();
124   }
125 
126 private:
127   unsigned targetVectorBitWidth;
128 };
129 
130 /// This pattern converts the ExtractStridedSliceOp into a ShuffleOp that works
131 /// on a linearized vector.
132 /// Following,
133 ///   vector.extract_strided_slice %source
134 ///         { offsets = [..], strides = [..], sizes = [..] }
135 /// is converted to :
136 ///   %source_1d = vector.shape_cast %source
137 ///   %out_1d = vector.shuffle %source_1d, %source_1d [ shuffle_indices_1d ]
138 ///   %out_nd = vector.shape_cast %out_1d
139 /// `shuffle_indices_1d` is computed using the offsets and sizes of the
140 /// extraction.
141 struct LinearizeVectorExtractStridedSlice final
142     : public mlir::OpConversionPattern<mlir::vector::ExtractStridedSliceOp> {
143   using OpConversionPattern::OpConversionPattern;
144   LinearizeVectorExtractStridedSlice(
145       const TypeConverter &typeConverter, MLIRContext *context,
146       unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
147       PatternBenefit benefit = 1)
148       : OpConversionPattern(typeConverter, context, benefit),
149         targetVectorBitWidth(targetVectBitWidth) {}
150 
151   LogicalResult
152   matchAndRewrite(vector::ExtractStridedSliceOp extractOp, OpAdaptor adaptor,
153                   ConversionPatternRewriter &rewriter) const override {
154     Type dstType = getTypeConverter()->convertType(extractOp.getType());
155     assert(!(extractOp.getVector().getType().isScalable() ||
156              cast<VectorType>(dstType).isScalable()) &&
157            "scalable vectors are not supported.");
158     if (!isLessThanTargetBitWidth(extractOp, targetVectorBitWidth))
159       return rewriter.notifyMatchFailure(
160           extractOp, "Can't flatten since targetBitWidth <= OpSize");
161 
162     ArrayAttr offsets = extractOp.getOffsets();
163     ArrayAttr sizes = extractOp.getSizes();
164     ArrayAttr strides = extractOp.getStrides();
165     if (!isConstantIntValue(strides[0], 1))
166       return rewriter.notifyMatchFailure(
167           extractOp, "Strided slice with stride != 1 is not supported.");
168     Value srcVector = adaptor.getVector();
169     // If kD offsets are specified for nD source vector (n > k), the granularity
170     // of the extraction is greater than 1. In this case last (n-k) dimensions
171     // form the extraction granularity.
172     // Example :
173     //  vector.extract_strided_slice %src {
174     //      offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} :
175     //      vector<4x8x8xf32> to vector<2x2x8xf32>
176     // Here, extraction granularity is 8.
177     int64_t extractGranularitySize = 1;
178     int64_t nD = extractOp.getSourceVectorType().getRank();
179     int64_t kD = (int64_t)offsets.size();
180     int64_t k = kD;
181     while (k < nD) {
182       extractGranularitySize *= extractOp.getSourceVectorType().getShape()[k];
183       ++k;
184     }
185     // Get total number of extracted slices.
186     int64_t nExtractedSlices = 1;
187     for (Attribute size : sizes) {
188       nExtractedSlices *= cast<IntegerAttr>(size).getInt();
189     }
190     // Compute the strides of the source vector considering first k dimensions.
191     llvm::SmallVector<int64_t, 4> sourceStrides(kD, extractGranularitySize);
192     for (int i = kD - 2; i >= 0; --i) {
193       sourceStrides[i] = sourceStrides[i + 1] *
194                          extractOp.getSourceVectorType().getShape()[i + 1];
195     }
196     // Final shuffle indices has nExtractedSlices * extractGranularitySize
197     // elements.
198     llvm::SmallVector<int64_t, 4> indices(nExtractedSlices *
199                                           extractGranularitySize);
200     // Compute the strides of the extracted kD vector.
201     llvm::SmallVector<int64_t, 4> extractedStrides(kD, 1);
202     // Compute extractedStrides.
203     for (int i = kD - 2; i >= 0; --i) {
204       extractedStrides[i] =
205           extractedStrides[i + 1] * cast<IntegerAttr>(sizes[i + 1]).getInt();
206     }
207     // Iterate over all extracted slices from 0 to nExtractedSlices - 1
208     // and compute the multi-dimensional index and the corresponding linearized
209     // index within the source vector.
210     for (int64_t i = 0; i < nExtractedSlices; ++i) {
211       int64_t index = i;
212       // Compute the corresponding multi-dimensional index.
213       llvm::SmallVector<int64_t, 4> multiDimIndex(kD, 0);
214       for (int64_t j = 0; j < kD; ++j) {
215         multiDimIndex[j] = (index / extractedStrides[j]);
216         index -= multiDimIndex[j] * extractedStrides[j];
217       }
218       // Compute the corresponding linearized index in the source vector
219       // i.e. shift the multiDimIndex by the offsets.
220       int64_t linearizedIndex = 0;
221       for (int64_t j = 0; j < kD; ++j) {
222         linearizedIndex +=
223             (cast<IntegerAttr>(offsets[j]).getInt() + multiDimIndex[j]) *
224             sourceStrides[j];
225       }
226       // Fill the indices array form linearizedIndex to linearizedIndex +
227       // extractGranularitySize.
228       for (int64_t j = 0; j < extractGranularitySize; ++j) {
229         indices[i * extractGranularitySize + j] = linearizedIndex + j;
230       }
231     }
232     // Perform a shuffle to extract the kD vector.
233     rewriter.replaceOpWithNewOp<vector::ShuffleOp>(
234         extractOp, dstType, srcVector, srcVector,
235         rewriter.getI64ArrayAttr(indices));
236     return success();
237   }
238 
239 private:
240   unsigned targetVectorBitWidth;
241 };
242 
243 /// This pattern converts the ShuffleOp that works on nD (n > 1)
244 /// vectors to a ShuffleOp that works on linearized vectors.
245 /// Following,
246 ///   vector.shuffle %v1, %v2 [ shuffle_indices ]
247 /// is converted to :
248 ///   %v1_1d = vector.shape_cast %v1
249 ///   %v2_1d = vector.shape_cast %v2
250 ///   %out_1d = vector.shuffle %v1_1d, %v2_1d [ shuffle_indices_1d ]
251 ///   %out_nd = vector.shape_cast %out_1d
252 // `shuffle_indices_1d` is computed using the sizes and `shuffle_indices`
253 /// of the original shuffle operation.
254 struct LinearizeVectorShuffle final
255     : public OpConversionPattern<vector::ShuffleOp> {
256   using OpConversionPattern::OpConversionPattern;
257   LinearizeVectorShuffle(
258       const TypeConverter &typeConverter, MLIRContext *context,
259       unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
260       PatternBenefit benefit = 1)
261       : OpConversionPattern(typeConverter, context, benefit),
262         targetVectorBitWidth(targetVectBitWidth) {}
263 
264   LogicalResult
265   matchAndRewrite(vector::ShuffleOp shuffleOp, OpAdaptor adaptor,
266                   ConversionPatternRewriter &rewriter) const override {
267     Type dstType = getTypeConverter()->convertType(shuffleOp.getType());
268     assert(!(shuffleOp.getV1VectorType().isScalable() ||
269              shuffleOp.getV2VectorType().isScalable() ||
270              cast<VectorType>(dstType).isScalable()) &&
271            "scalable vectors are not supported.");
272     if (!isLessThanTargetBitWidth(shuffleOp, targetVectorBitWidth))
273       return rewriter.notifyMatchFailure(
274           shuffleOp, "Can't flatten since targetBitWidth <= OpSize");
275 
276     Value vec1 = adaptor.getV1();
277     Value vec2 = adaptor.getV2();
278     int shuffleSliceLen = 1;
279     int rank = shuffleOp.getV1().getType().getRank();
280 
281     // If rank > 1, we need to do the shuffle in the granularity of slices
282     // instead of scalars. Size of the slice is equal to the rank-1 innermost
283     // dims. Mask of the shuffle op specifies which slice to take from the
284     // outermost dim.
285     if (rank > 1) {
286       llvm::ArrayRef<int64_t> shape = shuffleOp.getV1().getType().getShape();
287       for (unsigned i = 1; i < shape.size(); ++i) {
288         shuffleSliceLen *= shape[i];
289       }
290     }
291 
292     // For each value in the mask, we generate the indices of the source vectors
293     // that needs to be shuffled to the destination vector. If shuffleSliceLen >
294     // 1 we need to shuffle the slices (consecutive shuffleSliceLen number of
295     // elements) instead of scalars.
296     ArrayAttr mask = shuffleOp.getMask();
297     int64_t totalSizeOfShuffledElmnts = mask.size() * shuffleSliceLen;
298     llvm::SmallVector<int64_t, 2> indices(totalSizeOfShuffledElmnts);
299     for (auto [i, value] :
300          llvm::enumerate(mask.getAsValueRange<IntegerAttr>())) {
301 
302       int64_t v = value.getZExtValue();
303       std::iota(indices.begin() + shuffleSliceLen * i,
304                 indices.begin() + shuffleSliceLen * (i + 1),
305                 shuffleSliceLen * v);
306     }
307 
308     rewriter.replaceOpWithNewOp<vector::ShuffleOp>(
309         shuffleOp, dstType, vec1, vec2, rewriter.getI64ArrayAttr(indices));
310     return success();
311   }
312 
313 private:
314   unsigned targetVectorBitWidth;
315 };
316 
317 /// This pattern converts the ExtractOp to a ShuffleOp that works on a
318 /// linearized vector.
319 /// Following,
320 ///   vector.extract %source [ position ]
321 /// is converted to :
322 ///   %source_1d = vector.shape_cast %source
323 ///   %out_1d = vector.shuffle %source_1d, %source_1d [ shuffle_indices_1d ]
324 ///   %out_nd = vector.shape_cast %out_1d
325 /// `shuffle_indices_1d` is computed using the position of the original extract.
326 struct LinearizeVectorExtract final
327     : public OpConversionPattern<vector::ExtractOp> {
328   using OpConversionPattern::OpConversionPattern;
329   LinearizeVectorExtract(
330       const TypeConverter &typeConverter, MLIRContext *context,
331       unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
332       PatternBenefit benefit = 1)
333       : OpConversionPattern(typeConverter, context, benefit),
334         targetVectorBitWidth(targetVectBitWidth) {}
335   LogicalResult
336   matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor,
337                   ConversionPatternRewriter &rewriter) const override {
338     Type dstTy = getTypeConverter()->convertType(extractOp.getType());
339     assert(!(extractOp.getVector().getType().isScalable() ||
340              cast<VectorType>(dstTy).isScalable()) &&
341            "scalable vectors are not supported.");
342     if (!isLessThanTargetBitWidth(extractOp, targetVectorBitWidth))
343       return rewriter.notifyMatchFailure(
344           extractOp, "Can't flatten since targetBitWidth <= OpSize");
345 
346     // Dynamic position is not supported.
347     if (extractOp.hasDynamicPosition())
348       return rewriter.notifyMatchFailure(extractOp,
349                                          "dynamic position is not supported.");
350 
351     llvm::ArrayRef<int64_t> shape = extractOp.getVector().getType().getShape();
352     int64_t size = extractOp.getVector().getType().getNumElements();
353 
354     // Compute linearized offset.
355     int64_t linearizedOffset = 0;
356     llvm::ArrayRef<int64_t> offsets = extractOp.getStaticPosition();
357     for (auto [i, off] : llvm::enumerate(offsets)) {
358       size /= shape[i];
359       linearizedOffset += offsets[i] * size;
360     }
361 
362     llvm::SmallVector<int64_t, 2> indices(size);
363     std::iota(indices.begin(), indices.end(), linearizedOffset);
364     rewriter.replaceOpWithNewOp<vector::ShuffleOp>(
365         extractOp, dstTy, adaptor.getVector(), adaptor.getVector(),
366         rewriter.getI64ArrayAttr(indices));
367 
368     return success();
369   }
370 
371 private:
372   unsigned targetVectorBitWidth;
373 };
374 
375 /// This pattern converts the InsertOp to a ShuffleOp that works on a
376 /// linearized vector.
377 /// Following,
378 ///   vector.insert %source %destination [ position ]
379 /// is converted to :
380 ///   %source_1d = vector.shape_cast %source
381 ///   %destination_1d = vector.shape_cast %destination
382 ///   %out_1d = vector.shuffle %destination_1d, %source_1d [ shuffle_indices_1d
383 ///   ] %out_nd = vector.shape_cast %out_1d
384 /// `shuffle_indices_1d` is computed using the position of the original insert.
385 struct LinearizeVectorInsert final
386     : public OpConversionPattern<vector::InsertOp> {
387   using OpConversionPattern::OpConversionPattern;
388   LinearizeVectorInsert(
389       const TypeConverter &typeConverter, MLIRContext *context,
390       unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
391       PatternBenefit benefit = 1)
392       : OpConversionPattern(typeConverter, context, benefit),
393         targetVectorBitWidth(targetVectBitWidth) {}
394   LogicalResult
395   matchAndRewrite(vector::InsertOp insertOp, OpAdaptor adaptor,
396                   ConversionPatternRewriter &rewriter) const override {
397     Type dstTy = getTypeConverter()->convertType(insertOp.getDestVectorType());
398     assert(!(insertOp.getDestVectorType().isScalable() ||
399              cast<VectorType>(dstTy).isScalable()) &&
400            "scalable vectors are not supported.");
401 
402     if (!isLessThanOrEqualTargetBitWidth(insertOp.getSourceType(),
403                                          targetVectorBitWidth))
404       return rewriter.notifyMatchFailure(
405           insertOp, "Can't flatten since targetBitWidth < OpSize");
406 
407     // dynamic position is not supported
408     if (insertOp.hasDynamicPosition())
409       return rewriter.notifyMatchFailure(insertOp,
410                                          "dynamic position is not supported.");
411     auto srcTy = insertOp.getSourceType();
412     auto srcAsVec = dyn_cast<VectorType>(srcTy);
413     uint64_t srcSize = 0;
414     if (srcAsVec) {
415       srcSize = srcAsVec.getNumElements();
416     } else {
417       return rewriter.notifyMatchFailure(insertOp,
418                                          "scalars are not supported.");
419     }
420 
421     auto dstShape = insertOp.getDestVectorType().getShape();
422     const auto dstSize = insertOp.getDestVectorType().getNumElements();
423     auto dstSizeForOffsets = dstSize;
424 
425     // compute linearized offset
426     int64_t linearizedOffset = 0;
427     auto offsetsNd = insertOp.getStaticPosition();
428     for (auto [dim, offset] : llvm::enumerate(offsetsNd)) {
429       dstSizeForOffsets /= dstShape[dim];
430       linearizedOffset += offset * dstSizeForOffsets;
431     }
432 
433     llvm::SmallVector<int64_t, 2> indices(dstSize);
434     auto origValsUntil = indices.begin();
435     std::advance(origValsUntil, linearizedOffset);
436     std::iota(indices.begin(), origValsUntil,
437               0); // original values that remain [0, offset)
438     auto newValsUntil = origValsUntil;
439     std::advance(newValsUntil, srcSize);
440     std::iota(origValsUntil, newValsUntil,
441               dstSize); // new values [offset, offset+srcNumElements)
442     std::iota(newValsUntil, indices.end(),
443               linearizedOffset + srcSize); // the rest of original values
444                                            // [offset+srcNumElements, end)
445 
446     rewriter.replaceOpWithNewOp<vector::ShuffleOp>(
447         insertOp, dstTy, adaptor.getDest(), adaptor.getSource(),
448         rewriter.getI64ArrayAttr(indices));
449 
450     return success();
451   }
452 
453 private:
454   unsigned targetVectorBitWidth;
455 };
456 } // namespace
457 
458 void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality(
459     TypeConverter &typeConverter, RewritePatternSet &patterns,
460     ConversionTarget &target, unsigned targetBitWidth) {
461 
462   typeConverter.addConversion([](VectorType type) -> std::optional<Type> {
463     if (!isLinearizableVector(type))
464       return type;
465 
466     return VectorType::get(type.getNumElements(), type.getElementType(),
467                            type.isScalable());
468   });
469 
470   auto materializeCast = [](OpBuilder &builder, Type type, ValueRange inputs,
471                             Location loc) -> Value {
472     if (inputs.size() != 1 || !isa<VectorType>(inputs.front().getType()) ||
473         !isa<VectorType>(type))
474       return nullptr;
475 
476     return builder.create<vector::ShapeCastOp>(loc, type, inputs.front());
477   };
478   typeConverter.addArgumentMaterialization(materializeCast);
479   typeConverter.addSourceMaterialization(materializeCast);
480   typeConverter.addTargetMaterialization(materializeCast);
481   target.markUnknownOpDynamicallyLegal(
482       [=](Operation *op) -> std::optional<bool> {
483         if ((isa<arith::ConstantOp>(op) ||
484              op->hasTrait<OpTrait::Vectorizable>())) {
485           return (isLessThanTargetBitWidth(op, targetBitWidth)
486                       ? typeConverter.isLegal(op)
487                       : true);
488         }
489         return std::nullopt;
490       });
491 
492   patterns.add<LinearizeConstant, LinearizeVectorizable>(
493       typeConverter, patterns.getContext(), targetBitWidth);
494 }
495 
496 void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns(
497     TypeConverter &typeConverter, RewritePatternSet &patterns,
498     ConversionTarget &target, unsigned int targetBitWidth) {
499   target.addDynamicallyLegalOp<vector::ShuffleOp>(
500       [=](vector::ShuffleOp shuffleOp) -> bool {
501         return isLessThanTargetBitWidth(shuffleOp, targetBitWidth)
502                    ? (typeConverter.isLegal(shuffleOp) &&
503                       cast<mlir::VectorType>(shuffleOp.getResult().getType())
504                               .getRank() == 1)
505                    : true;
506       });
507   patterns.add<LinearizeVectorShuffle, LinearizeVectorExtract,
508                LinearizeVectorInsert, LinearizeVectorExtractStridedSlice>(
509       typeConverter, patterns.getContext(), targetBitWidth);
510 }
511