1//===-- ArmNeonOps.td - ArmNeon dialect op definitions -----*- tablegen -*-===// 2// 3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4// See https://llvm.org/LICENSE.txt for license information. 5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6// 7//===----------------------------------------------------------------------===// 8// 9// This file defines the basic operations for the ArmNeon dialect. 10// 11//===----------------------------------------------------------------------===// 12 13#ifndef ARMNEON_OPS 14#define ARMNEON_OPS 15 16include "mlir/Dialect/LLVMIR/LLVMOpBase.td" 17include "mlir/Interfaces/SideEffectInterfaces.td" 18include "mlir/IR/OpBase.td" 19 20//===----------------------------------------------------------------------===// 21// ArmNeon dialect definition 22//===----------------------------------------------------------------------===// 23 24def ArmNeon_Dialect : Dialect { 25 let name = "arm_neon"; 26 let cppNamespace = "::mlir::arm_neon"; 27 28 // Note: this does not need to depend on LLVMDialect as long as functions in 29 // this dialect (such as canonicalization) do not produce entities belonging 30 // to the LLVMDialect (ops or types). 31} 32 33//===----------------------------------------------------------------------===// 34// ArmNeon type definition 35//===----------------------------------------------------------------------===// 36 37class NeonVectorOfLength<int length, Type elementType> : ShapedContainerType< 38 [elementType], And<[IsVectorOfShape<[length]>, IsFixedVectorOfAnyRankTypePred]>, 39 "a vector with length " # length, 40 "::mlir::VectorType">; 41 42//===----------------------------------------------------------------------===// 43// ArmNeon op definitions 44//===----------------------------------------------------------------------===// 45 46// ArmNeon dialect op that corresponds (and is convertible to) an LLVM IR 47// intrinsic. 48class ArmNeon_IntrOp<string mnemonic, list<int> overloadedResults, 49 list<int> overloadedOperands, int numResults, 50 list<Trait> traits = [], bit requiresAccessGroup = 0, 51 bit requiresAliasAnalysis = 0> 52 : LLVM_IntrOpBase</*dialect=*/ArmNeon_Dialect, 53 /*opName=*/"intr." # mnemonic, 54 /*enumName=*/"aarch64_neon_" # !subst(".", "_", mnemonic), 55 /*overloadedResults=*/overloadedResults, 56 /*overloadedOperands=*/overloadedOperands, 57 /*traits=*/traits, 58 /*numResults=*/numResults, 59 /*requiresAccessGroup=*/requiresAccessGroup, 60 /*requiresAliasAnalysis=*/requiresAliasAnalysis>; 61 62// ArmNeon dialect op that corresponds to an LLVM IR intrinsic with one 63// overloaded result. 64class ArmNeon_OverloadedOneResultIntrOp<string mnemonic, 65 list<Trait> traits = []> 66 : ArmNeon_IntrOp<mnemonic, [0], [], 1, traits>; 67 68// ArmNeon dialect op that corresponds to an LLVM IR intrinsic with one 69// overloaded result and overloaded operands list. 70class ArmNeon_OverloadedOperandsWithOneResultIntrOp<string mnemonic, 71 list<int> overloadedOperands, 72 list<Trait> traits = []> 73 : ArmNeon_IntrOp<mnemonic, [0], overloadedOperands, 1, traits>; 74 75def SMullOp : ArmNeon_OverloadedOneResultIntrOp<"smull", [ 76 Pure, 77 AllTypesMatch<["a", "b"]>, 78 TypesMatchWith< 79 "res has same vector shape and element bitwidth scaled by 2 as a", 80 "a", "res", "::llvm::cast<VectorType>($_self).scaleElementBitwidth(2)"> 81 ]> { 82 let summary = "smull roundscale op"; 83 let description = [{ 84 Signed Multiply Long (vector). This instruction multiplies corresponding 85 signed integer values in the lower or upper half of the vectors of the two 86 source SIMD&FP registers, places the results in a vector, and writes the 87 vector to the destination SIMD&FP register. 88 89 Source: 90 https://developer.arm.com/architectures/instruction-sets/simd-isas/neon/intrinsics 91 }]; 92 93 // Supports either: 94 // (vector<8xi8>, vector<8xi8>) -> (vector<8xi16>) 95 // (vector<4xi16>, vector<4xi16>) -> (vector<4xi32>) 96 // (vector<2xi32>, vector<2xi32>) -> (vector<2xi64>) 97 let arguments = (ins VectorOfLengthAndType<[8, 4, 2], [I8, I16, I32]>:$a, 98 VectorOfLengthAndType<[8, 4, 2], [I8, I16, I32]>:$b); 99 let results = (outs VectorOfLengthAndType<[8, 4, 2], [I16, I32, I64]>:$res); 100 let assemblyFormat = 101 "$a `,` $b attr-dict `:` type($a) `to` type($res)"; 102} 103 104def SdotOp : ArmNeon_OverloadedOperandsWithOneResultIntrOp<"sdot",[1], [ 105 Pure, 106 AllTypesMatch<["b", "c"]>, 107 AllTypesMatch<["a", "res"]>, 108 TypesMatchWith<"res has the same number of elements as operand b", 109 "b", "res", 110 "VectorType::get({::llvm::cast<VectorType>($_self).getShape()[0] / 4}," 111 "IntegerType::get($_self.getContext(), 32))">]> { 112 let summary = "sdot op"; 113 let description = [{ 114 Signed integer addition of dot product (vector). This instruction performs 115 the following operation on signed integer vectors: res = dot(b, c) + a, 116 where vector operands are partitioned into groups of four elements. 117 118 Source: 119 https://developer.arm.com/architectures/instruction-sets/simd-isas/neon/intrinsics 120 }]; 121 // Supports either: 122 // (vector<2xi32>, vector<8xi8>, vector<8xi8>) -> vector<2xi32> 123 // (vector<4xi32>, vector<16xi8>, vector<16xi8>) -> vector<4xi32> 124 let arguments = (ins VectorOfLengthAndType<[4, 2], [I32]>:$a, 125 VectorOfLengthAndType<[16, 8], [I8]>:$b, 126 VectorOfLengthAndType<[16, 8], [I8]>:$c); 127 let results = (outs VectorOfLengthAndType<[4, 2], [I32]>:$res); 128 let assemblyFormat = 129 "$a `,` $b `,` $c attr-dict `:` type($b) `,` type($c) `to` type($res)"; 130 } 131 132def SmmlaOp : ArmNeon_OverloadedOperandsWithOneResultIntrOp<"smmla",[1], [ 133 Pure, 134 AllTypesMatch<["src1", "src2"]>, 135 AllTypesMatch<["acc", "res"]>, 136 ]> { 137 let summary = "Matrix-matrix multiply and accumulate op"; 138 let description = [{ 139 SMMLA: Signed integer matrix multiply-accumulate. 140 141 Signed 8-bit integer matrix multiply-accumulate. This instruction multiplies 142 the 2x8 matrix of signed 8-bit integer values in the first source vector by 143 the 8x2 matrix of signed 8-bit integer values in the second source vector. 144 The resulting 2x2 32-bit integer matrix product is destructively added to 145 the 32-bit integer matrix accumulator in the destination vector. This is 146 equivalent to performing an 8-way dot product per destination element. 147 148 Source: 149 https://developer.arm.com/architectures/instruction-sets/intrinsics/#f:@navigationhierarchiessimdisa=[Neon]&q=smmla 150 }]; 151 // Supports (vector<16xi8>, vector<16xi8>) -> (vector<4xi32>) 152 let arguments = (ins 153 NeonVectorOfLength<4, I32>:$acc, 154 NeonVectorOfLength<16, I8>:$src1, 155 NeonVectorOfLength<16, I8>:$src2 156 ); 157 let results = (outs NeonVectorOfLength<4, I32>:$res); 158 let assemblyFormat = 159 "$acc `,` $src1 `,` $src2 attr-dict `:` type($src1) `to` type($res)"; 160} 161 162def UmmlaOp : ArmNeon_OverloadedOperandsWithOneResultIntrOp<"ummla",[1], [ 163 Pure, 164 AllTypesMatch<["src1", "src2"]>, 165 AllTypesMatch<["acc", "res"]>, 166 ]> { 167 let summary = "Unsinged matrix-matrix multiply and accumulate op"; 168 let description = [{ 169 UMMLA: Signed integer matrix multiply-accumulate. 170 171 Unsigned 8-bit integer matrix multiply-accumulate. This instruction 172 multiplies the 2x8 matrix of unsigned 8-bit integer values in the first 173 source vector by the 8x2 matrix of unsigned 8-bit integer values in the 174 second source vector. The resulting 2x2 32-bit integer matrix product is 175 destructively added to the 32-bit integer matrix accumulator in the 176 destination vector. This is equivalent to performing an 8-way dot product 177 per destination element. 178 179 Source: 180 https://developer.arm.com/architectures/instruction-sets/intrinsics/#f:@navigationhierarchiessimdisa=[Neon]&q=ummla 181 }]; 182 // Supports (vector<16xi8>, vector<16xi8>) -> (vector<4xi32>) 183 let arguments = (ins 184 NeonVectorOfLength<4, I32>:$acc, 185 NeonVectorOfLength<16, I8>:$src1, 186 NeonVectorOfLength<16, I8>:$src2 187 ); 188 let results = (outs NeonVectorOfLength<4, I32>:$res); 189 let assemblyFormat = 190 "$acc `,` $src1 `,` $src2 attr-dict `:` type($src1) `to` type($res)"; 191} 192 193def UsmmlaOp : ArmNeon_OverloadedOperandsWithOneResultIntrOp<"usmmla",[1], [ 194 Pure, 195 AllTypesMatch<["src1", "src2"]>, 196 AllTypesMatch<["acc", "res"]>, 197 ]> { 198 let summary = "Unsignged and signed matrix-matrix multiply and accumulate op"; 199 let description = [{ 200 USMMLA: Signed integer matrix multiply-accumulate. 201 202 Unsigned and signed 8-bit integer matrix multiply-accumulate. This 203 instruction multiplies the 2x8 matrix of unsigned 8-bit integer values in 204 the first source vector by the 8x2 matrix of signed 8-bit integer values in 205 the second source vector. The resulting 2x2 32-bit integer matrix product is 206 destructively added to the 32-bit integer matrix accumulator in the 207 destination vector. This is equivalent to performing an 8-way dot product 208 per destination element. 209 210 211 Source: 212 https://developer.arm.com/architectures/instruction-sets/intrinsics/#f:@navigationhierarchiessimdisa=[Neon]&q=usmmla 213 }]; 214 // Supports (vector<16xi8>, vector<16xi8>) -> (vector<4xi32>) 215 let arguments = (ins 216 NeonVectorOfLength<4, I32>:$acc, 217 NeonVectorOfLength<16, I8>:$src1, 218 NeonVectorOfLength<16, I8>:$src2 219 ); 220 let results = (outs NeonVectorOfLength<4, I32>:$res); 221 let assemblyFormat = 222 "$acc `,` $src1 `,` $src2 attr-dict `:` type($src1) `to` type($res)"; 223} 224 225class ArmNeon_2dOp<string mnemonic, list<Trait> traits = []> 226 : Op</*dialect=*/ArmNeon_Dialect, 227 /*opName=*/"2d." # mnemonic, 228 /*traits=*/traits>; 229 230def Sdot2dOp : ArmNeon_2dOp<"sdot", [ 231 Pure, 232 AllTypesMatch<["b", "c"]>, 233 AllTypesMatch<["a", "res"]>, 234 PredOpTrait< 235 "operand `a` should be 1-dimensional", 236 CPred<"::llvm::cast<VectorType>(getA().getType()).getShape().size() == 1"> 237 >, 238 PredOpTrait< 239 "operand `b` should be 2-dimensional", 240 CPred<"::llvm::cast<VectorType>(getB().getType()).getShape().size() == 2"> 241 >, 242 PredOpTrait< 243 "operand `b` should have 4 columns", 244 CPred<"::llvm::cast<VectorType>(getB().getType()).getShape()[1] == 4"> 245 >, 246 PredOpTrait< 247 "operand `b` should have as many rows as the size of operand `a`", 248 CPred<"::llvm::cast<VectorType>(getB().getType()).getShape()[0] == ::llvm::cast<VectorType>(getA().getType()).getShape()[0]"> 249 >, 250 ] 251 > { 252 let summary = "sdot op"; 253 let description = [{ 254 The two input vectors `b` and `c` have a 2D shape, consisting of either 2 255 or 4 rows, each row having length 4. This operation computes the pair-wise 256 dot-products of the rows of `b` and `c` and accumulates them with the 257 corresponding entry of `a`: 258 259 ``` 260 res[i] := a[i] + dot_product(b[i, ...], c[i, ...]) 261 ``` 262 263 }]; 264 // Supports either: 265 // (vector<2xi32>, vector<2x4xi8>, vector<2x4xi8>) -> vector<2xi32> 266 // (vector<4xi32>, vector<4x4xi8>, vector<4x4xi8>) -> vector<4xi32> 267 // TODO: how do we express 2D shape requirements here? 268 let arguments = (ins VectorOfLengthAndType<[4, 2], [I32]>:$a, 269 VectorOfLengthAndType<[16, 8], [I8]>:$b, 270 VectorOfLengthAndType<[16, 8], [I8]>:$c); 271 let results = (outs VectorOfLengthAndType<[4, 2], [I32]>:$res); 272 let assemblyFormat = 273 "$a `,` $b `,` $c attr-dict `:` type($b) `,` type($c) `to` type($res)"; 274 let extraClassDeclaration = [{ 275 static constexpr int kReductionSize = 4; 276 }]; 277} 278 279#endif // ARMNEON_OPS 280