xref: /llvm-project/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp (revision 9cbc1f29cabc01c02a523c11d098c00650f6955c)
1cb6ff746SKojo Acquah //===- LowerContractionToSMMLAPattern.cpp - Contract to SMMLA ---*- C++ -*-===//
2cb6ff746SKojo Acquah //
3cb6ff746SKojo Acquah // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4cb6ff746SKojo Acquah // See https://llvm.org/LICENSE.txt for license information.
5cb6ff746SKojo Acquah // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6cb6ff746SKojo Acquah //
7cb6ff746SKojo Acquah //===----------------------------------------------------------------------===//
8cb6ff746SKojo Acquah //
9cb6ff746SKojo Acquah // This file implements lowering patterns from vector.contract to
10cb6ff746SKojo Acquah // arm_neon.intr.smmla
11cb6ff746SKojo Acquah //
12cb6ff746SKojo Acquah //===---
13cb6ff746SKojo Acquah 
14cb6ff746SKojo Acquah #include "mlir/Dialect/Arith/IR/Arith.h"
15cb6ff746SKojo Acquah #include "mlir/Dialect/ArmNeon/ArmNeonDialect.h"
16cb6ff746SKojo Acquah #include "mlir/Dialect/ArmNeon/Transforms.h"
17cb6ff746SKojo Acquah #include "mlir/Dialect/Func/IR/FuncOps.h"
18cb6ff746SKojo Acquah #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
19fe84369cSKojo Acquah #include "mlir/Dialect/Utils/IndexingUtils.h"
20cb6ff746SKojo Acquah #include "mlir/Dialect/Vector/IR/VectorOps.h"
21fe84369cSKojo Acquah #include "mlir/IR/AffineMap.h"
22cb6ff746SKojo Acquah #include "mlir/IR/PatternMatch.h"
23cb6ff746SKojo Acquah #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
24cb6ff746SKojo Acquah 
25cb6ff746SKojo Acquah #define DEBUG_TYPE "lower-contract-to-arm-neon"
26cb6ff746SKojo Acquah 
27cb6ff746SKojo Acquah using namespace mlir;
28cb6ff746SKojo Acquah using namespace mlir::arm_neon;
29cb6ff746SKojo Acquah 
30cb6ff746SKojo Acquah namespace {
31cb6ff746SKojo Acquah 
32cb6ff746SKojo Acquah /// Return the shaped type with new element type.
33cb6ff746SKojo Acquah static Type matchContainerType(Type element, Type container) {
34cb6ff746SKojo Acquah   if (auto shapedTy = dyn_cast<ShapedType>(container)) {
35cb6ff746SKojo Acquah     return shapedTy.clone(element);
36cb6ff746SKojo Acquah   }
37cb6ff746SKojo Acquah   return element;
38cb6ff746SKojo Acquah }
39cb6ff746SKojo Acquah 
40fe84369cSKojo Acquah /// Lowering from a vector::contractOp arm neon smmla intrinsic. This will tile
41fe84369cSKojo Acquah /// any vector.contract into multiple smmla instructions with unrolling so long
42c511c906SKojo Acquah /// as [2,2,8] is a divisor of its shape. It can also process vecmats with dimM
43c511c906SKojo Acquah /// = 1 (either explicitly or inferred if LHS has only dimK) If no unrolling is
44c511c906SKojo Acquah /// necessary, a single smmla instruction is emitted.
45cb6ff746SKojo Acquah class LowerContractionToSMMLAPattern
46cb6ff746SKojo Acquah     : public OpRewritePattern<vector::ContractionOp> {
47cb6ff746SKojo Acquah public:
48cb6ff746SKojo Acquah   using OpRewritePattern::OpRewritePattern;
49cb6ff746SKojo Acquah   LogicalResult matchAndRewrite(vector::ContractionOp op,
50cb6ff746SKojo Acquah                                 PatternRewriter &rewriter) const override {
51cb6ff746SKojo Acquah     Location loc = op.getLoc();
52c511c906SKojo Acquah     // Infer tile sizes from operands. For vecmat, LHS may only have 1 dim.
53c511c906SKojo Acquah     // Note: RHS is not transposed.
54cb6ff746SKojo Acquah     mlir::VectorType lhsType = op.getLhsType();
55cb6ff746SKojo Acquah     mlir::VectorType rhsType = op.getRhsType();
5604bf1a40SKojo Acquah     // Avoid 0-D vectors and 1-D rhs:
5704bf1a40SKojo Acquah     if (!lhsType.hasRank() || !rhsType.hasRank() || rhsType.getRank() < 2)
5804bf1a40SKojo Acquah       return failure();
59c511c906SKojo Acquah     auto dimM = lhsType.getRank() == 1 ? 1 : lhsType.getDimSize(0);
60cb6ff746SKojo Acquah     auto dimN = rhsType.getDimSize(0);
61c511c906SKojo Acquah     auto dimK = rhsType.getDimSize(1);
62c511c906SKojo Acquah     bool isVecmat = dimM == 1 ? true : false;
63c511c906SKojo Acquah     if (lhsType.getDimSize(lhsType.getRank() - 1) !=
64c511c906SKojo Acquah         rhsType.getDimSize(rhsType.getRank() - 1)) {
65c511c906SKojo Acquah       return failure(); // dimK mismatch
66c511c906SKojo Acquah     }
67fe84369cSKojo Acquah     // Unrolling patterns can handle any [2, 2, 8] shaped multiple of inputs for
68fe84369cSKojo Acquah     // tiling.
69c511c906SKojo Acquah     if ((dimM % 2 != 0 && !isVecmat) || dimN % 2 != 0 || dimK % 8 != 0) {
70c511c906SKojo Acquah       return failure();
71c511c906SKojo Acquah     }
72c511c906SKojo Acquah 
73c511c906SKojo Acquah     // Check iterator types for contract. All iterators except inner-most
74c511c906SKojo Acquah     // dimension must be parallel.
75c511c906SKojo Acquah     auto iteratorTypes = op.getIteratorTypesArray();
76c511c906SKojo Acquah     if (iteratorTypes.size() > 3 || iteratorTypes[iteratorTypes.size() - 1] !=
77c511c906SKojo Acquah                                         vector::IteratorType::reduction) {
78c511c906SKojo Acquah       return failure();
79c511c906SKojo Acquah     }
80c511c906SKojo Acquah     if (llvm::any_of(ArrayRef<vector::IteratorType>(iteratorTypes).drop_back(1),
81c511c906SKojo Acquah                      [](vector::IteratorType iteratorType) {
82c511c906SKojo Acquah                        return iteratorType != vector::IteratorType::parallel;
83c511c906SKojo Acquah                      })) {
84cb6ff746SKojo Acquah       return failure();
85cb6ff746SKojo Acquah     }
86cb6ff746SKojo Acquah 
87cb6ff746SKojo Acquah     // Check two extsi inputs Rhs Lhs for contract.
88cb6ff746SKojo Acquah     arith::ExtSIOp origLhsExtOp =
89fe84369cSKojo Acquah         dyn_cast_or_null<arith::ExtSIOp>(op.getLhs().getDefiningOp());
90cb6ff746SKojo Acquah     arith::ExtSIOp origRhsExtOp =
91fe84369cSKojo Acquah         dyn_cast_or_null<arith::ExtSIOp>(op.getRhs().getDefiningOp());
92cb6ff746SKojo Acquah     if (!origLhsExtOp || !origRhsExtOp) {
93cb6ff746SKojo Acquah       return failure();
94cb6ff746SKojo Acquah     }
95cb6ff746SKojo Acquah 
96cb6ff746SKojo Acquah     // Match any iX to i32 for X<8 then turn into an i8 output. Feed into
97cb6ff746SKojo Acquah     // following neon instruction. Check inputs for extsi are <=i8
98cb6ff746SKojo Acquah     Value extsiLhs;
99cb6ff746SKojo Acquah     Value extsiRhs;
100cb6ff746SKojo Acquah     if (auto lhsExtInType =
101a5757c5bSChristian Sigg             dyn_cast<mlir::VectorType>(origLhsExtOp.getIn().getType())) {
102cb6ff746SKojo Acquah       if (lhsExtInType.getElementTypeBitWidth() <= 8) {
103cb6ff746SKojo Acquah         Type targetLhsExtTy =
104cb6ff746SKojo Acquah             matchContainerType(rewriter.getI8Type(), lhsExtInType);
105cb6ff746SKojo Acquah         extsiLhs = rewriter.createOrFold<arith::ExtSIOp>(loc, targetLhsExtTy,
106cb6ff746SKojo Acquah                                                          origLhsExtOp.getIn());
107cb6ff746SKojo Acquah       }
108cb6ff746SKojo Acquah     }
109cb6ff746SKojo Acquah     if (auto rhsExtInType =
110a5757c5bSChristian Sigg             dyn_cast<mlir::VectorType>(origRhsExtOp.getIn().getType())) {
111cb6ff746SKojo Acquah       if (rhsExtInType.getElementTypeBitWidth() <= 8) {
112cb6ff746SKojo Acquah         Type targetRhsExtTy =
113cb6ff746SKojo Acquah             matchContainerType(rewriter.getI8Type(), rhsExtInType);
114cb6ff746SKojo Acquah         extsiRhs = rewriter.createOrFold<arith::ExtSIOp>(loc, targetRhsExtTy,
115cb6ff746SKojo Acquah                                                          origRhsExtOp.getIn());
116cb6ff746SKojo Acquah       }
117cb6ff746SKojo Acquah     }
118cb6ff746SKojo Acquah 
119cb6ff746SKojo Acquah     if (!extsiLhs || !extsiRhs) {
120cb6ff746SKojo Acquah       return failure();
121cb6ff746SKojo Acquah     }
122cb6ff746SKojo Acquah 
123fe84369cSKojo Acquah     // Initial accumulator for the final result. This is the un-tiled result if
124fe84369cSKojo Acquah     // tiling is done.
125fe84369cSKojo Acquah     Value result = rewriter.create<arith::ConstantOp>(
126fe84369cSKojo Acquah         loc, op.getResultType(), rewriter.getZeroAttr(op.getResultType()));
127cb6ff746SKojo Acquah 
128fe84369cSKojo Acquah     SmallVector<int64_t> unrolledSize = *op.getShapeForUnroll();
129*9cbc1f29SHan-Chung Wang     SmallVector<int64_t> smmlaShape = {2, 8};
130*9cbc1f29SHan-Chung Wang     SmallVector<int64_t> loopOrder = {0, 1};
131c511c906SKojo Acquah     if (unrolledSize.size() == 3) {
132c511c906SKojo Acquah       smmlaShape.insert(smmlaShape.begin(), isVecmat ? 1 : 2);
133c511c906SKojo Acquah       loopOrder.push_back(2);
134c511c906SKojo Acquah     }
135b859a9fcSKojo Acquah 
136b859a9fcSKojo Acquah     // Keep track of the previous accumulator when tiling over K.
137b859a9fcSKojo Acquah     Value kAcc;
138fe84369cSKojo Acquah     for (SmallVector<int64_t> offsets :
139fe84369cSKojo Acquah          StaticTileOffsetRange(unrolledSize, smmlaShape, loopOrder)) {
140fe84369cSKojo Acquah       // Helper to compute the new shape of each operand and extract the slice.
141fe84369cSKojo Acquah       auto extractOperand = [&](Value operand, AffineMap permutationMap,
142fe84369cSKojo Acquah                                 ArrayRef<int64_t> operandOffsets) {
143fe84369cSKojo Acquah         SmallVector<int64_t> operandShape =
144fe84369cSKojo Acquah             applyPermutationMap(permutationMap, ArrayRef<int64_t>(smmlaShape));
145fe84369cSKojo Acquah         SmallVector<int64_t> operandStrides(operandOffsets.size(), 1);
146fe84369cSKojo Acquah         return rewriter.createOrFold<vector::ExtractStridedSliceOp>(
147fe84369cSKojo Acquah             loc, operand, operandOffsets, operandShape, operandStrides);
148fe84369cSKojo Acquah       };
149fe84369cSKojo Acquah 
150fe84369cSKojo Acquah       // Extract tiled lhs, rhs, and acc
151fe84369cSKojo Acquah       AffineMap lhsPermutationMap = op.getIndexingMapsArray()[0];
152fe84369cSKojo Acquah       SmallVector<int64_t> lhsOffsets =
153fe84369cSKojo Acquah           applyPermutationMap(lhsPermutationMap, ArrayRef<int64_t>(offsets));
154fe84369cSKojo Acquah       Value tiledLhs = extractOperand(extsiLhs, lhsPermutationMap, lhsOffsets);
155fe84369cSKojo Acquah       AffineMap rhsPermutationMap = op.getIndexingMapsArray()[1];
156fe84369cSKojo Acquah       SmallVector<int64_t> rhsOffsets =
157fe84369cSKojo Acquah           applyPermutationMap(rhsPermutationMap, ArrayRef<int64_t>(offsets));
158fe84369cSKojo Acquah       Value tiledRhs = extractOperand(extsiRhs, rhsPermutationMap, rhsOffsets);
159fe84369cSKojo Acquah       AffineMap accPermutationMap = op.getIndexingMapsArray()[2];
160fe84369cSKojo Acquah       SmallVector<int64_t> accOffsets =
161fe84369cSKojo Acquah           applyPermutationMap(accPermutationMap, ArrayRef<int64_t>(offsets));
162fe84369cSKojo Acquah       Value tiledAcc =
163fe84369cSKojo Acquah           extractOperand(op.getAcc(), accPermutationMap, accOffsets);
164fe84369cSKojo Acquah 
165c511c906SKojo Acquah       auto inputElementType =
166a5757c5bSChristian Sigg           cast<ShapedType>(tiledLhs.getType()).getElementType();
167c511c906SKojo Acquah       auto accElementType =
168a5757c5bSChristian Sigg           cast<ShapedType>(tiledAcc.getType()).getElementType();
169c511c906SKojo Acquah       auto inputExpandedType = VectorType::get({2, 8}, inputElementType);
170c511c906SKojo Acquah       auto outputExpandedType = VectorType::get({2, 2}, accElementType);
171c511c906SKojo Acquah 
172c511c906SKojo Acquah       // With vecmat, tiled LHS and ACC will contain only one of 2 necessary
173c511c906SKojo Acquah       // rows along dimM. Expand their shapes to match the smmla op.
174c511c906SKojo Acquah       if (isVecmat) {
175c511c906SKojo Acquah         auto expandForSMMLA = [&](Value tiledOperand,
176c511c906SKojo Acquah                                   VectorType expandedTypeType) {
177c511c906SKojo Acquah           auto emptyOperand = rewriter.create<arith::ConstantOp>(
178c511c906SKojo Acquah               loc, expandedTypeType, rewriter.getZeroAttr(expandedTypeType));
179c511c906SKojo Acquah           SmallVector<int64_t> offsets(
180a5757c5bSChristian Sigg               cast<ShapedType>(emptyOperand.getType()).getRank(), 0);
181c511c906SKojo Acquah           SmallVector<int64_t> strides(
182a5757c5bSChristian Sigg               cast<ShapedType>(tiledOperand.getType()).getRank(), 1);
183c511c906SKojo Acquah           return rewriter.createOrFold<vector::InsertStridedSliceOp>(
184c511c906SKojo Acquah               loc, tiledOperand, emptyOperand, offsets, strides);
185c511c906SKojo Acquah         };
186c511c906SKojo Acquah         tiledLhs = expandForSMMLA(tiledLhs, inputExpandedType);
187c511c906SKojo Acquah         tiledAcc = expandForSMMLA(tiledAcc, outputExpandedType);
188c511c906SKojo Acquah       }
189c511c906SKojo Acquah 
190fe84369cSKojo Acquah       // Collapse tiled operands to 1D vectors required by smmla intrinsic
191c511c906SKojo Acquah       auto collapsedInputType =
192c511c906SKojo Acquah           VectorType::get(inputExpandedType.getNumElements(), inputElementType);
193fe84369cSKojo Acquah       auto collapsedLhs = rewriter.createOrFold<vector::ShapeCastOp>(
194fe84369cSKojo Acquah           tiledLhs.getLoc(), collapsedInputType, tiledLhs);
195fe84369cSKojo Acquah       auto collapsedRhs = rewriter.createOrFold<vector::ShapeCastOp>(
196fe84369cSKojo Acquah           tiledRhs.getLoc(), collapsedInputType, tiledRhs);
197c511c906SKojo Acquah       auto collapsedOutputType =
198c511c906SKojo Acquah           VectorType::get(outputExpandedType.getNumElements(), accElementType);
199b859a9fcSKojo Acquah 
200b859a9fcSKojo Acquah       bool initialKAcc = offsets.back() == 0;
201b859a9fcSKojo Acquah       Value collapsedRes;
202b859a9fcSKojo Acquah       if (!initialKAcc) {
203b859a9fcSKojo Acquah         collapsedRes = kAcc;
204b859a9fcSKojo Acquah       } else {
205b859a9fcSKojo Acquah         collapsedRes = rewriter.createOrFold<vector::ShapeCastOp>(
206fe84369cSKojo Acquah             tiledAcc.getLoc(), collapsedOutputType, tiledAcc);
207b859a9fcSKojo Acquah       }
208fe84369cSKojo Acquah 
209fe84369cSKojo Acquah       // Insert contract op
210b859a9fcSKojo Acquah       kAcc = rewriter.createOrFold<arm_neon::SmmlaOp>(
211cb6ff746SKojo Acquah           op.getLoc(), collapsedRes.getType(), collapsedRes, collapsedLhs,
212cb6ff746SKojo Acquah           collapsedRhs);
213cb6ff746SKojo Acquah 
214cb6ff746SKojo Acquah       // Reshape output back to 2D
215fe84369cSKojo Acquah       Value tiledRes = rewriter.createOrFold<vector::ShapeCastOp>(
216b859a9fcSKojo Acquah           kAcc.getLoc(), tiledAcc.getType(), kAcc);
217fe84369cSKojo Acquah 
218b859a9fcSKojo Acquah       // With vecmat, only one row of tiled ACC can be inserted into file result
219c511c906SKojo Acquah       if (isVecmat) {
220c511c906SKojo Acquah         tiledRes = rewriter.createOrFold<vector::ExtractOp>(loc, tiledRes, 0);
221c511c906SKojo Acquah       }
222c511c906SKojo Acquah 
223fe84369cSKojo Acquah       // Insert the tiled result back into the non tiled result of the
224fe84369cSKojo Acquah       // contract op.
225fe84369cSKojo Acquah       SmallVector<int64_t> strides(
226a5757c5bSChristian Sigg           cast<ShapedType>(tiledRes.getType()).getRank(), 1);
227fe84369cSKojo Acquah       result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
228fe84369cSKojo Acquah           loc, tiledRes, result, accOffsets, strides);
229fe84369cSKojo Acquah     }
230fe84369cSKojo Acquah 
231fe84369cSKojo Acquah     rewriter.replaceOp(op, result);
232cb6ff746SKojo Acquah     return success();
233cb6ff746SKojo Acquah   }
234cb6ff746SKojo Acquah };
235cb6ff746SKojo Acquah 
236cb6ff746SKojo Acquah } // namespace
237cb6ff746SKojo Acquah 
238cb6ff746SKojo Acquah void mlir::arm_neon::populateLowerContractionToSMMLAPatternPatterns(
239cb6ff746SKojo Acquah     RewritePatternSet &patterns) {
240cb6ff746SKojo Acquah   MLIRContext *context = patterns.getContext();
241cb6ff746SKojo Acquah   patterns.add<LowerContractionToSMMLAPattern>(context, /*benefit=*/1);
242cb6ff746SKojo Acquah }
243