xref: /llvm-project/llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp (revision b4f9c3a933e80de40ee7860db75bb1088cf9bfa7)
1 //===- ComplexDeinterleavingPass.cpp --------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Identification:
10 // This step is responsible for finding the patterns that can be lowered to
11 // complex instructions, and building a graph to represent the complex
12 // structures. Starting from the "Converging Shuffle" (a shuffle that
13 // reinterleaves the complex components, with a mask of <0, 2, 1, 3>), the
14 // operands are evaluated and identified as "Composite Nodes" (collections of
15 // instructions that can potentially be lowered to a single complex
16 // instruction). This is performed by checking the real and imaginary components
17 // and tracking the data flow for each component while following the operand
18 // pairs. Validity of each node is expected to be done upon creation, and any
19 // validation errors should halt traversal and prevent further graph
20 // construction.
21 // Instead of relying on Shuffle operations, vector interleaving and
22 // deinterleaving can be represented by vector.interleave2 and
23 // vector.deinterleave2 intrinsics. Scalable vectors can be represented only by
24 // these intrinsics, whereas, fixed-width vectors are recognized for both
25 // shufflevector instruction and intrinsics.
26 //
27 // Replacement:
28 // This step traverses the graph built up by identification, delegating to the
29 // target to validate and generate the correct intrinsics, and plumbs them
30 // together connecting each end of the new intrinsics graph to the existing
31 // use-def chain. This step is assumed to finish successfully, as all
32 // information is expected to be correct by this point.
33 //
34 //
35 // Internal data structure:
36 // ComplexDeinterleavingGraph:
37 // Keeps references to all the valid CompositeNodes formed as part of the
38 // transformation, and every Instruction contained within said nodes. It also
39 // holds onto a reference to the root Instruction, and the root node that should
40 // replace it.
41 //
42 // ComplexDeinterleavingCompositeNode:
43 // A CompositeNode represents a single transformation point; each node should
44 // transform into a single complex instruction (ignoring vector splitting, which
45 // would generate more instructions per node). They are identified in a
46 // depth-first manner, traversing and identifying the operands of each
47 // instruction in the order they appear in the IR.
48 // Each node maintains a reference  to its Real and Imaginary instructions,
49 // as well as any additional instructions that make up the identified operation
50 // (Internal instructions should only have uses within their containing node).
51 // A Node also contains the rotation and operation type that it represents.
52 // Operands contains pointers to other CompositeNodes, acting as the edges in
53 // the graph. ReplacementValue is the transformed Value* that has been emitted
54 // to the IR.
55 //
56 // Note: If the operation of a Node is Shuffle, only the Real, Imaginary, and
57 // ReplacementValue fields of that Node are relevant, where the ReplacementValue
58 // should be pre-populated.
59 //
60 //===----------------------------------------------------------------------===//
61 
62 #include "llvm/CodeGen/ComplexDeinterleavingPass.h"
63 #include "llvm/ADT/Statistic.h"
64 #include "llvm/Analysis/TargetLibraryInfo.h"
65 #include "llvm/Analysis/TargetTransformInfo.h"
66 #include "llvm/CodeGen/TargetLowering.h"
67 #include "llvm/CodeGen/TargetPassConfig.h"
68 #include "llvm/CodeGen/TargetSubtargetInfo.h"
69 #include "llvm/IR/IRBuilder.h"
70 #include "llvm/IR/PatternMatch.h"
71 #include "llvm/InitializePasses.h"
72 #include "llvm/Target/TargetMachine.h"
73 #include "llvm/Transforms/Utils/Local.h"
74 #include <algorithm>
75 
76 using namespace llvm;
77 using namespace PatternMatch;
78 
79 #define DEBUG_TYPE "complex-deinterleaving"
80 
81 STATISTIC(NumComplexTransformations, "Amount of complex patterns transformed");
82 
83 static cl::opt<bool> ComplexDeinterleavingEnabled(
84     "enable-complex-deinterleaving",
85     cl::desc("Enable generation of complex instructions"), cl::init(true),
86     cl::Hidden);
87 
88 /// Checks the given mask, and determines whether said mask is interleaving.
89 ///
90 /// To be interleaving, a mask must alternate between `i` and `i + (Length /
91 /// 2)`, and must contain all numbers within the range of `[0..Length)` (e.g. a
92 /// 4x vector interleaving mask would be <0, 2, 1, 3>).
93 static bool isInterleavingMask(ArrayRef<int> Mask);
94 
95 /// Checks the given mask, and determines whether said mask is deinterleaving.
96 ///
97 /// To be deinterleaving, a mask must increment in steps of 2, and either start
98 /// with 0 or 1.
99 /// (e.g. an 8x vector deinterleaving mask would be either <0, 2, 4, 6> or
100 /// <1, 3, 5, 7>).
101 static bool isDeinterleavingMask(ArrayRef<int> Mask);
102 
103 namespace {
104 
105 class ComplexDeinterleavingLegacyPass : public FunctionPass {
106 public:
107   static char ID;
108 
109   ComplexDeinterleavingLegacyPass(const TargetMachine *TM = nullptr)
110       : FunctionPass(ID), TM(TM) {
111     initializeComplexDeinterleavingLegacyPassPass(
112         *PassRegistry::getPassRegistry());
113   }
114 
115   StringRef getPassName() const override {
116     return "Complex Deinterleaving Pass";
117   }
118 
119   bool runOnFunction(Function &F) override;
120   void getAnalysisUsage(AnalysisUsage &AU) const override {
121     AU.addRequired<TargetLibraryInfoWrapperPass>();
122     AU.setPreservesCFG();
123   }
124 
125 private:
126   const TargetMachine *TM;
127 };
128 
129 class ComplexDeinterleavingGraph;
130 struct ComplexDeinterleavingCompositeNode {
131 
132   ComplexDeinterleavingCompositeNode(ComplexDeinterleavingOperation Op,
133                                      Value *R, Value *I)
134       : Operation(Op), Real(R), Imag(I) {}
135 
136 private:
137   friend class ComplexDeinterleavingGraph;
138   using NodePtr = std::shared_ptr<ComplexDeinterleavingCompositeNode>;
139   using RawNodePtr = ComplexDeinterleavingCompositeNode *;
140 
141 public:
142   ComplexDeinterleavingOperation Operation;
143   Value *Real;
144   Value *Imag;
145 
146   // This two members are required exclusively for generating
147   // ComplexDeinterleavingOperation::Symmetric operations.
148   unsigned Opcode;
149   FastMathFlags Flags;
150 
151   ComplexDeinterleavingRotation Rotation =
152       ComplexDeinterleavingRotation::Rotation_0;
153   SmallVector<RawNodePtr> Operands;
154   Value *ReplacementNode = nullptr;
155 
156   void addOperand(NodePtr Node) { Operands.push_back(Node.get()); }
157 
158   void dump() { dump(dbgs()); }
159   void dump(raw_ostream &OS) {
160     auto PrintValue = [&](Value *V) {
161       if (V) {
162         OS << "\"";
163         V->print(OS, true);
164         OS << "\"\n";
165       } else
166         OS << "nullptr\n";
167     };
168     auto PrintNodeRef = [&](RawNodePtr Ptr) {
169       if (Ptr)
170         OS << Ptr << "\n";
171       else
172         OS << "nullptr\n";
173     };
174 
175     OS << "- CompositeNode: " << this << "\n";
176     OS << "  Real: ";
177     PrintValue(Real);
178     OS << "  Imag: ";
179     PrintValue(Imag);
180     OS << "  ReplacementNode: ";
181     PrintValue(ReplacementNode);
182     OS << "  Operation: " << (int)Operation << "\n";
183     OS << "  Rotation: " << ((int)Rotation * 90) << "\n";
184     OS << "  Operands: \n";
185     for (const auto &Op : Operands) {
186       OS << "    - ";
187       PrintNodeRef(Op);
188     }
189   }
190 };
191 
192 class ComplexDeinterleavingGraph {
193 public:
194   struct Product {
195     Value *Multiplier;
196     Value *Multiplicand;
197     bool IsPositive;
198   };
199 
200   using Addend = std::pair<Value *, bool>;
201   using NodePtr = ComplexDeinterleavingCompositeNode::NodePtr;
202   using RawNodePtr = ComplexDeinterleavingCompositeNode::RawNodePtr;
203 
204   // Helper struct for holding info about potential partial multiplication
205   // candidates
206   struct PartialMulCandidate {
207     Value *Common;
208     NodePtr Node;
209     unsigned RealIdx;
210     unsigned ImagIdx;
211     bool IsNodeInverted;
212   };
213 
214   explicit ComplexDeinterleavingGraph(const TargetLowering *TL,
215                                       const TargetLibraryInfo *TLI)
216       : TL(TL), TLI(TLI) {}
217 
218 private:
219   const TargetLowering *TL = nullptr;
220   const TargetLibraryInfo *TLI = nullptr;
221   SmallVector<NodePtr> CompositeNodes;
222 
223   SmallPtrSet<Instruction *, 16> FinalInstructions;
224 
225   /// Root instructions are instructions from which complex computation starts
226   std::map<Instruction *, NodePtr> RootToNode;
227 
228   /// Topologically sorted root instructions
229   SmallVector<Instruction *, 1> OrderedRoots;
230 
231   /// When examining a basic block for complex deinterleaving, if it is a simple
232   /// one-block loop, then the only incoming block is 'Incoming' and the
233   /// 'BackEdge' block is the block itself."
234   BasicBlock *BackEdge = nullptr;
235   BasicBlock *Incoming = nullptr;
236 
237   /// ReductionInfo maps from %ReductionOp to %PHInode and Instruction
238   /// %OutsideUser as it is shown in the IR:
239   ///
240   /// vector.body:
241   ///   %PHInode = phi <vector type> [ zeroinitializer, %entry ],
242   ///                                [ %ReductionOp, %vector.body ]
243   ///   ...
244   ///   %ReductionOp = fadd i64 ...
245   ///   ...
246   ///   br i1 %condition, label %vector.body, %middle.block
247   ///
248   /// middle.block:
249   ///   %OutsideUser = llvm.vector.reduce.fadd(..., %ReductionOp)
250   ///
251   /// %OutsideUser can be `llvm.vector.reduce.fadd` or `fadd` preceding
252   /// `llvm.vector.reduce.fadd` when unroll factor isn't one.
253   std::map<Instruction *, std::pair<PHINode *, Instruction *>> ReductionInfo;
254 
255   /// In the process of detecting a reduction, we consider a pair of
256   /// %ReductionOP, which we refer to as real and imag (or vice versa), and
257   /// traverse the use-tree to detect complex operations. As this is a reduction
258   /// operation, it will eventually reach RealPHI and ImagPHI, which corresponds
259   /// to the %ReductionOPs that we suspect to be complex.
260   /// RealPHI and ImagPHI are used by the identifyPHINode method.
261   PHINode *RealPHI = nullptr;
262   PHINode *ImagPHI = nullptr;
263 
264   /// OldToNewPHI maps the original real PHINode to a new, double-sized PHINode.
265   /// The new PHINode corresponds to a vector of deinterleaved complex numbers.
266   /// This mapping is populated during
267   /// ComplexDeinterleavingOperation::ReductionPHI node replacement. It is then
268   /// used in the ComplexDeinterleavingOperation::ReductionOperation node
269   /// replacement process.
270   std::map<PHINode *, PHINode *> OldToNewPHI;
271 
272   NodePtr prepareCompositeNode(ComplexDeinterleavingOperation Operation,
273                                Value *R, Value *I) {
274     assert(((Operation != ComplexDeinterleavingOperation::ReductionPHI &&
275              Operation != ComplexDeinterleavingOperation::ReductionOperation) ||
276             (R && I)) &&
277            "Reduction related nodes must have Real and Imaginary parts");
278     return std::make_shared<ComplexDeinterleavingCompositeNode>(Operation, R,
279                                                                 I);
280   }
281 
282   NodePtr submitCompositeNode(NodePtr Node) {
283     CompositeNodes.push_back(Node);
284     return Node;
285   }
286 
287   NodePtr getContainingComposite(Value *R, Value *I) {
288     for (const auto &CN : CompositeNodes) {
289       if (CN->Real == R && CN->Imag == I)
290         return CN;
291     }
292     return nullptr;
293   }
294 
295   /// Identifies a complex partial multiply pattern and its rotation, based on
296   /// the following patterns
297   ///
298   ///  0:  r: cr + ar * br
299   ///      i: ci + ar * bi
300   /// 90:  r: cr - ai * bi
301   ///      i: ci + ai * br
302   /// 180: r: cr - ar * br
303   ///      i: ci - ar * bi
304   /// 270: r: cr + ai * bi
305   ///      i: ci - ai * br
306   NodePtr identifyPartialMul(Instruction *Real, Instruction *Imag);
307 
308   /// Identify the other branch of a Partial Mul, taking the CommonOperandI that
309   /// is partially known from identifyPartialMul, filling in the other half of
310   /// the complex pair.
311   NodePtr
312   identifyNodeWithImplicitAdd(Instruction *I, Instruction *J,
313                               std::pair<Value *, Value *> &CommonOperandI);
314 
315   /// Identifies a complex add pattern and its rotation, based on the following
316   /// patterns.
317   ///
318   /// 90:  r: ar - bi
319   ///      i: ai + br
320   /// 270: r: ar + bi
321   ///      i: ai - br
322   NodePtr identifyAdd(Instruction *Real, Instruction *Imag);
323   NodePtr identifySymmetricOperation(Instruction *Real, Instruction *Imag);
324 
325   NodePtr identifyNode(Value *R, Value *I);
326 
327   /// Determine if a sum of complex numbers can be formed from \p RealAddends
328   /// and \p ImagAddens. If \p Accumulator is not null, add the result to it.
329   /// Return nullptr if it is not possible to construct a complex number.
330   /// \p Flags are needed to generate symmetric Add and Sub operations.
331   NodePtr identifyAdditions(std::list<Addend> &RealAddends,
332                             std::list<Addend> &ImagAddends, FastMathFlags Flags,
333                             NodePtr Accumulator);
334 
335   /// Extract one addend that have both real and imaginary parts positive.
336   NodePtr extractPositiveAddend(std::list<Addend> &RealAddends,
337                                 std::list<Addend> &ImagAddends);
338 
339   /// Determine if sum of multiplications of complex numbers can be formed from
340   /// \p RealMuls and \p ImagMuls. If \p Accumulator is not null, add the result
341   /// to it. Return nullptr if it is not possible to construct a complex number.
342   NodePtr identifyMultiplications(std::vector<Product> &RealMuls,
343                                   std::vector<Product> &ImagMuls,
344                                   NodePtr Accumulator);
345 
346   /// Go through pairs of multiplication (one Real and one Imag) and find all
347   /// possible candidates for partial multiplication and put them into \p
348   /// Candidates. Returns true if all Product has pair with common operand
349   bool collectPartialMuls(const std::vector<Product> &RealMuls,
350                           const std::vector<Product> &ImagMuls,
351                           std::vector<PartialMulCandidate> &Candidates);
352 
353   /// If the code is compiled with -Ofast or expressions have `reassoc` flag,
354   /// the order of complex computation operations may be significantly altered,
355   /// and the real and imaginary parts may not be executed in parallel. This
356   /// function takes this into consideration and employs a more general approach
357   /// to identify complex computations. Initially, it gathers all the addends
358   /// and multiplicands and then constructs a complex expression from them.
359   NodePtr identifyReassocNodes(Instruction *I, Instruction *J);
360 
361   NodePtr identifyRoot(Instruction *I);
362 
363   /// Identifies the Deinterleave operation applied to a vector containing
364   /// complex numbers. There are two ways to represent the Deinterleave
365   /// operation:
366   /// * Using two shufflevectors with even indices for /pReal instruction and
367   /// odd indices for /pImag instructions (only for fixed-width vectors)
368   /// * Using two extractvalue instructions applied to `vector.deinterleave2`
369   /// intrinsic (for both fixed and scalable vectors)
370   NodePtr identifyDeinterleave(Instruction *Real, Instruction *Imag);
371 
372   NodePtr identifyPHINode(Instruction *Real, Instruction *Imag);
373 
374   /// Identifies SelectInsts in a loop that has reduction with predication masks
375   /// and/or predicated tail folding
376   NodePtr identifySelectNode(Instruction *Real, Instruction *Imag);
377 
378   Value *replaceNode(IRBuilderBase &Builder, RawNodePtr Node);
379 
380   /// Complete IR modifications after producing new reduction operation:
381   /// * Populate the PHINode generated for
382   /// ComplexDeinterleavingOperation::ReductionPHI
383   /// * Deinterleave the final value outside of the loop and repurpose original
384   /// reduction users
385   void processReductionOperation(Value *OperationReplacement, RawNodePtr Node);
386 
387 public:
388   void dump() { dump(dbgs()); }
389   void dump(raw_ostream &OS) {
390     for (const auto &Node : CompositeNodes)
391       Node->dump(OS);
392   }
393 
394   /// Returns false if the deinterleaving operation should be cancelled for the
395   /// current graph.
396   bool identifyNodes(Instruction *RootI);
397 
398   /// In case \pB is one-block loop, this function seeks potential reductions
399   /// and populates ReductionInfo. Returns true if any reductions were
400   /// identified.
401   bool collectPotentialReductions(BasicBlock *B);
402 
403   void identifyReductionNodes();
404 
405   /// Check that every instruction, from the roots to the leaves, has internal
406   /// uses.
407   bool checkNodes();
408 
409   /// Perform the actual replacement of the underlying instruction graph.
410   void replaceNodes();
411 };
412 
413 class ComplexDeinterleaving {
414 public:
415   ComplexDeinterleaving(const TargetLowering *tl, const TargetLibraryInfo *tli)
416       : TL(tl), TLI(tli) {}
417   bool runOnFunction(Function &F);
418 
419 private:
420   bool evaluateBasicBlock(BasicBlock *B);
421 
422   const TargetLowering *TL = nullptr;
423   const TargetLibraryInfo *TLI = nullptr;
424 };
425 
426 } // namespace
427 
428 char ComplexDeinterleavingLegacyPass::ID = 0;
429 
430 INITIALIZE_PASS_BEGIN(ComplexDeinterleavingLegacyPass, DEBUG_TYPE,
431                       "Complex Deinterleaving", false, false)
432 INITIALIZE_PASS_END(ComplexDeinterleavingLegacyPass, DEBUG_TYPE,
433                     "Complex Deinterleaving", false, false)
434 
435 PreservedAnalyses ComplexDeinterleavingPass::run(Function &F,
436                                                  FunctionAnalysisManager &AM) {
437   const TargetLowering *TL = TM->getSubtargetImpl(F)->getTargetLowering();
438   auto &TLI = AM.getResult<llvm::TargetLibraryAnalysis>(F);
439   if (!ComplexDeinterleaving(TL, &TLI).runOnFunction(F))
440     return PreservedAnalyses::all();
441 
442   PreservedAnalyses PA;
443   PA.preserve<FunctionAnalysisManagerModuleProxy>();
444   return PA;
445 }
446 
447 FunctionPass *llvm::createComplexDeinterleavingPass(const TargetMachine *TM) {
448   return new ComplexDeinterleavingLegacyPass(TM);
449 }
450 
451 bool ComplexDeinterleavingLegacyPass::runOnFunction(Function &F) {
452   const auto *TL = TM->getSubtargetImpl(F)->getTargetLowering();
453   auto TLI = getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F);
454   return ComplexDeinterleaving(TL, &TLI).runOnFunction(F);
455 }
456 
457 bool ComplexDeinterleaving::runOnFunction(Function &F) {
458   if (!ComplexDeinterleavingEnabled) {
459     LLVM_DEBUG(
460         dbgs() << "Complex deinterleaving has been explicitly disabled.\n");
461     return false;
462   }
463 
464   if (!TL->isComplexDeinterleavingSupported()) {
465     LLVM_DEBUG(
466         dbgs() << "Complex deinterleaving has been disabled, target does "
467                   "not support lowering of complex number operations.\n");
468     return false;
469   }
470 
471   bool Changed = false;
472   for (auto &B : F)
473     Changed |= evaluateBasicBlock(&B);
474 
475   return Changed;
476 }
477 
478 static bool isInterleavingMask(ArrayRef<int> Mask) {
479   // If the size is not even, it's not an interleaving mask
480   if ((Mask.size() & 1))
481     return false;
482 
483   int HalfNumElements = Mask.size() / 2;
484   for (int Idx = 0; Idx < HalfNumElements; ++Idx) {
485     int MaskIdx = Idx * 2;
486     if (Mask[MaskIdx] != Idx || Mask[MaskIdx + 1] != (Idx + HalfNumElements))
487       return false;
488   }
489 
490   return true;
491 }
492 
493 static bool isDeinterleavingMask(ArrayRef<int> Mask) {
494   int Offset = Mask[0];
495   int HalfNumElements = Mask.size() / 2;
496 
497   for (int Idx = 1; Idx < HalfNumElements; ++Idx) {
498     if (Mask[Idx] != (Idx * 2) + Offset)
499       return false;
500   }
501 
502   return true;
503 }
504 
505 bool ComplexDeinterleaving::evaluateBasicBlock(BasicBlock *B) {
506   ComplexDeinterleavingGraph Graph(TL, TLI);
507   if (Graph.collectPotentialReductions(B))
508     Graph.identifyReductionNodes();
509 
510   for (auto &I : *B)
511     Graph.identifyNodes(&I);
512 
513   if (Graph.checkNodes()) {
514     Graph.replaceNodes();
515     return true;
516   }
517 
518   return false;
519 }
520 
521 ComplexDeinterleavingGraph::NodePtr
522 ComplexDeinterleavingGraph::identifyNodeWithImplicitAdd(
523     Instruction *Real, Instruction *Imag,
524     std::pair<Value *, Value *> &PartialMatch) {
525   LLVM_DEBUG(dbgs() << "identifyNodeWithImplicitAdd " << *Real << " / " << *Imag
526                     << "\n");
527 
528   if (!Real->hasOneUse() || !Imag->hasOneUse()) {
529     LLVM_DEBUG(dbgs() << "  - Mul operand has multiple uses.\n");
530     return nullptr;
531   }
532 
533   if (Real->getOpcode() != Instruction::FMul ||
534       Imag->getOpcode() != Instruction::FMul) {
535     LLVM_DEBUG(dbgs() << "  - Real or imaginary instruction is not fmul\n");
536     return nullptr;
537   }
538 
539   Value *R0 = Real->getOperand(0);
540   Value *R1 = Real->getOperand(1);
541   Value *I0 = Imag->getOperand(0);
542   Value *I1 = Imag->getOperand(1);
543 
544   // A +/+ has a rotation of 0. If any of the operands are fneg, we flip the
545   // rotations and use the operand.
546   unsigned Negs = 0;
547   Value *Op;
548   if (match(R0, m_Neg(m_Value(Op)))) {
549     Negs |= 1;
550     R0 = Op;
551   } else if (match(R1, m_Neg(m_Value(Op)))) {
552     Negs |= 1;
553     R1 = Op;
554   }
555 
556   if (match(I0, m_Neg(m_Value(Op)))) {
557     Negs |= 2;
558     Negs ^= 1;
559     I0 = Op;
560   } else if (match(I1, m_Neg(m_Value(Op)))) {
561     Negs |= 2;
562     Negs ^= 1;
563     I1 = Op;
564   }
565 
566   ComplexDeinterleavingRotation Rotation = (ComplexDeinterleavingRotation)Negs;
567 
568   Value *CommonOperand;
569   Value *UncommonRealOp;
570   Value *UncommonImagOp;
571 
572   if (R0 == I0 || R0 == I1) {
573     CommonOperand = R0;
574     UncommonRealOp = R1;
575   } else if (R1 == I0 || R1 == I1) {
576     CommonOperand = R1;
577     UncommonRealOp = R0;
578   } else {
579     LLVM_DEBUG(dbgs() << "  - No equal operand\n");
580     return nullptr;
581   }
582 
583   UncommonImagOp = (CommonOperand == I0) ? I1 : I0;
584   if (Rotation == ComplexDeinterleavingRotation::Rotation_90 ||
585       Rotation == ComplexDeinterleavingRotation::Rotation_270)
586     std::swap(UncommonRealOp, UncommonImagOp);
587 
588   // Between identifyPartialMul and here we need to have found a complete valid
589   // pair from the CommonOperand of each part.
590   if (Rotation == ComplexDeinterleavingRotation::Rotation_0 ||
591       Rotation == ComplexDeinterleavingRotation::Rotation_180)
592     PartialMatch.first = CommonOperand;
593   else
594     PartialMatch.second = CommonOperand;
595 
596   if (!PartialMatch.first || !PartialMatch.second) {
597     LLVM_DEBUG(dbgs() << "  - Incomplete partial match\n");
598     return nullptr;
599   }
600 
601   NodePtr CommonNode = identifyNode(PartialMatch.first, PartialMatch.second);
602   if (!CommonNode) {
603     LLVM_DEBUG(dbgs() << "  - No CommonNode identified\n");
604     return nullptr;
605   }
606 
607   NodePtr UncommonNode = identifyNode(UncommonRealOp, UncommonImagOp);
608   if (!UncommonNode) {
609     LLVM_DEBUG(dbgs() << "  - No UncommonNode identified\n");
610     return nullptr;
611   }
612 
613   NodePtr Node = prepareCompositeNode(
614       ComplexDeinterleavingOperation::CMulPartial, Real, Imag);
615   Node->Rotation = Rotation;
616   Node->addOperand(CommonNode);
617   Node->addOperand(UncommonNode);
618   return submitCompositeNode(Node);
619 }
620 
621 ComplexDeinterleavingGraph::NodePtr
622 ComplexDeinterleavingGraph::identifyPartialMul(Instruction *Real,
623                                                Instruction *Imag) {
624   LLVM_DEBUG(dbgs() << "identifyPartialMul " << *Real << " / " << *Imag
625                     << "\n");
626   // Determine rotation
627   ComplexDeinterleavingRotation Rotation;
628   if (Real->getOpcode() == Instruction::FAdd &&
629       Imag->getOpcode() == Instruction::FAdd)
630     Rotation = ComplexDeinterleavingRotation::Rotation_0;
631   else if (Real->getOpcode() == Instruction::FSub &&
632            Imag->getOpcode() == Instruction::FAdd)
633     Rotation = ComplexDeinterleavingRotation::Rotation_90;
634   else if (Real->getOpcode() == Instruction::FSub &&
635            Imag->getOpcode() == Instruction::FSub)
636     Rotation = ComplexDeinterleavingRotation::Rotation_180;
637   else if (Real->getOpcode() == Instruction::FAdd &&
638            Imag->getOpcode() == Instruction::FSub)
639     Rotation = ComplexDeinterleavingRotation::Rotation_270;
640   else {
641     LLVM_DEBUG(dbgs() << "  - Unhandled rotation.\n");
642     return nullptr;
643   }
644 
645   if (!Real->getFastMathFlags().allowContract() ||
646       !Imag->getFastMathFlags().allowContract()) {
647     LLVM_DEBUG(dbgs() << "  - Contract is missing from the FastMath flags.\n");
648     return nullptr;
649   }
650 
651   Value *CR = Real->getOperand(0);
652   Instruction *RealMulI = dyn_cast<Instruction>(Real->getOperand(1));
653   if (!RealMulI)
654     return nullptr;
655   Value *CI = Imag->getOperand(0);
656   Instruction *ImagMulI = dyn_cast<Instruction>(Imag->getOperand(1));
657   if (!ImagMulI)
658     return nullptr;
659 
660   if (!RealMulI->hasOneUse() || !ImagMulI->hasOneUse()) {
661     LLVM_DEBUG(dbgs() << "  - Mul instruction has multiple uses\n");
662     return nullptr;
663   }
664 
665   Value *R0 = RealMulI->getOperand(0);
666   Value *R1 = RealMulI->getOperand(1);
667   Value *I0 = ImagMulI->getOperand(0);
668   Value *I1 = ImagMulI->getOperand(1);
669 
670   Value *CommonOperand;
671   Value *UncommonRealOp;
672   Value *UncommonImagOp;
673 
674   if (R0 == I0 || R0 == I1) {
675     CommonOperand = R0;
676     UncommonRealOp = R1;
677   } else if (R1 == I0 || R1 == I1) {
678     CommonOperand = R1;
679     UncommonRealOp = R0;
680   } else {
681     LLVM_DEBUG(dbgs() << "  - No equal operand\n");
682     return nullptr;
683   }
684 
685   UncommonImagOp = (CommonOperand == I0) ? I1 : I0;
686   if (Rotation == ComplexDeinterleavingRotation::Rotation_90 ||
687       Rotation == ComplexDeinterleavingRotation::Rotation_270)
688     std::swap(UncommonRealOp, UncommonImagOp);
689 
690   std::pair<Value *, Value *> PartialMatch(
691       (Rotation == ComplexDeinterleavingRotation::Rotation_0 ||
692        Rotation == ComplexDeinterleavingRotation::Rotation_180)
693           ? CommonOperand
694           : nullptr,
695       (Rotation == ComplexDeinterleavingRotation::Rotation_90 ||
696        Rotation == ComplexDeinterleavingRotation::Rotation_270)
697           ? CommonOperand
698           : nullptr);
699 
700   auto *CRInst = dyn_cast<Instruction>(CR);
701   auto *CIInst = dyn_cast<Instruction>(CI);
702 
703   if (!CRInst || !CIInst) {
704     LLVM_DEBUG(dbgs() << "  - Common operands are not instructions.\n");
705     return nullptr;
706   }
707 
708   NodePtr CNode = identifyNodeWithImplicitAdd(CRInst, CIInst, PartialMatch);
709   if (!CNode) {
710     LLVM_DEBUG(dbgs() << "  - No cnode identified\n");
711     return nullptr;
712   }
713 
714   NodePtr UncommonRes = identifyNode(UncommonRealOp, UncommonImagOp);
715   if (!UncommonRes) {
716     LLVM_DEBUG(dbgs() << "  - No UncommonRes identified\n");
717     return nullptr;
718   }
719 
720   assert(PartialMatch.first && PartialMatch.second);
721   NodePtr CommonRes = identifyNode(PartialMatch.first, PartialMatch.second);
722   if (!CommonRes) {
723     LLVM_DEBUG(dbgs() << "  - No CommonRes identified\n");
724     return nullptr;
725   }
726 
727   NodePtr Node = prepareCompositeNode(
728       ComplexDeinterleavingOperation::CMulPartial, Real, Imag);
729   Node->Rotation = Rotation;
730   Node->addOperand(CommonRes);
731   Node->addOperand(UncommonRes);
732   Node->addOperand(CNode);
733   return submitCompositeNode(Node);
734 }
735 
736 ComplexDeinterleavingGraph::NodePtr
737 ComplexDeinterleavingGraph::identifyAdd(Instruction *Real, Instruction *Imag) {
738   LLVM_DEBUG(dbgs() << "identifyAdd " << *Real << " / " << *Imag << "\n");
739 
740   // Determine rotation
741   ComplexDeinterleavingRotation Rotation;
742   if ((Real->getOpcode() == Instruction::FSub &&
743        Imag->getOpcode() == Instruction::FAdd) ||
744       (Real->getOpcode() == Instruction::Sub &&
745        Imag->getOpcode() == Instruction::Add))
746     Rotation = ComplexDeinterleavingRotation::Rotation_90;
747   else if ((Real->getOpcode() == Instruction::FAdd &&
748             Imag->getOpcode() == Instruction::FSub) ||
749            (Real->getOpcode() == Instruction::Add &&
750             Imag->getOpcode() == Instruction::Sub))
751     Rotation = ComplexDeinterleavingRotation::Rotation_270;
752   else {
753     LLVM_DEBUG(dbgs() << " - Unhandled case, rotation is not assigned.\n");
754     return nullptr;
755   }
756 
757   auto *AR = dyn_cast<Instruction>(Real->getOperand(0));
758   auto *BI = dyn_cast<Instruction>(Real->getOperand(1));
759   auto *AI = dyn_cast<Instruction>(Imag->getOperand(0));
760   auto *BR = dyn_cast<Instruction>(Imag->getOperand(1));
761 
762   if (!AR || !AI || !BR || !BI) {
763     LLVM_DEBUG(dbgs() << " - Not all operands are instructions.\n");
764     return nullptr;
765   }
766 
767   NodePtr ResA = identifyNode(AR, AI);
768   if (!ResA) {
769     LLVM_DEBUG(dbgs() << " - AR/AI is not identified as a composite node.\n");
770     return nullptr;
771   }
772   NodePtr ResB = identifyNode(BR, BI);
773   if (!ResB) {
774     LLVM_DEBUG(dbgs() << " - BR/BI is not identified as a composite node.\n");
775     return nullptr;
776   }
777 
778   NodePtr Node =
779       prepareCompositeNode(ComplexDeinterleavingOperation::CAdd, Real, Imag);
780   Node->Rotation = Rotation;
781   Node->addOperand(ResA);
782   Node->addOperand(ResB);
783   return submitCompositeNode(Node);
784 }
785 
786 static bool isInstructionPairAdd(Instruction *A, Instruction *B) {
787   unsigned OpcA = A->getOpcode();
788   unsigned OpcB = B->getOpcode();
789 
790   return (OpcA == Instruction::FSub && OpcB == Instruction::FAdd) ||
791          (OpcA == Instruction::FAdd && OpcB == Instruction::FSub) ||
792          (OpcA == Instruction::Sub && OpcB == Instruction::Add) ||
793          (OpcA == Instruction::Add && OpcB == Instruction::Sub);
794 }
795 
796 static bool isInstructionPairMul(Instruction *A, Instruction *B) {
797   auto Pattern =
798       m_BinOp(m_FMul(m_Value(), m_Value()), m_FMul(m_Value(), m_Value()));
799 
800   return match(A, Pattern) && match(B, Pattern);
801 }
802 
803 static bool isInstructionPotentiallySymmetric(Instruction *I) {
804   switch (I->getOpcode()) {
805   case Instruction::FAdd:
806   case Instruction::FSub:
807   case Instruction::FMul:
808   case Instruction::FNeg:
809     return true;
810   default:
811     return false;
812   }
813 }
814 
815 ComplexDeinterleavingGraph::NodePtr
816 ComplexDeinterleavingGraph::identifySymmetricOperation(Instruction *Real,
817                                                        Instruction *Imag) {
818   if (Real->getOpcode() != Imag->getOpcode())
819     return nullptr;
820 
821   if (!isInstructionPotentiallySymmetric(Real) ||
822       !isInstructionPotentiallySymmetric(Imag))
823     return nullptr;
824 
825   auto *R0 = Real->getOperand(0);
826   auto *I0 = Imag->getOperand(0);
827 
828   NodePtr Op0 = identifyNode(R0, I0);
829   NodePtr Op1 = nullptr;
830   if (Op0 == nullptr)
831     return nullptr;
832 
833   if (Real->isBinaryOp()) {
834     auto *R1 = Real->getOperand(1);
835     auto *I1 = Imag->getOperand(1);
836     Op1 = identifyNode(R1, I1);
837     if (Op1 == nullptr)
838       return nullptr;
839   }
840 
841   if (isa<FPMathOperator>(Real) &&
842       Real->getFastMathFlags() != Imag->getFastMathFlags())
843     return nullptr;
844 
845   auto Node = prepareCompositeNode(ComplexDeinterleavingOperation::Symmetric,
846                                    Real, Imag);
847   Node->Opcode = Real->getOpcode();
848   if (isa<FPMathOperator>(Real))
849     Node->Flags = Real->getFastMathFlags();
850 
851   Node->addOperand(Op0);
852   if (Real->isBinaryOp())
853     Node->addOperand(Op1);
854 
855   return submitCompositeNode(Node);
856 }
857 
858 ComplexDeinterleavingGraph::NodePtr
859 ComplexDeinterleavingGraph::identifyNode(Value *R, Value *I) {
860   LLVM_DEBUG(dbgs() << "identifyNode on " << *R << " / " << *I << "\n");
861   assert(R->getType() == I->getType() &&
862          "Real and imaginary parts should not have different types");
863   if (NodePtr CN = getContainingComposite(R, I)) {
864     LLVM_DEBUG(dbgs() << " - Folding to existing node\n");
865     return CN;
866   }
867 
868   auto *Real = dyn_cast<Instruction>(R);
869   auto *Imag = dyn_cast<Instruction>(I);
870   if (!Real || !Imag)
871     return nullptr;
872 
873   if (NodePtr CN = identifyDeinterleave(Real, Imag))
874     return CN;
875 
876   if (NodePtr CN = identifyPHINode(Real, Imag))
877     return CN;
878 
879   if (NodePtr CN = identifySelectNode(Real, Imag))
880     return CN;
881 
882   auto *VTy = cast<VectorType>(Real->getType());
883   auto *NewVTy = VectorType::getDoubleElementsVectorType(VTy);
884 
885   bool HasCMulSupport = TL->isComplexDeinterleavingOperationSupported(
886       ComplexDeinterleavingOperation::CMulPartial, NewVTy);
887   bool HasCAddSupport = TL->isComplexDeinterleavingOperationSupported(
888       ComplexDeinterleavingOperation::CAdd, NewVTy);
889 
890   if (HasCMulSupport && isInstructionPairMul(Real, Imag)) {
891     if (NodePtr CN = identifyPartialMul(Real, Imag))
892       return CN;
893   }
894 
895   if (HasCAddSupport && isInstructionPairAdd(Real, Imag)) {
896     if (NodePtr CN = identifyAdd(Real, Imag))
897       return CN;
898   }
899 
900   if (HasCMulSupport && HasCAddSupport) {
901     if (NodePtr CN = identifyReassocNodes(Real, Imag))
902       return CN;
903   }
904 
905   if (NodePtr CN = identifySymmetricOperation(Real, Imag))
906     return CN;
907 
908   LLVM_DEBUG(dbgs() << "  - Not recognised as a valid pattern.\n");
909   return nullptr;
910 }
911 
912 ComplexDeinterleavingGraph::NodePtr
913 ComplexDeinterleavingGraph::identifyReassocNodes(Instruction *Real,
914                                                  Instruction *Imag) {
915 
916   if ((Real->getOpcode() != Instruction::FAdd &&
917        Real->getOpcode() != Instruction::FSub &&
918        Real->getOpcode() != Instruction::FNeg) ||
919       (Imag->getOpcode() != Instruction::FAdd &&
920        Imag->getOpcode() != Instruction::FSub &&
921        Imag->getOpcode() != Instruction::FNeg))
922     return nullptr;
923 
924   if (Real->getFastMathFlags() != Imag->getFastMathFlags()) {
925     LLVM_DEBUG(
926         dbgs()
927         << "The flags in Real and Imaginary instructions are not identical\n");
928     return nullptr;
929   }
930 
931   FastMathFlags Flags = Real->getFastMathFlags();
932   if (!Flags.allowReassoc()) {
933     LLVM_DEBUG(
934         dbgs() << "the 'Reassoc' attribute is missing in the FastMath flags\n");
935     return nullptr;
936   }
937 
938   // Collect multiplications and addend instructions from the given instruction
939   // while traversing it operands. Additionally, verify that all instructions
940   // have the same fast math flags.
941   auto Collect = [&Flags](Instruction *Insn, std::vector<Product> &Muls,
942                           std::list<Addend> &Addends) -> bool {
943     SmallVector<PointerIntPair<Value *, 1, bool>> Worklist = {{Insn, true}};
944     SmallPtrSet<Value *, 8> Visited;
945     while (!Worklist.empty()) {
946       auto [V, IsPositive] = Worklist.back();
947       Worklist.pop_back();
948       if (!Visited.insert(V).second)
949         continue;
950 
951       Instruction *I = dyn_cast<Instruction>(V);
952       if (!I) {
953         Addends.emplace_back(V, IsPositive);
954         continue;
955       }
956 
957       // If an instruction has more than one user, it indicates that it either
958       // has an external user, which will be later checked by the checkNodes
959       // function, or it is a subexpression utilized by multiple expressions. In
960       // the latter case, we will attempt to separately identify the complex
961       // operation from here in order to create a shared
962       // ComplexDeinterleavingCompositeNode.
963       if (I != Insn && I->getNumUses() > 1) {
964         LLVM_DEBUG(dbgs() << "Found potential sub-expression: " << *I << "\n");
965         Addends.emplace_back(I, IsPositive);
966         continue;
967       }
968 
969       if (I->getOpcode() == Instruction::FAdd) {
970         Worklist.emplace_back(I->getOperand(1), IsPositive);
971         Worklist.emplace_back(I->getOperand(0), IsPositive);
972       } else if (I->getOpcode() == Instruction::FSub) {
973         Worklist.emplace_back(I->getOperand(1), !IsPositive);
974         Worklist.emplace_back(I->getOperand(0), IsPositive);
975       } else if (I->getOpcode() == Instruction::FMul) {
976         Value *A, *B;
977         if (match(I->getOperand(0), m_FNeg(m_Value(A)))) {
978           IsPositive = !IsPositive;
979         } else {
980           A = I->getOperand(0);
981         }
982 
983         if (match(I->getOperand(1), m_FNeg(m_Value(B)))) {
984           IsPositive = !IsPositive;
985         } else {
986           B = I->getOperand(1);
987         }
988         Muls.push_back(Product{A, B, IsPositive});
989       } else if (I->getOpcode() == Instruction::FNeg) {
990         Worklist.emplace_back(I->getOperand(0), !IsPositive);
991       } else {
992         Addends.emplace_back(I, IsPositive);
993         continue;
994       }
995 
996       if (I->getFastMathFlags() != Flags) {
997         LLVM_DEBUG(dbgs() << "The instruction's fast math flags are "
998                              "inconsistent with the root instructions' flags: "
999                           << *I << "\n");
1000         return false;
1001       }
1002     }
1003     return true;
1004   };
1005 
1006   std::vector<Product> RealMuls, ImagMuls;
1007   std::list<Addend> RealAddends, ImagAddends;
1008   if (!Collect(Real, RealMuls, RealAddends) ||
1009       !Collect(Imag, ImagMuls, ImagAddends))
1010     return nullptr;
1011 
1012   if (RealAddends.size() != ImagAddends.size())
1013     return nullptr;
1014 
1015   NodePtr FinalNode;
1016   if (!RealMuls.empty() || !ImagMuls.empty()) {
1017     // If there are multiplicands, extract positive addend and use it as an
1018     // accumulator
1019     FinalNode = extractPositiveAddend(RealAddends, ImagAddends);
1020     FinalNode = identifyMultiplications(RealMuls, ImagMuls, FinalNode);
1021     if (!FinalNode)
1022       return nullptr;
1023   }
1024 
1025   // Identify and process remaining additions
1026   if (!RealAddends.empty() || !ImagAddends.empty()) {
1027     FinalNode = identifyAdditions(RealAddends, ImagAddends, Flags, FinalNode);
1028     if (!FinalNode)
1029       return nullptr;
1030   }
1031   assert(FinalNode && "FinalNode can not be nullptr here");
1032   // Set the Real and Imag fields of the final node and submit it
1033   FinalNode->Real = Real;
1034   FinalNode->Imag = Imag;
1035   submitCompositeNode(FinalNode);
1036   return FinalNode;
1037 }
1038 
1039 bool ComplexDeinterleavingGraph::collectPartialMuls(
1040     const std::vector<Product> &RealMuls, const std::vector<Product> &ImagMuls,
1041     std::vector<PartialMulCandidate> &PartialMulCandidates) {
1042   // Helper function to extract a common operand from two products
1043   auto FindCommonInstruction = [](const Product &Real,
1044                                   const Product &Imag) -> Value * {
1045     if (Real.Multiplicand == Imag.Multiplicand ||
1046         Real.Multiplicand == Imag.Multiplier)
1047       return Real.Multiplicand;
1048 
1049     if (Real.Multiplier == Imag.Multiplicand ||
1050         Real.Multiplier == Imag.Multiplier)
1051       return Real.Multiplier;
1052 
1053     return nullptr;
1054   };
1055 
1056   // Iterating over real and imaginary multiplications to find common operands
1057   // If a common operand is found, a partial multiplication candidate is created
1058   // and added to the candidates vector The function returns false if no common
1059   // operands are found for any product
1060   for (unsigned i = 0; i < RealMuls.size(); ++i) {
1061     bool FoundCommon = false;
1062     for (unsigned j = 0; j < ImagMuls.size(); ++j) {
1063       auto *Common = FindCommonInstruction(RealMuls[i], ImagMuls[j]);
1064       if (!Common)
1065         continue;
1066 
1067       auto *A = RealMuls[i].Multiplicand == Common ? RealMuls[i].Multiplier
1068                                                    : RealMuls[i].Multiplicand;
1069       auto *B = ImagMuls[j].Multiplicand == Common ? ImagMuls[j].Multiplier
1070                                                    : ImagMuls[j].Multiplicand;
1071 
1072       auto Node = identifyNode(A, B);
1073       if (Node) {
1074         FoundCommon = true;
1075         PartialMulCandidates.push_back({Common, Node, i, j, false});
1076       }
1077 
1078       Node = identifyNode(B, A);
1079       if (Node) {
1080         FoundCommon = true;
1081         PartialMulCandidates.push_back({Common, Node, i, j, true});
1082       }
1083     }
1084     if (!FoundCommon)
1085       return false;
1086   }
1087   return true;
1088 }
1089 
1090 ComplexDeinterleavingGraph::NodePtr
1091 ComplexDeinterleavingGraph::identifyMultiplications(
1092     std::vector<Product> &RealMuls, std::vector<Product> &ImagMuls,
1093     NodePtr Accumulator = nullptr) {
1094   if (RealMuls.size() != ImagMuls.size())
1095     return nullptr;
1096 
1097   std::vector<PartialMulCandidate> Info;
1098   if (!collectPartialMuls(RealMuls, ImagMuls, Info))
1099     return nullptr;
1100 
1101   // Map to store common instruction to node pointers
1102   std::map<Value *, NodePtr> CommonToNode;
1103   std::vector<bool> Processed(Info.size(), false);
1104   for (unsigned I = 0; I < Info.size(); ++I) {
1105     if (Processed[I])
1106       continue;
1107 
1108     PartialMulCandidate &InfoA = Info[I];
1109     for (unsigned J = I + 1; J < Info.size(); ++J) {
1110       if (Processed[J])
1111         continue;
1112 
1113       PartialMulCandidate &InfoB = Info[J];
1114       auto *InfoReal = &InfoA;
1115       auto *InfoImag = &InfoB;
1116 
1117       auto NodeFromCommon = identifyNode(InfoReal->Common, InfoImag->Common);
1118       if (!NodeFromCommon) {
1119         std::swap(InfoReal, InfoImag);
1120         NodeFromCommon = identifyNode(InfoReal->Common, InfoImag->Common);
1121       }
1122       if (!NodeFromCommon)
1123         continue;
1124 
1125       CommonToNode[InfoReal->Common] = NodeFromCommon;
1126       CommonToNode[InfoImag->Common] = NodeFromCommon;
1127       Processed[I] = true;
1128       Processed[J] = true;
1129     }
1130   }
1131 
1132   std::vector<bool> ProcessedReal(RealMuls.size(), false);
1133   std::vector<bool> ProcessedImag(ImagMuls.size(), false);
1134   NodePtr Result = Accumulator;
1135   for (auto &PMI : Info) {
1136     if (ProcessedReal[PMI.RealIdx] || ProcessedImag[PMI.ImagIdx])
1137       continue;
1138 
1139     auto It = CommonToNode.find(PMI.Common);
1140     // TODO: Process independent complex multiplications. Cases like this:
1141     //  A.real() * B where both A and B are complex numbers.
1142     if (It == CommonToNode.end()) {
1143       LLVM_DEBUG({
1144         dbgs() << "Unprocessed independent partial multiplication:\n";
1145         for (auto *Mul : {&RealMuls[PMI.RealIdx], &RealMuls[PMI.RealIdx]})
1146           dbgs().indent(4) << (Mul->IsPositive ? "+" : "-") << *Mul->Multiplier
1147                            << " multiplied by " << *Mul->Multiplicand << "\n";
1148       });
1149       return nullptr;
1150     }
1151 
1152     auto &RealMul = RealMuls[PMI.RealIdx];
1153     auto &ImagMul = ImagMuls[PMI.ImagIdx];
1154 
1155     auto NodeA = It->second;
1156     auto NodeB = PMI.Node;
1157     auto IsMultiplicandReal = PMI.Common == NodeA->Real;
1158     // The following table illustrates the relationship between multiplications
1159     // and rotations. If we consider the multiplication (X + iY) * (U + iV), we
1160     // can see:
1161     //
1162     // Rotation |   Real |   Imag |
1163     // ---------+--------+--------+
1164     //        0 |  x * u |  x * v |
1165     //       90 | -y * v |  y * u |
1166     //      180 | -x * u | -x * v |
1167     //      270 |  y * v | -y * u |
1168     //
1169     // Check if the candidate can indeed be represented by partial
1170     // multiplication
1171     // TODO: Add support for multiplication by complex one
1172     if ((IsMultiplicandReal && PMI.IsNodeInverted) ||
1173         (!IsMultiplicandReal && !PMI.IsNodeInverted))
1174       continue;
1175 
1176     // Determine the rotation based on the multiplications
1177     ComplexDeinterleavingRotation Rotation;
1178     if (IsMultiplicandReal) {
1179       // Detect 0 and 180 degrees rotation
1180       if (RealMul.IsPositive && ImagMul.IsPositive)
1181         Rotation = llvm::ComplexDeinterleavingRotation::Rotation_0;
1182       else if (!RealMul.IsPositive && !ImagMul.IsPositive)
1183         Rotation = llvm::ComplexDeinterleavingRotation::Rotation_180;
1184       else
1185         continue;
1186 
1187     } else {
1188       // Detect 90 and 270 degrees rotation
1189       if (!RealMul.IsPositive && ImagMul.IsPositive)
1190         Rotation = llvm::ComplexDeinterleavingRotation::Rotation_90;
1191       else if (RealMul.IsPositive && !ImagMul.IsPositive)
1192         Rotation = llvm::ComplexDeinterleavingRotation::Rotation_270;
1193       else
1194         continue;
1195     }
1196 
1197     LLVM_DEBUG({
1198       dbgs() << "Identified partial multiplication (X, Y) * (U, V):\n";
1199       dbgs().indent(4) << "X: " << *NodeA->Real << "\n";
1200       dbgs().indent(4) << "Y: " << *NodeA->Imag << "\n";
1201       dbgs().indent(4) << "U: " << *NodeB->Real << "\n";
1202       dbgs().indent(4) << "V: " << *NodeB->Imag << "\n";
1203       dbgs().indent(4) << "Rotation - " << (int)Rotation * 90 << "\n";
1204     });
1205 
1206     NodePtr NodeMul = prepareCompositeNode(
1207         ComplexDeinterleavingOperation::CMulPartial, nullptr, nullptr);
1208     NodeMul->Rotation = Rotation;
1209     NodeMul->addOperand(NodeA);
1210     NodeMul->addOperand(NodeB);
1211     if (Result)
1212       NodeMul->addOperand(Result);
1213     submitCompositeNode(NodeMul);
1214     Result = NodeMul;
1215     ProcessedReal[PMI.RealIdx] = true;
1216     ProcessedImag[PMI.ImagIdx] = true;
1217   }
1218 
1219   // Ensure all products have been processed, if not return nullptr.
1220   if (!all_of(ProcessedReal, [](bool V) { return V; }) ||
1221       !all_of(ProcessedImag, [](bool V) { return V; })) {
1222 
1223     // Dump debug information about which partial multiplications are not
1224     // processed.
1225     LLVM_DEBUG({
1226       dbgs() << "Unprocessed products (Real):\n";
1227       for (size_t i = 0; i < ProcessedReal.size(); ++i) {
1228         if (!ProcessedReal[i])
1229           dbgs().indent(4) << (RealMuls[i].IsPositive ? "+" : "-")
1230                            << *RealMuls[i].Multiplier << " multiplied by "
1231                            << *RealMuls[i].Multiplicand << "\n";
1232       }
1233       dbgs() << "Unprocessed products (Imag):\n";
1234       for (size_t i = 0; i < ProcessedImag.size(); ++i) {
1235         if (!ProcessedImag[i])
1236           dbgs().indent(4) << (ImagMuls[i].IsPositive ? "+" : "-")
1237                            << *ImagMuls[i].Multiplier << " multiplied by "
1238                            << *ImagMuls[i].Multiplicand << "\n";
1239       }
1240     });
1241     return nullptr;
1242   }
1243 
1244   return Result;
1245 }
1246 
1247 ComplexDeinterleavingGraph::NodePtr
1248 ComplexDeinterleavingGraph::identifyAdditions(std::list<Addend> &RealAddends,
1249                                               std::list<Addend> &ImagAddends,
1250                                               FastMathFlags Flags,
1251                                               NodePtr Accumulator = nullptr) {
1252   if (RealAddends.size() != ImagAddends.size())
1253     return nullptr;
1254 
1255   NodePtr Result;
1256   // If we have accumulator use it as first addend
1257   if (Accumulator)
1258     Result = Accumulator;
1259   // Otherwise find an element with both positive real and imaginary parts.
1260   else
1261     Result = extractPositiveAddend(RealAddends, ImagAddends);
1262 
1263   if (!Result)
1264     return nullptr;
1265 
1266   while (!RealAddends.empty()) {
1267     auto ItR = RealAddends.begin();
1268     auto [R, IsPositiveR] = *ItR;
1269 
1270     bool FoundImag = false;
1271     for (auto ItI = ImagAddends.begin(); ItI != ImagAddends.end(); ++ItI) {
1272       auto [I, IsPositiveI] = *ItI;
1273       ComplexDeinterleavingRotation Rotation;
1274       if (IsPositiveR && IsPositiveI)
1275         Rotation = ComplexDeinterleavingRotation::Rotation_0;
1276       else if (!IsPositiveR && IsPositiveI)
1277         Rotation = ComplexDeinterleavingRotation::Rotation_90;
1278       else if (!IsPositiveR && !IsPositiveI)
1279         Rotation = ComplexDeinterleavingRotation::Rotation_180;
1280       else
1281         Rotation = ComplexDeinterleavingRotation::Rotation_270;
1282 
1283       NodePtr AddNode;
1284       if (Rotation == ComplexDeinterleavingRotation::Rotation_0 ||
1285           Rotation == ComplexDeinterleavingRotation::Rotation_180) {
1286         AddNode = identifyNode(R, I);
1287       } else {
1288         AddNode = identifyNode(I, R);
1289       }
1290       if (AddNode) {
1291         LLVM_DEBUG({
1292           dbgs() << "Identified addition:\n";
1293           dbgs().indent(4) << "X: " << *R << "\n";
1294           dbgs().indent(4) << "Y: " << *I << "\n";
1295           dbgs().indent(4) << "Rotation - " << (int)Rotation * 90 << "\n";
1296         });
1297 
1298         NodePtr TmpNode;
1299         if (Rotation == llvm::ComplexDeinterleavingRotation::Rotation_0) {
1300           TmpNode = prepareCompositeNode(
1301               ComplexDeinterleavingOperation::Symmetric, nullptr, nullptr);
1302           TmpNode->Opcode = Instruction::FAdd;
1303           TmpNode->Flags = Flags;
1304         } else if (Rotation ==
1305                    llvm::ComplexDeinterleavingRotation::Rotation_180) {
1306           TmpNode = prepareCompositeNode(
1307               ComplexDeinterleavingOperation::Symmetric, nullptr, nullptr);
1308           TmpNode->Opcode = Instruction::FSub;
1309           TmpNode->Flags = Flags;
1310         } else {
1311           TmpNode = prepareCompositeNode(ComplexDeinterleavingOperation::CAdd,
1312                                          nullptr, nullptr);
1313           TmpNode->Rotation = Rotation;
1314         }
1315 
1316         TmpNode->addOperand(Result);
1317         TmpNode->addOperand(AddNode);
1318         submitCompositeNode(TmpNode);
1319         Result = TmpNode;
1320         RealAddends.erase(ItR);
1321         ImagAddends.erase(ItI);
1322         FoundImag = true;
1323         break;
1324       }
1325     }
1326     if (!FoundImag)
1327       return nullptr;
1328   }
1329   return Result;
1330 }
1331 
1332 ComplexDeinterleavingGraph::NodePtr
1333 ComplexDeinterleavingGraph::extractPositiveAddend(
1334     std::list<Addend> &RealAddends, std::list<Addend> &ImagAddends) {
1335   for (auto ItR = RealAddends.begin(); ItR != RealAddends.end(); ++ItR) {
1336     for (auto ItI = ImagAddends.begin(); ItI != ImagAddends.end(); ++ItI) {
1337       auto [R, IsPositiveR] = *ItR;
1338       auto [I, IsPositiveI] = *ItI;
1339       if (IsPositiveR && IsPositiveI) {
1340         auto Result = identifyNode(R, I);
1341         if (Result) {
1342           RealAddends.erase(ItR);
1343           ImagAddends.erase(ItI);
1344           return Result;
1345         }
1346       }
1347     }
1348   }
1349   return nullptr;
1350 }
1351 
1352 bool ComplexDeinterleavingGraph::identifyNodes(Instruction *RootI) {
1353   // This potential root instruction might already have been recognized as
1354   // reduction. Because RootToNode maps both Real and Imaginary parts to
1355   // CompositeNode we should choose only one either Real or Imag instruction to
1356   // use as an anchor for generating complex instruction.
1357   auto It = RootToNode.find(RootI);
1358   if (It != RootToNode.end() && It->second->Real == RootI) {
1359     OrderedRoots.push_back(RootI);
1360     return true;
1361   }
1362 
1363   auto RootNode = identifyRoot(RootI);
1364   if (!RootNode)
1365     return false;
1366 
1367   LLVM_DEBUG({
1368     Function *F = RootI->getFunction();
1369     BasicBlock *B = RootI->getParent();
1370     dbgs() << "Complex deinterleaving graph for " << F->getName()
1371            << "::" << B->getName() << ".\n";
1372     dump(dbgs());
1373     dbgs() << "\n";
1374   });
1375   RootToNode[RootI] = RootNode;
1376   OrderedRoots.push_back(RootI);
1377   return true;
1378 }
1379 
1380 bool ComplexDeinterleavingGraph::collectPotentialReductions(BasicBlock *B) {
1381   bool FoundPotentialReduction = false;
1382 
1383   auto *Br = dyn_cast<BranchInst>(B->getTerminator());
1384   if (!Br || Br->getNumSuccessors() != 2)
1385     return false;
1386 
1387   // Identify simple one-block loop
1388   if (Br->getSuccessor(0) != B && Br->getSuccessor(1) != B)
1389     return false;
1390 
1391   SmallVector<PHINode *> PHIs;
1392   for (auto &PHI : B->phis()) {
1393     if (PHI.getNumIncomingValues() != 2)
1394       continue;
1395 
1396     if (!PHI.getType()->isVectorTy())
1397       continue;
1398 
1399     auto *ReductionOp = dyn_cast<Instruction>(PHI.getIncomingValueForBlock(B));
1400     if (!ReductionOp)
1401       continue;
1402 
1403     // Check if final instruction is reduced outside of current block
1404     Instruction *FinalReduction = nullptr;
1405     auto NumUsers = 0u;
1406     for (auto *U : ReductionOp->users()) {
1407       ++NumUsers;
1408       if (U == &PHI)
1409         continue;
1410       FinalReduction = dyn_cast<Instruction>(U);
1411     }
1412 
1413     if (NumUsers != 2 || !FinalReduction || FinalReduction->getParent() == B)
1414       continue;
1415 
1416     ReductionInfo[ReductionOp] = {&PHI, FinalReduction};
1417     BackEdge = B;
1418     auto BackEdgeIdx = PHI.getBasicBlockIndex(B);
1419     auto IncomingIdx = BackEdgeIdx == 0 ? 1 : 0;
1420     Incoming = PHI.getIncomingBlock(IncomingIdx);
1421     FoundPotentialReduction = true;
1422 
1423     // If the initial value of PHINode is an Instruction, consider it a leaf
1424     // value of a complex deinterleaving graph.
1425     if (auto *InitPHI =
1426             dyn_cast<Instruction>(PHI.getIncomingValueForBlock(Incoming)))
1427       FinalInstructions.insert(InitPHI);
1428   }
1429   return FoundPotentialReduction;
1430 }
1431 
1432 void ComplexDeinterleavingGraph::identifyReductionNodes() {
1433   SmallVector<bool> Processed(ReductionInfo.size(), false);
1434   SmallVector<Instruction *> OperationInstruction;
1435   for (auto &P : ReductionInfo)
1436     OperationInstruction.push_back(P.first);
1437 
1438   // Identify a complex computation by evaluating two reduction operations that
1439   // potentially could be involved
1440   for (size_t i = 0; i < OperationInstruction.size(); ++i) {
1441     if (Processed[i])
1442       continue;
1443     for (size_t j = i + 1; j < OperationInstruction.size(); ++j) {
1444       if (Processed[j])
1445         continue;
1446 
1447       auto *Real = OperationInstruction[i];
1448       auto *Imag = OperationInstruction[j];
1449       if (Real->getType() != Imag->getType())
1450         continue;
1451 
1452       RealPHI = ReductionInfo[Real].first;
1453       ImagPHI = ReductionInfo[Imag].first;
1454       auto Node = identifyNode(Real, Imag);
1455       if (!Node) {
1456         std::swap(Real, Imag);
1457         std::swap(RealPHI, ImagPHI);
1458         Node = identifyNode(Real, Imag);
1459       }
1460 
1461       // If a node is identified, mark its operation instructions as used to
1462       // prevent re-identification and attach the node to the real part
1463       if (Node) {
1464         LLVM_DEBUG(dbgs() << "Identified reduction starting from instructions: "
1465                           << *Real << " / " << *Imag << "\n");
1466         Processed[i] = true;
1467         Processed[j] = true;
1468         auto RootNode = prepareCompositeNode(
1469             ComplexDeinterleavingOperation::ReductionOperation, Real, Imag);
1470         RootNode->addOperand(Node);
1471         RootToNode[Real] = RootNode;
1472         RootToNode[Imag] = RootNode;
1473         submitCompositeNode(RootNode);
1474         break;
1475       }
1476     }
1477   }
1478 
1479   RealPHI = nullptr;
1480   ImagPHI = nullptr;
1481 }
1482 
1483 bool ComplexDeinterleavingGraph::checkNodes() {
1484   // Collect all instructions from roots to leaves
1485   SmallPtrSet<Instruction *, 16> AllInstructions;
1486   SmallVector<Instruction *, 8> Worklist;
1487   for (auto &Pair : RootToNode)
1488     Worklist.push_back(Pair.first);
1489 
1490   // Extract all instructions that are used by all XCMLA/XCADD/ADD/SUB/NEG
1491   // chains
1492   while (!Worklist.empty()) {
1493     auto *I = Worklist.back();
1494     Worklist.pop_back();
1495 
1496     if (!AllInstructions.insert(I).second)
1497       continue;
1498 
1499     for (Value *Op : I->operands()) {
1500       if (auto *OpI = dyn_cast<Instruction>(Op)) {
1501         if (!FinalInstructions.count(I))
1502           Worklist.emplace_back(OpI);
1503       }
1504     }
1505   }
1506 
1507   // Find instructions that have users outside of chain
1508   SmallVector<Instruction *, 2> OuterInstructions;
1509   for (auto *I : AllInstructions) {
1510     // Skip root nodes
1511     if (RootToNode.count(I))
1512       continue;
1513 
1514     for (User *U : I->users()) {
1515       if (AllInstructions.count(cast<Instruction>(U)))
1516         continue;
1517 
1518       // Found an instruction that is not used by XCMLA/XCADD chain
1519       Worklist.emplace_back(I);
1520       break;
1521     }
1522   }
1523 
1524   // If any instructions are found to be used outside, find and remove roots
1525   // that somehow connect to those instructions.
1526   SmallPtrSet<Instruction *, 16> Visited;
1527   while (!Worklist.empty()) {
1528     auto *I = Worklist.back();
1529     Worklist.pop_back();
1530     if (!Visited.insert(I).second)
1531       continue;
1532 
1533     // Found an impacted root node. Removing it from the nodes to be
1534     // deinterleaved
1535     if (RootToNode.count(I)) {
1536       LLVM_DEBUG(dbgs() << "Instruction " << *I
1537                         << " could be deinterleaved but its chain of complex "
1538                            "operations have an outside user\n");
1539       RootToNode.erase(I);
1540     }
1541 
1542     if (!AllInstructions.count(I) || FinalInstructions.count(I))
1543       continue;
1544 
1545     for (User *U : I->users())
1546       Worklist.emplace_back(cast<Instruction>(U));
1547 
1548     for (Value *Op : I->operands()) {
1549       if (auto *OpI = dyn_cast<Instruction>(Op))
1550         Worklist.emplace_back(OpI);
1551     }
1552   }
1553   return !RootToNode.empty();
1554 }
1555 
1556 ComplexDeinterleavingGraph::NodePtr
1557 ComplexDeinterleavingGraph::identifyRoot(Instruction *RootI) {
1558   if (auto *Intrinsic = dyn_cast<IntrinsicInst>(RootI)) {
1559     if (Intrinsic->getIntrinsicID() !=
1560         Intrinsic::experimental_vector_interleave2)
1561       return nullptr;
1562 
1563     auto *Real = dyn_cast<Instruction>(Intrinsic->getOperand(0));
1564     auto *Imag = dyn_cast<Instruction>(Intrinsic->getOperand(1));
1565     if (!Real || !Imag)
1566       return nullptr;
1567 
1568     return identifyNode(Real, Imag);
1569   }
1570 
1571   auto *SVI = dyn_cast<ShuffleVectorInst>(RootI);
1572   if (!SVI)
1573     return nullptr;
1574 
1575   // Look for a shufflevector that takes separate vectors of the real and
1576   // imaginary components and recombines them into a single vector.
1577   if (!isInterleavingMask(SVI->getShuffleMask()))
1578     return nullptr;
1579 
1580   Instruction *Real;
1581   Instruction *Imag;
1582   if (!match(RootI, m_Shuffle(m_Instruction(Real), m_Instruction(Imag))))
1583     return nullptr;
1584 
1585   return identifyNode(Real, Imag);
1586 }
1587 
1588 ComplexDeinterleavingGraph::NodePtr
1589 ComplexDeinterleavingGraph::identifyDeinterleave(Instruction *Real,
1590                                                  Instruction *Imag) {
1591   Instruction *I = nullptr;
1592   Value *FinalValue = nullptr;
1593   if (match(Real, m_ExtractValue<0>(m_Instruction(I))) &&
1594       match(Imag, m_ExtractValue<1>(m_Specific(I))) &&
1595       match(I, m_Intrinsic<Intrinsic::experimental_vector_deinterleave2>(
1596                    m_Value(FinalValue)))) {
1597     NodePtr PlaceholderNode = prepareCompositeNode(
1598         llvm::ComplexDeinterleavingOperation::Deinterleave, Real, Imag);
1599     PlaceholderNode->ReplacementNode = FinalValue;
1600     FinalInstructions.insert(Real);
1601     FinalInstructions.insert(Imag);
1602     return submitCompositeNode(PlaceholderNode);
1603   }
1604 
1605   auto *RealShuffle = dyn_cast<ShuffleVectorInst>(Real);
1606   auto *ImagShuffle = dyn_cast<ShuffleVectorInst>(Imag);
1607   if (!RealShuffle || !ImagShuffle) {
1608     if (RealShuffle || ImagShuffle)
1609       LLVM_DEBUG(dbgs() << " - There's a shuffle where there shouldn't be.\n");
1610     return nullptr;
1611   }
1612 
1613   Value *RealOp1 = RealShuffle->getOperand(1);
1614   if (!isa<UndefValue>(RealOp1) && !isa<ConstantAggregateZero>(RealOp1)) {
1615     LLVM_DEBUG(dbgs() << " - RealOp1 is not undef or zero.\n");
1616     return nullptr;
1617   }
1618   Value *ImagOp1 = ImagShuffle->getOperand(1);
1619   if (!isa<UndefValue>(ImagOp1) && !isa<ConstantAggregateZero>(ImagOp1)) {
1620     LLVM_DEBUG(dbgs() << " - ImagOp1 is not undef or zero.\n");
1621     return nullptr;
1622   }
1623 
1624   Value *RealOp0 = RealShuffle->getOperand(0);
1625   Value *ImagOp0 = ImagShuffle->getOperand(0);
1626 
1627   if (RealOp0 != ImagOp0) {
1628     LLVM_DEBUG(dbgs() << " - Shuffle operands are not equal.\n");
1629     return nullptr;
1630   }
1631 
1632   ArrayRef<int> RealMask = RealShuffle->getShuffleMask();
1633   ArrayRef<int> ImagMask = ImagShuffle->getShuffleMask();
1634   if (!isDeinterleavingMask(RealMask) || !isDeinterleavingMask(ImagMask)) {
1635     LLVM_DEBUG(dbgs() << " - Masks are not deinterleaving.\n");
1636     return nullptr;
1637   }
1638 
1639   if (RealMask[0] != 0 || ImagMask[0] != 1) {
1640     LLVM_DEBUG(dbgs() << " - Masks do not have the correct initial value.\n");
1641     return nullptr;
1642   }
1643 
1644   // Type checking, the shuffle type should be a vector type of the same
1645   // scalar type, but half the size
1646   auto CheckType = [&](ShuffleVectorInst *Shuffle) {
1647     Value *Op = Shuffle->getOperand(0);
1648     auto *ShuffleTy = cast<FixedVectorType>(Shuffle->getType());
1649     auto *OpTy = cast<FixedVectorType>(Op->getType());
1650 
1651     if (OpTy->getScalarType() != ShuffleTy->getScalarType())
1652       return false;
1653     if ((ShuffleTy->getNumElements() * 2) != OpTy->getNumElements())
1654       return false;
1655 
1656     return true;
1657   };
1658 
1659   auto CheckDeinterleavingShuffle = [&](ShuffleVectorInst *Shuffle) -> bool {
1660     if (!CheckType(Shuffle))
1661       return false;
1662 
1663     ArrayRef<int> Mask = Shuffle->getShuffleMask();
1664     int Last = *Mask.rbegin();
1665 
1666     Value *Op = Shuffle->getOperand(0);
1667     auto *OpTy = cast<FixedVectorType>(Op->getType());
1668     int NumElements = OpTy->getNumElements();
1669 
1670     // Ensure that the deinterleaving shuffle only pulls from the first
1671     // shuffle operand.
1672     return Last < NumElements;
1673   };
1674 
1675   if (RealShuffle->getType() != ImagShuffle->getType()) {
1676     LLVM_DEBUG(dbgs() << " - Shuffle types aren't equal.\n");
1677     return nullptr;
1678   }
1679   if (!CheckDeinterleavingShuffle(RealShuffle)) {
1680     LLVM_DEBUG(dbgs() << " - RealShuffle is invalid type.\n");
1681     return nullptr;
1682   }
1683   if (!CheckDeinterleavingShuffle(ImagShuffle)) {
1684     LLVM_DEBUG(dbgs() << " - ImagShuffle is invalid type.\n");
1685     return nullptr;
1686   }
1687 
1688   NodePtr PlaceholderNode =
1689       prepareCompositeNode(llvm::ComplexDeinterleavingOperation::Deinterleave,
1690                            RealShuffle, ImagShuffle);
1691   PlaceholderNode->ReplacementNode = RealShuffle->getOperand(0);
1692   FinalInstructions.insert(RealShuffle);
1693   FinalInstructions.insert(ImagShuffle);
1694   return submitCompositeNode(PlaceholderNode);
1695 }
1696 
1697 ComplexDeinterleavingGraph::NodePtr
1698 ComplexDeinterleavingGraph::identifyPHINode(Instruction *Real,
1699                                             Instruction *Imag) {
1700   if (Real != RealPHI || Imag != ImagPHI)
1701     return nullptr;
1702 
1703   NodePtr PlaceholderNode = prepareCompositeNode(
1704       ComplexDeinterleavingOperation::ReductionPHI, Real, Imag);
1705   return submitCompositeNode(PlaceholderNode);
1706 }
1707 
1708 ComplexDeinterleavingGraph::NodePtr
1709 ComplexDeinterleavingGraph::identifySelectNode(Instruction *Real,
1710                                                Instruction *Imag) {
1711   auto *SelectReal = dyn_cast<SelectInst>(Real);
1712   auto *SelectImag = dyn_cast<SelectInst>(Imag);
1713   if (!SelectReal || !SelectImag)
1714     return nullptr;
1715 
1716   Instruction *MaskA, *MaskB;
1717   Instruction *AR, *AI, *RA, *BI;
1718   if (!match(Real, m_Select(m_Instruction(MaskA), m_Instruction(AR),
1719                             m_Instruction(RA))) ||
1720       !match(Imag, m_Select(m_Instruction(MaskB), m_Instruction(AI),
1721                             m_Instruction(BI))))
1722     return nullptr;
1723 
1724   if (MaskA != MaskB && !MaskA->isIdenticalTo(MaskB))
1725     return nullptr;
1726 
1727   if (!MaskA->getType()->isVectorTy())
1728     return nullptr;
1729 
1730   auto NodeA = identifyNode(AR, AI);
1731   if (!NodeA)
1732     return nullptr;
1733 
1734   auto NodeB = identifyNode(RA, BI);
1735   if (!NodeB)
1736     return nullptr;
1737 
1738   NodePtr PlaceholderNode = prepareCompositeNode(
1739       ComplexDeinterleavingOperation::ReductionSelect, Real, Imag);
1740   PlaceholderNode->addOperand(NodeA);
1741   PlaceholderNode->addOperand(NodeB);
1742   FinalInstructions.insert(MaskA);
1743   FinalInstructions.insert(MaskB);
1744   return submitCompositeNode(PlaceholderNode);
1745 }
1746 
1747 static Value *replaceSymmetricNode(IRBuilderBase &B, unsigned Opcode,
1748                                    FastMathFlags Flags, Value *InputA,
1749                                    Value *InputB) {
1750   Value *I;
1751   switch (Opcode) {
1752   case Instruction::FNeg:
1753     I = B.CreateFNeg(InputA);
1754     break;
1755   case Instruction::FAdd:
1756     I = B.CreateFAdd(InputA, InputB);
1757     break;
1758   case Instruction::FSub:
1759     I = B.CreateFSub(InputA, InputB);
1760     break;
1761   case Instruction::FMul:
1762     I = B.CreateFMul(InputA, InputB);
1763     break;
1764   default:
1765     llvm_unreachable("Incorrect symmetric opcode");
1766   }
1767   cast<Instruction>(I)->setFastMathFlags(Flags);
1768   return I;
1769 }
1770 
1771 Value *ComplexDeinterleavingGraph::replaceNode(IRBuilderBase &Builder,
1772                                                RawNodePtr Node) {
1773   if (Node->ReplacementNode)
1774     return Node->ReplacementNode;
1775 
1776   auto ReplaceOperandIfExist = [&](RawNodePtr &Node, unsigned Idx) -> Value * {
1777     return Node->Operands.size() > Idx
1778                ? replaceNode(Builder, Node->Operands[Idx])
1779                : nullptr;
1780   };
1781 
1782   Value *ReplacementNode;
1783   switch (Node->Operation) {
1784   case ComplexDeinterleavingOperation::CAdd:
1785   case ComplexDeinterleavingOperation::CMulPartial:
1786   case ComplexDeinterleavingOperation::Symmetric: {
1787     Value *Input0 = ReplaceOperandIfExist(Node, 0);
1788     Value *Input1 = ReplaceOperandIfExist(Node, 1);
1789     Value *Accumulator = ReplaceOperandIfExist(Node, 2);
1790     assert(!Input1 || (Input0->getType() == Input1->getType() &&
1791                        "Node inputs need to be of the same type"));
1792     assert(!Accumulator ||
1793            (Input0->getType() == Accumulator->getType() &&
1794             "Accumulator and input need to be of the same type"));
1795     if (Node->Operation == ComplexDeinterleavingOperation::Symmetric)
1796       ReplacementNode = replaceSymmetricNode(Builder, Node->Opcode, Node->Flags,
1797                                              Input0, Input1);
1798     else
1799       ReplacementNode = TL->createComplexDeinterleavingIR(
1800           Builder, Node->Operation, Node->Rotation, Input0, Input1,
1801           Accumulator);
1802     break;
1803   }
1804   case ComplexDeinterleavingOperation::Deinterleave:
1805     llvm_unreachable("Deinterleave node should already have ReplacementNode");
1806     break;
1807   case ComplexDeinterleavingOperation::ReductionPHI: {
1808     // If Operation is ReductionPHI, a new empty PHINode is created.
1809     // It is filled later when the ReductionOperation is processed.
1810     auto *VTy = cast<VectorType>(Node->Real->getType());
1811     auto *NewVTy = VectorType::getDoubleElementsVectorType(VTy);
1812     auto *NewPHI = PHINode::Create(NewVTy, 0, "", BackEdge->getFirstNonPHI());
1813     OldToNewPHI[dyn_cast<PHINode>(Node->Real)] = NewPHI;
1814     ReplacementNode = NewPHI;
1815     break;
1816   }
1817   case ComplexDeinterleavingOperation::ReductionOperation:
1818     ReplacementNode = replaceNode(Builder, Node->Operands[0]);
1819     processReductionOperation(ReplacementNode, Node);
1820     break;
1821   case ComplexDeinterleavingOperation::ReductionSelect: {
1822     auto *MaskReal = cast<Instruction>(Node->Real)->getOperand(0);
1823     auto *MaskImag = cast<Instruction>(Node->Imag)->getOperand(0);
1824     auto *A = replaceNode(Builder, Node->Operands[0]);
1825     auto *B = replaceNode(Builder, Node->Operands[1]);
1826     auto *NewMaskTy = VectorType::getDoubleElementsVectorType(
1827         cast<VectorType>(MaskReal->getType()));
1828     auto *NewMask =
1829         Builder.CreateIntrinsic(Intrinsic::experimental_vector_interleave2,
1830                                 NewMaskTy, {MaskReal, MaskImag});
1831     ReplacementNode = Builder.CreateSelect(NewMask, A, B);
1832     break;
1833   }
1834   }
1835 
1836   assert(ReplacementNode && "Target failed to create Intrinsic call.");
1837   NumComplexTransformations += 1;
1838   Node->ReplacementNode = ReplacementNode;
1839   return ReplacementNode;
1840 }
1841 
1842 void ComplexDeinterleavingGraph::processReductionOperation(
1843     Value *OperationReplacement, RawNodePtr Node) {
1844   auto *Real = cast<Instruction>(Node->Real);
1845   auto *Imag = cast<Instruction>(Node->Imag);
1846   auto *OldPHIReal = ReductionInfo[Real].first;
1847   auto *OldPHIImag = ReductionInfo[Imag].first;
1848   auto *NewPHI = OldToNewPHI[OldPHIReal];
1849 
1850   auto *VTy = cast<VectorType>(Real->getType());
1851   auto *NewVTy = VectorType::getDoubleElementsVectorType(VTy);
1852 
1853   // We have to interleave initial origin values coming from IncomingBlock
1854   Value *InitReal = OldPHIReal->getIncomingValueForBlock(Incoming);
1855   Value *InitImag = OldPHIImag->getIncomingValueForBlock(Incoming);
1856 
1857   IRBuilder<> Builder(Incoming->getTerminator());
1858   auto *NewInit = Builder.CreateIntrinsic(
1859       Intrinsic::experimental_vector_interleave2, NewVTy, {InitReal, InitImag});
1860 
1861   NewPHI->addIncoming(NewInit, Incoming);
1862   NewPHI->addIncoming(OperationReplacement, BackEdge);
1863 
1864   // Deinterleave complex vector outside of loop so that it can be finally
1865   // reduced
1866   auto *FinalReductionReal = ReductionInfo[Real].second;
1867   auto *FinalReductionImag = ReductionInfo[Imag].second;
1868 
1869   Builder.SetInsertPoint(
1870       &*FinalReductionReal->getParent()->getFirstInsertionPt());
1871   auto *Deinterleave = Builder.CreateIntrinsic(
1872       Intrinsic::experimental_vector_deinterleave2,
1873       OperationReplacement->getType(), OperationReplacement);
1874 
1875   auto *NewReal = Builder.CreateExtractValue(Deinterleave, (uint64_t)0);
1876   FinalReductionReal->replaceUsesOfWith(Real, NewReal);
1877 
1878   Builder.SetInsertPoint(FinalReductionImag);
1879   auto *NewImag = Builder.CreateExtractValue(Deinterleave, 1);
1880   FinalReductionImag->replaceUsesOfWith(Imag, NewImag);
1881 }
1882 
1883 void ComplexDeinterleavingGraph::replaceNodes() {
1884   SmallVector<Instruction *, 16> DeadInstrRoots;
1885   for (auto *RootInstruction : OrderedRoots) {
1886     // Check if this potential root went through check process and we can
1887     // deinterleave it
1888     if (!RootToNode.count(RootInstruction))
1889       continue;
1890 
1891     IRBuilder<> Builder(RootInstruction);
1892     auto RootNode = RootToNode[RootInstruction];
1893     Value *R = replaceNode(Builder, RootNode.get());
1894 
1895     if (RootNode->Operation ==
1896         ComplexDeinterleavingOperation::ReductionOperation) {
1897       auto *RootReal = cast<Instruction>(RootNode->Real);
1898       auto *RootImag = cast<Instruction>(RootNode->Imag);
1899       ReductionInfo[RootReal].first->removeIncomingValue(BackEdge);
1900       ReductionInfo[RootImag].first->removeIncomingValue(BackEdge);
1901       DeadInstrRoots.push_back(cast<Instruction>(RootReal));
1902       DeadInstrRoots.push_back(cast<Instruction>(RootImag));
1903     } else {
1904       assert(R && "Unable to find replacement for RootInstruction");
1905       DeadInstrRoots.push_back(RootInstruction);
1906       RootInstruction->replaceAllUsesWith(R);
1907     }
1908   }
1909 
1910   for (auto *I : DeadInstrRoots)
1911     RecursivelyDeleteTriviallyDeadInstructions(I, TLI);
1912 }
1913