xref: /freebsd-src/contrib/llvm-project/llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp (revision 06c3fb2749bda94cb5201f81ffdb8fa6c3161b2e)
1bdd1243dSDimitry Andric //===- ComplexDeinterleavingPass.cpp --------------------------------------===//
2bdd1243dSDimitry Andric //
3bdd1243dSDimitry Andric // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4bdd1243dSDimitry Andric // See https://llvm.org/LICENSE.txt for license information.
5bdd1243dSDimitry Andric // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6bdd1243dSDimitry Andric //
7bdd1243dSDimitry Andric //===----------------------------------------------------------------------===//
8bdd1243dSDimitry Andric //
9bdd1243dSDimitry Andric // Identification:
10bdd1243dSDimitry Andric // This step is responsible for finding the patterns that can be lowered to
11bdd1243dSDimitry Andric // complex instructions, and building a graph to represent the complex
12bdd1243dSDimitry Andric // structures. Starting from the "Converging Shuffle" (a shuffle that
13bdd1243dSDimitry Andric // reinterleaves the complex components, with a mask of <0, 2, 1, 3>), the
14bdd1243dSDimitry Andric // operands are evaluated and identified as "Composite Nodes" (collections of
15bdd1243dSDimitry Andric // instructions that can potentially be lowered to a single complex
16bdd1243dSDimitry Andric // instruction). This is performed by checking the real and imaginary components
17bdd1243dSDimitry Andric // and tracking the data flow for each component while following the operand
18bdd1243dSDimitry Andric // pairs. Validity of each node is expected to be done upon creation, and any
19bdd1243dSDimitry Andric // validation errors should halt traversal and prevent further graph
20bdd1243dSDimitry Andric // construction.
21*06c3fb27SDimitry Andric // Instead of relying on Shuffle operations, vector interleaving and
22*06c3fb27SDimitry Andric // deinterleaving can be represented by vector.interleave2 and
23*06c3fb27SDimitry Andric // vector.deinterleave2 intrinsics. Scalable vectors can be represented only by
24*06c3fb27SDimitry Andric // these intrinsics, whereas, fixed-width vectors are recognized for both
25*06c3fb27SDimitry Andric // shufflevector instruction and intrinsics.
26bdd1243dSDimitry Andric //
27bdd1243dSDimitry Andric // Replacement:
28bdd1243dSDimitry Andric // This step traverses the graph built up by identification, delegating to the
29bdd1243dSDimitry Andric // target to validate and generate the correct intrinsics, and plumbs them
30bdd1243dSDimitry Andric // together connecting each end of the new intrinsics graph to the existing
31bdd1243dSDimitry Andric // use-def chain. This step is assumed to finish successfully, as all
32bdd1243dSDimitry Andric // information is expected to be correct by this point.
33bdd1243dSDimitry Andric //
34bdd1243dSDimitry Andric //
35bdd1243dSDimitry Andric // Internal data structure:
36bdd1243dSDimitry Andric // ComplexDeinterleavingGraph:
37bdd1243dSDimitry Andric // Keeps references to all the valid CompositeNodes formed as part of the
38bdd1243dSDimitry Andric // transformation, and every Instruction contained within said nodes. It also
39bdd1243dSDimitry Andric // holds onto a reference to the root Instruction, and the root node that should
40bdd1243dSDimitry Andric // replace it.
41bdd1243dSDimitry Andric //
42bdd1243dSDimitry Andric // ComplexDeinterleavingCompositeNode:
43bdd1243dSDimitry Andric // A CompositeNode represents a single transformation point; each node should
44bdd1243dSDimitry Andric // transform into a single complex instruction (ignoring vector splitting, which
45bdd1243dSDimitry Andric // would generate more instructions per node). They are identified in a
46bdd1243dSDimitry Andric // depth-first manner, traversing and identifying the operands of each
47bdd1243dSDimitry Andric // instruction in the order they appear in the IR.
48bdd1243dSDimitry Andric // Each node maintains a reference  to its Real and Imaginary instructions,
49bdd1243dSDimitry Andric // as well as any additional instructions that make up the identified operation
50bdd1243dSDimitry Andric // (Internal instructions should only have uses within their containing node).
51bdd1243dSDimitry Andric // A Node also contains the rotation and operation type that it represents.
52bdd1243dSDimitry Andric // Operands contains pointers to other CompositeNodes, acting as the edges in
53bdd1243dSDimitry Andric // the graph. ReplacementValue is the transformed Value* that has been emitted
54bdd1243dSDimitry Andric // to the IR.
55bdd1243dSDimitry Andric //
56bdd1243dSDimitry Andric // Note: If the operation of a Node is Shuffle, only the Real, Imaginary, and
57bdd1243dSDimitry Andric // ReplacementValue fields of that Node are relevant, where the ReplacementValue
58bdd1243dSDimitry Andric // should be pre-populated.
59bdd1243dSDimitry Andric //
60bdd1243dSDimitry Andric //===----------------------------------------------------------------------===//
61bdd1243dSDimitry Andric 
62bdd1243dSDimitry Andric #include "llvm/CodeGen/ComplexDeinterleavingPass.h"
63bdd1243dSDimitry Andric #include "llvm/ADT/Statistic.h"
64bdd1243dSDimitry Andric #include "llvm/Analysis/TargetLibraryInfo.h"
65bdd1243dSDimitry Andric #include "llvm/Analysis/TargetTransformInfo.h"
66bdd1243dSDimitry Andric #include "llvm/CodeGen/TargetLowering.h"
67bdd1243dSDimitry Andric #include "llvm/CodeGen/TargetPassConfig.h"
68bdd1243dSDimitry Andric #include "llvm/CodeGen/TargetSubtargetInfo.h"
69bdd1243dSDimitry Andric #include "llvm/IR/IRBuilder.h"
70*06c3fb27SDimitry Andric #include "llvm/IR/PatternMatch.h"
71bdd1243dSDimitry Andric #include "llvm/InitializePasses.h"
72bdd1243dSDimitry Andric #include "llvm/Target/TargetMachine.h"
73bdd1243dSDimitry Andric #include "llvm/Transforms/Utils/Local.h"
74bdd1243dSDimitry Andric #include <algorithm>
75bdd1243dSDimitry Andric 
76bdd1243dSDimitry Andric using namespace llvm;
77bdd1243dSDimitry Andric using namespace PatternMatch;
78bdd1243dSDimitry Andric 
79bdd1243dSDimitry Andric #define DEBUG_TYPE "complex-deinterleaving"
80bdd1243dSDimitry Andric 
81bdd1243dSDimitry Andric STATISTIC(NumComplexTransformations, "Amount of complex patterns transformed");
82bdd1243dSDimitry Andric 
83bdd1243dSDimitry Andric static cl::opt<bool> ComplexDeinterleavingEnabled(
84bdd1243dSDimitry Andric     "enable-complex-deinterleaving",
85bdd1243dSDimitry Andric     cl::desc("Enable generation of complex instructions"), cl::init(true),
86bdd1243dSDimitry Andric     cl::Hidden);
87bdd1243dSDimitry Andric 
88bdd1243dSDimitry Andric /// Checks the given mask, and determines whether said mask is interleaving.
89bdd1243dSDimitry Andric ///
90bdd1243dSDimitry Andric /// To be interleaving, a mask must alternate between `i` and `i + (Length /
91bdd1243dSDimitry Andric /// 2)`, and must contain all numbers within the range of `[0..Length)` (e.g. a
92bdd1243dSDimitry Andric /// 4x vector interleaving mask would be <0, 2, 1, 3>).
93bdd1243dSDimitry Andric static bool isInterleavingMask(ArrayRef<int> Mask);
94bdd1243dSDimitry Andric 
95bdd1243dSDimitry Andric /// Checks the given mask, and determines whether said mask is deinterleaving.
96bdd1243dSDimitry Andric ///
97bdd1243dSDimitry Andric /// To be deinterleaving, a mask must increment in steps of 2, and either start
98bdd1243dSDimitry Andric /// with 0 or 1.
99bdd1243dSDimitry Andric /// (e.g. an 8x vector deinterleaving mask would be either <0, 2, 4, 6> or
100bdd1243dSDimitry Andric /// <1, 3, 5, 7>).
101bdd1243dSDimitry Andric static bool isDeinterleavingMask(ArrayRef<int> Mask);
102bdd1243dSDimitry Andric 
103*06c3fb27SDimitry Andric /// Returns true if the operation is a negation of V, and it works for both
104*06c3fb27SDimitry Andric /// integers and floats.
105*06c3fb27SDimitry Andric static bool isNeg(Value *V);
106*06c3fb27SDimitry Andric 
107*06c3fb27SDimitry Andric /// Returns the operand for negation operation.
108*06c3fb27SDimitry Andric static Value *getNegOperand(Value *V);
109*06c3fb27SDimitry Andric 
110bdd1243dSDimitry Andric namespace {
111bdd1243dSDimitry Andric 
112bdd1243dSDimitry Andric class ComplexDeinterleavingLegacyPass : public FunctionPass {
113bdd1243dSDimitry Andric public:
114bdd1243dSDimitry Andric   static char ID;
115bdd1243dSDimitry Andric 
116bdd1243dSDimitry Andric   ComplexDeinterleavingLegacyPass(const TargetMachine *TM = nullptr)
117bdd1243dSDimitry Andric       : FunctionPass(ID), TM(TM) {
118bdd1243dSDimitry Andric     initializeComplexDeinterleavingLegacyPassPass(
119bdd1243dSDimitry Andric         *PassRegistry::getPassRegistry());
120bdd1243dSDimitry Andric   }
121bdd1243dSDimitry Andric 
122bdd1243dSDimitry Andric   StringRef getPassName() const override {
123bdd1243dSDimitry Andric     return "Complex Deinterleaving Pass";
124bdd1243dSDimitry Andric   }
125bdd1243dSDimitry Andric 
126bdd1243dSDimitry Andric   bool runOnFunction(Function &F) override;
127bdd1243dSDimitry Andric   void getAnalysisUsage(AnalysisUsage &AU) const override {
128bdd1243dSDimitry Andric     AU.addRequired<TargetLibraryInfoWrapperPass>();
129bdd1243dSDimitry Andric     AU.setPreservesCFG();
130bdd1243dSDimitry Andric   }
131bdd1243dSDimitry Andric 
132bdd1243dSDimitry Andric private:
133bdd1243dSDimitry Andric   const TargetMachine *TM;
134bdd1243dSDimitry Andric };
135bdd1243dSDimitry Andric 
136bdd1243dSDimitry Andric class ComplexDeinterleavingGraph;
137bdd1243dSDimitry Andric struct ComplexDeinterleavingCompositeNode {
138bdd1243dSDimitry Andric 
139bdd1243dSDimitry Andric   ComplexDeinterleavingCompositeNode(ComplexDeinterleavingOperation Op,
140*06c3fb27SDimitry Andric                                      Value *R, Value *I)
141bdd1243dSDimitry Andric       : Operation(Op), Real(R), Imag(I) {}
142bdd1243dSDimitry Andric 
143bdd1243dSDimitry Andric private:
144bdd1243dSDimitry Andric   friend class ComplexDeinterleavingGraph;
145bdd1243dSDimitry Andric   using NodePtr = std::shared_ptr<ComplexDeinterleavingCompositeNode>;
146bdd1243dSDimitry Andric   using RawNodePtr = ComplexDeinterleavingCompositeNode *;
147bdd1243dSDimitry Andric 
148bdd1243dSDimitry Andric public:
149bdd1243dSDimitry Andric   ComplexDeinterleavingOperation Operation;
150*06c3fb27SDimitry Andric   Value *Real;
151*06c3fb27SDimitry Andric   Value *Imag;
152bdd1243dSDimitry Andric 
153*06c3fb27SDimitry Andric   // This two members are required exclusively for generating
154*06c3fb27SDimitry Andric   // ComplexDeinterleavingOperation::Symmetric operations.
155*06c3fb27SDimitry Andric   unsigned Opcode;
156*06c3fb27SDimitry Andric   std::optional<FastMathFlags> Flags;
157*06c3fb27SDimitry Andric 
158*06c3fb27SDimitry Andric   ComplexDeinterleavingRotation Rotation =
159*06c3fb27SDimitry Andric       ComplexDeinterleavingRotation::Rotation_0;
160bdd1243dSDimitry Andric   SmallVector<RawNodePtr> Operands;
161bdd1243dSDimitry Andric   Value *ReplacementNode = nullptr;
162bdd1243dSDimitry Andric 
163bdd1243dSDimitry Andric   void addOperand(NodePtr Node) { Operands.push_back(Node.get()); }
164bdd1243dSDimitry Andric 
165bdd1243dSDimitry Andric   void dump() { dump(dbgs()); }
166bdd1243dSDimitry Andric   void dump(raw_ostream &OS) {
167bdd1243dSDimitry Andric     auto PrintValue = [&](Value *V) {
168bdd1243dSDimitry Andric       if (V) {
169bdd1243dSDimitry Andric         OS << "\"";
170bdd1243dSDimitry Andric         V->print(OS, true);
171bdd1243dSDimitry Andric         OS << "\"\n";
172bdd1243dSDimitry Andric       } else
173bdd1243dSDimitry Andric         OS << "nullptr\n";
174bdd1243dSDimitry Andric     };
175bdd1243dSDimitry Andric     auto PrintNodeRef = [&](RawNodePtr Ptr) {
176bdd1243dSDimitry Andric       if (Ptr)
177bdd1243dSDimitry Andric         OS << Ptr << "\n";
178bdd1243dSDimitry Andric       else
179bdd1243dSDimitry Andric         OS << "nullptr\n";
180bdd1243dSDimitry Andric     };
181bdd1243dSDimitry Andric 
182bdd1243dSDimitry Andric     OS << "- CompositeNode: " << this << "\n";
183bdd1243dSDimitry Andric     OS << "  Real: ";
184bdd1243dSDimitry Andric     PrintValue(Real);
185bdd1243dSDimitry Andric     OS << "  Imag: ";
186bdd1243dSDimitry Andric     PrintValue(Imag);
187bdd1243dSDimitry Andric     OS << "  ReplacementNode: ";
188bdd1243dSDimitry Andric     PrintValue(ReplacementNode);
189bdd1243dSDimitry Andric     OS << "  Operation: " << (int)Operation << "\n";
190bdd1243dSDimitry Andric     OS << "  Rotation: " << ((int)Rotation * 90) << "\n";
191bdd1243dSDimitry Andric     OS << "  Operands: \n";
192bdd1243dSDimitry Andric     for (const auto &Op : Operands) {
193bdd1243dSDimitry Andric       OS << "    - ";
194bdd1243dSDimitry Andric       PrintNodeRef(Op);
195bdd1243dSDimitry Andric     }
196bdd1243dSDimitry Andric   }
197bdd1243dSDimitry Andric };
198bdd1243dSDimitry Andric 
199bdd1243dSDimitry Andric class ComplexDeinterleavingGraph {
200bdd1243dSDimitry Andric public:
201*06c3fb27SDimitry Andric   struct Product {
202*06c3fb27SDimitry Andric     Value *Multiplier;
203*06c3fb27SDimitry Andric     Value *Multiplicand;
204*06c3fb27SDimitry Andric     bool IsPositive;
205*06c3fb27SDimitry Andric   };
206*06c3fb27SDimitry Andric 
207*06c3fb27SDimitry Andric   using Addend = std::pair<Value *, bool>;
208bdd1243dSDimitry Andric   using NodePtr = ComplexDeinterleavingCompositeNode::NodePtr;
209bdd1243dSDimitry Andric   using RawNodePtr = ComplexDeinterleavingCompositeNode::RawNodePtr;
210*06c3fb27SDimitry Andric 
211*06c3fb27SDimitry Andric   // Helper struct for holding info about potential partial multiplication
212*06c3fb27SDimitry Andric   // candidates
213*06c3fb27SDimitry Andric   struct PartialMulCandidate {
214*06c3fb27SDimitry Andric     Value *Common;
215*06c3fb27SDimitry Andric     NodePtr Node;
216*06c3fb27SDimitry Andric     unsigned RealIdx;
217*06c3fb27SDimitry Andric     unsigned ImagIdx;
218*06c3fb27SDimitry Andric     bool IsNodeInverted;
219*06c3fb27SDimitry Andric   };
220*06c3fb27SDimitry Andric 
221*06c3fb27SDimitry Andric   explicit ComplexDeinterleavingGraph(const TargetLowering *TL,
222*06c3fb27SDimitry Andric                                       const TargetLibraryInfo *TLI)
223*06c3fb27SDimitry Andric       : TL(TL), TLI(TLI) {}
224bdd1243dSDimitry Andric 
225bdd1243dSDimitry Andric private:
226*06c3fb27SDimitry Andric   const TargetLowering *TL = nullptr;
227*06c3fb27SDimitry Andric   const TargetLibraryInfo *TLI = nullptr;
228bdd1243dSDimitry Andric   SmallVector<NodePtr> CompositeNodes;
229*06c3fb27SDimitry Andric 
230*06c3fb27SDimitry Andric   SmallPtrSet<Instruction *, 16> FinalInstructions;
231*06c3fb27SDimitry Andric 
232*06c3fb27SDimitry Andric   /// Root instructions are instructions from which complex computation starts
233*06c3fb27SDimitry Andric   std::map<Instruction *, NodePtr> RootToNode;
234*06c3fb27SDimitry Andric 
235*06c3fb27SDimitry Andric   /// Topologically sorted root instructions
236*06c3fb27SDimitry Andric   SmallVector<Instruction *, 1> OrderedRoots;
237*06c3fb27SDimitry Andric 
238*06c3fb27SDimitry Andric   /// When examining a basic block for complex deinterleaving, if it is a simple
239*06c3fb27SDimitry Andric   /// one-block loop, then the only incoming block is 'Incoming' and the
240*06c3fb27SDimitry Andric   /// 'BackEdge' block is the block itself."
241*06c3fb27SDimitry Andric   BasicBlock *BackEdge = nullptr;
242*06c3fb27SDimitry Andric   BasicBlock *Incoming = nullptr;
243*06c3fb27SDimitry Andric 
244*06c3fb27SDimitry Andric   /// ReductionInfo maps from %ReductionOp to %PHInode and Instruction
245*06c3fb27SDimitry Andric   /// %OutsideUser as it is shown in the IR:
246*06c3fb27SDimitry Andric   ///
247*06c3fb27SDimitry Andric   /// vector.body:
248*06c3fb27SDimitry Andric   ///   %PHInode = phi <vector type> [ zeroinitializer, %entry ],
249*06c3fb27SDimitry Andric   ///                                [ %ReductionOp, %vector.body ]
250*06c3fb27SDimitry Andric   ///   ...
251*06c3fb27SDimitry Andric   ///   %ReductionOp = fadd i64 ...
252*06c3fb27SDimitry Andric   ///   ...
253*06c3fb27SDimitry Andric   ///   br i1 %condition, label %vector.body, %middle.block
254*06c3fb27SDimitry Andric   ///
255*06c3fb27SDimitry Andric   /// middle.block:
256*06c3fb27SDimitry Andric   ///   %OutsideUser = llvm.vector.reduce.fadd(..., %ReductionOp)
257*06c3fb27SDimitry Andric   ///
258*06c3fb27SDimitry Andric   /// %OutsideUser can be `llvm.vector.reduce.fadd` or `fadd` preceding
259*06c3fb27SDimitry Andric   /// `llvm.vector.reduce.fadd` when unroll factor isn't one.
260*06c3fb27SDimitry Andric   std::map<Instruction *, std::pair<PHINode *, Instruction *>> ReductionInfo;
261*06c3fb27SDimitry Andric 
262*06c3fb27SDimitry Andric   /// In the process of detecting a reduction, we consider a pair of
263*06c3fb27SDimitry Andric   /// %ReductionOP, which we refer to as real and imag (or vice versa), and
264*06c3fb27SDimitry Andric   /// traverse the use-tree to detect complex operations. As this is a reduction
265*06c3fb27SDimitry Andric   /// operation, it will eventually reach RealPHI and ImagPHI, which corresponds
266*06c3fb27SDimitry Andric   /// to the %ReductionOPs that we suspect to be complex.
267*06c3fb27SDimitry Andric   /// RealPHI and ImagPHI are used by the identifyPHINode method.
268*06c3fb27SDimitry Andric   PHINode *RealPHI = nullptr;
269*06c3fb27SDimitry Andric   PHINode *ImagPHI = nullptr;
270*06c3fb27SDimitry Andric 
271*06c3fb27SDimitry Andric   /// Set this flag to true if RealPHI and ImagPHI were reached during reduction
272*06c3fb27SDimitry Andric   /// detection.
273*06c3fb27SDimitry Andric   bool PHIsFound = false;
274*06c3fb27SDimitry Andric 
275*06c3fb27SDimitry Andric   /// OldToNewPHI maps the original real PHINode to a new, double-sized PHINode.
276*06c3fb27SDimitry Andric   /// The new PHINode corresponds to a vector of deinterleaved complex numbers.
277*06c3fb27SDimitry Andric   /// This mapping is populated during
278*06c3fb27SDimitry Andric   /// ComplexDeinterleavingOperation::ReductionPHI node replacement. It is then
279*06c3fb27SDimitry Andric   /// used in the ComplexDeinterleavingOperation::ReductionOperation node
280*06c3fb27SDimitry Andric   /// replacement process.
281*06c3fb27SDimitry Andric   std::map<PHINode *, PHINode *> OldToNewPHI;
282bdd1243dSDimitry Andric 
283bdd1243dSDimitry Andric   NodePtr prepareCompositeNode(ComplexDeinterleavingOperation Operation,
284*06c3fb27SDimitry Andric                                Value *R, Value *I) {
285*06c3fb27SDimitry Andric     assert(((Operation != ComplexDeinterleavingOperation::ReductionPHI &&
286*06c3fb27SDimitry Andric              Operation != ComplexDeinterleavingOperation::ReductionOperation) ||
287*06c3fb27SDimitry Andric             (R && I)) &&
288*06c3fb27SDimitry Andric            "Reduction related nodes must have Real and Imaginary parts");
289bdd1243dSDimitry Andric     return std::make_shared<ComplexDeinterleavingCompositeNode>(Operation, R,
290bdd1243dSDimitry Andric                                                                 I);
291bdd1243dSDimitry Andric   }
292bdd1243dSDimitry Andric 
293bdd1243dSDimitry Andric   NodePtr submitCompositeNode(NodePtr Node) {
294bdd1243dSDimitry Andric     CompositeNodes.push_back(Node);
295bdd1243dSDimitry Andric     return Node;
296bdd1243dSDimitry Andric   }
297bdd1243dSDimitry Andric 
298bdd1243dSDimitry Andric   NodePtr getContainingComposite(Value *R, Value *I) {
299bdd1243dSDimitry Andric     for (const auto &CN : CompositeNodes) {
300bdd1243dSDimitry Andric       if (CN->Real == R && CN->Imag == I)
301bdd1243dSDimitry Andric         return CN;
302bdd1243dSDimitry Andric     }
303bdd1243dSDimitry Andric     return nullptr;
304bdd1243dSDimitry Andric   }
305bdd1243dSDimitry Andric 
306bdd1243dSDimitry Andric   /// Identifies a complex partial multiply pattern and its rotation, based on
307bdd1243dSDimitry Andric   /// the following patterns
308bdd1243dSDimitry Andric   ///
309bdd1243dSDimitry Andric   ///  0:  r: cr + ar * br
310bdd1243dSDimitry Andric   ///      i: ci + ar * bi
311bdd1243dSDimitry Andric   /// 90:  r: cr - ai * bi
312bdd1243dSDimitry Andric   ///      i: ci + ai * br
313bdd1243dSDimitry Andric   /// 180: r: cr - ar * br
314bdd1243dSDimitry Andric   ///      i: ci - ar * bi
315bdd1243dSDimitry Andric   /// 270: r: cr + ai * bi
316bdd1243dSDimitry Andric   ///      i: ci - ai * br
317bdd1243dSDimitry Andric   NodePtr identifyPartialMul(Instruction *Real, Instruction *Imag);
318bdd1243dSDimitry Andric 
319bdd1243dSDimitry Andric   /// Identify the other branch of a Partial Mul, taking the CommonOperandI that
320bdd1243dSDimitry Andric   /// is partially known from identifyPartialMul, filling in the other half of
321bdd1243dSDimitry Andric   /// the complex pair.
322*06c3fb27SDimitry Andric   NodePtr
323*06c3fb27SDimitry Andric   identifyNodeWithImplicitAdd(Instruction *I, Instruction *J,
324*06c3fb27SDimitry Andric                               std::pair<Value *, Value *> &CommonOperandI);
325bdd1243dSDimitry Andric 
326bdd1243dSDimitry Andric   /// Identifies a complex add pattern and its rotation, based on the following
327bdd1243dSDimitry Andric   /// patterns.
328bdd1243dSDimitry Andric   ///
329bdd1243dSDimitry Andric   /// 90:  r: ar - bi
330bdd1243dSDimitry Andric   ///      i: ai + br
331bdd1243dSDimitry Andric   /// 270: r: ar + bi
332bdd1243dSDimitry Andric   ///      i: ai - br
333bdd1243dSDimitry Andric   NodePtr identifyAdd(Instruction *Real, Instruction *Imag);
334*06c3fb27SDimitry Andric   NodePtr identifySymmetricOperation(Instruction *Real, Instruction *Imag);
335bdd1243dSDimitry Andric 
336*06c3fb27SDimitry Andric   NodePtr identifyNode(Value *R, Value *I);
337bdd1243dSDimitry Andric 
338*06c3fb27SDimitry Andric   /// Determine if a sum of complex numbers can be formed from \p RealAddends
339*06c3fb27SDimitry Andric   /// and \p ImagAddens. If \p Accumulator is not null, add the result to it.
340*06c3fb27SDimitry Andric   /// Return nullptr if it is not possible to construct a complex number.
341*06c3fb27SDimitry Andric   /// \p Flags are needed to generate symmetric Add and Sub operations.
342*06c3fb27SDimitry Andric   NodePtr identifyAdditions(std::list<Addend> &RealAddends,
343*06c3fb27SDimitry Andric                             std::list<Addend> &ImagAddends,
344*06c3fb27SDimitry Andric                             std::optional<FastMathFlags> Flags,
345*06c3fb27SDimitry Andric                             NodePtr Accumulator);
346*06c3fb27SDimitry Andric 
347*06c3fb27SDimitry Andric   /// Extract one addend that have both real and imaginary parts positive.
348*06c3fb27SDimitry Andric   NodePtr extractPositiveAddend(std::list<Addend> &RealAddends,
349*06c3fb27SDimitry Andric                                 std::list<Addend> &ImagAddends);
350*06c3fb27SDimitry Andric 
351*06c3fb27SDimitry Andric   /// Determine if sum of multiplications of complex numbers can be formed from
352*06c3fb27SDimitry Andric   /// \p RealMuls and \p ImagMuls. If \p Accumulator is not null, add the result
353*06c3fb27SDimitry Andric   /// to it. Return nullptr if it is not possible to construct a complex number.
354*06c3fb27SDimitry Andric   NodePtr identifyMultiplications(std::vector<Product> &RealMuls,
355*06c3fb27SDimitry Andric                                   std::vector<Product> &ImagMuls,
356*06c3fb27SDimitry Andric                                   NodePtr Accumulator);
357*06c3fb27SDimitry Andric 
358*06c3fb27SDimitry Andric   /// Go through pairs of multiplication (one Real and one Imag) and find all
359*06c3fb27SDimitry Andric   /// possible candidates for partial multiplication and put them into \p
360*06c3fb27SDimitry Andric   /// Candidates. Returns true if all Product has pair with common operand
361*06c3fb27SDimitry Andric   bool collectPartialMuls(const std::vector<Product> &RealMuls,
362*06c3fb27SDimitry Andric                           const std::vector<Product> &ImagMuls,
363*06c3fb27SDimitry Andric                           std::vector<PartialMulCandidate> &Candidates);
364*06c3fb27SDimitry Andric 
365*06c3fb27SDimitry Andric   /// If the code is compiled with -Ofast or expressions have `reassoc` flag,
366*06c3fb27SDimitry Andric   /// the order of complex computation operations may be significantly altered,
367*06c3fb27SDimitry Andric   /// and the real and imaginary parts may not be executed in parallel. This
368*06c3fb27SDimitry Andric   /// function takes this into consideration and employs a more general approach
369*06c3fb27SDimitry Andric   /// to identify complex computations. Initially, it gathers all the addends
370*06c3fb27SDimitry Andric   /// and multiplicands and then constructs a complex expression from them.
371*06c3fb27SDimitry Andric   NodePtr identifyReassocNodes(Instruction *I, Instruction *J);
372*06c3fb27SDimitry Andric 
373*06c3fb27SDimitry Andric   NodePtr identifyRoot(Instruction *I);
374*06c3fb27SDimitry Andric 
375*06c3fb27SDimitry Andric   /// Identifies the Deinterleave operation applied to a vector containing
376*06c3fb27SDimitry Andric   /// complex numbers. There are two ways to represent the Deinterleave
377*06c3fb27SDimitry Andric   /// operation:
378*06c3fb27SDimitry Andric   /// * Using two shufflevectors with even indices for /pReal instruction and
379*06c3fb27SDimitry Andric   /// odd indices for /pImag instructions (only for fixed-width vectors)
380*06c3fb27SDimitry Andric   /// * Using two extractvalue instructions applied to `vector.deinterleave2`
381*06c3fb27SDimitry Andric   /// intrinsic (for both fixed and scalable vectors)
382*06c3fb27SDimitry Andric   NodePtr identifyDeinterleave(Instruction *Real, Instruction *Imag);
383*06c3fb27SDimitry Andric 
384*06c3fb27SDimitry Andric   /// identifying the operation that represents a complex number repeated in a
385*06c3fb27SDimitry Andric   /// Splat vector. There are two possible types of splats: ConstantExpr with
386*06c3fb27SDimitry Andric   /// the opcode ShuffleVector and ShuffleVectorInstr. Both should have an
387*06c3fb27SDimitry Andric   /// initialization mask with all values set to zero.
388*06c3fb27SDimitry Andric   NodePtr identifySplat(Value *Real, Value *Imag);
389*06c3fb27SDimitry Andric 
390*06c3fb27SDimitry Andric   NodePtr identifyPHINode(Instruction *Real, Instruction *Imag);
391*06c3fb27SDimitry Andric 
392*06c3fb27SDimitry Andric   /// Identifies SelectInsts in a loop that has reduction with predication masks
393*06c3fb27SDimitry Andric   /// and/or predicated tail folding
394*06c3fb27SDimitry Andric   NodePtr identifySelectNode(Instruction *Real, Instruction *Imag);
395*06c3fb27SDimitry Andric 
396*06c3fb27SDimitry Andric   Value *replaceNode(IRBuilderBase &Builder, RawNodePtr Node);
397*06c3fb27SDimitry Andric 
398*06c3fb27SDimitry Andric   /// Complete IR modifications after producing new reduction operation:
399*06c3fb27SDimitry Andric   /// * Populate the PHINode generated for
400*06c3fb27SDimitry Andric   /// ComplexDeinterleavingOperation::ReductionPHI
401*06c3fb27SDimitry Andric   /// * Deinterleave the final value outside of the loop and repurpose original
402*06c3fb27SDimitry Andric   /// reduction users
403*06c3fb27SDimitry Andric   void processReductionOperation(Value *OperationReplacement, RawNodePtr Node);
404bdd1243dSDimitry Andric 
405bdd1243dSDimitry Andric public:
406bdd1243dSDimitry Andric   void dump() { dump(dbgs()); }
407bdd1243dSDimitry Andric   void dump(raw_ostream &OS) {
408bdd1243dSDimitry Andric     for (const auto &Node : CompositeNodes)
409bdd1243dSDimitry Andric       Node->dump(OS);
410bdd1243dSDimitry Andric   }
411bdd1243dSDimitry Andric 
412bdd1243dSDimitry Andric   /// Returns false if the deinterleaving operation should be cancelled for the
413bdd1243dSDimitry Andric   /// current graph.
414bdd1243dSDimitry Andric   bool identifyNodes(Instruction *RootI);
415bdd1243dSDimitry Andric 
416*06c3fb27SDimitry Andric   /// In case \pB is one-block loop, this function seeks potential reductions
417*06c3fb27SDimitry Andric   /// and populates ReductionInfo. Returns true if any reductions were
418*06c3fb27SDimitry Andric   /// identified.
419*06c3fb27SDimitry Andric   bool collectPotentialReductions(BasicBlock *B);
420*06c3fb27SDimitry Andric 
421*06c3fb27SDimitry Andric   void identifyReductionNodes();
422*06c3fb27SDimitry Andric 
423*06c3fb27SDimitry Andric   /// Check that every instruction, from the roots to the leaves, has internal
424*06c3fb27SDimitry Andric   /// uses.
425*06c3fb27SDimitry Andric   bool checkNodes();
426*06c3fb27SDimitry Andric 
427bdd1243dSDimitry Andric   /// Perform the actual replacement of the underlying instruction graph.
428bdd1243dSDimitry Andric   void replaceNodes();
429bdd1243dSDimitry Andric };
430bdd1243dSDimitry Andric 
431bdd1243dSDimitry Andric class ComplexDeinterleaving {
432bdd1243dSDimitry Andric public:
433bdd1243dSDimitry Andric   ComplexDeinterleaving(const TargetLowering *tl, const TargetLibraryInfo *tli)
434bdd1243dSDimitry Andric       : TL(tl), TLI(tli) {}
435bdd1243dSDimitry Andric   bool runOnFunction(Function &F);
436bdd1243dSDimitry Andric 
437bdd1243dSDimitry Andric private:
438bdd1243dSDimitry Andric   bool evaluateBasicBlock(BasicBlock *B);
439bdd1243dSDimitry Andric 
440bdd1243dSDimitry Andric   const TargetLowering *TL = nullptr;
441bdd1243dSDimitry Andric   const TargetLibraryInfo *TLI = nullptr;
442bdd1243dSDimitry Andric };
443bdd1243dSDimitry Andric 
444bdd1243dSDimitry Andric } // namespace
445bdd1243dSDimitry Andric 
446bdd1243dSDimitry Andric char ComplexDeinterleavingLegacyPass::ID = 0;
447bdd1243dSDimitry Andric 
448bdd1243dSDimitry Andric INITIALIZE_PASS_BEGIN(ComplexDeinterleavingLegacyPass, DEBUG_TYPE,
449bdd1243dSDimitry Andric                       "Complex Deinterleaving", false, false)
450bdd1243dSDimitry Andric INITIALIZE_PASS_END(ComplexDeinterleavingLegacyPass, DEBUG_TYPE,
451bdd1243dSDimitry Andric                     "Complex Deinterleaving", false, false)
452bdd1243dSDimitry Andric 
453bdd1243dSDimitry Andric PreservedAnalyses ComplexDeinterleavingPass::run(Function &F,
454bdd1243dSDimitry Andric                                                  FunctionAnalysisManager &AM) {
455bdd1243dSDimitry Andric   const TargetLowering *TL = TM->getSubtargetImpl(F)->getTargetLowering();
456bdd1243dSDimitry Andric   auto &TLI = AM.getResult<llvm::TargetLibraryAnalysis>(F);
457bdd1243dSDimitry Andric   if (!ComplexDeinterleaving(TL, &TLI).runOnFunction(F))
458bdd1243dSDimitry Andric     return PreservedAnalyses::all();
459bdd1243dSDimitry Andric 
460bdd1243dSDimitry Andric   PreservedAnalyses PA;
461bdd1243dSDimitry Andric   PA.preserve<FunctionAnalysisManagerModuleProxy>();
462bdd1243dSDimitry Andric   return PA;
463bdd1243dSDimitry Andric }
464bdd1243dSDimitry Andric 
465bdd1243dSDimitry Andric FunctionPass *llvm::createComplexDeinterleavingPass(const TargetMachine *TM) {
466bdd1243dSDimitry Andric   return new ComplexDeinterleavingLegacyPass(TM);
467bdd1243dSDimitry Andric }
468bdd1243dSDimitry Andric 
469bdd1243dSDimitry Andric bool ComplexDeinterleavingLegacyPass::runOnFunction(Function &F) {
470bdd1243dSDimitry Andric   const auto *TL = TM->getSubtargetImpl(F)->getTargetLowering();
471bdd1243dSDimitry Andric   auto TLI = getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F);
472bdd1243dSDimitry Andric   return ComplexDeinterleaving(TL, &TLI).runOnFunction(F);
473bdd1243dSDimitry Andric }
474bdd1243dSDimitry Andric 
475bdd1243dSDimitry Andric bool ComplexDeinterleaving::runOnFunction(Function &F) {
476bdd1243dSDimitry Andric   if (!ComplexDeinterleavingEnabled) {
477bdd1243dSDimitry Andric     LLVM_DEBUG(
478bdd1243dSDimitry Andric         dbgs() << "Complex deinterleaving has been explicitly disabled.\n");
479bdd1243dSDimitry Andric     return false;
480bdd1243dSDimitry Andric   }
481bdd1243dSDimitry Andric 
482bdd1243dSDimitry Andric   if (!TL->isComplexDeinterleavingSupported()) {
483bdd1243dSDimitry Andric     LLVM_DEBUG(
484bdd1243dSDimitry Andric         dbgs() << "Complex deinterleaving has been disabled, target does "
485bdd1243dSDimitry Andric                   "not support lowering of complex number operations.\n");
486bdd1243dSDimitry Andric     return false;
487bdd1243dSDimitry Andric   }
488bdd1243dSDimitry Andric 
489bdd1243dSDimitry Andric   bool Changed = false;
490bdd1243dSDimitry Andric   for (auto &B : F)
491bdd1243dSDimitry Andric     Changed |= evaluateBasicBlock(&B);
492bdd1243dSDimitry Andric 
493bdd1243dSDimitry Andric   return Changed;
494bdd1243dSDimitry Andric }
495bdd1243dSDimitry Andric 
496bdd1243dSDimitry Andric static bool isInterleavingMask(ArrayRef<int> Mask) {
497bdd1243dSDimitry Andric   // If the size is not even, it's not an interleaving mask
498bdd1243dSDimitry Andric   if ((Mask.size() & 1))
499bdd1243dSDimitry Andric     return false;
500bdd1243dSDimitry Andric 
501bdd1243dSDimitry Andric   int HalfNumElements = Mask.size() / 2;
502bdd1243dSDimitry Andric   for (int Idx = 0; Idx < HalfNumElements; ++Idx) {
503bdd1243dSDimitry Andric     int MaskIdx = Idx * 2;
504bdd1243dSDimitry Andric     if (Mask[MaskIdx] != Idx || Mask[MaskIdx + 1] != (Idx + HalfNumElements))
505bdd1243dSDimitry Andric       return false;
506bdd1243dSDimitry Andric   }
507bdd1243dSDimitry Andric 
508bdd1243dSDimitry Andric   return true;
509bdd1243dSDimitry Andric }
510bdd1243dSDimitry Andric 
511bdd1243dSDimitry Andric static bool isDeinterleavingMask(ArrayRef<int> Mask) {
512bdd1243dSDimitry Andric   int Offset = Mask[0];
513bdd1243dSDimitry Andric   int HalfNumElements = Mask.size() / 2;
514bdd1243dSDimitry Andric 
515bdd1243dSDimitry Andric   for (int Idx = 1; Idx < HalfNumElements; ++Idx) {
516bdd1243dSDimitry Andric     if (Mask[Idx] != (Idx * 2) + Offset)
517bdd1243dSDimitry Andric       return false;
518bdd1243dSDimitry Andric   }
519bdd1243dSDimitry Andric 
520bdd1243dSDimitry Andric   return true;
521bdd1243dSDimitry Andric }
522bdd1243dSDimitry Andric 
523*06c3fb27SDimitry Andric bool isNeg(Value *V) {
524*06c3fb27SDimitry Andric   return match(V, m_FNeg(m_Value())) || match(V, m_Neg(m_Value()));
525*06c3fb27SDimitry Andric }
526*06c3fb27SDimitry Andric 
527*06c3fb27SDimitry Andric Value *getNegOperand(Value *V) {
528*06c3fb27SDimitry Andric   assert(isNeg(V));
529*06c3fb27SDimitry Andric   auto *I = cast<Instruction>(V);
530*06c3fb27SDimitry Andric   if (I->getOpcode() == Instruction::FNeg)
531*06c3fb27SDimitry Andric     return I->getOperand(0);
532*06c3fb27SDimitry Andric 
533*06c3fb27SDimitry Andric   return I->getOperand(1);
534*06c3fb27SDimitry Andric }
535*06c3fb27SDimitry Andric 
536bdd1243dSDimitry Andric bool ComplexDeinterleaving::evaluateBasicBlock(BasicBlock *B) {
537*06c3fb27SDimitry Andric   ComplexDeinterleavingGraph Graph(TL, TLI);
538*06c3fb27SDimitry Andric   if (Graph.collectPotentialReductions(B))
539*06c3fb27SDimitry Andric     Graph.identifyReductionNodes();
540bdd1243dSDimitry Andric 
541*06c3fb27SDimitry Andric   for (auto &I : *B)
542*06c3fb27SDimitry Andric     Graph.identifyNodes(&I);
543bdd1243dSDimitry Andric 
544*06c3fb27SDimitry Andric   if (Graph.checkNodes()) {
545bdd1243dSDimitry Andric     Graph.replaceNodes();
546*06c3fb27SDimitry Andric     return true;
547bdd1243dSDimitry Andric   }
548bdd1243dSDimitry Andric 
549*06c3fb27SDimitry Andric   return false;
550bdd1243dSDimitry Andric }
551bdd1243dSDimitry Andric 
552bdd1243dSDimitry Andric ComplexDeinterleavingGraph::NodePtr
553bdd1243dSDimitry Andric ComplexDeinterleavingGraph::identifyNodeWithImplicitAdd(
554bdd1243dSDimitry Andric     Instruction *Real, Instruction *Imag,
555*06c3fb27SDimitry Andric     std::pair<Value *, Value *> &PartialMatch) {
556bdd1243dSDimitry Andric   LLVM_DEBUG(dbgs() << "identifyNodeWithImplicitAdd " << *Real << " / " << *Imag
557bdd1243dSDimitry Andric                     << "\n");
558bdd1243dSDimitry Andric 
559bdd1243dSDimitry Andric   if (!Real->hasOneUse() || !Imag->hasOneUse()) {
560bdd1243dSDimitry Andric     LLVM_DEBUG(dbgs() << "  - Mul operand has multiple uses.\n");
561bdd1243dSDimitry Andric     return nullptr;
562bdd1243dSDimitry Andric   }
563bdd1243dSDimitry Andric 
564*06c3fb27SDimitry Andric   if ((Real->getOpcode() != Instruction::FMul &&
565*06c3fb27SDimitry Andric        Real->getOpcode() != Instruction::Mul) ||
566*06c3fb27SDimitry Andric       (Imag->getOpcode() != Instruction::FMul &&
567*06c3fb27SDimitry Andric        Imag->getOpcode() != Instruction::Mul)) {
568*06c3fb27SDimitry Andric     LLVM_DEBUG(
569*06c3fb27SDimitry Andric         dbgs() << "  - Real or imaginary instruction is not fmul or mul\n");
570bdd1243dSDimitry Andric     return nullptr;
571bdd1243dSDimitry Andric   }
572bdd1243dSDimitry Andric 
573*06c3fb27SDimitry Andric   Value *R0 = Real->getOperand(0);
574*06c3fb27SDimitry Andric   Value *R1 = Real->getOperand(1);
575*06c3fb27SDimitry Andric   Value *I0 = Imag->getOperand(0);
576*06c3fb27SDimitry Andric   Value *I1 = Imag->getOperand(1);
577bdd1243dSDimitry Andric 
578bdd1243dSDimitry Andric   // A +/+ has a rotation of 0. If any of the operands are fneg, we flip the
579bdd1243dSDimitry Andric   // rotations and use the operand.
580bdd1243dSDimitry Andric   unsigned Negs = 0;
581*06c3fb27SDimitry Andric   Value *Op;
582*06c3fb27SDimitry Andric   if (match(R0, m_Neg(m_Value(Op)))) {
583bdd1243dSDimitry Andric     Negs |= 1;
584*06c3fb27SDimitry Andric     R0 = Op;
585*06c3fb27SDimitry Andric   } else if (match(R1, m_Neg(m_Value(Op)))) {
586*06c3fb27SDimitry Andric     Negs |= 1;
587*06c3fb27SDimitry Andric     R1 = Op;
588bdd1243dSDimitry Andric   }
589*06c3fb27SDimitry Andric 
590*06c3fb27SDimitry Andric   if (isNeg(I0)) {
591bdd1243dSDimitry Andric     Negs |= 2;
592bdd1243dSDimitry Andric     Negs ^= 1;
593*06c3fb27SDimitry Andric     I0 = Op;
594*06c3fb27SDimitry Andric   } else if (match(I1, m_Neg(m_Value(Op)))) {
595*06c3fb27SDimitry Andric     Negs |= 2;
596*06c3fb27SDimitry Andric     Negs ^= 1;
597*06c3fb27SDimitry Andric     I1 = Op;
598bdd1243dSDimitry Andric   }
599bdd1243dSDimitry Andric 
600bdd1243dSDimitry Andric   ComplexDeinterleavingRotation Rotation = (ComplexDeinterleavingRotation)Negs;
601bdd1243dSDimitry Andric 
602*06c3fb27SDimitry Andric   Value *CommonOperand;
603*06c3fb27SDimitry Andric   Value *UncommonRealOp;
604*06c3fb27SDimitry Andric   Value *UncommonImagOp;
605bdd1243dSDimitry Andric 
606bdd1243dSDimitry Andric   if (R0 == I0 || R0 == I1) {
607bdd1243dSDimitry Andric     CommonOperand = R0;
608bdd1243dSDimitry Andric     UncommonRealOp = R1;
609bdd1243dSDimitry Andric   } else if (R1 == I0 || R1 == I1) {
610bdd1243dSDimitry Andric     CommonOperand = R1;
611bdd1243dSDimitry Andric     UncommonRealOp = R0;
612bdd1243dSDimitry Andric   } else {
613bdd1243dSDimitry Andric     LLVM_DEBUG(dbgs() << "  - No equal operand\n");
614bdd1243dSDimitry Andric     return nullptr;
615bdd1243dSDimitry Andric   }
616bdd1243dSDimitry Andric 
617bdd1243dSDimitry Andric   UncommonImagOp = (CommonOperand == I0) ? I1 : I0;
618bdd1243dSDimitry Andric   if (Rotation == ComplexDeinterleavingRotation::Rotation_90 ||
619bdd1243dSDimitry Andric       Rotation == ComplexDeinterleavingRotation::Rotation_270)
620bdd1243dSDimitry Andric     std::swap(UncommonRealOp, UncommonImagOp);
621bdd1243dSDimitry Andric 
622bdd1243dSDimitry Andric   // Between identifyPartialMul and here we need to have found a complete valid
623bdd1243dSDimitry Andric   // pair from the CommonOperand of each part.
624bdd1243dSDimitry Andric   if (Rotation == ComplexDeinterleavingRotation::Rotation_0 ||
625bdd1243dSDimitry Andric       Rotation == ComplexDeinterleavingRotation::Rotation_180)
626bdd1243dSDimitry Andric     PartialMatch.first = CommonOperand;
627bdd1243dSDimitry Andric   else
628bdd1243dSDimitry Andric     PartialMatch.second = CommonOperand;
629bdd1243dSDimitry Andric 
630bdd1243dSDimitry Andric   if (!PartialMatch.first || !PartialMatch.second) {
631bdd1243dSDimitry Andric     LLVM_DEBUG(dbgs() << "  - Incomplete partial match\n");
632bdd1243dSDimitry Andric     return nullptr;
633bdd1243dSDimitry Andric   }
634bdd1243dSDimitry Andric 
635bdd1243dSDimitry Andric   NodePtr CommonNode = identifyNode(PartialMatch.first, PartialMatch.second);
636bdd1243dSDimitry Andric   if (!CommonNode) {
637bdd1243dSDimitry Andric     LLVM_DEBUG(dbgs() << "  - No CommonNode identified\n");
638bdd1243dSDimitry Andric     return nullptr;
639bdd1243dSDimitry Andric   }
640bdd1243dSDimitry Andric 
641bdd1243dSDimitry Andric   NodePtr UncommonNode = identifyNode(UncommonRealOp, UncommonImagOp);
642bdd1243dSDimitry Andric   if (!UncommonNode) {
643bdd1243dSDimitry Andric     LLVM_DEBUG(dbgs() << "  - No UncommonNode identified\n");
644bdd1243dSDimitry Andric     return nullptr;
645bdd1243dSDimitry Andric   }
646bdd1243dSDimitry Andric 
647bdd1243dSDimitry Andric   NodePtr Node = prepareCompositeNode(
648bdd1243dSDimitry Andric       ComplexDeinterleavingOperation::CMulPartial, Real, Imag);
649bdd1243dSDimitry Andric   Node->Rotation = Rotation;
650bdd1243dSDimitry Andric   Node->addOperand(CommonNode);
651bdd1243dSDimitry Andric   Node->addOperand(UncommonNode);
652bdd1243dSDimitry Andric   return submitCompositeNode(Node);
653bdd1243dSDimitry Andric }
654bdd1243dSDimitry Andric 
655bdd1243dSDimitry Andric ComplexDeinterleavingGraph::NodePtr
656bdd1243dSDimitry Andric ComplexDeinterleavingGraph::identifyPartialMul(Instruction *Real,
657bdd1243dSDimitry Andric                                                Instruction *Imag) {
658bdd1243dSDimitry Andric   LLVM_DEBUG(dbgs() << "identifyPartialMul " << *Real << " / " << *Imag
659bdd1243dSDimitry Andric                     << "\n");
660bdd1243dSDimitry Andric   // Determine rotation
661*06c3fb27SDimitry Andric   auto IsAdd = [](unsigned Op) {
662*06c3fb27SDimitry Andric     return Op == Instruction::FAdd || Op == Instruction::Add;
663*06c3fb27SDimitry Andric   };
664*06c3fb27SDimitry Andric   auto IsSub = [](unsigned Op) {
665*06c3fb27SDimitry Andric     return Op == Instruction::FSub || Op == Instruction::Sub;
666*06c3fb27SDimitry Andric   };
667bdd1243dSDimitry Andric   ComplexDeinterleavingRotation Rotation;
668*06c3fb27SDimitry Andric   if (IsAdd(Real->getOpcode()) && IsAdd(Imag->getOpcode()))
669bdd1243dSDimitry Andric     Rotation = ComplexDeinterleavingRotation::Rotation_0;
670*06c3fb27SDimitry Andric   else if (IsSub(Real->getOpcode()) && IsAdd(Imag->getOpcode()))
671bdd1243dSDimitry Andric     Rotation = ComplexDeinterleavingRotation::Rotation_90;
672*06c3fb27SDimitry Andric   else if (IsSub(Real->getOpcode()) && IsSub(Imag->getOpcode()))
673bdd1243dSDimitry Andric     Rotation = ComplexDeinterleavingRotation::Rotation_180;
674*06c3fb27SDimitry Andric   else if (IsAdd(Real->getOpcode()) && IsSub(Imag->getOpcode()))
675bdd1243dSDimitry Andric     Rotation = ComplexDeinterleavingRotation::Rotation_270;
676bdd1243dSDimitry Andric   else {
677bdd1243dSDimitry Andric     LLVM_DEBUG(dbgs() << "  - Unhandled rotation.\n");
678bdd1243dSDimitry Andric     return nullptr;
679bdd1243dSDimitry Andric   }
680bdd1243dSDimitry Andric 
681*06c3fb27SDimitry Andric   if (isa<FPMathOperator>(Real) &&
682*06c3fb27SDimitry Andric       (!Real->getFastMathFlags().allowContract() ||
683*06c3fb27SDimitry Andric        !Imag->getFastMathFlags().allowContract())) {
684bdd1243dSDimitry Andric     LLVM_DEBUG(dbgs() << "  - Contract is missing from the FastMath flags.\n");
685bdd1243dSDimitry Andric     return nullptr;
686bdd1243dSDimitry Andric   }
687bdd1243dSDimitry Andric 
688bdd1243dSDimitry Andric   Value *CR = Real->getOperand(0);
689bdd1243dSDimitry Andric   Instruction *RealMulI = dyn_cast<Instruction>(Real->getOperand(1));
690bdd1243dSDimitry Andric   if (!RealMulI)
691bdd1243dSDimitry Andric     return nullptr;
692bdd1243dSDimitry Andric   Value *CI = Imag->getOperand(0);
693bdd1243dSDimitry Andric   Instruction *ImagMulI = dyn_cast<Instruction>(Imag->getOperand(1));
694bdd1243dSDimitry Andric   if (!ImagMulI)
695bdd1243dSDimitry Andric     return nullptr;
696bdd1243dSDimitry Andric 
697bdd1243dSDimitry Andric   if (!RealMulI->hasOneUse() || !ImagMulI->hasOneUse()) {
698bdd1243dSDimitry Andric     LLVM_DEBUG(dbgs() << "  - Mul instruction has multiple uses\n");
699bdd1243dSDimitry Andric     return nullptr;
700bdd1243dSDimitry Andric   }
701bdd1243dSDimitry Andric 
702*06c3fb27SDimitry Andric   Value *R0 = RealMulI->getOperand(0);
703*06c3fb27SDimitry Andric   Value *R1 = RealMulI->getOperand(1);
704*06c3fb27SDimitry Andric   Value *I0 = ImagMulI->getOperand(0);
705*06c3fb27SDimitry Andric   Value *I1 = ImagMulI->getOperand(1);
706bdd1243dSDimitry Andric 
707*06c3fb27SDimitry Andric   Value *CommonOperand;
708*06c3fb27SDimitry Andric   Value *UncommonRealOp;
709*06c3fb27SDimitry Andric   Value *UncommonImagOp;
710bdd1243dSDimitry Andric 
711bdd1243dSDimitry Andric   if (R0 == I0 || R0 == I1) {
712bdd1243dSDimitry Andric     CommonOperand = R0;
713bdd1243dSDimitry Andric     UncommonRealOp = R1;
714bdd1243dSDimitry Andric   } else if (R1 == I0 || R1 == I1) {
715bdd1243dSDimitry Andric     CommonOperand = R1;
716bdd1243dSDimitry Andric     UncommonRealOp = R0;
717bdd1243dSDimitry Andric   } else {
718bdd1243dSDimitry Andric     LLVM_DEBUG(dbgs() << "  - No equal operand\n");
719bdd1243dSDimitry Andric     return nullptr;
720bdd1243dSDimitry Andric   }
721bdd1243dSDimitry Andric 
722bdd1243dSDimitry Andric   UncommonImagOp = (CommonOperand == I0) ? I1 : I0;
723bdd1243dSDimitry Andric   if (Rotation == ComplexDeinterleavingRotation::Rotation_90 ||
724bdd1243dSDimitry Andric       Rotation == ComplexDeinterleavingRotation::Rotation_270)
725bdd1243dSDimitry Andric     std::swap(UncommonRealOp, UncommonImagOp);
726bdd1243dSDimitry Andric 
727*06c3fb27SDimitry Andric   std::pair<Value *, Value *> PartialMatch(
728bdd1243dSDimitry Andric       (Rotation == ComplexDeinterleavingRotation::Rotation_0 ||
729bdd1243dSDimitry Andric        Rotation == ComplexDeinterleavingRotation::Rotation_180)
730bdd1243dSDimitry Andric           ? CommonOperand
731bdd1243dSDimitry Andric           : nullptr,
732bdd1243dSDimitry Andric       (Rotation == ComplexDeinterleavingRotation::Rotation_90 ||
733bdd1243dSDimitry Andric        Rotation == ComplexDeinterleavingRotation::Rotation_270)
734bdd1243dSDimitry Andric           ? CommonOperand
735bdd1243dSDimitry Andric           : nullptr);
736*06c3fb27SDimitry Andric 
737*06c3fb27SDimitry Andric   auto *CRInst = dyn_cast<Instruction>(CR);
738*06c3fb27SDimitry Andric   auto *CIInst = dyn_cast<Instruction>(CI);
739*06c3fb27SDimitry Andric 
740*06c3fb27SDimitry Andric   if (!CRInst || !CIInst) {
741*06c3fb27SDimitry Andric     LLVM_DEBUG(dbgs() << "  - Common operands are not instructions.\n");
742*06c3fb27SDimitry Andric     return nullptr;
743*06c3fb27SDimitry Andric   }
744*06c3fb27SDimitry Andric 
745*06c3fb27SDimitry Andric   NodePtr CNode = identifyNodeWithImplicitAdd(CRInst, CIInst, PartialMatch);
746bdd1243dSDimitry Andric   if (!CNode) {
747bdd1243dSDimitry Andric     LLVM_DEBUG(dbgs() << "  - No cnode identified\n");
748bdd1243dSDimitry Andric     return nullptr;
749bdd1243dSDimitry Andric   }
750bdd1243dSDimitry Andric 
751bdd1243dSDimitry Andric   NodePtr UncommonRes = identifyNode(UncommonRealOp, UncommonImagOp);
752bdd1243dSDimitry Andric   if (!UncommonRes) {
753bdd1243dSDimitry Andric     LLVM_DEBUG(dbgs() << "  - No UncommonRes identified\n");
754bdd1243dSDimitry Andric     return nullptr;
755bdd1243dSDimitry Andric   }
756bdd1243dSDimitry Andric 
757bdd1243dSDimitry Andric   assert(PartialMatch.first && PartialMatch.second);
758bdd1243dSDimitry Andric   NodePtr CommonRes = identifyNode(PartialMatch.first, PartialMatch.second);
759bdd1243dSDimitry Andric   if (!CommonRes) {
760bdd1243dSDimitry Andric     LLVM_DEBUG(dbgs() << "  - No CommonRes identified\n");
761bdd1243dSDimitry Andric     return nullptr;
762bdd1243dSDimitry Andric   }
763bdd1243dSDimitry Andric 
764bdd1243dSDimitry Andric   NodePtr Node = prepareCompositeNode(
765bdd1243dSDimitry Andric       ComplexDeinterleavingOperation::CMulPartial, Real, Imag);
766bdd1243dSDimitry Andric   Node->Rotation = Rotation;
767bdd1243dSDimitry Andric   Node->addOperand(CommonRes);
768bdd1243dSDimitry Andric   Node->addOperand(UncommonRes);
769bdd1243dSDimitry Andric   Node->addOperand(CNode);
770bdd1243dSDimitry Andric   return submitCompositeNode(Node);
771bdd1243dSDimitry Andric }
772bdd1243dSDimitry Andric 
773bdd1243dSDimitry Andric ComplexDeinterleavingGraph::NodePtr
774bdd1243dSDimitry Andric ComplexDeinterleavingGraph::identifyAdd(Instruction *Real, Instruction *Imag) {
775bdd1243dSDimitry Andric   LLVM_DEBUG(dbgs() << "identifyAdd " << *Real << " / " << *Imag << "\n");
776bdd1243dSDimitry Andric 
777bdd1243dSDimitry Andric   // Determine rotation
778bdd1243dSDimitry Andric   ComplexDeinterleavingRotation Rotation;
779bdd1243dSDimitry Andric   if ((Real->getOpcode() == Instruction::FSub &&
780bdd1243dSDimitry Andric        Imag->getOpcode() == Instruction::FAdd) ||
781bdd1243dSDimitry Andric       (Real->getOpcode() == Instruction::Sub &&
782bdd1243dSDimitry Andric        Imag->getOpcode() == Instruction::Add))
783bdd1243dSDimitry Andric     Rotation = ComplexDeinterleavingRotation::Rotation_90;
784bdd1243dSDimitry Andric   else if ((Real->getOpcode() == Instruction::FAdd &&
785bdd1243dSDimitry Andric             Imag->getOpcode() == Instruction::FSub) ||
786bdd1243dSDimitry Andric            (Real->getOpcode() == Instruction::Add &&
787bdd1243dSDimitry Andric             Imag->getOpcode() == Instruction::Sub))
788bdd1243dSDimitry Andric     Rotation = ComplexDeinterleavingRotation::Rotation_270;
789bdd1243dSDimitry Andric   else {
790bdd1243dSDimitry Andric     LLVM_DEBUG(dbgs() << " - Unhandled case, rotation is not assigned.\n");
791bdd1243dSDimitry Andric     return nullptr;
792bdd1243dSDimitry Andric   }
793bdd1243dSDimitry Andric 
794bdd1243dSDimitry Andric   auto *AR = dyn_cast<Instruction>(Real->getOperand(0));
795bdd1243dSDimitry Andric   auto *BI = dyn_cast<Instruction>(Real->getOperand(1));
796bdd1243dSDimitry Andric   auto *AI = dyn_cast<Instruction>(Imag->getOperand(0));
797bdd1243dSDimitry Andric   auto *BR = dyn_cast<Instruction>(Imag->getOperand(1));
798bdd1243dSDimitry Andric 
799bdd1243dSDimitry Andric   if (!AR || !AI || !BR || !BI) {
800bdd1243dSDimitry Andric     LLVM_DEBUG(dbgs() << " - Not all operands are instructions.\n");
801bdd1243dSDimitry Andric     return nullptr;
802bdd1243dSDimitry Andric   }
803bdd1243dSDimitry Andric 
804bdd1243dSDimitry Andric   NodePtr ResA = identifyNode(AR, AI);
805bdd1243dSDimitry Andric   if (!ResA) {
806bdd1243dSDimitry Andric     LLVM_DEBUG(dbgs() << " - AR/AI is not identified as a composite node.\n");
807bdd1243dSDimitry Andric     return nullptr;
808bdd1243dSDimitry Andric   }
809bdd1243dSDimitry Andric   NodePtr ResB = identifyNode(BR, BI);
810bdd1243dSDimitry Andric   if (!ResB) {
811bdd1243dSDimitry Andric     LLVM_DEBUG(dbgs() << " - BR/BI is not identified as a composite node.\n");
812bdd1243dSDimitry Andric     return nullptr;
813bdd1243dSDimitry Andric   }
814bdd1243dSDimitry Andric 
815bdd1243dSDimitry Andric   NodePtr Node =
816bdd1243dSDimitry Andric       prepareCompositeNode(ComplexDeinterleavingOperation::CAdd, Real, Imag);
817bdd1243dSDimitry Andric   Node->Rotation = Rotation;
818bdd1243dSDimitry Andric   Node->addOperand(ResA);
819bdd1243dSDimitry Andric   Node->addOperand(ResB);
820bdd1243dSDimitry Andric   return submitCompositeNode(Node);
821bdd1243dSDimitry Andric }
822bdd1243dSDimitry Andric 
823bdd1243dSDimitry Andric static bool isInstructionPairAdd(Instruction *A, Instruction *B) {
824bdd1243dSDimitry Andric   unsigned OpcA = A->getOpcode();
825bdd1243dSDimitry Andric   unsigned OpcB = B->getOpcode();
826bdd1243dSDimitry Andric 
827bdd1243dSDimitry Andric   return (OpcA == Instruction::FSub && OpcB == Instruction::FAdd) ||
828bdd1243dSDimitry Andric          (OpcA == Instruction::FAdd && OpcB == Instruction::FSub) ||
829bdd1243dSDimitry Andric          (OpcA == Instruction::Sub && OpcB == Instruction::Add) ||
830bdd1243dSDimitry Andric          (OpcA == Instruction::Add && OpcB == Instruction::Sub);
831bdd1243dSDimitry Andric }
832bdd1243dSDimitry Andric 
833bdd1243dSDimitry Andric static bool isInstructionPairMul(Instruction *A, Instruction *B) {
834bdd1243dSDimitry Andric   auto Pattern =
835bdd1243dSDimitry Andric       m_BinOp(m_FMul(m_Value(), m_Value()), m_FMul(m_Value(), m_Value()));
836bdd1243dSDimitry Andric 
837bdd1243dSDimitry Andric   return match(A, Pattern) && match(B, Pattern);
838bdd1243dSDimitry Andric }
839bdd1243dSDimitry Andric 
840*06c3fb27SDimitry Andric static bool isInstructionPotentiallySymmetric(Instruction *I) {
841*06c3fb27SDimitry Andric   switch (I->getOpcode()) {
842*06c3fb27SDimitry Andric   case Instruction::FAdd:
843*06c3fb27SDimitry Andric   case Instruction::FSub:
844*06c3fb27SDimitry Andric   case Instruction::FMul:
845*06c3fb27SDimitry Andric   case Instruction::FNeg:
846*06c3fb27SDimitry Andric   case Instruction::Add:
847*06c3fb27SDimitry Andric   case Instruction::Sub:
848*06c3fb27SDimitry Andric   case Instruction::Mul:
849*06c3fb27SDimitry Andric     return true;
850*06c3fb27SDimitry Andric   default:
851*06c3fb27SDimitry Andric     return false;
852*06c3fb27SDimitry Andric   }
853*06c3fb27SDimitry Andric }
854*06c3fb27SDimitry Andric 
855bdd1243dSDimitry Andric ComplexDeinterleavingGraph::NodePtr
856*06c3fb27SDimitry Andric ComplexDeinterleavingGraph::identifySymmetricOperation(Instruction *Real,
857*06c3fb27SDimitry Andric                                                        Instruction *Imag) {
858*06c3fb27SDimitry Andric   if (Real->getOpcode() != Imag->getOpcode())
859*06c3fb27SDimitry Andric     return nullptr;
860*06c3fb27SDimitry Andric 
861*06c3fb27SDimitry Andric   if (!isInstructionPotentiallySymmetric(Real) ||
862*06c3fb27SDimitry Andric       !isInstructionPotentiallySymmetric(Imag))
863*06c3fb27SDimitry Andric     return nullptr;
864*06c3fb27SDimitry Andric 
865*06c3fb27SDimitry Andric   auto *R0 = Real->getOperand(0);
866*06c3fb27SDimitry Andric   auto *I0 = Imag->getOperand(0);
867*06c3fb27SDimitry Andric 
868*06c3fb27SDimitry Andric   NodePtr Op0 = identifyNode(R0, I0);
869*06c3fb27SDimitry Andric   NodePtr Op1 = nullptr;
870*06c3fb27SDimitry Andric   if (Op0 == nullptr)
871*06c3fb27SDimitry Andric     return nullptr;
872*06c3fb27SDimitry Andric 
873*06c3fb27SDimitry Andric   if (Real->isBinaryOp()) {
874*06c3fb27SDimitry Andric     auto *R1 = Real->getOperand(1);
875*06c3fb27SDimitry Andric     auto *I1 = Imag->getOperand(1);
876*06c3fb27SDimitry Andric     Op1 = identifyNode(R1, I1);
877*06c3fb27SDimitry Andric     if (Op1 == nullptr)
878*06c3fb27SDimitry Andric       return nullptr;
879*06c3fb27SDimitry Andric   }
880*06c3fb27SDimitry Andric 
881*06c3fb27SDimitry Andric   if (isa<FPMathOperator>(Real) &&
882*06c3fb27SDimitry Andric       Real->getFastMathFlags() != Imag->getFastMathFlags())
883*06c3fb27SDimitry Andric     return nullptr;
884*06c3fb27SDimitry Andric 
885*06c3fb27SDimitry Andric   auto Node = prepareCompositeNode(ComplexDeinterleavingOperation::Symmetric,
886*06c3fb27SDimitry Andric                                    Real, Imag);
887*06c3fb27SDimitry Andric   Node->Opcode = Real->getOpcode();
888*06c3fb27SDimitry Andric   if (isa<FPMathOperator>(Real))
889*06c3fb27SDimitry Andric     Node->Flags = Real->getFastMathFlags();
890*06c3fb27SDimitry Andric 
891*06c3fb27SDimitry Andric   Node->addOperand(Op0);
892*06c3fb27SDimitry Andric   if (Real->isBinaryOp())
893*06c3fb27SDimitry Andric     Node->addOperand(Op1);
894*06c3fb27SDimitry Andric 
895*06c3fb27SDimitry Andric   return submitCompositeNode(Node);
896*06c3fb27SDimitry Andric }
897*06c3fb27SDimitry Andric 
898*06c3fb27SDimitry Andric ComplexDeinterleavingGraph::NodePtr
899*06c3fb27SDimitry Andric ComplexDeinterleavingGraph::identifyNode(Value *R, Value *I) {
900*06c3fb27SDimitry Andric   LLVM_DEBUG(dbgs() << "identifyNode on " << *R << " / " << *I << "\n");
901*06c3fb27SDimitry Andric   assert(R->getType() == I->getType() &&
902*06c3fb27SDimitry Andric          "Real and imaginary parts should not have different types");
903*06c3fb27SDimitry Andric   if (NodePtr CN = getContainingComposite(R, I)) {
904bdd1243dSDimitry Andric     LLVM_DEBUG(dbgs() << " - Folding to existing node\n");
905bdd1243dSDimitry Andric     return CN;
906bdd1243dSDimitry Andric   }
907bdd1243dSDimitry Andric 
908*06c3fb27SDimitry Andric   if (NodePtr CN = identifySplat(R, I))
909*06c3fb27SDimitry Andric     return CN;
910*06c3fb27SDimitry Andric 
911*06c3fb27SDimitry Andric   auto *Real = dyn_cast<Instruction>(R);
912*06c3fb27SDimitry Andric   auto *Imag = dyn_cast<Instruction>(I);
913*06c3fb27SDimitry Andric   if (!Real || !Imag)
914*06c3fb27SDimitry Andric     return nullptr;
915*06c3fb27SDimitry Andric 
916*06c3fb27SDimitry Andric   if (NodePtr CN = identifyDeinterleave(Real, Imag))
917*06c3fb27SDimitry Andric     return CN;
918*06c3fb27SDimitry Andric 
919*06c3fb27SDimitry Andric   if (NodePtr CN = identifyPHINode(Real, Imag))
920*06c3fb27SDimitry Andric     return CN;
921*06c3fb27SDimitry Andric 
922*06c3fb27SDimitry Andric   if (NodePtr CN = identifySelectNode(Real, Imag))
923*06c3fb27SDimitry Andric     return CN;
924*06c3fb27SDimitry Andric 
925*06c3fb27SDimitry Andric   auto *VTy = cast<VectorType>(Real->getType());
926*06c3fb27SDimitry Andric   auto *NewVTy = VectorType::getDoubleElementsVectorType(VTy);
927*06c3fb27SDimitry Andric 
928*06c3fb27SDimitry Andric   bool HasCMulSupport = TL->isComplexDeinterleavingOperationSupported(
929*06c3fb27SDimitry Andric       ComplexDeinterleavingOperation::CMulPartial, NewVTy);
930*06c3fb27SDimitry Andric   bool HasCAddSupport = TL->isComplexDeinterleavingOperationSupported(
931*06c3fb27SDimitry Andric       ComplexDeinterleavingOperation::CAdd, NewVTy);
932*06c3fb27SDimitry Andric 
933*06c3fb27SDimitry Andric   if (HasCMulSupport && isInstructionPairMul(Real, Imag)) {
934*06c3fb27SDimitry Andric     if (NodePtr CN = identifyPartialMul(Real, Imag))
935*06c3fb27SDimitry Andric       return CN;
936*06c3fb27SDimitry Andric   }
937*06c3fb27SDimitry Andric 
938*06c3fb27SDimitry Andric   if (HasCAddSupport && isInstructionPairAdd(Real, Imag)) {
939*06c3fb27SDimitry Andric     if (NodePtr CN = identifyAdd(Real, Imag))
940*06c3fb27SDimitry Andric       return CN;
941*06c3fb27SDimitry Andric   }
942*06c3fb27SDimitry Andric 
943*06c3fb27SDimitry Andric   if (HasCMulSupport && HasCAddSupport) {
944*06c3fb27SDimitry Andric     if (NodePtr CN = identifyReassocNodes(Real, Imag))
945*06c3fb27SDimitry Andric       return CN;
946*06c3fb27SDimitry Andric   }
947*06c3fb27SDimitry Andric 
948*06c3fb27SDimitry Andric   if (NodePtr CN = identifySymmetricOperation(Real, Imag))
949*06c3fb27SDimitry Andric     return CN;
950*06c3fb27SDimitry Andric 
951*06c3fb27SDimitry Andric   LLVM_DEBUG(dbgs() << "  - Not recognised as a valid pattern.\n");
952*06c3fb27SDimitry Andric   return nullptr;
953*06c3fb27SDimitry Andric }
954*06c3fb27SDimitry Andric 
955*06c3fb27SDimitry Andric ComplexDeinterleavingGraph::NodePtr
956*06c3fb27SDimitry Andric ComplexDeinterleavingGraph::identifyReassocNodes(Instruction *Real,
957*06c3fb27SDimitry Andric                                                  Instruction *Imag) {
958*06c3fb27SDimitry Andric   auto IsOperationSupported = [](unsigned Opcode) -> bool {
959*06c3fb27SDimitry Andric     return Opcode == Instruction::FAdd || Opcode == Instruction::FSub ||
960*06c3fb27SDimitry Andric            Opcode == Instruction::FNeg || Opcode == Instruction::Add ||
961*06c3fb27SDimitry Andric            Opcode == Instruction::Sub;
962*06c3fb27SDimitry Andric   };
963*06c3fb27SDimitry Andric 
964*06c3fb27SDimitry Andric   if (!IsOperationSupported(Real->getOpcode()) ||
965*06c3fb27SDimitry Andric       !IsOperationSupported(Imag->getOpcode()))
966*06c3fb27SDimitry Andric     return nullptr;
967*06c3fb27SDimitry Andric 
968*06c3fb27SDimitry Andric   std::optional<FastMathFlags> Flags;
969*06c3fb27SDimitry Andric   if (isa<FPMathOperator>(Real)) {
970*06c3fb27SDimitry Andric     if (Real->getFastMathFlags() != Imag->getFastMathFlags()) {
971*06c3fb27SDimitry Andric       LLVM_DEBUG(dbgs() << "The flags in Real and Imaginary instructions are "
972*06c3fb27SDimitry Andric                            "not identical\n");
973*06c3fb27SDimitry Andric       return nullptr;
974*06c3fb27SDimitry Andric     }
975*06c3fb27SDimitry Andric 
976*06c3fb27SDimitry Andric     Flags = Real->getFastMathFlags();
977*06c3fb27SDimitry Andric     if (!Flags->allowReassoc()) {
978*06c3fb27SDimitry Andric       LLVM_DEBUG(
979*06c3fb27SDimitry Andric           dbgs()
980*06c3fb27SDimitry Andric           << "the 'Reassoc' attribute is missing in the FastMath flags\n");
981*06c3fb27SDimitry Andric       return nullptr;
982*06c3fb27SDimitry Andric     }
983*06c3fb27SDimitry Andric   }
984*06c3fb27SDimitry Andric 
985*06c3fb27SDimitry Andric   // Collect multiplications and addend instructions from the given instruction
986*06c3fb27SDimitry Andric   // while traversing it operands. Additionally, verify that all instructions
987*06c3fb27SDimitry Andric   // have the same fast math flags.
988*06c3fb27SDimitry Andric   auto Collect = [&Flags](Instruction *Insn, std::vector<Product> &Muls,
989*06c3fb27SDimitry Andric                           std::list<Addend> &Addends) -> bool {
990*06c3fb27SDimitry Andric     SmallVector<PointerIntPair<Value *, 1, bool>> Worklist = {{Insn, true}};
991*06c3fb27SDimitry Andric     SmallPtrSet<Value *, 8> Visited;
992*06c3fb27SDimitry Andric     while (!Worklist.empty()) {
993*06c3fb27SDimitry Andric       auto [V, IsPositive] = Worklist.back();
994*06c3fb27SDimitry Andric       Worklist.pop_back();
995*06c3fb27SDimitry Andric       if (!Visited.insert(V).second)
996*06c3fb27SDimitry Andric         continue;
997*06c3fb27SDimitry Andric 
998*06c3fb27SDimitry Andric       Instruction *I = dyn_cast<Instruction>(V);
999*06c3fb27SDimitry Andric       if (!I) {
1000*06c3fb27SDimitry Andric         Addends.emplace_back(V, IsPositive);
1001*06c3fb27SDimitry Andric         continue;
1002*06c3fb27SDimitry Andric       }
1003*06c3fb27SDimitry Andric 
1004*06c3fb27SDimitry Andric       // If an instruction has more than one user, it indicates that it either
1005*06c3fb27SDimitry Andric       // has an external user, which will be later checked by the checkNodes
1006*06c3fb27SDimitry Andric       // function, or it is a subexpression utilized by multiple expressions. In
1007*06c3fb27SDimitry Andric       // the latter case, we will attempt to separately identify the complex
1008*06c3fb27SDimitry Andric       // operation from here in order to create a shared
1009*06c3fb27SDimitry Andric       // ComplexDeinterleavingCompositeNode.
1010*06c3fb27SDimitry Andric       if (I != Insn && I->getNumUses() > 1) {
1011*06c3fb27SDimitry Andric         LLVM_DEBUG(dbgs() << "Found potential sub-expression: " << *I << "\n");
1012*06c3fb27SDimitry Andric         Addends.emplace_back(I, IsPositive);
1013*06c3fb27SDimitry Andric         continue;
1014*06c3fb27SDimitry Andric       }
1015*06c3fb27SDimitry Andric       switch (I->getOpcode()) {
1016*06c3fb27SDimitry Andric       case Instruction::FAdd:
1017*06c3fb27SDimitry Andric       case Instruction::Add:
1018*06c3fb27SDimitry Andric         Worklist.emplace_back(I->getOperand(1), IsPositive);
1019*06c3fb27SDimitry Andric         Worklist.emplace_back(I->getOperand(0), IsPositive);
1020*06c3fb27SDimitry Andric         break;
1021*06c3fb27SDimitry Andric       case Instruction::FSub:
1022*06c3fb27SDimitry Andric         Worklist.emplace_back(I->getOperand(1), !IsPositive);
1023*06c3fb27SDimitry Andric         Worklist.emplace_back(I->getOperand(0), IsPositive);
1024*06c3fb27SDimitry Andric         break;
1025*06c3fb27SDimitry Andric       case Instruction::Sub:
1026*06c3fb27SDimitry Andric         if (isNeg(I)) {
1027*06c3fb27SDimitry Andric           Worklist.emplace_back(getNegOperand(I), !IsPositive);
1028*06c3fb27SDimitry Andric         } else {
1029*06c3fb27SDimitry Andric           Worklist.emplace_back(I->getOperand(1), !IsPositive);
1030*06c3fb27SDimitry Andric           Worklist.emplace_back(I->getOperand(0), IsPositive);
1031*06c3fb27SDimitry Andric         }
1032*06c3fb27SDimitry Andric         break;
1033*06c3fb27SDimitry Andric       case Instruction::FMul:
1034*06c3fb27SDimitry Andric       case Instruction::Mul: {
1035*06c3fb27SDimitry Andric         Value *A, *B;
1036*06c3fb27SDimitry Andric         if (isNeg(I->getOperand(0))) {
1037*06c3fb27SDimitry Andric           A = getNegOperand(I->getOperand(0));
1038*06c3fb27SDimitry Andric           IsPositive = !IsPositive;
1039*06c3fb27SDimitry Andric         } else {
1040*06c3fb27SDimitry Andric           A = I->getOperand(0);
1041*06c3fb27SDimitry Andric         }
1042*06c3fb27SDimitry Andric 
1043*06c3fb27SDimitry Andric         if (isNeg(I->getOperand(1))) {
1044*06c3fb27SDimitry Andric           B = getNegOperand(I->getOperand(1));
1045*06c3fb27SDimitry Andric           IsPositive = !IsPositive;
1046*06c3fb27SDimitry Andric         } else {
1047*06c3fb27SDimitry Andric           B = I->getOperand(1);
1048*06c3fb27SDimitry Andric         }
1049*06c3fb27SDimitry Andric         Muls.push_back(Product{A, B, IsPositive});
1050*06c3fb27SDimitry Andric         break;
1051*06c3fb27SDimitry Andric       }
1052*06c3fb27SDimitry Andric       case Instruction::FNeg:
1053*06c3fb27SDimitry Andric         Worklist.emplace_back(I->getOperand(0), !IsPositive);
1054*06c3fb27SDimitry Andric         break;
1055*06c3fb27SDimitry Andric       default:
1056*06c3fb27SDimitry Andric         Addends.emplace_back(I, IsPositive);
1057*06c3fb27SDimitry Andric         continue;
1058*06c3fb27SDimitry Andric       }
1059*06c3fb27SDimitry Andric 
1060*06c3fb27SDimitry Andric       if (Flags && I->getFastMathFlags() != *Flags) {
1061*06c3fb27SDimitry Andric         LLVM_DEBUG(dbgs() << "The instruction's fast math flags are "
1062*06c3fb27SDimitry Andric                              "inconsistent with the root instructions' flags: "
1063*06c3fb27SDimitry Andric                           << *I << "\n");
1064*06c3fb27SDimitry Andric         return false;
1065*06c3fb27SDimitry Andric       }
1066*06c3fb27SDimitry Andric     }
1067*06c3fb27SDimitry Andric     return true;
1068*06c3fb27SDimitry Andric   };
1069*06c3fb27SDimitry Andric 
1070*06c3fb27SDimitry Andric   std::vector<Product> RealMuls, ImagMuls;
1071*06c3fb27SDimitry Andric   std::list<Addend> RealAddends, ImagAddends;
1072*06c3fb27SDimitry Andric   if (!Collect(Real, RealMuls, RealAddends) ||
1073*06c3fb27SDimitry Andric       !Collect(Imag, ImagMuls, ImagAddends))
1074*06c3fb27SDimitry Andric     return nullptr;
1075*06c3fb27SDimitry Andric 
1076*06c3fb27SDimitry Andric   if (RealAddends.size() != ImagAddends.size())
1077*06c3fb27SDimitry Andric     return nullptr;
1078*06c3fb27SDimitry Andric 
1079*06c3fb27SDimitry Andric   NodePtr FinalNode;
1080*06c3fb27SDimitry Andric   if (!RealMuls.empty() || !ImagMuls.empty()) {
1081*06c3fb27SDimitry Andric     // If there are multiplicands, extract positive addend and use it as an
1082*06c3fb27SDimitry Andric     // accumulator
1083*06c3fb27SDimitry Andric     FinalNode = extractPositiveAddend(RealAddends, ImagAddends);
1084*06c3fb27SDimitry Andric     FinalNode = identifyMultiplications(RealMuls, ImagMuls, FinalNode);
1085*06c3fb27SDimitry Andric     if (!FinalNode)
1086*06c3fb27SDimitry Andric       return nullptr;
1087*06c3fb27SDimitry Andric   }
1088*06c3fb27SDimitry Andric 
1089*06c3fb27SDimitry Andric   // Identify and process remaining additions
1090*06c3fb27SDimitry Andric   if (!RealAddends.empty() || !ImagAddends.empty()) {
1091*06c3fb27SDimitry Andric     FinalNode = identifyAdditions(RealAddends, ImagAddends, Flags, FinalNode);
1092*06c3fb27SDimitry Andric     if (!FinalNode)
1093*06c3fb27SDimitry Andric       return nullptr;
1094*06c3fb27SDimitry Andric   }
1095*06c3fb27SDimitry Andric   assert(FinalNode && "FinalNode can not be nullptr here");
1096*06c3fb27SDimitry Andric   // Set the Real and Imag fields of the final node and submit it
1097*06c3fb27SDimitry Andric   FinalNode->Real = Real;
1098*06c3fb27SDimitry Andric   FinalNode->Imag = Imag;
1099*06c3fb27SDimitry Andric   submitCompositeNode(FinalNode);
1100*06c3fb27SDimitry Andric   return FinalNode;
1101*06c3fb27SDimitry Andric }
1102*06c3fb27SDimitry Andric 
1103*06c3fb27SDimitry Andric bool ComplexDeinterleavingGraph::collectPartialMuls(
1104*06c3fb27SDimitry Andric     const std::vector<Product> &RealMuls, const std::vector<Product> &ImagMuls,
1105*06c3fb27SDimitry Andric     std::vector<PartialMulCandidate> &PartialMulCandidates) {
1106*06c3fb27SDimitry Andric   // Helper function to extract a common operand from two products
1107*06c3fb27SDimitry Andric   auto FindCommonInstruction = [](const Product &Real,
1108*06c3fb27SDimitry Andric                                   const Product &Imag) -> Value * {
1109*06c3fb27SDimitry Andric     if (Real.Multiplicand == Imag.Multiplicand ||
1110*06c3fb27SDimitry Andric         Real.Multiplicand == Imag.Multiplier)
1111*06c3fb27SDimitry Andric       return Real.Multiplicand;
1112*06c3fb27SDimitry Andric 
1113*06c3fb27SDimitry Andric     if (Real.Multiplier == Imag.Multiplicand ||
1114*06c3fb27SDimitry Andric         Real.Multiplier == Imag.Multiplier)
1115*06c3fb27SDimitry Andric       return Real.Multiplier;
1116*06c3fb27SDimitry Andric 
1117*06c3fb27SDimitry Andric     return nullptr;
1118*06c3fb27SDimitry Andric   };
1119*06c3fb27SDimitry Andric 
1120*06c3fb27SDimitry Andric   // Iterating over real and imaginary multiplications to find common operands
1121*06c3fb27SDimitry Andric   // If a common operand is found, a partial multiplication candidate is created
1122*06c3fb27SDimitry Andric   // and added to the candidates vector The function returns false if no common
1123*06c3fb27SDimitry Andric   // operands are found for any product
1124*06c3fb27SDimitry Andric   for (unsigned i = 0; i < RealMuls.size(); ++i) {
1125*06c3fb27SDimitry Andric     bool FoundCommon = false;
1126*06c3fb27SDimitry Andric     for (unsigned j = 0; j < ImagMuls.size(); ++j) {
1127*06c3fb27SDimitry Andric       auto *Common = FindCommonInstruction(RealMuls[i], ImagMuls[j]);
1128*06c3fb27SDimitry Andric       if (!Common)
1129*06c3fb27SDimitry Andric         continue;
1130*06c3fb27SDimitry Andric 
1131*06c3fb27SDimitry Andric       auto *A = RealMuls[i].Multiplicand == Common ? RealMuls[i].Multiplier
1132*06c3fb27SDimitry Andric                                                    : RealMuls[i].Multiplicand;
1133*06c3fb27SDimitry Andric       auto *B = ImagMuls[j].Multiplicand == Common ? ImagMuls[j].Multiplier
1134*06c3fb27SDimitry Andric                                                    : ImagMuls[j].Multiplicand;
1135*06c3fb27SDimitry Andric 
1136*06c3fb27SDimitry Andric       auto Node = identifyNode(A, B);
1137*06c3fb27SDimitry Andric       if (Node) {
1138*06c3fb27SDimitry Andric         FoundCommon = true;
1139*06c3fb27SDimitry Andric         PartialMulCandidates.push_back({Common, Node, i, j, false});
1140*06c3fb27SDimitry Andric       }
1141*06c3fb27SDimitry Andric 
1142*06c3fb27SDimitry Andric       Node = identifyNode(B, A);
1143*06c3fb27SDimitry Andric       if (Node) {
1144*06c3fb27SDimitry Andric         FoundCommon = true;
1145*06c3fb27SDimitry Andric         PartialMulCandidates.push_back({Common, Node, i, j, true});
1146*06c3fb27SDimitry Andric       }
1147*06c3fb27SDimitry Andric     }
1148*06c3fb27SDimitry Andric     if (!FoundCommon)
1149*06c3fb27SDimitry Andric       return false;
1150*06c3fb27SDimitry Andric   }
1151*06c3fb27SDimitry Andric   return true;
1152*06c3fb27SDimitry Andric }
1153*06c3fb27SDimitry Andric 
1154*06c3fb27SDimitry Andric ComplexDeinterleavingGraph::NodePtr
1155*06c3fb27SDimitry Andric ComplexDeinterleavingGraph::identifyMultiplications(
1156*06c3fb27SDimitry Andric     std::vector<Product> &RealMuls, std::vector<Product> &ImagMuls,
1157*06c3fb27SDimitry Andric     NodePtr Accumulator = nullptr) {
1158*06c3fb27SDimitry Andric   if (RealMuls.size() != ImagMuls.size())
1159*06c3fb27SDimitry Andric     return nullptr;
1160*06c3fb27SDimitry Andric 
1161*06c3fb27SDimitry Andric   std::vector<PartialMulCandidate> Info;
1162*06c3fb27SDimitry Andric   if (!collectPartialMuls(RealMuls, ImagMuls, Info))
1163*06c3fb27SDimitry Andric     return nullptr;
1164*06c3fb27SDimitry Andric 
1165*06c3fb27SDimitry Andric   // Map to store common instruction to node pointers
1166*06c3fb27SDimitry Andric   std::map<Value *, NodePtr> CommonToNode;
1167*06c3fb27SDimitry Andric   std::vector<bool> Processed(Info.size(), false);
1168*06c3fb27SDimitry Andric   for (unsigned I = 0; I < Info.size(); ++I) {
1169*06c3fb27SDimitry Andric     if (Processed[I])
1170*06c3fb27SDimitry Andric       continue;
1171*06c3fb27SDimitry Andric 
1172*06c3fb27SDimitry Andric     PartialMulCandidate &InfoA = Info[I];
1173*06c3fb27SDimitry Andric     for (unsigned J = I + 1; J < Info.size(); ++J) {
1174*06c3fb27SDimitry Andric       if (Processed[J])
1175*06c3fb27SDimitry Andric         continue;
1176*06c3fb27SDimitry Andric 
1177*06c3fb27SDimitry Andric       PartialMulCandidate &InfoB = Info[J];
1178*06c3fb27SDimitry Andric       auto *InfoReal = &InfoA;
1179*06c3fb27SDimitry Andric       auto *InfoImag = &InfoB;
1180*06c3fb27SDimitry Andric 
1181*06c3fb27SDimitry Andric       auto NodeFromCommon = identifyNode(InfoReal->Common, InfoImag->Common);
1182*06c3fb27SDimitry Andric       if (!NodeFromCommon) {
1183*06c3fb27SDimitry Andric         std::swap(InfoReal, InfoImag);
1184*06c3fb27SDimitry Andric         NodeFromCommon = identifyNode(InfoReal->Common, InfoImag->Common);
1185*06c3fb27SDimitry Andric       }
1186*06c3fb27SDimitry Andric       if (!NodeFromCommon)
1187*06c3fb27SDimitry Andric         continue;
1188*06c3fb27SDimitry Andric 
1189*06c3fb27SDimitry Andric       CommonToNode[InfoReal->Common] = NodeFromCommon;
1190*06c3fb27SDimitry Andric       CommonToNode[InfoImag->Common] = NodeFromCommon;
1191*06c3fb27SDimitry Andric       Processed[I] = true;
1192*06c3fb27SDimitry Andric       Processed[J] = true;
1193*06c3fb27SDimitry Andric     }
1194*06c3fb27SDimitry Andric   }
1195*06c3fb27SDimitry Andric 
1196*06c3fb27SDimitry Andric   std::vector<bool> ProcessedReal(RealMuls.size(), false);
1197*06c3fb27SDimitry Andric   std::vector<bool> ProcessedImag(ImagMuls.size(), false);
1198*06c3fb27SDimitry Andric   NodePtr Result = Accumulator;
1199*06c3fb27SDimitry Andric   for (auto &PMI : Info) {
1200*06c3fb27SDimitry Andric     if (ProcessedReal[PMI.RealIdx] || ProcessedImag[PMI.ImagIdx])
1201*06c3fb27SDimitry Andric       continue;
1202*06c3fb27SDimitry Andric 
1203*06c3fb27SDimitry Andric     auto It = CommonToNode.find(PMI.Common);
1204*06c3fb27SDimitry Andric     // TODO: Process independent complex multiplications. Cases like this:
1205*06c3fb27SDimitry Andric     //  A.real() * B where both A and B are complex numbers.
1206*06c3fb27SDimitry Andric     if (It == CommonToNode.end()) {
1207*06c3fb27SDimitry Andric       LLVM_DEBUG({
1208*06c3fb27SDimitry Andric         dbgs() << "Unprocessed independent partial multiplication:\n";
1209*06c3fb27SDimitry Andric         for (auto *Mul : {&RealMuls[PMI.RealIdx], &RealMuls[PMI.RealIdx]})
1210*06c3fb27SDimitry Andric           dbgs().indent(4) << (Mul->IsPositive ? "+" : "-") << *Mul->Multiplier
1211*06c3fb27SDimitry Andric                            << " multiplied by " << *Mul->Multiplicand << "\n";
1212*06c3fb27SDimitry Andric       });
1213*06c3fb27SDimitry Andric       return nullptr;
1214*06c3fb27SDimitry Andric     }
1215*06c3fb27SDimitry Andric 
1216*06c3fb27SDimitry Andric     auto &RealMul = RealMuls[PMI.RealIdx];
1217*06c3fb27SDimitry Andric     auto &ImagMul = ImagMuls[PMI.ImagIdx];
1218*06c3fb27SDimitry Andric 
1219*06c3fb27SDimitry Andric     auto NodeA = It->second;
1220*06c3fb27SDimitry Andric     auto NodeB = PMI.Node;
1221*06c3fb27SDimitry Andric     auto IsMultiplicandReal = PMI.Common == NodeA->Real;
1222*06c3fb27SDimitry Andric     // The following table illustrates the relationship between multiplications
1223*06c3fb27SDimitry Andric     // and rotations. If we consider the multiplication (X + iY) * (U + iV), we
1224*06c3fb27SDimitry Andric     // can see:
1225*06c3fb27SDimitry Andric     //
1226*06c3fb27SDimitry Andric     // Rotation |   Real |   Imag |
1227*06c3fb27SDimitry Andric     // ---------+--------+--------+
1228*06c3fb27SDimitry Andric     //        0 |  x * u |  x * v |
1229*06c3fb27SDimitry Andric     //       90 | -y * v |  y * u |
1230*06c3fb27SDimitry Andric     //      180 | -x * u | -x * v |
1231*06c3fb27SDimitry Andric     //      270 |  y * v | -y * u |
1232*06c3fb27SDimitry Andric     //
1233*06c3fb27SDimitry Andric     // Check if the candidate can indeed be represented by partial
1234*06c3fb27SDimitry Andric     // multiplication
1235*06c3fb27SDimitry Andric     // TODO: Add support for multiplication by complex one
1236*06c3fb27SDimitry Andric     if ((IsMultiplicandReal && PMI.IsNodeInverted) ||
1237*06c3fb27SDimitry Andric         (!IsMultiplicandReal && !PMI.IsNodeInverted))
1238*06c3fb27SDimitry Andric       continue;
1239*06c3fb27SDimitry Andric 
1240*06c3fb27SDimitry Andric     // Determine the rotation based on the multiplications
1241*06c3fb27SDimitry Andric     ComplexDeinterleavingRotation Rotation;
1242*06c3fb27SDimitry Andric     if (IsMultiplicandReal) {
1243*06c3fb27SDimitry Andric       // Detect 0 and 180 degrees rotation
1244*06c3fb27SDimitry Andric       if (RealMul.IsPositive && ImagMul.IsPositive)
1245*06c3fb27SDimitry Andric         Rotation = llvm::ComplexDeinterleavingRotation::Rotation_0;
1246*06c3fb27SDimitry Andric       else if (!RealMul.IsPositive && !ImagMul.IsPositive)
1247*06c3fb27SDimitry Andric         Rotation = llvm::ComplexDeinterleavingRotation::Rotation_180;
1248*06c3fb27SDimitry Andric       else
1249*06c3fb27SDimitry Andric         continue;
1250*06c3fb27SDimitry Andric 
1251*06c3fb27SDimitry Andric     } else {
1252*06c3fb27SDimitry Andric       // Detect 90 and 270 degrees rotation
1253*06c3fb27SDimitry Andric       if (!RealMul.IsPositive && ImagMul.IsPositive)
1254*06c3fb27SDimitry Andric         Rotation = llvm::ComplexDeinterleavingRotation::Rotation_90;
1255*06c3fb27SDimitry Andric       else if (RealMul.IsPositive && !ImagMul.IsPositive)
1256*06c3fb27SDimitry Andric         Rotation = llvm::ComplexDeinterleavingRotation::Rotation_270;
1257*06c3fb27SDimitry Andric       else
1258*06c3fb27SDimitry Andric         continue;
1259*06c3fb27SDimitry Andric     }
1260*06c3fb27SDimitry Andric 
1261*06c3fb27SDimitry Andric     LLVM_DEBUG({
1262*06c3fb27SDimitry Andric       dbgs() << "Identified partial multiplication (X, Y) * (U, V):\n";
1263*06c3fb27SDimitry Andric       dbgs().indent(4) << "X: " << *NodeA->Real << "\n";
1264*06c3fb27SDimitry Andric       dbgs().indent(4) << "Y: " << *NodeA->Imag << "\n";
1265*06c3fb27SDimitry Andric       dbgs().indent(4) << "U: " << *NodeB->Real << "\n";
1266*06c3fb27SDimitry Andric       dbgs().indent(4) << "V: " << *NodeB->Imag << "\n";
1267*06c3fb27SDimitry Andric       dbgs().indent(4) << "Rotation - " << (int)Rotation * 90 << "\n";
1268*06c3fb27SDimitry Andric     });
1269*06c3fb27SDimitry Andric 
1270*06c3fb27SDimitry Andric     NodePtr NodeMul = prepareCompositeNode(
1271*06c3fb27SDimitry Andric         ComplexDeinterleavingOperation::CMulPartial, nullptr, nullptr);
1272*06c3fb27SDimitry Andric     NodeMul->Rotation = Rotation;
1273*06c3fb27SDimitry Andric     NodeMul->addOperand(NodeA);
1274*06c3fb27SDimitry Andric     NodeMul->addOperand(NodeB);
1275*06c3fb27SDimitry Andric     if (Result)
1276*06c3fb27SDimitry Andric       NodeMul->addOperand(Result);
1277*06c3fb27SDimitry Andric     submitCompositeNode(NodeMul);
1278*06c3fb27SDimitry Andric     Result = NodeMul;
1279*06c3fb27SDimitry Andric     ProcessedReal[PMI.RealIdx] = true;
1280*06c3fb27SDimitry Andric     ProcessedImag[PMI.ImagIdx] = true;
1281*06c3fb27SDimitry Andric   }
1282*06c3fb27SDimitry Andric 
1283*06c3fb27SDimitry Andric   // Ensure all products have been processed, if not return nullptr.
1284*06c3fb27SDimitry Andric   if (!all_of(ProcessedReal, [](bool V) { return V; }) ||
1285*06c3fb27SDimitry Andric       !all_of(ProcessedImag, [](bool V) { return V; })) {
1286*06c3fb27SDimitry Andric 
1287*06c3fb27SDimitry Andric     // Dump debug information about which partial multiplications are not
1288*06c3fb27SDimitry Andric     // processed.
1289*06c3fb27SDimitry Andric     LLVM_DEBUG({
1290*06c3fb27SDimitry Andric       dbgs() << "Unprocessed products (Real):\n";
1291*06c3fb27SDimitry Andric       for (size_t i = 0; i < ProcessedReal.size(); ++i) {
1292*06c3fb27SDimitry Andric         if (!ProcessedReal[i])
1293*06c3fb27SDimitry Andric           dbgs().indent(4) << (RealMuls[i].IsPositive ? "+" : "-")
1294*06c3fb27SDimitry Andric                            << *RealMuls[i].Multiplier << " multiplied by "
1295*06c3fb27SDimitry Andric                            << *RealMuls[i].Multiplicand << "\n";
1296*06c3fb27SDimitry Andric       }
1297*06c3fb27SDimitry Andric       dbgs() << "Unprocessed products (Imag):\n";
1298*06c3fb27SDimitry Andric       for (size_t i = 0; i < ProcessedImag.size(); ++i) {
1299*06c3fb27SDimitry Andric         if (!ProcessedImag[i])
1300*06c3fb27SDimitry Andric           dbgs().indent(4) << (ImagMuls[i].IsPositive ? "+" : "-")
1301*06c3fb27SDimitry Andric                            << *ImagMuls[i].Multiplier << " multiplied by "
1302*06c3fb27SDimitry Andric                            << *ImagMuls[i].Multiplicand << "\n";
1303*06c3fb27SDimitry Andric       }
1304*06c3fb27SDimitry Andric     });
1305*06c3fb27SDimitry Andric     return nullptr;
1306*06c3fb27SDimitry Andric   }
1307*06c3fb27SDimitry Andric 
1308*06c3fb27SDimitry Andric   return Result;
1309*06c3fb27SDimitry Andric }
1310*06c3fb27SDimitry Andric 
1311*06c3fb27SDimitry Andric ComplexDeinterleavingGraph::NodePtr
1312*06c3fb27SDimitry Andric ComplexDeinterleavingGraph::identifyAdditions(
1313*06c3fb27SDimitry Andric     std::list<Addend> &RealAddends, std::list<Addend> &ImagAddends,
1314*06c3fb27SDimitry Andric     std::optional<FastMathFlags> Flags, NodePtr Accumulator = nullptr) {
1315*06c3fb27SDimitry Andric   if (RealAddends.size() != ImagAddends.size())
1316*06c3fb27SDimitry Andric     return nullptr;
1317*06c3fb27SDimitry Andric 
1318*06c3fb27SDimitry Andric   NodePtr Result;
1319*06c3fb27SDimitry Andric   // If we have accumulator use it as first addend
1320*06c3fb27SDimitry Andric   if (Accumulator)
1321*06c3fb27SDimitry Andric     Result = Accumulator;
1322*06c3fb27SDimitry Andric   // Otherwise find an element with both positive real and imaginary parts.
1323*06c3fb27SDimitry Andric   else
1324*06c3fb27SDimitry Andric     Result = extractPositiveAddend(RealAddends, ImagAddends);
1325*06c3fb27SDimitry Andric 
1326*06c3fb27SDimitry Andric   if (!Result)
1327*06c3fb27SDimitry Andric     return nullptr;
1328*06c3fb27SDimitry Andric 
1329*06c3fb27SDimitry Andric   while (!RealAddends.empty()) {
1330*06c3fb27SDimitry Andric     auto ItR = RealAddends.begin();
1331*06c3fb27SDimitry Andric     auto [R, IsPositiveR] = *ItR;
1332*06c3fb27SDimitry Andric 
1333*06c3fb27SDimitry Andric     bool FoundImag = false;
1334*06c3fb27SDimitry Andric     for (auto ItI = ImagAddends.begin(); ItI != ImagAddends.end(); ++ItI) {
1335*06c3fb27SDimitry Andric       auto [I, IsPositiveI] = *ItI;
1336*06c3fb27SDimitry Andric       ComplexDeinterleavingRotation Rotation;
1337*06c3fb27SDimitry Andric       if (IsPositiveR && IsPositiveI)
1338*06c3fb27SDimitry Andric         Rotation = ComplexDeinterleavingRotation::Rotation_0;
1339*06c3fb27SDimitry Andric       else if (!IsPositiveR && IsPositiveI)
1340*06c3fb27SDimitry Andric         Rotation = ComplexDeinterleavingRotation::Rotation_90;
1341*06c3fb27SDimitry Andric       else if (!IsPositiveR && !IsPositiveI)
1342*06c3fb27SDimitry Andric         Rotation = ComplexDeinterleavingRotation::Rotation_180;
1343*06c3fb27SDimitry Andric       else
1344*06c3fb27SDimitry Andric         Rotation = ComplexDeinterleavingRotation::Rotation_270;
1345*06c3fb27SDimitry Andric 
1346*06c3fb27SDimitry Andric       NodePtr AddNode;
1347*06c3fb27SDimitry Andric       if (Rotation == ComplexDeinterleavingRotation::Rotation_0 ||
1348*06c3fb27SDimitry Andric           Rotation == ComplexDeinterleavingRotation::Rotation_180) {
1349*06c3fb27SDimitry Andric         AddNode = identifyNode(R, I);
1350*06c3fb27SDimitry Andric       } else {
1351*06c3fb27SDimitry Andric         AddNode = identifyNode(I, R);
1352*06c3fb27SDimitry Andric       }
1353*06c3fb27SDimitry Andric       if (AddNode) {
1354*06c3fb27SDimitry Andric         LLVM_DEBUG({
1355*06c3fb27SDimitry Andric           dbgs() << "Identified addition:\n";
1356*06c3fb27SDimitry Andric           dbgs().indent(4) << "X: " << *R << "\n";
1357*06c3fb27SDimitry Andric           dbgs().indent(4) << "Y: " << *I << "\n";
1358*06c3fb27SDimitry Andric           dbgs().indent(4) << "Rotation - " << (int)Rotation * 90 << "\n";
1359*06c3fb27SDimitry Andric         });
1360*06c3fb27SDimitry Andric 
1361*06c3fb27SDimitry Andric         NodePtr TmpNode;
1362*06c3fb27SDimitry Andric         if (Rotation == llvm::ComplexDeinterleavingRotation::Rotation_0) {
1363*06c3fb27SDimitry Andric           TmpNode = prepareCompositeNode(
1364*06c3fb27SDimitry Andric               ComplexDeinterleavingOperation::Symmetric, nullptr, nullptr);
1365*06c3fb27SDimitry Andric           if (Flags) {
1366*06c3fb27SDimitry Andric             TmpNode->Opcode = Instruction::FAdd;
1367*06c3fb27SDimitry Andric             TmpNode->Flags = *Flags;
1368*06c3fb27SDimitry Andric           } else {
1369*06c3fb27SDimitry Andric             TmpNode->Opcode = Instruction::Add;
1370*06c3fb27SDimitry Andric           }
1371*06c3fb27SDimitry Andric         } else if (Rotation ==
1372*06c3fb27SDimitry Andric                    llvm::ComplexDeinterleavingRotation::Rotation_180) {
1373*06c3fb27SDimitry Andric           TmpNode = prepareCompositeNode(
1374*06c3fb27SDimitry Andric               ComplexDeinterleavingOperation::Symmetric, nullptr, nullptr);
1375*06c3fb27SDimitry Andric           if (Flags) {
1376*06c3fb27SDimitry Andric             TmpNode->Opcode = Instruction::FSub;
1377*06c3fb27SDimitry Andric             TmpNode->Flags = *Flags;
1378*06c3fb27SDimitry Andric           } else {
1379*06c3fb27SDimitry Andric             TmpNode->Opcode = Instruction::Sub;
1380*06c3fb27SDimitry Andric           }
1381*06c3fb27SDimitry Andric         } else {
1382*06c3fb27SDimitry Andric           TmpNode = prepareCompositeNode(ComplexDeinterleavingOperation::CAdd,
1383*06c3fb27SDimitry Andric                                          nullptr, nullptr);
1384*06c3fb27SDimitry Andric           TmpNode->Rotation = Rotation;
1385*06c3fb27SDimitry Andric         }
1386*06c3fb27SDimitry Andric 
1387*06c3fb27SDimitry Andric         TmpNode->addOperand(Result);
1388*06c3fb27SDimitry Andric         TmpNode->addOperand(AddNode);
1389*06c3fb27SDimitry Andric         submitCompositeNode(TmpNode);
1390*06c3fb27SDimitry Andric         Result = TmpNode;
1391*06c3fb27SDimitry Andric         RealAddends.erase(ItR);
1392*06c3fb27SDimitry Andric         ImagAddends.erase(ItI);
1393*06c3fb27SDimitry Andric         FoundImag = true;
1394*06c3fb27SDimitry Andric         break;
1395*06c3fb27SDimitry Andric       }
1396*06c3fb27SDimitry Andric     }
1397*06c3fb27SDimitry Andric     if (!FoundImag)
1398*06c3fb27SDimitry Andric       return nullptr;
1399*06c3fb27SDimitry Andric   }
1400*06c3fb27SDimitry Andric   return Result;
1401*06c3fb27SDimitry Andric }
1402*06c3fb27SDimitry Andric 
1403*06c3fb27SDimitry Andric ComplexDeinterleavingGraph::NodePtr
1404*06c3fb27SDimitry Andric ComplexDeinterleavingGraph::extractPositiveAddend(
1405*06c3fb27SDimitry Andric     std::list<Addend> &RealAddends, std::list<Addend> &ImagAddends) {
1406*06c3fb27SDimitry Andric   for (auto ItR = RealAddends.begin(); ItR != RealAddends.end(); ++ItR) {
1407*06c3fb27SDimitry Andric     for (auto ItI = ImagAddends.begin(); ItI != ImagAddends.end(); ++ItI) {
1408*06c3fb27SDimitry Andric       auto [R, IsPositiveR] = *ItR;
1409*06c3fb27SDimitry Andric       auto [I, IsPositiveI] = *ItI;
1410*06c3fb27SDimitry Andric       if (IsPositiveR && IsPositiveI) {
1411*06c3fb27SDimitry Andric         auto Result = identifyNode(R, I);
1412*06c3fb27SDimitry Andric         if (Result) {
1413*06c3fb27SDimitry Andric           RealAddends.erase(ItR);
1414*06c3fb27SDimitry Andric           ImagAddends.erase(ItI);
1415*06c3fb27SDimitry Andric           return Result;
1416*06c3fb27SDimitry Andric         }
1417*06c3fb27SDimitry Andric       }
1418*06c3fb27SDimitry Andric     }
1419*06c3fb27SDimitry Andric   }
1420*06c3fb27SDimitry Andric   return nullptr;
1421*06c3fb27SDimitry Andric }
1422*06c3fb27SDimitry Andric 
1423*06c3fb27SDimitry Andric bool ComplexDeinterleavingGraph::identifyNodes(Instruction *RootI) {
1424*06c3fb27SDimitry Andric   // This potential root instruction might already have been recognized as
1425*06c3fb27SDimitry Andric   // reduction. Because RootToNode maps both Real and Imaginary parts to
1426*06c3fb27SDimitry Andric   // CompositeNode we should choose only one either Real or Imag instruction to
1427*06c3fb27SDimitry Andric   // use as an anchor for generating complex instruction.
1428*06c3fb27SDimitry Andric   auto It = RootToNode.find(RootI);
1429*06c3fb27SDimitry Andric   if (It != RootToNode.end() && It->second->Real == RootI) {
1430*06c3fb27SDimitry Andric     OrderedRoots.push_back(RootI);
1431*06c3fb27SDimitry Andric     return true;
1432*06c3fb27SDimitry Andric   }
1433*06c3fb27SDimitry Andric 
1434*06c3fb27SDimitry Andric   auto RootNode = identifyRoot(RootI);
1435*06c3fb27SDimitry Andric   if (!RootNode)
1436*06c3fb27SDimitry Andric     return false;
1437*06c3fb27SDimitry Andric 
1438*06c3fb27SDimitry Andric   LLVM_DEBUG({
1439*06c3fb27SDimitry Andric     Function *F = RootI->getFunction();
1440*06c3fb27SDimitry Andric     BasicBlock *B = RootI->getParent();
1441*06c3fb27SDimitry Andric     dbgs() << "Complex deinterleaving graph for " << F->getName()
1442*06c3fb27SDimitry Andric            << "::" << B->getName() << ".\n";
1443*06c3fb27SDimitry Andric     dump(dbgs());
1444*06c3fb27SDimitry Andric     dbgs() << "\n";
1445*06c3fb27SDimitry Andric   });
1446*06c3fb27SDimitry Andric   RootToNode[RootI] = RootNode;
1447*06c3fb27SDimitry Andric   OrderedRoots.push_back(RootI);
1448*06c3fb27SDimitry Andric   return true;
1449*06c3fb27SDimitry Andric }
1450*06c3fb27SDimitry Andric 
1451*06c3fb27SDimitry Andric bool ComplexDeinterleavingGraph::collectPotentialReductions(BasicBlock *B) {
1452*06c3fb27SDimitry Andric   bool FoundPotentialReduction = false;
1453*06c3fb27SDimitry Andric 
1454*06c3fb27SDimitry Andric   auto *Br = dyn_cast<BranchInst>(B->getTerminator());
1455*06c3fb27SDimitry Andric   if (!Br || Br->getNumSuccessors() != 2)
1456*06c3fb27SDimitry Andric     return false;
1457*06c3fb27SDimitry Andric 
1458*06c3fb27SDimitry Andric   // Identify simple one-block loop
1459*06c3fb27SDimitry Andric   if (Br->getSuccessor(0) != B && Br->getSuccessor(1) != B)
1460*06c3fb27SDimitry Andric     return false;
1461*06c3fb27SDimitry Andric 
1462*06c3fb27SDimitry Andric   SmallVector<PHINode *> PHIs;
1463*06c3fb27SDimitry Andric   for (auto &PHI : B->phis()) {
1464*06c3fb27SDimitry Andric     if (PHI.getNumIncomingValues() != 2)
1465*06c3fb27SDimitry Andric       continue;
1466*06c3fb27SDimitry Andric 
1467*06c3fb27SDimitry Andric     if (!PHI.getType()->isVectorTy())
1468*06c3fb27SDimitry Andric       continue;
1469*06c3fb27SDimitry Andric 
1470*06c3fb27SDimitry Andric     auto *ReductionOp = dyn_cast<Instruction>(PHI.getIncomingValueForBlock(B));
1471*06c3fb27SDimitry Andric     if (!ReductionOp)
1472*06c3fb27SDimitry Andric       continue;
1473*06c3fb27SDimitry Andric 
1474*06c3fb27SDimitry Andric     // Check if final instruction is reduced outside of current block
1475*06c3fb27SDimitry Andric     Instruction *FinalReduction = nullptr;
1476*06c3fb27SDimitry Andric     auto NumUsers = 0u;
1477*06c3fb27SDimitry Andric     for (auto *U : ReductionOp->users()) {
1478*06c3fb27SDimitry Andric       ++NumUsers;
1479*06c3fb27SDimitry Andric       if (U == &PHI)
1480*06c3fb27SDimitry Andric         continue;
1481*06c3fb27SDimitry Andric       FinalReduction = dyn_cast<Instruction>(U);
1482*06c3fb27SDimitry Andric     }
1483*06c3fb27SDimitry Andric 
1484*06c3fb27SDimitry Andric     if (NumUsers != 2 || !FinalReduction || FinalReduction->getParent() == B ||
1485*06c3fb27SDimitry Andric         isa<PHINode>(FinalReduction))
1486*06c3fb27SDimitry Andric       continue;
1487*06c3fb27SDimitry Andric 
1488*06c3fb27SDimitry Andric     ReductionInfo[ReductionOp] = {&PHI, FinalReduction};
1489*06c3fb27SDimitry Andric     BackEdge = B;
1490*06c3fb27SDimitry Andric     auto BackEdgeIdx = PHI.getBasicBlockIndex(B);
1491*06c3fb27SDimitry Andric     auto IncomingIdx = BackEdgeIdx == 0 ? 1 : 0;
1492*06c3fb27SDimitry Andric     Incoming = PHI.getIncomingBlock(IncomingIdx);
1493*06c3fb27SDimitry Andric     FoundPotentialReduction = true;
1494*06c3fb27SDimitry Andric 
1495*06c3fb27SDimitry Andric     // If the initial value of PHINode is an Instruction, consider it a leaf
1496*06c3fb27SDimitry Andric     // value of a complex deinterleaving graph.
1497*06c3fb27SDimitry Andric     if (auto *InitPHI =
1498*06c3fb27SDimitry Andric             dyn_cast<Instruction>(PHI.getIncomingValueForBlock(Incoming)))
1499*06c3fb27SDimitry Andric       FinalInstructions.insert(InitPHI);
1500*06c3fb27SDimitry Andric   }
1501*06c3fb27SDimitry Andric   return FoundPotentialReduction;
1502*06c3fb27SDimitry Andric }
1503*06c3fb27SDimitry Andric 
1504*06c3fb27SDimitry Andric void ComplexDeinterleavingGraph::identifyReductionNodes() {
1505*06c3fb27SDimitry Andric   SmallVector<bool> Processed(ReductionInfo.size(), false);
1506*06c3fb27SDimitry Andric   SmallVector<Instruction *> OperationInstruction;
1507*06c3fb27SDimitry Andric   for (auto &P : ReductionInfo)
1508*06c3fb27SDimitry Andric     OperationInstruction.push_back(P.first);
1509*06c3fb27SDimitry Andric 
1510*06c3fb27SDimitry Andric   // Identify a complex computation by evaluating two reduction operations that
1511*06c3fb27SDimitry Andric   // potentially could be involved
1512*06c3fb27SDimitry Andric   for (size_t i = 0; i < OperationInstruction.size(); ++i) {
1513*06c3fb27SDimitry Andric     if (Processed[i])
1514*06c3fb27SDimitry Andric       continue;
1515*06c3fb27SDimitry Andric     for (size_t j = i + 1; j < OperationInstruction.size(); ++j) {
1516*06c3fb27SDimitry Andric       if (Processed[j])
1517*06c3fb27SDimitry Andric         continue;
1518*06c3fb27SDimitry Andric 
1519*06c3fb27SDimitry Andric       auto *Real = OperationInstruction[i];
1520*06c3fb27SDimitry Andric       auto *Imag = OperationInstruction[j];
1521*06c3fb27SDimitry Andric       if (Real->getType() != Imag->getType())
1522*06c3fb27SDimitry Andric         continue;
1523*06c3fb27SDimitry Andric 
1524*06c3fb27SDimitry Andric       RealPHI = ReductionInfo[Real].first;
1525*06c3fb27SDimitry Andric       ImagPHI = ReductionInfo[Imag].first;
1526*06c3fb27SDimitry Andric       PHIsFound = false;
1527*06c3fb27SDimitry Andric       auto Node = identifyNode(Real, Imag);
1528*06c3fb27SDimitry Andric       if (!Node) {
1529*06c3fb27SDimitry Andric         std::swap(Real, Imag);
1530*06c3fb27SDimitry Andric         std::swap(RealPHI, ImagPHI);
1531*06c3fb27SDimitry Andric         Node = identifyNode(Real, Imag);
1532*06c3fb27SDimitry Andric       }
1533*06c3fb27SDimitry Andric 
1534*06c3fb27SDimitry Andric       // If a node is identified and reduction PHINode is used in the chain of
1535*06c3fb27SDimitry Andric       // operations, mark its operation instructions as used to prevent
1536*06c3fb27SDimitry Andric       // re-identification and attach the node to the real part
1537*06c3fb27SDimitry Andric       if (Node && PHIsFound) {
1538*06c3fb27SDimitry Andric         LLVM_DEBUG(dbgs() << "Identified reduction starting from instructions: "
1539*06c3fb27SDimitry Andric                           << *Real << " / " << *Imag << "\n");
1540*06c3fb27SDimitry Andric         Processed[i] = true;
1541*06c3fb27SDimitry Andric         Processed[j] = true;
1542*06c3fb27SDimitry Andric         auto RootNode = prepareCompositeNode(
1543*06c3fb27SDimitry Andric             ComplexDeinterleavingOperation::ReductionOperation, Real, Imag);
1544*06c3fb27SDimitry Andric         RootNode->addOperand(Node);
1545*06c3fb27SDimitry Andric         RootToNode[Real] = RootNode;
1546*06c3fb27SDimitry Andric         RootToNode[Imag] = RootNode;
1547*06c3fb27SDimitry Andric         submitCompositeNode(RootNode);
1548*06c3fb27SDimitry Andric         break;
1549*06c3fb27SDimitry Andric       }
1550*06c3fb27SDimitry Andric     }
1551*06c3fb27SDimitry Andric   }
1552*06c3fb27SDimitry Andric 
1553*06c3fb27SDimitry Andric   RealPHI = nullptr;
1554*06c3fb27SDimitry Andric   ImagPHI = nullptr;
1555*06c3fb27SDimitry Andric }
1556*06c3fb27SDimitry Andric 
1557*06c3fb27SDimitry Andric bool ComplexDeinterleavingGraph::checkNodes() {
1558*06c3fb27SDimitry Andric   // Collect all instructions from roots to leaves
1559*06c3fb27SDimitry Andric   SmallPtrSet<Instruction *, 16> AllInstructions;
1560*06c3fb27SDimitry Andric   SmallVector<Instruction *, 8> Worklist;
1561*06c3fb27SDimitry Andric   for (auto &Pair : RootToNode)
1562*06c3fb27SDimitry Andric     Worklist.push_back(Pair.first);
1563*06c3fb27SDimitry Andric 
1564*06c3fb27SDimitry Andric   // Extract all instructions that are used by all XCMLA/XCADD/ADD/SUB/NEG
1565*06c3fb27SDimitry Andric   // chains
1566*06c3fb27SDimitry Andric   while (!Worklist.empty()) {
1567*06c3fb27SDimitry Andric     auto *I = Worklist.back();
1568*06c3fb27SDimitry Andric     Worklist.pop_back();
1569*06c3fb27SDimitry Andric 
1570*06c3fb27SDimitry Andric     if (!AllInstructions.insert(I).second)
1571*06c3fb27SDimitry Andric       continue;
1572*06c3fb27SDimitry Andric 
1573*06c3fb27SDimitry Andric     for (Value *Op : I->operands()) {
1574*06c3fb27SDimitry Andric       if (auto *OpI = dyn_cast<Instruction>(Op)) {
1575*06c3fb27SDimitry Andric         if (!FinalInstructions.count(I))
1576*06c3fb27SDimitry Andric           Worklist.emplace_back(OpI);
1577*06c3fb27SDimitry Andric       }
1578*06c3fb27SDimitry Andric     }
1579*06c3fb27SDimitry Andric   }
1580*06c3fb27SDimitry Andric 
1581*06c3fb27SDimitry Andric   // Find instructions that have users outside of chain
1582*06c3fb27SDimitry Andric   SmallVector<Instruction *, 2> OuterInstructions;
1583*06c3fb27SDimitry Andric   for (auto *I : AllInstructions) {
1584*06c3fb27SDimitry Andric     // Skip root nodes
1585*06c3fb27SDimitry Andric     if (RootToNode.count(I))
1586*06c3fb27SDimitry Andric       continue;
1587*06c3fb27SDimitry Andric 
1588*06c3fb27SDimitry Andric     for (User *U : I->users()) {
1589*06c3fb27SDimitry Andric       if (AllInstructions.count(cast<Instruction>(U)))
1590*06c3fb27SDimitry Andric         continue;
1591*06c3fb27SDimitry Andric 
1592*06c3fb27SDimitry Andric       // Found an instruction that is not used by XCMLA/XCADD chain
1593*06c3fb27SDimitry Andric       Worklist.emplace_back(I);
1594*06c3fb27SDimitry Andric       break;
1595*06c3fb27SDimitry Andric     }
1596*06c3fb27SDimitry Andric   }
1597*06c3fb27SDimitry Andric 
1598*06c3fb27SDimitry Andric   // If any instructions are found to be used outside, find and remove roots
1599*06c3fb27SDimitry Andric   // that somehow connect to those instructions.
1600*06c3fb27SDimitry Andric   SmallPtrSet<Instruction *, 16> Visited;
1601*06c3fb27SDimitry Andric   while (!Worklist.empty()) {
1602*06c3fb27SDimitry Andric     auto *I = Worklist.back();
1603*06c3fb27SDimitry Andric     Worklist.pop_back();
1604*06c3fb27SDimitry Andric     if (!Visited.insert(I).second)
1605*06c3fb27SDimitry Andric       continue;
1606*06c3fb27SDimitry Andric 
1607*06c3fb27SDimitry Andric     // Found an impacted root node. Removing it from the nodes to be
1608*06c3fb27SDimitry Andric     // deinterleaved
1609*06c3fb27SDimitry Andric     if (RootToNode.count(I)) {
1610*06c3fb27SDimitry Andric       LLVM_DEBUG(dbgs() << "Instruction " << *I
1611*06c3fb27SDimitry Andric                         << " could be deinterleaved but its chain of complex "
1612*06c3fb27SDimitry Andric                            "operations have an outside user\n");
1613*06c3fb27SDimitry Andric       RootToNode.erase(I);
1614*06c3fb27SDimitry Andric     }
1615*06c3fb27SDimitry Andric 
1616*06c3fb27SDimitry Andric     if (!AllInstructions.count(I) || FinalInstructions.count(I))
1617*06c3fb27SDimitry Andric       continue;
1618*06c3fb27SDimitry Andric 
1619*06c3fb27SDimitry Andric     for (User *U : I->users())
1620*06c3fb27SDimitry Andric       Worklist.emplace_back(cast<Instruction>(U));
1621*06c3fb27SDimitry Andric 
1622*06c3fb27SDimitry Andric     for (Value *Op : I->operands()) {
1623*06c3fb27SDimitry Andric       if (auto *OpI = dyn_cast<Instruction>(Op))
1624*06c3fb27SDimitry Andric         Worklist.emplace_back(OpI);
1625*06c3fb27SDimitry Andric     }
1626*06c3fb27SDimitry Andric   }
1627*06c3fb27SDimitry Andric   return !RootToNode.empty();
1628*06c3fb27SDimitry Andric }
1629*06c3fb27SDimitry Andric 
1630*06c3fb27SDimitry Andric ComplexDeinterleavingGraph::NodePtr
1631*06c3fb27SDimitry Andric ComplexDeinterleavingGraph::identifyRoot(Instruction *RootI) {
1632*06c3fb27SDimitry Andric   if (auto *Intrinsic = dyn_cast<IntrinsicInst>(RootI)) {
1633*06c3fb27SDimitry Andric     if (Intrinsic->getIntrinsicID() !=
1634*06c3fb27SDimitry Andric         Intrinsic::experimental_vector_interleave2)
1635*06c3fb27SDimitry Andric       return nullptr;
1636*06c3fb27SDimitry Andric 
1637*06c3fb27SDimitry Andric     auto *Real = dyn_cast<Instruction>(Intrinsic->getOperand(0));
1638*06c3fb27SDimitry Andric     auto *Imag = dyn_cast<Instruction>(Intrinsic->getOperand(1));
1639*06c3fb27SDimitry Andric     if (!Real || !Imag)
1640*06c3fb27SDimitry Andric       return nullptr;
1641*06c3fb27SDimitry Andric 
1642*06c3fb27SDimitry Andric     return identifyNode(Real, Imag);
1643*06c3fb27SDimitry Andric   }
1644*06c3fb27SDimitry Andric 
1645*06c3fb27SDimitry Andric   auto *SVI = dyn_cast<ShuffleVectorInst>(RootI);
1646*06c3fb27SDimitry Andric   if (!SVI)
1647*06c3fb27SDimitry Andric     return nullptr;
1648*06c3fb27SDimitry Andric 
1649*06c3fb27SDimitry Andric   // Look for a shufflevector that takes separate vectors of the real and
1650*06c3fb27SDimitry Andric   // imaginary components and recombines them into a single vector.
1651*06c3fb27SDimitry Andric   if (!isInterleavingMask(SVI->getShuffleMask()))
1652*06c3fb27SDimitry Andric     return nullptr;
1653*06c3fb27SDimitry Andric 
1654*06c3fb27SDimitry Andric   Instruction *Real;
1655*06c3fb27SDimitry Andric   Instruction *Imag;
1656*06c3fb27SDimitry Andric   if (!match(RootI, m_Shuffle(m_Instruction(Real), m_Instruction(Imag))))
1657*06c3fb27SDimitry Andric     return nullptr;
1658*06c3fb27SDimitry Andric 
1659*06c3fb27SDimitry Andric   return identifyNode(Real, Imag);
1660*06c3fb27SDimitry Andric }
1661*06c3fb27SDimitry Andric 
1662*06c3fb27SDimitry Andric ComplexDeinterleavingGraph::NodePtr
1663*06c3fb27SDimitry Andric ComplexDeinterleavingGraph::identifyDeinterleave(Instruction *Real,
1664*06c3fb27SDimitry Andric                                                  Instruction *Imag) {
1665*06c3fb27SDimitry Andric   Instruction *I = nullptr;
1666*06c3fb27SDimitry Andric   Value *FinalValue = nullptr;
1667*06c3fb27SDimitry Andric   if (match(Real, m_ExtractValue<0>(m_Instruction(I))) &&
1668*06c3fb27SDimitry Andric       match(Imag, m_ExtractValue<1>(m_Specific(I))) &&
1669*06c3fb27SDimitry Andric       match(I, m_Intrinsic<Intrinsic::experimental_vector_deinterleave2>(
1670*06c3fb27SDimitry Andric                    m_Value(FinalValue)))) {
1671*06c3fb27SDimitry Andric     NodePtr PlaceholderNode = prepareCompositeNode(
1672*06c3fb27SDimitry Andric         llvm::ComplexDeinterleavingOperation::Deinterleave, Real, Imag);
1673*06c3fb27SDimitry Andric     PlaceholderNode->ReplacementNode = FinalValue;
1674*06c3fb27SDimitry Andric     FinalInstructions.insert(Real);
1675*06c3fb27SDimitry Andric     FinalInstructions.insert(Imag);
1676*06c3fb27SDimitry Andric     return submitCompositeNode(PlaceholderNode);
1677*06c3fb27SDimitry Andric   }
1678*06c3fb27SDimitry Andric 
1679bdd1243dSDimitry Andric   auto *RealShuffle = dyn_cast<ShuffleVectorInst>(Real);
1680bdd1243dSDimitry Andric   auto *ImagShuffle = dyn_cast<ShuffleVectorInst>(Imag);
1681*06c3fb27SDimitry Andric   if (!RealShuffle || !ImagShuffle) {
1682*06c3fb27SDimitry Andric     if (RealShuffle || ImagShuffle)
1683*06c3fb27SDimitry Andric       LLVM_DEBUG(dbgs() << " - There's a shuffle where there shouldn't be.\n");
1684*06c3fb27SDimitry Andric     return nullptr;
1685*06c3fb27SDimitry Andric   }
1686*06c3fb27SDimitry Andric 
1687bdd1243dSDimitry Andric   Value *RealOp1 = RealShuffle->getOperand(1);
1688bdd1243dSDimitry Andric   if (!isa<UndefValue>(RealOp1) && !isa<ConstantAggregateZero>(RealOp1)) {
1689bdd1243dSDimitry Andric     LLVM_DEBUG(dbgs() << " - RealOp1 is not undef or zero.\n");
1690bdd1243dSDimitry Andric     return nullptr;
1691bdd1243dSDimitry Andric   }
1692bdd1243dSDimitry Andric   Value *ImagOp1 = ImagShuffle->getOperand(1);
1693bdd1243dSDimitry Andric   if (!isa<UndefValue>(ImagOp1) && !isa<ConstantAggregateZero>(ImagOp1)) {
1694bdd1243dSDimitry Andric     LLVM_DEBUG(dbgs() << " - ImagOp1 is not undef or zero.\n");
1695bdd1243dSDimitry Andric     return nullptr;
1696bdd1243dSDimitry Andric   }
1697bdd1243dSDimitry Andric 
1698bdd1243dSDimitry Andric   Value *RealOp0 = RealShuffle->getOperand(0);
1699bdd1243dSDimitry Andric   Value *ImagOp0 = ImagShuffle->getOperand(0);
1700bdd1243dSDimitry Andric 
1701bdd1243dSDimitry Andric   if (RealOp0 != ImagOp0) {
1702bdd1243dSDimitry Andric     LLVM_DEBUG(dbgs() << " - Shuffle operands are not equal.\n");
1703bdd1243dSDimitry Andric     return nullptr;
1704bdd1243dSDimitry Andric   }
1705bdd1243dSDimitry Andric 
1706bdd1243dSDimitry Andric   ArrayRef<int> RealMask = RealShuffle->getShuffleMask();
1707bdd1243dSDimitry Andric   ArrayRef<int> ImagMask = ImagShuffle->getShuffleMask();
1708bdd1243dSDimitry Andric   if (!isDeinterleavingMask(RealMask) || !isDeinterleavingMask(ImagMask)) {
1709bdd1243dSDimitry Andric     LLVM_DEBUG(dbgs() << " - Masks are not deinterleaving.\n");
1710bdd1243dSDimitry Andric     return nullptr;
1711bdd1243dSDimitry Andric   }
1712bdd1243dSDimitry Andric 
1713bdd1243dSDimitry Andric   if (RealMask[0] != 0 || ImagMask[0] != 1) {
1714bdd1243dSDimitry Andric     LLVM_DEBUG(dbgs() << " - Masks do not have the correct initial value.\n");
1715bdd1243dSDimitry Andric     return nullptr;
1716bdd1243dSDimitry Andric   }
1717bdd1243dSDimitry Andric 
1718bdd1243dSDimitry Andric   // Type checking, the shuffle type should be a vector type of the same
1719bdd1243dSDimitry Andric   // scalar type, but half the size
1720bdd1243dSDimitry Andric   auto CheckType = [&](ShuffleVectorInst *Shuffle) {
1721bdd1243dSDimitry Andric     Value *Op = Shuffle->getOperand(0);
1722bdd1243dSDimitry Andric     auto *ShuffleTy = cast<FixedVectorType>(Shuffle->getType());
1723bdd1243dSDimitry Andric     auto *OpTy = cast<FixedVectorType>(Op->getType());
1724bdd1243dSDimitry Andric 
1725bdd1243dSDimitry Andric     if (OpTy->getScalarType() != ShuffleTy->getScalarType())
1726bdd1243dSDimitry Andric       return false;
1727bdd1243dSDimitry Andric     if ((ShuffleTy->getNumElements() * 2) != OpTy->getNumElements())
1728bdd1243dSDimitry Andric       return false;
1729bdd1243dSDimitry Andric 
1730bdd1243dSDimitry Andric     return true;
1731bdd1243dSDimitry Andric   };
1732bdd1243dSDimitry Andric 
1733bdd1243dSDimitry Andric   auto CheckDeinterleavingShuffle = [&](ShuffleVectorInst *Shuffle) -> bool {
1734bdd1243dSDimitry Andric     if (!CheckType(Shuffle))
1735bdd1243dSDimitry Andric       return false;
1736bdd1243dSDimitry Andric 
1737bdd1243dSDimitry Andric     ArrayRef<int> Mask = Shuffle->getShuffleMask();
1738bdd1243dSDimitry Andric     int Last = *Mask.rbegin();
1739bdd1243dSDimitry Andric 
1740bdd1243dSDimitry Andric     Value *Op = Shuffle->getOperand(0);
1741bdd1243dSDimitry Andric     auto *OpTy = cast<FixedVectorType>(Op->getType());
1742bdd1243dSDimitry Andric     int NumElements = OpTy->getNumElements();
1743bdd1243dSDimitry Andric 
1744bdd1243dSDimitry Andric     // Ensure that the deinterleaving shuffle only pulls from the first
1745bdd1243dSDimitry Andric     // shuffle operand.
1746bdd1243dSDimitry Andric     return Last < NumElements;
1747bdd1243dSDimitry Andric   };
1748bdd1243dSDimitry Andric 
1749bdd1243dSDimitry Andric   if (RealShuffle->getType() != ImagShuffle->getType()) {
1750bdd1243dSDimitry Andric     LLVM_DEBUG(dbgs() << " - Shuffle types aren't equal.\n");
1751bdd1243dSDimitry Andric     return nullptr;
1752bdd1243dSDimitry Andric   }
1753bdd1243dSDimitry Andric   if (!CheckDeinterleavingShuffle(RealShuffle)) {
1754bdd1243dSDimitry Andric     LLVM_DEBUG(dbgs() << " - RealShuffle is invalid type.\n");
1755bdd1243dSDimitry Andric     return nullptr;
1756bdd1243dSDimitry Andric   }
1757bdd1243dSDimitry Andric   if (!CheckDeinterleavingShuffle(ImagShuffle)) {
1758bdd1243dSDimitry Andric     LLVM_DEBUG(dbgs() << " - ImagShuffle is invalid type.\n");
1759bdd1243dSDimitry Andric     return nullptr;
1760bdd1243dSDimitry Andric   }
1761bdd1243dSDimitry Andric 
1762bdd1243dSDimitry Andric   NodePtr PlaceholderNode =
1763*06c3fb27SDimitry Andric       prepareCompositeNode(llvm::ComplexDeinterleavingOperation::Deinterleave,
1764bdd1243dSDimitry Andric                            RealShuffle, ImagShuffle);
1765bdd1243dSDimitry Andric   PlaceholderNode->ReplacementNode = RealShuffle->getOperand(0);
1766*06c3fb27SDimitry Andric   FinalInstructions.insert(RealShuffle);
1767*06c3fb27SDimitry Andric   FinalInstructions.insert(ImagShuffle);
1768bdd1243dSDimitry Andric   return submitCompositeNode(PlaceholderNode);
1769bdd1243dSDimitry Andric }
1770bdd1243dSDimitry Andric 
1771*06c3fb27SDimitry Andric ComplexDeinterleavingGraph::NodePtr
1772*06c3fb27SDimitry Andric ComplexDeinterleavingGraph::identifySplat(Value *R, Value *I) {
1773*06c3fb27SDimitry Andric   auto IsSplat = [](Value *V) -> bool {
1774*06c3fb27SDimitry Andric     // Fixed-width vector with constants
1775*06c3fb27SDimitry Andric     if (isa<ConstantDataVector>(V))
1776*06c3fb27SDimitry Andric       return true;
1777bdd1243dSDimitry Andric 
1778*06c3fb27SDimitry Andric     VectorType *VTy;
1779*06c3fb27SDimitry Andric     ArrayRef<int> Mask;
1780*06c3fb27SDimitry Andric     // Splats are represented differently depending on whether the repeated
1781*06c3fb27SDimitry Andric     // value is a constant or an Instruction
1782*06c3fb27SDimitry Andric     if (auto *Const = dyn_cast<ConstantExpr>(V)) {
1783*06c3fb27SDimitry Andric       if (Const->getOpcode() != Instruction::ShuffleVector)
1784bdd1243dSDimitry Andric         return false;
1785*06c3fb27SDimitry Andric       VTy = cast<VectorType>(Const->getType());
1786*06c3fb27SDimitry Andric       Mask = Const->getShuffleMask();
1787*06c3fb27SDimitry Andric     } else if (auto *Shuf = dyn_cast<ShuffleVectorInst>(V)) {
1788*06c3fb27SDimitry Andric       VTy = Shuf->getType();
1789*06c3fb27SDimitry Andric       Mask = Shuf->getShuffleMask();
1790*06c3fb27SDimitry Andric     } else {
1791bdd1243dSDimitry Andric       return false;
1792bdd1243dSDimitry Andric     }
1793*06c3fb27SDimitry Andric 
1794*06c3fb27SDimitry Andric     // When the data type is <1 x Type>, it's not possible to differentiate
1795*06c3fb27SDimitry Andric     // between the ComplexDeinterleaving::Deinterleave and
1796*06c3fb27SDimitry Andric     // ComplexDeinterleaving::Splat operations.
1797*06c3fb27SDimitry Andric     if (!VTy->isScalableTy() && VTy->getElementCount().getKnownMinValue() == 1)
1798*06c3fb27SDimitry Andric       return false;
1799*06c3fb27SDimitry Andric 
1800*06c3fb27SDimitry Andric     return all_equal(Mask) && Mask[0] == 0;
1801*06c3fb27SDimitry Andric   };
1802*06c3fb27SDimitry Andric 
1803*06c3fb27SDimitry Andric   if (!IsSplat(R) || !IsSplat(I))
1804*06c3fb27SDimitry Andric     return nullptr;
1805*06c3fb27SDimitry Andric 
1806*06c3fb27SDimitry Andric   auto *Real = dyn_cast<Instruction>(R);
1807*06c3fb27SDimitry Andric   auto *Imag = dyn_cast<Instruction>(I);
1808*06c3fb27SDimitry Andric   if ((!Real && Imag) || (Real && !Imag))
1809*06c3fb27SDimitry Andric     return nullptr;
1810*06c3fb27SDimitry Andric 
1811*06c3fb27SDimitry Andric   if (Real && Imag) {
1812*06c3fb27SDimitry Andric     // Non-constant splats should be in the same basic block
1813*06c3fb27SDimitry Andric     if (Real->getParent() != Imag->getParent())
1814*06c3fb27SDimitry Andric       return nullptr;
1815*06c3fb27SDimitry Andric 
1816*06c3fb27SDimitry Andric     FinalInstructions.insert(Real);
1817*06c3fb27SDimitry Andric     FinalInstructions.insert(Imag);
1818bdd1243dSDimitry Andric   }
1819*06c3fb27SDimitry Andric   NodePtr PlaceholderNode =
1820*06c3fb27SDimitry Andric       prepareCompositeNode(ComplexDeinterleavingOperation::Splat, R, I);
1821*06c3fb27SDimitry Andric   return submitCompositeNode(PlaceholderNode);
1822bdd1243dSDimitry Andric }
1823bdd1243dSDimitry Andric 
1824*06c3fb27SDimitry Andric ComplexDeinterleavingGraph::NodePtr
1825*06c3fb27SDimitry Andric ComplexDeinterleavingGraph::identifyPHINode(Instruction *Real,
1826*06c3fb27SDimitry Andric                                             Instruction *Imag) {
1827*06c3fb27SDimitry Andric   if (Real != RealPHI || Imag != ImagPHI)
1828*06c3fb27SDimitry Andric     return nullptr;
1829*06c3fb27SDimitry Andric 
1830*06c3fb27SDimitry Andric   PHIsFound = true;
1831*06c3fb27SDimitry Andric   NodePtr PlaceholderNode = prepareCompositeNode(
1832*06c3fb27SDimitry Andric       ComplexDeinterleavingOperation::ReductionPHI, Real, Imag);
1833*06c3fb27SDimitry Andric   return submitCompositeNode(PlaceholderNode);
1834*06c3fb27SDimitry Andric }
1835*06c3fb27SDimitry Andric 
1836*06c3fb27SDimitry Andric ComplexDeinterleavingGraph::NodePtr
1837*06c3fb27SDimitry Andric ComplexDeinterleavingGraph::identifySelectNode(Instruction *Real,
1838*06c3fb27SDimitry Andric                                                Instruction *Imag) {
1839*06c3fb27SDimitry Andric   auto *SelectReal = dyn_cast<SelectInst>(Real);
1840*06c3fb27SDimitry Andric   auto *SelectImag = dyn_cast<SelectInst>(Imag);
1841*06c3fb27SDimitry Andric   if (!SelectReal || !SelectImag)
1842*06c3fb27SDimitry Andric     return nullptr;
1843*06c3fb27SDimitry Andric 
1844*06c3fb27SDimitry Andric   Instruction *MaskA, *MaskB;
1845*06c3fb27SDimitry Andric   Instruction *AR, *AI, *RA, *BI;
1846*06c3fb27SDimitry Andric   if (!match(Real, m_Select(m_Instruction(MaskA), m_Instruction(AR),
1847*06c3fb27SDimitry Andric                             m_Instruction(RA))) ||
1848*06c3fb27SDimitry Andric       !match(Imag, m_Select(m_Instruction(MaskB), m_Instruction(AI),
1849*06c3fb27SDimitry Andric                             m_Instruction(BI))))
1850*06c3fb27SDimitry Andric     return nullptr;
1851*06c3fb27SDimitry Andric 
1852*06c3fb27SDimitry Andric   if (MaskA != MaskB && !MaskA->isIdenticalTo(MaskB))
1853*06c3fb27SDimitry Andric     return nullptr;
1854*06c3fb27SDimitry Andric 
1855*06c3fb27SDimitry Andric   if (!MaskA->getType()->isVectorTy())
1856*06c3fb27SDimitry Andric     return nullptr;
1857*06c3fb27SDimitry Andric 
1858*06c3fb27SDimitry Andric   auto NodeA = identifyNode(AR, AI);
1859*06c3fb27SDimitry Andric   if (!NodeA)
1860*06c3fb27SDimitry Andric     return nullptr;
1861*06c3fb27SDimitry Andric 
1862*06c3fb27SDimitry Andric   auto NodeB = identifyNode(RA, BI);
1863*06c3fb27SDimitry Andric   if (!NodeB)
1864*06c3fb27SDimitry Andric     return nullptr;
1865*06c3fb27SDimitry Andric 
1866*06c3fb27SDimitry Andric   NodePtr PlaceholderNode = prepareCompositeNode(
1867*06c3fb27SDimitry Andric       ComplexDeinterleavingOperation::ReductionSelect, Real, Imag);
1868*06c3fb27SDimitry Andric   PlaceholderNode->addOperand(NodeA);
1869*06c3fb27SDimitry Andric   PlaceholderNode->addOperand(NodeB);
1870*06c3fb27SDimitry Andric   FinalInstructions.insert(MaskA);
1871*06c3fb27SDimitry Andric   FinalInstructions.insert(MaskB);
1872*06c3fb27SDimitry Andric   return submitCompositeNode(PlaceholderNode);
1873*06c3fb27SDimitry Andric }
1874*06c3fb27SDimitry Andric 
1875*06c3fb27SDimitry Andric static Value *replaceSymmetricNode(IRBuilderBase &B, unsigned Opcode,
1876*06c3fb27SDimitry Andric                                    std::optional<FastMathFlags> Flags,
1877*06c3fb27SDimitry Andric                                    Value *InputA, Value *InputB) {
1878*06c3fb27SDimitry Andric   Value *I;
1879*06c3fb27SDimitry Andric   switch (Opcode) {
1880*06c3fb27SDimitry Andric   case Instruction::FNeg:
1881*06c3fb27SDimitry Andric     I = B.CreateFNeg(InputA);
1882*06c3fb27SDimitry Andric     break;
1883*06c3fb27SDimitry Andric   case Instruction::FAdd:
1884*06c3fb27SDimitry Andric     I = B.CreateFAdd(InputA, InputB);
1885*06c3fb27SDimitry Andric     break;
1886*06c3fb27SDimitry Andric   case Instruction::Add:
1887*06c3fb27SDimitry Andric     I = B.CreateAdd(InputA, InputB);
1888*06c3fb27SDimitry Andric     break;
1889*06c3fb27SDimitry Andric   case Instruction::FSub:
1890*06c3fb27SDimitry Andric     I = B.CreateFSub(InputA, InputB);
1891*06c3fb27SDimitry Andric     break;
1892*06c3fb27SDimitry Andric   case Instruction::Sub:
1893*06c3fb27SDimitry Andric     I = B.CreateSub(InputA, InputB);
1894*06c3fb27SDimitry Andric     break;
1895*06c3fb27SDimitry Andric   case Instruction::FMul:
1896*06c3fb27SDimitry Andric     I = B.CreateFMul(InputA, InputB);
1897*06c3fb27SDimitry Andric     break;
1898*06c3fb27SDimitry Andric   case Instruction::Mul:
1899*06c3fb27SDimitry Andric     I = B.CreateMul(InputA, InputB);
1900*06c3fb27SDimitry Andric     break;
1901*06c3fb27SDimitry Andric   default:
1902*06c3fb27SDimitry Andric     llvm_unreachable("Incorrect symmetric opcode");
1903*06c3fb27SDimitry Andric   }
1904*06c3fb27SDimitry Andric   if (Flags)
1905*06c3fb27SDimitry Andric     cast<Instruction>(I)->setFastMathFlags(*Flags);
1906*06c3fb27SDimitry Andric   return I;
1907*06c3fb27SDimitry Andric }
1908*06c3fb27SDimitry Andric 
1909*06c3fb27SDimitry Andric Value *ComplexDeinterleavingGraph::replaceNode(IRBuilderBase &Builder,
1910*06c3fb27SDimitry Andric                                                RawNodePtr Node) {
1911bdd1243dSDimitry Andric   if (Node->ReplacementNode)
1912bdd1243dSDimitry Andric     return Node->ReplacementNode;
1913bdd1243dSDimitry Andric 
1914*06c3fb27SDimitry Andric   auto ReplaceOperandIfExist = [&](RawNodePtr &Node, unsigned Idx) -> Value * {
1915*06c3fb27SDimitry Andric     return Node->Operands.size() > Idx
1916*06c3fb27SDimitry Andric                ? replaceNode(Builder, Node->Operands[Idx])
1917*06c3fb27SDimitry Andric                : nullptr;
1918*06c3fb27SDimitry Andric   };
1919bdd1243dSDimitry Andric 
1920*06c3fb27SDimitry Andric   Value *ReplacementNode;
1921*06c3fb27SDimitry Andric   switch (Node->Operation) {
1922*06c3fb27SDimitry Andric   case ComplexDeinterleavingOperation::CAdd:
1923*06c3fb27SDimitry Andric   case ComplexDeinterleavingOperation::CMulPartial:
1924*06c3fb27SDimitry Andric   case ComplexDeinterleavingOperation::Symmetric: {
1925*06c3fb27SDimitry Andric     Value *Input0 = ReplaceOperandIfExist(Node, 0);
1926*06c3fb27SDimitry Andric     Value *Input1 = ReplaceOperandIfExist(Node, 1);
1927*06c3fb27SDimitry Andric     Value *Accumulator = ReplaceOperandIfExist(Node, 2);
1928*06c3fb27SDimitry Andric     assert(!Input1 || (Input0->getType() == Input1->getType() &&
1929*06c3fb27SDimitry Andric                        "Node inputs need to be of the same type"));
1930*06c3fb27SDimitry Andric     assert(!Accumulator ||
1931*06c3fb27SDimitry Andric            (Input0->getType() == Accumulator->getType() &&
1932*06c3fb27SDimitry Andric             "Accumulator and input need to be of the same type"));
1933*06c3fb27SDimitry Andric     if (Node->Operation == ComplexDeinterleavingOperation::Symmetric)
1934*06c3fb27SDimitry Andric       ReplacementNode = replaceSymmetricNode(Builder, Node->Opcode, Node->Flags,
1935*06c3fb27SDimitry Andric                                              Input0, Input1);
1936*06c3fb27SDimitry Andric     else
1937*06c3fb27SDimitry Andric       ReplacementNode = TL->createComplexDeinterleavingIR(
1938*06c3fb27SDimitry Andric           Builder, Node->Operation, Node->Rotation, Input0, Input1,
1939*06c3fb27SDimitry Andric           Accumulator);
1940*06c3fb27SDimitry Andric     break;
1941*06c3fb27SDimitry Andric   }
1942*06c3fb27SDimitry Andric   case ComplexDeinterleavingOperation::Deinterleave:
1943*06c3fb27SDimitry Andric     llvm_unreachable("Deinterleave node should already have ReplacementNode");
1944*06c3fb27SDimitry Andric     break;
1945*06c3fb27SDimitry Andric   case ComplexDeinterleavingOperation::Splat: {
1946*06c3fb27SDimitry Andric     auto *NewTy = VectorType::getDoubleElementsVectorType(
1947*06c3fb27SDimitry Andric         cast<VectorType>(Node->Real->getType()));
1948*06c3fb27SDimitry Andric     auto *R = dyn_cast<Instruction>(Node->Real);
1949*06c3fb27SDimitry Andric     auto *I = dyn_cast<Instruction>(Node->Imag);
1950*06c3fb27SDimitry Andric     if (R && I) {
1951*06c3fb27SDimitry Andric       // Splats that are not constant are interleaved where they are located
1952*06c3fb27SDimitry Andric       Instruction *InsertPoint = (I->comesBefore(R) ? R : I)->getNextNode();
1953*06c3fb27SDimitry Andric       IRBuilder<> IRB(InsertPoint);
1954*06c3fb27SDimitry Andric       ReplacementNode =
1955*06c3fb27SDimitry Andric           IRB.CreateIntrinsic(Intrinsic::experimental_vector_interleave2, NewTy,
1956*06c3fb27SDimitry Andric                               {Node->Real, Node->Imag});
1957*06c3fb27SDimitry Andric     } else {
1958*06c3fb27SDimitry Andric       ReplacementNode =
1959*06c3fb27SDimitry Andric           Builder.CreateIntrinsic(Intrinsic::experimental_vector_interleave2,
1960*06c3fb27SDimitry Andric                                   NewTy, {Node->Real, Node->Imag});
1961*06c3fb27SDimitry Andric     }
1962*06c3fb27SDimitry Andric     break;
1963*06c3fb27SDimitry Andric   }
1964*06c3fb27SDimitry Andric   case ComplexDeinterleavingOperation::ReductionPHI: {
1965*06c3fb27SDimitry Andric     // If Operation is ReductionPHI, a new empty PHINode is created.
1966*06c3fb27SDimitry Andric     // It is filled later when the ReductionOperation is processed.
1967*06c3fb27SDimitry Andric     auto *VTy = cast<VectorType>(Node->Real->getType());
1968*06c3fb27SDimitry Andric     auto *NewVTy = VectorType::getDoubleElementsVectorType(VTy);
1969*06c3fb27SDimitry Andric     auto *NewPHI = PHINode::Create(NewVTy, 0, "", BackEdge->getFirstNonPHI());
1970*06c3fb27SDimitry Andric     OldToNewPHI[dyn_cast<PHINode>(Node->Real)] = NewPHI;
1971*06c3fb27SDimitry Andric     ReplacementNode = NewPHI;
1972*06c3fb27SDimitry Andric     break;
1973*06c3fb27SDimitry Andric   }
1974*06c3fb27SDimitry Andric   case ComplexDeinterleavingOperation::ReductionOperation:
1975*06c3fb27SDimitry Andric     ReplacementNode = replaceNode(Builder, Node->Operands[0]);
1976*06c3fb27SDimitry Andric     processReductionOperation(ReplacementNode, Node);
1977*06c3fb27SDimitry Andric     break;
1978*06c3fb27SDimitry Andric   case ComplexDeinterleavingOperation::ReductionSelect: {
1979*06c3fb27SDimitry Andric     auto *MaskReal = cast<Instruction>(Node->Real)->getOperand(0);
1980*06c3fb27SDimitry Andric     auto *MaskImag = cast<Instruction>(Node->Imag)->getOperand(0);
1981*06c3fb27SDimitry Andric     auto *A = replaceNode(Builder, Node->Operands[0]);
1982*06c3fb27SDimitry Andric     auto *B = replaceNode(Builder, Node->Operands[1]);
1983*06c3fb27SDimitry Andric     auto *NewMaskTy = VectorType::getDoubleElementsVectorType(
1984*06c3fb27SDimitry Andric         cast<VectorType>(MaskReal->getType()));
1985*06c3fb27SDimitry Andric     auto *NewMask =
1986*06c3fb27SDimitry Andric         Builder.CreateIntrinsic(Intrinsic::experimental_vector_interleave2,
1987*06c3fb27SDimitry Andric                                 NewMaskTy, {MaskReal, MaskImag});
1988*06c3fb27SDimitry Andric     ReplacementNode = Builder.CreateSelect(NewMask, A, B);
1989*06c3fb27SDimitry Andric     break;
1990*06c3fb27SDimitry Andric   }
1991*06c3fb27SDimitry Andric   }
1992bdd1243dSDimitry Andric 
1993*06c3fb27SDimitry Andric   assert(ReplacementNode && "Target failed to create Intrinsic call.");
1994bdd1243dSDimitry Andric   NumComplexTransformations += 1;
1995*06c3fb27SDimitry Andric   Node->ReplacementNode = ReplacementNode;
1996*06c3fb27SDimitry Andric   return ReplacementNode;
1997*06c3fb27SDimitry Andric }
1998*06c3fb27SDimitry Andric 
1999*06c3fb27SDimitry Andric void ComplexDeinterleavingGraph::processReductionOperation(
2000*06c3fb27SDimitry Andric     Value *OperationReplacement, RawNodePtr Node) {
2001*06c3fb27SDimitry Andric   auto *Real = cast<Instruction>(Node->Real);
2002*06c3fb27SDimitry Andric   auto *Imag = cast<Instruction>(Node->Imag);
2003*06c3fb27SDimitry Andric   auto *OldPHIReal = ReductionInfo[Real].first;
2004*06c3fb27SDimitry Andric   auto *OldPHIImag = ReductionInfo[Imag].first;
2005*06c3fb27SDimitry Andric   auto *NewPHI = OldToNewPHI[OldPHIReal];
2006*06c3fb27SDimitry Andric 
2007*06c3fb27SDimitry Andric   auto *VTy = cast<VectorType>(Real->getType());
2008*06c3fb27SDimitry Andric   auto *NewVTy = VectorType::getDoubleElementsVectorType(VTy);
2009*06c3fb27SDimitry Andric 
2010*06c3fb27SDimitry Andric   // We have to interleave initial origin values coming from IncomingBlock
2011*06c3fb27SDimitry Andric   Value *InitReal = OldPHIReal->getIncomingValueForBlock(Incoming);
2012*06c3fb27SDimitry Andric   Value *InitImag = OldPHIImag->getIncomingValueForBlock(Incoming);
2013*06c3fb27SDimitry Andric 
2014*06c3fb27SDimitry Andric   IRBuilder<> Builder(Incoming->getTerminator());
2015*06c3fb27SDimitry Andric   auto *NewInit = Builder.CreateIntrinsic(
2016*06c3fb27SDimitry Andric       Intrinsic::experimental_vector_interleave2, NewVTy, {InitReal, InitImag});
2017*06c3fb27SDimitry Andric 
2018*06c3fb27SDimitry Andric   NewPHI->addIncoming(NewInit, Incoming);
2019*06c3fb27SDimitry Andric   NewPHI->addIncoming(OperationReplacement, BackEdge);
2020*06c3fb27SDimitry Andric 
2021*06c3fb27SDimitry Andric   // Deinterleave complex vector outside of loop so that it can be finally
2022*06c3fb27SDimitry Andric   // reduced
2023*06c3fb27SDimitry Andric   auto *FinalReductionReal = ReductionInfo[Real].second;
2024*06c3fb27SDimitry Andric   auto *FinalReductionImag = ReductionInfo[Imag].second;
2025*06c3fb27SDimitry Andric 
2026*06c3fb27SDimitry Andric   Builder.SetInsertPoint(
2027*06c3fb27SDimitry Andric       &*FinalReductionReal->getParent()->getFirstInsertionPt());
2028*06c3fb27SDimitry Andric   auto *Deinterleave = Builder.CreateIntrinsic(
2029*06c3fb27SDimitry Andric       Intrinsic::experimental_vector_deinterleave2,
2030*06c3fb27SDimitry Andric       OperationReplacement->getType(), OperationReplacement);
2031*06c3fb27SDimitry Andric 
2032*06c3fb27SDimitry Andric   auto *NewReal = Builder.CreateExtractValue(Deinterleave, (uint64_t)0);
2033*06c3fb27SDimitry Andric   FinalReductionReal->replaceUsesOfWith(Real, NewReal);
2034*06c3fb27SDimitry Andric 
2035*06c3fb27SDimitry Andric   Builder.SetInsertPoint(FinalReductionImag);
2036*06c3fb27SDimitry Andric   auto *NewImag = Builder.CreateExtractValue(Deinterleave, 1);
2037*06c3fb27SDimitry Andric   FinalReductionImag->replaceUsesOfWith(Imag, NewImag);
2038bdd1243dSDimitry Andric }
2039bdd1243dSDimitry Andric 
2040bdd1243dSDimitry Andric void ComplexDeinterleavingGraph::replaceNodes() {
2041*06c3fb27SDimitry Andric   SmallVector<Instruction *, 16> DeadInstrRoots;
2042*06c3fb27SDimitry Andric   for (auto *RootInstruction : OrderedRoots) {
2043*06c3fb27SDimitry Andric     // Check if this potential root went through check process and we can
2044*06c3fb27SDimitry Andric     // deinterleave it
2045*06c3fb27SDimitry Andric     if (!RootToNode.count(RootInstruction))
2046*06c3fb27SDimitry Andric       continue;
2047*06c3fb27SDimitry Andric 
2048*06c3fb27SDimitry Andric     IRBuilder<> Builder(RootInstruction);
2049*06c3fb27SDimitry Andric     auto RootNode = RootToNode[RootInstruction];
2050*06c3fb27SDimitry Andric     Value *R = replaceNode(Builder, RootNode.get());
2051*06c3fb27SDimitry Andric 
2052*06c3fb27SDimitry Andric     if (RootNode->Operation ==
2053*06c3fb27SDimitry Andric         ComplexDeinterleavingOperation::ReductionOperation) {
2054*06c3fb27SDimitry Andric       auto *RootReal = cast<Instruction>(RootNode->Real);
2055*06c3fb27SDimitry Andric       auto *RootImag = cast<Instruction>(RootNode->Imag);
2056*06c3fb27SDimitry Andric       ReductionInfo[RootReal].first->removeIncomingValue(BackEdge);
2057*06c3fb27SDimitry Andric       ReductionInfo[RootImag].first->removeIncomingValue(BackEdge);
2058*06c3fb27SDimitry Andric       DeadInstrRoots.push_back(cast<Instruction>(RootReal));
2059*06c3fb27SDimitry Andric       DeadInstrRoots.push_back(cast<Instruction>(RootImag));
2060*06c3fb27SDimitry Andric     } else {
2061*06c3fb27SDimitry Andric       assert(R && "Unable to find replacement for RootInstruction");
2062*06c3fb27SDimitry Andric       DeadInstrRoots.push_back(RootInstruction);
2063*06c3fb27SDimitry Andric       RootInstruction->replaceAllUsesWith(R);
2064*06c3fb27SDimitry Andric     }
2065bdd1243dSDimitry Andric   }
2066bdd1243dSDimitry Andric 
2067*06c3fb27SDimitry Andric   for (auto *I : DeadInstrRoots)
2068*06c3fb27SDimitry Andric     RecursivelyDeleteTriviallyDeadInstructions(I, TLI);
2069bdd1243dSDimitry Andric }
2070