xref: /llvm-project/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp (revision 8868c02cda875d1efe1646affa01656ef268ffed)
1 //===- Vectorization.cpp - Implementation of linalg Vectorization ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the linalg dialect Vectorization transformations.
10 //
11 //===----------------------------------------------------------------------===//
12 #include "mlir/Dialect/Affine/Utils.h"
13 
14 #include "mlir/Analysis/SliceAnalysis.h"
15 #include "mlir/Dialect/Affine/IR/AffineOps.h"
16 #include "mlir/Dialect/Arith/IR/Arith.h"
17 #include "mlir/Dialect/Func/IR/FuncOps.h"
18 #include "mlir/Dialect/Linalg/IR/Linalg.h"
19 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
20 #include "mlir/Dialect/Linalg/Utils/Utils.h"
21 #include "mlir/Dialect/Tensor/IR/Tensor.h"
22 #include "mlir/Dialect/Tensor/Utils/Utils.h"
23 #include "mlir/Dialect/Utils/IndexingUtils.h"
24 #include "mlir/Dialect/Utils/StructuredOpsUtils.h"
25 #include "mlir/Dialect/Vector/IR/VectorOps.h"
26 #include "mlir/Dialect/Vector/Interfaces/MaskableOpInterface.h"
27 #include "mlir/Dialect/Vector/Utils/VectorUtils.h"
28 #include "mlir/IR/AffineExpr.h"
29 #include "mlir/IR/Builders.h"
30 #include "mlir/IR/BuiltinTypeInterfaces.h"
31 #include "mlir/IR/BuiltinTypes.h"
32 #include "mlir/IR/OpDefinition.h"
33 #include "mlir/IR/PatternMatch.h"
34 #include "mlir/Support/LLVM.h"
35 #include "mlir/Transforms/RegionUtils.h"
36 #include "llvm/ADT/STLExtras.h"
37 #include "llvm/ADT/Sequence.h"
38 #include "llvm/ADT/SmallVector.h"
39 #include "llvm/ADT/TypeSwitch.h"
40 #include "llvm/ADT/iterator_range.h"
41 #include "llvm/Support/Debug.h"
42 #include "llvm/Support/MathExtras.h"
43 #include "llvm/Support/raw_ostream.h"
44 #include <optional>
45 #include <type_traits>
46 
47 using namespace mlir;
48 using namespace mlir::linalg;
49 
50 #define DEBUG_TYPE "linalg-vectorization"
51 
52 #define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
53 #define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
54 
55 /// Try to vectorize `convOp` as a convolution.
56 static FailureOr<Operation *>
57 vectorizeConvolution(RewriterBase &rewriter, LinalgOp convOp,
58                      ArrayRef<int64_t> inputVecSizes = {},
59                      ArrayRef<bool> inputVecScalableFlags = {},
60                      bool flatten1DDepthwiseConv = false);
61 
62 /// Return the unique instance of OpType in `block` if it is indeed unique.
63 /// Return null if none or more than 1 instances exist.
64 template <typename OpType>
65 static OpType getSingleOpOfType(Block &block) {
66   OpType res;
67   block.walk([&](OpType op) {
68     if (res) {
69       res = nullptr;
70       return WalkResult::interrupt();
71     }
72     res = op;
73     return WalkResult::advance();
74   });
75   return res;
76 }
77 
78 /// Helper function to extract the input slices after filter is unrolled along
79 /// kw.
80 static SmallVector<Value>
81 extractConvInputSlices(RewriterBase &rewriter, Location loc, Value input,
82                        int64_t nSize, int64_t wSize, int64_t cSize,
83                        int64_t kwSize, int strideW, int dilationW,
84                        int64_t wSizeStep, bool isSingleChanneled) {
85   SmallVector<Value> result;
86   if (isSingleChanneled) {
87     // Extract input slice of size {wSizeStep} @ [w + kw] for non-channeled
88     // convolution.
89     SmallVector<int64_t> sizes{wSizeStep};
90     SmallVector<int64_t> strides{1};
91     for (int64_t kw = 0; kw < kwSize; ++kw) {
92       for (int64_t w = 0; w < wSize; w += wSizeStep) {
93         result.push_back(rewriter.create<vector::ExtractStridedSliceOp>(
94             loc, input, /*offsets=*/ArrayRef<int64_t>{w + kw}, sizes, strides));
95       }
96     }
97   } else {
98     // Extract lhs slice of size {n, wSizeStep, c} @ [0, sw * w + dw * kw, 0]
99     // for channeled convolution.
100     SmallVector<int64_t> sizes{nSize, wSizeStep, cSize};
101     SmallVector<int64_t> strides{1, 1, 1};
102     for (int64_t kw = 0; kw < kwSize; ++kw) {
103       for (int64_t w = 0; w < wSize; w += wSizeStep) {
104         result.push_back(rewriter.create<vector::ExtractStridedSliceOp>(
105             loc, input,
106             /*offsets=*/ArrayRef<int64_t>{0, w * strideW + kw * dilationW, 0},
107             sizes, strides));
108       }
109     }
110   }
111   return result;
112 }
113 
114 /// Helper function to extract the filter slices after filter is unrolled along
115 /// kw.
116 static SmallVector<Value> extractConvFilterSlices(RewriterBase &rewriter,
117                                                   Location loc, Value filter,
118                                                   int64_t kwSize) {
119   SmallVector<Value> result;
120   // Extract rhs slice of size [{c, f} for channeled convolutions and {1} for
121   // non-chanelled convolution] @ [kw].
122   for (int64_t kw = 0; kw < kwSize; ++kw) {
123     result.push_back(rewriter.create<vector::ExtractOp>(
124         loc, filter, /*offsets=*/ArrayRef<int64_t>{kw}));
125   }
126   return result;
127 }
128 
129 /// Helper function to extract the result slices after filter is unrolled along
130 /// kw.
131 static SmallVector<Value>
132 extractConvResultSlices(RewriterBase &rewriter, Location loc, Value res,
133                         int64_t nSize, int64_t wSize, int64_t fSize,
134                         int64_t wSizeStep, bool isSingleChanneled) {
135   SmallVector<Value> result;
136   if (isSingleChanneled) {
137     // Extract res slice: {wSizeStep} @ [w] for non-channeled convolution.
138     SmallVector<int64_t> sizes{wSizeStep};
139     SmallVector<int64_t> strides{1};
140     for (int64_t w = 0; w < wSize; w += wSizeStep) {
141       result.push_back(rewriter.create<vector::ExtractStridedSliceOp>(
142           loc, res, /*offsets=*/ArrayRef<int64_t>{w}, sizes, strides));
143     }
144   } else {
145     // Extract res slice: {n, wSizeStep, f} @ [0, w, 0] for channeled
146     // convolution.
147     SmallVector<int64_t> sizes{nSize, wSizeStep, fSize};
148     SmallVector<int64_t> strides{1, 1, 1};
149     for (int64_t w = 0; w < wSize; w += wSizeStep) {
150       result.push_back(rewriter.create<vector::ExtractStridedSliceOp>(
151           loc, res, /*offsets=*/ArrayRef<int64_t>{0, w, 0}, sizes, strides));
152     }
153   }
154   return result;
155 }
156 
157 /// Helper function to insert the computed result slices.
158 static Value insertConvResultSlices(RewriterBase &rewriter, Location loc,
159                                     Value res, int64_t wSize, int64_t wSizeStep,
160                                     SmallVectorImpl<Value> &resVals,
161                                     bool isSingleChanneled) {
162 
163   if (isSingleChanneled) {
164     // Write back res slice: {wSizeStep} @ [w] for non-channeled convolution.
165     // This does not depend on kw.
166     SmallVector<int64_t> strides{1};
167     for (int64_t w = 0; w < wSize; w += wSizeStep) {
168       res = rewriter.create<vector::InsertStridedSliceOp>(
169           loc, resVals[w], res, /*offsets=*/ArrayRef<int64_t>{w}, strides);
170     }
171   } else {
172     // Write back res slice: {n, wSizeStep, f} @ [0, w, 0] for channeled
173     // convolution. This does not depend on kw.
174     SmallVector<int64_t> strides{1, 1, 1};
175     for (int64_t w = 0; w < wSize; w += wSizeStep) {
176       res = rewriter.create<vector::InsertStridedSliceOp>(
177           loc, resVals[w], res, /*offsets=*/ArrayRef<int64_t>{0, w, 0},
178           strides);
179     }
180   }
181   return res;
182 }
183 
184 /// Contains the vectorization state and related methods used across the
185 /// vectorization process of a given operation.
186 struct VectorizationState {
187   VectorizationState(RewriterBase &rewriter) : rewriterGuard(rewriter) {}
188 
189   /// Initializes the vectorization state, including the computation of the
190   /// canonical vector shape for vectorization.
191   LogicalResult initState(RewriterBase &rewriter, LinalgOp linalgOp,
192                           ArrayRef<int64_t> inputVectorSizes,
193                           ArrayRef<bool> inputScalableVecDims);
194 
195   /// Returns the canonical vector shape used to vectorize the iteration space.
196   ArrayRef<int64_t> getCanonicalVecShape() const { return canonicalVecShape; }
197 
198   /// Returns the vector dimensions that are scalable in the canonical vector
199   /// shape.
200   ArrayRef<bool> getScalableVecDims() const { return scalableVecDims; }
201 
202   /// Returns a vector type of the provided `elementType` with the canonical
203   /// vector shape and the corresponding fixed/scalable dimensions bit. If
204   /// `dimPermutation` is provided, the canonical vector dimensions are permuted
205   /// accordingly.
206   VectorType getCanonicalVecType(
207       Type elementType,
208       std::optional<AffineMap> dimPermutation = std::nullopt) const {
209     SmallVector<int64_t> vectorShape;
210     SmallVector<bool> scalableDims;
211     if (dimPermutation.has_value()) {
212       vectorShape =
213           applyPermutationMap<int64_t>(*dimPermutation, canonicalVecShape);
214       scalableDims =
215           applyPermutationMap<bool>(*dimPermutation, scalableVecDims);
216     } else {
217       vectorShape.append(canonicalVecShape.begin(), canonicalVecShape.end());
218       scalableDims.append(scalableVecDims.begin(), scalableVecDims.end());
219     }
220 
221     return VectorType::get(vectorShape, elementType, scalableDims);
222   }
223 
224   /// Masks an operation with the canonical vector mask if the operation needs
225   /// masking. Returns the masked operation or the original operation if masking
226   /// is not needed. If provided, the canonical mask for this operation is
227   /// permuted using `maybeMaskingMap`.
228   Operation *
229   maskOperation(RewriterBase &rewriter, Operation *opToMask, LinalgOp linalgOp,
230                 std::optional<AffineMap> maybeMaskingMap = std::nullopt);
231 
232 private:
233   /// Initializes the iteration space static sizes using the Linalg op
234   /// information. This may become more complicated in the future.
235   void initIterSpaceStaticSizes(LinalgOp linalgOp) {
236     iterSpaceStaticSizes.append(linalgOp.getStaticLoopRanges());
237   }
238 
239   /// Generates 'arith.constant' and 'tensor/memref.dim' operations for
240   /// all the static and dynamic dimensions of the iteration space to be
241   /// vectorized and store them in `iterSpaceValueSizes`.
242   LogicalResult precomputeIterSpaceValueSizes(RewriterBase &rewriter,
243                                               LinalgOp linalgOp);
244 
245   /// Create or retrieve an existing mask value to mask `opToMask` in the
246   /// canonical vector iteration space. If `maybeMaskingMap` the mask is
247   /// permuted using that permutation map. If a new mask is created, it will be
248   /// cached for future users.
249   Value getOrCreateMaskFor(RewriterBase &rewriter, Operation *opToMask,
250                            LinalgOp linalgOp,
251                            std::optional<AffineMap> maybeMaskingMap);
252 
253   // Holds the compile-time static sizes of the iteration space to vectorize.
254   // Dynamic dimensions are represented using ShapedType::kDynamic.
255   SmallVector<int64_t> iterSpaceStaticSizes;
256 
257   /// Holds the value sizes of the iteration space to vectorize. Static
258   /// dimensions are represented by 'arith.constant' and dynamic
259   /// dimensions by 'tensor/memref.dim'.
260   SmallVector<Value> iterSpaceValueSizes;
261 
262   /// Holds the canonical vector shape used to vectorize the iteration space.
263   SmallVector<int64_t> canonicalVecShape;
264 
265   /// Holds the vector dimensions that are scalable in the canonical vector
266   /// shape.
267   SmallVector<bool> scalableVecDims;
268 
269   /// Holds the active masks for permutations of the canonical vector iteration
270   /// space.
271   DenseMap<AffineMap, Value> activeMaskCache;
272 
273   /// Global vectorization guard for the incoming rewriter. It's initialized
274   /// when the vectorization state is initialized.
275   OpBuilder::InsertionGuard rewriterGuard;
276 };
277 
278 LogicalResult
279 VectorizationState::precomputeIterSpaceValueSizes(RewriterBase &rewriter,
280                                                   LinalgOp linalgOp) {
281   // TODO: Support 0-d vectors.
282   for (int vecDim = 0, end = canonicalVecShape.size(); vecDim < end; ++vecDim) {
283     if (!ShapedType::isDynamic(iterSpaceStaticSizes[vecDim])) {
284       // Create constant index op for static dimensions.
285       iterSpaceValueSizes.push_back(rewriter.create<arith::ConstantIndexOp>(
286           linalgOp.getLoc(), iterSpaceStaticSizes[vecDim]));
287       continue;
288     }
289 
290     // Find an operand defined on this dimension of the iteration space to
291     // extract the runtime dimension size.
292     Value operand;
293     unsigned operandDimPos;
294     if (failed(linalgOp.mapIterationSpaceDimToOperandDim(vecDim, operand,
295                                                          operandDimPos)))
296       return failure();
297 
298     Value dynamicDim = linalgOp.hasPureTensorSemantics()
299                            ? (Value)rewriter.create<tensor::DimOp>(
300                                  linalgOp.getLoc(), operand, operandDimPos)
301                            : (Value)rewriter.create<memref::DimOp>(
302                                  linalgOp.getLoc(), operand, operandDimPos);
303     iterSpaceValueSizes.push_back(dynamicDim);
304   }
305 
306   return success();
307 }
308 
309 /// Initializes the vectorization state, including the computation of the
310 /// canonical vector shape for vectorization.
311 // TODO: Move this to the constructor when we can remove the failure cases.
312 LogicalResult
313 VectorizationState::initState(RewriterBase &rewriter, LinalgOp linalgOp,
314                               ArrayRef<int64_t> inputVectorSizes,
315                               ArrayRef<bool> inputScalableVecDims) {
316   // Initialize the insertion point.
317   rewriter.setInsertionPoint(linalgOp);
318 
319   if (!inputVectorSizes.empty()) {
320     // Get the canonical vector shape from the input vector sizes provided. This
321     // path should be taken to vectorize code with dynamic shapes and when using
322     // vector sizes greater than the iteration space sizes.
323     canonicalVecShape.append(inputVectorSizes.begin(), inputVectorSizes.end());
324     scalableVecDims.append(inputScalableVecDims.begin(),
325                            inputScalableVecDims.end());
326   } else {
327     // Compute the canonical vector shape from the operation shape. If there are
328     // dynamic shapes, the operation won't be vectorized. We assume all the
329     // vector dimensions are fixed.
330     canonicalVecShape = linalgOp.getStaticLoopRanges();
331     scalableVecDims.append(linalgOp.getNumLoops(), false);
332   }
333 
334   LDBG("Canonical vector shape: ");
335   LLVM_DEBUG(llvm::interleaveComma(canonicalVecShape, llvm::dbgs()));
336   LLVM_DEBUG(llvm::dbgs() << "\n");
337   LDBG("Scalable vector dims: ");
338   LLVM_DEBUG(llvm::interleaveComma(scalableVecDims, llvm::dbgs()));
339   LLVM_DEBUG(llvm::dbgs() << "\n");
340 
341   if (ShapedType::isDynamicShape(canonicalVecShape))
342     return failure();
343 
344   // Initialize iteration space static sizes.
345   initIterSpaceStaticSizes(linalgOp);
346 
347   // Generate 'arith.constant' and 'tensor/memref.dim' operations for
348   // all the static and dynamic dimensions of the iteration space, needed to
349   // compute a mask during vectorization.
350   if (failed(precomputeIterSpaceValueSizes(rewriter, linalgOp)))
351     return failure();
352 
353   return success();
354 }
355 
356 /// Create or retrieve an existing mask value to mask `opToMask` in the
357 /// canonical vector iteration space. If `maybeMaskingMap` the mask is permuted
358 /// using that permutation map. If a new mask is created, it will be cached for
359 /// future users.
360 Value VectorizationState::getOrCreateMaskFor(
361     RewriterBase &rewriter, Operation *opToMask, LinalgOp linalgOp,
362     std::optional<AffineMap> maybeMaskingMap) {
363   // No mask is needed if the operation is not maskable.
364   auto maskableOp = dyn_cast<vector::MaskableOpInterface>(opToMask);
365   if (!maskableOp)
366     return Value();
367 
368   assert(!maskableOp.isMasked() &&
369          "Masking an operation that is already masked");
370 
371   // If no masking map was provided, use an identity map with the loop dims.
372   assert((!maybeMaskingMap || *maybeMaskingMap) &&
373          "Unexpected null mask permutation map");
374   AffineMap maskingMap =
375       maybeMaskingMap ? *maybeMaskingMap
376                       : AffineMap::getMultiDimIdentityMap(
377                             linalgOp.getNumLoops(), rewriter.getContext());
378 
379   LDBG("Masking map: " << maskingMap << "\n");
380 
381   // Return the active mask for the masking map of this operation if it was
382   // already created.
383   auto activeMaskIt = activeMaskCache.find(maskingMap);
384   if (activeMaskIt != activeMaskCache.end()) {
385     Value mask = activeMaskIt->second;
386     LDBG("Reusing mask: " << mask << "\n");
387     return mask;
388   }
389 
390   // Compute permuted projection of the iteration space to be masked and the
391   // corresponding mask shape. If the resulting iteration space dimensions are
392   // static and identical to the mask shape, masking is not needed for this
393   // operation.
394   // TODO: Improve this check. Only projected permutation indexing maps are
395   // supported.
396   SmallVector<int64_t> permutedStaticSizes =
397       applyPermutationMap<int64_t>(maskingMap, iterSpaceStaticSizes);
398   auto maskType = getCanonicalVecType(rewriter.getI1Type(), maskingMap);
399   auto maskShape = maskType.getShape();
400 
401   LDBG("Mask shape: ");
402   LLVM_DEBUG(llvm::interleaveComma(maskShape, llvm::dbgs()));
403   LLVM_DEBUG(llvm::dbgs() << "\n");
404 
405   if (permutedStaticSizes == maskShape) {
406     LDBG("Masking is not needed for masking map: " << maskingMap << "\n");
407     activeMaskCache[maskingMap] = Value();
408     return Value();
409   }
410 
411   // Permute the iteration space value sizes to compute the mask upper bounds.
412   SmallVector<Value> upperBounds =
413       applyPermutationMap(maskingMap, ArrayRef<Value>(iterSpaceValueSizes));
414   assert(!maskShape.empty() && !upperBounds.empty() &&
415          "Masked 0-d vectors are not supported yet");
416 
417   // Create the mask based on the dimension values.
418   Value mask = rewriter.create<vector::CreateMaskOp>(linalgOp.getLoc(),
419                                                      maskType, upperBounds);
420   LDBG("Creating new mask: " << mask << "\n");
421   activeMaskCache[maskingMap] = mask;
422   return mask;
423 }
424 
425 /// Masks an operation with the canonical vector mask if the operation needs
426 /// masking. Returns the masked operation or the original operation if masking
427 /// is not needed. If provided, the canonical mask for this operation is
428 /// permuted using `maybeMaskingMap`.
429 Operation *
430 VectorizationState::maskOperation(RewriterBase &rewriter, Operation *opToMask,
431                                   LinalgOp linalgOp,
432                                   std::optional<AffineMap> maybeMaskingMap) {
433   LDBG("Trying to mask: " << *opToMask << "\n");
434 
435   // Create or retrieve mask for this operation.
436   Value mask =
437       getOrCreateMaskFor(rewriter, opToMask, linalgOp, maybeMaskingMap);
438 
439   if (!mask) {
440     LDBG("No mask required\n");
441     return opToMask;
442   }
443 
444   // Wrap the operation with a new `vector.mask` and update D-U chain.
445   assert(opToMask && "Expected a valid operation to mask");
446   auto maskOp = cast<vector::MaskOp>(
447       mlir::vector::maskOperation(rewriter, opToMask, mask));
448   Operation *maskOpTerminator = &maskOp.getMaskRegion().front().back();
449 
450   for (auto [resIdx, resVal] : llvm::enumerate(opToMask->getResults()))
451     rewriter.replaceAllUsesExcept(resVal, maskOp.getResult(resIdx),
452                                   maskOpTerminator);
453 
454   LDBG("Masked operation: " << *maskOp << "\n");
455   return maskOp;
456 }
457 
458 /// Given an indexing `map` coming from a LinalgOp indexing, restricted to a
459 /// projectedPermutation, compress the unused dimensions to serve as a
460 /// permutation_map for a vector transfer operation.
461 /// For example, given a linalg op such as:
462 ///
463 /// ```
464 ///   %0 = linalg.generic {
465 ///        indexing_maps = affine_map<(d0, d1, d2, d3, d4) -> (d4, d0, d2)>,
466 ///        indexing_maps = affine_map<(d0, d1, d2, d3, d4) -> (d1, d3)>
467 ///      }
468 ///     ins(%0 : tensor<2x3x4xf32>)
469 ///    outs(%1 : tensor<5x6xf32>)
470 /// ```
471 ///
472 /// the iteration domain size of the linalg op is 3x5x4x6x2. The first affine
473 /// map is reindexed to `affine_map<(d0, d1, d2) -> (d2, d0, d1)>`, the second
474 /// affine map is reindexed to `affine_map<(d0, d1) -> (d0, d1)>`.
475 static AffineMap reindexIndexingMap(AffineMap map) {
476   assert(map.isProjectedPermutation(/*allowZeroInResults=*/true) &&
477          "expected projected permutation");
478   auto res = compressUnusedDims(map);
479   assert(res.getNumDims() == res.getNumResults() &&
480          "expected reindexed map with same number of dims and results");
481   return res;
482 }
483 
484 /// Helper enum to represent conv1d input traversal order.
485 enum class Conv1DOpOrder {
486   W,   // Corresponds to non-channeled 1D convolution operation.
487   Ncw, // Corresponds to operation that traverses the input in (n, c, w) order.
488   Nwc  // Corresponds to operation that traverses the input in (n, w, c) order.
489 };
490 
491 /// Helper data structure to represent the result of vectorization.
492 /// In certain specific cases, like terminators, we do not want to propagate/
493 enum VectorizationStatus {
494   /// Op failed to vectorize.
495   Failure = 0,
496   /// Op vectorized and custom function took care of replacement logic
497   NoReplace,
498   /// Op vectorized into a new Op whose results will replace original Op's
499   /// results.
500   NewOp
501   // TODO: support values if Op vectorized to Many-Ops whose results we need to
502   // aggregate for replacement.
503 };
504 struct VectorizationResult {
505   /// Return status from vectorizing the current op.
506   enum VectorizationStatus status = VectorizationStatus::Failure;
507   /// New vectorized operation to replace the current op.
508   /// Replacement behavior is specified by `status`.
509   Operation *newOp;
510 };
511 
512 std::optional<vector::CombiningKind>
513 mlir::linalg::getCombinerOpKind(Operation *combinerOp) {
514   using ::mlir::vector::CombiningKind;
515 
516   if (!combinerOp)
517     return std::nullopt;
518   return llvm::TypeSwitch<Operation *, std::optional<CombiningKind>>(combinerOp)
519       .Case<arith::AddIOp, arith::AddFOp>(
520           [&](auto op) { return CombiningKind::ADD; })
521       .Case<arith::AndIOp>([&](auto op) { return CombiningKind::AND; })
522       .Case<arith::MaxSIOp>([&](auto op) { return CombiningKind::MAXSI; })
523       .Case<arith::MaxUIOp>([&](auto op) { return CombiningKind::MAXUI; })
524       .Case<arith::MaximumFOp>([&](auto op) { return CombiningKind::MAXIMUMF; })
525       .Case<arith::MaxNumFOp>([&](auto op) { return CombiningKind::MAXNUMF; })
526       .Case<arith::MinSIOp>([&](auto op) { return CombiningKind::MINSI; })
527       .Case<arith::MinUIOp>([&](auto op) { return CombiningKind::MINUI; })
528       .Case<arith::MinimumFOp>([&](auto op) { return CombiningKind::MINIMUMF; })
529       .Case<arith::MinNumFOp>([&](auto op) { return CombiningKind::MINNUMF; })
530       .Case<arith::MulIOp, arith::MulFOp>(
531           [&](auto op) { return CombiningKind::MUL; })
532       .Case<arith::OrIOp>([&](auto op) { return CombiningKind::OR; })
533       .Case<arith::XOrIOp>([&](auto op) { return CombiningKind::XOR; })
534       .Default([&](auto op) { return std::nullopt; });
535 }
536 
537 /// Check whether `outputOperand` is a reduction with a single combiner
538 /// operation. Return the combiner operation of the reduction. Return
539 /// nullptr otherwise. Multiple reduction operations would impose an
540 /// ordering between reduction dimensions and is currently unsupported in
541 /// Linalg. This limitation is motivated by the fact that e.g. min(max(X)) !=
542 /// max(min(X))
543 // TODO: use in LinalgOp verification, there is a circular dependency atm.
544 static Operation *matchLinalgReduction(OpOperand *outputOperand) {
545   auto linalgOp = cast<LinalgOp>(outputOperand->getOwner());
546   unsigned outputPos =
547       outputOperand->getOperandNumber() - linalgOp.getNumDpsInputs();
548   // Only single combiner operations are supported for now.
549   SmallVector<Operation *, 4> combinerOps;
550   if (!matchReduction(linalgOp.getRegionOutputArgs(), outputPos, combinerOps) ||
551       combinerOps.size() != 1)
552     return nullptr;
553 
554   // Return the combiner operation.
555   return combinerOps[0];
556 }
557 
558 /// Broadcast `value` to a vector of `shape` if possible. Return value
559 /// otherwise.
560 static Value broadcastIfNeeded(OpBuilder &b, Value value, Type dstType) {
561   auto dstVecType = dyn_cast<VectorType>(dstType);
562   // If no shape to broadcast to, just return `value`.
563   if (dstVecType.getRank() == 0)
564     return value;
565   if (vector::isBroadcastableTo(value.getType(), dstVecType) !=
566       vector::BroadcastableToResult::Success)
567     return value;
568   Location loc = b.getInsertionPoint()->getLoc();
569   return b.createOrFold<vector::BroadcastOp>(loc, dstVecType, value);
570 }
571 
572 /// Create MultiDimReductionOp to compute the reduction for `reductionOp`. This
573 /// assumes that `reductionOp` has two operands and one of them is the reduction
574 /// initial value.buildMultiDimReduce
575 // Note: this is a true builder that notifies the OpBuilder listener.
576 // TODO: Consider moving as a static helper on the ReduceOp.
577 static Operation *buildMultiDimReduce(OpBuilder &b, Operation *reduceOp,
578                                       Value valueToReduce, Value acc,
579                                       ArrayRef<bool> dimsToMask) {
580   auto maybeKind = getCombinerOpKind(reduceOp);
581   assert(maybeKind && "Failed precondition: could not get reduction kind");
582   return b.create<vector::MultiDimReductionOp>(
583       reduceOp->getLoc(), valueToReduce, acc, dimsToMask, *maybeKind);
584 }
585 
586 static SmallVector<bool> getDimsToReduce(LinalgOp linalgOp) {
587   return llvm::to_vector(
588       llvm::map_range(linalgOp.getIteratorTypesArray(), isReductionIterator));
589 }
590 
591 /// Check if `op` is a linalg.reduce or a linalg.generic that has at least one
592 /// reduction iterator.
593 static bool hasReductionIterator(LinalgOp &op) {
594   return isa<linalg::ReduceOp>(op) ||
595          (isa<linalg::GenericOp>(op) &&
596           llvm::any_of(op.getIteratorTypesArray(), isReductionIterator));
597 }
598 
599 /// Build a vector.transfer_write of `value` into `outputOperand` at indices set
600 /// to all `0`; where `outputOperand` is an output operand of the LinalgOp
601 /// currently being vectorized. If `dest` has null rank, build an memref.store.
602 /// Return the produced value or null if no value is produced.
603 // Note: this is a true builder that notifies the OpBuilder listener.
604 // TODO: Consider moving as a static helper on the ReduceOp.
605 static Value buildVectorWrite(RewriterBase &rewriter, Value value,
606                               OpOperand *outputOperand,
607                               VectorizationState &state) {
608   Location loc = value.getLoc();
609   auto linalgOp = cast<LinalgOp>(outputOperand->getOwner());
610   AffineMap opOperandMap = linalgOp.getMatchingIndexingMap(outputOperand);
611 
612   // Compute the vector type of the value to store. This type should be an
613   // identity or projection of the canonical vector type without any permutation
614   // applied, given that any permutation in a transfer write happens as part of
615   // the write itself.
616   AffineMap vectorTypeMap = AffineMap::getFilteredIdentityMap(
617       opOperandMap.getContext(), opOperandMap.getNumInputs(),
618       [&](AffineDimExpr dimExpr) -> bool {
619         return llvm::is_contained(opOperandMap.getResults(), dimExpr);
620       });
621   auto vectorType = state.getCanonicalVecType(
622       getElementTypeOrSelf(outputOperand->get().getType()), vectorTypeMap);
623 
624   Operation *write;
625   if (vectorType.getRank() > 0) {
626     AffineMap writeMap = inversePermutation(reindexIndexingMap(opOperandMap));
627     SmallVector<Value> indices(linalgOp.getRank(outputOperand),
628                                rewriter.create<arith::ConstantIndexOp>(loc, 0));
629     value = broadcastIfNeeded(rewriter, value, vectorType);
630     assert(value.getType() == vectorType && "Incorrect type");
631     write = rewriter.create<vector::TransferWriteOp>(
632         loc, value, outputOperand->get(), indices, writeMap);
633   } else {
634     // 0-d case is still special: do not invert the reindexing writeMap.
635     if (!isa<VectorType>(value.getType()))
636       value = rewriter.create<vector::BroadcastOp>(loc, vectorType, value);
637     assert(value.getType() == vectorType && "Incorrect type");
638     write = rewriter.create<vector::TransferWriteOp>(
639         loc, value, outputOperand->get(), ValueRange{});
640   }
641 
642   write = state.maskOperation(rewriter, write, linalgOp, opOperandMap);
643 
644   // If masked, set in-bounds to true. Masking guarantees that the access will
645   // be in-bounds.
646   if (auto maskOp = dyn_cast<vector::MaskingOpInterface>(write)) {
647     auto maskedWriteOp = cast<vector::TransferWriteOp>(maskOp.getMaskableOp());
648     SmallVector<bool> inBounds(maskedWriteOp.getVectorType().getRank(), true);
649     maskedWriteOp.setInBoundsAttr(rewriter.getBoolArrayAttr(inBounds));
650   }
651 
652   LDBG("vectorized op: " << *write << "\n");
653   if (!write->getResults().empty())
654     return write->getResult(0);
655   return Value();
656 }
657 
658 // Custom vectorization precondition function type. This is intented to be used
659 // with CustomVectorizationHook. Returns success if the corresponding custom
660 // hook can vectorize the op.
661 using CustomVectorizationPrecondition =
662     std::function<LogicalResult(Operation *, bool)>;
663 
664 // Custom vectorization function type. Produce a vector form of Operation*
665 // assuming all its vectorized operands are already in the IRMapping.
666 // Return nullptr if the Operation cannot be vectorized.
667 using CustomVectorizationHook =
668     std::function<VectorizationResult(Operation *, const IRMapping &)>;
669 
670 /// Helper function to vectorize the terminator of a `linalgOp`. New result
671 /// vector values are appended to `newResults`. Return
672 /// VectorizationStatus::NoReplace to signal the vectorization algorithm that it
673 /// should not try to map produced operations and instead return the results
674 /// using the `newResults` vector making them available to the vectorization
675 /// algorithm for RAUW. This function is meant to be used as a
676 /// CustomVectorizationHook.
677 static VectorizationResult
678 vectorizeLinalgYield(RewriterBase &rewriter, Operation *op,
679                      const IRMapping &bvm, VectorizationState &state,
680                      LinalgOp linalgOp, SmallVectorImpl<Value> &newResults) {
681   auto yieldOp = dyn_cast<linalg::YieldOp>(op);
682   if (!yieldOp)
683     return VectorizationResult{VectorizationStatus::Failure, nullptr};
684   for (const auto &output : llvm::enumerate(yieldOp.getValues())) {
685     // TODO: Scan for an opportunity for reuse.
686     // TODO: use a map.
687     Value vectorValue = bvm.lookup(output.value());
688     Value newResult =
689         buildVectorWrite(rewriter, vectorValue,
690                          linalgOp.getDpsInitOperand(output.index()), state);
691     if (newResult)
692       newResults.push_back(newResult);
693   }
694 
695   return VectorizationResult{VectorizationStatus::NoReplace, nullptr};
696 }
697 
698 /// Helper function to vectorize the index operations of a `linalgOp`. Return
699 /// VectorizationStatus::NewOp to signal the vectorization algorithm that it
700 /// should map the produced operations. This function is meant to be used as a
701 /// CustomVectorizationHook.
702 static VectorizationResult vectorizeLinalgIndex(RewriterBase &rewriter,
703                                                 VectorizationState &state,
704                                                 Operation *op,
705                                                 LinalgOp linalgOp) {
706   IndexOp indexOp = dyn_cast<linalg::IndexOp>(op);
707   if (!indexOp)
708     return VectorizationResult{VectorizationStatus::Failure, nullptr};
709   auto loc = indexOp.getLoc();
710   // Compute the static loop sizes of the index op.
711   ArrayRef<int64_t> targetShape = state.getCanonicalVecShape();
712   auto dim = indexOp.getDim();
713   // Compute a one-dimensional index vector for the index op dimension.
714   auto indexVectorType =
715       VectorType::get({targetShape[dim]}, rewriter.getIndexType(),
716                       state.getScalableVecDims()[dim]);
717   auto indexSteps = rewriter.create<vector::StepOp>(loc, indexVectorType);
718   // Return the one-dimensional index vector if it lives in the trailing
719   // dimension of the iteration space since the vectorization algorithm in this
720   // case can handle the broadcast.
721   if (dim == targetShape.size() - 1)
722     return VectorizationResult{VectorizationStatus::NewOp, indexSteps};
723   // Otherwise permute the targetShape to move the index dimension last,
724   // broadcast the one-dimensional index vector to the permuted shape, and
725   // finally transpose the broadcasted index vector to undo the permutation.
726   auto permPattern =
727       llvm::to_vector(llvm::seq<unsigned>(0, targetShape.size()));
728   std::swap(permPattern[dim], permPattern.back());
729   auto permMap =
730       AffineMap::getPermutationMap(permPattern, linalgOp.getContext());
731 
732   auto broadCastOp = rewriter.create<vector::BroadcastOp>(
733       loc, state.getCanonicalVecType(rewriter.getIndexType(), permMap),
734       indexSteps);
735   SmallVector<int64_t> transposition =
736       llvm::to_vector<16>(llvm::seq<int64_t>(0, linalgOp.getNumLoops()));
737   std::swap(transposition.back(), transposition[dim]);
738   auto transposeOp =
739       rewriter.create<vector::TransposeOp>(loc, broadCastOp, transposition);
740   return VectorizationResult{VectorizationStatus::NewOp, transposeOp};
741 }
742 
743 /// Helper function to check if the tensor.extract can be vectorized by the
744 /// custom hook vectorizeTensorExtract.
745 static LogicalResult
746 tensorExtractVectorizationPrecondition(Operation *op, bool vectorizeNDExtract) {
747   tensor::ExtractOp extractOp = dyn_cast<tensor::ExtractOp>(op);
748   if (!extractOp)
749     return failure();
750 
751   if (extractOp.getIndices().size() != 1 && !vectorizeNDExtract)
752     return failure();
753 
754   // Check the index type, but only for non 0-d tensors (for which we do need
755   // access indices).
756   if (not extractOp.getIndices().empty()) {
757     if (!VectorType::isValidElementType(extractOp.getIndices()[0].getType()))
758       return failure();
759   }
760 
761   if (llvm::any_of(extractOp->getResultTypes(), [](Type type) {
762         return !VectorType::isValidElementType(type);
763       })) {
764     return failure();
765   }
766 
767   return success();
768 }
769 
770 /// Calculates the offsets (`$index_vec`) for `vector.gather` operations
771 /// generated from `tensor.extract`. The offset is calculated as follows
772 /// (example using scalar values):
773 ///
774 ///    offset = extractOp.indices[0]
775 ///    for (i = 1; i < numIndices; i++)
776 ///      offset = extractOp.dimSize[i] * offset + extractOp.indices[i];
777 ///
778 /// For tensor<45 x 80 x 15 x f32> and index [1, 2, 3], this leads to:
779 ///  offset = ( ( 1 ) * 80 +  2 ) * 15  + 3
780 static Value calculateGatherOffset(RewriterBase &rewriter,
781                                    VectorizationState &state,
782                                    tensor::ExtractOp extractOp,
783                                    const IRMapping &bvm) {
784   // The vector of indices for GatherOp should be shaped as the output vector.
785   auto indexVecType = state.getCanonicalVecType(rewriter.getIndexType());
786   auto loc = extractOp.getLoc();
787 
788   Value offset = broadcastIfNeeded(
789       rewriter, bvm.lookup(extractOp.getIndices()[0]), indexVecType);
790 
791   const size_t numIndices = extractOp.getIndices().size();
792   for (size_t i = 1; i < numIndices; i++) {
793     Value dimIdx = rewriter.create<arith::ConstantIndexOp>(loc, i);
794 
795     auto dimSize = broadcastIfNeeded(
796         rewriter,
797         rewriter.create<tensor::DimOp>(loc, extractOp.getTensor(), dimIdx),
798         indexVecType);
799 
800     offset = rewriter.create<arith::MulIOp>(loc, offset, dimSize);
801 
802     auto extractOpIndex = broadcastIfNeeded(
803         rewriter, bvm.lookup(extractOp.getIndices()[i]), indexVecType);
804 
805     offset = rewriter.create<arith::AddIOp>(loc, extractOpIndex, offset);
806   }
807 
808   return offset;
809 }
810 
811 enum VectorMemoryAccessKind { ScalarBroadcast, Contiguous, Gather };
812 
813 /// Checks whether /p val can be used for calculating a loop invariant index.
814 static bool isLoopInvariantIdx(LinalgOp &linalgOp, Value &val) {
815 
816   auto targetShape = linalgOp.getStaticLoopRanges();
817   assert(((llvm::count_if(targetShape,
818                           [](int64_t dimSize) { return dimSize > 1; }) == 1)) &&
819          "n-D vectors are not yet supported");
820   assert(targetShape.back() != 1 &&
821          "1-D vectors with the trailing dim eqaual 1 are not yet supported");
822 
823   // Blocks outside _this_ linalg.generic are effectively loop invariant.
824   // However, analysing block arguments for _this_ linalg.generic Op is a bit
825   // tricky. Just bail out in the latter case.
826   // TODO: We could try analysing the corresponding affine map here.
827   auto *block = linalgOp.getBlock();
828   if (isa<BlockArgument>(val))
829     return llvm::all_of(block->getArguments(),
830                         [&val](Value v) { return (v != val); });
831 
832   Operation *defOp = val.getDefiningOp();
833   assert(defOp && "This is neither a block argument nor an operation result");
834 
835   // IndexOp is loop invariant as long as its result remains constant across
836   // iterations. Given the assumptions on the loop ranges above, only the
837   // trailing loop dim ever changes.
838   auto trailingLoopDim = linalgOp.getStaticLoopRanges().size() - 1;
839   if (auto indexOp = dyn_cast<linalg::IndexOp>(defOp))
840     return (indexOp.getDim() != trailingLoopDim);
841 
842   auto *ancestor = block->findAncestorOpInBlock(*defOp);
843 
844   // Values define outside `linalgOp` are loop invariant.
845   if (!ancestor)
846     return true;
847 
848   // Values defined inside `linalgOp`, which are constant, are loop invariant.
849   if (isa<arith::ConstantOp>(ancestor))
850     return true;
851 
852   bool result = true;
853   for (auto op : ancestor->getOperands())
854     result &= isLoopInvariantIdx(linalgOp, op);
855 
856   return result;
857 }
858 
859 /// Check whether \p val could be used for calculating the trailing index for a
860 /// contiguous load operation.
861 ///
862 /// There are currently 3 types of values that are allowed here:
863 ///   1. loop-invariant values,
864 ///   2. values that increment by 1 with every loop iteration,
865 ///   3. results of basic arithmetic operations (linear and continuous)
866 ///      involving 1., 2. and 3.
867 /// This method returns True if indeed only such values are used in calculating
868 /// \p val.
869 ///
870 /// Additionally, the trailing index for a contiguous load operation should
871 /// increment by 1 with every loop iteration, i.e. be based on:
872 ///   * `linalg.index <dim>` ,
873 /// where <dim> is the trailing dim of the iteration space. \p foundIndexOp is
874 /// updated to `true` when such an op is found.
875 static bool isContiguousLoadIdx(LinalgOp &linalgOp, Value &val,
876                                 bool &foundIndexOp) {
877 
878   auto targetShape = linalgOp.getStaticLoopRanges();
879   assert(((llvm::count_if(targetShape,
880                           [](int64_t dimSize) { return dimSize > 1; }) == 1)) &&
881          "n-D vectors are not yet supported");
882   assert(targetShape.back() != 1 &&
883          "1-D vectors with the trailing dim 1 are not yet supported");
884 
885   // Blocks outside _this_ linalg.generic are effectively loop invariant.
886   // However, analysing block arguments for _this_ linalg.generic Op is a bit
887   // tricky. Just bail out in the latter case.
888   // TODO: We could try analysing the corresponding affine map here.
889   auto *block = linalgOp.getBlock();
890   if (isa<BlockArgument>(val))
891     return llvm::all_of(block->getArguments(),
892                         [&val](Value v) { return (v != val); });
893 
894   Operation *defOp = val.getDefiningOp();
895   assert(defOp && "This is neither a block argument nor an operation result");
896 
897   // Given the assumption on the loop ranges above, only the trailing loop
898   // index is not constant.
899   auto trailingLoopDim = linalgOp.getStaticLoopRanges().size() - 1;
900   if (auto indexOp = dyn_cast<linalg::IndexOp>(defOp)) {
901     foundIndexOp = (indexOp.getDim() == trailingLoopDim);
902     return true;
903   }
904 
905   auto *ancestor = block->findAncestorOpInBlock(*defOp);
906 
907   if (!ancestor)
908     return false;
909 
910   // Conservatively reject Ops that could lead to indices with stride other
911   // than 1.
912   if (!isa<arith::AddIOp, arith::ConstantOp, linalg::IndexOp>(ancestor))
913     return false;
914 
915   bool result = false;
916   for (auto op : ancestor->getOperands())
917     result |= isContiguousLoadIdx(linalgOp, op, foundIndexOp);
918 
919   return result;
920 }
921 
922 /// Infer the memory access pattern for the input ExtractOp
923 ///
924 /// Based on the operation shapes and indices (usually based on the iteration
925 /// space of the parent `linalgOp` operation), decides whether the input
926 /// ExtractOp is a contiguous load (including a broadcast of a scalar) or a
927 /// gather load.
928 ///
929 /// Note that it is always safe to use gather load operations for contiguous
930 /// loads (albeit slow), but not vice-versa. When in doubt, bail out and assume
931 /// that `extractOp` is a gather load.
932 static VectorMemoryAccessKind
933 getTensorExtractMemoryAccessPattern(tensor::ExtractOp extractOp,
934                                     LinalgOp &linalgOp) {
935 
936   auto targetShape = linalgOp.getStaticLoopRanges();
937   auto inputShape = cast<ShapedType>(extractOp.getTensor().getType());
938 
939   // 0.1 Is this a 0-D vector? If yes then this is a scalar broadcast.
940   if (inputShape.getShape().empty())
941     return VectorMemoryAccessKind::ScalarBroadcast;
942 
943   // 0.2 In the case of dynamic shapes just bail-out and assume that it's a
944   // gather load.
945   // TODO: Relax this condition.
946   if (linalgOp.hasDynamicShape())
947     return VectorMemoryAccessKind::Gather;
948 
949   // True for vectors that are effectively 1D, e.g. `vector<1x4x1xi32>`, false
950   // otherwise.
951   bool isOutput1DVector = (llvm::count_if(targetShape, [](int64_t dimSize) {
952                              return dimSize > 1;
953                            }) == 1);
954 
955   // 1. Assume that it's a gather load when reading non-1D vector.
956   if (!isOutput1DVector)
957     return VectorMemoryAccessKind::Gather;
958 
959   bool leadingIdxsLoopInvariant = true;
960 
961   // 2. Analyze the leading indices of `extractOp`.
962   // Look at the way each index is calculated and decide whether it is suitable
963   // for a contiguous load, i.e. whether it's loop invariant. If not, it's a
964   // gather load.
965   auto indices = extractOp.getIndices();
966   auto leadIndices = indices.drop_back(1);
967 
968   for (auto [i, indexVal] : llvm::enumerate(leadIndices)) {
969     if (inputShape.getShape()[i] == 1)
970       continue;
971 
972     leadingIdxsLoopInvariant &= isLoopInvariantIdx(linalgOp, indexVal);
973   }
974 
975   if (!leadingIdxsLoopInvariant) {
976     LDBG("Found gather load: " << extractOp);
977     return VectorMemoryAccessKind::Gather;
978   }
979 
980   // 3. Analyze the trailing index for `extractOp`.
981   // At this point we know that the leading indices are loop invariant. This
982   // means that is potentially a scalar or a contiguous load. We can decide
983   // based on the trailing idx.
984   auto extractOpTrailingIdx = indices.back();
985 
986   // 3a. Scalar broadcast load
987   // If the trailing index is loop invariant then this is a scalar load.
988   if (leadingIdxsLoopInvariant &&
989       isLoopInvariantIdx(linalgOp, extractOpTrailingIdx)) {
990     LDBG("Found scalar broadcast load: " << extractOp);
991 
992     return VectorMemoryAccessKind::ScalarBroadcast;
993   }
994 
995   // 3b. Contiguous loads
996   // The trailing `extractOp` index should increment with every loop iteration.
997   // This effectively means that it must be based on the trailing loop index.
998   // This is what the following bool captures.
999   bool foundIndexOp = false;
1000   bool isContiguousLoad =
1001       isContiguousLoadIdx(linalgOp, extractOpTrailingIdx, foundIndexOp);
1002   isContiguousLoad &= foundIndexOp;
1003 
1004   if (isContiguousLoad) {
1005     LDBG("Found contigous load: " << extractOp);
1006     return VectorMemoryAccessKind::Contiguous;
1007   }
1008 
1009   // 4. Fallback case - gather load.
1010   LDBG("Found gather load: " << extractOp);
1011   return VectorMemoryAccessKind::Gather;
1012 }
1013 
1014 /// Helper function to vectorize the tensor.extract operations. Returns
1015 /// VectorizationStatus::NewOp to signal the vectorization algorithm that it
1016 /// should map the produced operations. This function is meant to be used as a
1017 /// CustomVectorizationHook.
1018 static VectorizationResult
1019 vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state,
1020                        Operation *op, LinalgOp linalgOp, const IRMapping &bvm) {
1021   tensor::ExtractOp extractOp = dyn_cast<tensor::ExtractOp>(op);
1022   if (!extractOp)
1023     return VectorizationResult{VectorizationStatus::Failure, nullptr};
1024   auto loc = extractOp.getLoc();
1025 
1026   // Compute the static loop sizes of the extract op.
1027   auto resultType = state.getCanonicalVecType(extractOp.getResult().getType());
1028   auto maskConstantOp = rewriter.create<arith::ConstantOp>(
1029       loc,
1030       DenseIntElementsAttr::get(state.getCanonicalVecType(rewriter.getI1Type()),
1031                                 /*value=*/true));
1032   auto passThruConstantOp =
1033       rewriter.create<arith::ConstantOp>(loc, rewriter.getZeroAttr(resultType));
1034 
1035   // Base indices are currently set to 0. We will need to re-visit if more
1036   // generic scenarios are to be supported.
1037   SmallVector<Value> baseIndices(
1038       extractOp.getIndices().size(),
1039       rewriter.create<arith::ConstantIndexOp>(loc, 0));
1040 
1041   VectorMemoryAccessKind memAccessKind =
1042       getTensorExtractMemoryAccessPattern(extractOp, linalgOp);
1043 
1044   // 1. Handle gather access
1045   if (memAccessKind == VectorMemoryAccessKind::Gather) {
1046     Value offset = calculateGatherOffset(rewriter, state, extractOp, bvm);
1047 
1048     // Generate the gather load
1049     Operation *gatherOp = rewriter.create<vector::GatherOp>(
1050         loc, resultType, extractOp.getTensor(), baseIndices, offset,
1051         maskConstantOp, passThruConstantOp);
1052     gatherOp = state.maskOperation(rewriter, gatherOp, linalgOp);
1053 
1054     LDBG("Vectorised as gather load: " << extractOp << "\n");
1055     return VectorizationResult{VectorizationStatus::NewOp, gatherOp};
1056   }
1057 
1058   // 2. Handle:
1059   //  a. scalar loads + broadcast,
1060   //  b. contiguous loads.
1061   // Both cases use vector.transfer_read.
1062 
1063   // Collect indices for `vector.transfer_read`. At this point, the indices will
1064   // either be scalars or would have been broadcast to vectors matching the
1065   // result type. For indices that are vectors, there are two options:
1066   //    * for non-trailing indices, all elements are identical (contiguous
1067   //      loads are identified by looking for non-trailing indices that are
1068   //      invariant with respect to the corresponding linalg.generic), or
1069   //    * for trailing indices, the index vector will contain values with stride
1070   //      one, but for `vector.transfer_read` only the first (i.e. 0th) index is
1071   //      needed.
1072   // This means that
1073   //   * for scalar indices - just re-use it,
1074   //   * for vector indices (e.g. `vector<1x1x4xindex>`) - extract the bottom
1075   //    (0th) element and use that.
1076   SmallVector<Value> transferReadIdxs;
1077   auto zero = rewriter.create<arith::ConstantOp>(
1078       loc, rewriter.getI32Type(), rewriter.getZeroAttr(rewriter.getI32Type()));
1079   for (size_t i = 0; i < extractOp.getIndices().size(); i++) {
1080     Value idx = bvm.lookup(extractOp.getIndices()[i]);
1081     if (idx.getType().isIndex()) {
1082       transferReadIdxs.push_back(idx);
1083       continue;
1084     }
1085 
1086     auto indexAs1dVector = rewriter.create<vector::ShapeCastOp>(
1087         loc,
1088         VectorType::get(resultType.getShape().back(), rewriter.getIndexType(),
1089                         resultType.getScalableDims().back()),
1090         idx);
1091     transferReadIdxs.push_back(
1092         rewriter.create<vector::ExtractElementOp>(loc, indexAs1dVector, zero));
1093   }
1094 
1095   // `tensor.extract_element` is always in-bounds, hence the following holds.
1096   auto dstRank = resultType.getRank();
1097   auto srcRank = extractOp.getTensor().getType().getRank();
1098   SmallVector<bool> inBounds(dstRank, true);
1099 
1100   // 2a. Handle scalar broadcast access.
1101   if (memAccessKind == VectorMemoryAccessKind::ScalarBroadcast) {
1102     MLIRContext *ctx = rewriter.getContext();
1103     SmallVector<AffineExpr> exprs(dstRank, getAffineConstantExpr(0, ctx));
1104     auto permutationMap = AffineMap::get(srcRank, 0, exprs, ctx);
1105 
1106     auto transferReadOp = rewriter.create<vector::TransferReadOp>(
1107         loc, resultType, extractOp.getTensor(), transferReadIdxs,
1108         permutationMap, inBounds);
1109 
1110     LDBG("Vectorised as scalar broadcast load: " << extractOp << "\n");
1111     return VectorizationResult{VectorizationStatus::NewOp, transferReadOp};
1112   }
1113 
1114   // 2b. Handle contiguous access.
1115   auto permutationMap = AffineMap::getMinorIdentityMap(
1116       srcRank, std::min(dstRank, srcRank), rewriter.getContext());
1117 
1118   int32_t rankDiff = dstRank - srcRank;
1119   // When dstRank > srcRank, broadcast the source tensor to the unitary leading
1120   // dims so that the ranks match. This is done by extending the map with 0s.
1121   // For example, for dstRank = 3, srcRank = 2, the following map created
1122   // above:
1123   //    (d0, d1) --> (d0, d1)
1124   // is extended as:
1125   //    (d0, d1) --> (0, d0, d1)
1126   while (rankDiff > 0) {
1127     permutationMap = permutationMap.insertResult(
1128         mlir::getAffineConstantExpr(0, rewriter.getContext()), 0);
1129     rankDiff--;
1130   }
1131 
1132   auto transferReadOp = rewriter.create<vector::TransferReadOp>(
1133       loc, resultType, extractOp.getTensor(), transferReadIdxs, permutationMap,
1134       inBounds);
1135 
1136   LDBG("Vectorised as contiguous load: " << extractOp);
1137   return VectorizationResult{VectorizationStatus::NewOp, transferReadOp};
1138 }
1139 
1140 /// Emit reduction operations if the shapes of the value to reduce is different
1141 /// that the result shape.
1142 // Note: this is a true builder that notifies the OpBuilder listener.
1143 // TODO: Consider moving as a static helper on the ReduceOp.
1144 static Operation *reduceIfNeeded(OpBuilder &b, LinalgOp linalgOp, Operation *op,
1145                                  Value reduceValue, Value initialValue,
1146                                  const IRMapping &bvm) {
1147   Value reduceVec = bvm.lookup(reduceValue);
1148   Value outputVec = bvm.lookup(initialValue);
1149   auto reduceType = dyn_cast<VectorType>(reduceVec.getType());
1150   auto outputType = dyn_cast<VectorType>(outputVec.getType());
1151   // Reduce only if needed as the value may already have been reduce for
1152   // contraction vectorization.
1153   if (!reduceType ||
1154       (outputType && reduceType.getShape() == outputType.getShape()))
1155     return nullptr;
1156   SmallVector<bool> dimsToMask = getDimsToReduce(linalgOp);
1157   return buildMultiDimReduce(b, op, reduceVec, outputVec, dimsToMask);
1158 }
1159 
1160 /// Generic vectorization for a single operation `op`, given already vectorized
1161 /// operands carried by `bvm`. Vectorization occurs as follows:
1162 ///   1. Try to apply any of the `customVectorizationHooks` and return its
1163 ///   result on success.
1164 ///   2. Clone any constant in the current scope without vectorization: each
1165 ///   consumer of the constant will later determine the shape to which the
1166 ///   constant needs to be broadcast to.
1167 ///   3. Fail on any remaining non `ElementwiseMappable` op. It is the purpose
1168 ///   of the `customVectorizationHooks` to cover such cases.
1169 ///   4. Clone `op` in vector form to a vector of shape prescribed by the first
1170 ///   operand of maximal rank. Other operands have smaller rank and are
1171 ///   broadcast accordingly. It is assumed this broadcast is always legal,
1172 ///   otherwise, it means one of the `customVectorizationHooks` is incorrect.
1173 ///
1174 /// This function assumes all operands of `op` have been vectorized and are in
1175 /// the `bvm` mapping. As a consequence, this function is meant to be called  on
1176 /// a topologically-sorted list of ops.
1177 /// This function does not update `bvm` but returns a VectorizationStatus that
1178 /// instructs the caller what `bvm` update needs to occur.
1179 static VectorizationResult
1180 vectorizeOneOp(RewriterBase &rewriter, VectorizationState &state,
1181                LinalgOp linalgOp, Operation *op, const IRMapping &bvm,
1182                ArrayRef<CustomVectorizationHook> customVectorizationHooks) {
1183   LDBG("vectorize op " << *op << "\n");
1184 
1185   // 1. Try to apply any CustomVectorizationHook.
1186   if (!customVectorizationHooks.empty()) {
1187     for (auto &customFunc : customVectorizationHooks) {
1188       VectorizationResult result = customFunc(op, bvm);
1189       if (result.status == VectorizationStatus::Failure)
1190         continue;
1191       return result;
1192     }
1193   }
1194 
1195   // 2. Constant ops don't get vectorized but rather broadcasted at their users.
1196   // Clone so that the constant is not confined to the linalgOp block .
1197   if (isa<arith::ConstantOp, func::ConstantOp>(op))
1198     return VectorizationResult{VectorizationStatus::NewOp, rewriter.clone(*op)};
1199 
1200   // 3. Only ElementwiseMappable are allowed in the generic vectorization.
1201   if (!OpTrait::hasElementwiseMappableTraits(op))
1202     return VectorizationResult{VectorizationStatus::Failure, nullptr};
1203 
1204   // 4 . Check if the operation is a reduction.
1205   SmallVector<std::pair<Value, Value>> reductionOperands;
1206   for (Value operand : op->getOperands()) {
1207     auto blockArg = dyn_cast<BlockArgument>(operand);
1208     if (!blockArg || blockArg.getOwner() != linalgOp.getBlock() ||
1209         blockArg.getArgNumber() < linalgOp.getNumDpsInputs())
1210       continue;
1211     SmallVector<Operation *> reductionOps;
1212     Value reduceValue = matchReduction(
1213         linalgOp.getRegionOutputArgs(),
1214         blockArg.getArgNumber() - linalgOp.getNumDpsInputs(), reductionOps);
1215     if (!reduceValue)
1216       continue;
1217     reductionOperands.push_back(std::make_pair(reduceValue, operand));
1218   }
1219   if (!reductionOperands.empty()) {
1220     assert(reductionOperands.size() == 1);
1221     Operation *reduceOp =
1222         reduceIfNeeded(rewriter, linalgOp, op, reductionOperands[0].first,
1223                        reductionOperands[0].second, bvm);
1224     if (reduceOp)
1225       return VectorizationResult{VectorizationStatus::NewOp, reduceOp};
1226   }
1227 
1228   // 5. Generic vectorization path for ElementwiseMappable ops.
1229   //   a. Get the first max ranked shape.
1230   VectorType firstMaxRankedType;
1231   for (Value operand : op->getOperands()) {
1232     auto vecOperand = bvm.lookup(operand);
1233     assert(vecOperand && "Vector operand couldn't be found");
1234 
1235     auto vecType = dyn_cast<VectorType>(vecOperand.getType());
1236     if (vecType && (!firstMaxRankedType ||
1237                     firstMaxRankedType.getRank() < vecType.getRank()))
1238       firstMaxRankedType = vecType;
1239   }
1240   //   b. Broadcast each op if needed.
1241   SmallVector<Value> vecOperands;
1242   for (Value scalarOperand : op->getOperands()) {
1243     Value vecOperand = bvm.lookup(scalarOperand);
1244     assert(vecOperand && "Vector operand couldn't be found");
1245 
1246     if (firstMaxRankedType) {
1247       auto vecType = VectorType::get(firstMaxRankedType.getShape(),
1248                                      getElementTypeOrSelf(vecOperand.getType()),
1249                                      firstMaxRankedType.getScalableDims());
1250       vecOperands.push_back(broadcastIfNeeded(rewriter, vecOperand, vecType));
1251     } else {
1252       vecOperands.push_back(vecOperand);
1253     }
1254   }
1255   //   c. for elementwise, the result is the vector with the firstMaxRankedShape
1256   SmallVector<Type> resultTypes;
1257   for (Type resultType : op->getResultTypes()) {
1258     resultTypes.push_back(
1259         firstMaxRankedType
1260             ? VectorType::get(firstMaxRankedType.getShape(), resultType,
1261                               firstMaxRankedType.getScalableDims())
1262             : resultType);
1263   }
1264   //   d. Build and return the new op.
1265   return VectorizationResult{
1266       VectorizationStatus::NewOp,
1267       rewriter.create(op->getLoc(), op->getName().getIdentifier(), vecOperands,
1268                       resultTypes, op->getAttrs())};
1269 }
1270 
1271 /// Generic vectorization function that rewrites the body of a `linalgOp` into
1272 /// vector form. Generic vectorization proceeds as follows:
1273 ///   1. Verify the `linalgOp` has one non-empty region.
1274 ///   2. Values defined above the region are mapped to themselves and will be
1275 ///   broadcasted on a per-need basis by their consumers.
1276 ///   3. Each region argument is vectorized into a vector.transfer_read (or 0-d
1277 ///   load).
1278 ///   TODO: Reuse opportunities for RAR dependencies.
1279 ///   4a. Register CustomVectorizationHook for YieldOp to capture the results.
1280 ///   4rewriter. Register CustomVectorizationHook for IndexOp to access the
1281 ///   iteration indices.
1282 ///   5. Iteratively call vectorizeOneOp on the region operations.
1283 ///
1284 /// When `broadcastToMaximalCommonShape` is set to true, eager broadcasting is
1285 /// performed to the maximal common vector size implied by the `linalgOp`
1286 /// iteration space. This eager broadcasting is introduced in the
1287 /// permutation_map of the vector.transfer_read operations. The eager
1288 /// broadcasting makes it trivial to detrmine where broadcast, transposes and
1289 /// reductions should occur, without any bookkeeping. The tradeoff is that, in
1290 /// the absence of good canonicalizations, the amount of work increases.
1291 /// This is not deemed a problem as we expect canonicalizations and foldings to
1292 /// aggressively clean up the useless work.
1293 static LogicalResult
1294 vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state,
1295                          LinalgOp linalgOp,
1296                          SmallVectorImpl<Value> &newResults) {
1297   LDBG("Vectorizing operation as linalg generic\n");
1298   Block *block = linalgOp.getBlock();
1299 
1300   // 2. Values defined above the region can only be broadcast for now. Make them
1301   // map to themselves.
1302   IRMapping bvm;
1303   SetVector<Value> valuesSet;
1304   mlir::getUsedValuesDefinedAbove(linalgOp->getRegion(0), valuesSet);
1305   bvm.map(valuesSet.getArrayRef(), valuesSet.getArrayRef());
1306 
1307   if (linalgOp.getNumDpsInits() == 0)
1308     return failure();
1309 
1310   // 3. Turn all BBArgs into vector.transfer_read / load.
1311   Location loc = linalgOp.getLoc();
1312   Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
1313   for (OpOperand *opOperand : linalgOp.getOpOperandsMatchingBBargs()) {
1314     BlockArgument bbarg = linalgOp.getMatchingBlockArgument(opOperand);
1315     if (linalgOp.isScalar(opOperand)) {
1316       bvm.map(bbarg, opOperand->get());
1317       continue;
1318     }
1319 
1320     // 3.a. Convert the indexing map for this input/output to a transfer read
1321     // permutation map and masking map.
1322     AffineMap indexingMap = linalgOp.getMatchingIndexingMap(opOperand);
1323 
1324     // Remove zeros from indexing map to use it as masking map.
1325     SmallVector<int64_t> zeroPos;
1326     auto results = indexingMap.getResults();
1327     for (const auto &result : llvm::enumerate(results)) {
1328       if (isa<AffineConstantExpr>(result.value())) {
1329         zeroPos.push_back(result.index());
1330       }
1331     }
1332     AffineMap maskingMap = indexingMap.dropResults(zeroPos);
1333 
1334     AffineMap readMap;
1335     VectorType readType;
1336     Type elemType = getElementTypeOrSelf(opOperand->get());
1337     if (linalgOp.isDpsInput(opOperand)) {
1338       // 3.a.i. For input reads we use the canonical vector shape.
1339       readMap = inverseAndBroadcastProjectedPermutation(indexingMap);
1340       readType = state.getCanonicalVecType(elemType);
1341     } else {
1342       // 3.a.ii. For output reads (iteration-carried dependence, e.g.,
1343       // reductions), the vector shape is computed by mapping the canonical
1344       // vector shape to the output domain and back to the canonical domain.
1345       readMap = inversePermutation(reindexIndexingMap(indexingMap));
1346       readType =
1347           state.getCanonicalVecType(elemType, readMap.compose(indexingMap));
1348     }
1349 
1350     SmallVector<Value> indices(linalgOp.getShape(opOperand).size(), zero);
1351 
1352     // Make sure that the in_bounds attribute corresponding to a broadcast dim
1353     // is `true`
1354     SmallVector<unsigned> broadcastedDims = readMap.getBroadcastDims();
1355     SmallVector<bool> inBounds(readType.getRank(), false);
1356 
1357     for (auto idx : broadcastedDims)
1358       inBounds[idx] = true;
1359 
1360     Operation *read = rewriter.create<vector::TransferReadOp>(
1361         loc, readType, opOperand->get(), indices, readMap,
1362         ArrayRef<bool>(inBounds));
1363     read = state.maskOperation(rewriter, read, linalgOp, maskingMap);
1364     Value readValue = read->getResult(0);
1365 
1366     // 3.b. If masked, set in-bounds to true. Masking guarantees that the access
1367     // will be in-bounds.
1368     if (auto maskOp = dyn_cast<vector::MaskingOpInterface>(read)) {
1369       SmallVector<bool> inBounds(readType.getRank(), true);
1370       cast<vector::TransferReadOp>(maskOp.getMaskableOp())
1371           .setInBoundsAttr(rewriter.getBoolArrayAttr(inBounds));
1372     }
1373 
1374     // 3.c. Not all ops support 0-d vectors, extract the scalar for now.
1375     // TODO: remove this.
1376     if (readType.getRank() == 0)
1377       readValue = rewriter.create<vector::ExtractElementOp>(loc, readValue);
1378 
1379     LDBG("New vectorized bbarg(" << bbarg.getArgNumber() << "): " << readValue
1380                                  << "\n");
1381     bvm.map(bbarg, readValue);
1382     bvm.map(opOperand->get(), readValue);
1383   }
1384 
1385   SmallVector<CustomVectorizationHook> hooks;
1386   // 4a. Register CustomVectorizationHook for yieldOp.
1387   CustomVectorizationHook vectorizeYield =
1388       [&](Operation *op, const IRMapping &bvm) -> VectorizationResult {
1389     return vectorizeLinalgYield(rewriter, op, bvm, state, linalgOp, newResults);
1390   };
1391   hooks.push_back(vectorizeYield);
1392 
1393   // 4b. Register CustomVectorizationHook for indexOp.
1394   CustomVectorizationHook vectorizeIndex =
1395       [&](Operation *op, const IRMapping &bvm) -> VectorizationResult {
1396     return vectorizeLinalgIndex(rewriter, state, op, linalgOp);
1397   };
1398   hooks.push_back(vectorizeIndex);
1399 
1400   // 4c. Register CustomVectorizationHook for extractOp.
1401   CustomVectorizationHook vectorizeExtract =
1402       [&](Operation *op, const IRMapping &bvm) -> VectorizationResult {
1403     return vectorizeTensorExtract(rewriter, state, op, linalgOp, bvm);
1404   };
1405   hooks.push_back(vectorizeExtract);
1406 
1407   // 5. Iteratively call `vectorizeOneOp` to each op in the slice.
1408   for (Operation &op : block->getOperations()) {
1409     VectorizationResult result =
1410         vectorizeOneOp(rewriter, state, linalgOp, &op, bvm, hooks);
1411     if (result.status == VectorizationStatus::Failure) {
1412       LDBG("failed to vectorize: " << op << "\n");
1413       return failure();
1414     }
1415     if (result.status == VectorizationStatus::NewOp) {
1416       Operation *maybeMaskedOp =
1417           state.maskOperation(rewriter, result.newOp, linalgOp);
1418       LDBG("New vector op: " << *maybeMaskedOp << "\n");
1419       bvm.map(op.getResults(), maybeMaskedOp->getResults());
1420     }
1421   }
1422 
1423   return success();
1424 }
1425 
1426 /// Given a tensor::PackOp, return the `dest` shape before any packing
1427 /// permutations.
1428 static SmallVector<int64_t> getTiledPackShape(tensor::PackOp packOp,
1429                                               ArrayRef<int64_t> destShape) {
1430   return applyPermutation(destShape, tensor::getPackInverseDestPerm(packOp));
1431 }
1432 
1433 /// Given an input, the mixed destSizes, and the vector sizes for vectorization,
1434 /// create an empty destination tensor and create a TransferWriteOp from the
1435 /// input to the empty tensor. If the destination shape is not the same as the
1436 /// inputVectorSizes for the first rank(inputVectorSizes) dims, then create a
1437 /// mask for the write. If `useInBoundsInsteadOfMasking` is set, then update the
1438 /// inBounds attribute of the transfer write op instead of masking.
1439 static Operation *createWriteOrMaskedWrite(OpBuilder &builder, Location loc,
1440                                            Value input,
1441                                            SmallVector<OpFoldResult> destSizes,
1442                                            ArrayRef<int64_t> inputVectorSizes,
1443                                            bool useInBoundsInsteadOfMasking) {
1444 
1445   auto inputType = cast<VectorType>(input.getType());
1446   Value dest = builder.create<tensor::EmptyOp>(loc, destSizes,
1447                                                inputType.getElementType());
1448   int64_t rank = cast<ShapedType>(dest.getType()).getRank();
1449   auto zero = builder.create<arith::ConstantIndexOp>(loc, 0);
1450   auto destShape = cast<ShapedType>(dest.getType()).getShape();
1451   SmallVector<bool> inBoundsVal(rank, true);
1452   if (useInBoundsInsteadOfMasking) {
1453     // Update the inBounds attribute.
1454     for (unsigned i = 0; i < rank; i++)
1455       inBoundsVal[i] = (destShape[i] == inputVectorSizes[i]) &&
1456                        !ShapedType::isDynamic(destShape[i]);
1457   }
1458   Operation *write = builder.create<vector::TransferWriteOp>(
1459       loc,
1460       /*vector=*/input,
1461       /*source=*/dest,
1462       /*indices=*/SmallVector<Value>(rank, zero),
1463       /*inBounds=*/inBoundsVal);
1464   assert(llvm::none_of(
1465              destShape.drop_front(inputVectorSizes.size()),
1466              [](int64_t size) { return size == ShapedType::kDynamic; }) &&
1467          "Only dims aligned with inputVectorSizes may be dynamic");
1468   if (useInBoundsInsteadOfMasking)
1469     return write;
1470   bool needMaskForWrite = !llvm::equal(
1471       inputVectorSizes, destShape.take_front(inputVectorSizes.size()));
1472   if (needMaskForWrite) {
1473     SmallVector<int64_t> writeMaskShape;
1474     writeMaskShape.append(inputVectorSizes.begin(), inputVectorSizes.end());
1475     writeMaskShape.append(destShape.begin() + inputVectorSizes.size(),
1476                           destShape.end());
1477     auto writeMaskType = VectorType::get(writeMaskShape, builder.getI1Type());
1478     Value maskForWrite =
1479         builder.create<vector::CreateMaskOp>(loc, writeMaskType, destSizes);
1480     write = mlir::vector::maskOperation(builder, write, maskForWrite);
1481   }
1482   return write;
1483 }
1484 
1485 /// Vectorize tensor::PackOp with (1) static innerTiles (2) constant
1486 /// padding value and (3) input vector sizes into:
1487 /// masked_transfer_read->shape_cast->transpose->transfer_write_in_bounds
1488 /// As in the following example:
1489 /// %pack = tensor.pack %src inner_dims_pos = [2, 1] inner_tiles = [16, 2]
1490 ///     into %dst : tensor<32x8x16xf32> -> tensor<32x4x1x16x2xf32>
1491 ///
1492 /// This pack would be vectorized to:
1493 ///
1494 /// %load = vector.mask %mask {
1495 ///     vector.transfer_read %arg0[%c0, %c0, %c0], %cst
1496 ///         {in_bounds = [true, true, true]} :
1497 ///         tensor<32x7x16xf32>, vector<32x8x16xf32>
1498 /// } : vector<32x8x16xi1> -> vector<32x8x16xf32>
1499 /// %shape_cast = vector.shape_cast %load : vector<32x8x16xf32>
1500 ///                                         to vector<32x4x2x1x16xf32>
1501 /// %transpose = vector.transpose %shape_cast, [0, 1, 3, 4, 2]
1502 ///     : vector<32x4x2x1x16xf32> to vector<32x4x1x16x2xf32>
1503 /// %write = vector.transfer_write %transpose,
1504 ///     %empty[%c0_0, %c0_0, %c0_0, %c0_0, %c0_0]
1505 ///     {in_bounds = [true, true, true, true, true]}
1506 ///     : vector<32x4x1x16x2xf32>, tensor<32x4x1x16x2xf32>
1507 ///
1508 /// If the (3) input vector sizes are not provided, the vector sizes are
1509 /// determined by the result tensor shape. Also, we update the inBounds
1510 /// attribute instead of masking.
1511 static LogicalResult
1512 vectorizeAsTensorPackOp(RewriterBase &rewriter, tensor::PackOp packOp,
1513                         ArrayRef<int64_t> inputVectorSizes,
1514                         SmallVectorImpl<Value> &newResults) {
1515   OpBuilder::InsertionGuard g(rewriter);
1516   rewriter.setInsertionPoint(packOp);
1517 
1518   Location loc = packOp.getLoc();
1519   auto padValue = packOp.getPaddingValue();
1520   if (!padValue) {
1521     padValue = rewriter.create<arith::ConstantOp>(
1522         loc, rewriter.getZeroAttr(packOp.getSourceType().getElementType()));
1523   }
1524   ReifiedRankedShapedTypeDims reifiedReturnShapes;
1525   LogicalResult status =
1526       cast<ReifyRankedShapedTypeOpInterface>(packOp.getOperation())
1527           .reifyResultShapes(rewriter, reifiedReturnShapes);
1528   (void)status; // prevent unused variable warning on non-assert builds.
1529   assert(succeeded(status) && "failed to reify result shapes");
1530 
1531   // If the input vector sizes are not provided, then the vector sizes are
1532   // determined by the result tensor shape. In case the vector sizes aren't
1533   // provided, we update the inBounds attribute instead of masking.
1534   bool useInBoundsInsteadOfMasking = false;
1535   if (inputVectorSizes.empty()) {
1536     ArrayRef<int64_t> resultTensorShape = packOp.getDestType().getShape();
1537     inputVectorSizes = resultTensorShape.take_front(packOp.getSourceRank());
1538     useInBoundsInsteadOfMasking = true;
1539   }
1540 
1541   // Create masked TransferReadOp.
1542   SmallVector<int64_t> inputShape(inputVectorSizes);
1543   auto innerTiles = packOp.getStaticInnerTiles();
1544   auto innerDimsPos = packOp.getInnerDimsPos();
1545   auto outerDimsPerm = packOp.getOuterDimsPerm();
1546   if (!outerDimsPerm.empty())
1547     applyPermutationToVector(inputShape,
1548                              invertPermutationVector(outerDimsPerm));
1549   for (auto [idx, size] : enumerate(innerTiles))
1550     inputShape[innerDimsPos[idx]] *= size;
1551   auto maskedRead = vector::createReadOrMaskedRead(
1552       rewriter, loc, packOp.getSource(), inputShape, padValue,
1553       useInBoundsInsteadOfMasking);
1554 
1555   // Create ShapeCastOp.
1556   SmallVector<int64_t> destShape(inputVectorSizes);
1557   destShape.append(innerTiles.begin(), innerTiles.end());
1558   auto tiledPackType = VectorType::get(getTiledPackShape(packOp, destShape),
1559                                        packOp.getDestType().getElementType());
1560   auto shapeCastOp =
1561       rewriter.create<vector::ShapeCastOp>(loc, tiledPackType, maskedRead);
1562 
1563   // Create TransposeOp.
1564   auto destPermutation =
1565       invertPermutationVector(tensor::getPackInverseDestPerm(packOp));
1566   auto transposeOp = rewriter.create<vector::TransposeOp>(
1567       loc, shapeCastOp.getResult(), destPermutation);
1568 
1569   // Create TransferWriteOp.
1570   Operation *write = createWriteOrMaskedWrite(
1571       rewriter, loc, transposeOp.getResult(), reifiedReturnShapes[0],
1572       inputVectorSizes, /*useInBoundsInsteadOfMasking=*/false);
1573   newResults.push_back(write->getResult(0));
1574   return success();
1575 }
1576 
1577 /// Vectorize a `tensor::UnPackOp` to these 4 Ops:
1578 ///   Vector::TransferReadOp - Reads a vector from the source tensor
1579 ///   vector::TransposeOp - Transpose the Source tensor
1580 ///   ShapeCastOp - Reshape the data based on the target.
1581 ///   vector::TransferWriteOp. - Write the result vector back to the destination
1582 ///   tensor.
1583 ///   If the vector sizes are not provided:
1584 ///   * the vector sizes are determined by the input operand and attributes,
1585 ///   * update the inBounds attribute instead of masking.
1586 static LogicalResult
1587 vectorizeAsTensorUnpackOp(RewriterBase &rewriter, tensor::UnPackOp unpackOp,
1588                           ArrayRef<int64_t> inputVectorSizes,
1589                           SmallVectorImpl<Value> &newResults) {
1590 
1591   OpBuilder::InsertionGuard g(rewriter);
1592   rewriter.setInsertionPoint(unpackOp);
1593 
1594   RankedTensorType unpackTensorType = unpackOp.getSourceType();
1595 
1596   ArrayRef<int64_t> innerDimPos = unpackOp.getInnerDimsPos();
1597   ArrayRef<int64_t> innerTiles = unpackOp.getStaticInnerTiles();
1598   ArrayRef<int64_t> sourceShape = unpackTensorType.getShape();
1599   bool useInBoundsInsteadOfMasking = false;
1600   ArrayRef<int64_t> outerDimsPerm = unpackOp.getOuterDimsPerm();
1601 
1602   auto destSize = unpackOp.getDestRank();
1603 
1604   if (!inputVectorSizes.empty())
1605     assert(inputVectorSizes.size() == destSize &&
1606            "Incorrect number of input vector sizes");
1607 
1608   // vectorSizes is the shape of the vector that will be used to do final
1609   // write on the destination tensor. It is set like this: Let's say the
1610   // source tensor is rank 'M' and the dest tensor rank 'N', where N <= M.
1611   // Thus:
1612   // 1. vectorSizes = sourceShape.take_front(N)
1613   // 2. if outer_dims_perms is present: do that permutation on vectorSizes.
1614   // 3. multiply all the locations in vectorSize pointed by innerDimPos by the
1615   //    innerTiles attribute value.
1616   SmallVector<int64_t> vectorSizes(inputVectorSizes);
1617   if (vectorSizes.empty()) {
1618     llvm::append_range(vectorSizes, sourceShape.take_front(destSize));
1619     if (!outerDimsPerm.empty())
1620       applyPermutationToVector(vectorSizes, outerDimsPerm);
1621     for (auto [i, pos] : llvm::enumerate(innerDimPos))
1622       vectorSizes[pos] *= innerTiles[i];
1623 
1624     useInBoundsInsteadOfMasking = true;
1625   }
1626 
1627   // readVectorSizes is the size of tensor used to read and apply mask. It is
1628   // set like this: Let's say the vectorSize (VS) array is size 'N' and
1629   // the sourceShape(SS) is 'M' where M >= N and InnerTileSizes (IT) of
1630   // size M-N
1631   // Thus:
1632   // - initially: readVectorSizes = vectorInputSizes
1633   // - Divide all the readMaskShape locations pointed by innerDimPos
1634   //   by the innerTileSize attribute value.
1635   // - if outer_dims_perms is present: do that permutation on readVectorSizes.
1636   // - Append the remaining shape from SS
1637   // E.g. let's say let's say unpackTensorType.getShape() = <8x8x32x16>
1638   // inner Dim Pos = [0, 1] and Inner Tiles = [32, 16], vector_sizes are [512,
1639   // 128] and outer_dims_perm is [1, 0] then read shape is:
1640   //   ReadVectorSizes(initial): [512, 128]
1641   //   Final Value(after innerDim Adjustment): [512/32, 128/16]
1642   //                                           = [16, 8]
1643   //   After applying outer_dims_perm: [8, 16]
1644   //   After appending the rest of the sourceShape: [8, 16, 32, 16]
1645 
1646   SmallVector<int64_t> readVectorSizes(vectorSizes.begin(), vectorSizes.end());
1647 
1648   for (auto [index, size] : enumerate(innerTiles)) {
1649     readVectorSizes[innerDimPos[index]] =
1650         llvm::divideCeil(readVectorSizes[innerDimPos[index]], size);
1651   }
1652   if (!outerDimsPerm.empty()) {
1653     applyPermutationToVector(readVectorSizes, outerDimsPerm);
1654   }
1655   readVectorSizes.append(sourceShape.begin() + vectorSizes.size(),
1656                          sourceShape.end());
1657 
1658   ReifiedRankedShapedTypeDims reifiedRetShapes;
1659   LogicalResult status =
1660       cast<ReifyRankedShapedTypeOpInterface>(unpackOp.getOperation())
1661           .reifyResultShapes(rewriter, reifiedRetShapes);
1662   if (status.failed()) {
1663     LDBG("Unable to reify result shapes of " << unpackOp);
1664     return failure();
1665   }
1666   Location loc = unpackOp->getLoc();
1667 
1668   auto padValue = rewriter.create<arith::ConstantOp>(
1669       loc, rewriter.getZeroAttr(unpackOp.getSourceType().getElementType()));
1670 
1671   // Read result, mask if necessary. If transferReadOp shape is not equal
1672   // to shape of source, then a mask is necessary.
1673   Value readResult = vector::createReadOrMaskedRead(
1674       rewriter, loc, unpackOp.getSource(), readVectorSizes, padValue,
1675       /*useInBoundsInsteadOfMasking=*/false);
1676 
1677   PackingMetadata packMetadata;
1678   SmallVector<int64_t> lastDimToInsertPosPerm =
1679       tensor::getUnPackInverseSrcPerm(unpackOp, packMetadata);
1680   ShapedType maskedOpShapedType = cast<ShapedType>(readResult.getType());
1681   SmallVector<int64_t> stripMineShape(maskedOpShapedType.getShape());
1682   mlir::Type stripMineElemType = maskedOpShapedType.getElementType();
1683   applyPermutationToVector(stripMineShape, lastDimToInsertPosPerm);
1684   RankedTensorType stripMineTensorType =
1685       RankedTensorType::get(stripMineShape, stripMineElemType);
1686   // Transpose the appropriate rows to match output.
1687   vector::TransposeOp transposeOp = rewriter.create<vector::TransposeOp>(
1688       loc, readResult, lastDimToInsertPosPerm);
1689 
1690   // Collapse the vector to the size required by result.
1691   RankedTensorType collapsedType = tensor::CollapseShapeOp::inferCollapsedType(
1692       stripMineTensorType, packMetadata.reassociations);
1693   mlir::VectorType vecCollapsedType =
1694       VectorType::get(collapsedType.getShape(), collapsedType.getElementType());
1695   vector::ShapeCastOp shapeCastOp = rewriter.create<vector::ShapeCastOp>(
1696       loc, vecCollapsedType, transposeOp->getResult(0));
1697 
1698   // writeVectorSizes had to match the shapecast shape for dynamic sizes,
1699   // otherwise the validator complains that the mask size is invalid.
1700   SmallVector<int64_t> writeVectorSizes(
1701       unpackOp.getDestType().hasStaticShape()
1702           ? vectorSizes
1703           : shapeCastOp.getResultVectorType().getShape());
1704   Operation *write = createWriteOrMaskedWrite(
1705       rewriter, loc, shapeCastOp.getResult(), reifiedRetShapes[0],
1706       writeVectorSizes, useInBoundsInsteadOfMasking);
1707   newResults.push_back(write->getResult(0));
1708   return success();
1709 }
1710 
1711 /// Vectorize a `padOp` with (1) static result type, (2) constant padding value
1712 /// and (3) all-zero lowPad to
1713 ///   `transfer_write_in_bounds(transfer_read_masked(pad_source, pad_value))`.
1714 static LogicalResult
1715 vectorizeAsTensorPadOp(RewriterBase &rewriter, tensor::PadOp padOp,
1716                        ArrayRef<int64_t> inputVectorSizes,
1717                        SmallVectorImpl<Value> &newResults) {
1718   auto padValue = padOp.getConstantPaddingValue();
1719   Location loc = padOp.getLoc();
1720 
1721   // transfer_write_in_bounds(transfer_read_masked(pad_source, pad_value))
1722   OpBuilder::InsertionGuard g(rewriter);
1723   rewriter.setInsertionPoint(padOp);
1724 
1725   ReifiedRankedShapedTypeDims reifiedReturnShapes;
1726   LogicalResult status =
1727       cast<ReifyRankedShapedTypeOpInterface>(padOp.getOperation())
1728           .reifyResultShapes(rewriter, reifiedReturnShapes);
1729   (void)status; // prevent unused variable warning on non-assert builds
1730   assert(succeeded(status) && "failed to reify result shapes");
1731   auto maskedRead = vector::createReadOrMaskedRead(
1732       rewriter, loc, padOp.getSource(), inputVectorSizes, padValue,
1733       /*useInBoundsInsteadOfMasking=*/false);
1734   Operation *write = createWriteOrMaskedWrite(
1735       rewriter, loc, maskedRead, reifiedReturnShapes[0], inputVectorSizes,
1736       /*useInBoundsInsteadOfMasking=*/false);
1737   newResults.push_back(write->getResult(0));
1738   return success();
1739 }
1740 
1741 // TODO: probably need some extra checks for reduction followed by consumer
1742 // ops that may not commute (e.g. linear reduction + non-linear instructions).
1743 static LogicalResult reductionPreconditions(LinalgOp op) {
1744   if (llvm::none_of(op.getIteratorTypesArray(), isReductionIterator)) {
1745     LDBG("reduction precondition failed: no reduction iterator\n");
1746     return failure();
1747   }
1748   for (OpOperand &opOperand : op.getDpsInitsMutable()) {
1749     AffineMap indexingMap = op.getMatchingIndexingMap(&opOperand);
1750     if (indexingMap.isPermutation())
1751       continue;
1752 
1753     Operation *reduceOp = matchLinalgReduction(&opOperand);
1754     if (!reduceOp || !getCombinerOpKind(reduceOp)) {
1755       LDBG("reduction precondition failed: reduction detection failed\n");
1756       return failure();
1757     }
1758   }
1759   return success();
1760 }
1761 
1762 static LogicalResult
1763 vectorizeDynamicConvOpPrecondition(linalg::LinalgOp conv,
1764                                    bool flatten1DDepthwiseConv) {
1765   if (flatten1DDepthwiseConv) {
1766     LDBG("Vectorization of flattened convs with dynamic shapes is not "
1767          "supported\n");
1768     return failure();
1769   }
1770 
1771   if (!isa<linalg::DepthwiseConv1DNwcWcOp>(conv)) {
1772     LDBG("Not a 1D depth-wise WC conv, dynamic shapes are not supported\n");
1773     return failure();
1774   }
1775 
1776   // Support dynamic shapes in 1D depthwise convolution, but only in the
1777   // _channel_ dimension.
1778   Value lhs = conv.getDpsInputOperand(0)->get();
1779   ArrayRef<int64_t> lhsShape = cast<ShapedType>(lhs.getType()).getShape();
1780   auto shapeWithoutCh = lhsShape.drop_back(1);
1781   if (ShapedType::isDynamicShape(shapeWithoutCh)) {
1782     LDBG("Dynamically-shaped op vectorization precondition failed: only "
1783          "channel dim can be dynamic\n");
1784     return failure();
1785   }
1786 
1787   return success();
1788 }
1789 
1790 static LogicalResult
1791 vectorizeDynamicLinalgOpPrecondition(linalg::LinalgOp op,
1792                                      bool flatten1DDepthwiseConv) {
1793   if (isa<ConvolutionOpInterface>(op.getOperation()))
1794     return vectorizeDynamicConvOpPrecondition(op, flatten1DDepthwiseConv);
1795 
1796   if (hasReductionIterator(op))
1797     return reductionPreconditions(op);
1798 
1799   // TODO: Masking only supports dynamic element-wise ops, linalg.generic ops,
1800   // linalg.copy ops and ops that implement ContractionOpInterface for now.
1801   if (!isElementwise(op) &&
1802       !isa<linalg::GenericOp, linalg::CopyOp, linalg::ContractionOpInterface>(
1803           op.getOperation()))
1804     return failure();
1805 
1806   LDBG("Dynamically-shaped op meets vectorization pre-conditions\n");
1807   return success();
1808 }
1809 
1810 /// Need to check if the inner-tiles are static/constant.
1811 static LogicalResult
1812 vectorizeUnPackOpPrecondition(tensor::UnPackOp unpackOp,
1813                               ArrayRef<int64_t> inputVectorSizes) {
1814 
1815   if (llvm::any_of(unpackOp.getInnerTiles(), [](OpFoldResult res) {
1816         return !getConstantIntValue(res).has_value();
1817       })) {
1818     LDBG("Inner-tiles must be constant: " << unpackOp << "\n");
1819     return failure();
1820   }
1821   ArrayRef<int64_t> resultShape = unpackOp.getDestType().getShape();
1822   bool satisfyEmptyCond = inputVectorSizes.empty() &&
1823                           unpackOp.getDestType().hasStaticShape() &&
1824                           unpackOp.getSourceType().hasStaticShape();
1825   if (!satisfyEmptyCond &&
1826       failed(vector::isValidMaskedInputVector(resultShape, inputVectorSizes)))
1827     return failure();
1828 
1829   return success();
1830 }
1831 
1832 static LogicalResult vectorizeLinalgOpPrecondition(
1833     LinalgOp linalgOp, ArrayRef<int64_t> inputVectorSizes,
1834     bool vectorizeNDExtract, bool flatten1DDepthwiseConv) {
1835   // tensor with dimension of 0 cannot be vectorized.
1836   if (llvm::is_contained(linalgOp.getStaticShape(), 0))
1837     return failure();
1838   // Check API contract for input vector sizes.
1839   if (!inputVectorSizes.empty() &&
1840       failed(vector::isValidMaskedInputVector(linalgOp.getStaticLoopRanges(),
1841                                               inputVectorSizes)))
1842     return failure();
1843 
1844   if (linalgOp.hasDynamicShape() && failed(vectorizeDynamicLinalgOpPrecondition(
1845                                         linalgOp, flatten1DDepthwiseConv))) {
1846     LDBG("Dynamically-shaped op failed vectorization pre-conditions\n");
1847     return failure();
1848   }
1849 
1850   SmallVector<CustomVectorizationPrecondition> customPreconditions;
1851 
1852   // Register CustomVectorizationPrecondition for extractOp.
1853   customPreconditions.push_back(tensorExtractVectorizationPrecondition);
1854 
1855   // All types in the body should be a supported element type for VectorType.
1856   for (Operation &innerOp : linalgOp->getRegion(0).front()) {
1857     // Check if any custom hook can vectorize the inner op.
1858     if (llvm::any_of(
1859             customPreconditions,
1860             [&](const CustomVectorizationPrecondition &customPrecondition) {
1861               return succeeded(
1862                   customPrecondition(&innerOp, vectorizeNDExtract));
1863             })) {
1864       continue;
1865     }
1866     if (llvm::any_of(innerOp.getOperandTypes(), [](Type type) {
1867           return !VectorType::isValidElementType(type);
1868         })) {
1869       return failure();
1870     }
1871     if (llvm::any_of(innerOp.getResultTypes(), [](Type type) {
1872           return !VectorType::isValidElementType(type);
1873         })) {
1874       return failure();
1875     }
1876   }
1877   if (isElementwise(linalgOp))
1878     return success();
1879 
1880   // TODO: isaConvolutionOpInterface that can also infer from generic
1881   // features. But we will still need stride/dilation attributes that will be
1882   // annoying to reverse-engineer...
1883   if (isa<ConvolutionOpInterface>(linalgOp.getOperation()))
1884     return success();
1885   // TODO: the common vector shape is equal to the static loop sizes only when
1886   // all indexing maps are projected permutations. For convs and stencils the
1887   // logic will need to evolve.
1888   if (!allIndexingsAreProjectedPermutation(linalgOp)) {
1889     LDBG("precondition failed: not projected permutations\n");
1890     return failure();
1891   }
1892   if (failed(reductionPreconditions(linalgOp))) {
1893     LDBG("precondition failed: reduction preconditions\n");
1894     return failure();
1895   }
1896   return success();
1897 }
1898 
1899 static LogicalResult
1900 vectorizePackOpPrecondition(tensor::PackOp packOp,
1901                             ArrayRef<int64_t> inputVectorSizes) {
1902   auto padValue = packOp.getPaddingValue();
1903   Attribute cstAttr;
1904   if (padValue && !matchPattern(padValue, m_Constant(&cstAttr))) {
1905     LDBG("pad value is not constant: " << packOp << "\n");
1906     return failure();
1907   }
1908   ArrayRef<int64_t> resultTensorShape = packOp.getDestType().getShape();
1909   bool satisfyEmptyCond = true;
1910   if (inputVectorSizes.empty()) {
1911     if (!packOp.getDestType().hasStaticShape() ||
1912         !packOp.getSourceType().hasStaticShape())
1913       satisfyEmptyCond = false;
1914   }
1915 
1916   if (!satisfyEmptyCond &&
1917       failed(vector::isValidMaskedInputVector(
1918           resultTensorShape.take_front(packOp.getSourceRank()),
1919           inputVectorSizes)))
1920     return failure();
1921 
1922   if (llvm::any_of(packOp.getInnerTiles(), [](OpFoldResult v) {
1923         return !getConstantIntValue(v).has_value();
1924       })) {
1925     LDBG("inner_tiles must be constant: " << packOp << "\n");
1926     return failure();
1927   }
1928 
1929   return success();
1930 }
1931 
1932 static LogicalResult
1933 vectorizePadOpPrecondition(tensor::PadOp padOp,
1934                            ArrayRef<int64_t> inputVectorSizes) {
1935   auto padValue = padOp.getConstantPaddingValue();
1936   if (!padValue) {
1937     LDBG("pad value is not constant: " << padOp << "\n");
1938     return failure();
1939   }
1940 
1941   ArrayRef<int64_t> resultTensorShape = padOp.getResultType().getShape();
1942   if (failed(vector::isValidMaskedInputVector(resultTensorShape,
1943                                               inputVectorSizes)))
1944     return failure();
1945 
1946   if (llvm::any_of(padOp.getLow(), [](Value v) {
1947         std::optional<int64_t> res = getConstantIntValue(v);
1948         return !res.has_value() || res.value() != 0;
1949       })) {
1950     LDBG("low pad must all be zero: " << padOp << "\n");
1951     return failure();
1952   }
1953 
1954   return success();
1955 }
1956 
1957 /// Preconditions for scalable vectors. This is quite restrictive - it models
1958 /// the fact that in practice we would only make selected dimensions scalable.
1959 static LogicalResult
1960 vectorizeScalableVectorPrecondition(Operation *op,
1961                                     ArrayRef<int64_t> inputVectorSizes,
1962                                     ArrayRef<bool> inputScalableVecDims) {
1963   assert(inputVectorSizes.size() == inputScalableVecDims.size() &&
1964          "Number of input vector sizes and scalable dims doesn't match");
1965 
1966   size_t numOfScalableDims =
1967       llvm::count_if(inputScalableVecDims, [](bool flag) { return flag; });
1968 
1969   if (numOfScalableDims == 0)
1970     return success();
1971 
1972   auto linalgOp = dyn_cast<LinalgOp>(op);
1973 
1974   // Cond 1: There's been no need for scalable vectorisation of
1975   // non-linalg Ops so far
1976   if (!linalgOp)
1977     return failure();
1978 
1979   // Cond 2: There's been no need for more than 2 scalable dims so far
1980   if (numOfScalableDims > 2)
1981     return failure();
1982 
1983   // Cond 3: Look at the configuration in `inputScalableVecDims` and verify that
1984   // it matches one of the supported cases:
1985   //  1. exactly 1 dim is scalable and that's the _last_ parallel dim
1986   //  2. exactly 2 dims are scalable and those are the _last two adjacent_
1987   //     parallel dims
1988   //  3. exactly 1 reduction dim is scalable and that's the last (innermost) dim
1989   // The 2nd restriction above means that only Matmul-like Ops are supported
1990   // when 2 dims are scalable, e.g. :
1991   //    * iterators = [parallel, parallel, reduction]
1992   //    * scalable flags = [true, true, false]
1993 
1994   // Find the first scalable flag
1995   bool seenParalell = false;
1996   auto iterators = linalgOp.getIteratorTypesArray();
1997   SmallVector<bool> scalableFlags(inputScalableVecDims);
1998   while (!scalableFlags.back()) {
1999     seenParalell |= (iterators.back() == utils::IteratorType::parallel);
2000 
2001     iterators.pop_back();
2002     scalableFlags.pop_back();
2003   }
2004 
2005   switch (iterators.back()) {
2006   case utils::IteratorType::reduction: {
2007     // Check 3. above is met.
2008     if (iterators.size() != inputVectorSizes.size()) {
2009       LDBG("Non-trailing reduction dim requested for scalable "
2010            "vectorization\n");
2011       return failure();
2012     }
2013     if (isa<linalg::MatmulOp>(op) || isa<linalg::MatmulTransposeAOp>(op)) {
2014       LDBG("Scalable vectorization of the reduction dim in Matmul-like ops "
2015            "is not supported\n");
2016       return failure();
2017     }
2018     break;
2019   }
2020   case utils::IteratorType::parallel: {
2021     // Check 1. and 2. above are met.
2022     if (seenParalell) {
2023       LDBG("Inner parallel dim not requested for scalable "
2024            "vectorization\n");
2025       return failure();
2026     }
2027     break;
2028   }
2029   }
2030 
2031   // If present, check the 2nd scalable dim. ATM, only Matmul-like Ops are
2032   // supported for which expect the folowing config:
2033   //    * iterators = [parallel, parallel, reduction]
2034   //    * scalable flags = [true, true, false]
2035   if (numOfScalableDims == 2) {
2036     // Disallow below case which breaks 3. above:
2037     //    * iterators = [..., parallel, reduction]
2038     //    * scalable flags = [..., true, true]
2039     if (iterators.back() == utils::IteratorType::reduction) {
2040       LDBG("Higher dim than the trailing reduction dim requested for scalable "
2041            "vectorization\n");
2042       return failure();
2043     }
2044     scalableFlags.pop_back();
2045     iterators.pop_back();
2046 
2047     if (!scalableFlags.back() ||
2048         (iterators.back() != utils::IteratorType::parallel))
2049       return failure();
2050   }
2051 
2052   // Cond 4: Only the following ops are supported in the
2053   // presence of scalable vectors
2054   return success(isElementwise(linalgOp) || isa<linalg::MatmulOp>(op) ||
2055                  isa<linalg::MatmulTransposeAOp>(op) ||
2056                  isa<linalg::DepthwiseConv1DNwcWcOp>(op) ||
2057                  isa<linalg::MatvecOp>(op) || hasReductionIterator(linalgOp));
2058 }
2059 
2060 LogicalResult mlir::linalg::vectorizeOpPrecondition(
2061     Operation *op, ArrayRef<int64_t> inputVectorSizes,
2062     ArrayRef<bool> inputScalableVecDims, bool vectorizeNDExtract,
2063     bool flatten1DDepthwiseConv) {
2064   if (failed(vectorizeScalableVectorPrecondition(op, inputVectorSizes,
2065                                                  inputScalableVecDims)))
2066     return failure();
2067 
2068   return TypeSwitch<Operation *, LogicalResult>(op)
2069       .Case<linalg::LinalgOp>([&](auto linalgOp) {
2070         return vectorizeLinalgOpPrecondition(linalgOp, inputVectorSizes,
2071                                              vectorizeNDExtract,
2072                                              flatten1DDepthwiseConv);
2073       })
2074       .Case<tensor::PadOp>([&](auto padOp) {
2075         return vectorizePadOpPrecondition(padOp, inputVectorSizes);
2076       })
2077       .Case<tensor::PackOp>([&](auto packOp) {
2078         return vectorizePackOpPrecondition(packOp, inputVectorSizes);
2079       })
2080       .Case<tensor::UnPackOp>([&](auto unpackOp) {
2081         return vectorizeUnPackOpPrecondition(unpackOp, inputVectorSizes);
2082       })
2083       .Default([](auto) { return failure(); });
2084 }
2085 
2086 /// Converts affine.apply Ops to arithmetic operations.
2087 static void convertAffineApply(RewriterBase &rewriter, LinalgOp linalgOp) {
2088   OpBuilder::InsertionGuard g(rewriter);
2089   auto toReplace = linalgOp.getBlock()->getOps<affine::AffineApplyOp>();
2090 
2091   for (auto op : make_early_inc_range(toReplace)) {
2092     rewriter.setInsertionPoint(op);
2093     auto expanded = affine::expandAffineExpr(
2094         rewriter, op->getLoc(), op.getAffineMap().getResult(0),
2095         op.getOperands().take_front(op.getAffineMap().getNumDims()),
2096         op.getOperands().take_back(op.getAffineMap().getNumSymbols()));
2097     rewriter.replaceOp(op, expanded);
2098   }
2099 }
2100 
2101 /// Emit a suitable vector form for an operation. If provided,
2102 /// `inputVectorSizes` are used to vectorize this operation.
2103 /// `inputVectorSizes` must match the rank of the iteration space of the
2104 /// operation and the input vector sizes must be greater than or equal to
2105 /// their counterpart iteration space sizes, if static. `inputVectorShapes`
2106 /// also allows the vectorization of operations with dynamic shapes.
2107 LogicalResult mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op,
2108                                       ArrayRef<int64_t> inputVectorSizes,
2109                                       ArrayRef<bool> inputScalableVecDims,
2110                                       bool vectorizeNDExtract,
2111                                       bool flatten1DDepthwiseConv) {
2112   LDBG("Attempting to vectorize:\n" << *op << "\n");
2113   LDBG("Input vector sizes: ");
2114   LLVM_DEBUG(llvm::interleaveComma(inputVectorSizes, llvm::dbgs()));
2115   LLVM_DEBUG(llvm::dbgs() << "\n");
2116   LDBG("Input scalable vector dims: ");
2117   LLVM_DEBUG(llvm::interleaveComma(inputScalableVecDims, llvm::dbgs()));
2118   LLVM_DEBUG(llvm::dbgs() << "\n");
2119 
2120   if (failed(vectorizeOpPrecondition(op, inputVectorSizes, inputScalableVecDims,
2121                                      vectorizeNDExtract,
2122                                      flatten1DDepthwiseConv))) {
2123     LDBG("Vectorization pre-conditions failed\n");
2124     return failure();
2125   }
2126 
2127   // Initialize vectorization state.
2128   VectorizationState state(rewriter);
2129   if (auto linalgOp = dyn_cast<linalg::LinalgOp>(op)) {
2130     if (failed(state.initState(rewriter, linalgOp, inputVectorSizes,
2131                                inputScalableVecDims))) {
2132       LDBG("Vectorization state couldn't be initialized\n");
2133       return failure();
2134     }
2135   }
2136 
2137   SmallVector<Value> results;
2138   auto vectorizeResult =
2139       TypeSwitch<Operation *, LogicalResult>(op)
2140           .Case<linalg::LinalgOp>([&](auto linalgOp) {
2141             // TODO: isaConvolutionOpInterface that can also infer from
2142             // generic features. Will require stride/dilation attributes
2143             // inference.
2144             if (isa<ConvolutionOpInterface>(linalgOp.getOperation())) {
2145               FailureOr<Operation *> convOr = vectorizeConvolution(
2146                   rewriter, linalgOp, inputVectorSizes, inputScalableVecDims,
2147                   flatten1DDepthwiseConv);
2148               if (succeeded(convOr)) {
2149                 llvm::append_range(results, (*convOr)->getResults());
2150                 return success();
2151               }
2152 
2153               LDBG("Unsupported convolution can't be vectorized.\n");
2154               return failure();
2155             }
2156 
2157             LDBG("Vectorize generic by broadcasting to the canonical vector "
2158                  "shape\n");
2159 
2160             // Pre-process before proceeding.
2161             convertAffineApply(rewriter, linalgOp);
2162 
2163             // TODO: 'vectorize' takes in a 'RewriterBase' which is up-casted
2164             // to 'OpBuilder' when it is passed over to some methods like
2165             // 'vectorizeAsLinalgGeneric'. This is highly problematic: if we
2166             // erase an op within these methods, the actual rewriter won't be
2167             // notified and we will end up with read-after-free issues!
2168             return vectorizeAsLinalgGeneric(rewriter, state, linalgOp, results);
2169           })
2170           .Case<tensor::PadOp>([&](auto padOp) {
2171             return vectorizeAsTensorPadOp(rewriter, padOp, inputVectorSizes,
2172                                           results);
2173           })
2174           .Case<tensor::PackOp>([&](auto packOp) {
2175             return vectorizeAsTensorPackOp(rewriter, packOp, inputVectorSizes,
2176                                            results);
2177           })
2178           .Case<tensor::UnPackOp>([&](auto unpackOp) {
2179             return vectorizeAsTensorUnpackOp(rewriter, unpackOp,
2180                                              inputVectorSizes, results);
2181           })
2182           .Default([](auto) { return failure(); });
2183 
2184   if (failed(vectorizeResult)) {
2185     LDBG("Vectorization failed\n");
2186     return failure();
2187   }
2188 
2189   if (!results.empty())
2190     rewriter.replaceOp(op, results);
2191   else
2192     rewriter.eraseOp(op);
2193 
2194   return success();
2195 }
2196 
2197 LogicalResult mlir::linalg::vectorizeCopy(RewriterBase &rewriter,
2198                                           memref::CopyOp copyOp) {
2199   auto srcType = cast<MemRefType>(copyOp.getSource().getType());
2200   auto dstType = cast<MemRefType>(copyOp.getTarget().getType());
2201   if (!srcType.hasStaticShape() || !dstType.hasStaticShape())
2202     return failure();
2203 
2204   auto srcElementType = getElementTypeOrSelf(srcType);
2205   auto dstElementType = getElementTypeOrSelf(dstType);
2206   if (!VectorType::isValidElementType(srcElementType) ||
2207       !VectorType::isValidElementType(dstElementType))
2208     return failure();
2209 
2210   auto readType = VectorType::get(srcType.getShape(), srcElementType);
2211   auto writeType = VectorType::get(dstType.getShape(), dstElementType);
2212 
2213   Location loc = copyOp->getLoc();
2214   Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
2215   SmallVector<Value> indices(srcType.getRank(), zero);
2216 
2217   Value readValue = rewriter.create<vector::TransferReadOp>(
2218       loc, readType, copyOp.getSource(), indices,
2219       rewriter.getMultiDimIdentityMap(srcType.getRank()));
2220   if (cast<VectorType>(readValue.getType()).getRank() == 0) {
2221     readValue = rewriter.create<vector::ExtractElementOp>(loc, readValue);
2222     readValue = rewriter.create<vector::BroadcastOp>(loc, writeType, readValue);
2223   }
2224   Operation *writeValue = rewriter.create<vector::TransferWriteOp>(
2225       loc, readValue, copyOp.getTarget(), indices,
2226       rewriter.getMultiDimIdentityMap(srcType.getRank()));
2227   rewriter.replaceOp(copyOp, writeValue->getResults());
2228   return success();
2229 }
2230 
2231 //----------------------------------------------------------------------------//
2232 // Misc. vectorization patterns.
2233 //----------------------------------------------------------------------------//
2234 
2235 /// Helper function that retrieves the value of an IntegerAttr.
2236 static int64_t getIntFromAttr(Attribute attr) {
2237   return cast<IntegerAttr>(attr).getInt();
2238 }
2239 
2240 /// Given an ArrayRef of OpFoldResults, return a vector of Values.
2241 /// IntegerAttrs are converted to ConstantIndexOps. Other attribute types are
2242 /// not supported.
2243 static SmallVector<Value> ofrToIndexValues(RewriterBase &rewriter, Location loc,
2244                                            ArrayRef<OpFoldResult> ofrs) {
2245   SmallVector<Value> result;
2246   for (auto o : ofrs) {
2247     if (auto val = llvm::dyn_cast_if_present<Value>(o)) {
2248       result.push_back(val);
2249     } else {
2250       result.push_back(rewriter.create<arith::ConstantIndexOp>(
2251           loc, getIntFromAttr(o.template get<Attribute>())));
2252     }
2253   }
2254   return result;
2255 }
2256 
2257 /// Rewrite a tensor::PadOp into a sequence of EmptyOp, FillOp and
2258 /// InsertSliceOp. For now, only constant padding values are supported.
2259 /// If there is enough static type information, TransferReadOps and
2260 /// TransferWriteOps may be generated instead of InsertSliceOps.
2261 struct GenericPadOpVectorizationPattern : public GeneralizePadOpPattern {
2262   GenericPadOpVectorizationPattern(MLIRContext *context,
2263                                    PatternBenefit benefit = 1)
2264       : GeneralizePadOpPattern(context, tryVectorizeCopy, benefit) {}
2265   /// Vectorize the copying of a tensor::PadOp's source. This is possible if
2266   /// each dimension size is statically know in the source type or the result
2267   /// type (or both).
2268   static LogicalResult tryVectorizeCopy(RewriterBase &rewriter,
2269                                         tensor::PadOp padOp, Value dest) {
2270     auto sourceType = padOp.getSourceType();
2271     auto resultType = padOp.getResultType();
2272     if (!VectorType::isValidElementType(sourceType.getElementType()))
2273       return failure();
2274 
2275     // Copy cannot be vectorized if pad value is non-constant and source shape
2276     // is dynamic. In case of a dynamic source shape, padding must be appended
2277     // by TransferReadOp, but TransferReadOp supports only constant padding.
2278     auto padValue = padOp.getConstantPaddingValue();
2279     if (!padValue) {
2280       if (!sourceType.hasStaticShape())
2281         return failure();
2282       // Create dummy padding value.
2283       auto elemType = sourceType.getElementType();
2284       padValue = rewriter.create<arith::ConstantOp>(
2285           padOp.getLoc(), elemType, rewriter.getZeroAttr(elemType));
2286     }
2287 
2288     SmallVector<int64_t> vecShape;
2289     SmallVector<bool> readInBounds;
2290     SmallVector<bool> writeInBounds;
2291     for (unsigned i = 0; i < sourceType.getRank(); ++i) {
2292       if (!sourceType.isDynamicDim(i)) {
2293         vecShape.push_back(sourceType.getDimSize(i));
2294         // Source shape is statically known: Neither read nor write are
2295         // out-of- bounds.
2296         readInBounds.push_back(true);
2297         writeInBounds.push_back(true);
2298       } else if (!resultType.isDynamicDim(i)) {
2299         // Source shape is not statically known, but result shape is.
2300         // Vectorize with size of result shape. This may be larger than the
2301         // source size.
2302         vecShape.push_back(resultType.getDimSize(i));
2303         // Read may be out-of-bounds because the result size could be larger
2304         // than the source size.
2305         readInBounds.push_back(false);
2306         // Write is out-of-bounds if low padding > 0.
2307         writeInBounds.push_back(
2308             getConstantIntValue(padOp.getMixedLowPad()[i]) ==
2309             static_cast<int64_t>(0));
2310       } else {
2311         // Neither source nor result dim of padOp is static. Cannot vectorize
2312         // the copy.
2313         return failure();
2314       }
2315     }
2316     auto vecType = VectorType::get(vecShape, sourceType.getElementType());
2317 
2318     // Generate TransferReadOp.
2319     SmallVector<Value> readIndices(
2320         vecType.getRank(),
2321         rewriter.create<arith::ConstantIndexOp>(padOp.getLoc(), 0));
2322     auto read = rewriter.create<vector::TransferReadOp>(
2323         padOp.getLoc(), vecType, padOp.getSource(), readIndices, padValue,
2324         ArrayRef<bool>{readInBounds});
2325 
2326     // If `dest` is a FillOp and the TransferWriteOp would overwrite the
2327     // entire tensor, write directly to the FillOp's operand.
2328     if (llvm::equal(vecShape, resultType.getShape()) &&
2329         llvm::all_of(writeInBounds, [](bool b) { return b; }))
2330       if (auto fill = dest.getDefiningOp<FillOp>())
2331         dest = fill.output();
2332 
2333     // Generate TransferWriteOp.
2334     auto writeIndices =
2335         ofrToIndexValues(rewriter, padOp.getLoc(), padOp.getMixedLowPad());
2336     rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
2337         padOp, read, dest, writeIndices, ArrayRef<bool>{writeInBounds});
2338 
2339     return success();
2340   }
2341 };
2342 
2343 /// Base pattern for rewriting tensor::PadOps whose result is consumed by a
2344 /// given operation type OpTy.
2345 template <typename OpTy>
2346 struct VectorizePadOpUserPattern : public OpRewritePattern<tensor::PadOp> {
2347   using OpRewritePattern<tensor::PadOp>::OpRewritePattern;
2348 
2349   LogicalResult matchAndRewrite(tensor::PadOp padOp,
2350                                 PatternRewriter &rewriter) const final {
2351     bool changed = false;
2352     // Insert users in vector, because some users may be replaced/removed.
2353     for (auto *user : llvm::to_vector<4>(padOp->getUsers()))
2354       if (auto op = dyn_cast<OpTy>(user))
2355         changed |= rewriteUser(rewriter, padOp, op).succeeded();
2356     return success(changed);
2357   }
2358 
2359 protected:
2360   virtual LogicalResult rewriteUser(PatternRewriter &rewriter,
2361                                     tensor::PadOp padOp, OpTy op) const = 0;
2362 };
2363 
2364 /// Rewrite use of tensor::PadOp result in TransferReadOp. E.g.:
2365 /// ```
2366 /// %0 = tensor.pad %src ... : tensor<?x?xf32> to tensor<17x5xf32>
2367 /// %r = vector.transfer_read %0[%c0, %c0], %cst
2368 ///     {in_bounds = [true, true]} : tensor<17x5xf32>, vector<17x5xf32>
2369 /// ```
2370 /// is rewritten to:
2371 /// ```
2372 /// %r = vector.transfer_read %src[%c0, %c0], %padding
2373 ///     {in_bounds = [true, true]}
2374 ///     : tensor<?x?xf32>, vector<17x5xf32>
2375 /// ```
2376 /// Note: By restricting this pattern to in-bounds TransferReadOps, we can be
2377 /// sure that the original padding value %cst was never used.
2378 ///
2379 /// This rewrite is possible if:
2380 /// - `xferOp` has no out-of-bounds dims or mask.
2381 /// - Low padding is static 0.
2382 /// - Single, scalar padding value.
2383 struct PadOpVectorizationWithTransferReadPattern
2384     : public VectorizePadOpUserPattern<vector::TransferReadOp> {
2385   using VectorizePadOpUserPattern<
2386       vector::TransferReadOp>::VectorizePadOpUserPattern;
2387 
2388   LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp,
2389                             vector::TransferReadOp xferOp) const override {
2390     // Low padding must be static 0.
2391     if (!padOp.hasZeroLowPad())
2392       return failure();
2393     // Pad value must be a constant.
2394     auto padValue = padOp.getConstantPaddingValue();
2395     if (!padValue)
2396       return failure();
2397     // Padding value of existing `xferOp` is unused.
2398     if (xferOp.hasOutOfBoundsDim() || xferOp.getMask())
2399       return failure();
2400 
2401     rewriter.modifyOpInPlace(xferOp, [&]() {
2402       SmallVector<bool> inBounds(xferOp.getVectorType().getRank(), false);
2403       xferOp->setAttr(xferOp.getInBoundsAttrName(),
2404                       rewriter.getBoolArrayAttr(inBounds));
2405       xferOp.getSourceMutable().assign(padOp.getSource());
2406       xferOp.getPaddingMutable().assign(padValue);
2407     });
2408 
2409     return success();
2410   }
2411 };
2412 
2413 /// Rewrite use of tensor::PadOp result in TransferWriteOp.
2414 /// This pattern rewrites TransferWriteOps that write to a padded tensor
2415 /// value, where the same amount of padding is immediately removed again after
2416 /// the write. In such cases, the TransferWriteOp can write to the non-padded
2417 /// tensor value and apply out-of-bounds masking. E.g.:
2418 /// ```
2419 /// %0 = tensor.extract_slice ...[...] [%s0, %s1] [1, 1]
2420 ///     : tensor<...> to tensor<?x?xf32>
2421 /// %1 = tensor.pad %0 ... : tensor<?x?xf32> to tensor<17x5xf32>
2422 /// %2 = vector.transfer_write %vec, %1[...]
2423 ///     : vector<17x5xf32>, tensor<17x5xf32>
2424 /// %r = tensor.extract_slice %2[0, 0] [%s0, %s1] [1, 1]
2425 ///     : tensor<17x5xf32> to tensor<?x?xf32>
2426 /// ```
2427 /// is rewritten to:
2428 /// ```
2429 /// %0 = tensor.extract_slice ...[...] [%s0, %s1] [1, 1]
2430 ///     : tensor<...> to tensor<?x?xf32>
2431 /// %r = vector.transfer_write %vec, %0[...] : vector<17x5xf32>,
2432 /// tensor<?x?xf32>
2433 /// ```
2434 /// Note: It is important that the ExtractSliceOp %r resizes the result of the
2435 /// TransferWriteOp to the same size as the input of the TensorPadOp (or an
2436 /// even smaller size). Otherwise, %r's new (dynamic) dimensions would differ
2437 /// from %r's old dimensions.
2438 ///
2439 /// This rewrite is possible if:
2440 /// - Low padding is static 0.
2441 /// - `xferOp` has exactly one use, which is an ExtractSliceOp. This
2442 ///   ExtractSliceOp trims the same amount of padding that was added
2443 ///   beforehand.
2444 /// - Single, scalar padding value.
2445 struct PadOpVectorizationWithTransferWritePattern
2446     : public VectorizePadOpUserPattern<vector::TransferWriteOp> {
2447   using VectorizePadOpUserPattern<
2448       vector::TransferWriteOp>::VectorizePadOpUserPattern;
2449 
2450   LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp,
2451                             vector::TransferWriteOp xferOp) const override {
2452     // TODO: support 0-d corner case.
2453     if (xferOp.getTransferRank() == 0)
2454       return failure();
2455 
2456     // Low padding must be static 0.
2457     if (!padOp.hasZeroLowPad())
2458       return failure();
2459     // Pad value must be a constant.
2460     auto padValue = padOp.getConstantPaddingValue();
2461     if (!padValue)
2462       return failure();
2463     // TransferWriteOp result must be directly consumed by an ExtractSliceOp.
2464     if (!xferOp->hasOneUse())
2465       return failure();
2466     auto trimPadding = dyn_cast<tensor::ExtractSliceOp>(*xferOp->user_begin());
2467     if (!trimPadding)
2468       return failure();
2469     // Only static zero offsets supported when trimming padding.
2470     if (!trimPadding.hasZeroOffset())
2471       return failure();
2472     // trimPadding must remove the amount of padding that was added earlier.
2473     if (!hasSameTensorSize(padOp.getSource(), trimPadding))
2474       return failure();
2475 
2476     // Insert the new TransferWriteOp at position of the old TransferWriteOp.
2477     rewriter.setInsertionPoint(xferOp);
2478 
2479     SmallVector<bool> inBounds(xferOp.getVectorType().getRank(), false);
2480     auto newXferOp = rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
2481         xferOp, padOp.getSource().getType(), xferOp.getVector(),
2482         padOp.getSource(), xferOp.getIndices(), xferOp.getPermutationMapAttr(),
2483         xferOp.getMask(), rewriter.getBoolArrayAttr(inBounds));
2484     rewriter.replaceOp(trimPadding, newXferOp->getResult(0));
2485 
2486     return success();
2487   }
2488 
2489   /// Check if `beforePadding` and `afterTrimming` have the same tensor size,
2490   /// i.e., same dimensions.
2491   ///
2492   /// Dimensions may be static, dynamic or mix of both. In case of dynamic
2493   /// dimensions, this function tries to infer the (static) tensor size by
2494   /// looking at the defining op and utilizing op-specific knowledge.
2495   ///
2496   /// This is a conservative analysis. In case equal tensor sizes cannot be
2497   /// proven statically, this analysis returns `false` even though the tensor
2498   /// sizes may turn out to be equal at runtime.
2499   bool hasSameTensorSize(Value beforePadding,
2500                          tensor::ExtractSliceOp afterTrimming) const {
2501     // If the input to tensor::PadOp is a CastOp, try with both CastOp
2502     // result and CastOp operand.
2503     if (auto castOp = beforePadding.getDefiningOp<tensor::CastOp>())
2504       if (hasSameTensorSize(castOp.getSource(), afterTrimming))
2505         return true;
2506 
2507     auto t1 = dyn_cast<RankedTensorType>(beforePadding.getType());
2508     auto t2 = dyn_cast<RankedTensorType>(afterTrimming.getType());
2509     // Only RankedTensorType supported.
2510     if (!t1 || !t2)
2511       return false;
2512     // Rank of both values must be the same.
2513     if (t1.getRank() != t2.getRank())
2514       return false;
2515 
2516     // All static dimensions must be the same. Mixed cases (e.g., dimension
2517     // static in `t1` but dynamic in `t2`) are not supported.
2518     for (unsigned i = 0; i < t1.getRank(); ++i) {
2519       if (t1.isDynamicDim(i) != t2.isDynamicDim(i))
2520         return false;
2521       if (!t1.isDynamicDim(i) && t1.getDimSize(i) != t2.getDimSize(i))
2522         return false;
2523     }
2524 
2525     // Nothing more to check if all dimensions are static.
2526     if (t1.getNumDynamicDims() == 0)
2527       return true;
2528 
2529     // All dynamic sizes must be the same. The only supported case at the
2530     // moment is when `beforePadding` is an ExtractSliceOp (or a cast
2531     // thereof).
2532 
2533     // Apart from CastOp, only ExtractSliceOp is supported.
2534     auto beforeSlice = beforePadding.getDefiningOp<tensor::ExtractSliceOp>();
2535     if (!beforeSlice)
2536       return false;
2537 
2538     assert(static_cast<size_t>(t1.getRank()) ==
2539            beforeSlice.getMixedSizes().size());
2540     assert(static_cast<size_t>(t2.getRank()) ==
2541            afterTrimming.getMixedSizes().size());
2542 
2543     for (unsigned i = 0; i < t1.getRank(); ++i) {
2544       // Skip static dimensions.
2545       if (!t1.isDynamicDim(i))
2546         continue;
2547       auto size1 = beforeSlice.getMixedSizes()[i];
2548       auto size2 = afterTrimming.getMixedSizes()[i];
2549 
2550       // Case 1: Same value or same constant int.
2551       if (isEqualConstantIntOrValue(size1, size2))
2552         continue;
2553 
2554       // Other cases: Take a deeper look at defining ops of values.
2555       auto v1 = llvm::dyn_cast_if_present<Value>(size1);
2556       auto v2 = llvm::dyn_cast_if_present<Value>(size2);
2557       if (!v1 || !v2)
2558         return false;
2559 
2560       // Case 2: Both values are identical AffineMinOps. (Should not happen if
2561       // CSE is run.)
2562       auto minOp1 = v1.getDefiningOp<affine::AffineMinOp>();
2563       auto minOp2 = v2.getDefiningOp<affine::AffineMinOp>();
2564       if (minOp1 && minOp2 && minOp1.getAffineMap() == minOp2.getAffineMap() &&
2565           minOp1.getOperands() == minOp2.getOperands())
2566         continue;
2567 
2568       // Add additional cases as needed.
2569     }
2570 
2571     // All tests passed.
2572     return true;
2573   }
2574 };
2575 
2576 /// Rewrite use of tensor::PadOp result in InsertSliceOp. E.g.:
2577 /// ```
2578 /// %0 = tensor.pad %src ... : tensor<?x?xf32> to tensor<17x5xf32>
2579 /// %r = tensor.insert_slice %0
2580 ///     into %dest[%a, %b, 0, 0] [1, 1, 17, 5] [1, 1, 1, 1]
2581 ///     : tensor<17x5xf32> into tensor<?x?x17x5xf32>
2582 /// ```
2583 /// is rewritten to:
2584 /// ```
2585 /// %0 = vector.transfer_read %src[%c0, %c0], %padding
2586 ///     : tensor<?x?xf32>, vector<17x5xf32>
2587 /// %r = vector.transfer_write %0, %dest[%a, %b, %c0, %c0]
2588 ///     {in_bounds = [true, true]} : vector<17x5xf32>, tensor<?x?x17x5xf32>
2589 /// ```
2590 ///
2591 /// This rewrite is possible if:
2592 /// - Low padding is static 0.
2593 /// - `padOp` result shape is static.
2594 /// - The entire padded tensor is inserted.
2595 ///   (Implies that sizes of `insertOp` are all static.)
2596 /// - Only unit strides in `insertOp`.
2597 /// - Single, scalar padding value.
2598 /// - `padOp` result not used as destination.
2599 struct PadOpVectorizationWithInsertSlicePattern
2600     : public VectorizePadOpUserPattern<tensor::InsertSliceOp> {
2601   using VectorizePadOpUserPattern<
2602       tensor::InsertSliceOp>::VectorizePadOpUserPattern;
2603 
2604   LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp,
2605                             tensor::InsertSliceOp insertOp) const override {
2606     // Low padding must be static 0.
2607     if (!padOp.hasZeroLowPad())
2608       return failure();
2609     // Only unit stride supported.
2610     if (!insertOp.hasUnitStride())
2611       return failure();
2612     // Pad value must be a constant.
2613     auto padValue = padOp.getConstantPaddingValue();
2614     if (!padValue)
2615       return failure();
2616     // Dynamic shapes not supported.
2617     if (!cast<ShapedType>(padOp.getResult().getType()).hasStaticShape())
2618       return failure();
2619     // Pad result not used as destination.
2620     if (insertOp.getDest() == padOp.getResult())
2621       return failure();
2622 
2623     auto vecType = VectorType::get(padOp.getType().getShape(),
2624                                    padOp.getType().getElementType());
2625     unsigned vecRank = vecType.getRank();
2626     unsigned tensorRank = insertOp.getType().getRank();
2627 
2628     // Check if sizes match: Insert the entire tensor into most minor dims.
2629     // (No permutations allowed.)
2630     SmallVector<int64_t> expectedSizes(tensorRank - vecRank, 1);
2631     expectedSizes.append(vecType.getShape().begin(), vecType.getShape().end());
2632     if (!llvm::all_of(
2633             llvm::zip(insertOp.getMixedSizes(), expectedSizes), [](auto it) {
2634               return getConstantIntValue(std::get<0>(it)) == std::get<1>(it);
2635             }))
2636       return failure();
2637 
2638     // Insert the TransferReadOp and TransferWriteOp at the position of the
2639     // InsertSliceOp.
2640     rewriter.setInsertionPoint(insertOp);
2641 
2642     // Generate TransferReadOp: Read entire source tensor and add high
2643     // padding.
2644     SmallVector<Value> readIndices(
2645         vecRank, rewriter.create<arith::ConstantIndexOp>(padOp.getLoc(), 0));
2646     auto read = rewriter.create<vector::TransferReadOp>(
2647         padOp.getLoc(), vecType, padOp.getSource(), readIndices, padValue);
2648 
2649     // Generate TransferWriteOp: Write to InsertSliceOp's dest tensor at
2650     // specified offsets. Write is fully in-bounds because a InsertSliceOp's
2651     // source must fit into the destination at the specified offsets.
2652     auto writeIndices =
2653         ofrToIndexValues(rewriter, padOp.getLoc(), insertOp.getMixedOffsets());
2654     SmallVector<bool> inBounds(vecRank, true);
2655     rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
2656         insertOp, read, insertOp.getDest(), writeIndices,
2657         ArrayRef<bool>{inBounds});
2658 
2659     return success();
2660   }
2661 };
2662 
2663 void mlir::linalg::populatePadOpVectorizationPatterns(
2664     RewritePatternSet &patterns, PatternBenefit baseBenefit) {
2665   patterns.add<GenericPadOpVectorizationPattern>(patterns.getContext(),
2666                                                  baseBenefit);
2667   // Try these specialized patterns first before resorting to the generic one.
2668   patterns.add<PadOpVectorizationWithTransferReadPattern,
2669                PadOpVectorizationWithTransferWritePattern,
2670                PadOpVectorizationWithInsertSlicePattern>(
2671       patterns.getContext(), baseBenefit.getBenefit() + 1);
2672 }
2673 
2674 //----------------------------------------------------------------------------//
2675 // Forwarding patterns
2676 //----------------------------------------------------------------------------//
2677 
2678 /// Check whether there is any interleaved use of any `values` between
2679 /// `firstOp` and `secondOp`. Conservatively return `true` if any op or value
2680 /// is in a different block.
2681 static bool mayExistInterleavedUses(Operation *firstOp, Operation *secondOp,
2682                                     ValueRange values) {
2683   if (firstOp->getBlock() != secondOp->getBlock() ||
2684       !firstOp->isBeforeInBlock(secondOp)) {
2685     LDBG("interleavedUses precondition failed, firstOp: "
2686          << *firstOp << ", second op: " << *secondOp << "\n");
2687     return true;
2688   }
2689   for (auto v : values) {
2690     for (auto &u : v.getUses()) {
2691       Operation *owner = u.getOwner();
2692       if (owner == firstOp || owner == secondOp)
2693         continue;
2694       // TODO: this is too conservative, use dominance info in the future.
2695       if (owner->getBlock() == firstOp->getBlock() &&
2696           (owner->isBeforeInBlock(firstOp) || secondOp->isBeforeInBlock(owner)))
2697         continue;
2698       LDBG(" found interleaved op " << *owner << ", firstOp: " << *firstOp
2699                                     << ", second op: " << *secondOp << "\n");
2700       return true;
2701     }
2702   }
2703   return false;
2704 }
2705 
2706 /// Return the unique subview use of `v` if it is indeed unique, null
2707 /// otherwise.
2708 static memref::SubViewOp getSubViewUseIfUnique(Value v) {
2709   memref::SubViewOp subViewOp;
2710   for (auto &u : v.getUses()) {
2711     if (auto newSubViewOp = dyn_cast<memref::SubViewOp>(u.getOwner())) {
2712       if (subViewOp)
2713         return memref::SubViewOp();
2714       subViewOp = newSubViewOp;
2715     }
2716   }
2717   return subViewOp;
2718 }
2719 
2720 /// TODO: use interfaces, side-effects and aliasing analysis as appropriate,
2721 /// when available.
2722 LogicalResult LinalgCopyVTRForwardingPattern::matchAndRewrite(
2723     vector::TransferReadOp xferOp, PatternRewriter &rewriter) const {
2724 
2725   // TODO: support mask.
2726   if (xferOp.getMask())
2727     return rewriter.notifyMatchFailure(xferOp, "unsupported mask");
2728 
2729   // Transfer into `view`.
2730   Value viewOrAlloc = xferOp.getSource();
2731   if (!viewOrAlloc.getDefiningOp<memref::ViewOp>() &&
2732       !viewOrAlloc.getDefiningOp<memref::AllocOp>())
2733     return rewriter.notifyMatchFailure(xferOp, "source not a view or alloc");
2734 
2735   // Ensure there is exactly one subview of `viewOrAlloc` defining `subView`.
2736   memref::SubViewOp subViewOp = getSubViewUseIfUnique(viewOrAlloc);
2737   if (!subViewOp)
2738     return rewriter.notifyMatchFailure(xferOp, "no subview found");
2739   Value subView = subViewOp.getResult();
2740 
2741   // Find the copy into `subView` without interleaved uses.
2742   memref::CopyOp copyOp;
2743   for (auto &u : subView.getUses()) {
2744     if (auto newCopyOp = dyn_cast<memref::CopyOp>(u.getOwner())) {
2745       assert(isa<MemRefType>(newCopyOp.getTarget().getType()));
2746       if (newCopyOp.getTarget() != subView)
2747         continue;
2748       if (mayExistInterleavedUses(newCopyOp, xferOp, {viewOrAlloc, subView}))
2749         continue;
2750       copyOp = newCopyOp;
2751       break;
2752     }
2753   }
2754   if (!copyOp)
2755     return rewriter.notifyMatchFailure(xferOp, "no copy found");
2756 
2757   // Find the fill into `viewOrAlloc` without interleaved uses before the
2758   // copy.
2759   FillOp maybeFillOp;
2760   for (auto &u : viewOrAlloc.getUses()) {
2761     if (auto newFillOp = dyn_cast<FillOp>(u.getOwner())) {
2762       assert(isa<MemRefType>(newFillOp.output().getType()));
2763       if (newFillOp.output() != viewOrAlloc)
2764         continue;
2765       if (mayExistInterleavedUses(newFillOp, copyOp, {viewOrAlloc, subView}))
2766         continue;
2767       maybeFillOp = newFillOp;
2768       break;
2769     }
2770   }
2771   // Ensure padding matches.
2772   if (maybeFillOp && xferOp.getPadding() != maybeFillOp.value())
2773     return rewriter.notifyMatchFailure(xferOp,
2774                                        "padding value does not match fill");
2775 
2776   // `in` is the subview that memref.copy reads. Replace it.
2777   Value in = copyOp.getSource();
2778 
2779   // memref.copy + linalg.fill can be used to create a padded local buffer.
2780   // The `masked` attribute is only valid on this padded buffer.
2781   // When forwarding to vector.transfer_read, the attribute must be reset
2782   // conservatively.
2783   auto vectorType = xferOp.getVectorType();
2784   Value res = rewriter.create<vector::TransferReadOp>(
2785       xferOp.getLoc(), vectorType, in, xferOp.getIndices(),
2786       xferOp.getPermutationMapAttr(), xferOp.getPadding(), xferOp.getMask(),
2787       rewriter.getBoolArrayAttr(
2788           SmallVector<bool>(vectorType.getRank(), false)));
2789 
2790   if (maybeFillOp)
2791     rewriter.eraseOp(maybeFillOp);
2792   rewriter.eraseOp(copyOp);
2793   rewriter.replaceOp(xferOp, res);
2794 
2795   return success();
2796 }
2797 
2798 /// TODO: use interfaces, side-effects and aliasing analysis as appropriate,
2799 /// when available.
2800 LogicalResult LinalgCopyVTWForwardingPattern::matchAndRewrite(
2801     vector::TransferWriteOp xferOp, PatternRewriter &rewriter) const {
2802   // TODO: support mask.
2803   if (xferOp.getMask())
2804     return rewriter.notifyMatchFailure(xferOp, "unsupported mask");
2805 
2806   // Transfer into `viewOrAlloc`.
2807   Value viewOrAlloc = xferOp.getSource();
2808   if (!viewOrAlloc.getDefiningOp<memref::ViewOp>() &&
2809       !viewOrAlloc.getDefiningOp<memref::AllocOp>())
2810     return rewriter.notifyMatchFailure(xferOp, "source not a view or alloc");
2811 
2812   // Ensure there is exactly one subview of `viewOrAlloc` defining `subView`.
2813   memref::SubViewOp subViewOp = getSubViewUseIfUnique(viewOrAlloc);
2814   if (!subViewOp)
2815     return rewriter.notifyMatchFailure(xferOp, "no subview found");
2816   Value subView = subViewOp.getResult();
2817 
2818   // Find the copy from `subView` without interleaved uses.
2819   memref::CopyOp copyOp;
2820   for (auto &u : subViewOp.getResult().getUses()) {
2821     if (auto newCopyOp = dyn_cast<memref::CopyOp>(u.getOwner())) {
2822       if (newCopyOp.getSource() != subView)
2823         continue;
2824       if (mayExistInterleavedUses(xferOp, newCopyOp, {viewOrAlloc, subView}))
2825         continue;
2826       copyOp = newCopyOp;
2827       break;
2828     }
2829   }
2830   if (!copyOp)
2831     return rewriter.notifyMatchFailure(xferOp, "no copy found");
2832 
2833   // `out` is the subview copied into that we replace.
2834   assert(isa<MemRefType>(copyOp.getTarget().getType()));
2835   Value out = copyOp.getTarget();
2836 
2837   // Forward vector.transfer into copy.
2838   // memref.copy + linalg.fill can be used to create a padded local buffer.
2839   // The `masked` attribute is only valid on this padded buffer.
2840   // When forwarding to vector.transfer_write, the attribute must be reset
2841   // conservatively.
2842   auto vector = xferOp.getVector();
2843   rewriter.create<vector::TransferWriteOp>(
2844       xferOp.getLoc(), vector, out, xferOp.getIndices(),
2845       xferOp.getPermutationMapAttr(), xferOp.getMask(),
2846       rewriter.getBoolArrayAttr(
2847           SmallVector<bool>(vector.getType().getRank(), false)));
2848 
2849   rewriter.eraseOp(copyOp);
2850   rewriter.eraseOp(xferOp);
2851 
2852   return success();
2853 }
2854 
2855 //===----------------------------------------------------------------------===//
2856 // Convolution vectorization patterns
2857 //===----------------------------------------------------------------------===//
2858 
2859 template <int N>
2860 static void bindShapeDims(ShapedType shapedType) {}
2861 
2862 template <int N, typename IntTy, typename... IntTy2>
2863 static void bindShapeDims(ShapedType shapedType, IntTy &val, IntTy2 &...vals) {
2864   val = shapedType.getShape()[N];
2865   bindShapeDims<N + 1, IntTy2 &...>(shapedType, vals...);
2866 }
2867 
2868 /// Bind a pack of int& to the leading dimensions of shapedType.getShape().
2869 template <typename... IntTy>
2870 static void bindShapeDims(ShapedType shapedType, IntTy &...vals) {
2871   bindShapeDims<0>(shapedType, vals...);
2872 }
2873 
2874 namespace {
2875 bool isCastOfBlockArgument(Operation *op) {
2876   return isa<CastOpInterface>(op) && op->getNumOperands() == 1 &&
2877          isa<BlockArgument>(op->getOperand(0));
2878 }
2879 
2880 bool isSupportedPoolKind(vector::CombiningKind kind) {
2881   switch (kind) {
2882   case vector::CombiningKind::ADD:
2883   case vector::CombiningKind::MAXNUMF:
2884   case vector::CombiningKind::MAXIMUMF:
2885   case vector::CombiningKind::MAXSI:
2886   case vector::CombiningKind::MAXUI:
2887   case vector::CombiningKind::MINNUMF:
2888   case vector::CombiningKind::MINIMUMF:
2889   case vector::CombiningKind::MINSI:
2890   case vector::CombiningKind::MINUI:
2891     return true;
2892   default:
2893     return false;
2894   }
2895 }
2896 
2897 /// Generate a vector implementation for either:
2898 /// ```
2899 ///   Op def: (     w,     kw  )
2900 ///    Iters: ({Par(), Red()})
2901 ///   Layout: {{w + kw}, {kw}, {w}}
2902 /// ```
2903 /// kw is unrolled.
2904 ///
2905 /// or
2906 ///
2907 /// ```
2908 ///   Op def: (     n,     w,     c,    kw,    f  )
2909 ///    Iters: ({Par(), Par(), Par(), Red(), Red()})
2910 ///   Layout: {{n, strideW * w + dilationW * kw, c}, {kw, c, f}, {n, w, f}}
2911 /// ```
2912 /// kw is unrolled, w is unrolled iff dilationW > 1.
2913 ///
2914 /// or
2915 ///
2916 /// ```
2917 ///   Op def: (     n,     c,     w,    f,    kw )
2918 ///    Iters: ({Par(), Par(), Par(), Red(), Red()})
2919 ///   Layout: {{n, c, strideW * w + dilationW * kw}, {f, c, kw}, {n, f, w}}
2920 /// ```
2921 /// kw is unrolled, w is unrolled iff dilationW > 1.
2922 ///
2923 /// or
2924 ///
2925 /// ```
2926 ///   Op def: (     n,     w,     c,    kw )
2927 ///    Iters: ({Par(), Par(), Par(), Red()})
2928 ///   Layout: {{n, strideW * w + dilationW * kw, c}, {kw, c}, {n, w, c}}
2929 /// ```
2930 /// kw is unrolled, w is unrolled iff dilationW > 1.
2931 struct Conv1DGenerator
2932     : public StructuredGenerator<LinalgOp, utils::IteratorType> {
2933   Conv1DGenerator(RewriterBase &rewriter, LinalgOp linalgOp, int strideW,
2934                   int dilationW)
2935       : StructuredGenerator<LinalgOp, utils::IteratorType>(rewriter, linalgOp),
2936         strideW(strideW), dilationW(dilationW) {
2937     // Determine whether `linalgOp` can be generated with this generator
2938     if (linalgOp.getNumDpsInputs() != 2 || linalgOp.getNumDpsInits() != 1)
2939       return;
2940     lhsShaped = linalgOp.getDpsInputOperand(0)->get();
2941     rhsShaped = linalgOp.getDpsInputOperand(1)->get();
2942     resShaped = linalgOp.getDpsInitOperand(0)->get();
2943     lhsShapedType = dyn_cast<ShapedType>(lhsShaped.getType());
2944     rhsShapedType = dyn_cast<ShapedType>(rhsShaped.getType());
2945     resShapedType = dyn_cast<ShapedType>(resShaped.getType());
2946     if (!lhsShapedType || !rhsShapedType || !resShapedType)
2947       return;
2948     // (LHS has dimension NCW/NWC and RES has dimension NFW/NCW/NWF/NWC) OR
2949     // (non-channeled convolution -> LHS and RHS both have single dimensions).
2950     if ((lhsShapedType.getRank() != 3 || resShapedType.getRank() != 3) &&
2951         (lhsShapedType.getRank() != 1 || resShapedType.getRank() != 1))
2952       return;
2953 
2954     Operation *reduceOp = matchLinalgReduction(linalgOp.getDpsInitOperand(0));
2955     if (!reduceOp)
2956       return;
2957     redOp = reduceOp->getName().getIdentifier();
2958 
2959     if (!setOperKind(reduceOp))
2960       return;
2961     auto maybeKind = getCombinerOpKind(reduceOp);
2962     if (!maybeKind || (*maybeKind != vector::CombiningKind::ADD &&
2963                        (oper != Pool || !isSupportedPoolKind(*maybeKind)))) {
2964       return;
2965     }
2966 
2967     auto rhsRank = rhsShapedType.getRank();
2968     switch (oper) {
2969     case Conv:
2970       if (rhsRank != 1 && rhsRank != 2 && rhsRank != 3)
2971         return;
2972       break;
2973     case Pool:
2974       if (rhsRank != 1)
2975         return;
2976       break;
2977     }
2978     // The op is now known to be valid.
2979     valid = true;
2980   }
2981 
2982   /// Generate a vector implementation for:
2983   /// ```
2984   ///   Op def: (     w,     kw  )
2985   ///    Iters: ({Par(), Red()})
2986   ///   Layout: {{w + kw}, {kw}, {w}}
2987   /// ```
2988   /// kw is always unrolled.
2989   ///
2990   /// or
2991   ///
2992   /// ```
2993   ///   Op def: (     n,     w,     c,    kw,    f  )
2994   ///    Iters: ({Par(), Par(), Par(), Red(), Red()})
2995   ///   Layout: {{n, strideW * w + dilationW * kw, c}, {kw, c, f}, {n, w, f}}
2996   /// ```
2997   /// kw is always unrolled.
2998   /// TODO: w (resp. kw) is unrolled when the strideW ( resp. dilationW) is
2999   /// > 1.
3000   FailureOr<Operation *> conv(Conv1DOpOrder conv1DOpOrder) {
3001     if (!valid)
3002       return rewriter.notifyMatchFailure(op, "unvectorizable 1-D conv/pool");
3003 
3004     int64_t nSize, wSize, cSize, kwSize, fSize;
3005     SmallVector<int64_t, 3> lhsShape, rhsShape, resShape;
3006     bool isSingleChanneled = (conv1DOpOrder == Conv1DOpOrder::W);
3007     switch (conv1DOpOrder) {
3008     case Conv1DOpOrder::W:
3009       // Initialize unused dimensions
3010       nSize = fSize = cSize = 0;
3011       // out{W}
3012       bindShapeDims(resShapedType, wSize);
3013       // kernel{kw}
3014       bindShapeDims(rhsShapedType, kwSize);
3015       lhsShape = {// iw = ow + kw - 1
3016                   //   (i.e. 16 convolved with 3 -> 14)
3017                   (wSize + kwSize - 1)};
3018       rhsShape = {kwSize};
3019       resShape = {wSize};
3020       break;
3021     case Conv1DOpOrder::Nwc:
3022       // out{n, w, f}
3023       bindShapeDims(resShapedType, nSize, wSize, fSize);
3024       switch (oper) {
3025       case Conv:
3026         // kernel{kw, c, f}
3027         bindShapeDims(rhsShapedType, kwSize, cSize);
3028         break;
3029       case Pool:
3030         // kernel{kw}
3031         bindShapeDims(rhsShapedType, kwSize);
3032         cSize = fSize;
3033         break;
3034       }
3035       lhsShape = {nSize,
3036                   // iw = ow * sw + kw *  dw - 1
3037                   //   (i.e. 16 convolved with 3 (@stride 1 dilation 1) -> 14)
3038                   // Perform the proper inclusive -> exclusive -> inclusive.
3039                   ((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) -
3040                       1,
3041                   cSize};
3042       switch (oper) {
3043       case Conv:
3044         rhsShape = {kwSize, cSize, fSize};
3045         break;
3046       case Pool:
3047         rhsShape = {kwSize};
3048         break;
3049       }
3050       resShape = {nSize, wSize, fSize};
3051       break;
3052     case Conv1DOpOrder::Ncw:
3053       // out{n, f, w}
3054       bindShapeDims(resShapedType, nSize, fSize, wSize);
3055       switch (oper) {
3056       case Conv:
3057         // kernel{f, c, kw}
3058         bindShapeDims(rhsShapedType, fSize, cSize, kwSize);
3059         break;
3060       case Pool:
3061         // kernel{kw}
3062         bindShapeDims(rhsShapedType, kwSize);
3063         cSize = fSize;
3064         break;
3065       }
3066       lhsShape = {nSize, cSize,
3067                   // iw = ow * sw + kw *  dw - 1
3068                   //   (i.e. 16 convolved with 3 (@stride 1 dilation 1) -> 14)
3069                   // Perform the proper inclusive -> exclusive -> inclusive.
3070                   ((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) -
3071                       1};
3072       switch (oper) {
3073       case Conv:
3074         rhsShape = {fSize, cSize, kwSize};
3075         break;
3076       case Pool:
3077         rhsShape = {kwSize};
3078         break;
3079       }
3080       resShape = {nSize, fSize, wSize};
3081       break;
3082     }
3083 
3084     vector::TransferWriteOp write;
3085     Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
3086 
3087     // w is unrolled (i.e. wSizeStep == 1) iff strideW > 1.
3088     // When strideW == 1, we can batch the contiguous loads and avoid
3089     // unrolling
3090     int64_t wSizeStep = strideW == 1 ? wSize : 1;
3091 
3092     Type lhsEltType = lhsShapedType.getElementType();
3093     Type rhsEltType = rhsShapedType.getElementType();
3094     Type resEltType = resShapedType.getElementType();
3095     auto lhsType = VectorType::get(lhsShape, lhsEltType);
3096     auto rhsType = VectorType::get(rhsShape, rhsEltType);
3097     auto resType = VectorType::get(resShape, resEltType);
3098     // Zero padding with the corresponding dimensions for lhs, rhs and res.
3099     SmallVector<Value> lhsPadding(lhsShape.size(), zero);
3100     SmallVector<Value> rhsPadding(rhsShape.size(), zero);
3101     SmallVector<Value> resPadding(resShape.size(), zero);
3102 
3103     // Read the whole lhs, rhs and res in one shot (with zero padding).
3104     Value lhs = rewriter.create<vector::TransferReadOp>(loc, lhsType, lhsShaped,
3105                                                         lhsPadding);
3106     // This is needed only for Conv.
3107     Value rhs = nullptr;
3108     if (oper == Conv)
3109       rhs = rewriter.create<vector::TransferReadOp>(loc, rhsType, rhsShaped,
3110                                                     rhsPadding);
3111     Value res = rewriter.create<vector::TransferReadOp>(loc, resType, resShaped,
3112                                                         resPadding);
3113 
3114     // The base vectorization case for channeled convolution is input:
3115     // {n,w,c}, weight: {kw,c,f}, output: {n,w,f}. To reuse the base pattern
3116     // vectorization case, we do pre transpose on input, weight, and output.
3117     switch (conv1DOpOrder) {
3118     case Conv1DOpOrder::W:
3119     case Conv1DOpOrder::Nwc:
3120       // Base case, so no transposes necessary.
3121       break;
3122     case Conv1DOpOrder::Ncw: {
3123       // To match base vectorization case, we pre-transpose current case.
3124       // ncw -> nwc
3125       static constexpr std::array<int64_t, 3> permLhs = {0, 2, 1};
3126       lhs = rewriter.create<vector::TransposeOp>(loc, lhs, permLhs);
3127       // fcw -> wcf
3128       static constexpr std::array<int64_t, 3> permRhs = {2, 1, 0};
3129 
3130       // This is needed only for Conv.
3131       if (oper == Conv)
3132         rhs = rewriter.create<vector::TransposeOp>(loc, rhs, permRhs);
3133       // nfw -> nwf
3134       static constexpr std::array<int64_t, 3> permRes = {0, 2, 1};
3135       res = rewriter.create<vector::TransposeOp>(loc, res, permRes);
3136       break;
3137     }
3138     }
3139 
3140     //===------------------------------------------------------------------===//
3141     // Begin vector-only rewrite part
3142     //===------------------------------------------------------------------===//
3143     // Unroll along kw and read slices of lhs and rhs.
3144     SmallVector<Value> lhsVals, rhsVals, resVals;
3145     lhsVals = extractConvInputSlices(rewriter, loc, lhs, nSize, wSize, cSize,
3146                                      kwSize, strideW, dilationW, wSizeStep,
3147                                      isSingleChanneled);
3148     // Do not do for pooling.
3149     if (oper == Conv)
3150       rhsVals = extractConvFilterSlices(rewriter, loc, rhs, kwSize);
3151     resVals = extractConvResultSlices(rewriter, loc, res, nSize, wSize, fSize,
3152                                       wSizeStep, isSingleChanneled);
3153 
3154     auto linearIndex = [&](int64_t kw, int64_t w) {
3155       return kw * (wSize / wSizeStep) + w;
3156     };
3157 
3158     // Compute contraction: O{n, w, f} += I{n, sw * w + dw * kw, c} * F{c, f}
3159     // or perform outerproduct for non-channeled convolution or perform simple
3160     // arith operation for pooling
3161     for (int64_t kw = 0; kw < kwSize; ++kw) {
3162       for (int64_t w = 0; w < wSize; w += wSizeStep) {
3163         switch (oper) {
3164         case Conv:
3165           if (isSingleChanneled) {
3166             resVals[w] = conv1dSliceAsOuterProduct(rewriter, loc,
3167                                                    lhsVals[linearIndex(kw, w)],
3168                                                    rhsVals[kw], resVals[w]);
3169           } else {
3170             resVals[w] = conv1dSliceAsContraction(rewriter, loc,
3171                                                   lhsVals[linearIndex(kw, w)],
3172                                                   rhsVals[kw], resVals[w]);
3173           }
3174           break;
3175         case Pool:
3176           resVals[w] = pool1dSlice(rewriter, loc, lhsVals[linearIndex(kw, w)],
3177                                    resVals[w]);
3178           break;
3179         }
3180       }
3181     }
3182 
3183     res = insertConvResultSlices(rewriter, loc, res, wSize, wSizeStep, resVals,
3184                                  isSingleChanneled);
3185     //===------------------------------------------------------------------===//
3186     // End vector-only rewrite part
3187     //===------------------------------------------------------------------===//
3188 
3189     // The base vectorization case for channeled convolution is output:
3190     // {n,w,f} To reuse the result from base pattern vectorization case, we
3191     // post transpose the base case result.
3192     switch (conv1DOpOrder) {
3193     case Conv1DOpOrder::W:
3194     case Conv1DOpOrder::Nwc:
3195       // Base case, so no transposes necessary.
3196       break;
3197     case Conv1DOpOrder::Ncw: {
3198       // nwf -> nfw
3199       static constexpr std::array<int64_t, 3> perm = {0, 2, 1};
3200       res = rewriter.create<vector::TransposeOp>(loc, res, perm);
3201       break;
3202     }
3203     }
3204 
3205     return rewriter
3206         .create<vector::TransferWriteOp>(loc, res, resShaped, resPadding)
3207         .getOperation();
3208   }
3209 
3210   // Take a value and widen to have the same element type as `ty`.
3211   Value promote(RewriterBase &rewriter, Location loc, Value val, Type ty) {
3212     const Type srcElementType = getElementTypeOrSelf(val.getType());
3213     const Type dstElementType = getElementTypeOrSelf(ty);
3214     assert(isa<IntegerType>(dstElementType) || isa<FloatType>(dstElementType));
3215     if (srcElementType == dstElementType)
3216       return val;
3217 
3218     const int64_t srcWidth = srcElementType.getIntOrFloatBitWidth();
3219     const int64_t dstWidth = dstElementType.getIntOrFloatBitWidth();
3220     const Type dstType =
3221         cast<ShapedType>(val.getType()).cloneWith(std::nullopt, dstElementType);
3222 
3223     if (isa<IntegerType>(srcElementType) && isa<FloatType>(dstElementType)) {
3224       return rewriter.create<arith::SIToFPOp>(loc, dstType, val);
3225     }
3226 
3227     if (isa<FloatType>(srcElementType) && isa<FloatType>(dstElementType) &&
3228         srcWidth < dstWidth)
3229       return rewriter.create<arith::ExtFOp>(loc, dstType, val);
3230 
3231     if (isa<IntegerType>(srcElementType) && isa<IntegerType>(dstElementType) &&
3232         srcWidth < dstWidth)
3233       return rewriter.create<arith::ExtSIOp>(loc, dstType, val);
3234 
3235     assert(false && "unhandled promotion case");
3236     return nullptr;
3237   }
3238 
3239   // Create a contraction: lhs{n, w, c} * rhs{c, f} -> res{n, w, f}
3240   Value conv1dSliceAsContraction(RewriterBase &rewriter, Location loc,
3241                                  Value lhs, Value rhs, Value res) {
3242     vector::IteratorType par = vector::IteratorType::parallel;
3243     vector::IteratorType red = vector::IteratorType::reduction;
3244     AffineExpr n, w, f, c;
3245     bindDims(ctx, n, w, f, c);
3246     lhs = promote(rewriter, loc, lhs, res.getType());
3247     rhs = promote(rewriter, loc, rhs, res.getType());
3248     return rewriter.create<vector::ContractionOp>(
3249         loc, lhs, rhs, res,
3250         /*indexingMaps=*/MapList{{n, w, c}, {c, f}, {n, w, f}},
3251         /*iteratorTypes=*/ArrayRef<vector::IteratorType>{par, par, par, red});
3252   }
3253 
3254   // Create an outerproduct: lhs{w} * rhs{1} -> res{w} for single channel
3255   // convolution.
3256   Value conv1dSliceAsOuterProduct(RewriterBase &rewriter, Location loc,
3257                                   Value lhs, Value rhs, Value res) {
3258     return rewriter.create<vector::OuterProductOp>(
3259         loc, res.getType(), lhs, rhs, res, vector::CombiningKind::ADD);
3260   }
3261 
3262   // Create a reduction: lhs{n, w, c} -> res{n, w, c}
3263   Value pool1dSlice(RewriterBase &rewriter, Location loc, Value lhs,
3264                     Value res) {
3265     if (isPoolExt)
3266       lhs = rewriter.create(loc, poolExtOp, lhs, res.getType())->getResult(0);
3267     return rewriter
3268         .create(loc, redOp, ArrayRef<Value>{lhs, res}, res.getType())
3269         ->getResult(0);
3270   }
3271 
3272   /// Generate a vector implementation for:
3273   /// ```
3274   ///   Op def: (     n,     w,     c,    kw)
3275   ///    Iters: ({Par(), Par(), Par(), Red()})
3276   ///   Layout: {{n, strideW * w + dilationW * kw, c}, {kw, c}, {n, w, c}}
3277   /// ```
3278   /// kw is always unrolled.
3279   /// TODO: w (resp. kw) is unrolled when the strideW ( resp. dilationW) is
3280   /// > 1.
3281   FailureOr<Operation *> depthwiseConv(uint64_t channelDimVecSize,
3282                                        bool channelDimScalableFlag,
3283                                        bool flatten) {
3284     if (!valid)
3285       return rewriter.notifyMatchFailure(op, "unvectorizable depthwise conv");
3286 
3287     bool scalableChDim = false;
3288     bool useMasking = false;
3289     int64_t nSize, wSize, cSize, kwSize;
3290     // kernel{kw, c}
3291     bindShapeDims(rhsShapedType, kwSize, cSize);
3292     if (ShapedType::isDynamic(cSize)) {
3293       assert(channelDimVecSize != 0 && "Channel dim vec size must be > 0");
3294       cSize = channelDimVecSize;
3295       // Scalable vectors are only used when both conditions are met:
3296       //  1. channel dim is dynamic
3297       //  2. channelDimScalableFlag is set
3298       scalableChDim = channelDimScalableFlag;
3299       useMasking = true;
3300     }
3301 
3302     assert(!(useMasking && flatten) &&
3303            "Unsupported flattened conv with dynamic shapes");
3304 
3305     // out{n, w, c}
3306     bindShapeDims(resShapedType, nSize, wSize);
3307 
3308     vector::TransferWriteOp write;
3309     Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
3310 
3311     // w is unrolled (i.e. wSizeStep == 1) iff strideW > 1.
3312     // When strideW == 1, we can batch the contiguous loads and avoid
3313     // unrolling
3314     int64_t wSizeStep = strideW == 1 ? wSize : 1;
3315 
3316     Type lhsEltType = lhsShapedType.getElementType();
3317     Type rhsEltType = rhsShapedType.getElementType();
3318     Type resEltType = resShapedType.getElementType();
3319     VectorType lhsType = VectorType::get(
3320         {nSize,
3321          // iw = ow * sw + kw *  dw - 1
3322          //   (i.e. 16 convolved with 3 (@stride 1 dilation 1) -> 14)
3323          ((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) - 1,
3324          cSize},
3325         lhsEltType, /*scalableDims=*/{false, false, scalableChDim});
3326     VectorType rhsType =
3327         VectorType::get({kwSize, cSize}, rhsEltType,
3328                         /*scalableDims=*/{false, scalableChDim});
3329     VectorType resType =
3330         VectorType::get({nSize, wSize, cSize}, resEltType,
3331                         /*scalableDims=*/{false, false, scalableChDim});
3332 
3333     // Masks the input xfer Op along the channel dim, iff the corresponding
3334     // scalable flag is set.
3335     auto maybeMaskXferOp = [&](ArrayRef<int64_t> maskShape,
3336                                ArrayRef<bool> scalableDims,
3337                                Operation *opToMask) {
3338       if (!useMasking)
3339         return opToMask;
3340       auto maskType =
3341           VectorType::get(maskShape, rewriter.getI1Type(), scalableDims);
3342 
3343       SmallVector<bool> inBounds(maskShape.size(), true);
3344       auto xferOp = cast<VectorTransferOpInterface>(opToMask);
3345       xferOp->setAttr(xferOp.getInBoundsAttrName(),
3346                       rewriter.getBoolArrayAttr(inBounds));
3347 
3348       SmallVector<OpFoldResult> mixedDims = vector::getMixedSizesXfer(
3349           cast<LinalgOp>(op).hasPureTensorSemantics(), opToMask, rewriter);
3350 
3351       Value maskOp =
3352           rewriter.create<vector::CreateMaskOp>(loc, maskType, mixedDims);
3353 
3354       return mlir::vector::maskOperation(rewriter, opToMask, maskOp);
3355     };
3356 
3357     // Read lhs slice of size {n, w * strideW + kw * dilationW, c} @ [0, 0,
3358     // 0].
3359     Value lhs = rewriter.create<vector::TransferReadOp>(
3360         loc, lhsType, lhsShaped, ValueRange{zero, zero, zero});
3361     auto maybeMaskedLhs = maybeMaskXferOp(
3362         lhsType.getShape(), lhsType.getScalableDims(), lhs.getDefiningOp());
3363 
3364     // Read rhs slice of size {kw, c} @ [0, 0].
3365     Value rhs = rewriter.create<vector::TransferReadOp>(loc, rhsType, rhsShaped,
3366                                                         ValueRange{zero, zero});
3367     auto maybeMaskedRhs = maybeMaskXferOp(
3368         rhsType.getShape(), rhsType.getScalableDims(), rhs.getDefiningOp());
3369 
3370     // Read res slice of size {n, w, c} @ [0, 0, 0].
3371     Value res = rewriter.create<vector::TransferReadOp>(
3372         loc, resType, resShaped, ValueRange{zero, zero, zero});
3373     auto maybeMaskedRes = maybeMaskXferOp(
3374         resType.getShape(), resType.getScalableDims(), res.getDefiningOp());
3375 
3376     //===------------------------------------------------------------------===//
3377     // Begin vector-only rewrite part
3378     //===------------------------------------------------------------------===//
3379     // Unroll along kw and read slices of lhs and rhs.
3380     SmallVector<Value> lhsVals, rhsVals, resVals;
3381     auto inOutSliceSizes = SmallVector<int64_t>{nSize, wSizeStep, cSize};
3382     auto inOutStrides = SmallVector<int64_t>{1, 1, 1};
3383 
3384     // Extract lhs slice of size {n, wSizeStep, c}
3385     //   @ [0, sw * w + dw * kw, 0].
3386     for (int64_t kw = 0; kw < kwSize; ++kw) {
3387       for (int64_t w = 0; w < wSize; w += wSizeStep) {
3388         lhsVals.push_back(rewriter.create<vector::ExtractStridedSliceOp>(
3389             loc, maybeMaskedLhs->getResult(0),
3390             /*offsets=*/ArrayRef<int64_t>{0, w * strideW + kw * dilationW, 0},
3391             inOutSliceSizes, inOutStrides));
3392       }
3393     }
3394     // Extract rhs slice of size {c} @ [kw].
3395     for (int64_t kw = 0; kw < kwSize; ++kw) {
3396       rhsVals.push_back(rewriter.create<vector::ExtractOp>(
3397           loc, maybeMaskedRhs->getResult(0),
3398           /*offsets=*/ArrayRef<int64_t>{kw}));
3399     }
3400     // Extract res slice: {n, wSizeStep, c} @ [0, w, 0].
3401     for (int64_t w = 0; w < wSize; w += wSizeStep) {
3402       resVals.push_back(rewriter.create<vector::ExtractStridedSliceOp>(
3403           loc, maybeMaskedRes->getResult(0),
3404           /*offsets=*/ArrayRef<int64_t>{0, w, 0}, inOutSliceSizes,
3405           inOutStrides));
3406     }
3407 
3408     auto linearIndex = [&](int64_t kw, int64_t w) {
3409       return kw * (wSize / wSizeStep) + w;
3410     };
3411 
3412     // Note - the scalable flags are ignored as flattening combined with
3413     // scalable vectorization is not supported.
3414     auto inOutFlattenSliceSizes =
3415         SmallVector<int64_t>{nSize, wSizeStep * cSize};
3416     auto lhsTypeAfterFlattening =
3417         VectorType::get(inOutFlattenSliceSizes, lhsEltType);
3418     auto resTypeAfterFlattening =
3419         VectorType::get(inOutFlattenSliceSizes, resEltType);
3420 
3421     // Compute contraction: O{n, w, c} += I{n, sw * w + dw * kw, c} * F{c}
3422     for (int64_t kw = 0; kw < kwSize; ++kw) {
3423       for (int64_t w = 0; w < wSize; w += wSizeStep) {
3424         Value lhsVal = lhsVals[linearIndex(kw, w)];
3425         Value resVal = resVals[w];
3426         if (flatten) {
3427           // Flatten the input and output vectors (collapse the channel
3428           // dimension)
3429           lhsVal = rewriter.create<vector::ShapeCastOp>(
3430               loc, lhsTypeAfterFlattening, lhsVals[linearIndex(kw, w)]);
3431           resVal = rewriter.create<vector::ShapeCastOp>(
3432               loc, resTypeAfterFlattening, resVals[w]);
3433         }
3434         resVals[w] = depthwiseConv1dSliceAsMulAcc(rewriter, loc, lhsVal,
3435                                                   rhsVals[kw], resVal, flatten);
3436         if (flatten) {
3437           // Un-flatten the output vector (restore the channel dimension)
3438           resVals[w] = rewriter.create<vector::ShapeCastOp>(
3439               loc, VectorType::get(inOutSliceSizes, resEltType), resVals[w]);
3440         }
3441       }
3442     }
3443 
3444     // Its possible we failed to create the Fma.
3445     if (!llvm::all_of(resVals, [](Value v) { return v; })) {
3446       // Manually revert (in reverse order) to avoid leaving a bad IR state.
3447       for (auto &collection :
3448            {resVals, rhsVals, lhsVals, {res, rhs, lhs, zero}})
3449         for (Value v : collection)
3450           rewriter.eraseOp(v.getDefiningOp());
3451       return rewriter.notifyMatchFailure(op, "failed to create FMA");
3452     }
3453 
3454     // Write back res slice: {n, wSizeStep, c} @ [0, w, 0].
3455     // This does not depend on kw.
3456     for (int64_t w = 0; w < wSize; w += wSizeStep) {
3457       maybeMaskedRes = rewriter.create<vector::InsertStridedSliceOp>(
3458           loc, resVals[w], maybeMaskedRes->getResult(0),
3459           /*offsets=*/ArrayRef<int64_t>{0, w, 0},
3460           /*strides=*/ArrayRef<int64_t>{1, 1, 1});
3461     }
3462     //===------------------------------------------------------------------===//
3463     // End vector-only rewrite part
3464     //===------------------------------------------------------------------===//
3465 
3466     // Write back res slice of size {n, w, c} @ [0, 0, 0].
3467     Operation *resOut = rewriter.create<vector::TransferWriteOp>(
3468         loc, maybeMaskedRes->getResult(0), resShaped,
3469         ValueRange{zero, zero, zero});
3470     return maybeMaskXferOp(resType.getShape(), resType.getScalableDims(),
3471                            resOut);
3472   }
3473 
3474   /// Lower:
3475   ///   *  lhs{n, w, c} * rhs{c} -> res{n, w, c} (flatten = false)
3476   ///   *  lhs{n, w * c} * rhs{c} -> res{n, w * c} (flatten = true)
3477   /// to MulAcc.
3478   Value depthwiseConv1dSliceAsMulAcc(RewriterBase &rewriter, Location loc,
3479                                      Value lhs, Value rhs, Value res,
3480                                      bool flatten) {
3481     auto rhsTy = cast<ShapedType>(rhs.getType());
3482     auto resTy = cast<ShapedType>(res.getType());
3483 
3484     // TODO(suderman): Change this to use a vector.ima intrinsic.
3485     lhs = promote(rewriter, loc, lhs, resTy);
3486 
3487     if (flatten) {
3488       // NOTE: This following logic won't work for scalable vectors. For this
3489       // reason, "flattening" is not supported when shapes are dynamic (this
3490       // should be captured by one of the pre-conditions).
3491 
3492       // There are two options for handling the filter:
3493       //  * shape_cast(broadcast(filter))
3494       //  * broadcast(shuffle(filter))
3495       // Opt for the option without shape_cast to simplify the codegen.
3496       auto rhsSize = cast<VectorType>(rhs.getType()).getShape()[0];
3497       auto resSize = cast<VectorType>(res.getType()).getShape()[1];
3498 
3499       SmallVector<int64_t, 16> indices;
3500       for (int i = 0; i < resSize / rhsSize; ++i) {
3501         for (int j = 0; j < rhsSize; ++j)
3502           indices.push_back(j);
3503       }
3504 
3505       rhs = rewriter.create<vector::ShuffleOp>(loc, rhs, rhs, indices);
3506     }
3507     // Broadcast the filter to match the output vector
3508     rhs = rewriter.create<vector::BroadcastOp>(
3509         loc, resTy.clone(rhsTy.getElementType()), rhs);
3510 
3511     rhs = promote(rewriter, loc, rhs, resTy);
3512 
3513     if (!lhs || !rhs)
3514       return nullptr;
3515 
3516     if (isa<FloatType>(resTy.getElementType()))
3517       return rewriter.create<vector::FMAOp>(loc, lhs, rhs, res);
3518 
3519     auto mul = rewriter.create<arith::MulIOp>(loc, lhs, rhs);
3520     return rewriter.create<arith::AddIOp>(loc, mul, res);
3521   }
3522 
3523   /// Entry point for non-channeled convolution:
3524   ///   {{w + kw}, {kw}, {w}}
3525   FailureOr<Operation *> generateNonChanneledConv() {
3526     AffineExpr w, kw;
3527     bindDims(ctx, w, kw);
3528     if (!iters({Par(), Red()}))
3529       return rewriter.notifyMatchFailure(op,
3530                                          "failed to match conv::W 1-par 1-red");
3531 
3532     // No transposition needed.
3533     if (layout({/*lhsIndex*/ {w + kw},
3534                 /*rhsIndex*/ {kw},
3535                 /*resIndex*/ {w}}))
3536       return conv(Conv1DOpOrder::W);
3537 
3538     return rewriter.notifyMatchFailure(op, "not a conv::W layout");
3539   }
3540 
3541   /// Entry point that transposes into the common form:
3542   ///   {{n, strideW * w + dilationW * kw, c}, {kw, c, f}, {n, w, f}}
3543   FailureOr<Operation *> generateNwcConv() {
3544     AffineExpr n, w, f, kw, c;
3545     bindDims(ctx, n, w, f, kw, c);
3546     if (!iters({Par(), Par(), Par(), Red(), Red()}))
3547       return rewriter.notifyMatchFailure(
3548           op, "failed to match conv::Nwc 3-par 2-red");
3549 
3550     // No transposition needed.
3551     if (layout({/*lhsIndex*/ {n, strideW * w + dilationW * kw, c},
3552                 /*rhsIndex*/ {kw, c, f},
3553                 /*resIndex*/ {n, w, f}}))
3554       return conv(Conv1DOpOrder::Nwc);
3555 
3556     return rewriter.notifyMatchFailure(op, "not a conv::Nwc layout");
3557   }
3558 
3559   /// Entry point that transposes into the common form:
3560   ///   {{n, c, strideW * w + dilationW * kw}, {f, c, kw}, {n, f, w}}
3561   FailureOr<Operation *> generateNcwConv() {
3562     AffineExpr n, w, f, kw, c;
3563     bindDims(ctx, n, f, w, c, kw);
3564     if (!iters({Par(), Par(), Par(), Red(), Red()}))
3565       return rewriter.notifyMatchFailure(
3566           op, "failed to match conv::Ncw 3-par 2-red");
3567 
3568     if (layout({/*lhsIndex*/ {n, c, strideW * w + dilationW * kw},
3569                 /*rhsIndex*/ {f, c, kw},
3570                 /*resIndex*/ {n, f, w}}))
3571       return conv(Conv1DOpOrder::Ncw);
3572 
3573     return rewriter.notifyMatchFailure(op, "not a conv::Ncw layout");
3574   }
3575 
3576   /// Entry point that transposes into the common form:
3577   ///   {{n, strideW * w + dilationW * kw, c}, {kw}, {n, w, c}} for pooling
3578   FailureOr<Operation *> generateNwcPooling() {
3579     AffineExpr n, w, c, kw;
3580     bindDims(ctx, n, w, c, kw);
3581     if (!iters({Par(), Par(), Par(), Red()}))
3582       return rewriter.notifyMatchFailure(op,
3583                                          "failed to match pooling 3-par 1-red");
3584 
3585     // No transposition needed.
3586     if (layout({/*lhsIndex*/ {n, strideW * w + dilationW * kw, c},
3587                 /*rhsIndex*/ {kw},
3588                 /*resIndex*/ {n, w, c}}))
3589       return conv(Conv1DOpOrder::Nwc);
3590 
3591     return rewriter.notifyMatchFailure(op, "not a pooling::Nwc layout");
3592   }
3593 
3594   /// Entry point that transposes into the common form:
3595   ///   {{n, c, strideW * w + dilationW * kw}, {kw}, {n, c, w}} for pooling
3596   FailureOr<Operation *> generateNcwPooling() {
3597     AffineExpr n, w, c, kw;
3598     bindDims(ctx, n, c, w, kw);
3599     if (!iters({Par(), Par(), Par(), Red()}))
3600       return rewriter.notifyMatchFailure(op,
3601                                          "failed to match pooling 3-par 1-red");
3602 
3603     if (layout({/*lhsIndex*/ {n, c, strideW * w + dilationW * kw},
3604                 /*rhsIndex*/ {kw},
3605                 /*resIndex*/ {n, c, w}}))
3606       return conv(Conv1DOpOrder::Ncw);
3607 
3608     return rewriter.notifyMatchFailure(op, "not a pooling::Ncw layout");
3609   }
3610 
3611   /// Entry point that transposes into the common form:
3612   ///   {{n, strideW * w + dilationW * kw, c}, {kw, c}, {n, w, c}}
3613   FailureOr<Operation *> generateDilatedConv(uint64_t vecChDimSize = 0,
3614                                              bool vecChDimScalableFlag = false,
3615                                              bool flatten = false) {
3616     AffineExpr n, w, c, kw;
3617     bindDims(ctx, n, w, c, kw);
3618     if (!iters({Par(), Par(), Par(), Red()}))
3619       return rewriter.notifyMatchFailure(
3620           op, "failed to match depthwise::Nwc conv 3-par 1-red");
3621 
3622     // No transposition needed.
3623     if (layout({/*lhsIndex*/ {n, strideW * w + dilationW * kw, c},
3624                 /*rhsIndex*/ {kw, c},
3625                 /*resIndex*/ {n, w, c}}))
3626       return depthwiseConv(vecChDimSize, vecChDimScalableFlag, flatten);
3627 
3628     return rewriter.notifyMatchFailure(op, "not a depthwise::Nwc layout");
3629   }
3630 
3631 private:
3632   enum OperKind { Conv, Pool };
3633   bool valid = false;
3634   OperKind oper = Conv;
3635   StringAttr redOp;
3636   StringAttr poolExtOp;
3637   bool isPoolExt = false;
3638   int strideW, dilationW;
3639   Value lhsShaped, rhsShaped, resShaped;
3640   ShapedType lhsShapedType, rhsShapedType, resShapedType;
3641 
3642   // Sets oper, poolExtOp and isPoolExt for valid conv/pooling ops.
3643   // Returns true iff it is a valid conv/pooling op.
3644   // If (region has 2 ops (reduction + yield) or 3 ops (extension + reduction
3645   // + yield) and rhs is not used) then it is the body of a pooling
3646   // If conv, check for single `mul` predecessor. The `mul` operands must be
3647   // block arguments or extension of block arguments.
3648   // Otherwise, check for one or zero `ext` predecessor. The `ext` operands
3649   // must be block arguments or extension of block arguments.
3650   bool setOperKind(Operation *reduceOp) {
3651     int numBlockArguments =
3652         llvm::count_if(reduceOp->getOperands(), llvm::IsaPred<BlockArgument>);
3653     switch (numBlockArguments) {
3654     case 1: {
3655       // Will be convolution if feeder is a MulOp.
3656       // Otherwise, if it can be pooling.
3657       auto feedValIt = llvm::find_if_not(reduceOp->getOperands(),
3658                                          llvm::IsaPred<BlockArgument>);
3659       Operation *feedOp = (*feedValIt).getDefiningOp();
3660       if (isCastOfBlockArgument(feedOp)) {
3661         oper = Pool;
3662         isPoolExt = true;
3663         poolExtOp = feedOp->getName().getIdentifier();
3664       } else if (!(isa<arith::MulIOp, arith::MulFOp>(feedOp) &&
3665                    llvm::all_of(feedOp->getOperands(), [](Value v) {
3666                      if (isa<BlockArgument>(v))
3667                        return true;
3668                      if (Operation *op = v.getDefiningOp())
3669                        return isCastOfBlockArgument(op);
3670                      return false;
3671                    }))) {
3672         return false;
3673       }
3674       return true;
3675     }
3676     case 2:
3677       // Must be pooling
3678       oper = Pool;
3679       isPoolExt = false;
3680       return true;
3681     default:
3682       return false;
3683     }
3684   }
3685 };
3686 } // namespace
3687 
3688 /// Helper function to vectorize a LinalgOp with convolution semantics.
3689 // TODO: extend the generic vectorization to support windows and drop this.
3690 static FailureOr<Operation *> vectorizeConvolution(
3691     RewriterBase &rewriter, LinalgOp op, ArrayRef<int64_t> inputVecSizes,
3692     ArrayRef<bool> inputScalableVecDims, bool flatten1DDepthwiseConv) {
3693   // The ConvolutionOpInterface gives us guarantees of existence for
3694   // strides/dilations. However, we do not need to rely on those, we can
3695   // simply use them if present, otherwise use the default and let the generic
3696   // conv. matcher in the ConvGenerator succeed or fail.
3697   auto strides = op->getAttrOfType<DenseIntElementsAttr>("strides");
3698   auto dilations = op->getAttrOfType<DenseIntElementsAttr>("dilations");
3699   auto stride = strides ? *strides.getValues<uint64_t>().begin() : 1;
3700   auto dilation = dilations ? *dilations.getValues<uint64_t>().begin() : 1;
3701   Conv1DGenerator e(rewriter, op, stride, dilation);
3702   auto res = e.generateNonChanneledConv();
3703   if (succeeded(res))
3704     return res;
3705   res = e.generateNwcConv();
3706   if (succeeded(res))
3707     return res;
3708   res = e.generateNcwConv();
3709   if (succeeded(res))
3710     return res;
3711   res = e.generateNwcPooling();
3712   if (succeeded(res))
3713     return res;
3714   res = e.generateNcwPooling();
3715   if (succeeded(res))
3716     return res;
3717 
3718   // Only depthwise 1D NWC convs are left - these can be vectorized using masks
3719   // and scalable vectors. Note that ATM the only dim that can be dynamic (i.e.
3720   // masked/scalable) is the channel dim (i.e. the trailing dim).
3721   uint64_t vecChDimSize = ShapedType::kDynamic;
3722   bool vecChDimScalableFlag = false;
3723   if (!inputVecSizes.empty()) {
3724     // Only use the input vector size corresponding to the channel dim. Other
3725     // vector dims will be inferred from the Ops.
3726     assert((isa<linalg::DepthwiseConv1DNwcWcOp>(*op) ||
3727             isa<linalg::DepthwiseConv1DNcwCwOp>(*op)) &&
3728            "Not a 1D depthwise conv!");
3729     size_t chDimIdx =
3730         TypeSwitch<Operation *, size_t>(op)
3731             .Case<linalg::DepthwiseConv1DNwcWcOp>([](auto conv) { return 2; })
3732             .Case<linalg::DepthwiseConv1DNcwCwOp>([](auto conv) { return 1; });
3733 
3734     vecChDimSize = inputVecSizes[chDimIdx];
3735     vecChDimScalableFlag = inputScalableVecDims[chDimIdx];
3736   }
3737   return e.generateDilatedConv(vecChDimSize, vecChDimScalableFlag,
3738                                flatten1DDepthwiseConv);
3739 }
3740 
3741 struct VectorizeConvolution : public OpInterfaceRewritePattern<LinalgOp> {
3742   using OpInterfaceRewritePattern::OpInterfaceRewritePattern;
3743 
3744   LogicalResult matchAndRewrite(LinalgOp op,
3745                                 PatternRewriter &rewriter) const override {
3746     FailureOr<Operation *> resultOrFail = vectorizeConvolution(rewriter, op);
3747     if (failed(resultOrFail))
3748       return failure();
3749     Operation *newOp = *resultOrFail;
3750     if (newOp->getNumResults() == 0) {
3751       rewriter.eraseOp(op.getOperation());
3752       return success();
3753     }
3754     assert(newOp->getNumResults() == 1 && "expected single result");
3755     rewriter.replaceOp(op.getOperation(), newOp->getResult(0));
3756     return success();
3757   }
3758 };
3759 
3760 void mlir::linalg::populateConvolutionVectorizationPatterns(
3761     RewritePatternSet &patterns, PatternBenefit benefit) {
3762   patterns.add<VectorizeConvolution>(patterns.getContext(), benefit);
3763 }
3764