xref: /llvm-project/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp (revision 5270df3d177ae8849e5f1a2cec28b193ab209ef2)
1 //===- LowerVectorContract.cpp - Lower 'vector.contract' operation --------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements target-independent rewrites and utilities to lower the
10 // 'vector.contract' operation.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Dialect/Affine/IR/AffineOps.h"
15 #include "mlir/Dialect/Arith/IR/Arith.h"
16 #include "mlir/Dialect/Arith/Utils/Utils.h"
17 #include "mlir/Dialect/Linalg/IR/Linalg.h"
18 #include "mlir/Dialect/MemRef/IR/MemRef.h"
19 #include "mlir/Dialect/SCF/IR/SCF.h"
20 #include "mlir/Dialect/Tensor/IR/Tensor.h"
21 #include "mlir/Dialect/Utils/IndexingUtils.h"
22 #include "mlir/Dialect/Utils/StructuredOpsUtils.h"
23 #include "mlir/Dialect/Vector/IR/VectorOps.h"
24 #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
25 #include "mlir/Dialect/Vector/Utils/VectorUtils.h"
26 #include "mlir/IR/BuiltinAttributeInterfaces.h"
27 #include "mlir/IR/BuiltinTypes.h"
28 #include "mlir/IR/ImplicitLocOpBuilder.h"
29 #include "mlir/IR/Location.h"
30 #include "mlir/IR/Matchers.h"
31 #include "mlir/IR/PatternMatch.h"
32 #include "mlir/IR/TypeUtilities.h"
33 #include "mlir/Interfaces/VectorInterfaces.h"
34 #include "mlir/Support/LogicalResult.h"
35 
36 #define DEBUG_TYPE "vector-contract-lowering"
37 
38 using namespace mlir;
39 using namespace mlir::vector;
40 
41 //===----------------------------------------------------------------------===//
42 // Helper functions
43 //===----------------------------------------------------------------------===//
44 
45 // Helper to find an index in an affine map.
46 static std::optional<int64_t> getResultIndex(AffineMap map, int64_t index) {
47   for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) {
48     int64_t idx = map.getDimPosition(i);
49     if (idx == index)
50       return i;
51   }
52   return std::nullopt;
53 }
54 
55 // Helper to construct iterator types with one index removed.
56 static SmallVector<Attribute> adjustIter(ArrayAttr iteratorTypes,
57                                          int64_t index) {
58   SmallVector<Attribute> results;
59   for (const auto &it : llvm::enumerate(iteratorTypes)) {
60     int64_t idx = it.index();
61     if (idx == index)
62       continue;
63     results.push_back(it.value());
64   }
65   return results;
66 }
67 
68 // Helper to construct an affine map with one index removed.
69 static AffineMap adjustMap(AffineMap map, int64_t index,
70                            PatternRewriter &rewriter) {
71   auto *ctx = rewriter.getContext();
72   SmallVector<AffineExpr> results;
73   for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) {
74     int64_t idx = map.getDimPosition(i);
75     if (idx == index)
76       continue;
77     // Re-insert remaining indices, but renamed when occurring
78     // after the removed index.
79     auto targetExpr = getAffineDimExpr(idx < index ? idx : idx - 1, ctx);
80     results.push_back(targetExpr);
81   }
82   return AffineMap::get(map.getNumDims() - 1, 0, results, ctx);
83 }
84 
85 // Helper method to possibly drop a dimension in a load.
86 // TODO
87 static Value reshapeLoad(Location loc, Value val, VectorType type,
88                          int64_t index, int64_t pos,
89                          PatternRewriter &rewriter) {
90   if (index == -1)
91     return val;
92 
93   // At extraction dimension?
94   if (index == 0)
95     return rewriter.create<vector::ExtractOp>(loc, val, pos);
96 
97   // Unroll leading dimensions.
98   VectorType vType = VectorType::Builder(type).dropDim(0);
99   VectorType resType = VectorType::Builder(type).dropDim(index);
100   Value result = rewriter.create<arith::ConstantOp>(
101       loc, resType, rewriter.getZeroAttr(resType));
102   for (int64_t d = 0, e = resType.getDimSize(0); d < e; d++) {
103     Value ext = rewriter.create<vector::ExtractOp>(loc, val, d);
104     Value load = reshapeLoad(loc, ext, vType, index - 1, pos, rewriter);
105     result = rewriter.create<vector::InsertOp>(loc, load, result, d);
106   }
107   return result;
108 }
109 
110 // Helper method to possibly drop a dimension in a store.
111 // TODO
112 static Value reshapeStore(Location loc, Value val, Value result,
113                           VectorType type, int64_t index, int64_t pos,
114                           PatternRewriter &rewriter) {
115   // Unmodified?
116   if (index == -1)
117     return val;
118   // At insertion dimension?
119   if (index == 0)
120     return rewriter.create<vector::InsertOp>(loc, val, result, pos);
121 
122   // Unroll leading dimensions.
123   VectorType vType = VectorType::Builder(type).dropDim(0);
124   for (int64_t d = 0, e = type.getDimSize(0); d < e; d++) {
125     Value ext = rewriter.create<vector::ExtractOp>(loc, result, d);
126     Value ins = rewriter.create<vector::ExtractOp>(loc, val, d);
127     Value sto = reshapeStore(loc, ins, ext, vType, index - 1, pos, rewriter);
128     result = rewriter.create<vector::InsertOp>(loc, sto, result, d);
129   }
130   return result;
131 }
132 
133 /// Helper to create arithmetic operation associated with a kind of contraction.
134 static std::optional<Value>
135 createContractArithOp(Location loc, Value x, Value y, Value acc,
136                       vector::CombiningKind kind, PatternRewriter &rewriter,
137                       bool isInt, Value mask = Value()) {
138   using vector::CombiningKind;
139   Value mul;
140 
141   if (isInt) {
142     if (kind == CombiningKind::MINF || kind == CombiningKind::MAXF ||
143         kind == CombiningKind::MINIMUMF || kind == CombiningKind::MAXIMUMF)
144       // Only valid for floating point types.
145       return std::nullopt;
146     mul = rewriter.create<arith::MulIOp>(loc, x, y);
147   } else {
148     // Float case.
149     if (kind == CombiningKind::AND || kind == CombiningKind::MINUI ||
150         kind == CombiningKind::MINSI || kind == CombiningKind::MAXUI ||
151         kind == CombiningKind::MAXSI || kind == CombiningKind::OR ||
152         kind == CombiningKind::XOR)
153       // Only valid for integer types.
154       return std::nullopt;
155     // Special case for fused multiply-add.
156     if (acc && isa<VectorType>(acc.getType()) && kind == CombiningKind::ADD) {
157       Value fma = rewriter.create<vector::FMAOp>(loc, x, y, acc);
158       if (mask)
159         // The fma op doesn't need explicit masking. However, fma ops used in
160         // reductions must preserve previous 'acc' values for masked-out lanes.
161         fma = selectPassthru(rewriter, mask, fma, acc);
162       return fma;
163     }
164     mul = rewriter.create<arith::MulFOp>(loc, x, y);
165   }
166 
167   if (!acc)
168     return std::optional<Value>(mul);
169 
170   return makeArithReduction(rewriter, loc, kind, mul, acc, mask);
171 }
172 
173 /// Return the positions of the reductions in the given map.
174 static SmallVector<int64_t> getReductionIndex(AffineMap map,
175                                               ArrayAttr iteratorTypes) {
176   SmallVector<int64_t> dimsIdx;
177   for (unsigned i = 0, e = map.getNumResults(); i < e; i++) {
178     if (isReductionIterator(iteratorTypes[map.getDimPosition(i)]))
179       dimsIdx.push_back(i);
180   }
181   return dimsIdx;
182 }
183 
184 /// Look for a given dimension in an affine map and return its position. Return
185 /// std::nullopt if the dimension is not in the map results.
186 static std::optional<unsigned> getDimPosition(AffineMap map, unsigned dim) {
187   for (unsigned i = 0, e = map.getNumResults(); i < e; i++) {
188     if (map.getDimPosition(i) == dim)
189       return i;
190   }
191   return std::nullopt;
192 }
193 
194 /// Creates an AddIOp if `isInt` is true otherwise create an arith::AddFOp using
195 /// operands `x` and `y`.
196 static Value createAdd(Location loc, Value x, Value y, bool isInt,
197                        PatternRewriter &rewriter) {
198   if (isInt)
199     return rewriter.create<arith::AddIOp>(loc, x, y);
200   return rewriter.create<arith::AddFOp>(loc, x, y);
201 }
202 
203 /// Creates a MulIOp if `isInt` is true otherwise create an MulFOp using
204 /// operands `x and `y`.
205 static Value createMul(Location loc, Value x, Value y, bool isInt,
206                        PatternRewriter &rewriter) {
207   if (isInt)
208     return rewriter.create<arith::MulIOp>(loc, x, y);
209   return rewriter.create<arith::MulFOp>(loc, x, y);
210 }
211 
212 namespace {
213 
214 /// Progressive lowering of a `vector.contract %a, %b, %c` with row-major matmul
215 /// semantics to:
216 /// ```
217 ///    %flattened_a = vector.shape_cast %a
218 ///    %flattened_b = vector.shape_cast %b
219 ///    %flattened_d = vector.matmul %flattened_a, %flattened_b
220 ///    %d = vector.shape_cast %%flattened_d
221 ///    %e = add %c, %d
222 /// ```
223 /// `vector.matmul` later lowers to `llvm.matrix.multiply`.
224 //
225 /// This only kicks in when VectorTransformsOptions is set to OuterProduct and
226 /// the vector.contract op is a row-major matrix multiply.
227 class ContractionOpToMatmulOpLowering
228     : public OpRewritePattern<vector::ContractionOp> {
229 public:
230   using OpRewritePattern::OpRewritePattern;
231 
232   using FilterConstraintType =
233       std::function<LogicalResult(vector::ContractionOp op)>;
234 
235   static LogicalResult defaultFilter(vector::ContractionOp op) {
236     return success();
237   }
238 
239   ContractionOpToMatmulOpLowering(
240       vector::VectorTransformsOptions vectorTransformOptions,
241       MLIRContext *context, PatternBenefit benefit = 1,
242       FilterConstraintType constraint = defaultFilter)
243       : OpRewritePattern<vector::ContractionOp>(context, benefit),
244         vectorTransformOptions(vectorTransformOptions),
245         filter(std::move(constraint)) {}
246 
247   LogicalResult matchAndRewrite(vector::ContractionOp op,
248                                 PatternRewriter &rewriter) const override;
249 
250 private:
251   /// Options to control the vector patterns.
252   vector::VectorTransformsOptions vectorTransformOptions;
253   FilterConstraintType filter;
254 };
255 
256 /// Progressive lowering of a `vector.contract %a, %b, %c` with row-major matmul
257 /// semantics to a reduction_size-unrolled sequence:
258 /// ```
259 ///    %at = vector.transpose %a, [1, 0]
260 ///    %bRow0 = vector.extract %b[0]
261 ///    %atRow0 = vector.extract %at[0]
262 ///    %c0 = vector.outerproduct %atRow0, %bRow0, %c
263 ///    ...
264 ///    %bRowK = vector.extract %b[K]
265 ///    %atRowK = vector.extract %at[K]
266 ///    %cK = vector.outerproduct %atRowK, %bRowK, %cK-1
267 /// ```
268 ///
269 /// This only kicks in when VectorTransformsOptions is set to OuterProduct and
270 /// the vector.contract op is a row-major matrix multiply.
271 class ContractionOpToOuterProductOpLowering
272     : public OpRewritePattern<vector::ContractionOp> {
273 public:
274   using OpRewritePattern::OpRewritePattern;
275 
276   using FilterConstraintType =
277       std::function<LogicalResult(vector::ContractionOp op)>;
278 
279   static LogicalResult defaultFilter(vector::ContractionOp op) {
280     return success();
281   }
282 
283   ContractionOpToOuterProductOpLowering(
284       vector::VectorTransformsOptions vectorTransformOptions,
285       MLIRContext *context, PatternBenefit benefit = 1,
286       FilterConstraintType constraint = defaultFilter)
287       : OpRewritePattern<vector::ContractionOp>(context, benefit),
288         vectorTransformOptions(vectorTransformOptions),
289         filter(std::move(constraint)) {}
290 
291   LogicalResult matchAndRewrite(vector::ContractionOp op,
292                                 PatternRewriter &rewriter) const override;
293 
294 private:
295   /// Options to control the vector patterns.
296   vector::VectorTransformsOptions vectorTransformOptions;
297   FilterConstraintType filter;
298 };
299 
300 /// Progressive lowering of a `vector.contract %a, %b, %c` with row-major matmul
301 /// semantics to an output-size-unrolled sequence:
302 /// ```
303 ///    %out = arith.constant ... : vector<MxNxelt_type>
304 ///    %bt = vector.transpose %b, [1, 0]
305 ///    %aRow0 = vector.extract %a[0]
306 ///    %btRow0 = vector.extract %bt[0]
307 ///    %c00 = vector.reduce %atRow0, %bRow0
308 ///    %out00 = vector.insert %c00, %out[0, 0]
309 ///    ...
310 ///    %aRowLast = vector.extract %at[M-1]
311 ///    %btRowLast = vector.extract %b[N-1]
312 ///    %cLastLast = vector.reduce %atRowLast, %bRowLast
313 ///    %outcLastLast = vector.insert %cLastLast, %out[M-1, N-1]
314 /// ```
315 ///
316 /// This only kicks in when VectorTransformsOptions is set to Dot and
317 /// the vector.contract op is a row-major matmul or matvec.
318 class ContractionOpToDotLowering
319     : public OpRewritePattern<vector::ContractionOp> {
320 public:
321   using OpRewritePattern::OpRewritePattern;
322 
323   using FilterConstraintType =
324       std::function<LogicalResult(vector::ContractionOp op)>;
325 
326   static LogicalResult defaultFilter(vector::ContractionOp op) {
327     return success();
328   }
329 
330   ContractionOpToDotLowering(
331       vector::VectorTransformsOptions vectorTransformOptions,
332       MLIRContext *context, PatternBenefit benefit = 1,
333       const FilterConstraintType &constraint = defaultFilter)
334       : OpRewritePattern<vector::ContractionOp>(context, benefit),
335         vectorTransformOptions(vectorTransformOptions), filter(defaultFilter) {}
336 
337   LogicalResult matchAndRewrite(vector::ContractionOp op,
338                                 PatternRewriter &rewriter) const override;
339 
340 private:
341   /// Options to control the vector patterns.
342   vector::VectorTransformsOptions vectorTransformOptions;
343   FilterConstraintType filter;
344 };
345 
346 /// Progressive lowering of ContractionOp.
347 ///
348 /// One:
349 ///   %x = vector.contract with at least one free/batch dimension
350 /// is replaced by:
351 ///   %a = vector.contract with one less free/batch dimension
352 ///   %b = vector.contract with one less free/batch dimension
353 ///   ..
354 ///   %x = combine %a %b ..
355 /// until a pure contraction is reached (no free/batch dimensions),
356 /// which is replaced by a dot-product.
357 ///
358 /// This only kicks in when either VectorTransformsOptions is set
359 /// to Dot or when other contraction patterns fail.
360 class ContractionOpLowering : public OpRewritePattern<vector::ContractionOp> {
361 public:
362   using OpRewritePattern::OpRewritePattern;
363   using FilterConstraintType =
364       std::function<LogicalResult(vector::ContractionOp op)>;
365 
366   static LogicalResult defaultFilter(vector::ContractionOp op) {
367     return success();
368   }
369 
370   ContractionOpLowering(vector::VectorTransformsOptions vectorTransformOptions,
371                         MLIRContext *context, PatternBenefit benefit = 1,
372                         FilterConstraintType constraint = defaultFilter)
373       : OpRewritePattern<vector::ContractionOp>(context, benefit),
374         vectorTransformOptions(vectorTransformOptions),
375         filter(std::move(constraint)) {}
376 
377   LogicalResult matchAndRewrite(vector::ContractionOp op,
378                                 PatternRewriter &rewriter) const override;
379 
380 private:
381   /// Options to control the vector patterns.
382   vector::VectorTransformsOptions vectorTransformOptions;
383   FilterConstraintType filter;
384   // Lower one parallel dimension.
385   FailureOr<Value> lowerParallel(PatternRewriter &rewriter,
386                                  vector::ContractionOp op, int64_t lhsIndex,
387                                  int64_t rhsIndex, Value mask) const;
388   // Lower one reduction dimension.
389   FailureOr<Value> lowerReduction(PatternRewriter &rewriter,
390                                   vector::ContractionOp op, Value mask) const;
391 };
392 
393 /// Generate a vector implementation for matmat, matvec and tmatvec.
394 /// This unrolls outer-products along the reduction dimension.
395 struct UnrolledOuterProductGenerator
396     : public StructuredGenerator<vector::ContractionOp, vector::IteratorType> {
397   UnrolledOuterProductGenerator(RewriterBase &b, vector::ContractionOp op)
398       : StructuredGenerator<vector::ContractionOp, vector::IteratorType>(b, op),
399         kind(op.getKind()), lhs(op.getLhs()), rhs(op.getRhs()),
400         res(op.getAcc()), lhsType(op.getLhsType()) {
401     auto maskableOp = cast<MaskableOpInterface>(op.getOperation());
402     if (maskableOp.isMasked())
403       mask = maskableOp.getMaskingOp().getMask();
404   }
405 
406   Value t(Value v, ArrayRef<int64_t> perm = {1, 0}) {
407     if (!v)
408       return v;
409     return rewriter.create<vector::TransposeOp>(loc, v, perm);
410   }
411 
412   Value promote(Value v, Type dstElementType) {
413     Type elementType = v.getType();
414     auto vecType = dyn_cast<VectorType>(elementType);
415     if (vecType)
416       elementType = vecType.getElementType();
417     if (elementType == dstElementType)
418       return v;
419     Type promotedType = dstElementType;
420     if (vecType)
421       promotedType = vecType.clone(promotedType);
422     if (isa<FloatType>(dstElementType))
423       return rewriter.create<arith::ExtFOp>(loc, promotedType, v);
424     return rewriter.create<arith::ExtSIOp>(loc, promotedType, v);
425   }
426 
427   FailureOr<Value> outerProd(Value lhs, Value rhs, Value res,
428                              VectorType lhsType, int reductionDim,
429                              std::optional<Value> maybeMask = std::nullopt) {
430     // Unrolling a scalable dimension would be incorrect - bail out.
431     if (lhsType.getScalableDims()[reductionDim])
432       return failure();
433 
434     int reductionSize = lhsType.getDimSize(reductionDim);
435     assert(reductionSize > 0 &&
436            "Reduction dim must be a known static size to allow unrolling");
437 
438     // Incremental support for masking.
439     if (mask && !maybeMask.has_value())
440       return failure();
441 
442     Type resElementType = cast<VectorType>(res.getType()).getElementType();
443     for (int64_t k = 0; k < reductionSize; ++k) {
444       Value extractA = rewriter.create<vector::ExtractOp>(loc, lhs, k);
445       Value extractB = rewriter.create<vector::ExtractOp>(loc, rhs, k);
446       extractA = promote(extractA, resElementType);
447       extractB = promote(extractB, resElementType);
448       Value extractMask;
449       if (maybeMask.has_value() && maybeMask.value())
450         extractMask =
451             rewriter.create<vector::ExtractOp>(loc, maybeMask.value(), k);
452 
453       Operation *outerProdOp = rewriter.create<vector::OuterProductOp>(
454           loc, res.getType(), extractA, extractB, res, kind);
455       res = maskOperation(rewriter, outerProdOp, extractMask)->getResult(0);
456     }
457     return res;
458   }
459 
460   /// Two outer parallel, one inner reduction (matmat flavor).
461   FailureOr<Value> matmat() {
462     if (!iters({Par(), Par(), Red()}))
463       return failure();
464     // Set up the parallel/reduction structure in the right form.
465     AffineExpr m, n, k;
466     bindDims(rewriter.getContext(), m, n, k);
467     Value transposedMask = t(mask, {2, 0, 1});
468     // Classical row-major matmul:  Just permute the lhs.
469     if (layout({{m, k}, {k, n}, {m, n}}))
470       return outerProd(t(lhs), rhs, res, lhsType, /*reductionDim=*/1,
471                        transposedMask);
472     // TODO: may be better to fail and use some vector<k> -> scalar reduction.
473     if (layout({{m, k}, {n, k}, {m, n}})) {
474       Value tlhs = t(lhs);
475       return outerProd(tlhs, t(rhs), res, lhsType, /*reductionDim=*/1,
476                        transposedMask);
477     }
478     // No need to permute anything.
479     if (layout({{k, m}, {k, n}, {m, n}}))
480       return outerProd(lhs, rhs, res, lhsType, /*reductionDim=*/0,
481                        transposedMask);
482     // Just permute the rhs.
483     if (layout({{k, m}, {n, k}, {m, n}}))
484       return outerProd(lhs, t(rhs), res, lhsType, /*reductionDim=*/0,
485                        transposedMask);
486     // Transposed output: swap RHS and LHS.
487     // Classical row-major matmul: permute the lhs.
488     if (layout({{m, k}, {k, n}, {n, m}}))
489       return outerProd(rhs, t(lhs), res, lhsType, /*reductionDim=*/1,
490                        transposedMask);
491     // TODO: may be better to fail and use some vector<k> -> scalar reduction.
492     if (layout({{m, k}, {n, k}, {n, m}})) {
493       Value trhs = t(rhs);
494       return outerProd(trhs, t(lhs), res, lhsType, /*reductionDim=*/1,
495                        transposedMask);
496     }
497     if (layout({{k, m}, {k, n}, {n, m}}))
498       return outerProd(rhs, lhs, res, lhsType, /*reductionDim=*/0,
499                        transposedMask);
500     if (layout({{k, m}, {n, k}, {n, m}}))
501       return outerProd(t(rhs), lhs, res, lhsType, /*reductionDim=*/0,
502                        transposedMask);
503     return failure();
504   }
505 
506   //
507   // One outer parallel, one inner reduction (matvec flavor).
508   // Mask needs to be transposed everywhere to turn the reduction dimension
509   // outermost as required by outerproduct.
510   //
511   FailureOr<Value> matvec() {
512     if (!iters({Par(), Red()}))
513       return failure();
514     AffineExpr m, k;
515     bindDims(rewriter.getContext(), m, k);
516     Value transposedMask = t(mask);
517 
518     // Case mat-vec: transpose.
519     if (layout({{m, k}, {k}, {m}}))
520       return outerProd(t(lhs), rhs, res, lhsType, /*reductionDim=*/1,
521                        transposedMask);
522     // Case mat-trans-vec: ready to go.
523     if (layout({{k, m}, {k}, {m}}))
524       return outerProd(lhs, rhs, res, lhsType, /*reductionDim=*/0,
525                        transposedMask);
526     // Case vec-mat: swap and transpose.
527     if (layout({{k}, {m, k}, {m}}))
528       return outerProd(t(rhs), lhs, res, lhsType, /*reductionDim=*/0,
529                        transposedMask);
530     // Case vec-mat-trans: swap and ready to go.
531     if (layout({{k}, {k, m}, {m}}))
532       return outerProd(rhs, lhs, res, lhsType, /*reductionDim=*/0,
533                        transposedMask);
534     return failure();
535   }
536 
537   //
538   // One outer reduction, one inner parallel (tmatvec flavor).
539   // Mask already has the shape of the outer product.
540   //
541   FailureOr<Value> tmatvec() {
542     if (!iters({Red(), Par()}))
543       return failure();
544     AffineExpr k, m;
545     bindDims(rewriter.getContext(), k, m);
546 
547     // Case mat-vec: transpose.
548     if (layout({{m, k}, {k}, {m}}))
549       return outerProd(t(lhs), rhs, res, lhsType, /*reductionDim=*/1, mask);
550     // Case mat-trans-vec: ready to go.
551     if (layout({{k, m}, {k}, {m}}))
552       return outerProd(lhs, rhs, res, lhsType, /*reductionDim=*/0, mask);
553     // Case vec-mat: swap and transpose.
554     if (layout({{k}, {m, k}, {m}}))
555       return outerProd(t(rhs), lhs, res, lhsType, /*reductionDim=*/0, mask);
556     // Case vec-mat-trans: swap and ready to go.
557     if (layout({{k}, {k, m}, {m}}))
558       return outerProd(rhs, lhs, res, lhsType, /*reductionDim=*/0, mask);
559     return failure();
560   }
561 
562 private:
563   vector::CombiningKind kind;
564   Value lhs, rhs, res, mask;
565   VectorType lhsType;
566 };
567 
568 /// Progressively lower a `vector.contract %a, %b, %c` with row-major matmul
569 /// semantics to a reduction_size-unrolled sequence:
570 /// ```
571 ///    %at = vector.transpose %a, [1, 0]
572 ///    %bRow0 = vector.extract %b[0]
573 ///    %atRow0 = vector.extract %at[0]
574 ///    %c0 = vector.outerproduct %atRow0, %bRow0, %c
575 ///    ...
576 ///    %bRowK = vector.extract %b[K]
577 ///    %atRowK = vector.extract %at[K]
578 ///    %cK = vector.outerproduct %atRowK, %bRowK, %cK-1
579 /// ```
580 ///
581 /// This only kicks in when VectorTransformsOptions is set to OuterProduct but
582 /// otherwise supports any layout permutation of the matrix-multiply.
583 LogicalResult ContractionOpToOuterProductOpLowering::matchAndRewrite(
584     vector::ContractionOp op, PatternRewriter &rewriter) const {
585   if (vectorTransformOptions.vectorContractLowering !=
586       vector::VectorContractLowering::OuterProduct)
587     return failure();
588 
589   if (failed(filter(op)))
590     return failure();
591 
592   // Vector mask setup.
593   OpBuilder::InsertionGuard guard(rewriter);
594   auto maskableOp = cast<vector::MaskableOpInterface>(op.getOperation());
595   Operation *rootOp;
596   if (maskableOp.isMasked()) {
597     rewriter.setInsertionPoint(maskableOp.getMaskingOp());
598     rootOp = maskableOp.getMaskingOp();
599   } else {
600     rootOp = op;
601   }
602 
603   UnrolledOuterProductGenerator e(rewriter, op);
604   FailureOr<Value> matmatRes = e.matmat();
605   if (succeeded(matmatRes)) {
606     rewriter.replaceOp(rootOp, *matmatRes);
607     return success();
608   }
609   FailureOr<Value> matvecRes = e.matvec();
610   if (succeeded(matvecRes)) {
611     rewriter.replaceOp(rootOp, *matvecRes);
612     return success();
613   }
614   FailureOr<Value> tmatvecRes = e.tmatvec();
615   if (succeeded(tmatvecRes)) {
616     rewriter.replaceOp(rootOp, *tmatvecRes);
617     return success();
618   }
619 
620   return failure();
621 }
622 
623 LogicalResult
624 ContractionOpToDotLowering::matchAndRewrite(vector::ContractionOp op,
625                                             PatternRewriter &rewriter) const {
626   // TODO: Support vector.mask.
627   auto maskableOp = cast<MaskableOpInterface>(op.getOperation());
628   if (maskableOp.isMasked())
629     return failure();
630 
631   if (failed(filter(op)))
632     return failure();
633 
634   if (vectorTransformOptions.vectorContractLowering !=
635       vector::VectorContractLowering::Dot)
636     return failure();
637 
638   auto iteratorTypes = op.getIteratorTypes().getValue();
639   static constexpr std::array<int64_t, 2> perm = {1, 0};
640   Location loc = op.getLoc();
641   Value lhs = op.getLhs(), rhs = op.getRhs();
642 
643   using MapList = ArrayRef<ArrayRef<AffineExpr>>;
644   auto infer = [](MapList m) { return AffineMap::inferFromExprList(m); };
645   AffineExpr m, n, k;
646   bindDims(rewriter.getContext(), m, n, k);
647   SmallVector<AffineMap> maps = op.getIndexingMapsArray();
648   //
649   // In the following we wish to make the reduction dimension innermost so we
650   // can load vectors and just fmul + reduce into a scalar.
651   //
652   if (isParallelIterator(iteratorTypes[0]) &&
653       isParallelIterator(iteratorTypes[1]) &&
654       isReductionIterator(iteratorTypes[2])) {
655     //
656     // Two outer parallel, one inner reduction (matmat flavor).
657     //
658     if (maps == infer({{m, k}, {k, n}, {m, n}})) {
659       rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
660     } else if (maps == infer({{m, k}, {n, k}, {m, n}})) {
661       // No need to permute anything.
662     } else if (maps == infer({{k, m}, {k, n}, {m, n}})) {
663       lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
664       rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
665     } else if (maps == infer({{k, m}, {n, k}, {m, n}})) {
666       lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
667     } else if (maps == infer({{m, k}, {k, n}, {n, m}})) {
668       // This is the classical row-major matmul. Just permute the lhs.
669       Value tmp = lhs;
670       lhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
671       rhs = tmp;
672     } else if (maps == infer({{m, k}, {n, k}, {n, m}})) {
673       std::swap(lhs, rhs);
674     } else if (maps == infer({{k, m}, {k, n}, {n, m}})) {
675       Value tmp = lhs;
676       lhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
677       rhs = rewriter.create<vector::TransposeOp>(loc, tmp, perm);
678     } else if (maps == infer({{k, m}, {n, k}, {n, m}})) {
679       Value tmp = rhs;
680       rhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
681       lhs = tmp;
682     } else {
683       return failure();
684     }
685   } else if (isParallelIterator(iteratorTypes[0]) &&
686              isReductionIterator(iteratorTypes[1])) {
687     //
688     // One outer parallel, one inner reduction (matvec flavor)
689     //
690     if (maps == infer({{m, n}, {n}, {m}})) {
691       // No need to permute anything.
692     } else if (maps == infer({{n, m}, {n}, {m}})) {
693       lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
694     } else if (maps == infer({{n}, {m, n}, {m}})) {
695       std::swap(lhs, rhs);
696     } else if (maps == infer({{n}, {n, m}, {m}})) {
697       std::swap(lhs, rhs);
698       lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
699     } else {
700       return failure();
701     }
702   } else {
703     return failure();
704   }
705 
706   VectorType dstType = cast<VectorType>(op.getResultType());
707   assert(dstType.getRank() >= 1 && dstType.getRank() <= 2 &&
708          "Expected dst type of rank 1 or 2");
709 
710   unsigned rank = dstType.getRank();
711   unsigned dstRows = dstType.getShape()[0];
712   unsigned dstColumns = rank == 1 ? 1 : dstType.getShape()[1];
713 
714   // ExtractOp does not allow dynamic indexing, we must unroll explicitly.
715   Value res = rewriter.create<arith::ConstantOp>(loc, dstType,
716                                                  rewriter.getZeroAttr(dstType));
717   bool isInt = isa<IntegerType>(dstType.getElementType());
718   for (unsigned r = 0; r < dstRows; ++r) {
719     Value a = rewriter.create<vector::ExtractOp>(op.getLoc(), lhs, r);
720     for (unsigned c = 0; c < dstColumns; ++c) {
721       Value b = rank == 1
722                     ? rhs
723                     : rewriter.create<vector::ExtractOp>(op.getLoc(), rhs, c);
724       Value m = createMul(op.getLoc(), a, b, isInt, rewriter);
725       Value reduced = rewriter.create<vector::ReductionOp>(
726           op.getLoc(), vector::CombiningKind::ADD, m);
727 
728       SmallVector<int64_t, 2> pos = rank == 1 ? SmallVector<int64_t, 2>{r}
729                                               : SmallVector<int64_t, 2>{r, c};
730       res = rewriter.create<vector::InsertOp>(op.getLoc(), reduced, res, pos);
731     }
732   }
733   if (auto acc = op.getAcc())
734     res = createAdd(op.getLoc(), res, acc, isInt, rewriter);
735   rewriter.replaceOp(op, res);
736   return success();
737 }
738 
739 /// Lower vector.contract with all size one reduction dimensions to
740 /// elementwise ops when possible.
741 struct ContractOpToElementwise
742     : public OpRewritePattern<vector::ContractionOp> {
743   using OpRewritePattern::OpRewritePattern;
744   using FilterConstraintType =
745       std::function<LogicalResult(vector::ContractionOp op)>;
746   static LogicalResult defaultFilter(vector::ContractionOp op) {
747     return success();
748   }
749   ContractOpToElementwise(
750       vector::VectorTransformsOptions vectorTransformOptions,
751       MLIRContext *context, PatternBenefit benefit = 1,
752       const FilterConstraintType &constraint = defaultFilter)
753       : OpRewritePattern<vector::ContractionOp>(context, benefit),
754         vectorTransformOptions(vectorTransformOptions), filter(defaultFilter) {}
755 
756   LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
757                                 PatternRewriter &rewriter) const override {
758     // TODO: Support vector.mask.
759     auto maskableOp = cast<MaskableOpInterface>(contractOp.getOperation());
760     if (maskableOp.isMasked())
761       return failure();
762 
763     if (failed(filter(contractOp)))
764       return failure();
765 
766     if (vectorTransformOptions.vectorContractLowering !=
767         vector::VectorContractLowering::ParallelArith)
768       return failure();
769 
770     ArrayRef<int64_t> lhsShape = contractOp.getLhsType().getShape();
771     ArrayRef<int64_t> rhsShape = contractOp.getRhsType().getShape();
772     AffineMap lhsMap = contractOp.getIndexingMapsArray()[0];
773     AffineMap rhsMap = contractOp.getIndexingMapsArray()[1];
774     SmallVector<int64_t> lhsReductionDims =
775         getReductionIndex(lhsMap, contractOp.getIteratorTypes());
776     SmallVector<int64_t> rhsReductionDims =
777         getReductionIndex(rhsMap, contractOp.getIteratorTypes());
778     // All the reduction dimensions must be a size 1.
779     for (int64_t dim : lhsReductionDims) {
780       if (lhsShape[dim] != 1)
781         return failure();
782     }
783     for (int64_t dim : rhsReductionDims) {
784       if (rhsShape[dim] != 1)
785         return failure();
786     }
787     AffineMap accMap = contractOp.getIndexingMapsArray()[2];
788     unsigned numParallelDims = accMap.getNumResults();
789     unsigned numLhsDimToBroadcast =
790         numParallelDims - (lhsMap.getNumResults() - lhsReductionDims.size());
791     unsigned numRhsDimToBroadcast =
792         numParallelDims - (rhsMap.getNumResults() - rhsReductionDims.size());
793     SmallVector<int64_t> lhsDims;
794     SmallVector<int64_t> lhsTranspose;
795     SmallVector<int64_t> rhsDims;
796     SmallVector<int64_t> rhsTranspose;
797     for (int64_t dim : lhsReductionDims)
798       lhsTranspose.push_back(numLhsDimToBroadcast + dim);
799     for (int64_t dim : rhsReductionDims)
800       rhsTranspose.push_back(numRhsDimToBroadcast + dim);
801     // Loop through the parallel dimensions to calculate the dimensions to
802     // broadcast and to permute in order to extract only parallel dimensions.
803     for (unsigned i = 0; i < numParallelDims; i++) {
804       std::optional<unsigned> lhsDim =
805           getDimPosition(lhsMap, accMap.getDimPosition(i));
806       if (lhsDim) {
807         lhsTranspose.push_back(numLhsDimToBroadcast + *lhsDim);
808       } else {
809         // If the parallel dimension doesn't exist we will have to broadcast it.
810         lhsDims.push_back(
811             cast<VectorType>(contractOp.getResultType()).getDimSize(i));
812         lhsTranspose.push_back(lhsDims.size() - 1);
813       }
814       std::optional<unsigned> rhsDim =
815           getDimPosition(rhsMap, accMap.getDimPosition(i));
816       if (rhsDim) {
817         rhsTranspose.push_back(numRhsDimToBroadcast + *rhsDim);
818       } else {
819         // If the parallel dimension doesn't exist we will have to broadcast it.
820         rhsDims.push_back(
821             cast<VectorType>(contractOp.getResultType()).getDimSize(i));
822         rhsTranspose.push_back(rhsDims.size() - 1);
823       }
824     }
825     Value newLhs = contractOp.getLhs();
826     Value newRhs = contractOp.getRhs();
827     Location loc = contractOp.getLoc();
828     if (!lhsDims.empty()) {
829       lhsDims.append(lhsShape.begin(), lhsShape.end());
830       auto expandedType =
831           VectorType::get(lhsDims, contractOp.getLhsType().getElementType());
832       newLhs = rewriter.create<vector::BroadcastOp>(loc, expandedType, newLhs);
833     }
834     if (!rhsDims.empty()) {
835       rhsDims.append(rhsShape.begin(), rhsShape.end());
836       auto expandedType =
837           VectorType::get(rhsDims, contractOp.getRhsType().getElementType());
838       newRhs = rewriter.create<vector::BroadcastOp>(loc, expandedType, newRhs);
839     }
840     bool isInt = contractOp.getLhsType().getElementType().isIntOrIndex();
841     newLhs = rewriter.create<vector::TransposeOp>(loc, newLhs, lhsTranspose);
842     newRhs = rewriter.create<vector::TransposeOp>(loc, newRhs, rhsTranspose);
843     SmallVector<int64_t> lhsOffsets(lhsReductionDims.size(), 0);
844     SmallVector<int64_t> rhsOffsets(rhsReductionDims.size(), 0);
845     newLhs = rewriter.create<vector::ExtractOp>(loc, newLhs, lhsOffsets);
846     newRhs = rewriter.create<vector::ExtractOp>(loc, newRhs, rhsOffsets);
847     std::optional<Value> result =
848         createContractArithOp(loc, newLhs, newRhs, contractOp.getAcc(),
849                               contractOp.getKind(), rewriter, isInt);
850     rewriter.replaceOp(contractOp, {*result});
851     return success();
852   }
853 
854 private:
855   /// Options to control the vector patterns.
856   vector::VectorTransformsOptions vectorTransformOptions;
857   FilterConstraintType filter;
858 };
859 
860 /// Progressive lowering of ContractionOp.
861 /// One:
862 ///   %x = vector.contract with at least one free/batch dimension
863 /// is replaced by:
864 ///   %a = vector.contract with one less free/batch dimension
865 ///   %b = vector.contract with one less free/batch dimension
866 ///   ..
867 ///   %x = combine %a %b ..
868 /// until a pure contraction is reached (no free/batch dimensions),
869 /// which is replaced by a dot-product.
870 ///
871 /// This only kicks in when either VectorTransformsOptions is set
872 /// to DOT or when other contraction patterns fail.
873 //
874 // TODO: break down into transpose/reshape/cast ops
875 //               when they become available to avoid code dup
876 // TODO: investigate lowering order impact on performance
877 LogicalResult
878 ContractionOpLowering::matchAndRewrite(vector::ContractionOp op,
879                                        PatternRewriter &rewriter) const {
880   if (failed(filter(op)))
881     return failure();
882 
883   // TODO: support mixed mode contract lowering.
884   if (op.getLhsType().getElementType() !=
885           getElementTypeOrSelf(op.getAccType()) ||
886       op.getRhsType().getElementType() != getElementTypeOrSelf(op.getAccType()))
887     return failure();
888 
889   // TODO: the code below assumes the default contraction, make sure it supports
890   // other kinds before enabling this lowering.
891   if (op.getKind() != vector::CombiningKind::ADD) {
892     return rewriter.notifyMatchFailure(
893         op, "contractions other than 'add' not supported");
894   }
895 
896   // TODO: implement benefits, cost models.
897   MLIRContext *ctx = op.getContext();
898   ContractionOpToMatmulOpLowering pat1(vectorTransformOptions, ctx);
899   if (succeeded(pat1.matchAndRewrite(op, rewriter)))
900     return success();
901   ContractionOpToOuterProductOpLowering pat2(vectorTransformOptions, ctx);
902   if (succeeded(pat2.matchAndRewrite(op, rewriter)))
903     return success();
904   ContractionOpToDotLowering pat3(vectorTransformOptions, ctx);
905   if (succeeded(pat3.matchAndRewrite(op, rewriter)))
906     return success();
907   ContractOpToElementwise pat4(vectorTransformOptions, ctx);
908   if (succeeded(pat4.matchAndRewrite(op, rewriter)))
909     return success();
910 
911   // Vector mask setup.
912   OpBuilder::InsertionGuard guard(rewriter);
913   Operation *rootOp = op;
914   Value mask;
915   if (op.isMasked()) {
916     rewriter.setInsertionPoint(op.getMaskingOp());
917     rootOp = op.getMaskingOp();
918     mask = op.getMaskingOp().getMask();
919   }
920 
921   // Find first batch dimension in LHS/RHS, and lower when found.
922   std::vector<std::pair<int64_t, int64_t>> batchDimMap = op.getBatchDimMap();
923   if (!batchDimMap.empty()) {
924     int64_t lhsIndex = batchDimMap[0].first;
925     int64_t rhsIndex = batchDimMap[0].second;
926     auto newOp = lowerParallel(rewriter, op, lhsIndex, rhsIndex, mask);
927     if (failed(newOp))
928       return failure();
929     rewriter.replaceOp(rootOp, *newOp);
930     return success();
931   }
932 
933   // Collect contracting dimensions.
934   std::vector<std::pair<int64_t, int64_t>> contractingDimMap =
935       op.getContractingDimMap();
936   DenseSet<int64_t> lhsContractingDimSet;
937   DenseSet<int64_t> rhsContractingDimSet;
938   for (auto &dimPair : contractingDimMap) {
939     lhsContractingDimSet.insert(dimPair.first);
940     rhsContractingDimSet.insert(dimPair.second);
941   }
942 
943   // Find first free dimension in LHS, and lower when found.
944   VectorType lhsType = op.getLhsType();
945   for (int64_t lhsIndex = 0, e = lhsType.getRank(); lhsIndex < e; ++lhsIndex) {
946     if (lhsContractingDimSet.count(lhsIndex) == 0) {
947       auto newOp = lowerParallel(rewriter, op, lhsIndex, /*rhsIndex=*/-1, mask);
948       if (failed(newOp))
949         return failure();
950       rewriter.replaceOp(rootOp, *newOp);
951       return success();
952     }
953   }
954 
955   // Find first free dimension in RHS, and lower when found.
956   VectorType rhsType = op.getRhsType();
957   for (int64_t rhsIndex = 0, e = rhsType.getRank(); rhsIndex < e; ++rhsIndex) {
958     if (rhsContractingDimSet.count(rhsIndex) == 0) {
959       auto newOp = lowerParallel(rewriter, op, /*lhsIndex=*/-1, rhsIndex, mask);
960       if (failed(newOp))
961         return failure();
962       rewriter.replaceOp(rootOp, *newOp);
963       return success();
964     }
965   }
966 
967   // Lower the first remaining reduction dimension.
968   if (!contractingDimMap.empty()) {
969     auto newOp = lowerReduction(rewriter, op, mask);
970     if (failed(newOp))
971       return failure();
972     rewriter.replaceOp(rootOp, *newOp);
973     return success();
974   }
975 
976   return failure();
977 }
978 
979 // Lower one parallel dimension.
980 // Incidentally also tolerates unit-size (hence trivial) reduction dimensions.
981 // TODO: consider reusing existing contract unrolling
982 FailureOr<Value> ContractionOpLowering::lowerParallel(PatternRewriter &rewriter,
983                                                       vector::ContractionOp op,
984                                                       int64_t lhsIndex,
985                                                       int64_t rhsIndex,
986                                                       Value mask) const {
987   VectorType lhsType = op.getLhsType();
988   VectorType rhsType = op.getRhsType();
989   VectorType resType = cast<VectorType>(op.getResultType());
990   // Find the iterator type index and result index.
991   SmallVector<AffineMap> iMap = op.getIndexingMapsArray();
992   int64_t iterIndex = -1;
993   int64_t dimSize = -1;
994   if (lhsIndex >= 0) {
995     iterIndex = iMap[0].getDimPosition(lhsIndex);
996     if (rhsIndex >= 0 && iterIndex != iMap[1].getDimPosition(rhsIndex))
997       return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
998         diag << "expected lhsIndex=" << lhsIndex << " and rhsIndex=" << rhsIndex
999              << " to map to the same dimension";
1000       });
1001     if (lhsType.getScalableDims()[lhsIndex])
1002       return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
1003         diag << "Unrolling scalable dimension (lhsIndex=" << lhsIndex
1004              << ") is not supported yet";
1005       });
1006     dimSize = lhsType.getDimSize(lhsIndex);
1007   } else if (rhsIndex >= 0) {
1008     iterIndex = iMap[1].getDimPosition(rhsIndex);
1009     if (rhsType.getScalableDims()[rhsIndex])
1010       return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
1011         diag << "Unrolling scalable dimension (rhsIndex=" << rhsIndex
1012              << ") is not supported yet";
1013       });
1014     dimSize = rhsType.getDimSize(rhsIndex);
1015   }
1016   if (iterIndex < 0)
1017     return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
1018       diag << "expected either lhsIndex=" << lhsIndex
1019            << " or rhsIndex=" << rhsIndex << " to be nonnegative";
1020     });
1021   // value_or(-1) means that we tolerate a dimension not appearing
1022   // in the result map. That can't happen for actual parallel iterators, but
1023   // the caller ContractionOpLowering::matchAndRewrite is currently calling
1024   // lowerParallel also for the case of unit-size reduction dims appearing only
1025   // on one of LHS or RHS, not both. At the moment, such cases are created by
1026   // CastAwayContractionLeadingOneDim, so we need to either support that or
1027   // modify that pattern.
1028   int64_t resIndex = getResultIndex(iMap[2], iterIndex).value_or(-1);
1029   if (resIndex == -1 && dimSize != 1)
1030     return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
1031       diag << "expected the dimension for iterIndex=" << iterIndex
1032            << " to either appear in the result map, or to be a unit dimension";
1033     });
1034 
1035   // Construct new iterator types and affine map array attribute.
1036   std::array<AffineMap, 3> lowIndexingMaps = {
1037       adjustMap(iMap[0], iterIndex, rewriter),
1038       adjustMap(iMap[1], iterIndex, rewriter),
1039       adjustMap(iMap[2], iterIndex, rewriter)};
1040   auto lowAffine = rewriter.getAffineMapArrayAttr(lowIndexingMaps);
1041   auto lowIter =
1042       rewriter.getArrayAttr(adjustIter(op.getIteratorTypes(), iterIndex));
1043   // Unroll into a series of lower dimensional vector.contract ops.
1044   Location loc = op.getLoc();
1045   Value result = rewriter.create<arith::ConstantOp>(
1046       loc, resType, rewriter.getZeroAttr(resType));
1047 
1048   for (int64_t d = 0; d < dimSize; ++d) {
1049     auto lhs = reshapeLoad(loc, op.getLhs(), lhsType, lhsIndex, d, rewriter);
1050     auto rhs = reshapeLoad(loc, op.getRhs(), rhsType, rhsIndex, d, rewriter);
1051     auto acc = reshapeLoad(loc, op.getAcc(), resType, resIndex, d, rewriter);
1052 
1053     Value lowMask;
1054     if (mask)
1055       lowMask = reshapeLoad(loc, mask, cast<VectorType>(mask.getType()),
1056                             iterIndex, d, rewriter);
1057 
1058     Operation *lowContract = rewriter.create<vector::ContractionOp>(
1059         loc, lhs, rhs, acc, lowAffine, lowIter);
1060     lowContract = maskOperation(rewriter, lowContract, lowMask);
1061     result = reshapeStore(loc, lowContract->getResult(0), result, resType,
1062                           resIndex, d, rewriter);
1063   }
1064   return result;
1065 }
1066 
1067 // Lower one reduction dimension.
1068 FailureOr<Value> ContractionOpLowering::lowerReduction(
1069     PatternRewriter &rewriter, vector::ContractionOp op, Value mask) const {
1070   auto loc = op.getLoc();
1071   VectorType lhsType = op.getLhsType();
1072   VectorType rhsType = op.getRhsType();
1073   Type resType = op.getResultType();
1074   if (isa<VectorType>(resType))
1075     return rewriter.notifyMatchFailure(op,
1076                                        "did not expect a VectorType result");
1077   bool isInt = isa<IntegerType>(resType);
1078   // Use iterator index 0.
1079   int64_t iterIndex = 0;
1080   SmallVector<AffineMap> iMap = op.getIndexingMapsArray();
1081   std::optional<int64_t> lookupLhs = getResultIndex(iMap[0], iterIndex);
1082   std::optional<int64_t> lookupRhs = getResultIndex(iMap[1], iterIndex);
1083   if (!lookupLhs.has_value())
1084     return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
1085       diag << "expected iterIndex=" << iterIndex << "to map to a LHS dimension";
1086     });
1087   if (!lookupRhs.has_value())
1088     return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
1089       diag << "expected iterIndex=" << iterIndex << "to map to a RHS dimension";
1090     });
1091   int64_t lhsIndex = *lookupLhs;
1092   int64_t rhsIndex = *lookupRhs;
1093   int64_t dimSize = lhsType.getDimSize(lhsIndex);
1094   if (dimSize != rhsType.getDimSize(rhsIndex))
1095     return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
1096       diag << "expect LHS dimension " << lhsIndex
1097            << " to have the same size as RHS dimension " << rhsIndex;
1098     });
1099   // Base case.
1100   if (lhsType.getRank() == 1) {
1101     if (rhsType.getRank() != 1)
1102       return rewriter.notifyMatchFailure(
1103           op, "When LHS has rank 1, expected also RHS to have rank 1");
1104     Value m = createMul(loc, op.getLhs(), op.getRhs(), isInt, rewriter);
1105     auto kind = vector::CombiningKind::ADD;
1106 
1107     Value acc = op.getAcc();
1108     Operation *reductionOp =
1109         acc ? rewriter.create<vector::ReductionOp>(loc, kind, m, acc)
1110             : rewriter.create<vector::ReductionOp>(loc, kind, m);
1111     return maskOperation(rewriter, reductionOp, mask)->getResult(0);
1112   }
1113   // Construct new iterator types and affine map array attribute.
1114   std::array<AffineMap, 3> lowIndexingMaps = {
1115       adjustMap(iMap[0], iterIndex, rewriter),
1116       adjustMap(iMap[1], iterIndex, rewriter),
1117       adjustMap(iMap[2], iterIndex, rewriter)};
1118   auto lowAffine = rewriter.getAffineMapArrayAttr(lowIndexingMaps);
1119   auto lowIter =
1120       rewriter.getArrayAttr(adjustIter(op.getIteratorTypes(), iterIndex));
1121   // Unroll into a series of lower dimensional vector.contract ops.
1122   // By feeding the initial accumulator into the first contraction,
1123   // and the result of each contraction into the next, eventually
1124   // the sum of all reductions is computed.
1125   Value result = op.getAcc();
1126   for (int64_t d = 0; d < dimSize; ++d) {
1127     auto lhs = reshapeLoad(loc, op.getLhs(), lhsType, lhsIndex, d, rewriter);
1128     auto rhs = reshapeLoad(loc, op.getRhs(), rhsType, rhsIndex, d, rewriter);
1129     Value newMask;
1130     if (mask)
1131       newMask = reshapeLoad(loc, mask, cast<VectorType>(mask.getType()),
1132                             iterIndex, d, rewriter);
1133 
1134     Operation *newContract = rewriter.create<vector::ContractionOp>(
1135         loc, lhs, rhs, result, lowAffine, lowIter);
1136     result = maskOperation(rewriter, newContract, newMask)->getResult(0);
1137   }
1138   return result;
1139 }
1140 
1141 /// Progressive lowering of OuterProductOp.
1142 /// One:
1143 ///   %x = vector.outerproduct %lhs, %rhs, %acc
1144 /// is replaced by:
1145 ///   %z = zero-result
1146 ///   %0 = vector.extract %lhs[0]
1147 ///   %1 = vector.broadcast %0
1148 ///   %2 = vector.extract %acc[0]
1149 ///   %3 = vector.fma %1, %rhs, %2
1150 ///   %4 = vector.insert %3, %z[0]
1151 ///   ..
1152 ///   %x = vector.insert %.., %..[N-1]
1153 ///
1154 class OuterProductOpLowering : public OpRewritePattern<vector::OuterProductOp> {
1155 public:
1156   using OpRewritePattern::OpRewritePattern;
1157 
1158   LogicalResult matchAndRewrite(vector::OuterProductOp op,
1159                                 PatternRewriter &rewriter) const override {
1160     VectorType resType = op.getResultVectorType();
1161     if ((resType.getShape().size() >= 2) && resType.allDimsScalable())
1162       return failure();
1163 
1164     auto loc = op.getLoc();
1165 
1166     VectorType lhsType = op.getOperandVectorTypeLHS();
1167     VectorType rhsType = dyn_cast<VectorType>(op.getOperandTypeRHS());
1168     Type eltType = resType.getElementType();
1169     bool isInt = isa<IntegerType, IndexType>(eltType);
1170     Value acc = op.getAcc();
1171     vector::CombiningKind kind = op.getKind();
1172 
1173     // Vector mask setup.
1174     OpBuilder::InsertionGuard guard(rewriter);
1175     auto maskableOp = cast<vector::MaskableOpInterface>(op.getOperation());
1176     Operation *rootOp;
1177     Value mask;
1178     if (maskableOp.isMasked()) {
1179       rewriter.setInsertionPoint(maskableOp.getMaskingOp());
1180       rootOp = maskableOp.getMaskingOp();
1181       mask = maskableOp.getMaskingOp().getMask();
1182     } else {
1183       rootOp = op;
1184     }
1185 
1186     if (!rhsType) {
1187       // Special case: AXPY operation.
1188       Value b = rewriter.create<vector::BroadcastOp>(loc, lhsType, op.getRhs());
1189       std::optional<Value> mult = createContractArithOp(
1190           loc, op.getLhs(), b, acc, kind, rewriter, isInt, mask);
1191       if (!mult.has_value())
1192         return failure();
1193       rewriter.replaceOp(rootOp, *mult);
1194       return success();
1195     }
1196 
1197     Value result = rewriter.create<arith::ConstantOp>(
1198         loc, resType, rewriter.getZeroAttr(resType));
1199     for (int64_t d = 0, e = resType.getDimSize(0); d < e; ++d) {
1200       Value x = rewriter.create<vector::ExtractOp>(loc, op.getLhs(), d);
1201       Value a = rewriter.create<vector::BroadcastOp>(loc, rhsType, x);
1202       Value r = nullptr;
1203       if (acc)
1204         r = rewriter.create<vector::ExtractOp>(loc, acc, d);
1205       Value extrMask;
1206       if (mask)
1207         extrMask = rewriter.create<vector::ExtractOp>(loc, mask, d);
1208 
1209       std::optional<Value> m = createContractArithOp(
1210           loc, a, op.getRhs(), r, kind, rewriter, isInt, extrMask);
1211       if (!m.has_value())
1212         return failure();
1213       result = rewriter.create<vector::InsertOp>(loc, *m, result, d);
1214     }
1215 
1216     rewriter.replaceOp(rootOp, result);
1217     return success();
1218   }
1219 };
1220 
1221 /// Progressively lower a `vector.contract %a, %b, %c` with row-major matmul
1222 /// semantics to:
1223 /// ```
1224 ///    %mta = maybe_transpose
1225 ///    %mtb = maybe_transpose
1226 ///    %flattened_a = vector.shape_cast %mta
1227 ///    %flattened_b = vector.shape_cast %mtb
1228 ///    %flattened_d = vector.matmul %flattened_a, %flattened_b
1229 ///    %mtd = vector.shape_cast %flattened_d
1230 ///    %d = maybe_untranspose %mtd
1231 ///    %e = add %c, %d
1232 /// ```
1233 /// `vector.matmul` later lowers to `llvm.matrix.multiply`.
1234 //
1235 /// This only kicks in when VectorTransformsOptions is set to `Matmul`.
1236 /// vector.transpose operations are inserted if the vector.contract op is not a
1237 /// row-major matrix multiply.
1238 LogicalResult
1239 ContractionOpToMatmulOpLowering::matchAndRewrite(vector::ContractionOp op,
1240                                                  PatternRewriter &rew) const {
1241   // TODO: Support vector.mask.
1242   auto maskableOp = cast<MaskableOpInterface>(op.getOperation());
1243   if (maskableOp.isMasked())
1244     return failure();
1245 
1246   if (vectorTransformOptions.vectorContractLowering !=
1247       vector::VectorContractLowering::Matmul)
1248     return failure();
1249   if (failed(filter(op)))
1250     return failure();
1251 
1252   auto iteratorTypes = op.getIteratorTypes().getValue();
1253   if (!isParallelIterator(iteratorTypes[0]) ||
1254       !isParallelIterator(iteratorTypes[1]) ||
1255       !isReductionIterator(iteratorTypes[2]))
1256     return failure();
1257 
1258   Type elementType = op.getLhsType().getElementType();
1259   if (!elementType.isIntOrFloat())
1260     return failure();
1261 
1262   Type dstElementType = op.getType();
1263   if (auto vecType = dyn_cast<VectorType>(dstElementType))
1264     dstElementType = vecType.getElementType();
1265   if (elementType != dstElementType)
1266     return failure();
1267 
1268   // Perform lhs + rhs transpositions to conform to matmul row-major semantics.
1269   // Bail out if the contraction cannot be put in this form.
1270   MLIRContext *ctx = op.getContext();
1271   Location loc = op.getLoc();
1272   AffineExpr m, n, k;
1273   bindDims(rew.getContext(), m, n, k);
1274   // LHS must be A(m, k) or A(k, m).
1275   Value lhs = op.getLhs();
1276   auto lhsMap = op.getIndexingMapsArray()[0];
1277   if (lhsMap == AffineMap::get(3, 0, {k, m}, ctx))
1278     lhs = rew.create<vector::TransposeOp>(loc, lhs, ArrayRef<int64_t>{1, 0});
1279   else if (lhsMap != AffineMap::get(3, 0, {m, k}, ctx))
1280     return failure();
1281 
1282   // RHS must be B(k, n) or B(n, k).
1283   Value rhs = op.getRhs();
1284   auto rhsMap = op.getIndexingMapsArray()[1];
1285   if (rhsMap == AffineMap::get(3, 0, {n, k}, ctx))
1286     rhs = rew.create<vector::TransposeOp>(loc, rhs, ArrayRef<int64_t>{1, 0});
1287   else if (rhsMap != AffineMap::get(3, 0, {k, n}, ctx))
1288     return failure();
1289 
1290   // At this point lhs and rhs are in row-major.
1291   VectorType lhsType = cast<VectorType>(lhs.getType());
1292   VectorType rhsType = cast<VectorType>(rhs.getType());
1293   int64_t lhsRows = lhsType.getDimSize(0);
1294   int64_t lhsColumns = lhsType.getDimSize(1);
1295   int64_t rhsColumns = rhsType.getDimSize(1);
1296 
1297   Type flattenedLHSType =
1298       VectorType::get(lhsType.getNumElements(), lhsType.getElementType());
1299   lhs = rew.create<vector::ShapeCastOp>(loc, flattenedLHSType, lhs);
1300 
1301   Type flattenedRHSType =
1302       VectorType::get(rhsType.getNumElements(), rhsType.getElementType());
1303   rhs = rew.create<vector::ShapeCastOp>(loc, flattenedRHSType, rhs);
1304 
1305   Value mul = rew.create<vector::MatmulOp>(loc, lhs, rhs, lhsRows, lhsColumns,
1306                                            rhsColumns);
1307   mul = rew.create<vector::ShapeCastOp>(
1308       loc,
1309       VectorType::get({lhsRows, rhsColumns},
1310                       getElementTypeOrSelf(op.getAcc().getType())),
1311       mul);
1312 
1313   // ACC must be C(m, n) or C(n, m).
1314   auto accMap = op.getIndexingMapsArray()[2];
1315   if (accMap == AffineMap::get(3, 0, {n, m}, ctx))
1316     mul = rew.create<vector::TransposeOp>(loc, mul, ArrayRef<int64_t>{1, 0});
1317   else if (accMap != AffineMap::get(3, 0, {m, n}, ctx))
1318     llvm_unreachable("invalid contraction semantics");
1319 
1320   Value res =
1321       isa<IntegerType>(elementType)
1322           ? static_cast<Value>(rew.create<arith::AddIOp>(loc, op.getAcc(), mul))
1323           : static_cast<Value>(
1324                 rew.create<arith::AddFOp>(loc, op.getAcc(), mul));
1325 
1326   rew.replaceOp(op, res);
1327   return success();
1328 }
1329 } // namespace
1330 
1331 void mlir::vector::populateVectorContractLoweringPatterns(
1332     RewritePatternSet &patterns, VectorTransformsOptions options,
1333     PatternBenefit benefit, bool disableOuterProductLowering) {
1334   if (!disableOuterProductLowering)
1335     patterns.add<OuterProductOpLowering>(patterns.getContext(), benefit);
1336   patterns.add<ContractionOpLowering, ContractionOpToMatmulOpLowering,
1337                ContractionOpToOuterProductOpLowering>(
1338       options, patterns.getContext(), benefit);
1339 }
1340 
1341 void mlir::vector::populateVectorOuterProductLoweringPatterns(
1342     RewritePatternSet &patterns, PatternBenefit benefit) {
1343   patterns.add<OuterProductOpLowering>(patterns.getContext(), benefit);
1344 }
1345