1f1f57a31SAmjad Aboud //===- TruncInstCombine.cpp -----------------------------------------------===//
2f1f57a31SAmjad Aboud //
32946cd70SChandler Carruth // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
42946cd70SChandler Carruth // See https://llvm.org/LICENSE.txt for license information.
52946cd70SChandler Carruth // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6f1f57a31SAmjad Aboud //
7f1f57a31SAmjad Aboud //===----------------------------------------------------------------------===//
8f1f57a31SAmjad Aboud //
90dd84013SAnton Afanasyev // TruncInstCombine - looks for expression graphs post-dominated by TruncInst
100dd84013SAnton Afanasyev // and for each eligible graph, it will create a reduced bit-width expression,
110dd84013SAnton Afanasyev // replace the old expression with this new one and remove the old expression.
120dd84013SAnton Afanasyev // Eligible expression graph is such that:
13f1f57a31SAmjad Aboud // 1. Contains only supported instructions.
14f1f57a31SAmjad Aboud // 2. Supported leaves: ZExtInst, SExtInst, TruncInst and Constant value.
15f1f57a31SAmjad Aboud // 3. Can be evaluated into type with reduced legal bit-width.
160dd84013SAnton Afanasyev // 4. All instructions in the graph must not have users outside the graph.
17f1f57a31SAmjad Aboud // The only exception is for {ZExt, SExt}Inst with operand type equal to
18f1f57a31SAmjad Aboud // the new reduced type evaluated in (3).
19f1f57a31SAmjad Aboud //
20f1f57a31SAmjad Aboud // The motivation for this optimization is that evaluating and expression using
21f1f57a31SAmjad Aboud // smaller bit-width is preferable, especially for vectorization where we can
22f1f57a31SAmjad Aboud // fit more values in one vectorized instruction. In addition, this optimization
23f1f57a31SAmjad Aboud // may decrease the number of cast instructions, but will not increase it.
24f1f57a31SAmjad Aboud //
25f1f57a31SAmjad Aboud //===----------------------------------------------------------------------===//
26f1f57a31SAmjad Aboud
27f1f57a31SAmjad Aboud #include "AggressiveInstCombineInternal.h"
28f1f57a31SAmjad Aboud #include "llvm/ADT/STLExtras.h"
29e987ee63SRoman Lebedev #include "llvm/ADT/Statistic.h"
30f1f57a31SAmjad Aboud #include "llvm/Analysis/ConstantFolding.h"
31f1f57a31SAmjad Aboud #include "llvm/IR/DataLayout.h"
32d895bff5SAmjad Aboud #include "llvm/IR/Dominators.h"
33f1f57a31SAmjad Aboud #include "llvm/IR/IRBuilder.h"
340347f3eaSSimon Pilgrim #include "llvm/IR/Instruction.h"
351f3e35b6SAnton Afanasyev #include "llvm/Support/KnownBits.h"
36e987ee63SRoman Lebedev
37f1f57a31SAmjad Aboud using namespace llvm;
38f1f57a31SAmjad Aboud
39f1f57a31SAmjad Aboud #define DEBUG_TYPE "aggressive-instcombine"
40f1f57a31SAmjad Aboud
410dd84013SAnton Afanasyev STATISTIC(NumExprsReduced, "Number of truncations eliminated by reducing bit "
420dd84013SAnton Afanasyev "width of expression graph");
43e987ee63SRoman Lebedev STATISTIC(NumInstrsReduced,
44e987ee63SRoman Lebedev "Number of instructions whose bit width was reduced");
45e987ee63SRoman Lebedev
46f1f57a31SAmjad Aboud /// Given an instruction and a container, it fills all the relevant operands of
470dd84013SAnton Afanasyev /// that instruction, with respect to the Trunc expression graph optimizaton.
getRelevantOperands(Instruction * I,SmallVectorImpl<Value * > & Ops)48f1f57a31SAmjad Aboud static void getRelevantOperands(Instruction *I, SmallVectorImpl<Value *> &Ops) {
49f1f57a31SAmjad Aboud unsigned Opc = I->getOpcode();
50f1f57a31SAmjad Aboud switch (Opc) {
51f1f57a31SAmjad Aboud case Instruction::Trunc:
52f1f57a31SAmjad Aboud case Instruction::ZExt:
53f1f57a31SAmjad Aboud case Instruction::SExt:
54f1f57a31SAmjad Aboud // These CastInst are considered leaves of the evaluated expression, thus,
55f1f57a31SAmjad Aboud // their operands are not relevent.
56f1f57a31SAmjad Aboud break;
57f1f57a31SAmjad Aboud case Instruction::Add:
58f1f57a31SAmjad Aboud case Instruction::Sub:
59f1f57a31SAmjad Aboud case Instruction::Mul:
60f1f57a31SAmjad Aboud case Instruction::And:
61f1f57a31SAmjad Aboud case Instruction::Or:
62f1f57a31SAmjad Aboud case Instruction::Xor:
631f3e35b6SAnton Afanasyev case Instruction::Shl:
64cfb6dfcbSAnton Afanasyev case Instruction::LShr:
65bed58763SAnton Afanasyev case Instruction::AShr:
6654d8ebbbSAnton Afanasyev case Instruction::UDiv:
6754d8ebbbSAnton Afanasyev case Instruction::URem:
686a5f49a1SAnton Afanasyev case Instruction::InsertElement:
69f1f57a31SAmjad Aboud Ops.push_back(I->getOperand(0));
70f1f57a31SAmjad Aboud Ops.push_back(I->getOperand(1));
71f1f57a31SAmjad Aboud break;
726a5f49a1SAnton Afanasyev case Instruction::ExtractElement:
736a5f49a1SAnton Afanasyev Ops.push_back(I->getOperand(0));
746a5f49a1SAnton Afanasyev break;
753bda9059SAyman Musa case Instruction::Select:
763bda9059SAyman Musa Ops.push_back(I->getOperand(1));
773bda9059SAyman Musa Ops.push_back(I->getOperand(2));
783bda9059SAyman Musa break;
790dd84013SAnton Afanasyev case Instruction::PHI:
800dd84013SAnton Afanasyev for (Value *V : cast<PHINode>(I)->incoming_values())
810dd84013SAnton Afanasyev Ops.push_back(V);
820dd84013SAnton Afanasyev break;
83f1f57a31SAmjad Aboud default:
84f1f57a31SAmjad Aboud llvm_unreachable("Unreachable!");
85f1f57a31SAmjad Aboud }
86f1f57a31SAmjad Aboud }
87f1f57a31SAmjad Aboud
buildTruncExpressionGraph()880dd84013SAnton Afanasyev bool TruncInstCombine::buildTruncExpressionGraph() {
89f1f57a31SAmjad Aboud SmallVector<Value *, 8> Worklist;
90f1f57a31SAmjad Aboud SmallVector<Instruction *, 8> Stack;
910dd84013SAnton Afanasyev // Clear old instructions info.
92f1f57a31SAmjad Aboud InstInfoMap.clear();
93f1f57a31SAmjad Aboud
94f1f57a31SAmjad Aboud Worklist.push_back(CurrentTruncInst->getOperand(0));
95f1f57a31SAmjad Aboud
96f1f57a31SAmjad Aboud while (!Worklist.empty()) {
97f1f57a31SAmjad Aboud Value *Curr = Worklist.back();
98f1f57a31SAmjad Aboud
99f1f57a31SAmjad Aboud if (isa<Constant>(Curr)) {
100f1f57a31SAmjad Aboud Worklist.pop_back();
101f1f57a31SAmjad Aboud continue;
102f1f57a31SAmjad Aboud }
103f1f57a31SAmjad Aboud
104f1f57a31SAmjad Aboud auto *I = dyn_cast<Instruction>(Curr);
105f1f57a31SAmjad Aboud if (!I)
106f1f57a31SAmjad Aboud return false;
107f1f57a31SAmjad Aboud
108f1f57a31SAmjad Aboud if (!Stack.empty() && Stack.back() == I) {
109f1f57a31SAmjad Aboud // Already handled all instruction operands, can remove it from both the
110f1f57a31SAmjad Aboud // Worklist and the Stack, and add it to the instruction info map.
111f1f57a31SAmjad Aboud Worklist.pop_back();
112f1f57a31SAmjad Aboud Stack.pop_back();
113f1f57a31SAmjad Aboud // Insert I to the Info map.
114f1f57a31SAmjad Aboud InstInfoMap.insert(std::make_pair(I, Info()));
115f1f57a31SAmjad Aboud continue;
116f1f57a31SAmjad Aboud }
117f1f57a31SAmjad Aboud
118f1f57a31SAmjad Aboud if (InstInfoMap.count(I)) {
119f1f57a31SAmjad Aboud Worklist.pop_back();
120f1f57a31SAmjad Aboud continue;
121f1f57a31SAmjad Aboud }
122f1f57a31SAmjad Aboud
123f1f57a31SAmjad Aboud // Add the instruction to the stack before start handling its operands.
124f1f57a31SAmjad Aboud Stack.push_back(I);
125f1f57a31SAmjad Aboud
126f1f57a31SAmjad Aboud unsigned Opc = I->getOpcode();
127f1f57a31SAmjad Aboud switch (Opc) {
128f1f57a31SAmjad Aboud case Instruction::Trunc:
129f1f57a31SAmjad Aboud case Instruction::ZExt:
130f1f57a31SAmjad Aboud case Instruction::SExt:
131f1f57a31SAmjad Aboud // trunc(trunc(x)) -> trunc(x)
132f1f57a31SAmjad Aboud // trunc(ext(x)) -> ext(x) if the source type is smaller than the new dest
133f1f57a31SAmjad Aboud // trunc(ext(x)) -> trunc(x) if the source type is larger than the new
134f1f57a31SAmjad Aboud // dest
135f1f57a31SAmjad Aboud break;
136f1f57a31SAmjad Aboud case Instruction::Add:
137f1f57a31SAmjad Aboud case Instruction::Sub:
138f1f57a31SAmjad Aboud case Instruction::Mul:
139f1f57a31SAmjad Aboud case Instruction::And:
140f1f57a31SAmjad Aboud case Instruction::Or:
1413bda9059SAyman Musa case Instruction::Xor:
1421f3e35b6SAnton Afanasyev case Instruction::Shl:
143cfb6dfcbSAnton Afanasyev case Instruction::LShr:
144bed58763SAnton Afanasyev case Instruction::AShr:
14554d8ebbbSAnton Afanasyev case Instruction::UDiv:
14654d8ebbbSAnton Afanasyev case Instruction::URem:
1476a5f49a1SAnton Afanasyev case Instruction::InsertElement:
1486a5f49a1SAnton Afanasyev case Instruction::ExtractElement:
14935f02aa0SAyman Musa case Instruction::Select: {
150f1f57a31SAmjad Aboud SmallVector<Value *, 2> Operands;
151f1f57a31SAmjad Aboud getRelevantOperands(I, Operands);
152e53472deSKazu Hirata append_range(Worklist, Operands);
153f1f57a31SAmjad Aboud break;
154f1f57a31SAmjad Aboud }
1550dd84013SAnton Afanasyev case Instruction::PHI: {
1560dd84013SAnton Afanasyev SmallVector<Value *, 2> Operands;
1570dd84013SAnton Afanasyev getRelevantOperands(I, Operands);
1580dd84013SAnton Afanasyev // Add only operands not in Stack to prevent cycle
1590dd84013SAnton Afanasyev for (auto *Op : Operands)
16021de2888SKazu Hirata if (!llvm::is_contained(Stack, Op))
1610dd84013SAnton Afanasyev Worklist.push_back(Op);
1620dd84013SAnton Afanasyev break;
1630dd84013SAnton Afanasyev }
164f1f57a31SAmjad Aboud default:
165f1f57a31SAmjad Aboud // TODO: Can handle more cases here:
1666a5f49a1SAnton Afanasyev // 1. shufflevector
16754d8ebbbSAnton Afanasyev // 2. sdiv, srem
168f1f57a31SAmjad Aboud // ...
169f1f57a31SAmjad Aboud return false;
170f1f57a31SAmjad Aboud }
171f1f57a31SAmjad Aboud }
172f1f57a31SAmjad Aboud return true;
173f1f57a31SAmjad Aboud }
174f1f57a31SAmjad Aboud
getMinBitWidth()175f1f57a31SAmjad Aboud unsigned TruncInstCombine::getMinBitWidth() {
176f1f57a31SAmjad Aboud SmallVector<Value *, 8> Worklist;
177f1f57a31SAmjad Aboud SmallVector<Instruction *, 8> Stack;
178f1f57a31SAmjad Aboud
179f1f57a31SAmjad Aboud Value *Src = CurrentTruncInst->getOperand(0);
180f1f57a31SAmjad Aboud Type *DstTy = CurrentTruncInst->getType();
181f1f57a31SAmjad Aboud unsigned TruncBitWidth = DstTy->getScalarSizeInBits();
182f1f57a31SAmjad Aboud unsigned OrigBitWidth =
183f1f57a31SAmjad Aboud CurrentTruncInst->getOperand(0)->getType()->getScalarSizeInBits();
184f1f57a31SAmjad Aboud
185f1f57a31SAmjad Aboud if (isa<Constant>(Src))
186f1f57a31SAmjad Aboud return TruncBitWidth;
187f1f57a31SAmjad Aboud
188f1f57a31SAmjad Aboud Worklist.push_back(Src);
189f1f57a31SAmjad Aboud InstInfoMap[cast<Instruction>(Src)].ValidBitWidth = TruncBitWidth;
190f1f57a31SAmjad Aboud
191f1f57a31SAmjad Aboud while (!Worklist.empty()) {
192f1f57a31SAmjad Aboud Value *Curr = Worklist.back();
193f1f57a31SAmjad Aboud
194f1f57a31SAmjad Aboud if (isa<Constant>(Curr)) {
195f1f57a31SAmjad Aboud Worklist.pop_back();
196f1f57a31SAmjad Aboud continue;
197f1f57a31SAmjad Aboud }
198f1f57a31SAmjad Aboud
199f1f57a31SAmjad Aboud // Otherwise, it must be an instruction.
200f1f57a31SAmjad Aboud auto *I = cast<Instruction>(Curr);
201f1f57a31SAmjad Aboud
202f1f57a31SAmjad Aboud auto &Info = InstInfoMap[I];
203f1f57a31SAmjad Aboud
204f1f57a31SAmjad Aboud SmallVector<Value *, 2> Operands;
205f1f57a31SAmjad Aboud getRelevantOperands(I, Operands);
206f1f57a31SAmjad Aboud
207f1f57a31SAmjad Aboud if (!Stack.empty() && Stack.back() == I) {
208f1f57a31SAmjad Aboud // Already handled all instruction operands, can remove it from both, the
209f1f57a31SAmjad Aboud // Worklist and the Stack, and update MinBitWidth.
210f1f57a31SAmjad Aboud Worklist.pop_back();
211f1f57a31SAmjad Aboud Stack.pop_back();
212f1f57a31SAmjad Aboud for (auto *Operand : Operands)
213f1f57a31SAmjad Aboud if (auto *IOp = dyn_cast<Instruction>(Operand))
214f1f57a31SAmjad Aboud Info.MinBitWidth =
215f1f57a31SAmjad Aboud std::max(Info.MinBitWidth, InstInfoMap[IOp].MinBitWidth);
216f1f57a31SAmjad Aboud continue;
217f1f57a31SAmjad Aboud }
218f1f57a31SAmjad Aboud
219f1f57a31SAmjad Aboud // Add the instruction to the stack before start handling its operands.
220f1f57a31SAmjad Aboud Stack.push_back(I);
221f1f57a31SAmjad Aboud unsigned ValidBitWidth = Info.ValidBitWidth;
222f1f57a31SAmjad Aboud
223f1f57a31SAmjad Aboud // Update minimum bit-width before handling its operands. This is required
224f1f57a31SAmjad Aboud // when the instruction is part of a loop.
225f1f57a31SAmjad Aboud Info.MinBitWidth = std::max(Info.MinBitWidth, Info.ValidBitWidth);
226f1f57a31SAmjad Aboud
227f1f57a31SAmjad Aboud for (auto *Operand : Operands)
228f1f57a31SAmjad Aboud if (auto *IOp = dyn_cast<Instruction>(Operand)) {
229f1f57a31SAmjad Aboud // If we already calculated the minimum bit-width for this valid
230f1f57a31SAmjad Aboud // bit-width, or for a smaller valid bit-width, then just keep the
231f1f57a31SAmjad Aboud // answer we already calculated.
232f1f57a31SAmjad Aboud unsigned IOpBitwidth = InstInfoMap.lookup(IOp).ValidBitWidth;
233f1f57a31SAmjad Aboud if (IOpBitwidth >= ValidBitWidth)
234f1f57a31SAmjad Aboud continue;
23549a4d85fSAyman Musa InstInfoMap[IOp].ValidBitWidth = ValidBitWidth;
236f1f57a31SAmjad Aboud Worklist.push_back(IOp);
237f1f57a31SAmjad Aboud }
238f1f57a31SAmjad Aboud }
239f1f57a31SAmjad Aboud unsigned MinBitWidth = InstInfoMap.lookup(cast<Instruction>(Src)).MinBitWidth;
240f1f57a31SAmjad Aboud assert(MinBitWidth >= TruncBitWidth);
241f1f57a31SAmjad Aboud
242f1f57a31SAmjad Aboud if (MinBitWidth > TruncBitWidth) {
243f1f57a31SAmjad Aboud // In this case reducing expression with vector type might generate a new
244f1f57a31SAmjad Aboud // vector type, which is not preferable as it might result in generating
245f1f57a31SAmjad Aboud // sub-optimal code.
246f1f57a31SAmjad Aboud if (DstTy->isVectorTy())
247f1f57a31SAmjad Aboud return OrigBitWidth;
248f1f57a31SAmjad Aboud // Use the smallest integer type in the range [MinBitWidth, OrigBitWidth).
249f1f57a31SAmjad Aboud Type *Ty = DL.getSmallestLegalIntType(DstTy->getContext(), MinBitWidth);
250f1f57a31SAmjad Aboud // Update minimum bit-width with the new destination type bit-width if
251f1f57a31SAmjad Aboud // succeeded to find such, otherwise, with original bit-width.
252f1f57a31SAmjad Aboud MinBitWidth = Ty ? Ty->getScalarSizeInBits() : OrigBitWidth;
253f1f57a31SAmjad Aboud } else { // MinBitWidth == TruncBitWidth
254f1f57a31SAmjad Aboud // In this case the expression can be evaluated with the trunc instruction
255f1f57a31SAmjad Aboud // destination type, and trunc instruction can be omitted. However, we
256f1f57a31SAmjad Aboud // should not perform the evaluation if the original type is a legal scalar
257f1f57a31SAmjad Aboud // type and the target type is illegal.
258f1f57a31SAmjad Aboud bool FromLegal = MinBitWidth == 1 || DL.isLegalInteger(OrigBitWidth);
259f1f57a31SAmjad Aboud bool ToLegal = MinBitWidth == 1 || DL.isLegalInteger(MinBitWidth);
260f1f57a31SAmjad Aboud if (!DstTy->isVectorTy() && FromLegal && !ToLegal)
261f1f57a31SAmjad Aboud return OrigBitWidth;
262f1f57a31SAmjad Aboud }
263f1f57a31SAmjad Aboud return MinBitWidth;
264f1f57a31SAmjad Aboud }
265f1f57a31SAmjad Aboud
getBestTruncatedType()266f1f57a31SAmjad Aboud Type *TruncInstCombine::getBestTruncatedType() {
2670dd84013SAnton Afanasyev if (!buildTruncExpressionGraph())
268f1f57a31SAmjad Aboud return nullptr;
269f1f57a31SAmjad Aboud
270f1f57a31SAmjad Aboud // We don't want to duplicate instructions, which isn't profitable. Thus, we
271f1f57a31SAmjad Aboud // can't shrink something that has multiple users, unless all users are
272f1f57a31SAmjad Aboud // post-dominated by the trunc instruction, i.e., were visited during the
273f1f57a31SAmjad Aboud // expression evaluation.
274f1f57a31SAmjad Aboud unsigned DesiredBitWidth = 0;
275f1f57a31SAmjad Aboud for (auto Itr : InstInfoMap) {
276f1f57a31SAmjad Aboud Instruction *I = Itr.first;
277f1f57a31SAmjad Aboud if (I->hasOneUse())
278f1f57a31SAmjad Aboud continue;
279f1f57a31SAmjad Aboud bool IsExtInst = (isa<ZExtInst>(I) || isa<SExtInst>(I));
280f1f57a31SAmjad Aboud for (auto *U : I->users())
281f1f57a31SAmjad Aboud if (auto *UI = dyn_cast<Instruction>(U))
282f1f57a31SAmjad Aboud if (UI != CurrentTruncInst && !InstInfoMap.count(UI)) {
283f1f57a31SAmjad Aboud if (!IsExtInst)
284f1f57a31SAmjad Aboud return nullptr;
285f1f57a31SAmjad Aboud // If this is an extension from the dest type, we can eliminate it,
286f1f57a31SAmjad Aboud // even if it has multiple users. Thus, update the DesiredBitWidth and
287f1f57a31SAmjad Aboud // validate all extension instructions agrees on same DesiredBitWidth.
288f1f57a31SAmjad Aboud unsigned ExtInstBitWidth =
289f1f57a31SAmjad Aboud I->getOperand(0)->getType()->getScalarSizeInBits();
290f1f57a31SAmjad Aboud if (DesiredBitWidth && DesiredBitWidth != ExtInstBitWidth)
291f1f57a31SAmjad Aboud return nullptr;
292f1f57a31SAmjad Aboud DesiredBitWidth = ExtInstBitWidth;
293f1f57a31SAmjad Aboud }
294f1f57a31SAmjad Aboud }
295f1f57a31SAmjad Aboud
296f1f57a31SAmjad Aboud unsigned OrigBitWidth =
297f1f57a31SAmjad Aboud CurrentTruncInst->getOperand(0)->getType()->getScalarSizeInBits();
298f1f57a31SAmjad Aboud
299cfb6dfcbSAnton Afanasyev // Initialize MinBitWidth for shift instructions with the minimum number
300bed58763SAnton Afanasyev // that is greater than shift amount (i.e. shift amount + 1).
301bed58763SAnton Afanasyev // For `lshr` adjust MinBitWidth so that all potentially truncated
302bed58763SAnton Afanasyev // bits of the value-to-be-shifted are zeros.
303bed58763SAnton Afanasyev // For `ashr` adjust MinBitWidth so that all potentially truncated
304bed58763SAnton Afanasyev // bits of the value-to-be-shifted are sign bits (all zeros or ones)
305bed58763SAnton Afanasyev // and even one (first) untruncated bit is sign bit.
306bed58763SAnton Afanasyev // Exit early if MinBitWidth is not less than original bitwidth.
3071f3e35b6SAnton Afanasyev for (auto &Itr : InstInfoMap) {
3081f3e35b6SAnton Afanasyev Instruction *I = Itr.first;
309bed58763SAnton Afanasyev if (I->isShift()) {
3108c0a1940SAnton Afanasyev KnownBits KnownRHS = computeKnownBits(I->getOperand(1));
311803270c0SAnton Afanasyev unsigned MinBitWidth = KnownRHS.getMaxValue()
3123890ce70SAnton Afanasyev .uadd_sat(APInt(OrigBitWidth, 1))
3133890ce70SAnton Afanasyev .getLimitedValue(OrigBitWidth);
3143890ce70SAnton Afanasyev if (MinBitWidth == OrigBitWidth)
3151f3e35b6SAnton Afanasyev return nullptr;
316cfb6dfcbSAnton Afanasyev if (I->getOpcode() == Instruction::LShr) {
3178c0a1940SAnton Afanasyev KnownBits KnownLHS = computeKnownBits(I->getOperand(0));
318cfb6dfcbSAnton Afanasyev MinBitWidth =
319cfb6dfcbSAnton Afanasyev std::max(MinBitWidth, KnownLHS.getMaxValue().getActiveBits());
320bed58763SAnton Afanasyev }
321bed58763SAnton Afanasyev if (I->getOpcode() == Instruction::AShr) {
3228c0a1940SAnton Afanasyev unsigned NumSignBits = ComputeNumSignBits(I->getOperand(0));
323bed58763SAnton Afanasyev MinBitWidth = std::max(MinBitWidth, OrigBitWidth - NumSignBits + 1);
324bed58763SAnton Afanasyev }
325cfb6dfcbSAnton Afanasyev if (MinBitWidth >= OrigBitWidth)
326cfb6dfcbSAnton Afanasyev return nullptr;
3271f3e35b6SAnton Afanasyev Itr.second.MinBitWidth = MinBitWidth;
3281f3e35b6SAnton Afanasyev }
32954d8ebbbSAnton Afanasyev if (I->getOpcode() == Instruction::UDiv ||
33054d8ebbbSAnton Afanasyev I->getOpcode() == Instruction::URem) {
33154d8ebbbSAnton Afanasyev unsigned MinBitWidth = 0;
33254d8ebbbSAnton Afanasyev for (const auto &Op : I->operands()) {
33354d8ebbbSAnton Afanasyev KnownBits Known = computeKnownBits(Op);
33454d8ebbbSAnton Afanasyev MinBitWidth =
33554d8ebbbSAnton Afanasyev std::max(Known.getMaxValue().getActiveBits(), MinBitWidth);
33654d8ebbbSAnton Afanasyev if (MinBitWidth >= OrigBitWidth)
33754d8ebbbSAnton Afanasyev return nullptr;
33854d8ebbbSAnton Afanasyev }
33954d8ebbbSAnton Afanasyev Itr.second.MinBitWidth = MinBitWidth;
34054d8ebbbSAnton Afanasyev }
3411f3e35b6SAnton Afanasyev }
3421f3e35b6SAnton Afanasyev
343f1f57a31SAmjad Aboud // Calculate minimum allowed bit-width allowed for shrinking the currently
344f1f57a31SAmjad Aboud // visited truncate's operand.
345f1f57a31SAmjad Aboud unsigned MinBitWidth = getMinBitWidth();
346f1f57a31SAmjad Aboud
347f1f57a31SAmjad Aboud // Check that we can shrink to smaller bit-width than original one and that
348f1f57a31SAmjad Aboud // it is similar to the DesiredBitWidth is such exists.
349f1f57a31SAmjad Aboud if (MinBitWidth >= OrigBitWidth ||
350f1f57a31SAmjad Aboud (DesiredBitWidth && DesiredBitWidth != MinBitWidth))
351f1f57a31SAmjad Aboud return nullptr;
352f1f57a31SAmjad Aboud
353f1f57a31SAmjad Aboud return IntegerType::get(CurrentTruncInst->getContext(), MinBitWidth);
354f1f57a31SAmjad Aboud }
355f1f57a31SAmjad Aboud
356f1f57a31SAmjad Aboud /// Given a reduced scalar type \p Ty and a \p V value, return a reduced type
357f1f57a31SAmjad Aboud /// for \p V, according to its type, if it vector type, return the vector
358f1f57a31SAmjad Aboud /// version of \p Ty, otherwise return \p Ty.
getReducedType(Value * V,Type * Ty)359f1f57a31SAmjad Aboud static Type *getReducedType(Value *V, Type *Ty) {
360f1f57a31SAmjad Aboud assert(Ty && !Ty->isVectorTy() && "Expect Scalar Type");
361137674f8SJun Ma if (auto *VTy = dyn_cast<VectorType>(V->getType()))
362137674f8SJun Ma return VectorType::get(Ty, VTy->getElementCount());
363f1f57a31SAmjad Aboud return Ty;
364f1f57a31SAmjad Aboud }
365f1f57a31SAmjad Aboud
getReducedOperand(Value * V,Type * SclTy)366f1f57a31SAmjad Aboud Value *TruncInstCombine::getReducedOperand(Value *V, Type *SclTy) {
367f1f57a31SAmjad Aboud Type *Ty = getReducedType(V, SclTy);
368f1f57a31SAmjad Aboud if (auto *C = dyn_cast<Constant>(V)) {
369*3a2fbf54SNikita Popov C = ConstantExpr::getTrunc(C, Ty);
370f1f57a31SAmjad Aboud // If we got a constantexpr back, try to simplify it with DL info.
3710e890cd4SNikita Popov return ConstantFoldConstant(C, DL, &TLI);
372f1f57a31SAmjad Aboud }
373f1f57a31SAmjad Aboud
374f1f57a31SAmjad Aboud auto *I = cast<Instruction>(V);
375f1f57a31SAmjad Aboud Info Entry = InstInfoMap.lookup(I);
376f1f57a31SAmjad Aboud assert(Entry.NewValue);
377f1f57a31SAmjad Aboud return Entry.NewValue;
378f1f57a31SAmjad Aboud }
379f1f57a31SAmjad Aboud
ReduceExpressionGraph(Type * SclTy)3800dd84013SAnton Afanasyev void TruncInstCombine::ReduceExpressionGraph(Type *SclTy) {
381e987ee63SRoman Lebedev NumInstrsReduced += InstInfoMap.size();
3820dd84013SAnton Afanasyev // Pairs of old and new phi-nodes
3830dd84013SAnton Afanasyev SmallVector<std::pair<PHINode *, PHINode *>, 2> OldNewPHINodes;
384f1f57a31SAmjad Aboud for (auto &Itr : InstInfoMap) { // Forward
385f1f57a31SAmjad Aboud Instruction *I = Itr.first;
386f1f57a31SAmjad Aboud TruncInstCombine::Info &NodeInfo = Itr.second;
387f1f57a31SAmjad Aboud
388f1f57a31SAmjad Aboud assert(!NodeInfo.NewValue && "Instruction has been evaluated");
389f1f57a31SAmjad Aboud
390f1f57a31SAmjad Aboud IRBuilder<> Builder(I);
391f1f57a31SAmjad Aboud Value *Res = nullptr;
392f1f57a31SAmjad Aboud unsigned Opc = I->getOpcode();
393f1f57a31SAmjad Aboud switch (Opc) {
394f1f57a31SAmjad Aboud case Instruction::Trunc:
395f1f57a31SAmjad Aboud case Instruction::ZExt:
396f1f57a31SAmjad Aboud case Instruction::SExt: {
397f1f57a31SAmjad Aboud Type *Ty = getReducedType(I, SclTy);
398f1f57a31SAmjad Aboud // If the source type of the cast is the type we're trying for then we can
399f1f57a31SAmjad Aboud // just return the source. There's no need to insert it because it is not
400f1f57a31SAmjad Aboud // new.
401f1f57a31SAmjad Aboud if (I->getOperand(0)->getType() == Ty) {
402b86b771cSAmjad Aboud assert(!isa<TruncInst>(I) && "Cannot reach here with TruncInst");
403f1f57a31SAmjad Aboud NodeInfo.NewValue = I->getOperand(0);
404f1f57a31SAmjad Aboud continue;
405f1f57a31SAmjad Aboud }
406f1f57a31SAmjad Aboud // Otherwise, must be the same type of cast, so just reinsert a new one.
407f1f57a31SAmjad Aboud // This also handles the case of zext(trunc(x)) -> zext(x).
408f1f57a31SAmjad Aboud Res = Builder.CreateIntCast(I->getOperand(0), Ty,
409f1f57a31SAmjad Aboud Opc == Instruction::SExt);
410f1f57a31SAmjad Aboud
411f1f57a31SAmjad Aboud // Update Worklist entries with new value if needed.
412b86b771cSAmjad Aboud // There are three possible changes to the Worklist:
413b86b771cSAmjad Aboud // 1. Update Old-TruncInst -> New-TruncInst.
414b86b771cSAmjad Aboud // 2. Remove Old-TruncInst (if New node is not TruncInst).
415b86b771cSAmjad Aboud // 3. Add New-TruncInst (if Old node was not TruncInst).
4165e8e89d8SSimon Pilgrim auto *Entry = find(Worklist, I);
417b86b771cSAmjad Aboud if (Entry != Worklist.end()) {
418b86b771cSAmjad Aboud if (auto *NewCI = dyn_cast<TruncInst>(Res))
419f1f57a31SAmjad Aboud *Entry = NewCI;
420b86b771cSAmjad Aboud else
421b86b771cSAmjad Aboud Worklist.erase(Entry);
422b86b771cSAmjad Aboud } else if (auto *NewCI = dyn_cast<TruncInst>(Res))
423b86b771cSAmjad Aboud Worklist.push_back(NewCI);
424f1f57a31SAmjad Aboud break;
425f1f57a31SAmjad Aboud }
426f1f57a31SAmjad Aboud case Instruction::Add:
427f1f57a31SAmjad Aboud case Instruction::Sub:
428f1f57a31SAmjad Aboud case Instruction::Mul:
429f1f57a31SAmjad Aboud case Instruction::And:
430f1f57a31SAmjad Aboud case Instruction::Or:
4311f3e35b6SAnton Afanasyev case Instruction::Xor:
432cfb6dfcbSAnton Afanasyev case Instruction::Shl:
433bed58763SAnton Afanasyev case Instruction::LShr:
43454d8ebbbSAnton Afanasyev case Instruction::AShr:
43554d8ebbbSAnton Afanasyev case Instruction::UDiv:
43654d8ebbbSAnton Afanasyev case Instruction::URem: {
437f1f57a31SAmjad Aboud Value *LHS = getReducedOperand(I->getOperand(0), SclTy);
438f1f57a31SAmjad Aboud Value *RHS = getReducedOperand(I->getOperand(1), SclTy);
439f1f57a31SAmjad Aboud Res = Builder.CreateBinOp((Instruction::BinaryOps)Opc, LHS, RHS);
440cfb6dfcbSAnton Afanasyev // Preserve `exact` flag since truncation doesn't change exactness
441bed58763SAnton Afanasyev if (auto *PEO = dyn_cast<PossiblyExactOperator>(I))
442dd19f342SSanjay Patel if (auto *ResI = dyn_cast<Instruction>(Res))
443bed58763SAnton Afanasyev ResI->setIsExact(PEO->isExact());
444f1f57a31SAmjad Aboud break;
445f1f57a31SAmjad Aboud }
4466a5f49a1SAnton Afanasyev case Instruction::ExtractElement: {
4476a5f49a1SAnton Afanasyev Value *Vec = getReducedOperand(I->getOperand(0), SclTy);
4486a5f49a1SAnton Afanasyev Value *Idx = I->getOperand(1);
4496a5f49a1SAnton Afanasyev Res = Builder.CreateExtractElement(Vec, Idx);
4506a5f49a1SAnton Afanasyev break;
4516a5f49a1SAnton Afanasyev }
4526a5f49a1SAnton Afanasyev case Instruction::InsertElement: {
4536a5f49a1SAnton Afanasyev Value *Vec = getReducedOperand(I->getOperand(0), SclTy);
4546a5f49a1SAnton Afanasyev Value *NewElt = getReducedOperand(I->getOperand(1), SclTy);
4556a5f49a1SAnton Afanasyev Value *Idx = I->getOperand(2);
4566a5f49a1SAnton Afanasyev Res = Builder.CreateInsertElement(Vec, NewElt, Idx);
4576a5f49a1SAnton Afanasyev break;
4586a5f49a1SAnton Afanasyev }
4593bda9059SAyman Musa case Instruction::Select: {
4603bda9059SAyman Musa Value *Op0 = I->getOperand(0);
4613bda9059SAyman Musa Value *LHS = getReducedOperand(I->getOperand(1), SclTy);
4623bda9059SAyman Musa Value *RHS = getReducedOperand(I->getOperand(2), SclTy);
4633bda9059SAyman Musa Res = Builder.CreateSelect(Op0, LHS, RHS);
4643bda9059SAyman Musa break;
4653bda9059SAyman Musa }
4660dd84013SAnton Afanasyev case Instruction::PHI: {
4670dd84013SAnton Afanasyev Res = Builder.CreatePHI(getReducedType(I, SclTy), I->getNumOperands());
4680dd84013SAnton Afanasyev OldNewPHINodes.push_back(
4690dd84013SAnton Afanasyev std::make_pair(cast<PHINode>(I), cast<PHINode>(Res)));
4700dd84013SAnton Afanasyev break;
4710dd84013SAnton Afanasyev }
472f1f57a31SAmjad Aboud default:
473f1f57a31SAmjad Aboud llvm_unreachable("Unhandled instruction");
474f1f57a31SAmjad Aboud }
475f1f57a31SAmjad Aboud
476f1f57a31SAmjad Aboud NodeInfo.NewValue = Res;
477f1f57a31SAmjad Aboud if (auto *ResI = dyn_cast<Instruction>(Res))
478f1f57a31SAmjad Aboud ResI->takeName(I);
479f1f57a31SAmjad Aboud }
480f1f57a31SAmjad Aboud
4810dd84013SAnton Afanasyev for (auto &Node : OldNewPHINodes) {
4820dd84013SAnton Afanasyev PHINode *OldPN = Node.first;
4830dd84013SAnton Afanasyev PHINode *NewPN = Node.second;
4840dd84013SAnton Afanasyev for (auto Incoming : zip(OldPN->incoming_values(), OldPN->blocks()))
4850dd84013SAnton Afanasyev NewPN->addIncoming(getReducedOperand(std::get<0>(Incoming), SclTy),
4860dd84013SAnton Afanasyev std::get<1>(Incoming));
4870dd84013SAnton Afanasyev }
4880dd84013SAnton Afanasyev
489f1f57a31SAmjad Aboud Value *Res = getReducedOperand(CurrentTruncInst->getOperand(0), SclTy);
490f1f57a31SAmjad Aboud Type *DstTy = CurrentTruncInst->getType();
491f1f57a31SAmjad Aboud if (Res->getType() != DstTy) {
492f1f57a31SAmjad Aboud IRBuilder<> Builder(CurrentTruncInst);
493f1f57a31SAmjad Aboud Res = Builder.CreateIntCast(Res, DstTy, false);
494f1f57a31SAmjad Aboud if (auto *ResI = dyn_cast<Instruction>(Res))
495f1f57a31SAmjad Aboud ResI->takeName(CurrentTruncInst);
496f1f57a31SAmjad Aboud }
497f1f57a31SAmjad Aboud CurrentTruncInst->replaceAllUsesWith(Res);
498f1f57a31SAmjad Aboud
4990dd84013SAnton Afanasyev // Erase old expression graph, which was replaced by the reduced expression
5000dd84013SAnton Afanasyev // graph.
5018ad6d5e4SAnton Afanasyev CurrentTruncInst->eraseFromParent();
5020dd84013SAnton Afanasyev // First, erase old phi-nodes and its uses
5030dd84013SAnton Afanasyev for (auto &Node : OldNewPHINodes) {
5040dd84013SAnton Afanasyev PHINode *OldPN = Node.first;
5050dd84013SAnton Afanasyev OldPN->replaceAllUsesWith(PoisonValue::get(OldPN->getType()));
506904a00d1SAnton Afanasyev InstInfoMap.erase(OldPN);
5070dd84013SAnton Afanasyev OldPN->eraseFromParent();
5080dd84013SAnton Afanasyev }
5090dd84013SAnton Afanasyev // Now we have expression graph turned into dag.
5100dd84013SAnton Afanasyev // We iterate backward, which means we visit the instruction before we
5110dd84013SAnton Afanasyev // visit any of its operands, this way, when we get to the operand, we already
5120dd84013SAnton Afanasyev // removed the instructions (from the expression dag) that uses it.
5137787a8f1SKazu Hirata for (auto &I : llvm::reverse(InstInfoMap)) {
514f1f57a31SAmjad Aboud // We still need to check that the instruction has no users before we erase
515f1f57a31SAmjad Aboud // it, because {SExt, ZExt}Inst Instruction might have other users that was
516f1f57a31SAmjad Aboud // not reduced, in such case, we need to keep that instruction.
5177787a8f1SKazu Hirata if (I.first->use_empty())
5187787a8f1SKazu Hirata I.first->eraseFromParent();
5190dd84013SAnton Afanasyev else
5200dd84013SAnton Afanasyev assert((isa<SExtInst>(I.first) || isa<ZExtInst>(I.first)) &&
5210dd84013SAnton Afanasyev "Only {SExt, ZExt}Inst might have unreduced users");
522f1f57a31SAmjad Aboud }
523f1f57a31SAmjad Aboud }
524f1f57a31SAmjad Aboud
run(Function & F)525f1f57a31SAmjad Aboud bool TruncInstCombine::run(Function &F) {
526f1f57a31SAmjad Aboud bool MadeIRChange = false;
527f1f57a31SAmjad Aboud
528f1f57a31SAmjad Aboud // Collect all TruncInst in the function into the Worklist for evaluating.
529d895bff5SAmjad Aboud for (auto &BB : F) {
530d895bff5SAmjad Aboud // Ignore unreachable basic block.
531d895bff5SAmjad Aboud if (!DT.isReachableFromEntry(&BB))
532d895bff5SAmjad Aboud continue;
533f1f57a31SAmjad Aboud for (auto &I : BB)
534f1f57a31SAmjad Aboud if (auto *CI = dyn_cast<TruncInst>(&I))
535f1f57a31SAmjad Aboud Worklist.push_back(CI);
536d895bff5SAmjad Aboud }
537f1f57a31SAmjad Aboud
538f1f57a31SAmjad Aboud // Process all TruncInst in the Worklist, for each instruction:
5390dd84013SAnton Afanasyev // 1. Check if it dominates an eligible expression graph to be reduced.
5400dd84013SAnton Afanasyev // 2. Create a reduced expression graph and replace the old one with it.
541f1f57a31SAmjad Aboud while (!Worklist.empty()) {
542f1f57a31SAmjad Aboud CurrentTruncInst = Worklist.pop_back_val();
543f1f57a31SAmjad Aboud
544f1f57a31SAmjad Aboud if (Type *NewDstSclTy = getBestTruncatedType()) {
545d34e60caSNicola Zaghen LLVM_DEBUG(
5460dd84013SAnton Afanasyev dbgs() << "ICE: TruncInstCombine reducing type of expression graph "
547f1f57a31SAmjad Aboud "dominated by: "
548f1f57a31SAmjad Aboud << CurrentTruncInst << '\n');
5490dd84013SAnton Afanasyev ReduceExpressionGraph(NewDstSclTy);
5500dd84013SAnton Afanasyev ++NumExprsReduced;
551f1f57a31SAmjad Aboud MadeIRChange = true;
552f1f57a31SAmjad Aboud }
553f1f57a31SAmjad Aboud }
554f1f57a31SAmjad Aboud
555f1f57a31SAmjad Aboud return MadeIRChange;
556f1f57a31SAmjad Aboud }
557