1 //===- LowerContractionToSMMLAPattern.cpp - Contract to SMMLA ---*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements lowering patterns from vector.contract to 10 // arm_neon.intr.smmla 11 // 12 //===--- 13 14 #include "mlir/Dialect/Arith/IR/Arith.h" 15 #include "mlir/Dialect/ArmNeon/ArmNeonDialect.h" 16 #include "mlir/Dialect/ArmNeon/Transforms.h" 17 #include "mlir/Dialect/Func/IR/FuncOps.h" 18 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 19 #include "mlir/Dialect/Utils/IndexingUtils.h" 20 #include "mlir/Dialect/Vector/IR/VectorOps.h" 21 #include "mlir/IR/AffineMap.h" 22 #include "mlir/IR/PatternMatch.h" 23 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 24 25 #define DEBUG_TYPE "lower-contract-to-arm-neon" 26 27 using namespace mlir; 28 using namespace mlir::arm_neon; 29 30 namespace { 31 32 /// Return the shaped type with new element type. 33 static Type matchContainerType(Type element, Type container) { 34 if (auto shapedTy = dyn_cast<ShapedType>(container)) { 35 return shapedTy.clone(element); 36 } 37 return element; 38 } 39 40 /// Lowering from a vector::contractOp arm neon smmla intrinsic. This will tile 41 /// any vector.contract into multiple smmla instructions with unrolling so long 42 /// as [2,2,8] is a divisor of its shape. It can also process vecmats with dimM 43 /// = 1 (either explicitly or inferred if LHS has only dimK) If no unrolling is 44 /// necessary, a single smmla instruction is emitted. 45 class LowerContractionToSMMLAPattern 46 : public OpRewritePattern<vector::ContractionOp> { 47 public: 48 using OpRewritePattern::OpRewritePattern; 49 LogicalResult matchAndRewrite(vector::ContractionOp op, 50 PatternRewriter &rewriter) const override { 51 Location loc = op.getLoc(); 52 // Infer tile sizes from operands. For vecmat, LHS may only have 1 dim. 53 // Note: RHS is not transposed. 54 mlir::VectorType lhsType = op.getLhsType(); 55 mlir::VectorType rhsType = op.getRhsType(); 56 // Avoid 0-D vectors and 1-D rhs: 57 if (!lhsType.hasRank() || !rhsType.hasRank() || rhsType.getRank() < 2) 58 return failure(); 59 auto dimM = lhsType.getRank() == 1 ? 1 : lhsType.getDimSize(0); 60 auto dimN = rhsType.getDimSize(0); 61 auto dimK = rhsType.getDimSize(1); 62 bool isVecmat = dimM == 1 ? true : false; 63 if (lhsType.getDimSize(lhsType.getRank() - 1) != 64 rhsType.getDimSize(rhsType.getRank() - 1)) { 65 return failure(); // dimK mismatch 66 } 67 // Unrolling patterns can handle any [2, 2, 8] shaped multiple of inputs for 68 // tiling. 69 if ((dimM % 2 != 0 && !isVecmat) || dimN % 2 != 0 || dimK % 8 != 0) { 70 return failure(); 71 } 72 73 // Check iterator types for contract. All iterators except inner-most 74 // dimension must be parallel. 75 auto iteratorTypes = op.getIteratorTypesArray(); 76 if (iteratorTypes.size() > 3 || iteratorTypes[iteratorTypes.size() - 1] != 77 vector::IteratorType::reduction) { 78 return failure(); 79 } 80 if (llvm::any_of(ArrayRef<vector::IteratorType>(iteratorTypes).drop_back(1), 81 [](vector::IteratorType iteratorType) { 82 return iteratorType != vector::IteratorType::parallel; 83 })) { 84 return failure(); 85 } 86 87 // Check two extsi inputs Rhs Lhs for contract. 88 arith::ExtSIOp origLhsExtOp = 89 dyn_cast_or_null<arith::ExtSIOp>(op.getLhs().getDefiningOp()); 90 arith::ExtSIOp origRhsExtOp = 91 dyn_cast_or_null<arith::ExtSIOp>(op.getRhs().getDefiningOp()); 92 if (!origLhsExtOp || !origRhsExtOp) { 93 return failure(); 94 } 95 96 // Match any iX to i32 for X<8 then turn into an i8 output. Feed into 97 // following neon instruction. Check inputs for extsi are <=i8 98 Value extsiLhs; 99 Value extsiRhs; 100 if (auto lhsExtInType = 101 dyn_cast<mlir::VectorType>(origLhsExtOp.getIn().getType())) { 102 if (lhsExtInType.getElementTypeBitWidth() <= 8) { 103 Type targetLhsExtTy = 104 matchContainerType(rewriter.getI8Type(), lhsExtInType); 105 extsiLhs = rewriter.createOrFold<arith::ExtSIOp>(loc, targetLhsExtTy, 106 origLhsExtOp.getIn()); 107 } 108 } 109 if (auto rhsExtInType = 110 dyn_cast<mlir::VectorType>(origRhsExtOp.getIn().getType())) { 111 if (rhsExtInType.getElementTypeBitWidth() <= 8) { 112 Type targetRhsExtTy = 113 matchContainerType(rewriter.getI8Type(), rhsExtInType); 114 extsiRhs = rewriter.createOrFold<arith::ExtSIOp>(loc, targetRhsExtTy, 115 origRhsExtOp.getIn()); 116 } 117 } 118 119 if (!extsiLhs || !extsiRhs) { 120 return failure(); 121 } 122 123 // Initial accumulator for the final result. This is the un-tiled result if 124 // tiling is done. 125 Value result = rewriter.create<arith::ConstantOp>( 126 loc, op.getResultType(), rewriter.getZeroAttr(op.getResultType())); 127 128 SmallVector<int64_t> unrolledSize = *op.getShapeForUnroll(); 129 SmallVector<int64_t> smmlaShape = {2, 8}; 130 SmallVector<int64_t> loopOrder = {0, 1}; 131 if (unrolledSize.size() == 3) { 132 smmlaShape.insert(smmlaShape.begin(), isVecmat ? 1 : 2); 133 loopOrder.push_back(2); 134 } 135 136 // Keep track of the previous accumulator when tiling over K. 137 Value kAcc; 138 for (SmallVector<int64_t> offsets : 139 StaticTileOffsetRange(unrolledSize, smmlaShape, loopOrder)) { 140 // Helper to compute the new shape of each operand and extract the slice. 141 auto extractOperand = [&](Value operand, AffineMap permutationMap, 142 ArrayRef<int64_t> operandOffsets) { 143 SmallVector<int64_t> operandShape = 144 applyPermutationMap(permutationMap, ArrayRef<int64_t>(smmlaShape)); 145 SmallVector<int64_t> operandStrides(operandOffsets.size(), 1); 146 return rewriter.createOrFold<vector::ExtractStridedSliceOp>( 147 loc, operand, operandOffsets, operandShape, operandStrides); 148 }; 149 150 // Extract tiled lhs, rhs, and acc 151 AffineMap lhsPermutationMap = op.getIndexingMapsArray()[0]; 152 SmallVector<int64_t> lhsOffsets = 153 applyPermutationMap(lhsPermutationMap, ArrayRef<int64_t>(offsets)); 154 Value tiledLhs = extractOperand(extsiLhs, lhsPermutationMap, lhsOffsets); 155 AffineMap rhsPermutationMap = op.getIndexingMapsArray()[1]; 156 SmallVector<int64_t> rhsOffsets = 157 applyPermutationMap(rhsPermutationMap, ArrayRef<int64_t>(offsets)); 158 Value tiledRhs = extractOperand(extsiRhs, rhsPermutationMap, rhsOffsets); 159 AffineMap accPermutationMap = op.getIndexingMapsArray()[2]; 160 SmallVector<int64_t> accOffsets = 161 applyPermutationMap(accPermutationMap, ArrayRef<int64_t>(offsets)); 162 Value tiledAcc = 163 extractOperand(op.getAcc(), accPermutationMap, accOffsets); 164 165 auto inputElementType = 166 cast<ShapedType>(tiledLhs.getType()).getElementType(); 167 auto accElementType = 168 cast<ShapedType>(tiledAcc.getType()).getElementType(); 169 auto inputExpandedType = VectorType::get({2, 8}, inputElementType); 170 auto outputExpandedType = VectorType::get({2, 2}, accElementType); 171 172 // With vecmat, tiled LHS and ACC will contain only one of 2 necessary 173 // rows along dimM. Expand their shapes to match the smmla op. 174 if (isVecmat) { 175 auto expandForSMMLA = [&](Value tiledOperand, 176 VectorType expandedTypeType) { 177 auto emptyOperand = rewriter.create<arith::ConstantOp>( 178 loc, expandedTypeType, rewriter.getZeroAttr(expandedTypeType)); 179 SmallVector<int64_t> offsets( 180 cast<ShapedType>(emptyOperand.getType()).getRank(), 0); 181 SmallVector<int64_t> strides( 182 cast<ShapedType>(tiledOperand.getType()).getRank(), 1); 183 return rewriter.createOrFold<vector::InsertStridedSliceOp>( 184 loc, tiledOperand, emptyOperand, offsets, strides); 185 }; 186 tiledLhs = expandForSMMLA(tiledLhs, inputExpandedType); 187 tiledAcc = expandForSMMLA(tiledAcc, outputExpandedType); 188 } 189 190 // Collapse tiled operands to 1D vectors required by smmla intrinsic 191 auto collapsedInputType = 192 VectorType::get(inputExpandedType.getNumElements(), inputElementType); 193 auto collapsedLhs = rewriter.createOrFold<vector::ShapeCastOp>( 194 tiledLhs.getLoc(), collapsedInputType, tiledLhs); 195 auto collapsedRhs = rewriter.createOrFold<vector::ShapeCastOp>( 196 tiledRhs.getLoc(), collapsedInputType, tiledRhs); 197 auto collapsedOutputType = 198 VectorType::get(outputExpandedType.getNumElements(), accElementType); 199 200 bool initialKAcc = offsets.back() == 0; 201 Value collapsedRes; 202 if (!initialKAcc) { 203 collapsedRes = kAcc; 204 } else { 205 collapsedRes = rewriter.createOrFold<vector::ShapeCastOp>( 206 tiledAcc.getLoc(), collapsedOutputType, tiledAcc); 207 } 208 209 // Insert contract op 210 kAcc = rewriter.createOrFold<arm_neon::SmmlaOp>( 211 op.getLoc(), collapsedRes.getType(), collapsedRes, collapsedLhs, 212 collapsedRhs); 213 214 // Reshape output back to 2D 215 Value tiledRes = rewriter.createOrFold<vector::ShapeCastOp>( 216 kAcc.getLoc(), tiledAcc.getType(), kAcc); 217 218 // With vecmat, only one row of tiled ACC can be inserted into file result 219 if (isVecmat) { 220 tiledRes = rewriter.createOrFold<vector::ExtractOp>(loc, tiledRes, 0); 221 } 222 223 // Insert the tiled result back into the non tiled result of the 224 // contract op. 225 SmallVector<int64_t> strides( 226 cast<ShapedType>(tiledRes.getType()).getRank(), 1); 227 result = rewriter.createOrFold<vector::InsertStridedSliceOp>( 228 loc, tiledRes, result, accOffsets, strides); 229 } 230 231 rewriter.replaceOp(op, result); 232 return success(); 233 } 234 }; 235 236 } // namespace 237 238 void mlir::arm_neon::populateLowerContractionToSMMLAPatternPatterns( 239 RewritePatternSet &patterns) { 240 MLIRContext *context = patterns.getContext(); 241 patterns.add<LowerContractionToSMMLAPattern>(context, /*benefit=*/1); 242 } 243