//===- ComplexDeinterleavingPass.cpp --------------------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // Identification: // This step is responsible for finding the patterns that can be lowered to // complex instructions, and building a graph to represent the complex // structures. Starting from the "Converging Shuffle" (a shuffle that // reinterleaves the complex components, with a mask of <0, 2, 1, 3>), the // operands are evaluated and identified as "Composite Nodes" (collections of // instructions that can potentially be lowered to a single complex // instruction). This is performed by checking the real and imaginary components // and tracking the data flow for each component while following the operand // pairs. Validity of each node is expected to be done upon creation, and any // validation errors should halt traversal and prevent further graph // construction. // Instead of relying on Shuffle operations, vector interleaving and // deinterleaving can be represented by vector.interleave2 and // vector.deinterleave2 intrinsics. Scalable vectors can be represented only by // these intrinsics, whereas, fixed-width vectors are recognized for both // shufflevector instruction and intrinsics. // // Replacement: // This step traverses the graph built up by identification, delegating to the // target to validate and generate the correct intrinsics, and plumbs them // together connecting each end of the new intrinsics graph to the existing // use-def chain. This step is assumed to finish successfully, as all // information is expected to be correct by this point. // // // Internal data structure: // ComplexDeinterleavingGraph: // Keeps references to all the valid CompositeNodes formed as part of the // transformation, and every Instruction contained within said nodes. It also // holds onto a reference to the root Instruction, and the root node that should // replace it. // // ComplexDeinterleavingCompositeNode: // A CompositeNode represents a single transformation point; each node should // transform into a single complex instruction (ignoring vector splitting, which // would generate more instructions per node). They are identified in a // depth-first manner, traversing and identifying the operands of each // instruction in the order they appear in the IR. // Each node maintains a reference to its Real and Imaginary instructions, // as well as any additional instructions that make up the identified operation // (Internal instructions should only have uses within their containing node). // A Node also contains the rotation and operation type that it represents. // Operands contains pointers to other CompositeNodes, acting as the edges in // the graph. ReplacementValue is the transformed Value* that has been emitted // to the IR. // // Note: If the operation of a Node is Shuffle, only the Real, Imaginary, and // ReplacementValue fields of that Node are relevant, where the ReplacementValue // should be pre-populated. // //===----------------------------------------------------------------------===// #include "llvm/CodeGen/ComplexDeinterleavingPass.h" #include "llvm/ADT/Statistic.h" #include "llvm/Analysis/TargetLibraryInfo.h" #include "llvm/Analysis/TargetTransformInfo.h" #include "llvm/CodeGen/TargetLowering.h" #include "llvm/CodeGen/TargetPassConfig.h" #include "llvm/CodeGen/TargetSubtargetInfo.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/PatternMatch.h" #include "llvm/InitializePasses.h" #include "llvm/Target/TargetMachine.h" #include "llvm/Transforms/Utils/Local.h" #include using namespace llvm; using namespace PatternMatch; #define DEBUG_TYPE "complex-deinterleaving" STATISTIC(NumComplexTransformations, "Amount of complex patterns transformed"); static cl::opt ComplexDeinterleavingEnabled( "enable-complex-deinterleaving", cl::desc("Enable generation of complex instructions"), cl::init(true), cl::Hidden); /// Checks the given mask, and determines whether said mask is interleaving. /// /// To be interleaving, a mask must alternate between `i` and `i + (Length / /// 2)`, and must contain all numbers within the range of `[0..Length)` (e.g. a /// 4x vector interleaving mask would be <0, 2, 1, 3>). static bool isInterleavingMask(ArrayRef Mask); /// Checks the given mask, and determines whether said mask is deinterleaving. /// /// To be deinterleaving, a mask must increment in steps of 2, and either start /// with 0 or 1. /// (e.g. an 8x vector deinterleaving mask would be either <0, 2, 4, 6> or /// <1, 3, 5, 7>). static bool isDeinterleavingMask(ArrayRef Mask); namespace { class ComplexDeinterleavingLegacyPass : public FunctionPass { public: static char ID; ComplexDeinterleavingLegacyPass(const TargetMachine *TM = nullptr) : FunctionPass(ID), TM(TM) { initializeComplexDeinterleavingLegacyPassPass( *PassRegistry::getPassRegistry()); } StringRef getPassName() const override { return "Complex Deinterleaving Pass"; } bool runOnFunction(Function &F) override; void getAnalysisUsage(AnalysisUsage &AU) const override { AU.addRequired(); AU.setPreservesCFG(); } private: const TargetMachine *TM; }; class ComplexDeinterleavingGraph; struct ComplexDeinterleavingCompositeNode { ComplexDeinterleavingCompositeNode(ComplexDeinterleavingOperation Op, Instruction *R, Instruction *I) : Operation(Op), Real(R), Imag(I) {} private: friend class ComplexDeinterleavingGraph; using NodePtr = std::shared_ptr; using RawNodePtr = ComplexDeinterleavingCompositeNode *; public: ComplexDeinterleavingOperation Operation; Instruction *Real; Instruction *Imag; // This two members are required exclusively for generating // ComplexDeinterleavingOperation::Symmetric operations. unsigned Opcode; FastMathFlags Flags; ComplexDeinterleavingRotation Rotation = ComplexDeinterleavingRotation::Rotation_0; SmallVector Operands; Value *ReplacementNode = nullptr; void addOperand(NodePtr Node) { Operands.push_back(Node.get()); } void dump() { dump(dbgs()); } void dump(raw_ostream &OS) { auto PrintValue = [&](Value *V) { if (V) { OS << "\""; V->print(OS, true); OS << "\"\n"; } else OS << "nullptr\n"; }; auto PrintNodeRef = [&](RawNodePtr Ptr) { if (Ptr) OS << Ptr << "\n"; else OS << "nullptr\n"; }; OS << "- CompositeNode: " << this << "\n"; OS << " Real: "; PrintValue(Real); OS << " Imag: "; PrintValue(Imag); OS << " ReplacementNode: "; PrintValue(ReplacementNode); OS << " Operation: " << (int)Operation << "\n"; OS << " Rotation: " << ((int)Rotation * 90) << "\n"; OS << " Operands: \n"; for (const auto &Op : Operands) { OS << " - "; PrintNodeRef(Op); } } }; class ComplexDeinterleavingGraph { public: struct Product { Instruction *Multiplier; Instruction *Multiplicand; bool IsPositive; }; using Addend = std::pair; using NodePtr = ComplexDeinterleavingCompositeNode::NodePtr; using RawNodePtr = ComplexDeinterleavingCompositeNode::RawNodePtr; // Helper struct for holding info about potential partial multiplication // candidates struct PartialMulCandidate { Instruction *Common; NodePtr Node; unsigned RealIdx; unsigned ImagIdx; bool IsNodeInverted; }; explicit ComplexDeinterleavingGraph(const TargetLowering *TL, const TargetLibraryInfo *TLI) : TL(TL), TLI(TLI) {} private: const TargetLowering *TL = nullptr; const TargetLibraryInfo *TLI = nullptr; SmallVector CompositeNodes; SmallPtrSet FinalInstructions; /// Root instructions are instructions from which complex computation starts std::map RootToNode; /// Topologically sorted root instructions SmallVector OrderedRoots; /// When examining a basic block for complex deinterleaving, if it is a simple /// one-block loop, then the only incoming block is 'Incoming' and the /// 'BackEdge' block is the block itself." BasicBlock *BackEdge = nullptr; BasicBlock *Incoming = nullptr; /// ReductionInfo maps from %ReductionOp to %PHInode and Instruction /// %OutsideUser as it is shown in the IR: /// /// vector.body: /// %PHInode = phi [ zeroinitializer, %entry ], /// [ %ReductionOp, %vector.body ] /// ... /// %ReductionOp = fadd i64 ... /// ... /// br i1 %condition, label %vector.body, %middle.block /// /// middle.block: /// %OutsideUser = llvm.vector.reduce.fadd(..., %ReductionOp) /// /// %OutsideUser can be `llvm.vector.reduce.fadd` or `fadd` preceding /// `llvm.vector.reduce.fadd` when unroll factor isn't one. std::map> ReductionInfo; /// In the process of detecting a reduction, we consider a pair of /// %ReductionOP, which we refer to as real and imag (or vice versa), and /// traverse the use-tree to detect complex operations. As this is a reduction /// operation, it will eventually reach RealPHI and ImagPHI, which corresponds /// to the %ReductionOPs that we suspect to be complex. /// RealPHI and ImagPHI are used by the identifyPHINode method. PHINode *RealPHI = nullptr; PHINode *ImagPHI = nullptr; /// OldToNewPHI maps the original real PHINode to a new, double-sized PHINode. /// The new PHINode corresponds to a vector of deinterleaved complex numbers. /// This mapping is populated during /// ComplexDeinterleavingOperation::ReductionPHI node replacement. It is then /// used in the ComplexDeinterleavingOperation::ReductionOperation node /// replacement process. std::map OldToNewPHI; NodePtr prepareCompositeNode(ComplexDeinterleavingOperation Operation, Instruction *R, Instruction *I) { assert(((Operation != ComplexDeinterleavingOperation::ReductionPHI && Operation != ComplexDeinterleavingOperation::ReductionOperation) || (R && I)) && "Reduction related nodes must have Real and Imaginary parts"); return std::make_shared(Operation, R, I); } NodePtr submitCompositeNode(NodePtr Node) { CompositeNodes.push_back(Node); return Node; } NodePtr getContainingComposite(Value *R, Value *I) { for (const auto &CN : CompositeNodes) { if (CN->Real == R && CN->Imag == I) return CN; } return nullptr; } /// Identifies a complex partial multiply pattern and its rotation, based on /// the following patterns /// /// 0: r: cr + ar * br /// i: ci + ar * bi /// 90: r: cr - ai * bi /// i: ci + ai * br /// 180: r: cr - ar * br /// i: ci - ar * bi /// 270: r: cr + ai * bi /// i: ci - ai * br NodePtr identifyPartialMul(Instruction *Real, Instruction *Imag); /// Identify the other branch of a Partial Mul, taking the CommonOperandI that /// is partially known from identifyPartialMul, filling in the other half of /// the complex pair. NodePtr identifyNodeWithImplicitAdd( Instruction *I, Instruction *J, std::pair &CommonOperandI); /// Identifies a complex add pattern and its rotation, based on the following /// patterns. /// /// 90: r: ar - bi /// i: ai + br /// 270: r: ar + bi /// i: ai - br NodePtr identifyAdd(Instruction *Real, Instruction *Imag); NodePtr identifySymmetricOperation(Instruction *Real, Instruction *Imag); NodePtr identifyNode(Instruction *I, Instruction *J); /// Determine if a sum of complex numbers can be formed from \p RealAddends /// and \p ImagAddens. If \p Accumulator is not null, add the result to it. /// Return nullptr if it is not possible to construct a complex number. /// \p Flags are needed to generate symmetric Add and Sub operations. NodePtr identifyAdditions(std::list &RealAddends, std::list &ImagAddends, FastMathFlags Flags, NodePtr Accumulator); /// Extract one addend that have both real and imaginary parts positive. NodePtr extractPositiveAddend(std::list &RealAddends, std::list &ImagAddends); /// Determine if sum of multiplications of complex numbers can be formed from /// \p RealMuls and \p ImagMuls. If \p Accumulator is not null, add the result /// to it. Return nullptr if it is not possible to construct a complex number. NodePtr identifyMultiplications(std::vector &RealMuls, std::vector &ImagMuls, NodePtr Accumulator); /// Go through pairs of multiplication (one Real and one Imag) and find all /// possible candidates for partial multiplication and put them into \p /// Candidates. Returns true if all Product has pair with common operand bool collectPartialMuls(const std::vector &RealMuls, const std::vector &ImagMuls, std::vector &Candidates); /// If the code is compiled with -Ofast or expressions have `reassoc` flag, /// the order of complex computation operations may be significantly altered, /// and the real and imaginary parts may not be executed in parallel. This /// function takes this into consideration and employs a more general approach /// to identify complex computations. Initially, it gathers all the addends /// and multiplicands and then constructs a complex expression from them. NodePtr identifyReassocNodes(Instruction *I, Instruction *J); NodePtr identifyRoot(Instruction *I); /// Identifies the Deinterleave operation applied to a vector containing /// complex numbers. There are two ways to represent the Deinterleave /// operation: /// * Using two shufflevectors with even indices for /pReal instruction and /// odd indices for /pImag instructions (only for fixed-width vectors) /// * Using two extractvalue instructions applied to `vector.deinterleave2` /// intrinsic (for both fixed and scalable vectors) NodePtr identifyDeinterleave(Instruction *Real, Instruction *Imag); NodePtr identifyPHINode(Instruction *Real, Instruction *Imag); Value *replaceNode(IRBuilderBase &Builder, RawNodePtr Node); /// Complete IR modifications after producing new reduction operation: /// * Populate the PHINode generated for /// ComplexDeinterleavingOperation::ReductionPHI /// * Deinterleave the final value outside of the loop and repurpose original /// reduction users void processReductionOperation(Value *OperationReplacement, RawNodePtr Node); public: void dump() { dump(dbgs()); } void dump(raw_ostream &OS) { for (const auto &Node : CompositeNodes) Node->dump(OS); } /// Returns false if the deinterleaving operation should be cancelled for the /// current graph. bool identifyNodes(Instruction *RootI); /// In case \pB is one-block loop, this function seeks potential reductions /// and populates ReductionInfo. Returns true if any reductions were /// identified. bool collectPotentialReductions(BasicBlock *B); void identifyReductionNodes(); /// Check that every instruction, from the roots to the leaves, has internal /// uses. bool checkNodes(); /// Perform the actual replacement of the underlying instruction graph. void replaceNodes(); }; class ComplexDeinterleaving { public: ComplexDeinterleaving(const TargetLowering *tl, const TargetLibraryInfo *tli) : TL(tl), TLI(tli) {} bool runOnFunction(Function &F); private: bool evaluateBasicBlock(BasicBlock *B); const TargetLowering *TL = nullptr; const TargetLibraryInfo *TLI = nullptr; }; } // namespace char ComplexDeinterleavingLegacyPass::ID = 0; INITIALIZE_PASS_BEGIN(ComplexDeinterleavingLegacyPass, DEBUG_TYPE, "Complex Deinterleaving", false, false) INITIALIZE_PASS_END(ComplexDeinterleavingLegacyPass, DEBUG_TYPE, "Complex Deinterleaving", false, false) PreservedAnalyses ComplexDeinterleavingPass::run(Function &F, FunctionAnalysisManager &AM) { const TargetLowering *TL = TM->getSubtargetImpl(F)->getTargetLowering(); auto &TLI = AM.getResult(F); if (!ComplexDeinterleaving(TL, &TLI).runOnFunction(F)) return PreservedAnalyses::all(); PreservedAnalyses PA; PA.preserve(); return PA; } FunctionPass *llvm::createComplexDeinterleavingPass(const TargetMachine *TM) { return new ComplexDeinterleavingLegacyPass(TM); } bool ComplexDeinterleavingLegacyPass::runOnFunction(Function &F) { const auto *TL = TM->getSubtargetImpl(F)->getTargetLowering(); auto TLI = getAnalysis().getTLI(F); return ComplexDeinterleaving(TL, &TLI).runOnFunction(F); } bool ComplexDeinterleaving::runOnFunction(Function &F) { if (!ComplexDeinterleavingEnabled) { LLVM_DEBUG( dbgs() << "Complex deinterleaving has been explicitly disabled.\n"); return false; } if (!TL->isComplexDeinterleavingSupported()) { LLVM_DEBUG( dbgs() << "Complex deinterleaving has been disabled, target does " "not support lowering of complex number operations.\n"); return false; } bool Changed = false; for (auto &B : F) Changed |= evaluateBasicBlock(&B); return Changed; } static bool isInterleavingMask(ArrayRef Mask) { // If the size is not even, it's not an interleaving mask if ((Mask.size() & 1)) return false; int HalfNumElements = Mask.size() / 2; for (int Idx = 0; Idx < HalfNumElements; ++Idx) { int MaskIdx = Idx * 2; if (Mask[MaskIdx] != Idx || Mask[MaskIdx + 1] != (Idx + HalfNumElements)) return false; } return true; } static bool isDeinterleavingMask(ArrayRef Mask) { int Offset = Mask[0]; int HalfNumElements = Mask.size() / 2; for (int Idx = 1; Idx < HalfNumElements; ++Idx) { if (Mask[Idx] != (Idx * 2) + Offset) return false; } return true; } bool ComplexDeinterleaving::evaluateBasicBlock(BasicBlock *B) { ComplexDeinterleavingGraph Graph(TL, TLI); if (Graph.collectPotentialReductions(B)) Graph.identifyReductionNodes(); for (auto &I : *B) Graph.identifyNodes(&I); if (Graph.checkNodes()) { Graph.replaceNodes(); return true; } return false; } ComplexDeinterleavingGraph::NodePtr ComplexDeinterleavingGraph::identifyNodeWithImplicitAdd( Instruction *Real, Instruction *Imag, std::pair &PartialMatch) { LLVM_DEBUG(dbgs() << "identifyNodeWithImplicitAdd " << *Real << " / " << *Imag << "\n"); if (!Real->hasOneUse() || !Imag->hasOneUse()) { LLVM_DEBUG(dbgs() << " - Mul operand has multiple uses.\n"); return nullptr; } if (Real->getOpcode() != Instruction::FMul || Imag->getOpcode() != Instruction::FMul) { LLVM_DEBUG(dbgs() << " - Real or imaginary instruction is not fmul\n"); return nullptr; } Instruction *R0 = dyn_cast(Real->getOperand(0)); Instruction *R1 = dyn_cast(Real->getOperand(1)); Instruction *I0 = dyn_cast(Imag->getOperand(0)); Instruction *I1 = dyn_cast(Imag->getOperand(1)); if (!R0 || !R1 || !I0 || !I1) { LLVM_DEBUG(dbgs() << " - Mul operand not Instruction\n"); return nullptr; } // A +/+ has a rotation of 0. If any of the operands are fneg, we flip the // rotations and use the operand. unsigned Negs = 0; SmallVector FNegs; if (R0->getOpcode() == Instruction::FNeg || R1->getOpcode() == Instruction::FNeg) { Negs |= 1; if (R0->getOpcode() == Instruction::FNeg) { FNegs.push_back(R0); R0 = dyn_cast(R0->getOperand(0)); } else { FNegs.push_back(R1); R1 = dyn_cast(R1->getOperand(0)); } if (!R0 || !R1) return nullptr; } if (I0->getOpcode() == Instruction::FNeg || I1->getOpcode() == Instruction::FNeg) { Negs |= 2; Negs ^= 1; if (I0->getOpcode() == Instruction::FNeg) { FNegs.push_back(I0); I0 = dyn_cast(I0->getOperand(0)); } else { FNegs.push_back(I1); I1 = dyn_cast(I1->getOperand(0)); } if (!I0 || !I1) return nullptr; } ComplexDeinterleavingRotation Rotation = (ComplexDeinterleavingRotation)Negs; Instruction *CommonOperand; Instruction *UncommonRealOp; Instruction *UncommonImagOp; if (R0 == I0 || R0 == I1) { CommonOperand = R0; UncommonRealOp = R1; } else if (R1 == I0 || R1 == I1) { CommonOperand = R1; UncommonRealOp = R0; } else { LLVM_DEBUG(dbgs() << " - No equal operand\n"); return nullptr; } UncommonImagOp = (CommonOperand == I0) ? I1 : I0; if (Rotation == ComplexDeinterleavingRotation::Rotation_90 || Rotation == ComplexDeinterleavingRotation::Rotation_270) std::swap(UncommonRealOp, UncommonImagOp); // Between identifyPartialMul and here we need to have found a complete valid // pair from the CommonOperand of each part. if (Rotation == ComplexDeinterleavingRotation::Rotation_0 || Rotation == ComplexDeinterleavingRotation::Rotation_180) PartialMatch.first = CommonOperand; else PartialMatch.second = CommonOperand; if (!PartialMatch.first || !PartialMatch.second) { LLVM_DEBUG(dbgs() << " - Incomplete partial match\n"); return nullptr; } NodePtr CommonNode = identifyNode(PartialMatch.first, PartialMatch.second); if (!CommonNode) { LLVM_DEBUG(dbgs() << " - No CommonNode identified\n"); return nullptr; } NodePtr UncommonNode = identifyNode(UncommonRealOp, UncommonImagOp); if (!UncommonNode) { LLVM_DEBUG(dbgs() << " - No UncommonNode identified\n"); return nullptr; } NodePtr Node = prepareCompositeNode( ComplexDeinterleavingOperation::CMulPartial, Real, Imag); Node->Rotation = Rotation; Node->addOperand(CommonNode); Node->addOperand(UncommonNode); return submitCompositeNode(Node); } ComplexDeinterleavingGraph::NodePtr ComplexDeinterleavingGraph::identifyPartialMul(Instruction *Real, Instruction *Imag) { LLVM_DEBUG(dbgs() << "identifyPartialMul " << *Real << " / " << *Imag << "\n"); // Determine rotation ComplexDeinterleavingRotation Rotation; if (Real->getOpcode() == Instruction::FAdd && Imag->getOpcode() == Instruction::FAdd) Rotation = ComplexDeinterleavingRotation::Rotation_0; else if (Real->getOpcode() == Instruction::FSub && Imag->getOpcode() == Instruction::FAdd) Rotation = ComplexDeinterleavingRotation::Rotation_90; else if (Real->getOpcode() == Instruction::FSub && Imag->getOpcode() == Instruction::FSub) Rotation = ComplexDeinterleavingRotation::Rotation_180; else if (Real->getOpcode() == Instruction::FAdd && Imag->getOpcode() == Instruction::FSub) Rotation = ComplexDeinterleavingRotation::Rotation_270; else { LLVM_DEBUG(dbgs() << " - Unhandled rotation.\n"); return nullptr; } if (!Real->getFastMathFlags().allowContract() || !Imag->getFastMathFlags().allowContract()) { LLVM_DEBUG(dbgs() << " - Contract is missing from the FastMath flags.\n"); return nullptr; } Value *CR = Real->getOperand(0); Instruction *RealMulI = dyn_cast(Real->getOperand(1)); if (!RealMulI) return nullptr; Value *CI = Imag->getOperand(0); Instruction *ImagMulI = dyn_cast(Imag->getOperand(1)); if (!ImagMulI) return nullptr; if (!RealMulI->hasOneUse() || !ImagMulI->hasOneUse()) { LLVM_DEBUG(dbgs() << " - Mul instruction has multiple uses\n"); return nullptr; } Instruction *R0 = dyn_cast(RealMulI->getOperand(0)); Instruction *R1 = dyn_cast(RealMulI->getOperand(1)); Instruction *I0 = dyn_cast(ImagMulI->getOperand(0)); Instruction *I1 = dyn_cast(ImagMulI->getOperand(1)); if (!R0 || !R1 || !I0 || !I1) { LLVM_DEBUG(dbgs() << " - Mul operand not Instruction\n"); return nullptr; } Instruction *CommonOperand; Instruction *UncommonRealOp; Instruction *UncommonImagOp; if (R0 == I0 || R0 == I1) { CommonOperand = R0; UncommonRealOp = R1; } else if (R1 == I0 || R1 == I1) { CommonOperand = R1; UncommonRealOp = R0; } else { LLVM_DEBUG(dbgs() << " - No equal operand\n"); return nullptr; } UncommonImagOp = (CommonOperand == I0) ? I1 : I0; if (Rotation == ComplexDeinterleavingRotation::Rotation_90 || Rotation == ComplexDeinterleavingRotation::Rotation_270) std::swap(UncommonRealOp, UncommonImagOp); std::pair PartialMatch( (Rotation == ComplexDeinterleavingRotation::Rotation_0 || Rotation == ComplexDeinterleavingRotation::Rotation_180) ? CommonOperand : nullptr, (Rotation == ComplexDeinterleavingRotation::Rotation_90 || Rotation == ComplexDeinterleavingRotation::Rotation_270) ? CommonOperand : nullptr); auto *CRInst = dyn_cast(CR); auto *CIInst = dyn_cast(CI); if (!CRInst || !CIInst) { LLVM_DEBUG(dbgs() << " - Common operands are not instructions.\n"); return nullptr; } NodePtr CNode = identifyNodeWithImplicitAdd(CRInst, CIInst, PartialMatch); if (!CNode) { LLVM_DEBUG(dbgs() << " - No cnode identified\n"); return nullptr; } NodePtr UncommonRes = identifyNode(UncommonRealOp, UncommonImagOp); if (!UncommonRes) { LLVM_DEBUG(dbgs() << " - No UncommonRes identified\n"); return nullptr; } assert(PartialMatch.first && PartialMatch.second); NodePtr CommonRes = identifyNode(PartialMatch.first, PartialMatch.second); if (!CommonRes) { LLVM_DEBUG(dbgs() << " - No CommonRes identified\n"); return nullptr; } NodePtr Node = prepareCompositeNode( ComplexDeinterleavingOperation::CMulPartial, Real, Imag); Node->Rotation = Rotation; Node->addOperand(CommonRes); Node->addOperand(UncommonRes); Node->addOperand(CNode); return submitCompositeNode(Node); } ComplexDeinterleavingGraph::NodePtr ComplexDeinterleavingGraph::identifyAdd(Instruction *Real, Instruction *Imag) { LLVM_DEBUG(dbgs() << "identifyAdd " << *Real << " / " << *Imag << "\n"); // Determine rotation ComplexDeinterleavingRotation Rotation; if ((Real->getOpcode() == Instruction::FSub && Imag->getOpcode() == Instruction::FAdd) || (Real->getOpcode() == Instruction::Sub && Imag->getOpcode() == Instruction::Add)) Rotation = ComplexDeinterleavingRotation::Rotation_90; else if ((Real->getOpcode() == Instruction::FAdd && Imag->getOpcode() == Instruction::FSub) || (Real->getOpcode() == Instruction::Add && Imag->getOpcode() == Instruction::Sub)) Rotation = ComplexDeinterleavingRotation::Rotation_270; else { LLVM_DEBUG(dbgs() << " - Unhandled case, rotation is not assigned.\n"); return nullptr; } auto *AR = dyn_cast(Real->getOperand(0)); auto *BI = dyn_cast(Real->getOperand(1)); auto *AI = dyn_cast(Imag->getOperand(0)); auto *BR = dyn_cast(Imag->getOperand(1)); if (!AR || !AI || !BR || !BI) { LLVM_DEBUG(dbgs() << " - Not all operands are instructions.\n"); return nullptr; } NodePtr ResA = identifyNode(AR, AI); if (!ResA) { LLVM_DEBUG(dbgs() << " - AR/AI is not identified as a composite node.\n"); return nullptr; } NodePtr ResB = identifyNode(BR, BI); if (!ResB) { LLVM_DEBUG(dbgs() << " - BR/BI is not identified as a composite node.\n"); return nullptr; } NodePtr Node = prepareCompositeNode(ComplexDeinterleavingOperation::CAdd, Real, Imag); Node->Rotation = Rotation; Node->addOperand(ResA); Node->addOperand(ResB); return submitCompositeNode(Node); } static bool isInstructionPairAdd(Instruction *A, Instruction *B) { unsigned OpcA = A->getOpcode(); unsigned OpcB = B->getOpcode(); return (OpcA == Instruction::FSub && OpcB == Instruction::FAdd) || (OpcA == Instruction::FAdd && OpcB == Instruction::FSub) || (OpcA == Instruction::Sub && OpcB == Instruction::Add) || (OpcA == Instruction::Add && OpcB == Instruction::Sub); } static bool isInstructionPairMul(Instruction *A, Instruction *B) { auto Pattern = m_BinOp(m_FMul(m_Value(), m_Value()), m_FMul(m_Value(), m_Value())); return match(A, Pattern) && match(B, Pattern); } static bool isInstructionPotentiallySymmetric(Instruction *I) { switch (I->getOpcode()) { case Instruction::FAdd: case Instruction::FSub: case Instruction::FMul: case Instruction::FNeg: return true; default: return false; } } ComplexDeinterleavingGraph::NodePtr ComplexDeinterleavingGraph::identifySymmetricOperation(Instruction *Real, Instruction *Imag) { if (Real->getOpcode() != Imag->getOpcode()) return nullptr; if (!isInstructionPotentiallySymmetric(Real) || !isInstructionPotentiallySymmetric(Imag)) return nullptr; auto *R0 = dyn_cast(Real->getOperand(0)); auto *I0 = dyn_cast(Imag->getOperand(0)); if (!R0 || !I0) return nullptr; NodePtr Op0 = identifyNode(R0, I0); NodePtr Op1 = nullptr; if (Op0 == nullptr) return nullptr; if (Real->isBinaryOp()) { auto *R1 = dyn_cast(Real->getOperand(1)); auto *I1 = dyn_cast(Imag->getOperand(1)); if (!R1 || !I1) return nullptr; Op1 = identifyNode(R1, I1); if (Op1 == nullptr) return nullptr; } if (isa(Real) && Real->getFastMathFlags() != Imag->getFastMathFlags()) return nullptr; auto Node = prepareCompositeNode(ComplexDeinterleavingOperation::Symmetric, Real, Imag); Node->Opcode = Real->getOpcode(); if (isa(Real)) Node->Flags = Real->getFastMathFlags(); Node->addOperand(Op0); if (Real->isBinaryOp()) Node->addOperand(Op1); return submitCompositeNode(Node); } ComplexDeinterleavingGraph::NodePtr ComplexDeinterleavingGraph::identifyNode(Instruction *Real, Instruction *Imag) { LLVM_DEBUG(dbgs() << "identifyNode on " << *Real << " / " << *Imag << "\n"); if (NodePtr CN = getContainingComposite(Real, Imag)) { LLVM_DEBUG(dbgs() << " - Folding to existing node\n"); return CN; } if (NodePtr CN = identifyDeinterleave(Real, Imag)) return CN; if (NodePtr CN = identifyPHINode(Real, Imag)) return CN; auto *VTy = cast(Real->getType()); auto *NewVTy = VectorType::getDoubleElementsVectorType(VTy); bool HasCMulSupport = TL->isComplexDeinterleavingOperationSupported( ComplexDeinterleavingOperation::CMulPartial, NewVTy); bool HasCAddSupport = TL->isComplexDeinterleavingOperationSupported( ComplexDeinterleavingOperation::CAdd, NewVTy); if (HasCMulSupport && isInstructionPairMul(Real, Imag)) { if (NodePtr CN = identifyPartialMul(Real, Imag)) return CN; } if (HasCAddSupport && isInstructionPairAdd(Real, Imag)) { if (NodePtr CN = identifyAdd(Real, Imag)) return CN; } if (HasCMulSupport && HasCAddSupport) { if (NodePtr CN = identifyReassocNodes(Real, Imag)) return CN; } if (NodePtr CN = identifySymmetricOperation(Real, Imag)) return CN; LLVM_DEBUG(dbgs() << " - Not recognised as a valid pattern.\n"); return nullptr; } ComplexDeinterleavingGraph::NodePtr ComplexDeinterleavingGraph::identifyReassocNodes(Instruction *Real, Instruction *Imag) { if ((Real->getOpcode() != Instruction::FAdd && Real->getOpcode() != Instruction::FSub && Real->getOpcode() != Instruction::FNeg) || (Imag->getOpcode() != Instruction::FAdd && Imag->getOpcode() != Instruction::FSub && Imag->getOpcode() != Instruction::FNeg)) return nullptr; if (Real->getFastMathFlags() != Imag->getFastMathFlags()) { LLVM_DEBUG( dbgs() << "The flags in Real and Imaginary instructions are not identical\n"); return nullptr; } FastMathFlags Flags = Real->getFastMathFlags(); if (!Flags.allowReassoc()) { LLVM_DEBUG( dbgs() << "the 'Reassoc' attribute is missing in the FastMath flags\n"); return nullptr; } // Collect multiplications and addend instructions from the given instruction // while traversing it operands. Additionally, verify that all instructions // have the same fast math flags. auto Collect = [&Flags](Instruction *Insn, std::vector &Muls, std::list &Addends) -> bool { SmallVector> Worklist = {{Insn, true}}; SmallPtrSet Visited; while (!Worklist.empty()) { auto [V, IsPositive] = Worklist.back(); Worklist.pop_back(); if (!Visited.insert(V).second) continue; Instruction *I = dyn_cast(V); if (!I) return false; // If an instruction has more than one user, it indicates that it either // has an external user, which will be later checked by the checkNodes // function, or it is a subexpression utilized by multiple expressions. In // the latter case, we will attempt to separately identify the complex // operation from here in order to create a shared // ComplexDeinterleavingCompositeNode. if (I != Insn && I->getNumUses() > 1) { LLVM_DEBUG(dbgs() << "Found potential sub-expression: " << *I << "\n"); Addends.emplace_back(I, IsPositive); continue; } if (I->getOpcode() == Instruction::FAdd) { Worklist.emplace_back(I->getOperand(1), IsPositive); Worklist.emplace_back(I->getOperand(0), IsPositive); } else if (I->getOpcode() == Instruction::FSub) { Worklist.emplace_back(I->getOperand(1), !IsPositive); Worklist.emplace_back(I->getOperand(0), IsPositive); } else if (I->getOpcode() == Instruction::FMul) { auto *A = dyn_cast(I->getOperand(0)); if (A && A->getOpcode() == Instruction::FNeg) { A = dyn_cast(A->getOperand(0)); IsPositive = !IsPositive; } if (!A) return false; auto *B = dyn_cast(I->getOperand(1)); if (B && B->getOpcode() == Instruction::FNeg) { B = dyn_cast(B->getOperand(0)); IsPositive = !IsPositive; } if (!B) return false; Muls.push_back(Product{A, B, IsPositive}); } else if (I->getOpcode() == Instruction::FNeg) { Worklist.emplace_back(I->getOperand(0), !IsPositive); } else { Addends.emplace_back(I, IsPositive); continue; } if (I->getFastMathFlags() != Flags) { LLVM_DEBUG(dbgs() << "The instruction's fast math flags are " "inconsistent with the root instructions' flags: " << *I << "\n"); return false; } } return true; }; std::vector RealMuls, ImagMuls; std::list RealAddends, ImagAddends; if (!Collect(Real, RealMuls, RealAddends) || !Collect(Imag, ImagMuls, ImagAddends)) return nullptr; if (RealAddends.size() != ImagAddends.size()) return nullptr; NodePtr FinalNode; if (!RealMuls.empty() || !ImagMuls.empty()) { // If there are multiplicands, extract positive addend and use it as an // accumulator FinalNode = extractPositiveAddend(RealAddends, ImagAddends); FinalNode = identifyMultiplications(RealMuls, ImagMuls, FinalNode); if (!FinalNode) return nullptr; } // Identify and process remaining additions if (!RealAddends.empty() || !ImagAddends.empty()) { FinalNode = identifyAdditions(RealAddends, ImagAddends, Flags, FinalNode); if (!FinalNode) return nullptr; } // Set the Real and Imag fields of the final node and submit it FinalNode->Real = Real; FinalNode->Imag = Imag; submitCompositeNode(FinalNode); return FinalNode; } bool ComplexDeinterleavingGraph::collectPartialMuls( const std::vector &RealMuls, const std::vector &ImagMuls, std::vector &PartialMulCandidates) { // Helper function to extract a common operand from two products auto FindCommonInstruction = [](const Product &Real, const Product &Imag) -> Instruction * { if (Real.Multiplicand == Imag.Multiplicand || Real.Multiplicand == Imag.Multiplier) return Real.Multiplicand; if (Real.Multiplier == Imag.Multiplicand || Real.Multiplier == Imag.Multiplier) return Real.Multiplier; return nullptr; }; // Iterating over real and imaginary multiplications to find common operands // If a common operand is found, a partial multiplication candidate is created // and added to the candidates vector The function returns false if no common // operands are found for any product for (unsigned i = 0; i < RealMuls.size(); ++i) { bool FoundCommon = false; for (unsigned j = 0; j < ImagMuls.size(); ++j) { auto *Common = FindCommonInstruction(RealMuls[i], ImagMuls[j]); if (!Common) continue; auto *A = RealMuls[i].Multiplicand == Common ? RealMuls[i].Multiplier : RealMuls[i].Multiplicand; auto *B = ImagMuls[j].Multiplicand == Common ? ImagMuls[j].Multiplier : ImagMuls[j].Multiplicand; bool Inverted = false; auto Node = identifyNode(A, B); if (!Node) { std::swap(A, B); Inverted = true; Node = identifyNode(A, B); } if (!Node) continue; FoundCommon = true; PartialMulCandidates.push_back({Common, Node, i, j, Inverted}); } if (!FoundCommon) return false; } return true; } ComplexDeinterleavingGraph::NodePtr ComplexDeinterleavingGraph::identifyMultiplications( std::vector &RealMuls, std::vector &ImagMuls, NodePtr Accumulator = nullptr) { if (RealMuls.size() != ImagMuls.size()) return nullptr; std::vector Info; if (!collectPartialMuls(RealMuls, ImagMuls, Info)) return nullptr; // Map to store common instruction to node pointers std::map CommonToNode; std::vector Processed(Info.size(), false); for (unsigned I = 0; I < Info.size(); ++I) { if (Processed[I]) continue; PartialMulCandidate &InfoA = Info[I]; for (unsigned J = I + 1; J < Info.size(); ++J) { if (Processed[J]) continue; PartialMulCandidate &InfoB = Info[J]; auto *InfoReal = &InfoA; auto *InfoImag = &InfoB; auto NodeFromCommon = identifyNode(InfoReal->Common, InfoImag->Common); if (!NodeFromCommon) { std::swap(InfoReal, InfoImag); NodeFromCommon = identifyNode(InfoReal->Common, InfoImag->Common); } if (!NodeFromCommon) continue; CommonToNode[InfoReal->Common] = NodeFromCommon; CommonToNode[InfoImag->Common] = NodeFromCommon; Processed[I] = true; Processed[J] = true; } } std::vector ProcessedReal(RealMuls.size(), false); std::vector ProcessedImag(ImagMuls.size(), false); NodePtr Result = Accumulator; for (auto &PMI : Info) { if (ProcessedReal[PMI.RealIdx] || ProcessedImag[PMI.ImagIdx]) continue; auto It = CommonToNode.find(PMI.Common); // TODO: Process independent complex multiplications. Cases like this: // A.real() * B where both A and B are complex numbers. if (It == CommonToNode.end()) { LLVM_DEBUG({ dbgs() << "Unprocessed independent partial multiplication:\n"; for (auto *Mul : {&RealMuls[PMI.RealIdx], &RealMuls[PMI.RealIdx]}) dbgs().indent(4) << (Mul->IsPositive ? "+" : "-") << *Mul->Multiplier << " multiplied by " << *Mul->Multiplicand << "\n"; }); return nullptr; } auto &RealMul = RealMuls[PMI.RealIdx]; auto &ImagMul = ImagMuls[PMI.ImagIdx]; auto NodeA = It->second; auto NodeB = PMI.Node; auto IsMultiplicandReal = PMI.Common == NodeA->Real; // The following table illustrates the relationship between multiplications // and rotations. If we consider the multiplication (X + iY) * (U + iV), we // can see: // // Rotation | Real | Imag | // ---------+--------+--------+ // 0 | x * u | x * v | // 90 | -y * v | y * u | // 180 | -x * u | -x * v | // 270 | y * v | -y * u | // // Check if the candidate can indeed be represented by partial // multiplication // TODO: Add support for multiplication by complex one if ((IsMultiplicandReal && PMI.IsNodeInverted) || (!IsMultiplicandReal && !PMI.IsNodeInverted)) continue; // Determine the rotation based on the multiplications ComplexDeinterleavingRotation Rotation; if (IsMultiplicandReal) { // Detect 0 and 180 degrees rotation if (RealMul.IsPositive && ImagMul.IsPositive) Rotation = llvm::ComplexDeinterleavingRotation::Rotation_0; else if (!RealMul.IsPositive && !ImagMul.IsPositive) Rotation = llvm::ComplexDeinterleavingRotation::Rotation_180; else continue; } else { // Detect 90 and 270 degrees rotation if (!RealMul.IsPositive && ImagMul.IsPositive) Rotation = llvm::ComplexDeinterleavingRotation::Rotation_90; else if (RealMul.IsPositive && !ImagMul.IsPositive) Rotation = llvm::ComplexDeinterleavingRotation::Rotation_270; else continue; } LLVM_DEBUG({ dbgs() << "Identified partial multiplication (X, Y) * (U, V):\n"; dbgs().indent(4) << "X: " << *NodeA->Real << "\n"; dbgs().indent(4) << "Y: " << *NodeA->Imag << "\n"; dbgs().indent(4) << "U: " << *NodeB->Real << "\n"; dbgs().indent(4) << "V: " << *NodeB->Imag << "\n"; dbgs().indent(4) << "Rotation - " << (int)Rotation * 90 << "\n"; }); NodePtr NodeMul = prepareCompositeNode( ComplexDeinterleavingOperation::CMulPartial, nullptr, nullptr); NodeMul->Rotation = Rotation; NodeMul->addOperand(NodeA); NodeMul->addOperand(NodeB); if (Result) NodeMul->addOperand(Result); submitCompositeNode(NodeMul); Result = NodeMul; ProcessedReal[PMI.RealIdx] = true; ProcessedImag[PMI.ImagIdx] = true; } // Ensure all products have been processed, if not return nullptr. if (!all_of(ProcessedReal, [](bool V) { return V; }) || !all_of(ProcessedImag, [](bool V) { return V; })) { // Dump debug information about which partial multiplications are not // processed. LLVM_DEBUG({ dbgs() << "Unprocessed products (Real):\n"; for (size_t i = 0; i < ProcessedReal.size(); ++i) { if (!ProcessedReal[i]) dbgs().indent(4) << (RealMuls[i].IsPositive ? "+" : "-") << *RealMuls[i].Multiplier << " multiplied by " << *RealMuls[i].Multiplicand << "\n"; } dbgs() << "Unprocessed products (Imag):\n"; for (size_t i = 0; i < ProcessedImag.size(); ++i) { if (!ProcessedImag[i]) dbgs().indent(4) << (ImagMuls[i].IsPositive ? "+" : "-") << *ImagMuls[i].Multiplier << " multiplied by " << *ImagMuls[i].Multiplicand << "\n"; } }); return nullptr; } return Result; } ComplexDeinterleavingGraph::NodePtr ComplexDeinterleavingGraph::identifyAdditions(std::list &RealAddends, std::list &ImagAddends, FastMathFlags Flags, NodePtr Accumulator = nullptr) { if (RealAddends.size() != ImagAddends.size()) return nullptr; NodePtr Result; // If we have accumulator use it as first addend if (Accumulator) Result = Accumulator; // Otherwise find an element with both positive real and imaginary parts. else Result = extractPositiveAddend(RealAddends, ImagAddends); if (!Result) return nullptr; while (!RealAddends.empty()) { auto ItR = RealAddends.begin(); auto [R, IsPositiveR] = *ItR; bool FoundImag = false; for (auto ItI = ImagAddends.begin(); ItI != ImagAddends.end(); ++ItI) { auto [I, IsPositiveI] = *ItI; ComplexDeinterleavingRotation Rotation; if (IsPositiveR && IsPositiveI) Rotation = ComplexDeinterleavingRotation::Rotation_0; else if (!IsPositiveR && IsPositiveI) Rotation = ComplexDeinterleavingRotation::Rotation_90; else if (!IsPositiveR && !IsPositiveI) Rotation = ComplexDeinterleavingRotation::Rotation_180; else Rotation = ComplexDeinterleavingRotation::Rotation_270; NodePtr AddNode; if (Rotation == ComplexDeinterleavingRotation::Rotation_0 || Rotation == ComplexDeinterleavingRotation::Rotation_180) { AddNode = identifyNode(R, I); } else { AddNode = identifyNode(I, R); } if (AddNode) { LLVM_DEBUG({ dbgs() << "Identified addition:\n"; dbgs().indent(4) << "X: " << *R << "\n"; dbgs().indent(4) << "Y: " << *I << "\n"; dbgs().indent(4) << "Rotation - " << (int)Rotation * 90 << "\n"; }); NodePtr TmpNode; if (Rotation == llvm::ComplexDeinterleavingRotation::Rotation_0) { TmpNode = prepareCompositeNode( ComplexDeinterleavingOperation::Symmetric, nullptr, nullptr); TmpNode->Opcode = Instruction::FAdd; TmpNode->Flags = Flags; } else if (Rotation == llvm::ComplexDeinterleavingRotation::Rotation_180) { TmpNode = prepareCompositeNode( ComplexDeinterleavingOperation::Symmetric, nullptr, nullptr); TmpNode->Opcode = Instruction::FSub; TmpNode->Flags = Flags; } else { TmpNode = prepareCompositeNode(ComplexDeinterleavingOperation::CAdd, nullptr, nullptr); TmpNode->Rotation = Rotation; } TmpNode->addOperand(Result); TmpNode->addOperand(AddNode); submitCompositeNode(TmpNode); Result = TmpNode; RealAddends.erase(ItR); ImagAddends.erase(ItI); FoundImag = true; break; } } if (!FoundImag) return nullptr; } return Result; } ComplexDeinterleavingGraph::NodePtr ComplexDeinterleavingGraph::extractPositiveAddend( std::list &RealAddends, std::list &ImagAddends) { for (auto ItR = RealAddends.begin(); ItR != RealAddends.end(); ++ItR) { for (auto ItI = ImagAddends.begin(); ItI != ImagAddends.end(); ++ItI) { auto [R, IsPositiveR] = *ItR; auto [I, IsPositiveI] = *ItI; if (IsPositiveR && IsPositiveI) { auto Result = identifyNode(R, I); if (Result) { RealAddends.erase(ItR); ImagAddends.erase(ItI); return Result; } } } } return nullptr; } bool ComplexDeinterleavingGraph::identifyNodes(Instruction *RootI) { // This potential root instruction might already have been recognized as // reduction. Because RootToNode maps both Real and Imaginary parts to // CompositeNode we should choose only one either Real or Imag instruction to // use as an anchor for generating complex instruction. auto It = RootToNode.find(RootI); if (It != RootToNode.end() && It->second->Real == RootI) { OrderedRoots.push_back(RootI); return true; } auto RootNode = identifyRoot(RootI); if (!RootNode) return false; LLVM_DEBUG({ Function *F = RootI->getFunction(); BasicBlock *B = RootI->getParent(); dbgs() << "Complex deinterleaving graph for " << F->getName() << "::" << B->getName() << ".\n"; dump(dbgs()); dbgs() << "\n"; }); RootToNode[RootI] = RootNode; OrderedRoots.push_back(RootI); return true; } bool ComplexDeinterleavingGraph::collectPotentialReductions(BasicBlock *B) { bool FoundPotentialReduction = false; auto *Br = dyn_cast(B->getTerminator()); if (!Br || Br->getNumSuccessors() != 2) return false; // Identify simple one-block loop if (Br->getSuccessor(0) != B && Br->getSuccessor(1) != B) return false; SmallVector PHIs; for (auto &PHI : B->phis()) { if (PHI.getNumIncomingValues() != 2) continue; if (!PHI.getType()->isVectorTy()) continue; auto *ReductionOp = dyn_cast(PHI.getIncomingValueForBlock(B)); if (!ReductionOp) continue; // Check if final instruction is reduced outside of current block Instruction *FinalReduction = nullptr; auto NumUsers = 0u; for (auto *U : ReductionOp->users()) { ++NumUsers; if (U == &PHI) continue; FinalReduction = dyn_cast(U); } if (NumUsers != 2 || !FinalReduction || FinalReduction->getParent() == B) continue; ReductionInfo[ReductionOp] = {&PHI, FinalReduction}; BackEdge = B; auto BackEdgeIdx = PHI.getBasicBlockIndex(B); auto IncomingIdx = BackEdgeIdx == 0 ? 1 : 0; Incoming = PHI.getIncomingBlock(IncomingIdx); FoundPotentialReduction = true; // If the initial value of PHINode is an Instruction, consider it a leaf // value of a complex deinterleaving graph. if (auto *InitPHI = dyn_cast(PHI.getIncomingValueForBlock(Incoming))) FinalInstructions.insert(InitPHI); } return FoundPotentialReduction; } void ComplexDeinterleavingGraph::identifyReductionNodes() { SmallVector Processed(ReductionInfo.size(), false); SmallVector OperationInstruction; for (auto &P : ReductionInfo) OperationInstruction.push_back(P.first); // Identify a complex computation by evaluating two reduction operations that // potentially could be involved for (size_t i = 0; i < OperationInstruction.size(); ++i) { if (Processed[i]) continue; for (size_t j = i + 1; j < OperationInstruction.size(); ++j) { if (Processed[j]) continue; auto *Real = OperationInstruction[i]; auto *Imag = OperationInstruction[j]; RealPHI = ReductionInfo[Real].first; ImagPHI = ReductionInfo[Imag].first; auto Node = identifyNode(Real, Imag); if (!Node) { std::swap(Real, Imag); std::swap(RealPHI, ImagPHI); Node = identifyNode(Real, Imag); } // If a node is identified, mark its operation instructions as used to // prevent re-identification and attach the node to the real part if (Node) { LLVM_DEBUG(dbgs() << "Identified reduction starting from instructions: " << *Real << " / " << *Imag << "\n"); Processed[i] = true; Processed[j] = true; auto RootNode = prepareCompositeNode( ComplexDeinterleavingOperation::ReductionOperation, Real, Imag); RootNode->addOperand(Node); RootToNode[Real] = RootNode; RootToNode[Imag] = RootNode; submitCompositeNode(RootNode); break; } } } RealPHI = nullptr; ImagPHI = nullptr; } bool ComplexDeinterleavingGraph::checkNodes() { // Collect all instructions from roots to leaves SmallPtrSet AllInstructions; SmallVector Worklist; for (auto &Pair : RootToNode) Worklist.push_back(Pair.first); // Extract all instructions that are used by all XCMLA/XCADD/ADD/SUB/NEG // chains while (!Worklist.empty()) { auto *I = Worklist.back(); Worklist.pop_back(); if (!AllInstructions.insert(I).second) continue; for (Value *Op : I->operands()) { if (auto *OpI = dyn_cast(Op)) { if (!FinalInstructions.count(I)) Worklist.emplace_back(OpI); } } } // Find instructions that have users outside of chain SmallVector OuterInstructions; for (auto *I : AllInstructions) { // Skip root nodes if (RootToNode.count(I)) continue; for (User *U : I->users()) { if (AllInstructions.count(cast(U))) continue; // Found an instruction that is not used by XCMLA/XCADD chain Worklist.emplace_back(I); break; } } // If any instructions are found to be used outside, find and remove roots // that somehow connect to those instructions. SmallPtrSet Visited; while (!Worklist.empty()) { auto *I = Worklist.back(); Worklist.pop_back(); if (!Visited.insert(I).second) continue; // Found an impacted root node. Removing it from the nodes to be // deinterleaved if (RootToNode.count(I)) { LLVM_DEBUG(dbgs() << "Instruction " << *I << " could be deinterleaved but its chain of complex " "operations have an outside user\n"); RootToNode.erase(I); } if (!AllInstructions.count(I) || FinalInstructions.count(I)) continue; for (User *U : I->users()) Worklist.emplace_back(cast(U)); for (Value *Op : I->operands()) { if (auto *OpI = dyn_cast(Op)) Worklist.emplace_back(OpI); } } return !RootToNode.empty(); } ComplexDeinterleavingGraph::NodePtr ComplexDeinterleavingGraph::identifyRoot(Instruction *RootI) { if (auto *Intrinsic = dyn_cast(RootI)) { if (Intrinsic->getIntrinsicID() != Intrinsic::experimental_vector_interleave2) return nullptr; auto *Real = dyn_cast(Intrinsic->getOperand(0)); auto *Imag = dyn_cast(Intrinsic->getOperand(1)); if (!Real || !Imag) return nullptr; return identifyNode(Real, Imag); } auto *SVI = dyn_cast(RootI); if (!SVI) return nullptr; // Look for a shufflevector that takes separate vectors of the real and // imaginary components and recombines them into a single vector. if (!isInterleavingMask(SVI->getShuffleMask())) return nullptr; Instruction *Real; Instruction *Imag; if (!match(RootI, m_Shuffle(m_Instruction(Real), m_Instruction(Imag)))) return nullptr; return identifyNode(Real, Imag); } ComplexDeinterleavingGraph::NodePtr ComplexDeinterleavingGraph::identifyDeinterleave(Instruction *Real, Instruction *Imag) { Instruction *I = nullptr; Value *FinalValue = nullptr; if (match(Real, m_ExtractValue<0>(m_Instruction(I))) && match(Imag, m_ExtractValue<1>(m_Specific(I))) && match(I, m_Intrinsic( m_Value(FinalValue)))) { NodePtr PlaceholderNode = prepareCompositeNode( llvm::ComplexDeinterleavingOperation::Deinterleave, Real, Imag); PlaceholderNode->ReplacementNode = FinalValue; FinalInstructions.insert(Real); FinalInstructions.insert(Imag); return submitCompositeNode(PlaceholderNode); } auto *RealShuffle = dyn_cast(Real); auto *ImagShuffle = dyn_cast(Imag); if (!RealShuffle || !ImagShuffle) { if (RealShuffle || ImagShuffle) LLVM_DEBUG(dbgs() << " - There's a shuffle where there shouldn't be.\n"); return nullptr; } Value *RealOp1 = RealShuffle->getOperand(1); if (!isa(RealOp1) && !isa(RealOp1)) { LLVM_DEBUG(dbgs() << " - RealOp1 is not undef or zero.\n"); return nullptr; } Value *ImagOp1 = ImagShuffle->getOperand(1); if (!isa(ImagOp1) && !isa(ImagOp1)) { LLVM_DEBUG(dbgs() << " - ImagOp1 is not undef or zero.\n"); return nullptr; } Value *RealOp0 = RealShuffle->getOperand(0); Value *ImagOp0 = ImagShuffle->getOperand(0); if (RealOp0 != ImagOp0) { LLVM_DEBUG(dbgs() << " - Shuffle operands are not equal.\n"); return nullptr; } ArrayRef RealMask = RealShuffle->getShuffleMask(); ArrayRef ImagMask = ImagShuffle->getShuffleMask(); if (!isDeinterleavingMask(RealMask) || !isDeinterleavingMask(ImagMask)) { LLVM_DEBUG(dbgs() << " - Masks are not deinterleaving.\n"); return nullptr; } if (RealMask[0] != 0 || ImagMask[0] != 1) { LLVM_DEBUG(dbgs() << " - Masks do not have the correct initial value.\n"); return nullptr; } // Type checking, the shuffle type should be a vector type of the same // scalar type, but half the size auto CheckType = [&](ShuffleVectorInst *Shuffle) { Value *Op = Shuffle->getOperand(0); auto *ShuffleTy = cast(Shuffle->getType()); auto *OpTy = cast(Op->getType()); if (OpTy->getScalarType() != ShuffleTy->getScalarType()) return false; if ((ShuffleTy->getNumElements() * 2) != OpTy->getNumElements()) return false; return true; }; auto CheckDeinterleavingShuffle = [&](ShuffleVectorInst *Shuffle) -> bool { if (!CheckType(Shuffle)) return false; ArrayRef Mask = Shuffle->getShuffleMask(); int Last = *Mask.rbegin(); Value *Op = Shuffle->getOperand(0); auto *OpTy = cast(Op->getType()); int NumElements = OpTy->getNumElements(); // Ensure that the deinterleaving shuffle only pulls from the first // shuffle operand. return Last < NumElements; }; if (RealShuffle->getType() != ImagShuffle->getType()) { LLVM_DEBUG(dbgs() << " - Shuffle types aren't equal.\n"); return nullptr; } if (!CheckDeinterleavingShuffle(RealShuffle)) { LLVM_DEBUG(dbgs() << " - RealShuffle is invalid type.\n"); return nullptr; } if (!CheckDeinterleavingShuffle(ImagShuffle)) { LLVM_DEBUG(dbgs() << " - ImagShuffle is invalid type.\n"); return nullptr; } NodePtr PlaceholderNode = prepareCompositeNode(llvm::ComplexDeinterleavingOperation::Deinterleave, RealShuffle, ImagShuffle); PlaceholderNode->ReplacementNode = RealShuffle->getOperand(0); FinalInstructions.insert(RealShuffle); FinalInstructions.insert(ImagShuffle); return submitCompositeNode(PlaceholderNode); } ComplexDeinterleavingGraph::NodePtr ComplexDeinterleavingGraph::identifyPHINode(Instruction *Real, Instruction *Imag) { if (Real != RealPHI || Imag != ImagPHI) return nullptr; NodePtr PlaceholderNode = prepareCompositeNode( ComplexDeinterleavingOperation::ReductionPHI, Real, Imag); return submitCompositeNode(PlaceholderNode); } static Value *replaceSymmetricNode(IRBuilderBase &B, unsigned Opcode, FastMathFlags Flags, Value *InputA, Value *InputB) { Value *I; switch (Opcode) { case Instruction::FNeg: I = B.CreateFNeg(InputA); break; case Instruction::FAdd: I = B.CreateFAdd(InputA, InputB); break; case Instruction::FSub: I = B.CreateFSub(InputA, InputB); break; case Instruction::FMul: I = B.CreateFMul(InputA, InputB); break; default: llvm_unreachable("Incorrect symmetric opcode"); } cast(I)->setFastMathFlags(Flags); return I; } Value *ComplexDeinterleavingGraph::replaceNode(IRBuilderBase &Builder, RawNodePtr Node) { if (Node->ReplacementNode) return Node->ReplacementNode; auto ReplaceOperandIfExist = [&](RawNodePtr &Node, unsigned Idx) -> Value * { return Node->Operands.size() > Idx ? replaceNode(Builder, Node->Operands[Idx]) : nullptr; }; Value *ReplacementNode; switch (Node->Operation) { case ComplexDeinterleavingOperation::CAdd: case ComplexDeinterleavingOperation::CMulPartial: case ComplexDeinterleavingOperation::Symmetric: { Value *Input0 = ReplaceOperandIfExist(Node, 0); Value *Input1 = ReplaceOperandIfExist(Node, 1); Value *Accumulator = ReplaceOperandIfExist(Node, 2); assert(!Input1 || (Input0->getType() == Input1->getType() && "Node inputs need to be of the same type")); assert(!Accumulator || (Input0->getType() == Accumulator->getType() && "Accumulator and input need to be of the same type")); if (Node->Operation == ComplexDeinterleavingOperation::Symmetric) ReplacementNode = replaceSymmetricNode(Builder, Node->Opcode, Node->Flags, Input0, Input1); else ReplacementNode = TL->createComplexDeinterleavingIR( Builder, Node->Operation, Node->Rotation, Input0, Input1, Accumulator); break; } case ComplexDeinterleavingOperation::Deinterleave: llvm_unreachable("Deinterleave node should already have ReplacementNode"); break; case ComplexDeinterleavingOperation::ReductionPHI: { // If Operation is ReductionPHI, a new empty PHINode is created. // It is filled later when the ReductionOperation is processed. auto *VTy = cast(Node->Real->getType()); auto *NewVTy = VectorType::getDoubleElementsVectorType(VTy); auto *NewPHI = PHINode::Create(NewVTy, 0, "", BackEdge->getFirstNonPHI()); OldToNewPHI[dyn_cast(Node->Real)] = NewPHI; ReplacementNode = NewPHI; break; } case ComplexDeinterleavingOperation::ReductionOperation: ReplacementNode = replaceNode(Builder, Node->Operands[0]); processReductionOperation(ReplacementNode, Node); break; } assert(ReplacementNode && "Target failed to create Intrinsic call."); NumComplexTransformations += 1; Node->ReplacementNode = ReplacementNode; return ReplacementNode; } void ComplexDeinterleavingGraph::processReductionOperation( Value *OperationReplacement, RawNodePtr Node) { auto *OldPHIReal = ReductionInfo[Node->Real].first; auto *OldPHIImag = ReductionInfo[Node->Imag].first; auto *NewPHI = OldToNewPHI[OldPHIReal]; auto *VTy = cast(Node->Real->getType()); auto *NewVTy = VectorType::getDoubleElementsVectorType(VTy); // We have to interleave initial origin values coming from IncomingBlock Value *InitReal = OldPHIReal->getIncomingValueForBlock(Incoming); Value *InitImag = OldPHIImag->getIncomingValueForBlock(Incoming); IRBuilder<> Builder(Incoming->getTerminator()); auto *NewInit = Builder.CreateIntrinsic( Intrinsic::experimental_vector_interleave2, NewVTy, {InitReal, InitImag}); NewPHI->addIncoming(NewInit, Incoming); NewPHI->addIncoming(OperationReplacement, BackEdge); // Deinterleave complex vector outside of loop so that it can be finally // reduced auto *FinalReductionReal = ReductionInfo[Node->Real].second; auto *FinalReductionImag = ReductionInfo[Node->Imag].second; Builder.SetInsertPoint( &*FinalReductionReal->getParent()->getFirstInsertionPt()); auto *Deinterleave = Builder.CreateIntrinsic( Intrinsic::experimental_vector_deinterleave2, OperationReplacement->getType(), OperationReplacement); auto *NewReal = Builder.CreateExtractValue(Deinterleave, (uint64_t)0); FinalReductionReal->replaceUsesOfWith(Node->Real, NewReal); Builder.SetInsertPoint(FinalReductionImag); auto *NewImag = Builder.CreateExtractValue(Deinterleave, 1); FinalReductionImag->replaceUsesOfWith(Node->Imag, NewImag); } void ComplexDeinterleavingGraph::replaceNodes() { SmallVector DeadInstrRoots; for (auto *RootInstruction : OrderedRoots) { // Check if this potential root went through check process and we can // deinterleave it if (!RootToNode.count(RootInstruction)) continue; IRBuilder<> Builder(RootInstruction); auto RootNode = RootToNode[RootInstruction]; Value *R = replaceNode(Builder, RootNode.get()); if (RootNode->Operation == ComplexDeinterleavingOperation::ReductionOperation) { ReductionInfo[RootNode->Real].first->removeIncomingValue(BackEdge); ReductionInfo[RootNode->Imag].first->removeIncomingValue(BackEdge); DeadInstrRoots.push_back(RootNode->Real); DeadInstrRoots.push_back(RootNode->Imag); } else { assert(R && "Unable to find replacement for RootInstruction"); DeadInstrRoots.push_back(RootInstruction); RootInstruction->replaceAllUsesWith(R); } } for (auto *I : DeadInstrRoots) RecursivelyDeleteTriviallyDeadInstructions(I, TLI); }