1d52e2839SNicholas Guy //===- ComplexDeinterleavingPass.cpp --------------------------------------===// 2d52e2839SNicholas Guy // 3d52e2839SNicholas Guy // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4d52e2839SNicholas Guy // See https://llvm.org/LICENSE.txt for license information. 5d52e2839SNicholas Guy // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6d52e2839SNicholas Guy // 7d52e2839SNicholas Guy //===----------------------------------------------------------------------===// 8d52e2839SNicholas Guy // 9d52e2839SNicholas Guy // Identification: 10d52e2839SNicholas Guy // This step is responsible for finding the patterns that can be lowered to 11d52e2839SNicholas Guy // complex instructions, and building a graph to represent the complex 12d52e2839SNicholas Guy // structures. Starting from the "Converging Shuffle" (a shuffle that 13d52e2839SNicholas Guy // reinterleaves the complex components, with a mask of <0, 2, 1, 3>), the 14d52e2839SNicholas Guy // operands are evaluated and identified as "Composite Nodes" (collections of 15d52e2839SNicholas Guy // instructions that can potentially be lowered to a single complex 16d52e2839SNicholas Guy // instruction). This is performed by checking the real and imaginary components 17d52e2839SNicholas Guy // and tracking the data flow for each component while following the operand 18d52e2839SNicholas Guy // pairs. Validity of each node is expected to be done upon creation, and any 19d52e2839SNicholas Guy // validation errors should halt traversal and prevent further graph 20d52e2839SNicholas Guy // construction. 216850bc35SIgor Kirillov // Instead of relying on Shuffle operations, vector interleaving and 226850bc35SIgor Kirillov // deinterleaving can be represented by vector.interleave2 and 236850bc35SIgor Kirillov // vector.deinterleave2 intrinsics. Scalable vectors can be represented only by 246850bc35SIgor Kirillov // these intrinsics, whereas, fixed-width vectors are recognized for both 256850bc35SIgor Kirillov // shufflevector instruction and intrinsics. 26d52e2839SNicholas Guy // 27d52e2839SNicholas Guy // Replacement: 28d52e2839SNicholas Guy // This step traverses the graph built up by identification, delegating to the 29d52e2839SNicholas Guy // target to validate and generate the correct intrinsics, and plumbs them 30d52e2839SNicholas Guy // together connecting each end of the new intrinsics graph to the existing 31d52e2839SNicholas Guy // use-def chain. This step is assumed to finish successfully, as all 32d52e2839SNicholas Guy // information is expected to be correct by this point. 33d52e2839SNicholas Guy // 34d52e2839SNicholas Guy // 35d52e2839SNicholas Guy // Internal data structure: 36d52e2839SNicholas Guy // ComplexDeinterleavingGraph: 37d52e2839SNicholas Guy // Keeps references to all the valid CompositeNodes formed as part of the 38d52e2839SNicholas Guy // transformation, and every Instruction contained within said nodes. It also 39d52e2839SNicholas Guy // holds onto a reference to the root Instruction, and the root node that should 40d52e2839SNicholas Guy // replace it. 41d52e2839SNicholas Guy // 42d52e2839SNicholas Guy // ComplexDeinterleavingCompositeNode: 43d52e2839SNicholas Guy // A CompositeNode represents a single transformation point; each node should 44d52e2839SNicholas Guy // transform into a single complex instruction (ignoring vector splitting, which 45d52e2839SNicholas Guy // would generate more instructions per node). They are identified in a 46d52e2839SNicholas Guy // depth-first manner, traversing and identifying the operands of each 47d52e2839SNicholas Guy // instruction in the order they appear in the IR. 48d52e2839SNicholas Guy // Each node maintains a reference to its Real and Imaginary instructions, 49d52e2839SNicholas Guy // as well as any additional instructions that make up the identified operation 50d52e2839SNicholas Guy // (Internal instructions should only have uses within their containing node). 51d52e2839SNicholas Guy // A Node also contains the rotation and operation type that it represents. 52d52e2839SNicholas Guy // Operands contains pointers to other CompositeNodes, acting as the edges in 53d52e2839SNicholas Guy // the graph. ReplacementValue is the transformed Value* that has been emitted 54d52e2839SNicholas Guy // to the IR. 55d52e2839SNicholas Guy // 56d52e2839SNicholas Guy // Note: If the operation of a Node is Shuffle, only the Real, Imaginary, and 57d52e2839SNicholas Guy // ReplacementValue fields of that Node are relevant, where the ReplacementValue 58d52e2839SNicholas Guy // should be pre-populated. 59d52e2839SNicholas Guy // 60d52e2839SNicholas Guy //===----------------------------------------------------------------------===// 61d52e2839SNicholas Guy 62d52e2839SNicholas Guy #include "llvm/CodeGen/ComplexDeinterleavingPass.h" 634c9223c7SFlorian Hahn #include "llvm/ADT/MapVector.h" 64d52e2839SNicholas Guy #include "llvm/ADT/Statistic.h" 65d52e2839SNicholas Guy #include "llvm/Analysis/TargetLibraryInfo.h" 66d52e2839SNicholas Guy #include "llvm/Analysis/TargetTransformInfo.h" 67d52e2839SNicholas Guy #include "llvm/CodeGen/TargetLowering.h" 68d52e2839SNicholas Guy #include "llvm/CodeGen/TargetSubtargetInfo.h" 69d52e2839SNicholas Guy #include "llvm/IR/IRBuilder.h" 70ac73c48eSElliot Goodrich #include "llvm/IR/PatternMatch.h" 71d52e2839SNicholas Guy #include "llvm/InitializePasses.h" 72d52e2839SNicholas Guy #include "llvm/Target/TargetMachine.h" 73d52e2839SNicholas Guy #include "llvm/Transforms/Utils/Local.h" 74d52e2839SNicholas Guy #include <algorithm> 75d52e2839SNicholas Guy 76d52e2839SNicholas Guy using namespace llvm; 77d52e2839SNicholas Guy using namespace PatternMatch; 78d52e2839SNicholas Guy 79d52e2839SNicholas Guy #define DEBUG_TYPE "complex-deinterleaving" 80d52e2839SNicholas Guy 81d52e2839SNicholas Guy STATISTIC(NumComplexTransformations, "Amount of complex patterns transformed"); 82d52e2839SNicholas Guy 83d52e2839SNicholas Guy static cl::opt<bool> ComplexDeinterleavingEnabled( 84d52e2839SNicholas Guy "enable-complex-deinterleaving", 85d52e2839SNicholas Guy cl::desc("Enable generation of complex instructions"), cl::init(true), 86d52e2839SNicholas Guy cl::Hidden); 87d52e2839SNicholas Guy 88d52e2839SNicholas Guy /// Checks the given mask, and determines whether said mask is interleaving. 89d52e2839SNicholas Guy /// 90d52e2839SNicholas Guy /// To be interleaving, a mask must alternate between `i` and `i + (Length / 91d52e2839SNicholas Guy /// 2)`, and must contain all numbers within the range of `[0..Length)` (e.g. a 92d52e2839SNicholas Guy /// 4x vector interleaving mask would be <0, 2, 1, 3>). 93d52e2839SNicholas Guy static bool isInterleavingMask(ArrayRef<int> Mask); 94d52e2839SNicholas Guy 95d52e2839SNicholas Guy /// Checks the given mask, and determines whether said mask is deinterleaving. 96d52e2839SNicholas Guy /// 97d52e2839SNicholas Guy /// To be deinterleaving, a mask must increment in steps of 2, and either start 98d52e2839SNicholas Guy /// with 0 or 1. 99d52e2839SNicholas Guy /// (e.g. an 8x vector deinterleaving mask would be either <0, 2, 4, 6> or 100d52e2839SNicholas Guy /// <1, 3, 5, 7>). 101d52e2839SNicholas Guy static bool isDeinterleavingMask(ArrayRef<int> Mask); 102d52e2839SNicholas Guy 103c15557d6SIgor Kirillov /// Returns true if the operation is a negation of V, and it works for both 104c15557d6SIgor Kirillov /// integers and floats. 105c15557d6SIgor Kirillov static bool isNeg(Value *V); 106c15557d6SIgor Kirillov 107c15557d6SIgor Kirillov /// Returns the operand for negation operation. 108c15557d6SIgor Kirillov static Value *getNegOperand(Value *V); 109c15557d6SIgor Kirillov 110d52e2839SNicholas Guy namespace { 1118e1b49c3SNicholas Guy template <typename T, typename IterT> 1128e1b49c3SNicholas Guy std::optional<T> findCommonBetweenCollections(IterT A, IterT B) { 1138e1b49c3SNicholas Guy auto Common = llvm::find_if(A, [B](T I) { return llvm::is_contained(B, I); }); 1148e1b49c3SNicholas Guy if (Common != A.end()) 1158e1b49c3SNicholas Guy return std::make_optional(*Common); 1168e1b49c3SNicholas Guy return std::nullopt; 1178e1b49c3SNicholas Guy } 118d52e2839SNicholas Guy 119d52e2839SNicholas Guy class ComplexDeinterleavingLegacyPass : public FunctionPass { 120d52e2839SNicholas Guy public: 121d52e2839SNicholas Guy static char ID; 122d52e2839SNicholas Guy 123d52e2839SNicholas Guy ComplexDeinterleavingLegacyPass(const TargetMachine *TM = nullptr) 124d52e2839SNicholas Guy : FunctionPass(ID), TM(TM) { 125d52e2839SNicholas Guy initializeComplexDeinterleavingLegacyPassPass( 126d52e2839SNicholas Guy *PassRegistry::getPassRegistry()); 127d52e2839SNicholas Guy } 128d52e2839SNicholas Guy 129d52e2839SNicholas Guy StringRef getPassName() const override { 130d52e2839SNicholas Guy return "Complex Deinterleaving Pass"; 131d52e2839SNicholas Guy } 132d52e2839SNicholas Guy 133d52e2839SNicholas Guy bool runOnFunction(Function &F) override; 134d52e2839SNicholas Guy void getAnalysisUsage(AnalysisUsage &AU) const override { 135d52e2839SNicholas Guy AU.addRequired<TargetLibraryInfoWrapperPass>(); 136d52e2839SNicholas Guy AU.setPreservesCFG(); 137d52e2839SNicholas Guy } 138d52e2839SNicholas Guy 139d52e2839SNicholas Guy private: 140d52e2839SNicholas Guy const TargetMachine *TM; 141d52e2839SNicholas Guy }; 142d52e2839SNicholas Guy 143d52e2839SNicholas Guy class ComplexDeinterleavingGraph; 144d52e2839SNicholas Guy struct ComplexDeinterleavingCompositeNode { 145d52e2839SNicholas Guy 146d52e2839SNicholas Guy ComplexDeinterleavingCompositeNode(ComplexDeinterleavingOperation Op, 147b4f9c3a9SIgor Kirillov Value *R, Value *I) 148d52e2839SNicholas Guy : Operation(Op), Real(R), Imag(I) {} 149d52e2839SNicholas Guy 150d52e2839SNicholas Guy private: 151d52e2839SNicholas Guy friend class ComplexDeinterleavingGraph; 152d52e2839SNicholas Guy using NodePtr = std::shared_ptr<ComplexDeinterleavingCompositeNode>; 153d52e2839SNicholas Guy using RawNodePtr = ComplexDeinterleavingCompositeNode *; 1548e1b49c3SNicholas Guy bool OperandsValid = true; 155d52e2839SNicholas Guy 156d52e2839SNicholas Guy public: 157d52e2839SNicholas Guy ComplexDeinterleavingOperation Operation; 158b4f9c3a9SIgor Kirillov Value *Real; 159b4f9c3a9SIgor Kirillov Value *Imag; 160d52e2839SNicholas Guy 1611a1e7610SIgor Kirillov // This two members are required exclusively for generating 1621a1e7610SIgor Kirillov // ComplexDeinterleavingOperation::Symmetric operations. 1631a1e7610SIgor Kirillov unsigned Opcode; 164c15557d6SIgor Kirillov std::optional<FastMathFlags> Flags; 1651a1e7610SIgor Kirillov 166aab0ca3eSAkshay Khadse ComplexDeinterleavingRotation Rotation = 167aab0ca3eSAkshay Khadse ComplexDeinterleavingRotation::Rotation_0; 168d52e2839SNicholas Guy SmallVector<RawNodePtr> Operands; 169d52e2839SNicholas Guy Value *ReplacementNode = nullptr; 170d52e2839SNicholas Guy 1718e1b49c3SNicholas Guy void addOperand(NodePtr Node) { 1728e1b49c3SNicholas Guy if (!Node || !Node.get()) 1738e1b49c3SNicholas Guy OperandsValid = false; 1748e1b49c3SNicholas Guy Operands.push_back(Node.get()); 1758e1b49c3SNicholas Guy } 176d52e2839SNicholas Guy 177d52e2839SNicholas Guy void dump() { dump(dbgs()); } 178d52e2839SNicholas Guy void dump(raw_ostream &OS) { 179d52e2839SNicholas Guy auto PrintValue = [&](Value *V) { 180d52e2839SNicholas Guy if (V) { 181d52e2839SNicholas Guy OS << "\""; 182d52e2839SNicholas Guy V->print(OS, true); 183d52e2839SNicholas Guy OS << "\"\n"; 184d52e2839SNicholas Guy } else 185d52e2839SNicholas Guy OS << "nullptr\n"; 186d52e2839SNicholas Guy }; 187d52e2839SNicholas Guy auto PrintNodeRef = [&](RawNodePtr Ptr) { 188d52e2839SNicholas Guy if (Ptr) 189d52e2839SNicholas Guy OS << Ptr << "\n"; 190d52e2839SNicholas Guy else 191d52e2839SNicholas Guy OS << "nullptr\n"; 192d52e2839SNicholas Guy }; 193d52e2839SNicholas Guy 194d52e2839SNicholas Guy OS << "- CompositeNode: " << this << "\n"; 195d52e2839SNicholas Guy OS << " Real: "; 196d52e2839SNicholas Guy PrintValue(Real); 197d52e2839SNicholas Guy OS << " Imag: "; 198d52e2839SNicholas Guy PrintValue(Imag); 199d52e2839SNicholas Guy OS << " ReplacementNode: "; 200d52e2839SNicholas Guy PrintValue(ReplacementNode); 201d52e2839SNicholas Guy OS << " Operation: " << (int)Operation << "\n"; 202d52e2839SNicholas Guy OS << " Rotation: " << ((int)Rotation * 90) << "\n"; 203d52e2839SNicholas Guy OS << " Operands: \n"; 204d52e2839SNicholas Guy for (const auto &Op : Operands) { 205d52e2839SNicholas Guy OS << " - "; 206d52e2839SNicholas Guy PrintNodeRef(Op); 207d52e2839SNicholas Guy } 208d52e2839SNicholas Guy } 2098e1b49c3SNicholas Guy 2108e1b49c3SNicholas Guy bool areOperandsValid() { return OperandsValid; } 211d52e2839SNicholas Guy }; 212d52e2839SNicholas Guy 213d52e2839SNicholas Guy class ComplexDeinterleavingGraph { 214d52e2839SNicholas Guy public: 2151a1e7610SIgor Kirillov struct Product { 216b4f9c3a9SIgor Kirillov Value *Multiplier; 217b4f9c3a9SIgor Kirillov Value *Multiplicand; 2181a1e7610SIgor Kirillov bool IsPositive; 2191a1e7610SIgor Kirillov }; 2201a1e7610SIgor Kirillov 221b4f9c3a9SIgor Kirillov using Addend = std::pair<Value *, bool>; 222d52e2839SNicholas Guy using NodePtr = ComplexDeinterleavingCompositeNode::NodePtr; 223d52e2839SNicholas Guy using RawNodePtr = ComplexDeinterleavingCompositeNode::RawNodePtr; 2241a1e7610SIgor Kirillov 2251a1e7610SIgor Kirillov // Helper struct for holding info about potential partial multiplication 2261a1e7610SIgor Kirillov // candidates 2271a1e7610SIgor Kirillov struct PartialMulCandidate { 228b4f9c3a9SIgor Kirillov Value *Common; 2291a1e7610SIgor Kirillov NodePtr Node; 2301a1e7610SIgor Kirillov unsigned RealIdx; 2311a1e7610SIgor Kirillov unsigned ImagIdx; 2321a1e7610SIgor Kirillov bool IsNodeInverted; 2331a1e7610SIgor Kirillov }; 2341a1e7610SIgor Kirillov 235c692e87aSIgor Kirillov explicit ComplexDeinterleavingGraph(const TargetLowering *TL, 236c692e87aSIgor Kirillov const TargetLibraryInfo *TLI) 237c692e87aSIgor Kirillov : TL(TL), TLI(TLI) {} 238d52e2839SNicholas Guy 239d52e2839SNicholas Guy private: 2408bf7f86dSAkshay Khadse const TargetLowering *TL = nullptr; 241c692e87aSIgor Kirillov const TargetLibraryInfo *TLI = nullptr; 242d52e2839SNicholas Guy SmallVector<NodePtr> CompositeNodes; 24346b2ad02SIgor Kirillov DenseMap<std::pair<Value *, Value *>, NodePtr> CachedResult; 244c692e87aSIgor Kirillov 245c692e87aSIgor Kirillov SmallPtrSet<Instruction *, 16> FinalInstructions; 246c692e87aSIgor Kirillov 247c692e87aSIgor Kirillov /// Root instructions are instructions from which complex computation starts 248c692e87aSIgor Kirillov std::map<Instruction *, NodePtr> RootToNode; 249c692e87aSIgor Kirillov 250c692e87aSIgor Kirillov /// Topologically sorted root instructions 251c692e87aSIgor Kirillov SmallVector<Instruction *, 1> OrderedRoots; 252d52e2839SNicholas Guy 2532cbc265cSIgor Kirillov /// When examining a basic block for complex deinterleaving, if it is a simple 2542cbc265cSIgor Kirillov /// one-block loop, then the only incoming block is 'Incoming' and the 2552cbc265cSIgor Kirillov /// 'BackEdge' block is the block itself." 2562cbc265cSIgor Kirillov BasicBlock *BackEdge = nullptr; 2572cbc265cSIgor Kirillov BasicBlock *Incoming = nullptr; 2582cbc265cSIgor Kirillov 2592cbc265cSIgor Kirillov /// ReductionInfo maps from %ReductionOp to %PHInode and Instruction 2602cbc265cSIgor Kirillov /// %OutsideUser as it is shown in the IR: 2612cbc265cSIgor Kirillov /// 2622cbc265cSIgor Kirillov /// vector.body: 2632cbc265cSIgor Kirillov /// %PHInode = phi <vector type> [ zeroinitializer, %entry ], 2642cbc265cSIgor Kirillov /// [ %ReductionOp, %vector.body ] 2652cbc265cSIgor Kirillov /// ... 2662cbc265cSIgor Kirillov /// %ReductionOp = fadd i64 ... 2672cbc265cSIgor Kirillov /// ... 2682cbc265cSIgor Kirillov /// br i1 %condition, label %vector.body, %middle.block 2692cbc265cSIgor Kirillov /// 2702cbc265cSIgor Kirillov /// middle.block: 2712cbc265cSIgor Kirillov /// %OutsideUser = llvm.vector.reduce.fadd(..., %ReductionOp) 2722cbc265cSIgor Kirillov /// 2732cbc265cSIgor Kirillov /// %OutsideUser can be `llvm.vector.reduce.fadd` or `fadd` preceding 2742cbc265cSIgor Kirillov /// `llvm.vector.reduce.fadd` when unroll factor isn't one. 2754c9223c7SFlorian Hahn MapVector<Instruction *, std::pair<PHINode *, Instruction *>> ReductionInfo; 2762cbc265cSIgor Kirillov 2772cbc265cSIgor Kirillov /// In the process of detecting a reduction, we consider a pair of 2782cbc265cSIgor Kirillov /// %ReductionOP, which we refer to as real and imag (or vice versa), and 2792cbc265cSIgor Kirillov /// traverse the use-tree to detect complex operations. As this is a reduction 2802cbc265cSIgor Kirillov /// operation, it will eventually reach RealPHI and ImagPHI, which corresponds 2812cbc265cSIgor Kirillov /// to the %ReductionOPs that we suspect to be complex. 2822cbc265cSIgor Kirillov /// RealPHI and ImagPHI are used by the identifyPHINode method. 2832cbc265cSIgor Kirillov PHINode *RealPHI = nullptr; 2842cbc265cSIgor Kirillov PHINode *ImagPHI = nullptr; 2852cbc265cSIgor Kirillov 2860aecf7ffSIgor Kirillov /// Set this flag to true if RealPHI and ImagPHI were reached during reduction 2870aecf7ffSIgor Kirillov /// detection. 2880aecf7ffSIgor Kirillov bool PHIsFound = false; 2890aecf7ffSIgor Kirillov 2902cbc265cSIgor Kirillov /// OldToNewPHI maps the original real PHINode to a new, double-sized PHINode. 2912cbc265cSIgor Kirillov /// The new PHINode corresponds to a vector of deinterleaved complex numbers. 2922cbc265cSIgor Kirillov /// This mapping is populated during 2932cbc265cSIgor Kirillov /// ComplexDeinterleavingOperation::ReductionPHI node replacement. It is then 2942cbc265cSIgor Kirillov /// used in the ComplexDeinterleavingOperation::ReductionOperation node 2952cbc265cSIgor Kirillov /// replacement process. 2962cbc265cSIgor Kirillov std::map<PHINode *, PHINode *> OldToNewPHI; 2972cbc265cSIgor Kirillov 298d52e2839SNicholas Guy NodePtr prepareCompositeNode(ComplexDeinterleavingOperation Operation, 299b4f9c3a9SIgor Kirillov Value *R, Value *I) { 3002cbc265cSIgor Kirillov assert(((Operation != ComplexDeinterleavingOperation::ReductionPHI && 3012cbc265cSIgor Kirillov Operation != ComplexDeinterleavingOperation::ReductionOperation) || 3022cbc265cSIgor Kirillov (R && I)) && 3032cbc265cSIgor Kirillov "Reduction related nodes must have Real and Imaginary parts"); 304d52e2839SNicholas Guy return std::make_shared<ComplexDeinterleavingCompositeNode>(Operation, R, 305d52e2839SNicholas Guy I); 306d52e2839SNicholas Guy } 307d52e2839SNicholas Guy 308d52e2839SNicholas Guy NodePtr submitCompositeNode(NodePtr Node) { 309d52e2839SNicholas Guy CompositeNodes.push_back(Node); 3108e1b49c3SNicholas Guy if (Node->Real) 31146b2ad02SIgor Kirillov CachedResult[{Node->Real, Node->Imag}] = Node; 312d52e2839SNicholas Guy return Node; 313d52e2839SNicholas Guy } 314d52e2839SNicholas Guy 315d52e2839SNicholas Guy /// Identifies a complex partial multiply pattern and its rotation, based on 316d52e2839SNicholas Guy /// the following patterns 317d52e2839SNicholas Guy /// 318d52e2839SNicholas Guy /// 0: r: cr + ar * br 319d52e2839SNicholas Guy /// i: ci + ar * bi 320d52e2839SNicholas Guy /// 90: r: cr - ai * bi 321d52e2839SNicholas Guy /// i: ci + ai * br 322d52e2839SNicholas Guy /// 180: r: cr - ar * br 323d52e2839SNicholas Guy /// i: ci - ar * bi 324d52e2839SNicholas Guy /// 270: r: cr + ai * bi 325d52e2839SNicholas Guy /// i: ci - ai * br 326d52e2839SNicholas Guy NodePtr identifyPartialMul(Instruction *Real, Instruction *Imag); 327d52e2839SNicholas Guy 328d52e2839SNicholas Guy /// Identify the other branch of a Partial Mul, taking the CommonOperandI that 329d52e2839SNicholas Guy /// is partially known from identifyPartialMul, filling in the other half of 330d52e2839SNicholas Guy /// the complex pair. 331b4f9c3a9SIgor Kirillov NodePtr 332b4f9c3a9SIgor Kirillov identifyNodeWithImplicitAdd(Instruction *I, Instruction *J, 333b4f9c3a9SIgor Kirillov std::pair<Value *, Value *> &CommonOperandI); 334d52e2839SNicholas Guy 335d52e2839SNicholas Guy /// Identifies a complex add pattern and its rotation, based on the following 336d52e2839SNicholas Guy /// patterns. 337d52e2839SNicholas Guy /// 338d52e2839SNicholas Guy /// 90: r: ar - bi 339d52e2839SNicholas Guy /// i: ai + br 340d52e2839SNicholas Guy /// 270: r: ar + bi 341d52e2839SNicholas Guy /// i: ai - br 342d52e2839SNicholas Guy NodePtr identifyAdd(Instruction *Real, Instruction *Imag); 34396615c85SNicholas Guy NodePtr identifySymmetricOperation(Instruction *Real, Instruction *Imag); 3448e1b49c3SNicholas Guy NodePtr identifyPartialReduction(Value *R, Value *I); 3458e1b49c3SNicholas Guy NodePtr identifyDotProduct(Value *Inst); 346d52e2839SNicholas Guy 347b4f9c3a9SIgor Kirillov NodePtr identifyNode(Value *R, Value *I); 348d52e2839SNicholas Guy 3491a1e7610SIgor Kirillov /// Determine if a sum of complex numbers can be formed from \p RealAddends 3501a1e7610SIgor Kirillov /// and \p ImagAddens. If \p Accumulator is not null, add the result to it. 3511a1e7610SIgor Kirillov /// Return nullptr if it is not possible to construct a complex number. 3521a1e7610SIgor Kirillov /// \p Flags are needed to generate symmetric Add and Sub operations. 3531a1e7610SIgor Kirillov NodePtr identifyAdditions(std::list<Addend> &RealAddends, 354c15557d6SIgor Kirillov std::list<Addend> &ImagAddends, 355c15557d6SIgor Kirillov std::optional<FastMathFlags> Flags, 3561a1e7610SIgor Kirillov NodePtr Accumulator); 3571a1e7610SIgor Kirillov 3581a1e7610SIgor Kirillov /// Extract one addend that have both real and imaginary parts positive. 3591a1e7610SIgor Kirillov NodePtr extractPositiveAddend(std::list<Addend> &RealAddends, 3601a1e7610SIgor Kirillov std::list<Addend> &ImagAddends); 3611a1e7610SIgor Kirillov 3621a1e7610SIgor Kirillov /// Determine if sum of multiplications of complex numbers can be formed from 3631a1e7610SIgor Kirillov /// \p RealMuls and \p ImagMuls. If \p Accumulator is not null, add the result 3641a1e7610SIgor Kirillov /// to it. Return nullptr if it is not possible to construct a complex number. 3651a1e7610SIgor Kirillov NodePtr identifyMultiplications(std::vector<Product> &RealMuls, 3661a1e7610SIgor Kirillov std::vector<Product> &ImagMuls, 3671a1e7610SIgor Kirillov NodePtr Accumulator); 3681a1e7610SIgor Kirillov 3691a1e7610SIgor Kirillov /// Go through pairs of multiplication (one Real and one Imag) and find all 3701a1e7610SIgor Kirillov /// possible candidates for partial multiplication and put them into \p 3711a1e7610SIgor Kirillov /// Candidates. Returns true if all Product has pair with common operand 3721a1e7610SIgor Kirillov bool collectPartialMuls(const std::vector<Product> &RealMuls, 3731a1e7610SIgor Kirillov const std::vector<Product> &ImagMuls, 3741a1e7610SIgor Kirillov std::vector<PartialMulCandidate> &Candidates); 3751a1e7610SIgor Kirillov 3761a1e7610SIgor Kirillov /// If the code is compiled with -Ofast or expressions have `reassoc` flag, 3771a1e7610SIgor Kirillov /// the order of complex computation operations may be significantly altered, 3781a1e7610SIgor Kirillov /// and the real and imaginary parts may not be executed in parallel. This 3791a1e7610SIgor Kirillov /// function takes this into consideration and employs a more general approach 3801a1e7610SIgor Kirillov /// to identify complex computations. Initially, it gathers all the addends 3811a1e7610SIgor Kirillov /// and multiplicands and then constructs a complex expression from them. 3821a1e7610SIgor Kirillov NodePtr identifyReassocNodes(Instruction *I, Instruction *J); 3831a1e7610SIgor Kirillov 3846850bc35SIgor Kirillov NodePtr identifyRoot(Instruction *I); 3856850bc35SIgor Kirillov 3866850bc35SIgor Kirillov /// Identifies the Deinterleave operation applied to a vector containing 3876850bc35SIgor Kirillov /// complex numbers. There are two ways to represent the Deinterleave 3886850bc35SIgor Kirillov /// operation: 3896850bc35SIgor Kirillov /// * Using two shufflevectors with even indices for /pReal instruction and 3906850bc35SIgor Kirillov /// odd indices for /pImag instructions (only for fixed-width vectors) 3916850bc35SIgor Kirillov /// * Using two extractvalue instructions applied to `vector.deinterleave2` 3926850bc35SIgor Kirillov /// intrinsic (for both fixed and scalable vectors) 3936850bc35SIgor Kirillov NodePtr identifyDeinterleave(Instruction *Real, Instruction *Imag); 3946850bc35SIgor Kirillov 3957f20407cSIgor Kirillov /// identifying the operation that represents a complex number repeated in a 3967f20407cSIgor Kirillov /// Splat vector. There are two possible types of splats: ConstantExpr with 3977f20407cSIgor Kirillov /// the opcode ShuffleVector and ShuffleVectorInstr. Both should have an 3987f20407cSIgor Kirillov /// initialization mask with all values set to zero. 3997f20407cSIgor Kirillov NodePtr identifySplat(Value *Real, Value *Imag); 4007f20407cSIgor Kirillov 4012cbc265cSIgor Kirillov NodePtr identifyPHINode(Instruction *Real, Instruction *Imag); 4022cbc265cSIgor Kirillov 40304a8070bSIgor Kirillov /// Identifies SelectInsts in a loop that has reduction with predication masks 40404a8070bSIgor Kirillov /// and/or predicated tail folding 40504a8070bSIgor Kirillov NodePtr identifySelectNode(Instruction *Real, Instruction *Imag); 40604a8070bSIgor Kirillov 40740a81d31SIgor Kirillov Value *replaceNode(IRBuilderBase &Builder, RawNodePtr Node); 408d52e2839SNicholas Guy 4092cbc265cSIgor Kirillov /// Complete IR modifications after producing new reduction operation: 4102cbc265cSIgor Kirillov /// * Populate the PHINode generated for 4112cbc265cSIgor Kirillov /// ComplexDeinterleavingOperation::ReductionPHI 4122cbc265cSIgor Kirillov /// * Deinterleave the final value outside of the loop and repurpose original 4132cbc265cSIgor Kirillov /// reduction users 4142cbc265cSIgor Kirillov void processReductionOperation(Value *OperationReplacement, RawNodePtr Node); 4158e1b49c3SNicholas Guy void processReductionSingle(Value *OperationReplacement, RawNodePtr Node); 4162cbc265cSIgor Kirillov 417d52e2839SNicholas Guy public: 418d52e2839SNicholas Guy void dump() { dump(dbgs()); } 419d52e2839SNicholas Guy void dump(raw_ostream &OS) { 420d52e2839SNicholas Guy for (const auto &Node : CompositeNodes) 421d52e2839SNicholas Guy Node->dump(OS); 422d52e2839SNicholas Guy } 423d52e2839SNicholas Guy 424d52e2839SNicholas Guy /// Returns false if the deinterleaving operation should be cancelled for the 425d52e2839SNicholas Guy /// current graph. 426d52e2839SNicholas Guy bool identifyNodes(Instruction *RootI); 427d52e2839SNicholas Guy 4282cbc265cSIgor Kirillov /// In case \pB is one-block loop, this function seeks potential reductions 4292cbc265cSIgor Kirillov /// and populates ReductionInfo. Returns true if any reductions were 4302cbc265cSIgor Kirillov /// identified. 4312cbc265cSIgor Kirillov bool collectPotentialReductions(BasicBlock *B); 4322cbc265cSIgor Kirillov 4332cbc265cSIgor Kirillov void identifyReductionNodes(); 4342cbc265cSIgor Kirillov 435c692e87aSIgor Kirillov /// Check that every instruction, from the roots to the leaves, has internal 436c692e87aSIgor Kirillov /// uses. 437c692e87aSIgor Kirillov bool checkNodes(); 438c692e87aSIgor Kirillov 439d52e2839SNicholas Guy /// Perform the actual replacement of the underlying instruction graph. 440d52e2839SNicholas Guy void replaceNodes(); 441d52e2839SNicholas Guy }; 442d52e2839SNicholas Guy 443d52e2839SNicholas Guy class ComplexDeinterleaving { 444d52e2839SNicholas Guy public: 445d52e2839SNicholas Guy ComplexDeinterleaving(const TargetLowering *tl, const TargetLibraryInfo *tli) 446d52e2839SNicholas Guy : TL(tl), TLI(tli) {} 447d52e2839SNicholas Guy bool runOnFunction(Function &F); 448d52e2839SNicholas Guy 449d52e2839SNicholas Guy private: 450d52e2839SNicholas Guy bool evaluateBasicBlock(BasicBlock *B); 451d52e2839SNicholas Guy 452d52e2839SNicholas Guy const TargetLowering *TL = nullptr; 453d52e2839SNicholas Guy const TargetLibraryInfo *TLI = nullptr; 454d52e2839SNicholas Guy }; 455d52e2839SNicholas Guy 456d52e2839SNicholas Guy } // namespace 457d52e2839SNicholas Guy 458d52e2839SNicholas Guy char ComplexDeinterleavingLegacyPass::ID = 0; 459d52e2839SNicholas Guy 460d52e2839SNicholas Guy INITIALIZE_PASS_BEGIN(ComplexDeinterleavingLegacyPass, DEBUG_TYPE, 461d52e2839SNicholas Guy "Complex Deinterleaving", false, false) 462d52e2839SNicholas Guy INITIALIZE_PASS_END(ComplexDeinterleavingLegacyPass, DEBUG_TYPE, 463d52e2839SNicholas Guy "Complex Deinterleaving", false, false) 464d52e2839SNicholas Guy 465d52e2839SNicholas Guy PreservedAnalyses ComplexDeinterleavingPass::run(Function &F, 466d52e2839SNicholas Guy FunctionAnalysisManager &AM) { 467d52e2839SNicholas Guy const TargetLowering *TL = TM->getSubtargetImpl(F)->getTargetLowering(); 468d52e2839SNicholas Guy auto &TLI = AM.getResult<llvm::TargetLibraryAnalysis>(F); 469d52e2839SNicholas Guy if (!ComplexDeinterleaving(TL, &TLI).runOnFunction(F)) 470d52e2839SNicholas Guy return PreservedAnalyses::all(); 471d52e2839SNicholas Guy 472d52e2839SNicholas Guy PreservedAnalyses PA; 473d52e2839SNicholas Guy PA.preserve<FunctionAnalysisManagerModuleProxy>(); 474d52e2839SNicholas Guy return PA; 475d52e2839SNicholas Guy } 476d52e2839SNicholas Guy 477d52e2839SNicholas Guy FunctionPass *llvm::createComplexDeinterleavingPass(const TargetMachine *TM) { 478d52e2839SNicholas Guy return new ComplexDeinterleavingLegacyPass(TM); 479d52e2839SNicholas Guy } 480d52e2839SNicholas Guy 481d52e2839SNicholas Guy bool ComplexDeinterleavingLegacyPass::runOnFunction(Function &F) { 482d52e2839SNicholas Guy const auto *TL = TM->getSubtargetImpl(F)->getTargetLowering(); 483d52e2839SNicholas Guy auto TLI = getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F); 484d52e2839SNicholas Guy return ComplexDeinterleaving(TL, &TLI).runOnFunction(F); 485d52e2839SNicholas Guy } 486d52e2839SNicholas Guy 487d52e2839SNicholas Guy bool ComplexDeinterleaving::runOnFunction(Function &F) { 488d52e2839SNicholas Guy if (!ComplexDeinterleavingEnabled) { 489d52e2839SNicholas Guy LLVM_DEBUG( 490d52e2839SNicholas Guy dbgs() << "Complex deinterleaving has been explicitly disabled.\n"); 491d52e2839SNicholas Guy return false; 492d52e2839SNicholas Guy } 493d52e2839SNicholas Guy 494d52e2839SNicholas Guy if (!TL->isComplexDeinterleavingSupported()) { 495d52e2839SNicholas Guy LLVM_DEBUG( 496d52e2839SNicholas Guy dbgs() << "Complex deinterleaving has been disabled, target does " 497d52e2839SNicholas Guy "not support lowering of complex number operations.\n"); 498d52e2839SNicholas Guy return false; 499d52e2839SNicholas Guy } 500d52e2839SNicholas Guy 501d52e2839SNicholas Guy bool Changed = false; 502d52e2839SNicholas Guy for (auto &B : F) 503d52e2839SNicholas Guy Changed |= evaluateBasicBlock(&B); 504d52e2839SNicholas Guy 505d52e2839SNicholas Guy return Changed; 506d52e2839SNicholas Guy } 507d52e2839SNicholas Guy 508d52e2839SNicholas Guy static bool isInterleavingMask(ArrayRef<int> Mask) { 509d52e2839SNicholas Guy // If the size is not even, it's not an interleaving mask 510d52e2839SNicholas Guy if ((Mask.size() & 1)) 511d52e2839SNicholas Guy return false; 512d52e2839SNicholas Guy 513d52e2839SNicholas Guy int HalfNumElements = Mask.size() / 2; 514d52e2839SNicholas Guy for (int Idx = 0; Idx < HalfNumElements; ++Idx) { 515d52e2839SNicholas Guy int MaskIdx = Idx * 2; 516d52e2839SNicholas Guy if (Mask[MaskIdx] != Idx || Mask[MaskIdx + 1] != (Idx + HalfNumElements)) 517d52e2839SNicholas Guy return false; 518d52e2839SNicholas Guy } 519d52e2839SNicholas Guy 520d52e2839SNicholas Guy return true; 521d52e2839SNicholas Guy } 522d52e2839SNicholas Guy 523d52e2839SNicholas Guy static bool isDeinterleavingMask(ArrayRef<int> Mask) { 524d52e2839SNicholas Guy int Offset = Mask[0]; 525d52e2839SNicholas Guy int HalfNumElements = Mask.size() / 2; 526d52e2839SNicholas Guy 527d52e2839SNicholas Guy for (int Idx = 1; Idx < HalfNumElements; ++Idx) { 528d52e2839SNicholas Guy if (Mask[Idx] != (Idx * 2) + Offset) 529d52e2839SNicholas Guy return false; 530d52e2839SNicholas Guy } 531d52e2839SNicholas Guy 532d52e2839SNicholas Guy return true; 533d52e2839SNicholas Guy } 534d52e2839SNicholas Guy 535c15557d6SIgor Kirillov bool isNeg(Value *V) { 536c15557d6SIgor Kirillov return match(V, m_FNeg(m_Value())) || match(V, m_Neg(m_Value())); 537c15557d6SIgor Kirillov } 538c15557d6SIgor Kirillov 539c15557d6SIgor Kirillov Value *getNegOperand(Value *V) { 540c15557d6SIgor Kirillov assert(isNeg(V)); 541c15557d6SIgor Kirillov auto *I = cast<Instruction>(V); 542c15557d6SIgor Kirillov if (I->getOpcode() == Instruction::FNeg) 543c15557d6SIgor Kirillov return I->getOperand(0); 544c15557d6SIgor Kirillov 545c15557d6SIgor Kirillov return I->getOperand(1); 546c15557d6SIgor Kirillov } 547c15557d6SIgor Kirillov 548d52e2839SNicholas Guy bool ComplexDeinterleaving::evaluateBasicBlock(BasicBlock *B) { 549c692e87aSIgor Kirillov ComplexDeinterleavingGraph Graph(TL, TLI); 5502cbc265cSIgor Kirillov if (Graph.collectPotentialReductions(B)) 5512cbc265cSIgor Kirillov Graph.identifyReductionNodes(); 5522cbc265cSIgor Kirillov 5536850bc35SIgor Kirillov for (auto &I : *B) 5546850bc35SIgor Kirillov Graph.identifyNodes(&I); 555d52e2839SNicholas Guy 556c692e87aSIgor Kirillov if (Graph.checkNodes()) { 557d52e2839SNicholas Guy Graph.replaceNodes(); 558c692e87aSIgor Kirillov return true; 559d52e2839SNicholas Guy } 560d52e2839SNicholas Guy 561c692e87aSIgor Kirillov return false; 562d52e2839SNicholas Guy } 563d52e2839SNicholas Guy 564d52e2839SNicholas Guy ComplexDeinterleavingGraph::NodePtr 565d52e2839SNicholas Guy ComplexDeinterleavingGraph::identifyNodeWithImplicitAdd( 566d52e2839SNicholas Guy Instruction *Real, Instruction *Imag, 567b4f9c3a9SIgor Kirillov std::pair<Value *, Value *> &PartialMatch) { 568d52e2839SNicholas Guy LLVM_DEBUG(dbgs() << "identifyNodeWithImplicitAdd " << *Real << " / " << *Imag 569d52e2839SNicholas Guy << "\n"); 570d52e2839SNicholas Guy 571d52e2839SNicholas Guy if (!Real->hasOneUse() || !Imag->hasOneUse()) { 572d52e2839SNicholas Guy LLVM_DEBUG(dbgs() << " - Mul operand has multiple uses.\n"); 573d52e2839SNicholas Guy return nullptr; 574d52e2839SNicholas Guy } 575d52e2839SNicholas Guy 576c15557d6SIgor Kirillov if ((Real->getOpcode() != Instruction::FMul && 577c15557d6SIgor Kirillov Real->getOpcode() != Instruction::Mul) || 578c15557d6SIgor Kirillov (Imag->getOpcode() != Instruction::FMul && 579c15557d6SIgor Kirillov Imag->getOpcode() != Instruction::Mul)) { 580c15557d6SIgor Kirillov LLVM_DEBUG( 581c15557d6SIgor Kirillov dbgs() << " - Real or imaginary instruction is not fmul or mul\n"); 582d52e2839SNicholas Guy return nullptr; 583d52e2839SNicholas Guy } 584d52e2839SNicholas Guy 585b4f9c3a9SIgor Kirillov Value *R0 = Real->getOperand(0); 586b4f9c3a9SIgor Kirillov Value *R1 = Real->getOperand(1); 587b4f9c3a9SIgor Kirillov Value *I0 = Imag->getOperand(0); 588b4f9c3a9SIgor Kirillov Value *I1 = Imag->getOperand(1); 589d52e2839SNicholas Guy 590d52e2839SNicholas Guy // A +/+ has a rotation of 0. If any of the operands are fneg, we flip the 591d52e2839SNicholas Guy // rotations and use the operand. 592d52e2839SNicholas Guy unsigned Negs = 0; 593b4f9c3a9SIgor Kirillov Value *Op; 594b4f9c3a9SIgor Kirillov if (match(R0, m_Neg(m_Value(Op)))) { 595d52e2839SNicholas Guy Negs |= 1; 596b4f9c3a9SIgor Kirillov R0 = Op; 597b4f9c3a9SIgor Kirillov } else if (match(R1, m_Neg(m_Value(Op)))) { 598b4f9c3a9SIgor Kirillov Negs |= 1; 599b4f9c3a9SIgor Kirillov R1 = Op; 600d52e2839SNicholas Guy } 601b4f9c3a9SIgor Kirillov 602c15557d6SIgor Kirillov if (isNeg(I0)) { 603d52e2839SNicholas Guy Negs |= 2; 604d52e2839SNicholas Guy Negs ^= 1; 605b4f9c3a9SIgor Kirillov I0 = Op; 606b4f9c3a9SIgor Kirillov } else if (match(I1, m_Neg(m_Value(Op)))) { 607b4f9c3a9SIgor Kirillov Negs |= 2; 608b4f9c3a9SIgor Kirillov Negs ^= 1; 609b4f9c3a9SIgor Kirillov I1 = Op; 610d52e2839SNicholas Guy } 611d52e2839SNicholas Guy 612d52e2839SNicholas Guy ComplexDeinterleavingRotation Rotation = (ComplexDeinterleavingRotation)Negs; 613d52e2839SNicholas Guy 614b4f9c3a9SIgor Kirillov Value *CommonOperand; 615b4f9c3a9SIgor Kirillov Value *UncommonRealOp; 616b4f9c3a9SIgor Kirillov Value *UncommonImagOp; 617d52e2839SNicholas Guy 618d52e2839SNicholas Guy if (R0 == I0 || R0 == I1) { 619d52e2839SNicholas Guy CommonOperand = R0; 620d52e2839SNicholas Guy UncommonRealOp = R1; 621d52e2839SNicholas Guy } else if (R1 == I0 || R1 == I1) { 622d52e2839SNicholas Guy CommonOperand = R1; 623d52e2839SNicholas Guy UncommonRealOp = R0; 624d52e2839SNicholas Guy } else { 625d52e2839SNicholas Guy LLVM_DEBUG(dbgs() << " - No equal operand\n"); 626d52e2839SNicholas Guy return nullptr; 627d52e2839SNicholas Guy } 628d52e2839SNicholas Guy 629d52e2839SNicholas Guy UncommonImagOp = (CommonOperand == I0) ? I1 : I0; 630d52e2839SNicholas Guy if (Rotation == ComplexDeinterleavingRotation::Rotation_90 || 631d52e2839SNicholas Guy Rotation == ComplexDeinterleavingRotation::Rotation_270) 632d52e2839SNicholas Guy std::swap(UncommonRealOp, UncommonImagOp); 633d52e2839SNicholas Guy 634d52e2839SNicholas Guy // Between identifyPartialMul and here we need to have found a complete valid 635d52e2839SNicholas Guy // pair from the CommonOperand of each part. 636d52e2839SNicholas Guy if (Rotation == ComplexDeinterleavingRotation::Rotation_0 || 637d52e2839SNicholas Guy Rotation == ComplexDeinterleavingRotation::Rotation_180) 638d52e2839SNicholas Guy PartialMatch.first = CommonOperand; 639d52e2839SNicholas Guy else 640d52e2839SNicholas Guy PartialMatch.second = CommonOperand; 641d52e2839SNicholas Guy 642d52e2839SNicholas Guy if (!PartialMatch.first || !PartialMatch.second) { 643d52e2839SNicholas Guy LLVM_DEBUG(dbgs() << " - Incomplete partial match\n"); 644d52e2839SNicholas Guy return nullptr; 645d52e2839SNicholas Guy } 646d52e2839SNicholas Guy 647d52e2839SNicholas Guy NodePtr CommonNode = identifyNode(PartialMatch.first, PartialMatch.second); 648d52e2839SNicholas Guy if (!CommonNode) { 649d52e2839SNicholas Guy LLVM_DEBUG(dbgs() << " - No CommonNode identified\n"); 650d52e2839SNicholas Guy return nullptr; 651d52e2839SNicholas Guy } 652d52e2839SNicholas Guy 653d52e2839SNicholas Guy NodePtr UncommonNode = identifyNode(UncommonRealOp, UncommonImagOp); 654d52e2839SNicholas Guy if (!UncommonNode) { 655d52e2839SNicholas Guy LLVM_DEBUG(dbgs() << " - No UncommonNode identified\n"); 656d52e2839SNicholas Guy return nullptr; 657d52e2839SNicholas Guy } 658d52e2839SNicholas Guy 659d52e2839SNicholas Guy NodePtr Node = prepareCompositeNode( 660d52e2839SNicholas Guy ComplexDeinterleavingOperation::CMulPartial, Real, Imag); 661d52e2839SNicholas Guy Node->Rotation = Rotation; 662d52e2839SNicholas Guy Node->addOperand(CommonNode); 663d52e2839SNicholas Guy Node->addOperand(UncommonNode); 664d52e2839SNicholas Guy return submitCompositeNode(Node); 665d52e2839SNicholas Guy } 666d52e2839SNicholas Guy 667d52e2839SNicholas Guy ComplexDeinterleavingGraph::NodePtr 668d52e2839SNicholas Guy ComplexDeinterleavingGraph::identifyPartialMul(Instruction *Real, 669d52e2839SNicholas Guy Instruction *Imag) { 670d52e2839SNicholas Guy LLVM_DEBUG(dbgs() << "identifyPartialMul " << *Real << " / " << *Imag 671d52e2839SNicholas Guy << "\n"); 672d52e2839SNicholas Guy // Determine rotation 673c15557d6SIgor Kirillov auto IsAdd = [](unsigned Op) { 674c15557d6SIgor Kirillov return Op == Instruction::FAdd || Op == Instruction::Add; 675c15557d6SIgor Kirillov }; 676c15557d6SIgor Kirillov auto IsSub = [](unsigned Op) { 677c15557d6SIgor Kirillov return Op == Instruction::FSub || Op == Instruction::Sub; 678c15557d6SIgor Kirillov }; 679d52e2839SNicholas Guy ComplexDeinterleavingRotation Rotation; 680c15557d6SIgor Kirillov if (IsAdd(Real->getOpcode()) && IsAdd(Imag->getOpcode())) 681d52e2839SNicholas Guy Rotation = ComplexDeinterleavingRotation::Rotation_0; 682c15557d6SIgor Kirillov else if (IsSub(Real->getOpcode()) && IsAdd(Imag->getOpcode())) 683d52e2839SNicholas Guy Rotation = ComplexDeinterleavingRotation::Rotation_90; 684c15557d6SIgor Kirillov else if (IsSub(Real->getOpcode()) && IsSub(Imag->getOpcode())) 685d52e2839SNicholas Guy Rotation = ComplexDeinterleavingRotation::Rotation_180; 686c15557d6SIgor Kirillov else if (IsAdd(Real->getOpcode()) && IsSub(Imag->getOpcode())) 687d52e2839SNicholas Guy Rotation = ComplexDeinterleavingRotation::Rotation_270; 688d52e2839SNicholas Guy else { 689d52e2839SNicholas Guy LLVM_DEBUG(dbgs() << " - Unhandled rotation.\n"); 690d52e2839SNicholas Guy return nullptr; 691d52e2839SNicholas Guy } 692d52e2839SNicholas Guy 693c15557d6SIgor Kirillov if (isa<FPMathOperator>(Real) && 694c15557d6SIgor Kirillov (!Real->getFastMathFlags().allowContract() || 695c15557d6SIgor Kirillov !Imag->getFastMathFlags().allowContract())) { 696d52e2839SNicholas Guy LLVM_DEBUG(dbgs() << " - Contract is missing from the FastMath flags.\n"); 697d52e2839SNicholas Guy return nullptr; 698d52e2839SNicholas Guy } 699d52e2839SNicholas Guy 700d52e2839SNicholas Guy Value *CR = Real->getOperand(0); 701d52e2839SNicholas Guy Instruction *RealMulI = dyn_cast<Instruction>(Real->getOperand(1)); 702d52e2839SNicholas Guy if (!RealMulI) 703d52e2839SNicholas Guy return nullptr; 704d52e2839SNicholas Guy Value *CI = Imag->getOperand(0); 705d52e2839SNicholas Guy Instruction *ImagMulI = dyn_cast<Instruction>(Imag->getOperand(1)); 706d52e2839SNicholas Guy if (!ImagMulI) 707d52e2839SNicholas Guy return nullptr; 708d52e2839SNicholas Guy 709d52e2839SNicholas Guy if (!RealMulI->hasOneUse() || !ImagMulI->hasOneUse()) { 710d52e2839SNicholas Guy LLVM_DEBUG(dbgs() << " - Mul instruction has multiple uses\n"); 711d52e2839SNicholas Guy return nullptr; 712d52e2839SNicholas Guy } 713d52e2839SNicholas Guy 714b4f9c3a9SIgor Kirillov Value *R0 = RealMulI->getOperand(0); 715b4f9c3a9SIgor Kirillov Value *R1 = RealMulI->getOperand(1); 716b4f9c3a9SIgor Kirillov Value *I0 = ImagMulI->getOperand(0); 717b4f9c3a9SIgor Kirillov Value *I1 = ImagMulI->getOperand(1); 718d52e2839SNicholas Guy 719b4f9c3a9SIgor Kirillov Value *CommonOperand; 720b4f9c3a9SIgor Kirillov Value *UncommonRealOp; 721b4f9c3a9SIgor Kirillov Value *UncommonImagOp; 722d52e2839SNicholas Guy 723d52e2839SNicholas Guy if (R0 == I0 || R0 == I1) { 724d52e2839SNicholas Guy CommonOperand = R0; 725d52e2839SNicholas Guy UncommonRealOp = R1; 726d52e2839SNicholas Guy } else if (R1 == I0 || R1 == I1) { 727d52e2839SNicholas Guy CommonOperand = R1; 728d52e2839SNicholas Guy UncommonRealOp = R0; 729d52e2839SNicholas Guy } else { 730d52e2839SNicholas Guy LLVM_DEBUG(dbgs() << " - No equal operand\n"); 731d52e2839SNicholas Guy return nullptr; 732d52e2839SNicholas Guy } 733d52e2839SNicholas Guy 734d52e2839SNicholas Guy UncommonImagOp = (CommonOperand == I0) ? I1 : I0; 735d52e2839SNicholas Guy if (Rotation == ComplexDeinterleavingRotation::Rotation_90 || 736d52e2839SNicholas Guy Rotation == ComplexDeinterleavingRotation::Rotation_270) 737d52e2839SNicholas Guy std::swap(UncommonRealOp, UncommonImagOp); 738d52e2839SNicholas Guy 739b4f9c3a9SIgor Kirillov std::pair<Value *, Value *> PartialMatch( 740d52e2839SNicholas Guy (Rotation == ComplexDeinterleavingRotation::Rotation_0 || 741d52e2839SNicholas Guy Rotation == ComplexDeinterleavingRotation::Rotation_180) 742d52e2839SNicholas Guy ? CommonOperand 743d52e2839SNicholas Guy : nullptr, 744d52e2839SNicholas Guy (Rotation == ComplexDeinterleavingRotation::Rotation_90 || 745d52e2839SNicholas Guy Rotation == ComplexDeinterleavingRotation::Rotation_270) 746d52e2839SNicholas Guy ? CommonOperand 747d52e2839SNicholas Guy : nullptr); 74849384f14SNicholas Guy 74949384f14SNicholas Guy auto *CRInst = dyn_cast<Instruction>(CR); 75049384f14SNicholas Guy auto *CIInst = dyn_cast<Instruction>(CI); 75149384f14SNicholas Guy 75249384f14SNicholas Guy if (!CRInst || !CIInst) { 75349384f14SNicholas Guy LLVM_DEBUG(dbgs() << " - Common operands are not instructions.\n"); 75449384f14SNicholas Guy return nullptr; 75549384f14SNicholas Guy } 75649384f14SNicholas Guy 75749384f14SNicholas Guy NodePtr CNode = identifyNodeWithImplicitAdd(CRInst, CIInst, PartialMatch); 758d52e2839SNicholas Guy if (!CNode) { 759d52e2839SNicholas Guy LLVM_DEBUG(dbgs() << " - No cnode identified\n"); 760d52e2839SNicholas Guy return nullptr; 761d52e2839SNicholas Guy } 762d52e2839SNicholas Guy 763d52e2839SNicholas Guy NodePtr UncommonRes = identifyNode(UncommonRealOp, UncommonImagOp); 764d52e2839SNicholas Guy if (!UncommonRes) { 765d52e2839SNicholas Guy LLVM_DEBUG(dbgs() << " - No UncommonRes identified\n"); 766d52e2839SNicholas Guy return nullptr; 767d52e2839SNicholas Guy } 768d52e2839SNicholas Guy 769d52e2839SNicholas Guy assert(PartialMatch.first && PartialMatch.second); 770d52e2839SNicholas Guy NodePtr CommonRes = identifyNode(PartialMatch.first, PartialMatch.second); 771d52e2839SNicholas Guy if (!CommonRes) { 772d52e2839SNicholas Guy LLVM_DEBUG(dbgs() << " - No CommonRes identified\n"); 773d52e2839SNicholas Guy return nullptr; 774d52e2839SNicholas Guy } 775d52e2839SNicholas Guy 776d52e2839SNicholas Guy NodePtr Node = prepareCompositeNode( 777d52e2839SNicholas Guy ComplexDeinterleavingOperation::CMulPartial, Real, Imag); 778d52e2839SNicholas Guy Node->Rotation = Rotation; 779d52e2839SNicholas Guy Node->addOperand(CommonRes); 780d52e2839SNicholas Guy Node->addOperand(UncommonRes); 781d52e2839SNicholas Guy Node->addOperand(CNode); 782d52e2839SNicholas Guy return submitCompositeNode(Node); 783d52e2839SNicholas Guy } 784d52e2839SNicholas Guy 785d52e2839SNicholas Guy ComplexDeinterleavingGraph::NodePtr 786d52e2839SNicholas Guy ComplexDeinterleavingGraph::identifyAdd(Instruction *Real, Instruction *Imag) { 787d52e2839SNicholas Guy LLVM_DEBUG(dbgs() << "identifyAdd " << *Real << " / " << *Imag << "\n"); 788d52e2839SNicholas Guy 789d52e2839SNicholas Guy // Determine rotation 790d52e2839SNicholas Guy ComplexDeinterleavingRotation Rotation; 791a3dc5b53SNicholas Guy if ((Real->getOpcode() == Instruction::FSub && 792a3dc5b53SNicholas Guy Imag->getOpcode() == Instruction::FAdd) || 793a3dc5b53SNicholas Guy (Real->getOpcode() == Instruction::Sub && 794a3dc5b53SNicholas Guy Imag->getOpcode() == Instruction::Add)) 795d52e2839SNicholas Guy Rotation = ComplexDeinterleavingRotation::Rotation_90; 796a3dc5b53SNicholas Guy else if ((Real->getOpcode() == Instruction::FAdd && 797a3dc5b53SNicholas Guy Imag->getOpcode() == Instruction::FSub) || 798a3dc5b53SNicholas Guy (Real->getOpcode() == Instruction::Add && 799a3dc5b53SNicholas Guy Imag->getOpcode() == Instruction::Sub)) 800d52e2839SNicholas Guy Rotation = ComplexDeinterleavingRotation::Rotation_270; 801d52e2839SNicholas Guy else { 802d52e2839SNicholas Guy LLVM_DEBUG(dbgs() << " - Unhandled case, rotation is not assigned.\n"); 803d52e2839SNicholas Guy return nullptr; 804d52e2839SNicholas Guy } 805d52e2839SNicholas Guy 806a3dc5b53SNicholas Guy auto *AR = dyn_cast<Instruction>(Real->getOperand(0)); 807a3dc5b53SNicholas Guy auto *BI = dyn_cast<Instruction>(Real->getOperand(1)); 808a3dc5b53SNicholas Guy auto *AI = dyn_cast<Instruction>(Imag->getOperand(0)); 809a3dc5b53SNicholas Guy auto *BR = dyn_cast<Instruction>(Imag->getOperand(1)); 810a3dc5b53SNicholas Guy 811a3dc5b53SNicholas Guy if (!AR || !AI || !BR || !BI) { 812a3dc5b53SNicholas Guy LLVM_DEBUG(dbgs() << " - Not all operands are instructions.\n"); 813a3dc5b53SNicholas Guy return nullptr; 814a3dc5b53SNicholas Guy } 815d52e2839SNicholas Guy 816d52e2839SNicholas Guy NodePtr ResA = identifyNode(AR, AI); 817d52e2839SNicholas Guy if (!ResA) { 818d52e2839SNicholas Guy LLVM_DEBUG(dbgs() << " - AR/AI is not identified as a composite node.\n"); 819d52e2839SNicholas Guy return nullptr; 820d52e2839SNicholas Guy } 821d52e2839SNicholas Guy NodePtr ResB = identifyNode(BR, BI); 822d52e2839SNicholas Guy if (!ResB) { 823d52e2839SNicholas Guy LLVM_DEBUG(dbgs() << " - BR/BI is not identified as a composite node.\n"); 824d52e2839SNicholas Guy return nullptr; 825d52e2839SNicholas Guy } 826d52e2839SNicholas Guy 827d52e2839SNicholas Guy NodePtr Node = 828d52e2839SNicholas Guy prepareCompositeNode(ComplexDeinterleavingOperation::CAdd, Real, Imag); 829d52e2839SNicholas Guy Node->Rotation = Rotation; 830d52e2839SNicholas Guy Node->addOperand(ResA); 831d52e2839SNicholas Guy Node->addOperand(ResB); 832d52e2839SNicholas Guy return submitCompositeNode(Node); 833d52e2839SNicholas Guy } 834d52e2839SNicholas Guy 835d52e2839SNicholas Guy static bool isInstructionPairAdd(Instruction *A, Instruction *B) { 836d52e2839SNicholas Guy unsigned OpcA = A->getOpcode(); 837d52e2839SNicholas Guy unsigned OpcB = B->getOpcode(); 838a3dc5b53SNicholas Guy 839d52e2839SNicholas Guy return (OpcA == Instruction::FSub && OpcB == Instruction::FAdd) || 840a3dc5b53SNicholas Guy (OpcA == Instruction::FAdd && OpcB == Instruction::FSub) || 841a3dc5b53SNicholas Guy (OpcA == Instruction::Sub && OpcB == Instruction::Add) || 842a3dc5b53SNicholas Guy (OpcA == Instruction::Add && OpcB == Instruction::Sub); 843d52e2839SNicholas Guy } 844d52e2839SNicholas Guy 845d52e2839SNicholas Guy static bool isInstructionPairMul(Instruction *A, Instruction *B) { 846d52e2839SNicholas Guy auto Pattern = 847d52e2839SNicholas Guy m_BinOp(m_FMul(m_Value(), m_Value()), m_FMul(m_Value(), m_Value())); 848d52e2839SNicholas Guy 849d52e2839SNicholas Guy return match(A, Pattern) && match(B, Pattern); 850d52e2839SNicholas Guy } 851d52e2839SNicholas Guy 85296615c85SNicholas Guy static bool isInstructionPotentiallySymmetric(Instruction *I) { 85396615c85SNicholas Guy switch (I->getOpcode()) { 85496615c85SNicholas Guy case Instruction::FAdd: 85596615c85SNicholas Guy case Instruction::FSub: 85696615c85SNicholas Guy case Instruction::FMul: 85796615c85SNicholas Guy case Instruction::FNeg: 858c15557d6SIgor Kirillov case Instruction::Add: 859c15557d6SIgor Kirillov case Instruction::Sub: 860c15557d6SIgor Kirillov case Instruction::Mul: 86196615c85SNicholas Guy return true; 86296615c85SNicholas Guy default: 86396615c85SNicholas Guy return false; 86496615c85SNicholas Guy } 86596615c85SNicholas Guy } 86696615c85SNicholas Guy 86796615c85SNicholas Guy ComplexDeinterleavingGraph::NodePtr 86896615c85SNicholas Guy ComplexDeinterleavingGraph::identifySymmetricOperation(Instruction *Real, 86996615c85SNicholas Guy Instruction *Imag) { 87096615c85SNicholas Guy if (Real->getOpcode() != Imag->getOpcode()) 87196615c85SNicholas Guy return nullptr; 87296615c85SNicholas Guy 87396615c85SNicholas Guy if (!isInstructionPotentiallySymmetric(Real) || 87496615c85SNicholas Guy !isInstructionPotentiallySymmetric(Imag)) 87596615c85SNicholas Guy return nullptr; 87696615c85SNicholas Guy 877b4f9c3a9SIgor Kirillov auto *R0 = Real->getOperand(0); 878b4f9c3a9SIgor Kirillov auto *I0 = Imag->getOperand(0); 87996615c85SNicholas Guy 88096615c85SNicholas Guy NodePtr Op0 = identifyNode(R0, I0); 88196615c85SNicholas Guy NodePtr Op1 = nullptr; 88296615c85SNicholas Guy if (Op0 == nullptr) 88396615c85SNicholas Guy return nullptr; 88496615c85SNicholas Guy 88596615c85SNicholas Guy if (Real->isBinaryOp()) { 886b4f9c3a9SIgor Kirillov auto *R1 = Real->getOperand(1); 887b4f9c3a9SIgor Kirillov auto *I1 = Imag->getOperand(1); 88896615c85SNicholas Guy Op1 = identifyNode(R1, I1); 88996615c85SNicholas Guy if (Op1 == nullptr) 89096615c85SNicholas Guy return nullptr; 89196615c85SNicholas Guy } 89296615c85SNicholas Guy 8931a1e7610SIgor Kirillov if (isa<FPMathOperator>(Real) && 8941a1e7610SIgor Kirillov Real->getFastMathFlags() != Imag->getFastMathFlags()) 8951a1e7610SIgor Kirillov return nullptr; 8961a1e7610SIgor Kirillov 89796615c85SNicholas Guy auto Node = prepareCompositeNode(ComplexDeinterleavingOperation::Symmetric, 89896615c85SNicholas Guy Real, Imag); 8991a1e7610SIgor Kirillov Node->Opcode = Real->getOpcode(); 9001a1e7610SIgor Kirillov if (isa<FPMathOperator>(Real)) 9011a1e7610SIgor Kirillov Node->Flags = Real->getFastMathFlags(); 9021a1e7610SIgor Kirillov 90396615c85SNicholas Guy Node->addOperand(Op0); 90496615c85SNicholas Guy if (Real->isBinaryOp()) 90596615c85SNicholas Guy Node->addOperand(Op1); 90696615c85SNicholas Guy 90796615c85SNicholas Guy return submitCompositeNode(Node); 90896615c85SNicholas Guy } 90996615c85SNicholas Guy 910d52e2839SNicholas Guy ComplexDeinterleavingGraph::NodePtr 9118e1b49c3SNicholas Guy ComplexDeinterleavingGraph::identifyDotProduct(Value *V) { 91276714be5SFlorian Hahn 9138e1b49c3SNicholas Guy if (!TL->isComplexDeinterleavingOperationSupported( 9148e1b49c3SNicholas Guy ComplexDeinterleavingOperation::CDot, V->getType())) { 9158e1b49c3SNicholas Guy LLVM_DEBUG(dbgs() << "Target doesn't support complex deinterleaving " 9168e1b49c3SNicholas Guy "operation CDot with the type " 9178e1b49c3SNicholas Guy << *V->getType() << "\n"); 9188e1b49c3SNicholas Guy return nullptr; 9198e1b49c3SNicholas Guy } 9208e1b49c3SNicholas Guy 9218e1b49c3SNicholas Guy auto *Inst = cast<Instruction>(V); 9228e1b49c3SNicholas Guy auto *RealUser = cast<Instruction>(*Inst->user_begin()); 9238e1b49c3SNicholas Guy 9248e1b49c3SNicholas Guy NodePtr CN = 9258e1b49c3SNicholas Guy prepareCompositeNode(ComplexDeinterleavingOperation::CDot, Inst, nullptr); 9268e1b49c3SNicholas Guy 9278e1b49c3SNicholas Guy NodePtr ANode; 9288e1b49c3SNicholas Guy 9298e1b49c3SNicholas Guy const Intrinsic::ID PartialReduceInt = 9308e1b49c3SNicholas Guy Intrinsic::experimental_vector_partial_reduce_add; 9318e1b49c3SNicholas Guy 9328e1b49c3SNicholas Guy Value *AReal = nullptr; 9338e1b49c3SNicholas Guy Value *AImag = nullptr; 9348e1b49c3SNicholas Guy Value *BReal = nullptr; 9358e1b49c3SNicholas Guy Value *BImag = nullptr; 9368e1b49c3SNicholas Guy Value *Phi = nullptr; 9378e1b49c3SNicholas Guy 9388e1b49c3SNicholas Guy auto UnwrapCast = [](Value *V) -> Value * { 9398e1b49c3SNicholas Guy if (auto *CI = dyn_cast<CastInst>(V)) 9408e1b49c3SNicholas Guy return CI->getOperand(0); 9418e1b49c3SNicholas Guy return V; 9428e1b49c3SNicholas Guy }; 9438e1b49c3SNicholas Guy 9448e1b49c3SNicholas Guy auto PatternRot0 = m_Intrinsic<PartialReduceInt>( 9458e1b49c3SNicholas Guy m_Intrinsic<PartialReduceInt>(m_Value(Phi), 9468e1b49c3SNicholas Guy m_Mul(m_Value(BReal), m_Value(AReal))), 9478e1b49c3SNicholas Guy m_Neg(m_Mul(m_Value(BImag), m_Value(AImag)))); 9488e1b49c3SNicholas Guy 9498e1b49c3SNicholas Guy auto PatternRot270 = m_Intrinsic<PartialReduceInt>( 9508e1b49c3SNicholas Guy m_Intrinsic<PartialReduceInt>( 9518e1b49c3SNicholas Guy m_Value(Phi), m_Neg(m_Mul(m_Value(BReal), m_Value(AImag)))), 9528e1b49c3SNicholas Guy m_Mul(m_Value(BImag), m_Value(AReal))); 9538e1b49c3SNicholas Guy 9548e1b49c3SNicholas Guy if (match(Inst, PatternRot0)) { 9558e1b49c3SNicholas Guy CN->Rotation = ComplexDeinterleavingRotation::Rotation_0; 9568e1b49c3SNicholas Guy } else if (match(Inst, PatternRot270)) { 9578e1b49c3SNicholas Guy CN->Rotation = ComplexDeinterleavingRotation::Rotation_270; 9588e1b49c3SNicholas Guy } else { 9598e1b49c3SNicholas Guy Value *A0, *A1; 9608e1b49c3SNicholas Guy // The rotations 90 and 180 share the same operation pattern, so inspect the 9618e1b49c3SNicholas Guy // order of the operands, identifying where the real and imaginary 9628e1b49c3SNicholas Guy // components of A go, to discern between the aforementioned rotations. 9638e1b49c3SNicholas Guy auto PatternRot90Rot180 = m_Intrinsic<PartialReduceInt>( 9648e1b49c3SNicholas Guy m_Intrinsic<PartialReduceInt>(m_Value(Phi), 9658e1b49c3SNicholas Guy m_Mul(m_Value(BReal), m_Value(A0))), 9668e1b49c3SNicholas Guy m_Mul(m_Value(BImag), m_Value(A1))); 9678e1b49c3SNicholas Guy 9688e1b49c3SNicholas Guy if (!match(Inst, PatternRot90Rot180)) 9698e1b49c3SNicholas Guy return nullptr; 9708e1b49c3SNicholas Guy 9718e1b49c3SNicholas Guy A0 = UnwrapCast(A0); 9728e1b49c3SNicholas Guy A1 = UnwrapCast(A1); 9738e1b49c3SNicholas Guy 9748e1b49c3SNicholas Guy // Test if A0 is real/A1 is imag 9758e1b49c3SNicholas Guy ANode = identifyNode(A0, A1); 9768e1b49c3SNicholas Guy if (!ANode) { 9778e1b49c3SNicholas Guy // Test if A0 is imag/A1 is real 9788e1b49c3SNicholas Guy ANode = identifyNode(A1, A0); 9798e1b49c3SNicholas Guy // Unable to identify operand components, thus unable to identify rotation 9808e1b49c3SNicholas Guy if (!ANode) 9818e1b49c3SNicholas Guy return nullptr; 9828e1b49c3SNicholas Guy CN->Rotation = ComplexDeinterleavingRotation::Rotation_90; 9838e1b49c3SNicholas Guy AReal = A1; 9848e1b49c3SNicholas Guy AImag = A0; 9858e1b49c3SNicholas Guy } else { 9868e1b49c3SNicholas Guy AReal = A0; 9878e1b49c3SNicholas Guy AImag = A1; 9888e1b49c3SNicholas Guy CN->Rotation = ComplexDeinterleavingRotation::Rotation_180; 9898e1b49c3SNicholas Guy } 9908e1b49c3SNicholas Guy } 9918e1b49c3SNicholas Guy 9928e1b49c3SNicholas Guy AReal = UnwrapCast(AReal); 9938e1b49c3SNicholas Guy AImag = UnwrapCast(AImag); 9948e1b49c3SNicholas Guy BReal = UnwrapCast(BReal); 9958e1b49c3SNicholas Guy BImag = UnwrapCast(BImag); 9968e1b49c3SNicholas Guy 9978e1b49c3SNicholas Guy VectorType *VTy = cast<VectorType>(V->getType()); 9988e1b49c3SNicholas Guy Type *ExpectedOperandTy = VectorType::getSubdividedVectorType(VTy, 2); 9998e1b49c3SNicholas Guy if (AReal->getType() != ExpectedOperandTy) 10008e1b49c3SNicholas Guy return nullptr; 10018e1b49c3SNicholas Guy if (AImag->getType() != ExpectedOperandTy) 10028e1b49c3SNicholas Guy return nullptr; 10038e1b49c3SNicholas Guy if (BReal->getType() != ExpectedOperandTy) 10048e1b49c3SNicholas Guy return nullptr; 10058e1b49c3SNicholas Guy if (BImag->getType() != ExpectedOperandTy) 10068e1b49c3SNicholas Guy return nullptr; 10078e1b49c3SNicholas Guy 10088e1b49c3SNicholas Guy if (Phi->getType() != VTy && RealUser->getType() != VTy) 10098e1b49c3SNicholas Guy return nullptr; 10108e1b49c3SNicholas Guy 10118e1b49c3SNicholas Guy NodePtr Node = identifyNode(AReal, AImag); 10128e1b49c3SNicholas Guy 10138e1b49c3SNicholas Guy // In the case that a node was identified to figure out the rotation, ensure 10148e1b49c3SNicholas Guy // that trying to identify a node with AReal and AImag post-unwrap results in 10158e1b49c3SNicholas Guy // the same node 10168e1b49c3SNicholas Guy if (ANode && Node != ANode) { 10178e1b49c3SNicholas Guy LLVM_DEBUG( 10188e1b49c3SNicholas Guy dbgs() 10198e1b49c3SNicholas Guy << "Identified node is different from previously identified node. " 10208e1b49c3SNicholas Guy "Unable to confidently generate a complex operation node\n"); 10218e1b49c3SNicholas Guy return nullptr; 10228e1b49c3SNicholas Guy } 10238e1b49c3SNicholas Guy 10248e1b49c3SNicholas Guy CN->addOperand(Node); 10258e1b49c3SNicholas Guy CN->addOperand(identifyNode(BReal, BImag)); 10268e1b49c3SNicholas Guy CN->addOperand(identifyNode(Phi, RealUser)); 10278e1b49c3SNicholas Guy 10288e1b49c3SNicholas Guy return submitCompositeNode(CN); 10298e1b49c3SNicholas Guy } 10308e1b49c3SNicholas Guy 10318e1b49c3SNicholas Guy ComplexDeinterleavingGraph::NodePtr 10328e1b49c3SNicholas Guy ComplexDeinterleavingGraph::identifyPartialReduction(Value *R, Value *I) { 10338e1b49c3SNicholas Guy // Partial reductions don't support non-vector types, so check these first 10348e1b49c3SNicholas Guy if (!isa<VectorType>(R->getType()) || !isa<VectorType>(I->getType())) 10358e1b49c3SNicholas Guy return nullptr; 10368e1b49c3SNicholas Guy 10378e1b49c3SNicholas Guy auto CommonUser = 10388e1b49c3SNicholas Guy findCommonBetweenCollections<Value *>(R->users(), I->users()); 10398e1b49c3SNicholas Guy if (!CommonUser) 10408e1b49c3SNicholas Guy return nullptr; 10418e1b49c3SNicholas Guy 10428e1b49c3SNicholas Guy auto *IInst = dyn_cast<IntrinsicInst>(*CommonUser); 10438e1b49c3SNicholas Guy if (!IInst || IInst->getIntrinsicID() != 10448e1b49c3SNicholas Guy Intrinsic::experimental_vector_partial_reduce_add) 10458e1b49c3SNicholas Guy return nullptr; 10468e1b49c3SNicholas Guy 10478e1b49c3SNicholas Guy if (NodePtr CN = identifyDotProduct(IInst)) 10488e1b49c3SNicholas Guy return CN; 10498e1b49c3SNicholas Guy 10508e1b49c3SNicholas Guy return nullptr; 10518e1b49c3SNicholas Guy } 10528e1b49c3SNicholas Guy 10538e1b49c3SNicholas Guy ComplexDeinterleavingGraph::NodePtr 10548e1b49c3SNicholas Guy ComplexDeinterleavingGraph::identifyNode(Value *R, Value *I) { 105546b2ad02SIgor Kirillov auto It = CachedResult.find({R, I}); 105646b2ad02SIgor Kirillov if (It != CachedResult.end()) { 1057d52e2839SNicholas Guy LLVM_DEBUG(dbgs() << " - Folding to existing node\n"); 105846b2ad02SIgor Kirillov return It->second; 1059d52e2839SNicholas Guy } 1060d52e2839SNicholas Guy 10618e1b49c3SNicholas Guy if (NodePtr CN = identifyPartialReduction(R, I)) 10628e1b49c3SNicholas Guy return CN; 10638e1b49c3SNicholas Guy 10648e1b49c3SNicholas Guy bool IsReduction = RealPHI == R && (!ImagPHI || ImagPHI == I); 10658e1b49c3SNicholas Guy if (!IsReduction && R->getType() != I->getType()) 10668e1b49c3SNicholas Guy return nullptr; 10678e1b49c3SNicholas Guy 10687f20407cSIgor Kirillov if (NodePtr CN = identifySplat(R, I)) 10697f20407cSIgor Kirillov return CN; 10707f20407cSIgor Kirillov 1071b4f9c3a9SIgor Kirillov auto *Real = dyn_cast<Instruction>(R); 1072b4f9c3a9SIgor Kirillov auto *Imag = dyn_cast<Instruction>(I); 1073b4f9c3a9SIgor Kirillov if (!Real || !Imag) 1074b4f9c3a9SIgor Kirillov return nullptr; 1075b4f9c3a9SIgor Kirillov 10761a1e7610SIgor Kirillov if (NodePtr CN = identifyDeinterleave(Real, Imag)) 10771a1e7610SIgor Kirillov return CN; 10786850bc35SIgor Kirillov 10792cbc265cSIgor Kirillov if (NodePtr CN = identifyPHINode(Real, Imag)) 10802cbc265cSIgor Kirillov return CN; 10812cbc265cSIgor Kirillov 108204a8070bSIgor Kirillov if (NodePtr CN = identifySelectNode(Real, Imag)) 108304a8070bSIgor Kirillov return CN; 108404a8070bSIgor Kirillov 10856850bc35SIgor Kirillov auto *VTy = cast<VectorType>(Real->getType()); 10866850bc35SIgor Kirillov auto *NewVTy = VectorType::getDoubleElementsVectorType(VTy); 10876850bc35SIgor Kirillov 10881a1e7610SIgor Kirillov bool HasCMulSupport = TL->isComplexDeinterleavingOperationSupported( 10891a1e7610SIgor Kirillov ComplexDeinterleavingOperation::CMulPartial, NewVTy); 10901a1e7610SIgor Kirillov bool HasCAddSupport = TL->isComplexDeinterleavingOperationSupported( 10911a1e7610SIgor Kirillov ComplexDeinterleavingOperation::CAdd, NewVTy); 10921a1e7610SIgor Kirillov 10931a1e7610SIgor Kirillov if (HasCMulSupport && isInstructionPairMul(Real, Imag)) { 10941a1e7610SIgor Kirillov if (NodePtr CN = identifyPartialMul(Real, Imag)) 10951a1e7610SIgor Kirillov return CN; 10966850bc35SIgor Kirillov } 10976850bc35SIgor Kirillov 10981a1e7610SIgor Kirillov if (HasCAddSupport && isInstructionPairAdd(Real, Imag)) { 10991a1e7610SIgor Kirillov if (NodePtr CN = identifyAdd(Real, Imag)) 11001a1e7610SIgor Kirillov return CN; 11016850bc35SIgor Kirillov } 11026850bc35SIgor Kirillov 11031a1e7610SIgor Kirillov if (HasCMulSupport && HasCAddSupport) { 11041a1e7610SIgor Kirillov if (NodePtr CN = identifyReassocNodes(Real, Imag)) 11051a1e7610SIgor Kirillov return CN; 11061a1e7610SIgor Kirillov } 11071a1e7610SIgor Kirillov 11081a1e7610SIgor Kirillov if (NodePtr CN = identifySymmetricOperation(Real, Imag)) 11091a1e7610SIgor Kirillov return CN; 11101a1e7610SIgor Kirillov 11111a1e7610SIgor Kirillov LLVM_DEBUG(dbgs() << " - Not recognised as a valid pattern.\n"); 111246b2ad02SIgor Kirillov CachedResult[{R, I}] = nullptr; 11131a1e7610SIgor Kirillov return nullptr; 11141a1e7610SIgor Kirillov } 11151a1e7610SIgor Kirillov 11161a1e7610SIgor Kirillov ComplexDeinterleavingGraph::NodePtr 11171a1e7610SIgor Kirillov ComplexDeinterleavingGraph::identifyReassocNodes(Instruction *Real, 11181a1e7610SIgor Kirillov Instruction *Imag) { 1119c15557d6SIgor Kirillov auto IsOperationSupported = [](unsigned Opcode) -> bool { 1120c15557d6SIgor Kirillov return Opcode == Instruction::FAdd || Opcode == Instruction::FSub || 1121c15557d6SIgor Kirillov Opcode == Instruction::FNeg || Opcode == Instruction::Add || 1122c15557d6SIgor Kirillov Opcode == Instruction::Sub; 1123c15557d6SIgor Kirillov }; 1124b4f9c3a9SIgor Kirillov 1125c15557d6SIgor Kirillov if (!IsOperationSupported(Real->getOpcode()) || 1126c15557d6SIgor Kirillov !IsOperationSupported(Imag->getOpcode())) 11271a1e7610SIgor Kirillov return nullptr; 11281a1e7610SIgor Kirillov 1129c15557d6SIgor Kirillov std::optional<FastMathFlags> Flags; 1130c15557d6SIgor Kirillov if (isa<FPMathOperator>(Real)) { 11311a1e7610SIgor Kirillov if (Real->getFastMathFlags() != Imag->getFastMathFlags()) { 1132c15557d6SIgor Kirillov LLVM_DEBUG(dbgs() << "The flags in Real and Imaginary instructions are " 1133c15557d6SIgor Kirillov "not identical\n"); 11341a1e7610SIgor Kirillov return nullptr; 11351a1e7610SIgor Kirillov } 11361a1e7610SIgor Kirillov 1137c15557d6SIgor Kirillov Flags = Real->getFastMathFlags(); 1138c15557d6SIgor Kirillov if (!Flags->allowReassoc()) { 11391a1e7610SIgor Kirillov LLVM_DEBUG( 1140c15557d6SIgor Kirillov dbgs() 1141c15557d6SIgor Kirillov << "the 'Reassoc' attribute is missing in the FastMath flags\n"); 11421a1e7610SIgor Kirillov return nullptr; 11431a1e7610SIgor Kirillov } 1144c15557d6SIgor Kirillov } 11451a1e7610SIgor Kirillov 11461a1e7610SIgor Kirillov // Collect multiplications and addend instructions from the given instruction 11471a1e7610SIgor Kirillov // while traversing it operands. Additionally, verify that all instructions 11481a1e7610SIgor Kirillov // have the same fast math flags. 11491a1e7610SIgor Kirillov auto Collect = [&Flags](Instruction *Insn, std::vector<Product> &Muls, 11501a1e7610SIgor Kirillov std::list<Addend> &Addends) -> bool { 11511a1e7610SIgor Kirillov SmallVector<PointerIntPair<Value *, 1, bool>> Worklist = {{Insn, true}}; 11521a1e7610SIgor Kirillov SmallPtrSet<Value *, 8> Visited; 11531a1e7610SIgor Kirillov while (!Worklist.empty()) { 11541a1e7610SIgor Kirillov auto [V, IsPositive] = Worklist.back(); 11551a1e7610SIgor Kirillov Worklist.pop_back(); 11561a1e7610SIgor Kirillov if (!Visited.insert(V).second) 11571a1e7610SIgor Kirillov continue; 11581a1e7610SIgor Kirillov 11591a1e7610SIgor Kirillov Instruction *I = dyn_cast<Instruction>(V); 1160b4f9c3a9SIgor Kirillov if (!I) { 1161b4f9c3a9SIgor Kirillov Addends.emplace_back(V, IsPositive); 1162b4f9c3a9SIgor Kirillov continue; 1163b4f9c3a9SIgor Kirillov } 11641a1e7610SIgor Kirillov 11651a1e7610SIgor Kirillov // If an instruction has more than one user, it indicates that it either 11661a1e7610SIgor Kirillov // has an external user, which will be later checked by the checkNodes 11671a1e7610SIgor Kirillov // function, or it is a subexpression utilized by multiple expressions. In 11681a1e7610SIgor Kirillov // the latter case, we will attempt to separately identify the complex 11691a1e7610SIgor Kirillov // operation from here in order to create a shared 11701a1e7610SIgor Kirillov // ComplexDeinterleavingCompositeNode. 11711a1e7610SIgor Kirillov if (I != Insn && I->getNumUses() > 1) { 11721a1e7610SIgor Kirillov LLVM_DEBUG(dbgs() << "Found potential sub-expression: " << *I << "\n"); 11731a1e7610SIgor Kirillov Addends.emplace_back(I, IsPositive); 11741a1e7610SIgor Kirillov continue; 11751a1e7610SIgor Kirillov } 1176c15557d6SIgor Kirillov switch (I->getOpcode()) { 1177c15557d6SIgor Kirillov case Instruction::FAdd: 1178c15557d6SIgor Kirillov case Instruction::Add: 11791a1e7610SIgor Kirillov Worklist.emplace_back(I->getOperand(1), IsPositive); 11801a1e7610SIgor Kirillov Worklist.emplace_back(I->getOperand(0), IsPositive); 1181c15557d6SIgor Kirillov break; 1182c15557d6SIgor Kirillov case Instruction::FSub: 11831a1e7610SIgor Kirillov Worklist.emplace_back(I->getOperand(1), !IsPositive); 11841a1e7610SIgor Kirillov Worklist.emplace_back(I->getOperand(0), IsPositive); 1185c15557d6SIgor Kirillov break; 1186c15557d6SIgor Kirillov case Instruction::Sub: 1187c15557d6SIgor Kirillov if (isNeg(I)) { 1188c15557d6SIgor Kirillov Worklist.emplace_back(getNegOperand(I), !IsPositive); 1189c15557d6SIgor Kirillov } else { 1190c15557d6SIgor Kirillov Worklist.emplace_back(I->getOperand(1), !IsPositive); 1191c15557d6SIgor Kirillov Worklist.emplace_back(I->getOperand(0), IsPositive); 1192c15557d6SIgor Kirillov } 1193c15557d6SIgor Kirillov break; 1194c15557d6SIgor Kirillov case Instruction::FMul: 1195c15557d6SIgor Kirillov case Instruction::Mul: { 1196b4f9c3a9SIgor Kirillov Value *A, *B; 1197c15557d6SIgor Kirillov if (isNeg(I->getOperand(0))) { 1198c15557d6SIgor Kirillov A = getNegOperand(I->getOperand(0)); 11991a1e7610SIgor Kirillov IsPositive = !IsPositive; 1200b4f9c3a9SIgor Kirillov } else { 1201b4f9c3a9SIgor Kirillov A = I->getOperand(0); 12021a1e7610SIgor Kirillov } 1203b4f9c3a9SIgor Kirillov 1204c15557d6SIgor Kirillov if (isNeg(I->getOperand(1))) { 1205c15557d6SIgor Kirillov B = getNegOperand(I->getOperand(1)); 12061a1e7610SIgor Kirillov IsPositive = !IsPositive; 1207b4f9c3a9SIgor Kirillov } else { 1208b4f9c3a9SIgor Kirillov B = I->getOperand(1); 12091a1e7610SIgor Kirillov } 12101a1e7610SIgor Kirillov Muls.push_back(Product{A, B, IsPositive}); 1211c15557d6SIgor Kirillov break; 1212c15557d6SIgor Kirillov } 1213c15557d6SIgor Kirillov case Instruction::FNeg: 12141a1e7610SIgor Kirillov Worklist.emplace_back(I->getOperand(0), !IsPositive); 1215c15557d6SIgor Kirillov break; 1216c15557d6SIgor Kirillov default: 12171a1e7610SIgor Kirillov Addends.emplace_back(I, IsPositive); 12181a1e7610SIgor Kirillov continue; 12191a1e7610SIgor Kirillov } 12201a1e7610SIgor Kirillov 1221c15557d6SIgor Kirillov if (Flags && I->getFastMathFlags() != *Flags) { 12221a1e7610SIgor Kirillov LLVM_DEBUG(dbgs() << "The instruction's fast math flags are " 12231a1e7610SIgor Kirillov "inconsistent with the root instructions' flags: " 12241a1e7610SIgor Kirillov << *I << "\n"); 12251a1e7610SIgor Kirillov return false; 12261a1e7610SIgor Kirillov } 12271a1e7610SIgor Kirillov } 12281a1e7610SIgor Kirillov return true; 12291a1e7610SIgor Kirillov }; 12301a1e7610SIgor Kirillov 12311a1e7610SIgor Kirillov std::vector<Product> RealMuls, ImagMuls; 12321a1e7610SIgor Kirillov std::list<Addend> RealAddends, ImagAddends; 12331a1e7610SIgor Kirillov if (!Collect(Real, RealMuls, RealAddends) || 12341a1e7610SIgor Kirillov !Collect(Imag, ImagMuls, ImagAddends)) 12351a1e7610SIgor Kirillov return nullptr; 12361a1e7610SIgor Kirillov 12371a1e7610SIgor Kirillov if (RealAddends.size() != ImagAddends.size()) 12381a1e7610SIgor Kirillov return nullptr; 12391a1e7610SIgor Kirillov 12401a1e7610SIgor Kirillov NodePtr FinalNode; 12411a1e7610SIgor Kirillov if (!RealMuls.empty() || !ImagMuls.empty()) { 12421a1e7610SIgor Kirillov // If there are multiplicands, extract positive addend and use it as an 12431a1e7610SIgor Kirillov // accumulator 12441a1e7610SIgor Kirillov FinalNode = extractPositiveAddend(RealAddends, ImagAddends); 12451a1e7610SIgor Kirillov FinalNode = identifyMultiplications(RealMuls, ImagMuls, FinalNode); 12461a1e7610SIgor Kirillov if (!FinalNode) 12471a1e7610SIgor Kirillov return nullptr; 12481a1e7610SIgor Kirillov } 12491a1e7610SIgor Kirillov 12501a1e7610SIgor Kirillov // Identify and process remaining additions 12511a1e7610SIgor Kirillov if (!RealAddends.empty() || !ImagAddends.empty()) { 12521a1e7610SIgor Kirillov FinalNode = identifyAdditions(RealAddends, ImagAddends, Flags, FinalNode); 12531a1e7610SIgor Kirillov if (!FinalNode) 12541a1e7610SIgor Kirillov return nullptr; 12551a1e7610SIgor Kirillov } 1256c78acc92SWang, Xin10 assert(FinalNode && "FinalNode can not be nullptr here"); 12571a1e7610SIgor Kirillov // Set the Real and Imag fields of the final node and submit it 12581a1e7610SIgor Kirillov FinalNode->Real = Real; 12591a1e7610SIgor Kirillov FinalNode->Imag = Imag; 12601a1e7610SIgor Kirillov submitCompositeNode(FinalNode); 12611a1e7610SIgor Kirillov return FinalNode; 12621a1e7610SIgor Kirillov } 12631a1e7610SIgor Kirillov 12641a1e7610SIgor Kirillov bool ComplexDeinterleavingGraph::collectPartialMuls( 12651a1e7610SIgor Kirillov const std::vector<Product> &RealMuls, const std::vector<Product> &ImagMuls, 12661a1e7610SIgor Kirillov std::vector<PartialMulCandidate> &PartialMulCandidates) { 12671a1e7610SIgor Kirillov // Helper function to extract a common operand from two products 12681a1e7610SIgor Kirillov auto FindCommonInstruction = [](const Product &Real, 1269b4f9c3a9SIgor Kirillov const Product &Imag) -> Value * { 12701a1e7610SIgor Kirillov if (Real.Multiplicand == Imag.Multiplicand || 12711a1e7610SIgor Kirillov Real.Multiplicand == Imag.Multiplier) 12721a1e7610SIgor Kirillov return Real.Multiplicand; 12731a1e7610SIgor Kirillov 12741a1e7610SIgor Kirillov if (Real.Multiplier == Imag.Multiplicand || 12751a1e7610SIgor Kirillov Real.Multiplier == Imag.Multiplier) 12761a1e7610SIgor Kirillov return Real.Multiplier; 12771a1e7610SIgor Kirillov 12781a1e7610SIgor Kirillov return nullptr; 12791a1e7610SIgor Kirillov }; 12801a1e7610SIgor Kirillov 12811a1e7610SIgor Kirillov // Iterating over real and imaginary multiplications to find common operands 12821a1e7610SIgor Kirillov // If a common operand is found, a partial multiplication candidate is created 12831a1e7610SIgor Kirillov // and added to the candidates vector The function returns false if no common 12841a1e7610SIgor Kirillov // operands are found for any product 12851a1e7610SIgor Kirillov for (unsigned i = 0; i < RealMuls.size(); ++i) { 12861a1e7610SIgor Kirillov bool FoundCommon = false; 12871a1e7610SIgor Kirillov for (unsigned j = 0; j < ImagMuls.size(); ++j) { 12881a1e7610SIgor Kirillov auto *Common = FindCommonInstruction(RealMuls[i], ImagMuls[j]); 12891a1e7610SIgor Kirillov if (!Common) 12901a1e7610SIgor Kirillov continue; 12911a1e7610SIgor Kirillov 12921a1e7610SIgor Kirillov auto *A = RealMuls[i].Multiplicand == Common ? RealMuls[i].Multiplier 12931a1e7610SIgor Kirillov : RealMuls[i].Multiplicand; 12941a1e7610SIgor Kirillov auto *B = ImagMuls[j].Multiplicand == Common ? ImagMuls[j].Multiplier 12951a1e7610SIgor Kirillov : ImagMuls[j].Multiplicand; 12961a1e7610SIgor Kirillov 12971a1e7610SIgor Kirillov auto Node = identifyNode(A, B); 1298b4f9c3a9SIgor Kirillov if (Node) { 12991a1e7610SIgor Kirillov FoundCommon = true; 1300b4f9c3a9SIgor Kirillov PartialMulCandidates.push_back({Common, Node, i, j, false}); 1301b4f9c3a9SIgor Kirillov } 1302b4f9c3a9SIgor Kirillov 1303b4f9c3a9SIgor Kirillov Node = identifyNode(B, A); 1304b4f9c3a9SIgor Kirillov if (Node) { 1305b4f9c3a9SIgor Kirillov FoundCommon = true; 1306b4f9c3a9SIgor Kirillov PartialMulCandidates.push_back({Common, Node, i, j, true}); 1307b4f9c3a9SIgor Kirillov } 13081a1e7610SIgor Kirillov } 13091a1e7610SIgor Kirillov if (!FoundCommon) 13101a1e7610SIgor Kirillov return false; 13111a1e7610SIgor Kirillov } 13121a1e7610SIgor Kirillov return true; 13131a1e7610SIgor Kirillov } 13141a1e7610SIgor Kirillov 13151a1e7610SIgor Kirillov ComplexDeinterleavingGraph::NodePtr 13161a1e7610SIgor Kirillov ComplexDeinterleavingGraph::identifyMultiplications( 13171a1e7610SIgor Kirillov std::vector<Product> &RealMuls, std::vector<Product> &ImagMuls, 13181a1e7610SIgor Kirillov NodePtr Accumulator = nullptr) { 13191a1e7610SIgor Kirillov if (RealMuls.size() != ImagMuls.size()) 13201a1e7610SIgor Kirillov return nullptr; 13211a1e7610SIgor Kirillov 13221a1e7610SIgor Kirillov std::vector<PartialMulCandidate> Info; 13231a1e7610SIgor Kirillov if (!collectPartialMuls(RealMuls, ImagMuls, Info)) 13241a1e7610SIgor Kirillov return nullptr; 13251a1e7610SIgor Kirillov 13261a1e7610SIgor Kirillov // Map to store common instruction to node pointers 1327b4f9c3a9SIgor Kirillov std::map<Value *, NodePtr> CommonToNode; 13281a1e7610SIgor Kirillov std::vector<bool> Processed(Info.size(), false); 13291a1e7610SIgor Kirillov for (unsigned I = 0; I < Info.size(); ++I) { 13301a1e7610SIgor Kirillov if (Processed[I]) 13311a1e7610SIgor Kirillov continue; 13321a1e7610SIgor Kirillov 13331a1e7610SIgor Kirillov PartialMulCandidate &InfoA = Info[I]; 13341a1e7610SIgor Kirillov for (unsigned J = I + 1; J < Info.size(); ++J) { 13351a1e7610SIgor Kirillov if (Processed[J]) 13361a1e7610SIgor Kirillov continue; 13371a1e7610SIgor Kirillov 13381a1e7610SIgor Kirillov PartialMulCandidate &InfoB = Info[J]; 13391a1e7610SIgor Kirillov auto *InfoReal = &InfoA; 13401a1e7610SIgor Kirillov auto *InfoImag = &InfoB; 13411a1e7610SIgor Kirillov 13421a1e7610SIgor Kirillov auto NodeFromCommon = identifyNode(InfoReal->Common, InfoImag->Common); 13431a1e7610SIgor Kirillov if (!NodeFromCommon) { 13441a1e7610SIgor Kirillov std::swap(InfoReal, InfoImag); 13451a1e7610SIgor Kirillov NodeFromCommon = identifyNode(InfoReal->Common, InfoImag->Common); 13461a1e7610SIgor Kirillov } 13471a1e7610SIgor Kirillov if (!NodeFromCommon) 13481a1e7610SIgor Kirillov continue; 13491a1e7610SIgor Kirillov 13501a1e7610SIgor Kirillov CommonToNode[InfoReal->Common] = NodeFromCommon; 13511a1e7610SIgor Kirillov CommonToNode[InfoImag->Common] = NodeFromCommon; 13521a1e7610SIgor Kirillov Processed[I] = true; 13531a1e7610SIgor Kirillov Processed[J] = true; 13541a1e7610SIgor Kirillov } 13551a1e7610SIgor Kirillov } 13561a1e7610SIgor Kirillov 13571a1e7610SIgor Kirillov std::vector<bool> ProcessedReal(RealMuls.size(), false); 13581a1e7610SIgor Kirillov std::vector<bool> ProcessedImag(ImagMuls.size(), false); 13591a1e7610SIgor Kirillov NodePtr Result = Accumulator; 13601a1e7610SIgor Kirillov for (auto &PMI : Info) { 13611a1e7610SIgor Kirillov if (ProcessedReal[PMI.RealIdx] || ProcessedImag[PMI.ImagIdx]) 13621a1e7610SIgor Kirillov continue; 13631a1e7610SIgor Kirillov 13641a1e7610SIgor Kirillov auto It = CommonToNode.find(PMI.Common); 13651a1e7610SIgor Kirillov // TODO: Process independent complex multiplications. Cases like this: 13661a1e7610SIgor Kirillov // A.real() * B where both A and B are complex numbers. 13671a1e7610SIgor Kirillov if (It == CommonToNode.end()) { 13681a1e7610SIgor Kirillov LLVM_DEBUG({ 13691a1e7610SIgor Kirillov dbgs() << "Unprocessed independent partial multiplication:\n"; 13701a1e7610SIgor Kirillov for (auto *Mul : {&RealMuls[PMI.RealIdx], &RealMuls[PMI.RealIdx]}) 13711a1e7610SIgor Kirillov dbgs().indent(4) << (Mul->IsPositive ? "+" : "-") << *Mul->Multiplier 13721a1e7610SIgor Kirillov << " multiplied by " << *Mul->Multiplicand << "\n"; 13731a1e7610SIgor Kirillov }); 13741a1e7610SIgor Kirillov return nullptr; 13751a1e7610SIgor Kirillov } 13761a1e7610SIgor Kirillov 13771a1e7610SIgor Kirillov auto &RealMul = RealMuls[PMI.RealIdx]; 13781a1e7610SIgor Kirillov auto &ImagMul = ImagMuls[PMI.ImagIdx]; 13791a1e7610SIgor Kirillov 13801a1e7610SIgor Kirillov auto NodeA = It->second; 13811a1e7610SIgor Kirillov auto NodeB = PMI.Node; 13821a1e7610SIgor Kirillov auto IsMultiplicandReal = PMI.Common == NodeA->Real; 13831a1e7610SIgor Kirillov // The following table illustrates the relationship between multiplications 13841a1e7610SIgor Kirillov // and rotations. If we consider the multiplication (X + iY) * (U + iV), we 13851a1e7610SIgor Kirillov // can see: 13861a1e7610SIgor Kirillov // 13871a1e7610SIgor Kirillov // Rotation | Real | Imag | 13881a1e7610SIgor Kirillov // ---------+--------+--------+ 13891a1e7610SIgor Kirillov // 0 | x * u | x * v | 13901a1e7610SIgor Kirillov // 90 | -y * v | y * u | 13911a1e7610SIgor Kirillov // 180 | -x * u | -x * v | 13921a1e7610SIgor Kirillov // 270 | y * v | -y * u | 13931a1e7610SIgor Kirillov // 13941a1e7610SIgor Kirillov // Check if the candidate can indeed be represented by partial 13951a1e7610SIgor Kirillov // multiplication 13961a1e7610SIgor Kirillov // TODO: Add support for multiplication by complex one 13971a1e7610SIgor Kirillov if ((IsMultiplicandReal && PMI.IsNodeInverted) || 13981a1e7610SIgor Kirillov (!IsMultiplicandReal && !PMI.IsNodeInverted)) 13991a1e7610SIgor Kirillov continue; 14001a1e7610SIgor Kirillov 14011a1e7610SIgor Kirillov // Determine the rotation based on the multiplications 14021a1e7610SIgor Kirillov ComplexDeinterleavingRotation Rotation; 14031a1e7610SIgor Kirillov if (IsMultiplicandReal) { 14041a1e7610SIgor Kirillov // Detect 0 and 180 degrees rotation 14051a1e7610SIgor Kirillov if (RealMul.IsPositive && ImagMul.IsPositive) 14061a1e7610SIgor Kirillov Rotation = llvm::ComplexDeinterleavingRotation::Rotation_0; 14071a1e7610SIgor Kirillov else if (!RealMul.IsPositive && !ImagMul.IsPositive) 14081a1e7610SIgor Kirillov Rotation = llvm::ComplexDeinterleavingRotation::Rotation_180; 14091a1e7610SIgor Kirillov else 14101a1e7610SIgor Kirillov continue; 14111a1e7610SIgor Kirillov 14121a1e7610SIgor Kirillov } else { 14131a1e7610SIgor Kirillov // Detect 90 and 270 degrees rotation 14141a1e7610SIgor Kirillov if (!RealMul.IsPositive && ImagMul.IsPositive) 14151a1e7610SIgor Kirillov Rotation = llvm::ComplexDeinterleavingRotation::Rotation_90; 14161a1e7610SIgor Kirillov else if (RealMul.IsPositive && !ImagMul.IsPositive) 14171a1e7610SIgor Kirillov Rotation = llvm::ComplexDeinterleavingRotation::Rotation_270; 14181a1e7610SIgor Kirillov else 14191a1e7610SIgor Kirillov continue; 14201a1e7610SIgor Kirillov } 14211a1e7610SIgor Kirillov 14221a1e7610SIgor Kirillov LLVM_DEBUG({ 14231a1e7610SIgor Kirillov dbgs() << "Identified partial multiplication (X, Y) * (U, V):\n"; 14241a1e7610SIgor Kirillov dbgs().indent(4) << "X: " << *NodeA->Real << "\n"; 14251a1e7610SIgor Kirillov dbgs().indent(4) << "Y: " << *NodeA->Imag << "\n"; 14261a1e7610SIgor Kirillov dbgs().indent(4) << "U: " << *NodeB->Real << "\n"; 14271a1e7610SIgor Kirillov dbgs().indent(4) << "V: " << *NodeB->Imag << "\n"; 14281a1e7610SIgor Kirillov dbgs().indent(4) << "Rotation - " << (int)Rotation * 90 << "\n"; 14291a1e7610SIgor Kirillov }); 14301a1e7610SIgor Kirillov 14311a1e7610SIgor Kirillov NodePtr NodeMul = prepareCompositeNode( 14321a1e7610SIgor Kirillov ComplexDeinterleavingOperation::CMulPartial, nullptr, nullptr); 14331a1e7610SIgor Kirillov NodeMul->Rotation = Rotation; 14341a1e7610SIgor Kirillov NodeMul->addOperand(NodeA); 14351a1e7610SIgor Kirillov NodeMul->addOperand(NodeB); 14361a1e7610SIgor Kirillov if (Result) 14371a1e7610SIgor Kirillov NodeMul->addOperand(Result); 14381a1e7610SIgor Kirillov submitCompositeNode(NodeMul); 14391a1e7610SIgor Kirillov Result = NodeMul; 14401a1e7610SIgor Kirillov ProcessedReal[PMI.RealIdx] = true; 14411a1e7610SIgor Kirillov ProcessedImag[PMI.ImagIdx] = true; 14421a1e7610SIgor Kirillov } 14431a1e7610SIgor Kirillov 14441a1e7610SIgor Kirillov // Ensure all products have been processed, if not return nullptr. 14451a1e7610SIgor Kirillov if (!all_of(ProcessedReal, [](bool V) { return V; }) || 14461a1e7610SIgor Kirillov !all_of(ProcessedImag, [](bool V) { return V; })) { 14471a1e7610SIgor Kirillov 14481a1e7610SIgor Kirillov // Dump debug information about which partial multiplications are not 14491a1e7610SIgor Kirillov // processed. 14501a1e7610SIgor Kirillov LLVM_DEBUG({ 14511a1e7610SIgor Kirillov dbgs() << "Unprocessed products (Real):\n"; 14521a1e7610SIgor Kirillov for (size_t i = 0; i < ProcessedReal.size(); ++i) { 14531a1e7610SIgor Kirillov if (!ProcessedReal[i]) 14541a1e7610SIgor Kirillov dbgs().indent(4) << (RealMuls[i].IsPositive ? "+" : "-") 14551a1e7610SIgor Kirillov << *RealMuls[i].Multiplier << " multiplied by " 14561a1e7610SIgor Kirillov << *RealMuls[i].Multiplicand << "\n"; 14571a1e7610SIgor Kirillov } 14581a1e7610SIgor Kirillov dbgs() << "Unprocessed products (Imag):\n"; 14591a1e7610SIgor Kirillov for (size_t i = 0; i < ProcessedImag.size(); ++i) { 14601a1e7610SIgor Kirillov if (!ProcessedImag[i]) 14611a1e7610SIgor Kirillov dbgs().indent(4) << (ImagMuls[i].IsPositive ? "+" : "-") 14621a1e7610SIgor Kirillov << *ImagMuls[i].Multiplier << " multiplied by " 14631a1e7610SIgor Kirillov << *ImagMuls[i].Multiplicand << "\n"; 14641a1e7610SIgor Kirillov } 14651a1e7610SIgor Kirillov }); 14661a1e7610SIgor Kirillov return nullptr; 14671a1e7610SIgor Kirillov } 14681a1e7610SIgor Kirillov 14691a1e7610SIgor Kirillov return Result; 14701a1e7610SIgor Kirillov } 14711a1e7610SIgor Kirillov 14721a1e7610SIgor Kirillov ComplexDeinterleavingGraph::NodePtr 1473c15557d6SIgor Kirillov ComplexDeinterleavingGraph::identifyAdditions( 1474c15557d6SIgor Kirillov std::list<Addend> &RealAddends, std::list<Addend> &ImagAddends, 1475c15557d6SIgor Kirillov std::optional<FastMathFlags> Flags, NodePtr Accumulator = nullptr) { 14761a1e7610SIgor Kirillov if (RealAddends.size() != ImagAddends.size()) 14771a1e7610SIgor Kirillov return nullptr; 14781a1e7610SIgor Kirillov 14791a1e7610SIgor Kirillov NodePtr Result; 14801a1e7610SIgor Kirillov // If we have accumulator use it as first addend 14811a1e7610SIgor Kirillov if (Accumulator) 14821a1e7610SIgor Kirillov Result = Accumulator; 14831a1e7610SIgor Kirillov // Otherwise find an element with both positive real and imaginary parts. 14841a1e7610SIgor Kirillov else 14851a1e7610SIgor Kirillov Result = extractPositiveAddend(RealAddends, ImagAddends); 14861a1e7610SIgor Kirillov 14871a1e7610SIgor Kirillov if (!Result) 14881a1e7610SIgor Kirillov return nullptr; 14891a1e7610SIgor Kirillov 14901a1e7610SIgor Kirillov while (!RealAddends.empty()) { 14911a1e7610SIgor Kirillov auto ItR = RealAddends.begin(); 14921a1e7610SIgor Kirillov auto [R, IsPositiveR] = *ItR; 14931a1e7610SIgor Kirillov 14941a1e7610SIgor Kirillov bool FoundImag = false; 14951a1e7610SIgor Kirillov for (auto ItI = ImagAddends.begin(); ItI != ImagAddends.end(); ++ItI) { 14961a1e7610SIgor Kirillov auto [I, IsPositiveI] = *ItI; 14971a1e7610SIgor Kirillov ComplexDeinterleavingRotation Rotation; 14981a1e7610SIgor Kirillov if (IsPositiveR && IsPositiveI) 14991a1e7610SIgor Kirillov Rotation = ComplexDeinterleavingRotation::Rotation_0; 15001a1e7610SIgor Kirillov else if (!IsPositiveR && IsPositiveI) 15011a1e7610SIgor Kirillov Rotation = ComplexDeinterleavingRotation::Rotation_90; 15021a1e7610SIgor Kirillov else if (!IsPositiveR && !IsPositiveI) 15031a1e7610SIgor Kirillov Rotation = ComplexDeinterleavingRotation::Rotation_180; 15041a1e7610SIgor Kirillov else 15051a1e7610SIgor Kirillov Rotation = ComplexDeinterleavingRotation::Rotation_270; 15061a1e7610SIgor Kirillov 15071a1e7610SIgor Kirillov NodePtr AddNode; 15081a1e7610SIgor Kirillov if (Rotation == ComplexDeinterleavingRotation::Rotation_0 || 15091a1e7610SIgor Kirillov Rotation == ComplexDeinterleavingRotation::Rotation_180) { 15101a1e7610SIgor Kirillov AddNode = identifyNode(R, I); 15111a1e7610SIgor Kirillov } else { 15121a1e7610SIgor Kirillov AddNode = identifyNode(I, R); 15131a1e7610SIgor Kirillov } 15141a1e7610SIgor Kirillov if (AddNode) { 15151a1e7610SIgor Kirillov LLVM_DEBUG({ 15161a1e7610SIgor Kirillov dbgs() << "Identified addition:\n"; 15171a1e7610SIgor Kirillov dbgs().indent(4) << "X: " << *R << "\n"; 15181a1e7610SIgor Kirillov dbgs().indent(4) << "Y: " << *I << "\n"; 15191a1e7610SIgor Kirillov dbgs().indent(4) << "Rotation - " << (int)Rotation * 90 << "\n"; 15201a1e7610SIgor Kirillov }); 15211a1e7610SIgor Kirillov 15221a1e7610SIgor Kirillov NodePtr TmpNode; 15231a1e7610SIgor Kirillov if (Rotation == llvm::ComplexDeinterleavingRotation::Rotation_0) { 15241a1e7610SIgor Kirillov TmpNode = prepareCompositeNode( 15251a1e7610SIgor Kirillov ComplexDeinterleavingOperation::Symmetric, nullptr, nullptr); 1526c15557d6SIgor Kirillov if (Flags) { 15271a1e7610SIgor Kirillov TmpNode->Opcode = Instruction::FAdd; 1528c15557d6SIgor Kirillov TmpNode->Flags = *Flags; 1529c15557d6SIgor Kirillov } else { 1530c15557d6SIgor Kirillov TmpNode->Opcode = Instruction::Add; 1531c15557d6SIgor Kirillov } 15321a1e7610SIgor Kirillov } else if (Rotation == 15331a1e7610SIgor Kirillov llvm::ComplexDeinterleavingRotation::Rotation_180) { 15341a1e7610SIgor Kirillov TmpNode = prepareCompositeNode( 15351a1e7610SIgor Kirillov ComplexDeinterleavingOperation::Symmetric, nullptr, nullptr); 1536c15557d6SIgor Kirillov if (Flags) { 15371a1e7610SIgor Kirillov TmpNode->Opcode = Instruction::FSub; 1538c15557d6SIgor Kirillov TmpNode->Flags = *Flags; 1539c15557d6SIgor Kirillov } else { 1540c15557d6SIgor Kirillov TmpNode->Opcode = Instruction::Sub; 1541c15557d6SIgor Kirillov } 15421a1e7610SIgor Kirillov } else { 15431a1e7610SIgor Kirillov TmpNode = prepareCompositeNode(ComplexDeinterleavingOperation::CAdd, 15441a1e7610SIgor Kirillov nullptr, nullptr); 15451a1e7610SIgor Kirillov TmpNode->Rotation = Rotation; 15461a1e7610SIgor Kirillov } 15471a1e7610SIgor Kirillov 15481a1e7610SIgor Kirillov TmpNode->addOperand(Result); 15491a1e7610SIgor Kirillov TmpNode->addOperand(AddNode); 15501a1e7610SIgor Kirillov submitCompositeNode(TmpNode); 15511a1e7610SIgor Kirillov Result = TmpNode; 15521a1e7610SIgor Kirillov RealAddends.erase(ItR); 15531a1e7610SIgor Kirillov ImagAddends.erase(ItI); 15541a1e7610SIgor Kirillov FoundImag = true; 15551a1e7610SIgor Kirillov break; 15561a1e7610SIgor Kirillov } 15571a1e7610SIgor Kirillov } 15581a1e7610SIgor Kirillov if (!FoundImag) 15591a1e7610SIgor Kirillov return nullptr; 15601a1e7610SIgor Kirillov } 15611a1e7610SIgor Kirillov return Result; 15621a1e7610SIgor Kirillov } 15631a1e7610SIgor Kirillov 15641a1e7610SIgor Kirillov ComplexDeinterleavingGraph::NodePtr 15651a1e7610SIgor Kirillov ComplexDeinterleavingGraph::extractPositiveAddend( 15661a1e7610SIgor Kirillov std::list<Addend> &RealAddends, std::list<Addend> &ImagAddends) { 15671a1e7610SIgor Kirillov for (auto ItR = RealAddends.begin(); ItR != RealAddends.end(); ++ItR) { 15681a1e7610SIgor Kirillov for (auto ItI = ImagAddends.begin(); ItI != ImagAddends.end(); ++ItI) { 15691a1e7610SIgor Kirillov auto [R, IsPositiveR] = *ItR; 15701a1e7610SIgor Kirillov auto [I, IsPositiveI] = *ItI; 15711a1e7610SIgor Kirillov if (IsPositiveR && IsPositiveI) { 15721a1e7610SIgor Kirillov auto Result = identifyNode(R, I); 15731a1e7610SIgor Kirillov if (Result) { 15741a1e7610SIgor Kirillov RealAddends.erase(ItR); 15751a1e7610SIgor Kirillov ImagAddends.erase(ItI); 15761a1e7610SIgor Kirillov return Result; 15771a1e7610SIgor Kirillov } 15781a1e7610SIgor Kirillov } 15791a1e7610SIgor Kirillov } 15801a1e7610SIgor Kirillov } 15811a1e7610SIgor Kirillov return nullptr; 15826850bc35SIgor Kirillov } 15836850bc35SIgor Kirillov 15846850bc35SIgor Kirillov bool ComplexDeinterleavingGraph::identifyNodes(Instruction *RootI) { 15852cbc265cSIgor Kirillov // This potential root instruction might already have been recognized as 15862cbc265cSIgor Kirillov // reduction. Because RootToNode maps both Real and Imaginary parts to 15872cbc265cSIgor Kirillov // CompositeNode we should choose only one either Real or Imag instruction to 15882cbc265cSIgor Kirillov // use as an anchor for generating complex instruction. 15892cbc265cSIgor Kirillov auto It = RootToNode.find(RootI); 1590e2cb07c3SIgor Kirillov if (It != RootToNode.end()) { 1591e2cb07c3SIgor Kirillov auto RootNode = It->second; 1592e2cb07c3SIgor Kirillov assert(RootNode->Operation == 15938e1b49c3SNicholas Guy ComplexDeinterleavingOperation::ReductionOperation || 15948e1b49c3SNicholas Guy RootNode->Operation == 15958e1b49c3SNicholas Guy ComplexDeinterleavingOperation::ReductionSingle); 1596e2cb07c3SIgor Kirillov // Find out which part, Real or Imag, comes later, and only if we come to 1597e2cb07c3SIgor Kirillov // the latest part, add it to OrderedRoots. 1598e2cb07c3SIgor Kirillov auto *R = cast<Instruction>(RootNode->Real); 15998e1b49c3SNicholas Guy auto *I = RootNode->Imag ? cast<Instruction>(RootNode->Imag) : nullptr; 16008e1b49c3SNicholas Guy 16018e1b49c3SNicholas Guy Instruction *ReplacementAnchor; 16028e1b49c3SNicholas Guy if (I) 16038e1b49c3SNicholas Guy ReplacementAnchor = R->comesBefore(I) ? I : R; 16048e1b49c3SNicholas Guy else 16058e1b49c3SNicholas Guy ReplacementAnchor = R; 16068e1b49c3SNicholas Guy 1607e2cb07c3SIgor Kirillov if (ReplacementAnchor != RootI) 1608e2cb07c3SIgor Kirillov return false; 16092cbc265cSIgor Kirillov OrderedRoots.push_back(RootI); 16102cbc265cSIgor Kirillov return true; 16112cbc265cSIgor Kirillov } 16122cbc265cSIgor Kirillov 16136850bc35SIgor Kirillov auto RootNode = identifyRoot(RootI); 16146850bc35SIgor Kirillov if (!RootNode) 16156850bc35SIgor Kirillov return false; 16166850bc35SIgor Kirillov 16176850bc35SIgor Kirillov LLVM_DEBUG({ 16186850bc35SIgor Kirillov Function *F = RootI->getFunction(); 16196850bc35SIgor Kirillov BasicBlock *B = RootI->getParent(); 16206850bc35SIgor Kirillov dbgs() << "Complex deinterleaving graph for " << F->getName() 16216850bc35SIgor Kirillov << "::" << B->getName() << ".\n"; 16226850bc35SIgor Kirillov dump(dbgs()); 16236850bc35SIgor Kirillov dbgs() << "\n"; 16246850bc35SIgor Kirillov }); 16256850bc35SIgor Kirillov RootToNode[RootI] = RootNode; 16266850bc35SIgor Kirillov OrderedRoots.push_back(RootI); 16276850bc35SIgor Kirillov return true; 16286850bc35SIgor Kirillov } 16296850bc35SIgor Kirillov 16302cbc265cSIgor Kirillov bool ComplexDeinterleavingGraph::collectPotentialReductions(BasicBlock *B) { 16312cbc265cSIgor Kirillov bool FoundPotentialReduction = false; 16322cbc265cSIgor Kirillov 16332cbc265cSIgor Kirillov auto *Br = dyn_cast<BranchInst>(B->getTerminator()); 16342cbc265cSIgor Kirillov if (!Br || Br->getNumSuccessors() != 2) 16352cbc265cSIgor Kirillov return false; 16362cbc265cSIgor Kirillov 16372cbc265cSIgor Kirillov // Identify simple one-block loop 16382cbc265cSIgor Kirillov if (Br->getSuccessor(0) != B && Br->getSuccessor(1) != B) 16392cbc265cSIgor Kirillov return false; 16402cbc265cSIgor Kirillov 16412cbc265cSIgor Kirillov SmallVector<PHINode *> PHIs; 16422cbc265cSIgor Kirillov for (auto &PHI : B->phis()) { 16432cbc265cSIgor Kirillov if (PHI.getNumIncomingValues() != 2) 16442cbc265cSIgor Kirillov continue; 16452cbc265cSIgor Kirillov 16462cbc265cSIgor Kirillov if (!PHI.getType()->isVectorTy()) 16472cbc265cSIgor Kirillov continue; 16482cbc265cSIgor Kirillov 16492cbc265cSIgor Kirillov auto *ReductionOp = dyn_cast<Instruction>(PHI.getIncomingValueForBlock(B)); 16502cbc265cSIgor Kirillov if (!ReductionOp) 16512cbc265cSIgor Kirillov continue; 16522cbc265cSIgor Kirillov 16532cbc265cSIgor Kirillov // Check if final instruction is reduced outside of current block 16542cbc265cSIgor Kirillov Instruction *FinalReduction = nullptr; 16552cbc265cSIgor Kirillov auto NumUsers = 0u; 16562cbc265cSIgor Kirillov for (auto *U : ReductionOp->users()) { 16572cbc265cSIgor Kirillov ++NumUsers; 16582cbc265cSIgor Kirillov if (U == &PHI) 16592cbc265cSIgor Kirillov continue; 16602cbc265cSIgor Kirillov FinalReduction = dyn_cast<Instruction>(U); 16612cbc265cSIgor Kirillov } 16622cbc265cSIgor Kirillov 16630aecf7ffSIgor Kirillov if (NumUsers != 2 || !FinalReduction || FinalReduction->getParent() == B || 16640aecf7ffSIgor Kirillov isa<PHINode>(FinalReduction)) 16652cbc265cSIgor Kirillov continue; 16662cbc265cSIgor Kirillov 16672cbc265cSIgor Kirillov ReductionInfo[ReductionOp] = {&PHI, FinalReduction}; 16682cbc265cSIgor Kirillov BackEdge = B; 16692cbc265cSIgor Kirillov auto BackEdgeIdx = PHI.getBasicBlockIndex(B); 16702cbc265cSIgor Kirillov auto IncomingIdx = BackEdgeIdx == 0 ? 1 : 0; 16712cbc265cSIgor Kirillov Incoming = PHI.getIncomingBlock(IncomingIdx); 16722cbc265cSIgor Kirillov FoundPotentialReduction = true; 16732cbc265cSIgor Kirillov 16742cbc265cSIgor Kirillov // If the initial value of PHINode is an Instruction, consider it a leaf 16752cbc265cSIgor Kirillov // value of a complex deinterleaving graph. 16762cbc265cSIgor Kirillov if (auto *InitPHI = 16772cbc265cSIgor Kirillov dyn_cast<Instruction>(PHI.getIncomingValueForBlock(Incoming))) 16782cbc265cSIgor Kirillov FinalInstructions.insert(InitPHI); 16792cbc265cSIgor Kirillov } 16802cbc265cSIgor Kirillov return FoundPotentialReduction; 16812cbc265cSIgor Kirillov } 16822cbc265cSIgor Kirillov 16832cbc265cSIgor Kirillov void ComplexDeinterleavingGraph::identifyReductionNodes() { 16842cbc265cSIgor Kirillov SmallVector<bool> Processed(ReductionInfo.size(), false); 16852cbc265cSIgor Kirillov SmallVector<Instruction *> OperationInstruction; 16862cbc265cSIgor Kirillov for (auto &P : ReductionInfo) 16872cbc265cSIgor Kirillov OperationInstruction.push_back(P.first); 16882cbc265cSIgor Kirillov 16892cbc265cSIgor Kirillov // Identify a complex computation by evaluating two reduction operations that 16902cbc265cSIgor Kirillov // potentially could be involved 16912cbc265cSIgor Kirillov for (size_t i = 0; i < OperationInstruction.size(); ++i) { 16922cbc265cSIgor Kirillov if (Processed[i]) 16932cbc265cSIgor Kirillov continue; 16942cbc265cSIgor Kirillov for (size_t j = i + 1; j < OperationInstruction.size(); ++j) { 16952cbc265cSIgor Kirillov if (Processed[j]) 16962cbc265cSIgor Kirillov continue; 16972cbc265cSIgor Kirillov auto *Real = OperationInstruction[i]; 16982cbc265cSIgor Kirillov auto *Imag = OperationInstruction[j]; 16991fce8df5SIgor Kirillov if (Real->getType() != Imag->getType()) 17001fce8df5SIgor Kirillov continue; 17012cbc265cSIgor Kirillov 17022cbc265cSIgor Kirillov RealPHI = ReductionInfo[Real].first; 17032cbc265cSIgor Kirillov ImagPHI = ReductionInfo[Imag].first; 17040aecf7ffSIgor Kirillov PHIsFound = false; 17052cbc265cSIgor Kirillov auto Node = identifyNode(Real, Imag); 17062cbc265cSIgor Kirillov if (!Node) { 17072cbc265cSIgor Kirillov std::swap(Real, Imag); 17082cbc265cSIgor Kirillov std::swap(RealPHI, ImagPHI); 17092cbc265cSIgor Kirillov Node = identifyNode(Real, Imag); 17102cbc265cSIgor Kirillov } 17112cbc265cSIgor Kirillov 17120aecf7ffSIgor Kirillov // If a node is identified and reduction PHINode is used in the chain of 17130aecf7ffSIgor Kirillov // operations, mark its operation instructions as used to prevent 17140aecf7ffSIgor Kirillov // re-identification and attach the node to the real part 17150aecf7ffSIgor Kirillov if (Node && PHIsFound) { 17162cbc265cSIgor Kirillov LLVM_DEBUG(dbgs() << "Identified reduction starting from instructions: " 17172cbc265cSIgor Kirillov << *Real << " / " << *Imag << "\n"); 17182cbc265cSIgor Kirillov Processed[i] = true; 17192cbc265cSIgor Kirillov Processed[j] = true; 17202cbc265cSIgor Kirillov auto RootNode = prepareCompositeNode( 17212cbc265cSIgor Kirillov ComplexDeinterleavingOperation::ReductionOperation, Real, Imag); 17222cbc265cSIgor Kirillov RootNode->addOperand(Node); 17232cbc265cSIgor Kirillov RootToNode[Real] = RootNode; 17242cbc265cSIgor Kirillov RootToNode[Imag] = RootNode; 17252cbc265cSIgor Kirillov submitCompositeNode(RootNode); 17262cbc265cSIgor Kirillov break; 17272cbc265cSIgor Kirillov } 17282cbc265cSIgor Kirillov } 17298e1b49c3SNicholas Guy 17308e1b49c3SNicholas Guy auto *Real = OperationInstruction[i]; 17318e1b49c3SNicholas Guy // We want to check that we have 2 operands, but the function attributes 17328e1b49c3SNicholas Guy // being counted as operands bloats this value. 1733*1b294353SNicholas Guy if (Processed[i] || Real->getNumOperands() < 2) 17348e1b49c3SNicholas Guy continue; 17358e1b49c3SNicholas Guy 17368e1b49c3SNicholas Guy RealPHI = ReductionInfo[Real].first; 17378e1b49c3SNicholas Guy ImagPHI = nullptr; 17388e1b49c3SNicholas Guy PHIsFound = false; 17398e1b49c3SNicholas Guy auto Node = identifyNode(Real->getOperand(0), Real->getOperand(1)); 17408e1b49c3SNicholas Guy if (Node && PHIsFound) { 17418e1b49c3SNicholas Guy LLVM_DEBUG( 17428e1b49c3SNicholas Guy dbgs() << "Identified single reduction starting from instruction: " 17438e1b49c3SNicholas Guy << *Real << "/" << *ReductionInfo[Real].second << "\n"); 17448e1b49c3SNicholas Guy Processed[i] = true; 17458e1b49c3SNicholas Guy auto RootNode = prepareCompositeNode( 17468e1b49c3SNicholas Guy ComplexDeinterleavingOperation::ReductionSingle, Real, nullptr); 17478e1b49c3SNicholas Guy RootNode->addOperand(Node); 17488e1b49c3SNicholas Guy RootToNode[Real] = RootNode; 17498e1b49c3SNicholas Guy submitCompositeNode(RootNode); 17508e1b49c3SNicholas Guy } 17512cbc265cSIgor Kirillov } 17522cbc265cSIgor Kirillov 17532cbc265cSIgor Kirillov RealPHI = nullptr; 17542cbc265cSIgor Kirillov ImagPHI = nullptr; 17552cbc265cSIgor Kirillov } 17562cbc265cSIgor Kirillov 17576850bc35SIgor Kirillov bool ComplexDeinterleavingGraph::checkNodes() { 17588e1b49c3SNicholas Guy 17598e1b49c3SNicholas Guy bool FoundDeinterleaveNode = false; 17608e1b49c3SNicholas Guy for (NodePtr N : CompositeNodes) { 17618e1b49c3SNicholas Guy if (!N->areOperandsValid()) 17628e1b49c3SNicholas Guy return false; 17638e1b49c3SNicholas Guy if (N->Operation == ComplexDeinterleavingOperation::Deinterleave) 17648e1b49c3SNicholas Guy FoundDeinterleaveNode = true; 17658e1b49c3SNicholas Guy } 17668e1b49c3SNicholas Guy 17678e1b49c3SNicholas Guy // We need a deinterleave node in order to guarantee that we're working with 17688e1b49c3SNicholas Guy // complex numbers. 17698e1b49c3SNicholas Guy if (!FoundDeinterleaveNode) { 17708e1b49c3SNicholas Guy LLVM_DEBUG( 17718e1b49c3SNicholas Guy dbgs() << "Couldn't find a deinterleave node within the graph, cannot " 17728e1b49c3SNicholas Guy "guarantee safety during graph transformation.\n"); 17738e1b49c3SNicholas Guy return false; 17748e1b49c3SNicholas Guy } 17758e1b49c3SNicholas Guy 17766850bc35SIgor Kirillov // Collect all instructions from roots to leaves 17776850bc35SIgor Kirillov SmallPtrSet<Instruction *, 16> AllInstructions; 17786850bc35SIgor Kirillov SmallVector<Instruction *, 8> Worklist; 17792cbc265cSIgor Kirillov for (auto &Pair : RootToNode) 17802cbc265cSIgor Kirillov Worklist.push_back(Pair.first); 17816850bc35SIgor Kirillov 17826850bc35SIgor Kirillov // Extract all instructions that are used by all XCMLA/XCADD/ADD/SUB/NEG 17836850bc35SIgor Kirillov // chains 17846850bc35SIgor Kirillov while (!Worklist.empty()) { 17856850bc35SIgor Kirillov auto *I = Worklist.back(); 17866850bc35SIgor Kirillov Worklist.pop_back(); 17876850bc35SIgor Kirillov 17886850bc35SIgor Kirillov if (!AllInstructions.insert(I).second) 17896850bc35SIgor Kirillov continue; 17906850bc35SIgor Kirillov 17916850bc35SIgor Kirillov for (Value *Op : I->operands()) { 17926850bc35SIgor Kirillov if (auto *OpI = dyn_cast<Instruction>(Op)) { 17936850bc35SIgor Kirillov if (!FinalInstructions.count(I)) 17946850bc35SIgor Kirillov Worklist.emplace_back(OpI); 17956850bc35SIgor Kirillov } 17966850bc35SIgor Kirillov } 17976850bc35SIgor Kirillov } 17986850bc35SIgor Kirillov 17996850bc35SIgor Kirillov // Find instructions that have users outside of chain 18006850bc35SIgor Kirillov SmallVector<Instruction *, 2> OuterInstructions; 18016850bc35SIgor Kirillov for (auto *I : AllInstructions) { 18026850bc35SIgor Kirillov // Skip root nodes 18036850bc35SIgor Kirillov if (RootToNode.count(I)) 18046850bc35SIgor Kirillov continue; 18056850bc35SIgor Kirillov 18066850bc35SIgor Kirillov for (User *U : I->users()) { 18076850bc35SIgor Kirillov if (AllInstructions.count(cast<Instruction>(U))) 18086850bc35SIgor Kirillov continue; 18096850bc35SIgor Kirillov 18106850bc35SIgor Kirillov // Found an instruction that is not used by XCMLA/XCADD chain 18116850bc35SIgor Kirillov Worklist.emplace_back(I); 18126850bc35SIgor Kirillov break; 18136850bc35SIgor Kirillov } 18146850bc35SIgor Kirillov } 18156850bc35SIgor Kirillov 18166850bc35SIgor Kirillov // If any instructions are found to be used outside, find and remove roots 18176850bc35SIgor Kirillov // that somehow connect to those instructions. 18186850bc35SIgor Kirillov SmallPtrSet<Instruction *, 16> Visited; 18196850bc35SIgor Kirillov while (!Worklist.empty()) { 18206850bc35SIgor Kirillov auto *I = Worklist.back(); 18216850bc35SIgor Kirillov Worklist.pop_back(); 18226850bc35SIgor Kirillov if (!Visited.insert(I).second) 18236850bc35SIgor Kirillov continue; 18246850bc35SIgor Kirillov 18256850bc35SIgor Kirillov // Found an impacted root node. Removing it from the nodes to be 18266850bc35SIgor Kirillov // deinterleaved 18276850bc35SIgor Kirillov if (RootToNode.count(I)) { 18286850bc35SIgor Kirillov LLVM_DEBUG(dbgs() << "Instruction " << *I 18296850bc35SIgor Kirillov << " could be deinterleaved but its chain of complex " 18306850bc35SIgor Kirillov "operations have an outside user\n"); 18316850bc35SIgor Kirillov RootToNode.erase(I); 18326850bc35SIgor Kirillov } 18336850bc35SIgor Kirillov 18346850bc35SIgor Kirillov if (!AllInstructions.count(I) || FinalInstructions.count(I)) 18356850bc35SIgor Kirillov continue; 18366850bc35SIgor Kirillov 18376850bc35SIgor Kirillov for (User *U : I->users()) 18386850bc35SIgor Kirillov Worklist.emplace_back(cast<Instruction>(U)); 18396850bc35SIgor Kirillov 18406850bc35SIgor Kirillov for (Value *Op : I->operands()) { 18416850bc35SIgor Kirillov if (auto *OpI = dyn_cast<Instruction>(Op)) 18426850bc35SIgor Kirillov Worklist.emplace_back(OpI); 18436850bc35SIgor Kirillov } 18446850bc35SIgor Kirillov } 18456850bc35SIgor Kirillov return !RootToNode.empty(); 18466850bc35SIgor Kirillov } 18476850bc35SIgor Kirillov 18486850bc35SIgor Kirillov ComplexDeinterleavingGraph::NodePtr 18496850bc35SIgor Kirillov ComplexDeinterleavingGraph::identifyRoot(Instruction *RootI) { 18506850bc35SIgor Kirillov if (auto *Intrinsic = dyn_cast<IntrinsicInst>(RootI)) { 1851bfc03171SMaciej Gabka if (Intrinsic->getIntrinsicID() != Intrinsic::vector_interleave2) 18526850bc35SIgor Kirillov return nullptr; 18536850bc35SIgor Kirillov 18546850bc35SIgor Kirillov auto *Real = dyn_cast<Instruction>(Intrinsic->getOperand(0)); 18556850bc35SIgor Kirillov auto *Imag = dyn_cast<Instruction>(Intrinsic->getOperand(1)); 18566850bc35SIgor Kirillov if (!Real || !Imag) 18576850bc35SIgor Kirillov return nullptr; 18586850bc35SIgor Kirillov 18596850bc35SIgor Kirillov return identifyNode(Real, Imag); 18606850bc35SIgor Kirillov } 18616850bc35SIgor Kirillov 18626850bc35SIgor Kirillov auto *SVI = dyn_cast<ShuffleVectorInst>(RootI); 18636850bc35SIgor Kirillov if (!SVI) 18646850bc35SIgor Kirillov return nullptr; 18656850bc35SIgor Kirillov 18666850bc35SIgor Kirillov // Look for a shufflevector that takes separate vectors of the real and 18676850bc35SIgor Kirillov // imaginary components and recombines them into a single vector. 18686850bc35SIgor Kirillov if (!isInterleavingMask(SVI->getShuffleMask())) 18696850bc35SIgor Kirillov return nullptr; 18706850bc35SIgor Kirillov 18716850bc35SIgor Kirillov Instruction *Real; 18726850bc35SIgor Kirillov Instruction *Imag; 18736850bc35SIgor Kirillov if (!match(RootI, m_Shuffle(m_Instruction(Real), m_Instruction(Imag)))) 18746850bc35SIgor Kirillov return nullptr; 18756850bc35SIgor Kirillov 18766850bc35SIgor Kirillov return identifyNode(Real, Imag); 18776850bc35SIgor Kirillov } 18786850bc35SIgor Kirillov 18796850bc35SIgor Kirillov ComplexDeinterleavingGraph::NodePtr 18806850bc35SIgor Kirillov ComplexDeinterleavingGraph::identifyDeinterleave(Instruction *Real, 18816850bc35SIgor Kirillov Instruction *Imag) { 18826850bc35SIgor Kirillov Instruction *I = nullptr; 18836850bc35SIgor Kirillov Value *FinalValue = nullptr; 18846850bc35SIgor Kirillov if (match(Real, m_ExtractValue<0>(m_Instruction(I))) && 18856850bc35SIgor Kirillov match(Imag, m_ExtractValue<1>(m_Specific(I))) && 1886bfc03171SMaciej Gabka match(I, m_Intrinsic<Intrinsic::vector_deinterleave2>( 18876850bc35SIgor Kirillov m_Value(FinalValue)))) { 18886850bc35SIgor Kirillov NodePtr PlaceholderNode = prepareCompositeNode( 18896850bc35SIgor Kirillov llvm::ComplexDeinterleavingOperation::Deinterleave, Real, Imag); 18906850bc35SIgor Kirillov PlaceholderNode->ReplacementNode = FinalValue; 18916850bc35SIgor Kirillov FinalInstructions.insert(Real); 18926850bc35SIgor Kirillov FinalInstructions.insert(Imag); 18936850bc35SIgor Kirillov return submitCompositeNode(PlaceholderNode); 18946850bc35SIgor Kirillov } 18956850bc35SIgor Kirillov 1896d52e2839SNicholas Guy auto *RealShuffle = dyn_cast<ShuffleVectorInst>(Real); 1897d52e2839SNicholas Guy auto *ImagShuffle = dyn_cast<ShuffleVectorInst>(Imag); 18986850bc35SIgor Kirillov if (!RealShuffle || !ImagShuffle) { 18996850bc35SIgor Kirillov if (RealShuffle || ImagShuffle) 19006850bc35SIgor Kirillov LLVM_DEBUG(dbgs() << " - There's a shuffle where there shouldn't be.\n"); 19016850bc35SIgor Kirillov return nullptr; 19026850bc35SIgor Kirillov } 19036850bc35SIgor Kirillov 1904d52e2839SNicholas Guy Value *RealOp1 = RealShuffle->getOperand(1); 1905d52e2839SNicholas Guy if (!isa<UndefValue>(RealOp1) && !isa<ConstantAggregateZero>(RealOp1)) { 1906d52e2839SNicholas Guy LLVM_DEBUG(dbgs() << " - RealOp1 is not undef or zero.\n"); 1907d52e2839SNicholas Guy return nullptr; 1908d52e2839SNicholas Guy } 1909d52e2839SNicholas Guy Value *ImagOp1 = ImagShuffle->getOperand(1); 1910d52e2839SNicholas Guy if (!isa<UndefValue>(ImagOp1) && !isa<ConstantAggregateZero>(ImagOp1)) { 1911d52e2839SNicholas Guy LLVM_DEBUG(dbgs() << " - ImagOp1 is not undef or zero.\n"); 1912d52e2839SNicholas Guy return nullptr; 1913d52e2839SNicholas Guy } 1914d52e2839SNicholas Guy 1915d52e2839SNicholas Guy Value *RealOp0 = RealShuffle->getOperand(0); 1916d52e2839SNicholas Guy Value *ImagOp0 = ImagShuffle->getOperand(0); 1917d52e2839SNicholas Guy 1918d52e2839SNicholas Guy if (RealOp0 != ImagOp0) { 1919d52e2839SNicholas Guy LLVM_DEBUG(dbgs() << " - Shuffle operands are not equal.\n"); 1920d52e2839SNicholas Guy return nullptr; 1921d52e2839SNicholas Guy } 1922d52e2839SNicholas Guy 1923d52e2839SNicholas Guy ArrayRef<int> RealMask = RealShuffle->getShuffleMask(); 1924d52e2839SNicholas Guy ArrayRef<int> ImagMask = ImagShuffle->getShuffleMask(); 1925d52e2839SNicholas Guy if (!isDeinterleavingMask(RealMask) || !isDeinterleavingMask(ImagMask)) { 1926d52e2839SNicholas Guy LLVM_DEBUG(dbgs() << " - Masks are not deinterleaving.\n"); 1927d52e2839SNicholas Guy return nullptr; 1928d52e2839SNicholas Guy } 1929d52e2839SNicholas Guy 1930d52e2839SNicholas Guy if (RealMask[0] != 0 || ImagMask[0] != 1) { 1931d52e2839SNicholas Guy LLVM_DEBUG(dbgs() << " - Masks do not have the correct initial value.\n"); 1932d52e2839SNicholas Guy return nullptr; 1933d52e2839SNicholas Guy } 1934d52e2839SNicholas Guy 1935d52e2839SNicholas Guy // Type checking, the shuffle type should be a vector type of the same 1936d52e2839SNicholas Guy // scalar type, but half the size 1937d52e2839SNicholas Guy auto CheckType = [&](ShuffleVectorInst *Shuffle) { 1938d52e2839SNicholas Guy Value *Op = Shuffle->getOperand(0); 1939d52e2839SNicholas Guy auto *ShuffleTy = cast<FixedVectorType>(Shuffle->getType()); 1940d52e2839SNicholas Guy auto *OpTy = cast<FixedVectorType>(Op->getType()); 1941d52e2839SNicholas Guy 1942d52e2839SNicholas Guy if (OpTy->getScalarType() != ShuffleTy->getScalarType()) 1943d52e2839SNicholas Guy return false; 1944d52e2839SNicholas Guy if ((ShuffleTy->getNumElements() * 2) != OpTy->getNumElements()) 1945d52e2839SNicholas Guy return false; 1946d52e2839SNicholas Guy 1947d52e2839SNicholas Guy return true; 1948d52e2839SNicholas Guy }; 1949d52e2839SNicholas Guy 1950d52e2839SNicholas Guy auto CheckDeinterleavingShuffle = [&](ShuffleVectorInst *Shuffle) -> bool { 1951d52e2839SNicholas Guy if (!CheckType(Shuffle)) 1952d52e2839SNicholas Guy return false; 1953d52e2839SNicholas Guy 1954d52e2839SNicholas Guy ArrayRef<int> Mask = Shuffle->getShuffleMask(); 1955d52e2839SNicholas Guy int Last = *Mask.rbegin(); 1956d52e2839SNicholas Guy 1957d52e2839SNicholas Guy Value *Op = Shuffle->getOperand(0); 1958d52e2839SNicholas Guy auto *OpTy = cast<FixedVectorType>(Op->getType()); 1959d52e2839SNicholas Guy int NumElements = OpTy->getNumElements(); 1960d52e2839SNicholas Guy 1961d52e2839SNicholas Guy // Ensure that the deinterleaving shuffle only pulls from the first 1962d52e2839SNicholas Guy // shuffle operand. 1963d52e2839SNicholas Guy return Last < NumElements; 1964d52e2839SNicholas Guy }; 1965d52e2839SNicholas Guy 1966d52e2839SNicholas Guy if (RealShuffle->getType() != ImagShuffle->getType()) { 1967d52e2839SNicholas Guy LLVM_DEBUG(dbgs() << " - Shuffle types aren't equal.\n"); 1968d52e2839SNicholas Guy return nullptr; 1969d52e2839SNicholas Guy } 1970d52e2839SNicholas Guy if (!CheckDeinterleavingShuffle(RealShuffle)) { 1971d52e2839SNicholas Guy LLVM_DEBUG(dbgs() << " - RealShuffle is invalid type.\n"); 1972d52e2839SNicholas Guy return nullptr; 1973d52e2839SNicholas Guy } 1974d52e2839SNicholas Guy if (!CheckDeinterleavingShuffle(ImagShuffle)) { 1975d52e2839SNicholas Guy LLVM_DEBUG(dbgs() << " - ImagShuffle is invalid type.\n"); 1976d52e2839SNicholas Guy return nullptr; 1977d52e2839SNicholas Guy } 1978d52e2839SNicholas Guy 1979d52e2839SNicholas Guy NodePtr PlaceholderNode = 19806850bc35SIgor Kirillov prepareCompositeNode(llvm::ComplexDeinterleavingOperation::Deinterleave, 1981d52e2839SNicholas Guy RealShuffle, ImagShuffle); 1982d52e2839SNicholas Guy PlaceholderNode->ReplacementNode = RealShuffle->getOperand(0); 1983c692e87aSIgor Kirillov FinalInstructions.insert(RealShuffle); 1984c692e87aSIgor Kirillov FinalInstructions.insert(ImagShuffle); 1985d52e2839SNicholas Guy return submitCompositeNode(PlaceholderNode); 1986d52e2839SNicholas Guy } 1987d52e2839SNicholas Guy 19882cbc265cSIgor Kirillov ComplexDeinterleavingGraph::NodePtr 19897f20407cSIgor Kirillov ComplexDeinterleavingGraph::identifySplat(Value *R, Value *I) { 19907f20407cSIgor Kirillov auto IsSplat = [](Value *V) -> bool { 19917f20407cSIgor Kirillov // Fixed-width vector with constants 19927f20407cSIgor Kirillov if (isa<ConstantDataVector>(V)) 19937f20407cSIgor Kirillov return true; 19947f20407cSIgor Kirillov 19957f20407cSIgor Kirillov VectorType *VTy; 19967f20407cSIgor Kirillov ArrayRef<int> Mask; 19977f20407cSIgor Kirillov // Splats are represented differently depending on whether the repeated 19987f20407cSIgor Kirillov // value is a constant or an Instruction 19997f20407cSIgor Kirillov if (auto *Const = dyn_cast<ConstantExpr>(V)) { 20007f20407cSIgor Kirillov if (Const->getOpcode() != Instruction::ShuffleVector) 20017f20407cSIgor Kirillov return false; 20027f20407cSIgor Kirillov VTy = cast<VectorType>(Const->getType()); 20037f20407cSIgor Kirillov Mask = Const->getShuffleMask(); 20047f20407cSIgor Kirillov } else if (auto *Shuf = dyn_cast<ShuffleVectorInst>(V)) { 20057f20407cSIgor Kirillov VTy = Shuf->getType(); 20067f20407cSIgor Kirillov Mask = Shuf->getShuffleMask(); 20077f20407cSIgor Kirillov } else { 20087f20407cSIgor Kirillov return false; 20097f20407cSIgor Kirillov } 20107f20407cSIgor Kirillov 20117f20407cSIgor Kirillov // When the data type is <1 x Type>, it's not possible to differentiate 20127f20407cSIgor Kirillov // between the ComplexDeinterleaving::Deinterleave and 20137f20407cSIgor Kirillov // ComplexDeinterleaving::Splat operations. 20147f20407cSIgor Kirillov if (!VTy->isScalableTy() && VTy->getElementCount().getKnownMinValue() == 1) 20157f20407cSIgor Kirillov return false; 20167f20407cSIgor Kirillov 20177f20407cSIgor Kirillov return all_equal(Mask) && Mask[0] == 0; 20187f20407cSIgor Kirillov }; 20197f20407cSIgor Kirillov 20207f20407cSIgor Kirillov if (!IsSplat(R) || !IsSplat(I)) 20217f20407cSIgor Kirillov return nullptr; 20227f20407cSIgor Kirillov 20237f20407cSIgor Kirillov auto *Real = dyn_cast<Instruction>(R); 20247f20407cSIgor Kirillov auto *Imag = dyn_cast<Instruction>(I); 20257f20407cSIgor Kirillov if ((!Real && Imag) || (Real && !Imag)) 20267f20407cSIgor Kirillov return nullptr; 20277f20407cSIgor Kirillov 20287f20407cSIgor Kirillov if (Real && Imag) { 20297f20407cSIgor Kirillov // Non-constant splats should be in the same basic block 20307f20407cSIgor Kirillov if (Real->getParent() != Imag->getParent()) 20317f20407cSIgor Kirillov return nullptr; 20327f20407cSIgor Kirillov 20337f20407cSIgor Kirillov FinalInstructions.insert(Real); 20347f20407cSIgor Kirillov FinalInstructions.insert(Imag); 20357f20407cSIgor Kirillov } 20367f20407cSIgor Kirillov NodePtr PlaceholderNode = 20377f20407cSIgor Kirillov prepareCompositeNode(ComplexDeinterleavingOperation::Splat, R, I); 20387f20407cSIgor Kirillov return submitCompositeNode(PlaceholderNode); 20397f20407cSIgor Kirillov } 20407f20407cSIgor Kirillov 20417f20407cSIgor Kirillov ComplexDeinterleavingGraph::NodePtr 20422cbc265cSIgor Kirillov ComplexDeinterleavingGraph::identifyPHINode(Instruction *Real, 20432cbc265cSIgor Kirillov Instruction *Imag) { 20448e1b49c3SNicholas Guy if (Real != RealPHI || (ImagPHI && Imag != ImagPHI)) 20452cbc265cSIgor Kirillov return nullptr; 20462cbc265cSIgor Kirillov 20470aecf7ffSIgor Kirillov PHIsFound = true; 20482cbc265cSIgor Kirillov NodePtr PlaceholderNode = prepareCompositeNode( 20492cbc265cSIgor Kirillov ComplexDeinterleavingOperation::ReductionPHI, Real, Imag); 20502cbc265cSIgor Kirillov return submitCompositeNode(PlaceholderNode); 20512cbc265cSIgor Kirillov } 20522cbc265cSIgor Kirillov 205304a8070bSIgor Kirillov ComplexDeinterleavingGraph::NodePtr 205404a8070bSIgor Kirillov ComplexDeinterleavingGraph::identifySelectNode(Instruction *Real, 205504a8070bSIgor Kirillov Instruction *Imag) { 205604a8070bSIgor Kirillov auto *SelectReal = dyn_cast<SelectInst>(Real); 205704a8070bSIgor Kirillov auto *SelectImag = dyn_cast<SelectInst>(Imag); 205804a8070bSIgor Kirillov if (!SelectReal || !SelectImag) 205904a8070bSIgor Kirillov return nullptr; 206004a8070bSIgor Kirillov 206104a8070bSIgor Kirillov Instruction *MaskA, *MaskB; 206204a8070bSIgor Kirillov Instruction *AR, *AI, *RA, *BI; 206304a8070bSIgor Kirillov if (!match(Real, m_Select(m_Instruction(MaskA), m_Instruction(AR), 206404a8070bSIgor Kirillov m_Instruction(RA))) || 206504a8070bSIgor Kirillov !match(Imag, m_Select(m_Instruction(MaskB), m_Instruction(AI), 206604a8070bSIgor Kirillov m_Instruction(BI)))) 206704a8070bSIgor Kirillov return nullptr; 206804a8070bSIgor Kirillov 206904a8070bSIgor Kirillov if (MaskA != MaskB && !MaskA->isIdenticalTo(MaskB)) 207004a8070bSIgor Kirillov return nullptr; 207104a8070bSIgor Kirillov 207204a8070bSIgor Kirillov if (!MaskA->getType()->isVectorTy()) 207304a8070bSIgor Kirillov return nullptr; 207404a8070bSIgor Kirillov 207504a8070bSIgor Kirillov auto NodeA = identifyNode(AR, AI); 207604a8070bSIgor Kirillov if (!NodeA) 207704a8070bSIgor Kirillov return nullptr; 207804a8070bSIgor Kirillov 207904a8070bSIgor Kirillov auto NodeB = identifyNode(RA, BI); 208004a8070bSIgor Kirillov if (!NodeB) 208104a8070bSIgor Kirillov return nullptr; 208204a8070bSIgor Kirillov 208304a8070bSIgor Kirillov NodePtr PlaceholderNode = prepareCompositeNode( 208404a8070bSIgor Kirillov ComplexDeinterleavingOperation::ReductionSelect, Real, Imag); 208504a8070bSIgor Kirillov PlaceholderNode->addOperand(NodeA); 208604a8070bSIgor Kirillov PlaceholderNode->addOperand(NodeB); 208704a8070bSIgor Kirillov FinalInstructions.insert(MaskA); 208804a8070bSIgor Kirillov FinalInstructions.insert(MaskB); 208904a8070bSIgor Kirillov return submitCompositeNode(PlaceholderNode); 209004a8070bSIgor Kirillov } 209104a8070bSIgor Kirillov 20921a1e7610SIgor Kirillov static Value *replaceSymmetricNode(IRBuilderBase &B, unsigned Opcode, 2093c15557d6SIgor Kirillov std::optional<FastMathFlags> Flags, 2094c15557d6SIgor Kirillov Value *InputA, Value *InputB) { 20951a1e7610SIgor Kirillov Value *I; 20961a1e7610SIgor Kirillov switch (Opcode) { 209796615c85SNicholas Guy case Instruction::FNeg: 20981a1e7610SIgor Kirillov I = B.CreateFNeg(InputA); 20991a1e7610SIgor Kirillov break; 210096615c85SNicholas Guy case Instruction::FAdd: 21011a1e7610SIgor Kirillov I = B.CreateFAdd(InputA, InputB); 21021a1e7610SIgor Kirillov break; 2103c15557d6SIgor Kirillov case Instruction::Add: 2104c15557d6SIgor Kirillov I = B.CreateAdd(InputA, InputB); 2105c15557d6SIgor Kirillov break; 210696615c85SNicholas Guy case Instruction::FSub: 21071a1e7610SIgor Kirillov I = B.CreateFSub(InputA, InputB); 21081a1e7610SIgor Kirillov break; 2109c15557d6SIgor Kirillov case Instruction::Sub: 2110c15557d6SIgor Kirillov I = B.CreateSub(InputA, InputB); 2111c15557d6SIgor Kirillov break; 211296615c85SNicholas Guy case Instruction::FMul: 21131a1e7610SIgor Kirillov I = B.CreateFMul(InputA, InputB); 21141a1e7610SIgor Kirillov break; 2115c15557d6SIgor Kirillov case Instruction::Mul: 2116c15557d6SIgor Kirillov I = B.CreateMul(InputA, InputB); 2117c15557d6SIgor Kirillov break; 21181a1e7610SIgor Kirillov default: 21191a1e7610SIgor Kirillov llvm_unreachable("Incorrect symmetric opcode"); 212096615c85SNicholas Guy } 2121c15557d6SIgor Kirillov if (Flags) 2122c15557d6SIgor Kirillov cast<Instruction>(I)->setFastMathFlags(*Flags); 21231a1e7610SIgor Kirillov return I; 212496615c85SNicholas Guy } 212596615c85SNicholas Guy 212640a81d31SIgor Kirillov Value *ComplexDeinterleavingGraph::replaceNode(IRBuilderBase &Builder, 212740a81d31SIgor Kirillov RawNodePtr Node) { 2128d52e2839SNicholas Guy if (Node->ReplacementNode) 2129d52e2839SNicholas Guy return Node->ReplacementNode; 2130d52e2839SNicholas Guy 21312cbc265cSIgor Kirillov auto ReplaceOperandIfExist = [&](RawNodePtr &Node, unsigned Idx) -> Value * { 21322cbc265cSIgor Kirillov return Node->Operands.size() > Idx 21332cbc265cSIgor Kirillov ? replaceNode(Builder, Node->Operands[Idx]) 213440a81d31SIgor Kirillov : nullptr; 21352cbc265cSIgor Kirillov }; 2136d52e2839SNicholas Guy 21372cbc265cSIgor Kirillov Value *ReplacementNode; 21382cbc265cSIgor Kirillov switch (Node->Operation) { 21398e1b49c3SNicholas Guy case ComplexDeinterleavingOperation::CDot: { 21408e1b49c3SNicholas Guy Value *Input0 = ReplaceOperandIfExist(Node, 0); 21418e1b49c3SNicholas Guy Value *Input1 = ReplaceOperandIfExist(Node, 1); 21428e1b49c3SNicholas Guy Value *Accumulator = ReplaceOperandIfExist(Node, 2); 21438e1b49c3SNicholas Guy assert(!Input1 || (Input0->getType() == Input1->getType() && 21448e1b49c3SNicholas Guy "Node inputs need to be of the same type")); 21458e1b49c3SNicholas Guy ReplacementNode = TL->createComplexDeinterleavingIR( 21468e1b49c3SNicholas Guy Builder, Node->Operation, Node->Rotation, Input0, Input1, Accumulator); 21478e1b49c3SNicholas Guy break; 21488e1b49c3SNicholas Guy } 21492cbc265cSIgor Kirillov case ComplexDeinterleavingOperation::CAdd: 21502cbc265cSIgor Kirillov case ComplexDeinterleavingOperation::CMulPartial: 21512cbc265cSIgor Kirillov case ComplexDeinterleavingOperation::Symmetric: { 21522cbc265cSIgor Kirillov Value *Input0 = ReplaceOperandIfExist(Node, 0); 21532cbc265cSIgor Kirillov Value *Input1 = ReplaceOperandIfExist(Node, 1); 21542cbc265cSIgor Kirillov Value *Accumulator = ReplaceOperandIfExist(Node, 2); 21552cbc265cSIgor Kirillov assert(!Input1 || (Input0->getType() == Input1->getType() && 21562cbc265cSIgor Kirillov "Node inputs need to be of the same type")); 21572cbc265cSIgor Kirillov assert(!Accumulator || 21582cbc265cSIgor Kirillov (Input0->getType() == Accumulator->getType() && 21592cbc265cSIgor Kirillov "Accumulator and input need to be of the same type")); 216096615c85SNicholas Guy if (Node->Operation == ComplexDeinterleavingOperation::Symmetric) 21612cbc265cSIgor Kirillov ReplacementNode = replaceSymmetricNode(Builder, Node->Opcode, Node->Flags, 21622cbc265cSIgor Kirillov Input0, Input1); 216396615c85SNicholas Guy else 21642cbc265cSIgor Kirillov ReplacementNode = TL->createComplexDeinterleavingIR( 21652cbc265cSIgor Kirillov Builder, Node->Operation, Node->Rotation, Input0, Input1, 21662cbc265cSIgor Kirillov Accumulator); 21672cbc265cSIgor Kirillov break; 21682cbc265cSIgor Kirillov } 21692cbc265cSIgor Kirillov case ComplexDeinterleavingOperation::Deinterleave: 21702cbc265cSIgor Kirillov llvm_unreachable("Deinterleave node should already have ReplacementNode"); 21712cbc265cSIgor Kirillov break; 21727f20407cSIgor Kirillov case ComplexDeinterleavingOperation::Splat: { 21737f20407cSIgor Kirillov auto *NewTy = VectorType::getDoubleElementsVectorType( 21747f20407cSIgor Kirillov cast<VectorType>(Node->Real->getType())); 21757f20407cSIgor Kirillov auto *R = dyn_cast<Instruction>(Node->Real); 21767f20407cSIgor Kirillov auto *I = dyn_cast<Instruction>(Node->Imag); 21777f20407cSIgor Kirillov if (R && I) { 21787f20407cSIgor Kirillov // Splats that are not constant are interleaved where they are located 21797f20407cSIgor Kirillov Instruction *InsertPoint = (I->comesBefore(R) ? R : I)->getNextNode(); 21807f20407cSIgor Kirillov IRBuilder<> IRB(InsertPoint); 2181bfc03171SMaciej Gabka ReplacementNode = IRB.CreateIntrinsic(Intrinsic::vector_interleave2, 21827f20407cSIgor Kirillov NewTy, {Node->Real, Node->Imag}); 2183bfc03171SMaciej Gabka } else { 2184bfc03171SMaciej Gabka ReplacementNode = Builder.CreateIntrinsic( 2185bfc03171SMaciej Gabka Intrinsic::vector_interleave2, NewTy, {Node->Real, Node->Imag}); 21867f20407cSIgor Kirillov } 21877f20407cSIgor Kirillov break; 21887f20407cSIgor Kirillov } 21892cbc265cSIgor Kirillov case ComplexDeinterleavingOperation::ReductionPHI: { 21902cbc265cSIgor Kirillov // If Operation is ReductionPHI, a new empty PHINode is created. 21912cbc265cSIgor Kirillov // It is filled later when the ReductionOperation is processed. 21928e1b49c3SNicholas Guy auto *OldPHI = cast<PHINode>(Node->Real); 21932cbc265cSIgor Kirillov auto *VTy = cast<VectorType>(Node->Real->getType()); 21942cbc265cSIgor Kirillov auto *NewVTy = VectorType::getDoubleElementsVectorType(VTy); 2195f33f66beSJeremy Morse auto *NewPHI = PHINode::Create(NewVTy, 0, "", BackEdge->getFirstNonPHIIt()); 21968e1b49c3SNicholas Guy OldToNewPHI[OldPHI] = NewPHI; 21972cbc265cSIgor Kirillov ReplacementNode = NewPHI; 21982cbc265cSIgor Kirillov break; 21992cbc265cSIgor Kirillov } 22008e1b49c3SNicholas Guy case ComplexDeinterleavingOperation::ReductionSingle: 22018e1b49c3SNicholas Guy ReplacementNode = replaceNode(Builder, Node->Operands[0]); 22028e1b49c3SNicholas Guy processReductionSingle(ReplacementNode, Node); 22038e1b49c3SNicholas Guy break; 22042cbc265cSIgor Kirillov case ComplexDeinterleavingOperation::ReductionOperation: 22052cbc265cSIgor Kirillov ReplacementNode = replaceNode(Builder, Node->Operands[0]); 22062cbc265cSIgor Kirillov processReductionOperation(ReplacementNode, Node); 22072cbc265cSIgor Kirillov break; 220804a8070bSIgor Kirillov case ComplexDeinterleavingOperation::ReductionSelect: { 2209b4f9c3a9SIgor Kirillov auto *MaskReal = cast<Instruction>(Node->Real)->getOperand(0); 2210b4f9c3a9SIgor Kirillov auto *MaskImag = cast<Instruction>(Node->Imag)->getOperand(0); 221104a8070bSIgor Kirillov auto *A = replaceNode(Builder, Node->Operands[0]); 221204a8070bSIgor Kirillov auto *B = replaceNode(Builder, Node->Operands[1]); 221304a8070bSIgor Kirillov auto *NewMaskTy = VectorType::getDoubleElementsVectorType( 221404a8070bSIgor Kirillov cast<VectorType>(MaskReal->getType())); 2215bfc03171SMaciej Gabka auto *NewMask = Builder.CreateIntrinsic(Intrinsic::vector_interleave2, 221604a8070bSIgor Kirillov NewMaskTy, {MaskReal, MaskImag}); 221704a8070bSIgor Kirillov ReplacementNode = Builder.CreateSelect(NewMask, A, B); 221804a8070bSIgor Kirillov break; 221904a8070bSIgor Kirillov } 22202cbc265cSIgor Kirillov } 2221d52e2839SNicholas Guy 22222cbc265cSIgor Kirillov assert(ReplacementNode && "Target failed to create Intrinsic call."); 2223d52e2839SNicholas Guy NumComplexTransformations += 1; 22242cbc265cSIgor Kirillov Node->ReplacementNode = ReplacementNode; 22252cbc265cSIgor Kirillov return ReplacementNode; 22262cbc265cSIgor Kirillov } 22272cbc265cSIgor Kirillov 22288e1b49c3SNicholas Guy void ComplexDeinterleavingGraph::processReductionSingle( 22298e1b49c3SNicholas Guy Value *OperationReplacement, RawNodePtr Node) { 22308e1b49c3SNicholas Guy auto *Real = cast<Instruction>(Node->Real); 22318e1b49c3SNicholas Guy auto *OldPHI = ReductionInfo[Real].first; 22328e1b49c3SNicholas Guy auto *NewPHI = OldToNewPHI[OldPHI]; 22338e1b49c3SNicholas Guy auto *VTy = cast<VectorType>(Real->getType()); 22348e1b49c3SNicholas Guy auto *NewVTy = VectorType::getDoubleElementsVectorType(VTy); 22358e1b49c3SNicholas Guy 22368e1b49c3SNicholas Guy Value *Init = OldPHI->getIncomingValueForBlock(Incoming); 22378e1b49c3SNicholas Guy 22388e1b49c3SNicholas Guy IRBuilder<> Builder(Incoming->getTerminator()); 22398e1b49c3SNicholas Guy 22408e1b49c3SNicholas Guy Value *NewInit = nullptr; 22418e1b49c3SNicholas Guy if (auto *C = dyn_cast<Constant>(Init)) { 22428e1b49c3SNicholas Guy if (C->isZeroValue()) 22438e1b49c3SNicholas Guy NewInit = Constant::getNullValue(NewVTy); 22448e1b49c3SNicholas Guy } 22458e1b49c3SNicholas Guy 22468e1b49c3SNicholas Guy if (!NewInit) 22478e1b49c3SNicholas Guy NewInit = Builder.CreateIntrinsic(Intrinsic::vector_interleave2, NewVTy, 22488e1b49c3SNicholas Guy {Init, Constant::getNullValue(VTy)}); 22498e1b49c3SNicholas Guy 22508e1b49c3SNicholas Guy NewPHI->addIncoming(NewInit, Incoming); 22518e1b49c3SNicholas Guy NewPHI->addIncoming(OperationReplacement, BackEdge); 22528e1b49c3SNicholas Guy 22538e1b49c3SNicholas Guy auto *FinalReduction = ReductionInfo[Real].second; 22548e1b49c3SNicholas Guy Builder.SetInsertPoint(&*FinalReduction->getParent()->getFirstInsertionPt()); 22558e1b49c3SNicholas Guy 22568e1b49c3SNicholas Guy auto *AddReduce = Builder.CreateAddReduce(OperationReplacement); 22578e1b49c3SNicholas Guy FinalReduction->replaceAllUsesWith(AddReduce); 22588e1b49c3SNicholas Guy } 22598e1b49c3SNicholas Guy 22602cbc265cSIgor Kirillov void ComplexDeinterleavingGraph::processReductionOperation( 22612cbc265cSIgor Kirillov Value *OperationReplacement, RawNodePtr Node) { 2262b4f9c3a9SIgor Kirillov auto *Real = cast<Instruction>(Node->Real); 2263b4f9c3a9SIgor Kirillov auto *Imag = cast<Instruction>(Node->Imag); 2264b4f9c3a9SIgor Kirillov auto *OldPHIReal = ReductionInfo[Real].first; 2265b4f9c3a9SIgor Kirillov auto *OldPHIImag = ReductionInfo[Imag].first; 22662cbc265cSIgor Kirillov auto *NewPHI = OldToNewPHI[OldPHIReal]; 22672cbc265cSIgor Kirillov 2268b4f9c3a9SIgor Kirillov auto *VTy = cast<VectorType>(Real->getType()); 22692cbc265cSIgor Kirillov auto *NewVTy = VectorType::getDoubleElementsVectorType(VTy); 22702cbc265cSIgor Kirillov 22712cbc265cSIgor Kirillov // We have to interleave initial origin values coming from IncomingBlock 22722cbc265cSIgor Kirillov Value *InitReal = OldPHIReal->getIncomingValueForBlock(Incoming); 22732cbc265cSIgor Kirillov Value *InitImag = OldPHIImag->getIncomingValueForBlock(Incoming); 22742cbc265cSIgor Kirillov 22752cbc265cSIgor Kirillov IRBuilder<> Builder(Incoming->getTerminator()); 2276bfc03171SMaciej Gabka auto *NewInit = Builder.CreateIntrinsic(Intrinsic::vector_interleave2, NewVTy, 2277bfc03171SMaciej Gabka {InitReal, InitImag}); 22782cbc265cSIgor Kirillov 22792cbc265cSIgor Kirillov NewPHI->addIncoming(NewInit, Incoming); 22802cbc265cSIgor Kirillov NewPHI->addIncoming(OperationReplacement, BackEdge); 22812cbc265cSIgor Kirillov 22822cbc265cSIgor Kirillov // Deinterleave complex vector outside of loop so that it can be finally 22832cbc265cSIgor Kirillov // reduced 2284b4f9c3a9SIgor Kirillov auto *FinalReductionReal = ReductionInfo[Real].second; 2285b4f9c3a9SIgor Kirillov auto *FinalReductionImag = ReductionInfo[Imag].second; 22862cbc265cSIgor Kirillov 22872cbc265cSIgor Kirillov Builder.SetInsertPoint( 22882cbc265cSIgor Kirillov &*FinalReductionReal->getParent()->getFirstInsertionPt()); 2289bfc03171SMaciej Gabka auto *Deinterleave = Builder.CreateIntrinsic(Intrinsic::vector_deinterleave2, 2290bfc03171SMaciej Gabka OperationReplacement->getType(), 2291bfc03171SMaciej Gabka OperationReplacement); 22922cbc265cSIgor Kirillov 22932cbc265cSIgor Kirillov auto *NewReal = Builder.CreateExtractValue(Deinterleave, (uint64_t)0); 2294b4f9c3a9SIgor Kirillov FinalReductionReal->replaceUsesOfWith(Real, NewReal); 22952cbc265cSIgor Kirillov 22962cbc265cSIgor Kirillov Builder.SetInsertPoint(FinalReductionImag); 22972cbc265cSIgor Kirillov auto *NewImag = Builder.CreateExtractValue(Deinterleave, 1); 2298b4f9c3a9SIgor Kirillov FinalReductionImag->replaceUsesOfWith(Imag, NewImag); 2299d52e2839SNicholas Guy } 2300d52e2839SNicholas Guy 2301d52e2839SNicholas Guy void ComplexDeinterleavingGraph::replaceNodes() { 2302c692e87aSIgor Kirillov SmallVector<Instruction *, 16> DeadInstrRoots; 2303c692e87aSIgor Kirillov for (auto *RootInstruction : OrderedRoots) { 2304c692e87aSIgor Kirillov // Check if this potential root went through check process and we can 2305c692e87aSIgor Kirillov // deinterleave it 2306c692e87aSIgor Kirillov if (!RootToNode.count(RootInstruction)) 2307c692e87aSIgor Kirillov continue; 2308c692e87aSIgor Kirillov 2309c692e87aSIgor Kirillov IRBuilder<> Builder(RootInstruction); 2310c692e87aSIgor Kirillov auto RootNode = RootToNode[RootInstruction]; 231140a81d31SIgor Kirillov Value *R = replaceNode(Builder, RootNode.get()); 23122cbc265cSIgor Kirillov 23132cbc265cSIgor Kirillov if (RootNode->Operation == 23142cbc265cSIgor Kirillov ComplexDeinterleavingOperation::ReductionOperation) { 2315b4f9c3a9SIgor Kirillov auto *RootReal = cast<Instruction>(RootNode->Real); 2316b4f9c3a9SIgor Kirillov auto *RootImag = cast<Instruction>(RootNode->Imag); 2317b4f9c3a9SIgor Kirillov ReductionInfo[RootReal].first->removeIncomingValue(BackEdge); 2318b4f9c3a9SIgor Kirillov ReductionInfo[RootImag].first->removeIncomingValue(BackEdge); 23198e1b49c3SNicholas Guy DeadInstrRoots.push_back(RootReal); 23208e1b49c3SNicholas Guy DeadInstrRoots.push_back(RootImag); 23218e1b49c3SNicholas Guy } else if (RootNode->Operation == 23228e1b49c3SNicholas Guy ComplexDeinterleavingOperation::ReductionSingle) { 23238e1b49c3SNicholas Guy auto *RootInst = cast<Instruction>(RootNode->Real); 23248e1b49c3SNicholas Guy ReductionInfo[RootInst].first->removeIncomingValue(BackEdge); 23258e1b49c3SNicholas Guy DeadInstrRoots.push_back(ReductionInfo[RootInst].second); 23262cbc265cSIgor Kirillov } else { 2327c692e87aSIgor Kirillov assert(R && "Unable to find replacement for RootInstruction"); 2328c692e87aSIgor Kirillov DeadInstrRoots.push_back(RootInstruction); 2329c692e87aSIgor Kirillov RootInstruction->replaceAllUsesWith(R); 2330d52e2839SNicholas Guy } 23312cbc265cSIgor Kirillov } 2332d52e2839SNicholas Guy 2333c692e87aSIgor Kirillov for (auto *I : DeadInstrRoots) 2334c692e87aSIgor Kirillov RecursivelyDeleteTriviallyDeadInstructions(I, TLI); 2335d52e2839SNicholas Guy } 2336