1 //===- CombinerHelperCasts.cpp---------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements CombinerHelper for G_ANYEXT, G_SEXT, G_TRUNC, and 10 // G_ZEXT 11 // 12 //===----------------------------------------------------------------------===// 13 #include "llvm/CodeGen/GlobalISel/CombinerHelper.h" 14 #include "llvm/CodeGen/GlobalISel/LegalizerHelper.h" 15 #include "llvm/CodeGen/GlobalISel/LegalizerInfo.h" 16 #include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h" 17 #include "llvm/CodeGen/GlobalISel/Utils.h" 18 #include "llvm/CodeGen/LowLevelTypeUtils.h" 19 #include "llvm/CodeGen/MachineOperand.h" 20 #include "llvm/CodeGen/MachineRegisterInfo.h" 21 #include "llvm/CodeGen/TargetOpcodes.h" 22 #include "llvm/Support/Casting.h" 23 24 #define DEBUG_TYPE "gi-combiner" 25 26 using namespace llvm; 27 28 bool CombinerHelper::matchSextOfTrunc(const MachineOperand &MO, 29 BuildFnTy &MatchInfo) const { 30 GSext *Sext = cast<GSext>(getDefIgnoringCopies(MO.getReg(), MRI)); 31 GTrunc *Trunc = cast<GTrunc>(getDefIgnoringCopies(Sext->getSrcReg(), MRI)); 32 33 Register Dst = Sext->getReg(0); 34 Register Src = Trunc->getSrcReg(); 35 36 LLT DstTy = MRI.getType(Dst); 37 LLT SrcTy = MRI.getType(Src); 38 39 if (DstTy == SrcTy) { 40 MatchInfo = [=](MachineIRBuilder &B) { B.buildCopy(Dst, Src); }; 41 return true; 42 } 43 44 if (DstTy.getScalarSizeInBits() < SrcTy.getScalarSizeInBits() && 45 isLegalOrBeforeLegalizer({TargetOpcode::G_TRUNC, {DstTy, SrcTy}})) { 46 MatchInfo = [=](MachineIRBuilder &B) { 47 B.buildTrunc(Dst, Src, MachineInstr::MIFlag::NoSWrap); 48 }; 49 return true; 50 } 51 52 if (DstTy.getScalarSizeInBits() > SrcTy.getScalarSizeInBits() && 53 isLegalOrBeforeLegalizer({TargetOpcode::G_SEXT, {DstTy, SrcTy}})) { 54 MatchInfo = [=](MachineIRBuilder &B) { B.buildSExt(Dst, Src); }; 55 return true; 56 } 57 58 return false; 59 } 60 61 bool CombinerHelper::matchZextOfTrunc(const MachineOperand &MO, 62 BuildFnTy &MatchInfo) const { 63 GZext *Zext = cast<GZext>(getDefIgnoringCopies(MO.getReg(), MRI)); 64 GTrunc *Trunc = cast<GTrunc>(getDefIgnoringCopies(Zext->getSrcReg(), MRI)); 65 66 Register Dst = Zext->getReg(0); 67 Register Src = Trunc->getSrcReg(); 68 69 LLT DstTy = MRI.getType(Dst); 70 LLT SrcTy = MRI.getType(Src); 71 72 if (DstTy == SrcTy) { 73 MatchInfo = [=](MachineIRBuilder &B) { B.buildCopy(Dst, Src); }; 74 return true; 75 } 76 77 if (DstTy.getScalarSizeInBits() < SrcTy.getScalarSizeInBits() && 78 isLegalOrBeforeLegalizer({TargetOpcode::G_TRUNC, {DstTy, SrcTy}})) { 79 MatchInfo = [=](MachineIRBuilder &B) { 80 B.buildTrunc(Dst, Src, MachineInstr::MIFlag::NoUWrap); 81 }; 82 return true; 83 } 84 85 if (DstTy.getScalarSizeInBits() > SrcTy.getScalarSizeInBits() && 86 isLegalOrBeforeLegalizer({TargetOpcode::G_ZEXT, {DstTy, SrcTy}})) { 87 MatchInfo = [=](MachineIRBuilder &B) { 88 B.buildZExt(Dst, Src, MachineInstr::MIFlag::NonNeg); 89 }; 90 return true; 91 } 92 93 return false; 94 } 95 96 bool CombinerHelper::matchNonNegZext(const MachineOperand &MO, 97 BuildFnTy &MatchInfo) const { 98 GZext *Zext = cast<GZext>(MRI.getVRegDef(MO.getReg())); 99 100 Register Dst = Zext->getReg(0); 101 Register Src = Zext->getSrcReg(); 102 103 LLT DstTy = MRI.getType(Dst); 104 LLT SrcTy = MRI.getType(Src); 105 const auto &TLI = getTargetLowering(); 106 107 // Convert zext nneg to sext if sext is the preferred form for the target. 108 if (isLegalOrBeforeLegalizer({TargetOpcode::G_SEXT, {DstTy, SrcTy}}) && 109 TLI.isSExtCheaperThanZExt(getMVTForLLT(SrcTy), getMVTForLLT(DstTy))) { 110 MatchInfo = [=](MachineIRBuilder &B) { B.buildSExt(Dst, Src); }; 111 return true; 112 } 113 114 return false; 115 } 116 117 bool CombinerHelper::matchTruncateOfExt(const MachineInstr &Root, 118 const MachineInstr &ExtMI, 119 BuildFnTy &MatchInfo) const { 120 const GTrunc *Trunc = cast<GTrunc>(&Root); 121 const GExtOp *Ext = cast<GExtOp>(&ExtMI); 122 123 if (!MRI.hasOneNonDBGUse(Ext->getReg(0))) 124 return false; 125 126 Register Dst = Trunc->getReg(0); 127 Register Src = Ext->getSrcReg(); 128 LLT DstTy = MRI.getType(Dst); 129 LLT SrcTy = MRI.getType(Src); 130 131 if (SrcTy == DstTy) { 132 // The source and the destination are equally sized. We need to copy. 133 MatchInfo = [=](MachineIRBuilder &B) { B.buildCopy(Dst, Src); }; 134 135 return true; 136 } 137 138 if (SrcTy.getScalarSizeInBits() < DstTy.getScalarSizeInBits()) { 139 // If the source is smaller than the destination, we need to extend. 140 141 if (!isLegalOrBeforeLegalizer({Ext->getOpcode(), {DstTy, SrcTy}})) 142 return false; 143 144 MatchInfo = [=](MachineIRBuilder &B) { 145 B.buildInstr(Ext->getOpcode(), {Dst}, {Src}); 146 }; 147 148 return true; 149 } 150 151 if (SrcTy.getScalarSizeInBits() > DstTy.getScalarSizeInBits()) { 152 // If the source is larger than the destination, then we need to truncate. 153 154 if (!isLegalOrBeforeLegalizer({TargetOpcode::G_TRUNC, {DstTy, SrcTy}})) 155 return false; 156 157 MatchInfo = [=](MachineIRBuilder &B) { B.buildTrunc(Dst, Src); }; 158 159 return true; 160 } 161 162 return false; 163 } 164 165 bool CombinerHelper::isCastFree(unsigned Opcode, LLT ToTy, LLT FromTy) const { 166 const TargetLowering &TLI = getTargetLowering(); 167 LLVMContext &Ctx = getContext(); 168 169 switch (Opcode) { 170 case TargetOpcode::G_ANYEXT: 171 case TargetOpcode::G_ZEXT: 172 return TLI.isZExtFree(FromTy, ToTy, Ctx); 173 case TargetOpcode::G_TRUNC: 174 return TLI.isTruncateFree(FromTy, ToTy, Ctx); 175 default: 176 return false; 177 } 178 } 179 180 bool CombinerHelper::matchCastOfSelect(const MachineInstr &CastMI, 181 const MachineInstr &SelectMI, 182 BuildFnTy &MatchInfo) const { 183 const GExtOrTruncOp *Cast = cast<GExtOrTruncOp>(&CastMI); 184 const GSelect *Select = cast<GSelect>(&SelectMI); 185 186 if (!MRI.hasOneNonDBGUse(Select->getReg(0))) 187 return false; 188 189 Register Dst = Cast->getReg(0); 190 LLT DstTy = MRI.getType(Dst); 191 LLT CondTy = MRI.getType(Select->getCondReg()); 192 Register TrueReg = Select->getTrueReg(); 193 Register FalseReg = Select->getFalseReg(); 194 LLT SrcTy = MRI.getType(TrueReg); 195 Register Cond = Select->getCondReg(); 196 197 if (!isLegalOrBeforeLegalizer({TargetOpcode::G_SELECT, {DstTy, CondTy}})) 198 return false; 199 200 if (!isCastFree(Cast->getOpcode(), DstTy, SrcTy)) 201 return false; 202 203 MatchInfo = [=](MachineIRBuilder &B) { 204 auto True = B.buildInstr(Cast->getOpcode(), {DstTy}, {TrueReg}); 205 auto False = B.buildInstr(Cast->getOpcode(), {DstTy}, {FalseReg}); 206 B.buildSelect(Dst, Cond, True, False); 207 }; 208 209 return true; 210 } 211 212 bool CombinerHelper::matchExtOfExt(const MachineInstr &FirstMI, 213 const MachineInstr &SecondMI, 214 BuildFnTy &MatchInfo) const { 215 const GExtOp *First = cast<GExtOp>(&FirstMI); 216 const GExtOp *Second = cast<GExtOp>(&SecondMI); 217 218 Register Dst = First->getReg(0); 219 Register Src = Second->getSrcReg(); 220 LLT DstTy = MRI.getType(Dst); 221 LLT SrcTy = MRI.getType(Src); 222 223 if (!MRI.hasOneNonDBGUse(Second->getReg(0))) 224 return false; 225 226 // ext of ext -> later ext 227 if (First->getOpcode() == Second->getOpcode() && 228 isLegalOrBeforeLegalizer({Second->getOpcode(), {DstTy, SrcTy}})) { 229 if (Second->getOpcode() == TargetOpcode::G_ZEXT) { 230 MachineInstr::MIFlag Flag = MachineInstr::MIFlag::NoFlags; 231 if (Second->getFlag(MachineInstr::MIFlag::NonNeg)) 232 Flag = MachineInstr::MIFlag::NonNeg; 233 MatchInfo = [=](MachineIRBuilder &B) { B.buildZExt(Dst, Src, Flag); }; 234 return true; 235 } 236 // not zext -> no flags 237 MatchInfo = [=](MachineIRBuilder &B) { 238 B.buildInstr(Second->getOpcode(), {Dst}, {Src}); 239 }; 240 return true; 241 } 242 243 // anyext of sext/zext -> sext/zext 244 // -> pick anyext as second ext, then ext of ext 245 if (First->getOpcode() == TargetOpcode::G_ANYEXT && 246 isLegalOrBeforeLegalizer({Second->getOpcode(), {DstTy, SrcTy}})) { 247 if (Second->getOpcode() == TargetOpcode::G_ZEXT) { 248 MachineInstr::MIFlag Flag = MachineInstr::MIFlag::NoFlags; 249 if (Second->getFlag(MachineInstr::MIFlag::NonNeg)) 250 Flag = MachineInstr::MIFlag::NonNeg; 251 MatchInfo = [=](MachineIRBuilder &B) { B.buildZExt(Dst, Src, Flag); }; 252 return true; 253 } 254 MatchInfo = [=](MachineIRBuilder &B) { B.buildSExt(Dst, Src); }; 255 return true; 256 } 257 258 // sext/zext of anyext -> sext/zext 259 // -> pick anyext as first ext, then ext of ext 260 if (Second->getOpcode() == TargetOpcode::G_ANYEXT && 261 isLegalOrBeforeLegalizer({First->getOpcode(), {DstTy, SrcTy}})) { 262 if (First->getOpcode() == TargetOpcode::G_ZEXT) { 263 MachineInstr::MIFlag Flag = MachineInstr::MIFlag::NoFlags; 264 if (First->getFlag(MachineInstr::MIFlag::NonNeg)) 265 Flag = MachineInstr::MIFlag::NonNeg; 266 MatchInfo = [=](MachineIRBuilder &B) { B.buildZExt(Dst, Src, Flag); }; 267 return true; 268 } 269 MatchInfo = [=](MachineIRBuilder &B) { B.buildSExt(Dst, Src); }; 270 return true; 271 } 272 273 return false; 274 } 275 276 bool CombinerHelper::matchCastOfBuildVector(const MachineInstr &CastMI, 277 const MachineInstr &BVMI, 278 BuildFnTy &MatchInfo) const { 279 const GExtOrTruncOp *Cast = cast<GExtOrTruncOp>(&CastMI); 280 const GBuildVector *BV = cast<GBuildVector>(&BVMI); 281 282 if (!MRI.hasOneNonDBGUse(BV->getReg(0))) 283 return false; 284 285 Register Dst = Cast->getReg(0); 286 // The type of the new build vector. 287 LLT DstTy = MRI.getType(Dst); 288 // The scalar or element type of the new build vector. 289 LLT ElemTy = DstTy.getScalarType(); 290 // The scalar or element type of the old build vector. 291 LLT InputElemTy = MRI.getType(BV->getReg(0)).getElementType(); 292 293 // Check legality of new build vector, the scalar casts, and profitability of 294 // the many casts. 295 if (!isLegalOrBeforeLegalizer( 296 {TargetOpcode::G_BUILD_VECTOR, {DstTy, ElemTy}}) || 297 !isLegalOrBeforeLegalizer({Cast->getOpcode(), {ElemTy, InputElemTy}}) || 298 !isCastFree(Cast->getOpcode(), ElemTy, InputElemTy)) 299 return false; 300 301 MatchInfo = [=](MachineIRBuilder &B) { 302 SmallVector<Register> Casts; 303 unsigned Elements = BV->getNumSources(); 304 for (unsigned I = 0; I < Elements; ++I) { 305 auto CastI = 306 B.buildInstr(Cast->getOpcode(), {ElemTy}, {BV->getSourceReg(I)}); 307 Casts.push_back(CastI.getReg(0)); 308 } 309 310 B.buildBuildVector(Dst, Casts); 311 }; 312 313 return true; 314 } 315 316 bool CombinerHelper::matchNarrowBinop(const MachineInstr &TruncMI, 317 const MachineInstr &BinopMI, 318 BuildFnTy &MatchInfo) const { 319 const GTrunc *Trunc = cast<GTrunc>(&TruncMI); 320 const GBinOp *BinOp = cast<GBinOp>(&BinopMI); 321 322 if (!MRI.hasOneNonDBGUse(BinOp->getReg(0))) 323 return false; 324 325 Register Dst = Trunc->getReg(0); 326 LLT DstTy = MRI.getType(Dst); 327 328 // Is narrow binop legal? 329 if (!isLegalOrBeforeLegalizer({BinOp->getOpcode(), {DstTy}})) 330 return false; 331 332 MatchInfo = [=](MachineIRBuilder &B) { 333 auto LHS = B.buildTrunc(DstTy, BinOp->getLHSReg()); 334 auto RHS = B.buildTrunc(DstTy, BinOp->getRHSReg()); 335 B.buildInstr(BinOp->getOpcode(), {Dst}, {LHS, RHS}); 336 }; 337 338 return true; 339 } 340 341 bool CombinerHelper::matchCastOfInteger(const MachineInstr &CastMI, 342 APInt &MatchInfo) const { 343 const GExtOrTruncOp *Cast = cast<GExtOrTruncOp>(&CastMI); 344 345 APInt Input = getIConstantFromReg(Cast->getSrcReg(), MRI); 346 347 LLT DstTy = MRI.getType(Cast->getReg(0)); 348 349 if (!isConstantLegalOrBeforeLegalizer(DstTy)) 350 return false; 351 352 switch (Cast->getOpcode()) { 353 case TargetOpcode::G_TRUNC: { 354 MatchInfo = Input.trunc(DstTy.getScalarSizeInBits()); 355 return true; 356 } 357 default: 358 return false; 359 } 360 } 361