xref: /llvm-project/llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp (revision c78acc92759cda8fc6ecb2690df39c7ee65355bf)
1 //===- ComplexDeinterleavingPass.cpp --------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Identification:
10 // This step is responsible for finding the patterns that can be lowered to
11 // complex instructions, and building a graph to represent the complex
12 // structures. Starting from the "Converging Shuffle" (a shuffle that
13 // reinterleaves the complex components, with a mask of <0, 2, 1, 3>), the
14 // operands are evaluated and identified as "Composite Nodes" (collections of
15 // instructions that can potentially be lowered to a single complex
16 // instruction). This is performed by checking the real and imaginary components
17 // and tracking the data flow for each component while following the operand
18 // pairs. Validity of each node is expected to be done upon creation, and any
19 // validation errors should halt traversal and prevent further graph
20 // construction.
21 // Instead of relying on Shuffle operations, vector interleaving and
22 // deinterleaving can be represented by vector.interleave2 and
23 // vector.deinterleave2 intrinsics. Scalable vectors can be represented only by
24 // these intrinsics, whereas, fixed-width vectors are recognized for both
25 // shufflevector instruction and intrinsics.
26 //
27 // Replacement:
28 // This step traverses the graph built up by identification, delegating to the
29 // target to validate and generate the correct intrinsics, and plumbs them
30 // together connecting each end of the new intrinsics graph to the existing
31 // use-def chain. This step is assumed to finish successfully, as all
32 // information is expected to be correct by this point.
33 //
34 //
35 // Internal data structure:
36 // ComplexDeinterleavingGraph:
37 // Keeps references to all the valid CompositeNodes formed as part of the
38 // transformation, and every Instruction contained within said nodes. It also
39 // holds onto a reference to the root Instruction, and the root node that should
40 // replace it.
41 //
42 // ComplexDeinterleavingCompositeNode:
43 // A CompositeNode represents a single transformation point; each node should
44 // transform into a single complex instruction (ignoring vector splitting, which
45 // would generate more instructions per node). They are identified in a
46 // depth-first manner, traversing and identifying the operands of each
47 // instruction in the order they appear in the IR.
48 // Each node maintains a reference  to its Real and Imaginary instructions,
49 // as well as any additional instructions that make up the identified operation
50 // (Internal instructions should only have uses within their containing node).
51 // A Node also contains the rotation and operation type that it represents.
52 // Operands contains pointers to other CompositeNodes, acting as the edges in
53 // the graph. ReplacementValue is the transformed Value* that has been emitted
54 // to the IR.
55 //
56 // Note: If the operation of a Node is Shuffle, only the Real, Imaginary, and
57 // ReplacementValue fields of that Node are relevant, where the ReplacementValue
58 // should be pre-populated.
59 //
60 //===----------------------------------------------------------------------===//
61 
62 #include "llvm/CodeGen/ComplexDeinterleavingPass.h"
63 #include "llvm/ADT/Statistic.h"
64 #include "llvm/Analysis/TargetLibraryInfo.h"
65 #include "llvm/Analysis/TargetTransformInfo.h"
66 #include "llvm/CodeGen/TargetLowering.h"
67 #include "llvm/CodeGen/TargetPassConfig.h"
68 #include "llvm/CodeGen/TargetSubtargetInfo.h"
69 #include "llvm/IR/IRBuilder.h"
70 #include "llvm/IR/PatternMatch.h"
71 #include "llvm/InitializePasses.h"
72 #include "llvm/Target/TargetMachine.h"
73 #include "llvm/Transforms/Utils/Local.h"
74 #include <algorithm>
75 
76 using namespace llvm;
77 using namespace PatternMatch;
78 
79 #define DEBUG_TYPE "complex-deinterleaving"
80 
81 STATISTIC(NumComplexTransformations, "Amount of complex patterns transformed");
82 
83 static cl::opt<bool> ComplexDeinterleavingEnabled(
84     "enable-complex-deinterleaving",
85     cl::desc("Enable generation of complex instructions"), cl::init(true),
86     cl::Hidden);
87 
88 /// Checks the given mask, and determines whether said mask is interleaving.
89 ///
90 /// To be interleaving, a mask must alternate between `i` and `i + (Length /
91 /// 2)`, and must contain all numbers within the range of `[0..Length)` (e.g. a
92 /// 4x vector interleaving mask would be <0, 2, 1, 3>).
93 static bool isInterleavingMask(ArrayRef<int> Mask);
94 
95 /// Checks the given mask, and determines whether said mask is deinterleaving.
96 ///
97 /// To be deinterleaving, a mask must increment in steps of 2, and either start
98 /// with 0 or 1.
99 /// (e.g. an 8x vector deinterleaving mask would be either <0, 2, 4, 6> or
100 /// <1, 3, 5, 7>).
101 static bool isDeinterleavingMask(ArrayRef<int> Mask);
102 
103 namespace {
104 
105 class ComplexDeinterleavingLegacyPass : public FunctionPass {
106 public:
107   static char ID;
108 
109   ComplexDeinterleavingLegacyPass(const TargetMachine *TM = nullptr)
110       : FunctionPass(ID), TM(TM) {
111     initializeComplexDeinterleavingLegacyPassPass(
112         *PassRegistry::getPassRegistry());
113   }
114 
115   StringRef getPassName() const override {
116     return "Complex Deinterleaving Pass";
117   }
118 
119   bool runOnFunction(Function &F) override;
120   void getAnalysisUsage(AnalysisUsage &AU) const override {
121     AU.addRequired<TargetLibraryInfoWrapperPass>();
122     AU.setPreservesCFG();
123   }
124 
125 private:
126   const TargetMachine *TM;
127 };
128 
129 class ComplexDeinterleavingGraph;
130 struct ComplexDeinterleavingCompositeNode {
131 
132   ComplexDeinterleavingCompositeNode(ComplexDeinterleavingOperation Op,
133                                      Instruction *R, Instruction *I)
134       : Operation(Op), Real(R), Imag(I) {}
135 
136 private:
137   friend class ComplexDeinterleavingGraph;
138   using NodePtr = std::shared_ptr<ComplexDeinterleavingCompositeNode>;
139   using RawNodePtr = ComplexDeinterleavingCompositeNode *;
140 
141 public:
142   ComplexDeinterleavingOperation Operation;
143   Instruction *Real;
144   Instruction *Imag;
145 
146   // This two members are required exclusively for generating
147   // ComplexDeinterleavingOperation::Symmetric operations.
148   unsigned Opcode;
149   FastMathFlags Flags;
150 
151   ComplexDeinterleavingRotation Rotation =
152       ComplexDeinterleavingRotation::Rotation_0;
153   SmallVector<RawNodePtr> Operands;
154   Value *ReplacementNode = nullptr;
155 
156   void addOperand(NodePtr Node) { Operands.push_back(Node.get()); }
157 
158   void dump() { dump(dbgs()); }
159   void dump(raw_ostream &OS) {
160     auto PrintValue = [&](Value *V) {
161       if (V) {
162         OS << "\"";
163         V->print(OS, true);
164         OS << "\"\n";
165       } else
166         OS << "nullptr\n";
167     };
168     auto PrintNodeRef = [&](RawNodePtr Ptr) {
169       if (Ptr)
170         OS << Ptr << "\n";
171       else
172         OS << "nullptr\n";
173     };
174 
175     OS << "- CompositeNode: " << this << "\n";
176     OS << "  Real: ";
177     PrintValue(Real);
178     OS << "  Imag: ";
179     PrintValue(Imag);
180     OS << "  ReplacementNode: ";
181     PrintValue(ReplacementNode);
182     OS << "  Operation: " << (int)Operation << "\n";
183     OS << "  Rotation: " << ((int)Rotation * 90) << "\n";
184     OS << "  Operands: \n";
185     for (const auto &Op : Operands) {
186       OS << "    - ";
187       PrintNodeRef(Op);
188     }
189   }
190 };
191 
192 class ComplexDeinterleavingGraph {
193 public:
194   struct Product {
195     Instruction *Multiplier;
196     Instruction *Multiplicand;
197     bool IsPositive;
198   };
199 
200   using Addend = std::pair<Instruction *, bool>;
201   using NodePtr = ComplexDeinterleavingCompositeNode::NodePtr;
202   using RawNodePtr = ComplexDeinterleavingCompositeNode::RawNodePtr;
203 
204   // Helper struct for holding info about potential partial multiplication
205   // candidates
206   struct PartialMulCandidate {
207     Instruction *Common;
208     NodePtr Node;
209     unsigned RealIdx;
210     unsigned ImagIdx;
211     bool IsNodeInverted;
212   };
213 
214   explicit ComplexDeinterleavingGraph(const TargetLowering *TL,
215                                       const TargetLibraryInfo *TLI)
216       : TL(TL), TLI(TLI) {}
217 
218 private:
219   const TargetLowering *TL = nullptr;
220   const TargetLibraryInfo *TLI = nullptr;
221   SmallVector<NodePtr> CompositeNodes;
222 
223   SmallPtrSet<Instruction *, 16> FinalInstructions;
224 
225   /// Root instructions are instructions from which complex computation starts
226   std::map<Instruction *, NodePtr> RootToNode;
227 
228   /// Topologically sorted root instructions
229   SmallVector<Instruction *, 1> OrderedRoots;
230 
231   /// When examining a basic block for complex deinterleaving, if it is a simple
232   /// one-block loop, then the only incoming block is 'Incoming' and the
233   /// 'BackEdge' block is the block itself."
234   BasicBlock *BackEdge = nullptr;
235   BasicBlock *Incoming = nullptr;
236 
237   /// ReductionInfo maps from %ReductionOp to %PHInode and Instruction
238   /// %OutsideUser as it is shown in the IR:
239   ///
240   /// vector.body:
241   ///   %PHInode = phi <vector type> [ zeroinitializer, %entry ],
242   ///                                [ %ReductionOp, %vector.body ]
243   ///   ...
244   ///   %ReductionOp = fadd i64 ...
245   ///   ...
246   ///   br i1 %condition, label %vector.body, %middle.block
247   ///
248   /// middle.block:
249   ///   %OutsideUser = llvm.vector.reduce.fadd(..., %ReductionOp)
250   ///
251   /// %OutsideUser can be `llvm.vector.reduce.fadd` or `fadd` preceding
252   /// `llvm.vector.reduce.fadd` when unroll factor isn't one.
253   std::map<Instruction *, std::pair<PHINode *, Instruction *>> ReductionInfo;
254 
255   /// In the process of detecting a reduction, we consider a pair of
256   /// %ReductionOP, which we refer to as real and imag (or vice versa), and
257   /// traverse the use-tree to detect complex operations. As this is a reduction
258   /// operation, it will eventually reach RealPHI and ImagPHI, which corresponds
259   /// to the %ReductionOPs that we suspect to be complex.
260   /// RealPHI and ImagPHI are used by the identifyPHINode method.
261   PHINode *RealPHI = nullptr;
262   PHINode *ImagPHI = nullptr;
263 
264   /// OldToNewPHI maps the original real PHINode to a new, double-sized PHINode.
265   /// The new PHINode corresponds to a vector of deinterleaved complex numbers.
266   /// This mapping is populated during
267   /// ComplexDeinterleavingOperation::ReductionPHI node replacement. It is then
268   /// used in the ComplexDeinterleavingOperation::ReductionOperation node
269   /// replacement process.
270   std::map<PHINode *, PHINode *> OldToNewPHI;
271 
272   NodePtr prepareCompositeNode(ComplexDeinterleavingOperation Operation,
273                                Instruction *R, Instruction *I) {
274     assert(((Operation != ComplexDeinterleavingOperation::ReductionPHI &&
275              Operation != ComplexDeinterleavingOperation::ReductionOperation) ||
276             (R && I)) &&
277            "Reduction related nodes must have Real and Imaginary parts");
278     return std::make_shared<ComplexDeinterleavingCompositeNode>(Operation, R,
279                                                                 I);
280   }
281 
282   NodePtr submitCompositeNode(NodePtr Node) {
283     CompositeNodes.push_back(Node);
284     return Node;
285   }
286 
287   NodePtr getContainingComposite(Value *R, Value *I) {
288     for (const auto &CN : CompositeNodes) {
289       if (CN->Real == R && CN->Imag == I)
290         return CN;
291     }
292     return nullptr;
293   }
294 
295   /// Identifies a complex partial multiply pattern and its rotation, based on
296   /// the following patterns
297   ///
298   ///  0:  r: cr + ar * br
299   ///      i: ci + ar * bi
300   /// 90:  r: cr - ai * bi
301   ///      i: ci + ai * br
302   /// 180: r: cr - ar * br
303   ///      i: ci - ar * bi
304   /// 270: r: cr + ai * bi
305   ///      i: ci - ai * br
306   NodePtr identifyPartialMul(Instruction *Real, Instruction *Imag);
307 
308   /// Identify the other branch of a Partial Mul, taking the CommonOperandI that
309   /// is partially known from identifyPartialMul, filling in the other half of
310   /// the complex pair.
311   NodePtr identifyNodeWithImplicitAdd(
312       Instruction *I, Instruction *J,
313       std::pair<Instruction *, Instruction *> &CommonOperandI);
314 
315   /// Identifies a complex add pattern and its rotation, based on the following
316   /// patterns.
317   ///
318   /// 90:  r: ar - bi
319   ///      i: ai + br
320   /// 270: r: ar + bi
321   ///      i: ai - br
322   NodePtr identifyAdd(Instruction *Real, Instruction *Imag);
323   NodePtr identifySymmetricOperation(Instruction *Real, Instruction *Imag);
324 
325   NodePtr identifyNode(Instruction *I, Instruction *J);
326 
327   /// Determine if a sum of complex numbers can be formed from \p RealAddends
328   /// and \p ImagAddens. If \p Accumulator is not null, add the result to it.
329   /// Return nullptr if it is not possible to construct a complex number.
330   /// \p Flags are needed to generate symmetric Add and Sub operations.
331   NodePtr identifyAdditions(std::list<Addend> &RealAddends,
332                             std::list<Addend> &ImagAddends, FastMathFlags Flags,
333                             NodePtr Accumulator);
334 
335   /// Extract one addend that have both real and imaginary parts positive.
336   NodePtr extractPositiveAddend(std::list<Addend> &RealAddends,
337                                 std::list<Addend> &ImagAddends);
338 
339   /// Determine if sum of multiplications of complex numbers can be formed from
340   /// \p RealMuls and \p ImagMuls. If \p Accumulator is not null, add the result
341   /// to it. Return nullptr if it is not possible to construct a complex number.
342   NodePtr identifyMultiplications(std::vector<Product> &RealMuls,
343                                   std::vector<Product> &ImagMuls,
344                                   NodePtr Accumulator);
345 
346   /// Go through pairs of multiplication (one Real and one Imag) and find all
347   /// possible candidates for partial multiplication and put them into \p
348   /// Candidates. Returns true if all Product has pair with common operand
349   bool collectPartialMuls(const std::vector<Product> &RealMuls,
350                           const std::vector<Product> &ImagMuls,
351                           std::vector<PartialMulCandidate> &Candidates);
352 
353   /// If the code is compiled with -Ofast or expressions have `reassoc` flag,
354   /// the order of complex computation operations may be significantly altered,
355   /// and the real and imaginary parts may not be executed in parallel. This
356   /// function takes this into consideration and employs a more general approach
357   /// to identify complex computations. Initially, it gathers all the addends
358   /// and multiplicands and then constructs a complex expression from them.
359   NodePtr identifyReassocNodes(Instruction *I, Instruction *J);
360 
361   NodePtr identifyRoot(Instruction *I);
362 
363   /// Identifies the Deinterleave operation applied to a vector containing
364   /// complex numbers. There are two ways to represent the Deinterleave
365   /// operation:
366   /// * Using two shufflevectors with even indices for /pReal instruction and
367   /// odd indices for /pImag instructions (only for fixed-width vectors)
368   /// * Using two extractvalue instructions applied to `vector.deinterleave2`
369   /// intrinsic (for both fixed and scalable vectors)
370   NodePtr identifyDeinterleave(Instruction *Real, Instruction *Imag);
371 
372   NodePtr identifyPHINode(Instruction *Real, Instruction *Imag);
373 
374   /// Identifies SelectInsts in a loop that has reduction with predication masks
375   /// and/or predicated tail folding
376   NodePtr identifySelectNode(Instruction *Real, Instruction *Imag);
377 
378   Value *replaceNode(IRBuilderBase &Builder, RawNodePtr Node);
379 
380   /// Complete IR modifications after producing new reduction operation:
381   /// * Populate the PHINode generated for
382   /// ComplexDeinterleavingOperation::ReductionPHI
383   /// * Deinterleave the final value outside of the loop and repurpose original
384   /// reduction users
385   void processReductionOperation(Value *OperationReplacement, RawNodePtr Node);
386 
387 public:
388   void dump() { dump(dbgs()); }
389   void dump(raw_ostream &OS) {
390     for (const auto &Node : CompositeNodes)
391       Node->dump(OS);
392   }
393 
394   /// Returns false if the deinterleaving operation should be cancelled for the
395   /// current graph.
396   bool identifyNodes(Instruction *RootI);
397 
398   /// In case \pB is one-block loop, this function seeks potential reductions
399   /// and populates ReductionInfo. Returns true if any reductions were
400   /// identified.
401   bool collectPotentialReductions(BasicBlock *B);
402 
403   void identifyReductionNodes();
404 
405   /// Check that every instruction, from the roots to the leaves, has internal
406   /// uses.
407   bool checkNodes();
408 
409   /// Perform the actual replacement of the underlying instruction graph.
410   void replaceNodes();
411 };
412 
413 class ComplexDeinterleaving {
414 public:
415   ComplexDeinterleaving(const TargetLowering *tl, const TargetLibraryInfo *tli)
416       : TL(tl), TLI(tli) {}
417   bool runOnFunction(Function &F);
418 
419 private:
420   bool evaluateBasicBlock(BasicBlock *B);
421 
422   const TargetLowering *TL = nullptr;
423   const TargetLibraryInfo *TLI = nullptr;
424 };
425 
426 } // namespace
427 
428 char ComplexDeinterleavingLegacyPass::ID = 0;
429 
430 INITIALIZE_PASS_BEGIN(ComplexDeinterleavingLegacyPass, DEBUG_TYPE,
431                       "Complex Deinterleaving", false, false)
432 INITIALIZE_PASS_END(ComplexDeinterleavingLegacyPass, DEBUG_TYPE,
433                     "Complex Deinterleaving", false, false)
434 
435 PreservedAnalyses ComplexDeinterleavingPass::run(Function &F,
436                                                  FunctionAnalysisManager &AM) {
437   const TargetLowering *TL = TM->getSubtargetImpl(F)->getTargetLowering();
438   auto &TLI = AM.getResult<llvm::TargetLibraryAnalysis>(F);
439   if (!ComplexDeinterleaving(TL, &TLI).runOnFunction(F))
440     return PreservedAnalyses::all();
441 
442   PreservedAnalyses PA;
443   PA.preserve<FunctionAnalysisManagerModuleProxy>();
444   return PA;
445 }
446 
447 FunctionPass *llvm::createComplexDeinterleavingPass(const TargetMachine *TM) {
448   return new ComplexDeinterleavingLegacyPass(TM);
449 }
450 
451 bool ComplexDeinterleavingLegacyPass::runOnFunction(Function &F) {
452   const auto *TL = TM->getSubtargetImpl(F)->getTargetLowering();
453   auto TLI = getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F);
454   return ComplexDeinterleaving(TL, &TLI).runOnFunction(F);
455 }
456 
457 bool ComplexDeinterleaving::runOnFunction(Function &F) {
458   if (!ComplexDeinterleavingEnabled) {
459     LLVM_DEBUG(
460         dbgs() << "Complex deinterleaving has been explicitly disabled.\n");
461     return false;
462   }
463 
464   if (!TL->isComplexDeinterleavingSupported()) {
465     LLVM_DEBUG(
466         dbgs() << "Complex deinterleaving has been disabled, target does "
467                   "not support lowering of complex number operations.\n");
468     return false;
469   }
470 
471   bool Changed = false;
472   for (auto &B : F)
473     Changed |= evaluateBasicBlock(&B);
474 
475   return Changed;
476 }
477 
478 static bool isInterleavingMask(ArrayRef<int> Mask) {
479   // If the size is not even, it's not an interleaving mask
480   if ((Mask.size() & 1))
481     return false;
482 
483   int HalfNumElements = Mask.size() / 2;
484   for (int Idx = 0; Idx < HalfNumElements; ++Idx) {
485     int MaskIdx = Idx * 2;
486     if (Mask[MaskIdx] != Idx || Mask[MaskIdx + 1] != (Idx + HalfNumElements))
487       return false;
488   }
489 
490   return true;
491 }
492 
493 static bool isDeinterleavingMask(ArrayRef<int> Mask) {
494   int Offset = Mask[0];
495   int HalfNumElements = Mask.size() / 2;
496 
497   for (int Idx = 1; Idx < HalfNumElements; ++Idx) {
498     if (Mask[Idx] != (Idx * 2) + Offset)
499       return false;
500   }
501 
502   return true;
503 }
504 
505 bool ComplexDeinterleaving::evaluateBasicBlock(BasicBlock *B) {
506   ComplexDeinterleavingGraph Graph(TL, TLI);
507   if (Graph.collectPotentialReductions(B))
508     Graph.identifyReductionNodes();
509 
510   for (auto &I : *B)
511     Graph.identifyNodes(&I);
512 
513   if (Graph.checkNodes()) {
514     Graph.replaceNodes();
515     return true;
516   }
517 
518   return false;
519 }
520 
521 ComplexDeinterleavingGraph::NodePtr
522 ComplexDeinterleavingGraph::identifyNodeWithImplicitAdd(
523     Instruction *Real, Instruction *Imag,
524     std::pair<Instruction *, Instruction *> &PartialMatch) {
525   LLVM_DEBUG(dbgs() << "identifyNodeWithImplicitAdd " << *Real << " / " << *Imag
526                     << "\n");
527 
528   if (!Real->hasOneUse() || !Imag->hasOneUse()) {
529     LLVM_DEBUG(dbgs() << "  - Mul operand has multiple uses.\n");
530     return nullptr;
531   }
532 
533   if (Real->getOpcode() != Instruction::FMul ||
534       Imag->getOpcode() != Instruction::FMul) {
535     LLVM_DEBUG(dbgs() << "  - Real or imaginary instruction is not fmul\n");
536     return nullptr;
537   }
538 
539   Instruction *R0 = dyn_cast<Instruction>(Real->getOperand(0));
540   Instruction *R1 = dyn_cast<Instruction>(Real->getOperand(1));
541   Instruction *I0 = dyn_cast<Instruction>(Imag->getOperand(0));
542   Instruction *I1 = dyn_cast<Instruction>(Imag->getOperand(1));
543   if (!R0 || !R1 || !I0 || !I1) {
544     LLVM_DEBUG(dbgs() << "  - Mul operand not Instruction\n");
545     return nullptr;
546   }
547 
548   // A +/+ has a rotation of 0. If any of the operands are fneg, we flip the
549   // rotations and use the operand.
550   unsigned Negs = 0;
551   SmallVector<Instruction *> FNegs;
552   if (R0->getOpcode() == Instruction::FNeg ||
553       R1->getOpcode() == Instruction::FNeg) {
554     Negs |= 1;
555     if (R0->getOpcode() == Instruction::FNeg) {
556       FNegs.push_back(R0);
557       R0 = dyn_cast<Instruction>(R0->getOperand(0));
558     } else {
559       FNegs.push_back(R1);
560       R1 = dyn_cast<Instruction>(R1->getOperand(0));
561     }
562     if (!R0 || !R1)
563       return nullptr;
564   }
565   if (I0->getOpcode() == Instruction::FNeg ||
566       I1->getOpcode() == Instruction::FNeg) {
567     Negs |= 2;
568     Negs ^= 1;
569     if (I0->getOpcode() == Instruction::FNeg) {
570       FNegs.push_back(I0);
571       I0 = dyn_cast<Instruction>(I0->getOperand(0));
572     } else {
573       FNegs.push_back(I1);
574       I1 = dyn_cast<Instruction>(I1->getOperand(0));
575     }
576     if (!I0 || !I1)
577       return nullptr;
578   }
579 
580   ComplexDeinterleavingRotation Rotation = (ComplexDeinterleavingRotation)Negs;
581 
582   Instruction *CommonOperand;
583   Instruction *UncommonRealOp;
584   Instruction *UncommonImagOp;
585 
586   if (R0 == I0 || R0 == I1) {
587     CommonOperand = R0;
588     UncommonRealOp = R1;
589   } else if (R1 == I0 || R1 == I1) {
590     CommonOperand = R1;
591     UncommonRealOp = R0;
592   } else {
593     LLVM_DEBUG(dbgs() << "  - No equal operand\n");
594     return nullptr;
595   }
596 
597   UncommonImagOp = (CommonOperand == I0) ? I1 : I0;
598   if (Rotation == ComplexDeinterleavingRotation::Rotation_90 ||
599       Rotation == ComplexDeinterleavingRotation::Rotation_270)
600     std::swap(UncommonRealOp, UncommonImagOp);
601 
602   // Between identifyPartialMul and here we need to have found a complete valid
603   // pair from the CommonOperand of each part.
604   if (Rotation == ComplexDeinterleavingRotation::Rotation_0 ||
605       Rotation == ComplexDeinterleavingRotation::Rotation_180)
606     PartialMatch.first = CommonOperand;
607   else
608     PartialMatch.second = CommonOperand;
609 
610   if (!PartialMatch.first || !PartialMatch.second) {
611     LLVM_DEBUG(dbgs() << "  - Incomplete partial match\n");
612     return nullptr;
613   }
614 
615   NodePtr CommonNode = identifyNode(PartialMatch.first, PartialMatch.second);
616   if (!CommonNode) {
617     LLVM_DEBUG(dbgs() << "  - No CommonNode identified\n");
618     return nullptr;
619   }
620 
621   NodePtr UncommonNode = identifyNode(UncommonRealOp, UncommonImagOp);
622   if (!UncommonNode) {
623     LLVM_DEBUG(dbgs() << "  - No UncommonNode identified\n");
624     return nullptr;
625   }
626 
627   NodePtr Node = prepareCompositeNode(
628       ComplexDeinterleavingOperation::CMulPartial, Real, Imag);
629   Node->Rotation = Rotation;
630   Node->addOperand(CommonNode);
631   Node->addOperand(UncommonNode);
632   return submitCompositeNode(Node);
633 }
634 
635 ComplexDeinterleavingGraph::NodePtr
636 ComplexDeinterleavingGraph::identifyPartialMul(Instruction *Real,
637                                                Instruction *Imag) {
638   LLVM_DEBUG(dbgs() << "identifyPartialMul " << *Real << " / " << *Imag
639                     << "\n");
640   // Determine rotation
641   ComplexDeinterleavingRotation Rotation;
642   if (Real->getOpcode() == Instruction::FAdd &&
643       Imag->getOpcode() == Instruction::FAdd)
644     Rotation = ComplexDeinterleavingRotation::Rotation_0;
645   else if (Real->getOpcode() == Instruction::FSub &&
646            Imag->getOpcode() == Instruction::FAdd)
647     Rotation = ComplexDeinterleavingRotation::Rotation_90;
648   else if (Real->getOpcode() == Instruction::FSub &&
649            Imag->getOpcode() == Instruction::FSub)
650     Rotation = ComplexDeinterleavingRotation::Rotation_180;
651   else if (Real->getOpcode() == Instruction::FAdd &&
652            Imag->getOpcode() == Instruction::FSub)
653     Rotation = ComplexDeinterleavingRotation::Rotation_270;
654   else {
655     LLVM_DEBUG(dbgs() << "  - Unhandled rotation.\n");
656     return nullptr;
657   }
658 
659   if (!Real->getFastMathFlags().allowContract() ||
660       !Imag->getFastMathFlags().allowContract()) {
661     LLVM_DEBUG(dbgs() << "  - Contract is missing from the FastMath flags.\n");
662     return nullptr;
663   }
664 
665   Value *CR = Real->getOperand(0);
666   Instruction *RealMulI = dyn_cast<Instruction>(Real->getOperand(1));
667   if (!RealMulI)
668     return nullptr;
669   Value *CI = Imag->getOperand(0);
670   Instruction *ImagMulI = dyn_cast<Instruction>(Imag->getOperand(1));
671   if (!ImagMulI)
672     return nullptr;
673 
674   if (!RealMulI->hasOneUse() || !ImagMulI->hasOneUse()) {
675     LLVM_DEBUG(dbgs() << "  - Mul instruction has multiple uses\n");
676     return nullptr;
677   }
678 
679   Instruction *R0 = dyn_cast<Instruction>(RealMulI->getOperand(0));
680   Instruction *R1 = dyn_cast<Instruction>(RealMulI->getOperand(1));
681   Instruction *I0 = dyn_cast<Instruction>(ImagMulI->getOperand(0));
682   Instruction *I1 = dyn_cast<Instruction>(ImagMulI->getOperand(1));
683   if (!R0 || !R1 || !I0 || !I1) {
684     LLVM_DEBUG(dbgs() << "  - Mul operand not Instruction\n");
685     return nullptr;
686   }
687 
688   Instruction *CommonOperand;
689   Instruction *UncommonRealOp;
690   Instruction *UncommonImagOp;
691 
692   if (R0 == I0 || R0 == I1) {
693     CommonOperand = R0;
694     UncommonRealOp = R1;
695   } else if (R1 == I0 || R1 == I1) {
696     CommonOperand = R1;
697     UncommonRealOp = R0;
698   } else {
699     LLVM_DEBUG(dbgs() << "  - No equal operand\n");
700     return nullptr;
701   }
702 
703   UncommonImagOp = (CommonOperand == I0) ? I1 : I0;
704   if (Rotation == ComplexDeinterleavingRotation::Rotation_90 ||
705       Rotation == ComplexDeinterleavingRotation::Rotation_270)
706     std::swap(UncommonRealOp, UncommonImagOp);
707 
708   std::pair<Instruction *, Instruction *> PartialMatch(
709       (Rotation == ComplexDeinterleavingRotation::Rotation_0 ||
710        Rotation == ComplexDeinterleavingRotation::Rotation_180)
711           ? CommonOperand
712           : nullptr,
713       (Rotation == ComplexDeinterleavingRotation::Rotation_90 ||
714        Rotation == ComplexDeinterleavingRotation::Rotation_270)
715           ? CommonOperand
716           : nullptr);
717 
718   auto *CRInst = dyn_cast<Instruction>(CR);
719   auto *CIInst = dyn_cast<Instruction>(CI);
720 
721   if (!CRInst || !CIInst) {
722     LLVM_DEBUG(dbgs() << "  - Common operands are not instructions.\n");
723     return nullptr;
724   }
725 
726   NodePtr CNode = identifyNodeWithImplicitAdd(CRInst, CIInst, PartialMatch);
727   if (!CNode) {
728     LLVM_DEBUG(dbgs() << "  - No cnode identified\n");
729     return nullptr;
730   }
731 
732   NodePtr UncommonRes = identifyNode(UncommonRealOp, UncommonImagOp);
733   if (!UncommonRes) {
734     LLVM_DEBUG(dbgs() << "  - No UncommonRes identified\n");
735     return nullptr;
736   }
737 
738   assert(PartialMatch.first && PartialMatch.second);
739   NodePtr CommonRes = identifyNode(PartialMatch.first, PartialMatch.second);
740   if (!CommonRes) {
741     LLVM_DEBUG(dbgs() << "  - No CommonRes identified\n");
742     return nullptr;
743   }
744 
745   NodePtr Node = prepareCompositeNode(
746       ComplexDeinterleavingOperation::CMulPartial, Real, Imag);
747   Node->Rotation = Rotation;
748   Node->addOperand(CommonRes);
749   Node->addOperand(UncommonRes);
750   Node->addOperand(CNode);
751   return submitCompositeNode(Node);
752 }
753 
754 ComplexDeinterleavingGraph::NodePtr
755 ComplexDeinterleavingGraph::identifyAdd(Instruction *Real, Instruction *Imag) {
756   LLVM_DEBUG(dbgs() << "identifyAdd " << *Real << " / " << *Imag << "\n");
757 
758   // Determine rotation
759   ComplexDeinterleavingRotation Rotation;
760   if ((Real->getOpcode() == Instruction::FSub &&
761        Imag->getOpcode() == Instruction::FAdd) ||
762       (Real->getOpcode() == Instruction::Sub &&
763        Imag->getOpcode() == Instruction::Add))
764     Rotation = ComplexDeinterleavingRotation::Rotation_90;
765   else if ((Real->getOpcode() == Instruction::FAdd &&
766             Imag->getOpcode() == Instruction::FSub) ||
767            (Real->getOpcode() == Instruction::Add &&
768             Imag->getOpcode() == Instruction::Sub))
769     Rotation = ComplexDeinterleavingRotation::Rotation_270;
770   else {
771     LLVM_DEBUG(dbgs() << " - Unhandled case, rotation is not assigned.\n");
772     return nullptr;
773   }
774 
775   auto *AR = dyn_cast<Instruction>(Real->getOperand(0));
776   auto *BI = dyn_cast<Instruction>(Real->getOperand(1));
777   auto *AI = dyn_cast<Instruction>(Imag->getOperand(0));
778   auto *BR = dyn_cast<Instruction>(Imag->getOperand(1));
779 
780   if (!AR || !AI || !BR || !BI) {
781     LLVM_DEBUG(dbgs() << " - Not all operands are instructions.\n");
782     return nullptr;
783   }
784 
785   NodePtr ResA = identifyNode(AR, AI);
786   if (!ResA) {
787     LLVM_DEBUG(dbgs() << " - AR/AI is not identified as a composite node.\n");
788     return nullptr;
789   }
790   NodePtr ResB = identifyNode(BR, BI);
791   if (!ResB) {
792     LLVM_DEBUG(dbgs() << " - BR/BI is not identified as a composite node.\n");
793     return nullptr;
794   }
795 
796   NodePtr Node =
797       prepareCompositeNode(ComplexDeinterleavingOperation::CAdd, Real, Imag);
798   Node->Rotation = Rotation;
799   Node->addOperand(ResA);
800   Node->addOperand(ResB);
801   return submitCompositeNode(Node);
802 }
803 
804 static bool isInstructionPairAdd(Instruction *A, Instruction *B) {
805   unsigned OpcA = A->getOpcode();
806   unsigned OpcB = B->getOpcode();
807 
808   return (OpcA == Instruction::FSub && OpcB == Instruction::FAdd) ||
809          (OpcA == Instruction::FAdd && OpcB == Instruction::FSub) ||
810          (OpcA == Instruction::Sub && OpcB == Instruction::Add) ||
811          (OpcA == Instruction::Add && OpcB == Instruction::Sub);
812 }
813 
814 static bool isInstructionPairMul(Instruction *A, Instruction *B) {
815   auto Pattern =
816       m_BinOp(m_FMul(m_Value(), m_Value()), m_FMul(m_Value(), m_Value()));
817 
818   return match(A, Pattern) && match(B, Pattern);
819 }
820 
821 static bool isInstructionPotentiallySymmetric(Instruction *I) {
822   switch (I->getOpcode()) {
823   case Instruction::FAdd:
824   case Instruction::FSub:
825   case Instruction::FMul:
826   case Instruction::FNeg:
827     return true;
828   default:
829     return false;
830   }
831 }
832 
833 ComplexDeinterleavingGraph::NodePtr
834 ComplexDeinterleavingGraph::identifySymmetricOperation(Instruction *Real,
835                                                        Instruction *Imag) {
836   if (Real->getOpcode() != Imag->getOpcode())
837     return nullptr;
838 
839   if (!isInstructionPotentiallySymmetric(Real) ||
840       !isInstructionPotentiallySymmetric(Imag))
841     return nullptr;
842 
843   auto *R0 = dyn_cast<Instruction>(Real->getOperand(0));
844   auto *I0 = dyn_cast<Instruction>(Imag->getOperand(0));
845 
846   if (!R0 || !I0)
847     return nullptr;
848 
849   NodePtr Op0 = identifyNode(R0, I0);
850   NodePtr Op1 = nullptr;
851   if (Op0 == nullptr)
852     return nullptr;
853 
854   if (Real->isBinaryOp()) {
855     auto *R1 = dyn_cast<Instruction>(Real->getOperand(1));
856     auto *I1 = dyn_cast<Instruction>(Imag->getOperand(1));
857     if (!R1 || !I1)
858       return nullptr;
859 
860     Op1 = identifyNode(R1, I1);
861     if (Op1 == nullptr)
862       return nullptr;
863   }
864 
865   if (isa<FPMathOperator>(Real) &&
866       Real->getFastMathFlags() != Imag->getFastMathFlags())
867     return nullptr;
868 
869   auto Node = prepareCompositeNode(ComplexDeinterleavingOperation::Symmetric,
870                                    Real, Imag);
871   Node->Opcode = Real->getOpcode();
872   if (isa<FPMathOperator>(Real))
873     Node->Flags = Real->getFastMathFlags();
874 
875   Node->addOperand(Op0);
876   if (Real->isBinaryOp())
877     Node->addOperand(Op1);
878 
879   return submitCompositeNode(Node);
880 }
881 
882 ComplexDeinterleavingGraph::NodePtr
883 ComplexDeinterleavingGraph::identifyNode(Instruction *Real, Instruction *Imag) {
884   LLVM_DEBUG(dbgs() << "identifyNode on " << *Real << " / " << *Imag << "\n");
885   assert(Real->getType() == Imag->getType() &&
886          "Real and imaginary parts should not have different types");
887   if (NodePtr CN = getContainingComposite(Real, Imag)) {
888     LLVM_DEBUG(dbgs() << " - Folding to existing node\n");
889     return CN;
890   }
891 
892   if (NodePtr CN = identifyDeinterleave(Real, Imag))
893     return CN;
894 
895   if (NodePtr CN = identifyPHINode(Real, Imag))
896     return CN;
897 
898   if (NodePtr CN = identifySelectNode(Real, Imag))
899     return CN;
900 
901   auto *VTy = cast<VectorType>(Real->getType());
902   auto *NewVTy = VectorType::getDoubleElementsVectorType(VTy);
903 
904   bool HasCMulSupport = TL->isComplexDeinterleavingOperationSupported(
905       ComplexDeinterleavingOperation::CMulPartial, NewVTy);
906   bool HasCAddSupport = TL->isComplexDeinterleavingOperationSupported(
907       ComplexDeinterleavingOperation::CAdd, NewVTy);
908 
909   if (HasCMulSupport && isInstructionPairMul(Real, Imag)) {
910     if (NodePtr CN = identifyPartialMul(Real, Imag))
911       return CN;
912   }
913 
914   if (HasCAddSupport && isInstructionPairAdd(Real, Imag)) {
915     if (NodePtr CN = identifyAdd(Real, Imag))
916       return CN;
917   }
918 
919   if (HasCMulSupport && HasCAddSupport) {
920     if (NodePtr CN = identifyReassocNodes(Real, Imag))
921       return CN;
922   }
923 
924   if (NodePtr CN = identifySymmetricOperation(Real, Imag))
925     return CN;
926 
927   LLVM_DEBUG(dbgs() << "  - Not recognised as a valid pattern.\n");
928   return nullptr;
929 }
930 
931 ComplexDeinterleavingGraph::NodePtr
932 ComplexDeinterleavingGraph::identifyReassocNodes(Instruction *Real,
933                                                  Instruction *Imag) {
934   if ((Real->getOpcode() != Instruction::FAdd &&
935        Real->getOpcode() != Instruction::FSub &&
936        Real->getOpcode() != Instruction::FNeg) ||
937       (Imag->getOpcode() != Instruction::FAdd &&
938        Imag->getOpcode() != Instruction::FSub &&
939        Imag->getOpcode() != Instruction::FNeg))
940     return nullptr;
941 
942   if (Real->getFastMathFlags() != Imag->getFastMathFlags()) {
943     LLVM_DEBUG(
944         dbgs()
945         << "The flags in Real and Imaginary instructions are not identical\n");
946     return nullptr;
947   }
948 
949   FastMathFlags Flags = Real->getFastMathFlags();
950   if (!Flags.allowReassoc()) {
951     LLVM_DEBUG(
952         dbgs() << "the 'Reassoc' attribute is missing in the FastMath flags\n");
953     return nullptr;
954   }
955 
956   // Collect multiplications and addend instructions from the given instruction
957   // while traversing it operands. Additionally, verify that all instructions
958   // have the same fast math flags.
959   auto Collect = [&Flags](Instruction *Insn, std::vector<Product> &Muls,
960                           std::list<Addend> &Addends) -> bool {
961     SmallVector<PointerIntPair<Value *, 1, bool>> Worklist = {{Insn, true}};
962     SmallPtrSet<Value *, 8> Visited;
963     while (!Worklist.empty()) {
964       auto [V, IsPositive] = Worklist.back();
965       Worklist.pop_back();
966       if (!Visited.insert(V).second)
967         continue;
968 
969       Instruction *I = dyn_cast<Instruction>(V);
970       if (!I)
971         return false;
972 
973       // If an instruction has more than one user, it indicates that it either
974       // has an external user, which will be later checked by the checkNodes
975       // function, or it is a subexpression utilized by multiple expressions. In
976       // the latter case, we will attempt to separately identify the complex
977       // operation from here in order to create a shared
978       // ComplexDeinterleavingCompositeNode.
979       if (I != Insn && I->getNumUses() > 1) {
980         LLVM_DEBUG(dbgs() << "Found potential sub-expression: " << *I << "\n");
981         Addends.emplace_back(I, IsPositive);
982         continue;
983       }
984 
985       if (I->getOpcode() == Instruction::FAdd) {
986         Worklist.emplace_back(I->getOperand(1), IsPositive);
987         Worklist.emplace_back(I->getOperand(0), IsPositive);
988       } else if (I->getOpcode() == Instruction::FSub) {
989         Worklist.emplace_back(I->getOperand(1), !IsPositive);
990         Worklist.emplace_back(I->getOperand(0), IsPositive);
991       } else if (I->getOpcode() == Instruction::FMul) {
992         auto *A = dyn_cast<Instruction>(I->getOperand(0));
993         if (A && A->getOpcode() == Instruction::FNeg) {
994           A = dyn_cast<Instruction>(A->getOperand(0));
995           IsPositive = !IsPositive;
996         }
997         if (!A)
998           return false;
999         auto *B = dyn_cast<Instruction>(I->getOperand(1));
1000         if (B && B->getOpcode() == Instruction::FNeg) {
1001           B = dyn_cast<Instruction>(B->getOperand(0));
1002           IsPositive = !IsPositive;
1003         }
1004         if (!B)
1005           return false;
1006         Muls.push_back(Product{A, B, IsPositive});
1007       } else if (I->getOpcode() == Instruction::FNeg) {
1008         Worklist.emplace_back(I->getOperand(0), !IsPositive);
1009       } else {
1010         Addends.emplace_back(I, IsPositive);
1011         continue;
1012       }
1013 
1014       if (I->getFastMathFlags() != Flags) {
1015         LLVM_DEBUG(dbgs() << "The instruction's fast math flags are "
1016                              "inconsistent with the root instructions' flags: "
1017                           << *I << "\n");
1018         return false;
1019       }
1020     }
1021     return true;
1022   };
1023 
1024   std::vector<Product> RealMuls, ImagMuls;
1025   std::list<Addend> RealAddends, ImagAddends;
1026   if (!Collect(Real, RealMuls, RealAddends) ||
1027       !Collect(Imag, ImagMuls, ImagAddends))
1028     return nullptr;
1029 
1030   if (RealAddends.size() != ImagAddends.size())
1031     return nullptr;
1032 
1033   NodePtr FinalNode;
1034   if (!RealMuls.empty() || !ImagMuls.empty()) {
1035     // If there are multiplicands, extract positive addend and use it as an
1036     // accumulator
1037     FinalNode = extractPositiveAddend(RealAddends, ImagAddends);
1038     FinalNode = identifyMultiplications(RealMuls, ImagMuls, FinalNode);
1039     if (!FinalNode)
1040       return nullptr;
1041   }
1042 
1043   // Identify and process remaining additions
1044   if (!RealAddends.empty() || !ImagAddends.empty()) {
1045     FinalNode = identifyAdditions(RealAddends, ImagAddends, Flags, FinalNode);
1046     if (!FinalNode)
1047       return nullptr;
1048   }
1049   assert(FinalNode && "FinalNode can not be nullptr here");
1050   // Set the Real and Imag fields of the final node and submit it
1051   FinalNode->Real = Real;
1052   FinalNode->Imag = Imag;
1053   submitCompositeNode(FinalNode);
1054   return FinalNode;
1055 }
1056 
1057 bool ComplexDeinterleavingGraph::collectPartialMuls(
1058     const std::vector<Product> &RealMuls, const std::vector<Product> &ImagMuls,
1059     std::vector<PartialMulCandidate> &PartialMulCandidates) {
1060   // Helper function to extract a common operand from two products
1061   auto FindCommonInstruction = [](const Product &Real,
1062                                   const Product &Imag) -> Instruction * {
1063     if (Real.Multiplicand == Imag.Multiplicand ||
1064         Real.Multiplicand == Imag.Multiplier)
1065       return Real.Multiplicand;
1066 
1067     if (Real.Multiplier == Imag.Multiplicand ||
1068         Real.Multiplier == Imag.Multiplier)
1069       return Real.Multiplier;
1070 
1071     return nullptr;
1072   };
1073 
1074   // Iterating over real and imaginary multiplications to find common operands
1075   // If a common operand is found, a partial multiplication candidate is created
1076   // and added to the candidates vector The function returns false if no common
1077   // operands are found for any product
1078   for (unsigned i = 0; i < RealMuls.size(); ++i) {
1079     bool FoundCommon = false;
1080     for (unsigned j = 0; j < ImagMuls.size(); ++j) {
1081       auto *Common = FindCommonInstruction(RealMuls[i], ImagMuls[j]);
1082       if (!Common)
1083         continue;
1084 
1085       auto *A = RealMuls[i].Multiplicand == Common ? RealMuls[i].Multiplier
1086                                                    : RealMuls[i].Multiplicand;
1087       auto *B = ImagMuls[j].Multiplicand == Common ? ImagMuls[j].Multiplier
1088                                                    : ImagMuls[j].Multiplicand;
1089 
1090       bool Inverted = false;
1091       auto Node = identifyNode(A, B);
1092       if (!Node) {
1093         std::swap(A, B);
1094         Inverted = true;
1095         Node = identifyNode(A, B);
1096       }
1097       if (!Node)
1098         continue;
1099 
1100       FoundCommon = true;
1101       PartialMulCandidates.push_back({Common, Node, i, j, Inverted});
1102     }
1103     if (!FoundCommon)
1104       return false;
1105   }
1106   return true;
1107 }
1108 
1109 ComplexDeinterleavingGraph::NodePtr
1110 ComplexDeinterleavingGraph::identifyMultiplications(
1111     std::vector<Product> &RealMuls, std::vector<Product> &ImagMuls,
1112     NodePtr Accumulator = nullptr) {
1113   if (RealMuls.size() != ImagMuls.size())
1114     return nullptr;
1115 
1116   std::vector<PartialMulCandidate> Info;
1117   if (!collectPartialMuls(RealMuls, ImagMuls, Info))
1118     return nullptr;
1119 
1120   // Map to store common instruction to node pointers
1121   std::map<Instruction *, NodePtr> CommonToNode;
1122   std::vector<bool> Processed(Info.size(), false);
1123   for (unsigned I = 0; I < Info.size(); ++I) {
1124     if (Processed[I])
1125       continue;
1126 
1127     PartialMulCandidate &InfoA = Info[I];
1128     for (unsigned J = I + 1; J < Info.size(); ++J) {
1129       if (Processed[J])
1130         continue;
1131 
1132       PartialMulCandidate &InfoB = Info[J];
1133       auto *InfoReal = &InfoA;
1134       auto *InfoImag = &InfoB;
1135 
1136       auto NodeFromCommon = identifyNode(InfoReal->Common, InfoImag->Common);
1137       if (!NodeFromCommon) {
1138         std::swap(InfoReal, InfoImag);
1139         NodeFromCommon = identifyNode(InfoReal->Common, InfoImag->Common);
1140       }
1141       if (!NodeFromCommon)
1142         continue;
1143 
1144       CommonToNode[InfoReal->Common] = NodeFromCommon;
1145       CommonToNode[InfoImag->Common] = NodeFromCommon;
1146       Processed[I] = true;
1147       Processed[J] = true;
1148     }
1149   }
1150 
1151   std::vector<bool> ProcessedReal(RealMuls.size(), false);
1152   std::vector<bool> ProcessedImag(ImagMuls.size(), false);
1153   NodePtr Result = Accumulator;
1154   for (auto &PMI : Info) {
1155     if (ProcessedReal[PMI.RealIdx] || ProcessedImag[PMI.ImagIdx])
1156       continue;
1157 
1158     auto It = CommonToNode.find(PMI.Common);
1159     // TODO: Process independent complex multiplications. Cases like this:
1160     //  A.real() * B where both A and B are complex numbers.
1161     if (It == CommonToNode.end()) {
1162       LLVM_DEBUG({
1163         dbgs() << "Unprocessed independent partial multiplication:\n";
1164         for (auto *Mul : {&RealMuls[PMI.RealIdx], &RealMuls[PMI.RealIdx]})
1165           dbgs().indent(4) << (Mul->IsPositive ? "+" : "-") << *Mul->Multiplier
1166                            << " multiplied by " << *Mul->Multiplicand << "\n";
1167       });
1168       return nullptr;
1169     }
1170 
1171     auto &RealMul = RealMuls[PMI.RealIdx];
1172     auto &ImagMul = ImagMuls[PMI.ImagIdx];
1173 
1174     auto NodeA = It->second;
1175     auto NodeB = PMI.Node;
1176     auto IsMultiplicandReal = PMI.Common == NodeA->Real;
1177     // The following table illustrates the relationship between multiplications
1178     // and rotations. If we consider the multiplication (X + iY) * (U + iV), we
1179     // can see:
1180     //
1181     // Rotation |   Real |   Imag |
1182     // ---------+--------+--------+
1183     //        0 |  x * u |  x * v |
1184     //       90 | -y * v |  y * u |
1185     //      180 | -x * u | -x * v |
1186     //      270 |  y * v | -y * u |
1187     //
1188     // Check if the candidate can indeed be represented by partial
1189     // multiplication
1190     // TODO: Add support for multiplication by complex one
1191     if ((IsMultiplicandReal && PMI.IsNodeInverted) ||
1192         (!IsMultiplicandReal && !PMI.IsNodeInverted))
1193       continue;
1194 
1195     // Determine the rotation based on the multiplications
1196     ComplexDeinterleavingRotation Rotation;
1197     if (IsMultiplicandReal) {
1198       // Detect 0 and 180 degrees rotation
1199       if (RealMul.IsPositive && ImagMul.IsPositive)
1200         Rotation = llvm::ComplexDeinterleavingRotation::Rotation_0;
1201       else if (!RealMul.IsPositive && !ImagMul.IsPositive)
1202         Rotation = llvm::ComplexDeinterleavingRotation::Rotation_180;
1203       else
1204         continue;
1205 
1206     } else {
1207       // Detect 90 and 270 degrees rotation
1208       if (!RealMul.IsPositive && ImagMul.IsPositive)
1209         Rotation = llvm::ComplexDeinterleavingRotation::Rotation_90;
1210       else if (RealMul.IsPositive && !ImagMul.IsPositive)
1211         Rotation = llvm::ComplexDeinterleavingRotation::Rotation_270;
1212       else
1213         continue;
1214     }
1215 
1216     LLVM_DEBUG({
1217       dbgs() << "Identified partial multiplication (X, Y) * (U, V):\n";
1218       dbgs().indent(4) << "X: " << *NodeA->Real << "\n";
1219       dbgs().indent(4) << "Y: " << *NodeA->Imag << "\n";
1220       dbgs().indent(4) << "U: " << *NodeB->Real << "\n";
1221       dbgs().indent(4) << "V: " << *NodeB->Imag << "\n";
1222       dbgs().indent(4) << "Rotation - " << (int)Rotation * 90 << "\n";
1223     });
1224 
1225     NodePtr NodeMul = prepareCompositeNode(
1226         ComplexDeinterleavingOperation::CMulPartial, nullptr, nullptr);
1227     NodeMul->Rotation = Rotation;
1228     NodeMul->addOperand(NodeA);
1229     NodeMul->addOperand(NodeB);
1230     if (Result)
1231       NodeMul->addOperand(Result);
1232     submitCompositeNode(NodeMul);
1233     Result = NodeMul;
1234     ProcessedReal[PMI.RealIdx] = true;
1235     ProcessedImag[PMI.ImagIdx] = true;
1236   }
1237 
1238   // Ensure all products have been processed, if not return nullptr.
1239   if (!all_of(ProcessedReal, [](bool V) { return V; }) ||
1240       !all_of(ProcessedImag, [](bool V) { return V; })) {
1241 
1242     // Dump debug information about which partial multiplications are not
1243     // processed.
1244     LLVM_DEBUG({
1245       dbgs() << "Unprocessed products (Real):\n";
1246       for (size_t i = 0; i < ProcessedReal.size(); ++i) {
1247         if (!ProcessedReal[i])
1248           dbgs().indent(4) << (RealMuls[i].IsPositive ? "+" : "-")
1249                            << *RealMuls[i].Multiplier << " multiplied by "
1250                            << *RealMuls[i].Multiplicand << "\n";
1251       }
1252       dbgs() << "Unprocessed products (Imag):\n";
1253       for (size_t i = 0; i < ProcessedImag.size(); ++i) {
1254         if (!ProcessedImag[i])
1255           dbgs().indent(4) << (ImagMuls[i].IsPositive ? "+" : "-")
1256                            << *ImagMuls[i].Multiplier << " multiplied by "
1257                            << *ImagMuls[i].Multiplicand << "\n";
1258       }
1259     });
1260     return nullptr;
1261   }
1262 
1263   return Result;
1264 }
1265 
1266 ComplexDeinterleavingGraph::NodePtr
1267 ComplexDeinterleavingGraph::identifyAdditions(std::list<Addend> &RealAddends,
1268                                               std::list<Addend> &ImagAddends,
1269                                               FastMathFlags Flags,
1270                                               NodePtr Accumulator = nullptr) {
1271   if (RealAddends.size() != ImagAddends.size())
1272     return nullptr;
1273 
1274   NodePtr Result;
1275   // If we have accumulator use it as first addend
1276   if (Accumulator)
1277     Result = Accumulator;
1278   // Otherwise find an element with both positive real and imaginary parts.
1279   else
1280     Result = extractPositiveAddend(RealAddends, ImagAddends);
1281 
1282   if (!Result)
1283     return nullptr;
1284 
1285   while (!RealAddends.empty()) {
1286     auto ItR = RealAddends.begin();
1287     auto [R, IsPositiveR] = *ItR;
1288 
1289     bool FoundImag = false;
1290     for (auto ItI = ImagAddends.begin(); ItI != ImagAddends.end(); ++ItI) {
1291       auto [I, IsPositiveI] = *ItI;
1292       ComplexDeinterleavingRotation Rotation;
1293       if (IsPositiveR && IsPositiveI)
1294         Rotation = ComplexDeinterleavingRotation::Rotation_0;
1295       else if (!IsPositiveR && IsPositiveI)
1296         Rotation = ComplexDeinterleavingRotation::Rotation_90;
1297       else if (!IsPositiveR && !IsPositiveI)
1298         Rotation = ComplexDeinterleavingRotation::Rotation_180;
1299       else
1300         Rotation = ComplexDeinterleavingRotation::Rotation_270;
1301 
1302       NodePtr AddNode;
1303       if (Rotation == ComplexDeinterleavingRotation::Rotation_0 ||
1304           Rotation == ComplexDeinterleavingRotation::Rotation_180) {
1305         AddNode = identifyNode(R, I);
1306       } else {
1307         AddNode = identifyNode(I, R);
1308       }
1309       if (AddNode) {
1310         LLVM_DEBUG({
1311           dbgs() << "Identified addition:\n";
1312           dbgs().indent(4) << "X: " << *R << "\n";
1313           dbgs().indent(4) << "Y: " << *I << "\n";
1314           dbgs().indent(4) << "Rotation - " << (int)Rotation * 90 << "\n";
1315         });
1316 
1317         NodePtr TmpNode;
1318         if (Rotation == llvm::ComplexDeinterleavingRotation::Rotation_0) {
1319           TmpNode = prepareCompositeNode(
1320               ComplexDeinterleavingOperation::Symmetric, nullptr, nullptr);
1321           TmpNode->Opcode = Instruction::FAdd;
1322           TmpNode->Flags = Flags;
1323         } else if (Rotation ==
1324                    llvm::ComplexDeinterleavingRotation::Rotation_180) {
1325           TmpNode = prepareCompositeNode(
1326               ComplexDeinterleavingOperation::Symmetric, nullptr, nullptr);
1327           TmpNode->Opcode = Instruction::FSub;
1328           TmpNode->Flags = Flags;
1329         } else {
1330           TmpNode = prepareCompositeNode(ComplexDeinterleavingOperation::CAdd,
1331                                          nullptr, nullptr);
1332           TmpNode->Rotation = Rotation;
1333         }
1334 
1335         TmpNode->addOperand(Result);
1336         TmpNode->addOperand(AddNode);
1337         submitCompositeNode(TmpNode);
1338         Result = TmpNode;
1339         RealAddends.erase(ItR);
1340         ImagAddends.erase(ItI);
1341         FoundImag = true;
1342         break;
1343       }
1344     }
1345     if (!FoundImag)
1346       return nullptr;
1347   }
1348   return Result;
1349 }
1350 
1351 ComplexDeinterleavingGraph::NodePtr
1352 ComplexDeinterleavingGraph::extractPositiveAddend(
1353     std::list<Addend> &RealAddends, std::list<Addend> &ImagAddends) {
1354   for (auto ItR = RealAddends.begin(); ItR != RealAddends.end(); ++ItR) {
1355     for (auto ItI = ImagAddends.begin(); ItI != ImagAddends.end(); ++ItI) {
1356       auto [R, IsPositiveR] = *ItR;
1357       auto [I, IsPositiveI] = *ItI;
1358       if (IsPositiveR && IsPositiveI) {
1359         auto Result = identifyNode(R, I);
1360         if (Result) {
1361           RealAddends.erase(ItR);
1362           ImagAddends.erase(ItI);
1363           return Result;
1364         }
1365       }
1366     }
1367   }
1368   return nullptr;
1369 }
1370 
1371 bool ComplexDeinterleavingGraph::identifyNodes(Instruction *RootI) {
1372   // This potential root instruction might already have been recognized as
1373   // reduction. Because RootToNode maps both Real and Imaginary parts to
1374   // CompositeNode we should choose only one either Real or Imag instruction to
1375   // use as an anchor for generating complex instruction.
1376   auto It = RootToNode.find(RootI);
1377   if (It != RootToNode.end() && It->second->Real == RootI) {
1378     OrderedRoots.push_back(RootI);
1379     return true;
1380   }
1381 
1382   auto RootNode = identifyRoot(RootI);
1383   if (!RootNode)
1384     return false;
1385 
1386   LLVM_DEBUG({
1387     Function *F = RootI->getFunction();
1388     BasicBlock *B = RootI->getParent();
1389     dbgs() << "Complex deinterleaving graph for " << F->getName()
1390            << "::" << B->getName() << ".\n";
1391     dump(dbgs());
1392     dbgs() << "\n";
1393   });
1394   RootToNode[RootI] = RootNode;
1395   OrderedRoots.push_back(RootI);
1396   return true;
1397 }
1398 
1399 bool ComplexDeinterleavingGraph::collectPotentialReductions(BasicBlock *B) {
1400   bool FoundPotentialReduction = false;
1401 
1402   auto *Br = dyn_cast<BranchInst>(B->getTerminator());
1403   if (!Br || Br->getNumSuccessors() != 2)
1404     return false;
1405 
1406   // Identify simple one-block loop
1407   if (Br->getSuccessor(0) != B && Br->getSuccessor(1) != B)
1408     return false;
1409 
1410   SmallVector<PHINode *> PHIs;
1411   for (auto &PHI : B->phis()) {
1412     if (PHI.getNumIncomingValues() != 2)
1413       continue;
1414 
1415     if (!PHI.getType()->isVectorTy())
1416       continue;
1417 
1418     auto *ReductionOp = dyn_cast<Instruction>(PHI.getIncomingValueForBlock(B));
1419     if (!ReductionOp)
1420       continue;
1421 
1422     // Check if final instruction is reduced outside of current block
1423     Instruction *FinalReduction = nullptr;
1424     auto NumUsers = 0u;
1425     for (auto *U : ReductionOp->users()) {
1426       ++NumUsers;
1427       if (U == &PHI)
1428         continue;
1429       FinalReduction = dyn_cast<Instruction>(U);
1430     }
1431 
1432     if (NumUsers != 2 || !FinalReduction || FinalReduction->getParent() == B)
1433       continue;
1434 
1435     ReductionInfo[ReductionOp] = {&PHI, FinalReduction};
1436     BackEdge = B;
1437     auto BackEdgeIdx = PHI.getBasicBlockIndex(B);
1438     auto IncomingIdx = BackEdgeIdx == 0 ? 1 : 0;
1439     Incoming = PHI.getIncomingBlock(IncomingIdx);
1440     FoundPotentialReduction = true;
1441 
1442     // If the initial value of PHINode is an Instruction, consider it a leaf
1443     // value of a complex deinterleaving graph.
1444     if (auto *InitPHI =
1445             dyn_cast<Instruction>(PHI.getIncomingValueForBlock(Incoming)))
1446       FinalInstructions.insert(InitPHI);
1447   }
1448   return FoundPotentialReduction;
1449 }
1450 
1451 void ComplexDeinterleavingGraph::identifyReductionNodes() {
1452   SmallVector<bool> Processed(ReductionInfo.size(), false);
1453   SmallVector<Instruction *> OperationInstruction;
1454   for (auto &P : ReductionInfo)
1455     OperationInstruction.push_back(P.first);
1456 
1457   // Identify a complex computation by evaluating two reduction operations that
1458   // potentially could be involved
1459   for (size_t i = 0; i < OperationInstruction.size(); ++i) {
1460     if (Processed[i])
1461       continue;
1462     for (size_t j = i + 1; j < OperationInstruction.size(); ++j) {
1463       if (Processed[j])
1464         continue;
1465 
1466       auto *Real = OperationInstruction[i];
1467       auto *Imag = OperationInstruction[j];
1468       if (Real->getType() != Imag->getType())
1469         continue;
1470 
1471       RealPHI = ReductionInfo[Real].first;
1472       ImagPHI = ReductionInfo[Imag].first;
1473       auto Node = identifyNode(Real, Imag);
1474       if (!Node) {
1475         std::swap(Real, Imag);
1476         std::swap(RealPHI, ImagPHI);
1477         Node = identifyNode(Real, Imag);
1478       }
1479 
1480       // If a node is identified, mark its operation instructions as used to
1481       // prevent re-identification and attach the node to the real part
1482       if (Node) {
1483         LLVM_DEBUG(dbgs() << "Identified reduction starting from instructions: "
1484                           << *Real << " / " << *Imag << "\n");
1485         Processed[i] = true;
1486         Processed[j] = true;
1487         auto RootNode = prepareCompositeNode(
1488             ComplexDeinterleavingOperation::ReductionOperation, Real, Imag);
1489         RootNode->addOperand(Node);
1490         RootToNode[Real] = RootNode;
1491         RootToNode[Imag] = RootNode;
1492         submitCompositeNode(RootNode);
1493         break;
1494       }
1495     }
1496   }
1497 
1498   RealPHI = nullptr;
1499   ImagPHI = nullptr;
1500 }
1501 
1502 bool ComplexDeinterleavingGraph::checkNodes() {
1503   // Collect all instructions from roots to leaves
1504   SmallPtrSet<Instruction *, 16> AllInstructions;
1505   SmallVector<Instruction *, 8> Worklist;
1506   for (auto &Pair : RootToNode)
1507     Worklist.push_back(Pair.first);
1508 
1509   // Extract all instructions that are used by all XCMLA/XCADD/ADD/SUB/NEG
1510   // chains
1511   while (!Worklist.empty()) {
1512     auto *I = Worklist.back();
1513     Worklist.pop_back();
1514 
1515     if (!AllInstructions.insert(I).second)
1516       continue;
1517 
1518     for (Value *Op : I->operands()) {
1519       if (auto *OpI = dyn_cast<Instruction>(Op)) {
1520         if (!FinalInstructions.count(I))
1521           Worklist.emplace_back(OpI);
1522       }
1523     }
1524   }
1525 
1526   // Find instructions that have users outside of chain
1527   SmallVector<Instruction *, 2> OuterInstructions;
1528   for (auto *I : AllInstructions) {
1529     // Skip root nodes
1530     if (RootToNode.count(I))
1531       continue;
1532 
1533     for (User *U : I->users()) {
1534       if (AllInstructions.count(cast<Instruction>(U)))
1535         continue;
1536 
1537       // Found an instruction that is not used by XCMLA/XCADD chain
1538       Worklist.emplace_back(I);
1539       break;
1540     }
1541   }
1542 
1543   // If any instructions are found to be used outside, find and remove roots
1544   // that somehow connect to those instructions.
1545   SmallPtrSet<Instruction *, 16> Visited;
1546   while (!Worklist.empty()) {
1547     auto *I = Worklist.back();
1548     Worklist.pop_back();
1549     if (!Visited.insert(I).second)
1550       continue;
1551 
1552     // Found an impacted root node. Removing it from the nodes to be
1553     // deinterleaved
1554     if (RootToNode.count(I)) {
1555       LLVM_DEBUG(dbgs() << "Instruction " << *I
1556                         << " could be deinterleaved but its chain of complex "
1557                            "operations have an outside user\n");
1558       RootToNode.erase(I);
1559     }
1560 
1561     if (!AllInstructions.count(I) || FinalInstructions.count(I))
1562       continue;
1563 
1564     for (User *U : I->users())
1565       Worklist.emplace_back(cast<Instruction>(U));
1566 
1567     for (Value *Op : I->operands()) {
1568       if (auto *OpI = dyn_cast<Instruction>(Op))
1569         Worklist.emplace_back(OpI);
1570     }
1571   }
1572   return !RootToNode.empty();
1573 }
1574 
1575 ComplexDeinterleavingGraph::NodePtr
1576 ComplexDeinterleavingGraph::identifyRoot(Instruction *RootI) {
1577   if (auto *Intrinsic = dyn_cast<IntrinsicInst>(RootI)) {
1578     if (Intrinsic->getIntrinsicID() !=
1579         Intrinsic::experimental_vector_interleave2)
1580       return nullptr;
1581 
1582     auto *Real = dyn_cast<Instruction>(Intrinsic->getOperand(0));
1583     auto *Imag = dyn_cast<Instruction>(Intrinsic->getOperand(1));
1584     if (!Real || !Imag)
1585       return nullptr;
1586 
1587     return identifyNode(Real, Imag);
1588   }
1589 
1590   auto *SVI = dyn_cast<ShuffleVectorInst>(RootI);
1591   if (!SVI)
1592     return nullptr;
1593 
1594   // Look for a shufflevector that takes separate vectors of the real and
1595   // imaginary components and recombines them into a single vector.
1596   if (!isInterleavingMask(SVI->getShuffleMask()))
1597     return nullptr;
1598 
1599   Instruction *Real;
1600   Instruction *Imag;
1601   if (!match(RootI, m_Shuffle(m_Instruction(Real), m_Instruction(Imag))))
1602     return nullptr;
1603 
1604   return identifyNode(Real, Imag);
1605 }
1606 
1607 ComplexDeinterleavingGraph::NodePtr
1608 ComplexDeinterleavingGraph::identifyDeinterleave(Instruction *Real,
1609                                                  Instruction *Imag) {
1610   Instruction *I = nullptr;
1611   Value *FinalValue = nullptr;
1612   if (match(Real, m_ExtractValue<0>(m_Instruction(I))) &&
1613       match(Imag, m_ExtractValue<1>(m_Specific(I))) &&
1614       match(I, m_Intrinsic<Intrinsic::experimental_vector_deinterleave2>(
1615                    m_Value(FinalValue)))) {
1616     NodePtr PlaceholderNode = prepareCompositeNode(
1617         llvm::ComplexDeinterleavingOperation::Deinterleave, Real, Imag);
1618     PlaceholderNode->ReplacementNode = FinalValue;
1619     FinalInstructions.insert(Real);
1620     FinalInstructions.insert(Imag);
1621     return submitCompositeNode(PlaceholderNode);
1622   }
1623 
1624   auto *RealShuffle = dyn_cast<ShuffleVectorInst>(Real);
1625   auto *ImagShuffle = dyn_cast<ShuffleVectorInst>(Imag);
1626   if (!RealShuffle || !ImagShuffle) {
1627     if (RealShuffle || ImagShuffle)
1628       LLVM_DEBUG(dbgs() << " - There's a shuffle where there shouldn't be.\n");
1629     return nullptr;
1630   }
1631 
1632   Value *RealOp1 = RealShuffle->getOperand(1);
1633   if (!isa<UndefValue>(RealOp1) && !isa<ConstantAggregateZero>(RealOp1)) {
1634     LLVM_DEBUG(dbgs() << " - RealOp1 is not undef or zero.\n");
1635     return nullptr;
1636   }
1637   Value *ImagOp1 = ImagShuffle->getOperand(1);
1638   if (!isa<UndefValue>(ImagOp1) && !isa<ConstantAggregateZero>(ImagOp1)) {
1639     LLVM_DEBUG(dbgs() << " - ImagOp1 is not undef or zero.\n");
1640     return nullptr;
1641   }
1642 
1643   Value *RealOp0 = RealShuffle->getOperand(0);
1644   Value *ImagOp0 = ImagShuffle->getOperand(0);
1645 
1646   if (RealOp0 != ImagOp0) {
1647     LLVM_DEBUG(dbgs() << " - Shuffle operands are not equal.\n");
1648     return nullptr;
1649   }
1650 
1651   ArrayRef<int> RealMask = RealShuffle->getShuffleMask();
1652   ArrayRef<int> ImagMask = ImagShuffle->getShuffleMask();
1653   if (!isDeinterleavingMask(RealMask) || !isDeinterleavingMask(ImagMask)) {
1654     LLVM_DEBUG(dbgs() << " - Masks are not deinterleaving.\n");
1655     return nullptr;
1656   }
1657 
1658   if (RealMask[0] != 0 || ImagMask[0] != 1) {
1659     LLVM_DEBUG(dbgs() << " - Masks do not have the correct initial value.\n");
1660     return nullptr;
1661   }
1662 
1663   // Type checking, the shuffle type should be a vector type of the same
1664   // scalar type, but half the size
1665   auto CheckType = [&](ShuffleVectorInst *Shuffle) {
1666     Value *Op = Shuffle->getOperand(0);
1667     auto *ShuffleTy = cast<FixedVectorType>(Shuffle->getType());
1668     auto *OpTy = cast<FixedVectorType>(Op->getType());
1669 
1670     if (OpTy->getScalarType() != ShuffleTy->getScalarType())
1671       return false;
1672     if ((ShuffleTy->getNumElements() * 2) != OpTy->getNumElements())
1673       return false;
1674 
1675     return true;
1676   };
1677 
1678   auto CheckDeinterleavingShuffle = [&](ShuffleVectorInst *Shuffle) -> bool {
1679     if (!CheckType(Shuffle))
1680       return false;
1681 
1682     ArrayRef<int> Mask = Shuffle->getShuffleMask();
1683     int Last = *Mask.rbegin();
1684 
1685     Value *Op = Shuffle->getOperand(0);
1686     auto *OpTy = cast<FixedVectorType>(Op->getType());
1687     int NumElements = OpTy->getNumElements();
1688 
1689     // Ensure that the deinterleaving shuffle only pulls from the first
1690     // shuffle operand.
1691     return Last < NumElements;
1692   };
1693 
1694   if (RealShuffle->getType() != ImagShuffle->getType()) {
1695     LLVM_DEBUG(dbgs() << " - Shuffle types aren't equal.\n");
1696     return nullptr;
1697   }
1698   if (!CheckDeinterleavingShuffle(RealShuffle)) {
1699     LLVM_DEBUG(dbgs() << " - RealShuffle is invalid type.\n");
1700     return nullptr;
1701   }
1702   if (!CheckDeinterleavingShuffle(ImagShuffle)) {
1703     LLVM_DEBUG(dbgs() << " - ImagShuffle is invalid type.\n");
1704     return nullptr;
1705   }
1706 
1707   NodePtr PlaceholderNode =
1708       prepareCompositeNode(llvm::ComplexDeinterleavingOperation::Deinterleave,
1709                            RealShuffle, ImagShuffle);
1710   PlaceholderNode->ReplacementNode = RealShuffle->getOperand(0);
1711   FinalInstructions.insert(RealShuffle);
1712   FinalInstructions.insert(ImagShuffle);
1713   return submitCompositeNode(PlaceholderNode);
1714 }
1715 
1716 ComplexDeinterleavingGraph::NodePtr
1717 ComplexDeinterleavingGraph::identifyPHINode(Instruction *Real,
1718                                             Instruction *Imag) {
1719   if (Real != RealPHI || Imag != ImagPHI)
1720     return nullptr;
1721 
1722   NodePtr PlaceholderNode = prepareCompositeNode(
1723       ComplexDeinterleavingOperation::ReductionPHI, Real, Imag);
1724   return submitCompositeNode(PlaceholderNode);
1725 }
1726 
1727 ComplexDeinterleavingGraph::NodePtr
1728 ComplexDeinterleavingGraph::identifySelectNode(Instruction *Real,
1729                                                Instruction *Imag) {
1730   auto *SelectReal = dyn_cast<SelectInst>(Real);
1731   auto *SelectImag = dyn_cast<SelectInst>(Imag);
1732   if (!SelectReal || !SelectImag)
1733     return nullptr;
1734 
1735   Instruction *MaskA, *MaskB;
1736   Instruction *AR, *AI, *RA, *BI;
1737   if (!match(Real, m_Select(m_Instruction(MaskA), m_Instruction(AR),
1738                             m_Instruction(RA))) ||
1739       !match(Imag, m_Select(m_Instruction(MaskB), m_Instruction(AI),
1740                             m_Instruction(BI))))
1741     return nullptr;
1742 
1743   if (MaskA != MaskB && !MaskA->isIdenticalTo(MaskB))
1744     return nullptr;
1745 
1746   if (!MaskA->getType()->isVectorTy())
1747     return nullptr;
1748 
1749   auto NodeA = identifyNode(AR, AI);
1750   if (!NodeA)
1751     return nullptr;
1752 
1753   auto NodeB = identifyNode(RA, BI);
1754   if (!NodeB)
1755     return nullptr;
1756 
1757   NodePtr PlaceholderNode = prepareCompositeNode(
1758       ComplexDeinterleavingOperation::ReductionSelect, Real, Imag);
1759   PlaceholderNode->addOperand(NodeA);
1760   PlaceholderNode->addOperand(NodeB);
1761   FinalInstructions.insert(MaskA);
1762   FinalInstructions.insert(MaskB);
1763   return submitCompositeNode(PlaceholderNode);
1764 }
1765 
1766 static Value *replaceSymmetricNode(IRBuilderBase &B, unsigned Opcode,
1767                                    FastMathFlags Flags, Value *InputA,
1768                                    Value *InputB) {
1769   Value *I;
1770   switch (Opcode) {
1771   case Instruction::FNeg:
1772     I = B.CreateFNeg(InputA);
1773     break;
1774   case Instruction::FAdd:
1775     I = B.CreateFAdd(InputA, InputB);
1776     break;
1777   case Instruction::FSub:
1778     I = B.CreateFSub(InputA, InputB);
1779     break;
1780   case Instruction::FMul:
1781     I = B.CreateFMul(InputA, InputB);
1782     break;
1783   default:
1784     llvm_unreachable("Incorrect symmetric opcode");
1785   }
1786   cast<Instruction>(I)->setFastMathFlags(Flags);
1787   return I;
1788 }
1789 
1790 Value *ComplexDeinterleavingGraph::replaceNode(IRBuilderBase &Builder,
1791                                                RawNodePtr Node) {
1792   if (Node->ReplacementNode)
1793     return Node->ReplacementNode;
1794 
1795   auto ReplaceOperandIfExist = [&](RawNodePtr &Node, unsigned Idx) -> Value * {
1796     return Node->Operands.size() > Idx
1797                ? replaceNode(Builder, Node->Operands[Idx])
1798                : nullptr;
1799   };
1800 
1801   Value *ReplacementNode;
1802   switch (Node->Operation) {
1803   case ComplexDeinterleavingOperation::CAdd:
1804   case ComplexDeinterleavingOperation::CMulPartial:
1805   case ComplexDeinterleavingOperation::Symmetric: {
1806     Value *Input0 = ReplaceOperandIfExist(Node, 0);
1807     Value *Input1 = ReplaceOperandIfExist(Node, 1);
1808     Value *Accumulator = ReplaceOperandIfExist(Node, 2);
1809     assert(!Input1 || (Input0->getType() == Input1->getType() &&
1810                        "Node inputs need to be of the same type"));
1811     assert(!Accumulator ||
1812            (Input0->getType() == Accumulator->getType() &&
1813             "Accumulator and input need to be of the same type"));
1814     if (Node->Operation == ComplexDeinterleavingOperation::Symmetric)
1815       ReplacementNode = replaceSymmetricNode(Builder, Node->Opcode, Node->Flags,
1816                                              Input0, Input1);
1817     else
1818       ReplacementNode = TL->createComplexDeinterleavingIR(
1819           Builder, Node->Operation, Node->Rotation, Input0, Input1,
1820           Accumulator);
1821     break;
1822   }
1823   case ComplexDeinterleavingOperation::Deinterleave:
1824     llvm_unreachable("Deinterleave node should already have ReplacementNode");
1825     break;
1826   case ComplexDeinterleavingOperation::ReductionPHI: {
1827     // If Operation is ReductionPHI, a new empty PHINode is created.
1828     // It is filled later when the ReductionOperation is processed.
1829     auto *VTy = cast<VectorType>(Node->Real->getType());
1830     auto *NewVTy = VectorType::getDoubleElementsVectorType(VTy);
1831     auto *NewPHI = PHINode::Create(NewVTy, 0, "", BackEdge->getFirstNonPHI());
1832     OldToNewPHI[dyn_cast<PHINode>(Node->Real)] = NewPHI;
1833     ReplacementNode = NewPHI;
1834     break;
1835   }
1836   case ComplexDeinterleavingOperation::ReductionOperation:
1837     ReplacementNode = replaceNode(Builder, Node->Operands[0]);
1838     processReductionOperation(ReplacementNode, Node);
1839     break;
1840   case ComplexDeinterleavingOperation::ReductionSelect: {
1841     auto *MaskReal = Node->Real->getOperand(0);
1842     auto *MaskImag = Node->Imag->getOperand(0);
1843     auto *A = replaceNode(Builder, Node->Operands[0]);
1844     auto *B = replaceNode(Builder, Node->Operands[1]);
1845     auto *NewMaskTy = VectorType::getDoubleElementsVectorType(
1846         cast<VectorType>(MaskReal->getType()));
1847     auto *NewMask =
1848         Builder.CreateIntrinsic(Intrinsic::experimental_vector_interleave2,
1849                                 NewMaskTy, {MaskReal, MaskImag});
1850     ReplacementNode = Builder.CreateSelect(NewMask, A, B);
1851     break;
1852   }
1853   }
1854 
1855   assert(ReplacementNode && "Target failed to create Intrinsic call.");
1856   NumComplexTransformations += 1;
1857   Node->ReplacementNode = ReplacementNode;
1858   return ReplacementNode;
1859 }
1860 
1861 void ComplexDeinterleavingGraph::processReductionOperation(
1862     Value *OperationReplacement, RawNodePtr Node) {
1863   auto *OldPHIReal = ReductionInfo[Node->Real].first;
1864   auto *OldPHIImag = ReductionInfo[Node->Imag].first;
1865   auto *NewPHI = OldToNewPHI[OldPHIReal];
1866 
1867   auto *VTy = cast<VectorType>(Node->Real->getType());
1868   auto *NewVTy = VectorType::getDoubleElementsVectorType(VTy);
1869 
1870   // We have to interleave initial origin values coming from IncomingBlock
1871   Value *InitReal = OldPHIReal->getIncomingValueForBlock(Incoming);
1872   Value *InitImag = OldPHIImag->getIncomingValueForBlock(Incoming);
1873 
1874   IRBuilder<> Builder(Incoming->getTerminator());
1875   auto *NewInit = Builder.CreateIntrinsic(
1876       Intrinsic::experimental_vector_interleave2, NewVTy, {InitReal, InitImag});
1877 
1878   NewPHI->addIncoming(NewInit, Incoming);
1879   NewPHI->addIncoming(OperationReplacement, BackEdge);
1880 
1881   // Deinterleave complex vector outside of loop so that it can be finally
1882   // reduced
1883   auto *FinalReductionReal = ReductionInfo[Node->Real].second;
1884   auto *FinalReductionImag = ReductionInfo[Node->Imag].second;
1885 
1886   Builder.SetInsertPoint(
1887       &*FinalReductionReal->getParent()->getFirstInsertionPt());
1888   auto *Deinterleave = Builder.CreateIntrinsic(
1889       Intrinsic::experimental_vector_deinterleave2,
1890       OperationReplacement->getType(), OperationReplacement);
1891 
1892   auto *NewReal = Builder.CreateExtractValue(Deinterleave, (uint64_t)0);
1893   FinalReductionReal->replaceUsesOfWith(Node->Real, NewReal);
1894 
1895   Builder.SetInsertPoint(FinalReductionImag);
1896   auto *NewImag = Builder.CreateExtractValue(Deinterleave, 1);
1897   FinalReductionImag->replaceUsesOfWith(Node->Imag, NewImag);
1898 }
1899 
1900 void ComplexDeinterleavingGraph::replaceNodes() {
1901   SmallVector<Instruction *, 16> DeadInstrRoots;
1902   for (auto *RootInstruction : OrderedRoots) {
1903     // Check if this potential root went through check process and we can
1904     // deinterleave it
1905     if (!RootToNode.count(RootInstruction))
1906       continue;
1907 
1908     IRBuilder<> Builder(RootInstruction);
1909     auto RootNode = RootToNode[RootInstruction];
1910     Value *R = replaceNode(Builder, RootNode.get());
1911 
1912     if (RootNode->Operation ==
1913         ComplexDeinterleavingOperation::ReductionOperation) {
1914       ReductionInfo[RootNode->Real].first->removeIncomingValue(BackEdge);
1915       ReductionInfo[RootNode->Imag].first->removeIncomingValue(BackEdge);
1916       DeadInstrRoots.push_back(RootNode->Real);
1917       DeadInstrRoots.push_back(RootNode->Imag);
1918     } else {
1919       assert(R && "Unable to find replacement for RootInstruction");
1920       DeadInstrRoots.push_back(RootInstruction);
1921       RootInstruction->replaceAllUsesWith(R);
1922     }
1923   }
1924 
1925   for (auto *I : DeadInstrRoots)
1926     RecursivelyDeleteTriviallyDeadInstructions(I, TLI);
1927 }
1928