xref: /llvm-project/llvm/unittests/IR/VPIntrinsicTest.cpp (revision fa789dffb1e12c2aece0187aeacc48dfb1768340)
1 //===- VPIntrinsicTest.cpp - VPIntrinsic unit tests ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "llvm/ADT/SmallVector.h"
10 #include "llvm/AsmParser/Parser.h"
11 #include "llvm/CodeGen/ISDOpcodes.h"
12 #include "llvm/IR/Constants.h"
13 #include "llvm/IR/IRBuilder.h"
14 #include "llvm/IR/IntrinsicInst.h"
15 #include "llvm/IR/LLVMContext.h"
16 #include "llvm/IR/Module.h"
17 #include "llvm/IR/Verifier.h"
18 #include "llvm/Support/SourceMgr.h"
19 #include "gtest/gtest.h"
20 #include <optional>
21 #include <sstream>
22 
23 using namespace llvm;
24 
25 namespace {
26 
27 static const char *ReductionIntOpcodes[] = {
28     "add", "mul", "and", "or", "xor", "smin", "smax", "umin", "umax"};
29 
30 static const char *ReductionFPOpcodes[] = {"fadd", "fmul",     "fmin",
31                                            "fmax", "fminimum", "fmaximum"};
32 
33 class VPIntrinsicTest : public testing::Test {
34 protected:
35   LLVMContext Context;
36 
37   VPIntrinsicTest() : Context() {}
38 
39   LLVMContext C;
40   SMDiagnostic Err;
41 
42   std::unique_ptr<Module> createVPDeclarationModule() {
43     const char *BinaryIntOpcodes[] = {"add",  "sub",  "mul", "sdiv", "srem",
44                                       "udiv", "urem", "and", "xor",  "or",
45                                       "ashr", "lshr", "shl", "smin", "smax",
46                                       "umin", "umax"};
47     std::stringstream Str;
48     for (const char *BinaryIntOpcode : BinaryIntOpcodes)
49       Str << " declare <8 x i32> @llvm.vp." << BinaryIntOpcode
50           << ".v8i32(<8 x i32>, <8 x i32>, <8 x i1>, i32) ";
51 
52     const char *BinaryFPOpcodes[] = {"fadd",    "fsub",    "fmul",   "fdiv",
53                                      "frem",    "minnum",  "maxnum", "minimum",
54                                      "maximum", "copysign"};
55     for (const char *BinaryFPOpcode : BinaryFPOpcodes)
56       Str << " declare <8 x float> @llvm.vp." << BinaryFPOpcode
57           << ".v8f32(<8 x float>, <8 x float>, <8 x i1>, i32) ";
58 
59     Str << " declare <8 x float> @llvm.vp.floor.v8f32(<8 x float>, <8 x i1>, "
60            "i32)";
61     Str << " declare <8 x float> @llvm.vp.round.v8f32(<8 x float>, <8 x i1>, "
62            "i32)";
63     Str << " declare <8 x float> @llvm.vp.roundeven.v8f32(<8 x float>, <8 x "
64            "i1>, "
65            "i32)";
66     Str << " declare <8 x float> @llvm.vp.roundtozero.v8f32(<8 x float>, <8 x "
67            "i1>, "
68            "i32)";
69     Str << " declare <8 x float> @llvm.vp.rint.v8f32(<8 x float>, <8 x i1>, "
70            "i32)";
71     Str << " declare <8 x float> @llvm.vp.nearbyint.v8f32(<8 x float>, <8 x "
72            "i1>, "
73            "i32)";
74     Str << " declare <8 x float> @llvm.vp.ceil.v8f32(<8 x float>, <8 x i1>, "
75            "i32)";
76     Str << " declare <8 x i32> @llvm.vp.lrint.v8i32.v8f32(<8 x float>, "
77            "<8 x i1>, i32)";
78     Str << " declare <8 x i64> @llvm.vp.llrint.v8i64.v8f32(<8 x float>, "
79            "<8 x i1>, i32)";
80     Str << " declare <8 x float> @llvm.vp.fneg.v8f32(<8 x float>, <8 x i1>, "
81            "i32)";
82     Str << " declare <8 x float> @llvm.vp.fabs.v8f32(<8 x float>, <8 x i1>, "
83            "i32)";
84     Str << " declare <8 x float> @llvm.vp.sqrt.v8f32(<8 x float>, <8 x i1>, "
85            "i32)";
86     Str << " declare <8 x float> @llvm.vp.fma.v8f32(<8 x float>, <8 x float>, "
87            "<8 x float>, <8 x i1>, i32) ";
88     Str << " declare <8 x float> @llvm.vp.fmuladd.v8f32(<8 x float>, "
89            "<8 x float>, <8 x float>, <8 x i1>, i32) ";
90 
91     Str << " declare void @llvm.vp.store.v8i32.p0v8i32(<8 x i32>, <8 x i32>*, "
92            "<8 x i1>, i32) ";
93     Str << "declare void "
94            "@llvm.experimental.vp.strided.store.v8i32.i32(<8 x i32>, "
95            "i32*, i32, <8 x i1>, i32) ";
96     Str << "declare void "
97            "@llvm.experimental.vp.strided.store.v8i32.p1i32.i32(<8 x i32>, "
98            "i32 addrspace(1)*, i32, <8 x i1>, i32) ";
99     Str << " declare void @llvm.vp.scatter.v8i32.v8p0i32(<8 x i32>, <8 x "
100            "i32*>, <8 x i1>, i32) ";
101     Str << " declare <8 x i32> @llvm.vp.load.v8i32.p0v8i32(<8 x i32>*, <8 x "
102            "i1>, i32) ";
103     Str << "declare <8 x i32> "
104            "@llvm.experimental.vp.strided.load.v8i32.i32(i32*, i32, <8 "
105            "x i1>, i32) ";
106     Str << "declare <8 x i32> "
107            "@llvm.experimental.vp.strided.load.v8i32.p1i32.i32(i32 "
108            "addrspace(1)*, i32, <8 x i1>, i32) ";
109     Str << " declare <8 x i32> @llvm.vp.gather.v8i32.v8p0i32(<8 x i32*>, <8 x "
110            "i1>, i32) ";
111     Str << " declare <8 x i32> @llvm.experimental.vp.splat.v8i32(i32, <8 x "
112            "i1>, i32) ";
113 
114     for (const char *ReductionOpcode : ReductionIntOpcodes)
115       Str << " declare i32 @llvm.vp.reduce." << ReductionOpcode
116           << ".v8i32(i32, <8 x i32>, <8 x i1>, i32) ";
117 
118     for (const char *ReductionOpcode : ReductionFPOpcodes)
119       Str << " declare float @llvm.vp.reduce." << ReductionOpcode
120           << ".v8f32(float, <8 x float>, <8 x i1>, i32) ";
121 
122     Str << " declare <8 x i32> @llvm.vp.merge.v8i32(<8 x i1>, <8 x i32>, <8 x "
123            "i32>, i32)";
124     Str << " declare <8 x i32> @llvm.vp.select.v8i32(<8 x i1>, <8 x i32>, <8 x "
125            "i32>, i32)";
126     Str << " declare <8 x i1> @llvm.vp.is.fpclass.v8f32(<8 x float>, i32, <8 x "
127            "i1>, i32)";
128     Str << " declare <8 x i32> @llvm.experimental.vp.splice.v8i32(<8 x "
129            "i32>, <8 x i32>, i32, <8 x i1>, i32, i32) ";
130 
131     Str << " declare <8 x i32> @llvm.vp.fptoui.v8i32"
132         << ".v8f32(<8 x float>, <8 x i1>, i32) ";
133     Str << " declare <8 x i32> @llvm.vp.fptosi.v8i32"
134         << ".v8f32(<8 x float>, <8 x i1>, i32) ";
135     Str << " declare <8 x float> @llvm.vp.uitofp.v8f32"
136         << ".v8i32(<8 x i32>, <8 x i1>, i32) ";
137     Str << " declare <8 x float> @llvm.vp.sitofp.v8f32"
138         << ".v8i32(<8 x i32>, <8 x i1>, i32) ";
139     Str << " declare <8 x float> @llvm.vp.fptrunc.v8f32"
140         << ".v8f64(<8 x double>, <8 x i1>, i32) ";
141     Str << " declare <8 x double> @llvm.vp.fpext.v8f64"
142         << ".v8f32(<8 x float>, <8 x i1>, i32) ";
143     Str << " declare <8 x i32> @llvm.vp.trunc.v8i32"
144         << ".v8i64(<8 x i64>, <8 x i1>, i32) ";
145     Str << " declare <8 x i64> @llvm.vp.zext.v8i64"
146         << ".v8i32(<8 x i32>, <8 x i1>, i32) ";
147     Str << " declare <8 x i64> @llvm.vp.sext.v8i64"
148         << ".v8i32(<8 x i32>, <8 x i1>, i32) ";
149     Str << " declare <8 x i32> @llvm.vp.ptrtoint.v8i32"
150         << ".v8p0i32(<8 x i32*>, <8 x i1>, i32) ";
151     Str << " declare <8 x i32*> @llvm.vp.inttoptr.v8p0i32"
152         << ".v8i32(<8 x i32>, <8 x i1>, i32) ";
153 
154     Str << " declare <8 x i1> @llvm.vp.fcmp.v8f32"
155         << "(<8 x float>, <8 x float>, metadata, <8 x i1>, i32) ";
156     Str << " declare <8 x i1> @llvm.vp.icmp.v8i16"
157         << "(<8 x i16>, <8 x i16>, metadata, <8 x i1>, i32) ";
158 
159     Str << " declare <8 x i32> @llvm.experimental.vp.reverse.v8i32(<8 x i32>, "
160            "<8 x i1>, i32) ";
161     Str << " declare <8 x i16> @llvm.vp.abs.v8i16"
162         << "(<8 x i16>, i1 immarg, <8 x i1>, i32) ";
163     Str << " declare <8 x i16> @llvm.vp.bitreverse.v8i16"
164         << "(<8 x i16>, <8 x i1>, i32) ";
165     Str << " declare <8 x i16> @llvm.vp.bswap.v8i16"
166         << "(<8 x i16>, <8 x i1>, i32) ";
167     Str << " declare <8 x i16> @llvm.vp.ctpop.v8i16"
168         << "(<8 x i16>, <8 x i1>, i32) ";
169     Str << " declare <8 x i16> @llvm.vp.ctlz.v8i16"
170         << "(<8 x i16>, i1 immarg, <8 x i1>, i32) ";
171     Str << " declare <8 x i16> @llvm.vp.cttz.v8i16"
172         << "(<8 x i16>, i1 immarg, <8 x i1>, i32) ";
173     Str << " declare <8 x i16> @llvm.vp.sadd.sat.v8i16"
174         << "(<8 x i16>, <8 x i16>, <8 x i1>, i32) ";
175     Str << " declare <8 x i16> @llvm.vp.uadd.sat.v8i16"
176         << "(<8 x i16>, <8 x i16>, <8 x i1>, i32) ";
177     Str << " declare <8 x i16> @llvm.vp.ssub.sat.v8i16"
178         << "(<8 x i16>, <8 x i16>, <8 x i1>, i32) ";
179     Str << " declare <8 x i16> @llvm.vp.usub.sat.v8i16"
180         << "(<8 x i16>, <8 x i16>, <8 x i1>, i32) ";
181     Str << " declare <8 x i16> @llvm.vp.fshl.v8i16"
182         << "(<8 x i16>, <8 x i16>, <8 x i16>, <8 x i1>, i32) ";
183     Str << " declare <8 x i16> @llvm.vp.fshr.v8i16"
184         << "(<8 x i16>, <8 x i16>, <8 x i16>, <8 x i1>, i32) ";
185     Str << " declare i32 @llvm.vp.cttz.elts.i32.v8i16"
186         << "(<8 x i16>, i1 immarg, <8 x i1>, i32) ";
187 
188     return parseAssemblyString(Str.str(), Err, C);
189   }
190 };
191 
192 /// Check that the property scopes include/llvm/IR/VPIntrinsics.def are closed.
193 TEST_F(VPIntrinsicTest, VPIntrinsicsDefScopes) {
194   std::optional<Intrinsic::ID> ScopeVPID;
195 #define BEGIN_REGISTER_VP_INTRINSIC(VPID, ...)                                 \
196   ASSERT_FALSE(ScopeVPID.has_value());                                         \
197   ScopeVPID = Intrinsic::VPID;
198 #define END_REGISTER_VP_INTRINSIC(VPID)                                        \
199   ASSERT_TRUE(ScopeVPID.has_value());                                          \
200   ASSERT_EQ(*ScopeVPID, Intrinsic::VPID);                                      \
201   ScopeVPID = std::nullopt;
202 
203   std::optional<ISD::NodeType> ScopeOPC;
204 #define BEGIN_REGISTER_VP_SDNODE(SDOPC, ...)                                   \
205   ASSERT_FALSE(ScopeOPC.has_value());                                          \
206   ScopeOPC = ISD::SDOPC;
207 #define END_REGISTER_VP_SDNODE(SDOPC)                                          \
208   ASSERT_TRUE(ScopeOPC.has_value());                                           \
209   ASSERT_EQ(*ScopeOPC, ISD::SDOPC);                                            \
210   ScopeOPC = std::nullopt;
211 #include "llvm/IR/VPIntrinsics.def"
212 
213   ASSERT_FALSE(ScopeVPID.has_value());
214   ASSERT_FALSE(ScopeOPC.has_value());
215 }
216 
217 /// Check that every VP intrinsic in the test module is recognized as a VP
218 /// intrinsic.
219 TEST_F(VPIntrinsicTest, VPModuleComplete) {
220   std::unique_ptr<Module> M = createVPDeclarationModule();
221   assert(M);
222 
223   // Check that all @llvm.vp.* functions in the module are recognized vp
224   // intrinsics.
225   std::set<Intrinsic::ID> SeenIDs;
226   for (const auto &VPDecl : *M) {
227     ASSERT_TRUE(VPDecl.isIntrinsic());
228     ASSERT_TRUE(VPIntrinsic::isVPIntrinsic(VPDecl.getIntrinsicID()));
229     SeenIDs.insert(VPDecl.getIntrinsicID());
230   }
231 
232   // Check that every registered VP intrinsic has an instance in the test
233   // module.
234 #define BEGIN_REGISTER_VP_INTRINSIC(VPID, ...)                                 \
235   ASSERT_TRUE(SeenIDs.count(Intrinsic::VPID));
236 #include "llvm/IR/VPIntrinsics.def"
237 }
238 
239 /// Check that VPIntrinsic:canIgnoreVectorLengthParam() returns true
240 /// if the vector length parameter does not mask off any lanes.
241 TEST_F(VPIntrinsicTest, CanIgnoreVectorLength) {
242   LLVMContext C;
243   SMDiagnostic Err;
244 
245   std::unique_ptr<Module> M =
246       parseAssemblyString(
247 "declare <256 x i64> @llvm.vp.mul.v256i64(<256 x i64>, <256 x i64>, <256 x i1>, i32)"
248 "declare <vscale x 2 x i64> @llvm.vp.mul.nxv2i64(<vscale x 2 x i64>, <vscale x 2 x i64>, <vscale x 2 x i1>, i32)"
249 "declare <vscale x 1 x i64> @llvm.vp.mul.nxv1i64(<vscale x 1 x i64>, <vscale x 1 x i64>, <vscale x 1 x i1>, i32)"
250 "declare i32 @llvm.vscale.i32()"
251 "define void @test_static_vlen( "
252 "      <256 x i64> %i0, <vscale x 2 x i64> %si0x2, <vscale x 1 x i64> %si0x1,"
253 "      <256 x i64> %i1, <vscale x 2 x i64> %si1x2, <vscale x 1 x i64> %si1x1,"
254 "      <256 x i1> %m, <vscale x 2 x i1> %smx2, <vscale x 1 x i1> %smx1, i32 %vl) { "
255 "  %r0 = call <256 x i64> @llvm.vp.mul.v256i64(<256 x i64> %i0, <256 x i64> %i1, <256 x i1> %m, i32 %vl)"
256 "  %r1 = call <256 x i64> @llvm.vp.mul.v256i64(<256 x i64> %i0, <256 x i64> %i1, <256 x i1> %m, i32 256)"
257 "  %r2 = call <256 x i64> @llvm.vp.mul.v256i64(<256 x i64> %i0, <256 x i64> %i1, <256 x i1> %m, i32 0)"
258 "  %r3 = call <256 x i64> @llvm.vp.mul.v256i64(<256 x i64> %i0, <256 x i64> %i1, <256 x i1> %m, i32 7)"
259 "  %r4 = call <256 x i64> @llvm.vp.mul.v256i64(<256 x i64> %i0, <256 x i64> %i1, <256 x i1> %m, i32 123)"
260 "  %vs = call i32 @llvm.vscale.i32()"
261 "  %vs.x2 = mul i32 %vs, 2"
262 "  %r5 = call <vscale x 2 x i64> @llvm.vp.mul.nxv2i64(<vscale x 2 x i64> %si0x2, <vscale x 2 x i64> %si1x2, <vscale x 2 x i1> %smx2, i32 %vs.x2)"
263 "  %r6 = call <vscale x 2 x i64> @llvm.vp.mul.nxv2i64(<vscale x 2 x i64> %si0x2, <vscale x 2 x i64> %si1x2, <vscale x 2 x i1> %smx2, i32 %vs)"
264 "  %r7 = call <vscale x 2 x i64> @llvm.vp.mul.nxv2i64(<vscale x 2 x i64> %si0x2, <vscale x 2 x i64> %si1x2, <vscale x 2 x i1> %smx2, i32 99999)"
265 "  %r8 = call <vscale x 1 x i64> @llvm.vp.mul.nxv1i64(<vscale x 1 x i64> %si0x1, <vscale x 1 x i64> %si1x1, <vscale x 1 x i1> %smx1, i32 %vs)"
266 "  %r9 = call <vscale x 1 x i64> @llvm.vp.mul.nxv1i64(<vscale x 1 x i64> %si0x1, <vscale x 1 x i64> %si1x1, <vscale x 1 x i1> %smx1, i32 1)"
267 "  %r10 = call <vscale x 1 x i64> @llvm.vp.mul.nxv1i64(<vscale x 1 x i64> %si0x1, <vscale x 1 x i64> %si1x1, <vscale x 1 x i1> %smx1, i32 %vs.x2)"
268 "  %vs.wat = add i32 %vs, 2"
269 "  %r11 = call <vscale x 2 x i64> @llvm.vp.mul.nxv2i64(<vscale x 2 x i64> %si0x2, <vscale x 2 x i64> %si1x2, <vscale x 2 x i1> %smx2, i32 %vs.wat)"
270 "  ret void "
271 "}",
272           Err, C);
273 
274   auto *F = M->getFunction("test_static_vlen");
275   assert(F);
276 
277   const bool Expected[] = {false, true,  false, false, false, true,
278                            false, false, true,  false, true,  false};
279   const auto *ExpectedIt = std::begin(Expected);
280   for (auto &I : F->getEntryBlock()) {
281     VPIntrinsic *VPI = dyn_cast<VPIntrinsic>(&I);
282     if (!VPI)
283       continue;
284 
285     ASSERT_NE(ExpectedIt, std::end(Expected));
286     ASSERT_EQ(*ExpectedIt, VPI->canIgnoreVectorLengthParam());
287     ++ExpectedIt;
288   }
289 }
290 
291 /// Check that the argument returned by
292 /// VPIntrinsic::get<X>ParamPos(Intrinsic::ID) has the expected type.
293 TEST_F(VPIntrinsicTest, GetParamPos) {
294   std::unique_ptr<Module> M = createVPDeclarationModule();
295   assert(M);
296 
297   for (Function &F : *M) {
298     ASSERT_TRUE(F.isIntrinsic());
299     std::optional<unsigned> MaskParamPos =
300         VPIntrinsic::getMaskParamPos(F.getIntrinsicID());
301     if (MaskParamPos) {
302       Type *MaskParamType = F.getArg(*MaskParamPos)->getType();
303       ASSERT_TRUE(MaskParamType->isVectorTy());
304       ASSERT_TRUE(
305           cast<VectorType>(MaskParamType)->getElementType()->isIntegerTy(1));
306     }
307 
308     std::optional<unsigned> VecLenParamPos =
309         VPIntrinsic::getVectorLengthParamPos(F.getIntrinsicID());
310     if (VecLenParamPos) {
311       Type *VecLenParamType = F.getArg(*VecLenParamPos)->getType();
312       ASSERT_TRUE(VecLenParamType->isIntegerTy(32));
313     }
314   }
315 }
316 
317 /// Check that going from Opcode to VP intrinsic and back results in the same
318 /// Opcode.
319 TEST_F(VPIntrinsicTest, OpcodeRoundTrip) {
320   std::vector<unsigned> Opcodes;
321   Opcodes.reserve(100);
322 
323   {
324 #define HANDLE_INST(OCNum, OCName, Class) Opcodes.push_back(OCNum);
325 #include "llvm/IR/Instruction.def"
326   }
327 
328   unsigned FullTripCounts = 0;
329   for (unsigned OC : Opcodes) {
330     Intrinsic::ID VPID = VPIntrinsic::getForOpcode(OC);
331     // No equivalent VP intrinsic available.
332     if (VPID == Intrinsic::not_intrinsic)
333       continue;
334 
335     std::optional<unsigned> RoundTripOC =
336         VPIntrinsic::getFunctionalOpcodeForVP(VPID);
337     // No equivalent Opcode available.
338     if (!RoundTripOC)
339       continue;
340 
341     ASSERT_EQ(*RoundTripOC, OC);
342     ++FullTripCounts;
343   }
344   ASSERT_NE(FullTripCounts, 0u);
345 }
346 
347 /// Check that going from VP intrinsic to Opcode and back results in the same
348 /// intrinsic id.
349 TEST_F(VPIntrinsicTest, IntrinsicIDRoundTrip) {
350   std::unique_ptr<Module> M = createVPDeclarationModule();
351   assert(M);
352 
353   unsigned FullTripCounts = 0;
354   for (const auto &VPDecl : *M) {
355     auto VPID = VPDecl.getIntrinsicID();
356     std::optional<unsigned> OC = VPIntrinsic::getFunctionalOpcodeForVP(VPID);
357 
358     // no equivalent Opcode available
359     if (!OC)
360       continue;
361 
362     Intrinsic::ID RoundTripVPID = VPIntrinsic::getForOpcode(*OC);
363 
364     ASSERT_EQ(RoundTripVPID, VPID);
365     ++FullTripCounts;
366   }
367   ASSERT_NE(FullTripCounts, 0u);
368 }
369 
370 /// Check that going from intrinsic to VP intrinsic and back results in the same
371 /// intrinsic.
372 TEST_F(VPIntrinsicTest, IntrinsicToVPRoundTrip) {
373   bool IsFullTrip = false;
374   Intrinsic::ID IntrinsicID = Intrinsic::not_intrinsic + 1;
375   for (; IntrinsicID < Intrinsic::num_intrinsics; IntrinsicID++) {
376     Intrinsic::ID VPID = VPIntrinsic::getForIntrinsic(IntrinsicID);
377     // No equivalent VP intrinsic available.
378     if (VPID == Intrinsic::not_intrinsic)
379       continue;
380 
381     // Return itself if passed intrinsic ID is VP intrinsic.
382     if (VPIntrinsic::isVPIntrinsic(IntrinsicID)) {
383       ASSERT_EQ(IntrinsicID, VPID);
384       continue;
385     }
386 
387     std::optional<Intrinsic::ID> RoundTripIntrinsicID =
388         VPIntrinsic::getFunctionalIntrinsicIDForVP(VPID);
389     // No equivalent non-predicated intrinsic available.
390     if (!RoundTripIntrinsicID)
391       continue;
392 
393     ASSERT_EQ(*RoundTripIntrinsicID, IntrinsicID);
394     IsFullTrip = true;
395   }
396   ASSERT_TRUE(IsFullTrip);
397 }
398 
399 /// Check that going from VP intrinsic to equivalent non-predicated intrinsic
400 /// and back results in the same intrinsic.
401 TEST_F(VPIntrinsicTest, VPToNonPredIntrinsicRoundTrip) {
402   std::unique_ptr<Module> M = createVPDeclarationModule();
403   assert(M);
404 
405   bool IsFullTrip = false;
406   for (const auto &VPDecl : *M) {
407     auto VPID = VPDecl.getIntrinsicID();
408     std::optional<Intrinsic::ID> NonPredID =
409         VPIntrinsic::getFunctionalIntrinsicIDForVP(VPID);
410 
411     // No equivalent non-predicated intrinsic available
412     if (!NonPredID)
413       continue;
414 
415     Intrinsic::ID RoundTripVPID = VPIntrinsic::getForIntrinsic(*NonPredID);
416 
417     ASSERT_EQ(RoundTripVPID, VPID);
418     IsFullTrip = true;
419   }
420   ASSERT_TRUE(IsFullTrip);
421 }
422 
423 /// Check that VPIntrinsic::getOrInsertDeclarationForParams works.
424 TEST_F(VPIntrinsicTest, VPIntrinsicDeclarationForParams) {
425   std::unique_ptr<Module> M = createVPDeclarationModule();
426   assert(M);
427 
428   auto OutM = std::make_unique<Module>("", M->getContext());
429 
430   for (auto &F : *M) {
431     auto *FuncTy = F.getFunctionType();
432 
433     // Declare intrinsic anew with explicit types.
434     std::vector<Value *> Values;
435     for (auto *ParamTy : FuncTy->params())
436       Values.push_back(UndefValue::get(ParamTy));
437 
438     ASSERT_NE(F.getIntrinsicID(), Intrinsic::not_intrinsic);
439     auto *NewDecl = VPIntrinsic::getOrInsertDeclarationForParams(
440         OutM.get(), F.getIntrinsicID(), FuncTy->getReturnType(), Values);
441     ASSERT_TRUE(NewDecl);
442 
443     // Check that 'old decl' == 'new decl'.
444     ASSERT_EQ(F.getIntrinsicID(), NewDecl->getIntrinsicID());
445     FunctionType::param_iterator ItNewParams =
446         NewDecl->getFunctionType()->param_begin();
447     FunctionType::param_iterator EndItNewParams =
448         NewDecl->getFunctionType()->param_end();
449     for (auto *ParamTy : FuncTy->params()) {
450       ASSERT_NE(ItNewParams, EndItNewParams);
451       ASSERT_EQ(*ItNewParams, ParamTy);
452       ++ItNewParams;
453     }
454   }
455 }
456 
457 } // end anonymous namespace
458 
459 /// Check various properties of VPReductionIntrinsics
460 TEST_F(VPIntrinsicTest, VPReductions) {
461   LLVMContext C;
462   SMDiagnostic Err;
463 
464   std::stringstream Str;
465   Str << "declare <8 x i32> @llvm.vp.mul.v8i32(<8 x i32>, <8 x i32>, <8 x i1>, "
466          "i32)";
467   for (const char *ReductionOpcode : ReductionIntOpcodes)
468     Str << " declare i32 @llvm.vp.reduce." << ReductionOpcode
469         << ".v8i32(i32, <8 x i32>, <8 x i1>, i32) ";
470 
471   for (const char *ReductionOpcode : ReductionFPOpcodes)
472     Str << " declare float @llvm.vp.reduce." << ReductionOpcode
473         << ".v8f32(float, <8 x float>, <8 x i1>, i32) ";
474 
475   Str << "define void @test_reductions(i32 %start, <8 x i32> %val, float "
476          "%fpstart, <8 x float> %fpval, <8 x i1> %m, i32 %vl) {";
477 
478   // Mix in a regular non-reduction intrinsic to check that the
479   // VPReductionIntrinsic subclass works as intended.
480   Str << "  %r0 = call <8 x i32> @llvm.vp.mul.v8i32(<8 x i32> %val, <8 x i32> "
481          "%val, <8 x i1> %m, i32 %vl)";
482 
483   unsigned Idx = 1;
484   for (const char *ReductionOpcode : ReductionIntOpcodes)
485     Str << "  %r" << Idx++ << " = call i32 @llvm.vp.reduce." << ReductionOpcode
486         << ".v8i32(i32 %start, <8 x i32> %val, <8 x i1> %m, i32 %vl)";
487   for (const char *ReductionOpcode : ReductionFPOpcodes)
488     Str << "  %r" << Idx++ << " = call float @llvm.vp.reduce."
489         << ReductionOpcode
490         << ".v8f32(float %fpstart, <8 x float> %fpval, <8 x i1> %m, i32 %vl)";
491 
492   Str << "  ret void"
493          "}";
494 
495   std::unique_ptr<Module> M = parseAssemblyString(Str.str(), Err, C);
496   assert(M);
497 
498   auto *F = M->getFunction("test_reductions");
499   assert(F);
500 
501   for (const auto &I : F->getEntryBlock()) {
502     const VPIntrinsic *VPI = dyn_cast<VPIntrinsic>(&I);
503     if (!VPI)
504       continue;
505 
506     Intrinsic::ID ID = VPI->getIntrinsicID();
507     const auto *VPRedI = dyn_cast<VPReductionIntrinsic>(&I);
508 
509     if (!VPReductionIntrinsic::isVPReduction(ID)) {
510       EXPECT_EQ(VPRedI, nullptr);
511       EXPECT_EQ(VPReductionIntrinsic::getStartParamPos(ID).has_value(), false);
512       EXPECT_EQ(VPReductionIntrinsic::getVectorParamPos(ID).has_value(), false);
513       continue;
514     }
515 
516     EXPECT_EQ(VPReductionIntrinsic::getStartParamPos(ID).has_value(), true);
517     EXPECT_EQ(VPReductionIntrinsic::getVectorParamPos(ID).has_value(), true);
518     ASSERT_NE(VPRedI, nullptr);
519     EXPECT_EQ(VPReductionIntrinsic::getStartParamPos(ID),
520               VPRedI->getStartParamPos());
521     EXPECT_EQ(VPReductionIntrinsic::getVectorParamPos(ID),
522               VPRedI->getVectorParamPos());
523     EXPECT_EQ(VPRedI->getStartParamPos(), 0u);
524     EXPECT_EQ(VPRedI->getVectorParamPos(), 1u);
525   }
526 }
527