xref: /llvm-project/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEIntrinsicOps.td (revision 1dfb104eac73863b06751bea225ffa6ef589577f)
1//===-- ArmSMEIntrinsicOps.td ------------------------------*- tablegen -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file contains definitions of the intrinsic Ops for the ArmSME dialect.
10//
11//===----------------------------------------------------------------------===//
12
13#ifndef ARMSME_INTRINSIC_OPS
14#define ARMSME_INTRINSIC_OPS
15
16include "ArmSME.td"
17
18//===----------------------------------------------------------------------===//
19// ArmSME Intrinsic op definitions
20//===----------------------------------------------------------------------===//
21
22def MOPPredicate : ScalableVectorOfRankAndLengthAndType<[1], [16, 8, 4, 2], [I1]>
23{
24  let summary = "a vector type that is a supported predicate for the SME MOP instructions";
25  let description = [{
26    Possible vector types:
27
28    * `vector<[16]xi1>`
29    * `vector<[8]xi1>`
30    * `vector<[4]xi1>`
31    * `vector<[2]xi1>`
32  }];
33}
34
35// FIXME: This allows types that are not SVE vectors, e.g. vector<[16]xf32>.
36def MOPVector : ScalableVectorOfRankAndLengthAndType<[1], [16, 8, 4, 2],
37                                              [I8, I16, BF16, F16, F32, F64]>
38{
39  let summary = "a vector type that is a supported input for the SME MOP instructions";
40  let description = [{
41    Possible vector types:
42
43    Integer elements:
44
45    * `vector<[16]xi8>`
46    * `vector<[8]xi16>`
47
48    Floating point elements:
49
50    * `vector<[8]xf16>`
51    * `vector<[8]xbf16>`
52    * `vector<[4]xf32>`
53    * `vector<[2]xf64>`
54  }];
55}
56
57class ArmSME_IntrOp<string mnemonic,
58                    list<int> immArgPositions = [],
59                    list<string> immArgAttrNames = [],
60                    list<int> overloadedOperands = [],
61                    list<Trait> traits = [], int numResults = 0,
62                    list<int> overloadedResults = []>
63    : LLVM_IntrOpBase<
64          /*Dialect dialect=*/ArmSME_Dialect,
65          /*string opName=*/"intr." # mnemonic,
66          /*string enumName=*/"aarch64_sme_" # !subst(".", "_", mnemonic),
67          /*list<int> overloadedResults=*/overloadedResults,
68          /*list<int> overloadedOperands=*/overloadedOperands,
69          /*list<Trait> traits=*/traits,
70          /*int numResults=*/numResults,
71          /*bit requiresAccessGroup=*/0,
72          /*bit requiresAliasAnalysis=*/0,
73          /*bit requiresFastmath=*/0,
74          /*bit requiresOpBundles=*/0,
75          /*list<int> immArgPositions=*/immArgPositions,
76          /*list<string> immArgAttrNames=*/immArgAttrNames>;
77
78// Zero
79def LLVM_aarch64_sme_zero
80   : ArmSME_IntrOp<"zero",
81                   /*immArgPositions=*/[0],
82                   /*immArgAttrNames=*/["tile_mask"]>,
83     Arguments<(ins Arg<I32Attr, "Tile mask">:$tile_mask)>;
84
85// MOP's
86class ArmSME_IntrMopOverloadedOp<string mnemonic>
87    : ArmSME_IntrOp<mnemonic,
88                    /*immArgPositions=*/[0],
89                    /*immArgAttrNames=*/["tile_id"],
90                    /*overloadedOperands=*/[4]>,
91      Arguments<(ins Arg<I32Attr, "Virtual tile ID">:$tile_id,
92                 Arg<MOPPredicate, "LHS predicate">:$lhs_predicate,
93                 Arg<MOPPredicate, "RHS predicate">:$rhs_predicate,
94                 Arg<MOPVector, "LHS vector operand">:$lhs_vector,
95                 Arg<MOPVector, "RHS vector operand">:$rhs_vector)>;
96
97def LLVM_aarch64_sme_mopa : ArmSME_IntrMopOverloadedOp<"mopa">;
98def LLVM_aarch64_sme_mops : ArmSME_IntrMopOverloadedOp<"mops">;
99def LLVM_aarch64_sme_mopa_wide : ArmSME_IntrMopOverloadedOp<"mopa.wide">;
100def LLVM_aarch64_sme_mops_wide : ArmSME_IntrMopOverloadedOp<"mops.wide">;
101def LLVM_aarch64_sme_smopa_wide : ArmSME_IntrMopOverloadedOp<"smopa.wide">;
102def LLVM_aarch64_sme_smops_wide : ArmSME_IntrMopOverloadedOp<"smops.wide">;
103def LLVM_aarch64_sme_umopa_wide : ArmSME_IntrMopOverloadedOp<"umopa.wide">;
104def LLVM_aarch64_sme_umops_wide : ArmSME_IntrMopOverloadedOp<"umops.wide">;
105def LLVM_aarch64_sme_sumopa_wide : ArmSME_IntrMopOverloadedOp<"sumopa.wide">;
106def LLVM_aarch64_sme_sumops_wide : ArmSME_IntrMopOverloadedOp<"sumops.wide">;
107def LLVM_aarch64_sme_usmopa_wide : ArmSME_IntrMopOverloadedOp<"usmopa.wide">;
108def LLVM_aarch64_sme_usmops_wide : ArmSME_IntrMopOverloadedOp<"usmops.wide">;
109def LLVM_aarch64_sme_smopa_za32 : ArmSME_IntrMopOverloadedOp<"smopa.za32">;
110def LLVM_aarch64_sme_umopa_za32 : ArmSME_IntrMopOverloadedOp<"umopa.za32">;
111def LLVM_aarch64_sme_smops_za32 : ArmSME_IntrMopOverloadedOp<"smops.za32">;
112def LLVM_aarch64_sme_umops_za32 : ArmSME_IntrMopOverloadedOp<"umops.za32">;
113
114class ArmSME_IntrLoadStoreOp<string mnemonic>
115    : ArmSME_IntrOp<mnemonic,
116                    /*immArgPositions=*/[2],
117                    /*immArgAttrNames=*/["tile_id"]>;
118
119// Loads (from memory to ZA tile slice)
120class ArmSME_IntrLoadOp<string mnemonic>
121    : ArmSME_IntrLoadStoreOp<mnemonic>,
122      Arguments<(ins Arg<SVEPredicate, "Vector predicate">:$predicate,
123                 Arg<LLVM_AnyPointer, "Load address">:$load_address,
124                 Arg<I32Attr, "Virtual tile ID">:$tile_id,
125                 Arg<I32, "Tile slice">:$tile_slice_index)>;
126
127def LLVM_aarch64_sme_ld1b_horiz : ArmSME_IntrLoadOp<"ld1b.horiz">;
128def LLVM_aarch64_sme_ld1h_horiz : ArmSME_IntrLoadOp<"ld1h.horiz">;
129def LLVM_aarch64_sme_ld1w_horiz : ArmSME_IntrLoadOp<"ld1w.horiz">;
130def LLVM_aarch64_sme_ld1d_horiz : ArmSME_IntrLoadOp<"ld1d.horiz">;
131def LLVM_aarch64_sme_ld1q_horiz : ArmSME_IntrLoadOp<"ld1q.horiz">;
132def LLVM_aarch64_sme_ld1b_vert : ArmSME_IntrLoadOp<"ld1b.vert">;
133def LLVM_aarch64_sme_ld1h_vert : ArmSME_IntrLoadOp<"ld1h.vert">;
134def LLVM_aarch64_sme_ld1w_vert : ArmSME_IntrLoadOp<"ld1w.vert">;
135def LLVM_aarch64_sme_ld1d_vert : ArmSME_IntrLoadOp<"ld1d.vert">;
136def LLVM_aarch64_sme_ld1q_vert : ArmSME_IntrLoadOp<"ld1q.vert">;
137
138// Stores (ZA tile slice to memory)
139class ArmSME_IntrStoreOp<string mnemonic>
140    : ArmSME_IntrLoadStoreOp<mnemonic>,
141      Arguments<(ins Arg<SVEPredicate, "Vector predicate">:$predicate,
142                 Arg<LLVM_AnyPointer, "Store address", [MemWrite]>:$store_address,
143                 Arg<I32Attr, "Virtual tile ID">:$tile_id,
144                 Arg<I32, "Tile slice">:$tile_slice_index)>;
145
146def LLVM_aarch64_sme_st1b_horiz : ArmSME_IntrStoreOp<"st1b.horiz">;
147def LLVM_aarch64_sme_st1h_horiz : ArmSME_IntrStoreOp<"st1h.horiz">;
148def LLVM_aarch64_sme_st1w_horiz : ArmSME_IntrStoreOp<"st1w.horiz">;
149def LLVM_aarch64_sme_st1d_horiz : ArmSME_IntrStoreOp<"st1d.horiz">;
150def LLVM_aarch64_sme_st1q_horiz : ArmSME_IntrStoreOp<"st1q.horiz">;
151def LLVM_aarch64_sme_st1b_vert : ArmSME_IntrStoreOp<"st1b.vert">;
152def LLVM_aarch64_sme_st1h_vert : ArmSME_IntrStoreOp<"st1h.vert">;
153def LLVM_aarch64_sme_st1w_vert : ArmSME_IntrStoreOp<"st1w.vert">;
154def LLVM_aarch64_sme_st1d_vert : ArmSME_IntrStoreOp<"st1d.vert">;
155def LLVM_aarch64_sme_st1q_vert : ArmSME_IntrStoreOp<"st1q.vert">;
156
157def LLVM_aarch64_sme_str
158    : ArmSME_IntrOp<"str">,
159      Arguments<(ins Arg<I32, "Index">:$index,
160                 Arg<LLVM_AnyPointer, "Store address", [MemWrite]>:$store_address,
161                 Arg<I32, "Offset">:$offset)>;
162
163// Vector to tile slice
164class LLVM_aarch64_sme_write<string direction>
165    : ArmSME_IntrOp<"write." # direction,
166                    /*immArgPositions=*/[0],
167                    /*immArgAttrNames=*/["tile_id"],
168                    /*overloadedOperands=*/[3],
169                    [AllShapesMatch<["predicate", "vector"]>]>,
170      Arguments<(ins Arg<I32Attr, "Virtual tile ID">:$tile_id,
171                     Arg<I32, "Tile slice">:$tile_slice_index,
172                     Arg<SVEPredicate, "Vector predicate">:$predicate,
173                     Arg<SVEVector, "Vector operand">:$vector)>;
174
175// Tile slice to vector
176class LLVM_aarch64_sme_read<string direction>
177    : ArmSME_IntrOp<"read." # direction,
178                    /*immArgPositions=*/[2],
179                    /*immArgAttrNames=*/["tile_id"],
180                    /*overloadedOperands=*/[],
181                    [AllShapesMatch<["vector", "predicate", "res"]>,
182                     AllElementTypesMatch<["vector", "res"]>],
183                    /*numResults=*/1, /*overloadedResults=*/[0]>,
184      Arguments<(ins Arg<SVEVector, "Vector operand">:$vector,
185                     Arg<SVEPredicate, "Vector predicate">:$predicate,
186                     Arg<I32Attr, "Virtual tile ID">:$tile_id,
187                     Arg<I32, "Tile slice">:$tile_slice_index)>;
188
189def LLVM_aarch64_sme_write_horiz : LLVM_aarch64_sme_write<"horiz">;
190def LLVM_aarch64_sme_write_vert : LLVM_aarch64_sme_write<"vert">;
191
192def LLVM_aarch64_sme_read_horiz : LLVM_aarch64_sme_read<"horiz">;
193def LLVM_aarch64_sme_read_vert : LLVM_aarch64_sme_read<"vert">;
194
195class ArmSME_IntrCountOp<string mnemonic>
196    : ArmSME_IntrOp<mnemonic,
197                    /*immArgPositions=*/[],
198                    /*immArgAttrNames=*/[],
199                    /*overloadedOperands=*/[],
200                    /*traits*/[PredOpTrait<"`res` is i64", TypeIsPred<"res", I64>>],
201                    /*numResults=*/1, /*overloadedResults=*/[]>;
202
203def LLVM_aarch64_sme_cntsb : ArmSME_IntrCountOp<"cntsb">;
204def LLVM_aarch64_sme_cntsh : ArmSME_IntrCountOp<"cntsh">;
205def LLVM_aarch64_sme_cntsw : ArmSME_IntrCountOp<"cntsw">;
206def LLVM_aarch64_sme_cntsd : ArmSME_IntrCountOp<"cntsd">;
207
208#endif // ARMSME_INTRINSIC_OPS
209