1 //===- ComplexDeinterleavingPass.cpp --------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Identification: 10 // This step is responsible for finding the patterns that can be lowered to 11 // complex instructions, and building a graph to represent the complex 12 // structures. Starting from the "Converging Shuffle" (a shuffle that 13 // reinterleaves the complex components, with a mask of <0, 2, 1, 3>), the 14 // operands are evaluated and identified as "Composite Nodes" (collections of 15 // instructions that can potentially be lowered to a single complex 16 // instruction). This is performed by checking the real and imaginary components 17 // and tracking the data flow for each component while following the operand 18 // pairs. Validity of each node is expected to be done upon creation, and any 19 // validation errors should halt traversal and prevent further graph 20 // construction. 21 // Instead of relying on Shuffle operations, vector interleaving and 22 // deinterleaving can be represented by vector.interleave2 and 23 // vector.deinterleave2 intrinsics. Scalable vectors can be represented only by 24 // these intrinsics, whereas, fixed-width vectors are recognized for both 25 // shufflevector instruction and intrinsics. 26 // 27 // Replacement: 28 // This step traverses the graph built up by identification, delegating to the 29 // target to validate and generate the correct intrinsics, and plumbs them 30 // together connecting each end of the new intrinsics graph to the existing 31 // use-def chain. This step is assumed to finish successfully, as all 32 // information is expected to be correct by this point. 33 // 34 // 35 // Internal data structure: 36 // ComplexDeinterleavingGraph: 37 // Keeps references to all the valid CompositeNodes formed as part of the 38 // transformation, and every Instruction contained within said nodes. It also 39 // holds onto a reference to the root Instruction, and the root node that should 40 // replace it. 41 // 42 // ComplexDeinterleavingCompositeNode: 43 // A CompositeNode represents a single transformation point; each node should 44 // transform into a single complex instruction (ignoring vector splitting, which 45 // would generate more instructions per node). They are identified in a 46 // depth-first manner, traversing and identifying the operands of each 47 // instruction in the order they appear in the IR. 48 // Each node maintains a reference to its Real and Imaginary instructions, 49 // as well as any additional instructions that make up the identified operation 50 // (Internal instructions should only have uses within their containing node). 51 // A Node also contains the rotation and operation type that it represents. 52 // Operands contains pointers to other CompositeNodes, acting as the edges in 53 // the graph. ReplacementValue is the transformed Value* that has been emitted 54 // to the IR. 55 // 56 // Note: If the operation of a Node is Shuffle, only the Real, Imaginary, and 57 // ReplacementValue fields of that Node are relevant, where the ReplacementValue 58 // should be pre-populated. 59 // 60 //===----------------------------------------------------------------------===// 61 62 #include "llvm/CodeGen/ComplexDeinterleavingPass.h" 63 #include "llvm/ADT/Statistic.h" 64 #include "llvm/Analysis/TargetLibraryInfo.h" 65 #include "llvm/Analysis/TargetTransformInfo.h" 66 #include "llvm/CodeGen/TargetLowering.h" 67 #include "llvm/CodeGen/TargetPassConfig.h" 68 #include "llvm/CodeGen/TargetSubtargetInfo.h" 69 #include "llvm/IR/IRBuilder.h" 70 #include "llvm/IR/PatternMatch.h" 71 #include "llvm/InitializePasses.h" 72 #include "llvm/Target/TargetMachine.h" 73 #include "llvm/Transforms/Utils/Local.h" 74 #include <algorithm> 75 76 using namespace llvm; 77 using namespace PatternMatch; 78 79 #define DEBUG_TYPE "complex-deinterleaving" 80 81 STATISTIC(NumComplexTransformations, "Amount of complex patterns transformed"); 82 83 static cl::opt<bool> ComplexDeinterleavingEnabled( 84 "enable-complex-deinterleaving", 85 cl::desc("Enable generation of complex instructions"), cl::init(true), 86 cl::Hidden); 87 88 /// Checks the given mask, and determines whether said mask is interleaving. 89 /// 90 /// To be interleaving, a mask must alternate between `i` and `i + (Length / 91 /// 2)`, and must contain all numbers within the range of `[0..Length)` (e.g. a 92 /// 4x vector interleaving mask would be <0, 2, 1, 3>). 93 static bool isInterleavingMask(ArrayRef<int> Mask); 94 95 /// Checks the given mask, and determines whether said mask is deinterleaving. 96 /// 97 /// To be deinterleaving, a mask must increment in steps of 2, and either start 98 /// with 0 or 1. 99 /// (e.g. an 8x vector deinterleaving mask would be either <0, 2, 4, 6> or 100 /// <1, 3, 5, 7>). 101 static bool isDeinterleavingMask(ArrayRef<int> Mask); 102 103 namespace { 104 105 class ComplexDeinterleavingLegacyPass : public FunctionPass { 106 public: 107 static char ID; 108 109 ComplexDeinterleavingLegacyPass(const TargetMachine *TM = nullptr) 110 : FunctionPass(ID), TM(TM) { 111 initializeComplexDeinterleavingLegacyPassPass( 112 *PassRegistry::getPassRegistry()); 113 } 114 115 StringRef getPassName() const override { 116 return "Complex Deinterleaving Pass"; 117 } 118 119 bool runOnFunction(Function &F) override; 120 void getAnalysisUsage(AnalysisUsage &AU) const override { 121 AU.addRequired<TargetLibraryInfoWrapperPass>(); 122 AU.setPreservesCFG(); 123 } 124 125 private: 126 const TargetMachine *TM; 127 }; 128 129 class ComplexDeinterleavingGraph; 130 struct ComplexDeinterleavingCompositeNode { 131 132 ComplexDeinterleavingCompositeNode(ComplexDeinterleavingOperation Op, 133 Value *R, Value *I) 134 : Operation(Op), Real(R), Imag(I) {} 135 136 private: 137 friend class ComplexDeinterleavingGraph; 138 using NodePtr = std::shared_ptr<ComplexDeinterleavingCompositeNode>; 139 using RawNodePtr = ComplexDeinterleavingCompositeNode *; 140 141 public: 142 ComplexDeinterleavingOperation Operation; 143 Value *Real; 144 Value *Imag; 145 146 // This two members are required exclusively for generating 147 // ComplexDeinterleavingOperation::Symmetric operations. 148 unsigned Opcode; 149 FastMathFlags Flags; 150 151 ComplexDeinterleavingRotation Rotation = 152 ComplexDeinterleavingRotation::Rotation_0; 153 SmallVector<RawNodePtr> Operands; 154 Value *ReplacementNode = nullptr; 155 156 void addOperand(NodePtr Node) { Operands.push_back(Node.get()); } 157 158 void dump() { dump(dbgs()); } 159 void dump(raw_ostream &OS) { 160 auto PrintValue = [&](Value *V) { 161 if (V) { 162 OS << "\""; 163 V->print(OS, true); 164 OS << "\"\n"; 165 } else 166 OS << "nullptr\n"; 167 }; 168 auto PrintNodeRef = [&](RawNodePtr Ptr) { 169 if (Ptr) 170 OS << Ptr << "\n"; 171 else 172 OS << "nullptr\n"; 173 }; 174 175 OS << "- CompositeNode: " << this << "\n"; 176 OS << " Real: "; 177 PrintValue(Real); 178 OS << " Imag: "; 179 PrintValue(Imag); 180 OS << " ReplacementNode: "; 181 PrintValue(ReplacementNode); 182 OS << " Operation: " << (int)Operation << "\n"; 183 OS << " Rotation: " << ((int)Rotation * 90) << "\n"; 184 OS << " Operands: \n"; 185 for (const auto &Op : Operands) { 186 OS << " - "; 187 PrintNodeRef(Op); 188 } 189 } 190 }; 191 192 class ComplexDeinterleavingGraph { 193 public: 194 struct Product { 195 Value *Multiplier; 196 Value *Multiplicand; 197 bool IsPositive; 198 }; 199 200 using Addend = std::pair<Value *, bool>; 201 using NodePtr = ComplexDeinterleavingCompositeNode::NodePtr; 202 using RawNodePtr = ComplexDeinterleavingCompositeNode::RawNodePtr; 203 204 // Helper struct for holding info about potential partial multiplication 205 // candidates 206 struct PartialMulCandidate { 207 Value *Common; 208 NodePtr Node; 209 unsigned RealIdx; 210 unsigned ImagIdx; 211 bool IsNodeInverted; 212 }; 213 214 explicit ComplexDeinterleavingGraph(const TargetLowering *TL, 215 const TargetLibraryInfo *TLI) 216 : TL(TL), TLI(TLI) {} 217 218 private: 219 const TargetLowering *TL = nullptr; 220 const TargetLibraryInfo *TLI = nullptr; 221 SmallVector<NodePtr> CompositeNodes; 222 223 SmallPtrSet<Instruction *, 16> FinalInstructions; 224 225 /// Root instructions are instructions from which complex computation starts 226 std::map<Instruction *, NodePtr> RootToNode; 227 228 /// Topologically sorted root instructions 229 SmallVector<Instruction *, 1> OrderedRoots; 230 231 /// When examining a basic block for complex deinterleaving, if it is a simple 232 /// one-block loop, then the only incoming block is 'Incoming' and the 233 /// 'BackEdge' block is the block itself." 234 BasicBlock *BackEdge = nullptr; 235 BasicBlock *Incoming = nullptr; 236 237 /// ReductionInfo maps from %ReductionOp to %PHInode and Instruction 238 /// %OutsideUser as it is shown in the IR: 239 /// 240 /// vector.body: 241 /// %PHInode = phi <vector type> [ zeroinitializer, %entry ], 242 /// [ %ReductionOp, %vector.body ] 243 /// ... 244 /// %ReductionOp = fadd i64 ... 245 /// ... 246 /// br i1 %condition, label %vector.body, %middle.block 247 /// 248 /// middle.block: 249 /// %OutsideUser = llvm.vector.reduce.fadd(..., %ReductionOp) 250 /// 251 /// %OutsideUser can be `llvm.vector.reduce.fadd` or `fadd` preceding 252 /// `llvm.vector.reduce.fadd` when unroll factor isn't one. 253 std::map<Instruction *, std::pair<PHINode *, Instruction *>> ReductionInfo; 254 255 /// In the process of detecting a reduction, we consider a pair of 256 /// %ReductionOP, which we refer to as real and imag (or vice versa), and 257 /// traverse the use-tree to detect complex operations. As this is a reduction 258 /// operation, it will eventually reach RealPHI and ImagPHI, which corresponds 259 /// to the %ReductionOPs that we suspect to be complex. 260 /// RealPHI and ImagPHI are used by the identifyPHINode method. 261 PHINode *RealPHI = nullptr; 262 PHINode *ImagPHI = nullptr; 263 264 /// OldToNewPHI maps the original real PHINode to a new, double-sized PHINode. 265 /// The new PHINode corresponds to a vector of deinterleaved complex numbers. 266 /// This mapping is populated during 267 /// ComplexDeinterleavingOperation::ReductionPHI node replacement. It is then 268 /// used in the ComplexDeinterleavingOperation::ReductionOperation node 269 /// replacement process. 270 std::map<PHINode *, PHINode *> OldToNewPHI; 271 272 NodePtr prepareCompositeNode(ComplexDeinterleavingOperation Operation, 273 Value *R, Value *I) { 274 assert(((Operation != ComplexDeinterleavingOperation::ReductionPHI && 275 Operation != ComplexDeinterleavingOperation::ReductionOperation) || 276 (R && I)) && 277 "Reduction related nodes must have Real and Imaginary parts"); 278 return std::make_shared<ComplexDeinterleavingCompositeNode>(Operation, R, 279 I); 280 } 281 282 NodePtr submitCompositeNode(NodePtr Node) { 283 CompositeNodes.push_back(Node); 284 return Node; 285 } 286 287 NodePtr getContainingComposite(Value *R, Value *I) { 288 for (const auto &CN : CompositeNodes) { 289 if (CN->Real == R && CN->Imag == I) 290 return CN; 291 } 292 return nullptr; 293 } 294 295 /// Identifies a complex partial multiply pattern and its rotation, based on 296 /// the following patterns 297 /// 298 /// 0: r: cr + ar * br 299 /// i: ci + ar * bi 300 /// 90: r: cr - ai * bi 301 /// i: ci + ai * br 302 /// 180: r: cr - ar * br 303 /// i: ci - ar * bi 304 /// 270: r: cr + ai * bi 305 /// i: ci - ai * br 306 NodePtr identifyPartialMul(Instruction *Real, Instruction *Imag); 307 308 /// Identify the other branch of a Partial Mul, taking the CommonOperandI that 309 /// is partially known from identifyPartialMul, filling in the other half of 310 /// the complex pair. 311 NodePtr 312 identifyNodeWithImplicitAdd(Instruction *I, Instruction *J, 313 std::pair<Value *, Value *> &CommonOperandI); 314 315 /// Identifies a complex add pattern and its rotation, based on the following 316 /// patterns. 317 /// 318 /// 90: r: ar - bi 319 /// i: ai + br 320 /// 270: r: ar + bi 321 /// i: ai - br 322 NodePtr identifyAdd(Instruction *Real, Instruction *Imag); 323 NodePtr identifySymmetricOperation(Instruction *Real, Instruction *Imag); 324 325 NodePtr identifyNode(Value *R, Value *I); 326 327 /// Determine if a sum of complex numbers can be formed from \p RealAddends 328 /// and \p ImagAddens. If \p Accumulator is not null, add the result to it. 329 /// Return nullptr if it is not possible to construct a complex number. 330 /// \p Flags are needed to generate symmetric Add and Sub operations. 331 NodePtr identifyAdditions(std::list<Addend> &RealAddends, 332 std::list<Addend> &ImagAddends, FastMathFlags Flags, 333 NodePtr Accumulator); 334 335 /// Extract one addend that have both real and imaginary parts positive. 336 NodePtr extractPositiveAddend(std::list<Addend> &RealAddends, 337 std::list<Addend> &ImagAddends); 338 339 /// Determine if sum of multiplications of complex numbers can be formed from 340 /// \p RealMuls and \p ImagMuls. If \p Accumulator is not null, add the result 341 /// to it. Return nullptr if it is not possible to construct a complex number. 342 NodePtr identifyMultiplications(std::vector<Product> &RealMuls, 343 std::vector<Product> &ImagMuls, 344 NodePtr Accumulator); 345 346 /// Go through pairs of multiplication (one Real and one Imag) and find all 347 /// possible candidates for partial multiplication and put them into \p 348 /// Candidates. Returns true if all Product has pair with common operand 349 bool collectPartialMuls(const std::vector<Product> &RealMuls, 350 const std::vector<Product> &ImagMuls, 351 std::vector<PartialMulCandidate> &Candidates); 352 353 /// If the code is compiled with -Ofast or expressions have `reassoc` flag, 354 /// the order of complex computation operations may be significantly altered, 355 /// and the real and imaginary parts may not be executed in parallel. This 356 /// function takes this into consideration and employs a more general approach 357 /// to identify complex computations. Initially, it gathers all the addends 358 /// and multiplicands and then constructs a complex expression from them. 359 NodePtr identifyReassocNodes(Instruction *I, Instruction *J); 360 361 NodePtr identifyRoot(Instruction *I); 362 363 /// Identifies the Deinterleave operation applied to a vector containing 364 /// complex numbers. There are two ways to represent the Deinterleave 365 /// operation: 366 /// * Using two shufflevectors with even indices for /pReal instruction and 367 /// odd indices for /pImag instructions (only for fixed-width vectors) 368 /// * Using two extractvalue instructions applied to `vector.deinterleave2` 369 /// intrinsic (for both fixed and scalable vectors) 370 NodePtr identifyDeinterleave(Instruction *Real, Instruction *Imag); 371 372 NodePtr identifyPHINode(Instruction *Real, Instruction *Imag); 373 374 /// Identifies SelectInsts in a loop that has reduction with predication masks 375 /// and/or predicated tail folding 376 NodePtr identifySelectNode(Instruction *Real, Instruction *Imag); 377 378 Value *replaceNode(IRBuilderBase &Builder, RawNodePtr Node); 379 380 /// Complete IR modifications after producing new reduction operation: 381 /// * Populate the PHINode generated for 382 /// ComplexDeinterleavingOperation::ReductionPHI 383 /// * Deinterleave the final value outside of the loop and repurpose original 384 /// reduction users 385 void processReductionOperation(Value *OperationReplacement, RawNodePtr Node); 386 387 public: 388 void dump() { dump(dbgs()); } 389 void dump(raw_ostream &OS) { 390 for (const auto &Node : CompositeNodes) 391 Node->dump(OS); 392 } 393 394 /// Returns false if the deinterleaving operation should be cancelled for the 395 /// current graph. 396 bool identifyNodes(Instruction *RootI); 397 398 /// In case \pB is one-block loop, this function seeks potential reductions 399 /// and populates ReductionInfo. Returns true if any reductions were 400 /// identified. 401 bool collectPotentialReductions(BasicBlock *B); 402 403 void identifyReductionNodes(); 404 405 /// Check that every instruction, from the roots to the leaves, has internal 406 /// uses. 407 bool checkNodes(); 408 409 /// Perform the actual replacement of the underlying instruction graph. 410 void replaceNodes(); 411 }; 412 413 class ComplexDeinterleaving { 414 public: 415 ComplexDeinterleaving(const TargetLowering *tl, const TargetLibraryInfo *tli) 416 : TL(tl), TLI(tli) {} 417 bool runOnFunction(Function &F); 418 419 private: 420 bool evaluateBasicBlock(BasicBlock *B); 421 422 const TargetLowering *TL = nullptr; 423 const TargetLibraryInfo *TLI = nullptr; 424 }; 425 426 } // namespace 427 428 char ComplexDeinterleavingLegacyPass::ID = 0; 429 430 INITIALIZE_PASS_BEGIN(ComplexDeinterleavingLegacyPass, DEBUG_TYPE, 431 "Complex Deinterleaving", false, false) 432 INITIALIZE_PASS_END(ComplexDeinterleavingLegacyPass, DEBUG_TYPE, 433 "Complex Deinterleaving", false, false) 434 435 PreservedAnalyses ComplexDeinterleavingPass::run(Function &F, 436 FunctionAnalysisManager &AM) { 437 const TargetLowering *TL = TM->getSubtargetImpl(F)->getTargetLowering(); 438 auto &TLI = AM.getResult<llvm::TargetLibraryAnalysis>(F); 439 if (!ComplexDeinterleaving(TL, &TLI).runOnFunction(F)) 440 return PreservedAnalyses::all(); 441 442 PreservedAnalyses PA; 443 PA.preserve<FunctionAnalysisManagerModuleProxy>(); 444 return PA; 445 } 446 447 FunctionPass *llvm::createComplexDeinterleavingPass(const TargetMachine *TM) { 448 return new ComplexDeinterleavingLegacyPass(TM); 449 } 450 451 bool ComplexDeinterleavingLegacyPass::runOnFunction(Function &F) { 452 const auto *TL = TM->getSubtargetImpl(F)->getTargetLowering(); 453 auto TLI = getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F); 454 return ComplexDeinterleaving(TL, &TLI).runOnFunction(F); 455 } 456 457 bool ComplexDeinterleaving::runOnFunction(Function &F) { 458 if (!ComplexDeinterleavingEnabled) { 459 LLVM_DEBUG( 460 dbgs() << "Complex deinterleaving has been explicitly disabled.\n"); 461 return false; 462 } 463 464 if (!TL->isComplexDeinterleavingSupported()) { 465 LLVM_DEBUG( 466 dbgs() << "Complex deinterleaving has been disabled, target does " 467 "not support lowering of complex number operations.\n"); 468 return false; 469 } 470 471 bool Changed = false; 472 for (auto &B : F) 473 Changed |= evaluateBasicBlock(&B); 474 475 return Changed; 476 } 477 478 static bool isInterleavingMask(ArrayRef<int> Mask) { 479 // If the size is not even, it's not an interleaving mask 480 if ((Mask.size() & 1)) 481 return false; 482 483 int HalfNumElements = Mask.size() / 2; 484 for (int Idx = 0; Idx < HalfNumElements; ++Idx) { 485 int MaskIdx = Idx * 2; 486 if (Mask[MaskIdx] != Idx || Mask[MaskIdx + 1] != (Idx + HalfNumElements)) 487 return false; 488 } 489 490 return true; 491 } 492 493 static bool isDeinterleavingMask(ArrayRef<int> Mask) { 494 int Offset = Mask[0]; 495 int HalfNumElements = Mask.size() / 2; 496 497 for (int Idx = 1; Idx < HalfNumElements; ++Idx) { 498 if (Mask[Idx] != (Idx * 2) + Offset) 499 return false; 500 } 501 502 return true; 503 } 504 505 bool ComplexDeinterleaving::evaluateBasicBlock(BasicBlock *B) { 506 ComplexDeinterleavingGraph Graph(TL, TLI); 507 if (Graph.collectPotentialReductions(B)) 508 Graph.identifyReductionNodes(); 509 510 for (auto &I : *B) 511 Graph.identifyNodes(&I); 512 513 if (Graph.checkNodes()) { 514 Graph.replaceNodes(); 515 return true; 516 } 517 518 return false; 519 } 520 521 ComplexDeinterleavingGraph::NodePtr 522 ComplexDeinterleavingGraph::identifyNodeWithImplicitAdd( 523 Instruction *Real, Instruction *Imag, 524 std::pair<Value *, Value *> &PartialMatch) { 525 LLVM_DEBUG(dbgs() << "identifyNodeWithImplicitAdd " << *Real << " / " << *Imag 526 << "\n"); 527 528 if (!Real->hasOneUse() || !Imag->hasOneUse()) { 529 LLVM_DEBUG(dbgs() << " - Mul operand has multiple uses.\n"); 530 return nullptr; 531 } 532 533 if (Real->getOpcode() != Instruction::FMul || 534 Imag->getOpcode() != Instruction::FMul) { 535 LLVM_DEBUG(dbgs() << " - Real or imaginary instruction is not fmul\n"); 536 return nullptr; 537 } 538 539 Value *R0 = Real->getOperand(0); 540 Value *R1 = Real->getOperand(1); 541 Value *I0 = Imag->getOperand(0); 542 Value *I1 = Imag->getOperand(1); 543 544 // A +/+ has a rotation of 0. If any of the operands are fneg, we flip the 545 // rotations and use the operand. 546 unsigned Negs = 0; 547 Value *Op; 548 if (match(R0, m_Neg(m_Value(Op)))) { 549 Negs |= 1; 550 R0 = Op; 551 } else if (match(R1, m_Neg(m_Value(Op)))) { 552 Negs |= 1; 553 R1 = Op; 554 } 555 556 if (match(I0, m_Neg(m_Value(Op)))) { 557 Negs |= 2; 558 Negs ^= 1; 559 I0 = Op; 560 } else if (match(I1, m_Neg(m_Value(Op)))) { 561 Negs |= 2; 562 Negs ^= 1; 563 I1 = Op; 564 } 565 566 ComplexDeinterleavingRotation Rotation = (ComplexDeinterleavingRotation)Negs; 567 568 Value *CommonOperand; 569 Value *UncommonRealOp; 570 Value *UncommonImagOp; 571 572 if (R0 == I0 || R0 == I1) { 573 CommonOperand = R0; 574 UncommonRealOp = R1; 575 } else if (R1 == I0 || R1 == I1) { 576 CommonOperand = R1; 577 UncommonRealOp = R0; 578 } else { 579 LLVM_DEBUG(dbgs() << " - No equal operand\n"); 580 return nullptr; 581 } 582 583 UncommonImagOp = (CommonOperand == I0) ? I1 : I0; 584 if (Rotation == ComplexDeinterleavingRotation::Rotation_90 || 585 Rotation == ComplexDeinterleavingRotation::Rotation_270) 586 std::swap(UncommonRealOp, UncommonImagOp); 587 588 // Between identifyPartialMul and here we need to have found a complete valid 589 // pair from the CommonOperand of each part. 590 if (Rotation == ComplexDeinterleavingRotation::Rotation_0 || 591 Rotation == ComplexDeinterleavingRotation::Rotation_180) 592 PartialMatch.first = CommonOperand; 593 else 594 PartialMatch.second = CommonOperand; 595 596 if (!PartialMatch.first || !PartialMatch.second) { 597 LLVM_DEBUG(dbgs() << " - Incomplete partial match\n"); 598 return nullptr; 599 } 600 601 NodePtr CommonNode = identifyNode(PartialMatch.first, PartialMatch.second); 602 if (!CommonNode) { 603 LLVM_DEBUG(dbgs() << " - No CommonNode identified\n"); 604 return nullptr; 605 } 606 607 NodePtr UncommonNode = identifyNode(UncommonRealOp, UncommonImagOp); 608 if (!UncommonNode) { 609 LLVM_DEBUG(dbgs() << " - No UncommonNode identified\n"); 610 return nullptr; 611 } 612 613 NodePtr Node = prepareCompositeNode( 614 ComplexDeinterleavingOperation::CMulPartial, Real, Imag); 615 Node->Rotation = Rotation; 616 Node->addOperand(CommonNode); 617 Node->addOperand(UncommonNode); 618 return submitCompositeNode(Node); 619 } 620 621 ComplexDeinterleavingGraph::NodePtr 622 ComplexDeinterleavingGraph::identifyPartialMul(Instruction *Real, 623 Instruction *Imag) { 624 LLVM_DEBUG(dbgs() << "identifyPartialMul " << *Real << " / " << *Imag 625 << "\n"); 626 // Determine rotation 627 ComplexDeinterleavingRotation Rotation; 628 if (Real->getOpcode() == Instruction::FAdd && 629 Imag->getOpcode() == Instruction::FAdd) 630 Rotation = ComplexDeinterleavingRotation::Rotation_0; 631 else if (Real->getOpcode() == Instruction::FSub && 632 Imag->getOpcode() == Instruction::FAdd) 633 Rotation = ComplexDeinterleavingRotation::Rotation_90; 634 else if (Real->getOpcode() == Instruction::FSub && 635 Imag->getOpcode() == Instruction::FSub) 636 Rotation = ComplexDeinterleavingRotation::Rotation_180; 637 else if (Real->getOpcode() == Instruction::FAdd && 638 Imag->getOpcode() == Instruction::FSub) 639 Rotation = ComplexDeinterleavingRotation::Rotation_270; 640 else { 641 LLVM_DEBUG(dbgs() << " - Unhandled rotation.\n"); 642 return nullptr; 643 } 644 645 if (!Real->getFastMathFlags().allowContract() || 646 !Imag->getFastMathFlags().allowContract()) { 647 LLVM_DEBUG(dbgs() << " - Contract is missing from the FastMath flags.\n"); 648 return nullptr; 649 } 650 651 Value *CR = Real->getOperand(0); 652 Instruction *RealMulI = dyn_cast<Instruction>(Real->getOperand(1)); 653 if (!RealMulI) 654 return nullptr; 655 Value *CI = Imag->getOperand(0); 656 Instruction *ImagMulI = dyn_cast<Instruction>(Imag->getOperand(1)); 657 if (!ImagMulI) 658 return nullptr; 659 660 if (!RealMulI->hasOneUse() || !ImagMulI->hasOneUse()) { 661 LLVM_DEBUG(dbgs() << " - Mul instruction has multiple uses\n"); 662 return nullptr; 663 } 664 665 Value *R0 = RealMulI->getOperand(0); 666 Value *R1 = RealMulI->getOperand(1); 667 Value *I0 = ImagMulI->getOperand(0); 668 Value *I1 = ImagMulI->getOperand(1); 669 670 Value *CommonOperand; 671 Value *UncommonRealOp; 672 Value *UncommonImagOp; 673 674 if (R0 == I0 || R0 == I1) { 675 CommonOperand = R0; 676 UncommonRealOp = R1; 677 } else if (R1 == I0 || R1 == I1) { 678 CommonOperand = R1; 679 UncommonRealOp = R0; 680 } else { 681 LLVM_DEBUG(dbgs() << " - No equal operand\n"); 682 return nullptr; 683 } 684 685 UncommonImagOp = (CommonOperand == I0) ? I1 : I0; 686 if (Rotation == ComplexDeinterleavingRotation::Rotation_90 || 687 Rotation == ComplexDeinterleavingRotation::Rotation_270) 688 std::swap(UncommonRealOp, UncommonImagOp); 689 690 std::pair<Value *, Value *> PartialMatch( 691 (Rotation == ComplexDeinterleavingRotation::Rotation_0 || 692 Rotation == ComplexDeinterleavingRotation::Rotation_180) 693 ? CommonOperand 694 : nullptr, 695 (Rotation == ComplexDeinterleavingRotation::Rotation_90 || 696 Rotation == ComplexDeinterleavingRotation::Rotation_270) 697 ? CommonOperand 698 : nullptr); 699 700 auto *CRInst = dyn_cast<Instruction>(CR); 701 auto *CIInst = dyn_cast<Instruction>(CI); 702 703 if (!CRInst || !CIInst) { 704 LLVM_DEBUG(dbgs() << " - Common operands are not instructions.\n"); 705 return nullptr; 706 } 707 708 NodePtr CNode = identifyNodeWithImplicitAdd(CRInst, CIInst, PartialMatch); 709 if (!CNode) { 710 LLVM_DEBUG(dbgs() << " - No cnode identified\n"); 711 return nullptr; 712 } 713 714 NodePtr UncommonRes = identifyNode(UncommonRealOp, UncommonImagOp); 715 if (!UncommonRes) { 716 LLVM_DEBUG(dbgs() << " - No UncommonRes identified\n"); 717 return nullptr; 718 } 719 720 assert(PartialMatch.first && PartialMatch.second); 721 NodePtr CommonRes = identifyNode(PartialMatch.first, PartialMatch.second); 722 if (!CommonRes) { 723 LLVM_DEBUG(dbgs() << " - No CommonRes identified\n"); 724 return nullptr; 725 } 726 727 NodePtr Node = prepareCompositeNode( 728 ComplexDeinterleavingOperation::CMulPartial, Real, Imag); 729 Node->Rotation = Rotation; 730 Node->addOperand(CommonRes); 731 Node->addOperand(UncommonRes); 732 Node->addOperand(CNode); 733 return submitCompositeNode(Node); 734 } 735 736 ComplexDeinterleavingGraph::NodePtr 737 ComplexDeinterleavingGraph::identifyAdd(Instruction *Real, Instruction *Imag) { 738 LLVM_DEBUG(dbgs() << "identifyAdd " << *Real << " / " << *Imag << "\n"); 739 740 // Determine rotation 741 ComplexDeinterleavingRotation Rotation; 742 if ((Real->getOpcode() == Instruction::FSub && 743 Imag->getOpcode() == Instruction::FAdd) || 744 (Real->getOpcode() == Instruction::Sub && 745 Imag->getOpcode() == Instruction::Add)) 746 Rotation = ComplexDeinterleavingRotation::Rotation_90; 747 else if ((Real->getOpcode() == Instruction::FAdd && 748 Imag->getOpcode() == Instruction::FSub) || 749 (Real->getOpcode() == Instruction::Add && 750 Imag->getOpcode() == Instruction::Sub)) 751 Rotation = ComplexDeinterleavingRotation::Rotation_270; 752 else { 753 LLVM_DEBUG(dbgs() << " - Unhandled case, rotation is not assigned.\n"); 754 return nullptr; 755 } 756 757 auto *AR = dyn_cast<Instruction>(Real->getOperand(0)); 758 auto *BI = dyn_cast<Instruction>(Real->getOperand(1)); 759 auto *AI = dyn_cast<Instruction>(Imag->getOperand(0)); 760 auto *BR = dyn_cast<Instruction>(Imag->getOperand(1)); 761 762 if (!AR || !AI || !BR || !BI) { 763 LLVM_DEBUG(dbgs() << " - Not all operands are instructions.\n"); 764 return nullptr; 765 } 766 767 NodePtr ResA = identifyNode(AR, AI); 768 if (!ResA) { 769 LLVM_DEBUG(dbgs() << " - AR/AI is not identified as a composite node.\n"); 770 return nullptr; 771 } 772 NodePtr ResB = identifyNode(BR, BI); 773 if (!ResB) { 774 LLVM_DEBUG(dbgs() << " - BR/BI is not identified as a composite node.\n"); 775 return nullptr; 776 } 777 778 NodePtr Node = 779 prepareCompositeNode(ComplexDeinterleavingOperation::CAdd, Real, Imag); 780 Node->Rotation = Rotation; 781 Node->addOperand(ResA); 782 Node->addOperand(ResB); 783 return submitCompositeNode(Node); 784 } 785 786 static bool isInstructionPairAdd(Instruction *A, Instruction *B) { 787 unsigned OpcA = A->getOpcode(); 788 unsigned OpcB = B->getOpcode(); 789 790 return (OpcA == Instruction::FSub && OpcB == Instruction::FAdd) || 791 (OpcA == Instruction::FAdd && OpcB == Instruction::FSub) || 792 (OpcA == Instruction::Sub && OpcB == Instruction::Add) || 793 (OpcA == Instruction::Add && OpcB == Instruction::Sub); 794 } 795 796 static bool isInstructionPairMul(Instruction *A, Instruction *B) { 797 auto Pattern = 798 m_BinOp(m_FMul(m_Value(), m_Value()), m_FMul(m_Value(), m_Value())); 799 800 return match(A, Pattern) && match(B, Pattern); 801 } 802 803 static bool isInstructionPotentiallySymmetric(Instruction *I) { 804 switch (I->getOpcode()) { 805 case Instruction::FAdd: 806 case Instruction::FSub: 807 case Instruction::FMul: 808 case Instruction::FNeg: 809 return true; 810 default: 811 return false; 812 } 813 } 814 815 ComplexDeinterleavingGraph::NodePtr 816 ComplexDeinterleavingGraph::identifySymmetricOperation(Instruction *Real, 817 Instruction *Imag) { 818 if (Real->getOpcode() != Imag->getOpcode()) 819 return nullptr; 820 821 if (!isInstructionPotentiallySymmetric(Real) || 822 !isInstructionPotentiallySymmetric(Imag)) 823 return nullptr; 824 825 auto *R0 = Real->getOperand(0); 826 auto *I0 = Imag->getOperand(0); 827 828 NodePtr Op0 = identifyNode(R0, I0); 829 NodePtr Op1 = nullptr; 830 if (Op0 == nullptr) 831 return nullptr; 832 833 if (Real->isBinaryOp()) { 834 auto *R1 = Real->getOperand(1); 835 auto *I1 = Imag->getOperand(1); 836 Op1 = identifyNode(R1, I1); 837 if (Op1 == nullptr) 838 return nullptr; 839 } 840 841 if (isa<FPMathOperator>(Real) && 842 Real->getFastMathFlags() != Imag->getFastMathFlags()) 843 return nullptr; 844 845 auto Node = prepareCompositeNode(ComplexDeinterleavingOperation::Symmetric, 846 Real, Imag); 847 Node->Opcode = Real->getOpcode(); 848 if (isa<FPMathOperator>(Real)) 849 Node->Flags = Real->getFastMathFlags(); 850 851 Node->addOperand(Op0); 852 if (Real->isBinaryOp()) 853 Node->addOperand(Op1); 854 855 return submitCompositeNode(Node); 856 } 857 858 ComplexDeinterleavingGraph::NodePtr 859 ComplexDeinterleavingGraph::identifyNode(Value *R, Value *I) { 860 LLVM_DEBUG(dbgs() << "identifyNode on " << *R << " / " << *I << "\n"); 861 assert(R->getType() == I->getType() && 862 "Real and imaginary parts should not have different types"); 863 if (NodePtr CN = getContainingComposite(R, I)) { 864 LLVM_DEBUG(dbgs() << " - Folding to existing node\n"); 865 return CN; 866 } 867 868 auto *Real = dyn_cast<Instruction>(R); 869 auto *Imag = dyn_cast<Instruction>(I); 870 if (!Real || !Imag) 871 return nullptr; 872 873 if (NodePtr CN = identifyDeinterleave(Real, Imag)) 874 return CN; 875 876 if (NodePtr CN = identifyPHINode(Real, Imag)) 877 return CN; 878 879 if (NodePtr CN = identifySelectNode(Real, Imag)) 880 return CN; 881 882 auto *VTy = cast<VectorType>(Real->getType()); 883 auto *NewVTy = VectorType::getDoubleElementsVectorType(VTy); 884 885 bool HasCMulSupport = TL->isComplexDeinterleavingOperationSupported( 886 ComplexDeinterleavingOperation::CMulPartial, NewVTy); 887 bool HasCAddSupport = TL->isComplexDeinterleavingOperationSupported( 888 ComplexDeinterleavingOperation::CAdd, NewVTy); 889 890 if (HasCMulSupport && isInstructionPairMul(Real, Imag)) { 891 if (NodePtr CN = identifyPartialMul(Real, Imag)) 892 return CN; 893 } 894 895 if (HasCAddSupport && isInstructionPairAdd(Real, Imag)) { 896 if (NodePtr CN = identifyAdd(Real, Imag)) 897 return CN; 898 } 899 900 if (HasCMulSupport && HasCAddSupport) { 901 if (NodePtr CN = identifyReassocNodes(Real, Imag)) 902 return CN; 903 } 904 905 if (NodePtr CN = identifySymmetricOperation(Real, Imag)) 906 return CN; 907 908 LLVM_DEBUG(dbgs() << " - Not recognised as a valid pattern.\n"); 909 return nullptr; 910 } 911 912 ComplexDeinterleavingGraph::NodePtr 913 ComplexDeinterleavingGraph::identifyReassocNodes(Instruction *Real, 914 Instruction *Imag) { 915 916 if ((Real->getOpcode() != Instruction::FAdd && 917 Real->getOpcode() != Instruction::FSub && 918 Real->getOpcode() != Instruction::FNeg) || 919 (Imag->getOpcode() != Instruction::FAdd && 920 Imag->getOpcode() != Instruction::FSub && 921 Imag->getOpcode() != Instruction::FNeg)) 922 return nullptr; 923 924 if (Real->getFastMathFlags() != Imag->getFastMathFlags()) { 925 LLVM_DEBUG( 926 dbgs() 927 << "The flags in Real and Imaginary instructions are not identical\n"); 928 return nullptr; 929 } 930 931 FastMathFlags Flags = Real->getFastMathFlags(); 932 if (!Flags.allowReassoc()) { 933 LLVM_DEBUG( 934 dbgs() << "the 'Reassoc' attribute is missing in the FastMath flags\n"); 935 return nullptr; 936 } 937 938 // Collect multiplications and addend instructions from the given instruction 939 // while traversing it operands. Additionally, verify that all instructions 940 // have the same fast math flags. 941 auto Collect = [&Flags](Instruction *Insn, std::vector<Product> &Muls, 942 std::list<Addend> &Addends) -> bool { 943 SmallVector<PointerIntPair<Value *, 1, bool>> Worklist = {{Insn, true}}; 944 SmallPtrSet<Value *, 8> Visited; 945 while (!Worklist.empty()) { 946 auto [V, IsPositive] = Worklist.back(); 947 Worklist.pop_back(); 948 if (!Visited.insert(V).second) 949 continue; 950 951 Instruction *I = dyn_cast<Instruction>(V); 952 if (!I) { 953 Addends.emplace_back(V, IsPositive); 954 continue; 955 } 956 957 // If an instruction has more than one user, it indicates that it either 958 // has an external user, which will be later checked by the checkNodes 959 // function, or it is a subexpression utilized by multiple expressions. In 960 // the latter case, we will attempt to separately identify the complex 961 // operation from here in order to create a shared 962 // ComplexDeinterleavingCompositeNode. 963 if (I != Insn && I->getNumUses() > 1) { 964 LLVM_DEBUG(dbgs() << "Found potential sub-expression: " << *I << "\n"); 965 Addends.emplace_back(I, IsPositive); 966 continue; 967 } 968 969 if (I->getOpcode() == Instruction::FAdd) { 970 Worklist.emplace_back(I->getOperand(1), IsPositive); 971 Worklist.emplace_back(I->getOperand(0), IsPositive); 972 } else if (I->getOpcode() == Instruction::FSub) { 973 Worklist.emplace_back(I->getOperand(1), !IsPositive); 974 Worklist.emplace_back(I->getOperand(0), IsPositive); 975 } else if (I->getOpcode() == Instruction::FMul) { 976 Value *A, *B; 977 if (match(I->getOperand(0), m_FNeg(m_Value(A)))) { 978 IsPositive = !IsPositive; 979 } else { 980 A = I->getOperand(0); 981 } 982 983 if (match(I->getOperand(1), m_FNeg(m_Value(B)))) { 984 IsPositive = !IsPositive; 985 } else { 986 B = I->getOperand(1); 987 } 988 Muls.push_back(Product{A, B, IsPositive}); 989 } else if (I->getOpcode() == Instruction::FNeg) { 990 Worklist.emplace_back(I->getOperand(0), !IsPositive); 991 } else { 992 Addends.emplace_back(I, IsPositive); 993 continue; 994 } 995 996 if (I->getFastMathFlags() != Flags) { 997 LLVM_DEBUG(dbgs() << "The instruction's fast math flags are " 998 "inconsistent with the root instructions' flags: " 999 << *I << "\n"); 1000 return false; 1001 } 1002 } 1003 return true; 1004 }; 1005 1006 std::vector<Product> RealMuls, ImagMuls; 1007 std::list<Addend> RealAddends, ImagAddends; 1008 if (!Collect(Real, RealMuls, RealAddends) || 1009 !Collect(Imag, ImagMuls, ImagAddends)) 1010 return nullptr; 1011 1012 if (RealAddends.size() != ImagAddends.size()) 1013 return nullptr; 1014 1015 NodePtr FinalNode; 1016 if (!RealMuls.empty() || !ImagMuls.empty()) { 1017 // If there are multiplicands, extract positive addend and use it as an 1018 // accumulator 1019 FinalNode = extractPositiveAddend(RealAddends, ImagAddends); 1020 FinalNode = identifyMultiplications(RealMuls, ImagMuls, FinalNode); 1021 if (!FinalNode) 1022 return nullptr; 1023 } 1024 1025 // Identify and process remaining additions 1026 if (!RealAddends.empty() || !ImagAddends.empty()) { 1027 FinalNode = identifyAdditions(RealAddends, ImagAddends, Flags, FinalNode); 1028 if (!FinalNode) 1029 return nullptr; 1030 } 1031 assert(FinalNode && "FinalNode can not be nullptr here"); 1032 // Set the Real and Imag fields of the final node and submit it 1033 FinalNode->Real = Real; 1034 FinalNode->Imag = Imag; 1035 submitCompositeNode(FinalNode); 1036 return FinalNode; 1037 } 1038 1039 bool ComplexDeinterleavingGraph::collectPartialMuls( 1040 const std::vector<Product> &RealMuls, const std::vector<Product> &ImagMuls, 1041 std::vector<PartialMulCandidate> &PartialMulCandidates) { 1042 // Helper function to extract a common operand from two products 1043 auto FindCommonInstruction = [](const Product &Real, 1044 const Product &Imag) -> Value * { 1045 if (Real.Multiplicand == Imag.Multiplicand || 1046 Real.Multiplicand == Imag.Multiplier) 1047 return Real.Multiplicand; 1048 1049 if (Real.Multiplier == Imag.Multiplicand || 1050 Real.Multiplier == Imag.Multiplier) 1051 return Real.Multiplier; 1052 1053 return nullptr; 1054 }; 1055 1056 // Iterating over real and imaginary multiplications to find common operands 1057 // If a common operand is found, a partial multiplication candidate is created 1058 // and added to the candidates vector The function returns false if no common 1059 // operands are found for any product 1060 for (unsigned i = 0; i < RealMuls.size(); ++i) { 1061 bool FoundCommon = false; 1062 for (unsigned j = 0; j < ImagMuls.size(); ++j) { 1063 auto *Common = FindCommonInstruction(RealMuls[i], ImagMuls[j]); 1064 if (!Common) 1065 continue; 1066 1067 auto *A = RealMuls[i].Multiplicand == Common ? RealMuls[i].Multiplier 1068 : RealMuls[i].Multiplicand; 1069 auto *B = ImagMuls[j].Multiplicand == Common ? ImagMuls[j].Multiplier 1070 : ImagMuls[j].Multiplicand; 1071 1072 auto Node = identifyNode(A, B); 1073 if (Node) { 1074 FoundCommon = true; 1075 PartialMulCandidates.push_back({Common, Node, i, j, false}); 1076 } 1077 1078 Node = identifyNode(B, A); 1079 if (Node) { 1080 FoundCommon = true; 1081 PartialMulCandidates.push_back({Common, Node, i, j, true}); 1082 } 1083 } 1084 if (!FoundCommon) 1085 return false; 1086 } 1087 return true; 1088 } 1089 1090 ComplexDeinterleavingGraph::NodePtr 1091 ComplexDeinterleavingGraph::identifyMultiplications( 1092 std::vector<Product> &RealMuls, std::vector<Product> &ImagMuls, 1093 NodePtr Accumulator = nullptr) { 1094 if (RealMuls.size() != ImagMuls.size()) 1095 return nullptr; 1096 1097 std::vector<PartialMulCandidate> Info; 1098 if (!collectPartialMuls(RealMuls, ImagMuls, Info)) 1099 return nullptr; 1100 1101 // Map to store common instruction to node pointers 1102 std::map<Value *, NodePtr> CommonToNode; 1103 std::vector<bool> Processed(Info.size(), false); 1104 for (unsigned I = 0; I < Info.size(); ++I) { 1105 if (Processed[I]) 1106 continue; 1107 1108 PartialMulCandidate &InfoA = Info[I]; 1109 for (unsigned J = I + 1; J < Info.size(); ++J) { 1110 if (Processed[J]) 1111 continue; 1112 1113 PartialMulCandidate &InfoB = Info[J]; 1114 auto *InfoReal = &InfoA; 1115 auto *InfoImag = &InfoB; 1116 1117 auto NodeFromCommon = identifyNode(InfoReal->Common, InfoImag->Common); 1118 if (!NodeFromCommon) { 1119 std::swap(InfoReal, InfoImag); 1120 NodeFromCommon = identifyNode(InfoReal->Common, InfoImag->Common); 1121 } 1122 if (!NodeFromCommon) 1123 continue; 1124 1125 CommonToNode[InfoReal->Common] = NodeFromCommon; 1126 CommonToNode[InfoImag->Common] = NodeFromCommon; 1127 Processed[I] = true; 1128 Processed[J] = true; 1129 } 1130 } 1131 1132 std::vector<bool> ProcessedReal(RealMuls.size(), false); 1133 std::vector<bool> ProcessedImag(ImagMuls.size(), false); 1134 NodePtr Result = Accumulator; 1135 for (auto &PMI : Info) { 1136 if (ProcessedReal[PMI.RealIdx] || ProcessedImag[PMI.ImagIdx]) 1137 continue; 1138 1139 auto It = CommonToNode.find(PMI.Common); 1140 // TODO: Process independent complex multiplications. Cases like this: 1141 // A.real() * B where both A and B are complex numbers. 1142 if (It == CommonToNode.end()) { 1143 LLVM_DEBUG({ 1144 dbgs() << "Unprocessed independent partial multiplication:\n"; 1145 for (auto *Mul : {&RealMuls[PMI.RealIdx], &RealMuls[PMI.RealIdx]}) 1146 dbgs().indent(4) << (Mul->IsPositive ? "+" : "-") << *Mul->Multiplier 1147 << " multiplied by " << *Mul->Multiplicand << "\n"; 1148 }); 1149 return nullptr; 1150 } 1151 1152 auto &RealMul = RealMuls[PMI.RealIdx]; 1153 auto &ImagMul = ImagMuls[PMI.ImagIdx]; 1154 1155 auto NodeA = It->second; 1156 auto NodeB = PMI.Node; 1157 auto IsMultiplicandReal = PMI.Common == NodeA->Real; 1158 // The following table illustrates the relationship between multiplications 1159 // and rotations. If we consider the multiplication (X + iY) * (U + iV), we 1160 // can see: 1161 // 1162 // Rotation | Real | Imag | 1163 // ---------+--------+--------+ 1164 // 0 | x * u | x * v | 1165 // 90 | -y * v | y * u | 1166 // 180 | -x * u | -x * v | 1167 // 270 | y * v | -y * u | 1168 // 1169 // Check if the candidate can indeed be represented by partial 1170 // multiplication 1171 // TODO: Add support for multiplication by complex one 1172 if ((IsMultiplicandReal && PMI.IsNodeInverted) || 1173 (!IsMultiplicandReal && !PMI.IsNodeInverted)) 1174 continue; 1175 1176 // Determine the rotation based on the multiplications 1177 ComplexDeinterleavingRotation Rotation; 1178 if (IsMultiplicandReal) { 1179 // Detect 0 and 180 degrees rotation 1180 if (RealMul.IsPositive && ImagMul.IsPositive) 1181 Rotation = llvm::ComplexDeinterleavingRotation::Rotation_0; 1182 else if (!RealMul.IsPositive && !ImagMul.IsPositive) 1183 Rotation = llvm::ComplexDeinterleavingRotation::Rotation_180; 1184 else 1185 continue; 1186 1187 } else { 1188 // Detect 90 and 270 degrees rotation 1189 if (!RealMul.IsPositive && ImagMul.IsPositive) 1190 Rotation = llvm::ComplexDeinterleavingRotation::Rotation_90; 1191 else if (RealMul.IsPositive && !ImagMul.IsPositive) 1192 Rotation = llvm::ComplexDeinterleavingRotation::Rotation_270; 1193 else 1194 continue; 1195 } 1196 1197 LLVM_DEBUG({ 1198 dbgs() << "Identified partial multiplication (X, Y) * (U, V):\n"; 1199 dbgs().indent(4) << "X: " << *NodeA->Real << "\n"; 1200 dbgs().indent(4) << "Y: " << *NodeA->Imag << "\n"; 1201 dbgs().indent(4) << "U: " << *NodeB->Real << "\n"; 1202 dbgs().indent(4) << "V: " << *NodeB->Imag << "\n"; 1203 dbgs().indent(4) << "Rotation - " << (int)Rotation * 90 << "\n"; 1204 }); 1205 1206 NodePtr NodeMul = prepareCompositeNode( 1207 ComplexDeinterleavingOperation::CMulPartial, nullptr, nullptr); 1208 NodeMul->Rotation = Rotation; 1209 NodeMul->addOperand(NodeA); 1210 NodeMul->addOperand(NodeB); 1211 if (Result) 1212 NodeMul->addOperand(Result); 1213 submitCompositeNode(NodeMul); 1214 Result = NodeMul; 1215 ProcessedReal[PMI.RealIdx] = true; 1216 ProcessedImag[PMI.ImagIdx] = true; 1217 } 1218 1219 // Ensure all products have been processed, if not return nullptr. 1220 if (!all_of(ProcessedReal, [](bool V) { return V; }) || 1221 !all_of(ProcessedImag, [](bool V) { return V; })) { 1222 1223 // Dump debug information about which partial multiplications are not 1224 // processed. 1225 LLVM_DEBUG({ 1226 dbgs() << "Unprocessed products (Real):\n"; 1227 for (size_t i = 0; i < ProcessedReal.size(); ++i) { 1228 if (!ProcessedReal[i]) 1229 dbgs().indent(4) << (RealMuls[i].IsPositive ? "+" : "-") 1230 << *RealMuls[i].Multiplier << " multiplied by " 1231 << *RealMuls[i].Multiplicand << "\n"; 1232 } 1233 dbgs() << "Unprocessed products (Imag):\n"; 1234 for (size_t i = 0; i < ProcessedImag.size(); ++i) { 1235 if (!ProcessedImag[i]) 1236 dbgs().indent(4) << (ImagMuls[i].IsPositive ? "+" : "-") 1237 << *ImagMuls[i].Multiplier << " multiplied by " 1238 << *ImagMuls[i].Multiplicand << "\n"; 1239 } 1240 }); 1241 return nullptr; 1242 } 1243 1244 return Result; 1245 } 1246 1247 ComplexDeinterleavingGraph::NodePtr 1248 ComplexDeinterleavingGraph::identifyAdditions(std::list<Addend> &RealAddends, 1249 std::list<Addend> &ImagAddends, 1250 FastMathFlags Flags, 1251 NodePtr Accumulator = nullptr) { 1252 if (RealAddends.size() != ImagAddends.size()) 1253 return nullptr; 1254 1255 NodePtr Result; 1256 // If we have accumulator use it as first addend 1257 if (Accumulator) 1258 Result = Accumulator; 1259 // Otherwise find an element with both positive real and imaginary parts. 1260 else 1261 Result = extractPositiveAddend(RealAddends, ImagAddends); 1262 1263 if (!Result) 1264 return nullptr; 1265 1266 while (!RealAddends.empty()) { 1267 auto ItR = RealAddends.begin(); 1268 auto [R, IsPositiveR] = *ItR; 1269 1270 bool FoundImag = false; 1271 for (auto ItI = ImagAddends.begin(); ItI != ImagAddends.end(); ++ItI) { 1272 auto [I, IsPositiveI] = *ItI; 1273 ComplexDeinterleavingRotation Rotation; 1274 if (IsPositiveR && IsPositiveI) 1275 Rotation = ComplexDeinterleavingRotation::Rotation_0; 1276 else if (!IsPositiveR && IsPositiveI) 1277 Rotation = ComplexDeinterleavingRotation::Rotation_90; 1278 else if (!IsPositiveR && !IsPositiveI) 1279 Rotation = ComplexDeinterleavingRotation::Rotation_180; 1280 else 1281 Rotation = ComplexDeinterleavingRotation::Rotation_270; 1282 1283 NodePtr AddNode; 1284 if (Rotation == ComplexDeinterleavingRotation::Rotation_0 || 1285 Rotation == ComplexDeinterleavingRotation::Rotation_180) { 1286 AddNode = identifyNode(R, I); 1287 } else { 1288 AddNode = identifyNode(I, R); 1289 } 1290 if (AddNode) { 1291 LLVM_DEBUG({ 1292 dbgs() << "Identified addition:\n"; 1293 dbgs().indent(4) << "X: " << *R << "\n"; 1294 dbgs().indent(4) << "Y: " << *I << "\n"; 1295 dbgs().indent(4) << "Rotation - " << (int)Rotation * 90 << "\n"; 1296 }); 1297 1298 NodePtr TmpNode; 1299 if (Rotation == llvm::ComplexDeinterleavingRotation::Rotation_0) { 1300 TmpNode = prepareCompositeNode( 1301 ComplexDeinterleavingOperation::Symmetric, nullptr, nullptr); 1302 TmpNode->Opcode = Instruction::FAdd; 1303 TmpNode->Flags = Flags; 1304 } else if (Rotation == 1305 llvm::ComplexDeinterleavingRotation::Rotation_180) { 1306 TmpNode = prepareCompositeNode( 1307 ComplexDeinterleavingOperation::Symmetric, nullptr, nullptr); 1308 TmpNode->Opcode = Instruction::FSub; 1309 TmpNode->Flags = Flags; 1310 } else { 1311 TmpNode = prepareCompositeNode(ComplexDeinterleavingOperation::CAdd, 1312 nullptr, nullptr); 1313 TmpNode->Rotation = Rotation; 1314 } 1315 1316 TmpNode->addOperand(Result); 1317 TmpNode->addOperand(AddNode); 1318 submitCompositeNode(TmpNode); 1319 Result = TmpNode; 1320 RealAddends.erase(ItR); 1321 ImagAddends.erase(ItI); 1322 FoundImag = true; 1323 break; 1324 } 1325 } 1326 if (!FoundImag) 1327 return nullptr; 1328 } 1329 return Result; 1330 } 1331 1332 ComplexDeinterleavingGraph::NodePtr 1333 ComplexDeinterleavingGraph::extractPositiveAddend( 1334 std::list<Addend> &RealAddends, std::list<Addend> &ImagAddends) { 1335 for (auto ItR = RealAddends.begin(); ItR != RealAddends.end(); ++ItR) { 1336 for (auto ItI = ImagAddends.begin(); ItI != ImagAddends.end(); ++ItI) { 1337 auto [R, IsPositiveR] = *ItR; 1338 auto [I, IsPositiveI] = *ItI; 1339 if (IsPositiveR && IsPositiveI) { 1340 auto Result = identifyNode(R, I); 1341 if (Result) { 1342 RealAddends.erase(ItR); 1343 ImagAddends.erase(ItI); 1344 return Result; 1345 } 1346 } 1347 } 1348 } 1349 return nullptr; 1350 } 1351 1352 bool ComplexDeinterleavingGraph::identifyNodes(Instruction *RootI) { 1353 // This potential root instruction might already have been recognized as 1354 // reduction. Because RootToNode maps both Real and Imaginary parts to 1355 // CompositeNode we should choose only one either Real or Imag instruction to 1356 // use as an anchor for generating complex instruction. 1357 auto It = RootToNode.find(RootI); 1358 if (It != RootToNode.end() && It->second->Real == RootI) { 1359 OrderedRoots.push_back(RootI); 1360 return true; 1361 } 1362 1363 auto RootNode = identifyRoot(RootI); 1364 if (!RootNode) 1365 return false; 1366 1367 LLVM_DEBUG({ 1368 Function *F = RootI->getFunction(); 1369 BasicBlock *B = RootI->getParent(); 1370 dbgs() << "Complex deinterleaving graph for " << F->getName() 1371 << "::" << B->getName() << ".\n"; 1372 dump(dbgs()); 1373 dbgs() << "\n"; 1374 }); 1375 RootToNode[RootI] = RootNode; 1376 OrderedRoots.push_back(RootI); 1377 return true; 1378 } 1379 1380 bool ComplexDeinterleavingGraph::collectPotentialReductions(BasicBlock *B) { 1381 bool FoundPotentialReduction = false; 1382 1383 auto *Br = dyn_cast<BranchInst>(B->getTerminator()); 1384 if (!Br || Br->getNumSuccessors() != 2) 1385 return false; 1386 1387 // Identify simple one-block loop 1388 if (Br->getSuccessor(0) != B && Br->getSuccessor(1) != B) 1389 return false; 1390 1391 SmallVector<PHINode *> PHIs; 1392 for (auto &PHI : B->phis()) { 1393 if (PHI.getNumIncomingValues() != 2) 1394 continue; 1395 1396 if (!PHI.getType()->isVectorTy()) 1397 continue; 1398 1399 auto *ReductionOp = dyn_cast<Instruction>(PHI.getIncomingValueForBlock(B)); 1400 if (!ReductionOp) 1401 continue; 1402 1403 // Check if final instruction is reduced outside of current block 1404 Instruction *FinalReduction = nullptr; 1405 auto NumUsers = 0u; 1406 for (auto *U : ReductionOp->users()) { 1407 ++NumUsers; 1408 if (U == &PHI) 1409 continue; 1410 FinalReduction = dyn_cast<Instruction>(U); 1411 } 1412 1413 if (NumUsers != 2 || !FinalReduction || FinalReduction->getParent() == B) 1414 continue; 1415 1416 ReductionInfo[ReductionOp] = {&PHI, FinalReduction}; 1417 BackEdge = B; 1418 auto BackEdgeIdx = PHI.getBasicBlockIndex(B); 1419 auto IncomingIdx = BackEdgeIdx == 0 ? 1 : 0; 1420 Incoming = PHI.getIncomingBlock(IncomingIdx); 1421 FoundPotentialReduction = true; 1422 1423 // If the initial value of PHINode is an Instruction, consider it a leaf 1424 // value of a complex deinterleaving graph. 1425 if (auto *InitPHI = 1426 dyn_cast<Instruction>(PHI.getIncomingValueForBlock(Incoming))) 1427 FinalInstructions.insert(InitPHI); 1428 } 1429 return FoundPotentialReduction; 1430 } 1431 1432 void ComplexDeinterleavingGraph::identifyReductionNodes() { 1433 SmallVector<bool> Processed(ReductionInfo.size(), false); 1434 SmallVector<Instruction *> OperationInstruction; 1435 for (auto &P : ReductionInfo) 1436 OperationInstruction.push_back(P.first); 1437 1438 // Identify a complex computation by evaluating two reduction operations that 1439 // potentially could be involved 1440 for (size_t i = 0; i < OperationInstruction.size(); ++i) { 1441 if (Processed[i]) 1442 continue; 1443 for (size_t j = i + 1; j < OperationInstruction.size(); ++j) { 1444 if (Processed[j]) 1445 continue; 1446 1447 auto *Real = OperationInstruction[i]; 1448 auto *Imag = OperationInstruction[j]; 1449 if (Real->getType() != Imag->getType()) 1450 continue; 1451 1452 RealPHI = ReductionInfo[Real].first; 1453 ImagPHI = ReductionInfo[Imag].first; 1454 auto Node = identifyNode(Real, Imag); 1455 if (!Node) { 1456 std::swap(Real, Imag); 1457 std::swap(RealPHI, ImagPHI); 1458 Node = identifyNode(Real, Imag); 1459 } 1460 1461 // If a node is identified, mark its operation instructions as used to 1462 // prevent re-identification and attach the node to the real part 1463 if (Node) { 1464 LLVM_DEBUG(dbgs() << "Identified reduction starting from instructions: " 1465 << *Real << " / " << *Imag << "\n"); 1466 Processed[i] = true; 1467 Processed[j] = true; 1468 auto RootNode = prepareCompositeNode( 1469 ComplexDeinterleavingOperation::ReductionOperation, Real, Imag); 1470 RootNode->addOperand(Node); 1471 RootToNode[Real] = RootNode; 1472 RootToNode[Imag] = RootNode; 1473 submitCompositeNode(RootNode); 1474 break; 1475 } 1476 } 1477 } 1478 1479 RealPHI = nullptr; 1480 ImagPHI = nullptr; 1481 } 1482 1483 bool ComplexDeinterleavingGraph::checkNodes() { 1484 // Collect all instructions from roots to leaves 1485 SmallPtrSet<Instruction *, 16> AllInstructions; 1486 SmallVector<Instruction *, 8> Worklist; 1487 for (auto &Pair : RootToNode) 1488 Worklist.push_back(Pair.first); 1489 1490 // Extract all instructions that are used by all XCMLA/XCADD/ADD/SUB/NEG 1491 // chains 1492 while (!Worklist.empty()) { 1493 auto *I = Worklist.back(); 1494 Worklist.pop_back(); 1495 1496 if (!AllInstructions.insert(I).second) 1497 continue; 1498 1499 for (Value *Op : I->operands()) { 1500 if (auto *OpI = dyn_cast<Instruction>(Op)) { 1501 if (!FinalInstructions.count(I)) 1502 Worklist.emplace_back(OpI); 1503 } 1504 } 1505 } 1506 1507 // Find instructions that have users outside of chain 1508 SmallVector<Instruction *, 2> OuterInstructions; 1509 for (auto *I : AllInstructions) { 1510 // Skip root nodes 1511 if (RootToNode.count(I)) 1512 continue; 1513 1514 for (User *U : I->users()) { 1515 if (AllInstructions.count(cast<Instruction>(U))) 1516 continue; 1517 1518 // Found an instruction that is not used by XCMLA/XCADD chain 1519 Worklist.emplace_back(I); 1520 break; 1521 } 1522 } 1523 1524 // If any instructions are found to be used outside, find and remove roots 1525 // that somehow connect to those instructions. 1526 SmallPtrSet<Instruction *, 16> Visited; 1527 while (!Worklist.empty()) { 1528 auto *I = Worklist.back(); 1529 Worklist.pop_back(); 1530 if (!Visited.insert(I).second) 1531 continue; 1532 1533 // Found an impacted root node. Removing it from the nodes to be 1534 // deinterleaved 1535 if (RootToNode.count(I)) { 1536 LLVM_DEBUG(dbgs() << "Instruction " << *I 1537 << " could be deinterleaved but its chain of complex " 1538 "operations have an outside user\n"); 1539 RootToNode.erase(I); 1540 } 1541 1542 if (!AllInstructions.count(I) || FinalInstructions.count(I)) 1543 continue; 1544 1545 for (User *U : I->users()) 1546 Worklist.emplace_back(cast<Instruction>(U)); 1547 1548 for (Value *Op : I->operands()) { 1549 if (auto *OpI = dyn_cast<Instruction>(Op)) 1550 Worklist.emplace_back(OpI); 1551 } 1552 } 1553 return !RootToNode.empty(); 1554 } 1555 1556 ComplexDeinterleavingGraph::NodePtr 1557 ComplexDeinterleavingGraph::identifyRoot(Instruction *RootI) { 1558 if (auto *Intrinsic = dyn_cast<IntrinsicInst>(RootI)) { 1559 if (Intrinsic->getIntrinsicID() != 1560 Intrinsic::experimental_vector_interleave2) 1561 return nullptr; 1562 1563 auto *Real = dyn_cast<Instruction>(Intrinsic->getOperand(0)); 1564 auto *Imag = dyn_cast<Instruction>(Intrinsic->getOperand(1)); 1565 if (!Real || !Imag) 1566 return nullptr; 1567 1568 return identifyNode(Real, Imag); 1569 } 1570 1571 auto *SVI = dyn_cast<ShuffleVectorInst>(RootI); 1572 if (!SVI) 1573 return nullptr; 1574 1575 // Look for a shufflevector that takes separate vectors of the real and 1576 // imaginary components and recombines them into a single vector. 1577 if (!isInterleavingMask(SVI->getShuffleMask())) 1578 return nullptr; 1579 1580 Instruction *Real; 1581 Instruction *Imag; 1582 if (!match(RootI, m_Shuffle(m_Instruction(Real), m_Instruction(Imag)))) 1583 return nullptr; 1584 1585 return identifyNode(Real, Imag); 1586 } 1587 1588 ComplexDeinterleavingGraph::NodePtr 1589 ComplexDeinterleavingGraph::identifyDeinterleave(Instruction *Real, 1590 Instruction *Imag) { 1591 Instruction *I = nullptr; 1592 Value *FinalValue = nullptr; 1593 if (match(Real, m_ExtractValue<0>(m_Instruction(I))) && 1594 match(Imag, m_ExtractValue<1>(m_Specific(I))) && 1595 match(I, m_Intrinsic<Intrinsic::experimental_vector_deinterleave2>( 1596 m_Value(FinalValue)))) { 1597 NodePtr PlaceholderNode = prepareCompositeNode( 1598 llvm::ComplexDeinterleavingOperation::Deinterleave, Real, Imag); 1599 PlaceholderNode->ReplacementNode = FinalValue; 1600 FinalInstructions.insert(Real); 1601 FinalInstructions.insert(Imag); 1602 return submitCompositeNode(PlaceholderNode); 1603 } 1604 1605 auto *RealShuffle = dyn_cast<ShuffleVectorInst>(Real); 1606 auto *ImagShuffle = dyn_cast<ShuffleVectorInst>(Imag); 1607 if (!RealShuffle || !ImagShuffle) { 1608 if (RealShuffle || ImagShuffle) 1609 LLVM_DEBUG(dbgs() << " - There's a shuffle where there shouldn't be.\n"); 1610 return nullptr; 1611 } 1612 1613 Value *RealOp1 = RealShuffle->getOperand(1); 1614 if (!isa<UndefValue>(RealOp1) && !isa<ConstantAggregateZero>(RealOp1)) { 1615 LLVM_DEBUG(dbgs() << " - RealOp1 is not undef or zero.\n"); 1616 return nullptr; 1617 } 1618 Value *ImagOp1 = ImagShuffle->getOperand(1); 1619 if (!isa<UndefValue>(ImagOp1) && !isa<ConstantAggregateZero>(ImagOp1)) { 1620 LLVM_DEBUG(dbgs() << " - ImagOp1 is not undef or zero.\n"); 1621 return nullptr; 1622 } 1623 1624 Value *RealOp0 = RealShuffle->getOperand(0); 1625 Value *ImagOp0 = ImagShuffle->getOperand(0); 1626 1627 if (RealOp0 != ImagOp0) { 1628 LLVM_DEBUG(dbgs() << " - Shuffle operands are not equal.\n"); 1629 return nullptr; 1630 } 1631 1632 ArrayRef<int> RealMask = RealShuffle->getShuffleMask(); 1633 ArrayRef<int> ImagMask = ImagShuffle->getShuffleMask(); 1634 if (!isDeinterleavingMask(RealMask) || !isDeinterleavingMask(ImagMask)) { 1635 LLVM_DEBUG(dbgs() << " - Masks are not deinterleaving.\n"); 1636 return nullptr; 1637 } 1638 1639 if (RealMask[0] != 0 || ImagMask[0] != 1) { 1640 LLVM_DEBUG(dbgs() << " - Masks do not have the correct initial value.\n"); 1641 return nullptr; 1642 } 1643 1644 // Type checking, the shuffle type should be a vector type of the same 1645 // scalar type, but half the size 1646 auto CheckType = [&](ShuffleVectorInst *Shuffle) { 1647 Value *Op = Shuffle->getOperand(0); 1648 auto *ShuffleTy = cast<FixedVectorType>(Shuffle->getType()); 1649 auto *OpTy = cast<FixedVectorType>(Op->getType()); 1650 1651 if (OpTy->getScalarType() != ShuffleTy->getScalarType()) 1652 return false; 1653 if ((ShuffleTy->getNumElements() * 2) != OpTy->getNumElements()) 1654 return false; 1655 1656 return true; 1657 }; 1658 1659 auto CheckDeinterleavingShuffle = [&](ShuffleVectorInst *Shuffle) -> bool { 1660 if (!CheckType(Shuffle)) 1661 return false; 1662 1663 ArrayRef<int> Mask = Shuffle->getShuffleMask(); 1664 int Last = *Mask.rbegin(); 1665 1666 Value *Op = Shuffle->getOperand(0); 1667 auto *OpTy = cast<FixedVectorType>(Op->getType()); 1668 int NumElements = OpTy->getNumElements(); 1669 1670 // Ensure that the deinterleaving shuffle only pulls from the first 1671 // shuffle operand. 1672 return Last < NumElements; 1673 }; 1674 1675 if (RealShuffle->getType() != ImagShuffle->getType()) { 1676 LLVM_DEBUG(dbgs() << " - Shuffle types aren't equal.\n"); 1677 return nullptr; 1678 } 1679 if (!CheckDeinterleavingShuffle(RealShuffle)) { 1680 LLVM_DEBUG(dbgs() << " - RealShuffle is invalid type.\n"); 1681 return nullptr; 1682 } 1683 if (!CheckDeinterleavingShuffle(ImagShuffle)) { 1684 LLVM_DEBUG(dbgs() << " - ImagShuffle is invalid type.\n"); 1685 return nullptr; 1686 } 1687 1688 NodePtr PlaceholderNode = 1689 prepareCompositeNode(llvm::ComplexDeinterleavingOperation::Deinterleave, 1690 RealShuffle, ImagShuffle); 1691 PlaceholderNode->ReplacementNode = RealShuffle->getOperand(0); 1692 FinalInstructions.insert(RealShuffle); 1693 FinalInstructions.insert(ImagShuffle); 1694 return submitCompositeNode(PlaceholderNode); 1695 } 1696 1697 ComplexDeinterleavingGraph::NodePtr 1698 ComplexDeinterleavingGraph::identifyPHINode(Instruction *Real, 1699 Instruction *Imag) { 1700 if (Real != RealPHI || Imag != ImagPHI) 1701 return nullptr; 1702 1703 NodePtr PlaceholderNode = prepareCompositeNode( 1704 ComplexDeinterleavingOperation::ReductionPHI, Real, Imag); 1705 return submitCompositeNode(PlaceholderNode); 1706 } 1707 1708 ComplexDeinterleavingGraph::NodePtr 1709 ComplexDeinterleavingGraph::identifySelectNode(Instruction *Real, 1710 Instruction *Imag) { 1711 auto *SelectReal = dyn_cast<SelectInst>(Real); 1712 auto *SelectImag = dyn_cast<SelectInst>(Imag); 1713 if (!SelectReal || !SelectImag) 1714 return nullptr; 1715 1716 Instruction *MaskA, *MaskB; 1717 Instruction *AR, *AI, *RA, *BI; 1718 if (!match(Real, m_Select(m_Instruction(MaskA), m_Instruction(AR), 1719 m_Instruction(RA))) || 1720 !match(Imag, m_Select(m_Instruction(MaskB), m_Instruction(AI), 1721 m_Instruction(BI)))) 1722 return nullptr; 1723 1724 if (MaskA != MaskB && !MaskA->isIdenticalTo(MaskB)) 1725 return nullptr; 1726 1727 if (!MaskA->getType()->isVectorTy()) 1728 return nullptr; 1729 1730 auto NodeA = identifyNode(AR, AI); 1731 if (!NodeA) 1732 return nullptr; 1733 1734 auto NodeB = identifyNode(RA, BI); 1735 if (!NodeB) 1736 return nullptr; 1737 1738 NodePtr PlaceholderNode = prepareCompositeNode( 1739 ComplexDeinterleavingOperation::ReductionSelect, Real, Imag); 1740 PlaceholderNode->addOperand(NodeA); 1741 PlaceholderNode->addOperand(NodeB); 1742 FinalInstructions.insert(MaskA); 1743 FinalInstructions.insert(MaskB); 1744 return submitCompositeNode(PlaceholderNode); 1745 } 1746 1747 static Value *replaceSymmetricNode(IRBuilderBase &B, unsigned Opcode, 1748 FastMathFlags Flags, Value *InputA, 1749 Value *InputB) { 1750 Value *I; 1751 switch (Opcode) { 1752 case Instruction::FNeg: 1753 I = B.CreateFNeg(InputA); 1754 break; 1755 case Instruction::FAdd: 1756 I = B.CreateFAdd(InputA, InputB); 1757 break; 1758 case Instruction::FSub: 1759 I = B.CreateFSub(InputA, InputB); 1760 break; 1761 case Instruction::FMul: 1762 I = B.CreateFMul(InputA, InputB); 1763 break; 1764 default: 1765 llvm_unreachable("Incorrect symmetric opcode"); 1766 } 1767 cast<Instruction>(I)->setFastMathFlags(Flags); 1768 return I; 1769 } 1770 1771 Value *ComplexDeinterleavingGraph::replaceNode(IRBuilderBase &Builder, 1772 RawNodePtr Node) { 1773 if (Node->ReplacementNode) 1774 return Node->ReplacementNode; 1775 1776 auto ReplaceOperandIfExist = [&](RawNodePtr &Node, unsigned Idx) -> Value * { 1777 return Node->Operands.size() > Idx 1778 ? replaceNode(Builder, Node->Operands[Idx]) 1779 : nullptr; 1780 }; 1781 1782 Value *ReplacementNode; 1783 switch (Node->Operation) { 1784 case ComplexDeinterleavingOperation::CAdd: 1785 case ComplexDeinterleavingOperation::CMulPartial: 1786 case ComplexDeinterleavingOperation::Symmetric: { 1787 Value *Input0 = ReplaceOperandIfExist(Node, 0); 1788 Value *Input1 = ReplaceOperandIfExist(Node, 1); 1789 Value *Accumulator = ReplaceOperandIfExist(Node, 2); 1790 assert(!Input1 || (Input0->getType() == Input1->getType() && 1791 "Node inputs need to be of the same type")); 1792 assert(!Accumulator || 1793 (Input0->getType() == Accumulator->getType() && 1794 "Accumulator and input need to be of the same type")); 1795 if (Node->Operation == ComplexDeinterleavingOperation::Symmetric) 1796 ReplacementNode = replaceSymmetricNode(Builder, Node->Opcode, Node->Flags, 1797 Input0, Input1); 1798 else 1799 ReplacementNode = TL->createComplexDeinterleavingIR( 1800 Builder, Node->Operation, Node->Rotation, Input0, Input1, 1801 Accumulator); 1802 break; 1803 } 1804 case ComplexDeinterleavingOperation::Deinterleave: 1805 llvm_unreachable("Deinterleave node should already have ReplacementNode"); 1806 break; 1807 case ComplexDeinterleavingOperation::ReductionPHI: { 1808 // If Operation is ReductionPHI, a new empty PHINode is created. 1809 // It is filled later when the ReductionOperation is processed. 1810 auto *VTy = cast<VectorType>(Node->Real->getType()); 1811 auto *NewVTy = VectorType::getDoubleElementsVectorType(VTy); 1812 auto *NewPHI = PHINode::Create(NewVTy, 0, "", BackEdge->getFirstNonPHI()); 1813 OldToNewPHI[dyn_cast<PHINode>(Node->Real)] = NewPHI; 1814 ReplacementNode = NewPHI; 1815 break; 1816 } 1817 case ComplexDeinterleavingOperation::ReductionOperation: 1818 ReplacementNode = replaceNode(Builder, Node->Operands[0]); 1819 processReductionOperation(ReplacementNode, Node); 1820 break; 1821 case ComplexDeinterleavingOperation::ReductionSelect: { 1822 auto *MaskReal = cast<Instruction>(Node->Real)->getOperand(0); 1823 auto *MaskImag = cast<Instruction>(Node->Imag)->getOperand(0); 1824 auto *A = replaceNode(Builder, Node->Operands[0]); 1825 auto *B = replaceNode(Builder, Node->Operands[1]); 1826 auto *NewMaskTy = VectorType::getDoubleElementsVectorType( 1827 cast<VectorType>(MaskReal->getType())); 1828 auto *NewMask = 1829 Builder.CreateIntrinsic(Intrinsic::experimental_vector_interleave2, 1830 NewMaskTy, {MaskReal, MaskImag}); 1831 ReplacementNode = Builder.CreateSelect(NewMask, A, B); 1832 break; 1833 } 1834 } 1835 1836 assert(ReplacementNode && "Target failed to create Intrinsic call."); 1837 NumComplexTransformations += 1; 1838 Node->ReplacementNode = ReplacementNode; 1839 return ReplacementNode; 1840 } 1841 1842 void ComplexDeinterleavingGraph::processReductionOperation( 1843 Value *OperationReplacement, RawNodePtr Node) { 1844 auto *Real = cast<Instruction>(Node->Real); 1845 auto *Imag = cast<Instruction>(Node->Imag); 1846 auto *OldPHIReal = ReductionInfo[Real].first; 1847 auto *OldPHIImag = ReductionInfo[Imag].first; 1848 auto *NewPHI = OldToNewPHI[OldPHIReal]; 1849 1850 auto *VTy = cast<VectorType>(Real->getType()); 1851 auto *NewVTy = VectorType::getDoubleElementsVectorType(VTy); 1852 1853 // We have to interleave initial origin values coming from IncomingBlock 1854 Value *InitReal = OldPHIReal->getIncomingValueForBlock(Incoming); 1855 Value *InitImag = OldPHIImag->getIncomingValueForBlock(Incoming); 1856 1857 IRBuilder<> Builder(Incoming->getTerminator()); 1858 auto *NewInit = Builder.CreateIntrinsic( 1859 Intrinsic::experimental_vector_interleave2, NewVTy, {InitReal, InitImag}); 1860 1861 NewPHI->addIncoming(NewInit, Incoming); 1862 NewPHI->addIncoming(OperationReplacement, BackEdge); 1863 1864 // Deinterleave complex vector outside of loop so that it can be finally 1865 // reduced 1866 auto *FinalReductionReal = ReductionInfo[Real].second; 1867 auto *FinalReductionImag = ReductionInfo[Imag].second; 1868 1869 Builder.SetInsertPoint( 1870 &*FinalReductionReal->getParent()->getFirstInsertionPt()); 1871 auto *Deinterleave = Builder.CreateIntrinsic( 1872 Intrinsic::experimental_vector_deinterleave2, 1873 OperationReplacement->getType(), OperationReplacement); 1874 1875 auto *NewReal = Builder.CreateExtractValue(Deinterleave, (uint64_t)0); 1876 FinalReductionReal->replaceUsesOfWith(Real, NewReal); 1877 1878 Builder.SetInsertPoint(FinalReductionImag); 1879 auto *NewImag = Builder.CreateExtractValue(Deinterleave, 1); 1880 FinalReductionImag->replaceUsesOfWith(Imag, NewImag); 1881 } 1882 1883 void ComplexDeinterleavingGraph::replaceNodes() { 1884 SmallVector<Instruction *, 16> DeadInstrRoots; 1885 for (auto *RootInstruction : OrderedRoots) { 1886 // Check if this potential root went through check process and we can 1887 // deinterleave it 1888 if (!RootToNode.count(RootInstruction)) 1889 continue; 1890 1891 IRBuilder<> Builder(RootInstruction); 1892 auto RootNode = RootToNode[RootInstruction]; 1893 Value *R = replaceNode(Builder, RootNode.get()); 1894 1895 if (RootNode->Operation == 1896 ComplexDeinterleavingOperation::ReductionOperation) { 1897 auto *RootReal = cast<Instruction>(RootNode->Real); 1898 auto *RootImag = cast<Instruction>(RootNode->Imag); 1899 ReductionInfo[RootReal].first->removeIncomingValue(BackEdge); 1900 ReductionInfo[RootImag].first->removeIncomingValue(BackEdge); 1901 DeadInstrRoots.push_back(cast<Instruction>(RootReal)); 1902 DeadInstrRoots.push_back(cast<Instruction>(RootImag)); 1903 } else { 1904 assert(R && "Unable to find replacement for RootInstruction"); 1905 DeadInstrRoots.push_back(RootInstruction); 1906 RootInstruction->replaceAllUsesWith(R); 1907 } 1908 } 1909 1910 for (auto *I : DeadInstrRoots) 1911 RecursivelyDeleteTriviallyDeadInstructions(I, TLI); 1912 } 1913