xref: /llvm-project/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp (revision 3ad0148020ca91cc288bffd8ad36e25f7555a3bb)
1 //===- Vectorization.cpp - Implementation of linalg Vectorization ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the linalg dialect Vectorization transformations.
10 //
11 //===----------------------------------------------------------------------===//
12 #include "mlir/Dialect/Affine/Utils.h"
13 
14 #include "mlir/Analysis/SliceAnalysis.h"
15 #include "mlir/Dialect/Affine/IR/AffineOps.h"
16 #include "mlir/Dialect/Arith/IR/Arith.h"
17 #include "mlir/Dialect/Func/IR/FuncOps.h"
18 #include "mlir/Dialect/Linalg/IR/Linalg.h"
19 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
20 #include "mlir/Dialect/Linalg/Utils/Utils.h"
21 #include "mlir/Dialect/Tensor/IR/Tensor.h"
22 #include "mlir/Dialect/Tensor/Utils/Utils.h"
23 #include "mlir/Dialect/Utils/IndexingUtils.h"
24 #include "mlir/Dialect/Utils/StructuredOpsUtils.h"
25 #include "mlir/Dialect/Vector/IR/VectorOps.h"
26 #include "mlir/Dialect/Vector/Interfaces/MaskableOpInterface.h"
27 #include "mlir/Dialect/Vector/Utils/VectorUtils.h"
28 #include "mlir/IR/AffineExpr.h"
29 #include "mlir/IR/Builders.h"
30 #include "mlir/IR/BuiltinTypeInterfaces.h"
31 #include "mlir/IR/BuiltinTypes.h"
32 #include "mlir/IR/OpDefinition.h"
33 #include "mlir/IR/PatternMatch.h"
34 #include "mlir/Support/LLVM.h"
35 #include "mlir/Transforms/RegionUtils.h"
36 #include "llvm/ADT/STLExtras.h"
37 #include "llvm/ADT/Sequence.h"
38 #include "llvm/ADT/SmallVector.h"
39 #include "llvm/ADT/TypeSwitch.h"
40 #include "llvm/ADT/iterator_range.h"
41 #include "llvm/Support/Debug.h"
42 #include "llvm/Support/MathExtras.h"
43 #include "llvm/Support/raw_ostream.h"
44 #include <optional>
45 #include <type_traits>
46 
47 using namespace mlir;
48 using namespace mlir::linalg;
49 
50 #define DEBUG_TYPE "linalg-vectorization"
51 
52 #define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
53 #define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
54 
55 /// Try to vectorize `convOp` as a convolution.
56 static FailureOr<Operation *>
57 vectorizeConvolution(RewriterBase &rewriter, LinalgOp convOp,
58                      ArrayRef<int64_t> inputVecSizes = {},
59                      ArrayRef<bool> inputVecScalableFlags = {},
60                      bool flatten1DDepthwiseConv = false);
61 
62 /// Return the unique instance of OpType in `block` if it is indeed unique.
63 /// Return null if none or more than 1 instances exist.
64 template <typename OpType>
65 static OpType getSingleOpOfType(Block &block) {
66   OpType res;
67   block.walk([&](OpType op) {
68     if (res) {
69       res = nullptr;
70       return WalkResult::interrupt();
71     }
72     res = op;
73     return WalkResult::advance();
74   });
75   return res;
76 }
77 
78 /// Helper function to extract the input slices after filter is unrolled along
79 /// kw.
80 static SmallVector<Value>
81 extractConvInputSlices(RewriterBase &rewriter, Location loc, Value input,
82                        int64_t nSize, int64_t wSize, int64_t cSize,
83                        int64_t kwSize, int strideW, int dilationW,
84                        int64_t wSizeStep, bool isSingleChanneled) {
85   SmallVector<Value> result;
86   if (isSingleChanneled) {
87     // Extract input slice of size {wSizeStep} @ [w + kw] for non-channeled
88     // convolution.
89     SmallVector<int64_t> sizes{wSizeStep};
90     SmallVector<int64_t> strides{1};
91     for (int64_t kw = 0; kw < kwSize; ++kw) {
92       for (int64_t w = 0; w < wSize; w += wSizeStep) {
93         result.push_back(rewriter.create<vector::ExtractStridedSliceOp>(
94             loc, input, /*offsets=*/ArrayRef<int64_t>{w + kw}, sizes, strides));
95       }
96     }
97   } else {
98     // Extract lhs slice of size {n, wSizeStep, c} @ [0, sw * w + dw * kw, 0]
99     // for channeled convolution.
100     SmallVector<int64_t> sizes{nSize, wSizeStep, cSize};
101     SmallVector<int64_t> strides{1, 1, 1};
102     for (int64_t kw = 0; kw < kwSize; ++kw) {
103       for (int64_t w = 0; w < wSize; w += wSizeStep) {
104         result.push_back(rewriter.create<vector::ExtractStridedSliceOp>(
105             loc, input,
106             /*offsets=*/ArrayRef<int64_t>{0, w * strideW + kw * dilationW, 0},
107             sizes, strides));
108       }
109     }
110   }
111   return result;
112 }
113 
114 /// Helper function to extract the filter slices after filter is unrolled along
115 /// kw.
116 static SmallVector<Value> extractConvFilterSlices(RewriterBase &rewriter,
117                                                   Location loc, Value filter,
118                                                   int64_t kwSize) {
119   SmallVector<Value> result;
120   // Extract rhs slice of size [{c, f} for channeled convolutions and {1} for
121   // non-chanelled convolution] @ [kw].
122   for (int64_t kw = 0; kw < kwSize; ++kw) {
123     result.push_back(rewriter.create<vector::ExtractOp>(
124         loc, filter, /*offsets=*/ArrayRef<int64_t>{kw}));
125   }
126   return result;
127 }
128 
129 /// Helper function to extract the result slices after filter is unrolled along
130 /// kw.
131 static SmallVector<Value>
132 extractConvResultSlices(RewriterBase &rewriter, Location loc, Value res,
133                         int64_t nSize, int64_t wSize, int64_t fSize,
134                         int64_t wSizeStep, bool isSingleChanneled) {
135   SmallVector<Value> result;
136   if (isSingleChanneled) {
137     // Extract res slice: {wSizeStep} @ [w] for non-channeled convolution.
138     SmallVector<int64_t> sizes{wSizeStep};
139     SmallVector<int64_t> strides{1};
140     for (int64_t w = 0; w < wSize; w += wSizeStep) {
141       result.push_back(rewriter.create<vector::ExtractStridedSliceOp>(
142           loc, res, /*offsets=*/ArrayRef<int64_t>{w}, sizes, strides));
143     }
144   } else {
145     // Extract res slice: {n, wSizeStep, f} @ [0, w, 0] for channeled
146     // convolution.
147     SmallVector<int64_t> sizes{nSize, wSizeStep, fSize};
148     SmallVector<int64_t> strides{1, 1, 1};
149     for (int64_t w = 0; w < wSize; w += wSizeStep) {
150       result.push_back(rewriter.create<vector::ExtractStridedSliceOp>(
151           loc, res, /*offsets=*/ArrayRef<int64_t>{0, w, 0}, sizes, strides));
152     }
153   }
154   return result;
155 }
156 
157 /// Helper function to insert the computed result slices.
158 static Value insertConvResultSlices(RewriterBase &rewriter, Location loc,
159                                     Value res, int64_t wSize, int64_t wSizeStep,
160                                     SmallVectorImpl<Value> &resVals,
161                                     bool isSingleChanneled) {
162 
163   if (isSingleChanneled) {
164     // Write back res slice: {wSizeStep} @ [w] for non-channeled convolution.
165     // This does not depend on kw.
166     SmallVector<int64_t> strides{1};
167     for (int64_t w = 0; w < wSize; w += wSizeStep) {
168       res = rewriter.create<vector::InsertStridedSliceOp>(
169           loc, resVals[w], res, /*offsets=*/ArrayRef<int64_t>{w}, strides);
170     }
171   } else {
172     // Write back res slice: {n, wSizeStep, f} @ [0, w, 0] for channeled
173     // convolution. This does not depend on kw.
174     SmallVector<int64_t> strides{1, 1, 1};
175     for (int64_t w = 0; w < wSize; w += wSizeStep) {
176       res = rewriter.create<vector::InsertStridedSliceOp>(
177           loc, resVals[w], res, /*offsets=*/ArrayRef<int64_t>{0, w, 0},
178           strides);
179     }
180   }
181   return res;
182 }
183 
184 /// Contains the vectorization state and related methods used across the
185 /// vectorization process of a given operation.
186 struct VectorizationState {
187   VectorizationState(RewriterBase &rewriter) : rewriterGuard(rewriter) {}
188 
189   /// Initializes the vectorization state, including the computation of the
190   /// canonical vector shape for vectorization.
191   LogicalResult initState(RewriterBase &rewriter, LinalgOp linalgOp,
192                           ArrayRef<int64_t> inputVectorSizes,
193                           ArrayRef<bool> inputScalableVecDims);
194 
195   /// Returns the canonical vector shape used to vectorize the iteration space.
196   ArrayRef<int64_t> getCanonicalVecShape() const { return canonicalVecShape; }
197 
198   /// Returns the vector dimensions that are scalable in the canonical vector
199   /// shape.
200   ArrayRef<bool> getScalableVecDims() const { return scalableVecDims; }
201 
202   /// Returns a vector type of the provided `elementType` with the canonical
203   /// vector shape and the corresponding fixed/scalable dimensions bit. If
204   /// `dimPermutation` is provided, the canonical vector dimensions are permuted
205   /// accordingly.
206   VectorType getCanonicalVecType(
207       Type elementType,
208       std::optional<AffineMap> dimPermutation = std::nullopt) const {
209     SmallVector<int64_t> vectorShape;
210     SmallVector<bool> scalableDims;
211     if (dimPermutation.has_value()) {
212       vectorShape =
213           applyPermutationMap<int64_t>(*dimPermutation, canonicalVecShape);
214       scalableDims =
215           applyPermutationMap<bool>(*dimPermutation, scalableVecDims);
216     } else {
217       vectorShape.append(canonicalVecShape.begin(), canonicalVecShape.end());
218       scalableDims.append(scalableVecDims.begin(), scalableVecDims.end());
219     }
220 
221     return VectorType::get(vectorShape, elementType, scalableDims);
222   }
223 
224   /// Masks an operation with the canonical vector mask if the operation needs
225   /// masking. Returns the masked operation or the original operation if masking
226   /// is not needed. If provided, the canonical mask for this operation is
227   /// permuted using `maybeIndexingMap`.
228   Operation *
229   maskOperation(RewriterBase &rewriter, Operation *opToMask, LinalgOp linalgOp,
230                 std::optional<AffineMap> maybeIndexingMap = std::nullopt);
231 
232 private:
233   /// Initializes the iteration space static sizes using the Linalg op
234   /// information. This may become more complicated in the future.
235   void initIterSpaceStaticSizes(LinalgOp linalgOp) {
236     iterSpaceStaticSizes.append(linalgOp.getStaticLoopRanges());
237   }
238 
239   /// Generates 'arith.constant' and 'tensor/memref.dim' operations for
240   /// all the static and dynamic dimensions of the iteration space to be
241   /// vectorized and store them in `iterSpaceValueSizes`.
242   LogicalResult precomputeIterSpaceValueSizes(RewriterBase &rewriter,
243                                               LinalgOp linalgOp);
244 
245   /// Create or retrieve an existing mask value to mask `opToMask` in the
246   /// canonical vector iteration space. If `maybeMaskingMap` the mask is
247   /// permuted using that permutation map. If a new mask is created, it will be
248   /// cached for future users.
249   Value getOrCreateMaskFor(RewriterBase &rewriter, Operation *opToMask,
250                            LinalgOp linalgOp,
251                            std::optional<AffineMap> maybeMaskingMap);
252 
253   /// Check whether this permutation map can be used for masking. At the
254   /// moment we only make sure that there are no broadcast dimensions, but this
255   /// might change if indexing maps evolve.
256   bool isValidMaskingMap(AffineMap maskingMap) {
257     return maskingMap.getBroadcastDims().size() == 0;
258   }
259 
260   /// Turn the input indexing map into a valid masking map.
261   ///
262   /// The input indexing map may contain "zero" results, e.g.:
263   ///    (d0, d1, d2, d3) -> (d2, d1, d0, 0)
264   /// Applying such maps to canonical vector shapes like this one:
265   ///    (1, 16, 16, 4)
266   /// would yield an invalid vector shape like this:
267   ///    (16, 16, 1, 0)
268   /// Instead, drop the broadcasting dims that make no sense for masking perm.
269   /// maps:
270   ///    (d0, d1, d2, d3) -> (d2, d1, d0)
271   /// This way, the corresponding vector/mask type will be:
272   ///    vector<16x16x1xty>
273   /// rather than this invalid Vector type:
274   ///    vector<16x16x1x0xty>
275   AffineMap getMaskingMapFromIndexingMap(AffineMap &indexingMap) {
276     return indexingMap.dropZeroResults();
277   }
278 
279   // Holds the compile-time static sizes of the iteration space to vectorize.
280   // Dynamic dimensions are represented using ShapedType::kDynamic.
281   SmallVector<int64_t> iterSpaceStaticSizes;
282 
283   /// Holds the value sizes of the iteration space to vectorize. Static
284   /// dimensions are represented by 'arith.constant' and dynamic
285   /// dimensions by 'tensor/memref.dim'.
286   SmallVector<Value> iterSpaceValueSizes;
287 
288   /// Holds the canonical vector shape used to vectorize the iteration space.
289   SmallVector<int64_t> canonicalVecShape;
290 
291   /// Holds the vector dimensions that are scalable in the canonical vector
292   /// shape.
293   SmallVector<bool> scalableVecDims;
294 
295   /// Holds the active masks for permutations of the canonical vector iteration
296   /// space.
297   DenseMap<AffineMap, Value> activeMaskCache;
298 
299   /// Global vectorization guard for the incoming rewriter. It's initialized
300   /// when the vectorization state is initialized.
301   OpBuilder::InsertionGuard rewriterGuard;
302 };
303 
304 LogicalResult
305 VectorizationState::precomputeIterSpaceValueSizes(RewriterBase &rewriter,
306                                                   LinalgOp linalgOp) {
307   // TODO: Support 0-d vectors.
308   for (int vecDim = 0, end = canonicalVecShape.size(); vecDim < end; ++vecDim) {
309     if (!ShapedType::isDynamic(iterSpaceStaticSizes[vecDim])) {
310       // Create constant index op for static dimensions.
311       iterSpaceValueSizes.push_back(rewriter.create<arith::ConstantIndexOp>(
312           linalgOp.getLoc(), iterSpaceStaticSizes[vecDim]));
313       continue;
314     }
315 
316     // Find an operand defined on this dimension of the iteration space to
317     // extract the runtime dimension size.
318     Value operand;
319     unsigned operandDimPos;
320     if (failed(linalgOp.mapIterationSpaceDimToOperandDim(vecDim, operand,
321                                                          operandDimPos)))
322       return failure();
323 
324     Value dynamicDim = linalgOp.hasPureTensorSemantics()
325                            ? (Value)rewriter.create<tensor::DimOp>(
326                                  linalgOp.getLoc(), operand, operandDimPos)
327                            : (Value)rewriter.create<memref::DimOp>(
328                                  linalgOp.getLoc(), operand, operandDimPos);
329     iterSpaceValueSizes.push_back(dynamicDim);
330   }
331 
332   return success();
333 }
334 
335 /// Initializes the vectorization state, including the computation of the
336 /// canonical vector shape for vectorization.
337 // TODO: Move this to the constructor when we can remove the failure cases.
338 LogicalResult
339 VectorizationState::initState(RewriterBase &rewriter, LinalgOp linalgOp,
340                               ArrayRef<int64_t> inputVectorSizes,
341                               ArrayRef<bool> inputScalableVecDims) {
342   // Initialize the insertion point.
343   rewriter.setInsertionPoint(linalgOp);
344 
345   if (!inputVectorSizes.empty()) {
346     // Get the canonical vector shape from the input vector sizes provided. This
347     // path should be taken to vectorize code with dynamic shapes and when using
348     // vector sizes greater than the iteration space sizes.
349     canonicalVecShape.append(inputVectorSizes.begin(), inputVectorSizes.end());
350     scalableVecDims.append(inputScalableVecDims.begin(),
351                            inputScalableVecDims.end());
352   } else {
353     // Compute the canonical vector shape from the operation shape. If there are
354     // dynamic shapes, the operation won't be vectorized. We assume all the
355     // vector dimensions are fixed.
356     canonicalVecShape = linalgOp.getStaticLoopRanges();
357     scalableVecDims.append(linalgOp.getNumLoops(), false);
358   }
359 
360   LDBG("Canonical vector shape: ");
361   LLVM_DEBUG(llvm::interleaveComma(canonicalVecShape, llvm::dbgs()));
362   LLVM_DEBUG(llvm::dbgs() << "\n");
363   LDBG("Scalable vector dims: ");
364   LLVM_DEBUG(llvm::interleaveComma(scalableVecDims, llvm::dbgs()));
365   LLVM_DEBUG(llvm::dbgs() << "\n");
366 
367   if (ShapedType::isDynamicShape(canonicalVecShape))
368     return failure();
369 
370   // Initialize iteration space static sizes.
371   initIterSpaceStaticSizes(linalgOp);
372 
373   // Generate 'arith.constant' and 'tensor/memref.dim' operations for
374   // all the static and dynamic dimensions of the iteration space, needed to
375   // compute a mask during vectorization.
376   if (failed(precomputeIterSpaceValueSizes(rewriter, linalgOp)))
377     return failure();
378 
379   return success();
380 }
381 
382 /// Create or retrieve an existing mask value to mask `opToMask` in the
383 /// canonical vector iteration space. If `maybeMaskingMap` the mask is permuted
384 /// using that permutation map. If a new mask is created, it will be cached for
385 /// future users.
386 Value VectorizationState::getOrCreateMaskFor(
387     RewriterBase &rewriter, Operation *opToMask, LinalgOp linalgOp,
388     std::optional<AffineMap> maybeMaskingMap) {
389 
390   assert((!maybeMaskingMap || isValidMaskingMap(*maybeMaskingMap)) &&
391          "Ill-formed masking map.");
392 
393   // No mask is needed if the operation is not maskable.
394   auto maskableOp = dyn_cast<vector::MaskableOpInterface>(opToMask);
395   if (!maskableOp)
396     return Value();
397 
398   assert(!maskableOp.isMasked() &&
399          "Masking an operation that is already masked");
400 
401   // If no masking map was provided, use an identity map with the loop dims.
402   assert((!maybeMaskingMap || *maybeMaskingMap) &&
403          "Unexpected null mask permutation map");
404   AffineMap maskingMap =
405       maybeMaskingMap ? *maybeMaskingMap
406                       : AffineMap::getMultiDimIdentityMap(
407                             linalgOp.getNumLoops(), rewriter.getContext());
408 
409   LDBG("Masking map: " << maskingMap << "\n");
410 
411   // Return the active mask for the masking map of this operation if it was
412   // already created.
413   auto activeMaskIt = activeMaskCache.find(maskingMap);
414   if (activeMaskIt != activeMaskCache.end()) {
415     Value mask = activeMaskIt->second;
416     LDBG("Reusing mask: " << mask << "\n");
417     return mask;
418   }
419 
420   // Compute permuted projection of the iteration space to be masked and the
421   // corresponding mask shape. If the resulting iteration space dimensions are
422   // static and identical to the mask shape, masking is not needed for this
423   // operation.
424   // TODO: Improve this check. Only projected permutation indexing maps are
425   // supported.
426   SmallVector<int64_t> permutedStaticSizes =
427       applyPermutationMap<int64_t>(maskingMap, iterSpaceStaticSizes);
428   auto maskType = getCanonicalVecType(rewriter.getI1Type(), maskingMap);
429   auto maskShape = maskType.getShape();
430 
431   LDBG("Mask shape: ");
432   LLVM_DEBUG(llvm::interleaveComma(maskShape, llvm::dbgs()));
433   LLVM_DEBUG(llvm::dbgs() << "\n");
434 
435   if (permutedStaticSizes == maskShape) {
436     LDBG("Masking is not needed for masking map: " << maskingMap << "\n");
437     activeMaskCache[maskingMap] = Value();
438     return Value();
439   }
440 
441   // Permute the iteration space value sizes to compute the mask upper bounds.
442   SmallVector<Value> upperBounds =
443       applyPermutationMap(maskingMap, ArrayRef<Value>(iterSpaceValueSizes));
444   assert(!maskShape.empty() && !upperBounds.empty() &&
445          "Masked 0-d vectors are not supported yet");
446 
447   // Create the mask based on the dimension values.
448   Value mask = rewriter.create<vector::CreateMaskOp>(linalgOp.getLoc(),
449                                                      maskType, upperBounds);
450   LDBG("Creating new mask: " << mask << "\n");
451   activeMaskCache[maskingMap] = mask;
452   return mask;
453 }
454 
455 Operation *
456 VectorizationState::maskOperation(RewriterBase &rewriter, Operation *opToMask,
457                                   LinalgOp linalgOp,
458                                   std::optional<AffineMap> maybeIndexingMap) {
459   LDBG("Trying to mask: " << *opToMask << "\n");
460 
461   std::optional<AffineMap> maybeMaskingMap = std::nullopt;
462   if (maybeIndexingMap)
463     maybeMaskingMap = getMaskingMapFromIndexingMap(*maybeIndexingMap);
464 
465   // Create or retrieve mask for this operation.
466   Value mask =
467       getOrCreateMaskFor(rewriter, opToMask, linalgOp, maybeMaskingMap);
468 
469   if (!mask) {
470     LDBG("No mask required\n");
471     return opToMask;
472   }
473 
474   // Wrap the operation with a new `vector.mask` and update D-U chain.
475   assert(opToMask && "Expected a valid operation to mask");
476   auto maskOp = cast<vector::MaskOp>(
477       mlir::vector::maskOperation(rewriter, opToMask, mask));
478   Operation *maskOpTerminator = &maskOp.getMaskRegion().front().back();
479 
480   for (auto [resIdx, resVal] : llvm::enumerate(opToMask->getResults()))
481     rewriter.replaceAllUsesExcept(resVal, maskOp.getResult(resIdx),
482                                   maskOpTerminator);
483 
484   LDBG("Masked operation: " << *maskOp << "\n");
485   return maskOp;
486 }
487 
488 /// Given an indexing `map` coming from a LinalgOp indexing, restricted to a
489 /// projectedPermutation, compress the unused dimensions to serve as a
490 /// permutation_map for a vector transfer operation.
491 /// For example, given a linalg op such as:
492 ///
493 /// ```
494 ///   %0 = linalg.generic {
495 ///        indexing_maps = affine_map<(d0, d1, d2, d3, d4) -> (d4, d0, d2)>,
496 ///        indexing_maps = affine_map<(d0, d1, d2, d3, d4) -> (d1, d3)>
497 ///      }
498 ///     ins(%0 : tensor<2x3x4xf32>)
499 ///    outs(%1 : tensor<5x6xf32>)
500 /// ```
501 ///
502 /// the iteration domain size of the linalg op is 3x5x4x6x2. The first affine
503 /// map is reindexed to `affine_map<(d0, d1, d2) -> (d2, d0, d1)>`, the second
504 /// affine map is reindexed to `affine_map<(d0, d1) -> (d0, d1)>`.
505 static AffineMap reindexIndexingMap(AffineMap map) {
506   assert(map.isProjectedPermutation(/*allowZeroInResults=*/true) &&
507          "expected projected permutation");
508   auto res = compressUnusedDims(map);
509   assert(res.getNumDims() ==
510              (res.getNumResults() - res.getNumOfZeroResults()) &&
511          "expected reindexed map with same number of dims and results");
512   return res;
513 }
514 
515 /// Helper enum to represent conv1d input traversal order.
516 enum class Conv1DOpOrder {
517   W,   // Corresponds to non-channeled 1D convolution operation.
518   Ncw, // Corresponds to operation that traverses the input in (n, c, w) order.
519   Nwc  // Corresponds to operation that traverses the input in (n, w, c) order.
520 };
521 
522 /// Helper data structure to represent the result of vectorization.
523 /// In certain specific cases, like terminators, we do not want to propagate/
524 enum VectorizationStatus {
525   /// Op failed to vectorize.
526   Failure = 0,
527   /// Op vectorized and custom function took care of replacement logic
528   NoReplace,
529   /// Op vectorized into a new Op whose results will replace original Op's
530   /// results.
531   NewOp
532   // TODO: support values if Op vectorized to Many-Ops whose results we need to
533   // aggregate for replacement.
534 };
535 struct VectorizationResult {
536   /// Return status from vectorizing the current op.
537   enum VectorizationStatus status = VectorizationStatus::Failure;
538   /// New vectorized operation to replace the current op.
539   /// Replacement behavior is specified by `status`.
540   Operation *newOp;
541 };
542 
543 std::optional<vector::CombiningKind>
544 mlir::linalg::getCombinerOpKind(Operation *combinerOp) {
545   using ::mlir::vector::CombiningKind;
546 
547   if (!combinerOp)
548     return std::nullopt;
549   return llvm::TypeSwitch<Operation *, std::optional<CombiningKind>>(combinerOp)
550       .Case<arith::AddIOp, arith::AddFOp>(
551           [&](auto op) { return CombiningKind::ADD; })
552       .Case<arith::AndIOp>([&](auto op) { return CombiningKind::AND; })
553       .Case<arith::MaxSIOp>([&](auto op) { return CombiningKind::MAXSI; })
554       .Case<arith::MaxUIOp>([&](auto op) { return CombiningKind::MAXUI; })
555       .Case<arith::MaximumFOp>([&](auto op) { return CombiningKind::MAXIMUMF; })
556       .Case<arith::MaxNumFOp>([&](auto op) { return CombiningKind::MAXNUMF; })
557       .Case<arith::MinSIOp>([&](auto op) { return CombiningKind::MINSI; })
558       .Case<arith::MinUIOp>([&](auto op) { return CombiningKind::MINUI; })
559       .Case<arith::MinimumFOp>([&](auto op) { return CombiningKind::MINIMUMF; })
560       .Case<arith::MinNumFOp>([&](auto op) { return CombiningKind::MINNUMF; })
561       .Case<arith::MulIOp, arith::MulFOp>(
562           [&](auto op) { return CombiningKind::MUL; })
563       .Case<arith::OrIOp>([&](auto op) { return CombiningKind::OR; })
564       .Case<arith::XOrIOp>([&](auto op) { return CombiningKind::XOR; })
565       .Default([&](auto op) { return std::nullopt; });
566 }
567 
568 /// Check whether `outputOperand` is a reduction with a single combiner
569 /// operation. Return the combiner operation of the reduction. Return
570 /// nullptr otherwise. Multiple reduction operations would impose an
571 /// ordering between reduction dimensions and is currently unsupported in
572 /// Linalg. This limitation is motivated by the fact that e.g. min(max(X)) !=
573 /// max(min(X))
574 // TODO: use in LinalgOp verification, there is a circular dependency atm.
575 static Operation *matchLinalgReduction(OpOperand *outputOperand) {
576   auto linalgOp = cast<LinalgOp>(outputOperand->getOwner());
577   unsigned outputPos =
578       outputOperand->getOperandNumber() - linalgOp.getNumDpsInputs();
579   // Only single combiner operations are supported for now.
580   SmallVector<Operation *, 4> combinerOps;
581   if (!matchReduction(linalgOp.getRegionOutputArgs(), outputPos, combinerOps) ||
582       combinerOps.size() != 1)
583     return nullptr;
584 
585   // Return the combiner operation.
586   return combinerOps[0];
587 }
588 
589 /// Broadcast `value` to a vector of `shape` if possible. Return value
590 /// otherwise.
591 static Value broadcastIfNeeded(OpBuilder &b, Value value, Type dstType) {
592   auto dstVecType = dyn_cast<VectorType>(dstType);
593   // If no shape to broadcast to, just return `value`.
594   if (dstVecType.getRank() == 0)
595     return value;
596   if (vector::isBroadcastableTo(value.getType(), dstVecType) !=
597       vector::BroadcastableToResult::Success)
598     return value;
599   Location loc = b.getInsertionPoint()->getLoc();
600   return b.createOrFold<vector::BroadcastOp>(loc, dstVecType, value);
601 }
602 
603 /// Create MultiDimReductionOp to compute the reduction for `reductionOp`. This
604 /// assumes that `reductionOp` has two operands and one of them is the reduction
605 /// initial value.buildMultiDimReduce
606 // Note: this is a true builder that notifies the OpBuilder listener.
607 // TODO: Consider moving as a static helper on the ReduceOp.
608 static Operation *buildMultiDimReduce(OpBuilder &b, Operation *reduceOp,
609                                       Value valueToReduce, Value acc,
610                                       ArrayRef<bool> dimsToMask) {
611   auto maybeKind = getCombinerOpKind(reduceOp);
612   assert(maybeKind && "Failed precondition: could not get reduction kind");
613   return b.create<vector::MultiDimReductionOp>(
614       reduceOp->getLoc(), valueToReduce, acc, dimsToMask, *maybeKind);
615 }
616 
617 static SmallVector<bool> getDimsToReduce(LinalgOp linalgOp) {
618   return llvm::to_vector(
619       llvm::map_range(linalgOp.getIteratorTypesArray(), isReductionIterator));
620 }
621 
622 /// Check if `op` is a linalg.reduce or a linalg.generic that has at least one
623 /// reduction iterator.
624 static bool hasReductionIterator(LinalgOp &op) {
625   return isa<linalg::ReduceOp>(op) ||
626          (isa<linalg::GenericOp>(op) &&
627           llvm::any_of(op.getIteratorTypesArray(), isReductionIterator));
628 }
629 
630 /// Build a vector.transfer_write of `value` into `outputOperand` at indices set
631 /// to all `0`; where `outputOperand` is an output operand of the LinalgOp
632 /// currently being vectorized. If `dest` has null rank, build an memref.store.
633 /// Return the produced value or null if no value is produced.
634 // Note: this is a true builder that notifies the OpBuilder listener.
635 // TODO: Consider moving as a static helper on the ReduceOp.
636 static Value buildVectorWrite(RewriterBase &rewriter, Value value,
637                               OpOperand *outputOperand,
638                               VectorizationState &state) {
639   Location loc = value.getLoc();
640   auto linalgOp = cast<LinalgOp>(outputOperand->getOwner());
641   AffineMap opOperandMap = linalgOp.getMatchingIndexingMap(outputOperand);
642 
643   // Compute the vector type of the value to store. This type should be an
644   // identity or projection of the canonical vector type without any permutation
645   // applied, given that any permutation in a transfer write happens as part of
646   // the write itself.
647   AffineMap vectorTypeMap = AffineMap::getFilteredIdentityMap(
648       opOperandMap.getContext(), opOperandMap.getNumInputs(),
649       [&](AffineDimExpr dimExpr) -> bool {
650         return llvm::is_contained(opOperandMap.getResults(), dimExpr);
651       });
652   auto vectorType = state.getCanonicalVecType(
653       getElementTypeOrSelf(outputOperand->get().getType()), vectorTypeMap);
654 
655   Operation *write;
656   if (vectorType.getRank() > 0) {
657     AffineMap writeMap = inversePermutation(reindexIndexingMap(opOperandMap));
658     SmallVector<Value> indices(linalgOp.getRank(outputOperand),
659                                rewriter.create<arith::ConstantIndexOp>(loc, 0));
660     value = broadcastIfNeeded(rewriter, value, vectorType);
661     assert(value.getType() == vectorType && "Incorrect type");
662     write = rewriter.create<vector::TransferWriteOp>(
663         loc, value, outputOperand->get(), indices, writeMap);
664   } else {
665     // 0-d case is still special: do not invert the reindexing writeMap.
666     if (!isa<VectorType>(value.getType()))
667       value = rewriter.create<vector::BroadcastOp>(loc, vectorType, value);
668     assert(value.getType() == vectorType && "Incorrect type");
669     write = rewriter.create<vector::TransferWriteOp>(
670         loc, value, outputOperand->get(), ValueRange{});
671   }
672 
673   write = state.maskOperation(rewriter, write, linalgOp, opOperandMap);
674 
675   // If masked, set in-bounds to true. Masking guarantees that the access will
676   // be in-bounds.
677   if (auto maskOp = dyn_cast<vector::MaskingOpInterface>(write)) {
678     auto maskedWriteOp = cast<vector::TransferWriteOp>(maskOp.getMaskableOp());
679     SmallVector<bool> inBounds(maskedWriteOp.getVectorType().getRank(), true);
680     maskedWriteOp.setInBoundsAttr(rewriter.getBoolArrayAttr(inBounds));
681   }
682 
683   LDBG("vectorized op: " << *write << "\n");
684   if (!write->getResults().empty())
685     return write->getResult(0);
686   return Value();
687 }
688 
689 // Custom vectorization precondition function type. This is intented to be used
690 // with CustomVectorizationHook. Returns success if the corresponding custom
691 // hook can vectorize the op.
692 using CustomVectorizationPrecondition =
693     std::function<LogicalResult(Operation *, bool)>;
694 
695 // Custom vectorization function type. Produce a vector form of Operation*
696 // assuming all its vectorized operands are already in the IRMapping.
697 // Return nullptr if the Operation cannot be vectorized.
698 using CustomVectorizationHook =
699     std::function<VectorizationResult(Operation *, const IRMapping &)>;
700 
701 /// Helper function to vectorize the terminator of a `linalgOp`. New result
702 /// vector values are appended to `newResults`. Return
703 /// VectorizationStatus::NoReplace to signal the vectorization algorithm that it
704 /// should not try to map produced operations and instead return the results
705 /// using the `newResults` vector making them available to the vectorization
706 /// algorithm for RAUW. This function is meant to be used as a
707 /// CustomVectorizationHook.
708 static VectorizationResult
709 vectorizeLinalgYield(RewriterBase &rewriter, Operation *op,
710                      const IRMapping &bvm, VectorizationState &state,
711                      LinalgOp linalgOp, SmallVectorImpl<Value> &newResults) {
712   auto yieldOp = dyn_cast<linalg::YieldOp>(op);
713   if (!yieldOp)
714     return VectorizationResult{VectorizationStatus::Failure, nullptr};
715   for (const auto &output : llvm::enumerate(yieldOp.getValues())) {
716     // TODO: Scan for an opportunity for reuse.
717     // TODO: use a map.
718     Value vectorValue = bvm.lookup(output.value());
719     Value newResult =
720         buildVectorWrite(rewriter, vectorValue,
721                          linalgOp.getDpsInitOperand(output.index()), state);
722     if (newResult)
723       newResults.push_back(newResult);
724   }
725 
726   return VectorizationResult{VectorizationStatus::NoReplace, nullptr};
727 }
728 
729 /// Helper function to vectorize the index operations of a `linalgOp`. Return
730 /// VectorizationStatus::NewOp to signal the vectorization algorithm that it
731 /// should map the produced operations. This function is meant to be used as a
732 /// CustomVectorizationHook.
733 static VectorizationResult vectorizeLinalgIndex(RewriterBase &rewriter,
734                                                 VectorizationState &state,
735                                                 Operation *op,
736                                                 LinalgOp linalgOp) {
737   IndexOp indexOp = dyn_cast<linalg::IndexOp>(op);
738   if (!indexOp)
739     return VectorizationResult{VectorizationStatus::Failure, nullptr};
740   auto loc = indexOp.getLoc();
741   // Compute the static loop sizes of the index op.
742   ArrayRef<int64_t> targetShape = state.getCanonicalVecShape();
743   auto dim = indexOp.getDim();
744   // Compute a one-dimensional index vector for the index op dimension.
745   auto indexVectorType =
746       VectorType::get({targetShape[dim]}, rewriter.getIndexType(),
747                       state.getScalableVecDims()[dim]);
748   auto indexSteps = rewriter.create<vector::StepOp>(loc, indexVectorType);
749   // Return the one-dimensional index vector if it lives in the trailing
750   // dimension of the iteration space since the vectorization algorithm in this
751   // case can handle the broadcast.
752   if (dim == targetShape.size() - 1)
753     return VectorizationResult{VectorizationStatus::NewOp, indexSteps};
754   // Otherwise permute the targetShape to move the index dimension last,
755   // broadcast the one-dimensional index vector to the permuted shape, and
756   // finally transpose the broadcasted index vector to undo the permutation.
757   auto permPattern =
758       llvm::to_vector(llvm::seq<unsigned>(0, targetShape.size()));
759   std::swap(permPattern[dim], permPattern.back());
760   auto permMap =
761       AffineMap::getPermutationMap(permPattern, linalgOp.getContext());
762 
763   auto broadCastOp = rewriter.create<vector::BroadcastOp>(
764       loc, state.getCanonicalVecType(rewriter.getIndexType(), permMap),
765       indexSteps);
766   SmallVector<int64_t> transposition =
767       llvm::to_vector<16>(llvm::seq<int64_t>(0, linalgOp.getNumLoops()));
768   std::swap(transposition.back(), transposition[dim]);
769   auto transposeOp =
770       rewriter.create<vector::TransposeOp>(loc, broadCastOp, transposition);
771   return VectorizationResult{VectorizationStatus::NewOp, transposeOp};
772 }
773 
774 /// Helper function to check if the tensor.extract can be vectorized by the
775 /// custom hook vectorizeTensorExtract.
776 static LogicalResult
777 tensorExtractVectorizationPrecondition(Operation *op, bool vectorizeNDExtract) {
778   tensor::ExtractOp extractOp = dyn_cast<tensor::ExtractOp>(op);
779   if (!extractOp)
780     return failure();
781 
782   if (extractOp.getIndices().size() != 1 && !vectorizeNDExtract)
783     return failure();
784 
785   // Check the index type, but only for non 0-d tensors (for which we do need
786   // access indices).
787   if (not extractOp.getIndices().empty()) {
788     if (!VectorType::isValidElementType(extractOp.getIndices()[0].getType()))
789       return failure();
790   }
791 
792   if (llvm::any_of(extractOp->getResultTypes(), [](Type type) {
793         return !VectorType::isValidElementType(type);
794       })) {
795     return failure();
796   }
797 
798   return success();
799 }
800 
801 /// Calculates the offsets (`$index_vec`) for `vector.gather` operations
802 /// generated from `tensor.extract`. The offset is calculated as follows
803 /// (example using scalar values):
804 ///
805 ///    offset = extractOp.indices[0]
806 ///    for (i = 1; i < numIndices; i++)
807 ///      offset = extractOp.dimSize[i] * offset + extractOp.indices[i];
808 ///
809 /// For tensor<45 x 80 x 15 x f32> and index [1, 2, 3], this leads to:
810 ///  offset = ( ( 1 ) * 80 +  2 ) * 15  + 3
811 static Value calculateGatherOffset(RewriterBase &rewriter,
812                                    VectorizationState &state,
813                                    tensor::ExtractOp extractOp,
814                                    const IRMapping &bvm) {
815   // The vector of indices for GatherOp should be shaped as the output vector.
816   auto indexVecType = state.getCanonicalVecType(rewriter.getIndexType());
817   auto loc = extractOp.getLoc();
818 
819   Value offset = broadcastIfNeeded(
820       rewriter, bvm.lookup(extractOp.getIndices()[0]), indexVecType);
821 
822   const size_t numIndices = extractOp.getIndices().size();
823   for (size_t i = 1; i < numIndices; i++) {
824     Value dimIdx = rewriter.create<arith::ConstantIndexOp>(loc, i);
825 
826     auto dimSize = broadcastIfNeeded(
827         rewriter,
828         rewriter.create<tensor::DimOp>(loc, extractOp.getTensor(), dimIdx),
829         indexVecType);
830 
831     offset = rewriter.create<arith::MulIOp>(loc, offset, dimSize);
832 
833     auto extractOpIndex = broadcastIfNeeded(
834         rewriter, bvm.lookup(extractOp.getIndices()[i]), indexVecType);
835 
836     offset = rewriter.create<arith::AddIOp>(loc, extractOpIndex, offset);
837   }
838 
839   return offset;
840 }
841 
842 enum VectorMemoryAccessKind { ScalarBroadcast, Contiguous, Gather };
843 
844 /// Find the index of the trailing non-unit dim in linalgOp. This hook is used
845 /// when checking whether `tensor.extract` Op (within a `linalg.generic` Op)
846 /// represents a contiguous load operation.
847 ///
848 /// Note that when calling this hook, it is assumed that the output vector is
849 /// effectively 1D. Other cases (i.e. reading n-D vectors) should've been
850 /// labelled as a gather load before entering this method.
851 ///
852 /// Following on from the above, it is assumed that:
853 ///   * for statically shaped loops, when no masks are used, only one dim is !=
854 ///   1 (that's what the shape of the output vector is based on).
855 ///   * for dynamically shaped loops, there might be more non-unit dims
856 ///   as the output vector type is user-specified.
857 ///
858 /// TODO: Statically shaped loops + vector masking
859 static uint64_t getTrailingNonUnitLoopDimIdx(LinalgOp linalgOp) {
860   SmallVector<int64_t> loopRanges = linalgOp.getStaticLoopRanges();
861   assert(
862       (linalgOp.hasDynamicShape() ||
863        llvm::count_if(loopRanges, [](int64_t dim) { return dim != 1; }) == 1) &&
864       "For statically shaped Linalg Ops, only one "
865       "non-unit loop dim is expected");
866   assert(loopRanges.size() != 0 && "Empty loops, nothing to analyse.");
867 
868   size_t idx = loopRanges.size() - 1;
869   for (; idx != 0; idx--)
870     if (loopRanges[idx] != 1)
871       break;
872 
873   return idx;
874 }
875 
876 /// Checks whether `val` can be used for calculating a loop invariant index.
877 static bool isLoopInvariantIdx(LinalgOp &linalgOp, Value &val,
878                                VectorType resType) {
879 
880   assert(((llvm::count_if(resType.getShape(),
881                           [](int64_t dimSize) { return dimSize > 1; }) == 1)) &&
882          "n-D vectors are not yet supported");
883 
884   // Blocks outside _this_ linalg.generic are effectively loop invariant.
885   // However, analysing block arguments for _this_ linalg.generic Op is a bit
886   // tricky. Just bail out in the latter case.
887   // TODO: We could try analysing the corresponding affine map here.
888   auto *block = linalgOp.getBlock();
889   if (isa<BlockArgument>(val))
890     return llvm::all_of(block->getArguments(),
891                         [&val](Value v) { return (v != val); });
892 
893   Operation *defOp = val.getDefiningOp();
894   assert(defOp && "This is neither a block argument nor an operation result");
895 
896   // IndexOp is loop invariant as long as its result remains constant across
897   // iterations. Note that for dynamic shapes, the corresponding dim will also
898   // be conservatively treated as != 1.
899   if (auto indexOp = dyn_cast<linalg::IndexOp>(defOp)) {
900     return linalgOp.getStaticLoopRanges()[indexOp.getDim()] == 1;
901   }
902 
903   auto *ancestor = block->findAncestorOpInBlock(*defOp);
904 
905   // Values define outside `linalgOp` are loop invariant.
906   if (!ancestor)
907     return true;
908 
909   // Values defined inside `linalgOp`, which are constant, are loop invariant.
910   if (isa<arith::ConstantOp>(ancestor))
911     return true;
912 
913   bool result = true;
914   for (auto op : ancestor->getOperands())
915     result &= isLoopInvariantIdx(linalgOp, op, resType);
916 
917   return result;
918 }
919 
920 /// Check whether `val` could be used for calculating the trailing index for a
921 /// contiguous load operation.
922 ///
923 /// There are currently 3 types of values that are allowed here:
924 ///   1. loop-invariant values,
925 ///   2. values that increment by 1 with every loop iteration,
926 ///   3. results of basic arithmetic operations (linear and continuous)
927 ///      involving 1., 2. and 3.
928 /// This method returns True if indeed only such values are used in calculating
929 /// `val.`
930 ///
931 /// Additionally, the trailing index for a contiguous load operation should
932 /// increment by 1 with every loop iteration, i.e. be based on:
933 ///   * `linalg.index <dim>` ,
934 /// where <dim> is the trailing non-unit dim of the iteration space (this way,
935 /// `linalg.index <dim>` increments by 1 with every loop iteration).
936 /// `foundIndexOp` is updated to `true` when such Op is found.
937 static bool isContiguousLoadIdx(LinalgOp &linalgOp, Value &val,
938                                 bool &foundIndexOp, VectorType resType) {
939 
940   assert(((llvm::count_if(resType.getShape(),
941                           [](int64_t dimSize) { return dimSize > 1; }) == 1)) &&
942          "n-D vectors are not yet supported");
943 
944   // Blocks outside _this_ linalg.generic are effectively loop invariant.
945   // However, analysing block arguments for _this_ linalg.generic Op is a bit
946   // tricky. Just bail out in the latter case.
947   // TODO: We could try analysing the corresponding affine map here.
948   auto *block = linalgOp.getBlock();
949   if (isa<BlockArgument>(val))
950     return llvm::all_of(block->getArguments(),
951                         [&val](Value v) { return (v != val); });
952 
953   Operation *defOp = val.getDefiningOp();
954   assert(defOp && "This is neither a block argument nor an operation result");
955 
956   if (auto indexOp = dyn_cast<linalg::IndexOp>(defOp)) {
957     auto loopDimThatIncrementsByOne = getTrailingNonUnitLoopDimIdx(linalgOp);
958 
959     foundIndexOp = (indexOp.getDim() == loopDimThatIncrementsByOne);
960     return true;
961   }
962 
963   auto *ancestor = block->findAncestorOpInBlock(*defOp);
964 
965   if (!ancestor)
966     return false;
967 
968   // Conservatively reject Ops that could lead to indices with stride other
969   // than 1.
970   if (!isa<arith::AddIOp, arith::ConstantOp, linalg::IndexOp>(ancestor))
971     return false;
972 
973   bool result = false;
974   for (auto op : ancestor->getOperands())
975     result |= isContiguousLoadIdx(linalgOp, op, foundIndexOp, resType);
976 
977   return result;
978 }
979 
980 /// Infer the memory access pattern for the input ExtractOp
981 ///
982 /// Based on the ExtratOp result shape and the access indices, decides whether
983 /// this Op corresponds to a contiguous load (including a broadcast of a scalar)
984 /// or a gather load. When analysing the ExtractOp indices (to identify
985 /// contiguous laods), this method looks for "loop" invariant indices (e.g.
986 /// block arguments) and indices that change linearly (e.g. via `linalg.index`
987 /// Op).
988 ///
989 /// Note that it is always safe to use gather load operations for contiguous
990 /// loads (albeit slow), but not vice-versa. When in doubt, bail out and assume
991 /// that `extractOp` is a gather load.
992 static VectorMemoryAccessKind
993 getTensorExtractMemoryAccessPattern(tensor::ExtractOp extractOp,
994                                     LinalgOp &linalgOp, VectorType resType) {
995 
996   auto inputShape = cast<ShapedType>(extractOp.getTensor().getType());
997 
998   // 0. Is this a 0-D vector? If yes then this is a scalar broadcast.
999   if (inputShape.getShape().empty())
1000     return VectorMemoryAccessKind::ScalarBroadcast;
1001 
1002   // True for vectors that are effectively 1D, e.g. `vector<1x4x1xi32>`, false
1003   // otherwise.
1004   bool isOutput1DVector =
1005       (llvm::count_if(resType.getShape(),
1006                       [](int64_t dimSize) { return dimSize > 1; }) == 1);
1007   // 1. Assume that it's a gather load when reading non-1D vector.
1008   if (!isOutput1DVector)
1009     return VectorMemoryAccessKind::Gather;
1010 
1011   bool leadingIdxsLoopInvariant = true;
1012 
1013   // 2. Analyze the leading indices of `extractOp`.
1014   // Look at the way each index is calculated and decide whether it is suitable
1015   // for a contiguous load, i.e. whether it's loop invariant. If not, it's a
1016   // gather load.
1017   auto indices = extractOp.getIndices();
1018   auto leadIndices = indices.drop_back(1);
1019 
1020   for (auto [i, indexVal] : llvm::enumerate(leadIndices)) {
1021     if (inputShape.getShape()[i] == 1)
1022       continue;
1023 
1024     leadingIdxsLoopInvariant &= isLoopInvariantIdx(linalgOp, indexVal, resType);
1025   }
1026 
1027   if (!leadingIdxsLoopInvariant) {
1028     LDBG("Found gather load: " << extractOp);
1029     return VectorMemoryAccessKind::Gather;
1030   }
1031 
1032   // 3. Analyze the trailing index for `extractOp`.
1033   // At this point we know that the leading indices are loop invariant. This
1034   // means that is potentially a scalar or a contiguous load. We can decide
1035   // based on the trailing idx.
1036   auto extractOpTrailingIdx = indices.back();
1037 
1038   // 3a. Scalar broadcast load
1039   // If the trailing index is loop invariant then this is a scalar load.
1040   if (leadingIdxsLoopInvariant &&
1041       isLoopInvariantIdx(linalgOp, extractOpTrailingIdx, resType)) {
1042     LDBG("Found scalar broadcast load: " << extractOp);
1043 
1044     return VectorMemoryAccessKind::ScalarBroadcast;
1045   }
1046 
1047   // 3b. Contiguous loads
1048   // The trailing `extractOp` index should increment with every loop iteration.
1049   // This effectively means that it must be based on the trailing loop index.
1050   // This is what the following bool captures.
1051   bool foundIndexOp = false;
1052   bool isContiguousLoad = isContiguousLoadIdx(linalgOp, extractOpTrailingIdx,
1053                                               foundIndexOp, resType);
1054   // TODO: Support generating contiguous loads for column vectors - that will
1055   // require adding a permutation map to tranfer_read Ops.
1056   bool isRowVector = resType.getShape().back() != 1;
1057   isContiguousLoad &= (foundIndexOp && isRowVector);
1058 
1059   if (isContiguousLoad) {
1060     LDBG("Found contigous load: " << extractOp);
1061     return VectorMemoryAccessKind::Contiguous;
1062   }
1063 
1064   // 4. Fallback case - gather load.
1065   LDBG("Found gather load: " << extractOp);
1066   return VectorMemoryAccessKind::Gather;
1067 }
1068 
1069 /// Helper function to vectorize the tensor.extract operations. Returns
1070 /// VectorizationStatus::NewOp to signal the vectorization algorithm that it
1071 /// should map the produced operations. This function is meant to be used as a
1072 /// CustomVectorizationHook.
1073 static VectorizationResult
1074 vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state,
1075                        Operation *op, LinalgOp linalgOp, const IRMapping &bvm) {
1076   tensor::ExtractOp extractOp = dyn_cast<tensor::ExtractOp>(op);
1077   if (!extractOp)
1078     return VectorizationResult{VectorizationStatus::Failure, nullptr};
1079   auto loc = extractOp.getLoc();
1080 
1081   // Compute the static loop sizes of the extract op.
1082   auto resultType = state.getCanonicalVecType(extractOp.getResult().getType());
1083   auto maskConstantOp = rewriter.create<arith::ConstantOp>(
1084       loc,
1085       DenseIntElementsAttr::get(state.getCanonicalVecType(rewriter.getI1Type()),
1086                                 /*value=*/true));
1087   auto passThruConstantOp =
1088       rewriter.create<arith::ConstantOp>(loc, rewriter.getZeroAttr(resultType));
1089 
1090   // Base indices are currently set to 0. We will need to re-visit if more
1091   // generic scenarios are to be supported.
1092   SmallVector<Value> baseIndices(
1093       extractOp.getIndices().size(),
1094       rewriter.create<arith::ConstantIndexOp>(loc, 0));
1095 
1096   VectorMemoryAccessKind memAccessKind =
1097       getTensorExtractMemoryAccessPattern(extractOp, linalgOp, resultType);
1098 
1099   // 1. Handle gather access
1100   if (memAccessKind == VectorMemoryAccessKind::Gather) {
1101     Value offset = calculateGatherOffset(rewriter, state, extractOp, bvm);
1102 
1103     // Generate the gather load
1104     Operation *gatherOp = rewriter.create<vector::GatherOp>(
1105         loc, resultType, extractOp.getTensor(), baseIndices, offset,
1106         maskConstantOp, passThruConstantOp);
1107     gatherOp = state.maskOperation(rewriter, gatherOp, linalgOp);
1108 
1109     LDBG("Vectorised as gather load: " << extractOp << "\n");
1110     return VectorizationResult{VectorizationStatus::NewOp, gatherOp};
1111   }
1112 
1113   // 2. Handle:
1114   //  a. scalar loads + broadcast,
1115   //  b. contiguous loads.
1116   // Both cases use vector.transfer_read.
1117 
1118   assert(llvm::count_if(resultType.getShape(),
1119                         [](uint64_t dim) { return dim != 1; }) &&
1120          "Contiguous loads and scalar loads + broadcast only support 1-D "
1121          "vectors ATM!");
1122 
1123   // Collect indices for `vector.transfer_read`. At this point, the indices will
1124   // either be scalars or would have been broadcast to vectors matching the
1125   // result type. For indices that are vectors, there are two options:
1126   //    * for non-trailing indices, all elements are identical (contiguous
1127   //      loads are identified by looking for non-trailing indices that are
1128   //      invariant with respect to the corresponding linalg.generic), or
1129   //    * for trailing indices, the index vector will contain values with stride
1130   //      one, but for `vector.transfer_read` only the first (i.e. 0th) index is
1131   //      needed.
1132   // This means that
1133   //   * for scalar indices - just re-use it,
1134   //   * for vector indices (e.g. `vector<1x1x4xindex>`) - extract the bottom
1135   //    (0th) element and use that.
1136   SmallVector<Value> transferReadIdxs;
1137   auto zero = rewriter.create<arith::ConstantOp>(
1138       loc, rewriter.getI32Type(), rewriter.getZeroAttr(rewriter.getI32Type()));
1139   for (size_t i = 0; i < extractOp.getIndices().size(); i++) {
1140     Value idx = bvm.lookup(extractOp.getIndices()[i]);
1141     if (idx.getType().isIndex()) {
1142       transferReadIdxs.push_back(idx);
1143       continue;
1144     }
1145 
1146     auto indexAs1dVector = rewriter.create<vector::ShapeCastOp>(
1147         loc,
1148         VectorType::get(resultType.getShape().back(), rewriter.getIndexType(),
1149                         resultType.getScalableDims().back()),
1150         idx);
1151     transferReadIdxs.push_back(
1152         rewriter.create<vector::ExtractElementOp>(loc, indexAs1dVector, zero));
1153   }
1154 
1155   // `tensor.extract_element` is always in-bounds, hence the following holds.
1156   auto dstRank = resultType.getRank();
1157   auto srcRank = extractOp.getTensor().getType().getRank();
1158   SmallVector<bool> inBounds(dstRank, true);
1159 
1160   // 2a. Handle scalar broadcast access.
1161   if (memAccessKind == VectorMemoryAccessKind::ScalarBroadcast) {
1162     MLIRContext *ctx = rewriter.getContext();
1163     SmallVector<AffineExpr> exprs(dstRank, getAffineConstantExpr(0, ctx));
1164     auto permutationMap = AffineMap::get(srcRank, 0, exprs, ctx);
1165 
1166     auto transferReadOp = rewriter.create<vector::TransferReadOp>(
1167         loc, resultType, extractOp.getTensor(), transferReadIdxs,
1168         permutationMap, inBounds);
1169 
1170     LDBG("Vectorised as scalar broadcast load: " << extractOp << "\n");
1171     return VectorizationResult{VectorizationStatus::NewOp, transferReadOp};
1172   }
1173 
1174   // 2b. Handle contiguous access.
1175   auto permutationMap = AffineMap::getMinorIdentityMap(
1176       srcRank, std::min(dstRank, srcRank), rewriter.getContext());
1177 
1178   int32_t rankDiff = dstRank - srcRank;
1179   // When dstRank > srcRank, broadcast the source tensor to the unitary leading
1180   // dims so that the ranks match. This is done by extending the map with 0s.
1181   // For example, for dstRank = 3, srcRank = 2, the following map created
1182   // above:
1183   //    (d0, d1) --> (d0, d1)
1184   // is extended as:
1185   //    (d0, d1) --> (0, d0, d1)
1186   while (rankDiff > 0) {
1187     permutationMap = permutationMap.insertResult(
1188         mlir::getAffineConstantExpr(0, rewriter.getContext()), 0);
1189     rankDiff--;
1190   }
1191 
1192   auto transferReadOp = rewriter.create<vector::TransferReadOp>(
1193       loc, resultType, extractOp.getTensor(), transferReadIdxs, permutationMap,
1194       inBounds);
1195 
1196   LDBG("Vectorised as contiguous load: " << extractOp);
1197   return VectorizationResult{VectorizationStatus::NewOp, transferReadOp};
1198 }
1199 
1200 /// Emit reduction operations if the shapes of the value to reduce is different
1201 /// that the result shape.
1202 // Note: this is a true builder that notifies the OpBuilder listener.
1203 // TODO: Consider moving as a static helper on the ReduceOp.
1204 static Operation *reduceIfNeeded(OpBuilder &b, LinalgOp linalgOp, Operation *op,
1205                                  Value reduceValue, Value initialValue,
1206                                  const IRMapping &bvm) {
1207   Value reduceVec = bvm.lookup(reduceValue);
1208   Value outputVec = bvm.lookup(initialValue);
1209   auto reduceType = dyn_cast<VectorType>(reduceVec.getType());
1210   auto outputType = dyn_cast<VectorType>(outputVec.getType());
1211   // Reduce only if needed as the value may already have been reduce for
1212   // contraction vectorization.
1213   if (!reduceType ||
1214       (outputType && reduceType.getShape() == outputType.getShape()))
1215     return nullptr;
1216   SmallVector<bool> dimsToMask = getDimsToReduce(linalgOp);
1217   return buildMultiDimReduce(b, op, reduceVec, outputVec, dimsToMask);
1218 }
1219 
1220 /// Generic vectorization for a single operation `op`, given already vectorized
1221 /// operands carried by `bvm`. Vectorization occurs as follows:
1222 ///   1. Try to apply any of the `customVectorizationHooks` and return its
1223 ///   result on success.
1224 ///   2. Clone any constant in the current scope without vectorization: each
1225 ///   consumer of the constant will later determine the shape to which the
1226 ///   constant needs to be broadcast to.
1227 ///   3. Fail on any remaining non `ElementwiseMappable` op. It is the purpose
1228 ///   of the `customVectorizationHooks` to cover such cases.
1229 ///   4. Clone `op` in vector form to a vector of shape prescribed by the first
1230 ///   operand of maximal rank. Other operands have smaller rank and are
1231 ///   broadcast accordingly. It is assumed this broadcast is always legal,
1232 ///   otherwise, it means one of the `customVectorizationHooks` is incorrect.
1233 ///
1234 /// This function assumes all operands of `op` have been vectorized and are in
1235 /// the `bvm` mapping. As a consequence, this function is meant to be called  on
1236 /// a topologically-sorted list of ops.
1237 /// This function does not update `bvm` but returns a VectorizationStatus that
1238 /// instructs the caller what `bvm` update needs to occur.
1239 static VectorizationResult
1240 vectorizeOneOp(RewriterBase &rewriter, VectorizationState &state,
1241                LinalgOp linalgOp, Operation *op, const IRMapping &bvm,
1242                ArrayRef<CustomVectorizationHook> customVectorizationHooks) {
1243   LDBG("vectorize op " << *op << "\n");
1244 
1245   // 1. Try to apply any CustomVectorizationHook.
1246   if (!customVectorizationHooks.empty()) {
1247     for (auto &customFunc : customVectorizationHooks) {
1248       VectorizationResult result = customFunc(op, bvm);
1249       if (result.status == VectorizationStatus::Failure)
1250         continue;
1251       return result;
1252     }
1253   }
1254 
1255   // 2. Constant ops don't get vectorized but rather broadcasted at their users.
1256   // Clone so that the constant is not confined to the linalgOp block .
1257   if (isa<arith::ConstantOp, func::ConstantOp>(op))
1258     return VectorizationResult{VectorizationStatus::NewOp, rewriter.clone(*op)};
1259 
1260   // 3. Only ElementwiseMappable are allowed in the generic vectorization.
1261   if (!OpTrait::hasElementwiseMappableTraits(op))
1262     return VectorizationResult{VectorizationStatus::Failure, nullptr};
1263 
1264   // 4 . Check if the operation is a reduction.
1265   SmallVector<std::pair<Value, Value>> reductionOperands;
1266   for (Value operand : op->getOperands()) {
1267     auto blockArg = dyn_cast<BlockArgument>(operand);
1268     if (!blockArg || blockArg.getOwner() != linalgOp.getBlock() ||
1269         blockArg.getArgNumber() < linalgOp.getNumDpsInputs())
1270       continue;
1271     SmallVector<Operation *> reductionOps;
1272     Value reduceValue = matchReduction(
1273         linalgOp.getRegionOutputArgs(),
1274         blockArg.getArgNumber() - linalgOp.getNumDpsInputs(), reductionOps);
1275     if (!reduceValue)
1276       continue;
1277     reductionOperands.push_back(std::make_pair(reduceValue, operand));
1278   }
1279   if (!reductionOperands.empty()) {
1280     assert(reductionOperands.size() == 1);
1281     Operation *reduceOp =
1282         reduceIfNeeded(rewriter, linalgOp, op, reductionOperands[0].first,
1283                        reductionOperands[0].second, bvm);
1284     if (reduceOp)
1285       return VectorizationResult{VectorizationStatus::NewOp, reduceOp};
1286   }
1287 
1288   // 5. Generic vectorization path for ElementwiseMappable ops.
1289   //   a. Get the first max ranked shape.
1290   VectorType firstMaxRankedType;
1291   for (Value operand : op->getOperands()) {
1292     auto vecOperand = bvm.lookup(operand);
1293     assert(vecOperand && "Vector operand couldn't be found");
1294 
1295     auto vecType = dyn_cast<VectorType>(vecOperand.getType());
1296     if (vecType && (!firstMaxRankedType ||
1297                     firstMaxRankedType.getRank() < vecType.getRank()))
1298       firstMaxRankedType = vecType;
1299   }
1300   //   b. Broadcast each op if needed.
1301   SmallVector<Value> vecOperands;
1302   for (Value scalarOperand : op->getOperands()) {
1303     Value vecOperand = bvm.lookup(scalarOperand);
1304     assert(vecOperand && "Vector operand couldn't be found");
1305 
1306     if (firstMaxRankedType) {
1307       auto vecType = VectorType::get(firstMaxRankedType.getShape(),
1308                                      getElementTypeOrSelf(vecOperand.getType()),
1309                                      firstMaxRankedType.getScalableDims());
1310       vecOperands.push_back(broadcastIfNeeded(rewriter, vecOperand, vecType));
1311     } else {
1312       vecOperands.push_back(vecOperand);
1313     }
1314   }
1315   //   c. for elementwise, the result is the vector with the firstMaxRankedShape
1316   SmallVector<Type> resultTypes;
1317   for (Type resultType : op->getResultTypes()) {
1318     resultTypes.push_back(
1319         firstMaxRankedType
1320             ? VectorType::get(firstMaxRankedType.getShape(), resultType,
1321                               firstMaxRankedType.getScalableDims())
1322             : resultType);
1323   }
1324   //   d. Build and return the new op.
1325   return VectorizationResult{
1326       VectorizationStatus::NewOp,
1327       rewriter.create(op->getLoc(), op->getName().getIdentifier(), vecOperands,
1328                       resultTypes, op->getAttrs())};
1329 }
1330 
1331 /// Generic vectorization function that rewrites the body of a `linalgOp` into
1332 /// vector form. Generic vectorization proceeds as follows:
1333 ///   1. Verify the `linalgOp` has one non-empty region.
1334 ///   2. Values defined above the region are mapped to themselves and will be
1335 ///   broadcasted on a per-need basis by their consumers.
1336 ///   3. Each region argument is vectorized into a vector.transfer_read (or 0-d
1337 ///   load).
1338 ///   TODO: Reuse opportunities for RAR dependencies.
1339 ///   4a. Register CustomVectorizationHook for YieldOp to capture the results.
1340 ///   4rewriter. Register CustomVectorizationHook for IndexOp to access the
1341 ///   iteration indices.
1342 ///   5. Iteratively call vectorizeOneOp on the region operations.
1343 ///
1344 /// When `broadcastToMaximalCommonShape` is set to true, eager broadcasting is
1345 /// performed to the maximal common vector size implied by the `linalgOp`
1346 /// iteration space. This eager broadcasting is introduced in the
1347 /// permutation_map of the vector.transfer_read operations. The eager
1348 /// broadcasting makes it trivial to detrmine where broadcast, transposes and
1349 /// reductions should occur, without any bookkeeping. The tradeoff is that, in
1350 /// the absence of good canonicalizations, the amount of work increases.
1351 /// This is not deemed a problem as we expect canonicalizations and foldings to
1352 /// aggressively clean up the useless work.
1353 static LogicalResult
1354 vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state,
1355                          LinalgOp linalgOp,
1356                          SmallVectorImpl<Value> &newResults) {
1357   LDBG("Vectorizing operation as linalg generic\n");
1358   Block *block = linalgOp.getBlock();
1359 
1360   // 2. Values defined above the region can only be broadcast for now. Make them
1361   // map to themselves.
1362   IRMapping bvm;
1363   SetVector<Value> valuesSet;
1364   mlir::getUsedValuesDefinedAbove(linalgOp->getRegion(0), valuesSet);
1365   bvm.map(valuesSet.getArrayRef(), valuesSet.getArrayRef());
1366 
1367   if (linalgOp.getNumDpsInits() == 0)
1368     return failure();
1369 
1370   // 3. Turn all BBArgs into vector.transfer_read / load.
1371   Location loc = linalgOp.getLoc();
1372   Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
1373   for (OpOperand *opOperand : linalgOp.getOpOperandsMatchingBBargs()) {
1374     BlockArgument bbarg = linalgOp.getMatchingBlockArgument(opOperand);
1375     if (linalgOp.isScalar(opOperand)) {
1376       bvm.map(bbarg, opOperand->get());
1377       continue;
1378     }
1379 
1380     // 3.a. Convert the indexing map for this input/output to a transfer read
1381     // permutation map and masking map.
1382     AffineMap indexingMap = linalgOp.getMatchingIndexingMap(opOperand);
1383 
1384     AffineMap readMap;
1385     VectorType readType;
1386     Type elemType = getElementTypeOrSelf(opOperand->get());
1387     if (linalgOp.isDpsInput(opOperand)) {
1388       // 3.a.i. For input reads we use the canonical vector shape.
1389       readMap = inverseAndBroadcastProjectedPermutation(indexingMap);
1390       readType = state.getCanonicalVecType(elemType);
1391     } else {
1392       // 3.a.ii. For output reads (iteration-carried dependence, e.g.,
1393       // reductions), the vector shape is computed by mapping the canonical
1394       // vector shape to the output domain and back to the canonical domain.
1395       readMap = inversePermutation(reindexIndexingMap(indexingMap));
1396       readType =
1397           state.getCanonicalVecType(elemType, readMap.compose(indexingMap));
1398     }
1399 
1400     SmallVector<Value> indices(linalgOp.getShape(opOperand).size(), zero);
1401 
1402     Operation *read = rewriter.create<vector::TransferReadOp>(
1403         loc, readType, opOperand->get(), indices, readMap);
1404     read = state.maskOperation(rewriter, read, linalgOp, indexingMap);
1405     Value readValue = read->getResult(0);
1406 
1407     // 3.b. If masked, set in-bounds to true. Masking guarantees that the access
1408     // will be in-bounds.
1409     if (auto maskOp = dyn_cast<vector::MaskingOpInterface>(read)) {
1410       SmallVector<bool> inBounds(readType.getRank(), true);
1411       cast<vector::TransferReadOp>(maskOp.getMaskableOp())
1412           .setInBoundsAttr(rewriter.getBoolArrayAttr(inBounds));
1413     }
1414 
1415     // 3.c. Not all ops support 0-d vectors, extract the scalar for now.
1416     // TODO: remove this.
1417     if (readType.getRank() == 0)
1418       readValue = rewriter.create<vector::ExtractElementOp>(loc, readValue);
1419 
1420     LDBG("New vectorized bbarg(" << bbarg.getArgNumber() << "): " << readValue
1421                                  << "\n");
1422     bvm.map(bbarg, readValue);
1423     bvm.map(opOperand->get(), readValue);
1424   }
1425 
1426   SmallVector<CustomVectorizationHook> hooks;
1427   // 4a. Register CustomVectorizationHook for yieldOp.
1428   CustomVectorizationHook vectorizeYield =
1429       [&](Operation *op, const IRMapping &bvm) -> VectorizationResult {
1430     return vectorizeLinalgYield(rewriter, op, bvm, state, linalgOp, newResults);
1431   };
1432   hooks.push_back(vectorizeYield);
1433 
1434   // 4b. Register CustomVectorizationHook for indexOp.
1435   CustomVectorizationHook vectorizeIndex =
1436       [&](Operation *op, const IRMapping &bvm) -> VectorizationResult {
1437     return vectorizeLinalgIndex(rewriter, state, op, linalgOp);
1438   };
1439   hooks.push_back(vectorizeIndex);
1440 
1441   // 4c. Register CustomVectorizationHook for extractOp.
1442   CustomVectorizationHook vectorizeExtract =
1443       [&](Operation *op, const IRMapping &bvm) -> VectorizationResult {
1444     return vectorizeTensorExtract(rewriter, state, op, linalgOp, bvm);
1445   };
1446   hooks.push_back(vectorizeExtract);
1447 
1448   // 5. Iteratively call `vectorizeOneOp` to each op in the slice.
1449   for (Operation &op : block->getOperations()) {
1450     VectorizationResult result =
1451         vectorizeOneOp(rewriter, state, linalgOp, &op, bvm, hooks);
1452     if (result.status == VectorizationStatus::Failure) {
1453       LDBG("failed to vectorize: " << op << "\n");
1454       return failure();
1455     }
1456     if (result.status == VectorizationStatus::NewOp) {
1457       Operation *maybeMaskedOp =
1458           state.maskOperation(rewriter, result.newOp, linalgOp);
1459       LDBG("New vector op: " << *maybeMaskedOp << "\n");
1460       bvm.map(op.getResults(), maybeMaskedOp->getResults());
1461     }
1462   }
1463 
1464   return success();
1465 }
1466 
1467 /// Given a tensor::PackOp, return the `dest` shape before any packing
1468 /// permutations.
1469 static SmallVector<int64_t> getTiledPackShape(tensor::PackOp packOp,
1470                                               ArrayRef<int64_t> destShape) {
1471   return applyPermutation(destShape, tensor::getPackInverseDestPerm(packOp));
1472 }
1473 
1474 /// Given an input, the mixed destSizes, and the vector sizes for vectorization,
1475 /// create an empty destination tensor and create a TransferWriteOp from the
1476 /// input to the empty tensor. If the destination shape is not the same as the
1477 /// inputVectorSizes for the first rank(inputVectorSizes) dims, then create a
1478 /// mask for the write. If `useInBoundsInsteadOfMasking` is set, then update the
1479 /// inBounds attribute of the transfer write op instead of masking.
1480 static Operation *createWriteOrMaskedWrite(OpBuilder &builder, Location loc,
1481                                            Value input,
1482                                            SmallVector<OpFoldResult> destSizes,
1483                                            ArrayRef<int64_t> inputVectorSizes,
1484                                            bool useInBoundsInsteadOfMasking) {
1485 
1486   auto inputType = cast<VectorType>(input.getType());
1487   Value dest = builder.create<tensor::EmptyOp>(loc, destSizes,
1488                                                inputType.getElementType());
1489   int64_t rank = cast<ShapedType>(dest.getType()).getRank();
1490   auto zero = builder.create<arith::ConstantIndexOp>(loc, 0);
1491   auto destShape = cast<ShapedType>(dest.getType()).getShape();
1492   SmallVector<bool> inBoundsVal(rank, true);
1493   if (useInBoundsInsteadOfMasking) {
1494     // Update the inBounds attribute.
1495     for (unsigned i = 0; i < rank; i++)
1496       inBoundsVal[i] = (destShape[i] == inputVectorSizes[i]) &&
1497                        !ShapedType::isDynamic(destShape[i]);
1498   }
1499   Operation *write = builder.create<vector::TransferWriteOp>(
1500       loc,
1501       /*vector=*/input,
1502       /*source=*/dest,
1503       /*indices=*/SmallVector<Value>(rank, zero),
1504       /*inBounds=*/inBoundsVal);
1505   assert(llvm::none_of(
1506              destShape.drop_front(inputVectorSizes.size()),
1507              [](int64_t size) { return size == ShapedType::kDynamic; }) &&
1508          "Only dims aligned with inputVectorSizes may be dynamic");
1509   if (useInBoundsInsteadOfMasking)
1510     return write;
1511   bool needMaskForWrite = !llvm::equal(
1512       inputVectorSizes, destShape.take_front(inputVectorSizes.size()));
1513   if (needMaskForWrite) {
1514     SmallVector<int64_t> writeMaskShape;
1515     writeMaskShape.append(inputVectorSizes.begin(), inputVectorSizes.end());
1516     writeMaskShape.append(destShape.begin() + inputVectorSizes.size(),
1517                           destShape.end());
1518     auto writeMaskType = VectorType::get(writeMaskShape, builder.getI1Type());
1519     Value maskForWrite =
1520         builder.create<vector::CreateMaskOp>(loc, writeMaskType, destSizes);
1521     write = mlir::vector::maskOperation(builder, write, maskForWrite);
1522   }
1523   return write;
1524 }
1525 
1526 /// Vectorize tensor::PackOp with (1) static innerTiles (2) constant
1527 /// padding value and (3) input vector sizes into:
1528 /// masked_transfer_read->shape_cast->transpose->transfer_write_in_bounds
1529 /// As in the following example:
1530 /// %pack = tensor.pack %src inner_dims_pos = [2, 1] inner_tiles = [16, 2]
1531 ///     into %dst : tensor<32x8x16xf32> -> tensor<32x4x1x16x2xf32>
1532 ///
1533 /// This pack would be vectorized to:
1534 ///
1535 /// %load = vector.mask %mask {
1536 ///     vector.transfer_read %arg0[%c0, %c0, %c0], %cst
1537 ///         {in_bounds = [true, true, true]} :
1538 ///         tensor<32x7x16xf32>, vector<32x8x16xf32>
1539 /// } : vector<32x8x16xi1> -> vector<32x8x16xf32>
1540 /// %shape_cast = vector.shape_cast %load : vector<32x8x16xf32>
1541 ///                                         to vector<32x4x2x1x16xf32>
1542 /// %transpose = vector.transpose %shape_cast, [0, 1, 3, 4, 2]
1543 ///     : vector<32x4x2x1x16xf32> to vector<32x4x1x16x2xf32>
1544 /// %write = vector.transfer_write %transpose,
1545 ///     %empty[%c0_0, %c0_0, %c0_0, %c0_0, %c0_0]
1546 ///     {in_bounds = [true, true, true, true, true]}
1547 ///     : vector<32x4x1x16x2xf32>, tensor<32x4x1x16x2xf32>
1548 ///
1549 /// If the (3) input vector sizes are not provided, the vector sizes are
1550 /// determined by the result tensor shape. Also, we update the inBounds
1551 /// attribute instead of masking.
1552 static LogicalResult
1553 vectorizeAsTensorPackOp(RewriterBase &rewriter, tensor::PackOp packOp,
1554                         ArrayRef<int64_t> inputVectorSizes,
1555                         SmallVectorImpl<Value> &newResults) {
1556   OpBuilder::InsertionGuard g(rewriter);
1557   rewriter.setInsertionPoint(packOp);
1558 
1559   Location loc = packOp.getLoc();
1560   auto padValue = packOp.getPaddingValue();
1561   if (!padValue) {
1562     padValue = rewriter.create<arith::ConstantOp>(
1563         loc, rewriter.getZeroAttr(packOp.getSourceType().getElementType()));
1564   }
1565   ReifiedRankedShapedTypeDims reifiedReturnShapes;
1566   LogicalResult status =
1567       cast<ReifyRankedShapedTypeOpInterface>(packOp.getOperation())
1568           .reifyResultShapes(rewriter, reifiedReturnShapes);
1569   (void)status; // prevent unused variable warning on non-assert builds.
1570   assert(succeeded(status) && "failed to reify result shapes");
1571 
1572   // If the input vector sizes are not provided, then the vector sizes are
1573   // determined by the result tensor shape. In case the vector sizes aren't
1574   // provided, we update the inBounds attribute instead of masking.
1575   bool useInBoundsInsteadOfMasking = false;
1576   if (inputVectorSizes.empty()) {
1577     ArrayRef<int64_t> resultTensorShape = packOp.getDestType().getShape();
1578     inputVectorSizes = resultTensorShape.take_front(packOp.getSourceRank());
1579     useInBoundsInsteadOfMasking = true;
1580   }
1581 
1582   // Create masked TransferReadOp.
1583   SmallVector<int64_t> inputShape(inputVectorSizes);
1584   auto innerTiles = packOp.getStaticInnerTiles();
1585   auto innerDimsPos = packOp.getInnerDimsPos();
1586   auto outerDimsPerm = packOp.getOuterDimsPerm();
1587   if (!outerDimsPerm.empty())
1588     applyPermutationToVector(inputShape,
1589                              invertPermutationVector(outerDimsPerm));
1590   for (auto [idx, size] : enumerate(innerTiles))
1591     inputShape[innerDimsPos[idx]] *= size;
1592   auto maskedRead = vector::createReadOrMaskedRead(
1593       rewriter, loc, packOp.getSource(), inputShape, padValue,
1594       useInBoundsInsteadOfMasking);
1595 
1596   // Create ShapeCastOp.
1597   SmallVector<int64_t> destShape(inputVectorSizes);
1598   destShape.append(innerTiles.begin(), innerTiles.end());
1599   auto tiledPackType = VectorType::get(getTiledPackShape(packOp, destShape),
1600                                        packOp.getDestType().getElementType());
1601   auto shapeCastOp =
1602       rewriter.create<vector::ShapeCastOp>(loc, tiledPackType, maskedRead);
1603 
1604   // Create TransposeOp.
1605   auto destPermutation =
1606       invertPermutationVector(tensor::getPackInverseDestPerm(packOp));
1607   auto transposeOp = rewriter.create<vector::TransposeOp>(
1608       loc, shapeCastOp.getResult(), destPermutation);
1609 
1610   // Create TransferWriteOp.
1611   Operation *write = createWriteOrMaskedWrite(
1612       rewriter, loc, transposeOp.getResult(), reifiedReturnShapes[0],
1613       inputVectorSizes, /*useInBoundsInsteadOfMasking=*/false);
1614   newResults.push_back(write->getResult(0));
1615   return success();
1616 }
1617 
1618 /// Vectorize a `tensor::UnPackOp` to these 4 Ops:
1619 ///   Vector::TransferReadOp - Reads a vector from the source tensor
1620 ///   vector::TransposeOp - Transpose the Source tensor
1621 ///   ShapeCastOp - Reshape the data based on the target.
1622 ///   vector::TransferWriteOp. - Write the result vector back to the destination
1623 ///   tensor.
1624 ///   If the vector sizes are not provided:
1625 ///   * the vector sizes are determined by the input operand and attributes,
1626 ///   * update the inBounds attribute instead of masking.
1627 static LogicalResult
1628 vectorizeAsTensorUnpackOp(RewriterBase &rewriter, tensor::UnPackOp unpackOp,
1629                           ArrayRef<int64_t> inputVectorSizes,
1630                           SmallVectorImpl<Value> &newResults) {
1631 
1632   OpBuilder::InsertionGuard g(rewriter);
1633   rewriter.setInsertionPoint(unpackOp);
1634 
1635   RankedTensorType unpackTensorType = unpackOp.getSourceType();
1636 
1637   ArrayRef<int64_t> innerDimPos = unpackOp.getInnerDimsPos();
1638   ArrayRef<int64_t> innerTiles = unpackOp.getStaticInnerTiles();
1639   ArrayRef<int64_t> sourceShape = unpackTensorType.getShape();
1640   bool useInBoundsInsteadOfMasking = false;
1641   ArrayRef<int64_t> outerDimsPerm = unpackOp.getOuterDimsPerm();
1642 
1643   auto destSize = unpackOp.getDestRank();
1644 
1645   if (!inputVectorSizes.empty())
1646     assert(inputVectorSizes.size() == destSize &&
1647            "Incorrect number of input vector sizes");
1648 
1649   // vectorSizes is the shape of the vector that will be used to do final
1650   // write on the destination tensor. It is set like this: Let's say the
1651   // source tensor is rank 'M' and the dest tensor rank 'N', where N <= M.
1652   // Thus:
1653   // 1. vectorSizes = sourceShape.take_front(N)
1654   // 2. if outer_dims_perms is present: do that permutation on vectorSizes.
1655   // 3. multiply all the locations in vectorSize pointed by innerDimPos by the
1656   //    innerTiles attribute value.
1657   SmallVector<int64_t> vectorSizes(inputVectorSizes);
1658   if (vectorSizes.empty()) {
1659     llvm::append_range(vectorSizes, sourceShape.take_front(destSize));
1660     if (!outerDimsPerm.empty())
1661       applyPermutationToVector(vectorSizes, outerDimsPerm);
1662     for (auto [i, pos] : llvm::enumerate(innerDimPos))
1663       vectorSizes[pos] *= innerTiles[i];
1664 
1665     useInBoundsInsteadOfMasking = true;
1666   }
1667 
1668   // readVectorSizes is the size of tensor used to read and apply mask. It is
1669   // set like this: Let's say the vectorSize (VS) array is size 'N' and
1670   // the sourceShape(SS) is 'M' where M >= N and InnerTileSizes (IT) of
1671   // size M-N
1672   // Thus:
1673   // - initially: readVectorSizes = vectorInputSizes
1674   // - Divide all the readMaskShape locations pointed by innerDimPos
1675   //   by the innerTileSize attribute value.
1676   // - if outer_dims_perms is present: do that permutation on readVectorSizes.
1677   // - Append the remaining shape from SS
1678   // E.g. let's say let's say unpackTensorType.getShape() = <8x8x32x16>
1679   // inner Dim Pos = [0, 1] and Inner Tiles = [32, 16], vector_sizes are [512,
1680   // 128] and outer_dims_perm is [1, 0] then read shape is:
1681   //   ReadVectorSizes(initial): [512, 128]
1682   //   Final Value(after innerDim Adjustment): [512/32, 128/16]
1683   //                                           = [16, 8]
1684   //   After applying outer_dims_perm: [8, 16]
1685   //   After appending the rest of the sourceShape: [8, 16, 32, 16]
1686 
1687   SmallVector<int64_t> readVectorSizes(vectorSizes.begin(), vectorSizes.end());
1688 
1689   for (auto [index, size] : enumerate(innerTiles)) {
1690     readVectorSizes[innerDimPos[index]] =
1691         llvm::divideCeil(readVectorSizes[innerDimPos[index]], size);
1692   }
1693   if (!outerDimsPerm.empty()) {
1694     applyPermutationToVector(readVectorSizes, outerDimsPerm);
1695   }
1696   readVectorSizes.append(sourceShape.begin() + vectorSizes.size(),
1697                          sourceShape.end());
1698 
1699   ReifiedRankedShapedTypeDims reifiedRetShapes;
1700   LogicalResult status =
1701       cast<ReifyRankedShapedTypeOpInterface>(unpackOp.getOperation())
1702           .reifyResultShapes(rewriter, reifiedRetShapes);
1703   if (status.failed()) {
1704     LDBG("Unable to reify result shapes of " << unpackOp);
1705     return failure();
1706   }
1707   Location loc = unpackOp->getLoc();
1708 
1709   auto padValue = rewriter.create<arith::ConstantOp>(
1710       loc, rewriter.getZeroAttr(unpackOp.getSourceType().getElementType()));
1711 
1712   // Read result, mask if necessary. If transferReadOp shape is not equal
1713   // to shape of source, then a mask is necessary.
1714   Value readResult = vector::createReadOrMaskedRead(
1715       rewriter, loc, unpackOp.getSource(), readVectorSizes, padValue,
1716       /*useInBoundsInsteadOfMasking=*/false);
1717 
1718   PackingMetadata packMetadata;
1719   SmallVector<int64_t> lastDimToInsertPosPerm =
1720       tensor::getUnPackInverseSrcPerm(unpackOp, packMetadata);
1721   ShapedType maskedOpShapedType = cast<ShapedType>(readResult.getType());
1722   SmallVector<int64_t> stripMineShape(maskedOpShapedType.getShape());
1723   mlir::Type stripMineElemType = maskedOpShapedType.getElementType();
1724   applyPermutationToVector(stripMineShape, lastDimToInsertPosPerm);
1725   RankedTensorType stripMineTensorType =
1726       RankedTensorType::get(stripMineShape, stripMineElemType);
1727   // Transpose the appropriate rows to match output.
1728   vector::TransposeOp transposeOp = rewriter.create<vector::TransposeOp>(
1729       loc, readResult, lastDimToInsertPosPerm);
1730 
1731   // Collapse the vector to the size required by result.
1732   RankedTensorType collapsedType = tensor::CollapseShapeOp::inferCollapsedType(
1733       stripMineTensorType, packMetadata.reassociations);
1734   mlir::VectorType vecCollapsedType =
1735       VectorType::get(collapsedType.getShape(), collapsedType.getElementType());
1736   vector::ShapeCastOp shapeCastOp = rewriter.create<vector::ShapeCastOp>(
1737       loc, vecCollapsedType, transposeOp->getResult(0));
1738 
1739   // writeVectorSizes had to match the shapecast shape for dynamic sizes,
1740   // otherwise the validator complains that the mask size is invalid.
1741   SmallVector<int64_t> writeVectorSizes(
1742       unpackOp.getDestType().hasStaticShape()
1743           ? vectorSizes
1744           : shapeCastOp.getResultVectorType().getShape());
1745   Operation *write = createWriteOrMaskedWrite(
1746       rewriter, loc, shapeCastOp.getResult(), reifiedRetShapes[0],
1747       writeVectorSizes, useInBoundsInsteadOfMasking);
1748   newResults.push_back(write->getResult(0));
1749   return success();
1750 }
1751 
1752 /// Vectorize a `padOp` with (1) static result type, (2) constant padding value
1753 /// and (3) all-zero lowPad to
1754 ///   `transfer_write_in_bounds(transfer_read_masked(pad_source, pad_value))`.
1755 static LogicalResult
1756 vectorizeAsTensorPadOp(RewriterBase &rewriter, tensor::PadOp padOp,
1757                        ArrayRef<int64_t> inputVectorSizes,
1758                        SmallVectorImpl<Value> &newResults) {
1759   auto padValue = padOp.getConstantPaddingValue();
1760   Location loc = padOp.getLoc();
1761 
1762   // transfer_write_in_bounds(transfer_read_masked(pad_source, pad_value))
1763   OpBuilder::InsertionGuard g(rewriter);
1764   rewriter.setInsertionPoint(padOp);
1765 
1766   ReifiedRankedShapedTypeDims reifiedReturnShapes;
1767   LogicalResult status =
1768       cast<ReifyRankedShapedTypeOpInterface>(padOp.getOperation())
1769           .reifyResultShapes(rewriter, reifiedReturnShapes);
1770   (void)status; // prevent unused variable warning on non-assert builds
1771   assert(succeeded(status) && "failed to reify result shapes");
1772   auto maskedRead = vector::createReadOrMaskedRead(
1773       rewriter, loc, padOp.getSource(), inputVectorSizes, padValue,
1774       /*useInBoundsInsteadOfMasking=*/false);
1775   Operation *write = createWriteOrMaskedWrite(
1776       rewriter, loc, maskedRead, reifiedReturnShapes[0], inputVectorSizes,
1777       /*useInBoundsInsteadOfMasking=*/false);
1778   newResults.push_back(write->getResult(0));
1779   return success();
1780 }
1781 
1782 // TODO: probably need some extra checks for reduction followed by consumer
1783 // ops that may not commute (e.g. linear reduction + non-linear instructions).
1784 static LogicalResult reductionPreconditions(LinalgOp op) {
1785   if (llvm::none_of(op.getIteratorTypesArray(), isReductionIterator)) {
1786     LDBG("reduction precondition failed: no reduction iterator\n");
1787     return failure();
1788   }
1789   for (OpOperand &opOperand : op.getDpsInitsMutable()) {
1790     AffineMap indexingMap = op.getMatchingIndexingMap(&opOperand);
1791     if (indexingMap.isPermutation())
1792       continue;
1793 
1794     Operation *reduceOp = matchLinalgReduction(&opOperand);
1795     if (!reduceOp || !getCombinerOpKind(reduceOp)) {
1796       LDBG("reduction precondition failed: reduction detection failed\n");
1797       return failure();
1798     }
1799   }
1800   return success();
1801 }
1802 
1803 static LogicalResult
1804 vectorizeDynamicConvOpPrecondition(linalg::LinalgOp conv,
1805                                    bool flatten1DDepthwiseConv) {
1806   if (flatten1DDepthwiseConv) {
1807     LDBG("Vectorization of flattened convs with dynamic shapes is not "
1808          "supported\n");
1809     return failure();
1810   }
1811 
1812   if (!isa<linalg::DepthwiseConv1DNwcWcOp>(conv)) {
1813     LDBG("Not a 1D depth-wise WC conv, dynamic shapes are not supported\n");
1814     return failure();
1815   }
1816 
1817   // Support dynamic shapes in 1D depthwise convolution, but only in the
1818   // _channel_ dimension.
1819   Value lhs = conv.getDpsInputOperand(0)->get();
1820   ArrayRef<int64_t> lhsShape = cast<ShapedType>(lhs.getType()).getShape();
1821   auto shapeWithoutCh = lhsShape.drop_back(1);
1822   if (ShapedType::isDynamicShape(shapeWithoutCh)) {
1823     LDBG("Dynamically-shaped op vectorization precondition failed: only "
1824          "channel dim can be dynamic\n");
1825     return failure();
1826   }
1827 
1828   return success();
1829 }
1830 
1831 static LogicalResult
1832 vectorizeDynamicLinalgOpPrecondition(linalg::LinalgOp op,
1833                                      bool flatten1DDepthwiseConv) {
1834   if (isa<ConvolutionOpInterface>(op.getOperation()))
1835     return vectorizeDynamicConvOpPrecondition(op, flatten1DDepthwiseConv);
1836 
1837   if (hasReductionIterator(op))
1838     return reductionPreconditions(op);
1839 
1840   // TODO: Masking only supports dynamic element-wise ops, linalg.generic ops,
1841   // linalg.copy ops and ops that implement ContractionOpInterface for now.
1842   if (!isElementwise(op) &&
1843       !isa<linalg::GenericOp, linalg::CopyOp, linalg::ContractionOpInterface>(
1844           op.getOperation()))
1845     return failure();
1846 
1847   LDBG("Dynamically-shaped op meets vectorization pre-conditions\n");
1848   return success();
1849 }
1850 
1851 /// Need to check if the inner-tiles are static/constant.
1852 static LogicalResult
1853 vectorizeUnPackOpPrecondition(tensor::UnPackOp unpackOp,
1854                               ArrayRef<int64_t> inputVectorSizes) {
1855 
1856   if (llvm::any_of(unpackOp.getInnerTiles(), [](OpFoldResult res) {
1857         return !getConstantIntValue(res).has_value();
1858       })) {
1859     LDBG("Inner-tiles must be constant: " << unpackOp << "\n");
1860     return failure();
1861   }
1862   ArrayRef<int64_t> resultShape = unpackOp.getDestType().getShape();
1863   bool satisfyEmptyCond = inputVectorSizes.empty() &&
1864                           unpackOp.getDestType().hasStaticShape() &&
1865                           unpackOp.getSourceType().hasStaticShape();
1866   if (!satisfyEmptyCond &&
1867       failed(vector::isValidMaskedInputVector(resultShape, inputVectorSizes)))
1868     return failure();
1869 
1870   return success();
1871 }
1872 
1873 static LogicalResult vectorizeLinalgOpPrecondition(
1874     LinalgOp linalgOp, ArrayRef<int64_t> inputVectorSizes,
1875     bool vectorizeNDExtract, bool flatten1DDepthwiseConv) {
1876   // tensor with dimension of 0 cannot be vectorized.
1877   if (llvm::is_contained(linalgOp.getStaticShape(), 0))
1878     return failure();
1879   // Check API contract for input vector sizes.
1880   if (!inputVectorSizes.empty() &&
1881       failed(vector::isValidMaskedInputVector(linalgOp.getStaticLoopRanges(),
1882                                               inputVectorSizes)))
1883     return failure();
1884 
1885   if (linalgOp.hasDynamicShape() && failed(vectorizeDynamicLinalgOpPrecondition(
1886                                         linalgOp, flatten1DDepthwiseConv))) {
1887     LDBG("Dynamically-shaped op failed vectorization pre-conditions\n");
1888     return failure();
1889   }
1890 
1891   SmallVector<CustomVectorizationPrecondition> customPreconditions;
1892 
1893   // Register CustomVectorizationPrecondition for extractOp.
1894   customPreconditions.push_back(tensorExtractVectorizationPrecondition);
1895 
1896   // All types in the body should be a supported element type for VectorType.
1897   for (Operation &innerOp : linalgOp->getRegion(0).front()) {
1898     // Check if any custom hook can vectorize the inner op.
1899     if (llvm::any_of(
1900             customPreconditions,
1901             [&](const CustomVectorizationPrecondition &customPrecondition) {
1902               return succeeded(
1903                   customPrecondition(&innerOp, vectorizeNDExtract));
1904             })) {
1905       continue;
1906     }
1907     if (llvm::any_of(innerOp.getOperandTypes(), [](Type type) {
1908           return !VectorType::isValidElementType(type);
1909         })) {
1910       return failure();
1911     }
1912     if (llvm::any_of(innerOp.getResultTypes(), [](Type type) {
1913           return !VectorType::isValidElementType(type);
1914         })) {
1915       return failure();
1916     }
1917   }
1918   if (isElementwise(linalgOp))
1919     return success();
1920 
1921   // TODO: isaConvolutionOpInterface that can also infer from generic
1922   // features. But we will still need stride/dilation attributes that will be
1923   // annoying to reverse-engineer...
1924   if (isa<ConvolutionOpInterface>(linalgOp.getOperation()))
1925     return success();
1926   // TODO: the common vector shape is equal to the static loop sizes only when
1927   // all indexing maps are projected permutations. For convs and stencils the
1928   // logic will need to evolve.
1929   if (!allIndexingsAreProjectedPermutation(linalgOp)) {
1930     LDBG("precondition failed: not projected permutations\n");
1931     return failure();
1932   }
1933   if (failed(reductionPreconditions(linalgOp))) {
1934     LDBG("precondition failed: reduction preconditions\n");
1935     return failure();
1936   }
1937   return success();
1938 }
1939 
1940 static LogicalResult
1941 vectorizePackOpPrecondition(tensor::PackOp packOp,
1942                             ArrayRef<int64_t> inputVectorSizes) {
1943   auto padValue = packOp.getPaddingValue();
1944   Attribute cstAttr;
1945   if (padValue && !matchPattern(padValue, m_Constant(&cstAttr))) {
1946     LDBG("pad value is not constant: " << packOp << "\n");
1947     return failure();
1948   }
1949   ArrayRef<int64_t> resultTensorShape = packOp.getDestType().getShape();
1950   bool satisfyEmptyCond = true;
1951   if (inputVectorSizes.empty()) {
1952     if (!packOp.getDestType().hasStaticShape() ||
1953         !packOp.getSourceType().hasStaticShape())
1954       satisfyEmptyCond = false;
1955   }
1956 
1957   if (!satisfyEmptyCond &&
1958       failed(vector::isValidMaskedInputVector(
1959           resultTensorShape.take_front(packOp.getSourceRank()),
1960           inputVectorSizes)))
1961     return failure();
1962 
1963   if (llvm::any_of(packOp.getInnerTiles(), [](OpFoldResult v) {
1964         return !getConstantIntValue(v).has_value();
1965       })) {
1966     LDBG("inner_tiles must be constant: " << packOp << "\n");
1967     return failure();
1968   }
1969 
1970   return success();
1971 }
1972 
1973 static LogicalResult
1974 vectorizePadOpPrecondition(tensor::PadOp padOp,
1975                            ArrayRef<int64_t> inputVectorSizes) {
1976   auto padValue = padOp.getConstantPaddingValue();
1977   if (!padValue) {
1978     LDBG("pad value is not constant: " << padOp << "\n");
1979     return failure();
1980   }
1981 
1982   ArrayRef<int64_t> resultTensorShape = padOp.getResultType().getShape();
1983   if (failed(vector::isValidMaskedInputVector(resultTensorShape,
1984                                               inputVectorSizes)))
1985     return failure();
1986 
1987   if (llvm::any_of(padOp.getLow(), [](Value v) {
1988         std::optional<int64_t> res = getConstantIntValue(v);
1989         return !res.has_value() || res.value() != 0;
1990       })) {
1991     LDBG("low pad must all be zero: " << padOp << "\n");
1992     return failure();
1993   }
1994 
1995   return success();
1996 }
1997 
1998 /// Preconditions for scalable vectors. This is quite restrictive - it models
1999 /// the fact that in practice we would only make selected dimensions scalable.
2000 static LogicalResult
2001 vectorizeScalableVectorPrecondition(Operation *op,
2002                                     ArrayRef<int64_t> inputVectorSizes,
2003                                     ArrayRef<bool> inputScalableVecDims) {
2004   assert(inputVectorSizes.size() == inputScalableVecDims.size() &&
2005          "Number of input vector sizes and scalable dims doesn't match");
2006 
2007   size_t numOfScalableDims =
2008       llvm::count_if(inputScalableVecDims, [](bool flag) { return flag; });
2009 
2010   if (numOfScalableDims == 0)
2011     return success();
2012 
2013   auto linalgOp = dyn_cast<LinalgOp>(op);
2014 
2015   // Cond 1: There's been no need for scalable vectorisation of
2016   // non-linalg Ops so far
2017   if (!linalgOp)
2018     return failure();
2019 
2020   // Cond 2: There's been no need for more than 2 scalable dims so far
2021   if (numOfScalableDims > 2)
2022     return failure();
2023 
2024   // Cond 3: Look at the configuration in `inputScalableVecDims` and verify that
2025   // it matches one of the supported cases:
2026   //  1. exactly 1 dim is scalable and that's the _last_ parallel dim
2027   //  2. exactly 2 dims are scalable and those are the _last two adjacent_
2028   //     parallel dims
2029   //  3. exactly 1 reduction dim is scalable and that's the last (innermost) dim
2030   // The 2nd restriction above means that only Matmul-like Ops are supported
2031   // when 2 dims are scalable, e.g. :
2032   //    * iterators = [parallel, parallel, reduction]
2033   //    * scalable flags = [true, true, false]
2034 
2035   // Find the first scalable flag
2036   bool seenParalell = false;
2037   auto iterators = linalgOp.getIteratorTypesArray();
2038   SmallVector<bool> scalableFlags(inputScalableVecDims);
2039   while (!scalableFlags.back()) {
2040     seenParalell |= (iterators.back() == utils::IteratorType::parallel);
2041 
2042     iterators.pop_back();
2043     scalableFlags.pop_back();
2044   }
2045 
2046   switch (iterators.back()) {
2047   case utils::IteratorType::reduction: {
2048     // Check 3. above is met.
2049     if (iterators.size() != inputVectorSizes.size()) {
2050       LDBG("Non-trailing reduction dim requested for scalable "
2051            "vectorization\n");
2052       return failure();
2053     }
2054     if (isa<linalg::MatmulOp>(op) || isa<linalg::MatmulTransposeAOp>(op)) {
2055       LDBG("Scalable vectorization of the reduction dim in Matmul-like ops "
2056            "is not supported\n");
2057       return failure();
2058     }
2059     break;
2060   }
2061   case utils::IteratorType::parallel: {
2062     // Check 1. and 2. above are met.
2063     if (seenParalell) {
2064       LDBG("Inner parallel dim not requested for scalable "
2065            "vectorization\n");
2066       return failure();
2067     }
2068     break;
2069   }
2070   }
2071 
2072   // If present, check the 2nd scalable dim. ATM, only Matmul-like Ops are
2073   // supported for which expect the folowing config:
2074   //    * iterators = [parallel, parallel, reduction]
2075   //    * scalable flags = [true, true, false]
2076   if (numOfScalableDims == 2) {
2077     // Disallow below case which breaks 3. above:
2078     //    * iterators = [..., parallel, reduction]
2079     //    * scalable flags = [..., true, true]
2080     if (iterators.back() == utils::IteratorType::reduction) {
2081       LDBG("Higher dim than the trailing reduction dim requested for scalable "
2082            "vectorization\n");
2083       return failure();
2084     }
2085     scalableFlags.pop_back();
2086     iterators.pop_back();
2087 
2088     if (!scalableFlags.back() ||
2089         (iterators.back() != utils::IteratorType::parallel))
2090       return failure();
2091   }
2092 
2093   // Check to not let go the matmul with extended semantic, through this
2094   // transform.
2095   if (linalgOp.hasUserDefinedMaps())
2096     return failure();
2097 
2098   // Cond 4: Only the following ops are supported in the
2099   // presence of scalable vectors
2100   return success(isElementwise(linalgOp) || isa<linalg::MatmulOp>(op) ||
2101                  isa<linalg::MatmulTransposeAOp>(op) ||
2102                  isa<linalg::DepthwiseConv1DNwcWcOp>(op) ||
2103                  isa<linalg::MatvecOp>(op) || hasReductionIterator(linalgOp));
2104 }
2105 
2106 LogicalResult mlir::linalg::vectorizeOpPrecondition(
2107     Operation *op, ArrayRef<int64_t> inputVectorSizes,
2108     ArrayRef<bool> inputScalableVecDims, bool vectorizeNDExtract,
2109     bool flatten1DDepthwiseConv) {
2110 
2111   if (!hasVectorizationImpl(op))
2112     return failure();
2113 
2114   if (failed(vectorizeScalableVectorPrecondition(op, inputVectorSizes,
2115                                                  inputScalableVecDims)))
2116     return failure();
2117 
2118   return TypeSwitch<Operation *, LogicalResult>(op)
2119       .Case<linalg::LinalgOp>([&](auto linalgOp) {
2120         return vectorizeLinalgOpPrecondition(linalgOp, inputVectorSizes,
2121                                              vectorizeNDExtract,
2122                                              flatten1DDepthwiseConv);
2123       })
2124       .Case<tensor::PadOp>([&](auto padOp) {
2125         return vectorizePadOpPrecondition(padOp, inputVectorSizes);
2126       })
2127       .Case<tensor::PackOp>([&](auto packOp) {
2128         return vectorizePackOpPrecondition(packOp, inputVectorSizes);
2129       })
2130       .Case<tensor::UnPackOp>([&](auto unpackOp) {
2131         return vectorizeUnPackOpPrecondition(unpackOp, inputVectorSizes);
2132       })
2133       .Default([](auto) { return failure(); });
2134 }
2135 
2136 /// Converts affine.apply Ops to arithmetic operations.
2137 static void convertAffineApply(RewriterBase &rewriter, LinalgOp linalgOp) {
2138   OpBuilder::InsertionGuard g(rewriter);
2139   auto toReplace = linalgOp.getBlock()->getOps<affine::AffineApplyOp>();
2140 
2141   for (auto op : make_early_inc_range(toReplace)) {
2142     rewriter.setInsertionPoint(op);
2143     auto expanded = affine::expandAffineExpr(
2144         rewriter, op->getLoc(), op.getAffineMap().getResult(0),
2145         op.getOperands().take_front(op.getAffineMap().getNumDims()),
2146         op.getOperands().take_back(op.getAffineMap().getNumSymbols()));
2147     rewriter.replaceOp(op, expanded);
2148   }
2149 }
2150 
2151 bool mlir::linalg::hasVectorizationImpl(Operation *op) {
2152   return isa<linalg::LinalgOp, tensor::PadOp, tensor::PackOp, tensor::UnPackOp>(
2153       op);
2154 }
2155 
2156 /// Emit a suitable vector form for an operation. If provided,
2157 /// `inputVectorSizes` are used to vectorize this operation.
2158 /// `inputVectorSizes` must match the rank of the iteration space of the
2159 /// operation and the input vector sizes must be greater than or equal to
2160 /// their counterpart iteration space sizes, if static. `inputVectorShapes`
2161 /// also allows the vectorization of operations with dynamic shapes.
2162 LogicalResult mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op,
2163                                       ArrayRef<int64_t> inputVectorSizes,
2164                                       ArrayRef<bool> inputScalableVecDims,
2165                                       bool vectorizeNDExtract,
2166                                       bool flatten1DDepthwiseConv) {
2167   LDBG("Attempting to vectorize:\n" << *op << "\n");
2168   LDBG("Input vector sizes: ");
2169   LLVM_DEBUG(llvm::interleaveComma(inputVectorSizes, llvm::dbgs()));
2170   LLVM_DEBUG(llvm::dbgs() << "\n");
2171   LDBG("Input scalable vector dims: ");
2172   LLVM_DEBUG(llvm::interleaveComma(inputScalableVecDims, llvm::dbgs()));
2173   LLVM_DEBUG(llvm::dbgs() << "\n");
2174 
2175   if (failed(vectorizeOpPrecondition(op, inputVectorSizes, inputScalableVecDims,
2176                                      vectorizeNDExtract,
2177                                      flatten1DDepthwiseConv))) {
2178     LDBG("Vectorization pre-conditions failed\n");
2179     return failure();
2180   }
2181 
2182   // Initialize vectorization state.
2183   VectorizationState state(rewriter);
2184   if (auto linalgOp = dyn_cast<linalg::LinalgOp>(op)) {
2185     if (failed(state.initState(rewriter, linalgOp, inputVectorSizes,
2186                                inputScalableVecDims))) {
2187       LDBG("Vectorization state couldn't be initialized\n");
2188       return failure();
2189     }
2190   }
2191 
2192   SmallVector<Value> results;
2193   auto vectorizeResult =
2194       TypeSwitch<Operation *, LogicalResult>(op)
2195           .Case<linalg::LinalgOp>([&](auto linalgOp) {
2196             // TODO: isaConvolutionOpInterface that can also infer from
2197             // generic features. Will require stride/dilation attributes
2198             // inference.
2199             if (isa<ConvolutionOpInterface>(linalgOp.getOperation())) {
2200               FailureOr<Operation *> convOr = vectorizeConvolution(
2201                   rewriter, linalgOp, inputVectorSizes, inputScalableVecDims,
2202                   flatten1DDepthwiseConv);
2203               if (succeeded(convOr)) {
2204                 llvm::append_range(results, (*convOr)->getResults());
2205                 return success();
2206               }
2207 
2208               LDBG("Unsupported convolution can't be vectorized.\n");
2209               return failure();
2210             }
2211 
2212             LDBG("Vectorize generic by broadcasting to the canonical vector "
2213                  "shape\n");
2214 
2215             // Pre-process before proceeding.
2216             convertAffineApply(rewriter, linalgOp);
2217 
2218             // TODO: 'vectorize' takes in a 'RewriterBase' which is up-casted
2219             // to 'OpBuilder' when it is passed over to some methods like
2220             // 'vectorizeAsLinalgGeneric'. This is highly problematic: if we
2221             // erase an op within these methods, the actual rewriter won't be
2222             // notified and we will end up with read-after-free issues!
2223             return vectorizeAsLinalgGeneric(rewriter, state, linalgOp, results);
2224           })
2225           .Case<tensor::PadOp>([&](auto padOp) {
2226             return vectorizeAsTensorPadOp(rewriter, padOp, inputVectorSizes,
2227                                           results);
2228           })
2229           .Case<tensor::PackOp>([&](auto packOp) {
2230             return vectorizeAsTensorPackOp(rewriter, packOp, inputVectorSizes,
2231                                            results);
2232           })
2233           .Case<tensor::UnPackOp>([&](auto unpackOp) {
2234             return vectorizeAsTensorUnpackOp(rewriter, unpackOp,
2235                                              inputVectorSizes, results);
2236           })
2237           .Default([](auto) { return failure(); });
2238 
2239   if (failed(vectorizeResult)) {
2240     LDBG("Vectorization failed\n");
2241     return failure();
2242   }
2243 
2244   if (!results.empty())
2245     rewriter.replaceOp(op, results);
2246   else
2247     rewriter.eraseOp(op);
2248 
2249   return success();
2250 }
2251 
2252 LogicalResult mlir::linalg::vectorizeCopy(RewriterBase &rewriter,
2253                                           memref::CopyOp copyOp) {
2254   auto srcType = cast<MemRefType>(copyOp.getSource().getType());
2255   auto dstType = cast<MemRefType>(copyOp.getTarget().getType());
2256   if (!srcType.hasStaticShape() || !dstType.hasStaticShape())
2257     return failure();
2258 
2259   auto srcElementType = getElementTypeOrSelf(srcType);
2260   auto dstElementType = getElementTypeOrSelf(dstType);
2261   if (!VectorType::isValidElementType(srcElementType) ||
2262       !VectorType::isValidElementType(dstElementType))
2263     return failure();
2264 
2265   auto readType = VectorType::get(srcType.getShape(), srcElementType);
2266   auto writeType = VectorType::get(dstType.getShape(), dstElementType);
2267 
2268   Location loc = copyOp->getLoc();
2269   Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
2270   SmallVector<Value> indices(srcType.getRank(), zero);
2271 
2272   Value readValue = rewriter.create<vector::TransferReadOp>(
2273       loc, readType, copyOp.getSource(), indices,
2274       rewriter.getMultiDimIdentityMap(srcType.getRank()));
2275   if (cast<VectorType>(readValue.getType()).getRank() == 0) {
2276     readValue = rewriter.create<vector::ExtractElementOp>(loc, readValue);
2277     readValue = rewriter.create<vector::BroadcastOp>(loc, writeType, readValue);
2278   }
2279   Operation *writeValue = rewriter.create<vector::TransferWriteOp>(
2280       loc, readValue, copyOp.getTarget(), indices,
2281       rewriter.getMultiDimIdentityMap(srcType.getRank()));
2282   rewriter.replaceOp(copyOp, writeValue->getResults());
2283   return success();
2284 }
2285 
2286 //----------------------------------------------------------------------------//
2287 // Misc. vectorization patterns.
2288 //----------------------------------------------------------------------------//
2289 /// Base pattern for rewriting tensor::PadOps whose result is consumed by a
2290 /// given operation type OpTy.
2291 template <typename OpTy>
2292 struct VectorizePadOpUserPattern : public OpRewritePattern<tensor::PadOp> {
2293   using OpRewritePattern<tensor::PadOp>::OpRewritePattern;
2294 
2295   LogicalResult matchAndRewrite(tensor::PadOp padOp,
2296                                 PatternRewriter &rewriter) const final {
2297     bool changed = false;
2298     // Insert users in vector, because some users may be replaced/removed.
2299     for (auto *user : llvm::to_vector<4>(padOp->getUsers()))
2300       if (auto op = dyn_cast<OpTy>(user))
2301         changed |= rewriteUser(rewriter, padOp, op).succeeded();
2302     return success(changed);
2303   }
2304 
2305 protected:
2306   virtual LogicalResult rewriteUser(PatternRewriter &rewriter,
2307                                     tensor::PadOp padOp, OpTy op) const = 0;
2308 };
2309 
2310 /// Rewrite use of tensor::PadOp result in TransferReadOp. E.g.:
2311 /// ```
2312 /// %0 = tensor.pad %src ... : tensor<?x?xf32> to tensor<17x5xf32>
2313 /// %r = vector.transfer_read %0[%c0, %c0], %cst
2314 ///     {in_bounds = [true, true]} : tensor<17x5xf32>, vector<17x5xf32>
2315 /// ```
2316 /// is rewritten to:
2317 /// ```
2318 /// %r = vector.transfer_read %src[%c0, %c0], %padding
2319 ///     {in_bounds = [true, true]}
2320 ///     : tensor<?x?xf32>, vector<17x5xf32>
2321 /// ```
2322 /// Note: By restricting this pattern to in-bounds TransferReadOps, we can be
2323 /// sure that the original padding value %cst was never used.
2324 ///
2325 /// This rewrite is possible if:
2326 /// - `xferOp` has no out-of-bounds dims or mask.
2327 /// - Low padding is static 0.
2328 /// - Single, scalar padding value.
2329 struct PadOpVectorizationWithTransferReadPattern
2330     : public VectorizePadOpUserPattern<vector::TransferReadOp> {
2331   using VectorizePadOpUserPattern<
2332       vector::TransferReadOp>::VectorizePadOpUserPattern;
2333 
2334   LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp,
2335                             vector::TransferReadOp xferOp) const override {
2336     // Low padding must be static 0.
2337     if (!padOp.hasZeroLowPad())
2338       return failure();
2339     // Pad value must be a constant.
2340     auto padValue = padOp.getConstantPaddingValue();
2341     if (!padValue)
2342       return failure();
2343     // Padding value of existing `xferOp` is unused.
2344     if (xferOp.hasOutOfBoundsDim() || xferOp.getMask())
2345       return failure();
2346 
2347     rewriter.modifyOpInPlace(xferOp, [&]() {
2348       SmallVector<bool> inBounds(xferOp.getVectorType().getRank(), false);
2349       xferOp->setAttr(xferOp.getInBoundsAttrName(),
2350                       rewriter.getBoolArrayAttr(inBounds));
2351       xferOp.getSourceMutable().assign(padOp.getSource());
2352       xferOp.getPaddingMutable().assign(padValue);
2353     });
2354 
2355     return success();
2356   }
2357 };
2358 
2359 /// Rewrite use of tensor::PadOp result in TransferWriteOp.
2360 /// This pattern rewrites TransferWriteOps that write to a padded tensor
2361 /// value, where the same amount of padding is immediately removed again after
2362 /// the write. In such cases, the TransferWriteOp can write to the non-padded
2363 /// tensor value and apply out-of-bounds masking. E.g.:
2364 /// ```
2365 /// %0 = tensor.extract_slice ...[...] [%s0, %s1] [1, 1]
2366 ///     : tensor<...> to tensor<?x?xf32>
2367 /// %1 = tensor.pad %0 ... : tensor<?x?xf32> to tensor<17x5xf32>
2368 /// %2 = vector.transfer_write %vec, %1[...]
2369 ///     : vector<17x5xf32>, tensor<17x5xf32>
2370 /// %r = tensor.extract_slice %2[0, 0] [%s0, %s1] [1, 1]
2371 ///     : tensor<17x5xf32> to tensor<?x?xf32>
2372 /// ```
2373 /// is rewritten to:
2374 /// ```
2375 /// %0 = tensor.extract_slice ...[...] [%s0, %s1] [1, 1]
2376 ///     : tensor<...> to tensor<?x?xf32>
2377 /// %r = vector.transfer_write %vec, %0[...] : vector<17x5xf32>,
2378 /// tensor<?x?xf32>
2379 /// ```
2380 /// Note: It is important that the ExtractSliceOp %r resizes the result of the
2381 /// TransferWriteOp to the same size as the input of the TensorPadOp (or an
2382 /// even smaller size). Otherwise, %r's new (dynamic) dimensions would differ
2383 /// from %r's old dimensions.
2384 ///
2385 /// This rewrite is possible if:
2386 /// - Low padding is static 0.
2387 /// - `xferOp` has exactly one use, which is an ExtractSliceOp. This
2388 ///   ExtractSliceOp trims the same amount of padding that was added
2389 ///   beforehand.
2390 /// - Single, scalar padding value.
2391 struct PadOpVectorizationWithTransferWritePattern
2392     : public VectorizePadOpUserPattern<vector::TransferWriteOp> {
2393   using VectorizePadOpUserPattern<
2394       vector::TransferWriteOp>::VectorizePadOpUserPattern;
2395 
2396   LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp,
2397                             vector::TransferWriteOp xferOp) const override {
2398     // TODO: support 0-d corner case.
2399     if (xferOp.getTransferRank() == 0)
2400       return failure();
2401 
2402     // Low padding must be static 0.
2403     if (!padOp.hasZeroLowPad())
2404       return failure();
2405     // Pad value must be a constant.
2406     auto padValue = padOp.getConstantPaddingValue();
2407     if (!padValue)
2408       return failure();
2409     // TransferWriteOp result must be directly consumed by an ExtractSliceOp.
2410     if (!xferOp->hasOneUse())
2411       return failure();
2412     auto trimPadding = dyn_cast<tensor::ExtractSliceOp>(*xferOp->user_begin());
2413     if (!trimPadding)
2414       return failure();
2415     // Only static zero offsets supported when trimming padding.
2416     if (!trimPadding.hasZeroOffset())
2417       return failure();
2418     // trimPadding must remove the amount of padding that was added earlier.
2419     if (!hasSameTensorSize(padOp.getSource(), trimPadding))
2420       return failure();
2421 
2422     // Insert the new TransferWriteOp at position of the old TransferWriteOp.
2423     rewriter.setInsertionPoint(xferOp);
2424 
2425     SmallVector<bool> inBounds(xferOp.getVectorType().getRank(), false);
2426     auto newXferOp = rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
2427         xferOp, padOp.getSource().getType(), xferOp.getVector(),
2428         padOp.getSource(), xferOp.getIndices(), xferOp.getPermutationMapAttr(),
2429         xferOp.getMask(), rewriter.getBoolArrayAttr(inBounds));
2430     rewriter.replaceOp(trimPadding, newXferOp->getResult(0));
2431 
2432     return success();
2433   }
2434 
2435   /// Check if `beforePadding` and `afterTrimming` have the same tensor size,
2436   /// i.e., same dimensions.
2437   ///
2438   /// Dimensions may be static, dynamic or mix of both. In case of dynamic
2439   /// dimensions, this function tries to infer the (static) tensor size by
2440   /// looking at the defining op and utilizing op-specific knowledge.
2441   ///
2442   /// This is a conservative analysis. In case equal tensor sizes cannot be
2443   /// proven statically, this analysis returns `false` even though the tensor
2444   /// sizes may turn out to be equal at runtime.
2445   bool hasSameTensorSize(Value beforePadding,
2446                          tensor::ExtractSliceOp afterTrimming) const {
2447     // If the input to tensor::PadOp is a CastOp, try with both CastOp
2448     // result and CastOp operand.
2449     if (auto castOp = beforePadding.getDefiningOp<tensor::CastOp>())
2450       if (hasSameTensorSize(castOp.getSource(), afterTrimming))
2451         return true;
2452 
2453     auto t1 = dyn_cast<RankedTensorType>(beforePadding.getType());
2454     auto t2 = dyn_cast<RankedTensorType>(afterTrimming.getType());
2455     // Only RankedTensorType supported.
2456     if (!t1 || !t2)
2457       return false;
2458     // Rank of both values must be the same.
2459     if (t1.getRank() != t2.getRank())
2460       return false;
2461 
2462     // All static dimensions must be the same. Mixed cases (e.g., dimension
2463     // static in `t1` but dynamic in `t2`) are not supported.
2464     for (unsigned i = 0; i < t1.getRank(); ++i) {
2465       if (t1.isDynamicDim(i) != t2.isDynamicDim(i))
2466         return false;
2467       if (!t1.isDynamicDim(i) && t1.getDimSize(i) != t2.getDimSize(i))
2468         return false;
2469     }
2470 
2471     // Nothing more to check if all dimensions are static.
2472     if (t1.getNumDynamicDims() == 0)
2473       return true;
2474 
2475     // All dynamic sizes must be the same. The only supported case at the
2476     // moment is when `beforePadding` is an ExtractSliceOp (or a cast
2477     // thereof).
2478 
2479     // Apart from CastOp, only ExtractSliceOp is supported.
2480     auto beforeSlice = beforePadding.getDefiningOp<tensor::ExtractSliceOp>();
2481     if (!beforeSlice)
2482       return false;
2483 
2484     assert(static_cast<size_t>(t1.getRank()) ==
2485            beforeSlice.getMixedSizes().size());
2486     assert(static_cast<size_t>(t2.getRank()) ==
2487            afterTrimming.getMixedSizes().size());
2488 
2489     for (unsigned i = 0; i < t1.getRank(); ++i) {
2490       // Skip static dimensions.
2491       if (!t1.isDynamicDim(i))
2492         continue;
2493       auto size1 = beforeSlice.getMixedSizes()[i];
2494       auto size2 = afterTrimming.getMixedSizes()[i];
2495 
2496       // Case 1: Same value or same constant int.
2497       if (isEqualConstantIntOrValue(size1, size2))
2498         continue;
2499 
2500       // Other cases: Take a deeper look at defining ops of values.
2501       auto v1 = llvm::dyn_cast_if_present<Value>(size1);
2502       auto v2 = llvm::dyn_cast_if_present<Value>(size2);
2503       if (!v1 || !v2)
2504         return false;
2505 
2506       // Case 2: Both values are identical AffineMinOps. (Should not happen if
2507       // CSE is run.)
2508       auto minOp1 = v1.getDefiningOp<affine::AffineMinOp>();
2509       auto minOp2 = v2.getDefiningOp<affine::AffineMinOp>();
2510       if (minOp1 && minOp2 && minOp1.getAffineMap() == minOp2.getAffineMap() &&
2511           minOp1.getOperands() == minOp2.getOperands())
2512         continue;
2513 
2514       // Add additional cases as needed.
2515     }
2516 
2517     // All tests passed.
2518     return true;
2519   }
2520 };
2521 
2522 /// Returns the effective Pad value for the input op, provided it's a scalar.
2523 ///
2524 /// Many Ops exhibit pad-like behaviour, but this isn't always explicit. If
2525 /// this Op performs padding, retrieve the padding value provided that it's
2526 /// a scalar and static/fixed for all the padded values. Returns an empty value
2527 /// otherwise.
2528 static Value getStaticPadVal(Operation *op) {
2529   if (!op)
2530     return {};
2531 
2532   // 1. vector.broadcast (f32 -> vector <...xf32>) - return the value that's
2533   // being broadcast, provided that it's a scalar.
2534   if (auto bcast = llvm::dyn_cast<vector::BroadcastOp>(op)) {
2535     auto source = bcast.getSource();
2536     if (llvm::dyn_cast<VectorType>(source.getType()))
2537       return {};
2538 
2539     return source;
2540   }
2541 
2542   // 2. linalg.fill - use the scalar input value that used to fill the output
2543   // tensor.
2544   if (auto fill = llvm::dyn_cast<linalg::FillOp>(op)) {
2545     return fill.getInputs()[0];
2546   }
2547 
2548   // 3. tensor.generateOp - can't guarantee the value is fixed without
2549   // analysing, bail out.
2550   if (auto generate = llvm::dyn_cast<tensor::GenerateOp>(op)) {
2551     return {};
2552   }
2553 
2554   // 4. vector.transfer_write - inspect the input vector that's written from. If
2555   // if contains a single value that has been broadcast (e.g. via
2556   // vector.broadcast), extract it, fail otherwise.
2557   if (auto xferWrite = llvm::dyn_cast<vector::TransferWriteOp>(op))
2558     return getStaticPadVal(xferWrite.getVector().getDefiningOp());
2559 
2560   // 5. tensor.insert_slice - inspect the destination tensor. If it's larger
2561   // than the input tensor, then, provided it's constant, we'll extract the
2562   // value that was used to generate it (via e.g. linalg.fill), fail otherwise.
2563   // TODO: Clarify the semantics when the input tensor is larger than the
2564   // destination.
2565   if (auto slice = llvm::dyn_cast<tensor::InsertSliceOp>(op))
2566     return getStaticPadVal(slice.getDest().getDefiningOp());
2567 
2568   return {};
2569 }
2570 
2571 /// Rewrite tensor.insert.slice as a vector.transfer_read +
2572 /// vector.transfer_write pair. The vector size is inferred from the static
2573 /// dims in the input and output tensors. If a dim is dynamic in both the input
2574 /// and output tensors, bails out.
2575 ///
2576 /// Before:
2577 ///     !t_in_type = tensor<1x2x3xf32>
2578 ///     !t_out_type = tensor<9x8x7x1x2x3xf32>
2579 ///     !v_type = vector<1x2x3xf32>
2580 ///     %inserted_slice = tensor.insert_slice %src into %dest ... : !t_in_type
2581 ///     into !t_out_type
2582 /// After:
2583 ///     %read = vector.transfer_read %src[...], %pad ... : !t_in_type, !v_type
2584 ///     %write = vector.transfer_write %read, %dest ... : !v_type, !t_out_type
2585 ///
2586 /// TODO: Support masking
2587 struct InsertSliceVectorizePattern
2588     : public OpRewritePattern<tensor::InsertSliceOp> {
2589   using OpRewritePattern<tensor::InsertSliceOp>::OpRewritePattern;
2590 
2591   LogicalResult matchAndRewrite(tensor::InsertSliceOp sliceOp,
2592                                 PatternRewriter &rewriter) const final {
2593     auto sourceType = sliceOp.getSource().getType();
2594     if (!VectorType::isValidElementType(sourceType.getElementType()))
2595       return failure();
2596 
2597     auto resultType = sliceOp.getResultType();
2598 
2599     // 1. Get the pad value.
2600     // TransferReadOp requires a scalar padding value. Note that:
2601     //    * for in-bounds access, the value is actually irrelevant.
2602     //  There are 2 cases in which xfer.read accesses are known to be in-bounds:
2603     //  1. The source shape is static (output vector sizes would be based on
2604     //     the source shape and hence all memory accesses would be in-bounds),
2605     //  2. Masking is used (output vector sizes would be user-provided, in which
2606     //     case it is assumed that all memory accesses are in-bounds). This
2607     //     remains a TODO.
2608     //
2609     // When the value is not known and not needed, use 0. Otherwise, bail out.
2610     Value padValue = getStaticPadVal(sliceOp);
2611     bool isOutOfBoundsRead = !sourceType.hasStaticShape();
2612 
2613     if (!padValue && isOutOfBoundsRead) {
2614       LDBG("Failed to get a pad value for out-of-bounds read access\n");
2615       return failure();
2616     }
2617 
2618     if (!padValue) {
2619       auto elemType = sourceType.getElementType();
2620       padValue = rewriter.create<arith::ConstantOp>(
2621           sliceOp.getLoc(), elemType, rewriter.getZeroAttr(elemType));
2622     }
2623 
2624     // 2. Get the vector shape and in-bounds attributes
2625     SmallVector<int64_t> vecShape;
2626     SmallVector<bool> readInBounds;
2627     SmallVector<bool> writeInBounds;
2628     size_t rankDiff = resultType.getRank() - sourceType.getRank();
2629     for (unsigned i = 0; i < sourceType.getRank(); ++i) {
2630       if (!sourceType.isDynamicDim(i)) {
2631         vecShape.push_back(sourceType.getDimSize(i));
2632         // Source shape is statically known: Neither read nor write are
2633         // out-of-bounds.
2634         readInBounds.push_back(true);
2635         writeInBounds.push_back(true);
2636       } else if (!resultType.isDynamicDim(i)) {
2637         // Source shape is not statically known, but result shape is.
2638         // Vectorize with size of result shape. This may be larger than the
2639         // source size.
2640         // FIXME: Using rankDiff implies that the source tensor is inserted at
2641         // the end of the destination tensor. However, that's not required.
2642         vecShape.push_back(resultType.getDimSize(rankDiff + i));
2643         // Read may be out-of-bounds because the result size could be larger
2644         // than the source size.
2645         readInBounds.push_back(false);
2646         // Write will in-bounds provided that the corresponding write idx is 0.
2647         // To keep this logic simple, conservatively mark as out-of-bounds.
2648         writeInBounds.push_back(false);
2649       } else {
2650         // Neither source nor result dim of padOp is static. Cannot vectorize
2651         // the copy.
2652         // TODO: Add support for masking
2653         return failure();
2654       }
2655     }
2656     auto vecType = VectorType::get(vecShape, sourceType.getElementType());
2657 
2658     // 3. Generate TransferReadOp.
2659     SmallVector<Value> readIndices(
2660         vecType.getRank(),
2661         rewriter.create<arith::ConstantIndexOp>(sliceOp.getLoc(), 0));
2662     auto read = rewriter.create<vector::TransferReadOp>(
2663         sliceOp.getLoc(), vecType, sliceOp.getSource(), readIndices, padValue,
2664         ArrayRef<bool>{readInBounds});
2665 
2666     // 4. Generate TransferWriteOp.
2667     auto writeIndices = getValueOrCreateConstantIndexOp(
2668         rewriter, sliceOp.getLoc(), sliceOp.getMixedOffsets());
2669 
2670     // 5. Finalize
2671     rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
2672         sliceOp, read, sliceOp.getDest(), writeIndices,
2673         ArrayRef<bool>{writeInBounds});
2674 
2675     return success();
2676   }
2677 };
2678 
2679 /// Rewrite use of tensor::PadOp result in InsertSliceOp. E.g.:
2680 /// ```
2681 /// %0 = tensor.pad %src ... : tensor<?x?xf32> to tensor<17x5xf32>
2682 /// %r = tensor.insert_slice %0
2683 ///     into %dest[%a, %b, 0, 0] [1, 1, 17, 5] [1, 1, 1, 1]
2684 ///     : tensor<17x5xf32> into tensor<?x?x17x5xf32>
2685 /// ```
2686 /// is rewritten to:
2687 /// ```
2688 /// %0 = vector.transfer_read %src[%c0, %c0], %padding
2689 ///     : tensor<?x?xf32>, vector<17x5xf32>
2690 /// %r = vector.transfer_write %0, %dest[%a, %b, %c0, %c0]
2691 ///     {in_bounds = [true, true]} : vector<17x5xf32>, tensor<?x?x17x5xf32>
2692 /// ```
2693 ///
2694 /// This rewrite is possible if:
2695 /// - Low padding is static 0.
2696 /// - `padOp` result shape is static.
2697 /// - The entire padded tensor is inserted.
2698 ///   (Implies that sizes of `insertOp` are all static.)
2699 /// - Only unit strides in `insertOp`.
2700 /// - Single, scalar padding value.
2701 /// - `padOp` result not used as destination.
2702 struct PadOpVectorizationWithInsertSlicePattern
2703     : public VectorizePadOpUserPattern<tensor::InsertSliceOp> {
2704   using VectorizePadOpUserPattern<
2705       tensor::InsertSliceOp>::VectorizePadOpUserPattern;
2706 
2707   LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp,
2708                             tensor::InsertSliceOp insertOp) const override {
2709     // Low padding must be static 0.
2710     if (!padOp.hasZeroLowPad())
2711       return failure();
2712     // Only unit stride supported.
2713     if (!insertOp.hasUnitStride())
2714       return failure();
2715     // Pad value must be a constant.
2716     auto padValue = padOp.getConstantPaddingValue();
2717     if (!padValue)
2718       return failure();
2719     // Dynamic shapes not supported.
2720     if (!cast<ShapedType>(padOp.getResult().getType()).hasStaticShape())
2721       return failure();
2722     // Pad result not used as destination.
2723     if (insertOp.getDest() == padOp.getResult())
2724       return failure();
2725 
2726     auto vecType = VectorType::get(padOp.getType().getShape(),
2727                                    padOp.getType().getElementType());
2728     unsigned vecRank = vecType.getRank();
2729     unsigned tensorRank = insertOp.getType().getRank();
2730 
2731     // Check if sizes match: Insert the entire tensor into most minor dims.
2732     // (No permutations allowed.)
2733     SmallVector<int64_t> expectedSizes(tensorRank - vecRank, 1);
2734     expectedSizes.append(vecType.getShape().begin(), vecType.getShape().end());
2735     if (!llvm::all_of(
2736             llvm::zip(insertOp.getMixedSizes(), expectedSizes), [](auto it) {
2737               return getConstantIntValue(std::get<0>(it)) == std::get<1>(it);
2738             }))
2739       return failure();
2740 
2741     // Insert the TransferReadOp and TransferWriteOp at the position of the
2742     // InsertSliceOp.
2743     rewriter.setInsertionPoint(insertOp);
2744 
2745     // Generate TransferReadOp: Read entire source tensor and add high
2746     // padding.
2747     SmallVector<Value> readIndices(
2748         vecRank, rewriter.create<arith::ConstantIndexOp>(padOp.getLoc(), 0));
2749     auto read = rewriter.create<vector::TransferReadOp>(
2750         padOp.getLoc(), vecType, padOp.getSource(), readIndices, padValue);
2751 
2752     // Generate TransferWriteOp: Write to InsertSliceOp's dest tensor at
2753     // specified offsets. Write is fully in-bounds because a InsertSliceOp's
2754     // source must fit into the destination at the specified offsets.
2755     auto writeIndices = getValueOrCreateConstantIndexOp(
2756         rewriter, padOp.getLoc(), insertOp.getMixedOffsets());
2757     SmallVector<bool> inBounds(vecRank, true);
2758     rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
2759         insertOp, read, insertOp.getDest(), writeIndices,
2760         ArrayRef<bool>{inBounds});
2761 
2762     return success();
2763   }
2764 };
2765 
2766 void mlir::linalg::populateInsertSliceVectorizationPatterns(
2767     RewritePatternSet &patterns) {
2768   patterns.add<InsertSliceVectorizePattern>(patterns.getContext());
2769 }
2770 
2771 void mlir::linalg::populatePadOpVectorizationPatterns(
2772     RewritePatternSet &patterns, PatternBenefit baseBenefit) {
2773   // TODO: The following pattern implements "decomposition" and
2774   // optional "vectorization". Seperate "decomposition" into a sepereate
2775   // pre-processing pattern group.
2776   patterns.add<GeneralizePadOpPattern>(patterns.getContext(), baseBenefit);
2777 
2778   // Try these specialized patterns first before resorting to the generic one.
2779   patterns.add<PadOpVectorizationWithTransferReadPattern,
2780                PadOpVectorizationWithTransferWritePattern,
2781                PadOpVectorizationWithInsertSlicePattern>(
2782       patterns.getContext(), baseBenefit.getBenefit() + 1);
2783 }
2784 
2785 //----------------------------------------------------------------------------//
2786 // Forwarding patterns
2787 //----------------------------------------------------------------------------//
2788 
2789 /// Check whether there is any interleaved use of any `values` between
2790 /// `firstOp` and `secondOp`. Conservatively return `true` if any op or value
2791 /// is in a different block.
2792 static bool mayExistInterleavedUses(Operation *firstOp, Operation *secondOp,
2793                                     ValueRange values) {
2794   if (firstOp->getBlock() != secondOp->getBlock() ||
2795       !firstOp->isBeforeInBlock(secondOp)) {
2796     LDBG("interleavedUses precondition failed, firstOp: "
2797          << *firstOp << ", second op: " << *secondOp << "\n");
2798     return true;
2799   }
2800   for (auto v : values) {
2801     for (auto &u : v.getUses()) {
2802       Operation *owner = u.getOwner();
2803       if (owner == firstOp || owner == secondOp)
2804         continue;
2805       // TODO: this is too conservative, use dominance info in the future.
2806       if (owner->getBlock() == firstOp->getBlock() &&
2807           (owner->isBeforeInBlock(firstOp) || secondOp->isBeforeInBlock(owner)))
2808         continue;
2809       LDBG(" found interleaved op " << *owner << ", firstOp: " << *firstOp
2810                                     << ", second op: " << *secondOp << "\n");
2811       return true;
2812     }
2813   }
2814   return false;
2815 }
2816 
2817 /// Return the unique subview use of `v` if it is indeed unique, null
2818 /// otherwise.
2819 static memref::SubViewOp getSubViewUseIfUnique(Value v) {
2820   memref::SubViewOp subViewOp;
2821   for (auto &u : v.getUses()) {
2822     if (auto newSubViewOp = dyn_cast<memref::SubViewOp>(u.getOwner())) {
2823       if (subViewOp)
2824         return memref::SubViewOp();
2825       subViewOp = newSubViewOp;
2826     }
2827   }
2828   return subViewOp;
2829 }
2830 
2831 /// TODO: use interfaces, side-effects and aliasing analysis as appropriate,
2832 /// when available.
2833 LogicalResult LinalgCopyVTRForwardingPattern::matchAndRewrite(
2834     vector::TransferReadOp xferOp, PatternRewriter &rewriter) const {
2835 
2836   // TODO: support mask.
2837   if (xferOp.getMask())
2838     return rewriter.notifyMatchFailure(xferOp, "unsupported mask");
2839 
2840   // Transfer into `view`.
2841   Value viewOrAlloc = xferOp.getSource();
2842   if (!viewOrAlloc.getDefiningOp<memref::ViewOp>() &&
2843       !viewOrAlloc.getDefiningOp<memref::AllocOp>())
2844     return rewriter.notifyMatchFailure(xferOp, "source not a view or alloc");
2845 
2846   // Ensure there is exactly one subview of `viewOrAlloc` defining `subView`.
2847   memref::SubViewOp subViewOp = getSubViewUseIfUnique(viewOrAlloc);
2848   if (!subViewOp)
2849     return rewriter.notifyMatchFailure(xferOp, "no subview found");
2850   Value subView = subViewOp.getResult();
2851 
2852   // Find the copy into `subView` without interleaved uses.
2853   memref::CopyOp copyOp;
2854   for (auto &u : subView.getUses()) {
2855     if (auto newCopyOp = dyn_cast<memref::CopyOp>(u.getOwner())) {
2856       assert(isa<MemRefType>(newCopyOp.getTarget().getType()));
2857       if (newCopyOp.getTarget() != subView)
2858         continue;
2859       if (mayExistInterleavedUses(newCopyOp, xferOp, {viewOrAlloc, subView}))
2860         continue;
2861       copyOp = newCopyOp;
2862       break;
2863     }
2864   }
2865   if (!copyOp)
2866     return rewriter.notifyMatchFailure(xferOp, "no copy found");
2867 
2868   // Find the fill into `viewOrAlloc` without interleaved uses before the
2869   // copy.
2870   FillOp maybeFillOp;
2871   for (auto &u : viewOrAlloc.getUses()) {
2872     if (auto newFillOp = dyn_cast<FillOp>(u.getOwner())) {
2873       assert(isa<MemRefType>(newFillOp.output().getType()));
2874       if (newFillOp.output() != viewOrAlloc)
2875         continue;
2876       if (mayExistInterleavedUses(newFillOp, copyOp, {viewOrAlloc, subView}))
2877         continue;
2878       maybeFillOp = newFillOp;
2879       break;
2880     }
2881   }
2882   // Ensure padding matches.
2883   if (maybeFillOp && xferOp.getPadding() != maybeFillOp.value())
2884     return rewriter.notifyMatchFailure(xferOp,
2885                                        "padding value does not match fill");
2886 
2887   // `in` is the subview that memref.copy reads. Replace it.
2888   Value in = copyOp.getSource();
2889 
2890   // memref.copy + linalg.fill can be used to create a padded local buffer.
2891   // The `masked` attribute is only valid on this padded buffer.
2892   // When forwarding to vector.transfer_read, the attribute must be reset
2893   // conservatively.
2894   auto vectorType = xferOp.getVectorType();
2895   Value res = rewriter.create<vector::TransferReadOp>(
2896       xferOp.getLoc(), vectorType, in, xferOp.getIndices(),
2897       xferOp.getPermutationMapAttr(), xferOp.getPadding(), xferOp.getMask(),
2898       rewriter.getBoolArrayAttr(
2899           SmallVector<bool>(vectorType.getRank(), false)));
2900 
2901   if (maybeFillOp)
2902     rewriter.eraseOp(maybeFillOp);
2903   rewriter.eraseOp(copyOp);
2904   rewriter.replaceOp(xferOp, res);
2905 
2906   return success();
2907 }
2908 
2909 /// TODO: use interfaces, side-effects and aliasing analysis as appropriate,
2910 /// when available.
2911 LogicalResult LinalgCopyVTWForwardingPattern::matchAndRewrite(
2912     vector::TransferWriteOp xferOp, PatternRewriter &rewriter) const {
2913   // TODO: support mask.
2914   if (xferOp.getMask())
2915     return rewriter.notifyMatchFailure(xferOp, "unsupported mask");
2916 
2917   // Transfer into `viewOrAlloc`.
2918   Value viewOrAlloc = xferOp.getSource();
2919   if (!viewOrAlloc.getDefiningOp<memref::ViewOp>() &&
2920       !viewOrAlloc.getDefiningOp<memref::AllocOp>())
2921     return rewriter.notifyMatchFailure(xferOp, "source not a view or alloc");
2922 
2923   // Ensure there is exactly one subview of `viewOrAlloc` defining `subView`.
2924   memref::SubViewOp subViewOp = getSubViewUseIfUnique(viewOrAlloc);
2925   if (!subViewOp)
2926     return rewriter.notifyMatchFailure(xferOp, "no subview found");
2927   Value subView = subViewOp.getResult();
2928 
2929   // Find the copy from `subView` without interleaved uses.
2930   memref::CopyOp copyOp;
2931   for (auto &u : subViewOp.getResult().getUses()) {
2932     if (auto newCopyOp = dyn_cast<memref::CopyOp>(u.getOwner())) {
2933       if (newCopyOp.getSource() != subView)
2934         continue;
2935       if (mayExistInterleavedUses(xferOp, newCopyOp, {viewOrAlloc, subView}))
2936         continue;
2937       copyOp = newCopyOp;
2938       break;
2939     }
2940   }
2941   if (!copyOp)
2942     return rewriter.notifyMatchFailure(xferOp, "no copy found");
2943 
2944   // `out` is the subview copied into that we replace.
2945   assert(isa<MemRefType>(copyOp.getTarget().getType()));
2946   Value out = copyOp.getTarget();
2947 
2948   // Forward vector.transfer into copy.
2949   // memref.copy + linalg.fill can be used to create a padded local buffer.
2950   // The `masked` attribute is only valid on this padded buffer.
2951   // When forwarding to vector.transfer_write, the attribute must be reset
2952   // conservatively.
2953   auto vector = xferOp.getVector();
2954   rewriter.create<vector::TransferWriteOp>(
2955       xferOp.getLoc(), vector, out, xferOp.getIndices(),
2956       xferOp.getPermutationMapAttr(), xferOp.getMask(),
2957       rewriter.getBoolArrayAttr(
2958           SmallVector<bool>(vector.getType().getRank(), false)));
2959 
2960   rewriter.eraseOp(copyOp);
2961   rewriter.eraseOp(xferOp);
2962 
2963   return success();
2964 }
2965 
2966 //===----------------------------------------------------------------------===//
2967 // Convolution vectorization patterns
2968 //===----------------------------------------------------------------------===//
2969 
2970 template <int N>
2971 static void bindShapeDims(ShapedType shapedType) {}
2972 
2973 template <int N, typename IntTy, typename... IntTy2>
2974 static void bindShapeDims(ShapedType shapedType, IntTy &val, IntTy2 &...vals) {
2975   val = shapedType.getShape()[N];
2976   bindShapeDims<N + 1, IntTy2 &...>(shapedType, vals...);
2977 }
2978 
2979 /// Bind a pack of int& to the leading dimensions of shapedType.getShape().
2980 template <typename... IntTy>
2981 static void bindShapeDims(ShapedType shapedType, IntTy &...vals) {
2982   bindShapeDims<0>(shapedType, vals...);
2983 }
2984 
2985 namespace {
2986 bool isCastOfBlockArgument(Operation *op) {
2987   return isa<CastOpInterface>(op) && op->getNumOperands() == 1 &&
2988          isa<BlockArgument>(op->getOperand(0));
2989 }
2990 
2991 bool isSupportedPoolKind(vector::CombiningKind kind) {
2992   switch (kind) {
2993   case vector::CombiningKind::ADD:
2994   case vector::CombiningKind::MAXNUMF:
2995   case vector::CombiningKind::MAXIMUMF:
2996   case vector::CombiningKind::MAXSI:
2997   case vector::CombiningKind::MAXUI:
2998   case vector::CombiningKind::MINNUMF:
2999   case vector::CombiningKind::MINIMUMF:
3000   case vector::CombiningKind::MINSI:
3001   case vector::CombiningKind::MINUI:
3002     return true;
3003   default:
3004     return false;
3005   }
3006 }
3007 
3008 /// Generate a vector implementation for either:
3009 /// ```
3010 ///   Op def: (     w,     kw  )
3011 ///    Iters: ({Par(), Red()})
3012 ///   Layout: {{w + kw}, {kw}, {w}}
3013 /// ```
3014 /// kw is unrolled.
3015 ///
3016 /// or
3017 ///
3018 /// ```
3019 ///   Op def: (     n,     w,     c,    kw,    f  )
3020 ///    Iters: ({Par(), Par(), Par(), Red(), Red()})
3021 ///   Layout: {{n, strideW * w + dilationW * kw, c}, {kw, c, f}, {n, w, f}}
3022 /// ```
3023 /// kw is unrolled, w is unrolled iff dilationW > 1.
3024 ///
3025 /// or
3026 ///
3027 /// ```
3028 ///   Op def: (     n,     c,     w,    f,    kw )
3029 ///    Iters: ({Par(), Par(), Par(), Red(), Red()})
3030 ///   Layout: {{n, c, strideW * w + dilationW * kw}, {f, c, kw}, {n, f, w}}
3031 /// ```
3032 /// kw is unrolled, w is unrolled iff dilationW > 1.
3033 ///
3034 /// or
3035 ///
3036 /// ```
3037 ///   Op def: (     n,     w,     c,    kw )
3038 ///    Iters: ({Par(), Par(), Par(), Red()})
3039 ///   Layout: {{n, strideW * w + dilationW * kw, c}, {kw, c}, {n, w, c}}
3040 /// ```
3041 /// kw is unrolled, w is unrolled iff dilationW > 1.
3042 struct Conv1DGenerator
3043     : public StructuredGenerator<LinalgOp, utils::IteratorType> {
3044   Conv1DGenerator(RewriterBase &rewriter, LinalgOp linalgOp, int strideW,
3045                   int dilationW)
3046       : StructuredGenerator<LinalgOp, utils::IteratorType>(rewriter, linalgOp),
3047         strideW(strideW), dilationW(dilationW) {
3048     // Determine whether `linalgOp` can be generated with this generator
3049     if (linalgOp.getNumDpsInputs() != 2 || linalgOp.getNumDpsInits() != 1)
3050       return;
3051     lhsShaped = linalgOp.getDpsInputOperand(0)->get();
3052     rhsShaped = linalgOp.getDpsInputOperand(1)->get();
3053     resShaped = linalgOp.getDpsInitOperand(0)->get();
3054     lhsShapedType = dyn_cast<ShapedType>(lhsShaped.getType());
3055     rhsShapedType = dyn_cast<ShapedType>(rhsShaped.getType());
3056     resShapedType = dyn_cast<ShapedType>(resShaped.getType());
3057     if (!lhsShapedType || !rhsShapedType || !resShapedType)
3058       return;
3059     // (LHS has dimension NCW/NWC and RES has dimension NFW/NCW/NWF/NWC) OR
3060     // (non-channeled convolution -> LHS and RHS both have single dimensions).
3061     if ((lhsShapedType.getRank() != 3 || resShapedType.getRank() != 3) &&
3062         (lhsShapedType.getRank() != 1 || resShapedType.getRank() != 1))
3063       return;
3064 
3065     Operation *reduceOp = matchLinalgReduction(linalgOp.getDpsInitOperand(0));
3066     if (!reduceOp)
3067       return;
3068     redOp = reduceOp->getName().getIdentifier();
3069 
3070     if (!setOperKind(reduceOp))
3071       return;
3072     auto maybeKind = getCombinerOpKind(reduceOp);
3073     // Typically convolution will have a `Add` CombiningKind but for i1 type it
3074     // can get strength reduced to `OR` which is also supported. This strength
3075     // reduction logic is in `buildBinaryFn` helper in the Linalg dialect.
3076     if (!maybeKind || ((*maybeKind != vector::CombiningKind::ADD &&
3077                         *maybeKind != vector::CombiningKind::OR) &&
3078                        (oper != Pool || !isSupportedPoolKind(*maybeKind)))) {
3079       return;
3080     }
3081     reductionKind = maybeKind.value();
3082 
3083     auto rhsRank = rhsShapedType.getRank();
3084     switch (oper) {
3085     case Conv:
3086       if (rhsRank != 1 && rhsRank != 2 && rhsRank != 3)
3087         return;
3088       break;
3089     case Pool:
3090       if (rhsRank != 1)
3091         return;
3092       break;
3093     }
3094     // The op is now known to be valid.
3095     valid = true;
3096   }
3097 
3098   /// Generate a vector implementation for:
3099   /// ```
3100   ///   Op def: (     w,     kw  )
3101   ///    Iters: ({Par(), Red()})
3102   ///   Layout: {{w + kw}, {kw}, {w}}
3103   /// ```
3104   /// kw is always unrolled.
3105   ///
3106   /// or
3107   ///
3108   /// ```
3109   ///   Op def: (     n,     w,     c,    kw,    f  )
3110   ///    Iters: ({Par(), Par(), Par(), Red(), Red()})
3111   ///   Layout: {{n, strideW * w + dilationW * kw, c}, {kw, c, f}, {n, w, f}}
3112   /// ```
3113   /// kw is always unrolled.
3114   /// TODO: w (resp. kw) is unrolled when the strideW ( resp. dilationW) is
3115   /// > 1.
3116   FailureOr<Operation *> conv(Conv1DOpOrder conv1DOpOrder) {
3117     if (!valid)
3118       return rewriter.notifyMatchFailure(op, "unvectorizable 1-D conv/pool");
3119 
3120     int64_t nSize, wSize, cSize, kwSize, fSize;
3121     SmallVector<int64_t, 3> lhsShape, rhsShape, resShape;
3122     bool isSingleChanneled = (conv1DOpOrder == Conv1DOpOrder::W);
3123     switch (conv1DOpOrder) {
3124     case Conv1DOpOrder::W:
3125       // Initialize unused dimensions
3126       nSize = fSize = cSize = 0;
3127       // out{W}
3128       bindShapeDims(resShapedType, wSize);
3129       // kernel{kw}
3130       bindShapeDims(rhsShapedType, kwSize);
3131       lhsShape = {// iw = ow + kw - 1
3132                   //   (i.e. 16 convolved with 3 -> 14)
3133                   (wSize + kwSize - 1)};
3134       rhsShape = {kwSize};
3135       resShape = {wSize};
3136       break;
3137     case Conv1DOpOrder::Nwc:
3138       // out{n, w, f}
3139       bindShapeDims(resShapedType, nSize, wSize, fSize);
3140       switch (oper) {
3141       case Conv:
3142         // kernel{kw, c, f}
3143         bindShapeDims(rhsShapedType, kwSize, cSize);
3144         break;
3145       case Pool:
3146         // kernel{kw}
3147         bindShapeDims(rhsShapedType, kwSize);
3148         cSize = fSize;
3149         break;
3150       }
3151       lhsShape = {nSize,
3152                   // iw = ow * sw + kw *  dw - 1
3153                   //   (i.e. 16 convolved with 3 (@stride 1 dilation 1) -> 14)
3154                   // Perform the proper inclusive -> exclusive -> inclusive.
3155                   ((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) -
3156                       1,
3157                   cSize};
3158       switch (oper) {
3159       case Conv:
3160         rhsShape = {kwSize, cSize, fSize};
3161         break;
3162       case Pool:
3163         rhsShape = {kwSize};
3164         break;
3165       }
3166       resShape = {nSize, wSize, fSize};
3167       break;
3168     case Conv1DOpOrder::Ncw:
3169       // out{n, f, w}
3170       bindShapeDims(resShapedType, nSize, fSize, wSize);
3171       switch (oper) {
3172       case Conv:
3173         // kernel{f, c, kw}
3174         bindShapeDims(rhsShapedType, fSize, cSize, kwSize);
3175         break;
3176       case Pool:
3177         // kernel{kw}
3178         bindShapeDims(rhsShapedType, kwSize);
3179         cSize = fSize;
3180         break;
3181       }
3182       lhsShape = {nSize, cSize,
3183                   // iw = ow * sw + kw *  dw - 1
3184                   //   (i.e. 16 convolved with 3 (@stride 1 dilation 1) -> 14)
3185                   // Perform the proper inclusive -> exclusive -> inclusive.
3186                   ((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) -
3187                       1};
3188       switch (oper) {
3189       case Conv:
3190         rhsShape = {fSize, cSize, kwSize};
3191         break;
3192       case Pool:
3193         rhsShape = {kwSize};
3194         break;
3195       }
3196       resShape = {nSize, fSize, wSize};
3197       break;
3198     }
3199 
3200     vector::TransferWriteOp write;
3201     Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
3202 
3203     // w is unrolled (i.e. wSizeStep == 1) iff strideW > 1.
3204     // When strideW == 1, we can batch the contiguous loads and avoid
3205     // unrolling
3206     int64_t wSizeStep = strideW == 1 ? wSize : 1;
3207 
3208     Type lhsEltType = lhsShapedType.getElementType();
3209     Type rhsEltType = rhsShapedType.getElementType();
3210     Type resEltType = resShapedType.getElementType();
3211     auto lhsType = VectorType::get(lhsShape, lhsEltType);
3212     auto rhsType = VectorType::get(rhsShape, rhsEltType);
3213     auto resType = VectorType::get(resShape, resEltType);
3214     // Zero padding with the corresponding dimensions for lhs, rhs and res.
3215     SmallVector<Value> lhsPadding(lhsShape.size(), zero);
3216     SmallVector<Value> rhsPadding(rhsShape.size(), zero);
3217     SmallVector<Value> resPadding(resShape.size(), zero);
3218 
3219     // Read the whole lhs, rhs and res in one shot (with zero padding).
3220     Value lhs = rewriter.create<vector::TransferReadOp>(loc, lhsType, lhsShaped,
3221                                                         lhsPadding);
3222     // This is needed only for Conv.
3223     Value rhs = nullptr;
3224     if (oper == Conv)
3225       rhs = rewriter.create<vector::TransferReadOp>(loc, rhsType, rhsShaped,
3226                                                     rhsPadding);
3227     Value res = rewriter.create<vector::TransferReadOp>(loc, resType, resShaped,
3228                                                         resPadding);
3229 
3230     // The base vectorization case for channeled convolution is input:
3231     // {n,w,c}, weight: {kw,c,f}, output: {n,w,f}. To reuse the base pattern
3232     // vectorization case, we do pre transpose on input, weight, and output.
3233     switch (conv1DOpOrder) {
3234     case Conv1DOpOrder::W:
3235     case Conv1DOpOrder::Nwc:
3236       // Base case, so no transposes necessary.
3237       break;
3238     case Conv1DOpOrder::Ncw: {
3239       // To match base vectorization case, we pre-transpose current case.
3240       // ncw -> nwc
3241       static constexpr std::array<int64_t, 3> permLhs = {0, 2, 1};
3242       lhs = rewriter.create<vector::TransposeOp>(loc, lhs, permLhs);
3243       // fcw -> wcf
3244       static constexpr std::array<int64_t, 3> permRhs = {2, 1, 0};
3245 
3246       // This is needed only for Conv.
3247       if (oper == Conv)
3248         rhs = rewriter.create<vector::TransposeOp>(loc, rhs, permRhs);
3249       // nfw -> nwf
3250       static constexpr std::array<int64_t, 3> permRes = {0, 2, 1};
3251       res = rewriter.create<vector::TransposeOp>(loc, res, permRes);
3252       break;
3253     }
3254     }
3255 
3256     //===------------------------------------------------------------------===//
3257     // Begin vector-only rewrite part
3258     //===------------------------------------------------------------------===//
3259     // Unroll along kw and read slices of lhs and rhs.
3260     SmallVector<Value> lhsVals, rhsVals, resVals;
3261     lhsVals = extractConvInputSlices(rewriter, loc, lhs, nSize, wSize, cSize,
3262                                      kwSize, strideW, dilationW, wSizeStep,
3263                                      isSingleChanneled);
3264     // Do not do for pooling.
3265     if (oper == Conv)
3266       rhsVals = extractConvFilterSlices(rewriter, loc, rhs, kwSize);
3267     resVals = extractConvResultSlices(rewriter, loc, res, nSize, wSize, fSize,
3268                                       wSizeStep, isSingleChanneled);
3269 
3270     auto linearIndex = [&](int64_t kw, int64_t w) {
3271       return kw * (wSize / wSizeStep) + w;
3272     };
3273 
3274     // Compute contraction: O{n, w, f} += I{n, sw * w + dw * kw, c} * F{c, f}
3275     // or perform outerproduct for non-channeled convolution or perform simple
3276     // arith operation for pooling
3277     for (int64_t kw = 0; kw < kwSize; ++kw) {
3278       for (int64_t w = 0; w < wSize; w += wSizeStep) {
3279         switch (oper) {
3280         case Conv:
3281           if (isSingleChanneled) {
3282             resVals[w] = conv1dSliceAsOuterProduct(rewriter, loc,
3283                                                    lhsVals[linearIndex(kw, w)],
3284                                                    rhsVals[kw], resVals[w]);
3285           } else {
3286             resVals[w] = conv1dSliceAsContraction(rewriter, loc,
3287                                                   lhsVals[linearIndex(kw, w)],
3288                                                   rhsVals[kw], resVals[w]);
3289           }
3290           break;
3291         case Pool:
3292           resVals[w] = pool1dSlice(rewriter, loc, lhsVals[linearIndex(kw, w)],
3293                                    resVals[w]);
3294           break;
3295         }
3296       }
3297     }
3298 
3299     res = insertConvResultSlices(rewriter, loc, res, wSize, wSizeStep, resVals,
3300                                  isSingleChanneled);
3301     //===------------------------------------------------------------------===//
3302     // End vector-only rewrite part
3303     //===------------------------------------------------------------------===//
3304 
3305     // The base vectorization case for channeled convolution is output:
3306     // {n,w,f} To reuse the result from base pattern vectorization case, we
3307     // post transpose the base case result.
3308     switch (conv1DOpOrder) {
3309     case Conv1DOpOrder::W:
3310     case Conv1DOpOrder::Nwc:
3311       // Base case, so no transposes necessary.
3312       break;
3313     case Conv1DOpOrder::Ncw: {
3314       // nwf -> nfw
3315       static constexpr std::array<int64_t, 3> perm = {0, 2, 1};
3316       res = rewriter.create<vector::TransposeOp>(loc, res, perm);
3317       break;
3318     }
3319     }
3320 
3321     return rewriter
3322         .create<vector::TransferWriteOp>(loc, res, resShaped, resPadding)
3323         .getOperation();
3324   }
3325 
3326   // Take a value and widen to have the same element type as `ty`.
3327   Value promote(RewriterBase &rewriter, Location loc, Value val, Type ty) {
3328     const Type srcElementType = getElementTypeOrSelf(val.getType());
3329     const Type dstElementType = getElementTypeOrSelf(ty);
3330     assert(isa<IntegerType>(dstElementType) || isa<FloatType>(dstElementType));
3331     if (srcElementType == dstElementType)
3332       return val;
3333 
3334     const int64_t srcWidth = srcElementType.getIntOrFloatBitWidth();
3335     const int64_t dstWidth = dstElementType.getIntOrFloatBitWidth();
3336     const Type dstType =
3337         cast<ShapedType>(val.getType()).cloneWith(std::nullopt, dstElementType);
3338 
3339     if (isa<IntegerType>(srcElementType) && isa<FloatType>(dstElementType)) {
3340       return rewriter.create<arith::SIToFPOp>(loc, dstType, val);
3341     }
3342 
3343     if (isa<FloatType>(srcElementType) && isa<FloatType>(dstElementType) &&
3344         srcWidth < dstWidth)
3345       return rewriter.create<arith::ExtFOp>(loc, dstType, val);
3346 
3347     if (isa<IntegerType>(srcElementType) && isa<IntegerType>(dstElementType) &&
3348         srcWidth < dstWidth)
3349       return rewriter.create<arith::ExtSIOp>(loc, dstType, val);
3350 
3351     assert(false && "unhandled promotion case");
3352     return nullptr;
3353   }
3354 
3355   // Create a contraction: lhs{n, w, c} * rhs{c, f} -> res{n, w, f}
3356   Value conv1dSliceAsContraction(RewriterBase &rewriter, Location loc,
3357                                  Value lhs, Value rhs, Value res) {
3358     vector::IteratorType par = vector::IteratorType::parallel;
3359     vector::IteratorType red = vector::IteratorType::reduction;
3360     AffineExpr n, w, f, c;
3361     bindDims(ctx, n, w, f, c);
3362     lhs = promote(rewriter, loc, lhs, res.getType());
3363     rhs = promote(rewriter, loc, rhs, res.getType());
3364     auto contrationOp = rewriter.create<vector::ContractionOp>(
3365         loc, lhs, rhs, res,
3366         /*indexingMaps=*/MapList{{n, w, c}, {c, f}, {n, w, f}},
3367         /*iteratorTypes=*/ArrayRef<vector::IteratorType>{par, par, par, red});
3368     contrationOp.setKind(reductionKind);
3369     return contrationOp;
3370   }
3371 
3372   // Create an outerproduct: lhs{w} * rhs{1} -> res{w} for single channel
3373   // convolution.
3374   Value conv1dSliceAsOuterProduct(RewriterBase &rewriter, Location loc,
3375                                   Value lhs, Value rhs, Value res) {
3376     return rewriter.create<vector::OuterProductOp>(
3377         loc, res.getType(), lhs, rhs, res, vector::CombiningKind::ADD);
3378   }
3379 
3380   // Create a reduction: lhs{n, w, c} -> res{n, w, c}
3381   Value pool1dSlice(RewriterBase &rewriter, Location loc, Value lhs,
3382                     Value res) {
3383     if (isPoolExt)
3384       lhs = rewriter.create(loc, poolExtOp, lhs, res.getType())->getResult(0);
3385     return rewriter
3386         .create(loc, redOp, ArrayRef<Value>{lhs, res}, res.getType())
3387         ->getResult(0);
3388   }
3389 
3390   /// Generate a vector implementation for:
3391   /// ```
3392   ///   Op def: (     n,     w,     c,    kw)
3393   ///    Iters: ({Par(), Par(), Par(), Red()})
3394   ///   Layout: {{n, strideW * w + dilationW * kw, c}, {kw, c}, {n, w, c}}
3395   /// ```
3396   /// kw is always unrolled.
3397   /// TODO: w (resp. kw) is unrolled when the strideW ( resp. dilationW) is
3398   /// > 1.
3399   FailureOr<Operation *> depthwiseConv(uint64_t channelDimVecSize,
3400                                        bool channelDimScalableFlag,
3401                                        bool flatten) {
3402     if (!valid)
3403       return rewriter.notifyMatchFailure(op, "unvectorizable depthwise conv");
3404 
3405     bool scalableChDim = false;
3406     bool useMasking = false;
3407     int64_t nSize, wSize, cSize, kwSize;
3408     // kernel{kw, c}
3409     bindShapeDims(rhsShapedType, kwSize, cSize);
3410     if (ShapedType::isDynamic(cSize)) {
3411       assert(channelDimVecSize != 0 && "Channel dim vec size must be > 0");
3412       cSize = channelDimVecSize;
3413       // Scalable vectors are only used when both conditions are met:
3414       //  1. channel dim is dynamic
3415       //  2. channelDimScalableFlag is set
3416       scalableChDim = channelDimScalableFlag;
3417       useMasking = true;
3418     }
3419 
3420     assert(!(useMasking && flatten) &&
3421            "Unsupported flattened conv with dynamic shapes");
3422 
3423     // out{n, w, c}
3424     bindShapeDims(resShapedType, nSize, wSize);
3425 
3426     vector::TransferWriteOp write;
3427     Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
3428 
3429     // w is unrolled (i.e. wSizeStep == 1) iff strideW > 1.
3430     // When strideW == 1, we can batch the contiguous loads and avoid
3431     // unrolling
3432     int64_t wSizeStep = strideW == 1 ? wSize : 1;
3433 
3434     Type lhsEltType = lhsShapedType.getElementType();
3435     Type rhsEltType = rhsShapedType.getElementType();
3436     Type resEltType = resShapedType.getElementType();
3437     VectorType lhsType = VectorType::get(
3438         {nSize,
3439          // iw = ow * sw + kw *  dw - 1
3440          //   (i.e. 16 convolved with 3 (@stride 1 dilation 1) -> 14)
3441          ((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) - 1,
3442          cSize},
3443         lhsEltType, /*scalableDims=*/{false, false, scalableChDim});
3444     VectorType rhsType =
3445         VectorType::get({kwSize, cSize}, rhsEltType,
3446                         /*scalableDims=*/{false, scalableChDim});
3447     VectorType resType =
3448         VectorType::get({nSize, wSize, cSize}, resEltType,
3449                         /*scalableDims=*/{false, false, scalableChDim});
3450 
3451     // Masks the input xfer Op along the channel dim, iff the corresponding
3452     // scalable flag is set.
3453     auto maybeMaskXferOp = [&](ArrayRef<int64_t> maskShape,
3454                                ArrayRef<bool> scalableDims,
3455                                Operation *opToMask) {
3456       if (!useMasking)
3457         return opToMask;
3458       auto maskType =
3459           VectorType::get(maskShape, rewriter.getI1Type(), scalableDims);
3460 
3461       SmallVector<bool> inBounds(maskShape.size(), true);
3462       auto xferOp = cast<VectorTransferOpInterface>(opToMask);
3463       xferOp->setAttr(xferOp.getInBoundsAttrName(),
3464                       rewriter.getBoolArrayAttr(inBounds));
3465 
3466       SmallVector<OpFoldResult> mixedDims = vector::getMixedSizesXfer(
3467           cast<LinalgOp>(op).hasPureTensorSemantics(), opToMask, rewriter);
3468 
3469       Value maskOp =
3470           rewriter.create<vector::CreateMaskOp>(loc, maskType, mixedDims);
3471 
3472       return mlir::vector::maskOperation(rewriter, opToMask, maskOp);
3473     };
3474 
3475     // Read lhs slice of size {n, w * strideW + kw * dilationW, c} @ [0, 0,
3476     // 0].
3477     Value lhs = rewriter.create<vector::TransferReadOp>(
3478         loc, lhsType, lhsShaped, ValueRange{zero, zero, zero});
3479     auto maybeMaskedLhs = maybeMaskXferOp(
3480         lhsType.getShape(), lhsType.getScalableDims(), lhs.getDefiningOp());
3481 
3482     // Read rhs slice of size {kw, c} @ [0, 0].
3483     Value rhs = rewriter.create<vector::TransferReadOp>(loc, rhsType, rhsShaped,
3484                                                         ValueRange{zero, zero});
3485     auto maybeMaskedRhs = maybeMaskXferOp(
3486         rhsType.getShape(), rhsType.getScalableDims(), rhs.getDefiningOp());
3487 
3488     // Read res slice of size {n, w, c} @ [0, 0, 0].
3489     Value res = rewriter.create<vector::TransferReadOp>(
3490         loc, resType, resShaped, ValueRange{zero, zero, zero});
3491     auto maybeMaskedRes = maybeMaskXferOp(
3492         resType.getShape(), resType.getScalableDims(), res.getDefiningOp());
3493 
3494     //===------------------------------------------------------------------===//
3495     // Begin vector-only rewrite part
3496     //===------------------------------------------------------------------===//
3497     // Unroll along kw and read slices of lhs and rhs.
3498     SmallVector<Value> lhsVals, rhsVals, resVals;
3499     auto inOutSliceSizes = SmallVector<int64_t>{nSize, wSizeStep, cSize};
3500     auto inOutStrides = SmallVector<int64_t>{1, 1, 1};
3501 
3502     // Extract lhs slice of size {n, wSizeStep, c}
3503     //   @ [0, sw * w + dw * kw, 0].
3504     for (int64_t kw = 0; kw < kwSize; ++kw) {
3505       for (int64_t w = 0; w < wSize; w += wSizeStep) {
3506         lhsVals.push_back(rewriter.create<vector::ExtractStridedSliceOp>(
3507             loc, maybeMaskedLhs->getResult(0),
3508             /*offsets=*/ArrayRef<int64_t>{0, w * strideW + kw * dilationW, 0},
3509             inOutSliceSizes, inOutStrides));
3510       }
3511     }
3512     // Extract rhs slice of size {c} @ [kw].
3513     for (int64_t kw = 0; kw < kwSize; ++kw) {
3514       rhsVals.push_back(rewriter.create<vector::ExtractOp>(
3515           loc, maybeMaskedRhs->getResult(0),
3516           /*offsets=*/ArrayRef<int64_t>{kw}));
3517     }
3518     // Extract res slice: {n, wSizeStep, c} @ [0, w, 0].
3519     for (int64_t w = 0; w < wSize; w += wSizeStep) {
3520       resVals.push_back(rewriter.create<vector::ExtractStridedSliceOp>(
3521           loc, maybeMaskedRes->getResult(0),
3522           /*offsets=*/ArrayRef<int64_t>{0, w, 0}, inOutSliceSizes,
3523           inOutStrides));
3524     }
3525 
3526     auto linearIndex = [&](int64_t kw, int64_t w) {
3527       return kw * (wSize / wSizeStep) + w;
3528     };
3529 
3530     // Note - the scalable flags are ignored as flattening combined with
3531     // scalable vectorization is not supported.
3532     auto inOutFlattenSliceSizes =
3533         SmallVector<int64_t>{nSize, wSizeStep * cSize};
3534     auto lhsTypeAfterFlattening =
3535         VectorType::get(inOutFlattenSliceSizes, lhsEltType);
3536     auto resTypeAfterFlattening =
3537         VectorType::get(inOutFlattenSliceSizes, resEltType);
3538 
3539     // Compute contraction: O{n, w, c} += I{n, sw * w + dw * kw, c} * F{c}
3540     for (int64_t kw = 0; kw < kwSize; ++kw) {
3541       for (int64_t w = 0; w < wSize; w += wSizeStep) {
3542         Value lhsVal = lhsVals[linearIndex(kw, w)];
3543         Value resVal = resVals[w];
3544         if (flatten) {
3545           // Flatten the input and output vectors (collapse the channel
3546           // dimension)
3547           lhsVal = rewriter.create<vector::ShapeCastOp>(
3548               loc, lhsTypeAfterFlattening, lhsVals[linearIndex(kw, w)]);
3549           resVal = rewriter.create<vector::ShapeCastOp>(
3550               loc, resTypeAfterFlattening, resVals[w]);
3551         }
3552         resVals[w] = depthwiseConv1dSliceAsMulAcc(rewriter, loc, lhsVal,
3553                                                   rhsVals[kw], resVal, flatten);
3554         if (flatten) {
3555           // Un-flatten the output vector (restore the channel dimension)
3556           resVals[w] = rewriter.create<vector::ShapeCastOp>(
3557               loc, VectorType::get(inOutSliceSizes, resEltType), resVals[w]);
3558         }
3559       }
3560     }
3561 
3562     // Its possible we failed to create the Fma.
3563     if (!llvm::all_of(resVals, [](Value v) { return v; })) {
3564       // Manually revert (in reverse order) to avoid leaving a bad IR state.
3565       for (auto &collection :
3566            {resVals, rhsVals, lhsVals, {res, rhs, lhs, zero}})
3567         for (Value v : collection)
3568           rewriter.eraseOp(v.getDefiningOp());
3569       return rewriter.notifyMatchFailure(op, "failed to create FMA");
3570     }
3571 
3572     // Write back res slice: {n, wSizeStep, c} @ [0, w, 0].
3573     // This does not depend on kw.
3574     for (int64_t w = 0; w < wSize; w += wSizeStep) {
3575       maybeMaskedRes = rewriter.create<vector::InsertStridedSliceOp>(
3576           loc, resVals[w], maybeMaskedRes->getResult(0),
3577           /*offsets=*/ArrayRef<int64_t>{0, w, 0},
3578           /*strides=*/ArrayRef<int64_t>{1, 1, 1});
3579     }
3580     //===------------------------------------------------------------------===//
3581     // End vector-only rewrite part
3582     //===------------------------------------------------------------------===//
3583 
3584     // Write back res slice of size {n, w, c} @ [0, 0, 0].
3585     Operation *resOut = rewriter.create<vector::TransferWriteOp>(
3586         loc, maybeMaskedRes->getResult(0), resShaped,
3587         ValueRange{zero, zero, zero});
3588     return maybeMaskXferOp(resType.getShape(), resType.getScalableDims(),
3589                            resOut);
3590   }
3591 
3592   /// Lower:
3593   ///   *  lhs{n, w, c} * rhs{c} -> res{n, w, c} (flatten = false)
3594   ///   *  lhs{n, w * c} * rhs{c} -> res{n, w * c} (flatten = true)
3595   /// to MulAcc.
3596   Value depthwiseConv1dSliceAsMulAcc(RewriterBase &rewriter, Location loc,
3597                                      Value lhs, Value rhs, Value res,
3598                                      bool flatten) {
3599     auto rhsTy = cast<ShapedType>(rhs.getType());
3600     auto resTy = cast<ShapedType>(res.getType());
3601 
3602     // TODO(suderman): Change this to use a vector.ima intrinsic.
3603     lhs = promote(rewriter, loc, lhs, resTy);
3604 
3605     if (flatten) {
3606       // NOTE: This following logic won't work for scalable vectors. For this
3607       // reason, "flattening" is not supported when shapes are dynamic (this
3608       // should be captured by one of the pre-conditions).
3609 
3610       // There are two options for handling the filter:
3611       //  * shape_cast(broadcast(filter))
3612       //  * broadcast(shuffle(filter))
3613       // Opt for the option without shape_cast to simplify the codegen.
3614       auto rhsSize = cast<VectorType>(rhs.getType()).getShape()[0];
3615       auto resSize = cast<VectorType>(res.getType()).getShape()[1];
3616 
3617       SmallVector<int64_t, 16> indices;
3618       for (int i = 0; i < resSize / rhsSize; ++i) {
3619         for (int j = 0; j < rhsSize; ++j)
3620           indices.push_back(j);
3621       }
3622 
3623       rhs = rewriter.create<vector::ShuffleOp>(loc, rhs, rhs, indices);
3624     }
3625     // Broadcast the filter to match the output vector
3626     rhs = rewriter.create<vector::BroadcastOp>(
3627         loc, resTy.clone(rhsTy.getElementType()), rhs);
3628 
3629     rhs = promote(rewriter, loc, rhs, resTy);
3630 
3631     if (!lhs || !rhs)
3632       return nullptr;
3633 
3634     if (isa<FloatType>(resTy.getElementType()))
3635       return rewriter.create<vector::FMAOp>(loc, lhs, rhs, res);
3636 
3637     auto mul = rewriter.create<arith::MulIOp>(loc, lhs, rhs);
3638     return rewriter.create<arith::AddIOp>(loc, mul, res);
3639   }
3640 
3641   /// Entry point for non-channeled convolution:
3642   ///   {{w + kw}, {kw}, {w}}
3643   FailureOr<Operation *> generateNonChanneledConv() {
3644     AffineExpr w, kw;
3645     bindDims(ctx, w, kw);
3646     if (!iters({Par(), Red()}))
3647       return rewriter.notifyMatchFailure(op,
3648                                          "failed to match conv::W 1-par 1-red");
3649 
3650     // No transposition needed.
3651     if (layout({/*lhsIndex*/ {w + kw},
3652                 /*rhsIndex*/ {kw},
3653                 /*resIndex*/ {w}}))
3654       return conv(Conv1DOpOrder::W);
3655 
3656     return rewriter.notifyMatchFailure(op, "not a conv::W layout");
3657   }
3658 
3659   /// Entry point that transposes into the common form:
3660   ///   {{n, strideW * w + dilationW * kw, c}, {kw, c, f}, {n, w, f}}
3661   FailureOr<Operation *> generateNwcConv() {
3662     AffineExpr n, w, f, kw, c;
3663     bindDims(ctx, n, w, f, kw, c);
3664     if (!iters({Par(), Par(), Par(), Red(), Red()}))
3665       return rewriter.notifyMatchFailure(
3666           op, "failed to match conv::Nwc 3-par 2-red");
3667 
3668     // No transposition needed.
3669     if (layout({/*lhsIndex*/ {n, strideW * w + dilationW * kw, c},
3670                 /*rhsIndex*/ {kw, c, f},
3671                 /*resIndex*/ {n, w, f}}))
3672       return conv(Conv1DOpOrder::Nwc);
3673 
3674     return rewriter.notifyMatchFailure(op, "not a conv::Nwc layout");
3675   }
3676 
3677   /// Entry point that transposes into the common form:
3678   ///   {{n, c, strideW * w + dilationW * kw}, {f, c, kw}, {n, f, w}}
3679   FailureOr<Operation *> generateNcwConv() {
3680     AffineExpr n, w, f, kw, c;
3681     bindDims(ctx, n, f, w, c, kw);
3682     if (!iters({Par(), Par(), Par(), Red(), Red()}))
3683       return rewriter.notifyMatchFailure(
3684           op, "failed to match conv::Ncw 3-par 2-red");
3685 
3686     if (layout({/*lhsIndex*/ {n, c, strideW * w + dilationW * kw},
3687                 /*rhsIndex*/ {f, c, kw},
3688                 /*resIndex*/ {n, f, w}}))
3689       return conv(Conv1DOpOrder::Ncw);
3690 
3691     return rewriter.notifyMatchFailure(op, "not a conv::Ncw layout");
3692   }
3693 
3694   /// Entry point that transposes into the common form:
3695   ///   {{n, strideW * w + dilationW * kw, c}, {kw}, {n, w, c}} for pooling
3696   FailureOr<Operation *> generateNwcPooling() {
3697     AffineExpr n, w, c, kw;
3698     bindDims(ctx, n, w, c, kw);
3699     if (!iters({Par(), Par(), Par(), Red()}))
3700       return rewriter.notifyMatchFailure(op,
3701                                          "failed to match pooling 3-par 1-red");
3702 
3703     // No transposition needed.
3704     if (layout({/*lhsIndex*/ {n, strideW * w + dilationW * kw, c},
3705                 /*rhsIndex*/ {kw},
3706                 /*resIndex*/ {n, w, c}}))
3707       return conv(Conv1DOpOrder::Nwc);
3708 
3709     return rewriter.notifyMatchFailure(op, "not a pooling::Nwc layout");
3710   }
3711 
3712   /// Entry point that transposes into the common form:
3713   ///   {{n, c, strideW * w + dilationW * kw}, {kw}, {n, c, w}} for pooling
3714   FailureOr<Operation *> generateNcwPooling() {
3715     AffineExpr n, w, c, kw;
3716     bindDims(ctx, n, c, w, kw);
3717     if (!iters({Par(), Par(), Par(), Red()}))
3718       return rewriter.notifyMatchFailure(op,
3719                                          "failed to match pooling 3-par 1-red");
3720 
3721     if (layout({/*lhsIndex*/ {n, c, strideW * w + dilationW * kw},
3722                 /*rhsIndex*/ {kw},
3723                 /*resIndex*/ {n, c, w}}))
3724       return conv(Conv1DOpOrder::Ncw);
3725 
3726     return rewriter.notifyMatchFailure(op, "not a pooling::Ncw layout");
3727   }
3728 
3729   /// Entry point that transposes into the common form:
3730   ///   {{n, strideW * w + dilationW * kw, c}, {kw, c}, {n, w, c}}
3731   FailureOr<Operation *> generateDilatedConv(uint64_t vecChDimSize = 0,
3732                                              bool vecChDimScalableFlag = false,
3733                                              bool flatten = false) {
3734     AffineExpr n, w, c, kw;
3735     bindDims(ctx, n, w, c, kw);
3736     if (!iters({Par(), Par(), Par(), Red()}))
3737       return rewriter.notifyMatchFailure(
3738           op, "failed to match depthwise::Nwc conv 3-par 1-red");
3739 
3740     // No transposition needed.
3741     if (layout({/*lhsIndex*/ {n, strideW * w + dilationW * kw, c},
3742                 /*rhsIndex*/ {kw, c},
3743                 /*resIndex*/ {n, w, c}}))
3744       return depthwiseConv(vecChDimSize, vecChDimScalableFlag, flatten);
3745 
3746     return rewriter.notifyMatchFailure(op, "not a depthwise::Nwc layout");
3747   }
3748 
3749 private:
3750   enum OperKind { Conv, Pool };
3751   bool valid = false;
3752   OperKind oper = Conv;
3753   StringAttr redOp;
3754   StringAttr poolExtOp;
3755   bool isPoolExt = false;
3756   int strideW, dilationW;
3757   Value lhsShaped, rhsShaped, resShaped;
3758   ShapedType lhsShapedType, rhsShapedType, resShapedType;
3759   vector::CombiningKind reductionKind;
3760 
3761   // Sets oper, poolExtOp and isPoolExt for valid conv/pooling ops.
3762   // Returns true iff it is a valid conv/pooling op.
3763   // If (region has 2 ops (reduction + yield) or 3 ops (extension + reduction
3764   // + yield) and rhs is not used) then it is the body of a pooling
3765   // If conv, check for single `mul` predecessor. The `mul` operands must be
3766   // block arguments or extension of block arguments.
3767   // Otherwise, check for one or zero `ext` predecessor. The `ext` operands
3768   // must be block arguments or extension of block arguments.
3769   bool setOperKind(Operation *reduceOp) {
3770     int numBlockArguments =
3771         llvm::count_if(reduceOp->getOperands(), llvm::IsaPred<BlockArgument>);
3772     switch (numBlockArguments) {
3773     case 1: {
3774       // Will be convolution if feeder is a MulOp.
3775       // A strength reduced version of MulOp for i1 type is AndOp which is also
3776       // supported. Otherwise, it can be pooling. This strength reduction logic
3777       // is in `buildBinaryFn` helper in the Linalg dialect.
3778       auto feedValIt = llvm::find_if_not(reduceOp->getOperands(),
3779                                          llvm::IsaPred<BlockArgument>);
3780       Operation *feedOp = (*feedValIt).getDefiningOp();
3781       if (isCastOfBlockArgument(feedOp)) {
3782         oper = Pool;
3783         isPoolExt = true;
3784         poolExtOp = feedOp->getName().getIdentifier();
3785       } else if (!((isa<arith::MulIOp, arith::MulFOp>(feedOp) ||
3786                     (isa<arith::AndIOp>(feedOp) &&
3787                      feedOp->getResultTypes()[0].isInteger(1))) &&
3788                    llvm::all_of(feedOp->getOperands(), [](Value v) {
3789                      if (isa<BlockArgument>(v))
3790                        return true;
3791                      if (Operation *op = v.getDefiningOp())
3792                        return isCastOfBlockArgument(op);
3793                      return false;
3794                    }))) {
3795         return false;
3796       }
3797       return true;
3798     }
3799     case 2:
3800       // Must be pooling
3801       oper = Pool;
3802       isPoolExt = false;
3803       return true;
3804     default:
3805       return false;
3806     }
3807   }
3808 };
3809 } // namespace
3810 
3811 /// Helper function to vectorize a LinalgOp with convolution semantics.
3812 // TODO: extend the generic vectorization to support windows and drop this.
3813 static FailureOr<Operation *> vectorizeConvolution(
3814     RewriterBase &rewriter, LinalgOp op, ArrayRef<int64_t> inputVecSizes,
3815     ArrayRef<bool> inputScalableVecDims, bool flatten1DDepthwiseConv) {
3816   // The ConvolutionOpInterface gives us guarantees of existence for
3817   // strides/dilations. However, we do not need to rely on those, we can
3818   // simply use them if present, otherwise use the default and let the generic
3819   // conv. matcher in the ConvGenerator succeed or fail.
3820   auto strides = op->getAttrOfType<DenseIntElementsAttr>("strides");
3821   auto dilations = op->getAttrOfType<DenseIntElementsAttr>("dilations");
3822   auto stride = strides ? *strides.getValues<uint64_t>().begin() : 1;
3823   auto dilation = dilations ? *dilations.getValues<uint64_t>().begin() : 1;
3824   Conv1DGenerator e(rewriter, op, stride, dilation);
3825   auto res = e.generateNonChanneledConv();
3826   if (succeeded(res))
3827     return res;
3828   res = e.generateNwcConv();
3829   if (succeeded(res))
3830     return res;
3831   res = e.generateNcwConv();
3832   if (succeeded(res))
3833     return res;
3834   res = e.generateNwcPooling();
3835   if (succeeded(res))
3836     return res;
3837   res = e.generateNcwPooling();
3838   if (succeeded(res))
3839     return res;
3840 
3841   // Only depthwise 1D NWC convs are left - these can be vectorized using masks
3842   // and scalable vectors. Note that ATM the only dim that can be dynamic (i.e.
3843   // masked/scalable) is the channel dim (i.e. the trailing dim).
3844   uint64_t vecChDimSize = ShapedType::kDynamic;
3845   bool vecChDimScalableFlag = false;
3846   if (!inputVecSizes.empty()) {
3847     // Only use the input vector size corresponding to the channel dim. Other
3848     // vector dims will be inferred from the Ops.
3849     assert((isa<linalg::DepthwiseConv1DNwcWcOp>(*op) ||
3850             isa<linalg::DepthwiseConv1DNcwCwOp>(*op)) &&
3851            "Not a 1D depthwise conv!");
3852     size_t chDimIdx =
3853         TypeSwitch<Operation *, size_t>(op)
3854             .Case<linalg::DepthwiseConv1DNwcWcOp>([](auto conv) { return 2; })
3855             .Case<linalg::DepthwiseConv1DNcwCwOp>([](auto conv) { return 1; });
3856 
3857     vecChDimSize = inputVecSizes[chDimIdx];
3858     vecChDimScalableFlag = inputScalableVecDims[chDimIdx];
3859   }
3860   return e.generateDilatedConv(vecChDimSize, vecChDimScalableFlag,
3861                                flatten1DDepthwiseConv);
3862 }
3863 
3864 struct VectorizeConvolution : public OpInterfaceRewritePattern<LinalgOp> {
3865   using OpInterfaceRewritePattern::OpInterfaceRewritePattern;
3866 
3867   LogicalResult matchAndRewrite(LinalgOp op,
3868                                 PatternRewriter &rewriter) const override {
3869     FailureOr<Operation *> resultOrFail = vectorizeConvolution(rewriter, op);
3870     if (failed(resultOrFail))
3871       return failure();
3872     Operation *newOp = *resultOrFail;
3873     if (newOp->getNumResults() == 0) {
3874       rewriter.eraseOp(op.getOperation());
3875       return success();
3876     }
3877     assert(newOp->getNumResults() == 1 && "expected single result");
3878     rewriter.replaceOp(op.getOperation(), newOp->getResult(0));
3879     return success();
3880   }
3881 };
3882 
3883 void mlir::linalg::populateConvolutionVectorizationPatterns(
3884     RewritePatternSet &patterns, PatternBenefit benefit) {
3885   patterns.add<VectorizeConvolution>(patterns.getContext(), benefit);
3886 }
3887