xref: /llvm-project/llvm/lib/CodeGen/GlobalISel/CombinerHelperCasts.cpp (revision ba4bcce5f5ffa9e7d4af72c20fe4f1baf97075fc)
1 //===- CombinerHelperCasts.cpp---------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements CombinerHelper for G_ANYEXT, G_SEXT, G_TRUNC, and
10 // G_ZEXT
11 //
12 //===----------------------------------------------------------------------===//
13 #include "llvm/CodeGen/GlobalISel/CombinerHelper.h"
14 #include "llvm/CodeGen/GlobalISel/LegalizerHelper.h"
15 #include "llvm/CodeGen/GlobalISel/LegalizerInfo.h"
16 #include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h"
17 #include "llvm/CodeGen/GlobalISel/Utils.h"
18 #include "llvm/CodeGen/LowLevelTypeUtils.h"
19 #include "llvm/CodeGen/MachineOperand.h"
20 #include "llvm/CodeGen/MachineRegisterInfo.h"
21 #include "llvm/CodeGen/TargetOpcodes.h"
22 #include "llvm/Support/Casting.h"
23 
24 #define DEBUG_TYPE "gi-combiner"
25 
26 using namespace llvm;
27 
28 bool CombinerHelper::matchSextOfTrunc(const MachineOperand &MO,
29                                       BuildFnTy &MatchInfo) {
30   GSext *Sext = cast<GSext>(getDefIgnoringCopies(MO.getReg(), MRI));
31   GTrunc *Trunc = cast<GTrunc>(getDefIgnoringCopies(Sext->getSrcReg(), MRI));
32 
33   Register Dst = Sext->getReg(0);
34   Register Src = Trunc->getSrcReg();
35 
36   LLT DstTy = MRI.getType(Dst);
37   LLT SrcTy = MRI.getType(Src);
38 
39   if (DstTy == SrcTy) {
40     MatchInfo = [=](MachineIRBuilder &B) { B.buildCopy(Dst, Src); };
41     return true;
42   }
43 
44   if (DstTy.getScalarSizeInBits() < SrcTy.getScalarSizeInBits() &&
45       isLegalOrBeforeLegalizer({TargetOpcode::G_TRUNC, {DstTy, SrcTy}})) {
46     MatchInfo = [=](MachineIRBuilder &B) {
47       B.buildTrunc(Dst, Src, MachineInstr::MIFlag::NoSWrap);
48     };
49     return true;
50   }
51 
52   if (DstTy.getScalarSizeInBits() > SrcTy.getScalarSizeInBits() &&
53       isLegalOrBeforeLegalizer({TargetOpcode::G_SEXT, {DstTy, SrcTy}})) {
54     MatchInfo = [=](MachineIRBuilder &B) { B.buildSExt(Dst, Src); };
55     return true;
56   }
57 
58   return false;
59 }
60 
61 bool CombinerHelper::matchZextOfTrunc(const MachineOperand &MO,
62                                       BuildFnTy &MatchInfo) {
63   GZext *Zext = cast<GZext>(getDefIgnoringCopies(MO.getReg(), MRI));
64   GTrunc *Trunc = cast<GTrunc>(getDefIgnoringCopies(Zext->getSrcReg(), MRI));
65 
66   Register Dst = Zext->getReg(0);
67   Register Src = Trunc->getSrcReg();
68 
69   LLT DstTy = MRI.getType(Dst);
70   LLT SrcTy = MRI.getType(Src);
71 
72   if (DstTy == SrcTy) {
73     MatchInfo = [=](MachineIRBuilder &B) { B.buildCopy(Dst, Src); };
74     return true;
75   }
76 
77   if (DstTy.getScalarSizeInBits() < SrcTy.getScalarSizeInBits() &&
78       isLegalOrBeforeLegalizer({TargetOpcode::G_TRUNC, {DstTy, SrcTy}})) {
79     MatchInfo = [=](MachineIRBuilder &B) {
80       B.buildTrunc(Dst, Src, MachineInstr::MIFlag::NoUWrap);
81     };
82     return true;
83   }
84 
85   if (DstTy.getScalarSizeInBits() > SrcTy.getScalarSizeInBits() &&
86       isLegalOrBeforeLegalizer({TargetOpcode::G_ZEXT, {DstTy, SrcTy}})) {
87     MatchInfo = [=](MachineIRBuilder &B) {
88       B.buildZExt(Dst, Src, MachineInstr::MIFlag::NonNeg);
89     };
90     return true;
91   }
92 
93   return false;
94 }
95 
96 bool CombinerHelper::matchNonNegZext(const MachineOperand &MO,
97                                      BuildFnTy &MatchInfo) {
98   GZext *Zext = cast<GZext>(MRI.getVRegDef(MO.getReg()));
99 
100   Register Dst = Zext->getReg(0);
101   Register Src = Zext->getSrcReg();
102 
103   LLT DstTy = MRI.getType(Dst);
104   LLT SrcTy = MRI.getType(Src);
105   const auto &TLI = getTargetLowering();
106 
107   // Convert zext nneg to sext if sext is the preferred form for the target.
108   if (isLegalOrBeforeLegalizer({TargetOpcode::G_SEXT, {DstTy, SrcTy}}) &&
109       TLI.isSExtCheaperThanZExt(getMVTForLLT(SrcTy), getMVTForLLT(DstTy))) {
110     MatchInfo = [=](MachineIRBuilder &B) { B.buildSExt(Dst, Src); };
111     return true;
112   }
113 
114   return false;
115 }
116 
117 bool CombinerHelper::matchTruncateOfExt(const MachineInstr &Root,
118                                         const MachineInstr &ExtMI,
119                                         BuildFnTy &MatchInfo) {
120   const GTrunc *Trunc = cast<GTrunc>(&Root);
121   const GExtOp *Ext = cast<GExtOp>(&ExtMI);
122 
123   if (!MRI.hasOneNonDBGUse(Ext->getReg(0)))
124     return false;
125 
126   Register Dst = Trunc->getReg(0);
127   Register Src = Ext->getSrcReg();
128   LLT DstTy = MRI.getType(Dst);
129   LLT SrcTy = MRI.getType(Src);
130 
131   if (SrcTy == DstTy) {
132     // The source and the destination are equally sized. We need to copy.
133     MatchInfo = [=](MachineIRBuilder &B) { B.buildCopy(Dst, Src); };
134 
135     return true;
136   }
137 
138   if (SrcTy.getScalarSizeInBits() < DstTy.getScalarSizeInBits()) {
139     // If the source is smaller than the destination, we need to extend.
140 
141     if (!isLegalOrBeforeLegalizer({Ext->getOpcode(), {DstTy, SrcTy}}))
142       return false;
143 
144     MatchInfo = [=](MachineIRBuilder &B) {
145       B.buildInstr(Ext->getOpcode(), {Dst}, {Src});
146     };
147 
148     return true;
149   }
150 
151   if (SrcTy.getScalarSizeInBits() > DstTy.getScalarSizeInBits()) {
152     // If the source is larger than the destination, then we need to truncate.
153 
154     if (!isLegalOrBeforeLegalizer({TargetOpcode::G_TRUNC, {DstTy, SrcTy}}))
155       return false;
156 
157     MatchInfo = [=](MachineIRBuilder &B) { B.buildTrunc(Dst, Src); };
158 
159     return true;
160   }
161 
162   return false;
163 }
164 
165 bool CombinerHelper::isCastFree(unsigned Opcode, LLT ToTy, LLT FromTy) const {
166   const TargetLowering &TLI = getTargetLowering();
167   const DataLayout &DL = getDataLayout();
168   LLVMContext &Ctx = getContext();
169 
170   switch (Opcode) {
171   case TargetOpcode::G_ANYEXT:
172   case TargetOpcode::G_ZEXT:
173     return TLI.isZExtFree(FromTy, ToTy, DL, Ctx);
174   case TargetOpcode::G_TRUNC:
175     return TLI.isTruncateFree(FromTy, ToTy, DL, Ctx);
176   default:
177     return false;
178   }
179 }
180 
181 bool CombinerHelper::matchCastOfSelect(const MachineInstr &CastMI,
182                                        const MachineInstr &SelectMI,
183                                        BuildFnTy &MatchInfo) {
184   const GExtOrTruncOp *Cast = cast<GExtOrTruncOp>(&CastMI);
185   const GSelect *Select = cast<GSelect>(&SelectMI);
186 
187   if (!MRI.hasOneNonDBGUse(Select->getReg(0)))
188     return false;
189 
190   Register Dst = Cast->getReg(0);
191   LLT DstTy = MRI.getType(Dst);
192   LLT CondTy = MRI.getType(Select->getCondReg());
193   Register TrueReg = Select->getTrueReg();
194   Register FalseReg = Select->getFalseReg();
195   LLT SrcTy = MRI.getType(TrueReg);
196   Register Cond = Select->getCondReg();
197 
198   if (!isLegalOrBeforeLegalizer({TargetOpcode::G_SELECT, {DstTy, CondTy}}))
199     return false;
200 
201   if (!isCastFree(Cast->getOpcode(), DstTy, SrcTy))
202     return false;
203 
204   MatchInfo = [=](MachineIRBuilder &B) {
205     auto True = B.buildInstr(Cast->getOpcode(), {DstTy}, {TrueReg});
206     auto False = B.buildInstr(Cast->getOpcode(), {DstTy}, {FalseReg});
207     B.buildSelect(Dst, Cond, True, False);
208   };
209 
210   return true;
211 }
212 
213 bool CombinerHelper::matchExtOfExt(const MachineInstr &FirstMI,
214                                    const MachineInstr &SecondMI,
215                                    BuildFnTy &MatchInfo) {
216   const GExtOp *First = cast<GExtOp>(&FirstMI);
217   const GExtOp *Second = cast<GExtOp>(&SecondMI);
218 
219   Register Dst = First->getReg(0);
220   Register Src = Second->getSrcReg();
221   LLT DstTy = MRI.getType(Dst);
222   LLT SrcTy = MRI.getType(Src);
223 
224   if (!MRI.hasOneNonDBGUse(Second->getReg(0)))
225     return false;
226 
227   // ext of ext -> later ext
228   if (First->getOpcode() == Second->getOpcode() &&
229       isLegalOrBeforeLegalizer({Second->getOpcode(), {DstTy, SrcTy}})) {
230     if (Second->getOpcode() == TargetOpcode::G_ZEXT) {
231       MachineInstr::MIFlag Flag = MachineInstr::MIFlag::NoFlags;
232       if (Second->getFlag(MachineInstr::MIFlag::NonNeg))
233         Flag = MachineInstr::MIFlag::NonNeg;
234       MatchInfo = [=](MachineIRBuilder &B) { B.buildZExt(Dst, Src, Flag); };
235       return true;
236     }
237     // not zext -> no flags
238     MatchInfo = [=](MachineIRBuilder &B) {
239       B.buildInstr(Second->getOpcode(), {Dst}, {Src});
240     };
241     return true;
242   }
243 
244   // anyext of sext/zext  -> sext/zext
245   // -> pick anyext as second ext, then ext of ext
246   if (First->getOpcode() == TargetOpcode::G_ANYEXT &&
247       isLegalOrBeforeLegalizer({Second->getOpcode(), {DstTy, SrcTy}})) {
248     if (Second->getOpcode() == TargetOpcode::G_ZEXT) {
249       MachineInstr::MIFlag Flag = MachineInstr::MIFlag::NoFlags;
250       if (Second->getFlag(MachineInstr::MIFlag::NonNeg))
251         Flag = MachineInstr::MIFlag::NonNeg;
252       MatchInfo = [=](MachineIRBuilder &B) { B.buildZExt(Dst, Src, Flag); };
253       return true;
254     }
255     MatchInfo = [=](MachineIRBuilder &B) { B.buildSExt(Dst, Src); };
256     return true;
257   }
258 
259   // sext/zext of anyext -> sext/zext
260   // -> pick anyext as first ext, then ext of ext
261   if (Second->getOpcode() == TargetOpcode::G_ANYEXT &&
262       isLegalOrBeforeLegalizer({First->getOpcode(), {DstTy, SrcTy}})) {
263     if (First->getOpcode() == TargetOpcode::G_ZEXT) {
264       MachineInstr::MIFlag Flag = MachineInstr::MIFlag::NoFlags;
265       if (First->getFlag(MachineInstr::MIFlag::NonNeg))
266         Flag = MachineInstr::MIFlag::NonNeg;
267       MatchInfo = [=](MachineIRBuilder &B) { B.buildZExt(Dst, Src, Flag); };
268       return true;
269     }
270     MatchInfo = [=](MachineIRBuilder &B) { B.buildSExt(Dst, Src); };
271     return true;
272   }
273 
274   return false;
275 }
276 
277 bool CombinerHelper::matchCastOfBuildVector(const MachineInstr &CastMI,
278                                             const MachineInstr &BVMI,
279                                             BuildFnTy &MatchInfo) {
280   const GExtOrTruncOp *Cast = cast<GExtOrTruncOp>(&CastMI);
281   const GBuildVector *BV = cast<GBuildVector>(&BVMI);
282 
283   if (!MRI.hasOneNonDBGUse(BV->getReg(0)))
284     return false;
285 
286   Register Dst = Cast->getReg(0);
287   // The type of the new build vector.
288   LLT DstTy = MRI.getType(Dst);
289   // The scalar or element type of the new build vector.
290   LLT ElemTy = DstTy.getScalarType();
291   // The scalar or element type of the old build vector.
292   LLT InputElemTy = MRI.getType(BV->getReg(0)).getElementType();
293 
294   // Check legality of new build vector, the scalar casts, and profitability of
295   // the many casts.
296   if (!isLegalOrBeforeLegalizer(
297           {TargetOpcode::G_BUILD_VECTOR, {DstTy, ElemTy}}) ||
298       !isLegalOrBeforeLegalizer({Cast->getOpcode(), {ElemTy, InputElemTy}}) ||
299       !isCastFree(Cast->getOpcode(), ElemTy, InputElemTy))
300     return false;
301 
302   MatchInfo = [=](MachineIRBuilder &B) {
303     SmallVector<Register> Casts;
304     unsigned Elements = BV->getNumSources();
305     for (unsigned I = 0; I < Elements; ++I) {
306       auto CastI =
307           B.buildInstr(Cast->getOpcode(), {ElemTy}, {BV->getSourceReg(I)});
308       Casts.push_back(CastI.getReg(0));
309     }
310 
311     B.buildBuildVector(Dst, Casts);
312   };
313 
314   return true;
315 }
316 
317 bool CombinerHelper::matchNarrowBinop(const MachineInstr &TruncMI,
318                                       const MachineInstr &BinopMI,
319                                       BuildFnTy &MatchInfo) {
320   const GTrunc *Trunc = cast<GTrunc>(&TruncMI);
321   const GBinOp *BinOp = cast<GBinOp>(&BinopMI);
322 
323   if (!MRI.hasOneNonDBGUse(BinOp->getReg(0)))
324     return false;
325 
326   Register Dst = Trunc->getReg(0);
327   LLT DstTy = MRI.getType(Dst);
328 
329   // Is narrow binop legal?
330   if (!isLegalOrBeforeLegalizer({BinOp->getOpcode(), {DstTy}}))
331     return false;
332 
333   MatchInfo = [=](MachineIRBuilder &B) {
334     auto LHS = B.buildTrunc(DstTy, BinOp->getLHSReg());
335     auto RHS = B.buildTrunc(DstTy, BinOp->getRHSReg());
336     B.buildInstr(BinOp->getOpcode(), {Dst}, {LHS, RHS});
337   };
338 
339   return true;
340 }
341 
342 bool CombinerHelper::matchCastOfInteger(const MachineInstr &CastMI,
343                                         APInt &MatchInfo) {
344   const GExtOrTruncOp *Cast = cast<GExtOrTruncOp>(&CastMI);
345 
346   APInt Input = getIConstantFromReg(Cast->getSrcReg(), MRI);
347 
348   LLT DstTy = MRI.getType(Cast->getReg(0));
349 
350   if (!isConstantLegalOrBeforeLegalizer(DstTy))
351     return false;
352 
353   switch (Cast->getOpcode()) {
354   case TargetOpcode::G_TRUNC: {
355     MatchInfo = Input.trunc(DstTy.getScalarSizeInBits());
356     return true;
357   }
358   default:
359     return false;
360   }
361 }
362