1 //==- SemaRISCVVectorLookup.cpp - Name Lookup for RISC-V Vector Intrinsic -==// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements name lookup for RISC-V vector intrinsic. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "clang/AST/ASTContext.h" 14 #include "clang/AST/Decl.h" 15 #include "clang/Basic/Builtins.h" 16 #include "clang/Basic/TargetInfo.h" 17 #include "clang/Lex/Preprocessor.h" 18 #include "clang/Sema/Lookup.h" 19 #include "clang/Sema/RISCVIntrinsicManager.h" 20 #include "clang/Sema/Sema.h" 21 #include "clang/Support/RISCVVIntrinsicUtils.h" 22 #include "llvm/ADT/SmallVector.h" 23 #include <optional> 24 #include <string> 25 #include <vector> 26 27 using namespace llvm; 28 using namespace clang; 29 using namespace clang::RISCV; 30 31 namespace { 32 33 // Function definition of a RVV intrinsic. 34 struct RVVIntrinsicDef { 35 /// Full function name with suffix, e.g. vadd_vv_i32m1. 36 std::string Name; 37 38 /// Overloaded function name, e.g. vadd. 39 std::string OverloadName; 40 41 /// Mapping to which clang built-in function, e.g. __builtin_rvv_vadd. 42 std::string BuiltinName; 43 44 /// Function signature, first element is return type. 45 RVVTypes Signature; 46 }; 47 48 struct RVVOverloadIntrinsicDef { 49 // Indexes of RISCVIntrinsicManagerImpl::IntrinsicList. 50 SmallVector<size_t, 8> Indexes; 51 }; 52 53 } // namespace 54 55 static const PrototypeDescriptor RVVSignatureTable[] = { 56 #define DECL_SIGNATURE_TABLE 57 #include "clang/Basic/riscv_vector_builtin_sema.inc" 58 #undef DECL_SIGNATURE_TABLE 59 }; 60 61 static const RVVIntrinsicRecord RVVIntrinsicRecords[] = { 62 #define DECL_INTRINSIC_RECORDS 63 #include "clang/Basic/riscv_vector_builtin_sema.inc" 64 #undef DECL_INTRINSIC_RECORDS 65 }; 66 67 // Get subsequence of signature table. 68 static ArrayRef<PrototypeDescriptor> ProtoSeq2ArrayRef(uint16_t Index, 69 uint8_t Length) { 70 return ArrayRef(&RVVSignatureTable[Index], Length); 71 } 72 73 static QualType RVVType2Qual(ASTContext &Context, const RVVType *Type) { 74 QualType QT; 75 switch (Type->getScalarType()) { 76 case ScalarTypeKind::Void: 77 QT = Context.VoidTy; 78 break; 79 case ScalarTypeKind::Size_t: 80 QT = Context.getSizeType(); 81 break; 82 case ScalarTypeKind::Ptrdiff_t: 83 QT = Context.getPointerDiffType(); 84 break; 85 case ScalarTypeKind::UnsignedLong: 86 QT = Context.UnsignedLongTy; 87 break; 88 case ScalarTypeKind::SignedLong: 89 QT = Context.LongTy; 90 break; 91 case ScalarTypeKind::Boolean: 92 QT = Context.BoolTy; 93 break; 94 case ScalarTypeKind::SignedInteger: 95 QT = Context.getIntTypeForBitwidth(Type->getElementBitwidth(), true); 96 break; 97 case ScalarTypeKind::UnsignedInteger: 98 QT = Context.getIntTypeForBitwidth(Type->getElementBitwidth(), false); 99 break; 100 case ScalarTypeKind::Float: 101 switch (Type->getElementBitwidth()) { 102 case 64: 103 QT = Context.DoubleTy; 104 break; 105 case 32: 106 QT = Context.FloatTy; 107 break; 108 case 16: 109 QT = Context.Float16Ty; 110 break; 111 default: 112 llvm_unreachable("Unsupported floating point width."); 113 } 114 break; 115 case Invalid: 116 llvm_unreachable("Unhandled type."); 117 } 118 if (Type->isVector()) 119 QT = Context.getScalableVectorType(QT, *Type->getScale()); 120 121 if (Type->isConstant()) 122 QT = Context.getConstType(QT); 123 124 // Transform the type to a pointer as the last step, if necessary. 125 if (Type->isPointer()) 126 QT = Context.getPointerType(QT); 127 128 return QT; 129 } 130 131 namespace { 132 class RISCVIntrinsicManagerImpl : public sema::RISCVIntrinsicManager { 133 private: 134 Sema &S; 135 ASTContext &Context; 136 RVVTypeCache TypeCache; 137 138 // List of all RVV intrinsic. 139 std::vector<RVVIntrinsicDef> IntrinsicList; 140 // Mapping function name to index of IntrinsicList. 141 StringMap<size_t> Intrinsics; 142 // Mapping function name to RVVOverloadIntrinsicDef. 143 StringMap<RVVOverloadIntrinsicDef> OverloadIntrinsics; 144 145 // Create IntrinsicList 146 void InitIntrinsicList(); 147 148 // Create RVVIntrinsicDef. 149 void InitRVVIntrinsic(const RVVIntrinsicRecord &Record, StringRef SuffixStr, 150 StringRef OverloadedSuffixStr, bool IsMask, 151 RVVTypes &Types, bool HasPolicy, Policy PolicyAttrs); 152 153 // Create FunctionDecl for a vector intrinsic. 154 void CreateRVVIntrinsicDecl(LookupResult &LR, IdentifierInfo *II, 155 Preprocessor &PP, unsigned Index, 156 bool IsOverload); 157 158 public: 159 RISCVIntrinsicManagerImpl(clang::Sema &S) : S(S), Context(S.Context) { 160 InitIntrinsicList(); 161 } 162 163 // Create RISC-V vector intrinsic and insert into symbol table if found, and 164 // return true, otherwise return false. 165 bool CreateIntrinsicIfFound(LookupResult &LR, IdentifierInfo *II, 166 Preprocessor &PP) override; 167 }; 168 } // namespace 169 170 void RISCVIntrinsicManagerImpl::InitIntrinsicList() { 171 const TargetInfo &TI = Context.getTargetInfo(); 172 bool HasVectorFloat32 = TI.hasFeature("zve32f"); 173 bool HasVectorFloat64 = TI.hasFeature("zve64d"); 174 bool HasZvfh = TI.hasFeature("experimental-zvfh"); 175 bool HasRV64 = TI.hasFeature("64bit"); 176 bool HasFullMultiply = TI.hasFeature("v"); 177 178 // Construction of RVVIntrinsicRecords need to sync with createRVVIntrinsics 179 // in RISCVVEmitter.cpp. 180 for (auto &Record : RVVIntrinsicRecords) { 181 // Create Intrinsics for each type and LMUL. 182 BasicType BaseType = BasicType::Unknown; 183 ArrayRef<PrototypeDescriptor> BasicProtoSeq = 184 ProtoSeq2ArrayRef(Record.PrototypeIndex, Record.PrototypeLength); 185 ArrayRef<PrototypeDescriptor> SuffixProto = 186 ProtoSeq2ArrayRef(Record.SuffixIndex, Record.SuffixLength); 187 ArrayRef<PrototypeDescriptor> OverloadedSuffixProto = ProtoSeq2ArrayRef( 188 Record.OverloadedSuffixIndex, Record.OverloadedSuffixSize); 189 190 PolicyScheme UnMaskedPolicyScheme = 191 static_cast<PolicyScheme>(Record.UnMaskedPolicyScheme); 192 PolicyScheme MaskedPolicyScheme = 193 static_cast<PolicyScheme>(Record.MaskedPolicyScheme); 194 195 const Policy DefaultPolicy; 196 197 llvm::SmallVector<PrototypeDescriptor> ProtoSeq = 198 RVVIntrinsic::computeBuiltinTypes(BasicProtoSeq, /*IsMasked=*/false, 199 /*HasMaskedOffOperand=*/false, 200 Record.HasVL, Record.NF, 201 UnMaskedPolicyScheme, DefaultPolicy); 202 203 llvm::SmallVector<PrototypeDescriptor> ProtoMaskSeq = 204 RVVIntrinsic::computeBuiltinTypes( 205 BasicProtoSeq, /*IsMasked=*/true, Record.HasMaskedOffOperand, 206 Record.HasVL, Record.NF, MaskedPolicyScheme, DefaultPolicy); 207 208 bool UnMaskedHasPolicy = UnMaskedPolicyScheme != PolicyScheme::SchemeNone; 209 bool MaskedHasPolicy = MaskedPolicyScheme != PolicyScheme::SchemeNone; 210 SmallVector<Policy> SupportedUnMaskedPolicies = 211 RVVIntrinsic::getSupportedUnMaskedPolicies(); 212 SmallVector<Policy> SupportedMaskedPolicies = 213 RVVIntrinsic::getSupportedMaskedPolicies(Record.HasTailPolicy, 214 Record.HasMaskPolicy); 215 216 for (unsigned int TypeRangeMaskShift = 0; 217 TypeRangeMaskShift <= static_cast<unsigned int>(BasicType::MaxOffset); 218 ++TypeRangeMaskShift) { 219 unsigned int BaseTypeI = 1 << TypeRangeMaskShift; 220 BaseType = static_cast<BasicType>(BaseTypeI); 221 222 if ((BaseTypeI & Record.TypeRangeMask) != BaseTypeI) 223 continue; 224 225 // Check requirement. 226 if (BaseType == BasicType::Float16 && !HasZvfh) 227 continue; 228 229 if (BaseType == BasicType::Float32 && !HasVectorFloat32) 230 continue; 231 232 if (BaseType == BasicType::Float64 && !HasVectorFloat64) 233 continue; 234 235 if (((Record.RequiredExtensions & RVV_REQ_RV64) == RVV_REQ_RV64) && 236 !HasRV64) 237 continue; 238 239 if ((BaseType == BasicType::Int64) && 240 ((Record.RequiredExtensions & RVV_REQ_FullMultiply) == 241 RVV_REQ_FullMultiply) && 242 !HasFullMultiply) 243 continue; 244 245 // Expanded with different LMUL. 246 for (int Log2LMUL = -3; Log2LMUL <= 3; Log2LMUL++) { 247 if (!(Record.Log2LMULMask & (1 << (Log2LMUL + 3)))) 248 continue; 249 250 std::optional<RVVTypes> Types = 251 TypeCache.computeTypes(BaseType, Log2LMUL, Record.NF, ProtoSeq); 252 253 // Ignored to create new intrinsic if there are any illegal types. 254 if (!Types.has_value()) 255 continue; 256 257 std::string SuffixStr = RVVIntrinsic::getSuffixStr( 258 TypeCache, BaseType, Log2LMUL, SuffixProto); 259 std::string OverloadedSuffixStr = RVVIntrinsic::getSuffixStr( 260 TypeCache, BaseType, Log2LMUL, OverloadedSuffixProto); 261 262 // Create non-masked intrinsic. 263 InitRVVIntrinsic(Record, SuffixStr, OverloadedSuffixStr, false, *Types, 264 UnMaskedHasPolicy, DefaultPolicy); 265 266 // Create non-masked policy intrinsic. 267 if (Record.UnMaskedPolicyScheme != PolicyScheme::SchemeNone) { 268 for (auto P : SupportedUnMaskedPolicies) { 269 llvm::SmallVector<PrototypeDescriptor> PolicyPrototype = 270 RVVIntrinsic::computeBuiltinTypes( 271 BasicProtoSeq, /*IsMasked=*/false, 272 /*HasMaskedOffOperand=*/false, Record.HasVL, Record.NF, 273 UnMaskedPolicyScheme, P); 274 std::optional<RVVTypes> PolicyTypes = TypeCache.computeTypes( 275 BaseType, Log2LMUL, Record.NF, PolicyPrototype); 276 InitRVVIntrinsic(Record, SuffixStr, OverloadedSuffixStr, 277 /*IsMask=*/false, *PolicyTypes, UnMaskedHasPolicy, 278 P); 279 } 280 } 281 if (!Record.HasMasked) 282 continue; 283 // Create masked intrinsic. 284 std::optional<RVVTypes> MaskTypes = 285 TypeCache.computeTypes(BaseType, Log2LMUL, Record.NF, ProtoMaskSeq); 286 InitRVVIntrinsic(Record, SuffixStr, OverloadedSuffixStr, true, 287 *MaskTypes, MaskedHasPolicy, DefaultPolicy); 288 if (Record.MaskedPolicyScheme == PolicyScheme::SchemeNone) 289 continue; 290 // Create masked policy intrinsic. 291 for (auto P : SupportedMaskedPolicies) { 292 llvm::SmallVector<PrototypeDescriptor> PolicyPrototype = 293 RVVIntrinsic::computeBuiltinTypes( 294 BasicProtoSeq, /*IsMasked=*/true, Record.HasMaskedOffOperand, 295 Record.HasVL, Record.NF, MaskedPolicyScheme, P); 296 std::optional<RVVTypes> PolicyTypes = TypeCache.computeTypes( 297 BaseType, Log2LMUL, Record.NF, PolicyPrototype); 298 InitRVVIntrinsic(Record, SuffixStr, OverloadedSuffixStr, 299 /*IsMask=*/true, *PolicyTypes, MaskedHasPolicy, P); 300 } 301 } // End for different LMUL 302 } // End for different TypeRange 303 } 304 } 305 306 // Compute name and signatures for intrinsic with practical types. 307 void RISCVIntrinsicManagerImpl::InitRVVIntrinsic( 308 const RVVIntrinsicRecord &Record, StringRef SuffixStr, 309 StringRef OverloadedSuffixStr, bool IsMasked, RVVTypes &Signature, 310 bool HasPolicy, Policy PolicyAttrs) { 311 // Function name, e.g. vadd_vv_i32m1. 312 std::string Name = Record.Name; 313 if (!SuffixStr.empty()) 314 Name += "_" + SuffixStr.str(); 315 316 // Overloaded function name, e.g. vadd. 317 std::string OverloadedName; 318 if (!Record.OverloadedName) 319 OverloadedName = StringRef(Record.Name).split("_").first.str(); 320 else 321 OverloadedName = Record.OverloadedName; 322 if (!OverloadedSuffixStr.empty()) 323 OverloadedName += "_" + OverloadedSuffixStr.str(); 324 325 // clang built-in function name, e.g. __builtin_rvv_vadd. 326 std::string BuiltinName = "__builtin_rvv_" + std::string(Record.Name); 327 328 RVVIntrinsic::updateNamesAndPolicy(IsMasked, HasPolicy, Name, BuiltinName, 329 OverloadedName, PolicyAttrs); 330 331 // Put into IntrinsicList. 332 size_t Index = IntrinsicList.size(); 333 IntrinsicList.push_back({Name, OverloadedName, BuiltinName, Signature}); 334 335 // Creating mapping to Intrinsics. 336 Intrinsics.insert({Name, Index}); 337 338 // Get the RVVOverloadIntrinsicDef. 339 RVVOverloadIntrinsicDef &OverloadIntrinsicDef = 340 OverloadIntrinsics[OverloadedName]; 341 342 // And added the index. 343 OverloadIntrinsicDef.Indexes.push_back(Index); 344 } 345 346 void RISCVIntrinsicManagerImpl::CreateRVVIntrinsicDecl(LookupResult &LR, 347 IdentifierInfo *II, 348 Preprocessor &PP, 349 unsigned Index, 350 bool IsOverload) { 351 ASTContext &Context = S.Context; 352 RVVIntrinsicDef &IDef = IntrinsicList[Index]; 353 RVVTypes Sigs = IDef.Signature; 354 size_t SigLength = Sigs.size(); 355 RVVType *ReturnType = Sigs[0]; 356 QualType RetType = RVVType2Qual(Context, ReturnType); 357 SmallVector<QualType, 8> ArgTypes; 358 QualType BuiltinFuncType; 359 360 // Skip return type, and convert RVVType to QualType for arguments. 361 for (size_t i = 1; i < SigLength; ++i) 362 ArgTypes.push_back(RVVType2Qual(Context, Sigs[i])); 363 364 FunctionProtoType::ExtProtoInfo PI( 365 Context.getDefaultCallingConvention(false, false, true)); 366 367 PI.Variadic = false; 368 369 SourceLocation Loc = LR.getNameLoc(); 370 BuiltinFuncType = Context.getFunctionType(RetType, ArgTypes, PI); 371 DeclContext *Parent = Context.getTranslationUnitDecl(); 372 373 FunctionDecl *RVVIntrinsicDecl = FunctionDecl::Create( 374 Context, Parent, Loc, Loc, II, BuiltinFuncType, /*TInfo=*/nullptr, 375 SC_Extern, S.getCurFPFeatures().isFPConstrained(), 376 /*isInlineSpecified*/ false, 377 /*hasWrittenPrototype*/ true); 378 379 // Create Decl objects for each parameter, adding them to the 380 // FunctionDecl. 381 const auto *FP = cast<FunctionProtoType>(BuiltinFuncType); 382 SmallVector<ParmVarDecl *, 8> ParmList; 383 for (unsigned IParm = 0, E = FP->getNumParams(); IParm != E; ++IParm) { 384 ParmVarDecl *Parm = 385 ParmVarDecl::Create(Context, RVVIntrinsicDecl, Loc, Loc, nullptr, 386 FP->getParamType(IParm), nullptr, SC_None, nullptr); 387 Parm->setScopeInfo(0, IParm); 388 ParmList.push_back(Parm); 389 } 390 RVVIntrinsicDecl->setParams(ParmList); 391 392 // Add function attributes. 393 if (IsOverload) 394 RVVIntrinsicDecl->addAttr(OverloadableAttr::CreateImplicit(Context)); 395 396 // Setup alias to __builtin_rvv_* 397 IdentifierInfo &IntrinsicII = PP.getIdentifierTable().get(IDef.BuiltinName); 398 RVVIntrinsicDecl->addAttr( 399 BuiltinAliasAttr::CreateImplicit(S.Context, &IntrinsicII)); 400 401 // Add to symbol table. 402 LR.addDecl(RVVIntrinsicDecl); 403 } 404 405 bool RISCVIntrinsicManagerImpl::CreateIntrinsicIfFound(LookupResult &LR, 406 IdentifierInfo *II, 407 Preprocessor &PP) { 408 StringRef Name = II->getName(); 409 410 // Lookup the function name from the overload intrinsics first. 411 auto OvIItr = OverloadIntrinsics.find(Name); 412 if (OvIItr != OverloadIntrinsics.end()) { 413 const RVVOverloadIntrinsicDef &OvIntrinsicDef = OvIItr->second; 414 for (auto Index : OvIntrinsicDef.Indexes) 415 CreateRVVIntrinsicDecl(LR, II, PP, Index, 416 /*IsOverload*/ true); 417 418 // If we added overloads, need to resolve the lookup result. 419 LR.resolveKind(); 420 return true; 421 } 422 423 // Lookup the function name from the intrinsics. 424 auto Itr = Intrinsics.find(Name); 425 if (Itr != Intrinsics.end()) { 426 CreateRVVIntrinsicDecl(LR, II, PP, Itr->second, 427 /*IsOverload*/ false); 428 return true; 429 } 430 431 // It's not an RVV intrinsics. 432 return false; 433 } 434 435 namespace clang { 436 std::unique_ptr<clang::sema::RISCVIntrinsicManager> 437 CreateRISCVIntrinsicManager(Sema &S) { 438 return std::make_unique<RISCVIntrinsicManagerImpl>(S); 439 } 440 } // namespace clang 441