xref: /llvm-project/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp (revision 9cbc1f29cabc01c02a523c11d098c00650f6955c)
1 //===- LowerContractionToSMMLAPattern.cpp - Contract to SMMLA ---*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements lowering patterns from vector.contract to
10 // arm_neon.intr.smmla
11 //
12 //===---
13 
14 #include "mlir/Dialect/Arith/IR/Arith.h"
15 #include "mlir/Dialect/ArmNeon/ArmNeonDialect.h"
16 #include "mlir/Dialect/ArmNeon/Transforms.h"
17 #include "mlir/Dialect/Func/IR/FuncOps.h"
18 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
19 #include "mlir/Dialect/Utils/IndexingUtils.h"
20 #include "mlir/Dialect/Vector/IR/VectorOps.h"
21 #include "mlir/IR/AffineMap.h"
22 #include "mlir/IR/PatternMatch.h"
23 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
24 
25 #define DEBUG_TYPE "lower-contract-to-arm-neon"
26 
27 using namespace mlir;
28 using namespace mlir::arm_neon;
29 
30 namespace {
31 
32 /// Return the shaped type with new element type.
33 static Type matchContainerType(Type element, Type container) {
34   if (auto shapedTy = dyn_cast<ShapedType>(container)) {
35     return shapedTy.clone(element);
36   }
37   return element;
38 }
39 
40 /// Lowering from a vector::contractOp arm neon smmla intrinsic. This will tile
41 /// any vector.contract into multiple smmla instructions with unrolling so long
42 /// as [2,2,8] is a divisor of its shape. It can also process vecmats with dimM
43 /// = 1 (either explicitly or inferred if LHS has only dimK) If no unrolling is
44 /// necessary, a single smmla instruction is emitted.
45 class LowerContractionToSMMLAPattern
46     : public OpRewritePattern<vector::ContractionOp> {
47 public:
48   using OpRewritePattern::OpRewritePattern;
49   LogicalResult matchAndRewrite(vector::ContractionOp op,
50                                 PatternRewriter &rewriter) const override {
51     Location loc = op.getLoc();
52     // Infer tile sizes from operands. For vecmat, LHS may only have 1 dim.
53     // Note: RHS is not transposed.
54     mlir::VectorType lhsType = op.getLhsType();
55     mlir::VectorType rhsType = op.getRhsType();
56     // Avoid 0-D vectors and 1-D rhs:
57     if (!lhsType.hasRank() || !rhsType.hasRank() || rhsType.getRank() < 2)
58       return failure();
59     auto dimM = lhsType.getRank() == 1 ? 1 : lhsType.getDimSize(0);
60     auto dimN = rhsType.getDimSize(0);
61     auto dimK = rhsType.getDimSize(1);
62     bool isVecmat = dimM == 1 ? true : false;
63     if (lhsType.getDimSize(lhsType.getRank() - 1) !=
64         rhsType.getDimSize(rhsType.getRank() - 1)) {
65       return failure(); // dimK mismatch
66     }
67     // Unrolling patterns can handle any [2, 2, 8] shaped multiple of inputs for
68     // tiling.
69     if ((dimM % 2 != 0 && !isVecmat) || dimN % 2 != 0 || dimK % 8 != 0) {
70       return failure();
71     }
72 
73     // Check iterator types for contract. All iterators except inner-most
74     // dimension must be parallel.
75     auto iteratorTypes = op.getIteratorTypesArray();
76     if (iteratorTypes.size() > 3 || iteratorTypes[iteratorTypes.size() - 1] !=
77                                         vector::IteratorType::reduction) {
78       return failure();
79     }
80     if (llvm::any_of(ArrayRef<vector::IteratorType>(iteratorTypes).drop_back(1),
81                      [](vector::IteratorType iteratorType) {
82                        return iteratorType != vector::IteratorType::parallel;
83                      })) {
84       return failure();
85     }
86 
87     // Check two extsi inputs Rhs Lhs for contract.
88     arith::ExtSIOp origLhsExtOp =
89         dyn_cast_or_null<arith::ExtSIOp>(op.getLhs().getDefiningOp());
90     arith::ExtSIOp origRhsExtOp =
91         dyn_cast_or_null<arith::ExtSIOp>(op.getRhs().getDefiningOp());
92     if (!origLhsExtOp || !origRhsExtOp) {
93       return failure();
94     }
95 
96     // Match any iX to i32 for X<8 then turn into an i8 output. Feed into
97     // following neon instruction. Check inputs for extsi are <=i8
98     Value extsiLhs;
99     Value extsiRhs;
100     if (auto lhsExtInType =
101             dyn_cast<mlir::VectorType>(origLhsExtOp.getIn().getType())) {
102       if (lhsExtInType.getElementTypeBitWidth() <= 8) {
103         Type targetLhsExtTy =
104             matchContainerType(rewriter.getI8Type(), lhsExtInType);
105         extsiLhs = rewriter.createOrFold<arith::ExtSIOp>(loc, targetLhsExtTy,
106                                                          origLhsExtOp.getIn());
107       }
108     }
109     if (auto rhsExtInType =
110             dyn_cast<mlir::VectorType>(origRhsExtOp.getIn().getType())) {
111       if (rhsExtInType.getElementTypeBitWidth() <= 8) {
112         Type targetRhsExtTy =
113             matchContainerType(rewriter.getI8Type(), rhsExtInType);
114         extsiRhs = rewriter.createOrFold<arith::ExtSIOp>(loc, targetRhsExtTy,
115                                                          origRhsExtOp.getIn());
116       }
117     }
118 
119     if (!extsiLhs || !extsiRhs) {
120       return failure();
121     }
122 
123     // Initial accumulator for the final result. This is the un-tiled result if
124     // tiling is done.
125     Value result = rewriter.create<arith::ConstantOp>(
126         loc, op.getResultType(), rewriter.getZeroAttr(op.getResultType()));
127 
128     SmallVector<int64_t> unrolledSize = *op.getShapeForUnroll();
129     SmallVector<int64_t> smmlaShape = {2, 8};
130     SmallVector<int64_t> loopOrder = {0, 1};
131     if (unrolledSize.size() == 3) {
132       smmlaShape.insert(smmlaShape.begin(), isVecmat ? 1 : 2);
133       loopOrder.push_back(2);
134     }
135 
136     // Keep track of the previous accumulator when tiling over K.
137     Value kAcc;
138     for (SmallVector<int64_t> offsets :
139          StaticTileOffsetRange(unrolledSize, smmlaShape, loopOrder)) {
140       // Helper to compute the new shape of each operand and extract the slice.
141       auto extractOperand = [&](Value operand, AffineMap permutationMap,
142                                 ArrayRef<int64_t> operandOffsets) {
143         SmallVector<int64_t> operandShape =
144             applyPermutationMap(permutationMap, ArrayRef<int64_t>(smmlaShape));
145         SmallVector<int64_t> operandStrides(operandOffsets.size(), 1);
146         return rewriter.createOrFold<vector::ExtractStridedSliceOp>(
147             loc, operand, operandOffsets, operandShape, operandStrides);
148       };
149 
150       // Extract tiled lhs, rhs, and acc
151       AffineMap lhsPermutationMap = op.getIndexingMapsArray()[0];
152       SmallVector<int64_t> lhsOffsets =
153           applyPermutationMap(lhsPermutationMap, ArrayRef<int64_t>(offsets));
154       Value tiledLhs = extractOperand(extsiLhs, lhsPermutationMap, lhsOffsets);
155       AffineMap rhsPermutationMap = op.getIndexingMapsArray()[1];
156       SmallVector<int64_t> rhsOffsets =
157           applyPermutationMap(rhsPermutationMap, ArrayRef<int64_t>(offsets));
158       Value tiledRhs = extractOperand(extsiRhs, rhsPermutationMap, rhsOffsets);
159       AffineMap accPermutationMap = op.getIndexingMapsArray()[2];
160       SmallVector<int64_t> accOffsets =
161           applyPermutationMap(accPermutationMap, ArrayRef<int64_t>(offsets));
162       Value tiledAcc =
163           extractOperand(op.getAcc(), accPermutationMap, accOffsets);
164 
165       auto inputElementType =
166           cast<ShapedType>(tiledLhs.getType()).getElementType();
167       auto accElementType =
168           cast<ShapedType>(tiledAcc.getType()).getElementType();
169       auto inputExpandedType = VectorType::get({2, 8}, inputElementType);
170       auto outputExpandedType = VectorType::get({2, 2}, accElementType);
171 
172       // With vecmat, tiled LHS and ACC will contain only one of 2 necessary
173       // rows along dimM. Expand their shapes to match the smmla op.
174       if (isVecmat) {
175         auto expandForSMMLA = [&](Value tiledOperand,
176                                   VectorType expandedTypeType) {
177           auto emptyOperand = rewriter.create<arith::ConstantOp>(
178               loc, expandedTypeType, rewriter.getZeroAttr(expandedTypeType));
179           SmallVector<int64_t> offsets(
180               cast<ShapedType>(emptyOperand.getType()).getRank(), 0);
181           SmallVector<int64_t> strides(
182               cast<ShapedType>(tiledOperand.getType()).getRank(), 1);
183           return rewriter.createOrFold<vector::InsertStridedSliceOp>(
184               loc, tiledOperand, emptyOperand, offsets, strides);
185         };
186         tiledLhs = expandForSMMLA(tiledLhs, inputExpandedType);
187         tiledAcc = expandForSMMLA(tiledAcc, outputExpandedType);
188       }
189 
190       // Collapse tiled operands to 1D vectors required by smmla intrinsic
191       auto collapsedInputType =
192           VectorType::get(inputExpandedType.getNumElements(), inputElementType);
193       auto collapsedLhs = rewriter.createOrFold<vector::ShapeCastOp>(
194           tiledLhs.getLoc(), collapsedInputType, tiledLhs);
195       auto collapsedRhs = rewriter.createOrFold<vector::ShapeCastOp>(
196           tiledRhs.getLoc(), collapsedInputType, tiledRhs);
197       auto collapsedOutputType =
198           VectorType::get(outputExpandedType.getNumElements(), accElementType);
199 
200       bool initialKAcc = offsets.back() == 0;
201       Value collapsedRes;
202       if (!initialKAcc) {
203         collapsedRes = kAcc;
204       } else {
205         collapsedRes = rewriter.createOrFold<vector::ShapeCastOp>(
206             tiledAcc.getLoc(), collapsedOutputType, tiledAcc);
207       }
208 
209       // Insert contract op
210       kAcc = rewriter.createOrFold<arm_neon::SmmlaOp>(
211           op.getLoc(), collapsedRes.getType(), collapsedRes, collapsedLhs,
212           collapsedRhs);
213 
214       // Reshape output back to 2D
215       Value tiledRes = rewriter.createOrFold<vector::ShapeCastOp>(
216           kAcc.getLoc(), tiledAcc.getType(), kAcc);
217 
218       // With vecmat, only one row of tiled ACC can be inserted into file result
219       if (isVecmat) {
220         tiledRes = rewriter.createOrFold<vector::ExtractOp>(loc, tiledRes, 0);
221       }
222 
223       // Insert the tiled result back into the non tiled result of the
224       // contract op.
225       SmallVector<int64_t> strides(
226           cast<ShapedType>(tiledRes.getType()).getRank(), 1);
227       result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
228           loc, tiledRes, result, accOffsets, strides);
229     }
230 
231     rewriter.replaceOp(op, result);
232     return success();
233   }
234 };
235 
236 } // namespace
237 
238 void mlir::arm_neon::populateLowerContractionToSMMLAPatternPatterns(
239     RewritePatternSet &patterns) {
240   MLIRContext *context = patterns.getContext();
241   patterns.add<LowerContractionToSMMLAPattern>(context, /*benefit=*/1);
242 }
243