xref: /llvm-project/llvm/lib/Target/AMDGPU/AMDGPUInstCombineIntrinsic.cpp (revision c82a6a025179e3b155c70f5ad8f84fa8ec2a9452)
1 //===- AMDGPInstCombineIntrinsic.cpp - AMDGPU specific InstCombine pass ---===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // \file
10 // This file implements a TargetTransformInfo analysis pass specific to the
11 // AMDGPU target machine. It uses the target's detailed information to provide
12 // more precise answers to certain TTI queries, while letting the target
13 // independent and default TTI implementations handle the rest.
14 //
15 //===----------------------------------------------------------------------===//
16 
17 #include "AMDGPUInstrInfo.h"
18 #include "AMDGPUTargetTransformInfo.h"
19 #include "GCNSubtarget.h"
20 #include "llvm/ADT/FloatingPointMode.h"
21 #include "llvm/IR/IntrinsicsAMDGPU.h"
22 #include "llvm/Transforms/InstCombine/InstCombiner.h"
23 #include <optional>
24 
25 using namespace llvm;
26 using namespace llvm::PatternMatch;
27 
28 #define DEBUG_TYPE "AMDGPUtti"
29 
30 namespace {
31 
32 struct AMDGPUImageDMaskIntrinsic {
33   unsigned Intr;
34 };
35 
36 #define GET_AMDGPUImageDMaskIntrinsicTable_IMPL
37 #include "InstCombineTables.inc"
38 
39 } // end anonymous namespace
40 
41 // Constant fold llvm.amdgcn.fmed3 intrinsics for standard inputs.
42 //
43 // A single NaN input is folded to minnum, so we rely on that folding for
44 // handling NaNs.
45 static APFloat fmed3AMDGCN(const APFloat &Src0, const APFloat &Src1,
46                            const APFloat &Src2) {
47   APFloat Max3 = maxnum(maxnum(Src0, Src1), Src2);
48 
49   APFloat::cmpResult Cmp0 = Max3.compare(Src0);
50   assert(Cmp0 != APFloat::cmpUnordered && "nans handled separately");
51   if (Cmp0 == APFloat::cmpEqual)
52     return maxnum(Src1, Src2);
53 
54   APFloat::cmpResult Cmp1 = Max3.compare(Src1);
55   assert(Cmp1 != APFloat::cmpUnordered && "nans handled separately");
56   if (Cmp1 == APFloat::cmpEqual)
57     return maxnum(Src0, Src2);
58 
59   return maxnum(Src0, Src1);
60 }
61 
62 // Check if a value can be converted to a 16-bit value without losing
63 // precision.
64 // The value is expected to be either a float (IsFloat = true) or an unsigned
65 // integer (IsFloat = false).
66 static bool canSafelyConvertTo16Bit(Value &V, bool IsFloat) {
67   Type *VTy = V.getType();
68   if (VTy->isHalfTy() || VTy->isIntegerTy(16)) {
69     // The value is already 16-bit, so we don't want to convert to 16-bit again!
70     return false;
71   }
72   if (IsFloat) {
73     if (ConstantFP *ConstFloat = dyn_cast<ConstantFP>(&V)) {
74       // We need to check that if we cast the index down to a half, we do not
75       // lose precision.
76       APFloat FloatValue(ConstFloat->getValueAPF());
77       bool LosesInfo = true;
78       FloatValue.convert(APFloat::IEEEhalf(), APFloat::rmTowardZero,
79                          &LosesInfo);
80       return !LosesInfo;
81     }
82   } else {
83     if (ConstantInt *ConstInt = dyn_cast<ConstantInt>(&V)) {
84       // We need to check that if we cast the index down to an i16, we do not
85       // lose precision.
86       APInt IntValue(ConstInt->getValue());
87       return IntValue.getActiveBits() <= 16;
88     }
89   }
90 
91   Value *CastSrc;
92   bool IsExt = IsFloat ? match(&V, m_FPExt(PatternMatch::m_Value(CastSrc)))
93                        : match(&V, m_ZExt(PatternMatch::m_Value(CastSrc)));
94   if (IsExt) {
95     Type *CastSrcTy = CastSrc->getType();
96     if (CastSrcTy->isHalfTy() || CastSrcTy->isIntegerTy(16))
97       return true;
98   }
99 
100   return false;
101 }
102 
103 // Convert a value to 16-bit.
104 static Value *convertTo16Bit(Value &V, InstCombiner::BuilderTy &Builder) {
105   Type *VTy = V.getType();
106   if (isa<FPExtInst>(&V) || isa<SExtInst>(&V) || isa<ZExtInst>(&V))
107     return cast<Instruction>(&V)->getOperand(0);
108   if (VTy->isIntegerTy())
109     return Builder.CreateIntCast(&V, Type::getInt16Ty(V.getContext()), false);
110   if (VTy->isFloatingPointTy())
111     return Builder.CreateFPCast(&V, Type::getHalfTy(V.getContext()));
112 
113   llvm_unreachable("Should never be called!");
114 }
115 
116 /// Applies Func(OldIntr.Args, OldIntr.ArgTys), creates intrinsic call with
117 /// modified arguments (based on OldIntr) and replaces InstToReplace with
118 /// this newly created intrinsic call.
119 static std::optional<Instruction *> modifyIntrinsicCall(
120     IntrinsicInst &OldIntr, Instruction &InstToReplace, unsigned NewIntr,
121     InstCombiner &IC,
122     std::function<void(SmallVectorImpl<Value *> &, SmallVectorImpl<Type *> &)>
123         Func) {
124   SmallVector<Type *, 4> ArgTys;
125   if (!Intrinsic::getIntrinsicSignature(OldIntr.getCalledFunction(), ArgTys))
126     return std::nullopt;
127 
128   SmallVector<Value *, 8> Args(OldIntr.args());
129 
130   // Modify arguments and types
131   Func(Args, ArgTys);
132 
133   CallInst *NewCall = IC.Builder.CreateIntrinsic(NewIntr, ArgTys, Args);
134   NewCall->takeName(&OldIntr);
135   NewCall->copyMetadata(OldIntr);
136   if (isa<FPMathOperator>(NewCall))
137     NewCall->copyFastMathFlags(&OldIntr);
138 
139   // Erase and replace uses
140   if (!InstToReplace.getType()->isVoidTy())
141     IC.replaceInstUsesWith(InstToReplace, NewCall);
142 
143   bool RemoveOldIntr = &OldIntr != &InstToReplace;
144 
145   auto *RetValue = IC.eraseInstFromFunction(InstToReplace);
146   if (RemoveOldIntr)
147     IC.eraseInstFromFunction(OldIntr);
148 
149   return RetValue;
150 }
151 
152 static std::optional<Instruction *>
153 simplifyAMDGCNImageIntrinsic(const GCNSubtarget *ST,
154                              const AMDGPU::ImageDimIntrinsicInfo *ImageDimIntr,
155                              IntrinsicInst &II, InstCombiner &IC) {
156   // Optimize _L to _LZ when _L is zero
157   if (const auto *LZMappingInfo =
158           AMDGPU::getMIMGLZMappingInfo(ImageDimIntr->BaseOpcode)) {
159     if (auto *ConstantLod =
160             dyn_cast<ConstantFP>(II.getOperand(ImageDimIntr->LodIndex))) {
161       if (ConstantLod->isZero() || ConstantLod->isNegative()) {
162         const AMDGPU::ImageDimIntrinsicInfo *NewImageDimIntr =
163             AMDGPU::getImageDimIntrinsicByBaseOpcode(LZMappingInfo->LZ,
164                                                      ImageDimIntr->Dim);
165         return modifyIntrinsicCall(
166             II, II, NewImageDimIntr->Intr, IC, [&](auto &Args, auto &ArgTys) {
167               Args.erase(Args.begin() + ImageDimIntr->LodIndex);
168             });
169       }
170     }
171   }
172 
173   // Optimize _mip away, when 'lod' is zero
174   if (const auto *MIPMappingInfo =
175           AMDGPU::getMIMGMIPMappingInfo(ImageDimIntr->BaseOpcode)) {
176     if (auto *ConstantMip =
177             dyn_cast<ConstantInt>(II.getOperand(ImageDimIntr->MipIndex))) {
178       if (ConstantMip->isZero()) {
179         const AMDGPU::ImageDimIntrinsicInfo *NewImageDimIntr =
180             AMDGPU::getImageDimIntrinsicByBaseOpcode(MIPMappingInfo->NONMIP,
181                                                      ImageDimIntr->Dim);
182         return modifyIntrinsicCall(
183             II, II, NewImageDimIntr->Intr, IC, [&](auto &Args, auto &ArgTys) {
184               Args.erase(Args.begin() + ImageDimIntr->MipIndex);
185             });
186       }
187     }
188   }
189 
190   // Optimize _bias away when 'bias' is zero
191   if (const auto *BiasMappingInfo =
192           AMDGPU::getMIMGBiasMappingInfo(ImageDimIntr->BaseOpcode)) {
193     if (auto *ConstantBias =
194             dyn_cast<ConstantFP>(II.getOperand(ImageDimIntr->BiasIndex))) {
195       if (ConstantBias->isZero()) {
196         const AMDGPU::ImageDimIntrinsicInfo *NewImageDimIntr =
197             AMDGPU::getImageDimIntrinsicByBaseOpcode(BiasMappingInfo->NoBias,
198                                                      ImageDimIntr->Dim);
199         return modifyIntrinsicCall(
200             II, II, NewImageDimIntr->Intr, IC, [&](auto &Args, auto &ArgTys) {
201               Args.erase(Args.begin() + ImageDimIntr->BiasIndex);
202               ArgTys.erase(ArgTys.begin() + ImageDimIntr->BiasTyArg);
203             });
204       }
205     }
206   }
207 
208   // Optimize _offset away when 'offset' is zero
209   if (const auto *OffsetMappingInfo =
210           AMDGPU::getMIMGOffsetMappingInfo(ImageDimIntr->BaseOpcode)) {
211     if (auto *ConstantOffset =
212             dyn_cast<ConstantInt>(II.getOperand(ImageDimIntr->OffsetIndex))) {
213       if (ConstantOffset->isZero()) {
214         const AMDGPU::ImageDimIntrinsicInfo *NewImageDimIntr =
215             AMDGPU::getImageDimIntrinsicByBaseOpcode(
216                 OffsetMappingInfo->NoOffset, ImageDimIntr->Dim);
217         return modifyIntrinsicCall(
218             II, II, NewImageDimIntr->Intr, IC, [&](auto &Args, auto &ArgTys) {
219               Args.erase(Args.begin() + ImageDimIntr->OffsetIndex);
220             });
221       }
222     }
223   }
224 
225   // Try to use D16
226   if (ST->hasD16Images()) {
227 
228     const AMDGPU::MIMGBaseOpcodeInfo *BaseOpcode =
229         AMDGPU::getMIMGBaseOpcodeInfo(ImageDimIntr->BaseOpcode);
230 
231     if (BaseOpcode->HasD16) {
232 
233       // If the only use of image intrinsic is a fptrunc (with conversion to
234       // half) then both fptrunc and image intrinsic will be replaced with image
235       // intrinsic with D16 flag.
236       if (II.hasOneUse()) {
237         Instruction *User = II.user_back();
238 
239         if (User->getOpcode() == Instruction::FPTrunc &&
240             User->getType()->getScalarType()->isHalfTy()) {
241 
242           return modifyIntrinsicCall(II, *User, ImageDimIntr->Intr, IC,
243                                      [&](auto &Args, auto &ArgTys) {
244                                        // Change return type of image intrinsic.
245                                        // Set it to return type of fptrunc.
246                                        ArgTys[0] = User->getType();
247                                      });
248         }
249       }
250     }
251   }
252 
253   // Try to use A16 or G16
254   if (!ST->hasA16() && !ST->hasG16())
255     return std::nullopt;
256 
257   // Address is interpreted as float if the instruction has a sampler or as
258   // unsigned int if there is no sampler.
259   bool HasSampler =
260       AMDGPU::getMIMGBaseOpcodeInfo(ImageDimIntr->BaseOpcode)->Sampler;
261   bool FloatCoord = false;
262   // true means derivatives can be converted to 16 bit, coordinates not
263   bool OnlyDerivatives = false;
264 
265   for (unsigned OperandIndex = ImageDimIntr->GradientStart;
266        OperandIndex < ImageDimIntr->VAddrEnd; OperandIndex++) {
267     Value *Coord = II.getOperand(OperandIndex);
268     // If the values are not derived from 16-bit values, we cannot optimize.
269     if (!canSafelyConvertTo16Bit(*Coord, HasSampler)) {
270       if (OperandIndex < ImageDimIntr->CoordStart ||
271           ImageDimIntr->GradientStart == ImageDimIntr->CoordStart) {
272         return std::nullopt;
273       }
274       // All gradients can be converted, so convert only them
275       OnlyDerivatives = true;
276       break;
277     }
278 
279     assert(OperandIndex == ImageDimIntr->GradientStart ||
280            FloatCoord == Coord->getType()->isFloatingPointTy());
281     FloatCoord = Coord->getType()->isFloatingPointTy();
282   }
283 
284   if (!OnlyDerivatives && !ST->hasA16())
285     OnlyDerivatives = true; // Only supports G16
286 
287   // Check if there is a bias parameter and if it can be converted to f16
288   if (!OnlyDerivatives && ImageDimIntr->NumBiasArgs != 0) {
289     Value *Bias = II.getOperand(ImageDimIntr->BiasIndex);
290     assert(HasSampler &&
291            "Only image instructions with a sampler can have a bias");
292     if (!canSafelyConvertTo16Bit(*Bias, HasSampler))
293       OnlyDerivatives = true;
294   }
295 
296   if (OnlyDerivatives && (!ST->hasG16() || ImageDimIntr->GradientStart ==
297                                                ImageDimIntr->CoordStart))
298     return std::nullopt;
299 
300   Type *CoordType = FloatCoord ? Type::getHalfTy(II.getContext())
301                                : Type::getInt16Ty(II.getContext());
302 
303   return modifyIntrinsicCall(
304       II, II, II.getIntrinsicID(), IC, [&](auto &Args, auto &ArgTys) {
305         ArgTys[ImageDimIntr->GradientTyArg] = CoordType;
306         if (!OnlyDerivatives) {
307           ArgTys[ImageDimIntr->CoordTyArg] = CoordType;
308 
309           // Change the bias type
310           if (ImageDimIntr->NumBiasArgs != 0)
311             ArgTys[ImageDimIntr->BiasTyArg] = Type::getHalfTy(II.getContext());
312         }
313 
314         unsigned EndIndex =
315             OnlyDerivatives ? ImageDimIntr->CoordStart : ImageDimIntr->VAddrEnd;
316         for (unsigned OperandIndex = ImageDimIntr->GradientStart;
317              OperandIndex < EndIndex; OperandIndex++) {
318           Args[OperandIndex] =
319               convertTo16Bit(*II.getOperand(OperandIndex), IC.Builder);
320         }
321 
322         // Convert the bias
323         if (!OnlyDerivatives && ImageDimIntr->NumBiasArgs != 0) {
324           Value *Bias = II.getOperand(ImageDimIntr->BiasIndex);
325           Args[ImageDimIntr->BiasIndex] = convertTo16Bit(*Bias, IC.Builder);
326         }
327       });
328 }
329 
330 bool GCNTTIImpl::canSimplifyLegacyMulToMul(const Instruction &I,
331                                            const Value *Op0, const Value *Op1,
332                                            InstCombiner &IC) const {
333   // The legacy behaviour is that multiplying +/-0.0 by anything, even NaN or
334   // infinity, gives +0.0. If we can prove we don't have one of the special
335   // cases then we can use a normal multiply instead.
336   // TODO: Create and use isKnownFiniteNonZero instead of just matching
337   // constants here.
338   if (match(Op0, PatternMatch::m_FiniteNonZero()) ||
339       match(Op1, PatternMatch::m_FiniteNonZero())) {
340     // One operand is not zero or infinity or NaN.
341     return true;
342   }
343 
344   SimplifyQuery SQ = IC.getSimplifyQuery().getWithInstruction(&I);
345   if (isKnownNeverInfOrNaN(Op0, /*Depth=*/0, SQ) &&
346       isKnownNeverInfOrNaN(Op1, /*Depth=*/0, SQ)) {
347     // Neither operand is infinity or NaN.
348     return true;
349   }
350   return false;
351 }
352 
353 /// Match an fpext from half to float, or a constant we can convert.
354 static Value *matchFPExtFromF16(Value *Arg) {
355   Value *Src = nullptr;
356   ConstantFP *CFP = nullptr;
357   if (match(Arg, m_OneUse(m_FPExt(m_Value(Src))))) {
358     if (Src->getType()->isHalfTy())
359       return Src;
360   } else if (match(Arg, m_ConstantFP(CFP))) {
361     bool LosesInfo;
362     APFloat Val(CFP->getValueAPF());
363     Val.convert(APFloat::IEEEhalf(), APFloat::rmNearestTiesToEven, &LosesInfo);
364     if (!LosesInfo)
365       return ConstantFP::get(Type::getHalfTy(Arg->getContext()), Val);
366   }
367   return nullptr;
368 }
369 
370 // Trim all zero components from the end of the vector \p UseV and return
371 // an appropriate bitset with known elements.
372 static APInt trimTrailingZerosInVector(InstCombiner &IC, Value *UseV,
373                                        Instruction *I) {
374   auto *VTy = cast<FixedVectorType>(UseV->getType());
375   unsigned VWidth = VTy->getNumElements();
376   APInt DemandedElts = APInt::getAllOnes(VWidth);
377 
378   for (int i = VWidth - 1; i > 0; --i) {
379     auto *Elt = findScalarElement(UseV, i);
380     if (!Elt)
381       break;
382 
383     if (auto *ConstElt = dyn_cast<Constant>(Elt)) {
384       if (!ConstElt->isNullValue() && !isa<UndefValue>(Elt))
385         break;
386     } else {
387       break;
388     }
389 
390     DemandedElts.clearBit(i);
391   }
392 
393   return DemandedElts;
394 }
395 
396 // Trim elements of the end of the vector \p V, if they are
397 // equal to the first element of the vector.
398 static APInt defaultComponentBroadcast(Value *V) {
399   auto *VTy = cast<FixedVectorType>(V->getType());
400   unsigned VWidth = VTy->getNumElements();
401   APInt DemandedElts = APInt::getAllOnes(VWidth);
402   Value *FirstComponent = findScalarElement(V, 0);
403 
404   SmallVector<int> ShuffleMask;
405   if (auto *SVI = dyn_cast<ShuffleVectorInst>(V))
406     SVI->getShuffleMask(ShuffleMask);
407 
408   for (int I = VWidth - 1; I > 0; --I) {
409     if (ShuffleMask.empty()) {
410       auto *Elt = findScalarElement(V, I);
411       if (!Elt || (Elt != FirstComponent && !isa<UndefValue>(Elt)))
412         break;
413     } else {
414       // Detect identical elements in the shufflevector result, even though
415       // findScalarElement cannot tell us what that element is.
416       if (ShuffleMask[I] != ShuffleMask[0] && ShuffleMask[I] != PoisonMaskElem)
417         break;
418     }
419     DemandedElts.clearBit(I);
420   }
421 
422   return DemandedElts;
423 }
424 
425 static Value *simplifyAMDGCNMemoryIntrinsicDemanded(InstCombiner &IC,
426                                                     IntrinsicInst &II,
427                                                     APInt DemandedElts,
428                                                     int DMaskIdx = -1,
429                                                     bool IsLoad = true);
430 
431 /// Return true if it's legal to contract llvm.amdgcn.rcp(llvm.sqrt)
432 static bool canContractSqrtToRsq(const FPMathOperator *SqrtOp) {
433   return (SqrtOp->getType()->isFloatTy() &&
434           (SqrtOp->hasApproxFunc() || SqrtOp->getFPAccuracy() >= 1.0f)) ||
435          SqrtOp->getType()->isHalfTy();
436 }
437 
438 /// Return true if we can easily prove that use U is uniform.
439 static bool isTriviallyUniform(const Use &U) {
440   Value *V = U.get();
441   if (isa<Constant>(V))
442     return true;
443   if (const auto *II = dyn_cast<IntrinsicInst>(V)) {
444     if (!AMDGPU::isIntrinsicAlwaysUniform(II->getIntrinsicID()))
445       return false;
446     // If II and U are in different blocks then there is a possibility of
447     // temporal divergence.
448     return II->getParent() == cast<Instruction>(U.getUser())->getParent();
449   }
450   return false;
451 }
452 
453 /// Simplify a lane index operand (e.g. llvm.amdgcn.readlane src1).
454 ///
455 /// The instruction only reads the low 5 bits for wave32, and 6 bits for wave64.
456 bool GCNTTIImpl::simplifyDemandedLaneMaskArg(InstCombiner &IC,
457                                              IntrinsicInst &II,
458                                              unsigned LaneArgIdx) const {
459   unsigned MaskBits = ST->getWavefrontSizeLog2();
460   APInt DemandedMask(32, maskTrailingOnes<unsigned>(MaskBits));
461 
462   KnownBits Known(32);
463   if (IC.SimplifyDemandedBits(&II, LaneArgIdx, DemandedMask, Known))
464     return true;
465 
466   if (!Known.isConstant())
467     return false;
468 
469   // Out of bounds indexes may appear in wave64 code compiled for wave32.
470   // Unlike the DAG version, SimplifyDemandedBits does not change constants, so
471   // manually fix it up.
472 
473   Value *LaneArg = II.getArgOperand(LaneArgIdx);
474   Constant *MaskedConst =
475       ConstantInt::get(LaneArg->getType(), Known.getConstant() & DemandedMask);
476   if (MaskedConst != LaneArg) {
477     II.getOperandUse(LaneArgIdx).set(MaskedConst);
478     return true;
479   }
480 
481   return false;
482 }
483 
484 std::optional<Instruction *>
485 GCNTTIImpl::instCombineIntrinsic(InstCombiner &IC, IntrinsicInst &II) const {
486   Intrinsic::ID IID = II.getIntrinsicID();
487   switch (IID) {
488   case Intrinsic::amdgcn_rcp: {
489     Value *Src = II.getArgOperand(0);
490 
491     // TODO: Move to ConstantFolding/InstSimplify?
492     if (isa<UndefValue>(Src)) {
493       Type *Ty = II.getType();
494       auto *QNaN = ConstantFP::get(Ty, APFloat::getQNaN(Ty->getFltSemantics()));
495       return IC.replaceInstUsesWith(II, QNaN);
496     }
497 
498     if (II.isStrictFP())
499       break;
500 
501     if (const ConstantFP *C = dyn_cast<ConstantFP>(Src)) {
502       const APFloat &ArgVal = C->getValueAPF();
503       APFloat Val(ArgVal.getSemantics(), 1);
504       Val.divide(ArgVal, APFloat::rmNearestTiesToEven);
505 
506       // This is more precise than the instruction may give.
507       //
508       // TODO: The instruction always flushes denormal results (except for f16),
509       // should this also?
510       return IC.replaceInstUsesWith(II, ConstantFP::get(II.getContext(), Val));
511     }
512 
513     FastMathFlags FMF = cast<FPMathOperator>(II).getFastMathFlags();
514     if (!FMF.allowContract())
515       break;
516     auto *SrcCI = dyn_cast<IntrinsicInst>(Src);
517     if (!SrcCI)
518       break;
519 
520     auto IID = SrcCI->getIntrinsicID();
521     // llvm.amdgcn.rcp(llvm.amdgcn.sqrt(x)) -> llvm.amdgcn.rsq(x) if contractable
522     //
523     // llvm.amdgcn.rcp(llvm.sqrt(x)) -> llvm.amdgcn.rsq(x) if contractable and
524     // relaxed.
525     if (IID == Intrinsic::amdgcn_sqrt || IID == Intrinsic::sqrt) {
526       const FPMathOperator *SqrtOp = cast<FPMathOperator>(SrcCI);
527       FastMathFlags InnerFMF = SqrtOp->getFastMathFlags();
528       if (!InnerFMF.allowContract() || !SrcCI->hasOneUse())
529         break;
530 
531       if (IID == Intrinsic::sqrt && !canContractSqrtToRsq(SqrtOp))
532         break;
533 
534       Function *NewDecl = Intrinsic::getOrInsertDeclaration(
535           SrcCI->getModule(), Intrinsic::amdgcn_rsq, {SrcCI->getType()});
536 
537       InnerFMF |= FMF;
538       II.setFastMathFlags(InnerFMF);
539 
540       II.setCalledFunction(NewDecl);
541       return IC.replaceOperand(II, 0, SrcCI->getArgOperand(0));
542     }
543 
544     break;
545   }
546   case Intrinsic::amdgcn_sqrt:
547   case Intrinsic::amdgcn_rsq: {
548     Value *Src = II.getArgOperand(0);
549 
550     // TODO: Move to ConstantFolding/InstSimplify?
551     if (isa<UndefValue>(Src)) {
552       Type *Ty = II.getType();
553       auto *QNaN = ConstantFP::get(Ty, APFloat::getQNaN(Ty->getFltSemantics()));
554       return IC.replaceInstUsesWith(II, QNaN);
555     }
556 
557     // f16 amdgcn.sqrt is identical to regular sqrt.
558     if (IID == Intrinsic::amdgcn_sqrt && Src->getType()->isHalfTy()) {
559       Function *NewDecl = Intrinsic::getOrInsertDeclaration(
560           II.getModule(), Intrinsic::sqrt, {II.getType()});
561       II.setCalledFunction(NewDecl);
562       return &II;
563     }
564 
565     break;
566   }
567   case Intrinsic::amdgcn_log:
568   case Intrinsic::amdgcn_exp2: {
569     const bool IsLog = IID == Intrinsic::amdgcn_log;
570     const bool IsExp = IID == Intrinsic::amdgcn_exp2;
571     Value *Src = II.getArgOperand(0);
572     Type *Ty = II.getType();
573 
574     if (isa<PoisonValue>(Src))
575       return IC.replaceInstUsesWith(II, Src);
576 
577     if (IC.getSimplifyQuery().isUndefValue(Src))
578       return IC.replaceInstUsesWith(II, ConstantFP::getNaN(Ty));
579 
580     if (ConstantFP *C = dyn_cast<ConstantFP>(Src)) {
581       if (C->isInfinity()) {
582         // exp2(+inf) -> +inf
583         // log2(+inf) -> +inf
584         if (!C->isNegative())
585           return IC.replaceInstUsesWith(II, C);
586 
587         // exp2(-inf) -> 0
588         if (IsExp && C->isNegative())
589           return IC.replaceInstUsesWith(II, ConstantFP::getZero(Ty));
590       }
591 
592       if (II.isStrictFP())
593         break;
594 
595       if (C->isNaN()) {
596         Constant *Quieted = ConstantFP::get(Ty, C->getValue().makeQuiet());
597         return IC.replaceInstUsesWith(II, Quieted);
598       }
599 
600       // f32 instruction doesn't handle denormals, f16 does.
601       if (C->isZero() || (C->getValue().isDenormal() && Ty->isFloatTy())) {
602         Constant *FoldedValue = IsLog ? ConstantFP::getInfinity(Ty, true)
603                                       : ConstantFP::get(Ty, 1.0);
604         return IC.replaceInstUsesWith(II, FoldedValue);
605       }
606 
607       if (IsLog && C->isNegative())
608         return IC.replaceInstUsesWith(II, ConstantFP::getNaN(Ty));
609 
610       // TODO: Full constant folding matching hardware behavior.
611     }
612 
613     break;
614   }
615   case Intrinsic::amdgcn_frexp_mant:
616   case Intrinsic::amdgcn_frexp_exp: {
617     Value *Src = II.getArgOperand(0);
618     if (const ConstantFP *C = dyn_cast<ConstantFP>(Src)) {
619       int Exp;
620       APFloat Significand =
621           frexp(C->getValueAPF(), Exp, APFloat::rmNearestTiesToEven);
622 
623       if (IID == Intrinsic::amdgcn_frexp_mant) {
624         return IC.replaceInstUsesWith(
625             II, ConstantFP::get(II.getContext(), Significand));
626       }
627 
628       // Match instruction special case behavior.
629       if (Exp == APFloat::IEK_NaN || Exp == APFloat::IEK_Inf)
630         Exp = 0;
631 
632       return IC.replaceInstUsesWith(II, ConstantInt::get(II.getType(), Exp));
633     }
634 
635     if (isa<UndefValue>(Src)) {
636       return IC.replaceInstUsesWith(II, UndefValue::get(II.getType()));
637     }
638 
639     break;
640   }
641   case Intrinsic::amdgcn_class: {
642     Value *Src0 = II.getArgOperand(0);
643     Value *Src1 = II.getArgOperand(1);
644     const ConstantInt *CMask = dyn_cast<ConstantInt>(Src1);
645     if (CMask) {
646       II.setCalledOperand(Intrinsic::getOrInsertDeclaration(
647           II.getModule(), Intrinsic::is_fpclass, Src0->getType()));
648 
649       // Clamp any excess bits, as they're illegal for the generic intrinsic.
650       II.setArgOperand(1, ConstantInt::get(Src1->getType(),
651                                            CMask->getZExtValue() & fcAllFlags));
652       return &II;
653     }
654 
655     // Propagate poison.
656     if (isa<PoisonValue>(Src0) || isa<PoisonValue>(Src1))
657       return IC.replaceInstUsesWith(II, PoisonValue::get(II.getType()));
658 
659     // llvm.amdgcn.class(_, undef) -> false
660     if (IC.getSimplifyQuery().isUndefValue(Src1))
661       return IC.replaceInstUsesWith(II, ConstantInt::get(II.getType(), false));
662 
663     // llvm.amdgcn.class(undef, mask) -> mask != 0
664     if (IC.getSimplifyQuery().isUndefValue(Src0)) {
665       Value *CmpMask = IC.Builder.CreateICmpNE(
666           Src1, ConstantInt::getNullValue(Src1->getType()));
667       return IC.replaceInstUsesWith(II, CmpMask);
668     }
669     break;
670   }
671   case Intrinsic::amdgcn_cvt_pkrtz: {
672     auto foldFPTruncToF16RTZ = [](Value *Arg) -> Value * {
673       Type *HalfTy = Type::getHalfTy(Arg->getContext());
674 
675       if (isa<PoisonValue>(Arg))
676         return PoisonValue::get(HalfTy);
677       if (isa<UndefValue>(Arg))
678         return UndefValue::get(HalfTy);
679 
680       ConstantFP *CFP = nullptr;
681       if (match(Arg, m_ConstantFP(CFP))) {
682         bool LosesInfo;
683         APFloat Val(CFP->getValueAPF());
684         Val.convert(APFloat::IEEEhalf(), APFloat::rmTowardZero, &LosesInfo);
685         return ConstantFP::get(HalfTy, Val);
686       }
687 
688       Value *Src = nullptr;
689       if (match(Arg, m_FPExt(m_Value(Src)))) {
690         if (Src->getType()->isHalfTy())
691           return Src;
692       }
693 
694       return nullptr;
695     };
696 
697     if (Value *Src0 = foldFPTruncToF16RTZ(II.getArgOperand(0))) {
698       if (Value *Src1 = foldFPTruncToF16RTZ(II.getArgOperand(1))) {
699         Value *V = PoisonValue::get(II.getType());
700         V = IC.Builder.CreateInsertElement(V, Src0, (uint64_t)0);
701         V = IC.Builder.CreateInsertElement(V, Src1, (uint64_t)1);
702         return IC.replaceInstUsesWith(II, V);
703       }
704     }
705 
706     break;
707   }
708   case Intrinsic::amdgcn_cvt_pknorm_i16:
709   case Intrinsic::amdgcn_cvt_pknorm_u16:
710   case Intrinsic::amdgcn_cvt_pk_i16:
711   case Intrinsic::amdgcn_cvt_pk_u16: {
712     Value *Src0 = II.getArgOperand(0);
713     Value *Src1 = II.getArgOperand(1);
714 
715     if (isa<UndefValue>(Src0) && isa<UndefValue>(Src1)) {
716       return IC.replaceInstUsesWith(II, UndefValue::get(II.getType()));
717     }
718 
719     break;
720   }
721   case Intrinsic::amdgcn_ubfe:
722   case Intrinsic::amdgcn_sbfe: {
723     // Decompose simple cases into standard shifts.
724     Value *Src = II.getArgOperand(0);
725     if (isa<UndefValue>(Src)) {
726       return IC.replaceInstUsesWith(II, Src);
727     }
728 
729     unsigned Width;
730     Type *Ty = II.getType();
731     unsigned IntSize = Ty->getIntegerBitWidth();
732 
733     ConstantInt *CWidth = dyn_cast<ConstantInt>(II.getArgOperand(2));
734     if (CWidth) {
735       Width = CWidth->getZExtValue();
736       if ((Width & (IntSize - 1)) == 0) {
737         return IC.replaceInstUsesWith(II, ConstantInt::getNullValue(Ty));
738       }
739 
740       // Hardware ignores high bits, so remove those.
741       if (Width >= IntSize) {
742         return IC.replaceOperand(
743             II, 2, ConstantInt::get(CWidth->getType(), Width & (IntSize - 1)));
744       }
745     }
746 
747     unsigned Offset;
748     ConstantInt *COffset = dyn_cast<ConstantInt>(II.getArgOperand(1));
749     if (COffset) {
750       Offset = COffset->getZExtValue();
751       if (Offset >= IntSize) {
752         return IC.replaceOperand(
753             II, 1,
754             ConstantInt::get(COffset->getType(), Offset & (IntSize - 1)));
755       }
756     }
757 
758     bool Signed = IID == Intrinsic::amdgcn_sbfe;
759 
760     if (!CWidth || !COffset)
761       break;
762 
763     // The case of Width == 0 is handled above, which makes this transformation
764     // safe.  If Width == 0, then the ashr and lshr instructions become poison
765     // value since the shift amount would be equal to the bit size.
766     assert(Width != 0);
767 
768     // TODO: This allows folding to undef when the hardware has specific
769     // behavior?
770     if (Offset + Width < IntSize) {
771       Value *Shl = IC.Builder.CreateShl(Src, IntSize - Offset - Width);
772       Value *RightShift = Signed ? IC.Builder.CreateAShr(Shl, IntSize - Width)
773                                  : IC.Builder.CreateLShr(Shl, IntSize - Width);
774       RightShift->takeName(&II);
775       return IC.replaceInstUsesWith(II, RightShift);
776     }
777 
778     Value *RightShift = Signed ? IC.Builder.CreateAShr(Src, Offset)
779                                : IC.Builder.CreateLShr(Src, Offset);
780 
781     RightShift->takeName(&II);
782     return IC.replaceInstUsesWith(II, RightShift);
783   }
784   case Intrinsic::amdgcn_exp:
785   case Intrinsic::amdgcn_exp_row:
786   case Intrinsic::amdgcn_exp_compr: {
787     ConstantInt *En = cast<ConstantInt>(II.getArgOperand(1));
788     unsigned EnBits = En->getZExtValue();
789     if (EnBits == 0xf)
790       break; // All inputs enabled.
791 
792     bool IsCompr = IID == Intrinsic::amdgcn_exp_compr;
793     bool Changed = false;
794     for (int I = 0; I < (IsCompr ? 2 : 4); ++I) {
795       if ((!IsCompr && (EnBits & (1 << I)) == 0) ||
796           (IsCompr && ((EnBits & (0x3 << (2 * I))) == 0))) {
797         Value *Src = II.getArgOperand(I + 2);
798         if (!isa<UndefValue>(Src)) {
799           IC.replaceOperand(II, I + 2, UndefValue::get(Src->getType()));
800           Changed = true;
801         }
802       }
803     }
804 
805     if (Changed) {
806       return &II;
807     }
808 
809     break;
810   }
811   case Intrinsic::amdgcn_fmed3: {
812     // Note this does not preserve proper sNaN behavior if IEEE-mode is enabled
813     // for the shader.
814 
815     Value *Src0 = II.getArgOperand(0);
816     Value *Src1 = II.getArgOperand(1);
817     Value *Src2 = II.getArgOperand(2);
818 
819     // Checking for NaN before canonicalization provides better fidelity when
820     // mapping other operations onto fmed3 since the order of operands is
821     // unchanged.
822     Value *V = nullptr;
823     if (match(Src0, PatternMatch::m_NaN()) || isa<UndefValue>(Src0)) {
824       V = IC.Builder.CreateMinNum(Src1, Src2);
825     } else if (match(Src1, PatternMatch::m_NaN()) || isa<UndefValue>(Src1)) {
826       V = IC.Builder.CreateMinNum(Src0, Src2);
827     } else if (match(Src2, PatternMatch::m_NaN()) || isa<UndefValue>(Src2)) {
828       V = IC.Builder.CreateMaxNum(Src0, Src1);
829     }
830 
831     if (V) {
832       if (auto *CI = dyn_cast<CallInst>(V)) {
833         CI->copyFastMathFlags(&II);
834         CI->takeName(&II);
835       }
836       return IC.replaceInstUsesWith(II, V);
837     }
838 
839     bool Swap = false;
840     // Canonicalize constants to RHS operands.
841     //
842     // fmed3(c0, x, c1) -> fmed3(x, c0, c1)
843     if (isa<Constant>(Src0) && !isa<Constant>(Src1)) {
844       std::swap(Src0, Src1);
845       Swap = true;
846     }
847 
848     if (isa<Constant>(Src1) && !isa<Constant>(Src2)) {
849       std::swap(Src1, Src2);
850       Swap = true;
851     }
852 
853     if (isa<Constant>(Src0) && !isa<Constant>(Src1)) {
854       std::swap(Src0, Src1);
855       Swap = true;
856     }
857 
858     if (Swap) {
859       II.setArgOperand(0, Src0);
860       II.setArgOperand(1, Src1);
861       II.setArgOperand(2, Src2);
862       return &II;
863     }
864 
865     if (const ConstantFP *C0 = dyn_cast<ConstantFP>(Src0)) {
866       if (const ConstantFP *C1 = dyn_cast<ConstantFP>(Src1)) {
867         if (const ConstantFP *C2 = dyn_cast<ConstantFP>(Src2)) {
868           APFloat Result = fmed3AMDGCN(C0->getValueAPF(), C1->getValueAPF(),
869                                        C2->getValueAPF());
870           return IC.replaceInstUsesWith(
871               II, ConstantFP::get(IC.Builder.getContext(), Result));
872         }
873       }
874     }
875 
876     if (!ST->hasMed3_16())
877       break;
878 
879     // Repeat floating-point width reduction done for minnum/maxnum.
880     // fmed3((fpext X), (fpext Y), (fpext Z)) -> fpext (fmed3(X, Y, Z))
881     if (Value *X = matchFPExtFromF16(Src0)) {
882       if (Value *Y = matchFPExtFromF16(Src1)) {
883         if (Value *Z = matchFPExtFromF16(Src2)) {
884           Value *NewCall = IC.Builder.CreateIntrinsic(
885               IID, {X->getType()}, {X, Y, Z}, &II, II.getName());
886           return new FPExtInst(NewCall, II.getType());
887         }
888       }
889     }
890 
891     break;
892   }
893   case Intrinsic::amdgcn_icmp:
894   case Intrinsic::amdgcn_fcmp: {
895     const ConstantInt *CC = cast<ConstantInt>(II.getArgOperand(2));
896     // Guard against invalid arguments.
897     int64_t CCVal = CC->getZExtValue();
898     bool IsInteger = IID == Intrinsic::amdgcn_icmp;
899     if ((IsInteger && (CCVal < CmpInst::FIRST_ICMP_PREDICATE ||
900                        CCVal > CmpInst::LAST_ICMP_PREDICATE)) ||
901         (!IsInteger && (CCVal < CmpInst::FIRST_FCMP_PREDICATE ||
902                         CCVal > CmpInst::LAST_FCMP_PREDICATE)))
903       break;
904 
905     Value *Src0 = II.getArgOperand(0);
906     Value *Src1 = II.getArgOperand(1);
907 
908     if (auto *CSrc0 = dyn_cast<Constant>(Src0)) {
909       if (auto *CSrc1 = dyn_cast<Constant>(Src1)) {
910         Constant *CCmp = ConstantFoldCompareInstOperands(
911             (ICmpInst::Predicate)CCVal, CSrc0, CSrc1, DL);
912         if (CCmp && CCmp->isNullValue()) {
913           return IC.replaceInstUsesWith(
914               II, IC.Builder.CreateSExt(CCmp, II.getType()));
915         }
916 
917         // The result of V_ICMP/V_FCMP assembly instructions (which this
918         // intrinsic exposes) is one bit per thread, masked with the EXEC
919         // register (which contains the bitmask of live threads). So a
920         // comparison that always returns true is the same as a read of the
921         // EXEC register.
922         Metadata *MDArgs[] = {MDString::get(II.getContext(), "exec")};
923         MDNode *MD = MDNode::get(II.getContext(), MDArgs);
924         Value *Args[] = {MetadataAsValue::get(II.getContext(), MD)};
925         CallInst *NewCall = IC.Builder.CreateIntrinsic(Intrinsic::read_register,
926                                                        II.getType(), Args);
927         NewCall->addFnAttr(Attribute::Convergent);
928         NewCall->takeName(&II);
929         return IC.replaceInstUsesWith(II, NewCall);
930       }
931 
932       // Canonicalize constants to RHS.
933       CmpInst::Predicate SwapPred =
934           CmpInst::getSwappedPredicate(static_cast<CmpInst::Predicate>(CCVal));
935       II.setArgOperand(0, Src1);
936       II.setArgOperand(1, Src0);
937       II.setArgOperand(
938           2, ConstantInt::get(CC->getType(), static_cast<int>(SwapPred)));
939       return &II;
940     }
941 
942     if (CCVal != CmpInst::ICMP_EQ && CCVal != CmpInst::ICMP_NE)
943       break;
944 
945     // Canonicalize compare eq with true value to compare != 0
946     // llvm.amdgcn.icmp(zext (i1 x), 1, eq)
947     //   -> llvm.amdgcn.icmp(zext (i1 x), 0, ne)
948     // llvm.amdgcn.icmp(sext (i1 x), -1, eq)
949     //   -> llvm.amdgcn.icmp(sext (i1 x), 0, ne)
950     Value *ExtSrc;
951     if (CCVal == CmpInst::ICMP_EQ &&
952         ((match(Src1, PatternMatch::m_One()) &&
953           match(Src0, m_ZExt(PatternMatch::m_Value(ExtSrc)))) ||
954          (match(Src1, PatternMatch::m_AllOnes()) &&
955           match(Src0, m_SExt(PatternMatch::m_Value(ExtSrc))))) &&
956         ExtSrc->getType()->isIntegerTy(1)) {
957       IC.replaceOperand(II, 1, ConstantInt::getNullValue(Src1->getType()));
958       IC.replaceOperand(II, 2,
959                         ConstantInt::get(CC->getType(), CmpInst::ICMP_NE));
960       return &II;
961     }
962 
963     CmpPredicate SrcPred;
964     Value *SrcLHS;
965     Value *SrcRHS;
966 
967     // Fold compare eq/ne with 0 from a compare result as the predicate to the
968     // intrinsic. The typical use is a wave vote function in the library, which
969     // will be fed from a user code condition compared with 0. Fold in the
970     // redundant compare.
971 
972     // llvm.amdgcn.icmp([sz]ext ([if]cmp pred a, b), 0, ne)
973     //   -> llvm.amdgcn.[if]cmp(a, b, pred)
974     //
975     // llvm.amdgcn.icmp([sz]ext ([if]cmp pred a, b), 0, eq)
976     //   -> llvm.amdgcn.[if]cmp(a, b, inv pred)
977     if (match(Src1, PatternMatch::m_Zero()) &&
978         match(Src0, PatternMatch::m_ZExtOrSExt(
979                         m_Cmp(SrcPred, PatternMatch::m_Value(SrcLHS),
980                               PatternMatch::m_Value(SrcRHS))))) {
981       if (CCVal == CmpInst::ICMP_EQ)
982         SrcPred = CmpInst::getInversePredicate(SrcPred);
983 
984       Intrinsic::ID NewIID = CmpInst::isFPPredicate(SrcPred)
985                                  ? Intrinsic::amdgcn_fcmp
986                                  : Intrinsic::amdgcn_icmp;
987 
988       Type *Ty = SrcLHS->getType();
989       if (auto *CmpType = dyn_cast<IntegerType>(Ty)) {
990         // Promote to next legal integer type.
991         unsigned Width = CmpType->getBitWidth();
992         unsigned NewWidth = Width;
993 
994         // Don't do anything for i1 comparisons.
995         if (Width == 1)
996           break;
997 
998         if (Width <= 16)
999           NewWidth = 16;
1000         else if (Width <= 32)
1001           NewWidth = 32;
1002         else if (Width <= 64)
1003           NewWidth = 64;
1004         else
1005           break; // Can't handle this.
1006 
1007         if (Width != NewWidth) {
1008           IntegerType *CmpTy = IC.Builder.getIntNTy(NewWidth);
1009           if (CmpInst::isSigned(SrcPred)) {
1010             SrcLHS = IC.Builder.CreateSExt(SrcLHS, CmpTy);
1011             SrcRHS = IC.Builder.CreateSExt(SrcRHS, CmpTy);
1012           } else {
1013             SrcLHS = IC.Builder.CreateZExt(SrcLHS, CmpTy);
1014             SrcRHS = IC.Builder.CreateZExt(SrcRHS, CmpTy);
1015           }
1016         }
1017       } else if (!Ty->isFloatTy() && !Ty->isDoubleTy() && !Ty->isHalfTy())
1018         break;
1019 
1020       Value *Args[] = {SrcLHS, SrcRHS,
1021                        ConstantInt::get(CC->getType(), SrcPred)};
1022       CallInst *NewCall = IC.Builder.CreateIntrinsic(
1023           NewIID, {II.getType(), SrcLHS->getType()}, Args);
1024       NewCall->takeName(&II);
1025       return IC.replaceInstUsesWith(II, NewCall);
1026     }
1027 
1028     break;
1029   }
1030   case Intrinsic::amdgcn_mbcnt_hi: {
1031     // exec_hi is all 0, so this is just a copy.
1032     if (ST->isWave32())
1033       return IC.replaceInstUsesWith(II, II.getArgOperand(1));
1034     break;
1035   }
1036   case Intrinsic::amdgcn_ballot: {
1037     if (auto *Src = dyn_cast<ConstantInt>(II.getArgOperand(0))) {
1038       if (Src->isZero()) {
1039         // amdgcn.ballot(i1 0) is zero.
1040         return IC.replaceInstUsesWith(II, Constant::getNullValue(II.getType()));
1041       }
1042     }
1043     if (ST->isWave32() && II.getType()->getIntegerBitWidth() == 64) {
1044       // %b64 = call i64 ballot.i64(...)
1045       // =>
1046       // %b32 = call i32 ballot.i32(...)
1047       // %b64 = zext i32 %b32 to i64
1048       Value *Call = IC.Builder.CreateZExt(
1049           IC.Builder.CreateIntrinsic(Intrinsic::amdgcn_ballot,
1050                                      {IC.Builder.getInt32Ty()},
1051                                      {II.getArgOperand(0)}),
1052           II.getType());
1053       Call->takeName(&II);
1054       return IC.replaceInstUsesWith(II, Call);
1055     }
1056     break;
1057   }
1058   case Intrinsic::amdgcn_wavefrontsize: {
1059     if (ST->isWaveSizeKnown())
1060       return IC.replaceInstUsesWith(
1061           II, ConstantInt::get(II.getType(), ST->getWavefrontSize()));
1062     break;
1063   }
1064   case Intrinsic::amdgcn_wqm_vote: {
1065     // wqm_vote is identity when the argument is constant.
1066     if (!isa<Constant>(II.getArgOperand(0)))
1067       break;
1068 
1069     return IC.replaceInstUsesWith(II, II.getArgOperand(0));
1070   }
1071   case Intrinsic::amdgcn_kill: {
1072     const ConstantInt *C = dyn_cast<ConstantInt>(II.getArgOperand(0));
1073     if (!C || !C->getZExtValue())
1074       break;
1075 
1076     // amdgcn.kill(i1 1) is a no-op
1077     return IC.eraseInstFromFunction(II);
1078   }
1079   case Intrinsic::amdgcn_update_dpp: {
1080     Value *Old = II.getArgOperand(0);
1081 
1082     auto *BC = cast<ConstantInt>(II.getArgOperand(5));
1083     auto *RM = cast<ConstantInt>(II.getArgOperand(3));
1084     auto *BM = cast<ConstantInt>(II.getArgOperand(4));
1085     if (BC->isZeroValue() || RM->getZExtValue() != 0xF ||
1086         BM->getZExtValue() != 0xF || isa<UndefValue>(Old))
1087       break;
1088 
1089     // If bound_ctrl = 1, row mask = bank mask = 0xf we can omit old value.
1090     return IC.replaceOperand(II, 0, UndefValue::get(Old->getType()));
1091   }
1092   case Intrinsic::amdgcn_permlane16:
1093   case Intrinsic::amdgcn_permlane16_var:
1094   case Intrinsic::amdgcn_permlanex16:
1095   case Intrinsic::amdgcn_permlanex16_var: {
1096     // Discard vdst_in if it's not going to be read.
1097     Value *VDstIn = II.getArgOperand(0);
1098     if (isa<UndefValue>(VDstIn))
1099       break;
1100 
1101     // FetchInvalid operand idx.
1102     unsigned int FiIdx = (IID == Intrinsic::amdgcn_permlane16 ||
1103                           IID == Intrinsic::amdgcn_permlanex16)
1104                              ? 4  /* for permlane16 and permlanex16 */
1105                              : 3; /* for permlane16_var and permlanex16_var */
1106 
1107     // BoundCtrl operand idx.
1108     // For permlane16 and permlanex16 it should be 5
1109     // For Permlane16_var and permlanex16_var it should be 4
1110     unsigned int BcIdx = FiIdx + 1;
1111 
1112     ConstantInt *FetchInvalid = cast<ConstantInt>(II.getArgOperand(FiIdx));
1113     ConstantInt *BoundCtrl = cast<ConstantInt>(II.getArgOperand(BcIdx));
1114     if (!FetchInvalid->getZExtValue() && !BoundCtrl->getZExtValue())
1115       break;
1116 
1117     return IC.replaceOperand(II, 0, UndefValue::get(VDstIn->getType()));
1118   }
1119   case Intrinsic::amdgcn_permlane64:
1120   case Intrinsic::amdgcn_readfirstlane:
1121   case Intrinsic::amdgcn_readlane: {
1122     // If the first argument is uniform these intrinsics return it unchanged.
1123     const Use &Src = II.getArgOperandUse(0);
1124     if (isTriviallyUniform(Src))
1125       return IC.replaceInstUsesWith(II, Src.get());
1126 
1127     if (IID == Intrinsic::amdgcn_readlane &&
1128         simplifyDemandedLaneMaskArg(IC, II, 1))
1129       return &II;
1130 
1131     return std::nullopt;
1132   }
1133   case Intrinsic::amdgcn_writelane: {
1134     if (simplifyDemandedLaneMaskArg(IC, II, 1))
1135       return &II;
1136     return std::nullopt;
1137   }
1138   case Intrinsic::amdgcn_trig_preop: {
1139     // The intrinsic is declared with name mangling, but currently the
1140     // instruction only exists for f64
1141     if (!II.getType()->isDoubleTy())
1142       break;
1143 
1144     Value *Src = II.getArgOperand(0);
1145     Value *Segment = II.getArgOperand(1);
1146     if (isa<PoisonValue>(Src) || isa<PoisonValue>(Segment))
1147       return IC.replaceInstUsesWith(II, PoisonValue::get(II.getType()));
1148 
1149     if (isa<UndefValue>(Src)) {
1150       auto *QNaN = ConstantFP::get(
1151           II.getType(), APFloat::getQNaN(II.getType()->getFltSemantics()));
1152       return IC.replaceInstUsesWith(II, QNaN);
1153     }
1154 
1155     const ConstantFP *Csrc = dyn_cast<ConstantFP>(Src);
1156     if (!Csrc)
1157       break;
1158 
1159     if (II.isStrictFP())
1160       break;
1161 
1162     const APFloat &Fsrc = Csrc->getValueAPF();
1163     if (Fsrc.isNaN()) {
1164       auto *Quieted = ConstantFP::get(II.getType(), Fsrc.makeQuiet());
1165       return IC.replaceInstUsesWith(II, Quieted);
1166     }
1167 
1168     const ConstantInt *Cseg = dyn_cast<ConstantInt>(Segment);
1169     if (!Cseg)
1170       break;
1171 
1172     unsigned Exponent = (Fsrc.bitcastToAPInt().getZExtValue() >> 52) & 0x7ff;
1173     unsigned SegmentVal = Cseg->getValue().trunc(5).getZExtValue();
1174     unsigned Shift = SegmentVal * 53;
1175     if (Exponent > 1077)
1176       Shift += Exponent - 1077;
1177 
1178     // 2.0/PI table.
1179     static const uint32_t TwoByPi[] = {
1180         0xa2f9836e, 0x4e441529, 0xfc2757d1, 0xf534ddc0, 0xdb629599, 0x3c439041,
1181         0xfe5163ab, 0xdebbc561, 0xb7246e3a, 0x424dd2e0, 0x06492eea, 0x09d1921c,
1182         0xfe1deb1c, 0xb129a73e, 0xe88235f5, 0x2ebb4484, 0xe99c7026, 0xb45f7e41,
1183         0x3991d639, 0x835339f4, 0x9c845f8b, 0xbdf9283b, 0x1ff897ff, 0xde05980f,
1184         0xef2f118b, 0x5a0a6d1f, 0x6d367ecf, 0x27cb09b7, 0x4f463f66, 0x9e5fea2d,
1185         0x7527bac7, 0xebe5f17b, 0x3d0739f7, 0x8a5292ea, 0x6bfb5fb1, 0x1f8d5d08,
1186         0x56033046};
1187 
1188     // Return 0 for outbound segment (hardware behavior).
1189     unsigned Idx = Shift >> 5;
1190     if (Idx + 2 >= std::size(TwoByPi)) {
1191       APFloat Zero = APFloat::getZero(II.getType()->getFltSemantics());
1192       return IC.replaceInstUsesWith(II, ConstantFP::get(II.getType(), Zero));
1193     }
1194 
1195     unsigned BShift = Shift & 0x1f;
1196     uint64_t Thi = Make_64(TwoByPi[Idx], TwoByPi[Idx + 1]);
1197     uint64_t Tlo = Make_64(TwoByPi[Idx + 2], 0);
1198     if (BShift)
1199       Thi = (Thi << BShift) | (Tlo >> (64 - BShift));
1200     Thi = Thi >> 11;
1201     APFloat Result = APFloat((double)Thi);
1202 
1203     int Scale = -53 - Shift;
1204     if (Exponent >= 1968)
1205       Scale += 128;
1206 
1207     Result = scalbn(Result, Scale, RoundingMode::NearestTiesToEven);
1208     return IC.replaceInstUsesWith(II, ConstantFP::get(Src->getType(), Result));
1209   }
1210   case Intrinsic::amdgcn_fmul_legacy: {
1211     Value *Op0 = II.getArgOperand(0);
1212     Value *Op1 = II.getArgOperand(1);
1213 
1214     // The legacy behaviour is that multiplying +/-0.0 by anything, even NaN or
1215     // infinity, gives +0.0.
1216     // TODO: Move to InstSimplify?
1217     if (match(Op0, PatternMatch::m_AnyZeroFP()) ||
1218         match(Op1, PatternMatch::m_AnyZeroFP()))
1219       return IC.replaceInstUsesWith(II, ConstantFP::getZero(II.getType()));
1220 
1221     // If we can prove we don't have one of the special cases then we can use a
1222     // normal fmul instruction instead.
1223     if (canSimplifyLegacyMulToMul(II, Op0, Op1, IC)) {
1224       auto *FMul = IC.Builder.CreateFMulFMF(Op0, Op1, &II);
1225       FMul->takeName(&II);
1226       return IC.replaceInstUsesWith(II, FMul);
1227     }
1228     break;
1229   }
1230   case Intrinsic::amdgcn_fma_legacy: {
1231     Value *Op0 = II.getArgOperand(0);
1232     Value *Op1 = II.getArgOperand(1);
1233     Value *Op2 = II.getArgOperand(2);
1234 
1235     // The legacy behaviour is that multiplying +/-0.0 by anything, even NaN or
1236     // infinity, gives +0.0.
1237     // TODO: Move to InstSimplify?
1238     if (match(Op0, PatternMatch::m_AnyZeroFP()) ||
1239         match(Op1, PatternMatch::m_AnyZeroFP())) {
1240       // It's tempting to just return Op2 here, but that would give the wrong
1241       // result if Op2 was -0.0.
1242       auto *Zero = ConstantFP::getZero(II.getType());
1243       auto *FAdd = IC.Builder.CreateFAddFMF(Zero, Op2, &II);
1244       FAdd->takeName(&II);
1245       return IC.replaceInstUsesWith(II, FAdd);
1246     }
1247 
1248     // If we can prove we don't have one of the special cases then we can use a
1249     // normal fma instead.
1250     if (canSimplifyLegacyMulToMul(II, Op0, Op1, IC)) {
1251       II.setCalledOperand(Intrinsic::getOrInsertDeclaration(
1252           II.getModule(), Intrinsic::fma, II.getType()));
1253       return &II;
1254     }
1255     break;
1256   }
1257   case Intrinsic::amdgcn_is_shared:
1258   case Intrinsic::amdgcn_is_private: {
1259     if (isa<UndefValue>(II.getArgOperand(0)))
1260       return IC.replaceInstUsesWith(II, UndefValue::get(II.getType()));
1261 
1262     if (isa<ConstantPointerNull>(II.getArgOperand(0)))
1263       return IC.replaceInstUsesWith(II, ConstantInt::getFalse(II.getType()));
1264     break;
1265   }
1266   case Intrinsic::amdgcn_raw_buffer_store_format:
1267   case Intrinsic::amdgcn_struct_buffer_store_format:
1268   case Intrinsic::amdgcn_raw_tbuffer_store:
1269   case Intrinsic::amdgcn_struct_tbuffer_store:
1270   case Intrinsic::amdgcn_image_store_1d:
1271   case Intrinsic::amdgcn_image_store_1darray:
1272   case Intrinsic::amdgcn_image_store_2d:
1273   case Intrinsic::amdgcn_image_store_2darray:
1274   case Intrinsic::amdgcn_image_store_2darraymsaa:
1275   case Intrinsic::amdgcn_image_store_2dmsaa:
1276   case Intrinsic::amdgcn_image_store_3d:
1277   case Intrinsic::amdgcn_image_store_cube:
1278   case Intrinsic::amdgcn_image_store_mip_1d:
1279   case Intrinsic::amdgcn_image_store_mip_1darray:
1280   case Intrinsic::amdgcn_image_store_mip_2d:
1281   case Intrinsic::amdgcn_image_store_mip_2darray:
1282   case Intrinsic::amdgcn_image_store_mip_3d:
1283   case Intrinsic::amdgcn_image_store_mip_cube: {
1284     if (!isa<FixedVectorType>(II.getArgOperand(0)->getType()))
1285       break;
1286 
1287     APInt DemandedElts;
1288     if (ST->hasDefaultComponentBroadcast())
1289       DemandedElts = defaultComponentBroadcast(II.getArgOperand(0));
1290     else if (ST->hasDefaultComponentZero())
1291       DemandedElts = trimTrailingZerosInVector(IC, II.getArgOperand(0), &II);
1292     else
1293       break;
1294 
1295     int DMaskIdx = getAMDGPUImageDMaskIntrinsic(II.getIntrinsicID()) ? 1 : -1;
1296     if (simplifyAMDGCNMemoryIntrinsicDemanded(IC, II, DemandedElts, DMaskIdx,
1297                                               false)) {
1298       return IC.eraseInstFromFunction(II);
1299     }
1300 
1301     break;
1302   }
1303   case Intrinsic::amdgcn_prng_b32: {
1304     auto *Src = II.getArgOperand(0);
1305     if (isa<UndefValue>(Src)) {
1306       return IC.replaceInstUsesWith(II, Src);
1307     }
1308     return std::nullopt;
1309   }
1310   case Intrinsic::amdgcn_mfma_scale_f32_16x16x128_f8f6f4:
1311   case Intrinsic::amdgcn_mfma_scale_f32_32x32x64_f8f6f4: {
1312     Value *Src0 = II.getArgOperand(0);
1313     Value *Src1 = II.getArgOperand(1);
1314     uint64_t CBSZ = cast<ConstantInt>(II.getArgOperand(3))->getZExtValue();
1315     uint64_t BLGP = cast<ConstantInt>(II.getArgOperand(4))->getZExtValue();
1316     auto *Src0Ty = cast<FixedVectorType>(Src0->getType());
1317     auto *Src1Ty = cast<FixedVectorType>(Src1->getType());
1318 
1319     auto getFormatNumRegs = [](unsigned FormatVal) {
1320       switch (FormatVal) {
1321       case AMDGPU::MFMAScaleFormats::FP6_E2M3:
1322       case AMDGPU::MFMAScaleFormats::FP6_E3M2:
1323         return 6u;
1324       case AMDGPU::MFMAScaleFormats::FP4_E2M1:
1325         return 4u;
1326       case AMDGPU::MFMAScaleFormats::FP8_E4M3:
1327       case AMDGPU::MFMAScaleFormats::FP8_E5M2:
1328         return 8u;
1329       default:
1330         llvm_unreachable("invalid format value");
1331       }
1332     };
1333 
1334     bool MadeChange = false;
1335     unsigned Src0NumElts = getFormatNumRegs(CBSZ);
1336     unsigned Src1NumElts = getFormatNumRegs(BLGP);
1337 
1338     // Depending on the used format, fewer registers are required so shrink the
1339     // vector type.
1340     if (Src0Ty->getNumElements() > Src0NumElts) {
1341       Src0 = IC.Builder.CreateExtractVector(
1342           FixedVectorType::get(Src0Ty->getElementType(), Src0NumElts), Src0,
1343           IC.Builder.getInt64(0));
1344       MadeChange = true;
1345     }
1346 
1347     if (Src1Ty->getNumElements() > Src1NumElts) {
1348       Src1 = IC.Builder.CreateExtractVector(
1349           FixedVectorType::get(Src1Ty->getElementType(), Src1NumElts), Src1,
1350           IC.Builder.getInt64(0));
1351       MadeChange = true;
1352     }
1353 
1354     if (!MadeChange)
1355       return std::nullopt;
1356 
1357     SmallVector<Value *, 10> Args(II.args());
1358     Args[0] = Src0;
1359     Args[1] = Src1;
1360 
1361     CallInst *NewII = IC.Builder.CreateIntrinsic(
1362         IID, {Src0->getType(), Src1->getType()}, Args, &II);
1363     NewII->takeName(&II);
1364     return IC.replaceInstUsesWith(II, NewII);
1365   }
1366   }
1367   if (const AMDGPU::ImageDimIntrinsicInfo *ImageDimIntr =
1368             AMDGPU::getImageDimIntrinsicInfo(II.getIntrinsicID())) {
1369     return simplifyAMDGCNImageIntrinsic(ST, ImageDimIntr, II, IC);
1370   }
1371   return std::nullopt;
1372 }
1373 
1374 /// Implement SimplifyDemandedVectorElts for amdgcn buffer and image intrinsics.
1375 ///
1376 /// The result of simplifying amdgcn image and buffer store intrinsics is updating
1377 /// definitions of the intrinsics vector argument, not Uses of the result like
1378 /// image and buffer loads.
1379 /// Note: This only supports non-TFE/LWE image intrinsic calls; those have
1380 ///       struct returns.
1381 static Value *simplifyAMDGCNMemoryIntrinsicDemanded(InstCombiner &IC,
1382                                                     IntrinsicInst &II,
1383                                                     APInt DemandedElts,
1384                                                     int DMaskIdx, bool IsLoad) {
1385 
1386   auto *IIVTy = cast<FixedVectorType>(IsLoad ? II.getType()
1387                                              : II.getOperand(0)->getType());
1388   unsigned VWidth = IIVTy->getNumElements();
1389   if (VWidth == 1)
1390     return nullptr;
1391   Type *EltTy = IIVTy->getElementType();
1392 
1393   IRBuilderBase::InsertPointGuard Guard(IC.Builder);
1394   IC.Builder.SetInsertPoint(&II);
1395 
1396   // Assume the arguments are unchanged and later override them, if needed.
1397   SmallVector<Value *, 16> Args(II.args());
1398 
1399   if (DMaskIdx < 0) {
1400     // Buffer case.
1401 
1402     const unsigned ActiveBits = DemandedElts.getActiveBits();
1403     const unsigned UnusedComponentsAtFront = DemandedElts.countr_zero();
1404 
1405     // Start assuming the prefix of elements is demanded, but possibly clear
1406     // some other bits if there are trailing zeros (unused components at front)
1407     // and update offset.
1408     DemandedElts = (1 << ActiveBits) - 1;
1409 
1410     if (UnusedComponentsAtFront > 0) {
1411       static const unsigned InvalidOffsetIdx = 0xf;
1412 
1413       unsigned OffsetIdx;
1414       switch (II.getIntrinsicID()) {
1415       case Intrinsic::amdgcn_raw_buffer_load:
1416       case Intrinsic::amdgcn_raw_ptr_buffer_load:
1417         OffsetIdx = 1;
1418         break;
1419       case Intrinsic::amdgcn_s_buffer_load:
1420         // If resulting type is vec3, there is no point in trimming the
1421         // load with updated offset, as the vec3 would most likely be widened to
1422         // vec4 anyway during lowering.
1423         if (ActiveBits == 4 && UnusedComponentsAtFront == 1)
1424           OffsetIdx = InvalidOffsetIdx;
1425         else
1426           OffsetIdx = 1;
1427         break;
1428       case Intrinsic::amdgcn_struct_buffer_load:
1429       case Intrinsic::amdgcn_struct_ptr_buffer_load:
1430         OffsetIdx = 2;
1431         break;
1432       default:
1433         // TODO: handle tbuffer* intrinsics.
1434         OffsetIdx = InvalidOffsetIdx;
1435         break;
1436       }
1437 
1438       if (OffsetIdx != InvalidOffsetIdx) {
1439         // Clear demanded bits and update the offset.
1440         DemandedElts &= ~((1 << UnusedComponentsAtFront) - 1);
1441         auto *Offset = Args[OffsetIdx];
1442         unsigned SingleComponentSizeInBits =
1443             IC.getDataLayout().getTypeSizeInBits(EltTy);
1444         unsigned OffsetAdd =
1445             UnusedComponentsAtFront * SingleComponentSizeInBits / 8;
1446         auto *OffsetAddVal = ConstantInt::get(Offset->getType(), OffsetAdd);
1447         Args[OffsetIdx] = IC.Builder.CreateAdd(Offset, OffsetAddVal);
1448       }
1449     }
1450   } else {
1451     // Image case.
1452 
1453     ConstantInt *DMask = cast<ConstantInt>(Args[DMaskIdx]);
1454     unsigned DMaskVal = DMask->getZExtValue() & 0xf;
1455 
1456     // dmask 0 has special semantics, do not simplify.
1457     if (DMaskVal == 0)
1458       return nullptr;
1459 
1460     // Mask off values that are undefined because the dmask doesn't cover them
1461     DemandedElts &= (1 << llvm::popcount(DMaskVal)) - 1;
1462 
1463     unsigned NewDMaskVal = 0;
1464     unsigned OrigLdStIdx = 0;
1465     for (unsigned SrcIdx = 0; SrcIdx < 4; ++SrcIdx) {
1466       const unsigned Bit = 1 << SrcIdx;
1467       if (!!(DMaskVal & Bit)) {
1468         if (!!DemandedElts[OrigLdStIdx])
1469           NewDMaskVal |= Bit;
1470         OrigLdStIdx++;
1471       }
1472     }
1473 
1474     if (DMaskVal != NewDMaskVal)
1475       Args[DMaskIdx] = ConstantInt::get(DMask->getType(), NewDMaskVal);
1476   }
1477 
1478   unsigned NewNumElts = DemandedElts.popcount();
1479   if (!NewNumElts)
1480     return PoisonValue::get(IIVTy);
1481 
1482   if (NewNumElts >= VWidth && DemandedElts.isMask()) {
1483     if (DMaskIdx >= 0)
1484       II.setArgOperand(DMaskIdx, Args[DMaskIdx]);
1485     return nullptr;
1486   }
1487 
1488   // Validate function argument and return types, extracting overloaded types
1489   // along the way.
1490   SmallVector<Type *, 6> OverloadTys;
1491   if (!Intrinsic::getIntrinsicSignature(II.getCalledFunction(), OverloadTys))
1492     return nullptr;
1493 
1494   Type *NewTy =
1495       (NewNumElts == 1) ? EltTy : FixedVectorType::get(EltTy, NewNumElts);
1496   OverloadTys[0] = NewTy;
1497 
1498   if (!IsLoad) {
1499     SmallVector<int, 8> EltMask;
1500     for (unsigned OrigStoreIdx = 0; OrigStoreIdx < VWidth; ++OrigStoreIdx)
1501       if (DemandedElts[OrigStoreIdx])
1502         EltMask.push_back(OrigStoreIdx);
1503 
1504     if (NewNumElts == 1)
1505       Args[0] = IC.Builder.CreateExtractElement(II.getOperand(0), EltMask[0]);
1506     else
1507       Args[0] = IC.Builder.CreateShuffleVector(II.getOperand(0), EltMask);
1508   }
1509 
1510   CallInst *NewCall =
1511       IC.Builder.CreateIntrinsic(II.getIntrinsicID(), OverloadTys, Args);
1512   NewCall->takeName(&II);
1513   NewCall->copyMetadata(II);
1514 
1515   if (IsLoad) {
1516     if (NewNumElts == 1) {
1517       return IC.Builder.CreateInsertElement(PoisonValue::get(IIVTy), NewCall,
1518                                             DemandedElts.countr_zero());
1519     }
1520 
1521     SmallVector<int, 8> EltMask;
1522     unsigned NewLoadIdx = 0;
1523     for (unsigned OrigLoadIdx = 0; OrigLoadIdx < VWidth; ++OrigLoadIdx) {
1524       if (!!DemandedElts[OrigLoadIdx])
1525         EltMask.push_back(NewLoadIdx++);
1526       else
1527         EltMask.push_back(NewNumElts);
1528     }
1529 
1530     auto *Shuffle = IC.Builder.CreateShuffleVector(NewCall, EltMask);
1531 
1532     return Shuffle;
1533   }
1534 
1535   return NewCall;
1536 }
1537 
1538 std::optional<Value *> GCNTTIImpl::simplifyDemandedVectorEltsIntrinsic(
1539     InstCombiner &IC, IntrinsicInst &II, APInt DemandedElts, APInt &UndefElts,
1540     APInt &UndefElts2, APInt &UndefElts3,
1541     std::function<void(Instruction *, unsigned, APInt, APInt &)>
1542         SimplifyAndSetOp) const {
1543   switch (II.getIntrinsicID()) {
1544   case Intrinsic::amdgcn_raw_buffer_load:
1545   case Intrinsic::amdgcn_raw_ptr_buffer_load:
1546   case Intrinsic::amdgcn_raw_buffer_load_format:
1547   case Intrinsic::amdgcn_raw_ptr_buffer_load_format:
1548   case Intrinsic::amdgcn_raw_tbuffer_load:
1549   case Intrinsic::amdgcn_raw_ptr_tbuffer_load:
1550   case Intrinsic::amdgcn_s_buffer_load:
1551   case Intrinsic::amdgcn_struct_buffer_load:
1552   case Intrinsic::amdgcn_struct_ptr_buffer_load:
1553   case Intrinsic::amdgcn_struct_buffer_load_format:
1554   case Intrinsic::amdgcn_struct_ptr_buffer_load_format:
1555   case Intrinsic::amdgcn_struct_tbuffer_load:
1556   case Intrinsic::amdgcn_struct_ptr_tbuffer_load:
1557     return simplifyAMDGCNMemoryIntrinsicDemanded(IC, II, DemandedElts);
1558   default: {
1559     if (getAMDGPUImageDMaskIntrinsic(II.getIntrinsicID())) {
1560       return simplifyAMDGCNMemoryIntrinsicDemanded(IC, II, DemandedElts, 0);
1561     }
1562     break;
1563   }
1564   }
1565   return std::nullopt;
1566 }
1567