xref: /llvm-project/llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp (revision 1b2943534fa20a61c07592a9bd90203e682ae0f4)
1 //===- ComplexDeinterleavingPass.cpp --------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Identification:
10 // This step is responsible for finding the patterns that can be lowered to
11 // complex instructions, and building a graph to represent the complex
12 // structures. Starting from the "Converging Shuffle" (a shuffle that
13 // reinterleaves the complex components, with a mask of <0, 2, 1, 3>), the
14 // operands are evaluated and identified as "Composite Nodes" (collections of
15 // instructions that can potentially be lowered to a single complex
16 // instruction). This is performed by checking the real and imaginary components
17 // and tracking the data flow for each component while following the operand
18 // pairs. Validity of each node is expected to be done upon creation, and any
19 // validation errors should halt traversal and prevent further graph
20 // construction.
21 // Instead of relying on Shuffle operations, vector interleaving and
22 // deinterleaving can be represented by vector.interleave2 and
23 // vector.deinterleave2 intrinsics. Scalable vectors can be represented only by
24 // these intrinsics, whereas, fixed-width vectors are recognized for both
25 // shufflevector instruction and intrinsics.
26 //
27 // Replacement:
28 // This step traverses the graph built up by identification, delegating to the
29 // target to validate and generate the correct intrinsics, and plumbs them
30 // together connecting each end of the new intrinsics graph to the existing
31 // use-def chain. This step is assumed to finish successfully, as all
32 // information is expected to be correct by this point.
33 //
34 //
35 // Internal data structure:
36 // ComplexDeinterleavingGraph:
37 // Keeps references to all the valid CompositeNodes formed as part of the
38 // transformation, and every Instruction contained within said nodes. It also
39 // holds onto a reference to the root Instruction, and the root node that should
40 // replace it.
41 //
42 // ComplexDeinterleavingCompositeNode:
43 // A CompositeNode represents a single transformation point; each node should
44 // transform into a single complex instruction (ignoring vector splitting, which
45 // would generate more instructions per node). They are identified in a
46 // depth-first manner, traversing and identifying the operands of each
47 // instruction in the order they appear in the IR.
48 // Each node maintains a reference  to its Real and Imaginary instructions,
49 // as well as any additional instructions that make up the identified operation
50 // (Internal instructions should only have uses within their containing node).
51 // A Node also contains the rotation and operation type that it represents.
52 // Operands contains pointers to other CompositeNodes, acting as the edges in
53 // the graph. ReplacementValue is the transformed Value* that has been emitted
54 // to the IR.
55 //
56 // Note: If the operation of a Node is Shuffle, only the Real, Imaginary, and
57 // ReplacementValue fields of that Node are relevant, where the ReplacementValue
58 // should be pre-populated.
59 //
60 //===----------------------------------------------------------------------===//
61 
62 #include "llvm/CodeGen/ComplexDeinterleavingPass.h"
63 #include "llvm/ADT/MapVector.h"
64 #include "llvm/ADT/Statistic.h"
65 #include "llvm/Analysis/TargetLibraryInfo.h"
66 #include "llvm/Analysis/TargetTransformInfo.h"
67 #include "llvm/CodeGen/TargetLowering.h"
68 #include "llvm/CodeGen/TargetSubtargetInfo.h"
69 #include "llvm/IR/IRBuilder.h"
70 #include "llvm/IR/PatternMatch.h"
71 #include "llvm/InitializePasses.h"
72 #include "llvm/Target/TargetMachine.h"
73 #include "llvm/Transforms/Utils/Local.h"
74 #include <algorithm>
75 
76 using namespace llvm;
77 using namespace PatternMatch;
78 
79 #define DEBUG_TYPE "complex-deinterleaving"
80 
81 STATISTIC(NumComplexTransformations, "Amount of complex patterns transformed");
82 
83 static cl::opt<bool> ComplexDeinterleavingEnabled(
84     "enable-complex-deinterleaving",
85     cl::desc("Enable generation of complex instructions"), cl::init(true),
86     cl::Hidden);
87 
88 /// Checks the given mask, and determines whether said mask is interleaving.
89 ///
90 /// To be interleaving, a mask must alternate between `i` and `i + (Length /
91 /// 2)`, and must contain all numbers within the range of `[0..Length)` (e.g. a
92 /// 4x vector interleaving mask would be <0, 2, 1, 3>).
93 static bool isInterleavingMask(ArrayRef<int> Mask);
94 
95 /// Checks the given mask, and determines whether said mask is deinterleaving.
96 ///
97 /// To be deinterleaving, a mask must increment in steps of 2, and either start
98 /// with 0 or 1.
99 /// (e.g. an 8x vector deinterleaving mask would be either <0, 2, 4, 6> or
100 /// <1, 3, 5, 7>).
101 static bool isDeinterleavingMask(ArrayRef<int> Mask);
102 
103 /// Returns true if the operation is a negation of V, and it works for both
104 /// integers and floats.
105 static bool isNeg(Value *V);
106 
107 /// Returns the operand for negation operation.
108 static Value *getNegOperand(Value *V);
109 
110 namespace {
111 template <typename T, typename IterT>
112 std::optional<T> findCommonBetweenCollections(IterT A, IterT B) {
113   auto Common = llvm::find_if(A, [B](T I) { return llvm::is_contained(B, I); });
114   if (Common != A.end())
115     return std::make_optional(*Common);
116   return std::nullopt;
117 }
118 
119 class ComplexDeinterleavingLegacyPass : public FunctionPass {
120 public:
121   static char ID;
122 
123   ComplexDeinterleavingLegacyPass(const TargetMachine *TM = nullptr)
124       : FunctionPass(ID), TM(TM) {
125     initializeComplexDeinterleavingLegacyPassPass(
126         *PassRegistry::getPassRegistry());
127   }
128 
129   StringRef getPassName() const override {
130     return "Complex Deinterleaving Pass";
131   }
132 
133   bool runOnFunction(Function &F) override;
134   void getAnalysisUsage(AnalysisUsage &AU) const override {
135     AU.addRequired<TargetLibraryInfoWrapperPass>();
136     AU.setPreservesCFG();
137   }
138 
139 private:
140   const TargetMachine *TM;
141 };
142 
143 class ComplexDeinterleavingGraph;
144 struct ComplexDeinterleavingCompositeNode {
145 
146   ComplexDeinterleavingCompositeNode(ComplexDeinterleavingOperation Op,
147                                      Value *R, Value *I)
148       : Operation(Op), Real(R), Imag(I) {}
149 
150 private:
151   friend class ComplexDeinterleavingGraph;
152   using NodePtr = std::shared_ptr<ComplexDeinterleavingCompositeNode>;
153   using RawNodePtr = ComplexDeinterleavingCompositeNode *;
154   bool OperandsValid = true;
155 
156 public:
157   ComplexDeinterleavingOperation Operation;
158   Value *Real;
159   Value *Imag;
160 
161   // This two members are required exclusively for generating
162   // ComplexDeinterleavingOperation::Symmetric operations.
163   unsigned Opcode;
164   std::optional<FastMathFlags> Flags;
165 
166   ComplexDeinterleavingRotation Rotation =
167       ComplexDeinterleavingRotation::Rotation_0;
168   SmallVector<RawNodePtr> Operands;
169   Value *ReplacementNode = nullptr;
170 
171   void addOperand(NodePtr Node) {
172     if (!Node || !Node.get())
173       OperandsValid = false;
174     Operands.push_back(Node.get());
175   }
176 
177   void dump() { dump(dbgs()); }
178   void dump(raw_ostream &OS) {
179     auto PrintValue = [&](Value *V) {
180       if (V) {
181         OS << "\"";
182         V->print(OS, true);
183         OS << "\"\n";
184       } else
185         OS << "nullptr\n";
186     };
187     auto PrintNodeRef = [&](RawNodePtr Ptr) {
188       if (Ptr)
189         OS << Ptr << "\n";
190       else
191         OS << "nullptr\n";
192     };
193 
194     OS << "- CompositeNode: " << this << "\n";
195     OS << "  Real: ";
196     PrintValue(Real);
197     OS << "  Imag: ";
198     PrintValue(Imag);
199     OS << "  ReplacementNode: ";
200     PrintValue(ReplacementNode);
201     OS << "  Operation: " << (int)Operation << "\n";
202     OS << "  Rotation: " << ((int)Rotation * 90) << "\n";
203     OS << "  Operands: \n";
204     for (const auto &Op : Operands) {
205       OS << "    - ";
206       PrintNodeRef(Op);
207     }
208   }
209 
210   bool areOperandsValid() { return OperandsValid; }
211 };
212 
213 class ComplexDeinterleavingGraph {
214 public:
215   struct Product {
216     Value *Multiplier;
217     Value *Multiplicand;
218     bool IsPositive;
219   };
220 
221   using Addend = std::pair<Value *, bool>;
222   using NodePtr = ComplexDeinterleavingCompositeNode::NodePtr;
223   using RawNodePtr = ComplexDeinterleavingCompositeNode::RawNodePtr;
224 
225   // Helper struct for holding info about potential partial multiplication
226   // candidates
227   struct PartialMulCandidate {
228     Value *Common;
229     NodePtr Node;
230     unsigned RealIdx;
231     unsigned ImagIdx;
232     bool IsNodeInverted;
233   };
234 
235   explicit ComplexDeinterleavingGraph(const TargetLowering *TL,
236                                       const TargetLibraryInfo *TLI)
237       : TL(TL), TLI(TLI) {}
238 
239 private:
240   const TargetLowering *TL = nullptr;
241   const TargetLibraryInfo *TLI = nullptr;
242   SmallVector<NodePtr> CompositeNodes;
243   DenseMap<std::pair<Value *, Value *>, NodePtr> CachedResult;
244 
245   SmallPtrSet<Instruction *, 16> FinalInstructions;
246 
247   /// Root instructions are instructions from which complex computation starts
248   std::map<Instruction *, NodePtr> RootToNode;
249 
250   /// Topologically sorted root instructions
251   SmallVector<Instruction *, 1> OrderedRoots;
252 
253   /// When examining a basic block for complex deinterleaving, if it is a simple
254   /// one-block loop, then the only incoming block is 'Incoming' and the
255   /// 'BackEdge' block is the block itself."
256   BasicBlock *BackEdge = nullptr;
257   BasicBlock *Incoming = nullptr;
258 
259   /// ReductionInfo maps from %ReductionOp to %PHInode and Instruction
260   /// %OutsideUser as it is shown in the IR:
261   ///
262   /// vector.body:
263   ///   %PHInode = phi <vector type> [ zeroinitializer, %entry ],
264   ///                                [ %ReductionOp, %vector.body ]
265   ///   ...
266   ///   %ReductionOp = fadd i64 ...
267   ///   ...
268   ///   br i1 %condition, label %vector.body, %middle.block
269   ///
270   /// middle.block:
271   ///   %OutsideUser = llvm.vector.reduce.fadd(..., %ReductionOp)
272   ///
273   /// %OutsideUser can be `llvm.vector.reduce.fadd` or `fadd` preceding
274   /// `llvm.vector.reduce.fadd` when unroll factor isn't one.
275   MapVector<Instruction *, std::pair<PHINode *, Instruction *>> ReductionInfo;
276 
277   /// In the process of detecting a reduction, we consider a pair of
278   /// %ReductionOP, which we refer to as real and imag (or vice versa), and
279   /// traverse the use-tree to detect complex operations. As this is a reduction
280   /// operation, it will eventually reach RealPHI and ImagPHI, which corresponds
281   /// to the %ReductionOPs that we suspect to be complex.
282   /// RealPHI and ImagPHI are used by the identifyPHINode method.
283   PHINode *RealPHI = nullptr;
284   PHINode *ImagPHI = nullptr;
285 
286   /// Set this flag to true if RealPHI and ImagPHI were reached during reduction
287   /// detection.
288   bool PHIsFound = false;
289 
290   /// OldToNewPHI maps the original real PHINode to a new, double-sized PHINode.
291   /// The new PHINode corresponds to a vector of deinterleaved complex numbers.
292   /// This mapping is populated during
293   /// ComplexDeinterleavingOperation::ReductionPHI node replacement. It is then
294   /// used in the ComplexDeinterleavingOperation::ReductionOperation node
295   /// replacement process.
296   std::map<PHINode *, PHINode *> OldToNewPHI;
297 
298   NodePtr prepareCompositeNode(ComplexDeinterleavingOperation Operation,
299                                Value *R, Value *I) {
300     assert(((Operation != ComplexDeinterleavingOperation::ReductionPHI &&
301              Operation != ComplexDeinterleavingOperation::ReductionOperation) ||
302             (R && I)) &&
303            "Reduction related nodes must have Real and Imaginary parts");
304     return std::make_shared<ComplexDeinterleavingCompositeNode>(Operation, R,
305                                                                 I);
306   }
307 
308   NodePtr submitCompositeNode(NodePtr Node) {
309     CompositeNodes.push_back(Node);
310     if (Node->Real)
311       CachedResult[{Node->Real, Node->Imag}] = Node;
312     return Node;
313   }
314 
315   /// Identifies a complex partial multiply pattern and its rotation, based on
316   /// the following patterns
317   ///
318   ///  0:  r: cr + ar * br
319   ///      i: ci + ar * bi
320   /// 90:  r: cr - ai * bi
321   ///      i: ci + ai * br
322   /// 180: r: cr - ar * br
323   ///      i: ci - ar * bi
324   /// 270: r: cr + ai * bi
325   ///      i: ci - ai * br
326   NodePtr identifyPartialMul(Instruction *Real, Instruction *Imag);
327 
328   /// Identify the other branch of a Partial Mul, taking the CommonOperandI that
329   /// is partially known from identifyPartialMul, filling in the other half of
330   /// the complex pair.
331   NodePtr
332   identifyNodeWithImplicitAdd(Instruction *I, Instruction *J,
333                               std::pair<Value *, Value *> &CommonOperandI);
334 
335   /// Identifies a complex add pattern and its rotation, based on the following
336   /// patterns.
337   ///
338   /// 90:  r: ar - bi
339   ///      i: ai + br
340   /// 270: r: ar + bi
341   ///      i: ai - br
342   NodePtr identifyAdd(Instruction *Real, Instruction *Imag);
343   NodePtr identifySymmetricOperation(Instruction *Real, Instruction *Imag);
344   NodePtr identifyPartialReduction(Value *R, Value *I);
345   NodePtr identifyDotProduct(Value *Inst);
346 
347   NodePtr identifyNode(Value *R, Value *I);
348 
349   /// Determine if a sum of complex numbers can be formed from \p RealAddends
350   /// and \p ImagAddens. If \p Accumulator is not null, add the result to it.
351   /// Return nullptr if it is not possible to construct a complex number.
352   /// \p Flags are needed to generate symmetric Add and Sub operations.
353   NodePtr identifyAdditions(std::list<Addend> &RealAddends,
354                             std::list<Addend> &ImagAddends,
355                             std::optional<FastMathFlags> Flags,
356                             NodePtr Accumulator);
357 
358   /// Extract one addend that have both real and imaginary parts positive.
359   NodePtr extractPositiveAddend(std::list<Addend> &RealAddends,
360                                 std::list<Addend> &ImagAddends);
361 
362   /// Determine if sum of multiplications of complex numbers can be formed from
363   /// \p RealMuls and \p ImagMuls. If \p Accumulator is not null, add the result
364   /// to it. Return nullptr if it is not possible to construct a complex number.
365   NodePtr identifyMultiplications(std::vector<Product> &RealMuls,
366                                   std::vector<Product> &ImagMuls,
367                                   NodePtr Accumulator);
368 
369   /// Go through pairs of multiplication (one Real and one Imag) and find all
370   /// possible candidates for partial multiplication and put them into \p
371   /// Candidates. Returns true if all Product has pair with common operand
372   bool collectPartialMuls(const std::vector<Product> &RealMuls,
373                           const std::vector<Product> &ImagMuls,
374                           std::vector<PartialMulCandidate> &Candidates);
375 
376   /// If the code is compiled with -Ofast or expressions have `reassoc` flag,
377   /// the order of complex computation operations may be significantly altered,
378   /// and the real and imaginary parts may not be executed in parallel. This
379   /// function takes this into consideration and employs a more general approach
380   /// to identify complex computations. Initially, it gathers all the addends
381   /// and multiplicands and then constructs a complex expression from them.
382   NodePtr identifyReassocNodes(Instruction *I, Instruction *J);
383 
384   NodePtr identifyRoot(Instruction *I);
385 
386   /// Identifies the Deinterleave operation applied to a vector containing
387   /// complex numbers. There are two ways to represent the Deinterleave
388   /// operation:
389   /// * Using two shufflevectors with even indices for /pReal instruction and
390   /// odd indices for /pImag instructions (only for fixed-width vectors)
391   /// * Using two extractvalue instructions applied to `vector.deinterleave2`
392   /// intrinsic (for both fixed and scalable vectors)
393   NodePtr identifyDeinterleave(Instruction *Real, Instruction *Imag);
394 
395   /// identifying the operation that represents a complex number repeated in a
396   /// Splat vector. There are two possible types of splats: ConstantExpr with
397   /// the opcode ShuffleVector and ShuffleVectorInstr. Both should have an
398   /// initialization mask with all values set to zero.
399   NodePtr identifySplat(Value *Real, Value *Imag);
400 
401   NodePtr identifyPHINode(Instruction *Real, Instruction *Imag);
402 
403   /// Identifies SelectInsts in a loop that has reduction with predication masks
404   /// and/or predicated tail folding
405   NodePtr identifySelectNode(Instruction *Real, Instruction *Imag);
406 
407   Value *replaceNode(IRBuilderBase &Builder, RawNodePtr Node);
408 
409   /// Complete IR modifications after producing new reduction operation:
410   /// * Populate the PHINode generated for
411   /// ComplexDeinterleavingOperation::ReductionPHI
412   /// * Deinterleave the final value outside of the loop and repurpose original
413   /// reduction users
414   void processReductionOperation(Value *OperationReplacement, RawNodePtr Node);
415   void processReductionSingle(Value *OperationReplacement, RawNodePtr Node);
416 
417 public:
418   void dump() { dump(dbgs()); }
419   void dump(raw_ostream &OS) {
420     for (const auto &Node : CompositeNodes)
421       Node->dump(OS);
422   }
423 
424   /// Returns false if the deinterleaving operation should be cancelled for the
425   /// current graph.
426   bool identifyNodes(Instruction *RootI);
427 
428   /// In case \pB is one-block loop, this function seeks potential reductions
429   /// and populates ReductionInfo. Returns true if any reductions were
430   /// identified.
431   bool collectPotentialReductions(BasicBlock *B);
432 
433   void identifyReductionNodes();
434 
435   /// Check that every instruction, from the roots to the leaves, has internal
436   /// uses.
437   bool checkNodes();
438 
439   /// Perform the actual replacement of the underlying instruction graph.
440   void replaceNodes();
441 };
442 
443 class ComplexDeinterleaving {
444 public:
445   ComplexDeinterleaving(const TargetLowering *tl, const TargetLibraryInfo *tli)
446       : TL(tl), TLI(tli) {}
447   bool runOnFunction(Function &F);
448 
449 private:
450   bool evaluateBasicBlock(BasicBlock *B);
451 
452   const TargetLowering *TL = nullptr;
453   const TargetLibraryInfo *TLI = nullptr;
454 };
455 
456 } // namespace
457 
458 char ComplexDeinterleavingLegacyPass::ID = 0;
459 
460 INITIALIZE_PASS_BEGIN(ComplexDeinterleavingLegacyPass, DEBUG_TYPE,
461                       "Complex Deinterleaving", false, false)
462 INITIALIZE_PASS_END(ComplexDeinterleavingLegacyPass, DEBUG_TYPE,
463                     "Complex Deinterleaving", false, false)
464 
465 PreservedAnalyses ComplexDeinterleavingPass::run(Function &F,
466                                                  FunctionAnalysisManager &AM) {
467   const TargetLowering *TL = TM->getSubtargetImpl(F)->getTargetLowering();
468   auto &TLI = AM.getResult<llvm::TargetLibraryAnalysis>(F);
469   if (!ComplexDeinterleaving(TL, &TLI).runOnFunction(F))
470     return PreservedAnalyses::all();
471 
472   PreservedAnalyses PA;
473   PA.preserve<FunctionAnalysisManagerModuleProxy>();
474   return PA;
475 }
476 
477 FunctionPass *llvm::createComplexDeinterleavingPass(const TargetMachine *TM) {
478   return new ComplexDeinterleavingLegacyPass(TM);
479 }
480 
481 bool ComplexDeinterleavingLegacyPass::runOnFunction(Function &F) {
482   const auto *TL = TM->getSubtargetImpl(F)->getTargetLowering();
483   auto TLI = getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F);
484   return ComplexDeinterleaving(TL, &TLI).runOnFunction(F);
485 }
486 
487 bool ComplexDeinterleaving::runOnFunction(Function &F) {
488   if (!ComplexDeinterleavingEnabled) {
489     LLVM_DEBUG(
490         dbgs() << "Complex deinterleaving has been explicitly disabled.\n");
491     return false;
492   }
493 
494   if (!TL->isComplexDeinterleavingSupported()) {
495     LLVM_DEBUG(
496         dbgs() << "Complex deinterleaving has been disabled, target does "
497                   "not support lowering of complex number operations.\n");
498     return false;
499   }
500 
501   bool Changed = false;
502   for (auto &B : F)
503     Changed |= evaluateBasicBlock(&B);
504 
505   return Changed;
506 }
507 
508 static bool isInterleavingMask(ArrayRef<int> Mask) {
509   // If the size is not even, it's not an interleaving mask
510   if ((Mask.size() & 1))
511     return false;
512 
513   int HalfNumElements = Mask.size() / 2;
514   for (int Idx = 0; Idx < HalfNumElements; ++Idx) {
515     int MaskIdx = Idx * 2;
516     if (Mask[MaskIdx] != Idx || Mask[MaskIdx + 1] != (Idx + HalfNumElements))
517       return false;
518   }
519 
520   return true;
521 }
522 
523 static bool isDeinterleavingMask(ArrayRef<int> Mask) {
524   int Offset = Mask[0];
525   int HalfNumElements = Mask.size() / 2;
526 
527   for (int Idx = 1; Idx < HalfNumElements; ++Idx) {
528     if (Mask[Idx] != (Idx * 2) + Offset)
529       return false;
530   }
531 
532   return true;
533 }
534 
535 bool isNeg(Value *V) {
536   return match(V, m_FNeg(m_Value())) || match(V, m_Neg(m_Value()));
537 }
538 
539 Value *getNegOperand(Value *V) {
540   assert(isNeg(V));
541   auto *I = cast<Instruction>(V);
542   if (I->getOpcode() == Instruction::FNeg)
543     return I->getOperand(0);
544 
545   return I->getOperand(1);
546 }
547 
548 bool ComplexDeinterleaving::evaluateBasicBlock(BasicBlock *B) {
549   ComplexDeinterleavingGraph Graph(TL, TLI);
550   if (Graph.collectPotentialReductions(B))
551     Graph.identifyReductionNodes();
552 
553   for (auto &I : *B)
554     Graph.identifyNodes(&I);
555 
556   if (Graph.checkNodes()) {
557     Graph.replaceNodes();
558     return true;
559   }
560 
561   return false;
562 }
563 
564 ComplexDeinterleavingGraph::NodePtr
565 ComplexDeinterleavingGraph::identifyNodeWithImplicitAdd(
566     Instruction *Real, Instruction *Imag,
567     std::pair<Value *, Value *> &PartialMatch) {
568   LLVM_DEBUG(dbgs() << "identifyNodeWithImplicitAdd " << *Real << " / " << *Imag
569                     << "\n");
570 
571   if (!Real->hasOneUse() || !Imag->hasOneUse()) {
572     LLVM_DEBUG(dbgs() << "  - Mul operand has multiple uses.\n");
573     return nullptr;
574   }
575 
576   if ((Real->getOpcode() != Instruction::FMul &&
577        Real->getOpcode() != Instruction::Mul) ||
578       (Imag->getOpcode() != Instruction::FMul &&
579        Imag->getOpcode() != Instruction::Mul)) {
580     LLVM_DEBUG(
581         dbgs() << "  - Real or imaginary instruction is not fmul or mul\n");
582     return nullptr;
583   }
584 
585   Value *R0 = Real->getOperand(0);
586   Value *R1 = Real->getOperand(1);
587   Value *I0 = Imag->getOperand(0);
588   Value *I1 = Imag->getOperand(1);
589 
590   // A +/+ has a rotation of 0. If any of the operands are fneg, we flip the
591   // rotations and use the operand.
592   unsigned Negs = 0;
593   Value *Op;
594   if (match(R0, m_Neg(m_Value(Op)))) {
595     Negs |= 1;
596     R0 = Op;
597   } else if (match(R1, m_Neg(m_Value(Op)))) {
598     Negs |= 1;
599     R1 = Op;
600   }
601 
602   if (isNeg(I0)) {
603     Negs |= 2;
604     Negs ^= 1;
605     I0 = Op;
606   } else if (match(I1, m_Neg(m_Value(Op)))) {
607     Negs |= 2;
608     Negs ^= 1;
609     I1 = Op;
610   }
611 
612   ComplexDeinterleavingRotation Rotation = (ComplexDeinterleavingRotation)Negs;
613 
614   Value *CommonOperand;
615   Value *UncommonRealOp;
616   Value *UncommonImagOp;
617 
618   if (R0 == I0 || R0 == I1) {
619     CommonOperand = R0;
620     UncommonRealOp = R1;
621   } else if (R1 == I0 || R1 == I1) {
622     CommonOperand = R1;
623     UncommonRealOp = R0;
624   } else {
625     LLVM_DEBUG(dbgs() << "  - No equal operand\n");
626     return nullptr;
627   }
628 
629   UncommonImagOp = (CommonOperand == I0) ? I1 : I0;
630   if (Rotation == ComplexDeinterleavingRotation::Rotation_90 ||
631       Rotation == ComplexDeinterleavingRotation::Rotation_270)
632     std::swap(UncommonRealOp, UncommonImagOp);
633 
634   // Between identifyPartialMul and here we need to have found a complete valid
635   // pair from the CommonOperand of each part.
636   if (Rotation == ComplexDeinterleavingRotation::Rotation_0 ||
637       Rotation == ComplexDeinterleavingRotation::Rotation_180)
638     PartialMatch.first = CommonOperand;
639   else
640     PartialMatch.second = CommonOperand;
641 
642   if (!PartialMatch.first || !PartialMatch.second) {
643     LLVM_DEBUG(dbgs() << "  - Incomplete partial match\n");
644     return nullptr;
645   }
646 
647   NodePtr CommonNode = identifyNode(PartialMatch.first, PartialMatch.second);
648   if (!CommonNode) {
649     LLVM_DEBUG(dbgs() << "  - No CommonNode identified\n");
650     return nullptr;
651   }
652 
653   NodePtr UncommonNode = identifyNode(UncommonRealOp, UncommonImagOp);
654   if (!UncommonNode) {
655     LLVM_DEBUG(dbgs() << "  - No UncommonNode identified\n");
656     return nullptr;
657   }
658 
659   NodePtr Node = prepareCompositeNode(
660       ComplexDeinterleavingOperation::CMulPartial, Real, Imag);
661   Node->Rotation = Rotation;
662   Node->addOperand(CommonNode);
663   Node->addOperand(UncommonNode);
664   return submitCompositeNode(Node);
665 }
666 
667 ComplexDeinterleavingGraph::NodePtr
668 ComplexDeinterleavingGraph::identifyPartialMul(Instruction *Real,
669                                                Instruction *Imag) {
670   LLVM_DEBUG(dbgs() << "identifyPartialMul " << *Real << " / " << *Imag
671                     << "\n");
672   // Determine rotation
673   auto IsAdd = [](unsigned Op) {
674     return Op == Instruction::FAdd || Op == Instruction::Add;
675   };
676   auto IsSub = [](unsigned Op) {
677     return Op == Instruction::FSub || Op == Instruction::Sub;
678   };
679   ComplexDeinterleavingRotation Rotation;
680   if (IsAdd(Real->getOpcode()) && IsAdd(Imag->getOpcode()))
681     Rotation = ComplexDeinterleavingRotation::Rotation_0;
682   else if (IsSub(Real->getOpcode()) && IsAdd(Imag->getOpcode()))
683     Rotation = ComplexDeinterleavingRotation::Rotation_90;
684   else if (IsSub(Real->getOpcode()) && IsSub(Imag->getOpcode()))
685     Rotation = ComplexDeinterleavingRotation::Rotation_180;
686   else if (IsAdd(Real->getOpcode()) && IsSub(Imag->getOpcode()))
687     Rotation = ComplexDeinterleavingRotation::Rotation_270;
688   else {
689     LLVM_DEBUG(dbgs() << "  - Unhandled rotation.\n");
690     return nullptr;
691   }
692 
693   if (isa<FPMathOperator>(Real) &&
694       (!Real->getFastMathFlags().allowContract() ||
695        !Imag->getFastMathFlags().allowContract())) {
696     LLVM_DEBUG(dbgs() << "  - Contract is missing from the FastMath flags.\n");
697     return nullptr;
698   }
699 
700   Value *CR = Real->getOperand(0);
701   Instruction *RealMulI = dyn_cast<Instruction>(Real->getOperand(1));
702   if (!RealMulI)
703     return nullptr;
704   Value *CI = Imag->getOperand(0);
705   Instruction *ImagMulI = dyn_cast<Instruction>(Imag->getOperand(1));
706   if (!ImagMulI)
707     return nullptr;
708 
709   if (!RealMulI->hasOneUse() || !ImagMulI->hasOneUse()) {
710     LLVM_DEBUG(dbgs() << "  - Mul instruction has multiple uses\n");
711     return nullptr;
712   }
713 
714   Value *R0 = RealMulI->getOperand(0);
715   Value *R1 = RealMulI->getOperand(1);
716   Value *I0 = ImagMulI->getOperand(0);
717   Value *I1 = ImagMulI->getOperand(1);
718 
719   Value *CommonOperand;
720   Value *UncommonRealOp;
721   Value *UncommonImagOp;
722 
723   if (R0 == I0 || R0 == I1) {
724     CommonOperand = R0;
725     UncommonRealOp = R1;
726   } else if (R1 == I0 || R1 == I1) {
727     CommonOperand = R1;
728     UncommonRealOp = R0;
729   } else {
730     LLVM_DEBUG(dbgs() << "  - No equal operand\n");
731     return nullptr;
732   }
733 
734   UncommonImagOp = (CommonOperand == I0) ? I1 : I0;
735   if (Rotation == ComplexDeinterleavingRotation::Rotation_90 ||
736       Rotation == ComplexDeinterleavingRotation::Rotation_270)
737     std::swap(UncommonRealOp, UncommonImagOp);
738 
739   std::pair<Value *, Value *> PartialMatch(
740       (Rotation == ComplexDeinterleavingRotation::Rotation_0 ||
741        Rotation == ComplexDeinterleavingRotation::Rotation_180)
742           ? CommonOperand
743           : nullptr,
744       (Rotation == ComplexDeinterleavingRotation::Rotation_90 ||
745        Rotation == ComplexDeinterleavingRotation::Rotation_270)
746           ? CommonOperand
747           : nullptr);
748 
749   auto *CRInst = dyn_cast<Instruction>(CR);
750   auto *CIInst = dyn_cast<Instruction>(CI);
751 
752   if (!CRInst || !CIInst) {
753     LLVM_DEBUG(dbgs() << "  - Common operands are not instructions.\n");
754     return nullptr;
755   }
756 
757   NodePtr CNode = identifyNodeWithImplicitAdd(CRInst, CIInst, PartialMatch);
758   if (!CNode) {
759     LLVM_DEBUG(dbgs() << "  - No cnode identified\n");
760     return nullptr;
761   }
762 
763   NodePtr UncommonRes = identifyNode(UncommonRealOp, UncommonImagOp);
764   if (!UncommonRes) {
765     LLVM_DEBUG(dbgs() << "  - No UncommonRes identified\n");
766     return nullptr;
767   }
768 
769   assert(PartialMatch.first && PartialMatch.second);
770   NodePtr CommonRes = identifyNode(PartialMatch.first, PartialMatch.second);
771   if (!CommonRes) {
772     LLVM_DEBUG(dbgs() << "  - No CommonRes identified\n");
773     return nullptr;
774   }
775 
776   NodePtr Node = prepareCompositeNode(
777       ComplexDeinterleavingOperation::CMulPartial, Real, Imag);
778   Node->Rotation = Rotation;
779   Node->addOperand(CommonRes);
780   Node->addOperand(UncommonRes);
781   Node->addOperand(CNode);
782   return submitCompositeNode(Node);
783 }
784 
785 ComplexDeinterleavingGraph::NodePtr
786 ComplexDeinterleavingGraph::identifyAdd(Instruction *Real, Instruction *Imag) {
787   LLVM_DEBUG(dbgs() << "identifyAdd " << *Real << " / " << *Imag << "\n");
788 
789   // Determine rotation
790   ComplexDeinterleavingRotation Rotation;
791   if ((Real->getOpcode() == Instruction::FSub &&
792        Imag->getOpcode() == Instruction::FAdd) ||
793       (Real->getOpcode() == Instruction::Sub &&
794        Imag->getOpcode() == Instruction::Add))
795     Rotation = ComplexDeinterleavingRotation::Rotation_90;
796   else if ((Real->getOpcode() == Instruction::FAdd &&
797             Imag->getOpcode() == Instruction::FSub) ||
798            (Real->getOpcode() == Instruction::Add &&
799             Imag->getOpcode() == Instruction::Sub))
800     Rotation = ComplexDeinterleavingRotation::Rotation_270;
801   else {
802     LLVM_DEBUG(dbgs() << " - Unhandled case, rotation is not assigned.\n");
803     return nullptr;
804   }
805 
806   auto *AR = dyn_cast<Instruction>(Real->getOperand(0));
807   auto *BI = dyn_cast<Instruction>(Real->getOperand(1));
808   auto *AI = dyn_cast<Instruction>(Imag->getOperand(0));
809   auto *BR = dyn_cast<Instruction>(Imag->getOperand(1));
810 
811   if (!AR || !AI || !BR || !BI) {
812     LLVM_DEBUG(dbgs() << " - Not all operands are instructions.\n");
813     return nullptr;
814   }
815 
816   NodePtr ResA = identifyNode(AR, AI);
817   if (!ResA) {
818     LLVM_DEBUG(dbgs() << " - AR/AI is not identified as a composite node.\n");
819     return nullptr;
820   }
821   NodePtr ResB = identifyNode(BR, BI);
822   if (!ResB) {
823     LLVM_DEBUG(dbgs() << " - BR/BI is not identified as a composite node.\n");
824     return nullptr;
825   }
826 
827   NodePtr Node =
828       prepareCompositeNode(ComplexDeinterleavingOperation::CAdd, Real, Imag);
829   Node->Rotation = Rotation;
830   Node->addOperand(ResA);
831   Node->addOperand(ResB);
832   return submitCompositeNode(Node);
833 }
834 
835 static bool isInstructionPairAdd(Instruction *A, Instruction *B) {
836   unsigned OpcA = A->getOpcode();
837   unsigned OpcB = B->getOpcode();
838 
839   return (OpcA == Instruction::FSub && OpcB == Instruction::FAdd) ||
840          (OpcA == Instruction::FAdd && OpcB == Instruction::FSub) ||
841          (OpcA == Instruction::Sub && OpcB == Instruction::Add) ||
842          (OpcA == Instruction::Add && OpcB == Instruction::Sub);
843 }
844 
845 static bool isInstructionPairMul(Instruction *A, Instruction *B) {
846   auto Pattern =
847       m_BinOp(m_FMul(m_Value(), m_Value()), m_FMul(m_Value(), m_Value()));
848 
849   return match(A, Pattern) && match(B, Pattern);
850 }
851 
852 static bool isInstructionPotentiallySymmetric(Instruction *I) {
853   switch (I->getOpcode()) {
854   case Instruction::FAdd:
855   case Instruction::FSub:
856   case Instruction::FMul:
857   case Instruction::FNeg:
858   case Instruction::Add:
859   case Instruction::Sub:
860   case Instruction::Mul:
861     return true;
862   default:
863     return false;
864   }
865 }
866 
867 ComplexDeinterleavingGraph::NodePtr
868 ComplexDeinterleavingGraph::identifySymmetricOperation(Instruction *Real,
869                                                        Instruction *Imag) {
870   if (Real->getOpcode() != Imag->getOpcode())
871     return nullptr;
872 
873   if (!isInstructionPotentiallySymmetric(Real) ||
874       !isInstructionPotentiallySymmetric(Imag))
875     return nullptr;
876 
877   auto *R0 = Real->getOperand(0);
878   auto *I0 = Imag->getOperand(0);
879 
880   NodePtr Op0 = identifyNode(R0, I0);
881   NodePtr Op1 = nullptr;
882   if (Op0 == nullptr)
883     return nullptr;
884 
885   if (Real->isBinaryOp()) {
886     auto *R1 = Real->getOperand(1);
887     auto *I1 = Imag->getOperand(1);
888     Op1 = identifyNode(R1, I1);
889     if (Op1 == nullptr)
890       return nullptr;
891   }
892 
893   if (isa<FPMathOperator>(Real) &&
894       Real->getFastMathFlags() != Imag->getFastMathFlags())
895     return nullptr;
896 
897   auto Node = prepareCompositeNode(ComplexDeinterleavingOperation::Symmetric,
898                                    Real, Imag);
899   Node->Opcode = Real->getOpcode();
900   if (isa<FPMathOperator>(Real))
901     Node->Flags = Real->getFastMathFlags();
902 
903   Node->addOperand(Op0);
904   if (Real->isBinaryOp())
905     Node->addOperand(Op1);
906 
907   return submitCompositeNode(Node);
908 }
909 
910 ComplexDeinterleavingGraph::NodePtr
911 ComplexDeinterleavingGraph::identifyDotProduct(Value *V) {
912 
913   if (!TL->isComplexDeinterleavingOperationSupported(
914           ComplexDeinterleavingOperation::CDot, V->getType())) {
915     LLVM_DEBUG(dbgs() << "Target doesn't support complex deinterleaving "
916                          "operation CDot with the type "
917                       << *V->getType() << "\n");
918     return nullptr;
919   }
920 
921   auto *Inst = cast<Instruction>(V);
922   auto *RealUser = cast<Instruction>(*Inst->user_begin());
923 
924   NodePtr CN =
925       prepareCompositeNode(ComplexDeinterleavingOperation::CDot, Inst, nullptr);
926 
927   NodePtr ANode;
928 
929   const Intrinsic::ID PartialReduceInt =
930       Intrinsic::experimental_vector_partial_reduce_add;
931 
932   Value *AReal = nullptr;
933   Value *AImag = nullptr;
934   Value *BReal = nullptr;
935   Value *BImag = nullptr;
936   Value *Phi = nullptr;
937 
938   auto UnwrapCast = [](Value *V) -> Value * {
939     if (auto *CI = dyn_cast<CastInst>(V))
940       return CI->getOperand(0);
941     return V;
942   };
943 
944   auto PatternRot0 = m_Intrinsic<PartialReduceInt>(
945       m_Intrinsic<PartialReduceInt>(m_Value(Phi),
946                                     m_Mul(m_Value(BReal), m_Value(AReal))),
947       m_Neg(m_Mul(m_Value(BImag), m_Value(AImag))));
948 
949   auto PatternRot270 = m_Intrinsic<PartialReduceInt>(
950       m_Intrinsic<PartialReduceInt>(
951           m_Value(Phi), m_Neg(m_Mul(m_Value(BReal), m_Value(AImag)))),
952       m_Mul(m_Value(BImag), m_Value(AReal)));
953 
954   if (match(Inst, PatternRot0)) {
955     CN->Rotation = ComplexDeinterleavingRotation::Rotation_0;
956   } else if (match(Inst, PatternRot270)) {
957     CN->Rotation = ComplexDeinterleavingRotation::Rotation_270;
958   } else {
959     Value *A0, *A1;
960     // The rotations 90 and 180 share the same operation pattern, so inspect the
961     // order of the operands, identifying where the real and imaginary
962     // components of A go, to discern between the aforementioned rotations.
963     auto PatternRot90Rot180 = m_Intrinsic<PartialReduceInt>(
964         m_Intrinsic<PartialReduceInt>(m_Value(Phi),
965                                       m_Mul(m_Value(BReal), m_Value(A0))),
966         m_Mul(m_Value(BImag), m_Value(A1)));
967 
968     if (!match(Inst, PatternRot90Rot180))
969       return nullptr;
970 
971     A0 = UnwrapCast(A0);
972     A1 = UnwrapCast(A1);
973 
974     // Test if A0 is real/A1 is imag
975     ANode = identifyNode(A0, A1);
976     if (!ANode) {
977       // Test if A0 is imag/A1 is real
978       ANode = identifyNode(A1, A0);
979       // Unable to identify operand components, thus unable to identify rotation
980       if (!ANode)
981         return nullptr;
982       CN->Rotation = ComplexDeinterleavingRotation::Rotation_90;
983       AReal = A1;
984       AImag = A0;
985     } else {
986       AReal = A0;
987       AImag = A1;
988       CN->Rotation = ComplexDeinterleavingRotation::Rotation_180;
989     }
990   }
991 
992   AReal = UnwrapCast(AReal);
993   AImag = UnwrapCast(AImag);
994   BReal = UnwrapCast(BReal);
995   BImag = UnwrapCast(BImag);
996 
997   VectorType *VTy = cast<VectorType>(V->getType());
998   Type *ExpectedOperandTy = VectorType::getSubdividedVectorType(VTy, 2);
999   if (AReal->getType() != ExpectedOperandTy)
1000     return nullptr;
1001   if (AImag->getType() != ExpectedOperandTy)
1002     return nullptr;
1003   if (BReal->getType() != ExpectedOperandTy)
1004     return nullptr;
1005   if (BImag->getType() != ExpectedOperandTy)
1006     return nullptr;
1007 
1008   if (Phi->getType() != VTy && RealUser->getType() != VTy)
1009     return nullptr;
1010 
1011   NodePtr Node = identifyNode(AReal, AImag);
1012 
1013   // In the case that a node was identified to figure out the rotation, ensure
1014   // that trying to identify a node with AReal and AImag post-unwrap results in
1015   // the same node
1016   if (ANode && Node != ANode) {
1017     LLVM_DEBUG(
1018         dbgs()
1019         << "Identified node is different from previously identified node. "
1020            "Unable to confidently generate a complex operation node\n");
1021     return nullptr;
1022   }
1023 
1024   CN->addOperand(Node);
1025   CN->addOperand(identifyNode(BReal, BImag));
1026   CN->addOperand(identifyNode(Phi, RealUser));
1027 
1028   return submitCompositeNode(CN);
1029 }
1030 
1031 ComplexDeinterleavingGraph::NodePtr
1032 ComplexDeinterleavingGraph::identifyPartialReduction(Value *R, Value *I) {
1033   // Partial reductions don't support non-vector types, so check these first
1034   if (!isa<VectorType>(R->getType()) || !isa<VectorType>(I->getType()))
1035     return nullptr;
1036 
1037   auto CommonUser =
1038       findCommonBetweenCollections<Value *>(R->users(), I->users());
1039   if (!CommonUser)
1040     return nullptr;
1041 
1042   auto *IInst = dyn_cast<IntrinsicInst>(*CommonUser);
1043   if (!IInst || IInst->getIntrinsicID() !=
1044                     Intrinsic::experimental_vector_partial_reduce_add)
1045     return nullptr;
1046 
1047   if (NodePtr CN = identifyDotProduct(IInst))
1048     return CN;
1049 
1050   return nullptr;
1051 }
1052 
1053 ComplexDeinterleavingGraph::NodePtr
1054 ComplexDeinterleavingGraph::identifyNode(Value *R, Value *I) {
1055   auto It = CachedResult.find({R, I});
1056   if (It != CachedResult.end()) {
1057     LLVM_DEBUG(dbgs() << " - Folding to existing node\n");
1058     return It->second;
1059   }
1060 
1061   if (NodePtr CN = identifyPartialReduction(R, I))
1062     return CN;
1063 
1064   bool IsReduction = RealPHI == R && (!ImagPHI || ImagPHI == I);
1065   if (!IsReduction && R->getType() != I->getType())
1066     return nullptr;
1067 
1068   if (NodePtr CN = identifySplat(R, I))
1069     return CN;
1070 
1071   auto *Real = dyn_cast<Instruction>(R);
1072   auto *Imag = dyn_cast<Instruction>(I);
1073   if (!Real || !Imag)
1074     return nullptr;
1075 
1076   if (NodePtr CN = identifyDeinterleave(Real, Imag))
1077     return CN;
1078 
1079   if (NodePtr CN = identifyPHINode(Real, Imag))
1080     return CN;
1081 
1082   if (NodePtr CN = identifySelectNode(Real, Imag))
1083     return CN;
1084 
1085   auto *VTy = cast<VectorType>(Real->getType());
1086   auto *NewVTy = VectorType::getDoubleElementsVectorType(VTy);
1087 
1088   bool HasCMulSupport = TL->isComplexDeinterleavingOperationSupported(
1089       ComplexDeinterleavingOperation::CMulPartial, NewVTy);
1090   bool HasCAddSupport = TL->isComplexDeinterleavingOperationSupported(
1091       ComplexDeinterleavingOperation::CAdd, NewVTy);
1092 
1093   if (HasCMulSupport && isInstructionPairMul(Real, Imag)) {
1094     if (NodePtr CN = identifyPartialMul(Real, Imag))
1095       return CN;
1096   }
1097 
1098   if (HasCAddSupport && isInstructionPairAdd(Real, Imag)) {
1099     if (NodePtr CN = identifyAdd(Real, Imag))
1100       return CN;
1101   }
1102 
1103   if (HasCMulSupport && HasCAddSupport) {
1104     if (NodePtr CN = identifyReassocNodes(Real, Imag))
1105       return CN;
1106   }
1107 
1108   if (NodePtr CN = identifySymmetricOperation(Real, Imag))
1109     return CN;
1110 
1111   LLVM_DEBUG(dbgs() << "  - Not recognised as a valid pattern.\n");
1112   CachedResult[{R, I}] = nullptr;
1113   return nullptr;
1114 }
1115 
1116 ComplexDeinterleavingGraph::NodePtr
1117 ComplexDeinterleavingGraph::identifyReassocNodes(Instruction *Real,
1118                                                  Instruction *Imag) {
1119   auto IsOperationSupported = [](unsigned Opcode) -> bool {
1120     return Opcode == Instruction::FAdd || Opcode == Instruction::FSub ||
1121            Opcode == Instruction::FNeg || Opcode == Instruction::Add ||
1122            Opcode == Instruction::Sub;
1123   };
1124 
1125   if (!IsOperationSupported(Real->getOpcode()) ||
1126       !IsOperationSupported(Imag->getOpcode()))
1127     return nullptr;
1128 
1129   std::optional<FastMathFlags> Flags;
1130   if (isa<FPMathOperator>(Real)) {
1131     if (Real->getFastMathFlags() != Imag->getFastMathFlags()) {
1132       LLVM_DEBUG(dbgs() << "The flags in Real and Imaginary instructions are "
1133                            "not identical\n");
1134       return nullptr;
1135     }
1136 
1137     Flags = Real->getFastMathFlags();
1138     if (!Flags->allowReassoc()) {
1139       LLVM_DEBUG(
1140           dbgs()
1141           << "the 'Reassoc' attribute is missing in the FastMath flags\n");
1142       return nullptr;
1143     }
1144   }
1145 
1146   // Collect multiplications and addend instructions from the given instruction
1147   // while traversing it operands. Additionally, verify that all instructions
1148   // have the same fast math flags.
1149   auto Collect = [&Flags](Instruction *Insn, std::vector<Product> &Muls,
1150                           std::list<Addend> &Addends) -> bool {
1151     SmallVector<PointerIntPair<Value *, 1, bool>> Worklist = {{Insn, true}};
1152     SmallPtrSet<Value *, 8> Visited;
1153     while (!Worklist.empty()) {
1154       auto [V, IsPositive] = Worklist.back();
1155       Worklist.pop_back();
1156       if (!Visited.insert(V).second)
1157         continue;
1158 
1159       Instruction *I = dyn_cast<Instruction>(V);
1160       if (!I) {
1161         Addends.emplace_back(V, IsPositive);
1162         continue;
1163       }
1164 
1165       // If an instruction has more than one user, it indicates that it either
1166       // has an external user, which will be later checked by the checkNodes
1167       // function, or it is a subexpression utilized by multiple expressions. In
1168       // the latter case, we will attempt to separately identify the complex
1169       // operation from here in order to create a shared
1170       // ComplexDeinterleavingCompositeNode.
1171       if (I != Insn && I->getNumUses() > 1) {
1172         LLVM_DEBUG(dbgs() << "Found potential sub-expression: " << *I << "\n");
1173         Addends.emplace_back(I, IsPositive);
1174         continue;
1175       }
1176       switch (I->getOpcode()) {
1177       case Instruction::FAdd:
1178       case Instruction::Add:
1179         Worklist.emplace_back(I->getOperand(1), IsPositive);
1180         Worklist.emplace_back(I->getOperand(0), IsPositive);
1181         break;
1182       case Instruction::FSub:
1183         Worklist.emplace_back(I->getOperand(1), !IsPositive);
1184         Worklist.emplace_back(I->getOperand(0), IsPositive);
1185         break;
1186       case Instruction::Sub:
1187         if (isNeg(I)) {
1188           Worklist.emplace_back(getNegOperand(I), !IsPositive);
1189         } else {
1190           Worklist.emplace_back(I->getOperand(1), !IsPositive);
1191           Worklist.emplace_back(I->getOperand(0), IsPositive);
1192         }
1193         break;
1194       case Instruction::FMul:
1195       case Instruction::Mul: {
1196         Value *A, *B;
1197         if (isNeg(I->getOperand(0))) {
1198           A = getNegOperand(I->getOperand(0));
1199           IsPositive = !IsPositive;
1200         } else {
1201           A = I->getOperand(0);
1202         }
1203 
1204         if (isNeg(I->getOperand(1))) {
1205           B = getNegOperand(I->getOperand(1));
1206           IsPositive = !IsPositive;
1207         } else {
1208           B = I->getOperand(1);
1209         }
1210         Muls.push_back(Product{A, B, IsPositive});
1211         break;
1212       }
1213       case Instruction::FNeg:
1214         Worklist.emplace_back(I->getOperand(0), !IsPositive);
1215         break;
1216       default:
1217         Addends.emplace_back(I, IsPositive);
1218         continue;
1219       }
1220 
1221       if (Flags && I->getFastMathFlags() != *Flags) {
1222         LLVM_DEBUG(dbgs() << "The instruction's fast math flags are "
1223                              "inconsistent with the root instructions' flags: "
1224                           << *I << "\n");
1225         return false;
1226       }
1227     }
1228     return true;
1229   };
1230 
1231   std::vector<Product> RealMuls, ImagMuls;
1232   std::list<Addend> RealAddends, ImagAddends;
1233   if (!Collect(Real, RealMuls, RealAddends) ||
1234       !Collect(Imag, ImagMuls, ImagAddends))
1235     return nullptr;
1236 
1237   if (RealAddends.size() != ImagAddends.size())
1238     return nullptr;
1239 
1240   NodePtr FinalNode;
1241   if (!RealMuls.empty() || !ImagMuls.empty()) {
1242     // If there are multiplicands, extract positive addend and use it as an
1243     // accumulator
1244     FinalNode = extractPositiveAddend(RealAddends, ImagAddends);
1245     FinalNode = identifyMultiplications(RealMuls, ImagMuls, FinalNode);
1246     if (!FinalNode)
1247       return nullptr;
1248   }
1249 
1250   // Identify and process remaining additions
1251   if (!RealAddends.empty() || !ImagAddends.empty()) {
1252     FinalNode = identifyAdditions(RealAddends, ImagAddends, Flags, FinalNode);
1253     if (!FinalNode)
1254       return nullptr;
1255   }
1256   assert(FinalNode && "FinalNode can not be nullptr here");
1257   // Set the Real and Imag fields of the final node and submit it
1258   FinalNode->Real = Real;
1259   FinalNode->Imag = Imag;
1260   submitCompositeNode(FinalNode);
1261   return FinalNode;
1262 }
1263 
1264 bool ComplexDeinterleavingGraph::collectPartialMuls(
1265     const std::vector<Product> &RealMuls, const std::vector<Product> &ImagMuls,
1266     std::vector<PartialMulCandidate> &PartialMulCandidates) {
1267   // Helper function to extract a common operand from two products
1268   auto FindCommonInstruction = [](const Product &Real,
1269                                   const Product &Imag) -> Value * {
1270     if (Real.Multiplicand == Imag.Multiplicand ||
1271         Real.Multiplicand == Imag.Multiplier)
1272       return Real.Multiplicand;
1273 
1274     if (Real.Multiplier == Imag.Multiplicand ||
1275         Real.Multiplier == Imag.Multiplier)
1276       return Real.Multiplier;
1277 
1278     return nullptr;
1279   };
1280 
1281   // Iterating over real and imaginary multiplications to find common operands
1282   // If a common operand is found, a partial multiplication candidate is created
1283   // and added to the candidates vector The function returns false if no common
1284   // operands are found for any product
1285   for (unsigned i = 0; i < RealMuls.size(); ++i) {
1286     bool FoundCommon = false;
1287     for (unsigned j = 0; j < ImagMuls.size(); ++j) {
1288       auto *Common = FindCommonInstruction(RealMuls[i], ImagMuls[j]);
1289       if (!Common)
1290         continue;
1291 
1292       auto *A = RealMuls[i].Multiplicand == Common ? RealMuls[i].Multiplier
1293                                                    : RealMuls[i].Multiplicand;
1294       auto *B = ImagMuls[j].Multiplicand == Common ? ImagMuls[j].Multiplier
1295                                                    : ImagMuls[j].Multiplicand;
1296 
1297       auto Node = identifyNode(A, B);
1298       if (Node) {
1299         FoundCommon = true;
1300         PartialMulCandidates.push_back({Common, Node, i, j, false});
1301       }
1302 
1303       Node = identifyNode(B, A);
1304       if (Node) {
1305         FoundCommon = true;
1306         PartialMulCandidates.push_back({Common, Node, i, j, true});
1307       }
1308     }
1309     if (!FoundCommon)
1310       return false;
1311   }
1312   return true;
1313 }
1314 
1315 ComplexDeinterleavingGraph::NodePtr
1316 ComplexDeinterleavingGraph::identifyMultiplications(
1317     std::vector<Product> &RealMuls, std::vector<Product> &ImagMuls,
1318     NodePtr Accumulator = nullptr) {
1319   if (RealMuls.size() != ImagMuls.size())
1320     return nullptr;
1321 
1322   std::vector<PartialMulCandidate> Info;
1323   if (!collectPartialMuls(RealMuls, ImagMuls, Info))
1324     return nullptr;
1325 
1326   // Map to store common instruction to node pointers
1327   std::map<Value *, NodePtr> CommonToNode;
1328   std::vector<bool> Processed(Info.size(), false);
1329   for (unsigned I = 0; I < Info.size(); ++I) {
1330     if (Processed[I])
1331       continue;
1332 
1333     PartialMulCandidate &InfoA = Info[I];
1334     for (unsigned J = I + 1; J < Info.size(); ++J) {
1335       if (Processed[J])
1336         continue;
1337 
1338       PartialMulCandidate &InfoB = Info[J];
1339       auto *InfoReal = &InfoA;
1340       auto *InfoImag = &InfoB;
1341 
1342       auto NodeFromCommon = identifyNode(InfoReal->Common, InfoImag->Common);
1343       if (!NodeFromCommon) {
1344         std::swap(InfoReal, InfoImag);
1345         NodeFromCommon = identifyNode(InfoReal->Common, InfoImag->Common);
1346       }
1347       if (!NodeFromCommon)
1348         continue;
1349 
1350       CommonToNode[InfoReal->Common] = NodeFromCommon;
1351       CommonToNode[InfoImag->Common] = NodeFromCommon;
1352       Processed[I] = true;
1353       Processed[J] = true;
1354     }
1355   }
1356 
1357   std::vector<bool> ProcessedReal(RealMuls.size(), false);
1358   std::vector<bool> ProcessedImag(ImagMuls.size(), false);
1359   NodePtr Result = Accumulator;
1360   for (auto &PMI : Info) {
1361     if (ProcessedReal[PMI.RealIdx] || ProcessedImag[PMI.ImagIdx])
1362       continue;
1363 
1364     auto It = CommonToNode.find(PMI.Common);
1365     // TODO: Process independent complex multiplications. Cases like this:
1366     //  A.real() * B where both A and B are complex numbers.
1367     if (It == CommonToNode.end()) {
1368       LLVM_DEBUG({
1369         dbgs() << "Unprocessed independent partial multiplication:\n";
1370         for (auto *Mul : {&RealMuls[PMI.RealIdx], &RealMuls[PMI.RealIdx]})
1371           dbgs().indent(4) << (Mul->IsPositive ? "+" : "-") << *Mul->Multiplier
1372                            << " multiplied by " << *Mul->Multiplicand << "\n";
1373       });
1374       return nullptr;
1375     }
1376 
1377     auto &RealMul = RealMuls[PMI.RealIdx];
1378     auto &ImagMul = ImagMuls[PMI.ImagIdx];
1379 
1380     auto NodeA = It->second;
1381     auto NodeB = PMI.Node;
1382     auto IsMultiplicandReal = PMI.Common == NodeA->Real;
1383     // The following table illustrates the relationship between multiplications
1384     // and rotations. If we consider the multiplication (X + iY) * (U + iV), we
1385     // can see:
1386     //
1387     // Rotation |   Real |   Imag |
1388     // ---------+--------+--------+
1389     //        0 |  x * u |  x * v |
1390     //       90 | -y * v |  y * u |
1391     //      180 | -x * u | -x * v |
1392     //      270 |  y * v | -y * u |
1393     //
1394     // Check if the candidate can indeed be represented by partial
1395     // multiplication
1396     // TODO: Add support for multiplication by complex one
1397     if ((IsMultiplicandReal && PMI.IsNodeInverted) ||
1398         (!IsMultiplicandReal && !PMI.IsNodeInverted))
1399       continue;
1400 
1401     // Determine the rotation based on the multiplications
1402     ComplexDeinterleavingRotation Rotation;
1403     if (IsMultiplicandReal) {
1404       // Detect 0 and 180 degrees rotation
1405       if (RealMul.IsPositive && ImagMul.IsPositive)
1406         Rotation = llvm::ComplexDeinterleavingRotation::Rotation_0;
1407       else if (!RealMul.IsPositive && !ImagMul.IsPositive)
1408         Rotation = llvm::ComplexDeinterleavingRotation::Rotation_180;
1409       else
1410         continue;
1411 
1412     } else {
1413       // Detect 90 and 270 degrees rotation
1414       if (!RealMul.IsPositive && ImagMul.IsPositive)
1415         Rotation = llvm::ComplexDeinterleavingRotation::Rotation_90;
1416       else if (RealMul.IsPositive && !ImagMul.IsPositive)
1417         Rotation = llvm::ComplexDeinterleavingRotation::Rotation_270;
1418       else
1419         continue;
1420     }
1421 
1422     LLVM_DEBUG({
1423       dbgs() << "Identified partial multiplication (X, Y) * (U, V):\n";
1424       dbgs().indent(4) << "X: " << *NodeA->Real << "\n";
1425       dbgs().indent(4) << "Y: " << *NodeA->Imag << "\n";
1426       dbgs().indent(4) << "U: " << *NodeB->Real << "\n";
1427       dbgs().indent(4) << "V: " << *NodeB->Imag << "\n";
1428       dbgs().indent(4) << "Rotation - " << (int)Rotation * 90 << "\n";
1429     });
1430 
1431     NodePtr NodeMul = prepareCompositeNode(
1432         ComplexDeinterleavingOperation::CMulPartial, nullptr, nullptr);
1433     NodeMul->Rotation = Rotation;
1434     NodeMul->addOperand(NodeA);
1435     NodeMul->addOperand(NodeB);
1436     if (Result)
1437       NodeMul->addOperand(Result);
1438     submitCompositeNode(NodeMul);
1439     Result = NodeMul;
1440     ProcessedReal[PMI.RealIdx] = true;
1441     ProcessedImag[PMI.ImagIdx] = true;
1442   }
1443 
1444   // Ensure all products have been processed, if not return nullptr.
1445   if (!all_of(ProcessedReal, [](bool V) { return V; }) ||
1446       !all_of(ProcessedImag, [](bool V) { return V; })) {
1447 
1448     // Dump debug information about which partial multiplications are not
1449     // processed.
1450     LLVM_DEBUG({
1451       dbgs() << "Unprocessed products (Real):\n";
1452       for (size_t i = 0; i < ProcessedReal.size(); ++i) {
1453         if (!ProcessedReal[i])
1454           dbgs().indent(4) << (RealMuls[i].IsPositive ? "+" : "-")
1455                            << *RealMuls[i].Multiplier << " multiplied by "
1456                            << *RealMuls[i].Multiplicand << "\n";
1457       }
1458       dbgs() << "Unprocessed products (Imag):\n";
1459       for (size_t i = 0; i < ProcessedImag.size(); ++i) {
1460         if (!ProcessedImag[i])
1461           dbgs().indent(4) << (ImagMuls[i].IsPositive ? "+" : "-")
1462                            << *ImagMuls[i].Multiplier << " multiplied by "
1463                            << *ImagMuls[i].Multiplicand << "\n";
1464       }
1465     });
1466     return nullptr;
1467   }
1468 
1469   return Result;
1470 }
1471 
1472 ComplexDeinterleavingGraph::NodePtr
1473 ComplexDeinterleavingGraph::identifyAdditions(
1474     std::list<Addend> &RealAddends, std::list<Addend> &ImagAddends,
1475     std::optional<FastMathFlags> Flags, NodePtr Accumulator = nullptr) {
1476   if (RealAddends.size() != ImagAddends.size())
1477     return nullptr;
1478 
1479   NodePtr Result;
1480   // If we have accumulator use it as first addend
1481   if (Accumulator)
1482     Result = Accumulator;
1483   // Otherwise find an element with both positive real and imaginary parts.
1484   else
1485     Result = extractPositiveAddend(RealAddends, ImagAddends);
1486 
1487   if (!Result)
1488     return nullptr;
1489 
1490   while (!RealAddends.empty()) {
1491     auto ItR = RealAddends.begin();
1492     auto [R, IsPositiveR] = *ItR;
1493 
1494     bool FoundImag = false;
1495     for (auto ItI = ImagAddends.begin(); ItI != ImagAddends.end(); ++ItI) {
1496       auto [I, IsPositiveI] = *ItI;
1497       ComplexDeinterleavingRotation Rotation;
1498       if (IsPositiveR && IsPositiveI)
1499         Rotation = ComplexDeinterleavingRotation::Rotation_0;
1500       else if (!IsPositiveR && IsPositiveI)
1501         Rotation = ComplexDeinterleavingRotation::Rotation_90;
1502       else if (!IsPositiveR && !IsPositiveI)
1503         Rotation = ComplexDeinterleavingRotation::Rotation_180;
1504       else
1505         Rotation = ComplexDeinterleavingRotation::Rotation_270;
1506 
1507       NodePtr AddNode;
1508       if (Rotation == ComplexDeinterleavingRotation::Rotation_0 ||
1509           Rotation == ComplexDeinterleavingRotation::Rotation_180) {
1510         AddNode = identifyNode(R, I);
1511       } else {
1512         AddNode = identifyNode(I, R);
1513       }
1514       if (AddNode) {
1515         LLVM_DEBUG({
1516           dbgs() << "Identified addition:\n";
1517           dbgs().indent(4) << "X: " << *R << "\n";
1518           dbgs().indent(4) << "Y: " << *I << "\n";
1519           dbgs().indent(4) << "Rotation - " << (int)Rotation * 90 << "\n";
1520         });
1521 
1522         NodePtr TmpNode;
1523         if (Rotation == llvm::ComplexDeinterleavingRotation::Rotation_0) {
1524           TmpNode = prepareCompositeNode(
1525               ComplexDeinterleavingOperation::Symmetric, nullptr, nullptr);
1526           if (Flags) {
1527             TmpNode->Opcode = Instruction::FAdd;
1528             TmpNode->Flags = *Flags;
1529           } else {
1530             TmpNode->Opcode = Instruction::Add;
1531           }
1532         } else if (Rotation ==
1533                    llvm::ComplexDeinterleavingRotation::Rotation_180) {
1534           TmpNode = prepareCompositeNode(
1535               ComplexDeinterleavingOperation::Symmetric, nullptr, nullptr);
1536           if (Flags) {
1537             TmpNode->Opcode = Instruction::FSub;
1538             TmpNode->Flags = *Flags;
1539           } else {
1540             TmpNode->Opcode = Instruction::Sub;
1541           }
1542         } else {
1543           TmpNode = prepareCompositeNode(ComplexDeinterleavingOperation::CAdd,
1544                                          nullptr, nullptr);
1545           TmpNode->Rotation = Rotation;
1546         }
1547 
1548         TmpNode->addOperand(Result);
1549         TmpNode->addOperand(AddNode);
1550         submitCompositeNode(TmpNode);
1551         Result = TmpNode;
1552         RealAddends.erase(ItR);
1553         ImagAddends.erase(ItI);
1554         FoundImag = true;
1555         break;
1556       }
1557     }
1558     if (!FoundImag)
1559       return nullptr;
1560   }
1561   return Result;
1562 }
1563 
1564 ComplexDeinterleavingGraph::NodePtr
1565 ComplexDeinterleavingGraph::extractPositiveAddend(
1566     std::list<Addend> &RealAddends, std::list<Addend> &ImagAddends) {
1567   for (auto ItR = RealAddends.begin(); ItR != RealAddends.end(); ++ItR) {
1568     for (auto ItI = ImagAddends.begin(); ItI != ImagAddends.end(); ++ItI) {
1569       auto [R, IsPositiveR] = *ItR;
1570       auto [I, IsPositiveI] = *ItI;
1571       if (IsPositiveR && IsPositiveI) {
1572         auto Result = identifyNode(R, I);
1573         if (Result) {
1574           RealAddends.erase(ItR);
1575           ImagAddends.erase(ItI);
1576           return Result;
1577         }
1578       }
1579     }
1580   }
1581   return nullptr;
1582 }
1583 
1584 bool ComplexDeinterleavingGraph::identifyNodes(Instruction *RootI) {
1585   // This potential root instruction might already have been recognized as
1586   // reduction. Because RootToNode maps both Real and Imaginary parts to
1587   // CompositeNode we should choose only one either Real or Imag instruction to
1588   // use as an anchor for generating complex instruction.
1589   auto It = RootToNode.find(RootI);
1590   if (It != RootToNode.end()) {
1591     auto RootNode = It->second;
1592     assert(RootNode->Operation ==
1593                ComplexDeinterleavingOperation::ReductionOperation ||
1594            RootNode->Operation ==
1595                ComplexDeinterleavingOperation::ReductionSingle);
1596     // Find out which part, Real or Imag, comes later, and only if we come to
1597     // the latest part, add it to OrderedRoots.
1598     auto *R = cast<Instruction>(RootNode->Real);
1599     auto *I = RootNode->Imag ? cast<Instruction>(RootNode->Imag) : nullptr;
1600 
1601     Instruction *ReplacementAnchor;
1602     if (I)
1603       ReplacementAnchor = R->comesBefore(I) ? I : R;
1604     else
1605       ReplacementAnchor = R;
1606 
1607     if (ReplacementAnchor != RootI)
1608       return false;
1609     OrderedRoots.push_back(RootI);
1610     return true;
1611   }
1612 
1613   auto RootNode = identifyRoot(RootI);
1614   if (!RootNode)
1615     return false;
1616 
1617   LLVM_DEBUG({
1618     Function *F = RootI->getFunction();
1619     BasicBlock *B = RootI->getParent();
1620     dbgs() << "Complex deinterleaving graph for " << F->getName()
1621            << "::" << B->getName() << ".\n";
1622     dump(dbgs());
1623     dbgs() << "\n";
1624   });
1625   RootToNode[RootI] = RootNode;
1626   OrderedRoots.push_back(RootI);
1627   return true;
1628 }
1629 
1630 bool ComplexDeinterleavingGraph::collectPotentialReductions(BasicBlock *B) {
1631   bool FoundPotentialReduction = false;
1632 
1633   auto *Br = dyn_cast<BranchInst>(B->getTerminator());
1634   if (!Br || Br->getNumSuccessors() != 2)
1635     return false;
1636 
1637   // Identify simple one-block loop
1638   if (Br->getSuccessor(0) != B && Br->getSuccessor(1) != B)
1639     return false;
1640 
1641   SmallVector<PHINode *> PHIs;
1642   for (auto &PHI : B->phis()) {
1643     if (PHI.getNumIncomingValues() != 2)
1644       continue;
1645 
1646     if (!PHI.getType()->isVectorTy())
1647       continue;
1648 
1649     auto *ReductionOp = dyn_cast<Instruction>(PHI.getIncomingValueForBlock(B));
1650     if (!ReductionOp)
1651       continue;
1652 
1653     // Check if final instruction is reduced outside of current block
1654     Instruction *FinalReduction = nullptr;
1655     auto NumUsers = 0u;
1656     for (auto *U : ReductionOp->users()) {
1657       ++NumUsers;
1658       if (U == &PHI)
1659         continue;
1660       FinalReduction = dyn_cast<Instruction>(U);
1661     }
1662 
1663     if (NumUsers != 2 || !FinalReduction || FinalReduction->getParent() == B ||
1664         isa<PHINode>(FinalReduction))
1665       continue;
1666 
1667     ReductionInfo[ReductionOp] = {&PHI, FinalReduction};
1668     BackEdge = B;
1669     auto BackEdgeIdx = PHI.getBasicBlockIndex(B);
1670     auto IncomingIdx = BackEdgeIdx == 0 ? 1 : 0;
1671     Incoming = PHI.getIncomingBlock(IncomingIdx);
1672     FoundPotentialReduction = true;
1673 
1674     // If the initial value of PHINode is an Instruction, consider it a leaf
1675     // value of a complex deinterleaving graph.
1676     if (auto *InitPHI =
1677             dyn_cast<Instruction>(PHI.getIncomingValueForBlock(Incoming)))
1678       FinalInstructions.insert(InitPHI);
1679   }
1680   return FoundPotentialReduction;
1681 }
1682 
1683 void ComplexDeinterleavingGraph::identifyReductionNodes() {
1684   SmallVector<bool> Processed(ReductionInfo.size(), false);
1685   SmallVector<Instruction *> OperationInstruction;
1686   for (auto &P : ReductionInfo)
1687     OperationInstruction.push_back(P.first);
1688 
1689   // Identify a complex computation by evaluating two reduction operations that
1690   // potentially could be involved
1691   for (size_t i = 0; i < OperationInstruction.size(); ++i) {
1692     if (Processed[i])
1693       continue;
1694     for (size_t j = i + 1; j < OperationInstruction.size(); ++j) {
1695       if (Processed[j])
1696         continue;
1697       auto *Real = OperationInstruction[i];
1698       auto *Imag = OperationInstruction[j];
1699       if (Real->getType() != Imag->getType())
1700         continue;
1701 
1702       RealPHI = ReductionInfo[Real].first;
1703       ImagPHI = ReductionInfo[Imag].first;
1704       PHIsFound = false;
1705       auto Node = identifyNode(Real, Imag);
1706       if (!Node) {
1707         std::swap(Real, Imag);
1708         std::swap(RealPHI, ImagPHI);
1709         Node = identifyNode(Real, Imag);
1710       }
1711 
1712       // If a node is identified and reduction PHINode is used in the chain of
1713       // operations, mark its operation instructions as used to prevent
1714       // re-identification and attach the node to the real part
1715       if (Node && PHIsFound) {
1716         LLVM_DEBUG(dbgs() << "Identified reduction starting from instructions: "
1717                           << *Real << " / " << *Imag << "\n");
1718         Processed[i] = true;
1719         Processed[j] = true;
1720         auto RootNode = prepareCompositeNode(
1721             ComplexDeinterleavingOperation::ReductionOperation, Real, Imag);
1722         RootNode->addOperand(Node);
1723         RootToNode[Real] = RootNode;
1724         RootToNode[Imag] = RootNode;
1725         submitCompositeNode(RootNode);
1726         break;
1727       }
1728     }
1729 
1730     auto *Real = OperationInstruction[i];
1731     // We want to check that we have 2 operands, but the function attributes
1732     // being counted as operands bloats this value.
1733     if (Processed[i] || Real->getNumOperands() < 2)
1734       continue;
1735 
1736     RealPHI = ReductionInfo[Real].first;
1737     ImagPHI = nullptr;
1738     PHIsFound = false;
1739     auto Node = identifyNode(Real->getOperand(0), Real->getOperand(1));
1740     if (Node && PHIsFound) {
1741       LLVM_DEBUG(
1742           dbgs() << "Identified single reduction starting from instruction: "
1743                  << *Real << "/" << *ReductionInfo[Real].second << "\n");
1744       Processed[i] = true;
1745       auto RootNode = prepareCompositeNode(
1746           ComplexDeinterleavingOperation::ReductionSingle, Real, nullptr);
1747       RootNode->addOperand(Node);
1748       RootToNode[Real] = RootNode;
1749       submitCompositeNode(RootNode);
1750     }
1751   }
1752 
1753   RealPHI = nullptr;
1754   ImagPHI = nullptr;
1755 }
1756 
1757 bool ComplexDeinterleavingGraph::checkNodes() {
1758 
1759   bool FoundDeinterleaveNode = false;
1760   for (NodePtr N : CompositeNodes) {
1761     if (!N->areOperandsValid())
1762       return false;
1763     if (N->Operation == ComplexDeinterleavingOperation::Deinterleave)
1764       FoundDeinterleaveNode = true;
1765   }
1766 
1767   // We need a deinterleave node in order to guarantee that we're working with
1768   // complex numbers.
1769   if (!FoundDeinterleaveNode) {
1770     LLVM_DEBUG(
1771         dbgs() << "Couldn't find a deinterleave node within the graph, cannot "
1772                   "guarantee safety during graph transformation.\n");
1773     return false;
1774   }
1775 
1776   // Collect all instructions from roots to leaves
1777   SmallPtrSet<Instruction *, 16> AllInstructions;
1778   SmallVector<Instruction *, 8> Worklist;
1779   for (auto &Pair : RootToNode)
1780     Worklist.push_back(Pair.first);
1781 
1782   // Extract all instructions that are used by all XCMLA/XCADD/ADD/SUB/NEG
1783   // chains
1784   while (!Worklist.empty()) {
1785     auto *I = Worklist.back();
1786     Worklist.pop_back();
1787 
1788     if (!AllInstructions.insert(I).second)
1789       continue;
1790 
1791     for (Value *Op : I->operands()) {
1792       if (auto *OpI = dyn_cast<Instruction>(Op)) {
1793         if (!FinalInstructions.count(I))
1794           Worklist.emplace_back(OpI);
1795       }
1796     }
1797   }
1798 
1799   // Find instructions that have users outside of chain
1800   SmallVector<Instruction *, 2> OuterInstructions;
1801   for (auto *I : AllInstructions) {
1802     // Skip root nodes
1803     if (RootToNode.count(I))
1804       continue;
1805 
1806     for (User *U : I->users()) {
1807       if (AllInstructions.count(cast<Instruction>(U)))
1808         continue;
1809 
1810       // Found an instruction that is not used by XCMLA/XCADD chain
1811       Worklist.emplace_back(I);
1812       break;
1813     }
1814   }
1815 
1816   // If any instructions are found to be used outside, find and remove roots
1817   // that somehow connect to those instructions.
1818   SmallPtrSet<Instruction *, 16> Visited;
1819   while (!Worklist.empty()) {
1820     auto *I = Worklist.back();
1821     Worklist.pop_back();
1822     if (!Visited.insert(I).second)
1823       continue;
1824 
1825     // Found an impacted root node. Removing it from the nodes to be
1826     // deinterleaved
1827     if (RootToNode.count(I)) {
1828       LLVM_DEBUG(dbgs() << "Instruction " << *I
1829                         << " could be deinterleaved but its chain of complex "
1830                            "operations have an outside user\n");
1831       RootToNode.erase(I);
1832     }
1833 
1834     if (!AllInstructions.count(I) || FinalInstructions.count(I))
1835       continue;
1836 
1837     for (User *U : I->users())
1838       Worklist.emplace_back(cast<Instruction>(U));
1839 
1840     for (Value *Op : I->operands()) {
1841       if (auto *OpI = dyn_cast<Instruction>(Op))
1842         Worklist.emplace_back(OpI);
1843     }
1844   }
1845   return !RootToNode.empty();
1846 }
1847 
1848 ComplexDeinterleavingGraph::NodePtr
1849 ComplexDeinterleavingGraph::identifyRoot(Instruction *RootI) {
1850   if (auto *Intrinsic = dyn_cast<IntrinsicInst>(RootI)) {
1851     if (Intrinsic->getIntrinsicID() != Intrinsic::vector_interleave2)
1852       return nullptr;
1853 
1854     auto *Real = dyn_cast<Instruction>(Intrinsic->getOperand(0));
1855     auto *Imag = dyn_cast<Instruction>(Intrinsic->getOperand(1));
1856     if (!Real || !Imag)
1857       return nullptr;
1858 
1859     return identifyNode(Real, Imag);
1860   }
1861 
1862   auto *SVI = dyn_cast<ShuffleVectorInst>(RootI);
1863   if (!SVI)
1864     return nullptr;
1865 
1866   // Look for a shufflevector that takes separate vectors of the real and
1867   // imaginary components and recombines them into a single vector.
1868   if (!isInterleavingMask(SVI->getShuffleMask()))
1869     return nullptr;
1870 
1871   Instruction *Real;
1872   Instruction *Imag;
1873   if (!match(RootI, m_Shuffle(m_Instruction(Real), m_Instruction(Imag))))
1874     return nullptr;
1875 
1876   return identifyNode(Real, Imag);
1877 }
1878 
1879 ComplexDeinterleavingGraph::NodePtr
1880 ComplexDeinterleavingGraph::identifyDeinterleave(Instruction *Real,
1881                                                  Instruction *Imag) {
1882   Instruction *I = nullptr;
1883   Value *FinalValue = nullptr;
1884   if (match(Real, m_ExtractValue<0>(m_Instruction(I))) &&
1885       match(Imag, m_ExtractValue<1>(m_Specific(I))) &&
1886       match(I, m_Intrinsic<Intrinsic::vector_deinterleave2>(
1887                    m_Value(FinalValue)))) {
1888     NodePtr PlaceholderNode = prepareCompositeNode(
1889         llvm::ComplexDeinterleavingOperation::Deinterleave, Real, Imag);
1890     PlaceholderNode->ReplacementNode = FinalValue;
1891     FinalInstructions.insert(Real);
1892     FinalInstructions.insert(Imag);
1893     return submitCompositeNode(PlaceholderNode);
1894   }
1895 
1896   auto *RealShuffle = dyn_cast<ShuffleVectorInst>(Real);
1897   auto *ImagShuffle = dyn_cast<ShuffleVectorInst>(Imag);
1898   if (!RealShuffle || !ImagShuffle) {
1899     if (RealShuffle || ImagShuffle)
1900       LLVM_DEBUG(dbgs() << " - There's a shuffle where there shouldn't be.\n");
1901     return nullptr;
1902   }
1903 
1904   Value *RealOp1 = RealShuffle->getOperand(1);
1905   if (!isa<UndefValue>(RealOp1) && !isa<ConstantAggregateZero>(RealOp1)) {
1906     LLVM_DEBUG(dbgs() << " - RealOp1 is not undef or zero.\n");
1907     return nullptr;
1908   }
1909   Value *ImagOp1 = ImagShuffle->getOperand(1);
1910   if (!isa<UndefValue>(ImagOp1) && !isa<ConstantAggregateZero>(ImagOp1)) {
1911     LLVM_DEBUG(dbgs() << " - ImagOp1 is not undef or zero.\n");
1912     return nullptr;
1913   }
1914 
1915   Value *RealOp0 = RealShuffle->getOperand(0);
1916   Value *ImagOp0 = ImagShuffle->getOperand(0);
1917 
1918   if (RealOp0 != ImagOp0) {
1919     LLVM_DEBUG(dbgs() << " - Shuffle operands are not equal.\n");
1920     return nullptr;
1921   }
1922 
1923   ArrayRef<int> RealMask = RealShuffle->getShuffleMask();
1924   ArrayRef<int> ImagMask = ImagShuffle->getShuffleMask();
1925   if (!isDeinterleavingMask(RealMask) || !isDeinterleavingMask(ImagMask)) {
1926     LLVM_DEBUG(dbgs() << " - Masks are not deinterleaving.\n");
1927     return nullptr;
1928   }
1929 
1930   if (RealMask[0] != 0 || ImagMask[0] != 1) {
1931     LLVM_DEBUG(dbgs() << " - Masks do not have the correct initial value.\n");
1932     return nullptr;
1933   }
1934 
1935   // Type checking, the shuffle type should be a vector type of the same
1936   // scalar type, but half the size
1937   auto CheckType = [&](ShuffleVectorInst *Shuffle) {
1938     Value *Op = Shuffle->getOperand(0);
1939     auto *ShuffleTy = cast<FixedVectorType>(Shuffle->getType());
1940     auto *OpTy = cast<FixedVectorType>(Op->getType());
1941 
1942     if (OpTy->getScalarType() != ShuffleTy->getScalarType())
1943       return false;
1944     if ((ShuffleTy->getNumElements() * 2) != OpTy->getNumElements())
1945       return false;
1946 
1947     return true;
1948   };
1949 
1950   auto CheckDeinterleavingShuffle = [&](ShuffleVectorInst *Shuffle) -> bool {
1951     if (!CheckType(Shuffle))
1952       return false;
1953 
1954     ArrayRef<int> Mask = Shuffle->getShuffleMask();
1955     int Last = *Mask.rbegin();
1956 
1957     Value *Op = Shuffle->getOperand(0);
1958     auto *OpTy = cast<FixedVectorType>(Op->getType());
1959     int NumElements = OpTy->getNumElements();
1960 
1961     // Ensure that the deinterleaving shuffle only pulls from the first
1962     // shuffle operand.
1963     return Last < NumElements;
1964   };
1965 
1966   if (RealShuffle->getType() != ImagShuffle->getType()) {
1967     LLVM_DEBUG(dbgs() << " - Shuffle types aren't equal.\n");
1968     return nullptr;
1969   }
1970   if (!CheckDeinterleavingShuffle(RealShuffle)) {
1971     LLVM_DEBUG(dbgs() << " - RealShuffle is invalid type.\n");
1972     return nullptr;
1973   }
1974   if (!CheckDeinterleavingShuffle(ImagShuffle)) {
1975     LLVM_DEBUG(dbgs() << " - ImagShuffle is invalid type.\n");
1976     return nullptr;
1977   }
1978 
1979   NodePtr PlaceholderNode =
1980       prepareCompositeNode(llvm::ComplexDeinterleavingOperation::Deinterleave,
1981                            RealShuffle, ImagShuffle);
1982   PlaceholderNode->ReplacementNode = RealShuffle->getOperand(0);
1983   FinalInstructions.insert(RealShuffle);
1984   FinalInstructions.insert(ImagShuffle);
1985   return submitCompositeNode(PlaceholderNode);
1986 }
1987 
1988 ComplexDeinterleavingGraph::NodePtr
1989 ComplexDeinterleavingGraph::identifySplat(Value *R, Value *I) {
1990   auto IsSplat = [](Value *V) -> bool {
1991     // Fixed-width vector with constants
1992     if (isa<ConstantDataVector>(V))
1993       return true;
1994 
1995     VectorType *VTy;
1996     ArrayRef<int> Mask;
1997     // Splats are represented differently depending on whether the repeated
1998     // value is a constant or an Instruction
1999     if (auto *Const = dyn_cast<ConstantExpr>(V)) {
2000       if (Const->getOpcode() != Instruction::ShuffleVector)
2001         return false;
2002       VTy = cast<VectorType>(Const->getType());
2003       Mask = Const->getShuffleMask();
2004     } else if (auto *Shuf = dyn_cast<ShuffleVectorInst>(V)) {
2005       VTy = Shuf->getType();
2006       Mask = Shuf->getShuffleMask();
2007     } else {
2008       return false;
2009     }
2010 
2011     // When the data type is <1 x Type>, it's not possible to differentiate
2012     // between the ComplexDeinterleaving::Deinterleave and
2013     // ComplexDeinterleaving::Splat operations.
2014     if (!VTy->isScalableTy() && VTy->getElementCount().getKnownMinValue() == 1)
2015       return false;
2016 
2017     return all_equal(Mask) && Mask[0] == 0;
2018   };
2019 
2020   if (!IsSplat(R) || !IsSplat(I))
2021     return nullptr;
2022 
2023   auto *Real = dyn_cast<Instruction>(R);
2024   auto *Imag = dyn_cast<Instruction>(I);
2025   if ((!Real && Imag) || (Real && !Imag))
2026     return nullptr;
2027 
2028   if (Real && Imag) {
2029     // Non-constant splats should be in the same basic block
2030     if (Real->getParent() != Imag->getParent())
2031       return nullptr;
2032 
2033     FinalInstructions.insert(Real);
2034     FinalInstructions.insert(Imag);
2035   }
2036   NodePtr PlaceholderNode =
2037       prepareCompositeNode(ComplexDeinterleavingOperation::Splat, R, I);
2038   return submitCompositeNode(PlaceholderNode);
2039 }
2040 
2041 ComplexDeinterleavingGraph::NodePtr
2042 ComplexDeinterleavingGraph::identifyPHINode(Instruction *Real,
2043                                             Instruction *Imag) {
2044   if (Real != RealPHI || (ImagPHI && Imag != ImagPHI))
2045     return nullptr;
2046 
2047   PHIsFound = true;
2048   NodePtr PlaceholderNode = prepareCompositeNode(
2049       ComplexDeinterleavingOperation::ReductionPHI, Real, Imag);
2050   return submitCompositeNode(PlaceholderNode);
2051 }
2052 
2053 ComplexDeinterleavingGraph::NodePtr
2054 ComplexDeinterleavingGraph::identifySelectNode(Instruction *Real,
2055                                                Instruction *Imag) {
2056   auto *SelectReal = dyn_cast<SelectInst>(Real);
2057   auto *SelectImag = dyn_cast<SelectInst>(Imag);
2058   if (!SelectReal || !SelectImag)
2059     return nullptr;
2060 
2061   Instruction *MaskA, *MaskB;
2062   Instruction *AR, *AI, *RA, *BI;
2063   if (!match(Real, m_Select(m_Instruction(MaskA), m_Instruction(AR),
2064                             m_Instruction(RA))) ||
2065       !match(Imag, m_Select(m_Instruction(MaskB), m_Instruction(AI),
2066                             m_Instruction(BI))))
2067     return nullptr;
2068 
2069   if (MaskA != MaskB && !MaskA->isIdenticalTo(MaskB))
2070     return nullptr;
2071 
2072   if (!MaskA->getType()->isVectorTy())
2073     return nullptr;
2074 
2075   auto NodeA = identifyNode(AR, AI);
2076   if (!NodeA)
2077     return nullptr;
2078 
2079   auto NodeB = identifyNode(RA, BI);
2080   if (!NodeB)
2081     return nullptr;
2082 
2083   NodePtr PlaceholderNode = prepareCompositeNode(
2084       ComplexDeinterleavingOperation::ReductionSelect, Real, Imag);
2085   PlaceholderNode->addOperand(NodeA);
2086   PlaceholderNode->addOperand(NodeB);
2087   FinalInstructions.insert(MaskA);
2088   FinalInstructions.insert(MaskB);
2089   return submitCompositeNode(PlaceholderNode);
2090 }
2091 
2092 static Value *replaceSymmetricNode(IRBuilderBase &B, unsigned Opcode,
2093                                    std::optional<FastMathFlags> Flags,
2094                                    Value *InputA, Value *InputB) {
2095   Value *I;
2096   switch (Opcode) {
2097   case Instruction::FNeg:
2098     I = B.CreateFNeg(InputA);
2099     break;
2100   case Instruction::FAdd:
2101     I = B.CreateFAdd(InputA, InputB);
2102     break;
2103   case Instruction::Add:
2104     I = B.CreateAdd(InputA, InputB);
2105     break;
2106   case Instruction::FSub:
2107     I = B.CreateFSub(InputA, InputB);
2108     break;
2109   case Instruction::Sub:
2110     I = B.CreateSub(InputA, InputB);
2111     break;
2112   case Instruction::FMul:
2113     I = B.CreateFMul(InputA, InputB);
2114     break;
2115   case Instruction::Mul:
2116     I = B.CreateMul(InputA, InputB);
2117     break;
2118   default:
2119     llvm_unreachable("Incorrect symmetric opcode");
2120   }
2121   if (Flags)
2122     cast<Instruction>(I)->setFastMathFlags(*Flags);
2123   return I;
2124 }
2125 
2126 Value *ComplexDeinterleavingGraph::replaceNode(IRBuilderBase &Builder,
2127                                                RawNodePtr Node) {
2128   if (Node->ReplacementNode)
2129     return Node->ReplacementNode;
2130 
2131   auto ReplaceOperandIfExist = [&](RawNodePtr &Node, unsigned Idx) -> Value * {
2132     return Node->Operands.size() > Idx
2133                ? replaceNode(Builder, Node->Operands[Idx])
2134                : nullptr;
2135   };
2136 
2137   Value *ReplacementNode;
2138   switch (Node->Operation) {
2139   case ComplexDeinterleavingOperation::CDot: {
2140     Value *Input0 = ReplaceOperandIfExist(Node, 0);
2141     Value *Input1 = ReplaceOperandIfExist(Node, 1);
2142     Value *Accumulator = ReplaceOperandIfExist(Node, 2);
2143     assert(!Input1 || (Input0->getType() == Input1->getType() &&
2144                        "Node inputs need to be of the same type"));
2145     ReplacementNode = TL->createComplexDeinterleavingIR(
2146         Builder, Node->Operation, Node->Rotation, Input0, Input1, Accumulator);
2147     break;
2148   }
2149   case ComplexDeinterleavingOperation::CAdd:
2150   case ComplexDeinterleavingOperation::CMulPartial:
2151   case ComplexDeinterleavingOperation::Symmetric: {
2152     Value *Input0 = ReplaceOperandIfExist(Node, 0);
2153     Value *Input1 = ReplaceOperandIfExist(Node, 1);
2154     Value *Accumulator = ReplaceOperandIfExist(Node, 2);
2155     assert(!Input1 || (Input0->getType() == Input1->getType() &&
2156                        "Node inputs need to be of the same type"));
2157     assert(!Accumulator ||
2158            (Input0->getType() == Accumulator->getType() &&
2159             "Accumulator and input need to be of the same type"));
2160     if (Node->Operation == ComplexDeinterleavingOperation::Symmetric)
2161       ReplacementNode = replaceSymmetricNode(Builder, Node->Opcode, Node->Flags,
2162                                              Input0, Input1);
2163     else
2164       ReplacementNode = TL->createComplexDeinterleavingIR(
2165           Builder, Node->Operation, Node->Rotation, Input0, Input1,
2166           Accumulator);
2167     break;
2168   }
2169   case ComplexDeinterleavingOperation::Deinterleave:
2170     llvm_unreachable("Deinterleave node should already have ReplacementNode");
2171     break;
2172   case ComplexDeinterleavingOperation::Splat: {
2173     auto *NewTy = VectorType::getDoubleElementsVectorType(
2174         cast<VectorType>(Node->Real->getType()));
2175     auto *R = dyn_cast<Instruction>(Node->Real);
2176     auto *I = dyn_cast<Instruction>(Node->Imag);
2177     if (R && I) {
2178       // Splats that are not constant are interleaved where they are located
2179       Instruction *InsertPoint = (I->comesBefore(R) ? R : I)->getNextNode();
2180       IRBuilder<> IRB(InsertPoint);
2181       ReplacementNode = IRB.CreateIntrinsic(Intrinsic::vector_interleave2,
2182                                             NewTy, {Node->Real, Node->Imag});
2183     } else {
2184       ReplacementNode = Builder.CreateIntrinsic(
2185           Intrinsic::vector_interleave2, NewTy, {Node->Real, Node->Imag});
2186     }
2187     break;
2188   }
2189   case ComplexDeinterleavingOperation::ReductionPHI: {
2190     // If Operation is ReductionPHI, a new empty PHINode is created.
2191     // It is filled later when the ReductionOperation is processed.
2192     auto *OldPHI = cast<PHINode>(Node->Real);
2193     auto *VTy = cast<VectorType>(Node->Real->getType());
2194     auto *NewVTy = VectorType::getDoubleElementsVectorType(VTy);
2195     auto *NewPHI = PHINode::Create(NewVTy, 0, "", BackEdge->getFirstNonPHIIt());
2196     OldToNewPHI[OldPHI] = NewPHI;
2197     ReplacementNode = NewPHI;
2198     break;
2199   }
2200   case ComplexDeinterleavingOperation::ReductionSingle:
2201     ReplacementNode = replaceNode(Builder, Node->Operands[0]);
2202     processReductionSingle(ReplacementNode, Node);
2203     break;
2204   case ComplexDeinterleavingOperation::ReductionOperation:
2205     ReplacementNode = replaceNode(Builder, Node->Operands[0]);
2206     processReductionOperation(ReplacementNode, Node);
2207     break;
2208   case ComplexDeinterleavingOperation::ReductionSelect: {
2209     auto *MaskReal = cast<Instruction>(Node->Real)->getOperand(0);
2210     auto *MaskImag = cast<Instruction>(Node->Imag)->getOperand(0);
2211     auto *A = replaceNode(Builder, Node->Operands[0]);
2212     auto *B = replaceNode(Builder, Node->Operands[1]);
2213     auto *NewMaskTy = VectorType::getDoubleElementsVectorType(
2214         cast<VectorType>(MaskReal->getType()));
2215     auto *NewMask = Builder.CreateIntrinsic(Intrinsic::vector_interleave2,
2216                                             NewMaskTy, {MaskReal, MaskImag});
2217     ReplacementNode = Builder.CreateSelect(NewMask, A, B);
2218     break;
2219   }
2220   }
2221 
2222   assert(ReplacementNode && "Target failed to create Intrinsic call.");
2223   NumComplexTransformations += 1;
2224   Node->ReplacementNode = ReplacementNode;
2225   return ReplacementNode;
2226 }
2227 
2228 void ComplexDeinterleavingGraph::processReductionSingle(
2229     Value *OperationReplacement, RawNodePtr Node) {
2230   auto *Real = cast<Instruction>(Node->Real);
2231   auto *OldPHI = ReductionInfo[Real].first;
2232   auto *NewPHI = OldToNewPHI[OldPHI];
2233   auto *VTy = cast<VectorType>(Real->getType());
2234   auto *NewVTy = VectorType::getDoubleElementsVectorType(VTy);
2235 
2236   Value *Init = OldPHI->getIncomingValueForBlock(Incoming);
2237 
2238   IRBuilder<> Builder(Incoming->getTerminator());
2239 
2240   Value *NewInit = nullptr;
2241   if (auto *C = dyn_cast<Constant>(Init)) {
2242     if (C->isZeroValue())
2243       NewInit = Constant::getNullValue(NewVTy);
2244   }
2245 
2246   if (!NewInit)
2247     NewInit = Builder.CreateIntrinsic(Intrinsic::vector_interleave2, NewVTy,
2248                                       {Init, Constant::getNullValue(VTy)});
2249 
2250   NewPHI->addIncoming(NewInit, Incoming);
2251   NewPHI->addIncoming(OperationReplacement, BackEdge);
2252 
2253   auto *FinalReduction = ReductionInfo[Real].second;
2254   Builder.SetInsertPoint(&*FinalReduction->getParent()->getFirstInsertionPt());
2255 
2256   auto *AddReduce = Builder.CreateAddReduce(OperationReplacement);
2257   FinalReduction->replaceAllUsesWith(AddReduce);
2258 }
2259 
2260 void ComplexDeinterleavingGraph::processReductionOperation(
2261     Value *OperationReplacement, RawNodePtr Node) {
2262   auto *Real = cast<Instruction>(Node->Real);
2263   auto *Imag = cast<Instruction>(Node->Imag);
2264   auto *OldPHIReal = ReductionInfo[Real].first;
2265   auto *OldPHIImag = ReductionInfo[Imag].first;
2266   auto *NewPHI = OldToNewPHI[OldPHIReal];
2267 
2268   auto *VTy = cast<VectorType>(Real->getType());
2269   auto *NewVTy = VectorType::getDoubleElementsVectorType(VTy);
2270 
2271   // We have to interleave initial origin values coming from IncomingBlock
2272   Value *InitReal = OldPHIReal->getIncomingValueForBlock(Incoming);
2273   Value *InitImag = OldPHIImag->getIncomingValueForBlock(Incoming);
2274 
2275   IRBuilder<> Builder(Incoming->getTerminator());
2276   auto *NewInit = Builder.CreateIntrinsic(Intrinsic::vector_interleave2, NewVTy,
2277                                           {InitReal, InitImag});
2278 
2279   NewPHI->addIncoming(NewInit, Incoming);
2280   NewPHI->addIncoming(OperationReplacement, BackEdge);
2281 
2282   // Deinterleave complex vector outside of loop so that it can be finally
2283   // reduced
2284   auto *FinalReductionReal = ReductionInfo[Real].second;
2285   auto *FinalReductionImag = ReductionInfo[Imag].second;
2286 
2287   Builder.SetInsertPoint(
2288       &*FinalReductionReal->getParent()->getFirstInsertionPt());
2289   auto *Deinterleave = Builder.CreateIntrinsic(Intrinsic::vector_deinterleave2,
2290                                                OperationReplacement->getType(),
2291                                                OperationReplacement);
2292 
2293   auto *NewReal = Builder.CreateExtractValue(Deinterleave, (uint64_t)0);
2294   FinalReductionReal->replaceUsesOfWith(Real, NewReal);
2295 
2296   Builder.SetInsertPoint(FinalReductionImag);
2297   auto *NewImag = Builder.CreateExtractValue(Deinterleave, 1);
2298   FinalReductionImag->replaceUsesOfWith(Imag, NewImag);
2299 }
2300 
2301 void ComplexDeinterleavingGraph::replaceNodes() {
2302   SmallVector<Instruction *, 16> DeadInstrRoots;
2303   for (auto *RootInstruction : OrderedRoots) {
2304     // Check if this potential root went through check process and we can
2305     // deinterleave it
2306     if (!RootToNode.count(RootInstruction))
2307       continue;
2308 
2309     IRBuilder<> Builder(RootInstruction);
2310     auto RootNode = RootToNode[RootInstruction];
2311     Value *R = replaceNode(Builder, RootNode.get());
2312 
2313     if (RootNode->Operation ==
2314         ComplexDeinterleavingOperation::ReductionOperation) {
2315       auto *RootReal = cast<Instruction>(RootNode->Real);
2316       auto *RootImag = cast<Instruction>(RootNode->Imag);
2317       ReductionInfo[RootReal].first->removeIncomingValue(BackEdge);
2318       ReductionInfo[RootImag].first->removeIncomingValue(BackEdge);
2319       DeadInstrRoots.push_back(RootReal);
2320       DeadInstrRoots.push_back(RootImag);
2321     } else if (RootNode->Operation ==
2322                ComplexDeinterleavingOperation::ReductionSingle) {
2323       auto *RootInst = cast<Instruction>(RootNode->Real);
2324       ReductionInfo[RootInst].first->removeIncomingValue(BackEdge);
2325       DeadInstrRoots.push_back(ReductionInfo[RootInst].second);
2326     } else {
2327       assert(R && "Unable to find replacement for RootInstruction");
2328       DeadInstrRoots.push_back(RootInstruction);
2329       RootInstruction->replaceAllUsesWith(R);
2330     }
2331   }
2332 
2333   for (auto *I : DeadInstrRoots)
2334     RecursivelyDeleteTriviallyDeadInstructions(I, TLI);
2335 }
2336