xref: /llvm-project/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp (revision fe8a62c46365f5ef0c15df2265bbf0026d0a4047)
1 //===- LowerVectorContract.cpp - Lower 'vector.contract' operation --------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements target-independent rewrites and utilities to lower the
10 // 'vector.contract' operation.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Dialect/Affine/IR/AffineOps.h"
15 #include "mlir/Dialect/Arith/IR/Arith.h"
16 #include "mlir/Dialect/Arith/Utils/Utils.h"
17 #include "mlir/Dialect/Linalg/IR/Linalg.h"
18 #include "mlir/Dialect/MemRef/IR/MemRef.h"
19 #include "mlir/Dialect/SCF/IR/SCF.h"
20 #include "mlir/Dialect/Tensor/IR/Tensor.h"
21 #include "mlir/Dialect/Utils/IndexingUtils.h"
22 #include "mlir/Dialect/Utils/StructuredOpsUtils.h"
23 #include "mlir/Dialect/Vector/IR/VectorOps.h"
24 #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
25 #include "mlir/Dialect/Vector/Utils/VectorUtils.h"
26 #include "mlir/IR/BuiltinAttributeInterfaces.h"
27 #include "mlir/IR/BuiltinTypes.h"
28 #include "mlir/IR/ImplicitLocOpBuilder.h"
29 #include "mlir/IR/Location.h"
30 #include "mlir/IR/Matchers.h"
31 #include "mlir/IR/PatternMatch.h"
32 #include "mlir/IR/TypeUtilities.h"
33 #include "mlir/Interfaces/VectorInterfaces.h"
34 #include "mlir/Support/LogicalResult.h"
35 
36 #define DEBUG_TYPE "vector-contract-lowering"
37 
38 using namespace mlir;
39 using namespace mlir::vector;
40 
41 //===----------------------------------------------------------------------===//
42 // Helper functions
43 //===----------------------------------------------------------------------===//
44 
45 // Helper to find an index in an affine map.
46 static std::optional<int64_t> getResultIndex(AffineMap map, int64_t index) {
47   for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) {
48     int64_t idx = map.getDimPosition(i);
49     if (idx == index)
50       return i;
51   }
52   return std::nullopt;
53 }
54 
55 // Helper to construct iterator types with one index removed.
56 static SmallVector<Attribute> adjustIter(ArrayAttr iteratorTypes,
57                                          int64_t index) {
58   SmallVector<Attribute> results;
59   for (const auto &it : llvm::enumerate(iteratorTypes)) {
60     int64_t idx = it.index();
61     if (idx == index)
62       continue;
63     results.push_back(it.value());
64   }
65   return results;
66 }
67 
68 // Helper to construct an affine map with one index removed.
69 static AffineMap adjustMap(AffineMap map, int64_t index,
70                            PatternRewriter &rewriter) {
71   auto *ctx = rewriter.getContext();
72   SmallVector<AffineExpr> results;
73   for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) {
74     int64_t idx = map.getDimPosition(i);
75     if (idx == index)
76       continue;
77     // Re-insert remaining indices, but renamed when occurring
78     // after the removed index.
79     auto targetExpr = getAffineDimExpr(idx < index ? idx : idx - 1, ctx);
80     results.push_back(targetExpr);
81   }
82   return AffineMap::get(map.getNumDims() - 1, 0, results, ctx);
83 }
84 
85 // Helper method to possibly drop a dimension in a load.
86 // TODO
87 static Value reshapeLoad(Location loc, Value val, VectorType type,
88                          int64_t index, int64_t pos,
89                          PatternRewriter &rewriter) {
90   if (index == -1)
91     return val;
92 
93   // At extraction dimension?
94   if (index == 0)
95     return rewriter.create<vector::ExtractOp>(loc, val, pos);
96 
97   // Unroll leading dimensions.
98   VectorType vType = VectorType::Builder(type).dropDim(0);
99   VectorType resType = VectorType::Builder(type).dropDim(index);
100   Value result = rewriter.create<arith::ConstantOp>(
101       loc, resType, rewriter.getZeroAttr(resType));
102   for (int64_t d = 0, e = resType.getDimSize(0); d < e; d++) {
103     Value ext = rewriter.create<vector::ExtractOp>(loc, val, d);
104     Value load = reshapeLoad(loc, ext, vType, index - 1, pos, rewriter);
105     result = rewriter.create<vector::InsertOp>(loc, load, result, d);
106   }
107   return result;
108 }
109 
110 // Helper method to possibly drop a dimension in a store.
111 // TODO
112 static Value reshapeStore(Location loc, Value val, Value result,
113                           VectorType type, int64_t index, int64_t pos,
114                           PatternRewriter &rewriter) {
115   // Unmodified?
116   if (index == -1)
117     return val;
118   // At insertion dimension?
119   if (index == 0)
120     return rewriter.create<vector::InsertOp>(loc, val, result, pos);
121 
122   // Unroll leading dimensions.
123   VectorType vType = VectorType::Builder(type).dropDim(0);
124   for (int64_t d = 0, e = type.getDimSize(0); d < e; d++) {
125     Value ext = rewriter.create<vector::ExtractOp>(loc, result, d);
126     Value ins = rewriter.create<vector::ExtractOp>(loc, val, d);
127     Value sto = reshapeStore(loc, ins, ext, vType, index - 1, pos, rewriter);
128     result = rewriter.create<vector::InsertOp>(loc, sto, result, d);
129   }
130   return result;
131 }
132 
133 /// Helper to create arithmetic operation associated with a kind of contraction.
134 static std::optional<Value>
135 createContractArithOp(Location loc, Value x, Value y, Value acc,
136                       vector::CombiningKind kind, PatternRewriter &rewriter,
137                       bool isInt, Value mask = Value()) {
138   using vector::CombiningKind;
139   Value mul;
140 
141   if (isInt) {
142     if (kind == CombiningKind::MINNUMF || kind == CombiningKind::MAXNUMF ||
143         kind == CombiningKind::MINIMUMF || kind == CombiningKind::MAXIMUMF)
144       // Only valid for floating point types.
145       return std::nullopt;
146     mul = rewriter.create<arith::MulIOp>(loc, x, y);
147   } else {
148     // Float case.
149     if (kind == CombiningKind::AND || kind == CombiningKind::MINUI ||
150         kind == CombiningKind::MINSI || kind == CombiningKind::MAXUI ||
151         kind == CombiningKind::MAXSI || kind == CombiningKind::OR ||
152         kind == CombiningKind::XOR)
153       // Only valid for integer types.
154       return std::nullopt;
155     // Special case for fused multiply-add.
156     if (acc && isa<VectorType>(acc.getType()) && kind == CombiningKind::ADD) {
157       Value fma = rewriter.create<vector::FMAOp>(loc, x, y, acc);
158       if (mask)
159         // The fma op doesn't need explicit masking. However, fma ops used in
160         // reductions must preserve previous 'acc' values for masked-out lanes.
161         fma = selectPassthru(rewriter, mask, fma, acc);
162       return fma;
163     }
164     mul = rewriter.create<arith::MulFOp>(loc, x, y);
165   }
166 
167   if (!acc)
168     return std::optional<Value>(mul);
169 
170   return makeArithReduction(rewriter, loc, kind, mul, acc,
171                             /*fastmath=*/nullptr, mask);
172 }
173 
174 /// Return the positions of the reductions in the given map.
175 static SmallVector<int64_t> getReductionIndex(AffineMap map,
176                                               ArrayAttr iteratorTypes) {
177   SmallVector<int64_t> dimsIdx;
178   for (unsigned i = 0, e = map.getNumResults(); i < e; i++) {
179     if (isReductionIterator(iteratorTypes[map.getDimPosition(i)]))
180       dimsIdx.push_back(i);
181   }
182   return dimsIdx;
183 }
184 
185 /// Look for a given dimension in an affine map and return its position. Return
186 /// std::nullopt if the dimension is not in the map results.
187 static std::optional<unsigned> getDimPosition(AffineMap map, unsigned dim) {
188   for (unsigned i = 0, e = map.getNumResults(); i < e; i++) {
189     if (map.getDimPosition(i) == dim)
190       return i;
191   }
192   return std::nullopt;
193 }
194 
195 /// Creates an AddIOp if `isInt` is true otherwise create an arith::AddFOp using
196 /// operands `x` and `y`.
197 static Value createAdd(Location loc, Value x, Value y, bool isInt,
198                        PatternRewriter &rewriter) {
199   if (isInt)
200     return rewriter.create<arith::AddIOp>(loc, x, y);
201   return rewriter.create<arith::AddFOp>(loc, x, y);
202 }
203 
204 /// Creates a MulIOp if `isInt` is true otherwise create an MulFOp using
205 /// operands `x and `y`.
206 static Value createMul(Location loc, Value x, Value y, bool isInt,
207                        PatternRewriter &rewriter) {
208   if (isInt)
209     return rewriter.create<arith::MulIOp>(loc, x, y);
210   return rewriter.create<arith::MulFOp>(loc, x, y);
211 }
212 
213 namespace {
214 
215 /// Progressive lowering of a `vector.contract %a, %b, %c` with row-major matmul
216 /// semantics to:
217 /// ```
218 ///    %flattened_a = vector.shape_cast %a
219 ///    %flattened_b = vector.shape_cast %b
220 ///    %flattened_d = vector.matmul %flattened_a, %flattened_b
221 ///    %d = vector.shape_cast %%flattened_d
222 ///    %e = add %c, %d
223 /// ```
224 /// `vector.matmul` later lowers to `llvm.matrix.multiply`.
225 //
226 /// This only kicks in when VectorTransformsOptions is set to OuterProduct and
227 /// the vector.contract op is a row-major matrix multiply.
228 class ContractionOpToMatmulOpLowering
229     : public OpRewritePattern<vector::ContractionOp> {
230 public:
231   using OpRewritePattern::OpRewritePattern;
232 
233   using FilterConstraintType =
234       std::function<LogicalResult(vector::ContractionOp op)>;
235 
236   static LogicalResult defaultFilter(vector::ContractionOp op) {
237     return success();
238   }
239 
240   ContractionOpToMatmulOpLowering(
241       vector::VectorTransformsOptions vectorTransformOptions,
242       MLIRContext *context, PatternBenefit benefit = 1,
243       FilterConstraintType constraint = defaultFilter)
244       : OpRewritePattern<vector::ContractionOp>(context, benefit),
245         vectorTransformOptions(vectorTransformOptions),
246         filter(std::move(constraint)) {}
247 
248   LogicalResult matchAndRewrite(vector::ContractionOp op,
249                                 PatternRewriter &rewriter) const override;
250 
251 private:
252   /// Options to control the vector patterns.
253   vector::VectorTransformsOptions vectorTransformOptions;
254   FilterConstraintType filter;
255 };
256 
257 /// Progressive lowering of a `vector.contract %a, %b, %c` with row-major matmul
258 /// semantics to a reduction_size-unrolled sequence:
259 /// ```
260 ///    %at = vector.transpose %a, [1, 0]
261 ///    %bRow0 = vector.extract %b[0]
262 ///    %atRow0 = vector.extract %at[0]
263 ///    %c0 = vector.outerproduct %atRow0, %bRow0, %c
264 ///    ...
265 ///    %bRowK = vector.extract %b[K]
266 ///    %atRowK = vector.extract %at[K]
267 ///    %cK = vector.outerproduct %atRowK, %bRowK, %cK-1
268 /// ```
269 ///
270 /// This only kicks in when VectorTransformsOptions is set to OuterProduct and
271 /// the vector.contract op is a row-major matrix multiply.
272 class ContractionOpToOuterProductOpLowering
273     : public OpRewritePattern<vector::ContractionOp> {
274 public:
275   using OpRewritePattern::OpRewritePattern;
276 
277   using FilterConstraintType =
278       std::function<LogicalResult(vector::ContractionOp op)>;
279 
280   static LogicalResult defaultFilter(vector::ContractionOp op) {
281     return success();
282   }
283 
284   ContractionOpToOuterProductOpLowering(
285       vector::VectorTransformsOptions vectorTransformOptions,
286       MLIRContext *context, PatternBenefit benefit = 1,
287       FilterConstraintType constraint = defaultFilter)
288       : OpRewritePattern<vector::ContractionOp>(context, benefit),
289         vectorTransformOptions(vectorTransformOptions),
290         filter(std::move(constraint)) {}
291 
292   LogicalResult matchAndRewrite(vector::ContractionOp op,
293                                 PatternRewriter &rewriter) const override;
294 
295 private:
296   /// Options to control the vector patterns.
297   vector::VectorTransformsOptions vectorTransformOptions;
298   FilterConstraintType filter;
299 };
300 
301 /// Progressive lowering of a `vector.contract %a, %b, %c` with row-major matmul
302 /// semantics to an output-size-unrolled sequence:
303 /// ```
304 ///    %out = arith.constant ... : vector<MxNxelt_type>
305 ///    %bt = vector.transpose %b, [1, 0]
306 ///    %aRow0 = vector.extract %a[0]
307 ///    %btRow0 = vector.extract %bt[0]
308 ///    %c00 = vector.reduce %atRow0, %bRow0
309 ///    %out00 = vector.insert %c00, %out[0, 0]
310 ///    ...
311 ///    %aRowLast = vector.extract %at[M-1]
312 ///    %btRowLast = vector.extract %b[N-1]
313 ///    %cLastLast = vector.reduce %atRowLast, %bRowLast
314 ///    %outcLastLast = vector.insert %cLastLast, %out[M-1, N-1]
315 /// ```
316 ///
317 /// This only kicks in when VectorTransformsOptions is set to Dot and
318 /// the vector.contract op is a row-major matmul or matvec.
319 class ContractionOpToDotLowering
320     : public OpRewritePattern<vector::ContractionOp> {
321 public:
322   using OpRewritePattern::OpRewritePattern;
323 
324   using FilterConstraintType =
325       std::function<LogicalResult(vector::ContractionOp op)>;
326 
327   static LogicalResult defaultFilter(vector::ContractionOp op) {
328     return success();
329   }
330 
331   ContractionOpToDotLowering(
332       vector::VectorTransformsOptions vectorTransformOptions,
333       MLIRContext *context, PatternBenefit benefit = 1,
334       const FilterConstraintType &constraint = defaultFilter)
335       : OpRewritePattern<vector::ContractionOp>(context, benefit),
336         vectorTransformOptions(vectorTransformOptions), filter(defaultFilter) {}
337 
338   LogicalResult matchAndRewrite(vector::ContractionOp op,
339                                 PatternRewriter &rewriter) const override;
340 
341 private:
342   /// Options to control the vector patterns.
343   vector::VectorTransformsOptions vectorTransformOptions;
344   FilterConstraintType filter;
345 };
346 
347 /// Progressive lowering of ContractionOp.
348 ///
349 /// One:
350 ///   %x = vector.contract with at least one free/batch dimension
351 /// is replaced by:
352 ///   %a = vector.contract with one less free/batch dimension
353 ///   %b = vector.contract with one less free/batch dimension
354 ///   ..
355 ///   %x = combine %a %b ..
356 /// until a pure contraction is reached (no free/batch dimensions),
357 /// which is replaced by a dot-product.
358 ///
359 /// This only kicks in when either VectorTransformsOptions is set
360 /// to Dot or when other contraction patterns fail.
361 class ContractionOpLowering : public OpRewritePattern<vector::ContractionOp> {
362 public:
363   using OpRewritePattern::OpRewritePattern;
364   using FilterConstraintType =
365       std::function<LogicalResult(vector::ContractionOp op)>;
366 
367   static LogicalResult defaultFilter(vector::ContractionOp op) {
368     return success();
369   }
370 
371   ContractionOpLowering(vector::VectorTransformsOptions vectorTransformOptions,
372                         MLIRContext *context, PatternBenefit benefit = 1,
373                         FilterConstraintType constraint = defaultFilter)
374       : OpRewritePattern<vector::ContractionOp>(context, benefit),
375         vectorTransformOptions(vectorTransformOptions),
376         filter(std::move(constraint)) {}
377 
378   LogicalResult matchAndRewrite(vector::ContractionOp op,
379                                 PatternRewriter &rewriter) const override;
380 
381 private:
382   /// Options to control the vector patterns.
383   vector::VectorTransformsOptions vectorTransformOptions;
384   FilterConstraintType filter;
385   // Lower one parallel dimension.
386   FailureOr<Value> lowerParallel(PatternRewriter &rewriter,
387                                  vector::ContractionOp op, int64_t lhsIndex,
388                                  int64_t rhsIndex, Value mask) const;
389   // Lower one reduction dimension.
390   FailureOr<Value> lowerReduction(PatternRewriter &rewriter,
391                                   vector::ContractionOp op, Value mask) const;
392 };
393 
394 /// Generate a vector implementation for matmat, matvec and tmatvec.
395 /// This unrolls outer-products along the reduction dimension.
396 struct UnrolledOuterProductGenerator
397     : public StructuredGenerator<vector::ContractionOp, vector::IteratorType> {
398   UnrolledOuterProductGenerator(RewriterBase &b, vector::ContractionOp op)
399       : StructuredGenerator<vector::ContractionOp, vector::IteratorType>(b, op),
400         kind(op.getKind()), lhs(op.getLhs()), rhs(op.getRhs()),
401         res(op.getAcc()), lhsType(op.getLhsType()) {
402     auto maskableOp = cast<MaskableOpInterface>(op.getOperation());
403     if (maskableOp.isMasked())
404       mask = maskableOp.getMaskingOp().getMask();
405   }
406 
407   Value t(Value v, ArrayRef<int64_t> perm = {1, 0}) {
408     if (!v)
409       return v;
410     return rewriter.create<vector::TransposeOp>(loc, v, perm);
411   }
412 
413   Value promote(Value v, Type dstElementType) {
414     Type elementType = v.getType();
415     auto vecType = dyn_cast<VectorType>(elementType);
416     if (vecType)
417       elementType = vecType.getElementType();
418     if (elementType == dstElementType)
419       return v;
420     Type promotedType = dstElementType;
421     if (vecType)
422       promotedType = vecType.clone(promotedType);
423     if (isa<FloatType>(dstElementType))
424       return rewriter.create<arith::ExtFOp>(loc, promotedType, v);
425     return rewriter.create<arith::ExtSIOp>(loc, promotedType, v);
426   }
427 
428   FailureOr<Value> outerProd(Value lhs, Value rhs, Value res,
429                              VectorType lhsType, int reductionSize,
430                              std::optional<Value> maybeMask = std::nullopt) {
431     // Incremental support for masking.
432     if (mask && !maybeMask.has_value())
433       return failure();
434 
435     Type resElementType = cast<VectorType>(res.getType()).getElementType();
436     for (int64_t k = 0; k < reductionSize; ++k) {
437       Value extractA = rewriter.create<vector::ExtractOp>(loc, lhs, k);
438       Value extractB = rewriter.create<vector::ExtractOp>(loc, rhs, k);
439       extractA = promote(extractA, resElementType);
440       extractB = promote(extractB, resElementType);
441       Value extractMask;
442       if (maybeMask.has_value() && maybeMask.value())
443         extractMask =
444             rewriter.create<vector::ExtractOp>(loc, maybeMask.value(), k);
445 
446       Operation *outerProdOp = rewriter.create<vector::OuterProductOp>(
447           loc, res.getType(), extractA, extractB, res, kind);
448       res = maskOperation(rewriter, outerProdOp, extractMask)->getResult(0);
449     }
450     return res;
451   }
452 
453   /// Helper function for `matmat`, `matvec`, `tmatvec`. Returns the size of
454   /// dimension `reductionDim`. If the dimension is a scalable dimension,
455   /// returns "nullopt".
456   std::optional<int64_t> getReductionSize(VectorType vecType,
457                                           int64_t reductionDim) {
458     // Cannot unroll scalable dimension.
459     if (vecType.getScalableDims()[reductionDim])
460       return std::nullopt;
461     int64_t reductionSize = vecType.getDimSize(reductionDim);
462     assert(reductionSize > 0 &&
463            "Reduction dim must be a known static size to allow unrolling");
464     return reductionSize;
465   }
466 
467   /// Two outer parallel, one inner reduction (matmat flavor).
468   FailureOr<Value> matmat() {
469     if (!iters({Par(), Par(), Red()}))
470       return failure();
471     // Set up the parallel/reduction structure in the right form.
472     AffineExpr m, n, k;
473     bindDims(rewriter.getContext(), m, n, k);
474 
475     // Classical row-major matmul:  Just permute the lhs.
476     if (layout({{m, k}, {k, n}, {m, n}})) {
477       if (auto reductionSize = getReductionSize(lhsType, 1)) {
478         // Note: `t` creates new IR. It must be nested within this `if` check
479         // so that no IR is created when then pattern returns "failure".
480         Value tLhs = t(lhs);
481         Value tMask = t(mask, {2, 0, 1});
482         return outerProd(tLhs, rhs, res, lhsType, *reductionSize, tMask);
483       }
484     }
485     // TODO: may be better to fail and use some vector<k> -> scalar reduction.
486     if (layout({{m, k}, {n, k}, {m, n}})) {
487       if (auto reductionSize = getReductionSize(lhsType, 1)) {
488         Value tLhs = t(lhs);
489         Value tRhs = t(rhs);
490         Value tMask = t(mask, {2, 0, 1});
491         return outerProd(tLhs, tRhs, res, lhsType, *reductionSize, tMask);
492       }
493     }
494     // No need to permute anything.
495     if (layout({{k, m}, {k, n}, {m, n}})) {
496       if (auto reductionSize = getReductionSize(lhsType, 0)) {
497         Value tMask = t(mask, {2, 0, 1});
498         return outerProd(lhs, rhs, res, lhsType, *reductionSize, tMask);
499       }
500     }
501     // Just permute the rhs.
502     if (layout({{k, m}, {n, k}, {m, n}})) {
503       if (auto reductionSize = getReductionSize(lhsType, 0)) {
504         Value tRhs = t(rhs);
505         Value tMask = t(mask, {2, 0, 1});
506         return outerProd(lhs, tRhs, res, lhsType, *reductionSize, tMask);
507       }
508     }
509     // Transposed output: swap RHS and LHS.
510     // Classical row-major matmul: permute the lhs.
511     if (layout({{m, k}, {k, n}, {n, m}})) {
512       if (auto reductionSize = getReductionSize(lhsType, 1)) {
513         Value tLhs = t(lhs);
514         Value tMask = t(mask, {2, 0, 1});
515         return outerProd(rhs, tLhs, res, lhsType, *reductionSize, tMask);
516       }
517     }
518     // TODO: may be better to fail and use some vector<k> -> scalar reduction.
519     if (layout({{m, k}, {n, k}, {n, m}})) {
520       if (auto reductionSize = getReductionSize(lhsType, 1)) {
521         Value tRhs = t(rhs);
522         Value tLhs = t(lhs);
523         Value tMask = t(mask, {2, 0, 1});
524         return outerProd(tRhs, tLhs, res, lhsType, *reductionSize, tMask);
525       }
526     }
527     if (layout({{k, m}, {k, n}, {n, m}})) {
528       if (auto reductionSize = getReductionSize(lhsType, 0)) {
529         Value tMask = t(mask, {2, 0, 1});
530         return outerProd(rhs, lhs, res, lhsType, *reductionSize, tMask);
531       }
532     }
533     if (layout({{k, m}, {n, k}, {n, m}})) {
534       if (auto reductionSize = getReductionSize(lhsType, 0)) {
535         Value tRhs = t(rhs);
536         Value tMask = t(mask, {2, 0, 1});
537         return outerProd(tRhs, lhs, res, lhsType, *reductionSize, tMask);
538       }
539     }
540     return failure();
541   }
542 
543   //
544   // One outer parallel, one inner reduction (matvec flavor).
545   // Mask needs to be transposed everywhere to turn the reduction dimension
546   // outermost as required by outerproduct.
547   //
548   FailureOr<Value> matvec() {
549     if (!iters({Par(), Red()}))
550       return failure();
551     AffineExpr m, k;
552     bindDims(rewriter.getContext(), m, k);
553 
554     // Case mat-vec: transpose.
555     if (layout({{m, k}, {k}, {m}})) {
556       if (auto reductionSize = getReductionSize(lhsType, 1)) {
557         Value tLhs = t(lhs);
558         Value tMask = t(mask);
559         return outerProd(tLhs, rhs, res, lhsType, *reductionSize, tMask);
560       }
561     }
562     // Case mat-trans-vec: ready to go.
563     if (layout({{k, m}, {k}, {m}})) {
564       if (auto reductionSize = getReductionSize(lhsType, 0)) {
565         Value tMask = t(mask);
566         return outerProd(lhs, rhs, res, lhsType, *reductionSize, tMask);
567       }
568     }
569     // Case vec-mat: swap and transpose.
570     if (layout({{k}, {m, k}, {m}})) {
571       if (auto reductionSize = getReductionSize(lhsType, 0)) {
572         Value tRhs = t(rhs);
573         Value tMask = t(mask);
574         return outerProd(tRhs, lhs, res, lhsType, *reductionSize, tMask);
575       }
576     }
577     // Case vec-mat-trans: swap and ready to go.
578     if (layout({{k}, {k, m}, {m}})) {
579       if (auto reductionSize = getReductionSize(lhsType, 0)) {
580         Value tMask = t(mask);
581         return outerProd(rhs, lhs, res, lhsType, *reductionSize, tMask);
582       }
583     }
584     return failure();
585   }
586 
587   //
588   // One outer reduction, one inner parallel (tmatvec flavor).
589   // Mask already has the shape of the outer product.
590   //
591   FailureOr<Value> tmatvec() {
592     if (!iters({Red(), Par()}))
593       return failure();
594     AffineExpr k, m;
595     bindDims(rewriter.getContext(), k, m);
596 
597     // Case mat-vec: transpose.
598     if (layout({{m, k}, {k}, {m}}))
599       if (auto reductionSize = getReductionSize(lhsType, 1))
600         return outerProd(t(lhs), rhs, res, lhsType, *reductionSize, mask);
601     // Case mat-trans-vec: ready to go.
602     if (layout({{k, m}, {k}, {m}}))
603       if (auto reductionSize = getReductionSize(lhsType, 0))
604         return outerProd(lhs, rhs, res, lhsType, *reductionSize, mask);
605     // Case vec-mat: swap and transpose.
606     if (layout({{k}, {m, k}, {m}}))
607       if (auto reductionSize = getReductionSize(lhsType, 0))
608         return outerProd(t(rhs), lhs, res, lhsType, *reductionSize, mask);
609     // Case vec-mat-trans: swap and ready to go.
610     if (layout({{k}, {k, m}, {m}}))
611       if (auto reductionSize = getReductionSize(lhsType, 0))
612         return outerProd(rhs, lhs, res, lhsType, *reductionSize, mask);
613     return failure();
614   }
615 
616 private:
617   vector::CombiningKind kind;
618   Value lhs, rhs, res, mask;
619   VectorType lhsType;
620 };
621 
622 /// Progressively lower a `vector.contract %a, %b, %c` with row-major matmul
623 /// semantics to a reduction_size-unrolled sequence:
624 /// ```
625 ///    %at = vector.transpose %a, [1, 0]
626 ///    %bRow0 = vector.extract %b[0]
627 ///    %atRow0 = vector.extract %at[0]
628 ///    %c0 = vector.outerproduct %atRow0, %bRow0, %c
629 ///    ...
630 ///    %bRowK = vector.extract %b[K]
631 ///    %atRowK = vector.extract %at[K]
632 ///    %cK = vector.outerproduct %atRowK, %bRowK, %cK-1
633 /// ```
634 ///
635 /// This only kicks in when VectorTransformsOptions is set to OuterProduct but
636 /// otherwise supports any layout permutation of the matrix-multiply.
637 LogicalResult ContractionOpToOuterProductOpLowering::matchAndRewrite(
638     vector::ContractionOp op, PatternRewriter &rewriter) const {
639   if (vectorTransformOptions.vectorContractLowering !=
640       vector::VectorContractLowering::OuterProduct)
641     return failure();
642 
643   if (failed(filter(op)))
644     return failure();
645 
646   // Vector mask setup.
647   OpBuilder::InsertionGuard guard(rewriter);
648   auto maskableOp = cast<vector::MaskableOpInterface>(op.getOperation());
649   Operation *rootOp;
650   if (maskableOp.isMasked()) {
651     rewriter.setInsertionPoint(maskableOp.getMaskingOp());
652     rootOp = maskableOp.getMaskingOp();
653   } else {
654     rootOp = op;
655   }
656 
657   UnrolledOuterProductGenerator e(rewriter, op);
658   FailureOr<Value> matmatRes = e.matmat();
659   if (succeeded(matmatRes)) {
660     rewriter.replaceOp(rootOp, *matmatRes);
661     return success();
662   }
663   FailureOr<Value> matvecRes = e.matvec();
664   if (succeeded(matvecRes)) {
665     rewriter.replaceOp(rootOp, *matvecRes);
666     return success();
667   }
668   FailureOr<Value> tmatvecRes = e.tmatvec();
669   if (succeeded(tmatvecRes)) {
670     rewriter.replaceOp(rootOp, *tmatvecRes);
671     return success();
672   }
673 
674   return failure();
675 }
676 
677 LogicalResult
678 ContractionOpToDotLowering::matchAndRewrite(vector::ContractionOp op,
679                                             PatternRewriter &rewriter) const {
680   // TODO: Support vector.mask.
681   auto maskableOp = cast<MaskableOpInterface>(op.getOperation());
682   if (maskableOp.isMasked())
683     return failure();
684 
685   if (failed(filter(op)))
686     return failure();
687 
688   if (vectorTransformOptions.vectorContractLowering !=
689       vector::VectorContractLowering::Dot)
690     return failure();
691 
692   auto iteratorTypes = op.getIteratorTypes().getValue();
693   static constexpr std::array<int64_t, 2> perm = {1, 0};
694   Location loc = op.getLoc();
695   Value lhs = op.getLhs(), rhs = op.getRhs();
696 
697   using MapList = ArrayRef<ArrayRef<AffineExpr>>;
698   auto infer = [&](MapList m) {
699     return AffineMap::inferFromExprList(m, op.getContext());
700   };
701   AffineExpr m, n, k;
702   bindDims(rewriter.getContext(), m, n, k);
703   SmallVector<AffineMap> maps = op.getIndexingMapsArray();
704   //
705   // In the following we wish to make the reduction dimension innermost so we
706   // can load vectors and just fmul + reduce into a scalar.
707   //
708   if (isParallelIterator(iteratorTypes[0]) &&
709       isParallelIterator(iteratorTypes[1]) &&
710       isReductionIterator(iteratorTypes[2])) {
711     //
712     // Two outer parallel, one inner reduction (matmat flavor).
713     //
714     if (maps == infer({{m, k}, {k, n}, {m, n}})) {
715       rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
716     } else if (maps == infer({{m, k}, {n, k}, {m, n}})) {
717       // No need to permute anything.
718     } else if (maps == infer({{k, m}, {k, n}, {m, n}})) {
719       lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
720       rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
721     } else if (maps == infer({{k, m}, {n, k}, {m, n}})) {
722       lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
723     } else if (maps == infer({{m, k}, {k, n}, {n, m}})) {
724       // This is the classical row-major matmul. Just permute the lhs.
725       Value tmp = lhs;
726       lhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
727       rhs = tmp;
728     } else if (maps == infer({{m, k}, {n, k}, {n, m}})) {
729       std::swap(lhs, rhs);
730     } else if (maps == infer({{k, m}, {k, n}, {n, m}})) {
731       Value tmp = lhs;
732       lhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
733       rhs = rewriter.create<vector::TransposeOp>(loc, tmp, perm);
734     } else if (maps == infer({{k, m}, {n, k}, {n, m}})) {
735       Value tmp = rhs;
736       rhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
737       lhs = tmp;
738     } else {
739       return failure();
740     }
741   } else if (isParallelIterator(iteratorTypes[0]) &&
742              isReductionIterator(iteratorTypes[1])) {
743     //
744     // One outer parallel, one inner reduction (matvec flavor)
745     //
746     if (maps == infer({{m, n}, {n}, {m}})) {
747       // No need to permute anything.
748     } else if (maps == infer({{n, m}, {n}, {m}})) {
749       lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
750     } else if (maps == infer({{n}, {m, n}, {m}})) {
751       std::swap(lhs, rhs);
752     } else if (maps == infer({{n}, {n, m}, {m}})) {
753       std::swap(lhs, rhs);
754       lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
755     } else {
756       return failure();
757     }
758   } else {
759     return failure();
760   }
761 
762   VectorType dstType = cast<VectorType>(op.getResultType());
763   assert(dstType.getRank() >= 1 && dstType.getRank() <= 2 &&
764          "Expected dst type of rank 1 or 2");
765 
766   unsigned rank = dstType.getRank();
767   unsigned dstRows = dstType.getShape()[0];
768   unsigned dstColumns = rank == 1 ? 1 : dstType.getShape()[1];
769 
770   // ExtractOp does not allow dynamic indexing, we must unroll explicitly.
771   Value res = rewriter.create<arith::ConstantOp>(loc, dstType,
772                                                  rewriter.getZeroAttr(dstType));
773   bool isInt = isa<IntegerType>(dstType.getElementType());
774   for (unsigned r = 0; r < dstRows; ++r) {
775     Value a = rewriter.create<vector::ExtractOp>(op.getLoc(), lhs, r);
776     for (unsigned c = 0; c < dstColumns; ++c) {
777       Value b = rank == 1
778                     ? rhs
779                     : rewriter.create<vector::ExtractOp>(op.getLoc(), rhs, c);
780       Value m = createMul(op.getLoc(), a, b, isInt, rewriter);
781       Value reduced = rewriter.create<vector::ReductionOp>(
782           op.getLoc(), vector::CombiningKind::ADD, m);
783 
784       SmallVector<int64_t, 2> pos = rank == 1 ? SmallVector<int64_t, 2>{r}
785                                               : SmallVector<int64_t, 2>{r, c};
786       res = rewriter.create<vector::InsertOp>(op.getLoc(), reduced, res, pos);
787     }
788   }
789   if (auto acc = op.getAcc())
790     res = createAdd(op.getLoc(), res, acc, isInt, rewriter);
791   rewriter.replaceOp(op, res);
792   return success();
793 }
794 
795 /// Lower vector.contract with all size one reduction dimensions to
796 /// elementwise ops when possible.
797 struct ContractOpToElementwise
798     : public OpRewritePattern<vector::ContractionOp> {
799   using OpRewritePattern::OpRewritePattern;
800   using FilterConstraintType =
801       std::function<LogicalResult(vector::ContractionOp op)>;
802   static LogicalResult defaultFilter(vector::ContractionOp op) {
803     return success();
804   }
805   ContractOpToElementwise(
806       vector::VectorTransformsOptions vectorTransformOptions,
807       MLIRContext *context, PatternBenefit benefit = 1,
808       const FilterConstraintType &constraint = defaultFilter)
809       : OpRewritePattern<vector::ContractionOp>(context, benefit),
810         vectorTransformOptions(vectorTransformOptions), filter(defaultFilter) {}
811 
812   LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
813                                 PatternRewriter &rewriter) const override {
814     // TODO: Support vector.mask.
815     auto maskableOp = cast<MaskableOpInterface>(contractOp.getOperation());
816     if (maskableOp.isMasked())
817       return failure();
818 
819     if (failed(filter(contractOp)))
820       return failure();
821 
822     if (vectorTransformOptions.vectorContractLowering !=
823         vector::VectorContractLowering::ParallelArith)
824       return failure();
825 
826     ArrayRef<int64_t> lhsShape = contractOp.getLhsType().getShape();
827     ArrayRef<int64_t> rhsShape = contractOp.getRhsType().getShape();
828     AffineMap lhsMap = contractOp.getIndexingMapsArray()[0];
829     AffineMap rhsMap = contractOp.getIndexingMapsArray()[1];
830     SmallVector<int64_t> lhsReductionDims =
831         getReductionIndex(lhsMap, contractOp.getIteratorTypes());
832     SmallVector<int64_t> rhsReductionDims =
833         getReductionIndex(rhsMap, contractOp.getIteratorTypes());
834     // All the reduction dimensions must be a size 1.
835     for (int64_t dim : lhsReductionDims) {
836       if (lhsShape[dim] != 1)
837         return failure();
838     }
839     for (int64_t dim : rhsReductionDims) {
840       if (rhsShape[dim] != 1)
841         return failure();
842     }
843     AffineMap accMap = contractOp.getIndexingMapsArray()[2];
844     unsigned numParallelDims = accMap.getNumResults();
845     unsigned numLhsDimToBroadcast =
846         numParallelDims - (lhsMap.getNumResults() - lhsReductionDims.size());
847     unsigned numRhsDimToBroadcast =
848         numParallelDims - (rhsMap.getNumResults() - rhsReductionDims.size());
849     SmallVector<int64_t> lhsDims;
850     SmallVector<int64_t> lhsTranspose;
851     SmallVector<int64_t> rhsDims;
852     SmallVector<int64_t> rhsTranspose;
853     for (int64_t dim : lhsReductionDims)
854       lhsTranspose.push_back(numLhsDimToBroadcast + dim);
855     for (int64_t dim : rhsReductionDims)
856       rhsTranspose.push_back(numRhsDimToBroadcast + dim);
857     // Loop through the parallel dimensions to calculate the dimensions to
858     // broadcast and to permute in order to extract only parallel dimensions.
859     for (unsigned i = 0; i < numParallelDims; i++) {
860       std::optional<unsigned> lhsDim =
861           getDimPosition(lhsMap, accMap.getDimPosition(i));
862       if (lhsDim) {
863         lhsTranspose.push_back(numLhsDimToBroadcast + *lhsDim);
864       } else {
865         // If the parallel dimension doesn't exist we will have to broadcast it.
866         lhsDims.push_back(
867             cast<VectorType>(contractOp.getResultType()).getDimSize(i));
868         lhsTranspose.push_back(lhsDims.size() - 1);
869       }
870       std::optional<unsigned> rhsDim =
871           getDimPosition(rhsMap, accMap.getDimPosition(i));
872       if (rhsDim) {
873         rhsTranspose.push_back(numRhsDimToBroadcast + *rhsDim);
874       } else {
875         // If the parallel dimension doesn't exist we will have to broadcast it.
876         rhsDims.push_back(
877             cast<VectorType>(contractOp.getResultType()).getDimSize(i));
878         rhsTranspose.push_back(rhsDims.size() - 1);
879       }
880     }
881     Value newLhs = contractOp.getLhs();
882     Value newRhs = contractOp.getRhs();
883     Location loc = contractOp.getLoc();
884     if (!lhsDims.empty()) {
885       lhsDims.append(lhsShape.begin(), lhsShape.end());
886       auto expandedType =
887           VectorType::get(lhsDims, contractOp.getLhsType().getElementType());
888       newLhs = rewriter.create<vector::BroadcastOp>(loc, expandedType, newLhs);
889     }
890     if (!rhsDims.empty()) {
891       rhsDims.append(rhsShape.begin(), rhsShape.end());
892       auto expandedType =
893           VectorType::get(rhsDims, contractOp.getRhsType().getElementType());
894       newRhs = rewriter.create<vector::BroadcastOp>(loc, expandedType, newRhs);
895     }
896     bool isInt = contractOp.getLhsType().getElementType().isIntOrIndex();
897     newLhs = rewriter.create<vector::TransposeOp>(loc, newLhs, lhsTranspose);
898     newRhs = rewriter.create<vector::TransposeOp>(loc, newRhs, rhsTranspose);
899     SmallVector<int64_t> lhsOffsets(lhsReductionDims.size(), 0);
900     SmallVector<int64_t> rhsOffsets(rhsReductionDims.size(), 0);
901     newLhs = rewriter.create<vector::ExtractOp>(loc, newLhs, lhsOffsets);
902     newRhs = rewriter.create<vector::ExtractOp>(loc, newRhs, rhsOffsets);
903     std::optional<Value> result =
904         createContractArithOp(loc, newLhs, newRhs, contractOp.getAcc(),
905                               contractOp.getKind(), rewriter, isInt);
906     rewriter.replaceOp(contractOp, {*result});
907     return success();
908   }
909 
910 private:
911   /// Options to control the vector patterns.
912   vector::VectorTransformsOptions vectorTransformOptions;
913   FilterConstraintType filter;
914 };
915 
916 /// Progressive lowering of ContractionOp.
917 /// One:
918 ///   %x = vector.contract with at least one free/batch dimension
919 /// is replaced by:
920 ///   %a = vector.contract with one less free/batch dimension
921 ///   %b = vector.contract with one less free/batch dimension
922 ///   ..
923 ///   %x = combine %a %b ..
924 /// until a pure contraction is reached (no free/batch dimensions),
925 /// which is replaced by a dot-product.
926 ///
927 /// This only kicks in when either VectorTransformsOptions is set
928 /// to DOT or when other contraction patterns fail.
929 //
930 // TODO: break down into transpose/reshape/cast ops
931 //               when they become available to avoid code dup
932 // TODO: investigate lowering order impact on performance
933 LogicalResult
934 ContractionOpLowering::matchAndRewrite(vector::ContractionOp op,
935                                        PatternRewriter &rewriter) const {
936   if (failed(filter(op)))
937     return failure();
938 
939   // TODO: support mixed mode contract lowering.
940   if (op.getLhsType().getElementType() !=
941           getElementTypeOrSelf(op.getAccType()) ||
942       op.getRhsType().getElementType() != getElementTypeOrSelf(op.getAccType()))
943     return failure();
944 
945   // TODO: the code below assumes the default contraction, make sure it supports
946   // other kinds before enabling this lowering.
947   if (op.getKind() != vector::CombiningKind::ADD) {
948     return rewriter.notifyMatchFailure(
949         op, "contractions other than 'add' not supported");
950   }
951 
952   // TODO: implement benefits, cost models.
953   MLIRContext *ctx = op.getContext();
954   ContractionOpToMatmulOpLowering pat1(vectorTransformOptions, ctx);
955   if (succeeded(pat1.matchAndRewrite(op, rewriter)))
956     return success();
957   ContractionOpToOuterProductOpLowering pat2(vectorTransformOptions, ctx);
958   if (succeeded(pat2.matchAndRewrite(op, rewriter)))
959     return success();
960   ContractionOpToDotLowering pat3(vectorTransformOptions, ctx);
961   if (succeeded(pat3.matchAndRewrite(op, rewriter)))
962     return success();
963   ContractOpToElementwise pat4(vectorTransformOptions, ctx);
964   if (succeeded(pat4.matchAndRewrite(op, rewriter)))
965     return success();
966 
967   // Vector mask setup.
968   OpBuilder::InsertionGuard guard(rewriter);
969   Operation *rootOp = op;
970   Value mask;
971   if (op.isMasked()) {
972     rewriter.setInsertionPoint(op.getMaskingOp());
973     rootOp = op.getMaskingOp();
974     mask = op.getMaskingOp().getMask();
975   }
976 
977   // Find first batch dimension in LHS/RHS, and lower when found.
978   std::vector<std::pair<int64_t, int64_t>> batchDimMap = op.getBatchDimMap();
979   if (!batchDimMap.empty()) {
980     int64_t lhsIndex = batchDimMap[0].first;
981     int64_t rhsIndex = batchDimMap[0].second;
982     auto newOp = lowerParallel(rewriter, op, lhsIndex, rhsIndex, mask);
983     if (failed(newOp))
984       return failure();
985     rewriter.replaceOp(rootOp, *newOp);
986     return success();
987   }
988 
989   // Collect contracting dimensions.
990   std::vector<std::pair<int64_t, int64_t>> contractingDimMap =
991       op.getContractingDimMap();
992   DenseSet<int64_t> lhsContractingDimSet;
993   DenseSet<int64_t> rhsContractingDimSet;
994   for (auto &dimPair : contractingDimMap) {
995     lhsContractingDimSet.insert(dimPair.first);
996     rhsContractingDimSet.insert(dimPair.second);
997   }
998 
999   // Find first free dimension in LHS, and lower when found.
1000   VectorType lhsType = op.getLhsType();
1001   for (int64_t lhsIndex = 0, e = lhsType.getRank(); lhsIndex < e; ++lhsIndex) {
1002     if (lhsContractingDimSet.count(lhsIndex) == 0) {
1003       auto newOp = lowerParallel(rewriter, op, lhsIndex, /*rhsIndex=*/-1, mask);
1004       if (failed(newOp))
1005         return failure();
1006       rewriter.replaceOp(rootOp, *newOp);
1007       return success();
1008     }
1009   }
1010 
1011   // Find first free dimension in RHS, and lower when found.
1012   VectorType rhsType = op.getRhsType();
1013   for (int64_t rhsIndex = 0, e = rhsType.getRank(); rhsIndex < e; ++rhsIndex) {
1014     if (rhsContractingDimSet.count(rhsIndex) == 0) {
1015       auto newOp = lowerParallel(rewriter, op, /*lhsIndex=*/-1, rhsIndex, mask);
1016       if (failed(newOp))
1017         return failure();
1018       rewriter.replaceOp(rootOp, *newOp);
1019       return success();
1020     }
1021   }
1022 
1023   // Lower the first remaining reduction dimension.
1024   if (!contractingDimMap.empty()) {
1025     auto newOp = lowerReduction(rewriter, op, mask);
1026     if (failed(newOp))
1027       return failure();
1028     rewriter.replaceOp(rootOp, *newOp);
1029     return success();
1030   }
1031 
1032   return failure();
1033 }
1034 
1035 // Lower one parallel dimension.
1036 // Incidentally also tolerates unit-size (hence trivial) reduction dimensions.
1037 // TODO: consider reusing existing contract unrolling
1038 FailureOr<Value> ContractionOpLowering::lowerParallel(PatternRewriter &rewriter,
1039                                                       vector::ContractionOp op,
1040                                                       int64_t lhsIndex,
1041                                                       int64_t rhsIndex,
1042                                                       Value mask) const {
1043   VectorType lhsType = op.getLhsType();
1044   VectorType rhsType = op.getRhsType();
1045   VectorType resType = cast<VectorType>(op.getResultType());
1046   // Find the iterator type index and result index.
1047   SmallVector<AffineMap> iMap = op.getIndexingMapsArray();
1048   int64_t iterIndex = -1;
1049   int64_t dimSize = -1;
1050   if (lhsIndex >= 0) {
1051     iterIndex = iMap[0].getDimPosition(lhsIndex);
1052     if (rhsIndex >= 0 && iterIndex != iMap[1].getDimPosition(rhsIndex))
1053       return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
1054         diag << "expected lhsIndex=" << lhsIndex << " and rhsIndex=" << rhsIndex
1055              << " to map to the same dimension";
1056       });
1057     if (lhsType.getScalableDims()[lhsIndex])
1058       return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
1059         diag << "Unrolling scalable dimension (lhsIndex=" << lhsIndex
1060              << ") is not supported yet";
1061       });
1062     dimSize = lhsType.getDimSize(lhsIndex);
1063   } else if (rhsIndex >= 0) {
1064     iterIndex = iMap[1].getDimPosition(rhsIndex);
1065     if (rhsType.getScalableDims()[rhsIndex])
1066       return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
1067         diag << "Unrolling scalable dimension (rhsIndex=" << rhsIndex
1068              << ") is not supported yet";
1069       });
1070     dimSize = rhsType.getDimSize(rhsIndex);
1071   }
1072   if (iterIndex < 0)
1073     return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
1074       diag << "expected either lhsIndex=" << lhsIndex
1075            << " or rhsIndex=" << rhsIndex << " to be nonnegative";
1076     });
1077   // value_or(-1) means that we tolerate a dimension not appearing
1078   // in the result map. That can't happen for actual parallel iterators, but
1079   // the caller ContractionOpLowering::matchAndRewrite is currently calling
1080   // lowerParallel also for the case of unit-size reduction dims appearing only
1081   // on one of LHS or RHS, not both. At the moment, such cases are created by
1082   // CastAwayContractionLeadingOneDim, so we need to either support that or
1083   // modify that pattern.
1084   int64_t resIndex = getResultIndex(iMap[2], iterIndex).value_or(-1);
1085   if (resIndex == -1 && dimSize != 1)
1086     return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
1087       diag << "expected the dimension for iterIndex=" << iterIndex
1088            << " to either appear in the result map, or to be a unit dimension";
1089     });
1090 
1091   // Construct new iterator types and affine map array attribute.
1092   std::array<AffineMap, 3> lowIndexingMaps = {
1093       adjustMap(iMap[0], iterIndex, rewriter),
1094       adjustMap(iMap[1], iterIndex, rewriter),
1095       adjustMap(iMap[2], iterIndex, rewriter)};
1096   auto lowAffine = rewriter.getAffineMapArrayAttr(lowIndexingMaps);
1097   auto lowIter =
1098       rewriter.getArrayAttr(adjustIter(op.getIteratorTypes(), iterIndex));
1099   // Unroll into a series of lower dimensional vector.contract ops.
1100   Location loc = op.getLoc();
1101   Value result = rewriter.create<arith::ConstantOp>(
1102       loc, resType, rewriter.getZeroAttr(resType));
1103 
1104   for (int64_t d = 0; d < dimSize; ++d) {
1105     auto lhs = reshapeLoad(loc, op.getLhs(), lhsType, lhsIndex, d, rewriter);
1106     auto rhs = reshapeLoad(loc, op.getRhs(), rhsType, rhsIndex, d, rewriter);
1107     auto acc = reshapeLoad(loc, op.getAcc(), resType, resIndex, d, rewriter);
1108 
1109     Value lowMask;
1110     if (mask)
1111       lowMask = reshapeLoad(loc, mask, cast<VectorType>(mask.getType()),
1112                             iterIndex, d, rewriter);
1113 
1114     Operation *lowContract = rewriter.create<vector::ContractionOp>(
1115         loc, lhs, rhs, acc, lowAffine, lowIter);
1116     lowContract = maskOperation(rewriter, lowContract, lowMask);
1117     result = reshapeStore(loc, lowContract->getResult(0), result, resType,
1118                           resIndex, d, rewriter);
1119   }
1120   return result;
1121 }
1122 
1123 // Lower one reduction dimension.
1124 FailureOr<Value> ContractionOpLowering::lowerReduction(
1125     PatternRewriter &rewriter, vector::ContractionOp op, Value mask) const {
1126   auto loc = op.getLoc();
1127   VectorType lhsType = op.getLhsType();
1128   VectorType rhsType = op.getRhsType();
1129   Type resType = op.getResultType();
1130   if (isa<VectorType>(resType))
1131     return rewriter.notifyMatchFailure(op,
1132                                        "did not expect a VectorType result");
1133   bool isInt = isa<IntegerType>(resType);
1134   // Use iterator index 0.
1135   int64_t iterIndex = 0;
1136   SmallVector<AffineMap> iMap = op.getIndexingMapsArray();
1137   std::optional<int64_t> lookupLhs = getResultIndex(iMap[0], iterIndex);
1138   std::optional<int64_t> lookupRhs = getResultIndex(iMap[1], iterIndex);
1139   if (!lookupLhs.has_value())
1140     return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
1141       diag << "expected iterIndex=" << iterIndex << "to map to a LHS dimension";
1142     });
1143   if (!lookupRhs.has_value())
1144     return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
1145       diag << "expected iterIndex=" << iterIndex << "to map to a RHS dimension";
1146     });
1147   int64_t lhsIndex = *lookupLhs;
1148   int64_t rhsIndex = *lookupRhs;
1149   int64_t dimSize = lhsType.getDimSize(lhsIndex);
1150   if (dimSize != rhsType.getDimSize(rhsIndex))
1151     return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
1152       diag << "expect LHS dimension " << lhsIndex
1153            << " to have the same size as RHS dimension " << rhsIndex;
1154     });
1155   // Base case.
1156   if (lhsType.getRank() == 1) {
1157     if (rhsType.getRank() != 1)
1158       return rewriter.notifyMatchFailure(
1159           op, "When LHS has rank 1, expected also RHS to have rank 1");
1160     Value m = createMul(loc, op.getLhs(), op.getRhs(), isInt, rewriter);
1161     auto kind = vector::CombiningKind::ADD;
1162 
1163     Value acc = op.getAcc();
1164     Operation *reductionOp =
1165         acc ? rewriter.create<vector::ReductionOp>(loc, kind, m, acc)
1166             : rewriter.create<vector::ReductionOp>(loc, kind, m);
1167     return maskOperation(rewriter, reductionOp, mask)->getResult(0);
1168   }
1169   // Construct new iterator types and affine map array attribute.
1170   std::array<AffineMap, 3> lowIndexingMaps = {
1171       adjustMap(iMap[0], iterIndex, rewriter),
1172       adjustMap(iMap[1], iterIndex, rewriter),
1173       adjustMap(iMap[2], iterIndex, rewriter)};
1174   auto lowAffine = rewriter.getAffineMapArrayAttr(lowIndexingMaps);
1175   auto lowIter =
1176       rewriter.getArrayAttr(adjustIter(op.getIteratorTypes(), iterIndex));
1177   // Unroll into a series of lower dimensional vector.contract ops.
1178   // By feeding the initial accumulator into the first contraction,
1179   // and the result of each contraction into the next, eventually
1180   // the sum of all reductions is computed.
1181   Value result = op.getAcc();
1182   for (int64_t d = 0; d < dimSize; ++d) {
1183     auto lhs = reshapeLoad(loc, op.getLhs(), lhsType, lhsIndex, d, rewriter);
1184     auto rhs = reshapeLoad(loc, op.getRhs(), rhsType, rhsIndex, d, rewriter);
1185     Value newMask;
1186     if (mask)
1187       newMask = reshapeLoad(loc, mask, cast<VectorType>(mask.getType()),
1188                             iterIndex, d, rewriter);
1189 
1190     Operation *newContract = rewriter.create<vector::ContractionOp>(
1191         loc, lhs, rhs, result, lowAffine, lowIter);
1192     result = maskOperation(rewriter, newContract, newMask)->getResult(0);
1193   }
1194   return result;
1195 }
1196 
1197 /// Progressive lowering of OuterProductOp.
1198 /// One:
1199 ///   %x = vector.outerproduct %lhs, %rhs, %acc
1200 /// is replaced by:
1201 ///   %z = zero-result
1202 ///   %0 = vector.extract %lhs[0]
1203 ///   %1 = vector.broadcast %0
1204 ///   %2 = vector.extract %acc[0]
1205 ///   %3 = vector.fma %1, %rhs, %2
1206 ///   %4 = vector.insert %3, %z[0]
1207 ///   ..
1208 ///   %x = vector.insert %.., %..[N-1]
1209 ///
1210 class OuterProductOpLowering : public OpRewritePattern<vector::OuterProductOp> {
1211 public:
1212   using OpRewritePattern::OpRewritePattern;
1213 
1214   LogicalResult matchAndRewrite(vector::OuterProductOp op,
1215                                 PatternRewriter &rewriter) const override {
1216     VectorType resType = op.getResultVectorType();
1217     if ((resType.getShape().size() >= 2) && resType.allDimsScalable())
1218       return failure();
1219 
1220     auto loc = op.getLoc();
1221 
1222     VectorType lhsType = op.getOperandVectorTypeLHS();
1223     VectorType rhsType = dyn_cast<VectorType>(op.getOperandTypeRHS());
1224     Type eltType = resType.getElementType();
1225     bool isInt = isa<IntegerType, IndexType>(eltType);
1226     Value acc = op.getAcc();
1227     vector::CombiningKind kind = op.getKind();
1228 
1229     // Vector mask setup.
1230     OpBuilder::InsertionGuard guard(rewriter);
1231     auto maskableOp = cast<vector::MaskableOpInterface>(op.getOperation());
1232     Operation *rootOp;
1233     Value mask;
1234     if (maskableOp.isMasked()) {
1235       rewriter.setInsertionPoint(maskableOp.getMaskingOp());
1236       rootOp = maskableOp.getMaskingOp();
1237       mask = maskableOp.getMaskingOp().getMask();
1238     } else {
1239       rootOp = op;
1240     }
1241 
1242     if (!rhsType) {
1243       // Special case: AXPY operation.
1244       Value b = rewriter.create<vector::BroadcastOp>(loc, lhsType, op.getRhs());
1245       std::optional<Value> mult = createContractArithOp(
1246           loc, op.getLhs(), b, acc, kind, rewriter, isInt, mask);
1247       if (!mult.has_value())
1248         return failure();
1249       rewriter.replaceOp(rootOp, *mult);
1250       return success();
1251     }
1252 
1253     Value result = rewriter.create<arith::ConstantOp>(
1254         loc, resType, rewriter.getZeroAttr(resType));
1255     for (int64_t d = 0, e = resType.getDimSize(0); d < e; ++d) {
1256       Value x = rewriter.create<vector::ExtractOp>(loc, op.getLhs(), d);
1257       Value a = rewriter.create<vector::BroadcastOp>(loc, rhsType, x);
1258       Value r = nullptr;
1259       if (acc)
1260         r = rewriter.create<vector::ExtractOp>(loc, acc, d);
1261       Value extrMask;
1262       if (mask)
1263         extrMask = rewriter.create<vector::ExtractOp>(loc, mask, d);
1264 
1265       std::optional<Value> m = createContractArithOp(
1266           loc, a, op.getRhs(), r, kind, rewriter, isInt, extrMask);
1267       if (!m.has_value())
1268         return failure();
1269       result = rewriter.create<vector::InsertOp>(loc, *m, result, d);
1270     }
1271 
1272     rewriter.replaceOp(rootOp, result);
1273     return success();
1274   }
1275 };
1276 
1277 /// Progressively lower a `vector.contract %a, %b, %c` with row-major matmul
1278 /// semantics to:
1279 /// ```
1280 ///    %mta = maybe_transpose
1281 ///    %mtb = maybe_transpose
1282 ///    %flattened_a = vector.shape_cast %mta
1283 ///    %flattened_b = vector.shape_cast %mtb
1284 ///    %flattened_d = vector.matmul %flattened_a, %flattened_b
1285 ///    %mtd = vector.shape_cast %flattened_d
1286 ///    %d = maybe_untranspose %mtd
1287 ///    %e = add %c, %d
1288 /// ```
1289 /// `vector.matmul` later lowers to `llvm.matrix.multiply`.
1290 //
1291 /// This only kicks in when VectorTransformsOptions is set to `Matmul`.
1292 /// vector.transpose operations are inserted if the vector.contract op is not a
1293 /// row-major matrix multiply.
1294 LogicalResult
1295 ContractionOpToMatmulOpLowering::matchAndRewrite(vector::ContractionOp op,
1296                                                  PatternRewriter &rew) const {
1297   // TODO: Support vector.mask.
1298   auto maskableOp = cast<MaskableOpInterface>(op.getOperation());
1299   if (maskableOp.isMasked())
1300     return failure();
1301 
1302   if (vectorTransformOptions.vectorContractLowering !=
1303       vector::VectorContractLowering::Matmul)
1304     return failure();
1305   if (failed(filter(op)))
1306     return failure();
1307 
1308   auto iteratorTypes = op.getIteratorTypes().getValue();
1309   if (!isParallelIterator(iteratorTypes[0]) ||
1310       !isParallelIterator(iteratorTypes[1]) ||
1311       !isReductionIterator(iteratorTypes[2]))
1312     return failure();
1313 
1314   Type elementType = op.getLhsType().getElementType();
1315   if (!elementType.isIntOrFloat())
1316     return failure();
1317 
1318   Type dstElementType = op.getType();
1319   if (auto vecType = dyn_cast<VectorType>(dstElementType))
1320     dstElementType = vecType.getElementType();
1321   if (elementType != dstElementType)
1322     return failure();
1323 
1324   // Perform lhs + rhs transpositions to conform to matmul row-major semantics.
1325   // Bail out if the contraction cannot be put in this form.
1326   MLIRContext *ctx = op.getContext();
1327   Location loc = op.getLoc();
1328   AffineExpr m, n, k;
1329   bindDims(rew.getContext(), m, n, k);
1330   // LHS must be A(m, k) or A(k, m).
1331   Value lhs = op.getLhs();
1332   auto lhsMap = op.getIndexingMapsArray()[0];
1333   if (lhsMap == AffineMap::get(3, 0, {k, m}, ctx))
1334     lhs = rew.create<vector::TransposeOp>(loc, lhs, ArrayRef<int64_t>{1, 0});
1335   else if (lhsMap != AffineMap::get(3, 0, {m, k}, ctx))
1336     return failure();
1337 
1338   // RHS must be B(k, n) or B(n, k).
1339   Value rhs = op.getRhs();
1340   auto rhsMap = op.getIndexingMapsArray()[1];
1341   if (rhsMap == AffineMap::get(3, 0, {n, k}, ctx))
1342     rhs = rew.create<vector::TransposeOp>(loc, rhs, ArrayRef<int64_t>{1, 0});
1343   else if (rhsMap != AffineMap::get(3, 0, {k, n}, ctx))
1344     return failure();
1345 
1346   // At this point lhs and rhs are in row-major.
1347   VectorType lhsType = cast<VectorType>(lhs.getType());
1348   VectorType rhsType = cast<VectorType>(rhs.getType());
1349   int64_t lhsRows = lhsType.getDimSize(0);
1350   int64_t lhsColumns = lhsType.getDimSize(1);
1351   int64_t rhsColumns = rhsType.getDimSize(1);
1352 
1353   Type flattenedLHSType =
1354       VectorType::get(lhsType.getNumElements(), lhsType.getElementType());
1355   lhs = rew.create<vector::ShapeCastOp>(loc, flattenedLHSType, lhs);
1356 
1357   Type flattenedRHSType =
1358       VectorType::get(rhsType.getNumElements(), rhsType.getElementType());
1359   rhs = rew.create<vector::ShapeCastOp>(loc, flattenedRHSType, rhs);
1360 
1361   Value mul = rew.create<vector::MatmulOp>(loc, lhs, rhs, lhsRows, lhsColumns,
1362                                            rhsColumns);
1363   mul = rew.create<vector::ShapeCastOp>(
1364       loc,
1365       VectorType::get({lhsRows, rhsColumns},
1366                       getElementTypeOrSelf(op.getAcc().getType())),
1367       mul);
1368 
1369   // ACC must be C(m, n) or C(n, m).
1370   auto accMap = op.getIndexingMapsArray()[2];
1371   if (accMap == AffineMap::get(3, 0, {n, m}, ctx))
1372     mul = rew.create<vector::TransposeOp>(loc, mul, ArrayRef<int64_t>{1, 0});
1373   else if (accMap != AffineMap::get(3, 0, {m, n}, ctx))
1374     llvm_unreachable("invalid contraction semantics");
1375 
1376   Value res =
1377       isa<IntegerType>(elementType)
1378           ? static_cast<Value>(rew.create<arith::AddIOp>(loc, op.getAcc(), mul))
1379           : static_cast<Value>(
1380                 rew.create<arith::AddFOp>(loc, op.getAcc(), mul));
1381 
1382   rew.replaceOp(op, res);
1383   return success();
1384 }
1385 } // namespace
1386 
1387 void mlir::vector::populateVectorContractLoweringPatterns(
1388     RewritePatternSet &patterns, VectorTransformsOptions options,
1389     PatternBenefit benefit, bool disableOuterProductLowering) {
1390   if (!disableOuterProductLowering)
1391     patterns.add<OuterProductOpLowering>(patterns.getContext(), benefit);
1392   patterns.add<ContractionOpLowering, ContractionOpToMatmulOpLowering,
1393                ContractionOpToOuterProductOpLowering>(
1394       options, patterns.getContext(), benefit);
1395 }
1396 
1397 void mlir::vector::populateVectorOuterProductLoweringPatterns(
1398     RewritePatternSet &patterns, PatternBenefit benefit) {
1399   patterns.add<OuterProductOpLowering>(patterns.getContext(), benefit);
1400 }
1401