xref: /llvm-project/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp (revision 9cbc1f29cabc01c02a523c11d098c00650f6955c)
1 //===- Vectorization.cpp - Implementation of linalg Vectorization ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the linalg dialect Vectorization transformations.
10 //
11 //===----------------------------------------------------------------------===//
12 #include "mlir/Dialect/Affine/Utils.h"
13 
14 #include "mlir/Analysis/SliceAnalysis.h"
15 #include "mlir/Dialect/Affine/IR/AffineOps.h"
16 #include "mlir/Dialect/Arith/IR/Arith.h"
17 #include "mlir/Dialect/Func/IR/FuncOps.h"
18 #include "mlir/Dialect/Linalg/IR/Linalg.h"
19 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
20 #include "mlir/Dialect/Linalg/Utils/Utils.h"
21 #include "mlir/Dialect/Tensor/IR/Tensor.h"
22 #include "mlir/Dialect/Tensor/Utils/Utils.h"
23 #include "mlir/Dialect/Utils/IndexingUtils.h"
24 #include "mlir/Dialect/Utils/StructuredOpsUtils.h"
25 #include "mlir/Dialect/Vector/IR/VectorOps.h"
26 #include "mlir/Dialect/Vector/Interfaces/MaskableOpInterface.h"
27 #include "mlir/Dialect/Vector/Utils/VectorUtils.h"
28 #include "mlir/IR/AffineExpr.h"
29 #include "mlir/IR/Builders.h"
30 #include "mlir/IR/BuiltinTypeInterfaces.h"
31 #include "mlir/IR/BuiltinTypes.h"
32 #include "mlir/IR/OpDefinition.h"
33 #include "mlir/IR/PatternMatch.h"
34 #include "mlir/Support/LLVM.h"
35 #include "mlir/Transforms/RegionUtils.h"
36 #include "llvm/ADT/STLExtras.h"
37 #include "llvm/ADT/Sequence.h"
38 #include "llvm/ADT/SmallVector.h"
39 #include "llvm/ADT/TypeSwitch.h"
40 #include "llvm/ADT/iterator_range.h"
41 #include "llvm/Support/Debug.h"
42 #include "llvm/Support/MathExtras.h"
43 #include "llvm/Support/raw_ostream.h"
44 #include <optional>
45 #include <type_traits>
46 
47 using namespace mlir;
48 using namespace mlir::linalg;
49 
50 #define DEBUG_TYPE "linalg-vectorization"
51 
52 #define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
53 #define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
54 
55 /// Try to vectorize `convOp` as a convolution.
56 static FailureOr<Operation *>
57 vectorizeConvolution(RewriterBase &rewriter, LinalgOp convOp,
58                      ArrayRef<int64_t> inputVecSizes = {},
59                      ArrayRef<bool> inputVecScalableFlags = {},
60                      bool flatten1DDepthwiseConv = false);
61 
62 /// Return the unique instance of OpType in `block` if it is indeed unique.
63 /// Return null if none or more than 1 instances exist.
64 template <typename OpType>
65 static OpType getSingleOpOfType(Block &block) {
66   OpType res;
67   block.walk([&](OpType op) {
68     if (res) {
69       res = nullptr;
70       return WalkResult::interrupt();
71     }
72     res = op;
73     return WalkResult::advance();
74   });
75   return res;
76 }
77 
78 /// Helper function to extract the input slices after filter is unrolled along
79 /// kw.
80 static SmallVector<Value>
81 extractConvInputSlices(RewriterBase &rewriter, Location loc, Value input,
82                        int64_t nSize, int64_t wSize, int64_t cSize,
83                        int64_t kwSize, int strideW, int dilationW,
84                        int64_t wSizeStep, bool isSingleChanneled) {
85   SmallVector<Value> result;
86   if (isSingleChanneled) {
87     // Extract input slice of size {wSizeStep} @ [w + kw] for non-channeled
88     // convolution.
89     SmallVector<int64_t> sizes = {wSizeStep};
90     SmallVector<int64_t> strides = {1};
91     for (int64_t kw = 0; kw < kwSize; ++kw) {
92       for (int64_t w = 0; w < wSize; w += wSizeStep) {
93         result.push_back(rewriter.create<vector::ExtractStridedSliceOp>(
94             loc, input, /*offsets=*/ArrayRef<int64_t>{w + kw}, sizes, strides));
95       }
96     }
97   } else {
98     // Extract lhs slice of size {n, wSizeStep, c} @ [0, sw * w + dw * kw, 0]
99     // for channeled convolution.
100     SmallVector<int64_t> sizes = {nSize, wSizeStep, cSize};
101     SmallVector<int64_t> strides = {1, 1, 1};
102     for (int64_t kw = 0; kw < kwSize; ++kw) {
103       for (int64_t w = 0; w < wSize; w += wSizeStep) {
104         result.push_back(rewriter.create<vector::ExtractStridedSliceOp>(
105             loc, input,
106             /*offsets=*/ArrayRef<int64_t>{0, w * strideW + kw * dilationW, 0},
107             sizes, strides));
108       }
109     }
110   }
111   return result;
112 }
113 
114 /// Helper function to extract the filter slices after filter is unrolled along
115 /// kw.
116 static SmallVector<Value> extractConvFilterSlices(RewriterBase &rewriter,
117                                                   Location loc, Value filter,
118                                                   int64_t kwSize) {
119   SmallVector<Value> result;
120   // Extract rhs slice of size [{c, f} for channeled convolutions and {1} for
121   // non-chanelled convolution] @ [kw].
122   for (int64_t kw = 0; kw < kwSize; ++kw) {
123     result.push_back(rewriter.create<vector::ExtractOp>(
124         loc, filter, /*offsets=*/ArrayRef<int64_t>{kw}));
125   }
126   return result;
127 }
128 
129 /// Helper function to extract the result slices after filter is unrolled along
130 /// kw.
131 static SmallVector<Value>
132 extractConvResultSlices(RewriterBase &rewriter, Location loc, Value res,
133                         int64_t nSize, int64_t wSize, int64_t fSize,
134                         int64_t wSizeStep, bool isSingleChanneled) {
135   SmallVector<Value> result;
136   if (isSingleChanneled) {
137     // Extract res slice: {wSizeStep} @ [w] for non-channeled convolution.
138     SmallVector<int64_t> sizes = {wSizeStep};
139     SmallVector<int64_t> strides = {1};
140     for (int64_t w = 0; w < wSize; w += wSizeStep) {
141       result.push_back(rewriter.create<vector::ExtractStridedSliceOp>(
142           loc, res, /*offsets=*/ArrayRef<int64_t>{w}, sizes, strides));
143     }
144   } else {
145     // Extract res slice: {n, wSizeStep, f} @ [0, w, 0] for channeled
146     // convolution.
147     SmallVector<int64_t> sizes = {nSize, wSizeStep, fSize};
148     SmallVector<int64_t> strides = {1, 1, 1};
149     for (int64_t w = 0; w < wSize; w += wSizeStep) {
150       result.push_back(rewriter.create<vector::ExtractStridedSliceOp>(
151           loc, res, /*offsets=*/ArrayRef<int64_t>{0, w, 0}, sizes, strides));
152     }
153   }
154   return result;
155 }
156 
157 /// Helper function to insert the computed result slices.
158 static Value insertConvResultSlices(RewriterBase &rewriter, Location loc,
159                                     Value res, int64_t wSize, int64_t wSizeStep,
160                                     SmallVectorImpl<Value> &resVals,
161                                     bool isSingleChanneled) {
162 
163   if (isSingleChanneled) {
164     // Write back res slice: {wSizeStep} @ [w] for non-channeled convolution.
165     // This does not depend on kw.
166     SmallVector<int64_t> strides = {1};
167     for (int64_t w = 0; w < wSize; w += wSizeStep) {
168       res = rewriter.create<vector::InsertStridedSliceOp>(
169           loc, resVals[w], res, /*offsets=*/ArrayRef<int64_t>{w}, strides);
170     }
171   } else {
172     // Write back res slice: {n, wSizeStep, f} @ [0, w, 0] for channeled
173     // convolution. This does not depend on kw.
174     SmallVector<int64_t> strides = {1, 1, 1};
175     for (int64_t w = 0; w < wSize; w += wSizeStep) {
176       res = rewriter.create<vector::InsertStridedSliceOp>(
177           loc, resVals[w], res, /*offsets=*/ArrayRef<int64_t>{0, w, 0},
178           strides);
179     }
180   }
181   return res;
182 }
183 
184 /// Contains the vectorization state and related methods used across the
185 /// vectorization process of a given operation.
186 struct VectorizationState {
187   VectorizationState(RewriterBase &rewriter) : rewriterGuard(rewriter) {}
188 
189   /// Initializes the vectorization state, including the computation of the
190   /// canonical vector shape for vectorization.
191   LogicalResult initState(RewriterBase &rewriter, LinalgOp linalgOp,
192                           ArrayRef<int64_t> inputVectorSizes,
193                           ArrayRef<bool> inputScalableVecDims);
194 
195   /// Returns the canonical vector shape used to vectorize the iteration space.
196   ArrayRef<int64_t> getCanonicalVecShape() const { return canonicalVecShape; }
197 
198   /// Returns the vector dimensions that are scalable in the canonical vector
199   /// shape.
200   ArrayRef<bool> getScalableVecDims() const { return scalableVecDims; }
201 
202   /// Returns a vector type of the provided `elementType` with the canonical
203   /// vector shape and the corresponding fixed/scalable dimensions bit. If
204   /// `dimPermutation` is provided, the canonical vector dimensions are permuted
205   /// accordingly.
206   VectorType getCanonicalVecType(
207       Type elementType,
208       std::optional<AffineMap> dimPermutation = std::nullopt) const {
209     SmallVector<int64_t> vectorShape;
210     SmallVector<bool> scalableDims;
211     if (dimPermutation.has_value()) {
212       vectorShape =
213           applyPermutationMap<int64_t>(*dimPermutation, canonicalVecShape);
214       scalableDims =
215           applyPermutationMap<bool>(*dimPermutation, scalableVecDims);
216     } else {
217       vectorShape.append(canonicalVecShape.begin(), canonicalVecShape.end());
218       scalableDims.append(scalableVecDims.begin(), scalableVecDims.end());
219     }
220 
221     return VectorType::get(vectorShape, elementType, scalableDims);
222   }
223 
224   /// Masks an operation with the canonical vector mask if the operation needs
225   /// masking. Returns the masked operation or the original operation if masking
226   /// is not needed. If provided, the canonical mask for this operation is
227   /// permuted using `maybeIndexingMap`.
228   Operation *
229   maskOperation(RewriterBase &rewriter, Operation *opToMask, LinalgOp linalgOp,
230                 std::optional<AffineMap> maybeIndexingMap = std::nullopt);
231 
232 private:
233   /// Initializes the iteration space static sizes using the Linalg op
234   /// information. This may become more complicated in the future.
235   void initIterSpaceStaticSizes(LinalgOp linalgOp) {
236     iterSpaceStaticSizes.append(linalgOp.getStaticLoopRanges());
237   }
238 
239   /// Generates 'arith.constant' and 'tensor/memref.dim' operations for
240   /// all the static and dynamic dimensions of the iteration space to be
241   /// vectorized and store them in `iterSpaceValueSizes`.
242   LogicalResult precomputeIterSpaceValueSizes(RewriterBase &rewriter,
243                                               LinalgOp linalgOp);
244 
245   /// Create or retrieve an existing mask value to mask `opToMask` in the
246   /// canonical vector iteration space. If `maybeMaskingMap` the mask is
247   /// permuted using that permutation map. If a new mask is created, it will be
248   /// cached for future users.
249   Value getOrCreateMaskFor(RewriterBase &rewriter, Operation *opToMask,
250                            LinalgOp linalgOp,
251                            std::optional<AffineMap> maybeMaskingMap);
252 
253   /// Check whether this permutation map can be used for masking. At the
254   /// moment we only make sure that there are no broadcast dimensions, but this
255   /// might change if indexing maps evolve.
256   bool isValidMaskingMap(AffineMap maskingMap) {
257     return maskingMap.getBroadcastDims().size() == 0;
258   }
259 
260   /// Turn the input indexing map into a valid masking map.
261   ///
262   /// The input indexing map may contain "zero" results, e.g.:
263   ///    (d0, d1, d2, d3) -> (d2, d1, d0, 0)
264   /// Applying such maps to canonical vector shapes like this one:
265   ///    (1, 16, 16, 4)
266   /// would yield an invalid vector shape like this:
267   ///    (16, 16, 1, 0)
268   /// Instead, drop the broadcasting dims that make no sense for masking perm.
269   /// maps:
270   ///    (d0, d1, d2, d3) -> (d2, d1, d0)
271   /// This way, the corresponding vector/mask type will be:
272   ///    vector<16x16x1xty>
273   /// rather than this invalid Vector type:
274   ///    vector<16x16x1x0xty>
275   AffineMap getMaskingMapFromIndexingMap(AffineMap &indexingMap) {
276     return indexingMap.dropZeroResults();
277   }
278 
279   // Holds the compile-time static sizes of the iteration space to vectorize.
280   // Dynamic dimensions are represented using ShapedType::kDynamic.
281   SmallVector<int64_t> iterSpaceStaticSizes;
282 
283   /// Holds the value sizes of the iteration space to vectorize. Static
284   /// dimensions are represented by 'arith.constant' and dynamic
285   /// dimensions by 'tensor/memref.dim'.
286   SmallVector<Value> iterSpaceValueSizes;
287 
288   /// Holds the canonical vector shape used to vectorize the iteration space.
289   SmallVector<int64_t> canonicalVecShape;
290 
291   /// Holds the vector dimensions that are scalable in the canonical vector
292   /// shape.
293   SmallVector<bool> scalableVecDims;
294 
295   /// Holds the active masks for permutations of the canonical vector iteration
296   /// space.
297   DenseMap<AffineMap, Value> activeMaskCache;
298 
299   /// Global vectorization guard for the incoming rewriter. It's initialized
300   /// when the vectorization state is initialized.
301   OpBuilder::InsertionGuard rewriterGuard;
302 };
303 
304 LogicalResult
305 VectorizationState::precomputeIterSpaceValueSizes(RewriterBase &rewriter,
306                                                   LinalgOp linalgOp) {
307   // TODO: Support 0-d vectors.
308   for (int vecDim = 0, end = canonicalVecShape.size(); vecDim < end; ++vecDim) {
309     if (!ShapedType::isDynamic(iterSpaceStaticSizes[vecDim])) {
310       // Create constant index op for static dimensions.
311       iterSpaceValueSizes.push_back(rewriter.create<arith::ConstantIndexOp>(
312           linalgOp.getLoc(), iterSpaceStaticSizes[vecDim]));
313       continue;
314     }
315 
316     // Find an operand defined on this dimension of the iteration space to
317     // extract the runtime dimension size.
318     Value operand;
319     unsigned operandDimPos;
320     if (failed(linalgOp.mapIterationSpaceDimToOperandDim(vecDim, operand,
321                                                          operandDimPos)))
322       return failure();
323 
324     Value dynamicDim = linalgOp.hasPureTensorSemantics()
325                            ? (Value)rewriter.create<tensor::DimOp>(
326                                  linalgOp.getLoc(), operand, operandDimPos)
327                            : (Value)rewriter.create<memref::DimOp>(
328                                  linalgOp.getLoc(), operand, operandDimPos);
329     iterSpaceValueSizes.push_back(dynamicDim);
330   }
331 
332   return success();
333 }
334 
335 /// Initializes the vectorization state, including the computation of the
336 /// canonical vector shape for vectorization.
337 // TODO: Move this to the constructor when we can remove the failure cases.
338 LogicalResult
339 VectorizationState::initState(RewriterBase &rewriter, LinalgOp linalgOp,
340                               ArrayRef<int64_t> inputVectorSizes,
341                               ArrayRef<bool> inputScalableVecDims) {
342   // Initialize the insertion point.
343   rewriter.setInsertionPoint(linalgOp);
344 
345   if (!inputVectorSizes.empty()) {
346     // Get the canonical vector shape from the input vector sizes provided. This
347     // path should be taken to vectorize code with dynamic shapes and when using
348     // vector sizes greater than the iteration space sizes.
349     canonicalVecShape.append(inputVectorSizes.begin(), inputVectorSizes.end());
350     scalableVecDims.append(inputScalableVecDims.begin(),
351                            inputScalableVecDims.end());
352   } else {
353     // Compute the canonical vector shape from the operation shape. If there are
354     // dynamic shapes, the operation won't be vectorized. We assume all the
355     // vector dimensions are fixed.
356     canonicalVecShape = linalgOp.getStaticLoopRanges();
357     scalableVecDims.append(linalgOp.getNumLoops(), false);
358   }
359 
360   LDBG("Canonical vector shape: ");
361   LLVM_DEBUG(llvm::interleaveComma(canonicalVecShape, llvm::dbgs()));
362   LLVM_DEBUG(llvm::dbgs() << "\n");
363   LDBG("Scalable vector dims: ");
364   LLVM_DEBUG(llvm::interleaveComma(scalableVecDims, llvm::dbgs()));
365   LLVM_DEBUG(llvm::dbgs() << "\n");
366 
367   if (ShapedType::isDynamicShape(canonicalVecShape))
368     return failure();
369 
370   // Initialize iteration space static sizes.
371   initIterSpaceStaticSizes(linalgOp);
372 
373   // Generate 'arith.constant' and 'tensor/memref.dim' operations for
374   // all the static and dynamic dimensions of the iteration space, needed to
375   // compute a mask during vectorization.
376   if (failed(precomputeIterSpaceValueSizes(rewriter, linalgOp)))
377     return failure();
378 
379   return success();
380 }
381 
382 /// Create or retrieve an existing mask value to mask `opToMask` in the
383 /// canonical vector iteration space. If `maybeMaskingMap` the mask is permuted
384 /// using that permutation map. If a new mask is created, it will be cached for
385 /// future users.
386 Value VectorizationState::getOrCreateMaskFor(
387     RewriterBase &rewriter, Operation *opToMask, LinalgOp linalgOp,
388     std::optional<AffineMap> maybeMaskingMap) {
389 
390   assert((!maybeMaskingMap || isValidMaskingMap(*maybeMaskingMap)) &&
391          "Ill-formed masking map.");
392 
393   // No mask is needed if the operation is not maskable.
394   auto maskableOp = dyn_cast<vector::MaskableOpInterface>(opToMask);
395   if (!maskableOp)
396     return Value();
397 
398   assert(!maskableOp.isMasked() &&
399          "Masking an operation that is already masked");
400 
401   // If no masking map was provided, use an identity map with the loop dims.
402   assert((!maybeMaskingMap || *maybeMaskingMap) &&
403          "Unexpected null mask permutation map");
404   AffineMap maskingMap =
405       maybeMaskingMap ? *maybeMaskingMap
406                       : AffineMap::getMultiDimIdentityMap(
407                             linalgOp.getNumLoops(), rewriter.getContext());
408 
409   LDBG("Masking map: " << maskingMap << "\n");
410 
411   // Return the active mask for the masking map of this operation if it was
412   // already created.
413   auto activeMaskIt = activeMaskCache.find(maskingMap);
414   if (activeMaskIt != activeMaskCache.end()) {
415     Value mask = activeMaskIt->second;
416     LDBG("Reusing mask: " << mask << "\n");
417     return mask;
418   }
419 
420   // Compute permuted projection of the iteration space to be masked and the
421   // corresponding mask shape. If the resulting iteration space dimensions are
422   // static and identical to the mask shape, masking is not needed for this
423   // operation.
424   // TODO: Improve this check. Only projected permutation indexing maps are
425   // supported.
426   SmallVector<int64_t> permutedStaticSizes =
427       applyPermutationMap<int64_t>(maskingMap, iterSpaceStaticSizes);
428   auto maskType = getCanonicalVecType(rewriter.getI1Type(), maskingMap);
429   auto maskShape = maskType.getShape();
430 
431   LDBG("Mask shape: ");
432   LLVM_DEBUG(llvm::interleaveComma(maskShape, llvm::dbgs()));
433   LLVM_DEBUG(llvm::dbgs() << "\n");
434 
435   if (permutedStaticSizes == maskShape) {
436     LDBG("Masking is not needed for masking map: " << maskingMap << "\n");
437     activeMaskCache[maskingMap] = Value();
438     return Value();
439   }
440 
441   // Permute the iteration space value sizes to compute the mask upper bounds.
442   SmallVector<Value> upperBounds =
443       applyPermutationMap(maskingMap, ArrayRef<Value>(iterSpaceValueSizes));
444   assert(!maskShape.empty() && !upperBounds.empty() &&
445          "Masked 0-d vectors are not supported yet");
446 
447   // Create the mask based on the dimension values.
448   Value mask = rewriter.create<vector::CreateMaskOp>(linalgOp.getLoc(),
449                                                      maskType, upperBounds);
450   LDBG("Creating new mask: " << mask << "\n");
451   activeMaskCache[maskingMap] = mask;
452   return mask;
453 }
454 
455 Operation *
456 VectorizationState::maskOperation(RewriterBase &rewriter, Operation *opToMask,
457                                   LinalgOp linalgOp,
458                                   std::optional<AffineMap> maybeIndexingMap) {
459   LDBG("Trying to mask: " << *opToMask << "\n");
460 
461   std::optional<AffineMap> maybeMaskingMap = std::nullopt;
462   if (maybeIndexingMap)
463     maybeMaskingMap = getMaskingMapFromIndexingMap(*maybeIndexingMap);
464 
465   // Create or retrieve mask for this operation.
466   Value mask =
467       getOrCreateMaskFor(rewriter, opToMask, linalgOp, maybeMaskingMap);
468 
469   if (!mask) {
470     LDBG("No mask required\n");
471     return opToMask;
472   }
473 
474   // Wrap the operation with a new `vector.mask` and update D-U chain.
475   assert(opToMask && "Expected a valid operation to mask");
476   auto maskOp = cast<vector::MaskOp>(
477       mlir::vector::maskOperation(rewriter, opToMask, mask));
478   Operation *maskOpTerminator = &maskOp.getMaskRegion().front().back();
479 
480   for (auto [resIdx, resVal] : llvm::enumerate(opToMask->getResults()))
481     rewriter.replaceAllUsesExcept(resVal, maskOp.getResult(resIdx),
482                                   maskOpTerminator);
483 
484   LDBG("Masked operation: " << *maskOp << "\n");
485   return maskOp;
486 }
487 
488 /// Given an indexing `map` coming from a LinalgOp indexing, restricted to a
489 /// projectedPermutation, compress the unused dimensions to serve as a
490 /// permutation_map for a vector transfer operation.
491 /// For example, given a linalg op such as:
492 ///
493 /// ```
494 ///   %0 = linalg.generic {
495 ///        indexing_maps = affine_map<(d0, d1, d2, d3, d4) -> (d4, d0, d2)>,
496 ///        indexing_maps = affine_map<(d0, d1, d2, d3, d4) -> (d1, d3)>
497 ///      }
498 ///     ins(%0 : tensor<2x3x4xf32>)
499 ///    outs(%1 : tensor<5x6xf32>)
500 /// ```
501 ///
502 /// the iteration domain size of the linalg op is 3x5x4x6x2. The first affine
503 /// map is reindexed to `affine_map<(d0, d1, d2) -> (d2, d0, d1)>`, the second
504 /// affine map is reindexed to `affine_map<(d0, d1) -> (d0, d1)>`.
505 static AffineMap reindexIndexingMap(AffineMap map) {
506   assert(map.isProjectedPermutation(/*allowZeroInResults=*/true) &&
507          "expected projected permutation");
508   auto res = compressUnusedDims(map);
509   assert(res.getNumDims() ==
510              (res.getNumResults() - res.getNumOfZeroResults()) &&
511          "expected reindexed map with same number of dims and results");
512   return res;
513 }
514 
515 /// Helper enum to represent conv1d input traversal order.
516 enum class Conv1DOpOrder {
517   W,   // Corresponds to non-channeled 1D convolution operation.
518   Ncw, // Corresponds to operation that traverses the input in (n, c, w) order.
519   Nwc  // Corresponds to operation that traverses the input in (n, w, c) order.
520 };
521 
522 /// Helper data structure to represent the result of vectorization.
523 /// In certain specific cases, like terminators, we do not want to propagate/
524 enum VectorizationStatus {
525   /// Op failed to vectorize.
526   Failure = 0,
527   /// Op vectorized and custom function took care of replacement logic
528   NoReplace,
529   /// Op vectorized into a new Op whose results will replace original Op's
530   /// results.
531   NewOp
532   // TODO: support values if Op vectorized to Many-Ops whose results we need to
533   // aggregate for replacement.
534 };
535 struct VectorizationResult {
536   /// Return status from vectorizing the current op.
537   enum VectorizationStatus status = VectorizationStatus::Failure;
538   /// New vectorized operation to replace the current op.
539   /// Replacement behavior is specified by `status`.
540   Operation *newOp;
541 };
542 
543 std::optional<vector::CombiningKind>
544 mlir::linalg::getCombinerOpKind(Operation *combinerOp) {
545   using ::mlir::vector::CombiningKind;
546 
547   if (!combinerOp)
548     return std::nullopt;
549   return llvm::TypeSwitch<Operation *, std::optional<CombiningKind>>(combinerOp)
550       .Case<arith::AddIOp, arith::AddFOp>(
551           [&](auto op) { return CombiningKind::ADD; })
552       .Case<arith::AndIOp>([&](auto op) { return CombiningKind::AND; })
553       .Case<arith::MaxSIOp>([&](auto op) { return CombiningKind::MAXSI; })
554       .Case<arith::MaxUIOp>([&](auto op) { return CombiningKind::MAXUI; })
555       .Case<arith::MaximumFOp>([&](auto op) { return CombiningKind::MAXIMUMF; })
556       .Case<arith::MaxNumFOp>([&](auto op) { return CombiningKind::MAXNUMF; })
557       .Case<arith::MinSIOp>([&](auto op) { return CombiningKind::MINSI; })
558       .Case<arith::MinUIOp>([&](auto op) { return CombiningKind::MINUI; })
559       .Case<arith::MinimumFOp>([&](auto op) { return CombiningKind::MINIMUMF; })
560       .Case<arith::MinNumFOp>([&](auto op) { return CombiningKind::MINNUMF; })
561       .Case<arith::MulIOp, arith::MulFOp>(
562           [&](auto op) { return CombiningKind::MUL; })
563       .Case<arith::OrIOp>([&](auto op) { return CombiningKind::OR; })
564       .Case<arith::XOrIOp>([&](auto op) { return CombiningKind::XOR; })
565       .Default([&](auto op) { return std::nullopt; });
566 }
567 
568 /// Check whether `outputOperand` is a reduction with a single combiner
569 /// operation. Return the combiner operation of the reduction. Return
570 /// nullptr otherwise. Multiple reduction operations would impose an
571 /// ordering between reduction dimensions and is currently unsupported in
572 /// Linalg. This limitation is motivated by the fact that e.g. min(max(X)) !=
573 /// max(min(X))
574 // TODO: use in LinalgOp verification, there is a circular dependency atm.
575 static Operation *matchLinalgReduction(OpOperand *outputOperand) {
576   auto linalgOp = cast<LinalgOp>(outputOperand->getOwner());
577   unsigned outputPos =
578       outputOperand->getOperandNumber() - linalgOp.getNumDpsInputs();
579   // Only single combiner operations are supported for now.
580   SmallVector<Operation *, 4> combinerOps;
581   if (!matchReduction(linalgOp.getRegionOutputArgs(), outputPos, combinerOps) ||
582       combinerOps.size() != 1)
583     return nullptr;
584 
585   // Return the combiner operation.
586   return combinerOps[0];
587 }
588 
589 /// Broadcast `value` to a vector of `shape` if possible. Return value
590 /// otherwise.
591 static Value broadcastIfNeeded(OpBuilder &b, Value value, Type dstType) {
592   auto dstVecType = dyn_cast<VectorType>(dstType);
593   // If no shape to broadcast to, just return `value`.
594   if (dstVecType.getRank() == 0)
595     return value;
596   if (vector::isBroadcastableTo(value.getType(), dstVecType) !=
597       vector::BroadcastableToResult::Success)
598     return value;
599   Location loc = b.getInsertionPoint()->getLoc();
600   return b.createOrFold<vector::BroadcastOp>(loc, dstVecType, value);
601 }
602 
603 /// Create MultiDimReductionOp to compute the reduction for `reductionOp`. This
604 /// assumes that `reductionOp` has two operands and one of them is the reduction
605 /// initial value.buildMultiDimReduce
606 // Note: this is a true builder that notifies the OpBuilder listener.
607 // TODO: Consider moving as a static helper on the ReduceOp.
608 static Operation *buildMultiDimReduce(OpBuilder &b, Operation *reduceOp,
609                                       Value valueToReduce, Value acc,
610                                       ArrayRef<bool> dimsToMask) {
611   auto maybeKind = getCombinerOpKind(reduceOp);
612   assert(maybeKind && "Failed precondition: could not get reduction kind");
613   return b.create<vector::MultiDimReductionOp>(
614       reduceOp->getLoc(), valueToReduce, acc, dimsToMask, *maybeKind);
615 }
616 
617 static SmallVector<bool> getDimsToReduce(LinalgOp linalgOp) {
618   return llvm::to_vector(
619       llvm::map_range(linalgOp.getIteratorTypesArray(), isReductionIterator));
620 }
621 
622 /// Check if `op` is a linalg.reduce or a linalg.generic that has at least one
623 /// reduction iterator.
624 static bool hasReductionIterator(LinalgOp &op) {
625   return isa<linalg::ReduceOp>(op) ||
626          (isa<linalg::GenericOp>(op) &&
627           llvm::any_of(op.getIteratorTypesArray(), isReductionIterator));
628 }
629 
630 /// Build a vector.transfer_write of `value` into `outputOperand` at indices set
631 /// to all `0`; where `outputOperand` is an output operand of the LinalgOp
632 /// currently being vectorized. If `dest` has null rank, build an memref.store.
633 /// Return the produced value or null if no value is produced.
634 // Note: this is a true builder that notifies the OpBuilder listener.
635 // TODO: Consider moving as a static helper on the ReduceOp.
636 static Value buildVectorWrite(RewriterBase &rewriter, Value value,
637                               OpOperand *outputOperand,
638                               VectorizationState &state) {
639   Location loc = value.getLoc();
640   auto linalgOp = cast<LinalgOp>(outputOperand->getOwner());
641   AffineMap opOperandMap = linalgOp.getMatchingIndexingMap(outputOperand);
642 
643   // Compute the vector type of the value to store. This type should be an
644   // identity or projection of the canonical vector type without any permutation
645   // applied, given that any permutation in a transfer write happens as part of
646   // the write itself.
647   AffineMap vectorTypeMap = AffineMap::getFilteredIdentityMap(
648       opOperandMap.getContext(), opOperandMap.getNumInputs(),
649       [&](AffineDimExpr dimExpr) -> bool {
650         return llvm::is_contained(opOperandMap.getResults(), dimExpr);
651       });
652   auto vectorType = state.getCanonicalVecType(
653       getElementTypeOrSelf(outputOperand->get().getType()), vectorTypeMap);
654 
655   Operation *write;
656   if (vectorType.getRank() > 0) {
657     AffineMap writeMap = inversePermutation(reindexIndexingMap(opOperandMap));
658     SmallVector<Value> indices(linalgOp.getRank(outputOperand),
659                                rewriter.create<arith::ConstantIndexOp>(loc, 0));
660     value = broadcastIfNeeded(rewriter, value, vectorType);
661     assert(value.getType() == vectorType && "Incorrect type");
662     write = rewriter.create<vector::TransferWriteOp>(
663         loc, value, outputOperand->get(), indices, writeMap);
664   } else {
665     // 0-d case is still special: do not invert the reindexing writeMap.
666     if (!isa<VectorType>(value.getType()))
667       value = rewriter.create<vector::BroadcastOp>(loc, vectorType, value);
668     assert(value.getType() == vectorType && "Incorrect type");
669     write = rewriter.create<vector::TransferWriteOp>(
670         loc, value, outputOperand->get(), ValueRange{});
671   }
672 
673   write = state.maskOperation(rewriter, write, linalgOp, opOperandMap);
674 
675   // If masked, set in-bounds to true. Masking guarantees that the access will
676   // be in-bounds.
677   if (auto maskOp = dyn_cast<vector::MaskingOpInterface>(write)) {
678     auto maskedWriteOp = cast<vector::TransferWriteOp>(maskOp.getMaskableOp());
679     SmallVector<bool> inBounds(maskedWriteOp.getVectorType().getRank(), true);
680     maskedWriteOp.setInBoundsAttr(rewriter.getBoolArrayAttr(inBounds));
681   }
682 
683   LDBG("vectorized op: " << *write << "\n");
684   if (!write->getResults().empty())
685     return write->getResult(0);
686   return Value();
687 }
688 
689 // Custom vectorization precondition function type. This is intented to be used
690 // with CustomVectorizationHook. Returns success if the corresponding custom
691 // hook can vectorize the op.
692 using CustomVectorizationPrecondition =
693     std::function<LogicalResult(Operation *, bool)>;
694 
695 // Custom vectorization function type. Produce a vector form of Operation*
696 // assuming all its vectorized operands are already in the IRMapping.
697 // Return nullptr if the Operation cannot be vectorized.
698 using CustomVectorizationHook =
699     std::function<VectorizationResult(Operation *, const IRMapping &)>;
700 
701 /// Helper function to vectorize the terminator of a `linalgOp`. New result
702 /// vector values are appended to `newResults`. Return
703 /// VectorizationStatus::NoReplace to signal the vectorization algorithm that it
704 /// should not try to map produced operations and instead return the results
705 /// using the `newResults` vector making them available to the vectorization
706 /// algorithm for RAUW. This function is meant to be used as a
707 /// CustomVectorizationHook.
708 static VectorizationResult
709 vectorizeLinalgYield(RewriterBase &rewriter, Operation *op,
710                      const IRMapping &bvm, VectorizationState &state,
711                      LinalgOp linalgOp, SmallVectorImpl<Value> &newResults) {
712   auto yieldOp = dyn_cast<linalg::YieldOp>(op);
713   if (!yieldOp)
714     return VectorizationResult{VectorizationStatus::Failure, nullptr};
715   for (const auto &output : llvm::enumerate(yieldOp.getValues())) {
716     // TODO: Scan for an opportunity for reuse.
717     // TODO: use a map.
718     Value vectorValue = bvm.lookup(output.value());
719     Value newResult =
720         buildVectorWrite(rewriter, vectorValue,
721                          linalgOp.getDpsInitOperand(output.index()), state);
722     if (newResult)
723       newResults.push_back(newResult);
724   }
725 
726   return VectorizationResult{VectorizationStatus::NoReplace, nullptr};
727 }
728 
729 /// Helper function to vectorize the index operations of a `linalgOp`. Return
730 /// VectorizationStatus::NewOp to signal the vectorization algorithm that it
731 /// should map the produced operations. This function is meant to be used as a
732 /// CustomVectorizationHook.
733 static VectorizationResult vectorizeLinalgIndex(RewriterBase &rewriter,
734                                                 VectorizationState &state,
735                                                 Operation *op,
736                                                 LinalgOp linalgOp) {
737   IndexOp indexOp = dyn_cast<linalg::IndexOp>(op);
738   if (!indexOp)
739     return VectorizationResult{VectorizationStatus::Failure, nullptr};
740   auto loc = indexOp.getLoc();
741   // Compute the static loop sizes of the index op.
742   ArrayRef<int64_t> targetShape = state.getCanonicalVecShape();
743   auto dim = indexOp.getDim();
744   // Compute a one-dimensional index vector for the index op dimension.
745   auto indexVectorType =
746       VectorType::get({targetShape[dim]}, rewriter.getIndexType(),
747                       state.getScalableVecDims()[dim]);
748   auto indexSteps = rewriter.create<vector::StepOp>(loc, indexVectorType);
749   // Return the one-dimensional index vector if it lives in the trailing
750   // dimension of the iteration space since the vectorization algorithm in this
751   // case can handle the broadcast.
752   if (dim == targetShape.size() - 1)
753     return VectorizationResult{VectorizationStatus::NewOp, indexSteps};
754   // Otherwise permute the targetShape to move the index dimension last,
755   // broadcast the one-dimensional index vector to the permuted shape, and
756   // finally transpose the broadcasted index vector to undo the permutation.
757   auto permPattern =
758       llvm::to_vector(llvm::seq<unsigned>(0, targetShape.size()));
759   std::swap(permPattern[dim], permPattern.back());
760   auto permMap =
761       AffineMap::getPermutationMap(permPattern, linalgOp.getContext());
762 
763   auto broadCastOp = rewriter.create<vector::BroadcastOp>(
764       loc, state.getCanonicalVecType(rewriter.getIndexType(), permMap),
765       indexSteps);
766   SmallVector<int64_t> transposition =
767       llvm::to_vector<16>(llvm::seq<int64_t>(0, linalgOp.getNumLoops()));
768   std::swap(transposition.back(), transposition[dim]);
769   auto transposeOp =
770       rewriter.create<vector::TransposeOp>(loc, broadCastOp, transposition);
771   return VectorizationResult{VectorizationStatus::NewOp, transposeOp};
772 }
773 
774 /// Helper function to check if the tensor.extract can be vectorized by the
775 /// custom hook vectorizeTensorExtract.
776 static LogicalResult
777 tensorExtractVectorizationPrecondition(Operation *op, bool vectorizeNDExtract) {
778   tensor::ExtractOp extractOp = dyn_cast<tensor::ExtractOp>(op);
779   if (!extractOp)
780     return failure();
781 
782   if (extractOp.getIndices().size() != 1 && !vectorizeNDExtract)
783     return failure();
784 
785   // Check the index type, but only for non 0-d tensors (for which we do need
786   // access indices).
787   if (not extractOp.getIndices().empty()) {
788     if (!VectorType::isValidElementType(extractOp.getIndices()[0].getType()))
789       return failure();
790   }
791 
792   if (llvm::any_of(extractOp->getResultTypes(), [](Type type) {
793         return !VectorType::isValidElementType(type);
794       })) {
795     return failure();
796   }
797 
798   return success();
799 }
800 
801 /// Calculates the offsets (`$index_vec`) for `vector.gather` operations
802 /// generated from `tensor.extract`. The offset is calculated as follows
803 /// (example using scalar values):
804 ///
805 ///    offset = extractOp.indices[0]
806 ///    for (i = 1; i < numIndices; i++)
807 ///      offset = extractOp.dimSize[i] * offset + extractOp.indices[i];
808 ///
809 /// For tensor<45 x 80 x 15 x f32> and index [1, 2, 3], this leads to:
810 ///  offset = ( ( 1 ) * 80 +  2 ) * 15  + 3
811 static Value calculateGatherOffset(RewriterBase &rewriter,
812                                    VectorizationState &state,
813                                    tensor::ExtractOp extractOp,
814                                    const IRMapping &bvm) {
815   // The vector of indices for GatherOp should be shaped as the output vector.
816   auto indexVecType = state.getCanonicalVecType(rewriter.getIndexType());
817   auto loc = extractOp.getLoc();
818 
819   Value offset = broadcastIfNeeded(
820       rewriter, bvm.lookup(extractOp.getIndices()[0]), indexVecType);
821 
822   const size_t numIndices = extractOp.getIndices().size();
823   for (size_t i = 1; i < numIndices; i++) {
824     Value dimIdx = rewriter.create<arith::ConstantIndexOp>(loc, i);
825 
826     auto dimSize = broadcastIfNeeded(
827         rewriter,
828         rewriter.create<tensor::DimOp>(loc, extractOp.getTensor(), dimIdx),
829         indexVecType);
830 
831     offset = rewriter.create<arith::MulIOp>(loc, offset, dimSize);
832 
833     auto extractOpIndex = broadcastIfNeeded(
834         rewriter, bvm.lookup(extractOp.getIndices()[i]), indexVecType);
835 
836     offset = rewriter.create<arith::AddIOp>(loc, extractOpIndex, offset);
837   }
838 
839   return offset;
840 }
841 
842 enum VectorMemoryAccessKind { ScalarBroadcast, Contiguous, Gather };
843 
844 /// Find the index of the trailing non-unit dim in linalgOp. This hook is used
845 /// when checking whether `tensor.extract` Op (within a `linalg.generic` Op)
846 /// represents a contiguous load operation.
847 ///
848 /// Note that when calling this hook, it is assumed that the output vector is
849 /// effectively 1D. Other cases (i.e. reading n-D vectors) should've been
850 /// labelled as a gather load before entering this method.
851 ///
852 /// Following on from the above, it is assumed that:
853 ///   * for statically shaped loops, when no masks are used, only one dim is !=
854 ///   1 (that's what the shape of the output vector is based on).
855 ///   * for dynamically shaped loops, there might be more non-unit dims
856 ///   as the output vector type is user-specified.
857 ///
858 /// TODO: Statically shaped loops + vector masking
859 static uint64_t getTrailingNonUnitLoopDimIdx(LinalgOp linalgOp) {
860   SmallVector<int64_t> loopRanges = linalgOp.getStaticLoopRanges();
861   assert(
862       (linalgOp.hasDynamicShape() ||
863        llvm::count_if(loopRanges, [](int64_t dim) { return dim != 1; }) == 1) &&
864       "For statically shaped Linalg Ops, only one "
865       "non-unit loop dim is expected");
866   assert(loopRanges.size() != 0 && "Empty loops, nothing to analyse.");
867 
868   size_t idx = loopRanges.size() - 1;
869   for (; idx != 0; idx--)
870     if (loopRanges[idx] != 1)
871       break;
872 
873   return idx;
874 }
875 
876 /// Checks whether `val` can be used for calculating a loop invariant index.
877 static bool isLoopInvariantIdx(LinalgOp &linalgOp, Value &val,
878                                VectorType resType) {
879 
880   assert(((llvm::count_if(resType.getShape(),
881                           [](int64_t dimSize) { return dimSize > 1; }) == 1)) &&
882          "n-D vectors are not yet supported");
883 
884   // Blocks outside _this_ linalg.generic are effectively loop invariant.
885   // However, analysing block arguments for _this_ linalg.generic Op is a bit
886   // tricky. Just bail out in the latter case.
887   // TODO: We could try analysing the corresponding affine map here.
888   auto *block = linalgOp.getBlock();
889   if (isa<BlockArgument>(val))
890     return llvm::all_of(block->getArguments(),
891                         [&val](Value v) { return (v != val); });
892 
893   Operation *defOp = val.getDefiningOp();
894   assert(defOp && "This is neither a block argument nor an operation result");
895 
896   // IndexOp is loop invariant as long as its result remains constant across
897   // iterations. Note that for dynamic shapes, the corresponding dim will also
898   // be conservatively treated as != 1.
899   if (auto indexOp = dyn_cast<linalg::IndexOp>(defOp)) {
900     return linalgOp.getStaticLoopRanges()[indexOp.getDim()] == 1;
901   }
902 
903   auto *ancestor = block->findAncestorOpInBlock(*defOp);
904 
905   // Values define outside `linalgOp` are loop invariant.
906   if (!ancestor)
907     return true;
908 
909   // Values defined inside `linalgOp`, which are constant, are loop invariant.
910   if (isa<arith::ConstantOp>(ancestor))
911     return true;
912 
913   bool result = true;
914   for (auto op : ancestor->getOperands())
915     result &= isLoopInvariantIdx(linalgOp, op, resType);
916 
917   return result;
918 }
919 
920 /// Check whether `val` could be used for calculating the trailing index for a
921 /// contiguous load operation.
922 ///
923 /// There are currently 3 types of values that are allowed here:
924 ///   1. loop-invariant values,
925 ///   2. values that increment by 1 with every loop iteration,
926 ///   3. results of basic arithmetic operations (linear and continuous)
927 ///      involving 1., 2. and 3.
928 /// This method returns True if indeed only such values are used in calculating
929 /// `val.`
930 ///
931 /// Additionally, the trailing index for a contiguous load operation should
932 /// increment by 1 with every loop iteration, i.e. be based on:
933 ///   * `linalg.index <dim>` ,
934 /// where <dim> is the trailing non-unit dim of the iteration space (this way,
935 /// `linalg.index <dim>` increments by 1 with every loop iteration).
936 /// `foundIndexOp` is updated to `true` when such Op is found.
937 static bool isContiguousLoadIdx(LinalgOp &linalgOp, Value &val,
938                                 bool &foundIndexOp, VectorType resType) {
939 
940   assert(((llvm::count_if(resType.getShape(),
941                           [](int64_t dimSize) { return dimSize > 1; }) == 1)) &&
942          "n-D vectors are not yet supported");
943 
944   // Blocks outside _this_ linalg.generic are effectively loop invariant.
945   // However, analysing block arguments for _this_ linalg.generic Op is a bit
946   // tricky. Just bail out in the latter case.
947   // TODO: We could try analysing the corresponding affine map here.
948   auto *block = linalgOp.getBlock();
949   if (isa<BlockArgument>(val))
950     return llvm::all_of(block->getArguments(),
951                         [&val](Value v) { return (v != val); });
952 
953   Operation *defOp = val.getDefiningOp();
954   assert(defOp && "This is neither a block argument nor an operation result");
955 
956   if (auto indexOp = dyn_cast<linalg::IndexOp>(defOp)) {
957     auto loopDimThatIncrementsByOne = getTrailingNonUnitLoopDimIdx(linalgOp);
958 
959     foundIndexOp = (indexOp.getDim() == loopDimThatIncrementsByOne);
960     return true;
961   }
962 
963   auto *ancestor = block->findAncestorOpInBlock(*defOp);
964 
965   if (!ancestor)
966     return false;
967 
968   // Conservatively reject Ops that could lead to indices with stride other
969   // than 1.
970   if (!isa<arith::AddIOp, arith::ConstantOp, linalg::IndexOp>(ancestor))
971     return false;
972 
973   bool result = false;
974   for (auto op : ancestor->getOperands())
975     result |= isContiguousLoadIdx(linalgOp, op, foundIndexOp, resType);
976 
977   return result;
978 }
979 
980 /// Infer the memory access pattern for the input ExtractOp
981 ///
982 /// Based on the ExtratOp result shape and the access indices, decides whether
983 /// this Op corresponds to a contiguous load (including a broadcast of a scalar)
984 /// or a gather load. When analysing the ExtractOp indices (to identify
985 /// contiguous laods), this method looks for "loop" invariant indices (e.g.
986 /// block arguments) and indices that change linearly (e.g. via `linalg.index`
987 /// Op).
988 ///
989 /// Note that it is always safe to use gather load operations for contiguous
990 /// loads (albeit slow), but not vice-versa. When in doubt, bail out and assume
991 /// that `extractOp` is a gather load.
992 static VectorMemoryAccessKind
993 getTensorExtractMemoryAccessPattern(tensor::ExtractOp extractOp,
994                                     LinalgOp &linalgOp, VectorType resType) {
995 
996   auto inputShape = cast<ShapedType>(extractOp.getTensor().getType());
997 
998   // 0. Is this a 0-D vector? If yes then this is a scalar broadcast.
999   if (inputShape.getShape().empty())
1000     return VectorMemoryAccessKind::ScalarBroadcast;
1001 
1002   // True for vectors that are effectively 1D, e.g. `vector<1x4x1xi32>`, false
1003   // otherwise.
1004   bool isOutput1DVector =
1005       (llvm::count_if(resType.getShape(),
1006                       [](int64_t dimSize) { return dimSize > 1; }) == 1);
1007   // 1. Assume that it's a gather load when reading non-1D vector.
1008   if (!isOutput1DVector)
1009     return VectorMemoryAccessKind::Gather;
1010 
1011   bool leadingIdxsLoopInvariant = true;
1012 
1013   // 2. Analyze the leading indices of `extractOp`.
1014   // Look at the way each index is calculated and decide whether it is suitable
1015   // for a contiguous load, i.e. whether it's loop invariant. If not, it's a
1016   // gather load.
1017   auto indices = extractOp.getIndices();
1018   auto leadIndices = indices.drop_back(1);
1019 
1020   for (auto [i, indexVal] : llvm::enumerate(leadIndices)) {
1021     if (inputShape.getShape()[i] == 1)
1022       continue;
1023 
1024     leadingIdxsLoopInvariant &= isLoopInvariantIdx(linalgOp, indexVal, resType);
1025   }
1026 
1027   if (!leadingIdxsLoopInvariant) {
1028     LDBG("Found gather load: " << extractOp);
1029     return VectorMemoryAccessKind::Gather;
1030   }
1031 
1032   // 3. Analyze the trailing index for `extractOp`.
1033   // At this point we know that the leading indices are loop invariant. This
1034   // means that is potentially a scalar or a contiguous load. We can decide
1035   // based on the trailing idx.
1036   auto extractOpTrailingIdx = indices.back();
1037 
1038   // 3a. Scalar broadcast load
1039   // If the trailing index is loop invariant then this is a scalar load.
1040   if (leadingIdxsLoopInvariant &&
1041       isLoopInvariantIdx(linalgOp, extractOpTrailingIdx, resType)) {
1042     LDBG("Found scalar broadcast load: " << extractOp);
1043 
1044     return VectorMemoryAccessKind::ScalarBroadcast;
1045   }
1046 
1047   // 3b. Contiguous loads
1048   // The trailing `extractOp` index should increment with every loop iteration.
1049   // This effectively means that it must be based on the trailing loop index.
1050   // This is what the following bool captures.
1051   bool foundIndexOp = false;
1052   bool isContiguousLoad = isContiguousLoadIdx(linalgOp, extractOpTrailingIdx,
1053                                               foundIndexOp, resType);
1054   // TODO: Support generating contiguous loads for column vectors - that will
1055   // require adding a permutation map to tranfer_read Ops.
1056   bool isRowVector = resType.getShape().back() != 1;
1057   isContiguousLoad &= (foundIndexOp && isRowVector);
1058 
1059   if (isContiguousLoad) {
1060     LDBG("Found contigous load: " << extractOp);
1061     return VectorMemoryAccessKind::Contiguous;
1062   }
1063 
1064   // 4. Fallback case - gather load.
1065   LDBG("Found gather load: " << extractOp);
1066   return VectorMemoryAccessKind::Gather;
1067 }
1068 
1069 /// Helper function to vectorize the tensor.extract operations. Returns
1070 /// VectorizationStatus::NewOp to signal the vectorization algorithm that it
1071 /// should map the produced operations. This function is meant to be used as a
1072 /// CustomVectorizationHook.
1073 static VectorizationResult
1074 vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state,
1075                        Operation *op, LinalgOp linalgOp, const IRMapping &bvm) {
1076   tensor::ExtractOp extractOp = dyn_cast<tensor::ExtractOp>(op);
1077   if (!extractOp)
1078     return VectorizationResult{VectorizationStatus::Failure, nullptr};
1079   auto loc = extractOp.getLoc();
1080 
1081   // Compute the static loop sizes of the extract op.
1082   auto resultType = state.getCanonicalVecType(extractOp.getResult().getType());
1083   auto maskConstantOp = rewriter.create<arith::ConstantOp>(
1084       loc,
1085       DenseIntElementsAttr::get(state.getCanonicalVecType(rewriter.getI1Type()),
1086                                 /*value=*/true));
1087   auto passThruConstantOp =
1088       rewriter.create<arith::ConstantOp>(loc, rewriter.getZeroAttr(resultType));
1089 
1090   // Base indices are currently set to 0. We will need to re-visit if more
1091   // generic scenarios are to be supported.
1092   SmallVector<Value> baseIndices(
1093       extractOp.getIndices().size(),
1094       rewriter.create<arith::ConstantIndexOp>(loc, 0));
1095 
1096   VectorMemoryAccessKind memAccessKind =
1097       getTensorExtractMemoryAccessPattern(extractOp, linalgOp, resultType);
1098 
1099   // 1. Handle gather access
1100   if (memAccessKind == VectorMemoryAccessKind::Gather) {
1101     Value offset = calculateGatherOffset(rewriter, state, extractOp, bvm);
1102 
1103     // Generate the gather load
1104     Operation *gatherOp = rewriter.create<vector::GatherOp>(
1105         loc, resultType, extractOp.getTensor(), baseIndices, offset,
1106         maskConstantOp, passThruConstantOp);
1107     gatherOp = state.maskOperation(rewriter, gatherOp, linalgOp);
1108 
1109     LDBG("Vectorised as gather load: " << extractOp << "\n");
1110     return VectorizationResult{VectorizationStatus::NewOp, gatherOp};
1111   }
1112 
1113   // 2. Handle:
1114   //  a. scalar loads + broadcast,
1115   //  b. contiguous loads.
1116   // Both cases use vector.transfer_read.
1117 
1118   // Collect indices for `vector.transfer_read`. At this point, the indices will
1119   // either be scalars or would have been broadcast to vectors matching the
1120   // result type. For indices that are vectors, there are two options:
1121   //    * for non-trailing indices, all elements are identical (contiguous
1122   //      loads are identified by looking for non-trailing indices that are
1123   //      invariant with respect to the corresponding linalg.generic), or
1124   //    * for trailing indices, the index vector will contain values with stride
1125   //      one, but for `vector.transfer_read` only the first (i.e. 0th) index is
1126   //      needed.
1127   // This means that
1128   //   * for scalar indices - just re-use it,
1129   //   * for vector indices (e.g. `vector<1x1x4xindex>`) - extract the bottom
1130   //    (0th) element and use that.
1131   SmallVector<Value> transferReadIdxs;
1132   for (size_t i = 0; i < extractOp.getIndices().size(); i++) {
1133     Value idx = bvm.lookup(extractOp.getIndices()[i]);
1134     if (idx.getType().isIndex()) {
1135       transferReadIdxs.push_back(idx);
1136       continue;
1137     }
1138 
1139     auto indexAs1dVector = rewriter.create<vector::ShapeCastOp>(
1140         loc,
1141         VectorType::get(resultType.getShape().back(), rewriter.getIndexType(),
1142                         resultType.getScalableDims().back()),
1143         idx);
1144     transferReadIdxs.push_back(
1145         rewriter.create<vector::ExtractOp>(loc, indexAs1dVector, 0));
1146   }
1147 
1148   // `tensor.extract_element` is always in-bounds, hence the following holds.
1149   auto dstRank = resultType.getRank();
1150   auto srcRank = extractOp.getTensor().getType().getRank();
1151   SmallVector<bool> inBounds(dstRank, true);
1152 
1153   // 2a. Handle scalar broadcast access.
1154   if (memAccessKind == VectorMemoryAccessKind::ScalarBroadcast) {
1155     MLIRContext *ctx = rewriter.getContext();
1156     SmallVector<AffineExpr> exprs(dstRank, getAffineConstantExpr(0, ctx));
1157     auto permutationMap = AffineMap::get(srcRank, 0, exprs, ctx);
1158 
1159     auto transferReadOp = rewriter.create<vector::TransferReadOp>(
1160         loc, resultType, extractOp.getTensor(), transferReadIdxs,
1161         permutationMap, inBounds);
1162 
1163     // Mask this broadcasting xfer_read here rather than relying on the generic
1164     // path (the generic path assumes identity masking map, which wouldn't be
1165     // valid here).
1166     SmallVector<int64_t> readMaskShape = {1};
1167     auto readMaskType = VectorType::get(readMaskShape, rewriter.getI1Type());
1168     auto allTrue = rewriter.create<vector::ConstantMaskOp>(
1169         loc, readMaskType, vector::ConstantMaskKind::AllTrue);
1170     auto *maskedReadOp =
1171         mlir::vector::maskOperation(rewriter, transferReadOp, allTrue);
1172 
1173     LDBG("Vectorised as scalar broadcast load: " << extractOp << "\n");
1174     return VectorizationResult{VectorizationStatus::NewOp, maskedReadOp};
1175   }
1176 
1177   // 2b. Handle contiguous access.
1178   auto permutationMap = AffineMap::getMinorIdentityMap(
1179       srcRank, std::min(dstRank, srcRank), rewriter.getContext());
1180 
1181   int32_t rankDiff = dstRank - srcRank;
1182   // When dstRank > srcRank, broadcast the source tensor to the unitary leading
1183   // dims so that the ranks match. This is done by extending the map with 0s.
1184   // For example, for dstRank = 3, srcRank = 2, the following map created
1185   // above:
1186   //    (d0, d1) --> (d0, d1)
1187   // is extended as:
1188   //    (d0, d1) --> (0, d0, d1)
1189   while (rankDiff > 0) {
1190     permutationMap = permutationMap.insertResult(
1191         mlir::getAffineConstantExpr(0, rewriter.getContext()), 0);
1192     rankDiff--;
1193   }
1194 
1195   auto transferReadOp = rewriter.create<vector::TransferReadOp>(
1196       loc, resultType, extractOp.getTensor(), transferReadIdxs, permutationMap,
1197       inBounds);
1198 
1199   LDBG("Vectorised as contiguous load: " << extractOp);
1200   return VectorizationResult{VectorizationStatus::NewOp, transferReadOp};
1201 }
1202 
1203 /// Emit reduction operations if the shapes of the value to reduce is different
1204 /// that the result shape.
1205 // Note: this is a true builder that notifies the OpBuilder listener.
1206 // TODO: Consider moving as a static helper on the ReduceOp.
1207 static Operation *reduceIfNeeded(OpBuilder &b, LinalgOp linalgOp, Operation *op,
1208                                  Value reduceValue, Value initialValue,
1209                                  const IRMapping &bvm) {
1210   Value reduceVec = bvm.lookup(reduceValue);
1211   Value outputVec = bvm.lookup(initialValue);
1212   auto reduceType = dyn_cast<VectorType>(reduceVec.getType());
1213   auto outputType = dyn_cast<VectorType>(outputVec.getType());
1214   // Reduce only if needed as the value may already have been reduce for
1215   // contraction vectorization.
1216   if (!reduceType ||
1217       (outputType && reduceType.getShape() == outputType.getShape()))
1218     return nullptr;
1219   SmallVector<bool> dimsToMask = getDimsToReduce(linalgOp);
1220   return buildMultiDimReduce(b, op, reduceVec, outputVec, dimsToMask);
1221 }
1222 
1223 /// Generic vectorization for a single operation `op`, given already vectorized
1224 /// operands carried by `bvm`. Vectorization occurs as follows:
1225 ///   1. Try to apply any of the `customVectorizationHooks` and return its
1226 ///   result on success.
1227 ///   2. Clone any constant in the current scope without vectorization: each
1228 ///   consumer of the constant will later determine the shape to which the
1229 ///   constant needs to be broadcast to.
1230 ///   3. Fail on any remaining non `ElementwiseMappable` op. It is the purpose
1231 ///   of the `customVectorizationHooks` to cover such cases.
1232 ///   4. Clone `op` in vector form to a vector of shape prescribed by the first
1233 ///   operand of maximal rank. Other operands have smaller rank and are
1234 ///   broadcast accordingly. It is assumed this broadcast is always legal,
1235 ///   otherwise, it means one of the `customVectorizationHooks` is incorrect.
1236 ///
1237 /// This function assumes all operands of `op` have been vectorized and are in
1238 /// the `bvm` mapping. As a consequence, this function is meant to be called  on
1239 /// a topologically-sorted list of ops.
1240 /// This function does not update `bvm` but returns a VectorizationStatus that
1241 /// instructs the caller what `bvm` update needs to occur.
1242 static VectorizationResult
1243 vectorizeOneOp(RewriterBase &rewriter, VectorizationState &state,
1244                LinalgOp linalgOp, Operation *op, const IRMapping &bvm,
1245                ArrayRef<CustomVectorizationHook> customVectorizationHooks) {
1246   LDBG("vectorize op " << *op << "\n");
1247 
1248   // 1. Try to apply any CustomVectorizationHook.
1249   if (!customVectorizationHooks.empty()) {
1250     for (auto &customFunc : customVectorizationHooks) {
1251       VectorizationResult result = customFunc(op, bvm);
1252       if (result.status == VectorizationStatus::Failure)
1253         continue;
1254       return result;
1255     }
1256   }
1257 
1258   // 2. Constant ops don't get vectorized but rather broadcasted at their users.
1259   // Clone so that the constant is not confined to the linalgOp block .
1260   if (isa<arith::ConstantOp, func::ConstantOp>(op))
1261     return VectorizationResult{VectorizationStatus::NewOp, rewriter.clone(*op)};
1262 
1263   // 3. Only ElementwiseMappable are allowed in the generic vectorization.
1264   if (!OpTrait::hasElementwiseMappableTraits(op))
1265     return VectorizationResult{VectorizationStatus::Failure, nullptr};
1266 
1267   // 4 . Check if the operation is a reduction.
1268   SmallVector<std::pair<Value, Value>> reductionOperands;
1269   for (Value operand : op->getOperands()) {
1270     auto blockArg = dyn_cast<BlockArgument>(operand);
1271     if (!blockArg || blockArg.getOwner() != linalgOp.getBlock() ||
1272         blockArg.getArgNumber() < linalgOp.getNumDpsInputs())
1273       continue;
1274     SmallVector<Operation *> reductionOps;
1275     Value reduceValue = matchReduction(
1276         linalgOp.getRegionOutputArgs(),
1277         blockArg.getArgNumber() - linalgOp.getNumDpsInputs(), reductionOps);
1278     if (!reduceValue)
1279       continue;
1280     reductionOperands.push_back(std::make_pair(reduceValue, operand));
1281   }
1282   if (!reductionOperands.empty()) {
1283     assert(reductionOperands.size() == 1);
1284     Operation *reduceOp =
1285         reduceIfNeeded(rewriter, linalgOp, op, reductionOperands[0].first,
1286                        reductionOperands[0].second, bvm);
1287     if (reduceOp)
1288       return VectorizationResult{VectorizationStatus::NewOp, reduceOp};
1289   }
1290 
1291   // 5. Generic vectorization path for ElementwiseMappable ops.
1292   //   a. Get the first max ranked shape.
1293   VectorType firstMaxRankedType;
1294   for (Value operand : op->getOperands()) {
1295     auto vecOperand = bvm.lookup(operand);
1296     assert(vecOperand && "Vector operand couldn't be found");
1297 
1298     auto vecType = dyn_cast<VectorType>(vecOperand.getType());
1299     if (vecType && (!firstMaxRankedType ||
1300                     firstMaxRankedType.getRank() < vecType.getRank()))
1301       firstMaxRankedType = vecType;
1302   }
1303   //   b. Broadcast each op if needed.
1304   SmallVector<Value> vecOperands;
1305   for (Value scalarOperand : op->getOperands()) {
1306     Value vecOperand = bvm.lookup(scalarOperand);
1307     assert(vecOperand && "Vector operand couldn't be found");
1308 
1309     if (firstMaxRankedType) {
1310       auto vecType = VectorType::get(firstMaxRankedType.getShape(),
1311                                      getElementTypeOrSelf(vecOperand.getType()),
1312                                      firstMaxRankedType.getScalableDims());
1313       vecOperands.push_back(broadcastIfNeeded(rewriter, vecOperand, vecType));
1314     } else {
1315       vecOperands.push_back(vecOperand);
1316     }
1317   }
1318   //   c. for elementwise, the result is the vector with the firstMaxRankedShape
1319   SmallVector<Type> resultTypes;
1320   for (Type resultType : op->getResultTypes()) {
1321     resultTypes.push_back(
1322         firstMaxRankedType
1323             ? VectorType::get(firstMaxRankedType.getShape(), resultType,
1324                               firstMaxRankedType.getScalableDims())
1325             : resultType);
1326   }
1327   //   d. Build and return the new op.
1328   return VectorizationResult{
1329       VectorizationStatus::NewOp,
1330       rewriter.create(op->getLoc(), op->getName().getIdentifier(), vecOperands,
1331                       resultTypes, op->getAttrs())};
1332 }
1333 
1334 /// Generic vectorization function that rewrites the body of a `linalgOp` into
1335 /// vector form. Generic vectorization proceeds as follows:
1336 ///   1. Verify the `linalgOp` has one non-empty region.
1337 ///   2. Values defined above the region are mapped to themselves and will be
1338 ///   broadcasted on a per-need basis by their consumers.
1339 ///   3. Each region argument is vectorized into a vector.transfer_read (or 0-d
1340 ///   load).
1341 ///   TODO: Reuse opportunities for RAR dependencies.
1342 ///   4a. Register CustomVectorizationHook for YieldOp to capture the results.
1343 ///   4rewriter. Register CustomVectorizationHook for IndexOp to access the
1344 ///   iteration indices.
1345 ///   5. Iteratively call vectorizeOneOp on the region operations.
1346 ///
1347 /// When `broadcastToMaximalCommonShape` is set to true, eager broadcasting is
1348 /// performed to the maximal common vector size implied by the `linalgOp`
1349 /// iteration space. This eager broadcasting is introduced in the
1350 /// permutation_map of the vector.transfer_read operations. The eager
1351 /// broadcasting makes it trivial to detrmine where broadcast, transposes and
1352 /// reductions should occur, without any bookkeeping. The tradeoff is that, in
1353 /// the absence of good canonicalizations, the amount of work increases.
1354 /// This is not deemed a problem as we expect canonicalizations and foldings to
1355 /// aggressively clean up the useless work.
1356 static LogicalResult
1357 vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state,
1358                          LinalgOp linalgOp,
1359                          SmallVectorImpl<Value> &newResults) {
1360   LDBG("Vectorizing operation as linalg generic\n");
1361   Block *block = linalgOp.getBlock();
1362 
1363   // 2. Values defined above the region can only be broadcast for now. Make them
1364   // map to themselves.
1365   IRMapping bvm;
1366   SetVector<Value> valuesSet;
1367   mlir::getUsedValuesDefinedAbove(linalgOp->getRegion(0), valuesSet);
1368   bvm.map(valuesSet.getArrayRef(), valuesSet.getArrayRef());
1369 
1370   if (linalgOp.getNumDpsInits() == 0)
1371     return failure();
1372 
1373   // 3. Turn all BBArgs into vector.transfer_read / load.
1374   Location loc = linalgOp.getLoc();
1375   Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
1376   for (OpOperand *opOperand : linalgOp.getOpOperandsMatchingBBargs()) {
1377     BlockArgument bbarg = linalgOp.getMatchingBlockArgument(opOperand);
1378     if (linalgOp.isScalar(opOperand)) {
1379       bvm.map(bbarg, opOperand->get());
1380       continue;
1381     }
1382 
1383     // 3.a. Convert the indexing map for this input/output to a transfer read
1384     // permutation map and masking map.
1385     AffineMap indexingMap = linalgOp.getMatchingIndexingMap(opOperand);
1386 
1387     AffineMap readMap;
1388     VectorType readType;
1389     Type elemType = getElementTypeOrSelf(opOperand->get());
1390     if (linalgOp.isDpsInput(opOperand)) {
1391       // 3.a.i. For input reads we use the canonical vector shape.
1392       readMap = inverseAndBroadcastProjectedPermutation(indexingMap);
1393       readType = state.getCanonicalVecType(elemType);
1394     } else {
1395       // 3.a.ii. For output reads (iteration-carried dependence, e.g.,
1396       // reductions), the vector shape is computed by mapping the canonical
1397       // vector shape to the output domain and back to the canonical domain.
1398       readMap = inversePermutation(reindexIndexingMap(indexingMap));
1399       readType =
1400           state.getCanonicalVecType(elemType, readMap.compose(indexingMap));
1401     }
1402 
1403     SmallVector<Value> indices(linalgOp.getShape(opOperand).size(), zero);
1404 
1405     Operation *read = rewriter.create<vector::TransferReadOp>(
1406         loc, readType, opOperand->get(), indices, readMap);
1407     read = state.maskOperation(rewriter, read, linalgOp, indexingMap);
1408     Value readValue = read->getResult(0);
1409 
1410     // 3.b. If masked, set in-bounds to true. Masking guarantees that the access
1411     // will be in-bounds.
1412     if (auto maskOp = dyn_cast<vector::MaskingOpInterface>(read)) {
1413       SmallVector<bool> inBounds(readType.getRank(), true);
1414       cast<vector::TransferReadOp>(maskOp.getMaskableOp())
1415           .setInBoundsAttr(rewriter.getBoolArrayAttr(inBounds));
1416     }
1417 
1418     // 3.c. Not all ops support 0-d vectors, extract the scalar for now.
1419     // TODO: remove this.
1420     if (readType.getRank() == 0)
1421       readValue = rewriter.create<vector::ExtractOp>(loc, readValue,
1422                                                      ArrayRef<int64_t>());
1423 
1424     LDBG("New vectorized bbarg(" << bbarg.getArgNumber() << "): " << readValue
1425                                  << "\n");
1426     bvm.map(bbarg, readValue);
1427     bvm.map(opOperand->get(), readValue);
1428   }
1429 
1430   SmallVector<CustomVectorizationHook> hooks;
1431   // 4a. Register CustomVectorizationHook for yieldOp.
1432   CustomVectorizationHook vectorizeYield =
1433       [&](Operation *op, const IRMapping &bvm) -> VectorizationResult {
1434     return vectorizeLinalgYield(rewriter, op, bvm, state, linalgOp, newResults);
1435   };
1436   hooks.push_back(vectorizeYield);
1437 
1438   // 4b. Register CustomVectorizationHook for indexOp.
1439   CustomVectorizationHook vectorizeIndex =
1440       [&](Operation *op, const IRMapping &bvm) -> VectorizationResult {
1441     return vectorizeLinalgIndex(rewriter, state, op, linalgOp);
1442   };
1443   hooks.push_back(vectorizeIndex);
1444 
1445   // 4c. Register CustomVectorizationHook for extractOp.
1446   CustomVectorizationHook vectorizeExtract =
1447       [&](Operation *op, const IRMapping &bvm) -> VectorizationResult {
1448     return vectorizeTensorExtract(rewriter, state, op, linalgOp, bvm);
1449   };
1450   hooks.push_back(vectorizeExtract);
1451 
1452   // 5. Iteratively call `vectorizeOneOp` to each op in the slice.
1453   for (Operation &op : block->getOperations()) {
1454     VectorizationResult result =
1455         vectorizeOneOp(rewriter, state, linalgOp, &op, bvm, hooks);
1456     if (result.status == VectorizationStatus::Failure) {
1457       LDBG("failed to vectorize: " << op << "\n");
1458       return failure();
1459     }
1460     if (result.status == VectorizationStatus::NewOp) {
1461       Operation *maybeMaskedOp =
1462           state.maskOperation(rewriter, result.newOp, linalgOp);
1463       LDBG("New vector op: " << *maybeMaskedOp << "\n");
1464       bvm.map(op.getResults(), maybeMaskedOp->getResults());
1465     }
1466   }
1467 
1468   return success();
1469 }
1470 
1471 /// Given a tensor::PackOp, return the `dest` shape before any packing
1472 /// permutations.
1473 static SmallVector<int64_t> getTiledPackShape(tensor::PackOp packOp,
1474                                               ArrayRef<int64_t> destShape) {
1475   return applyPermutation(destShape, tensor::getPackInverseDestPerm(packOp));
1476 }
1477 
1478 /// Given an input, the mixed destSizes, and the vector sizes for vectorization,
1479 /// create an empty destination tensor and create a TransferWriteOp from the
1480 /// input to the empty tensor. If the destination shape is not the same as the
1481 /// inputVectorSizes for the first rank(inputVectorSizes) dims, then create a
1482 /// mask for the write. If `useInBoundsInsteadOfMasking` is set, then update the
1483 /// inBounds attribute of the transfer write op instead of masking.
1484 static Operation *createWriteOrMaskedWrite(OpBuilder &builder, Location loc,
1485                                            Value input,
1486                                            SmallVector<OpFoldResult> destSizes,
1487                                            ArrayRef<int64_t> inputVectorSizes,
1488                                            bool useInBoundsInsteadOfMasking) {
1489 
1490   auto inputType = cast<VectorType>(input.getType());
1491   Value dest = builder.create<tensor::EmptyOp>(loc, destSizes,
1492                                                inputType.getElementType());
1493   int64_t rank = cast<ShapedType>(dest.getType()).getRank();
1494   auto zero = builder.create<arith::ConstantIndexOp>(loc, 0);
1495   auto destShape = cast<ShapedType>(dest.getType()).getShape();
1496   SmallVector<bool> inBoundsVal(rank, true);
1497   if (useInBoundsInsteadOfMasking) {
1498     // Update the inBounds attribute.
1499     for (unsigned i = 0; i < rank; i++)
1500       inBoundsVal[i] = (destShape[i] == inputVectorSizes[i]) &&
1501                        !ShapedType::isDynamic(destShape[i]);
1502   }
1503   Operation *write = builder.create<vector::TransferWriteOp>(
1504       loc,
1505       /*vector=*/input,
1506       /*source=*/dest,
1507       /*indices=*/SmallVector<Value>(rank, zero),
1508       /*inBounds=*/inBoundsVal);
1509   assert(llvm::none_of(
1510              destShape.drop_front(inputVectorSizes.size()),
1511              [](int64_t size) { return size == ShapedType::kDynamic; }) &&
1512          "Only dims aligned with inputVectorSizes may be dynamic");
1513   if (useInBoundsInsteadOfMasking)
1514     return write;
1515   bool needMaskForWrite = !llvm::equal(
1516       inputVectorSizes, destShape.take_front(inputVectorSizes.size()));
1517   if (needMaskForWrite) {
1518     SmallVector<int64_t> writeMaskShape;
1519     writeMaskShape.append(inputVectorSizes.begin(), inputVectorSizes.end());
1520     writeMaskShape.append(destShape.begin() + inputVectorSizes.size(),
1521                           destShape.end());
1522     auto writeMaskType = VectorType::get(writeMaskShape, builder.getI1Type());
1523     Value maskForWrite =
1524         builder.create<vector::CreateMaskOp>(loc, writeMaskType, destSizes);
1525     write = mlir::vector::maskOperation(builder, write, maskForWrite);
1526   }
1527   return write;
1528 }
1529 
1530 /// Vectorize tensor::PackOp with (1) static innerTiles (2) constant
1531 /// padding value and (3) input vector sizes into:
1532 /// masked_transfer_read->shape_cast->transpose->transfer_write_in_bounds
1533 /// As in the following example:
1534 /// %pack = tensor.pack %src inner_dims_pos = [2, 1] inner_tiles = [16, 2]
1535 ///     into %dst : tensor<32x8x16xf32> -> tensor<32x4x1x16x2xf32>
1536 ///
1537 /// This pack would be vectorized to:
1538 ///
1539 /// %load = vector.mask %mask {
1540 ///     vector.transfer_read %arg0[%c0, %c0, %c0], %cst
1541 ///         {in_bounds = [true, true, true]} :
1542 ///         tensor<32x7x16xf32>, vector<32x8x16xf32>
1543 /// } : vector<32x8x16xi1> -> vector<32x8x16xf32>
1544 /// %shape_cast = vector.shape_cast %load : vector<32x8x16xf32>
1545 ///                                         to vector<32x4x2x1x16xf32>
1546 /// %transpose = vector.transpose %shape_cast, [0, 1, 3, 4, 2]
1547 ///     : vector<32x4x2x1x16xf32> to vector<32x4x1x16x2xf32>
1548 /// %write = vector.transfer_write %transpose,
1549 ///     %empty[%c0_0, %c0_0, %c0_0, %c0_0, %c0_0]
1550 ///     {in_bounds = [true, true, true, true, true]}
1551 ///     : vector<32x4x1x16x2xf32>, tensor<32x4x1x16x2xf32>
1552 ///
1553 /// If the (3) input vector sizes are not provided, the vector sizes are
1554 /// determined by the result tensor shape. Also, we update the inBounds
1555 /// attribute instead of masking.
1556 static LogicalResult
1557 vectorizeAsTensorPackOp(RewriterBase &rewriter, tensor::PackOp packOp,
1558                         ArrayRef<int64_t> inputVectorSizes,
1559                         SmallVectorImpl<Value> &newResults) {
1560   OpBuilder::InsertionGuard g(rewriter);
1561   rewriter.setInsertionPoint(packOp);
1562 
1563   Location loc = packOp.getLoc();
1564   auto padValue = packOp.getPaddingValue();
1565   if (!padValue) {
1566     padValue = rewriter.create<arith::ConstantOp>(
1567         loc, rewriter.getZeroAttr(packOp.getSourceType().getElementType()));
1568   }
1569   ReifiedRankedShapedTypeDims reifiedReturnShapes;
1570   LogicalResult status =
1571       cast<ReifyRankedShapedTypeOpInterface>(packOp.getOperation())
1572           .reifyResultShapes(rewriter, reifiedReturnShapes);
1573   (void)status; // prevent unused variable warning on non-assert builds.
1574   assert(succeeded(status) && "failed to reify result shapes");
1575 
1576   // If the input vector sizes are not provided, then the vector sizes are
1577   // determined by the result tensor shape. In case the vector sizes aren't
1578   // provided, we update the inBounds attribute instead of masking.
1579   bool useInBoundsInsteadOfMasking = false;
1580   if (inputVectorSizes.empty()) {
1581     ArrayRef<int64_t> resultTensorShape = packOp.getDestType().getShape();
1582     inputVectorSizes = resultTensorShape.take_front(packOp.getSourceRank());
1583     useInBoundsInsteadOfMasking = true;
1584   }
1585 
1586   // Create masked TransferReadOp.
1587   SmallVector<int64_t> inputShape(inputVectorSizes);
1588   auto innerTiles = packOp.getStaticInnerTiles();
1589   auto innerDimsPos = packOp.getInnerDimsPos();
1590   auto outerDimsPerm = packOp.getOuterDimsPerm();
1591   if (!outerDimsPerm.empty())
1592     applyPermutationToVector(inputShape,
1593                              invertPermutationVector(outerDimsPerm));
1594   for (auto [idx, size] : enumerate(innerTiles))
1595     inputShape[innerDimsPos[idx]] *= size;
1596   auto maskedRead = vector::createReadOrMaskedRead(
1597       rewriter, loc, packOp.getSource(), inputShape, padValue,
1598       useInBoundsInsteadOfMasking);
1599 
1600   // Create ShapeCastOp.
1601   SmallVector<int64_t> destShape(inputVectorSizes);
1602   destShape.append(innerTiles.begin(), innerTiles.end());
1603   auto tiledPackType = VectorType::get(getTiledPackShape(packOp, destShape),
1604                                        packOp.getDestType().getElementType());
1605   auto shapeCastOp =
1606       rewriter.create<vector::ShapeCastOp>(loc, tiledPackType, maskedRead);
1607 
1608   // Create TransposeOp.
1609   auto destPermutation =
1610       invertPermutationVector(tensor::getPackInverseDestPerm(packOp));
1611   auto transposeOp = rewriter.create<vector::TransposeOp>(
1612       loc, shapeCastOp.getResult(), destPermutation);
1613 
1614   // Create TransferWriteOp.
1615   Operation *write = createWriteOrMaskedWrite(
1616       rewriter, loc, transposeOp.getResult(), reifiedReturnShapes[0],
1617       inputVectorSizes, /*useInBoundsInsteadOfMasking=*/false);
1618   newResults.push_back(write->getResult(0));
1619   return success();
1620 }
1621 
1622 /// Vectorize a `tensor::UnPackOp` to these 4 Ops:
1623 ///   Vector::TransferReadOp - Reads a vector from the source tensor
1624 ///   vector::TransposeOp - Transpose the Source tensor
1625 ///   ShapeCastOp - Reshape the data based on the target.
1626 ///   vector::TransferWriteOp. - Write the result vector back to the destination
1627 ///   tensor.
1628 ///   If the vector sizes are not provided:
1629 ///   * the vector sizes are determined by the input operand and attributes,
1630 ///   * update the inBounds attribute instead of masking.
1631 static LogicalResult
1632 vectorizeAsTensorUnpackOp(RewriterBase &rewriter, tensor::UnPackOp unpackOp,
1633                           ArrayRef<int64_t> inputVectorSizes,
1634                           SmallVectorImpl<Value> &newResults) {
1635 
1636   OpBuilder::InsertionGuard g(rewriter);
1637   rewriter.setInsertionPoint(unpackOp);
1638 
1639   RankedTensorType unpackTensorType = unpackOp.getSourceType();
1640 
1641   ArrayRef<int64_t> innerDimPos = unpackOp.getInnerDimsPos();
1642   ArrayRef<int64_t> innerTiles = unpackOp.getStaticInnerTiles();
1643   ArrayRef<int64_t> sourceShape = unpackTensorType.getShape();
1644   bool useInBoundsInsteadOfMasking = false;
1645   ArrayRef<int64_t> outerDimsPerm = unpackOp.getOuterDimsPerm();
1646 
1647   auto destSize = unpackOp.getDestRank();
1648 
1649   if (!inputVectorSizes.empty())
1650     assert(inputVectorSizes.size() == destSize &&
1651            "Incorrect number of input vector sizes");
1652 
1653   // vectorSizes is the shape of the vector that will be used to do final
1654   // write on the destination tensor. It is set like this: Let's say the
1655   // source tensor is rank 'M' and the dest tensor rank 'N', where N <= M.
1656   // Thus:
1657   // 1. vectorSizes = sourceShape.take_front(N)
1658   // 2. if outer_dims_perms is present: do that permutation on vectorSizes.
1659   // 3. multiply all the locations in vectorSize pointed by innerDimPos by the
1660   //    innerTiles attribute value.
1661   SmallVector<int64_t> vectorSizes(inputVectorSizes);
1662   if (vectorSizes.empty()) {
1663     llvm::append_range(vectorSizes, sourceShape.take_front(destSize));
1664     if (!outerDimsPerm.empty())
1665       applyPermutationToVector(vectorSizes, outerDimsPerm);
1666     for (auto [i, pos] : llvm::enumerate(innerDimPos))
1667       vectorSizes[pos] *= innerTiles[i];
1668 
1669     useInBoundsInsteadOfMasking = true;
1670   }
1671 
1672   // readVectorSizes is the size of tensor used to read and apply mask. It is
1673   // set like this: Let's say the vectorSize (VS) array is size 'N' and
1674   // the sourceShape(SS) is 'M' where M >= N and InnerTileSizes (IT) of
1675   // size M-N
1676   // Thus:
1677   // - initially: readVectorSizes = vectorInputSizes
1678   // - Divide all the readMaskShape locations pointed by innerDimPos
1679   //   by the innerTileSize attribute value.
1680   // - if outer_dims_perms is present: do that permutation on readVectorSizes.
1681   // - Append the remaining shape from SS
1682   // E.g. let's say let's say unpackTensorType.getShape() = <8x8x32x16>
1683   // inner Dim Pos = [0, 1] and Inner Tiles = [32, 16], vector_sizes are [512,
1684   // 128] and outer_dims_perm is [1, 0] then read shape is:
1685   //   ReadVectorSizes(initial): [512, 128]
1686   //   Final Value(after innerDim Adjustment): [512/32, 128/16]
1687   //                                           = [16, 8]
1688   //   After applying outer_dims_perm: [8, 16]
1689   //   After appending the rest of the sourceShape: [8, 16, 32, 16]
1690 
1691   SmallVector<int64_t> readVectorSizes(vectorSizes.begin(), vectorSizes.end());
1692 
1693   for (auto [index, size] : enumerate(innerTiles)) {
1694     readVectorSizes[innerDimPos[index]] =
1695         llvm::divideCeil(readVectorSizes[innerDimPos[index]], size);
1696   }
1697   if (!outerDimsPerm.empty()) {
1698     applyPermutationToVector(readVectorSizes, outerDimsPerm);
1699   }
1700   readVectorSizes.append(sourceShape.begin() + vectorSizes.size(),
1701                          sourceShape.end());
1702 
1703   ReifiedRankedShapedTypeDims reifiedRetShapes;
1704   LogicalResult status =
1705       cast<ReifyRankedShapedTypeOpInterface>(unpackOp.getOperation())
1706           .reifyResultShapes(rewriter, reifiedRetShapes);
1707   if (status.failed()) {
1708     LDBG("Unable to reify result shapes of " << unpackOp);
1709     return failure();
1710   }
1711   Location loc = unpackOp->getLoc();
1712 
1713   auto padValue = rewriter.create<arith::ConstantOp>(
1714       loc, rewriter.getZeroAttr(unpackOp.getSourceType().getElementType()));
1715 
1716   // Read result, mask if necessary. If transferReadOp shape is not equal
1717   // to shape of source, then a mask is necessary.
1718   Value readResult = vector::createReadOrMaskedRead(
1719       rewriter, loc, unpackOp.getSource(), readVectorSizes, padValue,
1720       /*useInBoundsInsteadOfMasking=*/false);
1721 
1722   PackingMetadata packMetadata;
1723   SmallVector<int64_t> lastDimToInsertPosPerm =
1724       tensor::getUnPackInverseSrcPerm(unpackOp, packMetadata);
1725   ShapedType maskedOpShapedType = cast<ShapedType>(readResult.getType());
1726   SmallVector<int64_t> stripMineShape(maskedOpShapedType.getShape());
1727   mlir::Type stripMineElemType = maskedOpShapedType.getElementType();
1728   applyPermutationToVector(stripMineShape, lastDimToInsertPosPerm);
1729   RankedTensorType stripMineTensorType =
1730       RankedTensorType::get(stripMineShape, stripMineElemType);
1731   // Transpose the appropriate rows to match output.
1732   vector::TransposeOp transposeOp = rewriter.create<vector::TransposeOp>(
1733       loc, readResult, lastDimToInsertPosPerm);
1734 
1735   // Collapse the vector to the size required by result.
1736   RankedTensorType collapsedType = tensor::CollapseShapeOp::inferCollapsedType(
1737       stripMineTensorType, packMetadata.reassociations);
1738   mlir::VectorType vecCollapsedType =
1739       VectorType::get(collapsedType.getShape(), collapsedType.getElementType());
1740   vector::ShapeCastOp shapeCastOp = rewriter.create<vector::ShapeCastOp>(
1741       loc, vecCollapsedType, transposeOp->getResult(0));
1742 
1743   // writeVectorSizes had to match the shapecast shape for dynamic sizes,
1744   // otherwise the validator complains that the mask size is invalid.
1745   SmallVector<int64_t> writeVectorSizes(
1746       unpackOp.getDestType().hasStaticShape()
1747           ? vectorSizes
1748           : shapeCastOp.getResultVectorType().getShape());
1749   Operation *write = createWriteOrMaskedWrite(
1750       rewriter, loc, shapeCastOp.getResult(), reifiedRetShapes[0],
1751       writeVectorSizes, useInBoundsInsteadOfMasking);
1752   newResults.push_back(write->getResult(0));
1753   return success();
1754 }
1755 
1756 /// Vectorize a `padOp` with (1) static result type, (2) constant padding value
1757 /// and (3) all-zero lowPad to
1758 ///   `transfer_write_in_bounds(transfer_read_masked(pad_source, pad_value))`.
1759 static LogicalResult
1760 vectorizeAsTensorPadOp(RewriterBase &rewriter, tensor::PadOp padOp,
1761                        ArrayRef<int64_t> inputVectorSizes,
1762                        SmallVectorImpl<Value> &newResults) {
1763   auto padValue = padOp.getConstantPaddingValue();
1764   Location loc = padOp.getLoc();
1765 
1766   // transfer_write_in_bounds(transfer_read_masked(pad_source, pad_value))
1767   OpBuilder::InsertionGuard g(rewriter);
1768   rewriter.setInsertionPoint(padOp);
1769 
1770   ReifiedRankedShapedTypeDims reifiedReturnShapes;
1771   LogicalResult status =
1772       cast<ReifyRankedShapedTypeOpInterface>(padOp.getOperation())
1773           .reifyResultShapes(rewriter, reifiedReturnShapes);
1774   (void)status; // prevent unused variable warning on non-assert builds
1775   assert(succeeded(status) && "failed to reify result shapes");
1776   auto maskedRead = vector::createReadOrMaskedRead(
1777       rewriter, loc, padOp.getSource(), inputVectorSizes, padValue,
1778       /*useInBoundsInsteadOfMasking=*/false);
1779   Operation *write = createWriteOrMaskedWrite(
1780       rewriter, loc, maskedRead, reifiedReturnShapes[0], inputVectorSizes,
1781       /*useInBoundsInsteadOfMasking=*/false);
1782   newResults.push_back(write->getResult(0));
1783   return success();
1784 }
1785 
1786 // TODO: probably need some extra checks for reduction followed by consumer
1787 // ops that may not commute (e.g. linear reduction + non-linear instructions).
1788 static LogicalResult reductionPreconditions(LinalgOp op) {
1789   if (llvm::none_of(op.getIteratorTypesArray(), isReductionIterator)) {
1790     LDBG("reduction precondition failed: no reduction iterator\n");
1791     return failure();
1792   }
1793   for (OpOperand &opOperand : op.getDpsInitsMutable()) {
1794     AffineMap indexingMap = op.getMatchingIndexingMap(&opOperand);
1795     if (indexingMap.isPermutation())
1796       continue;
1797 
1798     Operation *reduceOp = matchLinalgReduction(&opOperand);
1799     if (!reduceOp || !getCombinerOpKind(reduceOp)) {
1800       LDBG("reduction precondition failed: reduction detection failed\n");
1801       return failure();
1802     }
1803   }
1804   return success();
1805 }
1806 
1807 static LogicalResult
1808 vectorizeDynamicConvOpPrecondition(linalg::LinalgOp conv,
1809                                    bool flatten1DDepthwiseConv) {
1810   if (flatten1DDepthwiseConv) {
1811     LDBG("Vectorization of flattened convs with dynamic shapes is not "
1812          "supported\n");
1813     return failure();
1814   }
1815 
1816   if (!isa<linalg::DepthwiseConv1DNwcWcOp>(conv)) {
1817     LDBG("Not a 1D depth-wise WC conv, dynamic shapes are not supported\n");
1818     return failure();
1819   }
1820 
1821   // Support dynamic shapes in 1D depthwise convolution, but only in the
1822   // _channel_ dimension.
1823   Value lhs = conv.getDpsInputOperand(0)->get();
1824   ArrayRef<int64_t> lhsShape = cast<ShapedType>(lhs.getType()).getShape();
1825   auto shapeWithoutCh = lhsShape.drop_back(1);
1826   if (ShapedType::isDynamicShape(shapeWithoutCh)) {
1827     LDBG("Dynamically-shaped op vectorization precondition failed: only "
1828          "channel dim can be dynamic\n");
1829     return failure();
1830   }
1831 
1832   return success();
1833 }
1834 
1835 static LogicalResult
1836 vectorizeDynamicLinalgOpPrecondition(linalg::LinalgOp op,
1837                                      bool flatten1DDepthwiseConv) {
1838   if (isa<ConvolutionOpInterface>(op.getOperation()))
1839     return vectorizeDynamicConvOpPrecondition(op, flatten1DDepthwiseConv);
1840 
1841   if (hasReductionIterator(op))
1842     return reductionPreconditions(op);
1843 
1844   // TODO: Masking only supports dynamic element-wise ops, linalg.generic ops,
1845   // linalg.copy ops and ops that implement ContractionOpInterface for now.
1846   if (!isElementwise(op) &&
1847       !isa<linalg::GenericOp, linalg::CopyOp, linalg::ContractionOpInterface>(
1848           op.getOperation()))
1849     return failure();
1850 
1851   LDBG("Dynamically-shaped op meets vectorization pre-conditions\n");
1852   return success();
1853 }
1854 
1855 /// Need to check if the inner-tiles are static/constant.
1856 static LogicalResult
1857 vectorizeUnPackOpPrecondition(tensor::UnPackOp unpackOp,
1858                               ArrayRef<int64_t> inputVectorSizes) {
1859 
1860   if (llvm::any_of(unpackOp.getInnerTiles(), [](OpFoldResult res) {
1861         return !getConstantIntValue(res).has_value();
1862       })) {
1863     LDBG("Inner-tiles must be constant: " << unpackOp << "\n");
1864     return failure();
1865   }
1866   ArrayRef<int64_t> resultShape = unpackOp.getDestType().getShape();
1867   bool satisfyEmptyCond = inputVectorSizes.empty() &&
1868                           unpackOp.getDestType().hasStaticShape() &&
1869                           unpackOp.getSourceType().hasStaticShape();
1870   if (!satisfyEmptyCond &&
1871       failed(vector::isValidMaskedInputVector(resultShape, inputVectorSizes)))
1872     return failure();
1873 
1874   return success();
1875 }
1876 
1877 static LogicalResult vectorizeLinalgOpPrecondition(
1878     LinalgOp linalgOp, ArrayRef<int64_t> inputVectorSizes,
1879     bool vectorizeNDExtract, bool flatten1DDepthwiseConv) {
1880   // tensor with dimension of 0 cannot be vectorized.
1881   if (llvm::is_contained(linalgOp.getStaticShape(), 0))
1882     return failure();
1883   // Check API contract for input vector sizes.
1884   if (!inputVectorSizes.empty() &&
1885       failed(vector::isValidMaskedInputVector(linalgOp.getStaticLoopRanges(),
1886                                               inputVectorSizes)))
1887     return failure();
1888 
1889   if (linalgOp.hasDynamicShape() && failed(vectorizeDynamicLinalgOpPrecondition(
1890                                         linalgOp, flatten1DDepthwiseConv))) {
1891     LDBG("Dynamically-shaped op failed vectorization pre-conditions\n");
1892     return failure();
1893   }
1894 
1895   SmallVector<CustomVectorizationPrecondition> customPreconditions;
1896 
1897   // Register CustomVectorizationPrecondition for extractOp.
1898   customPreconditions.push_back(tensorExtractVectorizationPrecondition);
1899 
1900   // All types in the body should be a supported element type for VectorType.
1901   for (Operation &innerOp : linalgOp->getRegion(0).front()) {
1902     // Check if any custom hook can vectorize the inner op.
1903     if (llvm::any_of(
1904             customPreconditions,
1905             [&](const CustomVectorizationPrecondition &customPrecondition) {
1906               return succeeded(
1907                   customPrecondition(&innerOp, vectorizeNDExtract));
1908             })) {
1909       continue;
1910     }
1911     if (llvm::any_of(innerOp.getOperandTypes(), [](Type type) {
1912           return !VectorType::isValidElementType(type);
1913         })) {
1914       return failure();
1915     }
1916     if (llvm::any_of(innerOp.getResultTypes(), [](Type type) {
1917           return !VectorType::isValidElementType(type);
1918         })) {
1919       return failure();
1920     }
1921   }
1922   if (isElementwise(linalgOp))
1923     return success();
1924 
1925   // TODO: isaConvolutionOpInterface that can also infer from generic
1926   // features. But we will still need stride/dilation attributes that will be
1927   // annoying to reverse-engineer...
1928   if (isa<ConvolutionOpInterface>(linalgOp.getOperation()))
1929     return success();
1930   // TODO: the common vector shape is equal to the static loop sizes only when
1931   // all indexing maps are projected permutations. For convs and stencils the
1932   // logic will need to evolve.
1933   if (!allIndexingsAreProjectedPermutation(linalgOp)) {
1934     LDBG("precondition failed: not projected permutations\n");
1935     return failure();
1936   }
1937   if (failed(reductionPreconditions(linalgOp))) {
1938     LDBG("precondition failed: reduction preconditions\n");
1939     return failure();
1940   }
1941   return success();
1942 }
1943 
1944 static LogicalResult
1945 vectorizePackOpPrecondition(tensor::PackOp packOp,
1946                             ArrayRef<int64_t> inputVectorSizes) {
1947   auto padValue = packOp.getPaddingValue();
1948   Attribute cstAttr;
1949   if (padValue && !matchPattern(padValue, m_Constant(&cstAttr))) {
1950     LDBG("pad value is not constant: " << packOp << "\n");
1951     return failure();
1952   }
1953   ArrayRef<int64_t> resultTensorShape = packOp.getDestType().getShape();
1954   bool satisfyEmptyCond = true;
1955   if (inputVectorSizes.empty()) {
1956     if (!packOp.getDestType().hasStaticShape() ||
1957         !packOp.getSourceType().hasStaticShape())
1958       satisfyEmptyCond = false;
1959   }
1960 
1961   if (!satisfyEmptyCond &&
1962       failed(vector::isValidMaskedInputVector(
1963           resultTensorShape.take_front(packOp.getSourceRank()),
1964           inputVectorSizes)))
1965     return failure();
1966 
1967   if (llvm::any_of(packOp.getInnerTiles(), [](OpFoldResult v) {
1968         return !getConstantIntValue(v).has_value();
1969       })) {
1970     LDBG("inner_tiles must be constant: " << packOp << "\n");
1971     return failure();
1972   }
1973 
1974   return success();
1975 }
1976 
1977 static LogicalResult
1978 vectorizePadOpPrecondition(tensor::PadOp padOp,
1979                            ArrayRef<int64_t> inputVectorSizes) {
1980   auto padValue = padOp.getConstantPaddingValue();
1981   if (!padValue) {
1982     LDBG("pad value is not constant: " << padOp << "\n");
1983     return failure();
1984   }
1985 
1986   ArrayRef<int64_t> resultTensorShape = padOp.getResultType().getShape();
1987   if (failed(vector::isValidMaskedInputVector(resultTensorShape,
1988                                               inputVectorSizes)))
1989     return failure();
1990 
1991   if (llvm::any_of(padOp.getLow(), [](Value v) {
1992         std::optional<int64_t> res = getConstantIntValue(v);
1993         return !res.has_value() || res.value() != 0;
1994       })) {
1995     LDBG("low pad must all be zero: " << padOp << "\n");
1996     return failure();
1997   }
1998 
1999   return success();
2000 }
2001 
2002 /// Preconditions for scalable vectors. This is quite restrictive - it models
2003 /// the fact that in practice we would only make selected dimensions scalable.
2004 static LogicalResult
2005 vectorizeScalableVectorPrecondition(Operation *op,
2006                                     ArrayRef<int64_t> inputVectorSizes,
2007                                     ArrayRef<bool> inputScalableVecDims) {
2008   assert(inputVectorSizes.size() == inputScalableVecDims.size() &&
2009          "Number of input vector sizes and scalable dims doesn't match");
2010 
2011   size_t numOfScalableDims =
2012       llvm::count_if(inputScalableVecDims, [](bool flag) { return flag; });
2013 
2014   if (numOfScalableDims == 0)
2015     return success();
2016 
2017   auto linalgOp = dyn_cast<LinalgOp>(op);
2018 
2019   // Cond 1: There's been no need for scalable vectorisation of
2020   // non-linalg Ops so far
2021   if (!linalgOp)
2022     return failure();
2023 
2024   // Cond 2: There's been no need for more than 2 scalable dims so far
2025   if (numOfScalableDims > 2)
2026     return failure();
2027 
2028   // Cond 3: Look at the configuration in `inputScalableVecDims` and verify that
2029   // it matches one of the supported cases:
2030   //  1. Exactly 1 dim is scalable and that's the _last_ non-unit parallel dim
2031   //    (*).
2032   //  2. Exactly 2 dims are scalable and those are the _last two adjacent_
2033   //     parallel dims.
2034   //  3. Exactly 1 reduction dim is scalable and that's the last (innermost)
2035   //  dim.
2036   // The 2nd restriction above means that only Matmul-like Ops are supported
2037   // when 2 dims are scalable, e.g. :
2038   //    * iterators = [parallel, parallel, reduction]
2039   //    * scalable flags = [true, true, false]
2040   //
2041   // (*) Non-unit dims get folded away in practice.
2042   // TODO: Relax these conditions as good motivating examples are identified.
2043 
2044   // Find the first scalable flag.
2045   bool seenNonUnitParallel = false;
2046   auto iterators = linalgOp.getIteratorTypesArray();
2047   SmallVector<bool> scalableFlags(inputScalableVecDims);
2048   int64_t idx = scalableFlags.size() - 1;
2049   while (!scalableFlags[idx]) {
2050     bool isNonUnitDim = (inputVectorSizes[idx] != 1);
2051     seenNonUnitParallel |=
2052         (iterators[idx] == utils::IteratorType::parallel && isNonUnitDim);
2053 
2054     iterators.pop_back();
2055     scalableFlags.pop_back();
2056     --idx;
2057   }
2058 
2059   // Analyze the iterator corresponding to the first scalable dim.
2060   switch (iterators.back()) {
2061   case utils::IteratorType::reduction: {
2062     // Check 3. above is met.
2063     if (iterators.size() != inputVectorSizes.size()) {
2064       LDBG("Non-trailing reduction dim requested for scalable "
2065            "vectorization\n");
2066       return failure();
2067     }
2068     if (isa<linalg::MatmulOp>(op) || isa<linalg::MatmulTransposeAOp>(op)) {
2069       LDBG("Scalable vectorization of the reduction dim in Matmul-like ops "
2070            "is not supported\n");
2071       return failure();
2072     }
2073     break;
2074   }
2075   case utils::IteratorType::parallel: {
2076     // Check 1. and 2. above are met.
2077     if (seenNonUnitParallel) {
2078       LDBG("Inner parallel dim not requested for scalable "
2079            "vectorization\n");
2080       return failure();
2081     }
2082     break;
2083   }
2084   }
2085 
2086   // If present, check the 2nd scalable dim. ATM, only Matmul-like Ops are
2087   // supported for which expect the folowing config:
2088   //    * iterators = [parallel, parallel, reduction]
2089   //    * scalable flags = [true, true, false]
2090   if (numOfScalableDims == 2) {
2091     // Disallow below case which breaks 3. above:
2092     //    * iterators = [..., parallel, reduction]
2093     //    * scalable flags = [..., true, true]
2094     if (iterators.back() == utils::IteratorType::reduction) {
2095       LDBG("Higher dim than the trailing reduction dim requested for scalable "
2096            "vectorization\n");
2097       return failure();
2098     }
2099     scalableFlags.pop_back();
2100     iterators.pop_back();
2101 
2102     if (!scalableFlags.back() ||
2103         (iterators.back() != utils::IteratorType::parallel))
2104       return failure();
2105   }
2106 
2107   // Check to not let go the matmul with extended semantic, through this
2108   // transform.
2109   if (linalgOp.hasUserDefinedMaps())
2110     return failure();
2111 
2112   // Cond 4: Only the following ops are supported in the
2113   // presence of scalable vectors
2114   return success(isElementwise(linalgOp) || isa<linalg::MatmulOp>(op) ||
2115                  isa<linalg::MatmulTransposeAOp>(op) ||
2116                  isa<linalg::DepthwiseConv1DNwcWcOp>(op) ||
2117                  isa<linalg::MatvecOp>(op) || hasReductionIterator(linalgOp));
2118 }
2119 
2120 LogicalResult mlir::linalg::vectorizeOpPrecondition(
2121     Operation *op, ArrayRef<int64_t> inputVectorSizes,
2122     ArrayRef<bool> inputScalableVecDims, bool vectorizeNDExtract,
2123     bool flatten1DDepthwiseConv) {
2124 
2125   if (!hasVectorizationImpl(op))
2126     return failure();
2127 
2128   if (failed(vectorizeScalableVectorPrecondition(op, inputVectorSizes,
2129                                                  inputScalableVecDims)))
2130     return failure();
2131 
2132   return TypeSwitch<Operation *, LogicalResult>(op)
2133       .Case<linalg::LinalgOp>([&](auto linalgOp) {
2134         return vectorizeLinalgOpPrecondition(linalgOp, inputVectorSizes,
2135                                              vectorizeNDExtract,
2136                                              flatten1DDepthwiseConv);
2137       })
2138       .Case<tensor::PadOp>([&](auto padOp) {
2139         return vectorizePadOpPrecondition(padOp, inputVectorSizes);
2140       })
2141       .Case<tensor::PackOp>([&](auto packOp) {
2142         return vectorizePackOpPrecondition(packOp, inputVectorSizes);
2143       })
2144       .Case<tensor::UnPackOp>([&](auto unpackOp) {
2145         return vectorizeUnPackOpPrecondition(unpackOp, inputVectorSizes);
2146       })
2147       .Default([](auto) { return failure(); });
2148 }
2149 
2150 /// Converts affine.apply Ops to arithmetic operations.
2151 static void convertAffineApply(RewriterBase &rewriter, LinalgOp linalgOp) {
2152   OpBuilder::InsertionGuard g(rewriter);
2153   auto toReplace = linalgOp.getBlock()->getOps<affine::AffineApplyOp>();
2154 
2155   for (auto op : make_early_inc_range(toReplace)) {
2156     rewriter.setInsertionPoint(op);
2157     auto expanded = affine::expandAffineExpr(
2158         rewriter, op->getLoc(), op.getAffineMap().getResult(0),
2159         op.getOperands().take_front(op.getAffineMap().getNumDims()),
2160         op.getOperands().take_back(op.getAffineMap().getNumSymbols()));
2161     rewriter.replaceOp(op, expanded);
2162   }
2163 }
2164 
2165 bool mlir::linalg::hasVectorizationImpl(Operation *op) {
2166   return isa<linalg::LinalgOp, tensor::PadOp, tensor::PackOp, tensor::UnPackOp>(
2167       op);
2168 }
2169 
2170 /// Emit a suitable vector form for an operation. If provided,
2171 /// `inputVectorSizes` are used to vectorize this operation.
2172 /// `inputVectorSizes` must match the rank of the iteration space of the
2173 /// operation and the input vector sizes must be greater than or equal to
2174 /// their counterpart iteration space sizes, if static. `inputVectorShapes`
2175 /// also allows the vectorization of operations with dynamic shapes.
2176 LogicalResult mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op,
2177                                       ArrayRef<int64_t> inputVectorSizes,
2178                                       ArrayRef<bool> inputScalableVecDims,
2179                                       bool vectorizeNDExtract,
2180                                       bool flatten1DDepthwiseConv) {
2181   LDBG("Attempting to vectorize:\n" << *op << "\n");
2182   LDBG("Input vector sizes: ");
2183   LLVM_DEBUG(llvm::interleaveComma(inputVectorSizes, llvm::dbgs()));
2184   LLVM_DEBUG(llvm::dbgs() << "\n");
2185   LDBG("Input scalable vector dims: ");
2186   LLVM_DEBUG(llvm::interleaveComma(inputScalableVecDims, llvm::dbgs()));
2187   LLVM_DEBUG(llvm::dbgs() << "\n");
2188 
2189   if (failed(vectorizeOpPrecondition(op, inputVectorSizes, inputScalableVecDims,
2190                                      vectorizeNDExtract,
2191                                      flatten1DDepthwiseConv))) {
2192     LDBG("Vectorization pre-conditions failed\n");
2193     return failure();
2194   }
2195 
2196   // Initialize vectorization state.
2197   VectorizationState state(rewriter);
2198   if (auto linalgOp = dyn_cast<linalg::LinalgOp>(op)) {
2199     if (failed(state.initState(rewriter, linalgOp, inputVectorSizes,
2200                                inputScalableVecDims))) {
2201       LDBG("Vectorization state couldn't be initialized\n");
2202       return failure();
2203     }
2204   }
2205 
2206   SmallVector<Value> results;
2207   auto vectorizeResult =
2208       TypeSwitch<Operation *, LogicalResult>(op)
2209           .Case<linalg::LinalgOp>([&](auto linalgOp) {
2210             // TODO: isaConvolutionOpInterface that can also infer from
2211             // generic features. Will require stride/dilation attributes
2212             // inference.
2213             if (isa<ConvolutionOpInterface>(linalgOp.getOperation())) {
2214               FailureOr<Operation *> convOr = vectorizeConvolution(
2215                   rewriter, linalgOp, inputVectorSizes, inputScalableVecDims,
2216                   flatten1DDepthwiseConv);
2217               if (succeeded(convOr)) {
2218                 llvm::append_range(results, (*convOr)->getResults());
2219                 return success();
2220               }
2221 
2222               LDBG("Unsupported convolution can't be vectorized.\n");
2223               return failure();
2224             }
2225 
2226             LDBG("Vectorize generic by broadcasting to the canonical vector "
2227                  "shape\n");
2228 
2229             // Pre-process before proceeding.
2230             convertAffineApply(rewriter, linalgOp);
2231 
2232             // TODO: 'vectorize' takes in a 'RewriterBase' which is up-casted
2233             // to 'OpBuilder' when it is passed over to some methods like
2234             // 'vectorizeAsLinalgGeneric'. This is highly problematic: if we
2235             // erase an op within these methods, the actual rewriter won't be
2236             // notified and we will end up with read-after-free issues!
2237             return vectorizeAsLinalgGeneric(rewriter, state, linalgOp, results);
2238           })
2239           .Case<tensor::PadOp>([&](auto padOp) {
2240             return vectorizeAsTensorPadOp(rewriter, padOp, inputVectorSizes,
2241                                           results);
2242           })
2243           .Case<tensor::PackOp>([&](auto packOp) {
2244             return vectorizeAsTensorPackOp(rewriter, packOp, inputVectorSizes,
2245                                            results);
2246           })
2247           .Case<tensor::UnPackOp>([&](auto unpackOp) {
2248             return vectorizeAsTensorUnpackOp(rewriter, unpackOp,
2249                                              inputVectorSizes, results);
2250           })
2251           .Default([](auto) { return failure(); });
2252 
2253   if (failed(vectorizeResult)) {
2254     LDBG("Vectorization failed\n");
2255     return failure();
2256   }
2257 
2258   if (!results.empty())
2259     rewriter.replaceOp(op, results);
2260   else
2261     rewriter.eraseOp(op);
2262 
2263   return success();
2264 }
2265 
2266 LogicalResult mlir::linalg::vectorizeCopy(RewriterBase &rewriter,
2267                                           memref::CopyOp copyOp) {
2268   auto srcType = cast<MemRefType>(copyOp.getSource().getType());
2269   auto dstType = cast<MemRefType>(copyOp.getTarget().getType());
2270   if (!srcType.hasStaticShape() || !dstType.hasStaticShape())
2271     return failure();
2272 
2273   auto srcElementType = getElementTypeOrSelf(srcType);
2274   auto dstElementType = getElementTypeOrSelf(dstType);
2275   if (!VectorType::isValidElementType(srcElementType) ||
2276       !VectorType::isValidElementType(dstElementType))
2277     return failure();
2278 
2279   auto readType = VectorType::get(srcType.getShape(), srcElementType);
2280   auto writeType = VectorType::get(dstType.getShape(), dstElementType);
2281 
2282   Location loc = copyOp->getLoc();
2283   Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
2284   SmallVector<Value> indices(srcType.getRank(), zero);
2285 
2286   Value readValue = rewriter.create<vector::TransferReadOp>(
2287       loc, readType, copyOp.getSource(), indices,
2288       rewriter.getMultiDimIdentityMap(srcType.getRank()));
2289   if (cast<VectorType>(readValue.getType()).getRank() == 0) {
2290     readValue =
2291         rewriter.create<vector::ExtractOp>(loc, readValue, ArrayRef<int64_t>());
2292     readValue = rewriter.create<vector::BroadcastOp>(loc, writeType, readValue);
2293   }
2294   Operation *writeValue = rewriter.create<vector::TransferWriteOp>(
2295       loc, readValue, copyOp.getTarget(), indices,
2296       rewriter.getMultiDimIdentityMap(srcType.getRank()));
2297   rewriter.replaceOp(copyOp, writeValue->getResults());
2298   return success();
2299 }
2300 
2301 //----------------------------------------------------------------------------//
2302 // Misc. vectorization patterns.
2303 //----------------------------------------------------------------------------//
2304 /// Base pattern for rewriting tensor::PadOps whose result is consumed by a
2305 /// given operation type OpTy.
2306 template <typename OpTy>
2307 struct VectorizePadOpUserPattern : public OpRewritePattern<tensor::PadOp> {
2308   using OpRewritePattern<tensor::PadOp>::OpRewritePattern;
2309 
2310   LogicalResult matchAndRewrite(tensor::PadOp padOp,
2311                                 PatternRewriter &rewriter) const final {
2312     bool changed = false;
2313     // Insert users in vector, because some users may be replaced/removed.
2314     for (auto *user : llvm::to_vector<4>(padOp->getUsers()))
2315       if (auto op = dyn_cast<OpTy>(user))
2316         changed |= rewriteUser(rewriter, padOp, op).succeeded();
2317     return success(changed);
2318   }
2319 
2320 protected:
2321   virtual LogicalResult rewriteUser(PatternRewriter &rewriter,
2322                                     tensor::PadOp padOp, OpTy op) const = 0;
2323 };
2324 
2325 /// Rewrite use of tensor::PadOp result in TransferReadOp. E.g.:
2326 /// ```
2327 /// %0 = tensor.pad %src ... : tensor<?x?xf32> to tensor<17x5xf32>
2328 /// %r = vector.transfer_read %0[%c0, %c0], %cst
2329 ///     {in_bounds = [true, true]} : tensor<17x5xf32>, vector<17x5xf32>
2330 /// ```
2331 /// is rewritten to:
2332 /// ```
2333 /// %r = vector.transfer_read %src[%c0, %c0], %padding
2334 ///     {in_bounds = [true, true]}
2335 ///     : tensor<?x?xf32>, vector<17x5xf32>
2336 /// ```
2337 /// Note: By restricting this pattern to in-bounds TransferReadOps, we can be
2338 /// sure that the original padding value %cst was never used.
2339 ///
2340 /// This rewrite is possible if:
2341 /// - `xferOp` has no out-of-bounds dims or mask.
2342 /// - Low padding is static 0.
2343 /// - Single, scalar padding value.
2344 struct PadOpVectorizationWithTransferReadPattern
2345     : public VectorizePadOpUserPattern<vector::TransferReadOp> {
2346   using VectorizePadOpUserPattern<
2347       vector::TransferReadOp>::VectorizePadOpUserPattern;
2348 
2349   LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp,
2350                             vector::TransferReadOp xferOp) const override {
2351     // Low padding must be static 0.
2352     if (!padOp.hasZeroLowPad())
2353       return failure();
2354     // Pad value must be a constant.
2355     auto padValue = padOp.getConstantPaddingValue();
2356     if (!padValue)
2357       return failure();
2358     // Padding value of existing `xferOp` is unused.
2359     if (xferOp.hasOutOfBoundsDim() || xferOp.getMask())
2360       return failure();
2361 
2362     rewriter.modifyOpInPlace(xferOp, [&]() {
2363       SmallVector<bool> inBounds(xferOp.getVectorType().getRank(), false);
2364       xferOp->setAttr(xferOp.getInBoundsAttrName(),
2365                       rewriter.getBoolArrayAttr(inBounds));
2366       xferOp.getSourceMutable().assign(padOp.getSource());
2367       xferOp.getPaddingMutable().assign(padValue);
2368     });
2369 
2370     return success();
2371   }
2372 };
2373 
2374 /// Rewrite use of tensor::PadOp result in TransferWriteOp.
2375 /// This pattern rewrites TransferWriteOps that write to a padded tensor
2376 /// value, where the same amount of padding is immediately removed again after
2377 /// the write. In such cases, the TransferWriteOp can write to the non-padded
2378 /// tensor value and apply out-of-bounds masking. E.g.:
2379 /// ```
2380 /// %0 = tensor.extract_slice ...[...] [%s0, %s1] [1, 1]
2381 ///     : tensor<...> to tensor<?x?xf32>
2382 /// %1 = tensor.pad %0 ... : tensor<?x?xf32> to tensor<17x5xf32>
2383 /// %2 = vector.transfer_write %vec, %1[...]
2384 ///     : vector<17x5xf32>, tensor<17x5xf32>
2385 /// %r = tensor.extract_slice %2[0, 0] [%s0, %s1] [1, 1]
2386 ///     : tensor<17x5xf32> to tensor<?x?xf32>
2387 /// ```
2388 /// is rewritten to:
2389 /// ```
2390 /// %0 = tensor.extract_slice ...[...] [%s0, %s1] [1, 1]
2391 ///     : tensor<...> to tensor<?x?xf32>
2392 /// %r = vector.transfer_write %vec, %0[...] : vector<17x5xf32>,
2393 /// tensor<?x?xf32>
2394 /// ```
2395 /// Note: It is important that the ExtractSliceOp %r resizes the result of the
2396 /// TransferWriteOp to the same size as the input of the TensorPadOp (or an
2397 /// even smaller size). Otherwise, %r's new (dynamic) dimensions would differ
2398 /// from %r's old dimensions.
2399 ///
2400 /// This rewrite is possible if:
2401 /// - Low padding is static 0.
2402 /// - `xferOp` has exactly one use, which is an ExtractSliceOp. This
2403 ///   ExtractSliceOp trims the same amount of padding that was added
2404 ///   beforehand.
2405 /// - Single, scalar padding value.
2406 struct PadOpVectorizationWithTransferWritePattern
2407     : public VectorizePadOpUserPattern<vector::TransferWriteOp> {
2408   using VectorizePadOpUserPattern<
2409       vector::TransferWriteOp>::VectorizePadOpUserPattern;
2410 
2411   LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp,
2412                             vector::TransferWriteOp xferOp) const override {
2413     // TODO: support 0-d corner case.
2414     if (xferOp.getTransferRank() == 0)
2415       return failure();
2416 
2417     // Low padding must be static 0.
2418     if (!padOp.hasZeroLowPad())
2419       return failure();
2420     // Pad value must be a constant.
2421     auto padValue = padOp.getConstantPaddingValue();
2422     if (!padValue)
2423       return failure();
2424     // TransferWriteOp result must be directly consumed by an ExtractSliceOp.
2425     if (!xferOp->hasOneUse())
2426       return failure();
2427     auto trimPadding = dyn_cast<tensor::ExtractSliceOp>(*xferOp->user_begin());
2428     if (!trimPadding)
2429       return failure();
2430     // Only static zero offsets supported when trimming padding.
2431     if (!trimPadding.hasZeroOffset())
2432       return failure();
2433     // trimPadding must remove the amount of padding that was added earlier.
2434     if (!hasSameTensorSize(padOp.getSource(), trimPadding))
2435       return failure();
2436 
2437     // Insert the new TransferWriteOp at position of the old TransferWriteOp.
2438     rewriter.setInsertionPoint(xferOp);
2439 
2440     SmallVector<bool> inBounds(xferOp.getVectorType().getRank(), false);
2441     auto newXferOp = rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
2442         xferOp, padOp.getSource().getType(), xferOp.getVector(),
2443         padOp.getSource(), xferOp.getIndices(), xferOp.getPermutationMapAttr(),
2444         xferOp.getMask(), rewriter.getBoolArrayAttr(inBounds));
2445     rewriter.replaceOp(trimPadding, newXferOp->getResult(0));
2446 
2447     return success();
2448   }
2449 
2450   /// Check if `beforePadding` and `afterTrimming` have the same tensor size,
2451   /// i.e., same dimensions.
2452   ///
2453   /// Dimensions may be static, dynamic or mix of both. In case of dynamic
2454   /// dimensions, this function tries to infer the (static) tensor size by
2455   /// looking at the defining op and utilizing op-specific knowledge.
2456   ///
2457   /// This is a conservative analysis. In case equal tensor sizes cannot be
2458   /// proven statically, this analysis returns `false` even though the tensor
2459   /// sizes may turn out to be equal at runtime.
2460   bool hasSameTensorSize(Value beforePadding,
2461                          tensor::ExtractSliceOp afterTrimming) const {
2462     // If the input to tensor::PadOp is a CastOp, try with both CastOp
2463     // result and CastOp operand.
2464     if (auto castOp = beforePadding.getDefiningOp<tensor::CastOp>())
2465       if (hasSameTensorSize(castOp.getSource(), afterTrimming))
2466         return true;
2467 
2468     auto t1 = dyn_cast<RankedTensorType>(beforePadding.getType());
2469     auto t2 = dyn_cast<RankedTensorType>(afterTrimming.getType());
2470     // Only RankedTensorType supported.
2471     if (!t1 || !t2)
2472       return false;
2473     // Rank of both values must be the same.
2474     if (t1.getRank() != t2.getRank())
2475       return false;
2476 
2477     // All static dimensions must be the same. Mixed cases (e.g., dimension
2478     // static in `t1` but dynamic in `t2`) are not supported.
2479     for (unsigned i = 0; i < t1.getRank(); ++i) {
2480       if (t1.isDynamicDim(i) != t2.isDynamicDim(i))
2481         return false;
2482       if (!t1.isDynamicDim(i) && t1.getDimSize(i) != t2.getDimSize(i))
2483         return false;
2484     }
2485 
2486     // Nothing more to check if all dimensions are static.
2487     if (t1.getNumDynamicDims() == 0)
2488       return true;
2489 
2490     // All dynamic sizes must be the same. The only supported case at the
2491     // moment is when `beforePadding` is an ExtractSliceOp (or a cast
2492     // thereof).
2493 
2494     // Apart from CastOp, only ExtractSliceOp is supported.
2495     auto beforeSlice = beforePadding.getDefiningOp<tensor::ExtractSliceOp>();
2496     if (!beforeSlice)
2497       return false;
2498 
2499     assert(static_cast<size_t>(t1.getRank()) ==
2500            beforeSlice.getMixedSizes().size());
2501     assert(static_cast<size_t>(t2.getRank()) ==
2502            afterTrimming.getMixedSizes().size());
2503 
2504     for (unsigned i = 0; i < t1.getRank(); ++i) {
2505       // Skip static dimensions.
2506       if (!t1.isDynamicDim(i))
2507         continue;
2508       auto size1 = beforeSlice.getMixedSizes()[i];
2509       auto size2 = afterTrimming.getMixedSizes()[i];
2510 
2511       // Case 1: Same value or same constant int.
2512       if (isEqualConstantIntOrValue(size1, size2))
2513         continue;
2514 
2515       // Other cases: Take a deeper look at defining ops of values.
2516       auto v1 = llvm::dyn_cast_if_present<Value>(size1);
2517       auto v2 = llvm::dyn_cast_if_present<Value>(size2);
2518       if (!v1 || !v2)
2519         return false;
2520 
2521       // Case 2: Both values are identical AffineMinOps. (Should not happen if
2522       // CSE is run.)
2523       auto minOp1 = v1.getDefiningOp<affine::AffineMinOp>();
2524       auto minOp2 = v2.getDefiningOp<affine::AffineMinOp>();
2525       if (minOp1 && minOp2 && minOp1.getAffineMap() == minOp2.getAffineMap() &&
2526           minOp1.getOperands() == minOp2.getOperands())
2527         continue;
2528 
2529       // Add additional cases as needed.
2530     }
2531 
2532     // All tests passed.
2533     return true;
2534   }
2535 };
2536 
2537 /// Returns the effective Pad value for the input op, provided it's a scalar.
2538 ///
2539 /// Many Ops exhibit pad-like behaviour, but this isn't always explicit. If
2540 /// this Op performs padding, retrieve the padding value provided that it's
2541 /// a scalar and static/fixed for all the padded values. Returns an empty value
2542 /// otherwise.
2543 static Value getStaticPadVal(Operation *op) {
2544   if (!op)
2545     return {};
2546 
2547   // 1. vector.broadcast (f32 -> vector <...xf32>) - return the value that's
2548   // being broadcast, provided that it's a scalar.
2549   if (auto bcast = llvm::dyn_cast<vector::BroadcastOp>(op)) {
2550     auto source = bcast.getSource();
2551     if (llvm::dyn_cast<VectorType>(source.getType()))
2552       return {};
2553 
2554     return source;
2555   }
2556 
2557   // 2. linalg.fill - use the scalar input value that used to fill the output
2558   // tensor.
2559   if (auto fill = llvm::dyn_cast<linalg::FillOp>(op)) {
2560     return fill.getInputs()[0];
2561   }
2562 
2563   // 3. tensor.generateOp - can't guarantee the value is fixed without
2564   // analysing, bail out.
2565   if (auto generate = llvm::dyn_cast<tensor::GenerateOp>(op)) {
2566     return {};
2567   }
2568 
2569   // 4. vector.transfer_write - inspect the input vector that's written from. If
2570   // if contains a single value that has been broadcast (e.g. via
2571   // vector.broadcast), extract it, fail otherwise.
2572   if (auto xferWrite = llvm::dyn_cast<vector::TransferWriteOp>(op))
2573     return getStaticPadVal(xferWrite.getVector().getDefiningOp());
2574 
2575   // 5. tensor.insert_slice - inspect the destination tensor. If it's larger
2576   // than the input tensor, then, provided it's constant, we'll extract the
2577   // value that was used to generate it (via e.g. linalg.fill), fail otherwise.
2578   // TODO: Clarify the semantics when the input tensor is larger than the
2579   // destination.
2580   if (auto slice = llvm::dyn_cast<tensor::InsertSliceOp>(op))
2581     return getStaticPadVal(slice.getDest().getDefiningOp());
2582 
2583   return {};
2584 }
2585 
2586 /// Rewrite tensor.insert.slice as a vector.transfer_read +
2587 /// vector.transfer_write pair. The vector size is inferred from the static
2588 /// dims in the input and output tensors. If a dim is dynamic in both the input
2589 /// and output tensors, bails out.
2590 ///
2591 /// Before:
2592 ///     !t_in_type = tensor<1x2x3xf32>
2593 ///     !t_out_type = tensor<9x8x7x1x2x3xf32>
2594 ///     !v_type = vector<1x2x3xf32>
2595 ///     %inserted_slice = tensor.insert_slice %src into %dest ... : !t_in_type
2596 ///     into !t_out_type
2597 /// After:
2598 ///     %read = vector.transfer_read %src[...], %pad ... : !t_in_type, !v_type
2599 ///     %write = vector.transfer_write %read, %dest ... : !v_type, !t_out_type
2600 ///
2601 /// TODO: Support masking
2602 struct InsertSliceVectorizePattern
2603     : public OpRewritePattern<tensor::InsertSliceOp> {
2604   using OpRewritePattern<tensor::InsertSliceOp>::OpRewritePattern;
2605 
2606   LogicalResult matchAndRewrite(tensor::InsertSliceOp sliceOp,
2607                                 PatternRewriter &rewriter) const final {
2608     auto sourceType = sliceOp.getSource().getType();
2609     if (!VectorType::isValidElementType(sourceType.getElementType()))
2610       return failure();
2611 
2612     auto resultType = sliceOp.getResultType();
2613 
2614     // 1. Get the pad value.
2615     // TransferReadOp requires a scalar padding value. Note that:
2616     //    * for in-bounds access, the value is actually irrelevant.
2617     //  There are 2 cases in which xfer.read accesses are known to be in-bounds:
2618     //  1. The source shape is static (output vector sizes would be based on
2619     //     the source shape and hence all memory accesses would be in-bounds),
2620     //  2. Masking is used (output vector sizes would be user-provided, in which
2621     //     case it is assumed that all memory accesses are in-bounds). This
2622     //     remains a TODO.
2623     //
2624     // When the value is not known and not needed, use 0. Otherwise, bail out.
2625     Value padValue = getStaticPadVal(sliceOp);
2626     bool isOutOfBoundsRead = !sourceType.hasStaticShape();
2627 
2628     if (!padValue && isOutOfBoundsRead) {
2629       LDBG("Failed to get a pad value for out-of-bounds read access\n");
2630       return failure();
2631     }
2632 
2633     if (!padValue) {
2634       auto elemType = sourceType.getElementType();
2635       padValue = rewriter.create<arith::ConstantOp>(
2636           sliceOp.getLoc(), elemType, rewriter.getZeroAttr(elemType));
2637     }
2638 
2639     // 2. Get the vector shape and in-bounds attributes
2640     SmallVector<int64_t> vecShape;
2641     SmallVector<bool> readInBounds;
2642     SmallVector<bool> writeInBounds;
2643     size_t rankDiff = resultType.getRank() - sourceType.getRank();
2644     for (unsigned i = 0; i < sourceType.getRank(); ++i) {
2645       if (!sourceType.isDynamicDim(i)) {
2646         vecShape.push_back(sourceType.getDimSize(i));
2647         // Source shape is statically known: Neither read nor write are
2648         // out-of-bounds.
2649         readInBounds.push_back(true);
2650         writeInBounds.push_back(true);
2651       } else if (!resultType.isDynamicDim(i)) {
2652         // Source shape is not statically known, but result shape is.
2653         // Vectorize with size of result shape. This may be larger than the
2654         // source size.
2655         // FIXME: Using rankDiff implies that the source tensor is inserted at
2656         // the end of the destination tensor. However, that's not required.
2657         vecShape.push_back(resultType.getDimSize(rankDiff + i));
2658         // Read may be out-of-bounds because the result size could be larger
2659         // than the source size.
2660         readInBounds.push_back(false);
2661         // Write will in-bounds provided that the corresponding write idx is 0.
2662         // To keep this logic simple, conservatively mark as out-of-bounds.
2663         writeInBounds.push_back(false);
2664       } else {
2665         // Neither source nor result dim of padOp is static. Cannot vectorize
2666         // the copy.
2667         // TODO: Add support for masking
2668         return failure();
2669       }
2670     }
2671     auto vecType = VectorType::get(vecShape, sourceType.getElementType());
2672 
2673     // 3. Generate TransferReadOp.
2674     SmallVector<Value> readIndices(
2675         vecType.getRank(),
2676         rewriter.create<arith::ConstantIndexOp>(sliceOp.getLoc(), 0));
2677     auto read = rewriter.create<vector::TransferReadOp>(
2678         sliceOp.getLoc(), vecType, sliceOp.getSource(), readIndices, padValue,
2679         ArrayRef<bool>{readInBounds});
2680 
2681     // 4. Generate TransferWriteOp.
2682     auto writeIndices = getValueOrCreateConstantIndexOp(
2683         rewriter, sliceOp.getLoc(), sliceOp.getMixedOffsets());
2684 
2685     // 5. Finalize
2686     rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
2687         sliceOp, read, sliceOp.getDest(), writeIndices,
2688         ArrayRef<bool>{writeInBounds});
2689 
2690     return success();
2691   }
2692 };
2693 
2694 /// Rewrite use of tensor::PadOp result in InsertSliceOp. E.g.:
2695 /// ```
2696 /// %0 = tensor.pad %src ... : tensor<?x?xf32> to tensor<17x5xf32>
2697 /// %r = tensor.insert_slice %0
2698 ///     into %dest[%a, %b, 0, 0] [1, 1, 17, 5] [1, 1, 1, 1]
2699 ///     : tensor<17x5xf32> into tensor<?x?x17x5xf32>
2700 /// ```
2701 /// is rewritten to:
2702 /// ```
2703 /// %0 = vector.transfer_read %src[%c0, %c0], %padding
2704 ///     : tensor<?x?xf32>, vector<17x5xf32>
2705 /// %r = vector.transfer_write %0, %dest[%a, %b, %c0, %c0]
2706 ///     {in_bounds = [true, true]} : vector<17x5xf32>, tensor<?x?x17x5xf32>
2707 /// ```
2708 ///
2709 /// This rewrite is possible if:
2710 /// - Low padding is static 0.
2711 /// - `padOp` result shape is static.
2712 /// - The entire padded tensor is inserted.
2713 ///   (Implies that sizes of `insertOp` are all static.)
2714 /// - Only unit strides in `insertOp`.
2715 /// - Single, scalar padding value.
2716 /// - `padOp` result not used as destination.
2717 struct PadOpVectorizationWithInsertSlicePattern
2718     : public VectorizePadOpUserPattern<tensor::InsertSliceOp> {
2719   using VectorizePadOpUserPattern<
2720       tensor::InsertSliceOp>::VectorizePadOpUserPattern;
2721 
2722   LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp,
2723                             tensor::InsertSliceOp insertOp) const override {
2724     // Low padding must be static 0.
2725     if (!padOp.hasZeroLowPad())
2726       return failure();
2727     // Only unit stride supported.
2728     if (!insertOp.hasUnitStride())
2729       return failure();
2730     // Pad value must be a constant.
2731     auto padValue = padOp.getConstantPaddingValue();
2732     if (!padValue)
2733       return failure();
2734     // Dynamic shapes not supported.
2735     if (!cast<ShapedType>(padOp.getResult().getType()).hasStaticShape())
2736       return failure();
2737     // Pad result not used as destination.
2738     if (insertOp.getDest() == padOp.getResult())
2739       return failure();
2740 
2741     auto vecType = VectorType::get(padOp.getType().getShape(),
2742                                    padOp.getType().getElementType());
2743     unsigned vecRank = vecType.getRank();
2744     unsigned tensorRank = insertOp.getType().getRank();
2745 
2746     // Check if sizes match: Insert the entire tensor into most minor dims.
2747     // (No permutations allowed.)
2748     SmallVector<int64_t> expectedSizes(tensorRank - vecRank, 1);
2749     expectedSizes.append(vecType.getShape().begin(), vecType.getShape().end());
2750     if (!llvm::all_of(
2751             llvm::zip(insertOp.getMixedSizes(), expectedSizes), [](auto it) {
2752               return getConstantIntValue(std::get<0>(it)) == std::get<1>(it);
2753             }))
2754       return failure();
2755 
2756     // Insert the TransferReadOp and TransferWriteOp at the position of the
2757     // InsertSliceOp.
2758     rewriter.setInsertionPoint(insertOp);
2759 
2760     // Generate TransferReadOp: Read entire source tensor and add high
2761     // padding.
2762     SmallVector<Value> readIndices(
2763         vecRank, rewriter.create<arith::ConstantIndexOp>(padOp.getLoc(), 0));
2764     auto read = rewriter.create<vector::TransferReadOp>(
2765         padOp.getLoc(), vecType, padOp.getSource(), readIndices, padValue);
2766 
2767     // Generate TransferWriteOp: Write to InsertSliceOp's dest tensor at
2768     // specified offsets. Write is fully in-bounds because a InsertSliceOp's
2769     // source must fit into the destination at the specified offsets.
2770     auto writeIndices = getValueOrCreateConstantIndexOp(
2771         rewriter, padOp.getLoc(), insertOp.getMixedOffsets());
2772     SmallVector<bool> inBounds(vecRank, true);
2773     rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
2774         insertOp, read, insertOp.getDest(), writeIndices,
2775         ArrayRef<bool>{inBounds});
2776 
2777     return success();
2778   }
2779 };
2780 
2781 void mlir::linalg::populateInsertSliceVectorizationPatterns(
2782     RewritePatternSet &patterns) {
2783   patterns.add<InsertSliceVectorizePattern>(patterns.getContext());
2784 }
2785 
2786 void mlir::linalg::populatePadOpVectorizationPatterns(
2787     RewritePatternSet &patterns, PatternBenefit baseBenefit) {
2788   patterns.add<PadOpVectorizationWithTransferReadPattern,
2789                PadOpVectorizationWithTransferWritePattern,
2790                PadOpVectorizationWithInsertSlicePattern>(
2791       patterns.getContext(), baseBenefit.getBenefit() + 1);
2792 }
2793 
2794 //----------------------------------------------------------------------------//
2795 // Forwarding patterns
2796 //----------------------------------------------------------------------------//
2797 
2798 /// Check whether there is any interleaved use of any `values` between
2799 /// `firstOp` and `secondOp`. Conservatively return `true` if any op or value
2800 /// is in a different block.
2801 static bool mayExistInterleavedUses(Operation *firstOp, Operation *secondOp,
2802                                     ValueRange values) {
2803   if (firstOp->getBlock() != secondOp->getBlock() ||
2804       !firstOp->isBeforeInBlock(secondOp)) {
2805     LDBG("interleavedUses precondition failed, firstOp: "
2806          << *firstOp << ", second op: " << *secondOp << "\n");
2807     return true;
2808   }
2809   for (auto v : values) {
2810     for (auto &u : v.getUses()) {
2811       Operation *owner = u.getOwner();
2812       if (owner == firstOp || owner == secondOp)
2813         continue;
2814       // TODO: this is too conservative, use dominance info in the future.
2815       if (owner->getBlock() == firstOp->getBlock() &&
2816           (owner->isBeforeInBlock(firstOp) || secondOp->isBeforeInBlock(owner)))
2817         continue;
2818       LDBG(" found interleaved op " << *owner << ", firstOp: " << *firstOp
2819                                     << ", second op: " << *secondOp << "\n");
2820       return true;
2821     }
2822   }
2823   return false;
2824 }
2825 
2826 /// Return the unique subview use of `v` if it is indeed unique, null
2827 /// otherwise.
2828 static memref::SubViewOp getSubViewUseIfUnique(Value v) {
2829   memref::SubViewOp subViewOp;
2830   for (auto &u : v.getUses()) {
2831     if (auto newSubViewOp = dyn_cast<memref::SubViewOp>(u.getOwner())) {
2832       if (subViewOp)
2833         return memref::SubViewOp();
2834       subViewOp = newSubViewOp;
2835     }
2836   }
2837   return subViewOp;
2838 }
2839 
2840 /// TODO: use interfaces, side-effects and aliasing analysis as appropriate,
2841 /// when available.
2842 LogicalResult LinalgCopyVTRForwardingPattern::matchAndRewrite(
2843     vector::TransferReadOp xferOp, PatternRewriter &rewriter) const {
2844 
2845   // TODO: support mask.
2846   if (xferOp.getMask())
2847     return rewriter.notifyMatchFailure(xferOp, "unsupported mask");
2848 
2849   // Transfer into `view`.
2850   Value viewOrAlloc = xferOp.getSource();
2851   if (!viewOrAlloc.getDefiningOp<memref::ViewOp>() &&
2852       !viewOrAlloc.getDefiningOp<memref::AllocOp>())
2853     return rewriter.notifyMatchFailure(xferOp, "source not a view or alloc");
2854 
2855   // Ensure there is exactly one subview of `viewOrAlloc` defining `subView`.
2856   memref::SubViewOp subViewOp = getSubViewUseIfUnique(viewOrAlloc);
2857   if (!subViewOp)
2858     return rewriter.notifyMatchFailure(xferOp, "no subview found");
2859   Value subView = subViewOp.getResult();
2860 
2861   // Find the copy into `subView` without interleaved uses.
2862   memref::CopyOp copyOp;
2863   for (auto &u : subView.getUses()) {
2864     if (auto newCopyOp = dyn_cast<memref::CopyOp>(u.getOwner())) {
2865       assert(isa<MemRefType>(newCopyOp.getTarget().getType()));
2866       if (newCopyOp.getTarget() != subView)
2867         continue;
2868       if (mayExistInterleavedUses(newCopyOp, xferOp, {viewOrAlloc, subView}))
2869         continue;
2870       copyOp = newCopyOp;
2871       break;
2872     }
2873   }
2874   if (!copyOp)
2875     return rewriter.notifyMatchFailure(xferOp, "no copy found");
2876 
2877   // Find the fill into `viewOrAlloc` without interleaved uses before the
2878   // copy.
2879   FillOp maybeFillOp;
2880   for (auto &u : viewOrAlloc.getUses()) {
2881     if (auto newFillOp = dyn_cast<FillOp>(u.getOwner())) {
2882       assert(isa<MemRefType>(newFillOp.output().getType()));
2883       if (newFillOp.output() != viewOrAlloc)
2884         continue;
2885       if (mayExistInterleavedUses(newFillOp, copyOp, {viewOrAlloc, subView}))
2886         continue;
2887       maybeFillOp = newFillOp;
2888       break;
2889     }
2890   }
2891   // Ensure padding matches.
2892   if (maybeFillOp && xferOp.getPadding() != maybeFillOp.value())
2893     return rewriter.notifyMatchFailure(xferOp,
2894                                        "padding value does not match fill");
2895 
2896   // `in` is the subview that memref.copy reads. Replace it.
2897   Value in = copyOp.getSource();
2898 
2899   // memref.copy + linalg.fill can be used to create a padded local buffer.
2900   // The `masked` attribute is only valid on this padded buffer.
2901   // When forwarding to vector.transfer_read, the attribute must be reset
2902   // conservatively.
2903   auto vectorType = xferOp.getVectorType();
2904   Value res = rewriter.create<vector::TransferReadOp>(
2905       xferOp.getLoc(), vectorType, in, xferOp.getIndices(),
2906       xferOp.getPermutationMapAttr(), xferOp.getPadding(), xferOp.getMask(),
2907       rewriter.getBoolArrayAttr(
2908           SmallVector<bool>(vectorType.getRank(), false)));
2909 
2910   if (maybeFillOp)
2911     rewriter.eraseOp(maybeFillOp);
2912   rewriter.eraseOp(copyOp);
2913   rewriter.replaceOp(xferOp, res);
2914 
2915   return success();
2916 }
2917 
2918 /// TODO: use interfaces, side-effects and aliasing analysis as appropriate,
2919 /// when available.
2920 LogicalResult LinalgCopyVTWForwardingPattern::matchAndRewrite(
2921     vector::TransferWriteOp xferOp, PatternRewriter &rewriter) const {
2922   // TODO: support mask.
2923   if (xferOp.getMask())
2924     return rewriter.notifyMatchFailure(xferOp, "unsupported mask");
2925 
2926   // Transfer into `viewOrAlloc`.
2927   Value viewOrAlloc = xferOp.getSource();
2928   if (!viewOrAlloc.getDefiningOp<memref::ViewOp>() &&
2929       !viewOrAlloc.getDefiningOp<memref::AllocOp>())
2930     return rewriter.notifyMatchFailure(xferOp, "source not a view or alloc");
2931 
2932   // Ensure there is exactly one subview of `viewOrAlloc` defining `subView`.
2933   memref::SubViewOp subViewOp = getSubViewUseIfUnique(viewOrAlloc);
2934   if (!subViewOp)
2935     return rewriter.notifyMatchFailure(xferOp, "no subview found");
2936   Value subView = subViewOp.getResult();
2937 
2938   // Find the copy from `subView` without interleaved uses.
2939   memref::CopyOp copyOp;
2940   for (auto &u : subViewOp.getResult().getUses()) {
2941     if (auto newCopyOp = dyn_cast<memref::CopyOp>(u.getOwner())) {
2942       if (newCopyOp.getSource() != subView)
2943         continue;
2944       if (mayExistInterleavedUses(xferOp, newCopyOp, {viewOrAlloc, subView}))
2945         continue;
2946       copyOp = newCopyOp;
2947       break;
2948     }
2949   }
2950   if (!copyOp)
2951     return rewriter.notifyMatchFailure(xferOp, "no copy found");
2952 
2953   // `out` is the subview copied into that we replace.
2954   assert(isa<MemRefType>(copyOp.getTarget().getType()));
2955   Value out = copyOp.getTarget();
2956 
2957   // Forward vector.transfer into copy.
2958   // memref.copy + linalg.fill can be used to create a padded local buffer.
2959   // The `masked` attribute is only valid on this padded buffer.
2960   // When forwarding to vector.transfer_write, the attribute must be reset
2961   // conservatively.
2962   auto vector = xferOp.getVector();
2963   rewriter.create<vector::TransferWriteOp>(
2964       xferOp.getLoc(), vector, out, xferOp.getIndices(),
2965       xferOp.getPermutationMapAttr(), xferOp.getMask(),
2966       rewriter.getBoolArrayAttr(
2967           SmallVector<bool>(vector.getType().getRank(), false)));
2968 
2969   rewriter.eraseOp(copyOp);
2970   rewriter.eraseOp(xferOp);
2971 
2972   return success();
2973 }
2974 
2975 //===----------------------------------------------------------------------===//
2976 // Convolution vectorization patterns
2977 //===----------------------------------------------------------------------===//
2978 
2979 template <int N>
2980 static void bindShapeDims(ShapedType shapedType) {}
2981 
2982 template <int N, typename IntTy, typename... IntTy2>
2983 static void bindShapeDims(ShapedType shapedType, IntTy &val, IntTy2 &...vals) {
2984   val = shapedType.getShape()[N];
2985   bindShapeDims<N + 1, IntTy2 &...>(shapedType, vals...);
2986 }
2987 
2988 /// Bind a pack of int& to the leading dimensions of shapedType.getShape().
2989 template <typename... IntTy>
2990 static void bindShapeDims(ShapedType shapedType, IntTy &...vals) {
2991   bindShapeDims<0>(shapedType, vals...);
2992 }
2993 
2994 namespace {
2995 bool isCastOfBlockArgument(Operation *op) {
2996   return isa<CastOpInterface>(op) && op->getNumOperands() == 1 &&
2997          isa<BlockArgument>(op->getOperand(0));
2998 }
2999 
3000 bool isSupportedPoolKind(vector::CombiningKind kind) {
3001   switch (kind) {
3002   case vector::CombiningKind::ADD:
3003   case vector::CombiningKind::MAXNUMF:
3004   case vector::CombiningKind::MAXIMUMF:
3005   case vector::CombiningKind::MAXSI:
3006   case vector::CombiningKind::MAXUI:
3007   case vector::CombiningKind::MINNUMF:
3008   case vector::CombiningKind::MINIMUMF:
3009   case vector::CombiningKind::MINSI:
3010   case vector::CombiningKind::MINUI:
3011     return true;
3012   default:
3013     return false;
3014   }
3015 }
3016 
3017 /// Generate a vector implementation for either:
3018 /// ```
3019 ///   Op def: (     w,     kw  )
3020 ///    Iters: ({Par(), Red()})
3021 ///   Layout: {{w + kw}, {kw}, {w}}
3022 /// ```
3023 /// kw is unrolled.
3024 ///
3025 /// or
3026 ///
3027 /// ```
3028 ///   Op def: (     n,     w,     c,    kw,    f  )
3029 ///    Iters: ({Par(), Par(), Par(), Red(), Red()})
3030 ///   Layout: {{n, strideW * w + dilationW * kw, c}, {kw, c, f}, {n, w, f}}
3031 /// ```
3032 /// kw is unrolled, w is unrolled iff dilationW > 1.
3033 ///
3034 /// or
3035 ///
3036 /// ```
3037 ///   Op def: (     n,     c,     w,    f,    kw )
3038 ///    Iters: ({Par(), Par(), Par(), Red(), Red()})
3039 ///   Layout: {{n, c, strideW * w + dilationW * kw}, {f, c, kw}, {n, f, w}}
3040 /// ```
3041 /// kw is unrolled, w is unrolled iff dilationW > 1.
3042 ///
3043 /// or
3044 ///
3045 /// ```
3046 ///   Op def: (     n,     w,     c,    kw )
3047 ///    Iters: ({Par(), Par(), Par(), Red()})
3048 ///   Layout: {{n, strideW * w + dilationW * kw, c}, {kw, c}, {n, w, c}}
3049 /// ```
3050 /// kw is unrolled, w is unrolled iff dilationW > 1.
3051 struct Conv1DGenerator
3052     : public StructuredGenerator<LinalgOp, utils::IteratorType> {
3053   Conv1DGenerator(RewriterBase &rewriter, LinalgOp linalgOp, int strideW,
3054                   int dilationW)
3055       : StructuredGenerator<LinalgOp, utils::IteratorType>(rewriter, linalgOp),
3056         strideW(strideW), dilationW(dilationW) {
3057     // Determine whether `linalgOp` can be generated with this generator
3058     if (linalgOp.getNumDpsInputs() != 2 || linalgOp.getNumDpsInits() != 1)
3059       return;
3060     lhsShaped = linalgOp.getDpsInputOperand(0)->get();
3061     rhsShaped = linalgOp.getDpsInputOperand(1)->get();
3062     resShaped = linalgOp.getDpsInitOperand(0)->get();
3063     lhsShapedType = dyn_cast<ShapedType>(lhsShaped.getType());
3064     rhsShapedType = dyn_cast<ShapedType>(rhsShaped.getType());
3065     resShapedType = dyn_cast<ShapedType>(resShaped.getType());
3066     if (!lhsShapedType || !rhsShapedType || !resShapedType)
3067       return;
3068     // (LHS has dimension NCW/NWC and RES has dimension NFW/NCW/NWF/NWC) OR
3069     // (non-channeled convolution -> LHS and RHS both have single dimensions).
3070     if ((lhsShapedType.getRank() != 3 || resShapedType.getRank() != 3) &&
3071         (lhsShapedType.getRank() != 1 || resShapedType.getRank() != 1))
3072       return;
3073 
3074     Operation *reduceOp = matchLinalgReduction(linalgOp.getDpsInitOperand(0));
3075     if (!reduceOp)
3076       return;
3077     redOp = reduceOp->getName().getIdentifier();
3078 
3079     if (!setOperKind(reduceOp))
3080       return;
3081     auto maybeKind = getCombinerOpKind(reduceOp);
3082     // Typically convolution will have a `Add` CombiningKind but for i1 type it
3083     // can get strength reduced to `OR` which is also supported. This strength
3084     // reduction logic is in `buildBinaryFn` helper in the Linalg dialect.
3085     if (!maybeKind || ((*maybeKind != vector::CombiningKind::ADD &&
3086                         *maybeKind != vector::CombiningKind::OR) &&
3087                        (oper != Pool || !isSupportedPoolKind(*maybeKind)))) {
3088       return;
3089     }
3090     reductionKind = maybeKind.value();
3091 
3092     auto rhsRank = rhsShapedType.getRank();
3093     switch (oper) {
3094     case Conv:
3095       if (rhsRank != 1 && rhsRank != 2 && rhsRank != 3)
3096         return;
3097       break;
3098     case Pool:
3099       if (rhsRank != 1)
3100         return;
3101       break;
3102     }
3103     // The op is now known to be valid.
3104     valid = true;
3105   }
3106 
3107   /// Generate a vector implementation for:
3108   /// ```
3109   ///   Op def: (     w,     kw  )
3110   ///    Iters: ({Par(), Red()})
3111   ///   Layout: {{w + kw}, {kw}, {w}}
3112   /// ```
3113   /// kw is always unrolled.
3114   ///
3115   /// or
3116   ///
3117   /// ```
3118   ///   Op def: (     n,     w,     c,    kw,    f  )
3119   ///    Iters: ({Par(), Par(), Par(), Red(), Red()})
3120   ///   Layout: {{n, strideW * w + dilationW * kw, c}, {kw, c, f}, {n, w, f}}
3121   /// ```
3122   /// kw is always unrolled.
3123   /// TODO: w (resp. kw) is unrolled when the strideW ( resp. dilationW) is
3124   /// > 1.
3125   FailureOr<Operation *> conv(Conv1DOpOrder conv1DOpOrder) {
3126     if (!valid)
3127       return rewriter.notifyMatchFailure(op, "unvectorizable 1-D conv/pool");
3128 
3129     int64_t nSize, wSize, cSize, kwSize, fSize;
3130     SmallVector<int64_t, 3> lhsShape, rhsShape, resShape;
3131     bool isSingleChanneled = (conv1DOpOrder == Conv1DOpOrder::W);
3132     switch (conv1DOpOrder) {
3133     case Conv1DOpOrder::W:
3134       // Initialize unused dimensions
3135       nSize = fSize = cSize = 0;
3136       // out{W}
3137       bindShapeDims(resShapedType, wSize);
3138       // kernel{kw}
3139       bindShapeDims(rhsShapedType, kwSize);
3140       lhsShape = {// iw = ow + kw - 1
3141                   //   (i.e. 16 convolved with 3 -> 14)
3142                   (wSize + kwSize - 1)};
3143       rhsShape = {kwSize};
3144       resShape = {wSize};
3145       break;
3146     case Conv1DOpOrder::Nwc:
3147       // out{n, w, f}
3148       bindShapeDims(resShapedType, nSize, wSize, fSize);
3149       switch (oper) {
3150       case Conv:
3151         // kernel{kw, c, f}
3152         bindShapeDims(rhsShapedType, kwSize, cSize);
3153         break;
3154       case Pool:
3155         // kernel{kw}
3156         bindShapeDims(rhsShapedType, kwSize);
3157         cSize = fSize;
3158         break;
3159       }
3160       lhsShape = {nSize,
3161                   // iw = ow * sw + kw *  dw - 1
3162                   //   (i.e. 16 convolved with 3 (@stride 1 dilation 1) -> 14)
3163                   // Perform the proper inclusive -> exclusive -> inclusive.
3164                   ((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) -
3165                       1,
3166                   cSize};
3167       switch (oper) {
3168       case Conv:
3169         rhsShape = {kwSize, cSize, fSize};
3170         break;
3171       case Pool:
3172         rhsShape = {kwSize};
3173         break;
3174       }
3175       resShape = {nSize, wSize, fSize};
3176       break;
3177     case Conv1DOpOrder::Ncw:
3178       // out{n, f, w}
3179       bindShapeDims(resShapedType, nSize, fSize, wSize);
3180       switch (oper) {
3181       case Conv:
3182         // kernel{f, c, kw}
3183         bindShapeDims(rhsShapedType, fSize, cSize, kwSize);
3184         break;
3185       case Pool:
3186         // kernel{kw}
3187         bindShapeDims(rhsShapedType, kwSize);
3188         cSize = fSize;
3189         break;
3190       }
3191       lhsShape = {nSize, cSize,
3192                   // iw = ow * sw + kw *  dw - 1
3193                   //   (i.e. 16 convolved with 3 (@stride 1 dilation 1) -> 14)
3194                   // Perform the proper inclusive -> exclusive -> inclusive.
3195                   ((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) -
3196                       1};
3197       switch (oper) {
3198       case Conv:
3199         rhsShape = {fSize, cSize, kwSize};
3200         break;
3201       case Pool:
3202         rhsShape = {kwSize};
3203         break;
3204       }
3205       resShape = {nSize, fSize, wSize};
3206       break;
3207     }
3208 
3209     vector::TransferWriteOp write;
3210     Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
3211 
3212     // w is unrolled (i.e. wSizeStep == 1) iff strideW > 1.
3213     // When strideW == 1, we can batch the contiguous loads and avoid
3214     // unrolling
3215     int64_t wSizeStep = strideW == 1 ? wSize : 1;
3216 
3217     Type lhsEltType = lhsShapedType.getElementType();
3218     Type rhsEltType = rhsShapedType.getElementType();
3219     Type resEltType = resShapedType.getElementType();
3220     auto lhsType = VectorType::get(lhsShape, lhsEltType);
3221     auto rhsType = VectorType::get(rhsShape, rhsEltType);
3222     auto resType = VectorType::get(resShape, resEltType);
3223     // Zero padding with the corresponding dimensions for lhs, rhs and res.
3224     SmallVector<Value> lhsPadding(lhsShape.size(), zero);
3225     SmallVector<Value> rhsPadding(rhsShape.size(), zero);
3226     SmallVector<Value> resPadding(resShape.size(), zero);
3227 
3228     // Read the whole lhs, rhs and res in one shot (with zero padding).
3229     Value lhs = rewriter.create<vector::TransferReadOp>(loc, lhsType, lhsShaped,
3230                                                         lhsPadding);
3231     // This is needed only for Conv.
3232     Value rhs = nullptr;
3233     if (oper == Conv)
3234       rhs = rewriter.create<vector::TransferReadOp>(loc, rhsType, rhsShaped,
3235                                                     rhsPadding);
3236     Value res = rewriter.create<vector::TransferReadOp>(loc, resType, resShaped,
3237                                                         resPadding);
3238 
3239     // The base vectorization case for channeled convolution is input:
3240     // {n,w,c}, weight: {kw,c,f}, output: {n,w,f}. To reuse the base pattern
3241     // vectorization case, we do pre transpose on input, weight, and output.
3242     switch (conv1DOpOrder) {
3243     case Conv1DOpOrder::W:
3244     case Conv1DOpOrder::Nwc:
3245       // Base case, so no transposes necessary.
3246       break;
3247     case Conv1DOpOrder::Ncw: {
3248       // To match base vectorization case, we pre-transpose current case.
3249       // ncw -> nwc
3250       static constexpr std::array<int64_t, 3> permLhs = {0, 2, 1};
3251       lhs = rewriter.create<vector::TransposeOp>(loc, lhs, permLhs);
3252       // fcw -> wcf
3253       static constexpr std::array<int64_t, 3> permRhs = {2, 1, 0};
3254 
3255       // This is needed only for Conv.
3256       if (oper == Conv)
3257         rhs = rewriter.create<vector::TransposeOp>(loc, rhs, permRhs);
3258       // nfw -> nwf
3259       static constexpr std::array<int64_t, 3> permRes = {0, 2, 1};
3260       res = rewriter.create<vector::TransposeOp>(loc, res, permRes);
3261       break;
3262     }
3263     }
3264 
3265     //===------------------------------------------------------------------===//
3266     // Begin vector-only rewrite part
3267     //===------------------------------------------------------------------===//
3268     // Unroll along kw and read slices of lhs and rhs.
3269     SmallVector<Value> lhsVals, rhsVals, resVals;
3270     lhsVals = extractConvInputSlices(rewriter, loc, lhs, nSize, wSize, cSize,
3271                                      kwSize, strideW, dilationW, wSizeStep,
3272                                      isSingleChanneled);
3273     // Do not do for pooling.
3274     if (oper == Conv)
3275       rhsVals = extractConvFilterSlices(rewriter, loc, rhs, kwSize);
3276     resVals = extractConvResultSlices(rewriter, loc, res, nSize, wSize, fSize,
3277                                       wSizeStep, isSingleChanneled);
3278 
3279     auto linearIndex = [&](int64_t kw, int64_t w) {
3280       return kw * (wSize / wSizeStep) + w;
3281     };
3282 
3283     // Compute contraction: O{n, w, f} += I{n, sw * w + dw * kw, c} * F{c, f}
3284     // or perform outerproduct for non-channeled convolution or perform simple
3285     // arith operation for pooling
3286     for (int64_t kw = 0; kw < kwSize; ++kw) {
3287       for (int64_t w = 0; w < wSize; w += wSizeStep) {
3288         switch (oper) {
3289         case Conv:
3290           if (isSingleChanneled) {
3291             resVals[w] = conv1dSliceAsOuterProduct(rewriter, loc,
3292                                                    lhsVals[linearIndex(kw, w)],
3293                                                    rhsVals[kw], resVals[w]);
3294           } else {
3295             resVals[w] = conv1dSliceAsContraction(rewriter, loc,
3296                                                   lhsVals[linearIndex(kw, w)],
3297                                                   rhsVals[kw], resVals[w]);
3298           }
3299           break;
3300         case Pool:
3301           resVals[w] = pool1dSlice(rewriter, loc, lhsVals[linearIndex(kw, w)],
3302                                    resVals[w]);
3303           break;
3304         }
3305       }
3306     }
3307 
3308     res = insertConvResultSlices(rewriter, loc, res, wSize, wSizeStep, resVals,
3309                                  isSingleChanneled);
3310     //===------------------------------------------------------------------===//
3311     // End vector-only rewrite part
3312     //===------------------------------------------------------------------===//
3313 
3314     // The base vectorization case for channeled convolution is output:
3315     // {n,w,f} To reuse the result from base pattern vectorization case, we
3316     // post transpose the base case result.
3317     switch (conv1DOpOrder) {
3318     case Conv1DOpOrder::W:
3319     case Conv1DOpOrder::Nwc:
3320       // Base case, so no transposes necessary.
3321       break;
3322     case Conv1DOpOrder::Ncw: {
3323       // nwf -> nfw
3324       static constexpr std::array<int64_t, 3> perm = {0, 2, 1};
3325       res = rewriter.create<vector::TransposeOp>(loc, res, perm);
3326       break;
3327     }
3328     }
3329 
3330     return rewriter
3331         .create<vector::TransferWriteOp>(loc, res, resShaped, resPadding)
3332         .getOperation();
3333   }
3334 
3335   // Take a value and widen to have the same element type as `ty`.
3336   Value promote(RewriterBase &rewriter, Location loc, Value val, Type ty) {
3337     const Type srcElementType = getElementTypeOrSelf(val.getType());
3338     const Type dstElementType = getElementTypeOrSelf(ty);
3339     assert(isa<IntegerType>(dstElementType) || isa<FloatType>(dstElementType));
3340     if (srcElementType == dstElementType)
3341       return val;
3342 
3343     const int64_t srcWidth = srcElementType.getIntOrFloatBitWidth();
3344     const int64_t dstWidth = dstElementType.getIntOrFloatBitWidth();
3345     const Type dstType =
3346         cast<ShapedType>(val.getType()).cloneWith(std::nullopt, dstElementType);
3347 
3348     if (isa<IntegerType>(srcElementType) && isa<FloatType>(dstElementType)) {
3349       return rewriter.create<arith::SIToFPOp>(loc, dstType, val);
3350     }
3351 
3352     if (isa<FloatType>(srcElementType) && isa<FloatType>(dstElementType) &&
3353         srcWidth < dstWidth)
3354       return rewriter.create<arith::ExtFOp>(loc, dstType, val);
3355 
3356     if (isa<IntegerType>(srcElementType) && isa<IntegerType>(dstElementType) &&
3357         srcWidth < dstWidth)
3358       return rewriter.create<arith::ExtSIOp>(loc, dstType, val);
3359 
3360     assert(false && "unhandled promotion case");
3361     return nullptr;
3362   }
3363 
3364   // Create a contraction: lhs{n, w, c} * rhs{c, f} -> res{n, w, f}
3365   Value conv1dSliceAsContraction(RewriterBase &rewriter, Location loc,
3366                                  Value lhs, Value rhs, Value res) {
3367     vector::IteratorType par = vector::IteratorType::parallel;
3368     vector::IteratorType red = vector::IteratorType::reduction;
3369     AffineExpr n, w, f, c;
3370     bindDims(ctx, n, w, f, c);
3371     lhs = promote(rewriter, loc, lhs, res.getType());
3372     rhs = promote(rewriter, loc, rhs, res.getType());
3373     auto contrationOp = rewriter.create<vector::ContractionOp>(
3374         loc, lhs, rhs, res,
3375         /*indexingMaps=*/MapList{{n, w, c}, {c, f}, {n, w, f}},
3376         /*iteratorTypes=*/ArrayRef<vector::IteratorType>{par, par, par, red});
3377     contrationOp.setKind(reductionKind);
3378     return contrationOp;
3379   }
3380 
3381   // Create an outerproduct: lhs{w} * rhs{1} -> res{w} for single channel
3382   // convolution.
3383   Value conv1dSliceAsOuterProduct(RewriterBase &rewriter, Location loc,
3384                                   Value lhs, Value rhs, Value res) {
3385     return rewriter.create<vector::OuterProductOp>(
3386         loc, res.getType(), lhs, rhs, res, vector::CombiningKind::ADD);
3387   }
3388 
3389   // Create a reduction: lhs{n, w, c} -> res{n, w, c}
3390   Value pool1dSlice(RewriterBase &rewriter, Location loc, Value lhs,
3391                     Value res) {
3392     if (isPoolExt)
3393       lhs = rewriter.create(loc, poolExtOp, lhs, res.getType())->getResult(0);
3394     return rewriter
3395         .create(loc, redOp, ArrayRef<Value>{lhs, res}, res.getType())
3396         ->getResult(0);
3397   }
3398 
3399   /// Generate a vector implementation for:
3400   /// ```
3401   ///   Op def: (     n,     w,     c,    kw)
3402   ///    Iters: ({Par(), Par(), Par(), Red()})
3403   ///   Layout: {{n, strideW * w + dilationW * kw, c}, {kw, c}, {n, w, c}}
3404   /// ```
3405   /// kw is always unrolled.
3406   /// TODO: w (resp. kw) is unrolled when the strideW ( resp. dilationW) is
3407   /// > 1.
3408   FailureOr<Operation *> depthwiseConv(uint64_t channelDimVecSize,
3409                                        bool channelDimScalableFlag,
3410                                        bool flatten) {
3411     if (!valid)
3412       return rewriter.notifyMatchFailure(op, "unvectorizable depthwise conv");
3413 
3414     bool scalableChDim = false;
3415     bool useMasking = false;
3416     int64_t nSize, wSize, cSize, kwSize;
3417     // kernel{kw, c}
3418     bindShapeDims(rhsShapedType, kwSize, cSize);
3419     if (ShapedType::isDynamic(cSize)) {
3420       assert(channelDimVecSize != 0 && "Channel dim vec size must be > 0");
3421       cSize = channelDimVecSize;
3422       // Scalable vectors are only used when both conditions are met:
3423       //  1. channel dim is dynamic
3424       //  2. channelDimScalableFlag is set
3425       scalableChDim = channelDimScalableFlag;
3426       useMasking = true;
3427     }
3428 
3429     assert(!(useMasking && flatten) &&
3430            "Unsupported flattened conv with dynamic shapes");
3431 
3432     // out{n, w, c}
3433     bindShapeDims(resShapedType, nSize, wSize);
3434 
3435     vector::TransferWriteOp write;
3436     Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
3437 
3438     // w is unrolled (i.e. wSizeStep == 1) iff strideW > 1.
3439     // When strideW == 1, we can batch the contiguous loads and avoid
3440     // unrolling
3441     int64_t wSizeStep = strideW == 1 ? wSize : 1;
3442 
3443     Type lhsEltType = lhsShapedType.getElementType();
3444     Type rhsEltType = rhsShapedType.getElementType();
3445     Type resEltType = resShapedType.getElementType();
3446     VectorType lhsType = VectorType::get(
3447         {nSize,
3448          // iw = ow * sw + kw *  dw - 1
3449          //   (i.e. 16 convolved with 3 (@stride 1 dilation 1) -> 14)
3450          ((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) - 1,
3451          cSize},
3452         lhsEltType, /*scalableDims=*/{false, false, scalableChDim});
3453     VectorType rhsType =
3454         VectorType::get({kwSize, cSize}, rhsEltType,
3455                         /*scalableDims=*/{false, scalableChDim});
3456     VectorType resType =
3457         VectorType::get({nSize, wSize, cSize}, resEltType,
3458                         /*scalableDims=*/{false, false, scalableChDim});
3459 
3460     // Masks the input xfer Op along the channel dim, iff the corresponding
3461     // scalable flag is set.
3462     auto maybeMaskXferOp = [&](ArrayRef<int64_t> maskShape,
3463                                ArrayRef<bool> scalableDims,
3464                                Operation *opToMask) {
3465       if (!useMasking)
3466         return opToMask;
3467       auto maskType =
3468           VectorType::get(maskShape, rewriter.getI1Type(), scalableDims);
3469 
3470       SmallVector<bool> inBounds(maskShape.size(), true);
3471       auto xferOp = cast<VectorTransferOpInterface>(opToMask);
3472       xferOp->setAttr(xferOp.getInBoundsAttrName(),
3473                       rewriter.getBoolArrayAttr(inBounds));
3474 
3475       SmallVector<OpFoldResult> mixedDims = vector::getMixedSizesXfer(
3476           cast<LinalgOp>(op).hasPureTensorSemantics(), opToMask, rewriter);
3477 
3478       Value maskOp =
3479           rewriter.create<vector::CreateMaskOp>(loc, maskType, mixedDims);
3480 
3481       return mlir::vector::maskOperation(rewriter, opToMask, maskOp);
3482     };
3483 
3484     // Read lhs slice of size {n, w * strideW + kw * dilationW, c} @ [0, 0,
3485     // 0].
3486     Value lhs = rewriter.create<vector::TransferReadOp>(
3487         loc, lhsType, lhsShaped, ValueRange{zero, zero, zero});
3488     auto maybeMaskedLhs = maybeMaskXferOp(
3489         lhsType.getShape(), lhsType.getScalableDims(), lhs.getDefiningOp());
3490 
3491     // Read rhs slice of size {kw, c} @ [0, 0].
3492     Value rhs = rewriter.create<vector::TransferReadOp>(loc, rhsType, rhsShaped,
3493                                                         ValueRange{zero, zero});
3494     auto maybeMaskedRhs = maybeMaskXferOp(
3495         rhsType.getShape(), rhsType.getScalableDims(), rhs.getDefiningOp());
3496 
3497     // Read res slice of size {n, w, c} @ [0, 0, 0].
3498     Value res = rewriter.create<vector::TransferReadOp>(
3499         loc, resType, resShaped, ValueRange{zero, zero, zero});
3500     auto maybeMaskedRes = maybeMaskXferOp(
3501         resType.getShape(), resType.getScalableDims(), res.getDefiningOp());
3502 
3503     //===------------------------------------------------------------------===//
3504     // Begin vector-only rewrite part
3505     //===------------------------------------------------------------------===//
3506     // Unroll along kw and read slices of lhs and rhs.
3507     SmallVector<Value> lhsVals, rhsVals, resVals;
3508     SmallVector<int64_t> inOutSliceSizes = {nSize, wSizeStep, cSize};
3509     SmallVector<int64_t> inOutStrides = {1, 1, 1};
3510 
3511     // Extract lhs slice of size {n, wSizeStep, c}
3512     //   @ [0, sw * w + dw * kw, 0].
3513     for (int64_t kw = 0; kw < kwSize; ++kw) {
3514       for (int64_t w = 0; w < wSize; w += wSizeStep) {
3515         lhsVals.push_back(rewriter.create<vector::ExtractStridedSliceOp>(
3516             loc, maybeMaskedLhs->getResult(0),
3517             /*offsets=*/ArrayRef<int64_t>{0, w * strideW + kw * dilationW, 0},
3518             inOutSliceSizes, inOutStrides));
3519       }
3520     }
3521     // Extract rhs slice of size {c} @ [kw].
3522     for (int64_t kw = 0; kw < kwSize; ++kw) {
3523       rhsVals.push_back(rewriter.create<vector::ExtractOp>(
3524           loc, maybeMaskedRhs->getResult(0),
3525           /*offsets=*/ArrayRef<int64_t>{kw}));
3526     }
3527     // Extract res slice: {n, wSizeStep, c} @ [0, w, 0].
3528     for (int64_t w = 0; w < wSize; w += wSizeStep) {
3529       resVals.push_back(rewriter.create<vector::ExtractStridedSliceOp>(
3530           loc, maybeMaskedRes->getResult(0),
3531           /*offsets=*/ArrayRef<int64_t>{0, w, 0}, inOutSliceSizes,
3532           inOutStrides));
3533     }
3534 
3535     auto linearIndex = [&](int64_t kw, int64_t w) {
3536       return kw * (wSize / wSizeStep) + w;
3537     };
3538 
3539     // Note - the scalable flags are ignored as flattening combined with
3540     // scalable vectorization is not supported.
3541     SmallVector<int64_t> inOutFlattenSliceSizes = {nSize, wSizeStep * cSize};
3542     auto lhsTypeAfterFlattening =
3543         VectorType::get(inOutFlattenSliceSizes, lhsEltType);
3544     auto resTypeAfterFlattening =
3545         VectorType::get(inOutFlattenSliceSizes, resEltType);
3546 
3547     // Compute contraction: O{n, w, c} += I{n, sw * w + dw * kw, c} * F{c}
3548     for (int64_t kw = 0; kw < kwSize; ++kw) {
3549       for (int64_t w = 0; w < wSize; w += wSizeStep) {
3550         Value lhsVal = lhsVals[linearIndex(kw, w)];
3551         Value resVal = resVals[w];
3552         if (flatten) {
3553           // Flatten the input and output vectors (collapse the channel
3554           // dimension)
3555           lhsVal = rewriter.create<vector::ShapeCastOp>(
3556               loc, lhsTypeAfterFlattening, lhsVals[linearIndex(kw, w)]);
3557           resVal = rewriter.create<vector::ShapeCastOp>(
3558               loc, resTypeAfterFlattening, resVals[w]);
3559         }
3560         resVals[w] = depthwiseConv1dSliceAsMulAcc(rewriter, loc, lhsVal,
3561                                                   rhsVals[kw], resVal, flatten);
3562         if (flatten) {
3563           // Un-flatten the output vector (restore the channel dimension)
3564           resVals[w] = rewriter.create<vector::ShapeCastOp>(
3565               loc, VectorType::get(inOutSliceSizes, resEltType), resVals[w]);
3566         }
3567       }
3568     }
3569 
3570     // Its possible we failed to create the Fma.
3571     if (!llvm::all_of(resVals, [](Value v) { return v; })) {
3572       // Manually revert (in reverse order) to avoid leaving a bad IR state.
3573       for (auto &collection :
3574            {resVals, rhsVals, lhsVals, {res, rhs, lhs, zero}})
3575         for (Value v : collection)
3576           rewriter.eraseOp(v.getDefiningOp());
3577       return rewriter.notifyMatchFailure(op, "failed to create FMA");
3578     }
3579 
3580     // Write back res slice: {n, wSizeStep, c} @ [0, w, 0].
3581     // This does not depend on kw.
3582     for (int64_t w = 0; w < wSize; w += wSizeStep) {
3583       maybeMaskedRes = rewriter.create<vector::InsertStridedSliceOp>(
3584           loc, resVals[w], maybeMaskedRes->getResult(0),
3585           /*offsets=*/ArrayRef<int64_t>{0, w, 0},
3586           /*strides=*/ArrayRef<int64_t>{1, 1, 1});
3587     }
3588     //===------------------------------------------------------------------===//
3589     // End vector-only rewrite part
3590     //===------------------------------------------------------------------===//
3591 
3592     // Write back res slice of size {n, w, c} @ [0, 0, 0].
3593     Operation *resOut = rewriter.create<vector::TransferWriteOp>(
3594         loc, maybeMaskedRes->getResult(0), resShaped,
3595         ValueRange{zero, zero, zero});
3596     return maybeMaskXferOp(resType.getShape(), resType.getScalableDims(),
3597                            resOut);
3598   }
3599 
3600   /// Lower:
3601   ///   *  lhs{n, w, c} * rhs{c} -> res{n, w, c} (flatten = false)
3602   ///   *  lhs{n, w * c} * rhs{c} -> res{n, w * c} (flatten = true)
3603   /// to MulAcc.
3604   Value depthwiseConv1dSliceAsMulAcc(RewriterBase &rewriter, Location loc,
3605                                      Value lhs, Value rhs, Value res,
3606                                      bool flatten) {
3607     auto rhsTy = cast<ShapedType>(rhs.getType());
3608     auto resTy = cast<ShapedType>(res.getType());
3609 
3610     // TODO(suderman): Change this to use a vector.ima intrinsic.
3611     lhs = promote(rewriter, loc, lhs, resTy);
3612 
3613     if (flatten) {
3614       // NOTE: This following logic won't work for scalable vectors. For this
3615       // reason, "flattening" is not supported when shapes are dynamic (this
3616       // should be captured by one of the pre-conditions).
3617 
3618       // There are two options for handling the filter:
3619       //  * shape_cast(broadcast(filter))
3620       //  * broadcast(shuffle(filter))
3621       // Opt for the option without shape_cast to simplify the codegen.
3622       auto rhsSize = cast<VectorType>(rhs.getType()).getShape()[0];
3623       auto resSize = cast<VectorType>(res.getType()).getShape()[1];
3624 
3625       SmallVector<int64_t, 16> indices;
3626       for (int i = 0; i < resSize / rhsSize; ++i) {
3627         for (int j = 0; j < rhsSize; ++j)
3628           indices.push_back(j);
3629       }
3630 
3631       rhs = rewriter.create<vector::ShuffleOp>(loc, rhs, rhs, indices);
3632     }
3633     // Broadcast the filter to match the output vector
3634     rhs = rewriter.create<vector::BroadcastOp>(
3635         loc, resTy.clone(rhsTy.getElementType()), rhs);
3636 
3637     rhs = promote(rewriter, loc, rhs, resTy);
3638 
3639     if (!lhs || !rhs)
3640       return nullptr;
3641 
3642     if (isa<FloatType>(resTy.getElementType()))
3643       return rewriter.create<vector::FMAOp>(loc, lhs, rhs, res);
3644 
3645     auto mul = rewriter.create<arith::MulIOp>(loc, lhs, rhs);
3646     return rewriter.create<arith::AddIOp>(loc, mul, res);
3647   }
3648 
3649   /// Entry point for non-channeled convolution:
3650   ///   {{w + kw}, {kw}, {w}}
3651   FailureOr<Operation *> generateNonChanneledConv() {
3652     AffineExpr w, kw;
3653     bindDims(ctx, w, kw);
3654     if (!iters({Par(), Red()}))
3655       return rewriter.notifyMatchFailure(op,
3656                                          "failed to match conv::W 1-par 1-red");
3657 
3658     // No transposition needed.
3659     if (layout({/*lhsIndex*/ {w + kw},
3660                 /*rhsIndex*/ {kw},
3661                 /*resIndex*/ {w}}))
3662       return conv(Conv1DOpOrder::W);
3663 
3664     return rewriter.notifyMatchFailure(op, "not a conv::W layout");
3665   }
3666 
3667   /// Entry point that transposes into the common form:
3668   ///   {{n, strideW * w + dilationW * kw, c}, {kw, c, f}, {n, w, f}}
3669   FailureOr<Operation *> generateNwcConv() {
3670     AffineExpr n, w, f, kw, c;
3671     bindDims(ctx, n, w, f, kw, c);
3672     if (!iters({Par(), Par(), Par(), Red(), Red()}))
3673       return rewriter.notifyMatchFailure(
3674           op, "failed to match conv::Nwc 3-par 2-red");
3675 
3676     // No transposition needed.
3677     if (layout({/*lhsIndex*/ {n, strideW * w + dilationW * kw, c},
3678                 /*rhsIndex*/ {kw, c, f},
3679                 /*resIndex*/ {n, w, f}}))
3680       return conv(Conv1DOpOrder::Nwc);
3681 
3682     return rewriter.notifyMatchFailure(op, "not a conv::Nwc layout");
3683   }
3684 
3685   /// Entry point that transposes into the common form:
3686   ///   {{n, c, strideW * w + dilationW * kw}, {f, c, kw}, {n, f, w}}
3687   FailureOr<Operation *> generateNcwConv() {
3688     AffineExpr n, w, f, kw, c;
3689     bindDims(ctx, n, f, w, c, kw);
3690     if (!iters({Par(), Par(), Par(), Red(), Red()}))
3691       return rewriter.notifyMatchFailure(
3692           op, "failed to match conv::Ncw 3-par 2-red");
3693 
3694     if (layout({/*lhsIndex*/ {n, c, strideW * w + dilationW * kw},
3695                 /*rhsIndex*/ {f, c, kw},
3696                 /*resIndex*/ {n, f, w}}))
3697       return conv(Conv1DOpOrder::Ncw);
3698 
3699     return rewriter.notifyMatchFailure(op, "not a conv::Ncw layout");
3700   }
3701 
3702   /// Entry point that transposes into the common form:
3703   ///   {{n, strideW * w + dilationW * kw, c}, {kw}, {n, w, c}} for pooling
3704   FailureOr<Operation *> generateNwcPooling() {
3705     AffineExpr n, w, c, kw;
3706     bindDims(ctx, n, w, c, kw);
3707     if (!iters({Par(), Par(), Par(), Red()}))
3708       return rewriter.notifyMatchFailure(op,
3709                                          "failed to match pooling 3-par 1-red");
3710 
3711     // No transposition needed.
3712     if (layout({/*lhsIndex*/ {n, strideW * w + dilationW * kw, c},
3713                 /*rhsIndex*/ {kw},
3714                 /*resIndex*/ {n, w, c}}))
3715       return conv(Conv1DOpOrder::Nwc);
3716 
3717     return rewriter.notifyMatchFailure(op, "not a pooling::Nwc layout");
3718   }
3719 
3720   /// Entry point that transposes into the common form:
3721   ///   {{n, c, strideW * w + dilationW * kw}, {kw}, {n, c, w}} for pooling
3722   FailureOr<Operation *> generateNcwPooling() {
3723     AffineExpr n, w, c, kw;
3724     bindDims(ctx, n, c, w, kw);
3725     if (!iters({Par(), Par(), Par(), Red()}))
3726       return rewriter.notifyMatchFailure(op,
3727                                          "failed to match pooling 3-par 1-red");
3728 
3729     if (layout({/*lhsIndex*/ {n, c, strideW * w + dilationW * kw},
3730                 /*rhsIndex*/ {kw},
3731                 /*resIndex*/ {n, c, w}}))
3732       return conv(Conv1DOpOrder::Ncw);
3733 
3734     return rewriter.notifyMatchFailure(op, "not a pooling::Ncw layout");
3735   }
3736 
3737   /// Entry point that transposes into the common form:
3738   ///   {{n, strideW * w + dilationW * kw, c}, {kw, c}, {n, w, c}}
3739   FailureOr<Operation *> generateDilatedConv(uint64_t vecChDimSize = 0,
3740                                              bool vecChDimScalableFlag = false,
3741                                              bool flatten = false) {
3742     AffineExpr n, w, c, kw;
3743     bindDims(ctx, n, w, c, kw);
3744     if (!iters({Par(), Par(), Par(), Red()}))
3745       return rewriter.notifyMatchFailure(
3746           op, "failed to match depthwise::Nwc conv 3-par 1-red");
3747 
3748     // No transposition needed.
3749     if (layout({/*lhsIndex*/ {n, strideW * w + dilationW * kw, c},
3750                 /*rhsIndex*/ {kw, c},
3751                 /*resIndex*/ {n, w, c}}))
3752       return depthwiseConv(vecChDimSize, vecChDimScalableFlag, flatten);
3753 
3754     return rewriter.notifyMatchFailure(op, "not a depthwise::Nwc layout");
3755   }
3756 
3757 private:
3758   enum OperKind { Conv, Pool };
3759   bool valid = false;
3760   OperKind oper = Conv;
3761   StringAttr redOp;
3762   StringAttr poolExtOp;
3763   bool isPoolExt = false;
3764   int strideW, dilationW;
3765   Value lhsShaped, rhsShaped, resShaped;
3766   ShapedType lhsShapedType, rhsShapedType, resShapedType;
3767   vector::CombiningKind reductionKind;
3768 
3769   // Sets oper, poolExtOp and isPoolExt for valid conv/pooling ops.
3770   // Returns true iff it is a valid conv/pooling op.
3771   // If (region has 2 ops (reduction + yield) or 3 ops (extension + reduction
3772   // + yield) and rhs is not used) then it is the body of a pooling
3773   // If conv, check for single `mul` predecessor. The `mul` operands must be
3774   // block arguments or extension of block arguments.
3775   // Otherwise, check for one or zero `ext` predecessor. The `ext` operands
3776   // must be block arguments or extension of block arguments.
3777   bool setOperKind(Operation *reduceOp) {
3778     int numBlockArguments =
3779         llvm::count_if(reduceOp->getOperands(), llvm::IsaPred<BlockArgument>);
3780     switch (numBlockArguments) {
3781     case 1: {
3782       // Will be convolution if feeder is a MulOp.
3783       // A strength reduced version of MulOp for i1 type is AndOp which is also
3784       // supported. Otherwise, it can be pooling. This strength reduction logic
3785       // is in `buildBinaryFn` helper in the Linalg dialect.
3786       auto feedValIt = llvm::find_if_not(reduceOp->getOperands(),
3787                                          llvm::IsaPred<BlockArgument>);
3788       Operation *feedOp = (*feedValIt).getDefiningOp();
3789       if (isCastOfBlockArgument(feedOp)) {
3790         oper = Pool;
3791         isPoolExt = true;
3792         poolExtOp = feedOp->getName().getIdentifier();
3793       } else if (!((isa<arith::MulIOp, arith::MulFOp>(feedOp) ||
3794                     (isa<arith::AndIOp>(feedOp) &&
3795                      feedOp->getResultTypes()[0].isInteger(1))) &&
3796                    llvm::all_of(feedOp->getOperands(), [](Value v) {
3797                      if (isa<BlockArgument>(v))
3798                        return true;
3799                      if (Operation *op = v.getDefiningOp())
3800                        return isCastOfBlockArgument(op);
3801                      return false;
3802                    }))) {
3803         return false;
3804       }
3805       return true;
3806     }
3807     case 2:
3808       // Must be pooling
3809       oper = Pool;
3810       isPoolExt = false;
3811       return true;
3812     default:
3813       return false;
3814     }
3815   }
3816 };
3817 } // namespace
3818 
3819 /// Helper function to vectorize a LinalgOp with convolution semantics.
3820 // TODO: extend the generic vectorization to support windows and drop this.
3821 static FailureOr<Operation *> vectorizeConvolution(
3822     RewriterBase &rewriter, LinalgOp op, ArrayRef<int64_t> inputVecSizes,
3823     ArrayRef<bool> inputScalableVecDims, bool flatten1DDepthwiseConv) {
3824   // The ConvolutionOpInterface gives us guarantees of existence for
3825   // strides/dilations. However, we do not need to rely on those, we can
3826   // simply use them if present, otherwise use the default and let the generic
3827   // conv. matcher in the ConvGenerator succeed or fail.
3828   auto strides = op->getAttrOfType<DenseIntElementsAttr>("strides");
3829   auto dilations = op->getAttrOfType<DenseIntElementsAttr>("dilations");
3830   auto stride = strides ? *strides.getValues<uint64_t>().begin() : 1;
3831   auto dilation = dilations ? *dilations.getValues<uint64_t>().begin() : 1;
3832   Conv1DGenerator e(rewriter, op, stride, dilation);
3833   auto res = e.generateNonChanneledConv();
3834   if (succeeded(res))
3835     return res;
3836   res = e.generateNwcConv();
3837   if (succeeded(res))
3838     return res;
3839   res = e.generateNcwConv();
3840   if (succeeded(res))
3841     return res;
3842   res = e.generateNwcPooling();
3843   if (succeeded(res))
3844     return res;
3845   res = e.generateNcwPooling();
3846   if (succeeded(res))
3847     return res;
3848 
3849   // Only depthwise 1D NWC convs are left - these can be vectorized using masks
3850   // and scalable vectors. Note that ATM the only dim that can be dynamic (i.e.
3851   // masked/scalable) is the channel dim (i.e. the trailing dim).
3852   uint64_t vecChDimSize = ShapedType::kDynamic;
3853   bool vecChDimScalableFlag = false;
3854   if (!inputVecSizes.empty()) {
3855     // Only use the input vector size corresponding to the channel dim. Other
3856     // vector dims will be inferred from the Ops.
3857     assert((isa<linalg::DepthwiseConv1DNwcWcOp>(*op) ||
3858             isa<linalg::DepthwiseConv1DNcwCwOp>(*op)) &&
3859            "Not a 1D depthwise conv!");
3860     size_t chDimIdx =
3861         TypeSwitch<Operation *, size_t>(op)
3862             .Case<linalg::DepthwiseConv1DNwcWcOp>([](auto conv) { return 2; })
3863             .Case<linalg::DepthwiseConv1DNcwCwOp>([](auto conv) { return 1; });
3864 
3865     vecChDimSize = inputVecSizes[chDimIdx];
3866     vecChDimScalableFlag = inputScalableVecDims[chDimIdx];
3867   }
3868   return e.generateDilatedConv(vecChDimSize, vecChDimScalableFlag,
3869                                flatten1DDepthwiseConv);
3870 }
3871 
3872 struct VectorizeConvolution : public OpInterfaceRewritePattern<LinalgOp> {
3873   using OpInterfaceRewritePattern::OpInterfaceRewritePattern;
3874 
3875   LogicalResult matchAndRewrite(LinalgOp op,
3876                                 PatternRewriter &rewriter) const override {
3877     FailureOr<Operation *> resultOrFail = vectorizeConvolution(rewriter, op);
3878     if (failed(resultOrFail))
3879       return failure();
3880     Operation *newOp = *resultOrFail;
3881     if (newOp->getNumResults() == 0) {
3882       rewriter.eraseOp(op.getOperation());
3883       return success();
3884     }
3885     assert(newOp->getNumResults() == 1 && "expected single result");
3886     rewriter.replaceOp(op.getOperation(), newOp->getResult(0));
3887     return success();
3888   }
3889 };
3890 
3891 void mlir::linalg::populateConvolutionVectorizationPatterns(
3892     RewritePatternSet &patterns, PatternBenefit benefit) {
3893   patterns.add<VectorizeConvolution>(patterns.getContext(), benefit);
3894 }
3895