xref: /llvm-project/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp (revision cb89457ff825926f0004711bef3d534df1f5576d)
12bc4c3e9SNicolas Vasilache //===- LowerVectorContract.cpp - Lower 'vector.contract' operation --------===//
22bc4c3e9SNicolas Vasilache //
32bc4c3e9SNicolas Vasilache // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
42bc4c3e9SNicolas Vasilache // See https://llvm.org/LICENSE.txt for license information.
52bc4c3e9SNicolas Vasilache // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
62bc4c3e9SNicolas Vasilache //
72bc4c3e9SNicolas Vasilache //===----------------------------------------------------------------------===//
82bc4c3e9SNicolas Vasilache //
92bc4c3e9SNicolas Vasilache // This file implements target-independent rewrites and utilities to lower the
102bc4c3e9SNicolas Vasilache // 'vector.contract' operation.
112bc4c3e9SNicolas Vasilache //
122bc4c3e9SNicolas Vasilache //===----------------------------------------------------------------------===//
132bc4c3e9SNicolas Vasilache 
142bc4c3e9SNicolas Vasilache #include "mlir/Dialect/Affine/IR/AffineOps.h"
152bc4c3e9SNicolas Vasilache #include "mlir/Dialect/Arith/IR/Arith.h"
162bc4c3e9SNicolas Vasilache #include "mlir/Dialect/Arith/Utils/Utils.h"
172bc4c3e9SNicolas Vasilache #include "mlir/Dialect/Linalg/IR/Linalg.h"
182bc4c3e9SNicolas Vasilache #include "mlir/Dialect/MemRef/IR/MemRef.h"
192bc4c3e9SNicolas Vasilache #include "mlir/Dialect/SCF/IR/SCF.h"
202bc4c3e9SNicolas Vasilache #include "mlir/Dialect/Tensor/IR/Tensor.h"
212bc4c3e9SNicolas Vasilache #include "mlir/Dialect/Utils/IndexingUtils.h"
222bc4c3e9SNicolas Vasilache #include "mlir/Dialect/Utils/StructuredOpsUtils.h"
232bc4c3e9SNicolas Vasilache #include "mlir/Dialect/Vector/IR/VectorOps.h"
242bc4c3e9SNicolas Vasilache #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
252bc4c3e9SNicolas Vasilache #include "mlir/Dialect/Vector/Utils/VectorUtils.h"
262bc4c3e9SNicolas Vasilache #include "mlir/IR/BuiltinAttributeInterfaces.h"
272bc4c3e9SNicolas Vasilache #include "mlir/IR/BuiltinTypes.h"
282bc4c3e9SNicolas Vasilache #include "mlir/IR/ImplicitLocOpBuilder.h"
292bc4c3e9SNicolas Vasilache #include "mlir/IR/Location.h"
302bc4c3e9SNicolas Vasilache #include "mlir/IR/Matchers.h"
312bc4c3e9SNicolas Vasilache #include "mlir/IR/PatternMatch.h"
322bc4c3e9SNicolas Vasilache #include "mlir/IR/TypeUtilities.h"
332bc4c3e9SNicolas Vasilache #include "mlir/Interfaces/VectorInterfaces.h"
342bc4c3e9SNicolas Vasilache 
352bc4c3e9SNicolas Vasilache #define DEBUG_TYPE "vector-contract-lowering"
362bc4c3e9SNicolas Vasilache 
372bc4c3e9SNicolas Vasilache using namespace mlir;
382bc4c3e9SNicolas Vasilache using namespace mlir::vector;
392bc4c3e9SNicolas Vasilache 
402bc4c3e9SNicolas Vasilache //===----------------------------------------------------------------------===//
412bc4c3e9SNicolas Vasilache // Helper functions
422bc4c3e9SNicolas Vasilache //===----------------------------------------------------------------------===//
432bc4c3e9SNicolas Vasilache // Helper to find an index in an affine map.
442bc4c3e9SNicolas Vasilache static std::optional<int64_t> getResultIndex(AffineMap map, int64_t index) {
452bc4c3e9SNicolas Vasilache   for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) {
462bc4c3e9SNicolas Vasilache     int64_t idx = map.getDimPosition(i);
472bc4c3e9SNicolas Vasilache     if (idx == index)
482bc4c3e9SNicolas Vasilache       return i;
492bc4c3e9SNicolas Vasilache   }
502bc4c3e9SNicolas Vasilache   return std::nullopt;
512bc4c3e9SNicolas Vasilache }
522bc4c3e9SNicolas Vasilache 
532bc4c3e9SNicolas Vasilache // Helper to construct iterator types with one index removed.
542bc4c3e9SNicolas Vasilache static SmallVector<Attribute> adjustIter(ArrayAttr iteratorTypes,
552bc4c3e9SNicolas Vasilache                                          int64_t index) {
562bc4c3e9SNicolas Vasilache   SmallVector<Attribute> results;
572bc4c3e9SNicolas Vasilache   for (const auto &it : llvm::enumerate(iteratorTypes)) {
582bc4c3e9SNicolas Vasilache     int64_t idx = it.index();
592bc4c3e9SNicolas Vasilache     if (idx == index)
602bc4c3e9SNicolas Vasilache       continue;
612bc4c3e9SNicolas Vasilache     results.push_back(it.value());
622bc4c3e9SNicolas Vasilache   }
632bc4c3e9SNicolas Vasilache   return results;
642bc4c3e9SNicolas Vasilache }
652bc4c3e9SNicolas Vasilache 
662bc4c3e9SNicolas Vasilache // Helper to construct an affine map with one index removed.
672bc4c3e9SNicolas Vasilache static AffineMap adjustMap(AffineMap map, int64_t index,
682bc4c3e9SNicolas Vasilache                            PatternRewriter &rewriter) {
692bc4c3e9SNicolas Vasilache   auto *ctx = rewriter.getContext();
702bc4c3e9SNicolas Vasilache   SmallVector<AffineExpr> results;
712bc4c3e9SNicolas Vasilache   for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) {
722bc4c3e9SNicolas Vasilache     int64_t idx = map.getDimPosition(i);
732bc4c3e9SNicolas Vasilache     if (idx == index)
742bc4c3e9SNicolas Vasilache       continue;
752bc4c3e9SNicolas Vasilache     // Re-insert remaining indices, but renamed when occurring
762bc4c3e9SNicolas Vasilache     // after the removed index.
772bc4c3e9SNicolas Vasilache     auto targetExpr = getAffineDimExpr(idx < index ? idx : idx - 1, ctx);
782bc4c3e9SNicolas Vasilache     results.push_back(targetExpr);
792bc4c3e9SNicolas Vasilache   }
802bc4c3e9SNicolas Vasilache   return AffineMap::get(map.getNumDims() - 1, 0, results, ctx);
812bc4c3e9SNicolas Vasilache }
822bc4c3e9SNicolas Vasilache 
832bc4c3e9SNicolas Vasilache // Helper method to possibly drop a dimension in a load.
842bc4c3e9SNicolas Vasilache // TODO
852bc4c3e9SNicolas Vasilache static Value reshapeLoad(Location loc, Value val, VectorType type,
862bc4c3e9SNicolas Vasilache                          int64_t index, int64_t pos,
872bc4c3e9SNicolas Vasilache                          PatternRewriter &rewriter) {
882bc4c3e9SNicolas Vasilache   if (index == -1)
892bc4c3e9SNicolas Vasilache     return val;
9098f6289aSDiego Caballero 
912bc4c3e9SNicolas Vasilache   // At extraction dimension?
9216b75cd2SMatthias Springer   if (index == 0)
9398f6289aSDiego Caballero     return rewriter.create<vector::ExtractOp>(loc, val, pos);
9498f6289aSDiego Caballero 
952bc4c3e9SNicolas Vasilache   // Unroll leading dimensions.
9698f6289aSDiego Caballero   VectorType vType = VectorType::Builder(type).dropDim(0);
97296d5cb6SBenjamin Maxwell   VectorType resType = VectorType::Builder(type).dropDim(index);
982bc4c3e9SNicolas Vasilache   Value result = rewriter.create<arith::ConstantOp>(
99296d5cb6SBenjamin Maxwell       loc, resType, rewriter.getZeroAttr(resType));
100296d5cb6SBenjamin Maxwell   for (int64_t d = 0, e = resType.getDimSize(0); d < e; d++) {
10198f6289aSDiego Caballero     Value ext = rewriter.create<vector::ExtractOp>(loc, val, d);
1022bc4c3e9SNicolas Vasilache     Value load = reshapeLoad(loc, ext, vType, index - 1, pos, rewriter);
10398f6289aSDiego Caballero     result = rewriter.create<vector::InsertOp>(loc, load, result, d);
1042bc4c3e9SNicolas Vasilache   }
1052bc4c3e9SNicolas Vasilache   return result;
1062bc4c3e9SNicolas Vasilache }
1072bc4c3e9SNicolas Vasilache 
1082bc4c3e9SNicolas Vasilache // Helper method to possibly drop a dimension in a store.
1092bc4c3e9SNicolas Vasilache // TODO
1102bc4c3e9SNicolas Vasilache static Value reshapeStore(Location loc, Value val, Value result,
1112bc4c3e9SNicolas Vasilache                           VectorType type, int64_t index, int64_t pos,
1122bc4c3e9SNicolas Vasilache                           PatternRewriter &rewriter) {
1132bc4c3e9SNicolas Vasilache   // Unmodified?
1142bc4c3e9SNicolas Vasilache   if (index == -1)
1152bc4c3e9SNicolas Vasilache     return val;
1162bc4c3e9SNicolas Vasilache   // At insertion dimension?
11716b75cd2SMatthias Springer   if (index == 0)
11898f6289aSDiego Caballero     return rewriter.create<vector::InsertOp>(loc, val, result, pos);
11998f6289aSDiego Caballero 
1202bc4c3e9SNicolas Vasilache   // Unroll leading dimensions.
12198f6289aSDiego Caballero   VectorType vType = VectorType::Builder(type).dropDim(0);
1222bc4c3e9SNicolas Vasilache   for (int64_t d = 0, e = type.getDimSize(0); d < e; d++) {
12398f6289aSDiego Caballero     Value ext = rewriter.create<vector::ExtractOp>(loc, result, d);
12498f6289aSDiego Caballero     Value ins = rewriter.create<vector::ExtractOp>(loc, val, d);
12598f6289aSDiego Caballero     Value sto = reshapeStore(loc, ins, ext, vType, index - 1, pos, rewriter);
12698f6289aSDiego Caballero     result = rewriter.create<vector::InsertOp>(loc, sto, result, d);
1272bc4c3e9SNicolas Vasilache   }
1282bc4c3e9SNicolas Vasilache   return result;
1292bc4c3e9SNicolas Vasilache }
1302bc4c3e9SNicolas Vasilache 
1312bc4c3e9SNicolas Vasilache /// Helper to create arithmetic operation associated with a kind of contraction.
1322bc4c3e9SNicolas Vasilache static std::optional<Value>
1332bc4c3e9SNicolas Vasilache createContractArithOp(Location loc, Value x, Value y, Value acc,
1342bc4c3e9SNicolas Vasilache                       vector::CombiningKind kind, PatternRewriter &rewriter,
1352bc4c3e9SNicolas Vasilache                       bool isInt, Value mask = Value()) {
1362bc4c3e9SNicolas Vasilache   using vector::CombiningKind;
1372bc4c3e9SNicolas Vasilache   Value mul;
1382bc4c3e9SNicolas Vasilache 
1392bc4c3e9SNicolas Vasilache   if (isInt) {
140560564f5SJakub Kuderski     if (kind == CombiningKind::MINNUMF || kind == CombiningKind::MAXNUMF ||
1414a831250SDaniil Dudkin         kind == CombiningKind::MINIMUMF || kind == CombiningKind::MAXIMUMF)
1422bc4c3e9SNicolas Vasilache       // Only valid for floating point types.
1432bc4c3e9SNicolas Vasilache       return std::nullopt;
1442bc4c3e9SNicolas Vasilache     mul = rewriter.create<arith::MulIOp>(loc, x, y);
1452bc4c3e9SNicolas Vasilache   } else {
1462bc4c3e9SNicolas Vasilache     // Float case.
1472bc4c3e9SNicolas Vasilache     if (kind == CombiningKind::AND || kind == CombiningKind::MINUI ||
1482bc4c3e9SNicolas Vasilache         kind == CombiningKind::MINSI || kind == CombiningKind::MAXUI ||
1492bc4c3e9SNicolas Vasilache         kind == CombiningKind::MAXSI || kind == CombiningKind::OR ||
1502bc4c3e9SNicolas Vasilache         kind == CombiningKind::XOR)
1512bc4c3e9SNicolas Vasilache       // Only valid for integer types.
1522bc4c3e9SNicolas Vasilache       return std::nullopt;
1532bc4c3e9SNicolas Vasilache     // Special case for fused multiply-add.
1545550c821STres Popp     if (acc && isa<VectorType>(acc.getType()) && kind == CombiningKind::ADD) {
1552bc4c3e9SNicolas Vasilache       Value fma = rewriter.create<vector::FMAOp>(loc, x, y, acc);
1562bc4c3e9SNicolas Vasilache       if (mask)
1572bc4c3e9SNicolas Vasilache         // The fma op doesn't need explicit masking. However, fma ops used in
1582bc4c3e9SNicolas Vasilache         // reductions must preserve previous 'acc' values for masked-out lanes.
1592bc4c3e9SNicolas Vasilache         fma = selectPassthru(rewriter, mask, fma, acc);
1602bc4c3e9SNicolas Vasilache       return fma;
1612bc4c3e9SNicolas Vasilache     }
1622bc4c3e9SNicolas Vasilache     mul = rewriter.create<arith::MulFOp>(loc, x, y);
1632bc4c3e9SNicolas Vasilache   }
1642bc4c3e9SNicolas Vasilache 
1652bc4c3e9SNicolas Vasilache   if (!acc)
1662bc4c3e9SNicolas Vasilache     return std::optional<Value>(mul);
1672bc4c3e9SNicolas Vasilache 
168a528cee2SJakub Kuderski   return makeArithReduction(rewriter, loc, kind, mul, acc,
169a528cee2SJakub Kuderski                             /*fastmath=*/nullptr, mask);
1702bc4c3e9SNicolas Vasilache }
1712bc4c3e9SNicolas Vasilache 
1722bc4c3e9SNicolas Vasilache /// Return the positions of the reductions in the given map.
1732bc4c3e9SNicolas Vasilache static SmallVector<int64_t> getReductionIndex(AffineMap map,
1742bc4c3e9SNicolas Vasilache                                               ArrayAttr iteratorTypes) {
1752bc4c3e9SNicolas Vasilache   SmallVector<int64_t> dimsIdx;
1762bc4c3e9SNicolas Vasilache   for (unsigned i = 0, e = map.getNumResults(); i < e; i++) {
1772bc4c3e9SNicolas Vasilache     if (isReductionIterator(iteratorTypes[map.getDimPosition(i)]))
1782bc4c3e9SNicolas Vasilache       dimsIdx.push_back(i);
1792bc4c3e9SNicolas Vasilache   }
1802bc4c3e9SNicolas Vasilache   return dimsIdx;
1812bc4c3e9SNicolas Vasilache }
1822bc4c3e9SNicolas Vasilache 
1832bc4c3e9SNicolas Vasilache /// Look for a given dimension in an affine map and return its position. Return
1842bc4c3e9SNicolas Vasilache /// std::nullopt if the dimension is not in the map results.
1852bc4c3e9SNicolas Vasilache static std::optional<unsigned> getDimPosition(AffineMap map, unsigned dim) {
1862bc4c3e9SNicolas Vasilache   for (unsigned i = 0, e = map.getNumResults(); i < e; i++) {
1872bc4c3e9SNicolas Vasilache     if (map.getDimPosition(i) == dim)
1882bc4c3e9SNicolas Vasilache       return i;
1892bc4c3e9SNicolas Vasilache   }
1902bc4c3e9SNicolas Vasilache   return std::nullopt;
1912bc4c3e9SNicolas Vasilache }
1922bc4c3e9SNicolas Vasilache 
1932bc4c3e9SNicolas Vasilache /// Creates an AddIOp if `isInt` is true otherwise create an arith::AddFOp using
1942bc4c3e9SNicolas Vasilache /// operands `x` and `y`.
1952bc4c3e9SNicolas Vasilache static Value createAdd(Location loc, Value x, Value y, bool isInt,
1962bc4c3e9SNicolas Vasilache                        PatternRewriter &rewriter) {
1972bc4c3e9SNicolas Vasilache   if (isInt)
1982bc4c3e9SNicolas Vasilache     return rewriter.create<arith::AddIOp>(loc, x, y);
1992bc4c3e9SNicolas Vasilache   return rewriter.create<arith::AddFOp>(loc, x, y);
2002bc4c3e9SNicolas Vasilache }
2012bc4c3e9SNicolas Vasilache 
2022bc4c3e9SNicolas Vasilache /// Creates a MulIOp if `isInt` is true otherwise create an MulFOp using
2032bc4c3e9SNicolas Vasilache /// operands `x and `y`.
2042bc4c3e9SNicolas Vasilache static Value createMul(Location loc, Value x, Value y, bool isInt,
2052bc4c3e9SNicolas Vasilache                        PatternRewriter &rewriter) {
2062bc4c3e9SNicolas Vasilache   if (isInt)
2072bc4c3e9SNicolas Vasilache     return rewriter.create<arith::MulIOp>(loc, x, y);
2082bc4c3e9SNicolas Vasilache   return rewriter.create<arith::MulFOp>(loc, x, y);
2092bc4c3e9SNicolas Vasilache }
2102bc4c3e9SNicolas Vasilache 
2112bc4c3e9SNicolas Vasilache namespace {
2122bc4c3e9SNicolas Vasilache 
2132bc4c3e9SNicolas Vasilache /// Progressive lowering of a `vector.contract %a, %b, %c` with row-major matmul
2142bc4c3e9SNicolas Vasilache /// semantics to:
2152bc4c3e9SNicolas Vasilache /// ```
2162bc4c3e9SNicolas Vasilache ///    %flattened_a = vector.shape_cast %a
2172bc4c3e9SNicolas Vasilache ///    %flattened_b = vector.shape_cast %b
2182bc4c3e9SNicolas Vasilache ///    %flattened_d = vector.matmul %flattened_a, %flattened_b
2192bc4c3e9SNicolas Vasilache ///    %d = vector.shape_cast %%flattened_d
2202bc4c3e9SNicolas Vasilache ///    %e = add %c, %d
2212bc4c3e9SNicolas Vasilache /// ```
2222bc4c3e9SNicolas Vasilache /// `vector.matmul` later lowers to `llvm.matrix.multiply`.
2232bc4c3e9SNicolas Vasilache //
2242bc4c3e9SNicolas Vasilache /// This only kicks in when VectorTransformsOptions is set to OuterProduct and
2252bc4c3e9SNicolas Vasilache /// the vector.contract op is a row-major matrix multiply.
2262bc4c3e9SNicolas Vasilache class ContractionOpToMatmulOpLowering
227b7324b6aSAndrzej Warzyński     : public vector::MaskableOpRewritePattern<vector::ContractionOp> {
2282bc4c3e9SNicolas Vasilache public:
229b7324b6aSAndrzej Warzyński   using MaskableOpRewritePattern::MaskableOpRewritePattern;
2302bc4c3e9SNicolas Vasilache 
2312bc4c3e9SNicolas Vasilache   using FilterConstraintType =
2322bc4c3e9SNicolas Vasilache       std::function<LogicalResult(vector::ContractionOp op)>;
2332bc4c3e9SNicolas Vasilache 
2342bc4c3e9SNicolas Vasilache   static LogicalResult defaultFilter(vector::ContractionOp op) {
2352bc4c3e9SNicolas Vasilache     return success();
2362bc4c3e9SNicolas Vasilache   }
2372bc4c3e9SNicolas Vasilache 
2382bc4c3e9SNicolas Vasilache   ContractionOpToMatmulOpLowering(
2392bc4c3e9SNicolas Vasilache       vector::VectorTransformsOptions vectorTransformOptions,
2402bc4c3e9SNicolas Vasilache       MLIRContext *context, PatternBenefit benefit = 1,
2412bc4c3e9SNicolas Vasilache       FilterConstraintType constraint = defaultFilter)
242b7324b6aSAndrzej Warzyński       : MaskableOpRewritePattern<vector::ContractionOp>(context, benefit),
2432bc4c3e9SNicolas Vasilache         vectorTransformOptions(vectorTransformOptions),
2442bc4c3e9SNicolas Vasilache         filter(std::move(constraint)) {}
2452bc4c3e9SNicolas Vasilache 
246b7324b6aSAndrzej Warzyński   FailureOr<Value>
247b7324b6aSAndrzej Warzyński   matchAndRewriteMaskableOp(vector::ContractionOp op, MaskingOpInterface maskOp,
2482bc4c3e9SNicolas Vasilache                             PatternRewriter &rewriter) const override;
2492bc4c3e9SNicolas Vasilache 
2502bc4c3e9SNicolas Vasilache private:
2512bc4c3e9SNicolas Vasilache   /// Options to control the vector patterns.
2522bc4c3e9SNicolas Vasilache   vector::VectorTransformsOptions vectorTransformOptions;
2532bc4c3e9SNicolas Vasilache   FilterConstraintType filter;
2542bc4c3e9SNicolas Vasilache };
2552bc4c3e9SNicolas Vasilache 
2562bc4c3e9SNicolas Vasilache /// Progressive lowering of a `vector.contract %a, %b, %c` with row-major matmul
2572bc4c3e9SNicolas Vasilache /// semantics to a reduction_size-unrolled sequence:
2582bc4c3e9SNicolas Vasilache /// ```
2592bc4c3e9SNicolas Vasilache ///    %at = vector.transpose %a, [1, 0]
2602bc4c3e9SNicolas Vasilache ///    %bRow0 = vector.extract %b[0]
2612bc4c3e9SNicolas Vasilache ///    %atRow0 = vector.extract %at[0]
2622bc4c3e9SNicolas Vasilache ///    %c0 = vector.outerproduct %atRow0, %bRow0, %c
2632bc4c3e9SNicolas Vasilache ///    ...
2642bc4c3e9SNicolas Vasilache ///    %bRowK = vector.extract %b[K]
2652bc4c3e9SNicolas Vasilache ///    %atRowK = vector.extract %at[K]
2662bc4c3e9SNicolas Vasilache ///    %cK = vector.outerproduct %atRowK, %bRowK, %cK-1
2672bc4c3e9SNicolas Vasilache /// ```
2682bc4c3e9SNicolas Vasilache ///
2692bc4c3e9SNicolas Vasilache /// This only kicks in when VectorTransformsOptions is set to OuterProduct and
2702bc4c3e9SNicolas Vasilache /// the vector.contract op is a row-major matrix multiply.
2712bc4c3e9SNicolas Vasilache class ContractionOpToOuterProductOpLowering
272b7324b6aSAndrzej Warzyński     : public MaskableOpRewritePattern<vector::ContractionOp> {
2732bc4c3e9SNicolas Vasilache public:
274b7324b6aSAndrzej Warzyński   using MaskableOpRewritePattern::MaskableOpRewritePattern;
2752bc4c3e9SNicolas Vasilache 
2762bc4c3e9SNicolas Vasilache   using FilterConstraintType =
2772bc4c3e9SNicolas Vasilache       std::function<LogicalResult(vector::ContractionOp op)>;
2782bc4c3e9SNicolas Vasilache 
2792bc4c3e9SNicolas Vasilache   static LogicalResult defaultFilter(vector::ContractionOp op) {
2802bc4c3e9SNicolas Vasilache     return success();
2812bc4c3e9SNicolas Vasilache   }
2822bc4c3e9SNicolas Vasilache 
2832bc4c3e9SNicolas Vasilache   ContractionOpToOuterProductOpLowering(
2842bc4c3e9SNicolas Vasilache       vector::VectorTransformsOptions vectorTransformOptions,
2852bc4c3e9SNicolas Vasilache       MLIRContext *context, PatternBenefit benefit = 1,
2862bc4c3e9SNicolas Vasilache       FilterConstraintType constraint = defaultFilter)
287b7324b6aSAndrzej Warzyński       : MaskableOpRewritePattern<vector::ContractionOp>(context, benefit),
2882bc4c3e9SNicolas Vasilache         vectorTransformOptions(vectorTransformOptions),
2892bc4c3e9SNicolas Vasilache         filter(std::move(constraint)) {}
2902bc4c3e9SNicolas Vasilache 
291b7324b6aSAndrzej Warzyński   FailureOr<Value>
292b7324b6aSAndrzej Warzyński   matchAndRewriteMaskableOp(vector::ContractionOp op, MaskingOpInterface maskOp,
2932bc4c3e9SNicolas Vasilache                             PatternRewriter &rewriter) const override;
2942bc4c3e9SNicolas Vasilache 
2952bc4c3e9SNicolas Vasilache private:
2962bc4c3e9SNicolas Vasilache   /// Options to control the vector patterns.
2972bc4c3e9SNicolas Vasilache   vector::VectorTransformsOptions vectorTransformOptions;
2982bc4c3e9SNicolas Vasilache   FilterConstraintType filter;
2992bc4c3e9SNicolas Vasilache };
3002bc4c3e9SNicolas Vasilache 
3012bc4c3e9SNicolas Vasilache /// Progressive lowering of a `vector.contract %a, %b, %c` with row-major matmul
3022bc4c3e9SNicolas Vasilache /// semantics to an output-size-unrolled sequence:
3032bc4c3e9SNicolas Vasilache /// ```
3042bc4c3e9SNicolas Vasilache ///    %out = arith.constant ... : vector<MxNxelt_type>
3052bc4c3e9SNicolas Vasilache ///    %bt = vector.transpose %b, [1, 0]
3062bc4c3e9SNicolas Vasilache ///    %aRow0 = vector.extract %a[0]
3072bc4c3e9SNicolas Vasilache ///    %btRow0 = vector.extract %bt[0]
3082bc4c3e9SNicolas Vasilache ///    %c00 = vector.reduce %atRow0, %bRow0
3092bc4c3e9SNicolas Vasilache ///    %out00 = vector.insert %c00, %out[0, 0]
3102bc4c3e9SNicolas Vasilache ///    ...
3112bc4c3e9SNicolas Vasilache ///    %aRowLast = vector.extract %at[M-1]
3122bc4c3e9SNicolas Vasilache ///    %btRowLast = vector.extract %b[N-1]
3132bc4c3e9SNicolas Vasilache ///    %cLastLast = vector.reduce %atRowLast, %bRowLast
3142bc4c3e9SNicolas Vasilache ///    %outcLastLast = vector.insert %cLastLast, %out[M-1, N-1]
3152bc4c3e9SNicolas Vasilache /// ```
3162bc4c3e9SNicolas Vasilache ///
3172bc4c3e9SNicolas Vasilache /// This only kicks in when VectorTransformsOptions is set to Dot and
3182bc4c3e9SNicolas Vasilache /// the vector.contract op is a row-major matmul or matvec.
3192bc4c3e9SNicolas Vasilache class ContractionOpToDotLowering
320b7324b6aSAndrzej Warzyński     : public MaskableOpRewritePattern<vector::ContractionOp> {
3212bc4c3e9SNicolas Vasilache public:
322b7324b6aSAndrzej Warzyński   using MaskableOpRewritePattern::MaskableOpRewritePattern;
3232bc4c3e9SNicolas Vasilache 
3242bc4c3e9SNicolas Vasilache   using FilterConstraintType =
3252bc4c3e9SNicolas Vasilache       std::function<LogicalResult(vector::ContractionOp op)>;
3262bc4c3e9SNicolas Vasilache 
3272bc4c3e9SNicolas Vasilache   static LogicalResult defaultFilter(vector::ContractionOp op) {
3282bc4c3e9SNicolas Vasilache     return success();
3292bc4c3e9SNicolas Vasilache   }
3302bc4c3e9SNicolas Vasilache 
3312bc4c3e9SNicolas Vasilache   ContractionOpToDotLowering(
3322bc4c3e9SNicolas Vasilache       vector::VectorTransformsOptions vectorTransformOptions,
3332bc4c3e9SNicolas Vasilache       MLIRContext *context, PatternBenefit benefit = 1,
3342bc4c3e9SNicolas Vasilache       const FilterConstraintType &constraint = defaultFilter)
335b7324b6aSAndrzej Warzyński       : MaskableOpRewritePattern<vector::ContractionOp>(context, benefit),
3362bc4c3e9SNicolas Vasilache         vectorTransformOptions(vectorTransformOptions), filter(defaultFilter) {}
3372bc4c3e9SNicolas Vasilache 
338b7324b6aSAndrzej Warzyński   FailureOr<Value>
339b7324b6aSAndrzej Warzyński   matchAndRewriteMaskableOp(vector::ContractionOp op, MaskingOpInterface maskOp,
3402bc4c3e9SNicolas Vasilache                             PatternRewriter &rewriter) const override;
3412bc4c3e9SNicolas Vasilache 
3422bc4c3e9SNicolas Vasilache private:
3432bc4c3e9SNicolas Vasilache   /// Options to control the vector patterns.
3442bc4c3e9SNicolas Vasilache   vector::VectorTransformsOptions vectorTransformOptions;
3452bc4c3e9SNicolas Vasilache   FilterConstraintType filter;
3462bc4c3e9SNicolas Vasilache };
3472bc4c3e9SNicolas Vasilache 
3482bc4c3e9SNicolas Vasilache /// Progressive lowering of ContractionOp.
3492bc4c3e9SNicolas Vasilache ///
3502bc4c3e9SNicolas Vasilache /// One:
3512bc4c3e9SNicolas Vasilache ///   %x = vector.contract with at least one free/batch dimension
3522bc4c3e9SNicolas Vasilache /// is replaced by:
3532bc4c3e9SNicolas Vasilache ///   %a = vector.contract with one less free/batch dimension
3542bc4c3e9SNicolas Vasilache ///   %b = vector.contract with one less free/batch dimension
3552bc4c3e9SNicolas Vasilache ///   ..
3562bc4c3e9SNicolas Vasilache ///   %x = combine %a %b ..
3572bc4c3e9SNicolas Vasilache /// until a pure contraction is reached (no free/batch dimensions),
3582bc4c3e9SNicolas Vasilache /// which is replaced by a dot-product.
3592bc4c3e9SNicolas Vasilache ///
3602bc4c3e9SNicolas Vasilache /// This only kicks in when either VectorTransformsOptions is set
3612bc4c3e9SNicolas Vasilache /// to Dot or when other contraction patterns fail.
362b7324b6aSAndrzej Warzyński class ContractionOpLowering
363b7324b6aSAndrzej Warzyński     : public MaskableOpRewritePattern<vector::ContractionOp> {
3642bc4c3e9SNicolas Vasilache public:
365b7324b6aSAndrzej Warzyński   using MaskableOpRewritePattern::MaskableOpRewritePattern;
3662bc4c3e9SNicolas Vasilache   using FilterConstraintType =
3672bc4c3e9SNicolas Vasilache       std::function<LogicalResult(vector::ContractionOp op)>;
3682bc4c3e9SNicolas Vasilache 
3692bc4c3e9SNicolas Vasilache   static LogicalResult defaultFilter(vector::ContractionOp op) {
3702bc4c3e9SNicolas Vasilache     return success();
3712bc4c3e9SNicolas Vasilache   }
3722bc4c3e9SNicolas Vasilache 
3732bc4c3e9SNicolas Vasilache   ContractionOpLowering(vector::VectorTransformsOptions vectorTransformOptions,
3742bc4c3e9SNicolas Vasilache                         MLIRContext *context, PatternBenefit benefit = 1,
3752bc4c3e9SNicolas Vasilache                         FilterConstraintType constraint = defaultFilter)
376b7324b6aSAndrzej Warzyński       : MaskableOpRewritePattern<vector::ContractionOp>(context, benefit),
3772bc4c3e9SNicolas Vasilache         vectorTransformOptions(vectorTransformOptions),
3782bc4c3e9SNicolas Vasilache         filter(std::move(constraint)) {}
3792bc4c3e9SNicolas Vasilache 
380b7324b6aSAndrzej Warzyński   FailureOr<Value>
381b7324b6aSAndrzej Warzyński   matchAndRewriteMaskableOp(vector::ContractionOp op, MaskingOpInterface maskOp,
3822bc4c3e9SNicolas Vasilache                             PatternRewriter &rewriter) const override;
3832bc4c3e9SNicolas Vasilache 
3842bc4c3e9SNicolas Vasilache private:
3852bc4c3e9SNicolas Vasilache   /// Options to control the vector patterns.
3862bc4c3e9SNicolas Vasilache   vector::VectorTransformsOptions vectorTransformOptions;
3872bc4c3e9SNicolas Vasilache   FilterConstraintType filter;
3882bc4c3e9SNicolas Vasilache   // Lower one parallel dimension.
3892bc4c3e9SNicolas Vasilache   FailureOr<Value> lowerParallel(PatternRewriter &rewriter,
3902bc4c3e9SNicolas Vasilache                                  vector::ContractionOp op, int64_t lhsIndex,
3912bc4c3e9SNicolas Vasilache                                  int64_t rhsIndex, Value mask) const;
3922bc4c3e9SNicolas Vasilache   // Lower one reduction dimension.
3932bc4c3e9SNicolas Vasilache   FailureOr<Value> lowerReduction(PatternRewriter &rewriter,
3942bc4c3e9SNicolas Vasilache                                   vector::ContractionOp op, Value mask) const;
3952bc4c3e9SNicolas Vasilache };
3962bc4c3e9SNicolas Vasilache 
3972bc4c3e9SNicolas Vasilache /// Generate a vector implementation for matmat, matvec and tmatvec.
3982bc4c3e9SNicolas Vasilache /// This unrolls outer-products along the reduction dimension.
3992bc4c3e9SNicolas Vasilache struct UnrolledOuterProductGenerator
4002bc4c3e9SNicolas Vasilache     : public StructuredGenerator<vector::ContractionOp, vector::IteratorType> {
4012bc4c3e9SNicolas Vasilache   UnrolledOuterProductGenerator(RewriterBase &b, vector::ContractionOp op)
4022bc4c3e9SNicolas Vasilache       : StructuredGenerator<vector::ContractionOp, vector::IteratorType>(b, op),
4032bc4c3e9SNicolas Vasilache         kind(op.getKind()), lhs(op.getLhs()), rhs(op.getRhs()),
4042bc4c3e9SNicolas Vasilache         res(op.getAcc()), lhsType(op.getLhsType()) {
4052bc4c3e9SNicolas Vasilache     auto maskableOp = cast<MaskableOpInterface>(op.getOperation());
4062bc4c3e9SNicolas Vasilache     if (maskableOp.isMasked())
4072bc4c3e9SNicolas Vasilache       mask = maskableOp.getMaskingOp().getMask();
4082bc4c3e9SNicolas Vasilache   }
4092bc4c3e9SNicolas Vasilache 
4102bc4c3e9SNicolas Vasilache   Value t(Value v, ArrayRef<int64_t> perm = {1, 0}) {
4112bc4c3e9SNicolas Vasilache     if (!v)
4122bc4c3e9SNicolas Vasilache       return v;
4132bc4c3e9SNicolas Vasilache     return rewriter.create<vector::TransposeOp>(loc, v, perm);
4142bc4c3e9SNicolas Vasilache   }
4152bc4c3e9SNicolas Vasilache 
4162bc4c3e9SNicolas Vasilache   Value promote(Value v, Type dstElementType) {
4172bc4c3e9SNicolas Vasilache     Type elementType = v.getType();
4185550c821STres Popp     auto vecType = dyn_cast<VectorType>(elementType);
4192bc4c3e9SNicolas Vasilache     if (vecType)
4202bc4c3e9SNicolas Vasilache       elementType = vecType.getElementType();
4212bc4c3e9SNicolas Vasilache     if (elementType == dstElementType)
4222bc4c3e9SNicolas Vasilache       return v;
4232bc4c3e9SNicolas Vasilache     Type promotedType = dstElementType;
4242bc4c3e9SNicolas Vasilache     if (vecType)
4255270df3dSAndrzej Warzyński       promotedType = vecType.clone(promotedType);
4265550c821STres Popp     if (isa<FloatType>(dstElementType))
4272bc4c3e9SNicolas Vasilache       return rewriter.create<arith::ExtFOp>(loc, promotedType, v);
4282bc4c3e9SNicolas Vasilache     return rewriter.create<arith::ExtSIOp>(loc, promotedType, v);
4292bc4c3e9SNicolas Vasilache   }
4302bc4c3e9SNicolas Vasilache 
431c91d3b0bSAndrzej Warzynski   FailureOr<Value> outerProd(Value lhs, Value rhs, Value res,
432c0a354dfSMatthias Springer                              VectorType lhsType, int reductionSize,
4332bc4c3e9SNicolas Vasilache                              std::optional<Value> maybeMask = std::nullopt) {
4342bc4c3e9SNicolas Vasilache     // Incremental support for masking.
4352bc4c3e9SNicolas Vasilache     if (mask && !maybeMask.has_value())
4362bc4c3e9SNicolas Vasilache       return failure();
4372bc4c3e9SNicolas Vasilache 
4385550c821STres Popp     Type resElementType = cast<VectorType>(res.getType()).getElementType();
4392bc4c3e9SNicolas Vasilache     for (int64_t k = 0; k < reductionSize; ++k) {
4402bc4c3e9SNicolas Vasilache       Value extractA = rewriter.create<vector::ExtractOp>(loc, lhs, k);
4412bc4c3e9SNicolas Vasilache       Value extractB = rewriter.create<vector::ExtractOp>(loc, rhs, k);
4422bc4c3e9SNicolas Vasilache       extractA = promote(extractA, resElementType);
4432bc4c3e9SNicolas Vasilache       extractB = promote(extractB, resElementType);
4442bc4c3e9SNicolas Vasilache       Value extractMask;
4452bc4c3e9SNicolas Vasilache       if (maybeMask.has_value() && maybeMask.value())
4462bc4c3e9SNicolas Vasilache         extractMask =
4472bc4c3e9SNicolas Vasilache             rewriter.create<vector::ExtractOp>(loc, maybeMask.value(), k);
4482bc4c3e9SNicolas Vasilache 
4492bc4c3e9SNicolas Vasilache       Operation *outerProdOp = rewriter.create<vector::OuterProductOp>(
4502bc4c3e9SNicolas Vasilache           loc, res.getType(), extractA, extractB, res, kind);
4512bc4c3e9SNicolas Vasilache       res = maskOperation(rewriter, outerProdOp, extractMask)->getResult(0);
4522bc4c3e9SNicolas Vasilache     }
4532bc4c3e9SNicolas Vasilache     return res;
4542bc4c3e9SNicolas Vasilache   }
4552bc4c3e9SNicolas Vasilache 
456c0a354dfSMatthias Springer   /// Helper function for `matmat`, `matvec`, `tmatvec`. Returns the size of
457c0a354dfSMatthias Springer   /// dimension `reductionDim`. If the dimension is a scalable dimension,
458c0a354dfSMatthias Springer   /// returns "nullopt".
459c0a354dfSMatthias Springer   std::optional<int64_t> getReductionSize(VectorType vecType,
460c0a354dfSMatthias Springer                                           int64_t reductionDim) {
461c0a354dfSMatthias Springer     // Cannot unroll scalable dimension.
462c0a354dfSMatthias Springer     if (vecType.getScalableDims()[reductionDim])
463c0a354dfSMatthias Springer       return std::nullopt;
464c0a354dfSMatthias Springer     int64_t reductionSize = vecType.getDimSize(reductionDim);
465c0a354dfSMatthias Springer     assert(reductionSize > 0 &&
466c0a354dfSMatthias Springer            "Reduction dim must be a known static size to allow unrolling");
467c0a354dfSMatthias Springer     return reductionSize;
468c0a354dfSMatthias Springer   }
469c0a354dfSMatthias Springer 
4702bc4c3e9SNicolas Vasilache   /// Two outer parallel, one inner reduction (matmat flavor).
4712bc4c3e9SNicolas Vasilache   FailureOr<Value> matmat() {
4722bc4c3e9SNicolas Vasilache     if (!iters({Par(), Par(), Red()}))
4732bc4c3e9SNicolas Vasilache       return failure();
4742bc4c3e9SNicolas Vasilache     // Set up the parallel/reduction structure in the right form.
4752bc4c3e9SNicolas Vasilache     AffineExpr m, n, k;
4762bc4c3e9SNicolas Vasilache     bindDims(rewriter.getContext(), m, n, k);
477c0a354dfSMatthias Springer 
4782bc4c3e9SNicolas Vasilache     // Classical row-major matmul:  Just permute the lhs.
479c0a354dfSMatthias Springer     if (layout({{m, k}, {k, n}, {m, n}})) {
480c0a354dfSMatthias Springer       if (auto reductionSize = getReductionSize(lhsType, 1)) {
481c0a354dfSMatthias Springer         // Note: `t` creates new IR. It must be nested within this `if` check
482c0a354dfSMatthias Springer         // so that no IR is created when then pattern returns "failure".
483c0a354dfSMatthias Springer         Value tLhs = t(lhs);
484c0a354dfSMatthias Springer         Value tMask = t(mask, {2, 0, 1});
485c0a354dfSMatthias Springer         return outerProd(tLhs, rhs, res, lhsType, *reductionSize, tMask);
486c0a354dfSMatthias Springer       }
487c0a354dfSMatthias Springer     }
4882bc4c3e9SNicolas Vasilache     // TODO: may be better to fail and use some vector<k> -> scalar reduction.
4892bc4c3e9SNicolas Vasilache     if (layout({{m, k}, {n, k}, {m, n}})) {
490c0a354dfSMatthias Springer       if (auto reductionSize = getReductionSize(lhsType, 1)) {
491c0a354dfSMatthias Springer         Value tLhs = t(lhs);
492c0a354dfSMatthias Springer         Value tRhs = t(rhs);
493c0a354dfSMatthias Springer         Value tMask = t(mask, {2, 0, 1});
494c0a354dfSMatthias Springer         return outerProd(tLhs, tRhs, res, lhsType, *reductionSize, tMask);
495c0a354dfSMatthias Springer       }
4962bc4c3e9SNicolas Vasilache     }
4972bc4c3e9SNicolas Vasilache     // No need to permute anything.
498c0a354dfSMatthias Springer     if (layout({{k, m}, {k, n}, {m, n}})) {
499c0a354dfSMatthias Springer       if (auto reductionSize = getReductionSize(lhsType, 0)) {
500c0a354dfSMatthias Springer         Value tMask = t(mask, {2, 0, 1});
501c0a354dfSMatthias Springer         return outerProd(lhs, rhs, res, lhsType, *reductionSize, tMask);
502c0a354dfSMatthias Springer       }
503c0a354dfSMatthias Springer     }
5042bc4c3e9SNicolas Vasilache     // Just permute the rhs.
505c0a354dfSMatthias Springer     if (layout({{k, m}, {n, k}, {m, n}})) {
506c0a354dfSMatthias Springer       if (auto reductionSize = getReductionSize(lhsType, 0)) {
507c0a354dfSMatthias Springer         Value tRhs = t(rhs);
508c0a354dfSMatthias Springer         Value tMask = t(mask, {2, 0, 1});
509c0a354dfSMatthias Springer         return outerProd(lhs, tRhs, res, lhsType, *reductionSize, tMask);
510c0a354dfSMatthias Springer       }
511c0a354dfSMatthias Springer     }
5122bc4c3e9SNicolas Vasilache     // Transposed output: swap RHS and LHS.
5132bc4c3e9SNicolas Vasilache     // Classical row-major matmul: permute the lhs.
514c0a354dfSMatthias Springer     if (layout({{m, k}, {k, n}, {n, m}})) {
515c0a354dfSMatthias Springer       if (auto reductionSize = getReductionSize(lhsType, 1)) {
516c0a354dfSMatthias Springer         Value tLhs = t(lhs);
517c0a354dfSMatthias Springer         Value tMask = t(mask, {2, 0, 1});
518c0a354dfSMatthias Springer         return outerProd(rhs, tLhs, res, lhsType, *reductionSize, tMask);
519c0a354dfSMatthias Springer       }
520c0a354dfSMatthias Springer     }
5212bc4c3e9SNicolas Vasilache     // TODO: may be better to fail and use some vector<k> -> scalar reduction.
5222bc4c3e9SNicolas Vasilache     if (layout({{m, k}, {n, k}, {n, m}})) {
523c0a354dfSMatthias Springer       if (auto reductionSize = getReductionSize(lhsType, 1)) {
524c0a354dfSMatthias Springer         Value tRhs = t(rhs);
525c0a354dfSMatthias Springer         Value tLhs = t(lhs);
526c0a354dfSMatthias Springer         Value tMask = t(mask, {2, 0, 1});
527c0a354dfSMatthias Springer         return outerProd(tRhs, tLhs, res, lhsType, *reductionSize, tMask);
5282bc4c3e9SNicolas Vasilache       }
529c0a354dfSMatthias Springer     }
530c0a354dfSMatthias Springer     if (layout({{k, m}, {k, n}, {n, m}})) {
531c0a354dfSMatthias Springer       if (auto reductionSize = getReductionSize(lhsType, 0)) {
532c0a354dfSMatthias Springer         Value tMask = t(mask, {2, 0, 1});
533c0a354dfSMatthias Springer         return outerProd(rhs, lhs, res, lhsType, *reductionSize, tMask);
534c0a354dfSMatthias Springer       }
535c0a354dfSMatthias Springer     }
536c0a354dfSMatthias Springer     if (layout({{k, m}, {n, k}, {n, m}})) {
537c0a354dfSMatthias Springer       if (auto reductionSize = getReductionSize(lhsType, 0)) {
538c0a354dfSMatthias Springer         Value tRhs = t(rhs);
539c0a354dfSMatthias Springer         Value tMask = t(mask, {2, 0, 1});
540c0a354dfSMatthias Springer         return outerProd(tRhs, lhs, res, lhsType, *reductionSize, tMask);
541c0a354dfSMatthias Springer       }
542c0a354dfSMatthias Springer     }
5432bc4c3e9SNicolas Vasilache     return failure();
5442bc4c3e9SNicolas Vasilache   }
5452bc4c3e9SNicolas Vasilache 
546a509a187SOleksandr "Alex" Zinenko   //
547a509a187SOleksandr "Alex" Zinenko   // One outer parallel, one inner reduction (matvec flavor).
548a509a187SOleksandr "Alex" Zinenko   // Mask needs to be transposed everywhere to turn the reduction dimension
549a509a187SOleksandr "Alex" Zinenko   // outermost as required by outerproduct.
550a509a187SOleksandr "Alex" Zinenko   //
5512bc4c3e9SNicolas Vasilache   FailureOr<Value> matvec() {
5522bc4c3e9SNicolas Vasilache     if (!iters({Par(), Red()}))
5532bc4c3e9SNicolas Vasilache       return failure();
5542bc4c3e9SNicolas Vasilache     AffineExpr m, k;
5552bc4c3e9SNicolas Vasilache     bindDims(rewriter.getContext(), m, k);
5562bc4c3e9SNicolas Vasilache 
5572bc4c3e9SNicolas Vasilache     // Case mat-vec: transpose.
558c0a354dfSMatthias Springer     if (layout({{m, k}, {k}, {m}})) {
559c0a354dfSMatthias Springer       if (auto reductionSize = getReductionSize(lhsType, 1)) {
560c0a354dfSMatthias Springer         Value tLhs = t(lhs);
561c0a354dfSMatthias Springer         Value tMask = t(mask);
562c0a354dfSMatthias Springer         return outerProd(tLhs, rhs, res, lhsType, *reductionSize, tMask);
563c0a354dfSMatthias Springer       }
564c0a354dfSMatthias Springer     }
5652bc4c3e9SNicolas Vasilache     // Case mat-trans-vec: ready to go.
566c0a354dfSMatthias Springer     if (layout({{k, m}, {k}, {m}})) {
567c0a354dfSMatthias Springer       if (auto reductionSize = getReductionSize(lhsType, 0)) {
568c0a354dfSMatthias Springer         Value tMask = t(mask);
569c0a354dfSMatthias Springer         return outerProd(lhs, rhs, res, lhsType, *reductionSize, tMask);
570c0a354dfSMatthias Springer       }
571c0a354dfSMatthias Springer     }
5722bc4c3e9SNicolas Vasilache     // Case vec-mat: swap and transpose.
573c0a354dfSMatthias Springer     if (layout({{k}, {m, k}, {m}})) {
574c0a354dfSMatthias Springer       if (auto reductionSize = getReductionSize(lhsType, 0)) {
575c0a354dfSMatthias Springer         Value tRhs = t(rhs);
576c0a354dfSMatthias Springer         Value tMask = t(mask);
577c0a354dfSMatthias Springer         return outerProd(tRhs, lhs, res, lhsType, *reductionSize, tMask);
578c0a354dfSMatthias Springer       }
579c0a354dfSMatthias Springer     }
5802bc4c3e9SNicolas Vasilache     // Case vec-mat-trans: swap and ready to go.
581c0a354dfSMatthias Springer     if (layout({{k}, {k, m}, {m}})) {
582c0a354dfSMatthias Springer       if (auto reductionSize = getReductionSize(lhsType, 0)) {
583c0a354dfSMatthias Springer         Value tMask = t(mask);
584c0a354dfSMatthias Springer         return outerProd(rhs, lhs, res, lhsType, *reductionSize, tMask);
585c0a354dfSMatthias Springer       }
586c0a354dfSMatthias Springer     }
5872bc4c3e9SNicolas Vasilache     return failure();
5882bc4c3e9SNicolas Vasilache   }
5892bc4c3e9SNicolas Vasilache 
5902bc4c3e9SNicolas Vasilache   //
591a509a187SOleksandr "Alex" Zinenko   // One outer reduction, one inner parallel (tmatvec flavor).
592a509a187SOleksandr "Alex" Zinenko   // Mask already has the shape of the outer product.
5932bc4c3e9SNicolas Vasilache   //
5942bc4c3e9SNicolas Vasilache   FailureOr<Value> tmatvec() {
5952bc4c3e9SNicolas Vasilache     if (!iters({Red(), Par()}))
5962bc4c3e9SNicolas Vasilache       return failure();
5972bc4c3e9SNicolas Vasilache     AffineExpr k, m;
5982bc4c3e9SNicolas Vasilache     bindDims(rewriter.getContext(), k, m);
5992bc4c3e9SNicolas Vasilache 
6002bc4c3e9SNicolas Vasilache     // Case mat-vec: transpose.
6012bc4c3e9SNicolas Vasilache     if (layout({{m, k}, {k}, {m}}))
602c0a354dfSMatthias Springer       if (auto reductionSize = getReductionSize(lhsType, 1))
603c0a354dfSMatthias Springer         return outerProd(t(lhs), rhs, res, lhsType, *reductionSize, mask);
6042bc4c3e9SNicolas Vasilache     // Case mat-trans-vec: ready to go.
6052bc4c3e9SNicolas Vasilache     if (layout({{k, m}, {k}, {m}}))
606c0a354dfSMatthias Springer       if (auto reductionSize = getReductionSize(lhsType, 0))
607c0a354dfSMatthias Springer         return outerProd(lhs, rhs, res, lhsType, *reductionSize, mask);
6082bc4c3e9SNicolas Vasilache     // Case vec-mat: swap and transpose.
6092bc4c3e9SNicolas Vasilache     if (layout({{k}, {m, k}, {m}}))
610c0a354dfSMatthias Springer       if (auto reductionSize = getReductionSize(lhsType, 0))
611c0a354dfSMatthias Springer         return outerProd(t(rhs), lhs, res, lhsType, *reductionSize, mask);
6122bc4c3e9SNicolas Vasilache     // Case vec-mat-trans: swap and ready to go.
6132bc4c3e9SNicolas Vasilache     if (layout({{k}, {k, m}, {m}}))
614c0a354dfSMatthias Springer       if (auto reductionSize = getReductionSize(lhsType, 0))
615c0a354dfSMatthias Springer         return outerProd(rhs, lhs, res, lhsType, *reductionSize, mask);
6162bc4c3e9SNicolas Vasilache     return failure();
6172bc4c3e9SNicolas Vasilache   }
6182bc4c3e9SNicolas Vasilache 
6192bc4c3e9SNicolas Vasilache private:
6202bc4c3e9SNicolas Vasilache   vector::CombiningKind kind;
6212bc4c3e9SNicolas Vasilache   Value lhs, rhs, res, mask;
6222bc4c3e9SNicolas Vasilache   VectorType lhsType;
6232bc4c3e9SNicolas Vasilache };
6242bc4c3e9SNicolas Vasilache 
6252bc4c3e9SNicolas Vasilache /// Progressively lower a `vector.contract %a, %b, %c` with row-major matmul
6262bc4c3e9SNicolas Vasilache /// semantics to a reduction_size-unrolled sequence:
6272bc4c3e9SNicolas Vasilache /// ```
6282bc4c3e9SNicolas Vasilache ///    %at = vector.transpose %a, [1, 0]
6292bc4c3e9SNicolas Vasilache ///    %bRow0 = vector.extract %b[0]
6302bc4c3e9SNicolas Vasilache ///    %atRow0 = vector.extract %at[0]
6312bc4c3e9SNicolas Vasilache ///    %c0 = vector.outerproduct %atRow0, %bRow0, %c
6322bc4c3e9SNicolas Vasilache ///    ...
6332bc4c3e9SNicolas Vasilache ///    %bRowK = vector.extract %b[K]
6342bc4c3e9SNicolas Vasilache ///    %atRowK = vector.extract %at[K]
6352bc4c3e9SNicolas Vasilache ///    %cK = vector.outerproduct %atRowK, %bRowK, %cK-1
6362bc4c3e9SNicolas Vasilache /// ```
6372bc4c3e9SNicolas Vasilache ///
6382bc4c3e9SNicolas Vasilache /// This only kicks in when VectorTransformsOptions is set to OuterProduct but
6392bc4c3e9SNicolas Vasilache /// otherwise supports any layout permutation of the matrix-multiply.
640b7324b6aSAndrzej Warzyński FailureOr<Value>
641b7324b6aSAndrzej Warzyński ContractionOpToOuterProductOpLowering::matchAndRewriteMaskableOp(
642b7324b6aSAndrzej Warzyński     vector::ContractionOp op, MaskingOpInterface maskOp,
643b7324b6aSAndrzej Warzyński     PatternRewriter &rewriter) const {
6442bc4c3e9SNicolas Vasilache   if (vectorTransformOptions.vectorContractLowering !=
6452bc4c3e9SNicolas Vasilache       vector::VectorContractLowering::OuterProduct)
6462bc4c3e9SNicolas Vasilache     return failure();
6472bc4c3e9SNicolas Vasilache 
6482bc4c3e9SNicolas Vasilache   if (failed(filter(op)))
6492bc4c3e9SNicolas Vasilache     return failure();
6502bc4c3e9SNicolas Vasilache 
6512bc4c3e9SNicolas Vasilache   UnrolledOuterProductGenerator e(rewriter, op);
6522bc4c3e9SNicolas Vasilache   FailureOr<Value> matmatRes = e.matmat();
6532bc4c3e9SNicolas Vasilache   if (succeeded(matmatRes)) {
654b7324b6aSAndrzej Warzyński     return matmatRes;
6552bc4c3e9SNicolas Vasilache   }
6562bc4c3e9SNicolas Vasilache   FailureOr<Value> matvecRes = e.matvec();
6572bc4c3e9SNicolas Vasilache   if (succeeded(matvecRes)) {
658b7324b6aSAndrzej Warzyński     return matvecRes;
6592bc4c3e9SNicolas Vasilache   }
660b7324b6aSAndrzej Warzyński 
6612bc4c3e9SNicolas Vasilache   FailureOr<Value> tmatvecRes = e.tmatvec();
662b7324b6aSAndrzej Warzyński   return tmatvecRes;
6632bc4c3e9SNicolas Vasilache }
6642bc4c3e9SNicolas Vasilache 
665b7324b6aSAndrzej Warzyński FailureOr<Value> ContractionOpToDotLowering::matchAndRewriteMaskableOp(
666b7324b6aSAndrzej Warzyński     vector::ContractionOp op, MaskingOpInterface maskOp,
6672bc4c3e9SNicolas Vasilache     PatternRewriter &rewriter) const {
6682bc4c3e9SNicolas Vasilache   // TODO: Support vector.mask.
669b7324b6aSAndrzej Warzyński   if (maskOp)
6702bc4c3e9SNicolas Vasilache     return failure();
6712bc4c3e9SNicolas Vasilache 
6722bc4c3e9SNicolas Vasilache   if (failed(filter(op)))
6732bc4c3e9SNicolas Vasilache     return failure();
6742bc4c3e9SNicolas Vasilache 
6752bc4c3e9SNicolas Vasilache   if (vectorTransformOptions.vectorContractLowering !=
6762bc4c3e9SNicolas Vasilache       vector::VectorContractLowering::Dot)
6772bc4c3e9SNicolas Vasilache     return failure();
6782bc4c3e9SNicolas Vasilache 
6792bc4c3e9SNicolas Vasilache   auto iteratorTypes = op.getIteratorTypes().getValue();
6802bc4c3e9SNicolas Vasilache   static constexpr std::array<int64_t, 2> perm = {1, 0};
6812bc4c3e9SNicolas Vasilache   Location loc = op.getLoc();
6822bc4c3e9SNicolas Vasilache   Value lhs = op.getLhs(), rhs = op.getRhs();
6832bc4c3e9SNicolas Vasilache 
6842bc4c3e9SNicolas Vasilache   using MapList = ArrayRef<ArrayRef<AffineExpr>>;
685fe8a62c4SUday Bondhugula   auto infer = [&](MapList m) {
686fe8a62c4SUday Bondhugula     return AffineMap::inferFromExprList(m, op.getContext());
687fe8a62c4SUday Bondhugula   };
6882bc4c3e9SNicolas Vasilache   AffineExpr m, n, k;
6892bc4c3e9SNicolas Vasilache   bindDims(rewriter.getContext(), m, n, k);
6902bc4c3e9SNicolas Vasilache   SmallVector<AffineMap> maps = op.getIndexingMapsArray();
6912bc4c3e9SNicolas Vasilache   //
6922bc4c3e9SNicolas Vasilache   // In the following we wish to make the reduction dimension innermost so we
6932bc4c3e9SNicolas Vasilache   // can load vectors and just fmul + reduce into a scalar.
6942bc4c3e9SNicolas Vasilache   //
6952bc4c3e9SNicolas Vasilache   if (isParallelIterator(iteratorTypes[0]) &&
6962bc4c3e9SNicolas Vasilache       isParallelIterator(iteratorTypes[1]) &&
6972bc4c3e9SNicolas Vasilache       isReductionIterator(iteratorTypes[2])) {
6982bc4c3e9SNicolas Vasilache     //
6992bc4c3e9SNicolas Vasilache     // Two outer parallel, one inner reduction (matmat flavor).
7002bc4c3e9SNicolas Vasilache     //
7012bc4c3e9SNicolas Vasilache     if (maps == infer({{m, k}, {k, n}, {m, n}})) {
7022bc4c3e9SNicolas Vasilache       rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
7032bc4c3e9SNicolas Vasilache     } else if (maps == infer({{m, k}, {n, k}, {m, n}})) {
7042bc4c3e9SNicolas Vasilache       // No need to permute anything.
7052bc4c3e9SNicolas Vasilache     } else if (maps == infer({{k, m}, {k, n}, {m, n}})) {
7062bc4c3e9SNicolas Vasilache       lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
7072bc4c3e9SNicolas Vasilache       rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
7082bc4c3e9SNicolas Vasilache     } else if (maps == infer({{k, m}, {n, k}, {m, n}})) {
7092bc4c3e9SNicolas Vasilache       lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
7102bc4c3e9SNicolas Vasilache     } else if (maps == infer({{m, k}, {k, n}, {n, m}})) {
7112bc4c3e9SNicolas Vasilache       // This is the classical row-major matmul. Just permute the lhs.
7122bc4c3e9SNicolas Vasilache       Value tmp = lhs;
7132bc4c3e9SNicolas Vasilache       lhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
7142bc4c3e9SNicolas Vasilache       rhs = tmp;
7152bc4c3e9SNicolas Vasilache     } else if (maps == infer({{m, k}, {n, k}, {n, m}})) {
7162bc4c3e9SNicolas Vasilache       std::swap(lhs, rhs);
7172bc4c3e9SNicolas Vasilache     } else if (maps == infer({{k, m}, {k, n}, {n, m}})) {
7182bc4c3e9SNicolas Vasilache       Value tmp = lhs;
7192bc4c3e9SNicolas Vasilache       lhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
7202bc4c3e9SNicolas Vasilache       rhs = rewriter.create<vector::TransposeOp>(loc, tmp, perm);
7212bc4c3e9SNicolas Vasilache     } else if (maps == infer({{k, m}, {n, k}, {n, m}})) {
7222bc4c3e9SNicolas Vasilache       Value tmp = rhs;
7232bc4c3e9SNicolas Vasilache       rhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
7242bc4c3e9SNicolas Vasilache       lhs = tmp;
7252bc4c3e9SNicolas Vasilache     } else {
7262bc4c3e9SNicolas Vasilache       return failure();
7272bc4c3e9SNicolas Vasilache     }
7282bc4c3e9SNicolas Vasilache   } else if (isParallelIterator(iteratorTypes[0]) &&
7292bc4c3e9SNicolas Vasilache              isReductionIterator(iteratorTypes[1])) {
7302bc4c3e9SNicolas Vasilache     //
7312bc4c3e9SNicolas Vasilache     // One outer parallel, one inner reduction (matvec flavor)
7322bc4c3e9SNicolas Vasilache     //
7332bc4c3e9SNicolas Vasilache     if (maps == infer({{m, n}, {n}, {m}})) {
7342bc4c3e9SNicolas Vasilache       // No need to permute anything.
7352bc4c3e9SNicolas Vasilache     } else if (maps == infer({{n, m}, {n}, {m}})) {
7362bc4c3e9SNicolas Vasilache       lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
7372bc4c3e9SNicolas Vasilache     } else if (maps == infer({{n}, {m, n}, {m}})) {
7382bc4c3e9SNicolas Vasilache       std::swap(lhs, rhs);
7392bc4c3e9SNicolas Vasilache     } else if (maps == infer({{n}, {n, m}, {m}})) {
7402bc4c3e9SNicolas Vasilache       std::swap(lhs, rhs);
7412bc4c3e9SNicolas Vasilache       lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
7422bc4c3e9SNicolas Vasilache     } else {
7432bc4c3e9SNicolas Vasilache       return failure();
7442bc4c3e9SNicolas Vasilache     }
7452bc4c3e9SNicolas Vasilache   } else {
7462bc4c3e9SNicolas Vasilache     return failure();
7472bc4c3e9SNicolas Vasilache   }
7482bc4c3e9SNicolas Vasilache 
7495550c821STres Popp   VectorType dstType = cast<VectorType>(op.getResultType());
7502bc4c3e9SNicolas Vasilache   assert(dstType.getRank() >= 1 && dstType.getRank() <= 2 &&
7512bc4c3e9SNicolas Vasilache          "Expected dst type of rank 1 or 2");
7522bc4c3e9SNicolas Vasilache 
7532bc4c3e9SNicolas Vasilache   unsigned rank = dstType.getRank();
7542bc4c3e9SNicolas Vasilache   unsigned dstRows = dstType.getShape()[0];
7552bc4c3e9SNicolas Vasilache   unsigned dstColumns = rank == 1 ? 1 : dstType.getShape()[1];
7562bc4c3e9SNicolas Vasilache 
7572bc4c3e9SNicolas Vasilache   // ExtractOp does not allow dynamic indexing, we must unroll explicitly.
7582bc4c3e9SNicolas Vasilache   Value res = rewriter.create<arith::ConstantOp>(loc, dstType,
7592bc4c3e9SNicolas Vasilache                                                  rewriter.getZeroAttr(dstType));
7605550c821STres Popp   bool isInt = isa<IntegerType>(dstType.getElementType());
7612bc4c3e9SNicolas Vasilache   for (unsigned r = 0; r < dstRows; ++r) {
7622bc4c3e9SNicolas Vasilache     Value a = rewriter.create<vector::ExtractOp>(op.getLoc(), lhs, r);
7632bc4c3e9SNicolas Vasilache     for (unsigned c = 0; c < dstColumns; ++c) {
7642bc4c3e9SNicolas Vasilache       Value b = rank == 1
7652bc4c3e9SNicolas Vasilache                     ? rhs
7662bc4c3e9SNicolas Vasilache                     : rewriter.create<vector::ExtractOp>(op.getLoc(), rhs, c);
7672bc4c3e9SNicolas Vasilache       Value m = createMul(op.getLoc(), a, b, isInt, rewriter);
7682bc4c3e9SNicolas Vasilache       Value reduced = rewriter.create<vector::ReductionOp>(
7692bc4c3e9SNicolas Vasilache           op.getLoc(), vector::CombiningKind::ADD, m);
7702bc4c3e9SNicolas Vasilache 
7712bc4c3e9SNicolas Vasilache       SmallVector<int64_t, 2> pos = rank == 1 ? SmallVector<int64_t, 2>{r}
7722bc4c3e9SNicolas Vasilache                                               : SmallVector<int64_t, 2>{r, c};
7732bc4c3e9SNicolas Vasilache       res = rewriter.create<vector::InsertOp>(op.getLoc(), reduced, res, pos);
7742bc4c3e9SNicolas Vasilache     }
7752bc4c3e9SNicolas Vasilache   }
7762bc4c3e9SNicolas Vasilache   if (auto acc = op.getAcc())
7772bc4c3e9SNicolas Vasilache     res = createAdd(op.getLoc(), res, acc, isInt, rewriter);
778b7324b6aSAndrzej Warzyński   return res;
7792bc4c3e9SNicolas Vasilache }
7802bc4c3e9SNicolas Vasilache 
7812bc4c3e9SNicolas Vasilache /// Lower vector.contract with all size one reduction dimensions to
7822bc4c3e9SNicolas Vasilache /// elementwise ops when possible.
7832bc4c3e9SNicolas Vasilache struct ContractOpToElementwise
784b7324b6aSAndrzej Warzyński     : public MaskableOpRewritePattern<vector::ContractionOp> {
785b7324b6aSAndrzej Warzyński   using MaskableOpRewritePattern::MaskableOpRewritePattern;
7862bc4c3e9SNicolas Vasilache   using FilterConstraintType =
7872bc4c3e9SNicolas Vasilache       std::function<LogicalResult(vector::ContractionOp op)>;
7882bc4c3e9SNicolas Vasilache   static LogicalResult defaultFilter(vector::ContractionOp op) {
7892bc4c3e9SNicolas Vasilache     return success();
7902bc4c3e9SNicolas Vasilache   }
7912bc4c3e9SNicolas Vasilache   ContractOpToElementwise(
7922bc4c3e9SNicolas Vasilache       vector::VectorTransformsOptions vectorTransformOptions,
7932bc4c3e9SNicolas Vasilache       MLIRContext *context, PatternBenefit benefit = 1,
7942bc4c3e9SNicolas Vasilache       const FilterConstraintType &constraint = defaultFilter)
795b7324b6aSAndrzej Warzyński       : MaskableOpRewritePattern<vector::ContractionOp>(context, benefit),
7962bc4c3e9SNicolas Vasilache         vectorTransformOptions(vectorTransformOptions), filter(defaultFilter) {}
7972bc4c3e9SNicolas Vasilache 
798b7324b6aSAndrzej Warzyński   FailureOr<Value>
799b7324b6aSAndrzej Warzyński   matchAndRewriteMaskableOp(vector::ContractionOp contractOp,
800b7324b6aSAndrzej Warzyński                             MaskingOpInterface maskOp,
8012bc4c3e9SNicolas Vasilache                             PatternRewriter &rewriter) const override {
8022bc4c3e9SNicolas Vasilache     // TODO: Support vector.mask.
803b7324b6aSAndrzej Warzyński     if (maskOp)
8042bc4c3e9SNicolas Vasilache       return failure();
8052bc4c3e9SNicolas Vasilache 
8062bc4c3e9SNicolas Vasilache     if (failed(filter(contractOp)))
8072bc4c3e9SNicolas Vasilache       return failure();
8082bc4c3e9SNicolas Vasilache 
8092bc4c3e9SNicolas Vasilache     if (vectorTransformOptions.vectorContractLowering !=
8102bc4c3e9SNicolas Vasilache         vector::VectorContractLowering::ParallelArith)
8112bc4c3e9SNicolas Vasilache       return failure();
8122bc4c3e9SNicolas Vasilache 
8132bc4c3e9SNicolas Vasilache     ArrayRef<int64_t> lhsShape = contractOp.getLhsType().getShape();
8142bc4c3e9SNicolas Vasilache     ArrayRef<int64_t> rhsShape = contractOp.getRhsType().getShape();
8152bc4c3e9SNicolas Vasilache     AffineMap lhsMap = contractOp.getIndexingMapsArray()[0];
8162bc4c3e9SNicolas Vasilache     AffineMap rhsMap = contractOp.getIndexingMapsArray()[1];
8172bc4c3e9SNicolas Vasilache     SmallVector<int64_t> lhsReductionDims =
8182bc4c3e9SNicolas Vasilache         getReductionIndex(lhsMap, contractOp.getIteratorTypes());
8192bc4c3e9SNicolas Vasilache     SmallVector<int64_t> rhsReductionDims =
8202bc4c3e9SNicolas Vasilache         getReductionIndex(rhsMap, contractOp.getIteratorTypes());
8212bc4c3e9SNicolas Vasilache     // All the reduction dimensions must be a size 1.
8222bc4c3e9SNicolas Vasilache     for (int64_t dim : lhsReductionDims) {
8232bc4c3e9SNicolas Vasilache       if (lhsShape[dim] != 1)
8242bc4c3e9SNicolas Vasilache         return failure();
8252bc4c3e9SNicolas Vasilache     }
8262bc4c3e9SNicolas Vasilache     for (int64_t dim : rhsReductionDims) {
8272bc4c3e9SNicolas Vasilache       if (rhsShape[dim] != 1)
8282bc4c3e9SNicolas Vasilache         return failure();
8292bc4c3e9SNicolas Vasilache     }
8302bc4c3e9SNicolas Vasilache     AffineMap accMap = contractOp.getIndexingMapsArray()[2];
8312bc4c3e9SNicolas Vasilache     unsigned numParallelDims = accMap.getNumResults();
8322bc4c3e9SNicolas Vasilache     unsigned numLhsDimToBroadcast =
8332bc4c3e9SNicolas Vasilache         numParallelDims - (lhsMap.getNumResults() - lhsReductionDims.size());
8342bc4c3e9SNicolas Vasilache     unsigned numRhsDimToBroadcast =
8352bc4c3e9SNicolas Vasilache         numParallelDims - (rhsMap.getNumResults() - rhsReductionDims.size());
8362bc4c3e9SNicolas Vasilache     SmallVector<int64_t> lhsDims;
8372bc4c3e9SNicolas Vasilache     SmallVector<int64_t> lhsTranspose;
8382bc4c3e9SNicolas Vasilache     SmallVector<int64_t> rhsDims;
8392bc4c3e9SNicolas Vasilache     SmallVector<int64_t> rhsTranspose;
8402bc4c3e9SNicolas Vasilache     for (int64_t dim : lhsReductionDims)
8412bc4c3e9SNicolas Vasilache       lhsTranspose.push_back(numLhsDimToBroadcast + dim);
8422bc4c3e9SNicolas Vasilache     for (int64_t dim : rhsReductionDims)
8432bc4c3e9SNicolas Vasilache       rhsTranspose.push_back(numRhsDimToBroadcast + dim);
8442bc4c3e9SNicolas Vasilache     // Loop through the parallel dimensions to calculate the dimensions to
8452bc4c3e9SNicolas Vasilache     // broadcast and to permute in order to extract only parallel dimensions.
8462bc4c3e9SNicolas Vasilache     for (unsigned i = 0; i < numParallelDims; i++) {
8472bc4c3e9SNicolas Vasilache       std::optional<unsigned> lhsDim =
8482bc4c3e9SNicolas Vasilache           getDimPosition(lhsMap, accMap.getDimPosition(i));
8492bc4c3e9SNicolas Vasilache       if (lhsDim) {
8502bc4c3e9SNicolas Vasilache         lhsTranspose.push_back(numLhsDimToBroadcast + *lhsDim);
8512bc4c3e9SNicolas Vasilache       } else {
8522bc4c3e9SNicolas Vasilache         // If the parallel dimension doesn't exist we will have to broadcast it.
8532bc4c3e9SNicolas Vasilache         lhsDims.push_back(
8545550c821STres Popp             cast<VectorType>(contractOp.getResultType()).getDimSize(i));
8552bc4c3e9SNicolas Vasilache         lhsTranspose.push_back(lhsDims.size() - 1);
8562bc4c3e9SNicolas Vasilache       }
8572bc4c3e9SNicolas Vasilache       std::optional<unsigned> rhsDim =
8582bc4c3e9SNicolas Vasilache           getDimPosition(rhsMap, accMap.getDimPosition(i));
8592bc4c3e9SNicolas Vasilache       if (rhsDim) {
8602bc4c3e9SNicolas Vasilache         rhsTranspose.push_back(numRhsDimToBroadcast + *rhsDim);
8612bc4c3e9SNicolas Vasilache       } else {
8622bc4c3e9SNicolas Vasilache         // If the parallel dimension doesn't exist we will have to broadcast it.
8632bc4c3e9SNicolas Vasilache         rhsDims.push_back(
8645550c821STres Popp             cast<VectorType>(contractOp.getResultType()).getDimSize(i));
8652bc4c3e9SNicolas Vasilache         rhsTranspose.push_back(rhsDims.size() - 1);
8662bc4c3e9SNicolas Vasilache       }
8672bc4c3e9SNicolas Vasilache     }
8682bc4c3e9SNicolas Vasilache     Value newLhs = contractOp.getLhs();
8692bc4c3e9SNicolas Vasilache     Value newRhs = contractOp.getRhs();
8702bc4c3e9SNicolas Vasilache     Location loc = contractOp.getLoc();
8712bc4c3e9SNicolas Vasilache     if (!lhsDims.empty()) {
8722bc4c3e9SNicolas Vasilache       lhsDims.append(lhsShape.begin(), lhsShape.end());
8732bc4c3e9SNicolas Vasilache       auto expandedType =
8742bc4c3e9SNicolas Vasilache           VectorType::get(lhsDims, contractOp.getLhsType().getElementType());
8752bc4c3e9SNicolas Vasilache       newLhs = rewriter.create<vector::BroadcastOp>(loc, expandedType, newLhs);
8762bc4c3e9SNicolas Vasilache     }
8772bc4c3e9SNicolas Vasilache     if (!rhsDims.empty()) {
8782bc4c3e9SNicolas Vasilache       rhsDims.append(rhsShape.begin(), rhsShape.end());
8792bc4c3e9SNicolas Vasilache       auto expandedType =
8802bc4c3e9SNicolas Vasilache           VectorType::get(rhsDims, contractOp.getRhsType().getElementType());
8812bc4c3e9SNicolas Vasilache       newRhs = rewriter.create<vector::BroadcastOp>(loc, expandedType, newRhs);
8822bc4c3e9SNicolas Vasilache     }
8832bc4c3e9SNicolas Vasilache     bool isInt = contractOp.getLhsType().getElementType().isIntOrIndex();
8842bc4c3e9SNicolas Vasilache     newLhs = rewriter.create<vector::TransposeOp>(loc, newLhs, lhsTranspose);
8852bc4c3e9SNicolas Vasilache     newRhs = rewriter.create<vector::TransposeOp>(loc, newRhs, rhsTranspose);
8862bc4c3e9SNicolas Vasilache     SmallVector<int64_t> lhsOffsets(lhsReductionDims.size(), 0);
8872bc4c3e9SNicolas Vasilache     SmallVector<int64_t> rhsOffsets(rhsReductionDims.size(), 0);
88816b75cd2SMatthias Springer     newLhs = rewriter.create<vector::ExtractOp>(loc, newLhs, lhsOffsets);
88916b75cd2SMatthias Springer     newRhs = rewriter.create<vector::ExtractOp>(loc, newRhs, rhsOffsets);
8902bc4c3e9SNicolas Vasilache     std::optional<Value> result =
8912bc4c3e9SNicolas Vasilache         createContractArithOp(loc, newLhs, newRhs, contractOp.getAcc(),
8922bc4c3e9SNicolas Vasilache                               contractOp.getKind(), rewriter, isInt);
893b7324b6aSAndrzej Warzyński     if (result)
894b7324b6aSAndrzej Warzyński       return *result;
895b7324b6aSAndrzej Warzyński 
896b7324b6aSAndrzej Warzyński     return failure();
8972bc4c3e9SNicolas Vasilache   }
8982bc4c3e9SNicolas Vasilache 
8992bc4c3e9SNicolas Vasilache private:
9002bc4c3e9SNicolas Vasilache   /// Options to control the vector patterns.
9012bc4c3e9SNicolas Vasilache   vector::VectorTransformsOptions vectorTransformOptions;
9022bc4c3e9SNicolas Vasilache   FilterConstraintType filter;
9032bc4c3e9SNicolas Vasilache };
9042bc4c3e9SNicolas Vasilache 
9052bc4c3e9SNicolas Vasilache /// Progressive lowering of ContractionOp.
9062bc4c3e9SNicolas Vasilache /// One:
9072bc4c3e9SNicolas Vasilache ///   %x = vector.contract with at least one free/batch dimension
9082bc4c3e9SNicolas Vasilache /// is replaced by:
9092bc4c3e9SNicolas Vasilache ///   %a = vector.contract with one less free/batch dimension
9102bc4c3e9SNicolas Vasilache ///   %b = vector.contract with one less free/batch dimension
9112bc4c3e9SNicolas Vasilache ///   ..
9122bc4c3e9SNicolas Vasilache ///   %x = combine %a %b ..
9132bc4c3e9SNicolas Vasilache /// until a pure contraction is reached (no free/batch dimensions),
9142bc4c3e9SNicolas Vasilache /// which is replaced by a dot-product.
9152bc4c3e9SNicolas Vasilache ///
9162bc4c3e9SNicolas Vasilache /// This only kicks in when either VectorTransformsOptions is set
9172bc4c3e9SNicolas Vasilache /// to DOT or when other contraction patterns fail.
9182bc4c3e9SNicolas Vasilache //
9192bc4c3e9SNicolas Vasilache // TODO: break down into transpose/reshape/cast ops
9202bc4c3e9SNicolas Vasilache //               when they become available to avoid code dup
9212bc4c3e9SNicolas Vasilache // TODO: investigate lowering order impact on performance
922b7324b6aSAndrzej Warzyński FailureOr<Value> ContractionOpLowering::matchAndRewriteMaskableOp(
923b7324b6aSAndrzej Warzyński     vector::ContractionOp op, MaskingOpInterface maskOp,
9242bc4c3e9SNicolas Vasilache     PatternRewriter &rewriter) const {
9252bc4c3e9SNicolas Vasilache   if (failed(filter(op)))
9262bc4c3e9SNicolas Vasilache     return failure();
9272bc4c3e9SNicolas Vasilache 
9282bc4c3e9SNicolas Vasilache   // TODO: support mixed mode contract lowering.
9292bc4c3e9SNicolas Vasilache   if (op.getLhsType().getElementType() !=
9302bc4c3e9SNicolas Vasilache           getElementTypeOrSelf(op.getAccType()) ||
9312bc4c3e9SNicolas Vasilache       op.getRhsType().getElementType() != getElementTypeOrSelf(op.getAccType()))
9322bc4c3e9SNicolas Vasilache     return failure();
9332bc4c3e9SNicolas Vasilache 
9342bc4c3e9SNicolas Vasilache   // TODO: the code below assumes the default contraction, make sure it supports
9352bc4c3e9SNicolas Vasilache   // other kinds before enabling this lowering.
9362bc4c3e9SNicolas Vasilache   if (op.getKind() != vector::CombiningKind::ADD) {
9372bc4c3e9SNicolas Vasilache     return rewriter.notifyMatchFailure(
9382bc4c3e9SNicolas Vasilache         op, "contractions other than 'add' not supported");
9392bc4c3e9SNicolas Vasilache   }
9402bc4c3e9SNicolas Vasilache 
9412bc4c3e9SNicolas Vasilache   // TODO: implement benefits, cost models.
9422bc4c3e9SNicolas Vasilache   MLIRContext *ctx = op.getContext();
943b7324b6aSAndrzej Warzyński 
9442bc4c3e9SNicolas Vasilache   ContractionOpToMatmulOpLowering pat1(vectorTransformOptions, ctx);
945b7324b6aSAndrzej Warzyński   FailureOr<Value> newVal1 =
946b7324b6aSAndrzej Warzyński       pat1.matchAndRewriteMaskableOp(op, maskOp, rewriter);
947b7324b6aSAndrzej Warzyński   if (!failed(newVal1))
948b7324b6aSAndrzej Warzyński     return newVal1;
949b7324b6aSAndrzej Warzyński 
9502bc4c3e9SNicolas Vasilache   ContractionOpToOuterProductOpLowering pat2(vectorTransformOptions, ctx);
951b7324b6aSAndrzej Warzyński   FailureOr<Value> newVal2 =
952b7324b6aSAndrzej Warzyński       pat2.matchAndRewriteMaskableOp(op, maskOp, rewriter);
953b7324b6aSAndrzej Warzyński   if (!failed(newVal2))
954b7324b6aSAndrzej Warzyński     return newVal2;
955b7324b6aSAndrzej Warzyński 
9562bc4c3e9SNicolas Vasilache   ContractionOpToDotLowering pat3(vectorTransformOptions, ctx);
957b7324b6aSAndrzej Warzyński   FailureOr<Value> newVal3 =
958b7324b6aSAndrzej Warzyński       pat3.matchAndRewriteMaskableOp(op, maskOp, rewriter);
959b7324b6aSAndrzej Warzyński   if (!failed(newVal3))
960b7324b6aSAndrzej Warzyński     return newVal3;
961b7324b6aSAndrzej Warzyński 
9622bc4c3e9SNicolas Vasilache   ContractOpToElementwise pat4(vectorTransformOptions, ctx);
963b7324b6aSAndrzej Warzyński   FailureOr<Value> newVal4 =
964b7324b6aSAndrzej Warzyński       pat4.matchAndRewriteMaskableOp(op, maskOp, rewriter);
965b7324b6aSAndrzej Warzyński   if (!failed(newVal4))
966b7324b6aSAndrzej Warzyński     return newVal4;
9672bc4c3e9SNicolas Vasilache 
9682bc4c3e9SNicolas Vasilache   // Vector mask setup.
9692bc4c3e9SNicolas Vasilache 
970b7324b6aSAndrzej Warzyński   Value mask;
971b7324b6aSAndrzej Warzyński   if (maskOp)
972b7324b6aSAndrzej Warzyński     mask = maskOp.getMask();
9732bc4c3e9SNicolas Vasilache   // Find first batch dimension in LHS/RHS, and lower when found.
9742bc4c3e9SNicolas Vasilache   std::vector<std::pair<int64_t, int64_t>> batchDimMap = op.getBatchDimMap();
9752bc4c3e9SNicolas Vasilache   if (!batchDimMap.empty()) {
9762bc4c3e9SNicolas Vasilache     int64_t lhsIndex = batchDimMap[0].first;
9772bc4c3e9SNicolas Vasilache     int64_t rhsIndex = batchDimMap[0].second;
9782bc4c3e9SNicolas Vasilache     auto newOp = lowerParallel(rewriter, op, lhsIndex, rhsIndex, mask);
9792bc4c3e9SNicolas Vasilache     if (failed(newOp))
9802bc4c3e9SNicolas Vasilache       return failure();
981b7324b6aSAndrzej Warzyński     return newOp;
9822bc4c3e9SNicolas Vasilache   }
9832bc4c3e9SNicolas Vasilache 
9842bc4c3e9SNicolas Vasilache   // Collect contracting dimensions.
9852bc4c3e9SNicolas Vasilache   std::vector<std::pair<int64_t, int64_t>> contractingDimMap =
9862bc4c3e9SNicolas Vasilache       op.getContractingDimMap();
9872bc4c3e9SNicolas Vasilache   DenseSet<int64_t> lhsContractingDimSet;
9882bc4c3e9SNicolas Vasilache   DenseSet<int64_t> rhsContractingDimSet;
9892bc4c3e9SNicolas Vasilache   for (auto &dimPair : contractingDimMap) {
9902bc4c3e9SNicolas Vasilache     lhsContractingDimSet.insert(dimPair.first);
9912bc4c3e9SNicolas Vasilache     rhsContractingDimSet.insert(dimPair.second);
9922bc4c3e9SNicolas Vasilache   }
9932bc4c3e9SNicolas Vasilache 
9942bc4c3e9SNicolas Vasilache   // Find first free dimension in LHS, and lower when found.
9952bc4c3e9SNicolas Vasilache   VectorType lhsType = op.getLhsType();
9962bc4c3e9SNicolas Vasilache   for (int64_t lhsIndex = 0, e = lhsType.getRank(); lhsIndex < e; ++lhsIndex) {
9972bc4c3e9SNicolas Vasilache     if (lhsContractingDimSet.count(lhsIndex) == 0) {
9982bc4c3e9SNicolas Vasilache       auto newOp = lowerParallel(rewriter, op, lhsIndex, /*rhsIndex=*/-1, mask);
9992bc4c3e9SNicolas Vasilache       if (failed(newOp))
10002bc4c3e9SNicolas Vasilache         return failure();
1001b7324b6aSAndrzej Warzyński       return newOp;
10022bc4c3e9SNicolas Vasilache     }
10032bc4c3e9SNicolas Vasilache   }
10042bc4c3e9SNicolas Vasilache 
10052bc4c3e9SNicolas Vasilache   // Find first free dimension in RHS, and lower when found.
10062bc4c3e9SNicolas Vasilache   VectorType rhsType = op.getRhsType();
10072bc4c3e9SNicolas Vasilache   for (int64_t rhsIndex = 0, e = rhsType.getRank(); rhsIndex < e; ++rhsIndex) {
10082bc4c3e9SNicolas Vasilache     if (rhsContractingDimSet.count(rhsIndex) == 0) {
10092bc4c3e9SNicolas Vasilache       auto newOp = lowerParallel(rewriter, op, /*lhsIndex=*/-1, rhsIndex, mask);
10102bc4c3e9SNicolas Vasilache       if (failed(newOp))
10112bc4c3e9SNicolas Vasilache         return failure();
1012b7324b6aSAndrzej Warzyński       return newOp;
10132bc4c3e9SNicolas Vasilache     }
10142bc4c3e9SNicolas Vasilache   }
10152bc4c3e9SNicolas Vasilache 
10162bc4c3e9SNicolas Vasilache   // Lower the first remaining reduction dimension.
10172bc4c3e9SNicolas Vasilache   if (!contractingDimMap.empty()) {
10182bc4c3e9SNicolas Vasilache     auto newOp = lowerReduction(rewriter, op, mask);
10192bc4c3e9SNicolas Vasilache     if (failed(newOp))
10202bc4c3e9SNicolas Vasilache       return failure();
1021b7324b6aSAndrzej Warzyński     return newOp;
10222bc4c3e9SNicolas Vasilache   }
10232bc4c3e9SNicolas Vasilache 
10242bc4c3e9SNicolas Vasilache   return failure();
10252bc4c3e9SNicolas Vasilache }
10262bc4c3e9SNicolas Vasilache 
10272bc4c3e9SNicolas Vasilache // Lower one parallel dimension.
10282bc4c3e9SNicolas Vasilache // Incidentally also tolerates unit-size (hence trivial) reduction dimensions.
10292bc4c3e9SNicolas Vasilache // TODO: consider reusing existing contract unrolling
10302bc4c3e9SNicolas Vasilache FailureOr<Value> ContractionOpLowering::lowerParallel(PatternRewriter &rewriter,
10312bc4c3e9SNicolas Vasilache                                                       vector::ContractionOp op,
10322bc4c3e9SNicolas Vasilache                                                       int64_t lhsIndex,
10332bc4c3e9SNicolas Vasilache                                                       int64_t rhsIndex,
10342bc4c3e9SNicolas Vasilache                                                       Value mask) const {
10352bc4c3e9SNicolas Vasilache   VectorType lhsType = op.getLhsType();
10362bc4c3e9SNicolas Vasilache   VectorType rhsType = op.getRhsType();
10375550c821STres Popp   VectorType resType = cast<VectorType>(op.getResultType());
10382bc4c3e9SNicolas Vasilache   // Find the iterator type index and result index.
10392bc4c3e9SNicolas Vasilache   SmallVector<AffineMap> iMap = op.getIndexingMapsArray();
10402bc4c3e9SNicolas Vasilache   int64_t iterIndex = -1;
10412bc4c3e9SNicolas Vasilache   int64_t dimSize = -1;
10422bc4c3e9SNicolas Vasilache   if (lhsIndex >= 0) {
10432bc4c3e9SNicolas Vasilache     iterIndex = iMap[0].getDimPosition(lhsIndex);
10442bc4c3e9SNicolas Vasilache     if (rhsIndex >= 0 && iterIndex != iMap[1].getDimPosition(rhsIndex))
10452bc4c3e9SNicolas Vasilache       return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
10462bc4c3e9SNicolas Vasilache         diag << "expected lhsIndex=" << lhsIndex << " and rhsIndex=" << rhsIndex
10472bc4c3e9SNicolas Vasilache              << " to map to the same dimension";
10482bc4c3e9SNicolas Vasilache       });
1049c91d3b0bSAndrzej Warzynski     if (lhsType.getScalableDims()[lhsIndex])
1050c91d3b0bSAndrzej Warzynski       return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
1051e01c8673SAndrzej Warzyński         diag << "Unrolling scalable dimension (lhsIndex=" << lhsIndex
1052c91d3b0bSAndrzej Warzynski              << ") is not supported yet";
1053c91d3b0bSAndrzej Warzynski       });
10542bc4c3e9SNicolas Vasilache     dimSize = lhsType.getDimSize(lhsIndex);
10552bc4c3e9SNicolas Vasilache   } else if (rhsIndex >= 0) {
10562bc4c3e9SNicolas Vasilache     iterIndex = iMap[1].getDimPosition(rhsIndex);
1057c91d3b0bSAndrzej Warzynski     if (rhsType.getScalableDims()[rhsIndex])
1058c91d3b0bSAndrzej Warzynski       return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
1059e01c8673SAndrzej Warzyński         diag << "Unrolling scalable dimension (rhsIndex=" << rhsIndex
1060c91d3b0bSAndrzej Warzynski              << ") is not supported yet";
1061c91d3b0bSAndrzej Warzynski       });
10622bc4c3e9SNicolas Vasilache     dimSize = rhsType.getDimSize(rhsIndex);
10632bc4c3e9SNicolas Vasilache   }
10642bc4c3e9SNicolas Vasilache   if (iterIndex < 0)
10652bc4c3e9SNicolas Vasilache     return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
10662bc4c3e9SNicolas Vasilache       diag << "expected either lhsIndex=" << lhsIndex
10672bc4c3e9SNicolas Vasilache            << " or rhsIndex=" << rhsIndex << " to be nonnegative";
10682bc4c3e9SNicolas Vasilache     });
10692bc4c3e9SNicolas Vasilache   // value_or(-1) means that we tolerate a dimension not appearing
10702bc4c3e9SNicolas Vasilache   // in the result map. That can't happen for actual parallel iterators, but
10712bc4c3e9SNicolas Vasilache   // the caller ContractionOpLowering::matchAndRewrite is currently calling
10722bc4c3e9SNicolas Vasilache   // lowerParallel also for the case of unit-size reduction dims appearing only
10732bc4c3e9SNicolas Vasilache   // on one of LHS or RHS, not both. At the moment, such cases are created by
10742bc4c3e9SNicolas Vasilache   // CastAwayContractionLeadingOneDim, so we need to either support that or
10752bc4c3e9SNicolas Vasilache   // modify that pattern.
10762bc4c3e9SNicolas Vasilache   int64_t resIndex = getResultIndex(iMap[2], iterIndex).value_or(-1);
10772bc4c3e9SNicolas Vasilache   if (resIndex == -1 && dimSize != 1)
10782bc4c3e9SNicolas Vasilache     return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
10792bc4c3e9SNicolas Vasilache       diag << "expected the dimension for iterIndex=" << iterIndex
10802bc4c3e9SNicolas Vasilache            << " to either appear in the result map, or to be a unit dimension";
10812bc4c3e9SNicolas Vasilache     });
10822bc4c3e9SNicolas Vasilache 
10832bc4c3e9SNicolas Vasilache   // Construct new iterator types and affine map array attribute.
10842bc4c3e9SNicolas Vasilache   std::array<AffineMap, 3> lowIndexingMaps = {
10852bc4c3e9SNicolas Vasilache       adjustMap(iMap[0], iterIndex, rewriter),
10862bc4c3e9SNicolas Vasilache       adjustMap(iMap[1], iterIndex, rewriter),
10872bc4c3e9SNicolas Vasilache       adjustMap(iMap[2], iterIndex, rewriter)};
10882bc4c3e9SNicolas Vasilache   auto lowAffine = rewriter.getAffineMapArrayAttr(lowIndexingMaps);
10892bc4c3e9SNicolas Vasilache   auto lowIter =
10902bc4c3e9SNicolas Vasilache       rewriter.getArrayAttr(adjustIter(op.getIteratorTypes(), iterIndex));
10912bc4c3e9SNicolas Vasilache   // Unroll into a series of lower dimensional vector.contract ops.
10922bc4c3e9SNicolas Vasilache   Location loc = op.getLoc();
10932bc4c3e9SNicolas Vasilache   Value result = rewriter.create<arith::ConstantOp>(
10942bc4c3e9SNicolas Vasilache       loc, resType, rewriter.getZeroAttr(resType));
10952bc4c3e9SNicolas Vasilache 
10962bc4c3e9SNicolas Vasilache   for (int64_t d = 0; d < dimSize; ++d) {
10972bc4c3e9SNicolas Vasilache     auto lhs = reshapeLoad(loc, op.getLhs(), lhsType, lhsIndex, d, rewriter);
10982bc4c3e9SNicolas Vasilache     auto rhs = reshapeLoad(loc, op.getRhs(), rhsType, rhsIndex, d, rewriter);
10992bc4c3e9SNicolas Vasilache     auto acc = reshapeLoad(loc, op.getAcc(), resType, resIndex, d, rewriter);
11002bc4c3e9SNicolas Vasilache 
11012bc4c3e9SNicolas Vasilache     Value lowMask;
11022bc4c3e9SNicolas Vasilache     if (mask)
11032bc4c3e9SNicolas Vasilache       lowMask = reshapeLoad(loc, mask, cast<VectorType>(mask.getType()),
11042bc4c3e9SNicolas Vasilache                             iterIndex, d, rewriter);
11052bc4c3e9SNicolas Vasilache 
11062bc4c3e9SNicolas Vasilache     Operation *lowContract = rewriter.create<vector::ContractionOp>(
11072bc4c3e9SNicolas Vasilache         loc, lhs, rhs, acc, lowAffine, lowIter);
11082bc4c3e9SNicolas Vasilache     lowContract = maskOperation(rewriter, lowContract, lowMask);
11092bc4c3e9SNicolas Vasilache     result = reshapeStore(loc, lowContract->getResult(0), result, resType,
11102bc4c3e9SNicolas Vasilache                           resIndex, d, rewriter);
11112bc4c3e9SNicolas Vasilache   }
11122bc4c3e9SNicolas Vasilache   return result;
11132bc4c3e9SNicolas Vasilache }
11142bc4c3e9SNicolas Vasilache 
11152bc4c3e9SNicolas Vasilache // Lower one reduction dimension.
11162bc4c3e9SNicolas Vasilache FailureOr<Value> ContractionOpLowering::lowerReduction(
11172bc4c3e9SNicolas Vasilache     PatternRewriter &rewriter, vector::ContractionOp op, Value mask) const {
11182bc4c3e9SNicolas Vasilache   auto loc = op.getLoc();
11192bc4c3e9SNicolas Vasilache   VectorType lhsType = op.getLhsType();
11202bc4c3e9SNicolas Vasilache   VectorType rhsType = op.getRhsType();
11212bc4c3e9SNicolas Vasilache   Type resType = op.getResultType();
11225550c821STres Popp   if (isa<VectorType>(resType))
11232bc4c3e9SNicolas Vasilache     return rewriter.notifyMatchFailure(op,
11242bc4c3e9SNicolas Vasilache                                        "did not expect a VectorType result");
11255550c821STres Popp   bool isInt = isa<IntegerType>(resType);
11262bc4c3e9SNicolas Vasilache   // Use iterator index 0.
11272bc4c3e9SNicolas Vasilache   int64_t iterIndex = 0;
11282bc4c3e9SNicolas Vasilache   SmallVector<AffineMap> iMap = op.getIndexingMapsArray();
11292bc4c3e9SNicolas Vasilache   std::optional<int64_t> lookupLhs = getResultIndex(iMap[0], iterIndex);
11302bc4c3e9SNicolas Vasilache   std::optional<int64_t> lookupRhs = getResultIndex(iMap[1], iterIndex);
11312bc4c3e9SNicolas Vasilache   if (!lookupLhs.has_value())
11322bc4c3e9SNicolas Vasilache     return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
11332bc4c3e9SNicolas Vasilache       diag << "expected iterIndex=" << iterIndex << "to map to a LHS dimension";
11342bc4c3e9SNicolas Vasilache     });
11352bc4c3e9SNicolas Vasilache   if (!lookupRhs.has_value())
11362bc4c3e9SNicolas Vasilache     return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
11372bc4c3e9SNicolas Vasilache       diag << "expected iterIndex=" << iterIndex << "to map to a RHS dimension";
11382bc4c3e9SNicolas Vasilache     });
11392bc4c3e9SNicolas Vasilache   int64_t lhsIndex = *lookupLhs;
11402bc4c3e9SNicolas Vasilache   int64_t rhsIndex = *lookupRhs;
11412bc4c3e9SNicolas Vasilache   int64_t dimSize = lhsType.getDimSize(lhsIndex);
11422bc4c3e9SNicolas Vasilache   if (dimSize != rhsType.getDimSize(rhsIndex))
11432bc4c3e9SNicolas Vasilache     return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
11442bc4c3e9SNicolas Vasilache       diag << "expect LHS dimension " << lhsIndex
11452bc4c3e9SNicolas Vasilache            << " to have the same size as RHS dimension " << rhsIndex;
11462bc4c3e9SNicolas Vasilache     });
11472bc4c3e9SNicolas Vasilache   // Base case.
11482bc4c3e9SNicolas Vasilache   if (lhsType.getRank() == 1) {
11492bc4c3e9SNicolas Vasilache     if (rhsType.getRank() != 1)
11502bc4c3e9SNicolas Vasilache       return rewriter.notifyMatchFailure(
11512bc4c3e9SNicolas Vasilache           op, "When LHS has rank 1, expected also RHS to have rank 1");
11522bc4c3e9SNicolas Vasilache     Value m = createMul(loc, op.getLhs(), op.getRhs(), isInt, rewriter);
11532bc4c3e9SNicolas Vasilache     auto kind = vector::CombiningKind::ADD;
11542bc4c3e9SNicolas Vasilache 
11552bc4c3e9SNicolas Vasilache     Value acc = op.getAcc();
11562bc4c3e9SNicolas Vasilache     Operation *reductionOp =
11572bc4c3e9SNicolas Vasilache         acc ? rewriter.create<vector::ReductionOp>(loc, kind, m, acc)
11582bc4c3e9SNicolas Vasilache             : rewriter.create<vector::ReductionOp>(loc, kind, m);
11592bc4c3e9SNicolas Vasilache     return maskOperation(rewriter, reductionOp, mask)->getResult(0);
11602bc4c3e9SNicolas Vasilache   }
11612bc4c3e9SNicolas Vasilache   // Construct new iterator types and affine map array attribute.
11622bc4c3e9SNicolas Vasilache   std::array<AffineMap, 3> lowIndexingMaps = {
11632bc4c3e9SNicolas Vasilache       adjustMap(iMap[0], iterIndex, rewriter),
11642bc4c3e9SNicolas Vasilache       adjustMap(iMap[1], iterIndex, rewriter),
11652bc4c3e9SNicolas Vasilache       adjustMap(iMap[2], iterIndex, rewriter)};
11662bc4c3e9SNicolas Vasilache   auto lowAffine = rewriter.getAffineMapArrayAttr(lowIndexingMaps);
11672bc4c3e9SNicolas Vasilache   auto lowIter =
11682bc4c3e9SNicolas Vasilache       rewriter.getArrayAttr(adjustIter(op.getIteratorTypes(), iterIndex));
11692bc4c3e9SNicolas Vasilache   // Unroll into a series of lower dimensional vector.contract ops.
11702bc4c3e9SNicolas Vasilache   // By feeding the initial accumulator into the first contraction,
11712bc4c3e9SNicolas Vasilache   // and the result of each contraction into the next, eventually
11722bc4c3e9SNicolas Vasilache   // the sum of all reductions is computed.
11732bc4c3e9SNicolas Vasilache   Value result = op.getAcc();
11742bc4c3e9SNicolas Vasilache   for (int64_t d = 0; d < dimSize; ++d) {
11752bc4c3e9SNicolas Vasilache     auto lhs = reshapeLoad(loc, op.getLhs(), lhsType, lhsIndex, d, rewriter);
11762bc4c3e9SNicolas Vasilache     auto rhs = reshapeLoad(loc, op.getRhs(), rhsType, rhsIndex, d, rewriter);
11772bc4c3e9SNicolas Vasilache     Value newMask;
11782bc4c3e9SNicolas Vasilache     if (mask)
11792bc4c3e9SNicolas Vasilache       newMask = reshapeLoad(loc, mask, cast<VectorType>(mask.getType()),
11802bc4c3e9SNicolas Vasilache                             iterIndex, d, rewriter);
11812bc4c3e9SNicolas Vasilache 
11822bc4c3e9SNicolas Vasilache     Operation *newContract = rewriter.create<vector::ContractionOp>(
11832bc4c3e9SNicolas Vasilache         loc, lhs, rhs, result, lowAffine, lowIter);
11842bc4c3e9SNicolas Vasilache     result = maskOperation(rewriter, newContract, newMask)->getResult(0);
11852bc4c3e9SNicolas Vasilache   }
11862bc4c3e9SNicolas Vasilache   return result;
11872bc4c3e9SNicolas Vasilache }
11882bc4c3e9SNicolas Vasilache 
11892bc4c3e9SNicolas Vasilache /// Progressive lowering of OuterProductOp.
11902bc4c3e9SNicolas Vasilache /// One:
11912bc4c3e9SNicolas Vasilache ///   %x = vector.outerproduct %lhs, %rhs, %acc
11922bc4c3e9SNicolas Vasilache /// is replaced by:
11932bc4c3e9SNicolas Vasilache ///   %z = zero-result
11942bc4c3e9SNicolas Vasilache ///   %0 = vector.extract %lhs[0]
11952bc4c3e9SNicolas Vasilache ///   %1 = vector.broadcast %0
11962bc4c3e9SNicolas Vasilache ///   %2 = vector.extract %acc[0]
11972bc4c3e9SNicolas Vasilache ///   %3 = vector.fma %1, %rhs, %2
11982bc4c3e9SNicolas Vasilache ///   %4 = vector.insert %3, %z[0]
11992bc4c3e9SNicolas Vasilache ///   ..
12002bc4c3e9SNicolas Vasilache ///   %x = vector.insert %.., %..[N-1]
12012bc4c3e9SNicolas Vasilache ///
12022bc4c3e9SNicolas Vasilache class OuterProductOpLowering : public OpRewritePattern<vector::OuterProductOp> {
12032bc4c3e9SNicolas Vasilache public:
12042bc4c3e9SNicolas Vasilache   using OpRewritePattern::OpRewritePattern;
12052bc4c3e9SNicolas Vasilache 
12062bc4c3e9SNicolas Vasilache   LogicalResult matchAndRewrite(vector::OuterProductOp op,
12072bc4c3e9SNicolas Vasilache                                 PatternRewriter &rewriter) const override {
1208f75d46a7SCullen Rhodes     VectorType resType = op.getResultVectorType();
1209f75d46a7SCullen Rhodes     if ((resType.getShape().size() >= 2) && resType.allDimsScalable())
1210f75d46a7SCullen Rhodes       return failure();
1211f75d46a7SCullen Rhodes 
12122bc4c3e9SNicolas Vasilache     auto loc = op.getLoc();
12132bc4c3e9SNicolas Vasilache 
12142bc4c3e9SNicolas Vasilache     VectorType lhsType = op.getOperandVectorTypeLHS();
12155550c821STres Popp     VectorType rhsType = dyn_cast<VectorType>(op.getOperandTypeRHS());
12162bc4c3e9SNicolas Vasilache     Type eltType = resType.getElementType();
12175550c821STres Popp     bool isInt = isa<IntegerType, IndexType>(eltType);
1218067bd7d0SCullen Rhodes     Value acc = op.getAcc();
12192bc4c3e9SNicolas Vasilache     vector::CombiningKind kind = op.getKind();
12202bc4c3e9SNicolas Vasilache 
12212bc4c3e9SNicolas Vasilache     // Vector mask setup.
12222bc4c3e9SNicolas Vasilache     OpBuilder::InsertionGuard guard(rewriter);
12232bc4c3e9SNicolas Vasilache     auto maskableOp = cast<vector::MaskableOpInterface>(op.getOperation());
12242bc4c3e9SNicolas Vasilache     Operation *rootOp;
12252bc4c3e9SNicolas Vasilache     Value mask;
12262bc4c3e9SNicolas Vasilache     if (maskableOp.isMasked()) {
12272bc4c3e9SNicolas Vasilache       rewriter.setInsertionPoint(maskableOp.getMaskingOp());
12282bc4c3e9SNicolas Vasilache       rootOp = maskableOp.getMaskingOp();
12292bc4c3e9SNicolas Vasilache       mask = maskableOp.getMaskingOp().getMask();
12302bc4c3e9SNicolas Vasilache     } else {
12312bc4c3e9SNicolas Vasilache       rootOp = op;
12322bc4c3e9SNicolas Vasilache     }
12332bc4c3e9SNicolas Vasilache 
12342bc4c3e9SNicolas Vasilache     if (!rhsType) {
12352bc4c3e9SNicolas Vasilache       // Special case: AXPY operation.
12362bc4c3e9SNicolas Vasilache       Value b = rewriter.create<vector::BroadcastOp>(loc, lhsType, op.getRhs());
12372bc4c3e9SNicolas Vasilache       std::optional<Value> mult = createContractArithOp(
12382bc4c3e9SNicolas Vasilache           loc, op.getLhs(), b, acc, kind, rewriter, isInt, mask);
12392bc4c3e9SNicolas Vasilache       if (!mult.has_value())
12402bc4c3e9SNicolas Vasilache         return failure();
12412bc4c3e9SNicolas Vasilache       rewriter.replaceOp(rootOp, *mult);
12422bc4c3e9SNicolas Vasilache       return success();
12432bc4c3e9SNicolas Vasilache     }
12442bc4c3e9SNicolas Vasilache 
12452bc4c3e9SNicolas Vasilache     Value result = rewriter.create<arith::ConstantOp>(
12462bc4c3e9SNicolas Vasilache         loc, resType, rewriter.getZeroAttr(resType));
12472bc4c3e9SNicolas Vasilache     for (int64_t d = 0, e = resType.getDimSize(0); d < e; ++d) {
124816b75cd2SMatthias Springer       Value x = rewriter.create<vector::ExtractOp>(loc, op.getLhs(), d);
12492bc4c3e9SNicolas Vasilache       Value a = rewriter.create<vector::BroadcastOp>(loc, rhsType, x);
12502bc4c3e9SNicolas Vasilache       Value r = nullptr;
12512bc4c3e9SNicolas Vasilache       if (acc)
125216b75cd2SMatthias Springer         r = rewriter.create<vector::ExtractOp>(loc, acc, d);
12532bc4c3e9SNicolas Vasilache       Value extrMask;
12542bc4c3e9SNicolas Vasilache       if (mask)
125516b75cd2SMatthias Springer         extrMask = rewriter.create<vector::ExtractOp>(loc, mask, d);
12562bc4c3e9SNicolas Vasilache 
12572bc4c3e9SNicolas Vasilache       std::optional<Value> m = createContractArithOp(
12582bc4c3e9SNicolas Vasilache           loc, a, op.getRhs(), r, kind, rewriter, isInt, extrMask);
12592bc4c3e9SNicolas Vasilache       if (!m.has_value())
12602bc4c3e9SNicolas Vasilache         return failure();
126198f6289aSDiego Caballero       result = rewriter.create<vector::InsertOp>(loc, *m, result, d);
12622bc4c3e9SNicolas Vasilache     }
12632bc4c3e9SNicolas Vasilache 
12642bc4c3e9SNicolas Vasilache     rewriter.replaceOp(rootOp, result);
12652bc4c3e9SNicolas Vasilache     return success();
12662bc4c3e9SNicolas Vasilache   }
12672bc4c3e9SNicolas Vasilache };
12682bc4c3e9SNicolas Vasilache 
12692bc4c3e9SNicolas Vasilache /// Progressively lower a `vector.contract %a, %b, %c` with row-major matmul
12702bc4c3e9SNicolas Vasilache /// semantics to:
12712bc4c3e9SNicolas Vasilache /// ```
12722bc4c3e9SNicolas Vasilache ///    %mta = maybe_transpose
12732bc4c3e9SNicolas Vasilache ///    %mtb = maybe_transpose
12742bc4c3e9SNicolas Vasilache ///    %flattened_a = vector.shape_cast %mta
12752bc4c3e9SNicolas Vasilache ///    %flattened_b = vector.shape_cast %mtb
12762bc4c3e9SNicolas Vasilache ///    %flattened_d = vector.matmul %flattened_a, %flattened_b
12772bc4c3e9SNicolas Vasilache ///    %mtd = vector.shape_cast %flattened_d
12782bc4c3e9SNicolas Vasilache ///    %d = maybe_untranspose %mtd
12792bc4c3e9SNicolas Vasilache ///    %e = add %c, %d
12802bc4c3e9SNicolas Vasilache /// ```
12812bc4c3e9SNicolas Vasilache /// `vector.matmul` later lowers to `llvm.matrix.multiply`.
12822bc4c3e9SNicolas Vasilache //
12832bc4c3e9SNicolas Vasilache /// This only kicks in when VectorTransformsOptions is set to `Matmul`.
12842bc4c3e9SNicolas Vasilache /// vector.transpose operations are inserted if the vector.contract op is not a
12852bc4c3e9SNicolas Vasilache /// row-major matrix multiply.
1286*cb89457fSAndrzej Warzyński ///
1287*cb89457fSAndrzej Warzyński /// Scalable vectors are not supported.
1288b7324b6aSAndrzej Warzyński FailureOr<Value> ContractionOpToMatmulOpLowering::matchAndRewriteMaskableOp(
1289b7324b6aSAndrzej Warzyński     vector::ContractionOp op, MaskingOpInterface maskOp,
12902bc4c3e9SNicolas Vasilache     PatternRewriter &rew) const {
12912bc4c3e9SNicolas Vasilache   // TODO: Support vector.mask.
1292b7324b6aSAndrzej Warzyński   if (maskOp)
12932bc4c3e9SNicolas Vasilache     return failure();
12942bc4c3e9SNicolas Vasilache 
12952bc4c3e9SNicolas Vasilache   if (vectorTransformOptions.vectorContractLowering !=
12962bc4c3e9SNicolas Vasilache       vector::VectorContractLowering::Matmul)
12972bc4c3e9SNicolas Vasilache     return failure();
12982bc4c3e9SNicolas Vasilache   if (failed(filter(op)))
12992bc4c3e9SNicolas Vasilache     return failure();
13002bc4c3e9SNicolas Vasilache 
13012bc4c3e9SNicolas Vasilache   auto iteratorTypes = op.getIteratorTypes().getValue();
13022bc4c3e9SNicolas Vasilache   if (!isParallelIterator(iteratorTypes[0]) ||
13032bc4c3e9SNicolas Vasilache       !isParallelIterator(iteratorTypes[1]) ||
13042bc4c3e9SNicolas Vasilache       !isReductionIterator(iteratorTypes[2]))
13052bc4c3e9SNicolas Vasilache     return failure();
13062bc4c3e9SNicolas Vasilache 
1307*cb89457fSAndrzej Warzyński   Type opResType = op.getType();
1308*cb89457fSAndrzej Warzyński   VectorType vecType = dyn_cast<VectorType>(opResType);
1309*cb89457fSAndrzej Warzyński   if (vecType && vecType.isScalable()) {
1310*cb89457fSAndrzej Warzyński     // Note - this is sufficient to reject all cases with scalable vectors.
1311*cb89457fSAndrzej Warzyński     return failure();
1312*cb89457fSAndrzej Warzyński   }
1313*cb89457fSAndrzej Warzyński 
13142bc4c3e9SNicolas Vasilache   Type elementType = op.getLhsType().getElementType();
13152bc4c3e9SNicolas Vasilache   if (!elementType.isIntOrFloat())
13162bc4c3e9SNicolas Vasilache     return failure();
13172bc4c3e9SNicolas Vasilache 
1318*cb89457fSAndrzej Warzyński   Type dstElementType = vecType ? vecType.getElementType() : opResType;
13192bc4c3e9SNicolas Vasilache   if (elementType != dstElementType)
13202bc4c3e9SNicolas Vasilache     return failure();
13212bc4c3e9SNicolas Vasilache 
13222bc4c3e9SNicolas Vasilache   // Perform lhs + rhs transpositions to conform to matmul row-major semantics.
13232bc4c3e9SNicolas Vasilache   // Bail out if the contraction cannot be put in this form.
13242bc4c3e9SNicolas Vasilache   MLIRContext *ctx = op.getContext();
13252bc4c3e9SNicolas Vasilache   Location loc = op.getLoc();
13262bc4c3e9SNicolas Vasilache   AffineExpr m, n, k;
13272bc4c3e9SNicolas Vasilache   bindDims(rew.getContext(), m, n, k);
13282bc4c3e9SNicolas Vasilache   // LHS must be A(m, k) or A(k, m).
13292bc4c3e9SNicolas Vasilache   Value lhs = op.getLhs();
13302bc4c3e9SNicolas Vasilache   auto lhsMap = op.getIndexingMapsArray()[0];
13312bc4c3e9SNicolas Vasilache   if (lhsMap == AffineMap::get(3, 0, {k, m}, ctx))
13322bc4c3e9SNicolas Vasilache     lhs = rew.create<vector::TransposeOp>(loc, lhs, ArrayRef<int64_t>{1, 0});
13332bc4c3e9SNicolas Vasilache   else if (lhsMap != AffineMap::get(3, 0, {m, k}, ctx))
13342bc4c3e9SNicolas Vasilache     return failure();
13352bc4c3e9SNicolas Vasilache 
13362bc4c3e9SNicolas Vasilache   // RHS must be B(k, n) or B(n, k).
13372bc4c3e9SNicolas Vasilache   Value rhs = op.getRhs();
13382bc4c3e9SNicolas Vasilache   auto rhsMap = op.getIndexingMapsArray()[1];
13392bc4c3e9SNicolas Vasilache   if (rhsMap == AffineMap::get(3, 0, {n, k}, ctx))
13402bc4c3e9SNicolas Vasilache     rhs = rew.create<vector::TransposeOp>(loc, rhs, ArrayRef<int64_t>{1, 0});
13412bc4c3e9SNicolas Vasilache   else if (rhsMap != AffineMap::get(3, 0, {k, n}, ctx))
13422bc4c3e9SNicolas Vasilache     return failure();
13432bc4c3e9SNicolas Vasilache 
13442bc4c3e9SNicolas Vasilache   // At this point lhs and rhs are in row-major.
13455550c821STres Popp   VectorType lhsType = cast<VectorType>(lhs.getType());
13465550c821STres Popp   VectorType rhsType = cast<VectorType>(rhs.getType());
13472bc4c3e9SNicolas Vasilache   int64_t lhsRows = lhsType.getDimSize(0);
13482bc4c3e9SNicolas Vasilache   int64_t lhsColumns = lhsType.getDimSize(1);
13492bc4c3e9SNicolas Vasilache   int64_t rhsColumns = rhsType.getDimSize(1);
13502bc4c3e9SNicolas Vasilache 
13512bc4c3e9SNicolas Vasilache   Type flattenedLHSType =
13522bc4c3e9SNicolas Vasilache       VectorType::get(lhsType.getNumElements(), lhsType.getElementType());
13532bc4c3e9SNicolas Vasilache   lhs = rew.create<vector::ShapeCastOp>(loc, flattenedLHSType, lhs);
13542bc4c3e9SNicolas Vasilache 
13552bc4c3e9SNicolas Vasilache   Type flattenedRHSType =
13562bc4c3e9SNicolas Vasilache       VectorType::get(rhsType.getNumElements(), rhsType.getElementType());
13572bc4c3e9SNicolas Vasilache   rhs = rew.create<vector::ShapeCastOp>(loc, flattenedRHSType, rhs);
13582bc4c3e9SNicolas Vasilache 
13592bc4c3e9SNicolas Vasilache   Value mul = rew.create<vector::MatmulOp>(loc, lhs, rhs, lhsRows, lhsColumns,
13602bc4c3e9SNicolas Vasilache                                            rhsColumns);
13612bc4c3e9SNicolas Vasilache   mul = rew.create<vector::ShapeCastOp>(
13622bc4c3e9SNicolas Vasilache       loc,
13632bc4c3e9SNicolas Vasilache       VectorType::get({lhsRows, rhsColumns},
13642bc4c3e9SNicolas Vasilache                       getElementTypeOrSelf(op.getAcc().getType())),
13652bc4c3e9SNicolas Vasilache       mul);
13662bc4c3e9SNicolas Vasilache 
13672bc4c3e9SNicolas Vasilache   // ACC must be C(m, n) or C(n, m).
13682bc4c3e9SNicolas Vasilache   auto accMap = op.getIndexingMapsArray()[2];
13692bc4c3e9SNicolas Vasilache   if (accMap == AffineMap::get(3, 0, {n, m}, ctx))
13702bc4c3e9SNicolas Vasilache     mul = rew.create<vector::TransposeOp>(loc, mul, ArrayRef<int64_t>{1, 0});
13712bc4c3e9SNicolas Vasilache   else if (accMap != AffineMap::get(3, 0, {m, n}, ctx))
13722bc4c3e9SNicolas Vasilache     llvm_unreachable("invalid contraction semantics");
13732bc4c3e9SNicolas Vasilache 
13742bc4c3e9SNicolas Vasilache   Value res =
13755550c821STres Popp       isa<IntegerType>(elementType)
13762bc4c3e9SNicolas Vasilache           ? static_cast<Value>(rew.create<arith::AddIOp>(loc, op.getAcc(), mul))
13772bc4c3e9SNicolas Vasilache           : static_cast<Value>(
13782bc4c3e9SNicolas Vasilache                 rew.create<arith::AddFOp>(loc, op.getAcc(), mul));
13792bc4c3e9SNicolas Vasilache 
1380b7324b6aSAndrzej Warzyński   return res;
13812bc4c3e9SNicolas Vasilache }
13822bc4c3e9SNicolas Vasilache } // namespace
13832bc4c3e9SNicolas Vasilache 
13842bc4c3e9SNicolas Vasilache void mlir::vector::populateVectorContractLoweringPatterns(
13852bc4c3e9SNicolas Vasilache     RewritePatternSet &patterns, VectorTransformsOptions options,
13862bc4c3e9SNicolas Vasilache     PatternBenefit benefit, bool disableOuterProductLowering) {
13872bc4c3e9SNicolas Vasilache   if (!disableOuterProductLowering)
13882bc4c3e9SNicolas Vasilache     patterns.add<OuterProductOpLowering>(patterns.getContext(), benefit);
13892bc4c3e9SNicolas Vasilache   patterns.add<ContractionOpLowering, ContractionOpToMatmulOpLowering,
13902bc4c3e9SNicolas Vasilache                ContractionOpToOuterProductOpLowering>(
13912bc4c3e9SNicolas Vasilache       options, patterns.getContext(), benefit);
13922bc4c3e9SNicolas Vasilache }
13938b513407SNicolas Vasilache 
13948b513407SNicolas Vasilache void mlir::vector::populateVectorOuterProductLoweringPatterns(
13958b513407SNicolas Vasilache     RewritePatternSet &patterns, PatternBenefit benefit) {
13968b513407SNicolas Vasilache   patterns.add<OuterProductOpLowering>(patterns.getContext(), benefit);
13978b513407SNicolas Vasilache }
1398