xref: /llvm-project/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp (revision b900379e26d9f49977c4d772f1b2b681fc5147d4)
1 //===- DXILIntrinsicExpansion.cpp - Prepare LLVM Module for DXIL encoding--===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 ///
9 /// \file This file contains DXIL intrinsic expansions for those that don't have
10 //  opcodes in DirectX Intermediate Language (DXIL).
11 //===----------------------------------------------------------------------===//
12 
13 #include "DXILIntrinsicExpansion.h"
14 #include "DirectX.h"
15 #include "llvm/ADT/STLExtras.h"
16 #include "llvm/ADT/SmallVector.h"
17 #include "llvm/CodeGen/Passes.h"
18 #include "llvm/IR/IRBuilder.h"
19 #include "llvm/IR/InstrTypes.h"
20 #include "llvm/IR/Instruction.h"
21 #include "llvm/IR/Instructions.h"
22 #include "llvm/IR/Intrinsics.h"
23 #include "llvm/IR/IntrinsicsDirectX.h"
24 #include "llvm/IR/Module.h"
25 #include "llvm/IR/PassManager.h"
26 #include "llvm/IR/Type.h"
27 #include "llvm/Pass.h"
28 #include "llvm/Support/ErrorHandling.h"
29 #include "llvm/Support/MathExtras.h"
30 
31 #define DEBUG_TYPE "dxil-intrinsic-expansion"
32 
33 using namespace llvm;
34 
35 class DXILIntrinsicExpansionLegacy : public ModulePass {
36 
37 public:
38   bool runOnModule(Module &M) override;
39   DXILIntrinsicExpansionLegacy() : ModulePass(ID) {}
40 
41   static char ID; // Pass identification.
42 };
43 
44 static bool isIntrinsicExpansion(Function &F) {
45   switch (F.getIntrinsicID()) {
46   case Intrinsic::abs:
47   case Intrinsic::atan2:
48   case Intrinsic::exp:
49   case Intrinsic::log:
50   case Intrinsic::log10:
51   case Intrinsic::pow:
52   case Intrinsic::dx_all:
53   case Intrinsic::dx_any:
54   case Intrinsic::dx_cross:
55   case Intrinsic::dx_uclamp:
56   case Intrinsic::dx_sclamp:
57   case Intrinsic::dx_nclamp:
58   case Intrinsic::dx_degrees:
59   case Intrinsic::dx_lerp:
60   case Intrinsic::dx_normalize:
61   case Intrinsic::dx_fdot:
62   case Intrinsic::dx_sdot:
63   case Intrinsic::dx_udot:
64   case Intrinsic::dx_sign:
65   case Intrinsic::dx_step:
66   case Intrinsic::dx_radians:
67   case Intrinsic::vector_reduce_add:
68   case Intrinsic::vector_reduce_fadd:
69     return true;
70   }
71   return false;
72 }
73 static Value *expandVecReduceAdd(CallInst *Orig, Intrinsic::ID IntrinsicId) {
74   assert(IntrinsicId == Intrinsic::vector_reduce_add ||
75          IntrinsicId == Intrinsic::vector_reduce_fadd);
76 
77   IRBuilder<> Builder(Orig);
78   bool IsFAdd = (IntrinsicId == Intrinsic::vector_reduce_fadd);
79 
80   Value *X = Orig->getOperand(IsFAdd ? 1 : 0);
81   Type *Ty = X->getType();
82   auto *XVec = dyn_cast<FixedVectorType>(Ty);
83   unsigned XVecSize = XVec->getNumElements();
84   Value *Sum = Builder.CreateExtractElement(X, static_cast<uint64_t>(0));
85 
86   // Handle the initial start value for floating-point addition.
87   if (IsFAdd) {
88     Constant *StartValue = dyn_cast<Constant>(Orig->getOperand(0));
89     if (StartValue && !StartValue->isZeroValue())
90       Sum = Builder.CreateFAdd(Sum, StartValue);
91   }
92 
93   // Accumulate the remaining vector elements.
94   for (unsigned I = 1; I < XVecSize; I++) {
95     Value *Elt = Builder.CreateExtractElement(X, I);
96     if (IsFAdd)
97       Sum = Builder.CreateFAdd(Sum, Elt);
98     else
99       Sum = Builder.CreateAdd(Sum, Elt);
100   }
101 
102   return Sum;
103 }
104 
105 static Value *expandAbs(CallInst *Orig) {
106   Value *X = Orig->getOperand(0);
107   IRBuilder<> Builder(Orig);
108   Type *Ty = X->getType();
109   Type *EltTy = Ty->getScalarType();
110   Constant *Zero = Ty->isVectorTy()
111                        ? ConstantVector::getSplat(
112                              ElementCount::getFixed(
113                                  cast<FixedVectorType>(Ty)->getNumElements()),
114                              ConstantInt::get(EltTy, 0))
115                        : ConstantInt::get(EltTy, 0);
116   auto *V = Builder.CreateSub(Zero, X);
117   return Builder.CreateIntrinsic(Ty, Intrinsic::smax, {X, V}, nullptr,
118                                  "dx.max");
119 }
120 
121 static Value *expandCrossIntrinsic(CallInst *Orig) {
122 
123   VectorType *VT = cast<VectorType>(Orig->getType());
124   if (cast<FixedVectorType>(VT)->getNumElements() != 3)
125     report_fatal_error(Twine("return vector must have exactly 3 elements"),
126                        /* gen_crash_diag=*/false);
127 
128   Value *op0 = Orig->getOperand(0);
129   Value *op1 = Orig->getOperand(1);
130   IRBuilder<> Builder(Orig);
131 
132   Value *op0_x = Builder.CreateExtractElement(op0, (uint64_t)0, "x0");
133   Value *op0_y = Builder.CreateExtractElement(op0, 1, "x1");
134   Value *op0_z = Builder.CreateExtractElement(op0, 2, "x2");
135 
136   Value *op1_x = Builder.CreateExtractElement(op1, (uint64_t)0, "y0");
137   Value *op1_y = Builder.CreateExtractElement(op1, 1, "y1");
138   Value *op1_z = Builder.CreateExtractElement(op1, 2, "y2");
139 
140   auto MulSub = [&](Value *x0, Value *y0, Value *x1, Value *y1) -> Value * {
141     Value *xy = Builder.CreateFMul(x0, y1);
142     Value *yx = Builder.CreateFMul(y0, x1);
143     return Builder.CreateFSub(xy, yx, Orig->getName());
144   };
145 
146   Value *yz_zy = MulSub(op0_y, op0_z, op1_y, op1_z);
147   Value *zx_xz = MulSub(op0_z, op0_x, op1_z, op1_x);
148   Value *xy_yx = MulSub(op0_x, op0_y, op1_x, op1_y);
149 
150   Value *cross = UndefValue::get(VT);
151   cross = Builder.CreateInsertElement(cross, yz_zy, (uint64_t)0);
152   cross = Builder.CreateInsertElement(cross, zx_xz, 1);
153   cross = Builder.CreateInsertElement(cross, xy_yx, 2);
154   return cross;
155 }
156 
157 // Create appropriate DXIL float dot intrinsic for the given A and B operands
158 // The appropriate opcode will be determined by the size of the operands
159 // The dot product is placed in the position indicated by Orig
160 static Value *expandFloatDotIntrinsic(CallInst *Orig, Value *A, Value *B) {
161   Type *ATy = A->getType();
162   [[maybe_unused]] Type *BTy = B->getType();
163   assert(ATy->isVectorTy() && BTy->isVectorTy());
164 
165   IRBuilder<> Builder(Orig);
166 
167   auto *AVec = dyn_cast<FixedVectorType>(ATy);
168 
169   assert(ATy->getScalarType()->isFloatingPointTy());
170 
171   Intrinsic::ID DotIntrinsic = Intrinsic::dx_dot4;
172   switch (AVec->getNumElements()) {
173   case 2:
174     DotIntrinsic = Intrinsic::dx_dot2;
175     break;
176   case 3:
177     DotIntrinsic = Intrinsic::dx_dot3;
178     break;
179   case 4:
180     DotIntrinsic = Intrinsic::dx_dot4;
181     break;
182   default:
183     report_fatal_error(
184         Twine("Invalid dot product input vector: length is outside 2-4"),
185         /* gen_crash_diag=*/false);
186     return nullptr;
187   }
188   return Builder.CreateIntrinsic(ATy->getScalarType(), DotIntrinsic,
189                                  ArrayRef<Value *>{A, B}, nullptr, "dot");
190 }
191 
192 // Create the appropriate DXIL float dot intrinsic for the operands of Orig
193 // The appropriate opcode will be determined by the size of the operands
194 // The dot product is placed in the position indicated by Orig
195 static Value *expandFloatDotIntrinsic(CallInst *Orig) {
196   return expandFloatDotIntrinsic(Orig, Orig->getOperand(0),
197                                  Orig->getOperand(1));
198 }
199 
200 // Expand integer dot product to multiply and add ops
201 static Value *expandIntegerDotIntrinsic(CallInst *Orig,
202                                         Intrinsic::ID DotIntrinsic) {
203   assert(DotIntrinsic == Intrinsic::dx_sdot ||
204          DotIntrinsic == Intrinsic::dx_udot);
205   Value *A = Orig->getOperand(0);
206   Value *B = Orig->getOperand(1);
207   Type *ATy = A->getType();
208   [[maybe_unused]] Type *BTy = B->getType();
209   assert(ATy->isVectorTy() && BTy->isVectorTy());
210 
211   IRBuilder<> Builder(Orig);
212 
213   auto *AVec = dyn_cast<FixedVectorType>(ATy);
214 
215   assert(ATy->getScalarType()->isIntegerTy());
216 
217   Value *Result;
218   Intrinsic::ID MadIntrinsic = DotIntrinsic == Intrinsic::dx_sdot
219                                    ? Intrinsic::dx_imad
220                                    : Intrinsic::dx_umad;
221   Value *Elt0 = Builder.CreateExtractElement(A, (uint64_t)0);
222   Value *Elt1 = Builder.CreateExtractElement(B, (uint64_t)0);
223   Result = Builder.CreateMul(Elt0, Elt1);
224   for (unsigned I = 1; I < AVec->getNumElements(); I++) {
225     Elt0 = Builder.CreateExtractElement(A, I);
226     Elt1 = Builder.CreateExtractElement(B, I);
227     Result = Builder.CreateIntrinsic(Result->getType(), MadIntrinsic,
228                                      ArrayRef<Value *>{Elt0, Elt1, Result},
229                                      nullptr, "dx.mad");
230   }
231   return Result;
232 }
233 
234 static Value *expandExpIntrinsic(CallInst *Orig) {
235   Value *X = Orig->getOperand(0);
236   IRBuilder<> Builder(Orig);
237   Type *Ty = X->getType();
238   Type *EltTy = Ty->getScalarType();
239   Constant *Log2eConst =
240       Ty->isVectorTy() ? ConstantVector::getSplat(
241                              ElementCount::getFixed(
242                                  cast<FixedVectorType>(Ty)->getNumElements()),
243                              ConstantFP::get(EltTy, numbers::log2ef))
244                        : ConstantFP::get(EltTy, numbers::log2ef);
245   Value *NewX = Builder.CreateFMul(Log2eConst, X);
246   auto *Exp2Call =
247       Builder.CreateIntrinsic(Ty, Intrinsic::exp2, {NewX}, nullptr, "dx.exp2");
248   Exp2Call->setTailCall(Orig->isTailCall());
249   Exp2Call->setAttributes(Orig->getAttributes());
250   return Exp2Call;
251 }
252 
253 static Value *expandAnyOrAllIntrinsic(CallInst *Orig,
254                                       Intrinsic::ID intrinsicId) {
255   Value *X = Orig->getOperand(0);
256   IRBuilder<> Builder(Orig);
257   Type *Ty = X->getType();
258   Type *EltTy = Ty->getScalarType();
259 
260   auto ApplyOp = [&Builder](Intrinsic::ID IntrinsicId, Value *Result,
261                             Value *Elt) {
262     if (IntrinsicId == Intrinsic::dx_any)
263       return Builder.CreateOr(Result, Elt);
264     assert(IntrinsicId == Intrinsic::dx_all);
265     return Builder.CreateAnd(Result, Elt);
266   };
267 
268   Value *Result = nullptr;
269   if (!Ty->isVectorTy()) {
270     Result = EltTy->isFloatingPointTy()
271                  ? Builder.CreateFCmpUNE(X, ConstantFP::get(EltTy, 0))
272                  : Builder.CreateICmpNE(X, ConstantInt::get(EltTy, 0));
273   } else {
274     auto *XVec = dyn_cast<FixedVectorType>(Ty);
275     Value *Cond =
276         EltTy->isFloatingPointTy()
277             ? Builder.CreateFCmpUNE(
278                   X, ConstantVector::getSplat(
279                          ElementCount::getFixed(XVec->getNumElements()),
280                          ConstantFP::get(EltTy, 0)))
281             : Builder.CreateICmpNE(
282                   X, ConstantVector::getSplat(
283                          ElementCount::getFixed(XVec->getNumElements()),
284                          ConstantInt::get(EltTy, 0)));
285     Result = Builder.CreateExtractElement(Cond, (uint64_t)0);
286     for (unsigned I = 1; I < XVec->getNumElements(); I++) {
287       Value *Elt = Builder.CreateExtractElement(Cond, I);
288       Result = ApplyOp(intrinsicId, Result, Elt);
289     }
290   }
291   return Result;
292 }
293 
294 static Value *expandLerpIntrinsic(CallInst *Orig) {
295   Value *X = Orig->getOperand(0);
296   Value *Y = Orig->getOperand(1);
297   Value *S = Orig->getOperand(2);
298   IRBuilder<> Builder(Orig);
299   auto *V = Builder.CreateFSub(Y, X);
300   V = Builder.CreateFMul(S, V);
301   return Builder.CreateFAdd(X, V, "dx.lerp");
302 }
303 
304 static Value *expandLogIntrinsic(CallInst *Orig,
305                                  float LogConstVal = numbers::ln2f) {
306   Value *X = Orig->getOperand(0);
307   IRBuilder<> Builder(Orig);
308   Type *Ty = X->getType();
309   Type *EltTy = Ty->getScalarType();
310   Constant *Ln2Const =
311       Ty->isVectorTy() ? ConstantVector::getSplat(
312                              ElementCount::getFixed(
313                                  cast<FixedVectorType>(Ty)->getNumElements()),
314                              ConstantFP::get(EltTy, LogConstVal))
315                        : ConstantFP::get(EltTy, LogConstVal);
316   auto *Log2Call =
317       Builder.CreateIntrinsic(Ty, Intrinsic::log2, {X}, nullptr, "elt.log2");
318   Log2Call->setTailCall(Orig->isTailCall());
319   Log2Call->setAttributes(Orig->getAttributes());
320   return Builder.CreateFMul(Ln2Const, Log2Call);
321 }
322 static Value *expandLog10Intrinsic(CallInst *Orig) {
323   return expandLogIntrinsic(Orig, numbers::ln2f / numbers::ln10f);
324 }
325 
326 // Use dot product of vector operand with itself to calculate the length.
327 // Divide the vector by that length to normalize it.
328 static Value *expandNormalizeIntrinsic(CallInst *Orig) {
329   Value *X = Orig->getOperand(0);
330   Type *Ty = Orig->getType();
331   Type *EltTy = Ty->getScalarType();
332   IRBuilder<> Builder(Orig);
333 
334   auto *XVec = dyn_cast<FixedVectorType>(Ty);
335   if (!XVec) {
336     if (auto *constantFP = dyn_cast<ConstantFP>(X)) {
337       const APFloat &fpVal = constantFP->getValueAPF();
338       if (fpVal.isZero())
339         report_fatal_error(Twine("Invalid input scalar: length is zero"),
340                            /* gen_crash_diag=*/false);
341     }
342     return Builder.CreateFDiv(X, X);
343   }
344 
345   Value *DotProduct = expandFloatDotIntrinsic(Orig, X, X);
346 
347   // verify that the length is non-zero
348   // (if the dot product is non-zero, then the length is non-zero)
349   if (auto *constantFP = dyn_cast<ConstantFP>(DotProduct)) {
350     const APFloat &fpVal = constantFP->getValueAPF();
351     if (fpVal.isZero())
352       report_fatal_error(Twine("Invalid input vector: length is zero"),
353                          /* gen_crash_diag=*/false);
354   }
355 
356   Value *Multiplicand = Builder.CreateIntrinsic(EltTy, Intrinsic::dx_rsqrt,
357                                                 ArrayRef<Value *>{DotProduct},
358                                                 nullptr, "dx.rsqrt");
359 
360   Value *MultiplicandVec =
361       Builder.CreateVectorSplat(XVec->getNumElements(), Multiplicand);
362   return Builder.CreateFMul(X, MultiplicandVec);
363 }
364 
365 static Value *expandAtan2Intrinsic(CallInst *Orig) {
366   Value *Y = Orig->getOperand(0);
367   Value *X = Orig->getOperand(1);
368   Type *Ty = X->getType();
369   IRBuilder<> Builder(Orig);
370   Builder.setFastMathFlags(Orig->getFastMathFlags());
371 
372   Value *Tan = Builder.CreateFDiv(Y, X);
373 
374   CallInst *Atan =
375       Builder.CreateIntrinsic(Ty, Intrinsic::atan, {Tan}, nullptr, "Elt.Atan");
376   Atan->setTailCall(Orig->isTailCall());
377   Atan->setAttributes(Orig->getAttributes());
378 
379   // Modify atan result based on https://en.wikipedia.org/wiki/Atan2.
380   Constant *Pi = ConstantFP::get(Ty, llvm::numbers::pi);
381   Constant *HalfPi = ConstantFP::get(Ty, llvm::numbers::pi / 2);
382   Constant *NegHalfPi = ConstantFP::get(Ty, -llvm::numbers::pi / 2);
383   Constant *Zero = ConstantFP::get(Ty, 0);
384   Value *AtanAddPi = Builder.CreateFAdd(Atan, Pi);
385   Value *AtanSubPi = Builder.CreateFSub(Atan, Pi);
386 
387   // x > 0 -> atan.
388   Value *Result = Atan;
389   Value *XLt0 = Builder.CreateFCmpOLT(X, Zero);
390   Value *XEq0 = Builder.CreateFCmpOEQ(X, Zero);
391   Value *YGe0 = Builder.CreateFCmpOGE(Y, Zero);
392   Value *YLt0 = Builder.CreateFCmpOLT(Y, Zero);
393 
394   // x < 0, y >= 0 -> atan + pi.
395   Value *XLt0AndYGe0 = Builder.CreateAnd(XLt0, YGe0);
396   Result = Builder.CreateSelect(XLt0AndYGe0, AtanAddPi, Result);
397 
398   // x < 0, y < 0 -> atan - pi.
399   Value *XLt0AndYLt0 = Builder.CreateAnd(XLt0, YLt0);
400   Result = Builder.CreateSelect(XLt0AndYLt0, AtanSubPi, Result);
401 
402   // x == 0, y < 0 -> -pi/2
403   Value *XEq0AndYLt0 = Builder.CreateAnd(XEq0, YLt0);
404   Result = Builder.CreateSelect(XEq0AndYLt0, NegHalfPi, Result);
405 
406   // x == 0, y > 0 -> pi/2
407   Value *XEq0AndYGe0 = Builder.CreateAnd(XEq0, YGe0);
408   Result = Builder.CreateSelect(XEq0AndYGe0, HalfPi, Result);
409 
410   return Result;
411 }
412 
413 static Value *expandPowIntrinsic(CallInst *Orig) {
414 
415   Value *X = Orig->getOperand(0);
416   Value *Y = Orig->getOperand(1);
417   Type *Ty = X->getType();
418   IRBuilder<> Builder(Orig);
419 
420   auto *Log2Call =
421       Builder.CreateIntrinsic(Ty, Intrinsic::log2, {X}, nullptr, "elt.log2");
422   auto *Mul = Builder.CreateFMul(Log2Call, Y);
423   auto *Exp2Call =
424       Builder.CreateIntrinsic(Ty, Intrinsic::exp2, {Mul}, nullptr, "elt.exp2");
425   Exp2Call->setTailCall(Orig->isTailCall());
426   Exp2Call->setAttributes(Orig->getAttributes());
427   return Exp2Call;
428 }
429 
430 static Value *expandStepIntrinsic(CallInst *Orig) {
431 
432   Value *X = Orig->getOperand(0);
433   Value *Y = Orig->getOperand(1);
434   Type *Ty = X->getType();
435   IRBuilder<> Builder(Orig);
436 
437   Constant *One = ConstantFP::get(Ty->getScalarType(), 1.0);
438   Constant *Zero = ConstantFP::get(Ty->getScalarType(), 0.0);
439   Value *Cond = Builder.CreateFCmpOLT(Y, X);
440 
441   if (Ty != Ty->getScalarType()) {
442     auto *XVec = dyn_cast<FixedVectorType>(Ty);
443     One = ConstantVector::getSplat(
444         ElementCount::getFixed(XVec->getNumElements()), One);
445     Zero = ConstantVector::getSplat(
446         ElementCount::getFixed(XVec->getNumElements()), Zero);
447   }
448 
449   return Builder.CreateSelect(Cond, Zero, One);
450 }
451 
452 static Value *expandRadiansIntrinsic(CallInst *Orig) {
453   Value *X = Orig->getOperand(0);
454   Type *Ty = X->getType();
455   IRBuilder<> Builder(Orig);
456   Value *PiOver180 = ConstantFP::get(Ty, llvm::numbers::pi / 180.0);
457   return Builder.CreateFMul(X, PiOver180);
458 }
459 
460 static Intrinsic::ID getMaxForClamp(Intrinsic::ID ClampIntrinsic) {
461   if (ClampIntrinsic == Intrinsic::dx_uclamp)
462     return Intrinsic::umax;
463   if (ClampIntrinsic == Intrinsic::dx_sclamp)
464     return Intrinsic::smax;
465   assert(ClampIntrinsic == Intrinsic::dx_nclamp);
466   return Intrinsic::maxnum;
467 }
468 
469 static Intrinsic::ID getMinForClamp(Intrinsic::ID ClampIntrinsic) {
470   if (ClampIntrinsic == Intrinsic::dx_uclamp)
471     return Intrinsic::umin;
472   if (ClampIntrinsic == Intrinsic::dx_sclamp)
473     return Intrinsic::smin;
474   assert(ClampIntrinsic == Intrinsic::dx_nclamp);
475   return Intrinsic::minnum;
476 }
477 
478 static Value *expandClampIntrinsic(CallInst *Orig,
479                                    Intrinsic::ID ClampIntrinsic) {
480   Value *X = Orig->getOperand(0);
481   Value *Min = Orig->getOperand(1);
482   Value *Max = Orig->getOperand(2);
483   Type *Ty = X->getType();
484   IRBuilder<> Builder(Orig);
485   auto *MaxCall = Builder.CreateIntrinsic(Ty, getMaxForClamp(ClampIntrinsic),
486                                           {X, Min}, nullptr, "dx.max");
487   return Builder.CreateIntrinsic(Ty, getMinForClamp(ClampIntrinsic),
488                                  {MaxCall, Max}, nullptr, "dx.min");
489 }
490 
491 static Value *expandDegreesIntrinsic(CallInst *Orig) {
492   Value *X = Orig->getOperand(0);
493   Type *Ty = X->getType();
494   IRBuilder<> Builder(Orig);
495   Value *DegreesRatio = ConstantFP::get(Ty, 180.0 * llvm::numbers::inv_pi);
496   return Builder.CreateFMul(X, DegreesRatio);
497 }
498 
499 static Value *expandSignIntrinsic(CallInst *Orig) {
500   Value *X = Orig->getOperand(0);
501   Type *Ty = X->getType();
502   Type *ScalarTy = Ty->getScalarType();
503   Type *RetTy = Orig->getType();
504   Constant *Zero = Constant::getNullValue(Ty);
505 
506   IRBuilder<> Builder(Orig);
507 
508   Value *GT;
509   Value *LT;
510   if (ScalarTy->isFloatingPointTy()) {
511     GT = Builder.CreateFCmpOLT(Zero, X);
512     LT = Builder.CreateFCmpOLT(X, Zero);
513   } else {
514     assert(ScalarTy->isIntegerTy());
515     GT = Builder.CreateICmpSLT(Zero, X);
516     LT = Builder.CreateICmpSLT(X, Zero);
517   }
518 
519   Value *ZextGT = Builder.CreateZExt(GT, RetTy);
520   Value *ZextLT = Builder.CreateZExt(LT, RetTy);
521 
522   return Builder.CreateSub(ZextGT, ZextLT);
523 }
524 
525 static bool expandIntrinsic(Function &F, CallInst *Orig) {
526   Value *Result = nullptr;
527   Intrinsic::ID IntrinsicId = F.getIntrinsicID();
528   switch (IntrinsicId) {
529   case Intrinsic::abs:
530     Result = expandAbs(Orig);
531     break;
532   case Intrinsic::atan2:
533     Result = expandAtan2Intrinsic(Orig);
534     break;
535   case Intrinsic::exp:
536     Result = expandExpIntrinsic(Orig);
537     break;
538   case Intrinsic::log:
539     Result = expandLogIntrinsic(Orig);
540     break;
541   case Intrinsic::log10:
542     Result = expandLog10Intrinsic(Orig);
543     break;
544   case Intrinsic::pow:
545     Result = expandPowIntrinsic(Orig);
546     break;
547   case Intrinsic::dx_all:
548   case Intrinsic::dx_any:
549     Result = expandAnyOrAllIntrinsic(Orig, IntrinsicId);
550     break;
551   case Intrinsic::dx_cross:
552     Result = expandCrossIntrinsic(Orig);
553     break;
554   case Intrinsic::dx_uclamp:
555   case Intrinsic::dx_sclamp:
556   case Intrinsic::dx_nclamp:
557     Result = expandClampIntrinsic(Orig, IntrinsicId);
558     break;
559   case Intrinsic::dx_degrees:
560     Result = expandDegreesIntrinsic(Orig);
561     break;
562   case Intrinsic::dx_lerp:
563     Result = expandLerpIntrinsic(Orig);
564     break;
565   case Intrinsic::dx_normalize:
566     Result = expandNormalizeIntrinsic(Orig);
567     break;
568   case Intrinsic::dx_fdot:
569     Result = expandFloatDotIntrinsic(Orig);
570     break;
571   case Intrinsic::dx_sdot:
572   case Intrinsic::dx_udot:
573     Result = expandIntegerDotIntrinsic(Orig, IntrinsicId);
574     break;
575   case Intrinsic::dx_sign:
576     Result = expandSignIntrinsic(Orig);
577     break;
578   case Intrinsic::dx_step:
579     Result = expandStepIntrinsic(Orig);
580     break;
581   case Intrinsic::dx_radians:
582     Result = expandRadiansIntrinsic(Orig);
583     break;
584   case Intrinsic::vector_reduce_add:
585   case Intrinsic::vector_reduce_fadd:
586     Result = expandVecReduceAdd(Orig, IntrinsicId);
587     break;
588   }
589   if (Result) {
590     Orig->replaceAllUsesWith(Result);
591     Orig->eraseFromParent();
592     return true;
593   }
594   return false;
595 }
596 
597 static bool expansionIntrinsics(Module &M) {
598   for (auto &F : make_early_inc_range(M.functions())) {
599     if (!isIntrinsicExpansion(F))
600       continue;
601     bool IntrinsicExpanded = false;
602     for (User *U : make_early_inc_range(F.users())) {
603       auto *IntrinsicCall = dyn_cast<CallInst>(U);
604       if (!IntrinsicCall)
605         continue;
606       IntrinsicExpanded = expandIntrinsic(F, IntrinsicCall);
607     }
608     if (F.user_empty() && IntrinsicExpanded)
609       F.eraseFromParent();
610   }
611   return true;
612 }
613 
614 PreservedAnalyses DXILIntrinsicExpansion::run(Module &M,
615                                               ModuleAnalysisManager &) {
616   if (expansionIntrinsics(M))
617     return PreservedAnalyses::none();
618   return PreservedAnalyses::all();
619 }
620 
621 bool DXILIntrinsicExpansionLegacy::runOnModule(Module &M) {
622   return expansionIntrinsics(M);
623 }
624 
625 char DXILIntrinsicExpansionLegacy::ID = 0;
626 
627 INITIALIZE_PASS_BEGIN(DXILIntrinsicExpansionLegacy, DEBUG_TYPE,
628                       "DXIL Intrinsic Expansion", false, false)
629 INITIALIZE_PASS_END(DXILIntrinsicExpansionLegacy, DEBUG_TYPE,
630                     "DXIL Intrinsic Expansion", false, false)
631 
632 ModulePass *llvm::createDXILIntrinsicExpansionLegacyPass() {
633   return new DXILIntrinsicExpansionLegacy();
634 }
635