xref: /llvm-project/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp (revision b47d1787b51f55d69ef1b4f88e72cd54af451649)
1 //===- Vectorization.cpp - Implementation of linalg Vectorization ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the linalg dialect Vectorization transformations.
10 //
11 //===----------------------------------------------------------------------===//
12 #include "mlir/Dialect/Affine/Utils.h"
13 
14 #include "mlir/Analysis/SliceAnalysis.h"
15 #include "mlir/Dialect/Affine/IR/AffineOps.h"
16 #include "mlir/Dialect/Arith/IR/Arith.h"
17 #include "mlir/Dialect/Func/IR/FuncOps.h"
18 #include "mlir/Dialect/Linalg/IR/Linalg.h"
19 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
20 #include "mlir/Dialect/Linalg/Utils/Utils.h"
21 #include "mlir/Dialect/Tensor/IR/Tensor.h"
22 #include "mlir/Dialect/Tensor/Utils/Utils.h"
23 #include "mlir/Dialect/Utils/IndexingUtils.h"
24 #include "mlir/Dialect/Utils/StructuredOpsUtils.h"
25 #include "mlir/Dialect/Vector/IR/VectorOps.h"
26 #include "mlir/Dialect/Vector/Interfaces/MaskableOpInterface.h"
27 #include "mlir/Dialect/Vector/Utils/VectorUtils.h"
28 #include "mlir/IR/AffineExpr.h"
29 #include "mlir/IR/Builders.h"
30 #include "mlir/IR/BuiltinTypeInterfaces.h"
31 #include "mlir/IR/BuiltinTypes.h"
32 #include "mlir/IR/OpDefinition.h"
33 #include "mlir/IR/PatternMatch.h"
34 #include "mlir/Support/LLVM.h"
35 #include "mlir/Transforms/RegionUtils.h"
36 #include "llvm/ADT/STLExtras.h"
37 #include "llvm/ADT/Sequence.h"
38 #include "llvm/ADT/SmallVector.h"
39 #include "llvm/ADT/TypeSwitch.h"
40 #include "llvm/ADT/iterator_range.h"
41 #include "llvm/Support/Debug.h"
42 #include "llvm/Support/MathExtras.h"
43 #include "llvm/Support/raw_ostream.h"
44 #include <optional>
45 #include <type_traits>
46 
47 using namespace mlir;
48 using namespace mlir::linalg;
49 
50 #define DEBUG_TYPE "linalg-vectorization"
51 
52 #define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
53 #define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
54 
55 /// Try to vectorize `convOp` as a convolution.
56 static FailureOr<Operation *>
57 vectorizeConvolution(RewriterBase &rewriter, LinalgOp convOp,
58                      ArrayRef<int64_t> inputVecSizes = {},
59                      ArrayRef<bool> inputVecScalableFlags = {},
60                      bool flatten1DDepthwiseConv = false);
61 
62 /// Return the unique instance of OpType in `block` if it is indeed unique.
63 /// Return null if none or more than 1 instances exist.
64 template <typename OpType>
65 static OpType getSingleOpOfType(Block &block) {
66   OpType res;
67   block.walk([&](OpType op) {
68     if (res) {
69       res = nullptr;
70       return WalkResult::interrupt();
71     }
72     res = op;
73     return WalkResult::advance();
74   });
75   return res;
76 }
77 
78 /// Helper function to extract the input slices after filter is unrolled along
79 /// kw.
80 static SmallVector<Value>
81 extractConvInputSlices(RewriterBase &rewriter, Location loc, Value input,
82                        int64_t nSize, int64_t wSize, int64_t cSize,
83                        int64_t kwSize, int strideW, int dilationW,
84                        int64_t wSizeStep, bool isSingleChanneled) {
85   SmallVector<Value> result;
86   if (isSingleChanneled) {
87     // Extract input slice of size {wSizeStep} @ [w + kw] for non-channeled
88     // convolution.
89     SmallVector<int64_t> sizes{wSizeStep};
90     SmallVector<int64_t> strides{1};
91     for (int64_t kw = 0; kw < kwSize; ++kw) {
92       for (int64_t w = 0; w < wSize; w += wSizeStep) {
93         result.push_back(rewriter.create<vector::ExtractStridedSliceOp>(
94             loc, input, /*offsets=*/ArrayRef<int64_t>{w + kw}, sizes, strides));
95       }
96     }
97   } else {
98     // Extract lhs slice of size {n, wSizeStep, c} @ [0, sw * w + dw * kw, 0]
99     // for channeled convolution.
100     SmallVector<int64_t> sizes{nSize, wSizeStep, cSize};
101     SmallVector<int64_t> strides{1, 1, 1};
102     for (int64_t kw = 0; kw < kwSize; ++kw) {
103       for (int64_t w = 0; w < wSize; w += wSizeStep) {
104         result.push_back(rewriter.create<vector::ExtractStridedSliceOp>(
105             loc, input,
106             /*offsets=*/ArrayRef<int64_t>{0, w * strideW + kw * dilationW, 0},
107             sizes, strides));
108       }
109     }
110   }
111   return result;
112 }
113 
114 /// Helper function to extract the filter slices after filter is unrolled along
115 /// kw.
116 static SmallVector<Value> extractConvFilterSlices(RewriterBase &rewriter,
117                                                   Location loc, Value filter,
118                                                   int64_t kwSize) {
119   SmallVector<Value> result;
120   // Extract rhs slice of size [{c, f} for channeled convolutions and {1} for
121   // non-chanelled convolution] @ [kw].
122   for (int64_t kw = 0; kw < kwSize; ++kw) {
123     result.push_back(rewriter.create<vector::ExtractOp>(
124         loc, filter, /*offsets=*/ArrayRef<int64_t>{kw}));
125   }
126   return result;
127 }
128 
129 /// Helper function to extract the result slices after filter is unrolled along
130 /// kw.
131 static SmallVector<Value>
132 extractConvResultSlices(RewriterBase &rewriter, Location loc, Value res,
133                         int64_t nSize, int64_t wSize, int64_t fSize,
134                         int64_t wSizeStep, bool isSingleChanneled) {
135   SmallVector<Value> result;
136   if (isSingleChanneled) {
137     // Extract res slice: {wSizeStep} @ [w] for non-channeled convolution.
138     SmallVector<int64_t> sizes{wSizeStep};
139     SmallVector<int64_t> strides{1};
140     for (int64_t w = 0; w < wSize; w += wSizeStep) {
141       result.push_back(rewriter.create<vector::ExtractStridedSliceOp>(
142           loc, res, /*offsets=*/ArrayRef<int64_t>{w}, sizes, strides));
143     }
144   } else {
145     // Extract res slice: {n, wSizeStep, f} @ [0, w, 0] for channeled
146     // convolution.
147     SmallVector<int64_t> sizes{nSize, wSizeStep, fSize};
148     SmallVector<int64_t> strides{1, 1, 1};
149     for (int64_t w = 0; w < wSize; w += wSizeStep) {
150       result.push_back(rewriter.create<vector::ExtractStridedSliceOp>(
151           loc, res, /*offsets=*/ArrayRef<int64_t>{0, w, 0}, sizes, strides));
152     }
153   }
154   return result;
155 }
156 
157 /// Helper function to insert the computed result slices.
158 static Value insertConvResultSlices(RewriterBase &rewriter, Location loc,
159                                     Value res, int64_t wSize, int64_t wSizeStep,
160                                     SmallVectorImpl<Value> &resVals,
161                                     bool isSingleChanneled) {
162 
163   if (isSingleChanneled) {
164     // Write back res slice: {wSizeStep} @ [w] for non-channeled convolution.
165     // This does not depend on kw.
166     SmallVector<int64_t> strides{1};
167     for (int64_t w = 0; w < wSize; w += wSizeStep) {
168       res = rewriter.create<vector::InsertStridedSliceOp>(
169           loc, resVals[w], res, /*offsets=*/ArrayRef<int64_t>{w}, strides);
170     }
171   } else {
172     // Write back res slice: {n, wSizeStep, f} @ [0, w, 0] for channeled
173     // convolution. This does not depend on kw.
174     SmallVector<int64_t> strides{1, 1, 1};
175     for (int64_t w = 0; w < wSize; w += wSizeStep) {
176       res = rewriter.create<vector::InsertStridedSliceOp>(
177           loc, resVals[w], res, /*offsets=*/ArrayRef<int64_t>{0, w, 0},
178           strides);
179     }
180   }
181   return res;
182 }
183 
184 /// Contains the vectorization state and related methods used across the
185 /// vectorization process of a given operation.
186 struct VectorizationState {
187   VectorizationState(RewriterBase &rewriter) : rewriterGuard(rewriter) {}
188 
189   /// Initializes the vectorization state, including the computation of the
190   /// canonical vector shape for vectorization.
191   LogicalResult initState(RewriterBase &rewriter, LinalgOp linalgOp,
192                           ArrayRef<int64_t> inputVectorSizes,
193                           ArrayRef<bool> inputScalableVecDims);
194 
195   /// Returns the canonical vector shape used to vectorize the iteration space.
196   ArrayRef<int64_t> getCanonicalVecShape() const { return canonicalVecShape; }
197 
198   /// Returns the vector dimensions that are scalable in the canonical vector
199   /// shape.
200   ArrayRef<bool> getScalableVecDims() const { return scalableVecDims; }
201 
202   /// Returns a vector type of the provided `elementType` with the canonical
203   /// vector shape and the corresponding fixed/scalable dimensions bit. If
204   /// `dimPermutation` is provided, the canonical vector dimensions are permuted
205   /// accordingly.
206   VectorType getCanonicalVecType(
207       Type elementType,
208       std::optional<AffineMap> dimPermutation = std::nullopt) const {
209     SmallVector<int64_t> vectorShape;
210     SmallVector<bool> scalableDims;
211     if (dimPermutation.has_value()) {
212       vectorShape =
213           applyPermutationMap<int64_t>(*dimPermutation, canonicalVecShape);
214       scalableDims =
215           applyPermutationMap<bool>(*dimPermutation, scalableVecDims);
216     } else {
217       vectorShape.append(canonicalVecShape.begin(), canonicalVecShape.end());
218       scalableDims.append(scalableVecDims.begin(), scalableVecDims.end());
219     }
220 
221     return VectorType::get(vectorShape, elementType, scalableDims);
222   }
223 
224   /// Masks an operation with the canonical vector mask if the operation needs
225   /// masking. Returns the masked operation or the original operation if masking
226   /// is not needed. If provided, the canonical mask for this operation is
227   /// permuted using `maybeMaskingMap`.
228   Operation *
229   maskOperation(RewriterBase &rewriter, Operation *opToMask, LinalgOp linalgOp,
230                 std::optional<AffineMap> maybeMaskingMap = std::nullopt);
231 
232 private:
233   /// Initializes the iteration space static sizes using the Linalg op
234   /// information. This may become more complicated in the future.
235   void initIterSpaceStaticSizes(LinalgOp linalgOp) {
236     iterSpaceStaticSizes.append(linalgOp.getStaticLoopRanges());
237   }
238 
239   /// Generates 'arith.constant' and 'tensor/memref.dim' operations for
240   /// all the static and dynamic dimensions of the iteration space to be
241   /// vectorized and store them in `iterSpaceValueSizes`.
242   LogicalResult precomputeIterSpaceValueSizes(RewriterBase &rewriter,
243                                               LinalgOp linalgOp);
244 
245   /// Create or retrieve an existing mask value to mask `opToMask` in the
246   /// canonical vector iteration space. If `maybeMaskingMap` the mask is
247   /// permuted using that permutation map. If a new mask is created, it will be
248   /// cached for future users.
249   Value getOrCreateMaskFor(RewriterBase &rewriter, Operation *opToMask,
250                            LinalgOp linalgOp,
251                            std::optional<AffineMap> maybeMaskingMap);
252 
253   // Holds the compile-time static sizes of the iteration space to vectorize.
254   // Dynamic dimensions are represented using ShapedType::kDynamic.
255   SmallVector<int64_t> iterSpaceStaticSizes;
256 
257   /// Holds the value sizes of the iteration space to vectorize. Static
258   /// dimensions are represented by 'arith.constant' and dynamic
259   /// dimensions by 'tensor/memref.dim'.
260   SmallVector<Value> iterSpaceValueSizes;
261 
262   /// Holds the canonical vector shape used to vectorize the iteration space.
263   SmallVector<int64_t> canonicalVecShape;
264 
265   /// Holds the vector dimensions that are scalable in the canonical vector
266   /// shape.
267   SmallVector<bool> scalableVecDims;
268 
269   /// Holds the active masks for permutations of the canonical vector iteration
270   /// space.
271   DenseMap<AffineMap, Value> activeMaskCache;
272 
273   /// Global vectorization guard for the incoming rewriter. It's initialized
274   /// when the vectorization state is initialized.
275   OpBuilder::InsertionGuard rewriterGuard;
276 };
277 
278 LogicalResult
279 VectorizationState::precomputeIterSpaceValueSizes(RewriterBase &rewriter,
280                                                   LinalgOp linalgOp) {
281   // TODO: Support 0-d vectors.
282   for (int vecDim = 0, end = canonicalVecShape.size(); vecDim < end; ++vecDim) {
283     if (!ShapedType::isDynamic(iterSpaceStaticSizes[vecDim])) {
284       // Create constant index op for static dimensions.
285       iterSpaceValueSizes.push_back(rewriter.create<arith::ConstantIndexOp>(
286           linalgOp.getLoc(), iterSpaceStaticSizes[vecDim]));
287       continue;
288     }
289 
290     // Find an operand defined on this dimension of the iteration space to
291     // extract the runtime dimension size.
292     Value operand;
293     unsigned operandDimPos;
294     if (failed(linalgOp.mapIterationSpaceDimToOperandDim(vecDim, operand,
295                                                          operandDimPos)))
296       return failure();
297 
298     Value dynamicDim = linalgOp.hasPureTensorSemantics()
299                            ? (Value)rewriter.create<tensor::DimOp>(
300                                  linalgOp.getLoc(), operand, operandDimPos)
301                            : (Value)rewriter.create<memref::DimOp>(
302                                  linalgOp.getLoc(), operand, operandDimPos);
303     iterSpaceValueSizes.push_back(dynamicDim);
304   }
305 
306   return success();
307 }
308 
309 /// Initializes the vectorization state, including the computation of the
310 /// canonical vector shape for vectorization.
311 // TODO: Move this to the constructor when we can remove the failure cases.
312 LogicalResult
313 VectorizationState::initState(RewriterBase &rewriter, LinalgOp linalgOp,
314                               ArrayRef<int64_t> inputVectorSizes,
315                               ArrayRef<bool> inputScalableVecDims) {
316   // Initialize the insertion point.
317   rewriter.setInsertionPoint(linalgOp);
318 
319   if (!inputVectorSizes.empty()) {
320     // Get the canonical vector shape from the input vector sizes provided. This
321     // path should be taken to vectorize code with dynamic shapes and when using
322     // vector sizes greater than the iteration space sizes.
323     canonicalVecShape.append(inputVectorSizes.begin(), inputVectorSizes.end());
324     scalableVecDims.append(inputScalableVecDims.begin(),
325                            inputScalableVecDims.end());
326   } else {
327     // Compute the canonical vector shape from the operation shape. If there are
328     // dynamic shapes, the operation won't be vectorized. We assume all the
329     // vector dimensions are fixed.
330     canonicalVecShape = linalgOp.getStaticLoopRanges();
331     scalableVecDims.append(linalgOp.getNumLoops(), false);
332   }
333 
334   LDBG("Canonical vector shape: ");
335   LLVM_DEBUG(llvm::interleaveComma(canonicalVecShape, llvm::dbgs()));
336   LLVM_DEBUG(llvm::dbgs() << "\n");
337   LDBG("Scalable vector dims: ");
338   LLVM_DEBUG(llvm::interleaveComma(scalableVecDims, llvm::dbgs()));
339   LLVM_DEBUG(llvm::dbgs() << "\n");
340 
341   if (ShapedType::isDynamicShape(canonicalVecShape))
342     return failure();
343 
344   // Initialize iteration space static sizes.
345   initIterSpaceStaticSizes(linalgOp);
346 
347   // Generate 'arith.constant' and 'tensor/memref.dim' operations for
348   // all the static and dynamic dimensions of the iteration space, needed to
349   // compute a mask during vectorization.
350   if (failed(precomputeIterSpaceValueSizes(rewriter, linalgOp)))
351     return failure();
352 
353   return success();
354 }
355 
356 /// Create or retrieve an existing mask value to mask `opToMask` in the
357 /// canonical vector iteration space. If `maybeMaskingMap` the mask is permuted
358 /// using that permutation map. If a new mask is created, it will be cached for
359 /// future users.
360 Value VectorizationState::getOrCreateMaskFor(
361     RewriterBase &rewriter, Operation *opToMask, LinalgOp linalgOp,
362     std::optional<AffineMap> maybeMaskingMap) {
363   // No mask is needed if the operation is not maskable.
364   auto maskableOp = dyn_cast<vector::MaskableOpInterface>(opToMask);
365   if (!maskableOp)
366     return Value();
367 
368   assert(!maskableOp.isMasked() &&
369          "Masking an operation that is already masked");
370 
371   // If no masking map was provided, use an identity map with the loop dims.
372   assert((!maybeMaskingMap || *maybeMaskingMap) &&
373          "Unexpected null mask permutation map");
374   AffineMap maskingMap =
375       maybeMaskingMap ? *maybeMaskingMap
376                       : AffineMap::getMultiDimIdentityMap(
377                             linalgOp.getNumLoops(), rewriter.getContext());
378 
379   LDBG("Masking map: " << maskingMap << "\n");
380 
381   // Return the active mask for the masking map of this operation if it was
382   // already created.
383   auto activeMaskIt = activeMaskCache.find(maskingMap);
384   if (activeMaskIt != activeMaskCache.end()) {
385     Value mask = activeMaskIt->second;
386     LDBG("Reusing mask: " << mask << "\n");
387     return mask;
388   }
389 
390   // Compute permuted projection of the iteration space to be masked and the
391   // corresponding mask shape. If the resulting iteration space dimensions are
392   // static and identical to the mask shape, masking is not needed for this
393   // operation.
394   // TODO: Improve this check. Only projected permutation indexing maps are
395   // supported.
396   SmallVector<int64_t> permutedStaticSizes =
397       applyPermutationMap<int64_t>(maskingMap, iterSpaceStaticSizes);
398   auto maskType = getCanonicalVecType(rewriter.getI1Type(), maskingMap);
399   auto maskShape = maskType.getShape();
400 
401   LDBG("Mask shape: ");
402   LLVM_DEBUG(llvm::interleaveComma(maskShape, llvm::dbgs()));
403   LLVM_DEBUG(llvm::dbgs() << "\n");
404 
405   if (permutedStaticSizes == maskShape) {
406     LDBG("Masking is not needed for masking map: " << maskingMap << "\n");
407     activeMaskCache[maskingMap] = Value();
408     return Value();
409   }
410 
411   // Permute the iteration space value sizes to compute the mask upper bounds.
412   SmallVector<Value> upperBounds =
413       applyPermutationMap(maskingMap, ArrayRef<Value>(iterSpaceValueSizes));
414   assert(!maskShape.empty() && !upperBounds.empty() &&
415          "Masked 0-d vectors are not supported yet");
416 
417   // Create the mask based on the dimension values.
418   Value mask = rewriter.create<vector::CreateMaskOp>(linalgOp.getLoc(),
419                                                      maskType, upperBounds);
420   LDBG("Creating new mask: " << mask << "\n");
421   activeMaskCache[maskingMap] = mask;
422   return mask;
423 }
424 
425 /// Masks an operation with the canonical vector mask if the operation needs
426 /// masking. Returns the masked operation or the original operation if masking
427 /// is not needed. If provided, the canonical mask for this operation is
428 /// permuted using `maybeMaskingMap`.
429 Operation *
430 VectorizationState::maskOperation(RewriterBase &rewriter, Operation *opToMask,
431                                   LinalgOp linalgOp,
432                                   std::optional<AffineMap> maybeMaskingMap) {
433   LDBG("Trying to mask: " << *opToMask << "\n");
434 
435   // Create or retrieve mask for this operation.
436   Value mask =
437       getOrCreateMaskFor(rewriter, opToMask, linalgOp, maybeMaskingMap);
438 
439   if (!mask) {
440     LDBG("No mask required\n");
441     return opToMask;
442   }
443 
444   // Wrap the operation with a new `vector.mask` and update D-U chain.
445   assert(opToMask && "Expected a valid operation to mask");
446   auto maskOp = cast<vector::MaskOp>(
447       mlir::vector::maskOperation(rewriter, opToMask, mask));
448   Operation *maskOpTerminator = &maskOp.getMaskRegion().front().back();
449 
450   for (auto [resIdx, resVal] : llvm::enumerate(opToMask->getResults()))
451     rewriter.replaceAllUsesExcept(resVal, maskOp.getResult(resIdx),
452                                   maskOpTerminator);
453 
454   LDBG("Masked operation: " << *maskOp << "\n");
455   return maskOp;
456 }
457 
458 /// Given an indexing `map` coming from a LinalgOp indexing, restricted to a
459 /// projectedPermutation, compress the unused dimensions to serve as a
460 /// permutation_map for a vector transfer operation.
461 /// For example, given a linalg op such as:
462 ///
463 /// ```
464 ///   %0 = linalg.generic {
465 ///        indexing_maps = affine_map<(d0, d1, d2, d3, d4) -> (d4, d0, d2)>,
466 ///        indexing_maps = affine_map<(d0, d1, d2, d3, d4) -> (d1, d3)>
467 ///      }
468 ///     ins(%0 : tensor<2x3x4xf32>)
469 ///    outs(%1 : tensor<5x6xf32>)
470 /// ```
471 ///
472 /// the iteration domain size of the linalg op is 3x5x4x6x2. The first affine
473 /// map is reindexed to `affine_map<(d0, d1, d2) -> (d2, d0, d1)>`, the second
474 /// affine map is reindexed to `affine_map<(d0, d1) -> (d0, d1)>`.
475 static AffineMap reindexIndexingMap(AffineMap map) {
476   assert(map.isProjectedPermutation(/*allowZeroInResults=*/true) &&
477          "expected projected permutation");
478   auto res = compressUnusedDims(map);
479   assert(res.getNumDims() == res.getNumResults() &&
480          "expected reindexed map with same number of dims and results");
481   return res;
482 }
483 
484 /// Helper enum to represent conv1d input traversal order.
485 enum class Conv1DOpOrder {
486   W,   // Corresponds to non-channeled 1D convolution operation.
487   Ncw, // Corresponds to operation that traverses the input in (n, c, w) order.
488   Nwc  // Corresponds to operation that traverses the input in (n, w, c) order.
489 };
490 
491 /// Helper data structure to represent the result of vectorization.
492 /// In certain specific cases, like terminators, we do not want to propagate/
493 enum VectorizationStatus {
494   /// Op failed to vectorize.
495   Failure = 0,
496   /// Op vectorized and custom function took care of replacement logic
497   NoReplace,
498   /// Op vectorized into a new Op whose results will replace original Op's
499   /// results.
500   NewOp
501   // TODO: support values if Op vectorized to Many-Ops whose results we need to
502   // aggregate for replacement.
503 };
504 struct VectorizationResult {
505   /// Return status from vectorizing the current op.
506   enum VectorizationStatus status = VectorizationStatus::Failure;
507   /// New vectorized operation to replace the current op.
508   /// Replacement behavior is specified by `status`.
509   Operation *newOp;
510 };
511 
512 std::optional<vector::CombiningKind>
513 mlir::linalg::getCombinerOpKind(Operation *combinerOp) {
514   using ::mlir::vector::CombiningKind;
515 
516   if (!combinerOp)
517     return std::nullopt;
518   return llvm::TypeSwitch<Operation *, std::optional<CombiningKind>>(combinerOp)
519       .Case<arith::AddIOp, arith::AddFOp>(
520           [&](auto op) { return CombiningKind::ADD; })
521       .Case<arith::AndIOp>([&](auto op) { return CombiningKind::AND; })
522       .Case<arith::MaxSIOp>([&](auto op) { return CombiningKind::MAXSI; })
523       .Case<arith::MaxUIOp>([&](auto op) { return CombiningKind::MAXUI; })
524       .Case<arith::MaximumFOp>([&](auto op) { return CombiningKind::MAXIMUMF; })
525       .Case<arith::MaxNumFOp>([&](auto op) { return CombiningKind::MAXNUMF; })
526       .Case<arith::MinSIOp>([&](auto op) { return CombiningKind::MINSI; })
527       .Case<arith::MinUIOp>([&](auto op) { return CombiningKind::MINUI; })
528       .Case<arith::MinimumFOp>([&](auto op) { return CombiningKind::MINIMUMF; })
529       .Case<arith::MinNumFOp>([&](auto op) { return CombiningKind::MINNUMF; })
530       .Case<arith::MulIOp, arith::MulFOp>(
531           [&](auto op) { return CombiningKind::MUL; })
532       .Case<arith::OrIOp>([&](auto op) { return CombiningKind::OR; })
533       .Case<arith::XOrIOp>([&](auto op) { return CombiningKind::XOR; })
534       .Default([&](auto op) { return std::nullopt; });
535 }
536 
537 /// Check whether `outputOperand` is a reduction with a single combiner
538 /// operation. Return the combiner operation of the reduction. Return
539 /// nullptr otherwise. Multiple reduction operations would impose an
540 /// ordering between reduction dimensions and is currently unsupported in
541 /// Linalg. This limitation is motivated by the fact that e.g. min(max(X)) !=
542 /// max(min(X))
543 // TODO: use in LinalgOp verification, there is a circular dependency atm.
544 static Operation *matchLinalgReduction(OpOperand *outputOperand) {
545   auto linalgOp = cast<LinalgOp>(outputOperand->getOwner());
546   unsigned outputPos =
547       outputOperand->getOperandNumber() - linalgOp.getNumDpsInputs();
548   // Only single combiner operations are supported for now.
549   SmallVector<Operation *, 4> combinerOps;
550   if (!matchReduction(linalgOp.getRegionOutputArgs(), outputPos, combinerOps) ||
551       combinerOps.size() != 1)
552     return nullptr;
553 
554   // Return the combiner operation.
555   return combinerOps[0];
556 }
557 
558 /// Broadcast `value` to a vector of `shape` if possible. Return value
559 /// otherwise.
560 static Value broadcastIfNeeded(OpBuilder &b, Value value, Type dstType) {
561   auto dstVecType = dyn_cast<VectorType>(dstType);
562   // If no shape to broadcast to, just return `value`.
563   if (dstVecType.getRank() == 0)
564     return value;
565   if (vector::isBroadcastableTo(value.getType(), dstVecType) !=
566       vector::BroadcastableToResult::Success)
567     return value;
568   Location loc = b.getInsertionPoint()->getLoc();
569   return b.createOrFold<vector::BroadcastOp>(loc, dstVecType, value);
570 }
571 
572 /// Create MultiDimReductionOp to compute the reduction for `reductionOp`. This
573 /// assumes that `reductionOp` has two operands and one of them is the reduction
574 /// initial value.buildMultiDimReduce
575 // Note: this is a true builder that notifies the OpBuilder listener.
576 // TODO: Consider moving as a static helper on the ReduceOp.
577 static Operation *buildMultiDimReduce(OpBuilder &b, Operation *reduceOp,
578                                       Value valueToReduce, Value acc,
579                                       ArrayRef<bool> dimsToMask) {
580   auto maybeKind = getCombinerOpKind(reduceOp);
581   assert(maybeKind && "Failed precondition: could not get reduction kind");
582   return b.create<vector::MultiDimReductionOp>(
583       reduceOp->getLoc(), valueToReduce, acc, dimsToMask, *maybeKind);
584 }
585 
586 static SmallVector<bool> getDimsToReduce(LinalgOp linalgOp) {
587   return llvm::to_vector(
588       llvm::map_range(linalgOp.getIteratorTypesArray(), isReductionIterator));
589 }
590 
591 /// Check if `op` is a linalg.reduce or a linalg.generic that has at least one
592 /// reduction iterator.
593 static bool hasReductionIterator(LinalgOp &op) {
594   return isa<linalg::ReduceOp>(op) ||
595          (isa<linalg::GenericOp>(op) &&
596           llvm::any_of(op.getIteratorTypesArray(), isReductionIterator));
597 }
598 
599 /// Build a vector.transfer_write of `value` into `outputOperand` at indices set
600 /// to all `0`; where `outputOperand` is an output operand of the LinalgOp
601 /// currently being vectorized. If `dest` has null rank, build an memref.store.
602 /// Return the produced value or null if no value is produced.
603 // Note: this is a true builder that notifies the OpBuilder listener.
604 // TODO: Consider moving as a static helper on the ReduceOp.
605 static Value buildVectorWrite(RewriterBase &rewriter, Value value,
606                               OpOperand *outputOperand,
607                               VectorizationState &state) {
608   Location loc = value.getLoc();
609   auto linalgOp = cast<LinalgOp>(outputOperand->getOwner());
610   AffineMap opOperandMap = linalgOp.getMatchingIndexingMap(outputOperand);
611 
612   // Compute the vector type of the value to store. This type should be an
613   // identity or projection of the canonical vector type without any permutation
614   // applied, given that any permutation in a transfer write happens as part of
615   // the write itself.
616   AffineMap vectorTypeMap = AffineMap::getFilteredIdentityMap(
617       opOperandMap.getContext(), opOperandMap.getNumInputs(),
618       [&](AffineDimExpr dimExpr) -> bool {
619         return llvm::is_contained(opOperandMap.getResults(), dimExpr);
620       });
621   auto vectorType = state.getCanonicalVecType(
622       getElementTypeOrSelf(outputOperand->get().getType()), vectorTypeMap);
623 
624   Operation *write;
625   if (vectorType.getRank() > 0) {
626     AffineMap writeMap = inversePermutation(reindexIndexingMap(opOperandMap));
627     SmallVector<Value> indices(linalgOp.getRank(outputOperand),
628                                rewriter.create<arith::ConstantIndexOp>(loc, 0));
629     value = broadcastIfNeeded(rewriter, value, vectorType);
630     assert(value.getType() == vectorType && "Incorrect type");
631     write = rewriter.create<vector::TransferWriteOp>(
632         loc, value, outputOperand->get(), indices, writeMap);
633   } else {
634     // 0-d case is still special: do not invert the reindexing writeMap.
635     if (!isa<VectorType>(value.getType()))
636       value = rewriter.create<vector::BroadcastOp>(loc, vectorType, value);
637     assert(value.getType() == vectorType && "Incorrect type");
638     write = rewriter.create<vector::TransferWriteOp>(
639         loc, value, outputOperand->get(), ValueRange{});
640   }
641 
642   write = state.maskOperation(rewriter, write, linalgOp, opOperandMap);
643 
644   // If masked, set in-bounds to true. Masking guarantees that the access will
645   // be in-bounds.
646   if (auto maskOp = dyn_cast<vector::MaskingOpInterface>(write)) {
647     auto maskedWriteOp = cast<vector::TransferWriteOp>(maskOp.getMaskableOp());
648     SmallVector<bool> inBounds(maskedWriteOp.getVectorType().getRank(), true);
649     maskedWriteOp.setInBoundsAttr(rewriter.getBoolArrayAttr(inBounds));
650   }
651 
652   LDBG("vectorized op: " << *write << "\n");
653   if (!write->getResults().empty())
654     return write->getResult(0);
655   return Value();
656 }
657 
658 // Custom vectorization precondition function type. This is intented to be used
659 // with CustomVectorizationHook. Returns success if the corresponding custom
660 // hook can vectorize the op.
661 using CustomVectorizationPrecondition =
662     std::function<LogicalResult(Operation *, bool)>;
663 
664 // Custom vectorization function type. Produce a vector form of Operation*
665 // assuming all its vectorized operands are already in the IRMapping.
666 // Return nullptr if the Operation cannot be vectorized.
667 using CustomVectorizationHook =
668     std::function<VectorizationResult(Operation *, const IRMapping &)>;
669 
670 /// Helper function to vectorize the terminator of a `linalgOp`. New result
671 /// vector values are appended to `newResults`. Return
672 /// VectorizationStatus::NoReplace to signal the vectorization algorithm that it
673 /// should not try to map produced operations and instead return the results
674 /// using the `newResults` vector making them available to the vectorization
675 /// algorithm for RAUW. This function is meant to be used as a
676 /// CustomVectorizationHook.
677 static VectorizationResult
678 vectorizeLinalgYield(RewriterBase &rewriter, Operation *op,
679                      const IRMapping &bvm, VectorizationState &state,
680                      LinalgOp linalgOp, SmallVectorImpl<Value> &newResults) {
681   auto yieldOp = dyn_cast<linalg::YieldOp>(op);
682   if (!yieldOp)
683     return VectorizationResult{VectorizationStatus::Failure, nullptr};
684   for (const auto &output : llvm::enumerate(yieldOp.getValues())) {
685     // TODO: Scan for an opportunity for reuse.
686     // TODO: use a map.
687     Value vectorValue = bvm.lookup(output.value());
688     Value newResult =
689         buildVectorWrite(rewriter, vectorValue,
690                          linalgOp.getDpsInitOperand(output.index()), state);
691     if (newResult)
692       newResults.push_back(newResult);
693   }
694 
695   return VectorizationResult{VectorizationStatus::NoReplace, nullptr};
696 }
697 
698 /// Helper function to vectorize the index operations of a `linalgOp`. Return
699 /// VectorizationStatus::NewOp to signal the vectorization algorithm that it
700 /// should map the produced operations. This function is meant to be used as a
701 /// CustomVectorizationHook.
702 static VectorizationResult vectorizeLinalgIndex(RewriterBase &rewriter,
703                                                 VectorizationState &state,
704                                                 Operation *op,
705                                                 LinalgOp linalgOp) {
706   IndexOp indexOp = dyn_cast<linalg::IndexOp>(op);
707   if (!indexOp)
708     return VectorizationResult{VectorizationStatus::Failure, nullptr};
709   auto loc = indexOp.getLoc();
710   // Compute the static loop sizes of the index op.
711   ArrayRef<int64_t> targetShape = state.getCanonicalVecShape();
712   auto dim = indexOp.getDim();
713   // Compute a one-dimensional index vector for the index op dimension.
714   auto indexVectorType =
715       VectorType::get({targetShape[dim]}, rewriter.getIndexType(),
716                       state.getScalableVecDims()[dim]);
717   auto indexSteps = rewriter.create<vector::StepOp>(loc, indexVectorType);
718   // Return the one-dimensional index vector if it lives in the trailing
719   // dimension of the iteration space since the vectorization algorithm in this
720   // case can handle the broadcast.
721   if (dim == targetShape.size() - 1)
722     return VectorizationResult{VectorizationStatus::NewOp, indexSteps};
723   // Otherwise permute the targetShape to move the index dimension last,
724   // broadcast the one-dimensional index vector to the permuted shape, and
725   // finally transpose the broadcasted index vector to undo the permutation.
726   auto permPattern =
727       llvm::to_vector(llvm::seq<unsigned>(0, targetShape.size()));
728   std::swap(permPattern[dim], permPattern.back());
729   auto permMap =
730       AffineMap::getPermutationMap(permPattern, linalgOp.getContext());
731 
732   auto broadCastOp = rewriter.create<vector::BroadcastOp>(
733       loc, state.getCanonicalVecType(rewriter.getIndexType(), permMap),
734       indexSteps);
735   SmallVector<int64_t> transposition =
736       llvm::to_vector<16>(llvm::seq<int64_t>(0, linalgOp.getNumLoops()));
737   std::swap(transposition.back(), transposition[dim]);
738   auto transposeOp =
739       rewriter.create<vector::TransposeOp>(loc, broadCastOp, transposition);
740   return VectorizationResult{VectorizationStatus::NewOp, transposeOp};
741 }
742 
743 /// Helper function to check if the tensor.extract can be vectorized by the
744 /// custom hook vectorizeTensorExtract.
745 static LogicalResult
746 tensorExtractVectorizationPrecondition(Operation *op, bool vectorizeNDExtract) {
747   tensor::ExtractOp extractOp = dyn_cast<tensor::ExtractOp>(op);
748   if (!extractOp)
749     return failure();
750 
751   if (extractOp.getIndices().size() != 1 && !vectorizeNDExtract)
752     return failure();
753 
754   // Check the index type, but only for non 0-d tensors (for which we do need
755   // access indices).
756   if (not extractOp.getIndices().empty()) {
757     if (!VectorType::isValidElementType(extractOp.getIndices()[0].getType()))
758       return failure();
759   }
760 
761   if (llvm::any_of(extractOp->getResultTypes(), [](Type type) {
762         return !VectorType::isValidElementType(type);
763       })) {
764     return failure();
765   }
766 
767   return success();
768 }
769 
770 /// Calculates the offsets (`$index_vec`) for `vector.gather` operations
771 /// generated from `tensor.extract`. The offset is calculated as follows
772 /// (example using scalar values):
773 ///
774 ///    offset = extractOp.indices[0]
775 ///    for (i = 1; i < numIndices; i++)
776 ///      offset = extractOp.dimSize[i] * offset + extractOp.indices[i];
777 ///
778 /// For tensor<45 x 80 x 15 x f32> and index [1, 2, 3], this leads to:
779 ///  offset = ( ( 1 ) * 80 +  2 ) * 15  + 3
780 static Value calculateGatherOffset(RewriterBase &rewriter,
781                                    VectorizationState &state,
782                                    tensor::ExtractOp extractOp,
783                                    const IRMapping &bvm) {
784   // The vector of indices for GatherOp should be shaped as the output vector.
785   auto indexVecType = state.getCanonicalVecType(rewriter.getIndexType());
786   auto loc = extractOp.getLoc();
787 
788   Value offset = broadcastIfNeeded(
789       rewriter, bvm.lookup(extractOp.getIndices()[0]), indexVecType);
790 
791   const size_t numIndices = extractOp.getIndices().size();
792   for (size_t i = 1; i < numIndices; i++) {
793     Value dimIdx = rewriter.create<arith::ConstantIndexOp>(loc, i);
794 
795     auto dimSize = broadcastIfNeeded(
796         rewriter,
797         rewriter.create<tensor::DimOp>(loc, extractOp.getTensor(), dimIdx),
798         indexVecType);
799 
800     offset = rewriter.create<arith::MulIOp>(loc, offset, dimSize);
801 
802     auto extractOpIndex = broadcastIfNeeded(
803         rewriter, bvm.lookup(extractOp.getIndices()[i]), indexVecType);
804 
805     offset = rewriter.create<arith::AddIOp>(loc, extractOpIndex, offset);
806   }
807 
808   return offset;
809 }
810 
811 enum VectorMemoryAccessKind { ScalarBroadcast, Contiguous, Gather };
812 
813 /// Find the index of the trailing non-unit dim in linalgOp. This hook is used
814 /// when checking whether `tensor.extract` Op (within a `linalg.generic` Op)
815 /// represents a contiguous load operation.
816 ///
817 /// Note that when calling this hook, it is assumed that the output vector is
818 /// effectively 1D. Other cases (i.e. reading n-D vectors) should've been
819 /// labelled as a gather load before entering this method.
820 ///
821 /// Following on from the above, it is assumed that:
822 ///   * for statically shaped loops, when no masks are used, only one dim is !=
823 ///   1 (that's what the shape of the output vector is based on).
824 ///   * for dynamically shaped loops, there might be more non-unit dims
825 ///   as the output vector type is user-specified.
826 ///
827 /// TODO: Statically shaped loops + vector masking
828 static uint64_t getTrailingNonUnitLoopDimIdx(LinalgOp linalgOp) {
829   SmallVector<int64_t> loopRanges = linalgOp.getStaticLoopRanges();
830   assert(linalgOp.hasDynamicShape() ||
831          llvm::count_if(loopRanges, [](int64_t dim) { return dim != 1; }) ==
832                  1 &&
833              "For statically shaped Linalg Ops, only one "
834              "non-unit loop dim is expected");
835 
836   size_t idx = loopRanges.size() - 1;
837   for (; idx >= 0; idx--)
838     if (loopRanges[idx] != 1)
839       break;
840 
841   return idx;
842 }
843 
844 /// Checks whether `val` can be used for calculating a loop invariant index.
845 static bool isLoopInvariantIdx(LinalgOp &linalgOp, Value &val,
846                                VectorType resType) {
847 
848   assert(((llvm::count_if(resType.getShape(),
849                           [](int64_t dimSize) { return dimSize > 1; }) == 1)) &&
850          "n-D vectors are not yet supported");
851 
852   // Blocks outside _this_ linalg.generic are effectively loop invariant.
853   // However, analysing block arguments for _this_ linalg.generic Op is a bit
854   // tricky. Just bail out in the latter case.
855   // TODO: We could try analysing the corresponding affine map here.
856   auto *block = linalgOp.getBlock();
857   if (isa<BlockArgument>(val))
858     return llvm::all_of(block->getArguments(),
859                         [&val](Value v) { return (v != val); });
860 
861   Operation *defOp = val.getDefiningOp();
862   assert(defOp && "This is neither a block argument nor an operation result");
863 
864   // IndexOp is loop invariant as long as its result remains constant across
865   // iterations. Note that for dynamic shapes, the corresponding dim will also
866   // be conservatively treated as != 1.
867   if (auto indexOp = dyn_cast<linalg::IndexOp>(defOp)) {
868     return linalgOp.getStaticLoopRanges()[indexOp.getDim()] == 1;
869   }
870 
871   auto *ancestor = block->findAncestorOpInBlock(*defOp);
872 
873   // Values define outside `linalgOp` are loop invariant.
874   if (!ancestor)
875     return true;
876 
877   // Values defined inside `linalgOp`, which are constant, are loop invariant.
878   if (isa<arith::ConstantOp>(ancestor))
879     return true;
880 
881   bool result = true;
882   for (auto op : ancestor->getOperands())
883     result &= isLoopInvariantIdx(linalgOp, op, resType);
884 
885   return result;
886 }
887 
888 /// Check whether `val` could be used for calculating the trailing index for a
889 /// contiguous load operation.
890 ///
891 /// There are currently 3 types of values that are allowed here:
892 ///   1. loop-invariant values,
893 ///   2. values that increment by 1 with every loop iteration,
894 ///   3. results of basic arithmetic operations (linear and continuous)
895 ///      involving 1., 2. and 3.
896 /// This method returns True if indeed only such values are used in calculating
897 /// `val.`
898 ///
899 /// Additionally, the trailing index for a contiguous load operation should
900 /// increment by 1 with every loop iteration, i.e. be based on:
901 ///   * `linalg.index <dim>` ,
902 /// where <dim> is the trailing non-unit dim of the iteration space (this way,
903 /// `linalg.index <dim>` increments by 1 with every loop iteration).
904 /// `foundIndexOp` is updated to `true` when such Op is found.
905 static bool isContiguousLoadIdx(LinalgOp &linalgOp, Value &val,
906                                 bool &foundIndexOp, VectorType resType) {
907 
908   assert(((llvm::count_if(resType.getShape(),
909                           [](int64_t dimSize) { return dimSize > 1; }) == 1)) &&
910          "n-D vectors are not yet supported");
911 
912   // Blocks outside _this_ linalg.generic are effectively loop invariant.
913   // However, analysing block arguments for _this_ linalg.generic Op is a bit
914   // tricky. Just bail out in the latter case.
915   // TODO: We could try analysing the corresponding affine map here.
916   auto *block = linalgOp.getBlock();
917   if (isa<BlockArgument>(val))
918     return llvm::all_of(block->getArguments(),
919                         [&val](Value v) { return (v != val); });
920 
921   Operation *defOp = val.getDefiningOp();
922   assert(defOp && "This is neither a block argument nor an operation result");
923 
924   if (auto indexOp = dyn_cast<linalg::IndexOp>(defOp)) {
925     auto loopDimThatIncrementsByOne = getTrailingNonUnitLoopDimIdx(linalgOp);
926 
927     foundIndexOp = (indexOp.getDim() == loopDimThatIncrementsByOne);
928     return true;
929   }
930 
931   auto *ancestor = block->findAncestorOpInBlock(*defOp);
932 
933   if (!ancestor)
934     return false;
935 
936   // Conservatively reject Ops that could lead to indices with stride other
937   // than 1.
938   if (!isa<arith::AddIOp, arith::ConstantOp, linalg::IndexOp>(ancestor))
939     return false;
940 
941   bool result = false;
942   for (auto op : ancestor->getOperands())
943     result |= isContiguousLoadIdx(linalgOp, op, foundIndexOp, resType);
944 
945   return result;
946 }
947 
948 /// Infer the memory access pattern for the input ExtractOp
949 ///
950 /// Based on the ExtratOp result shape and the access indices, decides whether
951 /// this Op corresponds to a contiguous load (including a broadcast of a scalar)
952 /// or a gather load. When analysing the ExtractOp indices (to identify
953 /// contiguous laods), this method looks for "loop" invariant indices (e.g.
954 /// block arguments) and indices that change linearly (e.g. via `linalg.index`
955 /// Op).
956 ///
957 /// Note that it is always safe to use gather load operations for contiguous
958 /// loads (albeit slow), but not vice-versa. When in doubt, bail out and assume
959 /// that `extractOp` is a gather load.
960 static VectorMemoryAccessKind
961 getTensorExtractMemoryAccessPattern(tensor::ExtractOp extractOp,
962                                     LinalgOp &linalgOp, VectorType resType) {
963 
964   auto inputShape = cast<ShapedType>(extractOp.getTensor().getType());
965 
966   // 0. Is this a 0-D vector? If yes then this is a scalar broadcast.
967   if (inputShape.getShape().empty())
968     return VectorMemoryAccessKind::ScalarBroadcast;
969 
970   // True for vectors that are effectively 1D, e.g. `vector<1x4x1xi32>`, false
971   // otherwise.
972   bool isOutput1DVector =
973       (llvm::count_if(resType.getShape(),
974                       [](int64_t dimSize) { return dimSize > 1; }) == 1);
975   // 1. Assume that it's a gather load when reading non-1D vector.
976   if (!isOutput1DVector)
977     return VectorMemoryAccessKind::Gather;
978 
979   bool leadingIdxsLoopInvariant = true;
980 
981   // 2. Analyze the leading indices of `extractOp`.
982   // Look at the way each index is calculated and decide whether it is suitable
983   // for a contiguous load, i.e. whether it's loop invariant. If not, it's a
984   // gather load.
985   auto indices = extractOp.getIndices();
986   auto leadIndices = indices.drop_back(1);
987 
988   for (auto [i, indexVal] : llvm::enumerate(leadIndices)) {
989     if (inputShape.getShape()[i] == 1)
990       continue;
991 
992     leadingIdxsLoopInvariant &= isLoopInvariantIdx(linalgOp, indexVal, resType);
993   }
994 
995   if (!leadingIdxsLoopInvariant) {
996     LDBG("Found gather load: " << extractOp);
997     return VectorMemoryAccessKind::Gather;
998   }
999 
1000   // 3. Analyze the trailing index for `extractOp`.
1001   // At this point we know that the leading indices are loop invariant. This
1002   // means that is potentially a scalar or a contiguous load. We can decide
1003   // based on the trailing idx.
1004   auto extractOpTrailingIdx = indices.back();
1005 
1006   // 3a. Scalar broadcast load
1007   // If the trailing index is loop invariant then this is a scalar load.
1008   if (leadingIdxsLoopInvariant &&
1009       isLoopInvariantIdx(linalgOp, extractOpTrailingIdx, resType)) {
1010     LDBG("Found scalar broadcast load: " << extractOp);
1011 
1012     return VectorMemoryAccessKind::ScalarBroadcast;
1013   }
1014 
1015   // 3b. Contiguous loads
1016   // The trailing `extractOp` index should increment with every loop iteration.
1017   // This effectively means that it must be based on the trailing loop index.
1018   // This is what the following bool captures.
1019   bool foundIndexOp = false;
1020   bool isContiguousLoad = isContiguousLoadIdx(linalgOp, extractOpTrailingIdx,
1021                                               foundIndexOp, resType);
1022   // TODO: Support generating contiguous loads for column vectors - that will
1023   // require adding a permutation map to tranfer_read Ops.
1024   bool isRowVector = resType.getShape().back() != 1;
1025   isContiguousLoad &= (foundIndexOp && isRowVector);
1026 
1027   if (isContiguousLoad) {
1028     LDBG("Found contigous load: " << extractOp);
1029     return VectorMemoryAccessKind::Contiguous;
1030   }
1031 
1032   // 4. Fallback case - gather load.
1033   LDBG("Found gather load: " << extractOp);
1034   return VectorMemoryAccessKind::Gather;
1035 }
1036 
1037 /// Helper function to vectorize the tensor.extract operations. Returns
1038 /// VectorizationStatus::NewOp to signal the vectorization algorithm that it
1039 /// should map the produced operations. This function is meant to be used as a
1040 /// CustomVectorizationHook.
1041 static VectorizationResult
1042 vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state,
1043                        Operation *op, LinalgOp linalgOp, const IRMapping &bvm) {
1044   tensor::ExtractOp extractOp = dyn_cast<tensor::ExtractOp>(op);
1045   if (!extractOp)
1046     return VectorizationResult{VectorizationStatus::Failure, nullptr};
1047   auto loc = extractOp.getLoc();
1048 
1049   // Compute the static loop sizes of the extract op.
1050   auto resultType = state.getCanonicalVecType(extractOp.getResult().getType());
1051   auto maskConstantOp = rewriter.create<arith::ConstantOp>(
1052       loc,
1053       DenseIntElementsAttr::get(state.getCanonicalVecType(rewriter.getI1Type()),
1054                                 /*value=*/true));
1055   auto passThruConstantOp =
1056       rewriter.create<arith::ConstantOp>(loc, rewriter.getZeroAttr(resultType));
1057 
1058   // Base indices are currently set to 0. We will need to re-visit if more
1059   // generic scenarios are to be supported.
1060   SmallVector<Value> baseIndices(
1061       extractOp.getIndices().size(),
1062       rewriter.create<arith::ConstantIndexOp>(loc, 0));
1063 
1064   VectorMemoryAccessKind memAccessKind =
1065       getTensorExtractMemoryAccessPattern(extractOp, linalgOp, resultType);
1066 
1067   // 1. Handle gather access
1068   if (memAccessKind == VectorMemoryAccessKind::Gather) {
1069     Value offset = calculateGatherOffset(rewriter, state, extractOp, bvm);
1070 
1071     // Generate the gather load
1072     Operation *gatherOp = rewriter.create<vector::GatherOp>(
1073         loc, resultType, extractOp.getTensor(), baseIndices, offset,
1074         maskConstantOp, passThruConstantOp);
1075     gatherOp = state.maskOperation(rewriter, gatherOp, linalgOp);
1076 
1077     LDBG("Vectorised as gather load: " << extractOp << "\n");
1078     return VectorizationResult{VectorizationStatus::NewOp, gatherOp};
1079   }
1080 
1081   // 2. Handle:
1082   //  a. scalar loads + broadcast,
1083   //  b. contiguous loads.
1084   // Both cases use vector.transfer_read.
1085 
1086   assert(llvm::count_if(resultType.getShape(),
1087                         [](uint64_t dim) { return dim != 1; }) &&
1088          "Contiguous loads and scalar loads + broadcast only support 1-D "
1089          "vectors ATM!");
1090 
1091   // Collect indices for `vector.transfer_read`. At this point, the indices will
1092   // either be scalars or would have been broadcast to vectors matching the
1093   // result type. For indices that are vectors, there are two options:
1094   //    * for non-trailing indices, all elements are identical (contiguous
1095   //      loads are identified by looking for non-trailing indices that are
1096   //      invariant with respect to the corresponding linalg.generic), or
1097   //    * for trailing indices, the index vector will contain values with stride
1098   //      one, but for `vector.transfer_read` only the first (i.e. 0th) index is
1099   //      needed.
1100   // This means that
1101   //   * for scalar indices - just re-use it,
1102   //   * for vector indices (e.g. `vector<1x1x4xindex>`) - extract the bottom
1103   //    (0th) element and use that.
1104   SmallVector<Value> transferReadIdxs;
1105   auto zero = rewriter.create<arith::ConstantOp>(
1106       loc, rewriter.getI32Type(), rewriter.getZeroAttr(rewriter.getI32Type()));
1107   for (size_t i = 0; i < extractOp.getIndices().size(); i++) {
1108     Value idx = bvm.lookup(extractOp.getIndices()[i]);
1109     if (idx.getType().isIndex()) {
1110       transferReadIdxs.push_back(idx);
1111       continue;
1112     }
1113 
1114     auto indexAs1dVector = rewriter.create<vector::ShapeCastOp>(
1115         loc,
1116         VectorType::get(resultType.getShape().back(), rewriter.getIndexType(),
1117                         resultType.getScalableDims().back()),
1118         idx);
1119     transferReadIdxs.push_back(
1120         rewriter.create<vector::ExtractElementOp>(loc, indexAs1dVector, zero));
1121   }
1122 
1123   // `tensor.extract_element` is always in-bounds, hence the following holds.
1124   auto dstRank = resultType.getRank();
1125   auto srcRank = extractOp.getTensor().getType().getRank();
1126   SmallVector<bool> inBounds(dstRank, true);
1127 
1128   // 2a. Handle scalar broadcast access.
1129   if (memAccessKind == VectorMemoryAccessKind::ScalarBroadcast) {
1130     MLIRContext *ctx = rewriter.getContext();
1131     SmallVector<AffineExpr> exprs(dstRank, getAffineConstantExpr(0, ctx));
1132     auto permutationMap = AffineMap::get(srcRank, 0, exprs, ctx);
1133 
1134     auto transferReadOp = rewriter.create<vector::TransferReadOp>(
1135         loc, resultType, extractOp.getTensor(), transferReadIdxs,
1136         permutationMap, inBounds);
1137 
1138     LDBG("Vectorised as scalar broadcast load: " << extractOp << "\n");
1139     return VectorizationResult{VectorizationStatus::NewOp, transferReadOp};
1140   }
1141 
1142   // 2b. Handle contiguous access.
1143   auto permutationMap = AffineMap::getMinorIdentityMap(
1144       srcRank, std::min(dstRank, srcRank), rewriter.getContext());
1145 
1146   int32_t rankDiff = dstRank - srcRank;
1147   // When dstRank > srcRank, broadcast the source tensor to the unitary leading
1148   // dims so that the ranks match. This is done by extending the map with 0s.
1149   // For example, for dstRank = 3, srcRank = 2, the following map created
1150   // above:
1151   //    (d0, d1) --> (d0, d1)
1152   // is extended as:
1153   //    (d0, d1) --> (0, d0, d1)
1154   while (rankDiff > 0) {
1155     permutationMap = permutationMap.insertResult(
1156         mlir::getAffineConstantExpr(0, rewriter.getContext()), 0);
1157     rankDiff--;
1158   }
1159 
1160   auto transferReadOp = rewriter.create<vector::TransferReadOp>(
1161       loc, resultType, extractOp.getTensor(), transferReadIdxs, permutationMap,
1162       inBounds);
1163 
1164   LDBG("Vectorised as contiguous load: " << extractOp);
1165   return VectorizationResult{VectorizationStatus::NewOp, transferReadOp};
1166 }
1167 
1168 /// Emit reduction operations if the shapes of the value to reduce is different
1169 /// that the result shape.
1170 // Note: this is a true builder that notifies the OpBuilder listener.
1171 // TODO: Consider moving as a static helper on the ReduceOp.
1172 static Operation *reduceIfNeeded(OpBuilder &b, LinalgOp linalgOp, Operation *op,
1173                                  Value reduceValue, Value initialValue,
1174                                  const IRMapping &bvm) {
1175   Value reduceVec = bvm.lookup(reduceValue);
1176   Value outputVec = bvm.lookup(initialValue);
1177   auto reduceType = dyn_cast<VectorType>(reduceVec.getType());
1178   auto outputType = dyn_cast<VectorType>(outputVec.getType());
1179   // Reduce only if needed as the value may already have been reduce for
1180   // contraction vectorization.
1181   if (!reduceType ||
1182       (outputType && reduceType.getShape() == outputType.getShape()))
1183     return nullptr;
1184   SmallVector<bool> dimsToMask = getDimsToReduce(linalgOp);
1185   return buildMultiDimReduce(b, op, reduceVec, outputVec, dimsToMask);
1186 }
1187 
1188 /// Generic vectorization for a single operation `op`, given already vectorized
1189 /// operands carried by `bvm`. Vectorization occurs as follows:
1190 ///   1. Try to apply any of the `customVectorizationHooks` and return its
1191 ///   result on success.
1192 ///   2. Clone any constant in the current scope without vectorization: each
1193 ///   consumer of the constant will later determine the shape to which the
1194 ///   constant needs to be broadcast to.
1195 ///   3. Fail on any remaining non `ElementwiseMappable` op. It is the purpose
1196 ///   of the `customVectorizationHooks` to cover such cases.
1197 ///   4. Clone `op` in vector form to a vector of shape prescribed by the first
1198 ///   operand of maximal rank. Other operands have smaller rank and are
1199 ///   broadcast accordingly. It is assumed this broadcast is always legal,
1200 ///   otherwise, it means one of the `customVectorizationHooks` is incorrect.
1201 ///
1202 /// This function assumes all operands of `op` have been vectorized and are in
1203 /// the `bvm` mapping. As a consequence, this function is meant to be called  on
1204 /// a topologically-sorted list of ops.
1205 /// This function does not update `bvm` but returns a VectorizationStatus that
1206 /// instructs the caller what `bvm` update needs to occur.
1207 static VectorizationResult
1208 vectorizeOneOp(RewriterBase &rewriter, VectorizationState &state,
1209                LinalgOp linalgOp, Operation *op, const IRMapping &bvm,
1210                ArrayRef<CustomVectorizationHook> customVectorizationHooks) {
1211   LDBG("vectorize op " << *op << "\n");
1212 
1213   // 1. Try to apply any CustomVectorizationHook.
1214   if (!customVectorizationHooks.empty()) {
1215     for (auto &customFunc : customVectorizationHooks) {
1216       VectorizationResult result = customFunc(op, bvm);
1217       if (result.status == VectorizationStatus::Failure)
1218         continue;
1219       return result;
1220     }
1221   }
1222 
1223   // 2. Constant ops don't get vectorized but rather broadcasted at their users.
1224   // Clone so that the constant is not confined to the linalgOp block .
1225   if (isa<arith::ConstantOp, func::ConstantOp>(op))
1226     return VectorizationResult{VectorizationStatus::NewOp, rewriter.clone(*op)};
1227 
1228   // 3. Only ElementwiseMappable are allowed in the generic vectorization.
1229   if (!OpTrait::hasElementwiseMappableTraits(op))
1230     return VectorizationResult{VectorizationStatus::Failure, nullptr};
1231 
1232   // 4 . Check if the operation is a reduction.
1233   SmallVector<std::pair<Value, Value>> reductionOperands;
1234   for (Value operand : op->getOperands()) {
1235     auto blockArg = dyn_cast<BlockArgument>(operand);
1236     if (!blockArg || blockArg.getOwner() != linalgOp.getBlock() ||
1237         blockArg.getArgNumber() < linalgOp.getNumDpsInputs())
1238       continue;
1239     SmallVector<Operation *> reductionOps;
1240     Value reduceValue = matchReduction(
1241         linalgOp.getRegionOutputArgs(),
1242         blockArg.getArgNumber() - linalgOp.getNumDpsInputs(), reductionOps);
1243     if (!reduceValue)
1244       continue;
1245     reductionOperands.push_back(std::make_pair(reduceValue, operand));
1246   }
1247   if (!reductionOperands.empty()) {
1248     assert(reductionOperands.size() == 1);
1249     Operation *reduceOp =
1250         reduceIfNeeded(rewriter, linalgOp, op, reductionOperands[0].first,
1251                        reductionOperands[0].second, bvm);
1252     if (reduceOp)
1253       return VectorizationResult{VectorizationStatus::NewOp, reduceOp};
1254   }
1255 
1256   // 5. Generic vectorization path for ElementwiseMappable ops.
1257   //   a. Get the first max ranked shape.
1258   VectorType firstMaxRankedType;
1259   for (Value operand : op->getOperands()) {
1260     auto vecOperand = bvm.lookup(operand);
1261     assert(vecOperand && "Vector operand couldn't be found");
1262 
1263     auto vecType = dyn_cast<VectorType>(vecOperand.getType());
1264     if (vecType && (!firstMaxRankedType ||
1265                     firstMaxRankedType.getRank() < vecType.getRank()))
1266       firstMaxRankedType = vecType;
1267   }
1268   //   b. Broadcast each op if needed.
1269   SmallVector<Value> vecOperands;
1270   for (Value scalarOperand : op->getOperands()) {
1271     Value vecOperand = bvm.lookup(scalarOperand);
1272     assert(vecOperand && "Vector operand couldn't be found");
1273 
1274     if (firstMaxRankedType) {
1275       auto vecType = VectorType::get(firstMaxRankedType.getShape(),
1276                                      getElementTypeOrSelf(vecOperand.getType()),
1277                                      firstMaxRankedType.getScalableDims());
1278       vecOperands.push_back(broadcastIfNeeded(rewriter, vecOperand, vecType));
1279     } else {
1280       vecOperands.push_back(vecOperand);
1281     }
1282   }
1283   //   c. for elementwise, the result is the vector with the firstMaxRankedShape
1284   SmallVector<Type> resultTypes;
1285   for (Type resultType : op->getResultTypes()) {
1286     resultTypes.push_back(
1287         firstMaxRankedType
1288             ? VectorType::get(firstMaxRankedType.getShape(), resultType,
1289                               firstMaxRankedType.getScalableDims())
1290             : resultType);
1291   }
1292   //   d. Build and return the new op.
1293   return VectorizationResult{
1294       VectorizationStatus::NewOp,
1295       rewriter.create(op->getLoc(), op->getName().getIdentifier(), vecOperands,
1296                       resultTypes, op->getAttrs())};
1297 }
1298 
1299 /// Generic vectorization function that rewrites the body of a `linalgOp` into
1300 /// vector form. Generic vectorization proceeds as follows:
1301 ///   1. Verify the `linalgOp` has one non-empty region.
1302 ///   2. Values defined above the region are mapped to themselves and will be
1303 ///   broadcasted on a per-need basis by their consumers.
1304 ///   3. Each region argument is vectorized into a vector.transfer_read (or 0-d
1305 ///   load).
1306 ///   TODO: Reuse opportunities for RAR dependencies.
1307 ///   4a. Register CustomVectorizationHook for YieldOp to capture the results.
1308 ///   4rewriter. Register CustomVectorizationHook for IndexOp to access the
1309 ///   iteration indices.
1310 ///   5. Iteratively call vectorizeOneOp on the region operations.
1311 ///
1312 /// When `broadcastToMaximalCommonShape` is set to true, eager broadcasting is
1313 /// performed to the maximal common vector size implied by the `linalgOp`
1314 /// iteration space. This eager broadcasting is introduced in the
1315 /// permutation_map of the vector.transfer_read operations. The eager
1316 /// broadcasting makes it trivial to detrmine where broadcast, transposes and
1317 /// reductions should occur, without any bookkeeping. The tradeoff is that, in
1318 /// the absence of good canonicalizations, the amount of work increases.
1319 /// This is not deemed a problem as we expect canonicalizations and foldings to
1320 /// aggressively clean up the useless work.
1321 static LogicalResult
1322 vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state,
1323                          LinalgOp linalgOp,
1324                          SmallVectorImpl<Value> &newResults) {
1325   LDBG("Vectorizing operation as linalg generic\n");
1326   Block *block = linalgOp.getBlock();
1327 
1328   // 2. Values defined above the region can only be broadcast for now. Make them
1329   // map to themselves.
1330   IRMapping bvm;
1331   SetVector<Value> valuesSet;
1332   mlir::getUsedValuesDefinedAbove(linalgOp->getRegion(0), valuesSet);
1333   bvm.map(valuesSet.getArrayRef(), valuesSet.getArrayRef());
1334 
1335   if (linalgOp.getNumDpsInits() == 0)
1336     return failure();
1337 
1338   // 3. Turn all BBArgs into vector.transfer_read / load.
1339   Location loc = linalgOp.getLoc();
1340   Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
1341   for (OpOperand *opOperand : linalgOp.getOpOperandsMatchingBBargs()) {
1342     BlockArgument bbarg = linalgOp.getMatchingBlockArgument(opOperand);
1343     if (linalgOp.isScalar(opOperand)) {
1344       bvm.map(bbarg, opOperand->get());
1345       continue;
1346     }
1347 
1348     // 3.a. Convert the indexing map for this input/output to a transfer read
1349     // permutation map and masking map.
1350     AffineMap indexingMap = linalgOp.getMatchingIndexingMap(opOperand);
1351 
1352     // Remove zeros from indexing map to use it as masking map.
1353     SmallVector<int64_t> zeroPos;
1354     auto results = indexingMap.getResults();
1355     for (const auto &result : llvm::enumerate(results)) {
1356       if (isa<AffineConstantExpr>(result.value())) {
1357         zeroPos.push_back(result.index());
1358       }
1359     }
1360     AffineMap maskingMap = indexingMap.dropResults(zeroPos);
1361 
1362     AffineMap readMap;
1363     VectorType readType;
1364     Type elemType = getElementTypeOrSelf(opOperand->get());
1365     if (linalgOp.isDpsInput(opOperand)) {
1366       // 3.a.i. For input reads we use the canonical vector shape.
1367       readMap = inverseAndBroadcastProjectedPermutation(indexingMap);
1368       readType = state.getCanonicalVecType(elemType);
1369     } else {
1370       // 3.a.ii. For output reads (iteration-carried dependence, e.g.,
1371       // reductions), the vector shape is computed by mapping the canonical
1372       // vector shape to the output domain and back to the canonical domain.
1373       readMap = inversePermutation(reindexIndexingMap(indexingMap));
1374       readType =
1375           state.getCanonicalVecType(elemType, readMap.compose(indexingMap));
1376     }
1377 
1378     SmallVector<Value> indices(linalgOp.getShape(opOperand).size(), zero);
1379 
1380     // Make sure that the in_bounds attribute corresponding to a broadcast dim
1381     // is `true`
1382     SmallVector<unsigned> broadcastedDims = readMap.getBroadcastDims();
1383     SmallVector<bool> inBounds(readType.getRank(), false);
1384 
1385     for (auto idx : broadcastedDims)
1386       inBounds[idx] = true;
1387 
1388     Operation *read = rewriter.create<vector::TransferReadOp>(
1389         loc, readType, opOperand->get(), indices, readMap,
1390         ArrayRef<bool>(inBounds));
1391     read = state.maskOperation(rewriter, read, linalgOp, maskingMap);
1392     Value readValue = read->getResult(0);
1393 
1394     // 3.b. If masked, set in-bounds to true. Masking guarantees that the access
1395     // will be in-bounds.
1396     if (auto maskOp = dyn_cast<vector::MaskingOpInterface>(read)) {
1397       SmallVector<bool> inBounds(readType.getRank(), true);
1398       cast<vector::TransferReadOp>(maskOp.getMaskableOp())
1399           .setInBoundsAttr(rewriter.getBoolArrayAttr(inBounds));
1400     }
1401 
1402     // 3.c. Not all ops support 0-d vectors, extract the scalar for now.
1403     // TODO: remove this.
1404     if (readType.getRank() == 0)
1405       readValue = rewriter.create<vector::ExtractElementOp>(loc, readValue);
1406 
1407     LDBG("New vectorized bbarg(" << bbarg.getArgNumber() << "): " << readValue
1408                                  << "\n");
1409     bvm.map(bbarg, readValue);
1410     bvm.map(opOperand->get(), readValue);
1411   }
1412 
1413   SmallVector<CustomVectorizationHook> hooks;
1414   // 4a. Register CustomVectorizationHook for yieldOp.
1415   CustomVectorizationHook vectorizeYield =
1416       [&](Operation *op, const IRMapping &bvm) -> VectorizationResult {
1417     return vectorizeLinalgYield(rewriter, op, bvm, state, linalgOp, newResults);
1418   };
1419   hooks.push_back(vectorizeYield);
1420 
1421   // 4b. Register CustomVectorizationHook for indexOp.
1422   CustomVectorizationHook vectorizeIndex =
1423       [&](Operation *op, const IRMapping &bvm) -> VectorizationResult {
1424     return vectorizeLinalgIndex(rewriter, state, op, linalgOp);
1425   };
1426   hooks.push_back(vectorizeIndex);
1427 
1428   // 4c. Register CustomVectorizationHook for extractOp.
1429   CustomVectorizationHook vectorizeExtract =
1430       [&](Operation *op, const IRMapping &bvm) -> VectorizationResult {
1431     return vectorizeTensorExtract(rewriter, state, op, linalgOp, bvm);
1432   };
1433   hooks.push_back(vectorizeExtract);
1434 
1435   // 5. Iteratively call `vectorizeOneOp` to each op in the slice.
1436   for (Operation &op : block->getOperations()) {
1437     VectorizationResult result =
1438         vectorizeOneOp(rewriter, state, linalgOp, &op, bvm, hooks);
1439     if (result.status == VectorizationStatus::Failure) {
1440       LDBG("failed to vectorize: " << op << "\n");
1441       return failure();
1442     }
1443     if (result.status == VectorizationStatus::NewOp) {
1444       Operation *maybeMaskedOp =
1445           state.maskOperation(rewriter, result.newOp, linalgOp);
1446       LDBG("New vector op: " << *maybeMaskedOp << "\n");
1447       bvm.map(op.getResults(), maybeMaskedOp->getResults());
1448     }
1449   }
1450 
1451   return success();
1452 }
1453 
1454 /// Given a tensor::PackOp, return the `dest` shape before any packing
1455 /// permutations.
1456 static SmallVector<int64_t> getTiledPackShape(tensor::PackOp packOp,
1457                                               ArrayRef<int64_t> destShape) {
1458   return applyPermutation(destShape, tensor::getPackInverseDestPerm(packOp));
1459 }
1460 
1461 /// Given an input, the mixed destSizes, and the vector sizes for vectorization,
1462 /// create an empty destination tensor and create a TransferWriteOp from the
1463 /// input to the empty tensor. If the destination shape is not the same as the
1464 /// inputVectorSizes for the first rank(inputVectorSizes) dims, then create a
1465 /// mask for the write. If `useInBoundsInsteadOfMasking` is set, then update the
1466 /// inBounds attribute of the transfer write op instead of masking.
1467 static Operation *createWriteOrMaskedWrite(OpBuilder &builder, Location loc,
1468                                            Value input,
1469                                            SmallVector<OpFoldResult> destSizes,
1470                                            ArrayRef<int64_t> inputVectorSizes,
1471                                            bool useInBoundsInsteadOfMasking) {
1472 
1473   auto inputType = cast<VectorType>(input.getType());
1474   Value dest = builder.create<tensor::EmptyOp>(loc, destSizes,
1475                                                inputType.getElementType());
1476   int64_t rank = cast<ShapedType>(dest.getType()).getRank();
1477   auto zero = builder.create<arith::ConstantIndexOp>(loc, 0);
1478   auto destShape = cast<ShapedType>(dest.getType()).getShape();
1479   SmallVector<bool> inBoundsVal(rank, true);
1480   if (useInBoundsInsteadOfMasking) {
1481     // Update the inBounds attribute.
1482     for (unsigned i = 0; i < rank; i++)
1483       inBoundsVal[i] = (destShape[i] == inputVectorSizes[i]) &&
1484                        !ShapedType::isDynamic(destShape[i]);
1485   }
1486   Operation *write = builder.create<vector::TransferWriteOp>(
1487       loc,
1488       /*vector=*/input,
1489       /*source=*/dest,
1490       /*indices=*/SmallVector<Value>(rank, zero),
1491       /*inBounds=*/inBoundsVal);
1492   assert(llvm::none_of(
1493              destShape.drop_front(inputVectorSizes.size()),
1494              [](int64_t size) { return size == ShapedType::kDynamic; }) &&
1495          "Only dims aligned with inputVectorSizes may be dynamic");
1496   if (useInBoundsInsteadOfMasking)
1497     return write;
1498   bool needMaskForWrite = !llvm::equal(
1499       inputVectorSizes, destShape.take_front(inputVectorSizes.size()));
1500   if (needMaskForWrite) {
1501     SmallVector<int64_t> writeMaskShape;
1502     writeMaskShape.append(inputVectorSizes.begin(), inputVectorSizes.end());
1503     writeMaskShape.append(destShape.begin() + inputVectorSizes.size(),
1504                           destShape.end());
1505     auto writeMaskType = VectorType::get(writeMaskShape, builder.getI1Type());
1506     Value maskForWrite =
1507         builder.create<vector::CreateMaskOp>(loc, writeMaskType, destSizes);
1508     write = mlir::vector::maskOperation(builder, write, maskForWrite);
1509   }
1510   return write;
1511 }
1512 
1513 /// Vectorize tensor::PackOp with (1) static innerTiles (2) constant
1514 /// padding value and (3) input vector sizes into:
1515 /// masked_transfer_read->shape_cast->transpose->transfer_write_in_bounds
1516 /// As in the following example:
1517 /// %pack = tensor.pack %src inner_dims_pos = [2, 1] inner_tiles = [16, 2]
1518 ///     into %dst : tensor<32x8x16xf32> -> tensor<32x4x1x16x2xf32>
1519 ///
1520 /// This pack would be vectorized to:
1521 ///
1522 /// %load = vector.mask %mask {
1523 ///     vector.transfer_read %arg0[%c0, %c0, %c0], %cst
1524 ///         {in_bounds = [true, true, true]} :
1525 ///         tensor<32x7x16xf32>, vector<32x8x16xf32>
1526 /// } : vector<32x8x16xi1> -> vector<32x8x16xf32>
1527 /// %shape_cast = vector.shape_cast %load : vector<32x8x16xf32>
1528 ///                                         to vector<32x4x2x1x16xf32>
1529 /// %transpose = vector.transpose %shape_cast, [0, 1, 3, 4, 2]
1530 ///     : vector<32x4x2x1x16xf32> to vector<32x4x1x16x2xf32>
1531 /// %write = vector.transfer_write %transpose,
1532 ///     %empty[%c0_0, %c0_0, %c0_0, %c0_0, %c0_0]
1533 ///     {in_bounds = [true, true, true, true, true]}
1534 ///     : vector<32x4x1x16x2xf32>, tensor<32x4x1x16x2xf32>
1535 ///
1536 /// If the (3) input vector sizes are not provided, the vector sizes are
1537 /// determined by the result tensor shape. Also, we update the inBounds
1538 /// attribute instead of masking.
1539 static LogicalResult
1540 vectorizeAsTensorPackOp(RewriterBase &rewriter, tensor::PackOp packOp,
1541                         ArrayRef<int64_t> inputVectorSizes,
1542                         SmallVectorImpl<Value> &newResults) {
1543   OpBuilder::InsertionGuard g(rewriter);
1544   rewriter.setInsertionPoint(packOp);
1545 
1546   Location loc = packOp.getLoc();
1547   auto padValue = packOp.getPaddingValue();
1548   if (!padValue) {
1549     padValue = rewriter.create<arith::ConstantOp>(
1550         loc, rewriter.getZeroAttr(packOp.getSourceType().getElementType()));
1551   }
1552   ReifiedRankedShapedTypeDims reifiedReturnShapes;
1553   LogicalResult status =
1554       cast<ReifyRankedShapedTypeOpInterface>(packOp.getOperation())
1555           .reifyResultShapes(rewriter, reifiedReturnShapes);
1556   (void)status; // prevent unused variable warning on non-assert builds.
1557   assert(succeeded(status) && "failed to reify result shapes");
1558 
1559   // If the input vector sizes are not provided, then the vector sizes are
1560   // determined by the result tensor shape. In case the vector sizes aren't
1561   // provided, we update the inBounds attribute instead of masking.
1562   bool useInBoundsInsteadOfMasking = false;
1563   if (inputVectorSizes.empty()) {
1564     ArrayRef<int64_t> resultTensorShape = packOp.getDestType().getShape();
1565     inputVectorSizes = resultTensorShape.take_front(packOp.getSourceRank());
1566     useInBoundsInsteadOfMasking = true;
1567   }
1568 
1569   // Create masked TransferReadOp.
1570   SmallVector<int64_t> inputShape(inputVectorSizes);
1571   auto innerTiles = packOp.getStaticInnerTiles();
1572   auto innerDimsPos = packOp.getInnerDimsPos();
1573   auto outerDimsPerm = packOp.getOuterDimsPerm();
1574   if (!outerDimsPerm.empty())
1575     applyPermutationToVector(inputShape,
1576                              invertPermutationVector(outerDimsPerm));
1577   for (auto [idx, size] : enumerate(innerTiles))
1578     inputShape[innerDimsPos[idx]] *= size;
1579   auto maskedRead = vector::createReadOrMaskedRead(
1580       rewriter, loc, packOp.getSource(), inputShape, padValue,
1581       useInBoundsInsteadOfMasking);
1582 
1583   // Create ShapeCastOp.
1584   SmallVector<int64_t> destShape(inputVectorSizes);
1585   destShape.append(innerTiles.begin(), innerTiles.end());
1586   auto tiledPackType = VectorType::get(getTiledPackShape(packOp, destShape),
1587                                        packOp.getDestType().getElementType());
1588   auto shapeCastOp =
1589       rewriter.create<vector::ShapeCastOp>(loc, tiledPackType, maskedRead);
1590 
1591   // Create TransposeOp.
1592   auto destPermutation =
1593       invertPermutationVector(tensor::getPackInverseDestPerm(packOp));
1594   auto transposeOp = rewriter.create<vector::TransposeOp>(
1595       loc, shapeCastOp.getResult(), destPermutation);
1596 
1597   // Create TransferWriteOp.
1598   Operation *write = createWriteOrMaskedWrite(
1599       rewriter, loc, transposeOp.getResult(), reifiedReturnShapes[0],
1600       inputVectorSizes, /*useInBoundsInsteadOfMasking=*/false);
1601   newResults.push_back(write->getResult(0));
1602   return success();
1603 }
1604 
1605 /// Vectorize a `tensor::UnPackOp` to these 4 Ops:
1606 ///   Vector::TransferReadOp - Reads a vector from the source tensor
1607 ///   vector::TransposeOp - Transpose the Source tensor
1608 ///   ShapeCastOp - Reshape the data based on the target.
1609 ///   vector::TransferWriteOp. - Write the result vector back to the destination
1610 ///   tensor.
1611 ///   If the vector sizes are not provided:
1612 ///   * the vector sizes are determined by the input operand and attributes,
1613 ///   * update the inBounds attribute instead of masking.
1614 static LogicalResult
1615 vectorizeAsTensorUnpackOp(RewriterBase &rewriter, tensor::UnPackOp unpackOp,
1616                           ArrayRef<int64_t> inputVectorSizes,
1617                           SmallVectorImpl<Value> &newResults) {
1618 
1619   OpBuilder::InsertionGuard g(rewriter);
1620   rewriter.setInsertionPoint(unpackOp);
1621 
1622   RankedTensorType unpackTensorType = unpackOp.getSourceType();
1623 
1624   ArrayRef<int64_t> innerDimPos = unpackOp.getInnerDimsPos();
1625   ArrayRef<int64_t> innerTiles = unpackOp.getStaticInnerTiles();
1626   ArrayRef<int64_t> sourceShape = unpackTensorType.getShape();
1627   bool useInBoundsInsteadOfMasking = false;
1628   ArrayRef<int64_t> outerDimsPerm = unpackOp.getOuterDimsPerm();
1629 
1630   auto destSize = unpackOp.getDestRank();
1631 
1632   if (!inputVectorSizes.empty())
1633     assert(inputVectorSizes.size() == destSize &&
1634            "Incorrect number of input vector sizes");
1635 
1636   // vectorSizes is the shape of the vector that will be used to do final
1637   // write on the destination tensor. It is set like this: Let's say the
1638   // source tensor is rank 'M' and the dest tensor rank 'N', where N <= M.
1639   // Thus:
1640   // 1. vectorSizes = sourceShape.take_front(N)
1641   // 2. if outer_dims_perms is present: do that permutation on vectorSizes.
1642   // 3. multiply all the locations in vectorSize pointed by innerDimPos by the
1643   //    innerTiles attribute value.
1644   SmallVector<int64_t> vectorSizes(inputVectorSizes);
1645   if (vectorSizes.empty()) {
1646     llvm::append_range(vectorSizes, sourceShape.take_front(destSize));
1647     if (!outerDimsPerm.empty())
1648       applyPermutationToVector(vectorSizes, outerDimsPerm);
1649     for (auto [i, pos] : llvm::enumerate(innerDimPos))
1650       vectorSizes[pos] *= innerTiles[i];
1651 
1652     useInBoundsInsteadOfMasking = true;
1653   }
1654 
1655   // readVectorSizes is the size of tensor used to read and apply mask. It is
1656   // set like this: Let's say the vectorSize (VS) array is size 'N' and
1657   // the sourceShape(SS) is 'M' where M >= N and InnerTileSizes (IT) of
1658   // size M-N
1659   // Thus:
1660   // - initially: readVectorSizes = vectorInputSizes
1661   // - Divide all the readMaskShape locations pointed by innerDimPos
1662   //   by the innerTileSize attribute value.
1663   // - if outer_dims_perms is present: do that permutation on readVectorSizes.
1664   // - Append the remaining shape from SS
1665   // E.g. let's say let's say unpackTensorType.getShape() = <8x8x32x16>
1666   // inner Dim Pos = [0, 1] and Inner Tiles = [32, 16], vector_sizes are [512,
1667   // 128] and outer_dims_perm is [1, 0] then read shape is:
1668   //   ReadVectorSizes(initial): [512, 128]
1669   //   Final Value(after innerDim Adjustment): [512/32, 128/16]
1670   //                                           = [16, 8]
1671   //   After applying outer_dims_perm: [8, 16]
1672   //   After appending the rest of the sourceShape: [8, 16, 32, 16]
1673 
1674   SmallVector<int64_t> readVectorSizes(vectorSizes.begin(), vectorSizes.end());
1675 
1676   for (auto [index, size] : enumerate(innerTiles)) {
1677     readVectorSizes[innerDimPos[index]] =
1678         llvm::divideCeil(readVectorSizes[innerDimPos[index]], size);
1679   }
1680   if (!outerDimsPerm.empty()) {
1681     applyPermutationToVector(readVectorSizes, outerDimsPerm);
1682   }
1683   readVectorSizes.append(sourceShape.begin() + vectorSizes.size(),
1684                          sourceShape.end());
1685 
1686   ReifiedRankedShapedTypeDims reifiedRetShapes;
1687   LogicalResult status =
1688       cast<ReifyRankedShapedTypeOpInterface>(unpackOp.getOperation())
1689           .reifyResultShapes(rewriter, reifiedRetShapes);
1690   if (status.failed()) {
1691     LDBG("Unable to reify result shapes of " << unpackOp);
1692     return failure();
1693   }
1694   Location loc = unpackOp->getLoc();
1695 
1696   auto padValue = rewriter.create<arith::ConstantOp>(
1697       loc, rewriter.getZeroAttr(unpackOp.getSourceType().getElementType()));
1698 
1699   // Read result, mask if necessary. If transferReadOp shape is not equal
1700   // to shape of source, then a mask is necessary.
1701   Value readResult = vector::createReadOrMaskedRead(
1702       rewriter, loc, unpackOp.getSource(), readVectorSizes, padValue,
1703       /*useInBoundsInsteadOfMasking=*/false);
1704 
1705   PackingMetadata packMetadata;
1706   SmallVector<int64_t> lastDimToInsertPosPerm =
1707       tensor::getUnPackInverseSrcPerm(unpackOp, packMetadata);
1708   ShapedType maskedOpShapedType = cast<ShapedType>(readResult.getType());
1709   SmallVector<int64_t> stripMineShape(maskedOpShapedType.getShape());
1710   mlir::Type stripMineElemType = maskedOpShapedType.getElementType();
1711   applyPermutationToVector(stripMineShape, lastDimToInsertPosPerm);
1712   RankedTensorType stripMineTensorType =
1713       RankedTensorType::get(stripMineShape, stripMineElemType);
1714   // Transpose the appropriate rows to match output.
1715   vector::TransposeOp transposeOp = rewriter.create<vector::TransposeOp>(
1716       loc, readResult, lastDimToInsertPosPerm);
1717 
1718   // Collapse the vector to the size required by result.
1719   RankedTensorType collapsedType = tensor::CollapseShapeOp::inferCollapsedType(
1720       stripMineTensorType, packMetadata.reassociations);
1721   mlir::VectorType vecCollapsedType =
1722       VectorType::get(collapsedType.getShape(), collapsedType.getElementType());
1723   vector::ShapeCastOp shapeCastOp = rewriter.create<vector::ShapeCastOp>(
1724       loc, vecCollapsedType, transposeOp->getResult(0));
1725 
1726   // writeVectorSizes had to match the shapecast shape for dynamic sizes,
1727   // otherwise the validator complains that the mask size is invalid.
1728   SmallVector<int64_t> writeVectorSizes(
1729       unpackOp.getDestType().hasStaticShape()
1730           ? vectorSizes
1731           : shapeCastOp.getResultVectorType().getShape());
1732   Operation *write = createWriteOrMaskedWrite(
1733       rewriter, loc, shapeCastOp.getResult(), reifiedRetShapes[0],
1734       writeVectorSizes, useInBoundsInsteadOfMasking);
1735   newResults.push_back(write->getResult(0));
1736   return success();
1737 }
1738 
1739 /// Vectorize a `padOp` with (1) static result type, (2) constant padding value
1740 /// and (3) all-zero lowPad to
1741 ///   `transfer_write_in_bounds(transfer_read_masked(pad_source, pad_value))`.
1742 static LogicalResult
1743 vectorizeAsTensorPadOp(RewriterBase &rewriter, tensor::PadOp padOp,
1744                        ArrayRef<int64_t> inputVectorSizes,
1745                        SmallVectorImpl<Value> &newResults) {
1746   auto padValue = padOp.getConstantPaddingValue();
1747   Location loc = padOp.getLoc();
1748 
1749   // transfer_write_in_bounds(transfer_read_masked(pad_source, pad_value))
1750   OpBuilder::InsertionGuard g(rewriter);
1751   rewriter.setInsertionPoint(padOp);
1752 
1753   ReifiedRankedShapedTypeDims reifiedReturnShapes;
1754   LogicalResult status =
1755       cast<ReifyRankedShapedTypeOpInterface>(padOp.getOperation())
1756           .reifyResultShapes(rewriter, reifiedReturnShapes);
1757   (void)status; // prevent unused variable warning on non-assert builds
1758   assert(succeeded(status) && "failed to reify result shapes");
1759   auto maskedRead = vector::createReadOrMaskedRead(
1760       rewriter, loc, padOp.getSource(), inputVectorSizes, padValue,
1761       /*useInBoundsInsteadOfMasking=*/false);
1762   Operation *write = createWriteOrMaskedWrite(
1763       rewriter, loc, maskedRead, reifiedReturnShapes[0], inputVectorSizes,
1764       /*useInBoundsInsteadOfMasking=*/false);
1765   newResults.push_back(write->getResult(0));
1766   return success();
1767 }
1768 
1769 // TODO: probably need some extra checks for reduction followed by consumer
1770 // ops that may not commute (e.g. linear reduction + non-linear instructions).
1771 static LogicalResult reductionPreconditions(LinalgOp op) {
1772   if (llvm::none_of(op.getIteratorTypesArray(), isReductionIterator)) {
1773     LDBG("reduction precondition failed: no reduction iterator\n");
1774     return failure();
1775   }
1776   for (OpOperand &opOperand : op.getDpsInitsMutable()) {
1777     AffineMap indexingMap = op.getMatchingIndexingMap(&opOperand);
1778     if (indexingMap.isPermutation())
1779       continue;
1780 
1781     Operation *reduceOp = matchLinalgReduction(&opOperand);
1782     if (!reduceOp || !getCombinerOpKind(reduceOp)) {
1783       LDBG("reduction precondition failed: reduction detection failed\n");
1784       return failure();
1785     }
1786   }
1787   return success();
1788 }
1789 
1790 static LogicalResult
1791 vectorizeDynamicConvOpPrecondition(linalg::LinalgOp conv,
1792                                    bool flatten1DDepthwiseConv) {
1793   if (flatten1DDepthwiseConv) {
1794     LDBG("Vectorization of flattened convs with dynamic shapes is not "
1795          "supported\n");
1796     return failure();
1797   }
1798 
1799   if (!isa<linalg::DepthwiseConv1DNwcWcOp>(conv)) {
1800     LDBG("Not a 1D depth-wise WC conv, dynamic shapes are not supported\n");
1801     return failure();
1802   }
1803 
1804   // Support dynamic shapes in 1D depthwise convolution, but only in the
1805   // _channel_ dimension.
1806   Value lhs = conv.getDpsInputOperand(0)->get();
1807   ArrayRef<int64_t> lhsShape = cast<ShapedType>(lhs.getType()).getShape();
1808   auto shapeWithoutCh = lhsShape.drop_back(1);
1809   if (ShapedType::isDynamicShape(shapeWithoutCh)) {
1810     LDBG("Dynamically-shaped op vectorization precondition failed: only "
1811          "channel dim can be dynamic\n");
1812     return failure();
1813   }
1814 
1815   return success();
1816 }
1817 
1818 static LogicalResult
1819 vectorizeDynamicLinalgOpPrecondition(linalg::LinalgOp op,
1820                                      bool flatten1DDepthwiseConv) {
1821   if (isa<ConvolutionOpInterface>(op.getOperation()))
1822     return vectorizeDynamicConvOpPrecondition(op, flatten1DDepthwiseConv);
1823 
1824   if (hasReductionIterator(op))
1825     return reductionPreconditions(op);
1826 
1827   // TODO: Masking only supports dynamic element-wise ops, linalg.generic ops,
1828   // linalg.copy ops and ops that implement ContractionOpInterface for now.
1829   if (!isElementwise(op) &&
1830       !isa<linalg::GenericOp, linalg::CopyOp, linalg::ContractionOpInterface>(
1831           op.getOperation()))
1832     return failure();
1833 
1834   LDBG("Dynamically-shaped op meets vectorization pre-conditions\n");
1835   return success();
1836 }
1837 
1838 /// Need to check if the inner-tiles are static/constant.
1839 static LogicalResult
1840 vectorizeUnPackOpPrecondition(tensor::UnPackOp unpackOp,
1841                               ArrayRef<int64_t> inputVectorSizes) {
1842 
1843   if (llvm::any_of(unpackOp.getInnerTiles(), [](OpFoldResult res) {
1844         return !getConstantIntValue(res).has_value();
1845       })) {
1846     LDBG("Inner-tiles must be constant: " << unpackOp << "\n");
1847     return failure();
1848   }
1849   ArrayRef<int64_t> resultShape = unpackOp.getDestType().getShape();
1850   bool satisfyEmptyCond = inputVectorSizes.empty() &&
1851                           unpackOp.getDestType().hasStaticShape() &&
1852                           unpackOp.getSourceType().hasStaticShape();
1853   if (!satisfyEmptyCond &&
1854       failed(vector::isValidMaskedInputVector(resultShape, inputVectorSizes)))
1855     return failure();
1856 
1857   return success();
1858 }
1859 
1860 static LogicalResult vectorizeLinalgOpPrecondition(
1861     LinalgOp linalgOp, ArrayRef<int64_t> inputVectorSizes,
1862     bool vectorizeNDExtract, bool flatten1DDepthwiseConv) {
1863   // tensor with dimension of 0 cannot be vectorized.
1864   if (llvm::is_contained(linalgOp.getStaticShape(), 0))
1865     return failure();
1866   // Check API contract for input vector sizes.
1867   if (!inputVectorSizes.empty() &&
1868       failed(vector::isValidMaskedInputVector(linalgOp.getStaticLoopRanges(),
1869                                               inputVectorSizes)))
1870     return failure();
1871 
1872   if (linalgOp.hasDynamicShape() && failed(vectorizeDynamicLinalgOpPrecondition(
1873                                         linalgOp, flatten1DDepthwiseConv))) {
1874     LDBG("Dynamically-shaped op failed vectorization pre-conditions\n");
1875     return failure();
1876   }
1877 
1878   SmallVector<CustomVectorizationPrecondition> customPreconditions;
1879 
1880   // Register CustomVectorizationPrecondition for extractOp.
1881   customPreconditions.push_back(tensorExtractVectorizationPrecondition);
1882 
1883   // All types in the body should be a supported element type for VectorType.
1884   for (Operation &innerOp : linalgOp->getRegion(0).front()) {
1885     // Check if any custom hook can vectorize the inner op.
1886     if (llvm::any_of(
1887             customPreconditions,
1888             [&](const CustomVectorizationPrecondition &customPrecondition) {
1889               return succeeded(
1890                   customPrecondition(&innerOp, vectorizeNDExtract));
1891             })) {
1892       continue;
1893     }
1894     if (llvm::any_of(innerOp.getOperandTypes(), [](Type type) {
1895           return !VectorType::isValidElementType(type);
1896         })) {
1897       return failure();
1898     }
1899     if (llvm::any_of(innerOp.getResultTypes(), [](Type type) {
1900           return !VectorType::isValidElementType(type);
1901         })) {
1902       return failure();
1903     }
1904   }
1905   if (isElementwise(linalgOp))
1906     return success();
1907 
1908   // TODO: isaConvolutionOpInterface that can also infer from generic
1909   // features. But we will still need stride/dilation attributes that will be
1910   // annoying to reverse-engineer...
1911   if (isa<ConvolutionOpInterface>(linalgOp.getOperation()))
1912     return success();
1913   // TODO: the common vector shape is equal to the static loop sizes only when
1914   // all indexing maps are projected permutations. For convs and stencils the
1915   // logic will need to evolve.
1916   if (!allIndexingsAreProjectedPermutation(linalgOp)) {
1917     LDBG("precondition failed: not projected permutations\n");
1918     return failure();
1919   }
1920   if (failed(reductionPreconditions(linalgOp))) {
1921     LDBG("precondition failed: reduction preconditions\n");
1922     return failure();
1923   }
1924   return success();
1925 }
1926 
1927 static LogicalResult
1928 vectorizePackOpPrecondition(tensor::PackOp packOp,
1929                             ArrayRef<int64_t> inputVectorSizes) {
1930   auto padValue = packOp.getPaddingValue();
1931   Attribute cstAttr;
1932   if (padValue && !matchPattern(padValue, m_Constant(&cstAttr))) {
1933     LDBG("pad value is not constant: " << packOp << "\n");
1934     return failure();
1935   }
1936   ArrayRef<int64_t> resultTensorShape = packOp.getDestType().getShape();
1937   bool satisfyEmptyCond = true;
1938   if (inputVectorSizes.empty()) {
1939     if (!packOp.getDestType().hasStaticShape() ||
1940         !packOp.getSourceType().hasStaticShape())
1941       satisfyEmptyCond = false;
1942   }
1943 
1944   if (!satisfyEmptyCond &&
1945       failed(vector::isValidMaskedInputVector(
1946           resultTensorShape.take_front(packOp.getSourceRank()),
1947           inputVectorSizes)))
1948     return failure();
1949 
1950   if (llvm::any_of(packOp.getInnerTiles(), [](OpFoldResult v) {
1951         return !getConstantIntValue(v).has_value();
1952       })) {
1953     LDBG("inner_tiles must be constant: " << packOp << "\n");
1954     return failure();
1955   }
1956 
1957   return success();
1958 }
1959 
1960 static LogicalResult
1961 vectorizePadOpPrecondition(tensor::PadOp padOp,
1962                            ArrayRef<int64_t> inputVectorSizes) {
1963   auto padValue = padOp.getConstantPaddingValue();
1964   if (!padValue) {
1965     LDBG("pad value is not constant: " << padOp << "\n");
1966     return failure();
1967   }
1968 
1969   ArrayRef<int64_t> resultTensorShape = padOp.getResultType().getShape();
1970   if (failed(vector::isValidMaskedInputVector(resultTensorShape,
1971                                               inputVectorSizes)))
1972     return failure();
1973 
1974   if (llvm::any_of(padOp.getLow(), [](Value v) {
1975         std::optional<int64_t> res = getConstantIntValue(v);
1976         return !res.has_value() || res.value() != 0;
1977       })) {
1978     LDBG("low pad must all be zero: " << padOp << "\n");
1979     return failure();
1980   }
1981 
1982   return success();
1983 }
1984 
1985 /// Preconditions for scalable vectors. This is quite restrictive - it models
1986 /// the fact that in practice we would only make selected dimensions scalable.
1987 static LogicalResult
1988 vectorizeScalableVectorPrecondition(Operation *op,
1989                                     ArrayRef<int64_t> inputVectorSizes,
1990                                     ArrayRef<bool> inputScalableVecDims) {
1991   assert(inputVectorSizes.size() == inputScalableVecDims.size() &&
1992          "Number of input vector sizes and scalable dims doesn't match");
1993 
1994   size_t numOfScalableDims =
1995       llvm::count_if(inputScalableVecDims, [](bool flag) { return flag; });
1996 
1997   if (numOfScalableDims == 0)
1998     return success();
1999 
2000   auto linalgOp = dyn_cast<LinalgOp>(op);
2001 
2002   // Cond 1: There's been no need for scalable vectorisation of
2003   // non-linalg Ops so far
2004   if (!linalgOp)
2005     return failure();
2006 
2007   // Cond 2: There's been no need for more than 2 scalable dims so far
2008   if (numOfScalableDims > 2)
2009     return failure();
2010 
2011   // Cond 3: Look at the configuration in `inputScalableVecDims` and verify that
2012   // it matches one of the supported cases:
2013   //  1. exactly 1 dim is scalable and that's the _last_ parallel dim
2014   //  2. exactly 2 dims are scalable and those are the _last two adjacent_
2015   //     parallel dims
2016   //  3. exactly 1 reduction dim is scalable and that's the last (innermost) dim
2017   // The 2nd restriction above means that only Matmul-like Ops are supported
2018   // when 2 dims are scalable, e.g. :
2019   //    * iterators = [parallel, parallel, reduction]
2020   //    * scalable flags = [true, true, false]
2021 
2022   // Find the first scalable flag
2023   bool seenParalell = false;
2024   auto iterators = linalgOp.getIteratorTypesArray();
2025   SmallVector<bool> scalableFlags(inputScalableVecDims);
2026   while (!scalableFlags.back()) {
2027     seenParalell |= (iterators.back() == utils::IteratorType::parallel);
2028 
2029     iterators.pop_back();
2030     scalableFlags.pop_back();
2031   }
2032 
2033   switch (iterators.back()) {
2034   case utils::IteratorType::reduction: {
2035     // Check 3. above is met.
2036     if (iterators.size() != inputVectorSizes.size()) {
2037       LDBG("Non-trailing reduction dim requested for scalable "
2038            "vectorization\n");
2039       return failure();
2040     }
2041     if (isa<linalg::MatmulOp>(op) || isa<linalg::MatmulTransposeAOp>(op)) {
2042       LDBG("Scalable vectorization of the reduction dim in Matmul-like ops "
2043            "is not supported\n");
2044       return failure();
2045     }
2046     break;
2047   }
2048   case utils::IteratorType::parallel: {
2049     // Check 1. and 2. above are met.
2050     if (seenParalell) {
2051       LDBG("Inner parallel dim not requested for scalable "
2052            "vectorization\n");
2053       return failure();
2054     }
2055     break;
2056   }
2057   }
2058 
2059   // If present, check the 2nd scalable dim. ATM, only Matmul-like Ops are
2060   // supported for which expect the folowing config:
2061   //    * iterators = [parallel, parallel, reduction]
2062   //    * scalable flags = [true, true, false]
2063   if (numOfScalableDims == 2) {
2064     // Disallow below case which breaks 3. above:
2065     //    * iterators = [..., parallel, reduction]
2066     //    * scalable flags = [..., true, true]
2067     if (iterators.back() == utils::IteratorType::reduction) {
2068       LDBG("Higher dim than the trailing reduction dim requested for scalable "
2069            "vectorization\n");
2070       return failure();
2071     }
2072     scalableFlags.pop_back();
2073     iterators.pop_back();
2074 
2075     if (!scalableFlags.back() ||
2076         (iterators.back() != utils::IteratorType::parallel))
2077       return failure();
2078   }
2079 
2080   // Cond 4: Only the following ops are supported in the
2081   // presence of scalable vectors
2082   return success(isElementwise(linalgOp) || isa<linalg::MatmulOp>(op) ||
2083                  isa<linalg::MatmulTransposeAOp>(op) ||
2084                  isa<linalg::DepthwiseConv1DNwcWcOp>(op) ||
2085                  isa<linalg::MatvecOp>(op) || hasReductionIterator(linalgOp));
2086 }
2087 
2088 LogicalResult mlir::linalg::vectorizeOpPrecondition(
2089     Operation *op, ArrayRef<int64_t> inputVectorSizes,
2090     ArrayRef<bool> inputScalableVecDims, bool vectorizeNDExtract,
2091     bool flatten1DDepthwiseConv) {
2092   if (failed(vectorizeScalableVectorPrecondition(op, inputVectorSizes,
2093                                                  inputScalableVecDims)))
2094     return failure();
2095 
2096   return TypeSwitch<Operation *, LogicalResult>(op)
2097       .Case<linalg::LinalgOp>([&](auto linalgOp) {
2098         return vectorizeLinalgOpPrecondition(linalgOp, inputVectorSizes,
2099                                              vectorizeNDExtract,
2100                                              flatten1DDepthwiseConv);
2101       })
2102       .Case<tensor::PadOp>([&](auto padOp) {
2103         return vectorizePadOpPrecondition(padOp, inputVectorSizes);
2104       })
2105       .Case<tensor::PackOp>([&](auto packOp) {
2106         return vectorizePackOpPrecondition(packOp, inputVectorSizes);
2107       })
2108       .Case<tensor::UnPackOp>([&](auto unpackOp) {
2109         return vectorizeUnPackOpPrecondition(unpackOp, inputVectorSizes);
2110       })
2111       .Default([](auto) { return failure(); });
2112 }
2113 
2114 /// Converts affine.apply Ops to arithmetic operations.
2115 static void convertAffineApply(RewriterBase &rewriter, LinalgOp linalgOp) {
2116   OpBuilder::InsertionGuard g(rewriter);
2117   auto toReplace = linalgOp.getBlock()->getOps<affine::AffineApplyOp>();
2118 
2119   for (auto op : make_early_inc_range(toReplace)) {
2120     rewriter.setInsertionPoint(op);
2121     auto expanded = affine::expandAffineExpr(
2122         rewriter, op->getLoc(), op.getAffineMap().getResult(0),
2123         op.getOperands().take_front(op.getAffineMap().getNumDims()),
2124         op.getOperands().take_back(op.getAffineMap().getNumSymbols()));
2125     rewriter.replaceOp(op, expanded);
2126   }
2127 }
2128 
2129 /// Emit a suitable vector form for an operation. If provided,
2130 /// `inputVectorSizes` are used to vectorize this operation.
2131 /// `inputVectorSizes` must match the rank of the iteration space of the
2132 /// operation and the input vector sizes must be greater than or equal to
2133 /// their counterpart iteration space sizes, if static. `inputVectorShapes`
2134 /// also allows the vectorization of operations with dynamic shapes.
2135 LogicalResult mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op,
2136                                       ArrayRef<int64_t> inputVectorSizes,
2137                                       ArrayRef<bool> inputScalableVecDims,
2138                                       bool vectorizeNDExtract,
2139                                       bool flatten1DDepthwiseConv) {
2140   LDBG("Attempting to vectorize:\n" << *op << "\n");
2141   LDBG("Input vector sizes: ");
2142   LLVM_DEBUG(llvm::interleaveComma(inputVectorSizes, llvm::dbgs()));
2143   LLVM_DEBUG(llvm::dbgs() << "\n");
2144   LDBG("Input scalable vector dims: ");
2145   LLVM_DEBUG(llvm::interleaveComma(inputScalableVecDims, llvm::dbgs()));
2146   LLVM_DEBUG(llvm::dbgs() << "\n");
2147 
2148   if (failed(vectorizeOpPrecondition(op, inputVectorSizes, inputScalableVecDims,
2149                                      vectorizeNDExtract,
2150                                      flatten1DDepthwiseConv))) {
2151     LDBG("Vectorization pre-conditions failed\n");
2152     return failure();
2153   }
2154 
2155   // Initialize vectorization state.
2156   VectorizationState state(rewriter);
2157   if (auto linalgOp = dyn_cast<linalg::LinalgOp>(op)) {
2158     if (failed(state.initState(rewriter, linalgOp, inputVectorSizes,
2159                                inputScalableVecDims))) {
2160       LDBG("Vectorization state couldn't be initialized\n");
2161       return failure();
2162     }
2163   }
2164 
2165   SmallVector<Value> results;
2166   auto vectorizeResult =
2167       TypeSwitch<Operation *, LogicalResult>(op)
2168           .Case<linalg::LinalgOp>([&](auto linalgOp) {
2169             // TODO: isaConvolutionOpInterface that can also infer from
2170             // generic features. Will require stride/dilation attributes
2171             // inference.
2172             if (isa<ConvolutionOpInterface>(linalgOp.getOperation())) {
2173               FailureOr<Operation *> convOr = vectorizeConvolution(
2174                   rewriter, linalgOp, inputVectorSizes, inputScalableVecDims,
2175                   flatten1DDepthwiseConv);
2176               if (succeeded(convOr)) {
2177                 llvm::append_range(results, (*convOr)->getResults());
2178                 return success();
2179               }
2180 
2181               LDBG("Unsupported convolution can't be vectorized.\n");
2182               return failure();
2183             }
2184 
2185             LDBG("Vectorize generic by broadcasting to the canonical vector "
2186                  "shape\n");
2187 
2188             // Pre-process before proceeding.
2189             convertAffineApply(rewriter, linalgOp);
2190 
2191             // TODO: 'vectorize' takes in a 'RewriterBase' which is up-casted
2192             // to 'OpBuilder' when it is passed over to some methods like
2193             // 'vectorizeAsLinalgGeneric'. This is highly problematic: if we
2194             // erase an op within these methods, the actual rewriter won't be
2195             // notified and we will end up with read-after-free issues!
2196             return vectorizeAsLinalgGeneric(rewriter, state, linalgOp, results);
2197           })
2198           .Case<tensor::PadOp>([&](auto padOp) {
2199             return vectorizeAsTensorPadOp(rewriter, padOp, inputVectorSizes,
2200                                           results);
2201           })
2202           .Case<tensor::PackOp>([&](auto packOp) {
2203             return vectorizeAsTensorPackOp(rewriter, packOp, inputVectorSizes,
2204                                            results);
2205           })
2206           .Case<tensor::UnPackOp>([&](auto unpackOp) {
2207             return vectorizeAsTensorUnpackOp(rewriter, unpackOp,
2208                                              inputVectorSizes, results);
2209           })
2210           .Default([](auto) { return failure(); });
2211 
2212   if (failed(vectorizeResult)) {
2213     LDBG("Vectorization failed\n");
2214     return failure();
2215   }
2216 
2217   if (!results.empty())
2218     rewriter.replaceOp(op, results);
2219   else
2220     rewriter.eraseOp(op);
2221 
2222   return success();
2223 }
2224 
2225 LogicalResult mlir::linalg::vectorizeCopy(RewriterBase &rewriter,
2226                                           memref::CopyOp copyOp) {
2227   auto srcType = cast<MemRefType>(copyOp.getSource().getType());
2228   auto dstType = cast<MemRefType>(copyOp.getTarget().getType());
2229   if (!srcType.hasStaticShape() || !dstType.hasStaticShape())
2230     return failure();
2231 
2232   auto srcElementType = getElementTypeOrSelf(srcType);
2233   auto dstElementType = getElementTypeOrSelf(dstType);
2234   if (!VectorType::isValidElementType(srcElementType) ||
2235       !VectorType::isValidElementType(dstElementType))
2236     return failure();
2237 
2238   auto readType = VectorType::get(srcType.getShape(), srcElementType);
2239   auto writeType = VectorType::get(dstType.getShape(), dstElementType);
2240 
2241   Location loc = copyOp->getLoc();
2242   Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
2243   SmallVector<Value> indices(srcType.getRank(), zero);
2244 
2245   Value readValue = rewriter.create<vector::TransferReadOp>(
2246       loc, readType, copyOp.getSource(), indices,
2247       rewriter.getMultiDimIdentityMap(srcType.getRank()));
2248   if (cast<VectorType>(readValue.getType()).getRank() == 0) {
2249     readValue = rewriter.create<vector::ExtractElementOp>(loc, readValue);
2250     readValue = rewriter.create<vector::BroadcastOp>(loc, writeType, readValue);
2251   }
2252   Operation *writeValue = rewriter.create<vector::TransferWriteOp>(
2253       loc, readValue, copyOp.getTarget(), indices,
2254       rewriter.getMultiDimIdentityMap(srcType.getRank()));
2255   rewriter.replaceOp(copyOp, writeValue->getResults());
2256   return success();
2257 }
2258 
2259 //----------------------------------------------------------------------------//
2260 // Misc. vectorization patterns.
2261 //----------------------------------------------------------------------------//
2262 
2263 /// Helper function that retrieves the value of an IntegerAttr.
2264 static int64_t getIntFromAttr(Attribute attr) {
2265   return cast<IntegerAttr>(attr).getInt();
2266 }
2267 
2268 /// Given an ArrayRef of OpFoldResults, return a vector of Values.
2269 /// IntegerAttrs are converted to ConstantIndexOps. Other attribute types are
2270 /// not supported.
2271 static SmallVector<Value> ofrToIndexValues(RewriterBase &rewriter, Location loc,
2272                                            ArrayRef<OpFoldResult> ofrs) {
2273   SmallVector<Value> result;
2274   for (auto o : ofrs) {
2275     if (auto val = llvm::dyn_cast_if_present<Value>(o)) {
2276       result.push_back(val);
2277     } else {
2278       result.push_back(rewriter.create<arith::ConstantIndexOp>(
2279           loc, getIntFromAttr(o.template get<Attribute>())));
2280     }
2281   }
2282   return result;
2283 }
2284 
2285 /// Rewrite a tensor::PadOp into a sequence of EmptyOp, FillOp and
2286 /// InsertSliceOp. For now, only constant padding values are supported.
2287 /// If there is enough static type information, TransferReadOps and
2288 /// TransferWriteOps may be generated instead of InsertSliceOps.
2289 struct GenericPadOpVectorizationPattern : public GeneralizePadOpPattern {
2290   GenericPadOpVectorizationPattern(MLIRContext *context,
2291                                    PatternBenefit benefit = 1)
2292       : GeneralizePadOpPattern(context, tryVectorizeCopy, benefit) {}
2293   /// Vectorize the copying of a tensor::PadOp's source. This is possible if
2294   /// each dimension size is statically know in the source type or the result
2295   /// type (or both).
2296   static LogicalResult tryVectorizeCopy(RewriterBase &rewriter,
2297                                         tensor::PadOp padOp, Value dest) {
2298     auto sourceType = padOp.getSourceType();
2299     auto resultType = padOp.getResultType();
2300     if (!VectorType::isValidElementType(sourceType.getElementType()))
2301       return failure();
2302 
2303     // Copy cannot be vectorized if pad value is non-constant and source shape
2304     // is dynamic. In case of a dynamic source shape, padding must be appended
2305     // by TransferReadOp, but TransferReadOp supports only constant padding.
2306     auto padValue = padOp.getConstantPaddingValue();
2307     if (!padValue) {
2308       if (!sourceType.hasStaticShape())
2309         return failure();
2310       // Create dummy padding value.
2311       auto elemType = sourceType.getElementType();
2312       padValue = rewriter.create<arith::ConstantOp>(
2313           padOp.getLoc(), elemType, rewriter.getZeroAttr(elemType));
2314     }
2315 
2316     SmallVector<int64_t> vecShape;
2317     SmallVector<bool> readInBounds;
2318     SmallVector<bool> writeInBounds;
2319     for (unsigned i = 0; i < sourceType.getRank(); ++i) {
2320       if (!sourceType.isDynamicDim(i)) {
2321         vecShape.push_back(sourceType.getDimSize(i));
2322         // Source shape is statically known: Neither read nor write are
2323         // out-of- bounds.
2324         readInBounds.push_back(true);
2325         writeInBounds.push_back(true);
2326       } else if (!resultType.isDynamicDim(i)) {
2327         // Source shape is not statically known, but result shape is.
2328         // Vectorize with size of result shape. This may be larger than the
2329         // source size.
2330         vecShape.push_back(resultType.getDimSize(i));
2331         // Read may be out-of-bounds because the result size could be larger
2332         // than the source size.
2333         readInBounds.push_back(false);
2334         // Write is out-of-bounds if low padding > 0.
2335         writeInBounds.push_back(
2336             getConstantIntValue(padOp.getMixedLowPad()[i]) ==
2337             static_cast<int64_t>(0));
2338       } else {
2339         // Neither source nor result dim of padOp is static. Cannot vectorize
2340         // the copy.
2341         return failure();
2342       }
2343     }
2344     auto vecType = VectorType::get(vecShape, sourceType.getElementType());
2345 
2346     // Generate TransferReadOp.
2347     SmallVector<Value> readIndices(
2348         vecType.getRank(),
2349         rewriter.create<arith::ConstantIndexOp>(padOp.getLoc(), 0));
2350     auto read = rewriter.create<vector::TransferReadOp>(
2351         padOp.getLoc(), vecType, padOp.getSource(), readIndices, padValue,
2352         ArrayRef<bool>{readInBounds});
2353 
2354     // If `dest` is a FillOp and the TransferWriteOp would overwrite the
2355     // entire tensor, write directly to the FillOp's operand.
2356     if (llvm::equal(vecShape, resultType.getShape()) &&
2357         llvm::all_of(writeInBounds, [](bool b) { return b; }))
2358       if (auto fill = dest.getDefiningOp<FillOp>())
2359         dest = fill.output();
2360 
2361     // Generate TransferWriteOp.
2362     auto writeIndices =
2363         ofrToIndexValues(rewriter, padOp.getLoc(), padOp.getMixedLowPad());
2364     rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
2365         padOp, read, dest, writeIndices, ArrayRef<bool>{writeInBounds});
2366 
2367     return success();
2368   }
2369 };
2370 
2371 /// Base pattern for rewriting tensor::PadOps whose result is consumed by a
2372 /// given operation type OpTy.
2373 template <typename OpTy>
2374 struct VectorizePadOpUserPattern : public OpRewritePattern<tensor::PadOp> {
2375   using OpRewritePattern<tensor::PadOp>::OpRewritePattern;
2376 
2377   LogicalResult matchAndRewrite(tensor::PadOp padOp,
2378                                 PatternRewriter &rewriter) const final {
2379     bool changed = false;
2380     // Insert users in vector, because some users may be replaced/removed.
2381     for (auto *user : llvm::to_vector<4>(padOp->getUsers()))
2382       if (auto op = dyn_cast<OpTy>(user))
2383         changed |= rewriteUser(rewriter, padOp, op).succeeded();
2384     return success(changed);
2385   }
2386 
2387 protected:
2388   virtual LogicalResult rewriteUser(PatternRewriter &rewriter,
2389                                     tensor::PadOp padOp, OpTy op) const = 0;
2390 };
2391 
2392 /// Rewrite use of tensor::PadOp result in TransferReadOp. E.g.:
2393 /// ```
2394 /// %0 = tensor.pad %src ... : tensor<?x?xf32> to tensor<17x5xf32>
2395 /// %r = vector.transfer_read %0[%c0, %c0], %cst
2396 ///     {in_bounds = [true, true]} : tensor<17x5xf32>, vector<17x5xf32>
2397 /// ```
2398 /// is rewritten to:
2399 /// ```
2400 /// %r = vector.transfer_read %src[%c0, %c0], %padding
2401 ///     {in_bounds = [true, true]}
2402 ///     : tensor<?x?xf32>, vector<17x5xf32>
2403 /// ```
2404 /// Note: By restricting this pattern to in-bounds TransferReadOps, we can be
2405 /// sure that the original padding value %cst was never used.
2406 ///
2407 /// This rewrite is possible if:
2408 /// - `xferOp` has no out-of-bounds dims or mask.
2409 /// - Low padding is static 0.
2410 /// - Single, scalar padding value.
2411 struct PadOpVectorizationWithTransferReadPattern
2412     : public VectorizePadOpUserPattern<vector::TransferReadOp> {
2413   using VectorizePadOpUserPattern<
2414       vector::TransferReadOp>::VectorizePadOpUserPattern;
2415 
2416   LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp,
2417                             vector::TransferReadOp xferOp) const override {
2418     // Low padding must be static 0.
2419     if (!padOp.hasZeroLowPad())
2420       return failure();
2421     // Pad value must be a constant.
2422     auto padValue = padOp.getConstantPaddingValue();
2423     if (!padValue)
2424       return failure();
2425     // Padding value of existing `xferOp` is unused.
2426     if (xferOp.hasOutOfBoundsDim() || xferOp.getMask())
2427       return failure();
2428 
2429     rewriter.modifyOpInPlace(xferOp, [&]() {
2430       SmallVector<bool> inBounds(xferOp.getVectorType().getRank(), false);
2431       xferOp->setAttr(xferOp.getInBoundsAttrName(),
2432                       rewriter.getBoolArrayAttr(inBounds));
2433       xferOp.getSourceMutable().assign(padOp.getSource());
2434       xferOp.getPaddingMutable().assign(padValue);
2435     });
2436 
2437     return success();
2438   }
2439 };
2440 
2441 /// Rewrite use of tensor::PadOp result in TransferWriteOp.
2442 /// This pattern rewrites TransferWriteOps that write to a padded tensor
2443 /// value, where the same amount of padding is immediately removed again after
2444 /// the write. In such cases, the TransferWriteOp can write to the non-padded
2445 /// tensor value and apply out-of-bounds masking. E.g.:
2446 /// ```
2447 /// %0 = tensor.extract_slice ...[...] [%s0, %s1] [1, 1]
2448 ///     : tensor<...> to tensor<?x?xf32>
2449 /// %1 = tensor.pad %0 ... : tensor<?x?xf32> to tensor<17x5xf32>
2450 /// %2 = vector.transfer_write %vec, %1[...]
2451 ///     : vector<17x5xf32>, tensor<17x5xf32>
2452 /// %r = tensor.extract_slice %2[0, 0] [%s0, %s1] [1, 1]
2453 ///     : tensor<17x5xf32> to tensor<?x?xf32>
2454 /// ```
2455 /// is rewritten to:
2456 /// ```
2457 /// %0 = tensor.extract_slice ...[...] [%s0, %s1] [1, 1]
2458 ///     : tensor<...> to tensor<?x?xf32>
2459 /// %r = vector.transfer_write %vec, %0[...] : vector<17x5xf32>,
2460 /// tensor<?x?xf32>
2461 /// ```
2462 /// Note: It is important that the ExtractSliceOp %r resizes the result of the
2463 /// TransferWriteOp to the same size as the input of the TensorPadOp (or an
2464 /// even smaller size). Otherwise, %r's new (dynamic) dimensions would differ
2465 /// from %r's old dimensions.
2466 ///
2467 /// This rewrite is possible if:
2468 /// - Low padding is static 0.
2469 /// - `xferOp` has exactly one use, which is an ExtractSliceOp. This
2470 ///   ExtractSliceOp trims the same amount of padding that was added
2471 ///   beforehand.
2472 /// - Single, scalar padding value.
2473 struct PadOpVectorizationWithTransferWritePattern
2474     : public VectorizePadOpUserPattern<vector::TransferWriteOp> {
2475   using VectorizePadOpUserPattern<
2476       vector::TransferWriteOp>::VectorizePadOpUserPattern;
2477 
2478   LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp,
2479                             vector::TransferWriteOp xferOp) const override {
2480     // TODO: support 0-d corner case.
2481     if (xferOp.getTransferRank() == 0)
2482       return failure();
2483 
2484     // Low padding must be static 0.
2485     if (!padOp.hasZeroLowPad())
2486       return failure();
2487     // Pad value must be a constant.
2488     auto padValue = padOp.getConstantPaddingValue();
2489     if (!padValue)
2490       return failure();
2491     // TransferWriteOp result must be directly consumed by an ExtractSliceOp.
2492     if (!xferOp->hasOneUse())
2493       return failure();
2494     auto trimPadding = dyn_cast<tensor::ExtractSliceOp>(*xferOp->user_begin());
2495     if (!trimPadding)
2496       return failure();
2497     // Only static zero offsets supported when trimming padding.
2498     if (!trimPadding.hasZeroOffset())
2499       return failure();
2500     // trimPadding must remove the amount of padding that was added earlier.
2501     if (!hasSameTensorSize(padOp.getSource(), trimPadding))
2502       return failure();
2503 
2504     // Insert the new TransferWriteOp at position of the old TransferWriteOp.
2505     rewriter.setInsertionPoint(xferOp);
2506 
2507     SmallVector<bool> inBounds(xferOp.getVectorType().getRank(), false);
2508     auto newXferOp = rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
2509         xferOp, padOp.getSource().getType(), xferOp.getVector(),
2510         padOp.getSource(), xferOp.getIndices(), xferOp.getPermutationMapAttr(),
2511         xferOp.getMask(), rewriter.getBoolArrayAttr(inBounds));
2512     rewriter.replaceOp(trimPadding, newXferOp->getResult(0));
2513 
2514     return success();
2515   }
2516 
2517   /// Check if `beforePadding` and `afterTrimming` have the same tensor size,
2518   /// i.e., same dimensions.
2519   ///
2520   /// Dimensions may be static, dynamic or mix of both. In case of dynamic
2521   /// dimensions, this function tries to infer the (static) tensor size by
2522   /// looking at the defining op and utilizing op-specific knowledge.
2523   ///
2524   /// This is a conservative analysis. In case equal tensor sizes cannot be
2525   /// proven statically, this analysis returns `false` even though the tensor
2526   /// sizes may turn out to be equal at runtime.
2527   bool hasSameTensorSize(Value beforePadding,
2528                          tensor::ExtractSliceOp afterTrimming) const {
2529     // If the input to tensor::PadOp is a CastOp, try with both CastOp
2530     // result and CastOp operand.
2531     if (auto castOp = beforePadding.getDefiningOp<tensor::CastOp>())
2532       if (hasSameTensorSize(castOp.getSource(), afterTrimming))
2533         return true;
2534 
2535     auto t1 = dyn_cast<RankedTensorType>(beforePadding.getType());
2536     auto t2 = dyn_cast<RankedTensorType>(afterTrimming.getType());
2537     // Only RankedTensorType supported.
2538     if (!t1 || !t2)
2539       return false;
2540     // Rank of both values must be the same.
2541     if (t1.getRank() != t2.getRank())
2542       return false;
2543 
2544     // All static dimensions must be the same. Mixed cases (e.g., dimension
2545     // static in `t1` but dynamic in `t2`) are not supported.
2546     for (unsigned i = 0; i < t1.getRank(); ++i) {
2547       if (t1.isDynamicDim(i) != t2.isDynamicDim(i))
2548         return false;
2549       if (!t1.isDynamicDim(i) && t1.getDimSize(i) != t2.getDimSize(i))
2550         return false;
2551     }
2552 
2553     // Nothing more to check if all dimensions are static.
2554     if (t1.getNumDynamicDims() == 0)
2555       return true;
2556 
2557     // All dynamic sizes must be the same. The only supported case at the
2558     // moment is when `beforePadding` is an ExtractSliceOp (or a cast
2559     // thereof).
2560 
2561     // Apart from CastOp, only ExtractSliceOp is supported.
2562     auto beforeSlice = beforePadding.getDefiningOp<tensor::ExtractSliceOp>();
2563     if (!beforeSlice)
2564       return false;
2565 
2566     assert(static_cast<size_t>(t1.getRank()) ==
2567            beforeSlice.getMixedSizes().size());
2568     assert(static_cast<size_t>(t2.getRank()) ==
2569            afterTrimming.getMixedSizes().size());
2570 
2571     for (unsigned i = 0; i < t1.getRank(); ++i) {
2572       // Skip static dimensions.
2573       if (!t1.isDynamicDim(i))
2574         continue;
2575       auto size1 = beforeSlice.getMixedSizes()[i];
2576       auto size2 = afterTrimming.getMixedSizes()[i];
2577 
2578       // Case 1: Same value or same constant int.
2579       if (isEqualConstantIntOrValue(size1, size2))
2580         continue;
2581 
2582       // Other cases: Take a deeper look at defining ops of values.
2583       auto v1 = llvm::dyn_cast_if_present<Value>(size1);
2584       auto v2 = llvm::dyn_cast_if_present<Value>(size2);
2585       if (!v1 || !v2)
2586         return false;
2587 
2588       // Case 2: Both values are identical AffineMinOps. (Should not happen if
2589       // CSE is run.)
2590       auto minOp1 = v1.getDefiningOp<affine::AffineMinOp>();
2591       auto minOp2 = v2.getDefiningOp<affine::AffineMinOp>();
2592       if (minOp1 && minOp2 && minOp1.getAffineMap() == minOp2.getAffineMap() &&
2593           minOp1.getOperands() == minOp2.getOperands())
2594         continue;
2595 
2596       // Add additional cases as needed.
2597     }
2598 
2599     // All tests passed.
2600     return true;
2601   }
2602 };
2603 
2604 /// Rewrite use of tensor::PadOp result in InsertSliceOp. E.g.:
2605 /// ```
2606 /// %0 = tensor.pad %src ... : tensor<?x?xf32> to tensor<17x5xf32>
2607 /// %r = tensor.insert_slice %0
2608 ///     into %dest[%a, %b, 0, 0] [1, 1, 17, 5] [1, 1, 1, 1]
2609 ///     : tensor<17x5xf32> into tensor<?x?x17x5xf32>
2610 /// ```
2611 /// is rewritten to:
2612 /// ```
2613 /// %0 = vector.transfer_read %src[%c0, %c0], %padding
2614 ///     : tensor<?x?xf32>, vector<17x5xf32>
2615 /// %r = vector.transfer_write %0, %dest[%a, %b, %c0, %c0]
2616 ///     {in_bounds = [true, true]} : vector<17x5xf32>, tensor<?x?x17x5xf32>
2617 /// ```
2618 ///
2619 /// This rewrite is possible if:
2620 /// - Low padding is static 0.
2621 /// - `padOp` result shape is static.
2622 /// - The entire padded tensor is inserted.
2623 ///   (Implies that sizes of `insertOp` are all static.)
2624 /// - Only unit strides in `insertOp`.
2625 /// - Single, scalar padding value.
2626 /// - `padOp` result not used as destination.
2627 struct PadOpVectorizationWithInsertSlicePattern
2628     : public VectorizePadOpUserPattern<tensor::InsertSliceOp> {
2629   using VectorizePadOpUserPattern<
2630       tensor::InsertSliceOp>::VectorizePadOpUserPattern;
2631 
2632   LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp,
2633                             tensor::InsertSliceOp insertOp) const override {
2634     // Low padding must be static 0.
2635     if (!padOp.hasZeroLowPad())
2636       return failure();
2637     // Only unit stride supported.
2638     if (!insertOp.hasUnitStride())
2639       return failure();
2640     // Pad value must be a constant.
2641     auto padValue = padOp.getConstantPaddingValue();
2642     if (!padValue)
2643       return failure();
2644     // Dynamic shapes not supported.
2645     if (!cast<ShapedType>(padOp.getResult().getType()).hasStaticShape())
2646       return failure();
2647     // Pad result not used as destination.
2648     if (insertOp.getDest() == padOp.getResult())
2649       return failure();
2650 
2651     auto vecType = VectorType::get(padOp.getType().getShape(),
2652                                    padOp.getType().getElementType());
2653     unsigned vecRank = vecType.getRank();
2654     unsigned tensorRank = insertOp.getType().getRank();
2655 
2656     // Check if sizes match: Insert the entire tensor into most minor dims.
2657     // (No permutations allowed.)
2658     SmallVector<int64_t> expectedSizes(tensorRank - vecRank, 1);
2659     expectedSizes.append(vecType.getShape().begin(), vecType.getShape().end());
2660     if (!llvm::all_of(
2661             llvm::zip(insertOp.getMixedSizes(), expectedSizes), [](auto it) {
2662               return getConstantIntValue(std::get<0>(it)) == std::get<1>(it);
2663             }))
2664       return failure();
2665 
2666     // Insert the TransferReadOp and TransferWriteOp at the position of the
2667     // InsertSliceOp.
2668     rewriter.setInsertionPoint(insertOp);
2669 
2670     // Generate TransferReadOp: Read entire source tensor and add high
2671     // padding.
2672     SmallVector<Value> readIndices(
2673         vecRank, rewriter.create<arith::ConstantIndexOp>(padOp.getLoc(), 0));
2674     auto read = rewriter.create<vector::TransferReadOp>(
2675         padOp.getLoc(), vecType, padOp.getSource(), readIndices, padValue);
2676 
2677     // Generate TransferWriteOp: Write to InsertSliceOp's dest tensor at
2678     // specified offsets. Write is fully in-bounds because a InsertSliceOp's
2679     // source must fit into the destination at the specified offsets.
2680     auto writeIndices =
2681         ofrToIndexValues(rewriter, padOp.getLoc(), insertOp.getMixedOffsets());
2682     SmallVector<bool> inBounds(vecRank, true);
2683     rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
2684         insertOp, read, insertOp.getDest(), writeIndices,
2685         ArrayRef<bool>{inBounds});
2686 
2687     return success();
2688   }
2689 };
2690 
2691 void mlir::linalg::populatePadOpVectorizationPatterns(
2692     RewritePatternSet &patterns, PatternBenefit baseBenefit) {
2693   patterns.add<GenericPadOpVectorizationPattern>(patterns.getContext(),
2694                                                  baseBenefit);
2695   // Try these specialized patterns first before resorting to the generic one.
2696   patterns.add<PadOpVectorizationWithTransferReadPattern,
2697                PadOpVectorizationWithTransferWritePattern,
2698                PadOpVectorizationWithInsertSlicePattern>(
2699       patterns.getContext(), baseBenefit.getBenefit() + 1);
2700 }
2701 
2702 //----------------------------------------------------------------------------//
2703 // Forwarding patterns
2704 //----------------------------------------------------------------------------//
2705 
2706 /// Check whether there is any interleaved use of any `values` between
2707 /// `firstOp` and `secondOp`. Conservatively return `true` if any op or value
2708 /// is in a different block.
2709 static bool mayExistInterleavedUses(Operation *firstOp, Operation *secondOp,
2710                                     ValueRange values) {
2711   if (firstOp->getBlock() != secondOp->getBlock() ||
2712       !firstOp->isBeforeInBlock(secondOp)) {
2713     LDBG("interleavedUses precondition failed, firstOp: "
2714          << *firstOp << ", second op: " << *secondOp << "\n");
2715     return true;
2716   }
2717   for (auto v : values) {
2718     for (auto &u : v.getUses()) {
2719       Operation *owner = u.getOwner();
2720       if (owner == firstOp || owner == secondOp)
2721         continue;
2722       // TODO: this is too conservative, use dominance info in the future.
2723       if (owner->getBlock() == firstOp->getBlock() &&
2724           (owner->isBeforeInBlock(firstOp) || secondOp->isBeforeInBlock(owner)))
2725         continue;
2726       LDBG(" found interleaved op " << *owner << ", firstOp: " << *firstOp
2727                                     << ", second op: " << *secondOp << "\n");
2728       return true;
2729     }
2730   }
2731   return false;
2732 }
2733 
2734 /// Return the unique subview use of `v` if it is indeed unique, null
2735 /// otherwise.
2736 static memref::SubViewOp getSubViewUseIfUnique(Value v) {
2737   memref::SubViewOp subViewOp;
2738   for (auto &u : v.getUses()) {
2739     if (auto newSubViewOp = dyn_cast<memref::SubViewOp>(u.getOwner())) {
2740       if (subViewOp)
2741         return memref::SubViewOp();
2742       subViewOp = newSubViewOp;
2743     }
2744   }
2745   return subViewOp;
2746 }
2747 
2748 /// TODO: use interfaces, side-effects and aliasing analysis as appropriate,
2749 /// when available.
2750 LogicalResult LinalgCopyVTRForwardingPattern::matchAndRewrite(
2751     vector::TransferReadOp xferOp, PatternRewriter &rewriter) const {
2752 
2753   // TODO: support mask.
2754   if (xferOp.getMask())
2755     return rewriter.notifyMatchFailure(xferOp, "unsupported mask");
2756 
2757   // Transfer into `view`.
2758   Value viewOrAlloc = xferOp.getSource();
2759   if (!viewOrAlloc.getDefiningOp<memref::ViewOp>() &&
2760       !viewOrAlloc.getDefiningOp<memref::AllocOp>())
2761     return rewriter.notifyMatchFailure(xferOp, "source not a view or alloc");
2762 
2763   // Ensure there is exactly one subview of `viewOrAlloc` defining `subView`.
2764   memref::SubViewOp subViewOp = getSubViewUseIfUnique(viewOrAlloc);
2765   if (!subViewOp)
2766     return rewriter.notifyMatchFailure(xferOp, "no subview found");
2767   Value subView = subViewOp.getResult();
2768 
2769   // Find the copy into `subView` without interleaved uses.
2770   memref::CopyOp copyOp;
2771   for (auto &u : subView.getUses()) {
2772     if (auto newCopyOp = dyn_cast<memref::CopyOp>(u.getOwner())) {
2773       assert(isa<MemRefType>(newCopyOp.getTarget().getType()));
2774       if (newCopyOp.getTarget() != subView)
2775         continue;
2776       if (mayExistInterleavedUses(newCopyOp, xferOp, {viewOrAlloc, subView}))
2777         continue;
2778       copyOp = newCopyOp;
2779       break;
2780     }
2781   }
2782   if (!copyOp)
2783     return rewriter.notifyMatchFailure(xferOp, "no copy found");
2784 
2785   // Find the fill into `viewOrAlloc` without interleaved uses before the
2786   // copy.
2787   FillOp maybeFillOp;
2788   for (auto &u : viewOrAlloc.getUses()) {
2789     if (auto newFillOp = dyn_cast<FillOp>(u.getOwner())) {
2790       assert(isa<MemRefType>(newFillOp.output().getType()));
2791       if (newFillOp.output() != viewOrAlloc)
2792         continue;
2793       if (mayExistInterleavedUses(newFillOp, copyOp, {viewOrAlloc, subView}))
2794         continue;
2795       maybeFillOp = newFillOp;
2796       break;
2797     }
2798   }
2799   // Ensure padding matches.
2800   if (maybeFillOp && xferOp.getPadding() != maybeFillOp.value())
2801     return rewriter.notifyMatchFailure(xferOp,
2802                                        "padding value does not match fill");
2803 
2804   // `in` is the subview that memref.copy reads. Replace it.
2805   Value in = copyOp.getSource();
2806 
2807   // memref.copy + linalg.fill can be used to create a padded local buffer.
2808   // The `masked` attribute is only valid on this padded buffer.
2809   // When forwarding to vector.transfer_read, the attribute must be reset
2810   // conservatively.
2811   auto vectorType = xferOp.getVectorType();
2812   Value res = rewriter.create<vector::TransferReadOp>(
2813       xferOp.getLoc(), vectorType, in, xferOp.getIndices(),
2814       xferOp.getPermutationMapAttr(), xferOp.getPadding(), xferOp.getMask(),
2815       rewriter.getBoolArrayAttr(
2816           SmallVector<bool>(vectorType.getRank(), false)));
2817 
2818   if (maybeFillOp)
2819     rewriter.eraseOp(maybeFillOp);
2820   rewriter.eraseOp(copyOp);
2821   rewriter.replaceOp(xferOp, res);
2822 
2823   return success();
2824 }
2825 
2826 /// TODO: use interfaces, side-effects and aliasing analysis as appropriate,
2827 /// when available.
2828 LogicalResult LinalgCopyVTWForwardingPattern::matchAndRewrite(
2829     vector::TransferWriteOp xferOp, PatternRewriter &rewriter) const {
2830   // TODO: support mask.
2831   if (xferOp.getMask())
2832     return rewriter.notifyMatchFailure(xferOp, "unsupported mask");
2833 
2834   // Transfer into `viewOrAlloc`.
2835   Value viewOrAlloc = xferOp.getSource();
2836   if (!viewOrAlloc.getDefiningOp<memref::ViewOp>() &&
2837       !viewOrAlloc.getDefiningOp<memref::AllocOp>())
2838     return rewriter.notifyMatchFailure(xferOp, "source not a view or alloc");
2839 
2840   // Ensure there is exactly one subview of `viewOrAlloc` defining `subView`.
2841   memref::SubViewOp subViewOp = getSubViewUseIfUnique(viewOrAlloc);
2842   if (!subViewOp)
2843     return rewriter.notifyMatchFailure(xferOp, "no subview found");
2844   Value subView = subViewOp.getResult();
2845 
2846   // Find the copy from `subView` without interleaved uses.
2847   memref::CopyOp copyOp;
2848   for (auto &u : subViewOp.getResult().getUses()) {
2849     if (auto newCopyOp = dyn_cast<memref::CopyOp>(u.getOwner())) {
2850       if (newCopyOp.getSource() != subView)
2851         continue;
2852       if (mayExistInterleavedUses(xferOp, newCopyOp, {viewOrAlloc, subView}))
2853         continue;
2854       copyOp = newCopyOp;
2855       break;
2856     }
2857   }
2858   if (!copyOp)
2859     return rewriter.notifyMatchFailure(xferOp, "no copy found");
2860 
2861   // `out` is the subview copied into that we replace.
2862   assert(isa<MemRefType>(copyOp.getTarget().getType()));
2863   Value out = copyOp.getTarget();
2864 
2865   // Forward vector.transfer into copy.
2866   // memref.copy + linalg.fill can be used to create a padded local buffer.
2867   // The `masked` attribute is only valid on this padded buffer.
2868   // When forwarding to vector.transfer_write, the attribute must be reset
2869   // conservatively.
2870   auto vector = xferOp.getVector();
2871   rewriter.create<vector::TransferWriteOp>(
2872       xferOp.getLoc(), vector, out, xferOp.getIndices(),
2873       xferOp.getPermutationMapAttr(), xferOp.getMask(),
2874       rewriter.getBoolArrayAttr(
2875           SmallVector<bool>(vector.getType().getRank(), false)));
2876 
2877   rewriter.eraseOp(copyOp);
2878   rewriter.eraseOp(xferOp);
2879 
2880   return success();
2881 }
2882 
2883 //===----------------------------------------------------------------------===//
2884 // Convolution vectorization patterns
2885 //===----------------------------------------------------------------------===//
2886 
2887 template <int N>
2888 static void bindShapeDims(ShapedType shapedType) {}
2889 
2890 template <int N, typename IntTy, typename... IntTy2>
2891 static void bindShapeDims(ShapedType shapedType, IntTy &val, IntTy2 &...vals) {
2892   val = shapedType.getShape()[N];
2893   bindShapeDims<N + 1, IntTy2 &...>(shapedType, vals...);
2894 }
2895 
2896 /// Bind a pack of int& to the leading dimensions of shapedType.getShape().
2897 template <typename... IntTy>
2898 static void bindShapeDims(ShapedType shapedType, IntTy &...vals) {
2899   bindShapeDims<0>(shapedType, vals...);
2900 }
2901 
2902 namespace {
2903 bool isCastOfBlockArgument(Operation *op) {
2904   return isa<CastOpInterface>(op) && op->getNumOperands() == 1 &&
2905          isa<BlockArgument>(op->getOperand(0));
2906 }
2907 
2908 bool isSupportedPoolKind(vector::CombiningKind kind) {
2909   switch (kind) {
2910   case vector::CombiningKind::ADD:
2911   case vector::CombiningKind::MAXNUMF:
2912   case vector::CombiningKind::MAXIMUMF:
2913   case vector::CombiningKind::MAXSI:
2914   case vector::CombiningKind::MAXUI:
2915   case vector::CombiningKind::MINNUMF:
2916   case vector::CombiningKind::MINIMUMF:
2917   case vector::CombiningKind::MINSI:
2918   case vector::CombiningKind::MINUI:
2919     return true;
2920   default:
2921     return false;
2922   }
2923 }
2924 
2925 /// Generate a vector implementation for either:
2926 /// ```
2927 ///   Op def: (     w,     kw  )
2928 ///    Iters: ({Par(), Red()})
2929 ///   Layout: {{w + kw}, {kw}, {w}}
2930 /// ```
2931 /// kw is unrolled.
2932 ///
2933 /// or
2934 ///
2935 /// ```
2936 ///   Op def: (     n,     w,     c,    kw,    f  )
2937 ///    Iters: ({Par(), Par(), Par(), Red(), Red()})
2938 ///   Layout: {{n, strideW * w + dilationW * kw, c}, {kw, c, f}, {n, w, f}}
2939 /// ```
2940 /// kw is unrolled, w is unrolled iff dilationW > 1.
2941 ///
2942 /// or
2943 ///
2944 /// ```
2945 ///   Op def: (     n,     c,     w,    f,    kw )
2946 ///    Iters: ({Par(), Par(), Par(), Red(), Red()})
2947 ///   Layout: {{n, c, strideW * w + dilationW * kw}, {f, c, kw}, {n, f, w}}
2948 /// ```
2949 /// kw is unrolled, w is unrolled iff dilationW > 1.
2950 ///
2951 /// or
2952 ///
2953 /// ```
2954 ///   Op def: (     n,     w,     c,    kw )
2955 ///    Iters: ({Par(), Par(), Par(), Red()})
2956 ///   Layout: {{n, strideW * w + dilationW * kw, c}, {kw, c}, {n, w, c}}
2957 /// ```
2958 /// kw is unrolled, w is unrolled iff dilationW > 1.
2959 struct Conv1DGenerator
2960     : public StructuredGenerator<LinalgOp, utils::IteratorType> {
2961   Conv1DGenerator(RewriterBase &rewriter, LinalgOp linalgOp, int strideW,
2962                   int dilationW)
2963       : StructuredGenerator<LinalgOp, utils::IteratorType>(rewriter, linalgOp),
2964         strideW(strideW), dilationW(dilationW) {
2965     // Determine whether `linalgOp` can be generated with this generator
2966     if (linalgOp.getNumDpsInputs() != 2 || linalgOp.getNumDpsInits() != 1)
2967       return;
2968     lhsShaped = linalgOp.getDpsInputOperand(0)->get();
2969     rhsShaped = linalgOp.getDpsInputOperand(1)->get();
2970     resShaped = linalgOp.getDpsInitOperand(0)->get();
2971     lhsShapedType = dyn_cast<ShapedType>(lhsShaped.getType());
2972     rhsShapedType = dyn_cast<ShapedType>(rhsShaped.getType());
2973     resShapedType = dyn_cast<ShapedType>(resShaped.getType());
2974     if (!lhsShapedType || !rhsShapedType || !resShapedType)
2975       return;
2976     // (LHS has dimension NCW/NWC and RES has dimension NFW/NCW/NWF/NWC) OR
2977     // (non-channeled convolution -> LHS and RHS both have single dimensions).
2978     if ((lhsShapedType.getRank() != 3 || resShapedType.getRank() != 3) &&
2979         (lhsShapedType.getRank() != 1 || resShapedType.getRank() != 1))
2980       return;
2981 
2982     Operation *reduceOp = matchLinalgReduction(linalgOp.getDpsInitOperand(0));
2983     if (!reduceOp)
2984       return;
2985     redOp = reduceOp->getName().getIdentifier();
2986 
2987     if (!setOperKind(reduceOp))
2988       return;
2989     auto maybeKind = getCombinerOpKind(reduceOp);
2990     if (!maybeKind || (*maybeKind != vector::CombiningKind::ADD &&
2991                        (oper != Pool || !isSupportedPoolKind(*maybeKind)))) {
2992       return;
2993     }
2994 
2995     auto rhsRank = rhsShapedType.getRank();
2996     switch (oper) {
2997     case Conv:
2998       if (rhsRank != 1 && rhsRank != 2 && rhsRank != 3)
2999         return;
3000       break;
3001     case Pool:
3002       if (rhsRank != 1)
3003         return;
3004       break;
3005     }
3006     // The op is now known to be valid.
3007     valid = true;
3008   }
3009 
3010   /// Generate a vector implementation for:
3011   /// ```
3012   ///   Op def: (     w,     kw  )
3013   ///    Iters: ({Par(), Red()})
3014   ///   Layout: {{w + kw}, {kw}, {w}}
3015   /// ```
3016   /// kw is always unrolled.
3017   ///
3018   /// or
3019   ///
3020   /// ```
3021   ///   Op def: (     n,     w,     c,    kw,    f  )
3022   ///    Iters: ({Par(), Par(), Par(), Red(), Red()})
3023   ///   Layout: {{n, strideW * w + dilationW * kw, c}, {kw, c, f}, {n, w, f}}
3024   /// ```
3025   /// kw is always unrolled.
3026   /// TODO: w (resp. kw) is unrolled when the strideW ( resp. dilationW) is
3027   /// > 1.
3028   FailureOr<Operation *> conv(Conv1DOpOrder conv1DOpOrder) {
3029     if (!valid)
3030       return rewriter.notifyMatchFailure(op, "unvectorizable 1-D conv/pool");
3031 
3032     int64_t nSize, wSize, cSize, kwSize, fSize;
3033     SmallVector<int64_t, 3> lhsShape, rhsShape, resShape;
3034     bool isSingleChanneled = (conv1DOpOrder == Conv1DOpOrder::W);
3035     switch (conv1DOpOrder) {
3036     case Conv1DOpOrder::W:
3037       // Initialize unused dimensions
3038       nSize = fSize = cSize = 0;
3039       // out{W}
3040       bindShapeDims(resShapedType, wSize);
3041       // kernel{kw}
3042       bindShapeDims(rhsShapedType, kwSize);
3043       lhsShape = {// iw = ow + kw - 1
3044                   //   (i.e. 16 convolved with 3 -> 14)
3045                   (wSize + kwSize - 1)};
3046       rhsShape = {kwSize};
3047       resShape = {wSize};
3048       break;
3049     case Conv1DOpOrder::Nwc:
3050       // out{n, w, f}
3051       bindShapeDims(resShapedType, nSize, wSize, fSize);
3052       switch (oper) {
3053       case Conv:
3054         // kernel{kw, c, f}
3055         bindShapeDims(rhsShapedType, kwSize, cSize);
3056         break;
3057       case Pool:
3058         // kernel{kw}
3059         bindShapeDims(rhsShapedType, kwSize);
3060         cSize = fSize;
3061         break;
3062       }
3063       lhsShape = {nSize,
3064                   // iw = ow * sw + kw *  dw - 1
3065                   //   (i.e. 16 convolved with 3 (@stride 1 dilation 1) -> 14)
3066                   // Perform the proper inclusive -> exclusive -> inclusive.
3067                   ((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) -
3068                       1,
3069                   cSize};
3070       switch (oper) {
3071       case Conv:
3072         rhsShape = {kwSize, cSize, fSize};
3073         break;
3074       case Pool:
3075         rhsShape = {kwSize};
3076         break;
3077       }
3078       resShape = {nSize, wSize, fSize};
3079       break;
3080     case Conv1DOpOrder::Ncw:
3081       // out{n, f, w}
3082       bindShapeDims(resShapedType, nSize, fSize, wSize);
3083       switch (oper) {
3084       case Conv:
3085         // kernel{f, c, kw}
3086         bindShapeDims(rhsShapedType, fSize, cSize, kwSize);
3087         break;
3088       case Pool:
3089         // kernel{kw}
3090         bindShapeDims(rhsShapedType, kwSize);
3091         cSize = fSize;
3092         break;
3093       }
3094       lhsShape = {nSize, cSize,
3095                   // iw = ow * sw + kw *  dw - 1
3096                   //   (i.e. 16 convolved with 3 (@stride 1 dilation 1) -> 14)
3097                   // Perform the proper inclusive -> exclusive -> inclusive.
3098                   ((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) -
3099                       1};
3100       switch (oper) {
3101       case Conv:
3102         rhsShape = {fSize, cSize, kwSize};
3103         break;
3104       case Pool:
3105         rhsShape = {kwSize};
3106         break;
3107       }
3108       resShape = {nSize, fSize, wSize};
3109       break;
3110     }
3111 
3112     vector::TransferWriteOp write;
3113     Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
3114 
3115     // w is unrolled (i.e. wSizeStep == 1) iff strideW > 1.
3116     // When strideW == 1, we can batch the contiguous loads and avoid
3117     // unrolling
3118     int64_t wSizeStep = strideW == 1 ? wSize : 1;
3119 
3120     Type lhsEltType = lhsShapedType.getElementType();
3121     Type rhsEltType = rhsShapedType.getElementType();
3122     Type resEltType = resShapedType.getElementType();
3123     auto lhsType = VectorType::get(lhsShape, lhsEltType);
3124     auto rhsType = VectorType::get(rhsShape, rhsEltType);
3125     auto resType = VectorType::get(resShape, resEltType);
3126     // Zero padding with the corresponding dimensions for lhs, rhs and res.
3127     SmallVector<Value> lhsPadding(lhsShape.size(), zero);
3128     SmallVector<Value> rhsPadding(rhsShape.size(), zero);
3129     SmallVector<Value> resPadding(resShape.size(), zero);
3130 
3131     // Read the whole lhs, rhs and res in one shot (with zero padding).
3132     Value lhs = rewriter.create<vector::TransferReadOp>(loc, lhsType, lhsShaped,
3133                                                         lhsPadding);
3134     // This is needed only for Conv.
3135     Value rhs = nullptr;
3136     if (oper == Conv)
3137       rhs = rewriter.create<vector::TransferReadOp>(loc, rhsType, rhsShaped,
3138                                                     rhsPadding);
3139     Value res = rewriter.create<vector::TransferReadOp>(loc, resType, resShaped,
3140                                                         resPadding);
3141 
3142     // The base vectorization case for channeled convolution is input:
3143     // {n,w,c}, weight: {kw,c,f}, output: {n,w,f}. To reuse the base pattern
3144     // vectorization case, we do pre transpose on input, weight, and output.
3145     switch (conv1DOpOrder) {
3146     case Conv1DOpOrder::W:
3147     case Conv1DOpOrder::Nwc:
3148       // Base case, so no transposes necessary.
3149       break;
3150     case Conv1DOpOrder::Ncw: {
3151       // To match base vectorization case, we pre-transpose current case.
3152       // ncw -> nwc
3153       static constexpr std::array<int64_t, 3> permLhs = {0, 2, 1};
3154       lhs = rewriter.create<vector::TransposeOp>(loc, lhs, permLhs);
3155       // fcw -> wcf
3156       static constexpr std::array<int64_t, 3> permRhs = {2, 1, 0};
3157 
3158       // This is needed only for Conv.
3159       if (oper == Conv)
3160         rhs = rewriter.create<vector::TransposeOp>(loc, rhs, permRhs);
3161       // nfw -> nwf
3162       static constexpr std::array<int64_t, 3> permRes = {0, 2, 1};
3163       res = rewriter.create<vector::TransposeOp>(loc, res, permRes);
3164       break;
3165     }
3166     }
3167 
3168     //===------------------------------------------------------------------===//
3169     // Begin vector-only rewrite part
3170     //===------------------------------------------------------------------===//
3171     // Unroll along kw and read slices of lhs and rhs.
3172     SmallVector<Value> lhsVals, rhsVals, resVals;
3173     lhsVals = extractConvInputSlices(rewriter, loc, lhs, nSize, wSize, cSize,
3174                                      kwSize, strideW, dilationW, wSizeStep,
3175                                      isSingleChanneled);
3176     // Do not do for pooling.
3177     if (oper == Conv)
3178       rhsVals = extractConvFilterSlices(rewriter, loc, rhs, kwSize);
3179     resVals = extractConvResultSlices(rewriter, loc, res, nSize, wSize, fSize,
3180                                       wSizeStep, isSingleChanneled);
3181 
3182     auto linearIndex = [&](int64_t kw, int64_t w) {
3183       return kw * (wSize / wSizeStep) + w;
3184     };
3185 
3186     // Compute contraction: O{n, w, f} += I{n, sw * w + dw * kw, c} * F{c, f}
3187     // or perform outerproduct for non-channeled convolution or perform simple
3188     // arith operation for pooling
3189     for (int64_t kw = 0; kw < kwSize; ++kw) {
3190       for (int64_t w = 0; w < wSize; w += wSizeStep) {
3191         switch (oper) {
3192         case Conv:
3193           if (isSingleChanneled) {
3194             resVals[w] = conv1dSliceAsOuterProduct(rewriter, loc,
3195                                                    lhsVals[linearIndex(kw, w)],
3196                                                    rhsVals[kw], resVals[w]);
3197           } else {
3198             resVals[w] = conv1dSliceAsContraction(rewriter, loc,
3199                                                   lhsVals[linearIndex(kw, w)],
3200                                                   rhsVals[kw], resVals[w]);
3201           }
3202           break;
3203         case Pool:
3204           resVals[w] = pool1dSlice(rewriter, loc, lhsVals[linearIndex(kw, w)],
3205                                    resVals[w]);
3206           break;
3207         }
3208       }
3209     }
3210 
3211     res = insertConvResultSlices(rewriter, loc, res, wSize, wSizeStep, resVals,
3212                                  isSingleChanneled);
3213     //===------------------------------------------------------------------===//
3214     // End vector-only rewrite part
3215     //===------------------------------------------------------------------===//
3216 
3217     // The base vectorization case for channeled convolution is output:
3218     // {n,w,f} To reuse the result from base pattern vectorization case, we
3219     // post transpose the base case result.
3220     switch (conv1DOpOrder) {
3221     case Conv1DOpOrder::W:
3222     case Conv1DOpOrder::Nwc:
3223       // Base case, so no transposes necessary.
3224       break;
3225     case Conv1DOpOrder::Ncw: {
3226       // nwf -> nfw
3227       static constexpr std::array<int64_t, 3> perm = {0, 2, 1};
3228       res = rewriter.create<vector::TransposeOp>(loc, res, perm);
3229       break;
3230     }
3231     }
3232 
3233     return rewriter
3234         .create<vector::TransferWriteOp>(loc, res, resShaped, resPadding)
3235         .getOperation();
3236   }
3237 
3238   // Take a value and widen to have the same element type as `ty`.
3239   Value promote(RewriterBase &rewriter, Location loc, Value val, Type ty) {
3240     const Type srcElementType = getElementTypeOrSelf(val.getType());
3241     const Type dstElementType = getElementTypeOrSelf(ty);
3242     assert(isa<IntegerType>(dstElementType) || isa<FloatType>(dstElementType));
3243     if (srcElementType == dstElementType)
3244       return val;
3245 
3246     const int64_t srcWidth = srcElementType.getIntOrFloatBitWidth();
3247     const int64_t dstWidth = dstElementType.getIntOrFloatBitWidth();
3248     const Type dstType =
3249         cast<ShapedType>(val.getType()).cloneWith(std::nullopt, dstElementType);
3250 
3251     if (isa<IntegerType>(srcElementType) && isa<FloatType>(dstElementType)) {
3252       return rewriter.create<arith::SIToFPOp>(loc, dstType, val);
3253     }
3254 
3255     if (isa<FloatType>(srcElementType) && isa<FloatType>(dstElementType) &&
3256         srcWidth < dstWidth)
3257       return rewriter.create<arith::ExtFOp>(loc, dstType, val);
3258 
3259     if (isa<IntegerType>(srcElementType) && isa<IntegerType>(dstElementType) &&
3260         srcWidth < dstWidth)
3261       return rewriter.create<arith::ExtSIOp>(loc, dstType, val);
3262 
3263     assert(false && "unhandled promotion case");
3264     return nullptr;
3265   }
3266 
3267   // Create a contraction: lhs{n, w, c} * rhs{c, f} -> res{n, w, f}
3268   Value conv1dSliceAsContraction(RewriterBase &rewriter, Location loc,
3269                                  Value lhs, Value rhs, Value res) {
3270     vector::IteratorType par = vector::IteratorType::parallel;
3271     vector::IteratorType red = vector::IteratorType::reduction;
3272     AffineExpr n, w, f, c;
3273     bindDims(ctx, n, w, f, c);
3274     lhs = promote(rewriter, loc, lhs, res.getType());
3275     rhs = promote(rewriter, loc, rhs, res.getType());
3276     return rewriter.create<vector::ContractionOp>(
3277         loc, lhs, rhs, res,
3278         /*indexingMaps=*/MapList{{n, w, c}, {c, f}, {n, w, f}},
3279         /*iteratorTypes=*/ArrayRef<vector::IteratorType>{par, par, par, red});
3280   }
3281 
3282   // Create an outerproduct: lhs{w} * rhs{1} -> res{w} for single channel
3283   // convolution.
3284   Value conv1dSliceAsOuterProduct(RewriterBase &rewriter, Location loc,
3285                                   Value lhs, Value rhs, Value res) {
3286     return rewriter.create<vector::OuterProductOp>(
3287         loc, res.getType(), lhs, rhs, res, vector::CombiningKind::ADD);
3288   }
3289 
3290   // Create a reduction: lhs{n, w, c} -> res{n, w, c}
3291   Value pool1dSlice(RewriterBase &rewriter, Location loc, Value lhs,
3292                     Value res) {
3293     if (isPoolExt)
3294       lhs = rewriter.create(loc, poolExtOp, lhs, res.getType())->getResult(0);
3295     return rewriter
3296         .create(loc, redOp, ArrayRef<Value>{lhs, res}, res.getType())
3297         ->getResult(0);
3298   }
3299 
3300   /// Generate a vector implementation for:
3301   /// ```
3302   ///   Op def: (     n,     w,     c,    kw)
3303   ///    Iters: ({Par(), Par(), Par(), Red()})
3304   ///   Layout: {{n, strideW * w + dilationW * kw, c}, {kw, c}, {n, w, c}}
3305   /// ```
3306   /// kw is always unrolled.
3307   /// TODO: w (resp. kw) is unrolled when the strideW ( resp. dilationW) is
3308   /// > 1.
3309   FailureOr<Operation *> depthwiseConv(uint64_t channelDimVecSize,
3310                                        bool channelDimScalableFlag,
3311                                        bool flatten) {
3312     if (!valid)
3313       return rewriter.notifyMatchFailure(op, "unvectorizable depthwise conv");
3314 
3315     bool scalableChDim = false;
3316     bool useMasking = false;
3317     int64_t nSize, wSize, cSize, kwSize;
3318     // kernel{kw, c}
3319     bindShapeDims(rhsShapedType, kwSize, cSize);
3320     if (ShapedType::isDynamic(cSize)) {
3321       assert(channelDimVecSize != 0 && "Channel dim vec size must be > 0");
3322       cSize = channelDimVecSize;
3323       // Scalable vectors are only used when both conditions are met:
3324       //  1. channel dim is dynamic
3325       //  2. channelDimScalableFlag is set
3326       scalableChDim = channelDimScalableFlag;
3327       useMasking = true;
3328     }
3329 
3330     assert(!(useMasking && flatten) &&
3331            "Unsupported flattened conv with dynamic shapes");
3332 
3333     // out{n, w, c}
3334     bindShapeDims(resShapedType, nSize, wSize);
3335 
3336     vector::TransferWriteOp write;
3337     Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
3338 
3339     // w is unrolled (i.e. wSizeStep == 1) iff strideW > 1.
3340     // When strideW == 1, we can batch the contiguous loads and avoid
3341     // unrolling
3342     int64_t wSizeStep = strideW == 1 ? wSize : 1;
3343 
3344     Type lhsEltType = lhsShapedType.getElementType();
3345     Type rhsEltType = rhsShapedType.getElementType();
3346     Type resEltType = resShapedType.getElementType();
3347     VectorType lhsType = VectorType::get(
3348         {nSize,
3349          // iw = ow * sw + kw *  dw - 1
3350          //   (i.e. 16 convolved with 3 (@stride 1 dilation 1) -> 14)
3351          ((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) - 1,
3352          cSize},
3353         lhsEltType, /*scalableDims=*/{false, false, scalableChDim});
3354     VectorType rhsType =
3355         VectorType::get({kwSize, cSize}, rhsEltType,
3356                         /*scalableDims=*/{false, scalableChDim});
3357     VectorType resType =
3358         VectorType::get({nSize, wSize, cSize}, resEltType,
3359                         /*scalableDims=*/{false, false, scalableChDim});
3360 
3361     // Masks the input xfer Op along the channel dim, iff the corresponding
3362     // scalable flag is set.
3363     auto maybeMaskXferOp = [&](ArrayRef<int64_t> maskShape,
3364                                ArrayRef<bool> scalableDims,
3365                                Operation *opToMask) {
3366       if (!useMasking)
3367         return opToMask;
3368       auto maskType =
3369           VectorType::get(maskShape, rewriter.getI1Type(), scalableDims);
3370 
3371       SmallVector<bool> inBounds(maskShape.size(), true);
3372       auto xferOp = cast<VectorTransferOpInterface>(opToMask);
3373       xferOp->setAttr(xferOp.getInBoundsAttrName(),
3374                       rewriter.getBoolArrayAttr(inBounds));
3375 
3376       SmallVector<OpFoldResult> mixedDims = vector::getMixedSizesXfer(
3377           cast<LinalgOp>(op).hasPureTensorSemantics(), opToMask, rewriter);
3378 
3379       Value maskOp =
3380           rewriter.create<vector::CreateMaskOp>(loc, maskType, mixedDims);
3381 
3382       return mlir::vector::maskOperation(rewriter, opToMask, maskOp);
3383     };
3384 
3385     // Read lhs slice of size {n, w * strideW + kw * dilationW, c} @ [0, 0,
3386     // 0].
3387     Value lhs = rewriter.create<vector::TransferReadOp>(
3388         loc, lhsType, lhsShaped, ValueRange{zero, zero, zero});
3389     auto maybeMaskedLhs = maybeMaskXferOp(
3390         lhsType.getShape(), lhsType.getScalableDims(), lhs.getDefiningOp());
3391 
3392     // Read rhs slice of size {kw, c} @ [0, 0].
3393     Value rhs = rewriter.create<vector::TransferReadOp>(loc, rhsType, rhsShaped,
3394                                                         ValueRange{zero, zero});
3395     auto maybeMaskedRhs = maybeMaskXferOp(
3396         rhsType.getShape(), rhsType.getScalableDims(), rhs.getDefiningOp());
3397 
3398     // Read res slice of size {n, w, c} @ [0, 0, 0].
3399     Value res = rewriter.create<vector::TransferReadOp>(
3400         loc, resType, resShaped, ValueRange{zero, zero, zero});
3401     auto maybeMaskedRes = maybeMaskXferOp(
3402         resType.getShape(), resType.getScalableDims(), res.getDefiningOp());
3403 
3404     //===------------------------------------------------------------------===//
3405     // Begin vector-only rewrite part
3406     //===------------------------------------------------------------------===//
3407     // Unroll along kw and read slices of lhs and rhs.
3408     SmallVector<Value> lhsVals, rhsVals, resVals;
3409     auto inOutSliceSizes = SmallVector<int64_t>{nSize, wSizeStep, cSize};
3410     auto inOutStrides = SmallVector<int64_t>{1, 1, 1};
3411 
3412     // Extract lhs slice of size {n, wSizeStep, c}
3413     //   @ [0, sw * w + dw * kw, 0].
3414     for (int64_t kw = 0; kw < kwSize; ++kw) {
3415       for (int64_t w = 0; w < wSize; w += wSizeStep) {
3416         lhsVals.push_back(rewriter.create<vector::ExtractStridedSliceOp>(
3417             loc, maybeMaskedLhs->getResult(0),
3418             /*offsets=*/ArrayRef<int64_t>{0, w * strideW + kw * dilationW, 0},
3419             inOutSliceSizes, inOutStrides));
3420       }
3421     }
3422     // Extract rhs slice of size {c} @ [kw].
3423     for (int64_t kw = 0; kw < kwSize; ++kw) {
3424       rhsVals.push_back(rewriter.create<vector::ExtractOp>(
3425           loc, maybeMaskedRhs->getResult(0),
3426           /*offsets=*/ArrayRef<int64_t>{kw}));
3427     }
3428     // Extract res slice: {n, wSizeStep, c} @ [0, w, 0].
3429     for (int64_t w = 0; w < wSize; w += wSizeStep) {
3430       resVals.push_back(rewriter.create<vector::ExtractStridedSliceOp>(
3431           loc, maybeMaskedRes->getResult(0),
3432           /*offsets=*/ArrayRef<int64_t>{0, w, 0}, inOutSliceSizes,
3433           inOutStrides));
3434     }
3435 
3436     auto linearIndex = [&](int64_t kw, int64_t w) {
3437       return kw * (wSize / wSizeStep) + w;
3438     };
3439 
3440     // Note - the scalable flags are ignored as flattening combined with
3441     // scalable vectorization is not supported.
3442     auto inOutFlattenSliceSizes =
3443         SmallVector<int64_t>{nSize, wSizeStep * cSize};
3444     auto lhsTypeAfterFlattening =
3445         VectorType::get(inOutFlattenSliceSizes, lhsEltType);
3446     auto resTypeAfterFlattening =
3447         VectorType::get(inOutFlattenSliceSizes, resEltType);
3448 
3449     // Compute contraction: O{n, w, c} += I{n, sw * w + dw * kw, c} * F{c}
3450     for (int64_t kw = 0; kw < kwSize; ++kw) {
3451       for (int64_t w = 0; w < wSize; w += wSizeStep) {
3452         Value lhsVal = lhsVals[linearIndex(kw, w)];
3453         Value resVal = resVals[w];
3454         if (flatten) {
3455           // Flatten the input and output vectors (collapse the channel
3456           // dimension)
3457           lhsVal = rewriter.create<vector::ShapeCastOp>(
3458               loc, lhsTypeAfterFlattening, lhsVals[linearIndex(kw, w)]);
3459           resVal = rewriter.create<vector::ShapeCastOp>(
3460               loc, resTypeAfterFlattening, resVals[w]);
3461         }
3462         resVals[w] = depthwiseConv1dSliceAsMulAcc(rewriter, loc, lhsVal,
3463                                                   rhsVals[kw], resVal, flatten);
3464         if (flatten) {
3465           // Un-flatten the output vector (restore the channel dimension)
3466           resVals[w] = rewriter.create<vector::ShapeCastOp>(
3467               loc, VectorType::get(inOutSliceSizes, resEltType), resVals[w]);
3468         }
3469       }
3470     }
3471 
3472     // Its possible we failed to create the Fma.
3473     if (!llvm::all_of(resVals, [](Value v) { return v; })) {
3474       // Manually revert (in reverse order) to avoid leaving a bad IR state.
3475       for (auto &collection :
3476            {resVals, rhsVals, lhsVals, {res, rhs, lhs, zero}})
3477         for (Value v : collection)
3478           rewriter.eraseOp(v.getDefiningOp());
3479       return rewriter.notifyMatchFailure(op, "failed to create FMA");
3480     }
3481 
3482     // Write back res slice: {n, wSizeStep, c} @ [0, w, 0].
3483     // This does not depend on kw.
3484     for (int64_t w = 0; w < wSize; w += wSizeStep) {
3485       maybeMaskedRes = rewriter.create<vector::InsertStridedSliceOp>(
3486           loc, resVals[w], maybeMaskedRes->getResult(0),
3487           /*offsets=*/ArrayRef<int64_t>{0, w, 0},
3488           /*strides=*/ArrayRef<int64_t>{1, 1, 1});
3489     }
3490     //===------------------------------------------------------------------===//
3491     // End vector-only rewrite part
3492     //===------------------------------------------------------------------===//
3493 
3494     // Write back res slice of size {n, w, c} @ [0, 0, 0].
3495     Operation *resOut = rewriter.create<vector::TransferWriteOp>(
3496         loc, maybeMaskedRes->getResult(0), resShaped,
3497         ValueRange{zero, zero, zero});
3498     return maybeMaskXferOp(resType.getShape(), resType.getScalableDims(),
3499                            resOut);
3500   }
3501 
3502   /// Lower:
3503   ///   *  lhs{n, w, c} * rhs{c} -> res{n, w, c} (flatten = false)
3504   ///   *  lhs{n, w * c} * rhs{c} -> res{n, w * c} (flatten = true)
3505   /// to MulAcc.
3506   Value depthwiseConv1dSliceAsMulAcc(RewriterBase &rewriter, Location loc,
3507                                      Value lhs, Value rhs, Value res,
3508                                      bool flatten) {
3509     auto rhsTy = cast<ShapedType>(rhs.getType());
3510     auto resTy = cast<ShapedType>(res.getType());
3511 
3512     // TODO(suderman): Change this to use a vector.ima intrinsic.
3513     lhs = promote(rewriter, loc, lhs, resTy);
3514 
3515     if (flatten) {
3516       // NOTE: This following logic won't work for scalable vectors. For this
3517       // reason, "flattening" is not supported when shapes are dynamic (this
3518       // should be captured by one of the pre-conditions).
3519 
3520       // There are two options for handling the filter:
3521       //  * shape_cast(broadcast(filter))
3522       //  * broadcast(shuffle(filter))
3523       // Opt for the option without shape_cast to simplify the codegen.
3524       auto rhsSize = cast<VectorType>(rhs.getType()).getShape()[0];
3525       auto resSize = cast<VectorType>(res.getType()).getShape()[1];
3526 
3527       SmallVector<int64_t, 16> indices;
3528       for (int i = 0; i < resSize / rhsSize; ++i) {
3529         for (int j = 0; j < rhsSize; ++j)
3530           indices.push_back(j);
3531       }
3532 
3533       rhs = rewriter.create<vector::ShuffleOp>(loc, rhs, rhs, indices);
3534     }
3535     // Broadcast the filter to match the output vector
3536     rhs = rewriter.create<vector::BroadcastOp>(
3537         loc, resTy.clone(rhsTy.getElementType()), rhs);
3538 
3539     rhs = promote(rewriter, loc, rhs, resTy);
3540 
3541     if (!lhs || !rhs)
3542       return nullptr;
3543 
3544     if (isa<FloatType>(resTy.getElementType()))
3545       return rewriter.create<vector::FMAOp>(loc, lhs, rhs, res);
3546 
3547     auto mul = rewriter.create<arith::MulIOp>(loc, lhs, rhs);
3548     return rewriter.create<arith::AddIOp>(loc, mul, res);
3549   }
3550 
3551   /// Entry point for non-channeled convolution:
3552   ///   {{w + kw}, {kw}, {w}}
3553   FailureOr<Operation *> generateNonChanneledConv() {
3554     AffineExpr w, kw;
3555     bindDims(ctx, w, kw);
3556     if (!iters({Par(), Red()}))
3557       return rewriter.notifyMatchFailure(op,
3558                                          "failed to match conv::W 1-par 1-red");
3559 
3560     // No transposition needed.
3561     if (layout({/*lhsIndex*/ {w + kw},
3562                 /*rhsIndex*/ {kw},
3563                 /*resIndex*/ {w}}))
3564       return conv(Conv1DOpOrder::W);
3565 
3566     return rewriter.notifyMatchFailure(op, "not a conv::W layout");
3567   }
3568 
3569   /// Entry point that transposes into the common form:
3570   ///   {{n, strideW * w + dilationW * kw, c}, {kw, c, f}, {n, w, f}}
3571   FailureOr<Operation *> generateNwcConv() {
3572     AffineExpr n, w, f, kw, c;
3573     bindDims(ctx, n, w, f, kw, c);
3574     if (!iters({Par(), Par(), Par(), Red(), Red()}))
3575       return rewriter.notifyMatchFailure(
3576           op, "failed to match conv::Nwc 3-par 2-red");
3577 
3578     // No transposition needed.
3579     if (layout({/*lhsIndex*/ {n, strideW * w + dilationW * kw, c},
3580                 /*rhsIndex*/ {kw, c, f},
3581                 /*resIndex*/ {n, w, f}}))
3582       return conv(Conv1DOpOrder::Nwc);
3583 
3584     return rewriter.notifyMatchFailure(op, "not a conv::Nwc layout");
3585   }
3586 
3587   /// Entry point that transposes into the common form:
3588   ///   {{n, c, strideW * w + dilationW * kw}, {f, c, kw}, {n, f, w}}
3589   FailureOr<Operation *> generateNcwConv() {
3590     AffineExpr n, w, f, kw, c;
3591     bindDims(ctx, n, f, w, c, kw);
3592     if (!iters({Par(), Par(), Par(), Red(), Red()}))
3593       return rewriter.notifyMatchFailure(
3594           op, "failed to match conv::Ncw 3-par 2-red");
3595 
3596     if (layout({/*lhsIndex*/ {n, c, strideW * w + dilationW * kw},
3597                 /*rhsIndex*/ {f, c, kw},
3598                 /*resIndex*/ {n, f, w}}))
3599       return conv(Conv1DOpOrder::Ncw);
3600 
3601     return rewriter.notifyMatchFailure(op, "not a conv::Ncw layout");
3602   }
3603 
3604   /// Entry point that transposes into the common form:
3605   ///   {{n, strideW * w + dilationW * kw, c}, {kw}, {n, w, c}} for pooling
3606   FailureOr<Operation *> generateNwcPooling() {
3607     AffineExpr n, w, c, kw;
3608     bindDims(ctx, n, w, c, kw);
3609     if (!iters({Par(), Par(), Par(), Red()}))
3610       return rewriter.notifyMatchFailure(op,
3611                                          "failed to match pooling 3-par 1-red");
3612 
3613     // No transposition needed.
3614     if (layout({/*lhsIndex*/ {n, strideW * w + dilationW * kw, c},
3615                 /*rhsIndex*/ {kw},
3616                 /*resIndex*/ {n, w, c}}))
3617       return conv(Conv1DOpOrder::Nwc);
3618 
3619     return rewriter.notifyMatchFailure(op, "not a pooling::Nwc layout");
3620   }
3621 
3622   /// Entry point that transposes into the common form:
3623   ///   {{n, c, strideW * w + dilationW * kw}, {kw}, {n, c, w}} for pooling
3624   FailureOr<Operation *> generateNcwPooling() {
3625     AffineExpr n, w, c, kw;
3626     bindDims(ctx, n, c, w, kw);
3627     if (!iters({Par(), Par(), Par(), Red()}))
3628       return rewriter.notifyMatchFailure(op,
3629                                          "failed to match pooling 3-par 1-red");
3630 
3631     if (layout({/*lhsIndex*/ {n, c, strideW * w + dilationW * kw},
3632                 /*rhsIndex*/ {kw},
3633                 /*resIndex*/ {n, c, w}}))
3634       return conv(Conv1DOpOrder::Ncw);
3635 
3636     return rewriter.notifyMatchFailure(op, "not a pooling::Ncw layout");
3637   }
3638 
3639   /// Entry point that transposes into the common form:
3640   ///   {{n, strideW * w + dilationW * kw, c}, {kw, c}, {n, w, c}}
3641   FailureOr<Operation *> generateDilatedConv(uint64_t vecChDimSize = 0,
3642                                              bool vecChDimScalableFlag = false,
3643                                              bool flatten = false) {
3644     AffineExpr n, w, c, kw;
3645     bindDims(ctx, n, w, c, kw);
3646     if (!iters({Par(), Par(), Par(), Red()}))
3647       return rewriter.notifyMatchFailure(
3648           op, "failed to match depthwise::Nwc conv 3-par 1-red");
3649 
3650     // No transposition needed.
3651     if (layout({/*lhsIndex*/ {n, strideW * w + dilationW * kw, c},
3652                 /*rhsIndex*/ {kw, c},
3653                 /*resIndex*/ {n, w, c}}))
3654       return depthwiseConv(vecChDimSize, vecChDimScalableFlag, flatten);
3655 
3656     return rewriter.notifyMatchFailure(op, "not a depthwise::Nwc layout");
3657   }
3658 
3659 private:
3660   enum OperKind { Conv, Pool };
3661   bool valid = false;
3662   OperKind oper = Conv;
3663   StringAttr redOp;
3664   StringAttr poolExtOp;
3665   bool isPoolExt = false;
3666   int strideW, dilationW;
3667   Value lhsShaped, rhsShaped, resShaped;
3668   ShapedType lhsShapedType, rhsShapedType, resShapedType;
3669 
3670   // Sets oper, poolExtOp and isPoolExt for valid conv/pooling ops.
3671   // Returns true iff it is a valid conv/pooling op.
3672   // If (region has 2 ops (reduction + yield) or 3 ops (extension + reduction
3673   // + yield) and rhs is not used) then it is the body of a pooling
3674   // If conv, check for single `mul` predecessor. The `mul` operands must be
3675   // block arguments or extension of block arguments.
3676   // Otherwise, check for one or zero `ext` predecessor. The `ext` operands
3677   // must be block arguments or extension of block arguments.
3678   bool setOperKind(Operation *reduceOp) {
3679     int numBlockArguments =
3680         llvm::count_if(reduceOp->getOperands(), llvm::IsaPred<BlockArgument>);
3681     switch (numBlockArguments) {
3682     case 1: {
3683       // Will be convolution if feeder is a MulOp.
3684       // Otherwise, if it can be pooling.
3685       auto feedValIt = llvm::find_if_not(reduceOp->getOperands(),
3686                                          llvm::IsaPred<BlockArgument>);
3687       Operation *feedOp = (*feedValIt).getDefiningOp();
3688       if (isCastOfBlockArgument(feedOp)) {
3689         oper = Pool;
3690         isPoolExt = true;
3691         poolExtOp = feedOp->getName().getIdentifier();
3692       } else if (!(isa<arith::MulIOp, arith::MulFOp>(feedOp) &&
3693                    llvm::all_of(feedOp->getOperands(), [](Value v) {
3694                      if (isa<BlockArgument>(v))
3695                        return true;
3696                      if (Operation *op = v.getDefiningOp())
3697                        return isCastOfBlockArgument(op);
3698                      return false;
3699                    }))) {
3700         return false;
3701       }
3702       return true;
3703     }
3704     case 2:
3705       // Must be pooling
3706       oper = Pool;
3707       isPoolExt = false;
3708       return true;
3709     default:
3710       return false;
3711     }
3712   }
3713 };
3714 } // namespace
3715 
3716 /// Helper function to vectorize a LinalgOp with convolution semantics.
3717 // TODO: extend the generic vectorization to support windows and drop this.
3718 static FailureOr<Operation *> vectorizeConvolution(
3719     RewriterBase &rewriter, LinalgOp op, ArrayRef<int64_t> inputVecSizes,
3720     ArrayRef<bool> inputScalableVecDims, bool flatten1DDepthwiseConv) {
3721   // The ConvolutionOpInterface gives us guarantees of existence for
3722   // strides/dilations. However, we do not need to rely on those, we can
3723   // simply use them if present, otherwise use the default and let the generic
3724   // conv. matcher in the ConvGenerator succeed or fail.
3725   auto strides = op->getAttrOfType<DenseIntElementsAttr>("strides");
3726   auto dilations = op->getAttrOfType<DenseIntElementsAttr>("dilations");
3727   auto stride = strides ? *strides.getValues<uint64_t>().begin() : 1;
3728   auto dilation = dilations ? *dilations.getValues<uint64_t>().begin() : 1;
3729   Conv1DGenerator e(rewriter, op, stride, dilation);
3730   auto res = e.generateNonChanneledConv();
3731   if (succeeded(res))
3732     return res;
3733   res = e.generateNwcConv();
3734   if (succeeded(res))
3735     return res;
3736   res = e.generateNcwConv();
3737   if (succeeded(res))
3738     return res;
3739   res = e.generateNwcPooling();
3740   if (succeeded(res))
3741     return res;
3742   res = e.generateNcwPooling();
3743   if (succeeded(res))
3744     return res;
3745 
3746   // Only depthwise 1D NWC convs are left - these can be vectorized using masks
3747   // and scalable vectors. Note that ATM the only dim that can be dynamic (i.e.
3748   // masked/scalable) is the channel dim (i.e. the trailing dim).
3749   uint64_t vecChDimSize = ShapedType::kDynamic;
3750   bool vecChDimScalableFlag = false;
3751   if (!inputVecSizes.empty()) {
3752     // Only use the input vector size corresponding to the channel dim. Other
3753     // vector dims will be inferred from the Ops.
3754     assert((isa<linalg::DepthwiseConv1DNwcWcOp>(*op) ||
3755             isa<linalg::DepthwiseConv1DNcwCwOp>(*op)) &&
3756            "Not a 1D depthwise conv!");
3757     size_t chDimIdx =
3758         TypeSwitch<Operation *, size_t>(op)
3759             .Case<linalg::DepthwiseConv1DNwcWcOp>([](auto conv) { return 2; })
3760             .Case<linalg::DepthwiseConv1DNcwCwOp>([](auto conv) { return 1; });
3761 
3762     vecChDimSize = inputVecSizes[chDimIdx];
3763     vecChDimScalableFlag = inputScalableVecDims[chDimIdx];
3764   }
3765   return e.generateDilatedConv(vecChDimSize, vecChDimScalableFlag,
3766                                flatten1DDepthwiseConv);
3767 }
3768 
3769 struct VectorizeConvolution : public OpInterfaceRewritePattern<LinalgOp> {
3770   using OpInterfaceRewritePattern::OpInterfaceRewritePattern;
3771 
3772   LogicalResult matchAndRewrite(LinalgOp op,
3773                                 PatternRewriter &rewriter) const override {
3774     FailureOr<Operation *> resultOrFail = vectorizeConvolution(rewriter, op);
3775     if (failed(resultOrFail))
3776       return failure();
3777     Operation *newOp = *resultOrFail;
3778     if (newOp->getNumResults() == 0) {
3779       rewriter.eraseOp(op.getOperation());
3780       return success();
3781     }
3782     assert(newOp->getNumResults() == 1 && "expected single result");
3783     rewriter.replaceOp(op.getOperation(), newOp->getResult(0));
3784     return success();
3785   }
3786 };
3787 
3788 void mlir::linalg::populateConvolutionVectorizationPatterns(
3789     RewritePatternSet &patterns, PatternBenefit benefit) {
3790   patterns.add<VectorizeConvolution>(patterns.getContext(), benefit);
3791 }
3792