xref: /llvm-project/mlir/lib/Dialect/Vector/IR/VectorOps.cpp (revision aa2952165cd1808dab2bb49b97becc097f4c9cac)
1 //===- VectorOps.cpp - MLIR Vector Dialect Operations ---------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements convenience types for working with super-vectorization
10 // operations, in particular super-vector loads and stores.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Dialect/Vector/IR/VectorOps.h"
15 
16 #include "mlir/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.h"
17 #include "mlir/Dialect/Arith/IR/Arith.h"
18 #include "mlir/Dialect/Arith/Utils/Utils.h"
19 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
20 #include "mlir/Dialect/MemRef/IR/MemRef.h"
21 #include "mlir/Dialect/Tensor/IR/Tensor.h"
22 #include "mlir/Dialect/UB/IR/UBOps.h"
23 #include "mlir/Dialect/Utils/IndexingUtils.h"
24 #include "mlir/Dialect/Utils/StructuredOpsUtils.h"
25 #include "mlir/IR/AffineExpr.h"
26 #include "mlir/IR/AffineMap.h"
27 #include "mlir/IR/Builders.h"
28 #include "mlir/IR/BuiltinAttributes.h"
29 #include "mlir/IR/BuiltinOps.h"
30 #include "mlir/IR/BuiltinTypes.h"
31 #include "mlir/IR/DialectImplementation.h"
32 #include "mlir/IR/IRMapping.h"
33 #include "mlir/IR/OpImplementation.h"
34 #include "mlir/IR/PatternMatch.h"
35 #include "mlir/IR/TypeUtilities.h"
36 #include "mlir/Interfaces/SubsetOpInterface.h"
37 #include "mlir/Interfaces/ValueBoundsOpInterface.h"
38 #include "mlir/Support/LLVM.h"
39 #include "mlir/Transforms/InliningUtils.h"
40 #include "llvm/ADT/ArrayRef.h"
41 #include "llvm/ADT/STLExtras.h"
42 #include "llvm/ADT/SmallVector.h"
43 #include "llvm/ADT/StringSet.h"
44 #include "llvm/ADT/TypeSwitch.h"
45 #include "llvm/ADT/bit.h"
46 
47 #include <cassert>
48 #include <cstdint>
49 #include <numeric>
50 
51 #include "mlir/Dialect/Vector/IR/VectorDialect.cpp.inc"
52 // Pull in all enum type and utility function definitions.
53 #include "mlir/Dialect/Vector/IR/VectorEnums.cpp.inc"
54 
55 using namespace mlir;
56 using namespace mlir::vector;
57 
58 /// Helper enum to classify mask value.
59 enum class MaskFormat {
60   AllTrue = 0,
61   AllFalse = 1,
62   Unknown = 2,
63 };
64 
65 /// Helper method to classify a mask value. Currently, the method
66 /// looks "under the hood" of a constant value with dense attributes
67 /// and a constant mask operation (since the client may be called at
68 /// various stages during progressive lowering).
69 static MaskFormat getMaskFormat(Value mask) {
70   if (auto c = mask.getDefiningOp<arith::ConstantOp>()) {
71     // Inspect constant dense values. We count up for bits that
72     // are set, count down for bits that are cleared, and bail
73     // when a mix is detected.
74     if (auto denseElts = llvm::dyn_cast<DenseIntElementsAttr>(c.getValue())) {
75       int64_t val = 0;
76       for (bool b : denseElts.getValues<bool>())
77         if (b && val >= 0)
78           val++;
79         else if (!b && val <= 0)
80           val--;
81         else
82           return MaskFormat::Unknown;
83       if (val > 0)
84         return MaskFormat::AllTrue;
85       if (val < 0)
86         return MaskFormat::AllFalse;
87     }
88   } else if (auto m = mask.getDefiningOp<ConstantMaskOp>()) {
89     // Inspect constant mask index. If the index exceeds the
90     // dimension size, all bits are set. If the index is zero
91     // or less, no bits are set.
92     ArrayRef<int64_t> masks = m.getMaskDimSizes();
93     auto shape = m.getType().getShape();
94     bool allTrue = true;
95     bool allFalse = true;
96     for (auto [maskIdx, dimSize] : llvm::zip_equal(masks, shape)) {
97       if (maskIdx < dimSize)
98         allTrue = false;
99       if (maskIdx > 0)
100         allFalse = false;
101     }
102     if (allTrue)
103       return MaskFormat::AllTrue;
104     if (allFalse)
105       return MaskFormat::AllFalse;
106   } else if (auto m = mask.getDefiningOp<CreateMaskOp>()) {
107     // Finds all-false create_masks. An all-true create_mask requires all
108     // dims to be constants, so that'll be folded to a constant_mask, then
109     // detected in the constant_mask case.
110     auto maskOperands = m.getOperands();
111     for (Value operand : maskOperands) {
112       if (auto constantOp = operand.getDefiningOp<arith::ConstantOp>()) {
113         int64_t dimSize =
114             llvm::cast<IntegerAttr>(constantOp.getValue()).getInt();
115         if (dimSize <= 0)
116           return MaskFormat::AllFalse;
117       }
118     }
119     return MaskFormat::Unknown;
120   }
121   return MaskFormat::Unknown;
122 }
123 
124 /// Default callback to build a region with a 'vector.yield' terminator with no
125 /// arguments.
126 void mlir::vector::buildTerminatedBody(OpBuilder &builder, Location loc) {
127   builder.create<vector::YieldOp>(loc);
128 }
129 
130 // Helper for verifying combining kinds in contractions and reductions.
131 static bool isSupportedCombiningKind(CombiningKind combiningKind,
132                                      Type elementType) {
133   switch (combiningKind) {
134   case CombiningKind::ADD:
135   case CombiningKind::MUL:
136     return elementType.isIntOrIndexOrFloat();
137   case CombiningKind::MINUI:
138   case CombiningKind::MINSI:
139   case CombiningKind::MAXUI:
140   case CombiningKind::MAXSI:
141   case CombiningKind::AND:
142   case CombiningKind::OR:
143   case CombiningKind::XOR:
144     return elementType.isIntOrIndex();
145   case CombiningKind::MINNUMF:
146   case CombiningKind::MAXNUMF:
147   case CombiningKind::MINIMUMF:
148   case CombiningKind::MAXIMUMF:
149     return llvm::isa<FloatType>(elementType);
150   }
151   return false;
152 }
153 
154 AffineMap mlir::vector::getTransferMinorIdentityMap(ShapedType shapedType,
155                                                     VectorType vectorType) {
156   int64_t elementVectorRank = 0;
157   VectorType elementVectorType =
158       llvm::dyn_cast<VectorType>(shapedType.getElementType());
159   if (elementVectorType)
160     elementVectorRank += elementVectorType.getRank();
161   // 0-d transfers are to/from tensor<t>/memref<t> and vector<1xt>.
162   // TODO: replace once we have 0-d vectors.
163   if (shapedType.getRank() == 0 &&
164       vectorType.getShape() == ArrayRef<int64_t>{1})
165     return AffineMap::get(
166         /*numDims=*/0, /*numSymbols=*/0,
167         getAffineConstantExpr(0, shapedType.getContext()));
168   return AffineMap::getMinorIdentityMap(
169       shapedType.getRank(), vectorType.getRank() - elementVectorRank,
170       shapedType.getContext());
171 }
172 
173 /// Check if `write` is of a constant splat and the masked `read` is padded with
174 /// the same splat value -- meaning it could be the same value as the initial
175 /// constant splat.
176 static bool isSplatWriteConsistentWithMaskedRead(vector::TransferWriteOp write,
177                                                  vector::TransferReadOp read) {
178   auto readMask = read.getMask();
179   auto writeMask = write.getMask();
180   // Check if the masks are consistent. The splat value could be the same if the
181   // read is masked (and padded with the splat value), and the write is unmasked
182   // or has the same mask. Note this does not allow the case where the write is
183   // masked and the read is unmasked, as then the read could be of more elements
184   // than the write (which may not be the same value).
185   bool couldBeSameSplat = readMask && (!writeMask || writeMask == readMask);
186   if (!couldBeSameSplat)
187     return false;
188   // Check for constant splat (as the source of the write).
189   DenseElementsAttr splatAttr;
190   if (!matchPattern(write.getVector(),
191                     m_Constant<DenseElementsAttr>(&splatAttr)) ||
192       !splatAttr.isSplat()) {
193     return false;
194   }
195   // The padding of the read and the constant splat value must be the same.
196   Attribute padAttr;
197   if (!matchPattern(read.getPadding(), m_Constant(&padAttr)))
198     return false;
199   return padAttr == splatAttr.getSplatValue<Attribute>();
200 }
201 
202 bool mlir::vector::checkSameValueRAW(vector::TransferWriteOp defWrite,
203                                      vector::TransferReadOp read) {
204   return !defWrite.hasOutOfBoundsDim() &&
205          defWrite.getIndices() == read.getIndices() &&
206          defWrite.getVectorType() == read.getVectorType() &&
207          defWrite.getPermutationMap() == read.getPermutationMap() &&
208          ((!defWrite.getMask() && !read.getMask()) ||
209           isSplatWriteConsistentWithMaskedRead(defWrite, read));
210 }
211 
212 bool mlir::vector::checkSameValueWAW(vector::TransferWriteOp write,
213                                      vector::TransferWriteOp priorWrite) {
214   return priorWrite.getIndices() == write.getIndices() &&
215          priorWrite.getMask() == write.getMask() &&
216          priorWrite.getVectorType() == write.getVectorType() &&
217          priorWrite.getPermutationMap() == write.getPermutationMap();
218 }
219 
220 bool mlir::vector::isDisjointTransferIndices(
221     VectorTransferOpInterface transferA, VectorTransferOpInterface transferB,
222     bool testDynamicValueUsingBounds) {
223   // For simplicity only look at transfer of same type.
224   if (transferA.getVectorType() != transferB.getVectorType())
225     return false;
226   unsigned rankOffset = transferA.getLeadingShapedRank();
227   for (unsigned i = 0, e = transferA.getIndices().size(); i < e; i++) {
228     Value indexA = transferA.getIndices()[i];
229     Value indexB = transferB.getIndices()[i];
230     std::optional<int64_t> cstIndexA = getConstantIntValue(indexA);
231     std::optional<int64_t> cstIndexB = getConstantIntValue(indexB);
232 
233     if (i < rankOffset) {
234       // For leading dimensions, if we can prove that index are different we
235       // know we are accessing disjoint slices.
236       if (cstIndexA.has_value() && cstIndexB.has_value()) {
237         if (*cstIndexA != *cstIndexB)
238           return true;
239         continue;
240       }
241       if (testDynamicValueUsingBounds) {
242         // First try to see if we can fully compose and simplify the affine
243         // expression as a fast track.
244         FailureOr<uint64_t> delta =
245             affine::fullyComposeAndComputeConstantDelta(indexA, indexB);
246         if (succeeded(delta) && *delta != 0)
247           return true;
248 
249         FailureOr<bool> testEqual =
250             ValueBoundsConstraintSet::areEqual(indexA, indexB);
251         if (succeeded(testEqual) && !testEqual.value())
252           return true;
253       }
254     } else {
255       // For this dimension, we slice a part of the memref we need to make sure
256       // the intervals accessed don't overlap.
257       int64_t vectorDim = transferA.getVectorType().getDimSize(i - rankOffset);
258       if (cstIndexA.has_value() && cstIndexB.has_value()) {
259         int64_t distance = std::abs(*cstIndexA - *cstIndexB);
260         if (distance >= vectorDim)
261           return true;
262         continue;
263       }
264       if (testDynamicValueUsingBounds) {
265         // First try to see if we can fully compose and simplify the affine
266         // expression as a fast track.
267         FailureOr<int64_t> delta =
268             affine::fullyComposeAndComputeConstantDelta(indexA, indexB);
269         if (succeeded(delta) && std::abs(*delta) >= vectorDim)
270           return true;
271 
272         FailureOr<int64_t> computeDelta =
273             ValueBoundsConstraintSet::computeConstantDelta(indexA, indexB);
274         if (succeeded(computeDelta)) {
275           if (std::abs(computeDelta.value()) >= vectorDim)
276             return true;
277         }
278       }
279     }
280   }
281   return false;
282 }
283 
284 bool mlir::vector::isDisjointTransferSet(VectorTransferOpInterface transferA,
285                                          VectorTransferOpInterface transferB,
286                                          bool testDynamicValueUsingBounds) {
287   if (transferA.getSource() != transferB.getSource())
288     return false;
289   return isDisjointTransferIndices(transferA, transferB,
290                                    testDynamicValueUsingBounds);
291 }
292 
293 // Helper to iterate over n-D vector slice elements. Calculate the next
294 // `position` in the n-D vector of size `shape`, applying an offset `offsets`.
295 // Modifies the `position` in place. Returns a failure when `position` becomes
296 // the end position.
297 static LogicalResult incSlicePosition(MutableArrayRef<int64_t> position,
298                                       ArrayRef<int64_t> shape,
299                                       ArrayRef<int64_t> offsets) {
300   for (auto [posInDim, dimSize, offsetInDim] :
301        llvm::reverse(llvm::zip_equal(position, shape, offsets))) {
302     ++posInDim;
303     if (posInDim < dimSize + offsetInDim)
304       return success();
305 
306     // Carry the overflow to the next loop iteration.
307     posInDim = offsetInDim;
308   }
309 
310   return failure();
311 }
312 
313 /// Returns the integer numbers in `values`. `values` are expected to be
314 /// constant operations.
315 SmallVector<int64_t> vector::getAsIntegers(ArrayRef<Value> values) {
316   SmallVector<int64_t> ints;
317   llvm::transform(values, std::back_inserter(ints), [](Value value) {
318     auto constOp = value.getDefiningOp<arith::ConstantIndexOp>();
319     assert(constOp && "Unexpected non-constant index");
320     return constOp.value();
321   });
322   return ints;
323 }
324 
325 /// Returns the integer numbers in `foldResults`. `foldResults` are expected to
326 /// be constant operations.
327 SmallVector<int64_t> vector::getAsIntegers(ArrayRef<OpFoldResult> foldResults) {
328   SmallVector<int64_t> ints;
329   llvm::transform(
330       foldResults, std::back_inserter(ints), [](OpFoldResult foldResult) {
331         assert(isa<Attribute>(foldResult) && "Unexpected non-constant index");
332         return cast<IntegerAttr>(cast<Attribute>(foldResult)).getInt();
333       });
334   return ints;
335 }
336 
337 /// Convert `foldResults` into Values. Integer attributes are converted to
338 /// constant op.
339 SmallVector<Value> vector::getAsValues(OpBuilder &builder, Location loc,
340                                        ArrayRef<OpFoldResult> foldResults) {
341   SmallVector<Value> values;
342   llvm::transform(foldResults, std::back_inserter(values),
343                   [&](OpFoldResult foldResult) {
344                     if (auto attr = foldResult.dyn_cast<Attribute>())
345                       return builder
346                           .create<arith::ConstantIndexOp>(
347                               loc, cast<IntegerAttr>(attr).getInt())
348                           .getResult();
349 
350                     return cast<Value>(foldResult);
351                   });
352   return values;
353 }
354 
355 std::optional<int64_t> vector::getConstantVscaleMultiplier(Value value) {
356   if (value.getDefiningOp<vector::VectorScaleOp>())
357     return 1;
358   auto mul = value.getDefiningOp<arith::MulIOp>();
359   if (!mul)
360     return {};
361   auto lhs = mul.getLhs();
362   auto rhs = mul.getRhs();
363   if (lhs.getDefiningOp<vector::VectorScaleOp>())
364     return getConstantIntValue(rhs);
365   if (rhs.getDefiningOp<vector::VectorScaleOp>())
366     return getConstantIntValue(lhs);
367   return {};
368 }
369 
370 //===----------------------------------------------------------------------===//
371 // CombiningKindAttr
372 //===----------------------------------------------------------------------===//
373 
374 namespace mlir {
375 namespace vector {
376 namespace detail {
377 struct BitmaskEnumStorage : public AttributeStorage {
378   using KeyTy = uint64_t;
379 
380   BitmaskEnumStorage(KeyTy val) : value(val) {}
381 
382   bool operator==(const KeyTy &key) const { return value == key; }
383 
384   static BitmaskEnumStorage *construct(AttributeStorageAllocator &allocator,
385                                        const KeyTy &key) {
386     return new (allocator.allocate<BitmaskEnumStorage>())
387         BitmaskEnumStorage(key);
388   }
389 
390   KeyTy value = 0;
391 };
392 } // namespace detail
393 } // namespace vector
394 } // namespace mlir
395 
396 //===----------------------------------------------------------------------===//
397 // VectorDialect
398 //===----------------------------------------------------------------------===//
399 
400 namespace {
401 /// This class defines the interface for handling inlining with vector dialect
402 /// operations.
403 struct VectorInlinerInterface : public DialectInlinerInterface {
404   using DialectInlinerInterface::DialectInlinerInterface;
405 
406   /// All vector dialect ops can be inlined.
407   bool isLegalToInline(Operation *, Region *, bool, IRMapping &) const final {
408     return true;
409   }
410 };
411 } // namespace
412 
413 void VectorDialect::initialize() {
414   addAttributes<
415 #define GET_ATTRDEF_LIST
416 #include "mlir/Dialect/Vector/IR/VectorAttributes.cpp.inc"
417       >();
418 
419   addOperations<
420 #define GET_OP_LIST
421 #include "mlir/Dialect/Vector/IR/VectorOps.cpp.inc"
422       >();
423 
424   addInterfaces<VectorInlinerInterface>();
425 
426   declarePromisedInterfaces<bufferization::BufferizableOpInterface,
427                             TransferReadOp, TransferWriteOp, GatherOp, MaskOp,
428                             YieldOp>();
429   declarePromisedInterfaces<SubsetOpInterface, TransferReadOp,
430                             TransferWriteOp>();
431   declarePromisedInterface<SubsetExtractionOpInterface, TransferReadOp>();
432   declarePromisedInterface<SubsetInsertionOpInterface, TransferWriteOp>();
433 }
434 
435 /// Materialize a single constant operation from a given attribute value with
436 /// the desired resultant type.
437 Operation *VectorDialect::materializeConstant(OpBuilder &builder,
438                                               Attribute value, Type type,
439                                               Location loc) {
440   return arith::ConstantOp::materialize(builder, value, type, loc);
441 }
442 
443 IntegerType vector::getVectorSubscriptType(Builder &builder) {
444   return builder.getIntegerType(64);
445 }
446 
447 ArrayAttr vector::getVectorSubscriptAttr(Builder &builder,
448                                          ArrayRef<int64_t> values) {
449   return builder.getI64ArrayAttr(values);
450 }
451 
452 //===----------------------------------------------------------------------===//
453 // MultiDimReductionOp
454 //===----------------------------------------------------------------------===//
455 
456 void vector::MultiDimReductionOp::build(OpBuilder &builder,
457                                         OperationState &result, Value source,
458                                         Value acc, ArrayRef<bool> reductionMask,
459                                         CombiningKind kind) {
460   SmallVector<int64_t> reductionDims;
461   for (const auto &en : llvm::enumerate(reductionMask))
462     if (en.value())
463       reductionDims.push_back(en.index());
464   build(builder, result, kind, source, acc, reductionDims);
465 }
466 
467 OpFoldResult MultiDimReductionOp::fold(FoldAdaptor adaptor) {
468   // Single parallel dim, this is a noop.
469   if (getSourceVectorType().getRank() == 1 && !isReducedDim(0))
470     return getSource();
471   return {};
472 }
473 
474 std::optional<SmallVector<int64_t, 4>>
475 MultiDimReductionOp::getShapeForUnroll() {
476   return llvm::to_vector<4>(getSourceVectorType().getShape());
477 }
478 
479 LogicalResult MultiDimReductionOp::verify() {
480   SmallVector<int64_t> targetShape;
481   SmallVector<bool> scalableDims;
482   Type inferredReturnType;
483   auto sourceScalableDims = getSourceVectorType().getScalableDims();
484   for (auto [dimIdx, dimSize] :
485        llvm::enumerate(getSourceVectorType().getShape()))
486     if (!llvm::any_of(getReductionDims(),
487                       [dimIdx = dimIdx](int64_t reductionDimIdx) {
488                         return reductionDimIdx == static_cast<int64_t>(dimIdx);
489                       })) {
490       targetShape.push_back(dimSize);
491       scalableDims.push_back(sourceScalableDims[dimIdx]);
492     }
493   // TODO: update to also allow 0-d vectors when available.
494   if (targetShape.empty())
495     inferredReturnType = getSourceVectorType().getElementType();
496   else
497     inferredReturnType = VectorType::get(
498         targetShape, getSourceVectorType().getElementType(), scalableDims);
499   if (getType() != inferredReturnType)
500     return emitOpError() << "destination type " << getType()
501                          << " is incompatible with source type "
502                          << getSourceVectorType();
503 
504   return success();
505 }
506 
507 /// Returns the mask type expected by this operation.
508 Type MultiDimReductionOp::getExpectedMaskType() {
509   auto vecType = getSourceVectorType();
510   return VectorType::get(vecType.getShape(),
511                          IntegerType::get(vecType.getContext(), /*width=*/1),
512                          vecType.getScalableDims());
513 }
514 
515 namespace {
516 // Only unit dimensions that are being reduced are folded. If the dimension is
517 // unit, but not reduced, it is not folded, thereby keeping the output type the
518 // same. If not all dimensions which are reduced are of unit dimension, this
519 // transformation does nothing. This is just a generalization of
520 // ElideSingleElementReduction for ReduceOp.
521 struct ElideUnitDimsInMultiDimReduction
522     : public OpRewritePattern<MultiDimReductionOp> {
523   using OpRewritePattern::OpRewritePattern;
524 
525   LogicalResult matchAndRewrite(MultiDimReductionOp reductionOp,
526                                 PatternRewriter &rewriter) const override {
527     ArrayRef<int64_t> shape = reductionOp.getSourceVectorType().getShape();
528     for (const auto &dim : enumerate(shape)) {
529       if (reductionOp.isReducedDim(dim.index()) && dim.value() != 1)
530         return failure();
531     }
532 
533     // Vector mask setup.
534     OpBuilder::InsertionGuard guard(rewriter);
535     Operation *rootOp;
536     Value mask;
537     if (reductionOp.isMasked()) {
538       rewriter.setInsertionPoint(reductionOp.getMaskingOp());
539       rootOp = reductionOp.getMaskingOp();
540       mask = reductionOp.getMaskingOp().getMask();
541     } else {
542       rootOp = reductionOp;
543     }
544 
545     Location loc = reductionOp.getLoc();
546     Value acc = reductionOp.getAcc();
547     Value cast;
548     if (auto dstVecType = dyn_cast<VectorType>(reductionOp.getDestType())) {
549       if (mask) {
550         VectorType newMaskType =
551             VectorType::get(dstVecType.getShape(), rewriter.getI1Type(),
552                             dstVecType.getScalableDims());
553         mask = rewriter.create<vector::ShapeCastOp>(loc, newMaskType, mask);
554       }
555       cast = rewriter.create<vector::ShapeCastOp>(
556           loc, reductionOp.getDestType(), reductionOp.getSource());
557     } else {
558       // This means we are reducing all the dimensions, and all reduction
559       // dimensions are of size 1. So a simple extraction would do.
560       SmallVector<int64_t> zeroIdx(shape.size(), 0);
561       if (mask)
562         mask = rewriter.create<vector::ExtractOp>(loc, mask, zeroIdx);
563       cast = rewriter.create<vector::ExtractOp>(loc, reductionOp.getSource(),
564                                                 zeroIdx);
565     }
566 
567     Value result =
568         vector::makeArithReduction(rewriter, loc, reductionOp.getKind(), acc,
569                                    cast, /*fastmath=*/nullptr, mask);
570     rewriter.replaceOp(rootOp, result);
571     return success();
572   }
573 };
574 } // namespace
575 
576 void MultiDimReductionOp::getCanonicalizationPatterns(
577     RewritePatternSet &results, MLIRContext *context) {
578   results.add<ElideUnitDimsInMultiDimReduction>(context);
579 }
580 
581 //===----------------------------------------------------------------------===//
582 // ReductionOp
583 //===----------------------------------------------------------------------===//
584 
585 void vector::ReductionOp::build(OpBuilder &builder, OperationState &result,
586                                 CombiningKind kind, Value vector,
587                                 arith::FastMathFlags fastMathFlags) {
588   build(builder, result, kind, vector, /*acc=*/Value(), fastMathFlags);
589 }
590 
591 void vector::ReductionOp::build(OpBuilder &builder, OperationState &result,
592                                 CombiningKind kind, Value vector, Value acc,
593                                 arith::FastMathFlags fastMathFlags) {
594   build(builder, result,
595         llvm::cast<VectorType>(vector.getType()).getElementType(), kind, vector,
596         acc, fastMathFlags);
597 }
598 
599 LogicalResult ReductionOp::verify() {
600   // Verify for 0-D and 1-D vector.
601   int64_t rank = getSourceVectorType().getRank();
602   if (rank > 1)
603     return emitOpError("unsupported reduction rank: ") << rank;
604 
605   // Verify supported reduction kind.
606   Type eltType = getDest().getType();
607   if (!isSupportedCombiningKind(getKind(), eltType))
608     return emitOpError("unsupported reduction type '")
609            << eltType << "' for kind '" << stringifyCombiningKind(getKind())
610            << "'";
611 
612   return success();
613 }
614 
615 // MaskableOpInterface methods.
616 
617 /// Returns the mask type expected by this operation.
618 Type ReductionOp::getExpectedMaskType() {
619   auto vecType = getSourceVectorType();
620   return VectorType::get(vecType.getShape(),
621                          IntegerType::get(vecType.getContext(), /*width=*/1),
622                          vecType.getScalableDims());
623 }
624 
625 Value mlir::vector::getVectorReductionOp(arith::AtomicRMWKind op,
626                                          OpBuilder &builder, Location loc,
627                                          Value vector) {
628   switch (op) {
629   case arith::AtomicRMWKind::addf:
630   case arith::AtomicRMWKind::addi:
631     return builder.create<vector::ReductionOp>(vector.getLoc(),
632                                                CombiningKind::ADD, vector);
633   case arith::AtomicRMWKind::mulf:
634   case arith::AtomicRMWKind::muli:
635     return builder.create<vector::ReductionOp>(vector.getLoc(),
636                                                CombiningKind::MUL, vector);
637   case arith::AtomicRMWKind::minimumf:
638     return builder.create<vector::ReductionOp>(vector.getLoc(),
639                                                CombiningKind::MINIMUMF, vector);
640   case arith::AtomicRMWKind::mins:
641     return builder.create<vector::ReductionOp>(vector.getLoc(),
642                                                CombiningKind::MINSI, vector);
643   case arith::AtomicRMWKind::minu:
644     return builder.create<vector::ReductionOp>(vector.getLoc(),
645                                                CombiningKind::MINUI, vector);
646   case arith::AtomicRMWKind::maximumf:
647     return builder.create<vector::ReductionOp>(vector.getLoc(),
648                                                CombiningKind::MAXIMUMF, vector);
649   case arith::AtomicRMWKind::maxs:
650     return builder.create<vector::ReductionOp>(vector.getLoc(),
651                                                CombiningKind::MAXSI, vector);
652   case arith::AtomicRMWKind::maxu:
653     return builder.create<vector::ReductionOp>(vector.getLoc(),
654                                                CombiningKind::MAXUI, vector);
655   case arith::AtomicRMWKind::andi:
656     return builder.create<vector::ReductionOp>(vector.getLoc(),
657                                                CombiningKind::AND, vector);
658   case arith::AtomicRMWKind::ori:
659     return builder.create<vector::ReductionOp>(vector.getLoc(),
660                                                CombiningKind::OR, vector);
661   // TODO: Add remaining reduction operations.
662   default:
663     (void)emitOptionalError(loc, "Reduction operation type not supported");
664     break;
665   }
666   return nullptr;
667 }
668 
669 std::optional<SmallVector<int64_t, 4>> ReductionOp::getShapeForUnroll() {
670   return llvm::to_vector<4>(getSourceVectorType().getShape());
671 }
672 
673 namespace {
674 struct ElideSingleElementReduction : public OpRewritePattern<ReductionOp> {
675   using OpRewritePattern::OpRewritePattern;
676 
677   LogicalResult matchAndRewrite(ReductionOp reductionOp,
678                                 PatternRewriter &rewriter) const override {
679     // Vector mask setup.
680     OpBuilder::InsertionGuard guard(rewriter);
681     auto maskableOp =
682         cast<vector::MaskableOpInterface>(reductionOp.getOperation());
683     Operation *rootOp;
684     Value mask;
685     if (maskableOp.isMasked()) {
686       rewriter.setInsertionPoint(maskableOp.getMaskingOp());
687       rootOp = maskableOp.getMaskingOp();
688       mask = maskableOp.getMaskingOp().getMask();
689     } else {
690       rootOp = reductionOp;
691     }
692 
693     auto vectorType = reductionOp.getSourceVectorType();
694     if (vectorType.getRank() != 0 && vectorType.getDimSize(0) != 1)
695       return failure();
696 
697     Location loc = reductionOp.getLoc();
698     Value result;
699     if (vectorType.getRank() == 0) {
700       if (mask)
701         mask = rewriter.create<ExtractElementOp>(loc, mask);
702       result = rewriter.create<ExtractElementOp>(loc, reductionOp.getVector());
703     } else {
704       if (mask)
705         mask = rewriter.create<ExtractOp>(loc, mask, 0);
706       result = rewriter.create<ExtractOp>(loc, reductionOp.getVector(), 0);
707     }
708 
709     if (Value acc = reductionOp.getAcc())
710       result = vector::makeArithReduction(rewriter, loc, reductionOp.getKind(),
711                                           result, acc,
712                                           reductionOp.getFastmathAttr(), mask);
713 
714     rewriter.replaceOp(rootOp, result);
715     return success();
716   }
717 };
718 } // namespace
719 
720 void ReductionOp::getCanonicalizationPatterns(RewritePatternSet &results,
721                                               MLIRContext *context) {
722   results.add<ElideSingleElementReduction>(context);
723 }
724 
725 //===----------------------------------------------------------------------===//
726 // ContractionOp
727 //===----------------------------------------------------------------------===//
728 
729 void vector::ContractionOp::build(OpBuilder &builder, OperationState &result,
730                                   Value lhs, Value rhs, Value acc,
731                                   ArrayRef<ArrayRef<AffineExpr>> indexingExprs,
732                                   ArrayRef<IteratorType> iteratorTypes) {
733   result.addOperands({lhs, rhs, acc});
734   result.addTypes(acc.getType());
735   result.addAttribute(
736       getIndexingMapsAttrName(result.name),
737       builder.getAffineMapArrayAttr(
738           AffineMap::inferFromExprList(indexingExprs, builder.getContext())));
739   result.addAttribute(
740       getIteratorTypesAttrName(result.name),
741       builder.getArrayAttr(llvm::to_vector(llvm::map_range(
742           iteratorTypes, [&](IteratorType t) -> mlir::Attribute {
743             return IteratorTypeAttr::get(builder.getContext(), t);
744           }))));
745 }
746 
747 void vector::ContractionOp::build(OpBuilder &builder, OperationState &result,
748                                   Value lhs, Value rhs, Value acc,
749                                   ArrayAttr indexingMaps,
750                                   ArrayAttr iteratorTypes) {
751   build(builder, result, lhs, rhs, acc, indexingMaps, iteratorTypes,
752         ContractionOp::getDefaultKind());
753 }
754 
755 void vector::ContractionOp::build(OpBuilder &builder, OperationState &result,
756                                   Value lhs, Value rhs, Value acc,
757                                   ArrayAttr indexingMaps,
758                                   ArrayAttr iteratorTypes, CombiningKind kind) {
759   result.addOperands({lhs, rhs, acc});
760   result.addTypes(acc.getType());
761   result.addAttribute(getIndexingMapsAttrName(result.name), indexingMaps);
762   result.addAttribute(getIteratorTypesAttrName(result.name), iteratorTypes);
763   result.addAttribute(getKindAttrName(result.name),
764                       CombiningKindAttr::get(builder.getContext(), kind));
765 }
766 
767 ParseResult ContractionOp::parse(OpAsmParser &parser, OperationState &result) {
768   OpAsmParser::UnresolvedOperand lhsInfo;
769   OpAsmParser::UnresolvedOperand rhsInfo;
770   OpAsmParser::UnresolvedOperand accInfo;
771   SmallVector<OpAsmParser::UnresolvedOperand, 2> masksInfo;
772   SmallVector<Type, 2> types;
773   Type resultType;
774   auto loc = parser.getCurrentLocation();
775   DictionaryAttr dictAttr;
776   // TODO: Unify linalg op attribute parsing.
777   if (parser.parseAttribute(dictAttr) || parser.parseOperand(lhsInfo) ||
778       parser.parseComma() || parser.parseOperand(rhsInfo) ||
779       parser.parseComma() || parser.parseOperand(accInfo) ||
780       parser.parseTrailingOperandList(masksInfo) ||
781       parser.parseOptionalAttrDict(result.attributes) ||
782       parser.parseColonTypeList(types) ||
783       parser.parseKeywordType("into", resultType) ||
784       parser.resolveOperand(lhsInfo, types[0], result.operands) ||
785       parser.resolveOperand(rhsInfo, types[1], result.operands) ||
786       parser.resolveOperand(accInfo, resultType, result.operands) ||
787       parser.addTypeToList(resultType, result.types))
788     return failure();
789   result.attributes.append(dictAttr.getValue().begin(),
790                            dictAttr.getValue().end());
791 
792   // Convert array of string into an array of IteratyType enums. This is needed,
793   // because tests still use the old format when 'iterator_types' attribute is
794   // represented as an array of strings.
795   // TODO: Remove this conversion once tests are fixed.
796   ArrayAttr iteratorTypes = llvm::cast<ArrayAttr>(
797       result.attributes.get(getIteratorTypesAttrName(result.name)));
798 
799   SmallVector<Attribute> iteratorTypeAttrs;
800 
801   for (StringRef s : iteratorTypes.getAsValueRange<StringAttr>()) {
802     auto maybeIteratorType = symbolizeIteratorType(s);
803     if (!maybeIteratorType.has_value())
804       return parser.emitError(loc) << "unexpected iterator_type (" << s << ")";
805 
806     iteratorTypeAttrs.push_back(
807         IteratorTypeAttr::get(parser.getContext(), maybeIteratorType.value()));
808   }
809   result.attributes.set(getIteratorTypesAttrName(result.name),
810                         parser.getBuilder().getArrayAttr(iteratorTypeAttrs));
811 
812   if (!result.attributes.get(getKindAttrName(result.name))) {
813     result.addAttribute(
814         getKindAttrName(result.name),
815         CombiningKindAttr::get(result.getContext(),
816                                ContractionOp::getDefaultKind()));
817   }
818   if (masksInfo.empty())
819     return success();
820   if (masksInfo.size() != 2)
821     return parser.emitError(parser.getNameLoc(),
822                             "expected zero or exactly 2 vector mask operands");
823   auto lhsType = llvm::cast<VectorType>(types[0]);
824   auto rhsType = llvm::cast<VectorType>(types[1]);
825   auto maskElementType = parser.getBuilder().getI1Type();
826   std::array<VectorType, 2> maskTypes = {
827       VectorType::Builder(lhsType).setElementType(maskElementType),
828       VectorType::Builder(rhsType).setElementType(maskElementType)};
829   if (parser.resolveOperands(masksInfo, maskTypes, loc, result.operands))
830     return failure();
831   return success();
832 }
833 
834 void ContractionOp::print(OpAsmPrinter &p) {
835   // TODO: Unify printing code with linalg ops.
836   auto attrNames = getTraitAttrNames();
837   llvm::StringSet<> traitAttrsSet;
838   traitAttrsSet.insert(attrNames.begin(), attrNames.end());
839   SmallVector<NamedAttribute, 8> attrs;
840   for (auto attr : (*this)->getAttrs()) {
841     if (attr.getName() == getIteratorTypesAttrName()) {
842       auto iteratorTypes =
843           llvm::cast<ArrayAttr>(attr.getValue())
844               .getAsValueRange<IteratorTypeAttr, IteratorType>();
845       // Convert IteratorType enums into the string representation. This is
846       // needed, because tests still use the old format when 'iterator_types'
847       // attribute is represented as an array of strings.
848       // TODO: Remove this conversion once tests are fixed.
849       SmallVector<Attribute> iteratorTypeNames = llvm::to_vector(
850           llvm::map_range(iteratorTypes, [&](IteratorType t) -> Attribute {
851             return StringAttr::get(getContext(), stringifyIteratorType(t));
852           }));
853 
854       attrs.emplace_back(getIteratorTypesAttrName(),
855                          ArrayAttr::get(getContext(), iteratorTypeNames));
856     } else if (traitAttrsSet.count(attr.getName().strref()) > 0)
857       attrs.push_back(attr);
858   }
859 
860   auto dictAttr = DictionaryAttr::get(getContext(), attrs);
861   p << " " << dictAttr << " " << getLhs() << ", ";
862   p << getRhs() << ", " << getAcc();
863 
864   p.printOptionalAttrDict((*this)->getAttrs(), attrNames);
865   p << " : " << getLhs().getType() << ", " << getRhs().getType() << " into "
866     << getResultType();
867 }
868 
869 static bool verifyDimMap(VectorType lhsType, VectorType rhsType,
870                          const std::vector<std::pair<int64_t, int64_t>> &map) {
871   for (auto &dimPair : map) {
872     if (dimPair.first < 0 || dimPair.first >= lhsType.getRank() ||
873         dimPair.second < 0 || dimPair.second >= rhsType.getRank() ||
874         lhsType.getDimSize(dimPair.first) != rhsType.getDimSize(dimPair.second))
875       return false;
876   }
877   return true;
878 }
879 
880 static LogicalResult verifyOutputShape(
881     ContractionOp op, VectorType lhsType, VectorType rhsType, Type accType,
882     Type resType,
883     const std::vector<std::pair<int64_t, int64_t>> &contractingDimMap,
884     const std::vector<std::pair<int64_t, int64_t>> &batchDimMap) {
885   DenseSet<int64_t> lhsContractingDimSet;
886   DenseSet<int64_t> rhsContractingDimSet;
887   for (auto &dimPair : contractingDimMap) {
888     lhsContractingDimSet.insert(dimPair.first);
889     rhsContractingDimSet.insert(dimPair.second);
890   }
891   DenseSet<int64_t> rhsBatchDimSet;
892   for (auto &dimPair : batchDimMap)
893     rhsBatchDimSet.insert(dimPair.second);
894 
895   // Add free and batch dimensions from 'lhsType' to 'expectedResultDims'.
896   SmallVector<int64_t, 4> expectedResultDims;
897   for (int64_t i = 0, e = lhsType.getRank(); i < e; ++i) {
898     if (lhsContractingDimSet.count(i) > 0)
899       continue;
900     expectedResultDims.push_back(lhsType.getDimSize(i));
901   }
902 
903   // Add free dimensions from 'rhsType' to 'expectedResultDims'.
904   for (int64_t i = 0, e = rhsType.getRank(); i < e; ++i) {
905     if (rhsContractingDimSet.count(i) > 0 || rhsBatchDimSet.count(i) > 0)
906       continue;
907     expectedResultDims.push_back(rhsType.getDimSize(i));
908   }
909 
910   // Verify 'expectedResultDims'.
911   if (expectedResultDims.empty()) {
912     // No batch or free dimension implies a scalar result.
913     if (llvm::isa<VectorType>(resType) || llvm::isa<VectorType>(accType))
914       return op.emitOpError("invalid accumulator/result vector shape");
915   } else {
916     // At least one batch or free dimension implies a vector result.
917     auto resVectorType = llvm::dyn_cast<VectorType>(resType);
918     auto accVectorType = llvm::dyn_cast<VectorType>(accType);
919     if (!resVectorType || !accVectorType)
920       return op.emitOpError("invalid accumulator/result vector shape");
921 
922     // Infer expected result vector type. Lhs + rhs map and lhs + rhs vector
923     // types fully define the result vector type. This assumes the affine maps
924     // are well-formed, which must have been verified already.
925     MLIRContext *ctx = op.getContext();
926     AffineMap lhsMap = op.getIndexingMapsArray()[0];
927     AffineMap rhsMap = op.getIndexingMapsArray()[1];
928     if (getUnusedDimsBitVector({lhsMap, rhsMap}).any())
929       return op.emitOpError(
930           "expected all dimensions to be either a LHS or a RHS dimension");
931     SmallVector<AffineExpr, 4> extents(lhsMap.getNumInputs());
932     for (auto pair :
933          {std::make_pair(lhsType, lhsMap), std::make_pair(rhsType, rhsMap)}) {
934       VectorType v = pair.first;
935       auto map = pair.second;
936       for (unsigned idx = 0, e = v.getRank(); idx < e; ++idx) {
937         unsigned pos = map.getDimPosition(idx);
938         if (!extents[pos])
939           extents[pos] = getAffineConstantExpr(v.getShape()[idx], ctx);
940       }
941     }
942     if (!llvm::all_of(extents, [](AffineExpr e) { return e; }))
943       return op.emitOpError("expected all dimensions to get an extent as "
944                             "either a LHS or a RHS dimension");
945 
946     AffineMap resMap = op.getIndexingMapsArray()[2];
947     auto extentsMap = AffineMap::get(/*dimCount=*/extents.size(),
948                                      /*symbolCount=*/0, extents, ctx);
949     // Compose the resMap with the extentsMap, which is a constant map.
950     AffineMap expectedMap = simplifyAffineMap(resMap.compose(extentsMap));
951     assert(llvm::all_of(expectedMap.getResults(),
952                         llvm::IsaPred<AffineConstantExpr>) &&
953            "expected constant extent along all dimensions.");
954     // Extract the expected shape and build the type.
955     auto expectedShape = llvm::to_vector<4>(
956         llvm::map_range(expectedMap.getResults(), [](AffineExpr e) {
957           return cast<AffineConstantExpr>(e).getValue();
958         }));
959     auto expected =
960         VectorType::get(expectedShape, resVectorType.getElementType(),
961                         resVectorType.getScalableDims());
962     if (resVectorType != expected || accVectorType != expected)
963       return op.emitOpError(
964                  "invalid accumulator/result vector shape, expected: ")
965              << expected;
966   }
967   return success();
968 }
969 
970 LogicalResult ContractionOp::verify() {
971   VectorType lhsType = getLhsType();
972   VectorType rhsType = getRhsType();
973   Type accType = getAccType();
974   Type resType = getResultType();
975 
976   if (llvm::isa<IntegerType>(lhsType.getElementType())) {
977     if (!lhsType.getElementType().isSignlessInteger())
978       return emitOpError("only supports signless integer types");
979   }
980 
981   // Verify that an indexing map was specified for each vector operand.
982   if (getIndexingMapsArray().size() != 3)
983     return emitOpError("expected an indexing map for each vector operand");
984 
985   // Verify that each index map has 'numIterators' inputs, no symbols, and
986   // that the number of map outputs equals the rank of its associated
987   // vector operand.
988   unsigned numIterators = getIteratorTypes().getValue().size();
989   for (const auto &it : llvm::enumerate(getIndexingMapsArray())) {
990     auto index = it.index();
991     auto map = it.value();
992     if (map.getNumSymbols() != 0)
993       return emitOpError("expected indexing map ")
994              << index << " to have no symbols";
995     auto vectorType = llvm::dyn_cast<VectorType>(getOperand(index).getType());
996     unsigned rank = vectorType ? vectorType.getShape().size() : 0;
997     // Verify that the map has the right number of inputs, outputs, and indices.
998     // This also correctly accounts for (..) -> () for rank-0 results.
999     if (map.getNumDims() != numIterators)
1000       return emitOpError("expected indexing map ")
1001              << index << " to have " << numIterators << " number of inputs";
1002     if (map.getNumResults() != rank)
1003       return emitOpError("expected indexing map ")
1004              << index << " to have " << rank << " number of outputs";
1005     if (!map.isProjectedPermutation())
1006       return emitOpError("expected indexing map ")
1007              << index << " to be a projected permutation of its inputs";
1008   }
1009 
1010   auto contractingDimMap = getContractingDimMap();
1011   auto batchDimMap = getBatchDimMap();
1012 
1013   // Verify at least one contracting dimension pair was specified.
1014   if (contractingDimMap.empty())
1015     return emitOpError("expected at least one contracting dimension pair");
1016 
1017   // Verify contracting dimension map was properly constructed.
1018   if (!verifyDimMap(lhsType, rhsType, contractingDimMap))
1019     return emitOpError("invalid contracting dimension map");
1020 
1021   // Verify batch dimension map was properly constructed.
1022   if (!verifyDimMap(lhsType, rhsType, batchDimMap))
1023     return emitOpError("invalid batch dimension map");
1024 
1025   // Verify 'accType' and 'resType' shape.
1026   if (failed(verifyOutputShape(*this, lhsType, rhsType, accType, resType,
1027                                contractingDimMap, batchDimMap)))
1028     return failure();
1029 
1030   // Verify supported combining kind.
1031   auto vectorType = llvm::dyn_cast<VectorType>(resType);
1032   auto elementType = vectorType ? vectorType.getElementType() : resType;
1033   if (!isSupportedCombiningKind(getKind(), elementType))
1034     return emitOpError("unsupported contraction type");
1035 
1036   return success();
1037 }
1038 
1039 // MaskableOpInterface methods.
1040 
1041 /// Returns the mask type expected by this operation. Mostly used for
1042 /// verification purposes. It requires the operation to be vectorized."
1043 Type ContractionOp::getExpectedMaskType() {
1044   auto indexingMaps = this->getIndexingMapsArray();
1045   AffineMap lhsIdxMap = indexingMaps[0];
1046   AffineMap rhsIdxMap = indexingMaps[1];
1047   VectorType lhsType = this->getLhsType();
1048   VectorType rhsType = this->getRhsType();
1049 
1050   unsigned numVecDims = lhsIdxMap.getNumDims();
1051   SmallVector<int64_t> maskShape(numVecDims, ShapedType::kDynamic);
1052   SmallVector<bool> maskShapeScalableDims(numVecDims, false);
1053 
1054   // Using the information in the indexing maps, extract the size of each
1055   // dimension in the vector.contract operation from the two input operands.
1056   for (auto [dimIdx, dimSize] : llvm::enumerate(lhsType.getShape())) {
1057     maskShape[lhsIdxMap.getDimPosition(dimIdx)] = dimSize;
1058     maskShapeScalableDims[lhsIdxMap.getDimPosition(dimIdx)] =
1059         lhsType.getScalableDims()[dimIdx];
1060   }
1061   for (auto [dimIdx, dimSize] : llvm::enumerate(rhsType.getShape())) {
1062     maskShape[rhsIdxMap.getDimPosition(dimIdx)] = dimSize;
1063     maskShapeScalableDims[rhsIdxMap.getDimPosition(dimIdx)] =
1064         rhsType.getScalableDims()[dimIdx];
1065   }
1066 
1067   assert(!ShapedType::isDynamicShape(maskShape) &&
1068          "Mask shape couldn't be computed");
1069 
1070   return VectorType::get(maskShape,
1071                          IntegerType::get(lhsType.getContext(), /*width=*/1),
1072                          maskShapeScalableDims);
1073 }
1074 
1075 SmallVector<StringRef> ContractionOp::getTraitAttrNames() {
1076   return SmallVector<StringRef>{getIndexingMapsAttrName(),
1077                                 getIteratorTypesAttrName(), getKindAttrName()};
1078 }
1079 
1080 static int64_t getResultIndex(AffineMap map, AffineExpr targetExpr) {
1081   for (int64_t i = 0, e = map.getNumResults(); i < e; ++i)
1082     if (targetExpr == map.getResult(i))
1083       return i;
1084   return -1;
1085 }
1086 
1087 static std::vector<std::pair<int64_t, int64_t>>
1088 getDimMap(ArrayRef<AffineMap> indexingMaps, ArrayAttr iteratorTypes,
1089           IteratorType targetIteratorType, MLIRContext *context) {
1090   std::vector<std::pair<int64_t, int64_t>> dimMap;
1091   for (const auto &it : llvm::enumerate(iteratorTypes)) {
1092     auto iteratorType = llvm::cast<IteratorTypeAttr>(it.value()).getValue();
1093     if (iteratorType != targetIteratorType)
1094       continue;
1095     // Search lhs/rhs map results for 'targetExpr'.
1096     auto targetExpr = getAffineDimExpr(it.index(), context);
1097     int64_t lhsDim = getResultIndex(indexingMaps[0], targetExpr);
1098     int64_t rhsDim = getResultIndex(indexingMaps[1], targetExpr);
1099     if (lhsDim >= 0 && rhsDim >= 0)
1100       dimMap.emplace_back(lhsDim, rhsDim);
1101   }
1102   return dimMap;
1103 }
1104 
1105 void ContractionOp::getIterationBounds(
1106     SmallVectorImpl<int64_t> &iterationBounds) {
1107   auto lhsShape = getLhsType().getShape();
1108   auto resVectorType = llvm::dyn_cast<VectorType>(getResultType());
1109   SmallVector<AffineMap, 4> indexingMaps(getIndexingMapsArray());
1110   SmallVector<int64_t, 2> iterationShape;
1111   for (const auto &it : llvm::enumerate(getIteratorTypes())) {
1112     // Search lhs/rhs map results for 'targetExpr'.
1113     auto targetExpr = getAffineDimExpr(it.index(), getContext());
1114     auto iteratorType = llvm::cast<IteratorTypeAttr>(it.value()).getValue();
1115     if (iteratorType == IteratorType::reduction) {
1116       // Get reduction dim size from lhs shape (same size in rhsShape).
1117       int64_t lhsDimIndex = getResultIndex(indexingMaps[0], targetExpr);
1118       assert(lhsDimIndex >= 0);
1119       iterationBounds.push_back(lhsShape[lhsDimIndex]);
1120       continue;
1121     }
1122     // Get parallel dimension size from result shape.
1123     int64_t resDimIndex = getResultIndex(indexingMaps[2], targetExpr);
1124     assert(resDimIndex >= 0);
1125     assert(resVectorType != nullptr);
1126     iterationBounds.push_back(resVectorType.getShape()[resDimIndex]);
1127   }
1128 }
1129 
1130 void ContractionOp::getIterationIndexMap(
1131     std::vector<DenseMap<int64_t, int64_t>> &iterationIndexMap) {
1132   unsigned numMaps = getIndexingMapsArray().size();
1133   iterationIndexMap.resize(numMaps);
1134   for (const auto &it : llvm::enumerate(getIndexingMapsArray())) {
1135     auto index = it.index();
1136     auto map = it.value();
1137     for (unsigned i = 0, e = map.getNumResults(); i < e; ++i) {
1138       auto dim = cast<AffineDimExpr>(map.getResult(i));
1139       iterationIndexMap[index][dim.getPosition()] = i;
1140     }
1141   }
1142 }
1143 
1144 std::vector<std::pair<int64_t, int64_t>> ContractionOp::getContractingDimMap() {
1145   SmallVector<AffineMap, 4> indexingMaps(getIndexingMapsArray());
1146   return getDimMap(indexingMaps, getIteratorTypes(), IteratorType::reduction,
1147                    getContext());
1148 }
1149 
1150 std::vector<std::pair<int64_t, int64_t>> ContractionOp::getBatchDimMap() {
1151   SmallVector<AffineMap, 4> indexingMaps(getIndexingMapsArray());
1152   return getDimMap(indexingMaps, getIteratorTypes(), IteratorType::parallel,
1153                    getContext());
1154 }
1155 
1156 std::optional<SmallVector<int64_t, 4>> ContractionOp::getShapeForUnroll() {
1157   SmallVector<int64_t, 4> shape;
1158   getIterationBounds(shape);
1159   return shape;
1160 }
1161 
1162 /// Return a fused vector::ContractionOp which represents a patterns such as:
1163 ///
1164 /// ```mlir
1165 ///    %c0 = vector.constant 0: ...
1166 ///    %c = vector.contract %a, %b, %c0: ...
1167 ///    %e = add %c, %d: ...
1168 /// ```
1169 ///
1170 /// by:
1171 ///
1172 /// ```mlir
1173 ///    %e = vector.contract %a, %b, %d: ...
1174 /// ```
1175 ///
1176 /// Return null if the canonicalization does not apply.
1177 // TODO: This should be a folding of Add into Contract in core but while they
1178 // live in different dialects, it is not possible without unnatural
1179 // dependencies.
1180 template <typename AddOpType>
1181 struct CanonicalizeContractAdd : public OpRewritePattern<AddOpType> {
1182   using OpRewritePattern<AddOpType>::OpRewritePattern;
1183 
1184   LogicalResult matchAndRewrite(AddOpType addOp,
1185                                 PatternRewriter &rewriter) const override {
1186     auto canonicalize = [&](Value maybeContraction,
1187                             Value otherOperand) -> vector::ContractionOp {
1188       vector::ContractionOp contractionOp =
1189           dyn_cast_or_null<vector::ContractionOp>(
1190               maybeContraction.getDefiningOp());
1191       if (!contractionOp)
1192         return vector::ContractionOp();
1193       if (auto maybeZero = dyn_cast_or_null<arith::ConstantOp>(
1194               contractionOp.getAcc().getDefiningOp())) {
1195         if (maybeZero.getValue() ==
1196             rewriter.getZeroAttr(contractionOp.getAcc().getType())) {
1197           IRMapping bvm;
1198           bvm.map(contractionOp.getAcc(), otherOperand);
1199           auto newContraction =
1200               cast<vector::ContractionOp>(rewriter.clone(*contractionOp, bvm));
1201           rewriter.replaceOp(addOp, newContraction.getResult());
1202           return newContraction;
1203         }
1204       }
1205       return vector::ContractionOp();
1206     };
1207 
1208     Value a = addOp->getOperand(0), b = addOp->getOperand(1);
1209     vector::ContractionOp contract = canonicalize(a, b);
1210     contract = contract ? contract : canonicalize(b, a);
1211     return contract ? success() : failure();
1212   }
1213 };
1214 
1215 void ContractionOp::getCanonicalizationPatterns(RewritePatternSet &results,
1216                                                 MLIRContext *context) {
1217   results.add<CanonicalizeContractAdd<arith::AddIOp>,
1218               CanonicalizeContractAdd<arith::AddFOp>>(context);
1219 }
1220 
1221 //===----------------------------------------------------------------------===//
1222 // ExtractElementOp
1223 //===----------------------------------------------------------------------===//
1224 
1225 void ExtractElementOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
1226                                          SetIntRangeFn setResultRanges) {
1227   setResultRanges(getResult(), argRanges.front());
1228 }
1229 
1230 void vector::ExtractElementOp::build(OpBuilder &builder, OperationState &result,
1231                                      Value source) {
1232   result.addOperands({source});
1233   result.addTypes(llvm::cast<VectorType>(source.getType()).getElementType());
1234 }
1235 
1236 LogicalResult vector::ExtractElementOp::verify() {
1237   VectorType vectorType = getSourceVectorType();
1238   if (vectorType.getRank() == 0) {
1239     if (getPosition())
1240       return emitOpError("expected position to be empty with 0-D vector");
1241     return success();
1242   }
1243   if (vectorType.getRank() != 1)
1244     return emitOpError("unexpected >1 vector rank");
1245   if (!getPosition())
1246     return emitOpError("expected position for 1-D vector");
1247   return success();
1248 }
1249 
1250 OpFoldResult vector::ExtractElementOp::fold(FoldAdaptor adaptor) {
1251   // Skip the 0-D vector here now.
1252   if (!adaptor.getPosition())
1253     return {};
1254 
1255   // Fold extractelement (splat X) -> X.
1256   if (auto splat = getVector().getDefiningOp<vector::SplatOp>())
1257     return splat.getInput();
1258 
1259   // Fold extractelement(broadcast(X)) -> X.
1260   if (auto broadcast = getVector().getDefiningOp<vector::BroadcastOp>())
1261     if (!llvm::isa<VectorType>(broadcast.getSource().getType()))
1262       return broadcast.getSource();
1263 
1264   auto src = dyn_cast_or_null<DenseElementsAttr>(adaptor.getVector());
1265   auto pos = dyn_cast_or_null<IntegerAttr>(adaptor.getPosition());
1266   if (!pos || !src)
1267     return {};
1268 
1269   auto srcElements = src.getValues<Attribute>();
1270 
1271   uint64_t posIdx = pos.getInt();
1272   if (posIdx >= srcElements.size())
1273     return {};
1274 
1275   return srcElements[posIdx];
1276 }
1277 
1278 // Returns `true` if `index` is either within [0, maxIndex) or equal to
1279 // `poisonValue`.
1280 static bool isValidPositiveIndexOrPoison(int64_t index, int64_t poisonValue,
1281                                          int64_t maxIndex) {
1282   return index == poisonValue || (index >= 0 && index < maxIndex);
1283 }
1284 
1285 //===----------------------------------------------------------------------===//
1286 // ExtractOp
1287 //===----------------------------------------------------------------------===//
1288 
1289 void ExtractOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
1290                                   SetIntRangeFn setResultRanges) {
1291   setResultRanges(getResult(), argRanges.front());
1292 }
1293 
1294 void vector::ExtractOp::build(OpBuilder &builder, OperationState &result,
1295                               Value source, int64_t position) {
1296   build(builder, result, source, ArrayRef<int64_t>{position});
1297 }
1298 
1299 void vector::ExtractOp::build(OpBuilder &builder, OperationState &result,
1300                               Value source, OpFoldResult position) {
1301   build(builder, result, source, ArrayRef<OpFoldResult>{position});
1302 }
1303 
1304 void vector::ExtractOp::build(OpBuilder &builder, OperationState &result,
1305                               Value source, ArrayRef<int64_t> position) {
1306   build(builder, result, source, /*dynamic_position=*/ArrayRef<Value>(),
1307         builder.getDenseI64ArrayAttr(position));
1308 }
1309 
1310 void vector::ExtractOp::build(OpBuilder &builder, OperationState &result,
1311                               Value source, ArrayRef<OpFoldResult> position) {
1312   SmallVector<int64_t> staticPos;
1313   SmallVector<Value> dynamicPos;
1314   dispatchIndexOpFoldResults(position, dynamicPos, staticPos);
1315   build(builder, result, source, dynamicPos,
1316         builder.getDenseI64ArrayAttr(staticPos));
1317 }
1318 
1319 LogicalResult
1320 ExtractOp::inferReturnTypes(MLIRContext *, std::optional<Location>,
1321                             ExtractOp::Adaptor adaptor,
1322                             SmallVectorImpl<Type> &inferredReturnTypes) {
1323   auto vectorType = llvm::cast<VectorType>(adaptor.getVector().getType());
1324   if (static_cast<int64_t>(adaptor.getStaticPosition().size()) ==
1325       vectorType.getRank()) {
1326     inferredReturnTypes.push_back(vectorType.getElementType());
1327   } else {
1328     auto n = std::min<size_t>(adaptor.getStaticPosition().size(),
1329                               vectorType.getRank());
1330     inferredReturnTypes.push_back(VectorType::get(
1331         vectorType.getShape().drop_front(n), vectorType.getElementType(),
1332         vectorType.getScalableDims().drop_front(n)));
1333   }
1334   return success();
1335 }
1336 
1337 bool ExtractOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1338   // Allow extracting 1-element vectors instead of scalars.
1339   auto isCompatible = [](TypeRange l, TypeRange r) {
1340     auto vectorType = llvm::dyn_cast<VectorType>(l.front());
1341     return vectorType && vectorType.getShape().equals({1}) &&
1342            vectorType.getElementType() == r.front();
1343   };
1344   if (l.size() == 1 && r.size() == 1 &&
1345       (isCompatible(l, r) || isCompatible(r, l)))
1346     return true;
1347   return l == r;
1348 }
1349 
1350 LogicalResult vector::ExtractOp::verify() {
1351   // Note: This check must come before getMixedPosition() to prevent a crash.
1352   auto dynamicMarkersCount =
1353       llvm::count_if(getStaticPosition(), ShapedType::isDynamic);
1354   if (static_cast<size_t>(dynamicMarkersCount) != getDynamicPosition().size())
1355     return emitOpError(
1356         "mismatch between dynamic and static positions (kDynamic marker but no "
1357         "corresponding dynamic position) -- this can only happen due to an "
1358         "incorrect fold/rewrite");
1359   auto position = getMixedPosition();
1360   if (position.size() > static_cast<unsigned>(getSourceVectorType().getRank()))
1361     return emitOpError(
1362         "expected position attribute of rank no greater than vector rank");
1363   for (auto [idx, pos] : llvm::enumerate(position)) {
1364     if (auto attr = dyn_cast<Attribute>(pos)) {
1365       int64_t constIdx = cast<IntegerAttr>(attr).getInt();
1366       if (!isValidPositiveIndexOrPoison(
1367               constIdx, kPoisonIndex, getSourceVectorType().getDimSize(idx))) {
1368         return emitOpError("expected position attribute #")
1369                << (idx + 1)
1370                << " to be a non-negative integer smaller than the "
1371                   "corresponding vector dimension or poison (-1)";
1372       }
1373     }
1374   }
1375   return success();
1376 }
1377 
1378 template <typename IntType>
1379 static SmallVector<IntType> extractVector(ArrayAttr arrayAttr) {
1380   return llvm::to_vector<4>(llvm::map_range(
1381       arrayAttr.getAsRange<IntegerAttr>(),
1382       [](IntegerAttr attr) { return static_cast<IntType>(attr.getInt()); }));
1383 }
1384 
1385 /// Fold the result of chains of ExtractOp in place by simply concatenating the
1386 /// positions.
1387 static LogicalResult foldExtractOpFromExtractChain(ExtractOp extractOp) {
1388   if (!extractOp.getVector().getDefiningOp<ExtractOp>())
1389     return failure();
1390 
1391   // TODO: Canonicalization for dynamic position not implemented yet.
1392   if (extractOp.hasDynamicPosition())
1393     return failure();
1394 
1395   SmallVector<int64_t> globalPosition;
1396   ExtractOp currentOp = extractOp;
1397   ArrayRef<int64_t> extrPos = currentOp.getStaticPosition();
1398   globalPosition.append(extrPos.rbegin(), extrPos.rend());
1399   while (ExtractOp nextOp = currentOp.getVector().getDefiningOp<ExtractOp>()) {
1400     currentOp = nextOp;
1401     // TODO: Canonicalization for dynamic position not implemented yet.
1402     if (currentOp.hasDynamicPosition())
1403       return failure();
1404     ArrayRef<int64_t> extrPos = currentOp.getStaticPosition();
1405     globalPosition.append(extrPos.rbegin(), extrPos.rend());
1406   }
1407   extractOp.setOperand(0, currentOp.getVector());
1408   // OpBuilder is only used as a helper to build an I64ArrayAttr.
1409   OpBuilder b(extractOp.getContext());
1410   std::reverse(globalPosition.begin(), globalPosition.end());
1411   extractOp.setStaticPosition(globalPosition);
1412   return success();
1413 }
1414 
1415 namespace {
1416 /// Fold an ExtractOp that is fed by a chain of InsertOps and TransposeOps.
1417 /// Walk back a chain of InsertOp/TransposeOp until we hit a match.
1418 /// Compose TransposeOp permutations as we walk back.
1419 /// This helper class keeps an updated extraction position `extractPosition`
1420 /// with extra trailing sentinels.
1421 /// The sentinels encode the internal transposition status of the result vector.
1422 /// As we iterate, extractPosition is permuted and updated.
1423 class ExtractFromInsertTransposeChainState {
1424 public:
1425   ExtractFromInsertTransposeChainState(ExtractOp e);
1426 
1427   /// Iterate over producing insert and transpose ops until we find a fold.
1428   Value fold();
1429 
1430 private:
1431   /// Return true if the vector at position `a` is contained within the vector
1432   /// at position `b`. Under insert/extract semantics, this is the same as `a`
1433   /// is a prefix of `b`.
1434   template <typename ContainerA, typename ContainerB>
1435   bool isContainedWithin(const ContainerA &a, const ContainerB &b) {
1436     return a.size() <= b.size() &&
1437            std::equal(a.begin(), a.begin() + a.size(), b.begin());
1438   }
1439 
1440   /// Return true if the vector at position `a` intersects the vector at
1441   /// position `b`. Under insert/extract semantics, this is the same as equality
1442   /// of all entries of `a` that are >=0 with the corresponding entries of b.
1443   /// Comparison is on the common prefix (i.e. zip).
1444   template <typename ContainerA, typename ContainerB>
1445   bool intersectsWhereNonNegative(const ContainerA &a, const ContainerB &b) {
1446     for (auto [elemA, elemB] : llvm::zip(a, b)) {
1447       if (elemA < 0 || elemB < 0)
1448         continue;
1449       if (elemA != elemB)
1450         return false;
1451     }
1452     return true;
1453   }
1454 
1455   /// Folding is only possible in the absence of an internal permutation in the
1456   /// result vector.
1457   bool canFold() {
1458     return (sentinels == ArrayRef(extractPosition).drop_front(extractedRank));
1459   }
1460 
1461   // Helper to get the next defining op of interest.
1462   void updateStateForNextIteration(Value v) {
1463     nextInsertOp = v.getDefiningOp<vector::InsertOp>();
1464     nextTransposeOp = v.getDefiningOp<vector::TransposeOp>();
1465   };
1466 
1467   // Case 1. If we hit a transpose, just compose the map and iterate.
1468   // Invariant: insert + transpose do not change rank, we can always compose.
1469   LogicalResult handleTransposeOp();
1470 
1471   // Case 2: the insert position matches extractPosition exactly, early return.
1472   LogicalResult handleInsertOpWithMatchingPos(Value &res);
1473 
1474   /// Case 3: if the insert position is a prefix of extractPosition, extract a
1475   /// portion of the source of the insert.
1476   /// Example:
1477   /// ```
1478   /// %ins = vector.insert %source, %vest[1]: vector<3x4> into vector<2x3x4x5>
1479   /// // extractPosition == [1, 2, 3]
1480   /// %ext = vector.extract %ins[1, 0]: vector<5> from vector<3x4x5>
1481   /// // can fold to vector.extract %source[0, 3]
1482   /// %ext = vector.extract %source[3]: vector<6> from vector<5x6>
1483   /// ```
1484   /// To traverse through %source, we need to set the leading dims to 0 and
1485   /// drop the extra leading dims.
1486   /// This method updates the internal state.
1487   LogicalResult handleInsertOpWithPrefixPos(Value &res);
1488 
1489   /// Try to fold in place to extract(source, extractPosition) and return the
1490   /// folded result. Return null if folding is not possible (e.g. due to an
1491   /// internal transposition in the result).
1492   Value tryToFoldExtractOpInPlace(Value source);
1493 
1494   ExtractOp extractOp;
1495   int64_t vectorRank;
1496   int64_t extractedRank;
1497 
1498   InsertOp nextInsertOp;
1499   TransposeOp nextTransposeOp;
1500 
1501   /// Sentinel values that encode the internal permutation status of the result.
1502   /// They are set to (-1, ... , -k) at the beginning and appended to
1503   /// `extractPosition`.
1504   /// In the end, the tail of `extractPosition` must be exactly `sentinels` to
1505   /// ensure that there is no internal transposition.
1506   /// Internal transposition cannot be accounted for with a folding pattern.
1507   // TODO: We could relax the internal transposition with an extra transposition
1508   // operation in a future canonicalizer.
1509   SmallVector<int64_t> sentinels;
1510   SmallVector<int64_t> extractPosition;
1511 };
1512 } // namespace
1513 
1514 ExtractFromInsertTransposeChainState::ExtractFromInsertTransposeChainState(
1515     ExtractOp e)
1516     : extractOp(e), vectorRank(extractOp.getSourceVectorType().getRank()),
1517       extractedRank(extractOp.getNumIndices()) {
1518   assert(vectorRank >= extractedRank && "Extracted position overflow");
1519   sentinels.reserve(vectorRank - extractedRank);
1520   for (int64_t i = 0, e = vectorRank - extractedRank; i < e; ++i)
1521     sentinels.push_back(-(i + 1));
1522   extractPosition.assign(extractOp.getStaticPosition().begin(),
1523                          extractOp.getStaticPosition().end());
1524   llvm::append_range(extractPosition, sentinels);
1525 }
1526 
1527 // Case 1. If we hit a transpose, just compose the map and iterate.
1528 // Invariant: insert + transpose do not change rank, we can always compose.
1529 LogicalResult ExtractFromInsertTransposeChainState::handleTransposeOp() {
1530   // TODO: Canonicalization for dynamic position not implemented yet.
1531   if (extractOp.hasDynamicPosition())
1532     return failure();
1533 
1534   if (!nextTransposeOp)
1535     return failure();
1536   AffineMap m = inversePermutation(AffineMap::getPermutationMap(
1537       nextTransposeOp.getPermutation(), extractOp.getContext()));
1538   extractPosition = applyPermutationMap(m, ArrayRef(extractPosition));
1539   return success();
1540 }
1541 
1542 // Case 2: the insert position matches extractPosition exactly, early return.
1543 LogicalResult
1544 ExtractFromInsertTransposeChainState::handleInsertOpWithMatchingPos(
1545     Value &res) {
1546   // TODO: Canonicalization for dynamic position not implemented yet.
1547   if (extractOp.hasDynamicPosition() || nextInsertOp.hasDynamicPosition())
1548     return failure();
1549 
1550   ArrayRef<int64_t> insertedPos = nextInsertOp.getStaticPosition();
1551   if (insertedPos != llvm::ArrayRef(extractPosition).take_front(extractedRank))
1552     return failure();
1553   // Case 2.a. early-exit fold.
1554   res = nextInsertOp.getSource();
1555   // Case 2.b. if internal transposition is present, canFold will be false.
1556   return success(canFold());
1557 }
1558 
1559 /// Case 3: if inserted position is a prefix of extractPosition,
1560 /// extract a portion of the source of the insertion.
1561 /// This method updates the internal state.
1562 LogicalResult
1563 ExtractFromInsertTransposeChainState::handleInsertOpWithPrefixPos(Value &res) {
1564   // TODO: Canonicalization for dynamic position not implemented yet.
1565   if (extractOp.hasDynamicPosition() || nextInsertOp.hasDynamicPosition())
1566     return failure();
1567 
1568   ArrayRef<int64_t> insertedPos = nextInsertOp.getStaticPosition();
1569   if (!isContainedWithin(insertedPos, extractPosition))
1570     return failure();
1571   // Set leading dims to zero.
1572   std::fill_n(extractPosition.begin(), insertedPos.size(), 0);
1573   // Drop extra leading dims.
1574   extractPosition.erase(extractPosition.begin(),
1575                         extractPosition.begin() + insertedPos.size());
1576   extractedRank = extractPosition.size() - sentinels.size();
1577   // Case 3.a. early-exit fold (break and delegate to post-while path).
1578   res = nextInsertOp.getSource();
1579   // Case 3.b. if internal transposition is present, canFold will be false.
1580   return success();
1581 }
1582 
1583 /// Try to fold in place to extract(source, extractPosition) and return the
1584 /// folded result. Return null if folding is not possible (e.g. due to an
1585 /// internal transposition in the result).
1586 Value ExtractFromInsertTransposeChainState::tryToFoldExtractOpInPlace(
1587     Value source) {
1588   // TODO: Canonicalization for dynamic position not implemented yet.
1589   if (extractOp.hasDynamicPosition())
1590     return Value();
1591 
1592   // If we can't fold (either internal transposition, or nothing to fold), bail.
1593   bool nothingToFold = (source == extractOp.getVector());
1594   if (nothingToFold || !canFold())
1595     return Value();
1596 
1597   // Otherwise, fold by updating the op inplace and return its result.
1598   OpBuilder b(extractOp.getContext());
1599   extractOp.setStaticPosition(
1600       ArrayRef(extractPosition).take_front(extractedRank));
1601   extractOp.getVectorMutable().assign(source);
1602   return extractOp.getResult();
1603 }
1604 
1605 /// Iterate over producing insert and transpose ops until we find a fold.
1606 Value ExtractFromInsertTransposeChainState::fold() {
1607   // TODO: Canonicalization for dynamic position not implemented yet.
1608   if (extractOp.hasDynamicPosition())
1609     return Value();
1610 
1611   Value valueToExtractFrom = extractOp.getVector();
1612   updateStateForNextIteration(valueToExtractFrom);
1613   while (nextInsertOp || nextTransposeOp) {
1614     // Case 1. If we hit a transpose, just compose the map and iterate.
1615     // Invariant: insert + transpose do not change rank, we can always compose.
1616     if (succeeded(handleTransposeOp())) {
1617       valueToExtractFrom = nextTransposeOp.getVector();
1618       updateStateForNextIteration(valueToExtractFrom);
1619       continue;
1620     }
1621 
1622     Value result;
1623     // Case 2: the position match exactly.
1624     if (succeeded(handleInsertOpWithMatchingPos(result)))
1625       return result;
1626 
1627     // Case 3: if the inserted position is a prefix of extractPosition, we can
1628     // just extract a portion of the source of the insert.
1629     if (succeeded(handleInsertOpWithPrefixPos(result)))
1630       return tryToFoldExtractOpInPlace(result);
1631 
1632     // Case 4: extractPositionRef intersects insertedPosRef on non-sentinel
1633     // values. This is a more difficult case and we bail.
1634     ArrayRef<int64_t> insertedPos = nextInsertOp.getStaticPosition();
1635     if (isContainedWithin(extractPosition, insertedPos) ||
1636         intersectsWhereNonNegative(extractPosition, insertedPos))
1637       return Value();
1638 
1639     // Case 5: No intersection, we forward the extract to insertOp.dest().
1640     valueToExtractFrom = nextInsertOp.getDest();
1641     updateStateForNextIteration(valueToExtractFrom);
1642   }
1643   // If after all this we can fold, go for it.
1644   return tryToFoldExtractOpInPlace(valueToExtractFrom);
1645 }
1646 
1647 /// Returns true if the operation has a 0-D vector type operand or result.
1648 static bool hasZeroDimVectors(Operation *op) {
1649   auto hasZeroDimVectorType = [](Type type) -> bool {
1650     auto vecType = dyn_cast<VectorType>(type);
1651     return vecType && vecType.getRank() == 0;
1652   };
1653 
1654   return llvm::any_of(op->getOperandTypes(), hasZeroDimVectorType) ||
1655          llvm::any_of(op->getResultTypes(), hasZeroDimVectorType);
1656 }
1657 
1658 /// Fold extractOp with scalar result coming from BroadcastOp or SplatOp.
1659 static Value foldExtractFromBroadcast(ExtractOp extractOp) {
1660   // TODO: Canonicalization for dynamic position not implemented yet.
1661   if (extractOp.hasDynamicPosition())
1662     return Value();
1663 
1664   Operation *defOp = extractOp.getVector().getDefiningOp();
1665   if (!defOp || !isa<vector::BroadcastOp, SplatOp>(defOp))
1666     return Value();
1667 
1668   Value source = defOp->getOperand(0);
1669   if (extractOp.getType() == source.getType())
1670     return source;
1671   auto getRank = [](Type type) {
1672     return llvm::isa<VectorType>(type) ? llvm::cast<VectorType>(type).getRank()
1673                                        : 0;
1674   };
1675 
1676   // If splat or broadcast from a scalar, just return the source scalar.
1677   unsigned broadcastSrcRank = getRank(source.getType());
1678   if (broadcastSrcRank == 0 && source.getType() == extractOp.getType())
1679     return source;
1680 
1681   unsigned extractResultRank = getRank(extractOp.getType());
1682   if (extractResultRank >= broadcastSrcRank)
1683     return Value();
1684   // Check that the dimension of the result haven't been broadcasted.
1685   auto extractVecType = llvm::dyn_cast<VectorType>(extractOp.getType());
1686   auto broadcastVecType = llvm::dyn_cast<VectorType>(source.getType());
1687   if (extractVecType && broadcastVecType &&
1688       extractVecType.getShape() !=
1689           broadcastVecType.getShape().take_back(extractResultRank))
1690     return Value();
1691 
1692   auto broadcastOp = cast<vector::BroadcastOp>(defOp);
1693   int64_t broadcastDstRank = broadcastOp.getResultVectorType().getRank();
1694 
1695   // Detect all the positions that come from "dim-1" broadcasting.
1696   // These dimensions correspond to "dim-1" broadcasted dims; set the mathching
1697   // extract position to `0` when extracting from the source operand.
1698   llvm::SetVector<int64_t> broadcastedUnitDims =
1699       broadcastOp.computeBroadcastedUnitDims();
1700   SmallVector<int64_t> extractPos(extractOp.getStaticPosition());
1701   int64_t broadcastRankDiff = broadcastDstRank - broadcastSrcRank;
1702   for (int64_t i = broadcastRankDiff, e = extractPos.size(); i < e; ++i)
1703     if (broadcastedUnitDims.contains(i))
1704       extractPos[i] = 0;
1705   // `rankDiff` leading dimensions correspond to new broadcasted dims, drop the
1706   // matching extract position when extracting from the source operand.
1707   int64_t rankDiff = broadcastSrcRank - extractResultRank;
1708   extractPos.erase(extractPos.begin(),
1709                    std::next(extractPos.begin(), extractPos.size() - rankDiff));
1710   // OpBuilder is only used as a helper to build an I64ArrayAttr.
1711   OpBuilder b(extractOp.getContext());
1712   extractOp.setOperand(0, source);
1713   extractOp.setStaticPosition(extractPos);
1714   return extractOp.getResult();
1715 }
1716 
1717 /// Fold extractOp coming from ShuffleOp.
1718 ///
1719 /// Example:
1720 ///
1721 ///   %shuffle = vector.shuffle %a, %b [0, 8, 7, 15]
1722 ///     : vector<8xf32>, vector<8xf32>
1723 ///   %extract = vector.extract %shuffle[3] : f32 from vector<4xf32>
1724 /// ->
1725 ///   %extract = vector.extract %b[7] : f32 from vector<8xf32>
1726 ///
1727 static Value foldExtractFromShuffle(ExtractOp extractOp) {
1728   // Dynamic positions are not folded as the resulting code would be more
1729   // complex than the input code.
1730   if (extractOp.hasDynamicPosition())
1731     return Value();
1732 
1733   auto shuffleOp = extractOp.getVector().getDefiningOp<ShuffleOp>();
1734   if (!shuffleOp)
1735     return Value();
1736 
1737   // TODO: 0-D or multi-dimensional vectors not supported yet.
1738   if (shuffleOp.getResultVectorType().getRank() != 1)
1739     return Value();
1740 
1741   int64_t inputVecSize = shuffleOp.getV1().getType().getShape()[0];
1742   auto shuffleMask = shuffleOp.getMask();
1743   int64_t extractIdx = extractOp.getStaticPosition()[0];
1744   int64_t shuffleIdx = shuffleMask[extractIdx];
1745 
1746   // Find the shuffled vector to extract from based on the shuffle index.
1747   if (shuffleIdx < inputVecSize) {
1748     extractOp.setOperand(0, shuffleOp.getV1());
1749     extractOp.setStaticPosition({shuffleIdx});
1750   } else {
1751     extractOp.setOperand(0, shuffleOp.getV2());
1752     extractOp.setStaticPosition({shuffleIdx - inputVecSize});
1753   }
1754 
1755   return extractOp.getResult();
1756 }
1757 
1758 // Fold extractOp with source coming from ShapeCast op.
1759 static Value foldExtractFromShapeCast(ExtractOp extractOp) {
1760   // TODO: Canonicalization for dynamic position not implemented yet.
1761   if (extractOp.hasDynamicPosition())
1762     return Value();
1763 
1764   auto shapeCastOp = extractOp.getVector().getDefiningOp<vector::ShapeCastOp>();
1765   if (!shapeCastOp)
1766     return Value();
1767 
1768   // Get the nth dimension size starting from lowest dimension.
1769   auto getDimReverse = [](VectorType type, int64_t n) {
1770     return type.getShape().take_back(n + 1).front();
1771   };
1772   int64_t destinationRank =
1773       llvm::isa<VectorType>(extractOp.getType())
1774           ? llvm::cast<VectorType>(extractOp.getType()).getRank()
1775           : 0;
1776   if (destinationRank > shapeCastOp.getSourceVectorType().getRank())
1777     return Value();
1778   if (destinationRank > 0) {
1779     auto destinationType =
1780         llvm::cast<VectorType>(extractOp.getResult().getType());
1781     for (int64_t i = 0; i < destinationRank; i++) {
1782       // The lowest dimension of the destination must match the lowest
1783       // dimension of the shapecast op source.
1784       // TODO: This case could be support in a canonicalization pattern.
1785       if (getDimReverse(shapeCastOp.getSourceVectorType(), i) !=
1786           getDimReverse(destinationType, i))
1787         return Value();
1788     }
1789   }
1790   // Extract the strides associated with the extract op vector source. Then use
1791   // this to calculate a linearized position for the extract.
1792   SmallVector<int64_t> extractedPos(extractOp.getStaticPosition());
1793   std::reverse(extractedPos.begin(), extractedPos.end());
1794   SmallVector<int64_t, 4> strides;
1795   int64_t stride = 1;
1796   for (int64_t i = 0, e = extractedPos.size(); i < e; i++) {
1797     strides.push_back(stride);
1798     stride *=
1799         getDimReverse(extractOp.getSourceVectorType(), i + destinationRank);
1800   }
1801 
1802   int64_t position = linearize(extractedPos, strides);
1803   // Then extract the strides associated to the shapeCast op vector source and
1804   // delinearize the position using those strides.
1805   SmallVector<int64_t, 4> newStrides;
1806   int64_t numDimension =
1807       shapeCastOp.getSourceVectorType().getRank() - destinationRank;
1808   stride = 1;
1809   for (int64_t i = 0; i < numDimension; i++) {
1810     newStrides.push_back(stride);
1811     stride *=
1812         getDimReverse(shapeCastOp.getSourceVectorType(), i + destinationRank);
1813   }
1814   std::reverse(newStrides.begin(), newStrides.end());
1815   SmallVector<int64_t, 4> newPosition = delinearize(position, newStrides);
1816   // OpBuilder is only used as a helper to build an I64ArrayAttr.
1817   OpBuilder b(extractOp.getContext());
1818   extractOp.setStaticPosition(newPosition);
1819   extractOp.setOperand(0, shapeCastOp.getSource());
1820   return extractOp.getResult();
1821 }
1822 
1823 /// Fold an ExtractOp from ExtractStridedSliceOp.
1824 static Value foldExtractFromExtractStrided(ExtractOp extractOp) {
1825   // TODO: Canonicalization for dynamic position not implemented yet.
1826   if (extractOp.hasDynamicPosition())
1827     return Value();
1828 
1829   auto extractStridedSliceOp =
1830       extractOp.getVector().getDefiningOp<vector::ExtractStridedSliceOp>();
1831   if (!extractStridedSliceOp)
1832     return Value();
1833 
1834   // 0-D vectors not supported.
1835   assert(!hasZeroDimVectors(extractOp) && "0-D vectors not supported");
1836   if (hasZeroDimVectors(extractStridedSliceOp))
1837     return Value();
1838 
1839   // Return if 'extractStridedSliceOp' has non-unit strides.
1840   if (extractStridedSliceOp.hasNonUnitStrides())
1841     return Value();
1842 
1843   // Trim offsets for dimensions fully extracted.
1844   auto sliceOffsets =
1845       extractVector<int64_t>(extractStridedSliceOp.getOffsets());
1846   while (!sliceOffsets.empty()) {
1847     size_t lastOffset = sliceOffsets.size() - 1;
1848     if (sliceOffsets.back() != 0 ||
1849         extractStridedSliceOp.getType().getDimSize(lastOffset) !=
1850             extractStridedSliceOp.getSourceVectorType().getDimSize(lastOffset))
1851       break;
1852     sliceOffsets.pop_back();
1853   }
1854   unsigned destinationRank = 0;
1855   if (auto vecType = llvm::dyn_cast<VectorType>(extractOp.getType()))
1856     destinationRank = vecType.getRank();
1857   // The dimensions of the result need to be untouched by the
1858   // extractStridedSlice op.
1859   if (destinationRank > extractStridedSliceOp.getSourceVectorType().getRank() -
1860                             sliceOffsets.size())
1861     return Value();
1862 
1863   SmallVector<int64_t> extractedPos(extractOp.getStaticPosition());
1864   assert(extractedPos.size() >= sliceOffsets.size());
1865   for (size_t i = 0, e = sliceOffsets.size(); i < e; i++)
1866     extractedPos[i] = extractedPos[i] + sliceOffsets[i];
1867   extractOp.getVectorMutable().assign(extractStridedSliceOp.getVector());
1868 
1869   // OpBuilder is only used as a helper to build an I64ArrayAttr.
1870   OpBuilder b(extractOp.getContext());
1871   extractOp.setStaticPosition(extractedPos);
1872   return extractOp.getResult();
1873 }
1874 
1875 /// Fold extract_op fed from a chain of insertStridedSlice ops.
1876 static Value foldExtractStridedOpFromInsertChain(ExtractOp extractOp) {
1877   // TODO: Canonicalization for dynamic position not implemented yet.
1878   if (extractOp.hasDynamicPosition())
1879     return Value();
1880 
1881   int64_t destinationRank =
1882       llvm::isa<VectorType>(extractOp.getType())
1883           ? llvm::cast<VectorType>(extractOp.getType()).getRank()
1884           : 0;
1885   auto insertOp = extractOp.getVector().getDefiningOp<InsertStridedSliceOp>();
1886   if (!insertOp)
1887     return Value();
1888 
1889   // 0-D vectors not supported.
1890   assert(!hasZeroDimVectors(extractOp) && "0-D vectors not supported");
1891   if (hasZeroDimVectors(insertOp))
1892     return Value();
1893 
1894   while (insertOp) {
1895     int64_t insertRankDiff = insertOp.getDestVectorType().getRank() -
1896                              insertOp.getSourceVectorType().getRank();
1897     if (destinationRank > insertOp.getSourceVectorType().getRank())
1898       return Value();
1899     auto insertOffsets = extractVector<int64_t>(insertOp.getOffsets());
1900     ArrayRef<int64_t> extractOffsets = extractOp.getStaticPosition();
1901 
1902     if (llvm::any_of(insertOp.getStrides(), [](Attribute attr) {
1903           return llvm::cast<IntegerAttr>(attr).getInt() != 1;
1904         }))
1905       return Value();
1906     bool disjoint = false;
1907     SmallVector<int64_t, 4> offsetDiffs;
1908     for (unsigned dim = 0, e = extractOffsets.size(); dim < e; ++dim) {
1909       int64_t start = insertOffsets[dim];
1910       int64_t size =
1911           (dim < insertRankDiff)
1912               ? 1
1913               : insertOp.getSourceVectorType().getDimSize(dim - insertRankDiff);
1914       int64_t end = start + size;
1915       int64_t offset = extractOffsets[dim];
1916       // Check if the start of the extract offset is in the interval inserted.
1917       if (start <= offset && offset < end) {
1918         if (dim >= insertRankDiff)
1919           offsetDiffs.push_back(offset - start);
1920         continue;
1921       }
1922       disjoint = true;
1923       break;
1924     }
1925     // The extract element chunk overlap with the vector inserted.
1926     if (!disjoint) {
1927       // If any of the inner dimensions are only partially inserted we have a
1928       // partial overlap.
1929       int64_t srcRankDiff =
1930           insertOp.getSourceVectorType().getRank() - destinationRank;
1931       for (int64_t i = 0; i < destinationRank; i++) {
1932         if (insertOp.getSourceVectorType().getDimSize(i + srcRankDiff) !=
1933             insertOp.getDestVectorType().getDimSize(i + srcRankDiff +
1934                                                     insertRankDiff))
1935           return Value();
1936       }
1937       extractOp.getVectorMutable().assign(insertOp.getSource());
1938       // OpBuilder is only used as a helper to build an I64ArrayAttr.
1939       OpBuilder b(extractOp.getContext());
1940       extractOp.setStaticPosition(offsetDiffs);
1941       return extractOp.getResult();
1942     }
1943     // If the chunk extracted is disjoint from the chunk inserted, keep
1944     // looking in the insert chain.
1945     insertOp = insertOp.getDest().getDefiningOp<InsertStridedSliceOp>();
1946   }
1947   return Value();
1948 }
1949 
1950 /// Try to fold the extraction of a scalar from a vector defined by
1951 /// vector.from_elements. E.g.:
1952 ///
1953 /// %0 = vector.from_elements %a, %b : vector<2xf32>
1954 /// %1 = vector.extract %0[0] : f32 from vector<2xf32>
1955 /// ==> fold to %a
1956 static Value foldScalarExtractFromFromElements(ExtractOp extractOp) {
1957   // Dynamic extractions cannot be folded.
1958   if (extractOp.hasDynamicPosition())
1959     return {};
1960 
1961   // Look for extract(from_elements).
1962   auto fromElementsOp = extractOp.getVector().getDefiningOp<FromElementsOp>();
1963   if (!fromElementsOp)
1964     return {};
1965 
1966   // Scalable vectors are not supported.
1967   auto vecType = llvm::cast<VectorType>(fromElementsOp.getType());
1968   if (vecType.isScalable())
1969     return {};
1970 
1971   // Only extractions of scalars are supported.
1972   int64_t rank = vecType.getRank();
1973   ArrayRef<int64_t> indices = extractOp.getStaticPosition();
1974   if (extractOp.getType() != vecType.getElementType())
1975     return {};
1976   assert(static_cast<int64_t>(indices.size()) == rank &&
1977          "unexpected number of indices");
1978 
1979   // Compute flattened/linearized index and fold to operand.
1980   int flatIndex = 0;
1981   int stride = 1;
1982   for (int i = rank - 1; i >= 0; --i) {
1983     flatIndex += indices[i] * stride;
1984     stride *= vecType.getDimSize(i);
1985   }
1986   return fromElementsOp.getElements()[flatIndex];
1987 }
1988 
1989 /// Fold an insert or extract operation into an poison value when a poison index
1990 /// is found at any dimension of the static position.
1991 static ub::PoisonAttr
1992 foldPoisonIndexInsertExtractOp(MLIRContext *context,
1993                                ArrayRef<int64_t> staticPos, int64_t poisonVal) {
1994   if (!llvm::is_contained(staticPos, poisonVal))
1995     return ub::PoisonAttr();
1996 
1997   return ub::PoisonAttr::get(context);
1998 }
1999 
2000 OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) {
2001   // Fold "vector.extract %v[] : vector<2x2xf32> from vector<2x2xf32>" to %v.
2002   // Note: Do not fold "vector.extract %v[] : f32 from vector<f32>" (type
2003   // mismatch).
2004   if (getNumIndices() == 0 && getVector().getType() == getResult().getType())
2005     return getVector();
2006   if (auto res = foldPoisonIndexInsertExtractOp(
2007           getContext(), adaptor.getStaticPosition(), kPoisonIndex))
2008     return res;
2009   if (succeeded(foldExtractOpFromExtractChain(*this)))
2010     return getResult();
2011   if (auto res = ExtractFromInsertTransposeChainState(*this).fold())
2012     return res;
2013   if (auto res = foldExtractFromBroadcast(*this))
2014     return res;
2015   if (auto res = foldExtractFromShuffle(*this))
2016     return res;
2017   if (auto res = foldExtractFromShapeCast(*this))
2018     return res;
2019   if (auto val = foldExtractFromExtractStrided(*this))
2020     return val;
2021   if (auto val = foldExtractStridedOpFromInsertChain(*this))
2022     return val;
2023   if (auto val = foldScalarExtractFromFromElements(*this))
2024     return val;
2025   return OpFoldResult();
2026 }
2027 
2028 namespace {
2029 
2030 // Pattern to rewrite a ExtractOp(Broadcast) -> Broadcast.
2031 class ExtractOpFromBroadcast final : public OpRewritePattern<ExtractOp> {
2032 public:
2033   using OpRewritePattern::OpRewritePattern;
2034 
2035   LogicalResult matchAndRewrite(ExtractOp extractOp,
2036                                 PatternRewriter &rewriter) const override {
2037     Operation *defOp = extractOp.getVector().getDefiningOp();
2038     if (!defOp || !isa<vector::BroadcastOp, SplatOp>(defOp))
2039       return failure();
2040 
2041     Value source = defOp->getOperand(0);
2042     if (extractOp.getType() == source.getType())
2043       return failure();
2044     auto getRank = [](Type type) {
2045       return llvm::isa<VectorType>(type)
2046                  ? llvm::cast<VectorType>(type).getRank()
2047                  : 0;
2048     };
2049     unsigned broadcastSrcRank = getRank(source.getType());
2050     unsigned extractResultRank = getRank(extractOp.getType());
2051     // We only consider the case where the rank of the source is less than or
2052     // equal to the rank of the extract dst. The other cases are handled in the
2053     // folding patterns.
2054     if (extractResultRank < broadcastSrcRank)
2055       return failure();
2056 
2057     // Special case if broadcast src is a 0D vector.
2058     if (extractResultRank == 0) {
2059       assert(broadcastSrcRank == 0 && llvm::isa<VectorType>(source.getType()));
2060       rewriter.replaceOpWithNewOp<vector::ExtractElementOp>(extractOp, source);
2061       return success();
2062     }
2063     rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
2064         extractOp, extractOp.getType(), source);
2065     return success();
2066   }
2067 };
2068 
2069 // Pattern to rewrite a ExtractOp(splat ConstantOp) -> ConstantOp.
2070 class ExtractOpSplatConstantFolder final : public OpRewritePattern<ExtractOp> {
2071 public:
2072   using OpRewritePattern::OpRewritePattern;
2073 
2074   LogicalResult matchAndRewrite(ExtractOp extractOp,
2075                                 PatternRewriter &rewriter) const override {
2076     // Return if 'ExtractOp' operand is not defined by a splat vector
2077     // ConstantOp.
2078     Value sourceVector = extractOp.getVector();
2079     Attribute vectorCst;
2080     if (!matchPattern(sourceVector, m_Constant(&vectorCst)))
2081       return failure();
2082     auto splat = llvm::dyn_cast<SplatElementsAttr>(vectorCst);
2083     if (!splat)
2084       return failure();
2085     TypedAttr newAttr = splat.getSplatValue<TypedAttr>();
2086     if (auto vecDstType = llvm::dyn_cast<VectorType>(extractOp.getType()))
2087       newAttr = DenseElementsAttr::get(vecDstType, newAttr);
2088     rewriter.replaceOpWithNewOp<arith::ConstantOp>(extractOp, newAttr);
2089     return success();
2090   }
2091 };
2092 
2093 // Pattern to rewrite a ExtractOp(non-splat ConstantOp)[...] -> ConstantOp.
2094 class ExtractOpNonSplatConstantFolder final
2095     : public OpRewritePattern<ExtractOp> {
2096 public:
2097   using OpRewritePattern::OpRewritePattern;
2098 
2099   LogicalResult matchAndRewrite(ExtractOp extractOp,
2100                                 PatternRewriter &rewriter) const override {
2101     // TODO: Canonicalization for dynamic position not implemented yet.
2102     if (extractOp.hasDynamicPosition())
2103       return failure();
2104 
2105     // Return if 'ExtractOp' operand is not defined by a compatible vector
2106     // ConstantOp.
2107     Value sourceVector = extractOp.getVector();
2108     Attribute vectorCst;
2109     if (!matchPattern(sourceVector, m_Constant(&vectorCst)))
2110       return failure();
2111 
2112     auto vecTy = llvm::cast<VectorType>(sourceVector.getType());
2113     if (vecTy.isScalable())
2114       return failure();
2115 
2116     // The splat case is handled by `ExtractOpSplatConstantFolder`.
2117     auto dense = llvm::dyn_cast<DenseElementsAttr>(vectorCst);
2118     if (!dense || dense.isSplat())
2119       return failure();
2120 
2121     // Calculate the linearized position of the continuous chunk of elements to
2122     // extract.
2123     llvm::SmallVector<int64_t> completePositions(vecTy.getRank(), 0);
2124     copy(extractOp.getStaticPosition(), completePositions.begin());
2125     int64_t elemBeginPosition =
2126         linearize(completePositions, computeStrides(vecTy.getShape()));
2127     auto denseValuesBegin = dense.value_begin<TypedAttr>() + elemBeginPosition;
2128 
2129     TypedAttr newAttr;
2130     if (auto resVecTy = llvm::dyn_cast<VectorType>(extractOp.getType())) {
2131       SmallVector<Attribute> elementValues(
2132           denseValuesBegin, denseValuesBegin + resVecTy.getNumElements());
2133       newAttr = DenseElementsAttr::get(resVecTy, elementValues);
2134     } else {
2135       newAttr = *denseValuesBegin;
2136     }
2137 
2138     rewriter.replaceOpWithNewOp<arith::ConstantOp>(extractOp, newAttr);
2139     return success();
2140   }
2141 };
2142 
2143 // Pattern to rewrite a ExtractOp(CreateMask) -> CreateMask.
2144 class ExtractOpFromCreateMask final : public OpRewritePattern<ExtractOp> {
2145 public:
2146   using OpRewritePattern::OpRewritePattern;
2147 
2148   LogicalResult matchAndRewrite(ExtractOp extractOp,
2149                                 PatternRewriter &rewriter) const override {
2150     auto createMaskOp =
2151         extractOp.getVector().getDefiningOp<vector::CreateMaskOp>();
2152     if (!createMaskOp)
2153       return failure();
2154 
2155     VectorType extractedMaskType =
2156         llvm::dyn_cast<VectorType>(extractOp.getResult().getType());
2157 
2158     if (!extractedMaskType)
2159       return failure();
2160 
2161     auto maskOperands = createMaskOp.getOperands();
2162     ArrayRef<int64_t> extractOpPos = extractOp.getStaticPosition();
2163     VectorType maskType = createMaskOp.getVectorType();
2164 
2165     bool containsUnknownDims = false;
2166     bool allFalse = getMaskFormat(createMaskOp) == MaskFormat::AllFalse;
2167 
2168     for (size_t dimIdx = 0; !allFalse && dimIdx < extractOpPos.size();
2169          dimIdx++) {
2170       int64_t pos = extractOpPos[dimIdx];
2171       Value operand = maskOperands[dimIdx];
2172       auto constantOp = operand.getDefiningOp<arith::ConstantOp>();
2173       if (!constantOp) {
2174         // Bounds of this dim unknown.
2175         containsUnknownDims = true;
2176         continue;
2177       }
2178 
2179       int64_t createMaskBound =
2180           llvm::cast<IntegerAttr>(constantOp.getValue()).getInt();
2181 
2182       if (pos != ShapedType::kDynamic) {
2183         // If any position is outside the range from the `create_mask`, then the
2184         // extracted mask will be all-false.
2185         allFalse |= pos >= createMaskBound;
2186       } else if (createMaskBound < maskType.getDimSize(dimIdx)) {
2187         // This dim is not all-true and since this is a dynamic index we don't
2188         // know if the extraction is within the true or false region.
2189         // Note: Zero dims have already handled via getMaskFormat().
2190         containsUnknownDims = true;
2191       }
2192     }
2193 
2194     if (allFalse) {
2195       rewriter.replaceOpWithNewOp<arith::ConstantOp>(
2196           extractOp, DenseElementsAttr::get(extractedMaskType, false));
2197     } else if (!containsUnknownDims) {
2198       rewriter.replaceOpWithNewOp<vector::CreateMaskOp>(
2199           extractOp, extractedMaskType,
2200           maskOperands.drop_front(extractOpPos.size()));
2201     } else {
2202       return failure();
2203     }
2204     return success();
2205   }
2206 };
2207 
2208 // Folds extract(shape_cast(..)) into shape_cast when the total element count
2209 // does not change.
2210 LogicalResult foldExtractFromShapeCastToShapeCast(ExtractOp extractOp,
2211                                                   PatternRewriter &rewriter) {
2212   auto castOp = extractOp.getVector().getDefiningOp<ShapeCastOp>();
2213   if (!castOp)
2214     return failure();
2215 
2216   VectorType sourceType = castOp.getSourceVectorType();
2217   auto targetType = dyn_cast<VectorType>(extractOp.getResult().getType());
2218   if (!targetType)
2219     return failure();
2220 
2221   if (sourceType.getNumElements() != targetType.getNumElements())
2222     return failure();
2223 
2224   rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(extractOp, targetType,
2225                                                    castOp.getSource());
2226   return success();
2227 }
2228 
2229 /// Try to canonicalize the extraction of a subvector from a vector defined by
2230 /// vector.from_elements. E.g.:
2231 ///
2232 /// %0 = vector.from_elements %a, %b, %a, %a : vector<2x2xf32>
2233 /// %1 = vector.extract %0[0] : vector<2xf32> from vector<2x2xf32>
2234 /// ==> canonicalize to vector.from_elements %a, %b : vector<2xf32>
2235 LogicalResult foldExtractFromFromElements(ExtractOp extractOp,
2236                                           PatternRewriter &rewriter) {
2237   // Dynamic positions are not supported.
2238   if (extractOp.hasDynamicPosition())
2239     return failure();
2240 
2241   // Scalar extracts are handled by the folder.
2242   auto resultType = dyn_cast<VectorType>(extractOp.getType());
2243   if (!resultType)
2244     return failure();
2245 
2246   // Look for extracts from a from_elements op.
2247   auto fromElementsOp = extractOp.getVector().getDefiningOp<FromElementsOp>();
2248   if (!fromElementsOp)
2249     return failure();
2250   VectorType inputType = fromElementsOp.getType();
2251 
2252   // Scalable vectors are not supported.
2253   if (resultType.isScalable() || inputType.isScalable())
2254     return failure();
2255 
2256   // Compute the position of first extracted element and flatten/linearize the
2257   // position.
2258   SmallVector<int64_t> firstElementPos =
2259       llvm::to_vector(extractOp.getStaticPosition());
2260   firstElementPos.append(/*NumInputs=*/resultType.getRank(), /*Elt=*/0);
2261   int flatIndex = 0;
2262   int stride = 1;
2263   for (int64_t i = inputType.getRank() - 1; i >= 0; --i) {
2264     flatIndex += firstElementPos[i] * stride;
2265     stride *= inputType.getDimSize(i);
2266   }
2267 
2268   // Replace the op with a smaller from_elements op.
2269   rewriter.replaceOpWithNewOp<FromElementsOp>(
2270       extractOp, resultType,
2271       fromElementsOp.getElements().slice(flatIndex,
2272                                          resultType.getNumElements()));
2273   return success();
2274 }
2275 
2276 /// Fold an insert or extract operation into an poison value when a poison index
2277 /// is found at any dimension of the static position.
2278 template <typename OpTy>
2279 LogicalResult
2280 canonicalizePoisonIndexInsertExtractOp(OpTy op, PatternRewriter &rewriter) {
2281   if (auto poisonAttr = foldPoisonIndexInsertExtractOp(
2282           op.getContext(), op.getStaticPosition(), OpTy::kPoisonIndex)) {
2283     rewriter.replaceOpWithNewOp<ub::PoisonOp>(op, op.getType(), poisonAttr);
2284     return success();
2285   }
2286 
2287   return failure();
2288 }
2289 
2290 } // namespace
2291 
2292 void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results,
2293                                             MLIRContext *context) {
2294   results.add<ExtractOpSplatConstantFolder, ExtractOpNonSplatConstantFolder,
2295               ExtractOpFromBroadcast, ExtractOpFromCreateMask>(context);
2296   results.add(foldExtractFromShapeCastToShapeCast);
2297   results.add(foldExtractFromFromElements);
2298   results.add(canonicalizePoisonIndexInsertExtractOp<ExtractOp>);
2299 }
2300 
2301 static void populateFromInt64AttrArray(ArrayAttr arrayAttr,
2302                                        SmallVectorImpl<int64_t> &results) {
2303   for (auto attr : arrayAttr)
2304     results.push_back(llvm::cast<IntegerAttr>(attr).getInt());
2305 }
2306 
2307 //===----------------------------------------------------------------------===//
2308 // FmaOp
2309 //===----------------------------------------------------------------------===//
2310 
2311 std::optional<SmallVector<int64_t, 4>> FMAOp::getShapeForUnroll() {
2312   return llvm::to_vector<4>(getVectorType().getShape());
2313 }
2314 
2315 //===----------------------------------------------------------------------===//
2316 // FromElementsOp
2317 //===----------------------------------------------------------------------===//
2318 
2319 /// Rewrite a vector.from_elements into a vector.splat if all elements are the
2320 /// same SSA value. E.g.:
2321 ///
2322 /// %0 = vector.from_elements %a, %a, %a : vector<3xf32>
2323 /// ==> rewrite to vector.splat %a : vector<3xf32>
2324 static LogicalResult rewriteFromElementsAsSplat(FromElementsOp fromElementsOp,
2325                                                 PatternRewriter &rewriter) {
2326   if (!llvm::all_equal(fromElementsOp.getElements()))
2327     return failure();
2328   rewriter.replaceOpWithNewOp<SplatOp>(fromElementsOp, fromElementsOp.getType(),
2329                                        fromElementsOp.getElements().front());
2330   return success();
2331 }
2332 
2333 void FromElementsOp::getCanonicalizationPatterns(RewritePatternSet &results,
2334                                                  MLIRContext *context) {
2335   results.add(rewriteFromElementsAsSplat);
2336 }
2337 
2338 //===----------------------------------------------------------------------===//
2339 // BroadcastOp
2340 //===----------------------------------------------------------------------===//
2341 
2342 void BroadcastOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
2343                                     SetIntRangeFn setResultRanges) {
2344   setResultRanges(getResult(), argRanges.front());
2345 }
2346 
2347 /// Return the dimensions of the result vector that were formerly ones in the
2348 /// source tensor and thus correspond to "dim-1" broadcasting.
2349 static llvm::SetVector<int64_t>
2350 computeBroadcastedUnitDims(ArrayRef<int64_t> srcShape,
2351                            ArrayRef<int64_t> dstShape) {
2352   int64_t rankDiff = dstShape.size() - srcShape.size();
2353   int64_t dstDim = rankDiff;
2354   llvm::SetVector<int64_t> res;
2355   for (auto [s1, s2] :
2356        llvm::zip_equal(srcShape, dstShape.drop_front(rankDiff))) {
2357     if (s1 != s2) {
2358       assert(s1 == 1 && "expected dim-1 broadcasting");
2359       res.insert(dstDim);
2360     }
2361     ++dstDim;
2362   }
2363   return res;
2364 }
2365 
2366 llvm::SetVector<int64_t> BroadcastOp::computeBroadcastedUnitDims() {
2367   // Scalar broadcast is without any unit dim broadcast.
2368   auto srcVectorType = llvm::dyn_cast<VectorType>(getSourceType());
2369   if (!srcVectorType)
2370     return {};
2371   return ::computeBroadcastedUnitDims(srcVectorType.getShape(),
2372                                       getResultVectorType().getShape());
2373 }
2374 
2375 /// Broadcast `value` to a vector of `dstShape`, knowing that exactly the
2376 /// `broadcastedDims` dimensions in the dstShape are broadcasted.
2377 /// This requires (and asserts) that the broadcast is free of dim-1
2378 /// broadcasting.
2379 /// Since vector.broadcast only allows expanding leading dimensions, an extra
2380 /// vector.transpose may be inserted to make the broadcast possible.
2381 /// `value`, `dstShape` and `broadcastedDims` must be properly specified or
2382 /// the helper will assert. This means:
2383 ///   1. `dstShape` must not be empty.
2384 ///   2. `broadcastedDims` must be confined to [0 .. rank(value.getVectorType)]
2385 ///   2. `dstShape` trimmed of the dimensions specified in `broadcastedDims`
2386 //       must match the `value` shape.
2387 Value BroadcastOp::createOrFoldBroadcastOp(
2388     OpBuilder &b, Value value, ArrayRef<int64_t> dstShape,
2389     const llvm::SetVector<int64_t> &broadcastedDims) {
2390   assert(!dstShape.empty() && "unexpected empty dst shape");
2391 
2392   // Well-formedness check.
2393   SmallVector<int64_t> checkShape;
2394   for (int i = 0, e = dstShape.size(); i < e; ++i) {
2395     if (broadcastedDims.contains(i))
2396       continue;
2397     checkShape.push_back(dstShape[i]);
2398   }
2399   assert(broadcastedDims.size() == dstShape.size() - checkShape.size() &&
2400          "ill-formed broadcastedDims contains values not confined to "
2401          "destVectorShape");
2402 
2403   Location loc = value.getLoc();
2404   Type elementType = getElementTypeOrSelf(value.getType());
2405   VectorType srcVectorType = llvm::dyn_cast<VectorType>(value.getType());
2406   VectorType dstVectorType = VectorType::get(dstShape, elementType);
2407 
2408   // Step 2. If scalar -> dstShape broadcast, just do it.
2409   if (!srcVectorType) {
2410     assert(checkShape.empty() &&
2411            "ill-formed createOrFoldBroadcastOp arguments");
2412     return b.createOrFold<vector::BroadcastOp>(loc, dstVectorType, value);
2413   }
2414 
2415   assert(srcVectorType.getShape().equals(checkShape) &&
2416          "ill-formed createOrFoldBroadcastOp arguments");
2417 
2418   // Step 3. Since vector.broadcast only allows creating leading dims,
2419   //   vector -> dstShape broadcast may require a transpose.
2420   // Traverse the dims in order and construct:
2421   //   1. The leading entries of the broadcastShape that is guaranteed to be
2422   //      achievable by a simple broadcast.
2423   //   2. The induced permutation for the subsequent vector.transpose that will
2424   //      bring us from `broadcastShape` back to he desired `dstShape`.
2425   // If the induced permutation is not the identity, create a vector.transpose.
2426   SmallVector<int64_t> broadcastShape, permutation(dstShape.size(), -1);
2427   broadcastShape.reserve(dstShape.size());
2428   // Consider the example:
2429   //   srcShape     = 2x4
2430   //   dstShape     = 1x2x3x4x5
2431   //   broadcastedDims = [0, 2, 4]
2432   //
2433   // We want to build:
2434   //   broadcastShape  = 1x3x5x2x4
2435   //   permutation     = [0, 2, 4,                 1, 3]
2436   //                      ---V---           -----V-----
2437   //            leading broadcast part      src shape part
2438   //
2439   // Note that the trailing dims of broadcastShape are exactly the srcShape
2440   // by construction.
2441   // nextSrcShapeDim is used to keep track of where in the permutation the
2442   // "src shape part" occurs.
2443   int64_t nextSrcShapeDim = broadcastedDims.size();
2444   for (int64_t i = 0, e = dstShape.size(); i < e; ++i) {
2445     if (broadcastedDims.contains(i)) {
2446       // 3.a. For each dim in the dst shape, if it is a broadcasted dim,
2447       // bring it to the head of the broadcastShape.
2448       // It will need to be permuted back from `broadcastShape.size() - 1` into
2449       // position `i`.
2450       broadcastShape.push_back(dstShape[i]);
2451       permutation[i] = broadcastShape.size() - 1;
2452     } else {
2453       // 3.b. Otherwise, the dim is not broadcasted, it comes from the src
2454       // shape and needs to be permuted into position `i`.
2455       // Don't touch `broadcastShape` here, the whole srcShape will be
2456       // appended after.
2457       permutation[i] = nextSrcShapeDim++;
2458     }
2459   }
2460   // 3.c. Append the srcShape.
2461   llvm::append_range(broadcastShape, srcVectorType.getShape());
2462 
2463   // Ensure there are no dim-1 broadcasts.
2464   assert(::computeBroadcastedUnitDims(srcVectorType.getShape(), broadcastShape)
2465              .empty() &&
2466          "unexpected dim-1 broadcast");
2467 
2468   VectorType broadcastType = VectorType::get(broadcastShape, elementType);
2469   assert(vector::isBroadcastableTo(value.getType(), broadcastType) ==
2470              vector::BroadcastableToResult::Success &&
2471          "must be broadcastable");
2472   Value res = b.createOrFold<vector::BroadcastOp>(loc, broadcastType, value);
2473   // Step 4. If we find any dimension that indeed needs to be permuted,
2474   // immediately return a new vector.transpose.
2475   for (int64_t i = 0, e = permutation.size(); i < e; ++i)
2476     if (permutation[i] != i)
2477       return b.createOrFold<vector::TransposeOp>(loc, res, permutation);
2478   // Otherwise return res.
2479   return res;
2480 }
2481 
2482 BroadcastableToResult mlir::vector::isBroadcastableTo(
2483     Type srcType, VectorType dstVectorType,
2484     std::pair<VectorDim, VectorDim> *mismatchingDims) {
2485   // Broadcast scalar to vector of the same element type.
2486   if (srcType.isIntOrIndexOrFloat() && dstVectorType &&
2487       getElementTypeOrSelf(srcType) == getElementTypeOrSelf(dstVectorType))
2488     return BroadcastableToResult::Success;
2489   // From now on, only vectors broadcast.
2490   VectorType srcVectorType = llvm::dyn_cast<VectorType>(srcType);
2491   if (!srcVectorType)
2492     return BroadcastableToResult::SourceTypeNotAVector;
2493 
2494   int64_t srcRank = srcVectorType.getRank();
2495   int64_t dstRank = dstVectorType.getRank();
2496   if (srcRank > dstRank)
2497     return BroadcastableToResult::SourceRankHigher;
2498   // Source has an exact match or singleton value for all trailing dimensions
2499   // (all leading dimensions are simply duplicated).
2500   int64_t lead = dstRank - srcRank;
2501   for (int64_t dimIdx = 0; dimIdx < srcRank; ++dimIdx) {
2502     // Have mismatching dims (in the sense of vector.broadcast semantics) been
2503     // encountered?
2504     bool foundMismatchingDims = false;
2505 
2506     // Check fixed-width dims.
2507     int64_t srcDim = srcVectorType.getDimSize(dimIdx);
2508     int64_t dstDim = dstVectorType.getDimSize(lead + dimIdx);
2509     if (srcDim != 1 && srcDim != dstDim)
2510       foundMismatchingDims = true;
2511 
2512     // Check scalable flags.
2513     bool srcDimScalableFlag = srcVectorType.getScalableDims()[dimIdx];
2514     bool dstDimScalableFlag = dstVectorType.getScalableDims()[lead + dimIdx];
2515     if ((srcDim == 1 && srcDimScalableFlag && dstDim != 1) ||
2516         // 1 -> [N] is fine, everything else should be rejected when mixing
2517         // fixed-width and scalable dims
2518         (srcDimScalableFlag != dstDimScalableFlag &&
2519          (srcDim != 1 || srcDimScalableFlag)))
2520       foundMismatchingDims = true;
2521 
2522     if (foundMismatchingDims) {
2523       if (mismatchingDims != nullptr) {
2524         mismatchingDims->first.dim = srcDim;
2525         mismatchingDims->first.isScalable = srcDimScalableFlag;
2526 
2527         mismatchingDims->second.dim = dstDim;
2528         mismatchingDims->second.isScalable = dstDimScalableFlag;
2529       }
2530       return BroadcastableToResult::DimensionMismatch;
2531     }
2532   }
2533 
2534   return BroadcastableToResult::Success;
2535 }
2536 
2537 LogicalResult BroadcastOp::verify() {
2538   std::pair<VectorDim, VectorDim> mismatchingDims;
2539   BroadcastableToResult res = isBroadcastableTo(
2540       getSourceType(), getResultVectorType(), &mismatchingDims);
2541   if (res == BroadcastableToResult::Success)
2542     return success();
2543   if (res == BroadcastableToResult::SourceRankHigher)
2544     return emitOpError("source rank higher than destination rank");
2545   if (res == BroadcastableToResult::DimensionMismatch) {
2546     return emitOpError("dimension mismatch (")
2547            << (mismatchingDims.first.isScalable ? "[" : "")
2548            << mismatchingDims.first.dim
2549            << (mismatchingDims.first.isScalable ? "]" : "") << " vs. "
2550            << (mismatchingDims.second.isScalable ? "[" : "")
2551            << mismatchingDims.second.dim
2552            << (mismatchingDims.second.isScalable ? "]" : "") << ")";
2553   }
2554   if (res == BroadcastableToResult::SourceTypeNotAVector)
2555     return emitOpError("source type is not a vector");
2556   llvm_unreachable("unexpected vector.broadcast op error");
2557 }
2558 
2559 OpFoldResult BroadcastOp::fold(FoldAdaptor adaptor) {
2560   if (getSourceType() == getResultVectorType())
2561     return getSource();
2562   if (!adaptor.getSource())
2563     return {};
2564   auto vectorType = getResultVectorType();
2565   if (auto attr = llvm::dyn_cast<IntegerAttr>(adaptor.getSource())) {
2566     if (vectorType.getElementType() != attr.getType())
2567       return {};
2568     return DenseElementsAttr::get(vectorType, attr);
2569   }
2570   if (auto attr = llvm::dyn_cast<FloatAttr>(adaptor.getSource())) {
2571     if (vectorType.getElementType() != attr.getType())
2572       return {};
2573     return DenseElementsAttr::get(vectorType, attr);
2574   }
2575   if (auto attr = llvm::dyn_cast<SplatElementsAttr>(adaptor.getSource()))
2576     return DenseElementsAttr::get(vectorType, attr.getSplatValue<Attribute>());
2577   return {};
2578 }
2579 
2580 namespace {
2581 
2582 // Fold broadcast1(broadcast2(x)) into broadcast1(x).
2583 struct BroadcastFolder : public OpRewritePattern<BroadcastOp> {
2584   using OpRewritePattern::OpRewritePattern;
2585 
2586   LogicalResult matchAndRewrite(BroadcastOp broadcastOp,
2587                                 PatternRewriter &rewriter) const override {
2588     auto srcBroadcast = broadcastOp.getSource().getDefiningOp<BroadcastOp>();
2589     if (!srcBroadcast)
2590       return failure();
2591     rewriter.replaceOpWithNewOp<BroadcastOp>(broadcastOp,
2592                                              broadcastOp.getResultVectorType(),
2593                                              srcBroadcast.getSource());
2594     return success();
2595   }
2596 };
2597 } // namespace
2598 
2599 void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
2600                                               MLIRContext *context) {
2601   // BroadcastToShapeCast is not a default canonicalization, it is opt-in by
2602   // calling `populateCastAwayVectorLeadingOneDimPatterns`
2603   results.add<BroadcastFolder>(context);
2604 }
2605 
2606 //===----------------------------------------------------------------------===//
2607 // ShuffleOp
2608 //===----------------------------------------------------------------------===//
2609 
2610 LogicalResult ShuffleOp::verify() {
2611   VectorType resultType = getResultVectorType();
2612   VectorType v1Type = getV1VectorType();
2613   VectorType v2Type = getV2VectorType();
2614   // Verify ranks.
2615   int64_t resRank = resultType.getRank();
2616   int64_t v1Rank = v1Type.getRank();
2617   int64_t v2Rank = v2Type.getRank();
2618   bool wellFormed0DCase = v1Rank == 0 && v2Rank == 0 && resRank == 1;
2619   bool wellFormedNDCase = v1Rank == resRank && v2Rank == resRank;
2620   if (!wellFormed0DCase && !wellFormedNDCase)
2621     return emitOpError("rank mismatch");
2622 
2623   // Verify all but leading dimension sizes.
2624   for (int64_t r = 1; r < v1Rank; ++r) {
2625     int64_t resDim = resultType.getDimSize(r);
2626     int64_t v1Dim = v1Type.getDimSize(r);
2627     int64_t v2Dim = v2Type.getDimSize(r);
2628     if (resDim != v1Dim || v1Dim != v2Dim)
2629       return emitOpError("dimension mismatch");
2630   }
2631   // Verify mask length.
2632   ArrayRef<int64_t> mask = getMask();
2633   int64_t maskLength = mask.size();
2634   if (maskLength <= 0)
2635     return emitOpError("invalid mask length");
2636   if (maskLength != resultType.getDimSize(0))
2637     return emitOpError("mask length mismatch");
2638   // Verify all indices.
2639   int64_t indexSize = (v1Type.getRank() == 0 ? 1 : v1Type.getDimSize(0)) +
2640                       (v2Type.getRank() == 0 ? 1 : v2Type.getDimSize(0));
2641   for (auto [idx, maskPos] : llvm::enumerate(mask)) {
2642     if (!isValidPositiveIndexOrPoison(maskPos, kPoisonIndex, indexSize))
2643       return emitOpError("mask index #") << (idx + 1) << " out of range";
2644   }
2645   return success();
2646 }
2647 
2648 LogicalResult
2649 ShuffleOp::inferReturnTypes(MLIRContext *, std::optional<Location>,
2650                             ShuffleOp::Adaptor adaptor,
2651                             SmallVectorImpl<Type> &inferredReturnTypes) {
2652   auto v1Type = llvm::cast<VectorType>(adaptor.getV1().getType());
2653   auto v1Rank = v1Type.getRank();
2654   // Construct resulting type: leading dimension matches mask
2655   // length, all trailing dimensions match the operands.
2656   SmallVector<int64_t, 4> shape;
2657   shape.reserve(v1Rank);
2658   shape.push_back(std::max<size_t>(1, adaptor.getMask().size()));
2659   // In the 0-D case there is no trailing shape to append.
2660   if (v1Rank > 0)
2661     llvm::append_range(shape, v1Type.getShape().drop_front());
2662   inferredReturnTypes.push_back(
2663       VectorType::get(shape, v1Type.getElementType()));
2664   return success();
2665 }
2666 
2667 template <typename T>
2668 static bool isStepIndexArray(ArrayRef<T> idxArr, uint64_t begin, size_t width) {
2669   T expected = begin;
2670   return idxArr.size() == width && llvm::all_of(idxArr, [&expected](T value) {
2671            return value == expected++;
2672          });
2673 }
2674 
2675 OpFoldResult vector::ShuffleOp::fold(FoldAdaptor adaptor) {
2676   VectorType v1Type = getV1VectorType();
2677   // For consistency: 0-D shuffle return type is 1-D, this cannot be a folding
2678   // but must be a canonicalization into a vector.broadcast.
2679   if (v1Type.getRank() == 0)
2680     return {};
2681 
2682   // fold shuffle V1, V2, [0, 1, 2, 3] : <4xi32>, <2xi32> -> V1
2683   if (!v1Type.isScalable() &&
2684       isStepIndexArray(getMask(), 0, v1Type.getDimSize(0)))
2685     return getV1();
2686   // fold shuffle V1, V2, [4, 5] : <4xi32>, <2xi32> -> V2
2687   if (!getV1VectorType().isScalable() && !getV2VectorType().isScalable() &&
2688       isStepIndexArray(getMask(), getV1VectorType().getDimSize(0),
2689                        getV2VectorType().getDimSize(0)))
2690     return getV2();
2691 
2692   Attribute lhs = adaptor.getV1(), rhs = adaptor.getV2();
2693   if (!lhs || !rhs)
2694     return {};
2695 
2696   auto lhsType =
2697       llvm::cast<VectorType>(llvm::cast<DenseElementsAttr>(lhs).getType());
2698   // Only support 1-D for now to avoid complicated n-D DenseElementsAttr
2699   // manipulation.
2700   if (lhsType.getRank() != 1)
2701     return {};
2702   int64_t lhsSize = lhsType.getDimSize(0);
2703 
2704   SmallVector<Attribute> results;
2705   auto lhsElements = llvm::cast<DenseElementsAttr>(lhs).getValues<Attribute>();
2706   auto rhsElements = llvm::cast<DenseElementsAttr>(rhs).getValues<Attribute>();
2707   for (int64_t i : this->getMask()) {
2708     if (i >= lhsSize) {
2709       results.push_back(rhsElements[i - lhsSize]);
2710     } else {
2711       results.push_back(lhsElements[i]);
2712     }
2713   }
2714 
2715   return DenseElementsAttr::get(getResultVectorType(), results);
2716 }
2717 
2718 namespace {
2719 
2720 // Pattern to rewrite a 0-D shuffle with [0] or [1] mask returning a 1-D vector
2721 // to a broadcast.
2722 struct Canonicalize0DShuffleOp : public OpRewritePattern<ShuffleOp> {
2723   using OpRewritePattern::OpRewritePattern;
2724 
2725   LogicalResult matchAndRewrite(ShuffleOp shuffleOp,
2726                                 PatternRewriter &rewriter) const override {
2727     VectorType v1VectorType = shuffleOp.getV1VectorType();
2728     ArrayRef<int64_t> mask = shuffleOp.getMask();
2729     if (v1VectorType.getRank() > 0)
2730       return failure();
2731     if (mask.size() != 1)
2732       return failure();
2733     VectorType resType = VectorType::Builder(v1VectorType).setShape({1});
2734     if (mask[0] == 0)
2735       rewriter.replaceOpWithNewOp<vector::BroadcastOp>(shuffleOp, resType,
2736                                                        shuffleOp.getV1());
2737     else
2738       rewriter.replaceOpWithNewOp<vector::BroadcastOp>(shuffleOp, resType,
2739                                                        shuffleOp.getV2());
2740     return success();
2741   }
2742 };
2743 
2744 /// Pattern to rewrite a ShuffleOp(SplatOp, SplatOp) to SplatOp.
2745 class ShuffleSplat final : public OpRewritePattern<ShuffleOp> {
2746 public:
2747   using OpRewritePattern::OpRewritePattern;
2748 
2749   LogicalResult matchAndRewrite(ShuffleOp op,
2750                                 PatternRewriter &rewriter) const override {
2751     auto v1Splat = op.getV1().getDefiningOp<SplatOp>();
2752     auto v2Splat = op.getV2().getDefiningOp<SplatOp>();
2753 
2754     if (!v1Splat || !v2Splat)
2755       return failure();
2756 
2757     if (v1Splat.getInput() != v2Splat.getInput())
2758       return failure();
2759 
2760     rewriter.replaceOpWithNewOp<SplatOp>(op, op.getType(), v1Splat.getInput());
2761     return success();
2762   }
2763 };
2764 
2765 /// Pattern to rewrite a fixed-size interleave via vector.shuffle to
2766 /// vector.interleave.
2767 class ShuffleInterleave : public OpRewritePattern<ShuffleOp> {
2768 public:
2769   using OpRewritePattern::OpRewritePattern;
2770 
2771   LogicalResult matchAndRewrite(ShuffleOp op,
2772                                 PatternRewriter &rewriter) const override {
2773     VectorType resultType = op.getResultVectorType();
2774     if (resultType.isScalable())
2775       return rewriter.notifyMatchFailure(
2776           op, "ShuffleOp can't represent a scalable interleave");
2777 
2778     if (resultType.getRank() != 1)
2779       return rewriter.notifyMatchFailure(
2780           op, "ShuffleOp can't represent an n-D interleave");
2781 
2782     VectorType sourceType = op.getV1VectorType();
2783     if (sourceType != op.getV2VectorType() ||
2784         sourceType.getNumElements() * 2 != resultType.getNumElements()) {
2785       return rewriter.notifyMatchFailure(
2786           op, "ShuffleOp types don't match an interleave");
2787     }
2788 
2789     ArrayRef<int64_t> shuffleMask = op.getMask();
2790     int64_t resultVectorSize = resultType.getNumElements();
2791     for (int i = 0, e = resultVectorSize / 2; i < e; ++i) {
2792       int64_t maskValueA = shuffleMask[i * 2];
2793       int64_t maskValueB = shuffleMask[(i * 2) + 1];
2794       if (maskValueA != i || maskValueB != (resultVectorSize / 2) + i)
2795         return rewriter.notifyMatchFailure(op,
2796                                            "ShuffleOp mask not interleaving");
2797     }
2798 
2799     rewriter.replaceOpWithNewOp<InterleaveOp>(op, op.getV1(), op.getV2());
2800     return success();
2801   }
2802 };
2803 
2804 } // namespace
2805 
2806 void ShuffleOp::getCanonicalizationPatterns(RewritePatternSet &results,
2807                                             MLIRContext *context) {
2808   results.add<ShuffleSplat, ShuffleInterleave, Canonicalize0DShuffleOp>(
2809       context);
2810 }
2811 
2812 //===----------------------------------------------------------------------===//
2813 // InsertElementOp
2814 //===----------------------------------------------------------------------===//
2815 
2816 void InsertElementOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
2817                                         SetIntRangeFn setResultRanges) {
2818   setResultRanges(getResult(), argRanges[0].rangeUnion(argRanges[1]));
2819 }
2820 
2821 void InsertElementOp::build(OpBuilder &builder, OperationState &result,
2822                             Value source, Value dest) {
2823   build(builder, result, source, dest, {});
2824 }
2825 
2826 LogicalResult InsertElementOp::verify() {
2827   auto dstVectorType = getDestVectorType();
2828   if (dstVectorType.getRank() == 0) {
2829     if (getPosition())
2830       return emitOpError("expected position to be empty with 0-D vector");
2831     return success();
2832   }
2833   if (dstVectorType.getRank() != 1)
2834     return emitOpError("unexpected >1 vector rank");
2835   if (!getPosition())
2836     return emitOpError("expected position for 1-D vector");
2837   return success();
2838 }
2839 
2840 OpFoldResult vector::InsertElementOp::fold(FoldAdaptor adaptor) {
2841   // Skip the 0-D vector here.
2842   if (!adaptor.getPosition())
2843     return {};
2844 
2845   auto src = dyn_cast_or_null<TypedAttr>(adaptor.getSource());
2846   auto dst = dyn_cast_or_null<DenseElementsAttr>(adaptor.getDest());
2847   auto pos = dyn_cast_or_null<IntegerAttr>(adaptor.getPosition());
2848   if (!src || !dst || !pos)
2849     return {};
2850 
2851   if (src.getType() != getDestVectorType().getElementType())
2852     return {};
2853 
2854   auto dstElements = dst.getValues<Attribute>();
2855 
2856   SmallVector<Attribute> results(dstElements);
2857 
2858   uint64_t posIdx = pos.getInt();
2859   if (posIdx >= results.size())
2860     return {};
2861   results[posIdx] = src;
2862 
2863   return DenseElementsAttr::get(getDestVectorType(), results);
2864 }
2865 
2866 //===----------------------------------------------------------------------===//
2867 // InsertOp
2868 //===----------------------------------------------------------------------===//
2869 
2870 void vector::InsertOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
2871                                          SetIntRangeFn setResultRanges) {
2872   setResultRanges(getResult(), argRanges[0].rangeUnion(argRanges[1]));
2873 }
2874 
2875 void vector::InsertOp::build(OpBuilder &builder, OperationState &result,
2876                              Value source, Value dest, int64_t position) {
2877   build(builder, result, source, dest, ArrayRef<int64_t>{position});
2878 }
2879 
2880 void vector::InsertOp::build(OpBuilder &builder, OperationState &result,
2881                              Value source, Value dest, OpFoldResult position) {
2882   build(builder, result, source, dest, ArrayRef<OpFoldResult>{position});
2883 }
2884 
2885 void vector::InsertOp::build(OpBuilder &builder, OperationState &result,
2886                              Value source, Value dest,
2887                              ArrayRef<int64_t> position) {
2888   SmallVector<OpFoldResult> posVals;
2889   posVals.reserve(position.size());
2890   llvm::transform(position, std::back_inserter(posVals),
2891                   [&](int64_t pos) { return builder.getI64IntegerAttr(pos); });
2892   build(builder, result, source, dest, posVals);
2893 }
2894 
2895 void vector::InsertOp::build(OpBuilder &builder, OperationState &result,
2896                              Value source, Value dest,
2897                              ArrayRef<OpFoldResult> position) {
2898   SmallVector<int64_t> staticPos;
2899   SmallVector<Value> dynamicPos;
2900   dispatchIndexOpFoldResults(position, dynamicPos, staticPos);
2901   build(builder, result, source, dest, dynamicPos,
2902         builder.getDenseI64ArrayAttr(staticPos));
2903 }
2904 
2905 LogicalResult InsertOp::verify() {
2906   SmallVector<OpFoldResult> position = getMixedPosition();
2907   auto destVectorType = getDestVectorType();
2908   if (position.size() > static_cast<unsigned>(destVectorType.getRank()))
2909     return emitOpError(
2910         "expected position attribute of rank no greater than dest vector rank");
2911   auto srcVectorType = llvm::dyn_cast<VectorType>(getSourceType());
2912   if (srcVectorType &&
2913       (static_cast<unsigned>(srcVectorType.getRank()) + position.size() !=
2914        static_cast<unsigned>(destVectorType.getRank())))
2915     return emitOpError("expected position attribute rank + source rank to "
2916                        "match dest vector rank");
2917   if (!srcVectorType &&
2918       (position.size() != static_cast<unsigned>(destVectorType.getRank())))
2919     return emitOpError(
2920         "expected position attribute rank to match the dest vector rank");
2921   for (auto [idx, pos] : llvm::enumerate(position)) {
2922     if (auto attr = pos.dyn_cast<Attribute>()) {
2923       int64_t constIdx = cast<IntegerAttr>(attr).getInt();
2924       if (!isValidPositiveIndexOrPoison(constIdx, kPoisonIndex,
2925                                         destVectorType.getDimSize(idx))) {
2926         return emitOpError("expected position attribute #")
2927                << (idx + 1)
2928                << " to be a non-negative integer smaller than the "
2929                   "corresponding "
2930                   "dest vector dimension";
2931       }
2932     }
2933   }
2934   return success();
2935 }
2936 
2937 namespace {
2938 
2939 // If insertOp is only inserting unit dimensions it can be transformed to a
2940 // broadcast.
2941 class InsertToBroadcast final : public OpRewritePattern<InsertOp> {
2942 public:
2943   using OpRewritePattern::OpRewritePattern;
2944 
2945   LogicalResult matchAndRewrite(InsertOp insertOp,
2946                                 PatternRewriter &rewriter) const override {
2947     auto srcVecType = llvm::dyn_cast<VectorType>(insertOp.getSourceType());
2948     if (!srcVecType || insertOp.getDestVectorType().getNumElements() !=
2949                            srcVecType.getNumElements())
2950       return failure();
2951     rewriter.replaceOpWithNewOp<BroadcastOp>(
2952         insertOp, insertOp.getDestVectorType(), insertOp.getSource());
2953     return success();
2954   }
2955 };
2956 
2957 /// Pattern to rewrite a InsertOp(SplatOp, SplatOp) to SplatOp.
2958 class InsertSplatToSplat final : public OpRewritePattern<InsertOp> {
2959 public:
2960   using OpRewritePattern::OpRewritePattern;
2961 
2962   LogicalResult matchAndRewrite(InsertOp op,
2963                                 PatternRewriter &rewriter) const override {
2964     auto srcSplat = op.getSource().getDefiningOp<SplatOp>();
2965     auto dstSplat = op.getDest().getDefiningOp<SplatOp>();
2966 
2967     if (!srcSplat || !dstSplat)
2968       return failure();
2969 
2970     if (srcSplat.getInput() != dstSplat.getInput())
2971       return failure();
2972 
2973     rewriter.replaceOpWithNewOp<SplatOp>(op, op.getType(), srcSplat.getInput());
2974     return success();
2975   }
2976 };
2977 
2978 // Pattern to rewrite a InsertOp(ConstantOp into ConstantOp) -> ConstantOp.
2979 class InsertOpConstantFolder final : public OpRewritePattern<InsertOp> {
2980 public:
2981   using OpRewritePattern::OpRewritePattern;
2982 
2983   // Do not create constants with more than `vectorSizeFoldThreashold` elements,
2984   // unless the source vector constant has a single use.
2985   static constexpr int64_t vectorSizeFoldThreshold = 256;
2986 
2987   LogicalResult matchAndRewrite(InsertOp op,
2988                                 PatternRewriter &rewriter) const override {
2989     // TODO: Canonicalization for dynamic position not implemented yet.
2990     if (op.hasDynamicPosition())
2991       return failure();
2992 
2993     // Return if 'InsertOp' operand is not defined by a compatible vector
2994     // ConstantOp.
2995     TypedValue<VectorType> destVector = op.getDest();
2996     Attribute vectorDestCst;
2997     if (!matchPattern(destVector, m_Constant(&vectorDestCst)))
2998       return failure();
2999     auto denseDest = llvm::dyn_cast<DenseElementsAttr>(vectorDestCst);
3000     if (!denseDest)
3001       return failure();
3002 
3003     VectorType destTy = destVector.getType();
3004     if (destTy.isScalable())
3005       return failure();
3006 
3007     // Make sure we do not create too many large constants.
3008     if (destTy.getNumElements() > vectorSizeFoldThreshold &&
3009         !destVector.hasOneUse())
3010       return failure();
3011 
3012     Value sourceValue = op.getSource();
3013     Attribute sourceCst;
3014     if (!matchPattern(sourceValue, m_Constant(&sourceCst)))
3015       return failure();
3016 
3017     // Calculate the linearized position of the continuous chunk of elements to
3018     // insert.
3019     llvm::SmallVector<int64_t> completePositions(destTy.getRank(), 0);
3020     copy(op.getStaticPosition(), completePositions.begin());
3021     int64_t insertBeginPosition =
3022         linearize(completePositions, computeStrides(destTy.getShape()));
3023 
3024     SmallVector<Attribute> insertedValues;
3025     Type destEltType = destTy.getElementType();
3026 
3027     // The `convertIntegerAttr` method specifically handles the case
3028     // for `llvm.mlir.constant` which can hold an attribute with a
3029     // different type than the return type.
3030     if (auto denseSource = llvm::dyn_cast<DenseElementsAttr>(sourceCst)) {
3031       for (auto value : denseSource.getValues<Attribute>())
3032         insertedValues.push_back(convertIntegerAttr(value, destEltType));
3033     } else {
3034       insertedValues.push_back(convertIntegerAttr(sourceCst, destEltType));
3035     }
3036 
3037     auto allValues = llvm::to_vector(denseDest.getValues<Attribute>());
3038     copy(insertedValues, allValues.begin() + insertBeginPosition);
3039     auto newAttr = DenseElementsAttr::get(destTy, allValues);
3040 
3041     rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, newAttr);
3042     return success();
3043   }
3044 
3045 private:
3046   /// Converts the expected type to an IntegerAttr if there's
3047   /// a mismatch.
3048   Attribute convertIntegerAttr(Attribute attr, Type expectedType) const {
3049     if (auto intAttr = mlir::dyn_cast<IntegerAttr>(attr)) {
3050       if (intAttr.getType() != expectedType)
3051         return IntegerAttr::get(expectedType, intAttr.getInt());
3052     }
3053     return attr;
3054   }
3055 };
3056 
3057 } // namespace
3058 
3059 void InsertOp::getCanonicalizationPatterns(RewritePatternSet &results,
3060                                            MLIRContext *context) {
3061   results.add<InsertToBroadcast, BroadcastFolder, InsertSplatToSplat,
3062               InsertOpConstantFolder>(context);
3063   results.add(canonicalizePoisonIndexInsertExtractOp<InsertOp>);
3064 }
3065 
3066 OpFoldResult vector::InsertOp::fold(FoldAdaptor adaptor) {
3067   // Fold "vector.insert %v, %dest [] : vector<2x2xf32> from vector<2x2xf32>" to
3068   // %v. Note: Do not fold "vector.insert %v, %dest [] : f32 into vector<f32>"
3069   // (type mismatch).
3070   if (getNumIndices() == 0 && getSourceType() == getType())
3071     return getSource();
3072   if (auto res = foldPoisonIndexInsertExtractOp(
3073           getContext(), adaptor.getStaticPosition(), kPoisonIndex))
3074     return res;
3075 
3076   return {};
3077 }
3078 
3079 //===----------------------------------------------------------------------===//
3080 // InsertStridedSliceOp
3081 //===----------------------------------------------------------------------===//
3082 
3083 void InsertStridedSliceOp::build(OpBuilder &builder, OperationState &result,
3084                                  Value source, Value dest,
3085                                  ArrayRef<int64_t> offsets,
3086                                  ArrayRef<int64_t> strides) {
3087   result.addOperands({source, dest});
3088   auto offsetsAttr = getVectorSubscriptAttr(builder, offsets);
3089   auto stridesAttr = getVectorSubscriptAttr(builder, strides);
3090   result.addTypes(dest.getType());
3091   result.addAttribute(InsertStridedSliceOp::getOffsetsAttrName(result.name),
3092                       offsetsAttr);
3093   result.addAttribute(InsertStridedSliceOp::getStridesAttrName(result.name),
3094                       stridesAttr);
3095 }
3096 
3097 // TODO: Should be moved to Tablegen ConfinedAttr attributes.
3098 template <typename OpType>
3099 static LogicalResult isIntegerArrayAttrSmallerThanShape(OpType op,
3100                                                         ArrayAttr arrayAttr,
3101                                                         ArrayRef<int64_t> shape,
3102                                                         StringRef attrName) {
3103   if (arrayAttr.size() > shape.size())
3104     return op.emitOpError("expected ")
3105            << attrName << " attribute of rank no greater than vector rank";
3106   return success();
3107 }
3108 
3109 // Returns true if all integers in `arrayAttr` are in the half-open [min, max}
3110 // interval. If `halfOpen` is true then the admissible interval is [min, max).
3111 // Otherwise, the admissible interval is [min, max].
3112 template <typename OpType>
3113 static LogicalResult
3114 isIntegerArrayAttrConfinedToRange(OpType op, ArrayAttr arrayAttr, int64_t min,
3115                                   int64_t max, StringRef attrName,
3116                                   bool halfOpen = true) {
3117   for (auto attr : arrayAttr) {
3118     auto val = llvm::cast<IntegerAttr>(attr).getInt();
3119     auto upper = max;
3120     if (!halfOpen)
3121       upper += 1;
3122     if (val < min || val >= upper)
3123       return op.emitOpError("expected ") << attrName << " to be confined to ["
3124                                          << min << ", " << upper << ")";
3125   }
3126   return success();
3127 }
3128 
3129 // Returns true if all integers in `arrayAttr` are in the half-open [min, max}
3130 // interval. If `halfOpen` is true then the admissible interval is [min, max).
3131 // Otherwise, the admissible interval is [min, max].
3132 template <typename OpType>
3133 static LogicalResult
3134 isIntegerArrayAttrConfinedToShape(OpType op, ArrayAttr arrayAttr,
3135                                   ArrayRef<int64_t> shape, StringRef attrName,
3136                                   bool halfOpen = true, int64_t min = 0) {
3137   for (auto [index, attrDimPair] :
3138        llvm::enumerate(llvm::zip_first(arrayAttr, shape))) {
3139     int64_t val = llvm::cast<IntegerAttr>(std::get<0>(attrDimPair)).getInt();
3140     int64_t max = std::get<1>(attrDimPair);
3141     if (!halfOpen)
3142       max += 1;
3143     if (val < min || val >= max)
3144       return op.emitOpError("expected ")
3145              << attrName << " dimension " << index << " to be confined to ["
3146              << min << ", " << max << ")";
3147   }
3148   return success();
3149 }
3150 
3151 // Returns true if, for all indices i = 0..shape.size()-1, val is in the
3152 // [min, max} interval:
3153 //   val = `arrayAttr1[i]` + `arrayAttr2[i]`,
3154 // If `halfOpen` is true then the admissible interval is [min, max). Otherwise,
3155 // the admissible interval is [min, max].
3156 template <typename OpType>
3157 static LogicalResult isSumOfIntegerArrayAttrConfinedToShape(
3158     OpType op, ArrayAttr arrayAttr1, ArrayAttr arrayAttr2,
3159     ArrayRef<int64_t> shape, StringRef attrName1, StringRef attrName2,
3160     bool halfOpen = true, int64_t min = 1) {
3161   assert(arrayAttr1.size() <= shape.size());
3162   assert(arrayAttr2.size() <= shape.size());
3163   for (auto [index, it] :
3164        llvm::enumerate(llvm::zip(arrayAttr1, arrayAttr2, shape))) {
3165     auto val1 = llvm::cast<IntegerAttr>(std::get<0>(it)).getInt();
3166     auto val2 = llvm::cast<IntegerAttr>(std::get<1>(it)).getInt();
3167     int64_t max = std::get<2>(it);
3168     if (!halfOpen)
3169       max += 1;
3170     if (val1 + val2 < 0 || val1 + val2 >= max)
3171       return op.emitOpError("expected sum(")
3172              << attrName1 << ", " << attrName2 << ") dimension " << index
3173              << " to be confined to [" << min << ", " << max << ")";
3174   }
3175   return success();
3176 }
3177 
3178 static ArrayAttr makeI64ArrayAttr(ArrayRef<int64_t> values,
3179                                   MLIRContext *context) {
3180   auto attrs = llvm::map_range(values, [context](int64_t v) -> Attribute {
3181     return IntegerAttr::get(IntegerType::get(context, 64), APInt(64, v));
3182   });
3183   return ArrayAttr::get(context, llvm::to_vector<8>(attrs));
3184 }
3185 
3186 LogicalResult InsertStridedSliceOp::verify() {
3187   auto sourceVectorType = getSourceVectorType();
3188   auto destVectorType = getDestVectorType();
3189   auto offsets = getOffsetsAttr();
3190   auto strides = getStridesAttr();
3191   if (offsets.size() != static_cast<unsigned>(destVectorType.getRank()))
3192     return emitOpError(
3193         "expected offsets of same size as destination vector rank");
3194   if (strides.size() != static_cast<unsigned>(sourceVectorType.getRank()))
3195     return emitOpError("expected strides of same size as source vector rank");
3196   if (sourceVectorType.getRank() > destVectorType.getRank())
3197     return emitOpError(
3198         "expected source rank to be no greater than destination rank");
3199 
3200   auto sourceShape = sourceVectorType.getShape();
3201   auto destShape = destVectorType.getShape();
3202   SmallVector<int64_t, 4> sourceShapeAsDestShape(
3203       destShape.size() - sourceShape.size(), 0);
3204   sourceShapeAsDestShape.append(sourceShape.begin(), sourceShape.end());
3205   auto offName = InsertStridedSliceOp::getOffsetsAttrName();
3206   auto stridesName = InsertStridedSliceOp::getStridesAttrName();
3207   if (failed(isIntegerArrayAttrConfinedToShape(*this, offsets, destShape,
3208                                                offName)) ||
3209       failed(isIntegerArrayAttrConfinedToRange(*this, strides, /*min=*/1,
3210                                                /*max=*/1, stridesName,
3211                                                /*halfOpen=*/false)) ||
3212       failed(isSumOfIntegerArrayAttrConfinedToShape(
3213           *this, offsets,
3214           makeI64ArrayAttr(sourceShapeAsDestShape, getContext()), destShape,
3215           offName, "source vector shape",
3216           /*halfOpen=*/false, /*min=*/1)))
3217     return failure();
3218 
3219   unsigned rankDiff = destShape.size() - sourceShape.size();
3220   for (unsigned idx = 0; idx < sourceShape.size(); ++idx) {
3221     if (sourceVectorType.getScalableDims()[idx] !=
3222         destVectorType.getScalableDims()[idx + rankDiff]) {
3223       return emitOpError("mismatching scalable flags (at source vector idx=")
3224              << idx << ")";
3225     }
3226     if (sourceVectorType.getScalableDims()[idx]) {
3227       auto sourceSize = sourceShape[idx];
3228       auto destSize = destShape[idx + rankDiff];
3229       if (sourceSize != destSize) {
3230         return emitOpError("expected size at idx=")
3231                << idx
3232                << (" to match the corresponding base size from the input "
3233                    "vector (")
3234                << sourceSize << (" vs ") << destSize << (")");
3235       }
3236     }
3237   }
3238 
3239   return success();
3240 }
3241 
3242 namespace {
3243 /// Pattern to rewrite an InsertStridedSliceOp(SplatOp(X):src_type,
3244 /// SplatOp(X):dst_type) to SplatOp(X):dst_type.
3245 class FoldInsertStridedSliceSplat final
3246     : public OpRewritePattern<InsertStridedSliceOp> {
3247 public:
3248   using OpRewritePattern::OpRewritePattern;
3249 
3250   LogicalResult matchAndRewrite(InsertStridedSliceOp insertStridedSliceOp,
3251                                 PatternRewriter &rewriter) const override {
3252     auto srcSplatOp =
3253         insertStridedSliceOp.getSource().getDefiningOp<vector::SplatOp>();
3254     auto destSplatOp =
3255         insertStridedSliceOp.getDest().getDefiningOp<vector::SplatOp>();
3256 
3257     if (!srcSplatOp || !destSplatOp)
3258       return failure();
3259 
3260     if (srcSplatOp.getInput() != destSplatOp.getInput())
3261       return failure();
3262 
3263     rewriter.replaceOp(insertStridedSliceOp, insertStridedSliceOp.getDest());
3264     return success();
3265   }
3266 };
3267 
3268 /// Pattern to rewrite an InsertStridedSliceOp(ExtractStridedSliceOp(dst), dst)
3269 /// to dst.
3270 class FoldInsertStridedSliceOfExtract final
3271     : public OpRewritePattern<InsertStridedSliceOp> {
3272 public:
3273   using OpRewritePattern::OpRewritePattern;
3274 
3275   LogicalResult matchAndRewrite(InsertStridedSliceOp insertStridedSliceOp,
3276                                 PatternRewriter &rewriter) const override {
3277     auto extractStridedSliceOp =
3278         insertStridedSliceOp.getSource()
3279             .getDefiningOp<vector::ExtractStridedSliceOp>();
3280 
3281     if (!extractStridedSliceOp)
3282       return failure();
3283 
3284     if (extractStridedSliceOp.getOperand() != insertStridedSliceOp.getDest())
3285       return failure();
3286 
3287     // Check if have the same strides and offsets.
3288     if (extractStridedSliceOp.getStrides() !=
3289             insertStridedSliceOp.getStrides() ||
3290         extractStridedSliceOp.getOffsets() != insertStridedSliceOp.getOffsets())
3291       return failure();
3292 
3293     rewriter.replaceOp(insertStridedSliceOp, insertStridedSliceOp.getDest());
3294     return success();
3295   }
3296 };
3297 
3298 // Pattern to rewrite an InsertStridedSliceOp(ConstantOp into ConstantOp) ->
3299 // ConstantOp.
3300 class InsertStridedSliceConstantFolder final
3301     : public OpRewritePattern<InsertStridedSliceOp> {
3302 public:
3303   using OpRewritePattern::OpRewritePattern;
3304 
3305   // Do not create constants with more than `vectorSizeFoldThreashold` elements,
3306   // unless the source vector constant has a single use.
3307   static constexpr int64_t vectorSizeFoldThreshold = 256;
3308 
3309   LogicalResult matchAndRewrite(InsertStridedSliceOp op,
3310                                 PatternRewriter &rewriter) const override {
3311     // Return if 'InsertOp' operand is not defined by a compatible vector
3312     // ConstantOp.
3313     TypedValue<VectorType> destVector = op.getDest();
3314     Attribute vectorDestCst;
3315     if (!matchPattern(destVector, m_Constant(&vectorDestCst)))
3316       return failure();
3317 
3318     VectorType destTy = destVector.getType();
3319     if (destTy.isScalable())
3320       return failure();
3321 
3322     // Make sure we do not create too many large constants.
3323     if (destTy.getNumElements() > vectorSizeFoldThreshold &&
3324         !destVector.hasOneUse())
3325       return failure();
3326 
3327     auto denseDest = llvm::cast<DenseElementsAttr>(vectorDestCst);
3328 
3329     TypedValue<VectorType> sourceValue = op.getSource();
3330     Attribute sourceCst;
3331     if (!matchPattern(sourceValue, m_Constant(&sourceCst)))
3332       return failure();
3333 
3334     // TODO: Handle non-unit strides when they become available.
3335     if (op.hasNonUnitStrides())
3336       return failure();
3337 
3338     VectorType sliceVecTy = sourceValue.getType();
3339     ArrayRef<int64_t> sliceShape = sliceVecTy.getShape();
3340     int64_t rankDifference = destTy.getRank() - sliceVecTy.getRank();
3341     SmallVector<int64_t, 4> offsets = getI64SubArray(op.getOffsets());
3342     SmallVector<int64_t, 4> destStrides = computeStrides(destTy.getShape());
3343 
3344     // Calcualte the destination element indices by enumerating all slice
3345     // positions within the destination and linearizing them. The enumeration
3346     // order is lexicographic which yields a sequence of monotonically
3347     // increasing linearized position indices.
3348     // Because the destination may have higher dimensionality then the slice,
3349     // we keep track of two overlapping sets of positions and offsets.
3350     auto denseSlice = llvm::cast<DenseElementsAttr>(sourceCst);
3351     auto sliceValuesIt = denseSlice.value_begin<Attribute>();
3352     auto newValues = llvm::to_vector(denseDest.getValues<Attribute>());
3353     SmallVector<int64_t> currDestPosition(offsets.begin(), offsets.end());
3354     MutableArrayRef<int64_t> currSlicePosition(
3355         currDestPosition.begin() + rankDifference, currDestPosition.end());
3356     ArrayRef<int64_t> sliceOffsets(offsets.begin() + rankDifference,
3357                                    offsets.end());
3358     do {
3359       int64_t linearizedPosition = linearize(currDestPosition, destStrides);
3360       assert(linearizedPosition < destTy.getNumElements() && "Invalid index");
3361       assert(sliceValuesIt != denseSlice.value_end<Attribute>() &&
3362              "Invalid slice element");
3363       newValues[linearizedPosition] = *sliceValuesIt;
3364       ++sliceValuesIt;
3365     } while (succeeded(
3366         incSlicePosition(currSlicePosition, sliceShape, sliceOffsets)));
3367 
3368     auto newAttr = DenseElementsAttr::get(destTy, newValues);
3369     rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, newAttr);
3370     return success();
3371   }
3372 };
3373 
3374 } // namespace
3375 
3376 void vector::InsertStridedSliceOp::getCanonicalizationPatterns(
3377     RewritePatternSet &results, MLIRContext *context) {
3378   results.add<FoldInsertStridedSliceSplat, FoldInsertStridedSliceOfExtract,
3379               InsertStridedSliceConstantFolder>(context);
3380 }
3381 
3382 OpFoldResult InsertStridedSliceOp::fold(FoldAdaptor adaptor) {
3383   if (getSourceVectorType() == getDestVectorType())
3384     return getSource();
3385   return {};
3386 }
3387 
3388 //===----------------------------------------------------------------------===//
3389 // OuterProductOp
3390 //===----------------------------------------------------------------------===//
3391 
3392 /// Build an op without mask, use the type of `acc` as the return type.
3393 void OuterProductOp::build(OpBuilder &builder, OperationState &result,
3394                            Value lhs, Value rhs, Value acc) {
3395   result.addOperands({lhs, rhs, acc});
3396   result.addTypes(acc.getType());
3397 }
3398 
3399 void OuterProductOp::print(OpAsmPrinter &p) {
3400   p << " " << getLhs() << ", " << getRhs();
3401   if (getAcc()) {
3402     p << ", " << getAcc();
3403     p.printOptionalAttrDict((*this)->getAttrs());
3404   }
3405   p << " : " << getLhs().getType() << ", " << getRhs().getType();
3406 }
3407 
3408 ParseResult OuterProductOp::parse(OpAsmParser &parser, OperationState &result) {
3409   SmallVector<OpAsmParser::UnresolvedOperand, 3> operandsInfo;
3410   Type tLHS, tRHS;
3411   if (parser.parseOperandList(operandsInfo) ||
3412       parser.parseOptionalAttrDict(result.attributes) ||
3413       parser.parseColonType(tLHS) || parser.parseComma() ||
3414       parser.parseType(tRHS))
3415     return failure();
3416   if (operandsInfo.size() < 2)
3417     return parser.emitError(parser.getNameLoc(),
3418                             "expected at least 2 operands");
3419   VectorType vLHS = llvm::dyn_cast<VectorType>(tLHS);
3420   VectorType vRHS = llvm::dyn_cast<VectorType>(tRHS);
3421   if (!vLHS)
3422     return parser.emitError(parser.getNameLoc(),
3423                             "expected vector type for operand #1");
3424 
3425   VectorType resType;
3426   if (vRHS) {
3427     SmallVector<bool> scalableDimsRes{vLHS.getScalableDims()[0],
3428                                       vRHS.getScalableDims()[0]};
3429     resType = VectorType::get({vLHS.getDimSize(0), vRHS.getDimSize(0)},
3430                               vLHS.getElementType(), scalableDimsRes);
3431   } else {
3432     // Scalar RHS operand
3433     SmallVector<bool> scalableDimsRes{vLHS.getScalableDims()[0]};
3434     resType = VectorType::get({vLHS.getDimSize(0)}, vLHS.getElementType(),
3435                               scalableDimsRes);
3436   }
3437 
3438   if (!result.attributes.get(OuterProductOp::getKindAttrName(result.name))) {
3439     result.attributes.append(
3440         OuterProductOp::getKindAttrName(result.name),
3441         CombiningKindAttr::get(result.getContext(),
3442                                OuterProductOp::getDefaultKind()));
3443   }
3444 
3445   return failure(
3446       parser.resolveOperand(operandsInfo[0], tLHS, result.operands) ||
3447       parser.resolveOperand(operandsInfo[1], tRHS, result.operands) ||
3448       (operandsInfo.size() > 2 &&
3449        parser.resolveOperand(operandsInfo[2], resType, result.operands)) ||
3450       parser.addTypeToList(resType, result.types));
3451 }
3452 
3453 LogicalResult OuterProductOp::verify() {
3454   Type tRHS = getOperandTypeRHS();
3455   VectorType vLHS = getOperandVectorTypeLHS(),
3456              vRHS = llvm::dyn_cast<VectorType>(tRHS),
3457              vACC = getOperandVectorTypeACC(), vRES = getResultVectorType();
3458 
3459   if (vLHS.getRank() != 1)
3460     return emitOpError("expected 1-d vector for operand #1");
3461 
3462   if (vRHS) {
3463     // Proper OUTER operation.
3464     if (vRHS.getRank() != 1)
3465       return emitOpError("expected 1-d vector for operand #2");
3466     if (vRES.getRank() != 2)
3467       return emitOpError("expected 2-d vector result");
3468     if (vLHS.getDimSize(0) != vRES.getDimSize(0))
3469       return emitOpError("expected #1 operand dim to match result dim #1");
3470     if (vRHS.getDimSize(0) != vRES.getDimSize(1))
3471       return emitOpError("expected #2 operand dim to match result dim #2");
3472     if (vLHS.isScalable() && !vRHS.isScalable()) {
3473       // This restriction reflects what's currently supported in terms of
3474       // scalable vectors. However, we could relax this if there's a use case.
3475       return emitOpError(
3476           "expected either both or only #2 operand dim to be scalable");
3477     }
3478   } else {
3479     // An AXPY operation.
3480     if (vRES.getRank() != 1)
3481       return emitOpError("expected 1-d vector result");
3482     if (vLHS.getDimSize(0) != vRES.getDimSize(0))
3483       return emitOpError("expected #1 operand dim to match result dim #1");
3484   }
3485 
3486   if (vACC && vACC != vRES)
3487     return emitOpError("expected operand #3 of same type as result type");
3488 
3489   // Verify supported combining kind.
3490   if (!isSupportedCombiningKind(getKind(), vRES.getElementType()))
3491     return emitOpError("unsupported outerproduct type");
3492 
3493   return success();
3494 }
3495 
3496 // MaskableOpInterface methods.
3497 
3498 /// Returns the mask type expected by this operation. Mostly used for
3499 /// verification purposes. It requires the operation to be vectorized."
3500 Type OuterProductOp::getExpectedMaskType() {
3501   auto vecType = this->getResultVectorType();
3502   return VectorType::get(vecType.getShape(),
3503                          IntegerType::get(vecType.getContext(), /*width=*/1),
3504                          vecType.getScalableDims());
3505 }
3506 
3507 //===----------------------------------------------------------------------===//
3508 // ExtractStridedSliceOp
3509 //===----------------------------------------------------------------------===//
3510 
3511 // Inference works as follows:
3512 //   1. Add 'sizes' from prefix of dims in 'offsets'.
3513 //   2. Add sizes from 'vectorType' for remaining dims.
3514 // Scalable flags are inherited from 'vectorType'.
3515 static Type inferStridedSliceOpResultType(VectorType vectorType,
3516                                           ArrayAttr offsets, ArrayAttr sizes,
3517                                           ArrayAttr strides) {
3518   assert(offsets.size() == sizes.size() && offsets.size() == strides.size());
3519   SmallVector<int64_t, 4> shape;
3520   shape.reserve(vectorType.getRank());
3521   unsigned idx = 0;
3522   for (unsigned e = offsets.size(); idx < e; ++idx)
3523     shape.push_back(llvm::cast<IntegerAttr>(sizes[idx]).getInt());
3524   for (unsigned e = vectorType.getShape().size(); idx < e; ++idx)
3525     shape.push_back(vectorType.getShape()[idx]);
3526 
3527   return VectorType::get(shape, vectorType.getElementType(),
3528                          vectorType.getScalableDims());
3529 }
3530 
3531 void ExtractStridedSliceOp::build(OpBuilder &builder, OperationState &result,
3532                                   Value source, ArrayRef<int64_t> offsets,
3533                                   ArrayRef<int64_t> sizes,
3534                                   ArrayRef<int64_t> strides) {
3535   result.addOperands(source);
3536   auto offsetsAttr = getVectorSubscriptAttr(builder, offsets);
3537   auto sizesAttr = getVectorSubscriptAttr(builder, sizes);
3538   auto stridesAttr = getVectorSubscriptAttr(builder, strides);
3539   result.addTypes(
3540       inferStridedSliceOpResultType(llvm::cast<VectorType>(source.getType()),
3541                                     offsetsAttr, sizesAttr, stridesAttr));
3542   result.addAttribute(ExtractStridedSliceOp::getOffsetsAttrName(result.name),
3543                       offsetsAttr);
3544   result.addAttribute(ExtractStridedSliceOp::getSizesAttrName(result.name),
3545                       sizesAttr);
3546   result.addAttribute(ExtractStridedSliceOp::getStridesAttrName(result.name),
3547                       stridesAttr);
3548 }
3549 
3550 LogicalResult ExtractStridedSliceOp::verify() {
3551   auto type = getSourceVectorType();
3552   auto offsets = getOffsetsAttr();
3553   auto sizes = getSizesAttr();
3554   auto strides = getStridesAttr();
3555   if (offsets.size() != sizes.size() || offsets.size() != strides.size())
3556     return emitOpError(
3557         "expected offsets, sizes and strides attributes of same size");
3558 
3559   auto shape = type.getShape();
3560   auto offName = getOffsetsAttrName();
3561   auto sizesName = getSizesAttrName();
3562   auto stridesName = getStridesAttrName();
3563   if (failed(
3564           isIntegerArrayAttrSmallerThanShape(*this, offsets, shape, offName)) ||
3565       failed(
3566           isIntegerArrayAttrSmallerThanShape(*this, sizes, shape, sizesName)) ||
3567       failed(isIntegerArrayAttrSmallerThanShape(*this, strides, shape,
3568                                                 stridesName)) ||
3569       failed(
3570           isIntegerArrayAttrConfinedToShape(*this, offsets, shape, offName)) ||
3571       failed(isIntegerArrayAttrConfinedToShape(*this, sizes, shape, sizesName,
3572                                                /*halfOpen=*/false,
3573                                                /*min=*/1)) ||
3574       failed(isIntegerArrayAttrConfinedToRange(*this, strides, /*min=*/1,
3575                                                /*max=*/1, stridesName,
3576                                                /*halfOpen=*/false)) ||
3577       failed(isSumOfIntegerArrayAttrConfinedToShape(*this, offsets, sizes,
3578                                                     shape, offName, sizesName,
3579                                                     /*halfOpen=*/false)))
3580     return failure();
3581 
3582   auto resultType = inferStridedSliceOpResultType(getSourceVectorType(),
3583                                                   offsets, sizes, strides);
3584   if (getResult().getType() != resultType)
3585     return emitOpError("expected result type to be ") << resultType;
3586 
3587   for (unsigned idx = 0; idx < sizes.size(); ++idx) {
3588     if (type.getScalableDims()[idx]) {
3589       auto inputDim = type.getShape()[idx];
3590       auto inputSize = llvm::cast<IntegerAttr>(sizes[idx]).getInt();
3591       if (inputDim != inputSize)
3592         return emitOpError("expected size at idx=")
3593                << idx
3594                << (" to match the corresponding base size from the input "
3595                    "vector (")
3596                << inputSize << (" vs ") << inputDim << (")");
3597     }
3598   }
3599 
3600   return success();
3601 }
3602 
3603 // When the source of ExtractStrided comes from a chain of InsertStrided ops try
3604 // to use the source of the InsertStrided ops if we can detect that the
3605 // extracted vector is a subset of one of the vector inserted.
3606 static LogicalResult
3607 foldExtractStridedOpFromInsertChain(ExtractStridedSliceOp op) {
3608   // Helper to extract integer out of ArrayAttr.
3609   auto getElement = [](ArrayAttr array, int idx) {
3610     return llvm::cast<IntegerAttr>(array[idx]).getInt();
3611   };
3612   ArrayAttr extractOffsets = op.getOffsets();
3613   ArrayAttr extractStrides = op.getStrides();
3614   ArrayAttr extractSizes = op.getSizes();
3615   auto insertOp = op.getVector().getDefiningOp<InsertStridedSliceOp>();
3616   while (insertOp) {
3617     if (op.getSourceVectorType().getRank() !=
3618         insertOp.getSourceVectorType().getRank())
3619       return failure();
3620     ArrayAttr insertOffsets = insertOp.getOffsets();
3621     ArrayAttr insertStrides = insertOp.getStrides();
3622     // If the rank of extract is greater than the rank of insert, we are likely
3623     // extracting a partial chunk of the vector inserted.
3624     if (extractOffsets.size() > insertOffsets.size())
3625       return failure();
3626     bool patialoverlap = false;
3627     bool disjoint = false;
3628     SmallVector<int64_t, 4> offsetDiffs;
3629     for (unsigned dim = 0, e = extractOffsets.size(); dim < e; ++dim) {
3630       if (getElement(extractStrides, dim) != getElement(insertStrides, dim))
3631         return failure();
3632       int64_t start = getElement(insertOffsets, dim);
3633       int64_t end = start + insertOp.getSourceVectorType().getDimSize(dim);
3634       int64_t offset = getElement(extractOffsets, dim);
3635       int64_t size = getElement(extractSizes, dim);
3636       // Check if the start of the extract offset is in the interval inserted.
3637       if (start <= offset && offset < end) {
3638         // If the extract interval overlaps but is not fully included we may
3639         // have a partial overlap that will prevent any folding.
3640         if (offset + size > end)
3641           patialoverlap = true;
3642         offsetDiffs.push_back(offset - start);
3643         continue;
3644       }
3645       disjoint = true;
3646       break;
3647     }
3648     // The extract element chunk is a subset of the insert element.
3649     if (!disjoint && !patialoverlap) {
3650       op.setOperand(insertOp.getSource());
3651       // OpBuilder is only used as a helper to build an I64ArrayAttr.
3652       OpBuilder b(op.getContext());
3653       op.setOffsetsAttr(b.getI64ArrayAttr(offsetDiffs));
3654       return success();
3655     }
3656     // If the chunk extracted is disjoint from the chunk inserted, keep looking
3657     // in the insert chain.
3658     if (disjoint)
3659       insertOp = insertOp.getDest().getDefiningOp<InsertStridedSliceOp>();
3660     else {
3661       // The extracted vector partially overlap the inserted vector, we cannot
3662       // fold.
3663       return failure();
3664     }
3665   }
3666   return failure();
3667 }
3668 
3669 OpFoldResult ExtractStridedSliceOp::fold(FoldAdaptor adaptor) {
3670   if (getSourceVectorType() == getResult().getType())
3671     return getVector();
3672   if (succeeded(foldExtractStridedOpFromInsertChain(*this)))
3673     return getResult();
3674   return {};
3675 }
3676 
3677 void ExtractStridedSliceOp::getOffsets(SmallVectorImpl<int64_t> &results) {
3678   populateFromInt64AttrArray(getOffsets(), results);
3679 }
3680 
3681 namespace {
3682 
3683 // Pattern to rewrite an ExtractStridedSliceOp(ConstantMaskOp) to
3684 // ConstantMaskOp.
3685 class StridedSliceConstantMaskFolder final
3686     : public OpRewritePattern<ExtractStridedSliceOp> {
3687 public:
3688   using OpRewritePattern::OpRewritePattern;
3689 
3690   LogicalResult matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp,
3691                                 PatternRewriter &rewriter) const override {
3692     // Return if 'extractStridedSliceOp' operand is not defined by a
3693     // ConstantMaskOp.
3694     auto *defOp = extractStridedSliceOp.getVector().getDefiningOp();
3695     auto constantMaskOp = dyn_cast_or_null<ConstantMaskOp>(defOp);
3696     if (!constantMaskOp)
3697       return failure();
3698     // Return if 'extractStridedSliceOp' has non-unit strides.
3699     if (extractStridedSliceOp.hasNonUnitStrides())
3700       return failure();
3701     // Gather constant mask dimension sizes.
3702     ArrayRef<int64_t> maskDimSizes = constantMaskOp.getMaskDimSizes();
3703     // Gather strided slice offsets and sizes.
3704     SmallVector<int64_t, 4> sliceOffsets;
3705     populateFromInt64AttrArray(extractStridedSliceOp.getOffsets(),
3706                                sliceOffsets);
3707     SmallVector<int64_t, 4> sliceSizes;
3708     populateFromInt64AttrArray(extractStridedSliceOp.getSizes(), sliceSizes);
3709 
3710     // Compute slice of vector mask region.
3711     SmallVector<int64_t, 4> sliceMaskDimSizes;
3712     sliceMaskDimSizes.reserve(maskDimSizes.size());
3713     for (auto [maskDimSize, sliceOffset, sliceSize] :
3714          llvm::zip(maskDimSizes, sliceOffsets, sliceSizes)) {
3715       int64_t sliceMaskDimSize = std::max(
3716           static_cast<int64_t>(0),
3717           std::min(sliceOffset + sliceSize, maskDimSize) - sliceOffset);
3718       sliceMaskDimSizes.push_back(sliceMaskDimSize);
3719     }
3720     // Add unchanged dimensions.
3721     if (sliceMaskDimSizes.size() < maskDimSizes.size())
3722       for (size_t i = sliceMaskDimSizes.size(); i < maskDimSizes.size(); ++i)
3723         sliceMaskDimSizes.push_back(maskDimSizes[i]);
3724     // If any of 'sliceMaskDimSizes' are zero, then set all to zero (masked
3725     // region is a conjunction of mask dim intervals).
3726     if (llvm::is_contained(sliceMaskDimSizes, 0))
3727       sliceMaskDimSizes.assign(maskDimSizes.size(), 0);
3728 
3729     // Replace 'extractStridedSliceOp' with ConstantMaskOp with sliced mask
3730     // region.
3731     rewriter.replaceOpWithNewOp<ConstantMaskOp>(
3732         extractStridedSliceOp, extractStridedSliceOp.getResult().getType(),
3733         sliceMaskDimSizes);
3734     return success();
3735   }
3736 };
3737 
3738 // Pattern to rewrite a ExtractStridedSliceOp(splat ConstantOp) -> ConstantOp.
3739 class StridedSliceSplatConstantFolder final
3740     : public OpRewritePattern<ExtractStridedSliceOp> {
3741 public:
3742   using OpRewritePattern::OpRewritePattern;
3743 
3744   LogicalResult matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp,
3745                                 PatternRewriter &rewriter) const override {
3746     // Return if 'ExtractStridedSliceOp' operand is not defined by a splat
3747     // ConstantOp.
3748     Value sourceVector = extractStridedSliceOp.getVector();
3749     Attribute vectorCst;
3750     if (!matchPattern(sourceVector, m_Constant(&vectorCst)))
3751       return failure();
3752 
3753     auto splat = llvm::dyn_cast<SplatElementsAttr>(vectorCst);
3754     if (!splat)
3755       return failure();
3756 
3757     auto newAttr = SplatElementsAttr::get(extractStridedSliceOp.getType(),
3758                                           splat.getSplatValue<Attribute>());
3759     rewriter.replaceOpWithNewOp<arith::ConstantOp>(extractStridedSliceOp,
3760                                                    newAttr);
3761     return success();
3762   }
3763 };
3764 
3765 // Pattern to rewrite a ExtractStridedSliceOp(non-splat ConstantOp) ->
3766 // ConstantOp.
3767 class StridedSliceNonSplatConstantFolder final
3768     : public OpRewritePattern<ExtractStridedSliceOp> {
3769 public:
3770   using OpRewritePattern::OpRewritePattern;
3771 
3772   LogicalResult matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp,
3773                                 PatternRewriter &rewriter) const override {
3774     // Return if 'ExtractStridedSliceOp' operand is not defined by a non-splat
3775     // ConstantOp.
3776     Value sourceVector = extractStridedSliceOp.getVector();
3777     Attribute vectorCst;
3778     if (!matchPattern(sourceVector, m_Constant(&vectorCst)))
3779       return failure();
3780 
3781     // The splat case is handled by `StridedSliceSplatConstantFolder`.
3782     auto dense = llvm::dyn_cast<DenseElementsAttr>(vectorCst);
3783     if (!dense || dense.isSplat())
3784       return failure();
3785 
3786     // TODO: Handle non-unit strides when they become available.
3787     if (extractStridedSliceOp.hasNonUnitStrides())
3788       return failure();
3789 
3790     auto sourceVecTy = llvm::cast<VectorType>(sourceVector.getType());
3791     ArrayRef<int64_t> sourceShape = sourceVecTy.getShape();
3792     SmallVector<int64_t, 4> sourceStrides = computeStrides(sourceShape);
3793 
3794     VectorType sliceVecTy = extractStridedSliceOp.getType();
3795     ArrayRef<int64_t> sliceShape = sliceVecTy.getShape();
3796     int64_t sliceRank = sliceVecTy.getRank();
3797 
3798     // Expand offsets and sizes to match the vector rank.
3799     SmallVector<int64_t, 4> offsets(sliceRank, 0);
3800     copy(getI64SubArray(extractStridedSliceOp.getOffsets()), offsets.begin());
3801 
3802     SmallVector<int64_t, 4> sizes(sourceShape);
3803     copy(getI64SubArray(extractStridedSliceOp.getSizes()), sizes.begin());
3804 
3805     // Calculate the slice elements by enumerating all slice positions and
3806     // linearizing them. The enumeration order is lexicographic which yields a
3807     // sequence of monotonically increasing linearized position indices.
3808     auto denseValuesBegin = dense.value_begin<Attribute>();
3809     SmallVector<Attribute> sliceValues;
3810     sliceValues.reserve(sliceVecTy.getNumElements());
3811     SmallVector<int64_t> currSlicePosition(offsets.begin(), offsets.end());
3812     do {
3813       int64_t linearizedPosition = linearize(currSlicePosition, sourceStrides);
3814       assert(linearizedPosition < sourceVecTy.getNumElements() &&
3815              "Invalid index");
3816       sliceValues.push_back(*(denseValuesBegin + linearizedPosition));
3817     } while (
3818         succeeded(incSlicePosition(currSlicePosition, sliceShape, offsets)));
3819 
3820     assert(static_cast<int64_t>(sliceValues.size()) ==
3821                sliceVecTy.getNumElements() &&
3822            "Invalid number of slice elements");
3823     auto newAttr = DenseElementsAttr::get(sliceVecTy, sliceValues);
3824     rewriter.replaceOpWithNewOp<arith::ConstantOp>(extractStridedSliceOp,
3825                                                    newAttr);
3826     return success();
3827   }
3828 };
3829 
3830 // Pattern to rewrite an ExtractStridedSliceOp(BroadcastOp) to
3831 // BroadcastOp(ExtractStrideSliceOp).
3832 class StridedSliceBroadcast final
3833     : public OpRewritePattern<ExtractStridedSliceOp> {
3834 public:
3835   using OpRewritePattern::OpRewritePattern;
3836 
3837   LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
3838                                 PatternRewriter &rewriter) const override {
3839     auto broadcast = op.getVector().getDefiningOp<BroadcastOp>();
3840     if (!broadcast)
3841       return failure();
3842     auto srcVecType =
3843         llvm::dyn_cast<VectorType>(broadcast.getSource().getType());
3844     unsigned srcRank = srcVecType ? srcVecType.getRank() : 0;
3845     auto dstVecType = llvm::cast<VectorType>(op.getType());
3846     unsigned dstRank = dstVecType.getRank();
3847     unsigned rankDiff = dstRank - srcRank;
3848     // Check if the most inner dimensions of the source of the broadcast are the
3849     // same as the destination of the extract. If this is the case we can just
3850     // use a broadcast as the original dimensions are untouched.
3851     bool lowerDimMatch = true;
3852     for (unsigned i = 0; i < srcRank; i++) {
3853       if (srcVecType.getDimSize(i) != dstVecType.getDimSize(i + rankDiff)) {
3854         lowerDimMatch = false;
3855         break;
3856       }
3857     }
3858     Value source = broadcast.getSource();
3859     // If the inner dimensions don't match, it means we need to extract from the
3860     // source of the orignal broadcast and then broadcast the extracted value.
3861     // We also need to handle degenerated cases where the source is effectively
3862     // just a single scalar.
3863     bool isScalarSrc = (srcRank == 0 || srcVecType.getNumElements() == 1);
3864     if (!lowerDimMatch && !isScalarSrc) {
3865       source = rewriter.create<ExtractStridedSliceOp>(
3866           op->getLoc(), source,
3867           getI64SubArray(op.getOffsets(), /* dropFront=*/rankDiff),
3868           getI64SubArray(op.getSizes(), /* dropFront=*/rankDiff),
3869           getI64SubArray(op.getStrides(), /* dropFront=*/rankDiff));
3870     }
3871     rewriter.replaceOpWithNewOp<BroadcastOp>(op, op.getType(), source);
3872     return success();
3873   }
3874 };
3875 
3876 /// Pattern to rewrite an ExtractStridedSliceOp(SplatOp) to SplatOp.
3877 class StridedSliceSplat final : public OpRewritePattern<ExtractStridedSliceOp> {
3878 public:
3879   using OpRewritePattern::OpRewritePattern;
3880 
3881   LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
3882                                 PatternRewriter &rewriter) const override {
3883     auto splat = op.getVector().getDefiningOp<SplatOp>();
3884     if (!splat)
3885       return failure();
3886     rewriter.replaceOpWithNewOp<SplatOp>(op, op.getType(), splat.getInput());
3887     return success();
3888   }
3889 };
3890 
3891 /// Pattern to rewrite simple cases of N-D extract_strided_slice, where the
3892 /// slice is contiguous, into extract and shape_cast.
3893 ///
3894 /// Example:
3895 ///     Before:
3896 ///         %1 = vector.extract_strided_slice %arg0 {
3897 ///                offsets = [0, 0, 0, 0, 0],
3898 ///                sizes = [1, 1, 1, 1, 8],
3899 ///                strides = [1, 1, 1, 1, 1]
3900 ///              } : vector<8x1x1x2x8xi8> to vector<1x1x1x1x8xi8>
3901 ///     After:
3902 ///         %0 = vector.extract %arg0[0, 0, 0, 0]
3903 ///                : vector<8xi8> from vector<8x1x1x2x8xi8>
3904 ///         %1 = vector.shape_cast %0
3905 ///                : vector<8xi8> to vector<1x1x1x1x8xi8>
3906 ///
3907 class ContiguousExtractStridedSliceToExtract final
3908     : public OpRewritePattern<ExtractStridedSliceOp> {
3909 public:
3910   using OpRewritePattern::OpRewritePattern;
3911 
3912   LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
3913                                 PatternRewriter &rewriter) const override {
3914     if (op.hasNonUnitStrides())
3915       return failure();
3916     Value source = op.getOperand();
3917     auto sourceType = cast<VectorType>(source.getType());
3918     if (sourceType.isScalable() || sourceType.getRank() == 0)
3919       return failure();
3920 
3921     // Compute the number of offsets to pass to ExtractOp::build. That is the
3922     // difference between the source rank and the desired slice rank. We walk
3923     // the dimensions from innermost out, and stop when the next slice dimension
3924     // is not full-size.
3925     SmallVector<int64_t> sizes = getI64SubArray(op.getSizes());
3926     int numOffsets;
3927     for (numOffsets = sizes.size(); numOffsets > 0; --numOffsets) {
3928       if (sizes[numOffsets - 1] != sourceType.getDimSize(numOffsets - 1))
3929         break;
3930     }
3931 
3932     // If the created extract op would have no offsets, then this whole
3933     // extract_strided_slice is the identity and should have been handled by
3934     // other canonicalizations.
3935     if (numOffsets == 0)
3936       return failure();
3937 
3938     // If not even the inner-most dimension is full-size, this op can't be
3939     // rewritten as an ExtractOp.
3940     if (numOffsets == sourceType.getRank() &&
3941         static_cast<int>(sizes.size()) == sourceType.getRank())
3942       return failure();
3943 
3944     // The outer dimensions must have unit size.
3945     for (int i = 0; i < numOffsets; ++i) {
3946       if (sizes[i] != 1)
3947         return failure();
3948     }
3949 
3950     // Avoid generating slices that have leading unit dimensions. The shape_cast
3951     // op that we create below would take bad generic fallback patterns
3952     // (ShapeCastOpRewritePattern).
3953     while (sizes[numOffsets] == 1 &&
3954            numOffsets < static_cast<int>(sizes.size()) - 1) {
3955       ++numOffsets;
3956     }
3957 
3958     SmallVector<int64_t> offsets = getI64SubArray(op.getOffsets());
3959     auto extractOffsets = ArrayRef(offsets).take_front(numOffsets);
3960     Value extract = rewriter.create<vector::ExtractOp>(op->getLoc(), source,
3961                                                        extractOffsets);
3962     rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, op.getType(), extract);
3963     return success();
3964   }
3965 };
3966 
3967 } // namespace
3968 
3969 void ExtractStridedSliceOp::getCanonicalizationPatterns(
3970     RewritePatternSet &results, MLIRContext *context) {
3971   // Pattern to rewrite a ExtractStridedSliceOp(ConstantMaskOp) ->
3972   // ConstantMaskOp and ExtractStridedSliceOp(ConstantOp) -> ConstantOp.
3973   results.add<StridedSliceConstantMaskFolder, StridedSliceSplatConstantFolder,
3974               StridedSliceNonSplatConstantFolder, StridedSliceBroadcast,
3975               StridedSliceSplat, ContiguousExtractStridedSliceToExtract>(
3976       context);
3977 }
3978 
3979 //===----------------------------------------------------------------------===//
3980 // TransferReadOp
3981 //===----------------------------------------------------------------------===//
3982 
3983 /// 1. Builder that sets padding to zero and an empty mask (variant with attrs).
3984 void TransferReadOp::build(OpBuilder &builder, OperationState &result,
3985                            VectorType vectorType, Value source,
3986                            ValueRange indices, AffineMapAttr permutationMapAttr,
3987                            /*optional*/ ArrayAttr inBoundsAttr) {
3988   Type elemType = llvm::cast<ShapedType>(source.getType()).getElementType();
3989   Value padding = builder.create<arith::ConstantOp>(
3990       result.location, elemType, builder.getZeroAttr(elemType));
3991   build(builder, result, vectorType, source, indices, permutationMapAttr,
3992         padding, /*mask=*/Value(), inBoundsAttr);
3993 }
3994 
3995 /// 2. Builder that sets padding to zero an empty mask (variant without attrs).
3996 void TransferReadOp::build(OpBuilder &builder, OperationState &result,
3997                            VectorType vectorType, Value source,
3998                            ValueRange indices, AffineMap permutationMap,
3999                            std::optional<ArrayRef<bool>> inBounds) {
4000   auto permutationMapAttr = AffineMapAttr::get(permutationMap);
4001   auto inBoundsAttr = (inBounds && !inBounds.value().empty())
4002                           ? builder.getBoolArrayAttr(inBounds.value())
4003                           : builder.getBoolArrayAttr(
4004                                 SmallVector<bool>(vectorType.getRank(), false));
4005   build(builder, result, vectorType, source, indices, permutationMapAttr,
4006         inBoundsAttr);
4007 }
4008 
4009 /// 3. Builder that sets permutation map to 'getMinorIdentityMap'.
4010 void TransferReadOp::build(OpBuilder &builder, OperationState &result,
4011                            VectorType vectorType, Value source,
4012                            ValueRange indices, Value padding,
4013                            std::optional<ArrayRef<bool>> inBounds) {
4014   AffineMap permutationMap = getTransferMinorIdentityMap(
4015       llvm::cast<ShapedType>(source.getType()), vectorType);
4016   auto permutationMapAttr = AffineMapAttr::get(permutationMap);
4017   auto inBoundsAttr = (inBounds && !inBounds.value().empty())
4018                           ? builder.getBoolArrayAttr(inBounds.value())
4019                           : builder.getBoolArrayAttr(
4020                                 SmallVector<bool>(vectorType.getRank(), false));
4021   build(builder, result, vectorType, source, indices, permutationMapAttr,
4022         padding,
4023         /*mask=*/Value(), inBoundsAttr);
4024 }
4025 
4026 /// 4. Builder that sets padding to zero and permutation map to
4027 /// 'getMinorIdentityMap'.
4028 void TransferReadOp::build(OpBuilder &builder, OperationState &result,
4029                            VectorType vectorType, Value source,
4030                            ValueRange indices,
4031                            std::optional<ArrayRef<bool>> inBounds) {
4032   Type elemType = llvm::cast<ShapedType>(source.getType()).getElementType();
4033   Value padding = builder.create<arith::ConstantOp>(
4034       result.location, elemType, builder.getZeroAttr(elemType));
4035   build(builder, result, vectorType, source, indices, padding, inBounds);
4036 }
4037 
4038 template <typename EmitFun>
4039 static LogicalResult verifyPermutationMap(AffineMap permutationMap,
4040                                           EmitFun emitOpError) {
4041   SmallVector<bool, 8> seen(permutationMap.getNumInputs(), false);
4042   for (auto expr : permutationMap.getResults()) {
4043     auto dim = dyn_cast<AffineDimExpr>(expr);
4044     auto zero = dyn_cast<AffineConstantExpr>(expr);
4045     if (zero) {
4046       if (zero.getValue() != 0) {
4047         return emitOpError(
4048             "requires a projected permutation_map (at most one dim or the zero "
4049             "constant can appear in each result)");
4050       }
4051       continue;
4052     }
4053     if (!dim) {
4054       return emitOpError("requires a projected permutation_map (at most one "
4055                          "dim or the zero constant can appear in each result)");
4056     }
4057     if (seen[dim.getPosition()]) {
4058       return emitOpError(
4059           "requires a permutation_map that is a permutation (found one dim "
4060           "used more than once)");
4061     }
4062     seen[dim.getPosition()] = true;
4063   }
4064   return success();
4065 }
4066 
4067 static LogicalResult
4068 verifyTransferOp(VectorTransferOpInterface op, ShapedType shapedType,
4069                  VectorType vectorType, VectorType maskType,
4070                  VectorType inferredMaskType, AffineMap permutationMap,
4071                  ArrayAttr inBounds) {
4072   if (op->hasAttr("masked")) {
4073     return op->emitOpError("masked attribute has been removed. "
4074                            "Use in_bounds instead.");
4075   }
4076 
4077   if (!llvm::isa<MemRefType, RankedTensorType>(shapedType))
4078     return op->emitOpError(
4079         "requires source to be a memref or ranked tensor type");
4080 
4081   auto elementType = shapedType.getElementType();
4082   DataLayout dataLayout = DataLayout::closest(op);
4083   if (auto vectorElementType = llvm::dyn_cast<VectorType>(elementType)) {
4084     // Memref or tensor has vector element type.
4085     unsigned sourceVecSize =
4086         dataLayout.getTypeSizeInBits(vectorElementType.getElementType()) *
4087         vectorElementType.getShape().back();
4088     unsigned resultVecSize =
4089         dataLayout.getTypeSizeInBits(vectorType.getElementType()) *
4090         vectorType.getShape().back();
4091     if (resultVecSize % sourceVecSize != 0)
4092       return op->emitOpError(
4093           "requires the bitwidth of the minor 1-D vector to be an integral "
4094           "multiple of the bitwidth of the minor 1-D vector of the source");
4095 
4096     unsigned sourceVecEltRank = vectorElementType.getRank();
4097     unsigned resultVecRank = vectorType.getRank();
4098     if (sourceVecEltRank > resultVecRank)
4099       return op->emitOpError(
4100           "requires source vector element and vector result ranks to match.");
4101     unsigned rankOffset = resultVecRank - sourceVecEltRank;
4102     // Check that permutation map results match 'rankOffset' of vector type.
4103     if (permutationMap.getNumResults() != rankOffset)
4104       return op->emitOpError("requires a permutation_map with result dims of "
4105                              "the same rank as the vector type");
4106 
4107     if (maskType)
4108       return op->emitOpError("does not support masks with vector element type");
4109   } else {
4110     // Memref or tensor has scalar element type.
4111     unsigned minorSize =
4112         vectorType.getRank() == 0 ? 1 : vectorType.getShape().back();
4113     unsigned resultVecSize =
4114         dataLayout.getTypeSizeInBits(vectorType.getElementType()) * minorSize;
4115     if (resultVecSize % dataLayout.getTypeSizeInBits(elementType) != 0)
4116       return op->emitOpError(
4117           "requires the bitwidth of the minor 1-D vector to be an integral "
4118           "multiple of the bitwidth of the source element type");
4119 
4120     // Check that permutation map results match rank of vector type.
4121     if (permutationMap.getNumResults() != vectorType.getRank())
4122       return op->emitOpError("requires a permutation_map with result dims of "
4123                              "the same rank as the vector type");
4124   }
4125 
4126   if (permutationMap.getNumSymbols() != 0)
4127     return op->emitOpError("requires permutation_map without symbols");
4128 
4129   if (permutationMap.getNumInputs() != shapedType.getRank())
4130     return op->emitOpError("requires a permutation_map with input dims of the "
4131                            "same rank as the source type");
4132 
4133   if (maskType && maskType != inferredMaskType)
4134     return op->emitOpError("inferred mask type (")
4135            << inferredMaskType << ") and mask operand type (" << maskType
4136            << ") don't match";
4137 
4138   if (permutationMap.getNumResults() != static_cast<int64_t>(inBounds.size()))
4139     return op->emitOpError("expects the in_bounds attr of same rank "
4140                            "as permutation_map results: ")
4141            << AffineMapAttr::get(permutationMap)
4142            << " vs inBounds of size: " << inBounds.size();
4143 
4144   return success();
4145 }
4146 
4147 static void printTransferAttrs(OpAsmPrinter &p, VectorTransferOpInterface op) {
4148   SmallVector<StringRef, 3> elidedAttrs;
4149   elidedAttrs.push_back(TransferReadOp::getOperandSegmentSizeAttr());
4150   if (op.getPermutationMap().isMinorIdentity())
4151     elidedAttrs.push_back(op.getPermutationMapAttrName());
4152   // Elide in_bounds attribute if all dims are out-of-bounds.
4153   if (llvm::none_of(op.getInBoundsValues(), [](bool b) { return b; }))
4154     elidedAttrs.push_back(op.getInBoundsAttrName());
4155   p.printOptionalAttrDict(op->getAttrs(), elidedAttrs);
4156 }
4157 
4158 void TransferReadOp::print(OpAsmPrinter &p) {
4159   p << " " << getSource() << "[" << getIndices() << "], " << getPadding();
4160   if (getMask())
4161     p << ", " << getMask();
4162   printTransferAttrs(p, *this);
4163   p << " : " << getShapedType() << ", " << getVectorType();
4164 }
4165 
4166 VectorType mlir::vector::inferTransferOpMaskType(VectorType vecType,
4167                                                  AffineMap permMap) {
4168   auto i1Type = IntegerType::get(permMap.getContext(), 1);
4169   AffineMap invPermMap = inversePermutation(compressUnusedDims(permMap));
4170   assert(invPermMap && "Inversed permutation map couldn't be computed");
4171   SmallVector<int64_t, 8> maskShape = invPermMap.compose(vecType.getShape());
4172 
4173   // The MaskOp specification doesn't support 0-D vectors at the moment. Turn a
4174   // 0-D mask into a single-element 1-D mask.
4175   if (maskShape.empty())
4176     maskShape.push_back(1);
4177 
4178   SmallVector<bool> scalableDims =
4179       applyPermutationMap(invPermMap, vecType.getScalableDims());
4180 
4181   return VectorType::get(maskShape, i1Type, scalableDims);
4182 }
4183 
4184 ParseResult TransferReadOp::parse(OpAsmParser &parser, OperationState &result) {
4185   auto &builder = parser.getBuilder();
4186   SMLoc typesLoc;
4187   OpAsmParser::UnresolvedOperand sourceInfo;
4188   SmallVector<OpAsmParser::UnresolvedOperand, 8> indexInfo;
4189   OpAsmParser::UnresolvedOperand paddingInfo;
4190   SmallVector<Type, 2> types;
4191   OpAsmParser::UnresolvedOperand maskInfo;
4192   // Parsing with support for paddingValue.
4193   if (parser.parseOperand(sourceInfo) ||
4194       parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square) ||
4195       parser.parseComma() || parser.parseOperand(paddingInfo))
4196     return failure();
4197   ParseResult hasMask = parser.parseOptionalComma();
4198   if (hasMask.succeeded()) {
4199     if (parser.parseOperand(maskInfo))
4200       return failure();
4201   }
4202   if (parser.parseOptionalAttrDict(result.attributes) ||
4203       parser.getCurrentLocation(&typesLoc) || parser.parseColonTypeList(types))
4204     return failure();
4205   if (types.size() != 2)
4206     return parser.emitError(typesLoc, "requires two types");
4207   auto indexType = builder.getIndexType();
4208   auto shapedType = llvm::dyn_cast<ShapedType>(types[0]);
4209   if (!shapedType || !llvm::isa<MemRefType, RankedTensorType>(shapedType))
4210     return parser.emitError(typesLoc, "requires memref or ranked tensor type");
4211   VectorType vectorType = llvm::dyn_cast<VectorType>(types[1]);
4212   if (!vectorType)
4213     return parser.emitError(typesLoc, "requires vector type");
4214   auto permMapAttrName = TransferReadOp::getPermutationMapAttrName(result.name);
4215   Attribute permMapAttr = result.attributes.get(permMapAttrName);
4216   AffineMap permMap;
4217   if (!permMapAttr) {
4218     permMap = getTransferMinorIdentityMap(shapedType, vectorType);
4219     result.attributes.set(permMapAttrName, AffineMapAttr::get(permMap));
4220   } else {
4221     permMap = llvm::cast<AffineMapAttr>(permMapAttr).getValue();
4222   }
4223   auto inBoundsAttrName = TransferReadOp::getInBoundsAttrName(result.name);
4224   Attribute inBoundsAttr = result.attributes.get(inBoundsAttrName);
4225   if (!inBoundsAttr) {
4226     result.addAttribute(inBoundsAttrName,
4227                         builder.getBoolArrayAttr(
4228                             SmallVector<bool>(permMap.getNumResults(), false)));
4229   }
4230   if (parser.resolveOperand(sourceInfo, shapedType, result.operands) ||
4231       parser.resolveOperands(indexInfo, indexType, result.operands) ||
4232       parser.resolveOperand(paddingInfo, shapedType.getElementType(),
4233                             result.operands))
4234     return failure();
4235   if (hasMask.succeeded()) {
4236     if (llvm::dyn_cast<VectorType>(shapedType.getElementType()))
4237       return parser.emitError(
4238           maskInfo.location, "does not support masks with vector element type");
4239     if (vectorType.getRank() != permMap.getNumResults()) {
4240       return parser.emitError(typesLoc,
4241                               "expected the same rank for the vector and the "
4242                               "results of the permutation map");
4243     }
4244     // Instead of adding the mask type as an op type, compute it based on the
4245     // vector type and the permutation map (to keep the type signature small).
4246     auto maskType = inferTransferOpMaskType(vectorType, permMap);
4247     if (parser.resolveOperand(maskInfo, maskType, result.operands))
4248       return failure();
4249   }
4250   result.addAttribute(TransferReadOp::getOperandSegmentSizeAttr(),
4251                       builder.getDenseI32ArrayAttr(
4252                           {1, static_cast<int32_t>(indexInfo.size()), 1,
4253                            static_cast<int32_t>(hasMask.succeeded())}));
4254   return parser.addTypeToList(vectorType, result.types);
4255 }
4256 
4257 LogicalResult TransferReadOp::verify() {
4258   // Consistency of elemental types in source and vector.
4259   ShapedType shapedType = getShapedType();
4260   VectorType vectorType = getVectorType();
4261   VectorType maskType = getMaskType();
4262   auto paddingType = getPadding().getType();
4263   auto permutationMap = getPermutationMap();
4264   VectorType inferredMaskType =
4265       maskType ? inferTransferOpMaskType(vectorType, permutationMap)
4266                : VectorType();
4267   auto sourceElementType = shapedType.getElementType();
4268 
4269   if (static_cast<int64_t>(getIndices().size()) != shapedType.getRank())
4270     return emitOpError("requires ") << shapedType.getRank() << " indices";
4271 
4272   if (failed(verifyTransferOp(cast<VectorTransferOpInterface>(getOperation()),
4273                               shapedType, vectorType, maskType,
4274                               inferredMaskType, permutationMap, getInBounds())))
4275     return failure();
4276 
4277   if (auto sourceVectorElementType =
4278           llvm::dyn_cast<VectorType>(sourceElementType)) {
4279     // Source has vector element type.
4280     // Check that 'sourceVectorElementType' and 'paddingType' types match.
4281     if (sourceVectorElementType != paddingType)
4282       return emitOpError(
4283           "requires source element type and padding type to match.");
4284 
4285   } else {
4286     // Check that 'paddingType' is valid to store in a vector type.
4287     if (!VectorType::isValidElementType(paddingType))
4288       return emitOpError("requires valid padding vector elemental type");
4289 
4290     // Check that padding type and vector element types match.
4291     if (paddingType != sourceElementType)
4292       return emitOpError(
4293           "requires formal padding and source of the same elemental type");
4294   }
4295 
4296   return verifyPermutationMap(permutationMap,
4297                               [&](Twine t) { return emitOpError(t); });
4298 }
4299 
4300 // MaskableOpInterface methods.
4301 
4302 /// Returns the mask type expected by this operation. Mostly used for
4303 /// verification purposes. It requires the operation to be vectorized."
4304 Type TransferReadOp::getExpectedMaskType() {
4305   return inferTransferOpMaskType(getVectorType(), getPermutationMap());
4306 }
4307 
4308 template <typename TransferOp>
4309 static bool isInBounds(TransferOp op, int64_t resultIdx, int64_t indicesIdx) {
4310   // TODO: support more aggressive createOrFold on:
4311   // op.getIndices()[indicesIdx] + vectorType < dim(op.getSource(), indicesIdx)
4312   if (op.getShapedType().isDynamicDim(indicesIdx))
4313     return false;
4314   Value index = op.getIndices()[indicesIdx];
4315   std::optional<int64_t> cstOp = getConstantIntValue(index);
4316   if (!cstOp.has_value())
4317     return false;
4318 
4319   int64_t sourceSize = op.getShapedType().getDimSize(indicesIdx);
4320   int64_t vectorSize = op.getVectorType().getDimSize(resultIdx);
4321 
4322   return cstOp.value() + vectorSize <= sourceSize;
4323 }
4324 
4325 template <typename TransferOp>
4326 static LogicalResult foldTransferInBoundsAttribute(TransferOp op) {
4327   // TODO: support 0-d corner case.
4328   // TODO: Be less conservative.
4329   if (op.getTransferRank() == 0)
4330     return failure();
4331   AffineMap permutationMap = op.getPermutationMap();
4332   bool changed = false;
4333   SmallVector<bool, 4> newInBounds;
4334   newInBounds.reserve(op.getTransferRank());
4335   // Idxs of non-bcast dims - used when analysing bcast dims.
4336   SmallVector<unsigned> nonBcastDims;
4337 
4338   // 1. Process non-broadcast dims
4339   for (unsigned i = 0; i < op.getTransferRank(); ++i) {
4340     // 1.1. Already marked as in-bounds, nothing to see here.
4341     if (op.isDimInBounds(i)) {
4342       newInBounds.push_back(true);
4343       continue;
4344     }
4345     // 1.2. Currently out-of-bounds, check whether we can statically determine
4346     // it is inBounds.
4347     bool inBounds = false;
4348     auto dimExpr = dyn_cast<AffineDimExpr>(permutationMap.getResult(i));
4349     if (dimExpr) {
4350       inBounds = isInBounds(op, /*resultIdx=*/i,
4351                             /*indicesIdx=*/dimExpr.getPosition());
4352       nonBcastDims.push_back(i);
4353     }
4354 
4355     newInBounds.push_back(inBounds);
4356     // We commit the pattern if it is "more inbounds".
4357     changed |= inBounds;
4358   }
4359 
4360   // 2. Handle broadcast dims
4361   // If all non-broadcast dims are "in bounds", then all bcast dims should be
4362   // "in bounds" as well.
4363   bool allNonBcastDimsInBounds = llvm::all_of(
4364       nonBcastDims, [&newInBounds](unsigned idx) { return newInBounds[idx]; });
4365   if (allNonBcastDimsInBounds) {
4366     for (size_t idx : permutationMap.getBroadcastDims()) {
4367       changed |= !newInBounds[idx];
4368       newInBounds[idx] = true;
4369     }
4370   }
4371 
4372   if (!changed)
4373     return failure();
4374   // OpBuilder is only used as a helper to build an I64ArrayAttr.
4375   OpBuilder b(op.getContext());
4376   op.setInBoundsAttr(b.getBoolArrayAttr(newInBounds));
4377   return success();
4378 }
4379 
4380 template <typename TransferOp>
4381 static LogicalResult foldTransferFullMask(TransferOp op) {
4382   auto mask = op.getMask();
4383   if (!mask)
4384     return failure();
4385 
4386   if (getMaskFormat(mask) != MaskFormat::AllTrue)
4387     return failure();
4388 
4389   op.getMaskMutable().clear();
4390   return success();
4391 }
4392 
4393 ///  ```
4394 ///  %w0 = vector.transfer_write %v0, %arg0[%c1, %c0] {in_bounds = [true, true]}
4395 ///    : vector<1x4xf32>, tensor<4x4xf32>
4396 ///  %0 = vector.transfer_read %w0[%c1, %c0], %cf0 {in_bounds = [true, true]}
4397 ///    : tensor<4x4xf32>, vector<1x4xf32>
4398 ///  ```
4399 ///  -> Folds into
4400 ///  ```
4401 ///  %v0
4402 ///  ```
4403 static Value foldRAW(TransferReadOp readOp) {
4404   if (!llvm::isa<RankedTensorType>(readOp.getShapedType()))
4405     return {};
4406   auto defWrite = readOp.getSource().getDefiningOp<vector::TransferWriteOp>();
4407   while (defWrite) {
4408     if (checkSameValueRAW(defWrite, readOp))
4409       return defWrite.getVector();
4410     if (!isDisjointTransferIndices(
4411             cast<VectorTransferOpInterface>(defWrite.getOperation()),
4412             cast<VectorTransferOpInterface>(readOp.getOperation())))
4413       break;
4414     defWrite = defWrite.getSource().getDefiningOp<vector::TransferWriteOp>();
4415   }
4416   return {};
4417 }
4418 
4419 OpFoldResult TransferReadOp::fold(FoldAdaptor) {
4420   if (Value vec = foldRAW(*this))
4421     return vec;
4422   /// transfer_read(memrefcast) -> transfer_read
4423   if (succeeded(foldTransferInBoundsAttribute(*this)))
4424     return getResult();
4425   if (succeeded(foldTransferFullMask(*this)))
4426     return getResult();
4427   if (succeeded(memref::foldMemRefCast(*this)))
4428     return getResult();
4429   if (succeeded(tensor::foldTensorCast(*this)))
4430     return getResult();
4431   return OpFoldResult();
4432 }
4433 
4434 std::optional<SmallVector<int64_t, 4>> TransferReadOp::getShapeForUnroll() {
4435   return llvm::to_vector<4>(getVectorType().getShape());
4436 }
4437 
4438 void TransferReadOp::getEffects(
4439     SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
4440         &effects) {
4441   if (llvm::isa<MemRefType>(getShapedType()))
4442     effects.emplace_back(MemoryEffects::Read::get(), &getSourceMutable(),
4443                          SideEffects::DefaultResource::get());
4444 }
4445 
4446 Speculation::Speculatability TransferReadOp::getSpeculatability() {
4447   if (hasPureTensorSemantics())
4448     return Speculation::Speculatable;
4449   return Speculation::NotSpeculatable;
4450 }
4451 
4452 namespace {
4453 /// Store to load forwarding for transfer operations with permuation maps.
4454 /// Even if the permutation maps are different we can still propagate the store
4455 /// into the load if the size of the dimensions read and written match. Then we
4456 /// can replace the transfer_read + transfer_write by vector.broadcast and
4457 /// vector.transpose.
4458 /// Example:
4459 /// ```
4460 /// %w0 = vector.transfer_write %v0, %arg0[%c0, %c0, %c0]
4461 ///  {in_bounds = [true, true],
4462 ///   permutation_map = affine_map<(d0, d1, d2) -> (d2, d1)>} :
4463 ///   vector<4x1xf32>, tensor<4x4x4xf32>
4464 ///  %r = vector.transfer_read %w0[%c0, %c0, %c0], %cf0
4465 ///   {in_bounds = [true, true, true, true],
4466 ///   permutation_map = affine_map<(d0, d1, d2) -> (d1, 0, d2, 0)>} :
4467 ///   tensor<4x4x4xf32>, vector<1x100x4x5xf32>
4468 /// ```
4469 /// To:
4470 /// ```
4471 /// %0 = vector.broadcast %arg1 : vector<4x1xf32> to vector<100x5x4x1xf32>
4472 /// %r = vector.transpose %0, [3, 0, 2, 1] :
4473 ///   vector<100x5x4x1xf32> to vector<1x100x4x5xf32>
4474 /// ```
4475 struct TransferReadAfterWriteToBroadcast
4476     : public OpRewritePattern<TransferReadOp> {
4477   using OpRewritePattern::OpRewritePattern;
4478 
4479   LogicalResult matchAndRewrite(TransferReadOp readOp,
4480                                 PatternRewriter &rewriter) const override {
4481     if (readOp.hasOutOfBoundsDim() ||
4482         !llvm::isa<RankedTensorType>(readOp.getShapedType()))
4483       return failure();
4484     auto defWrite = readOp.getSource().getDefiningOp<vector::TransferWriteOp>();
4485     if (!defWrite)
4486       return failure();
4487     // TODO: If the written transfer chunk is a superset of the read transfer
4488     // chunk we could do an extract_strided_slice.
4489     if (readOp.getTransferChunkAccessed() !=
4490         defWrite.getTransferChunkAccessed())
4491       return failure();
4492     // TODO: Support cases where a dim is explicitly written but implicitly
4493     // read (i.e., a unit dim that is rank reduced).
4494     if (getUnusedDimsBitVector({readOp.getPermutationMap()}) !=
4495         getUnusedDimsBitVector({defWrite.getPermutationMap()}))
4496       return failure();
4497     if (readOp.getIndices() != defWrite.getIndices() ||
4498         readOp.getMask() != defWrite.getMask())
4499       return failure();
4500     Value vec = defWrite.getVector();
4501     // TODO: loop through the chain of transfer_write if we can prove that they
4502     // don't overlap with the transfer_read. This requires improving
4503     // `isDisjointTransferIndices` helper.
4504     AffineMap readMap = compressUnusedDims(readOp.getPermutationMap());
4505     AffineMap writeMap = compressUnusedDims(defWrite.getPermutationMap());
4506     AffineMap map = readMap.compose(writeMap);
4507     if (map.getNumResults() == 0)
4508       return failure();
4509     // Calculate the permutation to apply to go from the vector stored to the
4510     // vector read.
4511     SmallVector<unsigned> permutation;
4512     if (!map.isPermutationOfMinorIdentityWithBroadcasting(permutation))
4513       return failure();
4514 
4515     Location loc = readOp.getLoc();
4516     // Calculate the broadcast shape by applying the reverse permutation to the
4517     // final shape we want.
4518     ArrayRef<int64_t> destShape = readOp.getVectorType().getShape();
4519     SmallVector<int64_t> broadcastShape(destShape.size());
4520     SmallVector<bool> broadcastScalableFlags(destShape.size());
4521     for (const auto &pos : llvm::enumerate(permutation)) {
4522       broadcastShape[pos.value()] = destShape[pos.index()];
4523       broadcastScalableFlags[pos.value()] =
4524           readOp.getVectorType().getScalableDims()[pos.index()];
4525     }
4526     VectorType broadcastedType = VectorType::get(
4527         broadcastShape, defWrite.getVectorType().getElementType(),
4528         broadcastScalableFlags);
4529     vec = rewriter.create<vector::BroadcastOp>(loc, broadcastedType, vec);
4530     SmallVector<int64_t> transposePerm(permutation.begin(), permutation.end());
4531     rewriter.replaceOpWithNewOp<vector::TransposeOp>(readOp, vec,
4532                                                      transposePerm);
4533     return success();
4534   }
4535 };
4536 } // namespace
4537 
4538 void TransferReadOp::getCanonicalizationPatterns(RewritePatternSet &results,
4539                                                  MLIRContext *context) {
4540   results.add<TransferReadAfterWriteToBroadcast>(context);
4541 }
4542 
4543 //===----------------------------------------------------------------------===//
4544 // TransferWriteOp
4545 //===----------------------------------------------------------------------===//
4546 
4547 /// 1. Builder with type inference.
4548 void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
4549                             Value vector, Value dest, ValueRange indices,
4550                             AffineMapAttr permutationMapAttr,
4551                             /*optional*/ Value mask,
4552                             /*optional*/ ArrayAttr inBoundsAttr) {
4553   Type resultType = llvm::dyn_cast<RankedTensorType>(dest.getType());
4554   build(builder, result, resultType, vector, dest, indices, permutationMapAttr,
4555         mask, inBoundsAttr);
4556 }
4557 
4558 /// 2. Builder with type inference that sets an empty mask (variant with attrs).
4559 void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
4560                             Value vector, Value dest, ValueRange indices,
4561                             AffineMapAttr permutationMapAttr,
4562                             /*optional*/ ArrayAttr inBoundsAttr) {
4563   build(builder, result, vector, dest, indices, permutationMapAttr,
4564         /*mask=*/Value(), inBoundsAttr);
4565 }
4566 
4567 /// 3. Builder with type inference that sets an empty mask (variant without
4568 /// attrs)
4569 void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
4570                             Value vector, Value dest, ValueRange indices,
4571                             AffineMap permutationMap,
4572                             std::optional<ArrayRef<bool>> inBounds) {
4573   auto permutationMapAttr = AffineMapAttr::get(permutationMap);
4574   auto inBoundsAttr =
4575       (inBounds && !inBounds.value().empty())
4576           ? builder.getBoolArrayAttr(inBounds.value())
4577           : builder.getBoolArrayAttr(SmallVector<bool>(
4578                 llvm::cast<VectorType>(vector.getType()).getRank(), false));
4579   build(builder, result, vector, dest, indices, permutationMapAttr,
4580         /*mask=*/Value(), inBoundsAttr);
4581 }
4582 
4583 /// 4. Builder with type inference that sets an empty mask and sets permutation
4584 ///    map to 'getMinorIdentityMap'.
4585 void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
4586                             Value vector, Value dest, ValueRange indices,
4587                             std::optional<ArrayRef<bool>> inBounds) {
4588   auto vectorType = llvm::cast<VectorType>(vector.getType());
4589   AffineMap permutationMap = getTransferMinorIdentityMap(
4590       llvm::cast<ShapedType>(dest.getType()), vectorType);
4591   build(builder, result, vector, dest, indices, permutationMap, inBounds);
4592 }
4593 
4594 ParseResult TransferWriteOp::parse(OpAsmParser &parser,
4595                                    OperationState &result) {
4596   auto &builder = parser.getBuilder();
4597   SMLoc typesLoc;
4598   OpAsmParser::UnresolvedOperand vectorInfo, sourceInfo;
4599   SmallVector<OpAsmParser::UnresolvedOperand, 8> indexInfo;
4600   SmallVector<Type, 2> types;
4601   OpAsmParser::UnresolvedOperand maskInfo;
4602   if (parser.parseOperand(vectorInfo) || parser.parseComma() ||
4603       parser.parseOperand(sourceInfo) ||
4604       parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square))
4605     return failure();
4606   ParseResult hasMask = parser.parseOptionalComma();
4607   if (hasMask.succeeded() && parser.parseOperand(maskInfo))
4608     return failure();
4609   if (parser.parseOptionalAttrDict(result.attributes) ||
4610       parser.getCurrentLocation(&typesLoc) || parser.parseColonTypeList(types))
4611     return failure();
4612   if (types.size() != 2)
4613     return parser.emitError(typesLoc, "requires two types");
4614   auto indexType = builder.getIndexType();
4615   VectorType vectorType = llvm::dyn_cast<VectorType>(types[0]);
4616   if (!vectorType)
4617     return parser.emitError(typesLoc, "requires vector type");
4618   ShapedType shapedType = llvm::dyn_cast<ShapedType>(types[1]);
4619   if (!shapedType || !llvm::isa<MemRefType, RankedTensorType>(shapedType))
4620     return parser.emitError(typesLoc, "requires memref or ranked tensor type");
4621   auto permMapAttrName =
4622       TransferWriteOp::getPermutationMapAttrName(result.name);
4623   auto permMapAttr = result.attributes.get(permMapAttrName);
4624   AffineMap permMap;
4625   if (!permMapAttr) {
4626     permMap = getTransferMinorIdentityMap(shapedType, vectorType);
4627     result.attributes.set(permMapAttrName, AffineMapAttr::get(permMap));
4628   } else {
4629     permMap = llvm::cast<AffineMapAttr>(permMapAttr).getValue();
4630   }
4631   auto inBoundsAttrName = TransferWriteOp::getInBoundsAttrName(result.name);
4632   Attribute inBoundsAttr = result.attributes.get(inBoundsAttrName);
4633   if (!inBoundsAttr) {
4634     result.addAttribute(inBoundsAttrName,
4635                         builder.getBoolArrayAttr(
4636                             SmallVector<bool>(permMap.getNumResults(), false)));
4637   }
4638   if (parser.resolveOperand(vectorInfo, vectorType, result.operands) ||
4639       parser.resolveOperand(sourceInfo, shapedType, result.operands) ||
4640       parser.resolveOperands(indexInfo, indexType, result.operands))
4641     return failure();
4642   if (hasMask.succeeded()) {
4643     if (llvm::dyn_cast<VectorType>(shapedType.getElementType()))
4644       return parser.emitError(
4645           maskInfo.location, "does not support masks with vector element type");
4646     if (vectorType.getRank() != permMap.getNumResults()) {
4647       return parser.emitError(typesLoc,
4648                               "expected the same rank for the vector and the "
4649                               "results of the permutation map");
4650     }
4651     auto maskType = inferTransferOpMaskType(vectorType, permMap);
4652     if (parser.resolveOperand(maskInfo, maskType, result.operands))
4653       return failure();
4654   }
4655   result.addAttribute(TransferWriteOp::getOperandSegmentSizeAttr(),
4656                       builder.getDenseI32ArrayAttr(
4657                           {1, 1, static_cast<int32_t>(indexInfo.size()),
4658                            static_cast<int32_t>(hasMask.succeeded())}));
4659   return failure(llvm::isa<RankedTensorType>(shapedType) &&
4660                  parser.addTypeToList(shapedType, result.types));
4661 }
4662 
4663 void TransferWriteOp::print(OpAsmPrinter &p) {
4664   p << " " << getVector() << ", " << getSource() << "[" << getIndices() << "]";
4665   if (getMask())
4666     p << ", " << getMask();
4667   printTransferAttrs(p, *this);
4668   p << " : " << getVectorType() << ", " << getShapedType();
4669 }
4670 
4671 LogicalResult TransferWriteOp::verify() {
4672   // Consistency of elemental types in shape and vector.
4673   ShapedType shapedType = getShapedType();
4674   VectorType vectorType = getVectorType();
4675   VectorType maskType = getMaskType();
4676   auto permutationMap = getPermutationMap();
4677   VectorType inferredMaskType =
4678       maskType ? inferTransferOpMaskType(vectorType, permutationMap)
4679                : VectorType();
4680 
4681   if (llvm::size(getIndices()) != shapedType.getRank())
4682     return emitOpError("requires ") << shapedType.getRank() << " indices";
4683 
4684   // We do not allow broadcast dimensions on TransferWriteOps for the moment,
4685   // as the semantics is unclear. This can be revisited later if necessary.
4686   if (hasBroadcastDim())
4687     return emitOpError("should not have broadcast dimensions");
4688 
4689   if (failed(verifyTransferOp(cast<VectorTransferOpInterface>(getOperation()),
4690                               shapedType, vectorType, maskType,
4691                               inferredMaskType, permutationMap, getInBounds())))
4692     return failure();
4693 
4694   return verifyPermutationMap(permutationMap,
4695                               [&](Twine t) { return emitOpError(t); });
4696 }
4697 
4698 // MaskableOpInterface methods.
4699 
4700 /// Returns the mask type expected by this operation. Mostly used for
4701 /// verification purposes.
4702 Type TransferWriteOp::getExpectedMaskType() {
4703   return inferTransferOpMaskType(getVectorType(), getPermutationMap());
4704 }
4705 
4706 /// Fold:
4707 /// ```
4708 ///    %t1 = ...
4709 ///    %v = vector.transfer_read %t0[%c0...], {in_bounds = [true...]} :
4710 ///      tensor<static_sizesxf32>, vector<static_sizesxf32>
4711 ///    %t2 = vector.transfer_write %v, %t1[%c0...] {in_bounds = [true...]} :
4712 ///      vector<static_sizesxf32>, tensor<static_sizesxf32>
4713 /// ```
4714 ///
4715 /// into:
4716 ///
4717 /// ```
4718 ///    %t0
4719 /// ```
4720 ///
4721 /// The producer of t1 may or may not be DCE'd depending on whether it is a
4722 /// block argument or has side effects.
4723 static LogicalResult foldReadInitWrite(TransferWriteOp write,
4724                                        ArrayRef<Attribute>,
4725                                        SmallVectorImpl<OpFoldResult> &results) {
4726   // TODO: support 0-d corner case.
4727   if (write.getTransferRank() == 0)
4728     return failure();
4729   auto rankedTensorType =
4730       llvm::dyn_cast<RankedTensorType>(write.getSource().getType());
4731   // If not operating on tensors, bail.
4732   if (!rankedTensorType)
4733     return failure();
4734   // If no read, bail.
4735   auto read = write.getVector().getDefiningOp<vector::TransferReadOp>();
4736   if (!read)
4737     return failure();
4738   // TODO: support 0-d corner case.
4739   if (read.getTransferRank() == 0)
4740     return failure();
4741   // For now, only accept minor identity. Future: composition is minor identity.
4742   if (!read.getPermutationMap().isMinorIdentity() ||
4743       !write.getPermutationMap().isMinorIdentity())
4744     return failure();
4745   // Bail on mismatching ranks.
4746   if (read.getTransferRank() != write.getTransferRank())
4747     return failure();
4748   // Bail on potential out-of-bounds accesses.
4749   if (read.hasOutOfBoundsDim() || write.hasOutOfBoundsDim())
4750     return failure();
4751   // Tensor types must be the same.
4752   if (read.getSource().getType() != rankedTensorType)
4753     return failure();
4754   // Vector types must be the same.
4755   if (read.getVectorType() != write.getVectorType())
4756     return failure();
4757   // Vector and Tensor shapes must match.
4758   if (read.getVectorType().getShape() != rankedTensorType.getShape())
4759     return failure();
4760   // If any index is nonzero.
4761   auto isNotConstantZero = [](Value v) {
4762     auto cstOp = getConstantIntValue(v);
4763     return !cstOp.has_value() || cstOp.value() != 0;
4764   };
4765   if (llvm::any_of(read.getIndices(), isNotConstantZero) ||
4766       llvm::any_of(write.getIndices(), isNotConstantZero))
4767     return failure();
4768   // Success.
4769   results.push_back(read.getSource());
4770   return success();
4771 }
4772 
4773 static bool checkSameValueWAR(vector::TransferReadOp read,
4774                               vector::TransferWriteOp write) {
4775   return read.getSource() == write.getSource() &&
4776          read.getIndices() == write.getIndices() &&
4777          read.getPermutationMap() == write.getPermutationMap() &&
4778          read.getVectorType() == write.getVectorType() && !read.getMask() &&
4779          !write.getMask();
4780 }
4781 /// Fold transfer_write write after read:
4782 /// ```
4783 ///    %t0 = ...
4784 ///    %v = vector.transfer_read %t0[%c0...] :
4785 ///      tensor<static_sizesxf32>, vector<static_sizesxf32>
4786 ///    %t1 = vector.transfer_write %v, %t0[%c0...] :
4787 ///      vector<static_sizesxf32>, tensor<static_sizesxf32>
4788 /// ```
4789 ///
4790 /// into:
4791 ///
4792 /// ```
4793 ///    %t0
4794 /// ```
4795 static LogicalResult foldWAR(TransferWriteOp write,
4796                              SmallVectorImpl<OpFoldResult> &results) {
4797   if (!llvm::isa<RankedTensorType>(write.getSource().getType()))
4798     return failure();
4799   auto read = write.getVector().getDefiningOp<vector::TransferReadOp>();
4800   if (!read)
4801     return failure();
4802 
4803   if (!checkSameValueWAR(read, write))
4804     return failure();
4805   results.push_back(read.getSource());
4806   return success();
4807 }
4808 
4809 LogicalResult TransferWriteOp::fold(FoldAdaptor adaptor,
4810                                     SmallVectorImpl<OpFoldResult> &results) {
4811   if (succeeded(foldReadInitWrite(*this, adaptor.getOperands(), results)))
4812     return success();
4813   if (succeeded(foldWAR(*this, results)))
4814     return success();
4815   if (succeeded(foldTransferInBoundsAttribute(*this)))
4816     return success();
4817   if (succeeded(foldTransferFullMask(*this)))
4818     return success();
4819   return memref::foldMemRefCast(*this);
4820 }
4821 
4822 std::optional<SmallVector<int64_t, 4>> TransferWriteOp::getShapeForUnroll() {
4823   return llvm::to_vector<4>(getVectorType().getShape());
4824 }
4825 
4826 void TransferWriteOp::getEffects(
4827     SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
4828         &effects) {
4829   if (llvm::isa<MemRefType>(getShapedType()))
4830     effects.emplace_back(MemoryEffects::Write::get(), &getSourceMutable(),
4831                          SideEffects::DefaultResource::get());
4832 }
4833 
4834 Speculation::Speculatability TransferWriteOp::getSpeculatability() {
4835   if (hasPureTensorSemantics())
4836     return Speculation::Speculatable;
4837   return Speculation::NotSpeculatable;
4838 }
4839 
4840 namespace {
4841 /// Remove dead transfer write from the SSA chain so that it an be eliminated by
4842 /// DCE
4843 /// ```
4844 ///  %w0 = vector.transfer_write %v0, %arg0[%c1, %c0] {in_bounds = [true, true]}
4845 ///    : vector<1x4xf32>, tensor<4x4xf32>
4846 ///  %w1 = vector.transfer_write %v0, %w0[%c2, %c0] {in_bounds = [true, true]}
4847 ///    : vector<1x4xf32>, tensor<4x4xf32>
4848 ///  %w2 = vector.transfer_write %v1, %w1[%c1, %c0] {in_bounds = [true, true]}
4849 ///    : vector<1x4xf32>, tensor<4x4xf32>
4850 /// ```
4851 ///
4852 /// into:
4853 ///
4854 /// ```
4855 ///  %w0 = vector.transfer_write %v0, %arg0[%c1, %c0] {in_bounds = [true, true]}
4856 ///    : vector<1x4xf32>, tensor<4x4xf32>
4857 ///  %w1 = vector.transfer_write %v0, %arg0[%c2, %c0] {in_bounds = [true, true]}
4858 ///    : vector<1x4xf32>, tensor<4x4xf32>
4859 ///  %w2 = vector.transfer_write %v1, %w1[%c1, %c0] {in_bounds = [true, true]}
4860 ///    : vector<1x4xf32>, tensor<4x4xf32>
4861 /// ```
4862 ///
4863 /// `%w0 = vector.transfer_write` op will be removed by DCE if it doesn't have
4864 /// any other uses.
4865 class FoldWaw final : public OpRewritePattern<TransferWriteOp> {
4866 public:
4867   using OpRewritePattern::OpRewritePattern;
4868   LogicalResult matchAndRewrite(TransferWriteOp writeOp,
4869                                 PatternRewriter &rewriter) const override {
4870     if (!llvm::isa<RankedTensorType>(writeOp.getShapedType()))
4871       return failure();
4872     vector::TransferWriteOp writeToModify = writeOp;
4873 
4874     auto defWrite =
4875         writeOp.getSource().getDefiningOp<vector::TransferWriteOp>();
4876     while (defWrite) {
4877       if (checkSameValueWAW(writeOp, defWrite)) {
4878         rewriter.modifyOpInPlace(writeToModify, [&]() {
4879           writeToModify.getSourceMutable().assign(defWrite.getSource());
4880         });
4881         return success();
4882       }
4883       if (!isDisjointTransferIndices(
4884               cast<VectorTransferOpInterface>(defWrite.getOperation()),
4885               cast<VectorTransferOpInterface>(writeOp.getOperation())))
4886         break;
4887       // If the previous write op doesn't have any other use we an safely look
4888       // at the previous store to see if it can be removed.
4889       if (!defWrite->hasOneUse())
4890         break;
4891       writeToModify = defWrite;
4892       defWrite = defWrite.getSource().getDefiningOp<vector::TransferWriteOp>();
4893     }
4894     return failure();
4895   }
4896 };
4897 
4898 /// Rewrite tensor::ExtractSliceOp(vector::TransferWriteOp) to
4899 /// vector::TransferWriteOp(tensor::ExtractSliceOp) if the full slice is
4900 /// overwritten and inserted into another tensor. After this rewrite, the
4901 /// operations bufferize in-place since all of them work on the same slice.
4902 ///
4903 /// For example:
4904 /// ```mlir
4905 ///   %0 = vector.transfer_write %vec, %init_tensor[%c0, %c0]
4906 ///        : vector<8x16xf32>, tensor<8x16xf32>
4907 ///   %1 = tensor.extract_slice %0[0, 0] [%sz0, %sz1] [1, 1]
4908 ///        : tensor<8x16xf32> to tensor<?x?xf32>
4909 ///   %r = tensor.insert_slice %1 into %iter_arg[%iv0, %iv1] [%sz0, %sz1] [1, 1]
4910 ///        : tensor<?x?xf32> into tensor<27x37xf32>
4911 /// ```
4912 /// folds to
4913 /// ```mlir
4914 ///   %0 = tensor.extract_slice %iter_arg[%iv0, %iv1] [%sz0, %sz1] [1, 1]
4915 ///        : tensor<27x37xf32> to tensor<?x?xf32>
4916 ///   %1 = vector.transfer_write %vec, %0[%c0, %c0]
4917 ///        : vector<8x16xf32>, tensor<?x?xf32>
4918 ///   %r = tensor.insert_slice %1 into %iter_arg[%iv0, %iv1] [%sz0, %sz1] [1, 1]
4919 ///        : tensor<?x?xf32> into tensor<27x37xf32>
4920 /// ```
4921 struct SwapExtractSliceOfTransferWrite
4922     : public OpRewritePattern<tensor::InsertSliceOp> {
4923 public:
4924   using OpRewritePattern::OpRewritePattern;
4925 
4926   LogicalResult matchAndRewrite(tensor::InsertSliceOp insertOp,
4927                                 PatternRewriter &rewriter) const override {
4928     if (!insertOp.hasUnitStride())
4929       return failure();
4930     auto extractOp =
4931         insertOp.getSource().getDefiningOp<tensor::ExtractSliceOp>();
4932     if (!extractOp || !extractOp.hasUnitStride() || !extractOp->hasOneUse())
4933       return failure();
4934     auto transferOp = extractOp.getSource().getDefiningOp<TransferWriteOp>();
4935     if (!transferOp || !transferOp->hasOneUse())
4936       return failure();
4937 
4938     // Fail if vector::TransferWriteOp or tensor::ExtractSliceOp is
4939     // rank-reducing.
4940     if (insertOp.getSourceType().getRank() != transferOp.getTransferRank()) {
4941       return rewriter.notifyMatchFailure(insertOp,
4942                                          "use-def chain is rank-reducing");
4943     }
4944 
4945     // Fail if tensor::ExtractSliceOp has non-zero offset.
4946     if (!extractOp.hasZeroOffset()) {
4947       return rewriter.notifyMatchFailure(insertOp,
4948                                          "ExtractSliceOp has non-zero offset");
4949     }
4950 
4951     // Fail if tensor::TransferWriteOp has non-zero offset.
4952     if (!llvm::all_of(transferOp.getIndices(), [](Value value) {
4953           return getConstantIntValue(value) == static_cast<int64_t>(0);
4954         })) {
4955       return rewriter.notifyMatchFailure(insertOp,
4956                                          "TranferWriteOp has non-zero offset");
4957     }
4958 
4959     // Fail if tensor::ExtractSliceOp and tensor::InsertSliceOp sizes differ.
4960     if (insertOp.getMixedSizes().size() != extractOp.getMixedSizes().size()) {
4961       return rewriter.notifyMatchFailure(
4962           insertOp, "InsertSliceOp and ExtractSliceOp ranks differ");
4963     }
4964 
4965     for (auto [insertSize, extractSize] :
4966          llvm::zip_equal(insertOp.getMixedSizes(), extractOp.getMixedSizes())) {
4967       if (!isEqualConstantIntOrValue(insertSize, extractSize)) {
4968         return rewriter.notifyMatchFailure(
4969             insertOp, "InsertSliceOp and ExtractSliceOp sizes differ");
4970       }
4971     }
4972 
4973     // Fail if the vector::TransferWriteOp may not overwrite the full tensor.
4974     assert(transferOp.getVectorType().hasStaticShape() &&
4975            "expected vector to have a static shape");
4976     ArrayRef<int64_t> vectorShape = transferOp.getVectorType().getShape();
4977     SmallVector<int64_t> resultShape = applyPermutationMap(
4978         transferOp.getPermutationMap(), transferOp.getShapedType().getShape());
4979     if (transferOp.getMask() || !vectorShape.equals(resultShape)) {
4980       return rewriter.notifyMatchFailure(
4981           insertOp, "TransferWriteOp may not write the full tensor.");
4982     }
4983 
4984     // Swap the tensor::ExtractSliceOp in front of the vector::TransferWriteOp.
4985     // Set all in_bounds to false and let the folder infer them.
4986     SmallVector<bool> newInBounds(vectorShape.size(), false);
4987     auto newExtractOp = rewriter.create<tensor::ExtractSliceOp>(
4988         extractOp.getLoc(), insertOp.getSourceType(), insertOp.getDest(),
4989         insertOp.getMixedOffsets(), insertOp.getMixedSizes(),
4990         insertOp.getMixedStrides());
4991     auto newTransferWriteOp = rewriter.create<TransferWriteOp>(
4992         transferOp.getLoc(), transferOp.getVector(), newExtractOp.getResult(),
4993         transferOp.getIndices(), transferOp.getPermutationMapAttr(),
4994         rewriter.getBoolArrayAttr(newInBounds));
4995     rewriter.modifyOpInPlace(insertOp, [&]() {
4996       insertOp.getSourceMutable().assign(newTransferWriteOp.getResult());
4997     });
4998     return success();
4999   }
5000 };
5001 
5002 } // namespace
5003 
5004 void TransferWriteOp::getCanonicalizationPatterns(RewritePatternSet &results,
5005                                                   MLIRContext *context) {
5006   results.add<FoldWaw, SwapExtractSliceOfTransferWrite>(context);
5007 }
5008 
5009 //===----------------------------------------------------------------------===//
5010 // LoadOp
5011 //===----------------------------------------------------------------------===//
5012 
5013 static LogicalResult verifyLoadStoreMemRefLayout(Operation *op,
5014                                                  VectorType vecTy,
5015                                                  MemRefType memRefTy) {
5016   // If rank==0 or size==1 it's equivalent to scalar load/store, so we don't
5017   // need any strides limitations.
5018   if (!vecTy.isScalable() &&
5019       (vecTy.getRank() == 0 || vecTy.getNumElements() == 1))
5020     return success();
5021 
5022   if (!memRefTy.isLastDimUnitStride())
5023     return op->emitOpError("most minor memref dim must have unit stride");
5024   return success();
5025 }
5026 
5027 LogicalResult vector::LoadOp::verify() {
5028   VectorType resVecTy = getVectorType();
5029   MemRefType memRefTy = getMemRefType();
5030 
5031   if (failed(verifyLoadStoreMemRefLayout(*this, resVecTy, memRefTy)))
5032     return failure();
5033 
5034   // Checks for vector memrefs.
5035   Type memElemTy = memRefTy.getElementType();
5036   if (auto memVecTy = llvm::dyn_cast<VectorType>(memElemTy)) {
5037     if (memVecTy != resVecTy)
5038       return emitOpError("base memref and result vector types should match");
5039     memElemTy = memVecTy.getElementType();
5040   }
5041 
5042   if (resVecTy.getElementType() != memElemTy)
5043     return emitOpError("base and result element types should match");
5044   if (llvm::size(getIndices()) != memRefTy.getRank())
5045     return emitOpError("requires ") << memRefTy.getRank() << " indices";
5046   return success();
5047 }
5048 
5049 OpFoldResult LoadOp::fold(FoldAdaptor) {
5050   if (succeeded(memref::foldMemRefCast(*this)))
5051     return getResult();
5052   return OpFoldResult();
5053 }
5054 
5055 //===----------------------------------------------------------------------===//
5056 // StoreOp
5057 //===----------------------------------------------------------------------===//
5058 
5059 LogicalResult vector::StoreOp::verify() {
5060   VectorType valueVecTy = getVectorType();
5061   MemRefType memRefTy = getMemRefType();
5062 
5063   if (failed(verifyLoadStoreMemRefLayout(*this, valueVecTy, memRefTy)))
5064     return failure();
5065 
5066   // Checks for vector memrefs.
5067   Type memElemTy = memRefTy.getElementType();
5068   if (auto memVecTy = llvm::dyn_cast<VectorType>(memElemTy)) {
5069     if (memVecTy != valueVecTy)
5070       return emitOpError(
5071           "base memref and valueToStore vector types should match");
5072     memElemTy = memVecTy.getElementType();
5073   }
5074 
5075   if (valueVecTy.getElementType() != memElemTy)
5076     return emitOpError("base and valueToStore element type should match");
5077   if (llvm::size(getIndices()) != memRefTy.getRank())
5078     return emitOpError("requires ") << memRefTy.getRank() << " indices";
5079   return success();
5080 }
5081 
5082 LogicalResult StoreOp::fold(FoldAdaptor adaptor,
5083                             SmallVectorImpl<OpFoldResult> &results) {
5084   return memref::foldMemRefCast(*this);
5085 }
5086 
5087 //===----------------------------------------------------------------------===//
5088 // MaskedLoadOp
5089 //===----------------------------------------------------------------------===//
5090 
5091 LogicalResult MaskedLoadOp::verify() {
5092   VectorType maskVType = getMaskVectorType();
5093   VectorType passVType = getPassThruVectorType();
5094   VectorType resVType = getVectorType();
5095   MemRefType memType = getMemRefType();
5096 
5097   if (resVType.getElementType() != memType.getElementType())
5098     return emitOpError("base and result element type should match");
5099   if (llvm::size(getIndices()) != memType.getRank())
5100     return emitOpError("requires ") << memType.getRank() << " indices";
5101   if (resVType.getShape() != maskVType.getShape())
5102     return emitOpError("expected result shape to match mask shape");
5103   if (resVType != passVType)
5104     return emitOpError("expected pass_thru of same type as result type");
5105   return success();
5106 }
5107 
5108 namespace {
5109 class MaskedLoadFolder final : public OpRewritePattern<MaskedLoadOp> {
5110 public:
5111   using OpRewritePattern::OpRewritePattern;
5112   LogicalResult matchAndRewrite(MaskedLoadOp load,
5113                                 PatternRewriter &rewriter) const override {
5114     switch (getMaskFormat(load.getMask())) {
5115     case MaskFormat::AllTrue:
5116       rewriter.replaceOpWithNewOp<vector::LoadOp>(
5117           load, load.getType(), load.getBase(), load.getIndices());
5118       return success();
5119     case MaskFormat::AllFalse:
5120       rewriter.replaceOp(load, load.getPassThru());
5121       return success();
5122     case MaskFormat::Unknown:
5123       return failure();
5124     }
5125     llvm_unreachable("Unexpected 1DMaskFormat on MaskedLoad");
5126   }
5127 };
5128 } // namespace
5129 
5130 void MaskedLoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
5131                                                MLIRContext *context) {
5132   results.add<MaskedLoadFolder>(context);
5133 }
5134 
5135 OpFoldResult MaskedLoadOp::fold(FoldAdaptor) {
5136   if (succeeded(memref::foldMemRefCast(*this)))
5137     return getResult();
5138   return OpFoldResult();
5139 }
5140 
5141 //===----------------------------------------------------------------------===//
5142 // MaskedStoreOp
5143 //===----------------------------------------------------------------------===//
5144 
5145 LogicalResult MaskedStoreOp::verify() {
5146   VectorType maskVType = getMaskVectorType();
5147   VectorType valueVType = getVectorType();
5148   MemRefType memType = getMemRefType();
5149 
5150   if (valueVType.getElementType() != memType.getElementType())
5151     return emitOpError("base and valueToStore element type should match");
5152   if (llvm::size(getIndices()) != memType.getRank())
5153     return emitOpError("requires ") << memType.getRank() << " indices";
5154   if (valueVType.getShape() != maskVType.getShape())
5155     return emitOpError("expected valueToStore shape to match mask shape");
5156   return success();
5157 }
5158 
5159 namespace {
5160 class MaskedStoreFolder final : public OpRewritePattern<MaskedStoreOp> {
5161 public:
5162   using OpRewritePattern::OpRewritePattern;
5163   LogicalResult matchAndRewrite(MaskedStoreOp store,
5164                                 PatternRewriter &rewriter) const override {
5165     switch (getMaskFormat(store.getMask())) {
5166     case MaskFormat::AllTrue:
5167       rewriter.replaceOpWithNewOp<vector::StoreOp>(
5168           store, store.getValueToStore(), store.getBase(), store.getIndices());
5169       return success();
5170     case MaskFormat::AllFalse:
5171       rewriter.eraseOp(store);
5172       return success();
5173     case MaskFormat::Unknown:
5174       return failure();
5175     }
5176     llvm_unreachable("Unexpected 1DMaskFormat on MaskedStore");
5177   }
5178 };
5179 } // namespace
5180 
5181 void MaskedStoreOp::getCanonicalizationPatterns(RewritePatternSet &results,
5182                                                 MLIRContext *context) {
5183   results.add<MaskedStoreFolder>(context);
5184 }
5185 
5186 LogicalResult MaskedStoreOp::fold(FoldAdaptor adaptor,
5187                                   SmallVectorImpl<OpFoldResult> &results) {
5188   return memref::foldMemRefCast(*this);
5189 }
5190 
5191 //===----------------------------------------------------------------------===//
5192 // GatherOp
5193 //===----------------------------------------------------------------------===//
5194 
5195 LogicalResult GatherOp::verify() {
5196   VectorType indVType = getIndexVectorType();
5197   VectorType maskVType = getMaskVectorType();
5198   VectorType resVType = getVectorType();
5199   ShapedType baseType = getBaseType();
5200 
5201   if (!llvm::isa<MemRefType, RankedTensorType>(baseType))
5202     return emitOpError("requires base to be a memref or ranked tensor type");
5203 
5204   if (resVType.getElementType() != baseType.getElementType())
5205     return emitOpError("base and result element type should match");
5206   if (llvm::size(getIndices()) != baseType.getRank())
5207     return emitOpError("requires ") << baseType.getRank() << " indices";
5208   if (resVType.getShape() != indVType.getShape())
5209     return emitOpError("expected result dim to match indices dim");
5210   if (resVType.getShape() != maskVType.getShape())
5211     return emitOpError("expected result dim to match mask dim");
5212   if (resVType != getPassThruVectorType())
5213     return emitOpError("expected pass_thru of same type as result type");
5214   return success();
5215 }
5216 
5217 // MaskableOpInterface methods.
5218 
5219 /// Returns the mask type expected by this operation. Mostly used for
5220 /// verification purposes. It requires the operation to be vectorized."
5221 Type GatherOp::getExpectedMaskType() {
5222   auto vecType = this->getIndexVectorType();
5223   return VectorType::get(vecType.getShape(),
5224                          IntegerType::get(vecType.getContext(), /*width=*/1),
5225                          vecType.getScalableDims());
5226 }
5227 
5228 std::optional<SmallVector<int64_t, 4>> GatherOp::getShapeForUnroll() {
5229   return llvm::to_vector<4>(getVectorType().getShape());
5230 }
5231 
5232 /// Cheeck if `indexVec` is constant 1D vec of consecutive values [0, 1, 2, ...]
5233 static LogicalResult isZeroBasedContiguousSeq(Value indexVec) {
5234   auto vecType = dyn_cast<VectorType>(indexVec.getType());
5235   if (!vecType || vecType.getRank() != 1 || vecType.isScalable())
5236     return failure();
5237 
5238   if (indexVec.getDefiningOp<StepOp>())
5239     return success();
5240 
5241   DenseIntElementsAttr elements;
5242   if (!matchPattern(indexVec, m_Constant(&elements)))
5243     return failure();
5244 
5245   return success(
5246       llvm::equal(elements, llvm::seq<int64_t>(0, vecType.getNumElements())));
5247 }
5248 
5249 namespace {
5250 class GatherFolder final : public OpRewritePattern<GatherOp> {
5251 public:
5252   using OpRewritePattern::OpRewritePattern;
5253   LogicalResult matchAndRewrite(GatherOp gather,
5254                                 PatternRewriter &rewriter) const override {
5255     switch (getMaskFormat(gather.getMask())) {
5256     case MaskFormat::AllTrue:
5257       return failure(); // no unmasked equivalent
5258     case MaskFormat::AllFalse:
5259       rewriter.replaceOp(gather, gather.getPassThru());
5260       return success();
5261     case MaskFormat::Unknown:
5262       return failure();
5263     }
5264     llvm_unreachable("Unexpected 1DMaskFormat on GatherFolder");
5265   }
5266 };
5267 
5268 /// Fold gathers with consecutive offsets [0, 1, 2, ...] into contiguous
5269 /// maskedload. Only 1D fixed vectors are supported for now.
5270 class FoldContiguousGather final : public OpRewritePattern<GatherOp> {
5271 public:
5272   using OpRewritePattern::OpRewritePattern;
5273   LogicalResult matchAndRewrite(GatherOp op,
5274                                 PatternRewriter &rewriter) const override {
5275     if (failed(isZeroBasedContiguousSeq(op.getIndexVec())))
5276       return failure();
5277 
5278     rewriter.replaceOpWithNewOp<MaskedLoadOp>(op, op.getType(), op.getBase(),
5279                                               op.getIndices(), op.getMask(),
5280                                               op.getPassThru());
5281     return success();
5282   }
5283 };
5284 } // namespace
5285 
5286 void GatherOp::getCanonicalizationPatterns(RewritePatternSet &results,
5287                                            MLIRContext *context) {
5288   results.add<GatherFolder, FoldContiguousGather>(context);
5289 }
5290 
5291 //===----------------------------------------------------------------------===//
5292 // ScatterOp
5293 //===----------------------------------------------------------------------===//
5294 
5295 LogicalResult ScatterOp::verify() {
5296   VectorType indVType = getIndexVectorType();
5297   VectorType maskVType = getMaskVectorType();
5298   VectorType valueVType = getVectorType();
5299   MemRefType memType = getMemRefType();
5300 
5301   if (valueVType.getElementType() != memType.getElementType())
5302     return emitOpError("base and valueToStore element type should match");
5303   if (llvm::size(getIndices()) != memType.getRank())
5304     return emitOpError("requires ") << memType.getRank() << " indices";
5305   if (valueVType.getDimSize(0) != indVType.getDimSize(0))
5306     return emitOpError("expected valueToStore dim to match indices dim");
5307   if (valueVType.getDimSize(0) != maskVType.getDimSize(0))
5308     return emitOpError("expected valueToStore dim to match mask dim");
5309   return success();
5310 }
5311 
5312 namespace {
5313 class ScatterFolder final : public OpRewritePattern<ScatterOp> {
5314 public:
5315   using OpRewritePattern::OpRewritePattern;
5316   LogicalResult matchAndRewrite(ScatterOp scatter,
5317                                 PatternRewriter &rewriter) const override {
5318     switch (getMaskFormat(scatter.getMask())) {
5319     case MaskFormat::AllTrue:
5320       return failure(); // no unmasked equivalent
5321     case MaskFormat::AllFalse:
5322       rewriter.eraseOp(scatter);
5323       return success();
5324     case MaskFormat::Unknown:
5325       return failure();
5326     }
5327     llvm_unreachable("Unexpected 1DMaskFormat on ScatterFolder");
5328   }
5329 };
5330 
5331 /// Fold scatters with consecutive offsets [0, 1, 2, ...] into contiguous
5332 /// maskedstore. Only 1D fixed vectors are supported for now.
5333 class FoldContiguousScatter final : public OpRewritePattern<ScatterOp> {
5334 public:
5335   using OpRewritePattern::OpRewritePattern;
5336   LogicalResult matchAndRewrite(ScatterOp op,
5337                                 PatternRewriter &rewriter) const override {
5338     if (failed(isZeroBasedContiguousSeq(op.getIndexVec())))
5339       return failure();
5340 
5341     rewriter.replaceOpWithNewOp<MaskedStoreOp>(
5342         op, op.getBase(), op.getIndices(), op.getMask(), op.getValueToStore());
5343     return success();
5344   }
5345 };
5346 } // namespace
5347 
5348 void ScatterOp::getCanonicalizationPatterns(RewritePatternSet &results,
5349                                             MLIRContext *context) {
5350   results.add<ScatterFolder, FoldContiguousScatter>(context);
5351 }
5352 
5353 //===----------------------------------------------------------------------===//
5354 // ExpandLoadOp
5355 //===----------------------------------------------------------------------===//
5356 
5357 LogicalResult ExpandLoadOp::verify() {
5358   VectorType maskVType = getMaskVectorType();
5359   VectorType passVType = getPassThruVectorType();
5360   VectorType resVType = getVectorType();
5361   MemRefType memType = getMemRefType();
5362 
5363   if (resVType.getElementType() != memType.getElementType())
5364     return emitOpError("base and result element type should match");
5365   if (llvm::size(getIndices()) != memType.getRank())
5366     return emitOpError("requires ") << memType.getRank() << " indices";
5367   if (resVType.getDimSize(0) != maskVType.getDimSize(0))
5368     return emitOpError("expected result dim to match mask dim");
5369   if (resVType != passVType)
5370     return emitOpError("expected pass_thru of same type as result type");
5371   return success();
5372 }
5373 
5374 namespace {
5375 class ExpandLoadFolder final : public OpRewritePattern<ExpandLoadOp> {
5376 public:
5377   using OpRewritePattern::OpRewritePattern;
5378   LogicalResult matchAndRewrite(ExpandLoadOp expand,
5379                                 PatternRewriter &rewriter) const override {
5380     switch (getMaskFormat(expand.getMask())) {
5381     case MaskFormat::AllTrue:
5382       rewriter.replaceOpWithNewOp<vector::LoadOp>(
5383           expand, expand.getType(), expand.getBase(), expand.getIndices());
5384       return success();
5385     case MaskFormat::AllFalse:
5386       rewriter.replaceOp(expand, expand.getPassThru());
5387       return success();
5388     case MaskFormat::Unknown:
5389       return failure();
5390     }
5391     llvm_unreachable("Unexpected 1DMaskFormat on ExpandLoadFolder");
5392   }
5393 };
5394 } // namespace
5395 
5396 void ExpandLoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
5397                                                MLIRContext *context) {
5398   results.add<ExpandLoadFolder>(context);
5399 }
5400 
5401 //===----------------------------------------------------------------------===//
5402 // CompressStoreOp
5403 //===----------------------------------------------------------------------===//
5404 
5405 LogicalResult CompressStoreOp::verify() {
5406   VectorType maskVType = getMaskVectorType();
5407   VectorType valueVType = getVectorType();
5408   MemRefType memType = getMemRefType();
5409 
5410   if (valueVType.getElementType() != memType.getElementType())
5411     return emitOpError("base and valueToStore element type should match");
5412   if (llvm::size(getIndices()) != memType.getRank())
5413     return emitOpError("requires ") << memType.getRank() << " indices";
5414   if (valueVType.getDimSize(0) != maskVType.getDimSize(0))
5415     return emitOpError("expected valueToStore dim to match mask dim");
5416   return success();
5417 }
5418 
5419 namespace {
5420 class CompressStoreFolder final : public OpRewritePattern<CompressStoreOp> {
5421 public:
5422   using OpRewritePattern::OpRewritePattern;
5423   LogicalResult matchAndRewrite(CompressStoreOp compress,
5424                                 PatternRewriter &rewriter) const override {
5425     switch (getMaskFormat(compress.getMask())) {
5426     case MaskFormat::AllTrue:
5427       rewriter.replaceOpWithNewOp<vector::StoreOp>(
5428           compress, compress.getValueToStore(), compress.getBase(),
5429           compress.getIndices());
5430       return success();
5431     case MaskFormat::AllFalse:
5432       rewriter.eraseOp(compress);
5433       return success();
5434     case MaskFormat::Unknown:
5435       return failure();
5436     }
5437     llvm_unreachable("Unexpected 1DMaskFormat on CompressStoreFolder");
5438   }
5439 };
5440 } // namespace
5441 
5442 void CompressStoreOp::getCanonicalizationPatterns(RewritePatternSet &results,
5443                                                   MLIRContext *context) {
5444   results.add<CompressStoreFolder>(context);
5445 }
5446 
5447 //===----------------------------------------------------------------------===//
5448 // ShapeCastOp
5449 //===----------------------------------------------------------------------===//
5450 
5451 void ShapeCastOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
5452                                     SetIntRangeFn setResultRanges) {
5453   setResultRanges(getResult(), argRanges.front());
5454 }
5455 
5456 /// Returns true if each element of 'a' is equal to the product of a contiguous
5457 /// sequence of the elements of 'b'. Returns false otherwise.
5458 static bool isValidShapeCast(ArrayRef<int64_t> a, ArrayRef<int64_t> b) {
5459   unsigned rankA = a.size();
5460   unsigned rankB = b.size();
5461   assert(rankA < rankB);
5462 
5463   auto isOne = [](int64_t v) { return v == 1; };
5464 
5465   // Special-case for n-D to 0-d shape cast. 'b' must be all ones to be shape
5466   // casted to a 0-d vector.
5467   if (rankA == 0 && llvm::all_of(b, isOne))
5468     return true;
5469 
5470   unsigned i = 0;
5471   unsigned j = 0;
5472   while (i < rankA && j < rankB) {
5473     int64_t dimA = a[i];
5474     int64_t dimB = 1;
5475     while (dimB < dimA && j < rankB)
5476       dimB *= b[j++];
5477     if (dimA != dimB)
5478       break;
5479     ++i;
5480 
5481     // Handle the case when trailing dimensions are of size 1.
5482     // Include them into the contiguous sequence.
5483     if (i < rankA && llvm::all_of(a.slice(i), isOne))
5484       i = rankA;
5485     if (j < rankB && llvm::all_of(b.slice(j), isOne))
5486       j = rankB;
5487   }
5488 
5489   return i == rankA && j == rankB;
5490 }
5491 
5492 static LogicalResult verifyVectorShapeCast(Operation *op,
5493                                            VectorType sourceVectorType,
5494                                            VectorType resultVectorType) {
5495   // Check that element type is the same.
5496   if (sourceVectorType.getElementType() != resultVectorType.getElementType())
5497     return op->emitOpError("source/result vectors must have same element type");
5498   auto sourceShape = sourceVectorType.getShape();
5499   auto resultShape = resultVectorType.getShape();
5500 
5501   // Check that product of source dim sizes matches product of result dim sizes.
5502   int64_t sourceDimProduct = std::accumulate(
5503       sourceShape.begin(), sourceShape.end(), 1LL, std::multiplies<int64_t>{});
5504   int64_t resultDimProduct = std::accumulate(
5505       resultShape.begin(), resultShape.end(), 1LL, std::multiplies<int64_t>{});
5506   if (sourceDimProduct != resultDimProduct)
5507     return op->emitOpError("source/result number of elements must match");
5508 
5509   // Check that expanding/contracting rank cases.
5510   unsigned sourceRank = sourceVectorType.getRank();
5511   unsigned resultRank = resultVectorType.getRank();
5512   if (sourceRank < resultRank) {
5513     if (!isValidShapeCast(sourceShape, resultShape))
5514       return op->emitOpError("invalid shape cast");
5515   } else if (sourceRank > resultRank) {
5516     if (!isValidShapeCast(resultShape, sourceShape))
5517       return op->emitOpError("invalid shape cast");
5518   }
5519 
5520   // Check that (non-)scalability is preserved
5521   int64_t sourceNScalableDims = sourceVectorType.getNumScalableDims();
5522   int64_t resultNScalableDims = resultVectorType.getNumScalableDims();
5523   if (sourceNScalableDims != resultNScalableDims)
5524     return op->emitOpError("different number of scalable dims at source (")
5525            << sourceNScalableDims << ") and result (" << resultNScalableDims
5526            << ")";
5527   sourceVectorType.getNumDynamicDims();
5528 
5529   return success();
5530 }
5531 
5532 LogicalResult ShapeCastOp::verify() {
5533   auto sourceVectorType =
5534       llvm::dyn_cast_or_null<VectorType>(getSource().getType());
5535   auto resultVectorType =
5536       llvm::dyn_cast_or_null<VectorType>(getResult().getType());
5537 
5538   // Check if source/result are of vector type.
5539   if (sourceVectorType && resultVectorType)
5540     return verifyVectorShapeCast(*this, sourceVectorType, resultVectorType);
5541 
5542   return success();
5543 }
5544 
5545 OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) {
5546   // No-op shape cast.
5547   if (getSource().getType() == getResult().getType())
5548     return getSource();
5549 
5550   // Canceling shape casts.
5551   if (auto otherOp = getSource().getDefiningOp<ShapeCastOp>()) {
5552     if (getResult().getType() == otherOp.getSource().getType())
5553       return otherOp.getSource();
5554 
5555     // Only allows valid transitive folding.
5556     VectorType srcType = llvm::cast<VectorType>(otherOp.getSource().getType());
5557     VectorType resultType = llvm::cast<VectorType>(getResult().getType());
5558     if (srcType.getRank() < resultType.getRank()) {
5559       if (!isValidShapeCast(srcType.getShape(), resultType.getShape()))
5560         return {};
5561     } else if (srcType.getRank() > resultType.getRank()) {
5562       if (!isValidShapeCast(resultType.getShape(), srcType.getShape()))
5563         return {};
5564     } else {
5565       return {};
5566     }
5567 
5568     setOperand(otherOp.getSource());
5569     return getResult();
5570   }
5571 
5572   // Cancelling broadcast and shape cast ops.
5573   if (auto bcastOp = getSource().getDefiningOp<BroadcastOp>()) {
5574     if (bcastOp.getSourceType() == getType())
5575       return bcastOp.getSource();
5576   }
5577 
5578   return {};
5579 }
5580 
5581 namespace {
5582 // Pattern to rewrite a ShapeCast(splat ConstantOp) -> ConstantOp.
5583 class ShapeCastConstantFolder final : public OpRewritePattern<ShapeCastOp> {
5584 public:
5585   using OpRewritePattern::OpRewritePattern;
5586 
5587   LogicalResult matchAndRewrite(ShapeCastOp shapeCastOp,
5588                                 PatternRewriter &rewriter) const override {
5589     auto constantOp =
5590         shapeCastOp.getSource().getDefiningOp<arith::ConstantOp>();
5591     if (!constantOp)
5592       return failure();
5593     // Only handle splat for now.
5594     auto dense = llvm::dyn_cast<SplatElementsAttr>(constantOp.getValue());
5595     if (!dense)
5596       return failure();
5597     auto newAttr =
5598         DenseElementsAttr::get(llvm::cast<VectorType>(shapeCastOp.getType()),
5599                                dense.getSplatValue<Attribute>());
5600     rewriter.replaceOpWithNewOp<arith::ConstantOp>(shapeCastOp, newAttr);
5601     return success();
5602   }
5603 };
5604 
5605 /// Helper function that computes a new vector type based on the input vector
5606 /// type by removing the trailing one dims:
5607 ///
5608 ///   vector<4x1x1xi1> --> vector<4x1>
5609 ///
5610 static VectorType trimTrailingOneDims(VectorType oldType) {
5611   ArrayRef<int64_t> oldShape = oldType.getShape();
5612   ArrayRef<int64_t> newShape = oldShape;
5613 
5614   ArrayRef<bool> oldScalableDims = oldType.getScalableDims();
5615   ArrayRef<bool> newScalableDims = oldScalableDims;
5616 
5617   while (!newShape.empty() && newShape.back() == 1 && !newScalableDims.back()) {
5618     newShape = newShape.drop_back(1);
5619     newScalableDims = newScalableDims.drop_back(1);
5620   }
5621 
5622   // Make sure we have at least 1 dimension.
5623   // TODO: Add support for 0-D vectors.
5624   if (newShape.empty()) {
5625     newShape = oldShape.take_back();
5626     newScalableDims = oldScalableDims.take_back();
5627   }
5628 
5629   return VectorType::get(newShape, oldType.getElementType(), newScalableDims);
5630 }
5631 
5632 /// Folds qualifying shape_cast(create_mask) into a new create_mask
5633 ///
5634 /// Looks at `vector.shape_cast` Ops that simply "drop" the trailing unit
5635 /// dimension. If the input vector comes from `vector.create_mask` for which
5636 /// the corresponding mask input value is 1 (e.g. `%c1` below), then it is safe
5637 /// to fold shape_cast into create_mask.
5638 ///
5639 /// BEFORE:
5640 ///    %1 = vector.create_mask %c1, %dim, %c1, %c1 : vector<1x[4]x1x1xi1>
5641 ///    %2 = vector.shape_cast %1 : vector<1x[4]x1x1xi1> to vector<1x[4]xi1>
5642 /// AFTER:
5643 ///    %0 = vector.create_mask %c1, %dim : vector<1x[4]xi1>
5644 class ShapeCastCreateMaskFolderTrailingOneDim final
5645     : public OpRewritePattern<ShapeCastOp> {
5646 public:
5647   using OpRewritePattern::OpRewritePattern;
5648 
5649   LogicalResult matchAndRewrite(ShapeCastOp shapeOp,
5650                                 PatternRewriter &rewriter) const override {
5651     Value shapeOpSrc = shapeOp->getOperand(0);
5652     auto createMaskOp = shapeOpSrc.getDefiningOp<vector::CreateMaskOp>();
5653     auto constantMaskOp = shapeOpSrc.getDefiningOp<vector::ConstantMaskOp>();
5654     if (!createMaskOp && !constantMaskOp)
5655       return failure();
5656 
5657     VectorType shapeOpResTy = shapeOp.getResultVectorType();
5658     VectorType shapeOpSrcTy = shapeOp.getSourceVectorType();
5659 
5660     VectorType newVecType = trimTrailingOneDims(shapeOpSrcTy);
5661     if (newVecType != shapeOpResTy)
5662       return failure();
5663 
5664     auto numDimsToDrop =
5665         shapeOpSrcTy.getShape().size() - shapeOpResTy.getShape().size();
5666 
5667     // No unit dims to drop
5668     if (!numDimsToDrop)
5669       return failure();
5670 
5671     if (createMaskOp) {
5672       auto maskOperands = createMaskOp.getOperands();
5673       auto numMaskOperands = maskOperands.size();
5674 
5675       // Check every mask dim size to see whether it can be dropped
5676       for (size_t i = numMaskOperands - 1; i >= numMaskOperands - numDimsToDrop;
5677            --i) {
5678         auto constant = maskOperands[i].getDefiningOp<arith::ConstantIndexOp>();
5679         if (!constant || (constant.value() != 1))
5680           return failure();
5681       }
5682       SmallVector<Value> newMaskOperands =
5683           maskOperands.drop_back(numDimsToDrop);
5684 
5685       rewriter.replaceOpWithNewOp<vector::CreateMaskOp>(shapeOp, shapeOpResTy,
5686                                                         newMaskOperands);
5687       return success();
5688     }
5689 
5690     if (constantMaskOp) {
5691       auto maskDimSizes = constantMaskOp.getMaskDimSizes();
5692       auto numMaskOperands = maskDimSizes.size();
5693 
5694       // Check every mask dim size to see whether it can be dropped
5695       for (size_t i = numMaskOperands - 1; i >= numMaskOperands - numDimsToDrop;
5696            --i) {
5697         if (maskDimSizes[i] != 1)
5698           return failure();
5699       }
5700 
5701       auto newMaskOperands = maskDimSizes.drop_back(numDimsToDrop);
5702       rewriter.replaceOpWithNewOp<vector::ConstantMaskOp>(shapeOp, shapeOpResTy,
5703                                                           newMaskOperands);
5704       return success();
5705     }
5706 
5707     return failure();
5708   }
5709 };
5710 
5711 /// Pattern to rewrite a ShapeCast(Broadcast) -> Broadcast.
5712 /// This only applies when the shape of the broadcast source
5713 /// 1. is a suffix of the shape of the result (i.e. when broadcast without
5714 ///    reshape is expressive enough to capture the result in a single op), or
5715 /// 2. has the same element count as the shape cast result.
5716 class ShapeCastBroadcastFolder final : public OpRewritePattern<ShapeCastOp> {
5717 public:
5718   using OpRewritePattern::OpRewritePattern;
5719 
5720   LogicalResult matchAndRewrite(ShapeCastOp shapeCastOp,
5721                                 PatternRewriter &rewriter) const override {
5722     auto broadcastOp =
5723         shapeCastOp.getSource().getDefiningOp<vector::BroadcastOp>();
5724     if (!broadcastOp)
5725       return failure();
5726 
5727     ArrayRef<int64_t> broadcastSourceShape;
5728     if (auto srcType = dyn_cast<VectorType>(broadcastOp.getSourceType()))
5729       broadcastSourceShape = srcType.getShape();
5730     ArrayRef<int64_t> shapeCastTargetShape =
5731         shapeCastOp.getResultVectorType().getShape();
5732 
5733     // If `broadcastSourceShape` is a suffix of the result, we can just replace
5734     // with a broadcast to the final shape.
5735     if (broadcastSourceShape ==
5736         shapeCastTargetShape.take_back(broadcastSourceShape.size())) {
5737       rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
5738           shapeCastOp, shapeCastOp.getResultVectorType(),
5739           broadcastOp.getSource());
5740       return success();
5741     }
5742 
5743     // Otherwise, if the final result has the same element count, we can replace
5744     // with a shape cast.
5745     if (auto srcType = dyn_cast<VectorType>(broadcastOp.getSourceType())) {
5746       if (srcType.getNumElements() ==
5747           shapeCastOp.getResultVectorType().getNumElements()) {
5748         rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(
5749             shapeCastOp, shapeCastOp.getResultVectorType(),
5750             broadcastOp.getSource());
5751         return success();
5752       }
5753     }
5754 
5755     return failure();
5756   }
5757 };
5758 
5759 } // namespace
5760 
5761 void ShapeCastOp::getCanonicalizationPatterns(RewritePatternSet &results,
5762                                               MLIRContext *context) {
5763   results.add<ShapeCastConstantFolder, ShapeCastCreateMaskFolderTrailingOneDim,
5764               ShapeCastBroadcastFolder>(context);
5765 }
5766 
5767 //===----------------------------------------------------------------------===//
5768 // VectorBitCastOp
5769 //===----------------------------------------------------------------------===//
5770 
5771 LogicalResult BitCastOp::verify() {
5772   auto sourceVectorType = getSourceVectorType();
5773   auto resultVectorType = getResultVectorType();
5774 
5775   for (int64_t i = 0, e = sourceVectorType.getRank() - 1; i < e; i++) {
5776     if (sourceVectorType.getDimSize(i) != resultVectorType.getDimSize(i))
5777       return emitOpError("dimension size mismatch at: ") << i;
5778   }
5779 
5780   DataLayout dataLayout = DataLayout::closest(*this);
5781   auto sourceElementBits =
5782       dataLayout.getTypeSizeInBits(sourceVectorType.getElementType());
5783   auto resultElementBits =
5784       dataLayout.getTypeSizeInBits(resultVectorType.getElementType());
5785 
5786   if (sourceVectorType.getRank() == 0) {
5787     if (sourceElementBits != resultElementBits)
5788       return emitOpError("source/result bitwidth of the 0-D vector element "
5789                          "types must be equal");
5790   } else if (sourceElementBits * sourceVectorType.getShape().back() !=
5791              resultElementBits * resultVectorType.getShape().back()) {
5792     return emitOpError(
5793         "source/result bitwidth of the minor 1-D vectors must be equal");
5794   }
5795 
5796   return success();
5797 }
5798 
5799 OpFoldResult BitCastOp::fold(FoldAdaptor adaptor) {
5800   // Nop cast.
5801   if (getSource().getType() == getResult().getType())
5802     return getSource();
5803 
5804   // Canceling bitcasts.
5805   if (auto otherOp = getSource().getDefiningOp<BitCastOp>()) {
5806     if (getResult().getType() == otherOp.getSource().getType())
5807       return otherOp.getSource();
5808 
5809     setOperand(otherOp.getSource());
5810     return getResult();
5811   }
5812 
5813   Attribute sourceConstant = adaptor.getSource();
5814   if (!sourceConstant)
5815     return {};
5816 
5817   Type srcElemType = getSourceVectorType().getElementType();
5818   Type dstElemType = getResultVectorType().getElementType();
5819 
5820   if (auto floatPack = llvm::dyn_cast<DenseFPElementsAttr>(sourceConstant)) {
5821     if (floatPack.isSplat()) {
5822       auto splat = floatPack.getSplatValue<FloatAttr>();
5823 
5824       // Casting fp16 into fp32.
5825       if (srcElemType.isF16() && dstElemType.isF32()) {
5826         uint32_t bits = static_cast<uint32_t>(
5827             splat.getValue().bitcastToAPInt().getZExtValue());
5828         // Duplicate the 16-bit pattern.
5829         bits = (bits << 16) | (bits & 0xffff);
5830         APInt intBits(32, bits);
5831         APFloat floatBits(llvm::APFloat::IEEEsingle(), intBits);
5832         return DenseElementsAttr::get(getResultVectorType(), floatBits);
5833       }
5834     }
5835   }
5836 
5837   if (auto intPack = llvm::dyn_cast<DenseIntElementsAttr>(sourceConstant)) {
5838     if (intPack.isSplat()) {
5839       auto splat = intPack.getSplatValue<IntegerAttr>();
5840 
5841       if (llvm::isa<IntegerType>(dstElemType)) {
5842         uint64_t srcBitWidth = srcElemType.getIntOrFloatBitWidth();
5843         uint64_t dstBitWidth = dstElemType.getIntOrFloatBitWidth();
5844 
5845         // Casting to a larger integer bit width.
5846         if (dstBitWidth > srcBitWidth && dstBitWidth % srcBitWidth == 0) {
5847           APInt intBits = splat.getValue().zext(dstBitWidth);
5848 
5849           // Duplicate the lower width element.
5850           for (uint64_t i = 0; i < dstBitWidth / srcBitWidth - 1; i++)
5851             intBits = (intBits << srcBitWidth) | intBits;
5852           return DenseElementsAttr::get(getResultVectorType(), intBits);
5853         }
5854       }
5855     }
5856   }
5857 
5858   return {};
5859 }
5860 
5861 //===----------------------------------------------------------------------===//
5862 // TypeCastOp
5863 //===----------------------------------------------------------------------===//
5864 
5865 static SmallVector<int64_t, 8> extractShape(MemRefType memRefType) {
5866   auto vectorType = llvm::dyn_cast<VectorType>(memRefType.getElementType());
5867   SmallVector<int64_t, 8> res(memRefType.getShape());
5868   if (vectorType)
5869     res.append(vectorType.getShape().begin(), vectorType.getShape().end());
5870   return res;
5871 }
5872 
5873 /// Build the canonical memRefType with a single vector.
5874 /// E.g. memref<4 x 5 x vector<6 x f32>> -> memref<vector<4 x 5 x 6 x f32>>.
5875 void TypeCastOp::build(OpBuilder &builder, OperationState &result,
5876                        Value source) {
5877   result.addOperands(source);
5878   MemRefType memRefType = llvm::cast<MemRefType>(source.getType());
5879   VectorType vectorType =
5880       VectorType::get(extractShape(memRefType),
5881                       getElementTypeOrSelf(getElementTypeOrSelf(memRefType)));
5882   result.addTypes(MemRefType::get({}, vectorType, MemRefLayoutAttrInterface(),
5883                                   memRefType.getMemorySpace()));
5884 }
5885 
5886 LogicalResult TypeCastOp::verify() {
5887   MemRefType canonicalType = getMemRefType().canonicalizeStridedLayout();
5888   if (!canonicalType.getLayout().isIdentity())
5889     return emitOpError("expects operand to be a memref with identity layout");
5890   if (!getResultMemRefType().getLayout().isIdentity())
5891     return emitOpError("expects result to be a memref with identity layout");
5892   if (getResultMemRefType().getMemorySpace() !=
5893       getMemRefType().getMemorySpace())
5894     return emitOpError("expects result in same memory space");
5895 
5896   auto sourceType = getMemRefType();
5897   auto resultType = getResultMemRefType();
5898   if (getElementTypeOrSelf(getElementTypeOrSelf(sourceType)) !=
5899       getElementTypeOrSelf(getElementTypeOrSelf(resultType)))
5900     return emitOpError(
5901                "expects result and operand with same underlying scalar type: ")
5902            << resultType;
5903   if (extractShape(sourceType) != extractShape(resultType))
5904     return emitOpError(
5905                "expects concatenated result and operand shapes to be equal: ")
5906            << resultType;
5907   return success();
5908 }
5909 
5910 //===----------------------------------------------------------------------===//
5911 // TransposeOp
5912 //===----------------------------------------------------------------------===//
5913 
5914 void vector::TransposeOp::build(OpBuilder &builder, OperationState &result,
5915                                 Value vector, ArrayRef<int64_t> permutation) {
5916   VectorType vt = llvm::cast<VectorType>(vector.getType());
5917   SmallVector<int64_t, 4> transposedShape(vt.getRank());
5918   SmallVector<bool, 4> transposedScalableDims(vt.getRank());
5919   for (unsigned i = 0; i < permutation.size(); ++i) {
5920     transposedShape[i] = vt.getShape()[permutation[i]];
5921     transposedScalableDims[i] = vt.getScalableDims()[permutation[i]];
5922   }
5923 
5924   result.addOperands(vector);
5925   result.addTypes(VectorType::get(transposedShape, vt.getElementType(),
5926                                   transposedScalableDims));
5927   result.addAttribute(TransposeOp::getPermutationAttrName(result.name),
5928                       builder.getDenseI64ArrayAttr(permutation));
5929 }
5930 
5931 OpFoldResult vector::TransposeOp::fold(FoldAdaptor adaptor) {
5932   // Eliminate splat constant transpose ops.
5933   if (auto attr =
5934           llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getVector()))
5935     if (attr.isSplat())
5936       return attr.reshape(getResultVectorType());
5937 
5938   // Eliminate identity transpose ops. This happens when the dimensions of the
5939   // input vector remain in their original order after the transpose operation.
5940   ArrayRef<int64_t> perm = getPermutation();
5941 
5942   // Check if the permutation of the dimensions contains sequential values:
5943   // {0, 1, 2, ...}.
5944   for (int64_t i = 0, e = perm.size(); i < e; i++) {
5945     if (perm[i] != i)
5946       return {};
5947   }
5948 
5949   return getVector();
5950 }
5951 
5952 LogicalResult vector::TransposeOp::verify() {
5953   VectorType vectorType = getSourceVectorType();
5954   VectorType resultType = getResultVectorType();
5955   int64_t rank = resultType.getRank();
5956   if (vectorType.getRank() != rank)
5957     return emitOpError("vector result rank mismatch: ") << rank;
5958   // Verify transposition array.
5959   ArrayRef<int64_t> perm = getPermutation();
5960   int64_t size = perm.size();
5961   if (rank != size)
5962     return emitOpError("transposition length mismatch: ") << size;
5963   SmallVector<bool, 8> seen(rank, false);
5964   for (const auto &ta : llvm::enumerate(perm)) {
5965     if (ta.value() < 0 || ta.value() >= rank)
5966       return emitOpError("transposition index out of range: ") << ta.value();
5967     if (seen[ta.value()])
5968       return emitOpError("duplicate position index: ") << ta.value();
5969     seen[ta.value()] = true;
5970     if (resultType.getDimSize(ta.index()) != vectorType.getDimSize(ta.value()))
5971       return emitOpError("dimension size mismatch at: ") << ta.value();
5972   }
5973   return success();
5974 }
5975 
5976 std::optional<SmallVector<int64_t, 4>> TransposeOp::getShapeForUnroll() {
5977   return llvm::to_vector<4>(getResultVectorType().getShape());
5978 }
5979 
5980 namespace {
5981 
5982 // Rewrites two back-to-back TransposeOp operations into a single TransposeOp.
5983 class TransposeFolder final : public OpRewritePattern<vector::TransposeOp> {
5984 public:
5985   using OpRewritePattern::OpRewritePattern;
5986 
5987   LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,
5988                                 PatternRewriter &rewriter) const override {
5989     // Composes two permutations: result[i] = permutation1[permutation2[i]].
5990     auto composePermutations = [](ArrayRef<int64_t> permutation1,
5991                                   ArrayRef<int64_t> permutation2) {
5992       SmallVector<int64_t, 4> result;
5993       for (auto index : permutation2)
5994         result.push_back(permutation1[index]);
5995       return result;
5996     };
5997 
5998     // Return if the input of 'transposeOp' is not defined by another transpose.
5999     vector::TransposeOp parentTransposeOp =
6000         transposeOp.getVector().getDefiningOp<vector::TransposeOp>();
6001     if (!parentTransposeOp)
6002       return failure();
6003 
6004     SmallVector<int64_t, 4> permutation = composePermutations(
6005         parentTransposeOp.getPermutation(), transposeOp.getPermutation());
6006     // Replace 'transposeOp' with a new transpose operation.
6007     rewriter.replaceOpWithNewOp<vector::TransposeOp>(
6008         transposeOp, transposeOp.getResult().getType(),
6009         parentTransposeOp.getVector(), permutation);
6010     return success();
6011   }
6012 };
6013 
6014 // Folds transpose(broadcast(<scalar>)) into brodcast(<scalar>).
6015 struct FoldTransposedScalarBroadcast final
6016     : public OpRewritePattern<vector::TransposeOp> {
6017   using OpRewritePattern::OpRewritePattern;
6018 
6019   LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,
6020                                 PatternRewriter &rewriter) const override {
6021     auto bcastOp = transposeOp.getVector().getDefiningOp<vector::BroadcastOp>();
6022     if (!bcastOp)
6023       return failure();
6024 
6025     auto srcVectorType = llvm::dyn_cast<VectorType>(bcastOp.getSourceType());
6026     if (!srcVectorType || srcVectorType.getNumElements() == 1) {
6027       rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
6028           transposeOp, transposeOp.getResultVectorType(), bcastOp.getSource());
6029       return success();
6030     }
6031 
6032     return failure();
6033   }
6034 };
6035 
6036 // Folds transpose(splat x : src_type) : res_type into splat x : res_type.
6037 class FoldTransposeSplat final : public OpRewritePattern<TransposeOp> {
6038 public:
6039   using OpRewritePattern::OpRewritePattern;
6040 
6041   LogicalResult matchAndRewrite(TransposeOp transposeOp,
6042                                 PatternRewriter &rewriter) const override {
6043     auto splatOp = transposeOp.getVector().getDefiningOp<vector::SplatOp>();
6044     if (!splatOp)
6045       return failure();
6046 
6047     rewriter.replaceOpWithNewOp<vector::SplatOp>(
6048         transposeOp, transposeOp.getResultVectorType(), splatOp.getInput());
6049     return success();
6050   }
6051 };
6052 
6053 /// Folds transpose(create_mask) into a new transposed create_mask.
6054 class FoldTransposeCreateMask final : public OpRewritePattern<TransposeOp> {
6055 public:
6056   using OpRewritePattern::OpRewritePattern;
6057 
6058   LogicalResult matchAndRewrite(TransposeOp transpOp,
6059                                 PatternRewriter &rewriter) const override {
6060     Value transposeSrc = transpOp.getVector();
6061     auto createMaskOp = transposeSrc.getDefiningOp<vector::CreateMaskOp>();
6062     auto constantMaskOp = transposeSrc.getDefiningOp<vector::ConstantMaskOp>();
6063     if (!createMaskOp && !constantMaskOp)
6064       return failure();
6065 
6066     // Get the transpose permutation and apply it to the vector.create_mask or
6067     // vector.constant_mask operands.
6068     ArrayRef<int64_t> permutation = transpOp.getPermutation();
6069 
6070     if (createMaskOp) {
6071       auto maskOperands = createMaskOp.getOperands();
6072       SmallVector<Value> newOperands(maskOperands.begin(), maskOperands.end());
6073       applyPermutationToVector(newOperands, permutation);
6074 
6075       rewriter.replaceOpWithNewOp<vector::CreateMaskOp>(
6076           transpOp, transpOp.getResultVectorType(), newOperands);
6077       return success();
6078     }
6079 
6080     // ConstantMaskOp case.
6081     auto maskDimSizes = constantMaskOp.getMaskDimSizes();
6082     auto newMaskDimSizes = applyPermutation(maskDimSizes, permutation);
6083 
6084     rewriter.replaceOpWithNewOp<vector::ConstantMaskOp>(
6085         transpOp, transpOp.getResultVectorType(), newMaskDimSizes);
6086     return success();
6087   }
6088 };
6089 
6090 } // namespace
6091 
6092 void vector::TransposeOp::getCanonicalizationPatterns(
6093     RewritePatternSet &results, MLIRContext *context) {
6094   results.add<FoldTransposeCreateMask, FoldTransposedScalarBroadcast,
6095               TransposeFolder, FoldTransposeSplat>(context);
6096 }
6097 
6098 //===----------------------------------------------------------------------===//
6099 // ConstantMaskOp
6100 //===----------------------------------------------------------------------===//
6101 
6102 void ConstantMaskOp::build(OpBuilder &builder, OperationState &result,
6103                            VectorType type, ConstantMaskKind kind) {
6104   assert(kind == ConstantMaskKind::AllTrue ||
6105          kind == ConstantMaskKind::AllFalse);
6106   build(builder, result, type,
6107         kind == ConstantMaskKind::AllTrue
6108             ? type.getShape()
6109             : SmallVector<int64_t>(type.getRank(), 0));
6110 }
6111 
6112 LogicalResult ConstantMaskOp::verify() {
6113   auto resultType = llvm::cast<VectorType>(getResult().getType());
6114   // Check the corner case of 0-D vectors first.
6115   if (resultType.getRank() == 0) {
6116     if (getMaskDimSizes().size() != 1)
6117       return emitError("array attr must have length 1 for 0-D vectors");
6118     auto dim = getMaskDimSizes()[0];
6119     if (dim != 0 && dim != 1)
6120       return emitError("mask dim size must be either 0 or 1 for 0-D vectors");
6121     return success();
6122   }
6123 
6124   // Verify that array attr size matches the rank of the vector result.
6125   if (static_cast<int64_t>(getMaskDimSizes().size()) != resultType.getRank())
6126     return emitOpError(
6127         "must specify array attr of size equal vector result rank");
6128   // Verify that each array attr element is in bounds of corresponding vector
6129   // result dimension size.
6130   auto resultShape = resultType.getShape();
6131   auto resultScalableDims = resultType.getScalableDims();
6132   ArrayRef<int64_t> maskDimSizes = getMaskDimSizes();
6133   for (const auto [index, maskDimSize] : llvm::enumerate(maskDimSizes)) {
6134     if (maskDimSize < 0 || maskDimSize > resultShape[index])
6135       return emitOpError(
6136           "array attr of size out of bounds of vector result dimension size");
6137     if (resultScalableDims[index] && maskDimSize != 0 &&
6138         maskDimSize != resultShape[index])
6139       return emitOpError(
6140           "only supports 'none set' or 'all set' scalable dimensions");
6141   }
6142   // Verify that if one mask dim size is zero, they all should be zero (because
6143   // the mask region is a conjunction of each mask dimension interval).
6144   bool anyZeros = llvm::is_contained(maskDimSizes, 0);
6145   bool allZeros = llvm::all_of(maskDimSizes, [](int64_t s) { return s == 0; });
6146   if (anyZeros && !allZeros)
6147     return emitOpError("expected all mask dim sizes to be zeros, "
6148                        "as a result of conjunction with zero mask dim");
6149   return success();
6150 }
6151 
6152 bool ConstantMaskOp::isAllOnesMask() {
6153   auto resultType = getVectorType();
6154   // Check the corner case of 0-D vectors first.
6155   if (resultType.getRank() == 0) {
6156     assert(getMaskDimSizes().size() == 1 && "invalid sizes for zero rank mask");
6157     return getMaskDimSizes()[0] == 1;
6158   }
6159   for (const auto [resultSize, maskDimSize] :
6160        llvm::zip_equal(resultType.getShape(), getMaskDimSizes())) {
6161     if (maskDimSize < resultSize)
6162       return false;
6163   }
6164   return true;
6165 }
6166 
6167 //===----------------------------------------------------------------------===//
6168 // CreateMaskOp
6169 //===----------------------------------------------------------------------===//
6170 
6171 void CreateMaskOp::build(OpBuilder &builder, OperationState &result,
6172                          VectorType type,
6173                          ArrayRef<OpFoldResult> mixedOperands) {
6174   SmallVector<Value> operands =
6175       getValueOrCreateConstantIndexOp(builder, result.location, mixedOperands);
6176   build(builder, result, type, operands);
6177 }
6178 
6179 LogicalResult CreateMaskOp::verify() {
6180   auto vectorType = llvm::cast<VectorType>(getResult().getType());
6181   // Verify that an operand was specified for each result vector each dimension.
6182   if (vectorType.getRank() == 0) {
6183     if (getNumOperands() != 1)
6184       return emitOpError(
6185           "must specify exactly one operand for 0-D create_mask");
6186   } else if (getNumOperands() !=
6187              llvm::cast<VectorType>(getResult().getType()).getRank()) {
6188     return emitOpError(
6189         "must specify an operand for each result vector dimension");
6190   }
6191   return success();
6192 }
6193 
6194 namespace {
6195 
6196 /// Pattern to rewrite a CreateMaskOp with a ConstantMaskOp.
6197 ///
6198 /// Ex 1:
6199 ///   %c2 = arith.constant 2 : index
6200 ///   %c3 = arith.constant 3 : index
6201 ///   %0 = vector.create_mask %c3, %c2 : vector<4x3xi1>
6202 /// Becomes:
6203 ///    vector.constant_mask [3, 2] : vector<4x3xi1>
6204 ///
6205 /// Ex 2:
6206 ///   %c_neg_1 = arith.constant -1 : index
6207 ///   %0 = vector.create_mask %c_neg_1 : vector<[8]xi1>
6208 /// becomes:
6209 ///   vector.constant_mask [0] : vector<[8]xi1>
6210 ///
6211 /// Ex 3:
6212 ///   %c8 = arith.constant 8 : index
6213 ///   %c16 = arith.constant 16 : index
6214 ///   %0 = vector.vscale
6215 ///   %1 = arith.muli %0, %c16 : index
6216 ///   %10 = vector.create_mask %c8, %1 : vector<8x[16]xi1>
6217 /// becomes:
6218 ///   %0 = vector.constant_mask [8, 16] : vector<8x[16]xi1>
6219 class CreateMaskFolder final : public OpRewritePattern<CreateMaskOp> {
6220 public:
6221   using OpRewritePattern::OpRewritePattern;
6222 
6223   LogicalResult matchAndRewrite(CreateMaskOp createMaskOp,
6224                                 PatternRewriter &rewriter) const override {
6225     VectorType maskType = createMaskOp.getVectorType();
6226     ArrayRef<int64_t> maskTypeDimSizes = maskType.getShape();
6227     ArrayRef<bool> maskTypeDimScalableFlags = maskType.getScalableDims();
6228 
6229     // Special case: Rank zero shape.
6230     constexpr std::array<int64_t, 1> rankZeroShape{1};
6231     constexpr std::array<bool, 1> rankZeroScalableDims{false};
6232     if (maskType.getRank() == 0) {
6233       maskTypeDimSizes = rankZeroShape;
6234       maskTypeDimScalableFlags = rankZeroScalableDims;
6235     }
6236 
6237     // Determine if this CreateMaskOp can be folded to a ConstantMaskOp and
6238     // collect the `constantDims` (for the ConstantMaskOp).
6239     SmallVector<int64_t, 4> constantDims;
6240     for (auto [i, dimSize] : llvm::enumerate(createMaskOp.getOperands())) {
6241       if (auto intSize = getConstantIntValue(dimSize)) {
6242         // Constant value.
6243         // If the mask dim is non-scalable this can be any value.
6244         // If the mask dim is scalable only zero (all-false) is supported.
6245         if (maskTypeDimScalableFlags[i] && intSize >= 0)
6246           return failure();
6247         constantDims.push_back(*intSize);
6248       } else if (auto vscaleMultiplier = getConstantVscaleMultiplier(dimSize)) {
6249         // Constant vscale multiple (e.g. 4 x vscale).
6250         // Must be all-true to fold to a ConstantMask.
6251         if (vscaleMultiplier < maskTypeDimSizes[i])
6252           return failure();
6253         constantDims.push_back(*vscaleMultiplier);
6254       } else {
6255         return failure();
6256       }
6257     }
6258 
6259     // Clamp values to constant_mask bounds.
6260     for (auto [value, maskDimSize] : llvm::zip(constantDims, maskTypeDimSizes))
6261       value = std::clamp<int64_t>(value, 0, maskDimSize);
6262 
6263     // If one of dim sizes is zero, set all dims to zero.
6264     if (llvm::is_contained(constantDims, 0))
6265       constantDims.assign(constantDims.size(), 0);
6266 
6267     // Replace 'createMaskOp' with ConstantMaskOp.
6268     rewriter.replaceOpWithNewOp<ConstantMaskOp>(createMaskOp, maskType,
6269                                                 constantDims);
6270     return success();
6271   }
6272 };
6273 
6274 } // namespace
6275 
6276 void CreateMaskOp::getCanonicalizationPatterns(RewritePatternSet &results,
6277                                                MLIRContext *context) {
6278   results.add<CreateMaskFolder>(context);
6279 }
6280 
6281 //===----------------------------------------------------------------------===//
6282 // MaskOp
6283 //===----------------------------------------------------------------------===//
6284 
6285 void MaskOp::build(
6286     OpBuilder &builder, OperationState &result, Value mask,
6287     Operation *maskableOp,
6288     function_ref<void(OpBuilder &, Operation *)> maskRegionBuilder) {
6289   assert(maskRegionBuilder &&
6290          "builder callback for 'maskRegion' must be present");
6291 
6292   result.addOperands(mask);
6293   OpBuilder::InsertionGuard guard(builder);
6294   Region *maskRegion = result.addRegion();
6295   builder.createBlock(maskRegion);
6296   maskRegionBuilder(builder, maskableOp);
6297 }
6298 
6299 void MaskOp::build(
6300     OpBuilder &builder, OperationState &result, TypeRange resultTypes,
6301     Value mask, Operation *maskableOp,
6302     function_ref<void(OpBuilder &, Operation *)> maskRegionBuilder) {
6303   build(builder, result, resultTypes, mask, /*passthru=*/Value(), maskableOp,
6304         maskRegionBuilder);
6305 }
6306 
6307 void MaskOp::build(
6308     OpBuilder &builder, OperationState &result, TypeRange resultTypes,
6309     Value mask, Value passthru, Operation *maskableOp,
6310     function_ref<void(OpBuilder &, Operation *)> maskRegionBuilder) {
6311   build(builder, result, mask, maskableOp, maskRegionBuilder);
6312   if (passthru)
6313     result.addOperands(passthru);
6314   result.addTypes(resultTypes);
6315 }
6316 
6317 ParseResult MaskOp::parse(OpAsmParser &parser, OperationState &result) {
6318   // Create the op region.
6319   result.regions.reserve(1);
6320   Region &maskRegion = *result.addRegion();
6321 
6322   auto &builder = parser.getBuilder();
6323 
6324   // Parse all the operands.
6325   OpAsmParser::UnresolvedOperand mask;
6326   if (parser.parseOperand(mask))
6327     return failure();
6328 
6329   // Optional passthru operand.
6330   OpAsmParser::UnresolvedOperand passthru;
6331   ParseResult parsePassthru = parser.parseOptionalComma();
6332   if (parsePassthru.succeeded() && parser.parseOperand(passthru))
6333     return failure();
6334 
6335   // Parse op region.
6336   if (parser.parseRegion(maskRegion, /*arguments=*/{}, /*argTypes=*/{}))
6337     return failure();
6338 
6339   MaskOp::ensureTerminator(maskRegion, builder, result.location);
6340 
6341   // Parse the optional attribute list.
6342   if (parser.parseOptionalAttrDict(result.attributes))
6343     return failure();
6344 
6345   // Parse all the types.
6346   Type maskType;
6347   if (parser.parseColonType(maskType))
6348     return failure();
6349 
6350   SmallVector<Type> resultTypes;
6351   if (parser.parseOptionalArrowTypeList(resultTypes))
6352     return failure();
6353   result.types.append(resultTypes);
6354 
6355   // Resolve operands.
6356   if (parser.resolveOperand(mask, maskType, result.operands))
6357     return failure();
6358 
6359   if (parsePassthru.succeeded())
6360     if (parser.resolveOperand(passthru, resultTypes[0], result.operands))
6361       return failure();
6362 
6363   return success();
6364 }
6365 
6366 void mlir::vector::MaskOp::print(OpAsmPrinter &p) {
6367   p << " " << getMask();
6368   if (getPassthru())
6369     p << ", " << getPassthru();
6370 
6371   // Print single masked operation and skip terminator.
6372   p << " { ";
6373   Block *singleBlock = &getMaskRegion().getBlocks().front();
6374   if (singleBlock && !singleBlock->getOperations().empty())
6375     p.printCustomOrGenericOp(&singleBlock->front());
6376   p << " }";
6377 
6378   p.printOptionalAttrDict(getOperation()->getAttrs());
6379 
6380   p << " : " << getMask().getType();
6381   if (getNumResults() > 0)
6382     p << " -> " << getResultTypes();
6383 }
6384 
6385 void MaskOp::ensureTerminator(Region &region, Builder &builder, Location loc) {
6386   OpTrait::SingleBlockImplicitTerminator<vector::YieldOp>::Impl<
6387       MaskOp>::ensureTerminator(region, builder, loc);
6388   // Keep the default yield terminator if the number of masked operations is not
6389   // the expected. This case will trigger a verification failure.
6390   Block &block = region.front();
6391   if (block.getOperations().size() != 2)
6392     return;
6393 
6394   // Replace default yield terminator with a new one that returns the results
6395   // from the masked operation.
6396   OpBuilder opBuilder(builder.getContext());
6397   Operation *maskedOp = &block.front();
6398   Operation *oldYieldOp = &block.back();
6399   assert(isa<vector::YieldOp>(oldYieldOp) && "Expected vector::YieldOp");
6400 
6401   // Empty vector.mask op.
6402   if (maskedOp == oldYieldOp)
6403     return;
6404 
6405   opBuilder.setInsertionPoint(oldYieldOp);
6406   opBuilder.create<vector::YieldOp>(loc, maskedOp->getResults());
6407   oldYieldOp->dropAllReferences();
6408   oldYieldOp->erase();
6409 }
6410 
6411 LogicalResult MaskOp::verify() {
6412   // Structural checks.
6413   Block &block = getMaskRegion().getBlocks().front();
6414   if (block.getOperations().empty())
6415     return emitOpError("expects a terminator within the mask region");
6416 
6417   unsigned numMaskRegionOps = block.getOperations().size();
6418   if (numMaskRegionOps > 2)
6419     return emitOpError("expects only one operation to mask");
6420 
6421   // Terminator checks.
6422   auto terminator = dyn_cast<vector::YieldOp>(block.back());
6423   if (!terminator)
6424     return emitOpError("expects a terminator within the mask region");
6425 
6426   if (terminator->getNumOperands() != getNumResults())
6427     return emitOpError(
6428         "expects number of results to match mask region yielded values");
6429 
6430   // Empty vector.mask. Nothing else to check.
6431   if (numMaskRegionOps == 1)
6432     return success();
6433 
6434   auto maskableOp = dyn_cast<MaskableOpInterface>(block.front());
6435   if (!maskableOp)
6436     return emitOpError("expects a MaskableOpInterface within the mask region");
6437 
6438   // Result checks.
6439   if (maskableOp->getNumResults() != getNumResults())
6440     return emitOpError("expects number of results to match maskable operation "
6441                        "number of results");
6442 
6443   if (!llvm::equal(maskableOp->getResultTypes(), getResultTypes()))
6444     return emitOpError(
6445         "expects result type to match maskable operation result type");
6446 
6447   if (llvm::count_if(maskableOp->getResultTypes(),
6448                      [](Type t) { return llvm::isa<VectorType>(t); }) > 1)
6449     return emitOpError("multiple vector results not supported");
6450 
6451   // Mask checks.
6452   Type expectedMaskType = maskableOp.getExpectedMaskType();
6453   if (getMask().getType() != expectedMaskType)
6454     return emitOpError("expects a ")
6455            << expectedMaskType << " mask for the maskable operation";
6456 
6457   // Passthru checks.
6458   Value passthru = getPassthru();
6459   if (passthru) {
6460     if (!maskableOp.supportsPassthru())
6461       return emitOpError(
6462           "doesn't expect a passthru argument for this maskable operation");
6463 
6464     if (maskableOp->getNumResults() != 1)
6465       return emitOpError("expects result when passthru argument is provided");
6466 
6467     if (passthru.getType() != maskableOp->getResultTypes()[0])
6468       return emitOpError("expects passthru type to match result type");
6469   }
6470 
6471   return success();
6472 }
6473 
6474 /// Folds vector.mask ops with an all-true mask.
6475 LogicalResult MaskOp::fold(FoldAdaptor adaptor,
6476                            SmallVectorImpl<OpFoldResult> &results) {
6477   MaskFormat maskFormat = getMaskFormat(getMask());
6478   if (isEmpty())
6479     return failure();
6480 
6481   if (maskFormat != MaskFormat::AllTrue)
6482     return failure();
6483 
6484   // Move maskable operation outside of the `vector.mask` region.
6485   Operation *maskableOp = getMaskableOp();
6486   maskableOp->dropAllUses();
6487   maskableOp->moveBefore(getOperation());
6488 
6489   llvm::append_range(results, maskableOp->getResults());
6490   return success();
6491 }
6492 
6493 // Elides empty vector.mask operations with or without return values. Propagates
6494 // the yielded values by the vector.yield terminator, if any, or erases the op,
6495 // otherwise.
6496 class ElideEmptyMaskOp : public OpRewritePattern<MaskOp> {
6497   using OpRewritePattern::OpRewritePattern;
6498 
6499   LogicalResult matchAndRewrite(MaskOp maskOp,
6500                                 PatternRewriter &rewriter) const override {
6501     auto maskingOp = cast<MaskingOpInterface>(maskOp.getOperation());
6502     if (maskingOp.getMaskableOp())
6503       return failure();
6504 
6505     if (!maskOp.isEmpty())
6506       return failure();
6507 
6508     Block *block = maskOp.getMaskBlock();
6509     auto terminator = cast<vector::YieldOp>(block->front());
6510     if (terminator.getNumOperands() == 0)
6511       rewriter.eraseOp(maskOp);
6512     else
6513       rewriter.replaceOp(maskOp, terminator.getOperands());
6514 
6515     return success();
6516   }
6517 };
6518 
6519 void MaskOp::getCanonicalizationPatterns(RewritePatternSet &results,
6520                                          MLIRContext *context) {
6521   results.add<ElideEmptyMaskOp>(context);
6522 }
6523 
6524 // MaskingOpInterface definitions.
6525 
6526 /// Returns the operation masked by this 'vector.mask'.
6527 Operation *MaskOp::getMaskableOp() {
6528   Block *block = getMaskBlock();
6529   if (block->getOperations().size() < 2)
6530     return nullptr;
6531 
6532   return &block->front();
6533 }
6534 
6535 /// Returns true if 'vector.mask' has a passthru value.
6536 bool MaskOp::hasPassthru() { return getPassthru() != Value(); }
6537 
6538 //===----------------------------------------------------------------------===//
6539 // ScanOp
6540 //===----------------------------------------------------------------------===//
6541 
6542 LogicalResult ScanOp::verify() {
6543   VectorType srcType = getSourceType();
6544   VectorType initialType = getInitialValueType();
6545   // Check reduction dimension < rank.
6546   int64_t srcRank = srcType.getRank();
6547   int64_t reductionDim = getReductionDim();
6548   if (reductionDim >= srcRank)
6549     return emitOpError("reduction dimension ")
6550            << reductionDim << " has to be less than " << srcRank;
6551 
6552   // Check that rank(initial_value) = rank(src) - 1.
6553   int64_t initialValueRank = initialType.getRank();
6554   if (initialValueRank != srcRank - 1)
6555     return emitOpError("initial value rank ")
6556            << initialValueRank << " has to be equal to " << srcRank - 1;
6557 
6558   // Check shapes of initial value and src.
6559   ArrayRef<int64_t> srcShape = srcType.getShape();
6560   ArrayRef<int64_t> initialValueShapes = initialType.getShape();
6561   SmallVector<int64_t> expectedShape;
6562   for (int i = 0; i < srcRank; i++) {
6563     if (i != reductionDim)
6564       expectedShape.push_back(srcShape[i]);
6565   }
6566   if (!llvm::equal(initialValueShapes, expectedShape)) {
6567     return emitOpError("incompatible input/initial value shapes");
6568   }
6569 
6570   // Verify supported reduction kind.
6571   Type eltType = getDestType().getElementType();
6572   if (!isSupportedCombiningKind(getKind(), eltType))
6573     return emitOpError("unsupported reduction type ")
6574            << eltType << " for kind '" << stringifyCombiningKind(getKind())
6575            << "'";
6576 
6577   return success();
6578 }
6579 
6580 void mlir::vector::populateVectorToVectorCanonicalizationPatterns(
6581     RewritePatternSet &patterns, PatternBenefit benefit) {
6582   patterns
6583       .add<CreateMaskFolder, MaskedLoadFolder, MaskedStoreFolder, GatherFolder,
6584            ScatterFolder, ExpandLoadFolder, CompressStoreFolder,
6585            StridedSliceConstantMaskFolder, TransposeFolder>(
6586           patterns.getContext(), benefit);
6587 }
6588 
6589 //===----------------------------------------------------------------------===//
6590 // SplatOp
6591 //===----------------------------------------------------------------------===//
6592 
6593 OpFoldResult SplatOp::fold(FoldAdaptor adaptor) {
6594   auto constOperand = adaptor.getInput();
6595   if (!isa_and_nonnull<IntegerAttr, FloatAttr>(constOperand))
6596     return {};
6597 
6598   // SplatElementsAttr::get treats single value for second arg as being a splat.
6599   return SplatElementsAttr::get(getType(), {constOperand});
6600 }
6601 
6602 void SplatOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
6603                                 SetIntRangeFn setResultRanges) {
6604   setResultRanges(getResult(), argRanges.front());
6605 }
6606 
6607 Value mlir::vector::makeArithReduction(OpBuilder &b, Location loc,
6608                                        CombiningKind kind, Value v1, Value acc,
6609                                        arith::FastMathFlagsAttr fastmath,
6610                                        Value mask) {
6611   Type t1 = getElementTypeOrSelf(v1.getType());
6612   Type tAcc = getElementTypeOrSelf(acc.getType());
6613   Value result;
6614 
6615   switch (kind) {
6616   case CombiningKind::ADD:
6617     if (t1.isIntOrIndex() && tAcc.isIntOrIndex())
6618       result = b.createOrFold<arith::AddIOp>(loc, v1, acc);
6619     else if (llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc))
6620       result = b.createOrFold<arith::AddFOp>(loc, v1, acc, fastmath);
6621     else
6622       llvm_unreachable("invalid value types for ADD reduction");
6623     break;
6624   case CombiningKind::AND:
6625     assert(t1.isIntOrIndex() && tAcc.isIntOrIndex() && "expected int values");
6626     result = b.createOrFold<arith::AndIOp>(loc, v1, acc);
6627     break;
6628   case CombiningKind::MAXNUMF:
6629     assert(llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc) &&
6630            "expected float values");
6631     result = b.createOrFold<arith::MaxNumFOp>(loc, v1, acc, fastmath);
6632     break;
6633   case CombiningKind::MAXIMUMF:
6634     assert(llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc) &&
6635            "expected float values");
6636     result = b.createOrFold<arith::MaximumFOp>(loc, v1, acc, fastmath);
6637     break;
6638   case CombiningKind::MINNUMF:
6639     assert(llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc) &&
6640            "expected float values");
6641     result = b.createOrFold<arith::MinNumFOp>(loc, v1, acc, fastmath);
6642     break;
6643   case CombiningKind::MINIMUMF:
6644     assert(llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc) &&
6645            "expected float values");
6646     result = b.createOrFold<arith::MinimumFOp>(loc, v1, acc, fastmath);
6647     break;
6648   case CombiningKind::MAXSI:
6649     assert(t1.isIntOrIndex() && tAcc.isIntOrIndex() && "expected int values");
6650     result = b.createOrFold<arith::MaxSIOp>(loc, v1, acc);
6651     break;
6652   case CombiningKind::MINSI:
6653     assert(t1.isIntOrIndex() && tAcc.isIntOrIndex() && "expected int values");
6654     result = b.createOrFold<arith::MinSIOp>(loc, v1, acc);
6655     break;
6656   case CombiningKind::MAXUI:
6657     assert(t1.isIntOrIndex() && tAcc.isIntOrIndex() && "expected int values");
6658     result = b.createOrFold<arith::MaxUIOp>(loc, v1, acc);
6659     break;
6660   case CombiningKind::MINUI:
6661     assert(t1.isIntOrIndex() && tAcc.isIntOrIndex() && "expected int values");
6662     result = b.createOrFold<arith::MinUIOp>(loc, v1, acc);
6663     break;
6664   case CombiningKind::MUL:
6665     if (t1.isIntOrIndex() && tAcc.isIntOrIndex())
6666       result = b.createOrFold<arith::MulIOp>(loc, v1, acc);
6667     else if (llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc))
6668       result = b.createOrFold<arith::MulFOp>(loc, v1, acc, fastmath);
6669     else
6670       llvm_unreachable("invalid value types for MUL reduction");
6671     break;
6672   case CombiningKind::OR:
6673     assert(t1.isIntOrIndex() && tAcc.isIntOrIndex() && "expected int values");
6674     result = b.createOrFold<arith::OrIOp>(loc, v1, acc);
6675     break;
6676   case CombiningKind::XOR:
6677     assert(t1.isIntOrIndex() && tAcc.isIntOrIndex() && "expected int values");
6678     result = b.createOrFold<arith::XOrIOp>(loc, v1, acc);
6679     break;
6680   };
6681 
6682   assert(result && "unknown CombiningKind");
6683   return selectPassthru(b, mask, result, acc);
6684 }
6685 
6686 //===----------------------------------------------------------------------===//
6687 // Vector Masking Utilities
6688 //===----------------------------------------------------------------------===//
6689 
6690 /// Create the vector.yield-ended region of a vector.mask op with `maskableOp`
6691 /// as masked operation.
6692 void mlir::vector::createMaskOpRegion(OpBuilder &builder,
6693                                       Operation *maskableOp) {
6694   assert(maskableOp->getBlock() && "MaskableOp must be inserted into a block");
6695   Block *insBlock = builder.getInsertionBlock();
6696   // Create a block and move the op to that block.
6697   insBlock->getOperations().splice(
6698       insBlock->begin(), maskableOp->getBlock()->getOperations(), maskableOp);
6699   builder.create<YieldOp>(maskableOp->getLoc(), maskableOp->getResults());
6700 }
6701 
6702 /// Creates a vector.mask operation around a maskable operation. Returns the
6703 /// vector.mask operation if the mask provided is valid. Otherwise, returns
6704 /// the maskable operation itself.
6705 Operation *mlir::vector::maskOperation(OpBuilder &builder,
6706                                        Operation *maskableOp, Value mask,
6707                                        Value passthru) {
6708   if (!mask)
6709     return maskableOp;
6710   if (passthru)
6711     return builder.create<MaskOp>(maskableOp->getLoc(),
6712                                   maskableOp->getResultTypes(), mask, passthru,
6713                                   maskableOp, createMaskOpRegion);
6714   return builder.create<MaskOp>(maskableOp->getLoc(),
6715                                 maskableOp->getResultTypes(), mask, maskableOp,
6716                                 createMaskOpRegion);
6717 }
6718 
6719 /// Creates a vector select operation that picks values from `newValue` or
6720 /// `passthru` for each result vector lane based on `mask`. This utility is used
6721 /// to propagate the pass-thru value of vector.mask or for cases where only the
6722 /// pass-thru value propagation is needed. VP intrinsics do not support
6723 /// pass-thru values and every mask-out lane is set to poison. LLVM backends are
6724 /// usually able to match op + select patterns and fold them into a native
6725 /// target instructions.
6726 Value mlir::vector::selectPassthru(OpBuilder &builder, Value mask,
6727                                    Value newValue, Value passthru) {
6728   if (!mask)
6729     return newValue;
6730 
6731   return builder.create<arith::SelectOp>(newValue.getLoc(), newValue.getType(),
6732                                          mask, newValue, passthru);
6733 }
6734 
6735 //===----------------------------------------------------------------------===//
6736 // TableGen'd op method definitions
6737 //===----------------------------------------------------------------------===//
6738 
6739 #define GET_ATTRDEF_CLASSES
6740 #include "mlir/Dialect/Vector/IR/VectorAttributes.cpp.inc"
6741 
6742 #define GET_OP_CLASSES
6743 #include "mlir/Dialect/Vector/IR/VectorOps.cpp.inc"
6744