xref: /llvm-project/mlir/include/mlir/Dialect/ArmNeon/ArmNeon.td (revision b214ca82daeece1568268ebc0fbcc2eaa649425b)
1//===-- ArmNeonOps.td - ArmNeon dialect op definitions -----*- tablegen -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file defines the basic operations for the ArmNeon dialect.
10//
11//===----------------------------------------------------------------------===//
12
13#ifndef ARMNEON_OPS
14#define ARMNEON_OPS
15
16include "mlir/Dialect/LLVMIR/LLVMOpBase.td"
17include "mlir/Interfaces/SideEffectInterfaces.td"
18include "mlir/IR/OpBase.td"
19
20//===----------------------------------------------------------------------===//
21// ArmNeon dialect definition
22//===----------------------------------------------------------------------===//
23
24def ArmNeon_Dialect : Dialect {
25  let name = "arm_neon";
26  let cppNamespace = "::mlir::arm_neon";
27
28  // Note: this does not need to depend on LLVMDialect as long as functions in
29  // this dialect (such as canonicalization) do not produce entities belonging
30  // to the LLVMDialect (ops or types).
31}
32
33//===----------------------------------------------------------------------===//
34// ArmNeon type definition
35//===----------------------------------------------------------------------===//
36
37class NeonVectorOfLength<int length, Type elementType> : ShapedContainerType<
38  [elementType], And<[IsVectorOfShape<[length]>, IsFixedVectorOfAnyRankTypePred]>,
39  "a vector with length " # length,
40  "::mlir::VectorType">;
41
42//===----------------------------------------------------------------------===//
43// ArmNeon op definitions
44//===----------------------------------------------------------------------===//
45
46// ArmNeon dialect op that corresponds (and is convertible to) an LLVM IR
47// intrinsic.
48class ArmNeon_IntrOp<string mnemonic, list<int> overloadedResults,
49                     list<int> overloadedOperands, int numResults,
50                     list<Trait> traits = [], bit requiresAccessGroup = 0,
51                     bit requiresAliasAnalysis = 0>
52    : LLVM_IntrOpBase</*dialect=*/ArmNeon_Dialect,
53                      /*opName=*/"intr." # mnemonic,
54                      /*enumName=*/"aarch64_neon_" # !subst(".", "_", mnemonic),
55                      /*overloadedResults=*/overloadedResults,
56                      /*overloadedOperands=*/overloadedOperands,
57                      /*traits=*/traits,
58                      /*numResults=*/numResults,
59                      /*requiresAccessGroup=*/requiresAccessGroup,
60                      /*requiresAliasAnalysis=*/requiresAliasAnalysis>;
61
62// ArmNeon dialect op that corresponds to an LLVM IR intrinsic with one
63// overloaded result.
64class ArmNeon_OverloadedOneResultIntrOp<string mnemonic,
65                                        list<Trait> traits = []>
66  : ArmNeon_IntrOp<mnemonic, [0], [], 1, traits>;
67
68// ArmNeon dialect op that corresponds to an LLVM IR intrinsic with one
69// overloaded result and overloaded operands list.
70class ArmNeon_OverloadedOperandsWithOneResultIntrOp<string mnemonic,
71                                                    list<int> overloadedOperands,
72                                                    list<Trait> traits = []>
73  : ArmNeon_IntrOp<mnemonic, [0], overloadedOperands, 1, traits>;
74
75def SMullOp : ArmNeon_OverloadedOneResultIntrOp<"smull", [
76       Pure,
77       AllTypesMatch<["a", "b"]>,
78       TypesMatchWith<
79         "res has same vector shape and element bitwidth scaled by 2 as a",
80         "a", "res", "::llvm::cast<VectorType>($_self).scaleElementBitwidth(2)">
81    ]> {
82  let summary = "smull roundscale op";
83  let description = [{
84    Signed Multiply Long (vector). This instruction multiplies corresponding
85    signed integer values in the lower or upper half of the vectors of the two
86    source SIMD&FP registers, places the results in a vector, and writes the
87    vector to the destination SIMD&FP register.
88
89    Source:
90    https://developer.arm.com/architectures/instruction-sets/simd-isas/neon/intrinsics
91  }];
92
93  // Supports either:
94  //   (vector<8xi8>, vector<8xi8>) -> (vector<8xi16>)
95  //   (vector<4xi16>, vector<4xi16>) -> (vector<4xi32>)
96  //   (vector<2xi32>, vector<2xi32>) -> (vector<2xi64>)
97  let arguments = (ins VectorOfLengthAndType<[8, 4, 2], [I8, I16, I32]>:$a,
98                       VectorOfLengthAndType<[8, 4, 2], [I8, I16, I32]>:$b);
99  let results = (outs VectorOfLengthAndType<[8, 4, 2], [I16, I32, I64]>:$res);
100  let assemblyFormat =
101    "$a `,` $b attr-dict `:` type($a) `to` type($res)";
102}
103
104def SdotOp : ArmNeon_OverloadedOperandsWithOneResultIntrOp<"sdot",[1], [
105      Pure,
106      AllTypesMatch<["b", "c"]>,
107      AllTypesMatch<["a", "res"]>,
108      TypesMatchWith<"res has the same number of elements as operand b",
109                     "b", "res",
110                     "VectorType::get({::llvm::cast<VectorType>($_self).getShape()[0] / 4},"
111                     "IntegerType::get($_self.getContext(), 32))">]> {
112  let summary = "sdot op";
113  let description = [{
114    Signed integer addition of dot product (vector). This instruction performs
115    the following operation on signed integer vectors: res = dot(b, c) + a,
116    where vector operands are partitioned into groups of four elements.
117
118    Source:
119    https://developer.arm.com/architectures/instruction-sets/simd-isas/neon/intrinsics
120  }];
121  // Supports either:
122  //   (vector<2xi32>, vector<8xi8>, vector<8xi8>) -> vector<2xi32>
123  //   (vector<4xi32>, vector<16xi8>, vector<16xi8>) -> vector<4xi32>
124  let arguments = (ins VectorOfLengthAndType<[4, 2], [I32]>:$a,
125                       VectorOfLengthAndType<[16, 8], [I8]>:$b,
126                       VectorOfLengthAndType<[16, 8], [I8]>:$c);
127  let results = (outs VectorOfLengthAndType<[4, 2], [I32]>:$res);
128  let assemblyFormat =
129    "$a `,` $b `,` $c attr-dict `:` type($b) `,` type($c) `to` type($res)";
130  }
131
132def SmmlaOp : ArmNeon_OverloadedOperandsWithOneResultIntrOp<"smmla",[1], [
133                Pure,
134                AllTypesMatch<["src1", "src2"]>,
135                AllTypesMatch<["acc", "res"]>,
136              ]> {
137  let summary = "Matrix-matrix multiply and accumulate op";
138  let description = [{
139    SMMLA: Signed integer matrix multiply-accumulate.
140
141    Signed 8-bit integer matrix multiply-accumulate. This instruction multiplies
142    the 2x8 matrix of signed 8-bit integer values in the first source vector by
143    the 8x2 matrix of signed 8-bit integer values in the second source vector.
144    The resulting 2x2 32-bit integer matrix product is destructively added to
145    the 32-bit integer matrix accumulator in the destination vector. This is
146    equivalent to performing an 8-way dot product per destination element.
147
148    Source:
149    https://developer.arm.com/architectures/instruction-sets/intrinsics/#f:@navigationhierarchiessimdisa=[Neon]&q=smmla
150  }];
151  // Supports (vector<16xi8>, vector<16xi8>) -> (vector<4xi32>)
152  let arguments = (ins
153          NeonVectorOfLength<4, I32>:$acc,
154          NeonVectorOfLength<16, I8>:$src1,
155          NeonVectorOfLength<16, I8>:$src2
156  );
157  let results = (outs NeonVectorOfLength<4, I32>:$res);
158  let assemblyFormat =
159    "$acc `,` $src1 `,` $src2 attr-dict `:` type($src1) `to` type($res)";
160}
161
162def UmmlaOp : ArmNeon_OverloadedOperandsWithOneResultIntrOp<"ummla",[1], [
163                Pure,
164                AllTypesMatch<["src1", "src2"]>,
165                AllTypesMatch<["acc", "res"]>,
166              ]> {
167  let summary = "Unsinged matrix-matrix multiply and accumulate op";
168  let description = [{
169    UMMLA: Signed integer matrix multiply-accumulate.
170
171    Unsigned 8-bit integer matrix multiply-accumulate. This instruction
172    multiplies the 2x8 matrix of unsigned 8-bit integer values in the first
173    source vector by the 8x2 matrix of unsigned 8-bit integer values in the
174    second source vector. The resulting 2x2 32-bit integer matrix product is
175    destructively added to the 32-bit integer matrix accumulator in the
176    destination vector. This is equivalent to performing an 8-way dot product
177    per destination element.
178
179    Source:
180    https://developer.arm.com/architectures/instruction-sets/intrinsics/#f:@navigationhierarchiessimdisa=[Neon]&q=ummla
181  }];
182  // Supports (vector<16xi8>, vector<16xi8>) -> (vector<4xi32>)
183  let arguments = (ins
184          NeonVectorOfLength<4, I32>:$acc,
185          NeonVectorOfLength<16, I8>:$src1,
186          NeonVectorOfLength<16, I8>:$src2
187  );
188  let results = (outs NeonVectorOfLength<4, I32>:$res);
189  let assemblyFormat =
190    "$acc `,` $src1 `,` $src2 attr-dict `:` type($src1) `to` type($res)";
191}
192
193def UsmmlaOp : ArmNeon_OverloadedOperandsWithOneResultIntrOp<"usmmla",[1], [
194                Pure,
195                AllTypesMatch<["src1", "src2"]>,
196                AllTypesMatch<["acc", "res"]>,
197              ]> {
198  let summary = "Unsignged and signed matrix-matrix multiply and accumulate op";
199  let description = [{
200    USMMLA: Signed integer matrix multiply-accumulate.
201
202    Unsigned and signed 8-bit integer matrix multiply-accumulate. This
203    instruction multiplies the 2x8 matrix of unsigned 8-bit integer values in
204    the first source vector by the 8x2 matrix of signed 8-bit integer values in
205    the second source vector. The resulting 2x2 32-bit integer matrix product is
206    destructively added to the 32-bit integer matrix accumulator in the
207    destination vector. This is equivalent to performing an 8-way dot product
208     per destination element.
209
210
211    Source:
212    https://developer.arm.com/architectures/instruction-sets/intrinsics/#f:@navigationhierarchiessimdisa=[Neon]&q=usmmla
213  }];
214  // Supports (vector<16xi8>, vector<16xi8>) -> (vector<4xi32>)
215  let arguments = (ins
216          NeonVectorOfLength<4, I32>:$acc,
217          NeonVectorOfLength<16, I8>:$src1,
218          NeonVectorOfLength<16, I8>:$src2
219  );
220  let results = (outs NeonVectorOfLength<4, I32>:$res);
221  let assemblyFormat =
222    "$acc `,` $src1 `,` $src2 attr-dict `:` type($src1) `to` type($res)";
223}
224
225class ArmNeon_2dOp<string mnemonic, list<Trait> traits = []>
226    : Op</*dialect=*/ArmNeon_Dialect,
227         /*opName=*/"2d." # mnemonic,
228         /*traits=*/traits>;
229
230def Sdot2dOp : ArmNeon_2dOp<"sdot", [
231      Pure,
232      AllTypesMatch<["b", "c"]>,
233      AllTypesMatch<["a", "res"]>,
234      PredOpTrait<
235        "operand `a` should be 1-dimensional",
236        CPred<"::llvm::cast<VectorType>(getA().getType()).getShape().size() == 1">
237      >,
238      PredOpTrait<
239        "operand `b` should be 2-dimensional",
240        CPred<"::llvm::cast<VectorType>(getB().getType()).getShape().size() == 2">
241      >,
242      PredOpTrait<
243        "operand `b` should have 4 columns",
244        CPred<"::llvm::cast<VectorType>(getB().getType()).getShape()[1] == 4">
245      >,
246      PredOpTrait<
247        "operand `b` should have as many rows as the size of operand `a`",
248        CPred<"::llvm::cast<VectorType>(getB().getType()).getShape()[0] == ::llvm::cast<VectorType>(getA().getType()).getShape()[0]">
249      >,
250      ]
251  > {
252  let summary = "sdot op";
253  let description = [{
254    The two input vectors `b` and `c` have a 2D shape, consisting of either 2
255    or 4 rows, each row having length 4. This operation computes the pair-wise
256    dot-products of the rows of `b` and `c` and accumulates them with the
257    corresponding entry of `a`:
258
259    ```
260    res[i] := a[i] + dot_product(b[i, ...], c[i, ...])
261    ```
262
263  }];
264  // Supports either:
265  //   (vector<2xi32>, vector<2x4xi8>, vector<2x4xi8>) -> vector<2xi32>
266  //   (vector<4xi32>, vector<4x4xi8>, vector<4x4xi8>) -> vector<4xi32>
267  // TODO: how do we express 2D shape requirements here?
268  let arguments = (ins VectorOfLengthAndType<[4, 2], [I32]>:$a,
269                       VectorOfLengthAndType<[16, 8], [I8]>:$b,
270                       VectorOfLengthAndType<[16, 8], [I8]>:$c);
271  let results = (outs VectorOfLengthAndType<[4, 2], [I32]>:$res);
272  let assemblyFormat =
273    "$a `,` $b `,` $c attr-dict `:` type($b) `,` type($c) `to` type($res)";
274  let extraClassDeclaration = [{
275    static constexpr int kReductionSize = 4;
276  }];
277}
278
279#endif // ARMNEON_OPS
280