12bc4c3e9SNicolas Vasilache //===- LowerVectorContract.cpp - Lower 'vector.contract' operation --------===// 22bc4c3e9SNicolas Vasilache // 32bc4c3e9SNicolas Vasilache // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 42bc4c3e9SNicolas Vasilache // See https://llvm.org/LICENSE.txt for license information. 52bc4c3e9SNicolas Vasilache // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 62bc4c3e9SNicolas Vasilache // 72bc4c3e9SNicolas Vasilache //===----------------------------------------------------------------------===// 82bc4c3e9SNicolas Vasilache // 92bc4c3e9SNicolas Vasilache // This file implements target-independent rewrites and utilities to lower the 102bc4c3e9SNicolas Vasilache // 'vector.contract' operation. 112bc4c3e9SNicolas Vasilache // 122bc4c3e9SNicolas Vasilache //===----------------------------------------------------------------------===// 132bc4c3e9SNicolas Vasilache 142bc4c3e9SNicolas Vasilache #include "mlir/Dialect/Affine/IR/AffineOps.h" 152bc4c3e9SNicolas Vasilache #include "mlir/Dialect/Arith/IR/Arith.h" 162bc4c3e9SNicolas Vasilache #include "mlir/Dialect/Arith/Utils/Utils.h" 172bc4c3e9SNicolas Vasilache #include "mlir/Dialect/Linalg/IR/Linalg.h" 182bc4c3e9SNicolas Vasilache #include "mlir/Dialect/MemRef/IR/MemRef.h" 192bc4c3e9SNicolas Vasilache #include "mlir/Dialect/SCF/IR/SCF.h" 202bc4c3e9SNicolas Vasilache #include "mlir/Dialect/Tensor/IR/Tensor.h" 212bc4c3e9SNicolas Vasilache #include "mlir/Dialect/Utils/IndexingUtils.h" 222bc4c3e9SNicolas Vasilache #include "mlir/Dialect/Utils/StructuredOpsUtils.h" 232bc4c3e9SNicolas Vasilache #include "mlir/Dialect/Vector/IR/VectorOps.h" 242bc4c3e9SNicolas Vasilache #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" 252bc4c3e9SNicolas Vasilache #include "mlir/Dialect/Vector/Utils/VectorUtils.h" 262bc4c3e9SNicolas Vasilache #include "mlir/IR/BuiltinAttributeInterfaces.h" 272bc4c3e9SNicolas Vasilache #include "mlir/IR/BuiltinTypes.h" 282bc4c3e9SNicolas Vasilache #include "mlir/IR/ImplicitLocOpBuilder.h" 292bc4c3e9SNicolas Vasilache #include "mlir/IR/Location.h" 302bc4c3e9SNicolas Vasilache #include "mlir/IR/Matchers.h" 312bc4c3e9SNicolas Vasilache #include "mlir/IR/PatternMatch.h" 322bc4c3e9SNicolas Vasilache #include "mlir/IR/TypeUtilities.h" 332bc4c3e9SNicolas Vasilache #include "mlir/Interfaces/VectorInterfaces.h" 342bc4c3e9SNicolas Vasilache 352bc4c3e9SNicolas Vasilache #define DEBUG_TYPE "vector-contract-lowering" 362bc4c3e9SNicolas Vasilache 372bc4c3e9SNicolas Vasilache using namespace mlir; 382bc4c3e9SNicolas Vasilache using namespace mlir::vector; 392bc4c3e9SNicolas Vasilache 402bc4c3e9SNicolas Vasilache //===----------------------------------------------------------------------===// 412bc4c3e9SNicolas Vasilache // Helper functions 422bc4c3e9SNicolas Vasilache //===----------------------------------------------------------------------===// 432bc4c3e9SNicolas Vasilache // Helper to find an index in an affine map. 442bc4c3e9SNicolas Vasilache static std::optional<int64_t> getResultIndex(AffineMap map, int64_t index) { 452bc4c3e9SNicolas Vasilache for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) { 462bc4c3e9SNicolas Vasilache int64_t idx = map.getDimPosition(i); 472bc4c3e9SNicolas Vasilache if (idx == index) 482bc4c3e9SNicolas Vasilache return i; 492bc4c3e9SNicolas Vasilache } 502bc4c3e9SNicolas Vasilache return std::nullopt; 512bc4c3e9SNicolas Vasilache } 522bc4c3e9SNicolas Vasilache 532bc4c3e9SNicolas Vasilache // Helper to construct iterator types with one index removed. 542bc4c3e9SNicolas Vasilache static SmallVector<Attribute> adjustIter(ArrayAttr iteratorTypes, 552bc4c3e9SNicolas Vasilache int64_t index) { 562bc4c3e9SNicolas Vasilache SmallVector<Attribute> results; 572bc4c3e9SNicolas Vasilache for (const auto &it : llvm::enumerate(iteratorTypes)) { 582bc4c3e9SNicolas Vasilache int64_t idx = it.index(); 592bc4c3e9SNicolas Vasilache if (idx == index) 602bc4c3e9SNicolas Vasilache continue; 612bc4c3e9SNicolas Vasilache results.push_back(it.value()); 622bc4c3e9SNicolas Vasilache } 632bc4c3e9SNicolas Vasilache return results; 642bc4c3e9SNicolas Vasilache } 652bc4c3e9SNicolas Vasilache 662bc4c3e9SNicolas Vasilache // Helper to construct an affine map with one index removed. 672bc4c3e9SNicolas Vasilache static AffineMap adjustMap(AffineMap map, int64_t index, 682bc4c3e9SNicolas Vasilache PatternRewriter &rewriter) { 692bc4c3e9SNicolas Vasilache auto *ctx = rewriter.getContext(); 702bc4c3e9SNicolas Vasilache SmallVector<AffineExpr> results; 712bc4c3e9SNicolas Vasilache for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) { 722bc4c3e9SNicolas Vasilache int64_t idx = map.getDimPosition(i); 732bc4c3e9SNicolas Vasilache if (idx == index) 742bc4c3e9SNicolas Vasilache continue; 752bc4c3e9SNicolas Vasilache // Re-insert remaining indices, but renamed when occurring 762bc4c3e9SNicolas Vasilache // after the removed index. 772bc4c3e9SNicolas Vasilache auto targetExpr = getAffineDimExpr(idx < index ? idx : idx - 1, ctx); 782bc4c3e9SNicolas Vasilache results.push_back(targetExpr); 792bc4c3e9SNicolas Vasilache } 802bc4c3e9SNicolas Vasilache return AffineMap::get(map.getNumDims() - 1, 0, results, ctx); 812bc4c3e9SNicolas Vasilache } 822bc4c3e9SNicolas Vasilache 832bc4c3e9SNicolas Vasilache // Helper method to possibly drop a dimension in a load. 842bc4c3e9SNicolas Vasilache // TODO 852bc4c3e9SNicolas Vasilache static Value reshapeLoad(Location loc, Value val, VectorType type, 862bc4c3e9SNicolas Vasilache int64_t index, int64_t pos, 872bc4c3e9SNicolas Vasilache PatternRewriter &rewriter) { 882bc4c3e9SNicolas Vasilache if (index == -1) 892bc4c3e9SNicolas Vasilache return val; 9098f6289aSDiego Caballero 912bc4c3e9SNicolas Vasilache // At extraction dimension? 9216b75cd2SMatthias Springer if (index == 0) 9398f6289aSDiego Caballero return rewriter.create<vector::ExtractOp>(loc, val, pos); 9498f6289aSDiego Caballero 952bc4c3e9SNicolas Vasilache // Unroll leading dimensions. 9698f6289aSDiego Caballero VectorType vType = VectorType::Builder(type).dropDim(0); 97296d5cb6SBenjamin Maxwell VectorType resType = VectorType::Builder(type).dropDim(index); 982bc4c3e9SNicolas Vasilache Value result = rewriter.create<arith::ConstantOp>( 99296d5cb6SBenjamin Maxwell loc, resType, rewriter.getZeroAttr(resType)); 100296d5cb6SBenjamin Maxwell for (int64_t d = 0, e = resType.getDimSize(0); d < e; d++) { 10198f6289aSDiego Caballero Value ext = rewriter.create<vector::ExtractOp>(loc, val, d); 1022bc4c3e9SNicolas Vasilache Value load = reshapeLoad(loc, ext, vType, index - 1, pos, rewriter); 10398f6289aSDiego Caballero result = rewriter.create<vector::InsertOp>(loc, load, result, d); 1042bc4c3e9SNicolas Vasilache } 1052bc4c3e9SNicolas Vasilache return result; 1062bc4c3e9SNicolas Vasilache } 1072bc4c3e9SNicolas Vasilache 1082bc4c3e9SNicolas Vasilache // Helper method to possibly drop a dimension in a store. 1092bc4c3e9SNicolas Vasilache // TODO 1102bc4c3e9SNicolas Vasilache static Value reshapeStore(Location loc, Value val, Value result, 1112bc4c3e9SNicolas Vasilache VectorType type, int64_t index, int64_t pos, 1122bc4c3e9SNicolas Vasilache PatternRewriter &rewriter) { 1132bc4c3e9SNicolas Vasilache // Unmodified? 1142bc4c3e9SNicolas Vasilache if (index == -1) 1152bc4c3e9SNicolas Vasilache return val; 1162bc4c3e9SNicolas Vasilache // At insertion dimension? 11716b75cd2SMatthias Springer if (index == 0) 11898f6289aSDiego Caballero return rewriter.create<vector::InsertOp>(loc, val, result, pos); 11998f6289aSDiego Caballero 1202bc4c3e9SNicolas Vasilache // Unroll leading dimensions. 12198f6289aSDiego Caballero VectorType vType = VectorType::Builder(type).dropDim(0); 1222bc4c3e9SNicolas Vasilache for (int64_t d = 0, e = type.getDimSize(0); d < e; d++) { 12398f6289aSDiego Caballero Value ext = rewriter.create<vector::ExtractOp>(loc, result, d); 12498f6289aSDiego Caballero Value ins = rewriter.create<vector::ExtractOp>(loc, val, d); 12598f6289aSDiego Caballero Value sto = reshapeStore(loc, ins, ext, vType, index - 1, pos, rewriter); 12698f6289aSDiego Caballero result = rewriter.create<vector::InsertOp>(loc, sto, result, d); 1272bc4c3e9SNicolas Vasilache } 1282bc4c3e9SNicolas Vasilache return result; 1292bc4c3e9SNicolas Vasilache } 1302bc4c3e9SNicolas Vasilache 1312bc4c3e9SNicolas Vasilache /// Helper to create arithmetic operation associated with a kind of contraction. 1322bc4c3e9SNicolas Vasilache static std::optional<Value> 1332bc4c3e9SNicolas Vasilache createContractArithOp(Location loc, Value x, Value y, Value acc, 1342bc4c3e9SNicolas Vasilache vector::CombiningKind kind, PatternRewriter &rewriter, 1352bc4c3e9SNicolas Vasilache bool isInt, Value mask = Value()) { 1362bc4c3e9SNicolas Vasilache using vector::CombiningKind; 1372bc4c3e9SNicolas Vasilache Value mul; 1382bc4c3e9SNicolas Vasilache 1392bc4c3e9SNicolas Vasilache if (isInt) { 140560564f5SJakub Kuderski if (kind == CombiningKind::MINNUMF || kind == CombiningKind::MAXNUMF || 1414a831250SDaniil Dudkin kind == CombiningKind::MINIMUMF || kind == CombiningKind::MAXIMUMF) 1422bc4c3e9SNicolas Vasilache // Only valid for floating point types. 1432bc4c3e9SNicolas Vasilache return std::nullopt; 1442bc4c3e9SNicolas Vasilache mul = rewriter.create<arith::MulIOp>(loc, x, y); 1452bc4c3e9SNicolas Vasilache } else { 1462bc4c3e9SNicolas Vasilache // Float case. 1472bc4c3e9SNicolas Vasilache if (kind == CombiningKind::AND || kind == CombiningKind::MINUI || 1482bc4c3e9SNicolas Vasilache kind == CombiningKind::MINSI || kind == CombiningKind::MAXUI || 1492bc4c3e9SNicolas Vasilache kind == CombiningKind::MAXSI || kind == CombiningKind::OR || 1502bc4c3e9SNicolas Vasilache kind == CombiningKind::XOR) 1512bc4c3e9SNicolas Vasilache // Only valid for integer types. 1522bc4c3e9SNicolas Vasilache return std::nullopt; 1532bc4c3e9SNicolas Vasilache // Special case for fused multiply-add. 1545550c821STres Popp if (acc && isa<VectorType>(acc.getType()) && kind == CombiningKind::ADD) { 1552bc4c3e9SNicolas Vasilache Value fma = rewriter.create<vector::FMAOp>(loc, x, y, acc); 1562bc4c3e9SNicolas Vasilache if (mask) 1572bc4c3e9SNicolas Vasilache // The fma op doesn't need explicit masking. However, fma ops used in 1582bc4c3e9SNicolas Vasilache // reductions must preserve previous 'acc' values for masked-out lanes. 1592bc4c3e9SNicolas Vasilache fma = selectPassthru(rewriter, mask, fma, acc); 1602bc4c3e9SNicolas Vasilache return fma; 1612bc4c3e9SNicolas Vasilache } 1622bc4c3e9SNicolas Vasilache mul = rewriter.create<arith::MulFOp>(loc, x, y); 1632bc4c3e9SNicolas Vasilache } 1642bc4c3e9SNicolas Vasilache 1652bc4c3e9SNicolas Vasilache if (!acc) 1662bc4c3e9SNicolas Vasilache return std::optional<Value>(mul); 1672bc4c3e9SNicolas Vasilache 168a528cee2SJakub Kuderski return makeArithReduction(rewriter, loc, kind, mul, acc, 169a528cee2SJakub Kuderski /*fastmath=*/nullptr, mask); 1702bc4c3e9SNicolas Vasilache } 1712bc4c3e9SNicolas Vasilache 1722bc4c3e9SNicolas Vasilache /// Return the positions of the reductions in the given map. 1732bc4c3e9SNicolas Vasilache static SmallVector<int64_t> getReductionIndex(AffineMap map, 1742bc4c3e9SNicolas Vasilache ArrayAttr iteratorTypes) { 1752bc4c3e9SNicolas Vasilache SmallVector<int64_t> dimsIdx; 1762bc4c3e9SNicolas Vasilache for (unsigned i = 0, e = map.getNumResults(); i < e; i++) { 1772bc4c3e9SNicolas Vasilache if (isReductionIterator(iteratorTypes[map.getDimPosition(i)])) 1782bc4c3e9SNicolas Vasilache dimsIdx.push_back(i); 1792bc4c3e9SNicolas Vasilache } 1802bc4c3e9SNicolas Vasilache return dimsIdx; 1812bc4c3e9SNicolas Vasilache } 1822bc4c3e9SNicolas Vasilache 1832bc4c3e9SNicolas Vasilache /// Look for a given dimension in an affine map and return its position. Return 1842bc4c3e9SNicolas Vasilache /// std::nullopt if the dimension is not in the map results. 1852bc4c3e9SNicolas Vasilache static std::optional<unsigned> getDimPosition(AffineMap map, unsigned dim) { 1862bc4c3e9SNicolas Vasilache for (unsigned i = 0, e = map.getNumResults(); i < e; i++) { 1872bc4c3e9SNicolas Vasilache if (map.getDimPosition(i) == dim) 1882bc4c3e9SNicolas Vasilache return i; 1892bc4c3e9SNicolas Vasilache } 1902bc4c3e9SNicolas Vasilache return std::nullopt; 1912bc4c3e9SNicolas Vasilache } 1922bc4c3e9SNicolas Vasilache 1932bc4c3e9SNicolas Vasilache /// Creates an AddIOp if `isInt` is true otherwise create an arith::AddFOp using 1942bc4c3e9SNicolas Vasilache /// operands `x` and `y`. 1952bc4c3e9SNicolas Vasilache static Value createAdd(Location loc, Value x, Value y, bool isInt, 1962bc4c3e9SNicolas Vasilache PatternRewriter &rewriter) { 1972bc4c3e9SNicolas Vasilache if (isInt) 1982bc4c3e9SNicolas Vasilache return rewriter.create<arith::AddIOp>(loc, x, y); 1992bc4c3e9SNicolas Vasilache return rewriter.create<arith::AddFOp>(loc, x, y); 2002bc4c3e9SNicolas Vasilache } 2012bc4c3e9SNicolas Vasilache 2022bc4c3e9SNicolas Vasilache /// Creates a MulIOp if `isInt` is true otherwise create an MulFOp using 2032bc4c3e9SNicolas Vasilache /// operands `x and `y`. 2042bc4c3e9SNicolas Vasilache static Value createMul(Location loc, Value x, Value y, bool isInt, 2052bc4c3e9SNicolas Vasilache PatternRewriter &rewriter) { 2062bc4c3e9SNicolas Vasilache if (isInt) 2072bc4c3e9SNicolas Vasilache return rewriter.create<arith::MulIOp>(loc, x, y); 2082bc4c3e9SNicolas Vasilache return rewriter.create<arith::MulFOp>(loc, x, y); 2092bc4c3e9SNicolas Vasilache } 2102bc4c3e9SNicolas Vasilache 2112bc4c3e9SNicolas Vasilache namespace { 2122bc4c3e9SNicolas Vasilache 2132bc4c3e9SNicolas Vasilache /// Progressive lowering of a `vector.contract %a, %b, %c` with row-major matmul 2142bc4c3e9SNicolas Vasilache /// semantics to: 2152bc4c3e9SNicolas Vasilache /// ``` 2162bc4c3e9SNicolas Vasilache /// %flattened_a = vector.shape_cast %a 2172bc4c3e9SNicolas Vasilache /// %flattened_b = vector.shape_cast %b 2182bc4c3e9SNicolas Vasilache /// %flattened_d = vector.matmul %flattened_a, %flattened_b 2192bc4c3e9SNicolas Vasilache /// %d = vector.shape_cast %%flattened_d 2202bc4c3e9SNicolas Vasilache /// %e = add %c, %d 2212bc4c3e9SNicolas Vasilache /// ``` 2222bc4c3e9SNicolas Vasilache /// `vector.matmul` later lowers to `llvm.matrix.multiply`. 2232bc4c3e9SNicolas Vasilache // 2242bc4c3e9SNicolas Vasilache /// This only kicks in when VectorTransformsOptions is set to OuterProduct and 2252bc4c3e9SNicolas Vasilache /// the vector.contract op is a row-major matrix multiply. 2262bc4c3e9SNicolas Vasilache class ContractionOpToMatmulOpLowering 227b7324b6aSAndrzej Warzyński : public vector::MaskableOpRewritePattern<vector::ContractionOp> { 2282bc4c3e9SNicolas Vasilache public: 229b7324b6aSAndrzej Warzyński using MaskableOpRewritePattern::MaskableOpRewritePattern; 2302bc4c3e9SNicolas Vasilache 2312bc4c3e9SNicolas Vasilache using FilterConstraintType = 2322bc4c3e9SNicolas Vasilache std::function<LogicalResult(vector::ContractionOp op)>; 2332bc4c3e9SNicolas Vasilache 2342bc4c3e9SNicolas Vasilache static LogicalResult defaultFilter(vector::ContractionOp op) { 2352bc4c3e9SNicolas Vasilache return success(); 2362bc4c3e9SNicolas Vasilache } 2372bc4c3e9SNicolas Vasilache 2382bc4c3e9SNicolas Vasilache ContractionOpToMatmulOpLowering( 2392bc4c3e9SNicolas Vasilache vector::VectorTransformsOptions vectorTransformOptions, 2402bc4c3e9SNicolas Vasilache MLIRContext *context, PatternBenefit benefit = 1, 2412bc4c3e9SNicolas Vasilache FilterConstraintType constraint = defaultFilter) 242b7324b6aSAndrzej Warzyński : MaskableOpRewritePattern<vector::ContractionOp>(context, benefit), 2432bc4c3e9SNicolas Vasilache vectorTransformOptions(vectorTransformOptions), 2442bc4c3e9SNicolas Vasilache filter(std::move(constraint)) {} 2452bc4c3e9SNicolas Vasilache 246b7324b6aSAndrzej Warzyński FailureOr<Value> 247b7324b6aSAndrzej Warzyński matchAndRewriteMaskableOp(vector::ContractionOp op, MaskingOpInterface maskOp, 2482bc4c3e9SNicolas Vasilache PatternRewriter &rewriter) const override; 2492bc4c3e9SNicolas Vasilache 2502bc4c3e9SNicolas Vasilache private: 2512bc4c3e9SNicolas Vasilache /// Options to control the vector patterns. 2522bc4c3e9SNicolas Vasilache vector::VectorTransformsOptions vectorTransformOptions; 2532bc4c3e9SNicolas Vasilache FilterConstraintType filter; 2542bc4c3e9SNicolas Vasilache }; 2552bc4c3e9SNicolas Vasilache 2562bc4c3e9SNicolas Vasilache /// Progressive lowering of a `vector.contract %a, %b, %c` with row-major matmul 2572bc4c3e9SNicolas Vasilache /// semantics to a reduction_size-unrolled sequence: 2582bc4c3e9SNicolas Vasilache /// ``` 2592bc4c3e9SNicolas Vasilache /// %at = vector.transpose %a, [1, 0] 2602bc4c3e9SNicolas Vasilache /// %bRow0 = vector.extract %b[0] 2612bc4c3e9SNicolas Vasilache /// %atRow0 = vector.extract %at[0] 2622bc4c3e9SNicolas Vasilache /// %c0 = vector.outerproduct %atRow0, %bRow0, %c 2632bc4c3e9SNicolas Vasilache /// ... 2642bc4c3e9SNicolas Vasilache /// %bRowK = vector.extract %b[K] 2652bc4c3e9SNicolas Vasilache /// %atRowK = vector.extract %at[K] 2662bc4c3e9SNicolas Vasilache /// %cK = vector.outerproduct %atRowK, %bRowK, %cK-1 2672bc4c3e9SNicolas Vasilache /// ``` 2682bc4c3e9SNicolas Vasilache /// 2692bc4c3e9SNicolas Vasilache /// This only kicks in when VectorTransformsOptions is set to OuterProduct and 2702bc4c3e9SNicolas Vasilache /// the vector.contract op is a row-major matrix multiply. 2712bc4c3e9SNicolas Vasilache class ContractionOpToOuterProductOpLowering 272b7324b6aSAndrzej Warzyński : public MaskableOpRewritePattern<vector::ContractionOp> { 2732bc4c3e9SNicolas Vasilache public: 274b7324b6aSAndrzej Warzyński using MaskableOpRewritePattern::MaskableOpRewritePattern; 2752bc4c3e9SNicolas Vasilache 2762bc4c3e9SNicolas Vasilache using FilterConstraintType = 2772bc4c3e9SNicolas Vasilache std::function<LogicalResult(vector::ContractionOp op)>; 2782bc4c3e9SNicolas Vasilache 2792bc4c3e9SNicolas Vasilache static LogicalResult defaultFilter(vector::ContractionOp op) { 2802bc4c3e9SNicolas Vasilache return success(); 2812bc4c3e9SNicolas Vasilache } 2822bc4c3e9SNicolas Vasilache 2832bc4c3e9SNicolas Vasilache ContractionOpToOuterProductOpLowering( 2842bc4c3e9SNicolas Vasilache vector::VectorTransformsOptions vectorTransformOptions, 2852bc4c3e9SNicolas Vasilache MLIRContext *context, PatternBenefit benefit = 1, 2862bc4c3e9SNicolas Vasilache FilterConstraintType constraint = defaultFilter) 287b7324b6aSAndrzej Warzyński : MaskableOpRewritePattern<vector::ContractionOp>(context, benefit), 2882bc4c3e9SNicolas Vasilache vectorTransformOptions(vectorTransformOptions), 2892bc4c3e9SNicolas Vasilache filter(std::move(constraint)) {} 2902bc4c3e9SNicolas Vasilache 291b7324b6aSAndrzej Warzyński FailureOr<Value> 292b7324b6aSAndrzej Warzyński matchAndRewriteMaskableOp(vector::ContractionOp op, MaskingOpInterface maskOp, 2932bc4c3e9SNicolas Vasilache PatternRewriter &rewriter) const override; 2942bc4c3e9SNicolas Vasilache 2952bc4c3e9SNicolas Vasilache private: 2962bc4c3e9SNicolas Vasilache /// Options to control the vector patterns. 2972bc4c3e9SNicolas Vasilache vector::VectorTransformsOptions vectorTransformOptions; 2982bc4c3e9SNicolas Vasilache FilterConstraintType filter; 2992bc4c3e9SNicolas Vasilache }; 3002bc4c3e9SNicolas Vasilache 3012bc4c3e9SNicolas Vasilache /// Progressive lowering of a `vector.contract %a, %b, %c` with row-major matmul 3022bc4c3e9SNicolas Vasilache /// semantics to an output-size-unrolled sequence: 3032bc4c3e9SNicolas Vasilache /// ``` 3042bc4c3e9SNicolas Vasilache /// %out = arith.constant ... : vector<MxNxelt_type> 3052bc4c3e9SNicolas Vasilache /// %bt = vector.transpose %b, [1, 0] 3062bc4c3e9SNicolas Vasilache /// %aRow0 = vector.extract %a[0] 3072bc4c3e9SNicolas Vasilache /// %btRow0 = vector.extract %bt[0] 3082bc4c3e9SNicolas Vasilache /// %c00 = vector.reduce %atRow0, %bRow0 3092bc4c3e9SNicolas Vasilache /// %out00 = vector.insert %c00, %out[0, 0] 3102bc4c3e9SNicolas Vasilache /// ... 3112bc4c3e9SNicolas Vasilache /// %aRowLast = vector.extract %at[M-1] 3122bc4c3e9SNicolas Vasilache /// %btRowLast = vector.extract %b[N-1] 3132bc4c3e9SNicolas Vasilache /// %cLastLast = vector.reduce %atRowLast, %bRowLast 3142bc4c3e9SNicolas Vasilache /// %outcLastLast = vector.insert %cLastLast, %out[M-1, N-1] 3152bc4c3e9SNicolas Vasilache /// ``` 3162bc4c3e9SNicolas Vasilache /// 3172bc4c3e9SNicolas Vasilache /// This only kicks in when VectorTransformsOptions is set to Dot and 3182bc4c3e9SNicolas Vasilache /// the vector.contract op is a row-major matmul or matvec. 3192bc4c3e9SNicolas Vasilache class ContractionOpToDotLowering 320b7324b6aSAndrzej Warzyński : public MaskableOpRewritePattern<vector::ContractionOp> { 3212bc4c3e9SNicolas Vasilache public: 322b7324b6aSAndrzej Warzyński using MaskableOpRewritePattern::MaskableOpRewritePattern; 3232bc4c3e9SNicolas Vasilache 3242bc4c3e9SNicolas Vasilache using FilterConstraintType = 3252bc4c3e9SNicolas Vasilache std::function<LogicalResult(vector::ContractionOp op)>; 3262bc4c3e9SNicolas Vasilache 3272bc4c3e9SNicolas Vasilache static LogicalResult defaultFilter(vector::ContractionOp op) { 3282bc4c3e9SNicolas Vasilache return success(); 3292bc4c3e9SNicolas Vasilache } 3302bc4c3e9SNicolas Vasilache 3312bc4c3e9SNicolas Vasilache ContractionOpToDotLowering( 3322bc4c3e9SNicolas Vasilache vector::VectorTransformsOptions vectorTransformOptions, 3332bc4c3e9SNicolas Vasilache MLIRContext *context, PatternBenefit benefit = 1, 3342bc4c3e9SNicolas Vasilache const FilterConstraintType &constraint = defaultFilter) 335b7324b6aSAndrzej Warzyński : MaskableOpRewritePattern<vector::ContractionOp>(context, benefit), 3362bc4c3e9SNicolas Vasilache vectorTransformOptions(vectorTransformOptions), filter(defaultFilter) {} 3372bc4c3e9SNicolas Vasilache 338b7324b6aSAndrzej Warzyński FailureOr<Value> 339b7324b6aSAndrzej Warzyński matchAndRewriteMaskableOp(vector::ContractionOp op, MaskingOpInterface maskOp, 3402bc4c3e9SNicolas Vasilache PatternRewriter &rewriter) const override; 3412bc4c3e9SNicolas Vasilache 3422bc4c3e9SNicolas Vasilache private: 3432bc4c3e9SNicolas Vasilache /// Options to control the vector patterns. 3442bc4c3e9SNicolas Vasilache vector::VectorTransformsOptions vectorTransformOptions; 3452bc4c3e9SNicolas Vasilache FilterConstraintType filter; 3462bc4c3e9SNicolas Vasilache }; 3472bc4c3e9SNicolas Vasilache 3482bc4c3e9SNicolas Vasilache /// Progressive lowering of ContractionOp. 3492bc4c3e9SNicolas Vasilache /// 3502bc4c3e9SNicolas Vasilache /// One: 3512bc4c3e9SNicolas Vasilache /// %x = vector.contract with at least one free/batch dimension 3522bc4c3e9SNicolas Vasilache /// is replaced by: 3532bc4c3e9SNicolas Vasilache /// %a = vector.contract with one less free/batch dimension 3542bc4c3e9SNicolas Vasilache /// %b = vector.contract with one less free/batch dimension 3552bc4c3e9SNicolas Vasilache /// .. 3562bc4c3e9SNicolas Vasilache /// %x = combine %a %b .. 3572bc4c3e9SNicolas Vasilache /// until a pure contraction is reached (no free/batch dimensions), 3582bc4c3e9SNicolas Vasilache /// which is replaced by a dot-product. 3592bc4c3e9SNicolas Vasilache /// 3602bc4c3e9SNicolas Vasilache /// This only kicks in when either VectorTransformsOptions is set 3612bc4c3e9SNicolas Vasilache /// to Dot or when other contraction patterns fail. 362b7324b6aSAndrzej Warzyński class ContractionOpLowering 363b7324b6aSAndrzej Warzyński : public MaskableOpRewritePattern<vector::ContractionOp> { 3642bc4c3e9SNicolas Vasilache public: 365b7324b6aSAndrzej Warzyński using MaskableOpRewritePattern::MaskableOpRewritePattern; 3662bc4c3e9SNicolas Vasilache using FilterConstraintType = 3672bc4c3e9SNicolas Vasilache std::function<LogicalResult(vector::ContractionOp op)>; 3682bc4c3e9SNicolas Vasilache 3692bc4c3e9SNicolas Vasilache static LogicalResult defaultFilter(vector::ContractionOp op) { 3702bc4c3e9SNicolas Vasilache return success(); 3712bc4c3e9SNicolas Vasilache } 3722bc4c3e9SNicolas Vasilache 3732bc4c3e9SNicolas Vasilache ContractionOpLowering(vector::VectorTransformsOptions vectorTransformOptions, 3742bc4c3e9SNicolas Vasilache MLIRContext *context, PatternBenefit benefit = 1, 3752bc4c3e9SNicolas Vasilache FilterConstraintType constraint = defaultFilter) 376b7324b6aSAndrzej Warzyński : MaskableOpRewritePattern<vector::ContractionOp>(context, benefit), 3772bc4c3e9SNicolas Vasilache vectorTransformOptions(vectorTransformOptions), 3782bc4c3e9SNicolas Vasilache filter(std::move(constraint)) {} 3792bc4c3e9SNicolas Vasilache 380b7324b6aSAndrzej Warzyński FailureOr<Value> 381b7324b6aSAndrzej Warzyński matchAndRewriteMaskableOp(vector::ContractionOp op, MaskingOpInterface maskOp, 3822bc4c3e9SNicolas Vasilache PatternRewriter &rewriter) const override; 3832bc4c3e9SNicolas Vasilache 3842bc4c3e9SNicolas Vasilache private: 3852bc4c3e9SNicolas Vasilache /// Options to control the vector patterns. 3862bc4c3e9SNicolas Vasilache vector::VectorTransformsOptions vectorTransformOptions; 3872bc4c3e9SNicolas Vasilache FilterConstraintType filter; 3882bc4c3e9SNicolas Vasilache // Lower one parallel dimension. 3892bc4c3e9SNicolas Vasilache FailureOr<Value> lowerParallel(PatternRewriter &rewriter, 3902bc4c3e9SNicolas Vasilache vector::ContractionOp op, int64_t lhsIndex, 3912bc4c3e9SNicolas Vasilache int64_t rhsIndex, Value mask) const; 3922bc4c3e9SNicolas Vasilache // Lower one reduction dimension. 3932bc4c3e9SNicolas Vasilache FailureOr<Value> lowerReduction(PatternRewriter &rewriter, 3942bc4c3e9SNicolas Vasilache vector::ContractionOp op, Value mask) const; 3952bc4c3e9SNicolas Vasilache }; 3962bc4c3e9SNicolas Vasilache 3972bc4c3e9SNicolas Vasilache /// Generate a vector implementation for matmat, matvec and tmatvec. 3982bc4c3e9SNicolas Vasilache /// This unrolls outer-products along the reduction dimension. 3992bc4c3e9SNicolas Vasilache struct UnrolledOuterProductGenerator 4002bc4c3e9SNicolas Vasilache : public StructuredGenerator<vector::ContractionOp, vector::IteratorType> { 4012bc4c3e9SNicolas Vasilache UnrolledOuterProductGenerator(RewriterBase &b, vector::ContractionOp op) 4022bc4c3e9SNicolas Vasilache : StructuredGenerator<vector::ContractionOp, vector::IteratorType>(b, op), 4032bc4c3e9SNicolas Vasilache kind(op.getKind()), lhs(op.getLhs()), rhs(op.getRhs()), 4042bc4c3e9SNicolas Vasilache res(op.getAcc()), lhsType(op.getLhsType()) { 4052bc4c3e9SNicolas Vasilache auto maskableOp = cast<MaskableOpInterface>(op.getOperation()); 4062bc4c3e9SNicolas Vasilache if (maskableOp.isMasked()) 4072bc4c3e9SNicolas Vasilache mask = maskableOp.getMaskingOp().getMask(); 4082bc4c3e9SNicolas Vasilache } 4092bc4c3e9SNicolas Vasilache 4102bc4c3e9SNicolas Vasilache Value t(Value v, ArrayRef<int64_t> perm = {1, 0}) { 4112bc4c3e9SNicolas Vasilache if (!v) 4122bc4c3e9SNicolas Vasilache return v; 4132bc4c3e9SNicolas Vasilache return rewriter.create<vector::TransposeOp>(loc, v, perm); 4142bc4c3e9SNicolas Vasilache } 4152bc4c3e9SNicolas Vasilache 4162bc4c3e9SNicolas Vasilache Value promote(Value v, Type dstElementType) { 4172bc4c3e9SNicolas Vasilache Type elementType = v.getType(); 4185550c821STres Popp auto vecType = dyn_cast<VectorType>(elementType); 4192bc4c3e9SNicolas Vasilache if (vecType) 4202bc4c3e9SNicolas Vasilache elementType = vecType.getElementType(); 4212bc4c3e9SNicolas Vasilache if (elementType == dstElementType) 4222bc4c3e9SNicolas Vasilache return v; 4232bc4c3e9SNicolas Vasilache Type promotedType = dstElementType; 4242bc4c3e9SNicolas Vasilache if (vecType) 4255270df3dSAndrzej Warzyński promotedType = vecType.clone(promotedType); 4265550c821STres Popp if (isa<FloatType>(dstElementType)) 4272bc4c3e9SNicolas Vasilache return rewriter.create<arith::ExtFOp>(loc, promotedType, v); 4282bc4c3e9SNicolas Vasilache return rewriter.create<arith::ExtSIOp>(loc, promotedType, v); 4292bc4c3e9SNicolas Vasilache } 4302bc4c3e9SNicolas Vasilache 431c91d3b0bSAndrzej Warzynski FailureOr<Value> outerProd(Value lhs, Value rhs, Value res, 432c0a354dfSMatthias Springer VectorType lhsType, int reductionSize, 4332bc4c3e9SNicolas Vasilache std::optional<Value> maybeMask = std::nullopt) { 4342bc4c3e9SNicolas Vasilache // Incremental support for masking. 4352bc4c3e9SNicolas Vasilache if (mask && !maybeMask.has_value()) 4362bc4c3e9SNicolas Vasilache return failure(); 4372bc4c3e9SNicolas Vasilache 4385550c821STres Popp Type resElementType = cast<VectorType>(res.getType()).getElementType(); 4392bc4c3e9SNicolas Vasilache for (int64_t k = 0; k < reductionSize; ++k) { 4402bc4c3e9SNicolas Vasilache Value extractA = rewriter.create<vector::ExtractOp>(loc, lhs, k); 4412bc4c3e9SNicolas Vasilache Value extractB = rewriter.create<vector::ExtractOp>(loc, rhs, k); 4422bc4c3e9SNicolas Vasilache extractA = promote(extractA, resElementType); 4432bc4c3e9SNicolas Vasilache extractB = promote(extractB, resElementType); 4442bc4c3e9SNicolas Vasilache Value extractMask; 4452bc4c3e9SNicolas Vasilache if (maybeMask.has_value() && maybeMask.value()) 4462bc4c3e9SNicolas Vasilache extractMask = 4472bc4c3e9SNicolas Vasilache rewriter.create<vector::ExtractOp>(loc, maybeMask.value(), k); 4482bc4c3e9SNicolas Vasilache 4492bc4c3e9SNicolas Vasilache Operation *outerProdOp = rewriter.create<vector::OuterProductOp>( 4502bc4c3e9SNicolas Vasilache loc, res.getType(), extractA, extractB, res, kind); 4512bc4c3e9SNicolas Vasilache res = maskOperation(rewriter, outerProdOp, extractMask)->getResult(0); 4522bc4c3e9SNicolas Vasilache } 4532bc4c3e9SNicolas Vasilache return res; 4542bc4c3e9SNicolas Vasilache } 4552bc4c3e9SNicolas Vasilache 456c0a354dfSMatthias Springer /// Helper function for `matmat`, `matvec`, `tmatvec`. Returns the size of 457c0a354dfSMatthias Springer /// dimension `reductionDim`. If the dimension is a scalable dimension, 458c0a354dfSMatthias Springer /// returns "nullopt". 459c0a354dfSMatthias Springer std::optional<int64_t> getReductionSize(VectorType vecType, 460c0a354dfSMatthias Springer int64_t reductionDim) { 461c0a354dfSMatthias Springer // Cannot unroll scalable dimension. 462c0a354dfSMatthias Springer if (vecType.getScalableDims()[reductionDim]) 463c0a354dfSMatthias Springer return std::nullopt; 464c0a354dfSMatthias Springer int64_t reductionSize = vecType.getDimSize(reductionDim); 465c0a354dfSMatthias Springer assert(reductionSize > 0 && 466c0a354dfSMatthias Springer "Reduction dim must be a known static size to allow unrolling"); 467c0a354dfSMatthias Springer return reductionSize; 468c0a354dfSMatthias Springer } 469c0a354dfSMatthias Springer 4702bc4c3e9SNicolas Vasilache /// Two outer parallel, one inner reduction (matmat flavor). 4712bc4c3e9SNicolas Vasilache FailureOr<Value> matmat() { 4722bc4c3e9SNicolas Vasilache if (!iters({Par(), Par(), Red()})) 4732bc4c3e9SNicolas Vasilache return failure(); 4742bc4c3e9SNicolas Vasilache // Set up the parallel/reduction structure in the right form. 4752bc4c3e9SNicolas Vasilache AffineExpr m, n, k; 4762bc4c3e9SNicolas Vasilache bindDims(rewriter.getContext(), m, n, k); 477c0a354dfSMatthias Springer 4782bc4c3e9SNicolas Vasilache // Classical row-major matmul: Just permute the lhs. 479c0a354dfSMatthias Springer if (layout({{m, k}, {k, n}, {m, n}})) { 480c0a354dfSMatthias Springer if (auto reductionSize = getReductionSize(lhsType, 1)) { 481c0a354dfSMatthias Springer // Note: `t` creates new IR. It must be nested within this `if` check 482c0a354dfSMatthias Springer // so that no IR is created when then pattern returns "failure". 483c0a354dfSMatthias Springer Value tLhs = t(lhs); 484c0a354dfSMatthias Springer Value tMask = t(mask, {2, 0, 1}); 485c0a354dfSMatthias Springer return outerProd(tLhs, rhs, res, lhsType, *reductionSize, tMask); 486c0a354dfSMatthias Springer } 487c0a354dfSMatthias Springer } 4882bc4c3e9SNicolas Vasilache // TODO: may be better to fail and use some vector<k> -> scalar reduction. 4892bc4c3e9SNicolas Vasilache if (layout({{m, k}, {n, k}, {m, n}})) { 490c0a354dfSMatthias Springer if (auto reductionSize = getReductionSize(lhsType, 1)) { 491c0a354dfSMatthias Springer Value tLhs = t(lhs); 492c0a354dfSMatthias Springer Value tRhs = t(rhs); 493c0a354dfSMatthias Springer Value tMask = t(mask, {2, 0, 1}); 494c0a354dfSMatthias Springer return outerProd(tLhs, tRhs, res, lhsType, *reductionSize, tMask); 495c0a354dfSMatthias Springer } 4962bc4c3e9SNicolas Vasilache } 4972bc4c3e9SNicolas Vasilache // No need to permute anything. 498c0a354dfSMatthias Springer if (layout({{k, m}, {k, n}, {m, n}})) { 499c0a354dfSMatthias Springer if (auto reductionSize = getReductionSize(lhsType, 0)) { 500c0a354dfSMatthias Springer Value tMask = t(mask, {2, 0, 1}); 501c0a354dfSMatthias Springer return outerProd(lhs, rhs, res, lhsType, *reductionSize, tMask); 502c0a354dfSMatthias Springer } 503c0a354dfSMatthias Springer } 5042bc4c3e9SNicolas Vasilache // Just permute the rhs. 505c0a354dfSMatthias Springer if (layout({{k, m}, {n, k}, {m, n}})) { 506c0a354dfSMatthias Springer if (auto reductionSize = getReductionSize(lhsType, 0)) { 507c0a354dfSMatthias Springer Value tRhs = t(rhs); 508c0a354dfSMatthias Springer Value tMask = t(mask, {2, 0, 1}); 509c0a354dfSMatthias Springer return outerProd(lhs, tRhs, res, lhsType, *reductionSize, tMask); 510c0a354dfSMatthias Springer } 511c0a354dfSMatthias Springer } 5122bc4c3e9SNicolas Vasilache // Transposed output: swap RHS and LHS. 5132bc4c3e9SNicolas Vasilache // Classical row-major matmul: permute the lhs. 514c0a354dfSMatthias Springer if (layout({{m, k}, {k, n}, {n, m}})) { 515c0a354dfSMatthias Springer if (auto reductionSize = getReductionSize(lhsType, 1)) { 516c0a354dfSMatthias Springer Value tLhs = t(lhs); 517c0a354dfSMatthias Springer Value tMask = t(mask, {2, 0, 1}); 518c0a354dfSMatthias Springer return outerProd(rhs, tLhs, res, lhsType, *reductionSize, tMask); 519c0a354dfSMatthias Springer } 520c0a354dfSMatthias Springer } 5212bc4c3e9SNicolas Vasilache // TODO: may be better to fail and use some vector<k> -> scalar reduction. 5222bc4c3e9SNicolas Vasilache if (layout({{m, k}, {n, k}, {n, m}})) { 523c0a354dfSMatthias Springer if (auto reductionSize = getReductionSize(lhsType, 1)) { 524c0a354dfSMatthias Springer Value tRhs = t(rhs); 525c0a354dfSMatthias Springer Value tLhs = t(lhs); 526c0a354dfSMatthias Springer Value tMask = t(mask, {2, 0, 1}); 527c0a354dfSMatthias Springer return outerProd(tRhs, tLhs, res, lhsType, *reductionSize, tMask); 5282bc4c3e9SNicolas Vasilache } 529c0a354dfSMatthias Springer } 530c0a354dfSMatthias Springer if (layout({{k, m}, {k, n}, {n, m}})) { 531c0a354dfSMatthias Springer if (auto reductionSize = getReductionSize(lhsType, 0)) { 532c0a354dfSMatthias Springer Value tMask = t(mask, {2, 0, 1}); 533c0a354dfSMatthias Springer return outerProd(rhs, lhs, res, lhsType, *reductionSize, tMask); 534c0a354dfSMatthias Springer } 535c0a354dfSMatthias Springer } 536c0a354dfSMatthias Springer if (layout({{k, m}, {n, k}, {n, m}})) { 537c0a354dfSMatthias Springer if (auto reductionSize = getReductionSize(lhsType, 0)) { 538c0a354dfSMatthias Springer Value tRhs = t(rhs); 539c0a354dfSMatthias Springer Value tMask = t(mask, {2, 0, 1}); 540c0a354dfSMatthias Springer return outerProd(tRhs, lhs, res, lhsType, *reductionSize, tMask); 541c0a354dfSMatthias Springer } 542c0a354dfSMatthias Springer } 5432bc4c3e9SNicolas Vasilache return failure(); 5442bc4c3e9SNicolas Vasilache } 5452bc4c3e9SNicolas Vasilache 546a509a187SOleksandr "Alex" Zinenko // 547a509a187SOleksandr "Alex" Zinenko // One outer parallel, one inner reduction (matvec flavor). 548a509a187SOleksandr "Alex" Zinenko // Mask needs to be transposed everywhere to turn the reduction dimension 549a509a187SOleksandr "Alex" Zinenko // outermost as required by outerproduct. 550a509a187SOleksandr "Alex" Zinenko // 5512bc4c3e9SNicolas Vasilache FailureOr<Value> matvec() { 5522bc4c3e9SNicolas Vasilache if (!iters({Par(), Red()})) 5532bc4c3e9SNicolas Vasilache return failure(); 5542bc4c3e9SNicolas Vasilache AffineExpr m, k; 5552bc4c3e9SNicolas Vasilache bindDims(rewriter.getContext(), m, k); 5562bc4c3e9SNicolas Vasilache 5572bc4c3e9SNicolas Vasilache // Case mat-vec: transpose. 558c0a354dfSMatthias Springer if (layout({{m, k}, {k}, {m}})) { 559c0a354dfSMatthias Springer if (auto reductionSize = getReductionSize(lhsType, 1)) { 560c0a354dfSMatthias Springer Value tLhs = t(lhs); 561c0a354dfSMatthias Springer Value tMask = t(mask); 562c0a354dfSMatthias Springer return outerProd(tLhs, rhs, res, lhsType, *reductionSize, tMask); 563c0a354dfSMatthias Springer } 564c0a354dfSMatthias Springer } 5652bc4c3e9SNicolas Vasilache // Case mat-trans-vec: ready to go. 566c0a354dfSMatthias Springer if (layout({{k, m}, {k}, {m}})) { 567c0a354dfSMatthias Springer if (auto reductionSize = getReductionSize(lhsType, 0)) { 568c0a354dfSMatthias Springer Value tMask = t(mask); 569c0a354dfSMatthias Springer return outerProd(lhs, rhs, res, lhsType, *reductionSize, tMask); 570c0a354dfSMatthias Springer } 571c0a354dfSMatthias Springer } 5722bc4c3e9SNicolas Vasilache // Case vec-mat: swap and transpose. 573c0a354dfSMatthias Springer if (layout({{k}, {m, k}, {m}})) { 574c0a354dfSMatthias Springer if (auto reductionSize = getReductionSize(lhsType, 0)) { 575c0a354dfSMatthias Springer Value tRhs = t(rhs); 576c0a354dfSMatthias Springer Value tMask = t(mask); 577c0a354dfSMatthias Springer return outerProd(tRhs, lhs, res, lhsType, *reductionSize, tMask); 578c0a354dfSMatthias Springer } 579c0a354dfSMatthias Springer } 5802bc4c3e9SNicolas Vasilache // Case vec-mat-trans: swap and ready to go. 581c0a354dfSMatthias Springer if (layout({{k}, {k, m}, {m}})) { 582c0a354dfSMatthias Springer if (auto reductionSize = getReductionSize(lhsType, 0)) { 583c0a354dfSMatthias Springer Value tMask = t(mask); 584c0a354dfSMatthias Springer return outerProd(rhs, lhs, res, lhsType, *reductionSize, tMask); 585c0a354dfSMatthias Springer } 586c0a354dfSMatthias Springer } 5872bc4c3e9SNicolas Vasilache return failure(); 5882bc4c3e9SNicolas Vasilache } 5892bc4c3e9SNicolas Vasilache 5902bc4c3e9SNicolas Vasilache // 591a509a187SOleksandr "Alex" Zinenko // One outer reduction, one inner parallel (tmatvec flavor). 592a509a187SOleksandr "Alex" Zinenko // Mask already has the shape of the outer product. 5932bc4c3e9SNicolas Vasilache // 5942bc4c3e9SNicolas Vasilache FailureOr<Value> tmatvec() { 5952bc4c3e9SNicolas Vasilache if (!iters({Red(), Par()})) 5962bc4c3e9SNicolas Vasilache return failure(); 5972bc4c3e9SNicolas Vasilache AffineExpr k, m; 5982bc4c3e9SNicolas Vasilache bindDims(rewriter.getContext(), k, m); 5992bc4c3e9SNicolas Vasilache 6002bc4c3e9SNicolas Vasilache // Case mat-vec: transpose. 6012bc4c3e9SNicolas Vasilache if (layout({{m, k}, {k}, {m}})) 602c0a354dfSMatthias Springer if (auto reductionSize = getReductionSize(lhsType, 1)) 603c0a354dfSMatthias Springer return outerProd(t(lhs), rhs, res, lhsType, *reductionSize, mask); 6042bc4c3e9SNicolas Vasilache // Case mat-trans-vec: ready to go. 6052bc4c3e9SNicolas Vasilache if (layout({{k, m}, {k}, {m}})) 606c0a354dfSMatthias Springer if (auto reductionSize = getReductionSize(lhsType, 0)) 607c0a354dfSMatthias Springer return outerProd(lhs, rhs, res, lhsType, *reductionSize, mask); 6082bc4c3e9SNicolas Vasilache // Case vec-mat: swap and transpose. 6092bc4c3e9SNicolas Vasilache if (layout({{k}, {m, k}, {m}})) 610c0a354dfSMatthias Springer if (auto reductionSize = getReductionSize(lhsType, 0)) 611c0a354dfSMatthias Springer return outerProd(t(rhs), lhs, res, lhsType, *reductionSize, mask); 6122bc4c3e9SNicolas Vasilache // Case vec-mat-trans: swap and ready to go. 6132bc4c3e9SNicolas Vasilache if (layout({{k}, {k, m}, {m}})) 614c0a354dfSMatthias Springer if (auto reductionSize = getReductionSize(lhsType, 0)) 615c0a354dfSMatthias Springer return outerProd(rhs, lhs, res, lhsType, *reductionSize, mask); 6162bc4c3e9SNicolas Vasilache return failure(); 6172bc4c3e9SNicolas Vasilache } 6182bc4c3e9SNicolas Vasilache 6192bc4c3e9SNicolas Vasilache private: 6202bc4c3e9SNicolas Vasilache vector::CombiningKind kind; 6212bc4c3e9SNicolas Vasilache Value lhs, rhs, res, mask; 6222bc4c3e9SNicolas Vasilache VectorType lhsType; 6232bc4c3e9SNicolas Vasilache }; 6242bc4c3e9SNicolas Vasilache 6252bc4c3e9SNicolas Vasilache /// Progressively lower a `vector.contract %a, %b, %c` with row-major matmul 6262bc4c3e9SNicolas Vasilache /// semantics to a reduction_size-unrolled sequence: 6272bc4c3e9SNicolas Vasilache /// ``` 6282bc4c3e9SNicolas Vasilache /// %at = vector.transpose %a, [1, 0] 6292bc4c3e9SNicolas Vasilache /// %bRow0 = vector.extract %b[0] 6302bc4c3e9SNicolas Vasilache /// %atRow0 = vector.extract %at[0] 6312bc4c3e9SNicolas Vasilache /// %c0 = vector.outerproduct %atRow0, %bRow0, %c 6322bc4c3e9SNicolas Vasilache /// ... 6332bc4c3e9SNicolas Vasilache /// %bRowK = vector.extract %b[K] 6342bc4c3e9SNicolas Vasilache /// %atRowK = vector.extract %at[K] 6352bc4c3e9SNicolas Vasilache /// %cK = vector.outerproduct %atRowK, %bRowK, %cK-1 6362bc4c3e9SNicolas Vasilache /// ``` 6372bc4c3e9SNicolas Vasilache /// 6382bc4c3e9SNicolas Vasilache /// This only kicks in when VectorTransformsOptions is set to OuterProduct but 6392bc4c3e9SNicolas Vasilache /// otherwise supports any layout permutation of the matrix-multiply. 640b7324b6aSAndrzej Warzyński FailureOr<Value> 641b7324b6aSAndrzej Warzyński ContractionOpToOuterProductOpLowering::matchAndRewriteMaskableOp( 642b7324b6aSAndrzej Warzyński vector::ContractionOp op, MaskingOpInterface maskOp, 643b7324b6aSAndrzej Warzyński PatternRewriter &rewriter) const { 6442bc4c3e9SNicolas Vasilache if (vectorTransformOptions.vectorContractLowering != 6452bc4c3e9SNicolas Vasilache vector::VectorContractLowering::OuterProduct) 6462bc4c3e9SNicolas Vasilache return failure(); 6472bc4c3e9SNicolas Vasilache 6482bc4c3e9SNicolas Vasilache if (failed(filter(op))) 6492bc4c3e9SNicolas Vasilache return failure(); 6502bc4c3e9SNicolas Vasilache 6512bc4c3e9SNicolas Vasilache UnrolledOuterProductGenerator e(rewriter, op); 6522bc4c3e9SNicolas Vasilache FailureOr<Value> matmatRes = e.matmat(); 6532bc4c3e9SNicolas Vasilache if (succeeded(matmatRes)) { 654b7324b6aSAndrzej Warzyński return matmatRes; 6552bc4c3e9SNicolas Vasilache } 6562bc4c3e9SNicolas Vasilache FailureOr<Value> matvecRes = e.matvec(); 6572bc4c3e9SNicolas Vasilache if (succeeded(matvecRes)) { 658b7324b6aSAndrzej Warzyński return matvecRes; 6592bc4c3e9SNicolas Vasilache } 660b7324b6aSAndrzej Warzyński 6612bc4c3e9SNicolas Vasilache FailureOr<Value> tmatvecRes = e.tmatvec(); 662b7324b6aSAndrzej Warzyński return tmatvecRes; 6632bc4c3e9SNicolas Vasilache } 6642bc4c3e9SNicolas Vasilache 665b7324b6aSAndrzej Warzyński FailureOr<Value> ContractionOpToDotLowering::matchAndRewriteMaskableOp( 666b7324b6aSAndrzej Warzyński vector::ContractionOp op, MaskingOpInterface maskOp, 6672bc4c3e9SNicolas Vasilache PatternRewriter &rewriter) const { 6682bc4c3e9SNicolas Vasilache // TODO: Support vector.mask. 669b7324b6aSAndrzej Warzyński if (maskOp) 6702bc4c3e9SNicolas Vasilache return failure(); 6712bc4c3e9SNicolas Vasilache 6722bc4c3e9SNicolas Vasilache if (failed(filter(op))) 6732bc4c3e9SNicolas Vasilache return failure(); 6742bc4c3e9SNicolas Vasilache 6752bc4c3e9SNicolas Vasilache if (vectorTransformOptions.vectorContractLowering != 6762bc4c3e9SNicolas Vasilache vector::VectorContractLowering::Dot) 6772bc4c3e9SNicolas Vasilache return failure(); 6782bc4c3e9SNicolas Vasilache 6792bc4c3e9SNicolas Vasilache auto iteratorTypes = op.getIteratorTypes().getValue(); 6802bc4c3e9SNicolas Vasilache static constexpr std::array<int64_t, 2> perm = {1, 0}; 6812bc4c3e9SNicolas Vasilache Location loc = op.getLoc(); 6822bc4c3e9SNicolas Vasilache Value lhs = op.getLhs(), rhs = op.getRhs(); 6832bc4c3e9SNicolas Vasilache 6842bc4c3e9SNicolas Vasilache using MapList = ArrayRef<ArrayRef<AffineExpr>>; 685fe8a62c4SUday Bondhugula auto infer = [&](MapList m) { 686fe8a62c4SUday Bondhugula return AffineMap::inferFromExprList(m, op.getContext()); 687fe8a62c4SUday Bondhugula }; 6882bc4c3e9SNicolas Vasilache AffineExpr m, n, k; 6892bc4c3e9SNicolas Vasilache bindDims(rewriter.getContext(), m, n, k); 6902bc4c3e9SNicolas Vasilache SmallVector<AffineMap> maps = op.getIndexingMapsArray(); 6912bc4c3e9SNicolas Vasilache // 6922bc4c3e9SNicolas Vasilache // In the following we wish to make the reduction dimension innermost so we 6932bc4c3e9SNicolas Vasilache // can load vectors and just fmul + reduce into a scalar. 6942bc4c3e9SNicolas Vasilache // 6952bc4c3e9SNicolas Vasilache if (isParallelIterator(iteratorTypes[0]) && 6962bc4c3e9SNicolas Vasilache isParallelIterator(iteratorTypes[1]) && 6972bc4c3e9SNicolas Vasilache isReductionIterator(iteratorTypes[2])) { 6982bc4c3e9SNicolas Vasilache // 6992bc4c3e9SNicolas Vasilache // Two outer parallel, one inner reduction (matmat flavor). 7002bc4c3e9SNicolas Vasilache // 7012bc4c3e9SNicolas Vasilache if (maps == infer({{m, k}, {k, n}, {m, n}})) { 7022bc4c3e9SNicolas Vasilache rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm); 7032bc4c3e9SNicolas Vasilache } else if (maps == infer({{m, k}, {n, k}, {m, n}})) { 7042bc4c3e9SNicolas Vasilache // No need to permute anything. 7052bc4c3e9SNicolas Vasilache } else if (maps == infer({{k, m}, {k, n}, {m, n}})) { 7062bc4c3e9SNicolas Vasilache lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm); 7072bc4c3e9SNicolas Vasilache rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm); 7082bc4c3e9SNicolas Vasilache } else if (maps == infer({{k, m}, {n, k}, {m, n}})) { 7092bc4c3e9SNicolas Vasilache lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm); 7102bc4c3e9SNicolas Vasilache } else if (maps == infer({{m, k}, {k, n}, {n, m}})) { 7112bc4c3e9SNicolas Vasilache // This is the classical row-major matmul. Just permute the lhs. 7122bc4c3e9SNicolas Vasilache Value tmp = lhs; 7132bc4c3e9SNicolas Vasilache lhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm); 7142bc4c3e9SNicolas Vasilache rhs = tmp; 7152bc4c3e9SNicolas Vasilache } else if (maps == infer({{m, k}, {n, k}, {n, m}})) { 7162bc4c3e9SNicolas Vasilache std::swap(lhs, rhs); 7172bc4c3e9SNicolas Vasilache } else if (maps == infer({{k, m}, {k, n}, {n, m}})) { 7182bc4c3e9SNicolas Vasilache Value tmp = lhs; 7192bc4c3e9SNicolas Vasilache lhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm); 7202bc4c3e9SNicolas Vasilache rhs = rewriter.create<vector::TransposeOp>(loc, tmp, perm); 7212bc4c3e9SNicolas Vasilache } else if (maps == infer({{k, m}, {n, k}, {n, m}})) { 7222bc4c3e9SNicolas Vasilache Value tmp = rhs; 7232bc4c3e9SNicolas Vasilache rhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm); 7242bc4c3e9SNicolas Vasilache lhs = tmp; 7252bc4c3e9SNicolas Vasilache } else { 7262bc4c3e9SNicolas Vasilache return failure(); 7272bc4c3e9SNicolas Vasilache } 7282bc4c3e9SNicolas Vasilache } else if (isParallelIterator(iteratorTypes[0]) && 7292bc4c3e9SNicolas Vasilache isReductionIterator(iteratorTypes[1])) { 7302bc4c3e9SNicolas Vasilache // 7312bc4c3e9SNicolas Vasilache // One outer parallel, one inner reduction (matvec flavor) 7322bc4c3e9SNicolas Vasilache // 7332bc4c3e9SNicolas Vasilache if (maps == infer({{m, n}, {n}, {m}})) { 7342bc4c3e9SNicolas Vasilache // No need to permute anything. 7352bc4c3e9SNicolas Vasilache } else if (maps == infer({{n, m}, {n}, {m}})) { 7362bc4c3e9SNicolas Vasilache lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm); 7372bc4c3e9SNicolas Vasilache } else if (maps == infer({{n}, {m, n}, {m}})) { 7382bc4c3e9SNicolas Vasilache std::swap(lhs, rhs); 7392bc4c3e9SNicolas Vasilache } else if (maps == infer({{n}, {n, m}, {m}})) { 7402bc4c3e9SNicolas Vasilache std::swap(lhs, rhs); 7412bc4c3e9SNicolas Vasilache lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm); 7422bc4c3e9SNicolas Vasilache } else { 7432bc4c3e9SNicolas Vasilache return failure(); 7442bc4c3e9SNicolas Vasilache } 7452bc4c3e9SNicolas Vasilache } else { 7462bc4c3e9SNicolas Vasilache return failure(); 7472bc4c3e9SNicolas Vasilache } 7482bc4c3e9SNicolas Vasilache 7495550c821STres Popp VectorType dstType = cast<VectorType>(op.getResultType()); 7502bc4c3e9SNicolas Vasilache assert(dstType.getRank() >= 1 && dstType.getRank() <= 2 && 7512bc4c3e9SNicolas Vasilache "Expected dst type of rank 1 or 2"); 7522bc4c3e9SNicolas Vasilache 7532bc4c3e9SNicolas Vasilache unsigned rank = dstType.getRank(); 7542bc4c3e9SNicolas Vasilache unsigned dstRows = dstType.getShape()[0]; 7552bc4c3e9SNicolas Vasilache unsigned dstColumns = rank == 1 ? 1 : dstType.getShape()[1]; 7562bc4c3e9SNicolas Vasilache 7572bc4c3e9SNicolas Vasilache // ExtractOp does not allow dynamic indexing, we must unroll explicitly. 7582bc4c3e9SNicolas Vasilache Value res = rewriter.create<arith::ConstantOp>(loc, dstType, 7592bc4c3e9SNicolas Vasilache rewriter.getZeroAttr(dstType)); 7605550c821STres Popp bool isInt = isa<IntegerType>(dstType.getElementType()); 7612bc4c3e9SNicolas Vasilache for (unsigned r = 0; r < dstRows; ++r) { 7622bc4c3e9SNicolas Vasilache Value a = rewriter.create<vector::ExtractOp>(op.getLoc(), lhs, r); 7632bc4c3e9SNicolas Vasilache for (unsigned c = 0; c < dstColumns; ++c) { 7642bc4c3e9SNicolas Vasilache Value b = rank == 1 7652bc4c3e9SNicolas Vasilache ? rhs 7662bc4c3e9SNicolas Vasilache : rewriter.create<vector::ExtractOp>(op.getLoc(), rhs, c); 7672bc4c3e9SNicolas Vasilache Value m = createMul(op.getLoc(), a, b, isInt, rewriter); 7682bc4c3e9SNicolas Vasilache Value reduced = rewriter.create<vector::ReductionOp>( 7692bc4c3e9SNicolas Vasilache op.getLoc(), vector::CombiningKind::ADD, m); 7702bc4c3e9SNicolas Vasilache 7712bc4c3e9SNicolas Vasilache SmallVector<int64_t, 2> pos = rank == 1 ? SmallVector<int64_t, 2>{r} 7722bc4c3e9SNicolas Vasilache : SmallVector<int64_t, 2>{r, c}; 7732bc4c3e9SNicolas Vasilache res = rewriter.create<vector::InsertOp>(op.getLoc(), reduced, res, pos); 7742bc4c3e9SNicolas Vasilache } 7752bc4c3e9SNicolas Vasilache } 7762bc4c3e9SNicolas Vasilache if (auto acc = op.getAcc()) 7772bc4c3e9SNicolas Vasilache res = createAdd(op.getLoc(), res, acc, isInt, rewriter); 778b7324b6aSAndrzej Warzyński return res; 7792bc4c3e9SNicolas Vasilache } 7802bc4c3e9SNicolas Vasilache 7812bc4c3e9SNicolas Vasilache /// Lower vector.contract with all size one reduction dimensions to 7822bc4c3e9SNicolas Vasilache /// elementwise ops when possible. 7832bc4c3e9SNicolas Vasilache struct ContractOpToElementwise 784b7324b6aSAndrzej Warzyński : public MaskableOpRewritePattern<vector::ContractionOp> { 785b7324b6aSAndrzej Warzyński using MaskableOpRewritePattern::MaskableOpRewritePattern; 7862bc4c3e9SNicolas Vasilache using FilterConstraintType = 7872bc4c3e9SNicolas Vasilache std::function<LogicalResult(vector::ContractionOp op)>; 7882bc4c3e9SNicolas Vasilache static LogicalResult defaultFilter(vector::ContractionOp op) { 7892bc4c3e9SNicolas Vasilache return success(); 7902bc4c3e9SNicolas Vasilache } 7912bc4c3e9SNicolas Vasilache ContractOpToElementwise( 7922bc4c3e9SNicolas Vasilache vector::VectorTransformsOptions vectorTransformOptions, 7932bc4c3e9SNicolas Vasilache MLIRContext *context, PatternBenefit benefit = 1, 7942bc4c3e9SNicolas Vasilache const FilterConstraintType &constraint = defaultFilter) 795b7324b6aSAndrzej Warzyński : MaskableOpRewritePattern<vector::ContractionOp>(context, benefit), 7962bc4c3e9SNicolas Vasilache vectorTransformOptions(vectorTransformOptions), filter(defaultFilter) {} 7972bc4c3e9SNicolas Vasilache 798b7324b6aSAndrzej Warzyński FailureOr<Value> 799b7324b6aSAndrzej Warzyński matchAndRewriteMaskableOp(vector::ContractionOp contractOp, 800b7324b6aSAndrzej Warzyński MaskingOpInterface maskOp, 8012bc4c3e9SNicolas Vasilache PatternRewriter &rewriter) const override { 8022bc4c3e9SNicolas Vasilache // TODO: Support vector.mask. 803b7324b6aSAndrzej Warzyński if (maskOp) 8042bc4c3e9SNicolas Vasilache return failure(); 8052bc4c3e9SNicolas Vasilache 8062bc4c3e9SNicolas Vasilache if (failed(filter(contractOp))) 8072bc4c3e9SNicolas Vasilache return failure(); 8082bc4c3e9SNicolas Vasilache 8092bc4c3e9SNicolas Vasilache if (vectorTransformOptions.vectorContractLowering != 8102bc4c3e9SNicolas Vasilache vector::VectorContractLowering::ParallelArith) 8112bc4c3e9SNicolas Vasilache return failure(); 8122bc4c3e9SNicolas Vasilache 8132bc4c3e9SNicolas Vasilache ArrayRef<int64_t> lhsShape = contractOp.getLhsType().getShape(); 8142bc4c3e9SNicolas Vasilache ArrayRef<int64_t> rhsShape = contractOp.getRhsType().getShape(); 8152bc4c3e9SNicolas Vasilache AffineMap lhsMap = contractOp.getIndexingMapsArray()[0]; 8162bc4c3e9SNicolas Vasilache AffineMap rhsMap = contractOp.getIndexingMapsArray()[1]; 8172bc4c3e9SNicolas Vasilache SmallVector<int64_t> lhsReductionDims = 8182bc4c3e9SNicolas Vasilache getReductionIndex(lhsMap, contractOp.getIteratorTypes()); 8192bc4c3e9SNicolas Vasilache SmallVector<int64_t> rhsReductionDims = 8202bc4c3e9SNicolas Vasilache getReductionIndex(rhsMap, contractOp.getIteratorTypes()); 8212bc4c3e9SNicolas Vasilache // All the reduction dimensions must be a size 1. 8222bc4c3e9SNicolas Vasilache for (int64_t dim : lhsReductionDims) { 8232bc4c3e9SNicolas Vasilache if (lhsShape[dim] != 1) 8242bc4c3e9SNicolas Vasilache return failure(); 8252bc4c3e9SNicolas Vasilache } 8262bc4c3e9SNicolas Vasilache for (int64_t dim : rhsReductionDims) { 8272bc4c3e9SNicolas Vasilache if (rhsShape[dim] != 1) 8282bc4c3e9SNicolas Vasilache return failure(); 8292bc4c3e9SNicolas Vasilache } 8302bc4c3e9SNicolas Vasilache AffineMap accMap = contractOp.getIndexingMapsArray()[2]; 8312bc4c3e9SNicolas Vasilache unsigned numParallelDims = accMap.getNumResults(); 8322bc4c3e9SNicolas Vasilache unsigned numLhsDimToBroadcast = 8332bc4c3e9SNicolas Vasilache numParallelDims - (lhsMap.getNumResults() - lhsReductionDims.size()); 8342bc4c3e9SNicolas Vasilache unsigned numRhsDimToBroadcast = 8352bc4c3e9SNicolas Vasilache numParallelDims - (rhsMap.getNumResults() - rhsReductionDims.size()); 8362bc4c3e9SNicolas Vasilache SmallVector<int64_t> lhsDims; 8372bc4c3e9SNicolas Vasilache SmallVector<int64_t> lhsTranspose; 8382bc4c3e9SNicolas Vasilache SmallVector<int64_t> rhsDims; 8392bc4c3e9SNicolas Vasilache SmallVector<int64_t> rhsTranspose; 8402bc4c3e9SNicolas Vasilache for (int64_t dim : lhsReductionDims) 8412bc4c3e9SNicolas Vasilache lhsTranspose.push_back(numLhsDimToBroadcast + dim); 8422bc4c3e9SNicolas Vasilache for (int64_t dim : rhsReductionDims) 8432bc4c3e9SNicolas Vasilache rhsTranspose.push_back(numRhsDimToBroadcast + dim); 8442bc4c3e9SNicolas Vasilache // Loop through the parallel dimensions to calculate the dimensions to 8452bc4c3e9SNicolas Vasilache // broadcast and to permute in order to extract only parallel dimensions. 8462bc4c3e9SNicolas Vasilache for (unsigned i = 0; i < numParallelDims; i++) { 8472bc4c3e9SNicolas Vasilache std::optional<unsigned> lhsDim = 8482bc4c3e9SNicolas Vasilache getDimPosition(lhsMap, accMap.getDimPosition(i)); 8492bc4c3e9SNicolas Vasilache if (lhsDim) { 8502bc4c3e9SNicolas Vasilache lhsTranspose.push_back(numLhsDimToBroadcast + *lhsDim); 8512bc4c3e9SNicolas Vasilache } else { 8522bc4c3e9SNicolas Vasilache // If the parallel dimension doesn't exist we will have to broadcast it. 8532bc4c3e9SNicolas Vasilache lhsDims.push_back( 8545550c821STres Popp cast<VectorType>(contractOp.getResultType()).getDimSize(i)); 8552bc4c3e9SNicolas Vasilache lhsTranspose.push_back(lhsDims.size() - 1); 8562bc4c3e9SNicolas Vasilache } 8572bc4c3e9SNicolas Vasilache std::optional<unsigned> rhsDim = 8582bc4c3e9SNicolas Vasilache getDimPosition(rhsMap, accMap.getDimPosition(i)); 8592bc4c3e9SNicolas Vasilache if (rhsDim) { 8602bc4c3e9SNicolas Vasilache rhsTranspose.push_back(numRhsDimToBroadcast + *rhsDim); 8612bc4c3e9SNicolas Vasilache } else { 8622bc4c3e9SNicolas Vasilache // If the parallel dimension doesn't exist we will have to broadcast it. 8632bc4c3e9SNicolas Vasilache rhsDims.push_back( 8645550c821STres Popp cast<VectorType>(contractOp.getResultType()).getDimSize(i)); 8652bc4c3e9SNicolas Vasilache rhsTranspose.push_back(rhsDims.size() - 1); 8662bc4c3e9SNicolas Vasilache } 8672bc4c3e9SNicolas Vasilache } 8682bc4c3e9SNicolas Vasilache Value newLhs = contractOp.getLhs(); 8692bc4c3e9SNicolas Vasilache Value newRhs = contractOp.getRhs(); 8702bc4c3e9SNicolas Vasilache Location loc = contractOp.getLoc(); 8712bc4c3e9SNicolas Vasilache if (!lhsDims.empty()) { 8722bc4c3e9SNicolas Vasilache lhsDims.append(lhsShape.begin(), lhsShape.end()); 8732bc4c3e9SNicolas Vasilache auto expandedType = 8742bc4c3e9SNicolas Vasilache VectorType::get(lhsDims, contractOp.getLhsType().getElementType()); 8752bc4c3e9SNicolas Vasilache newLhs = rewriter.create<vector::BroadcastOp>(loc, expandedType, newLhs); 8762bc4c3e9SNicolas Vasilache } 8772bc4c3e9SNicolas Vasilache if (!rhsDims.empty()) { 8782bc4c3e9SNicolas Vasilache rhsDims.append(rhsShape.begin(), rhsShape.end()); 8792bc4c3e9SNicolas Vasilache auto expandedType = 8802bc4c3e9SNicolas Vasilache VectorType::get(rhsDims, contractOp.getRhsType().getElementType()); 8812bc4c3e9SNicolas Vasilache newRhs = rewriter.create<vector::BroadcastOp>(loc, expandedType, newRhs); 8822bc4c3e9SNicolas Vasilache } 8832bc4c3e9SNicolas Vasilache bool isInt = contractOp.getLhsType().getElementType().isIntOrIndex(); 8842bc4c3e9SNicolas Vasilache newLhs = rewriter.create<vector::TransposeOp>(loc, newLhs, lhsTranspose); 8852bc4c3e9SNicolas Vasilache newRhs = rewriter.create<vector::TransposeOp>(loc, newRhs, rhsTranspose); 8862bc4c3e9SNicolas Vasilache SmallVector<int64_t> lhsOffsets(lhsReductionDims.size(), 0); 8872bc4c3e9SNicolas Vasilache SmallVector<int64_t> rhsOffsets(rhsReductionDims.size(), 0); 88816b75cd2SMatthias Springer newLhs = rewriter.create<vector::ExtractOp>(loc, newLhs, lhsOffsets); 88916b75cd2SMatthias Springer newRhs = rewriter.create<vector::ExtractOp>(loc, newRhs, rhsOffsets); 8902bc4c3e9SNicolas Vasilache std::optional<Value> result = 8912bc4c3e9SNicolas Vasilache createContractArithOp(loc, newLhs, newRhs, contractOp.getAcc(), 8922bc4c3e9SNicolas Vasilache contractOp.getKind(), rewriter, isInt); 893b7324b6aSAndrzej Warzyński if (result) 894b7324b6aSAndrzej Warzyński return *result; 895b7324b6aSAndrzej Warzyński 896b7324b6aSAndrzej Warzyński return failure(); 8972bc4c3e9SNicolas Vasilache } 8982bc4c3e9SNicolas Vasilache 8992bc4c3e9SNicolas Vasilache private: 9002bc4c3e9SNicolas Vasilache /// Options to control the vector patterns. 9012bc4c3e9SNicolas Vasilache vector::VectorTransformsOptions vectorTransformOptions; 9022bc4c3e9SNicolas Vasilache FilterConstraintType filter; 9032bc4c3e9SNicolas Vasilache }; 9042bc4c3e9SNicolas Vasilache 9052bc4c3e9SNicolas Vasilache /// Progressive lowering of ContractionOp. 9062bc4c3e9SNicolas Vasilache /// One: 9072bc4c3e9SNicolas Vasilache /// %x = vector.contract with at least one free/batch dimension 9082bc4c3e9SNicolas Vasilache /// is replaced by: 9092bc4c3e9SNicolas Vasilache /// %a = vector.contract with one less free/batch dimension 9102bc4c3e9SNicolas Vasilache /// %b = vector.contract with one less free/batch dimension 9112bc4c3e9SNicolas Vasilache /// .. 9122bc4c3e9SNicolas Vasilache /// %x = combine %a %b .. 9132bc4c3e9SNicolas Vasilache /// until a pure contraction is reached (no free/batch dimensions), 9142bc4c3e9SNicolas Vasilache /// which is replaced by a dot-product. 9152bc4c3e9SNicolas Vasilache /// 9162bc4c3e9SNicolas Vasilache /// This only kicks in when either VectorTransformsOptions is set 9172bc4c3e9SNicolas Vasilache /// to DOT or when other contraction patterns fail. 9182bc4c3e9SNicolas Vasilache // 9192bc4c3e9SNicolas Vasilache // TODO: break down into transpose/reshape/cast ops 9202bc4c3e9SNicolas Vasilache // when they become available to avoid code dup 9212bc4c3e9SNicolas Vasilache // TODO: investigate lowering order impact on performance 922b7324b6aSAndrzej Warzyński FailureOr<Value> ContractionOpLowering::matchAndRewriteMaskableOp( 923b7324b6aSAndrzej Warzyński vector::ContractionOp op, MaskingOpInterface maskOp, 9242bc4c3e9SNicolas Vasilache PatternRewriter &rewriter) const { 9252bc4c3e9SNicolas Vasilache if (failed(filter(op))) 9262bc4c3e9SNicolas Vasilache return failure(); 9272bc4c3e9SNicolas Vasilache 9282bc4c3e9SNicolas Vasilache // TODO: support mixed mode contract lowering. 9292bc4c3e9SNicolas Vasilache if (op.getLhsType().getElementType() != 9302bc4c3e9SNicolas Vasilache getElementTypeOrSelf(op.getAccType()) || 9312bc4c3e9SNicolas Vasilache op.getRhsType().getElementType() != getElementTypeOrSelf(op.getAccType())) 9322bc4c3e9SNicolas Vasilache return failure(); 9332bc4c3e9SNicolas Vasilache 9342bc4c3e9SNicolas Vasilache // TODO: the code below assumes the default contraction, make sure it supports 9352bc4c3e9SNicolas Vasilache // other kinds before enabling this lowering. 9362bc4c3e9SNicolas Vasilache if (op.getKind() != vector::CombiningKind::ADD) { 9372bc4c3e9SNicolas Vasilache return rewriter.notifyMatchFailure( 9382bc4c3e9SNicolas Vasilache op, "contractions other than 'add' not supported"); 9392bc4c3e9SNicolas Vasilache } 9402bc4c3e9SNicolas Vasilache 9412bc4c3e9SNicolas Vasilache // TODO: implement benefits, cost models. 9422bc4c3e9SNicolas Vasilache MLIRContext *ctx = op.getContext(); 943b7324b6aSAndrzej Warzyński 9442bc4c3e9SNicolas Vasilache ContractionOpToMatmulOpLowering pat1(vectorTransformOptions, ctx); 945b7324b6aSAndrzej Warzyński FailureOr<Value> newVal1 = 946b7324b6aSAndrzej Warzyński pat1.matchAndRewriteMaskableOp(op, maskOp, rewriter); 947b7324b6aSAndrzej Warzyński if (!failed(newVal1)) 948b7324b6aSAndrzej Warzyński return newVal1; 949b7324b6aSAndrzej Warzyński 9502bc4c3e9SNicolas Vasilache ContractionOpToOuterProductOpLowering pat2(vectorTransformOptions, ctx); 951b7324b6aSAndrzej Warzyński FailureOr<Value> newVal2 = 952b7324b6aSAndrzej Warzyński pat2.matchAndRewriteMaskableOp(op, maskOp, rewriter); 953b7324b6aSAndrzej Warzyński if (!failed(newVal2)) 954b7324b6aSAndrzej Warzyński return newVal2; 955b7324b6aSAndrzej Warzyński 9562bc4c3e9SNicolas Vasilache ContractionOpToDotLowering pat3(vectorTransformOptions, ctx); 957b7324b6aSAndrzej Warzyński FailureOr<Value> newVal3 = 958b7324b6aSAndrzej Warzyński pat3.matchAndRewriteMaskableOp(op, maskOp, rewriter); 959b7324b6aSAndrzej Warzyński if (!failed(newVal3)) 960b7324b6aSAndrzej Warzyński return newVal3; 961b7324b6aSAndrzej Warzyński 9622bc4c3e9SNicolas Vasilache ContractOpToElementwise pat4(vectorTransformOptions, ctx); 963b7324b6aSAndrzej Warzyński FailureOr<Value> newVal4 = 964b7324b6aSAndrzej Warzyński pat4.matchAndRewriteMaskableOp(op, maskOp, rewriter); 965b7324b6aSAndrzej Warzyński if (!failed(newVal4)) 966b7324b6aSAndrzej Warzyński return newVal4; 9672bc4c3e9SNicolas Vasilache 9682bc4c3e9SNicolas Vasilache // Vector mask setup. 9692bc4c3e9SNicolas Vasilache 970b7324b6aSAndrzej Warzyński Value mask; 971b7324b6aSAndrzej Warzyński if (maskOp) 972b7324b6aSAndrzej Warzyński mask = maskOp.getMask(); 9732bc4c3e9SNicolas Vasilache // Find first batch dimension in LHS/RHS, and lower when found. 9742bc4c3e9SNicolas Vasilache std::vector<std::pair<int64_t, int64_t>> batchDimMap = op.getBatchDimMap(); 9752bc4c3e9SNicolas Vasilache if (!batchDimMap.empty()) { 9762bc4c3e9SNicolas Vasilache int64_t lhsIndex = batchDimMap[0].first; 9772bc4c3e9SNicolas Vasilache int64_t rhsIndex = batchDimMap[0].second; 9782bc4c3e9SNicolas Vasilache auto newOp = lowerParallel(rewriter, op, lhsIndex, rhsIndex, mask); 9792bc4c3e9SNicolas Vasilache if (failed(newOp)) 9802bc4c3e9SNicolas Vasilache return failure(); 981b7324b6aSAndrzej Warzyński return newOp; 9822bc4c3e9SNicolas Vasilache } 9832bc4c3e9SNicolas Vasilache 9842bc4c3e9SNicolas Vasilache // Collect contracting dimensions. 9852bc4c3e9SNicolas Vasilache std::vector<std::pair<int64_t, int64_t>> contractingDimMap = 9862bc4c3e9SNicolas Vasilache op.getContractingDimMap(); 9872bc4c3e9SNicolas Vasilache DenseSet<int64_t> lhsContractingDimSet; 9882bc4c3e9SNicolas Vasilache DenseSet<int64_t> rhsContractingDimSet; 9892bc4c3e9SNicolas Vasilache for (auto &dimPair : contractingDimMap) { 9902bc4c3e9SNicolas Vasilache lhsContractingDimSet.insert(dimPair.first); 9912bc4c3e9SNicolas Vasilache rhsContractingDimSet.insert(dimPair.second); 9922bc4c3e9SNicolas Vasilache } 9932bc4c3e9SNicolas Vasilache 9942bc4c3e9SNicolas Vasilache // Find first free dimension in LHS, and lower when found. 9952bc4c3e9SNicolas Vasilache VectorType lhsType = op.getLhsType(); 9962bc4c3e9SNicolas Vasilache for (int64_t lhsIndex = 0, e = lhsType.getRank(); lhsIndex < e; ++lhsIndex) { 9972bc4c3e9SNicolas Vasilache if (lhsContractingDimSet.count(lhsIndex) == 0) { 9982bc4c3e9SNicolas Vasilache auto newOp = lowerParallel(rewriter, op, lhsIndex, /*rhsIndex=*/-1, mask); 9992bc4c3e9SNicolas Vasilache if (failed(newOp)) 10002bc4c3e9SNicolas Vasilache return failure(); 1001b7324b6aSAndrzej Warzyński return newOp; 10022bc4c3e9SNicolas Vasilache } 10032bc4c3e9SNicolas Vasilache } 10042bc4c3e9SNicolas Vasilache 10052bc4c3e9SNicolas Vasilache // Find first free dimension in RHS, and lower when found. 10062bc4c3e9SNicolas Vasilache VectorType rhsType = op.getRhsType(); 10072bc4c3e9SNicolas Vasilache for (int64_t rhsIndex = 0, e = rhsType.getRank(); rhsIndex < e; ++rhsIndex) { 10082bc4c3e9SNicolas Vasilache if (rhsContractingDimSet.count(rhsIndex) == 0) { 10092bc4c3e9SNicolas Vasilache auto newOp = lowerParallel(rewriter, op, /*lhsIndex=*/-1, rhsIndex, mask); 10102bc4c3e9SNicolas Vasilache if (failed(newOp)) 10112bc4c3e9SNicolas Vasilache return failure(); 1012b7324b6aSAndrzej Warzyński return newOp; 10132bc4c3e9SNicolas Vasilache } 10142bc4c3e9SNicolas Vasilache } 10152bc4c3e9SNicolas Vasilache 10162bc4c3e9SNicolas Vasilache // Lower the first remaining reduction dimension. 10172bc4c3e9SNicolas Vasilache if (!contractingDimMap.empty()) { 10182bc4c3e9SNicolas Vasilache auto newOp = lowerReduction(rewriter, op, mask); 10192bc4c3e9SNicolas Vasilache if (failed(newOp)) 10202bc4c3e9SNicolas Vasilache return failure(); 1021b7324b6aSAndrzej Warzyński return newOp; 10222bc4c3e9SNicolas Vasilache } 10232bc4c3e9SNicolas Vasilache 10242bc4c3e9SNicolas Vasilache return failure(); 10252bc4c3e9SNicolas Vasilache } 10262bc4c3e9SNicolas Vasilache 10272bc4c3e9SNicolas Vasilache // Lower one parallel dimension. 10282bc4c3e9SNicolas Vasilache // Incidentally also tolerates unit-size (hence trivial) reduction dimensions. 10292bc4c3e9SNicolas Vasilache // TODO: consider reusing existing contract unrolling 10302bc4c3e9SNicolas Vasilache FailureOr<Value> ContractionOpLowering::lowerParallel(PatternRewriter &rewriter, 10312bc4c3e9SNicolas Vasilache vector::ContractionOp op, 10322bc4c3e9SNicolas Vasilache int64_t lhsIndex, 10332bc4c3e9SNicolas Vasilache int64_t rhsIndex, 10342bc4c3e9SNicolas Vasilache Value mask) const { 10352bc4c3e9SNicolas Vasilache VectorType lhsType = op.getLhsType(); 10362bc4c3e9SNicolas Vasilache VectorType rhsType = op.getRhsType(); 10375550c821STres Popp VectorType resType = cast<VectorType>(op.getResultType()); 10382bc4c3e9SNicolas Vasilache // Find the iterator type index and result index. 10392bc4c3e9SNicolas Vasilache SmallVector<AffineMap> iMap = op.getIndexingMapsArray(); 10402bc4c3e9SNicolas Vasilache int64_t iterIndex = -1; 10412bc4c3e9SNicolas Vasilache int64_t dimSize = -1; 10422bc4c3e9SNicolas Vasilache if (lhsIndex >= 0) { 10432bc4c3e9SNicolas Vasilache iterIndex = iMap[0].getDimPosition(lhsIndex); 10442bc4c3e9SNicolas Vasilache if (rhsIndex >= 0 && iterIndex != iMap[1].getDimPosition(rhsIndex)) 10452bc4c3e9SNicolas Vasilache return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) { 10462bc4c3e9SNicolas Vasilache diag << "expected lhsIndex=" << lhsIndex << " and rhsIndex=" << rhsIndex 10472bc4c3e9SNicolas Vasilache << " to map to the same dimension"; 10482bc4c3e9SNicolas Vasilache }); 1049c91d3b0bSAndrzej Warzynski if (lhsType.getScalableDims()[lhsIndex]) 1050c91d3b0bSAndrzej Warzynski return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) { 1051e01c8673SAndrzej Warzyński diag << "Unrolling scalable dimension (lhsIndex=" << lhsIndex 1052c91d3b0bSAndrzej Warzynski << ") is not supported yet"; 1053c91d3b0bSAndrzej Warzynski }); 10542bc4c3e9SNicolas Vasilache dimSize = lhsType.getDimSize(lhsIndex); 10552bc4c3e9SNicolas Vasilache } else if (rhsIndex >= 0) { 10562bc4c3e9SNicolas Vasilache iterIndex = iMap[1].getDimPosition(rhsIndex); 1057c91d3b0bSAndrzej Warzynski if (rhsType.getScalableDims()[rhsIndex]) 1058c91d3b0bSAndrzej Warzynski return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) { 1059e01c8673SAndrzej Warzyński diag << "Unrolling scalable dimension (rhsIndex=" << rhsIndex 1060c91d3b0bSAndrzej Warzynski << ") is not supported yet"; 1061c91d3b0bSAndrzej Warzynski }); 10622bc4c3e9SNicolas Vasilache dimSize = rhsType.getDimSize(rhsIndex); 10632bc4c3e9SNicolas Vasilache } 10642bc4c3e9SNicolas Vasilache if (iterIndex < 0) 10652bc4c3e9SNicolas Vasilache return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) { 10662bc4c3e9SNicolas Vasilache diag << "expected either lhsIndex=" << lhsIndex 10672bc4c3e9SNicolas Vasilache << " or rhsIndex=" << rhsIndex << " to be nonnegative"; 10682bc4c3e9SNicolas Vasilache }); 10692bc4c3e9SNicolas Vasilache // value_or(-1) means that we tolerate a dimension not appearing 10702bc4c3e9SNicolas Vasilache // in the result map. That can't happen for actual parallel iterators, but 10712bc4c3e9SNicolas Vasilache // the caller ContractionOpLowering::matchAndRewrite is currently calling 10722bc4c3e9SNicolas Vasilache // lowerParallel also for the case of unit-size reduction dims appearing only 10732bc4c3e9SNicolas Vasilache // on one of LHS or RHS, not both. At the moment, such cases are created by 10742bc4c3e9SNicolas Vasilache // CastAwayContractionLeadingOneDim, so we need to either support that or 10752bc4c3e9SNicolas Vasilache // modify that pattern. 10762bc4c3e9SNicolas Vasilache int64_t resIndex = getResultIndex(iMap[2], iterIndex).value_or(-1); 10772bc4c3e9SNicolas Vasilache if (resIndex == -1 && dimSize != 1) 10782bc4c3e9SNicolas Vasilache return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) { 10792bc4c3e9SNicolas Vasilache diag << "expected the dimension for iterIndex=" << iterIndex 10802bc4c3e9SNicolas Vasilache << " to either appear in the result map, or to be a unit dimension"; 10812bc4c3e9SNicolas Vasilache }); 10822bc4c3e9SNicolas Vasilache 10832bc4c3e9SNicolas Vasilache // Construct new iterator types and affine map array attribute. 10842bc4c3e9SNicolas Vasilache std::array<AffineMap, 3> lowIndexingMaps = { 10852bc4c3e9SNicolas Vasilache adjustMap(iMap[0], iterIndex, rewriter), 10862bc4c3e9SNicolas Vasilache adjustMap(iMap[1], iterIndex, rewriter), 10872bc4c3e9SNicolas Vasilache adjustMap(iMap[2], iterIndex, rewriter)}; 10882bc4c3e9SNicolas Vasilache auto lowAffine = rewriter.getAffineMapArrayAttr(lowIndexingMaps); 10892bc4c3e9SNicolas Vasilache auto lowIter = 10902bc4c3e9SNicolas Vasilache rewriter.getArrayAttr(adjustIter(op.getIteratorTypes(), iterIndex)); 10912bc4c3e9SNicolas Vasilache // Unroll into a series of lower dimensional vector.contract ops. 10922bc4c3e9SNicolas Vasilache Location loc = op.getLoc(); 10932bc4c3e9SNicolas Vasilache Value result = rewriter.create<arith::ConstantOp>( 10942bc4c3e9SNicolas Vasilache loc, resType, rewriter.getZeroAttr(resType)); 10952bc4c3e9SNicolas Vasilache 10962bc4c3e9SNicolas Vasilache for (int64_t d = 0; d < dimSize; ++d) { 10972bc4c3e9SNicolas Vasilache auto lhs = reshapeLoad(loc, op.getLhs(), lhsType, lhsIndex, d, rewriter); 10982bc4c3e9SNicolas Vasilache auto rhs = reshapeLoad(loc, op.getRhs(), rhsType, rhsIndex, d, rewriter); 10992bc4c3e9SNicolas Vasilache auto acc = reshapeLoad(loc, op.getAcc(), resType, resIndex, d, rewriter); 11002bc4c3e9SNicolas Vasilache 11012bc4c3e9SNicolas Vasilache Value lowMask; 11022bc4c3e9SNicolas Vasilache if (mask) 11032bc4c3e9SNicolas Vasilache lowMask = reshapeLoad(loc, mask, cast<VectorType>(mask.getType()), 11042bc4c3e9SNicolas Vasilache iterIndex, d, rewriter); 11052bc4c3e9SNicolas Vasilache 11062bc4c3e9SNicolas Vasilache Operation *lowContract = rewriter.create<vector::ContractionOp>( 11072bc4c3e9SNicolas Vasilache loc, lhs, rhs, acc, lowAffine, lowIter); 11082bc4c3e9SNicolas Vasilache lowContract = maskOperation(rewriter, lowContract, lowMask); 11092bc4c3e9SNicolas Vasilache result = reshapeStore(loc, lowContract->getResult(0), result, resType, 11102bc4c3e9SNicolas Vasilache resIndex, d, rewriter); 11112bc4c3e9SNicolas Vasilache } 11122bc4c3e9SNicolas Vasilache return result; 11132bc4c3e9SNicolas Vasilache } 11142bc4c3e9SNicolas Vasilache 11152bc4c3e9SNicolas Vasilache // Lower one reduction dimension. 11162bc4c3e9SNicolas Vasilache FailureOr<Value> ContractionOpLowering::lowerReduction( 11172bc4c3e9SNicolas Vasilache PatternRewriter &rewriter, vector::ContractionOp op, Value mask) const { 11182bc4c3e9SNicolas Vasilache auto loc = op.getLoc(); 11192bc4c3e9SNicolas Vasilache VectorType lhsType = op.getLhsType(); 11202bc4c3e9SNicolas Vasilache VectorType rhsType = op.getRhsType(); 11212bc4c3e9SNicolas Vasilache Type resType = op.getResultType(); 11225550c821STres Popp if (isa<VectorType>(resType)) 11232bc4c3e9SNicolas Vasilache return rewriter.notifyMatchFailure(op, 11242bc4c3e9SNicolas Vasilache "did not expect a VectorType result"); 11255550c821STres Popp bool isInt = isa<IntegerType>(resType); 11262bc4c3e9SNicolas Vasilache // Use iterator index 0. 11272bc4c3e9SNicolas Vasilache int64_t iterIndex = 0; 11282bc4c3e9SNicolas Vasilache SmallVector<AffineMap> iMap = op.getIndexingMapsArray(); 11292bc4c3e9SNicolas Vasilache std::optional<int64_t> lookupLhs = getResultIndex(iMap[0], iterIndex); 11302bc4c3e9SNicolas Vasilache std::optional<int64_t> lookupRhs = getResultIndex(iMap[1], iterIndex); 11312bc4c3e9SNicolas Vasilache if (!lookupLhs.has_value()) 11322bc4c3e9SNicolas Vasilache return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) { 11332bc4c3e9SNicolas Vasilache diag << "expected iterIndex=" << iterIndex << "to map to a LHS dimension"; 11342bc4c3e9SNicolas Vasilache }); 11352bc4c3e9SNicolas Vasilache if (!lookupRhs.has_value()) 11362bc4c3e9SNicolas Vasilache return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) { 11372bc4c3e9SNicolas Vasilache diag << "expected iterIndex=" << iterIndex << "to map to a RHS dimension"; 11382bc4c3e9SNicolas Vasilache }); 11392bc4c3e9SNicolas Vasilache int64_t lhsIndex = *lookupLhs; 11402bc4c3e9SNicolas Vasilache int64_t rhsIndex = *lookupRhs; 11412bc4c3e9SNicolas Vasilache int64_t dimSize = lhsType.getDimSize(lhsIndex); 11422bc4c3e9SNicolas Vasilache if (dimSize != rhsType.getDimSize(rhsIndex)) 11432bc4c3e9SNicolas Vasilache return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) { 11442bc4c3e9SNicolas Vasilache diag << "expect LHS dimension " << lhsIndex 11452bc4c3e9SNicolas Vasilache << " to have the same size as RHS dimension " << rhsIndex; 11462bc4c3e9SNicolas Vasilache }); 11472bc4c3e9SNicolas Vasilache // Base case. 11482bc4c3e9SNicolas Vasilache if (lhsType.getRank() == 1) { 11492bc4c3e9SNicolas Vasilache if (rhsType.getRank() != 1) 11502bc4c3e9SNicolas Vasilache return rewriter.notifyMatchFailure( 11512bc4c3e9SNicolas Vasilache op, "When LHS has rank 1, expected also RHS to have rank 1"); 11522bc4c3e9SNicolas Vasilache Value m = createMul(loc, op.getLhs(), op.getRhs(), isInt, rewriter); 11532bc4c3e9SNicolas Vasilache auto kind = vector::CombiningKind::ADD; 11542bc4c3e9SNicolas Vasilache 11552bc4c3e9SNicolas Vasilache Value acc = op.getAcc(); 11562bc4c3e9SNicolas Vasilache Operation *reductionOp = 11572bc4c3e9SNicolas Vasilache acc ? rewriter.create<vector::ReductionOp>(loc, kind, m, acc) 11582bc4c3e9SNicolas Vasilache : rewriter.create<vector::ReductionOp>(loc, kind, m); 11592bc4c3e9SNicolas Vasilache return maskOperation(rewriter, reductionOp, mask)->getResult(0); 11602bc4c3e9SNicolas Vasilache } 11612bc4c3e9SNicolas Vasilache // Construct new iterator types and affine map array attribute. 11622bc4c3e9SNicolas Vasilache std::array<AffineMap, 3> lowIndexingMaps = { 11632bc4c3e9SNicolas Vasilache adjustMap(iMap[0], iterIndex, rewriter), 11642bc4c3e9SNicolas Vasilache adjustMap(iMap[1], iterIndex, rewriter), 11652bc4c3e9SNicolas Vasilache adjustMap(iMap[2], iterIndex, rewriter)}; 11662bc4c3e9SNicolas Vasilache auto lowAffine = rewriter.getAffineMapArrayAttr(lowIndexingMaps); 11672bc4c3e9SNicolas Vasilache auto lowIter = 11682bc4c3e9SNicolas Vasilache rewriter.getArrayAttr(adjustIter(op.getIteratorTypes(), iterIndex)); 11692bc4c3e9SNicolas Vasilache // Unroll into a series of lower dimensional vector.contract ops. 11702bc4c3e9SNicolas Vasilache // By feeding the initial accumulator into the first contraction, 11712bc4c3e9SNicolas Vasilache // and the result of each contraction into the next, eventually 11722bc4c3e9SNicolas Vasilache // the sum of all reductions is computed. 11732bc4c3e9SNicolas Vasilache Value result = op.getAcc(); 11742bc4c3e9SNicolas Vasilache for (int64_t d = 0; d < dimSize; ++d) { 11752bc4c3e9SNicolas Vasilache auto lhs = reshapeLoad(loc, op.getLhs(), lhsType, lhsIndex, d, rewriter); 11762bc4c3e9SNicolas Vasilache auto rhs = reshapeLoad(loc, op.getRhs(), rhsType, rhsIndex, d, rewriter); 11772bc4c3e9SNicolas Vasilache Value newMask; 11782bc4c3e9SNicolas Vasilache if (mask) 11792bc4c3e9SNicolas Vasilache newMask = reshapeLoad(loc, mask, cast<VectorType>(mask.getType()), 11802bc4c3e9SNicolas Vasilache iterIndex, d, rewriter); 11812bc4c3e9SNicolas Vasilache 11822bc4c3e9SNicolas Vasilache Operation *newContract = rewriter.create<vector::ContractionOp>( 11832bc4c3e9SNicolas Vasilache loc, lhs, rhs, result, lowAffine, lowIter); 11842bc4c3e9SNicolas Vasilache result = maskOperation(rewriter, newContract, newMask)->getResult(0); 11852bc4c3e9SNicolas Vasilache } 11862bc4c3e9SNicolas Vasilache return result; 11872bc4c3e9SNicolas Vasilache } 11882bc4c3e9SNicolas Vasilache 11892bc4c3e9SNicolas Vasilache /// Progressive lowering of OuterProductOp. 11902bc4c3e9SNicolas Vasilache /// One: 11912bc4c3e9SNicolas Vasilache /// %x = vector.outerproduct %lhs, %rhs, %acc 11922bc4c3e9SNicolas Vasilache /// is replaced by: 11932bc4c3e9SNicolas Vasilache /// %z = zero-result 11942bc4c3e9SNicolas Vasilache /// %0 = vector.extract %lhs[0] 11952bc4c3e9SNicolas Vasilache /// %1 = vector.broadcast %0 11962bc4c3e9SNicolas Vasilache /// %2 = vector.extract %acc[0] 11972bc4c3e9SNicolas Vasilache /// %3 = vector.fma %1, %rhs, %2 11982bc4c3e9SNicolas Vasilache /// %4 = vector.insert %3, %z[0] 11992bc4c3e9SNicolas Vasilache /// .. 12002bc4c3e9SNicolas Vasilache /// %x = vector.insert %.., %..[N-1] 12012bc4c3e9SNicolas Vasilache /// 12022bc4c3e9SNicolas Vasilache class OuterProductOpLowering : public OpRewritePattern<vector::OuterProductOp> { 12032bc4c3e9SNicolas Vasilache public: 12042bc4c3e9SNicolas Vasilache using OpRewritePattern::OpRewritePattern; 12052bc4c3e9SNicolas Vasilache 12062bc4c3e9SNicolas Vasilache LogicalResult matchAndRewrite(vector::OuterProductOp op, 12072bc4c3e9SNicolas Vasilache PatternRewriter &rewriter) const override { 1208f75d46a7SCullen Rhodes VectorType resType = op.getResultVectorType(); 1209f75d46a7SCullen Rhodes if ((resType.getShape().size() >= 2) && resType.allDimsScalable()) 1210f75d46a7SCullen Rhodes return failure(); 1211f75d46a7SCullen Rhodes 12122bc4c3e9SNicolas Vasilache auto loc = op.getLoc(); 12132bc4c3e9SNicolas Vasilache 12142bc4c3e9SNicolas Vasilache VectorType lhsType = op.getOperandVectorTypeLHS(); 12155550c821STres Popp VectorType rhsType = dyn_cast<VectorType>(op.getOperandTypeRHS()); 12162bc4c3e9SNicolas Vasilache Type eltType = resType.getElementType(); 12175550c821STres Popp bool isInt = isa<IntegerType, IndexType>(eltType); 1218067bd7d0SCullen Rhodes Value acc = op.getAcc(); 12192bc4c3e9SNicolas Vasilache vector::CombiningKind kind = op.getKind(); 12202bc4c3e9SNicolas Vasilache 12212bc4c3e9SNicolas Vasilache // Vector mask setup. 12222bc4c3e9SNicolas Vasilache OpBuilder::InsertionGuard guard(rewriter); 12232bc4c3e9SNicolas Vasilache auto maskableOp = cast<vector::MaskableOpInterface>(op.getOperation()); 12242bc4c3e9SNicolas Vasilache Operation *rootOp; 12252bc4c3e9SNicolas Vasilache Value mask; 12262bc4c3e9SNicolas Vasilache if (maskableOp.isMasked()) { 12272bc4c3e9SNicolas Vasilache rewriter.setInsertionPoint(maskableOp.getMaskingOp()); 12282bc4c3e9SNicolas Vasilache rootOp = maskableOp.getMaskingOp(); 12292bc4c3e9SNicolas Vasilache mask = maskableOp.getMaskingOp().getMask(); 12302bc4c3e9SNicolas Vasilache } else { 12312bc4c3e9SNicolas Vasilache rootOp = op; 12322bc4c3e9SNicolas Vasilache } 12332bc4c3e9SNicolas Vasilache 12342bc4c3e9SNicolas Vasilache if (!rhsType) { 12352bc4c3e9SNicolas Vasilache // Special case: AXPY operation. 12362bc4c3e9SNicolas Vasilache Value b = rewriter.create<vector::BroadcastOp>(loc, lhsType, op.getRhs()); 12372bc4c3e9SNicolas Vasilache std::optional<Value> mult = createContractArithOp( 12382bc4c3e9SNicolas Vasilache loc, op.getLhs(), b, acc, kind, rewriter, isInt, mask); 12392bc4c3e9SNicolas Vasilache if (!mult.has_value()) 12402bc4c3e9SNicolas Vasilache return failure(); 12412bc4c3e9SNicolas Vasilache rewriter.replaceOp(rootOp, *mult); 12422bc4c3e9SNicolas Vasilache return success(); 12432bc4c3e9SNicolas Vasilache } 12442bc4c3e9SNicolas Vasilache 12452bc4c3e9SNicolas Vasilache Value result = rewriter.create<arith::ConstantOp>( 12462bc4c3e9SNicolas Vasilache loc, resType, rewriter.getZeroAttr(resType)); 12472bc4c3e9SNicolas Vasilache for (int64_t d = 0, e = resType.getDimSize(0); d < e; ++d) { 124816b75cd2SMatthias Springer Value x = rewriter.create<vector::ExtractOp>(loc, op.getLhs(), d); 12492bc4c3e9SNicolas Vasilache Value a = rewriter.create<vector::BroadcastOp>(loc, rhsType, x); 12502bc4c3e9SNicolas Vasilache Value r = nullptr; 12512bc4c3e9SNicolas Vasilache if (acc) 125216b75cd2SMatthias Springer r = rewriter.create<vector::ExtractOp>(loc, acc, d); 12532bc4c3e9SNicolas Vasilache Value extrMask; 12542bc4c3e9SNicolas Vasilache if (mask) 125516b75cd2SMatthias Springer extrMask = rewriter.create<vector::ExtractOp>(loc, mask, d); 12562bc4c3e9SNicolas Vasilache 12572bc4c3e9SNicolas Vasilache std::optional<Value> m = createContractArithOp( 12582bc4c3e9SNicolas Vasilache loc, a, op.getRhs(), r, kind, rewriter, isInt, extrMask); 12592bc4c3e9SNicolas Vasilache if (!m.has_value()) 12602bc4c3e9SNicolas Vasilache return failure(); 126198f6289aSDiego Caballero result = rewriter.create<vector::InsertOp>(loc, *m, result, d); 12622bc4c3e9SNicolas Vasilache } 12632bc4c3e9SNicolas Vasilache 12642bc4c3e9SNicolas Vasilache rewriter.replaceOp(rootOp, result); 12652bc4c3e9SNicolas Vasilache return success(); 12662bc4c3e9SNicolas Vasilache } 12672bc4c3e9SNicolas Vasilache }; 12682bc4c3e9SNicolas Vasilache 12692bc4c3e9SNicolas Vasilache /// Progressively lower a `vector.contract %a, %b, %c` with row-major matmul 12702bc4c3e9SNicolas Vasilache /// semantics to: 12712bc4c3e9SNicolas Vasilache /// ``` 12722bc4c3e9SNicolas Vasilache /// %mta = maybe_transpose 12732bc4c3e9SNicolas Vasilache /// %mtb = maybe_transpose 12742bc4c3e9SNicolas Vasilache /// %flattened_a = vector.shape_cast %mta 12752bc4c3e9SNicolas Vasilache /// %flattened_b = vector.shape_cast %mtb 12762bc4c3e9SNicolas Vasilache /// %flattened_d = vector.matmul %flattened_a, %flattened_b 12772bc4c3e9SNicolas Vasilache /// %mtd = vector.shape_cast %flattened_d 12782bc4c3e9SNicolas Vasilache /// %d = maybe_untranspose %mtd 12792bc4c3e9SNicolas Vasilache /// %e = add %c, %d 12802bc4c3e9SNicolas Vasilache /// ``` 12812bc4c3e9SNicolas Vasilache /// `vector.matmul` later lowers to `llvm.matrix.multiply`. 12822bc4c3e9SNicolas Vasilache // 12832bc4c3e9SNicolas Vasilache /// This only kicks in when VectorTransformsOptions is set to `Matmul`. 12842bc4c3e9SNicolas Vasilache /// vector.transpose operations are inserted if the vector.contract op is not a 12852bc4c3e9SNicolas Vasilache /// row-major matrix multiply. 1286*cb89457fSAndrzej Warzyński /// 1287*cb89457fSAndrzej Warzyński /// Scalable vectors are not supported. 1288b7324b6aSAndrzej Warzyński FailureOr<Value> ContractionOpToMatmulOpLowering::matchAndRewriteMaskableOp( 1289b7324b6aSAndrzej Warzyński vector::ContractionOp op, MaskingOpInterface maskOp, 12902bc4c3e9SNicolas Vasilache PatternRewriter &rew) const { 12912bc4c3e9SNicolas Vasilache // TODO: Support vector.mask. 1292b7324b6aSAndrzej Warzyński if (maskOp) 12932bc4c3e9SNicolas Vasilache return failure(); 12942bc4c3e9SNicolas Vasilache 12952bc4c3e9SNicolas Vasilache if (vectorTransformOptions.vectorContractLowering != 12962bc4c3e9SNicolas Vasilache vector::VectorContractLowering::Matmul) 12972bc4c3e9SNicolas Vasilache return failure(); 12982bc4c3e9SNicolas Vasilache if (failed(filter(op))) 12992bc4c3e9SNicolas Vasilache return failure(); 13002bc4c3e9SNicolas Vasilache 13012bc4c3e9SNicolas Vasilache auto iteratorTypes = op.getIteratorTypes().getValue(); 13022bc4c3e9SNicolas Vasilache if (!isParallelIterator(iteratorTypes[0]) || 13032bc4c3e9SNicolas Vasilache !isParallelIterator(iteratorTypes[1]) || 13042bc4c3e9SNicolas Vasilache !isReductionIterator(iteratorTypes[2])) 13052bc4c3e9SNicolas Vasilache return failure(); 13062bc4c3e9SNicolas Vasilache 1307*cb89457fSAndrzej Warzyński Type opResType = op.getType(); 1308*cb89457fSAndrzej Warzyński VectorType vecType = dyn_cast<VectorType>(opResType); 1309*cb89457fSAndrzej Warzyński if (vecType && vecType.isScalable()) { 1310*cb89457fSAndrzej Warzyński // Note - this is sufficient to reject all cases with scalable vectors. 1311*cb89457fSAndrzej Warzyński return failure(); 1312*cb89457fSAndrzej Warzyński } 1313*cb89457fSAndrzej Warzyński 13142bc4c3e9SNicolas Vasilache Type elementType = op.getLhsType().getElementType(); 13152bc4c3e9SNicolas Vasilache if (!elementType.isIntOrFloat()) 13162bc4c3e9SNicolas Vasilache return failure(); 13172bc4c3e9SNicolas Vasilache 1318*cb89457fSAndrzej Warzyński Type dstElementType = vecType ? vecType.getElementType() : opResType; 13192bc4c3e9SNicolas Vasilache if (elementType != dstElementType) 13202bc4c3e9SNicolas Vasilache return failure(); 13212bc4c3e9SNicolas Vasilache 13222bc4c3e9SNicolas Vasilache // Perform lhs + rhs transpositions to conform to matmul row-major semantics. 13232bc4c3e9SNicolas Vasilache // Bail out if the contraction cannot be put in this form. 13242bc4c3e9SNicolas Vasilache MLIRContext *ctx = op.getContext(); 13252bc4c3e9SNicolas Vasilache Location loc = op.getLoc(); 13262bc4c3e9SNicolas Vasilache AffineExpr m, n, k; 13272bc4c3e9SNicolas Vasilache bindDims(rew.getContext(), m, n, k); 13282bc4c3e9SNicolas Vasilache // LHS must be A(m, k) or A(k, m). 13292bc4c3e9SNicolas Vasilache Value lhs = op.getLhs(); 13302bc4c3e9SNicolas Vasilache auto lhsMap = op.getIndexingMapsArray()[0]; 13312bc4c3e9SNicolas Vasilache if (lhsMap == AffineMap::get(3, 0, {k, m}, ctx)) 13322bc4c3e9SNicolas Vasilache lhs = rew.create<vector::TransposeOp>(loc, lhs, ArrayRef<int64_t>{1, 0}); 13332bc4c3e9SNicolas Vasilache else if (lhsMap != AffineMap::get(3, 0, {m, k}, ctx)) 13342bc4c3e9SNicolas Vasilache return failure(); 13352bc4c3e9SNicolas Vasilache 13362bc4c3e9SNicolas Vasilache // RHS must be B(k, n) or B(n, k). 13372bc4c3e9SNicolas Vasilache Value rhs = op.getRhs(); 13382bc4c3e9SNicolas Vasilache auto rhsMap = op.getIndexingMapsArray()[1]; 13392bc4c3e9SNicolas Vasilache if (rhsMap == AffineMap::get(3, 0, {n, k}, ctx)) 13402bc4c3e9SNicolas Vasilache rhs = rew.create<vector::TransposeOp>(loc, rhs, ArrayRef<int64_t>{1, 0}); 13412bc4c3e9SNicolas Vasilache else if (rhsMap != AffineMap::get(3, 0, {k, n}, ctx)) 13422bc4c3e9SNicolas Vasilache return failure(); 13432bc4c3e9SNicolas Vasilache 13442bc4c3e9SNicolas Vasilache // At this point lhs and rhs are in row-major. 13455550c821STres Popp VectorType lhsType = cast<VectorType>(lhs.getType()); 13465550c821STres Popp VectorType rhsType = cast<VectorType>(rhs.getType()); 13472bc4c3e9SNicolas Vasilache int64_t lhsRows = lhsType.getDimSize(0); 13482bc4c3e9SNicolas Vasilache int64_t lhsColumns = lhsType.getDimSize(1); 13492bc4c3e9SNicolas Vasilache int64_t rhsColumns = rhsType.getDimSize(1); 13502bc4c3e9SNicolas Vasilache 13512bc4c3e9SNicolas Vasilache Type flattenedLHSType = 13522bc4c3e9SNicolas Vasilache VectorType::get(lhsType.getNumElements(), lhsType.getElementType()); 13532bc4c3e9SNicolas Vasilache lhs = rew.create<vector::ShapeCastOp>(loc, flattenedLHSType, lhs); 13542bc4c3e9SNicolas Vasilache 13552bc4c3e9SNicolas Vasilache Type flattenedRHSType = 13562bc4c3e9SNicolas Vasilache VectorType::get(rhsType.getNumElements(), rhsType.getElementType()); 13572bc4c3e9SNicolas Vasilache rhs = rew.create<vector::ShapeCastOp>(loc, flattenedRHSType, rhs); 13582bc4c3e9SNicolas Vasilache 13592bc4c3e9SNicolas Vasilache Value mul = rew.create<vector::MatmulOp>(loc, lhs, rhs, lhsRows, lhsColumns, 13602bc4c3e9SNicolas Vasilache rhsColumns); 13612bc4c3e9SNicolas Vasilache mul = rew.create<vector::ShapeCastOp>( 13622bc4c3e9SNicolas Vasilache loc, 13632bc4c3e9SNicolas Vasilache VectorType::get({lhsRows, rhsColumns}, 13642bc4c3e9SNicolas Vasilache getElementTypeOrSelf(op.getAcc().getType())), 13652bc4c3e9SNicolas Vasilache mul); 13662bc4c3e9SNicolas Vasilache 13672bc4c3e9SNicolas Vasilache // ACC must be C(m, n) or C(n, m). 13682bc4c3e9SNicolas Vasilache auto accMap = op.getIndexingMapsArray()[2]; 13692bc4c3e9SNicolas Vasilache if (accMap == AffineMap::get(3, 0, {n, m}, ctx)) 13702bc4c3e9SNicolas Vasilache mul = rew.create<vector::TransposeOp>(loc, mul, ArrayRef<int64_t>{1, 0}); 13712bc4c3e9SNicolas Vasilache else if (accMap != AffineMap::get(3, 0, {m, n}, ctx)) 13722bc4c3e9SNicolas Vasilache llvm_unreachable("invalid contraction semantics"); 13732bc4c3e9SNicolas Vasilache 13742bc4c3e9SNicolas Vasilache Value res = 13755550c821STres Popp isa<IntegerType>(elementType) 13762bc4c3e9SNicolas Vasilache ? static_cast<Value>(rew.create<arith::AddIOp>(loc, op.getAcc(), mul)) 13772bc4c3e9SNicolas Vasilache : static_cast<Value>( 13782bc4c3e9SNicolas Vasilache rew.create<arith::AddFOp>(loc, op.getAcc(), mul)); 13792bc4c3e9SNicolas Vasilache 1380b7324b6aSAndrzej Warzyński return res; 13812bc4c3e9SNicolas Vasilache } 13822bc4c3e9SNicolas Vasilache } // namespace 13832bc4c3e9SNicolas Vasilache 13842bc4c3e9SNicolas Vasilache void mlir::vector::populateVectorContractLoweringPatterns( 13852bc4c3e9SNicolas Vasilache RewritePatternSet &patterns, VectorTransformsOptions options, 13862bc4c3e9SNicolas Vasilache PatternBenefit benefit, bool disableOuterProductLowering) { 13872bc4c3e9SNicolas Vasilache if (!disableOuterProductLowering) 13882bc4c3e9SNicolas Vasilache patterns.add<OuterProductOpLowering>(patterns.getContext(), benefit); 13892bc4c3e9SNicolas Vasilache patterns.add<ContractionOpLowering, ContractionOpToMatmulOpLowering, 13902bc4c3e9SNicolas Vasilache ContractionOpToOuterProductOpLowering>( 13912bc4c3e9SNicolas Vasilache options, patterns.getContext(), benefit); 13922bc4c3e9SNicolas Vasilache } 13938b513407SNicolas Vasilache 13948b513407SNicolas Vasilache void mlir::vector::populateVectorOuterProductLoweringPatterns( 13958b513407SNicolas Vasilache RewritePatternSet &patterns, PatternBenefit benefit) { 13968b513407SNicolas Vasilache patterns.add<OuterProductOpLowering>(patterns.getContext(), benefit); 13978b513407SNicolas Vasilache } 1398