1 //===- RISCVVEmitter.cpp - Generate riscv_vector.h for use with clang -----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This tablegen backend is responsible for emitting riscv_vector.h which 10 // includes a declaration and definition of each intrinsic functions specified 11 // in https://github.com/riscv/rvv-intrinsic-doc. 12 // 13 // See also the documentation in include/clang/Basic/riscv_vector.td. 14 // 15 //===----------------------------------------------------------------------===// 16 17 #include "llvm/ADT/ArrayRef.h" 18 #include "llvm/ADT/SmallSet.h" 19 #include "llvm/ADT/StringExtras.h" 20 #include "llvm/ADT/StringMap.h" 21 #include "llvm/ADT/StringSet.h" 22 #include "llvm/ADT/Twine.h" 23 #include "llvm/TableGen/Error.h" 24 #include "llvm/TableGen/Record.h" 25 #include <numeric> 26 27 using namespace llvm; 28 using BasicType = char; 29 using VScaleVal = Optional<unsigned>; 30 31 namespace { 32 33 // Exponential LMUL 34 struct LMULType { 35 int Log2LMUL; 36 LMULType(int Log2LMUL); 37 // Return the C/C++ string representation of LMUL 38 std::string str() const; 39 Optional<unsigned> getScale(unsigned ElementBitwidth) const; 40 void MulLog2LMUL(int Log2LMUL); 41 LMULType &operator*=(uint32_t RHS); 42 }; 43 44 // This class is compact representation of a valid and invalid RVVType. 45 class RVVType { 46 enum ScalarTypeKind : uint32_t { 47 Void, 48 Size_t, 49 Ptrdiff_t, 50 UnsignedLong, 51 SignedLong, 52 Boolean, 53 SignedInteger, 54 UnsignedInteger, 55 Float, 56 Invalid, 57 }; 58 BasicType BT; 59 ScalarTypeKind ScalarType = Invalid; 60 LMULType LMUL; 61 bool IsPointer = false; 62 // IsConstant indices are "int", but have the constant expression. 63 bool IsImmediate = false; 64 // Const qualifier for pointer to const object or object of const type. 65 bool IsConstant = false; 66 unsigned ElementBitwidth = 0; 67 VScaleVal Scale = 0; 68 bool Valid; 69 70 std::string BuiltinStr; 71 std::string ClangBuiltinStr; 72 std::string Str; 73 std::string ShortStr; 74 75 public: 76 RVVType() : RVVType(BasicType(), 0, StringRef()) {} 77 RVVType(BasicType BT, int Log2LMUL, StringRef prototype); 78 79 // Return the string representation of a type, which is an encoded string for 80 // passing to the BUILTIN() macro in Builtins.def. 81 const std::string &getBuiltinStr() const { return BuiltinStr; } 82 83 // Return the clang builtin type for RVV vector type which are used in the 84 // riscv_vector.h header file. 85 const std::string &getClangBuiltinStr() const { return ClangBuiltinStr; } 86 87 // Return the C/C++ string representation of a type for use in the 88 // riscv_vector.h header file. 89 const std::string &getTypeStr() const { return Str; } 90 91 // Return the short name of a type for C/C++ name suffix. 92 const std::string &getShortStr() { 93 // Not all types are used in short name, so compute the short name by 94 // demanded. 95 if (ShortStr.empty()) 96 initShortStr(); 97 return ShortStr; 98 } 99 100 bool isValid() const { return Valid; } 101 bool isScalar() const { return Scale.hasValue() && Scale.getValue() == 0; } 102 bool isVector() const { return Scale.hasValue() && Scale.getValue() != 0; } 103 bool isFloat() const { return ScalarType == ScalarTypeKind::Float; } 104 bool isSignedInteger() const { 105 return ScalarType == ScalarTypeKind::SignedInteger; 106 } 107 bool isFloatVector(unsigned Width) const { 108 return isVector() && isFloat() && ElementBitwidth == Width; 109 } 110 bool isFloat(unsigned Width) const { 111 return isFloat() && ElementBitwidth == Width; 112 } 113 114 private: 115 // Verify RVV vector type and set Valid. 116 bool verifyType() const; 117 118 // Creates a type based on basic types of TypeRange 119 void applyBasicType(); 120 121 // Applies a prototype modifier to the current type. The result maybe an 122 // invalid type. 123 void applyModifier(StringRef prototype); 124 125 // Compute and record a string for legal type. 126 void initBuiltinStr(); 127 // Compute and record a builtin RVV vector type string. 128 void initClangBuiltinStr(); 129 // Compute and record a type string for used in the header. 130 void initTypeStr(); 131 // Compute and record a short name of a type for C/C++ name suffix. 132 void initShortStr(); 133 }; 134 135 using RVVTypePtr = RVVType *; 136 using RVVTypes = std::vector<RVVTypePtr>; 137 138 enum RISCVExtension : uint8_t { 139 Basic = 0, 140 F = 1 << 1, 141 D = 1 << 2, 142 Zfh = 1 << 3, 143 Zvlsseg = 1 << 4, 144 }; 145 146 // TODO refactor RVVIntrinsic class design after support all intrinsic 147 // combination. This represents an instantiation of an intrinsic with a 148 // particular type and prototype 149 class RVVIntrinsic { 150 151 private: 152 std::string BuiltinName; // Builtin name 153 std::string Name; // C intrinsic name. 154 std::string MangledName; 155 std::string IRName; 156 bool IsMask; 157 bool HasVL; 158 bool HasPolicy; 159 bool HasNoMaskedOverloaded; 160 bool HasAutoDef; // There is automiatic definition in header 161 std::string ManualCodegen; 162 RVVTypePtr OutputType; // Builtin output type 163 RVVTypes InputTypes; // Builtin input types 164 // The types we use to obtain the specific LLVM intrinsic. They are index of 165 // InputTypes. -1 means the return type. 166 std::vector<int64_t> IntrinsicTypes; 167 uint8_t RISCVExtensions = 0; 168 unsigned NF = 1; 169 170 public: 171 RVVIntrinsic(StringRef Name, StringRef Suffix, StringRef MangledName, 172 StringRef MangledSuffix, StringRef IRName, bool IsMask, 173 bool HasMaskedOffOperand, bool HasVL, bool HasPolicy, 174 bool HasNoMaskedOverloaded, bool HasAutoDef, 175 StringRef ManualCodegen, const RVVTypes &Types, 176 const std::vector<int64_t> &IntrinsicTypes, 177 StringRef RequiredExtension, unsigned NF); 178 ~RVVIntrinsic() = default; 179 180 StringRef getBuiltinName() const { return BuiltinName; } 181 StringRef getName() const { return Name; } 182 StringRef getMangledName() const { return MangledName; } 183 bool hasVL() const { return HasVL; } 184 bool hasPolicy() const { return HasPolicy; } 185 bool hasNoMaskedOverloaded() const { return HasNoMaskedOverloaded; } 186 bool hasManualCodegen() const { return !ManualCodegen.empty(); } 187 bool hasAutoDef() const { return HasAutoDef; } 188 bool isMask() const { return IsMask; } 189 StringRef getIRName() const { return IRName; } 190 StringRef getManualCodegen() const { return ManualCodegen; } 191 uint8_t getRISCVExtensions() const { return RISCVExtensions; } 192 unsigned getNF() const { return NF; } 193 const std::vector<int64_t> &getIntrinsicTypes() const { 194 return IntrinsicTypes; 195 } 196 197 // Return the type string for a BUILTIN() macro in Builtins.def. 198 std::string getBuiltinTypeStr() const; 199 200 // Emit the code block for switch body in EmitRISCVBuiltinExpr, it should 201 // init the RVVIntrinsic ID and IntrinsicTypes. 202 void emitCodeGenSwitchBody(raw_ostream &o) const; 203 204 // Emit the macros for mapping C/C++ intrinsic function to builtin functions. 205 void emitIntrinsicFuncDef(raw_ostream &o) const; 206 207 // Emit the mangled function definition. 208 void emitMangledFuncDef(raw_ostream &o) const; 209 }; 210 211 class RVVEmitter { 212 private: 213 RecordKeeper &Records; 214 std::string HeaderCode; 215 // Concat BasicType, LMUL and Proto as key 216 StringMap<RVVType> LegalTypes; 217 StringSet<> IllegalTypes; 218 219 public: 220 RVVEmitter(RecordKeeper &R) : Records(R) {} 221 222 /// Emit riscv_vector.h 223 void createHeader(raw_ostream &o); 224 225 /// Emit all the __builtin prototypes and code needed by Sema. 226 void createBuiltins(raw_ostream &o); 227 228 /// Emit all the information needed to map builtin -> LLVM IR intrinsic. 229 void createCodeGen(raw_ostream &o); 230 231 std::string getSuffixStr(char Type, int Log2LMUL, StringRef Prototypes); 232 233 private: 234 /// Create all intrinsics and add them to \p Out 235 void createRVVIntrinsics(std::vector<std::unique_ptr<RVVIntrinsic>> &Out); 236 /// Create Headers and add them to \p Out 237 void createRVVHeaders(raw_ostream &OS); 238 /// Compute output and input types by applying different config (basic type 239 /// and LMUL with type transformers). It also record result of type in legal 240 /// or illegal set to avoid compute the same config again. The result maybe 241 /// have illegal RVVType. 242 Optional<RVVTypes> computeTypes(BasicType BT, int Log2LMUL, unsigned NF, 243 ArrayRef<std::string> PrototypeSeq); 244 Optional<RVVTypePtr> computeType(BasicType BT, int Log2LMUL, StringRef Proto); 245 246 /// Emit Acrh predecessor definitions and body, assume the element of Defs are 247 /// sorted by extension. 248 void emitArchMacroAndBody( 249 std::vector<std::unique_ptr<RVVIntrinsic>> &Defs, raw_ostream &o, 250 std::function<void(raw_ostream &, const RVVIntrinsic &)>); 251 252 // Emit the architecture preprocessor definitions. Return true when emits 253 // non-empty string. 254 bool emitExtDefStr(uint8_t Extensions, raw_ostream &o); 255 // Slice Prototypes string into sub prototype string and process each sub 256 // prototype string individually in the Handler. 257 void parsePrototypes(StringRef Prototypes, 258 std::function<void(StringRef)> Handler); 259 }; 260 261 } // namespace 262 263 //===----------------------------------------------------------------------===// 264 // Type implementation 265 //===----------------------------------------------------------------------===// 266 267 LMULType::LMULType(int NewLog2LMUL) { 268 // Check Log2LMUL is -3, -2, -1, 0, 1, 2, 3 269 assert(NewLog2LMUL <= 3 && NewLog2LMUL >= -3 && "Bad LMUL number!"); 270 Log2LMUL = NewLog2LMUL; 271 } 272 273 std::string LMULType::str() const { 274 if (Log2LMUL < 0) 275 return "mf" + utostr(1ULL << (-Log2LMUL)); 276 return "m" + utostr(1ULL << Log2LMUL); 277 } 278 279 VScaleVal LMULType::getScale(unsigned ElementBitwidth) const { 280 int Log2ScaleResult = 0; 281 switch (ElementBitwidth) { 282 default: 283 break; 284 case 8: 285 Log2ScaleResult = Log2LMUL + 3; 286 break; 287 case 16: 288 Log2ScaleResult = Log2LMUL + 2; 289 break; 290 case 32: 291 Log2ScaleResult = Log2LMUL + 1; 292 break; 293 case 64: 294 Log2ScaleResult = Log2LMUL; 295 break; 296 } 297 // Illegal vscale result would be less than 1 298 if (Log2ScaleResult < 0) 299 return None; 300 return 1 << Log2ScaleResult; 301 } 302 303 void LMULType::MulLog2LMUL(int log2LMUL) { Log2LMUL += log2LMUL; } 304 305 LMULType &LMULType::operator*=(uint32_t RHS) { 306 assert(isPowerOf2_32(RHS)); 307 this->Log2LMUL = this->Log2LMUL + Log2_32(RHS); 308 return *this; 309 } 310 311 RVVType::RVVType(BasicType BT, int Log2LMUL, StringRef prototype) 312 : BT(BT), LMUL(LMULType(Log2LMUL)) { 313 applyBasicType(); 314 applyModifier(prototype); 315 Valid = verifyType(); 316 if (Valid) { 317 initBuiltinStr(); 318 initTypeStr(); 319 if (isVector()) { 320 initClangBuiltinStr(); 321 } 322 } 323 } 324 325 // clang-format off 326 // boolean type are encoded the ratio of n (SEW/LMUL) 327 // SEW/LMUL | 1 | 2 | 4 | 8 | 16 | 32 | 64 328 // c type | vbool64_t | vbool32_t | vbool16_t | vbool8_t | vbool4_t | vbool2_t | vbool1_t 329 // IR type | nxv1i1 | nxv2i1 | nxv4i1 | nxv8i1 | nxv16i1 | nxv32i1 | nxv64i1 330 331 // type\lmul | 1/8 | 1/4 | 1/2 | 1 | 2 | 4 | 8 332 // -------- |------ | -------- | ------- | ------- | -------- | -------- | -------- 333 // i64 | N/A | N/A | N/A | nxv1i64 | nxv2i64 | nxv4i64 | nxv8i64 334 // i32 | N/A | N/A | nxv1i32 | nxv2i32 | nxv4i32 | nxv8i32 | nxv16i32 335 // i16 | N/A | nxv1i16 | nxv2i16 | nxv4i16 | nxv8i16 | nxv16i16 | nxv32i16 336 // i8 | nxv1i8 | nxv2i8 | nxv4i8 | nxv8i8 | nxv16i8 | nxv32i8 | nxv64i8 337 // double | N/A | N/A | N/A | nxv1f64 | nxv2f64 | nxv4f64 | nxv8f64 338 // float | N/A | N/A | nxv1f32 | nxv2f32 | nxv4f32 | nxv8f32 | nxv16f32 339 // half | N/A | nxv1f16 | nxv2f16 | nxv4f16 | nxv8f16 | nxv16f16 | nxv32f16 340 // clang-format on 341 342 bool RVVType::verifyType() const { 343 if (ScalarType == Invalid) 344 return false; 345 if (isScalar()) 346 return true; 347 if (!Scale.hasValue()) 348 return false; 349 if (isFloat() && ElementBitwidth == 8) 350 return false; 351 unsigned V = Scale.getValue(); 352 switch (ElementBitwidth) { 353 case 1: 354 case 8: 355 // Check Scale is 1,2,4,8,16,32,64 356 return (V <= 64 && isPowerOf2_32(V)); 357 case 16: 358 // Check Scale is 1,2,4,8,16,32 359 return (V <= 32 && isPowerOf2_32(V)); 360 case 32: 361 // Check Scale is 1,2,4,8,16 362 return (V <= 16 && isPowerOf2_32(V)); 363 case 64: 364 // Check Scale is 1,2,4,8 365 return (V <= 8 && isPowerOf2_32(V)); 366 } 367 return false; 368 } 369 370 void RVVType::initBuiltinStr() { 371 assert(isValid() && "RVVType is invalid"); 372 switch (ScalarType) { 373 case ScalarTypeKind::Void: 374 BuiltinStr = "v"; 375 return; 376 case ScalarTypeKind::Size_t: 377 BuiltinStr = "z"; 378 if (IsImmediate) 379 BuiltinStr = "I" + BuiltinStr; 380 if (IsPointer) 381 BuiltinStr += "*"; 382 return; 383 case ScalarTypeKind::Ptrdiff_t: 384 BuiltinStr = "Y"; 385 return; 386 case ScalarTypeKind::UnsignedLong: 387 BuiltinStr = "ULi"; 388 return; 389 case ScalarTypeKind::SignedLong: 390 BuiltinStr = "Li"; 391 return; 392 case ScalarTypeKind::Boolean: 393 assert(ElementBitwidth == 1); 394 BuiltinStr += "b"; 395 break; 396 case ScalarTypeKind::SignedInteger: 397 case ScalarTypeKind::UnsignedInteger: 398 switch (ElementBitwidth) { 399 case 8: 400 BuiltinStr += "c"; 401 break; 402 case 16: 403 BuiltinStr += "s"; 404 break; 405 case 32: 406 BuiltinStr += "i"; 407 break; 408 case 64: 409 BuiltinStr += "Wi"; 410 break; 411 default: 412 llvm_unreachable("Unhandled ElementBitwidth!"); 413 } 414 if (isSignedInteger()) 415 BuiltinStr = "S" + BuiltinStr; 416 else 417 BuiltinStr = "U" + BuiltinStr; 418 break; 419 case ScalarTypeKind::Float: 420 switch (ElementBitwidth) { 421 case 16: 422 BuiltinStr += "x"; 423 break; 424 case 32: 425 BuiltinStr += "f"; 426 break; 427 case 64: 428 BuiltinStr += "d"; 429 break; 430 default: 431 llvm_unreachable("Unhandled ElementBitwidth!"); 432 } 433 break; 434 default: 435 llvm_unreachable("ScalarType is invalid!"); 436 } 437 if (IsImmediate) 438 BuiltinStr = "I" + BuiltinStr; 439 if (isScalar()) { 440 if (IsConstant) 441 BuiltinStr += "C"; 442 if (IsPointer) 443 BuiltinStr += "*"; 444 return; 445 } 446 BuiltinStr = "q" + utostr(Scale.getValue()) + BuiltinStr; 447 // Pointer to vector types. Defined for Zvlsseg load intrinsics. 448 // Zvlsseg load intrinsics have pointer type arguments to store the loaded 449 // vector values. 450 if (IsPointer) 451 BuiltinStr += "*"; 452 } 453 454 void RVVType::initClangBuiltinStr() { 455 assert(isValid() && "RVVType is invalid"); 456 assert(isVector() && "Handle Vector type only"); 457 458 ClangBuiltinStr = "__rvv_"; 459 switch (ScalarType) { 460 case ScalarTypeKind::Boolean: 461 ClangBuiltinStr += "bool" + utostr(64 / Scale.getValue()) + "_t"; 462 return; 463 case ScalarTypeKind::Float: 464 ClangBuiltinStr += "float"; 465 break; 466 case ScalarTypeKind::SignedInteger: 467 ClangBuiltinStr += "int"; 468 break; 469 case ScalarTypeKind::UnsignedInteger: 470 ClangBuiltinStr += "uint"; 471 break; 472 default: 473 llvm_unreachable("ScalarTypeKind is invalid"); 474 } 475 ClangBuiltinStr += utostr(ElementBitwidth) + LMUL.str() + "_t"; 476 } 477 478 void RVVType::initTypeStr() { 479 assert(isValid() && "RVVType is invalid"); 480 481 if (IsConstant) 482 Str += "const "; 483 484 auto getTypeString = [&](StringRef TypeStr) { 485 if (isScalar()) 486 return Twine(TypeStr + Twine(ElementBitwidth) + "_t").str(); 487 return Twine("v" + TypeStr + Twine(ElementBitwidth) + LMUL.str() + "_t") 488 .str(); 489 }; 490 491 switch (ScalarType) { 492 case ScalarTypeKind::Void: 493 Str = "void"; 494 return; 495 case ScalarTypeKind::Size_t: 496 Str = "size_t"; 497 if (IsPointer) 498 Str += " *"; 499 return; 500 case ScalarTypeKind::Ptrdiff_t: 501 Str = "ptrdiff_t"; 502 return; 503 case ScalarTypeKind::UnsignedLong: 504 Str = "unsigned long"; 505 return; 506 case ScalarTypeKind::SignedLong: 507 Str = "long"; 508 return; 509 case ScalarTypeKind::Boolean: 510 if (isScalar()) 511 Str += "bool"; 512 else 513 // Vector bool is special case, the formulate is 514 // `vbool<N>_t = MVT::nxv<64/N>i1` ex. vbool16_t = MVT::4i1 515 Str += "vbool" + utostr(64 / Scale.getValue()) + "_t"; 516 break; 517 case ScalarTypeKind::Float: 518 if (isScalar()) { 519 if (ElementBitwidth == 64) 520 Str += "double"; 521 else if (ElementBitwidth == 32) 522 Str += "float"; 523 else if (ElementBitwidth == 16) 524 Str += "_Float16"; 525 else 526 llvm_unreachable("Unhandled floating type."); 527 } else 528 Str += getTypeString("float"); 529 break; 530 case ScalarTypeKind::SignedInteger: 531 Str += getTypeString("int"); 532 break; 533 case ScalarTypeKind::UnsignedInteger: 534 Str += getTypeString("uint"); 535 break; 536 default: 537 llvm_unreachable("ScalarType is invalid!"); 538 } 539 if (IsPointer) 540 Str += " *"; 541 } 542 543 void RVVType::initShortStr() { 544 switch (ScalarType) { 545 case ScalarTypeKind::Boolean: 546 assert(isVector()); 547 ShortStr = "b" + utostr(64 / Scale.getValue()); 548 return; 549 case ScalarTypeKind::Float: 550 ShortStr = "f" + utostr(ElementBitwidth); 551 break; 552 case ScalarTypeKind::SignedInteger: 553 ShortStr = "i" + utostr(ElementBitwidth); 554 break; 555 case ScalarTypeKind::UnsignedInteger: 556 ShortStr = "u" + utostr(ElementBitwidth); 557 break; 558 default: 559 PrintFatalError("Unhandled case!"); 560 } 561 if (isVector()) 562 ShortStr += LMUL.str(); 563 } 564 565 void RVVType::applyBasicType() { 566 switch (BT) { 567 case 'c': 568 ElementBitwidth = 8; 569 ScalarType = ScalarTypeKind::SignedInteger; 570 break; 571 case 's': 572 ElementBitwidth = 16; 573 ScalarType = ScalarTypeKind::SignedInteger; 574 break; 575 case 'i': 576 ElementBitwidth = 32; 577 ScalarType = ScalarTypeKind::SignedInteger; 578 break; 579 case 'l': 580 ElementBitwidth = 64; 581 ScalarType = ScalarTypeKind::SignedInteger; 582 break; 583 case 'x': 584 ElementBitwidth = 16; 585 ScalarType = ScalarTypeKind::Float; 586 break; 587 case 'f': 588 ElementBitwidth = 32; 589 ScalarType = ScalarTypeKind::Float; 590 break; 591 case 'd': 592 ElementBitwidth = 64; 593 ScalarType = ScalarTypeKind::Float; 594 break; 595 default: 596 PrintFatalError("Unhandled type code!"); 597 } 598 assert(ElementBitwidth != 0 && "Bad element bitwidth!"); 599 } 600 601 void RVVType::applyModifier(StringRef Transformer) { 602 if (Transformer.empty()) 603 return; 604 // Handle primitive type transformer 605 auto PType = Transformer.back(); 606 switch (PType) { 607 case 'e': 608 Scale = 0; 609 break; 610 case 'v': 611 Scale = LMUL.getScale(ElementBitwidth); 612 break; 613 case 'w': 614 ElementBitwidth *= 2; 615 LMUL *= 2; 616 Scale = LMUL.getScale(ElementBitwidth); 617 break; 618 case 'q': 619 ElementBitwidth *= 4; 620 LMUL *= 4; 621 Scale = LMUL.getScale(ElementBitwidth); 622 break; 623 case 'o': 624 ElementBitwidth *= 8; 625 LMUL *= 8; 626 Scale = LMUL.getScale(ElementBitwidth); 627 break; 628 case 'm': 629 ScalarType = ScalarTypeKind::Boolean; 630 Scale = LMUL.getScale(ElementBitwidth); 631 ElementBitwidth = 1; 632 break; 633 case '0': 634 ScalarType = ScalarTypeKind::Void; 635 break; 636 case 'z': 637 ScalarType = ScalarTypeKind::Size_t; 638 break; 639 case 't': 640 ScalarType = ScalarTypeKind::Ptrdiff_t; 641 break; 642 case 'u': 643 ScalarType = ScalarTypeKind::UnsignedLong; 644 break; 645 case 'l': 646 ScalarType = ScalarTypeKind::SignedLong; 647 break; 648 default: 649 PrintFatalError("Illegal primitive type transformers!"); 650 } 651 Transformer = Transformer.drop_back(); 652 653 // Extract and compute complex type transformer. It can only appear one time. 654 if (Transformer.startswith("(")) { 655 size_t Idx = Transformer.find(')'); 656 assert(Idx != StringRef::npos); 657 StringRef ComplexType = Transformer.slice(1, Idx); 658 Transformer = Transformer.drop_front(Idx + 1); 659 assert(!Transformer.contains('(') && 660 "Only allow one complex type transformer"); 661 662 auto UpdateAndCheckComplexProto = [&]() { 663 Scale = LMUL.getScale(ElementBitwidth); 664 const StringRef VectorPrototypes("vwqom"); 665 if (!VectorPrototypes.contains(PType)) 666 PrintFatalError("Complex type transformer only supports vector type!"); 667 if (Transformer.find_first_of("PCKWS") != StringRef::npos) 668 PrintFatalError( 669 "Illegal type transformer for Complex type transformer"); 670 }; 671 auto ComputeFixedLog2LMUL = 672 [&](StringRef Value, 673 std::function<bool(const int32_t &, const int32_t &)> Compare) { 674 int32_t Log2LMUL; 675 Value.getAsInteger(10, Log2LMUL); 676 if (!Compare(Log2LMUL, LMUL.Log2LMUL)) { 677 ScalarType = Invalid; 678 return false; 679 } 680 // Update new LMUL 681 LMUL = LMULType(Log2LMUL); 682 UpdateAndCheckComplexProto(); 683 return true; 684 }; 685 auto ComplexTT = ComplexType.split(":"); 686 if (ComplexTT.first == "Log2EEW") { 687 uint32_t Log2EEW; 688 ComplexTT.second.getAsInteger(10, Log2EEW); 689 // update new elmul = (eew/sew) * lmul 690 LMUL.MulLog2LMUL(Log2EEW - Log2_32(ElementBitwidth)); 691 // update new eew 692 ElementBitwidth = 1 << Log2EEW; 693 ScalarType = ScalarTypeKind::SignedInteger; 694 UpdateAndCheckComplexProto(); 695 } else if (ComplexTT.first == "FixedSEW") { 696 uint32_t NewSEW; 697 ComplexTT.second.getAsInteger(10, NewSEW); 698 // Set invalid type if src and dst SEW are same. 699 if (ElementBitwidth == NewSEW) { 700 ScalarType = Invalid; 701 return; 702 } 703 // Update new SEW 704 ElementBitwidth = NewSEW; 705 UpdateAndCheckComplexProto(); 706 } else if (ComplexTT.first == "LFixedLog2LMUL") { 707 // New LMUL should be larger than old 708 if (!ComputeFixedLog2LMUL(ComplexTT.second, std::greater<int32_t>())) 709 return; 710 } else if (ComplexTT.first == "SFixedLog2LMUL") { 711 // New LMUL should be smaller than old 712 if (!ComputeFixedLog2LMUL(ComplexTT.second, std::less<int32_t>())) 713 return; 714 } else { 715 PrintFatalError("Illegal complex type transformers!"); 716 } 717 } 718 719 // Compute the remain type transformers 720 for (char I : Transformer) { 721 switch (I) { 722 case 'P': 723 if (IsConstant) 724 PrintFatalError("'P' transformer cannot be used after 'C'"); 725 if (IsPointer) 726 PrintFatalError("'P' transformer cannot be used twice"); 727 IsPointer = true; 728 break; 729 case 'C': 730 if (IsConstant) 731 PrintFatalError("'C' transformer cannot be used twice"); 732 IsConstant = true; 733 break; 734 case 'K': 735 IsImmediate = true; 736 break; 737 case 'U': 738 ScalarType = ScalarTypeKind::UnsignedInteger; 739 break; 740 case 'I': 741 ScalarType = ScalarTypeKind::SignedInteger; 742 break; 743 case 'F': 744 ScalarType = ScalarTypeKind::Float; 745 break; 746 case 'S': 747 LMUL = LMULType(0); 748 // Update ElementBitwidth need to update Scale too. 749 Scale = LMUL.getScale(ElementBitwidth); 750 break; 751 default: 752 PrintFatalError("Illegal non-primitive type transformer!"); 753 } 754 } 755 } 756 757 //===----------------------------------------------------------------------===// 758 // RVVIntrinsic implementation 759 //===----------------------------------------------------------------------===// 760 RVVIntrinsic::RVVIntrinsic(StringRef NewName, StringRef Suffix, 761 StringRef NewMangledName, StringRef MangledSuffix, 762 StringRef IRName, bool IsMask, 763 bool HasMaskedOffOperand, bool HasVL, bool HasPolicy, 764 bool HasNoMaskedOverloaded, bool HasAutoDef, 765 StringRef ManualCodegen, const RVVTypes &OutInTypes, 766 const std::vector<int64_t> &NewIntrinsicTypes, 767 StringRef RequiredExtension, unsigned NF) 768 : IRName(IRName), IsMask(IsMask), HasVL(HasVL), HasPolicy(HasPolicy), 769 HasNoMaskedOverloaded(HasNoMaskedOverloaded), HasAutoDef(HasAutoDef), 770 ManualCodegen(ManualCodegen.str()), NF(NF) { 771 772 // Init BuiltinName, Name and MangledName 773 BuiltinName = NewName.str(); 774 Name = BuiltinName; 775 if (NewMangledName.empty()) 776 MangledName = NewName.split("_").first.str(); 777 else 778 MangledName = NewMangledName.str(); 779 if (!Suffix.empty()) 780 Name += "_" + Suffix.str(); 781 if (!MangledSuffix.empty()) 782 MangledName += "_" + MangledSuffix.str(); 783 if (IsMask) { 784 BuiltinName += "_m"; 785 Name += "_m"; 786 } 787 788 // Init RISC-V extensions 789 for (const auto &T : OutInTypes) { 790 if (T->isFloatVector(16) || T->isFloat(16)) 791 RISCVExtensions |= RISCVExtension::Zfh; 792 else if (T->isFloatVector(32) || T->isFloat(32)) 793 RISCVExtensions |= RISCVExtension::F; 794 else if (T->isFloatVector(64) || T->isFloat(64)) 795 RISCVExtensions |= RISCVExtension::D; 796 } 797 if (RequiredExtension == "Zvlsseg") 798 RISCVExtensions |= RISCVExtension::Zvlsseg; 799 800 // Init OutputType and InputTypes 801 OutputType = OutInTypes[0]; 802 InputTypes.assign(OutInTypes.begin() + 1, OutInTypes.end()); 803 804 // IntrinsicTypes is nonmasked version index. Need to update it 805 // if there is maskedoff operand (It is always in first operand). 806 IntrinsicTypes = NewIntrinsicTypes; 807 if (IsMask && HasMaskedOffOperand) { 808 for (auto &I : IntrinsicTypes) { 809 if (I >= 0) 810 I += NF; 811 } 812 } 813 } 814 815 std::string RVVIntrinsic::getBuiltinTypeStr() const { 816 std::string S; 817 S += OutputType->getBuiltinStr(); 818 for (const auto &T : InputTypes) { 819 S += T->getBuiltinStr(); 820 } 821 return S; 822 } 823 824 void RVVIntrinsic::emitCodeGenSwitchBody(raw_ostream &OS) const { 825 if (!getIRName().empty()) 826 OS << " ID = Intrinsic::riscv_" + getIRName() + ";\n"; 827 if (NF >= 2) 828 OS << " NF = " + utostr(getNF()) + ";\n"; 829 if (hasManualCodegen()) { 830 OS << ManualCodegen; 831 OS << "break;\n"; 832 return; 833 } 834 835 if (isMask()) { 836 if (hasVL()) { 837 OS << " std::rotate(Ops.begin(), Ops.begin() + 1, Ops.end() - 1);\n"; 838 if (hasPolicy()) 839 OS << " Ops.push_back(ConstantInt::get(Ops.back()->getType()," 840 " TAIL_UNDISTURBED));\n"; 841 } else { 842 OS << " std::rotate(Ops.begin(), Ops.begin() + 1, Ops.end());\n"; 843 } 844 } 845 846 OS << " IntrinsicTypes = {"; 847 ListSeparator LS; 848 for (const auto &Idx : IntrinsicTypes) { 849 if (Idx == -1) 850 OS << LS << "ResultType"; 851 else 852 OS << LS << "Ops[" << Idx << "]->getType()"; 853 } 854 855 // VL could be i64 or i32, need to encode it in IntrinsicTypes. VL is 856 // always last operand. 857 if (hasVL()) 858 OS << ", Ops.back()->getType()"; 859 OS << "};\n"; 860 OS << " break;\n"; 861 } 862 863 void RVVIntrinsic::emitIntrinsicFuncDef(raw_ostream &OS) const { 864 OS << "__attribute__((__clang_builtin_alias__("; 865 OS << "__builtin_rvv_" << getBuiltinName() << ")))\n"; 866 OS << OutputType->getTypeStr() << " " << getName() << "("; 867 // Emit function arguments 868 if (!InputTypes.empty()) { 869 ListSeparator LS; 870 for (unsigned i = 0; i < InputTypes.size(); ++i) 871 OS << LS << InputTypes[i]->getTypeStr(); 872 } 873 OS << ");\n"; 874 } 875 876 void RVVIntrinsic::emitMangledFuncDef(raw_ostream &OS) const { 877 OS << "__attribute__((__clang_builtin_alias__("; 878 OS << "__builtin_rvv_" << getBuiltinName() << ")))\n"; 879 OS << OutputType->getTypeStr() << " " << getMangledName() << "("; 880 // Emit function arguments 881 if (!InputTypes.empty()) { 882 ListSeparator LS; 883 for (unsigned i = 0; i < InputTypes.size(); ++i) 884 OS << LS << InputTypes[i]->getTypeStr(); 885 } 886 OS << ");\n"; 887 } 888 889 //===----------------------------------------------------------------------===// 890 // RVVEmitter implementation 891 //===----------------------------------------------------------------------===// 892 void RVVEmitter::createHeader(raw_ostream &OS) { 893 894 OS << "/*===---- riscv_vector.h - RISC-V V-extension RVVIntrinsics " 895 "-------------------===\n" 896 " *\n" 897 " *\n" 898 " * Part of the LLVM Project, under the Apache License v2.0 with LLVM " 899 "Exceptions.\n" 900 " * See https://llvm.org/LICENSE.txt for license information.\n" 901 " * SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception\n" 902 " *\n" 903 " *===-----------------------------------------------------------------" 904 "------===\n" 905 " */\n\n"; 906 907 OS << "#ifndef __RISCV_VECTOR_H\n"; 908 OS << "#define __RISCV_VECTOR_H\n\n"; 909 910 OS << "#include <stdint.h>\n"; 911 OS << "#include <stddef.h>\n\n"; 912 913 OS << "#ifndef __riscv_vector\n"; 914 OS << "#error \"Vector intrinsics require the vector extension.\"\n"; 915 OS << "#endif\n\n"; 916 917 OS << "#ifdef __cplusplus\n"; 918 OS << "extern \"C\" {\n"; 919 OS << "#endif\n\n"; 920 921 createRVVHeaders(OS); 922 923 std::vector<std::unique_ptr<RVVIntrinsic>> Defs; 924 createRVVIntrinsics(Defs); 925 926 // Print header code 927 if (!HeaderCode.empty()) { 928 OS << HeaderCode; 929 } 930 931 auto printType = [&](auto T) { 932 OS << "typedef " << T->getClangBuiltinStr() << " " << T->getTypeStr() 933 << ";\n"; 934 }; 935 936 constexpr int Log2LMULs[] = {-3, -2, -1, 0, 1, 2, 3}; 937 // Print RVV boolean types. 938 for (int Log2LMUL : Log2LMULs) { 939 auto T = computeType('c', Log2LMUL, "m"); 940 if (T.hasValue()) 941 printType(T.getValue()); 942 } 943 // Print RVV int/float types. 944 for (char I : StringRef("csil")) { 945 for (int Log2LMUL : Log2LMULs) { 946 auto T = computeType(I, Log2LMUL, "v"); 947 if (T.hasValue()) { 948 printType(T.getValue()); 949 auto UT = computeType(I, Log2LMUL, "Uv"); 950 printType(UT.getValue()); 951 } 952 } 953 } 954 OS << "#if defined(__riscv_zfh)\n"; 955 for (int Log2LMUL : Log2LMULs) { 956 auto T = computeType('x', Log2LMUL, "v"); 957 if (T.hasValue()) 958 printType(T.getValue()); 959 } 960 OS << "#endif\n"; 961 962 OS << "#if defined(__riscv_f)\n"; 963 for (int Log2LMUL : Log2LMULs) { 964 auto T = computeType('f', Log2LMUL, "v"); 965 if (T.hasValue()) 966 printType(T.getValue()); 967 } 968 OS << "#endif\n"; 969 970 OS << "#if defined(__riscv_d)\n"; 971 for (int Log2LMUL : Log2LMULs) { 972 auto T = computeType('d', Log2LMUL, "v"); 973 if (T.hasValue()) 974 printType(T.getValue()); 975 } 976 OS << "#endif\n\n"; 977 978 // The same extension include in the same arch guard marco. 979 llvm::stable_sort(Defs, [](const std::unique_ptr<RVVIntrinsic> &A, 980 const std::unique_ptr<RVVIntrinsic> &B) { 981 return A->getRISCVExtensions() < B->getRISCVExtensions(); 982 }); 983 984 OS << "#define __rvv_ai static __inline__\n"; 985 986 // Print intrinsic functions with macro 987 emitArchMacroAndBody(Defs, OS, [](raw_ostream &OS, const RVVIntrinsic &Inst) { 988 OS << "__rvv_ai "; 989 Inst.emitIntrinsicFuncDef(OS); 990 }); 991 992 OS << "#undef __rvv_ai\n\n"; 993 994 OS << "#define __riscv_v_intrinsic_overloading 1\n"; 995 996 // Print Overloaded APIs 997 OS << "#define __rvv_aio static __inline__ " 998 "__attribute__((__overloadable__))\n"; 999 1000 emitArchMacroAndBody(Defs, OS, [](raw_ostream &OS, const RVVIntrinsic &Inst) { 1001 if (!Inst.isMask() && !Inst.hasNoMaskedOverloaded()) 1002 return; 1003 OS << "__rvv_aio "; 1004 Inst.emitMangledFuncDef(OS); 1005 }); 1006 1007 OS << "#undef __rvv_aio\n"; 1008 1009 OS << "\n#ifdef __cplusplus\n"; 1010 OS << "}\n"; 1011 OS << "#endif // __cplusplus\n"; 1012 OS << "#endif // __RISCV_VECTOR_H\n"; 1013 } 1014 1015 void RVVEmitter::createBuiltins(raw_ostream &OS) { 1016 std::vector<std::unique_ptr<RVVIntrinsic>> Defs; 1017 createRVVIntrinsics(Defs); 1018 1019 // Map to keep track of which builtin names have already been emitted. 1020 StringMap<RVVIntrinsic *> BuiltinMap; 1021 1022 OS << "#if defined(TARGET_BUILTIN) && !defined(RISCVV_BUILTIN)\n"; 1023 OS << "#define RISCVV_BUILTIN(ID, TYPE, ATTRS) TARGET_BUILTIN(ID, TYPE, " 1024 "ATTRS, \"experimental-v\")\n"; 1025 OS << "#endif\n"; 1026 for (auto &Def : Defs) { 1027 auto P = 1028 BuiltinMap.insert(std::make_pair(Def->getBuiltinName(), Def.get())); 1029 if (!P.second) { 1030 // Verify that this would have produced the same builtin definition. 1031 if (P.first->second->hasAutoDef() != Def->hasAutoDef()) { 1032 PrintFatalError("Builtin with same name has different hasAutoDef"); 1033 } else if (!Def->hasAutoDef() && P.first->second->getBuiltinTypeStr() != 1034 Def->getBuiltinTypeStr()) { 1035 PrintFatalError("Builtin with same name has different type string"); 1036 } 1037 continue; 1038 } 1039 1040 OS << "RISCVV_BUILTIN(__builtin_rvv_" << Def->getBuiltinName() << ",\""; 1041 if (!Def->hasAutoDef()) 1042 OS << Def->getBuiltinTypeStr(); 1043 OS << "\", \"n\")\n"; 1044 } 1045 OS << "#undef RISCVV_BUILTIN\n"; 1046 } 1047 1048 void RVVEmitter::createCodeGen(raw_ostream &OS) { 1049 std::vector<std::unique_ptr<RVVIntrinsic>> Defs; 1050 createRVVIntrinsics(Defs); 1051 // IR name could be empty, use the stable sort preserves the relative order. 1052 llvm::stable_sort(Defs, [](const std::unique_ptr<RVVIntrinsic> &A, 1053 const std::unique_ptr<RVVIntrinsic> &B) { 1054 return A->getIRName() < B->getIRName(); 1055 }); 1056 1057 // Map to keep track of which builtin names have already been emitted. 1058 StringMap<RVVIntrinsic *> BuiltinMap; 1059 1060 // Print switch body when the ir name or ManualCodegen changes from previous 1061 // iteration. 1062 RVVIntrinsic *PrevDef = Defs.begin()->get(); 1063 for (auto &Def : Defs) { 1064 StringRef CurIRName = Def->getIRName(); 1065 if (CurIRName != PrevDef->getIRName() || 1066 (Def->getManualCodegen() != PrevDef->getManualCodegen())) { 1067 PrevDef->emitCodeGenSwitchBody(OS); 1068 } 1069 PrevDef = Def.get(); 1070 1071 auto P = 1072 BuiltinMap.insert(std::make_pair(Def->getBuiltinName(), Def.get())); 1073 if (P.second) { 1074 OS << "case RISCVVector::BI__builtin_rvv_" << Def->getBuiltinName() 1075 << ":\n"; 1076 continue; 1077 } 1078 1079 if (P.first->second->getIRName() != Def->getIRName()) 1080 PrintFatalError("Builtin with same name has different IRName"); 1081 else if (P.first->second->getManualCodegen() != Def->getManualCodegen()) 1082 PrintFatalError("Builtin with same name has different ManualCodegen"); 1083 else if (P.first->second->getNF() != Def->getNF()) 1084 PrintFatalError("Builtin with same name has different NF"); 1085 else if (P.first->second->isMask() != Def->isMask()) 1086 PrintFatalError("Builtin with same name has different isMask"); 1087 else if (P.first->second->hasVL() != Def->hasVL()) 1088 PrintFatalError("Builtin with same name has different HasPolicy"); 1089 else if (P.first->second->hasPolicy() != Def->hasPolicy()) 1090 PrintFatalError("Builtin with same name has different HasPolicy"); 1091 else if (P.first->second->getIntrinsicTypes() != Def->getIntrinsicTypes()) 1092 PrintFatalError("Builtin with same name has different IntrinsicTypes"); 1093 } 1094 Defs.back()->emitCodeGenSwitchBody(OS); 1095 OS << "\n"; 1096 } 1097 1098 void RVVEmitter::parsePrototypes(StringRef Prototypes, 1099 std::function<void(StringRef)> Handler) { 1100 const StringRef Primaries("evwqom0ztul"); 1101 while (!Prototypes.empty()) { 1102 size_t Idx = 0; 1103 // Skip over complex prototype because it could contain primitive type 1104 // character. 1105 if (Prototypes[0] == '(') 1106 Idx = Prototypes.find_first_of(')'); 1107 Idx = Prototypes.find_first_of(Primaries, Idx); 1108 assert(Idx != StringRef::npos); 1109 Handler(Prototypes.slice(0, Idx + 1)); 1110 Prototypes = Prototypes.drop_front(Idx + 1); 1111 } 1112 } 1113 1114 std::string RVVEmitter::getSuffixStr(char Type, int Log2LMUL, 1115 StringRef Prototypes) { 1116 SmallVector<std::string> SuffixStrs; 1117 parsePrototypes(Prototypes, [&](StringRef Proto) { 1118 auto T = computeType(Type, Log2LMUL, Proto); 1119 SuffixStrs.push_back(T.getValue()->getShortStr()); 1120 }); 1121 return join(SuffixStrs, "_"); 1122 } 1123 1124 void RVVEmitter::createRVVIntrinsics( 1125 std::vector<std::unique_ptr<RVVIntrinsic>> &Out) { 1126 std::vector<Record *> RV = Records.getAllDerivedDefinitions("RVVBuiltin"); 1127 for (auto *R : RV) { 1128 StringRef Name = R->getValueAsString("Name"); 1129 StringRef SuffixProto = R->getValueAsString("Suffix"); 1130 StringRef MangledName = R->getValueAsString("MangledName"); 1131 StringRef MangledSuffixProto = R->getValueAsString("MangledSuffix"); 1132 StringRef Prototypes = R->getValueAsString("Prototype"); 1133 StringRef TypeRange = R->getValueAsString("TypeRange"); 1134 bool HasMask = R->getValueAsBit("HasMask"); 1135 bool HasMaskedOffOperand = R->getValueAsBit("HasMaskedOffOperand"); 1136 bool HasVL = R->getValueAsBit("HasVL"); 1137 bool HasPolicy = R->getValueAsBit("HasPolicy"); 1138 bool HasNoMaskedOverloaded = R->getValueAsBit("HasNoMaskedOverloaded"); 1139 std::vector<int64_t> Log2LMULList = R->getValueAsListOfInts("Log2LMUL"); 1140 StringRef ManualCodegen = R->getValueAsString("ManualCodegen"); 1141 StringRef ManualCodegenMask = R->getValueAsString("ManualCodegenMask"); 1142 std::vector<int64_t> IntrinsicTypes = 1143 R->getValueAsListOfInts("IntrinsicTypes"); 1144 StringRef RequiredExtension = R->getValueAsString("RequiredExtension"); 1145 StringRef IRName = R->getValueAsString("IRName"); 1146 StringRef IRNameMask = R->getValueAsString("IRNameMask"); 1147 unsigned NF = R->getValueAsInt("NF"); 1148 1149 StringRef HeaderCodeStr = R->getValueAsString("HeaderCode"); 1150 bool HasAutoDef = HeaderCodeStr.empty(); 1151 if (!HeaderCodeStr.empty()) { 1152 HeaderCode += HeaderCodeStr.str(); 1153 } 1154 // Parse prototype and create a list of primitive type with transformers 1155 // (operand) in ProtoSeq. ProtoSeq[0] is output operand. 1156 SmallVector<std::string> ProtoSeq; 1157 parsePrototypes(Prototypes, [&ProtoSeq](StringRef Proto) { 1158 ProtoSeq.push_back(Proto.str()); 1159 }); 1160 1161 // Compute Builtin types 1162 SmallVector<std::string> ProtoMaskSeq = ProtoSeq; 1163 if (HasMask) { 1164 // If HasMaskedOffOperand, insert result type as first input operand. 1165 if (HasMaskedOffOperand) { 1166 if (NF == 1) { 1167 ProtoMaskSeq.insert(ProtoMaskSeq.begin() + 1, ProtoSeq[0]); 1168 } else { 1169 // Convert 1170 // (void, op0 address, op1 address, ...) 1171 // to 1172 // (void, op0 address, op1 address, ..., maskedoff0, maskedoff1, ...) 1173 for (unsigned I = 0; I < NF; ++I) 1174 ProtoMaskSeq.insert( 1175 ProtoMaskSeq.begin() + NF + 1, 1176 ProtoSeq[1].substr(1)); // Use substr(1) to skip '*' 1177 } 1178 } 1179 if (HasMaskedOffOperand && NF > 1) { 1180 // Convert 1181 // (void, op0 address, op1 address, ..., maskedoff0, maskedoff1, ...) 1182 // to 1183 // (void, op0 address, op1 address, ..., mask, maskedoff0, maskedoff1, 1184 // ...) 1185 ProtoMaskSeq.insert(ProtoMaskSeq.begin() + NF + 1, "m"); 1186 } else { 1187 // If HasMask, insert 'm' as first input operand. 1188 ProtoMaskSeq.insert(ProtoMaskSeq.begin() + 1, "m"); 1189 } 1190 } 1191 // If HasVL, append 'z' to last operand 1192 if (HasVL) { 1193 ProtoSeq.push_back("z"); 1194 ProtoMaskSeq.push_back("z"); 1195 } 1196 1197 // Create Intrinsics for each type and LMUL. 1198 for (char I : TypeRange) { 1199 for (int Log2LMUL : Log2LMULList) { 1200 Optional<RVVTypes> Types = computeTypes(I, Log2LMUL, NF, ProtoSeq); 1201 // Ignored to create new intrinsic if there are any illegal types. 1202 if (!Types.hasValue()) 1203 continue; 1204 1205 auto SuffixStr = getSuffixStr(I, Log2LMUL, SuffixProto); 1206 auto MangledSuffixStr = getSuffixStr(I, Log2LMUL, MangledSuffixProto); 1207 // Create a non-mask intrinsic 1208 Out.push_back(std::make_unique<RVVIntrinsic>( 1209 Name, SuffixStr, MangledName, MangledSuffixStr, IRName, 1210 /*IsMask=*/false, /*HasMaskedOffOperand=*/false, HasVL, HasPolicy, 1211 HasNoMaskedOverloaded, HasAutoDef, ManualCodegen, Types.getValue(), 1212 IntrinsicTypes, RequiredExtension, NF)); 1213 if (HasMask) { 1214 // Create a mask intrinsic 1215 Optional<RVVTypes> MaskTypes = 1216 computeTypes(I, Log2LMUL, NF, ProtoMaskSeq); 1217 Out.push_back(std::make_unique<RVVIntrinsic>( 1218 Name, SuffixStr, MangledName, MangledSuffixStr, IRNameMask, 1219 /*IsMask=*/true, HasMaskedOffOperand, HasVL, HasPolicy, 1220 HasNoMaskedOverloaded, HasAutoDef, ManualCodegenMask, 1221 MaskTypes.getValue(), IntrinsicTypes, RequiredExtension, NF)); 1222 } 1223 } // end for Log2LMULList 1224 } // end for TypeRange 1225 } 1226 } 1227 1228 void RVVEmitter::createRVVHeaders(raw_ostream &OS) { 1229 std::vector<Record *> RVVHeaders = 1230 Records.getAllDerivedDefinitions("RVVHeader"); 1231 for (auto *R : RVVHeaders) { 1232 StringRef HeaderCodeStr = R->getValueAsString("HeaderCode"); 1233 OS << HeaderCodeStr.str(); 1234 } 1235 } 1236 1237 Optional<RVVTypes> 1238 RVVEmitter::computeTypes(BasicType BT, int Log2LMUL, unsigned NF, 1239 ArrayRef<std::string> PrototypeSeq) { 1240 // LMUL x NF must be less than or equal to 8. 1241 if ((Log2LMUL >= 1) && (1 << Log2LMUL) * NF > 8) 1242 return llvm::None; 1243 1244 RVVTypes Types; 1245 for (const std::string &Proto : PrototypeSeq) { 1246 auto T = computeType(BT, Log2LMUL, Proto); 1247 if (!T.hasValue()) 1248 return llvm::None; 1249 // Record legal type index 1250 Types.push_back(T.getValue()); 1251 } 1252 return Types; 1253 } 1254 1255 Optional<RVVTypePtr> RVVEmitter::computeType(BasicType BT, int Log2LMUL, 1256 StringRef Proto) { 1257 std::string Idx = Twine(Twine(BT) + Twine(Log2LMUL) + Proto).str(); 1258 // Search first 1259 auto It = LegalTypes.find(Idx); 1260 if (It != LegalTypes.end()) 1261 return &(It->second); 1262 if (IllegalTypes.count(Idx)) 1263 return llvm::None; 1264 // Compute type and record the result. 1265 RVVType T(BT, Log2LMUL, Proto); 1266 if (T.isValid()) { 1267 // Record legal type index and value. 1268 LegalTypes.insert({Idx, T}); 1269 return &(LegalTypes[Idx]); 1270 } 1271 // Record illegal type index. 1272 IllegalTypes.insert(Idx); 1273 return llvm::None; 1274 } 1275 1276 void RVVEmitter::emitArchMacroAndBody( 1277 std::vector<std::unique_ptr<RVVIntrinsic>> &Defs, raw_ostream &OS, 1278 std::function<void(raw_ostream &, const RVVIntrinsic &)> PrintBody) { 1279 uint8_t PrevExt = (*Defs.begin())->getRISCVExtensions(); 1280 bool NeedEndif = emitExtDefStr(PrevExt, OS); 1281 for (auto &Def : Defs) { 1282 uint8_t CurExt = Def->getRISCVExtensions(); 1283 if (CurExt != PrevExt) { 1284 if (NeedEndif) 1285 OS << "#endif\n\n"; 1286 NeedEndif = emitExtDefStr(CurExt, OS); 1287 PrevExt = CurExt; 1288 } 1289 if (Def->hasAutoDef()) 1290 PrintBody(OS, *Def); 1291 } 1292 if (NeedEndif) 1293 OS << "#endif\n\n"; 1294 } 1295 1296 bool RVVEmitter::emitExtDefStr(uint8_t Extents, raw_ostream &OS) { 1297 if (Extents == RISCVExtension::Basic) 1298 return false; 1299 OS << "#if "; 1300 ListSeparator LS(" && "); 1301 if (Extents & RISCVExtension::F) 1302 OS << LS << "defined(__riscv_f)"; 1303 if (Extents & RISCVExtension::D) 1304 OS << LS << "defined(__riscv_d)"; 1305 if (Extents & RISCVExtension::Zfh) 1306 OS << LS << "defined(__riscv_zfh)"; 1307 if (Extents & RISCVExtension::Zvlsseg) 1308 OS << LS << "defined(__riscv_zvlsseg)"; 1309 OS << "\n"; 1310 return true; 1311 } 1312 1313 namespace clang { 1314 void EmitRVVHeader(RecordKeeper &Records, raw_ostream &OS) { 1315 RVVEmitter(Records).createHeader(OS); 1316 } 1317 1318 void EmitRVVBuiltins(RecordKeeper &Records, raw_ostream &OS) { 1319 RVVEmitter(Records).createBuiltins(OS); 1320 } 1321 1322 void EmitRVVBuiltinCG(RecordKeeper &Records, raw_ostream &OS) { 1323 RVVEmitter(Records).createCodeGen(OS); 1324 } 1325 1326 } // End namespace clang 1327