xref: /llvm-project/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp (revision cb89457ff825926f0004711bef3d534df1f5576d)
1 //===- LowerVectorContract.cpp - Lower 'vector.contract' operation --------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements target-independent rewrites and utilities to lower the
10 // 'vector.contract' operation.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Dialect/Affine/IR/AffineOps.h"
15 #include "mlir/Dialect/Arith/IR/Arith.h"
16 #include "mlir/Dialect/Arith/Utils/Utils.h"
17 #include "mlir/Dialect/Linalg/IR/Linalg.h"
18 #include "mlir/Dialect/MemRef/IR/MemRef.h"
19 #include "mlir/Dialect/SCF/IR/SCF.h"
20 #include "mlir/Dialect/Tensor/IR/Tensor.h"
21 #include "mlir/Dialect/Utils/IndexingUtils.h"
22 #include "mlir/Dialect/Utils/StructuredOpsUtils.h"
23 #include "mlir/Dialect/Vector/IR/VectorOps.h"
24 #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
25 #include "mlir/Dialect/Vector/Utils/VectorUtils.h"
26 #include "mlir/IR/BuiltinAttributeInterfaces.h"
27 #include "mlir/IR/BuiltinTypes.h"
28 #include "mlir/IR/ImplicitLocOpBuilder.h"
29 #include "mlir/IR/Location.h"
30 #include "mlir/IR/Matchers.h"
31 #include "mlir/IR/PatternMatch.h"
32 #include "mlir/IR/TypeUtilities.h"
33 #include "mlir/Interfaces/VectorInterfaces.h"
34 
35 #define DEBUG_TYPE "vector-contract-lowering"
36 
37 using namespace mlir;
38 using namespace mlir::vector;
39 
40 //===----------------------------------------------------------------------===//
41 // Helper functions
42 //===----------------------------------------------------------------------===//
43 // Helper to find an index in an affine map.
44 static std::optional<int64_t> getResultIndex(AffineMap map, int64_t index) {
45   for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) {
46     int64_t idx = map.getDimPosition(i);
47     if (idx == index)
48       return i;
49   }
50   return std::nullopt;
51 }
52 
53 // Helper to construct iterator types with one index removed.
54 static SmallVector<Attribute> adjustIter(ArrayAttr iteratorTypes,
55                                          int64_t index) {
56   SmallVector<Attribute> results;
57   for (const auto &it : llvm::enumerate(iteratorTypes)) {
58     int64_t idx = it.index();
59     if (idx == index)
60       continue;
61     results.push_back(it.value());
62   }
63   return results;
64 }
65 
66 // Helper to construct an affine map with one index removed.
67 static AffineMap adjustMap(AffineMap map, int64_t index,
68                            PatternRewriter &rewriter) {
69   auto *ctx = rewriter.getContext();
70   SmallVector<AffineExpr> results;
71   for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) {
72     int64_t idx = map.getDimPosition(i);
73     if (idx == index)
74       continue;
75     // Re-insert remaining indices, but renamed when occurring
76     // after the removed index.
77     auto targetExpr = getAffineDimExpr(idx < index ? idx : idx - 1, ctx);
78     results.push_back(targetExpr);
79   }
80   return AffineMap::get(map.getNumDims() - 1, 0, results, ctx);
81 }
82 
83 // Helper method to possibly drop a dimension in a load.
84 // TODO
85 static Value reshapeLoad(Location loc, Value val, VectorType type,
86                          int64_t index, int64_t pos,
87                          PatternRewriter &rewriter) {
88   if (index == -1)
89     return val;
90 
91   // At extraction dimension?
92   if (index == 0)
93     return rewriter.create<vector::ExtractOp>(loc, val, pos);
94 
95   // Unroll leading dimensions.
96   VectorType vType = VectorType::Builder(type).dropDim(0);
97   VectorType resType = VectorType::Builder(type).dropDim(index);
98   Value result = rewriter.create<arith::ConstantOp>(
99       loc, resType, rewriter.getZeroAttr(resType));
100   for (int64_t d = 0, e = resType.getDimSize(0); d < e; d++) {
101     Value ext = rewriter.create<vector::ExtractOp>(loc, val, d);
102     Value load = reshapeLoad(loc, ext, vType, index - 1, pos, rewriter);
103     result = rewriter.create<vector::InsertOp>(loc, load, result, d);
104   }
105   return result;
106 }
107 
108 // Helper method to possibly drop a dimension in a store.
109 // TODO
110 static Value reshapeStore(Location loc, Value val, Value result,
111                           VectorType type, int64_t index, int64_t pos,
112                           PatternRewriter &rewriter) {
113   // Unmodified?
114   if (index == -1)
115     return val;
116   // At insertion dimension?
117   if (index == 0)
118     return rewriter.create<vector::InsertOp>(loc, val, result, pos);
119 
120   // Unroll leading dimensions.
121   VectorType vType = VectorType::Builder(type).dropDim(0);
122   for (int64_t d = 0, e = type.getDimSize(0); d < e; d++) {
123     Value ext = rewriter.create<vector::ExtractOp>(loc, result, d);
124     Value ins = rewriter.create<vector::ExtractOp>(loc, val, d);
125     Value sto = reshapeStore(loc, ins, ext, vType, index - 1, pos, rewriter);
126     result = rewriter.create<vector::InsertOp>(loc, sto, result, d);
127   }
128   return result;
129 }
130 
131 /// Helper to create arithmetic operation associated with a kind of contraction.
132 static std::optional<Value>
133 createContractArithOp(Location loc, Value x, Value y, Value acc,
134                       vector::CombiningKind kind, PatternRewriter &rewriter,
135                       bool isInt, Value mask = Value()) {
136   using vector::CombiningKind;
137   Value mul;
138 
139   if (isInt) {
140     if (kind == CombiningKind::MINNUMF || kind == CombiningKind::MAXNUMF ||
141         kind == CombiningKind::MINIMUMF || kind == CombiningKind::MAXIMUMF)
142       // Only valid for floating point types.
143       return std::nullopt;
144     mul = rewriter.create<arith::MulIOp>(loc, x, y);
145   } else {
146     // Float case.
147     if (kind == CombiningKind::AND || kind == CombiningKind::MINUI ||
148         kind == CombiningKind::MINSI || kind == CombiningKind::MAXUI ||
149         kind == CombiningKind::MAXSI || kind == CombiningKind::OR ||
150         kind == CombiningKind::XOR)
151       // Only valid for integer types.
152       return std::nullopt;
153     // Special case for fused multiply-add.
154     if (acc && isa<VectorType>(acc.getType()) && kind == CombiningKind::ADD) {
155       Value fma = rewriter.create<vector::FMAOp>(loc, x, y, acc);
156       if (mask)
157         // The fma op doesn't need explicit masking. However, fma ops used in
158         // reductions must preserve previous 'acc' values for masked-out lanes.
159         fma = selectPassthru(rewriter, mask, fma, acc);
160       return fma;
161     }
162     mul = rewriter.create<arith::MulFOp>(loc, x, y);
163   }
164 
165   if (!acc)
166     return std::optional<Value>(mul);
167 
168   return makeArithReduction(rewriter, loc, kind, mul, acc,
169                             /*fastmath=*/nullptr, mask);
170 }
171 
172 /// Return the positions of the reductions in the given map.
173 static SmallVector<int64_t> getReductionIndex(AffineMap map,
174                                               ArrayAttr iteratorTypes) {
175   SmallVector<int64_t> dimsIdx;
176   for (unsigned i = 0, e = map.getNumResults(); i < e; i++) {
177     if (isReductionIterator(iteratorTypes[map.getDimPosition(i)]))
178       dimsIdx.push_back(i);
179   }
180   return dimsIdx;
181 }
182 
183 /// Look for a given dimension in an affine map and return its position. Return
184 /// std::nullopt if the dimension is not in the map results.
185 static std::optional<unsigned> getDimPosition(AffineMap map, unsigned dim) {
186   for (unsigned i = 0, e = map.getNumResults(); i < e; i++) {
187     if (map.getDimPosition(i) == dim)
188       return i;
189   }
190   return std::nullopt;
191 }
192 
193 /// Creates an AddIOp if `isInt` is true otherwise create an arith::AddFOp using
194 /// operands `x` and `y`.
195 static Value createAdd(Location loc, Value x, Value y, bool isInt,
196                        PatternRewriter &rewriter) {
197   if (isInt)
198     return rewriter.create<arith::AddIOp>(loc, x, y);
199   return rewriter.create<arith::AddFOp>(loc, x, y);
200 }
201 
202 /// Creates a MulIOp if `isInt` is true otherwise create an MulFOp using
203 /// operands `x and `y`.
204 static Value createMul(Location loc, Value x, Value y, bool isInt,
205                        PatternRewriter &rewriter) {
206   if (isInt)
207     return rewriter.create<arith::MulIOp>(loc, x, y);
208   return rewriter.create<arith::MulFOp>(loc, x, y);
209 }
210 
211 namespace {
212 
213 /// Progressive lowering of a `vector.contract %a, %b, %c` with row-major matmul
214 /// semantics to:
215 /// ```
216 ///    %flattened_a = vector.shape_cast %a
217 ///    %flattened_b = vector.shape_cast %b
218 ///    %flattened_d = vector.matmul %flattened_a, %flattened_b
219 ///    %d = vector.shape_cast %%flattened_d
220 ///    %e = add %c, %d
221 /// ```
222 /// `vector.matmul` later lowers to `llvm.matrix.multiply`.
223 //
224 /// This only kicks in when VectorTransformsOptions is set to OuterProduct and
225 /// the vector.contract op is a row-major matrix multiply.
226 class ContractionOpToMatmulOpLowering
227     : public vector::MaskableOpRewritePattern<vector::ContractionOp> {
228 public:
229   using MaskableOpRewritePattern::MaskableOpRewritePattern;
230 
231   using FilterConstraintType =
232       std::function<LogicalResult(vector::ContractionOp op)>;
233 
234   static LogicalResult defaultFilter(vector::ContractionOp op) {
235     return success();
236   }
237 
238   ContractionOpToMatmulOpLowering(
239       vector::VectorTransformsOptions vectorTransformOptions,
240       MLIRContext *context, PatternBenefit benefit = 1,
241       FilterConstraintType constraint = defaultFilter)
242       : MaskableOpRewritePattern<vector::ContractionOp>(context, benefit),
243         vectorTransformOptions(vectorTransformOptions),
244         filter(std::move(constraint)) {}
245 
246   FailureOr<Value>
247   matchAndRewriteMaskableOp(vector::ContractionOp op, MaskingOpInterface maskOp,
248                             PatternRewriter &rewriter) const override;
249 
250 private:
251   /// Options to control the vector patterns.
252   vector::VectorTransformsOptions vectorTransformOptions;
253   FilterConstraintType filter;
254 };
255 
256 /// Progressive lowering of a `vector.contract %a, %b, %c` with row-major matmul
257 /// semantics to a reduction_size-unrolled sequence:
258 /// ```
259 ///    %at = vector.transpose %a, [1, 0]
260 ///    %bRow0 = vector.extract %b[0]
261 ///    %atRow0 = vector.extract %at[0]
262 ///    %c0 = vector.outerproduct %atRow0, %bRow0, %c
263 ///    ...
264 ///    %bRowK = vector.extract %b[K]
265 ///    %atRowK = vector.extract %at[K]
266 ///    %cK = vector.outerproduct %atRowK, %bRowK, %cK-1
267 /// ```
268 ///
269 /// This only kicks in when VectorTransformsOptions is set to OuterProduct and
270 /// the vector.contract op is a row-major matrix multiply.
271 class ContractionOpToOuterProductOpLowering
272     : public MaskableOpRewritePattern<vector::ContractionOp> {
273 public:
274   using MaskableOpRewritePattern::MaskableOpRewritePattern;
275 
276   using FilterConstraintType =
277       std::function<LogicalResult(vector::ContractionOp op)>;
278 
279   static LogicalResult defaultFilter(vector::ContractionOp op) {
280     return success();
281   }
282 
283   ContractionOpToOuterProductOpLowering(
284       vector::VectorTransformsOptions vectorTransformOptions,
285       MLIRContext *context, PatternBenefit benefit = 1,
286       FilterConstraintType constraint = defaultFilter)
287       : MaskableOpRewritePattern<vector::ContractionOp>(context, benefit),
288         vectorTransformOptions(vectorTransformOptions),
289         filter(std::move(constraint)) {}
290 
291   FailureOr<Value>
292   matchAndRewriteMaskableOp(vector::ContractionOp op, MaskingOpInterface maskOp,
293                             PatternRewriter &rewriter) const override;
294 
295 private:
296   /// Options to control the vector patterns.
297   vector::VectorTransformsOptions vectorTransformOptions;
298   FilterConstraintType filter;
299 };
300 
301 /// Progressive lowering of a `vector.contract %a, %b, %c` with row-major matmul
302 /// semantics to an output-size-unrolled sequence:
303 /// ```
304 ///    %out = arith.constant ... : vector<MxNxelt_type>
305 ///    %bt = vector.transpose %b, [1, 0]
306 ///    %aRow0 = vector.extract %a[0]
307 ///    %btRow0 = vector.extract %bt[0]
308 ///    %c00 = vector.reduce %atRow0, %bRow0
309 ///    %out00 = vector.insert %c00, %out[0, 0]
310 ///    ...
311 ///    %aRowLast = vector.extract %at[M-1]
312 ///    %btRowLast = vector.extract %b[N-1]
313 ///    %cLastLast = vector.reduce %atRowLast, %bRowLast
314 ///    %outcLastLast = vector.insert %cLastLast, %out[M-1, N-1]
315 /// ```
316 ///
317 /// This only kicks in when VectorTransformsOptions is set to Dot and
318 /// the vector.contract op is a row-major matmul or matvec.
319 class ContractionOpToDotLowering
320     : public MaskableOpRewritePattern<vector::ContractionOp> {
321 public:
322   using MaskableOpRewritePattern::MaskableOpRewritePattern;
323 
324   using FilterConstraintType =
325       std::function<LogicalResult(vector::ContractionOp op)>;
326 
327   static LogicalResult defaultFilter(vector::ContractionOp op) {
328     return success();
329   }
330 
331   ContractionOpToDotLowering(
332       vector::VectorTransformsOptions vectorTransformOptions,
333       MLIRContext *context, PatternBenefit benefit = 1,
334       const FilterConstraintType &constraint = defaultFilter)
335       : MaskableOpRewritePattern<vector::ContractionOp>(context, benefit),
336         vectorTransformOptions(vectorTransformOptions), filter(defaultFilter) {}
337 
338   FailureOr<Value>
339   matchAndRewriteMaskableOp(vector::ContractionOp op, MaskingOpInterface maskOp,
340                             PatternRewriter &rewriter) const override;
341 
342 private:
343   /// Options to control the vector patterns.
344   vector::VectorTransformsOptions vectorTransformOptions;
345   FilterConstraintType filter;
346 };
347 
348 /// Progressive lowering of ContractionOp.
349 ///
350 /// One:
351 ///   %x = vector.contract with at least one free/batch dimension
352 /// is replaced by:
353 ///   %a = vector.contract with one less free/batch dimension
354 ///   %b = vector.contract with one less free/batch dimension
355 ///   ..
356 ///   %x = combine %a %b ..
357 /// until a pure contraction is reached (no free/batch dimensions),
358 /// which is replaced by a dot-product.
359 ///
360 /// This only kicks in when either VectorTransformsOptions is set
361 /// to Dot or when other contraction patterns fail.
362 class ContractionOpLowering
363     : public MaskableOpRewritePattern<vector::ContractionOp> {
364 public:
365   using MaskableOpRewritePattern::MaskableOpRewritePattern;
366   using FilterConstraintType =
367       std::function<LogicalResult(vector::ContractionOp op)>;
368 
369   static LogicalResult defaultFilter(vector::ContractionOp op) {
370     return success();
371   }
372 
373   ContractionOpLowering(vector::VectorTransformsOptions vectorTransformOptions,
374                         MLIRContext *context, PatternBenefit benefit = 1,
375                         FilterConstraintType constraint = defaultFilter)
376       : MaskableOpRewritePattern<vector::ContractionOp>(context, benefit),
377         vectorTransformOptions(vectorTransformOptions),
378         filter(std::move(constraint)) {}
379 
380   FailureOr<Value>
381   matchAndRewriteMaskableOp(vector::ContractionOp op, MaskingOpInterface maskOp,
382                             PatternRewriter &rewriter) const override;
383 
384 private:
385   /// Options to control the vector patterns.
386   vector::VectorTransformsOptions vectorTransformOptions;
387   FilterConstraintType filter;
388   // Lower one parallel dimension.
389   FailureOr<Value> lowerParallel(PatternRewriter &rewriter,
390                                  vector::ContractionOp op, int64_t lhsIndex,
391                                  int64_t rhsIndex, Value mask) const;
392   // Lower one reduction dimension.
393   FailureOr<Value> lowerReduction(PatternRewriter &rewriter,
394                                   vector::ContractionOp op, Value mask) const;
395 };
396 
397 /// Generate a vector implementation for matmat, matvec and tmatvec.
398 /// This unrolls outer-products along the reduction dimension.
399 struct UnrolledOuterProductGenerator
400     : public StructuredGenerator<vector::ContractionOp, vector::IteratorType> {
401   UnrolledOuterProductGenerator(RewriterBase &b, vector::ContractionOp op)
402       : StructuredGenerator<vector::ContractionOp, vector::IteratorType>(b, op),
403         kind(op.getKind()), lhs(op.getLhs()), rhs(op.getRhs()),
404         res(op.getAcc()), lhsType(op.getLhsType()) {
405     auto maskableOp = cast<MaskableOpInterface>(op.getOperation());
406     if (maskableOp.isMasked())
407       mask = maskableOp.getMaskingOp().getMask();
408   }
409 
410   Value t(Value v, ArrayRef<int64_t> perm = {1, 0}) {
411     if (!v)
412       return v;
413     return rewriter.create<vector::TransposeOp>(loc, v, perm);
414   }
415 
416   Value promote(Value v, Type dstElementType) {
417     Type elementType = v.getType();
418     auto vecType = dyn_cast<VectorType>(elementType);
419     if (vecType)
420       elementType = vecType.getElementType();
421     if (elementType == dstElementType)
422       return v;
423     Type promotedType = dstElementType;
424     if (vecType)
425       promotedType = vecType.clone(promotedType);
426     if (isa<FloatType>(dstElementType))
427       return rewriter.create<arith::ExtFOp>(loc, promotedType, v);
428     return rewriter.create<arith::ExtSIOp>(loc, promotedType, v);
429   }
430 
431   FailureOr<Value> outerProd(Value lhs, Value rhs, Value res,
432                              VectorType lhsType, int reductionSize,
433                              std::optional<Value> maybeMask = std::nullopt) {
434     // Incremental support for masking.
435     if (mask && !maybeMask.has_value())
436       return failure();
437 
438     Type resElementType = cast<VectorType>(res.getType()).getElementType();
439     for (int64_t k = 0; k < reductionSize; ++k) {
440       Value extractA = rewriter.create<vector::ExtractOp>(loc, lhs, k);
441       Value extractB = rewriter.create<vector::ExtractOp>(loc, rhs, k);
442       extractA = promote(extractA, resElementType);
443       extractB = promote(extractB, resElementType);
444       Value extractMask;
445       if (maybeMask.has_value() && maybeMask.value())
446         extractMask =
447             rewriter.create<vector::ExtractOp>(loc, maybeMask.value(), k);
448 
449       Operation *outerProdOp = rewriter.create<vector::OuterProductOp>(
450           loc, res.getType(), extractA, extractB, res, kind);
451       res = maskOperation(rewriter, outerProdOp, extractMask)->getResult(0);
452     }
453     return res;
454   }
455 
456   /// Helper function for `matmat`, `matvec`, `tmatvec`. Returns the size of
457   /// dimension `reductionDim`. If the dimension is a scalable dimension,
458   /// returns "nullopt".
459   std::optional<int64_t> getReductionSize(VectorType vecType,
460                                           int64_t reductionDim) {
461     // Cannot unroll scalable dimension.
462     if (vecType.getScalableDims()[reductionDim])
463       return std::nullopt;
464     int64_t reductionSize = vecType.getDimSize(reductionDim);
465     assert(reductionSize > 0 &&
466            "Reduction dim must be a known static size to allow unrolling");
467     return reductionSize;
468   }
469 
470   /// Two outer parallel, one inner reduction (matmat flavor).
471   FailureOr<Value> matmat() {
472     if (!iters({Par(), Par(), Red()}))
473       return failure();
474     // Set up the parallel/reduction structure in the right form.
475     AffineExpr m, n, k;
476     bindDims(rewriter.getContext(), m, n, k);
477 
478     // Classical row-major matmul:  Just permute the lhs.
479     if (layout({{m, k}, {k, n}, {m, n}})) {
480       if (auto reductionSize = getReductionSize(lhsType, 1)) {
481         // Note: `t` creates new IR. It must be nested within this `if` check
482         // so that no IR is created when then pattern returns "failure".
483         Value tLhs = t(lhs);
484         Value tMask = t(mask, {2, 0, 1});
485         return outerProd(tLhs, rhs, res, lhsType, *reductionSize, tMask);
486       }
487     }
488     // TODO: may be better to fail and use some vector<k> -> scalar reduction.
489     if (layout({{m, k}, {n, k}, {m, n}})) {
490       if (auto reductionSize = getReductionSize(lhsType, 1)) {
491         Value tLhs = t(lhs);
492         Value tRhs = t(rhs);
493         Value tMask = t(mask, {2, 0, 1});
494         return outerProd(tLhs, tRhs, res, lhsType, *reductionSize, tMask);
495       }
496     }
497     // No need to permute anything.
498     if (layout({{k, m}, {k, n}, {m, n}})) {
499       if (auto reductionSize = getReductionSize(lhsType, 0)) {
500         Value tMask = t(mask, {2, 0, 1});
501         return outerProd(lhs, rhs, res, lhsType, *reductionSize, tMask);
502       }
503     }
504     // Just permute the rhs.
505     if (layout({{k, m}, {n, k}, {m, n}})) {
506       if (auto reductionSize = getReductionSize(lhsType, 0)) {
507         Value tRhs = t(rhs);
508         Value tMask = t(mask, {2, 0, 1});
509         return outerProd(lhs, tRhs, res, lhsType, *reductionSize, tMask);
510       }
511     }
512     // Transposed output: swap RHS and LHS.
513     // Classical row-major matmul: permute the lhs.
514     if (layout({{m, k}, {k, n}, {n, m}})) {
515       if (auto reductionSize = getReductionSize(lhsType, 1)) {
516         Value tLhs = t(lhs);
517         Value tMask = t(mask, {2, 0, 1});
518         return outerProd(rhs, tLhs, res, lhsType, *reductionSize, tMask);
519       }
520     }
521     // TODO: may be better to fail and use some vector<k> -> scalar reduction.
522     if (layout({{m, k}, {n, k}, {n, m}})) {
523       if (auto reductionSize = getReductionSize(lhsType, 1)) {
524         Value tRhs = t(rhs);
525         Value tLhs = t(lhs);
526         Value tMask = t(mask, {2, 0, 1});
527         return outerProd(tRhs, tLhs, res, lhsType, *reductionSize, tMask);
528       }
529     }
530     if (layout({{k, m}, {k, n}, {n, m}})) {
531       if (auto reductionSize = getReductionSize(lhsType, 0)) {
532         Value tMask = t(mask, {2, 0, 1});
533         return outerProd(rhs, lhs, res, lhsType, *reductionSize, tMask);
534       }
535     }
536     if (layout({{k, m}, {n, k}, {n, m}})) {
537       if (auto reductionSize = getReductionSize(lhsType, 0)) {
538         Value tRhs = t(rhs);
539         Value tMask = t(mask, {2, 0, 1});
540         return outerProd(tRhs, lhs, res, lhsType, *reductionSize, tMask);
541       }
542     }
543     return failure();
544   }
545 
546   //
547   // One outer parallel, one inner reduction (matvec flavor).
548   // Mask needs to be transposed everywhere to turn the reduction dimension
549   // outermost as required by outerproduct.
550   //
551   FailureOr<Value> matvec() {
552     if (!iters({Par(), Red()}))
553       return failure();
554     AffineExpr m, k;
555     bindDims(rewriter.getContext(), m, k);
556 
557     // Case mat-vec: transpose.
558     if (layout({{m, k}, {k}, {m}})) {
559       if (auto reductionSize = getReductionSize(lhsType, 1)) {
560         Value tLhs = t(lhs);
561         Value tMask = t(mask);
562         return outerProd(tLhs, rhs, res, lhsType, *reductionSize, tMask);
563       }
564     }
565     // Case mat-trans-vec: ready to go.
566     if (layout({{k, m}, {k}, {m}})) {
567       if (auto reductionSize = getReductionSize(lhsType, 0)) {
568         Value tMask = t(mask);
569         return outerProd(lhs, rhs, res, lhsType, *reductionSize, tMask);
570       }
571     }
572     // Case vec-mat: swap and transpose.
573     if (layout({{k}, {m, k}, {m}})) {
574       if (auto reductionSize = getReductionSize(lhsType, 0)) {
575         Value tRhs = t(rhs);
576         Value tMask = t(mask);
577         return outerProd(tRhs, lhs, res, lhsType, *reductionSize, tMask);
578       }
579     }
580     // Case vec-mat-trans: swap and ready to go.
581     if (layout({{k}, {k, m}, {m}})) {
582       if (auto reductionSize = getReductionSize(lhsType, 0)) {
583         Value tMask = t(mask);
584         return outerProd(rhs, lhs, res, lhsType, *reductionSize, tMask);
585       }
586     }
587     return failure();
588   }
589 
590   //
591   // One outer reduction, one inner parallel (tmatvec flavor).
592   // Mask already has the shape of the outer product.
593   //
594   FailureOr<Value> tmatvec() {
595     if (!iters({Red(), Par()}))
596       return failure();
597     AffineExpr k, m;
598     bindDims(rewriter.getContext(), k, m);
599 
600     // Case mat-vec: transpose.
601     if (layout({{m, k}, {k}, {m}}))
602       if (auto reductionSize = getReductionSize(lhsType, 1))
603         return outerProd(t(lhs), rhs, res, lhsType, *reductionSize, mask);
604     // Case mat-trans-vec: ready to go.
605     if (layout({{k, m}, {k}, {m}}))
606       if (auto reductionSize = getReductionSize(lhsType, 0))
607         return outerProd(lhs, rhs, res, lhsType, *reductionSize, mask);
608     // Case vec-mat: swap and transpose.
609     if (layout({{k}, {m, k}, {m}}))
610       if (auto reductionSize = getReductionSize(lhsType, 0))
611         return outerProd(t(rhs), lhs, res, lhsType, *reductionSize, mask);
612     // Case vec-mat-trans: swap and ready to go.
613     if (layout({{k}, {k, m}, {m}}))
614       if (auto reductionSize = getReductionSize(lhsType, 0))
615         return outerProd(rhs, lhs, res, lhsType, *reductionSize, mask);
616     return failure();
617   }
618 
619 private:
620   vector::CombiningKind kind;
621   Value lhs, rhs, res, mask;
622   VectorType lhsType;
623 };
624 
625 /// Progressively lower a `vector.contract %a, %b, %c` with row-major matmul
626 /// semantics to a reduction_size-unrolled sequence:
627 /// ```
628 ///    %at = vector.transpose %a, [1, 0]
629 ///    %bRow0 = vector.extract %b[0]
630 ///    %atRow0 = vector.extract %at[0]
631 ///    %c0 = vector.outerproduct %atRow0, %bRow0, %c
632 ///    ...
633 ///    %bRowK = vector.extract %b[K]
634 ///    %atRowK = vector.extract %at[K]
635 ///    %cK = vector.outerproduct %atRowK, %bRowK, %cK-1
636 /// ```
637 ///
638 /// This only kicks in when VectorTransformsOptions is set to OuterProduct but
639 /// otherwise supports any layout permutation of the matrix-multiply.
640 FailureOr<Value>
641 ContractionOpToOuterProductOpLowering::matchAndRewriteMaskableOp(
642     vector::ContractionOp op, MaskingOpInterface maskOp,
643     PatternRewriter &rewriter) const {
644   if (vectorTransformOptions.vectorContractLowering !=
645       vector::VectorContractLowering::OuterProduct)
646     return failure();
647 
648   if (failed(filter(op)))
649     return failure();
650 
651   UnrolledOuterProductGenerator e(rewriter, op);
652   FailureOr<Value> matmatRes = e.matmat();
653   if (succeeded(matmatRes)) {
654     return matmatRes;
655   }
656   FailureOr<Value> matvecRes = e.matvec();
657   if (succeeded(matvecRes)) {
658     return matvecRes;
659   }
660 
661   FailureOr<Value> tmatvecRes = e.tmatvec();
662   return tmatvecRes;
663 }
664 
665 FailureOr<Value> ContractionOpToDotLowering::matchAndRewriteMaskableOp(
666     vector::ContractionOp op, MaskingOpInterface maskOp,
667     PatternRewriter &rewriter) const {
668   // TODO: Support vector.mask.
669   if (maskOp)
670     return failure();
671 
672   if (failed(filter(op)))
673     return failure();
674 
675   if (vectorTransformOptions.vectorContractLowering !=
676       vector::VectorContractLowering::Dot)
677     return failure();
678 
679   auto iteratorTypes = op.getIteratorTypes().getValue();
680   static constexpr std::array<int64_t, 2> perm = {1, 0};
681   Location loc = op.getLoc();
682   Value lhs = op.getLhs(), rhs = op.getRhs();
683 
684   using MapList = ArrayRef<ArrayRef<AffineExpr>>;
685   auto infer = [&](MapList m) {
686     return AffineMap::inferFromExprList(m, op.getContext());
687   };
688   AffineExpr m, n, k;
689   bindDims(rewriter.getContext(), m, n, k);
690   SmallVector<AffineMap> maps = op.getIndexingMapsArray();
691   //
692   // In the following we wish to make the reduction dimension innermost so we
693   // can load vectors and just fmul + reduce into a scalar.
694   //
695   if (isParallelIterator(iteratorTypes[0]) &&
696       isParallelIterator(iteratorTypes[1]) &&
697       isReductionIterator(iteratorTypes[2])) {
698     //
699     // Two outer parallel, one inner reduction (matmat flavor).
700     //
701     if (maps == infer({{m, k}, {k, n}, {m, n}})) {
702       rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
703     } else if (maps == infer({{m, k}, {n, k}, {m, n}})) {
704       // No need to permute anything.
705     } else if (maps == infer({{k, m}, {k, n}, {m, n}})) {
706       lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
707       rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
708     } else if (maps == infer({{k, m}, {n, k}, {m, n}})) {
709       lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
710     } else if (maps == infer({{m, k}, {k, n}, {n, m}})) {
711       // This is the classical row-major matmul. Just permute the lhs.
712       Value tmp = lhs;
713       lhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
714       rhs = tmp;
715     } else if (maps == infer({{m, k}, {n, k}, {n, m}})) {
716       std::swap(lhs, rhs);
717     } else if (maps == infer({{k, m}, {k, n}, {n, m}})) {
718       Value tmp = lhs;
719       lhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
720       rhs = rewriter.create<vector::TransposeOp>(loc, tmp, perm);
721     } else if (maps == infer({{k, m}, {n, k}, {n, m}})) {
722       Value tmp = rhs;
723       rhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
724       lhs = tmp;
725     } else {
726       return failure();
727     }
728   } else if (isParallelIterator(iteratorTypes[0]) &&
729              isReductionIterator(iteratorTypes[1])) {
730     //
731     // One outer parallel, one inner reduction (matvec flavor)
732     //
733     if (maps == infer({{m, n}, {n}, {m}})) {
734       // No need to permute anything.
735     } else if (maps == infer({{n, m}, {n}, {m}})) {
736       lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
737     } else if (maps == infer({{n}, {m, n}, {m}})) {
738       std::swap(lhs, rhs);
739     } else if (maps == infer({{n}, {n, m}, {m}})) {
740       std::swap(lhs, rhs);
741       lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
742     } else {
743       return failure();
744     }
745   } else {
746     return failure();
747   }
748 
749   VectorType dstType = cast<VectorType>(op.getResultType());
750   assert(dstType.getRank() >= 1 && dstType.getRank() <= 2 &&
751          "Expected dst type of rank 1 or 2");
752 
753   unsigned rank = dstType.getRank();
754   unsigned dstRows = dstType.getShape()[0];
755   unsigned dstColumns = rank == 1 ? 1 : dstType.getShape()[1];
756 
757   // ExtractOp does not allow dynamic indexing, we must unroll explicitly.
758   Value res = rewriter.create<arith::ConstantOp>(loc, dstType,
759                                                  rewriter.getZeroAttr(dstType));
760   bool isInt = isa<IntegerType>(dstType.getElementType());
761   for (unsigned r = 0; r < dstRows; ++r) {
762     Value a = rewriter.create<vector::ExtractOp>(op.getLoc(), lhs, r);
763     for (unsigned c = 0; c < dstColumns; ++c) {
764       Value b = rank == 1
765                     ? rhs
766                     : rewriter.create<vector::ExtractOp>(op.getLoc(), rhs, c);
767       Value m = createMul(op.getLoc(), a, b, isInt, rewriter);
768       Value reduced = rewriter.create<vector::ReductionOp>(
769           op.getLoc(), vector::CombiningKind::ADD, m);
770 
771       SmallVector<int64_t, 2> pos = rank == 1 ? SmallVector<int64_t, 2>{r}
772                                               : SmallVector<int64_t, 2>{r, c};
773       res = rewriter.create<vector::InsertOp>(op.getLoc(), reduced, res, pos);
774     }
775   }
776   if (auto acc = op.getAcc())
777     res = createAdd(op.getLoc(), res, acc, isInt, rewriter);
778   return res;
779 }
780 
781 /// Lower vector.contract with all size one reduction dimensions to
782 /// elementwise ops when possible.
783 struct ContractOpToElementwise
784     : public MaskableOpRewritePattern<vector::ContractionOp> {
785   using MaskableOpRewritePattern::MaskableOpRewritePattern;
786   using FilterConstraintType =
787       std::function<LogicalResult(vector::ContractionOp op)>;
788   static LogicalResult defaultFilter(vector::ContractionOp op) {
789     return success();
790   }
791   ContractOpToElementwise(
792       vector::VectorTransformsOptions vectorTransformOptions,
793       MLIRContext *context, PatternBenefit benefit = 1,
794       const FilterConstraintType &constraint = defaultFilter)
795       : MaskableOpRewritePattern<vector::ContractionOp>(context, benefit),
796         vectorTransformOptions(vectorTransformOptions), filter(defaultFilter) {}
797 
798   FailureOr<Value>
799   matchAndRewriteMaskableOp(vector::ContractionOp contractOp,
800                             MaskingOpInterface maskOp,
801                             PatternRewriter &rewriter) const override {
802     // TODO: Support vector.mask.
803     if (maskOp)
804       return failure();
805 
806     if (failed(filter(contractOp)))
807       return failure();
808 
809     if (vectorTransformOptions.vectorContractLowering !=
810         vector::VectorContractLowering::ParallelArith)
811       return failure();
812 
813     ArrayRef<int64_t> lhsShape = contractOp.getLhsType().getShape();
814     ArrayRef<int64_t> rhsShape = contractOp.getRhsType().getShape();
815     AffineMap lhsMap = contractOp.getIndexingMapsArray()[0];
816     AffineMap rhsMap = contractOp.getIndexingMapsArray()[1];
817     SmallVector<int64_t> lhsReductionDims =
818         getReductionIndex(lhsMap, contractOp.getIteratorTypes());
819     SmallVector<int64_t> rhsReductionDims =
820         getReductionIndex(rhsMap, contractOp.getIteratorTypes());
821     // All the reduction dimensions must be a size 1.
822     for (int64_t dim : lhsReductionDims) {
823       if (lhsShape[dim] != 1)
824         return failure();
825     }
826     for (int64_t dim : rhsReductionDims) {
827       if (rhsShape[dim] != 1)
828         return failure();
829     }
830     AffineMap accMap = contractOp.getIndexingMapsArray()[2];
831     unsigned numParallelDims = accMap.getNumResults();
832     unsigned numLhsDimToBroadcast =
833         numParallelDims - (lhsMap.getNumResults() - lhsReductionDims.size());
834     unsigned numRhsDimToBroadcast =
835         numParallelDims - (rhsMap.getNumResults() - rhsReductionDims.size());
836     SmallVector<int64_t> lhsDims;
837     SmallVector<int64_t> lhsTranspose;
838     SmallVector<int64_t> rhsDims;
839     SmallVector<int64_t> rhsTranspose;
840     for (int64_t dim : lhsReductionDims)
841       lhsTranspose.push_back(numLhsDimToBroadcast + dim);
842     for (int64_t dim : rhsReductionDims)
843       rhsTranspose.push_back(numRhsDimToBroadcast + dim);
844     // Loop through the parallel dimensions to calculate the dimensions to
845     // broadcast and to permute in order to extract only parallel dimensions.
846     for (unsigned i = 0; i < numParallelDims; i++) {
847       std::optional<unsigned> lhsDim =
848           getDimPosition(lhsMap, accMap.getDimPosition(i));
849       if (lhsDim) {
850         lhsTranspose.push_back(numLhsDimToBroadcast + *lhsDim);
851       } else {
852         // If the parallel dimension doesn't exist we will have to broadcast it.
853         lhsDims.push_back(
854             cast<VectorType>(contractOp.getResultType()).getDimSize(i));
855         lhsTranspose.push_back(lhsDims.size() - 1);
856       }
857       std::optional<unsigned> rhsDim =
858           getDimPosition(rhsMap, accMap.getDimPosition(i));
859       if (rhsDim) {
860         rhsTranspose.push_back(numRhsDimToBroadcast + *rhsDim);
861       } else {
862         // If the parallel dimension doesn't exist we will have to broadcast it.
863         rhsDims.push_back(
864             cast<VectorType>(contractOp.getResultType()).getDimSize(i));
865         rhsTranspose.push_back(rhsDims.size() - 1);
866       }
867     }
868     Value newLhs = contractOp.getLhs();
869     Value newRhs = contractOp.getRhs();
870     Location loc = contractOp.getLoc();
871     if (!lhsDims.empty()) {
872       lhsDims.append(lhsShape.begin(), lhsShape.end());
873       auto expandedType =
874           VectorType::get(lhsDims, contractOp.getLhsType().getElementType());
875       newLhs = rewriter.create<vector::BroadcastOp>(loc, expandedType, newLhs);
876     }
877     if (!rhsDims.empty()) {
878       rhsDims.append(rhsShape.begin(), rhsShape.end());
879       auto expandedType =
880           VectorType::get(rhsDims, contractOp.getRhsType().getElementType());
881       newRhs = rewriter.create<vector::BroadcastOp>(loc, expandedType, newRhs);
882     }
883     bool isInt = contractOp.getLhsType().getElementType().isIntOrIndex();
884     newLhs = rewriter.create<vector::TransposeOp>(loc, newLhs, lhsTranspose);
885     newRhs = rewriter.create<vector::TransposeOp>(loc, newRhs, rhsTranspose);
886     SmallVector<int64_t> lhsOffsets(lhsReductionDims.size(), 0);
887     SmallVector<int64_t> rhsOffsets(rhsReductionDims.size(), 0);
888     newLhs = rewriter.create<vector::ExtractOp>(loc, newLhs, lhsOffsets);
889     newRhs = rewriter.create<vector::ExtractOp>(loc, newRhs, rhsOffsets);
890     std::optional<Value> result =
891         createContractArithOp(loc, newLhs, newRhs, contractOp.getAcc(),
892                               contractOp.getKind(), rewriter, isInt);
893     if (result)
894       return *result;
895 
896     return failure();
897   }
898 
899 private:
900   /// Options to control the vector patterns.
901   vector::VectorTransformsOptions vectorTransformOptions;
902   FilterConstraintType filter;
903 };
904 
905 /// Progressive lowering of ContractionOp.
906 /// One:
907 ///   %x = vector.contract with at least one free/batch dimension
908 /// is replaced by:
909 ///   %a = vector.contract with one less free/batch dimension
910 ///   %b = vector.contract with one less free/batch dimension
911 ///   ..
912 ///   %x = combine %a %b ..
913 /// until a pure contraction is reached (no free/batch dimensions),
914 /// which is replaced by a dot-product.
915 ///
916 /// This only kicks in when either VectorTransformsOptions is set
917 /// to DOT or when other contraction patterns fail.
918 //
919 // TODO: break down into transpose/reshape/cast ops
920 //               when they become available to avoid code dup
921 // TODO: investigate lowering order impact on performance
922 FailureOr<Value> ContractionOpLowering::matchAndRewriteMaskableOp(
923     vector::ContractionOp op, MaskingOpInterface maskOp,
924     PatternRewriter &rewriter) const {
925   if (failed(filter(op)))
926     return failure();
927 
928   // TODO: support mixed mode contract lowering.
929   if (op.getLhsType().getElementType() !=
930           getElementTypeOrSelf(op.getAccType()) ||
931       op.getRhsType().getElementType() != getElementTypeOrSelf(op.getAccType()))
932     return failure();
933 
934   // TODO: the code below assumes the default contraction, make sure it supports
935   // other kinds before enabling this lowering.
936   if (op.getKind() != vector::CombiningKind::ADD) {
937     return rewriter.notifyMatchFailure(
938         op, "contractions other than 'add' not supported");
939   }
940 
941   // TODO: implement benefits, cost models.
942   MLIRContext *ctx = op.getContext();
943 
944   ContractionOpToMatmulOpLowering pat1(vectorTransformOptions, ctx);
945   FailureOr<Value> newVal1 =
946       pat1.matchAndRewriteMaskableOp(op, maskOp, rewriter);
947   if (!failed(newVal1))
948     return newVal1;
949 
950   ContractionOpToOuterProductOpLowering pat2(vectorTransformOptions, ctx);
951   FailureOr<Value> newVal2 =
952       pat2.matchAndRewriteMaskableOp(op, maskOp, rewriter);
953   if (!failed(newVal2))
954     return newVal2;
955 
956   ContractionOpToDotLowering pat3(vectorTransformOptions, ctx);
957   FailureOr<Value> newVal3 =
958       pat3.matchAndRewriteMaskableOp(op, maskOp, rewriter);
959   if (!failed(newVal3))
960     return newVal3;
961 
962   ContractOpToElementwise pat4(vectorTransformOptions, ctx);
963   FailureOr<Value> newVal4 =
964       pat4.matchAndRewriteMaskableOp(op, maskOp, rewriter);
965   if (!failed(newVal4))
966     return newVal4;
967 
968   // Vector mask setup.
969 
970   Value mask;
971   if (maskOp)
972     mask = maskOp.getMask();
973   // Find first batch dimension in LHS/RHS, and lower when found.
974   std::vector<std::pair<int64_t, int64_t>> batchDimMap = op.getBatchDimMap();
975   if (!batchDimMap.empty()) {
976     int64_t lhsIndex = batchDimMap[0].first;
977     int64_t rhsIndex = batchDimMap[0].second;
978     auto newOp = lowerParallel(rewriter, op, lhsIndex, rhsIndex, mask);
979     if (failed(newOp))
980       return failure();
981     return newOp;
982   }
983 
984   // Collect contracting dimensions.
985   std::vector<std::pair<int64_t, int64_t>> contractingDimMap =
986       op.getContractingDimMap();
987   DenseSet<int64_t> lhsContractingDimSet;
988   DenseSet<int64_t> rhsContractingDimSet;
989   for (auto &dimPair : contractingDimMap) {
990     lhsContractingDimSet.insert(dimPair.first);
991     rhsContractingDimSet.insert(dimPair.second);
992   }
993 
994   // Find first free dimension in LHS, and lower when found.
995   VectorType lhsType = op.getLhsType();
996   for (int64_t lhsIndex = 0, e = lhsType.getRank(); lhsIndex < e; ++lhsIndex) {
997     if (lhsContractingDimSet.count(lhsIndex) == 0) {
998       auto newOp = lowerParallel(rewriter, op, lhsIndex, /*rhsIndex=*/-1, mask);
999       if (failed(newOp))
1000         return failure();
1001       return newOp;
1002     }
1003   }
1004 
1005   // Find first free dimension in RHS, and lower when found.
1006   VectorType rhsType = op.getRhsType();
1007   for (int64_t rhsIndex = 0, e = rhsType.getRank(); rhsIndex < e; ++rhsIndex) {
1008     if (rhsContractingDimSet.count(rhsIndex) == 0) {
1009       auto newOp = lowerParallel(rewriter, op, /*lhsIndex=*/-1, rhsIndex, mask);
1010       if (failed(newOp))
1011         return failure();
1012       return newOp;
1013     }
1014   }
1015 
1016   // Lower the first remaining reduction dimension.
1017   if (!contractingDimMap.empty()) {
1018     auto newOp = lowerReduction(rewriter, op, mask);
1019     if (failed(newOp))
1020       return failure();
1021     return newOp;
1022   }
1023 
1024   return failure();
1025 }
1026 
1027 // Lower one parallel dimension.
1028 // Incidentally also tolerates unit-size (hence trivial) reduction dimensions.
1029 // TODO: consider reusing existing contract unrolling
1030 FailureOr<Value> ContractionOpLowering::lowerParallel(PatternRewriter &rewriter,
1031                                                       vector::ContractionOp op,
1032                                                       int64_t lhsIndex,
1033                                                       int64_t rhsIndex,
1034                                                       Value mask) const {
1035   VectorType lhsType = op.getLhsType();
1036   VectorType rhsType = op.getRhsType();
1037   VectorType resType = cast<VectorType>(op.getResultType());
1038   // Find the iterator type index and result index.
1039   SmallVector<AffineMap> iMap = op.getIndexingMapsArray();
1040   int64_t iterIndex = -1;
1041   int64_t dimSize = -1;
1042   if (lhsIndex >= 0) {
1043     iterIndex = iMap[0].getDimPosition(lhsIndex);
1044     if (rhsIndex >= 0 && iterIndex != iMap[1].getDimPosition(rhsIndex))
1045       return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
1046         diag << "expected lhsIndex=" << lhsIndex << " and rhsIndex=" << rhsIndex
1047              << " to map to the same dimension";
1048       });
1049     if (lhsType.getScalableDims()[lhsIndex])
1050       return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
1051         diag << "Unrolling scalable dimension (lhsIndex=" << lhsIndex
1052              << ") is not supported yet";
1053       });
1054     dimSize = lhsType.getDimSize(lhsIndex);
1055   } else if (rhsIndex >= 0) {
1056     iterIndex = iMap[1].getDimPosition(rhsIndex);
1057     if (rhsType.getScalableDims()[rhsIndex])
1058       return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
1059         diag << "Unrolling scalable dimension (rhsIndex=" << rhsIndex
1060              << ") is not supported yet";
1061       });
1062     dimSize = rhsType.getDimSize(rhsIndex);
1063   }
1064   if (iterIndex < 0)
1065     return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
1066       diag << "expected either lhsIndex=" << lhsIndex
1067            << " or rhsIndex=" << rhsIndex << " to be nonnegative";
1068     });
1069   // value_or(-1) means that we tolerate a dimension not appearing
1070   // in the result map. That can't happen for actual parallel iterators, but
1071   // the caller ContractionOpLowering::matchAndRewrite is currently calling
1072   // lowerParallel also for the case of unit-size reduction dims appearing only
1073   // on one of LHS or RHS, not both. At the moment, such cases are created by
1074   // CastAwayContractionLeadingOneDim, so we need to either support that or
1075   // modify that pattern.
1076   int64_t resIndex = getResultIndex(iMap[2], iterIndex).value_or(-1);
1077   if (resIndex == -1 && dimSize != 1)
1078     return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
1079       diag << "expected the dimension for iterIndex=" << iterIndex
1080            << " to either appear in the result map, or to be a unit dimension";
1081     });
1082 
1083   // Construct new iterator types and affine map array attribute.
1084   std::array<AffineMap, 3> lowIndexingMaps = {
1085       adjustMap(iMap[0], iterIndex, rewriter),
1086       adjustMap(iMap[1], iterIndex, rewriter),
1087       adjustMap(iMap[2], iterIndex, rewriter)};
1088   auto lowAffine = rewriter.getAffineMapArrayAttr(lowIndexingMaps);
1089   auto lowIter =
1090       rewriter.getArrayAttr(adjustIter(op.getIteratorTypes(), iterIndex));
1091   // Unroll into a series of lower dimensional vector.contract ops.
1092   Location loc = op.getLoc();
1093   Value result = rewriter.create<arith::ConstantOp>(
1094       loc, resType, rewriter.getZeroAttr(resType));
1095 
1096   for (int64_t d = 0; d < dimSize; ++d) {
1097     auto lhs = reshapeLoad(loc, op.getLhs(), lhsType, lhsIndex, d, rewriter);
1098     auto rhs = reshapeLoad(loc, op.getRhs(), rhsType, rhsIndex, d, rewriter);
1099     auto acc = reshapeLoad(loc, op.getAcc(), resType, resIndex, d, rewriter);
1100 
1101     Value lowMask;
1102     if (mask)
1103       lowMask = reshapeLoad(loc, mask, cast<VectorType>(mask.getType()),
1104                             iterIndex, d, rewriter);
1105 
1106     Operation *lowContract = rewriter.create<vector::ContractionOp>(
1107         loc, lhs, rhs, acc, lowAffine, lowIter);
1108     lowContract = maskOperation(rewriter, lowContract, lowMask);
1109     result = reshapeStore(loc, lowContract->getResult(0), result, resType,
1110                           resIndex, d, rewriter);
1111   }
1112   return result;
1113 }
1114 
1115 // Lower one reduction dimension.
1116 FailureOr<Value> ContractionOpLowering::lowerReduction(
1117     PatternRewriter &rewriter, vector::ContractionOp op, Value mask) const {
1118   auto loc = op.getLoc();
1119   VectorType lhsType = op.getLhsType();
1120   VectorType rhsType = op.getRhsType();
1121   Type resType = op.getResultType();
1122   if (isa<VectorType>(resType))
1123     return rewriter.notifyMatchFailure(op,
1124                                        "did not expect a VectorType result");
1125   bool isInt = isa<IntegerType>(resType);
1126   // Use iterator index 0.
1127   int64_t iterIndex = 0;
1128   SmallVector<AffineMap> iMap = op.getIndexingMapsArray();
1129   std::optional<int64_t> lookupLhs = getResultIndex(iMap[0], iterIndex);
1130   std::optional<int64_t> lookupRhs = getResultIndex(iMap[1], iterIndex);
1131   if (!lookupLhs.has_value())
1132     return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
1133       diag << "expected iterIndex=" << iterIndex << "to map to a LHS dimension";
1134     });
1135   if (!lookupRhs.has_value())
1136     return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
1137       diag << "expected iterIndex=" << iterIndex << "to map to a RHS dimension";
1138     });
1139   int64_t lhsIndex = *lookupLhs;
1140   int64_t rhsIndex = *lookupRhs;
1141   int64_t dimSize = lhsType.getDimSize(lhsIndex);
1142   if (dimSize != rhsType.getDimSize(rhsIndex))
1143     return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
1144       diag << "expect LHS dimension " << lhsIndex
1145            << " to have the same size as RHS dimension " << rhsIndex;
1146     });
1147   // Base case.
1148   if (lhsType.getRank() == 1) {
1149     if (rhsType.getRank() != 1)
1150       return rewriter.notifyMatchFailure(
1151           op, "When LHS has rank 1, expected also RHS to have rank 1");
1152     Value m = createMul(loc, op.getLhs(), op.getRhs(), isInt, rewriter);
1153     auto kind = vector::CombiningKind::ADD;
1154 
1155     Value acc = op.getAcc();
1156     Operation *reductionOp =
1157         acc ? rewriter.create<vector::ReductionOp>(loc, kind, m, acc)
1158             : rewriter.create<vector::ReductionOp>(loc, kind, m);
1159     return maskOperation(rewriter, reductionOp, mask)->getResult(0);
1160   }
1161   // Construct new iterator types and affine map array attribute.
1162   std::array<AffineMap, 3> lowIndexingMaps = {
1163       adjustMap(iMap[0], iterIndex, rewriter),
1164       adjustMap(iMap[1], iterIndex, rewriter),
1165       adjustMap(iMap[2], iterIndex, rewriter)};
1166   auto lowAffine = rewriter.getAffineMapArrayAttr(lowIndexingMaps);
1167   auto lowIter =
1168       rewriter.getArrayAttr(adjustIter(op.getIteratorTypes(), iterIndex));
1169   // Unroll into a series of lower dimensional vector.contract ops.
1170   // By feeding the initial accumulator into the first contraction,
1171   // and the result of each contraction into the next, eventually
1172   // the sum of all reductions is computed.
1173   Value result = op.getAcc();
1174   for (int64_t d = 0; d < dimSize; ++d) {
1175     auto lhs = reshapeLoad(loc, op.getLhs(), lhsType, lhsIndex, d, rewriter);
1176     auto rhs = reshapeLoad(loc, op.getRhs(), rhsType, rhsIndex, d, rewriter);
1177     Value newMask;
1178     if (mask)
1179       newMask = reshapeLoad(loc, mask, cast<VectorType>(mask.getType()),
1180                             iterIndex, d, rewriter);
1181 
1182     Operation *newContract = rewriter.create<vector::ContractionOp>(
1183         loc, lhs, rhs, result, lowAffine, lowIter);
1184     result = maskOperation(rewriter, newContract, newMask)->getResult(0);
1185   }
1186   return result;
1187 }
1188 
1189 /// Progressive lowering of OuterProductOp.
1190 /// One:
1191 ///   %x = vector.outerproduct %lhs, %rhs, %acc
1192 /// is replaced by:
1193 ///   %z = zero-result
1194 ///   %0 = vector.extract %lhs[0]
1195 ///   %1 = vector.broadcast %0
1196 ///   %2 = vector.extract %acc[0]
1197 ///   %3 = vector.fma %1, %rhs, %2
1198 ///   %4 = vector.insert %3, %z[0]
1199 ///   ..
1200 ///   %x = vector.insert %.., %..[N-1]
1201 ///
1202 class OuterProductOpLowering : public OpRewritePattern<vector::OuterProductOp> {
1203 public:
1204   using OpRewritePattern::OpRewritePattern;
1205 
1206   LogicalResult matchAndRewrite(vector::OuterProductOp op,
1207                                 PatternRewriter &rewriter) const override {
1208     VectorType resType = op.getResultVectorType();
1209     if ((resType.getShape().size() >= 2) && resType.allDimsScalable())
1210       return failure();
1211 
1212     auto loc = op.getLoc();
1213 
1214     VectorType lhsType = op.getOperandVectorTypeLHS();
1215     VectorType rhsType = dyn_cast<VectorType>(op.getOperandTypeRHS());
1216     Type eltType = resType.getElementType();
1217     bool isInt = isa<IntegerType, IndexType>(eltType);
1218     Value acc = op.getAcc();
1219     vector::CombiningKind kind = op.getKind();
1220 
1221     // Vector mask setup.
1222     OpBuilder::InsertionGuard guard(rewriter);
1223     auto maskableOp = cast<vector::MaskableOpInterface>(op.getOperation());
1224     Operation *rootOp;
1225     Value mask;
1226     if (maskableOp.isMasked()) {
1227       rewriter.setInsertionPoint(maskableOp.getMaskingOp());
1228       rootOp = maskableOp.getMaskingOp();
1229       mask = maskableOp.getMaskingOp().getMask();
1230     } else {
1231       rootOp = op;
1232     }
1233 
1234     if (!rhsType) {
1235       // Special case: AXPY operation.
1236       Value b = rewriter.create<vector::BroadcastOp>(loc, lhsType, op.getRhs());
1237       std::optional<Value> mult = createContractArithOp(
1238           loc, op.getLhs(), b, acc, kind, rewriter, isInt, mask);
1239       if (!mult.has_value())
1240         return failure();
1241       rewriter.replaceOp(rootOp, *mult);
1242       return success();
1243     }
1244 
1245     Value result = rewriter.create<arith::ConstantOp>(
1246         loc, resType, rewriter.getZeroAttr(resType));
1247     for (int64_t d = 0, e = resType.getDimSize(0); d < e; ++d) {
1248       Value x = rewriter.create<vector::ExtractOp>(loc, op.getLhs(), d);
1249       Value a = rewriter.create<vector::BroadcastOp>(loc, rhsType, x);
1250       Value r = nullptr;
1251       if (acc)
1252         r = rewriter.create<vector::ExtractOp>(loc, acc, d);
1253       Value extrMask;
1254       if (mask)
1255         extrMask = rewriter.create<vector::ExtractOp>(loc, mask, d);
1256 
1257       std::optional<Value> m = createContractArithOp(
1258           loc, a, op.getRhs(), r, kind, rewriter, isInt, extrMask);
1259       if (!m.has_value())
1260         return failure();
1261       result = rewriter.create<vector::InsertOp>(loc, *m, result, d);
1262     }
1263 
1264     rewriter.replaceOp(rootOp, result);
1265     return success();
1266   }
1267 };
1268 
1269 /// Progressively lower a `vector.contract %a, %b, %c` with row-major matmul
1270 /// semantics to:
1271 /// ```
1272 ///    %mta = maybe_transpose
1273 ///    %mtb = maybe_transpose
1274 ///    %flattened_a = vector.shape_cast %mta
1275 ///    %flattened_b = vector.shape_cast %mtb
1276 ///    %flattened_d = vector.matmul %flattened_a, %flattened_b
1277 ///    %mtd = vector.shape_cast %flattened_d
1278 ///    %d = maybe_untranspose %mtd
1279 ///    %e = add %c, %d
1280 /// ```
1281 /// `vector.matmul` later lowers to `llvm.matrix.multiply`.
1282 //
1283 /// This only kicks in when VectorTransformsOptions is set to `Matmul`.
1284 /// vector.transpose operations are inserted if the vector.contract op is not a
1285 /// row-major matrix multiply.
1286 ///
1287 /// Scalable vectors are not supported.
1288 FailureOr<Value> ContractionOpToMatmulOpLowering::matchAndRewriteMaskableOp(
1289     vector::ContractionOp op, MaskingOpInterface maskOp,
1290     PatternRewriter &rew) const {
1291   // TODO: Support vector.mask.
1292   if (maskOp)
1293     return failure();
1294 
1295   if (vectorTransformOptions.vectorContractLowering !=
1296       vector::VectorContractLowering::Matmul)
1297     return failure();
1298   if (failed(filter(op)))
1299     return failure();
1300 
1301   auto iteratorTypes = op.getIteratorTypes().getValue();
1302   if (!isParallelIterator(iteratorTypes[0]) ||
1303       !isParallelIterator(iteratorTypes[1]) ||
1304       !isReductionIterator(iteratorTypes[2]))
1305     return failure();
1306 
1307   Type opResType = op.getType();
1308   VectorType vecType = dyn_cast<VectorType>(opResType);
1309   if (vecType && vecType.isScalable()) {
1310     // Note - this is sufficient to reject all cases with scalable vectors.
1311     return failure();
1312   }
1313 
1314   Type elementType = op.getLhsType().getElementType();
1315   if (!elementType.isIntOrFloat())
1316     return failure();
1317 
1318   Type dstElementType = vecType ? vecType.getElementType() : opResType;
1319   if (elementType != dstElementType)
1320     return failure();
1321 
1322   // Perform lhs + rhs transpositions to conform to matmul row-major semantics.
1323   // Bail out if the contraction cannot be put in this form.
1324   MLIRContext *ctx = op.getContext();
1325   Location loc = op.getLoc();
1326   AffineExpr m, n, k;
1327   bindDims(rew.getContext(), m, n, k);
1328   // LHS must be A(m, k) or A(k, m).
1329   Value lhs = op.getLhs();
1330   auto lhsMap = op.getIndexingMapsArray()[0];
1331   if (lhsMap == AffineMap::get(3, 0, {k, m}, ctx))
1332     lhs = rew.create<vector::TransposeOp>(loc, lhs, ArrayRef<int64_t>{1, 0});
1333   else if (lhsMap != AffineMap::get(3, 0, {m, k}, ctx))
1334     return failure();
1335 
1336   // RHS must be B(k, n) or B(n, k).
1337   Value rhs = op.getRhs();
1338   auto rhsMap = op.getIndexingMapsArray()[1];
1339   if (rhsMap == AffineMap::get(3, 0, {n, k}, ctx))
1340     rhs = rew.create<vector::TransposeOp>(loc, rhs, ArrayRef<int64_t>{1, 0});
1341   else if (rhsMap != AffineMap::get(3, 0, {k, n}, ctx))
1342     return failure();
1343 
1344   // At this point lhs and rhs are in row-major.
1345   VectorType lhsType = cast<VectorType>(lhs.getType());
1346   VectorType rhsType = cast<VectorType>(rhs.getType());
1347   int64_t lhsRows = lhsType.getDimSize(0);
1348   int64_t lhsColumns = lhsType.getDimSize(1);
1349   int64_t rhsColumns = rhsType.getDimSize(1);
1350 
1351   Type flattenedLHSType =
1352       VectorType::get(lhsType.getNumElements(), lhsType.getElementType());
1353   lhs = rew.create<vector::ShapeCastOp>(loc, flattenedLHSType, lhs);
1354 
1355   Type flattenedRHSType =
1356       VectorType::get(rhsType.getNumElements(), rhsType.getElementType());
1357   rhs = rew.create<vector::ShapeCastOp>(loc, flattenedRHSType, rhs);
1358 
1359   Value mul = rew.create<vector::MatmulOp>(loc, lhs, rhs, lhsRows, lhsColumns,
1360                                            rhsColumns);
1361   mul = rew.create<vector::ShapeCastOp>(
1362       loc,
1363       VectorType::get({lhsRows, rhsColumns},
1364                       getElementTypeOrSelf(op.getAcc().getType())),
1365       mul);
1366 
1367   // ACC must be C(m, n) or C(n, m).
1368   auto accMap = op.getIndexingMapsArray()[2];
1369   if (accMap == AffineMap::get(3, 0, {n, m}, ctx))
1370     mul = rew.create<vector::TransposeOp>(loc, mul, ArrayRef<int64_t>{1, 0});
1371   else if (accMap != AffineMap::get(3, 0, {m, n}, ctx))
1372     llvm_unreachable("invalid contraction semantics");
1373 
1374   Value res =
1375       isa<IntegerType>(elementType)
1376           ? static_cast<Value>(rew.create<arith::AddIOp>(loc, op.getAcc(), mul))
1377           : static_cast<Value>(
1378                 rew.create<arith::AddFOp>(loc, op.getAcc(), mul));
1379 
1380   return res;
1381 }
1382 } // namespace
1383 
1384 void mlir::vector::populateVectorContractLoweringPatterns(
1385     RewritePatternSet &patterns, VectorTransformsOptions options,
1386     PatternBenefit benefit, bool disableOuterProductLowering) {
1387   if (!disableOuterProductLowering)
1388     patterns.add<OuterProductOpLowering>(patterns.getContext(), benefit);
1389   patterns.add<ContractionOpLowering, ContractionOpToMatmulOpLowering,
1390                ContractionOpToOuterProductOpLowering>(
1391       options, patterns.getContext(), benefit);
1392 }
1393 
1394 void mlir::vector::populateVectorOuterProductLoweringPatterns(
1395     RewritePatternSet &patterns, PatternBenefit benefit) {
1396   patterns.add<OuterProductOpLowering>(patterns.getContext(), benefit);
1397 }
1398