1 //===- ComplexDeinterleavingPass.cpp --------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Identification: 10 // This step is responsible for finding the patterns that can be lowered to 11 // complex instructions, and building a graph to represent the complex 12 // structures. Starting from the "Converging Shuffle" (a shuffle that 13 // reinterleaves the complex components, with a mask of <0, 2, 1, 3>), the 14 // operands are evaluated and identified as "Composite Nodes" (collections of 15 // instructions that can potentially be lowered to a single complex 16 // instruction). This is performed by checking the real and imaginary components 17 // and tracking the data flow for each component while following the operand 18 // pairs. Validity of each node is expected to be done upon creation, and any 19 // validation errors should halt traversal and prevent further graph 20 // construction. 21 // Instead of relying on Shuffle operations, vector interleaving and 22 // deinterleaving can be represented by vector.interleave2 and 23 // vector.deinterleave2 intrinsics. Scalable vectors can be represented only by 24 // these intrinsics, whereas, fixed-width vectors are recognized for both 25 // shufflevector instruction and intrinsics. 26 // 27 // Replacement: 28 // This step traverses the graph built up by identification, delegating to the 29 // target to validate and generate the correct intrinsics, and plumbs them 30 // together connecting each end of the new intrinsics graph to the existing 31 // use-def chain. This step is assumed to finish successfully, as all 32 // information is expected to be correct by this point. 33 // 34 // 35 // Internal data structure: 36 // ComplexDeinterleavingGraph: 37 // Keeps references to all the valid CompositeNodes formed as part of the 38 // transformation, and every Instruction contained within said nodes. It also 39 // holds onto a reference to the root Instruction, and the root node that should 40 // replace it. 41 // 42 // ComplexDeinterleavingCompositeNode: 43 // A CompositeNode represents a single transformation point; each node should 44 // transform into a single complex instruction (ignoring vector splitting, which 45 // would generate more instructions per node). They are identified in a 46 // depth-first manner, traversing and identifying the operands of each 47 // instruction in the order they appear in the IR. 48 // Each node maintains a reference to its Real and Imaginary instructions, 49 // as well as any additional instructions that make up the identified operation 50 // (Internal instructions should only have uses within their containing node). 51 // A Node also contains the rotation and operation type that it represents. 52 // Operands contains pointers to other CompositeNodes, acting as the edges in 53 // the graph. ReplacementValue is the transformed Value* that has been emitted 54 // to the IR. 55 // 56 // Note: If the operation of a Node is Shuffle, only the Real, Imaginary, and 57 // ReplacementValue fields of that Node are relevant, where the ReplacementValue 58 // should be pre-populated. 59 // 60 //===----------------------------------------------------------------------===// 61 62 #include "llvm/CodeGen/ComplexDeinterleavingPass.h" 63 #include "llvm/ADT/MapVector.h" 64 #include "llvm/ADT/Statistic.h" 65 #include "llvm/Analysis/TargetLibraryInfo.h" 66 #include "llvm/Analysis/TargetTransformInfo.h" 67 #include "llvm/CodeGen/TargetLowering.h" 68 #include "llvm/CodeGen/TargetSubtargetInfo.h" 69 #include "llvm/IR/IRBuilder.h" 70 #include "llvm/IR/PatternMatch.h" 71 #include "llvm/InitializePasses.h" 72 #include "llvm/Target/TargetMachine.h" 73 #include "llvm/Transforms/Utils/Local.h" 74 #include <algorithm> 75 76 using namespace llvm; 77 using namespace PatternMatch; 78 79 #define DEBUG_TYPE "complex-deinterleaving" 80 81 STATISTIC(NumComplexTransformations, "Amount of complex patterns transformed"); 82 83 static cl::opt<bool> ComplexDeinterleavingEnabled( 84 "enable-complex-deinterleaving", 85 cl::desc("Enable generation of complex instructions"), cl::init(true), 86 cl::Hidden); 87 88 /// Checks the given mask, and determines whether said mask is interleaving. 89 /// 90 /// To be interleaving, a mask must alternate between `i` and `i + (Length / 91 /// 2)`, and must contain all numbers within the range of `[0..Length)` (e.g. a 92 /// 4x vector interleaving mask would be <0, 2, 1, 3>). 93 static bool isInterleavingMask(ArrayRef<int> Mask); 94 95 /// Checks the given mask, and determines whether said mask is deinterleaving. 96 /// 97 /// To be deinterleaving, a mask must increment in steps of 2, and either start 98 /// with 0 or 1. 99 /// (e.g. an 8x vector deinterleaving mask would be either <0, 2, 4, 6> or 100 /// <1, 3, 5, 7>). 101 static bool isDeinterleavingMask(ArrayRef<int> Mask); 102 103 /// Returns true if the operation is a negation of V, and it works for both 104 /// integers and floats. 105 static bool isNeg(Value *V); 106 107 /// Returns the operand for negation operation. 108 static Value *getNegOperand(Value *V); 109 110 namespace { 111 template <typename T, typename IterT> 112 std::optional<T> findCommonBetweenCollections(IterT A, IterT B) { 113 auto Common = llvm::find_if(A, [B](T I) { return llvm::is_contained(B, I); }); 114 if (Common != A.end()) 115 return std::make_optional(*Common); 116 return std::nullopt; 117 } 118 119 class ComplexDeinterleavingLegacyPass : public FunctionPass { 120 public: 121 static char ID; 122 123 ComplexDeinterleavingLegacyPass(const TargetMachine *TM = nullptr) 124 : FunctionPass(ID), TM(TM) { 125 initializeComplexDeinterleavingLegacyPassPass( 126 *PassRegistry::getPassRegistry()); 127 } 128 129 StringRef getPassName() const override { 130 return "Complex Deinterleaving Pass"; 131 } 132 133 bool runOnFunction(Function &F) override; 134 void getAnalysisUsage(AnalysisUsage &AU) const override { 135 AU.addRequired<TargetLibraryInfoWrapperPass>(); 136 AU.setPreservesCFG(); 137 } 138 139 private: 140 const TargetMachine *TM; 141 }; 142 143 class ComplexDeinterleavingGraph; 144 struct ComplexDeinterleavingCompositeNode { 145 146 ComplexDeinterleavingCompositeNode(ComplexDeinterleavingOperation Op, 147 Value *R, Value *I) 148 : Operation(Op), Real(R), Imag(I) {} 149 150 private: 151 friend class ComplexDeinterleavingGraph; 152 using NodePtr = std::shared_ptr<ComplexDeinterleavingCompositeNode>; 153 using RawNodePtr = ComplexDeinterleavingCompositeNode *; 154 bool OperandsValid = true; 155 156 public: 157 ComplexDeinterleavingOperation Operation; 158 Value *Real; 159 Value *Imag; 160 161 // This two members are required exclusively for generating 162 // ComplexDeinterleavingOperation::Symmetric operations. 163 unsigned Opcode; 164 std::optional<FastMathFlags> Flags; 165 166 ComplexDeinterleavingRotation Rotation = 167 ComplexDeinterleavingRotation::Rotation_0; 168 SmallVector<RawNodePtr> Operands; 169 Value *ReplacementNode = nullptr; 170 171 void addOperand(NodePtr Node) { 172 if (!Node || !Node.get()) 173 OperandsValid = false; 174 Operands.push_back(Node.get()); 175 } 176 177 void dump() { dump(dbgs()); } 178 void dump(raw_ostream &OS) { 179 auto PrintValue = [&](Value *V) { 180 if (V) { 181 OS << "\""; 182 V->print(OS, true); 183 OS << "\"\n"; 184 } else 185 OS << "nullptr\n"; 186 }; 187 auto PrintNodeRef = [&](RawNodePtr Ptr) { 188 if (Ptr) 189 OS << Ptr << "\n"; 190 else 191 OS << "nullptr\n"; 192 }; 193 194 OS << "- CompositeNode: " << this << "\n"; 195 OS << " Real: "; 196 PrintValue(Real); 197 OS << " Imag: "; 198 PrintValue(Imag); 199 OS << " ReplacementNode: "; 200 PrintValue(ReplacementNode); 201 OS << " Operation: " << (int)Operation << "\n"; 202 OS << " Rotation: " << ((int)Rotation * 90) << "\n"; 203 OS << " Operands: \n"; 204 for (const auto &Op : Operands) { 205 OS << " - "; 206 PrintNodeRef(Op); 207 } 208 } 209 210 bool areOperandsValid() { return OperandsValid; } 211 }; 212 213 class ComplexDeinterleavingGraph { 214 public: 215 struct Product { 216 Value *Multiplier; 217 Value *Multiplicand; 218 bool IsPositive; 219 }; 220 221 using Addend = std::pair<Value *, bool>; 222 using NodePtr = ComplexDeinterleavingCompositeNode::NodePtr; 223 using RawNodePtr = ComplexDeinterleavingCompositeNode::RawNodePtr; 224 225 // Helper struct for holding info about potential partial multiplication 226 // candidates 227 struct PartialMulCandidate { 228 Value *Common; 229 NodePtr Node; 230 unsigned RealIdx; 231 unsigned ImagIdx; 232 bool IsNodeInverted; 233 }; 234 235 explicit ComplexDeinterleavingGraph(const TargetLowering *TL, 236 const TargetLibraryInfo *TLI) 237 : TL(TL), TLI(TLI) {} 238 239 private: 240 const TargetLowering *TL = nullptr; 241 const TargetLibraryInfo *TLI = nullptr; 242 SmallVector<NodePtr> CompositeNodes; 243 DenseMap<std::pair<Value *, Value *>, NodePtr> CachedResult; 244 245 SmallPtrSet<Instruction *, 16> FinalInstructions; 246 247 /// Root instructions are instructions from which complex computation starts 248 std::map<Instruction *, NodePtr> RootToNode; 249 250 /// Topologically sorted root instructions 251 SmallVector<Instruction *, 1> OrderedRoots; 252 253 /// When examining a basic block for complex deinterleaving, if it is a simple 254 /// one-block loop, then the only incoming block is 'Incoming' and the 255 /// 'BackEdge' block is the block itself." 256 BasicBlock *BackEdge = nullptr; 257 BasicBlock *Incoming = nullptr; 258 259 /// ReductionInfo maps from %ReductionOp to %PHInode and Instruction 260 /// %OutsideUser as it is shown in the IR: 261 /// 262 /// vector.body: 263 /// %PHInode = phi <vector type> [ zeroinitializer, %entry ], 264 /// [ %ReductionOp, %vector.body ] 265 /// ... 266 /// %ReductionOp = fadd i64 ... 267 /// ... 268 /// br i1 %condition, label %vector.body, %middle.block 269 /// 270 /// middle.block: 271 /// %OutsideUser = llvm.vector.reduce.fadd(..., %ReductionOp) 272 /// 273 /// %OutsideUser can be `llvm.vector.reduce.fadd` or `fadd` preceding 274 /// `llvm.vector.reduce.fadd` when unroll factor isn't one. 275 MapVector<Instruction *, std::pair<PHINode *, Instruction *>> ReductionInfo; 276 277 /// In the process of detecting a reduction, we consider a pair of 278 /// %ReductionOP, which we refer to as real and imag (or vice versa), and 279 /// traverse the use-tree to detect complex operations. As this is a reduction 280 /// operation, it will eventually reach RealPHI and ImagPHI, which corresponds 281 /// to the %ReductionOPs that we suspect to be complex. 282 /// RealPHI and ImagPHI are used by the identifyPHINode method. 283 PHINode *RealPHI = nullptr; 284 PHINode *ImagPHI = nullptr; 285 286 /// Set this flag to true if RealPHI and ImagPHI were reached during reduction 287 /// detection. 288 bool PHIsFound = false; 289 290 /// OldToNewPHI maps the original real PHINode to a new, double-sized PHINode. 291 /// The new PHINode corresponds to a vector of deinterleaved complex numbers. 292 /// This mapping is populated during 293 /// ComplexDeinterleavingOperation::ReductionPHI node replacement. It is then 294 /// used in the ComplexDeinterleavingOperation::ReductionOperation node 295 /// replacement process. 296 std::map<PHINode *, PHINode *> OldToNewPHI; 297 298 NodePtr prepareCompositeNode(ComplexDeinterleavingOperation Operation, 299 Value *R, Value *I) { 300 assert(((Operation != ComplexDeinterleavingOperation::ReductionPHI && 301 Operation != ComplexDeinterleavingOperation::ReductionOperation) || 302 (R && I)) && 303 "Reduction related nodes must have Real and Imaginary parts"); 304 return std::make_shared<ComplexDeinterleavingCompositeNode>(Operation, R, 305 I); 306 } 307 308 NodePtr submitCompositeNode(NodePtr Node) { 309 CompositeNodes.push_back(Node); 310 if (Node->Real) 311 CachedResult[{Node->Real, Node->Imag}] = Node; 312 return Node; 313 } 314 315 /// Identifies a complex partial multiply pattern and its rotation, based on 316 /// the following patterns 317 /// 318 /// 0: r: cr + ar * br 319 /// i: ci + ar * bi 320 /// 90: r: cr - ai * bi 321 /// i: ci + ai * br 322 /// 180: r: cr - ar * br 323 /// i: ci - ar * bi 324 /// 270: r: cr + ai * bi 325 /// i: ci - ai * br 326 NodePtr identifyPartialMul(Instruction *Real, Instruction *Imag); 327 328 /// Identify the other branch of a Partial Mul, taking the CommonOperandI that 329 /// is partially known from identifyPartialMul, filling in the other half of 330 /// the complex pair. 331 NodePtr 332 identifyNodeWithImplicitAdd(Instruction *I, Instruction *J, 333 std::pair<Value *, Value *> &CommonOperandI); 334 335 /// Identifies a complex add pattern and its rotation, based on the following 336 /// patterns. 337 /// 338 /// 90: r: ar - bi 339 /// i: ai + br 340 /// 270: r: ar + bi 341 /// i: ai - br 342 NodePtr identifyAdd(Instruction *Real, Instruction *Imag); 343 NodePtr identifySymmetricOperation(Instruction *Real, Instruction *Imag); 344 NodePtr identifyPartialReduction(Value *R, Value *I); 345 NodePtr identifyDotProduct(Value *Inst); 346 347 NodePtr identifyNode(Value *R, Value *I); 348 349 /// Determine if a sum of complex numbers can be formed from \p RealAddends 350 /// and \p ImagAddens. If \p Accumulator is not null, add the result to it. 351 /// Return nullptr if it is not possible to construct a complex number. 352 /// \p Flags are needed to generate symmetric Add and Sub operations. 353 NodePtr identifyAdditions(std::list<Addend> &RealAddends, 354 std::list<Addend> &ImagAddends, 355 std::optional<FastMathFlags> Flags, 356 NodePtr Accumulator); 357 358 /// Extract one addend that have both real and imaginary parts positive. 359 NodePtr extractPositiveAddend(std::list<Addend> &RealAddends, 360 std::list<Addend> &ImagAddends); 361 362 /// Determine if sum of multiplications of complex numbers can be formed from 363 /// \p RealMuls and \p ImagMuls. If \p Accumulator is not null, add the result 364 /// to it. Return nullptr if it is not possible to construct a complex number. 365 NodePtr identifyMultiplications(std::vector<Product> &RealMuls, 366 std::vector<Product> &ImagMuls, 367 NodePtr Accumulator); 368 369 /// Go through pairs of multiplication (one Real and one Imag) and find all 370 /// possible candidates for partial multiplication and put them into \p 371 /// Candidates. Returns true if all Product has pair with common operand 372 bool collectPartialMuls(const std::vector<Product> &RealMuls, 373 const std::vector<Product> &ImagMuls, 374 std::vector<PartialMulCandidate> &Candidates); 375 376 /// If the code is compiled with -Ofast or expressions have `reassoc` flag, 377 /// the order of complex computation operations may be significantly altered, 378 /// and the real and imaginary parts may not be executed in parallel. This 379 /// function takes this into consideration and employs a more general approach 380 /// to identify complex computations. Initially, it gathers all the addends 381 /// and multiplicands and then constructs a complex expression from them. 382 NodePtr identifyReassocNodes(Instruction *I, Instruction *J); 383 384 NodePtr identifyRoot(Instruction *I); 385 386 /// Identifies the Deinterleave operation applied to a vector containing 387 /// complex numbers. There are two ways to represent the Deinterleave 388 /// operation: 389 /// * Using two shufflevectors with even indices for /pReal instruction and 390 /// odd indices for /pImag instructions (only for fixed-width vectors) 391 /// * Using two extractvalue instructions applied to `vector.deinterleave2` 392 /// intrinsic (for both fixed and scalable vectors) 393 NodePtr identifyDeinterleave(Instruction *Real, Instruction *Imag); 394 395 /// identifying the operation that represents a complex number repeated in a 396 /// Splat vector. There are two possible types of splats: ConstantExpr with 397 /// the opcode ShuffleVector and ShuffleVectorInstr. Both should have an 398 /// initialization mask with all values set to zero. 399 NodePtr identifySplat(Value *Real, Value *Imag); 400 401 NodePtr identifyPHINode(Instruction *Real, Instruction *Imag); 402 403 /// Identifies SelectInsts in a loop that has reduction with predication masks 404 /// and/or predicated tail folding 405 NodePtr identifySelectNode(Instruction *Real, Instruction *Imag); 406 407 Value *replaceNode(IRBuilderBase &Builder, RawNodePtr Node); 408 409 /// Complete IR modifications after producing new reduction operation: 410 /// * Populate the PHINode generated for 411 /// ComplexDeinterleavingOperation::ReductionPHI 412 /// * Deinterleave the final value outside of the loop and repurpose original 413 /// reduction users 414 void processReductionOperation(Value *OperationReplacement, RawNodePtr Node); 415 void processReductionSingle(Value *OperationReplacement, RawNodePtr Node); 416 417 public: 418 void dump() { dump(dbgs()); } 419 void dump(raw_ostream &OS) { 420 for (const auto &Node : CompositeNodes) 421 Node->dump(OS); 422 } 423 424 /// Returns false if the deinterleaving operation should be cancelled for the 425 /// current graph. 426 bool identifyNodes(Instruction *RootI); 427 428 /// In case \pB is one-block loop, this function seeks potential reductions 429 /// and populates ReductionInfo. Returns true if any reductions were 430 /// identified. 431 bool collectPotentialReductions(BasicBlock *B); 432 433 void identifyReductionNodes(); 434 435 /// Check that every instruction, from the roots to the leaves, has internal 436 /// uses. 437 bool checkNodes(); 438 439 /// Perform the actual replacement of the underlying instruction graph. 440 void replaceNodes(); 441 }; 442 443 class ComplexDeinterleaving { 444 public: 445 ComplexDeinterleaving(const TargetLowering *tl, const TargetLibraryInfo *tli) 446 : TL(tl), TLI(tli) {} 447 bool runOnFunction(Function &F); 448 449 private: 450 bool evaluateBasicBlock(BasicBlock *B); 451 452 const TargetLowering *TL = nullptr; 453 const TargetLibraryInfo *TLI = nullptr; 454 }; 455 456 } // namespace 457 458 char ComplexDeinterleavingLegacyPass::ID = 0; 459 460 INITIALIZE_PASS_BEGIN(ComplexDeinterleavingLegacyPass, DEBUG_TYPE, 461 "Complex Deinterleaving", false, false) 462 INITIALIZE_PASS_END(ComplexDeinterleavingLegacyPass, DEBUG_TYPE, 463 "Complex Deinterleaving", false, false) 464 465 PreservedAnalyses ComplexDeinterleavingPass::run(Function &F, 466 FunctionAnalysisManager &AM) { 467 const TargetLowering *TL = TM->getSubtargetImpl(F)->getTargetLowering(); 468 auto &TLI = AM.getResult<llvm::TargetLibraryAnalysis>(F); 469 if (!ComplexDeinterleaving(TL, &TLI).runOnFunction(F)) 470 return PreservedAnalyses::all(); 471 472 PreservedAnalyses PA; 473 PA.preserve<FunctionAnalysisManagerModuleProxy>(); 474 return PA; 475 } 476 477 FunctionPass *llvm::createComplexDeinterleavingPass(const TargetMachine *TM) { 478 return new ComplexDeinterleavingLegacyPass(TM); 479 } 480 481 bool ComplexDeinterleavingLegacyPass::runOnFunction(Function &F) { 482 const auto *TL = TM->getSubtargetImpl(F)->getTargetLowering(); 483 auto TLI = getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F); 484 return ComplexDeinterleaving(TL, &TLI).runOnFunction(F); 485 } 486 487 bool ComplexDeinterleaving::runOnFunction(Function &F) { 488 if (!ComplexDeinterleavingEnabled) { 489 LLVM_DEBUG( 490 dbgs() << "Complex deinterleaving has been explicitly disabled.\n"); 491 return false; 492 } 493 494 if (!TL->isComplexDeinterleavingSupported()) { 495 LLVM_DEBUG( 496 dbgs() << "Complex deinterleaving has been disabled, target does " 497 "not support lowering of complex number operations.\n"); 498 return false; 499 } 500 501 bool Changed = false; 502 for (auto &B : F) 503 Changed |= evaluateBasicBlock(&B); 504 505 return Changed; 506 } 507 508 static bool isInterleavingMask(ArrayRef<int> Mask) { 509 // If the size is not even, it's not an interleaving mask 510 if ((Mask.size() & 1)) 511 return false; 512 513 int HalfNumElements = Mask.size() / 2; 514 for (int Idx = 0; Idx < HalfNumElements; ++Idx) { 515 int MaskIdx = Idx * 2; 516 if (Mask[MaskIdx] != Idx || Mask[MaskIdx + 1] != (Idx + HalfNumElements)) 517 return false; 518 } 519 520 return true; 521 } 522 523 static bool isDeinterleavingMask(ArrayRef<int> Mask) { 524 int Offset = Mask[0]; 525 int HalfNumElements = Mask.size() / 2; 526 527 for (int Idx = 1; Idx < HalfNumElements; ++Idx) { 528 if (Mask[Idx] != (Idx * 2) + Offset) 529 return false; 530 } 531 532 return true; 533 } 534 535 bool isNeg(Value *V) { 536 return match(V, m_FNeg(m_Value())) || match(V, m_Neg(m_Value())); 537 } 538 539 Value *getNegOperand(Value *V) { 540 assert(isNeg(V)); 541 auto *I = cast<Instruction>(V); 542 if (I->getOpcode() == Instruction::FNeg) 543 return I->getOperand(0); 544 545 return I->getOperand(1); 546 } 547 548 bool ComplexDeinterleaving::evaluateBasicBlock(BasicBlock *B) { 549 ComplexDeinterleavingGraph Graph(TL, TLI); 550 if (Graph.collectPotentialReductions(B)) 551 Graph.identifyReductionNodes(); 552 553 for (auto &I : *B) 554 Graph.identifyNodes(&I); 555 556 if (Graph.checkNodes()) { 557 Graph.replaceNodes(); 558 return true; 559 } 560 561 return false; 562 } 563 564 ComplexDeinterleavingGraph::NodePtr 565 ComplexDeinterleavingGraph::identifyNodeWithImplicitAdd( 566 Instruction *Real, Instruction *Imag, 567 std::pair<Value *, Value *> &PartialMatch) { 568 LLVM_DEBUG(dbgs() << "identifyNodeWithImplicitAdd " << *Real << " / " << *Imag 569 << "\n"); 570 571 if (!Real->hasOneUse() || !Imag->hasOneUse()) { 572 LLVM_DEBUG(dbgs() << " - Mul operand has multiple uses.\n"); 573 return nullptr; 574 } 575 576 if ((Real->getOpcode() != Instruction::FMul && 577 Real->getOpcode() != Instruction::Mul) || 578 (Imag->getOpcode() != Instruction::FMul && 579 Imag->getOpcode() != Instruction::Mul)) { 580 LLVM_DEBUG( 581 dbgs() << " - Real or imaginary instruction is not fmul or mul\n"); 582 return nullptr; 583 } 584 585 Value *R0 = Real->getOperand(0); 586 Value *R1 = Real->getOperand(1); 587 Value *I0 = Imag->getOperand(0); 588 Value *I1 = Imag->getOperand(1); 589 590 // A +/+ has a rotation of 0. If any of the operands are fneg, we flip the 591 // rotations and use the operand. 592 unsigned Negs = 0; 593 Value *Op; 594 if (match(R0, m_Neg(m_Value(Op)))) { 595 Negs |= 1; 596 R0 = Op; 597 } else if (match(R1, m_Neg(m_Value(Op)))) { 598 Negs |= 1; 599 R1 = Op; 600 } 601 602 if (isNeg(I0)) { 603 Negs |= 2; 604 Negs ^= 1; 605 I0 = Op; 606 } else if (match(I1, m_Neg(m_Value(Op)))) { 607 Negs |= 2; 608 Negs ^= 1; 609 I1 = Op; 610 } 611 612 ComplexDeinterleavingRotation Rotation = (ComplexDeinterleavingRotation)Negs; 613 614 Value *CommonOperand; 615 Value *UncommonRealOp; 616 Value *UncommonImagOp; 617 618 if (R0 == I0 || R0 == I1) { 619 CommonOperand = R0; 620 UncommonRealOp = R1; 621 } else if (R1 == I0 || R1 == I1) { 622 CommonOperand = R1; 623 UncommonRealOp = R0; 624 } else { 625 LLVM_DEBUG(dbgs() << " - No equal operand\n"); 626 return nullptr; 627 } 628 629 UncommonImagOp = (CommonOperand == I0) ? I1 : I0; 630 if (Rotation == ComplexDeinterleavingRotation::Rotation_90 || 631 Rotation == ComplexDeinterleavingRotation::Rotation_270) 632 std::swap(UncommonRealOp, UncommonImagOp); 633 634 // Between identifyPartialMul and here we need to have found a complete valid 635 // pair from the CommonOperand of each part. 636 if (Rotation == ComplexDeinterleavingRotation::Rotation_0 || 637 Rotation == ComplexDeinterleavingRotation::Rotation_180) 638 PartialMatch.first = CommonOperand; 639 else 640 PartialMatch.second = CommonOperand; 641 642 if (!PartialMatch.first || !PartialMatch.second) { 643 LLVM_DEBUG(dbgs() << " - Incomplete partial match\n"); 644 return nullptr; 645 } 646 647 NodePtr CommonNode = identifyNode(PartialMatch.first, PartialMatch.second); 648 if (!CommonNode) { 649 LLVM_DEBUG(dbgs() << " - No CommonNode identified\n"); 650 return nullptr; 651 } 652 653 NodePtr UncommonNode = identifyNode(UncommonRealOp, UncommonImagOp); 654 if (!UncommonNode) { 655 LLVM_DEBUG(dbgs() << " - No UncommonNode identified\n"); 656 return nullptr; 657 } 658 659 NodePtr Node = prepareCompositeNode( 660 ComplexDeinterleavingOperation::CMulPartial, Real, Imag); 661 Node->Rotation = Rotation; 662 Node->addOperand(CommonNode); 663 Node->addOperand(UncommonNode); 664 return submitCompositeNode(Node); 665 } 666 667 ComplexDeinterleavingGraph::NodePtr 668 ComplexDeinterleavingGraph::identifyPartialMul(Instruction *Real, 669 Instruction *Imag) { 670 LLVM_DEBUG(dbgs() << "identifyPartialMul " << *Real << " / " << *Imag 671 << "\n"); 672 // Determine rotation 673 auto IsAdd = [](unsigned Op) { 674 return Op == Instruction::FAdd || Op == Instruction::Add; 675 }; 676 auto IsSub = [](unsigned Op) { 677 return Op == Instruction::FSub || Op == Instruction::Sub; 678 }; 679 ComplexDeinterleavingRotation Rotation; 680 if (IsAdd(Real->getOpcode()) && IsAdd(Imag->getOpcode())) 681 Rotation = ComplexDeinterleavingRotation::Rotation_0; 682 else if (IsSub(Real->getOpcode()) && IsAdd(Imag->getOpcode())) 683 Rotation = ComplexDeinterleavingRotation::Rotation_90; 684 else if (IsSub(Real->getOpcode()) && IsSub(Imag->getOpcode())) 685 Rotation = ComplexDeinterleavingRotation::Rotation_180; 686 else if (IsAdd(Real->getOpcode()) && IsSub(Imag->getOpcode())) 687 Rotation = ComplexDeinterleavingRotation::Rotation_270; 688 else { 689 LLVM_DEBUG(dbgs() << " - Unhandled rotation.\n"); 690 return nullptr; 691 } 692 693 if (isa<FPMathOperator>(Real) && 694 (!Real->getFastMathFlags().allowContract() || 695 !Imag->getFastMathFlags().allowContract())) { 696 LLVM_DEBUG(dbgs() << " - Contract is missing from the FastMath flags.\n"); 697 return nullptr; 698 } 699 700 Value *CR = Real->getOperand(0); 701 Instruction *RealMulI = dyn_cast<Instruction>(Real->getOperand(1)); 702 if (!RealMulI) 703 return nullptr; 704 Value *CI = Imag->getOperand(0); 705 Instruction *ImagMulI = dyn_cast<Instruction>(Imag->getOperand(1)); 706 if (!ImagMulI) 707 return nullptr; 708 709 if (!RealMulI->hasOneUse() || !ImagMulI->hasOneUse()) { 710 LLVM_DEBUG(dbgs() << " - Mul instruction has multiple uses\n"); 711 return nullptr; 712 } 713 714 Value *R0 = RealMulI->getOperand(0); 715 Value *R1 = RealMulI->getOperand(1); 716 Value *I0 = ImagMulI->getOperand(0); 717 Value *I1 = ImagMulI->getOperand(1); 718 719 Value *CommonOperand; 720 Value *UncommonRealOp; 721 Value *UncommonImagOp; 722 723 if (R0 == I0 || R0 == I1) { 724 CommonOperand = R0; 725 UncommonRealOp = R1; 726 } else if (R1 == I0 || R1 == I1) { 727 CommonOperand = R1; 728 UncommonRealOp = R0; 729 } else { 730 LLVM_DEBUG(dbgs() << " - No equal operand\n"); 731 return nullptr; 732 } 733 734 UncommonImagOp = (CommonOperand == I0) ? I1 : I0; 735 if (Rotation == ComplexDeinterleavingRotation::Rotation_90 || 736 Rotation == ComplexDeinterleavingRotation::Rotation_270) 737 std::swap(UncommonRealOp, UncommonImagOp); 738 739 std::pair<Value *, Value *> PartialMatch( 740 (Rotation == ComplexDeinterleavingRotation::Rotation_0 || 741 Rotation == ComplexDeinterleavingRotation::Rotation_180) 742 ? CommonOperand 743 : nullptr, 744 (Rotation == ComplexDeinterleavingRotation::Rotation_90 || 745 Rotation == ComplexDeinterleavingRotation::Rotation_270) 746 ? CommonOperand 747 : nullptr); 748 749 auto *CRInst = dyn_cast<Instruction>(CR); 750 auto *CIInst = dyn_cast<Instruction>(CI); 751 752 if (!CRInst || !CIInst) { 753 LLVM_DEBUG(dbgs() << " - Common operands are not instructions.\n"); 754 return nullptr; 755 } 756 757 NodePtr CNode = identifyNodeWithImplicitAdd(CRInst, CIInst, PartialMatch); 758 if (!CNode) { 759 LLVM_DEBUG(dbgs() << " - No cnode identified\n"); 760 return nullptr; 761 } 762 763 NodePtr UncommonRes = identifyNode(UncommonRealOp, UncommonImagOp); 764 if (!UncommonRes) { 765 LLVM_DEBUG(dbgs() << " - No UncommonRes identified\n"); 766 return nullptr; 767 } 768 769 assert(PartialMatch.first && PartialMatch.second); 770 NodePtr CommonRes = identifyNode(PartialMatch.first, PartialMatch.second); 771 if (!CommonRes) { 772 LLVM_DEBUG(dbgs() << " - No CommonRes identified\n"); 773 return nullptr; 774 } 775 776 NodePtr Node = prepareCompositeNode( 777 ComplexDeinterleavingOperation::CMulPartial, Real, Imag); 778 Node->Rotation = Rotation; 779 Node->addOperand(CommonRes); 780 Node->addOperand(UncommonRes); 781 Node->addOperand(CNode); 782 return submitCompositeNode(Node); 783 } 784 785 ComplexDeinterleavingGraph::NodePtr 786 ComplexDeinterleavingGraph::identifyAdd(Instruction *Real, Instruction *Imag) { 787 LLVM_DEBUG(dbgs() << "identifyAdd " << *Real << " / " << *Imag << "\n"); 788 789 // Determine rotation 790 ComplexDeinterleavingRotation Rotation; 791 if ((Real->getOpcode() == Instruction::FSub && 792 Imag->getOpcode() == Instruction::FAdd) || 793 (Real->getOpcode() == Instruction::Sub && 794 Imag->getOpcode() == Instruction::Add)) 795 Rotation = ComplexDeinterleavingRotation::Rotation_90; 796 else if ((Real->getOpcode() == Instruction::FAdd && 797 Imag->getOpcode() == Instruction::FSub) || 798 (Real->getOpcode() == Instruction::Add && 799 Imag->getOpcode() == Instruction::Sub)) 800 Rotation = ComplexDeinterleavingRotation::Rotation_270; 801 else { 802 LLVM_DEBUG(dbgs() << " - Unhandled case, rotation is not assigned.\n"); 803 return nullptr; 804 } 805 806 auto *AR = dyn_cast<Instruction>(Real->getOperand(0)); 807 auto *BI = dyn_cast<Instruction>(Real->getOperand(1)); 808 auto *AI = dyn_cast<Instruction>(Imag->getOperand(0)); 809 auto *BR = dyn_cast<Instruction>(Imag->getOperand(1)); 810 811 if (!AR || !AI || !BR || !BI) { 812 LLVM_DEBUG(dbgs() << " - Not all operands are instructions.\n"); 813 return nullptr; 814 } 815 816 NodePtr ResA = identifyNode(AR, AI); 817 if (!ResA) { 818 LLVM_DEBUG(dbgs() << " - AR/AI is not identified as a composite node.\n"); 819 return nullptr; 820 } 821 NodePtr ResB = identifyNode(BR, BI); 822 if (!ResB) { 823 LLVM_DEBUG(dbgs() << " - BR/BI is not identified as a composite node.\n"); 824 return nullptr; 825 } 826 827 NodePtr Node = 828 prepareCompositeNode(ComplexDeinterleavingOperation::CAdd, Real, Imag); 829 Node->Rotation = Rotation; 830 Node->addOperand(ResA); 831 Node->addOperand(ResB); 832 return submitCompositeNode(Node); 833 } 834 835 static bool isInstructionPairAdd(Instruction *A, Instruction *B) { 836 unsigned OpcA = A->getOpcode(); 837 unsigned OpcB = B->getOpcode(); 838 839 return (OpcA == Instruction::FSub && OpcB == Instruction::FAdd) || 840 (OpcA == Instruction::FAdd && OpcB == Instruction::FSub) || 841 (OpcA == Instruction::Sub && OpcB == Instruction::Add) || 842 (OpcA == Instruction::Add && OpcB == Instruction::Sub); 843 } 844 845 static bool isInstructionPairMul(Instruction *A, Instruction *B) { 846 auto Pattern = 847 m_BinOp(m_FMul(m_Value(), m_Value()), m_FMul(m_Value(), m_Value())); 848 849 return match(A, Pattern) && match(B, Pattern); 850 } 851 852 static bool isInstructionPotentiallySymmetric(Instruction *I) { 853 switch (I->getOpcode()) { 854 case Instruction::FAdd: 855 case Instruction::FSub: 856 case Instruction::FMul: 857 case Instruction::FNeg: 858 case Instruction::Add: 859 case Instruction::Sub: 860 case Instruction::Mul: 861 return true; 862 default: 863 return false; 864 } 865 } 866 867 ComplexDeinterleavingGraph::NodePtr 868 ComplexDeinterleavingGraph::identifySymmetricOperation(Instruction *Real, 869 Instruction *Imag) { 870 if (Real->getOpcode() != Imag->getOpcode()) 871 return nullptr; 872 873 if (!isInstructionPotentiallySymmetric(Real) || 874 !isInstructionPotentiallySymmetric(Imag)) 875 return nullptr; 876 877 auto *R0 = Real->getOperand(0); 878 auto *I0 = Imag->getOperand(0); 879 880 NodePtr Op0 = identifyNode(R0, I0); 881 NodePtr Op1 = nullptr; 882 if (Op0 == nullptr) 883 return nullptr; 884 885 if (Real->isBinaryOp()) { 886 auto *R1 = Real->getOperand(1); 887 auto *I1 = Imag->getOperand(1); 888 Op1 = identifyNode(R1, I1); 889 if (Op1 == nullptr) 890 return nullptr; 891 } 892 893 if (isa<FPMathOperator>(Real) && 894 Real->getFastMathFlags() != Imag->getFastMathFlags()) 895 return nullptr; 896 897 auto Node = prepareCompositeNode(ComplexDeinterleavingOperation::Symmetric, 898 Real, Imag); 899 Node->Opcode = Real->getOpcode(); 900 if (isa<FPMathOperator>(Real)) 901 Node->Flags = Real->getFastMathFlags(); 902 903 Node->addOperand(Op0); 904 if (Real->isBinaryOp()) 905 Node->addOperand(Op1); 906 907 return submitCompositeNode(Node); 908 } 909 910 ComplexDeinterleavingGraph::NodePtr 911 ComplexDeinterleavingGraph::identifyDotProduct(Value *V) { 912 913 if (!TL->isComplexDeinterleavingOperationSupported( 914 ComplexDeinterleavingOperation::CDot, V->getType())) { 915 LLVM_DEBUG(dbgs() << "Target doesn't support complex deinterleaving " 916 "operation CDot with the type " 917 << *V->getType() << "\n"); 918 return nullptr; 919 } 920 921 auto *Inst = cast<Instruction>(V); 922 auto *RealUser = cast<Instruction>(*Inst->user_begin()); 923 924 NodePtr CN = 925 prepareCompositeNode(ComplexDeinterleavingOperation::CDot, Inst, nullptr); 926 927 NodePtr ANode; 928 929 const Intrinsic::ID PartialReduceInt = 930 Intrinsic::experimental_vector_partial_reduce_add; 931 932 Value *AReal = nullptr; 933 Value *AImag = nullptr; 934 Value *BReal = nullptr; 935 Value *BImag = nullptr; 936 Value *Phi = nullptr; 937 938 auto UnwrapCast = [](Value *V) -> Value * { 939 if (auto *CI = dyn_cast<CastInst>(V)) 940 return CI->getOperand(0); 941 return V; 942 }; 943 944 auto PatternRot0 = m_Intrinsic<PartialReduceInt>( 945 m_Intrinsic<PartialReduceInt>(m_Value(Phi), 946 m_Mul(m_Value(BReal), m_Value(AReal))), 947 m_Neg(m_Mul(m_Value(BImag), m_Value(AImag)))); 948 949 auto PatternRot270 = m_Intrinsic<PartialReduceInt>( 950 m_Intrinsic<PartialReduceInt>( 951 m_Value(Phi), m_Neg(m_Mul(m_Value(BReal), m_Value(AImag)))), 952 m_Mul(m_Value(BImag), m_Value(AReal))); 953 954 if (match(Inst, PatternRot0)) { 955 CN->Rotation = ComplexDeinterleavingRotation::Rotation_0; 956 } else if (match(Inst, PatternRot270)) { 957 CN->Rotation = ComplexDeinterleavingRotation::Rotation_270; 958 } else { 959 Value *A0, *A1; 960 // The rotations 90 and 180 share the same operation pattern, so inspect the 961 // order of the operands, identifying where the real and imaginary 962 // components of A go, to discern between the aforementioned rotations. 963 auto PatternRot90Rot180 = m_Intrinsic<PartialReduceInt>( 964 m_Intrinsic<PartialReduceInt>(m_Value(Phi), 965 m_Mul(m_Value(BReal), m_Value(A0))), 966 m_Mul(m_Value(BImag), m_Value(A1))); 967 968 if (!match(Inst, PatternRot90Rot180)) 969 return nullptr; 970 971 A0 = UnwrapCast(A0); 972 A1 = UnwrapCast(A1); 973 974 // Test if A0 is real/A1 is imag 975 ANode = identifyNode(A0, A1); 976 if (!ANode) { 977 // Test if A0 is imag/A1 is real 978 ANode = identifyNode(A1, A0); 979 // Unable to identify operand components, thus unable to identify rotation 980 if (!ANode) 981 return nullptr; 982 CN->Rotation = ComplexDeinterleavingRotation::Rotation_90; 983 AReal = A1; 984 AImag = A0; 985 } else { 986 AReal = A0; 987 AImag = A1; 988 CN->Rotation = ComplexDeinterleavingRotation::Rotation_180; 989 } 990 } 991 992 AReal = UnwrapCast(AReal); 993 AImag = UnwrapCast(AImag); 994 BReal = UnwrapCast(BReal); 995 BImag = UnwrapCast(BImag); 996 997 VectorType *VTy = cast<VectorType>(V->getType()); 998 Type *ExpectedOperandTy = VectorType::getSubdividedVectorType(VTy, 2); 999 if (AReal->getType() != ExpectedOperandTy) 1000 return nullptr; 1001 if (AImag->getType() != ExpectedOperandTy) 1002 return nullptr; 1003 if (BReal->getType() != ExpectedOperandTy) 1004 return nullptr; 1005 if (BImag->getType() != ExpectedOperandTy) 1006 return nullptr; 1007 1008 if (Phi->getType() != VTy && RealUser->getType() != VTy) 1009 return nullptr; 1010 1011 NodePtr Node = identifyNode(AReal, AImag); 1012 1013 // In the case that a node was identified to figure out the rotation, ensure 1014 // that trying to identify a node with AReal and AImag post-unwrap results in 1015 // the same node 1016 if (ANode && Node != ANode) { 1017 LLVM_DEBUG( 1018 dbgs() 1019 << "Identified node is different from previously identified node. " 1020 "Unable to confidently generate a complex operation node\n"); 1021 return nullptr; 1022 } 1023 1024 CN->addOperand(Node); 1025 CN->addOperand(identifyNode(BReal, BImag)); 1026 CN->addOperand(identifyNode(Phi, RealUser)); 1027 1028 return submitCompositeNode(CN); 1029 } 1030 1031 ComplexDeinterleavingGraph::NodePtr 1032 ComplexDeinterleavingGraph::identifyPartialReduction(Value *R, Value *I) { 1033 // Partial reductions don't support non-vector types, so check these first 1034 if (!isa<VectorType>(R->getType()) || !isa<VectorType>(I->getType())) 1035 return nullptr; 1036 1037 auto CommonUser = 1038 findCommonBetweenCollections<Value *>(R->users(), I->users()); 1039 if (!CommonUser) 1040 return nullptr; 1041 1042 auto *IInst = dyn_cast<IntrinsicInst>(*CommonUser); 1043 if (!IInst || IInst->getIntrinsicID() != 1044 Intrinsic::experimental_vector_partial_reduce_add) 1045 return nullptr; 1046 1047 if (NodePtr CN = identifyDotProduct(IInst)) 1048 return CN; 1049 1050 return nullptr; 1051 } 1052 1053 ComplexDeinterleavingGraph::NodePtr 1054 ComplexDeinterleavingGraph::identifyNode(Value *R, Value *I) { 1055 auto It = CachedResult.find({R, I}); 1056 if (It != CachedResult.end()) { 1057 LLVM_DEBUG(dbgs() << " - Folding to existing node\n"); 1058 return It->second; 1059 } 1060 1061 if (NodePtr CN = identifyPartialReduction(R, I)) 1062 return CN; 1063 1064 bool IsReduction = RealPHI == R && (!ImagPHI || ImagPHI == I); 1065 if (!IsReduction && R->getType() != I->getType()) 1066 return nullptr; 1067 1068 if (NodePtr CN = identifySplat(R, I)) 1069 return CN; 1070 1071 auto *Real = dyn_cast<Instruction>(R); 1072 auto *Imag = dyn_cast<Instruction>(I); 1073 if (!Real || !Imag) 1074 return nullptr; 1075 1076 if (NodePtr CN = identifyDeinterleave(Real, Imag)) 1077 return CN; 1078 1079 if (NodePtr CN = identifyPHINode(Real, Imag)) 1080 return CN; 1081 1082 if (NodePtr CN = identifySelectNode(Real, Imag)) 1083 return CN; 1084 1085 auto *VTy = cast<VectorType>(Real->getType()); 1086 auto *NewVTy = VectorType::getDoubleElementsVectorType(VTy); 1087 1088 bool HasCMulSupport = TL->isComplexDeinterleavingOperationSupported( 1089 ComplexDeinterleavingOperation::CMulPartial, NewVTy); 1090 bool HasCAddSupport = TL->isComplexDeinterleavingOperationSupported( 1091 ComplexDeinterleavingOperation::CAdd, NewVTy); 1092 1093 if (HasCMulSupport && isInstructionPairMul(Real, Imag)) { 1094 if (NodePtr CN = identifyPartialMul(Real, Imag)) 1095 return CN; 1096 } 1097 1098 if (HasCAddSupport && isInstructionPairAdd(Real, Imag)) { 1099 if (NodePtr CN = identifyAdd(Real, Imag)) 1100 return CN; 1101 } 1102 1103 if (HasCMulSupport && HasCAddSupport) { 1104 if (NodePtr CN = identifyReassocNodes(Real, Imag)) 1105 return CN; 1106 } 1107 1108 if (NodePtr CN = identifySymmetricOperation(Real, Imag)) 1109 return CN; 1110 1111 LLVM_DEBUG(dbgs() << " - Not recognised as a valid pattern.\n"); 1112 CachedResult[{R, I}] = nullptr; 1113 return nullptr; 1114 } 1115 1116 ComplexDeinterleavingGraph::NodePtr 1117 ComplexDeinterleavingGraph::identifyReassocNodes(Instruction *Real, 1118 Instruction *Imag) { 1119 auto IsOperationSupported = [](unsigned Opcode) -> bool { 1120 return Opcode == Instruction::FAdd || Opcode == Instruction::FSub || 1121 Opcode == Instruction::FNeg || Opcode == Instruction::Add || 1122 Opcode == Instruction::Sub; 1123 }; 1124 1125 if (!IsOperationSupported(Real->getOpcode()) || 1126 !IsOperationSupported(Imag->getOpcode())) 1127 return nullptr; 1128 1129 std::optional<FastMathFlags> Flags; 1130 if (isa<FPMathOperator>(Real)) { 1131 if (Real->getFastMathFlags() != Imag->getFastMathFlags()) { 1132 LLVM_DEBUG(dbgs() << "The flags in Real and Imaginary instructions are " 1133 "not identical\n"); 1134 return nullptr; 1135 } 1136 1137 Flags = Real->getFastMathFlags(); 1138 if (!Flags->allowReassoc()) { 1139 LLVM_DEBUG( 1140 dbgs() 1141 << "the 'Reassoc' attribute is missing in the FastMath flags\n"); 1142 return nullptr; 1143 } 1144 } 1145 1146 // Collect multiplications and addend instructions from the given instruction 1147 // while traversing it operands. Additionally, verify that all instructions 1148 // have the same fast math flags. 1149 auto Collect = [&Flags](Instruction *Insn, std::vector<Product> &Muls, 1150 std::list<Addend> &Addends) -> bool { 1151 SmallVector<PointerIntPair<Value *, 1, bool>> Worklist = {{Insn, true}}; 1152 SmallPtrSet<Value *, 8> Visited; 1153 while (!Worklist.empty()) { 1154 auto [V, IsPositive] = Worklist.back(); 1155 Worklist.pop_back(); 1156 if (!Visited.insert(V).second) 1157 continue; 1158 1159 Instruction *I = dyn_cast<Instruction>(V); 1160 if (!I) { 1161 Addends.emplace_back(V, IsPositive); 1162 continue; 1163 } 1164 1165 // If an instruction has more than one user, it indicates that it either 1166 // has an external user, which will be later checked by the checkNodes 1167 // function, or it is a subexpression utilized by multiple expressions. In 1168 // the latter case, we will attempt to separately identify the complex 1169 // operation from here in order to create a shared 1170 // ComplexDeinterleavingCompositeNode. 1171 if (I != Insn && I->getNumUses() > 1) { 1172 LLVM_DEBUG(dbgs() << "Found potential sub-expression: " << *I << "\n"); 1173 Addends.emplace_back(I, IsPositive); 1174 continue; 1175 } 1176 switch (I->getOpcode()) { 1177 case Instruction::FAdd: 1178 case Instruction::Add: 1179 Worklist.emplace_back(I->getOperand(1), IsPositive); 1180 Worklist.emplace_back(I->getOperand(0), IsPositive); 1181 break; 1182 case Instruction::FSub: 1183 Worklist.emplace_back(I->getOperand(1), !IsPositive); 1184 Worklist.emplace_back(I->getOperand(0), IsPositive); 1185 break; 1186 case Instruction::Sub: 1187 if (isNeg(I)) { 1188 Worklist.emplace_back(getNegOperand(I), !IsPositive); 1189 } else { 1190 Worklist.emplace_back(I->getOperand(1), !IsPositive); 1191 Worklist.emplace_back(I->getOperand(0), IsPositive); 1192 } 1193 break; 1194 case Instruction::FMul: 1195 case Instruction::Mul: { 1196 Value *A, *B; 1197 if (isNeg(I->getOperand(0))) { 1198 A = getNegOperand(I->getOperand(0)); 1199 IsPositive = !IsPositive; 1200 } else { 1201 A = I->getOperand(0); 1202 } 1203 1204 if (isNeg(I->getOperand(1))) { 1205 B = getNegOperand(I->getOperand(1)); 1206 IsPositive = !IsPositive; 1207 } else { 1208 B = I->getOperand(1); 1209 } 1210 Muls.push_back(Product{A, B, IsPositive}); 1211 break; 1212 } 1213 case Instruction::FNeg: 1214 Worklist.emplace_back(I->getOperand(0), !IsPositive); 1215 break; 1216 default: 1217 Addends.emplace_back(I, IsPositive); 1218 continue; 1219 } 1220 1221 if (Flags && I->getFastMathFlags() != *Flags) { 1222 LLVM_DEBUG(dbgs() << "The instruction's fast math flags are " 1223 "inconsistent with the root instructions' flags: " 1224 << *I << "\n"); 1225 return false; 1226 } 1227 } 1228 return true; 1229 }; 1230 1231 std::vector<Product> RealMuls, ImagMuls; 1232 std::list<Addend> RealAddends, ImagAddends; 1233 if (!Collect(Real, RealMuls, RealAddends) || 1234 !Collect(Imag, ImagMuls, ImagAddends)) 1235 return nullptr; 1236 1237 if (RealAddends.size() != ImagAddends.size()) 1238 return nullptr; 1239 1240 NodePtr FinalNode; 1241 if (!RealMuls.empty() || !ImagMuls.empty()) { 1242 // If there are multiplicands, extract positive addend and use it as an 1243 // accumulator 1244 FinalNode = extractPositiveAddend(RealAddends, ImagAddends); 1245 FinalNode = identifyMultiplications(RealMuls, ImagMuls, FinalNode); 1246 if (!FinalNode) 1247 return nullptr; 1248 } 1249 1250 // Identify and process remaining additions 1251 if (!RealAddends.empty() || !ImagAddends.empty()) { 1252 FinalNode = identifyAdditions(RealAddends, ImagAddends, Flags, FinalNode); 1253 if (!FinalNode) 1254 return nullptr; 1255 } 1256 assert(FinalNode && "FinalNode can not be nullptr here"); 1257 // Set the Real and Imag fields of the final node and submit it 1258 FinalNode->Real = Real; 1259 FinalNode->Imag = Imag; 1260 submitCompositeNode(FinalNode); 1261 return FinalNode; 1262 } 1263 1264 bool ComplexDeinterleavingGraph::collectPartialMuls( 1265 const std::vector<Product> &RealMuls, const std::vector<Product> &ImagMuls, 1266 std::vector<PartialMulCandidate> &PartialMulCandidates) { 1267 // Helper function to extract a common operand from two products 1268 auto FindCommonInstruction = [](const Product &Real, 1269 const Product &Imag) -> Value * { 1270 if (Real.Multiplicand == Imag.Multiplicand || 1271 Real.Multiplicand == Imag.Multiplier) 1272 return Real.Multiplicand; 1273 1274 if (Real.Multiplier == Imag.Multiplicand || 1275 Real.Multiplier == Imag.Multiplier) 1276 return Real.Multiplier; 1277 1278 return nullptr; 1279 }; 1280 1281 // Iterating over real and imaginary multiplications to find common operands 1282 // If a common operand is found, a partial multiplication candidate is created 1283 // and added to the candidates vector The function returns false if no common 1284 // operands are found for any product 1285 for (unsigned i = 0; i < RealMuls.size(); ++i) { 1286 bool FoundCommon = false; 1287 for (unsigned j = 0; j < ImagMuls.size(); ++j) { 1288 auto *Common = FindCommonInstruction(RealMuls[i], ImagMuls[j]); 1289 if (!Common) 1290 continue; 1291 1292 auto *A = RealMuls[i].Multiplicand == Common ? RealMuls[i].Multiplier 1293 : RealMuls[i].Multiplicand; 1294 auto *B = ImagMuls[j].Multiplicand == Common ? ImagMuls[j].Multiplier 1295 : ImagMuls[j].Multiplicand; 1296 1297 auto Node = identifyNode(A, B); 1298 if (Node) { 1299 FoundCommon = true; 1300 PartialMulCandidates.push_back({Common, Node, i, j, false}); 1301 } 1302 1303 Node = identifyNode(B, A); 1304 if (Node) { 1305 FoundCommon = true; 1306 PartialMulCandidates.push_back({Common, Node, i, j, true}); 1307 } 1308 } 1309 if (!FoundCommon) 1310 return false; 1311 } 1312 return true; 1313 } 1314 1315 ComplexDeinterleavingGraph::NodePtr 1316 ComplexDeinterleavingGraph::identifyMultiplications( 1317 std::vector<Product> &RealMuls, std::vector<Product> &ImagMuls, 1318 NodePtr Accumulator = nullptr) { 1319 if (RealMuls.size() != ImagMuls.size()) 1320 return nullptr; 1321 1322 std::vector<PartialMulCandidate> Info; 1323 if (!collectPartialMuls(RealMuls, ImagMuls, Info)) 1324 return nullptr; 1325 1326 // Map to store common instruction to node pointers 1327 std::map<Value *, NodePtr> CommonToNode; 1328 std::vector<bool> Processed(Info.size(), false); 1329 for (unsigned I = 0; I < Info.size(); ++I) { 1330 if (Processed[I]) 1331 continue; 1332 1333 PartialMulCandidate &InfoA = Info[I]; 1334 for (unsigned J = I + 1; J < Info.size(); ++J) { 1335 if (Processed[J]) 1336 continue; 1337 1338 PartialMulCandidate &InfoB = Info[J]; 1339 auto *InfoReal = &InfoA; 1340 auto *InfoImag = &InfoB; 1341 1342 auto NodeFromCommon = identifyNode(InfoReal->Common, InfoImag->Common); 1343 if (!NodeFromCommon) { 1344 std::swap(InfoReal, InfoImag); 1345 NodeFromCommon = identifyNode(InfoReal->Common, InfoImag->Common); 1346 } 1347 if (!NodeFromCommon) 1348 continue; 1349 1350 CommonToNode[InfoReal->Common] = NodeFromCommon; 1351 CommonToNode[InfoImag->Common] = NodeFromCommon; 1352 Processed[I] = true; 1353 Processed[J] = true; 1354 } 1355 } 1356 1357 std::vector<bool> ProcessedReal(RealMuls.size(), false); 1358 std::vector<bool> ProcessedImag(ImagMuls.size(), false); 1359 NodePtr Result = Accumulator; 1360 for (auto &PMI : Info) { 1361 if (ProcessedReal[PMI.RealIdx] || ProcessedImag[PMI.ImagIdx]) 1362 continue; 1363 1364 auto It = CommonToNode.find(PMI.Common); 1365 // TODO: Process independent complex multiplications. Cases like this: 1366 // A.real() * B where both A and B are complex numbers. 1367 if (It == CommonToNode.end()) { 1368 LLVM_DEBUG({ 1369 dbgs() << "Unprocessed independent partial multiplication:\n"; 1370 for (auto *Mul : {&RealMuls[PMI.RealIdx], &RealMuls[PMI.RealIdx]}) 1371 dbgs().indent(4) << (Mul->IsPositive ? "+" : "-") << *Mul->Multiplier 1372 << " multiplied by " << *Mul->Multiplicand << "\n"; 1373 }); 1374 return nullptr; 1375 } 1376 1377 auto &RealMul = RealMuls[PMI.RealIdx]; 1378 auto &ImagMul = ImagMuls[PMI.ImagIdx]; 1379 1380 auto NodeA = It->second; 1381 auto NodeB = PMI.Node; 1382 auto IsMultiplicandReal = PMI.Common == NodeA->Real; 1383 // The following table illustrates the relationship between multiplications 1384 // and rotations. If we consider the multiplication (X + iY) * (U + iV), we 1385 // can see: 1386 // 1387 // Rotation | Real | Imag | 1388 // ---------+--------+--------+ 1389 // 0 | x * u | x * v | 1390 // 90 | -y * v | y * u | 1391 // 180 | -x * u | -x * v | 1392 // 270 | y * v | -y * u | 1393 // 1394 // Check if the candidate can indeed be represented by partial 1395 // multiplication 1396 // TODO: Add support for multiplication by complex one 1397 if ((IsMultiplicandReal && PMI.IsNodeInverted) || 1398 (!IsMultiplicandReal && !PMI.IsNodeInverted)) 1399 continue; 1400 1401 // Determine the rotation based on the multiplications 1402 ComplexDeinterleavingRotation Rotation; 1403 if (IsMultiplicandReal) { 1404 // Detect 0 and 180 degrees rotation 1405 if (RealMul.IsPositive && ImagMul.IsPositive) 1406 Rotation = llvm::ComplexDeinterleavingRotation::Rotation_0; 1407 else if (!RealMul.IsPositive && !ImagMul.IsPositive) 1408 Rotation = llvm::ComplexDeinterleavingRotation::Rotation_180; 1409 else 1410 continue; 1411 1412 } else { 1413 // Detect 90 and 270 degrees rotation 1414 if (!RealMul.IsPositive && ImagMul.IsPositive) 1415 Rotation = llvm::ComplexDeinterleavingRotation::Rotation_90; 1416 else if (RealMul.IsPositive && !ImagMul.IsPositive) 1417 Rotation = llvm::ComplexDeinterleavingRotation::Rotation_270; 1418 else 1419 continue; 1420 } 1421 1422 LLVM_DEBUG({ 1423 dbgs() << "Identified partial multiplication (X, Y) * (U, V):\n"; 1424 dbgs().indent(4) << "X: " << *NodeA->Real << "\n"; 1425 dbgs().indent(4) << "Y: " << *NodeA->Imag << "\n"; 1426 dbgs().indent(4) << "U: " << *NodeB->Real << "\n"; 1427 dbgs().indent(4) << "V: " << *NodeB->Imag << "\n"; 1428 dbgs().indent(4) << "Rotation - " << (int)Rotation * 90 << "\n"; 1429 }); 1430 1431 NodePtr NodeMul = prepareCompositeNode( 1432 ComplexDeinterleavingOperation::CMulPartial, nullptr, nullptr); 1433 NodeMul->Rotation = Rotation; 1434 NodeMul->addOperand(NodeA); 1435 NodeMul->addOperand(NodeB); 1436 if (Result) 1437 NodeMul->addOperand(Result); 1438 submitCompositeNode(NodeMul); 1439 Result = NodeMul; 1440 ProcessedReal[PMI.RealIdx] = true; 1441 ProcessedImag[PMI.ImagIdx] = true; 1442 } 1443 1444 // Ensure all products have been processed, if not return nullptr. 1445 if (!all_of(ProcessedReal, [](bool V) { return V; }) || 1446 !all_of(ProcessedImag, [](bool V) { return V; })) { 1447 1448 // Dump debug information about which partial multiplications are not 1449 // processed. 1450 LLVM_DEBUG({ 1451 dbgs() << "Unprocessed products (Real):\n"; 1452 for (size_t i = 0; i < ProcessedReal.size(); ++i) { 1453 if (!ProcessedReal[i]) 1454 dbgs().indent(4) << (RealMuls[i].IsPositive ? "+" : "-") 1455 << *RealMuls[i].Multiplier << " multiplied by " 1456 << *RealMuls[i].Multiplicand << "\n"; 1457 } 1458 dbgs() << "Unprocessed products (Imag):\n"; 1459 for (size_t i = 0; i < ProcessedImag.size(); ++i) { 1460 if (!ProcessedImag[i]) 1461 dbgs().indent(4) << (ImagMuls[i].IsPositive ? "+" : "-") 1462 << *ImagMuls[i].Multiplier << " multiplied by " 1463 << *ImagMuls[i].Multiplicand << "\n"; 1464 } 1465 }); 1466 return nullptr; 1467 } 1468 1469 return Result; 1470 } 1471 1472 ComplexDeinterleavingGraph::NodePtr 1473 ComplexDeinterleavingGraph::identifyAdditions( 1474 std::list<Addend> &RealAddends, std::list<Addend> &ImagAddends, 1475 std::optional<FastMathFlags> Flags, NodePtr Accumulator = nullptr) { 1476 if (RealAddends.size() != ImagAddends.size()) 1477 return nullptr; 1478 1479 NodePtr Result; 1480 // If we have accumulator use it as first addend 1481 if (Accumulator) 1482 Result = Accumulator; 1483 // Otherwise find an element with both positive real and imaginary parts. 1484 else 1485 Result = extractPositiveAddend(RealAddends, ImagAddends); 1486 1487 if (!Result) 1488 return nullptr; 1489 1490 while (!RealAddends.empty()) { 1491 auto ItR = RealAddends.begin(); 1492 auto [R, IsPositiveR] = *ItR; 1493 1494 bool FoundImag = false; 1495 for (auto ItI = ImagAddends.begin(); ItI != ImagAddends.end(); ++ItI) { 1496 auto [I, IsPositiveI] = *ItI; 1497 ComplexDeinterleavingRotation Rotation; 1498 if (IsPositiveR && IsPositiveI) 1499 Rotation = ComplexDeinterleavingRotation::Rotation_0; 1500 else if (!IsPositiveR && IsPositiveI) 1501 Rotation = ComplexDeinterleavingRotation::Rotation_90; 1502 else if (!IsPositiveR && !IsPositiveI) 1503 Rotation = ComplexDeinterleavingRotation::Rotation_180; 1504 else 1505 Rotation = ComplexDeinterleavingRotation::Rotation_270; 1506 1507 NodePtr AddNode; 1508 if (Rotation == ComplexDeinterleavingRotation::Rotation_0 || 1509 Rotation == ComplexDeinterleavingRotation::Rotation_180) { 1510 AddNode = identifyNode(R, I); 1511 } else { 1512 AddNode = identifyNode(I, R); 1513 } 1514 if (AddNode) { 1515 LLVM_DEBUG({ 1516 dbgs() << "Identified addition:\n"; 1517 dbgs().indent(4) << "X: " << *R << "\n"; 1518 dbgs().indent(4) << "Y: " << *I << "\n"; 1519 dbgs().indent(4) << "Rotation - " << (int)Rotation * 90 << "\n"; 1520 }); 1521 1522 NodePtr TmpNode; 1523 if (Rotation == llvm::ComplexDeinterleavingRotation::Rotation_0) { 1524 TmpNode = prepareCompositeNode( 1525 ComplexDeinterleavingOperation::Symmetric, nullptr, nullptr); 1526 if (Flags) { 1527 TmpNode->Opcode = Instruction::FAdd; 1528 TmpNode->Flags = *Flags; 1529 } else { 1530 TmpNode->Opcode = Instruction::Add; 1531 } 1532 } else if (Rotation == 1533 llvm::ComplexDeinterleavingRotation::Rotation_180) { 1534 TmpNode = prepareCompositeNode( 1535 ComplexDeinterleavingOperation::Symmetric, nullptr, nullptr); 1536 if (Flags) { 1537 TmpNode->Opcode = Instruction::FSub; 1538 TmpNode->Flags = *Flags; 1539 } else { 1540 TmpNode->Opcode = Instruction::Sub; 1541 } 1542 } else { 1543 TmpNode = prepareCompositeNode(ComplexDeinterleavingOperation::CAdd, 1544 nullptr, nullptr); 1545 TmpNode->Rotation = Rotation; 1546 } 1547 1548 TmpNode->addOperand(Result); 1549 TmpNode->addOperand(AddNode); 1550 submitCompositeNode(TmpNode); 1551 Result = TmpNode; 1552 RealAddends.erase(ItR); 1553 ImagAddends.erase(ItI); 1554 FoundImag = true; 1555 break; 1556 } 1557 } 1558 if (!FoundImag) 1559 return nullptr; 1560 } 1561 return Result; 1562 } 1563 1564 ComplexDeinterleavingGraph::NodePtr 1565 ComplexDeinterleavingGraph::extractPositiveAddend( 1566 std::list<Addend> &RealAddends, std::list<Addend> &ImagAddends) { 1567 for (auto ItR = RealAddends.begin(); ItR != RealAddends.end(); ++ItR) { 1568 for (auto ItI = ImagAddends.begin(); ItI != ImagAddends.end(); ++ItI) { 1569 auto [R, IsPositiveR] = *ItR; 1570 auto [I, IsPositiveI] = *ItI; 1571 if (IsPositiveR && IsPositiveI) { 1572 auto Result = identifyNode(R, I); 1573 if (Result) { 1574 RealAddends.erase(ItR); 1575 ImagAddends.erase(ItI); 1576 return Result; 1577 } 1578 } 1579 } 1580 } 1581 return nullptr; 1582 } 1583 1584 bool ComplexDeinterleavingGraph::identifyNodes(Instruction *RootI) { 1585 // This potential root instruction might already have been recognized as 1586 // reduction. Because RootToNode maps both Real and Imaginary parts to 1587 // CompositeNode we should choose only one either Real or Imag instruction to 1588 // use as an anchor for generating complex instruction. 1589 auto It = RootToNode.find(RootI); 1590 if (It != RootToNode.end()) { 1591 auto RootNode = It->second; 1592 assert(RootNode->Operation == 1593 ComplexDeinterleavingOperation::ReductionOperation || 1594 RootNode->Operation == 1595 ComplexDeinterleavingOperation::ReductionSingle); 1596 // Find out which part, Real or Imag, comes later, and only if we come to 1597 // the latest part, add it to OrderedRoots. 1598 auto *R = cast<Instruction>(RootNode->Real); 1599 auto *I = RootNode->Imag ? cast<Instruction>(RootNode->Imag) : nullptr; 1600 1601 Instruction *ReplacementAnchor; 1602 if (I) 1603 ReplacementAnchor = R->comesBefore(I) ? I : R; 1604 else 1605 ReplacementAnchor = R; 1606 1607 if (ReplacementAnchor != RootI) 1608 return false; 1609 OrderedRoots.push_back(RootI); 1610 return true; 1611 } 1612 1613 auto RootNode = identifyRoot(RootI); 1614 if (!RootNode) 1615 return false; 1616 1617 LLVM_DEBUG({ 1618 Function *F = RootI->getFunction(); 1619 BasicBlock *B = RootI->getParent(); 1620 dbgs() << "Complex deinterleaving graph for " << F->getName() 1621 << "::" << B->getName() << ".\n"; 1622 dump(dbgs()); 1623 dbgs() << "\n"; 1624 }); 1625 RootToNode[RootI] = RootNode; 1626 OrderedRoots.push_back(RootI); 1627 return true; 1628 } 1629 1630 bool ComplexDeinterleavingGraph::collectPotentialReductions(BasicBlock *B) { 1631 bool FoundPotentialReduction = false; 1632 1633 auto *Br = dyn_cast<BranchInst>(B->getTerminator()); 1634 if (!Br || Br->getNumSuccessors() != 2) 1635 return false; 1636 1637 // Identify simple one-block loop 1638 if (Br->getSuccessor(0) != B && Br->getSuccessor(1) != B) 1639 return false; 1640 1641 SmallVector<PHINode *> PHIs; 1642 for (auto &PHI : B->phis()) { 1643 if (PHI.getNumIncomingValues() != 2) 1644 continue; 1645 1646 if (!PHI.getType()->isVectorTy()) 1647 continue; 1648 1649 auto *ReductionOp = dyn_cast<Instruction>(PHI.getIncomingValueForBlock(B)); 1650 if (!ReductionOp) 1651 continue; 1652 1653 // Check if final instruction is reduced outside of current block 1654 Instruction *FinalReduction = nullptr; 1655 auto NumUsers = 0u; 1656 for (auto *U : ReductionOp->users()) { 1657 ++NumUsers; 1658 if (U == &PHI) 1659 continue; 1660 FinalReduction = dyn_cast<Instruction>(U); 1661 } 1662 1663 if (NumUsers != 2 || !FinalReduction || FinalReduction->getParent() == B || 1664 isa<PHINode>(FinalReduction)) 1665 continue; 1666 1667 ReductionInfo[ReductionOp] = {&PHI, FinalReduction}; 1668 BackEdge = B; 1669 auto BackEdgeIdx = PHI.getBasicBlockIndex(B); 1670 auto IncomingIdx = BackEdgeIdx == 0 ? 1 : 0; 1671 Incoming = PHI.getIncomingBlock(IncomingIdx); 1672 FoundPotentialReduction = true; 1673 1674 // If the initial value of PHINode is an Instruction, consider it a leaf 1675 // value of a complex deinterleaving graph. 1676 if (auto *InitPHI = 1677 dyn_cast<Instruction>(PHI.getIncomingValueForBlock(Incoming))) 1678 FinalInstructions.insert(InitPHI); 1679 } 1680 return FoundPotentialReduction; 1681 } 1682 1683 void ComplexDeinterleavingGraph::identifyReductionNodes() { 1684 SmallVector<bool> Processed(ReductionInfo.size(), false); 1685 SmallVector<Instruction *> OperationInstruction; 1686 for (auto &P : ReductionInfo) 1687 OperationInstruction.push_back(P.first); 1688 1689 // Identify a complex computation by evaluating two reduction operations that 1690 // potentially could be involved 1691 for (size_t i = 0; i < OperationInstruction.size(); ++i) { 1692 if (Processed[i]) 1693 continue; 1694 for (size_t j = i + 1; j < OperationInstruction.size(); ++j) { 1695 if (Processed[j]) 1696 continue; 1697 auto *Real = OperationInstruction[i]; 1698 auto *Imag = OperationInstruction[j]; 1699 if (Real->getType() != Imag->getType()) 1700 continue; 1701 1702 RealPHI = ReductionInfo[Real].first; 1703 ImagPHI = ReductionInfo[Imag].first; 1704 PHIsFound = false; 1705 auto Node = identifyNode(Real, Imag); 1706 if (!Node) { 1707 std::swap(Real, Imag); 1708 std::swap(RealPHI, ImagPHI); 1709 Node = identifyNode(Real, Imag); 1710 } 1711 1712 // If a node is identified and reduction PHINode is used in the chain of 1713 // operations, mark its operation instructions as used to prevent 1714 // re-identification and attach the node to the real part 1715 if (Node && PHIsFound) { 1716 LLVM_DEBUG(dbgs() << "Identified reduction starting from instructions: " 1717 << *Real << " / " << *Imag << "\n"); 1718 Processed[i] = true; 1719 Processed[j] = true; 1720 auto RootNode = prepareCompositeNode( 1721 ComplexDeinterleavingOperation::ReductionOperation, Real, Imag); 1722 RootNode->addOperand(Node); 1723 RootToNode[Real] = RootNode; 1724 RootToNode[Imag] = RootNode; 1725 submitCompositeNode(RootNode); 1726 break; 1727 } 1728 } 1729 1730 auto *Real = OperationInstruction[i]; 1731 // We want to check that we have 2 operands, but the function attributes 1732 // being counted as operands bloats this value. 1733 if (Processed[i] || Real->getNumOperands() < 2) 1734 continue; 1735 1736 RealPHI = ReductionInfo[Real].first; 1737 ImagPHI = nullptr; 1738 PHIsFound = false; 1739 auto Node = identifyNode(Real->getOperand(0), Real->getOperand(1)); 1740 if (Node && PHIsFound) { 1741 LLVM_DEBUG( 1742 dbgs() << "Identified single reduction starting from instruction: " 1743 << *Real << "/" << *ReductionInfo[Real].second << "\n"); 1744 Processed[i] = true; 1745 auto RootNode = prepareCompositeNode( 1746 ComplexDeinterleavingOperation::ReductionSingle, Real, nullptr); 1747 RootNode->addOperand(Node); 1748 RootToNode[Real] = RootNode; 1749 submitCompositeNode(RootNode); 1750 } 1751 } 1752 1753 RealPHI = nullptr; 1754 ImagPHI = nullptr; 1755 } 1756 1757 bool ComplexDeinterleavingGraph::checkNodes() { 1758 1759 bool FoundDeinterleaveNode = false; 1760 for (NodePtr N : CompositeNodes) { 1761 if (!N->areOperandsValid()) 1762 return false; 1763 if (N->Operation == ComplexDeinterleavingOperation::Deinterleave) 1764 FoundDeinterleaveNode = true; 1765 } 1766 1767 // We need a deinterleave node in order to guarantee that we're working with 1768 // complex numbers. 1769 if (!FoundDeinterleaveNode) { 1770 LLVM_DEBUG( 1771 dbgs() << "Couldn't find a deinterleave node within the graph, cannot " 1772 "guarantee safety during graph transformation.\n"); 1773 return false; 1774 } 1775 1776 // Collect all instructions from roots to leaves 1777 SmallPtrSet<Instruction *, 16> AllInstructions; 1778 SmallVector<Instruction *, 8> Worklist; 1779 for (auto &Pair : RootToNode) 1780 Worklist.push_back(Pair.first); 1781 1782 // Extract all instructions that are used by all XCMLA/XCADD/ADD/SUB/NEG 1783 // chains 1784 while (!Worklist.empty()) { 1785 auto *I = Worklist.back(); 1786 Worklist.pop_back(); 1787 1788 if (!AllInstructions.insert(I).second) 1789 continue; 1790 1791 for (Value *Op : I->operands()) { 1792 if (auto *OpI = dyn_cast<Instruction>(Op)) { 1793 if (!FinalInstructions.count(I)) 1794 Worklist.emplace_back(OpI); 1795 } 1796 } 1797 } 1798 1799 // Find instructions that have users outside of chain 1800 SmallVector<Instruction *, 2> OuterInstructions; 1801 for (auto *I : AllInstructions) { 1802 // Skip root nodes 1803 if (RootToNode.count(I)) 1804 continue; 1805 1806 for (User *U : I->users()) { 1807 if (AllInstructions.count(cast<Instruction>(U))) 1808 continue; 1809 1810 // Found an instruction that is not used by XCMLA/XCADD chain 1811 Worklist.emplace_back(I); 1812 break; 1813 } 1814 } 1815 1816 // If any instructions are found to be used outside, find and remove roots 1817 // that somehow connect to those instructions. 1818 SmallPtrSet<Instruction *, 16> Visited; 1819 while (!Worklist.empty()) { 1820 auto *I = Worklist.back(); 1821 Worklist.pop_back(); 1822 if (!Visited.insert(I).second) 1823 continue; 1824 1825 // Found an impacted root node. Removing it from the nodes to be 1826 // deinterleaved 1827 if (RootToNode.count(I)) { 1828 LLVM_DEBUG(dbgs() << "Instruction " << *I 1829 << " could be deinterleaved but its chain of complex " 1830 "operations have an outside user\n"); 1831 RootToNode.erase(I); 1832 } 1833 1834 if (!AllInstructions.count(I) || FinalInstructions.count(I)) 1835 continue; 1836 1837 for (User *U : I->users()) 1838 Worklist.emplace_back(cast<Instruction>(U)); 1839 1840 for (Value *Op : I->operands()) { 1841 if (auto *OpI = dyn_cast<Instruction>(Op)) 1842 Worklist.emplace_back(OpI); 1843 } 1844 } 1845 return !RootToNode.empty(); 1846 } 1847 1848 ComplexDeinterleavingGraph::NodePtr 1849 ComplexDeinterleavingGraph::identifyRoot(Instruction *RootI) { 1850 if (auto *Intrinsic = dyn_cast<IntrinsicInst>(RootI)) { 1851 if (Intrinsic->getIntrinsicID() != Intrinsic::vector_interleave2) 1852 return nullptr; 1853 1854 auto *Real = dyn_cast<Instruction>(Intrinsic->getOperand(0)); 1855 auto *Imag = dyn_cast<Instruction>(Intrinsic->getOperand(1)); 1856 if (!Real || !Imag) 1857 return nullptr; 1858 1859 return identifyNode(Real, Imag); 1860 } 1861 1862 auto *SVI = dyn_cast<ShuffleVectorInst>(RootI); 1863 if (!SVI) 1864 return nullptr; 1865 1866 // Look for a shufflevector that takes separate vectors of the real and 1867 // imaginary components and recombines them into a single vector. 1868 if (!isInterleavingMask(SVI->getShuffleMask())) 1869 return nullptr; 1870 1871 Instruction *Real; 1872 Instruction *Imag; 1873 if (!match(RootI, m_Shuffle(m_Instruction(Real), m_Instruction(Imag)))) 1874 return nullptr; 1875 1876 return identifyNode(Real, Imag); 1877 } 1878 1879 ComplexDeinterleavingGraph::NodePtr 1880 ComplexDeinterleavingGraph::identifyDeinterleave(Instruction *Real, 1881 Instruction *Imag) { 1882 Instruction *I = nullptr; 1883 Value *FinalValue = nullptr; 1884 if (match(Real, m_ExtractValue<0>(m_Instruction(I))) && 1885 match(Imag, m_ExtractValue<1>(m_Specific(I))) && 1886 match(I, m_Intrinsic<Intrinsic::vector_deinterleave2>( 1887 m_Value(FinalValue)))) { 1888 NodePtr PlaceholderNode = prepareCompositeNode( 1889 llvm::ComplexDeinterleavingOperation::Deinterleave, Real, Imag); 1890 PlaceholderNode->ReplacementNode = FinalValue; 1891 FinalInstructions.insert(Real); 1892 FinalInstructions.insert(Imag); 1893 return submitCompositeNode(PlaceholderNode); 1894 } 1895 1896 auto *RealShuffle = dyn_cast<ShuffleVectorInst>(Real); 1897 auto *ImagShuffle = dyn_cast<ShuffleVectorInst>(Imag); 1898 if (!RealShuffle || !ImagShuffle) { 1899 if (RealShuffle || ImagShuffle) 1900 LLVM_DEBUG(dbgs() << " - There's a shuffle where there shouldn't be.\n"); 1901 return nullptr; 1902 } 1903 1904 Value *RealOp1 = RealShuffle->getOperand(1); 1905 if (!isa<UndefValue>(RealOp1) && !isa<ConstantAggregateZero>(RealOp1)) { 1906 LLVM_DEBUG(dbgs() << " - RealOp1 is not undef or zero.\n"); 1907 return nullptr; 1908 } 1909 Value *ImagOp1 = ImagShuffle->getOperand(1); 1910 if (!isa<UndefValue>(ImagOp1) && !isa<ConstantAggregateZero>(ImagOp1)) { 1911 LLVM_DEBUG(dbgs() << " - ImagOp1 is not undef or zero.\n"); 1912 return nullptr; 1913 } 1914 1915 Value *RealOp0 = RealShuffle->getOperand(0); 1916 Value *ImagOp0 = ImagShuffle->getOperand(0); 1917 1918 if (RealOp0 != ImagOp0) { 1919 LLVM_DEBUG(dbgs() << " - Shuffle operands are not equal.\n"); 1920 return nullptr; 1921 } 1922 1923 ArrayRef<int> RealMask = RealShuffle->getShuffleMask(); 1924 ArrayRef<int> ImagMask = ImagShuffle->getShuffleMask(); 1925 if (!isDeinterleavingMask(RealMask) || !isDeinterleavingMask(ImagMask)) { 1926 LLVM_DEBUG(dbgs() << " - Masks are not deinterleaving.\n"); 1927 return nullptr; 1928 } 1929 1930 if (RealMask[0] != 0 || ImagMask[0] != 1) { 1931 LLVM_DEBUG(dbgs() << " - Masks do not have the correct initial value.\n"); 1932 return nullptr; 1933 } 1934 1935 // Type checking, the shuffle type should be a vector type of the same 1936 // scalar type, but half the size 1937 auto CheckType = [&](ShuffleVectorInst *Shuffle) { 1938 Value *Op = Shuffle->getOperand(0); 1939 auto *ShuffleTy = cast<FixedVectorType>(Shuffle->getType()); 1940 auto *OpTy = cast<FixedVectorType>(Op->getType()); 1941 1942 if (OpTy->getScalarType() != ShuffleTy->getScalarType()) 1943 return false; 1944 if ((ShuffleTy->getNumElements() * 2) != OpTy->getNumElements()) 1945 return false; 1946 1947 return true; 1948 }; 1949 1950 auto CheckDeinterleavingShuffle = [&](ShuffleVectorInst *Shuffle) -> bool { 1951 if (!CheckType(Shuffle)) 1952 return false; 1953 1954 ArrayRef<int> Mask = Shuffle->getShuffleMask(); 1955 int Last = *Mask.rbegin(); 1956 1957 Value *Op = Shuffle->getOperand(0); 1958 auto *OpTy = cast<FixedVectorType>(Op->getType()); 1959 int NumElements = OpTy->getNumElements(); 1960 1961 // Ensure that the deinterleaving shuffle only pulls from the first 1962 // shuffle operand. 1963 return Last < NumElements; 1964 }; 1965 1966 if (RealShuffle->getType() != ImagShuffle->getType()) { 1967 LLVM_DEBUG(dbgs() << " - Shuffle types aren't equal.\n"); 1968 return nullptr; 1969 } 1970 if (!CheckDeinterleavingShuffle(RealShuffle)) { 1971 LLVM_DEBUG(dbgs() << " - RealShuffle is invalid type.\n"); 1972 return nullptr; 1973 } 1974 if (!CheckDeinterleavingShuffle(ImagShuffle)) { 1975 LLVM_DEBUG(dbgs() << " - ImagShuffle is invalid type.\n"); 1976 return nullptr; 1977 } 1978 1979 NodePtr PlaceholderNode = 1980 prepareCompositeNode(llvm::ComplexDeinterleavingOperation::Deinterleave, 1981 RealShuffle, ImagShuffle); 1982 PlaceholderNode->ReplacementNode = RealShuffle->getOperand(0); 1983 FinalInstructions.insert(RealShuffle); 1984 FinalInstructions.insert(ImagShuffle); 1985 return submitCompositeNode(PlaceholderNode); 1986 } 1987 1988 ComplexDeinterleavingGraph::NodePtr 1989 ComplexDeinterleavingGraph::identifySplat(Value *R, Value *I) { 1990 auto IsSplat = [](Value *V) -> bool { 1991 // Fixed-width vector with constants 1992 if (isa<ConstantDataVector>(V)) 1993 return true; 1994 1995 VectorType *VTy; 1996 ArrayRef<int> Mask; 1997 // Splats are represented differently depending on whether the repeated 1998 // value is a constant or an Instruction 1999 if (auto *Const = dyn_cast<ConstantExpr>(V)) { 2000 if (Const->getOpcode() != Instruction::ShuffleVector) 2001 return false; 2002 VTy = cast<VectorType>(Const->getType()); 2003 Mask = Const->getShuffleMask(); 2004 } else if (auto *Shuf = dyn_cast<ShuffleVectorInst>(V)) { 2005 VTy = Shuf->getType(); 2006 Mask = Shuf->getShuffleMask(); 2007 } else { 2008 return false; 2009 } 2010 2011 // When the data type is <1 x Type>, it's not possible to differentiate 2012 // between the ComplexDeinterleaving::Deinterleave and 2013 // ComplexDeinterleaving::Splat operations. 2014 if (!VTy->isScalableTy() && VTy->getElementCount().getKnownMinValue() == 1) 2015 return false; 2016 2017 return all_equal(Mask) && Mask[0] == 0; 2018 }; 2019 2020 if (!IsSplat(R) || !IsSplat(I)) 2021 return nullptr; 2022 2023 auto *Real = dyn_cast<Instruction>(R); 2024 auto *Imag = dyn_cast<Instruction>(I); 2025 if ((!Real && Imag) || (Real && !Imag)) 2026 return nullptr; 2027 2028 if (Real && Imag) { 2029 // Non-constant splats should be in the same basic block 2030 if (Real->getParent() != Imag->getParent()) 2031 return nullptr; 2032 2033 FinalInstructions.insert(Real); 2034 FinalInstructions.insert(Imag); 2035 } 2036 NodePtr PlaceholderNode = 2037 prepareCompositeNode(ComplexDeinterleavingOperation::Splat, R, I); 2038 return submitCompositeNode(PlaceholderNode); 2039 } 2040 2041 ComplexDeinterleavingGraph::NodePtr 2042 ComplexDeinterleavingGraph::identifyPHINode(Instruction *Real, 2043 Instruction *Imag) { 2044 if (Real != RealPHI || (ImagPHI && Imag != ImagPHI)) 2045 return nullptr; 2046 2047 PHIsFound = true; 2048 NodePtr PlaceholderNode = prepareCompositeNode( 2049 ComplexDeinterleavingOperation::ReductionPHI, Real, Imag); 2050 return submitCompositeNode(PlaceholderNode); 2051 } 2052 2053 ComplexDeinterleavingGraph::NodePtr 2054 ComplexDeinterleavingGraph::identifySelectNode(Instruction *Real, 2055 Instruction *Imag) { 2056 auto *SelectReal = dyn_cast<SelectInst>(Real); 2057 auto *SelectImag = dyn_cast<SelectInst>(Imag); 2058 if (!SelectReal || !SelectImag) 2059 return nullptr; 2060 2061 Instruction *MaskA, *MaskB; 2062 Instruction *AR, *AI, *RA, *BI; 2063 if (!match(Real, m_Select(m_Instruction(MaskA), m_Instruction(AR), 2064 m_Instruction(RA))) || 2065 !match(Imag, m_Select(m_Instruction(MaskB), m_Instruction(AI), 2066 m_Instruction(BI)))) 2067 return nullptr; 2068 2069 if (MaskA != MaskB && !MaskA->isIdenticalTo(MaskB)) 2070 return nullptr; 2071 2072 if (!MaskA->getType()->isVectorTy()) 2073 return nullptr; 2074 2075 auto NodeA = identifyNode(AR, AI); 2076 if (!NodeA) 2077 return nullptr; 2078 2079 auto NodeB = identifyNode(RA, BI); 2080 if (!NodeB) 2081 return nullptr; 2082 2083 NodePtr PlaceholderNode = prepareCompositeNode( 2084 ComplexDeinterleavingOperation::ReductionSelect, Real, Imag); 2085 PlaceholderNode->addOperand(NodeA); 2086 PlaceholderNode->addOperand(NodeB); 2087 FinalInstructions.insert(MaskA); 2088 FinalInstructions.insert(MaskB); 2089 return submitCompositeNode(PlaceholderNode); 2090 } 2091 2092 static Value *replaceSymmetricNode(IRBuilderBase &B, unsigned Opcode, 2093 std::optional<FastMathFlags> Flags, 2094 Value *InputA, Value *InputB) { 2095 Value *I; 2096 switch (Opcode) { 2097 case Instruction::FNeg: 2098 I = B.CreateFNeg(InputA); 2099 break; 2100 case Instruction::FAdd: 2101 I = B.CreateFAdd(InputA, InputB); 2102 break; 2103 case Instruction::Add: 2104 I = B.CreateAdd(InputA, InputB); 2105 break; 2106 case Instruction::FSub: 2107 I = B.CreateFSub(InputA, InputB); 2108 break; 2109 case Instruction::Sub: 2110 I = B.CreateSub(InputA, InputB); 2111 break; 2112 case Instruction::FMul: 2113 I = B.CreateFMul(InputA, InputB); 2114 break; 2115 case Instruction::Mul: 2116 I = B.CreateMul(InputA, InputB); 2117 break; 2118 default: 2119 llvm_unreachable("Incorrect symmetric opcode"); 2120 } 2121 if (Flags) 2122 cast<Instruction>(I)->setFastMathFlags(*Flags); 2123 return I; 2124 } 2125 2126 Value *ComplexDeinterleavingGraph::replaceNode(IRBuilderBase &Builder, 2127 RawNodePtr Node) { 2128 if (Node->ReplacementNode) 2129 return Node->ReplacementNode; 2130 2131 auto ReplaceOperandIfExist = [&](RawNodePtr &Node, unsigned Idx) -> Value * { 2132 return Node->Operands.size() > Idx 2133 ? replaceNode(Builder, Node->Operands[Idx]) 2134 : nullptr; 2135 }; 2136 2137 Value *ReplacementNode; 2138 switch (Node->Operation) { 2139 case ComplexDeinterleavingOperation::CDot: { 2140 Value *Input0 = ReplaceOperandIfExist(Node, 0); 2141 Value *Input1 = ReplaceOperandIfExist(Node, 1); 2142 Value *Accumulator = ReplaceOperandIfExist(Node, 2); 2143 assert(!Input1 || (Input0->getType() == Input1->getType() && 2144 "Node inputs need to be of the same type")); 2145 ReplacementNode = TL->createComplexDeinterleavingIR( 2146 Builder, Node->Operation, Node->Rotation, Input0, Input1, Accumulator); 2147 break; 2148 } 2149 case ComplexDeinterleavingOperation::CAdd: 2150 case ComplexDeinterleavingOperation::CMulPartial: 2151 case ComplexDeinterleavingOperation::Symmetric: { 2152 Value *Input0 = ReplaceOperandIfExist(Node, 0); 2153 Value *Input1 = ReplaceOperandIfExist(Node, 1); 2154 Value *Accumulator = ReplaceOperandIfExist(Node, 2); 2155 assert(!Input1 || (Input0->getType() == Input1->getType() && 2156 "Node inputs need to be of the same type")); 2157 assert(!Accumulator || 2158 (Input0->getType() == Accumulator->getType() && 2159 "Accumulator and input need to be of the same type")); 2160 if (Node->Operation == ComplexDeinterleavingOperation::Symmetric) 2161 ReplacementNode = replaceSymmetricNode(Builder, Node->Opcode, Node->Flags, 2162 Input0, Input1); 2163 else 2164 ReplacementNode = TL->createComplexDeinterleavingIR( 2165 Builder, Node->Operation, Node->Rotation, Input0, Input1, 2166 Accumulator); 2167 break; 2168 } 2169 case ComplexDeinterleavingOperation::Deinterleave: 2170 llvm_unreachable("Deinterleave node should already have ReplacementNode"); 2171 break; 2172 case ComplexDeinterleavingOperation::Splat: { 2173 auto *NewTy = VectorType::getDoubleElementsVectorType( 2174 cast<VectorType>(Node->Real->getType())); 2175 auto *R = dyn_cast<Instruction>(Node->Real); 2176 auto *I = dyn_cast<Instruction>(Node->Imag); 2177 if (R && I) { 2178 // Splats that are not constant are interleaved where they are located 2179 Instruction *InsertPoint = (I->comesBefore(R) ? R : I)->getNextNode(); 2180 IRBuilder<> IRB(InsertPoint); 2181 ReplacementNode = IRB.CreateIntrinsic(Intrinsic::vector_interleave2, 2182 NewTy, {Node->Real, Node->Imag}); 2183 } else { 2184 ReplacementNode = Builder.CreateIntrinsic( 2185 Intrinsic::vector_interleave2, NewTy, {Node->Real, Node->Imag}); 2186 } 2187 break; 2188 } 2189 case ComplexDeinterleavingOperation::ReductionPHI: { 2190 // If Operation is ReductionPHI, a new empty PHINode is created. 2191 // It is filled later when the ReductionOperation is processed. 2192 auto *OldPHI = cast<PHINode>(Node->Real); 2193 auto *VTy = cast<VectorType>(Node->Real->getType()); 2194 auto *NewVTy = VectorType::getDoubleElementsVectorType(VTy); 2195 auto *NewPHI = PHINode::Create(NewVTy, 0, "", BackEdge->getFirstNonPHIIt()); 2196 OldToNewPHI[OldPHI] = NewPHI; 2197 ReplacementNode = NewPHI; 2198 break; 2199 } 2200 case ComplexDeinterleavingOperation::ReductionSingle: 2201 ReplacementNode = replaceNode(Builder, Node->Operands[0]); 2202 processReductionSingle(ReplacementNode, Node); 2203 break; 2204 case ComplexDeinterleavingOperation::ReductionOperation: 2205 ReplacementNode = replaceNode(Builder, Node->Operands[0]); 2206 processReductionOperation(ReplacementNode, Node); 2207 break; 2208 case ComplexDeinterleavingOperation::ReductionSelect: { 2209 auto *MaskReal = cast<Instruction>(Node->Real)->getOperand(0); 2210 auto *MaskImag = cast<Instruction>(Node->Imag)->getOperand(0); 2211 auto *A = replaceNode(Builder, Node->Operands[0]); 2212 auto *B = replaceNode(Builder, Node->Operands[1]); 2213 auto *NewMaskTy = VectorType::getDoubleElementsVectorType( 2214 cast<VectorType>(MaskReal->getType())); 2215 auto *NewMask = Builder.CreateIntrinsic(Intrinsic::vector_interleave2, 2216 NewMaskTy, {MaskReal, MaskImag}); 2217 ReplacementNode = Builder.CreateSelect(NewMask, A, B); 2218 break; 2219 } 2220 } 2221 2222 assert(ReplacementNode && "Target failed to create Intrinsic call."); 2223 NumComplexTransformations += 1; 2224 Node->ReplacementNode = ReplacementNode; 2225 return ReplacementNode; 2226 } 2227 2228 void ComplexDeinterleavingGraph::processReductionSingle( 2229 Value *OperationReplacement, RawNodePtr Node) { 2230 auto *Real = cast<Instruction>(Node->Real); 2231 auto *OldPHI = ReductionInfo[Real].first; 2232 auto *NewPHI = OldToNewPHI[OldPHI]; 2233 auto *VTy = cast<VectorType>(Real->getType()); 2234 auto *NewVTy = VectorType::getDoubleElementsVectorType(VTy); 2235 2236 Value *Init = OldPHI->getIncomingValueForBlock(Incoming); 2237 2238 IRBuilder<> Builder(Incoming->getTerminator()); 2239 2240 Value *NewInit = nullptr; 2241 if (auto *C = dyn_cast<Constant>(Init)) { 2242 if (C->isZeroValue()) 2243 NewInit = Constant::getNullValue(NewVTy); 2244 } 2245 2246 if (!NewInit) 2247 NewInit = Builder.CreateIntrinsic(Intrinsic::vector_interleave2, NewVTy, 2248 {Init, Constant::getNullValue(VTy)}); 2249 2250 NewPHI->addIncoming(NewInit, Incoming); 2251 NewPHI->addIncoming(OperationReplacement, BackEdge); 2252 2253 auto *FinalReduction = ReductionInfo[Real].second; 2254 Builder.SetInsertPoint(&*FinalReduction->getParent()->getFirstInsertionPt()); 2255 2256 auto *AddReduce = Builder.CreateAddReduce(OperationReplacement); 2257 FinalReduction->replaceAllUsesWith(AddReduce); 2258 } 2259 2260 void ComplexDeinterleavingGraph::processReductionOperation( 2261 Value *OperationReplacement, RawNodePtr Node) { 2262 auto *Real = cast<Instruction>(Node->Real); 2263 auto *Imag = cast<Instruction>(Node->Imag); 2264 auto *OldPHIReal = ReductionInfo[Real].first; 2265 auto *OldPHIImag = ReductionInfo[Imag].first; 2266 auto *NewPHI = OldToNewPHI[OldPHIReal]; 2267 2268 auto *VTy = cast<VectorType>(Real->getType()); 2269 auto *NewVTy = VectorType::getDoubleElementsVectorType(VTy); 2270 2271 // We have to interleave initial origin values coming from IncomingBlock 2272 Value *InitReal = OldPHIReal->getIncomingValueForBlock(Incoming); 2273 Value *InitImag = OldPHIImag->getIncomingValueForBlock(Incoming); 2274 2275 IRBuilder<> Builder(Incoming->getTerminator()); 2276 auto *NewInit = Builder.CreateIntrinsic(Intrinsic::vector_interleave2, NewVTy, 2277 {InitReal, InitImag}); 2278 2279 NewPHI->addIncoming(NewInit, Incoming); 2280 NewPHI->addIncoming(OperationReplacement, BackEdge); 2281 2282 // Deinterleave complex vector outside of loop so that it can be finally 2283 // reduced 2284 auto *FinalReductionReal = ReductionInfo[Real].second; 2285 auto *FinalReductionImag = ReductionInfo[Imag].second; 2286 2287 Builder.SetInsertPoint( 2288 &*FinalReductionReal->getParent()->getFirstInsertionPt()); 2289 auto *Deinterleave = Builder.CreateIntrinsic(Intrinsic::vector_deinterleave2, 2290 OperationReplacement->getType(), 2291 OperationReplacement); 2292 2293 auto *NewReal = Builder.CreateExtractValue(Deinterleave, (uint64_t)0); 2294 FinalReductionReal->replaceUsesOfWith(Real, NewReal); 2295 2296 Builder.SetInsertPoint(FinalReductionImag); 2297 auto *NewImag = Builder.CreateExtractValue(Deinterleave, 1); 2298 FinalReductionImag->replaceUsesOfWith(Imag, NewImag); 2299 } 2300 2301 void ComplexDeinterleavingGraph::replaceNodes() { 2302 SmallVector<Instruction *, 16> DeadInstrRoots; 2303 for (auto *RootInstruction : OrderedRoots) { 2304 // Check if this potential root went through check process and we can 2305 // deinterleave it 2306 if (!RootToNode.count(RootInstruction)) 2307 continue; 2308 2309 IRBuilder<> Builder(RootInstruction); 2310 auto RootNode = RootToNode[RootInstruction]; 2311 Value *R = replaceNode(Builder, RootNode.get()); 2312 2313 if (RootNode->Operation == 2314 ComplexDeinterleavingOperation::ReductionOperation) { 2315 auto *RootReal = cast<Instruction>(RootNode->Real); 2316 auto *RootImag = cast<Instruction>(RootNode->Imag); 2317 ReductionInfo[RootReal].first->removeIncomingValue(BackEdge); 2318 ReductionInfo[RootImag].first->removeIncomingValue(BackEdge); 2319 DeadInstrRoots.push_back(RootReal); 2320 DeadInstrRoots.push_back(RootImag); 2321 } else if (RootNode->Operation == 2322 ComplexDeinterleavingOperation::ReductionSingle) { 2323 auto *RootInst = cast<Instruction>(RootNode->Real); 2324 ReductionInfo[RootInst].first->removeIncomingValue(BackEdge); 2325 DeadInstrRoots.push_back(ReductionInfo[RootInst].second); 2326 } else { 2327 assert(R && "Unable to find replacement for RootInstruction"); 2328 DeadInstrRoots.push_back(RootInstruction); 2329 RootInstruction->replaceAllUsesWith(R); 2330 } 2331 } 2332 2333 for (auto *I : DeadInstrRoots) 2334 RecursivelyDeleteTriviallyDeadInstructions(I, TLI); 2335 } 2336