1cb6ff746SKojo Acquah //===- LowerContractionToSMMLAPattern.cpp - Contract to SMMLA ---*- C++ -*-===// 2cb6ff746SKojo Acquah // 3cb6ff746SKojo Acquah // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4cb6ff746SKojo Acquah // See https://llvm.org/LICENSE.txt for license information. 5cb6ff746SKojo Acquah // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6cb6ff746SKojo Acquah // 7cb6ff746SKojo Acquah //===----------------------------------------------------------------------===// 8cb6ff746SKojo Acquah // 9cb6ff746SKojo Acquah // This file implements lowering patterns from vector.contract to 10cb6ff746SKojo Acquah // arm_neon.intr.smmla 11cb6ff746SKojo Acquah // 12cb6ff746SKojo Acquah //===--- 13cb6ff746SKojo Acquah 14cb6ff746SKojo Acquah #include "mlir/Dialect/Arith/IR/Arith.h" 15cb6ff746SKojo Acquah #include "mlir/Dialect/ArmNeon/ArmNeonDialect.h" 16cb6ff746SKojo Acquah #include "mlir/Dialect/ArmNeon/Transforms.h" 17cb6ff746SKojo Acquah #include "mlir/Dialect/Func/IR/FuncOps.h" 18cb6ff746SKojo Acquah #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 19fe84369cSKojo Acquah #include "mlir/Dialect/Utils/IndexingUtils.h" 20cb6ff746SKojo Acquah #include "mlir/Dialect/Vector/IR/VectorOps.h" 21fe84369cSKojo Acquah #include "mlir/IR/AffineMap.h" 22cb6ff746SKojo Acquah #include "mlir/IR/PatternMatch.h" 23cb6ff746SKojo Acquah #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 24cb6ff746SKojo Acquah 25cb6ff746SKojo Acquah #define DEBUG_TYPE "lower-contract-to-arm-neon" 26cb6ff746SKojo Acquah 27cb6ff746SKojo Acquah using namespace mlir; 28cb6ff746SKojo Acquah using namespace mlir::arm_neon; 29cb6ff746SKojo Acquah 30cb6ff746SKojo Acquah namespace { 31cb6ff746SKojo Acquah 32cb6ff746SKojo Acquah /// Return the shaped type with new element type. 33cb6ff746SKojo Acquah static Type matchContainerType(Type element, Type container) { 34cb6ff746SKojo Acquah if (auto shapedTy = dyn_cast<ShapedType>(container)) { 35cb6ff746SKojo Acquah return shapedTy.clone(element); 36cb6ff746SKojo Acquah } 37cb6ff746SKojo Acquah return element; 38cb6ff746SKojo Acquah } 39cb6ff746SKojo Acquah 40fe84369cSKojo Acquah /// Lowering from a vector::contractOp arm neon smmla intrinsic. This will tile 41fe84369cSKojo Acquah /// any vector.contract into multiple smmla instructions with unrolling so long 42c511c906SKojo Acquah /// as [2,2,8] is a divisor of its shape. It can also process vecmats with dimM 43c511c906SKojo Acquah /// = 1 (either explicitly or inferred if LHS has only dimK) If no unrolling is 44c511c906SKojo Acquah /// necessary, a single smmla instruction is emitted. 45cb6ff746SKojo Acquah class LowerContractionToSMMLAPattern 46cb6ff746SKojo Acquah : public OpRewritePattern<vector::ContractionOp> { 47cb6ff746SKojo Acquah public: 48cb6ff746SKojo Acquah using OpRewritePattern::OpRewritePattern; 49cb6ff746SKojo Acquah LogicalResult matchAndRewrite(vector::ContractionOp op, 50cb6ff746SKojo Acquah PatternRewriter &rewriter) const override { 51cb6ff746SKojo Acquah Location loc = op.getLoc(); 52c511c906SKojo Acquah // Infer tile sizes from operands. For vecmat, LHS may only have 1 dim. 53c511c906SKojo Acquah // Note: RHS is not transposed. 54cb6ff746SKojo Acquah mlir::VectorType lhsType = op.getLhsType(); 55cb6ff746SKojo Acquah mlir::VectorType rhsType = op.getRhsType(); 5604bf1a40SKojo Acquah // Avoid 0-D vectors and 1-D rhs: 5704bf1a40SKojo Acquah if (!lhsType.hasRank() || !rhsType.hasRank() || rhsType.getRank() < 2) 5804bf1a40SKojo Acquah return failure(); 59c511c906SKojo Acquah auto dimM = lhsType.getRank() == 1 ? 1 : lhsType.getDimSize(0); 60cb6ff746SKojo Acquah auto dimN = rhsType.getDimSize(0); 61c511c906SKojo Acquah auto dimK = rhsType.getDimSize(1); 62c511c906SKojo Acquah bool isVecmat = dimM == 1 ? true : false; 63c511c906SKojo Acquah if (lhsType.getDimSize(lhsType.getRank() - 1) != 64c511c906SKojo Acquah rhsType.getDimSize(rhsType.getRank() - 1)) { 65c511c906SKojo Acquah return failure(); // dimK mismatch 66c511c906SKojo Acquah } 67fe84369cSKojo Acquah // Unrolling patterns can handle any [2, 2, 8] shaped multiple of inputs for 68fe84369cSKojo Acquah // tiling. 69c511c906SKojo Acquah if ((dimM % 2 != 0 && !isVecmat) || dimN % 2 != 0 || dimK % 8 != 0) { 70c511c906SKojo Acquah return failure(); 71c511c906SKojo Acquah } 72c511c906SKojo Acquah 73c511c906SKojo Acquah // Check iterator types for contract. All iterators except inner-most 74c511c906SKojo Acquah // dimension must be parallel. 75c511c906SKojo Acquah auto iteratorTypes = op.getIteratorTypesArray(); 76c511c906SKojo Acquah if (iteratorTypes.size() > 3 || iteratorTypes[iteratorTypes.size() - 1] != 77c511c906SKojo Acquah vector::IteratorType::reduction) { 78c511c906SKojo Acquah return failure(); 79c511c906SKojo Acquah } 80c511c906SKojo Acquah if (llvm::any_of(ArrayRef<vector::IteratorType>(iteratorTypes).drop_back(1), 81c511c906SKojo Acquah [](vector::IteratorType iteratorType) { 82c511c906SKojo Acquah return iteratorType != vector::IteratorType::parallel; 83c511c906SKojo Acquah })) { 84cb6ff746SKojo Acquah return failure(); 85cb6ff746SKojo Acquah } 86cb6ff746SKojo Acquah 87cb6ff746SKojo Acquah // Check two extsi inputs Rhs Lhs for contract. 88cb6ff746SKojo Acquah arith::ExtSIOp origLhsExtOp = 89fe84369cSKojo Acquah dyn_cast_or_null<arith::ExtSIOp>(op.getLhs().getDefiningOp()); 90cb6ff746SKojo Acquah arith::ExtSIOp origRhsExtOp = 91fe84369cSKojo Acquah dyn_cast_or_null<arith::ExtSIOp>(op.getRhs().getDefiningOp()); 92cb6ff746SKojo Acquah if (!origLhsExtOp || !origRhsExtOp) { 93cb6ff746SKojo Acquah return failure(); 94cb6ff746SKojo Acquah } 95cb6ff746SKojo Acquah 96cb6ff746SKojo Acquah // Match any iX to i32 for X<8 then turn into an i8 output. Feed into 97cb6ff746SKojo Acquah // following neon instruction. Check inputs for extsi are <=i8 98cb6ff746SKojo Acquah Value extsiLhs; 99cb6ff746SKojo Acquah Value extsiRhs; 100cb6ff746SKojo Acquah if (auto lhsExtInType = 101a5757c5bSChristian Sigg dyn_cast<mlir::VectorType>(origLhsExtOp.getIn().getType())) { 102cb6ff746SKojo Acquah if (lhsExtInType.getElementTypeBitWidth() <= 8) { 103cb6ff746SKojo Acquah Type targetLhsExtTy = 104cb6ff746SKojo Acquah matchContainerType(rewriter.getI8Type(), lhsExtInType); 105cb6ff746SKojo Acquah extsiLhs = rewriter.createOrFold<arith::ExtSIOp>(loc, targetLhsExtTy, 106cb6ff746SKojo Acquah origLhsExtOp.getIn()); 107cb6ff746SKojo Acquah } 108cb6ff746SKojo Acquah } 109cb6ff746SKojo Acquah if (auto rhsExtInType = 110a5757c5bSChristian Sigg dyn_cast<mlir::VectorType>(origRhsExtOp.getIn().getType())) { 111cb6ff746SKojo Acquah if (rhsExtInType.getElementTypeBitWidth() <= 8) { 112cb6ff746SKojo Acquah Type targetRhsExtTy = 113cb6ff746SKojo Acquah matchContainerType(rewriter.getI8Type(), rhsExtInType); 114cb6ff746SKojo Acquah extsiRhs = rewriter.createOrFold<arith::ExtSIOp>(loc, targetRhsExtTy, 115cb6ff746SKojo Acquah origRhsExtOp.getIn()); 116cb6ff746SKojo Acquah } 117cb6ff746SKojo Acquah } 118cb6ff746SKojo Acquah 119cb6ff746SKojo Acquah if (!extsiLhs || !extsiRhs) { 120cb6ff746SKojo Acquah return failure(); 121cb6ff746SKojo Acquah } 122cb6ff746SKojo Acquah 123fe84369cSKojo Acquah // Initial accumulator for the final result. This is the un-tiled result if 124fe84369cSKojo Acquah // tiling is done. 125fe84369cSKojo Acquah Value result = rewriter.create<arith::ConstantOp>( 126fe84369cSKojo Acquah loc, op.getResultType(), rewriter.getZeroAttr(op.getResultType())); 127cb6ff746SKojo Acquah 128fe84369cSKojo Acquah SmallVector<int64_t> unrolledSize = *op.getShapeForUnroll(); 129*9cbc1f29SHan-Chung Wang SmallVector<int64_t> smmlaShape = {2, 8}; 130*9cbc1f29SHan-Chung Wang SmallVector<int64_t> loopOrder = {0, 1}; 131c511c906SKojo Acquah if (unrolledSize.size() == 3) { 132c511c906SKojo Acquah smmlaShape.insert(smmlaShape.begin(), isVecmat ? 1 : 2); 133c511c906SKojo Acquah loopOrder.push_back(2); 134c511c906SKojo Acquah } 135b859a9fcSKojo Acquah 136b859a9fcSKojo Acquah // Keep track of the previous accumulator when tiling over K. 137b859a9fcSKojo Acquah Value kAcc; 138fe84369cSKojo Acquah for (SmallVector<int64_t> offsets : 139fe84369cSKojo Acquah StaticTileOffsetRange(unrolledSize, smmlaShape, loopOrder)) { 140fe84369cSKojo Acquah // Helper to compute the new shape of each operand and extract the slice. 141fe84369cSKojo Acquah auto extractOperand = [&](Value operand, AffineMap permutationMap, 142fe84369cSKojo Acquah ArrayRef<int64_t> operandOffsets) { 143fe84369cSKojo Acquah SmallVector<int64_t> operandShape = 144fe84369cSKojo Acquah applyPermutationMap(permutationMap, ArrayRef<int64_t>(smmlaShape)); 145fe84369cSKojo Acquah SmallVector<int64_t> operandStrides(operandOffsets.size(), 1); 146fe84369cSKojo Acquah return rewriter.createOrFold<vector::ExtractStridedSliceOp>( 147fe84369cSKojo Acquah loc, operand, operandOffsets, operandShape, operandStrides); 148fe84369cSKojo Acquah }; 149fe84369cSKojo Acquah 150fe84369cSKojo Acquah // Extract tiled lhs, rhs, and acc 151fe84369cSKojo Acquah AffineMap lhsPermutationMap = op.getIndexingMapsArray()[0]; 152fe84369cSKojo Acquah SmallVector<int64_t> lhsOffsets = 153fe84369cSKojo Acquah applyPermutationMap(lhsPermutationMap, ArrayRef<int64_t>(offsets)); 154fe84369cSKojo Acquah Value tiledLhs = extractOperand(extsiLhs, lhsPermutationMap, lhsOffsets); 155fe84369cSKojo Acquah AffineMap rhsPermutationMap = op.getIndexingMapsArray()[1]; 156fe84369cSKojo Acquah SmallVector<int64_t> rhsOffsets = 157fe84369cSKojo Acquah applyPermutationMap(rhsPermutationMap, ArrayRef<int64_t>(offsets)); 158fe84369cSKojo Acquah Value tiledRhs = extractOperand(extsiRhs, rhsPermutationMap, rhsOffsets); 159fe84369cSKojo Acquah AffineMap accPermutationMap = op.getIndexingMapsArray()[2]; 160fe84369cSKojo Acquah SmallVector<int64_t> accOffsets = 161fe84369cSKojo Acquah applyPermutationMap(accPermutationMap, ArrayRef<int64_t>(offsets)); 162fe84369cSKojo Acquah Value tiledAcc = 163fe84369cSKojo Acquah extractOperand(op.getAcc(), accPermutationMap, accOffsets); 164fe84369cSKojo Acquah 165c511c906SKojo Acquah auto inputElementType = 166a5757c5bSChristian Sigg cast<ShapedType>(tiledLhs.getType()).getElementType(); 167c511c906SKojo Acquah auto accElementType = 168a5757c5bSChristian Sigg cast<ShapedType>(tiledAcc.getType()).getElementType(); 169c511c906SKojo Acquah auto inputExpandedType = VectorType::get({2, 8}, inputElementType); 170c511c906SKojo Acquah auto outputExpandedType = VectorType::get({2, 2}, accElementType); 171c511c906SKojo Acquah 172c511c906SKojo Acquah // With vecmat, tiled LHS and ACC will contain only one of 2 necessary 173c511c906SKojo Acquah // rows along dimM. Expand their shapes to match the smmla op. 174c511c906SKojo Acquah if (isVecmat) { 175c511c906SKojo Acquah auto expandForSMMLA = [&](Value tiledOperand, 176c511c906SKojo Acquah VectorType expandedTypeType) { 177c511c906SKojo Acquah auto emptyOperand = rewriter.create<arith::ConstantOp>( 178c511c906SKojo Acquah loc, expandedTypeType, rewriter.getZeroAttr(expandedTypeType)); 179c511c906SKojo Acquah SmallVector<int64_t> offsets( 180a5757c5bSChristian Sigg cast<ShapedType>(emptyOperand.getType()).getRank(), 0); 181c511c906SKojo Acquah SmallVector<int64_t> strides( 182a5757c5bSChristian Sigg cast<ShapedType>(tiledOperand.getType()).getRank(), 1); 183c511c906SKojo Acquah return rewriter.createOrFold<vector::InsertStridedSliceOp>( 184c511c906SKojo Acquah loc, tiledOperand, emptyOperand, offsets, strides); 185c511c906SKojo Acquah }; 186c511c906SKojo Acquah tiledLhs = expandForSMMLA(tiledLhs, inputExpandedType); 187c511c906SKojo Acquah tiledAcc = expandForSMMLA(tiledAcc, outputExpandedType); 188c511c906SKojo Acquah } 189c511c906SKojo Acquah 190fe84369cSKojo Acquah // Collapse tiled operands to 1D vectors required by smmla intrinsic 191c511c906SKojo Acquah auto collapsedInputType = 192c511c906SKojo Acquah VectorType::get(inputExpandedType.getNumElements(), inputElementType); 193fe84369cSKojo Acquah auto collapsedLhs = rewriter.createOrFold<vector::ShapeCastOp>( 194fe84369cSKojo Acquah tiledLhs.getLoc(), collapsedInputType, tiledLhs); 195fe84369cSKojo Acquah auto collapsedRhs = rewriter.createOrFold<vector::ShapeCastOp>( 196fe84369cSKojo Acquah tiledRhs.getLoc(), collapsedInputType, tiledRhs); 197c511c906SKojo Acquah auto collapsedOutputType = 198c511c906SKojo Acquah VectorType::get(outputExpandedType.getNumElements(), accElementType); 199b859a9fcSKojo Acquah 200b859a9fcSKojo Acquah bool initialKAcc = offsets.back() == 0; 201b859a9fcSKojo Acquah Value collapsedRes; 202b859a9fcSKojo Acquah if (!initialKAcc) { 203b859a9fcSKojo Acquah collapsedRes = kAcc; 204b859a9fcSKojo Acquah } else { 205b859a9fcSKojo Acquah collapsedRes = rewriter.createOrFold<vector::ShapeCastOp>( 206fe84369cSKojo Acquah tiledAcc.getLoc(), collapsedOutputType, tiledAcc); 207b859a9fcSKojo Acquah } 208fe84369cSKojo Acquah 209fe84369cSKojo Acquah // Insert contract op 210b859a9fcSKojo Acquah kAcc = rewriter.createOrFold<arm_neon::SmmlaOp>( 211cb6ff746SKojo Acquah op.getLoc(), collapsedRes.getType(), collapsedRes, collapsedLhs, 212cb6ff746SKojo Acquah collapsedRhs); 213cb6ff746SKojo Acquah 214cb6ff746SKojo Acquah // Reshape output back to 2D 215fe84369cSKojo Acquah Value tiledRes = rewriter.createOrFold<vector::ShapeCastOp>( 216b859a9fcSKojo Acquah kAcc.getLoc(), tiledAcc.getType(), kAcc); 217fe84369cSKojo Acquah 218b859a9fcSKojo Acquah // With vecmat, only one row of tiled ACC can be inserted into file result 219c511c906SKojo Acquah if (isVecmat) { 220c511c906SKojo Acquah tiledRes = rewriter.createOrFold<vector::ExtractOp>(loc, tiledRes, 0); 221c511c906SKojo Acquah } 222c511c906SKojo Acquah 223fe84369cSKojo Acquah // Insert the tiled result back into the non tiled result of the 224fe84369cSKojo Acquah // contract op. 225fe84369cSKojo Acquah SmallVector<int64_t> strides( 226a5757c5bSChristian Sigg cast<ShapedType>(tiledRes.getType()).getRank(), 1); 227fe84369cSKojo Acquah result = rewriter.createOrFold<vector::InsertStridedSliceOp>( 228fe84369cSKojo Acquah loc, tiledRes, result, accOffsets, strides); 229fe84369cSKojo Acquah } 230fe84369cSKojo Acquah 231fe84369cSKojo Acquah rewriter.replaceOp(op, result); 232cb6ff746SKojo Acquah return success(); 233cb6ff746SKojo Acquah } 234cb6ff746SKojo Acquah }; 235cb6ff746SKojo Acquah 236cb6ff746SKojo Acquah } // namespace 237cb6ff746SKojo Acquah 238cb6ff746SKojo Acquah void mlir::arm_neon::populateLowerContractionToSMMLAPatternPatterns( 239cb6ff746SKojo Acquah RewritePatternSet &patterns) { 240cb6ff746SKojo Acquah MLIRContext *context = patterns.getContext(); 241cb6ff746SKojo Acquah patterns.add<LowerContractionToSMMLAPattern>(context, /*benefit=*/1); 242cb6ff746SKojo Acquah } 243