xref: /freebsd-src/contrib/llvm-project/llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp (revision bdd1243df58e60e85101c09001d9812a789b6bc4)
1*bdd1243dSDimitry Andric //===- ComplexDeinterleavingPass.cpp --------------------------------------===//
2*bdd1243dSDimitry Andric //
3*bdd1243dSDimitry Andric // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4*bdd1243dSDimitry Andric // See https://llvm.org/LICENSE.txt for license information.
5*bdd1243dSDimitry Andric // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6*bdd1243dSDimitry Andric //
7*bdd1243dSDimitry Andric //===----------------------------------------------------------------------===//
8*bdd1243dSDimitry Andric //
9*bdd1243dSDimitry Andric // Identification:
10*bdd1243dSDimitry Andric // This step is responsible for finding the patterns that can be lowered to
11*bdd1243dSDimitry Andric // complex instructions, and building a graph to represent the complex
12*bdd1243dSDimitry Andric // structures. Starting from the "Converging Shuffle" (a shuffle that
13*bdd1243dSDimitry Andric // reinterleaves the complex components, with a mask of <0, 2, 1, 3>), the
14*bdd1243dSDimitry Andric // operands are evaluated and identified as "Composite Nodes" (collections of
15*bdd1243dSDimitry Andric // instructions that can potentially be lowered to a single complex
16*bdd1243dSDimitry Andric // instruction). This is performed by checking the real and imaginary components
17*bdd1243dSDimitry Andric // and tracking the data flow for each component while following the operand
18*bdd1243dSDimitry Andric // pairs. Validity of each node is expected to be done upon creation, and any
19*bdd1243dSDimitry Andric // validation errors should halt traversal and prevent further graph
20*bdd1243dSDimitry Andric // construction.
21*bdd1243dSDimitry Andric //
22*bdd1243dSDimitry Andric // Replacement:
23*bdd1243dSDimitry Andric // This step traverses the graph built up by identification, delegating to the
24*bdd1243dSDimitry Andric // target to validate and generate the correct intrinsics, and plumbs them
25*bdd1243dSDimitry Andric // together connecting each end of the new intrinsics graph to the existing
26*bdd1243dSDimitry Andric // use-def chain. This step is assumed to finish successfully, as all
27*bdd1243dSDimitry Andric // information is expected to be correct by this point.
28*bdd1243dSDimitry Andric //
29*bdd1243dSDimitry Andric //
30*bdd1243dSDimitry Andric // Internal data structure:
31*bdd1243dSDimitry Andric // ComplexDeinterleavingGraph:
32*bdd1243dSDimitry Andric // Keeps references to all the valid CompositeNodes formed as part of the
33*bdd1243dSDimitry Andric // transformation, and every Instruction contained within said nodes. It also
34*bdd1243dSDimitry Andric // holds onto a reference to the root Instruction, and the root node that should
35*bdd1243dSDimitry Andric // replace it.
36*bdd1243dSDimitry Andric //
37*bdd1243dSDimitry Andric // ComplexDeinterleavingCompositeNode:
38*bdd1243dSDimitry Andric // A CompositeNode represents a single transformation point; each node should
39*bdd1243dSDimitry Andric // transform into a single complex instruction (ignoring vector splitting, which
40*bdd1243dSDimitry Andric // would generate more instructions per node). They are identified in a
41*bdd1243dSDimitry Andric // depth-first manner, traversing and identifying the operands of each
42*bdd1243dSDimitry Andric // instruction in the order they appear in the IR.
43*bdd1243dSDimitry Andric // Each node maintains a reference  to its Real and Imaginary instructions,
44*bdd1243dSDimitry Andric // as well as any additional instructions that make up the identified operation
45*bdd1243dSDimitry Andric // (Internal instructions should only have uses within their containing node).
46*bdd1243dSDimitry Andric // A Node also contains the rotation and operation type that it represents.
47*bdd1243dSDimitry Andric // Operands contains pointers to other CompositeNodes, acting as the edges in
48*bdd1243dSDimitry Andric // the graph. ReplacementValue is the transformed Value* that has been emitted
49*bdd1243dSDimitry Andric // to the IR.
50*bdd1243dSDimitry Andric //
51*bdd1243dSDimitry Andric // Note: If the operation of a Node is Shuffle, only the Real, Imaginary, and
52*bdd1243dSDimitry Andric // ReplacementValue fields of that Node are relevant, where the ReplacementValue
53*bdd1243dSDimitry Andric // should be pre-populated.
54*bdd1243dSDimitry Andric //
55*bdd1243dSDimitry Andric //===----------------------------------------------------------------------===//
56*bdd1243dSDimitry Andric 
57*bdd1243dSDimitry Andric #include "llvm/CodeGen/ComplexDeinterleavingPass.h"
58*bdd1243dSDimitry Andric #include "llvm/ADT/Statistic.h"
59*bdd1243dSDimitry Andric #include "llvm/Analysis/TargetLibraryInfo.h"
60*bdd1243dSDimitry Andric #include "llvm/Analysis/TargetTransformInfo.h"
61*bdd1243dSDimitry Andric #include "llvm/CodeGen/TargetLowering.h"
62*bdd1243dSDimitry Andric #include "llvm/CodeGen/TargetPassConfig.h"
63*bdd1243dSDimitry Andric #include "llvm/CodeGen/TargetSubtargetInfo.h"
64*bdd1243dSDimitry Andric #include "llvm/IR/IRBuilder.h"
65*bdd1243dSDimitry Andric #include "llvm/InitializePasses.h"
66*bdd1243dSDimitry Andric #include "llvm/Target/TargetMachine.h"
67*bdd1243dSDimitry Andric #include "llvm/Transforms/Utils/Local.h"
68*bdd1243dSDimitry Andric #include <algorithm>
69*bdd1243dSDimitry Andric 
70*bdd1243dSDimitry Andric using namespace llvm;
71*bdd1243dSDimitry Andric using namespace PatternMatch;
72*bdd1243dSDimitry Andric 
73*bdd1243dSDimitry Andric #define DEBUG_TYPE "complex-deinterleaving"
74*bdd1243dSDimitry Andric 
75*bdd1243dSDimitry Andric STATISTIC(NumComplexTransformations, "Amount of complex patterns transformed");
76*bdd1243dSDimitry Andric 
77*bdd1243dSDimitry Andric static cl::opt<bool> ComplexDeinterleavingEnabled(
78*bdd1243dSDimitry Andric     "enable-complex-deinterleaving",
79*bdd1243dSDimitry Andric     cl::desc("Enable generation of complex instructions"), cl::init(true),
80*bdd1243dSDimitry Andric     cl::Hidden);
81*bdd1243dSDimitry Andric 
82*bdd1243dSDimitry Andric /// Checks the given mask, and determines whether said mask is interleaving.
83*bdd1243dSDimitry Andric ///
84*bdd1243dSDimitry Andric /// To be interleaving, a mask must alternate between `i` and `i + (Length /
85*bdd1243dSDimitry Andric /// 2)`, and must contain all numbers within the range of `[0..Length)` (e.g. a
86*bdd1243dSDimitry Andric /// 4x vector interleaving mask would be <0, 2, 1, 3>).
87*bdd1243dSDimitry Andric static bool isInterleavingMask(ArrayRef<int> Mask);
88*bdd1243dSDimitry Andric 
89*bdd1243dSDimitry Andric /// Checks the given mask, and determines whether said mask is deinterleaving.
90*bdd1243dSDimitry Andric ///
91*bdd1243dSDimitry Andric /// To be deinterleaving, a mask must increment in steps of 2, and either start
92*bdd1243dSDimitry Andric /// with 0 or 1.
93*bdd1243dSDimitry Andric /// (e.g. an 8x vector deinterleaving mask would be either <0, 2, 4, 6> or
94*bdd1243dSDimitry Andric /// <1, 3, 5, 7>).
95*bdd1243dSDimitry Andric static bool isDeinterleavingMask(ArrayRef<int> Mask);
96*bdd1243dSDimitry Andric 
97*bdd1243dSDimitry Andric namespace {
98*bdd1243dSDimitry Andric 
99*bdd1243dSDimitry Andric class ComplexDeinterleavingLegacyPass : public FunctionPass {
100*bdd1243dSDimitry Andric public:
101*bdd1243dSDimitry Andric   static char ID;
102*bdd1243dSDimitry Andric 
103*bdd1243dSDimitry Andric   ComplexDeinterleavingLegacyPass(const TargetMachine *TM = nullptr)
104*bdd1243dSDimitry Andric       : FunctionPass(ID), TM(TM) {
105*bdd1243dSDimitry Andric     initializeComplexDeinterleavingLegacyPassPass(
106*bdd1243dSDimitry Andric         *PassRegistry::getPassRegistry());
107*bdd1243dSDimitry Andric   }
108*bdd1243dSDimitry Andric 
109*bdd1243dSDimitry Andric   StringRef getPassName() const override {
110*bdd1243dSDimitry Andric     return "Complex Deinterleaving Pass";
111*bdd1243dSDimitry Andric   }
112*bdd1243dSDimitry Andric 
113*bdd1243dSDimitry Andric   bool runOnFunction(Function &F) override;
114*bdd1243dSDimitry Andric   void getAnalysisUsage(AnalysisUsage &AU) const override {
115*bdd1243dSDimitry Andric     AU.addRequired<TargetLibraryInfoWrapperPass>();
116*bdd1243dSDimitry Andric     AU.setPreservesCFG();
117*bdd1243dSDimitry Andric   }
118*bdd1243dSDimitry Andric 
119*bdd1243dSDimitry Andric private:
120*bdd1243dSDimitry Andric   const TargetMachine *TM;
121*bdd1243dSDimitry Andric };
122*bdd1243dSDimitry Andric 
123*bdd1243dSDimitry Andric class ComplexDeinterleavingGraph;
124*bdd1243dSDimitry Andric struct ComplexDeinterleavingCompositeNode {
125*bdd1243dSDimitry Andric 
126*bdd1243dSDimitry Andric   ComplexDeinterleavingCompositeNode(ComplexDeinterleavingOperation Op,
127*bdd1243dSDimitry Andric                                      Instruction *R, Instruction *I)
128*bdd1243dSDimitry Andric       : Operation(Op), Real(R), Imag(I) {}
129*bdd1243dSDimitry Andric 
130*bdd1243dSDimitry Andric private:
131*bdd1243dSDimitry Andric   friend class ComplexDeinterleavingGraph;
132*bdd1243dSDimitry Andric   using NodePtr = std::shared_ptr<ComplexDeinterleavingCompositeNode>;
133*bdd1243dSDimitry Andric   using RawNodePtr = ComplexDeinterleavingCompositeNode *;
134*bdd1243dSDimitry Andric 
135*bdd1243dSDimitry Andric public:
136*bdd1243dSDimitry Andric   ComplexDeinterleavingOperation Operation;
137*bdd1243dSDimitry Andric   Instruction *Real;
138*bdd1243dSDimitry Andric   Instruction *Imag;
139*bdd1243dSDimitry Andric 
140*bdd1243dSDimitry Andric   // Instructions that should only exist within this node, there should be no
141*bdd1243dSDimitry Andric   // users of these instructions outside the node. An example of these would be
142*bdd1243dSDimitry Andric   // the multiply instructions of a partial multiply operation.
143*bdd1243dSDimitry Andric   SmallVector<Instruction *> InternalInstructions;
144*bdd1243dSDimitry Andric   ComplexDeinterleavingRotation Rotation;
145*bdd1243dSDimitry Andric   SmallVector<RawNodePtr> Operands;
146*bdd1243dSDimitry Andric   Value *ReplacementNode = nullptr;
147*bdd1243dSDimitry Andric 
148*bdd1243dSDimitry Andric   void addInstruction(Instruction *I) { InternalInstructions.push_back(I); }
149*bdd1243dSDimitry Andric   void addOperand(NodePtr Node) { Operands.push_back(Node.get()); }
150*bdd1243dSDimitry Andric 
151*bdd1243dSDimitry Andric   bool hasAllInternalUses(SmallPtrSet<Instruction *, 16> &AllInstructions);
152*bdd1243dSDimitry Andric 
153*bdd1243dSDimitry Andric   void dump() { dump(dbgs()); }
154*bdd1243dSDimitry Andric   void dump(raw_ostream &OS) {
155*bdd1243dSDimitry Andric     auto PrintValue = [&](Value *V) {
156*bdd1243dSDimitry Andric       if (V) {
157*bdd1243dSDimitry Andric         OS << "\"";
158*bdd1243dSDimitry Andric         V->print(OS, true);
159*bdd1243dSDimitry Andric         OS << "\"\n";
160*bdd1243dSDimitry Andric       } else
161*bdd1243dSDimitry Andric         OS << "nullptr\n";
162*bdd1243dSDimitry Andric     };
163*bdd1243dSDimitry Andric     auto PrintNodeRef = [&](RawNodePtr Ptr) {
164*bdd1243dSDimitry Andric       if (Ptr)
165*bdd1243dSDimitry Andric         OS << Ptr << "\n";
166*bdd1243dSDimitry Andric       else
167*bdd1243dSDimitry Andric         OS << "nullptr\n";
168*bdd1243dSDimitry Andric     };
169*bdd1243dSDimitry Andric 
170*bdd1243dSDimitry Andric     OS << "- CompositeNode: " << this << "\n";
171*bdd1243dSDimitry Andric     OS << "  Real: ";
172*bdd1243dSDimitry Andric     PrintValue(Real);
173*bdd1243dSDimitry Andric     OS << "  Imag: ";
174*bdd1243dSDimitry Andric     PrintValue(Imag);
175*bdd1243dSDimitry Andric     OS << "  ReplacementNode: ";
176*bdd1243dSDimitry Andric     PrintValue(ReplacementNode);
177*bdd1243dSDimitry Andric     OS << "  Operation: " << (int)Operation << "\n";
178*bdd1243dSDimitry Andric     OS << "  Rotation: " << ((int)Rotation * 90) << "\n";
179*bdd1243dSDimitry Andric     OS << "  Operands: \n";
180*bdd1243dSDimitry Andric     for (const auto &Op : Operands) {
181*bdd1243dSDimitry Andric       OS << "    - ";
182*bdd1243dSDimitry Andric       PrintNodeRef(Op);
183*bdd1243dSDimitry Andric     }
184*bdd1243dSDimitry Andric     OS << "  InternalInstructions:\n";
185*bdd1243dSDimitry Andric     for (const auto &I : InternalInstructions) {
186*bdd1243dSDimitry Andric       OS << "    - \"";
187*bdd1243dSDimitry Andric       I->print(OS, true);
188*bdd1243dSDimitry Andric       OS << "\"\n";
189*bdd1243dSDimitry Andric     }
190*bdd1243dSDimitry Andric   }
191*bdd1243dSDimitry Andric };
192*bdd1243dSDimitry Andric 
193*bdd1243dSDimitry Andric class ComplexDeinterleavingGraph {
194*bdd1243dSDimitry Andric public:
195*bdd1243dSDimitry Andric   using NodePtr = ComplexDeinterleavingCompositeNode::NodePtr;
196*bdd1243dSDimitry Andric   using RawNodePtr = ComplexDeinterleavingCompositeNode::RawNodePtr;
197*bdd1243dSDimitry Andric   explicit ComplexDeinterleavingGraph(const TargetLowering *tl) : TL(tl) {}
198*bdd1243dSDimitry Andric 
199*bdd1243dSDimitry Andric private:
200*bdd1243dSDimitry Andric   const TargetLowering *TL;
201*bdd1243dSDimitry Andric   Instruction *RootValue;
202*bdd1243dSDimitry Andric   NodePtr RootNode;
203*bdd1243dSDimitry Andric   SmallVector<NodePtr> CompositeNodes;
204*bdd1243dSDimitry Andric   SmallPtrSet<Instruction *, 16> AllInstructions;
205*bdd1243dSDimitry Andric 
206*bdd1243dSDimitry Andric   NodePtr prepareCompositeNode(ComplexDeinterleavingOperation Operation,
207*bdd1243dSDimitry Andric                                Instruction *R, Instruction *I) {
208*bdd1243dSDimitry Andric     return std::make_shared<ComplexDeinterleavingCompositeNode>(Operation, R,
209*bdd1243dSDimitry Andric                                                                 I);
210*bdd1243dSDimitry Andric   }
211*bdd1243dSDimitry Andric 
212*bdd1243dSDimitry Andric   NodePtr submitCompositeNode(NodePtr Node) {
213*bdd1243dSDimitry Andric     CompositeNodes.push_back(Node);
214*bdd1243dSDimitry Andric     AllInstructions.insert(Node->Real);
215*bdd1243dSDimitry Andric     AllInstructions.insert(Node->Imag);
216*bdd1243dSDimitry Andric     for (auto *I : Node->InternalInstructions)
217*bdd1243dSDimitry Andric       AllInstructions.insert(I);
218*bdd1243dSDimitry Andric     return Node;
219*bdd1243dSDimitry Andric   }
220*bdd1243dSDimitry Andric 
221*bdd1243dSDimitry Andric   NodePtr getContainingComposite(Value *R, Value *I) {
222*bdd1243dSDimitry Andric     for (const auto &CN : CompositeNodes) {
223*bdd1243dSDimitry Andric       if (CN->Real == R && CN->Imag == I)
224*bdd1243dSDimitry Andric         return CN;
225*bdd1243dSDimitry Andric     }
226*bdd1243dSDimitry Andric     return nullptr;
227*bdd1243dSDimitry Andric   }
228*bdd1243dSDimitry Andric 
229*bdd1243dSDimitry Andric   /// Identifies a complex partial multiply pattern and its rotation, based on
230*bdd1243dSDimitry Andric   /// the following patterns
231*bdd1243dSDimitry Andric   ///
232*bdd1243dSDimitry Andric   ///  0:  r: cr + ar * br
233*bdd1243dSDimitry Andric   ///      i: ci + ar * bi
234*bdd1243dSDimitry Andric   /// 90:  r: cr - ai * bi
235*bdd1243dSDimitry Andric   ///      i: ci + ai * br
236*bdd1243dSDimitry Andric   /// 180: r: cr - ar * br
237*bdd1243dSDimitry Andric   ///      i: ci - ar * bi
238*bdd1243dSDimitry Andric   /// 270: r: cr + ai * bi
239*bdd1243dSDimitry Andric   ///      i: ci - ai * br
240*bdd1243dSDimitry Andric   NodePtr identifyPartialMul(Instruction *Real, Instruction *Imag);
241*bdd1243dSDimitry Andric 
242*bdd1243dSDimitry Andric   /// Identify the other branch of a Partial Mul, taking the CommonOperandI that
243*bdd1243dSDimitry Andric   /// is partially known from identifyPartialMul, filling in the other half of
244*bdd1243dSDimitry Andric   /// the complex pair.
245*bdd1243dSDimitry Andric   NodePtr identifyNodeWithImplicitAdd(
246*bdd1243dSDimitry Andric       Instruction *I, Instruction *J,
247*bdd1243dSDimitry Andric       std::pair<Instruction *, Instruction *> &CommonOperandI);
248*bdd1243dSDimitry Andric 
249*bdd1243dSDimitry Andric   /// Identifies a complex add pattern and its rotation, based on the following
250*bdd1243dSDimitry Andric   /// patterns.
251*bdd1243dSDimitry Andric   ///
252*bdd1243dSDimitry Andric   /// 90:  r: ar - bi
253*bdd1243dSDimitry Andric   ///      i: ai + br
254*bdd1243dSDimitry Andric   /// 270: r: ar + bi
255*bdd1243dSDimitry Andric   ///      i: ai - br
256*bdd1243dSDimitry Andric   NodePtr identifyAdd(Instruction *Real, Instruction *Imag);
257*bdd1243dSDimitry Andric 
258*bdd1243dSDimitry Andric   NodePtr identifyNode(Instruction *I, Instruction *J);
259*bdd1243dSDimitry Andric 
260*bdd1243dSDimitry Andric   Value *replaceNode(RawNodePtr Node);
261*bdd1243dSDimitry Andric 
262*bdd1243dSDimitry Andric public:
263*bdd1243dSDimitry Andric   void dump() { dump(dbgs()); }
264*bdd1243dSDimitry Andric   void dump(raw_ostream &OS) {
265*bdd1243dSDimitry Andric     for (const auto &Node : CompositeNodes)
266*bdd1243dSDimitry Andric       Node->dump(OS);
267*bdd1243dSDimitry Andric   }
268*bdd1243dSDimitry Andric 
269*bdd1243dSDimitry Andric   /// Returns false if the deinterleaving operation should be cancelled for the
270*bdd1243dSDimitry Andric   /// current graph.
271*bdd1243dSDimitry Andric   bool identifyNodes(Instruction *RootI);
272*bdd1243dSDimitry Andric 
273*bdd1243dSDimitry Andric   /// Perform the actual replacement of the underlying instruction graph.
274*bdd1243dSDimitry Andric   /// Returns false if the deinterleaving operation should be cancelled for the
275*bdd1243dSDimitry Andric   /// current graph.
276*bdd1243dSDimitry Andric   void replaceNodes();
277*bdd1243dSDimitry Andric };
278*bdd1243dSDimitry Andric 
279*bdd1243dSDimitry Andric class ComplexDeinterleaving {
280*bdd1243dSDimitry Andric public:
281*bdd1243dSDimitry Andric   ComplexDeinterleaving(const TargetLowering *tl, const TargetLibraryInfo *tli)
282*bdd1243dSDimitry Andric       : TL(tl), TLI(tli) {}
283*bdd1243dSDimitry Andric   bool runOnFunction(Function &F);
284*bdd1243dSDimitry Andric 
285*bdd1243dSDimitry Andric private:
286*bdd1243dSDimitry Andric   bool evaluateBasicBlock(BasicBlock *B);
287*bdd1243dSDimitry Andric 
288*bdd1243dSDimitry Andric   const TargetLowering *TL = nullptr;
289*bdd1243dSDimitry Andric   const TargetLibraryInfo *TLI = nullptr;
290*bdd1243dSDimitry Andric };
291*bdd1243dSDimitry Andric 
292*bdd1243dSDimitry Andric } // namespace
293*bdd1243dSDimitry Andric 
294*bdd1243dSDimitry Andric char ComplexDeinterleavingLegacyPass::ID = 0;
295*bdd1243dSDimitry Andric 
296*bdd1243dSDimitry Andric INITIALIZE_PASS_BEGIN(ComplexDeinterleavingLegacyPass, DEBUG_TYPE,
297*bdd1243dSDimitry Andric                       "Complex Deinterleaving", false, false)
298*bdd1243dSDimitry Andric INITIALIZE_PASS_END(ComplexDeinterleavingLegacyPass, DEBUG_TYPE,
299*bdd1243dSDimitry Andric                     "Complex Deinterleaving", false, false)
300*bdd1243dSDimitry Andric 
301*bdd1243dSDimitry Andric PreservedAnalyses ComplexDeinterleavingPass::run(Function &F,
302*bdd1243dSDimitry Andric                                                  FunctionAnalysisManager &AM) {
303*bdd1243dSDimitry Andric   const TargetLowering *TL = TM->getSubtargetImpl(F)->getTargetLowering();
304*bdd1243dSDimitry Andric   auto &TLI = AM.getResult<llvm::TargetLibraryAnalysis>(F);
305*bdd1243dSDimitry Andric   if (!ComplexDeinterleaving(TL, &TLI).runOnFunction(F))
306*bdd1243dSDimitry Andric     return PreservedAnalyses::all();
307*bdd1243dSDimitry Andric 
308*bdd1243dSDimitry Andric   PreservedAnalyses PA;
309*bdd1243dSDimitry Andric   PA.preserve<FunctionAnalysisManagerModuleProxy>();
310*bdd1243dSDimitry Andric   return PA;
311*bdd1243dSDimitry Andric }
312*bdd1243dSDimitry Andric 
313*bdd1243dSDimitry Andric FunctionPass *llvm::createComplexDeinterleavingPass(const TargetMachine *TM) {
314*bdd1243dSDimitry Andric   return new ComplexDeinterleavingLegacyPass(TM);
315*bdd1243dSDimitry Andric }
316*bdd1243dSDimitry Andric 
317*bdd1243dSDimitry Andric bool ComplexDeinterleavingLegacyPass::runOnFunction(Function &F) {
318*bdd1243dSDimitry Andric   const auto *TL = TM->getSubtargetImpl(F)->getTargetLowering();
319*bdd1243dSDimitry Andric   auto TLI = getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F);
320*bdd1243dSDimitry Andric   return ComplexDeinterleaving(TL, &TLI).runOnFunction(F);
321*bdd1243dSDimitry Andric }
322*bdd1243dSDimitry Andric 
323*bdd1243dSDimitry Andric bool ComplexDeinterleaving::runOnFunction(Function &F) {
324*bdd1243dSDimitry Andric   if (!ComplexDeinterleavingEnabled) {
325*bdd1243dSDimitry Andric     LLVM_DEBUG(
326*bdd1243dSDimitry Andric         dbgs() << "Complex deinterleaving has been explicitly disabled.\n");
327*bdd1243dSDimitry Andric     return false;
328*bdd1243dSDimitry Andric   }
329*bdd1243dSDimitry Andric 
330*bdd1243dSDimitry Andric   if (!TL->isComplexDeinterleavingSupported()) {
331*bdd1243dSDimitry Andric     LLVM_DEBUG(
332*bdd1243dSDimitry Andric         dbgs() << "Complex deinterleaving has been disabled, target does "
333*bdd1243dSDimitry Andric                   "not support lowering of complex number operations.\n");
334*bdd1243dSDimitry Andric     return false;
335*bdd1243dSDimitry Andric   }
336*bdd1243dSDimitry Andric 
337*bdd1243dSDimitry Andric   bool Changed = false;
338*bdd1243dSDimitry Andric   for (auto &B : F)
339*bdd1243dSDimitry Andric     Changed |= evaluateBasicBlock(&B);
340*bdd1243dSDimitry Andric 
341*bdd1243dSDimitry Andric   return Changed;
342*bdd1243dSDimitry Andric }
343*bdd1243dSDimitry Andric 
344*bdd1243dSDimitry Andric static bool isInterleavingMask(ArrayRef<int> Mask) {
345*bdd1243dSDimitry Andric   // If the size is not even, it's not an interleaving mask
346*bdd1243dSDimitry Andric   if ((Mask.size() & 1))
347*bdd1243dSDimitry Andric     return false;
348*bdd1243dSDimitry Andric 
349*bdd1243dSDimitry Andric   int HalfNumElements = Mask.size() / 2;
350*bdd1243dSDimitry Andric   for (int Idx = 0; Idx < HalfNumElements; ++Idx) {
351*bdd1243dSDimitry Andric     int MaskIdx = Idx * 2;
352*bdd1243dSDimitry Andric     if (Mask[MaskIdx] != Idx || Mask[MaskIdx + 1] != (Idx + HalfNumElements))
353*bdd1243dSDimitry Andric       return false;
354*bdd1243dSDimitry Andric   }
355*bdd1243dSDimitry Andric 
356*bdd1243dSDimitry Andric   return true;
357*bdd1243dSDimitry Andric }
358*bdd1243dSDimitry Andric 
359*bdd1243dSDimitry Andric static bool isDeinterleavingMask(ArrayRef<int> Mask) {
360*bdd1243dSDimitry Andric   int Offset = Mask[0];
361*bdd1243dSDimitry Andric   int HalfNumElements = Mask.size() / 2;
362*bdd1243dSDimitry Andric 
363*bdd1243dSDimitry Andric   for (int Idx = 1; Idx < HalfNumElements; ++Idx) {
364*bdd1243dSDimitry Andric     if (Mask[Idx] != (Idx * 2) + Offset)
365*bdd1243dSDimitry Andric       return false;
366*bdd1243dSDimitry Andric   }
367*bdd1243dSDimitry Andric 
368*bdd1243dSDimitry Andric   return true;
369*bdd1243dSDimitry Andric }
370*bdd1243dSDimitry Andric 
371*bdd1243dSDimitry Andric bool ComplexDeinterleaving::evaluateBasicBlock(BasicBlock *B) {
372*bdd1243dSDimitry Andric   bool Changed = false;
373*bdd1243dSDimitry Andric 
374*bdd1243dSDimitry Andric   SmallVector<Instruction *> DeadInstrRoots;
375*bdd1243dSDimitry Andric 
376*bdd1243dSDimitry Andric   for (auto &I : *B) {
377*bdd1243dSDimitry Andric     auto *SVI = dyn_cast<ShuffleVectorInst>(&I);
378*bdd1243dSDimitry Andric     if (!SVI)
379*bdd1243dSDimitry Andric       continue;
380*bdd1243dSDimitry Andric 
381*bdd1243dSDimitry Andric     // Look for a shufflevector that takes separate vectors of the real and
382*bdd1243dSDimitry Andric     // imaginary components and recombines them into a single vector.
383*bdd1243dSDimitry Andric     if (!isInterleavingMask(SVI->getShuffleMask()))
384*bdd1243dSDimitry Andric       continue;
385*bdd1243dSDimitry Andric 
386*bdd1243dSDimitry Andric     ComplexDeinterleavingGraph Graph(TL);
387*bdd1243dSDimitry Andric     if (!Graph.identifyNodes(SVI))
388*bdd1243dSDimitry Andric       continue;
389*bdd1243dSDimitry Andric 
390*bdd1243dSDimitry Andric     Graph.replaceNodes();
391*bdd1243dSDimitry Andric     DeadInstrRoots.push_back(SVI);
392*bdd1243dSDimitry Andric     Changed = true;
393*bdd1243dSDimitry Andric   }
394*bdd1243dSDimitry Andric 
395*bdd1243dSDimitry Andric   for (const auto &I : DeadInstrRoots) {
396*bdd1243dSDimitry Andric     if (!I || I->getParent() == nullptr)
397*bdd1243dSDimitry Andric       continue;
398*bdd1243dSDimitry Andric     llvm::RecursivelyDeleteTriviallyDeadInstructions(I, TLI);
399*bdd1243dSDimitry Andric   }
400*bdd1243dSDimitry Andric 
401*bdd1243dSDimitry Andric   return Changed;
402*bdd1243dSDimitry Andric }
403*bdd1243dSDimitry Andric 
404*bdd1243dSDimitry Andric ComplexDeinterleavingGraph::NodePtr
405*bdd1243dSDimitry Andric ComplexDeinterleavingGraph::identifyNodeWithImplicitAdd(
406*bdd1243dSDimitry Andric     Instruction *Real, Instruction *Imag,
407*bdd1243dSDimitry Andric     std::pair<Instruction *, Instruction *> &PartialMatch) {
408*bdd1243dSDimitry Andric   LLVM_DEBUG(dbgs() << "identifyNodeWithImplicitAdd " << *Real << " / " << *Imag
409*bdd1243dSDimitry Andric                     << "\n");
410*bdd1243dSDimitry Andric 
411*bdd1243dSDimitry Andric   if (!Real->hasOneUse() || !Imag->hasOneUse()) {
412*bdd1243dSDimitry Andric     LLVM_DEBUG(dbgs() << "  - Mul operand has multiple uses.\n");
413*bdd1243dSDimitry Andric     return nullptr;
414*bdd1243dSDimitry Andric   }
415*bdd1243dSDimitry Andric 
416*bdd1243dSDimitry Andric   if (Real->getOpcode() != Instruction::FMul ||
417*bdd1243dSDimitry Andric       Imag->getOpcode() != Instruction::FMul) {
418*bdd1243dSDimitry Andric     LLVM_DEBUG(dbgs() << "  - Real or imaginary instruction is not fmul\n");
419*bdd1243dSDimitry Andric     return nullptr;
420*bdd1243dSDimitry Andric   }
421*bdd1243dSDimitry Andric 
422*bdd1243dSDimitry Andric   Instruction *R0 = dyn_cast<Instruction>(Real->getOperand(0));
423*bdd1243dSDimitry Andric   Instruction *R1 = dyn_cast<Instruction>(Real->getOperand(1));
424*bdd1243dSDimitry Andric   Instruction *I0 = dyn_cast<Instruction>(Imag->getOperand(0));
425*bdd1243dSDimitry Andric   Instruction *I1 = dyn_cast<Instruction>(Imag->getOperand(1));
426*bdd1243dSDimitry Andric   if (!R0 || !R1 || !I0 || !I1) {
427*bdd1243dSDimitry Andric     LLVM_DEBUG(dbgs() << "  - Mul operand not Instruction\n");
428*bdd1243dSDimitry Andric     return nullptr;
429*bdd1243dSDimitry Andric   }
430*bdd1243dSDimitry Andric 
431*bdd1243dSDimitry Andric   // A +/+ has a rotation of 0. If any of the operands are fneg, we flip the
432*bdd1243dSDimitry Andric   // rotations and use the operand.
433*bdd1243dSDimitry Andric   unsigned Negs = 0;
434*bdd1243dSDimitry Andric   SmallVector<Instruction *> FNegs;
435*bdd1243dSDimitry Andric   if (R0->getOpcode() == Instruction::FNeg ||
436*bdd1243dSDimitry Andric       R1->getOpcode() == Instruction::FNeg) {
437*bdd1243dSDimitry Andric     Negs |= 1;
438*bdd1243dSDimitry Andric     if (R0->getOpcode() == Instruction::FNeg) {
439*bdd1243dSDimitry Andric       FNegs.push_back(R0);
440*bdd1243dSDimitry Andric       R0 = dyn_cast<Instruction>(R0->getOperand(0));
441*bdd1243dSDimitry Andric     } else {
442*bdd1243dSDimitry Andric       FNegs.push_back(R1);
443*bdd1243dSDimitry Andric       R1 = dyn_cast<Instruction>(R1->getOperand(0));
444*bdd1243dSDimitry Andric     }
445*bdd1243dSDimitry Andric     if (!R0 || !R1)
446*bdd1243dSDimitry Andric       return nullptr;
447*bdd1243dSDimitry Andric   }
448*bdd1243dSDimitry Andric   if (I0->getOpcode() == Instruction::FNeg ||
449*bdd1243dSDimitry Andric       I1->getOpcode() == Instruction::FNeg) {
450*bdd1243dSDimitry Andric     Negs |= 2;
451*bdd1243dSDimitry Andric     Negs ^= 1;
452*bdd1243dSDimitry Andric     if (I0->getOpcode() == Instruction::FNeg) {
453*bdd1243dSDimitry Andric       FNegs.push_back(I0);
454*bdd1243dSDimitry Andric       I0 = dyn_cast<Instruction>(I0->getOperand(0));
455*bdd1243dSDimitry Andric     } else {
456*bdd1243dSDimitry Andric       FNegs.push_back(I1);
457*bdd1243dSDimitry Andric       I1 = dyn_cast<Instruction>(I1->getOperand(0));
458*bdd1243dSDimitry Andric     }
459*bdd1243dSDimitry Andric     if (!I0 || !I1)
460*bdd1243dSDimitry Andric       return nullptr;
461*bdd1243dSDimitry Andric   }
462*bdd1243dSDimitry Andric 
463*bdd1243dSDimitry Andric   ComplexDeinterleavingRotation Rotation = (ComplexDeinterleavingRotation)Negs;
464*bdd1243dSDimitry Andric 
465*bdd1243dSDimitry Andric   Instruction *CommonOperand;
466*bdd1243dSDimitry Andric   Instruction *UncommonRealOp;
467*bdd1243dSDimitry Andric   Instruction *UncommonImagOp;
468*bdd1243dSDimitry Andric 
469*bdd1243dSDimitry Andric   if (R0 == I0 || R0 == I1) {
470*bdd1243dSDimitry Andric     CommonOperand = R0;
471*bdd1243dSDimitry Andric     UncommonRealOp = R1;
472*bdd1243dSDimitry Andric   } else if (R1 == I0 || R1 == I1) {
473*bdd1243dSDimitry Andric     CommonOperand = R1;
474*bdd1243dSDimitry Andric     UncommonRealOp = R0;
475*bdd1243dSDimitry Andric   } else {
476*bdd1243dSDimitry Andric     LLVM_DEBUG(dbgs() << "  - No equal operand\n");
477*bdd1243dSDimitry Andric     return nullptr;
478*bdd1243dSDimitry Andric   }
479*bdd1243dSDimitry Andric 
480*bdd1243dSDimitry Andric   UncommonImagOp = (CommonOperand == I0) ? I1 : I0;
481*bdd1243dSDimitry Andric   if (Rotation == ComplexDeinterleavingRotation::Rotation_90 ||
482*bdd1243dSDimitry Andric       Rotation == ComplexDeinterleavingRotation::Rotation_270)
483*bdd1243dSDimitry Andric     std::swap(UncommonRealOp, UncommonImagOp);
484*bdd1243dSDimitry Andric 
485*bdd1243dSDimitry Andric   // Between identifyPartialMul and here we need to have found a complete valid
486*bdd1243dSDimitry Andric   // pair from the CommonOperand of each part.
487*bdd1243dSDimitry Andric   if (Rotation == ComplexDeinterleavingRotation::Rotation_0 ||
488*bdd1243dSDimitry Andric       Rotation == ComplexDeinterleavingRotation::Rotation_180)
489*bdd1243dSDimitry Andric     PartialMatch.first = CommonOperand;
490*bdd1243dSDimitry Andric   else
491*bdd1243dSDimitry Andric     PartialMatch.second = CommonOperand;
492*bdd1243dSDimitry Andric 
493*bdd1243dSDimitry Andric   if (!PartialMatch.first || !PartialMatch.second) {
494*bdd1243dSDimitry Andric     LLVM_DEBUG(dbgs() << "  - Incomplete partial match\n");
495*bdd1243dSDimitry Andric     return nullptr;
496*bdd1243dSDimitry Andric   }
497*bdd1243dSDimitry Andric 
498*bdd1243dSDimitry Andric   NodePtr CommonNode = identifyNode(PartialMatch.first, PartialMatch.second);
499*bdd1243dSDimitry Andric   if (!CommonNode) {
500*bdd1243dSDimitry Andric     LLVM_DEBUG(dbgs() << "  - No CommonNode identified\n");
501*bdd1243dSDimitry Andric     return nullptr;
502*bdd1243dSDimitry Andric   }
503*bdd1243dSDimitry Andric 
504*bdd1243dSDimitry Andric   NodePtr UncommonNode = identifyNode(UncommonRealOp, UncommonImagOp);
505*bdd1243dSDimitry Andric   if (!UncommonNode) {
506*bdd1243dSDimitry Andric     LLVM_DEBUG(dbgs() << "  - No UncommonNode identified\n");
507*bdd1243dSDimitry Andric     return nullptr;
508*bdd1243dSDimitry Andric   }
509*bdd1243dSDimitry Andric 
510*bdd1243dSDimitry Andric   NodePtr Node = prepareCompositeNode(
511*bdd1243dSDimitry Andric       ComplexDeinterleavingOperation::CMulPartial, Real, Imag);
512*bdd1243dSDimitry Andric   Node->Rotation = Rotation;
513*bdd1243dSDimitry Andric   Node->addOperand(CommonNode);
514*bdd1243dSDimitry Andric   Node->addOperand(UncommonNode);
515*bdd1243dSDimitry Andric   Node->InternalInstructions.append(FNegs);
516*bdd1243dSDimitry Andric   return submitCompositeNode(Node);
517*bdd1243dSDimitry Andric }
518*bdd1243dSDimitry Andric 
519*bdd1243dSDimitry Andric ComplexDeinterleavingGraph::NodePtr
520*bdd1243dSDimitry Andric ComplexDeinterleavingGraph::identifyPartialMul(Instruction *Real,
521*bdd1243dSDimitry Andric                                                Instruction *Imag) {
522*bdd1243dSDimitry Andric   LLVM_DEBUG(dbgs() << "identifyPartialMul " << *Real << " / " << *Imag
523*bdd1243dSDimitry Andric                     << "\n");
524*bdd1243dSDimitry Andric   // Determine rotation
525*bdd1243dSDimitry Andric   ComplexDeinterleavingRotation Rotation;
526*bdd1243dSDimitry Andric   if (Real->getOpcode() == Instruction::FAdd &&
527*bdd1243dSDimitry Andric       Imag->getOpcode() == Instruction::FAdd)
528*bdd1243dSDimitry Andric     Rotation = ComplexDeinterleavingRotation::Rotation_0;
529*bdd1243dSDimitry Andric   else if (Real->getOpcode() == Instruction::FSub &&
530*bdd1243dSDimitry Andric            Imag->getOpcode() == Instruction::FAdd)
531*bdd1243dSDimitry Andric     Rotation = ComplexDeinterleavingRotation::Rotation_90;
532*bdd1243dSDimitry Andric   else if (Real->getOpcode() == Instruction::FSub &&
533*bdd1243dSDimitry Andric            Imag->getOpcode() == Instruction::FSub)
534*bdd1243dSDimitry Andric     Rotation = ComplexDeinterleavingRotation::Rotation_180;
535*bdd1243dSDimitry Andric   else if (Real->getOpcode() == Instruction::FAdd &&
536*bdd1243dSDimitry Andric            Imag->getOpcode() == Instruction::FSub)
537*bdd1243dSDimitry Andric     Rotation = ComplexDeinterleavingRotation::Rotation_270;
538*bdd1243dSDimitry Andric   else {
539*bdd1243dSDimitry Andric     LLVM_DEBUG(dbgs() << "  - Unhandled rotation.\n");
540*bdd1243dSDimitry Andric     return nullptr;
541*bdd1243dSDimitry Andric   }
542*bdd1243dSDimitry Andric 
543*bdd1243dSDimitry Andric   if (!Real->getFastMathFlags().allowContract() ||
544*bdd1243dSDimitry Andric       !Imag->getFastMathFlags().allowContract()) {
545*bdd1243dSDimitry Andric     LLVM_DEBUG(dbgs() << "  - Contract is missing from the FastMath flags.\n");
546*bdd1243dSDimitry Andric     return nullptr;
547*bdd1243dSDimitry Andric   }
548*bdd1243dSDimitry Andric 
549*bdd1243dSDimitry Andric   Value *CR = Real->getOperand(0);
550*bdd1243dSDimitry Andric   Instruction *RealMulI = dyn_cast<Instruction>(Real->getOperand(1));
551*bdd1243dSDimitry Andric   if (!RealMulI)
552*bdd1243dSDimitry Andric     return nullptr;
553*bdd1243dSDimitry Andric   Value *CI = Imag->getOperand(0);
554*bdd1243dSDimitry Andric   Instruction *ImagMulI = dyn_cast<Instruction>(Imag->getOperand(1));
555*bdd1243dSDimitry Andric   if (!ImagMulI)
556*bdd1243dSDimitry Andric     return nullptr;
557*bdd1243dSDimitry Andric 
558*bdd1243dSDimitry Andric   if (!RealMulI->hasOneUse() || !ImagMulI->hasOneUse()) {
559*bdd1243dSDimitry Andric     LLVM_DEBUG(dbgs() << "  - Mul instruction has multiple uses\n");
560*bdd1243dSDimitry Andric     return nullptr;
561*bdd1243dSDimitry Andric   }
562*bdd1243dSDimitry Andric 
563*bdd1243dSDimitry Andric   Instruction *R0 = dyn_cast<Instruction>(RealMulI->getOperand(0));
564*bdd1243dSDimitry Andric   Instruction *R1 = dyn_cast<Instruction>(RealMulI->getOperand(1));
565*bdd1243dSDimitry Andric   Instruction *I0 = dyn_cast<Instruction>(ImagMulI->getOperand(0));
566*bdd1243dSDimitry Andric   Instruction *I1 = dyn_cast<Instruction>(ImagMulI->getOperand(1));
567*bdd1243dSDimitry Andric   if (!R0 || !R1 || !I0 || !I1) {
568*bdd1243dSDimitry Andric     LLVM_DEBUG(dbgs() << "  - Mul operand not Instruction\n");
569*bdd1243dSDimitry Andric     return nullptr;
570*bdd1243dSDimitry Andric   }
571*bdd1243dSDimitry Andric 
572*bdd1243dSDimitry Andric   Instruction *CommonOperand;
573*bdd1243dSDimitry Andric   Instruction *UncommonRealOp;
574*bdd1243dSDimitry Andric   Instruction *UncommonImagOp;
575*bdd1243dSDimitry Andric 
576*bdd1243dSDimitry Andric   if (R0 == I0 || R0 == I1) {
577*bdd1243dSDimitry Andric     CommonOperand = R0;
578*bdd1243dSDimitry Andric     UncommonRealOp = R1;
579*bdd1243dSDimitry Andric   } else if (R1 == I0 || R1 == I1) {
580*bdd1243dSDimitry Andric     CommonOperand = R1;
581*bdd1243dSDimitry Andric     UncommonRealOp = R0;
582*bdd1243dSDimitry Andric   } else {
583*bdd1243dSDimitry Andric     LLVM_DEBUG(dbgs() << "  - No equal operand\n");
584*bdd1243dSDimitry Andric     return nullptr;
585*bdd1243dSDimitry Andric   }
586*bdd1243dSDimitry Andric 
587*bdd1243dSDimitry Andric   UncommonImagOp = (CommonOperand == I0) ? I1 : I0;
588*bdd1243dSDimitry Andric   if (Rotation == ComplexDeinterleavingRotation::Rotation_90 ||
589*bdd1243dSDimitry Andric       Rotation == ComplexDeinterleavingRotation::Rotation_270)
590*bdd1243dSDimitry Andric     std::swap(UncommonRealOp, UncommonImagOp);
591*bdd1243dSDimitry Andric 
592*bdd1243dSDimitry Andric   std::pair<Instruction *, Instruction *> PartialMatch(
593*bdd1243dSDimitry Andric       (Rotation == ComplexDeinterleavingRotation::Rotation_0 ||
594*bdd1243dSDimitry Andric        Rotation == ComplexDeinterleavingRotation::Rotation_180)
595*bdd1243dSDimitry Andric           ? CommonOperand
596*bdd1243dSDimitry Andric           : nullptr,
597*bdd1243dSDimitry Andric       (Rotation == ComplexDeinterleavingRotation::Rotation_90 ||
598*bdd1243dSDimitry Andric        Rotation == ComplexDeinterleavingRotation::Rotation_270)
599*bdd1243dSDimitry Andric           ? CommonOperand
600*bdd1243dSDimitry Andric           : nullptr);
601*bdd1243dSDimitry Andric   NodePtr CNode = identifyNodeWithImplicitAdd(
602*bdd1243dSDimitry Andric       cast<Instruction>(CR), cast<Instruction>(CI), PartialMatch);
603*bdd1243dSDimitry Andric   if (!CNode) {
604*bdd1243dSDimitry Andric     LLVM_DEBUG(dbgs() << "  - No cnode identified\n");
605*bdd1243dSDimitry Andric     return nullptr;
606*bdd1243dSDimitry Andric   }
607*bdd1243dSDimitry Andric 
608*bdd1243dSDimitry Andric   NodePtr UncommonRes = identifyNode(UncommonRealOp, UncommonImagOp);
609*bdd1243dSDimitry Andric   if (!UncommonRes) {
610*bdd1243dSDimitry Andric     LLVM_DEBUG(dbgs() << "  - No UncommonRes identified\n");
611*bdd1243dSDimitry Andric     return nullptr;
612*bdd1243dSDimitry Andric   }
613*bdd1243dSDimitry Andric 
614*bdd1243dSDimitry Andric   assert(PartialMatch.first && PartialMatch.second);
615*bdd1243dSDimitry Andric   NodePtr CommonRes = identifyNode(PartialMatch.first, PartialMatch.second);
616*bdd1243dSDimitry Andric   if (!CommonRes) {
617*bdd1243dSDimitry Andric     LLVM_DEBUG(dbgs() << "  - No CommonRes identified\n");
618*bdd1243dSDimitry Andric     return nullptr;
619*bdd1243dSDimitry Andric   }
620*bdd1243dSDimitry Andric 
621*bdd1243dSDimitry Andric   NodePtr Node = prepareCompositeNode(
622*bdd1243dSDimitry Andric       ComplexDeinterleavingOperation::CMulPartial, Real, Imag);
623*bdd1243dSDimitry Andric   Node->addInstruction(RealMulI);
624*bdd1243dSDimitry Andric   Node->addInstruction(ImagMulI);
625*bdd1243dSDimitry Andric   Node->Rotation = Rotation;
626*bdd1243dSDimitry Andric   Node->addOperand(CommonRes);
627*bdd1243dSDimitry Andric   Node->addOperand(UncommonRes);
628*bdd1243dSDimitry Andric   Node->addOperand(CNode);
629*bdd1243dSDimitry Andric   return submitCompositeNode(Node);
630*bdd1243dSDimitry Andric }
631*bdd1243dSDimitry Andric 
632*bdd1243dSDimitry Andric ComplexDeinterleavingGraph::NodePtr
633*bdd1243dSDimitry Andric ComplexDeinterleavingGraph::identifyAdd(Instruction *Real, Instruction *Imag) {
634*bdd1243dSDimitry Andric   LLVM_DEBUG(dbgs() << "identifyAdd " << *Real << " / " << *Imag << "\n");
635*bdd1243dSDimitry Andric 
636*bdd1243dSDimitry Andric   // Determine rotation
637*bdd1243dSDimitry Andric   ComplexDeinterleavingRotation Rotation;
638*bdd1243dSDimitry Andric   if ((Real->getOpcode() == Instruction::FSub &&
639*bdd1243dSDimitry Andric        Imag->getOpcode() == Instruction::FAdd) ||
640*bdd1243dSDimitry Andric       (Real->getOpcode() == Instruction::Sub &&
641*bdd1243dSDimitry Andric        Imag->getOpcode() == Instruction::Add))
642*bdd1243dSDimitry Andric     Rotation = ComplexDeinterleavingRotation::Rotation_90;
643*bdd1243dSDimitry Andric   else if ((Real->getOpcode() == Instruction::FAdd &&
644*bdd1243dSDimitry Andric             Imag->getOpcode() == Instruction::FSub) ||
645*bdd1243dSDimitry Andric            (Real->getOpcode() == Instruction::Add &&
646*bdd1243dSDimitry Andric             Imag->getOpcode() == Instruction::Sub))
647*bdd1243dSDimitry Andric     Rotation = ComplexDeinterleavingRotation::Rotation_270;
648*bdd1243dSDimitry Andric   else {
649*bdd1243dSDimitry Andric     LLVM_DEBUG(dbgs() << " - Unhandled case, rotation is not assigned.\n");
650*bdd1243dSDimitry Andric     return nullptr;
651*bdd1243dSDimitry Andric   }
652*bdd1243dSDimitry Andric 
653*bdd1243dSDimitry Andric   auto *AR = dyn_cast<Instruction>(Real->getOperand(0));
654*bdd1243dSDimitry Andric   auto *BI = dyn_cast<Instruction>(Real->getOperand(1));
655*bdd1243dSDimitry Andric   auto *AI = dyn_cast<Instruction>(Imag->getOperand(0));
656*bdd1243dSDimitry Andric   auto *BR = dyn_cast<Instruction>(Imag->getOperand(1));
657*bdd1243dSDimitry Andric 
658*bdd1243dSDimitry Andric   if (!AR || !AI || !BR || !BI) {
659*bdd1243dSDimitry Andric     LLVM_DEBUG(dbgs() << " - Not all operands are instructions.\n");
660*bdd1243dSDimitry Andric     return nullptr;
661*bdd1243dSDimitry Andric   }
662*bdd1243dSDimitry Andric 
663*bdd1243dSDimitry Andric   NodePtr ResA = identifyNode(AR, AI);
664*bdd1243dSDimitry Andric   if (!ResA) {
665*bdd1243dSDimitry Andric     LLVM_DEBUG(dbgs() << " - AR/AI is not identified as a composite node.\n");
666*bdd1243dSDimitry Andric     return nullptr;
667*bdd1243dSDimitry Andric   }
668*bdd1243dSDimitry Andric   NodePtr ResB = identifyNode(BR, BI);
669*bdd1243dSDimitry Andric   if (!ResB) {
670*bdd1243dSDimitry Andric     LLVM_DEBUG(dbgs() << " - BR/BI is not identified as a composite node.\n");
671*bdd1243dSDimitry Andric     return nullptr;
672*bdd1243dSDimitry Andric   }
673*bdd1243dSDimitry Andric 
674*bdd1243dSDimitry Andric   NodePtr Node =
675*bdd1243dSDimitry Andric       prepareCompositeNode(ComplexDeinterleavingOperation::CAdd, Real, Imag);
676*bdd1243dSDimitry Andric   Node->Rotation = Rotation;
677*bdd1243dSDimitry Andric   Node->addOperand(ResA);
678*bdd1243dSDimitry Andric   Node->addOperand(ResB);
679*bdd1243dSDimitry Andric   return submitCompositeNode(Node);
680*bdd1243dSDimitry Andric }
681*bdd1243dSDimitry Andric 
682*bdd1243dSDimitry Andric static bool isInstructionPairAdd(Instruction *A, Instruction *B) {
683*bdd1243dSDimitry Andric   unsigned OpcA = A->getOpcode();
684*bdd1243dSDimitry Andric   unsigned OpcB = B->getOpcode();
685*bdd1243dSDimitry Andric 
686*bdd1243dSDimitry Andric   return (OpcA == Instruction::FSub && OpcB == Instruction::FAdd) ||
687*bdd1243dSDimitry Andric          (OpcA == Instruction::FAdd && OpcB == Instruction::FSub) ||
688*bdd1243dSDimitry Andric          (OpcA == Instruction::Sub && OpcB == Instruction::Add) ||
689*bdd1243dSDimitry Andric          (OpcA == Instruction::Add && OpcB == Instruction::Sub);
690*bdd1243dSDimitry Andric }
691*bdd1243dSDimitry Andric 
692*bdd1243dSDimitry Andric static bool isInstructionPairMul(Instruction *A, Instruction *B) {
693*bdd1243dSDimitry Andric   auto Pattern =
694*bdd1243dSDimitry Andric       m_BinOp(m_FMul(m_Value(), m_Value()), m_FMul(m_Value(), m_Value()));
695*bdd1243dSDimitry Andric 
696*bdd1243dSDimitry Andric   return match(A, Pattern) && match(B, Pattern);
697*bdd1243dSDimitry Andric }
698*bdd1243dSDimitry Andric 
699*bdd1243dSDimitry Andric ComplexDeinterleavingGraph::NodePtr
700*bdd1243dSDimitry Andric ComplexDeinterleavingGraph::identifyNode(Instruction *Real, Instruction *Imag) {
701*bdd1243dSDimitry Andric   LLVM_DEBUG(dbgs() << "identifyNode on " << *Real << " / " << *Imag << "\n");
702*bdd1243dSDimitry Andric   if (NodePtr CN = getContainingComposite(Real, Imag)) {
703*bdd1243dSDimitry Andric     LLVM_DEBUG(dbgs() << " - Folding to existing node\n");
704*bdd1243dSDimitry Andric     return CN;
705*bdd1243dSDimitry Andric   }
706*bdd1243dSDimitry Andric 
707*bdd1243dSDimitry Andric   auto *RealShuffle = dyn_cast<ShuffleVectorInst>(Real);
708*bdd1243dSDimitry Andric   auto *ImagShuffle = dyn_cast<ShuffleVectorInst>(Imag);
709*bdd1243dSDimitry Andric   if (RealShuffle && ImagShuffle) {
710*bdd1243dSDimitry Andric     Value *RealOp1 = RealShuffle->getOperand(1);
711*bdd1243dSDimitry Andric     if (!isa<UndefValue>(RealOp1) && !isa<ConstantAggregateZero>(RealOp1)) {
712*bdd1243dSDimitry Andric       LLVM_DEBUG(dbgs() << " - RealOp1 is not undef or zero.\n");
713*bdd1243dSDimitry Andric       return nullptr;
714*bdd1243dSDimitry Andric     }
715*bdd1243dSDimitry Andric     Value *ImagOp1 = ImagShuffle->getOperand(1);
716*bdd1243dSDimitry Andric     if (!isa<UndefValue>(ImagOp1) && !isa<ConstantAggregateZero>(ImagOp1)) {
717*bdd1243dSDimitry Andric       LLVM_DEBUG(dbgs() << " - ImagOp1 is not undef or zero.\n");
718*bdd1243dSDimitry Andric       return nullptr;
719*bdd1243dSDimitry Andric     }
720*bdd1243dSDimitry Andric 
721*bdd1243dSDimitry Andric     Value *RealOp0 = RealShuffle->getOperand(0);
722*bdd1243dSDimitry Andric     Value *ImagOp0 = ImagShuffle->getOperand(0);
723*bdd1243dSDimitry Andric 
724*bdd1243dSDimitry Andric     if (RealOp0 != ImagOp0) {
725*bdd1243dSDimitry Andric       LLVM_DEBUG(dbgs() << " - Shuffle operands are not equal.\n");
726*bdd1243dSDimitry Andric       return nullptr;
727*bdd1243dSDimitry Andric     }
728*bdd1243dSDimitry Andric 
729*bdd1243dSDimitry Andric     ArrayRef<int> RealMask = RealShuffle->getShuffleMask();
730*bdd1243dSDimitry Andric     ArrayRef<int> ImagMask = ImagShuffle->getShuffleMask();
731*bdd1243dSDimitry Andric     if (!isDeinterleavingMask(RealMask) || !isDeinterleavingMask(ImagMask)) {
732*bdd1243dSDimitry Andric       LLVM_DEBUG(dbgs() << " - Masks are not deinterleaving.\n");
733*bdd1243dSDimitry Andric       return nullptr;
734*bdd1243dSDimitry Andric     }
735*bdd1243dSDimitry Andric 
736*bdd1243dSDimitry Andric     if (RealMask[0] != 0 || ImagMask[0] != 1) {
737*bdd1243dSDimitry Andric       LLVM_DEBUG(dbgs() << " - Masks do not have the correct initial value.\n");
738*bdd1243dSDimitry Andric       return nullptr;
739*bdd1243dSDimitry Andric     }
740*bdd1243dSDimitry Andric 
741*bdd1243dSDimitry Andric     // Type checking, the shuffle type should be a vector type of the same
742*bdd1243dSDimitry Andric     // scalar type, but half the size
743*bdd1243dSDimitry Andric     auto CheckType = [&](ShuffleVectorInst *Shuffle) {
744*bdd1243dSDimitry Andric       Value *Op = Shuffle->getOperand(0);
745*bdd1243dSDimitry Andric       auto *ShuffleTy = cast<FixedVectorType>(Shuffle->getType());
746*bdd1243dSDimitry Andric       auto *OpTy = cast<FixedVectorType>(Op->getType());
747*bdd1243dSDimitry Andric 
748*bdd1243dSDimitry Andric       if (OpTy->getScalarType() != ShuffleTy->getScalarType())
749*bdd1243dSDimitry Andric         return false;
750*bdd1243dSDimitry Andric       if ((ShuffleTy->getNumElements() * 2) != OpTy->getNumElements())
751*bdd1243dSDimitry Andric         return false;
752*bdd1243dSDimitry Andric 
753*bdd1243dSDimitry Andric       return true;
754*bdd1243dSDimitry Andric     };
755*bdd1243dSDimitry Andric 
756*bdd1243dSDimitry Andric     auto CheckDeinterleavingShuffle = [&](ShuffleVectorInst *Shuffle) -> bool {
757*bdd1243dSDimitry Andric       if (!CheckType(Shuffle))
758*bdd1243dSDimitry Andric         return false;
759*bdd1243dSDimitry Andric 
760*bdd1243dSDimitry Andric       ArrayRef<int> Mask = Shuffle->getShuffleMask();
761*bdd1243dSDimitry Andric       int Last = *Mask.rbegin();
762*bdd1243dSDimitry Andric 
763*bdd1243dSDimitry Andric       Value *Op = Shuffle->getOperand(0);
764*bdd1243dSDimitry Andric       auto *OpTy = cast<FixedVectorType>(Op->getType());
765*bdd1243dSDimitry Andric       int NumElements = OpTy->getNumElements();
766*bdd1243dSDimitry Andric 
767*bdd1243dSDimitry Andric       // Ensure that the deinterleaving shuffle only pulls from the first
768*bdd1243dSDimitry Andric       // shuffle operand.
769*bdd1243dSDimitry Andric       return Last < NumElements;
770*bdd1243dSDimitry Andric     };
771*bdd1243dSDimitry Andric 
772*bdd1243dSDimitry Andric     if (RealShuffle->getType() != ImagShuffle->getType()) {
773*bdd1243dSDimitry Andric       LLVM_DEBUG(dbgs() << " - Shuffle types aren't equal.\n");
774*bdd1243dSDimitry Andric       return nullptr;
775*bdd1243dSDimitry Andric     }
776*bdd1243dSDimitry Andric     if (!CheckDeinterleavingShuffle(RealShuffle)) {
777*bdd1243dSDimitry Andric       LLVM_DEBUG(dbgs() << " - RealShuffle is invalid type.\n");
778*bdd1243dSDimitry Andric       return nullptr;
779*bdd1243dSDimitry Andric     }
780*bdd1243dSDimitry Andric     if (!CheckDeinterleavingShuffle(ImagShuffle)) {
781*bdd1243dSDimitry Andric       LLVM_DEBUG(dbgs() << " - ImagShuffle is invalid type.\n");
782*bdd1243dSDimitry Andric       return nullptr;
783*bdd1243dSDimitry Andric     }
784*bdd1243dSDimitry Andric 
785*bdd1243dSDimitry Andric     NodePtr PlaceholderNode =
786*bdd1243dSDimitry Andric         prepareCompositeNode(llvm::ComplexDeinterleavingOperation::Shuffle,
787*bdd1243dSDimitry Andric                              RealShuffle, ImagShuffle);
788*bdd1243dSDimitry Andric     PlaceholderNode->ReplacementNode = RealShuffle->getOperand(0);
789*bdd1243dSDimitry Andric     return submitCompositeNode(PlaceholderNode);
790*bdd1243dSDimitry Andric   }
791*bdd1243dSDimitry Andric   if (RealShuffle || ImagShuffle)
792*bdd1243dSDimitry Andric     return nullptr;
793*bdd1243dSDimitry Andric 
794*bdd1243dSDimitry Andric   auto *VTy = cast<FixedVectorType>(Real->getType());
795*bdd1243dSDimitry Andric   auto *NewVTy =
796*bdd1243dSDimitry Andric       FixedVectorType::get(VTy->getScalarType(), VTy->getNumElements() * 2);
797*bdd1243dSDimitry Andric 
798*bdd1243dSDimitry Andric   if (TL->isComplexDeinterleavingOperationSupported(
799*bdd1243dSDimitry Andric           ComplexDeinterleavingOperation::CMulPartial, NewVTy) &&
800*bdd1243dSDimitry Andric       isInstructionPairMul(Real, Imag)) {
801*bdd1243dSDimitry Andric     return identifyPartialMul(Real, Imag);
802*bdd1243dSDimitry Andric   }
803*bdd1243dSDimitry Andric 
804*bdd1243dSDimitry Andric   if (TL->isComplexDeinterleavingOperationSupported(
805*bdd1243dSDimitry Andric           ComplexDeinterleavingOperation::CAdd, NewVTy) &&
806*bdd1243dSDimitry Andric       isInstructionPairAdd(Real, Imag)) {
807*bdd1243dSDimitry Andric     return identifyAdd(Real, Imag);
808*bdd1243dSDimitry Andric   }
809*bdd1243dSDimitry Andric 
810*bdd1243dSDimitry Andric   return nullptr;
811*bdd1243dSDimitry Andric }
812*bdd1243dSDimitry Andric 
813*bdd1243dSDimitry Andric bool ComplexDeinterleavingGraph::identifyNodes(Instruction *RootI) {
814*bdd1243dSDimitry Andric   Instruction *Real;
815*bdd1243dSDimitry Andric   Instruction *Imag;
816*bdd1243dSDimitry Andric   if (!match(RootI, m_Shuffle(m_Instruction(Real), m_Instruction(Imag))))
817*bdd1243dSDimitry Andric     return false;
818*bdd1243dSDimitry Andric 
819*bdd1243dSDimitry Andric   RootValue = RootI;
820*bdd1243dSDimitry Andric   AllInstructions.insert(RootI);
821*bdd1243dSDimitry Andric   RootNode = identifyNode(Real, Imag);
822*bdd1243dSDimitry Andric 
823*bdd1243dSDimitry Andric   LLVM_DEBUG({
824*bdd1243dSDimitry Andric     Function *F = RootI->getFunction();
825*bdd1243dSDimitry Andric     BasicBlock *B = RootI->getParent();
826*bdd1243dSDimitry Andric     dbgs() << "Complex deinterleaving graph for " << F->getName()
827*bdd1243dSDimitry Andric            << "::" << B->getName() << ".\n";
828*bdd1243dSDimitry Andric     dump(dbgs());
829*bdd1243dSDimitry Andric     dbgs() << "\n";
830*bdd1243dSDimitry Andric   });
831*bdd1243dSDimitry Andric 
832*bdd1243dSDimitry Andric   // Check all instructions have internal uses
833*bdd1243dSDimitry Andric   for (const auto &Node : CompositeNodes) {
834*bdd1243dSDimitry Andric     if (!Node->hasAllInternalUses(AllInstructions)) {
835*bdd1243dSDimitry Andric       LLVM_DEBUG(dbgs() << "  - Invalid internal uses\n");
836*bdd1243dSDimitry Andric       return false;
837*bdd1243dSDimitry Andric     }
838*bdd1243dSDimitry Andric   }
839*bdd1243dSDimitry Andric   return RootNode != nullptr;
840*bdd1243dSDimitry Andric }
841*bdd1243dSDimitry Andric 
842*bdd1243dSDimitry Andric Value *ComplexDeinterleavingGraph::replaceNode(
843*bdd1243dSDimitry Andric     ComplexDeinterleavingGraph::RawNodePtr Node) {
844*bdd1243dSDimitry Andric   if (Node->ReplacementNode)
845*bdd1243dSDimitry Andric     return Node->ReplacementNode;
846*bdd1243dSDimitry Andric 
847*bdd1243dSDimitry Andric   Value *Input0 = replaceNode(Node->Operands[0]);
848*bdd1243dSDimitry Andric   Value *Input1 = replaceNode(Node->Operands[1]);
849*bdd1243dSDimitry Andric   Value *Accumulator =
850*bdd1243dSDimitry Andric       Node->Operands.size() > 2 ? replaceNode(Node->Operands[2]) : nullptr;
851*bdd1243dSDimitry Andric 
852*bdd1243dSDimitry Andric   assert(Input0->getType() == Input1->getType() &&
853*bdd1243dSDimitry Andric          "Node inputs need to be of the same type");
854*bdd1243dSDimitry Andric 
855*bdd1243dSDimitry Andric   Node->ReplacementNode = TL->createComplexDeinterleavingIR(
856*bdd1243dSDimitry Andric       Node->Real, Node->Operation, Node->Rotation, Input0, Input1, Accumulator);
857*bdd1243dSDimitry Andric 
858*bdd1243dSDimitry Andric   assert(Node->ReplacementNode && "Target failed to create Intrinsic call.");
859*bdd1243dSDimitry Andric   NumComplexTransformations += 1;
860*bdd1243dSDimitry Andric   return Node->ReplacementNode;
861*bdd1243dSDimitry Andric }
862*bdd1243dSDimitry Andric 
863*bdd1243dSDimitry Andric void ComplexDeinterleavingGraph::replaceNodes() {
864*bdd1243dSDimitry Andric   Value *R = replaceNode(RootNode.get());
865*bdd1243dSDimitry Andric   assert(R && "Unable to find replacement for RootValue");
866*bdd1243dSDimitry Andric   RootValue->replaceAllUsesWith(R);
867*bdd1243dSDimitry Andric }
868*bdd1243dSDimitry Andric 
869*bdd1243dSDimitry Andric bool ComplexDeinterleavingCompositeNode::hasAllInternalUses(
870*bdd1243dSDimitry Andric     SmallPtrSet<Instruction *, 16> &AllInstructions) {
871*bdd1243dSDimitry Andric   if (Operation == ComplexDeinterleavingOperation::Shuffle)
872*bdd1243dSDimitry Andric     return true;
873*bdd1243dSDimitry Andric 
874*bdd1243dSDimitry Andric   for (auto *User : Real->users()) {
875*bdd1243dSDimitry Andric     if (!AllInstructions.contains(cast<Instruction>(User)))
876*bdd1243dSDimitry Andric       return false;
877*bdd1243dSDimitry Andric   }
878*bdd1243dSDimitry Andric   for (auto *User : Imag->users()) {
879*bdd1243dSDimitry Andric     if (!AllInstructions.contains(cast<Instruction>(User)))
880*bdd1243dSDimitry Andric       return false;
881*bdd1243dSDimitry Andric   }
882*bdd1243dSDimitry Andric   for (auto *I : InternalInstructions) {
883*bdd1243dSDimitry Andric     for (auto *User : I->users()) {
884*bdd1243dSDimitry Andric       if (!AllInstructions.contains(cast<Instruction>(User)))
885*bdd1243dSDimitry Andric         return false;
886*bdd1243dSDimitry Andric     }
887*bdd1243dSDimitry Andric   }
888*bdd1243dSDimitry Andric   return true;
889*bdd1243dSDimitry Andric }
890