xref: /llvm-project/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp (revision 16b75cd2bb439633d29c99a7663f2586e4068ecf)
1 //===- LowerVectorContract.cpp - Lower 'vector.contract' operation --------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements target-independent rewrites and utilities to lower the
10 // 'vector.contract' operation.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Dialect/Affine/IR/AffineOps.h"
15 #include "mlir/Dialect/Arith/IR/Arith.h"
16 #include "mlir/Dialect/Arith/Utils/Utils.h"
17 #include "mlir/Dialect/Linalg/IR/Linalg.h"
18 #include "mlir/Dialect/MemRef/IR/MemRef.h"
19 #include "mlir/Dialect/SCF/IR/SCF.h"
20 #include "mlir/Dialect/Tensor/IR/Tensor.h"
21 #include "mlir/Dialect/Utils/IndexingUtils.h"
22 #include "mlir/Dialect/Utils/StructuredOpsUtils.h"
23 #include "mlir/Dialect/Vector/IR/VectorOps.h"
24 #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
25 #include "mlir/Dialect/Vector/Utils/VectorUtils.h"
26 #include "mlir/IR/BuiltinAttributeInterfaces.h"
27 #include "mlir/IR/BuiltinTypes.h"
28 #include "mlir/IR/ImplicitLocOpBuilder.h"
29 #include "mlir/IR/Location.h"
30 #include "mlir/IR/Matchers.h"
31 #include "mlir/IR/PatternMatch.h"
32 #include "mlir/IR/TypeUtilities.h"
33 #include "mlir/Interfaces/VectorInterfaces.h"
34 #include "mlir/Support/LogicalResult.h"
35 
36 #define DEBUG_TYPE "vector-contract-lowering"
37 
38 using namespace mlir;
39 using namespace mlir::vector;
40 
41 //===----------------------------------------------------------------------===//
42 // Helper functions
43 //===----------------------------------------------------------------------===//
44 
45 // Helper to find an index in an affine map.
46 static std::optional<int64_t> getResultIndex(AffineMap map, int64_t index) {
47   for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) {
48     int64_t idx = map.getDimPosition(i);
49     if (idx == index)
50       return i;
51   }
52   return std::nullopt;
53 }
54 
55 // Helper to construct iterator types with one index removed.
56 static SmallVector<Attribute> adjustIter(ArrayAttr iteratorTypes,
57                                          int64_t index) {
58   SmallVector<Attribute> results;
59   for (const auto &it : llvm::enumerate(iteratorTypes)) {
60     int64_t idx = it.index();
61     if (idx == index)
62       continue;
63     results.push_back(it.value());
64   }
65   return results;
66 }
67 
68 // Helper to construct an affine map with one index removed.
69 static AffineMap adjustMap(AffineMap map, int64_t index,
70                            PatternRewriter &rewriter) {
71   auto *ctx = rewriter.getContext();
72   SmallVector<AffineExpr> results;
73   for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) {
74     int64_t idx = map.getDimPosition(i);
75     if (idx == index)
76       continue;
77     // Re-insert remaining indices, but renamed when occurring
78     // after the removed index.
79     auto targetExpr = getAffineDimExpr(idx < index ? idx : idx - 1, ctx);
80     results.push_back(targetExpr);
81   }
82   return AffineMap::get(map.getNumDims() - 1, 0, results, ctx);
83 }
84 
85 // Helper method to possibly drop a dimension in a load.
86 // TODO
87 static Value reshapeLoad(Location loc, Value val, VectorType type,
88                          int64_t index, int64_t pos,
89                          PatternRewriter &rewriter) {
90   if (index == -1)
91     return val;
92   Type lowType = VectorType::Builder(type).dropDim(0);
93   // At extraction dimension?
94   if (index == 0)
95     return rewriter.create<vector::ExtractOp>(loc, lowType, val, pos);
96   // Unroll leading dimensions.
97   VectorType vType = cast<VectorType>(lowType);
98   Type resType = VectorType::Builder(type).dropDim(index);
99   auto resVectorType = cast<VectorType>(resType);
100   Value result = rewriter.create<arith::ConstantOp>(
101       loc, resVectorType, rewriter.getZeroAttr(resVectorType));
102   for (int64_t d = 0, e = resVectorType.getDimSize(0); d < e; d++) {
103     Value ext = rewriter.create<vector::ExtractOp>(loc, vType, val, d);
104     Value load = reshapeLoad(loc, ext, vType, index - 1, pos, rewriter);
105     result =
106         rewriter.create<vector::InsertOp>(loc, resVectorType, load, result, d);
107   }
108   return result;
109 }
110 
111 // Helper method to possibly drop a dimension in a store.
112 // TODO
113 static Value reshapeStore(Location loc, Value val, Value result,
114                           VectorType type, int64_t index, int64_t pos,
115                           PatternRewriter &rewriter) {
116   // Unmodified?
117   if (index == -1)
118     return val;
119   // At insertion dimension?
120   if (index == 0)
121     return rewriter.create<vector::InsertOp>(loc, type, val, result, pos);
122   // Unroll leading dimensions.
123   Type lowType = VectorType::Builder(type).dropDim(0);
124   VectorType vType = cast<VectorType>(lowType);
125   Type insType = VectorType::Builder(vType).dropDim(0);
126   for (int64_t d = 0, e = type.getDimSize(0); d < e; d++) {
127     Value ext = rewriter.create<vector::ExtractOp>(loc, vType, result, d);
128     Value ins = rewriter.create<vector::ExtractOp>(loc, insType, val, d);
129     Value sto = reshapeStore(loc, ins, ext, vType, index - 1, pos, rewriter);
130     result = rewriter.create<vector::InsertOp>(loc, type, sto, result, d);
131   }
132   return result;
133 }
134 
135 /// Helper to create arithmetic operation associated with a kind of contraction.
136 static std::optional<Value>
137 createContractArithOp(Location loc, Value x, Value y, Value acc,
138                       vector::CombiningKind kind, PatternRewriter &rewriter,
139                       bool isInt, Value mask = Value()) {
140   using vector::CombiningKind;
141   Value mul;
142 
143   if (isInt) {
144     if (kind == CombiningKind::MINF || kind == CombiningKind::MAXF)
145       // Only valid for floating point types.
146       return std::nullopt;
147     mul = rewriter.create<arith::MulIOp>(loc, x, y);
148   } else {
149     // Float case.
150     if (kind == CombiningKind::AND || kind == CombiningKind::MINUI ||
151         kind == CombiningKind::MINSI || kind == CombiningKind::MAXUI ||
152         kind == CombiningKind::MAXSI || kind == CombiningKind::OR ||
153         kind == CombiningKind::XOR)
154       // Only valid for integer types.
155       return std::nullopt;
156     // Special case for fused multiply-add.
157     if (acc && isa<VectorType>(acc.getType()) && kind == CombiningKind::ADD) {
158       Value fma = rewriter.create<vector::FMAOp>(loc, x, y, acc);
159       if (mask)
160         // The fma op doesn't need explicit masking. However, fma ops used in
161         // reductions must preserve previous 'acc' values for masked-out lanes.
162         fma = selectPassthru(rewriter, mask, fma, acc);
163       return fma;
164     }
165     mul = rewriter.create<arith::MulFOp>(loc, x, y);
166   }
167 
168   if (!acc)
169     return std::optional<Value>(mul);
170 
171   return makeArithReduction(rewriter, loc, kind, mul, acc, mask);
172 }
173 
174 /// Return the positions of the reductions in the given map.
175 static SmallVector<int64_t> getReductionIndex(AffineMap map,
176                                               ArrayAttr iteratorTypes) {
177   SmallVector<int64_t> dimsIdx;
178   for (unsigned i = 0, e = map.getNumResults(); i < e; i++) {
179     if (isReductionIterator(iteratorTypes[map.getDimPosition(i)]))
180       dimsIdx.push_back(i);
181   }
182   return dimsIdx;
183 }
184 
185 /// Look for a given dimension in an affine map and return its position. Return
186 /// std::nullopt if the dimension is not in the map results.
187 static std::optional<unsigned> getDimPosition(AffineMap map, unsigned dim) {
188   for (unsigned i = 0, e = map.getNumResults(); i < e; i++) {
189     if (map.getDimPosition(i) == dim)
190       return i;
191   }
192   return std::nullopt;
193 }
194 
195 /// Creates an AddIOp if `isInt` is true otherwise create an arith::AddFOp using
196 /// operands `x` and `y`.
197 static Value createAdd(Location loc, Value x, Value y, bool isInt,
198                        PatternRewriter &rewriter) {
199   if (isInt)
200     return rewriter.create<arith::AddIOp>(loc, x, y);
201   return rewriter.create<arith::AddFOp>(loc, x, y);
202 }
203 
204 /// Creates a MulIOp if `isInt` is true otherwise create an MulFOp using
205 /// operands `x and `y`.
206 static Value createMul(Location loc, Value x, Value y, bool isInt,
207                        PatternRewriter &rewriter) {
208   if (isInt)
209     return rewriter.create<arith::MulIOp>(loc, x, y);
210   return rewriter.create<arith::MulFOp>(loc, x, y);
211 }
212 
213 namespace {
214 
215 /// Progressive lowering of a `vector.contract %a, %b, %c` with row-major matmul
216 /// semantics to:
217 /// ```
218 ///    %flattened_a = vector.shape_cast %a
219 ///    %flattened_b = vector.shape_cast %b
220 ///    %flattened_d = vector.matmul %flattened_a, %flattened_b
221 ///    %d = vector.shape_cast %%flattened_d
222 ///    %e = add %c, %d
223 /// ```
224 /// `vector.matmul` later lowers to `llvm.matrix.multiply`.
225 //
226 /// This only kicks in when VectorTransformsOptions is set to OuterProduct and
227 /// the vector.contract op is a row-major matrix multiply.
228 class ContractionOpToMatmulOpLowering
229     : public OpRewritePattern<vector::ContractionOp> {
230 public:
231   using OpRewritePattern::OpRewritePattern;
232 
233   using FilterConstraintType =
234       std::function<LogicalResult(vector::ContractionOp op)>;
235 
236   static LogicalResult defaultFilter(vector::ContractionOp op) {
237     return success();
238   }
239 
240   ContractionOpToMatmulOpLowering(
241       vector::VectorTransformsOptions vectorTransformOptions,
242       MLIRContext *context, PatternBenefit benefit = 1,
243       FilterConstraintType constraint = defaultFilter)
244       : OpRewritePattern<vector::ContractionOp>(context, benefit),
245         vectorTransformOptions(vectorTransformOptions),
246         filter(std::move(constraint)) {}
247 
248   LogicalResult matchAndRewrite(vector::ContractionOp op,
249                                 PatternRewriter &rewriter) const override;
250 
251 private:
252   /// Options to control the vector patterns.
253   vector::VectorTransformsOptions vectorTransformOptions;
254   FilterConstraintType filter;
255 };
256 
257 /// Progressive lowering of a `vector.contract %a, %b, %c` with row-major matmul
258 /// semantics to a reduction_size-unrolled sequence:
259 /// ```
260 ///    %at = vector.transpose %a, [1, 0]
261 ///    %bRow0 = vector.extract %b[0]
262 ///    %atRow0 = vector.extract %at[0]
263 ///    %c0 = vector.outerproduct %atRow0, %bRow0, %c
264 ///    ...
265 ///    %bRowK = vector.extract %b[K]
266 ///    %atRowK = vector.extract %at[K]
267 ///    %cK = vector.outerproduct %atRowK, %bRowK, %cK-1
268 /// ```
269 ///
270 /// This only kicks in when VectorTransformsOptions is set to OuterProduct and
271 /// the vector.contract op is a row-major matrix multiply.
272 class ContractionOpToOuterProductOpLowering
273     : public OpRewritePattern<vector::ContractionOp> {
274 public:
275   using OpRewritePattern::OpRewritePattern;
276 
277   using FilterConstraintType =
278       std::function<LogicalResult(vector::ContractionOp op)>;
279 
280   static LogicalResult defaultFilter(vector::ContractionOp op) {
281     return success();
282   }
283 
284   ContractionOpToOuterProductOpLowering(
285       vector::VectorTransformsOptions vectorTransformOptions,
286       MLIRContext *context, PatternBenefit benefit = 1,
287       FilterConstraintType constraint = defaultFilter)
288       : OpRewritePattern<vector::ContractionOp>(context, benefit),
289         vectorTransformOptions(vectorTransformOptions),
290         filter(std::move(constraint)) {}
291 
292   LogicalResult matchAndRewrite(vector::ContractionOp op,
293                                 PatternRewriter &rewriter) const override;
294 
295 private:
296   /// Options to control the vector patterns.
297   vector::VectorTransformsOptions vectorTransformOptions;
298   FilterConstraintType filter;
299 };
300 
301 /// Progressive lowering of a `vector.contract %a, %b, %c` with row-major matmul
302 /// semantics to an output-size-unrolled sequence:
303 /// ```
304 ///    %out = arith.constant ... : vector<MxNxelt_type>
305 ///    %bt = vector.transpose %b, [1, 0]
306 ///    %aRow0 = vector.extract %a[0]
307 ///    %btRow0 = vector.extract %bt[0]
308 ///    %c00 = vector.reduce %atRow0, %bRow0
309 ///    %out00 = vector.insert %c00, %out[0, 0]
310 ///    ...
311 ///    %aRowLast = vector.extract %at[M-1]
312 ///    %btRowLast = vector.extract %b[N-1]
313 ///    %cLastLast = vector.reduce %atRowLast, %bRowLast
314 ///    %outcLastLast = vector.insert %cLastLast, %out[M-1, N-1]
315 /// ```
316 ///
317 /// This only kicks in when VectorTransformsOptions is set to Dot and
318 /// the vector.contract op is a row-major matmul or matvec.
319 class ContractionOpToDotLowering
320     : public OpRewritePattern<vector::ContractionOp> {
321 public:
322   using OpRewritePattern::OpRewritePattern;
323 
324   using FilterConstraintType =
325       std::function<LogicalResult(vector::ContractionOp op)>;
326 
327   static LogicalResult defaultFilter(vector::ContractionOp op) {
328     return success();
329   }
330 
331   ContractionOpToDotLowering(
332       vector::VectorTransformsOptions vectorTransformOptions,
333       MLIRContext *context, PatternBenefit benefit = 1,
334       const FilterConstraintType &constraint = defaultFilter)
335       : OpRewritePattern<vector::ContractionOp>(context, benefit),
336         vectorTransformOptions(vectorTransformOptions), filter(defaultFilter) {}
337 
338   LogicalResult matchAndRewrite(vector::ContractionOp op,
339                                 PatternRewriter &rewriter) const override;
340 
341 private:
342   /// Options to control the vector patterns.
343   vector::VectorTransformsOptions vectorTransformOptions;
344   FilterConstraintType filter;
345 };
346 
347 /// Progressive lowering of ContractionOp.
348 ///
349 /// One:
350 ///   %x = vector.contract with at least one free/batch dimension
351 /// is replaced by:
352 ///   %a = vector.contract with one less free/batch dimension
353 ///   %b = vector.contract with one less free/batch dimension
354 ///   ..
355 ///   %x = combine %a %b ..
356 /// until a pure contraction is reached (no free/batch dimensions),
357 /// which is replaced by a dot-product.
358 ///
359 /// This only kicks in when either VectorTransformsOptions is set
360 /// to Dot or when other contraction patterns fail.
361 class ContractionOpLowering : public OpRewritePattern<vector::ContractionOp> {
362 public:
363   using OpRewritePattern::OpRewritePattern;
364   using FilterConstraintType =
365       std::function<LogicalResult(vector::ContractionOp op)>;
366 
367   static LogicalResult defaultFilter(vector::ContractionOp op) {
368     return success();
369   }
370 
371   ContractionOpLowering(vector::VectorTransformsOptions vectorTransformOptions,
372                         MLIRContext *context, PatternBenefit benefit = 1,
373                         FilterConstraintType constraint = defaultFilter)
374       : OpRewritePattern<vector::ContractionOp>(context, benefit),
375         vectorTransformOptions(vectorTransformOptions),
376         filter(std::move(constraint)) {}
377 
378   LogicalResult matchAndRewrite(vector::ContractionOp op,
379                                 PatternRewriter &rewriter) const override;
380 
381 private:
382   /// Options to control the vector patterns.
383   vector::VectorTransformsOptions vectorTransformOptions;
384   FilterConstraintType filter;
385   // Lower one parallel dimension.
386   FailureOr<Value> lowerParallel(PatternRewriter &rewriter,
387                                  vector::ContractionOp op, int64_t lhsIndex,
388                                  int64_t rhsIndex, Value mask) const;
389   // Lower one reduction dimension.
390   FailureOr<Value> lowerReduction(PatternRewriter &rewriter,
391                                   vector::ContractionOp op, Value mask) const;
392 };
393 
394 /// Generate a vector implementation for matmat, matvec and tmatvec.
395 /// This unrolls outer-products along the reduction dimension.
396 struct UnrolledOuterProductGenerator
397     : public StructuredGenerator<vector::ContractionOp, vector::IteratorType> {
398   UnrolledOuterProductGenerator(RewriterBase &b, vector::ContractionOp op)
399       : StructuredGenerator<vector::ContractionOp, vector::IteratorType>(b, op),
400         kind(op.getKind()), lhs(op.getLhs()), rhs(op.getRhs()),
401         res(op.getAcc()), lhsType(op.getLhsType()) {
402     auto maskableOp = cast<MaskableOpInterface>(op.getOperation());
403     if (maskableOp.isMasked())
404       mask = maskableOp.getMaskingOp().getMask();
405   }
406 
407   Value t(Value v, ArrayRef<int64_t> perm = {1, 0}) {
408     if (!v)
409       return v;
410     return rewriter.create<vector::TransposeOp>(loc, v, perm);
411   }
412 
413   Value promote(Value v, Type dstElementType) {
414     Type elementType = v.getType();
415     auto vecType = dyn_cast<VectorType>(elementType);
416     if (vecType)
417       elementType = vecType.getElementType();
418     if (elementType == dstElementType)
419       return v;
420     Type promotedType = dstElementType;
421     if (vecType)
422       promotedType = VectorType::get(vecType.getShape(), promotedType);
423     if (isa<FloatType>(dstElementType))
424       return rewriter.create<arith::ExtFOp>(loc, promotedType, v);
425     return rewriter.create<arith::ExtSIOp>(loc, promotedType, v);
426   }
427 
428   FailureOr<Value> outerProd(Value lhs, Value rhs, Value res, int reductionSize,
429                              std::optional<Value> maybeMask = std::nullopt) {
430     assert(reductionSize > 0);
431     // Incremental support for masking.
432     if (mask && !maybeMask.has_value())
433       return failure();
434 
435     Type resElementType = cast<VectorType>(res.getType()).getElementType();
436     for (int64_t k = 0; k < reductionSize; ++k) {
437       Value extractA = rewriter.create<vector::ExtractOp>(loc, lhs, k);
438       Value extractB = rewriter.create<vector::ExtractOp>(loc, rhs, k);
439       extractA = promote(extractA, resElementType);
440       extractB = promote(extractB, resElementType);
441       Value extractMask;
442       if (maybeMask.has_value() && maybeMask.value())
443         extractMask =
444             rewriter.create<vector::ExtractOp>(loc, maybeMask.value(), k);
445 
446       Operation *outerProdOp = rewriter.create<vector::OuterProductOp>(
447           loc, res.getType(), extractA, extractB, res, kind);
448       res = maskOperation(rewriter, outerProdOp, extractMask)->getResult(0);
449     }
450     return res;
451   }
452 
453   /// Two outer parallel, one inner reduction (matmat flavor).
454   FailureOr<Value> matmat() {
455     if (!iters({Par(), Par(), Red()}))
456       return failure();
457     // Set up the parallel/reduction structure in the right form.
458     AffineExpr m, n, k;
459     bindDims(rewriter.getContext(), m, n, k);
460     // Classical row-major matmul:  Just permute the lhs.
461     if (layout({{m, k}, {k, n}, {m, n}}))
462       return outerProd(t(lhs), rhs, res, lhsType.getDimSize(1),
463                        t(mask, {2, 0, 1}));
464     // TODO: may be better to fail and use some vector<k> -> scalar reduction.
465     if (layout({{m, k}, {n, k}, {m, n}})) {
466       Value tlhs = t(lhs);
467       return outerProd(tlhs, t(rhs), res, lhsType.getDimSize(1));
468     }
469     // No need to permute anything.
470     if (layout({{k, m}, {k, n}, {m, n}}))
471       return outerProd(lhs, rhs, res, lhsType.getDimSize(0));
472     // Just permute the rhs.
473     if (layout({{k, m}, {n, k}, {m, n}}))
474       return outerProd(lhs, t(rhs), res, lhsType.getDimSize(0));
475     // Transposed output: swap RHS and LHS.
476     // Classical row-major matmul: permute the lhs.
477     if (layout({{m, k}, {k, n}, {n, m}}))
478       return outerProd(rhs, t(lhs), res, lhsType.getDimSize(1));
479     // TODO: may be better to fail and use some vector<k> -> scalar reduction.
480     if (layout({{m, k}, {n, k}, {n, m}})) {
481       Value trhs = t(rhs);
482       return outerProd(trhs, t(lhs), res, lhsType.getDimSize(1));
483     }
484     if (layout({{k, m}, {k, n}, {n, m}}))
485       return outerProd(rhs, lhs, res, lhsType.getDimSize(0));
486     if (layout({{k, m}, {n, k}, {n, m}}))
487       return outerProd(t(rhs), lhs, res, lhsType.getDimSize(0));
488     return failure();
489   }
490 
491   /// One outer parallel, one inner reduction (matvec flavor)
492   FailureOr<Value> matvec() {
493     if (!iters({Par(), Red()}))
494       return failure();
495     AffineExpr m, k;
496     bindDims(rewriter.getContext(), m, k);
497 
498     // Case mat-vec: transpose.
499     if (layout({{m, k}, {k}, {m}}))
500       return outerProd(t(lhs), rhs, res, lhsType.getDimSize(1), t(mask));
501     // Case mat-trans-vec: ready to go.
502     if (layout({{k, m}, {k}, {m}}))
503       return outerProd(lhs, rhs, res, lhsType.getDimSize(0));
504     // Case vec-mat: swap and transpose.
505     if (layout({{k}, {m, k}, {m}}))
506       return outerProd(t(rhs), lhs, res, lhsType.getDimSize(0));
507     // Case vec-mat-trans: swap and ready to go.
508     if (layout({{k}, {k, m}, {m}}))
509       return outerProd(rhs, lhs, res, lhsType.getDimSize(0));
510     return failure();
511   }
512 
513   //
514   // One outer reduction, one inner parallel (tmatvec flavor)
515   //
516   FailureOr<Value> tmatvec() {
517     if (!iters({Red(), Par()}))
518       return failure();
519     AffineExpr k, m;
520     bindDims(rewriter.getContext(), k, m);
521 
522     // Case mat-vec: transpose.
523     if (layout({{m, k}, {k}, {m}}))
524       return outerProd(t(lhs), rhs, res, lhsType.getDimSize(1));
525     // Case mat-trans-vec: ready to go.
526     if (layout({{k, m}, {k}, {m}}))
527       return outerProd(lhs, rhs, res, lhsType.getDimSize(0));
528     // Case vec-mat: swap and transpose.
529     if (layout({{k}, {m, k}, {m}}))
530       return outerProd(t(rhs), lhs, res, lhsType.getDimSize(0));
531     // Case vec-mat-trans: swap and ready to go.
532     if (layout({{k}, {k, m}, {m}}))
533       return outerProd(rhs, lhs, res, lhsType.getDimSize(0));
534     return failure();
535   }
536 
537 private:
538   vector::CombiningKind kind;
539   Value lhs, rhs, res, mask;
540   VectorType lhsType;
541 };
542 
543 /// Progressively lower a `vector.contract %a, %b, %c` with row-major matmul
544 /// semantics to a reduction_size-unrolled sequence:
545 /// ```
546 ///    %at = vector.transpose %a, [1, 0]
547 ///    %bRow0 = vector.extract %b[0]
548 ///    %atRow0 = vector.extract %at[0]
549 ///    %c0 = vector.outerproduct %atRow0, %bRow0, %c
550 ///    ...
551 ///    %bRowK = vector.extract %b[K]
552 ///    %atRowK = vector.extract %at[K]
553 ///    %cK = vector.outerproduct %atRowK, %bRowK, %cK-1
554 /// ```
555 ///
556 /// This only kicks in when VectorTransformsOptions is set to OuterProduct but
557 /// otherwise supports any layout permutation of the matrix-multiply.
558 LogicalResult ContractionOpToOuterProductOpLowering::matchAndRewrite(
559     vector::ContractionOp op, PatternRewriter &rewriter) const {
560   if (vectorTransformOptions.vectorContractLowering !=
561       vector::VectorContractLowering::OuterProduct)
562     return failure();
563 
564   if (failed(filter(op)))
565     return failure();
566 
567   // Vector mask setup.
568   OpBuilder::InsertionGuard guard(rewriter);
569   auto maskableOp = cast<vector::MaskableOpInterface>(op.getOperation());
570   Operation *rootOp;
571   if (maskableOp.isMasked()) {
572     rewriter.setInsertionPoint(maskableOp.getMaskingOp());
573     rootOp = maskableOp.getMaskingOp();
574   } else {
575     rootOp = op;
576   }
577 
578   UnrolledOuterProductGenerator e(rewriter, op);
579   FailureOr<Value> matmatRes = e.matmat();
580   if (succeeded(matmatRes)) {
581     rewriter.replaceOp(rootOp, *matmatRes);
582     return success();
583   }
584   FailureOr<Value> matvecRes = e.matvec();
585   if (succeeded(matvecRes)) {
586     rewriter.replaceOp(rootOp, *matvecRes);
587     return success();
588   }
589   FailureOr<Value> tmatvecRes = e.tmatvec();
590   if (succeeded(tmatvecRes)) {
591     rewriter.replaceOp(rootOp, *tmatvecRes);
592     return success();
593   }
594 
595   return failure();
596 }
597 
598 LogicalResult
599 ContractionOpToDotLowering::matchAndRewrite(vector::ContractionOp op,
600                                             PatternRewriter &rewriter) const {
601   // TODO: Support vector.mask.
602   auto maskableOp = cast<MaskableOpInterface>(op.getOperation());
603   if (maskableOp.isMasked())
604     return failure();
605 
606   if (failed(filter(op)))
607     return failure();
608 
609   if (vectorTransformOptions.vectorContractLowering !=
610       vector::VectorContractLowering::Dot)
611     return failure();
612 
613   auto iteratorTypes = op.getIteratorTypes().getValue();
614   static constexpr std::array<int64_t, 2> perm = {1, 0};
615   Location loc = op.getLoc();
616   Value lhs = op.getLhs(), rhs = op.getRhs();
617 
618   using MapList = ArrayRef<ArrayRef<AffineExpr>>;
619   auto infer = [](MapList m) { return AffineMap::inferFromExprList(m); };
620   AffineExpr m, n, k;
621   bindDims(rewriter.getContext(), m, n, k);
622   SmallVector<AffineMap> maps = op.getIndexingMapsArray();
623   //
624   // In the following we wish to make the reduction dimension innermost so we
625   // can load vectors and just fmul + reduce into a scalar.
626   //
627   if (isParallelIterator(iteratorTypes[0]) &&
628       isParallelIterator(iteratorTypes[1]) &&
629       isReductionIterator(iteratorTypes[2])) {
630     //
631     // Two outer parallel, one inner reduction (matmat flavor).
632     //
633     if (maps == infer({{m, k}, {k, n}, {m, n}})) {
634       rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
635     } else if (maps == infer({{m, k}, {n, k}, {m, n}})) {
636       // No need to permute anything.
637     } else if (maps == infer({{k, m}, {k, n}, {m, n}})) {
638       lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
639       rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
640     } else if (maps == infer({{k, m}, {n, k}, {m, n}})) {
641       lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
642     } else if (maps == infer({{m, k}, {k, n}, {n, m}})) {
643       // This is the classical row-major matmul. Just permute the lhs.
644       Value tmp = lhs;
645       lhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
646       rhs = tmp;
647     } else if (maps == infer({{m, k}, {n, k}, {n, m}})) {
648       std::swap(lhs, rhs);
649     } else if (maps == infer({{k, m}, {k, n}, {n, m}})) {
650       Value tmp = lhs;
651       lhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
652       rhs = rewriter.create<vector::TransposeOp>(loc, tmp, perm);
653     } else if (maps == infer({{k, m}, {n, k}, {n, m}})) {
654       Value tmp = rhs;
655       rhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
656       lhs = tmp;
657     } else {
658       return failure();
659     }
660   } else if (isParallelIterator(iteratorTypes[0]) &&
661              isReductionIterator(iteratorTypes[1])) {
662     //
663     // One outer parallel, one inner reduction (matvec flavor)
664     //
665     if (maps == infer({{m, n}, {n}, {m}})) {
666       // No need to permute anything.
667     } else if (maps == infer({{n, m}, {n}, {m}})) {
668       lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
669     } else if (maps == infer({{n}, {m, n}, {m}})) {
670       std::swap(lhs, rhs);
671     } else if (maps == infer({{n}, {n, m}, {m}})) {
672       std::swap(lhs, rhs);
673       lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
674     } else {
675       return failure();
676     }
677   } else {
678     return failure();
679   }
680 
681   VectorType dstType = cast<VectorType>(op.getResultType());
682   assert(dstType.getRank() >= 1 && dstType.getRank() <= 2 &&
683          "Expected dst type of rank 1 or 2");
684 
685   unsigned rank = dstType.getRank();
686   unsigned dstRows = dstType.getShape()[0];
687   unsigned dstColumns = rank == 1 ? 1 : dstType.getShape()[1];
688 
689   // ExtractOp does not allow dynamic indexing, we must unroll explicitly.
690   Value res = rewriter.create<arith::ConstantOp>(loc, dstType,
691                                                  rewriter.getZeroAttr(dstType));
692   bool isInt = isa<IntegerType>(dstType.getElementType());
693   for (unsigned r = 0; r < dstRows; ++r) {
694     Value a = rewriter.create<vector::ExtractOp>(op.getLoc(), lhs, r);
695     for (unsigned c = 0; c < dstColumns; ++c) {
696       Value b = rank == 1
697                     ? rhs
698                     : rewriter.create<vector::ExtractOp>(op.getLoc(), rhs, c);
699       Value m = createMul(op.getLoc(), a, b, isInt, rewriter);
700       Value reduced = rewriter.create<vector::ReductionOp>(
701           op.getLoc(), vector::CombiningKind::ADD, m);
702 
703       SmallVector<int64_t, 2> pos = rank == 1 ? SmallVector<int64_t, 2>{r}
704                                               : SmallVector<int64_t, 2>{r, c};
705       res = rewriter.create<vector::InsertOp>(op.getLoc(), reduced, res, pos);
706     }
707   }
708   if (auto acc = op.getAcc())
709     res = createAdd(op.getLoc(), res, acc, isInt, rewriter);
710   rewriter.replaceOp(op, res);
711   return success();
712 }
713 
714 /// Lower vector.contract with all size one reduction dimensions to
715 /// elementwise ops when possible.
716 struct ContractOpToElementwise
717     : public OpRewritePattern<vector::ContractionOp> {
718   using OpRewritePattern::OpRewritePattern;
719   using FilterConstraintType =
720       std::function<LogicalResult(vector::ContractionOp op)>;
721   static LogicalResult defaultFilter(vector::ContractionOp op) {
722     return success();
723   }
724   ContractOpToElementwise(
725       vector::VectorTransformsOptions vectorTransformOptions,
726       MLIRContext *context, PatternBenefit benefit = 1,
727       const FilterConstraintType &constraint = defaultFilter)
728       : OpRewritePattern<vector::ContractionOp>(context, benefit),
729         vectorTransformOptions(vectorTransformOptions), filter(defaultFilter) {}
730 
731   LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
732                                 PatternRewriter &rewriter) const override {
733     // TODO: Support vector.mask.
734     auto maskableOp = cast<MaskableOpInterface>(contractOp.getOperation());
735     if (maskableOp.isMasked())
736       return failure();
737 
738     if (failed(filter(contractOp)))
739       return failure();
740 
741     if (vectorTransformOptions.vectorContractLowering !=
742         vector::VectorContractLowering::ParallelArith)
743       return failure();
744 
745     ArrayRef<int64_t> lhsShape = contractOp.getLhsType().getShape();
746     ArrayRef<int64_t> rhsShape = contractOp.getRhsType().getShape();
747     AffineMap lhsMap = contractOp.getIndexingMapsArray()[0];
748     AffineMap rhsMap = contractOp.getIndexingMapsArray()[1];
749     SmallVector<int64_t> lhsReductionDims =
750         getReductionIndex(lhsMap, contractOp.getIteratorTypes());
751     SmallVector<int64_t> rhsReductionDims =
752         getReductionIndex(rhsMap, contractOp.getIteratorTypes());
753     // All the reduction dimensions must be a size 1.
754     for (int64_t dim : lhsReductionDims) {
755       if (lhsShape[dim] != 1)
756         return failure();
757     }
758     for (int64_t dim : rhsReductionDims) {
759       if (rhsShape[dim] != 1)
760         return failure();
761     }
762     AffineMap accMap = contractOp.getIndexingMapsArray()[2];
763     unsigned numParallelDims = accMap.getNumResults();
764     unsigned numLhsDimToBroadcast =
765         numParallelDims - (lhsMap.getNumResults() - lhsReductionDims.size());
766     unsigned numRhsDimToBroadcast =
767         numParallelDims - (rhsMap.getNumResults() - rhsReductionDims.size());
768     SmallVector<int64_t> lhsDims;
769     SmallVector<int64_t> lhsTranspose;
770     SmallVector<int64_t> rhsDims;
771     SmallVector<int64_t> rhsTranspose;
772     for (int64_t dim : lhsReductionDims)
773       lhsTranspose.push_back(numLhsDimToBroadcast + dim);
774     for (int64_t dim : rhsReductionDims)
775       rhsTranspose.push_back(numRhsDimToBroadcast + dim);
776     // Loop through the parallel dimensions to calculate the dimensions to
777     // broadcast and to permute in order to extract only parallel dimensions.
778     for (unsigned i = 0; i < numParallelDims; i++) {
779       std::optional<unsigned> lhsDim =
780           getDimPosition(lhsMap, accMap.getDimPosition(i));
781       if (lhsDim) {
782         lhsTranspose.push_back(numLhsDimToBroadcast + *lhsDim);
783       } else {
784         // If the parallel dimension doesn't exist we will have to broadcast it.
785         lhsDims.push_back(
786             cast<VectorType>(contractOp.getResultType()).getDimSize(i));
787         lhsTranspose.push_back(lhsDims.size() - 1);
788       }
789       std::optional<unsigned> rhsDim =
790           getDimPosition(rhsMap, accMap.getDimPosition(i));
791       if (rhsDim) {
792         rhsTranspose.push_back(numRhsDimToBroadcast + *rhsDim);
793       } else {
794         // If the parallel dimension doesn't exist we will have to broadcast it.
795         rhsDims.push_back(
796             cast<VectorType>(contractOp.getResultType()).getDimSize(i));
797         rhsTranspose.push_back(rhsDims.size() - 1);
798       }
799     }
800     Value newLhs = contractOp.getLhs();
801     Value newRhs = contractOp.getRhs();
802     Location loc = contractOp.getLoc();
803     if (!lhsDims.empty()) {
804       lhsDims.append(lhsShape.begin(), lhsShape.end());
805       auto expandedType =
806           VectorType::get(lhsDims, contractOp.getLhsType().getElementType());
807       newLhs = rewriter.create<vector::BroadcastOp>(loc, expandedType, newLhs);
808     }
809     if (!rhsDims.empty()) {
810       rhsDims.append(rhsShape.begin(), rhsShape.end());
811       auto expandedType =
812           VectorType::get(rhsDims, contractOp.getRhsType().getElementType());
813       newRhs = rewriter.create<vector::BroadcastOp>(loc, expandedType, newRhs);
814     }
815     bool isInt = contractOp.getLhsType().getElementType().isIntOrIndex();
816     newLhs = rewriter.create<vector::TransposeOp>(loc, newLhs, lhsTranspose);
817     newRhs = rewriter.create<vector::TransposeOp>(loc, newRhs, rhsTranspose);
818     SmallVector<int64_t> lhsOffsets(lhsReductionDims.size(), 0);
819     SmallVector<int64_t> rhsOffsets(rhsReductionDims.size(), 0);
820     newLhs = rewriter.create<vector::ExtractOp>(loc, newLhs, lhsOffsets);
821     newRhs = rewriter.create<vector::ExtractOp>(loc, newRhs, rhsOffsets);
822     std::optional<Value> result =
823         createContractArithOp(loc, newLhs, newRhs, contractOp.getAcc(),
824                               contractOp.getKind(), rewriter, isInt);
825     rewriter.replaceOp(contractOp, {*result});
826     return success();
827   }
828 
829 private:
830   /// Options to control the vector patterns.
831   vector::VectorTransformsOptions vectorTransformOptions;
832   FilterConstraintType filter;
833 };
834 
835 /// Progressive lowering of ContractionOp.
836 /// One:
837 ///   %x = vector.contract with at least one free/batch dimension
838 /// is replaced by:
839 ///   %a = vector.contract with one less free/batch dimension
840 ///   %b = vector.contract with one less free/batch dimension
841 ///   ..
842 ///   %x = combine %a %b ..
843 /// until a pure contraction is reached (no free/batch dimensions),
844 /// which is replaced by a dot-product.
845 ///
846 /// This only kicks in when either VectorTransformsOptions is set
847 /// to DOT or when other contraction patterns fail.
848 //
849 // TODO: break down into transpose/reshape/cast ops
850 //               when they become available to avoid code dup
851 // TODO: investigate lowering order impact on performance
852 LogicalResult
853 ContractionOpLowering::matchAndRewrite(vector::ContractionOp op,
854                                        PatternRewriter &rewriter) const {
855   if (failed(filter(op)))
856     return failure();
857 
858   // TODO: support mixed mode contract lowering.
859   if (op.getLhsType().getElementType() !=
860           getElementTypeOrSelf(op.getAccType()) ||
861       op.getRhsType().getElementType() != getElementTypeOrSelf(op.getAccType()))
862     return failure();
863 
864   // TODO: the code below assumes the default contraction, make sure it supports
865   // other kinds before enabling this lowering.
866   if (op.getKind() != vector::CombiningKind::ADD) {
867     return rewriter.notifyMatchFailure(
868         op, "contractions other than 'add' not supported");
869   }
870 
871   // TODO: implement benefits, cost models.
872   MLIRContext *ctx = op.getContext();
873   ContractionOpToMatmulOpLowering pat1(vectorTransformOptions, ctx);
874   if (succeeded(pat1.matchAndRewrite(op, rewriter)))
875     return success();
876   ContractionOpToOuterProductOpLowering pat2(vectorTransformOptions, ctx);
877   if (succeeded(pat2.matchAndRewrite(op, rewriter)))
878     return success();
879   ContractionOpToDotLowering pat3(vectorTransformOptions, ctx);
880   if (succeeded(pat3.matchAndRewrite(op, rewriter)))
881     return success();
882   ContractOpToElementwise pat4(vectorTransformOptions, ctx);
883   if (succeeded(pat4.matchAndRewrite(op, rewriter)))
884     return success();
885 
886   // Vector mask setup.
887   OpBuilder::InsertionGuard guard(rewriter);
888   Operation *rootOp = op;
889   Value mask;
890   if (op.isMasked()) {
891     rewriter.setInsertionPoint(op.getMaskingOp());
892     rootOp = op.getMaskingOp();
893     mask = op.getMaskingOp().getMask();
894   }
895 
896   // Find first batch dimension in LHS/RHS, and lower when found.
897   std::vector<std::pair<int64_t, int64_t>> batchDimMap = op.getBatchDimMap();
898   if (!batchDimMap.empty()) {
899     int64_t lhsIndex = batchDimMap[0].first;
900     int64_t rhsIndex = batchDimMap[0].second;
901     auto newOp = lowerParallel(rewriter, op, lhsIndex, rhsIndex, mask);
902     if (failed(newOp))
903       return failure();
904     rewriter.replaceOp(rootOp, *newOp);
905     return success();
906   }
907 
908   // Collect contracting dimensions.
909   std::vector<std::pair<int64_t, int64_t>> contractingDimMap =
910       op.getContractingDimMap();
911   DenseSet<int64_t> lhsContractingDimSet;
912   DenseSet<int64_t> rhsContractingDimSet;
913   for (auto &dimPair : contractingDimMap) {
914     lhsContractingDimSet.insert(dimPair.first);
915     rhsContractingDimSet.insert(dimPair.second);
916   }
917 
918   // Find first free dimension in LHS, and lower when found.
919   VectorType lhsType = op.getLhsType();
920   for (int64_t lhsIndex = 0, e = lhsType.getRank(); lhsIndex < e; ++lhsIndex) {
921     if (lhsContractingDimSet.count(lhsIndex) == 0) {
922       auto newOp = lowerParallel(rewriter, op, lhsIndex, /*rhsIndex=*/-1, mask);
923       if (failed(newOp))
924         return failure();
925       rewriter.replaceOp(rootOp, *newOp);
926       return success();
927     }
928   }
929 
930   // Find first free dimension in RHS, and lower when found.
931   VectorType rhsType = op.getRhsType();
932   for (int64_t rhsIndex = 0, e = rhsType.getRank(); rhsIndex < e; ++rhsIndex) {
933     if (rhsContractingDimSet.count(rhsIndex) == 0) {
934       auto newOp = lowerParallel(rewriter, op, /*lhsIndex=*/-1, rhsIndex, mask);
935       if (failed(newOp))
936         return failure();
937       rewriter.replaceOp(rootOp, *newOp);
938       return success();
939     }
940   }
941 
942   // Lower the first remaining reduction dimension.
943   if (!contractingDimMap.empty()) {
944     auto newOp = lowerReduction(rewriter, op, mask);
945     if (failed(newOp))
946       return failure();
947     rewriter.replaceOp(rootOp, *newOp);
948     return success();
949   }
950 
951   return failure();
952 }
953 
954 // Lower one parallel dimension.
955 // Incidentally also tolerates unit-size (hence trivial) reduction dimensions.
956 // TODO: consider reusing existing contract unrolling
957 FailureOr<Value> ContractionOpLowering::lowerParallel(PatternRewriter &rewriter,
958                                                       vector::ContractionOp op,
959                                                       int64_t lhsIndex,
960                                                       int64_t rhsIndex,
961                                                       Value mask) const {
962   VectorType lhsType = op.getLhsType();
963   VectorType rhsType = op.getRhsType();
964   VectorType resType = cast<VectorType>(op.getResultType());
965   // Find the iterator type index and result index.
966   SmallVector<AffineMap> iMap = op.getIndexingMapsArray();
967   int64_t iterIndex = -1;
968   int64_t dimSize = -1;
969   if (lhsIndex >= 0) {
970     iterIndex = iMap[0].getDimPosition(lhsIndex);
971     if (rhsIndex >= 0 && iterIndex != iMap[1].getDimPosition(rhsIndex))
972       return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
973         diag << "expected lhsIndex=" << lhsIndex << " and rhsIndex=" << rhsIndex
974              << " to map to the same dimension";
975       });
976     dimSize = lhsType.getDimSize(lhsIndex);
977   } else if (rhsIndex >= 0) {
978     iterIndex = iMap[1].getDimPosition(rhsIndex);
979     dimSize = rhsType.getDimSize(rhsIndex);
980   }
981   if (iterIndex < 0)
982     return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
983       diag << "expected either lhsIndex=" << lhsIndex
984            << " or rhsIndex=" << rhsIndex << " to be nonnegative";
985     });
986   // value_or(-1) means that we tolerate a dimension not appearing
987   // in the result map. That can't happen for actual parallel iterators, but
988   // the caller ContractionOpLowering::matchAndRewrite is currently calling
989   // lowerParallel also for the case of unit-size reduction dims appearing only
990   // on one of LHS or RHS, not both. At the moment, such cases are created by
991   // CastAwayContractionLeadingOneDim, so we need to either support that or
992   // modify that pattern.
993   int64_t resIndex = getResultIndex(iMap[2], iterIndex).value_or(-1);
994   if (resIndex == -1 && dimSize != 1)
995     return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
996       diag << "expected the dimension for iterIndex=" << iterIndex
997            << " to either appear in the result map, or to be a unit dimension";
998     });
999 
1000   // Construct new iterator types and affine map array attribute.
1001   std::array<AffineMap, 3> lowIndexingMaps = {
1002       adjustMap(iMap[0], iterIndex, rewriter),
1003       adjustMap(iMap[1], iterIndex, rewriter),
1004       adjustMap(iMap[2], iterIndex, rewriter)};
1005   auto lowAffine = rewriter.getAffineMapArrayAttr(lowIndexingMaps);
1006   auto lowIter =
1007       rewriter.getArrayAttr(adjustIter(op.getIteratorTypes(), iterIndex));
1008   // Unroll into a series of lower dimensional vector.contract ops.
1009   Location loc = op.getLoc();
1010   Value result = rewriter.create<arith::ConstantOp>(
1011       loc, resType, rewriter.getZeroAttr(resType));
1012 
1013   for (int64_t d = 0; d < dimSize; ++d) {
1014     auto lhs = reshapeLoad(loc, op.getLhs(), lhsType, lhsIndex, d, rewriter);
1015     auto rhs = reshapeLoad(loc, op.getRhs(), rhsType, rhsIndex, d, rewriter);
1016     auto acc = reshapeLoad(loc, op.getAcc(), resType, resIndex, d, rewriter);
1017 
1018     Value lowMask;
1019     if (mask)
1020       lowMask = reshapeLoad(loc, mask, cast<VectorType>(mask.getType()),
1021                             iterIndex, d, rewriter);
1022 
1023     Operation *lowContract = rewriter.create<vector::ContractionOp>(
1024         loc, lhs, rhs, acc, lowAffine, lowIter);
1025     lowContract = maskOperation(rewriter, lowContract, lowMask);
1026     result = reshapeStore(loc, lowContract->getResult(0), result, resType,
1027                           resIndex, d, rewriter);
1028   }
1029   return result;
1030 }
1031 
1032 // Lower one reduction dimension.
1033 FailureOr<Value> ContractionOpLowering::lowerReduction(
1034     PatternRewriter &rewriter, vector::ContractionOp op, Value mask) const {
1035   auto loc = op.getLoc();
1036   VectorType lhsType = op.getLhsType();
1037   VectorType rhsType = op.getRhsType();
1038   Type resType = op.getResultType();
1039   if (isa<VectorType>(resType))
1040     return rewriter.notifyMatchFailure(op,
1041                                        "did not expect a VectorType result");
1042   bool isInt = isa<IntegerType>(resType);
1043   // Use iterator index 0.
1044   int64_t iterIndex = 0;
1045   SmallVector<AffineMap> iMap = op.getIndexingMapsArray();
1046   std::optional<int64_t> lookupLhs = getResultIndex(iMap[0], iterIndex);
1047   std::optional<int64_t> lookupRhs = getResultIndex(iMap[1], iterIndex);
1048   if (!lookupLhs.has_value())
1049     return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
1050       diag << "expected iterIndex=" << iterIndex << "to map to a LHS dimension";
1051     });
1052   if (!lookupRhs.has_value())
1053     return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
1054       diag << "expected iterIndex=" << iterIndex << "to map to a RHS dimension";
1055     });
1056   int64_t lhsIndex = *lookupLhs;
1057   int64_t rhsIndex = *lookupRhs;
1058   int64_t dimSize = lhsType.getDimSize(lhsIndex);
1059   if (dimSize != rhsType.getDimSize(rhsIndex))
1060     return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
1061       diag << "expect LHS dimension " << lhsIndex
1062            << " to have the same size as RHS dimension " << rhsIndex;
1063     });
1064   // Base case.
1065   if (lhsType.getRank() == 1) {
1066     if (rhsType.getRank() != 1)
1067       return rewriter.notifyMatchFailure(
1068           op, "When LHS has rank 1, expected also RHS to have rank 1");
1069     Value m = createMul(loc, op.getLhs(), op.getRhs(), isInt, rewriter);
1070     auto kind = vector::CombiningKind::ADD;
1071 
1072     Value acc = op.getAcc();
1073     Operation *reductionOp =
1074         acc ? rewriter.create<vector::ReductionOp>(loc, kind, m, acc)
1075             : rewriter.create<vector::ReductionOp>(loc, kind, m);
1076     return maskOperation(rewriter, reductionOp, mask)->getResult(0);
1077   }
1078   // Construct new iterator types and affine map array attribute.
1079   std::array<AffineMap, 3> lowIndexingMaps = {
1080       adjustMap(iMap[0], iterIndex, rewriter),
1081       adjustMap(iMap[1], iterIndex, rewriter),
1082       adjustMap(iMap[2], iterIndex, rewriter)};
1083   auto lowAffine = rewriter.getAffineMapArrayAttr(lowIndexingMaps);
1084   auto lowIter =
1085       rewriter.getArrayAttr(adjustIter(op.getIteratorTypes(), iterIndex));
1086   // Unroll into a series of lower dimensional vector.contract ops.
1087   // By feeding the initial accumulator into the first contraction,
1088   // and the result of each contraction into the next, eventually
1089   // the sum of all reductions is computed.
1090   Value result = op.getAcc();
1091   for (int64_t d = 0; d < dimSize; ++d) {
1092     auto lhs = reshapeLoad(loc, op.getLhs(), lhsType, lhsIndex, d, rewriter);
1093     auto rhs = reshapeLoad(loc, op.getRhs(), rhsType, rhsIndex, d, rewriter);
1094     Value newMask;
1095     if (mask)
1096       newMask = reshapeLoad(loc, mask, cast<VectorType>(mask.getType()),
1097                             iterIndex, d, rewriter);
1098 
1099     Operation *newContract = rewriter.create<vector::ContractionOp>(
1100         loc, lhs, rhs, result, lowAffine, lowIter);
1101     result = maskOperation(rewriter, newContract, newMask)->getResult(0);
1102   }
1103   return result;
1104 }
1105 
1106 /// Progressive lowering of OuterProductOp.
1107 /// One:
1108 ///   %x = vector.outerproduct %lhs, %rhs, %acc
1109 /// is replaced by:
1110 ///   %z = zero-result
1111 ///   %0 = vector.extract %lhs[0]
1112 ///   %1 = vector.broadcast %0
1113 ///   %2 = vector.extract %acc[0]
1114 ///   %3 = vector.fma %1, %rhs, %2
1115 ///   %4 = vector.insert %3, %z[0]
1116 ///   ..
1117 ///   %x = vector.insert %.., %..[N-1]
1118 ///
1119 class OuterProductOpLowering : public OpRewritePattern<vector::OuterProductOp> {
1120 public:
1121   using OpRewritePattern::OpRewritePattern;
1122 
1123   LogicalResult matchAndRewrite(vector::OuterProductOp op,
1124                                 PatternRewriter &rewriter) const override {
1125     auto loc = op.getLoc();
1126 
1127     VectorType lhsType = op.getOperandVectorTypeLHS();
1128     VectorType rhsType = dyn_cast<VectorType>(op.getOperandTypeRHS());
1129     VectorType resType = op.getResultVectorType();
1130     Type eltType = resType.getElementType();
1131     bool isInt = isa<IntegerType, IndexType>(eltType);
1132     Value acc = (op.getAcc().empty()) ? nullptr : op.getAcc()[0];
1133     vector::CombiningKind kind = op.getKind();
1134 
1135     // Vector mask setup.
1136     OpBuilder::InsertionGuard guard(rewriter);
1137     auto maskableOp = cast<vector::MaskableOpInterface>(op.getOperation());
1138     Operation *rootOp;
1139     Value mask;
1140     if (maskableOp.isMasked()) {
1141       rewriter.setInsertionPoint(maskableOp.getMaskingOp());
1142       rootOp = maskableOp.getMaskingOp();
1143       mask = maskableOp.getMaskingOp().getMask();
1144     } else {
1145       rootOp = op;
1146     }
1147 
1148     if (!rhsType) {
1149       // Special case: AXPY operation.
1150       Value b = rewriter.create<vector::BroadcastOp>(loc, lhsType, op.getRhs());
1151       std::optional<Value> mult = createContractArithOp(
1152           loc, op.getLhs(), b, acc, kind, rewriter, isInt, mask);
1153       if (!mult.has_value())
1154         return failure();
1155       rewriter.replaceOp(rootOp, *mult);
1156       return success();
1157     }
1158 
1159     Value result = rewriter.create<arith::ConstantOp>(
1160         loc, resType, rewriter.getZeroAttr(resType));
1161     for (int64_t d = 0, e = resType.getDimSize(0); d < e; ++d) {
1162       Value x = rewriter.create<vector::ExtractOp>(loc, op.getLhs(), d);
1163       Value a = rewriter.create<vector::BroadcastOp>(loc, rhsType, x);
1164       Value r = nullptr;
1165       if (acc)
1166         r = rewriter.create<vector::ExtractOp>(loc, acc, d);
1167       Value extrMask;
1168       if (mask)
1169         extrMask = rewriter.create<vector::ExtractOp>(loc, mask, d);
1170 
1171       std::optional<Value> m = createContractArithOp(
1172           loc, a, op.getRhs(), r, kind, rewriter, isInt, extrMask);
1173       if (!m.has_value())
1174         return failure();
1175       result = rewriter.create<vector::InsertOp>(loc, resType, *m, result, d);
1176     }
1177 
1178     rewriter.replaceOp(rootOp, result);
1179     return success();
1180   }
1181 };
1182 
1183 /// Progressively lower a `vector.contract %a, %b, %c` with row-major matmul
1184 /// semantics to:
1185 /// ```
1186 ///    %mta = maybe_transpose
1187 ///    %mtb = maybe_transpose
1188 ///    %flattened_a = vector.shape_cast %mta
1189 ///    %flattened_b = vector.shape_cast %mtb
1190 ///    %flattened_d = vector.matmul %flattened_a, %flattened_b
1191 ///    %mtd = vector.shape_cast %flattened_d
1192 ///    %d = maybe_untranspose %mtd
1193 ///    %e = add %c, %d
1194 /// ```
1195 /// `vector.matmul` later lowers to `llvm.matrix.multiply`.
1196 //
1197 /// This only kicks in when VectorTransformsOptions is set to `Matmul`.
1198 /// vector.transpose operations are inserted if the vector.contract op is not a
1199 /// row-major matrix multiply.
1200 LogicalResult
1201 ContractionOpToMatmulOpLowering::matchAndRewrite(vector::ContractionOp op,
1202                                                  PatternRewriter &rew) const {
1203   // TODO: Support vector.mask.
1204   auto maskableOp = cast<MaskableOpInterface>(op.getOperation());
1205   if (maskableOp.isMasked())
1206     return failure();
1207 
1208   if (vectorTransformOptions.vectorContractLowering !=
1209       vector::VectorContractLowering::Matmul)
1210     return failure();
1211   if (failed(filter(op)))
1212     return failure();
1213 
1214   auto iteratorTypes = op.getIteratorTypes().getValue();
1215   if (!isParallelIterator(iteratorTypes[0]) ||
1216       !isParallelIterator(iteratorTypes[1]) ||
1217       !isReductionIterator(iteratorTypes[2]))
1218     return failure();
1219 
1220   Type elementType = op.getLhsType().getElementType();
1221   if (!elementType.isIntOrFloat())
1222     return failure();
1223 
1224   Type dstElementType = op.getType();
1225   if (auto vecType = dyn_cast<VectorType>(dstElementType))
1226     dstElementType = vecType.getElementType();
1227   if (elementType != dstElementType)
1228     return failure();
1229 
1230   // Perform lhs + rhs transpositions to conform to matmul row-major semantics.
1231   // Bail out if the contraction cannot be put in this form.
1232   MLIRContext *ctx = op.getContext();
1233   Location loc = op.getLoc();
1234   AffineExpr m, n, k;
1235   bindDims(rew.getContext(), m, n, k);
1236   // LHS must be A(m, k) or A(k, m).
1237   Value lhs = op.getLhs();
1238   auto lhsMap = op.getIndexingMapsArray()[0];
1239   if (lhsMap == AffineMap::get(3, 0, {k, m}, ctx))
1240     lhs = rew.create<vector::TransposeOp>(loc, lhs, ArrayRef<int64_t>{1, 0});
1241   else if (lhsMap != AffineMap::get(3, 0, {m, k}, ctx))
1242     return failure();
1243 
1244   // RHS must be B(k, n) or B(n, k).
1245   Value rhs = op.getRhs();
1246   auto rhsMap = op.getIndexingMapsArray()[1];
1247   if (rhsMap == AffineMap::get(3, 0, {n, k}, ctx))
1248     rhs = rew.create<vector::TransposeOp>(loc, rhs, ArrayRef<int64_t>{1, 0});
1249   else if (rhsMap != AffineMap::get(3, 0, {k, n}, ctx))
1250     return failure();
1251 
1252   // At this point lhs and rhs are in row-major.
1253   VectorType lhsType = cast<VectorType>(lhs.getType());
1254   VectorType rhsType = cast<VectorType>(rhs.getType());
1255   int64_t lhsRows = lhsType.getDimSize(0);
1256   int64_t lhsColumns = lhsType.getDimSize(1);
1257   int64_t rhsColumns = rhsType.getDimSize(1);
1258 
1259   Type flattenedLHSType =
1260       VectorType::get(lhsType.getNumElements(), lhsType.getElementType());
1261   lhs = rew.create<vector::ShapeCastOp>(loc, flattenedLHSType, lhs);
1262 
1263   Type flattenedRHSType =
1264       VectorType::get(rhsType.getNumElements(), rhsType.getElementType());
1265   rhs = rew.create<vector::ShapeCastOp>(loc, flattenedRHSType, rhs);
1266 
1267   Value mul = rew.create<vector::MatmulOp>(loc, lhs, rhs, lhsRows, lhsColumns,
1268                                            rhsColumns);
1269   mul = rew.create<vector::ShapeCastOp>(
1270       loc,
1271       VectorType::get({lhsRows, rhsColumns},
1272                       getElementTypeOrSelf(op.getAcc().getType())),
1273       mul);
1274 
1275   // ACC must be C(m, n) or C(n, m).
1276   auto accMap = op.getIndexingMapsArray()[2];
1277   if (accMap == AffineMap::get(3, 0, {n, m}, ctx))
1278     mul = rew.create<vector::TransposeOp>(loc, mul, ArrayRef<int64_t>{1, 0});
1279   else if (accMap != AffineMap::get(3, 0, {m, n}, ctx))
1280     llvm_unreachable("invalid contraction semantics");
1281 
1282   Value res =
1283       isa<IntegerType>(elementType)
1284           ? static_cast<Value>(rew.create<arith::AddIOp>(loc, op.getAcc(), mul))
1285           : static_cast<Value>(
1286                 rew.create<arith::AddFOp>(loc, op.getAcc(), mul));
1287 
1288   rew.replaceOp(op, res);
1289   return success();
1290 }
1291 } // namespace
1292 
1293 void mlir::vector::populateVectorContractLoweringPatterns(
1294     RewritePatternSet &patterns, VectorTransformsOptions options,
1295     PatternBenefit benefit, bool disableOuterProductLowering) {
1296   if (!disableOuterProductLowering)
1297     patterns.add<OuterProductOpLowering>(patterns.getContext(), benefit);
1298   patterns.add<ContractionOpLowering, ContractionOpToMatmulOpLowering,
1299                ContractionOpToOuterProductOpLowering>(
1300       options, patterns.getContext(), benefit);
1301 }
1302 
1303 void mlir::vector::populateVectorOuterProductLoweringPatterns(
1304     RewritePatternSet &patterns, PatternBenefit benefit) {
1305   patterns.add<OuterProductOpLowering>(patterns.getContext(), benefit);
1306 }
1307