1 //===- VPIntrinsicTest.cpp - VPIntrinsic unit tests ---------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "llvm/ADT/SmallVector.h" 10 #include "llvm/AsmParser/Parser.h" 11 #include "llvm/CodeGen/ISDOpcodes.h" 12 #include "llvm/IR/Constants.h" 13 #include "llvm/IR/IRBuilder.h" 14 #include "llvm/IR/IntrinsicInst.h" 15 #include "llvm/IR/LLVMContext.h" 16 #include "llvm/IR/Module.h" 17 #include "llvm/IR/Verifier.h" 18 #include "llvm/Support/SourceMgr.h" 19 #include "gtest/gtest.h" 20 #include <optional> 21 #include <sstream> 22 23 using namespace llvm; 24 25 namespace { 26 27 static const char *ReductionIntOpcodes[] = { 28 "add", "mul", "and", "or", "xor", "smin", "smax", "umin", "umax"}; 29 30 static const char *ReductionFPOpcodes[] = {"fadd", "fmul", "fmin", 31 "fmax", "fminimum", "fmaximum"}; 32 33 class VPIntrinsicTest : public testing::Test { 34 protected: 35 LLVMContext Context; 36 37 VPIntrinsicTest() : Context() {} 38 39 LLVMContext C; 40 SMDiagnostic Err; 41 42 std::unique_ptr<Module> createVPDeclarationModule() { 43 const char *BinaryIntOpcodes[] = {"add", "sub", "mul", "sdiv", "srem", 44 "udiv", "urem", "and", "xor", "or", 45 "ashr", "lshr", "shl", "smin", "smax", 46 "umin", "umax"}; 47 std::stringstream Str; 48 for (const char *BinaryIntOpcode : BinaryIntOpcodes) 49 Str << " declare <8 x i32> @llvm.vp." << BinaryIntOpcode 50 << ".v8i32(<8 x i32>, <8 x i32>, <8 x i1>, i32) "; 51 52 const char *BinaryFPOpcodes[] = {"fadd", "fsub", "fmul", "fdiv", 53 "frem", "minnum", "maxnum", "minimum", 54 "maximum", "copysign"}; 55 for (const char *BinaryFPOpcode : BinaryFPOpcodes) 56 Str << " declare <8 x float> @llvm.vp." << BinaryFPOpcode 57 << ".v8f32(<8 x float>, <8 x float>, <8 x i1>, i32) "; 58 59 Str << " declare <8 x float> @llvm.vp.floor.v8f32(<8 x float>, <8 x i1>, " 60 "i32)"; 61 Str << " declare <8 x float> @llvm.vp.round.v8f32(<8 x float>, <8 x i1>, " 62 "i32)"; 63 Str << " declare <8 x float> @llvm.vp.roundeven.v8f32(<8 x float>, <8 x " 64 "i1>, " 65 "i32)"; 66 Str << " declare <8 x float> @llvm.vp.roundtozero.v8f32(<8 x float>, <8 x " 67 "i1>, " 68 "i32)"; 69 Str << " declare <8 x float> @llvm.vp.rint.v8f32(<8 x float>, <8 x i1>, " 70 "i32)"; 71 Str << " declare <8 x float> @llvm.vp.nearbyint.v8f32(<8 x float>, <8 x " 72 "i1>, " 73 "i32)"; 74 Str << " declare <8 x float> @llvm.vp.ceil.v8f32(<8 x float>, <8 x i1>, " 75 "i32)"; 76 Str << " declare <8 x i32> @llvm.vp.lrint.v8i32.v8f32(<8 x float>, " 77 "<8 x i1>, i32)"; 78 Str << " declare <8 x i64> @llvm.vp.llrint.v8i64.v8f32(<8 x float>, " 79 "<8 x i1>, i32)"; 80 Str << " declare <8 x float> @llvm.vp.fneg.v8f32(<8 x float>, <8 x i1>, " 81 "i32)"; 82 Str << " declare <8 x float> @llvm.vp.fabs.v8f32(<8 x float>, <8 x i1>, " 83 "i32)"; 84 Str << " declare <8 x float> @llvm.vp.sqrt.v8f32(<8 x float>, <8 x i1>, " 85 "i32)"; 86 Str << " declare <8 x float> @llvm.vp.fma.v8f32(<8 x float>, <8 x float>, " 87 "<8 x float>, <8 x i1>, i32) "; 88 Str << " declare <8 x float> @llvm.vp.fmuladd.v8f32(<8 x float>, " 89 "<8 x float>, <8 x float>, <8 x i1>, i32) "; 90 91 Str << " declare void @llvm.vp.store.v8i32.p0v8i32(<8 x i32>, <8 x i32>*, " 92 "<8 x i1>, i32) "; 93 Str << "declare void " 94 "@llvm.experimental.vp.strided.store.v8i32.i32(<8 x i32>, " 95 "i32*, i32, <8 x i1>, i32) "; 96 Str << "declare void " 97 "@llvm.experimental.vp.strided.store.v8i32.p1i32.i32(<8 x i32>, " 98 "i32 addrspace(1)*, i32, <8 x i1>, i32) "; 99 Str << " declare void @llvm.vp.scatter.v8i32.v8p0i32(<8 x i32>, <8 x " 100 "i32*>, <8 x i1>, i32) "; 101 Str << " declare <8 x i32> @llvm.vp.load.v8i32.p0v8i32(<8 x i32>*, <8 x " 102 "i1>, i32) "; 103 Str << "declare <8 x i32> " 104 "@llvm.experimental.vp.strided.load.v8i32.i32(i32*, i32, <8 " 105 "x i1>, i32) "; 106 Str << "declare <8 x i32> " 107 "@llvm.experimental.vp.strided.load.v8i32.p1i32.i32(i32 " 108 "addrspace(1)*, i32, <8 x i1>, i32) "; 109 Str << " declare <8 x i32> @llvm.vp.gather.v8i32.v8p0i32(<8 x i32*>, <8 x " 110 "i1>, i32) "; 111 Str << " declare <8 x i32> @llvm.experimental.vp.splat.v8i32(i32, <8 x " 112 "i1>, i32) "; 113 114 for (const char *ReductionOpcode : ReductionIntOpcodes) 115 Str << " declare i32 @llvm.vp.reduce." << ReductionOpcode 116 << ".v8i32(i32, <8 x i32>, <8 x i1>, i32) "; 117 118 for (const char *ReductionOpcode : ReductionFPOpcodes) 119 Str << " declare float @llvm.vp.reduce." << ReductionOpcode 120 << ".v8f32(float, <8 x float>, <8 x i1>, i32) "; 121 122 Str << " declare <8 x i32> @llvm.vp.merge.v8i32(<8 x i1>, <8 x i32>, <8 x " 123 "i32>, i32)"; 124 Str << " declare <8 x i32> @llvm.vp.select.v8i32(<8 x i1>, <8 x i32>, <8 x " 125 "i32>, i32)"; 126 Str << " declare <8 x i1> @llvm.vp.is.fpclass.v8f32(<8 x float>, i32, <8 x " 127 "i1>, i32)"; 128 Str << " declare <8 x i32> @llvm.experimental.vp.splice.v8i32(<8 x " 129 "i32>, <8 x i32>, i32, <8 x i1>, i32, i32) "; 130 131 Str << " declare <8 x i32> @llvm.vp.fptoui.v8i32" 132 << ".v8f32(<8 x float>, <8 x i1>, i32) "; 133 Str << " declare <8 x i32> @llvm.vp.fptosi.v8i32" 134 << ".v8f32(<8 x float>, <8 x i1>, i32) "; 135 Str << " declare <8 x float> @llvm.vp.uitofp.v8f32" 136 << ".v8i32(<8 x i32>, <8 x i1>, i32) "; 137 Str << " declare <8 x float> @llvm.vp.sitofp.v8f32" 138 << ".v8i32(<8 x i32>, <8 x i1>, i32) "; 139 Str << " declare <8 x float> @llvm.vp.fptrunc.v8f32" 140 << ".v8f64(<8 x double>, <8 x i1>, i32) "; 141 Str << " declare <8 x double> @llvm.vp.fpext.v8f64" 142 << ".v8f32(<8 x float>, <8 x i1>, i32) "; 143 Str << " declare <8 x i32> @llvm.vp.trunc.v8i32" 144 << ".v8i64(<8 x i64>, <8 x i1>, i32) "; 145 Str << " declare <8 x i64> @llvm.vp.zext.v8i64" 146 << ".v8i32(<8 x i32>, <8 x i1>, i32) "; 147 Str << " declare <8 x i64> @llvm.vp.sext.v8i64" 148 << ".v8i32(<8 x i32>, <8 x i1>, i32) "; 149 Str << " declare <8 x i32> @llvm.vp.ptrtoint.v8i32" 150 << ".v8p0i32(<8 x i32*>, <8 x i1>, i32) "; 151 Str << " declare <8 x i32*> @llvm.vp.inttoptr.v8p0i32" 152 << ".v8i32(<8 x i32>, <8 x i1>, i32) "; 153 154 Str << " declare <8 x i1> @llvm.vp.fcmp.v8f32" 155 << "(<8 x float>, <8 x float>, metadata, <8 x i1>, i32) "; 156 Str << " declare <8 x i1> @llvm.vp.icmp.v8i16" 157 << "(<8 x i16>, <8 x i16>, metadata, <8 x i1>, i32) "; 158 159 Str << " declare <8 x i32> @llvm.experimental.vp.reverse.v8i32(<8 x i32>, " 160 "<8 x i1>, i32) "; 161 Str << " declare <8 x i16> @llvm.vp.abs.v8i16" 162 << "(<8 x i16>, i1 immarg, <8 x i1>, i32) "; 163 Str << " declare <8 x i16> @llvm.vp.bitreverse.v8i16" 164 << "(<8 x i16>, <8 x i1>, i32) "; 165 Str << " declare <8 x i16> @llvm.vp.bswap.v8i16" 166 << "(<8 x i16>, <8 x i1>, i32) "; 167 Str << " declare <8 x i16> @llvm.vp.ctpop.v8i16" 168 << "(<8 x i16>, <8 x i1>, i32) "; 169 Str << " declare <8 x i16> @llvm.vp.ctlz.v8i16" 170 << "(<8 x i16>, i1 immarg, <8 x i1>, i32) "; 171 Str << " declare <8 x i16> @llvm.vp.cttz.v8i16" 172 << "(<8 x i16>, i1 immarg, <8 x i1>, i32) "; 173 Str << " declare <8 x i16> @llvm.vp.sadd.sat.v8i16" 174 << "(<8 x i16>, <8 x i16>, <8 x i1>, i32) "; 175 Str << " declare <8 x i16> @llvm.vp.uadd.sat.v8i16" 176 << "(<8 x i16>, <8 x i16>, <8 x i1>, i32) "; 177 Str << " declare <8 x i16> @llvm.vp.ssub.sat.v8i16" 178 << "(<8 x i16>, <8 x i16>, <8 x i1>, i32) "; 179 Str << " declare <8 x i16> @llvm.vp.usub.sat.v8i16" 180 << "(<8 x i16>, <8 x i16>, <8 x i1>, i32) "; 181 Str << " declare <8 x i16> @llvm.vp.fshl.v8i16" 182 << "(<8 x i16>, <8 x i16>, <8 x i16>, <8 x i1>, i32) "; 183 Str << " declare <8 x i16> @llvm.vp.fshr.v8i16" 184 << "(<8 x i16>, <8 x i16>, <8 x i16>, <8 x i1>, i32) "; 185 Str << " declare i32 @llvm.vp.cttz.elts.i32.v8i16" 186 << "(<8 x i16>, i1 immarg, <8 x i1>, i32) "; 187 188 return parseAssemblyString(Str.str(), Err, C); 189 } 190 }; 191 192 /// Check that the property scopes include/llvm/IR/VPIntrinsics.def are closed. 193 TEST_F(VPIntrinsicTest, VPIntrinsicsDefScopes) { 194 std::optional<Intrinsic::ID> ScopeVPID; 195 #define BEGIN_REGISTER_VP_INTRINSIC(VPID, ...) \ 196 ASSERT_FALSE(ScopeVPID.has_value()); \ 197 ScopeVPID = Intrinsic::VPID; 198 #define END_REGISTER_VP_INTRINSIC(VPID) \ 199 ASSERT_TRUE(ScopeVPID.has_value()); \ 200 ASSERT_EQ(*ScopeVPID, Intrinsic::VPID); \ 201 ScopeVPID = std::nullopt; 202 203 std::optional<ISD::NodeType> ScopeOPC; 204 #define BEGIN_REGISTER_VP_SDNODE(SDOPC, ...) \ 205 ASSERT_FALSE(ScopeOPC.has_value()); \ 206 ScopeOPC = ISD::SDOPC; 207 #define END_REGISTER_VP_SDNODE(SDOPC) \ 208 ASSERT_TRUE(ScopeOPC.has_value()); \ 209 ASSERT_EQ(*ScopeOPC, ISD::SDOPC); \ 210 ScopeOPC = std::nullopt; 211 #include "llvm/IR/VPIntrinsics.def" 212 213 ASSERT_FALSE(ScopeVPID.has_value()); 214 ASSERT_FALSE(ScopeOPC.has_value()); 215 } 216 217 /// Check that every VP intrinsic in the test module is recognized as a VP 218 /// intrinsic. 219 TEST_F(VPIntrinsicTest, VPModuleComplete) { 220 std::unique_ptr<Module> M = createVPDeclarationModule(); 221 assert(M); 222 223 // Check that all @llvm.vp.* functions in the module are recognized vp 224 // intrinsics. 225 std::set<Intrinsic::ID> SeenIDs; 226 for (const auto &VPDecl : *M) { 227 ASSERT_TRUE(VPDecl.isIntrinsic()); 228 ASSERT_TRUE(VPIntrinsic::isVPIntrinsic(VPDecl.getIntrinsicID())); 229 SeenIDs.insert(VPDecl.getIntrinsicID()); 230 } 231 232 // Check that every registered VP intrinsic has an instance in the test 233 // module. 234 #define BEGIN_REGISTER_VP_INTRINSIC(VPID, ...) \ 235 ASSERT_TRUE(SeenIDs.count(Intrinsic::VPID)); 236 #include "llvm/IR/VPIntrinsics.def" 237 } 238 239 /// Check that VPIntrinsic:canIgnoreVectorLengthParam() returns true 240 /// if the vector length parameter does not mask off any lanes. 241 TEST_F(VPIntrinsicTest, CanIgnoreVectorLength) { 242 LLVMContext C; 243 SMDiagnostic Err; 244 245 std::unique_ptr<Module> M = 246 parseAssemblyString( 247 "declare <256 x i64> @llvm.vp.mul.v256i64(<256 x i64>, <256 x i64>, <256 x i1>, i32)" 248 "declare <vscale x 2 x i64> @llvm.vp.mul.nxv2i64(<vscale x 2 x i64>, <vscale x 2 x i64>, <vscale x 2 x i1>, i32)" 249 "declare <vscale x 1 x i64> @llvm.vp.mul.nxv1i64(<vscale x 1 x i64>, <vscale x 1 x i64>, <vscale x 1 x i1>, i32)" 250 "declare i32 @llvm.vscale.i32()" 251 "define void @test_static_vlen( " 252 " <256 x i64> %i0, <vscale x 2 x i64> %si0x2, <vscale x 1 x i64> %si0x1," 253 " <256 x i64> %i1, <vscale x 2 x i64> %si1x2, <vscale x 1 x i64> %si1x1," 254 " <256 x i1> %m, <vscale x 2 x i1> %smx2, <vscale x 1 x i1> %smx1, i32 %vl) { " 255 " %r0 = call <256 x i64> @llvm.vp.mul.v256i64(<256 x i64> %i0, <256 x i64> %i1, <256 x i1> %m, i32 %vl)" 256 " %r1 = call <256 x i64> @llvm.vp.mul.v256i64(<256 x i64> %i0, <256 x i64> %i1, <256 x i1> %m, i32 256)" 257 " %r2 = call <256 x i64> @llvm.vp.mul.v256i64(<256 x i64> %i0, <256 x i64> %i1, <256 x i1> %m, i32 0)" 258 " %r3 = call <256 x i64> @llvm.vp.mul.v256i64(<256 x i64> %i0, <256 x i64> %i1, <256 x i1> %m, i32 7)" 259 " %r4 = call <256 x i64> @llvm.vp.mul.v256i64(<256 x i64> %i0, <256 x i64> %i1, <256 x i1> %m, i32 123)" 260 " %vs = call i32 @llvm.vscale.i32()" 261 " %vs.x2 = mul i32 %vs, 2" 262 " %r5 = call <vscale x 2 x i64> @llvm.vp.mul.nxv2i64(<vscale x 2 x i64> %si0x2, <vscale x 2 x i64> %si1x2, <vscale x 2 x i1> %smx2, i32 %vs.x2)" 263 " %r6 = call <vscale x 2 x i64> @llvm.vp.mul.nxv2i64(<vscale x 2 x i64> %si0x2, <vscale x 2 x i64> %si1x2, <vscale x 2 x i1> %smx2, i32 %vs)" 264 " %r7 = call <vscale x 2 x i64> @llvm.vp.mul.nxv2i64(<vscale x 2 x i64> %si0x2, <vscale x 2 x i64> %si1x2, <vscale x 2 x i1> %smx2, i32 99999)" 265 " %r8 = call <vscale x 1 x i64> @llvm.vp.mul.nxv1i64(<vscale x 1 x i64> %si0x1, <vscale x 1 x i64> %si1x1, <vscale x 1 x i1> %smx1, i32 %vs)" 266 " %r9 = call <vscale x 1 x i64> @llvm.vp.mul.nxv1i64(<vscale x 1 x i64> %si0x1, <vscale x 1 x i64> %si1x1, <vscale x 1 x i1> %smx1, i32 1)" 267 " %r10 = call <vscale x 1 x i64> @llvm.vp.mul.nxv1i64(<vscale x 1 x i64> %si0x1, <vscale x 1 x i64> %si1x1, <vscale x 1 x i1> %smx1, i32 %vs.x2)" 268 " %vs.wat = add i32 %vs, 2" 269 " %r11 = call <vscale x 2 x i64> @llvm.vp.mul.nxv2i64(<vscale x 2 x i64> %si0x2, <vscale x 2 x i64> %si1x2, <vscale x 2 x i1> %smx2, i32 %vs.wat)" 270 " ret void " 271 "}", 272 Err, C); 273 274 auto *F = M->getFunction("test_static_vlen"); 275 assert(F); 276 277 const bool Expected[] = {false, true, false, false, false, true, 278 false, false, true, false, true, false}; 279 const auto *ExpectedIt = std::begin(Expected); 280 for (auto &I : F->getEntryBlock()) { 281 VPIntrinsic *VPI = dyn_cast<VPIntrinsic>(&I); 282 if (!VPI) 283 continue; 284 285 ASSERT_NE(ExpectedIt, std::end(Expected)); 286 ASSERT_EQ(*ExpectedIt, VPI->canIgnoreVectorLengthParam()); 287 ++ExpectedIt; 288 } 289 } 290 291 /// Check that the argument returned by 292 /// VPIntrinsic::get<X>ParamPos(Intrinsic::ID) has the expected type. 293 TEST_F(VPIntrinsicTest, GetParamPos) { 294 std::unique_ptr<Module> M = createVPDeclarationModule(); 295 assert(M); 296 297 for (Function &F : *M) { 298 ASSERT_TRUE(F.isIntrinsic()); 299 std::optional<unsigned> MaskParamPos = 300 VPIntrinsic::getMaskParamPos(F.getIntrinsicID()); 301 if (MaskParamPos) { 302 Type *MaskParamType = F.getArg(*MaskParamPos)->getType(); 303 ASSERT_TRUE(MaskParamType->isVectorTy()); 304 ASSERT_TRUE( 305 cast<VectorType>(MaskParamType)->getElementType()->isIntegerTy(1)); 306 } 307 308 std::optional<unsigned> VecLenParamPos = 309 VPIntrinsic::getVectorLengthParamPos(F.getIntrinsicID()); 310 if (VecLenParamPos) { 311 Type *VecLenParamType = F.getArg(*VecLenParamPos)->getType(); 312 ASSERT_TRUE(VecLenParamType->isIntegerTy(32)); 313 } 314 } 315 } 316 317 /// Check that going from Opcode to VP intrinsic and back results in the same 318 /// Opcode. 319 TEST_F(VPIntrinsicTest, OpcodeRoundTrip) { 320 std::vector<unsigned> Opcodes; 321 Opcodes.reserve(100); 322 323 { 324 #define HANDLE_INST(OCNum, OCName, Class) Opcodes.push_back(OCNum); 325 #include "llvm/IR/Instruction.def" 326 } 327 328 unsigned FullTripCounts = 0; 329 for (unsigned OC : Opcodes) { 330 Intrinsic::ID VPID = VPIntrinsic::getForOpcode(OC); 331 // No equivalent VP intrinsic available. 332 if (VPID == Intrinsic::not_intrinsic) 333 continue; 334 335 std::optional<unsigned> RoundTripOC = 336 VPIntrinsic::getFunctionalOpcodeForVP(VPID); 337 // No equivalent Opcode available. 338 if (!RoundTripOC) 339 continue; 340 341 ASSERT_EQ(*RoundTripOC, OC); 342 ++FullTripCounts; 343 } 344 ASSERT_NE(FullTripCounts, 0u); 345 } 346 347 /// Check that going from VP intrinsic to Opcode and back results in the same 348 /// intrinsic id. 349 TEST_F(VPIntrinsicTest, IntrinsicIDRoundTrip) { 350 std::unique_ptr<Module> M = createVPDeclarationModule(); 351 assert(M); 352 353 unsigned FullTripCounts = 0; 354 for (const auto &VPDecl : *M) { 355 auto VPID = VPDecl.getIntrinsicID(); 356 std::optional<unsigned> OC = VPIntrinsic::getFunctionalOpcodeForVP(VPID); 357 358 // no equivalent Opcode available 359 if (!OC) 360 continue; 361 362 Intrinsic::ID RoundTripVPID = VPIntrinsic::getForOpcode(*OC); 363 364 ASSERT_EQ(RoundTripVPID, VPID); 365 ++FullTripCounts; 366 } 367 ASSERT_NE(FullTripCounts, 0u); 368 } 369 370 /// Check that going from intrinsic to VP intrinsic and back results in the same 371 /// intrinsic. 372 TEST_F(VPIntrinsicTest, IntrinsicToVPRoundTrip) { 373 bool IsFullTrip = false; 374 Intrinsic::ID IntrinsicID = Intrinsic::not_intrinsic + 1; 375 for (; IntrinsicID < Intrinsic::num_intrinsics; IntrinsicID++) { 376 Intrinsic::ID VPID = VPIntrinsic::getForIntrinsic(IntrinsicID); 377 // No equivalent VP intrinsic available. 378 if (VPID == Intrinsic::not_intrinsic) 379 continue; 380 381 // Return itself if passed intrinsic ID is VP intrinsic. 382 if (VPIntrinsic::isVPIntrinsic(IntrinsicID)) { 383 ASSERT_EQ(IntrinsicID, VPID); 384 continue; 385 } 386 387 std::optional<Intrinsic::ID> RoundTripIntrinsicID = 388 VPIntrinsic::getFunctionalIntrinsicIDForVP(VPID); 389 // No equivalent non-predicated intrinsic available. 390 if (!RoundTripIntrinsicID) 391 continue; 392 393 ASSERT_EQ(*RoundTripIntrinsicID, IntrinsicID); 394 IsFullTrip = true; 395 } 396 ASSERT_TRUE(IsFullTrip); 397 } 398 399 /// Check that going from VP intrinsic to equivalent non-predicated intrinsic 400 /// and back results in the same intrinsic. 401 TEST_F(VPIntrinsicTest, VPToNonPredIntrinsicRoundTrip) { 402 std::unique_ptr<Module> M = createVPDeclarationModule(); 403 assert(M); 404 405 bool IsFullTrip = false; 406 for (const auto &VPDecl : *M) { 407 auto VPID = VPDecl.getIntrinsicID(); 408 std::optional<Intrinsic::ID> NonPredID = 409 VPIntrinsic::getFunctionalIntrinsicIDForVP(VPID); 410 411 // No equivalent non-predicated intrinsic available 412 if (!NonPredID) 413 continue; 414 415 Intrinsic::ID RoundTripVPID = VPIntrinsic::getForIntrinsic(*NonPredID); 416 417 ASSERT_EQ(RoundTripVPID, VPID); 418 IsFullTrip = true; 419 } 420 ASSERT_TRUE(IsFullTrip); 421 } 422 423 /// Check that VPIntrinsic::getOrInsertDeclarationForParams works. 424 TEST_F(VPIntrinsicTest, VPIntrinsicDeclarationForParams) { 425 std::unique_ptr<Module> M = createVPDeclarationModule(); 426 assert(M); 427 428 auto OutM = std::make_unique<Module>("", M->getContext()); 429 430 for (auto &F : *M) { 431 auto *FuncTy = F.getFunctionType(); 432 433 // Declare intrinsic anew with explicit types. 434 std::vector<Value *> Values; 435 for (auto *ParamTy : FuncTy->params()) 436 Values.push_back(UndefValue::get(ParamTy)); 437 438 ASSERT_NE(F.getIntrinsicID(), Intrinsic::not_intrinsic); 439 auto *NewDecl = VPIntrinsic::getOrInsertDeclarationForParams( 440 OutM.get(), F.getIntrinsicID(), FuncTy->getReturnType(), Values); 441 ASSERT_TRUE(NewDecl); 442 443 // Check that 'old decl' == 'new decl'. 444 ASSERT_EQ(F.getIntrinsicID(), NewDecl->getIntrinsicID()); 445 FunctionType::param_iterator ItNewParams = 446 NewDecl->getFunctionType()->param_begin(); 447 FunctionType::param_iterator EndItNewParams = 448 NewDecl->getFunctionType()->param_end(); 449 for (auto *ParamTy : FuncTy->params()) { 450 ASSERT_NE(ItNewParams, EndItNewParams); 451 ASSERT_EQ(*ItNewParams, ParamTy); 452 ++ItNewParams; 453 } 454 } 455 } 456 457 } // end anonymous namespace 458 459 /// Check various properties of VPReductionIntrinsics 460 TEST_F(VPIntrinsicTest, VPReductions) { 461 LLVMContext C; 462 SMDiagnostic Err; 463 464 std::stringstream Str; 465 Str << "declare <8 x i32> @llvm.vp.mul.v8i32(<8 x i32>, <8 x i32>, <8 x i1>, " 466 "i32)"; 467 for (const char *ReductionOpcode : ReductionIntOpcodes) 468 Str << " declare i32 @llvm.vp.reduce." << ReductionOpcode 469 << ".v8i32(i32, <8 x i32>, <8 x i1>, i32) "; 470 471 for (const char *ReductionOpcode : ReductionFPOpcodes) 472 Str << " declare float @llvm.vp.reduce." << ReductionOpcode 473 << ".v8f32(float, <8 x float>, <8 x i1>, i32) "; 474 475 Str << "define void @test_reductions(i32 %start, <8 x i32> %val, float " 476 "%fpstart, <8 x float> %fpval, <8 x i1> %m, i32 %vl) {"; 477 478 // Mix in a regular non-reduction intrinsic to check that the 479 // VPReductionIntrinsic subclass works as intended. 480 Str << " %r0 = call <8 x i32> @llvm.vp.mul.v8i32(<8 x i32> %val, <8 x i32> " 481 "%val, <8 x i1> %m, i32 %vl)"; 482 483 unsigned Idx = 1; 484 for (const char *ReductionOpcode : ReductionIntOpcodes) 485 Str << " %r" << Idx++ << " = call i32 @llvm.vp.reduce." << ReductionOpcode 486 << ".v8i32(i32 %start, <8 x i32> %val, <8 x i1> %m, i32 %vl)"; 487 for (const char *ReductionOpcode : ReductionFPOpcodes) 488 Str << " %r" << Idx++ << " = call float @llvm.vp.reduce." 489 << ReductionOpcode 490 << ".v8f32(float %fpstart, <8 x float> %fpval, <8 x i1> %m, i32 %vl)"; 491 492 Str << " ret void" 493 "}"; 494 495 std::unique_ptr<Module> M = parseAssemblyString(Str.str(), Err, C); 496 assert(M); 497 498 auto *F = M->getFunction("test_reductions"); 499 assert(F); 500 501 for (const auto &I : F->getEntryBlock()) { 502 const VPIntrinsic *VPI = dyn_cast<VPIntrinsic>(&I); 503 if (!VPI) 504 continue; 505 506 Intrinsic::ID ID = VPI->getIntrinsicID(); 507 const auto *VPRedI = dyn_cast<VPReductionIntrinsic>(&I); 508 509 if (!VPReductionIntrinsic::isVPReduction(ID)) { 510 EXPECT_EQ(VPRedI, nullptr); 511 EXPECT_EQ(VPReductionIntrinsic::getStartParamPos(ID).has_value(), false); 512 EXPECT_EQ(VPReductionIntrinsic::getVectorParamPos(ID).has_value(), false); 513 continue; 514 } 515 516 EXPECT_EQ(VPReductionIntrinsic::getStartParamPos(ID).has_value(), true); 517 EXPECT_EQ(VPReductionIntrinsic::getVectorParamPos(ID).has_value(), true); 518 ASSERT_NE(VPRedI, nullptr); 519 EXPECT_EQ(VPReductionIntrinsic::getStartParamPos(ID), 520 VPRedI->getStartParamPos()); 521 EXPECT_EQ(VPReductionIntrinsic::getVectorParamPos(ID), 522 VPRedI->getVectorParamPos()); 523 EXPECT_EQ(VPRedI->getStartParamPos(), 0u); 524 EXPECT_EQ(VPRedI->getVectorParamPos(), 1u); 525 } 526 } 527