xref: /llvm-project/llvm/lib/CodeGen/GlobalISel/CombinerHelperCasts.cpp (revision db8c84fc7a75dd60bcfff7160b51e1a55e7e0f73)
1 //===- CombinerHelperCasts.cpp---------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements CombinerHelper for G_ANYEXT, G_SEXT, G_TRUNC, and
10 // G_ZEXT
11 //
12 //===----------------------------------------------------------------------===//
13 #include "llvm/CodeGen/GlobalISel/CombinerHelper.h"
14 #include "llvm/CodeGen/GlobalISel/LegalizerHelper.h"
15 #include "llvm/CodeGen/GlobalISel/LegalizerInfo.h"
16 #include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h"
17 #include "llvm/CodeGen/GlobalISel/Utils.h"
18 #include "llvm/CodeGen/LowLevelTypeUtils.h"
19 #include "llvm/CodeGen/MachineOperand.h"
20 #include "llvm/CodeGen/MachineRegisterInfo.h"
21 #include "llvm/CodeGen/TargetOpcodes.h"
22 #include "llvm/Support/Casting.h"
23 
24 #define DEBUG_TYPE "gi-combiner"
25 
26 using namespace llvm;
27 
28 bool CombinerHelper::matchSextOfTrunc(const MachineOperand &MO,
29                                       BuildFnTy &MatchInfo) {
30   GSext *Sext = cast<GSext>(getDefIgnoringCopies(MO.getReg(), MRI));
31   GTrunc *Trunc = cast<GTrunc>(getDefIgnoringCopies(Sext->getSrcReg(), MRI));
32 
33   Register Dst = Sext->getReg(0);
34   Register Src = Trunc->getSrcReg();
35 
36   LLT DstTy = MRI.getType(Dst);
37   LLT SrcTy = MRI.getType(Src);
38 
39   if (DstTy == SrcTy) {
40     MatchInfo = [=](MachineIRBuilder &B) { B.buildCopy(Dst, Src); };
41     return true;
42   }
43 
44   if (DstTy.getScalarSizeInBits() < SrcTy.getScalarSizeInBits() &&
45       isLegalOrBeforeLegalizer({TargetOpcode::G_TRUNC, {DstTy, SrcTy}})) {
46     MatchInfo = [=](MachineIRBuilder &B) {
47       B.buildTrunc(Dst, Src, MachineInstr::MIFlag::NoSWrap);
48     };
49     return true;
50   }
51 
52   if (DstTy.getScalarSizeInBits() > SrcTy.getScalarSizeInBits() &&
53       isLegalOrBeforeLegalizer({TargetOpcode::G_SEXT, {DstTy, SrcTy}})) {
54     MatchInfo = [=](MachineIRBuilder &B) { B.buildSExt(Dst, Src); };
55     return true;
56   }
57 
58   return false;
59 }
60 
61 bool CombinerHelper::matchZextOfTrunc(const MachineOperand &MO,
62                                       BuildFnTy &MatchInfo) {
63   GZext *Zext = cast<GZext>(getDefIgnoringCopies(MO.getReg(), MRI));
64   GTrunc *Trunc = cast<GTrunc>(getDefIgnoringCopies(Zext->getSrcReg(), MRI));
65 
66   Register Dst = Zext->getReg(0);
67   Register Src = Trunc->getSrcReg();
68 
69   LLT DstTy = MRI.getType(Dst);
70   LLT SrcTy = MRI.getType(Src);
71 
72   if (DstTy == SrcTy) {
73     MatchInfo = [=](MachineIRBuilder &B) { B.buildCopy(Dst, Src); };
74     return true;
75   }
76 
77   if (DstTy.getScalarSizeInBits() < SrcTy.getScalarSizeInBits() &&
78       isLegalOrBeforeLegalizer({TargetOpcode::G_TRUNC, {DstTy, SrcTy}})) {
79     MatchInfo = [=](MachineIRBuilder &B) {
80       B.buildTrunc(Dst, Src, MachineInstr::MIFlag::NoUWrap);
81     };
82     return true;
83   }
84 
85   if (DstTy.getScalarSizeInBits() > SrcTy.getScalarSizeInBits() &&
86       isLegalOrBeforeLegalizer({TargetOpcode::G_ZEXT, {DstTy, SrcTy}})) {
87     MatchInfo = [=](MachineIRBuilder &B) {
88       B.buildZExt(Dst, Src, MachineInstr::MIFlag::NonNeg);
89     };
90     return true;
91   }
92 
93   return false;
94 }
95 
96 bool CombinerHelper::matchNonNegZext(const MachineOperand &MO,
97                                      BuildFnTy &MatchInfo) {
98   GZext *Zext = cast<GZext>(MRI.getVRegDef(MO.getReg()));
99 
100   Register Dst = Zext->getReg(0);
101   Register Src = Zext->getSrcReg();
102 
103   LLT DstTy = MRI.getType(Dst);
104   LLT SrcTy = MRI.getType(Src);
105   const auto &TLI = getTargetLowering();
106 
107   // Convert zext nneg to sext if sext is the preferred form for the target.
108   if (isLegalOrBeforeLegalizer({TargetOpcode::G_SEXT, {DstTy, SrcTy}}) &&
109       TLI.isSExtCheaperThanZExt(getMVTForLLT(SrcTy), getMVTForLLT(DstTy))) {
110     MatchInfo = [=](MachineIRBuilder &B) { B.buildSExt(Dst, Src); };
111     return true;
112   }
113 
114   return false;
115 }
116 
117 bool CombinerHelper::matchTruncateOfExt(const MachineInstr &Root,
118                                         const MachineInstr &ExtMI,
119                                         BuildFnTy &MatchInfo) {
120   const GTrunc *Trunc = cast<GTrunc>(&Root);
121   const GExtOp *Ext = cast<GExtOp>(&ExtMI);
122 
123   if (!MRI.hasOneNonDBGUse(Ext->getReg(0)))
124     return false;
125 
126   Register Dst = Trunc->getReg(0);
127   Register Src = Ext->getSrcReg();
128   LLT DstTy = MRI.getType(Dst);
129   LLT SrcTy = MRI.getType(Src);
130 
131   if (SrcTy == DstTy) {
132     // The source and the destination are equally sized. We need to copy.
133     MatchInfo = [=](MachineIRBuilder &B) { B.buildCopy(Dst, Src); };
134 
135     return true;
136   }
137 
138   if (SrcTy.getScalarSizeInBits() < DstTy.getScalarSizeInBits()) {
139     // If the source is smaller than the destination, we need to extend.
140 
141     if (!isLegalOrBeforeLegalizer({Ext->getOpcode(), {DstTy, SrcTy}}))
142       return false;
143 
144     MatchInfo = [=](MachineIRBuilder &B) {
145       B.buildInstr(Ext->getOpcode(), {Dst}, {Src});
146     };
147 
148     return true;
149   }
150 
151   if (SrcTy.getScalarSizeInBits() > DstTy.getScalarSizeInBits()) {
152     // If the source is larger than the destination, then we need to truncate.
153 
154     if (!isLegalOrBeforeLegalizer({TargetOpcode::G_TRUNC, {DstTy, SrcTy}}))
155       return false;
156 
157     MatchInfo = [=](MachineIRBuilder &B) { B.buildTrunc(Dst, Src); };
158 
159     return true;
160   }
161 
162   return false;
163 }
164 
165 bool CombinerHelper::isCastFree(unsigned Opcode, LLT ToTy, LLT FromTy) const {
166   const TargetLowering &TLI = getTargetLowering();
167   const DataLayout &DL = getDataLayout();
168   LLVMContext &Ctx = getContext();
169 
170   switch (Opcode) {
171   case TargetOpcode::G_ANYEXT:
172   case TargetOpcode::G_ZEXT:
173     return TLI.isZExtFree(FromTy, ToTy, DL, Ctx);
174   case TargetOpcode::G_TRUNC:
175     return TLI.isTruncateFree(FromTy, ToTy, DL, Ctx);
176   default:
177     return false;
178   }
179 }
180 
181 bool CombinerHelper::matchCastOfSelect(const MachineInstr &CastMI,
182                                        const MachineInstr &SelectMI,
183                                        BuildFnTy &MatchInfo) {
184   const GExtOrTruncOp *Cast = cast<GExtOrTruncOp>(&CastMI);
185   const GSelect *Select = cast<GSelect>(&SelectMI);
186 
187   if (!MRI.hasOneNonDBGUse(Select->getReg(0)))
188     return false;
189 
190   Register Dst = Cast->getReg(0);
191   LLT DstTy = MRI.getType(Dst);
192   LLT CondTy = MRI.getType(Select->getCondReg());
193   Register TrueReg = Select->getTrueReg();
194   Register FalseReg = Select->getFalseReg();
195   LLT SrcTy = MRI.getType(TrueReg);
196   Register Cond = Select->getCondReg();
197 
198   if (!isLegalOrBeforeLegalizer({TargetOpcode::G_SELECT, {DstTy, CondTy}}))
199     return false;
200 
201   if (!isCastFree(Cast->getOpcode(), DstTy, SrcTy))
202     return false;
203 
204   MatchInfo = [=](MachineIRBuilder &B) {
205     auto True = B.buildInstr(Cast->getOpcode(), {DstTy}, {TrueReg});
206     auto False = B.buildInstr(Cast->getOpcode(), {DstTy}, {FalseReg});
207     B.buildSelect(Dst, Cond, True, False);
208   };
209 
210   return true;
211 }
212