xref: /llvm-project/llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp (revision 40a81d3100b416393557f015efc971497c0bea46)
1 //===- ComplexDeinterleavingPass.cpp --------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Identification:
10 // This step is responsible for finding the patterns that can be lowered to
11 // complex instructions, and building a graph to represent the complex
12 // structures. Starting from the "Converging Shuffle" (a shuffle that
13 // reinterleaves the complex components, with a mask of <0, 2, 1, 3>), the
14 // operands are evaluated and identified as "Composite Nodes" (collections of
15 // instructions that can potentially be lowered to a single complex
16 // instruction). This is performed by checking the real and imaginary components
17 // and tracking the data flow for each component while following the operand
18 // pairs. Validity of each node is expected to be done upon creation, and any
19 // validation errors should halt traversal and prevent further graph
20 // construction.
21 // Instead of relying on Shuffle operations, vector interleaving and
22 // deinterleaving can be represented by vector.interleave2 and
23 // vector.deinterleave2 intrinsics. Scalable vectors can be represented only by
24 // these intrinsics, whereas, fixed-width vectors are recognized for both
25 // shufflevector instruction and intrinsics.
26 //
27 // Replacement:
28 // This step traverses the graph built up by identification, delegating to the
29 // target to validate and generate the correct intrinsics, and plumbs them
30 // together connecting each end of the new intrinsics graph to the existing
31 // use-def chain. This step is assumed to finish successfully, as all
32 // information is expected to be correct by this point.
33 //
34 //
35 // Internal data structure:
36 // ComplexDeinterleavingGraph:
37 // Keeps references to all the valid CompositeNodes formed as part of the
38 // transformation, and every Instruction contained within said nodes. It also
39 // holds onto a reference to the root Instruction, and the root node that should
40 // replace it.
41 //
42 // ComplexDeinterleavingCompositeNode:
43 // A CompositeNode represents a single transformation point; each node should
44 // transform into a single complex instruction (ignoring vector splitting, which
45 // would generate more instructions per node). They are identified in a
46 // depth-first manner, traversing and identifying the operands of each
47 // instruction in the order they appear in the IR.
48 // Each node maintains a reference  to its Real and Imaginary instructions,
49 // as well as any additional instructions that make up the identified operation
50 // (Internal instructions should only have uses within their containing node).
51 // A Node also contains the rotation and operation type that it represents.
52 // Operands contains pointers to other CompositeNodes, acting as the edges in
53 // the graph. ReplacementValue is the transformed Value* that has been emitted
54 // to the IR.
55 //
56 // Note: If the operation of a Node is Shuffle, only the Real, Imaginary, and
57 // ReplacementValue fields of that Node are relevant, where the ReplacementValue
58 // should be pre-populated.
59 //
60 //===----------------------------------------------------------------------===//
61 
62 #include "llvm/CodeGen/ComplexDeinterleavingPass.h"
63 #include "llvm/ADT/Statistic.h"
64 #include "llvm/Analysis/TargetLibraryInfo.h"
65 #include "llvm/Analysis/TargetTransformInfo.h"
66 #include "llvm/CodeGen/TargetLowering.h"
67 #include "llvm/CodeGen/TargetPassConfig.h"
68 #include "llvm/CodeGen/TargetSubtargetInfo.h"
69 #include "llvm/IR/IRBuilder.h"
70 #include "llvm/IR/PatternMatch.h"
71 #include "llvm/InitializePasses.h"
72 #include "llvm/Target/TargetMachine.h"
73 #include "llvm/Transforms/Utils/Local.h"
74 #include <algorithm>
75 
76 using namespace llvm;
77 using namespace PatternMatch;
78 
79 #define DEBUG_TYPE "complex-deinterleaving"
80 
81 STATISTIC(NumComplexTransformations, "Amount of complex patterns transformed");
82 
83 static cl::opt<bool> ComplexDeinterleavingEnabled(
84     "enable-complex-deinterleaving",
85     cl::desc("Enable generation of complex instructions"), cl::init(true),
86     cl::Hidden);
87 
88 /// Checks the given mask, and determines whether said mask is interleaving.
89 ///
90 /// To be interleaving, a mask must alternate between `i` and `i + (Length /
91 /// 2)`, and must contain all numbers within the range of `[0..Length)` (e.g. a
92 /// 4x vector interleaving mask would be <0, 2, 1, 3>).
93 static bool isInterleavingMask(ArrayRef<int> Mask);
94 
95 /// Checks the given mask, and determines whether said mask is deinterleaving.
96 ///
97 /// To be deinterleaving, a mask must increment in steps of 2, and either start
98 /// with 0 or 1.
99 /// (e.g. an 8x vector deinterleaving mask would be either <0, 2, 4, 6> or
100 /// <1, 3, 5, 7>).
101 static bool isDeinterleavingMask(ArrayRef<int> Mask);
102 
103 namespace {
104 
105 class ComplexDeinterleavingLegacyPass : public FunctionPass {
106 public:
107   static char ID;
108 
109   ComplexDeinterleavingLegacyPass(const TargetMachine *TM = nullptr)
110       : FunctionPass(ID), TM(TM) {
111     initializeComplexDeinterleavingLegacyPassPass(
112         *PassRegistry::getPassRegistry());
113   }
114 
115   StringRef getPassName() const override {
116     return "Complex Deinterleaving Pass";
117   }
118 
119   bool runOnFunction(Function &F) override;
120   void getAnalysisUsage(AnalysisUsage &AU) const override {
121     AU.addRequired<TargetLibraryInfoWrapperPass>();
122     AU.setPreservesCFG();
123   }
124 
125 private:
126   const TargetMachine *TM;
127 };
128 
129 class ComplexDeinterleavingGraph;
130 struct ComplexDeinterleavingCompositeNode {
131 
132   ComplexDeinterleavingCompositeNode(ComplexDeinterleavingOperation Op,
133                                      Instruction *R, Instruction *I)
134       : Operation(Op), Real(R), Imag(I) {}
135 
136 private:
137   friend class ComplexDeinterleavingGraph;
138   using NodePtr = std::shared_ptr<ComplexDeinterleavingCompositeNode>;
139   using RawNodePtr = ComplexDeinterleavingCompositeNode *;
140 
141 public:
142   ComplexDeinterleavingOperation Operation;
143   Instruction *Real;
144   Instruction *Imag;
145 
146   ComplexDeinterleavingRotation Rotation =
147       ComplexDeinterleavingRotation::Rotation_0;
148   SmallVector<RawNodePtr> Operands;
149   Value *ReplacementNode = nullptr;
150 
151   void addOperand(NodePtr Node) { Operands.push_back(Node.get()); }
152 
153   void dump() { dump(dbgs()); }
154   void dump(raw_ostream &OS) {
155     auto PrintValue = [&](Value *V) {
156       if (V) {
157         OS << "\"";
158         V->print(OS, true);
159         OS << "\"\n";
160       } else
161         OS << "nullptr\n";
162     };
163     auto PrintNodeRef = [&](RawNodePtr Ptr) {
164       if (Ptr)
165         OS << Ptr << "\n";
166       else
167         OS << "nullptr\n";
168     };
169 
170     OS << "- CompositeNode: " << this << "\n";
171     OS << "  Real: ";
172     PrintValue(Real);
173     OS << "  Imag: ";
174     PrintValue(Imag);
175     OS << "  ReplacementNode: ";
176     PrintValue(ReplacementNode);
177     OS << "  Operation: " << (int)Operation << "\n";
178     OS << "  Rotation: " << ((int)Rotation * 90) << "\n";
179     OS << "  Operands: \n";
180     for (const auto &Op : Operands) {
181       OS << "    - ";
182       PrintNodeRef(Op);
183     }
184   }
185 };
186 
187 class ComplexDeinterleavingGraph {
188 public:
189   using NodePtr = ComplexDeinterleavingCompositeNode::NodePtr;
190   using RawNodePtr = ComplexDeinterleavingCompositeNode::RawNodePtr;
191   explicit ComplexDeinterleavingGraph(const TargetLowering *TL,
192                                       const TargetLibraryInfo *TLI)
193       : TL(TL), TLI(TLI) {}
194 
195 private:
196   const TargetLowering *TL = nullptr;
197   const TargetLibraryInfo *TLI = nullptr;
198   SmallVector<NodePtr> CompositeNodes;
199 
200   SmallPtrSet<Instruction *, 16> FinalInstructions;
201 
202   /// Root instructions are instructions from which complex computation starts
203   std::map<Instruction *, NodePtr> RootToNode;
204 
205   /// Topologically sorted root instructions
206   SmallVector<Instruction *, 1> OrderedRoots;
207 
208   NodePtr prepareCompositeNode(ComplexDeinterleavingOperation Operation,
209                                Instruction *R, Instruction *I) {
210     return std::make_shared<ComplexDeinterleavingCompositeNode>(Operation, R,
211                                                                 I);
212   }
213 
214   NodePtr submitCompositeNode(NodePtr Node) {
215     CompositeNodes.push_back(Node);
216     return Node;
217   }
218 
219   NodePtr getContainingComposite(Value *R, Value *I) {
220     for (const auto &CN : CompositeNodes) {
221       if (CN->Real == R && CN->Imag == I)
222         return CN;
223     }
224     return nullptr;
225   }
226 
227   /// Identifies a complex partial multiply pattern and its rotation, based on
228   /// the following patterns
229   ///
230   ///  0:  r: cr + ar * br
231   ///      i: ci + ar * bi
232   /// 90:  r: cr - ai * bi
233   ///      i: ci + ai * br
234   /// 180: r: cr - ar * br
235   ///      i: ci - ar * bi
236   /// 270: r: cr + ai * bi
237   ///      i: ci - ai * br
238   NodePtr identifyPartialMul(Instruction *Real, Instruction *Imag);
239 
240   /// Identify the other branch of a Partial Mul, taking the CommonOperandI that
241   /// is partially known from identifyPartialMul, filling in the other half of
242   /// the complex pair.
243   NodePtr identifyNodeWithImplicitAdd(
244       Instruction *I, Instruction *J,
245       std::pair<Instruction *, Instruction *> &CommonOperandI);
246 
247   /// Identifies a complex add pattern and its rotation, based on the following
248   /// patterns.
249   ///
250   /// 90:  r: ar - bi
251   ///      i: ai + br
252   /// 270: r: ar + bi
253   ///      i: ai - br
254   NodePtr identifyAdd(Instruction *Real, Instruction *Imag);
255   NodePtr identifySymmetricOperation(Instruction *Real, Instruction *Imag);
256 
257   NodePtr identifyNode(Instruction *I, Instruction *J);
258 
259   NodePtr identifyRoot(Instruction *I);
260 
261   /// Identifies the Deinterleave operation applied to a vector containing
262   /// complex numbers. There are two ways to represent the Deinterleave
263   /// operation:
264   /// * Using two shufflevectors with even indices for /pReal instruction and
265   /// odd indices for /pImag instructions (only for fixed-width vectors)
266   /// * Using two extractvalue instructions applied to `vector.deinterleave2`
267   /// intrinsic (for both fixed and scalable vectors)
268   NodePtr identifyDeinterleave(Instruction *Real, Instruction *Imag);
269 
270   Value *replaceNode(IRBuilderBase &Builder, RawNodePtr Node);
271 
272 public:
273   void dump() { dump(dbgs()); }
274   void dump(raw_ostream &OS) {
275     for (const auto &Node : CompositeNodes)
276       Node->dump(OS);
277   }
278 
279   /// Returns false if the deinterleaving operation should be cancelled for the
280   /// current graph.
281   bool identifyNodes(Instruction *RootI);
282 
283   /// Check that every instruction, from the roots to the leaves, has internal
284   /// uses.
285   bool checkNodes();
286 
287   /// Perform the actual replacement of the underlying instruction graph.
288   void replaceNodes();
289 };
290 
291 class ComplexDeinterleaving {
292 public:
293   ComplexDeinterleaving(const TargetLowering *tl, const TargetLibraryInfo *tli)
294       : TL(tl), TLI(tli) {}
295   bool runOnFunction(Function &F);
296 
297 private:
298   bool evaluateBasicBlock(BasicBlock *B);
299 
300   const TargetLowering *TL = nullptr;
301   const TargetLibraryInfo *TLI = nullptr;
302 };
303 
304 } // namespace
305 
306 char ComplexDeinterleavingLegacyPass::ID = 0;
307 
308 INITIALIZE_PASS_BEGIN(ComplexDeinterleavingLegacyPass, DEBUG_TYPE,
309                       "Complex Deinterleaving", false, false)
310 INITIALIZE_PASS_END(ComplexDeinterleavingLegacyPass, DEBUG_TYPE,
311                     "Complex Deinterleaving", false, false)
312 
313 PreservedAnalyses ComplexDeinterleavingPass::run(Function &F,
314                                                  FunctionAnalysisManager &AM) {
315   const TargetLowering *TL = TM->getSubtargetImpl(F)->getTargetLowering();
316   auto &TLI = AM.getResult<llvm::TargetLibraryAnalysis>(F);
317   if (!ComplexDeinterleaving(TL, &TLI).runOnFunction(F))
318     return PreservedAnalyses::all();
319 
320   PreservedAnalyses PA;
321   PA.preserve<FunctionAnalysisManagerModuleProxy>();
322   return PA;
323 }
324 
325 FunctionPass *llvm::createComplexDeinterleavingPass(const TargetMachine *TM) {
326   return new ComplexDeinterleavingLegacyPass(TM);
327 }
328 
329 bool ComplexDeinterleavingLegacyPass::runOnFunction(Function &F) {
330   const auto *TL = TM->getSubtargetImpl(F)->getTargetLowering();
331   auto TLI = getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F);
332   return ComplexDeinterleaving(TL, &TLI).runOnFunction(F);
333 }
334 
335 bool ComplexDeinterleaving::runOnFunction(Function &F) {
336   if (!ComplexDeinterleavingEnabled) {
337     LLVM_DEBUG(
338         dbgs() << "Complex deinterleaving has been explicitly disabled.\n");
339     return false;
340   }
341 
342   if (!TL->isComplexDeinterleavingSupported()) {
343     LLVM_DEBUG(
344         dbgs() << "Complex deinterleaving has been disabled, target does "
345                   "not support lowering of complex number operations.\n");
346     return false;
347   }
348 
349   bool Changed = false;
350   for (auto &B : F)
351     Changed |= evaluateBasicBlock(&B);
352 
353   return Changed;
354 }
355 
356 static bool isInterleavingMask(ArrayRef<int> Mask) {
357   // If the size is not even, it's not an interleaving mask
358   if ((Mask.size() & 1))
359     return false;
360 
361   int HalfNumElements = Mask.size() / 2;
362   for (int Idx = 0; Idx < HalfNumElements; ++Idx) {
363     int MaskIdx = Idx * 2;
364     if (Mask[MaskIdx] != Idx || Mask[MaskIdx + 1] != (Idx + HalfNumElements))
365       return false;
366   }
367 
368   return true;
369 }
370 
371 static bool isDeinterleavingMask(ArrayRef<int> Mask) {
372   int Offset = Mask[0];
373   int HalfNumElements = Mask.size() / 2;
374 
375   for (int Idx = 1; Idx < HalfNumElements; ++Idx) {
376     if (Mask[Idx] != (Idx * 2) + Offset)
377       return false;
378   }
379 
380   return true;
381 }
382 
383 bool ComplexDeinterleaving::evaluateBasicBlock(BasicBlock *B) {
384   ComplexDeinterleavingGraph Graph(TL, TLI);
385   for (auto &I : *B)
386     Graph.identifyNodes(&I);
387 
388   if (Graph.checkNodes()) {
389     Graph.replaceNodes();
390     return true;
391   }
392 
393   return false;
394 }
395 
396 ComplexDeinterleavingGraph::NodePtr
397 ComplexDeinterleavingGraph::identifyNodeWithImplicitAdd(
398     Instruction *Real, Instruction *Imag,
399     std::pair<Instruction *, Instruction *> &PartialMatch) {
400   LLVM_DEBUG(dbgs() << "identifyNodeWithImplicitAdd " << *Real << " / " << *Imag
401                     << "\n");
402 
403   if (!Real->hasOneUse() || !Imag->hasOneUse()) {
404     LLVM_DEBUG(dbgs() << "  - Mul operand has multiple uses.\n");
405     return nullptr;
406   }
407 
408   if (Real->getOpcode() != Instruction::FMul ||
409       Imag->getOpcode() != Instruction::FMul) {
410     LLVM_DEBUG(dbgs() << "  - Real or imaginary instruction is not fmul\n");
411     return nullptr;
412   }
413 
414   Instruction *R0 = dyn_cast<Instruction>(Real->getOperand(0));
415   Instruction *R1 = dyn_cast<Instruction>(Real->getOperand(1));
416   Instruction *I0 = dyn_cast<Instruction>(Imag->getOperand(0));
417   Instruction *I1 = dyn_cast<Instruction>(Imag->getOperand(1));
418   if (!R0 || !R1 || !I0 || !I1) {
419     LLVM_DEBUG(dbgs() << "  - Mul operand not Instruction\n");
420     return nullptr;
421   }
422 
423   // A +/+ has a rotation of 0. If any of the operands are fneg, we flip the
424   // rotations and use the operand.
425   unsigned Negs = 0;
426   SmallVector<Instruction *> FNegs;
427   if (R0->getOpcode() == Instruction::FNeg ||
428       R1->getOpcode() == Instruction::FNeg) {
429     Negs |= 1;
430     if (R0->getOpcode() == Instruction::FNeg) {
431       FNegs.push_back(R0);
432       R0 = dyn_cast<Instruction>(R0->getOperand(0));
433     } else {
434       FNegs.push_back(R1);
435       R1 = dyn_cast<Instruction>(R1->getOperand(0));
436     }
437     if (!R0 || !R1)
438       return nullptr;
439   }
440   if (I0->getOpcode() == Instruction::FNeg ||
441       I1->getOpcode() == Instruction::FNeg) {
442     Negs |= 2;
443     Negs ^= 1;
444     if (I0->getOpcode() == Instruction::FNeg) {
445       FNegs.push_back(I0);
446       I0 = dyn_cast<Instruction>(I0->getOperand(0));
447     } else {
448       FNegs.push_back(I1);
449       I1 = dyn_cast<Instruction>(I1->getOperand(0));
450     }
451     if (!I0 || !I1)
452       return nullptr;
453   }
454 
455   ComplexDeinterleavingRotation Rotation = (ComplexDeinterleavingRotation)Negs;
456 
457   Instruction *CommonOperand;
458   Instruction *UncommonRealOp;
459   Instruction *UncommonImagOp;
460 
461   if (R0 == I0 || R0 == I1) {
462     CommonOperand = R0;
463     UncommonRealOp = R1;
464   } else if (R1 == I0 || R1 == I1) {
465     CommonOperand = R1;
466     UncommonRealOp = R0;
467   } else {
468     LLVM_DEBUG(dbgs() << "  - No equal operand\n");
469     return nullptr;
470   }
471 
472   UncommonImagOp = (CommonOperand == I0) ? I1 : I0;
473   if (Rotation == ComplexDeinterleavingRotation::Rotation_90 ||
474       Rotation == ComplexDeinterleavingRotation::Rotation_270)
475     std::swap(UncommonRealOp, UncommonImagOp);
476 
477   // Between identifyPartialMul and here we need to have found a complete valid
478   // pair from the CommonOperand of each part.
479   if (Rotation == ComplexDeinterleavingRotation::Rotation_0 ||
480       Rotation == ComplexDeinterleavingRotation::Rotation_180)
481     PartialMatch.first = CommonOperand;
482   else
483     PartialMatch.second = CommonOperand;
484 
485   if (!PartialMatch.first || !PartialMatch.second) {
486     LLVM_DEBUG(dbgs() << "  - Incomplete partial match\n");
487     return nullptr;
488   }
489 
490   NodePtr CommonNode = identifyNode(PartialMatch.first, PartialMatch.second);
491   if (!CommonNode) {
492     LLVM_DEBUG(dbgs() << "  - No CommonNode identified\n");
493     return nullptr;
494   }
495 
496   NodePtr UncommonNode = identifyNode(UncommonRealOp, UncommonImagOp);
497   if (!UncommonNode) {
498     LLVM_DEBUG(dbgs() << "  - No UncommonNode identified\n");
499     return nullptr;
500   }
501 
502   NodePtr Node = prepareCompositeNode(
503       ComplexDeinterleavingOperation::CMulPartial, Real, Imag);
504   Node->Rotation = Rotation;
505   Node->addOperand(CommonNode);
506   Node->addOperand(UncommonNode);
507   return submitCompositeNode(Node);
508 }
509 
510 ComplexDeinterleavingGraph::NodePtr
511 ComplexDeinterleavingGraph::identifyPartialMul(Instruction *Real,
512                                                Instruction *Imag) {
513   LLVM_DEBUG(dbgs() << "identifyPartialMul " << *Real << " / " << *Imag
514                     << "\n");
515   // Determine rotation
516   ComplexDeinterleavingRotation Rotation;
517   if (Real->getOpcode() == Instruction::FAdd &&
518       Imag->getOpcode() == Instruction::FAdd)
519     Rotation = ComplexDeinterleavingRotation::Rotation_0;
520   else if (Real->getOpcode() == Instruction::FSub &&
521            Imag->getOpcode() == Instruction::FAdd)
522     Rotation = ComplexDeinterleavingRotation::Rotation_90;
523   else if (Real->getOpcode() == Instruction::FSub &&
524            Imag->getOpcode() == Instruction::FSub)
525     Rotation = ComplexDeinterleavingRotation::Rotation_180;
526   else if (Real->getOpcode() == Instruction::FAdd &&
527            Imag->getOpcode() == Instruction::FSub)
528     Rotation = ComplexDeinterleavingRotation::Rotation_270;
529   else {
530     LLVM_DEBUG(dbgs() << "  - Unhandled rotation.\n");
531     return nullptr;
532   }
533 
534   if (!Real->getFastMathFlags().allowContract() ||
535       !Imag->getFastMathFlags().allowContract()) {
536     LLVM_DEBUG(dbgs() << "  - Contract is missing from the FastMath flags.\n");
537     return nullptr;
538   }
539 
540   Value *CR = Real->getOperand(0);
541   Instruction *RealMulI = dyn_cast<Instruction>(Real->getOperand(1));
542   if (!RealMulI)
543     return nullptr;
544   Value *CI = Imag->getOperand(0);
545   Instruction *ImagMulI = dyn_cast<Instruction>(Imag->getOperand(1));
546   if (!ImagMulI)
547     return nullptr;
548 
549   if (!RealMulI->hasOneUse() || !ImagMulI->hasOneUse()) {
550     LLVM_DEBUG(dbgs() << "  - Mul instruction has multiple uses\n");
551     return nullptr;
552   }
553 
554   Instruction *R0 = dyn_cast<Instruction>(RealMulI->getOperand(0));
555   Instruction *R1 = dyn_cast<Instruction>(RealMulI->getOperand(1));
556   Instruction *I0 = dyn_cast<Instruction>(ImagMulI->getOperand(0));
557   Instruction *I1 = dyn_cast<Instruction>(ImagMulI->getOperand(1));
558   if (!R0 || !R1 || !I0 || !I1) {
559     LLVM_DEBUG(dbgs() << "  - Mul operand not Instruction\n");
560     return nullptr;
561   }
562 
563   Instruction *CommonOperand;
564   Instruction *UncommonRealOp;
565   Instruction *UncommonImagOp;
566 
567   if (R0 == I0 || R0 == I1) {
568     CommonOperand = R0;
569     UncommonRealOp = R1;
570   } else if (R1 == I0 || R1 == I1) {
571     CommonOperand = R1;
572     UncommonRealOp = R0;
573   } else {
574     LLVM_DEBUG(dbgs() << "  - No equal operand\n");
575     return nullptr;
576   }
577 
578   UncommonImagOp = (CommonOperand == I0) ? I1 : I0;
579   if (Rotation == ComplexDeinterleavingRotation::Rotation_90 ||
580       Rotation == ComplexDeinterleavingRotation::Rotation_270)
581     std::swap(UncommonRealOp, UncommonImagOp);
582 
583   std::pair<Instruction *, Instruction *> PartialMatch(
584       (Rotation == ComplexDeinterleavingRotation::Rotation_0 ||
585        Rotation == ComplexDeinterleavingRotation::Rotation_180)
586           ? CommonOperand
587           : nullptr,
588       (Rotation == ComplexDeinterleavingRotation::Rotation_90 ||
589        Rotation == ComplexDeinterleavingRotation::Rotation_270)
590           ? CommonOperand
591           : nullptr);
592 
593   auto *CRInst = dyn_cast<Instruction>(CR);
594   auto *CIInst = dyn_cast<Instruction>(CI);
595 
596   if (!CRInst || !CIInst) {
597     LLVM_DEBUG(dbgs() << "  - Common operands are not instructions.\n");
598     return nullptr;
599   }
600 
601   NodePtr CNode = identifyNodeWithImplicitAdd(CRInst, CIInst, PartialMatch);
602   if (!CNode) {
603     LLVM_DEBUG(dbgs() << "  - No cnode identified\n");
604     return nullptr;
605   }
606 
607   NodePtr UncommonRes = identifyNode(UncommonRealOp, UncommonImagOp);
608   if (!UncommonRes) {
609     LLVM_DEBUG(dbgs() << "  - No UncommonRes identified\n");
610     return nullptr;
611   }
612 
613   assert(PartialMatch.first && PartialMatch.second);
614   NodePtr CommonRes = identifyNode(PartialMatch.first, PartialMatch.second);
615   if (!CommonRes) {
616     LLVM_DEBUG(dbgs() << "  - No CommonRes identified\n");
617     return nullptr;
618   }
619 
620   NodePtr Node = prepareCompositeNode(
621       ComplexDeinterleavingOperation::CMulPartial, Real, Imag);
622   Node->Rotation = Rotation;
623   Node->addOperand(CommonRes);
624   Node->addOperand(UncommonRes);
625   Node->addOperand(CNode);
626   return submitCompositeNode(Node);
627 }
628 
629 ComplexDeinterleavingGraph::NodePtr
630 ComplexDeinterleavingGraph::identifyAdd(Instruction *Real, Instruction *Imag) {
631   LLVM_DEBUG(dbgs() << "identifyAdd " << *Real << " / " << *Imag << "\n");
632 
633   // Determine rotation
634   ComplexDeinterleavingRotation Rotation;
635   if ((Real->getOpcode() == Instruction::FSub &&
636        Imag->getOpcode() == Instruction::FAdd) ||
637       (Real->getOpcode() == Instruction::Sub &&
638        Imag->getOpcode() == Instruction::Add))
639     Rotation = ComplexDeinterleavingRotation::Rotation_90;
640   else if ((Real->getOpcode() == Instruction::FAdd &&
641             Imag->getOpcode() == Instruction::FSub) ||
642            (Real->getOpcode() == Instruction::Add &&
643             Imag->getOpcode() == Instruction::Sub))
644     Rotation = ComplexDeinterleavingRotation::Rotation_270;
645   else {
646     LLVM_DEBUG(dbgs() << " - Unhandled case, rotation is not assigned.\n");
647     return nullptr;
648   }
649 
650   auto *AR = dyn_cast<Instruction>(Real->getOperand(0));
651   auto *BI = dyn_cast<Instruction>(Real->getOperand(1));
652   auto *AI = dyn_cast<Instruction>(Imag->getOperand(0));
653   auto *BR = dyn_cast<Instruction>(Imag->getOperand(1));
654 
655   if (!AR || !AI || !BR || !BI) {
656     LLVM_DEBUG(dbgs() << " - Not all operands are instructions.\n");
657     return nullptr;
658   }
659 
660   NodePtr ResA = identifyNode(AR, AI);
661   if (!ResA) {
662     LLVM_DEBUG(dbgs() << " - AR/AI is not identified as a composite node.\n");
663     return nullptr;
664   }
665   NodePtr ResB = identifyNode(BR, BI);
666   if (!ResB) {
667     LLVM_DEBUG(dbgs() << " - BR/BI is not identified as a composite node.\n");
668     return nullptr;
669   }
670 
671   NodePtr Node =
672       prepareCompositeNode(ComplexDeinterleavingOperation::CAdd, Real, Imag);
673   Node->Rotation = Rotation;
674   Node->addOperand(ResA);
675   Node->addOperand(ResB);
676   return submitCompositeNode(Node);
677 }
678 
679 static bool isInstructionPairAdd(Instruction *A, Instruction *B) {
680   unsigned OpcA = A->getOpcode();
681   unsigned OpcB = B->getOpcode();
682 
683   return (OpcA == Instruction::FSub && OpcB == Instruction::FAdd) ||
684          (OpcA == Instruction::FAdd && OpcB == Instruction::FSub) ||
685          (OpcA == Instruction::Sub && OpcB == Instruction::Add) ||
686          (OpcA == Instruction::Add && OpcB == Instruction::Sub);
687 }
688 
689 static bool isInstructionPairMul(Instruction *A, Instruction *B) {
690   auto Pattern =
691       m_BinOp(m_FMul(m_Value(), m_Value()), m_FMul(m_Value(), m_Value()));
692 
693   return match(A, Pattern) && match(B, Pattern);
694 }
695 
696 static bool isInstructionPotentiallySymmetric(Instruction *I) {
697   switch (I->getOpcode()) {
698   case Instruction::FAdd:
699   case Instruction::FSub:
700   case Instruction::FMul:
701   case Instruction::FNeg:
702     return true;
703   default:
704     return false;
705   }
706 }
707 
708 ComplexDeinterleavingGraph::NodePtr
709 ComplexDeinterleavingGraph::identifySymmetricOperation(Instruction *Real,
710                                                        Instruction *Imag) {
711   if (Real->getOpcode() != Imag->getOpcode())
712     return nullptr;
713 
714   if (!isInstructionPotentiallySymmetric(Real) ||
715       !isInstructionPotentiallySymmetric(Imag))
716     return nullptr;
717 
718   auto *R0 = dyn_cast<Instruction>(Real->getOperand(0));
719   auto *I0 = dyn_cast<Instruction>(Imag->getOperand(0));
720 
721   if (!R0 || !I0)
722     return nullptr;
723 
724   NodePtr Op0 = identifyNode(R0, I0);
725   NodePtr Op1 = nullptr;
726   if (Op0 == nullptr)
727     return nullptr;
728 
729   if (Real->isBinaryOp()) {
730     auto *R1 = dyn_cast<Instruction>(Real->getOperand(1));
731     auto *I1 = dyn_cast<Instruction>(Imag->getOperand(1));
732     if (!R1 || !I1)
733       return nullptr;
734 
735     Op1 = identifyNode(R1, I1);
736     if (Op1 == nullptr)
737       return nullptr;
738   }
739 
740   auto Node = prepareCompositeNode(ComplexDeinterleavingOperation::Symmetric,
741                                    Real, Imag);
742   Node->addOperand(Op0);
743   if (Real->isBinaryOp())
744     Node->addOperand(Op1);
745 
746   return submitCompositeNode(Node);
747 }
748 
749 ComplexDeinterleavingGraph::NodePtr
750 ComplexDeinterleavingGraph::identifyNode(Instruction *Real, Instruction *Imag) {
751   LLVM_DEBUG(dbgs() << "identifyNode on " << *Real << " / " << *Imag << "\n");
752   if (NodePtr CN = getContainingComposite(Real, Imag)) {
753     LLVM_DEBUG(dbgs() << " - Folding to existing node\n");
754     return CN;
755   }
756 
757   NodePtr Node = identifyDeinterleave(Real, Imag);
758   if (Node)
759     return Node;
760 
761   auto *VTy = cast<VectorType>(Real->getType());
762   auto *NewVTy = VectorType::getDoubleElementsVectorType(VTy);
763 
764   if (TL->isComplexDeinterleavingOperationSupported(
765           ComplexDeinterleavingOperation::CMulPartial, NewVTy) &&
766       isInstructionPairMul(Real, Imag)) {
767     return identifyPartialMul(Real, Imag);
768   }
769 
770   if (TL->isComplexDeinterleavingOperationSupported(
771           ComplexDeinterleavingOperation::CAdd, NewVTy) &&
772       isInstructionPairAdd(Real, Imag)) {
773     return identifyAdd(Real, Imag);
774   }
775 
776   auto Symmetric = identifySymmetricOperation(Real, Imag);
777   LLVM_DEBUG(if (Symmetric == nullptr) dbgs()
778              << "  - Not recognised as a valid pattern.\n");
779   return Symmetric;
780 }
781 
782 bool ComplexDeinterleavingGraph::identifyNodes(Instruction *RootI) {
783   auto RootNode = identifyRoot(RootI);
784   if (!RootNode)
785     return false;
786 
787   LLVM_DEBUG({
788     Function *F = RootI->getFunction();
789     BasicBlock *B = RootI->getParent();
790     dbgs() << "Complex deinterleaving graph for " << F->getName()
791            << "::" << B->getName() << ".\n";
792     dump(dbgs());
793     dbgs() << "\n";
794   });
795   RootToNode[RootI] = RootNode;
796   OrderedRoots.push_back(RootI);
797   return true;
798 }
799 
800 bool ComplexDeinterleavingGraph::checkNodes() {
801   // Collect all instructions from roots to leaves
802   SmallPtrSet<Instruction *, 16> AllInstructions;
803   SmallVector<Instruction *, 8> Worklist;
804   for (auto *I : OrderedRoots)
805     Worklist.push_back(I);
806 
807   // Extract all instructions that are used by all XCMLA/XCADD/ADD/SUB/NEG
808   // chains
809   while (!Worklist.empty()) {
810     auto *I = Worklist.back();
811     Worklist.pop_back();
812 
813     if (!AllInstructions.insert(I).second)
814       continue;
815 
816     for (Value *Op : I->operands()) {
817       if (auto *OpI = dyn_cast<Instruction>(Op)) {
818         if (!FinalInstructions.count(I))
819           Worklist.emplace_back(OpI);
820       }
821     }
822   }
823 
824   // Find instructions that have users outside of chain
825   SmallVector<Instruction *, 2> OuterInstructions;
826   for (auto *I : AllInstructions) {
827     // Skip root nodes
828     if (RootToNode.count(I))
829       continue;
830 
831     for (User *U : I->users()) {
832       if (AllInstructions.count(cast<Instruction>(U)))
833         continue;
834 
835       // Found an instruction that is not used by XCMLA/XCADD chain
836       Worklist.emplace_back(I);
837       break;
838     }
839   }
840 
841   // If any instructions are found to be used outside, find and remove roots
842   // that somehow connect to those instructions.
843   SmallPtrSet<Instruction *, 16> Visited;
844   while (!Worklist.empty()) {
845     auto *I = Worklist.back();
846     Worklist.pop_back();
847     if (!Visited.insert(I).second)
848       continue;
849 
850     // Found an impacted root node. Removing it from the nodes to be
851     // deinterleaved
852     if (RootToNode.count(I)) {
853       LLVM_DEBUG(dbgs() << "Instruction " << *I
854                         << " could be deinterleaved but its chain of complex "
855                            "operations have an outside user\n");
856       RootToNode.erase(I);
857     }
858 
859     if (!AllInstructions.count(I) || FinalInstructions.count(I))
860       continue;
861 
862     for (User *U : I->users())
863       Worklist.emplace_back(cast<Instruction>(U));
864 
865     for (Value *Op : I->operands()) {
866       if (auto *OpI = dyn_cast<Instruction>(Op))
867         Worklist.emplace_back(OpI);
868     }
869   }
870   return !RootToNode.empty();
871 }
872 
873 ComplexDeinterleavingGraph::NodePtr
874 ComplexDeinterleavingGraph::identifyRoot(Instruction *RootI) {
875   if (auto *Intrinsic = dyn_cast<IntrinsicInst>(RootI)) {
876     if (Intrinsic->getIntrinsicID() !=
877         Intrinsic::experimental_vector_interleave2)
878       return nullptr;
879 
880     auto *Real = dyn_cast<Instruction>(Intrinsic->getOperand(0));
881     auto *Imag = dyn_cast<Instruction>(Intrinsic->getOperand(1));
882     if (!Real || !Imag)
883       return nullptr;
884 
885     return identifyNode(Real, Imag);
886   }
887 
888   auto *SVI = dyn_cast<ShuffleVectorInst>(RootI);
889   if (!SVI)
890     return nullptr;
891 
892   // Look for a shufflevector that takes separate vectors of the real and
893   // imaginary components and recombines them into a single vector.
894   if (!isInterleavingMask(SVI->getShuffleMask()))
895     return nullptr;
896 
897   Instruction *Real;
898   Instruction *Imag;
899   if (!match(RootI, m_Shuffle(m_Instruction(Real), m_Instruction(Imag))))
900     return nullptr;
901 
902   return identifyNode(Real, Imag);
903 }
904 
905 ComplexDeinterleavingGraph::NodePtr
906 ComplexDeinterleavingGraph::identifyDeinterleave(Instruction *Real,
907                                                  Instruction *Imag) {
908   Instruction *I = nullptr;
909   Value *FinalValue = nullptr;
910   if (match(Real, m_ExtractValue<0>(m_Instruction(I))) &&
911       match(Imag, m_ExtractValue<1>(m_Specific(I))) &&
912       match(I, m_Intrinsic<Intrinsic::experimental_vector_deinterleave2>(
913                    m_Value(FinalValue)))) {
914     NodePtr PlaceholderNode = prepareCompositeNode(
915         llvm::ComplexDeinterleavingOperation::Deinterleave, Real, Imag);
916     PlaceholderNode->ReplacementNode = FinalValue;
917     FinalInstructions.insert(Real);
918     FinalInstructions.insert(Imag);
919     return submitCompositeNode(PlaceholderNode);
920   }
921 
922   auto *RealShuffle = dyn_cast<ShuffleVectorInst>(Real);
923   auto *ImagShuffle = dyn_cast<ShuffleVectorInst>(Imag);
924   if (!RealShuffle || !ImagShuffle) {
925     if (RealShuffle || ImagShuffle)
926       LLVM_DEBUG(dbgs() << " - There's a shuffle where there shouldn't be.\n");
927     return nullptr;
928   }
929 
930   Value *RealOp1 = RealShuffle->getOperand(1);
931   if (!isa<UndefValue>(RealOp1) && !isa<ConstantAggregateZero>(RealOp1)) {
932     LLVM_DEBUG(dbgs() << " - RealOp1 is not undef or zero.\n");
933     return nullptr;
934   }
935   Value *ImagOp1 = ImagShuffle->getOperand(1);
936   if (!isa<UndefValue>(ImagOp1) && !isa<ConstantAggregateZero>(ImagOp1)) {
937     LLVM_DEBUG(dbgs() << " - ImagOp1 is not undef or zero.\n");
938     return nullptr;
939   }
940 
941   Value *RealOp0 = RealShuffle->getOperand(0);
942   Value *ImagOp0 = ImagShuffle->getOperand(0);
943 
944   if (RealOp0 != ImagOp0) {
945     LLVM_DEBUG(dbgs() << " - Shuffle operands are not equal.\n");
946     return nullptr;
947   }
948 
949   ArrayRef<int> RealMask = RealShuffle->getShuffleMask();
950   ArrayRef<int> ImagMask = ImagShuffle->getShuffleMask();
951   if (!isDeinterleavingMask(RealMask) || !isDeinterleavingMask(ImagMask)) {
952     LLVM_DEBUG(dbgs() << " - Masks are not deinterleaving.\n");
953     return nullptr;
954   }
955 
956   if (RealMask[0] != 0 || ImagMask[0] != 1) {
957     LLVM_DEBUG(dbgs() << " - Masks do not have the correct initial value.\n");
958     return nullptr;
959   }
960 
961   // Type checking, the shuffle type should be a vector type of the same
962   // scalar type, but half the size
963   auto CheckType = [&](ShuffleVectorInst *Shuffle) {
964     Value *Op = Shuffle->getOperand(0);
965     auto *ShuffleTy = cast<FixedVectorType>(Shuffle->getType());
966     auto *OpTy = cast<FixedVectorType>(Op->getType());
967 
968     if (OpTy->getScalarType() != ShuffleTy->getScalarType())
969       return false;
970     if ((ShuffleTy->getNumElements() * 2) != OpTy->getNumElements())
971       return false;
972 
973     return true;
974   };
975 
976   auto CheckDeinterleavingShuffle = [&](ShuffleVectorInst *Shuffle) -> bool {
977     if (!CheckType(Shuffle))
978       return false;
979 
980     ArrayRef<int> Mask = Shuffle->getShuffleMask();
981     int Last = *Mask.rbegin();
982 
983     Value *Op = Shuffle->getOperand(0);
984     auto *OpTy = cast<FixedVectorType>(Op->getType());
985     int NumElements = OpTy->getNumElements();
986 
987     // Ensure that the deinterleaving shuffle only pulls from the first
988     // shuffle operand.
989     return Last < NumElements;
990   };
991 
992   if (RealShuffle->getType() != ImagShuffle->getType()) {
993     LLVM_DEBUG(dbgs() << " - Shuffle types aren't equal.\n");
994     return nullptr;
995   }
996   if (!CheckDeinterleavingShuffle(RealShuffle)) {
997     LLVM_DEBUG(dbgs() << " - RealShuffle is invalid type.\n");
998     return nullptr;
999   }
1000   if (!CheckDeinterleavingShuffle(ImagShuffle)) {
1001     LLVM_DEBUG(dbgs() << " - ImagShuffle is invalid type.\n");
1002     return nullptr;
1003   }
1004 
1005   NodePtr PlaceholderNode =
1006       prepareCompositeNode(llvm::ComplexDeinterleavingOperation::Deinterleave,
1007                            RealShuffle, ImagShuffle);
1008   PlaceholderNode->ReplacementNode = RealShuffle->getOperand(0);
1009   FinalInstructions.insert(RealShuffle);
1010   FinalInstructions.insert(ImagShuffle);
1011   return submitCompositeNode(PlaceholderNode);
1012 }
1013 
1014 static Value *replaceSymmetricNode(IRBuilderBase &B,
1015                                    ComplexDeinterleavingGraph::RawNodePtr Node,
1016                                    Value *InputA, Value *InputB) {
1017   Instruction *I = Node->Real;
1018   if (I->isUnaryOp())
1019     assert(!InputB &&
1020            "Unary symmetric operations need one input, but two were provided.");
1021   else if (I->isBinaryOp())
1022     assert(InputB && "Binary symmetric operations need two inputs, only one "
1023                      "was provided.");
1024 
1025   switch (I->getOpcode()) {
1026   case Instruction::FNeg:
1027     return B.CreateFNegFMF(InputA, I);
1028   case Instruction::FAdd:
1029     return B.CreateFAddFMF(InputA, InputB, I);
1030   case Instruction::FSub:
1031     return B.CreateFSubFMF(InputA, InputB, I);
1032   case Instruction::FMul:
1033     return B.CreateFMulFMF(InputA, InputB, I);
1034   }
1035 
1036   return nullptr;
1037 }
1038 
1039 Value *ComplexDeinterleavingGraph::replaceNode(IRBuilderBase &Builder,
1040                                                RawNodePtr Node) {
1041   if (Node->ReplacementNode)
1042     return Node->ReplacementNode;
1043 
1044   Value *Input0 = replaceNode(Builder, Node->Operands[0]);
1045   Value *Input1 = Node->Operands.size() > 1
1046                       ? replaceNode(Builder, Node->Operands[1])
1047                       : nullptr;
1048   Value *Accumulator = Node->Operands.size() > 2
1049                            ? replaceNode(Builder, Node->Operands[2])
1050                            : nullptr;
1051 
1052   if (Input1)
1053     assert(Input0->getType() == Input1->getType() &&
1054            "Node inputs need to be of the same type");
1055 
1056   if (Node->Operation == ComplexDeinterleavingOperation::Symmetric)
1057     Node->ReplacementNode = replaceSymmetricNode(Builder, Node, Input0, Input1);
1058   else
1059     Node->ReplacementNode = TL->createComplexDeinterleavingIR(
1060         Builder, Node->Operation, Node->Rotation, Input0, Input1, Accumulator);
1061 
1062   assert(Node->ReplacementNode && "Target failed to create Intrinsic call.");
1063   NumComplexTransformations += 1;
1064   return Node->ReplacementNode;
1065 }
1066 
1067 void ComplexDeinterleavingGraph::replaceNodes() {
1068   SmallVector<Instruction *, 16> DeadInstrRoots;
1069   for (auto *RootInstruction : OrderedRoots) {
1070     // Check if this potential root went through check process and we can
1071     // deinterleave it
1072     if (!RootToNode.count(RootInstruction))
1073       continue;
1074 
1075     IRBuilder<> Builder(RootInstruction);
1076     auto RootNode = RootToNode[RootInstruction];
1077     Value *R = replaceNode(Builder, RootNode.get());
1078     assert(R && "Unable to find replacement for RootInstruction");
1079     DeadInstrRoots.push_back(RootInstruction);
1080     RootInstruction->replaceAllUsesWith(R);
1081   }
1082 
1083   for (auto *I : DeadInstrRoots)
1084     RecursivelyDeleteTriviallyDeadInstructions(I, TLI);
1085 }
1086