xref: /llvm-project/llvm/lib/Transforms/Utils/LowerVectorIntrinsics.cpp (revision ab976a17121374ae3407374b2aa6306e95863eb3)
1*ab976a17SStephen Long //===- LowerVectorIntrinsics.cpp ------------------------------------------===//
2*ab976a17SStephen Long //
3*ab976a17SStephen Long // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4*ab976a17SStephen Long // See https://llvm.org/LICENSE.txt for license information.
5*ab976a17SStephen Long // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6*ab976a17SStephen Long //
7*ab976a17SStephen Long //===----------------------------------------------------------------------===//
8*ab976a17SStephen Long 
9*ab976a17SStephen Long #include "llvm/Transforms/Utils/LowerVectorIntrinsics.h"
10*ab976a17SStephen Long #include "llvm/IR/IRBuilder.h"
11*ab976a17SStephen Long #include "llvm/IR/IntrinsicInst.h"
12*ab976a17SStephen Long #include "llvm/Support/Debug.h"
13*ab976a17SStephen Long 
14*ab976a17SStephen Long #define DEBUG_TYPE "lower-vector-intrinsics"
15*ab976a17SStephen Long 
16*ab976a17SStephen Long using namespace llvm;
17*ab976a17SStephen Long 
18*ab976a17SStephen Long bool llvm::lowerUnaryVectorIntrinsicAsLoop(Module &M, CallInst *CI) {
19*ab976a17SStephen Long   Type *ArgTy = CI->getArgOperand(0)->getType();
20*ab976a17SStephen Long   VectorType *VecTy = cast<VectorType>(ArgTy);
21*ab976a17SStephen Long 
22*ab976a17SStephen Long   BasicBlock *PreLoopBB = CI->getParent();
23*ab976a17SStephen Long   BasicBlock *PostLoopBB = nullptr;
24*ab976a17SStephen Long   Function *ParentFunc = PreLoopBB->getParent();
25*ab976a17SStephen Long   LLVMContext &Ctx = PreLoopBB->getContext();
26*ab976a17SStephen Long 
27*ab976a17SStephen Long   PostLoopBB = PreLoopBB->splitBasicBlock(CI);
28*ab976a17SStephen Long   BasicBlock *LoopBB = BasicBlock::Create(Ctx, "", ParentFunc, PostLoopBB);
29*ab976a17SStephen Long   PreLoopBB->getTerminator()->setSuccessor(0, LoopBB);
30*ab976a17SStephen Long 
31*ab976a17SStephen Long   // Loop preheader
32*ab976a17SStephen Long   IRBuilder<> PreLoopBuilder(PreLoopBB->getTerminator());
33*ab976a17SStephen Long   Value *LoopEnd = nullptr;
34*ab976a17SStephen Long   if (auto *ScalableVecTy = dyn_cast<ScalableVectorType>(VecTy)) {
35*ab976a17SStephen Long     Value *VScale = PreLoopBuilder.CreateVScale(
36*ab976a17SStephen Long         ConstantInt::get(PreLoopBuilder.getInt64Ty(), 1));
37*ab976a17SStephen Long     Value *N = ConstantInt::get(PreLoopBuilder.getInt64Ty(),
38*ab976a17SStephen Long                                 ScalableVecTy->getMinNumElements());
39*ab976a17SStephen Long     LoopEnd = PreLoopBuilder.CreateMul(VScale, N);
40*ab976a17SStephen Long   } else {
41*ab976a17SStephen Long     FixedVectorType *FixedVecTy = cast<FixedVectorType>(VecTy);
42*ab976a17SStephen Long     LoopEnd = ConstantInt::get(PreLoopBuilder.getInt64Ty(),
43*ab976a17SStephen Long                                FixedVecTy->getNumElements());
44*ab976a17SStephen Long   }
45*ab976a17SStephen Long 
46*ab976a17SStephen Long   // Loop body
47*ab976a17SStephen Long   IRBuilder<> LoopBuilder(LoopBB);
48*ab976a17SStephen Long   Type *Int64Ty = LoopBuilder.getInt64Ty();
49*ab976a17SStephen Long 
50*ab976a17SStephen Long   PHINode *LoopIndex = LoopBuilder.CreatePHI(Int64Ty, 2);
51*ab976a17SStephen Long   LoopIndex->addIncoming(ConstantInt::get(Int64Ty, 0U), PreLoopBB);
52*ab976a17SStephen Long   PHINode *Vec = LoopBuilder.CreatePHI(VecTy, 2);
53*ab976a17SStephen Long   Vec->addIncoming(CI->getArgOperand(0), PreLoopBB);
54*ab976a17SStephen Long 
55*ab976a17SStephen Long   Value *Elem = LoopBuilder.CreateExtractElement(Vec, LoopIndex);
56*ab976a17SStephen Long   Function *Exp = Intrinsic::getOrInsertDeclaration(&M, CI->getIntrinsicID(),
57*ab976a17SStephen Long                                                     VecTy->getElementType());
58*ab976a17SStephen Long   Value *Res = LoopBuilder.CreateCall(Exp, Elem);
59*ab976a17SStephen Long   Value *NewVec = LoopBuilder.CreateInsertElement(Vec, Res, LoopIndex);
60*ab976a17SStephen Long   Vec->addIncoming(NewVec, LoopBB);
61*ab976a17SStephen Long 
62*ab976a17SStephen Long   Value *One = ConstantInt::get(Int64Ty, 1U);
63*ab976a17SStephen Long   Value *NextLoopIndex = LoopBuilder.CreateAdd(LoopIndex, One);
64*ab976a17SStephen Long   LoopIndex->addIncoming(NextLoopIndex, LoopBB);
65*ab976a17SStephen Long 
66*ab976a17SStephen Long   Value *ExitCond =
67*ab976a17SStephen Long       LoopBuilder.CreateICmp(CmpInst::ICMP_EQ, NextLoopIndex, LoopEnd);
68*ab976a17SStephen Long   LoopBuilder.CreateCondBr(ExitCond, PostLoopBB, LoopBB);
69*ab976a17SStephen Long 
70*ab976a17SStephen Long   CI->replaceAllUsesWith(NewVec);
71*ab976a17SStephen Long   CI->eraseFromParent();
72*ab976a17SStephen Long   return true;
73*ab976a17SStephen Long }
74