1 //===- RISCVVEmitter.cpp - Generate riscv_vector.h for use with clang -----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This tablegen backend is responsible for emitting riscv_vector.h which 10 // includes a declaration and definition of each intrinsic functions specified 11 // in https://github.com/riscv/rvv-intrinsic-doc. 12 // 13 // See also the documentation in include/clang/Basic/riscv_vector.td. 14 // 15 //===----------------------------------------------------------------------===// 16 17 #include "llvm/ADT/ArrayRef.h" 18 #include "llvm/ADT/SmallSet.h" 19 #include "llvm/ADT/StringExtras.h" 20 #include "llvm/ADT/StringMap.h" 21 #include "llvm/ADT/StringSet.h" 22 #include "llvm/ADT/Twine.h" 23 #include "llvm/TableGen/Error.h" 24 #include "llvm/TableGen/Record.h" 25 #include <numeric> 26 27 using namespace llvm; 28 using BasicType = char; 29 using VScaleVal = Optional<unsigned>; 30 31 namespace { 32 33 // Exponential LMUL 34 struct LMULType { 35 int Log2LMUL; 36 LMULType(int Log2LMUL); 37 // Return the C/C++ string representation of LMUL 38 std::string str() const; 39 Optional<unsigned> getScale(unsigned ElementBitwidth) const; 40 void MulLog2LMUL(int Log2LMUL); 41 LMULType &operator*=(uint32_t RHS); 42 }; 43 44 // This class is compact representation of a valid and invalid RVVType. 45 class RVVType { 46 enum ScalarTypeKind : uint32_t { 47 Void, 48 Size_t, 49 Ptrdiff_t, 50 UnsignedLong, 51 SignedLong, 52 Boolean, 53 SignedInteger, 54 UnsignedInteger, 55 Float, 56 Invalid, 57 }; 58 BasicType BT; 59 ScalarTypeKind ScalarType = Invalid; 60 LMULType LMUL; 61 bool IsPointer = false; 62 // IsConstant indices are "int", but have the constant expression. 63 bool IsImmediate = false; 64 // Const qualifier for pointer to const object or object of const type. 65 bool IsConstant = false; 66 unsigned ElementBitwidth = 0; 67 VScaleVal Scale = 0; 68 bool Valid; 69 70 std::string BuiltinStr; 71 std::string ClangBuiltinStr; 72 std::string Str; 73 std::string ShortStr; 74 75 public: 76 RVVType() : RVVType(BasicType(), 0, StringRef()) {} 77 RVVType(BasicType BT, int Log2LMUL, StringRef prototype); 78 79 // Return the string representation of a type, which is an encoded string for 80 // passing to the BUILTIN() macro in Builtins.def. 81 const std::string &getBuiltinStr() const { return BuiltinStr; } 82 83 // Return the clang buitlin type for RVV vector type which are used in the 84 // riscv_vector.h header file. 85 const std::string &getClangBuiltinStr() const { return ClangBuiltinStr; } 86 87 // Return the C/C++ string representation of a type for use in the 88 // riscv_vector.h header file. 89 const std::string &getTypeStr() const { return Str; } 90 91 // Return the short name of a type for C/C++ name suffix. 92 const std::string &getShortStr() { 93 // Not all types are used in short name, so compute the short name by 94 // demanded. 95 if (ShortStr.empty()) 96 initShortStr(); 97 return ShortStr; 98 } 99 100 bool isValid() const { return Valid; } 101 bool isScalar() const { return Scale.hasValue() && Scale.getValue() == 0; } 102 bool isVector() const { return Scale.hasValue() && Scale.getValue() != 0; } 103 bool isFloat() const { return ScalarType == ScalarTypeKind::Float; } 104 bool isSignedInteger() const { 105 return ScalarType == ScalarTypeKind::SignedInteger; 106 } 107 bool isFloatVector(unsigned Width) const { 108 return isVector() && isFloat() && ElementBitwidth == Width; 109 } 110 bool isFloat(unsigned Width) const { 111 return isFloat() && ElementBitwidth == Width; 112 } 113 114 private: 115 // Verify RVV vector type and set Valid. 116 bool verifyType() const; 117 118 // Creates a type based on basic types of TypeRange 119 void applyBasicType(); 120 121 // Applies a prototype modifier to the current type. The result maybe an 122 // invalid type. 123 void applyModifier(StringRef prototype); 124 125 // Compute and record a string for legal type. 126 void initBuiltinStr(); 127 // Compute and record a builtin RVV vector type string. 128 void initClangBuiltinStr(); 129 // Compute and record a type string for used in the header. 130 void initTypeStr(); 131 // Compute and record a short name of a type for C/C++ name suffix. 132 void initShortStr(); 133 }; 134 135 using RVVTypePtr = RVVType *; 136 using RVVTypes = std::vector<RVVTypePtr>; 137 138 enum RISCVExtension : uint8_t { 139 Basic = 0, 140 F = 1 << 1, 141 D = 1 << 2, 142 Zfh = 1 << 3, 143 Zvamo = 1 << 4, 144 }; 145 146 // TODO refactor RVVIntrinsic class design after support all intrinsic 147 // combination. This represents an instantiation of an intrinsic with a 148 // particular type and prototype 149 class RVVIntrinsic { 150 151 private: 152 std::string Name; // Builtin name 153 std::string MangledName; 154 std::string IRName; 155 bool HasSideEffects; 156 bool IsMask; 157 bool HasMaskedOffOperand; 158 bool HasVL; 159 bool HasNoMaskedOverloaded; 160 bool HasAutoDef; // There is automiatic definition in header 161 std::string ManualCodegen; 162 RVVTypePtr OutputType; // Builtin output type 163 RVVTypes InputTypes; // Builtin input types 164 // The types we use to obtain the specific LLVM intrinsic. They are index of 165 // InputTypes. -1 means the return type. 166 std::vector<int64_t> IntrinsicTypes; 167 uint8_t RISCVExtensions = 0; 168 169 public: 170 RVVIntrinsic(StringRef Name, StringRef Suffix, StringRef MangledName, 171 StringRef IRName, bool HasSideEffects, bool IsMask, 172 bool HasMaskedOffOperand, bool HasVL, bool HasNoMaskedOverloaded, 173 bool HasAutoDef, StringRef ManualCodegen, const RVVTypes &Types, 174 const std::vector<int64_t> &IntrinsicTypes, 175 StringRef RequiredExtension); 176 ~RVVIntrinsic() = default; 177 178 StringRef getName() const { return Name; } 179 StringRef getMangledName() const { return MangledName; } 180 bool hasSideEffects() const { return HasSideEffects; } 181 bool hasMaskedOffOperand() const { return HasMaskedOffOperand; } 182 bool hasVL() const { return HasVL; } 183 bool hasNoMaskedOverloaded() const { return HasNoMaskedOverloaded; } 184 bool hasManualCodegen() const { return !ManualCodegen.empty(); } 185 bool hasAutoDef() const { return HasAutoDef; } 186 bool isMask() const { return IsMask; } 187 StringRef getIRName() const { return IRName; } 188 StringRef getManualCodegen() const { return ManualCodegen; } 189 uint8_t getRISCVExtensions() const { return RISCVExtensions; } 190 191 // Return the type string for a BUILTIN() macro in Builtins.def. 192 std::string getBuiltinTypeStr() const; 193 194 // Emit the code block for switch body in EmitRISCVBuiltinExpr, it should 195 // init the RVVIntrinsic ID and IntrinsicTypes. 196 void emitCodeGenSwitchBody(raw_ostream &o) const; 197 198 // Emit the macros for mapping C/C++ intrinsic function to builtin functions. 199 void emitIntrinsicMacro(raw_ostream &o) const; 200 201 // Emit the mangled function definition. 202 void emitMangledFuncDef(raw_ostream &o) const; 203 }; 204 205 class RVVEmitter { 206 private: 207 RecordKeeper &Records; 208 std::string HeaderCode; 209 // Concat BasicType, LMUL and Proto as key 210 StringMap<RVVType> LegalTypes; 211 StringSet<> IllegalTypes; 212 213 public: 214 RVVEmitter(RecordKeeper &R) : Records(R) {} 215 216 /// Emit riscv_vector.h 217 void createHeader(raw_ostream &o); 218 219 /// Emit all the __builtin prototypes and code needed by Sema. 220 void createBuiltins(raw_ostream &o); 221 222 /// Emit all the information needed to map builtin -> LLVM IR intrinsic. 223 void createCodeGen(raw_ostream &o); 224 225 std::string getSuffixStr(char Type, int Log2LMUL, StringRef Prototypes); 226 227 private: 228 /// Create all intrinsics and add them to \p Out 229 void createRVVIntrinsics(std::vector<std::unique_ptr<RVVIntrinsic>> &Out); 230 /// Compute output and input types by applying different config (basic type 231 /// and LMUL with type transformers). It also record result of type in legal 232 /// or illegal set to avoid compute the same config again. The result maybe 233 /// have illegal RVVType. 234 Optional<RVVTypes> computeTypes(BasicType BT, int Log2LMUL, 235 ArrayRef<std::string> PrototypeSeq); 236 Optional<RVVTypePtr> computeType(BasicType BT, int Log2LMUL, StringRef Proto); 237 238 /// Emit Acrh predecessor definitions and body, assume the element of Defs are 239 /// sorted by extension. 240 void emitArchMacroAndBody( 241 std::vector<std::unique_ptr<RVVIntrinsic>> &Defs, raw_ostream &o, 242 std::function<void(raw_ostream &, const RVVIntrinsic &)>); 243 244 // Emit the architecture preprocessor definitions. Return true when emits 245 // non-empty string. 246 bool emitExtDefStr(uint8_t Extensions, raw_ostream &o); 247 // Slice Prototypes string into sub prototype string and process each sub 248 // prototype string individually in the Handler. 249 void parsePrototypes(StringRef Prototypes, 250 std::function<void(StringRef)> Handler); 251 }; 252 253 } // namespace 254 255 //===----------------------------------------------------------------------===// 256 // Type implementation 257 //===----------------------------------------------------------------------===// 258 259 LMULType::LMULType(int NewLog2LMUL) { 260 // Check Log2LMUL is -3, -2, -1, 0, 1, 2, 3 261 assert(NewLog2LMUL <= 3 && NewLog2LMUL >= -3 && "Bad LMUL number!"); 262 Log2LMUL = NewLog2LMUL; 263 } 264 265 std::string LMULType::str() const { 266 if (Log2LMUL < 0) 267 return "mf" + utostr(1ULL << (-Log2LMUL)); 268 return "m" + utostr(1ULL << Log2LMUL); 269 } 270 271 VScaleVal LMULType::getScale(unsigned ElementBitwidth) const { 272 int Log2ScaleResult = 0; 273 switch (ElementBitwidth) { 274 default: 275 break; 276 case 8: 277 Log2ScaleResult = Log2LMUL + 3; 278 break; 279 case 16: 280 Log2ScaleResult = Log2LMUL + 2; 281 break; 282 case 32: 283 Log2ScaleResult = Log2LMUL + 1; 284 break; 285 case 64: 286 Log2ScaleResult = Log2LMUL; 287 break; 288 } 289 // Illegal vscale result would be less than 1 290 if (Log2ScaleResult < 0) 291 return None; 292 return 1 << Log2ScaleResult; 293 } 294 295 void LMULType::MulLog2LMUL(int log2LMUL) { Log2LMUL += log2LMUL; } 296 297 LMULType &LMULType::operator*=(uint32_t RHS) { 298 assert(isPowerOf2_32(RHS)); 299 this->Log2LMUL = this->Log2LMUL + Log2_32(RHS); 300 return *this; 301 } 302 303 RVVType::RVVType(BasicType BT, int Log2LMUL, StringRef prototype) 304 : BT(BT), LMUL(LMULType(Log2LMUL)) { 305 applyBasicType(); 306 applyModifier(prototype); 307 Valid = verifyType(); 308 if (Valid) { 309 initBuiltinStr(); 310 initTypeStr(); 311 if (isVector()) { 312 initClangBuiltinStr(); 313 } 314 } 315 } 316 317 // clang-format off 318 // boolean type are encoded the ratio of n (SEW/LMUL) 319 // SEW/LMUL | 1 | 2 | 4 | 8 | 16 | 32 | 64 320 // c type | vbool64_t | vbool32_t | vbool16_t | vbool8_t | vbool4_t | vbool2_t | vbool1_t 321 // IR type | nxv1i1 | nxv2i1 | nxv4i1 | nxv8i1 | nxv16i1 | nxv32i1 | nxv64i1 322 323 // type\lmul | 1/8 | 1/4 | 1/2 | 1 | 2 | 4 | 8 324 // -------- |------ | -------- | ------- | ------- | -------- | -------- | -------- 325 // i64 | N/A | N/A | N/A | nxv1i64 | nxv2i64 | nxv4i64 | nxv8i64 326 // i32 | N/A | N/A | nxv1i32 | nxv2i32 | nxv4i32 | nxv8i32 | nxv16i32 327 // i16 | N/A | nxv1i16 | nxv2i16 | nxv4i16 | nxv8i16 | nxv16i16 | nxv32i16 328 // i8 | nxv1i8 | nxv2i8 | nxv4i8 | nxv8i8 | nxv16i8 | nxv32i8 | nxv64i8 329 // double | N/A | N/A | N/A | nxv1f64 | nxv2f64 | nxv4f64 | nxv8f64 330 // float | N/A | N/A | nxv1f32 | nxv2f32 | nxv4f32 | nxv8f32 | nxv16f32 331 // half | N/A | nxv1f16 | nxv2f16 | nxv4f16 | nxv8f16 | nxv16f16 | nxv32f16 332 // clang-format on 333 334 bool RVVType::verifyType() const { 335 if (ScalarType == Invalid) 336 return false; 337 if (isScalar()) 338 return true; 339 if (!Scale.hasValue()) 340 return false; 341 if (isFloat() && ElementBitwidth == 8) 342 return false; 343 unsigned V = Scale.getValue(); 344 switch (ElementBitwidth) { 345 case 1: 346 case 8: 347 // Check Scale is 1,2,4,8,16,32,64 348 return (V <= 64 && isPowerOf2_32(V)); 349 case 16: 350 // Check Scale is 1,2,4,8,16,32 351 return (V <= 32 && isPowerOf2_32(V)); 352 case 32: 353 // Check Scale is 1,2,4,8,16 354 return (V <= 16 && isPowerOf2_32(V)); 355 case 64: 356 // Check Scale is 1,2,4,8 357 return (V <= 8 && isPowerOf2_32(V)); 358 } 359 return false; 360 } 361 362 void RVVType::initBuiltinStr() { 363 assert(isValid() && "RVVType is invalid"); 364 switch (ScalarType) { 365 case ScalarTypeKind::Void: 366 BuiltinStr = "v"; 367 return; 368 case ScalarTypeKind::Size_t: 369 BuiltinStr = "z"; 370 if (IsImmediate) 371 BuiltinStr = "I" + BuiltinStr; 372 if (IsPointer) 373 BuiltinStr += "*"; 374 return; 375 case ScalarTypeKind::Ptrdiff_t: 376 BuiltinStr = "Y"; 377 return; 378 case ScalarTypeKind::UnsignedLong: 379 BuiltinStr = "ULi"; 380 return; 381 case ScalarTypeKind::SignedLong: 382 BuiltinStr = "Li"; 383 return; 384 case ScalarTypeKind::Boolean: 385 assert(ElementBitwidth == 1); 386 BuiltinStr += "b"; 387 break; 388 case ScalarTypeKind::SignedInteger: 389 case ScalarTypeKind::UnsignedInteger: 390 switch (ElementBitwidth) { 391 case 8: 392 BuiltinStr += "c"; 393 break; 394 case 16: 395 BuiltinStr += "s"; 396 break; 397 case 32: 398 BuiltinStr += "i"; 399 break; 400 case 64: 401 BuiltinStr += "Wi"; 402 break; 403 default: 404 llvm_unreachable("Unhandled ElementBitwidth!"); 405 } 406 if (isSignedInteger()) 407 BuiltinStr = "S" + BuiltinStr; 408 else 409 BuiltinStr = "U" + BuiltinStr; 410 break; 411 case ScalarTypeKind::Float: 412 switch (ElementBitwidth) { 413 case 16: 414 BuiltinStr += "h"; 415 break; 416 case 32: 417 BuiltinStr += "f"; 418 break; 419 case 64: 420 BuiltinStr += "d"; 421 break; 422 default: 423 llvm_unreachable("Unhandled ElementBitwidth!"); 424 } 425 break; 426 default: 427 llvm_unreachable("ScalarType is invalid!"); 428 } 429 if (IsImmediate) 430 BuiltinStr = "I" + BuiltinStr; 431 if (isScalar()) { 432 if (IsConstant) 433 BuiltinStr += "C"; 434 if (IsPointer) 435 BuiltinStr += "*"; 436 return; 437 } 438 BuiltinStr = "q" + utostr(Scale.getValue()) + BuiltinStr; 439 } 440 441 void RVVType::initClangBuiltinStr() { 442 assert(isValid() && "RVVType is invalid"); 443 assert(isVector() && "Handle Vector type only"); 444 445 ClangBuiltinStr = "__rvv_"; 446 switch (ScalarType) { 447 case ScalarTypeKind::Boolean: 448 ClangBuiltinStr += "bool" + utostr(64 / Scale.getValue()) + "_t"; 449 return; 450 case ScalarTypeKind::Float: 451 ClangBuiltinStr += "float"; 452 break; 453 case ScalarTypeKind::SignedInteger: 454 ClangBuiltinStr += "int"; 455 break; 456 case ScalarTypeKind::UnsignedInteger: 457 ClangBuiltinStr += "uint"; 458 break; 459 default: 460 llvm_unreachable("ScalarTypeKind is invalid"); 461 } 462 ClangBuiltinStr += utostr(ElementBitwidth) + LMUL.str() + "_t"; 463 } 464 465 void RVVType::initTypeStr() { 466 assert(isValid() && "RVVType is invalid"); 467 468 if (IsConstant) 469 Str += "const "; 470 471 auto getTypeString = [&](StringRef TypeStr) { 472 if (isScalar()) 473 return Twine(TypeStr + Twine(ElementBitwidth) + "_t").str(); 474 return Twine("v" + TypeStr + Twine(ElementBitwidth) + LMUL.str() + "_t") 475 .str(); 476 }; 477 478 switch (ScalarType) { 479 case ScalarTypeKind::Void: 480 Str = "void"; 481 return; 482 case ScalarTypeKind::Size_t: 483 Str = "size_t"; 484 if (IsPointer) 485 Str += " *"; 486 return; 487 case ScalarTypeKind::Ptrdiff_t: 488 Str = "ptrdiff_t"; 489 return; 490 case ScalarTypeKind::UnsignedLong: 491 Str = "unsigned long"; 492 return; 493 case ScalarTypeKind::SignedLong: 494 Str = "long"; 495 return; 496 case ScalarTypeKind::Boolean: 497 if (isScalar()) 498 Str += "bool"; 499 else 500 // Vector bool is special case, the formulate is 501 // `vbool<N>_t = MVT::nxv<64/N>i1` ex. vbool16_t = MVT::4i1 502 Str += "vbool" + utostr(64 / Scale.getValue()) + "_t"; 503 break; 504 case ScalarTypeKind::Float: 505 if (isScalar()) { 506 if (ElementBitwidth == 64) 507 Str += "double"; 508 else if (ElementBitwidth == 32) 509 Str += "float"; 510 assert((ElementBitwidth == 32 || ElementBitwidth == 64) && 511 "Unhandled floating type"); 512 } else 513 Str += getTypeString("float"); 514 break; 515 case ScalarTypeKind::SignedInteger: 516 Str += getTypeString("int"); 517 break; 518 case ScalarTypeKind::UnsignedInteger: 519 Str += getTypeString("uint"); 520 break; 521 default: 522 llvm_unreachable("ScalarType is invalid!"); 523 } 524 if (IsPointer) 525 Str += " *"; 526 } 527 528 void RVVType::initShortStr() { 529 switch (ScalarType) { 530 case ScalarTypeKind::Boolean: 531 assert(isVector()); 532 ShortStr = "b" + utostr(64 / Scale.getValue()); 533 return; 534 case ScalarTypeKind::Float: 535 ShortStr = "f" + utostr(ElementBitwidth); 536 break; 537 case ScalarTypeKind::SignedInteger: 538 ShortStr = "i" + utostr(ElementBitwidth); 539 break; 540 case ScalarTypeKind::UnsignedInteger: 541 ShortStr = "u" + utostr(ElementBitwidth); 542 break; 543 default: 544 PrintFatalError("Unhandled case!"); 545 } 546 if (isVector()) 547 ShortStr += LMUL.str(); 548 } 549 550 void RVVType::applyBasicType() { 551 switch (BT) { 552 case 'c': 553 ElementBitwidth = 8; 554 ScalarType = ScalarTypeKind::SignedInteger; 555 break; 556 case 's': 557 ElementBitwidth = 16; 558 ScalarType = ScalarTypeKind::SignedInteger; 559 break; 560 case 'i': 561 ElementBitwidth = 32; 562 ScalarType = ScalarTypeKind::SignedInteger; 563 break; 564 case 'l': 565 ElementBitwidth = 64; 566 ScalarType = ScalarTypeKind::SignedInteger; 567 break; 568 case 'h': 569 ElementBitwidth = 16; 570 ScalarType = ScalarTypeKind::Float; 571 break; 572 case 'f': 573 ElementBitwidth = 32; 574 ScalarType = ScalarTypeKind::Float; 575 break; 576 case 'd': 577 ElementBitwidth = 64; 578 ScalarType = ScalarTypeKind::Float; 579 break; 580 default: 581 PrintFatalError("Unhandled type code!"); 582 } 583 assert(ElementBitwidth != 0 && "Bad element bitwidth!"); 584 } 585 586 void RVVType::applyModifier(StringRef Transformer) { 587 if (Transformer.empty()) 588 return; 589 // Handle primitive type transformer 590 auto PType = Transformer.back(); 591 switch (PType) { 592 case 'e': 593 Scale = 0; 594 break; 595 case 'v': 596 Scale = LMUL.getScale(ElementBitwidth); 597 break; 598 case 'w': 599 ElementBitwidth *= 2; 600 LMUL *= 2; 601 Scale = LMUL.getScale(ElementBitwidth); 602 break; 603 case 'q': 604 ElementBitwidth *= 4; 605 LMUL *= 4; 606 Scale = LMUL.getScale(ElementBitwidth); 607 break; 608 case 'o': 609 ElementBitwidth *= 8; 610 LMUL *= 8; 611 Scale = LMUL.getScale(ElementBitwidth); 612 break; 613 case 'm': 614 ScalarType = ScalarTypeKind::Boolean; 615 Scale = LMUL.getScale(ElementBitwidth); 616 ElementBitwidth = 1; 617 break; 618 case '0': 619 ScalarType = ScalarTypeKind::Void; 620 break; 621 case 'z': 622 ScalarType = ScalarTypeKind::Size_t; 623 break; 624 case 't': 625 ScalarType = ScalarTypeKind::Ptrdiff_t; 626 break; 627 case 'u': 628 ScalarType = ScalarTypeKind::UnsignedLong; 629 break; 630 case 'l': 631 ScalarType = ScalarTypeKind::SignedLong; 632 break; 633 default: 634 PrintFatalError("Illegal primitive type transformers!"); 635 } 636 Transformer = Transformer.drop_back(); 637 638 // Extract and compute complex type transformer. It can only appear one time. 639 if (Transformer.startswith("(")) { 640 size_t Idx = Transformer.find(')'); 641 assert(Idx != StringRef::npos); 642 StringRef ComplexType = Transformer.slice(1, Idx); 643 Transformer = Transformer.drop_front(Idx + 1); 644 assert(Transformer.find('(') == StringRef::npos && 645 "Only allow one complex type transformer"); 646 647 auto UpdateAndCheckComplexProto = [&]() { 648 Scale = LMUL.getScale(ElementBitwidth); 649 const StringRef VectorPrototypes("vwqom"); 650 if (!VectorPrototypes.contains(PType)) 651 PrintFatalError("Complex type transformer only supports vector type!"); 652 if (Transformer.find_first_of("PCKWS") != StringRef::npos) 653 PrintFatalError( 654 "Illegal type transformer for Complex type transformer"); 655 }; 656 auto ComputeFixedLog2LMUL = 657 [&](StringRef Value, 658 std::function<bool(const int32_t &, const int32_t &)> Compare) { 659 int32_t Log2LMUL; 660 Value.getAsInteger(10, Log2LMUL); 661 if (!Compare(Log2LMUL, LMUL.Log2LMUL)) { 662 ScalarType = Invalid; 663 return false; 664 } 665 // Update new LMUL 666 LMUL = LMULType(Log2LMUL); 667 UpdateAndCheckComplexProto(); 668 return true; 669 }; 670 auto ComplexTT = ComplexType.split(":"); 671 if (ComplexTT.first == "Log2EEW") { 672 uint32_t Log2EEW; 673 ComplexTT.second.getAsInteger(10, Log2EEW); 674 // update new elmul = (eew/sew) * lmul 675 LMUL.MulLog2LMUL(Log2EEW - Log2_32(ElementBitwidth)); 676 // update new eew 677 ElementBitwidth = 1 << Log2EEW; 678 ScalarType = ScalarTypeKind::SignedInteger; 679 UpdateAndCheckComplexProto(); 680 } else if (ComplexTT.first == "FixedSEW") { 681 uint32_t NewSEW; 682 ComplexTT.second.getAsInteger(10, NewSEW); 683 // Set invalid type if src and dst SEW are same. 684 if (ElementBitwidth == NewSEW) { 685 ScalarType = Invalid; 686 return; 687 } 688 // Update new SEW 689 ElementBitwidth = NewSEW; 690 UpdateAndCheckComplexProto(); 691 } else if (ComplexTT.first == "LFixedLog2LMUL") { 692 // New LMUL should be larger than old 693 if (!ComputeFixedLog2LMUL(ComplexTT.second, std::greater<int32_t>())) 694 return; 695 } else if (ComplexTT.first == "SFixedLog2LMUL") { 696 // New LMUL should be smaller than old 697 if (!ComputeFixedLog2LMUL(ComplexTT.second, std::less<int32_t>())) 698 return; 699 } else { 700 PrintFatalError("Illegal complex type transformers!"); 701 } 702 } 703 704 // Compute the remain type transformers 705 for (char I : Transformer) { 706 switch (I) { 707 case 'P': 708 if (IsConstant) 709 PrintFatalError("'P' transformer cannot be used after 'C'"); 710 if (IsPointer) 711 PrintFatalError("'P' transformer cannot be used twice"); 712 IsPointer = true; 713 break; 714 case 'C': 715 if (IsConstant) 716 PrintFatalError("'C' transformer cannot be used twice"); 717 IsConstant = true; 718 break; 719 case 'K': 720 IsImmediate = true; 721 break; 722 case 'U': 723 ScalarType = ScalarTypeKind::UnsignedInteger; 724 break; 725 case 'I': 726 ScalarType = ScalarTypeKind::SignedInteger; 727 break; 728 case 'F': 729 ScalarType = ScalarTypeKind::Float; 730 break; 731 case 'S': 732 LMUL = LMULType(0); 733 // Update ElementBitwidth need to update Scale too. 734 Scale = LMUL.getScale(ElementBitwidth); 735 break; 736 default: 737 PrintFatalError("Illegal non-primitive type transformer!"); 738 } 739 } 740 } 741 742 //===----------------------------------------------------------------------===// 743 // RVVIntrinsic implementation 744 //===----------------------------------------------------------------------===// 745 RVVIntrinsic::RVVIntrinsic(StringRef NewName, StringRef Suffix, 746 StringRef NewMangledName, StringRef IRName, 747 bool HasSideEffects, bool IsMask, 748 bool HasMaskedOffOperand, bool HasVL, 749 bool HasNoMaskedOverloaded, bool HasAutoDef, 750 StringRef ManualCodegen, const RVVTypes &OutInTypes, 751 const std::vector<int64_t> &NewIntrinsicTypes, 752 StringRef RequiredExtension) 753 : IRName(IRName), HasSideEffects(HasSideEffects), IsMask(IsMask), 754 HasMaskedOffOperand(HasMaskedOffOperand), HasVL(HasVL), 755 HasNoMaskedOverloaded(HasNoMaskedOverloaded), HasAutoDef(HasAutoDef), 756 ManualCodegen(ManualCodegen.str()) { 757 758 // Init Name and MangledName 759 Name = NewName.str(); 760 if (NewMangledName.empty()) 761 MangledName = NewName.split("_").first.str(); 762 else 763 MangledName = NewMangledName.str(); 764 if (!Suffix.empty()) 765 Name += "_" + Suffix.str(); 766 if (IsMask) { 767 Name += "_m"; 768 } 769 // Init RISC-V extensions 770 for (const auto &T : OutInTypes) { 771 if (T->isFloatVector(16) || T->isFloat(16)) 772 RISCVExtensions |= RISCVExtension::Zfh; 773 else if (T->isFloatVector(32) || T->isFloat(32)) 774 RISCVExtensions |= RISCVExtension::F; 775 else if (T->isFloatVector(64) || T->isFloat(64)) 776 RISCVExtensions |= RISCVExtension::D; 777 } 778 if (RequiredExtension == "Zvamo") 779 RISCVExtensions |= RISCVExtension::Zvamo; 780 781 // Init OutputType and InputTypes 782 OutputType = OutInTypes[0]; 783 InputTypes.assign(OutInTypes.begin() + 1, OutInTypes.end()); 784 785 // IntrinsicTypes is nonmasked version index. Need to update it 786 // if there is maskedoff operand (It is always in first operand). 787 IntrinsicTypes = NewIntrinsicTypes; 788 if (IsMask && HasMaskedOffOperand) { 789 for (auto &I : IntrinsicTypes) { 790 if (I >= 0) 791 I += 1; 792 } 793 } 794 } 795 796 std::string RVVIntrinsic::getBuiltinTypeStr() const { 797 std::string S; 798 S += OutputType->getBuiltinStr(); 799 for (const auto &T : InputTypes) { 800 S += T->getBuiltinStr(); 801 } 802 return S; 803 } 804 805 void RVVIntrinsic::emitCodeGenSwitchBody(raw_ostream &OS) const { 806 if (!getIRName().empty()) 807 OS << " ID = Intrinsic::riscv_" + getIRName() + ";\n"; 808 if (hasManualCodegen()) { 809 OS << ManualCodegen; 810 OS << "break;\n"; 811 return; 812 } 813 814 if (isMask()) { 815 if (hasVL()) { 816 OS << " std::rotate(Ops.begin(), Ops.begin() + 1, Ops.end() - 1);\n"; 817 } else { 818 OS << " std::rotate(Ops.begin(), Ops.begin() + 1, Ops.end());\n"; 819 } 820 } 821 822 OS << " IntrinsicTypes = {"; 823 ListSeparator LS; 824 for (const auto &Idx : IntrinsicTypes) { 825 if (Idx == -1) 826 OS << LS << "ResultType"; 827 else 828 OS << LS << "Ops[" << Idx << "]->getType()"; 829 } 830 831 // VL could be i64 or i32, need to encode it in IntrinsicTypes. VL is 832 // always last operand. 833 if (hasVL()) 834 OS << ", Ops.back()->getType()"; 835 OS << "};\n"; 836 OS << " break;\n"; 837 } 838 839 void RVVIntrinsic::emitIntrinsicMacro(raw_ostream &OS) const { 840 OS << "#define " << getName() << "("; 841 if (!InputTypes.empty()) { 842 ListSeparator LS; 843 for (unsigned i = 0, e = InputTypes.size(); i != e; ++i) 844 OS << LS << "op" << i; 845 } 846 OS << ") \\\n"; 847 OS << "__builtin_rvv_" << getName() << "("; 848 if (!InputTypes.empty()) { 849 ListSeparator LS; 850 for (unsigned i = 0, e = InputTypes.size(); i != e; ++i) 851 OS << LS << "(" << InputTypes[i]->getTypeStr() << ")(op" << i << ")"; 852 } 853 OS << ")\n"; 854 } 855 856 void RVVIntrinsic::emitMangledFuncDef(raw_ostream &OS) const { 857 OS << "__attribute__((clang_builtin_alias("; 858 OS << "__builtin_rvv_" << getName() << ")))\n"; 859 OS << OutputType->getTypeStr() << " " << getMangledName() << "("; 860 // Emit function arguments 861 if (!InputTypes.empty()) { 862 ListSeparator LS; 863 for (unsigned i = 0; i < InputTypes.size(); ++i) 864 OS << LS << InputTypes[i]->getTypeStr() << " op" << i; 865 } 866 OS << ");\n\n"; 867 } 868 869 //===----------------------------------------------------------------------===// 870 // RVVEmitter implementation 871 //===----------------------------------------------------------------------===// 872 void RVVEmitter::createHeader(raw_ostream &OS) { 873 874 OS << "/*===---- riscv_vector.h - RISC-V V-extension RVVIntrinsics " 875 "-------------------===\n" 876 " *\n" 877 " *\n" 878 " * Part of the LLVM Project, under the Apache License v2.0 with LLVM " 879 "Exceptions.\n" 880 " * See https://llvm.org/LICENSE.txt for license information.\n" 881 " * SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception\n" 882 " *\n" 883 " *===-----------------------------------------------------------------" 884 "------===\n" 885 " */\n\n"; 886 887 OS << "#ifndef __RISCV_VECTOR_H\n"; 888 OS << "#define __RISCV_VECTOR_H\n\n"; 889 890 OS << "#include <stdint.h>\n"; 891 OS << "#include <stddef.h>\n\n"; 892 893 OS << "#ifndef __riscv_vector\n"; 894 OS << "#error \"Vector intrinsics require the vector extension.\"\n"; 895 OS << "#endif\n\n"; 896 897 OS << "#ifdef __cplusplus\n"; 898 OS << "extern \"C\" {\n"; 899 OS << "#endif\n\n"; 900 901 std::vector<std::unique_ptr<RVVIntrinsic>> Defs; 902 createRVVIntrinsics(Defs); 903 904 // Print header code 905 if (!HeaderCode.empty()) { 906 OS << HeaderCode; 907 } 908 909 auto printType = [&](auto T) { 910 OS << "typedef " << T->getClangBuiltinStr() << " " << T->getTypeStr() 911 << ";\n"; 912 }; 913 914 constexpr int Log2LMULs[] = {-3, -2, -1, 0, 1, 2, 3}; 915 // Print RVV boolean types. 916 for (int Log2LMUL : Log2LMULs) { 917 auto T = computeType('c', Log2LMUL, "m"); 918 if (T.hasValue()) 919 printType(T.getValue()); 920 } 921 // Print RVV int/float types. 922 for (char I : StringRef("csil")) { 923 for (int Log2LMUL : Log2LMULs) { 924 auto T = computeType(I, Log2LMUL, "v"); 925 if (T.hasValue()) { 926 printType(T.getValue()); 927 auto UT = computeType(I, Log2LMUL, "Uv"); 928 printType(UT.getValue()); 929 } 930 } 931 } 932 OS << "#if defined(__riscv_zfh)\n"; 933 for (int Log2LMUL : Log2LMULs) { 934 auto T = computeType('h', Log2LMUL, "v"); 935 if (T.hasValue()) 936 printType(T.getValue()); 937 } 938 OS << "#endif\n"; 939 940 OS << "#if defined(__riscv_f)\n"; 941 for (int Log2LMUL : Log2LMULs) { 942 auto T = computeType('f', Log2LMUL, "v"); 943 if (T.hasValue()) 944 printType(T.getValue()); 945 } 946 OS << "#endif\n"; 947 948 OS << "#if defined(__riscv_d)\n"; 949 for (int Log2LMUL : Log2LMULs) { 950 auto T = computeType('d', Log2LMUL, "v"); 951 if (T.hasValue()) 952 printType(T.getValue()); 953 } 954 OS << "#endif\n\n"; 955 956 // The same extension include in the same arch guard marco. 957 std::stable_sort(Defs.begin(), Defs.end(), 958 [](const std::unique_ptr<RVVIntrinsic> &A, 959 const std::unique_ptr<RVVIntrinsic> &B) { 960 return A->getRISCVExtensions() < B->getRISCVExtensions(); 961 }); 962 963 // Print intrinsic functions with macro 964 emitArchMacroAndBody(Defs, OS, [](raw_ostream &OS, const RVVIntrinsic &Inst) { 965 Inst.emitIntrinsicMacro(OS); 966 }); 967 968 OS << "#define __riscv_v_intrinsic_overloading 1\n"; 969 970 // Print Overloaded APIs 971 OS << "#define __rvv_overloaded static inline " 972 "__attribute__((__always_inline__, __nodebug__, __overloadable__))\n"; 973 974 emitArchMacroAndBody(Defs, OS, [](raw_ostream &OS, const RVVIntrinsic &Inst) { 975 if (!Inst.isMask() && !Inst.hasNoMaskedOverloaded()) 976 return; 977 OS << "__rvv_overloaded "; 978 Inst.emitMangledFuncDef(OS); 979 }); 980 981 OS << "\n#ifdef __cplusplus\n"; 982 OS << "}\n"; 983 OS << "#endif // __riscv_vector\n"; 984 OS << "#endif // __RISCV_VECTOR_H\n"; 985 } 986 987 void RVVEmitter::createBuiltins(raw_ostream &OS) { 988 std::vector<std::unique_ptr<RVVIntrinsic>> Defs; 989 createRVVIntrinsics(Defs); 990 991 OS << "#if defined(TARGET_BUILTIN) && !defined(RISCVV_BUILTIN)\n"; 992 OS << "#define RISCVV_BUILTIN(ID, TYPE, ATTRS) TARGET_BUILTIN(ID, TYPE, " 993 "ATTRS, \"experimental-v\")\n"; 994 OS << "#endif\n"; 995 for (auto &Def : Defs) { 996 OS << "RISCVV_BUILTIN(__builtin_rvv_" << Def->getName() << ",\"" 997 << Def->getBuiltinTypeStr() << "\", "; 998 if (!Def->hasSideEffects()) 999 OS << "\"n\")\n"; 1000 else 1001 OS << "\"\")\n"; 1002 } 1003 OS << "#undef RISCVV_BUILTIN\n"; 1004 } 1005 1006 void RVVEmitter::createCodeGen(raw_ostream &OS) { 1007 std::vector<std::unique_ptr<RVVIntrinsic>> Defs; 1008 createRVVIntrinsics(Defs); 1009 // IR name could be empty, use the stable sort preserves the relative order. 1010 std::stable_sort(Defs.begin(), Defs.end(), 1011 [](const std::unique_ptr<RVVIntrinsic> &A, 1012 const std::unique_ptr<RVVIntrinsic> &B) { 1013 return A->getIRName() < B->getIRName(); 1014 }); 1015 // Print switch body when the ir name or ManualCodegen changes from previous 1016 // iteration. 1017 RVVIntrinsic *PrevDef = Defs.begin()->get(); 1018 for (auto &Def : Defs) { 1019 StringRef CurIRName = Def->getIRName(); 1020 if (CurIRName != PrevDef->getIRName() || 1021 (Def->getManualCodegen() != PrevDef->getManualCodegen())) { 1022 PrevDef->emitCodeGenSwitchBody(OS); 1023 } 1024 PrevDef = Def.get(); 1025 OS << "case RISCV::BI__builtin_rvv_" << Def->getName() << ":\n"; 1026 } 1027 Defs.back()->emitCodeGenSwitchBody(OS); 1028 OS << "\n"; 1029 } 1030 1031 void RVVEmitter::parsePrototypes(StringRef Prototypes, 1032 std::function<void(StringRef)> Handler) { 1033 const StringRef Primaries("evwqom0ztul"); 1034 while (!Prototypes.empty()) { 1035 size_t Idx = 0; 1036 // Skip over complex prototype because it could contain primitive type 1037 // character. 1038 if (Prototypes[0] == '(') 1039 Idx = Prototypes.find_first_of(')'); 1040 Idx = Prototypes.find_first_of(Primaries, Idx); 1041 assert(Idx != StringRef::npos); 1042 Handler(Prototypes.slice(0, Idx + 1)); 1043 Prototypes = Prototypes.drop_front(Idx + 1); 1044 } 1045 } 1046 1047 std::string RVVEmitter::getSuffixStr(char Type, int Log2LMUL, 1048 StringRef Prototypes) { 1049 SmallVector<std::string> SuffixStrs; 1050 parsePrototypes(Prototypes, [&](StringRef Proto) { 1051 auto T = computeType(Type, Log2LMUL, Proto); 1052 SuffixStrs.push_back(T.getValue()->getShortStr()); 1053 }); 1054 return join(SuffixStrs, "_"); 1055 } 1056 1057 void RVVEmitter::createRVVIntrinsics( 1058 std::vector<std::unique_ptr<RVVIntrinsic>> &Out) { 1059 std::vector<Record *> RV = Records.getAllDerivedDefinitions("RVVBuiltin"); 1060 for (auto *R : RV) { 1061 StringRef Name = R->getValueAsString("Name"); 1062 StringRef SuffixProto = R->getValueAsString("Suffix"); 1063 StringRef MangledName = R->getValueAsString("MangledName"); 1064 StringRef Prototypes = R->getValueAsString("Prototype"); 1065 StringRef TypeRange = R->getValueAsString("TypeRange"); 1066 bool HasMask = R->getValueAsBit("HasMask"); 1067 bool HasMaskedOffOperand = R->getValueAsBit("HasMaskedOffOperand"); 1068 bool HasVL = R->getValueAsBit("HasVL"); 1069 bool HasNoMaskedOverloaded = R->getValueAsBit("HasNoMaskedOverloaded"); 1070 bool HasSideEffects = R->getValueAsBit("HasSideEffects"); 1071 std::vector<int64_t> Log2LMULList = R->getValueAsListOfInts("Log2LMUL"); 1072 StringRef ManualCodegen = R->getValueAsString("ManualCodegen"); 1073 StringRef ManualCodegenMask = R->getValueAsString("ManualCodegenMask"); 1074 std::vector<int64_t> IntrinsicTypes = 1075 R->getValueAsListOfInts("IntrinsicTypes"); 1076 StringRef RequiredExtension = R->getValueAsString("RequiredExtension"); 1077 StringRef IRName = R->getValueAsString("IRName"); 1078 StringRef IRNameMask = R->getValueAsString("IRNameMask"); 1079 1080 StringRef HeaderCodeStr = R->getValueAsString("HeaderCode"); 1081 bool HasAutoDef = HeaderCodeStr.empty(); 1082 if (!HeaderCodeStr.empty()) { 1083 HeaderCode += HeaderCodeStr.str(); 1084 } 1085 // Parse prototype and create a list of primitive type with transformers 1086 // (operand) in ProtoSeq. ProtoSeq[0] is output operand. 1087 SmallVector<std::string> ProtoSeq; 1088 parsePrototypes(Prototypes, [&ProtoSeq](StringRef Proto) { 1089 ProtoSeq.push_back(Proto.str()); 1090 }); 1091 1092 // Compute Builtin types 1093 SmallVector<std::string> ProtoMaskSeq = ProtoSeq; 1094 if (HasMask) { 1095 // If HasMaskedOffOperand, insert result type as first input operand. 1096 if (HasMaskedOffOperand) 1097 ProtoMaskSeq.insert(ProtoMaskSeq.begin() + 1, ProtoSeq[0]); 1098 // If HasMask, insert 'm' as first input operand. 1099 ProtoMaskSeq.insert(ProtoMaskSeq.begin() + 1, "m"); 1100 } 1101 // If HasVL, append 'z' to last operand 1102 if (HasVL) { 1103 ProtoSeq.push_back("z"); 1104 ProtoMaskSeq.push_back("z"); 1105 } 1106 1107 // Create Intrinsics for each type and LMUL. 1108 for (char I : TypeRange) { 1109 for (int Log2LMUL : Log2LMULList) { 1110 Optional<RVVTypes> Types = computeTypes(I, Log2LMUL, ProtoSeq); 1111 // Ignored to create new intrinsic if there are any illegal types. 1112 if (!Types.hasValue()) 1113 continue; 1114 1115 auto SuffixStr = getSuffixStr(I, Log2LMUL, SuffixProto); 1116 // Create a non-mask intrinsic 1117 Out.push_back(std::make_unique<RVVIntrinsic>( 1118 Name, SuffixStr, MangledName, IRName, HasSideEffects, 1119 /*IsMask=*/false, /*HasMaskedOffOperand=*/false, HasVL, 1120 HasNoMaskedOverloaded, HasAutoDef, ManualCodegen, Types.getValue(), 1121 IntrinsicTypes, RequiredExtension)); 1122 if (HasMask) { 1123 // Create a mask intrinsic 1124 Optional<RVVTypes> MaskTypes = 1125 computeTypes(I, Log2LMUL, ProtoMaskSeq); 1126 Out.push_back(std::make_unique<RVVIntrinsic>( 1127 Name, SuffixStr, MangledName, IRNameMask, HasSideEffects, 1128 /*IsMask=*/true, HasMaskedOffOperand, HasVL, 1129 HasNoMaskedOverloaded, HasAutoDef, ManualCodegenMask, 1130 MaskTypes.getValue(), IntrinsicTypes, RequiredExtension)); 1131 } 1132 } // end for Log2LMULList 1133 } // end for TypeRange 1134 } 1135 } 1136 1137 Optional<RVVTypes> 1138 RVVEmitter::computeTypes(BasicType BT, int Log2LMUL, 1139 ArrayRef<std::string> PrototypeSeq) { 1140 RVVTypes Types; 1141 for (const std::string &Proto : PrototypeSeq) { 1142 auto T = computeType(BT, Log2LMUL, Proto); 1143 if (!T.hasValue()) 1144 return llvm::None; 1145 // Record legal type index 1146 Types.push_back(T.getValue()); 1147 } 1148 return Types; 1149 } 1150 1151 Optional<RVVTypePtr> RVVEmitter::computeType(BasicType BT, int Log2LMUL, 1152 StringRef Proto) { 1153 std::string Idx = Twine(Twine(BT) + Twine(Log2LMUL) + Proto).str(); 1154 // Search first 1155 auto It = LegalTypes.find(Idx); 1156 if (It != LegalTypes.end()) 1157 return &(It->second); 1158 if (IllegalTypes.count(Idx)) 1159 return llvm::None; 1160 // Compute type and record the result. 1161 RVVType T(BT, Log2LMUL, Proto); 1162 if (T.isValid()) { 1163 // Record legal type index and value. 1164 LegalTypes.insert({Idx, T}); 1165 return &(LegalTypes[Idx]); 1166 } 1167 // Record illegal type index. 1168 IllegalTypes.insert(Idx); 1169 return llvm::None; 1170 } 1171 1172 void RVVEmitter::emitArchMacroAndBody( 1173 std::vector<std::unique_ptr<RVVIntrinsic>> &Defs, raw_ostream &OS, 1174 std::function<void(raw_ostream &, const RVVIntrinsic &)> PrintBody) { 1175 uint8_t PrevExt = (*Defs.begin())->getRISCVExtensions(); 1176 bool NeedEndif = emitExtDefStr(PrevExt, OS); 1177 for (auto &Def : Defs) { 1178 uint8_t CurExt = Def->getRISCVExtensions(); 1179 if (CurExt != PrevExt) { 1180 if (NeedEndif) 1181 OS << "#endif\n\n"; 1182 NeedEndif = emitExtDefStr(CurExt, OS); 1183 PrevExt = CurExt; 1184 } 1185 if (Def->hasAutoDef()) 1186 PrintBody(OS, *Def); 1187 } 1188 if (NeedEndif) 1189 OS << "#endif\n\n"; 1190 } 1191 1192 bool RVVEmitter::emitExtDefStr(uint8_t Extents, raw_ostream &OS) { 1193 if (Extents == RISCVExtension::Basic) 1194 return false; 1195 OS << "#if "; 1196 ListSeparator LS(" && "); 1197 if (Extents & RISCVExtension::F) 1198 OS << LS << "defined(__riscv_f)"; 1199 if (Extents & RISCVExtension::D) 1200 OS << LS << "defined(__riscv_d)"; 1201 if (Extents & RISCVExtension::Zfh) 1202 OS << LS << "defined(__riscv_zfh)"; 1203 if (Extents & RISCVExtension::Zvamo) 1204 OS << LS << "defined(__riscv_zvamo)"; 1205 OS << "\n"; 1206 return true; 1207 } 1208 1209 namespace clang { 1210 void EmitRVVHeader(RecordKeeper &Records, raw_ostream &OS) { 1211 RVVEmitter(Records).createHeader(OS); 1212 } 1213 1214 void EmitRVVBuiltins(RecordKeeper &Records, raw_ostream &OS) { 1215 RVVEmitter(Records).createBuiltins(OS); 1216 } 1217 1218 void EmitRVVBuiltinCG(RecordKeeper &Records, raw_ostream &OS) { 1219 RVVEmitter(Records).createCodeGen(OS); 1220 } 1221 1222 } // End namespace clang 1223