xref: /freebsd-src/contrib/llvm-project/llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp (revision 0fca6ea1d4eea4c934cfff25ac9ee8ad6fe95583)
1bdd1243dSDimitry Andric //===- ComplexDeinterleavingPass.cpp --------------------------------------===//
2bdd1243dSDimitry Andric //
3bdd1243dSDimitry Andric // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4bdd1243dSDimitry Andric // See https://llvm.org/LICENSE.txt for license information.
5bdd1243dSDimitry Andric // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6bdd1243dSDimitry Andric //
7bdd1243dSDimitry Andric //===----------------------------------------------------------------------===//
8bdd1243dSDimitry Andric //
9bdd1243dSDimitry Andric // Identification:
10bdd1243dSDimitry Andric // This step is responsible for finding the patterns that can be lowered to
11bdd1243dSDimitry Andric // complex instructions, and building a graph to represent the complex
12bdd1243dSDimitry Andric // structures. Starting from the "Converging Shuffle" (a shuffle that
13bdd1243dSDimitry Andric // reinterleaves the complex components, with a mask of <0, 2, 1, 3>), the
14bdd1243dSDimitry Andric // operands are evaluated and identified as "Composite Nodes" (collections of
15bdd1243dSDimitry Andric // instructions that can potentially be lowered to a single complex
16bdd1243dSDimitry Andric // instruction). This is performed by checking the real and imaginary components
17bdd1243dSDimitry Andric // and tracking the data flow for each component while following the operand
18bdd1243dSDimitry Andric // pairs. Validity of each node is expected to be done upon creation, and any
19bdd1243dSDimitry Andric // validation errors should halt traversal and prevent further graph
20bdd1243dSDimitry Andric // construction.
2106c3fb27SDimitry Andric // Instead of relying on Shuffle operations, vector interleaving and
2206c3fb27SDimitry Andric // deinterleaving can be represented by vector.interleave2 and
2306c3fb27SDimitry Andric // vector.deinterleave2 intrinsics. Scalable vectors can be represented only by
2406c3fb27SDimitry Andric // these intrinsics, whereas, fixed-width vectors are recognized for both
2506c3fb27SDimitry Andric // shufflevector instruction and intrinsics.
26bdd1243dSDimitry Andric //
27bdd1243dSDimitry Andric // Replacement:
28bdd1243dSDimitry Andric // This step traverses the graph built up by identification, delegating to the
29bdd1243dSDimitry Andric // target to validate and generate the correct intrinsics, and plumbs them
30bdd1243dSDimitry Andric // together connecting each end of the new intrinsics graph to the existing
31bdd1243dSDimitry Andric // use-def chain. This step is assumed to finish successfully, as all
32bdd1243dSDimitry Andric // information is expected to be correct by this point.
33bdd1243dSDimitry Andric //
34bdd1243dSDimitry Andric //
35bdd1243dSDimitry Andric // Internal data structure:
36bdd1243dSDimitry Andric // ComplexDeinterleavingGraph:
37bdd1243dSDimitry Andric // Keeps references to all the valid CompositeNodes formed as part of the
38bdd1243dSDimitry Andric // transformation, and every Instruction contained within said nodes. It also
39bdd1243dSDimitry Andric // holds onto a reference to the root Instruction, and the root node that should
40bdd1243dSDimitry Andric // replace it.
41bdd1243dSDimitry Andric //
42bdd1243dSDimitry Andric // ComplexDeinterleavingCompositeNode:
43bdd1243dSDimitry Andric // A CompositeNode represents a single transformation point; each node should
44bdd1243dSDimitry Andric // transform into a single complex instruction (ignoring vector splitting, which
45bdd1243dSDimitry Andric // would generate more instructions per node). They are identified in a
46bdd1243dSDimitry Andric // depth-first manner, traversing and identifying the operands of each
47bdd1243dSDimitry Andric // instruction in the order they appear in the IR.
48bdd1243dSDimitry Andric // Each node maintains a reference  to its Real and Imaginary instructions,
49bdd1243dSDimitry Andric // as well as any additional instructions that make up the identified operation
50bdd1243dSDimitry Andric // (Internal instructions should only have uses within their containing node).
51bdd1243dSDimitry Andric // A Node also contains the rotation and operation type that it represents.
52bdd1243dSDimitry Andric // Operands contains pointers to other CompositeNodes, acting as the edges in
53bdd1243dSDimitry Andric // the graph. ReplacementValue is the transformed Value* that has been emitted
54bdd1243dSDimitry Andric // to the IR.
55bdd1243dSDimitry Andric //
56bdd1243dSDimitry Andric // Note: If the operation of a Node is Shuffle, only the Real, Imaginary, and
57bdd1243dSDimitry Andric // ReplacementValue fields of that Node are relevant, where the ReplacementValue
58bdd1243dSDimitry Andric // should be pre-populated.
59bdd1243dSDimitry Andric //
60bdd1243dSDimitry Andric //===----------------------------------------------------------------------===//
61bdd1243dSDimitry Andric 
62bdd1243dSDimitry Andric #include "llvm/CodeGen/ComplexDeinterleavingPass.h"
635f757f3fSDimitry Andric #include "llvm/ADT/MapVector.h"
64bdd1243dSDimitry Andric #include "llvm/ADT/Statistic.h"
65bdd1243dSDimitry Andric #include "llvm/Analysis/TargetLibraryInfo.h"
66bdd1243dSDimitry Andric #include "llvm/Analysis/TargetTransformInfo.h"
67bdd1243dSDimitry Andric #include "llvm/CodeGen/TargetLowering.h"
68bdd1243dSDimitry Andric #include "llvm/CodeGen/TargetPassConfig.h"
69bdd1243dSDimitry Andric #include "llvm/CodeGen/TargetSubtargetInfo.h"
70bdd1243dSDimitry Andric #include "llvm/IR/IRBuilder.h"
7106c3fb27SDimitry Andric #include "llvm/IR/PatternMatch.h"
72bdd1243dSDimitry Andric #include "llvm/InitializePasses.h"
73bdd1243dSDimitry Andric #include "llvm/Target/TargetMachine.h"
74bdd1243dSDimitry Andric #include "llvm/Transforms/Utils/Local.h"
75bdd1243dSDimitry Andric #include <algorithm>
76bdd1243dSDimitry Andric 
77bdd1243dSDimitry Andric using namespace llvm;
78bdd1243dSDimitry Andric using namespace PatternMatch;
79bdd1243dSDimitry Andric 
80bdd1243dSDimitry Andric #define DEBUG_TYPE "complex-deinterleaving"
81bdd1243dSDimitry Andric 
82bdd1243dSDimitry Andric STATISTIC(NumComplexTransformations, "Amount of complex patterns transformed");
83bdd1243dSDimitry Andric 
84bdd1243dSDimitry Andric static cl::opt<bool> ComplexDeinterleavingEnabled(
85bdd1243dSDimitry Andric     "enable-complex-deinterleaving",
86bdd1243dSDimitry Andric     cl::desc("Enable generation of complex instructions"), cl::init(true),
87bdd1243dSDimitry Andric     cl::Hidden);
88bdd1243dSDimitry Andric 
89bdd1243dSDimitry Andric /// Checks the given mask, and determines whether said mask is interleaving.
90bdd1243dSDimitry Andric ///
91bdd1243dSDimitry Andric /// To be interleaving, a mask must alternate between `i` and `i + (Length /
92bdd1243dSDimitry Andric /// 2)`, and must contain all numbers within the range of `[0..Length)` (e.g. a
93bdd1243dSDimitry Andric /// 4x vector interleaving mask would be <0, 2, 1, 3>).
94bdd1243dSDimitry Andric static bool isInterleavingMask(ArrayRef<int> Mask);
95bdd1243dSDimitry Andric 
96bdd1243dSDimitry Andric /// Checks the given mask, and determines whether said mask is deinterleaving.
97bdd1243dSDimitry Andric ///
98bdd1243dSDimitry Andric /// To be deinterleaving, a mask must increment in steps of 2, and either start
99bdd1243dSDimitry Andric /// with 0 or 1.
100bdd1243dSDimitry Andric /// (e.g. an 8x vector deinterleaving mask would be either <0, 2, 4, 6> or
101bdd1243dSDimitry Andric /// <1, 3, 5, 7>).
102bdd1243dSDimitry Andric static bool isDeinterleavingMask(ArrayRef<int> Mask);
103bdd1243dSDimitry Andric 
10406c3fb27SDimitry Andric /// Returns true if the operation is a negation of V, and it works for both
10506c3fb27SDimitry Andric /// integers and floats.
10606c3fb27SDimitry Andric static bool isNeg(Value *V);
10706c3fb27SDimitry Andric 
10806c3fb27SDimitry Andric /// Returns the operand for negation operation.
10906c3fb27SDimitry Andric static Value *getNegOperand(Value *V);
11006c3fb27SDimitry Andric 
111bdd1243dSDimitry Andric namespace {
112bdd1243dSDimitry Andric 
113bdd1243dSDimitry Andric class ComplexDeinterleavingLegacyPass : public FunctionPass {
114bdd1243dSDimitry Andric public:
115bdd1243dSDimitry Andric   static char ID;
116bdd1243dSDimitry Andric 
117bdd1243dSDimitry Andric   ComplexDeinterleavingLegacyPass(const TargetMachine *TM = nullptr)
118bdd1243dSDimitry Andric       : FunctionPass(ID), TM(TM) {
119bdd1243dSDimitry Andric     initializeComplexDeinterleavingLegacyPassPass(
120bdd1243dSDimitry Andric         *PassRegistry::getPassRegistry());
121bdd1243dSDimitry Andric   }
122bdd1243dSDimitry Andric 
123bdd1243dSDimitry Andric   StringRef getPassName() const override {
124bdd1243dSDimitry Andric     return "Complex Deinterleaving Pass";
125bdd1243dSDimitry Andric   }
126bdd1243dSDimitry Andric 
127bdd1243dSDimitry Andric   bool runOnFunction(Function &F) override;
128bdd1243dSDimitry Andric   void getAnalysisUsage(AnalysisUsage &AU) const override {
129bdd1243dSDimitry Andric     AU.addRequired<TargetLibraryInfoWrapperPass>();
130bdd1243dSDimitry Andric     AU.setPreservesCFG();
131bdd1243dSDimitry Andric   }
132bdd1243dSDimitry Andric 
133bdd1243dSDimitry Andric private:
134bdd1243dSDimitry Andric   const TargetMachine *TM;
135bdd1243dSDimitry Andric };
136bdd1243dSDimitry Andric 
137bdd1243dSDimitry Andric class ComplexDeinterleavingGraph;
138bdd1243dSDimitry Andric struct ComplexDeinterleavingCompositeNode {
139bdd1243dSDimitry Andric 
140bdd1243dSDimitry Andric   ComplexDeinterleavingCompositeNode(ComplexDeinterleavingOperation Op,
14106c3fb27SDimitry Andric                                      Value *R, Value *I)
142bdd1243dSDimitry Andric       : Operation(Op), Real(R), Imag(I) {}
143bdd1243dSDimitry Andric 
144bdd1243dSDimitry Andric private:
145bdd1243dSDimitry Andric   friend class ComplexDeinterleavingGraph;
146bdd1243dSDimitry Andric   using NodePtr = std::shared_ptr<ComplexDeinterleavingCompositeNode>;
147bdd1243dSDimitry Andric   using RawNodePtr = ComplexDeinterleavingCompositeNode *;
148bdd1243dSDimitry Andric 
149bdd1243dSDimitry Andric public:
150bdd1243dSDimitry Andric   ComplexDeinterleavingOperation Operation;
15106c3fb27SDimitry Andric   Value *Real;
15206c3fb27SDimitry Andric   Value *Imag;
153bdd1243dSDimitry Andric 
15406c3fb27SDimitry Andric   // This two members are required exclusively for generating
15506c3fb27SDimitry Andric   // ComplexDeinterleavingOperation::Symmetric operations.
15606c3fb27SDimitry Andric   unsigned Opcode;
15706c3fb27SDimitry Andric   std::optional<FastMathFlags> Flags;
15806c3fb27SDimitry Andric 
15906c3fb27SDimitry Andric   ComplexDeinterleavingRotation Rotation =
16006c3fb27SDimitry Andric       ComplexDeinterleavingRotation::Rotation_0;
161bdd1243dSDimitry Andric   SmallVector<RawNodePtr> Operands;
162bdd1243dSDimitry Andric   Value *ReplacementNode = nullptr;
163bdd1243dSDimitry Andric 
164bdd1243dSDimitry Andric   void addOperand(NodePtr Node) { Operands.push_back(Node.get()); }
165bdd1243dSDimitry Andric 
166bdd1243dSDimitry Andric   void dump() { dump(dbgs()); }
167bdd1243dSDimitry Andric   void dump(raw_ostream &OS) {
168bdd1243dSDimitry Andric     auto PrintValue = [&](Value *V) {
169bdd1243dSDimitry Andric       if (V) {
170bdd1243dSDimitry Andric         OS << "\"";
171bdd1243dSDimitry Andric         V->print(OS, true);
172bdd1243dSDimitry Andric         OS << "\"\n";
173bdd1243dSDimitry Andric       } else
174bdd1243dSDimitry Andric         OS << "nullptr\n";
175bdd1243dSDimitry Andric     };
176bdd1243dSDimitry Andric     auto PrintNodeRef = [&](RawNodePtr Ptr) {
177bdd1243dSDimitry Andric       if (Ptr)
178bdd1243dSDimitry Andric         OS << Ptr << "\n";
179bdd1243dSDimitry Andric       else
180bdd1243dSDimitry Andric         OS << "nullptr\n";
181bdd1243dSDimitry Andric     };
182bdd1243dSDimitry Andric 
183bdd1243dSDimitry Andric     OS << "- CompositeNode: " << this << "\n";
184bdd1243dSDimitry Andric     OS << "  Real: ";
185bdd1243dSDimitry Andric     PrintValue(Real);
186bdd1243dSDimitry Andric     OS << "  Imag: ";
187bdd1243dSDimitry Andric     PrintValue(Imag);
188bdd1243dSDimitry Andric     OS << "  ReplacementNode: ";
189bdd1243dSDimitry Andric     PrintValue(ReplacementNode);
190bdd1243dSDimitry Andric     OS << "  Operation: " << (int)Operation << "\n";
191bdd1243dSDimitry Andric     OS << "  Rotation: " << ((int)Rotation * 90) << "\n";
192bdd1243dSDimitry Andric     OS << "  Operands: \n";
193bdd1243dSDimitry Andric     for (const auto &Op : Operands) {
194bdd1243dSDimitry Andric       OS << "    - ";
195bdd1243dSDimitry Andric       PrintNodeRef(Op);
196bdd1243dSDimitry Andric     }
197bdd1243dSDimitry Andric   }
198bdd1243dSDimitry Andric };
199bdd1243dSDimitry Andric 
200bdd1243dSDimitry Andric class ComplexDeinterleavingGraph {
201bdd1243dSDimitry Andric public:
20206c3fb27SDimitry Andric   struct Product {
20306c3fb27SDimitry Andric     Value *Multiplier;
20406c3fb27SDimitry Andric     Value *Multiplicand;
20506c3fb27SDimitry Andric     bool IsPositive;
20606c3fb27SDimitry Andric   };
20706c3fb27SDimitry Andric 
20806c3fb27SDimitry Andric   using Addend = std::pair<Value *, bool>;
209bdd1243dSDimitry Andric   using NodePtr = ComplexDeinterleavingCompositeNode::NodePtr;
210bdd1243dSDimitry Andric   using RawNodePtr = ComplexDeinterleavingCompositeNode::RawNodePtr;
21106c3fb27SDimitry Andric 
21206c3fb27SDimitry Andric   // Helper struct for holding info about potential partial multiplication
21306c3fb27SDimitry Andric   // candidates
21406c3fb27SDimitry Andric   struct PartialMulCandidate {
21506c3fb27SDimitry Andric     Value *Common;
21606c3fb27SDimitry Andric     NodePtr Node;
21706c3fb27SDimitry Andric     unsigned RealIdx;
21806c3fb27SDimitry Andric     unsigned ImagIdx;
21906c3fb27SDimitry Andric     bool IsNodeInverted;
22006c3fb27SDimitry Andric   };
22106c3fb27SDimitry Andric 
22206c3fb27SDimitry Andric   explicit ComplexDeinterleavingGraph(const TargetLowering *TL,
22306c3fb27SDimitry Andric                                       const TargetLibraryInfo *TLI)
22406c3fb27SDimitry Andric       : TL(TL), TLI(TLI) {}
225bdd1243dSDimitry Andric 
226bdd1243dSDimitry Andric private:
22706c3fb27SDimitry Andric   const TargetLowering *TL = nullptr;
22806c3fb27SDimitry Andric   const TargetLibraryInfo *TLI = nullptr;
229bdd1243dSDimitry Andric   SmallVector<NodePtr> CompositeNodes;
2308a4dda33SDimitry Andric   DenseMap<std::pair<Value *, Value *>, NodePtr> CachedResult;
23106c3fb27SDimitry Andric 
23206c3fb27SDimitry Andric   SmallPtrSet<Instruction *, 16> FinalInstructions;
23306c3fb27SDimitry Andric 
23406c3fb27SDimitry Andric   /// Root instructions are instructions from which complex computation starts
23506c3fb27SDimitry Andric   std::map<Instruction *, NodePtr> RootToNode;
23606c3fb27SDimitry Andric 
23706c3fb27SDimitry Andric   /// Topologically sorted root instructions
23806c3fb27SDimitry Andric   SmallVector<Instruction *, 1> OrderedRoots;
23906c3fb27SDimitry Andric 
24006c3fb27SDimitry Andric   /// When examining a basic block for complex deinterleaving, if it is a simple
24106c3fb27SDimitry Andric   /// one-block loop, then the only incoming block is 'Incoming' and the
24206c3fb27SDimitry Andric   /// 'BackEdge' block is the block itself."
24306c3fb27SDimitry Andric   BasicBlock *BackEdge = nullptr;
24406c3fb27SDimitry Andric   BasicBlock *Incoming = nullptr;
24506c3fb27SDimitry Andric 
24606c3fb27SDimitry Andric   /// ReductionInfo maps from %ReductionOp to %PHInode and Instruction
24706c3fb27SDimitry Andric   /// %OutsideUser as it is shown in the IR:
24806c3fb27SDimitry Andric   ///
24906c3fb27SDimitry Andric   /// vector.body:
25006c3fb27SDimitry Andric   ///   %PHInode = phi <vector type> [ zeroinitializer, %entry ],
25106c3fb27SDimitry Andric   ///                                [ %ReductionOp, %vector.body ]
25206c3fb27SDimitry Andric   ///   ...
25306c3fb27SDimitry Andric   ///   %ReductionOp = fadd i64 ...
25406c3fb27SDimitry Andric   ///   ...
25506c3fb27SDimitry Andric   ///   br i1 %condition, label %vector.body, %middle.block
25606c3fb27SDimitry Andric   ///
25706c3fb27SDimitry Andric   /// middle.block:
25806c3fb27SDimitry Andric   ///   %OutsideUser = llvm.vector.reduce.fadd(..., %ReductionOp)
25906c3fb27SDimitry Andric   ///
26006c3fb27SDimitry Andric   /// %OutsideUser can be `llvm.vector.reduce.fadd` or `fadd` preceding
26106c3fb27SDimitry Andric   /// `llvm.vector.reduce.fadd` when unroll factor isn't one.
2625f757f3fSDimitry Andric   MapVector<Instruction *, std::pair<PHINode *, Instruction *>> ReductionInfo;
26306c3fb27SDimitry Andric 
26406c3fb27SDimitry Andric   /// In the process of detecting a reduction, we consider a pair of
26506c3fb27SDimitry Andric   /// %ReductionOP, which we refer to as real and imag (or vice versa), and
26606c3fb27SDimitry Andric   /// traverse the use-tree to detect complex operations. As this is a reduction
26706c3fb27SDimitry Andric   /// operation, it will eventually reach RealPHI and ImagPHI, which corresponds
26806c3fb27SDimitry Andric   /// to the %ReductionOPs that we suspect to be complex.
26906c3fb27SDimitry Andric   /// RealPHI and ImagPHI are used by the identifyPHINode method.
27006c3fb27SDimitry Andric   PHINode *RealPHI = nullptr;
27106c3fb27SDimitry Andric   PHINode *ImagPHI = nullptr;
27206c3fb27SDimitry Andric 
27306c3fb27SDimitry Andric   /// Set this flag to true if RealPHI and ImagPHI were reached during reduction
27406c3fb27SDimitry Andric   /// detection.
27506c3fb27SDimitry Andric   bool PHIsFound = false;
27606c3fb27SDimitry Andric 
27706c3fb27SDimitry Andric   /// OldToNewPHI maps the original real PHINode to a new, double-sized PHINode.
27806c3fb27SDimitry Andric   /// The new PHINode corresponds to a vector of deinterleaved complex numbers.
27906c3fb27SDimitry Andric   /// This mapping is populated during
28006c3fb27SDimitry Andric   /// ComplexDeinterleavingOperation::ReductionPHI node replacement. It is then
28106c3fb27SDimitry Andric   /// used in the ComplexDeinterleavingOperation::ReductionOperation node
28206c3fb27SDimitry Andric   /// replacement process.
28306c3fb27SDimitry Andric   std::map<PHINode *, PHINode *> OldToNewPHI;
284bdd1243dSDimitry Andric 
285bdd1243dSDimitry Andric   NodePtr prepareCompositeNode(ComplexDeinterleavingOperation Operation,
28606c3fb27SDimitry Andric                                Value *R, Value *I) {
28706c3fb27SDimitry Andric     assert(((Operation != ComplexDeinterleavingOperation::ReductionPHI &&
28806c3fb27SDimitry Andric              Operation != ComplexDeinterleavingOperation::ReductionOperation) ||
28906c3fb27SDimitry Andric             (R && I)) &&
29006c3fb27SDimitry Andric            "Reduction related nodes must have Real and Imaginary parts");
291bdd1243dSDimitry Andric     return std::make_shared<ComplexDeinterleavingCompositeNode>(Operation, R,
292bdd1243dSDimitry Andric                                                                 I);
293bdd1243dSDimitry Andric   }
294bdd1243dSDimitry Andric 
295bdd1243dSDimitry Andric   NodePtr submitCompositeNode(NodePtr Node) {
296bdd1243dSDimitry Andric     CompositeNodes.push_back(Node);
2978a4dda33SDimitry Andric     if (Node->Real && Node->Imag)
2988a4dda33SDimitry Andric       CachedResult[{Node->Real, Node->Imag}] = Node;
299bdd1243dSDimitry Andric     return Node;
300bdd1243dSDimitry Andric   }
301bdd1243dSDimitry Andric 
302bdd1243dSDimitry Andric   /// Identifies a complex partial multiply pattern and its rotation, based on
303bdd1243dSDimitry Andric   /// the following patterns
304bdd1243dSDimitry Andric   ///
305bdd1243dSDimitry Andric   ///  0:  r: cr + ar * br
306bdd1243dSDimitry Andric   ///      i: ci + ar * bi
307bdd1243dSDimitry Andric   /// 90:  r: cr - ai * bi
308bdd1243dSDimitry Andric   ///      i: ci + ai * br
309bdd1243dSDimitry Andric   /// 180: r: cr - ar * br
310bdd1243dSDimitry Andric   ///      i: ci - ar * bi
311bdd1243dSDimitry Andric   /// 270: r: cr + ai * bi
312bdd1243dSDimitry Andric   ///      i: ci - ai * br
313bdd1243dSDimitry Andric   NodePtr identifyPartialMul(Instruction *Real, Instruction *Imag);
314bdd1243dSDimitry Andric 
315bdd1243dSDimitry Andric   /// Identify the other branch of a Partial Mul, taking the CommonOperandI that
316bdd1243dSDimitry Andric   /// is partially known from identifyPartialMul, filling in the other half of
317bdd1243dSDimitry Andric   /// the complex pair.
31806c3fb27SDimitry Andric   NodePtr
31906c3fb27SDimitry Andric   identifyNodeWithImplicitAdd(Instruction *I, Instruction *J,
32006c3fb27SDimitry Andric                               std::pair<Value *, Value *> &CommonOperandI);
321bdd1243dSDimitry Andric 
322bdd1243dSDimitry Andric   /// Identifies a complex add pattern and its rotation, based on the following
323bdd1243dSDimitry Andric   /// patterns.
324bdd1243dSDimitry Andric   ///
325bdd1243dSDimitry Andric   /// 90:  r: ar - bi
326bdd1243dSDimitry Andric   ///      i: ai + br
327bdd1243dSDimitry Andric   /// 270: r: ar + bi
328bdd1243dSDimitry Andric   ///      i: ai - br
329bdd1243dSDimitry Andric   NodePtr identifyAdd(Instruction *Real, Instruction *Imag);
33006c3fb27SDimitry Andric   NodePtr identifySymmetricOperation(Instruction *Real, Instruction *Imag);
331bdd1243dSDimitry Andric 
33206c3fb27SDimitry Andric   NodePtr identifyNode(Value *R, Value *I);
333bdd1243dSDimitry Andric 
33406c3fb27SDimitry Andric   /// Determine if a sum of complex numbers can be formed from \p RealAddends
33506c3fb27SDimitry Andric   /// and \p ImagAddens. If \p Accumulator is not null, add the result to it.
33606c3fb27SDimitry Andric   /// Return nullptr if it is not possible to construct a complex number.
33706c3fb27SDimitry Andric   /// \p Flags are needed to generate symmetric Add and Sub operations.
33806c3fb27SDimitry Andric   NodePtr identifyAdditions(std::list<Addend> &RealAddends,
33906c3fb27SDimitry Andric                             std::list<Addend> &ImagAddends,
34006c3fb27SDimitry Andric                             std::optional<FastMathFlags> Flags,
34106c3fb27SDimitry Andric                             NodePtr Accumulator);
34206c3fb27SDimitry Andric 
34306c3fb27SDimitry Andric   /// Extract one addend that have both real and imaginary parts positive.
34406c3fb27SDimitry Andric   NodePtr extractPositiveAddend(std::list<Addend> &RealAddends,
34506c3fb27SDimitry Andric                                 std::list<Addend> &ImagAddends);
34606c3fb27SDimitry Andric 
34706c3fb27SDimitry Andric   /// Determine if sum of multiplications of complex numbers can be formed from
34806c3fb27SDimitry Andric   /// \p RealMuls and \p ImagMuls. If \p Accumulator is not null, add the result
34906c3fb27SDimitry Andric   /// to it. Return nullptr if it is not possible to construct a complex number.
35006c3fb27SDimitry Andric   NodePtr identifyMultiplications(std::vector<Product> &RealMuls,
35106c3fb27SDimitry Andric                                   std::vector<Product> &ImagMuls,
35206c3fb27SDimitry Andric                                   NodePtr Accumulator);
35306c3fb27SDimitry Andric 
35406c3fb27SDimitry Andric   /// Go through pairs of multiplication (one Real and one Imag) and find all
35506c3fb27SDimitry Andric   /// possible candidates for partial multiplication and put them into \p
35606c3fb27SDimitry Andric   /// Candidates. Returns true if all Product has pair with common operand
35706c3fb27SDimitry Andric   bool collectPartialMuls(const std::vector<Product> &RealMuls,
35806c3fb27SDimitry Andric                           const std::vector<Product> &ImagMuls,
35906c3fb27SDimitry Andric                           std::vector<PartialMulCandidate> &Candidates);
36006c3fb27SDimitry Andric 
36106c3fb27SDimitry Andric   /// If the code is compiled with -Ofast or expressions have `reassoc` flag,
36206c3fb27SDimitry Andric   /// the order of complex computation operations may be significantly altered,
36306c3fb27SDimitry Andric   /// and the real and imaginary parts may not be executed in parallel. This
36406c3fb27SDimitry Andric   /// function takes this into consideration and employs a more general approach
36506c3fb27SDimitry Andric   /// to identify complex computations. Initially, it gathers all the addends
36606c3fb27SDimitry Andric   /// and multiplicands and then constructs a complex expression from them.
36706c3fb27SDimitry Andric   NodePtr identifyReassocNodes(Instruction *I, Instruction *J);
36806c3fb27SDimitry Andric 
36906c3fb27SDimitry Andric   NodePtr identifyRoot(Instruction *I);
37006c3fb27SDimitry Andric 
37106c3fb27SDimitry Andric   /// Identifies the Deinterleave operation applied to a vector containing
37206c3fb27SDimitry Andric   /// complex numbers. There are two ways to represent the Deinterleave
37306c3fb27SDimitry Andric   /// operation:
37406c3fb27SDimitry Andric   /// * Using two shufflevectors with even indices for /pReal instruction and
37506c3fb27SDimitry Andric   /// odd indices for /pImag instructions (only for fixed-width vectors)
37606c3fb27SDimitry Andric   /// * Using two extractvalue instructions applied to `vector.deinterleave2`
37706c3fb27SDimitry Andric   /// intrinsic (for both fixed and scalable vectors)
37806c3fb27SDimitry Andric   NodePtr identifyDeinterleave(Instruction *Real, Instruction *Imag);
37906c3fb27SDimitry Andric 
38006c3fb27SDimitry Andric   /// identifying the operation that represents a complex number repeated in a
38106c3fb27SDimitry Andric   /// Splat vector. There are two possible types of splats: ConstantExpr with
38206c3fb27SDimitry Andric   /// the opcode ShuffleVector and ShuffleVectorInstr. Both should have an
38306c3fb27SDimitry Andric   /// initialization mask with all values set to zero.
38406c3fb27SDimitry Andric   NodePtr identifySplat(Value *Real, Value *Imag);
38506c3fb27SDimitry Andric 
38606c3fb27SDimitry Andric   NodePtr identifyPHINode(Instruction *Real, Instruction *Imag);
38706c3fb27SDimitry Andric 
38806c3fb27SDimitry Andric   /// Identifies SelectInsts in a loop that has reduction with predication masks
38906c3fb27SDimitry Andric   /// and/or predicated tail folding
39006c3fb27SDimitry Andric   NodePtr identifySelectNode(Instruction *Real, Instruction *Imag);
39106c3fb27SDimitry Andric 
39206c3fb27SDimitry Andric   Value *replaceNode(IRBuilderBase &Builder, RawNodePtr Node);
39306c3fb27SDimitry Andric 
39406c3fb27SDimitry Andric   /// Complete IR modifications after producing new reduction operation:
39506c3fb27SDimitry Andric   /// * Populate the PHINode generated for
39606c3fb27SDimitry Andric   /// ComplexDeinterleavingOperation::ReductionPHI
39706c3fb27SDimitry Andric   /// * Deinterleave the final value outside of the loop and repurpose original
39806c3fb27SDimitry Andric   /// reduction users
39906c3fb27SDimitry Andric   void processReductionOperation(Value *OperationReplacement, RawNodePtr Node);
400bdd1243dSDimitry Andric 
401bdd1243dSDimitry Andric public:
402bdd1243dSDimitry Andric   void dump() { dump(dbgs()); }
403bdd1243dSDimitry Andric   void dump(raw_ostream &OS) {
404bdd1243dSDimitry Andric     for (const auto &Node : CompositeNodes)
405bdd1243dSDimitry Andric       Node->dump(OS);
406bdd1243dSDimitry Andric   }
407bdd1243dSDimitry Andric 
408bdd1243dSDimitry Andric   /// Returns false if the deinterleaving operation should be cancelled for the
409bdd1243dSDimitry Andric   /// current graph.
410bdd1243dSDimitry Andric   bool identifyNodes(Instruction *RootI);
411bdd1243dSDimitry Andric 
41206c3fb27SDimitry Andric   /// In case \pB is one-block loop, this function seeks potential reductions
41306c3fb27SDimitry Andric   /// and populates ReductionInfo. Returns true if any reductions were
41406c3fb27SDimitry Andric   /// identified.
41506c3fb27SDimitry Andric   bool collectPotentialReductions(BasicBlock *B);
41606c3fb27SDimitry Andric 
41706c3fb27SDimitry Andric   void identifyReductionNodes();
41806c3fb27SDimitry Andric 
41906c3fb27SDimitry Andric   /// Check that every instruction, from the roots to the leaves, has internal
42006c3fb27SDimitry Andric   /// uses.
42106c3fb27SDimitry Andric   bool checkNodes();
42206c3fb27SDimitry Andric 
423bdd1243dSDimitry Andric   /// Perform the actual replacement of the underlying instruction graph.
424bdd1243dSDimitry Andric   void replaceNodes();
425bdd1243dSDimitry Andric };
426bdd1243dSDimitry Andric 
427bdd1243dSDimitry Andric class ComplexDeinterleaving {
428bdd1243dSDimitry Andric public:
429bdd1243dSDimitry Andric   ComplexDeinterleaving(const TargetLowering *tl, const TargetLibraryInfo *tli)
430bdd1243dSDimitry Andric       : TL(tl), TLI(tli) {}
431bdd1243dSDimitry Andric   bool runOnFunction(Function &F);
432bdd1243dSDimitry Andric 
433bdd1243dSDimitry Andric private:
434bdd1243dSDimitry Andric   bool evaluateBasicBlock(BasicBlock *B);
435bdd1243dSDimitry Andric 
436bdd1243dSDimitry Andric   const TargetLowering *TL = nullptr;
437bdd1243dSDimitry Andric   const TargetLibraryInfo *TLI = nullptr;
438bdd1243dSDimitry Andric };
439bdd1243dSDimitry Andric 
440bdd1243dSDimitry Andric } // namespace
441bdd1243dSDimitry Andric 
442bdd1243dSDimitry Andric char ComplexDeinterleavingLegacyPass::ID = 0;
443bdd1243dSDimitry Andric 
444bdd1243dSDimitry Andric INITIALIZE_PASS_BEGIN(ComplexDeinterleavingLegacyPass, DEBUG_TYPE,
445bdd1243dSDimitry Andric                       "Complex Deinterleaving", false, false)
446bdd1243dSDimitry Andric INITIALIZE_PASS_END(ComplexDeinterleavingLegacyPass, DEBUG_TYPE,
447bdd1243dSDimitry Andric                     "Complex Deinterleaving", false, false)
448bdd1243dSDimitry Andric 
449bdd1243dSDimitry Andric PreservedAnalyses ComplexDeinterleavingPass::run(Function &F,
450bdd1243dSDimitry Andric                                                  FunctionAnalysisManager &AM) {
451bdd1243dSDimitry Andric   const TargetLowering *TL = TM->getSubtargetImpl(F)->getTargetLowering();
452bdd1243dSDimitry Andric   auto &TLI = AM.getResult<llvm::TargetLibraryAnalysis>(F);
453bdd1243dSDimitry Andric   if (!ComplexDeinterleaving(TL, &TLI).runOnFunction(F))
454bdd1243dSDimitry Andric     return PreservedAnalyses::all();
455bdd1243dSDimitry Andric 
456bdd1243dSDimitry Andric   PreservedAnalyses PA;
457bdd1243dSDimitry Andric   PA.preserve<FunctionAnalysisManagerModuleProxy>();
458bdd1243dSDimitry Andric   return PA;
459bdd1243dSDimitry Andric }
460bdd1243dSDimitry Andric 
461bdd1243dSDimitry Andric FunctionPass *llvm::createComplexDeinterleavingPass(const TargetMachine *TM) {
462bdd1243dSDimitry Andric   return new ComplexDeinterleavingLegacyPass(TM);
463bdd1243dSDimitry Andric }
464bdd1243dSDimitry Andric 
465bdd1243dSDimitry Andric bool ComplexDeinterleavingLegacyPass::runOnFunction(Function &F) {
466bdd1243dSDimitry Andric   const auto *TL = TM->getSubtargetImpl(F)->getTargetLowering();
467bdd1243dSDimitry Andric   auto TLI = getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F);
468bdd1243dSDimitry Andric   return ComplexDeinterleaving(TL, &TLI).runOnFunction(F);
469bdd1243dSDimitry Andric }
470bdd1243dSDimitry Andric 
471bdd1243dSDimitry Andric bool ComplexDeinterleaving::runOnFunction(Function &F) {
472bdd1243dSDimitry Andric   if (!ComplexDeinterleavingEnabled) {
473bdd1243dSDimitry Andric     LLVM_DEBUG(
474bdd1243dSDimitry Andric         dbgs() << "Complex deinterleaving has been explicitly disabled.\n");
475bdd1243dSDimitry Andric     return false;
476bdd1243dSDimitry Andric   }
477bdd1243dSDimitry Andric 
478bdd1243dSDimitry Andric   if (!TL->isComplexDeinterleavingSupported()) {
479bdd1243dSDimitry Andric     LLVM_DEBUG(
480bdd1243dSDimitry Andric         dbgs() << "Complex deinterleaving has been disabled, target does "
481bdd1243dSDimitry Andric                   "not support lowering of complex number operations.\n");
482bdd1243dSDimitry Andric     return false;
483bdd1243dSDimitry Andric   }
484bdd1243dSDimitry Andric 
485bdd1243dSDimitry Andric   bool Changed = false;
486bdd1243dSDimitry Andric   for (auto &B : F)
487bdd1243dSDimitry Andric     Changed |= evaluateBasicBlock(&B);
488bdd1243dSDimitry Andric 
489bdd1243dSDimitry Andric   return Changed;
490bdd1243dSDimitry Andric }
491bdd1243dSDimitry Andric 
492bdd1243dSDimitry Andric static bool isInterleavingMask(ArrayRef<int> Mask) {
493bdd1243dSDimitry Andric   // If the size is not even, it's not an interleaving mask
494bdd1243dSDimitry Andric   if ((Mask.size() & 1))
495bdd1243dSDimitry Andric     return false;
496bdd1243dSDimitry Andric 
497bdd1243dSDimitry Andric   int HalfNumElements = Mask.size() / 2;
498bdd1243dSDimitry Andric   for (int Idx = 0; Idx < HalfNumElements; ++Idx) {
499bdd1243dSDimitry Andric     int MaskIdx = Idx * 2;
500bdd1243dSDimitry Andric     if (Mask[MaskIdx] != Idx || Mask[MaskIdx + 1] != (Idx + HalfNumElements))
501bdd1243dSDimitry Andric       return false;
502bdd1243dSDimitry Andric   }
503bdd1243dSDimitry Andric 
504bdd1243dSDimitry Andric   return true;
505bdd1243dSDimitry Andric }
506bdd1243dSDimitry Andric 
507bdd1243dSDimitry Andric static bool isDeinterleavingMask(ArrayRef<int> Mask) {
508bdd1243dSDimitry Andric   int Offset = Mask[0];
509bdd1243dSDimitry Andric   int HalfNumElements = Mask.size() / 2;
510bdd1243dSDimitry Andric 
511bdd1243dSDimitry Andric   for (int Idx = 1; Idx < HalfNumElements; ++Idx) {
512bdd1243dSDimitry Andric     if (Mask[Idx] != (Idx * 2) + Offset)
513bdd1243dSDimitry Andric       return false;
514bdd1243dSDimitry Andric   }
515bdd1243dSDimitry Andric 
516bdd1243dSDimitry Andric   return true;
517bdd1243dSDimitry Andric }
518bdd1243dSDimitry Andric 
51906c3fb27SDimitry Andric bool isNeg(Value *V) {
52006c3fb27SDimitry Andric   return match(V, m_FNeg(m_Value())) || match(V, m_Neg(m_Value()));
52106c3fb27SDimitry Andric }
52206c3fb27SDimitry Andric 
52306c3fb27SDimitry Andric Value *getNegOperand(Value *V) {
52406c3fb27SDimitry Andric   assert(isNeg(V));
52506c3fb27SDimitry Andric   auto *I = cast<Instruction>(V);
52606c3fb27SDimitry Andric   if (I->getOpcode() == Instruction::FNeg)
52706c3fb27SDimitry Andric     return I->getOperand(0);
52806c3fb27SDimitry Andric 
52906c3fb27SDimitry Andric   return I->getOperand(1);
53006c3fb27SDimitry Andric }
53106c3fb27SDimitry Andric 
532bdd1243dSDimitry Andric bool ComplexDeinterleaving::evaluateBasicBlock(BasicBlock *B) {
53306c3fb27SDimitry Andric   ComplexDeinterleavingGraph Graph(TL, TLI);
53406c3fb27SDimitry Andric   if (Graph.collectPotentialReductions(B))
53506c3fb27SDimitry Andric     Graph.identifyReductionNodes();
536bdd1243dSDimitry Andric 
53706c3fb27SDimitry Andric   for (auto &I : *B)
53806c3fb27SDimitry Andric     Graph.identifyNodes(&I);
539bdd1243dSDimitry Andric 
54006c3fb27SDimitry Andric   if (Graph.checkNodes()) {
541bdd1243dSDimitry Andric     Graph.replaceNodes();
54206c3fb27SDimitry Andric     return true;
543bdd1243dSDimitry Andric   }
544bdd1243dSDimitry Andric 
54506c3fb27SDimitry Andric   return false;
546bdd1243dSDimitry Andric }
547bdd1243dSDimitry Andric 
548bdd1243dSDimitry Andric ComplexDeinterleavingGraph::NodePtr
549bdd1243dSDimitry Andric ComplexDeinterleavingGraph::identifyNodeWithImplicitAdd(
550bdd1243dSDimitry Andric     Instruction *Real, Instruction *Imag,
55106c3fb27SDimitry Andric     std::pair<Value *, Value *> &PartialMatch) {
552bdd1243dSDimitry Andric   LLVM_DEBUG(dbgs() << "identifyNodeWithImplicitAdd " << *Real << " / " << *Imag
553bdd1243dSDimitry Andric                     << "\n");
554bdd1243dSDimitry Andric 
555bdd1243dSDimitry Andric   if (!Real->hasOneUse() || !Imag->hasOneUse()) {
556bdd1243dSDimitry Andric     LLVM_DEBUG(dbgs() << "  - Mul operand has multiple uses.\n");
557bdd1243dSDimitry Andric     return nullptr;
558bdd1243dSDimitry Andric   }
559bdd1243dSDimitry Andric 
56006c3fb27SDimitry Andric   if ((Real->getOpcode() != Instruction::FMul &&
56106c3fb27SDimitry Andric        Real->getOpcode() != Instruction::Mul) ||
56206c3fb27SDimitry Andric       (Imag->getOpcode() != Instruction::FMul &&
56306c3fb27SDimitry Andric        Imag->getOpcode() != Instruction::Mul)) {
56406c3fb27SDimitry Andric     LLVM_DEBUG(
56506c3fb27SDimitry Andric         dbgs() << "  - Real or imaginary instruction is not fmul or mul\n");
566bdd1243dSDimitry Andric     return nullptr;
567bdd1243dSDimitry Andric   }
568bdd1243dSDimitry Andric 
56906c3fb27SDimitry Andric   Value *R0 = Real->getOperand(0);
57006c3fb27SDimitry Andric   Value *R1 = Real->getOperand(1);
57106c3fb27SDimitry Andric   Value *I0 = Imag->getOperand(0);
57206c3fb27SDimitry Andric   Value *I1 = Imag->getOperand(1);
573bdd1243dSDimitry Andric 
574bdd1243dSDimitry Andric   // A +/+ has a rotation of 0. If any of the operands are fneg, we flip the
575bdd1243dSDimitry Andric   // rotations and use the operand.
576bdd1243dSDimitry Andric   unsigned Negs = 0;
57706c3fb27SDimitry Andric   Value *Op;
57806c3fb27SDimitry Andric   if (match(R0, m_Neg(m_Value(Op)))) {
579bdd1243dSDimitry Andric     Negs |= 1;
58006c3fb27SDimitry Andric     R0 = Op;
58106c3fb27SDimitry Andric   } else if (match(R1, m_Neg(m_Value(Op)))) {
58206c3fb27SDimitry Andric     Negs |= 1;
58306c3fb27SDimitry Andric     R1 = Op;
584bdd1243dSDimitry Andric   }
58506c3fb27SDimitry Andric 
58606c3fb27SDimitry Andric   if (isNeg(I0)) {
587bdd1243dSDimitry Andric     Negs |= 2;
588bdd1243dSDimitry Andric     Negs ^= 1;
58906c3fb27SDimitry Andric     I0 = Op;
59006c3fb27SDimitry Andric   } else if (match(I1, m_Neg(m_Value(Op)))) {
59106c3fb27SDimitry Andric     Negs |= 2;
59206c3fb27SDimitry Andric     Negs ^= 1;
59306c3fb27SDimitry Andric     I1 = Op;
594bdd1243dSDimitry Andric   }
595bdd1243dSDimitry Andric 
596bdd1243dSDimitry Andric   ComplexDeinterleavingRotation Rotation = (ComplexDeinterleavingRotation)Negs;
597bdd1243dSDimitry Andric 
59806c3fb27SDimitry Andric   Value *CommonOperand;
59906c3fb27SDimitry Andric   Value *UncommonRealOp;
60006c3fb27SDimitry Andric   Value *UncommonImagOp;
601bdd1243dSDimitry Andric 
602bdd1243dSDimitry Andric   if (R0 == I0 || R0 == I1) {
603bdd1243dSDimitry Andric     CommonOperand = R0;
604bdd1243dSDimitry Andric     UncommonRealOp = R1;
605bdd1243dSDimitry Andric   } else if (R1 == I0 || R1 == I1) {
606bdd1243dSDimitry Andric     CommonOperand = R1;
607bdd1243dSDimitry Andric     UncommonRealOp = R0;
608bdd1243dSDimitry Andric   } else {
609bdd1243dSDimitry Andric     LLVM_DEBUG(dbgs() << "  - No equal operand\n");
610bdd1243dSDimitry Andric     return nullptr;
611bdd1243dSDimitry Andric   }
612bdd1243dSDimitry Andric 
613bdd1243dSDimitry Andric   UncommonImagOp = (CommonOperand == I0) ? I1 : I0;
614bdd1243dSDimitry Andric   if (Rotation == ComplexDeinterleavingRotation::Rotation_90 ||
615bdd1243dSDimitry Andric       Rotation == ComplexDeinterleavingRotation::Rotation_270)
616bdd1243dSDimitry Andric     std::swap(UncommonRealOp, UncommonImagOp);
617bdd1243dSDimitry Andric 
618bdd1243dSDimitry Andric   // Between identifyPartialMul and here we need to have found a complete valid
619bdd1243dSDimitry Andric   // pair from the CommonOperand of each part.
620bdd1243dSDimitry Andric   if (Rotation == ComplexDeinterleavingRotation::Rotation_0 ||
621bdd1243dSDimitry Andric       Rotation == ComplexDeinterleavingRotation::Rotation_180)
622bdd1243dSDimitry Andric     PartialMatch.first = CommonOperand;
623bdd1243dSDimitry Andric   else
624bdd1243dSDimitry Andric     PartialMatch.second = CommonOperand;
625bdd1243dSDimitry Andric 
626bdd1243dSDimitry Andric   if (!PartialMatch.first || !PartialMatch.second) {
627bdd1243dSDimitry Andric     LLVM_DEBUG(dbgs() << "  - Incomplete partial match\n");
628bdd1243dSDimitry Andric     return nullptr;
629bdd1243dSDimitry Andric   }
630bdd1243dSDimitry Andric 
631bdd1243dSDimitry Andric   NodePtr CommonNode = identifyNode(PartialMatch.first, PartialMatch.second);
632bdd1243dSDimitry Andric   if (!CommonNode) {
633bdd1243dSDimitry Andric     LLVM_DEBUG(dbgs() << "  - No CommonNode identified\n");
634bdd1243dSDimitry Andric     return nullptr;
635bdd1243dSDimitry Andric   }
636bdd1243dSDimitry Andric 
637bdd1243dSDimitry Andric   NodePtr UncommonNode = identifyNode(UncommonRealOp, UncommonImagOp);
638bdd1243dSDimitry Andric   if (!UncommonNode) {
639bdd1243dSDimitry Andric     LLVM_DEBUG(dbgs() << "  - No UncommonNode identified\n");
640bdd1243dSDimitry Andric     return nullptr;
641bdd1243dSDimitry Andric   }
642bdd1243dSDimitry Andric 
643bdd1243dSDimitry Andric   NodePtr Node = prepareCompositeNode(
644bdd1243dSDimitry Andric       ComplexDeinterleavingOperation::CMulPartial, Real, Imag);
645bdd1243dSDimitry Andric   Node->Rotation = Rotation;
646bdd1243dSDimitry Andric   Node->addOperand(CommonNode);
647bdd1243dSDimitry Andric   Node->addOperand(UncommonNode);
648bdd1243dSDimitry Andric   return submitCompositeNode(Node);
649bdd1243dSDimitry Andric }
650bdd1243dSDimitry Andric 
651bdd1243dSDimitry Andric ComplexDeinterleavingGraph::NodePtr
652bdd1243dSDimitry Andric ComplexDeinterleavingGraph::identifyPartialMul(Instruction *Real,
653bdd1243dSDimitry Andric                                                Instruction *Imag) {
654bdd1243dSDimitry Andric   LLVM_DEBUG(dbgs() << "identifyPartialMul " << *Real << " / " << *Imag
655bdd1243dSDimitry Andric                     << "\n");
656bdd1243dSDimitry Andric   // Determine rotation
65706c3fb27SDimitry Andric   auto IsAdd = [](unsigned Op) {
65806c3fb27SDimitry Andric     return Op == Instruction::FAdd || Op == Instruction::Add;
65906c3fb27SDimitry Andric   };
66006c3fb27SDimitry Andric   auto IsSub = [](unsigned Op) {
66106c3fb27SDimitry Andric     return Op == Instruction::FSub || Op == Instruction::Sub;
66206c3fb27SDimitry Andric   };
663bdd1243dSDimitry Andric   ComplexDeinterleavingRotation Rotation;
66406c3fb27SDimitry Andric   if (IsAdd(Real->getOpcode()) && IsAdd(Imag->getOpcode()))
665bdd1243dSDimitry Andric     Rotation = ComplexDeinterleavingRotation::Rotation_0;
66606c3fb27SDimitry Andric   else if (IsSub(Real->getOpcode()) && IsAdd(Imag->getOpcode()))
667bdd1243dSDimitry Andric     Rotation = ComplexDeinterleavingRotation::Rotation_90;
66806c3fb27SDimitry Andric   else if (IsSub(Real->getOpcode()) && IsSub(Imag->getOpcode()))
669bdd1243dSDimitry Andric     Rotation = ComplexDeinterleavingRotation::Rotation_180;
67006c3fb27SDimitry Andric   else if (IsAdd(Real->getOpcode()) && IsSub(Imag->getOpcode()))
671bdd1243dSDimitry Andric     Rotation = ComplexDeinterleavingRotation::Rotation_270;
672bdd1243dSDimitry Andric   else {
673bdd1243dSDimitry Andric     LLVM_DEBUG(dbgs() << "  - Unhandled rotation.\n");
674bdd1243dSDimitry Andric     return nullptr;
675bdd1243dSDimitry Andric   }
676bdd1243dSDimitry Andric 
67706c3fb27SDimitry Andric   if (isa<FPMathOperator>(Real) &&
67806c3fb27SDimitry Andric       (!Real->getFastMathFlags().allowContract() ||
67906c3fb27SDimitry Andric        !Imag->getFastMathFlags().allowContract())) {
680bdd1243dSDimitry Andric     LLVM_DEBUG(dbgs() << "  - Contract is missing from the FastMath flags.\n");
681bdd1243dSDimitry Andric     return nullptr;
682bdd1243dSDimitry Andric   }
683bdd1243dSDimitry Andric 
684bdd1243dSDimitry Andric   Value *CR = Real->getOperand(0);
685bdd1243dSDimitry Andric   Instruction *RealMulI = dyn_cast<Instruction>(Real->getOperand(1));
686bdd1243dSDimitry Andric   if (!RealMulI)
687bdd1243dSDimitry Andric     return nullptr;
688bdd1243dSDimitry Andric   Value *CI = Imag->getOperand(0);
689bdd1243dSDimitry Andric   Instruction *ImagMulI = dyn_cast<Instruction>(Imag->getOperand(1));
690bdd1243dSDimitry Andric   if (!ImagMulI)
691bdd1243dSDimitry Andric     return nullptr;
692bdd1243dSDimitry Andric 
693bdd1243dSDimitry Andric   if (!RealMulI->hasOneUse() || !ImagMulI->hasOneUse()) {
694bdd1243dSDimitry Andric     LLVM_DEBUG(dbgs() << "  - Mul instruction has multiple uses\n");
695bdd1243dSDimitry Andric     return nullptr;
696bdd1243dSDimitry Andric   }
697bdd1243dSDimitry Andric 
69806c3fb27SDimitry Andric   Value *R0 = RealMulI->getOperand(0);
69906c3fb27SDimitry Andric   Value *R1 = RealMulI->getOperand(1);
70006c3fb27SDimitry Andric   Value *I0 = ImagMulI->getOperand(0);
70106c3fb27SDimitry Andric   Value *I1 = ImagMulI->getOperand(1);
702bdd1243dSDimitry Andric 
70306c3fb27SDimitry Andric   Value *CommonOperand;
70406c3fb27SDimitry Andric   Value *UncommonRealOp;
70506c3fb27SDimitry Andric   Value *UncommonImagOp;
706bdd1243dSDimitry Andric 
707bdd1243dSDimitry Andric   if (R0 == I0 || R0 == I1) {
708bdd1243dSDimitry Andric     CommonOperand = R0;
709bdd1243dSDimitry Andric     UncommonRealOp = R1;
710bdd1243dSDimitry Andric   } else if (R1 == I0 || R1 == I1) {
711bdd1243dSDimitry Andric     CommonOperand = R1;
712bdd1243dSDimitry Andric     UncommonRealOp = R0;
713bdd1243dSDimitry Andric   } else {
714bdd1243dSDimitry Andric     LLVM_DEBUG(dbgs() << "  - No equal operand\n");
715bdd1243dSDimitry Andric     return nullptr;
716bdd1243dSDimitry Andric   }
717bdd1243dSDimitry Andric 
718bdd1243dSDimitry Andric   UncommonImagOp = (CommonOperand == I0) ? I1 : I0;
719bdd1243dSDimitry Andric   if (Rotation == ComplexDeinterleavingRotation::Rotation_90 ||
720bdd1243dSDimitry Andric       Rotation == ComplexDeinterleavingRotation::Rotation_270)
721bdd1243dSDimitry Andric     std::swap(UncommonRealOp, UncommonImagOp);
722bdd1243dSDimitry Andric 
72306c3fb27SDimitry Andric   std::pair<Value *, Value *> PartialMatch(
724bdd1243dSDimitry Andric       (Rotation == ComplexDeinterleavingRotation::Rotation_0 ||
725bdd1243dSDimitry Andric        Rotation == ComplexDeinterleavingRotation::Rotation_180)
726bdd1243dSDimitry Andric           ? CommonOperand
727bdd1243dSDimitry Andric           : nullptr,
728bdd1243dSDimitry Andric       (Rotation == ComplexDeinterleavingRotation::Rotation_90 ||
729bdd1243dSDimitry Andric        Rotation == ComplexDeinterleavingRotation::Rotation_270)
730bdd1243dSDimitry Andric           ? CommonOperand
731bdd1243dSDimitry Andric           : nullptr);
73206c3fb27SDimitry Andric 
73306c3fb27SDimitry Andric   auto *CRInst = dyn_cast<Instruction>(CR);
73406c3fb27SDimitry Andric   auto *CIInst = dyn_cast<Instruction>(CI);
73506c3fb27SDimitry Andric 
73606c3fb27SDimitry Andric   if (!CRInst || !CIInst) {
73706c3fb27SDimitry Andric     LLVM_DEBUG(dbgs() << "  - Common operands are not instructions.\n");
73806c3fb27SDimitry Andric     return nullptr;
73906c3fb27SDimitry Andric   }
74006c3fb27SDimitry Andric 
74106c3fb27SDimitry Andric   NodePtr CNode = identifyNodeWithImplicitAdd(CRInst, CIInst, PartialMatch);
742bdd1243dSDimitry Andric   if (!CNode) {
743bdd1243dSDimitry Andric     LLVM_DEBUG(dbgs() << "  - No cnode identified\n");
744bdd1243dSDimitry Andric     return nullptr;
745bdd1243dSDimitry Andric   }
746bdd1243dSDimitry Andric 
747bdd1243dSDimitry Andric   NodePtr UncommonRes = identifyNode(UncommonRealOp, UncommonImagOp);
748bdd1243dSDimitry Andric   if (!UncommonRes) {
749bdd1243dSDimitry Andric     LLVM_DEBUG(dbgs() << "  - No UncommonRes identified\n");
750bdd1243dSDimitry Andric     return nullptr;
751bdd1243dSDimitry Andric   }
752bdd1243dSDimitry Andric 
753bdd1243dSDimitry Andric   assert(PartialMatch.first && PartialMatch.second);
754bdd1243dSDimitry Andric   NodePtr CommonRes = identifyNode(PartialMatch.first, PartialMatch.second);
755bdd1243dSDimitry Andric   if (!CommonRes) {
756bdd1243dSDimitry Andric     LLVM_DEBUG(dbgs() << "  - No CommonRes identified\n");
757bdd1243dSDimitry Andric     return nullptr;
758bdd1243dSDimitry Andric   }
759bdd1243dSDimitry Andric 
760bdd1243dSDimitry Andric   NodePtr Node = prepareCompositeNode(
761bdd1243dSDimitry Andric       ComplexDeinterleavingOperation::CMulPartial, Real, Imag);
762bdd1243dSDimitry Andric   Node->Rotation = Rotation;
763bdd1243dSDimitry Andric   Node->addOperand(CommonRes);
764bdd1243dSDimitry Andric   Node->addOperand(UncommonRes);
765bdd1243dSDimitry Andric   Node->addOperand(CNode);
766bdd1243dSDimitry Andric   return submitCompositeNode(Node);
767bdd1243dSDimitry Andric }
768bdd1243dSDimitry Andric 
769bdd1243dSDimitry Andric ComplexDeinterleavingGraph::NodePtr
770bdd1243dSDimitry Andric ComplexDeinterleavingGraph::identifyAdd(Instruction *Real, Instruction *Imag) {
771bdd1243dSDimitry Andric   LLVM_DEBUG(dbgs() << "identifyAdd " << *Real << " / " << *Imag << "\n");
772bdd1243dSDimitry Andric 
773bdd1243dSDimitry Andric   // Determine rotation
774bdd1243dSDimitry Andric   ComplexDeinterleavingRotation Rotation;
775bdd1243dSDimitry Andric   if ((Real->getOpcode() == Instruction::FSub &&
776bdd1243dSDimitry Andric        Imag->getOpcode() == Instruction::FAdd) ||
777bdd1243dSDimitry Andric       (Real->getOpcode() == Instruction::Sub &&
778bdd1243dSDimitry Andric        Imag->getOpcode() == Instruction::Add))
779bdd1243dSDimitry Andric     Rotation = ComplexDeinterleavingRotation::Rotation_90;
780bdd1243dSDimitry Andric   else if ((Real->getOpcode() == Instruction::FAdd &&
781bdd1243dSDimitry Andric             Imag->getOpcode() == Instruction::FSub) ||
782bdd1243dSDimitry Andric            (Real->getOpcode() == Instruction::Add &&
783bdd1243dSDimitry Andric             Imag->getOpcode() == Instruction::Sub))
784bdd1243dSDimitry Andric     Rotation = ComplexDeinterleavingRotation::Rotation_270;
785bdd1243dSDimitry Andric   else {
786bdd1243dSDimitry Andric     LLVM_DEBUG(dbgs() << " - Unhandled case, rotation is not assigned.\n");
787bdd1243dSDimitry Andric     return nullptr;
788bdd1243dSDimitry Andric   }
789bdd1243dSDimitry Andric 
790bdd1243dSDimitry Andric   auto *AR = dyn_cast<Instruction>(Real->getOperand(0));
791bdd1243dSDimitry Andric   auto *BI = dyn_cast<Instruction>(Real->getOperand(1));
792bdd1243dSDimitry Andric   auto *AI = dyn_cast<Instruction>(Imag->getOperand(0));
793bdd1243dSDimitry Andric   auto *BR = dyn_cast<Instruction>(Imag->getOperand(1));
794bdd1243dSDimitry Andric 
795bdd1243dSDimitry Andric   if (!AR || !AI || !BR || !BI) {
796bdd1243dSDimitry Andric     LLVM_DEBUG(dbgs() << " - Not all operands are instructions.\n");
797bdd1243dSDimitry Andric     return nullptr;
798bdd1243dSDimitry Andric   }
799bdd1243dSDimitry Andric 
800bdd1243dSDimitry Andric   NodePtr ResA = identifyNode(AR, AI);
801bdd1243dSDimitry Andric   if (!ResA) {
802bdd1243dSDimitry Andric     LLVM_DEBUG(dbgs() << " - AR/AI is not identified as a composite node.\n");
803bdd1243dSDimitry Andric     return nullptr;
804bdd1243dSDimitry Andric   }
805bdd1243dSDimitry Andric   NodePtr ResB = identifyNode(BR, BI);
806bdd1243dSDimitry Andric   if (!ResB) {
807bdd1243dSDimitry Andric     LLVM_DEBUG(dbgs() << " - BR/BI is not identified as a composite node.\n");
808bdd1243dSDimitry Andric     return nullptr;
809bdd1243dSDimitry Andric   }
810bdd1243dSDimitry Andric 
811bdd1243dSDimitry Andric   NodePtr Node =
812bdd1243dSDimitry Andric       prepareCompositeNode(ComplexDeinterleavingOperation::CAdd, Real, Imag);
813bdd1243dSDimitry Andric   Node->Rotation = Rotation;
814bdd1243dSDimitry Andric   Node->addOperand(ResA);
815bdd1243dSDimitry Andric   Node->addOperand(ResB);
816bdd1243dSDimitry Andric   return submitCompositeNode(Node);
817bdd1243dSDimitry Andric }
818bdd1243dSDimitry Andric 
819bdd1243dSDimitry Andric static bool isInstructionPairAdd(Instruction *A, Instruction *B) {
820bdd1243dSDimitry Andric   unsigned OpcA = A->getOpcode();
821bdd1243dSDimitry Andric   unsigned OpcB = B->getOpcode();
822bdd1243dSDimitry Andric 
823bdd1243dSDimitry Andric   return (OpcA == Instruction::FSub && OpcB == Instruction::FAdd) ||
824bdd1243dSDimitry Andric          (OpcA == Instruction::FAdd && OpcB == Instruction::FSub) ||
825bdd1243dSDimitry Andric          (OpcA == Instruction::Sub && OpcB == Instruction::Add) ||
826bdd1243dSDimitry Andric          (OpcA == Instruction::Add && OpcB == Instruction::Sub);
827bdd1243dSDimitry Andric }
828bdd1243dSDimitry Andric 
829bdd1243dSDimitry Andric static bool isInstructionPairMul(Instruction *A, Instruction *B) {
830bdd1243dSDimitry Andric   auto Pattern =
831bdd1243dSDimitry Andric       m_BinOp(m_FMul(m_Value(), m_Value()), m_FMul(m_Value(), m_Value()));
832bdd1243dSDimitry Andric 
833bdd1243dSDimitry Andric   return match(A, Pattern) && match(B, Pattern);
834bdd1243dSDimitry Andric }
835bdd1243dSDimitry Andric 
83606c3fb27SDimitry Andric static bool isInstructionPotentiallySymmetric(Instruction *I) {
83706c3fb27SDimitry Andric   switch (I->getOpcode()) {
83806c3fb27SDimitry Andric   case Instruction::FAdd:
83906c3fb27SDimitry Andric   case Instruction::FSub:
84006c3fb27SDimitry Andric   case Instruction::FMul:
84106c3fb27SDimitry Andric   case Instruction::FNeg:
84206c3fb27SDimitry Andric   case Instruction::Add:
84306c3fb27SDimitry Andric   case Instruction::Sub:
84406c3fb27SDimitry Andric   case Instruction::Mul:
84506c3fb27SDimitry Andric     return true;
84606c3fb27SDimitry Andric   default:
84706c3fb27SDimitry Andric     return false;
84806c3fb27SDimitry Andric   }
84906c3fb27SDimitry Andric }
85006c3fb27SDimitry Andric 
851bdd1243dSDimitry Andric ComplexDeinterleavingGraph::NodePtr
85206c3fb27SDimitry Andric ComplexDeinterleavingGraph::identifySymmetricOperation(Instruction *Real,
85306c3fb27SDimitry Andric                                                        Instruction *Imag) {
85406c3fb27SDimitry Andric   if (Real->getOpcode() != Imag->getOpcode())
85506c3fb27SDimitry Andric     return nullptr;
85606c3fb27SDimitry Andric 
85706c3fb27SDimitry Andric   if (!isInstructionPotentiallySymmetric(Real) ||
85806c3fb27SDimitry Andric       !isInstructionPotentiallySymmetric(Imag))
85906c3fb27SDimitry Andric     return nullptr;
86006c3fb27SDimitry Andric 
86106c3fb27SDimitry Andric   auto *R0 = Real->getOperand(0);
86206c3fb27SDimitry Andric   auto *I0 = Imag->getOperand(0);
86306c3fb27SDimitry Andric 
86406c3fb27SDimitry Andric   NodePtr Op0 = identifyNode(R0, I0);
86506c3fb27SDimitry Andric   NodePtr Op1 = nullptr;
86606c3fb27SDimitry Andric   if (Op0 == nullptr)
86706c3fb27SDimitry Andric     return nullptr;
86806c3fb27SDimitry Andric 
86906c3fb27SDimitry Andric   if (Real->isBinaryOp()) {
87006c3fb27SDimitry Andric     auto *R1 = Real->getOperand(1);
87106c3fb27SDimitry Andric     auto *I1 = Imag->getOperand(1);
87206c3fb27SDimitry Andric     Op1 = identifyNode(R1, I1);
87306c3fb27SDimitry Andric     if (Op1 == nullptr)
87406c3fb27SDimitry Andric       return nullptr;
87506c3fb27SDimitry Andric   }
87606c3fb27SDimitry Andric 
87706c3fb27SDimitry Andric   if (isa<FPMathOperator>(Real) &&
87806c3fb27SDimitry Andric       Real->getFastMathFlags() != Imag->getFastMathFlags())
87906c3fb27SDimitry Andric     return nullptr;
88006c3fb27SDimitry Andric 
88106c3fb27SDimitry Andric   auto Node = prepareCompositeNode(ComplexDeinterleavingOperation::Symmetric,
88206c3fb27SDimitry Andric                                    Real, Imag);
88306c3fb27SDimitry Andric   Node->Opcode = Real->getOpcode();
88406c3fb27SDimitry Andric   if (isa<FPMathOperator>(Real))
88506c3fb27SDimitry Andric     Node->Flags = Real->getFastMathFlags();
88606c3fb27SDimitry Andric 
88706c3fb27SDimitry Andric   Node->addOperand(Op0);
88806c3fb27SDimitry Andric   if (Real->isBinaryOp())
88906c3fb27SDimitry Andric     Node->addOperand(Op1);
89006c3fb27SDimitry Andric 
89106c3fb27SDimitry Andric   return submitCompositeNode(Node);
89206c3fb27SDimitry Andric }
89306c3fb27SDimitry Andric 
89406c3fb27SDimitry Andric ComplexDeinterleavingGraph::NodePtr
89506c3fb27SDimitry Andric ComplexDeinterleavingGraph::identifyNode(Value *R, Value *I) {
89606c3fb27SDimitry Andric   LLVM_DEBUG(dbgs() << "identifyNode on " << *R << " / " << *I << "\n");
89706c3fb27SDimitry Andric   assert(R->getType() == I->getType() &&
89806c3fb27SDimitry Andric          "Real and imaginary parts should not have different types");
8998a4dda33SDimitry Andric 
9008a4dda33SDimitry Andric   auto It = CachedResult.find({R, I});
9018a4dda33SDimitry Andric   if (It != CachedResult.end()) {
902bdd1243dSDimitry Andric     LLVM_DEBUG(dbgs() << " - Folding to existing node\n");
9038a4dda33SDimitry Andric     return It->second;
904bdd1243dSDimitry Andric   }
905bdd1243dSDimitry Andric 
90606c3fb27SDimitry Andric   if (NodePtr CN = identifySplat(R, I))
90706c3fb27SDimitry Andric     return CN;
90806c3fb27SDimitry Andric 
90906c3fb27SDimitry Andric   auto *Real = dyn_cast<Instruction>(R);
91006c3fb27SDimitry Andric   auto *Imag = dyn_cast<Instruction>(I);
91106c3fb27SDimitry Andric   if (!Real || !Imag)
91206c3fb27SDimitry Andric     return nullptr;
91306c3fb27SDimitry Andric 
91406c3fb27SDimitry Andric   if (NodePtr CN = identifyDeinterleave(Real, Imag))
91506c3fb27SDimitry Andric     return CN;
91606c3fb27SDimitry Andric 
91706c3fb27SDimitry Andric   if (NodePtr CN = identifyPHINode(Real, Imag))
91806c3fb27SDimitry Andric     return CN;
91906c3fb27SDimitry Andric 
92006c3fb27SDimitry Andric   if (NodePtr CN = identifySelectNode(Real, Imag))
92106c3fb27SDimitry Andric     return CN;
92206c3fb27SDimitry Andric 
92306c3fb27SDimitry Andric   auto *VTy = cast<VectorType>(Real->getType());
92406c3fb27SDimitry Andric   auto *NewVTy = VectorType::getDoubleElementsVectorType(VTy);
92506c3fb27SDimitry Andric 
92606c3fb27SDimitry Andric   bool HasCMulSupport = TL->isComplexDeinterleavingOperationSupported(
92706c3fb27SDimitry Andric       ComplexDeinterleavingOperation::CMulPartial, NewVTy);
92806c3fb27SDimitry Andric   bool HasCAddSupport = TL->isComplexDeinterleavingOperationSupported(
92906c3fb27SDimitry Andric       ComplexDeinterleavingOperation::CAdd, NewVTy);
93006c3fb27SDimitry Andric 
93106c3fb27SDimitry Andric   if (HasCMulSupport && isInstructionPairMul(Real, Imag)) {
93206c3fb27SDimitry Andric     if (NodePtr CN = identifyPartialMul(Real, Imag))
93306c3fb27SDimitry Andric       return CN;
93406c3fb27SDimitry Andric   }
93506c3fb27SDimitry Andric 
93606c3fb27SDimitry Andric   if (HasCAddSupport && isInstructionPairAdd(Real, Imag)) {
93706c3fb27SDimitry Andric     if (NodePtr CN = identifyAdd(Real, Imag))
93806c3fb27SDimitry Andric       return CN;
93906c3fb27SDimitry Andric   }
94006c3fb27SDimitry Andric 
94106c3fb27SDimitry Andric   if (HasCMulSupport && HasCAddSupport) {
94206c3fb27SDimitry Andric     if (NodePtr CN = identifyReassocNodes(Real, Imag))
94306c3fb27SDimitry Andric       return CN;
94406c3fb27SDimitry Andric   }
94506c3fb27SDimitry Andric 
94606c3fb27SDimitry Andric   if (NodePtr CN = identifySymmetricOperation(Real, Imag))
94706c3fb27SDimitry Andric     return CN;
94806c3fb27SDimitry Andric 
94906c3fb27SDimitry Andric   LLVM_DEBUG(dbgs() << "  - Not recognised as a valid pattern.\n");
9508a4dda33SDimitry Andric   CachedResult[{R, I}] = nullptr;
95106c3fb27SDimitry Andric   return nullptr;
95206c3fb27SDimitry Andric }
95306c3fb27SDimitry Andric 
95406c3fb27SDimitry Andric ComplexDeinterleavingGraph::NodePtr
95506c3fb27SDimitry Andric ComplexDeinterleavingGraph::identifyReassocNodes(Instruction *Real,
95606c3fb27SDimitry Andric                                                  Instruction *Imag) {
95706c3fb27SDimitry Andric   auto IsOperationSupported = [](unsigned Opcode) -> bool {
95806c3fb27SDimitry Andric     return Opcode == Instruction::FAdd || Opcode == Instruction::FSub ||
95906c3fb27SDimitry Andric            Opcode == Instruction::FNeg || Opcode == Instruction::Add ||
96006c3fb27SDimitry Andric            Opcode == Instruction::Sub;
96106c3fb27SDimitry Andric   };
96206c3fb27SDimitry Andric 
96306c3fb27SDimitry Andric   if (!IsOperationSupported(Real->getOpcode()) ||
96406c3fb27SDimitry Andric       !IsOperationSupported(Imag->getOpcode()))
96506c3fb27SDimitry Andric     return nullptr;
96606c3fb27SDimitry Andric 
96706c3fb27SDimitry Andric   std::optional<FastMathFlags> Flags;
96806c3fb27SDimitry Andric   if (isa<FPMathOperator>(Real)) {
96906c3fb27SDimitry Andric     if (Real->getFastMathFlags() != Imag->getFastMathFlags()) {
97006c3fb27SDimitry Andric       LLVM_DEBUG(dbgs() << "The flags in Real and Imaginary instructions are "
97106c3fb27SDimitry Andric                            "not identical\n");
97206c3fb27SDimitry Andric       return nullptr;
97306c3fb27SDimitry Andric     }
97406c3fb27SDimitry Andric 
97506c3fb27SDimitry Andric     Flags = Real->getFastMathFlags();
97606c3fb27SDimitry Andric     if (!Flags->allowReassoc()) {
97706c3fb27SDimitry Andric       LLVM_DEBUG(
97806c3fb27SDimitry Andric           dbgs()
97906c3fb27SDimitry Andric           << "the 'Reassoc' attribute is missing in the FastMath flags\n");
98006c3fb27SDimitry Andric       return nullptr;
98106c3fb27SDimitry Andric     }
98206c3fb27SDimitry Andric   }
98306c3fb27SDimitry Andric 
98406c3fb27SDimitry Andric   // Collect multiplications and addend instructions from the given instruction
98506c3fb27SDimitry Andric   // while traversing it operands. Additionally, verify that all instructions
98606c3fb27SDimitry Andric   // have the same fast math flags.
98706c3fb27SDimitry Andric   auto Collect = [&Flags](Instruction *Insn, std::vector<Product> &Muls,
98806c3fb27SDimitry Andric                           std::list<Addend> &Addends) -> bool {
98906c3fb27SDimitry Andric     SmallVector<PointerIntPair<Value *, 1, bool>> Worklist = {{Insn, true}};
99006c3fb27SDimitry Andric     SmallPtrSet<Value *, 8> Visited;
99106c3fb27SDimitry Andric     while (!Worklist.empty()) {
99206c3fb27SDimitry Andric       auto [V, IsPositive] = Worklist.back();
99306c3fb27SDimitry Andric       Worklist.pop_back();
99406c3fb27SDimitry Andric       if (!Visited.insert(V).second)
99506c3fb27SDimitry Andric         continue;
99606c3fb27SDimitry Andric 
99706c3fb27SDimitry Andric       Instruction *I = dyn_cast<Instruction>(V);
99806c3fb27SDimitry Andric       if (!I) {
99906c3fb27SDimitry Andric         Addends.emplace_back(V, IsPositive);
100006c3fb27SDimitry Andric         continue;
100106c3fb27SDimitry Andric       }
100206c3fb27SDimitry Andric 
100306c3fb27SDimitry Andric       // If an instruction has more than one user, it indicates that it either
100406c3fb27SDimitry Andric       // has an external user, which will be later checked by the checkNodes
100506c3fb27SDimitry Andric       // function, or it is a subexpression utilized by multiple expressions. In
100606c3fb27SDimitry Andric       // the latter case, we will attempt to separately identify the complex
100706c3fb27SDimitry Andric       // operation from here in order to create a shared
100806c3fb27SDimitry Andric       // ComplexDeinterleavingCompositeNode.
100906c3fb27SDimitry Andric       if (I != Insn && I->getNumUses() > 1) {
101006c3fb27SDimitry Andric         LLVM_DEBUG(dbgs() << "Found potential sub-expression: " << *I << "\n");
101106c3fb27SDimitry Andric         Addends.emplace_back(I, IsPositive);
101206c3fb27SDimitry Andric         continue;
101306c3fb27SDimitry Andric       }
101406c3fb27SDimitry Andric       switch (I->getOpcode()) {
101506c3fb27SDimitry Andric       case Instruction::FAdd:
101606c3fb27SDimitry Andric       case Instruction::Add:
101706c3fb27SDimitry Andric         Worklist.emplace_back(I->getOperand(1), IsPositive);
101806c3fb27SDimitry Andric         Worklist.emplace_back(I->getOperand(0), IsPositive);
101906c3fb27SDimitry Andric         break;
102006c3fb27SDimitry Andric       case Instruction::FSub:
102106c3fb27SDimitry Andric         Worklist.emplace_back(I->getOperand(1), !IsPositive);
102206c3fb27SDimitry Andric         Worklist.emplace_back(I->getOperand(0), IsPositive);
102306c3fb27SDimitry Andric         break;
102406c3fb27SDimitry Andric       case Instruction::Sub:
102506c3fb27SDimitry Andric         if (isNeg(I)) {
102606c3fb27SDimitry Andric           Worklist.emplace_back(getNegOperand(I), !IsPositive);
102706c3fb27SDimitry Andric         } else {
102806c3fb27SDimitry Andric           Worklist.emplace_back(I->getOperand(1), !IsPositive);
102906c3fb27SDimitry Andric           Worklist.emplace_back(I->getOperand(0), IsPositive);
103006c3fb27SDimitry Andric         }
103106c3fb27SDimitry Andric         break;
103206c3fb27SDimitry Andric       case Instruction::FMul:
103306c3fb27SDimitry Andric       case Instruction::Mul: {
103406c3fb27SDimitry Andric         Value *A, *B;
103506c3fb27SDimitry Andric         if (isNeg(I->getOperand(0))) {
103606c3fb27SDimitry Andric           A = getNegOperand(I->getOperand(0));
103706c3fb27SDimitry Andric           IsPositive = !IsPositive;
103806c3fb27SDimitry Andric         } else {
103906c3fb27SDimitry Andric           A = I->getOperand(0);
104006c3fb27SDimitry Andric         }
104106c3fb27SDimitry Andric 
104206c3fb27SDimitry Andric         if (isNeg(I->getOperand(1))) {
104306c3fb27SDimitry Andric           B = getNegOperand(I->getOperand(1));
104406c3fb27SDimitry Andric           IsPositive = !IsPositive;
104506c3fb27SDimitry Andric         } else {
104606c3fb27SDimitry Andric           B = I->getOperand(1);
104706c3fb27SDimitry Andric         }
104806c3fb27SDimitry Andric         Muls.push_back(Product{A, B, IsPositive});
104906c3fb27SDimitry Andric         break;
105006c3fb27SDimitry Andric       }
105106c3fb27SDimitry Andric       case Instruction::FNeg:
105206c3fb27SDimitry Andric         Worklist.emplace_back(I->getOperand(0), !IsPositive);
105306c3fb27SDimitry Andric         break;
105406c3fb27SDimitry Andric       default:
105506c3fb27SDimitry Andric         Addends.emplace_back(I, IsPositive);
105606c3fb27SDimitry Andric         continue;
105706c3fb27SDimitry Andric       }
105806c3fb27SDimitry Andric 
105906c3fb27SDimitry Andric       if (Flags && I->getFastMathFlags() != *Flags) {
106006c3fb27SDimitry Andric         LLVM_DEBUG(dbgs() << "The instruction's fast math flags are "
106106c3fb27SDimitry Andric                              "inconsistent with the root instructions' flags: "
106206c3fb27SDimitry Andric                           << *I << "\n");
106306c3fb27SDimitry Andric         return false;
106406c3fb27SDimitry Andric       }
106506c3fb27SDimitry Andric     }
106606c3fb27SDimitry Andric     return true;
106706c3fb27SDimitry Andric   };
106806c3fb27SDimitry Andric 
106906c3fb27SDimitry Andric   std::vector<Product> RealMuls, ImagMuls;
107006c3fb27SDimitry Andric   std::list<Addend> RealAddends, ImagAddends;
107106c3fb27SDimitry Andric   if (!Collect(Real, RealMuls, RealAddends) ||
107206c3fb27SDimitry Andric       !Collect(Imag, ImagMuls, ImagAddends))
107306c3fb27SDimitry Andric     return nullptr;
107406c3fb27SDimitry Andric 
107506c3fb27SDimitry Andric   if (RealAddends.size() != ImagAddends.size())
107606c3fb27SDimitry Andric     return nullptr;
107706c3fb27SDimitry Andric 
107806c3fb27SDimitry Andric   NodePtr FinalNode;
107906c3fb27SDimitry Andric   if (!RealMuls.empty() || !ImagMuls.empty()) {
108006c3fb27SDimitry Andric     // If there are multiplicands, extract positive addend and use it as an
108106c3fb27SDimitry Andric     // accumulator
108206c3fb27SDimitry Andric     FinalNode = extractPositiveAddend(RealAddends, ImagAddends);
108306c3fb27SDimitry Andric     FinalNode = identifyMultiplications(RealMuls, ImagMuls, FinalNode);
108406c3fb27SDimitry Andric     if (!FinalNode)
108506c3fb27SDimitry Andric       return nullptr;
108606c3fb27SDimitry Andric   }
108706c3fb27SDimitry Andric 
108806c3fb27SDimitry Andric   // Identify and process remaining additions
108906c3fb27SDimitry Andric   if (!RealAddends.empty() || !ImagAddends.empty()) {
109006c3fb27SDimitry Andric     FinalNode = identifyAdditions(RealAddends, ImagAddends, Flags, FinalNode);
109106c3fb27SDimitry Andric     if (!FinalNode)
109206c3fb27SDimitry Andric       return nullptr;
109306c3fb27SDimitry Andric   }
109406c3fb27SDimitry Andric   assert(FinalNode && "FinalNode can not be nullptr here");
109506c3fb27SDimitry Andric   // Set the Real and Imag fields of the final node and submit it
109606c3fb27SDimitry Andric   FinalNode->Real = Real;
109706c3fb27SDimitry Andric   FinalNode->Imag = Imag;
109806c3fb27SDimitry Andric   submitCompositeNode(FinalNode);
109906c3fb27SDimitry Andric   return FinalNode;
110006c3fb27SDimitry Andric }
110106c3fb27SDimitry Andric 
110206c3fb27SDimitry Andric bool ComplexDeinterleavingGraph::collectPartialMuls(
110306c3fb27SDimitry Andric     const std::vector<Product> &RealMuls, const std::vector<Product> &ImagMuls,
110406c3fb27SDimitry Andric     std::vector<PartialMulCandidate> &PartialMulCandidates) {
110506c3fb27SDimitry Andric   // Helper function to extract a common operand from two products
110606c3fb27SDimitry Andric   auto FindCommonInstruction = [](const Product &Real,
110706c3fb27SDimitry Andric                                   const Product &Imag) -> Value * {
110806c3fb27SDimitry Andric     if (Real.Multiplicand == Imag.Multiplicand ||
110906c3fb27SDimitry Andric         Real.Multiplicand == Imag.Multiplier)
111006c3fb27SDimitry Andric       return Real.Multiplicand;
111106c3fb27SDimitry Andric 
111206c3fb27SDimitry Andric     if (Real.Multiplier == Imag.Multiplicand ||
111306c3fb27SDimitry Andric         Real.Multiplier == Imag.Multiplier)
111406c3fb27SDimitry Andric       return Real.Multiplier;
111506c3fb27SDimitry Andric 
111606c3fb27SDimitry Andric     return nullptr;
111706c3fb27SDimitry Andric   };
111806c3fb27SDimitry Andric 
111906c3fb27SDimitry Andric   // Iterating over real and imaginary multiplications to find common operands
112006c3fb27SDimitry Andric   // If a common operand is found, a partial multiplication candidate is created
112106c3fb27SDimitry Andric   // and added to the candidates vector The function returns false if no common
112206c3fb27SDimitry Andric   // operands are found for any product
112306c3fb27SDimitry Andric   for (unsigned i = 0; i < RealMuls.size(); ++i) {
112406c3fb27SDimitry Andric     bool FoundCommon = false;
112506c3fb27SDimitry Andric     for (unsigned j = 0; j < ImagMuls.size(); ++j) {
112606c3fb27SDimitry Andric       auto *Common = FindCommonInstruction(RealMuls[i], ImagMuls[j]);
112706c3fb27SDimitry Andric       if (!Common)
112806c3fb27SDimitry Andric         continue;
112906c3fb27SDimitry Andric 
113006c3fb27SDimitry Andric       auto *A = RealMuls[i].Multiplicand == Common ? RealMuls[i].Multiplier
113106c3fb27SDimitry Andric                                                    : RealMuls[i].Multiplicand;
113206c3fb27SDimitry Andric       auto *B = ImagMuls[j].Multiplicand == Common ? ImagMuls[j].Multiplier
113306c3fb27SDimitry Andric                                                    : ImagMuls[j].Multiplicand;
113406c3fb27SDimitry Andric 
113506c3fb27SDimitry Andric       auto Node = identifyNode(A, B);
113606c3fb27SDimitry Andric       if (Node) {
113706c3fb27SDimitry Andric         FoundCommon = true;
113806c3fb27SDimitry Andric         PartialMulCandidates.push_back({Common, Node, i, j, false});
113906c3fb27SDimitry Andric       }
114006c3fb27SDimitry Andric 
114106c3fb27SDimitry Andric       Node = identifyNode(B, A);
114206c3fb27SDimitry Andric       if (Node) {
114306c3fb27SDimitry Andric         FoundCommon = true;
114406c3fb27SDimitry Andric         PartialMulCandidates.push_back({Common, Node, i, j, true});
114506c3fb27SDimitry Andric       }
114606c3fb27SDimitry Andric     }
114706c3fb27SDimitry Andric     if (!FoundCommon)
114806c3fb27SDimitry Andric       return false;
114906c3fb27SDimitry Andric   }
115006c3fb27SDimitry Andric   return true;
115106c3fb27SDimitry Andric }
115206c3fb27SDimitry Andric 
115306c3fb27SDimitry Andric ComplexDeinterleavingGraph::NodePtr
115406c3fb27SDimitry Andric ComplexDeinterleavingGraph::identifyMultiplications(
115506c3fb27SDimitry Andric     std::vector<Product> &RealMuls, std::vector<Product> &ImagMuls,
115606c3fb27SDimitry Andric     NodePtr Accumulator = nullptr) {
115706c3fb27SDimitry Andric   if (RealMuls.size() != ImagMuls.size())
115806c3fb27SDimitry Andric     return nullptr;
115906c3fb27SDimitry Andric 
116006c3fb27SDimitry Andric   std::vector<PartialMulCandidate> Info;
116106c3fb27SDimitry Andric   if (!collectPartialMuls(RealMuls, ImagMuls, Info))
116206c3fb27SDimitry Andric     return nullptr;
116306c3fb27SDimitry Andric 
116406c3fb27SDimitry Andric   // Map to store common instruction to node pointers
116506c3fb27SDimitry Andric   std::map<Value *, NodePtr> CommonToNode;
116606c3fb27SDimitry Andric   std::vector<bool> Processed(Info.size(), false);
116706c3fb27SDimitry Andric   for (unsigned I = 0; I < Info.size(); ++I) {
116806c3fb27SDimitry Andric     if (Processed[I])
116906c3fb27SDimitry Andric       continue;
117006c3fb27SDimitry Andric 
117106c3fb27SDimitry Andric     PartialMulCandidate &InfoA = Info[I];
117206c3fb27SDimitry Andric     for (unsigned J = I + 1; J < Info.size(); ++J) {
117306c3fb27SDimitry Andric       if (Processed[J])
117406c3fb27SDimitry Andric         continue;
117506c3fb27SDimitry Andric 
117606c3fb27SDimitry Andric       PartialMulCandidate &InfoB = Info[J];
117706c3fb27SDimitry Andric       auto *InfoReal = &InfoA;
117806c3fb27SDimitry Andric       auto *InfoImag = &InfoB;
117906c3fb27SDimitry Andric 
118006c3fb27SDimitry Andric       auto NodeFromCommon = identifyNode(InfoReal->Common, InfoImag->Common);
118106c3fb27SDimitry Andric       if (!NodeFromCommon) {
118206c3fb27SDimitry Andric         std::swap(InfoReal, InfoImag);
118306c3fb27SDimitry Andric         NodeFromCommon = identifyNode(InfoReal->Common, InfoImag->Common);
118406c3fb27SDimitry Andric       }
118506c3fb27SDimitry Andric       if (!NodeFromCommon)
118606c3fb27SDimitry Andric         continue;
118706c3fb27SDimitry Andric 
118806c3fb27SDimitry Andric       CommonToNode[InfoReal->Common] = NodeFromCommon;
118906c3fb27SDimitry Andric       CommonToNode[InfoImag->Common] = NodeFromCommon;
119006c3fb27SDimitry Andric       Processed[I] = true;
119106c3fb27SDimitry Andric       Processed[J] = true;
119206c3fb27SDimitry Andric     }
119306c3fb27SDimitry Andric   }
119406c3fb27SDimitry Andric 
119506c3fb27SDimitry Andric   std::vector<bool> ProcessedReal(RealMuls.size(), false);
119606c3fb27SDimitry Andric   std::vector<bool> ProcessedImag(ImagMuls.size(), false);
119706c3fb27SDimitry Andric   NodePtr Result = Accumulator;
119806c3fb27SDimitry Andric   for (auto &PMI : Info) {
119906c3fb27SDimitry Andric     if (ProcessedReal[PMI.RealIdx] || ProcessedImag[PMI.ImagIdx])
120006c3fb27SDimitry Andric       continue;
120106c3fb27SDimitry Andric 
120206c3fb27SDimitry Andric     auto It = CommonToNode.find(PMI.Common);
120306c3fb27SDimitry Andric     // TODO: Process independent complex multiplications. Cases like this:
120406c3fb27SDimitry Andric     //  A.real() * B where both A and B are complex numbers.
120506c3fb27SDimitry Andric     if (It == CommonToNode.end()) {
120606c3fb27SDimitry Andric       LLVM_DEBUG({
120706c3fb27SDimitry Andric         dbgs() << "Unprocessed independent partial multiplication:\n";
120806c3fb27SDimitry Andric         for (auto *Mul : {&RealMuls[PMI.RealIdx], &RealMuls[PMI.RealIdx]})
120906c3fb27SDimitry Andric           dbgs().indent(4) << (Mul->IsPositive ? "+" : "-") << *Mul->Multiplier
121006c3fb27SDimitry Andric                            << " multiplied by " << *Mul->Multiplicand << "\n";
121106c3fb27SDimitry Andric       });
121206c3fb27SDimitry Andric       return nullptr;
121306c3fb27SDimitry Andric     }
121406c3fb27SDimitry Andric 
121506c3fb27SDimitry Andric     auto &RealMul = RealMuls[PMI.RealIdx];
121606c3fb27SDimitry Andric     auto &ImagMul = ImagMuls[PMI.ImagIdx];
121706c3fb27SDimitry Andric 
121806c3fb27SDimitry Andric     auto NodeA = It->second;
121906c3fb27SDimitry Andric     auto NodeB = PMI.Node;
122006c3fb27SDimitry Andric     auto IsMultiplicandReal = PMI.Common == NodeA->Real;
122106c3fb27SDimitry Andric     // The following table illustrates the relationship between multiplications
122206c3fb27SDimitry Andric     // and rotations. If we consider the multiplication (X + iY) * (U + iV), we
122306c3fb27SDimitry Andric     // can see:
122406c3fb27SDimitry Andric     //
122506c3fb27SDimitry Andric     // Rotation |   Real |   Imag |
122606c3fb27SDimitry Andric     // ---------+--------+--------+
122706c3fb27SDimitry Andric     //        0 |  x * u |  x * v |
122806c3fb27SDimitry Andric     //       90 | -y * v |  y * u |
122906c3fb27SDimitry Andric     //      180 | -x * u | -x * v |
123006c3fb27SDimitry Andric     //      270 |  y * v | -y * u |
123106c3fb27SDimitry Andric     //
123206c3fb27SDimitry Andric     // Check if the candidate can indeed be represented by partial
123306c3fb27SDimitry Andric     // multiplication
123406c3fb27SDimitry Andric     // TODO: Add support for multiplication by complex one
123506c3fb27SDimitry Andric     if ((IsMultiplicandReal && PMI.IsNodeInverted) ||
123606c3fb27SDimitry Andric         (!IsMultiplicandReal && !PMI.IsNodeInverted))
123706c3fb27SDimitry Andric       continue;
123806c3fb27SDimitry Andric 
123906c3fb27SDimitry Andric     // Determine the rotation based on the multiplications
124006c3fb27SDimitry Andric     ComplexDeinterleavingRotation Rotation;
124106c3fb27SDimitry Andric     if (IsMultiplicandReal) {
124206c3fb27SDimitry Andric       // Detect 0 and 180 degrees rotation
124306c3fb27SDimitry Andric       if (RealMul.IsPositive && ImagMul.IsPositive)
124406c3fb27SDimitry Andric         Rotation = llvm::ComplexDeinterleavingRotation::Rotation_0;
124506c3fb27SDimitry Andric       else if (!RealMul.IsPositive && !ImagMul.IsPositive)
124606c3fb27SDimitry Andric         Rotation = llvm::ComplexDeinterleavingRotation::Rotation_180;
124706c3fb27SDimitry Andric       else
124806c3fb27SDimitry Andric         continue;
124906c3fb27SDimitry Andric 
125006c3fb27SDimitry Andric     } else {
125106c3fb27SDimitry Andric       // Detect 90 and 270 degrees rotation
125206c3fb27SDimitry Andric       if (!RealMul.IsPositive && ImagMul.IsPositive)
125306c3fb27SDimitry Andric         Rotation = llvm::ComplexDeinterleavingRotation::Rotation_90;
125406c3fb27SDimitry Andric       else if (RealMul.IsPositive && !ImagMul.IsPositive)
125506c3fb27SDimitry Andric         Rotation = llvm::ComplexDeinterleavingRotation::Rotation_270;
125606c3fb27SDimitry Andric       else
125706c3fb27SDimitry Andric         continue;
125806c3fb27SDimitry Andric     }
125906c3fb27SDimitry Andric 
126006c3fb27SDimitry Andric     LLVM_DEBUG({
126106c3fb27SDimitry Andric       dbgs() << "Identified partial multiplication (X, Y) * (U, V):\n";
126206c3fb27SDimitry Andric       dbgs().indent(4) << "X: " << *NodeA->Real << "\n";
126306c3fb27SDimitry Andric       dbgs().indent(4) << "Y: " << *NodeA->Imag << "\n";
126406c3fb27SDimitry Andric       dbgs().indent(4) << "U: " << *NodeB->Real << "\n";
126506c3fb27SDimitry Andric       dbgs().indent(4) << "V: " << *NodeB->Imag << "\n";
126606c3fb27SDimitry Andric       dbgs().indent(4) << "Rotation - " << (int)Rotation * 90 << "\n";
126706c3fb27SDimitry Andric     });
126806c3fb27SDimitry Andric 
126906c3fb27SDimitry Andric     NodePtr NodeMul = prepareCompositeNode(
127006c3fb27SDimitry Andric         ComplexDeinterleavingOperation::CMulPartial, nullptr, nullptr);
127106c3fb27SDimitry Andric     NodeMul->Rotation = Rotation;
127206c3fb27SDimitry Andric     NodeMul->addOperand(NodeA);
127306c3fb27SDimitry Andric     NodeMul->addOperand(NodeB);
127406c3fb27SDimitry Andric     if (Result)
127506c3fb27SDimitry Andric       NodeMul->addOperand(Result);
127606c3fb27SDimitry Andric     submitCompositeNode(NodeMul);
127706c3fb27SDimitry Andric     Result = NodeMul;
127806c3fb27SDimitry Andric     ProcessedReal[PMI.RealIdx] = true;
127906c3fb27SDimitry Andric     ProcessedImag[PMI.ImagIdx] = true;
128006c3fb27SDimitry Andric   }
128106c3fb27SDimitry Andric 
128206c3fb27SDimitry Andric   // Ensure all products have been processed, if not return nullptr.
128306c3fb27SDimitry Andric   if (!all_of(ProcessedReal, [](bool V) { return V; }) ||
128406c3fb27SDimitry Andric       !all_of(ProcessedImag, [](bool V) { return V; })) {
128506c3fb27SDimitry Andric 
128606c3fb27SDimitry Andric     // Dump debug information about which partial multiplications are not
128706c3fb27SDimitry Andric     // processed.
128806c3fb27SDimitry Andric     LLVM_DEBUG({
128906c3fb27SDimitry Andric       dbgs() << "Unprocessed products (Real):\n";
129006c3fb27SDimitry Andric       for (size_t i = 0; i < ProcessedReal.size(); ++i) {
129106c3fb27SDimitry Andric         if (!ProcessedReal[i])
129206c3fb27SDimitry Andric           dbgs().indent(4) << (RealMuls[i].IsPositive ? "+" : "-")
129306c3fb27SDimitry Andric                            << *RealMuls[i].Multiplier << " multiplied by "
129406c3fb27SDimitry Andric                            << *RealMuls[i].Multiplicand << "\n";
129506c3fb27SDimitry Andric       }
129606c3fb27SDimitry Andric       dbgs() << "Unprocessed products (Imag):\n";
129706c3fb27SDimitry Andric       for (size_t i = 0; i < ProcessedImag.size(); ++i) {
129806c3fb27SDimitry Andric         if (!ProcessedImag[i])
129906c3fb27SDimitry Andric           dbgs().indent(4) << (ImagMuls[i].IsPositive ? "+" : "-")
130006c3fb27SDimitry Andric                            << *ImagMuls[i].Multiplier << " multiplied by "
130106c3fb27SDimitry Andric                            << *ImagMuls[i].Multiplicand << "\n";
130206c3fb27SDimitry Andric       }
130306c3fb27SDimitry Andric     });
130406c3fb27SDimitry Andric     return nullptr;
130506c3fb27SDimitry Andric   }
130606c3fb27SDimitry Andric 
130706c3fb27SDimitry Andric   return Result;
130806c3fb27SDimitry Andric }
130906c3fb27SDimitry Andric 
131006c3fb27SDimitry Andric ComplexDeinterleavingGraph::NodePtr
131106c3fb27SDimitry Andric ComplexDeinterleavingGraph::identifyAdditions(
131206c3fb27SDimitry Andric     std::list<Addend> &RealAddends, std::list<Addend> &ImagAddends,
131306c3fb27SDimitry Andric     std::optional<FastMathFlags> Flags, NodePtr Accumulator = nullptr) {
131406c3fb27SDimitry Andric   if (RealAddends.size() != ImagAddends.size())
131506c3fb27SDimitry Andric     return nullptr;
131606c3fb27SDimitry Andric 
131706c3fb27SDimitry Andric   NodePtr Result;
131806c3fb27SDimitry Andric   // If we have accumulator use it as first addend
131906c3fb27SDimitry Andric   if (Accumulator)
132006c3fb27SDimitry Andric     Result = Accumulator;
132106c3fb27SDimitry Andric   // Otherwise find an element with both positive real and imaginary parts.
132206c3fb27SDimitry Andric   else
132306c3fb27SDimitry Andric     Result = extractPositiveAddend(RealAddends, ImagAddends);
132406c3fb27SDimitry Andric 
132506c3fb27SDimitry Andric   if (!Result)
132606c3fb27SDimitry Andric     return nullptr;
132706c3fb27SDimitry Andric 
132806c3fb27SDimitry Andric   while (!RealAddends.empty()) {
132906c3fb27SDimitry Andric     auto ItR = RealAddends.begin();
133006c3fb27SDimitry Andric     auto [R, IsPositiveR] = *ItR;
133106c3fb27SDimitry Andric 
133206c3fb27SDimitry Andric     bool FoundImag = false;
133306c3fb27SDimitry Andric     for (auto ItI = ImagAddends.begin(); ItI != ImagAddends.end(); ++ItI) {
133406c3fb27SDimitry Andric       auto [I, IsPositiveI] = *ItI;
133506c3fb27SDimitry Andric       ComplexDeinterleavingRotation Rotation;
133606c3fb27SDimitry Andric       if (IsPositiveR && IsPositiveI)
133706c3fb27SDimitry Andric         Rotation = ComplexDeinterleavingRotation::Rotation_0;
133806c3fb27SDimitry Andric       else if (!IsPositiveR && IsPositiveI)
133906c3fb27SDimitry Andric         Rotation = ComplexDeinterleavingRotation::Rotation_90;
134006c3fb27SDimitry Andric       else if (!IsPositiveR && !IsPositiveI)
134106c3fb27SDimitry Andric         Rotation = ComplexDeinterleavingRotation::Rotation_180;
134206c3fb27SDimitry Andric       else
134306c3fb27SDimitry Andric         Rotation = ComplexDeinterleavingRotation::Rotation_270;
134406c3fb27SDimitry Andric 
134506c3fb27SDimitry Andric       NodePtr AddNode;
134606c3fb27SDimitry Andric       if (Rotation == ComplexDeinterleavingRotation::Rotation_0 ||
134706c3fb27SDimitry Andric           Rotation == ComplexDeinterleavingRotation::Rotation_180) {
134806c3fb27SDimitry Andric         AddNode = identifyNode(R, I);
134906c3fb27SDimitry Andric       } else {
135006c3fb27SDimitry Andric         AddNode = identifyNode(I, R);
135106c3fb27SDimitry Andric       }
135206c3fb27SDimitry Andric       if (AddNode) {
135306c3fb27SDimitry Andric         LLVM_DEBUG({
135406c3fb27SDimitry Andric           dbgs() << "Identified addition:\n";
135506c3fb27SDimitry Andric           dbgs().indent(4) << "X: " << *R << "\n";
135606c3fb27SDimitry Andric           dbgs().indent(4) << "Y: " << *I << "\n";
135706c3fb27SDimitry Andric           dbgs().indent(4) << "Rotation - " << (int)Rotation * 90 << "\n";
135806c3fb27SDimitry Andric         });
135906c3fb27SDimitry Andric 
136006c3fb27SDimitry Andric         NodePtr TmpNode;
136106c3fb27SDimitry Andric         if (Rotation == llvm::ComplexDeinterleavingRotation::Rotation_0) {
136206c3fb27SDimitry Andric           TmpNode = prepareCompositeNode(
136306c3fb27SDimitry Andric               ComplexDeinterleavingOperation::Symmetric, nullptr, nullptr);
136406c3fb27SDimitry Andric           if (Flags) {
136506c3fb27SDimitry Andric             TmpNode->Opcode = Instruction::FAdd;
136606c3fb27SDimitry Andric             TmpNode->Flags = *Flags;
136706c3fb27SDimitry Andric           } else {
136806c3fb27SDimitry Andric             TmpNode->Opcode = Instruction::Add;
136906c3fb27SDimitry Andric           }
137006c3fb27SDimitry Andric         } else if (Rotation ==
137106c3fb27SDimitry Andric                    llvm::ComplexDeinterleavingRotation::Rotation_180) {
137206c3fb27SDimitry Andric           TmpNode = prepareCompositeNode(
137306c3fb27SDimitry Andric               ComplexDeinterleavingOperation::Symmetric, nullptr, nullptr);
137406c3fb27SDimitry Andric           if (Flags) {
137506c3fb27SDimitry Andric             TmpNode->Opcode = Instruction::FSub;
137606c3fb27SDimitry Andric             TmpNode->Flags = *Flags;
137706c3fb27SDimitry Andric           } else {
137806c3fb27SDimitry Andric             TmpNode->Opcode = Instruction::Sub;
137906c3fb27SDimitry Andric           }
138006c3fb27SDimitry Andric         } else {
138106c3fb27SDimitry Andric           TmpNode = prepareCompositeNode(ComplexDeinterleavingOperation::CAdd,
138206c3fb27SDimitry Andric                                          nullptr, nullptr);
138306c3fb27SDimitry Andric           TmpNode->Rotation = Rotation;
138406c3fb27SDimitry Andric         }
138506c3fb27SDimitry Andric 
138606c3fb27SDimitry Andric         TmpNode->addOperand(Result);
138706c3fb27SDimitry Andric         TmpNode->addOperand(AddNode);
138806c3fb27SDimitry Andric         submitCompositeNode(TmpNode);
138906c3fb27SDimitry Andric         Result = TmpNode;
139006c3fb27SDimitry Andric         RealAddends.erase(ItR);
139106c3fb27SDimitry Andric         ImagAddends.erase(ItI);
139206c3fb27SDimitry Andric         FoundImag = true;
139306c3fb27SDimitry Andric         break;
139406c3fb27SDimitry Andric       }
139506c3fb27SDimitry Andric     }
139606c3fb27SDimitry Andric     if (!FoundImag)
139706c3fb27SDimitry Andric       return nullptr;
139806c3fb27SDimitry Andric   }
139906c3fb27SDimitry Andric   return Result;
140006c3fb27SDimitry Andric }
140106c3fb27SDimitry Andric 
140206c3fb27SDimitry Andric ComplexDeinterleavingGraph::NodePtr
140306c3fb27SDimitry Andric ComplexDeinterleavingGraph::extractPositiveAddend(
140406c3fb27SDimitry Andric     std::list<Addend> &RealAddends, std::list<Addend> &ImagAddends) {
140506c3fb27SDimitry Andric   for (auto ItR = RealAddends.begin(); ItR != RealAddends.end(); ++ItR) {
140606c3fb27SDimitry Andric     for (auto ItI = ImagAddends.begin(); ItI != ImagAddends.end(); ++ItI) {
140706c3fb27SDimitry Andric       auto [R, IsPositiveR] = *ItR;
140806c3fb27SDimitry Andric       auto [I, IsPositiveI] = *ItI;
140906c3fb27SDimitry Andric       if (IsPositiveR && IsPositiveI) {
141006c3fb27SDimitry Andric         auto Result = identifyNode(R, I);
141106c3fb27SDimitry Andric         if (Result) {
141206c3fb27SDimitry Andric           RealAddends.erase(ItR);
141306c3fb27SDimitry Andric           ImagAddends.erase(ItI);
141406c3fb27SDimitry Andric           return Result;
141506c3fb27SDimitry Andric         }
141606c3fb27SDimitry Andric       }
141706c3fb27SDimitry Andric     }
141806c3fb27SDimitry Andric   }
141906c3fb27SDimitry Andric   return nullptr;
142006c3fb27SDimitry Andric }
142106c3fb27SDimitry Andric 
142206c3fb27SDimitry Andric bool ComplexDeinterleavingGraph::identifyNodes(Instruction *RootI) {
142306c3fb27SDimitry Andric   // This potential root instruction might already have been recognized as
142406c3fb27SDimitry Andric   // reduction. Because RootToNode maps both Real and Imaginary parts to
142506c3fb27SDimitry Andric   // CompositeNode we should choose only one either Real or Imag instruction to
142606c3fb27SDimitry Andric   // use as an anchor for generating complex instruction.
142706c3fb27SDimitry Andric   auto It = RootToNode.find(RootI);
14288a4dda33SDimitry Andric   if (It != RootToNode.end()) {
14298a4dda33SDimitry Andric     auto RootNode = It->second;
14308a4dda33SDimitry Andric     assert(RootNode->Operation ==
14318a4dda33SDimitry Andric            ComplexDeinterleavingOperation::ReductionOperation);
14328a4dda33SDimitry Andric     // Find out which part, Real or Imag, comes later, and only if we come to
14338a4dda33SDimitry Andric     // the latest part, add it to OrderedRoots.
14348a4dda33SDimitry Andric     auto *R = cast<Instruction>(RootNode->Real);
14358a4dda33SDimitry Andric     auto *I = cast<Instruction>(RootNode->Imag);
14368a4dda33SDimitry Andric     auto *ReplacementAnchor = R->comesBefore(I) ? I : R;
14378a4dda33SDimitry Andric     if (ReplacementAnchor != RootI)
14388a4dda33SDimitry Andric       return false;
143906c3fb27SDimitry Andric     OrderedRoots.push_back(RootI);
144006c3fb27SDimitry Andric     return true;
144106c3fb27SDimitry Andric   }
144206c3fb27SDimitry Andric 
144306c3fb27SDimitry Andric   auto RootNode = identifyRoot(RootI);
144406c3fb27SDimitry Andric   if (!RootNode)
144506c3fb27SDimitry Andric     return false;
144606c3fb27SDimitry Andric 
144706c3fb27SDimitry Andric   LLVM_DEBUG({
144806c3fb27SDimitry Andric     Function *F = RootI->getFunction();
144906c3fb27SDimitry Andric     BasicBlock *B = RootI->getParent();
145006c3fb27SDimitry Andric     dbgs() << "Complex deinterleaving graph for " << F->getName()
145106c3fb27SDimitry Andric            << "::" << B->getName() << ".\n";
145206c3fb27SDimitry Andric     dump(dbgs());
145306c3fb27SDimitry Andric     dbgs() << "\n";
145406c3fb27SDimitry Andric   });
145506c3fb27SDimitry Andric   RootToNode[RootI] = RootNode;
145606c3fb27SDimitry Andric   OrderedRoots.push_back(RootI);
145706c3fb27SDimitry Andric   return true;
145806c3fb27SDimitry Andric }
145906c3fb27SDimitry Andric 
146006c3fb27SDimitry Andric bool ComplexDeinterleavingGraph::collectPotentialReductions(BasicBlock *B) {
146106c3fb27SDimitry Andric   bool FoundPotentialReduction = false;
146206c3fb27SDimitry Andric 
146306c3fb27SDimitry Andric   auto *Br = dyn_cast<BranchInst>(B->getTerminator());
146406c3fb27SDimitry Andric   if (!Br || Br->getNumSuccessors() != 2)
146506c3fb27SDimitry Andric     return false;
146606c3fb27SDimitry Andric 
146706c3fb27SDimitry Andric   // Identify simple one-block loop
146806c3fb27SDimitry Andric   if (Br->getSuccessor(0) != B && Br->getSuccessor(1) != B)
146906c3fb27SDimitry Andric     return false;
147006c3fb27SDimitry Andric 
147106c3fb27SDimitry Andric   SmallVector<PHINode *> PHIs;
147206c3fb27SDimitry Andric   for (auto &PHI : B->phis()) {
147306c3fb27SDimitry Andric     if (PHI.getNumIncomingValues() != 2)
147406c3fb27SDimitry Andric       continue;
147506c3fb27SDimitry Andric 
147606c3fb27SDimitry Andric     if (!PHI.getType()->isVectorTy())
147706c3fb27SDimitry Andric       continue;
147806c3fb27SDimitry Andric 
147906c3fb27SDimitry Andric     auto *ReductionOp = dyn_cast<Instruction>(PHI.getIncomingValueForBlock(B));
148006c3fb27SDimitry Andric     if (!ReductionOp)
148106c3fb27SDimitry Andric       continue;
148206c3fb27SDimitry Andric 
148306c3fb27SDimitry Andric     // Check if final instruction is reduced outside of current block
148406c3fb27SDimitry Andric     Instruction *FinalReduction = nullptr;
148506c3fb27SDimitry Andric     auto NumUsers = 0u;
148606c3fb27SDimitry Andric     for (auto *U : ReductionOp->users()) {
148706c3fb27SDimitry Andric       ++NumUsers;
148806c3fb27SDimitry Andric       if (U == &PHI)
148906c3fb27SDimitry Andric         continue;
149006c3fb27SDimitry Andric       FinalReduction = dyn_cast<Instruction>(U);
149106c3fb27SDimitry Andric     }
149206c3fb27SDimitry Andric 
149306c3fb27SDimitry Andric     if (NumUsers != 2 || !FinalReduction || FinalReduction->getParent() == B ||
149406c3fb27SDimitry Andric         isa<PHINode>(FinalReduction))
149506c3fb27SDimitry Andric       continue;
149606c3fb27SDimitry Andric 
149706c3fb27SDimitry Andric     ReductionInfo[ReductionOp] = {&PHI, FinalReduction};
149806c3fb27SDimitry Andric     BackEdge = B;
149906c3fb27SDimitry Andric     auto BackEdgeIdx = PHI.getBasicBlockIndex(B);
150006c3fb27SDimitry Andric     auto IncomingIdx = BackEdgeIdx == 0 ? 1 : 0;
150106c3fb27SDimitry Andric     Incoming = PHI.getIncomingBlock(IncomingIdx);
150206c3fb27SDimitry Andric     FoundPotentialReduction = true;
150306c3fb27SDimitry Andric 
150406c3fb27SDimitry Andric     // If the initial value of PHINode is an Instruction, consider it a leaf
150506c3fb27SDimitry Andric     // value of a complex deinterleaving graph.
150606c3fb27SDimitry Andric     if (auto *InitPHI =
150706c3fb27SDimitry Andric             dyn_cast<Instruction>(PHI.getIncomingValueForBlock(Incoming)))
150806c3fb27SDimitry Andric       FinalInstructions.insert(InitPHI);
150906c3fb27SDimitry Andric   }
151006c3fb27SDimitry Andric   return FoundPotentialReduction;
151106c3fb27SDimitry Andric }
151206c3fb27SDimitry Andric 
151306c3fb27SDimitry Andric void ComplexDeinterleavingGraph::identifyReductionNodes() {
151406c3fb27SDimitry Andric   SmallVector<bool> Processed(ReductionInfo.size(), false);
151506c3fb27SDimitry Andric   SmallVector<Instruction *> OperationInstruction;
151606c3fb27SDimitry Andric   for (auto &P : ReductionInfo)
151706c3fb27SDimitry Andric     OperationInstruction.push_back(P.first);
151806c3fb27SDimitry Andric 
151906c3fb27SDimitry Andric   // Identify a complex computation by evaluating two reduction operations that
152006c3fb27SDimitry Andric   // potentially could be involved
152106c3fb27SDimitry Andric   for (size_t i = 0; i < OperationInstruction.size(); ++i) {
152206c3fb27SDimitry Andric     if (Processed[i])
152306c3fb27SDimitry Andric       continue;
152406c3fb27SDimitry Andric     for (size_t j = i + 1; j < OperationInstruction.size(); ++j) {
152506c3fb27SDimitry Andric       if (Processed[j])
152606c3fb27SDimitry Andric         continue;
152706c3fb27SDimitry Andric 
152806c3fb27SDimitry Andric       auto *Real = OperationInstruction[i];
152906c3fb27SDimitry Andric       auto *Imag = OperationInstruction[j];
153006c3fb27SDimitry Andric       if (Real->getType() != Imag->getType())
153106c3fb27SDimitry Andric         continue;
153206c3fb27SDimitry Andric 
153306c3fb27SDimitry Andric       RealPHI = ReductionInfo[Real].first;
153406c3fb27SDimitry Andric       ImagPHI = ReductionInfo[Imag].first;
153506c3fb27SDimitry Andric       PHIsFound = false;
153606c3fb27SDimitry Andric       auto Node = identifyNode(Real, Imag);
153706c3fb27SDimitry Andric       if (!Node) {
153806c3fb27SDimitry Andric         std::swap(Real, Imag);
153906c3fb27SDimitry Andric         std::swap(RealPHI, ImagPHI);
154006c3fb27SDimitry Andric         Node = identifyNode(Real, Imag);
154106c3fb27SDimitry Andric       }
154206c3fb27SDimitry Andric 
154306c3fb27SDimitry Andric       // If a node is identified and reduction PHINode is used in the chain of
154406c3fb27SDimitry Andric       // operations, mark its operation instructions as used to prevent
154506c3fb27SDimitry Andric       // re-identification and attach the node to the real part
154606c3fb27SDimitry Andric       if (Node && PHIsFound) {
154706c3fb27SDimitry Andric         LLVM_DEBUG(dbgs() << "Identified reduction starting from instructions: "
154806c3fb27SDimitry Andric                           << *Real << " / " << *Imag << "\n");
154906c3fb27SDimitry Andric         Processed[i] = true;
155006c3fb27SDimitry Andric         Processed[j] = true;
155106c3fb27SDimitry Andric         auto RootNode = prepareCompositeNode(
155206c3fb27SDimitry Andric             ComplexDeinterleavingOperation::ReductionOperation, Real, Imag);
155306c3fb27SDimitry Andric         RootNode->addOperand(Node);
155406c3fb27SDimitry Andric         RootToNode[Real] = RootNode;
155506c3fb27SDimitry Andric         RootToNode[Imag] = RootNode;
155606c3fb27SDimitry Andric         submitCompositeNode(RootNode);
155706c3fb27SDimitry Andric         break;
155806c3fb27SDimitry Andric       }
155906c3fb27SDimitry Andric     }
156006c3fb27SDimitry Andric   }
156106c3fb27SDimitry Andric 
156206c3fb27SDimitry Andric   RealPHI = nullptr;
156306c3fb27SDimitry Andric   ImagPHI = nullptr;
156406c3fb27SDimitry Andric }
156506c3fb27SDimitry Andric 
156606c3fb27SDimitry Andric bool ComplexDeinterleavingGraph::checkNodes() {
156706c3fb27SDimitry Andric   // Collect all instructions from roots to leaves
156806c3fb27SDimitry Andric   SmallPtrSet<Instruction *, 16> AllInstructions;
156906c3fb27SDimitry Andric   SmallVector<Instruction *, 8> Worklist;
157006c3fb27SDimitry Andric   for (auto &Pair : RootToNode)
157106c3fb27SDimitry Andric     Worklist.push_back(Pair.first);
157206c3fb27SDimitry Andric 
157306c3fb27SDimitry Andric   // Extract all instructions that are used by all XCMLA/XCADD/ADD/SUB/NEG
157406c3fb27SDimitry Andric   // chains
157506c3fb27SDimitry Andric   while (!Worklist.empty()) {
157606c3fb27SDimitry Andric     auto *I = Worklist.back();
157706c3fb27SDimitry Andric     Worklist.pop_back();
157806c3fb27SDimitry Andric 
157906c3fb27SDimitry Andric     if (!AllInstructions.insert(I).second)
158006c3fb27SDimitry Andric       continue;
158106c3fb27SDimitry Andric 
158206c3fb27SDimitry Andric     for (Value *Op : I->operands()) {
158306c3fb27SDimitry Andric       if (auto *OpI = dyn_cast<Instruction>(Op)) {
158406c3fb27SDimitry Andric         if (!FinalInstructions.count(I))
158506c3fb27SDimitry Andric           Worklist.emplace_back(OpI);
158606c3fb27SDimitry Andric       }
158706c3fb27SDimitry Andric     }
158806c3fb27SDimitry Andric   }
158906c3fb27SDimitry Andric 
159006c3fb27SDimitry Andric   // Find instructions that have users outside of chain
159106c3fb27SDimitry Andric   SmallVector<Instruction *, 2> OuterInstructions;
159206c3fb27SDimitry Andric   for (auto *I : AllInstructions) {
159306c3fb27SDimitry Andric     // Skip root nodes
159406c3fb27SDimitry Andric     if (RootToNode.count(I))
159506c3fb27SDimitry Andric       continue;
159606c3fb27SDimitry Andric 
159706c3fb27SDimitry Andric     for (User *U : I->users()) {
159806c3fb27SDimitry Andric       if (AllInstructions.count(cast<Instruction>(U)))
159906c3fb27SDimitry Andric         continue;
160006c3fb27SDimitry Andric 
160106c3fb27SDimitry Andric       // Found an instruction that is not used by XCMLA/XCADD chain
160206c3fb27SDimitry Andric       Worklist.emplace_back(I);
160306c3fb27SDimitry Andric       break;
160406c3fb27SDimitry Andric     }
160506c3fb27SDimitry Andric   }
160606c3fb27SDimitry Andric 
160706c3fb27SDimitry Andric   // If any instructions are found to be used outside, find and remove roots
160806c3fb27SDimitry Andric   // that somehow connect to those instructions.
160906c3fb27SDimitry Andric   SmallPtrSet<Instruction *, 16> Visited;
161006c3fb27SDimitry Andric   while (!Worklist.empty()) {
161106c3fb27SDimitry Andric     auto *I = Worklist.back();
161206c3fb27SDimitry Andric     Worklist.pop_back();
161306c3fb27SDimitry Andric     if (!Visited.insert(I).second)
161406c3fb27SDimitry Andric       continue;
161506c3fb27SDimitry Andric 
161606c3fb27SDimitry Andric     // Found an impacted root node. Removing it from the nodes to be
161706c3fb27SDimitry Andric     // deinterleaved
161806c3fb27SDimitry Andric     if (RootToNode.count(I)) {
161906c3fb27SDimitry Andric       LLVM_DEBUG(dbgs() << "Instruction " << *I
162006c3fb27SDimitry Andric                         << " could be deinterleaved but its chain of complex "
162106c3fb27SDimitry Andric                            "operations have an outside user\n");
162206c3fb27SDimitry Andric       RootToNode.erase(I);
162306c3fb27SDimitry Andric     }
162406c3fb27SDimitry Andric 
162506c3fb27SDimitry Andric     if (!AllInstructions.count(I) || FinalInstructions.count(I))
162606c3fb27SDimitry Andric       continue;
162706c3fb27SDimitry Andric 
162806c3fb27SDimitry Andric     for (User *U : I->users())
162906c3fb27SDimitry Andric       Worklist.emplace_back(cast<Instruction>(U));
163006c3fb27SDimitry Andric 
163106c3fb27SDimitry Andric     for (Value *Op : I->operands()) {
163206c3fb27SDimitry Andric       if (auto *OpI = dyn_cast<Instruction>(Op))
163306c3fb27SDimitry Andric         Worklist.emplace_back(OpI);
163406c3fb27SDimitry Andric     }
163506c3fb27SDimitry Andric   }
163606c3fb27SDimitry Andric   return !RootToNode.empty();
163706c3fb27SDimitry Andric }
163806c3fb27SDimitry Andric 
163906c3fb27SDimitry Andric ComplexDeinterleavingGraph::NodePtr
164006c3fb27SDimitry Andric ComplexDeinterleavingGraph::identifyRoot(Instruction *RootI) {
164106c3fb27SDimitry Andric   if (auto *Intrinsic = dyn_cast<IntrinsicInst>(RootI)) {
1642*0fca6ea1SDimitry Andric     if (Intrinsic->getIntrinsicID() != Intrinsic::vector_interleave2)
164306c3fb27SDimitry Andric       return nullptr;
164406c3fb27SDimitry Andric 
164506c3fb27SDimitry Andric     auto *Real = dyn_cast<Instruction>(Intrinsic->getOperand(0));
164606c3fb27SDimitry Andric     auto *Imag = dyn_cast<Instruction>(Intrinsic->getOperand(1));
164706c3fb27SDimitry Andric     if (!Real || !Imag)
164806c3fb27SDimitry Andric       return nullptr;
164906c3fb27SDimitry Andric 
165006c3fb27SDimitry Andric     return identifyNode(Real, Imag);
165106c3fb27SDimitry Andric   }
165206c3fb27SDimitry Andric 
165306c3fb27SDimitry Andric   auto *SVI = dyn_cast<ShuffleVectorInst>(RootI);
165406c3fb27SDimitry Andric   if (!SVI)
165506c3fb27SDimitry Andric     return nullptr;
165606c3fb27SDimitry Andric 
165706c3fb27SDimitry Andric   // Look for a shufflevector that takes separate vectors of the real and
165806c3fb27SDimitry Andric   // imaginary components and recombines them into a single vector.
165906c3fb27SDimitry Andric   if (!isInterleavingMask(SVI->getShuffleMask()))
166006c3fb27SDimitry Andric     return nullptr;
166106c3fb27SDimitry Andric 
166206c3fb27SDimitry Andric   Instruction *Real;
166306c3fb27SDimitry Andric   Instruction *Imag;
166406c3fb27SDimitry Andric   if (!match(RootI, m_Shuffle(m_Instruction(Real), m_Instruction(Imag))))
166506c3fb27SDimitry Andric     return nullptr;
166606c3fb27SDimitry Andric 
166706c3fb27SDimitry Andric   return identifyNode(Real, Imag);
166806c3fb27SDimitry Andric }
166906c3fb27SDimitry Andric 
167006c3fb27SDimitry Andric ComplexDeinterleavingGraph::NodePtr
167106c3fb27SDimitry Andric ComplexDeinterleavingGraph::identifyDeinterleave(Instruction *Real,
167206c3fb27SDimitry Andric                                                  Instruction *Imag) {
167306c3fb27SDimitry Andric   Instruction *I = nullptr;
167406c3fb27SDimitry Andric   Value *FinalValue = nullptr;
167506c3fb27SDimitry Andric   if (match(Real, m_ExtractValue<0>(m_Instruction(I))) &&
167606c3fb27SDimitry Andric       match(Imag, m_ExtractValue<1>(m_Specific(I))) &&
1677*0fca6ea1SDimitry Andric       match(I, m_Intrinsic<Intrinsic::vector_deinterleave2>(
167806c3fb27SDimitry Andric                    m_Value(FinalValue)))) {
167906c3fb27SDimitry Andric     NodePtr PlaceholderNode = prepareCompositeNode(
168006c3fb27SDimitry Andric         llvm::ComplexDeinterleavingOperation::Deinterleave, Real, Imag);
168106c3fb27SDimitry Andric     PlaceholderNode->ReplacementNode = FinalValue;
168206c3fb27SDimitry Andric     FinalInstructions.insert(Real);
168306c3fb27SDimitry Andric     FinalInstructions.insert(Imag);
168406c3fb27SDimitry Andric     return submitCompositeNode(PlaceholderNode);
168506c3fb27SDimitry Andric   }
168606c3fb27SDimitry Andric 
1687bdd1243dSDimitry Andric   auto *RealShuffle = dyn_cast<ShuffleVectorInst>(Real);
1688bdd1243dSDimitry Andric   auto *ImagShuffle = dyn_cast<ShuffleVectorInst>(Imag);
168906c3fb27SDimitry Andric   if (!RealShuffle || !ImagShuffle) {
169006c3fb27SDimitry Andric     if (RealShuffle || ImagShuffle)
169106c3fb27SDimitry Andric       LLVM_DEBUG(dbgs() << " - There's a shuffle where there shouldn't be.\n");
169206c3fb27SDimitry Andric     return nullptr;
169306c3fb27SDimitry Andric   }
169406c3fb27SDimitry Andric 
1695bdd1243dSDimitry Andric   Value *RealOp1 = RealShuffle->getOperand(1);
1696bdd1243dSDimitry Andric   if (!isa<UndefValue>(RealOp1) && !isa<ConstantAggregateZero>(RealOp1)) {
1697bdd1243dSDimitry Andric     LLVM_DEBUG(dbgs() << " - RealOp1 is not undef or zero.\n");
1698bdd1243dSDimitry Andric     return nullptr;
1699bdd1243dSDimitry Andric   }
1700bdd1243dSDimitry Andric   Value *ImagOp1 = ImagShuffle->getOperand(1);
1701bdd1243dSDimitry Andric   if (!isa<UndefValue>(ImagOp1) && !isa<ConstantAggregateZero>(ImagOp1)) {
1702bdd1243dSDimitry Andric     LLVM_DEBUG(dbgs() << " - ImagOp1 is not undef or zero.\n");
1703bdd1243dSDimitry Andric     return nullptr;
1704bdd1243dSDimitry Andric   }
1705bdd1243dSDimitry Andric 
1706bdd1243dSDimitry Andric   Value *RealOp0 = RealShuffle->getOperand(0);
1707bdd1243dSDimitry Andric   Value *ImagOp0 = ImagShuffle->getOperand(0);
1708bdd1243dSDimitry Andric 
1709bdd1243dSDimitry Andric   if (RealOp0 != ImagOp0) {
1710bdd1243dSDimitry Andric     LLVM_DEBUG(dbgs() << " - Shuffle operands are not equal.\n");
1711bdd1243dSDimitry Andric     return nullptr;
1712bdd1243dSDimitry Andric   }
1713bdd1243dSDimitry Andric 
1714bdd1243dSDimitry Andric   ArrayRef<int> RealMask = RealShuffle->getShuffleMask();
1715bdd1243dSDimitry Andric   ArrayRef<int> ImagMask = ImagShuffle->getShuffleMask();
1716bdd1243dSDimitry Andric   if (!isDeinterleavingMask(RealMask) || !isDeinterleavingMask(ImagMask)) {
1717bdd1243dSDimitry Andric     LLVM_DEBUG(dbgs() << " - Masks are not deinterleaving.\n");
1718bdd1243dSDimitry Andric     return nullptr;
1719bdd1243dSDimitry Andric   }
1720bdd1243dSDimitry Andric 
1721bdd1243dSDimitry Andric   if (RealMask[0] != 0 || ImagMask[0] != 1) {
1722bdd1243dSDimitry Andric     LLVM_DEBUG(dbgs() << " - Masks do not have the correct initial value.\n");
1723bdd1243dSDimitry Andric     return nullptr;
1724bdd1243dSDimitry Andric   }
1725bdd1243dSDimitry Andric 
1726bdd1243dSDimitry Andric   // Type checking, the shuffle type should be a vector type of the same
1727bdd1243dSDimitry Andric   // scalar type, but half the size
1728bdd1243dSDimitry Andric   auto CheckType = [&](ShuffleVectorInst *Shuffle) {
1729bdd1243dSDimitry Andric     Value *Op = Shuffle->getOperand(0);
1730bdd1243dSDimitry Andric     auto *ShuffleTy = cast<FixedVectorType>(Shuffle->getType());
1731bdd1243dSDimitry Andric     auto *OpTy = cast<FixedVectorType>(Op->getType());
1732bdd1243dSDimitry Andric 
1733bdd1243dSDimitry Andric     if (OpTy->getScalarType() != ShuffleTy->getScalarType())
1734bdd1243dSDimitry Andric       return false;
1735bdd1243dSDimitry Andric     if ((ShuffleTy->getNumElements() * 2) != OpTy->getNumElements())
1736bdd1243dSDimitry Andric       return false;
1737bdd1243dSDimitry Andric 
1738bdd1243dSDimitry Andric     return true;
1739bdd1243dSDimitry Andric   };
1740bdd1243dSDimitry Andric 
1741bdd1243dSDimitry Andric   auto CheckDeinterleavingShuffle = [&](ShuffleVectorInst *Shuffle) -> bool {
1742bdd1243dSDimitry Andric     if (!CheckType(Shuffle))
1743bdd1243dSDimitry Andric       return false;
1744bdd1243dSDimitry Andric 
1745bdd1243dSDimitry Andric     ArrayRef<int> Mask = Shuffle->getShuffleMask();
1746bdd1243dSDimitry Andric     int Last = *Mask.rbegin();
1747bdd1243dSDimitry Andric 
1748bdd1243dSDimitry Andric     Value *Op = Shuffle->getOperand(0);
1749bdd1243dSDimitry Andric     auto *OpTy = cast<FixedVectorType>(Op->getType());
1750bdd1243dSDimitry Andric     int NumElements = OpTy->getNumElements();
1751bdd1243dSDimitry Andric 
1752bdd1243dSDimitry Andric     // Ensure that the deinterleaving shuffle only pulls from the first
1753bdd1243dSDimitry Andric     // shuffle operand.
1754bdd1243dSDimitry Andric     return Last < NumElements;
1755bdd1243dSDimitry Andric   };
1756bdd1243dSDimitry Andric 
1757bdd1243dSDimitry Andric   if (RealShuffle->getType() != ImagShuffle->getType()) {
1758bdd1243dSDimitry Andric     LLVM_DEBUG(dbgs() << " - Shuffle types aren't equal.\n");
1759bdd1243dSDimitry Andric     return nullptr;
1760bdd1243dSDimitry Andric   }
1761bdd1243dSDimitry Andric   if (!CheckDeinterleavingShuffle(RealShuffle)) {
1762bdd1243dSDimitry Andric     LLVM_DEBUG(dbgs() << " - RealShuffle is invalid type.\n");
1763bdd1243dSDimitry Andric     return nullptr;
1764bdd1243dSDimitry Andric   }
1765bdd1243dSDimitry Andric   if (!CheckDeinterleavingShuffle(ImagShuffle)) {
1766bdd1243dSDimitry Andric     LLVM_DEBUG(dbgs() << " - ImagShuffle is invalid type.\n");
1767bdd1243dSDimitry Andric     return nullptr;
1768bdd1243dSDimitry Andric   }
1769bdd1243dSDimitry Andric 
1770bdd1243dSDimitry Andric   NodePtr PlaceholderNode =
177106c3fb27SDimitry Andric       prepareCompositeNode(llvm::ComplexDeinterleavingOperation::Deinterleave,
1772bdd1243dSDimitry Andric                            RealShuffle, ImagShuffle);
1773bdd1243dSDimitry Andric   PlaceholderNode->ReplacementNode = RealShuffle->getOperand(0);
177406c3fb27SDimitry Andric   FinalInstructions.insert(RealShuffle);
177506c3fb27SDimitry Andric   FinalInstructions.insert(ImagShuffle);
1776bdd1243dSDimitry Andric   return submitCompositeNode(PlaceholderNode);
1777bdd1243dSDimitry Andric }
1778bdd1243dSDimitry Andric 
177906c3fb27SDimitry Andric ComplexDeinterleavingGraph::NodePtr
178006c3fb27SDimitry Andric ComplexDeinterleavingGraph::identifySplat(Value *R, Value *I) {
178106c3fb27SDimitry Andric   auto IsSplat = [](Value *V) -> bool {
178206c3fb27SDimitry Andric     // Fixed-width vector with constants
178306c3fb27SDimitry Andric     if (isa<ConstantDataVector>(V))
178406c3fb27SDimitry Andric       return true;
1785bdd1243dSDimitry Andric 
178606c3fb27SDimitry Andric     VectorType *VTy;
178706c3fb27SDimitry Andric     ArrayRef<int> Mask;
178806c3fb27SDimitry Andric     // Splats are represented differently depending on whether the repeated
178906c3fb27SDimitry Andric     // value is a constant or an Instruction
179006c3fb27SDimitry Andric     if (auto *Const = dyn_cast<ConstantExpr>(V)) {
179106c3fb27SDimitry Andric       if (Const->getOpcode() != Instruction::ShuffleVector)
1792bdd1243dSDimitry Andric         return false;
179306c3fb27SDimitry Andric       VTy = cast<VectorType>(Const->getType());
179406c3fb27SDimitry Andric       Mask = Const->getShuffleMask();
179506c3fb27SDimitry Andric     } else if (auto *Shuf = dyn_cast<ShuffleVectorInst>(V)) {
179606c3fb27SDimitry Andric       VTy = Shuf->getType();
179706c3fb27SDimitry Andric       Mask = Shuf->getShuffleMask();
179806c3fb27SDimitry Andric     } else {
1799bdd1243dSDimitry Andric       return false;
1800bdd1243dSDimitry Andric     }
180106c3fb27SDimitry Andric 
180206c3fb27SDimitry Andric     // When the data type is <1 x Type>, it's not possible to differentiate
180306c3fb27SDimitry Andric     // between the ComplexDeinterleaving::Deinterleave and
180406c3fb27SDimitry Andric     // ComplexDeinterleaving::Splat operations.
180506c3fb27SDimitry Andric     if (!VTy->isScalableTy() && VTy->getElementCount().getKnownMinValue() == 1)
180606c3fb27SDimitry Andric       return false;
180706c3fb27SDimitry Andric 
180806c3fb27SDimitry Andric     return all_equal(Mask) && Mask[0] == 0;
180906c3fb27SDimitry Andric   };
181006c3fb27SDimitry Andric 
181106c3fb27SDimitry Andric   if (!IsSplat(R) || !IsSplat(I))
181206c3fb27SDimitry Andric     return nullptr;
181306c3fb27SDimitry Andric 
181406c3fb27SDimitry Andric   auto *Real = dyn_cast<Instruction>(R);
181506c3fb27SDimitry Andric   auto *Imag = dyn_cast<Instruction>(I);
181606c3fb27SDimitry Andric   if ((!Real && Imag) || (Real && !Imag))
181706c3fb27SDimitry Andric     return nullptr;
181806c3fb27SDimitry Andric 
181906c3fb27SDimitry Andric   if (Real && Imag) {
182006c3fb27SDimitry Andric     // Non-constant splats should be in the same basic block
182106c3fb27SDimitry Andric     if (Real->getParent() != Imag->getParent())
182206c3fb27SDimitry Andric       return nullptr;
182306c3fb27SDimitry Andric 
182406c3fb27SDimitry Andric     FinalInstructions.insert(Real);
182506c3fb27SDimitry Andric     FinalInstructions.insert(Imag);
1826bdd1243dSDimitry Andric   }
182706c3fb27SDimitry Andric   NodePtr PlaceholderNode =
182806c3fb27SDimitry Andric       prepareCompositeNode(ComplexDeinterleavingOperation::Splat, R, I);
182906c3fb27SDimitry Andric   return submitCompositeNode(PlaceholderNode);
1830bdd1243dSDimitry Andric }
1831bdd1243dSDimitry Andric 
183206c3fb27SDimitry Andric ComplexDeinterleavingGraph::NodePtr
183306c3fb27SDimitry Andric ComplexDeinterleavingGraph::identifyPHINode(Instruction *Real,
183406c3fb27SDimitry Andric                                             Instruction *Imag) {
183506c3fb27SDimitry Andric   if (Real != RealPHI || Imag != ImagPHI)
183606c3fb27SDimitry Andric     return nullptr;
183706c3fb27SDimitry Andric 
183806c3fb27SDimitry Andric   PHIsFound = true;
183906c3fb27SDimitry Andric   NodePtr PlaceholderNode = prepareCompositeNode(
184006c3fb27SDimitry Andric       ComplexDeinterleavingOperation::ReductionPHI, Real, Imag);
184106c3fb27SDimitry Andric   return submitCompositeNode(PlaceholderNode);
184206c3fb27SDimitry Andric }
184306c3fb27SDimitry Andric 
184406c3fb27SDimitry Andric ComplexDeinterleavingGraph::NodePtr
184506c3fb27SDimitry Andric ComplexDeinterleavingGraph::identifySelectNode(Instruction *Real,
184606c3fb27SDimitry Andric                                                Instruction *Imag) {
184706c3fb27SDimitry Andric   auto *SelectReal = dyn_cast<SelectInst>(Real);
184806c3fb27SDimitry Andric   auto *SelectImag = dyn_cast<SelectInst>(Imag);
184906c3fb27SDimitry Andric   if (!SelectReal || !SelectImag)
185006c3fb27SDimitry Andric     return nullptr;
185106c3fb27SDimitry Andric 
185206c3fb27SDimitry Andric   Instruction *MaskA, *MaskB;
185306c3fb27SDimitry Andric   Instruction *AR, *AI, *RA, *BI;
185406c3fb27SDimitry Andric   if (!match(Real, m_Select(m_Instruction(MaskA), m_Instruction(AR),
185506c3fb27SDimitry Andric                             m_Instruction(RA))) ||
185606c3fb27SDimitry Andric       !match(Imag, m_Select(m_Instruction(MaskB), m_Instruction(AI),
185706c3fb27SDimitry Andric                             m_Instruction(BI))))
185806c3fb27SDimitry Andric     return nullptr;
185906c3fb27SDimitry Andric 
186006c3fb27SDimitry Andric   if (MaskA != MaskB && !MaskA->isIdenticalTo(MaskB))
186106c3fb27SDimitry Andric     return nullptr;
186206c3fb27SDimitry Andric 
186306c3fb27SDimitry Andric   if (!MaskA->getType()->isVectorTy())
186406c3fb27SDimitry Andric     return nullptr;
186506c3fb27SDimitry Andric 
186606c3fb27SDimitry Andric   auto NodeA = identifyNode(AR, AI);
186706c3fb27SDimitry Andric   if (!NodeA)
186806c3fb27SDimitry Andric     return nullptr;
186906c3fb27SDimitry Andric 
187006c3fb27SDimitry Andric   auto NodeB = identifyNode(RA, BI);
187106c3fb27SDimitry Andric   if (!NodeB)
187206c3fb27SDimitry Andric     return nullptr;
187306c3fb27SDimitry Andric 
187406c3fb27SDimitry Andric   NodePtr PlaceholderNode = prepareCompositeNode(
187506c3fb27SDimitry Andric       ComplexDeinterleavingOperation::ReductionSelect, Real, Imag);
187606c3fb27SDimitry Andric   PlaceholderNode->addOperand(NodeA);
187706c3fb27SDimitry Andric   PlaceholderNode->addOperand(NodeB);
187806c3fb27SDimitry Andric   FinalInstructions.insert(MaskA);
187906c3fb27SDimitry Andric   FinalInstructions.insert(MaskB);
188006c3fb27SDimitry Andric   return submitCompositeNode(PlaceholderNode);
188106c3fb27SDimitry Andric }
188206c3fb27SDimitry Andric 
188306c3fb27SDimitry Andric static Value *replaceSymmetricNode(IRBuilderBase &B, unsigned Opcode,
188406c3fb27SDimitry Andric                                    std::optional<FastMathFlags> Flags,
188506c3fb27SDimitry Andric                                    Value *InputA, Value *InputB) {
188606c3fb27SDimitry Andric   Value *I;
188706c3fb27SDimitry Andric   switch (Opcode) {
188806c3fb27SDimitry Andric   case Instruction::FNeg:
188906c3fb27SDimitry Andric     I = B.CreateFNeg(InputA);
189006c3fb27SDimitry Andric     break;
189106c3fb27SDimitry Andric   case Instruction::FAdd:
189206c3fb27SDimitry Andric     I = B.CreateFAdd(InputA, InputB);
189306c3fb27SDimitry Andric     break;
189406c3fb27SDimitry Andric   case Instruction::Add:
189506c3fb27SDimitry Andric     I = B.CreateAdd(InputA, InputB);
189606c3fb27SDimitry Andric     break;
189706c3fb27SDimitry Andric   case Instruction::FSub:
189806c3fb27SDimitry Andric     I = B.CreateFSub(InputA, InputB);
189906c3fb27SDimitry Andric     break;
190006c3fb27SDimitry Andric   case Instruction::Sub:
190106c3fb27SDimitry Andric     I = B.CreateSub(InputA, InputB);
190206c3fb27SDimitry Andric     break;
190306c3fb27SDimitry Andric   case Instruction::FMul:
190406c3fb27SDimitry Andric     I = B.CreateFMul(InputA, InputB);
190506c3fb27SDimitry Andric     break;
190606c3fb27SDimitry Andric   case Instruction::Mul:
190706c3fb27SDimitry Andric     I = B.CreateMul(InputA, InputB);
190806c3fb27SDimitry Andric     break;
190906c3fb27SDimitry Andric   default:
191006c3fb27SDimitry Andric     llvm_unreachable("Incorrect symmetric opcode");
191106c3fb27SDimitry Andric   }
191206c3fb27SDimitry Andric   if (Flags)
191306c3fb27SDimitry Andric     cast<Instruction>(I)->setFastMathFlags(*Flags);
191406c3fb27SDimitry Andric   return I;
191506c3fb27SDimitry Andric }
191606c3fb27SDimitry Andric 
191706c3fb27SDimitry Andric Value *ComplexDeinterleavingGraph::replaceNode(IRBuilderBase &Builder,
191806c3fb27SDimitry Andric                                                RawNodePtr Node) {
1919bdd1243dSDimitry Andric   if (Node->ReplacementNode)
1920bdd1243dSDimitry Andric     return Node->ReplacementNode;
1921bdd1243dSDimitry Andric 
192206c3fb27SDimitry Andric   auto ReplaceOperandIfExist = [&](RawNodePtr &Node, unsigned Idx) -> Value * {
192306c3fb27SDimitry Andric     return Node->Operands.size() > Idx
192406c3fb27SDimitry Andric                ? replaceNode(Builder, Node->Operands[Idx])
192506c3fb27SDimitry Andric                : nullptr;
192606c3fb27SDimitry Andric   };
1927bdd1243dSDimitry Andric 
192806c3fb27SDimitry Andric   Value *ReplacementNode;
192906c3fb27SDimitry Andric   switch (Node->Operation) {
193006c3fb27SDimitry Andric   case ComplexDeinterleavingOperation::CAdd:
193106c3fb27SDimitry Andric   case ComplexDeinterleavingOperation::CMulPartial:
193206c3fb27SDimitry Andric   case ComplexDeinterleavingOperation::Symmetric: {
193306c3fb27SDimitry Andric     Value *Input0 = ReplaceOperandIfExist(Node, 0);
193406c3fb27SDimitry Andric     Value *Input1 = ReplaceOperandIfExist(Node, 1);
193506c3fb27SDimitry Andric     Value *Accumulator = ReplaceOperandIfExist(Node, 2);
193606c3fb27SDimitry Andric     assert(!Input1 || (Input0->getType() == Input1->getType() &&
193706c3fb27SDimitry Andric                        "Node inputs need to be of the same type"));
193806c3fb27SDimitry Andric     assert(!Accumulator ||
193906c3fb27SDimitry Andric            (Input0->getType() == Accumulator->getType() &&
194006c3fb27SDimitry Andric             "Accumulator and input need to be of the same type"));
194106c3fb27SDimitry Andric     if (Node->Operation == ComplexDeinterleavingOperation::Symmetric)
194206c3fb27SDimitry Andric       ReplacementNode = replaceSymmetricNode(Builder, Node->Opcode, Node->Flags,
194306c3fb27SDimitry Andric                                              Input0, Input1);
194406c3fb27SDimitry Andric     else
194506c3fb27SDimitry Andric       ReplacementNode = TL->createComplexDeinterleavingIR(
194606c3fb27SDimitry Andric           Builder, Node->Operation, Node->Rotation, Input0, Input1,
194706c3fb27SDimitry Andric           Accumulator);
194806c3fb27SDimitry Andric     break;
194906c3fb27SDimitry Andric   }
195006c3fb27SDimitry Andric   case ComplexDeinterleavingOperation::Deinterleave:
195106c3fb27SDimitry Andric     llvm_unreachable("Deinterleave node should already have ReplacementNode");
195206c3fb27SDimitry Andric     break;
195306c3fb27SDimitry Andric   case ComplexDeinterleavingOperation::Splat: {
195406c3fb27SDimitry Andric     auto *NewTy = VectorType::getDoubleElementsVectorType(
195506c3fb27SDimitry Andric         cast<VectorType>(Node->Real->getType()));
195606c3fb27SDimitry Andric     auto *R = dyn_cast<Instruction>(Node->Real);
195706c3fb27SDimitry Andric     auto *I = dyn_cast<Instruction>(Node->Imag);
195806c3fb27SDimitry Andric     if (R && I) {
195906c3fb27SDimitry Andric       // Splats that are not constant are interleaved where they are located
196006c3fb27SDimitry Andric       Instruction *InsertPoint = (I->comesBefore(R) ? R : I)->getNextNode();
196106c3fb27SDimitry Andric       IRBuilder<> IRB(InsertPoint);
1962*0fca6ea1SDimitry Andric       ReplacementNode = IRB.CreateIntrinsic(Intrinsic::vector_interleave2,
196306c3fb27SDimitry Andric                                             NewTy, {Node->Real, Node->Imag});
1964*0fca6ea1SDimitry Andric     } else {
1965*0fca6ea1SDimitry Andric       ReplacementNode = Builder.CreateIntrinsic(
1966*0fca6ea1SDimitry Andric           Intrinsic::vector_interleave2, NewTy, {Node->Real, Node->Imag});
196706c3fb27SDimitry Andric     }
196806c3fb27SDimitry Andric     break;
196906c3fb27SDimitry Andric   }
197006c3fb27SDimitry Andric   case ComplexDeinterleavingOperation::ReductionPHI: {
197106c3fb27SDimitry Andric     // If Operation is ReductionPHI, a new empty PHINode is created.
197206c3fb27SDimitry Andric     // It is filled later when the ReductionOperation is processed.
197306c3fb27SDimitry Andric     auto *VTy = cast<VectorType>(Node->Real->getType());
197406c3fb27SDimitry Andric     auto *NewVTy = VectorType::getDoubleElementsVectorType(VTy);
1975*0fca6ea1SDimitry Andric     auto *NewPHI = PHINode::Create(NewVTy, 0, "", BackEdge->getFirstNonPHIIt());
197606c3fb27SDimitry Andric     OldToNewPHI[dyn_cast<PHINode>(Node->Real)] = NewPHI;
197706c3fb27SDimitry Andric     ReplacementNode = NewPHI;
197806c3fb27SDimitry Andric     break;
197906c3fb27SDimitry Andric   }
198006c3fb27SDimitry Andric   case ComplexDeinterleavingOperation::ReductionOperation:
198106c3fb27SDimitry Andric     ReplacementNode = replaceNode(Builder, Node->Operands[0]);
198206c3fb27SDimitry Andric     processReductionOperation(ReplacementNode, Node);
198306c3fb27SDimitry Andric     break;
198406c3fb27SDimitry Andric   case ComplexDeinterleavingOperation::ReductionSelect: {
198506c3fb27SDimitry Andric     auto *MaskReal = cast<Instruction>(Node->Real)->getOperand(0);
198606c3fb27SDimitry Andric     auto *MaskImag = cast<Instruction>(Node->Imag)->getOperand(0);
198706c3fb27SDimitry Andric     auto *A = replaceNode(Builder, Node->Operands[0]);
198806c3fb27SDimitry Andric     auto *B = replaceNode(Builder, Node->Operands[1]);
198906c3fb27SDimitry Andric     auto *NewMaskTy = VectorType::getDoubleElementsVectorType(
199006c3fb27SDimitry Andric         cast<VectorType>(MaskReal->getType()));
1991*0fca6ea1SDimitry Andric     auto *NewMask = Builder.CreateIntrinsic(Intrinsic::vector_interleave2,
199206c3fb27SDimitry Andric                                             NewMaskTy, {MaskReal, MaskImag});
199306c3fb27SDimitry Andric     ReplacementNode = Builder.CreateSelect(NewMask, A, B);
199406c3fb27SDimitry Andric     break;
199506c3fb27SDimitry Andric   }
199606c3fb27SDimitry Andric   }
1997bdd1243dSDimitry Andric 
199806c3fb27SDimitry Andric   assert(ReplacementNode && "Target failed to create Intrinsic call.");
1999bdd1243dSDimitry Andric   NumComplexTransformations += 1;
200006c3fb27SDimitry Andric   Node->ReplacementNode = ReplacementNode;
200106c3fb27SDimitry Andric   return ReplacementNode;
200206c3fb27SDimitry Andric }
200306c3fb27SDimitry Andric 
200406c3fb27SDimitry Andric void ComplexDeinterleavingGraph::processReductionOperation(
200506c3fb27SDimitry Andric     Value *OperationReplacement, RawNodePtr Node) {
200606c3fb27SDimitry Andric   auto *Real = cast<Instruction>(Node->Real);
200706c3fb27SDimitry Andric   auto *Imag = cast<Instruction>(Node->Imag);
200806c3fb27SDimitry Andric   auto *OldPHIReal = ReductionInfo[Real].first;
200906c3fb27SDimitry Andric   auto *OldPHIImag = ReductionInfo[Imag].first;
201006c3fb27SDimitry Andric   auto *NewPHI = OldToNewPHI[OldPHIReal];
201106c3fb27SDimitry Andric 
201206c3fb27SDimitry Andric   auto *VTy = cast<VectorType>(Real->getType());
201306c3fb27SDimitry Andric   auto *NewVTy = VectorType::getDoubleElementsVectorType(VTy);
201406c3fb27SDimitry Andric 
201506c3fb27SDimitry Andric   // We have to interleave initial origin values coming from IncomingBlock
201606c3fb27SDimitry Andric   Value *InitReal = OldPHIReal->getIncomingValueForBlock(Incoming);
201706c3fb27SDimitry Andric   Value *InitImag = OldPHIImag->getIncomingValueForBlock(Incoming);
201806c3fb27SDimitry Andric 
201906c3fb27SDimitry Andric   IRBuilder<> Builder(Incoming->getTerminator());
2020*0fca6ea1SDimitry Andric   auto *NewInit = Builder.CreateIntrinsic(Intrinsic::vector_interleave2, NewVTy,
2021*0fca6ea1SDimitry Andric                                           {InitReal, InitImag});
202206c3fb27SDimitry Andric 
202306c3fb27SDimitry Andric   NewPHI->addIncoming(NewInit, Incoming);
202406c3fb27SDimitry Andric   NewPHI->addIncoming(OperationReplacement, BackEdge);
202506c3fb27SDimitry Andric 
202606c3fb27SDimitry Andric   // Deinterleave complex vector outside of loop so that it can be finally
202706c3fb27SDimitry Andric   // reduced
202806c3fb27SDimitry Andric   auto *FinalReductionReal = ReductionInfo[Real].second;
202906c3fb27SDimitry Andric   auto *FinalReductionImag = ReductionInfo[Imag].second;
203006c3fb27SDimitry Andric 
203106c3fb27SDimitry Andric   Builder.SetInsertPoint(
203206c3fb27SDimitry Andric       &*FinalReductionReal->getParent()->getFirstInsertionPt());
2033*0fca6ea1SDimitry Andric   auto *Deinterleave = Builder.CreateIntrinsic(Intrinsic::vector_deinterleave2,
2034*0fca6ea1SDimitry Andric                                                OperationReplacement->getType(),
2035*0fca6ea1SDimitry Andric                                                OperationReplacement);
203606c3fb27SDimitry Andric 
203706c3fb27SDimitry Andric   auto *NewReal = Builder.CreateExtractValue(Deinterleave, (uint64_t)0);
203806c3fb27SDimitry Andric   FinalReductionReal->replaceUsesOfWith(Real, NewReal);
203906c3fb27SDimitry Andric 
204006c3fb27SDimitry Andric   Builder.SetInsertPoint(FinalReductionImag);
204106c3fb27SDimitry Andric   auto *NewImag = Builder.CreateExtractValue(Deinterleave, 1);
204206c3fb27SDimitry Andric   FinalReductionImag->replaceUsesOfWith(Imag, NewImag);
2043bdd1243dSDimitry Andric }
2044bdd1243dSDimitry Andric 
2045bdd1243dSDimitry Andric void ComplexDeinterleavingGraph::replaceNodes() {
204606c3fb27SDimitry Andric   SmallVector<Instruction *, 16> DeadInstrRoots;
204706c3fb27SDimitry Andric   for (auto *RootInstruction : OrderedRoots) {
204806c3fb27SDimitry Andric     // Check if this potential root went through check process and we can
204906c3fb27SDimitry Andric     // deinterleave it
205006c3fb27SDimitry Andric     if (!RootToNode.count(RootInstruction))
205106c3fb27SDimitry Andric       continue;
205206c3fb27SDimitry Andric 
205306c3fb27SDimitry Andric     IRBuilder<> Builder(RootInstruction);
205406c3fb27SDimitry Andric     auto RootNode = RootToNode[RootInstruction];
205506c3fb27SDimitry Andric     Value *R = replaceNode(Builder, RootNode.get());
205606c3fb27SDimitry Andric 
205706c3fb27SDimitry Andric     if (RootNode->Operation ==
205806c3fb27SDimitry Andric         ComplexDeinterleavingOperation::ReductionOperation) {
205906c3fb27SDimitry Andric       auto *RootReal = cast<Instruction>(RootNode->Real);
206006c3fb27SDimitry Andric       auto *RootImag = cast<Instruction>(RootNode->Imag);
206106c3fb27SDimitry Andric       ReductionInfo[RootReal].first->removeIncomingValue(BackEdge);
206206c3fb27SDimitry Andric       ReductionInfo[RootImag].first->removeIncomingValue(BackEdge);
206306c3fb27SDimitry Andric       DeadInstrRoots.push_back(cast<Instruction>(RootReal));
206406c3fb27SDimitry Andric       DeadInstrRoots.push_back(cast<Instruction>(RootImag));
206506c3fb27SDimitry Andric     } else {
206606c3fb27SDimitry Andric       assert(R && "Unable to find replacement for RootInstruction");
206706c3fb27SDimitry Andric       DeadInstrRoots.push_back(RootInstruction);
206806c3fb27SDimitry Andric       RootInstruction->replaceAllUsesWith(R);
206906c3fb27SDimitry Andric     }
2070bdd1243dSDimitry Andric   }
2071bdd1243dSDimitry Andric 
207206c3fb27SDimitry Andric   for (auto *I : DeadInstrRoots)
207306c3fb27SDimitry Andric     RecursivelyDeleteTriviallyDeadInstructions(I, TLI);
2074bdd1243dSDimitry Andric }
2075