xref: /llvm-project/llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp (revision ab09654832dba5cef8baa6400fdfd3e4d1495624)
1 //===- ComplexDeinterleavingPass.cpp --------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Identification:
10 // This step is responsible for finding the patterns that can be lowered to
11 // complex instructions, and building a graph to represent the complex
12 // structures. Starting from the "Converging Shuffle" (a shuffle that
13 // reinterleaves the complex components, with a mask of <0, 2, 1, 3>), the
14 // operands are evaluated and identified as "Composite Nodes" (collections of
15 // instructions that can potentially be lowered to a single complex
16 // instruction). This is performed by checking the real and imaginary components
17 // and tracking the data flow for each component while following the operand
18 // pairs. Validity of each node is expected to be done upon creation, and any
19 // validation errors should halt traversal and prevent further graph
20 // construction.
21 // Instead of relying on Shuffle operations, vector interleaving and
22 // deinterleaving can be represented by vector.interleave2 and
23 // vector.deinterleave2 intrinsics. Scalable vectors can be represented only by
24 // these intrinsics, whereas, fixed-width vectors are recognized for both
25 // shufflevector instruction and intrinsics.
26 //
27 // Replacement:
28 // This step traverses the graph built up by identification, delegating to the
29 // target to validate and generate the correct intrinsics, and plumbs them
30 // together connecting each end of the new intrinsics graph to the existing
31 // use-def chain. This step is assumed to finish successfully, as all
32 // information is expected to be correct by this point.
33 //
34 //
35 // Internal data structure:
36 // ComplexDeinterleavingGraph:
37 // Keeps references to all the valid CompositeNodes formed as part of the
38 // transformation, and every Instruction contained within said nodes. It also
39 // holds onto a reference to the root Instruction, and the root node that should
40 // replace it.
41 //
42 // ComplexDeinterleavingCompositeNode:
43 // A CompositeNode represents a single transformation point; each node should
44 // transform into a single complex instruction (ignoring vector splitting, which
45 // would generate more instructions per node). They are identified in a
46 // depth-first manner, traversing and identifying the operands of each
47 // instruction in the order they appear in the IR.
48 // Each node maintains a reference  to its Real and Imaginary instructions,
49 // as well as any additional instructions that make up the identified operation
50 // (Internal instructions should only have uses within their containing node).
51 // A Node also contains the rotation and operation type that it represents.
52 // Operands contains pointers to other CompositeNodes, acting as the edges in
53 // the graph. ReplacementValue is the transformed Value* that has been emitted
54 // to the IR.
55 //
56 // Note: If the operation of a Node is Shuffle, only the Real, Imaginary, and
57 // ReplacementValue fields of that Node are relevant, where the ReplacementValue
58 // should be pre-populated.
59 //
60 //===----------------------------------------------------------------------===//
61 
62 #include "llvm/CodeGen/ComplexDeinterleavingPass.h"
63 #include "llvm/ADT/Statistic.h"
64 #include "llvm/Analysis/TargetLibraryInfo.h"
65 #include "llvm/Analysis/TargetTransformInfo.h"
66 #include "llvm/CodeGen/TargetLowering.h"
67 #include "llvm/CodeGen/TargetPassConfig.h"
68 #include "llvm/CodeGen/TargetSubtargetInfo.h"
69 #include "llvm/IR/IRBuilder.h"
70 #include "llvm/IR/PatternMatch.h"
71 #include "llvm/InitializePasses.h"
72 #include "llvm/Target/TargetMachine.h"
73 #include "llvm/Transforms/Utils/Local.h"
74 #include <algorithm>
75 
76 using namespace llvm;
77 using namespace PatternMatch;
78 
79 #define DEBUG_TYPE "complex-deinterleaving"
80 
81 STATISTIC(NumComplexTransformations, "Amount of complex patterns transformed");
82 
83 static cl::opt<bool> ComplexDeinterleavingEnabled(
84     "enable-complex-deinterleaving",
85     cl::desc("Enable generation of complex instructions"), cl::init(true),
86     cl::Hidden);
87 
88 /// Checks the given mask, and determines whether said mask is interleaving.
89 ///
90 /// To be interleaving, a mask must alternate between `i` and `i + (Length /
91 /// 2)`, and must contain all numbers within the range of `[0..Length)` (e.g. a
92 /// 4x vector interleaving mask would be <0, 2, 1, 3>).
93 static bool isInterleavingMask(ArrayRef<int> Mask);
94 
95 /// Checks the given mask, and determines whether said mask is deinterleaving.
96 ///
97 /// To be deinterleaving, a mask must increment in steps of 2, and either start
98 /// with 0 or 1.
99 /// (e.g. an 8x vector deinterleaving mask would be either <0, 2, 4, 6> or
100 /// <1, 3, 5, 7>).
101 static bool isDeinterleavingMask(ArrayRef<int> Mask);
102 
103 namespace {
104 
105 class ComplexDeinterleavingLegacyPass : public FunctionPass {
106 public:
107   static char ID;
108 
109   ComplexDeinterleavingLegacyPass(const TargetMachine *TM = nullptr)
110       : FunctionPass(ID), TM(TM) {
111     initializeComplexDeinterleavingLegacyPassPass(
112         *PassRegistry::getPassRegistry());
113   }
114 
115   StringRef getPassName() const override {
116     return "Complex Deinterleaving Pass";
117   }
118 
119   bool runOnFunction(Function &F) override;
120   void getAnalysisUsage(AnalysisUsage &AU) const override {
121     AU.addRequired<TargetLibraryInfoWrapperPass>();
122     AU.setPreservesCFG();
123   }
124 
125 private:
126   const TargetMachine *TM;
127 };
128 
129 class ComplexDeinterleavingGraph;
130 struct ComplexDeinterleavingCompositeNode {
131 
132   ComplexDeinterleavingCompositeNode(ComplexDeinterleavingOperation Op,
133                                      Instruction *R, Instruction *I)
134       : Operation(Op), Real(R), Imag(I) {}
135 
136 private:
137   friend class ComplexDeinterleavingGraph;
138   using NodePtr = std::shared_ptr<ComplexDeinterleavingCompositeNode>;
139   using RawNodePtr = ComplexDeinterleavingCompositeNode *;
140 
141 public:
142   ComplexDeinterleavingOperation Operation;
143   Instruction *Real;
144   Instruction *Imag;
145 
146   // This two members are required exclusively for generating
147   // ComplexDeinterleavingOperation::Symmetric operations.
148   unsigned Opcode;
149   FastMathFlags Flags;
150 
151   ComplexDeinterleavingRotation Rotation =
152       ComplexDeinterleavingRotation::Rotation_0;
153   SmallVector<RawNodePtr> Operands;
154   Value *ReplacementNode = nullptr;
155 
156   void addOperand(NodePtr Node) { Operands.push_back(Node.get()); }
157 
158   void dump() { dump(dbgs()); }
159   void dump(raw_ostream &OS) {
160     auto PrintValue = [&](Value *V) {
161       if (V) {
162         OS << "\"";
163         V->print(OS, true);
164         OS << "\"\n";
165       } else
166         OS << "nullptr\n";
167     };
168     auto PrintNodeRef = [&](RawNodePtr Ptr) {
169       if (Ptr)
170         OS << Ptr << "\n";
171       else
172         OS << "nullptr\n";
173     };
174 
175     OS << "- CompositeNode: " << this << "\n";
176     OS << "  Real: ";
177     PrintValue(Real);
178     OS << "  Imag: ";
179     PrintValue(Imag);
180     OS << "  ReplacementNode: ";
181     PrintValue(ReplacementNode);
182     OS << "  Operation: " << (int)Operation << "\n";
183     OS << "  Rotation: " << ((int)Rotation * 90) << "\n";
184     OS << "  Operands: \n";
185     for (const auto &Op : Operands) {
186       OS << "    - ";
187       PrintNodeRef(Op);
188     }
189   }
190 };
191 
192 class ComplexDeinterleavingGraph {
193 public:
194   struct Product {
195     Instruction *Multiplier;
196     Instruction *Multiplicand;
197     bool IsPositive;
198   };
199 
200   using Addend = std::pair<Instruction *, bool>;
201   using NodePtr = ComplexDeinterleavingCompositeNode::NodePtr;
202   using RawNodePtr = ComplexDeinterleavingCompositeNode::RawNodePtr;
203 
204   // Helper struct for holding info about potential partial multiplication
205   // candidates
206   struct PartialMulCandidate {
207     Instruction *Common;
208     NodePtr Node;
209     unsigned RealIdx;
210     unsigned ImagIdx;
211     bool IsNodeInverted;
212   };
213 
214   explicit ComplexDeinterleavingGraph(const TargetLowering *TL,
215                                       const TargetLibraryInfo *TLI)
216       : TL(TL), TLI(TLI) {}
217 
218 private:
219   const TargetLowering *TL = nullptr;
220   const TargetLibraryInfo *TLI = nullptr;
221   SmallVector<NodePtr> CompositeNodes;
222 
223   SmallPtrSet<Instruction *, 16> FinalInstructions;
224 
225   /// Root instructions are instructions from which complex computation starts
226   std::map<Instruction *, NodePtr> RootToNode;
227 
228   /// Topologically sorted root instructions
229   SmallVector<Instruction *, 1> OrderedRoots;
230 
231   /// When examining a basic block for complex deinterleaving, if it is a simple
232   /// one-block loop, then the only incoming block is 'Incoming' and the
233   /// 'BackEdge' block is the block itself."
234   BasicBlock *BackEdge = nullptr;
235   BasicBlock *Incoming = nullptr;
236 
237   /// ReductionInfo maps from %ReductionOp to %PHInode and Instruction
238   /// %OutsideUser as it is shown in the IR:
239   ///
240   /// vector.body:
241   ///   %PHInode = phi <vector type> [ zeroinitializer, %entry ],
242   ///                                [ %ReductionOp, %vector.body ]
243   ///   ...
244   ///   %ReductionOp = fadd i64 ...
245   ///   ...
246   ///   br i1 %condition, label %vector.body, %middle.block
247   ///
248   /// middle.block:
249   ///   %OutsideUser = llvm.vector.reduce.fadd(..., %ReductionOp)
250   ///
251   /// %OutsideUser can be `llvm.vector.reduce.fadd` or `fadd` preceding
252   /// `llvm.vector.reduce.fadd` when unroll factor isn't one.
253   std::map<Instruction *, std::pair<PHINode *, Instruction *>> ReductionInfo;
254 
255   /// In the process of detecting a reduction, we consider a pair of
256   /// %ReductionOP, which we refer to as real and imag (or vice versa), and
257   /// traverse the use-tree to detect complex operations. As this is a reduction
258   /// operation, it will eventually reach RealPHI and ImagPHI, which corresponds
259   /// to the %ReductionOPs that we suspect to be complex.
260   /// RealPHI and ImagPHI are used by the identifyPHINode method.
261   PHINode *RealPHI = nullptr;
262   PHINode *ImagPHI = nullptr;
263 
264   /// OldToNewPHI maps the original real PHINode to a new, double-sized PHINode.
265   /// The new PHINode corresponds to a vector of deinterleaved complex numbers.
266   /// This mapping is populated during
267   /// ComplexDeinterleavingOperation::ReductionPHI node replacement. It is then
268   /// used in the ComplexDeinterleavingOperation::ReductionOperation node
269   /// replacement process.
270   std::map<PHINode *, PHINode *> OldToNewPHI;
271 
272   NodePtr prepareCompositeNode(ComplexDeinterleavingOperation Operation,
273                                Instruction *R, Instruction *I) {
274     assert(((Operation != ComplexDeinterleavingOperation::ReductionPHI &&
275              Operation != ComplexDeinterleavingOperation::ReductionOperation) ||
276             (R && I)) &&
277            "Reduction related nodes must have Real and Imaginary parts");
278     return std::make_shared<ComplexDeinterleavingCompositeNode>(Operation, R,
279                                                                 I);
280   }
281 
282   NodePtr submitCompositeNode(NodePtr Node) {
283     CompositeNodes.push_back(Node);
284     return Node;
285   }
286 
287   NodePtr getContainingComposite(Value *R, Value *I) {
288     for (const auto &CN : CompositeNodes) {
289       if (CN->Real == R && CN->Imag == I)
290         return CN;
291     }
292     return nullptr;
293   }
294 
295   /// Identifies a complex partial multiply pattern and its rotation, based on
296   /// the following patterns
297   ///
298   ///  0:  r: cr + ar * br
299   ///      i: ci + ar * bi
300   /// 90:  r: cr - ai * bi
301   ///      i: ci + ai * br
302   /// 180: r: cr - ar * br
303   ///      i: ci - ar * bi
304   /// 270: r: cr + ai * bi
305   ///      i: ci - ai * br
306   NodePtr identifyPartialMul(Instruction *Real, Instruction *Imag);
307 
308   /// Identify the other branch of a Partial Mul, taking the CommonOperandI that
309   /// is partially known from identifyPartialMul, filling in the other half of
310   /// the complex pair.
311   NodePtr identifyNodeWithImplicitAdd(
312       Instruction *I, Instruction *J,
313       std::pair<Instruction *, Instruction *> &CommonOperandI);
314 
315   /// Identifies a complex add pattern and its rotation, based on the following
316   /// patterns.
317   ///
318   /// 90:  r: ar - bi
319   ///      i: ai + br
320   /// 270: r: ar + bi
321   ///      i: ai - br
322   NodePtr identifyAdd(Instruction *Real, Instruction *Imag);
323   NodePtr identifySymmetricOperation(Instruction *Real, Instruction *Imag);
324 
325   NodePtr identifyNode(Instruction *I, Instruction *J);
326 
327   /// Determine if a sum of complex numbers can be formed from \p RealAddends
328   /// and \p ImagAddens. If \p Accumulator is not null, add the result to it.
329   /// Return nullptr if it is not possible to construct a complex number.
330   /// \p Flags are needed to generate symmetric Add and Sub operations.
331   NodePtr identifyAdditions(std::list<Addend> &RealAddends,
332                             std::list<Addend> &ImagAddends, FastMathFlags Flags,
333                             NodePtr Accumulator);
334 
335   /// Extract one addend that have both real and imaginary parts positive.
336   NodePtr extractPositiveAddend(std::list<Addend> &RealAddends,
337                                 std::list<Addend> &ImagAddends);
338 
339   /// Determine if sum of multiplications of complex numbers can be formed from
340   /// \p RealMuls and \p ImagMuls. If \p Accumulator is not null, add the result
341   /// to it. Return nullptr if it is not possible to construct a complex number.
342   NodePtr identifyMultiplications(std::vector<Product> &RealMuls,
343                                   std::vector<Product> &ImagMuls,
344                                   NodePtr Accumulator);
345 
346   /// Go through pairs of multiplication (one Real and one Imag) and find all
347   /// possible candidates for partial multiplication and put them into \p
348   /// Candidates. Returns true if all Product has pair with common operand
349   bool collectPartialMuls(const std::vector<Product> &RealMuls,
350                           const std::vector<Product> &ImagMuls,
351                           std::vector<PartialMulCandidate> &Candidates);
352 
353   /// If the code is compiled with -Ofast or expressions have `reassoc` flag,
354   /// the order of complex computation operations may be significantly altered,
355   /// and the real and imaginary parts may not be executed in parallel. This
356   /// function takes this into consideration and employs a more general approach
357   /// to identify complex computations. Initially, it gathers all the addends
358   /// and multiplicands and then constructs a complex expression from them.
359   NodePtr identifyReassocNodes(Instruction *I, Instruction *J);
360 
361   NodePtr identifyRoot(Instruction *I);
362 
363   /// Identifies the Deinterleave operation applied to a vector containing
364   /// complex numbers. There are two ways to represent the Deinterleave
365   /// operation:
366   /// * Using two shufflevectors with even indices for /pReal instruction and
367   /// odd indices for /pImag instructions (only for fixed-width vectors)
368   /// * Using two extractvalue instructions applied to `vector.deinterleave2`
369   /// intrinsic (for both fixed and scalable vectors)
370   NodePtr identifyDeinterleave(Instruction *Real, Instruction *Imag);
371 
372   NodePtr identifyPHINode(Instruction *Real, Instruction *Imag);
373 
374   Value *replaceNode(IRBuilderBase &Builder, RawNodePtr Node);
375 
376   /// Complete IR modifications after producing new reduction operation:
377   /// * Populate the PHINode generated for
378   /// ComplexDeinterleavingOperation::ReductionPHI
379   /// * Deinterleave the final value outside of the loop and repurpose original
380   /// reduction users
381   void processReductionOperation(Value *OperationReplacement, RawNodePtr Node);
382 
383 public:
384   void dump() { dump(dbgs()); }
385   void dump(raw_ostream &OS) {
386     for (const auto &Node : CompositeNodes)
387       Node->dump(OS);
388   }
389 
390   /// Returns false if the deinterleaving operation should be cancelled for the
391   /// current graph.
392   bool identifyNodes(Instruction *RootI);
393 
394   /// In case \pB is one-block loop, this function seeks potential reductions
395   /// and populates ReductionInfo. Returns true if any reductions were
396   /// identified.
397   bool collectPotentialReductions(BasicBlock *B);
398 
399   void identifyReductionNodes();
400 
401   /// Check that every instruction, from the roots to the leaves, has internal
402   /// uses.
403   bool checkNodes();
404 
405   /// Perform the actual replacement of the underlying instruction graph.
406   void replaceNodes();
407 };
408 
409 class ComplexDeinterleaving {
410 public:
411   ComplexDeinterleaving(const TargetLowering *tl, const TargetLibraryInfo *tli)
412       : TL(tl), TLI(tli) {}
413   bool runOnFunction(Function &F);
414 
415 private:
416   bool evaluateBasicBlock(BasicBlock *B);
417 
418   const TargetLowering *TL = nullptr;
419   const TargetLibraryInfo *TLI = nullptr;
420 };
421 
422 } // namespace
423 
424 char ComplexDeinterleavingLegacyPass::ID = 0;
425 
426 INITIALIZE_PASS_BEGIN(ComplexDeinterleavingLegacyPass, DEBUG_TYPE,
427                       "Complex Deinterleaving", false, false)
428 INITIALIZE_PASS_END(ComplexDeinterleavingLegacyPass, DEBUG_TYPE,
429                     "Complex Deinterleaving", false, false)
430 
431 PreservedAnalyses ComplexDeinterleavingPass::run(Function &F,
432                                                  FunctionAnalysisManager &AM) {
433   const TargetLowering *TL = TM->getSubtargetImpl(F)->getTargetLowering();
434   auto &TLI = AM.getResult<llvm::TargetLibraryAnalysis>(F);
435   if (!ComplexDeinterleaving(TL, &TLI).runOnFunction(F))
436     return PreservedAnalyses::all();
437 
438   PreservedAnalyses PA;
439   PA.preserve<FunctionAnalysisManagerModuleProxy>();
440   return PA;
441 }
442 
443 FunctionPass *llvm::createComplexDeinterleavingPass(const TargetMachine *TM) {
444   return new ComplexDeinterleavingLegacyPass(TM);
445 }
446 
447 bool ComplexDeinterleavingLegacyPass::runOnFunction(Function &F) {
448   const auto *TL = TM->getSubtargetImpl(F)->getTargetLowering();
449   auto TLI = getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F);
450   return ComplexDeinterleaving(TL, &TLI).runOnFunction(F);
451 }
452 
453 bool ComplexDeinterleaving::runOnFunction(Function &F) {
454   if (!ComplexDeinterleavingEnabled) {
455     LLVM_DEBUG(
456         dbgs() << "Complex deinterleaving has been explicitly disabled.\n");
457     return false;
458   }
459 
460   if (!TL->isComplexDeinterleavingSupported()) {
461     LLVM_DEBUG(
462         dbgs() << "Complex deinterleaving has been disabled, target does "
463                   "not support lowering of complex number operations.\n");
464     return false;
465   }
466 
467   bool Changed = false;
468   for (auto &B : F)
469     Changed |= evaluateBasicBlock(&B);
470 
471   return Changed;
472 }
473 
474 static bool isInterleavingMask(ArrayRef<int> Mask) {
475   // If the size is not even, it's not an interleaving mask
476   if ((Mask.size() & 1))
477     return false;
478 
479   int HalfNumElements = Mask.size() / 2;
480   for (int Idx = 0; Idx < HalfNumElements; ++Idx) {
481     int MaskIdx = Idx * 2;
482     if (Mask[MaskIdx] != Idx || Mask[MaskIdx + 1] != (Idx + HalfNumElements))
483       return false;
484   }
485 
486   return true;
487 }
488 
489 static bool isDeinterleavingMask(ArrayRef<int> Mask) {
490   int Offset = Mask[0];
491   int HalfNumElements = Mask.size() / 2;
492 
493   for (int Idx = 1; Idx < HalfNumElements; ++Idx) {
494     if (Mask[Idx] != (Idx * 2) + Offset)
495       return false;
496   }
497 
498   return true;
499 }
500 
501 bool ComplexDeinterleaving::evaluateBasicBlock(BasicBlock *B) {
502   ComplexDeinterleavingGraph Graph(TL, TLI);
503   if (Graph.collectPotentialReductions(B))
504     Graph.identifyReductionNodes();
505 
506   for (auto &I : *B)
507     Graph.identifyNodes(&I);
508 
509   if (Graph.checkNodes()) {
510     Graph.replaceNodes();
511     return true;
512   }
513 
514   return false;
515 }
516 
517 ComplexDeinterleavingGraph::NodePtr
518 ComplexDeinterleavingGraph::identifyNodeWithImplicitAdd(
519     Instruction *Real, Instruction *Imag,
520     std::pair<Instruction *, Instruction *> &PartialMatch) {
521   LLVM_DEBUG(dbgs() << "identifyNodeWithImplicitAdd " << *Real << " / " << *Imag
522                     << "\n");
523 
524   if (!Real->hasOneUse() || !Imag->hasOneUse()) {
525     LLVM_DEBUG(dbgs() << "  - Mul operand has multiple uses.\n");
526     return nullptr;
527   }
528 
529   if (Real->getOpcode() != Instruction::FMul ||
530       Imag->getOpcode() != Instruction::FMul) {
531     LLVM_DEBUG(dbgs() << "  - Real or imaginary instruction is not fmul\n");
532     return nullptr;
533   }
534 
535   Instruction *R0 = dyn_cast<Instruction>(Real->getOperand(0));
536   Instruction *R1 = dyn_cast<Instruction>(Real->getOperand(1));
537   Instruction *I0 = dyn_cast<Instruction>(Imag->getOperand(0));
538   Instruction *I1 = dyn_cast<Instruction>(Imag->getOperand(1));
539   if (!R0 || !R1 || !I0 || !I1) {
540     LLVM_DEBUG(dbgs() << "  - Mul operand not Instruction\n");
541     return nullptr;
542   }
543 
544   // A +/+ has a rotation of 0. If any of the operands are fneg, we flip the
545   // rotations and use the operand.
546   unsigned Negs = 0;
547   SmallVector<Instruction *> FNegs;
548   if (R0->getOpcode() == Instruction::FNeg ||
549       R1->getOpcode() == Instruction::FNeg) {
550     Negs |= 1;
551     if (R0->getOpcode() == Instruction::FNeg) {
552       FNegs.push_back(R0);
553       R0 = dyn_cast<Instruction>(R0->getOperand(0));
554     } else {
555       FNegs.push_back(R1);
556       R1 = dyn_cast<Instruction>(R1->getOperand(0));
557     }
558     if (!R0 || !R1)
559       return nullptr;
560   }
561   if (I0->getOpcode() == Instruction::FNeg ||
562       I1->getOpcode() == Instruction::FNeg) {
563     Negs |= 2;
564     Negs ^= 1;
565     if (I0->getOpcode() == Instruction::FNeg) {
566       FNegs.push_back(I0);
567       I0 = dyn_cast<Instruction>(I0->getOperand(0));
568     } else {
569       FNegs.push_back(I1);
570       I1 = dyn_cast<Instruction>(I1->getOperand(0));
571     }
572     if (!I0 || !I1)
573       return nullptr;
574   }
575 
576   ComplexDeinterleavingRotation Rotation = (ComplexDeinterleavingRotation)Negs;
577 
578   Instruction *CommonOperand;
579   Instruction *UncommonRealOp;
580   Instruction *UncommonImagOp;
581 
582   if (R0 == I0 || R0 == I1) {
583     CommonOperand = R0;
584     UncommonRealOp = R1;
585   } else if (R1 == I0 || R1 == I1) {
586     CommonOperand = R1;
587     UncommonRealOp = R0;
588   } else {
589     LLVM_DEBUG(dbgs() << "  - No equal operand\n");
590     return nullptr;
591   }
592 
593   UncommonImagOp = (CommonOperand == I0) ? I1 : I0;
594   if (Rotation == ComplexDeinterleavingRotation::Rotation_90 ||
595       Rotation == ComplexDeinterleavingRotation::Rotation_270)
596     std::swap(UncommonRealOp, UncommonImagOp);
597 
598   // Between identifyPartialMul and here we need to have found a complete valid
599   // pair from the CommonOperand of each part.
600   if (Rotation == ComplexDeinterleavingRotation::Rotation_0 ||
601       Rotation == ComplexDeinterleavingRotation::Rotation_180)
602     PartialMatch.first = CommonOperand;
603   else
604     PartialMatch.second = CommonOperand;
605 
606   if (!PartialMatch.first || !PartialMatch.second) {
607     LLVM_DEBUG(dbgs() << "  - Incomplete partial match\n");
608     return nullptr;
609   }
610 
611   NodePtr CommonNode = identifyNode(PartialMatch.first, PartialMatch.second);
612   if (!CommonNode) {
613     LLVM_DEBUG(dbgs() << "  - No CommonNode identified\n");
614     return nullptr;
615   }
616 
617   NodePtr UncommonNode = identifyNode(UncommonRealOp, UncommonImagOp);
618   if (!UncommonNode) {
619     LLVM_DEBUG(dbgs() << "  - No UncommonNode identified\n");
620     return nullptr;
621   }
622 
623   NodePtr Node = prepareCompositeNode(
624       ComplexDeinterleavingOperation::CMulPartial, Real, Imag);
625   Node->Rotation = Rotation;
626   Node->addOperand(CommonNode);
627   Node->addOperand(UncommonNode);
628   return submitCompositeNode(Node);
629 }
630 
631 ComplexDeinterleavingGraph::NodePtr
632 ComplexDeinterleavingGraph::identifyPartialMul(Instruction *Real,
633                                                Instruction *Imag) {
634   LLVM_DEBUG(dbgs() << "identifyPartialMul " << *Real << " / " << *Imag
635                     << "\n");
636   // Determine rotation
637   ComplexDeinterleavingRotation Rotation;
638   if (Real->getOpcode() == Instruction::FAdd &&
639       Imag->getOpcode() == Instruction::FAdd)
640     Rotation = ComplexDeinterleavingRotation::Rotation_0;
641   else if (Real->getOpcode() == Instruction::FSub &&
642            Imag->getOpcode() == Instruction::FAdd)
643     Rotation = ComplexDeinterleavingRotation::Rotation_90;
644   else if (Real->getOpcode() == Instruction::FSub &&
645            Imag->getOpcode() == Instruction::FSub)
646     Rotation = ComplexDeinterleavingRotation::Rotation_180;
647   else if (Real->getOpcode() == Instruction::FAdd &&
648            Imag->getOpcode() == Instruction::FSub)
649     Rotation = ComplexDeinterleavingRotation::Rotation_270;
650   else {
651     LLVM_DEBUG(dbgs() << "  - Unhandled rotation.\n");
652     return nullptr;
653   }
654 
655   if (!Real->getFastMathFlags().allowContract() ||
656       !Imag->getFastMathFlags().allowContract()) {
657     LLVM_DEBUG(dbgs() << "  - Contract is missing from the FastMath flags.\n");
658     return nullptr;
659   }
660 
661   Value *CR = Real->getOperand(0);
662   Instruction *RealMulI = dyn_cast<Instruction>(Real->getOperand(1));
663   if (!RealMulI)
664     return nullptr;
665   Value *CI = Imag->getOperand(0);
666   Instruction *ImagMulI = dyn_cast<Instruction>(Imag->getOperand(1));
667   if (!ImagMulI)
668     return nullptr;
669 
670   if (!RealMulI->hasOneUse() || !ImagMulI->hasOneUse()) {
671     LLVM_DEBUG(dbgs() << "  - Mul instruction has multiple uses\n");
672     return nullptr;
673   }
674 
675   Instruction *R0 = dyn_cast<Instruction>(RealMulI->getOperand(0));
676   Instruction *R1 = dyn_cast<Instruction>(RealMulI->getOperand(1));
677   Instruction *I0 = dyn_cast<Instruction>(ImagMulI->getOperand(0));
678   Instruction *I1 = dyn_cast<Instruction>(ImagMulI->getOperand(1));
679   if (!R0 || !R1 || !I0 || !I1) {
680     LLVM_DEBUG(dbgs() << "  - Mul operand not Instruction\n");
681     return nullptr;
682   }
683 
684   Instruction *CommonOperand;
685   Instruction *UncommonRealOp;
686   Instruction *UncommonImagOp;
687 
688   if (R0 == I0 || R0 == I1) {
689     CommonOperand = R0;
690     UncommonRealOp = R1;
691   } else if (R1 == I0 || R1 == I1) {
692     CommonOperand = R1;
693     UncommonRealOp = R0;
694   } else {
695     LLVM_DEBUG(dbgs() << "  - No equal operand\n");
696     return nullptr;
697   }
698 
699   UncommonImagOp = (CommonOperand == I0) ? I1 : I0;
700   if (Rotation == ComplexDeinterleavingRotation::Rotation_90 ||
701       Rotation == ComplexDeinterleavingRotation::Rotation_270)
702     std::swap(UncommonRealOp, UncommonImagOp);
703 
704   std::pair<Instruction *, Instruction *> PartialMatch(
705       (Rotation == ComplexDeinterleavingRotation::Rotation_0 ||
706        Rotation == ComplexDeinterleavingRotation::Rotation_180)
707           ? CommonOperand
708           : nullptr,
709       (Rotation == ComplexDeinterleavingRotation::Rotation_90 ||
710        Rotation == ComplexDeinterleavingRotation::Rotation_270)
711           ? CommonOperand
712           : nullptr);
713 
714   auto *CRInst = dyn_cast<Instruction>(CR);
715   auto *CIInst = dyn_cast<Instruction>(CI);
716 
717   if (!CRInst || !CIInst) {
718     LLVM_DEBUG(dbgs() << "  - Common operands are not instructions.\n");
719     return nullptr;
720   }
721 
722   NodePtr CNode = identifyNodeWithImplicitAdd(CRInst, CIInst, PartialMatch);
723   if (!CNode) {
724     LLVM_DEBUG(dbgs() << "  - No cnode identified\n");
725     return nullptr;
726   }
727 
728   NodePtr UncommonRes = identifyNode(UncommonRealOp, UncommonImagOp);
729   if (!UncommonRes) {
730     LLVM_DEBUG(dbgs() << "  - No UncommonRes identified\n");
731     return nullptr;
732   }
733 
734   assert(PartialMatch.first && PartialMatch.second);
735   NodePtr CommonRes = identifyNode(PartialMatch.first, PartialMatch.second);
736   if (!CommonRes) {
737     LLVM_DEBUG(dbgs() << "  - No CommonRes identified\n");
738     return nullptr;
739   }
740 
741   NodePtr Node = prepareCompositeNode(
742       ComplexDeinterleavingOperation::CMulPartial, Real, Imag);
743   Node->Rotation = Rotation;
744   Node->addOperand(CommonRes);
745   Node->addOperand(UncommonRes);
746   Node->addOperand(CNode);
747   return submitCompositeNode(Node);
748 }
749 
750 ComplexDeinterleavingGraph::NodePtr
751 ComplexDeinterleavingGraph::identifyAdd(Instruction *Real, Instruction *Imag) {
752   LLVM_DEBUG(dbgs() << "identifyAdd " << *Real << " / " << *Imag << "\n");
753 
754   // Determine rotation
755   ComplexDeinterleavingRotation Rotation;
756   if ((Real->getOpcode() == Instruction::FSub &&
757        Imag->getOpcode() == Instruction::FAdd) ||
758       (Real->getOpcode() == Instruction::Sub &&
759        Imag->getOpcode() == Instruction::Add))
760     Rotation = ComplexDeinterleavingRotation::Rotation_90;
761   else if ((Real->getOpcode() == Instruction::FAdd &&
762             Imag->getOpcode() == Instruction::FSub) ||
763            (Real->getOpcode() == Instruction::Add &&
764             Imag->getOpcode() == Instruction::Sub))
765     Rotation = ComplexDeinterleavingRotation::Rotation_270;
766   else {
767     LLVM_DEBUG(dbgs() << " - Unhandled case, rotation is not assigned.\n");
768     return nullptr;
769   }
770 
771   auto *AR = dyn_cast<Instruction>(Real->getOperand(0));
772   auto *BI = dyn_cast<Instruction>(Real->getOperand(1));
773   auto *AI = dyn_cast<Instruction>(Imag->getOperand(0));
774   auto *BR = dyn_cast<Instruction>(Imag->getOperand(1));
775 
776   if (!AR || !AI || !BR || !BI) {
777     LLVM_DEBUG(dbgs() << " - Not all operands are instructions.\n");
778     return nullptr;
779   }
780 
781   NodePtr ResA = identifyNode(AR, AI);
782   if (!ResA) {
783     LLVM_DEBUG(dbgs() << " - AR/AI is not identified as a composite node.\n");
784     return nullptr;
785   }
786   NodePtr ResB = identifyNode(BR, BI);
787   if (!ResB) {
788     LLVM_DEBUG(dbgs() << " - BR/BI is not identified as a composite node.\n");
789     return nullptr;
790   }
791 
792   NodePtr Node =
793       prepareCompositeNode(ComplexDeinterleavingOperation::CAdd, Real, Imag);
794   Node->Rotation = Rotation;
795   Node->addOperand(ResA);
796   Node->addOperand(ResB);
797   return submitCompositeNode(Node);
798 }
799 
800 static bool isInstructionPairAdd(Instruction *A, Instruction *B) {
801   unsigned OpcA = A->getOpcode();
802   unsigned OpcB = B->getOpcode();
803 
804   return (OpcA == Instruction::FSub && OpcB == Instruction::FAdd) ||
805          (OpcA == Instruction::FAdd && OpcB == Instruction::FSub) ||
806          (OpcA == Instruction::Sub && OpcB == Instruction::Add) ||
807          (OpcA == Instruction::Add && OpcB == Instruction::Sub);
808 }
809 
810 static bool isInstructionPairMul(Instruction *A, Instruction *B) {
811   auto Pattern =
812       m_BinOp(m_FMul(m_Value(), m_Value()), m_FMul(m_Value(), m_Value()));
813 
814   return match(A, Pattern) && match(B, Pattern);
815 }
816 
817 static bool isInstructionPotentiallySymmetric(Instruction *I) {
818   switch (I->getOpcode()) {
819   case Instruction::FAdd:
820   case Instruction::FSub:
821   case Instruction::FMul:
822   case Instruction::FNeg:
823     return true;
824   default:
825     return false;
826   }
827 }
828 
829 ComplexDeinterleavingGraph::NodePtr
830 ComplexDeinterleavingGraph::identifySymmetricOperation(Instruction *Real,
831                                                        Instruction *Imag) {
832   if (Real->getOpcode() != Imag->getOpcode())
833     return nullptr;
834 
835   if (!isInstructionPotentiallySymmetric(Real) ||
836       !isInstructionPotentiallySymmetric(Imag))
837     return nullptr;
838 
839   auto *R0 = dyn_cast<Instruction>(Real->getOperand(0));
840   auto *I0 = dyn_cast<Instruction>(Imag->getOperand(0));
841 
842   if (!R0 || !I0)
843     return nullptr;
844 
845   NodePtr Op0 = identifyNode(R0, I0);
846   NodePtr Op1 = nullptr;
847   if (Op0 == nullptr)
848     return nullptr;
849 
850   if (Real->isBinaryOp()) {
851     auto *R1 = dyn_cast<Instruction>(Real->getOperand(1));
852     auto *I1 = dyn_cast<Instruction>(Imag->getOperand(1));
853     if (!R1 || !I1)
854       return nullptr;
855 
856     Op1 = identifyNode(R1, I1);
857     if (Op1 == nullptr)
858       return nullptr;
859   }
860 
861   if (isa<FPMathOperator>(Real) &&
862       Real->getFastMathFlags() != Imag->getFastMathFlags())
863     return nullptr;
864 
865   auto Node = prepareCompositeNode(ComplexDeinterleavingOperation::Symmetric,
866                                    Real, Imag);
867   Node->Opcode = Real->getOpcode();
868   if (isa<FPMathOperator>(Real))
869     Node->Flags = Real->getFastMathFlags();
870 
871   Node->addOperand(Op0);
872   if (Real->isBinaryOp())
873     Node->addOperand(Op1);
874 
875   return submitCompositeNode(Node);
876 }
877 
878 ComplexDeinterleavingGraph::NodePtr
879 ComplexDeinterleavingGraph::identifyNode(Instruction *Real, Instruction *Imag) {
880   LLVM_DEBUG(dbgs() << "identifyNode on " << *Real << " / " << *Imag << "\n");
881   if (NodePtr CN = getContainingComposite(Real, Imag)) {
882     LLVM_DEBUG(dbgs() << " - Folding to existing node\n");
883     return CN;
884   }
885 
886   if (NodePtr CN = identifyDeinterleave(Real, Imag))
887     return CN;
888 
889   if (NodePtr CN = identifyPHINode(Real, Imag))
890     return CN;
891 
892   auto *VTy = cast<VectorType>(Real->getType());
893   auto *NewVTy = VectorType::getDoubleElementsVectorType(VTy);
894 
895   bool HasCMulSupport = TL->isComplexDeinterleavingOperationSupported(
896       ComplexDeinterleavingOperation::CMulPartial, NewVTy);
897   bool HasCAddSupport = TL->isComplexDeinterleavingOperationSupported(
898       ComplexDeinterleavingOperation::CAdd, NewVTy);
899 
900   if (HasCMulSupport && isInstructionPairMul(Real, Imag)) {
901     if (NodePtr CN = identifyPartialMul(Real, Imag))
902       return CN;
903   }
904 
905   if (HasCAddSupport && isInstructionPairAdd(Real, Imag)) {
906     if (NodePtr CN = identifyAdd(Real, Imag))
907       return CN;
908   }
909 
910   if (HasCMulSupport && HasCAddSupport) {
911     if (NodePtr CN = identifyReassocNodes(Real, Imag))
912       return CN;
913   }
914 
915   if (NodePtr CN = identifySymmetricOperation(Real, Imag))
916     return CN;
917 
918   LLVM_DEBUG(dbgs() << "  - Not recognised as a valid pattern.\n");
919   return nullptr;
920 }
921 
922 ComplexDeinterleavingGraph::NodePtr
923 ComplexDeinterleavingGraph::identifyReassocNodes(Instruction *Real,
924                                                  Instruction *Imag) {
925   if ((Real->getOpcode() != Instruction::FAdd &&
926        Real->getOpcode() != Instruction::FSub &&
927        Real->getOpcode() != Instruction::FNeg) ||
928       (Imag->getOpcode() != Instruction::FAdd &&
929        Imag->getOpcode() != Instruction::FSub &&
930        Imag->getOpcode() != Instruction::FNeg))
931     return nullptr;
932 
933   if (Real->getFastMathFlags() != Imag->getFastMathFlags()) {
934     LLVM_DEBUG(
935         dbgs()
936         << "The flags in Real and Imaginary instructions are not identical\n");
937     return nullptr;
938   }
939 
940   FastMathFlags Flags = Real->getFastMathFlags();
941   if (!Flags.allowReassoc()) {
942     LLVM_DEBUG(
943         dbgs() << "the 'Reassoc' attribute is missing in the FastMath flags\n");
944     return nullptr;
945   }
946 
947   // Collect multiplications and addend instructions from the given instruction
948   // while traversing it operands. Additionally, verify that all instructions
949   // have the same fast math flags.
950   auto Collect = [&Flags](Instruction *Insn, std::vector<Product> &Muls,
951                           std::list<Addend> &Addends) -> bool {
952     SmallVector<PointerIntPair<Value *, 1, bool>> Worklist = {{Insn, true}};
953     SmallPtrSet<Value *, 8> Visited;
954     while (!Worklist.empty()) {
955       auto [V, IsPositive] = Worklist.back();
956       Worklist.pop_back();
957       if (!Visited.insert(V).second)
958         continue;
959 
960       Instruction *I = dyn_cast<Instruction>(V);
961       if (!I)
962         return false;
963 
964       // If an instruction has more than one user, it indicates that it either
965       // has an external user, which will be later checked by the checkNodes
966       // function, or it is a subexpression utilized by multiple expressions. In
967       // the latter case, we will attempt to separately identify the complex
968       // operation from here in order to create a shared
969       // ComplexDeinterleavingCompositeNode.
970       if (I != Insn && I->getNumUses() > 1) {
971         LLVM_DEBUG(dbgs() << "Found potential sub-expression: " << *I << "\n");
972         Addends.emplace_back(I, IsPositive);
973         continue;
974       }
975 
976       if (I->getOpcode() == Instruction::FAdd) {
977         Worklist.emplace_back(I->getOperand(1), IsPositive);
978         Worklist.emplace_back(I->getOperand(0), IsPositive);
979       } else if (I->getOpcode() == Instruction::FSub) {
980         Worklist.emplace_back(I->getOperand(1), !IsPositive);
981         Worklist.emplace_back(I->getOperand(0), IsPositive);
982       } else if (I->getOpcode() == Instruction::FMul) {
983         auto *A = dyn_cast<Instruction>(I->getOperand(0));
984         if (A && A->getOpcode() == Instruction::FNeg) {
985           A = dyn_cast<Instruction>(A->getOperand(0));
986           IsPositive = !IsPositive;
987         }
988         if (!A)
989           return false;
990         auto *B = dyn_cast<Instruction>(I->getOperand(1));
991         if (B && B->getOpcode() == Instruction::FNeg) {
992           B = dyn_cast<Instruction>(B->getOperand(0));
993           IsPositive = !IsPositive;
994         }
995         if (!B)
996           return false;
997         Muls.push_back(Product{A, B, IsPositive});
998       } else if (I->getOpcode() == Instruction::FNeg) {
999         Worklist.emplace_back(I->getOperand(0), !IsPositive);
1000       } else {
1001         Addends.emplace_back(I, IsPositive);
1002         continue;
1003       }
1004 
1005       if (I->getFastMathFlags() != Flags) {
1006         LLVM_DEBUG(dbgs() << "The instruction's fast math flags are "
1007                              "inconsistent with the root instructions' flags: "
1008                           << *I << "\n");
1009         return false;
1010       }
1011     }
1012     return true;
1013   };
1014 
1015   std::vector<Product> RealMuls, ImagMuls;
1016   std::list<Addend> RealAddends, ImagAddends;
1017   if (!Collect(Real, RealMuls, RealAddends) ||
1018       !Collect(Imag, ImagMuls, ImagAddends))
1019     return nullptr;
1020 
1021   if (RealAddends.size() != ImagAddends.size())
1022     return nullptr;
1023 
1024   NodePtr FinalNode;
1025   if (!RealMuls.empty() || !ImagMuls.empty()) {
1026     // If there are multiplicands, extract positive addend and use it as an
1027     // accumulator
1028     FinalNode = extractPositiveAddend(RealAddends, ImagAddends);
1029     FinalNode = identifyMultiplications(RealMuls, ImagMuls, FinalNode);
1030     if (!FinalNode)
1031       return nullptr;
1032   }
1033 
1034   // Identify and process remaining additions
1035   if (!RealAddends.empty() || !ImagAddends.empty()) {
1036     FinalNode = identifyAdditions(RealAddends, ImagAddends, Flags, FinalNode);
1037     if (!FinalNode)
1038       return nullptr;
1039   }
1040 
1041   // Set the Real and Imag fields of the final node and submit it
1042   FinalNode->Real = Real;
1043   FinalNode->Imag = Imag;
1044   submitCompositeNode(FinalNode);
1045   return FinalNode;
1046 }
1047 
1048 bool ComplexDeinterleavingGraph::collectPartialMuls(
1049     const std::vector<Product> &RealMuls, const std::vector<Product> &ImagMuls,
1050     std::vector<PartialMulCandidate> &PartialMulCandidates) {
1051   // Helper function to extract a common operand from two products
1052   auto FindCommonInstruction = [](const Product &Real,
1053                                   const Product &Imag) -> Instruction * {
1054     if (Real.Multiplicand == Imag.Multiplicand ||
1055         Real.Multiplicand == Imag.Multiplier)
1056       return Real.Multiplicand;
1057 
1058     if (Real.Multiplier == Imag.Multiplicand ||
1059         Real.Multiplier == Imag.Multiplier)
1060       return Real.Multiplier;
1061 
1062     return nullptr;
1063   };
1064 
1065   // Iterating over real and imaginary multiplications to find common operands
1066   // If a common operand is found, a partial multiplication candidate is created
1067   // and added to the candidates vector The function returns false if no common
1068   // operands are found for any product
1069   for (unsigned i = 0; i < RealMuls.size(); ++i) {
1070     bool FoundCommon = false;
1071     for (unsigned j = 0; j < ImagMuls.size(); ++j) {
1072       auto *Common = FindCommonInstruction(RealMuls[i], ImagMuls[j]);
1073       if (!Common)
1074         continue;
1075 
1076       auto *A = RealMuls[i].Multiplicand == Common ? RealMuls[i].Multiplier
1077                                                    : RealMuls[i].Multiplicand;
1078       auto *B = ImagMuls[j].Multiplicand == Common ? ImagMuls[j].Multiplier
1079                                                    : ImagMuls[j].Multiplicand;
1080 
1081       bool Inverted = false;
1082       auto Node = identifyNode(A, B);
1083       if (!Node) {
1084         std::swap(A, B);
1085         Inverted = true;
1086         Node = identifyNode(A, B);
1087       }
1088       if (!Node)
1089         continue;
1090 
1091       FoundCommon = true;
1092       PartialMulCandidates.push_back({Common, Node, i, j, Inverted});
1093     }
1094     if (!FoundCommon)
1095       return false;
1096   }
1097   return true;
1098 }
1099 
1100 ComplexDeinterleavingGraph::NodePtr
1101 ComplexDeinterleavingGraph::identifyMultiplications(
1102     std::vector<Product> &RealMuls, std::vector<Product> &ImagMuls,
1103     NodePtr Accumulator = nullptr) {
1104   if (RealMuls.size() != ImagMuls.size())
1105     return nullptr;
1106 
1107   std::vector<PartialMulCandidate> Info;
1108   if (!collectPartialMuls(RealMuls, ImagMuls, Info))
1109     return nullptr;
1110 
1111   // Map to store common instruction to node pointers
1112   std::map<Instruction *, NodePtr> CommonToNode;
1113   std::vector<bool> Processed(Info.size(), false);
1114   for (unsigned I = 0; I < Info.size(); ++I) {
1115     if (Processed[I])
1116       continue;
1117 
1118     PartialMulCandidate &InfoA = Info[I];
1119     for (unsigned J = I + 1; J < Info.size(); ++J) {
1120       if (Processed[J])
1121         continue;
1122 
1123       PartialMulCandidate &InfoB = Info[J];
1124       auto *InfoReal = &InfoA;
1125       auto *InfoImag = &InfoB;
1126 
1127       auto NodeFromCommon = identifyNode(InfoReal->Common, InfoImag->Common);
1128       if (!NodeFromCommon) {
1129         std::swap(InfoReal, InfoImag);
1130         NodeFromCommon = identifyNode(InfoReal->Common, InfoImag->Common);
1131       }
1132       if (!NodeFromCommon)
1133         continue;
1134 
1135       CommonToNode[InfoReal->Common] = NodeFromCommon;
1136       CommonToNode[InfoImag->Common] = NodeFromCommon;
1137       Processed[I] = true;
1138       Processed[J] = true;
1139     }
1140   }
1141 
1142   std::vector<bool> ProcessedReal(RealMuls.size(), false);
1143   std::vector<bool> ProcessedImag(ImagMuls.size(), false);
1144   NodePtr Result = Accumulator;
1145   for (auto &PMI : Info) {
1146     if (ProcessedReal[PMI.RealIdx] || ProcessedImag[PMI.ImagIdx])
1147       continue;
1148 
1149     auto It = CommonToNode.find(PMI.Common);
1150     // TODO: Process independent complex multiplications. Cases like this:
1151     //  A.real() * B where both A and B are complex numbers.
1152     if (It == CommonToNode.end()) {
1153       LLVM_DEBUG({
1154         dbgs() << "Unprocessed independent partial multiplication:\n";
1155         for (auto *Mul : {&RealMuls[PMI.RealIdx], &RealMuls[PMI.RealIdx]})
1156           dbgs().indent(4) << (Mul->IsPositive ? "+" : "-") << *Mul->Multiplier
1157                            << " multiplied by " << *Mul->Multiplicand << "\n";
1158       });
1159       return nullptr;
1160     }
1161 
1162     auto &RealMul = RealMuls[PMI.RealIdx];
1163     auto &ImagMul = ImagMuls[PMI.ImagIdx];
1164 
1165     auto NodeA = It->second;
1166     auto NodeB = PMI.Node;
1167     auto IsMultiplicandReal = PMI.Common == NodeA->Real;
1168     // The following table illustrates the relationship between multiplications
1169     // and rotations. If we consider the multiplication (X + iY) * (U + iV), we
1170     // can see:
1171     //
1172     // Rotation |   Real |   Imag |
1173     // ---------+--------+--------+
1174     //        0 |  x * u |  x * v |
1175     //       90 | -y * v |  y * u |
1176     //      180 | -x * u | -x * v |
1177     //      270 |  y * v | -y * u |
1178     //
1179     // Check if the candidate can indeed be represented by partial
1180     // multiplication
1181     // TODO: Add support for multiplication by complex one
1182     if ((IsMultiplicandReal && PMI.IsNodeInverted) ||
1183         (!IsMultiplicandReal && !PMI.IsNodeInverted))
1184       continue;
1185 
1186     // Determine the rotation based on the multiplications
1187     ComplexDeinterleavingRotation Rotation;
1188     if (IsMultiplicandReal) {
1189       // Detect 0 and 180 degrees rotation
1190       if (RealMul.IsPositive && ImagMul.IsPositive)
1191         Rotation = llvm::ComplexDeinterleavingRotation::Rotation_0;
1192       else if (!RealMul.IsPositive && !ImagMul.IsPositive)
1193         Rotation = llvm::ComplexDeinterleavingRotation::Rotation_180;
1194       else
1195         continue;
1196 
1197     } else {
1198       // Detect 90 and 270 degrees rotation
1199       if (!RealMul.IsPositive && ImagMul.IsPositive)
1200         Rotation = llvm::ComplexDeinterleavingRotation::Rotation_90;
1201       else if (RealMul.IsPositive && !ImagMul.IsPositive)
1202         Rotation = llvm::ComplexDeinterleavingRotation::Rotation_270;
1203       else
1204         continue;
1205     }
1206 
1207     LLVM_DEBUG({
1208       dbgs() << "Identified partial multiplication (X, Y) * (U, V):\n";
1209       dbgs().indent(4) << "X: " << *NodeA->Real << "\n";
1210       dbgs().indent(4) << "Y: " << *NodeA->Imag << "\n";
1211       dbgs().indent(4) << "U: " << *NodeB->Real << "\n";
1212       dbgs().indent(4) << "V: " << *NodeB->Imag << "\n";
1213       dbgs().indent(4) << "Rotation - " << (int)Rotation * 90 << "\n";
1214     });
1215 
1216     NodePtr NodeMul = prepareCompositeNode(
1217         ComplexDeinterleavingOperation::CMulPartial, nullptr, nullptr);
1218     NodeMul->Rotation = Rotation;
1219     NodeMul->addOperand(NodeA);
1220     NodeMul->addOperand(NodeB);
1221     if (Result)
1222       NodeMul->addOperand(Result);
1223     submitCompositeNode(NodeMul);
1224     Result = NodeMul;
1225     ProcessedReal[PMI.RealIdx] = true;
1226     ProcessedImag[PMI.ImagIdx] = true;
1227   }
1228 
1229   // Ensure all products have been processed, if not return nullptr.
1230   if (!all_of(ProcessedReal, [](bool V) { return V; }) ||
1231       !all_of(ProcessedImag, [](bool V) { return V; })) {
1232 
1233     // Dump debug information about which partial multiplications are not
1234     // processed.
1235     LLVM_DEBUG({
1236       dbgs() << "Unprocessed products (Real):\n";
1237       for (size_t i = 0; i < ProcessedReal.size(); ++i) {
1238         if (!ProcessedReal[i])
1239           dbgs().indent(4) << (RealMuls[i].IsPositive ? "+" : "-")
1240                            << *RealMuls[i].Multiplier << " multiplied by "
1241                            << *RealMuls[i].Multiplicand << "\n";
1242       }
1243       dbgs() << "Unprocessed products (Imag):\n";
1244       for (size_t i = 0; i < ProcessedImag.size(); ++i) {
1245         if (!ProcessedImag[i])
1246           dbgs().indent(4) << (ImagMuls[i].IsPositive ? "+" : "-")
1247                            << *ImagMuls[i].Multiplier << " multiplied by "
1248                            << *ImagMuls[i].Multiplicand << "\n";
1249       }
1250     });
1251     return nullptr;
1252   }
1253 
1254   return Result;
1255 }
1256 
1257 ComplexDeinterleavingGraph::NodePtr
1258 ComplexDeinterleavingGraph::identifyAdditions(std::list<Addend> &RealAddends,
1259                                               std::list<Addend> &ImagAddends,
1260                                               FastMathFlags Flags,
1261                                               NodePtr Accumulator = nullptr) {
1262   if (RealAddends.size() != ImagAddends.size())
1263     return nullptr;
1264 
1265   NodePtr Result;
1266   // If we have accumulator use it as first addend
1267   if (Accumulator)
1268     Result = Accumulator;
1269   // Otherwise find an element with both positive real and imaginary parts.
1270   else
1271     Result = extractPositiveAddend(RealAddends, ImagAddends);
1272 
1273   if (!Result)
1274     return nullptr;
1275 
1276   while (!RealAddends.empty()) {
1277     auto ItR = RealAddends.begin();
1278     auto [R, IsPositiveR] = *ItR;
1279 
1280     bool FoundImag = false;
1281     for (auto ItI = ImagAddends.begin(); ItI != ImagAddends.end(); ++ItI) {
1282       auto [I, IsPositiveI] = *ItI;
1283       ComplexDeinterleavingRotation Rotation;
1284       if (IsPositiveR && IsPositiveI)
1285         Rotation = ComplexDeinterleavingRotation::Rotation_0;
1286       else if (!IsPositiveR && IsPositiveI)
1287         Rotation = ComplexDeinterleavingRotation::Rotation_90;
1288       else if (!IsPositiveR && !IsPositiveI)
1289         Rotation = ComplexDeinterleavingRotation::Rotation_180;
1290       else
1291         Rotation = ComplexDeinterleavingRotation::Rotation_270;
1292 
1293       NodePtr AddNode;
1294       if (Rotation == ComplexDeinterleavingRotation::Rotation_0 ||
1295           Rotation == ComplexDeinterleavingRotation::Rotation_180) {
1296         AddNode = identifyNode(R, I);
1297       } else {
1298         AddNode = identifyNode(I, R);
1299       }
1300       if (AddNode) {
1301         LLVM_DEBUG({
1302           dbgs() << "Identified addition:\n";
1303           dbgs().indent(4) << "X: " << *R << "\n";
1304           dbgs().indent(4) << "Y: " << *I << "\n";
1305           dbgs().indent(4) << "Rotation - " << (int)Rotation * 90 << "\n";
1306         });
1307 
1308         NodePtr TmpNode;
1309         if (Rotation == llvm::ComplexDeinterleavingRotation::Rotation_0) {
1310           TmpNode = prepareCompositeNode(
1311               ComplexDeinterleavingOperation::Symmetric, nullptr, nullptr);
1312           TmpNode->Opcode = Instruction::FAdd;
1313           TmpNode->Flags = Flags;
1314         } else if (Rotation ==
1315                    llvm::ComplexDeinterleavingRotation::Rotation_180) {
1316           TmpNode = prepareCompositeNode(
1317               ComplexDeinterleavingOperation::Symmetric, nullptr, nullptr);
1318           TmpNode->Opcode = Instruction::FSub;
1319           TmpNode->Flags = Flags;
1320         } else {
1321           TmpNode = prepareCompositeNode(ComplexDeinterleavingOperation::CAdd,
1322                                          nullptr, nullptr);
1323           TmpNode->Rotation = Rotation;
1324         }
1325 
1326         TmpNode->addOperand(Result);
1327         TmpNode->addOperand(AddNode);
1328         submitCompositeNode(TmpNode);
1329         Result = TmpNode;
1330         RealAddends.erase(ItR);
1331         ImagAddends.erase(ItI);
1332         FoundImag = true;
1333         break;
1334       }
1335     }
1336     if (!FoundImag)
1337       return nullptr;
1338   }
1339   return Result;
1340 }
1341 
1342 ComplexDeinterleavingGraph::NodePtr
1343 ComplexDeinterleavingGraph::extractPositiveAddend(
1344     std::list<Addend> &RealAddends, std::list<Addend> &ImagAddends) {
1345   for (auto ItR = RealAddends.begin(); ItR != RealAddends.end(); ++ItR) {
1346     for (auto ItI = ImagAddends.begin(); ItI != ImagAddends.end(); ++ItI) {
1347       auto [R, IsPositiveR] = *ItR;
1348       auto [I, IsPositiveI] = *ItI;
1349       if (IsPositiveR && IsPositiveI) {
1350         auto Result = identifyNode(R, I);
1351         if (Result) {
1352           RealAddends.erase(ItR);
1353           ImagAddends.erase(ItI);
1354           return Result;
1355         }
1356       }
1357     }
1358   }
1359   return nullptr;
1360 }
1361 
1362 bool ComplexDeinterleavingGraph::identifyNodes(Instruction *RootI) {
1363   // This potential root instruction might already have been recognized as
1364   // reduction. Because RootToNode maps both Real and Imaginary parts to
1365   // CompositeNode we should choose only one either Real or Imag instruction to
1366   // use as an anchor for generating complex instruction.
1367   auto It = RootToNode.find(RootI);
1368   if (It != RootToNode.end() && It->second->Real == RootI) {
1369     OrderedRoots.push_back(RootI);
1370     return true;
1371   }
1372 
1373   auto RootNode = identifyRoot(RootI);
1374   if (!RootNode)
1375     return false;
1376 
1377   LLVM_DEBUG({
1378     Function *F = RootI->getFunction();
1379     BasicBlock *B = RootI->getParent();
1380     dbgs() << "Complex deinterleaving graph for " << F->getName()
1381            << "::" << B->getName() << ".\n";
1382     dump(dbgs());
1383     dbgs() << "\n";
1384   });
1385   RootToNode[RootI] = RootNode;
1386   OrderedRoots.push_back(RootI);
1387   return true;
1388 }
1389 
1390 bool ComplexDeinterleavingGraph::collectPotentialReductions(BasicBlock *B) {
1391   bool FoundPotentialReduction = false;
1392 
1393   auto *Br = dyn_cast<BranchInst>(B->getTerminator());
1394   if (!Br || Br->getNumSuccessors() != 2)
1395     return false;
1396 
1397   // Identify simple one-block loop
1398   if (Br->getSuccessor(0) != B && Br->getSuccessor(1) != B)
1399     return false;
1400 
1401   SmallVector<PHINode *> PHIs;
1402   for (auto &PHI : B->phis()) {
1403     if (PHI.getNumIncomingValues() != 2)
1404       continue;
1405 
1406     if (!PHI.getType()->isVectorTy())
1407       continue;
1408 
1409     auto *ReductionOp = dyn_cast<Instruction>(PHI.getIncomingValueForBlock(B));
1410     if (!ReductionOp)
1411       continue;
1412 
1413     // Check if final instruction is reduced outside of current block
1414     Instruction *FinalReduction = nullptr;
1415     auto NumUsers = 0u;
1416     for (auto *U : ReductionOp->users()) {
1417       ++NumUsers;
1418       if (U == &PHI)
1419         continue;
1420       FinalReduction = dyn_cast<Instruction>(U);
1421     }
1422 
1423     if (NumUsers != 2 || !FinalReduction || FinalReduction->getParent() == B)
1424       continue;
1425 
1426     ReductionInfo[ReductionOp] = {&PHI, FinalReduction};
1427     BackEdge = B;
1428     auto BackEdgeIdx = PHI.getBasicBlockIndex(B);
1429     auto IncomingIdx = BackEdgeIdx == 0 ? 1 : 0;
1430     Incoming = PHI.getIncomingBlock(IncomingIdx);
1431     FoundPotentialReduction = true;
1432 
1433     // If the initial value of PHINode is an Instruction, consider it a leaf
1434     // value of a complex deinterleaving graph.
1435     if (auto *InitPHI =
1436             dyn_cast<Instruction>(PHI.getIncomingValueForBlock(Incoming)))
1437       FinalInstructions.insert(InitPHI);
1438   }
1439   return FoundPotentialReduction;
1440 }
1441 
1442 void ComplexDeinterleavingGraph::identifyReductionNodes() {
1443   SmallVector<bool> Processed(ReductionInfo.size(), false);
1444   SmallVector<Instruction *> OperationInstruction;
1445   for (auto &P : ReductionInfo)
1446     OperationInstruction.push_back(P.first);
1447 
1448   // Identify a complex computation by evaluating two reduction operations that
1449   // potentially could be involved
1450   for (size_t i = 0; i < OperationInstruction.size(); ++i) {
1451     if (Processed[i])
1452       continue;
1453     for (size_t j = i + 1; j < OperationInstruction.size(); ++j) {
1454       if (Processed[j])
1455         continue;
1456 
1457       auto *Real = OperationInstruction[i];
1458       auto *Imag = OperationInstruction[j];
1459 
1460       RealPHI = ReductionInfo[Real].first;
1461       ImagPHI = ReductionInfo[Imag].first;
1462       auto Node = identifyNode(Real, Imag);
1463       if (!Node) {
1464         std::swap(Real, Imag);
1465         std::swap(RealPHI, ImagPHI);
1466         Node = identifyNode(Real, Imag);
1467       }
1468 
1469       // If a node is identified, mark its operation instructions as used to
1470       // prevent re-identification and attach the node to the real part
1471       if (Node) {
1472         LLVM_DEBUG(dbgs() << "Identified reduction starting from instructions: "
1473                           << *Real << " / " << *Imag << "\n");
1474         Processed[i] = true;
1475         Processed[j] = true;
1476         auto RootNode = prepareCompositeNode(
1477             ComplexDeinterleavingOperation::ReductionOperation, Real, Imag);
1478         RootNode->addOperand(Node);
1479         RootToNode[Real] = RootNode;
1480         RootToNode[Imag] = RootNode;
1481         submitCompositeNode(RootNode);
1482         break;
1483       }
1484     }
1485   }
1486 
1487   RealPHI = nullptr;
1488   ImagPHI = nullptr;
1489 }
1490 
1491 bool ComplexDeinterleavingGraph::checkNodes() {
1492   // Collect all instructions from roots to leaves
1493   SmallPtrSet<Instruction *, 16> AllInstructions;
1494   SmallVector<Instruction *, 8> Worklist;
1495   for (auto &Pair : RootToNode)
1496     Worklist.push_back(Pair.first);
1497 
1498   // Extract all instructions that are used by all XCMLA/XCADD/ADD/SUB/NEG
1499   // chains
1500   while (!Worklist.empty()) {
1501     auto *I = Worklist.back();
1502     Worklist.pop_back();
1503 
1504     if (!AllInstructions.insert(I).second)
1505       continue;
1506 
1507     for (Value *Op : I->operands()) {
1508       if (auto *OpI = dyn_cast<Instruction>(Op)) {
1509         if (!FinalInstructions.count(I))
1510           Worklist.emplace_back(OpI);
1511       }
1512     }
1513   }
1514 
1515   // Find instructions that have users outside of chain
1516   SmallVector<Instruction *, 2> OuterInstructions;
1517   for (auto *I : AllInstructions) {
1518     // Skip root nodes
1519     if (RootToNode.count(I))
1520       continue;
1521 
1522     for (User *U : I->users()) {
1523       if (AllInstructions.count(cast<Instruction>(U)))
1524         continue;
1525 
1526       // Found an instruction that is not used by XCMLA/XCADD chain
1527       Worklist.emplace_back(I);
1528       break;
1529     }
1530   }
1531 
1532   // If any instructions are found to be used outside, find and remove roots
1533   // that somehow connect to those instructions.
1534   SmallPtrSet<Instruction *, 16> Visited;
1535   while (!Worklist.empty()) {
1536     auto *I = Worklist.back();
1537     Worklist.pop_back();
1538     if (!Visited.insert(I).second)
1539       continue;
1540 
1541     // Found an impacted root node. Removing it from the nodes to be
1542     // deinterleaved
1543     if (RootToNode.count(I)) {
1544       LLVM_DEBUG(dbgs() << "Instruction " << *I
1545                         << " could be deinterleaved but its chain of complex "
1546                            "operations have an outside user\n");
1547       RootToNode.erase(I);
1548     }
1549 
1550     if (!AllInstructions.count(I) || FinalInstructions.count(I))
1551       continue;
1552 
1553     for (User *U : I->users())
1554       Worklist.emplace_back(cast<Instruction>(U));
1555 
1556     for (Value *Op : I->operands()) {
1557       if (auto *OpI = dyn_cast<Instruction>(Op))
1558         Worklist.emplace_back(OpI);
1559     }
1560   }
1561   return !RootToNode.empty();
1562 }
1563 
1564 ComplexDeinterleavingGraph::NodePtr
1565 ComplexDeinterleavingGraph::identifyRoot(Instruction *RootI) {
1566   if (auto *Intrinsic = dyn_cast<IntrinsicInst>(RootI)) {
1567     if (Intrinsic->getIntrinsicID() !=
1568         Intrinsic::experimental_vector_interleave2)
1569       return nullptr;
1570 
1571     auto *Real = dyn_cast<Instruction>(Intrinsic->getOperand(0));
1572     auto *Imag = dyn_cast<Instruction>(Intrinsic->getOperand(1));
1573     if (!Real || !Imag)
1574       return nullptr;
1575 
1576     return identifyNode(Real, Imag);
1577   }
1578 
1579   auto *SVI = dyn_cast<ShuffleVectorInst>(RootI);
1580   if (!SVI)
1581     return nullptr;
1582 
1583   // Look for a shufflevector that takes separate vectors of the real and
1584   // imaginary components and recombines them into a single vector.
1585   if (!isInterleavingMask(SVI->getShuffleMask()))
1586     return nullptr;
1587 
1588   Instruction *Real;
1589   Instruction *Imag;
1590   if (!match(RootI, m_Shuffle(m_Instruction(Real), m_Instruction(Imag))))
1591     return nullptr;
1592 
1593   return identifyNode(Real, Imag);
1594 }
1595 
1596 ComplexDeinterleavingGraph::NodePtr
1597 ComplexDeinterleavingGraph::identifyDeinterleave(Instruction *Real,
1598                                                  Instruction *Imag) {
1599   Instruction *I = nullptr;
1600   Value *FinalValue = nullptr;
1601   if (match(Real, m_ExtractValue<0>(m_Instruction(I))) &&
1602       match(Imag, m_ExtractValue<1>(m_Specific(I))) &&
1603       match(I, m_Intrinsic<Intrinsic::experimental_vector_deinterleave2>(
1604                    m_Value(FinalValue)))) {
1605     NodePtr PlaceholderNode = prepareCompositeNode(
1606         llvm::ComplexDeinterleavingOperation::Deinterleave, Real, Imag);
1607     PlaceholderNode->ReplacementNode = FinalValue;
1608     FinalInstructions.insert(Real);
1609     FinalInstructions.insert(Imag);
1610     return submitCompositeNode(PlaceholderNode);
1611   }
1612 
1613   auto *RealShuffle = dyn_cast<ShuffleVectorInst>(Real);
1614   auto *ImagShuffle = dyn_cast<ShuffleVectorInst>(Imag);
1615   if (!RealShuffle || !ImagShuffle) {
1616     if (RealShuffle || ImagShuffle)
1617       LLVM_DEBUG(dbgs() << " - There's a shuffle where there shouldn't be.\n");
1618     return nullptr;
1619   }
1620 
1621   Value *RealOp1 = RealShuffle->getOperand(1);
1622   if (!isa<UndefValue>(RealOp1) && !isa<ConstantAggregateZero>(RealOp1)) {
1623     LLVM_DEBUG(dbgs() << " - RealOp1 is not undef or zero.\n");
1624     return nullptr;
1625   }
1626   Value *ImagOp1 = ImagShuffle->getOperand(1);
1627   if (!isa<UndefValue>(ImagOp1) && !isa<ConstantAggregateZero>(ImagOp1)) {
1628     LLVM_DEBUG(dbgs() << " - ImagOp1 is not undef or zero.\n");
1629     return nullptr;
1630   }
1631 
1632   Value *RealOp0 = RealShuffle->getOperand(0);
1633   Value *ImagOp0 = ImagShuffle->getOperand(0);
1634 
1635   if (RealOp0 != ImagOp0) {
1636     LLVM_DEBUG(dbgs() << " - Shuffle operands are not equal.\n");
1637     return nullptr;
1638   }
1639 
1640   ArrayRef<int> RealMask = RealShuffle->getShuffleMask();
1641   ArrayRef<int> ImagMask = ImagShuffle->getShuffleMask();
1642   if (!isDeinterleavingMask(RealMask) || !isDeinterleavingMask(ImagMask)) {
1643     LLVM_DEBUG(dbgs() << " - Masks are not deinterleaving.\n");
1644     return nullptr;
1645   }
1646 
1647   if (RealMask[0] != 0 || ImagMask[0] != 1) {
1648     LLVM_DEBUG(dbgs() << " - Masks do not have the correct initial value.\n");
1649     return nullptr;
1650   }
1651 
1652   // Type checking, the shuffle type should be a vector type of the same
1653   // scalar type, but half the size
1654   auto CheckType = [&](ShuffleVectorInst *Shuffle) {
1655     Value *Op = Shuffle->getOperand(0);
1656     auto *ShuffleTy = cast<FixedVectorType>(Shuffle->getType());
1657     auto *OpTy = cast<FixedVectorType>(Op->getType());
1658 
1659     if (OpTy->getScalarType() != ShuffleTy->getScalarType())
1660       return false;
1661     if ((ShuffleTy->getNumElements() * 2) != OpTy->getNumElements())
1662       return false;
1663 
1664     return true;
1665   };
1666 
1667   auto CheckDeinterleavingShuffle = [&](ShuffleVectorInst *Shuffle) -> bool {
1668     if (!CheckType(Shuffle))
1669       return false;
1670 
1671     ArrayRef<int> Mask = Shuffle->getShuffleMask();
1672     int Last = *Mask.rbegin();
1673 
1674     Value *Op = Shuffle->getOperand(0);
1675     auto *OpTy = cast<FixedVectorType>(Op->getType());
1676     int NumElements = OpTy->getNumElements();
1677 
1678     // Ensure that the deinterleaving shuffle only pulls from the first
1679     // shuffle operand.
1680     return Last < NumElements;
1681   };
1682 
1683   if (RealShuffle->getType() != ImagShuffle->getType()) {
1684     LLVM_DEBUG(dbgs() << " - Shuffle types aren't equal.\n");
1685     return nullptr;
1686   }
1687   if (!CheckDeinterleavingShuffle(RealShuffle)) {
1688     LLVM_DEBUG(dbgs() << " - RealShuffle is invalid type.\n");
1689     return nullptr;
1690   }
1691   if (!CheckDeinterleavingShuffle(ImagShuffle)) {
1692     LLVM_DEBUG(dbgs() << " - ImagShuffle is invalid type.\n");
1693     return nullptr;
1694   }
1695 
1696   NodePtr PlaceholderNode =
1697       prepareCompositeNode(llvm::ComplexDeinterleavingOperation::Deinterleave,
1698                            RealShuffle, ImagShuffle);
1699   PlaceholderNode->ReplacementNode = RealShuffle->getOperand(0);
1700   FinalInstructions.insert(RealShuffle);
1701   FinalInstructions.insert(ImagShuffle);
1702   return submitCompositeNode(PlaceholderNode);
1703 }
1704 
1705 ComplexDeinterleavingGraph::NodePtr
1706 ComplexDeinterleavingGraph::identifyPHINode(Instruction *Real,
1707                                             Instruction *Imag) {
1708   if (Real != RealPHI || Imag != ImagPHI)
1709     return nullptr;
1710 
1711   NodePtr PlaceholderNode = prepareCompositeNode(
1712       ComplexDeinterleavingOperation::ReductionPHI, Real, Imag);
1713   return submitCompositeNode(PlaceholderNode);
1714 }
1715 
1716 static Value *replaceSymmetricNode(IRBuilderBase &B, unsigned Opcode,
1717                                    FastMathFlags Flags, Value *InputA,
1718                                    Value *InputB) {
1719   Value *I;
1720   switch (Opcode) {
1721   case Instruction::FNeg:
1722     I = B.CreateFNeg(InputA);
1723     break;
1724   case Instruction::FAdd:
1725     I = B.CreateFAdd(InputA, InputB);
1726     break;
1727   case Instruction::FSub:
1728     I = B.CreateFSub(InputA, InputB);
1729     break;
1730   case Instruction::FMul:
1731     I = B.CreateFMul(InputA, InputB);
1732     break;
1733   default:
1734     llvm_unreachable("Incorrect symmetric opcode");
1735   }
1736   cast<Instruction>(I)->setFastMathFlags(Flags);
1737   return I;
1738 }
1739 
1740 Value *ComplexDeinterleavingGraph::replaceNode(IRBuilderBase &Builder,
1741                                                RawNodePtr Node) {
1742   if (Node->ReplacementNode)
1743     return Node->ReplacementNode;
1744 
1745   auto ReplaceOperandIfExist = [&](RawNodePtr &Node, unsigned Idx) -> Value * {
1746     return Node->Operands.size() > Idx
1747                ? replaceNode(Builder, Node->Operands[Idx])
1748                : nullptr;
1749   };
1750 
1751   Value *ReplacementNode;
1752   switch (Node->Operation) {
1753   case ComplexDeinterleavingOperation::CAdd:
1754   case ComplexDeinterleavingOperation::CMulPartial:
1755   case ComplexDeinterleavingOperation::Symmetric: {
1756     Value *Input0 = ReplaceOperandIfExist(Node, 0);
1757     Value *Input1 = ReplaceOperandIfExist(Node, 1);
1758     Value *Accumulator = ReplaceOperandIfExist(Node, 2);
1759     assert(!Input1 || (Input0->getType() == Input1->getType() &&
1760                        "Node inputs need to be of the same type"));
1761     assert(!Accumulator ||
1762            (Input0->getType() == Accumulator->getType() &&
1763             "Accumulator and input need to be of the same type"));
1764     if (Node->Operation == ComplexDeinterleavingOperation::Symmetric)
1765       ReplacementNode = replaceSymmetricNode(Builder, Node->Opcode, Node->Flags,
1766                                              Input0, Input1);
1767     else
1768       ReplacementNode = TL->createComplexDeinterleavingIR(
1769           Builder, Node->Operation, Node->Rotation, Input0, Input1,
1770           Accumulator);
1771     break;
1772   }
1773   case ComplexDeinterleavingOperation::Deinterleave:
1774     llvm_unreachable("Deinterleave node should already have ReplacementNode");
1775     break;
1776   case ComplexDeinterleavingOperation::ReductionPHI: {
1777     // If Operation is ReductionPHI, a new empty PHINode is created.
1778     // It is filled later when the ReductionOperation is processed.
1779     auto *VTy = cast<VectorType>(Node->Real->getType());
1780     auto *NewVTy = VectorType::getDoubleElementsVectorType(VTy);
1781     auto *NewPHI = PHINode::Create(NewVTy, 0, "", BackEdge->getFirstNonPHI());
1782     OldToNewPHI[dyn_cast<PHINode>(Node->Real)] = NewPHI;
1783     ReplacementNode = NewPHI;
1784     break;
1785   }
1786   case ComplexDeinterleavingOperation::ReductionOperation:
1787     ReplacementNode = replaceNode(Builder, Node->Operands[0]);
1788     processReductionOperation(ReplacementNode, Node);
1789     break;
1790   }
1791 
1792   assert(ReplacementNode && "Target failed to create Intrinsic call.");
1793   NumComplexTransformations += 1;
1794   Node->ReplacementNode = ReplacementNode;
1795   return ReplacementNode;
1796 }
1797 
1798 void ComplexDeinterleavingGraph::processReductionOperation(
1799     Value *OperationReplacement, RawNodePtr Node) {
1800   auto *OldPHIReal = ReductionInfo[Node->Real].first;
1801   auto *OldPHIImag = ReductionInfo[Node->Imag].first;
1802   auto *NewPHI = OldToNewPHI[OldPHIReal];
1803 
1804   auto *VTy = cast<VectorType>(Node->Real->getType());
1805   auto *NewVTy = VectorType::getDoubleElementsVectorType(VTy);
1806 
1807   // We have to interleave initial origin values coming from IncomingBlock
1808   Value *InitReal = OldPHIReal->getIncomingValueForBlock(Incoming);
1809   Value *InitImag = OldPHIImag->getIncomingValueForBlock(Incoming);
1810 
1811   IRBuilder<> Builder(Incoming->getTerminator());
1812   auto *NewInit = Builder.CreateIntrinsic(
1813       Intrinsic::experimental_vector_interleave2, NewVTy, {InitReal, InitImag});
1814 
1815   NewPHI->addIncoming(NewInit, Incoming);
1816   NewPHI->addIncoming(OperationReplacement, BackEdge);
1817 
1818   // Deinterleave complex vector outside of loop so that it can be finally
1819   // reduced
1820   auto *FinalReductionReal = ReductionInfo[Node->Real].second;
1821   auto *FinalReductionImag = ReductionInfo[Node->Imag].second;
1822 
1823   Builder.SetInsertPoint(
1824       &*FinalReductionReal->getParent()->getFirstInsertionPt());
1825   auto *Deinterleave = Builder.CreateIntrinsic(
1826       Intrinsic::experimental_vector_deinterleave2,
1827       OperationReplacement->getType(), OperationReplacement);
1828 
1829   auto *NewReal = Builder.CreateExtractValue(Deinterleave, (uint64_t)0);
1830   FinalReductionReal->replaceUsesOfWith(Node->Real, NewReal);
1831 
1832   Builder.SetInsertPoint(FinalReductionImag);
1833   auto *NewImag = Builder.CreateExtractValue(Deinterleave, 1);
1834   FinalReductionImag->replaceUsesOfWith(Node->Imag, NewImag);
1835 }
1836 
1837 void ComplexDeinterleavingGraph::replaceNodes() {
1838   SmallVector<Instruction *, 16> DeadInstrRoots;
1839   for (auto *RootInstruction : OrderedRoots) {
1840     // Check if this potential root went through check process and we can
1841     // deinterleave it
1842     if (!RootToNode.count(RootInstruction))
1843       continue;
1844 
1845     IRBuilder<> Builder(RootInstruction);
1846     auto RootNode = RootToNode[RootInstruction];
1847     Value *R = replaceNode(Builder, RootNode.get());
1848 
1849     if (RootNode->Operation ==
1850         ComplexDeinterleavingOperation::ReductionOperation) {
1851       ReductionInfo[RootNode->Real].first->removeIncomingValue(BackEdge);
1852       ReductionInfo[RootNode->Imag].first->removeIncomingValue(BackEdge);
1853       DeadInstrRoots.push_back(RootNode->Real);
1854       DeadInstrRoots.push_back(RootNode->Imag);
1855     } else {
1856       assert(R && "Unable to find replacement for RootInstruction");
1857       DeadInstrRoots.push_back(RootInstruction);
1858       RootInstruction->replaceAllUsesWith(R);
1859     }
1860   }
1861 
1862   for (auto *I : DeadInstrRoots)
1863     RecursivelyDeleteTriviallyDeadInstructions(I, TLI);
1864 }
1865