xref: /llvm-project/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp (revision 1276ce9e9713b2a0802004676fad7e40980396d5)
1 //===- Vectorization.cpp - Implementation of linalg Vectorization ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the linalg dialect Vectorization transformations.
10 //
11 //===----------------------------------------------------------------------===//
12 #include "mlir/Dialect/Affine/Utils.h"
13 
14 #include "mlir/Analysis/SliceAnalysis.h"
15 #include "mlir/Dialect/Affine/IR/AffineOps.h"
16 #include "mlir/Dialect/Arith/IR/Arith.h"
17 #include "mlir/Dialect/Func/IR/FuncOps.h"
18 #include "mlir/Dialect/Linalg/IR/Linalg.h"
19 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
20 #include "mlir/Dialect/Linalg/Utils/Utils.h"
21 #include "mlir/Dialect/Tensor/IR/Tensor.h"
22 #include "mlir/Dialect/Tensor/Utils/Utils.h"
23 #include "mlir/Dialect/Utils/IndexingUtils.h"
24 #include "mlir/Dialect/Utils/StructuredOpsUtils.h"
25 #include "mlir/Dialect/Vector/IR/VectorOps.h"
26 #include "mlir/Dialect/Vector/Interfaces/MaskableOpInterface.h"
27 #include "mlir/Dialect/Vector/Utils/VectorUtils.h"
28 #include "mlir/IR/AffineExpr.h"
29 #include "mlir/IR/Builders.h"
30 #include "mlir/IR/BuiltinTypeInterfaces.h"
31 #include "mlir/IR/BuiltinTypes.h"
32 #include "mlir/IR/OpDefinition.h"
33 #include "mlir/IR/PatternMatch.h"
34 #include "mlir/Support/LLVM.h"
35 #include "mlir/Transforms/RegionUtils.h"
36 #include "llvm/ADT/STLExtras.h"
37 #include "llvm/ADT/Sequence.h"
38 #include "llvm/ADT/SmallVector.h"
39 #include "llvm/ADT/TypeSwitch.h"
40 #include "llvm/ADT/iterator_range.h"
41 #include "llvm/Support/Debug.h"
42 #include "llvm/Support/MathExtras.h"
43 #include "llvm/Support/raw_ostream.h"
44 #include <optional>
45 #include <type_traits>
46 
47 using namespace mlir;
48 using namespace mlir::linalg;
49 
50 #define DEBUG_TYPE "linalg-vectorization"
51 
52 #define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
53 #define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
54 
55 /// Try to vectorize `convOp` as a convolution.
56 static FailureOr<Operation *>
57 vectorizeConvolution(RewriterBase &rewriter, LinalgOp convOp,
58                      ArrayRef<int64_t> inputVecSizes = {},
59                      ArrayRef<bool> inputVecScalableFlags = {},
60                      bool flatten1DDepthwiseConv = false);
61 
62 /// Return the unique instance of OpType in `block` if it is indeed unique.
63 /// Return null if none or more than 1 instances exist.
64 template <typename OpType>
65 static OpType getSingleOpOfType(Block &block) {
66   OpType res;
67   block.walk([&](OpType op) {
68     if (res) {
69       res = nullptr;
70       return WalkResult::interrupt();
71     }
72     res = op;
73     return WalkResult::advance();
74   });
75   return res;
76 }
77 
78 /// Helper function to extract the input slices after filter is unrolled along
79 /// kw.
80 static SmallVector<Value>
81 extractConvInputSlices(RewriterBase &rewriter, Location loc, Value input,
82                        int64_t nSize, int64_t wSize, int64_t cSize,
83                        int64_t kwSize, int strideW, int dilationW,
84                        int64_t wSizeStep, bool isSingleChanneled) {
85   SmallVector<Value> result;
86   if (isSingleChanneled) {
87     // Extract input slice of size {wSizeStep} @ [w + kw] for non-channeled
88     // convolution.
89     SmallVector<int64_t> sizes{wSizeStep};
90     SmallVector<int64_t> strides{1};
91     for (int64_t kw = 0; kw < kwSize; ++kw) {
92       for (int64_t w = 0; w < wSize; w += wSizeStep) {
93         result.push_back(rewriter.create<vector::ExtractStridedSliceOp>(
94             loc, input, /*offsets=*/ArrayRef<int64_t>{w + kw}, sizes, strides));
95       }
96     }
97   } else {
98     // Extract lhs slice of size {n, wSizeStep, c} @ [0, sw * w + dw * kw, 0]
99     // for channeled convolution.
100     SmallVector<int64_t> sizes{nSize, wSizeStep, cSize};
101     SmallVector<int64_t> strides{1, 1, 1};
102     for (int64_t kw = 0; kw < kwSize; ++kw) {
103       for (int64_t w = 0; w < wSize; w += wSizeStep) {
104         result.push_back(rewriter.create<vector::ExtractStridedSliceOp>(
105             loc, input,
106             /*offsets=*/ArrayRef<int64_t>{0, w * strideW + kw * dilationW, 0},
107             sizes, strides));
108       }
109     }
110   }
111   return result;
112 }
113 
114 /// Helper function to extract the filter slices after filter is unrolled along
115 /// kw.
116 static SmallVector<Value> extractConvFilterSlices(RewriterBase &rewriter,
117                                                   Location loc, Value filter,
118                                                   int64_t kwSize) {
119   SmallVector<Value> result;
120   // Extract rhs slice of size [{c, f} for channeled convolutions and {1} for
121   // non-chanelled convolution] @ [kw].
122   for (int64_t kw = 0; kw < kwSize; ++kw) {
123     result.push_back(rewriter.create<vector::ExtractOp>(
124         loc, filter, /*offsets=*/ArrayRef<int64_t>{kw}));
125   }
126   return result;
127 }
128 
129 /// Helper function to extract the result slices after filter is unrolled along
130 /// kw.
131 static SmallVector<Value>
132 extractConvResultSlices(RewriterBase &rewriter, Location loc, Value res,
133                         int64_t nSize, int64_t wSize, int64_t fSize,
134                         int64_t wSizeStep, bool isSingleChanneled) {
135   SmallVector<Value> result;
136   if (isSingleChanneled) {
137     // Extract res slice: {wSizeStep} @ [w] for non-channeled convolution.
138     SmallVector<int64_t> sizes{wSizeStep};
139     SmallVector<int64_t> strides{1};
140     for (int64_t w = 0; w < wSize; w += wSizeStep) {
141       result.push_back(rewriter.create<vector::ExtractStridedSliceOp>(
142           loc, res, /*offsets=*/ArrayRef<int64_t>{w}, sizes, strides));
143     }
144   } else {
145     // Extract res slice: {n, wSizeStep, f} @ [0, w, 0] for channeled
146     // convolution.
147     SmallVector<int64_t> sizes{nSize, wSizeStep, fSize};
148     SmallVector<int64_t> strides{1, 1, 1};
149     for (int64_t w = 0; w < wSize; w += wSizeStep) {
150       result.push_back(rewriter.create<vector::ExtractStridedSliceOp>(
151           loc, res, /*offsets=*/ArrayRef<int64_t>{0, w, 0}, sizes, strides));
152     }
153   }
154   return result;
155 }
156 
157 /// Helper function to insert the computed result slices.
158 static Value insertConvResultSlices(RewriterBase &rewriter, Location loc,
159                                     Value res, int64_t wSize, int64_t wSizeStep,
160                                     SmallVectorImpl<Value> &resVals,
161                                     bool isSingleChanneled) {
162 
163   if (isSingleChanneled) {
164     // Write back res slice: {wSizeStep} @ [w] for non-channeled convolution.
165     // This does not depend on kw.
166     SmallVector<int64_t> strides{1};
167     for (int64_t w = 0; w < wSize; w += wSizeStep) {
168       res = rewriter.create<vector::InsertStridedSliceOp>(
169           loc, resVals[w], res, /*offsets=*/ArrayRef<int64_t>{w}, strides);
170     }
171   } else {
172     // Write back res slice: {n, wSizeStep, f} @ [0, w, 0] for channeled
173     // convolution. This does not depend on kw.
174     SmallVector<int64_t> strides{1, 1, 1};
175     for (int64_t w = 0; w < wSize; w += wSizeStep) {
176       res = rewriter.create<vector::InsertStridedSliceOp>(
177           loc, resVals[w], res, /*offsets=*/ArrayRef<int64_t>{0, w, 0},
178           strides);
179     }
180   }
181   return res;
182 }
183 
184 /// Contains the vectorization state and related methods used across the
185 /// vectorization process of a given operation.
186 struct VectorizationState {
187   VectorizationState(RewriterBase &rewriter) : rewriterGuard(rewriter) {}
188 
189   /// Initializes the vectorization state, including the computation of the
190   /// canonical vector shape for vectorization.
191   LogicalResult initState(RewriterBase &rewriter, LinalgOp linalgOp,
192                           ArrayRef<int64_t> inputVectorSizes,
193                           ArrayRef<bool> inputScalableVecDims);
194 
195   /// Returns the canonical vector shape used to vectorize the iteration space.
196   ArrayRef<int64_t> getCanonicalVecShape() const { return canonicalVecShape; }
197 
198   /// Returns the vector dimensions that are scalable in the canonical vector
199   /// shape.
200   ArrayRef<bool> getScalableVecDims() const { return scalableVecDims; }
201 
202   /// Returns a vector type of the provided `elementType` with the canonical
203   /// vector shape and the corresponding fixed/scalable dimensions bit. If
204   /// `dimPermutation` is provided, the canonical vector dimensions are permuted
205   /// accordingly.
206   VectorType getCanonicalVecType(
207       Type elementType,
208       std::optional<AffineMap> dimPermutation = std::nullopt) const {
209     SmallVector<int64_t> vectorShape;
210     SmallVector<bool> scalableDims;
211     if (dimPermutation.has_value()) {
212       vectorShape =
213           applyPermutationMap<int64_t>(*dimPermutation, canonicalVecShape);
214       scalableDims =
215           applyPermutationMap<bool>(*dimPermutation, scalableVecDims);
216     } else {
217       vectorShape.append(canonicalVecShape.begin(), canonicalVecShape.end());
218       scalableDims.append(scalableVecDims.begin(), scalableVecDims.end());
219     }
220 
221     return VectorType::get(vectorShape, elementType, scalableDims);
222   }
223 
224   /// Masks an operation with the canonical vector mask if the operation needs
225   /// masking. Returns the masked operation or the original operation if masking
226   /// is not needed. If provided, the canonical mask for this operation is
227   /// permuted using `maybeIndexingMap`.
228   Operation *
229   maskOperation(RewriterBase &rewriter, Operation *opToMask, LinalgOp linalgOp,
230                 std::optional<AffineMap> maybeIndexingMap = std::nullopt);
231 
232 private:
233   /// Initializes the iteration space static sizes using the Linalg op
234   /// information. This may become more complicated in the future.
235   void initIterSpaceStaticSizes(LinalgOp linalgOp) {
236     iterSpaceStaticSizes.append(linalgOp.getStaticLoopRanges());
237   }
238 
239   /// Generates 'arith.constant' and 'tensor/memref.dim' operations for
240   /// all the static and dynamic dimensions of the iteration space to be
241   /// vectorized and store them in `iterSpaceValueSizes`.
242   LogicalResult precomputeIterSpaceValueSizes(RewriterBase &rewriter,
243                                               LinalgOp linalgOp);
244 
245   /// Create or retrieve an existing mask value to mask `opToMask` in the
246   /// canonical vector iteration space. If `maybeMaskingMap` the mask is
247   /// permuted using that permutation map. If a new mask is created, it will be
248   /// cached for future users.
249   Value getOrCreateMaskFor(RewriterBase &rewriter, Operation *opToMask,
250                            LinalgOp linalgOp,
251                            std::optional<AffineMap> maybeMaskingMap);
252 
253   // Holds the compile-time static sizes of the iteration space to vectorize.
254   // Dynamic dimensions are represented using ShapedType::kDynamic.
255   SmallVector<int64_t> iterSpaceStaticSizes;
256 
257   /// Holds the value sizes of the iteration space to vectorize. Static
258   /// dimensions are represented by 'arith.constant' and dynamic
259   /// dimensions by 'tensor/memref.dim'.
260   SmallVector<Value> iterSpaceValueSizes;
261 
262   /// Holds the canonical vector shape used to vectorize the iteration space.
263   SmallVector<int64_t> canonicalVecShape;
264 
265   /// Holds the vector dimensions that are scalable in the canonical vector
266   /// shape.
267   SmallVector<bool> scalableVecDims;
268 
269   /// Holds the active masks for permutations of the canonical vector iteration
270   /// space.
271   DenseMap<AffineMap, Value> activeMaskCache;
272 
273   /// Global vectorization guard for the incoming rewriter. It's initialized
274   /// when the vectorization state is initialized.
275   OpBuilder::InsertionGuard rewriterGuard;
276 };
277 
278 LogicalResult
279 VectorizationState::precomputeIterSpaceValueSizes(RewriterBase &rewriter,
280                                                   LinalgOp linalgOp) {
281   // TODO: Support 0-d vectors.
282   for (int vecDim = 0, end = canonicalVecShape.size(); vecDim < end; ++vecDim) {
283     if (!ShapedType::isDynamic(iterSpaceStaticSizes[vecDim])) {
284       // Create constant index op for static dimensions.
285       iterSpaceValueSizes.push_back(rewriter.create<arith::ConstantIndexOp>(
286           linalgOp.getLoc(), iterSpaceStaticSizes[vecDim]));
287       continue;
288     }
289 
290     // Find an operand defined on this dimension of the iteration space to
291     // extract the runtime dimension size.
292     Value operand;
293     unsigned operandDimPos;
294     if (failed(linalgOp.mapIterationSpaceDimToOperandDim(vecDim, operand,
295                                                          operandDimPos)))
296       return failure();
297 
298     Value dynamicDim = linalgOp.hasPureTensorSemantics()
299                            ? (Value)rewriter.create<tensor::DimOp>(
300                                  linalgOp.getLoc(), operand, operandDimPos)
301                            : (Value)rewriter.create<memref::DimOp>(
302                                  linalgOp.getLoc(), operand, operandDimPos);
303     iterSpaceValueSizes.push_back(dynamicDim);
304   }
305 
306   return success();
307 }
308 
309 /// Initializes the vectorization state, including the computation of the
310 /// canonical vector shape for vectorization.
311 // TODO: Move this to the constructor when we can remove the failure cases.
312 LogicalResult
313 VectorizationState::initState(RewriterBase &rewriter, LinalgOp linalgOp,
314                               ArrayRef<int64_t> inputVectorSizes,
315                               ArrayRef<bool> inputScalableVecDims) {
316   // Initialize the insertion point.
317   rewriter.setInsertionPoint(linalgOp);
318 
319   if (!inputVectorSizes.empty()) {
320     // Get the canonical vector shape from the input vector sizes provided. This
321     // path should be taken to vectorize code with dynamic shapes and when using
322     // vector sizes greater than the iteration space sizes.
323     canonicalVecShape.append(inputVectorSizes.begin(), inputVectorSizes.end());
324     scalableVecDims.append(inputScalableVecDims.begin(),
325                            inputScalableVecDims.end());
326   } else {
327     // Compute the canonical vector shape from the operation shape. If there are
328     // dynamic shapes, the operation won't be vectorized. We assume all the
329     // vector dimensions are fixed.
330     canonicalVecShape = linalgOp.getStaticLoopRanges();
331     scalableVecDims.append(linalgOp.getNumLoops(), false);
332   }
333 
334   LDBG("Canonical vector shape: ");
335   LLVM_DEBUG(llvm::interleaveComma(canonicalVecShape, llvm::dbgs()));
336   LLVM_DEBUG(llvm::dbgs() << "\n");
337   LDBG("Scalable vector dims: ");
338   LLVM_DEBUG(llvm::interleaveComma(scalableVecDims, llvm::dbgs()));
339   LLVM_DEBUG(llvm::dbgs() << "\n");
340 
341   if (ShapedType::isDynamicShape(canonicalVecShape))
342     return failure();
343 
344   // Initialize iteration space static sizes.
345   initIterSpaceStaticSizes(linalgOp);
346 
347   // Generate 'arith.constant' and 'tensor/memref.dim' operations for
348   // all the static and dynamic dimensions of the iteration space, needed to
349   // compute a mask during vectorization.
350   if (failed(precomputeIterSpaceValueSizes(rewriter, linalgOp)))
351     return failure();
352 
353   return success();
354 }
355 
356 /// Create or retrieve an existing mask value to mask `opToMask` in the
357 /// canonical vector iteration space. If `maybeMaskingMap` the mask is permuted
358 /// using that permutation map. If a new mask is created, it will be cached for
359 /// future users.
360 Value VectorizationState::getOrCreateMaskFor(
361     RewriterBase &rewriter, Operation *opToMask, LinalgOp linalgOp,
362     std::optional<AffineMap> maybeMaskingMap) {
363   // No mask is needed if the operation is not maskable.
364   auto maskableOp = dyn_cast<vector::MaskableOpInterface>(opToMask);
365   if (!maskableOp)
366     return Value();
367 
368   assert(!maskableOp.isMasked() &&
369          "Masking an operation that is already masked");
370 
371   // If no masking map was provided, use an identity map with the loop dims.
372   assert((!maybeMaskingMap || *maybeMaskingMap) &&
373          "Unexpected null mask permutation map");
374   AffineMap maskingMap =
375       maybeMaskingMap ? *maybeMaskingMap
376                       : AffineMap::getMultiDimIdentityMap(
377                             linalgOp.getNumLoops(), rewriter.getContext());
378 
379   LDBG("Masking map: " << maskingMap << "\n");
380 
381   // Return the active mask for the masking map of this operation if it was
382   // already created.
383   auto activeMaskIt = activeMaskCache.find(maskingMap);
384   if (activeMaskIt != activeMaskCache.end()) {
385     Value mask = activeMaskIt->second;
386     LDBG("Reusing mask: " << mask << "\n");
387     return mask;
388   }
389 
390   // Compute permuted projection of the iteration space to be masked and the
391   // corresponding mask shape. If the resulting iteration space dimensions are
392   // static and identical to the mask shape, masking is not needed for this
393   // operation.
394   // TODO: Improve this check. Only projected permutation indexing maps are
395   // supported.
396   SmallVector<int64_t> permutedStaticSizes =
397       applyPermutationMap<int64_t>(maskingMap, iterSpaceStaticSizes);
398   auto maskType = getCanonicalVecType(rewriter.getI1Type(), maskingMap);
399   auto maskShape = maskType.getShape();
400 
401   LDBG("Mask shape: ");
402   LLVM_DEBUG(llvm::interleaveComma(maskShape, llvm::dbgs()));
403   LLVM_DEBUG(llvm::dbgs() << "\n");
404 
405   if (permutedStaticSizes == maskShape) {
406     LDBG("Masking is not needed for masking map: " << maskingMap << "\n");
407     activeMaskCache[maskingMap] = Value();
408     return Value();
409   }
410 
411   // Permute the iteration space value sizes to compute the mask upper bounds.
412   SmallVector<Value> upperBounds =
413       applyPermutationMap(maskingMap, ArrayRef<Value>(iterSpaceValueSizes));
414   assert(!maskShape.empty() && !upperBounds.empty() &&
415          "Masked 0-d vectors are not supported yet");
416 
417   // Create the mask based on the dimension values.
418   Value mask = rewriter.create<vector::CreateMaskOp>(linalgOp.getLoc(),
419                                                      maskType, upperBounds);
420   LDBG("Creating new mask: " << mask << "\n");
421   activeMaskCache[maskingMap] = mask;
422   return mask;
423 }
424 
425 Operation *
426 VectorizationState::maskOperation(RewriterBase &rewriter, Operation *opToMask,
427                                   LinalgOp linalgOp,
428                                   std::optional<AffineMap> maybeIndexingMap) {
429   LDBG("Trying to mask: " << *opToMask << "\n");
430 
431   std::optional<AffineMap> maybeMaskingMap = std::nullopt;
432   // The Operand indexing map may contain "zero" results, e.g.:
433   //    (d0, d1, d2, d3) -> (d0, d1, d2, 0)
434   // When applied to canonical vector shapes like these:
435   //    (1, 16, 16, 4)
436   // we would get:
437   //    (1, 16, 16, 0)
438   // Instead, we should extract the following map permutation map for masking:
439   //    (d0, d1, d2, d3) -> (d0, d1, d2)
440   // This way, the corresponding vector/mask type will be:
441   //    vector<1x16x16xty>
442   // rather than:
443   //    vector<1x16x16x0xty>
444   if (maybeIndexingMap)
445     maybeMaskingMap = maybeIndexingMap->dropZeroResults();
446 
447   // Create or retrieve mask for this operation.
448   Value mask =
449       getOrCreateMaskFor(rewriter, opToMask, linalgOp, maybeMaskingMap);
450 
451   if (!mask) {
452     LDBG("No mask required\n");
453     return opToMask;
454   }
455 
456   // Wrap the operation with a new `vector.mask` and update D-U chain.
457   assert(opToMask && "Expected a valid operation to mask");
458   auto maskOp = cast<vector::MaskOp>(
459       mlir::vector::maskOperation(rewriter, opToMask, mask));
460   Operation *maskOpTerminator = &maskOp.getMaskRegion().front().back();
461 
462   for (auto [resIdx, resVal] : llvm::enumerate(opToMask->getResults()))
463     rewriter.replaceAllUsesExcept(resVal, maskOp.getResult(resIdx),
464                                   maskOpTerminator);
465 
466   LDBG("Masked operation: " << *maskOp << "\n");
467   return maskOp;
468 }
469 
470 /// Given an indexing `map` coming from a LinalgOp indexing, restricted to a
471 /// projectedPermutation, compress the unused dimensions to serve as a
472 /// permutation_map for a vector transfer operation.
473 /// For example, given a linalg op such as:
474 ///
475 /// ```
476 ///   %0 = linalg.generic {
477 ///        indexing_maps = affine_map<(d0, d1, d2, d3, d4) -> (d4, d0, d2)>,
478 ///        indexing_maps = affine_map<(d0, d1, d2, d3, d4) -> (d1, d3)>
479 ///      }
480 ///     ins(%0 : tensor<2x3x4xf32>)
481 ///    outs(%1 : tensor<5x6xf32>)
482 /// ```
483 ///
484 /// the iteration domain size of the linalg op is 3x5x4x6x2. The first affine
485 /// map is reindexed to `affine_map<(d0, d1, d2) -> (d2, d0, d1)>`, the second
486 /// affine map is reindexed to `affine_map<(d0, d1) -> (d0, d1)>`.
487 static AffineMap reindexIndexingMap(AffineMap map) {
488   assert(map.isProjectedPermutation(/*allowZeroInResults=*/true) &&
489          "expected projected permutation");
490   auto res = compressUnusedDims(map);
491   assert(res.getNumDims() ==
492              (res.getNumResults() - res.getNumOfZeroResults()) &&
493          "expected reindexed map with same number of dims and results");
494   return res;
495 }
496 
497 /// Helper enum to represent conv1d input traversal order.
498 enum class Conv1DOpOrder {
499   W,   // Corresponds to non-channeled 1D convolution operation.
500   Ncw, // Corresponds to operation that traverses the input in (n, c, w) order.
501   Nwc  // Corresponds to operation that traverses the input in (n, w, c) order.
502 };
503 
504 /// Helper data structure to represent the result of vectorization.
505 /// In certain specific cases, like terminators, we do not want to propagate/
506 enum VectorizationStatus {
507   /// Op failed to vectorize.
508   Failure = 0,
509   /// Op vectorized and custom function took care of replacement logic
510   NoReplace,
511   /// Op vectorized into a new Op whose results will replace original Op's
512   /// results.
513   NewOp
514   // TODO: support values if Op vectorized to Many-Ops whose results we need to
515   // aggregate for replacement.
516 };
517 struct VectorizationResult {
518   /// Return status from vectorizing the current op.
519   enum VectorizationStatus status = VectorizationStatus::Failure;
520   /// New vectorized operation to replace the current op.
521   /// Replacement behavior is specified by `status`.
522   Operation *newOp;
523 };
524 
525 std::optional<vector::CombiningKind>
526 mlir::linalg::getCombinerOpKind(Operation *combinerOp) {
527   using ::mlir::vector::CombiningKind;
528 
529   if (!combinerOp)
530     return std::nullopt;
531   return llvm::TypeSwitch<Operation *, std::optional<CombiningKind>>(combinerOp)
532       .Case<arith::AddIOp, arith::AddFOp>(
533           [&](auto op) { return CombiningKind::ADD; })
534       .Case<arith::AndIOp>([&](auto op) { return CombiningKind::AND; })
535       .Case<arith::MaxSIOp>([&](auto op) { return CombiningKind::MAXSI; })
536       .Case<arith::MaxUIOp>([&](auto op) { return CombiningKind::MAXUI; })
537       .Case<arith::MaximumFOp>([&](auto op) { return CombiningKind::MAXIMUMF; })
538       .Case<arith::MaxNumFOp>([&](auto op) { return CombiningKind::MAXNUMF; })
539       .Case<arith::MinSIOp>([&](auto op) { return CombiningKind::MINSI; })
540       .Case<arith::MinUIOp>([&](auto op) { return CombiningKind::MINUI; })
541       .Case<arith::MinimumFOp>([&](auto op) { return CombiningKind::MINIMUMF; })
542       .Case<arith::MinNumFOp>([&](auto op) { return CombiningKind::MINNUMF; })
543       .Case<arith::MulIOp, arith::MulFOp>(
544           [&](auto op) { return CombiningKind::MUL; })
545       .Case<arith::OrIOp>([&](auto op) { return CombiningKind::OR; })
546       .Case<arith::XOrIOp>([&](auto op) { return CombiningKind::XOR; })
547       .Default([&](auto op) { return std::nullopt; });
548 }
549 
550 /// Check whether `outputOperand` is a reduction with a single combiner
551 /// operation. Return the combiner operation of the reduction. Return
552 /// nullptr otherwise. Multiple reduction operations would impose an
553 /// ordering between reduction dimensions and is currently unsupported in
554 /// Linalg. This limitation is motivated by the fact that e.g. min(max(X)) !=
555 /// max(min(X))
556 // TODO: use in LinalgOp verification, there is a circular dependency atm.
557 static Operation *matchLinalgReduction(OpOperand *outputOperand) {
558   auto linalgOp = cast<LinalgOp>(outputOperand->getOwner());
559   unsigned outputPos =
560       outputOperand->getOperandNumber() - linalgOp.getNumDpsInputs();
561   // Only single combiner operations are supported for now.
562   SmallVector<Operation *, 4> combinerOps;
563   if (!matchReduction(linalgOp.getRegionOutputArgs(), outputPos, combinerOps) ||
564       combinerOps.size() != 1)
565     return nullptr;
566 
567   // Return the combiner operation.
568   return combinerOps[0];
569 }
570 
571 /// Broadcast `value` to a vector of `shape` if possible. Return value
572 /// otherwise.
573 static Value broadcastIfNeeded(OpBuilder &b, Value value, Type dstType) {
574   auto dstVecType = dyn_cast<VectorType>(dstType);
575   // If no shape to broadcast to, just return `value`.
576   if (dstVecType.getRank() == 0)
577     return value;
578   if (vector::isBroadcastableTo(value.getType(), dstVecType) !=
579       vector::BroadcastableToResult::Success)
580     return value;
581   Location loc = b.getInsertionPoint()->getLoc();
582   return b.createOrFold<vector::BroadcastOp>(loc, dstVecType, value);
583 }
584 
585 /// Create MultiDimReductionOp to compute the reduction for `reductionOp`. This
586 /// assumes that `reductionOp` has two operands and one of them is the reduction
587 /// initial value.buildMultiDimReduce
588 // Note: this is a true builder that notifies the OpBuilder listener.
589 // TODO: Consider moving as a static helper on the ReduceOp.
590 static Operation *buildMultiDimReduce(OpBuilder &b, Operation *reduceOp,
591                                       Value valueToReduce, Value acc,
592                                       ArrayRef<bool> dimsToMask) {
593   auto maybeKind = getCombinerOpKind(reduceOp);
594   assert(maybeKind && "Failed precondition: could not get reduction kind");
595   return b.create<vector::MultiDimReductionOp>(
596       reduceOp->getLoc(), valueToReduce, acc, dimsToMask, *maybeKind);
597 }
598 
599 static SmallVector<bool> getDimsToReduce(LinalgOp linalgOp) {
600   return llvm::to_vector(
601       llvm::map_range(linalgOp.getIteratorTypesArray(), isReductionIterator));
602 }
603 
604 /// Check if `op` is a linalg.reduce or a linalg.generic that has at least one
605 /// reduction iterator.
606 static bool hasReductionIterator(LinalgOp &op) {
607   return isa<linalg::ReduceOp>(op) ||
608          (isa<linalg::GenericOp>(op) &&
609           llvm::any_of(op.getIteratorTypesArray(), isReductionIterator));
610 }
611 
612 /// Build a vector.transfer_write of `value` into `outputOperand` at indices set
613 /// to all `0`; where `outputOperand` is an output operand of the LinalgOp
614 /// currently being vectorized. If `dest` has null rank, build an memref.store.
615 /// Return the produced value or null if no value is produced.
616 // Note: this is a true builder that notifies the OpBuilder listener.
617 // TODO: Consider moving as a static helper on the ReduceOp.
618 static Value buildVectorWrite(RewriterBase &rewriter, Value value,
619                               OpOperand *outputOperand,
620                               VectorizationState &state) {
621   Location loc = value.getLoc();
622   auto linalgOp = cast<LinalgOp>(outputOperand->getOwner());
623   AffineMap opOperandMap = linalgOp.getMatchingIndexingMap(outputOperand);
624 
625   // Compute the vector type of the value to store. This type should be an
626   // identity or projection of the canonical vector type without any permutation
627   // applied, given that any permutation in a transfer write happens as part of
628   // the write itself.
629   AffineMap vectorTypeMap = AffineMap::getFilteredIdentityMap(
630       opOperandMap.getContext(), opOperandMap.getNumInputs(),
631       [&](AffineDimExpr dimExpr) -> bool {
632         return llvm::is_contained(opOperandMap.getResults(), dimExpr);
633       });
634   auto vectorType = state.getCanonicalVecType(
635       getElementTypeOrSelf(outputOperand->get().getType()), vectorTypeMap);
636 
637   Operation *write;
638   if (vectorType.getRank() > 0) {
639     AffineMap writeMap = inversePermutation(reindexIndexingMap(opOperandMap));
640     SmallVector<Value> indices(linalgOp.getRank(outputOperand),
641                                rewriter.create<arith::ConstantIndexOp>(loc, 0));
642     value = broadcastIfNeeded(rewriter, value, vectorType);
643     assert(value.getType() == vectorType && "Incorrect type");
644     write = rewriter.create<vector::TransferWriteOp>(
645         loc, value, outputOperand->get(), indices, writeMap);
646   } else {
647     // 0-d case is still special: do not invert the reindexing writeMap.
648     if (!isa<VectorType>(value.getType()))
649       value = rewriter.create<vector::BroadcastOp>(loc, vectorType, value);
650     assert(value.getType() == vectorType && "Incorrect type");
651     write = rewriter.create<vector::TransferWriteOp>(
652         loc, value, outputOperand->get(), ValueRange{});
653   }
654 
655   write = state.maskOperation(rewriter, write, linalgOp, opOperandMap);
656 
657   // If masked, set in-bounds to true. Masking guarantees that the access will
658   // be in-bounds.
659   if (auto maskOp = dyn_cast<vector::MaskingOpInterface>(write)) {
660     auto maskedWriteOp = cast<vector::TransferWriteOp>(maskOp.getMaskableOp());
661     SmallVector<bool> inBounds(maskedWriteOp.getVectorType().getRank(), true);
662     maskedWriteOp.setInBoundsAttr(rewriter.getBoolArrayAttr(inBounds));
663   }
664 
665   LDBG("vectorized op: " << *write << "\n");
666   if (!write->getResults().empty())
667     return write->getResult(0);
668   return Value();
669 }
670 
671 // Custom vectorization precondition function type. This is intented to be used
672 // with CustomVectorizationHook. Returns success if the corresponding custom
673 // hook can vectorize the op.
674 using CustomVectorizationPrecondition =
675     std::function<LogicalResult(Operation *, bool)>;
676 
677 // Custom vectorization function type. Produce a vector form of Operation*
678 // assuming all its vectorized operands are already in the IRMapping.
679 // Return nullptr if the Operation cannot be vectorized.
680 using CustomVectorizationHook =
681     std::function<VectorizationResult(Operation *, const IRMapping &)>;
682 
683 /// Helper function to vectorize the terminator of a `linalgOp`. New result
684 /// vector values are appended to `newResults`. Return
685 /// VectorizationStatus::NoReplace to signal the vectorization algorithm that it
686 /// should not try to map produced operations and instead return the results
687 /// using the `newResults` vector making them available to the vectorization
688 /// algorithm for RAUW. This function is meant to be used as a
689 /// CustomVectorizationHook.
690 static VectorizationResult
691 vectorizeLinalgYield(RewriterBase &rewriter, Operation *op,
692                      const IRMapping &bvm, VectorizationState &state,
693                      LinalgOp linalgOp, SmallVectorImpl<Value> &newResults) {
694   auto yieldOp = dyn_cast<linalg::YieldOp>(op);
695   if (!yieldOp)
696     return VectorizationResult{VectorizationStatus::Failure, nullptr};
697   for (const auto &output : llvm::enumerate(yieldOp.getValues())) {
698     // TODO: Scan for an opportunity for reuse.
699     // TODO: use a map.
700     Value vectorValue = bvm.lookup(output.value());
701     Value newResult =
702         buildVectorWrite(rewriter, vectorValue,
703                          linalgOp.getDpsInitOperand(output.index()), state);
704     if (newResult)
705       newResults.push_back(newResult);
706   }
707 
708   return VectorizationResult{VectorizationStatus::NoReplace, nullptr};
709 }
710 
711 /// Helper function to vectorize the index operations of a `linalgOp`. Return
712 /// VectorizationStatus::NewOp to signal the vectorization algorithm that it
713 /// should map the produced operations. This function is meant to be used as a
714 /// CustomVectorizationHook.
715 static VectorizationResult vectorizeLinalgIndex(RewriterBase &rewriter,
716                                                 VectorizationState &state,
717                                                 Operation *op,
718                                                 LinalgOp linalgOp) {
719   IndexOp indexOp = dyn_cast<linalg::IndexOp>(op);
720   if (!indexOp)
721     return VectorizationResult{VectorizationStatus::Failure, nullptr};
722   auto loc = indexOp.getLoc();
723   // Compute the static loop sizes of the index op.
724   ArrayRef<int64_t> targetShape = state.getCanonicalVecShape();
725   auto dim = indexOp.getDim();
726   // Compute a one-dimensional index vector for the index op dimension.
727   auto indexVectorType =
728       VectorType::get({targetShape[dim]}, rewriter.getIndexType(),
729                       state.getScalableVecDims()[dim]);
730   auto indexSteps = rewriter.create<vector::StepOp>(loc, indexVectorType);
731   // Return the one-dimensional index vector if it lives in the trailing
732   // dimension of the iteration space since the vectorization algorithm in this
733   // case can handle the broadcast.
734   if (dim == targetShape.size() - 1)
735     return VectorizationResult{VectorizationStatus::NewOp, indexSteps};
736   // Otherwise permute the targetShape to move the index dimension last,
737   // broadcast the one-dimensional index vector to the permuted shape, and
738   // finally transpose the broadcasted index vector to undo the permutation.
739   auto permPattern =
740       llvm::to_vector(llvm::seq<unsigned>(0, targetShape.size()));
741   std::swap(permPattern[dim], permPattern.back());
742   auto permMap =
743       AffineMap::getPermutationMap(permPattern, linalgOp.getContext());
744 
745   auto broadCastOp = rewriter.create<vector::BroadcastOp>(
746       loc, state.getCanonicalVecType(rewriter.getIndexType(), permMap),
747       indexSteps);
748   SmallVector<int64_t> transposition =
749       llvm::to_vector<16>(llvm::seq<int64_t>(0, linalgOp.getNumLoops()));
750   std::swap(transposition.back(), transposition[dim]);
751   auto transposeOp =
752       rewriter.create<vector::TransposeOp>(loc, broadCastOp, transposition);
753   return VectorizationResult{VectorizationStatus::NewOp, transposeOp};
754 }
755 
756 /// Helper function to check if the tensor.extract can be vectorized by the
757 /// custom hook vectorizeTensorExtract.
758 static LogicalResult
759 tensorExtractVectorizationPrecondition(Operation *op, bool vectorizeNDExtract) {
760   tensor::ExtractOp extractOp = dyn_cast<tensor::ExtractOp>(op);
761   if (!extractOp)
762     return failure();
763 
764   if (extractOp.getIndices().size() != 1 && !vectorizeNDExtract)
765     return failure();
766 
767   // Check the index type, but only for non 0-d tensors (for which we do need
768   // access indices).
769   if (not extractOp.getIndices().empty()) {
770     if (!VectorType::isValidElementType(extractOp.getIndices()[0].getType()))
771       return failure();
772   }
773 
774   if (llvm::any_of(extractOp->getResultTypes(), [](Type type) {
775         return !VectorType::isValidElementType(type);
776       })) {
777     return failure();
778   }
779 
780   return success();
781 }
782 
783 /// Calculates the offsets (`$index_vec`) for `vector.gather` operations
784 /// generated from `tensor.extract`. The offset is calculated as follows
785 /// (example using scalar values):
786 ///
787 ///    offset = extractOp.indices[0]
788 ///    for (i = 1; i < numIndices; i++)
789 ///      offset = extractOp.dimSize[i] * offset + extractOp.indices[i];
790 ///
791 /// For tensor<45 x 80 x 15 x f32> and index [1, 2, 3], this leads to:
792 ///  offset = ( ( 1 ) * 80 +  2 ) * 15  + 3
793 static Value calculateGatherOffset(RewriterBase &rewriter,
794                                    VectorizationState &state,
795                                    tensor::ExtractOp extractOp,
796                                    const IRMapping &bvm) {
797   // The vector of indices for GatherOp should be shaped as the output vector.
798   auto indexVecType = state.getCanonicalVecType(rewriter.getIndexType());
799   auto loc = extractOp.getLoc();
800 
801   Value offset = broadcastIfNeeded(
802       rewriter, bvm.lookup(extractOp.getIndices()[0]), indexVecType);
803 
804   const size_t numIndices = extractOp.getIndices().size();
805   for (size_t i = 1; i < numIndices; i++) {
806     Value dimIdx = rewriter.create<arith::ConstantIndexOp>(loc, i);
807 
808     auto dimSize = broadcastIfNeeded(
809         rewriter,
810         rewriter.create<tensor::DimOp>(loc, extractOp.getTensor(), dimIdx),
811         indexVecType);
812 
813     offset = rewriter.create<arith::MulIOp>(loc, offset, dimSize);
814 
815     auto extractOpIndex = broadcastIfNeeded(
816         rewriter, bvm.lookup(extractOp.getIndices()[i]), indexVecType);
817 
818     offset = rewriter.create<arith::AddIOp>(loc, extractOpIndex, offset);
819   }
820 
821   return offset;
822 }
823 
824 enum VectorMemoryAccessKind { ScalarBroadcast, Contiguous, Gather };
825 
826 /// Find the index of the trailing non-unit dim in linalgOp. This hook is used
827 /// when checking whether `tensor.extract` Op (within a `linalg.generic` Op)
828 /// represents a contiguous load operation.
829 ///
830 /// Note that when calling this hook, it is assumed that the output vector is
831 /// effectively 1D. Other cases (i.e. reading n-D vectors) should've been
832 /// labelled as a gather load before entering this method.
833 ///
834 /// Following on from the above, it is assumed that:
835 ///   * for statically shaped loops, when no masks are used, only one dim is !=
836 ///   1 (that's what the shape of the output vector is based on).
837 ///   * for dynamically shaped loops, there might be more non-unit dims
838 ///   as the output vector type is user-specified.
839 ///
840 /// TODO: Statically shaped loops + vector masking
841 static uint64_t getTrailingNonUnitLoopDimIdx(LinalgOp linalgOp) {
842   SmallVector<int64_t> loopRanges = linalgOp.getStaticLoopRanges();
843   assert(linalgOp.hasDynamicShape() ||
844          llvm::count_if(loopRanges, [](int64_t dim) { return dim != 1; }) ==
845                  1 &&
846              "For statically shaped Linalg Ops, only one "
847              "non-unit loop dim is expected");
848 
849   size_t idx = loopRanges.size() - 1;
850   for (; idx >= 0; idx--)
851     if (loopRanges[idx] != 1)
852       break;
853 
854   return idx;
855 }
856 
857 /// Checks whether `val` can be used for calculating a loop invariant index.
858 static bool isLoopInvariantIdx(LinalgOp &linalgOp, Value &val,
859                                VectorType resType) {
860 
861   assert(((llvm::count_if(resType.getShape(),
862                           [](int64_t dimSize) { return dimSize > 1; }) == 1)) &&
863          "n-D vectors are not yet supported");
864 
865   // Blocks outside _this_ linalg.generic are effectively loop invariant.
866   // However, analysing block arguments for _this_ linalg.generic Op is a bit
867   // tricky. Just bail out in the latter case.
868   // TODO: We could try analysing the corresponding affine map here.
869   auto *block = linalgOp.getBlock();
870   if (isa<BlockArgument>(val))
871     return llvm::all_of(block->getArguments(),
872                         [&val](Value v) { return (v != val); });
873 
874   Operation *defOp = val.getDefiningOp();
875   assert(defOp && "This is neither a block argument nor an operation result");
876 
877   // IndexOp is loop invariant as long as its result remains constant across
878   // iterations. Note that for dynamic shapes, the corresponding dim will also
879   // be conservatively treated as != 1.
880   if (auto indexOp = dyn_cast<linalg::IndexOp>(defOp)) {
881     return linalgOp.getStaticLoopRanges()[indexOp.getDim()] == 1;
882   }
883 
884   auto *ancestor = block->findAncestorOpInBlock(*defOp);
885 
886   // Values define outside `linalgOp` are loop invariant.
887   if (!ancestor)
888     return true;
889 
890   // Values defined inside `linalgOp`, which are constant, are loop invariant.
891   if (isa<arith::ConstantOp>(ancestor))
892     return true;
893 
894   bool result = true;
895   for (auto op : ancestor->getOperands())
896     result &= isLoopInvariantIdx(linalgOp, op, resType);
897 
898   return result;
899 }
900 
901 /// Check whether `val` could be used for calculating the trailing index for a
902 /// contiguous load operation.
903 ///
904 /// There are currently 3 types of values that are allowed here:
905 ///   1. loop-invariant values,
906 ///   2. values that increment by 1 with every loop iteration,
907 ///   3. results of basic arithmetic operations (linear and continuous)
908 ///      involving 1., 2. and 3.
909 /// This method returns True if indeed only such values are used in calculating
910 /// `val.`
911 ///
912 /// Additionally, the trailing index for a contiguous load operation should
913 /// increment by 1 with every loop iteration, i.e. be based on:
914 ///   * `linalg.index <dim>` ,
915 /// where <dim> is the trailing non-unit dim of the iteration space (this way,
916 /// `linalg.index <dim>` increments by 1 with every loop iteration).
917 /// `foundIndexOp` is updated to `true` when such Op is found.
918 static bool isContiguousLoadIdx(LinalgOp &linalgOp, Value &val,
919                                 bool &foundIndexOp, VectorType resType) {
920 
921   assert(((llvm::count_if(resType.getShape(),
922                           [](int64_t dimSize) { return dimSize > 1; }) == 1)) &&
923          "n-D vectors are not yet supported");
924 
925   // Blocks outside _this_ linalg.generic are effectively loop invariant.
926   // However, analysing block arguments for _this_ linalg.generic Op is a bit
927   // tricky. Just bail out in the latter case.
928   // TODO: We could try analysing the corresponding affine map here.
929   auto *block = linalgOp.getBlock();
930   if (isa<BlockArgument>(val))
931     return llvm::all_of(block->getArguments(),
932                         [&val](Value v) { return (v != val); });
933 
934   Operation *defOp = val.getDefiningOp();
935   assert(defOp && "This is neither a block argument nor an operation result");
936 
937   if (auto indexOp = dyn_cast<linalg::IndexOp>(defOp)) {
938     auto loopDimThatIncrementsByOne = getTrailingNonUnitLoopDimIdx(linalgOp);
939 
940     foundIndexOp = (indexOp.getDim() == loopDimThatIncrementsByOne);
941     return true;
942   }
943 
944   auto *ancestor = block->findAncestorOpInBlock(*defOp);
945 
946   if (!ancestor)
947     return false;
948 
949   // Conservatively reject Ops that could lead to indices with stride other
950   // than 1.
951   if (!isa<arith::AddIOp, arith::ConstantOp, linalg::IndexOp>(ancestor))
952     return false;
953 
954   bool result = false;
955   for (auto op : ancestor->getOperands())
956     result |= isContiguousLoadIdx(linalgOp, op, foundIndexOp, resType);
957 
958   return result;
959 }
960 
961 /// Infer the memory access pattern for the input ExtractOp
962 ///
963 /// Based on the ExtratOp result shape and the access indices, decides whether
964 /// this Op corresponds to a contiguous load (including a broadcast of a scalar)
965 /// or a gather load. When analysing the ExtractOp indices (to identify
966 /// contiguous laods), this method looks for "loop" invariant indices (e.g.
967 /// block arguments) and indices that change linearly (e.g. via `linalg.index`
968 /// Op).
969 ///
970 /// Note that it is always safe to use gather load operations for contiguous
971 /// loads (albeit slow), but not vice-versa. When in doubt, bail out and assume
972 /// that `extractOp` is a gather load.
973 static VectorMemoryAccessKind
974 getTensorExtractMemoryAccessPattern(tensor::ExtractOp extractOp,
975                                     LinalgOp &linalgOp, VectorType resType) {
976 
977   auto inputShape = cast<ShapedType>(extractOp.getTensor().getType());
978 
979   // 0. Is this a 0-D vector? If yes then this is a scalar broadcast.
980   if (inputShape.getShape().empty())
981     return VectorMemoryAccessKind::ScalarBroadcast;
982 
983   // True for vectors that are effectively 1D, e.g. `vector<1x4x1xi32>`, false
984   // otherwise.
985   bool isOutput1DVector =
986       (llvm::count_if(resType.getShape(),
987                       [](int64_t dimSize) { return dimSize > 1; }) == 1);
988   // 1. Assume that it's a gather load when reading non-1D vector.
989   if (!isOutput1DVector)
990     return VectorMemoryAccessKind::Gather;
991 
992   bool leadingIdxsLoopInvariant = true;
993 
994   // 2. Analyze the leading indices of `extractOp`.
995   // Look at the way each index is calculated and decide whether it is suitable
996   // for a contiguous load, i.e. whether it's loop invariant. If not, it's a
997   // gather load.
998   auto indices = extractOp.getIndices();
999   auto leadIndices = indices.drop_back(1);
1000 
1001   for (auto [i, indexVal] : llvm::enumerate(leadIndices)) {
1002     if (inputShape.getShape()[i] == 1)
1003       continue;
1004 
1005     leadingIdxsLoopInvariant &= isLoopInvariantIdx(linalgOp, indexVal, resType);
1006   }
1007 
1008   if (!leadingIdxsLoopInvariant) {
1009     LDBG("Found gather load: " << extractOp);
1010     return VectorMemoryAccessKind::Gather;
1011   }
1012 
1013   // 3. Analyze the trailing index for `extractOp`.
1014   // At this point we know that the leading indices are loop invariant. This
1015   // means that is potentially a scalar or a contiguous load. We can decide
1016   // based on the trailing idx.
1017   auto extractOpTrailingIdx = indices.back();
1018 
1019   // 3a. Scalar broadcast load
1020   // If the trailing index is loop invariant then this is a scalar load.
1021   if (leadingIdxsLoopInvariant &&
1022       isLoopInvariantIdx(linalgOp, extractOpTrailingIdx, resType)) {
1023     LDBG("Found scalar broadcast load: " << extractOp);
1024 
1025     return VectorMemoryAccessKind::ScalarBroadcast;
1026   }
1027 
1028   // 3b. Contiguous loads
1029   // The trailing `extractOp` index should increment with every loop iteration.
1030   // This effectively means that it must be based on the trailing loop index.
1031   // This is what the following bool captures.
1032   bool foundIndexOp = false;
1033   bool isContiguousLoad = isContiguousLoadIdx(linalgOp, extractOpTrailingIdx,
1034                                               foundIndexOp, resType);
1035   // TODO: Support generating contiguous loads for column vectors - that will
1036   // require adding a permutation map to tranfer_read Ops.
1037   bool isRowVector = resType.getShape().back() != 1;
1038   isContiguousLoad &= (foundIndexOp && isRowVector);
1039 
1040   if (isContiguousLoad) {
1041     LDBG("Found contigous load: " << extractOp);
1042     return VectorMemoryAccessKind::Contiguous;
1043   }
1044 
1045   // 4. Fallback case - gather load.
1046   LDBG("Found gather load: " << extractOp);
1047   return VectorMemoryAccessKind::Gather;
1048 }
1049 
1050 /// Helper function to vectorize the tensor.extract operations. Returns
1051 /// VectorizationStatus::NewOp to signal the vectorization algorithm that it
1052 /// should map the produced operations. This function is meant to be used as a
1053 /// CustomVectorizationHook.
1054 static VectorizationResult
1055 vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state,
1056                        Operation *op, LinalgOp linalgOp, const IRMapping &bvm) {
1057   tensor::ExtractOp extractOp = dyn_cast<tensor::ExtractOp>(op);
1058   if (!extractOp)
1059     return VectorizationResult{VectorizationStatus::Failure, nullptr};
1060   auto loc = extractOp.getLoc();
1061 
1062   // Compute the static loop sizes of the extract op.
1063   auto resultType = state.getCanonicalVecType(extractOp.getResult().getType());
1064   auto maskConstantOp = rewriter.create<arith::ConstantOp>(
1065       loc,
1066       DenseIntElementsAttr::get(state.getCanonicalVecType(rewriter.getI1Type()),
1067                                 /*value=*/true));
1068   auto passThruConstantOp =
1069       rewriter.create<arith::ConstantOp>(loc, rewriter.getZeroAttr(resultType));
1070 
1071   // Base indices are currently set to 0. We will need to re-visit if more
1072   // generic scenarios are to be supported.
1073   SmallVector<Value> baseIndices(
1074       extractOp.getIndices().size(),
1075       rewriter.create<arith::ConstantIndexOp>(loc, 0));
1076 
1077   VectorMemoryAccessKind memAccessKind =
1078       getTensorExtractMemoryAccessPattern(extractOp, linalgOp, resultType);
1079 
1080   // 1. Handle gather access
1081   if (memAccessKind == VectorMemoryAccessKind::Gather) {
1082     Value offset = calculateGatherOffset(rewriter, state, extractOp, bvm);
1083 
1084     // Generate the gather load
1085     Operation *gatherOp = rewriter.create<vector::GatherOp>(
1086         loc, resultType, extractOp.getTensor(), baseIndices, offset,
1087         maskConstantOp, passThruConstantOp);
1088     gatherOp = state.maskOperation(rewriter, gatherOp, linalgOp);
1089 
1090     LDBG("Vectorised as gather load: " << extractOp << "\n");
1091     return VectorizationResult{VectorizationStatus::NewOp, gatherOp};
1092   }
1093 
1094   // 2. Handle:
1095   //  a. scalar loads + broadcast,
1096   //  b. contiguous loads.
1097   // Both cases use vector.transfer_read.
1098 
1099   assert(llvm::count_if(resultType.getShape(),
1100                         [](uint64_t dim) { return dim != 1; }) &&
1101          "Contiguous loads and scalar loads + broadcast only support 1-D "
1102          "vectors ATM!");
1103 
1104   // Collect indices for `vector.transfer_read`. At this point, the indices will
1105   // either be scalars or would have been broadcast to vectors matching the
1106   // result type. For indices that are vectors, there are two options:
1107   //    * for non-trailing indices, all elements are identical (contiguous
1108   //      loads are identified by looking for non-trailing indices that are
1109   //      invariant with respect to the corresponding linalg.generic), or
1110   //    * for trailing indices, the index vector will contain values with stride
1111   //      one, but for `vector.transfer_read` only the first (i.e. 0th) index is
1112   //      needed.
1113   // This means that
1114   //   * for scalar indices - just re-use it,
1115   //   * for vector indices (e.g. `vector<1x1x4xindex>`) - extract the bottom
1116   //    (0th) element and use that.
1117   SmallVector<Value> transferReadIdxs;
1118   auto zero = rewriter.create<arith::ConstantOp>(
1119       loc, rewriter.getI32Type(), rewriter.getZeroAttr(rewriter.getI32Type()));
1120   for (size_t i = 0; i < extractOp.getIndices().size(); i++) {
1121     Value idx = bvm.lookup(extractOp.getIndices()[i]);
1122     if (idx.getType().isIndex()) {
1123       transferReadIdxs.push_back(idx);
1124       continue;
1125     }
1126 
1127     auto indexAs1dVector = rewriter.create<vector::ShapeCastOp>(
1128         loc,
1129         VectorType::get(resultType.getShape().back(), rewriter.getIndexType(),
1130                         resultType.getScalableDims().back()),
1131         idx);
1132     transferReadIdxs.push_back(
1133         rewriter.create<vector::ExtractElementOp>(loc, indexAs1dVector, zero));
1134   }
1135 
1136   // `tensor.extract_element` is always in-bounds, hence the following holds.
1137   auto dstRank = resultType.getRank();
1138   auto srcRank = extractOp.getTensor().getType().getRank();
1139   SmallVector<bool> inBounds(dstRank, true);
1140 
1141   // 2a. Handle scalar broadcast access.
1142   if (memAccessKind == VectorMemoryAccessKind::ScalarBroadcast) {
1143     MLIRContext *ctx = rewriter.getContext();
1144     SmallVector<AffineExpr> exprs(dstRank, getAffineConstantExpr(0, ctx));
1145     auto permutationMap = AffineMap::get(srcRank, 0, exprs, ctx);
1146 
1147     auto transferReadOp = rewriter.create<vector::TransferReadOp>(
1148         loc, resultType, extractOp.getTensor(), transferReadIdxs,
1149         permutationMap, inBounds);
1150 
1151     LDBG("Vectorised as scalar broadcast load: " << extractOp << "\n");
1152     return VectorizationResult{VectorizationStatus::NewOp, transferReadOp};
1153   }
1154 
1155   // 2b. Handle contiguous access.
1156   auto permutationMap = AffineMap::getMinorIdentityMap(
1157       srcRank, std::min(dstRank, srcRank), rewriter.getContext());
1158 
1159   int32_t rankDiff = dstRank - srcRank;
1160   // When dstRank > srcRank, broadcast the source tensor to the unitary leading
1161   // dims so that the ranks match. This is done by extending the map with 0s.
1162   // For example, for dstRank = 3, srcRank = 2, the following map created
1163   // above:
1164   //    (d0, d1) --> (d0, d1)
1165   // is extended as:
1166   //    (d0, d1) --> (0, d0, d1)
1167   while (rankDiff > 0) {
1168     permutationMap = permutationMap.insertResult(
1169         mlir::getAffineConstantExpr(0, rewriter.getContext()), 0);
1170     rankDiff--;
1171   }
1172 
1173   auto transferReadOp = rewriter.create<vector::TransferReadOp>(
1174       loc, resultType, extractOp.getTensor(), transferReadIdxs, permutationMap,
1175       inBounds);
1176 
1177   LDBG("Vectorised as contiguous load: " << extractOp);
1178   return VectorizationResult{VectorizationStatus::NewOp, transferReadOp};
1179 }
1180 
1181 /// Emit reduction operations if the shapes of the value to reduce is different
1182 /// that the result shape.
1183 // Note: this is a true builder that notifies the OpBuilder listener.
1184 // TODO: Consider moving as a static helper on the ReduceOp.
1185 static Operation *reduceIfNeeded(OpBuilder &b, LinalgOp linalgOp, Operation *op,
1186                                  Value reduceValue, Value initialValue,
1187                                  const IRMapping &bvm) {
1188   Value reduceVec = bvm.lookup(reduceValue);
1189   Value outputVec = bvm.lookup(initialValue);
1190   auto reduceType = dyn_cast<VectorType>(reduceVec.getType());
1191   auto outputType = dyn_cast<VectorType>(outputVec.getType());
1192   // Reduce only if needed as the value may already have been reduce for
1193   // contraction vectorization.
1194   if (!reduceType ||
1195       (outputType && reduceType.getShape() == outputType.getShape()))
1196     return nullptr;
1197   SmallVector<bool> dimsToMask = getDimsToReduce(linalgOp);
1198   return buildMultiDimReduce(b, op, reduceVec, outputVec, dimsToMask);
1199 }
1200 
1201 /// Generic vectorization for a single operation `op`, given already vectorized
1202 /// operands carried by `bvm`. Vectorization occurs as follows:
1203 ///   1. Try to apply any of the `customVectorizationHooks` and return its
1204 ///   result on success.
1205 ///   2. Clone any constant in the current scope without vectorization: each
1206 ///   consumer of the constant will later determine the shape to which the
1207 ///   constant needs to be broadcast to.
1208 ///   3. Fail on any remaining non `ElementwiseMappable` op. It is the purpose
1209 ///   of the `customVectorizationHooks` to cover such cases.
1210 ///   4. Clone `op` in vector form to a vector of shape prescribed by the first
1211 ///   operand of maximal rank. Other operands have smaller rank and are
1212 ///   broadcast accordingly. It is assumed this broadcast is always legal,
1213 ///   otherwise, it means one of the `customVectorizationHooks` is incorrect.
1214 ///
1215 /// This function assumes all operands of `op` have been vectorized and are in
1216 /// the `bvm` mapping. As a consequence, this function is meant to be called  on
1217 /// a topologically-sorted list of ops.
1218 /// This function does not update `bvm` but returns a VectorizationStatus that
1219 /// instructs the caller what `bvm` update needs to occur.
1220 static VectorizationResult
1221 vectorizeOneOp(RewriterBase &rewriter, VectorizationState &state,
1222                LinalgOp linalgOp, Operation *op, const IRMapping &bvm,
1223                ArrayRef<CustomVectorizationHook> customVectorizationHooks) {
1224   LDBG("vectorize op " << *op << "\n");
1225 
1226   // 1. Try to apply any CustomVectorizationHook.
1227   if (!customVectorizationHooks.empty()) {
1228     for (auto &customFunc : customVectorizationHooks) {
1229       VectorizationResult result = customFunc(op, bvm);
1230       if (result.status == VectorizationStatus::Failure)
1231         continue;
1232       return result;
1233     }
1234   }
1235 
1236   // 2. Constant ops don't get vectorized but rather broadcasted at their users.
1237   // Clone so that the constant is not confined to the linalgOp block .
1238   if (isa<arith::ConstantOp, func::ConstantOp>(op))
1239     return VectorizationResult{VectorizationStatus::NewOp, rewriter.clone(*op)};
1240 
1241   // 3. Only ElementwiseMappable are allowed in the generic vectorization.
1242   if (!OpTrait::hasElementwiseMappableTraits(op))
1243     return VectorizationResult{VectorizationStatus::Failure, nullptr};
1244 
1245   // 4 . Check if the operation is a reduction.
1246   SmallVector<std::pair<Value, Value>> reductionOperands;
1247   for (Value operand : op->getOperands()) {
1248     auto blockArg = dyn_cast<BlockArgument>(operand);
1249     if (!blockArg || blockArg.getOwner() != linalgOp.getBlock() ||
1250         blockArg.getArgNumber() < linalgOp.getNumDpsInputs())
1251       continue;
1252     SmallVector<Operation *> reductionOps;
1253     Value reduceValue = matchReduction(
1254         linalgOp.getRegionOutputArgs(),
1255         blockArg.getArgNumber() - linalgOp.getNumDpsInputs(), reductionOps);
1256     if (!reduceValue)
1257       continue;
1258     reductionOperands.push_back(std::make_pair(reduceValue, operand));
1259   }
1260   if (!reductionOperands.empty()) {
1261     assert(reductionOperands.size() == 1);
1262     Operation *reduceOp =
1263         reduceIfNeeded(rewriter, linalgOp, op, reductionOperands[0].first,
1264                        reductionOperands[0].second, bvm);
1265     if (reduceOp)
1266       return VectorizationResult{VectorizationStatus::NewOp, reduceOp};
1267   }
1268 
1269   // 5. Generic vectorization path for ElementwiseMappable ops.
1270   //   a. Get the first max ranked shape.
1271   VectorType firstMaxRankedType;
1272   for (Value operand : op->getOperands()) {
1273     auto vecOperand = bvm.lookup(operand);
1274     assert(vecOperand && "Vector operand couldn't be found");
1275 
1276     auto vecType = dyn_cast<VectorType>(vecOperand.getType());
1277     if (vecType && (!firstMaxRankedType ||
1278                     firstMaxRankedType.getRank() < vecType.getRank()))
1279       firstMaxRankedType = vecType;
1280   }
1281   //   b. Broadcast each op if needed.
1282   SmallVector<Value> vecOperands;
1283   for (Value scalarOperand : op->getOperands()) {
1284     Value vecOperand = bvm.lookup(scalarOperand);
1285     assert(vecOperand && "Vector operand couldn't be found");
1286 
1287     if (firstMaxRankedType) {
1288       auto vecType = VectorType::get(firstMaxRankedType.getShape(),
1289                                      getElementTypeOrSelf(vecOperand.getType()),
1290                                      firstMaxRankedType.getScalableDims());
1291       vecOperands.push_back(broadcastIfNeeded(rewriter, vecOperand, vecType));
1292     } else {
1293       vecOperands.push_back(vecOperand);
1294     }
1295   }
1296   //   c. for elementwise, the result is the vector with the firstMaxRankedShape
1297   SmallVector<Type> resultTypes;
1298   for (Type resultType : op->getResultTypes()) {
1299     resultTypes.push_back(
1300         firstMaxRankedType
1301             ? VectorType::get(firstMaxRankedType.getShape(), resultType,
1302                               firstMaxRankedType.getScalableDims())
1303             : resultType);
1304   }
1305   //   d. Build and return the new op.
1306   return VectorizationResult{
1307       VectorizationStatus::NewOp,
1308       rewriter.create(op->getLoc(), op->getName().getIdentifier(), vecOperands,
1309                       resultTypes, op->getAttrs())};
1310 }
1311 
1312 /// Generic vectorization function that rewrites the body of a `linalgOp` into
1313 /// vector form. Generic vectorization proceeds as follows:
1314 ///   1. Verify the `linalgOp` has one non-empty region.
1315 ///   2. Values defined above the region are mapped to themselves and will be
1316 ///   broadcasted on a per-need basis by their consumers.
1317 ///   3. Each region argument is vectorized into a vector.transfer_read (or 0-d
1318 ///   load).
1319 ///   TODO: Reuse opportunities for RAR dependencies.
1320 ///   4a. Register CustomVectorizationHook for YieldOp to capture the results.
1321 ///   4rewriter. Register CustomVectorizationHook for IndexOp to access the
1322 ///   iteration indices.
1323 ///   5. Iteratively call vectorizeOneOp on the region operations.
1324 ///
1325 /// When `broadcastToMaximalCommonShape` is set to true, eager broadcasting is
1326 /// performed to the maximal common vector size implied by the `linalgOp`
1327 /// iteration space. This eager broadcasting is introduced in the
1328 /// permutation_map of the vector.transfer_read operations. The eager
1329 /// broadcasting makes it trivial to detrmine where broadcast, transposes and
1330 /// reductions should occur, without any bookkeeping. The tradeoff is that, in
1331 /// the absence of good canonicalizations, the amount of work increases.
1332 /// This is not deemed a problem as we expect canonicalizations and foldings to
1333 /// aggressively clean up the useless work.
1334 static LogicalResult
1335 vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state,
1336                          LinalgOp linalgOp,
1337                          SmallVectorImpl<Value> &newResults) {
1338   LDBG("Vectorizing operation as linalg generic\n");
1339   Block *block = linalgOp.getBlock();
1340 
1341   // 2. Values defined above the region can only be broadcast for now. Make them
1342   // map to themselves.
1343   IRMapping bvm;
1344   SetVector<Value> valuesSet;
1345   mlir::getUsedValuesDefinedAbove(linalgOp->getRegion(0), valuesSet);
1346   bvm.map(valuesSet.getArrayRef(), valuesSet.getArrayRef());
1347 
1348   if (linalgOp.getNumDpsInits() == 0)
1349     return failure();
1350 
1351   // 3. Turn all BBArgs into vector.transfer_read / load.
1352   Location loc = linalgOp.getLoc();
1353   Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
1354   for (OpOperand *opOperand : linalgOp.getOpOperandsMatchingBBargs()) {
1355     BlockArgument bbarg = linalgOp.getMatchingBlockArgument(opOperand);
1356     if (linalgOp.isScalar(opOperand)) {
1357       bvm.map(bbarg, opOperand->get());
1358       continue;
1359     }
1360 
1361     // 3.a. Convert the indexing map for this input/output to a transfer read
1362     // permutation map and masking map.
1363     AffineMap indexingMap = linalgOp.getMatchingIndexingMap(opOperand);
1364 
1365     AffineMap readMap;
1366     VectorType readType;
1367     Type elemType = getElementTypeOrSelf(opOperand->get());
1368     if (linalgOp.isDpsInput(opOperand)) {
1369       // 3.a.i. For input reads we use the canonical vector shape.
1370       readMap = inverseAndBroadcastProjectedPermutation(indexingMap);
1371       readType = state.getCanonicalVecType(elemType);
1372     } else {
1373       // 3.a.ii. For output reads (iteration-carried dependence, e.g.,
1374       // reductions), the vector shape is computed by mapping the canonical
1375       // vector shape to the output domain and back to the canonical domain.
1376       readMap = inversePermutation(reindexIndexingMap(indexingMap));
1377       readType =
1378           state.getCanonicalVecType(elemType, readMap.compose(indexingMap));
1379     }
1380 
1381     SmallVector<Value> indices(linalgOp.getShape(opOperand).size(), zero);
1382 
1383     Operation *read = rewriter.create<vector::TransferReadOp>(
1384         loc, readType, opOperand->get(), indices, readMap);
1385     read = state.maskOperation(rewriter, read, linalgOp, indexingMap);
1386     Value readValue = read->getResult(0);
1387 
1388     // 3.b. If masked, set in-bounds to true. Masking guarantees that the access
1389     // will be in-bounds.
1390     if (auto maskOp = dyn_cast<vector::MaskingOpInterface>(read)) {
1391       SmallVector<bool> inBounds(readType.getRank(), true);
1392       cast<vector::TransferReadOp>(maskOp.getMaskableOp())
1393           .setInBoundsAttr(rewriter.getBoolArrayAttr(inBounds));
1394     }
1395 
1396     // 3.c. Not all ops support 0-d vectors, extract the scalar for now.
1397     // TODO: remove this.
1398     if (readType.getRank() == 0)
1399       readValue = rewriter.create<vector::ExtractElementOp>(loc, readValue);
1400 
1401     LDBG("New vectorized bbarg(" << bbarg.getArgNumber() << "): " << readValue
1402                                  << "\n");
1403     bvm.map(bbarg, readValue);
1404     bvm.map(opOperand->get(), readValue);
1405   }
1406 
1407   SmallVector<CustomVectorizationHook> hooks;
1408   // 4a. Register CustomVectorizationHook for yieldOp.
1409   CustomVectorizationHook vectorizeYield =
1410       [&](Operation *op, const IRMapping &bvm) -> VectorizationResult {
1411     return vectorizeLinalgYield(rewriter, op, bvm, state, linalgOp, newResults);
1412   };
1413   hooks.push_back(vectorizeYield);
1414 
1415   // 4b. Register CustomVectorizationHook for indexOp.
1416   CustomVectorizationHook vectorizeIndex =
1417       [&](Operation *op, const IRMapping &bvm) -> VectorizationResult {
1418     return vectorizeLinalgIndex(rewriter, state, op, linalgOp);
1419   };
1420   hooks.push_back(vectorizeIndex);
1421 
1422   // 4c. Register CustomVectorizationHook for extractOp.
1423   CustomVectorizationHook vectorizeExtract =
1424       [&](Operation *op, const IRMapping &bvm) -> VectorizationResult {
1425     return vectorizeTensorExtract(rewriter, state, op, linalgOp, bvm);
1426   };
1427   hooks.push_back(vectorizeExtract);
1428 
1429   // 5. Iteratively call `vectorizeOneOp` to each op in the slice.
1430   for (Operation &op : block->getOperations()) {
1431     VectorizationResult result =
1432         vectorizeOneOp(rewriter, state, linalgOp, &op, bvm, hooks);
1433     if (result.status == VectorizationStatus::Failure) {
1434       LDBG("failed to vectorize: " << op << "\n");
1435       return failure();
1436     }
1437     if (result.status == VectorizationStatus::NewOp) {
1438       Operation *maybeMaskedOp =
1439           state.maskOperation(rewriter, result.newOp, linalgOp);
1440       LDBG("New vector op: " << *maybeMaskedOp << "\n");
1441       bvm.map(op.getResults(), maybeMaskedOp->getResults());
1442     }
1443   }
1444 
1445   return success();
1446 }
1447 
1448 /// Given a tensor::PackOp, return the `dest` shape before any packing
1449 /// permutations.
1450 static SmallVector<int64_t> getTiledPackShape(tensor::PackOp packOp,
1451                                               ArrayRef<int64_t> destShape) {
1452   return applyPermutation(destShape, tensor::getPackInverseDestPerm(packOp));
1453 }
1454 
1455 /// Given an input, the mixed destSizes, and the vector sizes for vectorization,
1456 /// create an empty destination tensor and create a TransferWriteOp from the
1457 /// input to the empty tensor. If the destination shape is not the same as the
1458 /// inputVectorSizes for the first rank(inputVectorSizes) dims, then create a
1459 /// mask for the write. If `useInBoundsInsteadOfMasking` is set, then update the
1460 /// inBounds attribute of the transfer write op instead of masking.
1461 static Operation *createWriteOrMaskedWrite(OpBuilder &builder, Location loc,
1462                                            Value input,
1463                                            SmallVector<OpFoldResult> destSizes,
1464                                            ArrayRef<int64_t> inputVectorSizes,
1465                                            bool useInBoundsInsteadOfMasking) {
1466 
1467   auto inputType = cast<VectorType>(input.getType());
1468   Value dest = builder.create<tensor::EmptyOp>(loc, destSizes,
1469                                                inputType.getElementType());
1470   int64_t rank = cast<ShapedType>(dest.getType()).getRank();
1471   auto zero = builder.create<arith::ConstantIndexOp>(loc, 0);
1472   auto destShape = cast<ShapedType>(dest.getType()).getShape();
1473   SmallVector<bool> inBoundsVal(rank, true);
1474   if (useInBoundsInsteadOfMasking) {
1475     // Update the inBounds attribute.
1476     for (unsigned i = 0; i < rank; i++)
1477       inBoundsVal[i] = (destShape[i] == inputVectorSizes[i]) &&
1478                        !ShapedType::isDynamic(destShape[i]);
1479   }
1480   Operation *write = builder.create<vector::TransferWriteOp>(
1481       loc,
1482       /*vector=*/input,
1483       /*source=*/dest,
1484       /*indices=*/SmallVector<Value>(rank, zero),
1485       /*inBounds=*/inBoundsVal);
1486   assert(llvm::none_of(
1487              destShape.drop_front(inputVectorSizes.size()),
1488              [](int64_t size) { return size == ShapedType::kDynamic; }) &&
1489          "Only dims aligned with inputVectorSizes may be dynamic");
1490   if (useInBoundsInsteadOfMasking)
1491     return write;
1492   bool needMaskForWrite = !llvm::equal(
1493       inputVectorSizes, destShape.take_front(inputVectorSizes.size()));
1494   if (needMaskForWrite) {
1495     SmallVector<int64_t> writeMaskShape;
1496     writeMaskShape.append(inputVectorSizes.begin(), inputVectorSizes.end());
1497     writeMaskShape.append(destShape.begin() + inputVectorSizes.size(),
1498                           destShape.end());
1499     auto writeMaskType = VectorType::get(writeMaskShape, builder.getI1Type());
1500     Value maskForWrite =
1501         builder.create<vector::CreateMaskOp>(loc, writeMaskType, destSizes);
1502     write = mlir::vector::maskOperation(builder, write, maskForWrite);
1503   }
1504   return write;
1505 }
1506 
1507 /// Vectorize tensor::PackOp with (1) static innerTiles (2) constant
1508 /// padding value and (3) input vector sizes into:
1509 /// masked_transfer_read->shape_cast->transpose->transfer_write_in_bounds
1510 /// As in the following example:
1511 /// %pack = tensor.pack %src inner_dims_pos = [2, 1] inner_tiles = [16, 2]
1512 ///     into %dst : tensor<32x8x16xf32> -> tensor<32x4x1x16x2xf32>
1513 ///
1514 /// This pack would be vectorized to:
1515 ///
1516 /// %load = vector.mask %mask {
1517 ///     vector.transfer_read %arg0[%c0, %c0, %c0], %cst
1518 ///         {in_bounds = [true, true, true]} :
1519 ///         tensor<32x7x16xf32>, vector<32x8x16xf32>
1520 /// } : vector<32x8x16xi1> -> vector<32x8x16xf32>
1521 /// %shape_cast = vector.shape_cast %load : vector<32x8x16xf32>
1522 ///                                         to vector<32x4x2x1x16xf32>
1523 /// %transpose = vector.transpose %shape_cast, [0, 1, 3, 4, 2]
1524 ///     : vector<32x4x2x1x16xf32> to vector<32x4x1x16x2xf32>
1525 /// %write = vector.transfer_write %transpose,
1526 ///     %empty[%c0_0, %c0_0, %c0_0, %c0_0, %c0_0]
1527 ///     {in_bounds = [true, true, true, true, true]}
1528 ///     : vector<32x4x1x16x2xf32>, tensor<32x4x1x16x2xf32>
1529 ///
1530 /// If the (3) input vector sizes are not provided, the vector sizes are
1531 /// determined by the result tensor shape. Also, we update the inBounds
1532 /// attribute instead of masking.
1533 static LogicalResult
1534 vectorizeAsTensorPackOp(RewriterBase &rewriter, tensor::PackOp packOp,
1535                         ArrayRef<int64_t> inputVectorSizes,
1536                         SmallVectorImpl<Value> &newResults) {
1537   OpBuilder::InsertionGuard g(rewriter);
1538   rewriter.setInsertionPoint(packOp);
1539 
1540   Location loc = packOp.getLoc();
1541   auto padValue = packOp.getPaddingValue();
1542   if (!padValue) {
1543     padValue = rewriter.create<arith::ConstantOp>(
1544         loc, rewriter.getZeroAttr(packOp.getSourceType().getElementType()));
1545   }
1546   ReifiedRankedShapedTypeDims reifiedReturnShapes;
1547   LogicalResult status =
1548       cast<ReifyRankedShapedTypeOpInterface>(packOp.getOperation())
1549           .reifyResultShapes(rewriter, reifiedReturnShapes);
1550   (void)status; // prevent unused variable warning on non-assert builds.
1551   assert(succeeded(status) && "failed to reify result shapes");
1552 
1553   // If the input vector sizes are not provided, then the vector sizes are
1554   // determined by the result tensor shape. In case the vector sizes aren't
1555   // provided, we update the inBounds attribute instead of masking.
1556   bool useInBoundsInsteadOfMasking = false;
1557   if (inputVectorSizes.empty()) {
1558     ArrayRef<int64_t> resultTensorShape = packOp.getDestType().getShape();
1559     inputVectorSizes = resultTensorShape.take_front(packOp.getSourceRank());
1560     useInBoundsInsteadOfMasking = true;
1561   }
1562 
1563   // Create masked TransferReadOp.
1564   SmallVector<int64_t> inputShape(inputVectorSizes);
1565   auto innerTiles = packOp.getStaticInnerTiles();
1566   auto innerDimsPos = packOp.getInnerDimsPos();
1567   auto outerDimsPerm = packOp.getOuterDimsPerm();
1568   if (!outerDimsPerm.empty())
1569     applyPermutationToVector(inputShape,
1570                              invertPermutationVector(outerDimsPerm));
1571   for (auto [idx, size] : enumerate(innerTiles))
1572     inputShape[innerDimsPos[idx]] *= size;
1573   auto maskedRead = vector::createReadOrMaskedRead(
1574       rewriter, loc, packOp.getSource(), inputShape, padValue,
1575       useInBoundsInsteadOfMasking);
1576 
1577   // Create ShapeCastOp.
1578   SmallVector<int64_t> destShape(inputVectorSizes);
1579   destShape.append(innerTiles.begin(), innerTiles.end());
1580   auto tiledPackType = VectorType::get(getTiledPackShape(packOp, destShape),
1581                                        packOp.getDestType().getElementType());
1582   auto shapeCastOp =
1583       rewriter.create<vector::ShapeCastOp>(loc, tiledPackType, maskedRead);
1584 
1585   // Create TransposeOp.
1586   auto destPermutation =
1587       invertPermutationVector(tensor::getPackInverseDestPerm(packOp));
1588   auto transposeOp = rewriter.create<vector::TransposeOp>(
1589       loc, shapeCastOp.getResult(), destPermutation);
1590 
1591   // Create TransferWriteOp.
1592   Operation *write = createWriteOrMaskedWrite(
1593       rewriter, loc, transposeOp.getResult(), reifiedReturnShapes[0],
1594       inputVectorSizes, /*useInBoundsInsteadOfMasking=*/false);
1595   newResults.push_back(write->getResult(0));
1596   return success();
1597 }
1598 
1599 /// Vectorize a `tensor::UnPackOp` to these 4 Ops:
1600 ///   Vector::TransferReadOp - Reads a vector from the source tensor
1601 ///   vector::TransposeOp - Transpose the Source tensor
1602 ///   ShapeCastOp - Reshape the data based on the target.
1603 ///   vector::TransferWriteOp. - Write the result vector back to the destination
1604 ///   tensor.
1605 ///   If the vector sizes are not provided:
1606 ///   * the vector sizes are determined by the input operand and attributes,
1607 ///   * update the inBounds attribute instead of masking.
1608 static LogicalResult
1609 vectorizeAsTensorUnpackOp(RewriterBase &rewriter, tensor::UnPackOp unpackOp,
1610                           ArrayRef<int64_t> inputVectorSizes,
1611                           SmallVectorImpl<Value> &newResults) {
1612 
1613   OpBuilder::InsertionGuard g(rewriter);
1614   rewriter.setInsertionPoint(unpackOp);
1615 
1616   RankedTensorType unpackTensorType = unpackOp.getSourceType();
1617 
1618   ArrayRef<int64_t> innerDimPos = unpackOp.getInnerDimsPos();
1619   ArrayRef<int64_t> innerTiles = unpackOp.getStaticInnerTiles();
1620   ArrayRef<int64_t> sourceShape = unpackTensorType.getShape();
1621   bool useInBoundsInsteadOfMasking = false;
1622   ArrayRef<int64_t> outerDimsPerm = unpackOp.getOuterDimsPerm();
1623 
1624   auto destSize = unpackOp.getDestRank();
1625 
1626   if (!inputVectorSizes.empty())
1627     assert(inputVectorSizes.size() == destSize &&
1628            "Incorrect number of input vector sizes");
1629 
1630   // vectorSizes is the shape of the vector that will be used to do final
1631   // write on the destination tensor. It is set like this: Let's say the
1632   // source tensor is rank 'M' and the dest tensor rank 'N', where N <= M.
1633   // Thus:
1634   // 1. vectorSizes = sourceShape.take_front(N)
1635   // 2. if outer_dims_perms is present: do that permutation on vectorSizes.
1636   // 3. multiply all the locations in vectorSize pointed by innerDimPos by the
1637   //    innerTiles attribute value.
1638   SmallVector<int64_t> vectorSizes(inputVectorSizes);
1639   if (vectorSizes.empty()) {
1640     llvm::append_range(vectorSizes, sourceShape.take_front(destSize));
1641     if (!outerDimsPerm.empty())
1642       applyPermutationToVector(vectorSizes, outerDimsPerm);
1643     for (auto [i, pos] : llvm::enumerate(innerDimPos))
1644       vectorSizes[pos] *= innerTiles[i];
1645 
1646     useInBoundsInsteadOfMasking = true;
1647   }
1648 
1649   // readVectorSizes is the size of tensor used to read and apply mask. It is
1650   // set like this: Let's say the vectorSize (VS) array is size 'N' and
1651   // the sourceShape(SS) is 'M' where M >= N and InnerTileSizes (IT) of
1652   // size M-N
1653   // Thus:
1654   // - initially: readVectorSizes = vectorInputSizes
1655   // - Divide all the readMaskShape locations pointed by innerDimPos
1656   //   by the innerTileSize attribute value.
1657   // - if outer_dims_perms is present: do that permutation on readVectorSizes.
1658   // - Append the remaining shape from SS
1659   // E.g. let's say let's say unpackTensorType.getShape() = <8x8x32x16>
1660   // inner Dim Pos = [0, 1] and Inner Tiles = [32, 16], vector_sizes are [512,
1661   // 128] and outer_dims_perm is [1, 0] then read shape is:
1662   //   ReadVectorSizes(initial): [512, 128]
1663   //   Final Value(after innerDim Adjustment): [512/32, 128/16]
1664   //                                           = [16, 8]
1665   //   After applying outer_dims_perm: [8, 16]
1666   //   After appending the rest of the sourceShape: [8, 16, 32, 16]
1667 
1668   SmallVector<int64_t> readVectorSizes(vectorSizes.begin(), vectorSizes.end());
1669 
1670   for (auto [index, size] : enumerate(innerTiles)) {
1671     readVectorSizes[innerDimPos[index]] =
1672         llvm::divideCeil(readVectorSizes[innerDimPos[index]], size);
1673   }
1674   if (!outerDimsPerm.empty()) {
1675     applyPermutationToVector(readVectorSizes, outerDimsPerm);
1676   }
1677   readVectorSizes.append(sourceShape.begin() + vectorSizes.size(),
1678                          sourceShape.end());
1679 
1680   ReifiedRankedShapedTypeDims reifiedRetShapes;
1681   LogicalResult status =
1682       cast<ReifyRankedShapedTypeOpInterface>(unpackOp.getOperation())
1683           .reifyResultShapes(rewriter, reifiedRetShapes);
1684   if (status.failed()) {
1685     LDBG("Unable to reify result shapes of " << unpackOp);
1686     return failure();
1687   }
1688   Location loc = unpackOp->getLoc();
1689 
1690   auto padValue = rewriter.create<arith::ConstantOp>(
1691       loc, rewriter.getZeroAttr(unpackOp.getSourceType().getElementType()));
1692 
1693   // Read result, mask if necessary. If transferReadOp shape is not equal
1694   // to shape of source, then a mask is necessary.
1695   Value readResult = vector::createReadOrMaskedRead(
1696       rewriter, loc, unpackOp.getSource(), readVectorSizes, padValue,
1697       /*useInBoundsInsteadOfMasking=*/false);
1698 
1699   PackingMetadata packMetadata;
1700   SmallVector<int64_t> lastDimToInsertPosPerm =
1701       tensor::getUnPackInverseSrcPerm(unpackOp, packMetadata);
1702   ShapedType maskedOpShapedType = cast<ShapedType>(readResult.getType());
1703   SmallVector<int64_t> stripMineShape(maskedOpShapedType.getShape());
1704   mlir::Type stripMineElemType = maskedOpShapedType.getElementType();
1705   applyPermutationToVector(stripMineShape, lastDimToInsertPosPerm);
1706   RankedTensorType stripMineTensorType =
1707       RankedTensorType::get(stripMineShape, stripMineElemType);
1708   // Transpose the appropriate rows to match output.
1709   vector::TransposeOp transposeOp = rewriter.create<vector::TransposeOp>(
1710       loc, readResult, lastDimToInsertPosPerm);
1711 
1712   // Collapse the vector to the size required by result.
1713   RankedTensorType collapsedType = tensor::CollapseShapeOp::inferCollapsedType(
1714       stripMineTensorType, packMetadata.reassociations);
1715   mlir::VectorType vecCollapsedType =
1716       VectorType::get(collapsedType.getShape(), collapsedType.getElementType());
1717   vector::ShapeCastOp shapeCastOp = rewriter.create<vector::ShapeCastOp>(
1718       loc, vecCollapsedType, transposeOp->getResult(0));
1719 
1720   // writeVectorSizes had to match the shapecast shape for dynamic sizes,
1721   // otherwise the validator complains that the mask size is invalid.
1722   SmallVector<int64_t> writeVectorSizes(
1723       unpackOp.getDestType().hasStaticShape()
1724           ? vectorSizes
1725           : shapeCastOp.getResultVectorType().getShape());
1726   Operation *write = createWriteOrMaskedWrite(
1727       rewriter, loc, shapeCastOp.getResult(), reifiedRetShapes[0],
1728       writeVectorSizes, useInBoundsInsteadOfMasking);
1729   newResults.push_back(write->getResult(0));
1730   return success();
1731 }
1732 
1733 /// Vectorize a `padOp` with (1) static result type, (2) constant padding value
1734 /// and (3) all-zero lowPad to
1735 ///   `transfer_write_in_bounds(transfer_read_masked(pad_source, pad_value))`.
1736 static LogicalResult
1737 vectorizeAsTensorPadOp(RewriterBase &rewriter, tensor::PadOp padOp,
1738                        ArrayRef<int64_t> inputVectorSizes,
1739                        SmallVectorImpl<Value> &newResults) {
1740   auto padValue = padOp.getConstantPaddingValue();
1741   Location loc = padOp.getLoc();
1742 
1743   // transfer_write_in_bounds(transfer_read_masked(pad_source, pad_value))
1744   OpBuilder::InsertionGuard g(rewriter);
1745   rewriter.setInsertionPoint(padOp);
1746 
1747   ReifiedRankedShapedTypeDims reifiedReturnShapes;
1748   LogicalResult status =
1749       cast<ReifyRankedShapedTypeOpInterface>(padOp.getOperation())
1750           .reifyResultShapes(rewriter, reifiedReturnShapes);
1751   (void)status; // prevent unused variable warning on non-assert builds
1752   assert(succeeded(status) && "failed to reify result shapes");
1753   auto maskedRead = vector::createReadOrMaskedRead(
1754       rewriter, loc, padOp.getSource(), inputVectorSizes, padValue,
1755       /*useInBoundsInsteadOfMasking=*/false);
1756   Operation *write = createWriteOrMaskedWrite(
1757       rewriter, loc, maskedRead, reifiedReturnShapes[0], inputVectorSizes,
1758       /*useInBoundsInsteadOfMasking=*/false);
1759   newResults.push_back(write->getResult(0));
1760   return success();
1761 }
1762 
1763 // TODO: probably need some extra checks for reduction followed by consumer
1764 // ops that may not commute (e.g. linear reduction + non-linear instructions).
1765 static LogicalResult reductionPreconditions(LinalgOp op) {
1766   if (llvm::none_of(op.getIteratorTypesArray(), isReductionIterator)) {
1767     LDBG("reduction precondition failed: no reduction iterator\n");
1768     return failure();
1769   }
1770   for (OpOperand &opOperand : op.getDpsInitsMutable()) {
1771     AffineMap indexingMap = op.getMatchingIndexingMap(&opOperand);
1772     if (indexingMap.isPermutation())
1773       continue;
1774 
1775     Operation *reduceOp = matchLinalgReduction(&opOperand);
1776     if (!reduceOp || !getCombinerOpKind(reduceOp)) {
1777       LDBG("reduction precondition failed: reduction detection failed\n");
1778       return failure();
1779     }
1780   }
1781   return success();
1782 }
1783 
1784 static LogicalResult
1785 vectorizeDynamicConvOpPrecondition(linalg::LinalgOp conv,
1786                                    bool flatten1DDepthwiseConv) {
1787   if (flatten1DDepthwiseConv) {
1788     LDBG("Vectorization of flattened convs with dynamic shapes is not "
1789          "supported\n");
1790     return failure();
1791   }
1792 
1793   if (!isa<linalg::DepthwiseConv1DNwcWcOp>(conv)) {
1794     LDBG("Not a 1D depth-wise WC conv, dynamic shapes are not supported\n");
1795     return failure();
1796   }
1797 
1798   // Support dynamic shapes in 1D depthwise convolution, but only in the
1799   // _channel_ dimension.
1800   Value lhs = conv.getDpsInputOperand(0)->get();
1801   ArrayRef<int64_t> lhsShape = cast<ShapedType>(lhs.getType()).getShape();
1802   auto shapeWithoutCh = lhsShape.drop_back(1);
1803   if (ShapedType::isDynamicShape(shapeWithoutCh)) {
1804     LDBG("Dynamically-shaped op vectorization precondition failed: only "
1805          "channel dim can be dynamic\n");
1806     return failure();
1807   }
1808 
1809   return success();
1810 }
1811 
1812 static LogicalResult
1813 vectorizeDynamicLinalgOpPrecondition(linalg::LinalgOp op,
1814                                      bool flatten1DDepthwiseConv) {
1815   if (isa<ConvolutionOpInterface>(op.getOperation()))
1816     return vectorizeDynamicConvOpPrecondition(op, flatten1DDepthwiseConv);
1817 
1818   if (hasReductionIterator(op))
1819     return reductionPreconditions(op);
1820 
1821   // TODO: Masking only supports dynamic element-wise ops, linalg.generic ops,
1822   // linalg.copy ops and ops that implement ContractionOpInterface for now.
1823   if (!isElementwise(op) &&
1824       !isa<linalg::GenericOp, linalg::CopyOp, linalg::ContractionOpInterface>(
1825           op.getOperation()))
1826     return failure();
1827 
1828   LDBG("Dynamically-shaped op meets vectorization pre-conditions\n");
1829   return success();
1830 }
1831 
1832 /// Need to check if the inner-tiles are static/constant.
1833 static LogicalResult
1834 vectorizeUnPackOpPrecondition(tensor::UnPackOp unpackOp,
1835                               ArrayRef<int64_t> inputVectorSizes) {
1836 
1837   if (llvm::any_of(unpackOp.getInnerTiles(), [](OpFoldResult res) {
1838         return !getConstantIntValue(res).has_value();
1839       })) {
1840     LDBG("Inner-tiles must be constant: " << unpackOp << "\n");
1841     return failure();
1842   }
1843   ArrayRef<int64_t> resultShape = unpackOp.getDestType().getShape();
1844   bool satisfyEmptyCond = inputVectorSizes.empty() &&
1845                           unpackOp.getDestType().hasStaticShape() &&
1846                           unpackOp.getSourceType().hasStaticShape();
1847   if (!satisfyEmptyCond &&
1848       failed(vector::isValidMaskedInputVector(resultShape, inputVectorSizes)))
1849     return failure();
1850 
1851   return success();
1852 }
1853 
1854 static LogicalResult vectorizeLinalgOpPrecondition(
1855     LinalgOp linalgOp, ArrayRef<int64_t> inputVectorSizes,
1856     bool vectorizeNDExtract, bool flatten1DDepthwiseConv) {
1857   // tensor with dimension of 0 cannot be vectorized.
1858   if (llvm::is_contained(linalgOp.getStaticShape(), 0))
1859     return failure();
1860   // Check API contract for input vector sizes.
1861   if (!inputVectorSizes.empty() &&
1862       failed(vector::isValidMaskedInputVector(linalgOp.getStaticLoopRanges(),
1863                                               inputVectorSizes)))
1864     return failure();
1865 
1866   if (linalgOp.hasDynamicShape() && failed(vectorizeDynamicLinalgOpPrecondition(
1867                                         linalgOp, flatten1DDepthwiseConv))) {
1868     LDBG("Dynamically-shaped op failed vectorization pre-conditions\n");
1869     return failure();
1870   }
1871 
1872   SmallVector<CustomVectorizationPrecondition> customPreconditions;
1873 
1874   // Register CustomVectorizationPrecondition for extractOp.
1875   customPreconditions.push_back(tensorExtractVectorizationPrecondition);
1876 
1877   // All types in the body should be a supported element type for VectorType.
1878   for (Operation &innerOp : linalgOp->getRegion(0).front()) {
1879     // Check if any custom hook can vectorize the inner op.
1880     if (llvm::any_of(
1881             customPreconditions,
1882             [&](const CustomVectorizationPrecondition &customPrecondition) {
1883               return succeeded(
1884                   customPrecondition(&innerOp, vectorizeNDExtract));
1885             })) {
1886       continue;
1887     }
1888     if (llvm::any_of(innerOp.getOperandTypes(), [](Type type) {
1889           return !VectorType::isValidElementType(type);
1890         })) {
1891       return failure();
1892     }
1893     if (llvm::any_of(innerOp.getResultTypes(), [](Type type) {
1894           return !VectorType::isValidElementType(type);
1895         })) {
1896       return failure();
1897     }
1898   }
1899   if (isElementwise(linalgOp))
1900     return success();
1901 
1902   // TODO: isaConvolutionOpInterface that can also infer from generic
1903   // features. But we will still need stride/dilation attributes that will be
1904   // annoying to reverse-engineer...
1905   if (isa<ConvolutionOpInterface>(linalgOp.getOperation()))
1906     return success();
1907   // TODO: the common vector shape is equal to the static loop sizes only when
1908   // all indexing maps are projected permutations. For convs and stencils the
1909   // logic will need to evolve.
1910   if (!allIndexingsAreProjectedPermutation(linalgOp)) {
1911     LDBG("precondition failed: not projected permutations\n");
1912     return failure();
1913   }
1914   if (failed(reductionPreconditions(linalgOp))) {
1915     LDBG("precondition failed: reduction preconditions\n");
1916     return failure();
1917   }
1918   return success();
1919 }
1920 
1921 static LogicalResult
1922 vectorizePackOpPrecondition(tensor::PackOp packOp,
1923                             ArrayRef<int64_t> inputVectorSizes) {
1924   auto padValue = packOp.getPaddingValue();
1925   Attribute cstAttr;
1926   if (padValue && !matchPattern(padValue, m_Constant(&cstAttr))) {
1927     LDBG("pad value is not constant: " << packOp << "\n");
1928     return failure();
1929   }
1930   ArrayRef<int64_t> resultTensorShape = packOp.getDestType().getShape();
1931   bool satisfyEmptyCond = true;
1932   if (inputVectorSizes.empty()) {
1933     if (!packOp.getDestType().hasStaticShape() ||
1934         !packOp.getSourceType().hasStaticShape())
1935       satisfyEmptyCond = false;
1936   }
1937 
1938   if (!satisfyEmptyCond &&
1939       failed(vector::isValidMaskedInputVector(
1940           resultTensorShape.take_front(packOp.getSourceRank()),
1941           inputVectorSizes)))
1942     return failure();
1943 
1944   if (llvm::any_of(packOp.getInnerTiles(), [](OpFoldResult v) {
1945         return !getConstantIntValue(v).has_value();
1946       })) {
1947     LDBG("inner_tiles must be constant: " << packOp << "\n");
1948     return failure();
1949   }
1950 
1951   return success();
1952 }
1953 
1954 static LogicalResult
1955 vectorizePadOpPrecondition(tensor::PadOp padOp,
1956                            ArrayRef<int64_t> inputVectorSizes) {
1957   auto padValue = padOp.getConstantPaddingValue();
1958   if (!padValue) {
1959     LDBG("pad value is not constant: " << padOp << "\n");
1960     return failure();
1961   }
1962 
1963   ArrayRef<int64_t> resultTensorShape = padOp.getResultType().getShape();
1964   if (failed(vector::isValidMaskedInputVector(resultTensorShape,
1965                                               inputVectorSizes)))
1966     return failure();
1967 
1968   if (llvm::any_of(padOp.getLow(), [](Value v) {
1969         std::optional<int64_t> res = getConstantIntValue(v);
1970         return !res.has_value() || res.value() != 0;
1971       })) {
1972     LDBG("low pad must all be zero: " << padOp << "\n");
1973     return failure();
1974   }
1975 
1976   return success();
1977 }
1978 
1979 /// Preconditions for scalable vectors. This is quite restrictive - it models
1980 /// the fact that in practice we would only make selected dimensions scalable.
1981 static LogicalResult
1982 vectorizeScalableVectorPrecondition(Operation *op,
1983                                     ArrayRef<int64_t> inputVectorSizes,
1984                                     ArrayRef<bool> inputScalableVecDims) {
1985   assert(inputVectorSizes.size() == inputScalableVecDims.size() &&
1986          "Number of input vector sizes and scalable dims doesn't match");
1987 
1988   size_t numOfScalableDims =
1989       llvm::count_if(inputScalableVecDims, [](bool flag) { return flag; });
1990 
1991   if (numOfScalableDims == 0)
1992     return success();
1993 
1994   auto linalgOp = dyn_cast<LinalgOp>(op);
1995 
1996   // Cond 1: There's been no need for scalable vectorisation of
1997   // non-linalg Ops so far
1998   if (!linalgOp)
1999     return failure();
2000 
2001   // Cond 2: There's been no need for more than 2 scalable dims so far
2002   if (numOfScalableDims > 2)
2003     return failure();
2004 
2005   // Cond 3: Look at the configuration in `inputScalableVecDims` and verify that
2006   // it matches one of the supported cases:
2007   //  1. exactly 1 dim is scalable and that's the _last_ parallel dim
2008   //  2. exactly 2 dims are scalable and those are the _last two adjacent_
2009   //     parallel dims
2010   //  3. exactly 1 reduction dim is scalable and that's the last (innermost) dim
2011   // The 2nd restriction above means that only Matmul-like Ops are supported
2012   // when 2 dims are scalable, e.g. :
2013   //    * iterators = [parallel, parallel, reduction]
2014   //    * scalable flags = [true, true, false]
2015 
2016   // Find the first scalable flag
2017   bool seenParalell = false;
2018   auto iterators = linalgOp.getIteratorTypesArray();
2019   SmallVector<bool> scalableFlags(inputScalableVecDims);
2020   while (!scalableFlags.back()) {
2021     seenParalell |= (iterators.back() == utils::IteratorType::parallel);
2022 
2023     iterators.pop_back();
2024     scalableFlags.pop_back();
2025   }
2026 
2027   switch (iterators.back()) {
2028   case utils::IteratorType::reduction: {
2029     // Check 3. above is met.
2030     if (iterators.size() != inputVectorSizes.size()) {
2031       LDBG("Non-trailing reduction dim requested for scalable "
2032            "vectorization\n");
2033       return failure();
2034     }
2035     if (isa<linalg::MatmulOp>(op) || isa<linalg::MatmulTransposeAOp>(op)) {
2036       LDBG("Scalable vectorization of the reduction dim in Matmul-like ops "
2037            "is not supported\n");
2038       return failure();
2039     }
2040     break;
2041   }
2042   case utils::IteratorType::parallel: {
2043     // Check 1. and 2. above are met.
2044     if (seenParalell) {
2045       LDBG("Inner parallel dim not requested for scalable "
2046            "vectorization\n");
2047       return failure();
2048     }
2049     break;
2050   }
2051   }
2052 
2053   // If present, check the 2nd scalable dim. ATM, only Matmul-like Ops are
2054   // supported for which expect the folowing config:
2055   //    * iterators = [parallel, parallel, reduction]
2056   //    * scalable flags = [true, true, false]
2057   if (numOfScalableDims == 2) {
2058     // Disallow below case which breaks 3. above:
2059     //    * iterators = [..., parallel, reduction]
2060     //    * scalable flags = [..., true, true]
2061     if (iterators.back() == utils::IteratorType::reduction) {
2062       LDBG("Higher dim than the trailing reduction dim requested for scalable "
2063            "vectorization\n");
2064       return failure();
2065     }
2066     scalableFlags.pop_back();
2067     iterators.pop_back();
2068 
2069     if (!scalableFlags.back() ||
2070         (iterators.back() != utils::IteratorType::parallel))
2071       return failure();
2072   }
2073 
2074   // Cond 4: Only the following ops are supported in the
2075   // presence of scalable vectors
2076   return success(isElementwise(linalgOp) || isa<linalg::MatmulOp>(op) ||
2077                  isa<linalg::MatmulTransposeAOp>(op) ||
2078                  isa<linalg::DepthwiseConv1DNwcWcOp>(op) ||
2079                  isa<linalg::MatvecOp>(op) || hasReductionIterator(linalgOp));
2080 }
2081 
2082 LogicalResult mlir::linalg::vectorizeOpPrecondition(
2083     Operation *op, ArrayRef<int64_t> inputVectorSizes,
2084     ArrayRef<bool> inputScalableVecDims, bool vectorizeNDExtract,
2085     bool flatten1DDepthwiseConv) {
2086 
2087   if (!hasVectorizationImpl(op))
2088     return failure();
2089 
2090   if (failed(vectorizeScalableVectorPrecondition(op, inputVectorSizes,
2091                                                  inputScalableVecDims)))
2092     return failure();
2093 
2094   return TypeSwitch<Operation *, LogicalResult>(op)
2095       .Case<linalg::LinalgOp>([&](auto linalgOp) {
2096         return vectorizeLinalgOpPrecondition(linalgOp, inputVectorSizes,
2097                                              vectorizeNDExtract,
2098                                              flatten1DDepthwiseConv);
2099       })
2100       .Case<tensor::PadOp>([&](auto padOp) {
2101         return vectorizePadOpPrecondition(padOp, inputVectorSizes);
2102       })
2103       .Case<tensor::PackOp>([&](auto packOp) {
2104         return vectorizePackOpPrecondition(packOp, inputVectorSizes);
2105       })
2106       .Case<tensor::UnPackOp>([&](auto unpackOp) {
2107         return vectorizeUnPackOpPrecondition(unpackOp, inputVectorSizes);
2108       })
2109       .Default([](auto) { return failure(); });
2110 }
2111 
2112 /// Converts affine.apply Ops to arithmetic operations.
2113 static void convertAffineApply(RewriterBase &rewriter, LinalgOp linalgOp) {
2114   OpBuilder::InsertionGuard g(rewriter);
2115   auto toReplace = linalgOp.getBlock()->getOps<affine::AffineApplyOp>();
2116 
2117   for (auto op : make_early_inc_range(toReplace)) {
2118     rewriter.setInsertionPoint(op);
2119     auto expanded = affine::expandAffineExpr(
2120         rewriter, op->getLoc(), op.getAffineMap().getResult(0),
2121         op.getOperands().take_front(op.getAffineMap().getNumDims()),
2122         op.getOperands().take_back(op.getAffineMap().getNumSymbols()));
2123     rewriter.replaceOp(op, expanded);
2124   }
2125 }
2126 
2127 bool mlir::linalg::hasVectorizationImpl(Operation *op) {
2128   return isa<linalg::LinalgOp, tensor::PadOp, tensor::PackOp, tensor::UnPackOp>(
2129       op);
2130 }
2131 
2132 /// Emit a suitable vector form for an operation. If provided,
2133 /// `inputVectorSizes` are used to vectorize this operation.
2134 /// `inputVectorSizes` must match the rank of the iteration space of the
2135 /// operation and the input vector sizes must be greater than or equal to
2136 /// their counterpart iteration space sizes, if static. `inputVectorShapes`
2137 /// also allows the vectorization of operations with dynamic shapes.
2138 LogicalResult mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op,
2139                                       ArrayRef<int64_t> inputVectorSizes,
2140                                       ArrayRef<bool> inputScalableVecDims,
2141                                       bool vectorizeNDExtract,
2142                                       bool flatten1DDepthwiseConv) {
2143   LDBG("Attempting to vectorize:\n" << *op << "\n");
2144   LDBG("Input vector sizes: ");
2145   LLVM_DEBUG(llvm::interleaveComma(inputVectorSizes, llvm::dbgs()));
2146   LLVM_DEBUG(llvm::dbgs() << "\n");
2147   LDBG("Input scalable vector dims: ");
2148   LLVM_DEBUG(llvm::interleaveComma(inputScalableVecDims, llvm::dbgs()));
2149   LLVM_DEBUG(llvm::dbgs() << "\n");
2150 
2151   if (failed(vectorizeOpPrecondition(op, inputVectorSizes, inputScalableVecDims,
2152                                      vectorizeNDExtract,
2153                                      flatten1DDepthwiseConv))) {
2154     LDBG("Vectorization pre-conditions failed\n");
2155     return failure();
2156   }
2157 
2158   // Initialize vectorization state.
2159   VectorizationState state(rewriter);
2160   if (auto linalgOp = dyn_cast<linalg::LinalgOp>(op)) {
2161     if (failed(state.initState(rewriter, linalgOp, inputVectorSizes,
2162                                inputScalableVecDims))) {
2163       LDBG("Vectorization state couldn't be initialized\n");
2164       return failure();
2165     }
2166   }
2167 
2168   SmallVector<Value> results;
2169   auto vectorizeResult =
2170       TypeSwitch<Operation *, LogicalResult>(op)
2171           .Case<linalg::LinalgOp>([&](auto linalgOp) {
2172             // TODO: isaConvolutionOpInterface that can also infer from
2173             // generic features. Will require stride/dilation attributes
2174             // inference.
2175             if (isa<ConvolutionOpInterface>(linalgOp.getOperation())) {
2176               FailureOr<Operation *> convOr = vectorizeConvolution(
2177                   rewriter, linalgOp, inputVectorSizes, inputScalableVecDims,
2178                   flatten1DDepthwiseConv);
2179               if (succeeded(convOr)) {
2180                 llvm::append_range(results, (*convOr)->getResults());
2181                 return success();
2182               }
2183 
2184               LDBG("Unsupported convolution can't be vectorized.\n");
2185               return failure();
2186             }
2187 
2188             LDBG("Vectorize generic by broadcasting to the canonical vector "
2189                  "shape\n");
2190 
2191             // Pre-process before proceeding.
2192             convertAffineApply(rewriter, linalgOp);
2193 
2194             // TODO: 'vectorize' takes in a 'RewriterBase' which is up-casted
2195             // to 'OpBuilder' when it is passed over to some methods like
2196             // 'vectorizeAsLinalgGeneric'. This is highly problematic: if we
2197             // erase an op within these methods, the actual rewriter won't be
2198             // notified and we will end up with read-after-free issues!
2199             return vectorizeAsLinalgGeneric(rewriter, state, linalgOp, results);
2200           })
2201           .Case<tensor::PadOp>([&](auto padOp) {
2202             return vectorizeAsTensorPadOp(rewriter, padOp, inputVectorSizes,
2203                                           results);
2204           })
2205           .Case<tensor::PackOp>([&](auto packOp) {
2206             return vectorizeAsTensorPackOp(rewriter, packOp, inputVectorSizes,
2207                                            results);
2208           })
2209           .Case<tensor::UnPackOp>([&](auto unpackOp) {
2210             return vectorizeAsTensorUnpackOp(rewriter, unpackOp,
2211                                              inputVectorSizes, results);
2212           })
2213           .Default([](auto) { return failure(); });
2214 
2215   if (failed(vectorizeResult)) {
2216     LDBG("Vectorization failed\n");
2217     return failure();
2218   }
2219 
2220   if (!results.empty())
2221     rewriter.replaceOp(op, results);
2222   else
2223     rewriter.eraseOp(op);
2224 
2225   return success();
2226 }
2227 
2228 LogicalResult mlir::linalg::vectorizeCopy(RewriterBase &rewriter,
2229                                           memref::CopyOp copyOp) {
2230   auto srcType = cast<MemRefType>(copyOp.getSource().getType());
2231   auto dstType = cast<MemRefType>(copyOp.getTarget().getType());
2232   if (!srcType.hasStaticShape() || !dstType.hasStaticShape())
2233     return failure();
2234 
2235   auto srcElementType = getElementTypeOrSelf(srcType);
2236   auto dstElementType = getElementTypeOrSelf(dstType);
2237   if (!VectorType::isValidElementType(srcElementType) ||
2238       !VectorType::isValidElementType(dstElementType))
2239     return failure();
2240 
2241   auto readType = VectorType::get(srcType.getShape(), srcElementType);
2242   auto writeType = VectorType::get(dstType.getShape(), dstElementType);
2243 
2244   Location loc = copyOp->getLoc();
2245   Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
2246   SmallVector<Value> indices(srcType.getRank(), zero);
2247 
2248   Value readValue = rewriter.create<vector::TransferReadOp>(
2249       loc, readType, copyOp.getSource(), indices,
2250       rewriter.getMultiDimIdentityMap(srcType.getRank()));
2251   if (cast<VectorType>(readValue.getType()).getRank() == 0) {
2252     readValue = rewriter.create<vector::ExtractElementOp>(loc, readValue);
2253     readValue = rewriter.create<vector::BroadcastOp>(loc, writeType, readValue);
2254   }
2255   Operation *writeValue = rewriter.create<vector::TransferWriteOp>(
2256       loc, readValue, copyOp.getTarget(), indices,
2257       rewriter.getMultiDimIdentityMap(srcType.getRank()));
2258   rewriter.replaceOp(copyOp, writeValue->getResults());
2259   return success();
2260 }
2261 
2262 //----------------------------------------------------------------------------//
2263 // Misc. vectorization patterns.
2264 //----------------------------------------------------------------------------//
2265 
2266 /// Helper function that retrieves the value of an IntegerAttr.
2267 static int64_t getIntFromAttr(Attribute attr) {
2268   return cast<IntegerAttr>(attr).getInt();
2269 }
2270 
2271 /// Given an ArrayRef of OpFoldResults, return a vector of Values.
2272 /// IntegerAttrs are converted to ConstantIndexOps. Other attribute types are
2273 /// not supported.
2274 static SmallVector<Value> ofrToIndexValues(RewriterBase &rewriter, Location loc,
2275                                            ArrayRef<OpFoldResult> ofrs) {
2276   SmallVector<Value> result;
2277   for (auto o : ofrs) {
2278     if (auto val = llvm::dyn_cast_if_present<Value>(o)) {
2279       result.push_back(val);
2280     } else {
2281       result.push_back(rewriter.create<arith::ConstantIndexOp>(
2282           loc, getIntFromAttr(o.template get<Attribute>())));
2283     }
2284   }
2285   return result;
2286 }
2287 
2288 /// Rewrite a tensor::PadOp into a sequence of EmptyOp, FillOp and
2289 /// InsertSliceOp. For now, only constant padding values are supported.
2290 /// If there is enough static type information, TransferReadOps and
2291 /// TransferWriteOps may be generated instead of InsertSliceOps.
2292 struct GenericPadOpVectorizationPattern : public GeneralizePadOpPattern {
2293   GenericPadOpVectorizationPattern(MLIRContext *context,
2294                                    PatternBenefit benefit = 1)
2295       : GeneralizePadOpPattern(context, tryVectorizeCopy, benefit) {}
2296   /// Vectorize the copying of a tensor::PadOp's source. This is possible if
2297   /// each dimension size is statically know in the source type or the result
2298   /// type (or both).
2299   static LogicalResult tryVectorizeCopy(RewriterBase &rewriter,
2300                                         tensor::PadOp padOp, Value dest) {
2301     auto sourceType = padOp.getSourceType();
2302     auto resultType = padOp.getResultType();
2303     if (!VectorType::isValidElementType(sourceType.getElementType()))
2304       return failure();
2305 
2306     // Copy cannot be vectorized if pad value is non-constant and source shape
2307     // is dynamic. In case of a dynamic source shape, padding must be appended
2308     // by TransferReadOp, but TransferReadOp supports only constant padding.
2309     auto padValue = padOp.getConstantPaddingValue();
2310     if (!padValue) {
2311       if (!sourceType.hasStaticShape())
2312         return failure();
2313       // Create dummy padding value.
2314       auto elemType = sourceType.getElementType();
2315       padValue = rewriter.create<arith::ConstantOp>(
2316           padOp.getLoc(), elemType, rewriter.getZeroAttr(elemType));
2317     }
2318 
2319     SmallVector<int64_t> vecShape;
2320     SmallVector<bool> readInBounds;
2321     SmallVector<bool> writeInBounds;
2322     for (unsigned i = 0; i < sourceType.getRank(); ++i) {
2323       if (!sourceType.isDynamicDim(i)) {
2324         vecShape.push_back(sourceType.getDimSize(i));
2325         // Source shape is statically known: Neither read nor write are
2326         // out-of- bounds.
2327         readInBounds.push_back(true);
2328         writeInBounds.push_back(true);
2329       } else if (!resultType.isDynamicDim(i)) {
2330         // Source shape is not statically known, but result shape is.
2331         // Vectorize with size of result shape. This may be larger than the
2332         // source size.
2333         vecShape.push_back(resultType.getDimSize(i));
2334         // Read may be out-of-bounds because the result size could be larger
2335         // than the source size.
2336         readInBounds.push_back(false);
2337         // Write is out-of-bounds if low padding > 0.
2338         writeInBounds.push_back(
2339             getConstantIntValue(padOp.getMixedLowPad()[i]) ==
2340             static_cast<int64_t>(0));
2341       } else {
2342         // Neither source nor result dim of padOp is static. Cannot vectorize
2343         // the copy.
2344         return failure();
2345       }
2346     }
2347     auto vecType = VectorType::get(vecShape, sourceType.getElementType());
2348 
2349     // Generate TransferReadOp.
2350     SmallVector<Value> readIndices(
2351         vecType.getRank(),
2352         rewriter.create<arith::ConstantIndexOp>(padOp.getLoc(), 0));
2353     auto read = rewriter.create<vector::TransferReadOp>(
2354         padOp.getLoc(), vecType, padOp.getSource(), readIndices, padValue,
2355         ArrayRef<bool>{readInBounds});
2356 
2357     // If `dest` is a FillOp and the TransferWriteOp would overwrite the
2358     // entire tensor, write directly to the FillOp's operand.
2359     if (llvm::equal(vecShape, resultType.getShape()) &&
2360         llvm::all_of(writeInBounds, [](bool b) { return b; }))
2361       if (auto fill = dest.getDefiningOp<FillOp>())
2362         dest = fill.output();
2363 
2364     // Generate TransferWriteOp.
2365     auto writeIndices =
2366         ofrToIndexValues(rewriter, padOp.getLoc(), padOp.getMixedLowPad());
2367     rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
2368         padOp, read, dest, writeIndices, ArrayRef<bool>{writeInBounds});
2369 
2370     return success();
2371   }
2372 };
2373 
2374 /// Base pattern for rewriting tensor::PadOps whose result is consumed by a
2375 /// given operation type OpTy.
2376 template <typename OpTy>
2377 struct VectorizePadOpUserPattern : public OpRewritePattern<tensor::PadOp> {
2378   using OpRewritePattern<tensor::PadOp>::OpRewritePattern;
2379 
2380   LogicalResult matchAndRewrite(tensor::PadOp padOp,
2381                                 PatternRewriter &rewriter) const final {
2382     bool changed = false;
2383     // Insert users in vector, because some users may be replaced/removed.
2384     for (auto *user : llvm::to_vector<4>(padOp->getUsers()))
2385       if (auto op = dyn_cast<OpTy>(user))
2386         changed |= rewriteUser(rewriter, padOp, op).succeeded();
2387     return success(changed);
2388   }
2389 
2390 protected:
2391   virtual LogicalResult rewriteUser(PatternRewriter &rewriter,
2392                                     tensor::PadOp padOp, OpTy op) const = 0;
2393 };
2394 
2395 /// Rewrite use of tensor::PadOp result in TransferReadOp. E.g.:
2396 /// ```
2397 /// %0 = tensor.pad %src ... : tensor<?x?xf32> to tensor<17x5xf32>
2398 /// %r = vector.transfer_read %0[%c0, %c0], %cst
2399 ///     {in_bounds = [true, true]} : tensor<17x5xf32>, vector<17x5xf32>
2400 /// ```
2401 /// is rewritten to:
2402 /// ```
2403 /// %r = vector.transfer_read %src[%c0, %c0], %padding
2404 ///     {in_bounds = [true, true]}
2405 ///     : tensor<?x?xf32>, vector<17x5xf32>
2406 /// ```
2407 /// Note: By restricting this pattern to in-bounds TransferReadOps, we can be
2408 /// sure that the original padding value %cst was never used.
2409 ///
2410 /// This rewrite is possible if:
2411 /// - `xferOp` has no out-of-bounds dims or mask.
2412 /// - Low padding is static 0.
2413 /// - Single, scalar padding value.
2414 struct PadOpVectorizationWithTransferReadPattern
2415     : public VectorizePadOpUserPattern<vector::TransferReadOp> {
2416   using VectorizePadOpUserPattern<
2417       vector::TransferReadOp>::VectorizePadOpUserPattern;
2418 
2419   LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp,
2420                             vector::TransferReadOp xferOp) const override {
2421     // Low padding must be static 0.
2422     if (!padOp.hasZeroLowPad())
2423       return failure();
2424     // Pad value must be a constant.
2425     auto padValue = padOp.getConstantPaddingValue();
2426     if (!padValue)
2427       return failure();
2428     // Padding value of existing `xferOp` is unused.
2429     if (xferOp.hasOutOfBoundsDim() || xferOp.getMask())
2430       return failure();
2431 
2432     rewriter.modifyOpInPlace(xferOp, [&]() {
2433       SmallVector<bool> inBounds(xferOp.getVectorType().getRank(), false);
2434       xferOp->setAttr(xferOp.getInBoundsAttrName(),
2435                       rewriter.getBoolArrayAttr(inBounds));
2436       xferOp.getSourceMutable().assign(padOp.getSource());
2437       xferOp.getPaddingMutable().assign(padValue);
2438     });
2439 
2440     return success();
2441   }
2442 };
2443 
2444 /// Rewrite use of tensor::PadOp result in TransferWriteOp.
2445 /// This pattern rewrites TransferWriteOps that write to a padded tensor
2446 /// value, where the same amount of padding is immediately removed again after
2447 /// the write. In such cases, the TransferWriteOp can write to the non-padded
2448 /// tensor value and apply out-of-bounds masking. E.g.:
2449 /// ```
2450 /// %0 = tensor.extract_slice ...[...] [%s0, %s1] [1, 1]
2451 ///     : tensor<...> to tensor<?x?xf32>
2452 /// %1 = tensor.pad %0 ... : tensor<?x?xf32> to tensor<17x5xf32>
2453 /// %2 = vector.transfer_write %vec, %1[...]
2454 ///     : vector<17x5xf32>, tensor<17x5xf32>
2455 /// %r = tensor.extract_slice %2[0, 0] [%s0, %s1] [1, 1]
2456 ///     : tensor<17x5xf32> to tensor<?x?xf32>
2457 /// ```
2458 /// is rewritten to:
2459 /// ```
2460 /// %0 = tensor.extract_slice ...[...] [%s0, %s1] [1, 1]
2461 ///     : tensor<...> to tensor<?x?xf32>
2462 /// %r = vector.transfer_write %vec, %0[...] : vector<17x5xf32>,
2463 /// tensor<?x?xf32>
2464 /// ```
2465 /// Note: It is important that the ExtractSliceOp %r resizes the result of the
2466 /// TransferWriteOp to the same size as the input of the TensorPadOp (or an
2467 /// even smaller size). Otherwise, %r's new (dynamic) dimensions would differ
2468 /// from %r's old dimensions.
2469 ///
2470 /// This rewrite is possible if:
2471 /// - Low padding is static 0.
2472 /// - `xferOp` has exactly one use, which is an ExtractSliceOp. This
2473 ///   ExtractSliceOp trims the same amount of padding that was added
2474 ///   beforehand.
2475 /// - Single, scalar padding value.
2476 struct PadOpVectorizationWithTransferWritePattern
2477     : public VectorizePadOpUserPattern<vector::TransferWriteOp> {
2478   using VectorizePadOpUserPattern<
2479       vector::TransferWriteOp>::VectorizePadOpUserPattern;
2480 
2481   LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp,
2482                             vector::TransferWriteOp xferOp) const override {
2483     // TODO: support 0-d corner case.
2484     if (xferOp.getTransferRank() == 0)
2485       return failure();
2486 
2487     // Low padding must be static 0.
2488     if (!padOp.hasZeroLowPad())
2489       return failure();
2490     // Pad value must be a constant.
2491     auto padValue = padOp.getConstantPaddingValue();
2492     if (!padValue)
2493       return failure();
2494     // TransferWriteOp result must be directly consumed by an ExtractSliceOp.
2495     if (!xferOp->hasOneUse())
2496       return failure();
2497     auto trimPadding = dyn_cast<tensor::ExtractSliceOp>(*xferOp->user_begin());
2498     if (!trimPadding)
2499       return failure();
2500     // Only static zero offsets supported when trimming padding.
2501     if (!trimPadding.hasZeroOffset())
2502       return failure();
2503     // trimPadding must remove the amount of padding that was added earlier.
2504     if (!hasSameTensorSize(padOp.getSource(), trimPadding))
2505       return failure();
2506 
2507     // Insert the new TransferWriteOp at position of the old TransferWriteOp.
2508     rewriter.setInsertionPoint(xferOp);
2509 
2510     SmallVector<bool> inBounds(xferOp.getVectorType().getRank(), false);
2511     auto newXferOp = rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
2512         xferOp, padOp.getSource().getType(), xferOp.getVector(),
2513         padOp.getSource(), xferOp.getIndices(), xferOp.getPermutationMapAttr(),
2514         xferOp.getMask(), rewriter.getBoolArrayAttr(inBounds));
2515     rewriter.replaceOp(trimPadding, newXferOp->getResult(0));
2516 
2517     return success();
2518   }
2519 
2520   /// Check if `beforePadding` and `afterTrimming` have the same tensor size,
2521   /// i.e., same dimensions.
2522   ///
2523   /// Dimensions may be static, dynamic or mix of both. In case of dynamic
2524   /// dimensions, this function tries to infer the (static) tensor size by
2525   /// looking at the defining op and utilizing op-specific knowledge.
2526   ///
2527   /// This is a conservative analysis. In case equal tensor sizes cannot be
2528   /// proven statically, this analysis returns `false` even though the tensor
2529   /// sizes may turn out to be equal at runtime.
2530   bool hasSameTensorSize(Value beforePadding,
2531                          tensor::ExtractSliceOp afterTrimming) const {
2532     // If the input to tensor::PadOp is a CastOp, try with both CastOp
2533     // result and CastOp operand.
2534     if (auto castOp = beforePadding.getDefiningOp<tensor::CastOp>())
2535       if (hasSameTensorSize(castOp.getSource(), afterTrimming))
2536         return true;
2537 
2538     auto t1 = dyn_cast<RankedTensorType>(beforePadding.getType());
2539     auto t2 = dyn_cast<RankedTensorType>(afterTrimming.getType());
2540     // Only RankedTensorType supported.
2541     if (!t1 || !t2)
2542       return false;
2543     // Rank of both values must be the same.
2544     if (t1.getRank() != t2.getRank())
2545       return false;
2546 
2547     // All static dimensions must be the same. Mixed cases (e.g., dimension
2548     // static in `t1` but dynamic in `t2`) are not supported.
2549     for (unsigned i = 0; i < t1.getRank(); ++i) {
2550       if (t1.isDynamicDim(i) != t2.isDynamicDim(i))
2551         return false;
2552       if (!t1.isDynamicDim(i) && t1.getDimSize(i) != t2.getDimSize(i))
2553         return false;
2554     }
2555 
2556     // Nothing more to check if all dimensions are static.
2557     if (t1.getNumDynamicDims() == 0)
2558       return true;
2559 
2560     // All dynamic sizes must be the same. The only supported case at the
2561     // moment is when `beforePadding` is an ExtractSliceOp (or a cast
2562     // thereof).
2563 
2564     // Apart from CastOp, only ExtractSliceOp is supported.
2565     auto beforeSlice = beforePadding.getDefiningOp<tensor::ExtractSliceOp>();
2566     if (!beforeSlice)
2567       return false;
2568 
2569     assert(static_cast<size_t>(t1.getRank()) ==
2570            beforeSlice.getMixedSizes().size());
2571     assert(static_cast<size_t>(t2.getRank()) ==
2572            afterTrimming.getMixedSizes().size());
2573 
2574     for (unsigned i = 0; i < t1.getRank(); ++i) {
2575       // Skip static dimensions.
2576       if (!t1.isDynamicDim(i))
2577         continue;
2578       auto size1 = beforeSlice.getMixedSizes()[i];
2579       auto size2 = afterTrimming.getMixedSizes()[i];
2580 
2581       // Case 1: Same value or same constant int.
2582       if (isEqualConstantIntOrValue(size1, size2))
2583         continue;
2584 
2585       // Other cases: Take a deeper look at defining ops of values.
2586       auto v1 = llvm::dyn_cast_if_present<Value>(size1);
2587       auto v2 = llvm::dyn_cast_if_present<Value>(size2);
2588       if (!v1 || !v2)
2589         return false;
2590 
2591       // Case 2: Both values are identical AffineMinOps. (Should not happen if
2592       // CSE is run.)
2593       auto minOp1 = v1.getDefiningOp<affine::AffineMinOp>();
2594       auto minOp2 = v2.getDefiningOp<affine::AffineMinOp>();
2595       if (minOp1 && minOp2 && minOp1.getAffineMap() == minOp2.getAffineMap() &&
2596           minOp1.getOperands() == minOp2.getOperands())
2597         continue;
2598 
2599       // Add additional cases as needed.
2600     }
2601 
2602     // All tests passed.
2603     return true;
2604   }
2605 };
2606 
2607 /// Rewrite use of tensor::PadOp result in InsertSliceOp. E.g.:
2608 /// ```
2609 /// %0 = tensor.pad %src ... : tensor<?x?xf32> to tensor<17x5xf32>
2610 /// %r = tensor.insert_slice %0
2611 ///     into %dest[%a, %b, 0, 0] [1, 1, 17, 5] [1, 1, 1, 1]
2612 ///     : tensor<17x5xf32> into tensor<?x?x17x5xf32>
2613 /// ```
2614 /// is rewritten to:
2615 /// ```
2616 /// %0 = vector.transfer_read %src[%c0, %c0], %padding
2617 ///     : tensor<?x?xf32>, vector<17x5xf32>
2618 /// %r = vector.transfer_write %0, %dest[%a, %b, %c0, %c0]
2619 ///     {in_bounds = [true, true]} : vector<17x5xf32>, tensor<?x?x17x5xf32>
2620 /// ```
2621 ///
2622 /// This rewrite is possible if:
2623 /// - Low padding is static 0.
2624 /// - `padOp` result shape is static.
2625 /// - The entire padded tensor is inserted.
2626 ///   (Implies that sizes of `insertOp` are all static.)
2627 /// - Only unit strides in `insertOp`.
2628 /// - Single, scalar padding value.
2629 /// - `padOp` result not used as destination.
2630 struct PadOpVectorizationWithInsertSlicePattern
2631     : public VectorizePadOpUserPattern<tensor::InsertSliceOp> {
2632   using VectorizePadOpUserPattern<
2633       tensor::InsertSliceOp>::VectorizePadOpUserPattern;
2634 
2635   LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp,
2636                             tensor::InsertSliceOp insertOp) const override {
2637     // Low padding must be static 0.
2638     if (!padOp.hasZeroLowPad())
2639       return failure();
2640     // Only unit stride supported.
2641     if (!insertOp.hasUnitStride())
2642       return failure();
2643     // Pad value must be a constant.
2644     auto padValue = padOp.getConstantPaddingValue();
2645     if (!padValue)
2646       return failure();
2647     // Dynamic shapes not supported.
2648     if (!cast<ShapedType>(padOp.getResult().getType()).hasStaticShape())
2649       return failure();
2650     // Pad result not used as destination.
2651     if (insertOp.getDest() == padOp.getResult())
2652       return failure();
2653 
2654     auto vecType = VectorType::get(padOp.getType().getShape(),
2655                                    padOp.getType().getElementType());
2656     unsigned vecRank = vecType.getRank();
2657     unsigned tensorRank = insertOp.getType().getRank();
2658 
2659     // Check if sizes match: Insert the entire tensor into most minor dims.
2660     // (No permutations allowed.)
2661     SmallVector<int64_t> expectedSizes(tensorRank - vecRank, 1);
2662     expectedSizes.append(vecType.getShape().begin(), vecType.getShape().end());
2663     if (!llvm::all_of(
2664             llvm::zip(insertOp.getMixedSizes(), expectedSizes), [](auto it) {
2665               return getConstantIntValue(std::get<0>(it)) == std::get<1>(it);
2666             }))
2667       return failure();
2668 
2669     // Insert the TransferReadOp and TransferWriteOp at the position of the
2670     // InsertSliceOp.
2671     rewriter.setInsertionPoint(insertOp);
2672 
2673     // Generate TransferReadOp: Read entire source tensor and add high
2674     // padding.
2675     SmallVector<Value> readIndices(
2676         vecRank, rewriter.create<arith::ConstantIndexOp>(padOp.getLoc(), 0));
2677     auto read = rewriter.create<vector::TransferReadOp>(
2678         padOp.getLoc(), vecType, padOp.getSource(), readIndices, padValue);
2679 
2680     // Generate TransferWriteOp: Write to InsertSliceOp's dest tensor at
2681     // specified offsets. Write is fully in-bounds because a InsertSliceOp's
2682     // source must fit into the destination at the specified offsets.
2683     auto writeIndices =
2684         ofrToIndexValues(rewriter, padOp.getLoc(), insertOp.getMixedOffsets());
2685     SmallVector<bool> inBounds(vecRank, true);
2686     rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
2687         insertOp, read, insertOp.getDest(), writeIndices,
2688         ArrayRef<bool>{inBounds});
2689 
2690     return success();
2691   }
2692 };
2693 
2694 void mlir::linalg::populatePadOpVectorizationPatterns(
2695     RewritePatternSet &patterns, PatternBenefit baseBenefit) {
2696   patterns.add<GenericPadOpVectorizationPattern>(patterns.getContext(),
2697                                                  baseBenefit);
2698   // Try these specialized patterns first before resorting to the generic one.
2699   patterns.add<PadOpVectorizationWithTransferReadPattern,
2700                PadOpVectorizationWithTransferWritePattern,
2701                PadOpVectorizationWithInsertSlicePattern>(
2702       patterns.getContext(), baseBenefit.getBenefit() + 1);
2703 }
2704 
2705 //----------------------------------------------------------------------------//
2706 // Forwarding patterns
2707 //----------------------------------------------------------------------------//
2708 
2709 /// Check whether there is any interleaved use of any `values` between
2710 /// `firstOp` and `secondOp`. Conservatively return `true` if any op or value
2711 /// is in a different block.
2712 static bool mayExistInterleavedUses(Operation *firstOp, Operation *secondOp,
2713                                     ValueRange values) {
2714   if (firstOp->getBlock() != secondOp->getBlock() ||
2715       !firstOp->isBeforeInBlock(secondOp)) {
2716     LDBG("interleavedUses precondition failed, firstOp: "
2717          << *firstOp << ", second op: " << *secondOp << "\n");
2718     return true;
2719   }
2720   for (auto v : values) {
2721     for (auto &u : v.getUses()) {
2722       Operation *owner = u.getOwner();
2723       if (owner == firstOp || owner == secondOp)
2724         continue;
2725       // TODO: this is too conservative, use dominance info in the future.
2726       if (owner->getBlock() == firstOp->getBlock() &&
2727           (owner->isBeforeInBlock(firstOp) || secondOp->isBeforeInBlock(owner)))
2728         continue;
2729       LDBG(" found interleaved op " << *owner << ", firstOp: " << *firstOp
2730                                     << ", second op: " << *secondOp << "\n");
2731       return true;
2732     }
2733   }
2734   return false;
2735 }
2736 
2737 /// Return the unique subview use of `v` if it is indeed unique, null
2738 /// otherwise.
2739 static memref::SubViewOp getSubViewUseIfUnique(Value v) {
2740   memref::SubViewOp subViewOp;
2741   for (auto &u : v.getUses()) {
2742     if (auto newSubViewOp = dyn_cast<memref::SubViewOp>(u.getOwner())) {
2743       if (subViewOp)
2744         return memref::SubViewOp();
2745       subViewOp = newSubViewOp;
2746     }
2747   }
2748   return subViewOp;
2749 }
2750 
2751 /// TODO: use interfaces, side-effects and aliasing analysis as appropriate,
2752 /// when available.
2753 LogicalResult LinalgCopyVTRForwardingPattern::matchAndRewrite(
2754     vector::TransferReadOp xferOp, PatternRewriter &rewriter) const {
2755 
2756   // TODO: support mask.
2757   if (xferOp.getMask())
2758     return rewriter.notifyMatchFailure(xferOp, "unsupported mask");
2759 
2760   // Transfer into `view`.
2761   Value viewOrAlloc = xferOp.getSource();
2762   if (!viewOrAlloc.getDefiningOp<memref::ViewOp>() &&
2763       !viewOrAlloc.getDefiningOp<memref::AllocOp>())
2764     return rewriter.notifyMatchFailure(xferOp, "source not a view or alloc");
2765 
2766   // Ensure there is exactly one subview of `viewOrAlloc` defining `subView`.
2767   memref::SubViewOp subViewOp = getSubViewUseIfUnique(viewOrAlloc);
2768   if (!subViewOp)
2769     return rewriter.notifyMatchFailure(xferOp, "no subview found");
2770   Value subView = subViewOp.getResult();
2771 
2772   // Find the copy into `subView` without interleaved uses.
2773   memref::CopyOp copyOp;
2774   for (auto &u : subView.getUses()) {
2775     if (auto newCopyOp = dyn_cast<memref::CopyOp>(u.getOwner())) {
2776       assert(isa<MemRefType>(newCopyOp.getTarget().getType()));
2777       if (newCopyOp.getTarget() != subView)
2778         continue;
2779       if (mayExistInterleavedUses(newCopyOp, xferOp, {viewOrAlloc, subView}))
2780         continue;
2781       copyOp = newCopyOp;
2782       break;
2783     }
2784   }
2785   if (!copyOp)
2786     return rewriter.notifyMatchFailure(xferOp, "no copy found");
2787 
2788   // Find the fill into `viewOrAlloc` without interleaved uses before the
2789   // copy.
2790   FillOp maybeFillOp;
2791   for (auto &u : viewOrAlloc.getUses()) {
2792     if (auto newFillOp = dyn_cast<FillOp>(u.getOwner())) {
2793       assert(isa<MemRefType>(newFillOp.output().getType()));
2794       if (newFillOp.output() != viewOrAlloc)
2795         continue;
2796       if (mayExistInterleavedUses(newFillOp, copyOp, {viewOrAlloc, subView}))
2797         continue;
2798       maybeFillOp = newFillOp;
2799       break;
2800     }
2801   }
2802   // Ensure padding matches.
2803   if (maybeFillOp && xferOp.getPadding() != maybeFillOp.value())
2804     return rewriter.notifyMatchFailure(xferOp,
2805                                        "padding value does not match fill");
2806 
2807   // `in` is the subview that memref.copy reads. Replace it.
2808   Value in = copyOp.getSource();
2809 
2810   // memref.copy + linalg.fill can be used to create a padded local buffer.
2811   // The `masked` attribute is only valid on this padded buffer.
2812   // When forwarding to vector.transfer_read, the attribute must be reset
2813   // conservatively.
2814   auto vectorType = xferOp.getVectorType();
2815   Value res = rewriter.create<vector::TransferReadOp>(
2816       xferOp.getLoc(), vectorType, in, xferOp.getIndices(),
2817       xferOp.getPermutationMapAttr(), xferOp.getPadding(), xferOp.getMask(),
2818       rewriter.getBoolArrayAttr(
2819           SmallVector<bool>(vectorType.getRank(), false)));
2820 
2821   if (maybeFillOp)
2822     rewriter.eraseOp(maybeFillOp);
2823   rewriter.eraseOp(copyOp);
2824   rewriter.replaceOp(xferOp, res);
2825 
2826   return success();
2827 }
2828 
2829 /// TODO: use interfaces, side-effects and aliasing analysis as appropriate,
2830 /// when available.
2831 LogicalResult LinalgCopyVTWForwardingPattern::matchAndRewrite(
2832     vector::TransferWriteOp xferOp, PatternRewriter &rewriter) const {
2833   // TODO: support mask.
2834   if (xferOp.getMask())
2835     return rewriter.notifyMatchFailure(xferOp, "unsupported mask");
2836 
2837   // Transfer into `viewOrAlloc`.
2838   Value viewOrAlloc = xferOp.getSource();
2839   if (!viewOrAlloc.getDefiningOp<memref::ViewOp>() &&
2840       !viewOrAlloc.getDefiningOp<memref::AllocOp>())
2841     return rewriter.notifyMatchFailure(xferOp, "source not a view or alloc");
2842 
2843   // Ensure there is exactly one subview of `viewOrAlloc` defining `subView`.
2844   memref::SubViewOp subViewOp = getSubViewUseIfUnique(viewOrAlloc);
2845   if (!subViewOp)
2846     return rewriter.notifyMatchFailure(xferOp, "no subview found");
2847   Value subView = subViewOp.getResult();
2848 
2849   // Find the copy from `subView` without interleaved uses.
2850   memref::CopyOp copyOp;
2851   for (auto &u : subViewOp.getResult().getUses()) {
2852     if (auto newCopyOp = dyn_cast<memref::CopyOp>(u.getOwner())) {
2853       if (newCopyOp.getSource() != subView)
2854         continue;
2855       if (mayExistInterleavedUses(xferOp, newCopyOp, {viewOrAlloc, subView}))
2856         continue;
2857       copyOp = newCopyOp;
2858       break;
2859     }
2860   }
2861   if (!copyOp)
2862     return rewriter.notifyMatchFailure(xferOp, "no copy found");
2863 
2864   // `out` is the subview copied into that we replace.
2865   assert(isa<MemRefType>(copyOp.getTarget().getType()));
2866   Value out = copyOp.getTarget();
2867 
2868   // Forward vector.transfer into copy.
2869   // memref.copy + linalg.fill can be used to create a padded local buffer.
2870   // The `masked` attribute is only valid on this padded buffer.
2871   // When forwarding to vector.transfer_write, the attribute must be reset
2872   // conservatively.
2873   auto vector = xferOp.getVector();
2874   rewriter.create<vector::TransferWriteOp>(
2875       xferOp.getLoc(), vector, out, xferOp.getIndices(),
2876       xferOp.getPermutationMapAttr(), xferOp.getMask(),
2877       rewriter.getBoolArrayAttr(
2878           SmallVector<bool>(vector.getType().getRank(), false)));
2879 
2880   rewriter.eraseOp(copyOp);
2881   rewriter.eraseOp(xferOp);
2882 
2883   return success();
2884 }
2885 
2886 //===----------------------------------------------------------------------===//
2887 // Convolution vectorization patterns
2888 //===----------------------------------------------------------------------===//
2889 
2890 template <int N>
2891 static void bindShapeDims(ShapedType shapedType) {}
2892 
2893 template <int N, typename IntTy, typename... IntTy2>
2894 static void bindShapeDims(ShapedType shapedType, IntTy &val, IntTy2 &...vals) {
2895   val = shapedType.getShape()[N];
2896   bindShapeDims<N + 1, IntTy2 &...>(shapedType, vals...);
2897 }
2898 
2899 /// Bind a pack of int& to the leading dimensions of shapedType.getShape().
2900 template <typename... IntTy>
2901 static void bindShapeDims(ShapedType shapedType, IntTy &...vals) {
2902   bindShapeDims<0>(shapedType, vals...);
2903 }
2904 
2905 namespace {
2906 bool isCastOfBlockArgument(Operation *op) {
2907   return isa<CastOpInterface>(op) && op->getNumOperands() == 1 &&
2908          isa<BlockArgument>(op->getOperand(0));
2909 }
2910 
2911 bool isSupportedPoolKind(vector::CombiningKind kind) {
2912   switch (kind) {
2913   case vector::CombiningKind::ADD:
2914   case vector::CombiningKind::MAXNUMF:
2915   case vector::CombiningKind::MAXIMUMF:
2916   case vector::CombiningKind::MAXSI:
2917   case vector::CombiningKind::MAXUI:
2918   case vector::CombiningKind::MINNUMF:
2919   case vector::CombiningKind::MINIMUMF:
2920   case vector::CombiningKind::MINSI:
2921   case vector::CombiningKind::MINUI:
2922     return true;
2923   default:
2924     return false;
2925   }
2926 }
2927 
2928 /// Generate a vector implementation for either:
2929 /// ```
2930 ///   Op def: (     w,     kw  )
2931 ///    Iters: ({Par(), Red()})
2932 ///   Layout: {{w + kw}, {kw}, {w}}
2933 /// ```
2934 /// kw is unrolled.
2935 ///
2936 /// or
2937 ///
2938 /// ```
2939 ///   Op def: (     n,     w,     c,    kw,    f  )
2940 ///    Iters: ({Par(), Par(), Par(), Red(), Red()})
2941 ///   Layout: {{n, strideW * w + dilationW * kw, c}, {kw, c, f}, {n, w, f}}
2942 /// ```
2943 /// kw is unrolled, w is unrolled iff dilationW > 1.
2944 ///
2945 /// or
2946 ///
2947 /// ```
2948 ///   Op def: (     n,     c,     w,    f,    kw )
2949 ///    Iters: ({Par(), Par(), Par(), Red(), Red()})
2950 ///   Layout: {{n, c, strideW * w + dilationW * kw}, {f, c, kw}, {n, f, w}}
2951 /// ```
2952 /// kw is unrolled, w is unrolled iff dilationW > 1.
2953 ///
2954 /// or
2955 ///
2956 /// ```
2957 ///   Op def: (     n,     w,     c,    kw )
2958 ///    Iters: ({Par(), Par(), Par(), Red()})
2959 ///   Layout: {{n, strideW * w + dilationW * kw, c}, {kw, c}, {n, w, c}}
2960 /// ```
2961 /// kw is unrolled, w is unrolled iff dilationW > 1.
2962 struct Conv1DGenerator
2963     : public StructuredGenerator<LinalgOp, utils::IteratorType> {
2964   Conv1DGenerator(RewriterBase &rewriter, LinalgOp linalgOp, int strideW,
2965                   int dilationW)
2966       : StructuredGenerator<LinalgOp, utils::IteratorType>(rewriter, linalgOp),
2967         strideW(strideW), dilationW(dilationW) {
2968     // Determine whether `linalgOp` can be generated with this generator
2969     if (linalgOp.getNumDpsInputs() != 2 || linalgOp.getNumDpsInits() != 1)
2970       return;
2971     lhsShaped = linalgOp.getDpsInputOperand(0)->get();
2972     rhsShaped = linalgOp.getDpsInputOperand(1)->get();
2973     resShaped = linalgOp.getDpsInitOperand(0)->get();
2974     lhsShapedType = dyn_cast<ShapedType>(lhsShaped.getType());
2975     rhsShapedType = dyn_cast<ShapedType>(rhsShaped.getType());
2976     resShapedType = dyn_cast<ShapedType>(resShaped.getType());
2977     if (!lhsShapedType || !rhsShapedType || !resShapedType)
2978       return;
2979     // (LHS has dimension NCW/NWC and RES has dimension NFW/NCW/NWF/NWC) OR
2980     // (non-channeled convolution -> LHS and RHS both have single dimensions).
2981     if ((lhsShapedType.getRank() != 3 || resShapedType.getRank() != 3) &&
2982         (lhsShapedType.getRank() != 1 || resShapedType.getRank() != 1))
2983       return;
2984 
2985     Operation *reduceOp = matchLinalgReduction(linalgOp.getDpsInitOperand(0));
2986     if (!reduceOp)
2987       return;
2988     redOp = reduceOp->getName().getIdentifier();
2989 
2990     if (!setOperKind(reduceOp))
2991       return;
2992     auto maybeKind = getCombinerOpKind(reduceOp);
2993     // Typically convolution will have a `Add` CombiningKind but for i1 type it
2994     // can get strength reduced to `OR` which is also supported. This strength
2995     // reduction logic is in `buildBinaryFn` helper in the Linalg dialect.
2996     if (!maybeKind || ((*maybeKind != vector::CombiningKind::ADD &&
2997                         *maybeKind != vector::CombiningKind::OR) &&
2998                        (oper != Pool || !isSupportedPoolKind(*maybeKind)))) {
2999       return;
3000     }
3001     reductionKind = maybeKind.value();
3002 
3003     auto rhsRank = rhsShapedType.getRank();
3004     switch (oper) {
3005     case Conv:
3006       if (rhsRank != 1 && rhsRank != 2 && rhsRank != 3)
3007         return;
3008       break;
3009     case Pool:
3010       if (rhsRank != 1)
3011         return;
3012       break;
3013     }
3014     // The op is now known to be valid.
3015     valid = true;
3016   }
3017 
3018   /// Generate a vector implementation for:
3019   /// ```
3020   ///   Op def: (     w,     kw  )
3021   ///    Iters: ({Par(), Red()})
3022   ///   Layout: {{w + kw}, {kw}, {w}}
3023   /// ```
3024   /// kw is always unrolled.
3025   ///
3026   /// or
3027   ///
3028   /// ```
3029   ///   Op def: (     n,     w,     c,    kw,    f  )
3030   ///    Iters: ({Par(), Par(), Par(), Red(), Red()})
3031   ///   Layout: {{n, strideW * w + dilationW * kw, c}, {kw, c, f}, {n, w, f}}
3032   /// ```
3033   /// kw is always unrolled.
3034   /// TODO: w (resp. kw) is unrolled when the strideW ( resp. dilationW) is
3035   /// > 1.
3036   FailureOr<Operation *> conv(Conv1DOpOrder conv1DOpOrder) {
3037     if (!valid)
3038       return rewriter.notifyMatchFailure(op, "unvectorizable 1-D conv/pool");
3039 
3040     int64_t nSize, wSize, cSize, kwSize, fSize;
3041     SmallVector<int64_t, 3> lhsShape, rhsShape, resShape;
3042     bool isSingleChanneled = (conv1DOpOrder == Conv1DOpOrder::W);
3043     switch (conv1DOpOrder) {
3044     case Conv1DOpOrder::W:
3045       // Initialize unused dimensions
3046       nSize = fSize = cSize = 0;
3047       // out{W}
3048       bindShapeDims(resShapedType, wSize);
3049       // kernel{kw}
3050       bindShapeDims(rhsShapedType, kwSize);
3051       lhsShape = {// iw = ow + kw - 1
3052                   //   (i.e. 16 convolved with 3 -> 14)
3053                   (wSize + kwSize - 1)};
3054       rhsShape = {kwSize};
3055       resShape = {wSize};
3056       break;
3057     case Conv1DOpOrder::Nwc:
3058       // out{n, w, f}
3059       bindShapeDims(resShapedType, nSize, wSize, fSize);
3060       switch (oper) {
3061       case Conv:
3062         // kernel{kw, c, f}
3063         bindShapeDims(rhsShapedType, kwSize, cSize);
3064         break;
3065       case Pool:
3066         // kernel{kw}
3067         bindShapeDims(rhsShapedType, kwSize);
3068         cSize = fSize;
3069         break;
3070       }
3071       lhsShape = {nSize,
3072                   // iw = ow * sw + kw *  dw - 1
3073                   //   (i.e. 16 convolved with 3 (@stride 1 dilation 1) -> 14)
3074                   // Perform the proper inclusive -> exclusive -> inclusive.
3075                   ((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) -
3076                       1,
3077                   cSize};
3078       switch (oper) {
3079       case Conv:
3080         rhsShape = {kwSize, cSize, fSize};
3081         break;
3082       case Pool:
3083         rhsShape = {kwSize};
3084         break;
3085       }
3086       resShape = {nSize, wSize, fSize};
3087       break;
3088     case Conv1DOpOrder::Ncw:
3089       // out{n, f, w}
3090       bindShapeDims(resShapedType, nSize, fSize, wSize);
3091       switch (oper) {
3092       case Conv:
3093         // kernel{f, c, kw}
3094         bindShapeDims(rhsShapedType, fSize, cSize, kwSize);
3095         break;
3096       case Pool:
3097         // kernel{kw}
3098         bindShapeDims(rhsShapedType, kwSize);
3099         cSize = fSize;
3100         break;
3101       }
3102       lhsShape = {nSize, cSize,
3103                   // iw = ow * sw + kw *  dw - 1
3104                   //   (i.e. 16 convolved with 3 (@stride 1 dilation 1) -> 14)
3105                   // Perform the proper inclusive -> exclusive -> inclusive.
3106                   ((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) -
3107                       1};
3108       switch (oper) {
3109       case Conv:
3110         rhsShape = {fSize, cSize, kwSize};
3111         break;
3112       case Pool:
3113         rhsShape = {kwSize};
3114         break;
3115       }
3116       resShape = {nSize, fSize, wSize};
3117       break;
3118     }
3119 
3120     vector::TransferWriteOp write;
3121     Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
3122 
3123     // w is unrolled (i.e. wSizeStep == 1) iff strideW > 1.
3124     // When strideW == 1, we can batch the contiguous loads and avoid
3125     // unrolling
3126     int64_t wSizeStep = strideW == 1 ? wSize : 1;
3127 
3128     Type lhsEltType = lhsShapedType.getElementType();
3129     Type rhsEltType = rhsShapedType.getElementType();
3130     Type resEltType = resShapedType.getElementType();
3131     auto lhsType = VectorType::get(lhsShape, lhsEltType);
3132     auto rhsType = VectorType::get(rhsShape, rhsEltType);
3133     auto resType = VectorType::get(resShape, resEltType);
3134     // Zero padding with the corresponding dimensions for lhs, rhs and res.
3135     SmallVector<Value> lhsPadding(lhsShape.size(), zero);
3136     SmallVector<Value> rhsPadding(rhsShape.size(), zero);
3137     SmallVector<Value> resPadding(resShape.size(), zero);
3138 
3139     // Read the whole lhs, rhs and res in one shot (with zero padding).
3140     Value lhs = rewriter.create<vector::TransferReadOp>(loc, lhsType, lhsShaped,
3141                                                         lhsPadding);
3142     // This is needed only for Conv.
3143     Value rhs = nullptr;
3144     if (oper == Conv)
3145       rhs = rewriter.create<vector::TransferReadOp>(loc, rhsType, rhsShaped,
3146                                                     rhsPadding);
3147     Value res = rewriter.create<vector::TransferReadOp>(loc, resType, resShaped,
3148                                                         resPadding);
3149 
3150     // The base vectorization case for channeled convolution is input:
3151     // {n,w,c}, weight: {kw,c,f}, output: {n,w,f}. To reuse the base pattern
3152     // vectorization case, we do pre transpose on input, weight, and output.
3153     switch (conv1DOpOrder) {
3154     case Conv1DOpOrder::W:
3155     case Conv1DOpOrder::Nwc:
3156       // Base case, so no transposes necessary.
3157       break;
3158     case Conv1DOpOrder::Ncw: {
3159       // To match base vectorization case, we pre-transpose current case.
3160       // ncw -> nwc
3161       static constexpr std::array<int64_t, 3> permLhs = {0, 2, 1};
3162       lhs = rewriter.create<vector::TransposeOp>(loc, lhs, permLhs);
3163       // fcw -> wcf
3164       static constexpr std::array<int64_t, 3> permRhs = {2, 1, 0};
3165 
3166       // This is needed only for Conv.
3167       if (oper == Conv)
3168         rhs = rewriter.create<vector::TransposeOp>(loc, rhs, permRhs);
3169       // nfw -> nwf
3170       static constexpr std::array<int64_t, 3> permRes = {0, 2, 1};
3171       res = rewriter.create<vector::TransposeOp>(loc, res, permRes);
3172       break;
3173     }
3174     }
3175 
3176     //===------------------------------------------------------------------===//
3177     // Begin vector-only rewrite part
3178     //===------------------------------------------------------------------===//
3179     // Unroll along kw and read slices of lhs and rhs.
3180     SmallVector<Value> lhsVals, rhsVals, resVals;
3181     lhsVals = extractConvInputSlices(rewriter, loc, lhs, nSize, wSize, cSize,
3182                                      kwSize, strideW, dilationW, wSizeStep,
3183                                      isSingleChanneled);
3184     // Do not do for pooling.
3185     if (oper == Conv)
3186       rhsVals = extractConvFilterSlices(rewriter, loc, rhs, kwSize);
3187     resVals = extractConvResultSlices(rewriter, loc, res, nSize, wSize, fSize,
3188                                       wSizeStep, isSingleChanneled);
3189 
3190     auto linearIndex = [&](int64_t kw, int64_t w) {
3191       return kw * (wSize / wSizeStep) + w;
3192     };
3193 
3194     // Compute contraction: O{n, w, f} += I{n, sw * w + dw * kw, c} * F{c, f}
3195     // or perform outerproduct for non-channeled convolution or perform simple
3196     // arith operation for pooling
3197     for (int64_t kw = 0; kw < kwSize; ++kw) {
3198       for (int64_t w = 0; w < wSize; w += wSizeStep) {
3199         switch (oper) {
3200         case Conv:
3201           if (isSingleChanneled) {
3202             resVals[w] = conv1dSliceAsOuterProduct(rewriter, loc,
3203                                                    lhsVals[linearIndex(kw, w)],
3204                                                    rhsVals[kw], resVals[w]);
3205           } else {
3206             resVals[w] = conv1dSliceAsContraction(rewriter, loc,
3207                                                   lhsVals[linearIndex(kw, w)],
3208                                                   rhsVals[kw], resVals[w]);
3209           }
3210           break;
3211         case Pool:
3212           resVals[w] = pool1dSlice(rewriter, loc, lhsVals[linearIndex(kw, w)],
3213                                    resVals[w]);
3214           break;
3215         }
3216       }
3217     }
3218 
3219     res = insertConvResultSlices(rewriter, loc, res, wSize, wSizeStep, resVals,
3220                                  isSingleChanneled);
3221     //===------------------------------------------------------------------===//
3222     // End vector-only rewrite part
3223     //===------------------------------------------------------------------===//
3224 
3225     // The base vectorization case for channeled convolution is output:
3226     // {n,w,f} To reuse the result from base pattern vectorization case, we
3227     // post transpose the base case result.
3228     switch (conv1DOpOrder) {
3229     case Conv1DOpOrder::W:
3230     case Conv1DOpOrder::Nwc:
3231       // Base case, so no transposes necessary.
3232       break;
3233     case Conv1DOpOrder::Ncw: {
3234       // nwf -> nfw
3235       static constexpr std::array<int64_t, 3> perm = {0, 2, 1};
3236       res = rewriter.create<vector::TransposeOp>(loc, res, perm);
3237       break;
3238     }
3239     }
3240 
3241     return rewriter
3242         .create<vector::TransferWriteOp>(loc, res, resShaped, resPadding)
3243         .getOperation();
3244   }
3245 
3246   // Take a value and widen to have the same element type as `ty`.
3247   Value promote(RewriterBase &rewriter, Location loc, Value val, Type ty) {
3248     const Type srcElementType = getElementTypeOrSelf(val.getType());
3249     const Type dstElementType = getElementTypeOrSelf(ty);
3250     assert(isa<IntegerType>(dstElementType) || isa<FloatType>(dstElementType));
3251     if (srcElementType == dstElementType)
3252       return val;
3253 
3254     const int64_t srcWidth = srcElementType.getIntOrFloatBitWidth();
3255     const int64_t dstWidth = dstElementType.getIntOrFloatBitWidth();
3256     const Type dstType =
3257         cast<ShapedType>(val.getType()).cloneWith(std::nullopt, dstElementType);
3258 
3259     if (isa<IntegerType>(srcElementType) && isa<FloatType>(dstElementType)) {
3260       return rewriter.create<arith::SIToFPOp>(loc, dstType, val);
3261     }
3262 
3263     if (isa<FloatType>(srcElementType) && isa<FloatType>(dstElementType) &&
3264         srcWidth < dstWidth)
3265       return rewriter.create<arith::ExtFOp>(loc, dstType, val);
3266 
3267     if (isa<IntegerType>(srcElementType) && isa<IntegerType>(dstElementType) &&
3268         srcWidth < dstWidth)
3269       return rewriter.create<arith::ExtSIOp>(loc, dstType, val);
3270 
3271     assert(false && "unhandled promotion case");
3272     return nullptr;
3273   }
3274 
3275   // Create a contraction: lhs{n, w, c} * rhs{c, f} -> res{n, w, f}
3276   Value conv1dSliceAsContraction(RewriterBase &rewriter, Location loc,
3277                                  Value lhs, Value rhs, Value res) {
3278     vector::IteratorType par = vector::IteratorType::parallel;
3279     vector::IteratorType red = vector::IteratorType::reduction;
3280     AffineExpr n, w, f, c;
3281     bindDims(ctx, n, w, f, c);
3282     lhs = promote(rewriter, loc, lhs, res.getType());
3283     rhs = promote(rewriter, loc, rhs, res.getType());
3284     auto contrationOp = rewriter.create<vector::ContractionOp>(
3285         loc, lhs, rhs, res,
3286         /*indexingMaps=*/MapList{{n, w, c}, {c, f}, {n, w, f}},
3287         /*iteratorTypes=*/ArrayRef<vector::IteratorType>{par, par, par, red});
3288     contrationOp.setKind(reductionKind);
3289     return contrationOp;
3290   }
3291 
3292   // Create an outerproduct: lhs{w} * rhs{1} -> res{w} for single channel
3293   // convolution.
3294   Value conv1dSliceAsOuterProduct(RewriterBase &rewriter, Location loc,
3295                                   Value lhs, Value rhs, Value res) {
3296     return rewriter.create<vector::OuterProductOp>(
3297         loc, res.getType(), lhs, rhs, res, vector::CombiningKind::ADD);
3298   }
3299 
3300   // Create a reduction: lhs{n, w, c} -> res{n, w, c}
3301   Value pool1dSlice(RewriterBase &rewriter, Location loc, Value lhs,
3302                     Value res) {
3303     if (isPoolExt)
3304       lhs = rewriter.create(loc, poolExtOp, lhs, res.getType())->getResult(0);
3305     return rewriter
3306         .create(loc, redOp, ArrayRef<Value>{lhs, res}, res.getType())
3307         ->getResult(0);
3308   }
3309 
3310   /// Generate a vector implementation for:
3311   /// ```
3312   ///   Op def: (     n,     w,     c,    kw)
3313   ///    Iters: ({Par(), Par(), Par(), Red()})
3314   ///   Layout: {{n, strideW * w + dilationW * kw, c}, {kw, c}, {n, w, c}}
3315   /// ```
3316   /// kw is always unrolled.
3317   /// TODO: w (resp. kw) is unrolled when the strideW ( resp. dilationW) is
3318   /// > 1.
3319   FailureOr<Operation *> depthwiseConv(uint64_t channelDimVecSize,
3320                                        bool channelDimScalableFlag,
3321                                        bool flatten) {
3322     if (!valid)
3323       return rewriter.notifyMatchFailure(op, "unvectorizable depthwise conv");
3324 
3325     bool scalableChDim = false;
3326     bool useMasking = false;
3327     int64_t nSize, wSize, cSize, kwSize;
3328     // kernel{kw, c}
3329     bindShapeDims(rhsShapedType, kwSize, cSize);
3330     if (ShapedType::isDynamic(cSize)) {
3331       assert(channelDimVecSize != 0 && "Channel dim vec size must be > 0");
3332       cSize = channelDimVecSize;
3333       // Scalable vectors are only used when both conditions are met:
3334       //  1. channel dim is dynamic
3335       //  2. channelDimScalableFlag is set
3336       scalableChDim = channelDimScalableFlag;
3337       useMasking = true;
3338     }
3339 
3340     assert(!(useMasking && flatten) &&
3341            "Unsupported flattened conv with dynamic shapes");
3342 
3343     // out{n, w, c}
3344     bindShapeDims(resShapedType, nSize, wSize);
3345 
3346     vector::TransferWriteOp write;
3347     Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
3348 
3349     // w is unrolled (i.e. wSizeStep == 1) iff strideW > 1.
3350     // When strideW == 1, we can batch the contiguous loads and avoid
3351     // unrolling
3352     int64_t wSizeStep = strideW == 1 ? wSize : 1;
3353 
3354     Type lhsEltType = lhsShapedType.getElementType();
3355     Type rhsEltType = rhsShapedType.getElementType();
3356     Type resEltType = resShapedType.getElementType();
3357     VectorType lhsType = VectorType::get(
3358         {nSize,
3359          // iw = ow * sw + kw *  dw - 1
3360          //   (i.e. 16 convolved with 3 (@stride 1 dilation 1) -> 14)
3361          ((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) - 1,
3362          cSize},
3363         lhsEltType, /*scalableDims=*/{false, false, scalableChDim});
3364     VectorType rhsType =
3365         VectorType::get({kwSize, cSize}, rhsEltType,
3366                         /*scalableDims=*/{false, scalableChDim});
3367     VectorType resType =
3368         VectorType::get({nSize, wSize, cSize}, resEltType,
3369                         /*scalableDims=*/{false, false, scalableChDim});
3370 
3371     // Masks the input xfer Op along the channel dim, iff the corresponding
3372     // scalable flag is set.
3373     auto maybeMaskXferOp = [&](ArrayRef<int64_t> maskShape,
3374                                ArrayRef<bool> scalableDims,
3375                                Operation *opToMask) {
3376       if (!useMasking)
3377         return opToMask;
3378       auto maskType =
3379           VectorType::get(maskShape, rewriter.getI1Type(), scalableDims);
3380 
3381       SmallVector<bool> inBounds(maskShape.size(), true);
3382       auto xferOp = cast<VectorTransferOpInterface>(opToMask);
3383       xferOp->setAttr(xferOp.getInBoundsAttrName(),
3384                       rewriter.getBoolArrayAttr(inBounds));
3385 
3386       SmallVector<OpFoldResult> mixedDims = vector::getMixedSizesXfer(
3387           cast<LinalgOp>(op).hasPureTensorSemantics(), opToMask, rewriter);
3388 
3389       Value maskOp =
3390           rewriter.create<vector::CreateMaskOp>(loc, maskType, mixedDims);
3391 
3392       return mlir::vector::maskOperation(rewriter, opToMask, maskOp);
3393     };
3394 
3395     // Read lhs slice of size {n, w * strideW + kw * dilationW, c} @ [0, 0,
3396     // 0].
3397     Value lhs = rewriter.create<vector::TransferReadOp>(
3398         loc, lhsType, lhsShaped, ValueRange{zero, zero, zero});
3399     auto maybeMaskedLhs = maybeMaskXferOp(
3400         lhsType.getShape(), lhsType.getScalableDims(), lhs.getDefiningOp());
3401 
3402     // Read rhs slice of size {kw, c} @ [0, 0].
3403     Value rhs = rewriter.create<vector::TransferReadOp>(loc, rhsType, rhsShaped,
3404                                                         ValueRange{zero, zero});
3405     auto maybeMaskedRhs = maybeMaskXferOp(
3406         rhsType.getShape(), rhsType.getScalableDims(), rhs.getDefiningOp());
3407 
3408     // Read res slice of size {n, w, c} @ [0, 0, 0].
3409     Value res = rewriter.create<vector::TransferReadOp>(
3410         loc, resType, resShaped, ValueRange{zero, zero, zero});
3411     auto maybeMaskedRes = maybeMaskXferOp(
3412         resType.getShape(), resType.getScalableDims(), res.getDefiningOp());
3413 
3414     //===------------------------------------------------------------------===//
3415     // Begin vector-only rewrite part
3416     //===------------------------------------------------------------------===//
3417     // Unroll along kw and read slices of lhs and rhs.
3418     SmallVector<Value> lhsVals, rhsVals, resVals;
3419     auto inOutSliceSizes = SmallVector<int64_t>{nSize, wSizeStep, cSize};
3420     auto inOutStrides = SmallVector<int64_t>{1, 1, 1};
3421 
3422     // Extract lhs slice of size {n, wSizeStep, c}
3423     //   @ [0, sw * w + dw * kw, 0].
3424     for (int64_t kw = 0; kw < kwSize; ++kw) {
3425       for (int64_t w = 0; w < wSize; w += wSizeStep) {
3426         lhsVals.push_back(rewriter.create<vector::ExtractStridedSliceOp>(
3427             loc, maybeMaskedLhs->getResult(0),
3428             /*offsets=*/ArrayRef<int64_t>{0, w * strideW + kw * dilationW, 0},
3429             inOutSliceSizes, inOutStrides));
3430       }
3431     }
3432     // Extract rhs slice of size {c} @ [kw].
3433     for (int64_t kw = 0; kw < kwSize; ++kw) {
3434       rhsVals.push_back(rewriter.create<vector::ExtractOp>(
3435           loc, maybeMaskedRhs->getResult(0),
3436           /*offsets=*/ArrayRef<int64_t>{kw}));
3437     }
3438     // Extract res slice: {n, wSizeStep, c} @ [0, w, 0].
3439     for (int64_t w = 0; w < wSize; w += wSizeStep) {
3440       resVals.push_back(rewriter.create<vector::ExtractStridedSliceOp>(
3441           loc, maybeMaskedRes->getResult(0),
3442           /*offsets=*/ArrayRef<int64_t>{0, w, 0}, inOutSliceSizes,
3443           inOutStrides));
3444     }
3445 
3446     auto linearIndex = [&](int64_t kw, int64_t w) {
3447       return kw * (wSize / wSizeStep) + w;
3448     };
3449 
3450     // Note - the scalable flags are ignored as flattening combined with
3451     // scalable vectorization is not supported.
3452     auto inOutFlattenSliceSizes =
3453         SmallVector<int64_t>{nSize, wSizeStep * cSize};
3454     auto lhsTypeAfterFlattening =
3455         VectorType::get(inOutFlattenSliceSizes, lhsEltType);
3456     auto resTypeAfterFlattening =
3457         VectorType::get(inOutFlattenSliceSizes, resEltType);
3458 
3459     // Compute contraction: O{n, w, c} += I{n, sw * w + dw * kw, c} * F{c}
3460     for (int64_t kw = 0; kw < kwSize; ++kw) {
3461       for (int64_t w = 0; w < wSize; w += wSizeStep) {
3462         Value lhsVal = lhsVals[linearIndex(kw, w)];
3463         Value resVal = resVals[w];
3464         if (flatten) {
3465           // Flatten the input and output vectors (collapse the channel
3466           // dimension)
3467           lhsVal = rewriter.create<vector::ShapeCastOp>(
3468               loc, lhsTypeAfterFlattening, lhsVals[linearIndex(kw, w)]);
3469           resVal = rewriter.create<vector::ShapeCastOp>(
3470               loc, resTypeAfterFlattening, resVals[w]);
3471         }
3472         resVals[w] = depthwiseConv1dSliceAsMulAcc(rewriter, loc, lhsVal,
3473                                                   rhsVals[kw], resVal, flatten);
3474         if (flatten) {
3475           // Un-flatten the output vector (restore the channel dimension)
3476           resVals[w] = rewriter.create<vector::ShapeCastOp>(
3477               loc, VectorType::get(inOutSliceSizes, resEltType), resVals[w]);
3478         }
3479       }
3480     }
3481 
3482     // Its possible we failed to create the Fma.
3483     if (!llvm::all_of(resVals, [](Value v) { return v; })) {
3484       // Manually revert (in reverse order) to avoid leaving a bad IR state.
3485       for (auto &collection :
3486            {resVals, rhsVals, lhsVals, {res, rhs, lhs, zero}})
3487         for (Value v : collection)
3488           rewriter.eraseOp(v.getDefiningOp());
3489       return rewriter.notifyMatchFailure(op, "failed to create FMA");
3490     }
3491 
3492     // Write back res slice: {n, wSizeStep, c} @ [0, w, 0].
3493     // This does not depend on kw.
3494     for (int64_t w = 0; w < wSize; w += wSizeStep) {
3495       maybeMaskedRes = rewriter.create<vector::InsertStridedSliceOp>(
3496           loc, resVals[w], maybeMaskedRes->getResult(0),
3497           /*offsets=*/ArrayRef<int64_t>{0, w, 0},
3498           /*strides=*/ArrayRef<int64_t>{1, 1, 1});
3499     }
3500     //===------------------------------------------------------------------===//
3501     // End vector-only rewrite part
3502     //===------------------------------------------------------------------===//
3503 
3504     // Write back res slice of size {n, w, c} @ [0, 0, 0].
3505     Operation *resOut = rewriter.create<vector::TransferWriteOp>(
3506         loc, maybeMaskedRes->getResult(0), resShaped,
3507         ValueRange{zero, zero, zero});
3508     return maybeMaskXferOp(resType.getShape(), resType.getScalableDims(),
3509                            resOut);
3510   }
3511 
3512   /// Lower:
3513   ///   *  lhs{n, w, c} * rhs{c} -> res{n, w, c} (flatten = false)
3514   ///   *  lhs{n, w * c} * rhs{c} -> res{n, w * c} (flatten = true)
3515   /// to MulAcc.
3516   Value depthwiseConv1dSliceAsMulAcc(RewriterBase &rewriter, Location loc,
3517                                      Value lhs, Value rhs, Value res,
3518                                      bool flatten) {
3519     auto rhsTy = cast<ShapedType>(rhs.getType());
3520     auto resTy = cast<ShapedType>(res.getType());
3521 
3522     // TODO(suderman): Change this to use a vector.ima intrinsic.
3523     lhs = promote(rewriter, loc, lhs, resTy);
3524 
3525     if (flatten) {
3526       // NOTE: This following logic won't work for scalable vectors. For this
3527       // reason, "flattening" is not supported when shapes are dynamic (this
3528       // should be captured by one of the pre-conditions).
3529 
3530       // There are two options for handling the filter:
3531       //  * shape_cast(broadcast(filter))
3532       //  * broadcast(shuffle(filter))
3533       // Opt for the option without shape_cast to simplify the codegen.
3534       auto rhsSize = cast<VectorType>(rhs.getType()).getShape()[0];
3535       auto resSize = cast<VectorType>(res.getType()).getShape()[1];
3536 
3537       SmallVector<int64_t, 16> indices;
3538       for (int i = 0; i < resSize / rhsSize; ++i) {
3539         for (int j = 0; j < rhsSize; ++j)
3540           indices.push_back(j);
3541       }
3542 
3543       rhs = rewriter.create<vector::ShuffleOp>(loc, rhs, rhs, indices);
3544     }
3545     // Broadcast the filter to match the output vector
3546     rhs = rewriter.create<vector::BroadcastOp>(
3547         loc, resTy.clone(rhsTy.getElementType()), rhs);
3548 
3549     rhs = promote(rewriter, loc, rhs, resTy);
3550 
3551     if (!lhs || !rhs)
3552       return nullptr;
3553 
3554     if (isa<FloatType>(resTy.getElementType()))
3555       return rewriter.create<vector::FMAOp>(loc, lhs, rhs, res);
3556 
3557     auto mul = rewriter.create<arith::MulIOp>(loc, lhs, rhs);
3558     return rewriter.create<arith::AddIOp>(loc, mul, res);
3559   }
3560 
3561   /// Entry point for non-channeled convolution:
3562   ///   {{w + kw}, {kw}, {w}}
3563   FailureOr<Operation *> generateNonChanneledConv() {
3564     AffineExpr w, kw;
3565     bindDims(ctx, w, kw);
3566     if (!iters({Par(), Red()}))
3567       return rewriter.notifyMatchFailure(op,
3568                                          "failed to match conv::W 1-par 1-red");
3569 
3570     // No transposition needed.
3571     if (layout({/*lhsIndex*/ {w + kw},
3572                 /*rhsIndex*/ {kw},
3573                 /*resIndex*/ {w}}))
3574       return conv(Conv1DOpOrder::W);
3575 
3576     return rewriter.notifyMatchFailure(op, "not a conv::W layout");
3577   }
3578 
3579   /// Entry point that transposes into the common form:
3580   ///   {{n, strideW * w + dilationW * kw, c}, {kw, c, f}, {n, w, f}}
3581   FailureOr<Operation *> generateNwcConv() {
3582     AffineExpr n, w, f, kw, c;
3583     bindDims(ctx, n, w, f, kw, c);
3584     if (!iters({Par(), Par(), Par(), Red(), Red()}))
3585       return rewriter.notifyMatchFailure(
3586           op, "failed to match conv::Nwc 3-par 2-red");
3587 
3588     // No transposition needed.
3589     if (layout({/*lhsIndex*/ {n, strideW * w + dilationW * kw, c},
3590                 /*rhsIndex*/ {kw, c, f},
3591                 /*resIndex*/ {n, w, f}}))
3592       return conv(Conv1DOpOrder::Nwc);
3593 
3594     return rewriter.notifyMatchFailure(op, "not a conv::Nwc layout");
3595   }
3596 
3597   /// Entry point that transposes into the common form:
3598   ///   {{n, c, strideW * w + dilationW * kw}, {f, c, kw}, {n, f, w}}
3599   FailureOr<Operation *> generateNcwConv() {
3600     AffineExpr n, w, f, kw, c;
3601     bindDims(ctx, n, f, w, c, kw);
3602     if (!iters({Par(), Par(), Par(), Red(), Red()}))
3603       return rewriter.notifyMatchFailure(
3604           op, "failed to match conv::Ncw 3-par 2-red");
3605 
3606     if (layout({/*lhsIndex*/ {n, c, strideW * w + dilationW * kw},
3607                 /*rhsIndex*/ {f, c, kw},
3608                 /*resIndex*/ {n, f, w}}))
3609       return conv(Conv1DOpOrder::Ncw);
3610 
3611     return rewriter.notifyMatchFailure(op, "not a conv::Ncw layout");
3612   }
3613 
3614   /// Entry point that transposes into the common form:
3615   ///   {{n, strideW * w + dilationW * kw, c}, {kw}, {n, w, c}} for pooling
3616   FailureOr<Operation *> generateNwcPooling() {
3617     AffineExpr n, w, c, kw;
3618     bindDims(ctx, n, w, c, kw);
3619     if (!iters({Par(), Par(), Par(), Red()}))
3620       return rewriter.notifyMatchFailure(op,
3621                                          "failed to match pooling 3-par 1-red");
3622 
3623     // No transposition needed.
3624     if (layout({/*lhsIndex*/ {n, strideW * w + dilationW * kw, c},
3625                 /*rhsIndex*/ {kw},
3626                 /*resIndex*/ {n, w, c}}))
3627       return conv(Conv1DOpOrder::Nwc);
3628 
3629     return rewriter.notifyMatchFailure(op, "not a pooling::Nwc layout");
3630   }
3631 
3632   /// Entry point that transposes into the common form:
3633   ///   {{n, c, strideW * w + dilationW * kw}, {kw}, {n, c, w}} for pooling
3634   FailureOr<Operation *> generateNcwPooling() {
3635     AffineExpr n, w, c, kw;
3636     bindDims(ctx, n, c, w, kw);
3637     if (!iters({Par(), Par(), Par(), Red()}))
3638       return rewriter.notifyMatchFailure(op,
3639                                          "failed to match pooling 3-par 1-red");
3640 
3641     if (layout({/*lhsIndex*/ {n, c, strideW * w + dilationW * kw},
3642                 /*rhsIndex*/ {kw},
3643                 /*resIndex*/ {n, c, w}}))
3644       return conv(Conv1DOpOrder::Ncw);
3645 
3646     return rewriter.notifyMatchFailure(op, "not a pooling::Ncw layout");
3647   }
3648 
3649   /// Entry point that transposes into the common form:
3650   ///   {{n, strideW * w + dilationW * kw, c}, {kw, c}, {n, w, c}}
3651   FailureOr<Operation *> generateDilatedConv(uint64_t vecChDimSize = 0,
3652                                              bool vecChDimScalableFlag = false,
3653                                              bool flatten = false) {
3654     AffineExpr n, w, c, kw;
3655     bindDims(ctx, n, w, c, kw);
3656     if (!iters({Par(), Par(), Par(), Red()}))
3657       return rewriter.notifyMatchFailure(
3658           op, "failed to match depthwise::Nwc conv 3-par 1-red");
3659 
3660     // No transposition needed.
3661     if (layout({/*lhsIndex*/ {n, strideW * w + dilationW * kw, c},
3662                 /*rhsIndex*/ {kw, c},
3663                 /*resIndex*/ {n, w, c}}))
3664       return depthwiseConv(vecChDimSize, vecChDimScalableFlag, flatten);
3665 
3666     return rewriter.notifyMatchFailure(op, "not a depthwise::Nwc layout");
3667   }
3668 
3669 private:
3670   enum OperKind { Conv, Pool };
3671   bool valid = false;
3672   OperKind oper = Conv;
3673   StringAttr redOp;
3674   StringAttr poolExtOp;
3675   bool isPoolExt = false;
3676   int strideW, dilationW;
3677   Value lhsShaped, rhsShaped, resShaped;
3678   ShapedType lhsShapedType, rhsShapedType, resShapedType;
3679   vector::CombiningKind reductionKind;
3680 
3681   // Sets oper, poolExtOp and isPoolExt for valid conv/pooling ops.
3682   // Returns true iff it is a valid conv/pooling op.
3683   // If (region has 2 ops (reduction + yield) or 3 ops (extension + reduction
3684   // + yield) and rhs is not used) then it is the body of a pooling
3685   // If conv, check for single `mul` predecessor. The `mul` operands must be
3686   // block arguments or extension of block arguments.
3687   // Otherwise, check for one or zero `ext` predecessor. The `ext` operands
3688   // must be block arguments or extension of block arguments.
3689   bool setOperKind(Operation *reduceOp) {
3690     int numBlockArguments =
3691         llvm::count_if(reduceOp->getOperands(), llvm::IsaPred<BlockArgument>);
3692     switch (numBlockArguments) {
3693     case 1: {
3694       // Will be convolution if feeder is a MulOp.
3695       // A strength reduced version of MulOp for i1 type is AndOp which is also
3696       // supported. Otherwise, it can be pooling. This strength reduction logic
3697       // is in `buildBinaryFn` helper in the Linalg dialect.
3698       auto feedValIt = llvm::find_if_not(reduceOp->getOperands(),
3699                                          llvm::IsaPred<BlockArgument>);
3700       Operation *feedOp = (*feedValIt).getDefiningOp();
3701       if (isCastOfBlockArgument(feedOp)) {
3702         oper = Pool;
3703         isPoolExt = true;
3704         poolExtOp = feedOp->getName().getIdentifier();
3705       } else if (!((isa<arith::MulIOp, arith::MulFOp>(feedOp) ||
3706                     (isa<arith::AndIOp>(feedOp) &&
3707                      feedOp->getResultTypes()[0].isInteger(1))) &&
3708                    llvm::all_of(feedOp->getOperands(), [](Value v) {
3709                      if (isa<BlockArgument>(v))
3710                        return true;
3711                      if (Operation *op = v.getDefiningOp())
3712                        return isCastOfBlockArgument(op);
3713                      return false;
3714                    }))) {
3715         return false;
3716       }
3717       return true;
3718     }
3719     case 2:
3720       // Must be pooling
3721       oper = Pool;
3722       isPoolExt = false;
3723       return true;
3724     default:
3725       return false;
3726     }
3727   }
3728 };
3729 } // namespace
3730 
3731 /// Helper function to vectorize a LinalgOp with convolution semantics.
3732 // TODO: extend the generic vectorization to support windows and drop this.
3733 static FailureOr<Operation *> vectorizeConvolution(
3734     RewriterBase &rewriter, LinalgOp op, ArrayRef<int64_t> inputVecSizes,
3735     ArrayRef<bool> inputScalableVecDims, bool flatten1DDepthwiseConv) {
3736   // The ConvolutionOpInterface gives us guarantees of existence for
3737   // strides/dilations. However, we do not need to rely on those, we can
3738   // simply use them if present, otherwise use the default and let the generic
3739   // conv. matcher in the ConvGenerator succeed or fail.
3740   auto strides = op->getAttrOfType<DenseIntElementsAttr>("strides");
3741   auto dilations = op->getAttrOfType<DenseIntElementsAttr>("dilations");
3742   auto stride = strides ? *strides.getValues<uint64_t>().begin() : 1;
3743   auto dilation = dilations ? *dilations.getValues<uint64_t>().begin() : 1;
3744   Conv1DGenerator e(rewriter, op, stride, dilation);
3745   auto res = e.generateNonChanneledConv();
3746   if (succeeded(res))
3747     return res;
3748   res = e.generateNwcConv();
3749   if (succeeded(res))
3750     return res;
3751   res = e.generateNcwConv();
3752   if (succeeded(res))
3753     return res;
3754   res = e.generateNwcPooling();
3755   if (succeeded(res))
3756     return res;
3757   res = e.generateNcwPooling();
3758   if (succeeded(res))
3759     return res;
3760 
3761   // Only depthwise 1D NWC convs are left - these can be vectorized using masks
3762   // and scalable vectors. Note that ATM the only dim that can be dynamic (i.e.
3763   // masked/scalable) is the channel dim (i.e. the trailing dim).
3764   uint64_t vecChDimSize = ShapedType::kDynamic;
3765   bool vecChDimScalableFlag = false;
3766   if (!inputVecSizes.empty()) {
3767     // Only use the input vector size corresponding to the channel dim. Other
3768     // vector dims will be inferred from the Ops.
3769     assert((isa<linalg::DepthwiseConv1DNwcWcOp>(*op) ||
3770             isa<linalg::DepthwiseConv1DNcwCwOp>(*op)) &&
3771            "Not a 1D depthwise conv!");
3772     size_t chDimIdx =
3773         TypeSwitch<Operation *, size_t>(op)
3774             .Case<linalg::DepthwiseConv1DNwcWcOp>([](auto conv) { return 2; })
3775             .Case<linalg::DepthwiseConv1DNcwCwOp>([](auto conv) { return 1; });
3776 
3777     vecChDimSize = inputVecSizes[chDimIdx];
3778     vecChDimScalableFlag = inputScalableVecDims[chDimIdx];
3779   }
3780   return e.generateDilatedConv(vecChDimSize, vecChDimScalableFlag,
3781                                flatten1DDepthwiseConv);
3782 }
3783 
3784 struct VectorizeConvolution : public OpInterfaceRewritePattern<LinalgOp> {
3785   using OpInterfaceRewritePattern::OpInterfaceRewritePattern;
3786 
3787   LogicalResult matchAndRewrite(LinalgOp op,
3788                                 PatternRewriter &rewriter) const override {
3789     FailureOr<Operation *> resultOrFail = vectorizeConvolution(rewriter, op);
3790     if (failed(resultOrFail))
3791       return failure();
3792     Operation *newOp = *resultOrFail;
3793     if (newOp->getNumResults() == 0) {
3794       rewriter.eraseOp(op.getOperation());
3795       return success();
3796     }
3797     assert(newOp->getNumResults() == 1 && "expected single result");
3798     rewriter.replaceOp(op.getOperation(), newOp->getResult(0));
3799     return success();
3800   }
3801 };
3802 
3803 void mlir::linalg::populateConvolutionVectorizationPatterns(
3804     RewritePatternSet &patterns, PatternBenefit benefit) {
3805   patterns.add<VectorizeConvolution>(patterns.getContext(), benefit);
3806 }
3807