1*12c85518Srobert //===- RISCVVIntrinsicUtils.cpp - RISC-V Vector Intrinsic Utils -*- C++ -*-===//
2*12c85518Srobert //
3*12c85518Srobert // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4*12c85518Srobert // See https://llvm.org/LICENSE.txt for license information.
5*12c85518Srobert // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6*12c85518Srobert //
7*12c85518Srobert //===----------------------------------------------------------------------===//
8*12c85518Srobert
9*12c85518Srobert #include "clang/Support/RISCVVIntrinsicUtils.h"
10*12c85518Srobert #include "llvm/ADT/ArrayRef.h"
11*12c85518Srobert #include "llvm/ADT/SmallSet.h"
12*12c85518Srobert #include "llvm/ADT/StringExtras.h"
13*12c85518Srobert #include "llvm/ADT/StringMap.h"
14*12c85518Srobert #include "llvm/ADT/StringSet.h"
15*12c85518Srobert #include "llvm/ADT/Twine.h"
16*12c85518Srobert #include "llvm/Support/ErrorHandling.h"
17*12c85518Srobert #include "llvm/Support/raw_ostream.h"
18*12c85518Srobert #include <numeric>
19*12c85518Srobert #include <optional>
20*12c85518Srobert
21*12c85518Srobert using namespace llvm;
22*12c85518Srobert
23*12c85518Srobert namespace clang {
24*12c85518Srobert namespace RISCV {
25*12c85518Srobert
26*12c85518Srobert const PrototypeDescriptor PrototypeDescriptor::Mask = PrototypeDescriptor(
27*12c85518Srobert BaseTypeModifier::Vector, VectorTypeModifier::MaskVector);
28*12c85518Srobert const PrototypeDescriptor PrototypeDescriptor::VL =
29*12c85518Srobert PrototypeDescriptor(BaseTypeModifier::SizeT);
30*12c85518Srobert const PrototypeDescriptor PrototypeDescriptor::Vector =
31*12c85518Srobert PrototypeDescriptor(BaseTypeModifier::Vector);
32*12c85518Srobert
33*12c85518Srobert //===----------------------------------------------------------------------===//
34*12c85518Srobert // Type implementation
35*12c85518Srobert //===----------------------------------------------------------------------===//
36*12c85518Srobert
LMULType(int NewLog2LMUL)37*12c85518Srobert LMULType::LMULType(int NewLog2LMUL) {
38*12c85518Srobert // Check Log2LMUL is -3, -2, -1, 0, 1, 2, 3
39*12c85518Srobert assert(NewLog2LMUL <= 3 && NewLog2LMUL >= -3 && "Bad LMUL number!");
40*12c85518Srobert Log2LMUL = NewLog2LMUL;
41*12c85518Srobert }
42*12c85518Srobert
str() const43*12c85518Srobert std::string LMULType::str() const {
44*12c85518Srobert if (Log2LMUL < 0)
45*12c85518Srobert return "mf" + utostr(1ULL << (-Log2LMUL));
46*12c85518Srobert return "m" + utostr(1ULL << Log2LMUL);
47*12c85518Srobert }
48*12c85518Srobert
getScale(unsigned ElementBitwidth) const49*12c85518Srobert VScaleVal LMULType::getScale(unsigned ElementBitwidth) const {
50*12c85518Srobert int Log2ScaleResult = 0;
51*12c85518Srobert switch (ElementBitwidth) {
52*12c85518Srobert default:
53*12c85518Srobert break;
54*12c85518Srobert case 8:
55*12c85518Srobert Log2ScaleResult = Log2LMUL + 3;
56*12c85518Srobert break;
57*12c85518Srobert case 16:
58*12c85518Srobert Log2ScaleResult = Log2LMUL + 2;
59*12c85518Srobert break;
60*12c85518Srobert case 32:
61*12c85518Srobert Log2ScaleResult = Log2LMUL + 1;
62*12c85518Srobert break;
63*12c85518Srobert case 64:
64*12c85518Srobert Log2ScaleResult = Log2LMUL;
65*12c85518Srobert break;
66*12c85518Srobert }
67*12c85518Srobert // Illegal vscale result would be less than 1
68*12c85518Srobert if (Log2ScaleResult < 0)
69*12c85518Srobert return std::nullopt;
70*12c85518Srobert return 1 << Log2ScaleResult;
71*12c85518Srobert }
72*12c85518Srobert
MulLog2LMUL(int log2LMUL)73*12c85518Srobert void LMULType::MulLog2LMUL(int log2LMUL) { Log2LMUL += log2LMUL; }
74*12c85518Srobert
RVVType(BasicType BT,int Log2LMUL,const PrototypeDescriptor & prototype)75*12c85518Srobert RVVType::RVVType(BasicType BT, int Log2LMUL,
76*12c85518Srobert const PrototypeDescriptor &prototype)
77*12c85518Srobert : BT(BT), LMUL(LMULType(Log2LMUL)) {
78*12c85518Srobert applyBasicType();
79*12c85518Srobert applyModifier(prototype);
80*12c85518Srobert Valid = verifyType();
81*12c85518Srobert if (Valid) {
82*12c85518Srobert initBuiltinStr();
83*12c85518Srobert initTypeStr();
84*12c85518Srobert if (isVector()) {
85*12c85518Srobert initClangBuiltinStr();
86*12c85518Srobert }
87*12c85518Srobert }
88*12c85518Srobert }
89*12c85518Srobert
90*12c85518Srobert // clang-format off
91*12c85518Srobert // boolean type are encoded the ratio of n (SEW/LMUL)
92*12c85518Srobert // SEW/LMUL | 1 | 2 | 4 | 8 | 16 | 32 | 64
93*12c85518Srobert // c type | vbool64_t | vbool32_t | vbool16_t | vbool8_t | vbool4_t | vbool2_t | vbool1_t
94*12c85518Srobert // IR type | nxv1i1 | nxv2i1 | nxv4i1 | nxv8i1 | nxv16i1 | nxv32i1 | nxv64i1
95*12c85518Srobert
96*12c85518Srobert // type\lmul | 1/8 | 1/4 | 1/2 | 1 | 2 | 4 | 8
97*12c85518Srobert // -------- |------ | -------- | ------- | ------- | -------- | -------- | --------
98*12c85518Srobert // i64 | N/A | N/A | N/A | nxv1i64 | nxv2i64 | nxv4i64 | nxv8i64
99*12c85518Srobert // i32 | N/A | N/A | nxv1i32 | nxv2i32 | nxv4i32 | nxv8i32 | nxv16i32
100*12c85518Srobert // i16 | N/A | nxv1i16 | nxv2i16 | nxv4i16 | nxv8i16 | nxv16i16 | nxv32i16
101*12c85518Srobert // i8 | nxv1i8 | nxv2i8 | nxv4i8 | nxv8i8 | nxv16i8 | nxv32i8 | nxv64i8
102*12c85518Srobert // double | N/A | N/A | N/A | nxv1f64 | nxv2f64 | nxv4f64 | nxv8f64
103*12c85518Srobert // float | N/A | N/A | nxv1f32 | nxv2f32 | nxv4f32 | nxv8f32 | nxv16f32
104*12c85518Srobert // half | N/A | nxv1f16 | nxv2f16 | nxv4f16 | nxv8f16 | nxv16f16 | nxv32f16
105*12c85518Srobert // clang-format on
106*12c85518Srobert
verifyType() const107*12c85518Srobert bool RVVType::verifyType() const {
108*12c85518Srobert if (ScalarType == Invalid)
109*12c85518Srobert return false;
110*12c85518Srobert if (isScalar())
111*12c85518Srobert return true;
112*12c85518Srobert if (!Scale)
113*12c85518Srobert return false;
114*12c85518Srobert if (isFloat() && ElementBitwidth == 8)
115*12c85518Srobert return false;
116*12c85518Srobert unsigned V = *Scale;
117*12c85518Srobert switch (ElementBitwidth) {
118*12c85518Srobert case 1:
119*12c85518Srobert case 8:
120*12c85518Srobert // Check Scale is 1,2,4,8,16,32,64
121*12c85518Srobert return (V <= 64 && isPowerOf2_32(V));
122*12c85518Srobert case 16:
123*12c85518Srobert // Check Scale is 1,2,4,8,16,32
124*12c85518Srobert return (V <= 32 && isPowerOf2_32(V));
125*12c85518Srobert case 32:
126*12c85518Srobert // Check Scale is 1,2,4,8,16
127*12c85518Srobert return (V <= 16 && isPowerOf2_32(V));
128*12c85518Srobert case 64:
129*12c85518Srobert // Check Scale is 1,2,4,8
130*12c85518Srobert return (V <= 8 && isPowerOf2_32(V));
131*12c85518Srobert }
132*12c85518Srobert return false;
133*12c85518Srobert }
134*12c85518Srobert
initBuiltinStr()135*12c85518Srobert void RVVType::initBuiltinStr() {
136*12c85518Srobert assert(isValid() && "RVVType is invalid");
137*12c85518Srobert switch (ScalarType) {
138*12c85518Srobert case ScalarTypeKind::Void:
139*12c85518Srobert BuiltinStr = "v";
140*12c85518Srobert return;
141*12c85518Srobert case ScalarTypeKind::Size_t:
142*12c85518Srobert BuiltinStr = "z";
143*12c85518Srobert if (IsImmediate)
144*12c85518Srobert BuiltinStr = "I" + BuiltinStr;
145*12c85518Srobert if (IsPointer)
146*12c85518Srobert BuiltinStr += "*";
147*12c85518Srobert return;
148*12c85518Srobert case ScalarTypeKind::Ptrdiff_t:
149*12c85518Srobert BuiltinStr = "Y";
150*12c85518Srobert return;
151*12c85518Srobert case ScalarTypeKind::UnsignedLong:
152*12c85518Srobert BuiltinStr = "ULi";
153*12c85518Srobert return;
154*12c85518Srobert case ScalarTypeKind::SignedLong:
155*12c85518Srobert BuiltinStr = "Li";
156*12c85518Srobert return;
157*12c85518Srobert case ScalarTypeKind::Boolean:
158*12c85518Srobert assert(ElementBitwidth == 1);
159*12c85518Srobert BuiltinStr += "b";
160*12c85518Srobert break;
161*12c85518Srobert case ScalarTypeKind::SignedInteger:
162*12c85518Srobert case ScalarTypeKind::UnsignedInteger:
163*12c85518Srobert switch (ElementBitwidth) {
164*12c85518Srobert case 8:
165*12c85518Srobert BuiltinStr += "c";
166*12c85518Srobert break;
167*12c85518Srobert case 16:
168*12c85518Srobert BuiltinStr += "s";
169*12c85518Srobert break;
170*12c85518Srobert case 32:
171*12c85518Srobert BuiltinStr += "i";
172*12c85518Srobert break;
173*12c85518Srobert case 64:
174*12c85518Srobert BuiltinStr += "Wi";
175*12c85518Srobert break;
176*12c85518Srobert default:
177*12c85518Srobert llvm_unreachable("Unhandled ElementBitwidth!");
178*12c85518Srobert }
179*12c85518Srobert if (isSignedInteger())
180*12c85518Srobert BuiltinStr = "S" + BuiltinStr;
181*12c85518Srobert else
182*12c85518Srobert BuiltinStr = "U" + BuiltinStr;
183*12c85518Srobert break;
184*12c85518Srobert case ScalarTypeKind::Float:
185*12c85518Srobert switch (ElementBitwidth) {
186*12c85518Srobert case 16:
187*12c85518Srobert BuiltinStr += "x";
188*12c85518Srobert break;
189*12c85518Srobert case 32:
190*12c85518Srobert BuiltinStr += "f";
191*12c85518Srobert break;
192*12c85518Srobert case 64:
193*12c85518Srobert BuiltinStr += "d";
194*12c85518Srobert break;
195*12c85518Srobert default:
196*12c85518Srobert llvm_unreachable("Unhandled ElementBitwidth!");
197*12c85518Srobert }
198*12c85518Srobert break;
199*12c85518Srobert default:
200*12c85518Srobert llvm_unreachable("ScalarType is invalid!");
201*12c85518Srobert }
202*12c85518Srobert if (IsImmediate)
203*12c85518Srobert BuiltinStr = "I" + BuiltinStr;
204*12c85518Srobert if (isScalar()) {
205*12c85518Srobert if (IsConstant)
206*12c85518Srobert BuiltinStr += "C";
207*12c85518Srobert if (IsPointer)
208*12c85518Srobert BuiltinStr += "*";
209*12c85518Srobert return;
210*12c85518Srobert }
211*12c85518Srobert BuiltinStr = "q" + utostr(*Scale) + BuiltinStr;
212*12c85518Srobert // Pointer to vector types. Defined for segment load intrinsics.
213*12c85518Srobert // segment load intrinsics have pointer type arguments to store the loaded
214*12c85518Srobert // vector values.
215*12c85518Srobert if (IsPointer)
216*12c85518Srobert BuiltinStr += "*";
217*12c85518Srobert }
218*12c85518Srobert
initClangBuiltinStr()219*12c85518Srobert void RVVType::initClangBuiltinStr() {
220*12c85518Srobert assert(isValid() && "RVVType is invalid");
221*12c85518Srobert assert(isVector() && "Handle Vector type only");
222*12c85518Srobert
223*12c85518Srobert ClangBuiltinStr = "__rvv_";
224*12c85518Srobert switch (ScalarType) {
225*12c85518Srobert case ScalarTypeKind::Boolean:
226*12c85518Srobert ClangBuiltinStr += "bool" + utostr(64 / *Scale) + "_t";
227*12c85518Srobert return;
228*12c85518Srobert case ScalarTypeKind::Float:
229*12c85518Srobert ClangBuiltinStr += "float";
230*12c85518Srobert break;
231*12c85518Srobert case ScalarTypeKind::SignedInteger:
232*12c85518Srobert ClangBuiltinStr += "int";
233*12c85518Srobert break;
234*12c85518Srobert case ScalarTypeKind::UnsignedInteger:
235*12c85518Srobert ClangBuiltinStr += "uint";
236*12c85518Srobert break;
237*12c85518Srobert default:
238*12c85518Srobert llvm_unreachable("ScalarTypeKind is invalid");
239*12c85518Srobert }
240*12c85518Srobert ClangBuiltinStr += utostr(ElementBitwidth) + LMUL.str() + "_t";
241*12c85518Srobert }
242*12c85518Srobert
initTypeStr()243*12c85518Srobert void RVVType::initTypeStr() {
244*12c85518Srobert assert(isValid() && "RVVType is invalid");
245*12c85518Srobert
246*12c85518Srobert if (IsConstant)
247*12c85518Srobert Str += "const ";
248*12c85518Srobert
249*12c85518Srobert auto getTypeString = [&](StringRef TypeStr) {
250*12c85518Srobert if (isScalar())
251*12c85518Srobert return Twine(TypeStr + Twine(ElementBitwidth) + "_t").str();
252*12c85518Srobert return Twine("v" + TypeStr + Twine(ElementBitwidth) + LMUL.str() + "_t")
253*12c85518Srobert .str();
254*12c85518Srobert };
255*12c85518Srobert
256*12c85518Srobert switch (ScalarType) {
257*12c85518Srobert case ScalarTypeKind::Void:
258*12c85518Srobert Str = "void";
259*12c85518Srobert return;
260*12c85518Srobert case ScalarTypeKind::Size_t:
261*12c85518Srobert Str = "size_t";
262*12c85518Srobert if (IsPointer)
263*12c85518Srobert Str += " *";
264*12c85518Srobert return;
265*12c85518Srobert case ScalarTypeKind::Ptrdiff_t:
266*12c85518Srobert Str = "ptrdiff_t";
267*12c85518Srobert return;
268*12c85518Srobert case ScalarTypeKind::UnsignedLong:
269*12c85518Srobert Str = "unsigned long";
270*12c85518Srobert return;
271*12c85518Srobert case ScalarTypeKind::SignedLong:
272*12c85518Srobert Str = "long";
273*12c85518Srobert return;
274*12c85518Srobert case ScalarTypeKind::Boolean:
275*12c85518Srobert if (isScalar())
276*12c85518Srobert Str += "bool";
277*12c85518Srobert else
278*12c85518Srobert // Vector bool is special case, the formulate is
279*12c85518Srobert // `vbool<N>_t = MVT::nxv<64/N>i1` ex. vbool16_t = MVT::4i1
280*12c85518Srobert Str += "vbool" + utostr(64 / *Scale) + "_t";
281*12c85518Srobert break;
282*12c85518Srobert case ScalarTypeKind::Float:
283*12c85518Srobert if (isScalar()) {
284*12c85518Srobert if (ElementBitwidth == 64)
285*12c85518Srobert Str += "double";
286*12c85518Srobert else if (ElementBitwidth == 32)
287*12c85518Srobert Str += "float";
288*12c85518Srobert else if (ElementBitwidth == 16)
289*12c85518Srobert Str += "_Float16";
290*12c85518Srobert else
291*12c85518Srobert llvm_unreachable("Unhandled floating type.");
292*12c85518Srobert } else
293*12c85518Srobert Str += getTypeString("float");
294*12c85518Srobert break;
295*12c85518Srobert case ScalarTypeKind::SignedInteger:
296*12c85518Srobert Str += getTypeString("int");
297*12c85518Srobert break;
298*12c85518Srobert case ScalarTypeKind::UnsignedInteger:
299*12c85518Srobert Str += getTypeString("uint");
300*12c85518Srobert break;
301*12c85518Srobert default:
302*12c85518Srobert llvm_unreachable("ScalarType is invalid!");
303*12c85518Srobert }
304*12c85518Srobert if (IsPointer)
305*12c85518Srobert Str += " *";
306*12c85518Srobert }
307*12c85518Srobert
initShortStr()308*12c85518Srobert void RVVType::initShortStr() {
309*12c85518Srobert switch (ScalarType) {
310*12c85518Srobert case ScalarTypeKind::Boolean:
311*12c85518Srobert assert(isVector());
312*12c85518Srobert ShortStr = "b" + utostr(64 / *Scale);
313*12c85518Srobert return;
314*12c85518Srobert case ScalarTypeKind::Float:
315*12c85518Srobert ShortStr = "f" + utostr(ElementBitwidth);
316*12c85518Srobert break;
317*12c85518Srobert case ScalarTypeKind::SignedInteger:
318*12c85518Srobert ShortStr = "i" + utostr(ElementBitwidth);
319*12c85518Srobert break;
320*12c85518Srobert case ScalarTypeKind::UnsignedInteger:
321*12c85518Srobert ShortStr = "u" + utostr(ElementBitwidth);
322*12c85518Srobert break;
323*12c85518Srobert default:
324*12c85518Srobert llvm_unreachable("Unhandled case!");
325*12c85518Srobert }
326*12c85518Srobert if (isVector())
327*12c85518Srobert ShortStr += LMUL.str();
328*12c85518Srobert }
329*12c85518Srobert
applyBasicType()330*12c85518Srobert void RVVType::applyBasicType() {
331*12c85518Srobert switch (BT) {
332*12c85518Srobert case BasicType::Int8:
333*12c85518Srobert ElementBitwidth = 8;
334*12c85518Srobert ScalarType = ScalarTypeKind::SignedInteger;
335*12c85518Srobert break;
336*12c85518Srobert case BasicType::Int16:
337*12c85518Srobert ElementBitwidth = 16;
338*12c85518Srobert ScalarType = ScalarTypeKind::SignedInteger;
339*12c85518Srobert break;
340*12c85518Srobert case BasicType::Int32:
341*12c85518Srobert ElementBitwidth = 32;
342*12c85518Srobert ScalarType = ScalarTypeKind::SignedInteger;
343*12c85518Srobert break;
344*12c85518Srobert case BasicType::Int64:
345*12c85518Srobert ElementBitwidth = 64;
346*12c85518Srobert ScalarType = ScalarTypeKind::SignedInteger;
347*12c85518Srobert break;
348*12c85518Srobert case BasicType::Float16:
349*12c85518Srobert ElementBitwidth = 16;
350*12c85518Srobert ScalarType = ScalarTypeKind::Float;
351*12c85518Srobert break;
352*12c85518Srobert case BasicType::Float32:
353*12c85518Srobert ElementBitwidth = 32;
354*12c85518Srobert ScalarType = ScalarTypeKind::Float;
355*12c85518Srobert break;
356*12c85518Srobert case BasicType::Float64:
357*12c85518Srobert ElementBitwidth = 64;
358*12c85518Srobert ScalarType = ScalarTypeKind::Float;
359*12c85518Srobert break;
360*12c85518Srobert default:
361*12c85518Srobert llvm_unreachable("Unhandled type code!");
362*12c85518Srobert }
363*12c85518Srobert assert(ElementBitwidth != 0 && "Bad element bitwidth!");
364*12c85518Srobert }
365*12c85518Srobert
366*12c85518Srobert std::optional<PrototypeDescriptor>
parsePrototypeDescriptor(llvm::StringRef PrototypeDescriptorStr)367*12c85518Srobert PrototypeDescriptor::parsePrototypeDescriptor(
368*12c85518Srobert llvm::StringRef PrototypeDescriptorStr) {
369*12c85518Srobert PrototypeDescriptor PD;
370*12c85518Srobert BaseTypeModifier PT = BaseTypeModifier::Invalid;
371*12c85518Srobert VectorTypeModifier VTM = VectorTypeModifier::NoModifier;
372*12c85518Srobert
373*12c85518Srobert if (PrototypeDescriptorStr.empty())
374*12c85518Srobert return PD;
375*12c85518Srobert
376*12c85518Srobert // Handle base type modifier
377*12c85518Srobert auto PType = PrototypeDescriptorStr.back();
378*12c85518Srobert switch (PType) {
379*12c85518Srobert case 'e':
380*12c85518Srobert PT = BaseTypeModifier::Scalar;
381*12c85518Srobert break;
382*12c85518Srobert case 'v':
383*12c85518Srobert PT = BaseTypeModifier::Vector;
384*12c85518Srobert break;
385*12c85518Srobert case 'w':
386*12c85518Srobert PT = BaseTypeModifier::Vector;
387*12c85518Srobert VTM = VectorTypeModifier::Widening2XVector;
388*12c85518Srobert break;
389*12c85518Srobert case 'q':
390*12c85518Srobert PT = BaseTypeModifier::Vector;
391*12c85518Srobert VTM = VectorTypeModifier::Widening4XVector;
392*12c85518Srobert break;
393*12c85518Srobert case 'o':
394*12c85518Srobert PT = BaseTypeModifier::Vector;
395*12c85518Srobert VTM = VectorTypeModifier::Widening8XVector;
396*12c85518Srobert break;
397*12c85518Srobert case 'm':
398*12c85518Srobert PT = BaseTypeModifier::Vector;
399*12c85518Srobert VTM = VectorTypeModifier::MaskVector;
400*12c85518Srobert break;
401*12c85518Srobert case '0':
402*12c85518Srobert PT = BaseTypeModifier::Void;
403*12c85518Srobert break;
404*12c85518Srobert case 'z':
405*12c85518Srobert PT = BaseTypeModifier::SizeT;
406*12c85518Srobert break;
407*12c85518Srobert case 't':
408*12c85518Srobert PT = BaseTypeModifier::Ptrdiff;
409*12c85518Srobert break;
410*12c85518Srobert case 'u':
411*12c85518Srobert PT = BaseTypeModifier::UnsignedLong;
412*12c85518Srobert break;
413*12c85518Srobert case 'l':
414*12c85518Srobert PT = BaseTypeModifier::SignedLong;
415*12c85518Srobert break;
416*12c85518Srobert default:
417*12c85518Srobert llvm_unreachable("Illegal primitive type transformers!");
418*12c85518Srobert }
419*12c85518Srobert PD.PT = static_cast<uint8_t>(PT);
420*12c85518Srobert PrototypeDescriptorStr = PrototypeDescriptorStr.drop_back();
421*12c85518Srobert
422*12c85518Srobert // Compute the vector type transformers, it can only appear one time.
423*12c85518Srobert if (PrototypeDescriptorStr.startswith("(")) {
424*12c85518Srobert assert(VTM == VectorTypeModifier::NoModifier &&
425*12c85518Srobert "VectorTypeModifier should only have one modifier");
426*12c85518Srobert size_t Idx = PrototypeDescriptorStr.find(')');
427*12c85518Srobert assert(Idx != StringRef::npos);
428*12c85518Srobert StringRef ComplexType = PrototypeDescriptorStr.slice(1, Idx);
429*12c85518Srobert PrototypeDescriptorStr = PrototypeDescriptorStr.drop_front(Idx + 1);
430*12c85518Srobert assert(!PrototypeDescriptorStr.contains('(') &&
431*12c85518Srobert "Only allow one vector type modifier");
432*12c85518Srobert
433*12c85518Srobert auto ComplexTT = ComplexType.split(":");
434*12c85518Srobert if (ComplexTT.first == "Log2EEW") {
435*12c85518Srobert uint32_t Log2EEW;
436*12c85518Srobert if (ComplexTT.second.getAsInteger(10, Log2EEW)) {
437*12c85518Srobert llvm_unreachable("Invalid Log2EEW value!");
438*12c85518Srobert return std::nullopt;
439*12c85518Srobert }
440*12c85518Srobert switch (Log2EEW) {
441*12c85518Srobert case 3:
442*12c85518Srobert VTM = VectorTypeModifier::Log2EEW3;
443*12c85518Srobert break;
444*12c85518Srobert case 4:
445*12c85518Srobert VTM = VectorTypeModifier::Log2EEW4;
446*12c85518Srobert break;
447*12c85518Srobert case 5:
448*12c85518Srobert VTM = VectorTypeModifier::Log2EEW5;
449*12c85518Srobert break;
450*12c85518Srobert case 6:
451*12c85518Srobert VTM = VectorTypeModifier::Log2EEW6;
452*12c85518Srobert break;
453*12c85518Srobert default:
454*12c85518Srobert llvm_unreachable("Invalid Log2EEW value, should be [3-6]");
455*12c85518Srobert return std::nullopt;
456*12c85518Srobert }
457*12c85518Srobert } else if (ComplexTT.first == "FixedSEW") {
458*12c85518Srobert uint32_t NewSEW;
459*12c85518Srobert if (ComplexTT.second.getAsInteger(10, NewSEW)) {
460*12c85518Srobert llvm_unreachable("Invalid FixedSEW value!");
461*12c85518Srobert return std::nullopt;
462*12c85518Srobert }
463*12c85518Srobert switch (NewSEW) {
464*12c85518Srobert case 8:
465*12c85518Srobert VTM = VectorTypeModifier::FixedSEW8;
466*12c85518Srobert break;
467*12c85518Srobert case 16:
468*12c85518Srobert VTM = VectorTypeModifier::FixedSEW16;
469*12c85518Srobert break;
470*12c85518Srobert case 32:
471*12c85518Srobert VTM = VectorTypeModifier::FixedSEW32;
472*12c85518Srobert break;
473*12c85518Srobert case 64:
474*12c85518Srobert VTM = VectorTypeModifier::FixedSEW64;
475*12c85518Srobert break;
476*12c85518Srobert default:
477*12c85518Srobert llvm_unreachable("Invalid FixedSEW value, should be 8, 16, 32 or 64");
478*12c85518Srobert return std::nullopt;
479*12c85518Srobert }
480*12c85518Srobert } else if (ComplexTT.first == "LFixedLog2LMUL") {
481*12c85518Srobert int32_t Log2LMUL;
482*12c85518Srobert if (ComplexTT.second.getAsInteger(10, Log2LMUL)) {
483*12c85518Srobert llvm_unreachable("Invalid LFixedLog2LMUL value!");
484*12c85518Srobert return std::nullopt;
485*12c85518Srobert }
486*12c85518Srobert switch (Log2LMUL) {
487*12c85518Srobert case -3:
488*12c85518Srobert VTM = VectorTypeModifier::LFixedLog2LMULN3;
489*12c85518Srobert break;
490*12c85518Srobert case -2:
491*12c85518Srobert VTM = VectorTypeModifier::LFixedLog2LMULN2;
492*12c85518Srobert break;
493*12c85518Srobert case -1:
494*12c85518Srobert VTM = VectorTypeModifier::LFixedLog2LMULN1;
495*12c85518Srobert break;
496*12c85518Srobert case 0:
497*12c85518Srobert VTM = VectorTypeModifier::LFixedLog2LMUL0;
498*12c85518Srobert break;
499*12c85518Srobert case 1:
500*12c85518Srobert VTM = VectorTypeModifier::LFixedLog2LMUL1;
501*12c85518Srobert break;
502*12c85518Srobert case 2:
503*12c85518Srobert VTM = VectorTypeModifier::LFixedLog2LMUL2;
504*12c85518Srobert break;
505*12c85518Srobert case 3:
506*12c85518Srobert VTM = VectorTypeModifier::LFixedLog2LMUL3;
507*12c85518Srobert break;
508*12c85518Srobert default:
509*12c85518Srobert llvm_unreachable("Invalid LFixedLog2LMUL value, should be [-3, 3]");
510*12c85518Srobert return std::nullopt;
511*12c85518Srobert }
512*12c85518Srobert } else if (ComplexTT.first == "SFixedLog2LMUL") {
513*12c85518Srobert int32_t Log2LMUL;
514*12c85518Srobert if (ComplexTT.second.getAsInteger(10, Log2LMUL)) {
515*12c85518Srobert llvm_unreachable("Invalid SFixedLog2LMUL value!");
516*12c85518Srobert return std::nullopt;
517*12c85518Srobert }
518*12c85518Srobert switch (Log2LMUL) {
519*12c85518Srobert case -3:
520*12c85518Srobert VTM = VectorTypeModifier::SFixedLog2LMULN3;
521*12c85518Srobert break;
522*12c85518Srobert case -2:
523*12c85518Srobert VTM = VectorTypeModifier::SFixedLog2LMULN2;
524*12c85518Srobert break;
525*12c85518Srobert case -1:
526*12c85518Srobert VTM = VectorTypeModifier::SFixedLog2LMULN1;
527*12c85518Srobert break;
528*12c85518Srobert case 0:
529*12c85518Srobert VTM = VectorTypeModifier::SFixedLog2LMUL0;
530*12c85518Srobert break;
531*12c85518Srobert case 1:
532*12c85518Srobert VTM = VectorTypeModifier::SFixedLog2LMUL1;
533*12c85518Srobert break;
534*12c85518Srobert case 2:
535*12c85518Srobert VTM = VectorTypeModifier::SFixedLog2LMUL2;
536*12c85518Srobert break;
537*12c85518Srobert case 3:
538*12c85518Srobert VTM = VectorTypeModifier::SFixedLog2LMUL3;
539*12c85518Srobert break;
540*12c85518Srobert default:
541*12c85518Srobert llvm_unreachable("Invalid LFixedLog2LMUL value, should be [-3, 3]");
542*12c85518Srobert return std::nullopt;
543*12c85518Srobert }
544*12c85518Srobert
545*12c85518Srobert } else {
546*12c85518Srobert llvm_unreachable("Illegal complex type transformers!");
547*12c85518Srobert }
548*12c85518Srobert }
549*12c85518Srobert PD.VTM = static_cast<uint8_t>(VTM);
550*12c85518Srobert
551*12c85518Srobert // Compute the remain type transformers
552*12c85518Srobert TypeModifier TM = TypeModifier::NoModifier;
553*12c85518Srobert for (char I : PrototypeDescriptorStr) {
554*12c85518Srobert switch (I) {
555*12c85518Srobert case 'P':
556*12c85518Srobert if ((TM & TypeModifier::Const) == TypeModifier::Const)
557*12c85518Srobert llvm_unreachable("'P' transformer cannot be used after 'C'");
558*12c85518Srobert if ((TM & TypeModifier::Pointer) == TypeModifier::Pointer)
559*12c85518Srobert llvm_unreachable("'P' transformer cannot be used twice");
560*12c85518Srobert TM |= TypeModifier::Pointer;
561*12c85518Srobert break;
562*12c85518Srobert case 'C':
563*12c85518Srobert TM |= TypeModifier::Const;
564*12c85518Srobert break;
565*12c85518Srobert case 'K':
566*12c85518Srobert TM |= TypeModifier::Immediate;
567*12c85518Srobert break;
568*12c85518Srobert case 'U':
569*12c85518Srobert TM |= TypeModifier::UnsignedInteger;
570*12c85518Srobert break;
571*12c85518Srobert case 'I':
572*12c85518Srobert TM |= TypeModifier::SignedInteger;
573*12c85518Srobert break;
574*12c85518Srobert case 'F':
575*12c85518Srobert TM |= TypeModifier::Float;
576*12c85518Srobert break;
577*12c85518Srobert case 'S':
578*12c85518Srobert TM |= TypeModifier::LMUL1;
579*12c85518Srobert break;
580*12c85518Srobert default:
581*12c85518Srobert llvm_unreachable("Illegal non-primitive type transformer!");
582*12c85518Srobert }
583*12c85518Srobert }
584*12c85518Srobert PD.TM = static_cast<uint8_t>(TM);
585*12c85518Srobert
586*12c85518Srobert return PD;
587*12c85518Srobert }
588*12c85518Srobert
applyModifier(const PrototypeDescriptor & Transformer)589*12c85518Srobert void RVVType::applyModifier(const PrototypeDescriptor &Transformer) {
590*12c85518Srobert // Handle primitive type transformer
591*12c85518Srobert switch (static_cast<BaseTypeModifier>(Transformer.PT)) {
592*12c85518Srobert case BaseTypeModifier::Scalar:
593*12c85518Srobert Scale = 0;
594*12c85518Srobert break;
595*12c85518Srobert case BaseTypeModifier::Vector:
596*12c85518Srobert Scale = LMUL.getScale(ElementBitwidth);
597*12c85518Srobert break;
598*12c85518Srobert case BaseTypeModifier::Void:
599*12c85518Srobert ScalarType = ScalarTypeKind::Void;
600*12c85518Srobert break;
601*12c85518Srobert case BaseTypeModifier::SizeT:
602*12c85518Srobert ScalarType = ScalarTypeKind::Size_t;
603*12c85518Srobert break;
604*12c85518Srobert case BaseTypeModifier::Ptrdiff:
605*12c85518Srobert ScalarType = ScalarTypeKind::Ptrdiff_t;
606*12c85518Srobert break;
607*12c85518Srobert case BaseTypeModifier::UnsignedLong:
608*12c85518Srobert ScalarType = ScalarTypeKind::UnsignedLong;
609*12c85518Srobert break;
610*12c85518Srobert case BaseTypeModifier::SignedLong:
611*12c85518Srobert ScalarType = ScalarTypeKind::SignedLong;
612*12c85518Srobert break;
613*12c85518Srobert case BaseTypeModifier::Invalid:
614*12c85518Srobert ScalarType = ScalarTypeKind::Invalid;
615*12c85518Srobert return;
616*12c85518Srobert }
617*12c85518Srobert
618*12c85518Srobert switch (static_cast<VectorTypeModifier>(Transformer.VTM)) {
619*12c85518Srobert case VectorTypeModifier::Widening2XVector:
620*12c85518Srobert ElementBitwidth *= 2;
621*12c85518Srobert LMUL.MulLog2LMUL(1);
622*12c85518Srobert Scale = LMUL.getScale(ElementBitwidth);
623*12c85518Srobert break;
624*12c85518Srobert case VectorTypeModifier::Widening4XVector:
625*12c85518Srobert ElementBitwidth *= 4;
626*12c85518Srobert LMUL.MulLog2LMUL(2);
627*12c85518Srobert Scale = LMUL.getScale(ElementBitwidth);
628*12c85518Srobert break;
629*12c85518Srobert case VectorTypeModifier::Widening8XVector:
630*12c85518Srobert ElementBitwidth *= 8;
631*12c85518Srobert LMUL.MulLog2LMUL(3);
632*12c85518Srobert Scale = LMUL.getScale(ElementBitwidth);
633*12c85518Srobert break;
634*12c85518Srobert case VectorTypeModifier::MaskVector:
635*12c85518Srobert ScalarType = ScalarTypeKind::Boolean;
636*12c85518Srobert Scale = LMUL.getScale(ElementBitwidth);
637*12c85518Srobert ElementBitwidth = 1;
638*12c85518Srobert break;
639*12c85518Srobert case VectorTypeModifier::Log2EEW3:
640*12c85518Srobert applyLog2EEW(3);
641*12c85518Srobert break;
642*12c85518Srobert case VectorTypeModifier::Log2EEW4:
643*12c85518Srobert applyLog2EEW(4);
644*12c85518Srobert break;
645*12c85518Srobert case VectorTypeModifier::Log2EEW5:
646*12c85518Srobert applyLog2EEW(5);
647*12c85518Srobert break;
648*12c85518Srobert case VectorTypeModifier::Log2EEW6:
649*12c85518Srobert applyLog2EEW(6);
650*12c85518Srobert break;
651*12c85518Srobert case VectorTypeModifier::FixedSEW8:
652*12c85518Srobert applyFixedSEW(8);
653*12c85518Srobert break;
654*12c85518Srobert case VectorTypeModifier::FixedSEW16:
655*12c85518Srobert applyFixedSEW(16);
656*12c85518Srobert break;
657*12c85518Srobert case VectorTypeModifier::FixedSEW32:
658*12c85518Srobert applyFixedSEW(32);
659*12c85518Srobert break;
660*12c85518Srobert case VectorTypeModifier::FixedSEW64:
661*12c85518Srobert applyFixedSEW(64);
662*12c85518Srobert break;
663*12c85518Srobert case VectorTypeModifier::LFixedLog2LMULN3:
664*12c85518Srobert applyFixedLog2LMUL(-3, FixedLMULType::LargerThan);
665*12c85518Srobert break;
666*12c85518Srobert case VectorTypeModifier::LFixedLog2LMULN2:
667*12c85518Srobert applyFixedLog2LMUL(-2, FixedLMULType::LargerThan);
668*12c85518Srobert break;
669*12c85518Srobert case VectorTypeModifier::LFixedLog2LMULN1:
670*12c85518Srobert applyFixedLog2LMUL(-1, FixedLMULType::LargerThan);
671*12c85518Srobert break;
672*12c85518Srobert case VectorTypeModifier::LFixedLog2LMUL0:
673*12c85518Srobert applyFixedLog2LMUL(0, FixedLMULType::LargerThan);
674*12c85518Srobert break;
675*12c85518Srobert case VectorTypeModifier::LFixedLog2LMUL1:
676*12c85518Srobert applyFixedLog2LMUL(1, FixedLMULType::LargerThan);
677*12c85518Srobert break;
678*12c85518Srobert case VectorTypeModifier::LFixedLog2LMUL2:
679*12c85518Srobert applyFixedLog2LMUL(2, FixedLMULType::LargerThan);
680*12c85518Srobert break;
681*12c85518Srobert case VectorTypeModifier::LFixedLog2LMUL3:
682*12c85518Srobert applyFixedLog2LMUL(3, FixedLMULType::LargerThan);
683*12c85518Srobert break;
684*12c85518Srobert case VectorTypeModifier::SFixedLog2LMULN3:
685*12c85518Srobert applyFixedLog2LMUL(-3, FixedLMULType::SmallerThan);
686*12c85518Srobert break;
687*12c85518Srobert case VectorTypeModifier::SFixedLog2LMULN2:
688*12c85518Srobert applyFixedLog2LMUL(-2, FixedLMULType::SmallerThan);
689*12c85518Srobert break;
690*12c85518Srobert case VectorTypeModifier::SFixedLog2LMULN1:
691*12c85518Srobert applyFixedLog2LMUL(-1, FixedLMULType::SmallerThan);
692*12c85518Srobert break;
693*12c85518Srobert case VectorTypeModifier::SFixedLog2LMUL0:
694*12c85518Srobert applyFixedLog2LMUL(0, FixedLMULType::SmallerThan);
695*12c85518Srobert break;
696*12c85518Srobert case VectorTypeModifier::SFixedLog2LMUL1:
697*12c85518Srobert applyFixedLog2LMUL(1, FixedLMULType::SmallerThan);
698*12c85518Srobert break;
699*12c85518Srobert case VectorTypeModifier::SFixedLog2LMUL2:
700*12c85518Srobert applyFixedLog2LMUL(2, FixedLMULType::SmallerThan);
701*12c85518Srobert break;
702*12c85518Srobert case VectorTypeModifier::SFixedLog2LMUL3:
703*12c85518Srobert applyFixedLog2LMUL(3, FixedLMULType::SmallerThan);
704*12c85518Srobert break;
705*12c85518Srobert case VectorTypeModifier::NoModifier:
706*12c85518Srobert break;
707*12c85518Srobert }
708*12c85518Srobert
709*12c85518Srobert for (unsigned TypeModifierMaskShift = 0;
710*12c85518Srobert TypeModifierMaskShift <= static_cast<unsigned>(TypeModifier::MaxOffset);
711*12c85518Srobert ++TypeModifierMaskShift) {
712*12c85518Srobert unsigned TypeModifierMask = 1 << TypeModifierMaskShift;
713*12c85518Srobert if ((static_cast<unsigned>(Transformer.TM) & TypeModifierMask) !=
714*12c85518Srobert TypeModifierMask)
715*12c85518Srobert continue;
716*12c85518Srobert switch (static_cast<TypeModifier>(TypeModifierMask)) {
717*12c85518Srobert case TypeModifier::Pointer:
718*12c85518Srobert IsPointer = true;
719*12c85518Srobert break;
720*12c85518Srobert case TypeModifier::Const:
721*12c85518Srobert IsConstant = true;
722*12c85518Srobert break;
723*12c85518Srobert case TypeModifier::Immediate:
724*12c85518Srobert IsImmediate = true;
725*12c85518Srobert IsConstant = true;
726*12c85518Srobert break;
727*12c85518Srobert case TypeModifier::UnsignedInteger:
728*12c85518Srobert ScalarType = ScalarTypeKind::UnsignedInteger;
729*12c85518Srobert break;
730*12c85518Srobert case TypeModifier::SignedInteger:
731*12c85518Srobert ScalarType = ScalarTypeKind::SignedInteger;
732*12c85518Srobert break;
733*12c85518Srobert case TypeModifier::Float:
734*12c85518Srobert ScalarType = ScalarTypeKind::Float;
735*12c85518Srobert break;
736*12c85518Srobert case TypeModifier::LMUL1:
737*12c85518Srobert LMUL = LMULType(0);
738*12c85518Srobert // Update ElementBitwidth need to update Scale too.
739*12c85518Srobert Scale = LMUL.getScale(ElementBitwidth);
740*12c85518Srobert break;
741*12c85518Srobert default:
742*12c85518Srobert llvm_unreachable("Unknown type modifier mask!");
743*12c85518Srobert }
744*12c85518Srobert }
745*12c85518Srobert }
746*12c85518Srobert
applyLog2EEW(unsigned Log2EEW)747*12c85518Srobert void RVVType::applyLog2EEW(unsigned Log2EEW) {
748*12c85518Srobert // update new elmul = (eew/sew) * lmul
749*12c85518Srobert LMUL.MulLog2LMUL(Log2EEW - Log2_32(ElementBitwidth));
750*12c85518Srobert // update new eew
751*12c85518Srobert ElementBitwidth = 1 << Log2EEW;
752*12c85518Srobert ScalarType = ScalarTypeKind::SignedInteger;
753*12c85518Srobert Scale = LMUL.getScale(ElementBitwidth);
754*12c85518Srobert }
755*12c85518Srobert
applyFixedSEW(unsigned NewSEW)756*12c85518Srobert void RVVType::applyFixedSEW(unsigned NewSEW) {
757*12c85518Srobert // Set invalid type if src and dst SEW are same.
758*12c85518Srobert if (ElementBitwidth == NewSEW) {
759*12c85518Srobert ScalarType = ScalarTypeKind::Invalid;
760*12c85518Srobert return;
761*12c85518Srobert }
762*12c85518Srobert // Update new SEW
763*12c85518Srobert ElementBitwidth = NewSEW;
764*12c85518Srobert Scale = LMUL.getScale(ElementBitwidth);
765*12c85518Srobert }
766*12c85518Srobert
applyFixedLog2LMUL(int Log2LMUL,enum FixedLMULType Type)767*12c85518Srobert void RVVType::applyFixedLog2LMUL(int Log2LMUL, enum FixedLMULType Type) {
768*12c85518Srobert switch (Type) {
769*12c85518Srobert case FixedLMULType::LargerThan:
770*12c85518Srobert if (Log2LMUL < LMUL.Log2LMUL) {
771*12c85518Srobert ScalarType = ScalarTypeKind::Invalid;
772*12c85518Srobert return;
773*12c85518Srobert }
774*12c85518Srobert break;
775*12c85518Srobert case FixedLMULType::SmallerThan:
776*12c85518Srobert if (Log2LMUL > LMUL.Log2LMUL) {
777*12c85518Srobert ScalarType = ScalarTypeKind::Invalid;
778*12c85518Srobert return;
779*12c85518Srobert }
780*12c85518Srobert break;
781*12c85518Srobert }
782*12c85518Srobert
783*12c85518Srobert // Update new LMUL
784*12c85518Srobert LMUL = LMULType(Log2LMUL);
785*12c85518Srobert Scale = LMUL.getScale(ElementBitwidth);
786*12c85518Srobert }
787*12c85518Srobert
788*12c85518Srobert std::optional<RVVTypes>
computeTypes(BasicType BT,int Log2LMUL,unsigned NF,ArrayRef<PrototypeDescriptor> Prototype)789*12c85518Srobert RVVTypeCache::computeTypes(BasicType BT, int Log2LMUL, unsigned NF,
790*12c85518Srobert ArrayRef<PrototypeDescriptor> Prototype) {
791*12c85518Srobert // LMUL x NF must be less than or equal to 8.
792*12c85518Srobert if ((Log2LMUL >= 1) && (1 << Log2LMUL) * NF > 8)
793*12c85518Srobert return std::nullopt;
794*12c85518Srobert
795*12c85518Srobert RVVTypes Types;
796*12c85518Srobert for (const PrototypeDescriptor &Proto : Prototype) {
797*12c85518Srobert auto T = computeType(BT, Log2LMUL, Proto);
798*12c85518Srobert if (!T)
799*12c85518Srobert return std::nullopt;
800*12c85518Srobert // Record legal type index
801*12c85518Srobert Types.push_back(*T);
802*12c85518Srobert }
803*12c85518Srobert return Types;
804*12c85518Srobert }
805*12c85518Srobert
806*12c85518Srobert // Compute the hash value of RVVType, used for cache the result of computeType.
computeRVVTypeHashValue(BasicType BT,int Log2LMUL,PrototypeDescriptor Proto)807*12c85518Srobert static uint64_t computeRVVTypeHashValue(BasicType BT, int Log2LMUL,
808*12c85518Srobert PrototypeDescriptor Proto) {
809*12c85518Srobert // Layout of hash value:
810*12c85518Srobert // 0 8 16 24 32 40
811*12c85518Srobert // | Log2LMUL + 3 | BT | Proto.PT | Proto.TM | Proto.VTM |
812*12c85518Srobert assert(Log2LMUL >= -3 && Log2LMUL <= 3);
813*12c85518Srobert return (Log2LMUL + 3) | (static_cast<uint64_t>(BT) & 0xff) << 8 |
814*12c85518Srobert ((uint64_t)(Proto.PT & 0xff) << 16) |
815*12c85518Srobert ((uint64_t)(Proto.TM & 0xff) << 24) |
816*12c85518Srobert ((uint64_t)(Proto.VTM & 0xff) << 32);
817*12c85518Srobert }
818*12c85518Srobert
computeType(BasicType BT,int Log2LMUL,PrototypeDescriptor Proto)819*12c85518Srobert std::optional<RVVTypePtr> RVVTypeCache::computeType(BasicType BT, int Log2LMUL,
820*12c85518Srobert PrototypeDescriptor Proto) {
821*12c85518Srobert uint64_t Idx = computeRVVTypeHashValue(BT, Log2LMUL, Proto);
822*12c85518Srobert // Search first
823*12c85518Srobert auto It = LegalTypes.find(Idx);
824*12c85518Srobert if (It != LegalTypes.end())
825*12c85518Srobert return &(It->second);
826*12c85518Srobert
827*12c85518Srobert if (IllegalTypes.count(Idx))
828*12c85518Srobert return std::nullopt;
829*12c85518Srobert
830*12c85518Srobert // Compute type and record the result.
831*12c85518Srobert RVVType T(BT, Log2LMUL, Proto);
832*12c85518Srobert if (T.isValid()) {
833*12c85518Srobert // Record legal type index and value.
834*12c85518Srobert std::pair<std::unordered_map<uint64_t, RVVType>::iterator, bool>
835*12c85518Srobert InsertResult = LegalTypes.insert({Idx, T});
836*12c85518Srobert return &(InsertResult.first->second);
837*12c85518Srobert }
838*12c85518Srobert // Record illegal type index.
839*12c85518Srobert IllegalTypes.insert(Idx);
840*12c85518Srobert return std::nullopt;
841*12c85518Srobert }
842*12c85518Srobert
843*12c85518Srobert //===----------------------------------------------------------------------===//
844*12c85518Srobert // RVVIntrinsic implementation
845*12c85518Srobert //===----------------------------------------------------------------------===//
RVVIntrinsic(StringRef NewName,StringRef Suffix,StringRef NewOverloadedName,StringRef OverloadedSuffix,StringRef IRName,bool IsMasked,bool HasMaskedOffOperand,bool HasVL,PolicyScheme Scheme,bool SupportOverloading,bool HasBuiltinAlias,StringRef ManualCodegen,const RVVTypes & OutInTypes,const std::vector<int64_t> & NewIntrinsicTypes,const std::vector<StringRef> & RequiredFeatures,unsigned NF,Policy NewPolicyAttrs)846*12c85518Srobert RVVIntrinsic::RVVIntrinsic(StringRef NewName, StringRef Suffix,
847*12c85518Srobert StringRef NewOverloadedName,
848*12c85518Srobert StringRef OverloadedSuffix, StringRef IRName,
849*12c85518Srobert bool IsMasked, bool HasMaskedOffOperand, bool HasVL,
850*12c85518Srobert PolicyScheme Scheme, bool SupportOverloading,
851*12c85518Srobert bool HasBuiltinAlias, StringRef ManualCodegen,
852*12c85518Srobert const RVVTypes &OutInTypes,
853*12c85518Srobert const std::vector<int64_t> &NewIntrinsicTypes,
854*12c85518Srobert const std::vector<StringRef> &RequiredFeatures,
855*12c85518Srobert unsigned NF, Policy NewPolicyAttrs)
856*12c85518Srobert : IRName(IRName), IsMasked(IsMasked),
857*12c85518Srobert HasMaskedOffOperand(HasMaskedOffOperand), HasVL(HasVL), Scheme(Scheme),
858*12c85518Srobert SupportOverloading(SupportOverloading), HasBuiltinAlias(HasBuiltinAlias),
859*12c85518Srobert ManualCodegen(ManualCodegen.str()), NF(NF), PolicyAttrs(NewPolicyAttrs) {
860*12c85518Srobert
861*12c85518Srobert // Init BuiltinName, Name and OverloadedName
862*12c85518Srobert BuiltinName = NewName.str();
863*12c85518Srobert Name = BuiltinName;
864*12c85518Srobert if (NewOverloadedName.empty())
865*12c85518Srobert OverloadedName = NewName.split("_").first.str();
866*12c85518Srobert else
867*12c85518Srobert OverloadedName = NewOverloadedName.str();
868*12c85518Srobert if (!Suffix.empty())
869*12c85518Srobert Name += "_" + Suffix.str();
870*12c85518Srobert if (!OverloadedSuffix.empty())
871*12c85518Srobert OverloadedName += "_" + OverloadedSuffix.str();
872*12c85518Srobert
873*12c85518Srobert updateNamesAndPolicy(IsMasked, hasPolicy(), Name, BuiltinName, OverloadedName,
874*12c85518Srobert PolicyAttrs);
875*12c85518Srobert
876*12c85518Srobert // Init OutputType and InputTypes
877*12c85518Srobert OutputType = OutInTypes[0];
878*12c85518Srobert InputTypes.assign(OutInTypes.begin() + 1, OutInTypes.end());
879*12c85518Srobert
880*12c85518Srobert // IntrinsicTypes is unmasked TA version index. Need to update it
881*12c85518Srobert // if there is merge operand (It is always in first operand).
882*12c85518Srobert IntrinsicTypes = NewIntrinsicTypes;
883*12c85518Srobert if ((IsMasked && hasMaskedOffOperand()) ||
884*12c85518Srobert (!IsMasked && hasPassthruOperand())) {
885*12c85518Srobert for (auto &I : IntrinsicTypes) {
886*12c85518Srobert if (I >= 0)
887*12c85518Srobert I += NF;
888*12c85518Srobert }
889*12c85518Srobert }
890*12c85518Srobert }
891*12c85518Srobert
getBuiltinTypeStr() const892*12c85518Srobert std::string RVVIntrinsic::getBuiltinTypeStr() const {
893*12c85518Srobert std::string S;
894*12c85518Srobert S += OutputType->getBuiltinStr();
895*12c85518Srobert for (const auto &T : InputTypes) {
896*12c85518Srobert S += T->getBuiltinStr();
897*12c85518Srobert }
898*12c85518Srobert return S;
899*12c85518Srobert }
900*12c85518Srobert
getSuffixStr(RVVTypeCache & TypeCache,BasicType Type,int Log2LMUL,llvm::ArrayRef<PrototypeDescriptor> PrototypeDescriptors)901*12c85518Srobert std::string RVVIntrinsic::getSuffixStr(
902*12c85518Srobert RVVTypeCache &TypeCache, BasicType Type, int Log2LMUL,
903*12c85518Srobert llvm::ArrayRef<PrototypeDescriptor> PrototypeDescriptors) {
904*12c85518Srobert SmallVector<std::string> SuffixStrs;
905*12c85518Srobert for (auto PD : PrototypeDescriptors) {
906*12c85518Srobert auto T = TypeCache.computeType(Type, Log2LMUL, PD);
907*12c85518Srobert SuffixStrs.push_back((*T)->getShortStr());
908*12c85518Srobert }
909*12c85518Srobert return join(SuffixStrs, "_");
910*12c85518Srobert }
911*12c85518Srobert
computeBuiltinTypes(llvm::ArrayRef<PrototypeDescriptor> Prototype,bool IsMasked,bool HasMaskedOffOperand,bool HasVL,unsigned NF,PolicyScheme DefaultScheme,Policy PolicyAttrs)912*12c85518Srobert llvm::SmallVector<PrototypeDescriptor> RVVIntrinsic::computeBuiltinTypes(
913*12c85518Srobert llvm::ArrayRef<PrototypeDescriptor> Prototype, bool IsMasked,
914*12c85518Srobert bool HasMaskedOffOperand, bool HasVL, unsigned NF,
915*12c85518Srobert PolicyScheme DefaultScheme, Policy PolicyAttrs) {
916*12c85518Srobert SmallVector<PrototypeDescriptor> NewPrototype(Prototype.begin(),
917*12c85518Srobert Prototype.end());
918*12c85518Srobert bool HasPassthruOp = DefaultScheme == PolicyScheme::HasPassthruOperand;
919*12c85518Srobert if (IsMasked) {
920*12c85518Srobert // If HasMaskedOffOperand, insert result type as first input operand if
921*12c85518Srobert // need.
922*12c85518Srobert if (HasMaskedOffOperand && !PolicyAttrs.isTAMAPolicy()) {
923*12c85518Srobert if (NF == 1) {
924*12c85518Srobert NewPrototype.insert(NewPrototype.begin() + 1, NewPrototype[0]);
925*12c85518Srobert } else if (NF > 1) {
926*12c85518Srobert // Convert
927*12c85518Srobert // (void, op0 address, op1 address, ...)
928*12c85518Srobert // to
929*12c85518Srobert // (void, op0 address, op1 address, ..., maskedoff0, maskedoff1, ...)
930*12c85518Srobert PrototypeDescriptor MaskoffType = NewPrototype[1];
931*12c85518Srobert MaskoffType.TM &= ~static_cast<uint8_t>(TypeModifier::Pointer);
932*12c85518Srobert NewPrototype.insert(NewPrototype.begin() + NF + 1, NF, MaskoffType);
933*12c85518Srobert }
934*12c85518Srobert }
935*12c85518Srobert if (HasMaskedOffOperand && NF > 1) {
936*12c85518Srobert // Convert
937*12c85518Srobert // (void, op0 address, op1 address, ..., maskedoff0, maskedoff1, ...)
938*12c85518Srobert // to
939*12c85518Srobert // (void, op0 address, op1 address, ..., mask, maskedoff0, maskedoff1,
940*12c85518Srobert // ...)
941*12c85518Srobert NewPrototype.insert(NewPrototype.begin() + NF + 1,
942*12c85518Srobert PrototypeDescriptor::Mask);
943*12c85518Srobert } else {
944*12c85518Srobert // If IsMasked, insert PrototypeDescriptor:Mask as first input operand.
945*12c85518Srobert NewPrototype.insert(NewPrototype.begin() + 1, PrototypeDescriptor::Mask);
946*12c85518Srobert }
947*12c85518Srobert } else {
948*12c85518Srobert if (NF == 1) {
949*12c85518Srobert if (PolicyAttrs.isTUPolicy() && HasPassthruOp)
950*12c85518Srobert NewPrototype.insert(NewPrototype.begin(), NewPrototype[0]);
951*12c85518Srobert } else if (PolicyAttrs.isTUPolicy() && HasPassthruOp) {
952*12c85518Srobert // NF > 1 cases for segment load operations.
953*12c85518Srobert // Convert
954*12c85518Srobert // (void, op0 address, op1 address, ...)
955*12c85518Srobert // to
956*12c85518Srobert // (void, op0 address, op1 address, maskedoff0, maskedoff1, ...)
957*12c85518Srobert PrototypeDescriptor MaskoffType = Prototype[1];
958*12c85518Srobert MaskoffType.TM &= ~static_cast<uint8_t>(TypeModifier::Pointer);
959*12c85518Srobert NewPrototype.insert(NewPrototype.begin() + NF + 1, NF, MaskoffType);
960*12c85518Srobert }
961*12c85518Srobert }
962*12c85518Srobert
963*12c85518Srobert // If HasVL, append PrototypeDescriptor:VL to last operand
964*12c85518Srobert if (HasVL)
965*12c85518Srobert NewPrototype.push_back(PrototypeDescriptor::VL);
966*12c85518Srobert return NewPrototype;
967*12c85518Srobert }
968*12c85518Srobert
getSupportedUnMaskedPolicies()969*12c85518Srobert llvm::SmallVector<Policy> RVVIntrinsic::getSupportedUnMaskedPolicies() {
970*12c85518Srobert return {Policy(Policy::PolicyType::Undisturbed)}; // TU
971*12c85518Srobert }
972*12c85518Srobert
973*12c85518Srobert llvm::SmallVector<Policy>
getSupportedMaskedPolicies(bool HasTailPolicy,bool HasMaskPolicy)974*12c85518Srobert RVVIntrinsic::getSupportedMaskedPolicies(bool HasTailPolicy,
975*12c85518Srobert bool HasMaskPolicy) {
976*12c85518Srobert if (HasTailPolicy && HasMaskPolicy)
977*12c85518Srobert return {Policy(Policy::PolicyType::Undisturbed,
978*12c85518Srobert Policy::PolicyType::Agnostic), // TUM
979*12c85518Srobert Policy(Policy::PolicyType::Undisturbed,
980*12c85518Srobert Policy::PolicyType::Undisturbed), // TUMU
981*12c85518Srobert Policy(Policy::PolicyType::Agnostic,
982*12c85518Srobert Policy::PolicyType::Undisturbed)}; // MU
983*12c85518Srobert if (HasTailPolicy && !HasMaskPolicy)
984*12c85518Srobert return {Policy(Policy::PolicyType::Undisturbed,
985*12c85518Srobert Policy::PolicyType::Agnostic)}; // TU
986*12c85518Srobert if (!HasTailPolicy && HasMaskPolicy)
987*12c85518Srobert return {Policy(Policy::PolicyType::Agnostic,
988*12c85518Srobert Policy::PolicyType::Undisturbed)}; // MU
989*12c85518Srobert llvm_unreachable("An RVV instruction should not be without both tail policy "
990*12c85518Srobert "and mask policy");
991*12c85518Srobert }
992*12c85518Srobert
updateNamesAndPolicy(bool IsMasked,bool HasPolicy,std::string & Name,std::string & BuiltinName,std::string & OverloadedName,Policy & PolicyAttrs)993*12c85518Srobert void RVVIntrinsic::updateNamesAndPolicy(bool IsMasked, bool HasPolicy,
994*12c85518Srobert std::string &Name,
995*12c85518Srobert std::string &BuiltinName,
996*12c85518Srobert std::string &OverloadedName,
997*12c85518Srobert Policy &PolicyAttrs) {
998*12c85518Srobert
999*12c85518Srobert auto appendPolicySuffix = [&](const std::string &suffix) {
1000*12c85518Srobert Name += suffix;
1001*12c85518Srobert BuiltinName += suffix;
1002*12c85518Srobert OverloadedName += suffix;
1003*12c85518Srobert };
1004*12c85518Srobert
1005*12c85518Srobert // This follows the naming guideline under riscv-c-api-doc to add the
1006*12c85518Srobert // `__riscv_` suffix for all RVV intrinsics.
1007*12c85518Srobert Name = "__riscv_" + Name;
1008*12c85518Srobert OverloadedName = "__riscv_" + OverloadedName;
1009*12c85518Srobert
1010*12c85518Srobert if (IsMasked) {
1011*12c85518Srobert if (PolicyAttrs.isTUMUPolicy())
1012*12c85518Srobert appendPolicySuffix("_tumu");
1013*12c85518Srobert else if (PolicyAttrs.isTUMAPolicy())
1014*12c85518Srobert appendPolicySuffix("_tum");
1015*12c85518Srobert else if (PolicyAttrs.isTAMUPolicy())
1016*12c85518Srobert appendPolicySuffix("_mu");
1017*12c85518Srobert else if (PolicyAttrs.isTAMAPolicy()) {
1018*12c85518Srobert Name += "_m";
1019*12c85518Srobert if (HasPolicy)
1020*12c85518Srobert BuiltinName += "_tama";
1021*12c85518Srobert else
1022*12c85518Srobert BuiltinName += "_m";
1023*12c85518Srobert } else
1024*12c85518Srobert llvm_unreachable("Unhandled policy condition");
1025*12c85518Srobert } else {
1026*12c85518Srobert if (PolicyAttrs.isTUPolicy())
1027*12c85518Srobert appendPolicySuffix("_tu");
1028*12c85518Srobert else if (PolicyAttrs.isTAPolicy()) {
1029*12c85518Srobert if (HasPolicy)
1030*12c85518Srobert BuiltinName += "_ta";
1031*12c85518Srobert } else
1032*12c85518Srobert llvm_unreachable("Unhandled policy condition");
1033*12c85518Srobert }
1034*12c85518Srobert }
1035*12c85518Srobert
parsePrototypes(StringRef Prototypes)1036*12c85518Srobert SmallVector<PrototypeDescriptor> parsePrototypes(StringRef Prototypes) {
1037*12c85518Srobert SmallVector<PrototypeDescriptor> PrototypeDescriptors;
1038*12c85518Srobert const StringRef Primaries("evwqom0ztul");
1039*12c85518Srobert while (!Prototypes.empty()) {
1040*12c85518Srobert size_t Idx = 0;
1041*12c85518Srobert // Skip over complex prototype because it could contain primitive type
1042*12c85518Srobert // character.
1043*12c85518Srobert if (Prototypes[0] == '(')
1044*12c85518Srobert Idx = Prototypes.find_first_of(')');
1045*12c85518Srobert Idx = Prototypes.find_first_of(Primaries, Idx);
1046*12c85518Srobert assert(Idx != StringRef::npos);
1047*12c85518Srobert auto PD = PrototypeDescriptor::parsePrototypeDescriptor(
1048*12c85518Srobert Prototypes.slice(0, Idx + 1));
1049*12c85518Srobert if (!PD)
1050*12c85518Srobert llvm_unreachable("Error during parsing prototype.");
1051*12c85518Srobert PrototypeDescriptors.push_back(*PD);
1052*12c85518Srobert Prototypes = Prototypes.drop_front(Idx + 1);
1053*12c85518Srobert }
1054*12c85518Srobert return PrototypeDescriptors;
1055*12c85518Srobert }
1056*12c85518Srobert
operator <<(raw_ostream & OS,const RVVIntrinsicRecord & Record)1057*12c85518Srobert raw_ostream &operator<<(raw_ostream &OS, const RVVIntrinsicRecord &Record) {
1058*12c85518Srobert OS << "{";
1059*12c85518Srobert OS << "\"" << Record.Name << "\",";
1060*12c85518Srobert if (Record.OverloadedName == nullptr ||
1061*12c85518Srobert StringRef(Record.OverloadedName).empty())
1062*12c85518Srobert OS << "nullptr,";
1063*12c85518Srobert else
1064*12c85518Srobert OS << "\"" << Record.OverloadedName << "\",";
1065*12c85518Srobert OS << Record.PrototypeIndex << ",";
1066*12c85518Srobert OS << Record.SuffixIndex << ",";
1067*12c85518Srobert OS << Record.OverloadedSuffixIndex << ",";
1068*12c85518Srobert OS << (int)Record.PrototypeLength << ",";
1069*12c85518Srobert OS << (int)Record.SuffixLength << ",";
1070*12c85518Srobert OS << (int)Record.OverloadedSuffixSize << ",";
1071*12c85518Srobert OS << (int)Record.RequiredExtensions << ",";
1072*12c85518Srobert OS << (int)Record.TypeRangeMask << ",";
1073*12c85518Srobert OS << (int)Record.Log2LMULMask << ",";
1074*12c85518Srobert OS << (int)Record.NF << ",";
1075*12c85518Srobert OS << (int)Record.HasMasked << ",";
1076*12c85518Srobert OS << (int)Record.HasVL << ",";
1077*12c85518Srobert OS << (int)Record.HasMaskedOffOperand << ",";
1078*12c85518Srobert OS << (int)Record.HasTailPolicy << ",";
1079*12c85518Srobert OS << (int)Record.HasMaskPolicy << ",";
1080*12c85518Srobert OS << (int)Record.UnMaskedPolicyScheme << ",";
1081*12c85518Srobert OS << (int)Record.MaskedPolicyScheme << ",";
1082*12c85518Srobert OS << "},\n";
1083*12c85518Srobert return OS;
1084*12c85518Srobert }
1085*12c85518Srobert
1086*12c85518Srobert } // end namespace RISCV
1087*12c85518Srobert } // end namespace clang
1088