xref: /llvm-project/llvm/lib/CodeGen/GlobalISel/CombinerHelperCasts.cpp (revision ee7ca0dddafb609090ad1789570c099d95c0afb6)
1 //===- CombinerHelperCasts.cpp---------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements CombinerHelper for G_ANYEXT, G_SEXT, G_TRUNC, and
10 // G_ZEXT
11 //
12 //===----------------------------------------------------------------------===//
13 #include "llvm/CodeGen/GlobalISel/CombinerHelper.h"
14 #include "llvm/CodeGen/GlobalISel/LegalizerHelper.h"
15 #include "llvm/CodeGen/GlobalISel/LegalizerInfo.h"
16 #include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h"
17 #include "llvm/CodeGen/GlobalISel/Utils.h"
18 #include "llvm/CodeGen/LowLevelTypeUtils.h"
19 #include "llvm/CodeGen/MachineOperand.h"
20 #include "llvm/CodeGen/MachineRegisterInfo.h"
21 #include "llvm/CodeGen/TargetOpcodes.h"
22 #include "llvm/Support/Casting.h"
23 
24 #define DEBUG_TYPE "gi-combiner"
25 
26 using namespace llvm;
27 
28 bool CombinerHelper::matchSextOfTrunc(const MachineOperand &MO,
29                                       BuildFnTy &MatchInfo) const {
30   GSext *Sext = cast<GSext>(getDefIgnoringCopies(MO.getReg(), MRI));
31   GTrunc *Trunc = cast<GTrunc>(getDefIgnoringCopies(Sext->getSrcReg(), MRI));
32 
33   Register Dst = Sext->getReg(0);
34   Register Src = Trunc->getSrcReg();
35 
36   LLT DstTy = MRI.getType(Dst);
37   LLT SrcTy = MRI.getType(Src);
38 
39   if (DstTy == SrcTy) {
40     MatchInfo = [=](MachineIRBuilder &B) { B.buildCopy(Dst, Src); };
41     return true;
42   }
43 
44   if (DstTy.getScalarSizeInBits() < SrcTy.getScalarSizeInBits() &&
45       isLegalOrBeforeLegalizer({TargetOpcode::G_TRUNC, {DstTy, SrcTy}})) {
46     MatchInfo = [=](MachineIRBuilder &B) {
47       B.buildTrunc(Dst, Src, MachineInstr::MIFlag::NoSWrap);
48     };
49     return true;
50   }
51 
52   if (DstTy.getScalarSizeInBits() > SrcTy.getScalarSizeInBits() &&
53       isLegalOrBeforeLegalizer({TargetOpcode::G_SEXT, {DstTy, SrcTy}})) {
54     MatchInfo = [=](MachineIRBuilder &B) { B.buildSExt(Dst, Src); };
55     return true;
56   }
57 
58   return false;
59 }
60 
61 bool CombinerHelper::matchZextOfTrunc(const MachineOperand &MO,
62                                       BuildFnTy &MatchInfo) const {
63   GZext *Zext = cast<GZext>(getDefIgnoringCopies(MO.getReg(), MRI));
64   GTrunc *Trunc = cast<GTrunc>(getDefIgnoringCopies(Zext->getSrcReg(), MRI));
65 
66   Register Dst = Zext->getReg(0);
67   Register Src = Trunc->getSrcReg();
68 
69   LLT DstTy = MRI.getType(Dst);
70   LLT SrcTy = MRI.getType(Src);
71 
72   if (DstTy == SrcTy) {
73     MatchInfo = [=](MachineIRBuilder &B) { B.buildCopy(Dst, Src); };
74     return true;
75   }
76 
77   if (DstTy.getScalarSizeInBits() < SrcTy.getScalarSizeInBits() &&
78       isLegalOrBeforeLegalizer({TargetOpcode::G_TRUNC, {DstTy, SrcTy}})) {
79     MatchInfo = [=](MachineIRBuilder &B) {
80       B.buildTrunc(Dst, Src, MachineInstr::MIFlag::NoUWrap);
81     };
82     return true;
83   }
84 
85   if (DstTy.getScalarSizeInBits() > SrcTy.getScalarSizeInBits() &&
86       isLegalOrBeforeLegalizer({TargetOpcode::G_ZEXT, {DstTy, SrcTy}})) {
87     MatchInfo = [=](MachineIRBuilder &B) {
88       B.buildZExt(Dst, Src, MachineInstr::MIFlag::NonNeg);
89     };
90     return true;
91   }
92 
93   return false;
94 }
95 
96 bool CombinerHelper::matchNonNegZext(const MachineOperand &MO,
97                                      BuildFnTy &MatchInfo) const {
98   GZext *Zext = cast<GZext>(MRI.getVRegDef(MO.getReg()));
99 
100   Register Dst = Zext->getReg(0);
101   Register Src = Zext->getSrcReg();
102 
103   LLT DstTy = MRI.getType(Dst);
104   LLT SrcTy = MRI.getType(Src);
105   const auto &TLI = getTargetLowering();
106 
107   // Convert zext nneg to sext if sext is the preferred form for the target.
108   if (isLegalOrBeforeLegalizer({TargetOpcode::G_SEXT, {DstTy, SrcTy}}) &&
109       TLI.isSExtCheaperThanZExt(getMVTForLLT(SrcTy), getMVTForLLT(DstTy))) {
110     MatchInfo = [=](MachineIRBuilder &B) { B.buildSExt(Dst, Src); };
111     return true;
112   }
113 
114   return false;
115 }
116 
117 bool CombinerHelper::matchTruncateOfExt(const MachineInstr &Root,
118                                         const MachineInstr &ExtMI,
119                                         BuildFnTy &MatchInfo) const {
120   const GTrunc *Trunc = cast<GTrunc>(&Root);
121   const GExtOp *Ext = cast<GExtOp>(&ExtMI);
122 
123   if (!MRI.hasOneNonDBGUse(Ext->getReg(0)))
124     return false;
125 
126   Register Dst = Trunc->getReg(0);
127   Register Src = Ext->getSrcReg();
128   LLT DstTy = MRI.getType(Dst);
129   LLT SrcTy = MRI.getType(Src);
130 
131   if (SrcTy == DstTy) {
132     // The source and the destination are equally sized. We need to copy.
133     MatchInfo = [=](MachineIRBuilder &B) { B.buildCopy(Dst, Src); };
134 
135     return true;
136   }
137 
138   if (SrcTy.getScalarSizeInBits() < DstTy.getScalarSizeInBits()) {
139     // If the source is smaller than the destination, we need to extend.
140 
141     if (!isLegalOrBeforeLegalizer({Ext->getOpcode(), {DstTy, SrcTy}}))
142       return false;
143 
144     MatchInfo = [=](MachineIRBuilder &B) {
145       B.buildInstr(Ext->getOpcode(), {Dst}, {Src});
146     };
147 
148     return true;
149   }
150 
151   if (SrcTy.getScalarSizeInBits() > DstTy.getScalarSizeInBits()) {
152     // If the source is larger than the destination, then we need to truncate.
153 
154     if (!isLegalOrBeforeLegalizer({TargetOpcode::G_TRUNC, {DstTy, SrcTy}}))
155       return false;
156 
157     MatchInfo = [=](MachineIRBuilder &B) { B.buildTrunc(Dst, Src); };
158 
159     return true;
160   }
161 
162   return false;
163 }
164 
165 bool CombinerHelper::isCastFree(unsigned Opcode, LLT ToTy, LLT FromTy) const {
166   const TargetLowering &TLI = getTargetLowering();
167   LLVMContext &Ctx = getContext();
168 
169   switch (Opcode) {
170   case TargetOpcode::G_ANYEXT:
171   case TargetOpcode::G_ZEXT:
172     return TLI.isZExtFree(FromTy, ToTy, Ctx);
173   case TargetOpcode::G_TRUNC:
174     return TLI.isTruncateFree(FromTy, ToTy, Ctx);
175   default:
176     return false;
177   }
178 }
179 
180 bool CombinerHelper::matchCastOfSelect(const MachineInstr &CastMI,
181                                        const MachineInstr &SelectMI,
182                                        BuildFnTy &MatchInfo) const {
183   const GExtOrTruncOp *Cast = cast<GExtOrTruncOp>(&CastMI);
184   const GSelect *Select = cast<GSelect>(&SelectMI);
185 
186   if (!MRI.hasOneNonDBGUse(Select->getReg(0)))
187     return false;
188 
189   Register Dst = Cast->getReg(0);
190   LLT DstTy = MRI.getType(Dst);
191   LLT CondTy = MRI.getType(Select->getCondReg());
192   Register TrueReg = Select->getTrueReg();
193   Register FalseReg = Select->getFalseReg();
194   LLT SrcTy = MRI.getType(TrueReg);
195   Register Cond = Select->getCondReg();
196 
197   if (!isLegalOrBeforeLegalizer({TargetOpcode::G_SELECT, {DstTy, CondTy}}))
198     return false;
199 
200   if (!isCastFree(Cast->getOpcode(), DstTy, SrcTy))
201     return false;
202 
203   MatchInfo = [=](MachineIRBuilder &B) {
204     auto True = B.buildInstr(Cast->getOpcode(), {DstTy}, {TrueReg});
205     auto False = B.buildInstr(Cast->getOpcode(), {DstTy}, {FalseReg});
206     B.buildSelect(Dst, Cond, True, False);
207   };
208 
209   return true;
210 }
211 
212 bool CombinerHelper::matchExtOfExt(const MachineInstr &FirstMI,
213                                    const MachineInstr &SecondMI,
214                                    BuildFnTy &MatchInfo) const {
215   const GExtOp *First = cast<GExtOp>(&FirstMI);
216   const GExtOp *Second = cast<GExtOp>(&SecondMI);
217 
218   Register Dst = First->getReg(0);
219   Register Src = Second->getSrcReg();
220   LLT DstTy = MRI.getType(Dst);
221   LLT SrcTy = MRI.getType(Src);
222 
223   if (!MRI.hasOneNonDBGUse(Second->getReg(0)))
224     return false;
225 
226   // ext of ext -> later ext
227   if (First->getOpcode() == Second->getOpcode() &&
228       isLegalOrBeforeLegalizer({Second->getOpcode(), {DstTy, SrcTy}})) {
229     if (Second->getOpcode() == TargetOpcode::G_ZEXT) {
230       MachineInstr::MIFlag Flag = MachineInstr::MIFlag::NoFlags;
231       if (Second->getFlag(MachineInstr::MIFlag::NonNeg))
232         Flag = MachineInstr::MIFlag::NonNeg;
233       MatchInfo = [=](MachineIRBuilder &B) { B.buildZExt(Dst, Src, Flag); };
234       return true;
235     }
236     // not zext -> no flags
237     MatchInfo = [=](MachineIRBuilder &B) {
238       B.buildInstr(Second->getOpcode(), {Dst}, {Src});
239     };
240     return true;
241   }
242 
243   // anyext of sext/zext  -> sext/zext
244   // -> pick anyext as second ext, then ext of ext
245   if (First->getOpcode() == TargetOpcode::G_ANYEXT &&
246       isLegalOrBeforeLegalizer({Second->getOpcode(), {DstTy, SrcTy}})) {
247     if (Second->getOpcode() == TargetOpcode::G_ZEXT) {
248       MachineInstr::MIFlag Flag = MachineInstr::MIFlag::NoFlags;
249       if (Second->getFlag(MachineInstr::MIFlag::NonNeg))
250         Flag = MachineInstr::MIFlag::NonNeg;
251       MatchInfo = [=](MachineIRBuilder &B) { B.buildZExt(Dst, Src, Flag); };
252       return true;
253     }
254     MatchInfo = [=](MachineIRBuilder &B) { B.buildSExt(Dst, Src); };
255     return true;
256   }
257 
258   // sext/zext of anyext -> sext/zext
259   // -> pick anyext as first ext, then ext of ext
260   if (Second->getOpcode() == TargetOpcode::G_ANYEXT &&
261       isLegalOrBeforeLegalizer({First->getOpcode(), {DstTy, SrcTy}})) {
262     if (First->getOpcode() == TargetOpcode::G_ZEXT) {
263       MachineInstr::MIFlag Flag = MachineInstr::MIFlag::NoFlags;
264       if (First->getFlag(MachineInstr::MIFlag::NonNeg))
265         Flag = MachineInstr::MIFlag::NonNeg;
266       MatchInfo = [=](MachineIRBuilder &B) { B.buildZExt(Dst, Src, Flag); };
267       return true;
268     }
269     MatchInfo = [=](MachineIRBuilder &B) { B.buildSExt(Dst, Src); };
270     return true;
271   }
272 
273   return false;
274 }
275 
276 bool CombinerHelper::matchCastOfBuildVector(const MachineInstr &CastMI,
277                                             const MachineInstr &BVMI,
278                                             BuildFnTy &MatchInfo) const {
279   const GExtOrTruncOp *Cast = cast<GExtOrTruncOp>(&CastMI);
280   const GBuildVector *BV = cast<GBuildVector>(&BVMI);
281 
282   if (!MRI.hasOneNonDBGUse(BV->getReg(0)))
283     return false;
284 
285   Register Dst = Cast->getReg(0);
286   // The type of the new build vector.
287   LLT DstTy = MRI.getType(Dst);
288   // The scalar or element type of the new build vector.
289   LLT ElemTy = DstTy.getScalarType();
290   // The scalar or element type of the old build vector.
291   LLT InputElemTy = MRI.getType(BV->getReg(0)).getElementType();
292 
293   // Check legality of new build vector, the scalar casts, and profitability of
294   // the many casts.
295   if (!isLegalOrBeforeLegalizer(
296           {TargetOpcode::G_BUILD_VECTOR, {DstTy, ElemTy}}) ||
297       !isLegalOrBeforeLegalizer({Cast->getOpcode(), {ElemTy, InputElemTy}}) ||
298       !isCastFree(Cast->getOpcode(), ElemTy, InputElemTy))
299     return false;
300 
301   MatchInfo = [=](MachineIRBuilder &B) {
302     SmallVector<Register> Casts;
303     unsigned Elements = BV->getNumSources();
304     for (unsigned I = 0; I < Elements; ++I) {
305       auto CastI =
306           B.buildInstr(Cast->getOpcode(), {ElemTy}, {BV->getSourceReg(I)});
307       Casts.push_back(CastI.getReg(0));
308     }
309 
310     B.buildBuildVector(Dst, Casts);
311   };
312 
313   return true;
314 }
315 
316 bool CombinerHelper::matchNarrowBinop(const MachineInstr &TruncMI,
317                                       const MachineInstr &BinopMI,
318                                       BuildFnTy &MatchInfo) const {
319   const GTrunc *Trunc = cast<GTrunc>(&TruncMI);
320   const GBinOp *BinOp = cast<GBinOp>(&BinopMI);
321 
322   if (!MRI.hasOneNonDBGUse(BinOp->getReg(0)))
323     return false;
324 
325   Register Dst = Trunc->getReg(0);
326   LLT DstTy = MRI.getType(Dst);
327 
328   // Is narrow binop legal?
329   if (!isLegalOrBeforeLegalizer({BinOp->getOpcode(), {DstTy}}))
330     return false;
331 
332   MatchInfo = [=](MachineIRBuilder &B) {
333     auto LHS = B.buildTrunc(DstTy, BinOp->getLHSReg());
334     auto RHS = B.buildTrunc(DstTy, BinOp->getRHSReg());
335     B.buildInstr(BinOp->getOpcode(), {Dst}, {LHS, RHS});
336   };
337 
338   return true;
339 }
340 
341 bool CombinerHelper::matchCastOfInteger(const MachineInstr &CastMI,
342                                         APInt &MatchInfo) const {
343   const GExtOrTruncOp *Cast = cast<GExtOrTruncOp>(&CastMI);
344 
345   APInt Input = getIConstantFromReg(Cast->getSrcReg(), MRI);
346 
347   LLT DstTy = MRI.getType(Cast->getReg(0));
348 
349   if (!isConstantLegalOrBeforeLegalizer(DstTy))
350     return false;
351 
352   switch (Cast->getOpcode()) {
353   case TargetOpcode::G_TRUNC: {
354     MatchInfo = Input.trunc(DstTy.getScalarSizeInBits());
355     return true;
356   }
357   default:
358     return false;
359   }
360 }
361