1 //===- CombinerHelperCasts.cpp---------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements CombinerHelper for G_ANYEXT, G_SEXT, G_TRUNC, and 10 // G_ZEXT 11 // 12 //===----------------------------------------------------------------------===// 13 #include "llvm/CodeGen/GlobalISel/CombinerHelper.h" 14 #include "llvm/CodeGen/GlobalISel/LegalizerHelper.h" 15 #include "llvm/CodeGen/GlobalISel/LegalizerInfo.h" 16 #include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h" 17 #include "llvm/CodeGen/GlobalISel/Utils.h" 18 #include "llvm/CodeGen/LowLevelTypeUtils.h" 19 #include "llvm/CodeGen/MachineOperand.h" 20 #include "llvm/CodeGen/MachineRegisterInfo.h" 21 #include "llvm/CodeGen/TargetOpcodes.h" 22 #include "llvm/Support/Casting.h" 23 24 #define DEBUG_TYPE "gi-combiner" 25 26 using namespace llvm; 27 28 bool CombinerHelper::matchSextOfTrunc(const MachineOperand &MO, 29 BuildFnTy &MatchInfo) { 30 GSext *Sext = cast<GSext>(getDefIgnoringCopies(MO.getReg(), MRI)); 31 GTrunc *Trunc = cast<GTrunc>(getDefIgnoringCopies(Sext->getSrcReg(), MRI)); 32 33 Register Dst = Sext->getReg(0); 34 Register Src = Trunc->getSrcReg(); 35 36 LLT DstTy = MRI.getType(Dst); 37 LLT SrcTy = MRI.getType(Src); 38 39 if (DstTy == SrcTy) { 40 MatchInfo = [=](MachineIRBuilder &B) { B.buildCopy(Dst, Src); }; 41 return true; 42 } 43 44 if (DstTy.getScalarSizeInBits() < SrcTy.getScalarSizeInBits() && 45 isLegalOrBeforeLegalizer({TargetOpcode::G_TRUNC, {DstTy, SrcTy}})) { 46 MatchInfo = [=](MachineIRBuilder &B) { 47 B.buildTrunc(Dst, Src, MachineInstr::MIFlag::NoSWrap); 48 }; 49 return true; 50 } 51 52 if (DstTy.getScalarSizeInBits() > SrcTy.getScalarSizeInBits() && 53 isLegalOrBeforeLegalizer({TargetOpcode::G_SEXT, {DstTy, SrcTy}})) { 54 MatchInfo = [=](MachineIRBuilder &B) { B.buildSExt(Dst, Src); }; 55 return true; 56 } 57 58 return false; 59 } 60 61 bool CombinerHelper::matchZextOfTrunc(const MachineOperand &MO, 62 BuildFnTy &MatchInfo) { 63 GZext *Zext = cast<GZext>(getDefIgnoringCopies(MO.getReg(), MRI)); 64 GTrunc *Trunc = cast<GTrunc>(getDefIgnoringCopies(Zext->getSrcReg(), MRI)); 65 66 Register Dst = Zext->getReg(0); 67 Register Src = Trunc->getSrcReg(); 68 69 LLT DstTy = MRI.getType(Dst); 70 LLT SrcTy = MRI.getType(Src); 71 72 if (DstTy == SrcTy) { 73 MatchInfo = [=](MachineIRBuilder &B) { B.buildCopy(Dst, Src); }; 74 return true; 75 } 76 77 if (DstTy.getScalarSizeInBits() < SrcTy.getScalarSizeInBits() && 78 isLegalOrBeforeLegalizer({TargetOpcode::G_TRUNC, {DstTy, SrcTy}})) { 79 MatchInfo = [=](MachineIRBuilder &B) { 80 B.buildTrunc(Dst, Src, MachineInstr::MIFlag::NoUWrap); 81 }; 82 return true; 83 } 84 85 if (DstTy.getScalarSizeInBits() > SrcTy.getScalarSizeInBits() && 86 isLegalOrBeforeLegalizer({TargetOpcode::G_ZEXT, {DstTy, SrcTy}})) { 87 MatchInfo = [=](MachineIRBuilder &B) { 88 B.buildZExt(Dst, Src, MachineInstr::MIFlag::NonNeg); 89 }; 90 return true; 91 } 92 93 return false; 94 } 95 96 bool CombinerHelper::matchNonNegZext(const MachineOperand &MO, 97 BuildFnTy &MatchInfo) { 98 GZext *Zext = cast<GZext>(MRI.getVRegDef(MO.getReg())); 99 100 Register Dst = Zext->getReg(0); 101 Register Src = Zext->getSrcReg(); 102 103 LLT DstTy = MRI.getType(Dst); 104 LLT SrcTy = MRI.getType(Src); 105 const auto &TLI = getTargetLowering(); 106 107 // Convert zext nneg to sext if sext is the preferred form for the target. 108 if (isLegalOrBeforeLegalizer({TargetOpcode::G_SEXT, {DstTy, SrcTy}}) && 109 TLI.isSExtCheaperThanZExt(getMVTForLLT(SrcTy), getMVTForLLT(DstTy))) { 110 MatchInfo = [=](MachineIRBuilder &B) { B.buildSExt(Dst, Src); }; 111 return true; 112 } 113 114 return false; 115 } 116 117 bool CombinerHelper::matchTruncateOfExt(const MachineInstr &Root, 118 const MachineInstr &ExtMI, 119 BuildFnTy &MatchInfo) { 120 const GTrunc *Trunc = cast<GTrunc>(&Root); 121 const GExtOp *Ext = cast<GExtOp>(&ExtMI); 122 123 if (!MRI.hasOneNonDBGUse(Ext->getReg(0))) 124 return false; 125 126 Register Dst = Trunc->getReg(0); 127 Register Src = Ext->getSrcReg(); 128 LLT DstTy = MRI.getType(Dst); 129 LLT SrcTy = MRI.getType(Src); 130 131 if (SrcTy == DstTy) { 132 // The source and the destination are equally sized. We need to copy. 133 MatchInfo = [=](MachineIRBuilder &B) { B.buildCopy(Dst, Src); }; 134 135 return true; 136 } 137 138 if (SrcTy.getScalarSizeInBits() < DstTy.getScalarSizeInBits()) { 139 // If the source is smaller than the destination, we need to extend. 140 141 if (!isLegalOrBeforeLegalizer({Ext->getOpcode(), {DstTy, SrcTy}})) 142 return false; 143 144 MatchInfo = [=](MachineIRBuilder &B) { 145 B.buildInstr(Ext->getOpcode(), {Dst}, {Src}); 146 }; 147 148 return true; 149 } 150 151 if (SrcTy.getScalarSizeInBits() > DstTy.getScalarSizeInBits()) { 152 // If the source is larger than the destination, then we need to truncate. 153 154 if (!isLegalOrBeforeLegalizer({TargetOpcode::G_TRUNC, {DstTy, SrcTy}})) 155 return false; 156 157 MatchInfo = [=](MachineIRBuilder &B) { B.buildTrunc(Dst, Src); }; 158 159 return true; 160 } 161 162 return false; 163 } 164 165 bool CombinerHelper::isCastFree(unsigned Opcode, LLT ToTy, LLT FromTy) const { 166 const TargetLowering &TLI = getTargetLowering(); 167 const DataLayout &DL = getDataLayout(); 168 LLVMContext &Ctx = getContext(); 169 170 switch (Opcode) { 171 case TargetOpcode::G_ANYEXT: 172 case TargetOpcode::G_ZEXT: 173 return TLI.isZExtFree(FromTy, ToTy, DL, Ctx); 174 case TargetOpcode::G_TRUNC: 175 return TLI.isTruncateFree(FromTy, ToTy, DL, Ctx); 176 default: 177 return false; 178 } 179 } 180 181 bool CombinerHelper::matchCastOfSelect(const MachineInstr &CastMI, 182 const MachineInstr &SelectMI, 183 BuildFnTy &MatchInfo) { 184 const GExtOrTruncOp *Cast = cast<GExtOrTruncOp>(&CastMI); 185 const GSelect *Select = cast<GSelect>(&SelectMI); 186 187 if (!MRI.hasOneNonDBGUse(Select->getReg(0))) 188 return false; 189 190 Register Dst = Cast->getReg(0); 191 LLT DstTy = MRI.getType(Dst); 192 LLT CondTy = MRI.getType(Select->getCondReg()); 193 Register TrueReg = Select->getTrueReg(); 194 Register FalseReg = Select->getFalseReg(); 195 LLT SrcTy = MRI.getType(TrueReg); 196 Register Cond = Select->getCondReg(); 197 198 if (!isLegalOrBeforeLegalizer({TargetOpcode::G_SELECT, {DstTy, CondTy}})) 199 return false; 200 201 if (!isCastFree(Cast->getOpcode(), DstTy, SrcTy)) 202 return false; 203 204 MatchInfo = [=](MachineIRBuilder &B) { 205 auto True = B.buildInstr(Cast->getOpcode(), {DstTy}, {TrueReg}); 206 auto False = B.buildInstr(Cast->getOpcode(), {DstTy}, {FalseReg}); 207 B.buildSelect(Dst, Cond, True, False); 208 }; 209 210 return true; 211 } 212 213 bool CombinerHelper::matchExtOfExt(const MachineInstr &FirstMI, 214 const MachineInstr &SecondMI, 215 BuildFnTy &MatchInfo) { 216 const GExtOp *First = cast<GExtOp>(&FirstMI); 217 const GExtOp *Second = cast<GExtOp>(&SecondMI); 218 219 Register Dst = First->getReg(0); 220 Register Src = Second->getSrcReg(); 221 LLT DstTy = MRI.getType(Dst); 222 LLT SrcTy = MRI.getType(Src); 223 224 if (!MRI.hasOneNonDBGUse(Second->getReg(0))) 225 return false; 226 227 // ext of ext -> later ext 228 if (First->getOpcode() == Second->getOpcode() && 229 isLegalOrBeforeLegalizer({Second->getOpcode(), {DstTy, SrcTy}})) { 230 if (Second->getOpcode() == TargetOpcode::G_ZEXT) { 231 MachineInstr::MIFlag Flag = MachineInstr::MIFlag::NoFlags; 232 if (Second->getFlag(MachineInstr::MIFlag::NonNeg)) 233 Flag = MachineInstr::MIFlag::NonNeg; 234 MatchInfo = [=](MachineIRBuilder &B) { B.buildZExt(Dst, Src, Flag); }; 235 return true; 236 } 237 // not zext -> no flags 238 MatchInfo = [=](MachineIRBuilder &B) { 239 B.buildInstr(Second->getOpcode(), {Dst}, {Src}); 240 }; 241 return true; 242 } 243 244 // anyext of sext/zext -> sext/zext 245 // -> pick anyext as second ext, then ext of ext 246 if (First->getOpcode() == TargetOpcode::G_ANYEXT && 247 isLegalOrBeforeLegalizer({Second->getOpcode(), {DstTy, SrcTy}})) { 248 if (Second->getOpcode() == TargetOpcode::G_ZEXT) { 249 MachineInstr::MIFlag Flag = MachineInstr::MIFlag::NoFlags; 250 if (Second->getFlag(MachineInstr::MIFlag::NonNeg)) 251 Flag = MachineInstr::MIFlag::NonNeg; 252 MatchInfo = [=](MachineIRBuilder &B) { B.buildZExt(Dst, Src, Flag); }; 253 return true; 254 } 255 MatchInfo = [=](MachineIRBuilder &B) { B.buildSExt(Dst, Src); }; 256 return true; 257 } 258 259 // sext/zext of anyext -> sext/zext 260 // -> pick anyext as first ext, then ext of ext 261 if (Second->getOpcode() == TargetOpcode::G_ANYEXT && 262 isLegalOrBeforeLegalizer({First->getOpcode(), {DstTy, SrcTy}})) { 263 if (First->getOpcode() == TargetOpcode::G_ZEXT) { 264 MachineInstr::MIFlag Flag = MachineInstr::MIFlag::NoFlags; 265 if (First->getFlag(MachineInstr::MIFlag::NonNeg)) 266 Flag = MachineInstr::MIFlag::NonNeg; 267 MatchInfo = [=](MachineIRBuilder &B) { B.buildZExt(Dst, Src, Flag); }; 268 return true; 269 } 270 MatchInfo = [=](MachineIRBuilder &B) { B.buildSExt(Dst, Src); }; 271 return true; 272 } 273 274 return false; 275 } 276 277 bool CombinerHelper::matchCastOfBuildVector(const MachineInstr &CastMI, 278 const MachineInstr &BVMI, 279 BuildFnTy &MatchInfo) { 280 const GExtOrTruncOp *Cast = cast<GExtOrTruncOp>(&CastMI); 281 const GBuildVector *BV = cast<GBuildVector>(&BVMI); 282 283 if (!MRI.hasOneNonDBGUse(BV->getReg(0))) 284 return false; 285 286 Register Dst = Cast->getReg(0); 287 // The type of the new build vector. 288 LLT DstTy = MRI.getType(Dst); 289 // The scalar or element type of the new build vector. 290 LLT ElemTy = DstTy.getScalarType(); 291 // The scalar or element type of the old build vector. 292 LLT InputElemTy = MRI.getType(BV->getReg(0)).getElementType(); 293 294 // Check legality of new build vector, the scalar casts, and profitability of 295 // the many casts. 296 if (!isLegalOrBeforeLegalizer( 297 {TargetOpcode::G_BUILD_VECTOR, {DstTy, ElemTy}}) || 298 !isLegalOrBeforeLegalizer({Cast->getOpcode(), {ElemTy, InputElemTy}}) || 299 !isCastFree(Cast->getOpcode(), ElemTy, InputElemTy)) 300 return false; 301 302 MatchInfo = [=](MachineIRBuilder &B) { 303 SmallVector<Register> Casts; 304 unsigned Elements = BV->getNumSources(); 305 for (unsigned I = 0; I < Elements; ++I) { 306 auto CastI = 307 B.buildInstr(Cast->getOpcode(), {ElemTy}, {BV->getSourceReg(I)}); 308 Casts.push_back(CastI.getReg(0)); 309 } 310 311 B.buildBuildVector(Dst, Casts); 312 }; 313 314 return true; 315 } 316 317 bool CombinerHelper::matchNarrowBinop(const MachineInstr &TruncMI, 318 const MachineInstr &BinopMI, 319 BuildFnTy &MatchInfo) { 320 const GTrunc *Trunc = cast<GTrunc>(&TruncMI); 321 const GBinOp *BinOp = cast<GBinOp>(&BinopMI); 322 323 if (!MRI.hasOneNonDBGUse(BinOp->getReg(0))) 324 return false; 325 326 Register Dst = Trunc->getReg(0); 327 LLT DstTy = MRI.getType(Dst); 328 329 // Is narrow binop legal? 330 if (!isLegalOrBeforeLegalizer({BinOp->getOpcode(), {DstTy}})) 331 return false; 332 333 MatchInfo = [=](MachineIRBuilder &B) { 334 auto LHS = B.buildTrunc(DstTy, BinOp->getLHSReg()); 335 auto RHS = B.buildTrunc(DstTy, BinOp->getRHSReg()); 336 B.buildInstr(BinOp->getOpcode(), {Dst}, {LHS, RHS}); 337 }; 338 339 return true; 340 } 341 342 bool CombinerHelper::matchCastOfInteger(const MachineInstr &CastMI, 343 APInt &MatchInfo) { 344 const GExtOrTruncOp *Cast = cast<GExtOrTruncOp>(&CastMI); 345 346 APInt Input = getIConstantFromReg(Cast->getSrcReg(), MRI); 347 348 LLT DstTy = MRI.getType(Cast->getReg(0)); 349 350 if (!isConstantLegalOrBeforeLegalizer(DstTy)) 351 return false; 352 353 switch (Cast->getOpcode()) { 354 case TargetOpcode::G_TRUNC: { 355 MatchInfo = Input.trunc(DstTy.getScalarSizeInBits()); 356 return true; 357 } 358 default: 359 return false; 360 } 361 } 362