1bdd1243dSDimitry Andric //===- ComplexDeinterleavingPass.cpp --------------------------------------===// 2bdd1243dSDimitry Andric // 3bdd1243dSDimitry Andric // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4bdd1243dSDimitry Andric // See https://llvm.org/LICENSE.txt for license information. 5bdd1243dSDimitry Andric // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6bdd1243dSDimitry Andric // 7bdd1243dSDimitry Andric //===----------------------------------------------------------------------===// 8bdd1243dSDimitry Andric // 9bdd1243dSDimitry Andric // Identification: 10bdd1243dSDimitry Andric // This step is responsible for finding the patterns that can be lowered to 11bdd1243dSDimitry Andric // complex instructions, and building a graph to represent the complex 12bdd1243dSDimitry Andric // structures. Starting from the "Converging Shuffle" (a shuffle that 13bdd1243dSDimitry Andric // reinterleaves the complex components, with a mask of <0, 2, 1, 3>), the 14bdd1243dSDimitry Andric // operands are evaluated and identified as "Composite Nodes" (collections of 15bdd1243dSDimitry Andric // instructions that can potentially be lowered to a single complex 16bdd1243dSDimitry Andric // instruction). This is performed by checking the real and imaginary components 17bdd1243dSDimitry Andric // and tracking the data flow for each component while following the operand 18bdd1243dSDimitry Andric // pairs. Validity of each node is expected to be done upon creation, and any 19bdd1243dSDimitry Andric // validation errors should halt traversal and prevent further graph 20bdd1243dSDimitry Andric // construction. 21*06c3fb27SDimitry Andric // Instead of relying on Shuffle operations, vector interleaving and 22*06c3fb27SDimitry Andric // deinterleaving can be represented by vector.interleave2 and 23*06c3fb27SDimitry Andric // vector.deinterleave2 intrinsics. Scalable vectors can be represented only by 24*06c3fb27SDimitry Andric // these intrinsics, whereas, fixed-width vectors are recognized for both 25*06c3fb27SDimitry Andric // shufflevector instruction and intrinsics. 26bdd1243dSDimitry Andric // 27bdd1243dSDimitry Andric // Replacement: 28bdd1243dSDimitry Andric // This step traverses the graph built up by identification, delegating to the 29bdd1243dSDimitry Andric // target to validate and generate the correct intrinsics, and plumbs them 30bdd1243dSDimitry Andric // together connecting each end of the new intrinsics graph to the existing 31bdd1243dSDimitry Andric // use-def chain. This step is assumed to finish successfully, as all 32bdd1243dSDimitry Andric // information is expected to be correct by this point. 33bdd1243dSDimitry Andric // 34bdd1243dSDimitry Andric // 35bdd1243dSDimitry Andric // Internal data structure: 36bdd1243dSDimitry Andric // ComplexDeinterleavingGraph: 37bdd1243dSDimitry Andric // Keeps references to all the valid CompositeNodes formed as part of the 38bdd1243dSDimitry Andric // transformation, and every Instruction contained within said nodes. It also 39bdd1243dSDimitry Andric // holds onto a reference to the root Instruction, and the root node that should 40bdd1243dSDimitry Andric // replace it. 41bdd1243dSDimitry Andric // 42bdd1243dSDimitry Andric // ComplexDeinterleavingCompositeNode: 43bdd1243dSDimitry Andric // A CompositeNode represents a single transformation point; each node should 44bdd1243dSDimitry Andric // transform into a single complex instruction (ignoring vector splitting, which 45bdd1243dSDimitry Andric // would generate more instructions per node). They are identified in a 46bdd1243dSDimitry Andric // depth-first manner, traversing and identifying the operands of each 47bdd1243dSDimitry Andric // instruction in the order they appear in the IR. 48bdd1243dSDimitry Andric // Each node maintains a reference to its Real and Imaginary instructions, 49bdd1243dSDimitry Andric // as well as any additional instructions that make up the identified operation 50bdd1243dSDimitry Andric // (Internal instructions should only have uses within their containing node). 51bdd1243dSDimitry Andric // A Node also contains the rotation and operation type that it represents. 52bdd1243dSDimitry Andric // Operands contains pointers to other CompositeNodes, acting as the edges in 53bdd1243dSDimitry Andric // the graph. ReplacementValue is the transformed Value* that has been emitted 54bdd1243dSDimitry Andric // to the IR. 55bdd1243dSDimitry Andric // 56bdd1243dSDimitry Andric // Note: If the operation of a Node is Shuffle, only the Real, Imaginary, and 57bdd1243dSDimitry Andric // ReplacementValue fields of that Node are relevant, where the ReplacementValue 58bdd1243dSDimitry Andric // should be pre-populated. 59bdd1243dSDimitry Andric // 60bdd1243dSDimitry Andric //===----------------------------------------------------------------------===// 61bdd1243dSDimitry Andric 62bdd1243dSDimitry Andric #include "llvm/CodeGen/ComplexDeinterleavingPass.h" 63bdd1243dSDimitry Andric #include "llvm/ADT/Statistic.h" 64bdd1243dSDimitry Andric #include "llvm/Analysis/TargetLibraryInfo.h" 65bdd1243dSDimitry Andric #include "llvm/Analysis/TargetTransformInfo.h" 66bdd1243dSDimitry Andric #include "llvm/CodeGen/TargetLowering.h" 67bdd1243dSDimitry Andric #include "llvm/CodeGen/TargetPassConfig.h" 68bdd1243dSDimitry Andric #include "llvm/CodeGen/TargetSubtargetInfo.h" 69bdd1243dSDimitry Andric #include "llvm/IR/IRBuilder.h" 70*06c3fb27SDimitry Andric #include "llvm/IR/PatternMatch.h" 71bdd1243dSDimitry Andric #include "llvm/InitializePasses.h" 72bdd1243dSDimitry Andric #include "llvm/Target/TargetMachine.h" 73bdd1243dSDimitry Andric #include "llvm/Transforms/Utils/Local.h" 74bdd1243dSDimitry Andric #include <algorithm> 75bdd1243dSDimitry Andric 76bdd1243dSDimitry Andric using namespace llvm; 77bdd1243dSDimitry Andric using namespace PatternMatch; 78bdd1243dSDimitry Andric 79bdd1243dSDimitry Andric #define DEBUG_TYPE "complex-deinterleaving" 80bdd1243dSDimitry Andric 81bdd1243dSDimitry Andric STATISTIC(NumComplexTransformations, "Amount of complex patterns transformed"); 82bdd1243dSDimitry Andric 83bdd1243dSDimitry Andric static cl::opt<bool> ComplexDeinterleavingEnabled( 84bdd1243dSDimitry Andric "enable-complex-deinterleaving", 85bdd1243dSDimitry Andric cl::desc("Enable generation of complex instructions"), cl::init(true), 86bdd1243dSDimitry Andric cl::Hidden); 87bdd1243dSDimitry Andric 88bdd1243dSDimitry Andric /// Checks the given mask, and determines whether said mask is interleaving. 89bdd1243dSDimitry Andric /// 90bdd1243dSDimitry Andric /// To be interleaving, a mask must alternate between `i` and `i + (Length / 91bdd1243dSDimitry Andric /// 2)`, and must contain all numbers within the range of `[0..Length)` (e.g. a 92bdd1243dSDimitry Andric /// 4x vector interleaving mask would be <0, 2, 1, 3>). 93bdd1243dSDimitry Andric static bool isInterleavingMask(ArrayRef<int> Mask); 94bdd1243dSDimitry Andric 95bdd1243dSDimitry Andric /// Checks the given mask, and determines whether said mask is deinterleaving. 96bdd1243dSDimitry Andric /// 97bdd1243dSDimitry Andric /// To be deinterleaving, a mask must increment in steps of 2, and either start 98bdd1243dSDimitry Andric /// with 0 or 1. 99bdd1243dSDimitry Andric /// (e.g. an 8x vector deinterleaving mask would be either <0, 2, 4, 6> or 100bdd1243dSDimitry Andric /// <1, 3, 5, 7>). 101bdd1243dSDimitry Andric static bool isDeinterleavingMask(ArrayRef<int> Mask); 102bdd1243dSDimitry Andric 103*06c3fb27SDimitry Andric /// Returns true if the operation is a negation of V, and it works for both 104*06c3fb27SDimitry Andric /// integers and floats. 105*06c3fb27SDimitry Andric static bool isNeg(Value *V); 106*06c3fb27SDimitry Andric 107*06c3fb27SDimitry Andric /// Returns the operand for negation operation. 108*06c3fb27SDimitry Andric static Value *getNegOperand(Value *V); 109*06c3fb27SDimitry Andric 110bdd1243dSDimitry Andric namespace { 111bdd1243dSDimitry Andric 112bdd1243dSDimitry Andric class ComplexDeinterleavingLegacyPass : public FunctionPass { 113bdd1243dSDimitry Andric public: 114bdd1243dSDimitry Andric static char ID; 115bdd1243dSDimitry Andric 116bdd1243dSDimitry Andric ComplexDeinterleavingLegacyPass(const TargetMachine *TM = nullptr) 117bdd1243dSDimitry Andric : FunctionPass(ID), TM(TM) { 118bdd1243dSDimitry Andric initializeComplexDeinterleavingLegacyPassPass( 119bdd1243dSDimitry Andric *PassRegistry::getPassRegistry()); 120bdd1243dSDimitry Andric } 121bdd1243dSDimitry Andric 122bdd1243dSDimitry Andric StringRef getPassName() const override { 123bdd1243dSDimitry Andric return "Complex Deinterleaving Pass"; 124bdd1243dSDimitry Andric } 125bdd1243dSDimitry Andric 126bdd1243dSDimitry Andric bool runOnFunction(Function &F) override; 127bdd1243dSDimitry Andric void getAnalysisUsage(AnalysisUsage &AU) const override { 128bdd1243dSDimitry Andric AU.addRequired<TargetLibraryInfoWrapperPass>(); 129bdd1243dSDimitry Andric AU.setPreservesCFG(); 130bdd1243dSDimitry Andric } 131bdd1243dSDimitry Andric 132bdd1243dSDimitry Andric private: 133bdd1243dSDimitry Andric const TargetMachine *TM; 134bdd1243dSDimitry Andric }; 135bdd1243dSDimitry Andric 136bdd1243dSDimitry Andric class ComplexDeinterleavingGraph; 137bdd1243dSDimitry Andric struct ComplexDeinterleavingCompositeNode { 138bdd1243dSDimitry Andric 139bdd1243dSDimitry Andric ComplexDeinterleavingCompositeNode(ComplexDeinterleavingOperation Op, 140*06c3fb27SDimitry Andric Value *R, Value *I) 141bdd1243dSDimitry Andric : Operation(Op), Real(R), Imag(I) {} 142bdd1243dSDimitry Andric 143bdd1243dSDimitry Andric private: 144bdd1243dSDimitry Andric friend class ComplexDeinterleavingGraph; 145bdd1243dSDimitry Andric using NodePtr = std::shared_ptr<ComplexDeinterleavingCompositeNode>; 146bdd1243dSDimitry Andric using RawNodePtr = ComplexDeinterleavingCompositeNode *; 147bdd1243dSDimitry Andric 148bdd1243dSDimitry Andric public: 149bdd1243dSDimitry Andric ComplexDeinterleavingOperation Operation; 150*06c3fb27SDimitry Andric Value *Real; 151*06c3fb27SDimitry Andric Value *Imag; 152bdd1243dSDimitry Andric 153*06c3fb27SDimitry Andric // This two members are required exclusively for generating 154*06c3fb27SDimitry Andric // ComplexDeinterleavingOperation::Symmetric operations. 155*06c3fb27SDimitry Andric unsigned Opcode; 156*06c3fb27SDimitry Andric std::optional<FastMathFlags> Flags; 157*06c3fb27SDimitry Andric 158*06c3fb27SDimitry Andric ComplexDeinterleavingRotation Rotation = 159*06c3fb27SDimitry Andric ComplexDeinterleavingRotation::Rotation_0; 160bdd1243dSDimitry Andric SmallVector<RawNodePtr> Operands; 161bdd1243dSDimitry Andric Value *ReplacementNode = nullptr; 162bdd1243dSDimitry Andric 163bdd1243dSDimitry Andric void addOperand(NodePtr Node) { Operands.push_back(Node.get()); } 164bdd1243dSDimitry Andric 165bdd1243dSDimitry Andric void dump() { dump(dbgs()); } 166bdd1243dSDimitry Andric void dump(raw_ostream &OS) { 167bdd1243dSDimitry Andric auto PrintValue = [&](Value *V) { 168bdd1243dSDimitry Andric if (V) { 169bdd1243dSDimitry Andric OS << "\""; 170bdd1243dSDimitry Andric V->print(OS, true); 171bdd1243dSDimitry Andric OS << "\"\n"; 172bdd1243dSDimitry Andric } else 173bdd1243dSDimitry Andric OS << "nullptr\n"; 174bdd1243dSDimitry Andric }; 175bdd1243dSDimitry Andric auto PrintNodeRef = [&](RawNodePtr Ptr) { 176bdd1243dSDimitry Andric if (Ptr) 177bdd1243dSDimitry Andric OS << Ptr << "\n"; 178bdd1243dSDimitry Andric else 179bdd1243dSDimitry Andric OS << "nullptr\n"; 180bdd1243dSDimitry Andric }; 181bdd1243dSDimitry Andric 182bdd1243dSDimitry Andric OS << "- CompositeNode: " << this << "\n"; 183bdd1243dSDimitry Andric OS << " Real: "; 184bdd1243dSDimitry Andric PrintValue(Real); 185bdd1243dSDimitry Andric OS << " Imag: "; 186bdd1243dSDimitry Andric PrintValue(Imag); 187bdd1243dSDimitry Andric OS << " ReplacementNode: "; 188bdd1243dSDimitry Andric PrintValue(ReplacementNode); 189bdd1243dSDimitry Andric OS << " Operation: " << (int)Operation << "\n"; 190bdd1243dSDimitry Andric OS << " Rotation: " << ((int)Rotation * 90) << "\n"; 191bdd1243dSDimitry Andric OS << " Operands: \n"; 192bdd1243dSDimitry Andric for (const auto &Op : Operands) { 193bdd1243dSDimitry Andric OS << " - "; 194bdd1243dSDimitry Andric PrintNodeRef(Op); 195bdd1243dSDimitry Andric } 196bdd1243dSDimitry Andric } 197bdd1243dSDimitry Andric }; 198bdd1243dSDimitry Andric 199bdd1243dSDimitry Andric class ComplexDeinterleavingGraph { 200bdd1243dSDimitry Andric public: 201*06c3fb27SDimitry Andric struct Product { 202*06c3fb27SDimitry Andric Value *Multiplier; 203*06c3fb27SDimitry Andric Value *Multiplicand; 204*06c3fb27SDimitry Andric bool IsPositive; 205*06c3fb27SDimitry Andric }; 206*06c3fb27SDimitry Andric 207*06c3fb27SDimitry Andric using Addend = std::pair<Value *, bool>; 208bdd1243dSDimitry Andric using NodePtr = ComplexDeinterleavingCompositeNode::NodePtr; 209bdd1243dSDimitry Andric using RawNodePtr = ComplexDeinterleavingCompositeNode::RawNodePtr; 210*06c3fb27SDimitry Andric 211*06c3fb27SDimitry Andric // Helper struct for holding info about potential partial multiplication 212*06c3fb27SDimitry Andric // candidates 213*06c3fb27SDimitry Andric struct PartialMulCandidate { 214*06c3fb27SDimitry Andric Value *Common; 215*06c3fb27SDimitry Andric NodePtr Node; 216*06c3fb27SDimitry Andric unsigned RealIdx; 217*06c3fb27SDimitry Andric unsigned ImagIdx; 218*06c3fb27SDimitry Andric bool IsNodeInverted; 219*06c3fb27SDimitry Andric }; 220*06c3fb27SDimitry Andric 221*06c3fb27SDimitry Andric explicit ComplexDeinterleavingGraph(const TargetLowering *TL, 222*06c3fb27SDimitry Andric const TargetLibraryInfo *TLI) 223*06c3fb27SDimitry Andric : TL(TL), TLI(TLI) {} 224bdd1243dSDimitry Andric 225bdd1243dSDimitry Andric private: 226*06c3fb27SDimitry Andric const TargetLowering *TL = nullptr; 227*06c3fb27SDimitry Andric const TargetLibraryInfo *TLI = nullptr; 228bdd1243dSDimitry Andric SmallVector<NodePtr> CompositeNodes; 229*06c3fb27SDimitry Andric 230*06c3fb27SDimitry Andric SmallPtrSet<Instruction *, 16> FinalInstructions; 231*06c3fb27SDimitry Andric 232*06c3fb27SDimitry Andric /// Root instructions are instructions from which complex computation starts 233*06c3fb27SDimitry Andric std::map<Instruction *, NodePtr> RootToNode; 234*06c3fb27SDimitry Andric 235*06c3fb27SDimitry Andric /// Topologically sorted root instructions 236*06c3fb27SDimitry Andric SmallVector<Instruction *, 1> OrderedRoots; 237*06c3fb27SDimitry Andric 238*06c3fb27SDimitry Andric /// When examining a basic block for complex deinterleaving, if it is a simple 239*06c3fb27SDimitry Andric /// one-block loop, then the only incoming block is 'Incoming' and the 240*06c3fb27SDimitry Andric /// 'BackEdge' block is the block itself." 241*06c3fb27SDimitry Andric BasicBlock *BackEdge = nullptr; 242*06c3fb27SDimitry Andric BasicBlock *Incoming = nullptr; 243*06c3fb27SDimitry Andric 244*06c3fb27SDimitry Andric /// ReductionInfo maps from %ReductionOp to %PHInode and Instruction 245*06c3fb27SDimitry Andric /// %OutsideUser as it is shown in the IR: 246*06c3fb27SDimitry Andric /// 247*06c3fb27SDimitry Andric /// vector.body: 248*06c3fb27SDimitry Andric /// %PHInode = phi <vector type> [ zeroinitializer, %entry ], 249*06c3fb27SDimitry Andric /// [ %ReductionOp, %vector.body ] 250*06c3fb27SDimitry Andric /// ... 251*06c3fb27SDimitry Andric /// %ReductionOp = fadd i64 ... 252*06c3fb27SDimitry Andric /// ... 253*06c3fb27SDimitry Andric /// br i1 %condition, label %vector.body, %middle.block 254*06c3fb27SDimitry Andric /// 255*06c3fb27SDimitry Andric /// middle.block: 256*06c3fb27SDimitry Andric /// %OutsideUser = llvm.vector.reduce.fadd(..., %ReductionOp) 257*06c3fb27SDimitry Andric /// 258*06c3fb27SDimitry Andric /// %OutsideUser can be `llvm.vector.reduce.fadd` or `fadd` preceding 259*06c3fb27SDimitry Andric /// `llvm.vector.reduce.fadd` when unroll factor isn't one. 260*06c3fb27SDimitry Andric std::map<Instruction *, std::pair<PHINode *, Instruction *>> ReductionInfo; 261*06c3fb27SDimitry Andric 262*06c3fb27SDimitry Andric /// In the process of detecting a reduction, we consider a pair of 263*06c3fb27SDimitry Andric /// %ReductionOP, which we refer to as real and imag (or vice versa), and 264*06c3fb27SDimitry Andric /// traverse the use-tree to detect complex operations. As this is a reduction 265*06c3fb27SDimitry Andric /// operation, it will eventually reach RealPHI and ImagPHI, which corresponds 266*06c3fb27SDimitry Andric /// to the %ReductionOPs that we suspect to be complex. 267*06c3fb27SDimitry Andric /// RealPHI and ImagPHI are used by the identifyPHINode method. 268*06c3fb27SDimitry Andric PHINode *RealPHI = nullptr; 269*06c3fb27SDimitry Andric PHINode *ImagPHI = nullptr; 270*06c3fb27SDimitry Andric 271*06c3fb27SDimitry Andric /// Set this flag to true if RealPHI and ImagPHI were reached during reduction 272*06c3fb27SDimitry Andric /// detection. 273*06c3fb27SDimitry Andric bool PHIsFound = false; 274*06c3fb27SDimitry Andric 275*06c3fb27SDimitry Andric /// OldToNewPHI maps the original real PHINode to a new, double-sized PHINode. 276*06c3fb27SDimitry Andric /// The new PHINode corresponds to a vector of deinterleaved complex numbers. 277*06c3fb27SDimitry Andric /// This mapping is populated during 278*06c3fb27SDimitry Andric /// ComplexDeinterleavingOperation::ReductionPHI node replacement. It is then 279*06c3fb27SDimitry Andric /// used in the ComplexDeinterleavingOperation::ReductionOperation node 280*06c3fb27SDimitry Andric /// replacement process. 281*06c3fb27SDimitry Andric std::map<PHINode *, PHINode *> OldToNewPHI; 282bdd1243dSDimitry Andric 283bdd1243dSDimitry Andric NodePtr prepareCompositeNode(ComplexDeinterleavingOperation Operation, 284*06c3fb27SDimitry Andric Value *R, Value *I) { 285*06c3fb27SDimitry Andric assert(((Operation != ComplexDeinterleavingOperation::ReductionPHI && 286*06c3fb27SDimitry Andric Operation != ComplexDeinterleavingOperation::ReductionOperation) || 287*06c3fb27SDimitry Andric (R && I)) && 288*06c3fb27SDimitry Andric "Reduction related nodes must have Real and Imaginary parts"); 289bdd1243dSDimitry Andric return std::make_shared<ComplexDeinterleavingCompositeNode>(Operation, R, 290bdd1243dSDimitry Andric I); 291bdd1243dSDimitry Andric } 292bdd1243dSDimitry Andric 293bdd1243dSDimitry Andric NodePtr submitCompositeNode(NodePtr Node) { 294bdd1243dSDimitry Andric CompositeNodes.push_back(Node); 295bdd1243dSDimitry Andric return Node; 296bdd1243dSDimitry Andric } 297bdd1243dSDimitry Andric 298bdd1243dSDimitry Andric NodePtr getContainingComposite(Value *R, Value *I) { 299bdd1243dSDimitry Andric for (const auto &CN : CompositeNodes) { 300bdd1243dSDimitry Andric if (CN->Real == R && CN->Imag == I) 301bdd1243dSDimitry Andric return CN; 302bdd1243dSDimitry Andric } 303bdd1243dSDimitry Andric return nullptr; 304bdd1243dSDimitry Andric } 305bdd1243dSDimitry Andric 306bdd1243dSDimitry Andric /// Identifies a complex partial multiply pattern and its rotation, based on 307bdd1243dSDimitry Andric /// the following patterns 308bdd1243dSDimitry Andric /// 309bdd1243dSDimitry Andric /// 0: r: cr + ar * br 310bdd1243dSDimitry Andric /// i: ci + ar * bi 311bdd1243dSDimitry Andric /// 90: r: cr - ai * bi 312bdd1243dSDimitry Andric /// i: ci + ai * br 313bdd1243dSDimitry Andric /// 180: r: cr - ar * br 314bdd1243dSDimitry Andric /// i: ci - ar * bi 315bdd1243dSDimitry Andric /// 270: r: cr + ai * bi 316bdd1243dSDimitry Andric /// i: ci - ai * br 317bdd1243dSDimitry Andric NodePtr identifyPartialMul(Instruction *Real, Instruction *Imag); 318bdd1243dSDimitry Andric 319bdd1243dSDimitry Andric /// Identify the other branch of a Partial Mul, taking the CommonOperandI that 320bdd1243dSDimitry Andric /// is partially known from identifyPartialMul, filling in the other half of 321bdd1243dSDimitry Andric /// the complex pair. 322*06c3fb27SDimitry Andric NodePtr 323*06c3fb27SDimitry Andric identifyNodeWithImplicitAdd(Instruction *I, Instruction *J, 324*06c3fb27SDimitry Andric std::pair<Value *, Value *> &CommonOperandI); 325bdd1243dSDimitry Andric 326bdd1243dSDimitry Andric /// Identifies a complex add pattern and its rotation, based on the following 327bdd1243dSDimitry Andric /// patterns. 328bdd1243dSDimitry Andric /// 329bdd1243dSDimitry Andric /// 90: r: ar - bi 330bdd1243dSDimitry Andric /// i: ai + br 331bdd1243dSDimitry Andric /// 270: r: ar + bi 332bdd1243dSDimitry Andric /// i: ai - br 333bdd1243dSDimitry Andric NodePtr identifyAdd(Instruction *Real, Instruction *Imag); 334*06c3fb27SDimitry Andric NodePtr identifySymmetricOperation(Instruction *Real, Instruction *Imag); 335bdd1243dSDimitry Andric 336*06c3fb27SDimitry Andric NodePtr identifyNode(Value *R, Value *I); 337bdd1243dSDimitry Andric 338*06c3fb27SDimitry Andric /// Determine if a sum of complex numbers can be formed from \p RealAddends 339*06c3fb27SDimitry Andric /// and \p ImagAddens. If \p Accumulator is not null, add the result to it. 340*06c3fb27SDimitry Andric /// Return nullptr if it is not possible to construct a complex number. 341*06c3fb27SDimitry Andric /// \p Flags are needed to generate symmetric Add and Sub operations. 342*06c3fb27SDimitry Andric NodePtr identifyAdditions(std::list<Addend> &RealAddends, 343*06c3fb27SDimitry Andric std::list<Addend> &ImagAddends, 344*06c3fb27SDimitry Andric std::optional<FastMathFlags> Flags, 345*06c3fb27SDimitry Andric NodePtr Accumulator); 346*06c3fb27SDimitry Andric 347*06c3fb27SDimitry Andric /// Extract one addend that have both real and imaginary parts positive. 348*06c3fb27SDimitry Andric NodePtr extractPositiveAddend(std::list<Addend> &RealAddends, 349*06c3fb27SDimitry Andric std::list<Addend> &ImagAddends); 350*06c3fb27SDimitry Andric 351*06c3fb27SDimitry Andric /// Determine if sum of multiplications of complex numbers can be formed from 352*06c3fb27SDimitry Andric /// \p RealMuls and \p ImagMuls. If \p Accumulator is not null, add the result 353*06c3fb27SDimitry Andric /// to it. Return nullptr if it is not possible to construct a complex number. 354*06c3fb27SDimitry Andric NodePtr identifyMultiplications(std::vector<Product> &RealMuls, 355*06c3fb27SDimitry Andric std::vector<Product> &ImagMuls, 356*06c3fb27SDimitry Andric NodePtr Accumulator); 357*06c3fb27SDimitry Andric 358*06c3fb27SDimitry Andric /// Go through pairs of multiplication (one Real and one Imag) and find all 359*06c3fb27SDimitry Andric /// possible candidates for partial multiplication and put them into \p 360*06c3fb27SDimitry Andric /// Candidates. Returns true if all Product has pair with common operand 361*06c3fb27SDimitry Andric bool collectPartialMuls(const std::vector<Product> &RealMuls, 362*06c3fb27SDimitry Andric const std::vector<Product> &ImagMuls, 363*06c3fb27SDimitry Andric std::vector<PartialMulCandidate> &Candidates); 364*06c3fb27SDimitry Andric 365*06c3fb27SDimitry Andric /// If the code is compiled with -Ofast or expressions have `reassoc` flag, 366*06c3fb27SDimitry Andric /// the order of complex computation operations may be significantly altered, 367*06c3fb27SDimitry Andric /// and the real and imaginary parts may not be executed in parallel. This 368*06c3fb27SDimitry Andric /// function takes this into consideration and employs a more general approach 369*06c3fb27SDimitry Andric /// to identify complex computations. Initially, it gathers all the addends 370*06c3fb27SDimitry Andric /// and multiplicands and then constructs a complex expression from them. 371*06c3fb27SDimitry Andric NodePtr identifyReassocNodes(Instruction *I, Instruction *J); 372*06c3fb27SDimitry Andric 373*06c3fb27SDimitry Andric NodePtr identifyRoot(Instruction *I); 374*06c3fb27SDimitry Andric 375*06c3fb27SDimitry Andric /// Identifies the Deinterleave operation applied to a vector containing 376*06c3fb27SDimitry Andric /// complex numbers. There are two ways to represent the Deinterleave 377*06c3fb27SDimitry Andric /// operation: 378*06c3fb27SDimitry Andric /// * Using two shufflevectors with even indices for /pReal instruction and 379*06c3fb27SDimitry Andric /// odd indices for /pImag instructions (only for fixed-width vectors) 380*06c3fb27SDimitry Andric /// * Using two extractvalue instructions applied to `vector.deinterleave2` 381*06c3fb27SDimitry Andric /// intrinsic (for both fixed and scalable vectors) 382*06c3fb27SDimitry Andric NodePtr identifyDeinterleave(Instruction *Real, Instruction *Imag); 383*06c3fb27SDimitry Andric 384*06c3fb27SDimitry Andric /// identifying the operation that represents a complex number repeated in a 385*06c3fb27SDimitry Andric /// Splat vector. There are two possible types of splats: ConstantExpr with 386*06c3fb27SDimitry Andric /// the opcode ShuffleVector and ShuffleVectorInstr. Both should have an 387*06c3fb27SDimitry Andric /// initialization mask with all values set to zero. 388*06c3fb27SDimitry Andric NodePtr identifySplat(Value *Real, Value *Imag); 389*06c3fb27SDimitry Andric 390*06c3fb27SDimitry Andric NodePtr identifyPHINode(Instruction *Real, Instruction *Imag); 391*06c3fb27SDimitry Andric 392*06c3fb27SDimitry Andric /// Identifies SelectInsts in a loop that has reduction with predication masks 393*06c3fb27SDimitry Andric /// and/or predicated tail folding 394*06c3fb27SDimitry Andric NodePtr identifySelectNode(Instruction *Real, Instruction *Imag); 395*06c3fb27SDimitry Andric 396*06c3fb27SDimitry Andric Value *replaceNode(IRBuilderBase &Builder, RawNodePtr Node); 397*06c3fb27SDimitry Andric 398*06c3fb27SDimitry Andric /// Complete IR modifications after producing new reduction operation: 399*06c3fb27SDimitry Andric /// * Populate the PHINode generated for 400*06c3fb27SDimitry Andric /// ComplexDeinterleavingOperation::ReductionPHI 401*06c3fb27SDimitry Andric /// * Deinterleave the final value outside of the loop and repurpose original 402*06c3fb27SDimitry Andric /// reduction users 403*06c3fb27SDimitry Andric void processReductionOperation(Value *OperationReplacement, RawNodePtr Node); 404bdd1243dSDimitry Andric 405bdd1243dSDimitry Andric public: 406bdd1243dSDimitry Andric void dump() { dump(dbgs()); } 407bdd1243dSDimitry Andric void dump(raw_ostream &OS) { 408bdd1243dSDimitry Andric for (const auto &Node : CompositeNodes) 409bdd1243dSDimitry Andric Node->dump(OS); 410bdd1243dSDimitry Andric } 411bdd1243dSDimitry Andric 412bdd1243dSDimitry Andric /// Returns false if the deinterleaving operation should be cancelled for the 413bdd1243dSDimitry Andric /// current graph. 414bdd1243dSDimitry Andric bool identifyNodes(Instruction *RootI); 415bdd1243dSDimitry Andric 416*06c3fb27SDimitry Andric /// In case \pB is one-block loop, this function seeks potential reductions 417*06c3fb27SDimitry Andric /// and populates ReductionInfo. Returns true if any reductions were 418*06c3fb27SDimitry Andric /// identified. 419*06c3fb27SDimitry Andric bool collectPotentialReductions(BasicBlock *B); 420*06c3fb27SDimitry Andric 421*06c3fb27SDimitry Andric void identifyReductionNodes(); 422*06c3fb27SDimitry Andric 423*06c3fb27SDimitry Andric /// Check that every instruction, from the roots to the leaves, has internal 424*06c3fb27SDimitry Andric /// uses. 425*06c3fb27SDimitry Andric bool checkNodes(); 426*06c3fb27SDimitry Andric 427bdd1243dSDimitry Andric /// Perform the actual replacement of the underlying instruction graph. 428bdd1243dSDimitry Andric void replaceNodes(); 429bdd1243dSDimitry Andric }; 430bdd1243dSDimitry Andric 431bdd1243dSDimitry Andric class ComplexDeinterleaving { 432bdd1243dSDimitry Andric public: 433bdd1243dSDimitry Andric ComplexDeinterleaving(const TargetLowering *tl, const TargetLibraryInfo *tli) 434bdd1243dSDimitry Andric : TL(tl), TLI(tli) {} 435bdd1243dSDimitry Andric bool runOnFunction(Function &F); 436bdd1243dSDimitry Andric 437bdd1243dSDimitry Andric private: 438bdd1243dSDimitry Andric bool evaluateBasicBlock(BasicBlock *B); 439bdd1243dSDimitry Andric 440bdd1243dSDimitry Andric const TargetLowering *TL = nullptr; 441bdd1243dSDimitry Andric const TargetLibraryInfo *TLI = nullptr; 442bdd1243dSDimitry Andric }; 443bdd1243dSDimitry Andric 444bdd1243dSDimitry Andric } // namespace 445bdd1243dSDimitry Andric 446bdd1243dSDimitry Andric char ComplexDeinterleavingLegacyPass::ID = 0; 447bdd1243dSDimitry Andric 448bdd1243dSDimitry Andric INITIALIZE_PASS_BEGIN(ComplexDeinterleavingLegacyPass, DEBUG_TYPE, 449bdd1243dSDimitry Andric "Complex Deinterleaving", false, false) 450bdd1243dSDimitry Andric INITIALIZE_PASS_END(ComplexDeinterleavingLegacyPass, DEBUG_TYPE, 451bdd1243dSDimitry Andric "Complex Deinterleaving", false, false) 452bdd1243dSDimitry Andric 453bdd1243dSDimitry Andric PreservedAnalyses ComplexDeinterleavingPass::run(Function &F, 454bdd1243dSDimitry Andric FunctionAnalysisManager &AM) { 455bdd1243dSDimitry Andric const TargetLowering *TL = TM->getSubtargetImpl(F)->getTargetLowering(); 456bdd1243dSDimitry Andric auto &TLI = AM.getResult<llvm::TargetLibraryAnalysis>(F); 457bdd1243dSDimitry Andric if (!ComplexDeinterleaving(TL, &TLI).runOnFunction(F)) 458bdd1243dSDimitry Andric return PreservedAnalyses::all(); 459bdd1243dSDimitry Andric 460bdd1243dSDimitry Andric PreservedAnalyses PA; 461bdd1243dSDimitry Andric PA.preserve<FunctionAnalysisManagerModuleProxy>(); 462bdd1243dSDimitry Andric return PA; 463bdd1243dSDimitry Andric } 464bdd1243dSDimitry Andric 465bdd1243dSDimitry Andric FunctionPass *llvm::createComplexDeinterleavingPass(const TargetMachine *TM) { 466bdd1243dSDimitry Andric return new ComplexDeinterleavingLegacyPass(TM); 467bdd1243dSDimitry Andric } 468bdd1243dSDimitry Andric 469bdd1243dSDimitry Andric bool ComplexDeinterleavingLegacyPass::runOnFunction(Function &F) { 470bdd1243dSDimitry Andric const auto *TL = TM->getSubtargetImpl(F)->getTargetLowering(); 471bdd1243dSDimitry Andric auto TLI = getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F); 472bdd1243dSDimitry Andric return ComplexDeinterleaving(TL, &TLI).runOnFunction(F); 473bdd1243dSDimitry Andric } 474bdd1243dSDimitry Andric 475bdd1243dSDimitry Andric bool ComplexDeinterleaving::runOnFunction(Function &F) { 476bdd1243dSDimitry Andric if (!ComplexDeinterleavingEnabled) { 477bdd1243dSDimitry Andric LLVM_DEBUG( 478bdd1243dSDimitry Andric dbgs() << "Complex deinterleaving has been explicitly disabled.\n"); 479bdd1243dSDimitry Andric return false; 480bdd1243dSDimitry Andric } 481bdd1243dSDimitry Andric 482bdd1243dSDimitry Andric if (!TL->isComplexDeinterleavingSupported()) { 483bdd1243dSDimitry Andric LLVM_DEBUG( 484bdd1243dSDimitry Andric dbgs() << "Complex deinterleaving has been disabled, target does " 485bdd1243dSDimitry Andric "not support lowering of complex number operations.\n"); 486bdd1243dSDimitry Andric return false; 487bdd1243dSDimitry Andric } 488bdd1243dSDimitry Andric 489bdd1243dSDimitry Andric bool Changed = false; 490bdd1243dSDimitry Andric for (auto &B : F) 491bdd1243dSDimitry Andric Changed |= evaluateBasicBlock(&B); 492bdd1243dSDimitry Andric 493bdd1243dSDimitry Andric return Changed; 494bdd1243dSDimitry Andric } 495bdd1243dSDimitry Andric 496bdd1243dSDimitry Andric static bool isInterleavingMask(ArrayRef<int> Mask) { 497bdd1243dSDimitry Andric // If the size is not even, it's not an interleaving mask 498bdd1243dSDimitry Andric if ((Mask.size() & 1)) 499bdd1243dSDimitry Andric return false; 500bdd1243dSDimitry Andric 501bdd1243dSDimitry Andric int HalfNumElements = Mask.size() / 2; 502bdd1243dSDimitry Andric for (int Idx = 0; Idx < HalfNumElements; ++Idx) { 503bdd1243dSDimitry Andric int MaskIdx = Idx * 2; 504bdd1243dSDimitry Andric if (Mask[MaskIdx] != Idx || Mask[MaskIdx + 1] != (Idx + HalfNumElements)) 505bdd1243dSDimitry Andric return false; 506bdd1243dSDimitry Andric } 507bdd1243dSDimitry Andric 508bdd1243dSDimitry Andric return true; 509bdd1243dSDimitry Andric } 510bdd1243dSDimitry Andric 511bdd1243dSDimitry Andric static bool isDeinterleavingMask(ArrayRef<int> Mask) { 512bdd1243dSDimitry Andric int Offset = Mask[0]; 513bdd1243dSDimitry Andric int HalfNumElements = Mask.size() / 2; 514bdd1243dSDimitry Andric 515bdd1243dSDimitry Andric for (int Idx = 1; Idx < HalfNumElements; ++Idx) { 516bdd1243dSDimitry Andric if (Mask[Idx] != (Idx * 2) + Offset) 517bdd1243dSDimitry Andric return false; 518bdd1243dSDimitry Andric } 519bdd1243dSDimitry Andric 520bdd1243dSDimitry Andric return true; 521bdd1243dSDimitry Andric } 522bdd1243dSDimitry Andric 523*06c3fb27SDimitry Andric bool isNeg(Value *V) { 524*06c3fb27SDimitry Andric return match(V, m_FNeg(m_Value())) || match(V, m_Neg(m_Value())); 525*06c3fb27SDimitry Andric } 526*06c3fb27SDimitry Andric 527*06c3fb27SDimitry Andric Value *getNegOperand(Value *V) { 528*06c3fb27SDimitry Andric assert(isNeg(V)); 529*06c3fb27SDimitry Andric auto *I = cast<Instruction>(V); 530*06c3fb27SDimitry Andric if (I->getOpcode() == Instruction::FNeg) 531*06c3fb27SDimitry Andric return I->getOperand(0); 532*06c3fb27SDimitry Andric 533*06c3fb27SDimitry Andric return I->getOperand(1); 534*06c3fb27SDimitry Andric } 535*06c3fb27SDimitry Andric 536bdd1243dSDimitry Andric bool ComplexDeinterleaving::evaluateBasicBlock(BasicBlock *B) { 537*06c3fb27SDimitry Andric ComplexDeinterleavingGraph Graph(TL, TLI); 538*06c3fb27SDimitry Andric if (Graph.collectPotentialReductions(B)) 539*06c3fb27SDimitry Andric Graph.identifyReductionNodes(); 540bdd1243dSDimitry Andric 541*06c3fb27SDimitry Andric for (auto &I : *B) 542*06c3fb27SDimitry Andric Graph.identifyNodes(&I); 543bdd1243dSDimitry Andric 544*06c3fb27SDimitry Andric if (Graph.checkNodes()) { 545bdd1243dSDimitry Andric Graph.replaceNodes(); 546*06c3fb27SDimitry Andric return true; 547bdd1243dSDimitry Andric } 548bdd1243dSDimitry Andric 549*06c3fb27SDimitry Andric return false; 550bdd1243dSDimitry Andric } 551bdd1243dSDimitry Andric 552bdd1243dSDimitry Andric ComplexDeinterleavingGraph::NodePtr 553bdd1243dSDimitry Andric ComplexDeinterleavingGraph::identifyNodeWithImplicitAdd( 554bdd1243dSDimitry Andric Instruction *Real, Instruction *Imag, 555*06c3fb27SDimitry Andric std::pair<Value *, Value *> &PartialMatch) { 556bdd1243dSDimitry Andric LLVM_DEBUG(dbgs() << "identifyNodeWithImplicitAdd " << *Real << " / " << *Imag 557bdd1243dSDimitry Andric << "\n"); 558bdd1243dSDimitry Andric 559bdd1243dSDimitry Andric if (!Real->hasOneUse() || !Imag->hasOneUse()) { 560bdd1243dSDimitry Andric LLVM_DEBUG(dbgs() << " - Mul operand has multiple uses.\n"); 561bdd1243dSDimitry Andric return nullptr; 562bdd1243dSDimitry Andric } 563bdd1243dSDimitry Andric 564*06c3fb27SDimitry Andric if ((Real->getOpcode() != Instruction::FMul && 565*06c3fb27SDimitry Andric Real->getOpcode() != Instruction::Mul) || 566*06c3fb27SDimitry Andric (Imag->getOpcode() != Instruction::FMul && 567*06c3fb27SDimitry Andric Imag->getOpcode() != Instruction::Mul)) { 568*06c3fb27SDimitry Andric LLVM_DEBUG( 569*06c3fb27SDimitry Andric dbgs() << " - Real or imaginary instruction is not fmul or mul\n"); 570bdd1243dSDimitry Andric return nullptr; 571bdd1243dSDimitry Andric } 572bdd1243dSDimitry Andric 573*06c3fb27SDimitry Andric Value *R0 = Real->getOperand(0); 574*06c3fb27SDimitry Andric Value *R1 = Real->getOperand(1); 575*06c3fb27SDimitry Andric Value *I0 = Imag->getOperand(0); 576*06c3fb27SDimitry Andric Value *I1 = Imag->getOperand(1); 577bdd1243dSDimitry Andric 578bdd1243dSDimitry Andric // A +/+ has a rotation of 0. If any of the operands are fneg, we flip the 579bdd1243dSDimitry Andric // rotations and use the operand. 580bdd1243dSDimitry Andric unsigned Negs = 0; 581*06c3fb27SDimitry Andric Value *Op; 582*06c3fb27SDimitry Andric if (match(R0, m_Neg(m_Value(Op)))) { 583bdd1243dSDimitry Andric Negs |= 1; 584*06c3fb27SDimitry Andric R0 = Op; 585*06c3fb27SDimitry Andric } else if (match(R1, m_Neg(m_Value(Op)))) { 586*06c3fb27SDimitry Andric Negs |= 1; 587*06c3fb27SDimitry Andric R1 = Op; 588bdd1243dSDimitry Andric } 589*06c3fb27SDimitry Andric 590*06c3fb27SDimitry Andric if (isNeg(I0)) { 591bdd1243dSDimitry Andric Negs |= 2; 592bdd1243dSDimitry Andric Negs ^= 1; 593*06c3fb27SDimitry Andric I0 = Op; 594*06c3fb27SDimitry Andric } else if (match(I1, m_Neg(m_Value(Op)))) { 595*06c3fb27SDimitry Andric Negs |= 2; 596*06c3fb27SDimitry Andric Negs ^= 1; 597*06c3fb27SDimitry Andric I1 = Op; 598bdd1243dSDimitry Andric } 599bdd1243dSDimitry Andric 600bdd1243dSDimitry Andric ComplexDeinterleavingRotation Rotation = (ComplexDeinterleavingRotation)Negs; 601bdd1243dSDimitry Andric 602*06c3fb27SDimitry Andric Value *CommonOperand; 603*06c3fb27SDimitry Andric Value *UncommonRealOp; 604*06c3fb27SDimitry Andric Value *UncommonImagOp; 605bdd1243dSDimitry Andric 606bdd1243dSDimitry Andric if (R0 == I0 || R0 == I1) { 607bdd1243dSDimitry Andric CommonOperand = R0; 608bdd1243dSDimitry Andric UncommonRealOp = R1; 609bdd1243dSDimitry Andric } else if (R1 == I0 || R1 == I1) { 610bdd1243dSDimitry Andric CommonOperand = R1; 611bdd1243dSDimitry Andric UncommonRealOp = R0; 612bdd1243dSDimitry Andric } else { 613bdd1243dSDimitry Andric LLVM_DEBUG(dbgs() << " - No equal operand\n"); 614bdd1243dSDimitry Andric return nullptr; 615bdd1243dSDimitry Andric } 616bdd1243dSDimitry Andric 617bdd1243dSDimitry Andric UncommonImagOp = (CommonOperand == I0) ? I1 : I0; 618bdd1243dSDimitry Andric if (Rotation == ComplexDeinterleavingRotation::Rotation_90 || 619bdd1243dSDimitry Andric Rotation == ComplexDeinterleavingRotation::Rotation_270) 620bdd1243dSDimitry Andric std::swap(UncommonRealOp, UncommonImagOp); 621bdd1243dSDimitry Andric 622bdd1243dSDimitry Andric // Between identifyPartialMul and here we need to have found a complete valid 623bdd1243dSDimitry Andric // pair from the CommonOperand of each part. 624bdd1243dSDimitry Andric if (Rotation == ComplexDeinterleavingRotation::Rotation_0 || 625bdd1243dSDimitry Andric Rotation == ComplexDeinterleavingRotation::Rotation_180) 626bdd1243dSDimitry Andric PartialMatch.first = CommonOperand; 627bdd1243dSDimitry Andric else 628bdd1243dSDimitry Andric PartialMatch.second = CommonOperand; 629bdd1243dSDimitry Andric 630bdd1243dSDimitry Andric if (!PartialMatch.first || !PartialMatch.second) { 631bdd1243dSDimitry Andric LLVM_DEBUG(dbgs() << " - Incomplete partial match\n"); 632bdd1243dSDimitry Andric return nullptr; 633bdd1243dSDimitry Andric } 634bdd1243dSDimitry Andric 635bdd1243dSDimitry Andric NodePtr CommonNode = identifyNode(PartialMatch.first, PartialMatch.second); 636bdd1243dSDimitry Andric if (!CommonNode) { 637bdd1243dSDimitry Andric LLVM_DEBUG(dbgs() << " - No CommonNode identified\n"); 638bdd1243dSDimitry Andric return nullptr; 639bdd1243dSDimitry Andric } 640bdd1243dSDimitry Andric 641bdd1243dSDimitry Andric NodePtr UncommonNode = identifyNode(UncommonRealOp, UncommonImagOp); 642bdd1243dSDimitry Andric if (!UncommonNode) { 643bdd1243dSDimitry Andric LLVM_DEBUG(dbgs() << " - No UncommonNode identified\n"); 644bdd1243dSDimitry Andric return nullptr; 645bdd1243dSDimitry Andric } 646bdd1243dSDimitry Andric 647bdd1243dSDimitry Andric NodePtr Node = prepareCompositeNode( 648bdd1243dSDimitry Andric ComplexDeinterleavingOperation::CMulPartial, Real, Imag); 649bdd1243dSDimitry Andric Node->Rotation = Rotation; 650bdd1243dSDimitry Andric Node->addOperand(CommonNode); 651bdd1243dSDimitry Andric Node->addOperand(UncommonNode); 652bdd1243dSDimitry Andric return submitCompositeNode(Node); 653bdd1243dSDimitry Andric } 654bdd1243dSDimitry Andric 655bdd1243dSDimitry Andric ComplexDeinterleavingGraph::NodePtr 656bdd1243dSDimitry Andric ComplexDeinterleavingGraph::identifyPartialMul(Instruction *Real, 657bdd1243dSDimitry Andric Instruction *Imag) { 658bdd1243dSDimitry Andric LLVM_DEBUG(dbgs() << "identifyPartialMul " << *Real << " / " << *Imag 659bdd1243dSDimitry Andric << "\n"); 660bdd1243dSDimitry Andric // Determine rotation 661*06c3fb27SDimitry Andric auto IsAdd = [](unsigned Op) { 662*06c3fb27SDimitry Andric return Op == Instruction::FAdd || Op == Instruction::Add; 663*06c3fb27SDimitry Andric }; 664*06c3fb27SDimitry Andric auto IsSub = [](unsigned Op) { 665*06c3fb27SDimitry Andric return Op == Instruction::FSub || Op == Instruction::Sub; 666*06c3fb27SDimitry Andric }; 667bdd1243dSDimitry Andric ComplexDeinterleavingRotation Rotation; 668*06c3fb27SDimitry Andric if (IsAdd(Real->getOpcode()) && IsAdd(Imag->getOpcode())) 669bdd1243dSDimitry Andric Rotation = ComplexDeinterleavingRotation::Rotation_0; 670*06c3fb27SDimitry Andric else if (IsSub(Real->getOpcode()) && IsAdd(Imag->getOpcode())) 671bdd1243dSDimitry Andric Rotation = ComplexDeinterleavingRotation::Rotation_90; 672*06c3fb27SDimitry Andric else if (IsSub(Real->getOpcode()) && IsSub(Imag->getOpcode())) 673bdd1243dSDimitry Andric Rotation = ComplexDeinterleavingRotation::Rotation_180; 674*06c3fb27SDimitry Andric else if (IsAdd(Real->getOpcode()) && IsSub(Imag->getOpcode())) 675bdd1243dSDimitry Andric Rotation = ComplexDeinterleavingRotation::Rotation_270; 676bdd1243dSDimitry Andric else { 677bdd1243dSDimitry Andric LLVM_DEBUG(dbgs() << " - Unhandled rotation.\n"); 678bdd1243dSDimitry Andric return nullptr; 679bdd1243dSDimitry Andric } 680bdd1243dSDimitry Andric 681*06c3fb27SDimitry Andric if (isa<FPMathOperator>(Real) && 682*06c3fb27SDimitry Andric (!Real->getFastMathFlags().allowContract() || 683*06c3fb27SDimitry Andric !Imag->getFastMathFlags().allowContract())) { 684bdd1243dSDimitry Andric LLVM_DEBUG(dbgs() << " - Contract is missing from the FastMath flags.\n"); 685bdd1243dSDimitry Andric return nullptr; 686bdd1243dSDimitry Andric } 687bdd1243dSDimitry Andric 688bdd1243dSDimitry Andric Value *CR = Real->getOperand(0); 689bdd1243dSDimitry Andric Instruction *RealMulI = dyn_cast<Instruction>(Real->getOperand(1)); 690bdd1243dSDimitry Andric if (!RealMulI) 691bdd1243dSDimitry Andric return nullptr; 692bdd1243dSDimitry Andric Value *CI = Imag->getOperand(0); 693bdd1243dSDimitry Andric Instruction *ImagMulI = dyn_cast<Instruction>(Imag->getOperand(1)); 694bdd1243dSDimitry Andric if (!ImagMulI) 695bdd1243dSDimitry Andric return nullptr; 696bdd1243dSDimitry Andric 697bdd1243dSDimitry Andric if (!RealMulI->hasOneUse() || !ImagMulI->hasOneUse()) { 698bdd1243dSDimitry Andric LLVM_DEBUG(dbgs() << " - Mul instruction has multiple uses\n"); 699bdd1243dSDimitry Andric return nullptr; 700bdd1243dSDimitry Andric } 701bdd1243dSDimitry Andric 702*06c3fb27SDimitry Andric Value *R0 = RealMulI->getOperand(0); 703*06c3fb27SDimitry Andric Value *R1 = RealMulI->getOperand(1); 704*06c3fb27SDimitry Andric Value *I0 = ImagMulI->getOperand(0); 705*06c3fb27SDimitry Andric Value *I1 = ImagMulI->getOperand(1); 706bdd1243dSDimitry Andric 707*06c3fb27SDimitry Andric Value *CommonOperand; 708*06c3fb27SDimitry Andric Value *UncommonRealOp; 709*06c3fb27SDimitry Andric Value *UncommonImagOp; 710bdd1243dSDimitry Andric 711bdd1243dSDimitry Andric if (R0 == I0 || R0 == I1) { 712bdd1243dSDimitry Andric CommonOperand = R0; 713bdd1243dSDimitry Andric UncommonRealOp = R1; 714bdd1243dSDimitry Andric } else if (R1 == I0 || R1 == I1) { 715bdd1243dSDimitry Andric CommonOperand = R1; 716bdd1243dSDimitry Andric UncommonRealOp = R0; 717bdd1243dSDimitry Andric } else { 718bdd1243dSDimitry Andric LLVM_DEBUG(dbgs() << " - No equal operand\n"); 719bdd1243dSDimitry Andric return nullptr; 720bdd1243dSDimitry Andric } 721bdd1243dSDimitry Andric 722bdd1243dSDimitry Andric UncommonImagOp = (CommonOperand == I0) ? I1 : I0; 723bdd1243dSDimitry Andric if (Rotation == ComplexDeinterleavingRotation::Rotation_90 || 724bdd1243dSDimitry Andric Rotation == ComplexDeinterleavingRotation::Rotation_270) 725bdd1243dSDimitry Andric std::swap(UncommonRealOp, UncommonImagOp); 726bdd1243dSDimitry Andric 727*06c3fb27SDimitry Andric std::pair<Value *, Value *> PartialMatch( 728bdd1243dSDimitry Andric (Rotation == ComplexDeinterleavingRotation::Rotation_0 || 729bdd1243dSDimitry Andric Rotation == ComplexDeinterleavingRotation::Rotation_180) 730bdd1243dSDimitry Andric ? CommonOperand 731bdd1243dSDimitry Andric : nullptr, 732bdd1243dSDimitry Andric (Rotation == ComplexDeinterleavingRotation::Rotation_90 || 733bdd1243dSDimitry Andric Rotation == ComplexDeinterleavingRotation::Rotation_270) 734bdd1243dSDimitry Andric ? CommonOperand 735bdd1243dSDimitry Andric : nullptr); 736*06c3fb27SDimitry Andric 737*06c3fb27SDimitry Andric auto *CRInst = dyn_cast<Instruction>(CR); 738*06c3fb27SDimitry Andric auto *CIInst = dyn_cast<Instruction>(CI); 739*06c3fb27SDimitry Andric 740*06c3fb27SDimitry Andric if (!CRInst || !CIInst) { 741*06c3fb27SDimitry Andric LLVM_DEBUG(dbgs() << " - Common operands are not instructions.\n"); 742*06c3fb27SDimitry Andric return nullptr; 743*06c3fb27SDimitry Andric } 744*06c3fb27SDimitry Andric 745*06c3fb27SDimitry Andric NodePtr CNode = identifyNodeWithImplicitAdd(CRInst, CIInst, PartialMatch); 746bdd1243dSDimitry Andric if (!CNode) { 747bdd1243dSDimitry Andric LLVM_DEBUG(dbgs() << " - No cnode identified\n"); 748bdd1243dSDimitry Andric return nullptr; 749bdd1243dSDimitry Andric } 750bdd1243dSDimitry Andric 751bdd1243dSDimitry Andric NodePtr UncommonRes = identifyNode(UncommonRealOp, UncommonImagOp); 752bdd1243dSDimitry Andric if (!UncommonRes) { 753bdd1243dSDimitry Andric LLVM_DEBUG(dbgs() << " - No UncommonRes identified\n"); 754bdd1243dSDimitry Andric return nullptr; 755bdd1243dSDimitry Andric } 756bdd1243dSDimitry Andric 757bdd1243dSDimitry Andric assert(PartialMatch.first && PartialMatch.second); 758bdd1243dSDimitry Andric NodePtr CommonRes = identifyNode(PartialMatch.first, PartialMatch.second); 759bdd1243dSDimitry Andric if (!CommonRes) { 760bdd1243dSDimitry Andric LLVM_DEBUG(dbgs() << " - No CommonRes identified\n"); 761bdd1243dSDimitry Andric return nullptr; 762bdd1243dSDimitry Andric } 763bdd1243dSDimitry Andric 764bdd1243dSDimitry Andric NodePtr Node = prepareCompositeNode( 765bdd1243dSDimitry Andric ComplexDeinterleavingOperation::CMulPartial, Real, Imag); 766bdd1243dSDimitry Andric Node->Rotation = Rotation; 767bdd1243dSDimitry Andric Node->addOperand(CommonRes); 768bdd1243dSDimitry Andric Node->addOperand(UncommonRes); 769bdd1243dSDimitry Andric Node->addOperand(CNode); 770bdd1243dSDimitry Andric return submitCompositeNode(Node); 771bdd1243dSDimitry Andric } 772bdd1243dSDimitry Andric 773bdd1243dSDimitry Andric ComplexDeinterleavingGraph::NodePtr 774bdd1243dSDimitry Andric ComplexDeinterleavingGraph::identifyAdd(Instruction *Real, Instruction *Imag) { 775bdd1243dSDimitry Andric LLVM_DEBUG(dbgs() << "identifyAdd " << *Real << " / " << *Imag << "\n"); 776bdd1243dSDimitry Andric 777bdd1243dSDimitry Andric // Determine rotation 778bdd1243dSDimitry Andric ComplexDeinterleavingRotation Rotation; 779bdd1243dSDimitry Andric if ((Real->getOpcode() == Instruction::FSub && 780bdd1243dSDimitry Andric Imag->getOpcode() == Instruction::FAdd) || 781bdd1243dSDimitry Andric (Real->getOpcode() == Instruction::Sub && 782bdd1243dSDimitry Andric Imag->getOpcode() == Instruction::Add)) 783bdd1243dSDimitry Andric Rotation = ComplexDeinterleavingRotation::Rotation_90; 784bdd1243dSDimitry Andric else if ((Real->getOpcode() == Instruction::FAdd && 785bdd1243dSDimitry Andric Imag->getOpcode() == Instruction::FSub) || 786bdd1243dSDimitry Andric (Real->getOpcode() == Instruction::Add && 787bdd1243dSDimitry Andric Imag->getOpcode() == Instruction::Sub)) 788bdd1243dSDimitry Andric Rotation = ComplexDeinterleavingRotation::Rotation_270; 789bdd1243dSDimitry Andric else { 790bdd1243dSDimitry Andric LLVM_DEBUG(dbgs() << " - Unhandled case, rotation is not assigned.\n"); 791bdd1243dSDimitry Andric return nullptr; 792bdd1243dSDimitry Andric } 793bdd1243dSDimitry Andric 794bdd1243dSDimitry Andric auto *AR = dyn_cast<Instruction>(Real->getOperand(0)); 795bdd1243dSDimitry Andric auto *BI = dyn_cast<Instruction>(Real->getOperand(1)); 796bdd1243dSDimitry Andric auto *AI = dyn_cast<Instruction>(Imag->getOperand(0)); 797bdd1243dSDimitry Andric auto *BR = dyn_cast<Instruction>(Imag->getOperand(1)); 798bdd1243dSDimitry Andric 799bdd1243dSDimitry Andric if (!AR || !AI || !BR || !BI) { 800bdd1243dSDimitry Andric LLVM_DEBUG(dbgs() << " - Not all operands are instructions.\n"); 801bdd1243dSDimitry Andric return nullptr; 802bdd1243dSDimitry Andric } 803bdd1243dSDimitry Andric 804bdd1243dSDimitry Andric NodePtr ResA = identifyNode(AR, AI); 805bdd1243dSDimitry Andric if (!ResA) { 806bdd1243dSDimitry Andric LLVM_DEBUG(dbgs() << " - AR/AI is not identified as a composite node.\n"); 807bdd1243dSDimitry Andric return nullptr; 808bdd1243dSDimitry Andric } 809bdd1243dSDimitry Andric NodePtr ResB = identifyNode(BR, BI); 810bdd1243dSDimitry Andric if (!ResB) { 811bdd1243dSDimitry Andric LLVM_DEBUG(dbgs() << " - BR/BI is not identified as a composite node.\n"); 812bdd1243dSDimitry Andric return nullptr; 813bdd1243dSDimitry Andric } 814bdd1243dSDimitry Andric 815bdd1243dSDimitry Andric NodePtr Node = 816bdd1243dSDimitry Andric prepareCompositeNode(ComplexDeinterleavingOperation::CAdd, Real, Imag); 817bdd1243dSDimitry Andric Node->Rotation = Rotation; 818bdd1243dSDimitry Andric Node->addOperand(ResA); 819bdd1243dSDimitry Andric Node->addOperand(ResB); 820bdd1243dSDimitry Andric return submitCompositeNode(Node); 821bdd1243dSDimitry Andric } 822bdd1243dSDimitry Andric 823bdd1243dSDimitry Andric static bool isInstructionPairAdd(Instruction *A, Instruction *B) { 824bdd1243dSDimitry Andric unsigned OpcA = A->getOpcode(); 825bdd1243dSDimitry Andric unsigned OpcB = B->getOpcode(); 826bdd1243dSDimitry Andric 827bdd1243dSDimitry Andric return (OpcA == Instruction::FSub && OpcB == Instruction::FAdd) || 828bdd1243dSDimitry Andric (OpcA == Instruction::FAdd && OpcB == Instruction::FSub) || 829bdd1243dSDimitry Andric (OpcA == Instruction::Sub && OpcB == Instruction::Add) || 830bdd1243dSDimitry Andric (OpcA == Instruction::Add && OpcB == Instruction::Sub); 831bdd1243dSDimitry Andric } 832bdd1243dSDimitry Andric 833bdd1243dSDimitry Andric static bool isInstructionPairMul(Instruction *A, Instruction *B) { 834bdd1243dSDimitry Andric auto Pattern = 835bdd1243dSDimitry Andric m_BinOp(m_FMul(m_Value(), m_Value()), m_FMul(m_Value(), m_Value())); 836bdd1243dSDimitry Andric 837bdd1243dSDimitry Andric return match(A, Pattern) && match(B, Pattern); 838bdd1243dSDimitry Andric } 839bdd1243dSDimitry Andric 840*06c3fb27SDimitry Andric static bool isInstructionPotentiallySymmetric(Instruction *I) { 841*06c3fb27SDimitry Andric switch (I->getOpcode()) { 842*06c3fb27SDimitry Andric case Instruction::FAdd: 843*06c3fb27SDimitry Andric case Instruction::FSub: 844*06c3fb27SDimitry Andric case Instruction::FMul: 845*06c3fb27SDimitry Andric case Instruction::FNeg: 846*06c3fb27SDimitry Andric case Instruction::Add: 847*06c3fb27SDimitry Andric case Instruction::Sub: 848*06c3fb27SDimitry Andric case Instruction::Mul: 849*06c3fb27SDimitry Andric return true; 850*06c3fb27SDimitry Andric default: 851*06c3fb27SDimitry Andric return false; 852*06c3fb27SDimitry Andric } 853*06c3fb27SDimitry Andric } 854*06c3fb27SDimitry Andric 855bdd1243dSDimitry Andric ComplexDeinterleavingGraph::NodePtr 856*06c3fb27SDimitry Andric ComplexDeinterleavingGraph::identifySymmetricOperation(Instruction *Real, 857*06c3fb27SDimitry Andric Instruction *Imag) { 858*06c3fb27SDimitry Andric if (Real->getOpcode() != Imag->getOpcode()) 859*06c3fb27SDimitry Andric return nullptr; 860*06c3fb27SDimitry Andric 861*06c3fb27SDimitry Andric if (!isInstructionPotentiallySymmetric(Real) || 862*06c3fb27SDimitry Andric !isInstructionPotentiallySymmetric(Imag)) 863*06c3fb27SDimitry Andric return nullptr; 864*06c3fb27SDimitry Andric 865*06c3fb27SDimitry Andric auto *R0 = Real->getOperand(0); 866*06c3fb27SDimitry Andric auto *I0 = Imag->getOperand(0); 867*06c3fb27SDimitry Andric 868*06c3fb27SDimitry Andric NodePtr Op0 = identifyNode(R0, I0); 869*06c3fb27SDimitry Andric NodePtr Op1 = nullptr; 870*06c3fb27SDimitry Andric if (Op0 == nullptr) 871*06c3fb27SDimitry Andric return nullptr; 872*06c3fb27SDimitry Andric 873*06c3fb27SDimitry Andric if (Real->isBinaryOp()) { 874*06c3fb27SDimitry Andric auto *R1 = Real->getOperand(1); 875*06c3fb27SDimitry Andric auto *I1 = Imag->getOperand(1); 876*06c3fb27SDimitry Andric Op1 = identifyNode(R1, I1); 877*06c3fb27SDimitry Andric if (Op1 == nullptr) 878*06c3fb27SDimitry Andric return nullptr; 879*06c3fb27SDimitry Andric } 880*06c3fb27SDimitry Andric 881*06c3fb27SDimitry Andric if (isa<FPMathOperator>(Real) && 882*06c3fb27SDimitry Andric Real->getFastMathFlags() != Imag->getFastMathFlags()) 883*06c3fb27SDimitry Andric return nullptr; 884*06c3fb27SDimitry Andric 885*06c3fb27SDimitry Andric auto Node = prepareCompositeNode(ComplexDeinterleavingOperation::Symmetric, 886*06c3fb27SDimitry Andric Real, Imag); 887*06c3fb27SDimitry Andric Node->Opcode = Real->getOpcode(); 888*06c3fb27SDimitry Andric if (isa<FPMathOperator>(Real)) 889*06c3fb27SDimitry Andric Node->Flags = Real->getFastMathFlags(); 890*06c3fb27SDimitry Andric 891*06c3fb27SDimitry Andric Node->addOperand(Op0); 892*06c3fb27SDimitry Andric if (Real->isBinaryOp()) 893*06c3fb27SDimitry Andric Node->addOperand(Op1); 894*06c3fb27SDimitry Andric 895*06c3fb27SDimitry Andric return submitCompositeNode(Node); 896*06c3fb27SDimitry Andric } 897*06c3fb27SDimitry Andric 898*06c3fb27SDimitry Andric ComplexDeinterleavingGraph::NodePtr 899*06c3fb27SDimitry Andric ComplexDeinterleavingGraph::identifyNode(Value *R, Value *I) { 900*06c3fb27SDimitry Andric LLVM_DEBUG(dbgs() << "identifyNode on " << *R << " / " << *I << "\n"); 901*06c3fb27SDimitry Andric assert(R->getType() == I->getType() && 902*06c3fb27SDimitry Andric "Real and imaginary parts should not have different types"); 903*06c3fb27SDimitry Andric if (NodePtr CN = getContainingComposite(R, I)) { 904bdd1243dSDimitry Andric LLVM_DEBUG(dbgs() << " - Folding to existing node\n"); 905bdd1243dSDimitry Andric return CN; 906bdd1243dSDimitry Andric } 907bdd1243dSDimitry Andric 908*06c3fb27SDimitry Andric if (NodePtr CN = identifySplat(R, I)) 909*06c3fb27SDimitry Andric return CN; 910*06c3fb27SDimitry Andric 911*06c3fb27SDimitry Andric auto *Real = dyn_cast<Instruction>(R); 912*06c3fb27SDimitry Andric auto *Imag = dyn_cast<Instruction>(I); 913*06c3fb27SDimitry Andric if (!Real || !Imag) 914*06c3fb27SDimitry Andric return nullptr; 915*06c3fb27SDimitry Andric 916*06c3fb27SDimitry Andric if (NodePtr CN = identifyDeinterleave(Real, Imag)) 917*06c3fb27SDimitry Andric return CN; 918*06c3fb27SDimitry Andric 919*06c3fb27SDimitry Andric if (NodePtr CN = identifyPHINode(Real, Imag)) 920*06c3fb27SDimitry Andric return CN; 921*06c3fb27SDimitry Andric 922*06c3fb27SDimitry Andric if (NodePtr CN = identifySelectNode(Real, Imag)) 923*06c3fb27SDimitry Andric return CN; 924*06c3fb27SDimitry Andric 925*06c3fb27SDimitry Andric auto *VTy = cast<VectorType>(Real->getType()); 926*06c3fb27SDimitry Andric auto *NewVTy = VectorType::getDoubleElementsVectorType(VTy); 927*06c3fb27SDimitry Andric 928*06c3fb27SDimitry Andric bool HasCMulSupport = TL->isComplexDeinterleavingOperationSupported( 929*06c3fb27SDimitry Andric ComplexDeinterleavingOperation::CMulPartial, NewVTy); 930*06c3fb27SDimitry Andric bool HasCAddSupport = TL->isComplexDeinterleavingOperationSupported( 931*06c3fb27SDimitry Andric ComplexDeinterleavingOperation::CAdd, NewVTy); 932*06c3fb27SDimitry Andric 933*06c3fb27SDimitry Andric if (HasCMulSupport && isInstructionPairMul(Real, Imag)) { 934*06c3fb27SDimitry Andric if (NodePtr CN = identifyPartialMul(Real, Imag)) 935*06c3fb27SDimitry Andric return CN; 936*06c3fb27SDimitry Andric } 937*06c3fb27SDimitry Andric 938*06c3fb27SDimitry Andric if (HasCAddSupport && isInstructionPairAdd(Real, Imag)) { 939*06c3fb27SDimitry Andric if (NodePtr CN = identifyAdd(Real, Imag)) 940*06c3fb27SDimitry Andric return CN; 941*06c3fb27SDimitry Andric } 942*06c3fb27SDimitry Andric 943*06c3fb27SDimitry Andric if (HasCMulSupport && HasCAddSupport) { 944*06c3fb27SDimitry Andric if (NodePtr CN = identifyReassocNodes(Real, Imag)) 945*06c3fb27SDimitry Andric return CN; 946*06c3fb27SDimitry Andric } 947*06c3fb27SDimitry Andric 948*06c3fb27SDimitry Andric if (NodePtr CN = identifySymmetricOperation(Real, Imag)) 949*06c3fb27SDimitry Andric return CN; 950*06c3fb27SDimitry Andric 951*06c3fb27SDimitry Andric LLVM_DEBUG(dbgs() << " - Not recognised as a valid pattern.\n"); 952*06c3fb27SDimitry Andric return nullptr; 953*06c3fb27SDimitry Andric } 954*06c3fb27SDimitry Andric 955*06c3fb27SDimitry Andric ComplexDeinterleavingGraph::NodePtr 956*06c3fb27SDimitry Andric ComplexDeinterleavingGraph::identifyReassocNodes(Instruction *Real, 957*06c3fb27SDimitry Andric Instruction *Imag) { 958*06c3fb27SDimitry Andric auto IsOperationSupported = [](unsigned Opcode) -> bool { 959*06c3fb27SDimitry Andric return Opcode == Instruction::FAdd || Opcode == Instruction::FSub || 960*06c3fb27SDimitry Andric Opcode == Instruction::FNeg || Opcode == Instruction::Add || 961*06c3fb27SDimitry Andric Opcode == Instruction::Sub; 962*06c3fb27SDimitry Andric }; 963*06c3fb27SDimitry Andric 964*06c3fb27SDimitry Andric if (!IsOperationSupported(Real->getOpcode()) || 965*06c3fb27SDimitry Andric !IsOperationSupported(Imag->getOpcode())) 966*06c3fb27SDimitry Andric return nullptr; 967*06c3fb27SDimitry Andric 968*06c3fb27SDimitry Andric std::optional<FastMathFlags> Flags; 969*06c3fb27SDimitry Andric if (isa<FPMathOperator>(Real)) { 970*06c3fb27SDimitry Andric if (Real->getFastMathFlags() != Imag->getFastMathFlags()) { 971*06c3fb27SDimitry Andric LLVM_DEBUG(dbgs() << "The flags in Real and Imaginary instructions are " 972*06c3fb27SDimitry Andric "not identical\n"); 973*06c3fb27SDimitry Andric return nullptr; 974*06c3fb27SDimitry Andric } 975*06c3fb27SDimitry Andric 976*06c3fb27SDimitry Andric Flags = Real->getFastMathFlags(); 977*06c3fb27SDimitry Andric if (!Flags->allowReassoc()) { 978*06c3fb27SDimitry Andric LLVM_DEBUG( 979*06c3fb27SDimitry Andric dbgs() 980*06c3fb27SDimitry Andric << "the 'Reassoc' attribute is missing in the FastMath flags\n"); 981*06c3fb27SDimitry Andric return nullptr; 982*06c3fb27SDimitry Andric } 983*06c3fb27SDimitry Andric } 984*06c3fb27SDimitry Andric 985*06c3fb27SDimitry Andric // Collect multiplications and addend instructions from the given instruction 986*06c3fb27SDimitry Andric // while traversing it operands. Additionally, verify that all instructions 987*06c3fb27SDimitry Andric // have the same fast math flags. 988*06c3fb27SDimitry Andric auto Collect = [&Flags](Instruction *Insn, std::vector<Product> &Muls, 989*06c3fb27SDimitry Andric std::list<Addend> &Addends) -> bool { 990*06c3fb27SDimitry Andric SmallVector<PointerIntPair<Value *, 1, bool>> Worklist = {{Insn, true}}; 991*06c3fb27SDimitry Andric SmallPtrSet<Value *, 8> Visited; 992*06c3fb27SDimitry Andric while (!Worklist.empty()) { 993*06c3fb27SDimitry Andric auto [V, IsPositive] = Worklist.back(); 994*06c3fb27SDimitry Andric Worklist.pop_back(); 995*06c3fb27SDimitry Andric if (!Visited.insert(V).second) 996*06c3fb27SDimitry Andric continue; 997*06c3fb27SDimitry Andric 998*06c3fb27SDimitry Andric Instruction *I = dyn_cast<Instruction>(V); 999*06c3fb27SDimitry Andric if (!I) { 1000*06c3fb27SDimitry Andric Addends.emplace_back(V, IsPositive); 1001*06c3fb27SDimitry Andric continue; 1002*06c3fb27SDimitry Andric } 1003*06c3fb27SDimitry Andric 1004*06c3fb27SDimitry Andric // If an instruction has more than one user, it indicates that it either 1005*06c3fb27SDimitry Andric // has an external user, which will be later checked by the checkNodes 1006*06c3fb27SDimitry Andric // function, or it is a subexpression utilized by multiple expressions. In 1007*06c3fb27SDimitry Andric // the latter case, we will attempt to separately identify the complex 1008*06c3fb27SDimitry Andric // operation from here in order to create a shared 1009*06c3fb27SDimitry Andric // ComplexDeinterleavingCompositeNode. 1010*06c3fb27SDimitry Andric if (I != Insn && I->getNumUses() > 1) { 1011*06c3fb27SDimitry Andric LLVM_DEBUG(dbgs() << "Found potential sub-expression: " << *I << "\n"); 1012*06c3fb27SDimitry Andric Addends.emplace_back(I, IsPositive); 1013*06c3fb27SDimitry Andric continue; 1014*06c3fb27SDimitry Andric } 1015*06c3fb27SDimitry Andric switch (I->getOpcode()) { 1016*06c3fb27SDimitry Andric case Instruction::FAdd: 1017*06c3fb27SDimitry Andric case Instruction::Add: 1018*06c3fb27SDimitry Andric Worklist.emplace_back(I->getOperand(1), IsPositive); 1019*06c3fb27SDimitry Andric Worklist.emplace_back(I->getOperand(0), IsPositive); 1020*06c3fb27SDimitry Andric break; 1021*06c3fb27SDimitry Andric case Instruction::FSub: 1022*06c3fb27SDimitry Andric Worklist.emplace_back(I->getOperand(1), !IsPositive); 1023*06c3fb27SDimitry Andric Worklist.emplace_back(I->getOperand(0), IsPositive); 1024*06c3fb27SDimitry Andric break; 1025*06c3fb27SDimitry Andric case Instruction::Sub: 1026*06c3fb27SDimitry Andric if (isNeg(I)) { 1027*06c3fb27SDimitry Andric Worklist.emplace_back(getNegOperand(I), !IsPositive); 1028*06c3fb27SDimitry Andric } else { 1029*06c3fb27SDimitry Andric Worklist.emplace_back(I->getOperand(1), !IsPositive); 1030*06c3fb27SDimitry Andric Worklist.emplace_back(I->getOperand(0), IsPositive); 1031*06c3fb27SDimitry Andric } 1032*06c3fb27SDimitry Andric break; 1033*06c3fb27SDimitry Andric case Instruction::FMul: 1034*06c3fb27SDimitry Andric case Instruction::Mul: { 1035*06c3fb27SDimitry Andric Value *A, *B; 1036*06c3fb27SDimitry Andric if (isNeg(I->getOperand(0))) { 1037*06c3fb27SDimitry Andric A = getNegOperand(I->getOperand(0)); 1038*06c3fb27SDimitry Andric IsPositive = !IsPositive; 1039*06c3fb27SDimitry Andric } else { 1040*06c3fb27SDimitry Andric A = I->getOperand(0); 1041*06c3fb27SDimitry Andric } 1042*06c3fb27SDimitry Andric 1043*06c3fb27SDimitry Andric if (isNeg(I->getOperand(1))) { 1044*06c3fb27SDimitry Andric B = getNegOperand(I->getOperand(1)); 1045*06c3fb27SDimitry Andric IsPositive = !IsPositive; 1046*06c3fb27SDimitry Andric } else { 1047*06c3fb27SDimitry Andric B = I->getOperand(1); 1048*06c3fb27SDimitry Andric } 1049*06c3fb27SDimitry Andric Muls.push_back(Product{A, B, IsPositive}); 1050*06c3fb27SDimitry Andric break; 1051*06c3fb27SDimitry Andric } 1052*06c3fb27SDimitry Andric case Instruction::FNeg: 1053*06c3fb27SDimitry Andric Worklist.emplace_back(I->getOperand(0), !IsPositive); 1054*06c3fb27SDimitry Andric break; 1055*06c3fb27SDimitry Andric default: 1056*06c3fb27SDimitry Andric Addends.emplace_back(I, IsPositive); 1057*06c3fb27SDimitry Andric continue; 1058*06c3fb27SDimitry Andric } 1059*06c3fb27SDimitry Andric 1060*06c3fb27SDimitry Andric if (Flags && I->getFastMathFlags() != *Flags) { 1061*06c3fb27SDimitry Andric LLVM_DEBUG(dbgs() << "The instruction's fast math flags are " 1062*06c3fb27SDimitry Andric "inconsistent with the root instructions' flags: " 1063*06c3fb27SDimitry Andric << *I << "\n"); 1064*06c3fb27SDimitry Andric return false; 1065*06c3fb27SDimitry Andric } 1066*06c3fb27SDimitry Andric } 1067*06c3fb27SDimitry Andric return true; 1068*06c3fb27SDimitry Andric }; 1069*06c3fb27SDimitry Andric 1070*06c3fb27SDimitry Andric std::vector<Product> RealMuls, ImagMuls; 1071*06c3fb27SDimitry Andric std::list<Addend> RealAddends, ImagAddends; 1072*06c3fb27SDimitry Andric if (!Collect(Real, RealMuls, RealAddends) || 1073*06c3fb27SDimitry Andric !Collect(Imag, ImagMuls, ImagAddends)) 1074*06c3fb27SDimitry Andric return nullptr; 1075*06c3fb27SDimitry Andric 1076*06c3fb27SDimitry Andric if (RealAddends.size() != ImagAddends.size()) 1077*06c3fb27SDimitry Andric return nullptr; 1078*06c3fb27SDimitry Andric 1079*06c3fb27SDimitry Andric NodePtr FinalNode; 1080*06c3fb27SDimitry Andric if (!RealMuls.empty() || !ImagMuls.empty()) { 1081*06c3fb27SDimitry Andric // If there are multiplicands, extract positive addend and use it as an 1082*06c3fb27SDimitry Andric // accumulator 1083*06c3fb27SDimitry Andric FinalNode = extractPositiveAddend(RealAddends, ImagAddends); 1084*06c3fb27SDimitry Andric FinalNode = identifyMultiplications(RealMuls, ImagMuls, FinalNode); 1085*06c3fb27SDimitry Andric if (!FinalNode) 1086*06c3fb27SDimitry Andric return nullptr; 1087*06c3fb27SDimitry Andric } 1088*06c3fb27SDimitry Andric 1089*06c3fb27SDimitry Andric // Identify and process remaining additions 1090*06c3fb27SDimitry Andric if (!RealAddends.empty() || !ImagAddends.empty()) { 1091*06c3fb27SDimitry Andric FinalNode = identifyAdditions(RealAddends, ImagAddends, Flags, FinalNode); 1092*06c3fb27SDimitry Andric if (!FinalNode) 1093*06c3fb27SDimitry Andric return nullptr; 1094*06c3fb27SDimitry Andric } 1095*06c3fb27SDimitry Andric assert(FinalNode && "FinalNode can not be nullptr here"); 1096*06c3fb27SDimitry Andric // Set the Real and Imag fields of the final node and submit it 1097*06c3fb27SDimitry Andric FinalNode->Real = Real; 1098*06c3fb27SDimitry Andric FinalNode->Imag = Imag; 1099*06c3fb27SDimitry Andric submitCompositeNode(FinalNode); 1100*06c3fb27SDimitry Andric return FinalNode; 1101*06c3fb27SDimitry Andric } 1102*06c3fb27SDimitry Andric 1103*06c3fb27SDimitry Andric bool ComplexDeinterleavingGraph::collectPartialMuls( 1104*06c3fb27SDimitry Andric const std::vector<Product> &RealMuls, const std::vector<Product> &ImagMuls, 1105*06c3fb27SDimitry Andric std::vector<PartialMulCandidate> &PartialMulCandidates) { 1106*06c3fb27SDimitry Andric // Helper function to extract a common operand from two products 1107*06c3fb27SDimitry Andric auto FindCommonInstruction = [](const Product &Real, 1108*06c3fb27SDimitry Andric const Product &Imag) -> Value * { 1109*06c3fb27SDimitry Andric if (Real.Multiplicand == Imag.Multiplicand || 1110*06c3fb27SDimitry Andric Real.Multiplicand == Imag.Multiplier) 1111*06c3fb27SDimitry Andric return Real.Multiplicand; 1112*06c3fb27SDimitry Andric 1113*06c3fb27SDimitry Andric if (Real.Multiplier == Imag.Multiplicand || 1114*06c3fb27SDimitry Andric Real.Multiplier == Imag.Multiplier) 1115*06c3fb27SDimitry Andric return Real.Multiplier; 1116*06c3fb27SDimitry Andric 1117*06c3fb27SDimitry Andric return nullptr; 1118*06c3fb27SDimitry Andric }; 1119*06c3fb27SDimitry Andric 1120*06c3fb27SDimitry Andric // Iterating over real and imaginary multiplications to find common operands 1121*06c3fb27SDimitry Andric // If a common operand is found, a partial multiplication candidate is created 1122*06c3fb27SDimitry Andric // and added to the candidates vector The function returns false if no common 1123*06c3fb27SDimitry Andric // operands are found for any product 1124*06c3fb27SDimitry Andric for (unsigned i = 0; i < RealMuls.size(); ++i) { 1125*06c3fb27SDimitry Andric bool FoundCommon = false; 1126*06c3fb27SDimitry Andric for (unsigned j = 0; j < ImagMuls.size(); ++j) { 1127*06c3fb27SDimitry Andric auto *Common = FindCommonInstruction(RealMuls[i], ImagMuls[j]); 1128*06c3fb27SDimitry Andric if (!Common) 1129*06c3fb27SDimitry Andric continue; 1130*06c3fb27SDimitry Andric 1131*06c3fb27SDimitry Andric auto *A = RealMuls[i].Multiplicand == Common ? RealMuls[i].Multiplier 1132*06c3fb27SDimitry Andric : RealMuls[i].Multiplicand; 1133*06c3fb27SDimitry Andric auto *B = ImagMuls[j].Multiplicand == Common ? ImagMuls[j].Multiplier 1134*06c3fb27SDimitry Andric : ImagMuls[j].Multiplicand; 1135*06c3fb27SDimitry Andric 1136*06c3fb27SDimitry Andric auto Node = identifyNode(A, B); 1137*06c3fb27SDimitry Andric if (Node) { 1138*06c3fb27SDimitry Andric FoundCommon = true; 1139*06c3fb27SDimitry Andric PartialMulCandidates.push_back({Common, Node, i, j, false}); 1140*06c3fb27SDimitry Andric } 1141*06c3fb27SDimitry Andric 1142*06c3fb27SDimitry Andric Node = identifyNode(B, A); 1143*06c3fb27SDimitry Andric if (Node) { 1144*06c3fb27SDimitry Andric FoundCommon = true; 1145*06c3fb27SDimitry Andric PartialMulCandidates.push_back({Common, Node, i, j, true}); 1146*06c3fb27SDimitry Andric } 1147*06c3fb27SDimitry Andric } 1148*06c3fb27SDimitry Andric if (!FoundCommon) 1149*06c3fb27SDimitry Andric return false; 1150*06c3fb27SDimitry Andric } 1151*06c3fb27SDimitry Andric return true; 1152*06c3fb27SDimitry Andric } 1153*06c3fb27SDimitry Andric 1154*06c3fb27SDimitry Andric ComplexDeinterleavingGraph::NodePtr 1155*06c3fb27SDimitry Andric ComplexDeinterleavingGraph::identifyMultiplications( 1156*06c3fb27SDimitry Andric std::vector<Product> &RealMuls, std::vector<Product> &ImagMuls, 1157*06c3fb27SDimitry Andric NodePtr Accumulator = nullptr) { 1158*06c3fb27SDimitry Andric if (RealMuls.size() != ImagMuls.size()) 1159*06c3fb27SDimitry Andric return nullptr; 1160*06c3fb27SDimitry Andric 1161*06c3fb27SDimitry Andric std::vector<PartialMulCandidate> Info; 1162*06c3fb27SDimitry Andric if (!collectPartialMuls(RealMuls, ImagMuls, Info)) 1163*06c3fb27SDimitry Andric return nullptr; 1164*06c3fb27SDimitry Andric 1165*06c3fb27SDimitry Andric // Map to store common instruction to node pointers 1166*06c3fb27SDimitry Andric std::map<Value *, NodePtr> CommonToNode; 1167*06c3fb27SDimitry Andric std::vector<bool> Processed(Info.size(), false); 1168*06c3fb27SDimitry Andric for (unsigned I = 0; I < Info.size(); ++I) { 1169*06c3fb27SDimitry Andric if (Processed[I]) 1170*06c3fb27SDimitry Andric continue; 1171*06c3fb27SDimitry Andric 1172*06c3fb27SDimitry Andric PartialMulCandidate &InfoA = Info[I]; 1173*06c3fb27SDimitry Andric for (unsigned J = I + 1; J < Info.size(); ++J) { 1174*06c3fb27SDimitry Andric if (Processed[J]) 1175*06c3fb27SDimitry Andric continue; 1176*06c3fb27SDimitry Andric 1177*06c3fb27SDimitry Andric PartialMulCandidate &InfoB = Info[J]; 1178*06c3fb27SDimitry Andric auto *InfoReal = &InfoA; 1179*06c3fb27SDimitry Andric auto *InfoImag = &InfoB; 1180*06c3fb27SDimitry Andric 1181*06c3fb27SDimitry Andric auto NodeFromCommon = identifyNode(InfoReal->Common, InfoImag->Common); 1182*06c3fb27SDimitry Andric if (!NodeFromCommon) { 1183*06c3fb27SDimitry Andric std::swap(InfoReal, InfoImag); 1184*06c3fb27SDimitry Andric NodeFromCommon = identifyNode(InfoReal->Common, InfoImag->Common); 1185*06c3fb27SDimitry Andric } 1186*06c3fb27SDimitry Andric if (!NodeFromCommon) 1187*06c3fb27SDimitry Andric continue; 1188*06c3fb27SDimitry Andric 1189*06c3fb27SDimitry Andric CommonToNode[InfoReal->Common] = NodeFromCommon; 1190*06c3fb27SDimitry Andric CommonToNode[InfoImag->Common] = NodeFromCommon; 1191*06c3fb27SDimitry Andric Processed[I] = true; 1192*06c3fb27SDimitry Andric Processed[J] = true; 1193*06c3fb27SDimitry Andric } 1194*06c3fb27SDimitry Andric } 1195*06c3fb27SDimitry Andric 1196*06c3fb27SDimitry Andric std::vector<bool> ProcessedReal(RealMuls.size(), false); 1197*06c3fb27SDimitry Andric std::vector<bool> ProcessedImag(ImagMuls.size(), false); 1198*06c3fb27SDimitry Andric NodePtr Result = Accumulator; 1199*06c3fb27SDimitry Andric for (auto &PMI : Info) { 1200*06c3fb27SDimitry Andric if (ProcessedReal[PMI.RealIdx] || ProcessedImag[PMI.ImagIdx]) 1201*06c3fb27SDimitry Andric continue; 1202*06c3fb27SDimitry Andric 1203*06c3fb27SDimitry Andric auto It = CommonToNode.find(PMI.Common); 1204*06c3fb27SDimitry Andric // TODO: Process independent complex multiplications. Cases like this: 1205*06c3fb27SDimitry Andric // A.real() * B where both A and B are complex numbers. 1206*06c3fb27SDimitry Andric if (It == CommonToNode.end()) { 1207*06c3fb27SDimitry Andric LLVM_DEBUG({ 1208*06c3fb27SDimitry Andric dbgs() << "Unprocessed independent partial multiplication:\n"; 1209*06c3fb27SDimitry Andric for (auto *Mul : {&RealMuls[PMI.RealIdx], &RealMuls[PMI.RealIdx]}) 1210*06c3fb27SDimitry Andric dbgs().indent(4) << (Mul->IsPositive ? "+" : "-") << *Mul->Multiplier 1211*06c3fb27SDimitry Andric << " multiplied by " << *Mul->Multiplicand << "\n"; 1212*06c3fb27SDimitry Andric }); 1213*06c3fb27SDimitry Andric return nullptr; 1214*06c3fb27SDimitry Andric } 1215*06c3fb27SDimitry Andric 1216*06c3fb27SDimitry Andric auto &RealMul = RealMuls[PMI.RealIdx]; 1217*06c3fb27SDimitry Andric auto &ImagMul = ImagMuls[PMI.ImagIdx]; 1218*06c3fb27SDimitry Andric 1219*06c3fb27SDimitry Andric auto NodeA = It->second; 1220*06c3fb27SDimitry Andric auto NodeB = PMI.Node; 1221*06c3fb27SDimitry Andric auto IsMultiplicandReal = PMI.Common == NodeA->Real; 1222*06c3fb27SDimitry Andric // The following table illustrates the relationship between multiplications 1223*06c3fb27SDimitry Andric // and rotations. If we consider the multiplication (X + iY) * (U + iV), we 1224*06c3fb27SDimitry Andric // can see: 1225*06c3fb27SDimitry Andric // 1226*06c3fb27SDimitry Andric // Rotation | Real | Imag | 1227*06c3fb27SDimitry Andric // ---------+--------+--------+ 1228*06c3fb27SDimitry Andric // 0 | x * u | x * v | 1229*06c3fb27SDimitry Andric // 90 | -y * v | y * u | 1230*06c3fb27SDimitry Andric // 180 | -x * u | -x * v | 1231*06c3fb27SDimitry Andric // 270 | y * v | -y * u | 1232*06c3fb27SDimitry Andric // 1233*06c3fb27SDimitry Andric // Check if the candidate can indeed be represented by partial 1234*06c3fb27SDimitry Andric // multiplication 1235*06c3fb27SDimitry Andric // TODO: Add support for multiplication by complex one 1236*06c3fb27SDimitry Andric if ((IsMultiplicandReal && PMI.IsNodeInverted) || 1237*06c3fb27SDimitry Andric (!IsMultiplicandReal && !PMI.IsNodeInverted)) 1238*06c3fb27SDimitry Andric continue; 1239*06c3fb27SDimitry Andric 1240*06c3fb27SDimitry Andric // Determine the rotation based on the multiplications 1241*06c3fb27SDimitry Andric ComplexDeinterleavingRotation Rotation; 1242*06c3fb27SDimitry Andric if (IsMultiplicandReal) { 1243*06c3fb27SDimitry Andric // Detect 0 and 180 degrees rotation 1244*06c3fb27SDimitry Andric if (RealMul.IsPositive && ImagMul.IsPositive) 1245*06c3fb27SDimitry Andric Rotation = llvm::ComplexDeinterleavingRotation::Rotation_0; 1246*06c3fb27SDimitry Andric else if (!RealMul.IsPositive && !ImagMul.IsPositive) 1247*06c3fb27SDimitry Andric Rotation = llvm::ComplexDeinterleavingRotation::Rotation_180; 1248*06c3fb27SDimitry Andric else 1249*06c3fb27SDimitry Andric continue; 1250*06c3fb27SDimitry Andric 1251*06c3fb27SDimitry Andric } else { 1252*06c3fb27SDimitry Andric // Detect 90 and 270 degrees rotation 1253*06c3fb27SDimitry Andric if (!RealMul.IsPositive && ImagMul.IsPositive) 1254*06c3fb27SDimitry Andric Rotation = llvm::ComplexDeinterleavingRotation::Rotation_90; 1255*06c3fb27SDimitry Andric else if (RealMul.IsPositive && !ImagMul.IsPositive) 1256*06c3fb27SDimitry Andric Rotation = llvm::ComplexDeinterleavingRotation::Rotation_270; 1257*06c3fb27SDimitry Andric else 1258*06c3fb27SDimitry Andric continue; 1259*06c3fb27SDimitry Andric } 1260*06c3fb27SDimitry Andric 1261*06c3fb27SDimitry Andric LLVM_DEBUG({ 1262*06c3fb27SDimitry Andric dbgs() << "Identified partial multiplication (X, Y) * (U, V):\n"; 1263*06c3fb27SDimitry Andric dbgs().indent(4) << "X: " << *NodeA->Real << "\n"; 1264*06c3fb27SDimitry Andric dbgs().indent(4) << "Y: " << *NodeA->Imag << "\n"; 1265*06c3fb27SDimitry Andric dbgs().indent(4) << "U: " << *NodeB->Real << "\n"; 1266*06c3fb27SDimitry Andric dbgs().indent(4) << "V: " << *NodeB->Imag << "\n"; 1267*06c3fb27SDimitry Andric dbgs().indent(4) << "Rotation - " << (int)Rotation * 90 << "\n"; 1268*06c3fb27SDimitry Andric }); 1269*06c3fb27SDimitry Andric 1270*06c3fb27SDimitry Andric NodePtr NodeMul = prepareCompositeNode( 1271*06c3fb27SDimitry Andric ComplexDeinterleavingOperation::CMulPartial, nullptr, nullptr); 1272*06c3fb27SDimitry Andric NodeMul->Rotation = Rotation; 1273*06c3fb27SDimitry Andric NodeMul->addOperand(NodeA); 1274*06c3fb27SDimitry Andric NodeMul->addOperand(NodeB); 1275*06c3fb27SDimitry Andric if (Result) 1276*06c3fb27SDimitry Andric NodeMul->addOperand(Result); 1277*06c3fb27SDimitry Andric submitCompositeNode(NodeMul); 1278*06c3fb27SDimitry Andric Result = NodeMul; 1279*06c3fb27SDimitry Andric ProcessedReal[PMI.RealIdx] = true; 1280*06c3fb27SDimitry Andric ProcessedImag[PMI.ImagIdx] = true; 1281*06c3fb27SDimitry Andric } 1282*06c3fb27SDimitry Andric 1283*06c3fb27SDimitry Andric // Ensure all products have been processed, if not return nullptr. 1284*06c3fb27SDimitry Andric if (!all_of(ProcessedReal, [](bool V) { return V; }) || 1285*06c3fb27SDimitry Andric !all_of(ProcessedImag, [](bool V) { return V; })) { 1286*06c3fb27SDimitry Andric 1287*06c3fb27SDimitry Andric // Dump debug information about which partial multiplications are not 1288*06c3fb27SDimitry Andric // processed. 1289*06c3fb27SDimitry Andric LLVM_DEBUG({ 1290*06c3fb27SDimitry Andric dbgs() << "Unprocessed products (Real):\n"; 1291*06c3fb27SDimitry Andric for (size_t i = 0; i < ProcessedReal.size(); ++i) { 1292*06c3fb27SDimitry Andric if (!ProcessedReal[i]) 1293*06c3fb27SDimitry Andric dbgs().indent(4) << (RealMuls[i].IsPositive ? "+" : "-") 1294*06c3fb27SDimitry Andric << *RealMuls[i].Multiplier << " multiplied by " 1295*06c3fb27SDimitry Andric << *RealMuls[i].Multiplicand << "\n"; 1296*06c3fb27SDimitry Andric } 1297*06c3fb27SDimitry Andric dbgs() << "Unprocessed products (Imag):\n"; 1298*06c3fb27SDimitry Andric for (size_t i = 0; i < ProcessedImag.size(); ++i) { 1299*06c3fb27SDimitry Andric if (!ProcessedImag[i]) 1300*06c3fb27SDimitry Andric dbgs().indent(4) << (ImagMuls[i].IsPositive ? "+" : "-") 1301*06c3fb27SDimitry Andric << *ImagMuls[i].Multiplier << " multiplied by " 1302*06c3fb27SDimitry Andric << *ImagMuls[i].Multiplicand << "\n"; 1303*06c3fb27SDimitry Andric } 1304*06c3fb27SDimitry Andric }); 1305*06c3fb27SDimitry Andric return nullptr; 1306*06c3fb27SDimitry Andric } 1307*06c3fb27SDimitry Andric 1308*06c3fb27SDimitry Andric return Result; 1309*06c3fb27SDimitry Andric } 1310*06c3fb27SDimitry Andric 1311*06c3fb27SDimitry Andric ComplexDeinterleavingGraph::NodePtr 1312*06c3fb27SDimitry Andric ComplexDeinterleavingGraph::identifyAdditions( 1313*06c3fb27SDimitry Andric std::list<Addend> &RealAddends, std::list<Addend> &ImagAddends, 1314*06c3fb27SDimitry Andric std::optional<FastMathFlags> Flags, NodePtr Accumulator = nullptr) { 1315*06c3fb27SDimitry Andric if (RealAddends.size() != ImagAddends.size()) 1316*06c3fb27SDimitry Andric return nullptr; 1317*06c3fb27SDimitry Andric 1318*06c3fb27SDimitry Andric NodePtr Result; 1319*06c3fb27SDimitry Andric // If we have accumulator use it as first addend 1320*06c3fb27SDimitry Andric if (Accumulator) 1321*06c3fb27SDimitry Andric Result = Accumulator; 1322*06c3fb27SDimitry Andric // Otherwise find an element with both positive real and imaginary parts. 1323*06c3fb27SDimitry Andric else 1324*06c3fb27SDimitry Andric Result = extractPositiveAddend(RealAddends, ImagAddends); 1325*06c3fb27SDimitry Andric 1326*06c3fb27SDimitry Andric if (!Result) 1327*06c3fb27SDimitry Andric return nullptr; 1328*06c3fb27SDimitry Andric 1329*06c3fb27SDimitry Andric while (!RealAddends.empty()) { 1330*06c3fb27SDimitry Andric auto ItR = RealAddends.begin(); 1331*06c3fb27SDimitry Andric auto [R, IsPositiveR] = *ItR; 1332*06c3fb27SDimitry Andric 1333*06c3fb27SDimitry Andric bool FoundImag = false; 1334*06c3fb27SDimitry Andric for (auto ItI = ImagAddends.begin(); ItI != ImagAddends.end(); ++ItI) { 1335*06c3fb27SDimitry Andric auto [I, IsPositiveI] = *ItI; 1336*06c3fb27SDimitry Andric ComplexDeinterleavingRotation Rotation; 1337*06c3fb27SDimitry Andric if (IsPositiveR && IsPositiveI) 1338*06c3fb27SDimitry Andric Rotation = ComplexDeinterleavingRotation::Rotation_0; 1339*06c3fb27SDimitry Andric else if (!IsPositiveR && IsPositiveI) 1340*06c3fb27SDimitry Andric Rotation = ComplexDeinterleavingRotation::Rotation_90; 1341*06c3fb27SDimitry Andric else if (!IsPositiveR && !IsPositiveI) 1342*06c3fb27SDimitry Andric Rotation = ComplexDeinterleavingRotation::Rotation_180; 1343*06c3fb27SDimitry Andric else 1344*06c3fb27SDimitry Andric Rotation = ComplexDeinterleavingRotation::Rotation_270; 1345*06c3fb27SDimitry Andric 1346*06c3fb27SDimitry Andric NodePtr AddNode; 1347*06c3fb27SDimitry Andric if (Rotation == ComplexDeinterleavingRotation::Rotation_0 || 1348*06c3fb27SDimitry Andric Rotation == ComplexDeinterleavingRotation::Rotation_180) { 1349*06c3fb27SDimitry Andric AddNode = identifyNode(R, I); 1350*06c3fb27SDimitry Andric } else { 1351*06c3fb27SDimitry Andric AddNode = identifyNode(I, R); 1352*06c3fb27SDimitry Andric } 1353*06c3fb27SDimitry Andric if (AddNode) { 1354*06c3fb27SDimitry Andric LLVM_DEBUG({ 1355*06c3fb27SDimitry Andric dbgs() << "Identified addition:\n"; 1356*06c3fb27SDimitry Andric dbgs().indent(4) << "X: " << *R << "\n"; 1357*06c3fb27SDimitry Andric dbgs().indent(4) << "Y: " << *I << "\n"; 1358*06c3fb27SDimitry Andric dbgs().indent(4) << "Rotation - " << (int)Rotation * 90 << "\n"; 1359*06c3fb27SDimitry Andric }); 1360*06c3fb27SDimitry Andric 1361*06c3fb27SDimitry Andric NodePtr TmpNode; 1362*06c3fb27SDimitry Andric if (Rotation == llvm::ComplexDeinterleavingRotation::Rotation_0) { 1363*06c3fb27SDimitry Andric TmpNode = prepareCompositeNode( 1364*06c3fb27SDimitry Andric ComplexDeinterleavingOperation::Symmetric, nullptr, nullptr); 1365*06c3fb27SDimitry Andric if (Flags) { 1366*06c3fb27SDimitry Andric TmpNode->Opcode = Instruction::FAdd; 1367*06c3fb27SDimitry Andric TmpNode->Flags = *Flags; 1368*06c3fb27SDimitry Andric } else { 1369*06c3fb27SDimitry Andric TmpNode->Opcode = Instruction::Add; 1370*06c3fb27SDimitry Andric } 1371*06c3fb27SDimitry Andric } else if (Rotation == 1372*06c3fb27SDimitry Andric llvm::ComplexDeinterleavingRotation::Rotation_180) { 1373*06c3fb27SDimitry Andric TmpNode = prepareCompositeNode( 1374*06c3fb27SDimitry Andric ComplexDeinterleavingOperation::Symmetric, nullptr, nullptr); 1375*06c3fb27SDimitry Andric if (Flags) { 1376*06c3fb27SDimitry Andric TmpNode->Opcode = Instruction::FSub; 1377*06c3fb27SDimitry Andric TmpNode->Flags = *Flags; 1378*06c3fb27SDimitry Andric } else { 1379*06c3fb27SDimitry Andric TmpNode->Opcode = Instruction::Sub; 1380*06c3fb27SDimitry Andric } 1381*06c3fb27SDimitry Andric } else { 1382*06c3fb27SDimitry Andric TmpNode = prepareCompositeNode(ComplexDeinterleavingOperation::CAdd, 1383*06c3fb27SDimitry Andric nullptr, nullptr); 1384*06c3fb27SDimitry Andric TmpNode->Rotation = Rotation; 1385*06c3fb27SDimitry Andric } 1386*06c3fb27SDimitry Andric 1387*06c3fb27SDimitry Andric TmpNode->addOperand(Result); 1388*06c3fb27SDimitry Andric TmpNode->addOperand(AddNode); 1389*06c3fb27SDimitry Andric submitCompositeNode(TmpNode); 1390*06c3fb27SDimitry Andric Result = TmpNode; 1391*06c3fb27SDimitry Andric RealAddends.erase(ItR); 1392*06c3fb27SDimitry Andric ImagAddends.erase(ItI); 1393*06c3fb27SDimitry Andric FoundImag = true; 1394*06c3fb27SDimitry Andric break; 1395*06c3fb27SDimitry Andric } 1396*06c3fb27SDimitry Andric } 1397*06c3fb27SDimitry Andric if (!FoundImag) 1398*06c3fb27SDimitry Andric return nullptr; 1399*06c3fb27SDimitry Andric } 1400*06c3fb27SDimitry Andric return Result; 1401*06c3fb27SDimitry Andric } 1402*06c3fb27SDimitry Andric 1403*06c3fb27SDimitry Andric ComplexDeinterleavingGraph::NodePtr 1404*06c3fb27SDimitry Andric ComplexDeinterleavingGraph::extractPositiveAddend( 1405*06c3fb27SDimitry Andric std::list<Addend> &RealAddends, std::list<Addend> &ImagAddends) { 1406*06c3fb27SDimitry Andric for (auto ItR = RealAddends.begin(); ItR != RealAddends.end(); ++ItR) { 1407*06c3fb27SDimitry Andric for (auto ItI = ImagAddends.begin(); ItI != ImagAddends.end(); ++ItI) { 1408*06c3fb27SDimitry Andric auto [R, IsPositiveR] = *ItR; 1409*06c3fb27SDimitry Andric auto [I, IsPositiveI] = *ItI; 1410*06c3fb27SDimitry Andric if (IsPositiveR && IsPositiveI) { 1411*06c3fb27SDimitry Andric auto Result = identifyNode(R, I); 1412*06c3fb27SDimitry Andric if (Result) { 1413*06c3fb27SDimitry Andric RealAddends.erase(ItR); 1414*06c3fb27SDimitry Andric ImagAddends.erase(ItI); 1415*06c3fb27SDimitry Andric return Result; 1416*06c3fb27SDimitry Andric } 1417*06c3fb27SDimitry Andric } 1418*06c3fb27SDimitry Andric } 1419*06c3fb27SDimitry Andric } 1420*06c3fb27SDimitry Andric return nullptr; 1421*06c3fb27SDimitry Andric } 1422*06c3fb27SDimitry Andric 1423*06c3fb27SDimitry Andric bool ComplexDeinterleavingGraph::identifyNodes(Instruction *RootI) { 1424*06c3fb27SDimitry Andric // This potential root instruction might already have been recognized as 1425*06c3fb27SDimitry Andric // reduction. Because RootToNode maps both Real and Imaginary parts to 1426*06c3fb27SDimitry Andric // CompositeNode we should choose only one either Real or Imag instruction to 1427*06c3fb27SDimitry Andric // use as an anchor for generating complex instruction. 1428*06c3fb27SDimitry Andric auto It = RootToNode.find(RootI); 1429*06c3fb27SDimitry Andric if (It != RootToNode.end() && It->second->Real == RootI) { 1430*06c3fb27SDimitry Andric OrderedRoots.push_back(RootI); 1431*06c3fb27SDimitry Andric return true; 1432*06c3fb27SDimitry Andric } 1433*06c3fb27SDimitry Andric 1434*06c3fb27SDimitry Andric auto RootNode = identifyRoot(RootI); 1435*06c3fb27SDimitry Andric if (!RootNode) 1436*06c3fb27SDimitry Andric return false; 1437*06c3fb27SDimitry Andric 1438*06c3fb27SDimitry Andric LLVM_DEBUG({ 1439*06c3fb27SDimitry Andric Function *F = RootI->getFunction(); 1440*06c3fb27SDimitry Andric BasicBlock *B = RootI->getParent(); 1441*06c3fb27SDimitry Andric dbgs() << "Complex deinterleaving graph for " << F->getName() 1442*06c3fb27SDimitry Andric << "::" << B->getName() << ".\n"; 1443*06c3fb27SDimitry Andric dump(dbgs()); 1444*06c3fb27SDimitry Andric dbgs() << "\n"; 1445*06c3fb27SDimitry Andric }); 1446*06c3fb27SDimitry Andric RootToNode[RootI] = RootNode; 1447*06c3fb27SDimitry Andric OrderedRoots.push_back(RootI); 1448*06c3fb27SDimitry Andric return true; 1449*06c3fb27SDimitry Andric } 1450*06c3fb27SDimitry Andric 1451*06c3fb27SDimitry Andric bool ComplexDeinterleavingGraph::collectPotentialReductions(BasicBlock *B) { 1452*06c3fb27SDimitry Andric bool FoundPotentialReduction = false; 1453*06c3fb27SDimitry Andric 1454*06c3fb27SDimitry Andric auto *Br = dyn_cast<BranchInst>(B->getTerminator()); 1455*06c3fb27SDimitry Andric if (!Br || Br->getNumSuccessors() != 2) 1456*06c3fb27SDimitry Andric return false; 1457*06c3fb27SDimitry Andric 1458*06c3fb27SDimitry Andric // Identify simple one-block loop 1459*06c3fb27SDimitry Andric if (Br->getSuccessor(0) != B && Br->getSuccessor(1) != B) 1460*06c3fb27SDimitry Andric return false; 1461*06c3fb27SDimitry Andric 1462*06c3fb27SDimitry Andric SmallVector<PHINode *> PHIs; 1463*06c3fb27SDimitry Andric for (auto &PHI : B->phis()) { 1464*06c3fb27SDimitry Andric if (PHI.getNumIncomingValues() != 2) 1465*06c3fb27SDimitry Andric continue; 1466*06c3fb27SDimitry Andric 1467*06c3fb27SDimitry Andric if (!PHI.getType()->isVectorTy()) 1468*06c3fb27SDimitry Andric continue; 1469*06c3fb27SDimitry Andric 1470*06c3fb27SDimitry Andric auto *ReductionOp = dyn_cast<Instruction>(PHI.getIncomingValueForBlock(B)); 1471*06c3fb27SDimitry Andric if (!ReductionOp) 1472*06c3fb27SDimitry Andric continue; 1473*06c3fb27SDimitry Andric 1474*06c3fb27SDimitry Andric // Check if final instruction is reduced outside of current block 1475*06c3fb27SDimitry Andric Instruction *FinalReduction = nullptr; 1476*06c3fb27SDimitry Andric auto NumUsers = 0u; 1477*06c3fb27SDimitry Andric for (auto *U : ReductionOp->users()) { 1478*06c3fb27SDimitry Andric ++NumUsers; 1479*06c3fb27SDimitry Andric if (U == &PHI) 1480*06c3fb27SDimitry Andric continue; 1481*06c3fb27SDimitry Andric FinalReduction = dyn_cast<Instruction>(U); 1482*06c3fb27SDimitry Andric } 1483*06c3fb27SDimitry Andric 1484*06c3fb27SDimitry Andric if (NumUsers != 2 || !FinalReduction || FinalReduction->getParent() == B || 1485*06c3fb27SDimitry Andric isa<PHINode>(FinalReduction)) 1486*06c3fb27SDimitry Andric continue; 1487*06c3fb27SDimitry Andric 1488*06c3fb27SDimitry Andric ReductionInfo[ReductionOp] = {&PHI, FinalReduction}; 1489*06c3fb27SDimitry Andric BackEdge = B; 1490*06c3fb27SDimitry Andric auto BackEdgeIdx = PHI.getBasicBlockIndex(B); 1491*06c3fb27SDimitry Andric auto IncomingIdx = BackEdgeIdx == 0 ? 1 : 0; 1492*06c3fb27SDimitry Andric Incoming = PHI.getIncomingBlock(IncomingIdx); 1493*06c3fb27SDimitry Andric FoundPotentialReduction = true; 1494*06c3fb27SDimitry Andric 1495*06c3fb27SDimitry Andric // If the initial value of PHINode is an Instruction, consider it a leaf 1496*06c3fb27SDimitry Andric // value of a complex deinterleaving graph. 1497*06c3fb27SDimitry Andric if (auto *InitPHI = 1498*06c3fb27SDimitry Andric dyn_cast<Instruction>(PHI.getIncomingValueForBlock(Incoming))) 1499*06c3fb27SDimitry Andric FinalInstructions.insert(InitPHI); 1500*06c3fb27SDimitry Andric } 1501*06c3fb27SDimitry Andric return FoundPotentialReduction; 1502*06c3fb27SDimitry Andric } 1503*06c3fb27SDimitry Andric 1504*06c3fb27SDimitry Andric void ComplexDeinterleavingGraph::identifyReductionNodes() { 1505*06c3fb27SDimitry Andric SmallVector<bool> Processed(ReductionInfo.size(), false); 1506*06c3fb27SDimitry Andric SmallVector<Instruction *> OperationInstruction; 1507*06c3fb27SDimitry Andric for (auto &P : ReductionInfo) 1508*06c3fb27SDimitry Andric OperationInstruction.push_back(P.first); 1509*06c3fb27SDimitry Andric 1510*06c3fb27SDimitry Andric // Identify a complex computation by evaluating two reduction operations that 1511*06c3fb27SDimitry Andric // potentially could be involved 1512*06c3fb27SDimitry Andric for (size_t i = 0; i < OperationInstruction.size(); ++i) { 1513*06c3fb27SDimitry Andric if (Processed[i]) 1514*06c3fb27SDimitry Andric continue; 1515*06c3fb27SDimitry Andric for (size_t j = i + 1; j < OperationInstruction.size(); ++j) { 1516*06c3fb27SDimitry Andric if (Processed[j]) 1517*06c3fb27SDimitry Andric continue; 1518*06c3fb27SDimitry Andric 1519*06c3fb27SDimitry Andric auto *Real = OperationInstruction[i]; 1520*06c3fb27SDimitry Andric auto *Imag = OperationInstruction[j]; 1521*06c3fb27SDimitry Andric if (Real->getType() != Imag->getType()) 1522*06c3fb27SDimitry Andric continue; 1523*06c3fb27SDimitry Andric 1524*06c3fb27SDimitry Andric RealPHI = ReductionInfo[Real].first; 1525*06c3fb27SDimitry Andric ImagPHI = ReductionInfo[Imag].first; 1526*06c3fb27SDimitry Andric PHIsFound = false; 1527*06c3fb27SDimitry Andric auto Node = identifyNode(Real, Imag); 1528*06c3fb27SDimitry Andric if (!Node) { 1529*06c3fb27SDimitry Andric std::swap(Real, Imag); 1530*06c3fb27SDimitry Andric std::swap(RealPHI, ImagPHI); 1531*06c3fb27SDimitry Andric Node = identifyNode(Real, Imag); 1532*06c3fb27SDimitry Andric } 1533*06c3fb27SDimitry Andric 1534*06c3fb27SDimitry Andric // If a node is identified and reduction PHINode is used in the chain of 1535*06c3fb27SDimitry Andric // operations, mark its operation instructions as used to prevent 1536*06c3fb27SDimitry Andric // re-identification and attach the node to the real part 1537*06c3fb27SDimitry Andric if (Node && PHIsFound) { 1538*06c3fb27SDimitry Andric LLVM_DEBUG(dbgs() << "Identified reduction starting from instructions: " 1539*06c3fb27SDimitry Andric << *Real << " / " << *Imag << "\n"); 1540*06c3fb27SDimitry Andric Processed[i] = true; 1541*06c3fb27SDimitry Andric Processed[j] = true; 1542*06c3fb27SDimitry Andric auto RootNode = prepareCompositeNode( 1543*06c3fb27SDimitry Andric ComplexDeinterleavingOperation::ReductionOperation, Real, Imag); 1544*06c3fb27SDimitry Andric RootNode->addOperand(Node); 1545*06c3fb27SDimitry Andric RootToNode[Real] = RootNode; 1546*06c3fb27SDimitry Andric RootToNode[Imag] = RootNode; 1547*06c3fb27SDimitry Andric submitCompositeNode(RootNode); 1548*06c3fb27SDimitry Andric break; 1549*06c3fb27SDimitry Andric } 1550*06c3fb27SDimitry Andric } 1551*06c3fb27SDimitry Andric } 1552*06c3fb27SDimitry Andric 1553*06c3fb27SDimitry Andric RealPHI = nullptr; 1554*06c3fb27SDimitry Andric ImagPHI = nullptr; 1555*06c3fb27SDimitry Andric } 1556*06c3fb27SDimitry Andric 1557*06c3fb27SDimitry Andric bool ComplexDeinterleavingGraph::checkNodes() { 1558*06c3fb27SDimitry Andric // Collect all instructions from roots to leaves 1559*06c3fb27SDimitry Andric SmallPtrSet<Instruction *, 16> AllInstructions; 1560*06c3fb27SDimitry Andric SmallVector<Instruction *, 8> Worklist; 1561*06c3fb27SDimitry Andric for (auto &Pair : RootToNode) 1562*06c3fb27SDimitry Andric Worklist.push_back(Pair.first); 1563*06c3fb27SDimitry Andric 1564*06c3fb27SDimitry Andric // Extract all instructions that are used by all XCMLA/XCADD/ADD/SUB/NEG 1565*06c3fb27SDimitry Andric // chains 1566*06c3fb27SDimitry Andric while (!Worklist.empty()) { 1567*06c3fb27SDimitry Andric auto *I = Worklist.back(); 1568*06c3fb27SDimitry Andric Worklist.pop_back(); 1569*06c3fb27SDimitry Andric 1570*06c3fb27SDimitry Andric if (!AllInstructions.insert(I).second) 1571*06c3fb27SDimitry Andric continue; 1572*06c3fb27SDimitry Andric 1573*06c3fb27SDimitry Andric for (Value *Op : I->operands()) { 1574*06c3fb27SDimitry Andric if (auto *OpI = dyn_cast<Instruction>(Op)) { 1575*06c3fb27SDimitry Andric if (!FinalInstructions.count(I)) 1576*06c3fb27SDimitry Andric Worklist.emplace_back(OpI); 1577*06c3fb27SDimitry Andric } 1578*06c3fb27SDimitry Andric } 1579*06c3fb27SDimitry Andric } 1580*06c3fb27SDimitry Andric 1581*06c3fb27SDimitry Andric // Find instructions that have users outside of chain 1582*06c3fb27SDimitry Andric SmallVector<Instruction *, 2> OuterInstructions; 1583*06c3fb27SDimitry Andric for (auto *I : AllInstructions) { 1584*06c3fb27SDimitry Andric // Skip root nodes 1585*06c3fb27SDimitry Andric if (RootToNode.count(I)) 1586*06c3fb27SDimitry Andric continue; 1587*06c3fb27SDimitry Andric 1588*06c3fb27SDimitry Andric for (User *U : I->users()) { 1589*06c3fb27SDimitry Andric if (AllInstructions.count(cast<Instruction>(U))) 1590*06c3fb27SDimitry Andric continue; 1591*06c3fb27SDimitry Andric 1592*06c3fb27SDimitry Andric // Found an instruction that is not used by XCMLA/XCADD chain 1593*06c3fb27SDimitry Andric Worklist.emplace_back(I); 1594*06c3fb27SDimitry Andric break; 1595*06c3fb27SDimitry Andric } 1596*06c3fb27SDimitry Andric } 1597*06c3fb27SDimitry Andric 1598*06c3fb27SDimitry Andric // If any instructions are found to be used outside, find and remove roots 1599*06c3fb27SDimitry Andric // that somehow connect to those instructions. 1600*06c3fb27SDimitry Andric SmallPtrSet<Instruction *, 16> Visited; 1601*06c3fb27SDimitry Andric while (!Worklist.empty()) { 1602*06c3fb27SDimitry Andric auto *I = Worklist.back(); 1603*06c3fb27SDimitry Andric Worklist.pop_back(); 1604*06c3fb27SDimitry Andric if (!Visited.insert(I).second) 1605*06c3fb27SDimitry Andric continue; 1606*06c3fb27SDimitry Andric 1607*06c3fb27SDimitry Andric // Found an impacted root node. Removing it from the nodes to be 1608*06c3fb27SDimitry Andric // deinterleaved 1609*06c3fb27SDimitry Andric if (RootToNode.count(I)) { 1610*06c3fb27SDimitry Andric LLVM_DEBUG(dbgs() << "Instruction " << *I 1611*06c3fb27SDimitry Andric << " could be deinterleaved but its chain of complex " 1612*06c3fb27SDimitry Andric "operations have an outside user\n"); 1613*06c3fb27SDimitry Andric RootToNode.erase(I); 1614*06c3fb27SDimitry Andric } 1615*06c3fb27SDimitry Andric 1616*06c3fb27SDimitry Andric if (!AllInstructions.count(I) || FinalInstructions.count(I)) 1617*06c3fb27SDimitry Andric continue; 1618*06c3fb27SDimitry Andric 1619*06c3fb27SDimitry Andric for (User *U : I->users()) 1620*06c3fb27SDimitry Andric Worklist.emplace_back(cast<Instruction>(U)); 1621*06c3fb27SDimitry Andric 1622*06c3fb27SDimitry Andric for (Value *Op : I->operands()) { 1623*06c3fb27SDimitry Andric if (auto *OpI = dyn_cast<Instruction>(Op)) 1624*06c3fb27SDimitry Andric Worklist.emplace_back(OpI); 1625*06c3fb27SDimitry Andric } 1626*06c3fb27SDimitry Andric } 1627*06c3fb27SDimitry Andric return !RootToNode.empty(); 1628*06c3fb27SDimitry Andric } 1629*06c3fb27SDimitry Andric 1630*06c3fb27SDimitry Andric ComplexDeinterleavingGraph::NodePtr 1631*06c3fb27SDimitry Andric ComplexDeinterleavingGraph::identifyRoot(Instruction *RootI) { 1632*06c3fb27SDimitry Andric if (auto *Intrinsic = dyn_cast<IntrinsicInst>(RootI)) { 1633*06c3fb27SDimitry Andric if (Intrinsic->getIntrinsicID() != 1634*06c3fb27SDimitry Andric Intrinsic::experimental_vector_interleave2) 1635*06c3fb27SDimitry Andric return nullptr; 1636*06c3fb27SDimitry Andric 1637*06c3fb27SDimitry Andric auto *Real = dyn_cast<Instruction>(Intrinsic->getOperand(0)); 1638*06c3fb27SDimitry Andric auto *Imag = dyn_cast<Instruction>(Intrinsic->getOperand(1)); 1639*06c3fb27SDimitry Andric if (!Real || !Imag) 1640*06c3fb27SDimitry Andric return nullptr; 1641*06c3fb27SDimitry Andric 1642*06c3fb27SDimitry Andric return identifyNode(Real, Imag); 1643*06c3fb27SDimitry Andric } 1644*06c3fb27SDimitry Andric 1645*06c3fb27SDimitry Andric auto *SVI = dyn_cast<ShuffleVectorInst>(RootI); 1646*06c3fb27SDimitry Andric if (!SVI) 1647*06c3fb27SDimitry Andric return nullptr; 1648*06c3fb27SDimitry Andric 1649*06c3fb27SDimitry Andric // Look for a shufflevector that takes separate vectors of the real and 1650*06c3fb27SDimitry Andric // imaginary components and recombines them into a single vector. 1651*06c3fb27SDimitry Andric if (!isInterleavingMask(SVI->getShuffleMask())) 1652*06c3fb27SDimitry Andric return nullptr; 1653*06c3fb27SDimitry Andric 1654*06c3fb27SDimitry Andric Instruction *Real; 1655*06c3fb27SDimitry Andric Instruction *Imag; 1656*06c3fb27SDimitry Andric if (!match(RootI, m_Shuffle(m_Instruction(Real), m_Instruction(Imag)))) 1657*06c3fb27SDimitry Andric return nullptr; 1658*06c3fb27SDimitry Andric 1659*06c3fb27SDimitry Andric return identifyNode(Real, Imag); 1660*06c3fb27SDimitry Andric } 1661*06c3fb27SDimitry Andric 1662*06c3fb27SDimitry Andric ComplexDeinterleavingGraph::NodePtr 1663*06c3fb27SDimitry Andric ComplexDeinterleavingGraph::identifyDeinterleave(Instruction *Real, 1664*06c3fb27SDimitry Andric Instruction *Imag) { 1665*06c3fb27SDimitry Andric Instruction *I = nullptr; 1666*06c3fb27SDimitry Andric Value *FinalValue = nullptr; 1667*06c3fb27SDimitry Andric if (match(Real, m_ExtractValue<0>(m_Instruction(I))) && 1668*06c3fb27SDimitry Andric match(Imag, m_ExtractValue<1>(m_Specific(I))) && 1669*06c3fb27SDimitry Andric match(I, m_Intrinsic<Intrinsic::experimental_vector_deinterleave2>( 1670*06c3fb27SDimitry Andric m_Value(FinalValue)))) { 1671*06c3fb27SDimitry Andric NodePtr PlaceholderNode = prepareCompositeNode( 1672*06c3fb27SDimitry Andric llvm::ComplexDeinterleavingOperation::Deinterleave, Real, Imag); 1673*06c3fb27SDimitry Andric PlaceholderNode->ReplacementNode = FinalValue; 1674*06c3fb27SDimitry Andric FinalInstructions.insert(Real); 1675*06c3fb27SDimitry Andric FinalInstructions.insert(Imag); 1676*06c3fb27SDimitry Andric return submitCompositeNode(PlaceholderNode); 1677*06c3fb27SDimitry Andric } 1678*06c3fb27SDimitry Andric 1679bdd1243dSDimitry Andric auto *RealShuffle = dyn_cast<ShuffleVectorInst>(Real); 1680bdd1243dSDimitry Andric auto *ImagShuffle = dyn_cast<ShuffleVectorInst>(Imag); 1681*06c3fb27SDimitry Andric if (!RealShuffle || !ImagShuffle) { 1682*06c3fb27SDimitry Andric if (RealShuffle || ImagShuffle) 1683*06c3fb27SDimitry Andric LLVM_DEBUG(dbgs() << " - There's a shuffle where there shouldn't be.\n"); 1684*06c3fb27SDimitry Andric return nullptr; 1685*06c3fb27SDimitry Andric } 1686*06c3fb27SDimitry Andric 1687bdd1243dSDimitry Andric Value *RealOp1 = RealShuffle->getOperand(1); 1688bdd1243dSDimitry Andric if (!isa<UndefValue>(RealOp1) && !isa<ConstantAggregateZero>(RealOp1)) { 1689bdd1243dSDimitry Andric LLVM_DEBUG(dbgs() << " - RealOp1 is not undef or zero.\n"); 1690bdd1243dSDimitry Andric return nullptr; 1691bdd1243dSDimitry Andric } 1692bdd1243dSDimitry Andric Value *ImagOp1 = ImagShuffle->getOperand(1); 1693bdd1243dSDimitry Andric if (!isa<UndefValue>(ImagOp1) && !isa<ConstantAggregateZero>(ImagOp1)) { 1694bdd1243dSDimitry Andric LLVM_DEBUG(dbgs() << " - ImagOp1 is not undef or zero.\n"); 1695bdd1243dSDimitry Andric return nullptr; 1696bdd1243dSDimitry Andric } 1697bdd1243dSDimitry Andric 1698bdd1243dSDimitry Andric Value *RealOp0 = RealShuffle->getOperand(0); 1699bdd1243dSDimitry Andric Value *ImagOp0 = ImagShuffle->getOperand(0); 1700bdd1243dSDimitry Andric 1701bdd1243dSDimitry Andric if (RealOp0 != ImagOp0) { 1702bdd1243dSDimitry Andric LLVM_DEBUG(dbgs() << " - Shuffle operands are not equal.\n"); 1703bdd1243dSDimitry Andric return nullptr; 1704bdd1243dSDimitry Andric } 1705bdd1243dSDimitry Andric 1706bdd1243dSDimitry Andric ArrayRef<int> RealMask = RealShuffle->getShuffleMask(); 1707bdd1243dSDimitry Andric ArrayRef<int> ImagMask = ImagShuffle->getShuffleMask(); 1708bdd1243dSDimitry Andric if (!isDeinterleavingMask(RealMask) || !isDeinterleavingMask(ImagMask)) { 1709bdd1243dSDimitry Andric LLVM_DEBUG(dbgs() << " - Masks are not deinterleaving.\n"); 1710bdd1243dSDimitry Andric return nullptr; 1711bdd1243dSDimitry Andric } 1712bdd1243dSDimitry Andric 1713bdd1243dSDimitry Andric if (RealMask[0] != 0 || ImagMask[0] != 1) { 1714bdd1243dSDimitry Andric LLVM_DEBUG(dbgs() << " - Masks do not have the correct initial value.\n"); 1715bdd1243dSDimitry Andric return nullptr; 1716bdd1243dSDimitry Andric } 1717bdd1243dSDimitry Andric 1718bdd1243dSDimitry Andric // Type checking, the shuffle type should be a vector type of the same 1719bdd1243dSDimitry Andric // scalar type, but half the size 1720bdd1243dSDimitry Andric auto CheckType = [&](ShuffleVectorInst *Shuffle) { 1721bdd1243dSDimitry Andric Value *Op = Shuffle->getOperand(0); 1722bdd1243dSDimitry Andric auto *ShuffleTy = cast<FixedVectorType>(Shuffle->getType()); 1723bdd1243dSDimitry Andric auto *OpTy = cast<FixedVectorType>(Op->getType()); 1724bdd1243dSDimitry Andric 1725bdd1243dSDimitry Andric if (OpTy->getScalarType() != ShuffleTy->getScalarType()) 1726bdd1243dSDimitry Andric return false; 1727bdd1243dSDimitry Andric if ((ShuffleTy->getNumElements() * 2) != OpTy->getNumElements()) 1728bdd1243dSDimitry Andric return false; 1729bdd1243dSDimitry Andric 1730bdd1243dSDimitry Andric return true; 1731bdd1243dSDimitry Andric }; 1732bdd1243dSDimitry Andric 1733bdd1243dSDimitry Andric auto CheckDeinterleavingShuffle = [&](ShuffleVectorInst *Shuffle) -> bool { 1734bdd1243dSDimitry Andric if (!CheckType(Shuffle)) 1735bdd1243dSDimitry Andric return false; 1736bdd1243dSDimitry Andric 1737bdd1243dSDimitry Andric ArrayRef<int> Mask = Shuffle->getShuffleMask(); 1738bdd1243dSDimitry Andric int Last = *Mask.rbegin(); 1739bdd1243dSDimitry Andric 1740bdd1243dSDimitry Andric Value *Op = Shuffle->getOperand(0); 1741bdd1243dSDimitry Andric auto *OpTy = cast<FixedVectorType>(Op->getType()); 1742bdd1243dSDimitry Andric int NumElements = OpTy->getNumElements(); 1743bdd1243dSDimitry Andric 1744bdd1243dSDimitry Andric // Ensure that the deinterleaving shuffle only pulls from the first 1745bdd1243dSDimitry Andric // shuffle operand. 1746bdd1243dSDimitry Andric return Last < NumElements; 1747bdd1243dSDimitry Andric }; 1748bdd1243dSDimitry Andric 1749bdd1243dSDimitry Andric if (RealShuffle->getType() != ImagShuffle->getType()) { 1750bdd1243dSDimitry Andric LLVM_DEBUG(dbgs() << " - Shuffle types aren't equal.\n"); 1751bdd1243dSDimitry Andric return nullptr; 1752bdd1243dSDimitry Andric } 1753bdd1243dSDimitry Andric if (!CheckDeinterleavingShuffle(RealShuffle)) { 1754bdd1243dSDimitry Andric LLVM_DEBUG(dbgs() << " - RealShuffle is invalid type.\n"); 1755bdd1243dSDimitry Andric return nullptr; 1756bdd1243dSDimitry Andric } 1757bdd1243dSDimitry Andric if (!CheckDeinterleavingShuffle(ImagShuffle)) { 1758bdd1243dSDimitry Andric LLVM_DEBUG(dbgs() << " - ImagShuffle is invalid type.\n"); 1759bdd1243dSDimitry Andric return nullptr; 1760bdd1243dSDimitry Andric } 1761bdd1243dSDimitry Andric 1762bdd1243dSDimitry Andric NodePtr PlaceholderNode = 1763*06c3fb27SDimitry Andric prepareCompositeNode(llvm::ComplexDeinterleavingOperation::Deinterleave, 1764bdd1243dSDimitry Andric RealShuffle, ImagShuffle); 1765bdd1243dSDimitry Andric PlaceholderNode->ReplacementNode = RealShuffle->getOperand(0); 1766*06c3fb27SDimitry Andric FinalInstructions.insert(RealShuffle); 1767*06c3fb27SDimitry Andric FinalInstructions.insert(ImagShuffle); 1768bdd1243dSDimitry Andric return submitCompositeNode(PlaceholderNode); 1769bdd1243dSDimitry Andric } 1770bdd1243dSDimitry Andric 1771*06c3fb27SDimitry Andric ComplexDeinterleavingGraph::NodePtr 1772*06c3fb27SDimitry Andric ComplexDeinterleavingGraph::identifySplat(Value *R, Value *I) { 1773*06c3fb27SDimitry Andric auto IsSplat = [](Value *V) -> bool { 1774*06c3fb27SDimitry Andric // Fixed-width vector with constants 1775*06c3fb27SDimitry Andric if (isa<ConstantDataVector>(V)) 1776*06c3fb27SDimitry Andric return true; 1777bdd1243dSDimitry Andric 1778*06c3fb27SDimitry Andric VectorType *VTy; 1779*06c3fb27SDimitry Andric ArrayRef<int> Mask; 1780*06c3fb27SDimitry Andric // Splats are represented differently depending on whether the repeated 1781*06c3fb27SDimitry Andric // value is a constant or an Instruction 1782*06c3fb27SDimitry Andric if (auto *Const = dyn_cast<ConstantExpr>(V)) { 1783*06c3fb27SDimitry Andric if (Const->getOpcode() != Instruction::ShuffleVector) 1784bdd1243dSDimitry Andric return false; 1785*06c3fb27SDimitry Andric VTy = cast<VectorType>(Const->getType()); 1786*06c3fb27SDimitry Andric Mask = Const->getShuffleMask(); 1787*06c3fb27SDimitry Andric } else if (auto *Shuf = dyn_cast<ShuffleVectorInst>(V)) { 1788*06c3fb27SDimitry Andric VTy = Shuf->getType(); 1789*06c3fb27SDimitry Andric Mask = Shuf->getShuffleMask(); 1790*06c3fb27SDimitry Andric } else { 1791bdd1243dSDimitry Andric return false; 1792bdd1243dSDimitry Andric } 1793*06c3fb27SDimitry Andric 1794*06c3fb27SDimitry Andric // When the data type is <1 x Type>, it's not possible to differentiate 1795*06c3fb27SDimitry Andric // between the ComplexDeinterleaving::Deinterleave and 1796*06c3fb27SDimitry Andric // ComplexDeinterleaving::Splat operations. 1797*06c3fb27SDimitry Andric if (!VTy->isScalableTy() && VTy->getElementCount().getKnownMinValue() == 1) 1798*06c3fb27SDimitry Andric return false; 1799*06c3fb27SDimitry Andric 1800*06c3fb27SDimitry Andric return all_equal(Mask) && Mask[0] == 0; 1801*06c3fb27SDimitry Andric }; 1802*06c3fb27SDimitry Andric 1803*06c3fb27SDimitry Andric if (!IsSplat(R) || !IsSplat(I)) 1804*06c3fb27SDimitry Andric return nullptr; 1805*06c3fb27SDimitry Andric 1806*06c3fb27SDimitry Andric auto *Real = dyn_cast<Instruction>(R); 1807*06c3fb27SDimitry Andric auto *Imag = dyn_cast<Instruction>(I); 1808*06c3fb27SDimitry Andric if ((!Real && Imag) || (Real && !Imag)) 1809*06c3fb27SDimitry Andric return nullptr; 1810*06c3fb27SDimitry Andric 1811*06c3fb27SDimitry Andric if (Real && Imag) { 1812*06c3fb27SDimitry Andric // Non-constant splats should be in the same basic block 1813*06c3fb27SDimitry Andric if (Real->getParent() != Imag->getParent()) 1814*06c3fb27SDimitry Andric return nullptr; 1815*06c3fb27SDimitry Andric 1816*06c3fb27SDimitry Andric FinalInstructions.insert(Real); 1817*06c3fb27SDimitry Andric FinalInstructions.insert(Imag); 1818bdd1243dSDimitry Andric } 1819*06c3fb27SDimitry Andric NodePtr PlaceholderNode = 1820*06c3fb27SDimitry Andric prepareCompositeNode(ComplexDeinterleavingOperation::Splat, R, I); 1821*06c3fb27SDimitry Andric return submitCompositeNode(PlaceholderNode); 1822bdd1243dSDimitry Andric } 1823bdd1243dSDimitry Andric 1824*06c3fb27SDimitry Andric ComplexDeinterleavingGraph::NodePtr 1825*06c3fb27SDimitry Andric ComplexDeinterleavingGraph::identifyPHINode(Instruction *Real, 1826*06c3fb27SDimitry Andric Instruction *Imag) { 1827*06c3fb27SDimitry Andric if (Real != RealPHI || Imag != ImagPHI) 1828*06c3fb27SDimitry Andric return nullptr; 1829*06c3fb27SDimitry Andric 1830*06c3fb27SDimitry Andric PHIsFound = true; 1831*06c3fb27SDimitry Andric NodePtr PlaceholderNode = prepareCompositeNode( 1832*06c3fb27SDimitry Andric ComplexDeinterleavingOperation::ReductionPHI, Real, Imag); 1833*06c3fb27SDimitry Andric return submitCompositeNode(PlaceholderNode); 1834*06c3fb27SDimitry Andric } 1835*06c3fb27SDimitry Andric 1836*06c3fb27SDimitry Andric ComplexDeinterleavingGraph::NodePtr 1837*06c3fb27SDimitry Andric ComplexDeinterleavingGraph::identifySelectNode(Instruction *Real, 1838*06c3fb27SDimitry Andric Instruction *Imag) { 1839*06c3fb27SDimitry Andric auto *SelectReal = dyn_cast<SelectInst>(Real); 1840*06c3fb27SDimitry Andric auto *SelectImag = dyn_cast<SelectInst>(Imag); 1841*06c3fb27SDimitry Andric if (!SelectReal || !SelectImag) 1842*06c3fb27SDimitry Andric return nullptr; 1843*06c3fb27SDimitry Andric 1844*06c3fb27SDimitry Andric Instruction *MaskA, *MaskB; 1845*06c3fb27SDimitry Andric Instruction *AR, *AI, *RA, *BI; 1846*06c3fb27SDimitry Andric if (!match(Real, m_Select(m_Instruction(MaskA), m_Instruction(AR), 1847*06c3fb27SDimitry Andric m_Instruction(RA))) || 1848*06c3fb27SDimitry Andric !match(Imag, m_Select(m_Instruction(MaskB), m_Instruction(AI), 1849*06c3fb27SDimitry Andric m_Instruction(BI)))) 1850*06c3fb27SDimitry Andric return nullptr; 1851*06c3fb27SDimitry Andric 1852*06c3fb27SDimitry Andric if (MaskA != MaskB && !MaskA->isIdenticalTo(MaskB)) 1853*06c3fb27SDimitry Andric return nullptr; 1854*06c3fb27SDimitry Andric 1855*06c3fb27SDimitry Andric if (!MaskA->getType()->isVectorTy()) 1856*06c3fb27SDimitry Andric return nullptr; 1857*06c3fb27SDimitry Andric 1858*06c3fb27SDimitry Andric auto NodeA = identifyNode(AR, AI); 1859*06c3fb27SDimitry Andric if (!NodeA) 1860*06c3fb27SDimitry Andric return nullptr; 1861*06c3fb27SDimitry Andric 1862*06c3fb27SDimitry Andric auto NodeB = identifyNode(RA, BI); 1863*06c3fb27SDimitry Andric if (!NodeB) 1864*06c3fb27SDimitry Andric return nullptr; 1865*06c3fb27SDimitry Andric 1866*06c3fb27SDimitry Andric NodePtr PlaceholderNode = prepareCompositeNode( 1867*06c3fb27SDimitry Andric ComplexDeinterleavingOperation::ReductionSelect, Real, Imag); 1868*06c3fb27SDimitry Andric PlaceholderNode->addOperand(NodeA); 1869*06c3fb27SDimitry Andric PlaceholderNode->addOperand(NodeB); 1870*06c3fb27SDimitry Andric FinalInstructions.insert(MaskA); 1871*06c3fb27SDimitry Andric FinalInstructions.insert(MaskB); 1872*06c3fb27SDimitry Andric return submitCompositeNode(PlaceholderNode); 1873*06c3fb27SDimitry Andric } 1874*06c3fb27SDimitry Andric 1875*06c3fb27SDimitry Andric static Value *replaceSymmetricNode(IRBuilderBase &B, unsigned Opcode, 1876*06c3fb27SDimitry Andric std::optional<FastMathFlags> Flags, 1877*06c3fb27SDimitry Andric Value *InputA, Value *InputB) { 1878*06c3fb27SDimitry Andric Value *I; 1879*06c3fb27SDimitry Andric switch (Opcode) { 1880*06c3fb27SDimitry Andric case Instruction::FNeg: 1881*06c3fb27SDimitry Andric I = B.CreateFNeg(InputA); 1882*06c3fb27SDimitry Andric break; 1883*06c3fb27SDimitry Andric case Instruction::FAdd: 1884*06c3fb27SDimitry Andric I = B.CreateFAdd(InputA, InputB); 1885*06c3fb27SDimitry Andric break; 1886*06c3fb27SDimitry Andric case Instruction::Add: 1887*06c3fb27SDimitry Andric I = B.CreateAdd(InputA, InputB); 1888*06c3fb27SDimitry Andric break; 1889*06c3fb27SDimitry Andric case Instruction::FSub: 1890*06c3fb27SDimitry Andric I = B.CreateFSub(InputA, InputB); 1891*06c3fb27SDimitry Andric break; 1892*06c3fb27SDimitry Andric case Instruction::Sub: 1893*06c3fb27SDimitry Andric I = B.CreateSub(InputA, InputB); 1894*06c3fb27SDimitry Andric break; 1895*06c3fb27SDimitry Andric case Instruction::FMul: 1896*06c3fb27SDimitry Andric I = B.CreateFMul(InputA, InputB); 1897*06c3fb27SDimitry Andric break; 1898*06c3fb27SDimitry Andric case Instruction::Mul: 1899*06c3fb27SDimitry Andric I = B.CreateMul(InputA, InputB); 1900*06c3fb27SDimitry Andric break; 1901*06c3fb27SDimitry Andric default: 1902*06c3fb27SDimitry Andric llvm_unreachable("Incorrect symmetric opcode"); 1903*06c3fb27SDimitry Andric } 1904*06c3fb27SDimitry Andric if (Flags) 1905*06c3fb27SDimitry Andric cast<Instruction>(I)->setFastMathFlags(*Flags); 1906*06c3fb27SDimitry Andric return I; 1907*06c3fb27SDimitry Andric } 1908*06c3fb27SDimitry Andric 1909*06c3fb27SDimitry Andric Value *ComplexDeinterleavingGraph::replaceNode(IRBuilderBase &Builder, 1910*06c3fb27SDimitry Andric RawNodePtr Node) { 1911bdd1243dSDimitry Andric if (Node->ReplacementNode) 1912bdd1243dSDimitry Andric return Node->ReplacementNode; 1913bdd1243dSDimitry Andric 1914*06c3fb27SDimitry Andric auto ReplaceOperandIfExist = [&](RawNodePtr &Node, unsigned Idx) -> Value * { 1915*06c3fb27SDimitry Andric return Node->Operands.size() > Idx 1916*06c3fb27SDimitry Andric ? replaceNode(Builder, Node->Operands[Idx]) 1917*06c3fb27SDimitry Andric : nullptr; 1918*06c3fb27SDimitry Andric }; 1919bdd1243dSDimitry Andric 1920*06c3fb27SDimitry Andric Value *ReplacementNode; 1921*06c3fb27SDimitry Andric switch (Node->Operation) { 1922*06c3fb27SDimitry Andric case ComplexDeinterleavingOperation::CAdd: 1923*06c3fb27SDimitry Andric case ComplexDeinterleavingOperation::CMulPartial: 1924*06c3fb27SDimitry Andric case ComplexDeinterleavingOperation::Symmetric: { 1925*06c3fb27SDimitry Andric Value *Input0 = ReplaceOperandIfExist(Node, 0); 1926*06c3fb27SDimitry Andric Value *Input1 = ReplaceOperandIfExist(Node, 1); 1927*06c3fb27SDimitry Andric Value *Accumulator = ReplaceOperandIfExist(Node, 2); 1928*06c3fb27SDimitry Andric assert(!Input1 || (Input0->getType() == Input1->getType() && 1929*06c3fb27SDimitry Andric "Node inputs need to be of the same type")); 1930*06c3fb27SDimitry Andric assert(!Accumulator || 1931*06c3fb27SDimitry Andric (Input0->getType() == Accumulator->getType() && 1932*06c3fb27SDimitry Andric "Accumulator and input need to be of the same type")); 1933*06c3fb27SDimitry Andric if (Node->Operation == ComplexDeinterleavingOperation::Symmetric) 1934*06c3fb27SDimitry Andric ReplacementNode = replaceSymmetricNode(Builder, Node->Opcode, Node->Flags, 1935*06c3fb27SDimitry Andric Input0, Input1); 1936*06c3fb27SDimitry Andric else 1937*06c3fb27SDimitry Andric ReplacementNode = TL->createComplexDeinterleavingIR( 1938*06c3fb27SDimitry Andric Builder, Node->Operation, Node->Rotation, Input0, Input1, 1939*06c3fb27SDimitry Andric Accumulator); 1940*06c3fb27SDimitry Andric break; 1941*06c3fb27SDimitry Andric } 1942*06c3fb27SDimitry Andric case ComplexDeinterleavingOperation::Deinterleave: 1943*06c3fb27SDimitry Andric llvm_unreachable("Deinterleave node should already have ReplacementNode"); 1944*06c3fb27SDimitry Andric break; 1945*06c3fb27SDimitry Andric case ComplexDeinterleavingOperation::Splat: { 1946*06c3fb27SDimitry Andric auto *NewTy = VectorType::getDoubleElementsVectorType( 1947*06c3fb27SDimitry Andric cast<VectorType>(Node->Real->getType())); 1948*06c3fb27SDimitry Andric auto *R = dyn_cast<Instruction>(Node->Real); 1949*06c3fb27SDimitry Andric auto *I = dyn_cast<Instruction>(Node->Imag); 1950*06c3fb27SDimitry Andric if (R && I) { 1951*06c3fb27SDimitry Andric // Splats that are not constant are interleaved where they are located 1952*06c3fb27SDimitry Andric Instruction *InsertPoint = (I->comesBefore(R) ? R : I)->getNextNode(); 1953*06c3fb27SDimitry Andric IRBuilder<> IRB(InsertPoint); 1954*06c3fb27SDimitry Andric ReplacementNode = 1955*06c3fb27SDimitry Andric IRB.CreateIntrinsic(Intrinsic::experimental_vector_interleave2, NewTy, 1956*06c3fb27SDimitry Andric {Node->Real, Node->Imag}); 1957*06c3fb27SDimitry Andric } else { 1958*06c3fb27SDimitry Andric ReplacementNode = 1959*06c3fb27SDimitry Andric Builder.CreateIntrinsic(Intrinsic::experimental_vector_interleave2, 1960*06c3fb27SDimitry Andric NewTy, {Node->Real, Node->Imag}); 1961*06c3fb27SDimitry Andric } 1962*06c3fb27SDimitry Andric break; 1963*06c3fb27SDimitry Andric } 1964*06c3fb27SDimitry Andric case ComplexDeinterleavingOperation::ReductionPHI: { 1965*06c3fb27SDimitry Andric // If Operation is ReductionPHI, a new empty PHINode is created. 1966*06c3fb27SDimitry Andric // It is filled later when the ReductionOperation is processed. 1967*06c3fb27SDimitry Andric auto *VTy = cast<VectorType>(Node->Real->getType()); 1968*06c3fb27SDimitry Andric auto *NewVTy = VectorType::getDoubleElementsVectorType(VTy); 1969*06c3fb27SDimitry Andric auto *NewPHI = PHINode::Create(NewVTy, 0, "", BackEdge->getFirstNonPHI()); 1970*06c3fb27SDimitry Andric OldToNewPHI[dyn_cast<PHINode>(Node->Real)] = NewPHI; 1971*06c3fb27SDimitry Andric ReplacementNode = NewPHI; 1972*06c3fb27SDimitry Andric break; 1973*06c3fb27SDimitry Andric } 1974*06c3fb27SDimitry Andric case ComplexDeinterleavingOperation::ReductionOperation: 1975*06c3fb27SDimitry Andric ReplacementNode = replaceNode(Builder, Node->Operands[0]); 1976*06c3fb27SDimitry Andric processReductionOperation(ReplacementNode, Node); 1977*06c3fb27SDimitry Andric break; 1978*06c3fb27SDimitry Andric case ComplexDeinterleavingOperation::ReductionSelect: { 1979*06c3fb27SDimitry Andric auto *MaskReal = cast<Instruction>(Node->Real)->getOperand(0); 1980*06c3fb27SDimitry Andric auto *MaskImag = cast<Instruction>(Node->Imag)->getOperand(0); 1981*06c3fb27SDimitry Andric auto *A = replaceNode(Builder, Node->Operands[0]); 1982*06c3fb27SDimitry Andric auto *B = replaceNode(Builder, Node->Operands[1]); 1983*06c3fb27SDimitry Andric auto *NewMaskTy = VectorType::getDoubleElementsVectorType( 1984*06c3fb27SDimitry Andric cast<VectorType>(MaskReal->getType())); 1985*06c3fb27SDimitry Andric auto *NewMask = 1986*06c3fb27SDimitry Andric Builder.CreateIntrinsic(Intrinsic::experimental_vector_interleave2, 1987*06c3fb27SDimitry Andric NewMaskTy, {MaskReal, MaskImag}); 1988*06c3fb27SDimitry Andric ReplacementNode = Builder.CreateSelect(NewMask, A, B); 1989*06c3fb27SDimitry Andric break; 1990*06c3fb27SDimitry Andric } 1991*06c3fb27SDimitry Andric } 1992bdd1243dSDimitry Andric 1993*06c3fb27SDimitry Andric assert(ReplacementNode && "Target failed to create Intrinsic call."); 1994bdd1243dSDimitry Andric NumComplexTransformations += 1; 1995*06c3fb27SDimitry Andric Node->ReplacementNode = ReplacementNode; 1996*06c3fb27SDimitry Andric return ReplacementNode; 1997*06c3fb27SDimitry Andric } 1998*06c3fb27SDimitry Andric 1999*06c3fb27SDimitry Andric void ComplexDeinterleavingGraph::processReductionOperation( 2000*06c3fb27SDimitry Andric Value *OperationReplacement, RawNodePtr Node) { 2001*06c3fb27SDimitry Andric auto *Real = cast<Instruction>(Node->Real); 2002*06c3fb27SDimitry Andric auto *Imag = cast<Instruction>(Node->Imag); 2003*06c3fb27SDimitry Andric auto *OldPHIReal = ReductionInfo[Real].first; 2004*06c3fb27SDimitry Andric auto *OldPHIImag = ReductionInfo[Imag].first; 2005*06c3fb27SDimitry Andric auto *NewPHI = OldToNewPHI[OldPHIReal]; 2006*06c3fb27SDimitry Andric 2007*06c3fb27SDimitry Andric auto *VTy = cast<VectorType>(Real->getType()); 2008*06c3fb27SDimitry Andric auto *NewVTy = VectorType::getDoubleElementsVectorType(VTy); 2009*06c3fb27SDimitry Andric 2010*06c3fb27SDimitry Andric // We have to interleave initial origin values coming from IncomingBlock 2011*06c3fb27SDimitry Andric Value *InitReal = OldPHIReal->getIncomingValueForBlock(Incoming); 2012*06c3fb27SDimitry Andric Value *InitImag = OldPHIImag->getIncomingValueForBlock(Incoming); 2013*06c3fb27SDimitry Andric 2014*06c3fb27SDimitry Andric IRBuilder<> Builder(Incoming->getTerminator()); 2015*06c3fb27SDimitry Andric auto *NewInit = Builder.CreateIntrinsic( 2016*06c3fb27SDimitry Andric Intrinsic::experimental_vector_interleave2, NewVTy, {InitReal, InitImag}); 2017*06c3fb27SDimitry Andric 2018*06c3fb27SDimitry Andric NewPHI->addIncoming(NewInit, Incoming); 2019*06c3fb27SDimitry Andric NewPHI->addIncoming(OperationReplacement, BackEdge); 2020*06c3fb27SDimitry Andric 2021*06c3fb27SDimitry Andric // Deinterleave complex vector outside of loop so that it can be finally 2022*06c3fb27SDimitry Andric // reduced 2023*06c3fb27SDimitry Andric auto *FinalReductionReal = ReductionInfo[Real].second; 2024*06c3fb27SDimitry Andric auto *FinalReductionImag = ReductionInfo[Imag].second; 2025*06c3fb27SDimitry Andric 2026*06c3fb27SDimitry Andric Builder.SetInsertPoint( 2027*06c3fb27SDimitry Andric &*FinalReductionReal->getParent()->getFirstInsertionPt()); 2028*06c3fb27SDimitry Andric auto *Deinterleave = Builder.CreateIntrinsic( 2029*06c3fb27SDimitry Andric Intrinsic::experimental_vector_deinterleave2, 2030*06c3fb27SDimitry Andric OperationReplacement->getType(), OperationReplacement); 2031*06c3fb27SDimitry Andric 2032*06c3fb27SDimitry Andric auto *NewReal = Builder.CreateExtractValue(Deinterleave, (uint64_t)0); 2033*06c3fb27SDimitry Andric FinalReductionReal->replaceUsesOfWith(Real, NewReal); 2034*06c3fb27SDimitry Andric 2035*06c3fb27SDimitry Andric Builder.SetInsertPoint(FinalReductionImag); 2036*06c3fb27SDimitry Andric auto *NewImag = Builder.CreateExtractValue(Deinterleave, 1); 2037*06c3fb27SDimitry Andric FinalReductionImag->replaceUsesOfWith(Imag, NewImag); 2038bdd1243dSDimitry Andric } 2039bdd1243dSDimitry Andric 2040bdd1243dSDimitry Andric void ComplexDeinterleavingGraph::replaceNodes() { 2041*06c3fb27SDimitry Andric SmallVector<Instruction *, 16> DeadInstrRoots; 2042*06c3fb27SDimitry Andric for (auto *RootInstruction : OrderedRoots) { 2043*06c3fb27SDimitry Andric // Check if this potential root went through check process and we can 2044*06c3fb27SDimitry Andric // deinterleave it 2045*06c3fb27SDimitry Andric if (!RootToNode.count(RootInstruction)) 2046*06c3fb27SDimitry Andric continue; 2047*06c3fb27SDimitry Andric 2048*06c3fb27SDimitry Andric IRBuilder<> Builder(RootInstruction); 2049*06c3fb27SDimitry Andric auto RootNode = RootToNode[RootInstruction]; 2050*06c3fb27SDimitry Andric Value *R = replaceNode(Builder, RootNode.get()); 2051*06c3fb27SDimitry Andric 2052*06c3fb27SDimitry Andric if (RootNode->Operation == 2053*06c3fb27SDimitry Andric ComplexDeinterleavingOperation::ReductionOperation) { 2054*06c3fb27SDimitry Andric auto *RootReal = cast<Instruction>(RootNode->Real); 2055*06c3fb27SDimitry Andric auto *RootImag = cast<Instruction>(RootNode->Imag); 2056*06c3fb27SDimitry Andric ReductionInfo[RootReal].first->removeIncomingValue(BackEdge); 2057*06c3fb27SDimitry Andric ReductionInfo[RootImag].first->removeIncomingValue(BackEdge); 2058*06c3fb27SDimitry Andric DeadInstrRoots.push_back(cast<Instruction>(RootReal)); 2059*06c3fb27SDimitry Andric DeadInstrRoots.push_back(cast<Instruction>(RootImag)); 2060*06c3fb27SDimitry Andric } else { 2061*06c3fb27SDimitry Andric assert(R && "Unable to find replacement for RootInstruction"); 2062*06c3fb27SDimitry Andric DeadInstrRoots.push_back(RootInstruction); 2063*06c3fb27SDimitry Andric RootInstruction->replaceAllUsesWith(R); 2064*06c3fb27SDimitry Andric } 2065bdd1243dSDimitry Andric } 2066bdd1243dSDimitry Andric 2067*06c3fb27SDimitry Andric for (auto *I : DeadInstrRoots) 2068*06c3fb27SDimitry Andric RecursivelyDeleteTriviallyDeadInstructions(I, TLI); 2069bdd1243dSDimitry Andric } 2070