1*bdd1243dSDimitry Andric //===- ComplexDeinterleavingPass.cpp --------------------------------------===// 2*bdd1243dSDimitry Andric // 3*bdd1243dSDimitry Andric // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4*bdd1243dSDimitry Andric // See https://llvm.org/LICENSE.txt for license information. 5*bdd1243dSDimitry Andric // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6*bdd1243dSDimitry Andric // 7*bdd1243dSDimitry Andric //===----------------------------------------------------------------------===// 8*bdd1243dSDimitry Andric // 9*bdd1243dSDimitry Andric // Identification: 10*bdd1243dSDimitry Andric // This step is responsible for finding the patterns that can be lowered to 11*bdd1243dSDimitry Andric // complex instructions, and building a graph to represent the complex 12*bdd1243dSDimitry Andric // structures. Starting from the "Converging Shuffle" (a shuffle that 13*bdd1243dSDimitry Andric // reinterleaves the complex components, with a mask of <0, 2, 1, 3>), the 14*bdd1243dSDimitry Andric // operands are evaluated and identified as "Composite Nodes" (collections of 15*bdd1243dSDimitry Andric // instructions that can potentially be lowered to a single complex 16*bdd1243dSDimitry Andric // instruction). This is performed by checking the real and imaginary components 17*bdd1243dSDimitry Andric // and tracking the data flow for each component while following the operand 18*bdd1243dSDimitry Andric // pairs. Validity of each node is expected to be done upon creation, and any 19*bdd1243dSDimitry Andric // validation errors should halt traversal and prevent further graph 20*bdd1243dSDimitry Andric // construction. 21*bdd1243dSDimitry Andric // 22*bdd1243dSDimitry Andric // Replacement: 23*bdd1243dSDimitry Andric // This step traverses the graph built up by identification, delegating to the 24*bdd1243dSDimitry Andric // target to validate and generate the correct intrinsics, and plumbs them 25*bdd1243dSDimitry Andric // together connecting each end of the new intrinsics graph to the existing 26*bdd1243dSDimitry Andric // use-def chain. This step is assumed to finish successfully, as all 27*bdd1243dSDimitry Andric // information is expected to be correct by this point. 28*bdd1243dSDimitry Andric // 29*bdd1243dSDimitry Andric // 30*bdd1243dSDimitry Andric // Internal data structure: 31*bdd1243dSDimitry Andric // ComplexDeinterleavingGraph: 32*bdd1243dSDimitry Andric // Keeps references to all the valid CompositeNodes formed as part of the 33*bdd1243dSDimitry Andric // transformation, and every Instruction contained within said nodes. It also 34*bdd1243dSDimitry Andric // holds onto a reference to the root Instruction, and the root node that should 35*bdd1243dSDimitry Andric // replace it. 36*bdd1243dSDimitry Andric // 37*bdd1243dSDimitry Andric // ComplexDeinterleavingCompositeNode: 38*bdd1243dSDimitry Andric // A CompositeNode represents a single transformation point; each node should 39*bdd1243dSDimitry Andric // transform into a single complex instruction (ignoring vector splitting, which 40*bdd1243dSDimitry Andric // would generate more instructions per node). They are identified in a 41*bdd1243dSDimitry Andric // depth-first manner, traversing and identifying the operands of each 42*bdd1243dSDimitry Andric // instruction in the order they appear in the IR. 43*bdd1243dSDimitry Andric // Each node maintains a reference to its Real and Imaginary instructions, 44*bdd1243dSDimitry Andric // as well as any additional instructions that make up the identified operation 45*bdd1243dSDimitry Andric // (Internal instructions should only have uses within their containing node). 46*bdd1243dSDimitry Andric // A Node also contains the rotation and operation type that it represents. 47*bdd1243dSDimitry Andric // Operands contains pointers to other CompositeNodes, acting as the edges in 48*bdd1243dSDimitry Andric // the graph. ReplacementValue is the transformed Value* that has been emitted 49*bdd1243dSDimitry Andric // to the IR. 50*bdd1243dSDimitry Andric // 51*bdd1243dSDimitry Andric // Note: If the operation of a Node is Shuffle, only the Real, Imaginary, and 52*bdd1243dSDimitry Andric // ReplacementValue fields of that Node are relevant, where the ReplacementValue 53*bdd1243dSDimitry Andric // should be pre-populated. 54*bdd1243dSDimitry Andric // 55*bdd1243dSDimitry Andric //===----------------------------------------------------------------------===// 56*bdd1243dSDimitry Andric 57*bdd1243dSDimitry Andric #include "llvm/CodeGen/ComplexDeinterleavingPass.h" 58*bdd1243dSDimitry Andric #include "llvm/ADT/Statistic.h" 59*bdd1243dSDimitry Andric #include "llvm/Analysis/TargetLibraryInfo.h" 60*bdd1243dSDimitry Andric #include "llvm/Analysis/TargetTransformInfo.h" 61*bdd1243dSDimitry Andric #include "llvm/CodeGen/TargetLowering.h" 62*bdd1243dSDimitry Andric #include "llvm/CodeGen/TargetPassConfig.h" 63*bdd1243dSDimitry Andric #include "llvm/CodeGen/TargetSubtargetInfo.h" 64*bdd1243dSDimitry Andric #include "llvm/IR/IRBuilder.h" 65*bdd1243dSDimitry Andric #include "llvm/InitializePasses.h" 66*bdd1243dSDimitry Andric #include "llvm/Target/TargetMachine.h" 67*bdd1243dSDimitry Andric #include "llvm/Transforms/Utils/Local.h" 68*bdd1243dSDimitry Andric #include <algorithm> 69*bdd1243dSDimitry Andric 70*bdd1243dSDimitry Andric using namespace llvm; 71*bdd1243dSDimitry Andric using namespace PatternMatch; 72*bdd1243dSDimitry Andric 73*bdd1243dSDimitry Andric #define DEBUG_TYPE "complex-deinterleaving" 74*bdd1243dSDimitry Andric 75*bdd1243dSDimitry Andric STATISTIC(NumComplexTransformations, "Amount of complex patterns transformed"); 76*bdd1243dSDimitry Andric 77*bdd1243dSDimitry Andric static cl::opt<bool> ComplexDeinterleavingEnabled( 78*bdd1243dSDimitry Andric "enable-complex-deinterleaving", 79*bdd1243dSDimitry Andric cl::desc("Enable generation of complex instructions"), cl::init(true), 80*bdd1243dSDimitry Andric cl::Hidden); 81*bdd1243dSDimitry Andric 82*bdd1243dSDimitry Andric /// Checks the given mask, and determines whether said mask is interleaving. 83*bdd1243dSDimitry Andric /// 84*bdd1243dSDimitry Andric /// To be interleaving, a mask must alternate between `i` and `i + (Length / 85*bdd1243dSDimitry Andric /// 2)`, and must contain all numbers within the range of `[0..Length)` (e.g. a 86*bdd1243dSDimitry Andric /// 4x vector interleaving mask would be <0, 2, 1, 3>). 87*bdd1243dSDimitry Andric static bool isInterleavingMask(ArrayRef<int> Mask); 88*bdd1243dSDimitry Andric 89*bdd1243dSDimitry Andric /// Checks the given mask, and determines whether said mask is deinterleaving. 90*bdd1243dSDimitry Andric /// 91*bdd1243dSDimitry Andric /// To be deinterleaving, a mask must increment in steps of 2, and either start 92*bdd1243dSDimitry Andric /// with 0 or 1. 93*bdd1243dSDimitry Andric /// (e.g. an 8x vector deinterleaving mask would be either <0, 2, 4, 6> or 94*bdd1243dSDimitry Andric /// <1, 3, 5, 7>). 95*bdd1243dSDimitry Andric static bool isDeinterleavingMask(ArrayRef<int> Mask); 96*bdd1243dSDimitry Andric 97*bdd1243dSDimitry Andric namespace { 98*bdd1243dSDimitry Andric 99*bdd1243dSDimitry Andric class ComplexDeinterleavingLegacyPass : public FunctionPass { 100*bdd1243dSDimitry Andric public: 101*bdd1243dSDimitry Andric static char ID; 102*bdd1243dSDimitry Andric 103*bdd1243dSDimitry Andric ComplexDeinterleavingLegacyPass(const TargetMachine *TM = nullptr) 104*bdd1243dSDimitry Andric : FunctionPass(ID), TM(TM) { 105*bdd1243dSDimitry Andric initializeComplexDeinterleavingLegacyPassPass( 106*bdd1243dSDimitry Andric *PassRegistry::getPassRegistry()); 107*bdd1243dSDimitry Andric } 108*bdd1243dSDimitry Andric 109*bdd1243dSDimitry Andric StringRef getPassName() const override { 110*bdd1243dSDimitry Andric return "Complex Deinterleaving Pass"; 111*bdd1243dSDimitry Andric } 112*bdd1243dSDimitry Andric 113*bdd1243dSDimitry Andric bool runOnFunction(Function &F) override; 114*bdd1243dSDimitry Andric void getAnalysisUsage(AnalysisUsage &AU) const override { 115*bdd1243dSDimitry Andric AU.addRequired<TargetLibraryInfoWrapperPass>(); 116*bdd1243dSDimitry Andric AU.setPreservesCFG(); 117*bdd1243dSDimitry Andric } 118*bdd1243dSDimitry Andric 119*bdd1243dSDimitry Andric private: 120*bdd1243dSDimitry Andric const TargetMachine *TM; 121*bdd1243dSDimitry Andric }; 122*bdd1243dSDimitry Andric 123*bdd1243dSDimitry Andric class ComplexDeinterleavingGraph; 124*bdd1243dSDimitry Andric struct ComplexDeinterleavingCompositeNode { 125*bdd1243dSDimitry Andric 126*bdd1243dSDimitry Andric ComplexDeinterleavingCompositeNode(ComplexDeinterleavingOperation Op, 127*bdd1243dSDimitry Andric Instruction *R, Instruction *I) 128*bdd1243dSDimitry Andric : Operation(Op), Real(R), Imag(I) {} 129*bdd1243dSDimitry Andric 130*bdd1243dSDimitry Andric private: 131*bdd1243dSDimitry Andric friend class ComplexDeinterleavingGraph; 132*bdd1243dSDimitry Andric using NodePtr = std::shared_ptr<ComplexDeinterleavingCompositeNode>; 133*bdd1243dSDimitry Andric using RawNodePtr = ComplexDeinterleavingCompositeNode *; 134*bdd1243dSDimitry Andric 135*bdd1243dSDimitry Andric public: 136*bdd1243dSDimitry Andric ComplexDeinterleavingOperation Operation; 137*bdd1243dSDimitry Andric Instruction *Real; 138*bdd1243dSDimitry Andric Instruction *Imag; 139*bdd1243dSDimitry Andric 140*bdd1243dSDimitry Andric // Instructions that should only exist within this node, there should be no 141*bdd1243dSDimitry Andric // users of these instructions outside the node. An example of these would be 142*bdd1243dSDimitry Andric // the multiply instructions of a partial multiply operation. 143*bdd1243dSDimitry Andric SmallVector<Instruction *> InternalInstructions; 144*bdd1243dSDimitry Andric ComplexDeinterleavingRotation Rotation; 145*bdd1243dSDimitry Andric SmallVector<RawNodePtr> Operands; 146*bdd1243dSDimitry Andric Value *ReplacementNode = nullptr; 147*bdd1243dSDimitry Andric 148*bdd1243dSDimitry Andric void addInstruction(Instruction *I) { InternalInstructions.push_back(I); } 149*bdd1243dSDimitry Andric void addOperand(NodePtr Node) { Operands.push_back(Node.get()); } 150*bdd1243dSDimitry Andric 151*bdd1243dSDimitry Andric bool hasAllInternalUses(SmallPtrSet<Instruction *, 16> &AllInstructions); 152*bdd1243dSDimitry Andric 153*bdd1243dSDimitry Andric void dump() { dump(dbgs()); } 154*bdd1243dSDimitry Andric void dump(raw_ostream &OS) { 155*bdd1243dSDimitry Andric auto PrintValue = [&](Value *V) { 156*bdd1243dSDimitry Andric if (V) { 157*bdd1243dSDimitry Andric OS << "\""; 158*bdd1243dSDimitry Andric V->print(OS, true); 159*bdd1243dSDimitry Andric OS << "\"\n"; 160*bdd1243dSDimitry Andric } else 161*bdd1243dSDimitry Andric OS << "nullptr\n"; 162*bdd1243dSDimitry Andric }; 163*bdd1243dSDimitry Andric auto PrintNodeRef = [&](RawNodePtr Ptr) { 164*bdd1243dSDimitry Andric if (Ptr) 165*bdd1243dSDimitry Andric OS << Ptr << "\n"; 166*bdd1243dSDimitry Andric else 167*bdd1243dSDimitry Andric OS << "nullptr\n"; 168*bdd1243dSDimitry Andric }; 169*bdd1243dSDimitry Andric 170*bdd1243dSDimitry Andric OS << "- CompositeNode: " << this << "\n"; 171*bdd1243dSDimitry Andric OS << " Real: "; 172*bdd1243dSDimitry Andric PrintValue(Real); 173*bdd1243dSDimitry Andric OS << " Imag: "; 174*bdd1243dSDimitry Andric PrintValue(Imag); 175*bdd1243dSDimitry Andric OS << " ReplacementNode: "; 176*bdd1243dSDimitry Andric PrintValue(ReplacementNode); 177*bdd1243dSDimitry Andric OS << " Operation: " << (int)Operation << "\n"; 178*bdd1243dSDimitry Andric OS << " Rotation: " << ((int)Rotation * 90) << "\n"; 179*bdd1243dSDimitry Andric OS << " Operands: \n"; 180*bdd1243dSDimitry Andric for (const auto &Op : Operands) { 181*bdd1243dSDimitry Andric OS << " - "; 182*bdd1243dSDimitry Andric PrintNodeRef(Op); 183*bdd1243dSDimitry Andric } 184*bdd1243dSDimitry Andric OS << " InternalInstructions:\n"; 185*bdd1243dSDimitry Andric for (const auto &I : InternalInstructions) { 186*bdd1243dSDimitry Andric OS << " - \""; 187*bdd1243dSDimitry Andric I->print(OS, true); 188*bdd1243dSDimitry Andric OS << "\"\n"; 189*bdd1243dSDimitry Andric } 190*bdd1243dSDimitry Andric } 191*bdd1243dSDimitry Andric }; 192*bdd1243dSDimitry Andric 193*bdd1243dSDimitry Andric class ComplexDeinterleavingGraph { 194*bdd1243dSDimitry Andric public: 195*bdd1243dSDimitry Andric using NodePtr = ComplexDeinterleavingCompositeNode::NodePtr; 196*bdd1243dSDimitry Andric using RawNodePtr = ComplexDeinterleavingCompositeNode::RawNodePtr; 197*bdd1243dSDimitry Andric explicit ComplexDeinterleavingGraph(const TargetLowering *tl) : TL(tl) {} 198*bdd1243dSDimitry Andric 199*bdd1243dSDimitry Andric private: 200*bdd1243dSDimitry Andric const TargetLowering *TL; 201*bdd1243dSDimitry Andric Instruction *RootValue; 202*bdd1243dSDimitry Andric NodePtr RootNode; 203*bdd1243dSDimitry Andric SmallVector<NodePtr> CompositeNodes; 204*bdd1243dSDimitry Andric SmallPtrSet<Instruction *, 16> AllInstructions; 205*bdd1243dSDimitry Andric 206*bdd1243dSDimitry Andric NodePtr prepareCompositeNode(ComplexDeinterleavingOperation Operation, 207*bdd1243dSDimitry Andric Instruction *R, Instruction *I) { 208*bdd1243dSDimitry Andric return std::make_shared<ComplexDeinterleavingCompositeNode>(Operation, R, 209*bdd1243dSDimitry Andric I); 210*bdd1243dSDimitry Andric } 211*bdd1243dSDimitry Andric 212*bdd1243dSDimitry Andric NodePtr submitCompositeNode(NodePtr Node) { 213*bdd1243dSDimitry Andric CompositeNodes.push_back(Node); 214*bdd1243dSDimitry Andric AllInstructions.insert(Node->Real); 215*bdd1243dSDimitry Andric AllInstructions.insert(Node->Imag); 216*bdd1243dSDimitry Andric for (auto *I : Node->InternalInstructions) 217*bdd1243dSDimitry Andric AllInstructions.insert(I); 218*bdd1243dSDimitry Andric return Node; 219*bdd1243dSDimitry Andric } 220*bdd1243dSDimitry Andric 221*bdd1243dSDimitry Andric NodePtr getContainingComposite(Value *R, Value *I) { 222*bdd1243dSDimitry Andric for (const auto &CN : CompositeNodes) { 223*bdd1243dSDimitry Andric if (CN->Real == R && CN->Imag == I) 224*bdd1243dSDimitry Andric return CN; 225*bdd1243dSDimitry Andric } 226*bdd1243dSDimitry Andric return nullptr; 227*bdd1243dSDimitry Andric } 228*bdd1243dSDimitry Andric 229*bdd1243dSDimitry Andric /// Identifies a complex partial multiply pattern and its rotation, based on 230*bdd1243dSDimitry Andric /// the following patterns 231*bdd1243dSDimitry Andric /// 232*bdd1243dSDimitry Andric /// 0: r: cr + ar * br 233*bdd1243dSDimitry Andric /// i: ci + ar * bi 234*bdd1243dSDimitry Andric /// 90: r: cr - ai * bi 235*bdd1243dSDimitry Andric /// i: ci + ai * br 236*bdd1243dSDimitry Andric /// 180: r: cr - ar * br 237*bdd1243dSDimitry Andric /// i: ci - ar * bi 238*bdd1243dSDimitry Andric /// 270: r: cr + ai * bi 239*bdd1243dSDimitry Andric /// i: ci - ai * br 240*bdd1243dSDimitry Andric NodePtr identifyPartialMul(Instruction *Real, Instruction *Imag); 241*bdd1243dSDimitry Andric 242*bdd1243dSDimitry Andric /// Identify the other branch of a Partial Mul, taking the CommonOperandI that 243*bdd1243dSDimitry Andric /// is partially known from identifyPartialMul, filling in the other half of 244*bdd1243dSDimitry Andric /// the complex pair. 245*bdd1243dSDimitry Andric NodePtr identifyNodeWithImplicitAdd( 246*bdd1243dSDimitry Andric Instruction *I, Instruction *J, 247*bdd1243dSDimitry Andric std::pair<Instruction *, Instruction *> &CommonOperandI); 248*bdd1243dSDimitry Andric 249*bdd1243dSDimitry Andric /// Identifies a complex add pattern and its rotation, based on the following 250*bdd1243dSDimitry Andric /// patterns. 251*bdd1243dSDimitry Andric /// 252*bdd1243dSDimitry Andric /// 90: r: ar - bi 253*bdd1243dSDimitry Andric /// i: ai + br 254*bdd1243dSDimitry Andric /// 270: r: ar + bi 255*bdd1243dSDimitry Andric /// i: ai - br 256*bdd1243dSDimitry Andric NodePtr identifyAdd(Instruction *Real, Instruction *Imag); 257*bdd1243dSDimitry Andric 258*bdd1243dSDimitry Andric NodePtr identifyNode(Instruction *I, Instruction *J); 259*bdd1243dSDimitry Andric 260*bdd1243dSDimitry Andric Value *replaceNode(RawNodePtr Node); 261*bdd1243dSDimitry Andric 262*bdd1243dSDimitry Andric public: 263*bdd1243dSDimitry Andric void dump() { dump(dbgs()); } 264*bdd1243dSDimitry Andric void dump(raw_ostream &OS) { 265*bdd1243dSDimitry Andric for (const auto &Node : CompositeNodes) 266*bdd1243dSDimitry Andric Node->dump(OS); 267*bdd1243dSDimitry Andric } 268*bdd1243dSDimitry Andric 269*bdd1243dSDimitry Andric /// Returns false if the deinterleaving operation should be cancelled for the 270*bdd1243dSDimitry Andric /// current graph. 271*bdd1243dSDimitry Andric bool identifyNodes(Instruction *RootI); 272*bdd1243dSDimitry Andric 273*bdd1243dSDimitry Andric /// Perform the actual replacement of the underlying instruction graph. 274*bdd1243dSDimitry Andric /// Returns false if the deinterleaving operation should be cancelled for the 275*bdd1243dSDimitry Andric /// current graph. 276*bdd1243dSDimitry Andric void replaceNodes(); 277*bdd1243dSDimitry Andric }; 278*bdd1243dSDimitry Andric 279*bdd1243dSDimitry Andric class ComplexDeinterleaving { 280*bdd1243dSDimitry Andric public: 281*bdd1243dSDimitry Andric ComplexDeinterleaving(const TargetLowering *tl, const TargetLibraryInfo *tli) 282*bdd1243dSDimitry Andric : TL(tl), TLI(tli) {} 283*bdd1243dSDimitry Andric bool runOnFunction(Function &F); 284*bdd1243dSDimitry Andric 285*bdd1243dSDimitry Andric private: 286*bdd1243dSDimitry Andric bool evaluateBasicBlock(BasicBlock *B); 287*bdd1243dSDimitry Andric 288*bdd1243dSDimitry Andric const TargetLowering *TL = nullptr; 289*bdd1243dSDimitry Andric const TargetLibraryInfo *TLI = nullptr; 290*bdd1243dSDimitry Andric }; 291*bdd1243dSDimitry Andric 292*bdd1243dSDimitry Andric } // namespace 293*bdd1243dSDimitry Andric 294*bdd1243dSDimitry Andric char ComplexDeinterleavingLegacyPass::ID = 0; 295*bdd1243dSDimitry Andric 296*bdd1243dSDimitry Andric INITIALIZE_PASS_BEGIN(ComplexDeinterleavingLegacyPass, DEBUG_TYPE, 297*bdd1243dSDimitry Andric "Complex Deinterleaving", false, false) 298*bdd1243dSDimitry Andric INITIALIZE_PASS_END(ComplexDeinterleavingLegacyPass, DEBUG_TYPE, 299*bdd1243dSDimitry Andric "Complex Deinterleaving", false, false) 300*bdd1243dSDimitry Andric 301*bdd1243dSDimitry Andric PreservedAnalyses ComplexDeinterleavingPass::run(Function &F, 302*bdd1243dSDimitry Andric FunctionAnalysisManager &AM) { 303*bdd1243dSDimitry Andric const TargetLowering *TL = TM->getSubtargetImpl(F)->getTargetLowering(); 304*bdd1243dSDimitry Andric auto &TLI = AM.getResult<llvm::TargetLibraryAnalysis>(F); 305*bdd1243dSDimitry Andric if (!ComplexDeinterleaving(TL, &TLI).runOnFunction(F)) 306*bdd1243dSDimitry Andric return PreservedAnalyses::all(); 307*bdd1243dSDimitry Andric 308*bdd1243dSDimitry Andric PreservedAnalyses PA; 309*bdd1243dSDimitry Andric PA.preserve<FunctionAnalysisManagerModuleProxy>(); 310*bdd1243dSDimitry Andric return PA; 311*bdd1243dSDimitry Andric } 312*bdd1243dSDimitry Andric 313*bdd1243dSDimitry Andric FunctionPass *llvm::createComplexDeinterleavingPass(const TargetMachine *TM) { 314*bdd1243dSDimitry Andric return new ComplexDeinterleavingLegacyPass(TM); 315*bdd1243dSDimitry Andric } 316*bdd1243dSDimitry Andric 317*bdd1243dSDimitry Andric bool ComplexDeinterleavingLegacyPass::runOnFunction(Function &F) { 318*bdd1243dSDimitry Andric const auto *TL = TM->getSubtargetImpl(F)->getTargetLowering(); 319*bdd1243dSDimitry Andric auto TLI = getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F); 320*bdd1243dSDimitry Andric return ComplexDeinterleaving(TL, &TLI).runOnFunction(F); 321*bdd1243dSDimitry Andric } 322*bdd1243dSDimitry Andric 323*bdd1243dSDimitry Andric bool ComplexDeinterleaving::runOnFunction(Function &F) { 324*bdd1243dSDimitry Andric if (!ComplexDeinterleavingEnabled) { 325*bdd1243dSDimitry Andric LLVM_DEBUG( 326*bdd1243dSDimitry Andric dbgs() << "Complex deinterleaving has been explicitly disabled.\n"); 327*bdd1243dSDimitry Andric return false; 328*bdd1243dSDimitry Andric } 329*bdd1243dSDimitry Andric 330*bdd1243dSDimitry Andric if (!TL->isComplexDeinterleavingSupported()) { 331*bdd1243dSDimitry Andric LLVM_DEBUG( 332*bdd1243dSDimitry Andric dbgs() << "Complex deinterleaving has been disabled, target does " 333*bdd1243dSDimitry Andric "not support lowering of complex number operations.\n"); 334*bdd1243dSDimitry Andric return false; 335*bdd1243dSDimitry Andric } 336*bdd1243dSDimitry Andric 337*bdd1243dSDimitry Andric bool Changed = false; 338*bdd1243dSDimitry Andric for (auto &B : F) 339*bdd1243dSDimitry Andric Changed |= evaluateBasicBlock(&B); 340*bdd1243dSDimitry Andric 341*bdd1243dSDimitry Andric return Changed; 342*bdd1243dSDimitry Andric } 343*bdd1243dSDimitry Andric 344*bdd1243dSDimitry Andric static bool isInterleavingMask(ArrayRef<int> Mask) { 345*bdd1243dSDimitry Andric // If the size is not even, it's not an interleaving mask 346*bdd1243dSDimitry Andric if ((Mask.size() & 1)) 347*bdd1243dSDimitry Andric return false; 348*bdd1243dSDimitry Andric 349*bdd1243dSDimitry Andric int HalfNumElements = Mask.size() / 2; 350*bdd1243dSDimitry Andric for (int Idx = 0; Idx < HalfNumElements; ++Idx) { 351*bdd1243dSDimitry Andric int MaskIdx = Idx * 2; 352*bdd1243dSDimitry Andric if (Mask[MaskIdx] != Idx || Mask[MaskIdx + 1] != (Idx + HalfNumElements)) 353*bdd1243dSDimitry Andric return false; 354*bdd1243dSDimitry Andric } 355*bdd1243dSDimitry Andric 356*bdd1243dSDimitry Andric return true; 357*bdd1243dSDimitry Andric } 358*bdd1243dSDimitry Andric 359*bdd1243dSDimitry Andric static bool isDeinterleavingMask(ArrayRef<int> Mask) { 360*bdd1243dSDimitry Andric int Offset = Mask[0]; 361*bdd1243dSDimitry Andric int HalfNumElements = Mask.size() / 2; 362*bdd1243dSDimitry Andric 363*bdd1243dSDimitry Andric for (int Idx = 1; Idx < HalfNumElements; ++Idx) { 364*bdd1243dSDimitry Andric if (Mask[Idx] != (Idx * 2) + Offset) 365*bdd1243dSDimitry Andric return false; 366*bdd1243dSDimitry Andric } 367*bdd1243dSDimitry Andric 368*bdd1243dSDimitry Andric return true; 369*bdd1243dSDimitry Andric } 370*bdd1243dSDimitry Andric 371*bdd1243dSDimitry Andric bool ComplexDeinterleaving::evaluateBasicBlock(BasicBlock *B) { 372*bdd1243dSDimitry Andric bool Changed = false; 373*bdd1243dSDimitry Andric 374*bdd1243dSDimitry Andric SmallVector<Instruction *> DeadInstrRoots; 375*bdd1243dSDimitry Andric 376*bdd1243dSDimitry Andric for (auto &I : *B) { 377*bdd1243dSDimitry Andric auto *SVI = dyn_cast<ShuffleVectorInst>(&I); 378*bdd1243dSDimitry Andric if (!SVI) 379*bdd1243dSDimitry Andric continue; 380*bdd1243dSDimitry Andric 381*bdd1243dSDimitry Andric // Look for a shufflevector that takes separate vectors of the real and 382*bdd1243dSDimitry Andric // imaginary components and recombines them into a single vector. 383*bdd1243dSDimitry Andric if (!isInterleavingMask(SVI->getShuffleMask())) 384*bdd1243dSDimitry Andric continue; 385*bdd1243dSDimitry Andric 386*bdd1243dSDimitry Andric ComplexDeinterleavingGraph Graph(TL); 387*bdd1243dSDimitry Andric if (!Graph.identifyNodes(SVI)) 388*bdd1243dSDimitry Andric continue; 389*bdd1243dSDimitry Andric 390*bdd1243dSDimitry Andric Graph.replaceNodes(); 391*bdd1243dSDimitry Andric DeadInstrRoots.push_back(SVI); 392*bdd1243dSDimitry Andric Changed = true; 393*bdd1243dSDimitry Andric } 394*bdd1243dSDimitry Andric 395*bdd1243dSDimitry Andric for (const auto &I : DeadInstrRoots) { 396*bdd1243dSDimitry Andric if (!I || I->getParent() == nullptr) 397*bdd1243dSDimitry Andric continue; 398*bdd1243dSDimitry Andric llvm::RecursivelyDeleteTriviallyDeadInstructions(I, TLI); 399*bdd1243dSDimitry Andric } 400*bdd1243dSDimitry Andric 401*bdd1243dSDimitry Andric return Changed; 402*bdd1243dSDimitry Andric } 403*bdd1243dSDimitry Andric 404*bdd1243dSDimitry Andric ComplexDeinterleavingGraph::NodePtr 405*bdd1243dSDimitry Andric ComplexDeinterleavingGraph::identifyNodeWithImplicitAdd( 406*bdd1243dSDimitry Andric Instruction *Real, Instruction *Imag, 407*bdd1243dSDimitry Andric std::pair<Instruction *, Instruction *> &PartialMatch) { 408*bdd1243dSDimitry Andric LLVM_DEBUG(dbgs() << "identifyNodeWithImplicitAdd " << *Real << " / " << *Imag 409*bdd1243dSDimitry Andric << "\n"); 410*bdd1243dSDimitry Andric 411*bdd1243dSDimitry Andric if (!Real->hasOneUse() || !Imag->hasOneUse()) { 412*bdd1243dSDimitry Andric LLVM_DEBUG(dbgs() << " - Mul operand has multiple uses.\n"); 413*bdd1243dSDimitry Andric return nullptr; 414*bdd1243dSDimitry Andric } 415*bdd1243dSDimitry Andric 416*bdd1243dSDimitry Andric if (Real->getOpcode() != Instruction::FMul || 417*bdd1243dSDimitry Andric Imag->getOpcode() != Instruction::FMul) { 418*bdd1243dSDimitry Andric LLVM_DEBUG(dbgs() << " - Real or imaginary instruction is not fmul\n"); 419*bdd1243dSDimitry Andric return nullptr; 420*bdd1243dSDimitry Andric } 421*bdd1243dSDimitry Andric 422*bdd1243dSDimitry Andric Instruction *R0 = dyn_cast<Instruction>(Real->getOperand(0)); 423*bdd1243dSDimitry Andric Instruction *R1 = dyn_cast<Instruction>(Real->getOperand(1)); 424*bdd1243dSDimitry Andric Instruction *I0 = dyn_cast<Instruction>(Imag->getOperand(0)); 425*bdd1243dSDimitry Andric Instruction *I1 = dyn_cast<Instruction>(Imag->getOperand(1)); 426*bdd1243dSDimitry Andric if (!R0 || !R1 || !I0 || !I1) { 427*bdd1243dSDimitry Andric LLVM_DEBUG(dbgs() << " - Mul operand not Instruction\n"); 428*bdd1243dSDimitry Andric return nullptr; 429*bdd1243dSDimitry Andric } 430*bdd1243dSDimitry Andric 431*bdd1243dSDimitry Andric // A +/+ has a rotation of 0. If any of the operands are fneg, we flip the 432*bdd1243dSDimitry Andric // rotations and use the operand. 433*bdd1243dSDimitry Andric unsigned Negs = 0; 434*bdd1243dSDimitry Andric SmallVector<Instruction *> FNegs; 435*bdd1243dSDimitry Andric if (R0->getOpcode() == Instruction::FNeg || 436*bdd1243dSDimitry Andric R1->getOpcode() == Instruction::FNeg) { 437*bdd1243dSDimitry Andric Negs |= 1; 438*bdd1243dSDimitry Andric if (R0->getOpcode() == Instruction::FNeg) { 439*bdd1243dSDimitry Andric FNegs.push_back(R0); 440*bdd1243dSDimitry Andric R0 = dyn_cast<Instruction>(R0->getOperand(0)); 441*bdd1243dSDimitry Andric } else { 442*bdd1243dSDimitry Andric FNegs.push_back(R1); 443*bdd1243dSDimitry Andric R1 = dyn_cast<Instruction>(R1->getOperand(0)); 444*bdd1243dSDimitry Andric } 445*bdd1243dSDimitry Andric if (!R0 || !R1) 446*bdd1243dSDimitry Andric return nullptr; 447*bdd1243dSDimitry Andric } 448*bdd1243dSDimitry Andric if (I0->getOpcode() == Instruction::FNeg || 449*bdd1243dSDimitry Andric I1->getOpcode() == Instruction::FNeg) { 450*bdd1243dSDimitry Andric Negs |= 2; 451*bdd1243dSDimitry Andric Negs ^= 1; 452*bdd1243dSDimitry Andric if (I0->getOpcode() == Instruction::FNeg) { 453*bdd1243dSDimitry Andric FNegs.push_back(I0); 454*bdd1243dSDimitry Andric I0 = dyn_cast<Instruction>(I0->getOperand(0)); 455*bdd1243dSDimitry Andric } else { 456*bdd1243dSDimitry Andric FNegs.push_back(I1); 457*bdd1243dSDimitry Andric I1 = dyn_cast<Instruction>(I1->getOperand(0)); 458*bdd1243dSDimitry Andric } 459*bdd1243dSDimitry Andric if (!I0 || !I1) 460*bdd1243dSDimitry Andric return nullptr; 461*bdd1243dSDimitry Andric } 462*bdd1243dSDimitry Andric 463*bdd1243dSDimitry Andric ComplexDeinterleavingRotation Rotation = (ComplexDeinterleavingRotation)Negs; 464*bdd1243dSDimitry Andric 465*bdd1243dSDimitry Andric Instruction *CommonOperand; 466*bdd1243dSDimitry Andric Instruction *UncommonRealOp; 467*bdd1243dSDimitry Andric Instruction *UncommonImagOp; 468*bdd1243dSDimitry Andric 469*bdd1243dSDimitry Andric if (R0 == I0 || R0 == I1) { 470*bdd1243dSDimitry Andric CommonOperand = R0; 471*bdd1243dSDimitry Andric UncommonRealOp = R1; 472*bdd1243dSDimitry Andric } else if (R1 == I0 || R1 == I1) { 473*bdd1243dSDimitry Andric CommonOperand = R1; 474*bdd1243dSDimitry Andric UncommonRealOp = R0; 475*bdd1243dSDimitry Andric } else { 476*bdd1243dSDimitry Andric LLVM_DEBUG(dbgs() << " - No equal operand\n"); 477*bdd1243dSDimitry Andric return nullptr; 478*bdd1243dSDimitry Andric } 479*bdd1243dSDimitry Andric 480*bdd1243dSDimitry Andric UncommonImagOp = (CommonOperand == I0) ? I1 : I0; 481*bdd1243dSDimitry Andric if (Rotation == ComplexDeinterleavingRotation::Rotation_90 || 482*bdd1243dSDimitry Andric Rotation == ComplexDeinterleavingRotation::Rotation_270) 483*bdd1243dSDimitry Andric std::swap(UncommonRealOp, UncommonImagOp); 484*bdd1243dSDimitry Andric 485*bdd1243dSDimitry Andric // Between identifyPartialMul and here we need to have found a complete valid 486*bdd1243dSDimitry Andric // pair from the CommonOperand of each part. 487*bdd1243dSDimitry Andric if (Rotation == ComplexDeinterleavingRotation::Rotation_0 || 488*bdd1243dSDimitry Andric Rotation == ComplexDeinterleavingRotation::Rotation_180) 489*bdd1243dSDimitry Andric PartialMatch.first = CommonOperand; 490*bdd1243dSDimitry Andric else 491*bdd1243dSDimitry Andric PartialMatch.second = CommonOperand; 492*bdd1243dSDimitry Andric 493*bdd1243dSDimitry Andric if (!PartialMatch.first || !PartialMatch.second) { 494*bdd1243dSDimitry Andric LLVM_DEBUG(dbgs() << " - Incomplete partial match\n"); 495*bdd1243dSDimitry Andric return nullptr; 496*bdd1243dSDimitry Andric } 497*bdd1243dSDimitry Andric 498*bdd1243dSDimitry Andric NodePtr CommonNode = identifyNode(PartialMatch.first, PartialMatch.second); 499*bdd1243dSDimitry Andric if (!CommonNode) { 500*bdd1243dSDimitry Andric LLVM_DEBUG(dbgs() << " - No CommonNode identified\n"); 501*bdd1243dSDimitry Andric return nullptr; 502*bdd1243dSDimitry Andric } 503*bdd1243dSDimitry Andric 504*bdd1243dSDimitry Andric NodePtr UncommonNode = identifyNode(UncommonRealOp, UncommonImagOp); 505*bdd1243dSDimitry Andric if (!UncommonNode) { 506*bdd1243dSDimitry Andric LLVM_DEBUG(dbgs() << " - No UncommonNode identified\n"); 507*bdd1243dSDimitry Andric return nullptr; 508*bdd1243dSDimitry Andric } 509*bdd1243dSDimitry Andric 510*bdd1243dSDimitry Andric NodePtr Node = prepareCompositeNode( 511*bdd1243dSDimitry Andric ComplexDeinterleavingOperation::CMulPartial, Real, Imag); 512*bdd1243dSDimitry Andric Node->Rotation = Rotation; 513*bdd1243dSDimitry Andric Node->addOperand(CommonNode); 514*bdd1243dSDimitry Andric Node->addOperand(UncommonNode); 515*bdd1243dSDimitry Andric Node->InternalInstructions.append(FNegs); 516*bdd1243dSDimitry Andric return submitCompositeNode(Node); 517*bdd1243dSDimitry Andric } 518*bdd1243dSDimitry Andric 519*bdd1243dSDimitry Andric ComplexDeinterleavingGraph::NodePtr 520*bdd1243dSDimitry Andric ComplexDeinterleavingGraph::identifyPartialMul(Instruction *Real, 521*bdd1243dSDimitry Andric Instruction *Imag) { 522*bdd1243dSDimitry Andric LLVM_DEBUG(dbgs() << "identifyPartialMul " << *Real << " / " << *Imag 523*bdd1243dSDimitry Andric << "\n"); 524*bdd1243dSDimitry Andric // Determine rotation 525*bdd1243dSDimitry Andric ComplexDeinterleavingRotation Rotation; 526*bdd1243dSDimitry Andric if (Real->getOpcode() == Instruction::FAdd && 527*bdd1243dSDimitry Andric Imag->getOpcode() == Instruction::FAdd) 528*bdd1243dSDimitry Andric Rotation = ComplexDeinterleavingRotation::Rotation_0; 529*bdd1243dSDimitry Andric else if (Real->getOpcode() == Instruction::FSub && 530*bdd1243dSDimitry Andric Imag->getOpcode() == Instruction::FAdd) 531*bdd1243dSDimitry Andric Rotation = ComplexDeinterleavingRotation::Rotation_90; 532*bdd1243dSDimitry Andric else if (Real->getOpcode() == Instruction::FSub && 533*bdd1243dSDimitry Andric Imag->getOpcode() == Instruction::FSub) 534*bdd1243dSDimitry Andric Rotation = ComplexDeinterleavingRotation::Rotation_180; 535*bdd1243dSDimitry Andric else if (Real->getOpcode() == Instruction::FAdd && 536*bdd1243dSDimitry Andric Imag->getOpcode() == Instruction::FSub) 537*bdd1243dSDimitry Andric Rotation = ComplexDeinterleavingRotation::Rotation_270; 538*bdd1243dSDimitry Andric else { 539*bdd1243dSDimitry Andric LLVM_DEBUG(dbgs() << " - Unhandled rotation.\n"); 540*bdd1243dSDimitry Andric return nullptr; 541*bdd1243dSDimitry Andric } 542*bdd1243dSDimitry Andric 543*bdd1243dSDimitry Andric if (!Real->getFastMathFlags().allowContract() || 544*bdd1243dSDimitry Andric !Imag->getFastMathFlags().allowContract()) { 545*bdd1243dSDimitry Andric LLVM_DEBUG(dbgs() << " - Contract is missing from the FastMath flags.\n"); 546*bdd1243dSDimitry Andric return nullptr; 547*bdd1243dSDimitry Andric } 548*bdd1243dSDimitry Andric 549*bdd1243dSDimitry Andric Value *CR = Real->getOperand(0); 550*bdd1243dSDimitry Andric Instruction *RealMulI = dyn_cast<Instruction>(Real->getOperand(1)); 551*bdd1243dSDimitry Andric if (!RealMulI) 552*bdd1243dSDimitry Andric return nullptr; 553*bdd1243dSDimitry Andric Value *CI = Imag->getOperand(0); 554*bdd1243dSDimitry Andric Instruction *ImagMulI = dyn_cast<Instruction>(Imag->getOperand(1)); 555*bdd1243dSDimitry Andric if (!ImagMulI) 556*bdd1243dSDimitry Andric return nullptr; 557*bdd1243dSDimitry Andric 558*bdd1243dSDimitry Andric if (!RealMulI->hasOneUse() || !ImagMulI->hasOneUse()) { 559*bdd1243dSDimitry Andric LLVM_DEBUG(dbgs() << " - Mul instruction has multiple uses\n"); 560*bdd1243dSDimitry Andric return nullptr; 561*bdd1243dSDimitry Andric } 562*bdd1243dSDimitry Andric 563*bdd1243dSDimitry Andric Instruction *R0 = dyn_cast<Instruction>(RealMulI->getOperand(0)); 564*bdd1243dSDimitry Andric Instruction *R1 = dyn_cast<Instruction>(RealMulI->getOperand(1)); 565*bdd1243dSDimitry Andric Instruction *I0 = dyn_cast<Instruction>(ImagMulI->getOperand(0)); 566*bdd1243dSDimitry Andric Instruction *I1 = dyn_cast<Instruction>(ImagMulI->getOperand(1)); 567*bdd1243dSDimitry Andric if (!R0 || !R1 || !I0 || !I1) { 568*bdd1243dSDimitry Andric LLVM_DEBUG(dbgs() << " - Mul operand not Instruction\n"); 569*bdd1243dSDimitry Andric return nullptr; 570*bdd1243dSDimitry Andric } 571*bdd1243dSDimitry Andric 572*bdd1243dSDimitry Andric Instruction *CommonOperand; 573*bdd1243dSDimitry Andric Instruction *UncommonRealOp; 574*bdd1243dSDimitry Andric Instruction *UncommonImagOp; 575*bdd1243dSDimitry Andric 576*bdd1243dSDimitry Andric if (R0 == I0 || R0 == I1) { 577*bdd1243dSDimitry Andric CommonOperand = R0; 578*bdd1243dSDimitry Andric UncommonRealOp = R1; 579*bdd1243dSDimitry Andric } else if (R1 == I0 || R1 == I1) { 580*bdd1243dSDimitry Andric CommonOperand = R1; 581*bdd1243dSDimitry Andric UncommonRealOp = R0; 582*bdd1243dSDimitry Andric } else { 583*bdd1243dSDimitry Andric LLVM_DEBUG(dbgs() << " - No equal operand\n"); 584*bdd1243dSDimitry Andric return nullptr; 585*bdd1243dSDimitry Andric } 586*bdd1243dSDimitry Andric 587*bdd1243dSDimitry Andric UncommonImagOp = (CommonOperand == I0) ? I1 : I0; 588*bdd1243dSDimitry Andric if (Rotation == ComplexDeinterleavingRotation::Rotation_90 || 589*bdd1243dSDimitry Andric Rotation == ComplexDeinterleavingRotation::Rotation_270) 590*bdd1243dSDimitry Andric std::swap(UncommonRealOp, UncommonImagOp); 591*bdd1243dSDimitry Andric 592*bdd1243dSDimitry Andric std::pair<Instruction *, Instruction *> PartialMatch( 593*bdd1243dSDimitry Andric (Rotation == ComplexDeinterleavingRotation::Rotation_0 || 594*bdd1243dSDimitry Andric Rotation == ComplexDeinterleavingRotation::Rotation_180) 595*bdd1243dSDimitry Andric ? CommonOperand 596*bdd1243dSDimitry Andric : nullptr, 597*bdd1243dSDimitry Andric (Rotation == ComplexDeinterleavingRotation::Rotation_90 || 598*bdd1243dSDimitry Andric Rotation == ComplexDeinterleavingRotation::Rotation_270) 599*bdd1243dSDimitry Andric ? CommonOperand 600*bdd1243dSDimitry Andric : nullptr); 601*bdd1243dSDimitry Andric NodePtr CNode = identifyNodeWithImplicitAdd( 602*bdd1243dSDimitry Andric cast<Instruction>(CR), cast<Instruction>(CI), PartialMatch); 603*bdd1243dSDimitry Andric if (!CNode) { 604*bdd1243dSDimitry Andric LLVM_DEBUG(dbgs() << " - No cnode identified\n"); 605*bdd1243dSDimitry Andric return nullptr; 606*bdd1243dSDimitry Andric } 607*bdd1243dSDimitry Andric 608*bdd1243dSDimitry Andric NodePtr UncommonRes = identifyNode(UncommonRealOp, UncommonImagOp); 609*bdd1243dSDimitry Andric if (!UncommonRes) { 610*bdd1243dSDimitry Andric LLVM_DEBUG(dbgs() << " - No UncommonRes identified\n"); 611*bdd1243dSDimitry Andric return nullptr; 612*bdd1243dSDimitry Andric } 613*bdd1243dSDimitry Andric 614*bdd1243dSDimitry Andric assert(PartialMatch.first && PartialMatch.second); 615*bdd1243dSDimitry Andric NodePtr CommonRes = identifyNode(PartialMatch.first, PartialMatch.second); 616*bdd1243dSDimitry Andric if (!CommonRes) { 617*bdd1243dSDimitry Andric LLVM_DEBUG(dbgs() << " - No CommonRes identified\n"); 618*bdd1243dSDimitry Andric return nullptr; 619*bdd1243dSDimitry Andric } 620*bdd1243dSDimitry Andric 621*bdd1243dSDimitry Andric NodePtr Node = prepareCompositeNode( 622*bdd1243dSDimitry Andric ComplexDeinterleavingOperation::CMulPartial, Real, Imag); 623*bdd1243dSDimitry Andric Node->addInstruction(RealMulI); 624*bdd1243dSDimitry Andric Node->addInstruction(ImagMulI); 625*bdd1243dSDimitry Andric Node->Rotation = Rotation; 626*bdd1243dSDimitry Andric Node->addOperand(CommonRes); 627*bdd1243dSDimitry Andric Node->addOperand(UncommonRes); 628*bdd1243dSDimitry Andric Node->addOperand(CNode); 629*bdd1243dSDimitry Andric return submitCompositeNode(Node); 630*bdd1243dSDimitry Andric } 631*bdd1243dSDimitry Andric 632*bdd1243dSDimitry Andric ComplexDeinterleavingGraph::NodePtr 633*bdd1243dSDimitry Andric ComplexDeinterleavingGraph::identifyAdd(Instruction *Real, Instruction *Imag) { 634*bdd1243dSDimitry Andric LLVM_DEBUG(dbgs() << "identifyAdd " << *Real << " / " << *Imag << "\n"); 635*bdd1243dSDimitry Andric 636*bdd1243dSDimitry Andric // Determine rotation 637*bdd1243dSDimitry Andric ComplexDeinterleavingRotation Rotation; 638*bdd1243dSDimitry Andric if ((Real->getOpcode() == Instruction::FSub && 639*bdd1243dSDimitry Andric Imag->getOpcode() == Instruction::FAdd) || 640*bdd1243dSDimitry Andric (Real->getOpcode() == Instruction::Sub && 641*bdd1243dSDimitry Andric Imag->getOpcode() == Instruction::Add)) 642*bdd1243dSDimitry Andric Rotation = ComplexDeinterleavingRotation::Rotation_90; 643*bdd1243dSDimitry Andric else if ((Real->getOpcode() == Instruction::FAdd && 644*bdd1243dSDimitry Andric Imag->getOpcode() == Instruction::FSub) || 645*bdd1243dSDimitry Andric (Real->getOpcode() == Instruction::Add && 646*bdd1243dSDimitry Andric Imag->getOpcode() == Instruction::Sub)) 647*bdd1243dSDimitry Andric Rotation = ComplexDeinterleavingRotation::Rotation_270; 648*bdd1243dSDimitry Andric else { 649*bdd1243dSDimitry Andric LLVM_DEBUG(dbgs() << " - Unhandled case, rotation is not assigned.\n"); 650*bdd1243dSDimitry Andric return nullptr; 651*bdd1243dSDimitry Andric } 652*bdd1243dSDimitry Andric 653*bdd1243dSDimitry Andric auto *AR = dyn_cast<Instruction>(Real->getOperand(0)); 654*bdd1243dSDimitry Andric auto *BI = dyn_cast<Instruction>(Real->getOperand(1)); 655*bdd1243dSDimitry Andric auto *AI = dyn_cast<Instruction>(Imag->getOperand(0)); 656*bdd1243dSDimitry Andric auto *BR = dyn_cast<Instruction>(Imag->getOperand(1)); 657*bdd1243dSDimitry Andric 658*bdd1243dSDimitry Andric if (!AR || !AI || !BR || !BI) { 659*bdd1243dSDimitry Andric LLVM_DEBUG(dbgs() << " - Not all operands are instructions.\n"); 660*bdd1243dSDimitry Andric return nullptr; 661*bdd1243dSDimitry Andric } 662*bdd1243dSDimitry Andric 663*bdd1243dSDimitry Andric NodePtr ResA = identifyNode(AR, AI); 664*bdd1243dSDimitry Andric if (!ResA) { 665*bdd1243dSDimitry Andric LLVM_DEBUG(dbgs() << " - AR/AI is not identified as a composite node.\n"); 666*bdd1243dSDimitry Andric return nullptr; 667*bdd1243dSDimitry Andric } 668*bdd1243dSDimitry Andric NodePtr ResB = identifyNode(BR, BI); 669*bdd1243dSDimitry Andric if (!ResB) { 670*bdd1243dSDimitry Andric LLVM_DEBUG(dbgs() << " - BR/BI is not identified as a composite node.\n"); 671*bdd1243dSDimitry Andric return nullptr; 672*bdd1243dSDimitry Andric } 673*bdd1243dSDimitry Andric 674*bdd1243dSDimitry Andric NodePtr Node = 675*bdd1243dSDimitry Andric prepareCompositeNode(ComplexDeinterleavingOperation::CAdd, Real, Imag); 676*bdd1243dSDimitry Andric Node->Rotation = Rotation; 677*bdd1243dSDimitry Andric Node->addOperand(ResA); 678*bdd1243dSDimitry Andric Node->addOperand(ResB); 679*bdd1243dSDimitry Andric return submitCompositeNode(Node); 680*bdd1243dSDimitry Andric } 681*bdd1243dSDimitry Andric 682*bdd1243dSDimitry Andric static bool isInstructionPairAdd(Instruction *A, Instruction *B) { 683*bdd1243dSDimitry Andric unsigned OpcA = A->getOpcode(); 684*bdd1243dSDimitry Andric unsigned OpcB = B->getOpcode(); 685*bdd1243dSDimitry Andric 686*bdd1243dSDimitry Andric return (OpcA == Instruction::FSub && OpcB == Instruction::FAdd) || 687*bdd1243dSDimitry Andric (OpcA == Instruction::FAdd && OpcB == Instruction::FSub) || 688*bdd1243dSDimitry Andric (OpcA == Instruction::Sub && OpcB == Instruction::Add) || 689*bdd1243dSDimitry Andric (OpcA == Instruction::Add && OpcB == Instruction::Sub); 690*bdd1243dSDimitry Andric } 691*bdd1243dSDimitry Andric 692*bdd1243dSDimitry Andric static bool isInstructionPairMul(Instruction *A, Instruction *B) { 693*bdd1243dSDimitry Andric auto Pattern = 694*bdd1243dSDimitry Andric m_BinOp(m_FMul(m_Value(), m_Value()), m_FMul(m_Value(), m_Value())); 695*bdd1243dSDimitry Andric 696*bdd1243dSDimitry Andric return match(A, Pattern) && match(B, Pattern); 697*bdd1243dSDimitry Andric } 698*bdd1243dSDimitry Andric 699*bdd1243dSDimitry Andric ComplexDeinterleavingGraph::NodePtr 700*bdd1243dSDimitry Andric ComplexDeinterleavingGraph::identifyNode(Instruction *Real, Instruction *Imag) { 701*bdd1243dSDimitry Andric LLVM_DEBUG(dbgs() << "identifyNode on " << *Real << " / " << *Imag << "\n"); 702*bdd1243dSDimitry Andric if (NodePtr CN = getContainingComposite(Real, Imag)) { 703*bdd1243dSDimitry Andric LLVM_DEBUG(dbgs() << " - Folding to existing node\n"); 704*bdd1243dSDimitry Andric return CN; 705*bdd1243dSDimitry Andric } 706*bdd1243dSDimitry Andric 707*bdd1243dSDimitry Andric auto *RealShuffle = dyn_cast<ShuffleVectorInst>(Real); 708*bdd1243dSDimitry Andric auto *ImagShuffle = dyn_cast<ShuffleVectorInst>(Imag); 709*bdd1243dSDimitry Andric if (RealShuffle && ImagShuffle) { 710*bdd1243dSDimitry Andric Value *RealOp1 = RealShuffle->getOperand(1); 711*bdd1243dSDimitry Andric if (!isa<UndefValue>(RealOp1) && !isa<ConstantAggregateZero>(RealOp1)) { 712*bdd1243dSDimitry Andric LLVM_DEBUG(dbgs() << " - RealOp1 is not undef or zero.\n"); 713*bdd1243dSDimitry Andric return nullptr; 714*bdd1243dSDimitry Andric } 715*bdd1243dSDimitry Andric Value *ImagOp1 = ImagShuffle->getOperand(1); 716*bdd1243dSDimitry Andric if (!isa<UndefValue>(ImagOp1) && !isa<ConstantAggregateZero>(ImagOp1)) { 717*bdd1243dSDimitry Andric LLVM_DEBUG(dbgs() << " - ImagOp1 is not undef or zero.\n"); 718*bdd1243dSDimitry Andric return nullptr; 719*bdd1243dSDimitry Andric } 720*bdd1243dSDimitry Andric 721*bdd1243dSDimitry Andric Value *RealOp0 = RealShuffle->getOperand(0); 722*bdd1243dSDimitry Andric Value *ImagOp0 = ImagShuffle->getOperand(0); 723*bdd1243dSDimitry Andric 724*bdd1243dSDimitry Andric if (RealOp0 != ImagOp0) { 725*bdd1243dSDimitry Andric LLVM_DEBUG(dbgs() << " - Shuffle operands are not equal.\n"); 726*bdd1243dSDimitry Andric return nullptr; 727*bdd1243dSDimitry Andric } 728*bdd1243dSDimitry Andric 729*bdd1243dSDimitry Andric ArrayRef<int> RealMask = RealShuffle->getShuffleMask(); 730*bdd1243dSDimitry Andric ArrayRef<int> ImagMask = ImagShuffle->getShuffleMask(); 731*bdd1243dSDimitry Andric if (!isDeinterleavingMask(RealMask) || !isDeinterleavingMask(ImagMask)) { 732*bdd1243dSDimitry Andric LLVM_DEBUG(dbgs() << " - Masks are not deinterleaving.\n"); 733*bdd1243dSDimitry Andric return nullptr; 734*bdd1243dSDimitry Andric } 735*bdd1243dSDimitry Andric 736*bdd1243dSDimitry Andric if (RealMask[0] != 0 || ImagMask[0] != 1) { 737*bdd1243dSDimitry Andric LLVM_DEBUG(dbgs() << " - Masks do not have the correct initial value.\n"); 738*bdd1243dSDimitry Andric return nullptr; 739*bdd1243dSDimitry Andric } 740*bdd1243dSDimitry Andric 741*bdd1243dSDimitry Andric // Type checking, the shuffle type should be a vector type of the same 742*bdd1243dSDimitry Andric // scalar type, but half the size 743*bdd1243dSDimitry Andric auto CheckType = [&](ShuffleVectorInst *Shuffle) { 744*bdd1243dSDimitry Andric Value *Op = Shuffle->getOperand(0); 745*bdd1243dSDimitry Andric auto *ShuffleTy = cast<FixedVectorType>(Shuffle->getType()); 746*bdd1243dSDimitry Andric auto *OpTy = cast<FixedVectorType>(Op->getType()); 747*bdd1243dSDimitry Andric 748*bdd1243dSDimitry Andric if (OpTy->getScalarType() != ShuffleTy->getScalarType()) 749*bdd1243dSDimitry Andric return false; 750*bdd1243dSDimitry Andric if ((ShuffleTy->getNumElements() * 2) != OpTy->getNumElements()) 751*bdd1243dSDimitry Andric return false; 752*bdd1243dSDimitry Andric 753*bdd1243dSDimitry Andric return true; 754*bdd1243dSDimitry Andric }; 755*bdd1243dSDimitry Andric 756*bdd1243dSDimitry Andric auto CheckDeinterleavingShuffle = [&](ShuffleVectorInst *Shuffle) -> bool { 757*bdd1243dSDimitry Andric if (!CheckType(Shuffle)) 758*bdd1243dSDimitry Andric return false; 759*bdd1243dSDimitry Andric 760*bdd1243dSDimitry Andric ArrayRef<int> Mask = Shuffle->getShuffleMask(); 761*bdd1243dSDimitry Andric int Last = *Mask.rbegin(); 762*bdd1243dSDimitry Andric 763*bdd1243dSDimitry Andric Value *Op = Shuffle->getOperand(0); 764*bdd1243dSDimitry Andric auto *OpTy = cast<FixedVectorType>(Op->getType()); 765*bdd1243dSDimitry Andric int NumElements = OpTy->getNumElements(); 766*bdd1243dSDimitry Andric 767*bdd1243dSDimitry Andric // Ensure that the deinterleaving shuffle only pulls from the first 768*bdd1243dSDimitry Andric // shuffle operand. 769*bdd1243dSDimitry Andric return Last < NumElements; 770*bdd1243dSDimitry Andric }; 771*bdd1243dSDimitry Andric 772*bdd1243dSDimitry Andric if (RealShuffle->getType() != ImagShuffle->getType()) { 773*bdd1243dSDimitry Andric LLVM_DEBUG(dbgs() << " - Shuffle types aren't equal.\n"); 774*bdd1243dSDimitry Andric return nullptr; 775*bdd1243dSDimitry Andric } 776*bdd1243dSDimitry Andric if (!CheckDeinterleavingShuffle(RealShuffle)) { 777*bdd1243dSDimitry Andric LLVM_DEBUG(dbgs() << " - RealShuffle is invalid type.\n"); 778*bdd1243dSDimitry Andric return nullptr; 779*bdd1243dSDimitry Andric } 780*bdd1243dSDimitry Andric if (!CheckDeinterleavingShuffle(ImagShuffle)) { 781*bdd1243dSDimitry Andric LLVM_DEBUG(dbgs() << " - ImagShuffle is invalid type.\n"); 782*bdd1243dSDimitry Andric return nullptr; 783*bdd1243dSDimitry Andric } 784*bdd1243dSDimitry Andric 785*bdd1243dSDimitry Andric NodePtr PlaceholderNode = 786*bdd1243dSDimitry Andric prepareCompositeNode(llvm::ComplexDeinterleavingOperation::Shuffle, 787*bdd1243dSDimitry Andric RealShuffle, ImagShuffle); 788*bdd1243dSDimitry Andric PlaceholderNode->ReplacementNode = RealShuffle->getOperand(0); 789*bdd1243dSDimitry Andric return submitCompositeNode(PlaceholderNode); 790*bdd1243dSDimitry Andric } 791*bdd1243dSDimitry Andric if (RealShuffle || ImagShuffle) 792*bdd1243dSDimitry Andric return nullptr; 793*bdd1243dSDimitry Andric 794*bdd1243dSDimitry Andric auto *VTy = cast<FixedVectorType>(Real->getType()); 795*bdd1243dSDimitry Andric auto *NewVTy = 796*bdd1243dSDimitry Andric FixedVectorType::get(VTy->getScalarType(), VTy->getNumElements() * 2); 797*bdd1243dSDimitry Andric 798*bdd1243dSDimitry Andric if (TL->isComplexDeinterleavingOperationSupported( 799*bdd1243dSDimitry Andric ComplexDeinterleavingOperation::CMulPartial, NewVTy) && 800*bdd1243dSDimitry Andric isInstructionPairMul(Real, Imag)) { 801*bdd1243dSDimitry Andric return identifyPartialMul(Real, Imag); 802*bdd1243dSDimitry Andric } 803*bdd1243dSDimitry Andric 804*bdd1243dSDimitry Andric if (TL->isComplexDeinterleavingOperationSupported( 805*bdd1243dSDimitry Andric ComplexDeinterleavingOperation::CAdd, NewVTy) && 806*bdd1243dSDimitry Andric isInstructionPairAdd(Real, Imag)) { 807*bdd1243dSDimitry Andric return identifyAdd(Real, Imag); 808*bdd1243dSDimitry Andric } 809*bdd1243dSDimitry Andric 810*bdd1243dSDimitry Andric return nullptr; 811*bdd1243dSDimitry Andric } 812*bdd1243dSDimitry Andric 813*bdd1243dSDimitry Andric bool ComplexDeinterleavingGraph::identifyNodes(Instruction *RootI) { 814*bdd1243dSDimitry Andric Instruction *Real; 815*bdd1243dSDimitry Andric Instruction *Imag; 816*bdd1243dSDimitry Andric if (!match(RootI, m_Shuffle(m_Instruction(Real), m_Instruction(Imag)))) 817*bdd1243dSDimitry Andric return false; 818*bdd1243dSDimitry Andric 819*bdd1243dSDimitry Andric RootValue = RootI; 820*bdd1243dSDimitry Andric AllInstructions.insert(RootI); 821*bdd1243dSDimitry Andric RootNode = identifyNode(Real, Imag); 822*bdd1243dSDimitry Andric 823*bdd1243dSDimitry Andric LLVM_DEBUG({ 824*bdd1243dSDimitry Andric Function *F = RootI->getFunction(); 825*bdd1243dSDimitry Andric BasicBlock *B = RootI->getParent(); 826*bdd1243dSDimitry Andric dbgs() << "Complex deinterleaving graph for " << F->getName() 827*bdd1243dSDimitry Andric << "::" << B->getName() << ".\n"; 828*bdd1243dSDimitry Andric dump(dbgs()); 829*bdd1243dSDimitry Andric dbgs() << "\n"; 830*bdd1243dSDimitry Andric }); 831*bdd1243dSDimitry Andric 832*bdd1243dSDimitry Andric // Check all instructions have internal uses 833*bdd1243dSDimitry Andric for (const auto &Node : CompositeNodes) { 834*bdd1243dSDimitry Andric if (!Node->hasAllInternalUses(AllInstructions)) { 835*bdd1243dSDimitry Andric LLVM_DEBUG(dbgs() << " - Invalid internal uses\n"); 836*bdd1243dSDimitry Andric return false; 837*bdd1243dSDimitry Andric } 838*bdd1243dSDimitry Andric } 839*bdd1243dSDimitry Andric return RootNode != nullptr; 840*bdd1243dSDimitry Andric } 841*bdd1243dSDimitry Andric 842*bdd1243dSDimitry Andric Value *ComplexDeinterleavingGraph::replaceNode( 843*bdd1243dSDimitry Andric ComplexDeinterleavingGraph::RawNodePtr Node) { 844*bdd1243dSDimitry Andric if (Node->ReplacementNode) 845*bdd1243dSDimitry Andric return Node->ReplacementNode; 846*bdd1243dSDimitry Andric 847*bdd1243dSDimitry Andric Value *Input0 = replaceNode(Node->Operands[0]); 848*bdd1243dSDimitry Andric Value *Input1 = replaceNode(Node->Operands[1]); 849*bdd1243dSDimitry Andric Value *Accumulator = 850*bdd1243dSDimitry Andric Node->Operands.size() > 2 ? replaceNode(Node->Operands[2]) : nullptr; 851*bdd1243dSDimitry Andric 852*bdd1243dSDimitry Andric assert(Input0->getType() == Input1->getType() && 853*bdd1243dSDimitry Andric "Node inputs need to be of the same type"); 854*bdd1243dSDimitry Andric 855*bdd1243dSDimitry Andric Node->ReplacementNode = TL->createComplexDeinterleavingIR( 856*bdd1243dSDimitry Andric Node->Real, Node->Operation, Node->Rotation, Input0, Input1, Accumulator); 857*bdd1243dSDimitry Andric 858*bdd1243dSDimitry Andric assert(Node->ReplacementNode && "Target failed to create Intrinsic call."); 859*bdd1243dSDimitry Andric NumComplexTransformations += 1; 860*bdd1243dSDimitry Andric return Node->ReplacementNode; 861*bdd1243dSDimitry Andric } 862*bdd1243dSDimitry Andric 863*bdd1243dSDimitry Andric void ComplexDeinterleavingGraph::replaceNodes() { 864*bdd1243dSDimitry Andric Value *R = replaceNode(RootNode.get()); 865*bdd1243dSDimitry Andric assert(R && "Unable to find replacement for RootValue"); 866*bdd1243dSDimitry Andric RootValue->replaceAllUsesWith(R); 867*bdd1243dSDimitry Andric } 868*bdd1243dSDimitry Andric 869*bdd1243dSDimitry Andric bool ComplexDeinterleavingCompositeNode::hasAllInternalUses( 870*bdd1243dSDimitry Andric SmallPtrSet<Instruction *, 16> &AllInstructions) { 871*bdd1243dSDimitry Andric if (Operation == ComplexDeinterleavingOperation::Shuffle) 872*bdd1243dSDimitry Andric return true; 873*bdd1243dSDimitry Andric 874*bdd1243dSDimitry Andric for (auto *User : Real->users()) { 875*bdd1243dSDimitry Andric if (!AllInstructions.contains(cast<Instruction>(User))) 876*bdd1243dSDimitry Andric return false; 877*bdd1243dSDimitry Andric } 878*bdd1243dSDimitry Andric for (auto *User : Imag->users()) { 879*bdd1243dSDimitry Andric if (!AllInstructions.contains(cast<Instruction>(User))) 880*bdd1243dSDimitry Andric return false; 881*bdd1243dSDimitry Andric } 882*bdd1243dSDimitry Andric for (auto *I : InternalInstructions) { 883*bdd1243dSDimitry Andric for (auto *User : I->users()) { 884*bdd1243dSDimitry Andric if (!AllInstructions.contains(cast<Instruction>(User))) 885*bdd1243dSDimitry Andric return false; 886*bdd1243dSDimitry Andric } 887*bdd1243dSDimitry Andric } 888*bdd1243dSDimitry Andric return true; 889*bdd1243dSDimitry Andric } 890